├── .gitignore
├── README
├── commands.txt
├── final-draft.pdf
├── pom.xml
└── src
├── main
├── java
│ ├── collabstream
│ │ ├── App.java
│ │ ├── lc
│ │ │ ├── CounterBolt.java
│ │ │ ├── FileReaderSpout.java
│ │ │ ├── LineCountTopology.java
│ │ │ ├── MsgForwarder.java
│ │ │ └── MsgTracker.java
│ │ └── streaming
│ │ │ ├── BlockPair.java
│ │ │ ├── Configuration.java
│ │ │ ├── Master.java
│ │ │ ├── MatrixSerialization.java
│ │ │ ├── MatrixStore.java
│ │ │ ├── MatrixUtils.java
│ │ │ ├── MsgType.java
│ │ │ ├── PermutationUtils.java
│ │ │ ├── RatingsSource.java
│ │ │ ├── StreamingDSGD.java
│ │ │ ├── TestPredictions.java
│ │ │ ├── TrainingExample.java
│ │ │ ├── Worker.java
│ │ │ └── WorkingBlock.java
│ └── comparison
│ │ └── dsgd
│ │ ├── DSGDMain.java
│ │ ├── IntMatrixItemPair.java
│ │ ├── LongPair.java
│ │ ├── MatrixItem.java
│ │ ├── MatrixUtils.java
│ │ ├── mapper
│ │ ├── DSGDFinalMapper.java
│ │ ├── DSGDIntermediateMapper.java
│ │ ├── DSGDOutputFactorsMapper.java
│ │ ├── DSGDPreprocFactorMapper.java
│ │ ├── DSGDPreprocMapper.java
│ │ ├── DSGDPreprocRatingsMapper.java
│ │ └── DSGDRmseMapper.java
│ │ └── reducer
│ │ ├── DSGDFinalReducer.java
│ │ ├── DSGDIntermediateReducer.java
│ │ ├── DSGDOutputFactorsReducer.java
│ │ ├── DSGDPreprocFactorReducer.java
│ │ ├── DSGDPreprocRatingsReducer.java
│ │ ├── DSGDPreprocReducer.java
│ │ └── DSGDRmseReducer.java
└── resources
│ ├── log4j.properties
│ └── lotr.txt
└── test
└── java
└── collabstream
└── AppTest.java
/.gitignore:
--------------------------------------------------------------------------------
1 | #Compiled source #
2 | ###################
3 | *.com
4 | *.class
5 | *.dll
6 | *.exe
7 | *.o
8 | *.so
9 |
10 | # Packages #
11 | ############
12 | # it's better to unpack these files and commit the raw source
13 | # git has its own built in compression methods
14 | *.7z
15 | *.dmg
16 | *.gz
17 | *.iso
18 | *.jar
19 | *.rar
20 | *.tar
21 | *.zip
22 |
23 | # Logs and databases #
24 | ######################
25 | *.log
26 | *.sql
27 | *.sqlite
28 |
29 | # OS generated files #
30 | ######################
31 | .DS_Store*
32 | ehthumbs.db
33 | Icon?
34 | Thumbs.db
35 |
36 | #directories
37 | classes/
38 | lib/
39 | data/
40 |
41 | # Eclipse settings
42 | .settings/
43 | .project
44 | .classpath
45 |
--------------------------------------------------------------------------------
/README:
--------------------------------------------------------------------------------
1 | Authors: Chris Johnson, Alex Tang, Muqeet Ali
2 |
3 | CollabStream implements parallelized online matrix factorization using stochastic gradient descent in the case of online streaming data. CollabStream utilizes the Storm parallelization framework developed by Natahn Marz and available at https://github.com/nathanmarz/storm. For more information on the algorithm used by CollabStream please see the full report found at https://github.com/MrChrisJohnson/CollabStream.
4 |
5 | To run:
6 | 1. Go to the directory where you cloned the project.
7 | 2. mvn compile
8 | 3. mvn package
9 | 4. Take the jar with dependencies from the target folder (e.g., name.jar)
10 | 5. run on Hadoop cluster using:
11 | hadoop jar name.jar comparison.dsgd.DSGDMain inputPath outputPath numReducers
12 | where inputPath refers to the path of the train and test data, outputPath refers to where the RMSE and the factor matrices are outputted.
13 |
14 |
--------------------------------------------------------------------------------
/commands.txt:
--------------------------------------------------------------------------------
1 | # Some useful commands
2 |
3 | mvn exec:java -Dexec.mainClass=collabstream.lc.LineCountTopology -Dexec.args='local src/main/resources/lotr.txt' -Dexec.classpathScope=compile
4 | mvn exec:java -Dexec.mainClass=collabstream.lc.LineCountTopology -Dexec.args='local src/main/resources/lotr.txt' -Dexec.classpathScope=compile | grep '########'
5 | mvn compile exec:java -Dexec.mainClass=collabstream.lc.LineCountTopology -Dexec.args='local src/main/resources/lotr.txt' -Dexec.classpathScope=compile
6 | mvn compile exec:java -Dexec.mainClass=collabstream.lc.LineCountTopology -Dexec.args='local src/main/resources/lotr.txt' -Dexec.classpathScope=compile | grep '########'
7 |
8 | mvn exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 17 28 data/input/test.dat data/output/test.user data/output/test.item' -Dexec.classpathScope=compile
9 | mvn exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 17 28 data/input/test.dat data/output/test.user data/output/test.item' -Dexec.classpathScope=compile | grep '########'
10 | mvn compile exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 17 28 data/input/test.dat data/output/test.user data/output/test.item' -Dexec.classpathScope=compile
11 | mvn compile exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 17 28 data/input/test.dat data/output/test.user data/output/test.item' -Dexec.classpathScope=compile | grep '########'
12 |
13 | mvn exec:java -Dexec.mainClass=collabstream.streaming.TestPredictions -Dexec.args='4 5 3 data/input/predtest data/output/predtest.user data/output/predtest.item' -Dexec.classpathScope=compile
14 |
15 | mvn exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 6041 3953 data/input/MovieLens/ml_tr_rand.txt data/output/MovieLens/ml.user data/output/MovieLens/ml.item' -Dexec.classpathScope=compile
16 |
17 | mvn exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 6041 3953 data/input/MovieLens/ml_100k data/output/MovieLens/ml_100k.user data/output/MovieLens/ml_100k.item' -Dexec.classpathScope=compile
18 | mvn compile exec:java -Dexec.mainClass=collabstream.streaming.TestPredictions -Dexec.args='6041 3953 10 data/input/MovieLens/ml1m_te_rb.dat data/output/MovieLens/ml_100k.user data/output/MovieLens/ml_100k.item' -Dexec.classpathScope=compile
--------------------------------------------------------------------------------
/final-draft.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MrChrisJohnson/CollabStream/b54dd4d5f0a098bca63a8de9275417a2ec892b6d/final-draft.pdf
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
3 | 4.0.0
4 |
5 | collabstream
6 | CollabStream
7 | 1.0
8 | jar
9 |
10 | CollabStream
11 | https://github.com/MrChrisJohnson/CollabStream
12 |
13 |
14 | UTF-8
15 |
16 |
17 |
18 |
19 | clojars.org
20 | http://clojars.org/repo
21 |
22 |
23 |
24 |
25 |
26 | junit
27 | junit
28 | 3.8.1
29 | compile
30 |
31 |
32 | storm
33 | storm
34 | 0.5.4
35 | provided
36 |
37 |
38 | log4j
39 | log4j
40 | 1.2.14
41 |
42 |
43 | org.apache.hadoop
44 | hadoop-core
45 | 0.20.2
46 |
47 |
48 |
49 |
50 |
51 |
52 |
54 | maven-assembly-plugin
55 | 2.2.1
56 |
57 | jar-with-dependencies
58 |
59 |
60 |
61 | make-assembly
62 | package
63 | single
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/src/main/java/collabstream/App.java:
--------------------------------------------------------------------------------
1 | package collabstream;
2 |
3 | /**
4 | * Hello world!
5 | *
6 | */
7 | public class App
8 | {
9 | public static void main( String[] args )
10 | {
11 | System.out.println( "Hello World!" );
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/src/main/java/collabstream/lc/CounterBolt.java:
--------------------------------------------------------------------------------
1 | package collabstream.lc;
2 |
3 | import java.util.HashSet;
4 | import java.util.Map;
5 |
6 | import backtype.storm.task.OutputCollector;
7 | import backtype.storm.task.TopologyContext;
8 | import backtype.storm.topology.IRichBolt;
9 | import backtype.storm.topology.OutputFieldsDeclarer;
10 | import backtype.storm.tuple.Fields;
11 | import backtype.storm.tuple.Tuple;
12 | import backtype.storm.tuple.Values;
13 |
14 | public class CounterBolt implements IRichBolt {
15 | private OutputCollector collector;
16 | private HashSet lineNumSet = new HashSet();
17 |
18 |
19 | public void prepare(Map config, TopologyContext context, OutputCollector collector) {
20 | this.collector = collector;
21 | }
22 |
23 | public void cleanup() {
24 | System.out.println("######## CounterBolt.cleanup: counted " + lineNumSet.size() + " lines");
25 | }
26 |
27 | public void execute(Tuple tuple) {
28 | String line = tuple.getStringByField("line");
29 | Integer lineNum = tuple.getIntegerByField("lineNum");
30 | System.out.println("######## CounterBolt.execute: " + lineNum + ". " + line);
31 | lineNumSet.add(lineNum);
32 | collector.emit(tuple, new Values(lineNum));
33 | collector.ack(tuple);
34 | }
35 |
36 | public void declareOutputFields(OutputFieldsDeclarer declarer) {
37 | declarer.declare(new Fields("lineNum"));
38 | }
39 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/lc/FileReaderSpout.java:
--------------------------------------------------------------------------------
1 | package collabstream.lc;
2 |
3 | import java.io.FileNotFoundException;
4 | import java.io.FileReader;
5 | import java.io.IOException;
6 | import java.io.LineNumberReader;
7 | import java.util.Map;
8 |
9 | import backtype.storm.spout.SpoutOutputCollector;
10 | import backtype.storm.task.TopologyContext;
11 | import backtype.storm.topology.IRichSpout;
12 | import backtype.storm.topology.OutputFieldsDeclarer;
13 | import backtype.storm.tuple.Fields;
14 | import backtype.storm.tuple.Values;
15 |
16 | public class FileReaderSpout implements IRichSpout {
17 | private SpoutOutputCollector collector;
18 | private String fileName;
19 | private LineNumberReader in;
20 |
21 | public FileReaderSpout(String fileName) {
22 | this.fileName = fileName;
23 | }
24 |
25 | public boolean isDistributed() {
26 | return false;
27 | }
28 |
29 | public void open(Map config, TopologyContext context, SpoutOutputCollector collector) {
30 | this.collector = collector;
31 | try {
32 | in = new LineNumberReader(new FileReader(fileName));
33 | } catch (FileNotFoundException e){
34 | System.err.print("######## FileReaderSpout.nextTuple: " + e);
35 | }
36 | }
37 |
38 | public void close() {
39 | if (in == null) return;
40 | try {
41 | in.close();
42 | } catch (IOException e) {
43 | System.err.print("######## FileReaderSpout.nextTuple: " + e);
44 | }
45 | }
46 |
47 | public void ack(Object msgId) {
48 | System.out.println("######## FileReaderSpout.ack: msgId=" + msgId);
49 | }
50 |
51 | public void fail(Object msgId) {
52 | System.out.println("######## FileReaderSpout.fail: msgId=" + msgId);
53 | }
54 |
55 | public void nextTuple() {
56 | if (in == null) return;
57 | String line = null;
58 |
59 | try {
60 | line = in.readLine();
61 | } catch (IOException e) {
62 | System.err.print("######## FileReaderSpout.nextTuple: " + e);
63 | }
64 |
65 | if (line != null) {
66 | int lineNum = in.getLineNumber();
67 | System.out.println("######## FileReaderSpout.nextTuple: " + lineNum + ". " + line);
68 | collector.emit(new Values(lineNum, line), lineNum);
69 | }
70 | }
71 |
72 | public void declareOutputFields(OutputFieldsDeclarer declarer) {
73 | declarer.declare(new Fields("lineNum", "line"));
74 | }
75 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/lc/LineCountTopology.java:
--------------------------------------------------------------------------------
1 | package collabstream.lc;
2 |
3 | import backtype.storm.Config;
4 | import backtype.storm.LocalCluster;
5 | import backtype.storm.StormSubmitter;
6 | import backtype.storm.topology.TopologyBuilder;
7 | import backtype.storm.utils.Utils;
8 |
9 | /**
10 | * Simple line-counting program to test the functionality of Storm. Instantiates four classes: FileReaderSpout,
11 | * CounterBolt, MsgForwarder, and MsgTracker. FileReaderSpout reads lines from a file and emits them. CounterBolt
12 | * counts the lines received from FileReaderSpout and emits their corresponding line numbers. MsgForwarder just forwards
13 | * the lines received from FileReaderSpout and the line numbers received from CounterBolt. MsgTracker receives the
14 | * messages forwarded by MsgForwarder and keeps track of which lines were read and which lines were counted.
15 | */
16 | public class LineCountTopology {
17 | public static void main(String[] args) throws Exception {
18 | if (args.length < 2) {
19 | System.err.println("######## Wrong number of arguments");
20 | System.err.println("######## required args: local|production fileName");
21 | return;
22 | }
23 |
24 | Config config = new Config();
25 | TopologyBuilder builder = new TopologyBuilder();
26 | builder.setSpout(1, new FileReaderSpout(args[1]));
27 | builder.setBolt(2, new CounterBolt()).shuffleGrouping(1);
28 | builder.setBolt(3, new MsgForwarder(1,2)).shuffleGrouping(1).shuffleGrouping(2);
29 | builder.setBolt(4, new MsgTracker(1,2)).shuffleGrouping(3,1).shuffleGrouping(3,2);
30 |
31 | System.out.println("######## LineCountTopology.main: submitting topology");
32 |
33 | if ("local".equals(args[0])) {
34 | LocalCluster cluster = new LocalCluster();
35 | cluster.submitTopology("line-count", config, builder.createTopology());
36 | System.out.println("######## LineCountTopology.main: sleeping for 10 secs");
37 | Utils.sleep(10000);
38 | cluster.shutdown();
39 | } else {
40 | StormSubmitter.submitTopology("line-count", config, builder.createTopology());
41 | }
42 | }
43 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/lc/MsgForwarder.java:
--------------------------------------------------------------------------------
1 | package collabstream.lc;
2 |
3 | import java.util.Map;
4 |
5 | import backtype.storm.task.OutputCollector;
6 | import backtype.storm.task.TopologyContext;
7 | import backtype.storm.topology.IRichBolt;
8 | import backtype.storm.topology.OutputFieldsDeclarer;
9 | import backtype.storm.tuple.Fields;
10 | import backtype.storm.tuple.Tuple;
11 |
12 | public class MsgForwarder implements IRichBolt {
13 | private OutputCollector collector;
14 | public int fileReaderId, counterId;
15 |
16 | public MsgForwarder(int fileReaderId, int counterId) {
17 | this.fileReaderId = fileReaderId;
18 | this.counterId = counterId;
19 | }
20 |
21 | public void prepare(Map config, TopologyContext context, OutputCollector collector) {
22 | this.collector = collector;
23 | }
24 |
25 | public void cleanup() {
26 | }
27 |
28 | public void execute(Tuple tuple) {
29 | // The ID of the component that emitted the tuple is mapped directly to the ID of the forwarding stream.
30 | collector.emit(tuple.getSourceComponent(), tuple.getValues());
31 | collector.ack(tuple);
32 | }
33 |
34 | public void declareOutputFields(OutputFieldsDeclarer declarer) {
35 | declarer.declareStream(fileReaderId, new Fields("lineNum", "line"));
36 | declarer.declareStream(counterId, new Fields("lineNum"));
37 | }
38 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/lc/MsgTracker.java:
--------------------------------------------------------------------------------
1 | package collabstream.lc;
2 |
3 | import java.util.ArrayList;
4 | import java.util.Collections;
5 | import java.util.HashMap;
6 | import java.util.Map;
7 |
8 | import backtype.storm.task.OutputCollector;
9 | import backtype.storm.task.TopologyContext;
10 | import backtype.storm.topology.IRichBolt;
11 | import backtype.storm.topology.OutputFieldsDeclarer;
12 | import backtype.storm.tuple.Tuple;
13 |
14 | public class MsgTracker implements IRichBolt {
15 | private OutputCollector collector;
16 | private int fileReaderId, counterId;
17 | private HashMap lnToRec = new HashMap();
18 |
19 | private static class Record {
20 | boolean read = false;
21 | boolean counted = false;
22 |
23 | Record() {}
24 |
25 | public String toString() {
26 | return "read=" + (read ? 1 : 0) + ", counted=" + (counted ? 1 : 0);
27 | }
28 | }
29 |
30 | public MsgTracker(int fileReaderId, int counterId) {
31 | this.fileReaderId = fileReaderId;
32 | this.counterId = counterId;
33 | }
34 |
35 | public void prepare(Map config, TopologyContext context, OutputCollector collector) {
36 | this.collector = collector;
37 | }
38 |
39 | public void cleanup() {
40 | ArrayList keys = new ArrayList(lnToRec.keySet());
41 | Collections.sort(keys);
42 | for (Integer lineNum : keys) {
43 | System.out.println("######## MsgTracker.cleanup: " + lineNum + ". " + lnToRec.get(lineNum));
44 | }
45 | }
46 |
47 | public void execute(Tuple tuple) {
48 | Integer lineNum = tuple.getIntegerByField("lineNum");
49 | Record rec = lnToRec.get(lineNum);
50 | if (rec == null) {
51 | rec = new Record();
52 | lnToRec.put(lineNum, rec);
53 | }
54 |
55 | int streamId = tuple.getSourceStreamId();
56 | if (streamId == fileReaderId) {
57 | rec.read = true;
58 | } else if (streamId == counterId) {
59 | rec.counted = true;
60 | }
61 | collector.ack(tuple);
62 | }
63 |
64 | public void declareOutputFields(OutputFieldsDeclarer declarer) {
65 | }
66 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/streaming/BlockPair.java:
--------------------------------------------------------------------------------
1 | package collabstream.streaming;
2 |
3 | import java.io.DataInputStream;
4 | import java.io.DataOutputStream;
5 | import java.io.IOException;
6 | import java.io.Serializable;
7 |
8 | import backtype.storm.serialization.ISerialization;
9 |
10 | public class BlockPair implements Serializable {
11 | public final int userBlockIdx, itemBlockIdx;
12 |
13 | public BlockPair(int userBlockIdx, int itemBlockIdx) {
14 | this.userBlockIdx = userBlockIdx;
15 | this.itemBlockIdx = itemBlockIdx;
16 | }
17 |
18 | public String toString() {
19 | return "(" + userBlockIdx + "," + itemBlockIdx + ")";
20 | }
21 |
22 | public boolean equals(Object obj) {
23 | if (obj instanceof BlockPair) {
24 | BlockPair bp = (BlockPair)obj;
25 | return this.userBlockIdx == bp.userBlockIdx && this.itemBlockIdx == bp.itemBlockIdx;
26 | } else {
27 | return false;
28 | }
29 | }
30 |
31 | public int hashCode() {
32 | int userHi = userBlockIdx >>> 16;
33 | int userLo = userBlockIdx & 0xFFFF;
34 | int itemHi = itemBlockIdx >>> 16;
35 | int itemLo = itemBlockIdx & 0xFFFF;
36 |
37 | return ((userLo ^ itemHi) << 16) | (userHi ^ itemLo);
38 | }
39 |
40 | public static class Serialization implements ISerialization {
41 | public boolean accept(Class c) {
42 | return BlockPair.class.equals(c);
43 | }
44 |
45 | public void serialize(BlockPair bp, DataOutputStream out) throws IOException {
46 | out.writeInt(bp.userBlockIdx);
47 | out.writeInt(bp.itemBlockIdx);
48 | }
49 |
50 | public BlockPair deserialize(DataInputStream in) throws IOException {
51 | return new BlockPair(in.readInt(), in.readInt());
52 | }
53 | }
54 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/streaming/Configuration.java:
--------------------------------------------------------------------------------
1 | package collabstream.streaming;
2 |
3 | import java.io.Serializable;
4 |
5 | public class Configuration implements Serializable {
6 | public final int numUsers, numItems, numLatent;
7 | public final int numUserBlocks, numItemBlocks;
8 | public final float userPenalty, itemPenalty, initialStepSize;
9 | public final int maxTrainingIters;
10 | public final String inputFilename, userOutputFilename, itemOutputFilename;
11 | public final long inputDelay; // delay between sending two training examples in milliseconds
12 | public final boolean debug;
13 |
14 | private final int smallUserBlockSize, smallItemBlockSize;
15 | private final int bigUserBlockSize, bigItemBlockSize;
16 | private final int numBigUserBlocks, numBigItemBlocks;
17 | private final int userBlockThreshold, itemBlockThreshold;
18 |
19 | public Configuration(int numUsers, int numItems, int numLatent, int numUserBlocks, int numItemBlocks,
20 | float userPenalty, float itemPenalty, float initialStepSize, int maxTrainingIters,
21 | String inputFilename, String userOutputFilename, String itemOutputFilename,
22 | long inputDelay, boolean debug) {
23 | this.numUsers = numUsers;
24 | this.numItems = numItems;
25 | this.numLatent = numLatent;
26 | this.numUserBlocks = numUserBlocks;
27 | this.numItemBlocks = numItemBlocks;
28 | this.userPenalty = userPenalty;
29 | this.itemPenalty = itemPenalty;
30 | this.initialStepSize = initialStepSize;
31 | this.maxTrainingIters = maxTrainingIters;
32 | this.inputFilename = inputFilename;
33 | this.userOutputFilename = userOutputFilename;
34 | this.itemOutputFilename = itemOutputFilename;
35 | this.inputDelay = inputDelay;
36 | this.debug = debug;
37 |
38 | smallUserBlockSize = numUsers / numUserBlocks;
39 | bigUserBlockSize = smallUserBlockSize + 1;
40 | numBigUserBlocks = numUsers % numUserBlocks;
41 | userBlockThreshold = bigUserBlockSize * numBigUserBlocks;
42 |
43 | smallItemBlockSize = numItems / numItemBlocks;
44 | bigItemBlockSize = smallItemBlockSize + 1;
45 | numBigItemBlocks = numItems % numItemBlocks;
46 | itemBlockThreshold = bigItemBlockSize * numBigItemBlocks;
47 | }
48 |
49 | public int getNumWorkers() {
50 | return numUserBlocks;
51 | }
52 |
53 | public int getNumProcesses() {
54 | // numUserBlocks (Worker) + numUserBlocks (MatrixStore) + numItemBlocks (MatrixStore) + 1 (Master) + 1 (RatingsSource)
55 | return 2*numUserBlocks + numItemBlocks + 2;
56 | }
57 |
58 | public int getUserBlockIdx(int userId) {
59 | if (userId < userBlockThreshold) {
60 | return userId / bigUserBlockSize;
61 | } else {
62 | return numBigUserBlocks + (userId - userBlockThreshold) / smallUserBlockSize;
63 | }
64 | }
65 |
66 | public int getItemBlockIdx(int itemId) {
67 | if (itemId < itemBlockThreshold) {
68 | return itemId / bigItemBlockSize;
69 | } else {
70 | return numBigItemBlocks + (itemId - itemBlockThreshold) / smallItemBlockSize;
71 | }
72 | }
73 |
74 | public int getUserBlockLength(int userBlockIdx) {
75 | return (userBlockIdx < numBigUserBlocks) ? bigUserBlockSize : smallUserBlockSize;
76 | }
77 |
78 | public int getItemBlockLength(int itemBlockIdx) {
79 | return (itemBlockIdx < numBigItemBlocks) ? bigItemBlockSize : smallItemBlockSize;
80 | }
81 |
82 | public int getUserBlockStart(int userBlockIdx) {
83 | if (userBlockIdx < numBigUserBlocks) {
84 | return bigUserBlockSize * userBlockIdx;
85 | } else {
86 | return userBlockThreshold + (userBlockIdx - numBigUserBlocks) * smallUserBlockSize;
87 | }
88 | }
89 |
90 | public int getItemBlockStart(int itemBlockIdx) {
91 | if (itemBlockIdx < numBigItemBlocks) {
92 | return bigItemBlockSize * itemBlockIdx;
93 | } else {
94 | return itemBlockThreshold + (itemBlockIdx - numBigItemBlocks) * smallItemBlockSize;
95 | }
96 | }
97 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/streaming/Master.java:
--------------------------------------------------------------------------------
1 | package collabstream.streaming;
2 |
3 | import java.io.BufferedWriter;
4 | import java.io.FileWriter;
5 | import java.io.IOException;
6 | import java.io.PrintWriter;
7 | import java.util.ArrayList;
8 | import java.util.HashSet;
9 | import java.util.LinkedList;
10 | import java.util.Map;
11 | import java.util.Queue;
12 | import java.util.Random;
13 | import java.util.Set;
14 |
15 | import org.apache.commons.lang.time.DurationFormatUtils;
16 |
17 | import backtype.storm.task.OutputCollector;
18 | import backtype.storm.task.TopologyContext;
19 | import backtype.storm.topology.IRichBolt;
20 | import backtype.storm.topology.OutputFieldsDeclarer;
21 | import backtype.storm.tuple.Fields;
22 | import backtype.storm.tuple.Tuple;
23 | import backtype.storm.tuple.Values;
24 |
25 | import static collabstream.streaming.MsgType.*;
26 |
27 | public class Master implements IRichBolt {
28 | private OutputCollector collector;
29 | private final Configuration config;
30 | private PrintWriter userOutput, itemOutput;
31 | private BlockPair[][] blockPair;
32 | private TrainingExample[][] latestExample;
33 | private Set unfinished = new HashSet();
34 | private Set freeSet = new HashSet();
35 | private Queue userBlockQueue = new LinkedList();
36 | private Queue itemBlockQueue = new LinkedList();
37 | private boolean endOfData = false;
38 | private long startTime, outputStartTime = 0;
39 | private final Random random = new Random();
40 |
41 | public Master(Configuration config) {
42 | this.config = config;
43 | blockPair = new BlockPair[config.numUserBlocks][config.numItemBlocks];
44 | latestExample = new TrainingExample[config.numUserBlocks][config.numItemBlocks];
45 | }
46 |
47 | public void prepare(Map stormConfig, TopologyContext context, OutputCollector collector) {
48 | this.collector = collector;
49 | startTime = System.currentTimeMillis();
50 | System.out.printf("######## Training started: %1$tY-%1$tb-%1$td %1$tT %tZ\n", startTime);
51 | }
52 |
53 | public void cleanup() {
54 | }
55 |
56 | public void execute(Tuple tuple) {
57 | MsgType msgType = (MsgType)tuple.getValue(0);
58 | if (config.debug && msgType != END_OF_DATA && msgType != PROCESS_BLOCK_FIN) {
59 | System.out.println("######## Master.execute: " + msgType + " " + tuple.getValue(1));
60 | }
61 | TrainingExample ex, latest;
62 | BlockPair bp, head;
63 |
64 | switch (msgType) {
65 | case END_OF_DATA:
66 | if (config.debug) {
67 | System.out.println("######## Master.execute: " + msgType);
68 | }
69 | endOfData = true;
70 | collector.ack(tuple);
71 | distributeWork();
72 | break;
73 | case TRAINING_EXAMPLE:
74 | ex = (TrainingExample)tuple.getValue(1);
75 | int userBlockIdx = config.getUserBlockIdx(ex.userId);
76 | int itemBlockIdx = config.getItemBlockIdx(ex.itemId);
77 |
78 | latest = latestExample[userBlockIdx][itemBlockIdx];
79 | if (latest == null || latest.timestamp < ex.timestamp) {
80 | latestExample[userBlockIdx][itemBlockIdx] = ex;
81 | }
82 |
83 | bp = blockPair[userBlockIdx][itemBlockIdx];
84 | if (bp == null) {
85 | bp = blockPair[userBlockIdx][itemBlockIdx] = new BlockPair(userBlockIdx, itemBlockIdx);
86 | unfinished.add(bp);
87 | freeSet.add(bp);
88 | }
89 |
90 | collector.emit(tuple, new Values(TRAINING_EXAMPLE, null, ex, userBlockIdx));
91 | collector.ack(tuple);
92 | distributeWork();
93 | break;
94 | case PROCESS_BLOCK_FIN:
95 | bp = (BlockPair)tuple.getValue(1);
96 | ex = (TrainingExample)tuple.getValue(2);
97 | if (config.debug) {
98 | System.out.println("######## Master.execute: " + msgType + " " + bp + " " + ex);
99 | }
100 |
101 | latest = latestExample[bp.userBlockIdx][bp.itemBlockIdx];
102 | if (latest.timestamp == ex.timestamp) {
103 | latest.numTrainingIters = ex.numTrainingIters;
104 | if (endOfData && latest.numTrainingIters >= config.maxTrainingIters) {
105 | unfinished.remove(bp);
106 | }
107 | }
108 |
109 | // numTrainingIters must be updated before free() is called to prevent finished blocks
110 | // from being added to the freeSet
111 | free(bp.userBlockIdx, bp.itemBlockIdx);
112 |
113 | if (unfinished.isEmpty()) {
114 | startOutput();
115 | } else {
116 | distributeWork();
117 | }
118 | break;
119 | case USER_BLOCK:
120 | bp = (BlockPair)tuple.getValue(1);
121 | float[][] userBlock = (float[][])tuple.getValue(2);
122 | head = userBlockQueue.remove();
123 | if (!head.equals(bp)) {
124 | throw new RuntimeException("Expected " + head + ", but received " + bp + " for " + USER_BLOCK);
125 | }
126 | writeUserBlock(bp.userBlockIdx, userBlock);
127 | requestNextUserBlock();
128 | break;
129 | case ITEM_BLOCK:
130 | bp = (BlockPair)tuple.getValue(1);
131 | float[][] itemBlock = (float[][])tuple.getValue(2);
132 | head = itemBlockQueue.remove();
133 | if (!head.equals(bp)) {
134 | throw new RuntimeException("Expected " + head + ", but received " + bp + " for " + ITEM_BLOCK);
135 | }
136 | writeItemBlock(bp.itemBlockIdx, itemBlock);
137 | requestNextItemBlock();
138 | break;
139 | }
140 | }
141 |
142 | public void declareOutputFields(OutputFieldsDeclarer declarer) {
143 | // Field "userBlockIdx" used solely for grouping
144 | declarer.declare(new Fields("msgType", "blockPair", "example", "userBlockIdx"));
145 | }
146 |
147 | private void lock(int userBlockIdx, int itemBlockIdx) {
148 | for (int j = 0; j < config.numItemBlocks; ++j) {
149 | BlockPair bp = blockPair[userBlockIdx][j];
150 | if (bp != null) {
151 | freeSet.remove(bp);
152 | }
153 | }
154 | for (int i = 0; i < config.numUserBlocks; ++i) {
155 | BlockPair bp = blockPair[i][itemBlockIdx];
156 | if (bp != null) {
157 | freeSet.remove(bp);
158 | }
159 | }
160 | }
161 |
162 | private void free(int userBlockIdx, int itemBlockIdx) {
163 | for (int j = 0; j < config.numItemBlocks; ++j) {
164 | BlockPair bp = blockPair[userBlockIdx][j];
165 | if (bp != null) {
166 | if (latestExample[userBlockIdx][j].numTrainingIters < config.maxTrainingIters) {
167 | freeSet.add(bp);
168 | }
169 | }
170 | }
171 | for (int i = 0; i < config.numUserBlocks; ++i) {
172 | BlockPair bp = blockPair[i][itemBlockIdx];
173 | if (bp != null) {
174 | if (latestExample[i][itemBlockIdx].numTrainingIters < config.maxTrainingIters) {
175 | freeSet.add(bp);
176 | }
177 | }
178 | }
179 | }
180 |
181 | private void distributeWork() {
182 | ArrayList freeList = new ArrayList(freeSet.size());
183 | while (!freeSet.isEmpty()) {
184 | freeList.addAll(freeSet);
185 | int i = random.nextInt(freeList.size());
186 | BlockPair bp = freeList.get(i);
187 | lock(bp.userBlockIdx, bp.itemBlockIdx);
188 | collector.emit(new Values(PROCESS_BLOCK_REQ, bp, null, bp.userBlockIdx));
189 | freeList.clear();
190 | }
191 | }
192 |
193 | private void startOutput() {
194 | try {
195 | if (outputStartTime > 0) return;
196 | outputStartTime = System.currentTimeMillis();
197 | System.out.printf("######## Training finished: %1$tY-%1$tb-%1$td %1$tT %tZ\n", outputStartTime);
198 | System.out.println("######## Elapsed training time: "
199 | + DurationFormatUtils.formatPeriod(startTime, outputStartTime, "H:m:s") + " (h:m:s)");
200 | System.out.printf("######## Output started: %1$tY-%1$tb-%1$td %1$tT %tZ\n", outputStartTime);
201 |
202 | userOutput = new PrintWriter(new BufferedWriter(new FileWriter(config.userOutputFilename)));
203 | itemOutput = new PrintWriter(new BufferedWriter(new FileWriter(config.itemOutputFilename)));
204 |
205 | for (int i = 0; i < config.numUserBlocks; ++i) {
206 | // Add the first block in row i to userBlockQueue
207 | for (int j = 0; j < config.numItemBlocks; ++j) {
208 | BlockPair bp = blockPair[i][j];
209 | if (bp != null) {
210 | userBlockQueue.add(bp);
211 | break;
212 | }
213 | }
214 | }
215 | for (int j = 0; j < config.numItemBlocks; ++j) {
216 | // Add the first block in column j to itemBlockQueue
217 | for (int i = 0; i < config.numUserBlocks; ++i) {
218 | BlockPair bp = blockPair[i][j];
219 | if (bp != null) {
220 | itemBlockQueue.add(bp);
221 | break;
222 | }
223 | }
224 | }
225 |
226 | requestNextUserBlock();
227 | requestNextItemBlock();
228 | } catch (IOException e) {
229 | System.err.println("######## Master.startOutput: " + e);
230 | }
231 | }
232 |
233 | private void writeUserBlock(int userBlockIdx, float[][] userBlock) {
234 | int userBlockStart = config.getUserBlockStart(userBlockIdx);
235 | for (int i = 0; i < userBlock.length; ++i) {
236 | userOutput.print(userBlockStart + i);
237 | for (int k = 0; k < config.numLatent; ++k) {
238 | userOutput.print(' ');
239 | userOutput.print(userBlock[i][k]);
240 | }
241 | userOutput.println();
242 | }
243 | }
244 |
245 | private void writeItemBlock(int itemBlockIdx, float[][] itemBlock) {
246 | int itemBlockStart = config.getItemBlockStart(itemBlockIdx);
247 | for (int j = 0; j < itemBlock.length; ++j) {
248 | itemOutput.print(itemBlockStart + j);
249 | for (int k = 0; k < config.numLatent; ++k) {
250 | itemOutput.print(' ');
251 | itemOutput.print(itemBlock[j][k]);
252 | }
253 | itemOutput.println();
254 | }
255 | }
256 |
257 | private void endOutput() {
258 | long endTime = System.currentTimeMillis();
259 | System.out.printf("######## Output finished: %1$tY-%1$tb-%1$td %1$tT %tZ\n", endTime);
260 | System.out.println("######## Elapsed output time: "
261 | + DurationFormatUtils.formatPeriod(outputStartTime, endTime, "H:m:s") + " (h:m:s)");
262 | System.out.println("######## Total elapsed time: "
263 | + DurationFormatUtils.formatPeriod(startTime, endTime, "H:m:s") + " (h:m:s)");
264 | }
265 |
266 | private void requestNextUserBlock() {
267 | if (userBlockQueue.isEmpty()) {
268 | if (userOutput != null) {
269 | userOutput.close();
270 | }
271 | if (itemBlockQueue.isEmpty()) {
272 | endOutput();
273 | }
274 | } else {
275 | BlockPair bp = userBlockQueue.peek();
276 | collector.emit(new Values(USER_BLOCK_REQ, bp, null, bp.userBlockIdx));
277 | }
278 | }
279 |
280 | private void requestNextItemBlock() {
281 | if (itemBlockQueue.isEmpty()) {
282 | if (itemOutput != null) {
283 | itemOutput.close();
284 | }
285 | if (userBlockQueue.isEmpty()) {
286 | endOutput();
287 | }
288 | } else {
289 | BlockPair bp = itemBlockQueue.peek();
290 | collector.emit(new Values(ITEM_BLOCK_REQ, bp, null, bp.userBlockIdx));
291 | }
292 | }
293 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/streaming/MatrixSerialization.java:
--------------------------------------------------------------------------------
1 | package collabstream.streaming;
2 |
3 | import java.io.DataInputStream;
4 | import java.io.DataOutputStream;
5 | import java.io.IOException;
6 | import java.nio.ByteBuffer;
7 | import java.nio.FloatBuffer;
8 |
9 | import backtype.storm.serialization.ISerialization;
10 |
11 | public class MatrixSerialization implements ISerialization{
12 | public boolean accept(Class c) {
13 | return float[][].class.equals(c);
14 | }
15 |
16 | public void serialize(float[][] matrix, DataOutputStream out) throws IOException {
17 | if (matrix == null) return;
18 | int numRows = matrix.length;
19 | int numCols = (numRows == 0) ? 0 : matrix[0].length;
20 |
21 | out.writeInt(numRows);
22 | out.writeInt(numCols);
23 |
24 | byte[] arr = new byte[4*numCols];
25 | ByteBuffer bbuf = ByteBuffer.wrap(arr);
26 | FloatBuffer fbuf = bbuf.asFloatBuffer();
27 |
28 | for (int i = 0; i < numRows; ++i) {
29 | if (matrix[i].length != numCols) {
30 | throw new Error("Rows of matrix have different sizes");
31 | }
32 | fbuf.put(matrix[i]);
33 | out.write(arr);
34 | fbuf.clear();
35 | }
36 | }
37 |
38 | public float[][] deserialize(DataInputStream in) throws IOException {
39 | int numRows = in.readInt();
40 | int numCols = in.readInt();
41 | float[][] matrix = new float[numRows][numCols];
42 |
43 | byte[] arr = new byte[4*numCols];
44 | ByteBuffer bbuf = ByteBuffer.wrap(arr);
45 | FloatBuffer fbuf = bbuf.asFloatBuffer();
46 |
47 | for (int i = 0; i < numRows; ++i) {
48 | in.readFully(arr);
49 | fbuf.get(matrix[i]);
50 | fbuf.clear();
51 | }
52 |
53 | return matrix;
54 | }
55 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/streaming/MatrixStore.java:
--------------------------------------------------------------------------------
1 | package collabstream.streaming;
2 |
3 | import java.util.Map;
4 | import java.util.concurrent.ConcurrentHashMap;
5 |
6 | import backtype.storm.task.OutputCollector;
7 | import backtype.storm.task.TopologyContext;
8 | import backtype.storm.topology.IRichBolt;
9 | import backtype.storm.topology.OutputFieldsDeclarer;
10 | import backtype.storm.tuple.Fields;
11 | import backtype.storm.tuple.Tuple;
12 | import backtype.storm.tuple.Values;
13 |
14 | import static collabstream.streaming.MsgType.*;
15 |
16 | public class MatrixStore implements IRichBolt {
17 | private OutputCollector collector;
18 | private final Configuration config;
19 | private final Map userBlockMap = new ConcurrentHashMap();
20 | private final Map itemBlockMap = new ConcurrentHashMap();
21 |
22 | public MatrixStore(Configuration config) {
23 | this.config = config;
24 | }
25 |
26 | public void prepare(Map stormConfig, TopologyContext context, OutputCollector collector) {
27 | this.collector = collector;
28 | }
29 |
30 | public void cleanup() {
31 | }
32 |
33 | public void execute(Tuple tuple) {
34 | MsgType msgType = (MsgType)tuple.getValue(0);
35 | BlockPair bp = (BlockPair)tuple.getValue(1);
36 | Integer taskIdObj = (Integer)tuple.getValue(3);
37 | int taskId = (taskIdObj != null) ? taskIdObj.intValue() : tuple.getSourceTask();
38 | if (config.debug) {
39 | System.out.println("######## MatrixStore.execute: " + msgType + " " + bp + " [" + taskId + "]");
40 | }
41 |
42 | switch (msgType) {
43 | case USER_BLOCK_REQ:
44 | // In general, this pattern of access is not thread-safe. But since requests with the same userBlockIdx
45 | // are sent to the same thread, it should be safe in our case.
46 | float[][] userBlock = userBlockMap.get(bp.userBlockIdx);
47 | if (userBlock == null) {
48 | userBlock = MatrixUtils.generateRandomMatrix(config.getUserBlockLength(bp.userBlockIdx), config.numLatent);
49 | userBlockMap.put(bp.userBlockIdx, userBlock);
50 | }
51 | collector.emitDirect(taskId, new Values(USER_BLOCK, bp, (Object)userBlock));
52 | break;
53 | case ITEM_BLOCK_REQ:
54 | // In general, this pattern of access is not thread-safe. But since requests with the same itemBlockIdx
55 | // are sent to the same thread, it should be safe in our case.
56 | float[][] itemBlock = itemBlockMap.get(bp.itemBlockIdx);
57 | if (itemBlock == null) {
58 | itemBlock = MatrixUtils.generateRandomMatrix(config.getItemBlockLength(bp.itemBlockIdx), config.numLatent);
59 | itemBlockMap.put(bp.itemBlockIdx, itemBlock);
60 | }
61 | collector.emitDirect(taskId, new Values(ITEM_BLOCK, bp, (Object)itemBlock));
62 | break;
63 | case USER_BLOCK:
64 | userBlockMap.put(bp.userBlockIdx, (float[][])tuple.getValue(2));
65 | collector.emitDirect(tuple.getSourceTask(), new Values(USER_BLOCK_SAVED, bp, null));
66 | break;
67 | case ITEM_BLOCK:
68 | itemBlockMap.put(bp.itemBlockIdx, (float[][])tuple.getValue(2));
69 | collector.emitDirect(taskId, new Values(ITEM_BLOCK_SAVED, bp, null));
70 | break;
71 | }
72 | }
73 |
74 | public void declareOutputFields(OutputFieldsDeclarer declarer) {
75 | declarer.declare(true, new Fields("msgType", "blockPair", "block"));
76 | }
77 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/streaming/MatrixUtils.java:
--------------------------------------------------------------------------------
1 | package collabstream.streaming;
2 |
3 | import java.util.Random;
4 |
5 | public class MatrixUtils {
6 | private static final ThreadLocal localRandom = new ThreadLocal() {
7 | protected Random initialValue() { return new Random(); }
8 | };
9 |
10 | public static String toString(float[][] matrix) {
11 | if (matrix == null) return "";
12 | int numRows = matrix.length;
13 | if (numRows == 0) return "[]";
14 | int numCols = matrix[0].length;
15 |
16 | StringBuilder b = new StringBuilder(numRows*(9*numCols + 2) + 1);
17 |
18 | b.append('[');
19 | for (int i = 0; i < numRows-1; ++i) {
20 | numCols = matrix[i].length;
21 | for (int j = 0; j < numCols; ++j) {
22 | b.append(String.format(" %8.3f", matrix[i][j]));
23 | }
24 | b.append("\n ");
25 | }
26 | numCols = matrix[numRows-1].length;
27 | for (int j = 0; j < numCols; ++j) {
28 | b.append(String.format(" %8.3f", matrix[numRows-1][j]));
29 | }
30 | b.append(" ]");
31 |
32 | return b.toString();
33 | }
34 |
35 | public static float[][] generateRandomMatrix(int numRows, int numCols) {
36 | Random random = localRandom.get();
37 | float[][] matrix = new float[numRows][numCols];
38 |
39 | for (int i = 0; i < numRows; ++i) {
40 | for (int j = 0; j < numCols; ++j) {
41 | matrix[i][j] = random.nextFloat();
42 | }
43 | }
44 | return matrix;
45 | }
46 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/streaming/MsgType.java:
--------------------------------------------------------------------------------
1 | package collabstream.streaming;
2 |
3 | public enum MsgType {
4 | END_OF_DATA,
5 | TRAINING_EXAMPLE,
6 | PROCESS_BLOCK_REQ,
7 | PROCESS_BLOCK_FIN,
8 | USER_BLOCK_REQ,
9 | ITEM_BLOCK_REQ,
10 | USER_BLOCK,
11 | ITEM_BLOCK,
12 | USER_BLOCK_SAVED,
13 | ITEM_BLOCK_SAVED
14 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/streaming/PermutationUtils.java:
--------------------------------------------------------------------------------
1 | package collabstream.streaming;
2 |
3 | import java.util.ArrayList;
4 | import java.util.Random;
5 |
6 | import org.apache.commons.lang.time.DurationFormatUtils;
7 |
8 | public class PermutationUtils {
9 | private static final ThreadLocal localRandom = new ThreadLocal() {
10 | protected Random initialValue() { return new Random(); }
11 | };
12 |
13 | public static T[] permute(T[] arr) {
14 | if (arr == null) return null;
15 | Random random = localRandom.get();
16 | for (int n = arr.length; n > 1; --n) {
17 | int i = random.nextInt(n);
18 | T temp = arr[i];
19 | arr[i] = arr[n-1];
20 | arr[n-1] = temp;
21 | }
22 | return arr;
23 | }
24 |
25 | public static ArrayList permute(ArrayList list) {
26 | if (list == null) return null;
27 | Random random = localRandom.get();
28 | for (int n = list.size(); n > 1; --n) {
29 | int i = random.nextInt(n);
30 | T temp = list.get(i);
31 | list.set(i, list.get(n-1));
32 | list.set(n-1, temp);
33 | }
34 | return list;
35 | }
36 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/streaming/RatingsSource.java:
--------------------------------------------------------------------------------
1 | package collabstream.streaming;
2 |
3 | import java.io.BufferedReader;
4 | import java.io.FileReader;
5 | import java.io.IOException;
6 | import java.util.Map;
7 |
8 | import org.apache.commons.lang.StringUtils;
9 | import org.apache.commons.lang.time.DurationFormatUtils;
10 |
11 | import backtype.storm.spout.SpoutOutputCollector;
12 | import backtype.storm.task.TopologyContext;
13 | import backtype.storm.topology.IRichSpout;
14 | import backtype.storm.topology.OutputFieldsDeclarer;
15 | import backtype.storm.tuple.Fields;
16 | import backtype.storm.tuple.Values;
17 | import backtype.storm.utils.Utils;
18 |
19 | import static collabstream.streaming.MsgType.*;
20 |
21 | public class RatingsSource implements IRichSpout {
22 | protected SpoutOutputCollector collector;
23 | private final Configuration config;
24 | private BufferedReader input;
25 | private int sequenceNum = 0;
26 | private long inputStartTime;
27 |
28 | public RatingsSource(Configuration config) {
29 | this.config = config;
30 | }
31 |
32 | public boolean isDistributed() {
33 | return false;
34 | }
35 |
36 | public void open(Map stormConfig, TopologyContext context, SpoutOutputCollector collector) {
37 | this.collector = collector;
38 | inputStartTime = System.currentTimeMillis();
39 | System.out.printf("######## Input started: %1$tY-%1$tb-%1$td %1$tT %tZ\n", inputStartTime);
40 | try {
41 | input = new BufferedReader(new FileReader(config.inputFilename));
42 | } catch (IOException e) {
43 | throw new RuntimeException(e);
44 | }
45 | }
46 |
47 | public void close() {
48 | }
49 |
50 | public void ack(Object msgId) {
51 | }
52 |
53 | public void fail(Object msgId) {
54 | if (config.debug) {
55 | System.err.println("######## RatingsSource.fail: Resending " + msgId);
56 | }
57 | if (msgId == END_OF_DATA) {
58 | collector.emit(new Values(END_OF_DATA, null), END_OF_DATA);
59 | } else {
60 | TrainingExample ex = (TrainingExample)msgId;
61 | collector.emit(new Values(TRAINING_EXAMPLE, ex), ex);
62 | }
63 | }
64 |
65 | public void nextTuple() {
66 | if (input == null) return;
67 | try {
68 | String line = input.readLine();
69 | if (line == null) {
70 | long inputEndTime = System.currentTimeMillis();
71 | System.out.printf("######## Input finished: %1$tY-%1$tb-%1$td %1$tT %tZ\n", inputEndTime);
72 | System.out.println("######## Elapsed input time: "
73 | + DurationFormatUtils.formatPeriod(inputStartTime, inputEndTime, "H:m:s") + " (h:m:s)");
74 |
75 | input.close();
76 | input = null;
77 | collector.emit(new Values(END_OF_DATA, null), END_OF_DATA);
78 | } else {
79 | try {
80 | String[] token = StringUtils.split(line, ' ');
81 | int userId = Integer.parseInt(token[0]);
82 | int itemId = Integer.parseInt(token[1]);
83 | float rating = Float.parseFloat(token[2]);
84 |
85 | TrainingExample ex = new TrainingExample(sequenceNum++, userId, itemId, rating);
86 | collector.emit(new Values(TRAINING_EXAMPLE, ex), ex);
87 |
88 | if (config.inputDelay > 0) {
89 | Utils.sleep(config.inputDelay);
90 | }
91 | } catch (Exception e) {
92 | System.err.println("######## RatingsSource.nextTuple: Could not parse line: " + line + "\n" + e);
93 | }
94 | }
95 | } catch (IOException e) {
96 | throw new RuntimeException(e);
97 | }
98 | }
99 |
100 | public void declareOutputFields(OutputFieldsDeclarer declarer) {
101 | declarer.declare(new Fields("msgType", "example"));
102 | }
103 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/streaming/StreamingDSGD.java:
--------------------------------------------------------------------------------
1 | package collabstream.streaming;
2 |
3 | import java.io.File;
4 | import java.io.FileReader;
5 | import java.util.Properties;
6 |
7 | import backtype.storm.Config;
8 | import backtype.storm.LocalCluster;
9 | import backtype.storm.StormSubmitter;
10 | import backtype.storm.topology.TopologyBuilder;
11 | import backtype.storm.tuple.Fields;
12 |
13 | public class StreamingDSGD {
14 | public static void main(String[] args) throws Exception {
15 | if (args.length < 6) {
16 | System.err.println("######## Wrong number of arguments");
17 | System.err.println("######## required args: local|production numUsers numItems"
18 | + " inputFilename userOutputFilename itemOutputFilename");
19 | return;
20 | }
21 |
22 | Properties props = new Properties();
23 | File propFile = new File("data/collabstream.properties");
24 | if (propFile.exists()) {
25 | FileReader in = new FileReader(propFile);
26 | props.load(in);
27 | in.close();
28 | }
29 |
30 | int numUsers = Integer.parseInt(args[1]);
31 | int numItems = Integer.parseInt(args[2]);
32 | int numLatent = Integer.parseInt(props.getProperty("numLatent", "10"));
33 | int numUserBlocks = Integer.parseInt(props.getProperty("numUserBlocks", "10"));
34 | int numItemBlocks = Integer.parseInt(props.getProperty("numItemBlocks", "10"));
35 | float userPenalty = Float.parseFloat(props.getProperty("userPenalty", "0.1"));
36 | float itemPenalty = Float.parseFloat(props.getProperty("itemPenalty", "0.1"));
37 | float initialStepSize = Float.parseFloat(props.getProperty("initialStepSize", "0.1"));
38 | int maxTrainingIters = Integer.parseInt(props.getProperty("maxTrainingIters", "30"));
39 | String inputFilename = args[3];
40 | String userOutputFilename = args[4];
41 | String itemOutputFilename = args[5];
42 | long inputDelay = Long.parseLong(props.getProperty("inputDelay", "0"));
43 | boolean debug = Boolean.parseBoolean(props.getProperty("debug", "false"));
44 |
45 | Configuration config = new Configuration(numUsers, numItems, numLatent, numUserBlocks, numItemBlocks,
46 | userPenalty, itemPenalty, initialStepSize, maxTrainingIters,
47 | inputFilename, userOutputFilename, itemOutputFilename,
48 | inputDelay, debug);
49 |
50 | Config stormConfig = new Config();
51 | stormConfig.addSerialization(TrainingExample.Serialization.class);
52 | stormConfig.addSerialization(BlockPair.Serialization.class);
53 | stormConfig.addSerialization(MatrixSerialization.class);
54 | stormConfig.setNumWorkers(config.getNumProcesses());
55 | stormConfig.setNumAckers(config.getNumWorkers()); // our notion of a worker is different from Storm's
56 |
57 | TopologyBuilder builder = new TopologyBuilder();
58 | builder.setSpout(1, new RatingsSource(config));
59 | builder.setBolt(2, new Master(config))
60 | .globalGrouping(1)
61 | .globalGrouping(3, Worker.TO_MASTER_STREAM_ID)
62 | .directGrouping(4)
63 | .directGrouping(5);
64 | builder.setBolt(3, new Worker(config), config.getNumWorkers())
65 | .fieldsGrouping(2, new Fields("userBlockIdx"))
66 | .directGrouping(4)
67 | .directGrouping(5);
68 | builder.setBolt(4, new MatrixStore(config), config.numUserBlocks)
69 | .fieldsGrouping(3, Worker.USER_BLOCK_STREAM_ID, new Fields("userBlockIdx"));
70 | builder.setBolt(5, new MatrixStore(config), config.numItemBlocks)
71 | .fieldsGrouping(3, Worker.ITEM_BLOCK_STREAM_ID, new Fields("itemBlockIdx"));
72 |
73 | System.out.println("######## StreamingDSGD.main: submitting topology");
74 |
75 | if ("local".equals(args[0])) {
76 | LocalCluster cluster = new LocalCluster();
77 | cluster.submitTopology("StreamingDSGD", stormConfig, builder.createTopology());
78 | } else {
79 | StormSubmitter.submitTopology("StreamingDSGD", stormConfig, builder.createTopology());
80 | }
81 | }
82 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/streaming/TestPredictions.java:
--------------------------------------------------------------------------------
1 | package collabstream.streaming;
2 |
3 | import java.io.FileReader;
4 | import java.io.LineNumberReader;
5 | import java.util.HashMap;
6 | import java.util.Map;
7 |
8 | import org.apache.commons.lang.StringUtils;
9 | import org.apache.commons.lang.time.DurationFormatUtils;
10 |
11 | public class TestPredictions {
12 | public static void main(String[] args) throws Exception {
13 | if (args.length < 7) {
14 | System.err.println("######## Wrong number of arguments");
15 | System.err.println("######## required args: numUsers numItems numLatent"
16 | + " trainingFilename testFilename userFilename itemFilename");
17 | return;
18 | }
19 |
20 | long testStartTime = System.currentTimeMillis();
21 | System.out.printf("######## Testing started: %1$tY-%1$tb-%1$td %1$tT %tZ\n", testStartTime);
22 |
23 | int numUsers = Integer.parseInt(args[0]);
24 | int numItems = Integer.parseInt(args[1]);
25 | int numLatent = Integer.parseInt(args[2]);
26 | String trainingFilename = args[3];
27 | String testFilename = args[4];
28 | String userFilename = args[5];
29 | String itemFilename = args[6];
30 |
31 | float trainingTotal = 0.0f;
32 | int trainingCount = 0;
33 |
34 | Map userCount = new HashMap();
35 | Map itemCount = new HashMap();
36 | Map userTotal = new HashMap();
37 | Map itemTotal = new HashMap();
38 |
39 | long startTime = System.currentTimeMillis();
40 | System.out.printf("######## Started reading training file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", startTime);
41 |
42 | String line;
43 | LineNumberReader in = new LineNumberReader(new FileReader(trainingFilename));
44 | while ((line = in.readLine()) != null) {
45 | try {
46 | String[] token = StringUtils.split(line, ' ');
47 | int i = Integer.parseInt(token[0]);
48 | int j = Integer.parseInt(token[1]);
49 | float rating = Float.parseFloat(token[2]);
50 |
51 | trainingTotal += rating;
52 | ++trainingCount;
53 |
54 | if (userCount.containsKey(i)) {
55 | userCount.put(i, userCount.get(i) + 1);
56 | userTotal.put(i, userTotal.get(i) + rating);
57 | } else {
58 | userCount.put(i, 1);
59 | userTotal.put(i, rating);
60 | }
61 |
62 | if (itemCount.containsKey(j)) {
63 | itemCount.put(j, itemCount.get(j) + 1);
64 | itemTotal.put(j, itemTotal.get(j) + rating);
65 | } else {
66 | itemCount.put(j, 1);
67 | itemTotal.put(j, rating);
68 | }
69 | } catch (Exception e) {
70 | System.err.printf("######## Could not parse line %d in %s\n%s\n", in.getLineNumber(), trainingFilename, e);
71 | }
72 | }
73 | in.close();
74 |
75 | float trainingAvg = trainingTotal / trainingCount;
76 |
77 | long endTime = System.currentTimeMillis();
78 | System.out.printf("######## Finished reading training file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", endTime);
79 | System.out.println("######## Time elapsed reading training file: "
80 | + DurationFormatUtils.formatPeriod(startTime, endTime, "H:m:s") + " (h:m:s)");
81 |
82 | float[][] userMatrix = new float[numUsers][numLatent];
83 | for (int i = 0; i < numUsers; ++i) {
84 | for (int k = 0; k < numLatent; ++k) {
85 | userMatrix[i][k] = 0.0f;
86 | }
87 | }
88 |
89 | float[][] itemMatrix = new float[numItems][numLatent];
90 | for (int i = 0; i < numItems; ++i) {
91 | for (int k = 0; k < numLatent; ++k) {
92 | itemMatrix[i][k] = 0.0f;
93 | }
94 | }
95 |
96 | startTime = System.currentTimeMillis();
97 | System.out.printf("######## Started reading user file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", startTime);
98 |
99 | in = new LineNumberReader(new FileReader(userFilename));
100 | while ((line = in.readLine()) != null) {
101 | try {
102 | String[] token = StringUtils.split(line, ' ');
103 | int i = Integer.parseInt(token[0]);
104 | for (int k = 0; k < numLatent; ++k) {
105 | userMatrix[i][k] = Float.parseFloat(token[k+1]);
106 | }
107 | } catch (Exception e) {
108 | System.err.printf("######## Could not parse line %d in %s\n%s\n", in.getLineNumber(), userFilename, e);
109 | }
110 | }
111 | in.close();
112 |
113 | endTime = System.currentTimeMillis();
114 | System.out.printf("######## Finished reading user file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", endTime);
115 | System.out.println("######## Time elapsed reading user file: "
116 | + DurationFormatUtils.formatPeriod(startTime, endTime, "H:m:s") + " (h:m:s)");
117 |
118 | startTime = System.currentTimeMillis();
119 | System.out.printf("######## Started reading item file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", startTime);
120 |
121 | in = new LineNumberReader(new FileReader(itemFilename));
122 | while ((line = in.readLine()) != null) {
123 | try {
124 | String[] token = StringUtils.split(line, ' ');
125 | int j = Integer.parseInt(token[0]);
126 | for (int k = 0; k < numLatent; ++k) {
127 | itemMatrix[j][k] = Float.parseFloat(token[k+1]);
128 | }
129 | } catch (Exception e) {
130 | System.err.printf("######## Could not parse line %d in %s\n%s\n", in.getLineNumber(), itemFilename, e);
131 | }
132 | }
133 | in.close();
134 |
135 | endTime = System.currentTimeMillis();
136 | System.out.printf("######## Finished reading item file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", endTime);
137 | System.out.println("######## Time elapsed reading item file: "
138 | + DurationFormatUtils.formatPeriod(startTime, endTime, "H:m:s") + " (h:m:s)");
139 |
140 | startTime = System.currentTimeMillis();
141 | System.out.printf("######## Started reading test file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", startTime);
142 |
143 | float totalSqErr = 0.0f;
144 | int numRatings = 0;
145 |
146 | in = new LineNumberReader(new FileReader(testFilename));
147 | while ((line = in.readLine()) != null) {
148 | try {
149 | String[] token = StringUtils.split(line, ' ');
150 | int i = Integer.parseInt(token[0]);
151 | int j = Integer.parseInt(token[1]);
152 | float rating = Float.parseFloat(token[2]);
153 | float prediction;
154 |
155 | boolean userKnown = userCount.containsKey(i);
156 | boolean itemKnown = itemCount.containsKey(j);
157 |
158 | if (userKnown && itemKnown) {
159 | prediction = 0.0f;
160 | for (int k = 0; k < numLatent; ++k) {
161 | prediction += userMatrix[i][k] * itemMatrix[j][k];
162 | }
163 | } else if (userKnown) {
164 | prediction = userTotal.get(i) / userCount.get(i);
165 | } else if (itemKnown) {
166 | prediction = itemTotal.get(j) / itemCount.get(j);
167 | } else {
168 | prediction = trainingAvg;
169 | }
170 |
171 | float diff = prediction - rating;
172 | totalSqErr += diff*diff;
173 | ++numRatings;
174 | } catch (Exception e) {
175 | System.err.printf("######## Could not parse line %d in %s\n%s\n", in.getLineNumber(), testFilename, e);
176 | }
177 | }
178 |
179 | double rmse = Math.sqrt(totalSqErr / numRatings);
180 |
181 | endTime = System.currentTimeMillis();
182 | System.out.printf("######## Finished reading test file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", endTime);
183 | System.out.println("######## Time elapsed reading test file: "
184 | + DurationFormatUtils.formatPeriod(startTime, endTime, "H:m:s") + " (h:m:s)");
185 | System.out.println("######## Total elapsed testing time: "
186 | + DurationFormatUtils.formatPeriod(testStartTime, endTime, "H:m:s") + " (h:m:s)");
187 | System.out.println("######## Number of ratings used: " + numRatings);
188 | System.out.println("######## RMSE: " + rmse);
189 | }
190 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/streaming/TrainingExample.java:
--------------------------------------------------------------------------------
1 | package collabstream.streaming;
2 |
3 | import java.io.DataInputStream;
4 | import java.io.DataOutputStream;
5 | import java.io.IOException;
6 | import java.io.Serializable;
7 |
8 | import backtype.storm.serialization.ISerialization;
9 |
10 | public class TrainingExample implements Serializable {
11 | // Not a system timestamp; just a sequence number. Using type int to save memory; as a consequence,
12 | // cannot handle more than 2^31 training examples.
13 | public final int timestamp;
14 | public final int userId, itemId;
15 | public final float rating;
16 | public int numTrainingIters = 0;
17 |
18 | public TrainingExample(int timestamp, int userId, int itemId, float rating) {
19 | this.timestamp = timestamp;
20 | this.userId = userId;
21 | this.itemId = itemId;
22 | this.rating = rating;
23 | }
24 |
25 | public String toString() {
26 | return "(<" + timestamp + ">," + userId + "," + itemId + "," + rating + "," + numTrainingIters + ")";
27 | }
28 |
29 | public boolean equals(Object obj) {
30 | if (obj instanceof TrainingExample) {
31 | TrainingExample ex = (TrainingExample)obj;
32 | return this.userId == ex.userId && this.itemId == ex.itemId;
33 | } else {
34 | return false;
35 | }
36 | }
37 |
38 | public int hashCode() {
39 | int userHi = userId >>> 16;
40 | int userLo = userId & 0xFFFF;
41 | int itemHi = itemId >>> 16;
42 | int itemLo = itemId & 0xFFFF;
43 |
44 | return ((userLo ^ itemHi) << 16) | (userHi ^ itemLo);
45 | }
46 |
47 | public static class Serialization implements ISerialization {
48 | public boolean accept(Class c) {
49 | return TrainingExample.class.equals(c);
50 | }
51 |
52 | public void serialize(TrainingExample ex, DataOutputStream out) throws IOException {
53 | out.writeInt(ex.timestamp);
54 | out.writeInt(ex.userId);
55 | out.writeInt(ex.itemId);
56 | out.writeFloat(ex.rating);
57 | out.writeInt(ex.numTrainingIters);
58 | }
59 |
60 | public TrainingExample deserialize(DataInputStream in) throws IOException {
61 | TrainingExample ex = new TrainingExample(in.readInt(), in.readInt(), in.readInt(), in.readFloat());
62 | ex.numTrainingIters = in.readInt();
63 | return ex;
64 | }
65 | }
66 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/streaming/Worker.java:
--------------------------------------------------------------------------------
1 | package collabstream.streaming;
2 |
3 | import java.util.Map;
4 | import java.util.concurrent.ConcurrentHashMap;
5 |
6 | import backtype.storm.task.OutputCollector;
7 | import backtype.storm.task.TopologyContext;
8 | import backtype.storm.topology.IRichBolt;
9 | import backtype.storm.topology.OutputFieldsDeclarer;
10 | import backtype.storm.tuple.Fields;
11 | import backtype.storm.tuple.Tuple;
12 | import backtype.storm.tuple.Values;
13 |
14 | import static collabstream.streaming.MsgType.*;
15 |
16 | public class Worker implements IRichBolt {
17 | public static final int TO_MASTER_STREAM_ID = 1;
18 | public static final int USER_BLOCK_STREAM_ID = 2;
19 | public static final int ITEM_BLOCK_STREAM_ID = 3;
20 |
21 | private OutputCollector collector;
22 | private final Configuration config;
23 | private final Map workingBlockMap = new ConcurrentHashMap();
24 |
25 | public Worker(Configuration config) {
26 | this.config = config;
27 | }
28 |
29 | public void prepare(Map stormConfig, TopologyContext context, OutputCollector collector) {
30 | this.collector = collector;
31 | }
32 |
33 | public void cleanup() {
34 | }
35 |
36 | public void execute(Tuple tuple) {
37 | MsgType msgType = (MsgType)tuple.getValue(0);
38 | if (config.debug && msgType != TRAINING_EXAMPLE) {
39 | System.out.println("######## Worker.execute: " + msgType + " " + tuple.getValue(1));
40 | }
41 | BlockPair bp;
42 | WorkingBlock workingBlock;
43 |
44 | switch (msgType) {
45 | case TRAINING_EXAMPLE:
46 | TrainingExample ex = (TrainingExample)tuple.getValue(2);
47 | if (config.debug) {
48 | System.out.println("######## Worker.execute: " + msgType + " " + ex);
49 | }
50 | bp = new BlockPair(config.getUserBlockIdx(ex.userId), config.getItemBlockIdx(ex.itemId));
51 | workingBlock = getWorkingBlock(bp);
52 | workingBlock.examples.add(ex);
53 | collector.ack(tuple);
54 | break;
55 | case PROCESS_BLOCK_REQ:
56 | bp = (BlockPair)tuple.getValue(1);
57 | workingBlock = getWorkingBlock(bp);
58 | workingBlock.waitingForBlocks = true;
59 | collector.emit(USER_BLOCK_STREAM_ID, new Values(USER_BLOCK_REQ, bp, null, null, bp.userBlockIdx));
60 | collector.emit(ITEM_BLOCK_STREAM_ID, new Values(ITEM_BLOCK_REQ, bp, null, null, bp.itemBlockIdx));
61 | break;
62 | case USER_BLOCK_REQ:
63 | // Forward request from master
64 | bp = (BlockPair)tuple.getValue(1);
65 | collector.emit(USER_BLOCK_STREAM_ID, new Values(USER_BLOCK_REQ, bp, null, tuple.getSourceTask(), bp.userBlockIdx));
66 | break;
67 | case ITEM_BLOCK_REQ:
68 | // Forward request from master
69 | bp = (BlockPair)tuple.getValue(1);
70 | collector.emit(ITEM_BLOCK_STREAM_ID, new Values(ITEM_BLOCK_REQ, bp, null, tuple.getSourceTask(), bp.itemBlockIdx));
71 | break;
72 | case USER_BLOCK:
73 | bp = (BlockPair)tuple.getValue(1);
74 | float[][] userBlock = (float[][])tuple.getValue(2);
75 | workingBlock = getWorkingBlock(bp);
76 |
77 | if (workingBlock.waitingForBlocks) {
78 | workingBlock.userBlock = userBlock;
79 | if (workingBlock.itemBlock != null) {
80 | update(bp, workingBlock);
81 | }
82 | }
83 | break;
84 | case ITEM_BLOCK:
85 | bp = (BlockPair)tuple.getValue(1);
86 | float[][] itemBlock = (float[][])tuple.getValue(2);
87 | workingBlock = getWorkingBlock(bp);
88 |
89 | if (workingBlock.waitingForBlocks) {
90 | workingBlock.itemBlock = itemBlock;
91 | if (workingBlock.userBlock != null) {
92 | update(bp, workingBlock);
93 | }
94 | }
95 | break;
96 | case USER_BLOCK_SAVED:
97 | bp = (BlockPair)tuple.getValue(1);
98 | workingBlock = getWorkingBlock(bp);
99 |
100 | if (workingBlock.waitingForStorage) {
101 | workingBlock.userBlock = null;
102 | if (workingBlock.itemBlock == null) {
103 | workingBlock.waitingForStorage = false;
104 | collector.emit(TO_MASTER_STREAM_ID, new Values(PROCESS_BLOCK_FIN, bp, workingBlock.getLatestExample()));
105 | }
106 | }
107 | break;
108 | case ITEM_BLOCK_SAVED:
109 | bp = (BlockPair)tuple.getValue(1);
110 | workingBlock = getWorkingBlock(bp);
111 |
112 | if (workingBlock.waitingForStorage) {
113 | workingBlock.itemBlock = null;
114 | if (workingBlock.userBlock == null) {
115 | workingBlock.waitingForStorage = false;
116 | collector.emit(TO_MASTER_STREAM_ID, new Values(PROCESS_BLOCK_FIN, bp, workingBlock.getLatestExample()));
117 | }
118 | }
119 | break;
120 | }
121 | }
122 |
123 | public void declareOutputFields(OutputFieldsDeclarer declarer) {
124 | // Fields "userBlockIdx" and "itemBlockIdx" used solely for grouping
125 | declarer.declareStream(TO_MASTER_STREAM_ID, new Fields("msgType", "blockPair", "latestExample"));
126 | declarer.declareStream(USER_BLOCK_STREAM_ID, new Fields("msgType", "blockPair", "block", "taskId", "userBlockIdx"));
127 | declarer.declareStream(ITEM_BLOCK_STREAM_ID, new Fields("msgType", "blockPair", "block", "taskId", "itemBlockIdx"));
128 | }
129 |
130 | private WorkingBlock getWorkingBlock(BlockPair bp) {
131 | // In general, this pattern of access is not thread-safe. But since requests with the same userBlockIdx
132 | // are sent to the same thread, it should be safe in our case.
133 | WorkingBlock workingBlock = workingBlockMap.get(bp);
134 | if (workingBlock == null) {
135 | workingBlock = new WorkingBlock();
136 | workingBlockMap.put(bp, workingBlock);
137 | }
138 | return workingBlock;
139 | }
140 |
141 | private void update(BlockPair bp, WorkingBlock workingBlock) {
142 | int userBlockStart = config.getUserBlockStart(bp.userBlockIdx);
143 | int itemBlockStart = config.getItemBlockStart(bp.itemBlockIdx);
144 | float[][] userBlock = workingBlock.userBlock;
145 | float[][] itemBlock = workingBlock.itemBlock;
146 |
147 | TrainingExample[] examples = workingBlock.examples.toArray(new TrainingExample[workingBlock.examples.size()]);
148 | PermutationUtils.permute(examples);
149 |
150 | for (TrainingExample ex : examples) {
151 | if (ex.numTrainingIters >= config.maxTrainingIters) continue;
152 | int i = ex.userId - userBlockStart;
153 | int j = ex.itemId - itemBlockStart;
154 |
155 | float dotProduct = 0.0f;
156 | for (int k = 0; k < config.numLatent; ++k) {
157 | dotProduct += userBlock[i][k] * itemBlock[j][k];
158 | }
159 | float ratingDiff = dotProduct - ex.rating;
160 |
161 | ++ex.numTrainingIters;
162 | float stepSize = 2 * config.initialStepSize / ex.numTrainingIters;
163 |
164 | for (int k = 0; k < config.numLatent; ++k) {
165 | float oldUserWeight = userBlock[i][k];
166 | float oldItemWeight = itemBlock[j][k];
167 | userBlock[i][k] -= stepSize*(ratingDiff * oldItemWeight + config.userPenalty * oldUserWeight);
168 | itemBlock[j][k] -= stepSize*(ratingDiff * oldUserWeight + config.itemPenalty * oldItemWeight);
169 | }
170 | }
171 | workingBlock.waitingForBlocks = false;
172 | workingBlock.waitingForStorage = true;
173 |
174 | collector.emit(USER_BLOCK_STREAM_ID, new Values(USER_BLOCK, bp, userBlock, null, bp.userBlockIdx));
175 | collector.emit(ITEM_BLOCK_STREAM_ID, new Values(ITEM_BLOCK, bp, itemBlock, null, bp.itemBlockIdx));
176 | }
177 | }
--------------------------------------------------------------------------------
/src/main/java/collabstream/streaming/WorkingBlock.java:
--------------------------------------------------------------------------------
1 | package collabstream.streaming;
2 |
3 | import java.io.Serializable;
4 | import java.util.HashSet;
5 | import java.util.Set;
6 |
7 | public class WorkingBlock implements Serializable {
8 | public final Set examples = new HashSet();
9 | public float[][] userBlock = null;
10 | public float[][] itemBlock = null;
11 | public boolean waitingForBlocks = false;
12 | public boolean waitingForStorage = false;
13 |
14 | public String toString() {
15 | StringBuilder b = new StringBuilder(24*examples.size() + 72);
16 |
17 | b.append("examples={");
18 | boolean first = true;
19 | for (TrainingExample ex : examples) {
20 | if (first) {
21 | first = false;
22 | } else {
23 | b.append(", ");
24 | }
25 | b.append(ex.toString());
26 | }
27 | b.append('}');
28 |
29 | b.append("\nuserBlock=\n").append(MatrixUtils.toString(userBlock));
30 | b.append("\nitemBlock=\n").append(MatrixUtils.toString(itemBlock));
31 | b.append("\nwaitingForBlocks=").append(waitingForBlocks);
32 | b.append("\nwaitingForStorage=").append(waitingForStorage);
33 |
34 | return b.toString();
35 | }
36 |
37 | public TrainingExample getLatestExample() {
38 | TrainingExample latest = null;
39 | for (TrainingExample ex : examples) {
40 | if (latest == null || latest.timestamp < ex.timestamp) {
41 | latest = ex;
42 | }
43 | }
44 | return latest;
45 | }
46 | }
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/DSGDMain.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd;
2 |
3 | import java.util.ArrayList;
4 | import java.util.Collections;
5 | import java.util.List;
6 |
7 | import org.apache.hadoop.conf.Configuration;
8 | import org.apache.hadoop.conf.Configured;
9 | import org.apache.hadoop.fs.FileSystem;
10 | import org.apache.hadoop.fs.Path;
11 | import org.apache.hadoop.io.DoubleWritable;
12 | import org.apache.hadoop.io.NullWritable;
13 | import org.apache.hadoop.io.Text;
14 | import org.apache.hadoop.mapreduce.Job;
15 | import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
16 | import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
17 | import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
18 | import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
19 | import org.apache.log4j.Logger;
20 |
21 | import comparison.dsgd.mapper.DSGDOutputFactorsMapper;
22 | import comparison.dsgd.mapper.DSGDIntermediateMapper;
23 | import comparison.dsgd.mapper.DSGDPreprocFactorMapper;
24 | import comparison.dsgd.mapper.DSGDPreprocRatingsMapper;
25 | import comparison.dsgd.mapper.DSGDRmseMapper;
26 | import comparison.dsgd.reducer.DSGDOutputFactorsReducer;
27 | import comparison.dsgd.reducer.DSGDIntermediateReducer;
28 | import comparison.dsgd.reducer.DSGDPreprocFactorReducer;
29 | import comparison.dsgd.reducer.DSGDPreprocRatingsReducer;
30 | import comparison.dsgd.reducer.DSGDRmseReducer;
31 |
32 |
33 | public class DSGDMain extends Configured {
34 |
35 | private static final Logger sLogger = Logger.getLogger(DSGDMain.class);
36 |
37 | public static int getBlockRow(int row, int numUsers, int numReducers) {
38 | int usersPerBlock = (numUsers / numReducers) + 1;
39 | return (int)Math.floor(((double)row / (double)usersPerBlock));
40 | }
41 |
42 | public static int getBlockColumn(int column, int numItems, int numReducers) {
43 | int itemsPerBlock = (numItems / numReducers) + 1;
44 | return (int)Math.floor(((double)column / (double)itemsPerBlock));
45 | }
46 |
47 | public static void main(String[] args) throws Exception {
48 | if (args.length != 3) {
49 | System.out.println("usage: [input-path] [output-path]");
50 | return;
51 | }
52 | long startTime = System.currentTimeMillis();
53 | long timeElapsed = 0;
54 |
55 | String input = args[0];
56 | String output = args[1];
57 | int numReducers = Integer.parseInt(args[2]);
58 |
59 | // Fixed parameters
60 | // Edit these for individual experiment
61 | String numUsers = "573";
62 | String numItems = "2649430";
63 | int kValue = 10;
64 | int numIterations = 10;
65 | int numBlocks = numReducers;
66 | double stepSize = 0.1; // \tau (step size for SGD updates)
67 | double lambda = 0.1;
68 |
69 | String trainInput = input + "/train";
70 | String testInput = input + "/test";
71 | String tempFactorsInput = output + "/temp/input";
72 | String tempFactorsOutput = output + "/temp/output/";
73 | String tempTrainRatings = output + "/temp/train";
74 | String tempTestRatings = output + "/temp/test";
75 | String emptyInput = input + "/empty";
76 | String finalFactorsOutput = output + "/final_factors";
77 | String rmseOutput = output + "/rmse";
78 |
79 | /**
80 | * jobPreprocRatings
81 | * write train and test ratings to RatingsItem format
82 | */
83 | Configuration confPreprocRatings = new Configuration();
84 | confPreprocRatings.set("numUsers", numUsers);
85 | confPreprocRatings.set("numReducers", Integer.toString(numBlocks));
86 | confPreprocRatings.set("numItems", numItems);
87 | confPreprocRatings.set("kValue", Integer.toString(kValue));
88 |
89 | Job jobPreprocRatings = new Job(confPreprocRatings, "train_ratings");
90 |
91 | FileInputFormat.setInputPaths(jobPreprocRatings, new Path(trainInput));
92 | FileOutputFormat.setOutputPath(jobPreprocRatings, new Path(tempTrainRatings));
93 | // Delete the output directory if it exists already
94 | FileSystem.get(confPreprocRatings).delete(new Path(tempTrainRatings), true);
95 |
96 | jobPreprocRatings.setOutputFormatClass(SequenceFileOutputFormat.class);
97 |
98 | jobPreprocRatings.setJarByClass(DSGDMain.class);
99 | jobPreprocRatings.setNumReduceTasks(numReducers);
100 | jobPreprocRatings.setOutputKeyClass(MatrixItem.class);
101 | jobPreprocRatings.setOutputValueClass(NullWritable.class);
102 | jobPreprocRatings.setMapOutputKeyClass(MatrixItem.class);
103 | jobPreprocRatings.setMapOutputValueClass(NullWritable.class);
104 | jobPreprocRatings.setMapperClass(DSGDPreprocRatingsMapper.class);
105 | jobPreprocRatings.setReducerClass(DSGDPreprocRatingsReducer.class);
106 | jobPreprocRatings.setCombinerClass(DSGDPreprocRatingsReducer.class);
107 |
108 | // train RatingsItems
109 | jobPreprocRatings.waitForCompletion(true);
110 |
111 | // test RatingsItems
112 | jobPreprocRatings = new Job(confPreprocRatings, "train_ratings");
113 |
114 | FileInputFormat.setInputPaths(jobPreprocRatings, new Path(testInput));
115 | FileOutputFormat.setOutputPath(jobPreprocRatings, new Path(tempTestRatings));
116 | // Delete the output directory if it exists already
117 | FileSystem.get(confPreprocRatings).delete(new Path(tempTestRatings), true);
118 |
119 | jobPreprocRatings.setOutputFormatClass(SequenceFileOutputFormat.class);
120 | jobPreprocRatings.setJarByClass(DSGDMain.class);
121 | jobPreprocRatings.setNumReduceTasks(numReducers);
122 | jobPreprocRatings.setOutputKeyClass(MatrixItem.class);
123 | jobPreprocRatings.setOutputValueClass(NullWritable.class);
124 | jobPreprocRatings.setMapOutputKeyClass(MatrixItem.class);
125 | jobPreprocRatings.setMapOutputValueClass(NullWritable.class);
126 | jobPreprocRatings.setMapperClass(DSGDPreprocRatingsMapper.class);
127 | jobPreprocRatings.setReducerClass(DSGDPreprocRatingsReducer.class);
128 | jobPreprocRatings.setCombinerClass(DSGDPreprocRatingsReducer.class);
129 |
130 | jobPreprocRatings.waitForCompletion(true);
131 |
132 | /**
133 | * jobPreprocFactors
134 | * Create initial U and M and output to FactorItem format
135 | */
136 |
137 | Configuration confPreprocFactors = new Configuration();
138 | confPreprocFactors.set("numUsers", numUsers);
139 | confPreprocFactors.set("numItems", numItems);
140 | confPreprocFactors.set("kValue", Integer.toString(kValue));
141 |
142 | Job jobPreprocFactors = new Job(confPreprocFactors, "factors");
143 |
144 | FileInputFormat.setInputPaths(jobPreprocFactors, new Path(emptyInput));
145 | FileOutputFormat.setOutputPath(jobPreprocFactors, new Path(tempFactorsInput));
146 | // Delete the output directory if it exists already
147 | FileSystem.get(confPreprocRatings).delete(new Path(tempFactorsInput), true);
148 |
149 | jobPreprocFactors.setJarByClass(DSGDMain.class);
150 | jobPreprocFactors.setNumReduceTasks(1);
151 | jobPreprocFactors.setOutputKeyClass(MatrixItem.class);
152 | jobPreprocFactors.setOutputValueClass(NullWritable.class);
153 | jobPreprocFactors.setMapOutputKeyClass(MatrixItem.class);
154 | jobPreprocFactors.setMapOutputValueClass(NullWritable.class);
155 | jobPreprocFactors.setMapperClass(DSGDPreprocFactorMapper.class);
156 | jobPreprocFactors.setReducerClass(DSGDPreprocFactorReducer.class);
157 | jobPreprocFactors.setCombinerClass(DSGDPreprocFactorReducer.class);
158 | jobPreprocFactors.setOutputFormatClass(SequenceFileOutputFormat.class);
159 |
160 | jobPreprocFactors.waitForCompletion(true);
161 |
162 | // if(true){
163 | // return;
164 | // }
165 |
166 | for(int a = 0; numIterations < 1; ++a){
167 | sLogger.info("Running iteration " + a);
168 | // Build stratum
169 | List stratum = new ArrayList();
170 | for(int i = 0; i < numBlocks; ++i ){
171 | stratum.add(i);
172 | }
173 | Collections.shuffle(stratum); //choose random initial stratum
174 | String stratumString;
175 |
176 | for(int i=0; i < numReducers; ++i) {
177 |
178 | sLogger.info("Running stratum " + i + " on iteration " + a );
179 | // Choose next stratum
180 | int current1 = stratum.get(0);
181 | int current2;
182 | for(int j = 1; j < numBlocks; ++j){
183 | current2 = stratum.get(j);
184 | stratum.set(j, current1);
185 | current1 = current2;
186 | }
187 | stratum.set(0, current1);
188 |
189 | // Build string version of stratum
190 | stratumString = "";
191 | for(int j : stratum){
192 | stratumString += " " + j;
193 | }
194 |
195 | /**
196 | * jobIntermediate
197 | * Perform DSGD updates on stratum
198 | */
199 | Configuration confIntermediate = new Configuration();
200 | confIntermediate.set("stratum", stratumString);
201 | confIntermediate.set("numUsers", numUsers);
202 | confIntermediate.set("numReducers", Integer.toString(numBlocks));
203 | confIntermediate.set("numItems", numItems);
204 | confIntermediate.set("kValue", Integer.toString(kValue));
205 | confIntermediate.set("stepSize", Double.toString(stepSize));
206 | confIntermediate.set("lambda", Double.toString(lambda));
207 | FileSystem fs = FileSystem.get(confIntermediate);
208 |
209 | Job jobIntermediate = new Job(confIntermediate, "DSGD: Intermediate");
210 |
211 | FileInputFormat.setInputPaths(jobIntermediate, tempFactorsInput + "," + tempTrainRatings);
212 | FileOutputFormat.setOutputPath(jobIntermediate, new Path(tempFactorsOutput));
213 | // Delete the output directory if it exists already
214 | FileSystem.get(confPreprocRatings).delete(new Path(tempFactorsOutput), true);
215 | jobIntermediate.setInputFormatClass(SequenceFileInputFormat.class);
216 | jobIntermediate.setOutputFormatClass(SequenceFileOutputFormat.class);
217 |
218 | jobIntermediate.setJarByClass(DSGDMain.class);
219 | jobIntermediate.setNumReduceTasks(numReducers);
220 | jobIntermediate.setOutputKeyClass(MatrixItem.class);
221 | jobIntermediate.setOutputValueClass(NullWritable.class);
222 | jobIntermediate.setMapOutputKeyClass(IntMatrixItemPair.class);
223 | jobIntermediate.setMapOutputValueClass(NullWritable.class);
224 |
225 | jobIntermediate.setMapperClass(DSGDIntermediateMapper.class);
226 | // job.setCombinerClass(IntermediateReducer.class);
227 | jobIntermediate.setReducerClass(DSGDIntermediateReducer.class);
228 |
229 | jobIntermediate.waitForCompletion(true);
230 |
231 | fs.delete(new Path(tempFactorsInput), true);
232 | fs.rename(new Path(tempFactorsOutput), new Path(tempFactorsInput));
233 | }
234 | /**
235 | * jobRmse
236 | * Calculate RMSE on test data using current factors
237 | */
238 | timeElapsed += (System.currentTimeMillis() - startTime);
239 |
240 | Configuration confRmse = new Configuration();
241 | confRmse.set("numUsers", numUsers);
242 | confRmse.set("numReducers", Integer.toString(numBlocks));
243 | confRmse.set("numItems", numItems);
244 | confRmse.set("timeElapsed", Long.toString(timeElapsed));
245 | confRmse.set("kValue", Integer.toString(kValue));
246 | Job jobRmse = new Job(confRmse, "DSGD: RMSE");
247 | FileInputFormat.setInputPaths(jobRmse, tempFactorsInput + "," + tempTestRatings);
248 | FileOutputFormat.setOutputPath(jobRmse, new Path(rmseOutput + "/iter_" + a));
249 | // Delete the output directory if it exists already
250 | FileSystem.get(confPreprocRatings).delete(new Path(rmseOutput + "/iter_" + a), true);
251 | jobRmse.setInputFormatClass(SequenceFileInputFormat.class);
252 |
253 | jobRmse.setJarByClass(DSGDMain.class);
254 | jobRmse.setNumReduceTasks(1);
255 | jobRmse.setOutputKeyClass(DoubleWritable.class);
256 | jobRmse.setOutputValueClass(DoubleWritable.class);
257 | jobRmse.setMapOutputKeyClass(MatrixItem.class);
258 | jobRmse.setMapOutputValueClass(NullWritable.class);
259 | jobRmse.setMapperClass(DSGDRmseMapper.class);
260 | jobRmse.setReducerClass(DSGDRmseReducer.class);
261 |
262 | jobRmse.waitForCompletion(true);
263 |
264 | startTime = System.currentTimeMillis();
265 | }
266 | // Emit final factor matrices
267 | Configuration confFinal = new Configuration();
268 | confFinal.set("kValue", Integer.toString(kValue));
269 | confFinal.set("numUsers", numUsers);
270 | confFinal.set("numItems", numItems);
271 |
272 | Job jobFinal = new Job(confFinal, "ImprovedPBFS");
273 | FileInputFormat.setInputPaths(jobFinal, new Path(tempFactorsInput));
274 | FileOutputFormat.setOutputPath(jobFinal, new Path(finalFactorsOutput));
275 | // Delete the output directory if it exists already
276 | FileSystem.get(confPreprocRatings).delete(new Path(finalFactorsOutput), true);
277 |
278 | jobFinal.setInputFormatClass(SequenceFileInputFormat.class);
279 | jobFinal.setJarByClass(DSGDMain.class);
280 | jobFinal.setNumReduceTasks(1); // Must use 1 reducer so that both U and M
281 | // get sent to same reducer
282 |
283 | jobFinal.setOutputKeyClass(Text.class);
284 | jobFinal.setOutputValueClass(NullWritable.class);
285 |
286 | jobFinal.setMapOutputKeyClass(MatrixItem.class);
287 | jobFinal.setMapOutputValueClass(NullWritable.class);
288 |
289 | jobFinal.setMapperClass(DSGDOutputFactorsMapper.class);
290 | // jobFinal.setCombinerClass(FinalReducer.class);
291 | jobFinal.setReducerClass(DSGDOutputFactorsReducer.class);
292 |
293 | jobFinal.waitForCompletion(true);
294 | // fsFinal.delete(tempInputDir, true);
295 | }
296 |
297 | }
298 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/IntMatrixItemPair.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd;
2 |
3 | import java.io.DataInput;
4 | import java.io.DataOutput;
5 | import java.io.IOException;
6 |
7 | import org.apache.hadoop.io.IntWritable;
8 | import org.apache.hadoop.io.WritableComparable;
9 |
10 | /**
11 | * Pair consisting of an IntWritable (the reducer number) and a MatrixItem Used
12 | * for 2 way sorting
13 | *
14 | * @author christopherjohnson
15 | *
16 | */
17 | public class IntMatrixItemPair implements WritableComparable {
18 |
19 | private IntWritable reducerNum = new IntWritable();
20 | private MatrixItem matItem = new MatrixItem();
21 |
22 | public IntMatrixItemPair() {
23 |
24 | }
25 |
26 | public void set(IntWritable r, MatrixItem m) {
27 | this.reducerNum = new IntWritable(r.get());
28 | this.matItem.set(m.getRow().get(), m.getColumn().get(), m
29 | .getValue().get(), m.getMatrixType().toString());
30 | }
31 |
32 | public void readFields(DataInput arg0) throws IOException {
33 | reducerNum.readFields(arg0);
34 | matItem.readFields(arg0);
35 | }
36 |
37 | public void write(DataOutput arg0) throws IOException {
38 | reducerNum.write(arg0);
39 | matItem.write(arg0);
40 | }
41 |
42 | public int compareTo(IntMatrixItemPair o) {
43 | int cmp = reducerNum.compareTo(o.reducerNum);
44 | if (cmp != 0) {
45 | return cmp;
46 | }
47 | return matItem.compareTo(o.getMatItem());
48 | }
49 |
50 | public IntWritable getReducerNum() {
51 | return reducerNum;
52 | }
53 |
54 | public MatrixItem getMatItem() {
55 | return matItem;
56 | }
57 |
58 | public void setReducerNum(IntWritable reducerNum) {
59 | this.reducerNum = reducerNum;
60 | }
61 |
62 | public void setMatItem(MatrixItem matItem) {
63 | this.matItem = matItem;
64 | }
65 |
66 | }
67 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/LongPair.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd;
2 |
3 | public class LongPair {
4 |
5 | public long first;
6 | public long second;
7 |
8 | public LongPair(long first, long second) {
9 | this.first = first;
10 | this.second = second;
11 | }
12 |
13 | public int compareTo(LongPair o) {
14 | if (this.first != o.first) {
15 | if (this.first < o.first) {
16 | return -1;
17 | } else {
18 | return 1;
19 | }
20 | } else if (this.second != o.second) {
21 | if (this.second < o.second) {
22 | return -1;
23 | } else {
24 | return 1;
25 | }
26 | }
27 | return 0;
28 | }
29 |
30 | }
31 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/MatrixItem.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd;
2 |
3 | import java.io.DataInput;
4 | import java.io.DataOutput;
5 | import java.io.IOException;
6 |
7 | import org.apache.hadoop.io.DoubleWritable;
8 | import org.apache.hadoop.io.IntWritable;
9 | import org.apache.hadoop.io.Text;
10 | import org.apache.hadoop.io.WritableComparable;
11 |
12 | public class MatrixItem implements WritableComparable, Cloneable {
13 |
14 | private IntWritable row = new IntWritable();
15 | private IntWritable column = new IntWritable();
16 | private DoubleWritable value = new DoubleWritable();
17 | private Text matrixType = new Text();
18 |
19 | public static final Text U_MATRIX = new Text("U_MATRIX");
20 | public static final Text M_MATRIX = new Text("M_MATRIX");
21 | public static final Text R_MATRIX = new Text("R_MATRIX");
22 |
23 | public boolean isRatingsItem(){
24 | return matrixType.equals(R_MATRIX);
25 | }
26 |
27 | public boolean isFactorItem(){
28 | return !matrixType.equals(R_MATRIX);
29 | }
30 |
31 | public void set(int r, int c, double v, String m) {
32 | setRow(new IntWritable(r));
33 | setColumn(new IntWritable(c));
34 | setValue(new DoubleWritable(v));
35 | setMatrixType(new Text(m));
36 | }
37 |
38 | public void readFields(DataInput arg0) throws IOException {
39 | row.readFields(arg0);
40 | column.readFields(arg0);
41 | value.readFields(arg0);
42 | matrixType.readFields(arg0);
43 | }
44 |
45 | public void write(DataOutput arg0) throws IOException {
46 | row.write(arg0);
47 | column.write(arg0);
48 | value.write(arg0);
49 | matrixType.write(arg0);
50 | }
51 |
52 | public MatrixItem clone() {
53 | MatrixItem matItem = new MatrixItem();
54 | matItem.set(this.row.get(), this.column.get(), this.value.get(),
55 | new String(this.matrixType.toString()));
56 | return matItem;
57 | }
58 |
59 | public int compareTo(MatrixItem o) {
60 | int compare = this.getMatrixType().compareTo(o.getMatrixType());
61 | if(compare != 0){
62 | if(this.isRatingsItem()){
63 | return 1;
64 | } else if(o.isRatingsItem()){
65 | return -1;
66 | }
67 | return compare;
68 | }
69 | // compare = this.getRow().compareTo(o.getRow());
70 | // if(compare != 0){
71 | // return compare;
72 | // }
73 | // compare = this.getColumn().compareTo(o.getColumn());
74 | // if(compare != 0){
75 | // return compare;
76 | // }
77 | // return this.getValue().compareTo(o.getValue());
78 | double rand = Math.random();
79 | if(rand > 0.5){
80 | return 1;
81 | } else{
82 | return -1;
83 | }
84 | }
85 |
86 | public IntWritable getRow() {
87 | return row;
88 | }
89 |
90 | public IntWritable getColumn() {
91 | return column;
92 | }
93 |
94 | public DoubleWritable getValue() {
95 | return value;
96 | }
97 |
98 | public void setRow(IntWritable row) {
99 | this.row = row;
100 | }
101 |
102 | public void setColumn(IntWritable column) {
103 | this.column = column;
104 | }
105 |
106 | public void setValue(DoubleWritable value) {
107 | this.value = value;
108 | }
109 |
110 | public Text getMatrixType() {
111 | return matrixType;
112 | }
113 |
114 | public void setMatrixType(Text matrixType) {
115 | this.matrixType = matrixType;
116 | }
117 |
118 | // public IntWritable getRow();
119 | // public IntWritable getColumn();
120 | // public MatrixItem clone();
121 |
122 | }
123 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/MatrixUtils.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd;
2 |
3 | public class MatrixUtils {
4 |
5 | public static double dotProduct(double[] v1, double[] v2){
6 | if(v1.length != v2.length){
7 | System.out.println("lengths do not match!!");
8 | return 0;
9 | }
10 |
11 | double total = 0;
12 | for(int i = 0; i < v1.length; ++i){
13 | // System.out.println("v1{" + i + "]: " + v1[i]);
14 | // System.out.println("v2{" + i + "]: " + v2[i]);
15 | total += v1[i]*v2[i];
16 | }
17 |
18 | return total;
19 | }
20 |
21 | public static class MatrixException extends Exception{
22 |
23 | private static final long serialVersionUID = 6678191521290159097L;
24 |
25 |
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/mapper/DSGDFinalMapper.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd.mapper;
2 |
3 | public class DSGDFinalMapper {
4 |
5 | }
6 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/mapper/DSGDIntermediateMapper.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd.mapper;
2 |
3 | import java.io.IOException;
4 | import java.util.StringTokenizer;
5 |
6 | import org.apache.hadoop.io.IntWritable;
7 | import org.apache.hadoop.io.NullWritable;
8 | import org.apache.hadoop.mapreduce.Mapper;
9 |
10 | import comparison.dsgd.DSGDMain;
11 | import comparison.dsgd.IntMatrixItemPair;
12 | import comparison.dsgd.MatrixItem;
13 |
14 | public class DSGDIntermediateMapper extends
15 | Mapper {
16 |
17 | int[] stratumArray;
18 | int numItems;
19 | int numUsers;
20 | int numReducers;
21 | IntMatrixItemPair intMatPair = new IntMatrixItemPair();
22 | IntWritable reducerNum = new IntWritable();
23 | MatrixItem matItem = new MatrixItem();
24 | NullWritable nw = NullWritable.get();
25 |
26 | @Override
27 | protected void setup(Context context) throws IOException,
28 | InterruptedException {
29 | String stratumString = context.getConfiguration().get("stratum");
30 | StringTokenizer st = new StringTokenizer(stratumString);
31 | numReducers = Integer.parseInt(context.getConfiguration().get(
32 | "numReducers"));
33 | stratumArray = new int[numReducers];
34 | for (int i = 0; i < numReducers; ++i) {
35 | stratumArray[i] = Integer.parseInt(st.nextToken());
36 | }
37 | numUsers = Integer.parseInt(context.getConfiguration().get("numUsers"));
38 | numItems = Integer.parseInt(context.getConfiguration().get("numItems"));
39 | super.setup(context);
40 | }
41 |
42 | public void map(MatrixItem key, NullWritable value, Context context)
43 | throws IOException, InterruptedException {
44 | int row = key.getRow().get();
45 | int column = key.getColumn().get();
46 | if (key.isRatingsItem()) {
47 | int blockRow = DSGDMain.getBlockColumn(row, numUsers, numReducers);
48 | int blockColumn = DSGDMain.getBlockColumn(column, numItems,
49 | numReducers);
50 | // Emit if in current stratum
51 | if (stratumArray[blockColumn] == blockRow) {
52 | reducerNum.set(blockColumn);
53 | matItem = new MatrixItem();
54 | matItem.set(key.getRow().get(), key.getColumn().get(), key
55 | .getValue().get(), key.getMatrixType().toString());
56 | intMatPair.set(reducerNum, matItem);
57 | context.write(intMatPair, nw);
58 | }
59 | } else if (key.isFactorItem()) {
60 | if (key.getMatrixType().equals(MatrixItem.U_MATRIX)) {
61 | // System.out.println("UMatrix["
62 | // + key.getRow()
63 | // + "]["
64 | // + key.getColumn().get()
65 | // + "]: "
66 | // + key.getValue());
67 | int blockRow = DSGDMain.getBlockColumn(row, numUsers,
68 | numReducers);
69 | reducerNum.set(stratumArray[blockRow]);
70 | matItem = new MatrixItem();
71 | matItem.set(key.getRow().get(), key.getColumn().get(), key
72 | .getValue().get(), key.getMatrixType().toString());
73 | intMatPair.set(reducerNum, matItem);
74 | context.write(intMatPair, nw);
75 | } else {
76 | // System.out.println("MMatrix["
77 | // + key.getRow()
78 | // + "]["
79 | // + key.getColumn().get()
80 | // + "]: "
81 | // + key.getValue());
82 | int blockColumn = DSGDMain.getBlockColumn(row, numItems,
83 | numReducers);
84 | reducerNum.set(blockColumn);
85 | matItem = new MatrixItem();
86 | matItem.set(key.getRow().get(), key.getColumn().get(), key
87 | .getValue().get(), key.getMatrixType().toString());
88 | intMatPair.set(reducerNum, matItem);
89 | context.write(intMatPair, nw);
90 | }
91 | }
92 | }
93 | }
94 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/mapper/DSGDOutputFactorsMapper.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd.mapper;
2 |
3 | import java.io.IOException;
4 |
5 | import org.apache.hadoop.io.NullWritable;
6 | import org.apache.hadoop.io.Text;
7 | import org.apache.hadoop.mapreduce.Mapper;
8 |
9 | import comparison.dsgd.MatrixItem;
10 |
11 | public class DSGDOutputFactorsMapper extends Mapper{
12 |
13 | public void map(MatrixItem key, NullWritable value, Context context)
14 | throws IOException, InterruptedException {
15 |
16 | context.write(key, value);
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/mapper/DSGDPreprocFactorMapper.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd.mapper;
2 |
3 | import java.io.IOException;
4 |
5 | import org.apache.hadoop.io.LongWritable;
6 | import org.apache.hadoop.io.NullWritable;
7 | import org.apache.hadoop.io.Text;
8 | import org.apache.hadoop.mapreduce.Mapper;
9 |
10 | import comparison.dsgd.MatrixItem;
11 |
12 | public class DSGDPreprocFactorMapper extends
13 | Mapper {
14 |
15 | MatrixItem matItem = new MatrixItem(); // reuse writable objects for efficiency
16 | //RatingsItem ratItem = new RatingsItem();
17 | NullWritable nw = NullWritable.get();
18 |
19 | @Override
20 | protected void setup(Context context)
21 | throws IOException, InterruptedException {
22 | context.write(new MatrixItem(), nw);
23 | super.setup(context);
24 | }
25 |
26 | @Override
27 | public void map(LongWritable key, Text value, Context context)
28 | throws IOException, InterruptedException {
29 | }
30 | }
31 |
32 |
33 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/mapper/DSGDPreprocMapper.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd.mapper;
2 |
3 | import java.io.IOException;
4 | import java.util.HashMap;
5 | import java.util.Set;
6 | import java.util.StringTokenizer;
7 |
8 | import org.apache.hadoop.io.LongWritable;
9 | import org.apache.hadoop.io.Text;
10 | import org.apache.hadoop.mapreduce.Mapper;
11 | import org.apache.hadoop.mapreduce.Mapper.Context;
12 |
13 | public class DSGDPreprocMapper extends
14 | Mapper {
15 |
16 | private HashMap> gramMap;
17 |
18 | @Override
19 | public void map(LongWritable key, Text value, Context context)
20 | throws IOException, InterruptedException {
21 |
22 | int numReducers = Integer.parseInt(context.getConfiguration().get(
23 | "numReducers"));
24 | long numUsers = Long.parseLong(context.getConfiguration().get(
25 | "numUsers"));
26 | long numItems = Long.parseLong(context.getConfiguration().get(
27 | "numItems"));
28 | long user = key.get();
29 | long movie = 0;
30 | String line = ((Text) value).toString();
31 | StringTokenizer itr = new StringTokenizer(line);
32 |
33 |
34 | // build gramMap
35 | while (itr.hasMoreTokens()) {
36 |
37 | user++;
38 | }
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/mapper/DSGDPreprocRatingsMapper.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd.mapper;
2 |
3 | import java.io.IOException;
4 | import java.util.StringTokenizer;
5 |
6 | import org.apache.hadoop.io.LongWritable;
7 | import org.apache.hadoop.io.NullWritable;
8 | import org.apache.hadoop.io.Text;
9 | import org.apache.hadoop.mapreduce.Mapper;
10 |
11 | import comparison.dsgd.MatrixItem;
12 |
13 | public class DSGDPreprocRatingsMapper extends
14 | Mapper {
15 |
16 | // reuse writable objects for efficiency
17 | MatrixItem matItem = new MatrixItem();
18 | NullWritable nw = NullWritable.get();
19 |
20 | @Override
21 | public void map(LongWritable key, Text value, Context context)
22 | throws IOException, InterruptedException {
23 |
24 | String line = ((Text) value).toString();
25 | StringTokenizer itr = new StringTokenizer(line);
26 | double dUser = Double.parseDouble(itr.nextToken()); //use double due to data's floating point format
27 | int user = (int)dUser;
28 | double dItem = Double.parseDouble(itr.nextToken());
29 | int item = (int)dItem;
30 | double rating = Double.parseDouble(itr.nextToken());
31 |
32 | if(rating != 0){ // don't do anything with un-rated items
33 | matItem.set(user, item, rating, MatrixItem.R_MATRIX.toString());
34 | context.write(matItem, nw);
35 | }
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/mapper/DSGDRmseMapper.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd.mapper;
2 |
3 | import java.io.IOException;
4 |
5 | import org.apache.hadoop.io.NullWritable;
6 | import org.apache.hadoop.mapreduce.Mapper;
7 |
8 | import comparison.dsgd.MatrixItem;
9 |
10 | public class DSGDRmseMapper extends Mapper{
11 |
12 | NullWritable nw = NullWritable.get();
13 |
14 | public void map(MatrixItem key, NullWritable value, Context context)
15 | throws IOException, InterruptedException {
16 |
17 | context.write(key, nw);
18 |
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/reducer/DSGDFinalReducer.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd.reducer;
2 |
3 | public class DSGDFinalReducer {
4 |
5 | }
6 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/reducer/DSGDIntermediateReducer.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd.reducer;
2 |
3 | import java.io.IOException;
4 |
5 | import org.apache.hadoop.io.NullWritable;
6 | import org.apache.hadoop.mapreduce.Reducer;
7 |
8 | import comparison.dsgd.IntMatrixItemPair;
9 | import comparison.dsgd.MatrixItem;
10 | import comparison.dsgd.MatrixUtils;
11 |
12 | public class DSGDIntermediateReducer extends
13 | Reducer {
14 |
15 | double[][] UMatrix;
16 | double[][] MMatrix;
17 | boolean[][] UMatrixSet;
18 | boolean[][] MMatrixSet;
19 | MatrixItem matItem;
20 | NullWritable nw = NullWritable.get();
21 | int kValue;
22 | double tau;
23 | double lambda;
24 | int numUsers;
25 | int numItems;
26 |
27 | @Override
28 | protected void cleanup(Context context) throws IOException,
29 | InterruptedException {
30 |
31 | // Emit UMatrix
32 | for (int i = 0; i < numUsers; ++i) {
33 | for (int j = 0; j < kValue; ++j) {
34 | if (UMatrixSet[i][j]) {
35 | if (Double.isNaN(UMatrix[i][j])
36 | || Double.isInfinite(UMatrix[i][j])) {
37 | System.out.println("cool");
38 | }
39 | matItem.set(i, j, UMatrix[i][j],
40 | MatrixItem.U_MATRIX.toString());
41 | context.write(matItem, nw);
42 | }
43 | }
44 | }
45 |
46 | // Emit MMatrix
47 | for (int i = 0; i < numItems; ++i) {
48 | for (int j = 0; j < kValue; ++j) {
49 | if (MMatrixSet[i][j]) {
50 | if (Double.isNaN(MMatrix[i][j])
51 | || Double.isInfinite(MMatrix[i][j])) {
52 | System.out.println("cool");
53 | }
54 | matItem.set(i, j, MMatrix[i][j],
55 | MatrixItem.M_MATRIX.toString());
56 | context.write(matItem, nw);
57 | }
58 | }
59 | }
60 | super.cleanup(context);
61 | }
62 |
63 | @Override
64 | protected void setup(Context context) throws IOException,
65 | InterruptedException {
66 | kValue = Integer.parseInt(context.getConfiguration().get("kValue"));
67 | tau = Double.parseDouble(context.getConfiguration().get("stepSize"));
68 | lambda = Double.parseDouble(context.getConfiguration().get("lambda"));
69 | numUsers = Integer.parseInt(context.getConfiguration().get("numUsers"));
70 | numItems = Integer.parseInt(context.getConfiguration().get("numItems"));
71 |
72 | UMatrix = new double[numUsers][kValue];
73 | MMatrix = new double[numItems][kValue];
74 | UMatrixSet = new boolean[numUsers][kValue];
75 | MMatrixSet = new boolean[numItems][kValue];
76 |
77 | super.setup(context);
78 | }
79 |
80 | protected void reduce(IntMatrixItemPair key, Iterable values,
81 | Context context) throws IOException, InterruptedException {
82 |
83 | matItem = key.getMatItem();
84 |
85 | if (matItem.isFactorItem()) {
86 | if (matItem.getMatrixType().equals(MatrixItem.U_MATRIX)) {
87 | UMatrix[matItem.getRow().get()][matItem.getColumn().get()] = matItem
88 | .getValue().get();
89 | UMatrixSet[matItem.getRow().get()][matItem.getColumn().get()] = true;
90 | // System.out.println("UMatrix["
91 | // + matItem.getRow()
92 | // + "]["
93 | // + matItem.getColumn().get()
94 | // + "]: "
95 | // + matItem.getValue().get());
96 | } else {
97 | MMatrix[matItem.getRow().get()][matItem.getColumn().get()] = matItem
98 | .getValue().get();
99 | MMatrixSet[matItem.getRow().get()][matItem.getColumn().get()] = true;
100 | // System.out.println("MMatrix["
101 | // + matItem.getRow()
102 | // + "]["
103 | // + matItem.getColumn().get()
104 | // + "]: "
105 | // + matItem.getValue().get());
106 | }
107 | } else { // If we are now processing Ratings Items then we must have the
108 | // full Factor Matrices
109 | // Perform update
110 | // forall l
111 | // U_il <-- U_il + 2*\tau*M_jl*(r-r') + (\lambda/m)U_il
112 | // M_jl <-- M_jl + 2*\tau*U_il*(r-r') + (\lambda/n)M_jl
113 | int i = matItem.getRow().get();
114 | int j = matItem.getColumn().get();
115 | double oldUval;
116 | double oldMval;
117 | double newUval;
118 | double newMval;
119 | // Get predicted value
120 | // build U_i vector
121 | double[] UVector = UMatrix[i];
122 | // System.out.println("UMatrix[" + i + "][" + 0 + "]: " +
123 | // UVector[0]);
124 |
125 | // build M_j vector
126 | double[] MVector = MMatrix[j];
127 | // System.out.println("MMatrix[" + j + "][" + 0 + "]: " +
128 | // MVector[0]);
129 |
130 | // Compute ratings prediction using dot product of factor vectors
131 | double prediction = MatrixUtils.dotProduct(UVector, MVector);
132 |
133 | // Update Factor Matrices
134 | for (int l = 0; l < kValue; ++l) {
135 | oldUval = UMatrix[i][l];
136 | oldMval = MMatrix[j][l];
137 | newUval = oldUval - (2*tau* (oldMval
138 | * (prediction - matItem.getValue().get()) + lambda
139 | * oldUval));
140 | newMval = oldMval - (2 * tau * (oldUval
141 | * (prediction - matItem.getValue().get()) + lambda
142 | * oldMval));
143 | UMatrix[i][l] = newUval;
144 | MMatrix[j][l] = newMval;
145 | // if (Double.isNaN(newUval) || Double.isInfinite(newUval)
146 | // || newUval == Double.NEGATIVE_INFINITY) {
147 | // System.out.println("cool");
148 | // }
149 | // if (Double.isNaN(newMval) || Double.isInfinite(newMval)
150 | // || newMval == Double.NEGATIVE_INFINITY) {
151 | // System.out.println("cool");
152 | // }
153 | if(Math.abs(newUval) > 100 || Math.abs(newMval) > 100){
154 | System.out.println("cool");
155 | }
156 | }
157 | }
158 | }
159 |
160 | }
161 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/reducer/DSGDOutputFactorsReducer.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd.reducer;
2 |
3 | import java.io.IOException;
4 |
5 | import org.apache.hadoop.io.NullWritable;
6 | import org.apache.hadoop.io.Text;
7 | import org.apache.hadoop.mapreduce.Reducer;
8 |
9 | import comparison.dsgd.MatrixItem;
10 |
11 | /**
12 | * Note: DGSDFinalReducer requires a single reducer!!
13 | *
14 | * @author christopherjohnson
15 | *
16 | */
17 | public class DSGDOutputFactorsReducer extends
18 | Reducer {
19 |
20 | double[][] UMatrix;
21 | double[][] MMatrix;
22 | MatrixItem matItem;
23 | NullWritable nw = NullWritable.get();
24 | int kValue;
25 | double tau;
26 | double lambda;
27 | int numUsers;
28 | int numItems;
29 | Text text = new Text();
30 |
31 | /**
32 | * Output factor matrices on cleanup
33 | */
34 | @Override
35 | protected void cleanup(Context context) throws IOException,
36 | InterruptedException {
37 | text.set("U_Matrix");
38 | context.write(text, nw);
39 | for(double[] UVector : UMatrix){
40 | String current = "";
41 | for(double d : UVector){
42 | current += d + " ";
43 | }
44 | text.set(current);
45 | context.write(text, nw);
46 | }
47 | text.set("M_Matrix");
48 | context.write(text, nw);
49 |
50 | for(double[] MVector : MMatrix){
51 | String current = "";
52 | for(double d : MVector){
53 | current += d + " ";
54 | }
55 | text.set(current);
56 | context.write(text, nw);
57 | }
58 | super.cleanup(context);
59 | }
60 |
61 | @Override
62 | protected void setup(Context context) throws IOException,
63 | InterruptedException {
64 | kValue = Integer.parseInt(context.getConfiguration().get("kValue"));
65 | numUsers = Integer.parseInt(context.getConfiguration().get("numUsers"));
66 | numItems = Integer.parseInt(context.getConfiguration().get("numItems"));
67 |
68 | UMatrix = new double[numUsers][kValue];
69 | MMatrix = new double[numItems][kValue];
70 |
71 |
72 | super.setup(context);
73 | }
74 |
75 | @Override
76 | protected void reduce(MatrixItem key, Iterable values, Context context)
77 | throws IOException, InterruptedException {
78 | if(key.getMatrixType().equals(MatrixItem.U_MATRIX)){
79 | UMatrix[key.getRow().get()][key.getColumn().get()] = key.getValue().get();
80 | } else{
81 | MMatrix[key.getRow().get()][key.getColumn().get()] = key.getValue().get();
82 | }
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/reducer/DSGDPreprocFactorReducer.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd.reducer;
2 |
3 | import java.io.IOException;
4 |
5 | import org.apache.hadoop.io.NullWritable;
6 | import org.apache.hadoop.mapreduce.Reducer;
7 |
8 | import comparison.dsgd.MatrixItem;
9 |
10 | public class DSGDPreprocFactorReducer extends Reducer{
11 |
12 | NullWritable nw = NullWritable.get();
13 | MatrixItem matItem = new MatrixItem();
14 |
15 | @Override
16 | protected void setup(Context context) throws IOException,
17 | InterruptedException {
18 | // Emit Factor Items (on the reduce side be sure to ignore duplicates
19 | // that may occur due to multiple mappers)
20 | int numUsers = Integer.parseInt(context.getConfiguration().get(
21 | "numUsers"));
22 | int numItems = Integer.parseInt(context.getConfiguration().get(
23 | "numItems"));
24 | int kValue = Integer.parseInt(context.getConfiguration().get("kValue"));
25 | for (int i = 0; i < numUsers; ++i) {
26 | for (int j = 0; j < kValue; ++j) {
27 | matItem.set(i, j, Math.random(), MatrixItem.U_MATRIX.toString());
28 | context.write(matItem, nw);
29 | }
30 | }
31 | for (int i = 0; i < numItems; ++i) {
32 | for (int j = 0; j < kValue; ++j) {
33 | matItem.set(i, j, Math.random(), MatrixItem.M_MATRIX.toString());
34 | context.write(matItem, nw);
35 | }
36 | }
37 | super.setup(context);
38 | }
39 |
40 | protected void reduce(MatrixItem key, Iterable values, Context context)
41 | throws IOException, InterruptedException {
42 |
43 | }
44 |
45 | }
46 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/reducer/DSGDPreprocRatingsReducer.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd.reducer;
2 |
3 | import java.io.IOException;
4 |
5 | import org.apache.hadoop.io.NullWritable;
6 | import org.apache.hadoop.mapreduce.Reducer;
7 |
8 | import comparison.dsgd.MatrixItem;
9 |
10 | public class DSGDPreprocRatingsReducer extends Reducer{
11 |
12 | NullWritable nw = NullWritable.get();
13 |
14 | protected void reduce(MatrixItem key, Iterable values, Context context)
15 | throws IOException, InterruptedException {
16 |
17 | context.write(key, nw);
18 |
19 | }
20 |
21 | }
22 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/reducer/DSGDPreprocReducer.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd.reducer;
2 |
3 | public class DSGDPreprocReducer {
4 |
5 | }
6 |
--------------------------------------------------------------------------------
/src/main/java/comparison/dsgd/reducer/DSGDRmseReducer.java:
--------------------------------------------------------------------------------
1 | package comparison.dsgd.reducer;
2 |
3 | import java.io.IOException;
4 |
5 | import org.apache.hadoop.io.DoubleWritable;
6 | import org.apache.hadoop.io.NullWritable;
7 | import org.apache.hadoop.mapreduce.Reducer;
8 |
9 | import comparison.dsgd.MatrixItem;
10 | import comparison.dsgd.MatrixUtils;
11 | import comparison.dsgd.MatrixUtils.MatrixException;
12 |
13 | public class DSGDRmseReducer extends
14 | Reducer {
15 |
16 | double[][] UMatrix;
17 | double[][] MMatrix;
18 | int kValue;
19 | double timeElapsed;
20 | int numUsers;
21 | int numItems;
22 | double sqError = 0;
23 | long numRatings = 0;
24 |
25 | @Override
26 | protected void cleanup(Context context) throws IOException,
27 | InterruptedException {
28 | // Calculate RMSE from Sq Error and emit along with time elapsed
29 | double rmse = sqError / (double)numRatings;
30 | context.write(new DoubleWritable(timeElapsed), new DoubleWritable(rmse));
31 |
32 | super.cleanup(context);
33 | }
34 |
35 | @Override
36 | protected void setup(Context context) throws IOException,
37 | InterruptedException {
38 | // UMap = new HashMap();
39 | // MMap = new HashMap();
40 | kValue = Integer.parseInt(context.getConfiguration().get("kValue"));
41 | numUsers = Integer.parseInt(context.getConfiguration().get("numUsers"));
42 | numItems = Integer.parseInt(context.getConfiguration().get("numItems"));
43 | timeElapsed = Double.parseDouble(context.getConfiguration().get("timeElapsed"));
44 |
45 | UMatrix = new double[numUsers][kValue];
46 | MMatrix = new double[numItems][kValue];
47 |
48 | super.setup(context);
49 | }
50 |
51 | @Override
52 | protected void reduce(MatrixItem key, Iterable values,
53 | Context context) throws IOException, InterruptedException {
54 |
55 | if (key.isFactorItem()) {
56 | if (key.getMatrixType().equals(MatrixItem.U_MATRIX)) {
57 | UMatrix[key.getRow().get()][key.getColumn().get()] = key
58 | .getValue().get();
59 | } else {
60 | // MMap.put(new LongPair(facItem.getRow().get(), facItem
61 | // .getColumn().get()), facItem.getValue().get());
62 | MMatrix[key.getRow().get()][key.getColumn().get()] = key
63 | .getValue().get();
64 | }
65 | } else { // If we are now processing RatingsItems then we must have the
66 | // full Factor Matrices
67 | int i = key.getRow().get();
68 | int j = key.getColumn().get();
69 |
70 | double[] UVector = UMatrix[i];
71 | double[] MVector = MMatrix[j];
72 |
73 | double prediction = MatrixUtils.dotProduct(UVector, MVector);
74 |
75 | sqError += ((prediction - key.getValue().get()) * (prediction - key
76 | .getValue().get()));
77 | numRatings++;
78 | }
79 | }
80 |
81 | }
82 |
--------------------------------------------------------------------------------
/src/main/resources/log4j.properties:
--------------------------------------------------------------------------------
1 | log4j.rootLogger=INFO, A1
2 | log4j.appender.A1=org.apache.log4j.ConsoleAppender
3 | log4j.appender.A1.layout=org.apache.log4j.PatternLayout
4 | log4j.appender.A1.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n
5 |
6 | log4j.logger.backtype.storm.daemon=WARN
7 | log4j.logger.backtype.storm.serialization=WARN
8 | log4j.logger.backtype.storm.zookeeper=WARN
9 | log4j.logger.org.apache.zookeeper=ERROR
10 |
--------------------------------------------------------------------------------
/src/main/resources/lotr.txt:
--------------------------------------------------------------------------------
1 | Three Rings for the Elven-kings under the sky,
2 | Seven for the Dwarf-lords in their halls of stone,
3 | Nine for Mortal Men doomed to die,
4 | One for the Dark Lord on his dark throne
5 | In the Land of Mordor where the Shadows lie.
6 | One Ring to rule them all, One Ring to find them,
7 | One Ring to bring them all and in the darkness bind them
8 | In the Land of Mordor where the Shadows lie.
--------------------------------------------------------------------------------
/src/test/java/collabstream/AppTest.java:
--------------------------------------------------------------------------------
1 | package collabstream;
2 |
3 | import junit.framework.Test;
4 | import junit.framework.TestCase;
5 | import junit.framework.TestSuite;
6 |
7 | /**
8 | * Unit test for simple App.
9 | */
10 | public class AppTest
11 | extends TestCase
12 | {
13 | /**
14 | * Create the test case
15 | *
16 | * @param testName name of the test case
17 | */
18 | public AppTest( String testName )
19 | {
20 | super( testName );
21 | }
22 |
23 | /**
24 | * @return the suite of tests being tested
25 | */
26 | public static Test suite()
27 | {
28 | return new TestSuite( AppTest.class );
29 | }
30 |
31 | /**
32 | * Rigourous Test :-)
33 | */
34 | public void testApp()
35 | {
36 | assertTrue( true );
37 | }
38 | }
39 |
--------------------------------------------------------------------------------