├── .gitignore ├── logos ├── UofT.jpg ├── vector.jpg └── layer6ai-logo.png ├── .settings ├── org.eclipse.m2e.core.prefs ├── org.eclipse.core.resources.prefs └── org.eclipse.jdt.core.prefs ├── src └── main │ └── java │ ├── main │ ├── Data.java │ ├── Track.java │ ├── ParsedData.java │ ├── Song.java │ ├── Playlist.java │ ├── Latents.java │ ├── DataLoader.java │ ├── SVD.java │ ├── Executor.java │ ├── SVDModel.java │ ├── RecSysSplitter.java │ └── ParsedDataLoader.java │ └── common │ ├── MutableFloat.java │ ├── ResultCF.java │ ├── MLRandomUtils.java │ ├── EvaluatorClicks.java │ ├── EvaluatorRPrecision.java │ ├── LowLevelRoutines.java │ ├── EvaluatorBinaryNDCG.java │ ├── EvaluatorRPrecisionArtist.java │ ├── MLTimer.java │ ├── MLConcurrentUtils.java │ ├── MLDenseVector.java │ ├── MLIOUtils.java │ ├── MLSparseMatrix.java │ ├── EvaluatorCF.java │ ├── MLXGBoost.java │ ├── MLTextTransform.java │ ├── MLMatrixElement.java │ ├── MLFeatureTransform.java │ ├── FloatElement.java │ ├── ALS.java │ ├── SplitterCF.java │ ├── MLDenseMatrix.java │ └── MLSparseVector.java ├── .project ├── .classpath ├── pom.xml ├── script └── svd_py.py ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | /target/ 2 | .settings/org.scala* 3 | *.log 4 | .cache* 5 | .idea 6 | *.iml -------------------------------------------------------------------------------- /logos/UofT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layer6ai-labs/RecSys2018/HEAD/logos/UofT.jpg -------------------------------------------------------------------------------- /logos/vector.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layer6ai-labs/RecSys2018/HEAD/logos/vector.jpg -------------------------------------------------------------------------------- /logos/layer6ai-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layer6ai-labs/RecSys2018/HEAD/logos/layer6ai-logo.png -------------------------------------------------------------------------------- /.settings/org.eclipse.m2e.core.prefs: -------------------------------------------------------------------------------- 1 | activeProfiles= 2 | eclipse.preferences.version=1 3 | resolveWorkspaceProjects=true 4 | version=1 5 | -------------------------------------------------------------------------------- /.settings/org.eclipse.core.resources.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | encoding//src/main/java=UTF-8 3 | encoding/=UTF-8 4 | -------------------------------------------------------------------------------- /src/main/java/main/Data.java: -------------------------------------------------------------------------------- 1 | package main; 2 | 3 | import java.io.Serializable; 4 | 5 | public class Data implements Serializable { 6 | private static final long serialVersionUID = -7664221183075239249L; 7 | public Song[] songs; 8 | public Playlist[] playlists; 9 | public int[] testIndexes; 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/common/MutableFloat.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.io.Serializable; 4 | 5 | public class MutableFloat implements Serializable { 6 | 7 | private static final long serialVersionUID = -3705775132945867924L; 8 | public float value; 9 | 10 | public MutableFloat(final float valueP) { 11 | this.value = valueP; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/main/Track.java: -------------------------------------------------------------------------------- 1 | package main; 2 | 3 | import java.io.Serializable; 4 | 5 | public class Track implements Serializable { 6 | 7 | private static final long serialVersionUID = -4185883780088342841L; 8 | 9 | private int songIndex; 10 | private int songPos; 11 | 12 | public Track(final int songIndexP, final int songPosP) { 13 | this.songIndex = songIndexP; 14 | this.songPos = songPosP; 15 | } 16 | 17 | public int getSongIndex() { 18 | return this.songIndex; 19 | } 20 | 21 | public int getSongPos() { 22 | return this.songPos; 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | vl6 4 | 5 | 6 | 7 | 8 | 9 | org.eclipse.jdt.core.javabuilder 10 | 11 | 12 | 13 | 14 | org.eclipse.m2e.core.maven2Builder 15 | 16 | 17 | 18 | 19 | 20 | org.eclipse.jdt.core.javanature 21 | org.eclipse.m2e.core.maven2Nature 22 | 23 | 24 | -------------------------------------------------------------------------------- /.settings/org.eclipse.jdt.core.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled 3 | org.eclipse.jdt.core.compiler.codegen.methodParameters=do not generate 4 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 5 | org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve 6 | org.eclipse.jdt.core.compiler.compliance=1.8 7 | org.eclipse.jdt.core.compiler.debug.lineNumber=generate 8 | org.eclipse.jdt.core.compiler.debug.localVariable=generate 9 | org.eclipse.jdt.core.compiler.debug.sourceFile=generate 10 | org.eclipse.jdt.core.compiler.problem.assertIdentifier=error 11 | org.eclipse.jdt.core.compiler.problem.enumIdentifier=error 12 | org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning 13 | org.eclipse.jdt.core.compiler.source=1.8 14 | -------------------------------------------------------------------------------- /.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /src/main/java/common/ResultCF.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.io.Serializable; 4 | 5 | public class ResultCF implements Comparable, Serializable { 6 | private static final long serialVersionUID = 22998468127105885L; 7 | private String objective; 8 | private double[] result; 9 | private int nEval; 10 | 11 | public ResultCF(final String objectiveP, final double[] resultP, 12 | final int nEvalP) { 13 | this.objective = objectiveP; 14 | this.result = resultP; 15 | this.nEval = nEvalP; 16 | } 17 | 18 | @Override public int compareTo(final ResultCF o) { 19 | return Double.compare(this.last(), o.last()); 20 | } 21 | 22 | public double[] get() { 23 | return this.result; 24 | } 25 | 26 | public double last() { 27 | return this.result[this.result.length - 1]; 28 | } 29 | 30 | @Override public String toString() { 31 | StringBuilder builder = new StringBuilder(); 32 | builder.append("nEval: " + this.nEval + ", "); 33 | builder.append(this.objective + ":"); 34 | for (int i = 0; i < this.result.length; i++) { 35 | builder.append(String.format(" %.4f", this.result[i])); 36 | } 37 | 38 | return builder.toString(); 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/common/MLRandomUtils.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.util.Random; 4 | 5 | public class MLRandomUtils { 6 | 7 | public static float nextFloat(final float min, final float max, 8 | final Random rng) { 9 | return min + rng.nextFloat() * (max - min); 10 | } 11 | 12 | public static void shuffle(int[] array, final Random rng) { 13 | for (int i = array.length - 1; i > 0; i--) { 14 | int index = rng.nextInt(i + 1); 15 | // swap 16 | int element = array[index]; 17 | array[index] = array[i]; 18 | array[i] = element; 19 | } 20 | } 21 | 22 | public static void shuffle(Object[] array, int startInclusive, 23 | int endExclusive, final Random rng) { 24 | final int len = endExclusive - startInclusive; 25 | 26 | for (int j = len - 1; j > 0; j--) { 27 | int index = rng.nextInt(j + 1) + startInclusive; 28 | int i = j + startInclusive; 29 | // swap 30 | Object element = array[index]; 31 | array[index] = array[i]; 32 | array[i] = element; 33 | } 34 | } 35 | 36 | public static void shuffle(Object[] array, final Random rng) { 37 | shuffle(array, 0, array.length, rng); 38 | } 39 | 40 | public static int[] shuffleCopy(int[] array, final Random rng) { 41 | 42 | int[] copy = array.clone(); 43 | shuffle(copy, rng); 44 | return copy; 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/main/ParsedData.java: -------------------------------------------------------------------------------- 1 | package main; 2 | 3 | import java.io.Serializable; 4 | import java.util.Map; 5 | 6 | import common.MLSparseFeature; 7 | import common.MLSparseMatrix; 8 | 9 | public class ParsedData implements Serializable { 10 | 11 | public static final String INTERACTION_KEY = "INT"; 12 | 13 | public enum PlaylistFeature implements Serializable { 14 | NAME_TOKENIZED, 15 | NAME_REGEXED, 16 | NAME_ORIGINAL, 17 | N_TRACKS, 18 | } 19 | 20 | public enum SongFeature implements Serializable { 21 | ARTIST_ID, 22 | ALBUM_ID, 23 | TRACK_NAME, 24 | DURATION; 25 | } 26 | 27 | public enum SongExtraInfoFeature implements Serializable { 28 | acousticness, 29 | danceability, 30 | energy, 31 | instrumentalness, 32 | key, // categorical 33 | liveness, 34 | loudness, 35 | mode, 36 | speechiness, 37 | tempo, 38 | time_signature, 39 | valence 40 | } 41 | 42 | private static final long serialVersionUID = 736424464160763130L; 43 | 44 | public String[] songIds; 45 | public String[] playlistIds; 46 | public int[] testIndexes; 47 | 48 | public Map playlistFeatures; 49 | public Map songFeatures; 50 | public Map songExtraInfoFeatures; 51 | public MLSparseMatrix interactions; 52 | 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/main/Song.java: -------------------------------------------------------------------------------- 1 | package main; 2 | 3 | import net.minidev.json.JSONObject; 4 | 5 | import java.io.Serializable; 6 | 7 | public class Song implements Serializable { 8 | 9 | private static final long serialVersionUID = 6265625137029257218L; 10 | private String artist_name; 11 | private String track_uri; 12 | private String artist_uri; 13 | private String track_name; 14 | private String album_uri; 15 | private int duration_ms; 16 | private String album_name; 17 | 18 | public Song(final JSONObject obj) { 19 | this.artist_name = obj.getAsString("artist_name"); 20 | this.track_uri = obj.getAsString("track_uri"); 21 | this.artist_uri = obj.getAsString("artist_uri"); 22 | this.track_name = obj.getAsString("track_name"); 23 | this.album_uri = obj.getAsString("album_uri"); 24 | this.duration_ms = obj.getAsNumber("duration_ms").intValue(); 25 | this.album_name = obj.getAsString("album_name"); 26 | } 27 | 28 | public String get_artist_name() { 29 | return this.artist_name; 30 | } 31 | 32 | public String get_track_uri() { 33 | return this.track_uri; 34 | } 35 | 36 | public String get_artist_uri() { 37 | return this.artist_uri; 38 | } 39 | 40 | public String get_track_name() { 41 | return this.track_name; 42 | } 43 | 44 | public String get_album_uri() { 45 | return this.album_uri; 46 | } 47 | 48 | public int get_duration_ms() { 49 | return this.duration_ms; 50 | } 51 | 52 | public String get_album_name() { 53 | return this.album_name; 54 | } 55 | 56 | } 57 | -------------------------------------------------------------------------------- /src/main/java/common/EvaluatorClicks.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.util.Arrays; 4 | import java.util.concurrent.atomic.AtomicInteger; 5 | import java.util.stream.IntStream; 6 | 7 | public class EvaluatorClicks extends EvaluatorCF { 8 | 9 | public EvaluatorClicks(int[] evalThreshsP) { 10 | super(evalThreshsP); 11 | } 12 | 13 | @Override 14 | public ResultCF evaluate(final SplitterCF split, 15 | final String interactionType, final FloatElement[][] preds) { 16 | 17 | double[] clicks = new double[] { 0.0 }; 18 | int maxThresh = this.getMaxEvalThresh(); 19 | int[] validRowIndexes = split.getValidRowIndexes(); 20 | MLSparseMatrix validMatrix = split.getRsvalid().get(interactionType); 21 | AtomicInteger nTotal = new AtomicInteger(0); 22 | IntStream.range(0, validRowIndexes.length).parallel().forEach(index -> { 23 | 24 | int rowIndex = validRowIndexes[index]; 25 | MLSparseVector row = validMatrix.getRow(rowIndex); 26 | FloatElement[] rowPreds = preds[rowIndex]; 27 | 28 | if (row == null || rowPreds == null) { 29 | return; 30 | } 31 | nTotal.incrementAndGet(); 32 | int[] targetIndexes = row.getIndexes(); 33 | 34 | int matchIndex = (int) Math.floor(maxThresh / 10.0) + 1; 35 | 36 | for (int i = 0; i < maxThresh; i++) { 37 | if (Arrays.binarySearch(targetIndexes, 38 | rowPreds[i].getIndex()) >= 0) { 39 | matchIndex = (int) Math.floor(i / 10.0); 40 | break; 41 | } 42 | } 43 | 44 | synchronized (clicks) { 45 | clicks[0] += matchIndex; 46 | } 47 | }); 48 | 49 | clicks[0] = clicks[0] / nTotal.get(); 50 | return new ResultCF("clicks", clicks, nTotal.get()); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/main/java/common/EvaluatorRPrecision.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.util.Arrays; 4 | import java.util.concurrent.atomic.AtomicInteger; 5 | import java.util.stream.IntStream; 6 | 7 | public class EvaluatorRPrecision extends EvaluatorCF { 8 | 9 | public EvaluatorRPrecision(int[] evalThreshsP) { 10 | super(evalThreshsP); 11 | } 12 | 13 | @Override 14 | public ResultCF evaluate(final SplitterCF split, 15 | final String interactionType, final FloatElement[][] preds) { 16 | 17 | double[] rPrecision = new double[] { 0.0 }; 18 | MLSparseMatrix validMatrix = split.getRsvalid().get(interactionType); 19 | AtomicInteger nTotal = new AtomicInteger(0); 20 | int[] validRowIndexes = split.getValidRowIndexes(); 21 | IntStream.range(0, validRowIndexes.length).parallel().forEach(index -> { 22 | 23 | int rowIndex = validRowIndexes[index]; 24 | MLSparseVector row = validMatrix.getRow(rowIndex); 25 | FloatElement[] rowPreds = preds[rowIndex]; 26 | 27 | if (row == null || rowPreds == null) { 28 | return; 29 | } 30 | nTotal.incrementAndGet(); 31 | 32 | double nMatched = 0; 33 | int[] targetIndexes = row.getIndexes(); 34 | for (int i = 0; i < Math.min(targetIndexes.length, 35 | rowPreds.length); i++) { 36 | if (Arrays.binarySearch(targetIndexes, 37 | rowPreds[i].getIndex()) >= 0) { 38 | nMatched++; 39 | } 40 | } 41 | synchronized (rPrecision) { 42 | rPrecision[0] += nMatched 43 | / Math.min(targetIndexes.length, rowPreds.length); 44 | } 45 | }); 46 | 47 | rPrecision[0] = rPrecision[0] / nTotal.get(); 48 | return new ResultCF("r-precision", rPrecision, nTotal.get()); 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/main/java/common/LowLevelRoutines.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import org.netlib.util.intW; 4 | 5 | import com.github.fommil.netlib.BLAS; 6 | import com.github.fommil.netlib.LAPACK; 7 | 8 | public class LowLevelRoutines { 9 | 10 | public static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 50; 11 | 12 | public static void sgemm(float[] A, float[] B, float[] C, int nRowsA, 13 | int nColsB, int nColsA, boolean rowStoreA, boolean rowStoreB, 14 | float alpha, float beta) { 15 | // blas uses fortran out, so we do CT=(BT)(AT) instead of C=AB 16 | 17 | final int m = nColsB; 18 | final int n = nRowsA; 19 | final int k = nColsA; 20 | final int lda; 21 | final int ldb; 22 | final int ldc; 23 | final String transA; 24 | final String transB; 25 | if (rowStoreB) { 26 | transA = "N"; 27 | lda = m; 28 | ldc = m; 29 | if (rowStoreA) { 30 | ldb = k; 31 | transB = "N"; 32 | } else { 33 | ldb = n; 34 | transB = "T"; 35 | } 36 | } else { 37 | transA = "T"; 38 | lda = k; 39 | ldc = m; 40 | if (rowStoreA) { 41 | ldb = k; 42 | transB = "N"; 43 | } else { 44 | ldb = n; 45 | transB = "T"; 46 | } 47 | } 48 | BLAS.getInstance().sgemm(transA, transB, m, n, k, alpha, B, lda, A, ldb, 49 | beta, C, ldc); 50 | } 51 | 52 | public static void symmetricSolve(final float[] data, final int nRows, 53 | float[] b, float[] cache) { 54 | int[] ipiv = new int[nRows]; 55 | intW info = new intW(0); 56 | LAPACK.getInstance().ssysv("L", nRows, 1, data, nRows, ipiv, b, 57 | b.length, cache, cache.length, info); 58 | } 59 | 60 | public static int symmInverseCacheSize(final float[] data, 61 | final int nRows) { 62 | int[] ipiv = new int[nRows]; 63 | intW info = new intW(0); 64 | float[] cacheSize = new float[1]; 65 | LAPACK.getInstance().ssytrf("L", nRows, data, nRows, ipiv, cacheSize, 66 | -1, info); 67 | return (int) cacheSize[0]; 68 | } 69 | 70 | } 71 | -------------------------------------------------------------------------------- /src/main/java/common/EvaluatorBinaryNDCG.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.util.Arrays; 4 | import java.util.HashMap; 5 | import java.util.Map; 6 | import java.util.concurrent.atomic.AtomicInteger; 7 | import java.util.stream.IntStream; 8 | 9 | public class EvaluatorBinaryNDCG extends EvaluatorCF { 10 | 11 | public EvaluatorBinaryNDCG(int[] evalThreshsP) { 12 | super(evalThreshsP); 13 | } 14 | 15 | @Override 16 | public ResultCF evaluate(final SplitterCF split, 17 | final String interactionType, final FloatElement[][] preds) { 18 | 19 | double[] ndcg = new double[this.threshs.length]; 20 | 21 | int maxThresh = this.getMaxEvalThresh(); 22 | Map threshToIndex = new HashMap<>(); 23 | for (int i = 0; i < this.threshs.length; i++) { 24 | threshToIndex.put(this.threshs[i] - 1, i); 25 | } 26 | 27 | int[] validRowIndexes = split.getValidRowIndexes(); 28 | 29 | MLSparseMatrix validMatrix = split.getRsvalid().get(interactionType); 30 | AtomicInteger nTotal = new AtomicInteger(0); 31 | IntStream.range(0, validRowIndexes.length).parallel().forEach(index -> { 32 | 33 | int rowIndex = validRowIndexes[index]; 34 | MLSparseVector row = validMatrix.getRow(rowIndex); 35 | FloatElement[] rowPreds = preds[rowIndex]; 36 | 37 | if (row == null || rowPreds == null) { 38 | return; 39 | } 40 | nTotal.incrementAndGet(); 41 | int[] targetIndexes = row.getIndexes(); 42 | 43 | double dcg = 0; 44 | double idcg = 0; 45 | for (int i = 0; i < maxThresh; i++) { 46 | if (Arrays.binarySearch(targetIndexes, 47 | rowPreds[i].getIndex()) >= 0) { 48 | // prediction DCG 49 | if (i == 0) { 50 | dcg += 1.0; 51 | } else { 52 | dcg += 1.0 / log2(i + 2.0); 53 | } 54 | } 55 | if (i < targetIndexes.length) { 56 | // ideal DCG 57 | if (i == 0) { 58 | idcg += 1.0; 59 | } else { 60 | idcg += 1.0 / log2(i + 2.0); 61 | } 62 | } 63 | 64 | if (threshToIndex.containsKey(i) == true) { 65 | synchronized (ndcg) { 66 | ndcg[threshToIndex.get(i)] += dcg / idcg; 67 | } 68 | } 69 | } 70 | }); 71 | 72 | int nEval = nTotal.get(); 73 | for (int i = 0; i < ndcg.length; i++) { 74 | ndcg[i] /= nTotal.get(); 75 | } 76 | return new ResultCF("b-ndcg", ndcg, nEval); 77 | } 78 | 79 | private static double log2(final double in) { 80 | return Math.log(in) / Math.log(2.0); 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 4 | 4.0.0 5 | 6 | recsysy2018 7 | vl6 8 | 0.0.1-SNAPSHOT 9 | jar 10 | 11 | vl6 12 | http://maven.apache.org 13 | 14 | 15 | UTF-8 16 | 17 | 18 | 19 | 20 | 21 | 22 | junit 23 | junit 24 | 3.8.1 25 | test 26 | 27 | 28 | 29 | ml.dmlc 30 | xgboost4j 31 | 0.7 32 | 33 | 34 | 35 | org.apache.lucene 36 | lucene-analyzers-common 37 | 6.6.1 38 | 39 | 40 | 41 | com.google.guava 42 | guava 43 | 22.0 44 | 45 | 46 | 47 | com.github.fommil.netlib 48 | all 49 | 1.1.2 50 | pom 51 | 52 | 53 | 54 | net.minidev 55 | json-smart 56 | 2.2.1 57 | 58 | 59 | 60 | com.squareup.okhttp3 61 | okhttp 62 | 3.10.0 63 | 64 | 65 | 66 | org.json 67 | json 68 | 20180130 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | org.apache.maven.plugins 77 | maven-compiler-plugin 78 | 3.5 79 | 80 | 1.8 81 | 1.8 82 | 83 | 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /script/svd_py.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from scipy.sparse import * 3 | import numpy as np 4 | from sklearn.utils.extmath import randomized_svd 5 | 6 | # Note: This file is for Java command call only, not part of this package at all. 7 | 8 | def check_int_positive(value): 9 | ivalue = int(value) 10 | if ivalue < 0: 11 | raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value) 12 | return ivalue 13 | 14 | 15 | def check_float_positive(value): 16 | ivalue = float(value) 17 | if ivalue <= 0: 18 | raise argparse.ArgumentTypeError("%s is an invalid positive float value" % value) 19 | return ivalue 20 | 21 | def shape(s): 22 | try: 23 | num = int(s) 24 | return num 25 | except: 26 | raise argparse.ArgumentTypeError("Sparse matrix shape must be integer") 27 | 28 | 29 | def load_csv(path, name, shape): 30 | data = np.genfromtxt(path + name, delimiter=',') 31 | matrix = csr_matrix((data[:, 2], (data[:, 0].astype('int32'), data[:, 1].astype('int32'))), shape=shape) 32 | return matrix 33 | 34 | def save_np(matrix, path, name): 35 | np.savetxt(path + name, matrix, delimiter=',', fmt='%.5f') 36 | 37 | def main(args): 38 | print("Reading CSV") 39 | matrix_input = load_csv(path=args.path, name=args.train, shape=args.shape) 40 | print("Perform SVD") 41 | P, sigma, Qt = randomized_svd(matrix_input, 42 | n_components=args.rank, 43 | n_iter=args.iter, 44 | power_iteration_normalizer='LU', 45 | random_state=1) 46 | 47 | PtimesS = np.matmul(P, np.diag(sigma)) 48 | print "computed P*S" 49 | 50 | #Pt = P.T 51 | save_np(PtimesS, args.path, args.user) 52 | print "saved P*S" 53 | 54 | save_np(Qt.T, args.path, args.item) 55 | print "saved Q" 56 | 57 | save_np(sigma, args.path, args.sigm) 58 | print "saved s" 59 | 60 | 61 | if __name__ == "__main__": 62 | # Commandline arguments 63 | parser = argparse.ArgumentParser(description="SVD") 64 | parser.add_argument('-i', dest='iter', type=check_int_positive, default=4) 65 | parser.add_argument('-r', dest='rank', type=check_int_positive, default=100) 66 | parser.add_argument('-d', dest='path', default="/media/wuga/Storage/python_project/wlrec/IMPLEMENTATION_Projected_LRec/data/") 67 | parser.add_argument('-f', dest='train', default='matrix.csv') 68 | parser.add_argument('-u', dest='user', default='U.nd') 69 | parser.add_argument('-v', dest='item', default='V.nd') 70 | parser.add_argument('-s', dest='sigm', default='S.nd') 71 | parser.add_argument('--shape', help="CSR Shape", dest="shape", type=shape, nargs=2) 72 | args = parser.parse_args() 73 | 74 | main(args) 75 | -------------------------------------------------------------------------------- /src/main/java/common/EvaluatorRPrecisionArtist.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.util.Arrays; 4 | import java.util.HashSet; 5 | import java.util.Set; 6 | import java.util.concurrent.atomic.AtomicInteger; 7 | import java.util.stream.IntStream; 8 | 9 | public class EvaluatorRPrecisionArtist extends EvaluatorCF { 10 | public MLSparseMatrix songArtist; 11 | 12 | public EvaluatorRPrecisionArtist(int[] evalThreshsP, 13 | MLSparseMatrix songArtist) { 14 | super(evalThreshsP); 15 | this.songArtist = songArtist; 16 | } 17 | 18 | @Override 19 | public ResultCF evaluate(final SplitterCF split, 20 | final String interactionType, final FloatElement[][] preds) { 21 | 22 | double[] rPrecision = new double[] { 0.0 }; 23 | MLSparseMatrix validMatrix = split.getRsvalid().get(interactionType); 24 | AtomicInteger nTotal = new AtomicInteger(0); 25 | int[] validRowIndexes = split.getValidRowIndexes(); 26 | IntStream.range(0, validRowIndexes.length).parallel().forEach(index -> { 27 | 28 | int rowIndex = validRowIndexes[index]; 29 | MLSparseVector row = validMatrix.getRow(rowIndex); 30 | FloatElement[] rowPreds = preds[rowIndex]; 31 | 32 | if (row == null || rowPreds == null) { 33 | return; 34 | } 35 | nTotal.incrementAndGet(); 36 | 37 | double nMatched = 0; 38 | int[] targetIndexes = row.getIndexes(); 39 | 40 | // get indexes of all artists of the target songs 41 | Set artistIndexes = new HashSet(); 42 | for (int songIndex : targetIndexes) { 43 | 44 | MLSparseVector artist = songArtist.getRow(songIndex); 45 | if (artist == null) { 46 | continue; 47 | } 48 | 49 | for (int artistIndex : artist.getIndexes()) { 50 | artistIndexes.add(artistIndex); 51 | } 52 | } 53 | 54 | // set of artist Indexes that's already matched, since it 55 | // only counts once 56 | Set artistIndexes_already_matched = new HashSet(); 57 | 58 | for (int i = 0; i < Math.min(targetIndexes.length, 59 | rowPreds.length); i++) { 60 | if (Arrays.binarySearch(targetIndexes, 61 | rowPreds[i].getIndex()) >= 0) { 62 | nMatched++; 63 | } else { 64 | int artistIndex = songArtist.getRow(rowPreds[i].getIndex()) 65 | .getIndexes()[0]; 66 | 67 | if (artistIndexes.contains(artistIndex) 68 | && (!artistIndexes_already_matched 69 | .contains(artistIndex))) { 70 | artistIndexes_already_matched.add(artistIndex); 71 | nMatched = nMatched + 0.25; 72 | } 73 | } 74 | } 75 | synchronized (rPrecision) { 76 | rPrecision[0] += nMatched 77 | / Math.min(targetIndexes.length, rowPreds.length); 78 | } 79 | }); 80 | 81 | rPrecision[0] = rPrecision[0] / nTotal.get(); 82 | return new ResultCF("r-precision-artist", rPrecision, nTotal.get()); 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/main/java/main/Playlist.java: -------------------------------------------------------------------------------- 1 | package main; 2 | 3 | import net.minidev.json.JSONObject; 4 | 5 | import java.io.Serializable; 6 | 7 | public class Playlist implements Serializable { 8 | 9 | private static final long serialVersionUID = 6071968443385525640L; 10 | public String name; 11 | public Boolean collaborative; 12 | public String pid; 13 | public Long modified_at; 14 | public Integer num_albums; 15 | public Integer num_tracks; 16 | public Integer num_followers; 17 | public Integer num_edits; 18 | public Integer duration_ms; 19 | public Integer num_artists; 20 | public Track[] tracks; 21 | 22 | public Playlist(final JSONObject obj) { 23 | this.pid = obj.getAsString("pid"); 24 | this.name = obj.getAsString("name"); 25 | 26 | if (obj.containsKey("collaborative") == true) { 27 | this.collaborative = obj.getAsString("collaborative").toLowerCase() 28 | .equals("true"); 29 | } 30 | if (obj.containsKey("modified_at") == true) { 31 | this.modified_at = obj.getAsNumber("modified_at").longValue(); 32 | } 33 | if (obj.containsKey("num_albums") == true) { 34 | this.num_albums = obj.getAsNumber("num_albums").intValue(); 35 | } 36 | if (obj.containsKey("num_tracks") == true) { 37 | this.num_tracks = obj.getAsNumber("num_tracks").intValue(); 38 | } 39 | if (obj.containsKey("num_followers") == true) { 40 | this.num_followers = obj.getAsNumber("num_followers").intValue(); 41 | } 42 | if (obj.containsKey("num_edits") == true) { 43 | this.num_edits = obj.getAsNumber("num_edits").intValue(); 44 | } 45 | if (obj.containsKey("duration_ms") == true) { 46 | this.duration_ms = obj.getAsNumber("duration_ms").intValue(); 47 | } 48 | if (obj.containsKey("num_artists") == true) { 49 | this.num_artists = obj.getAsNumber("num_artists").intValue(); 50 | } 51 | } 52 | 53 | public Boolean get_collaborative() { 54 | return this.collaborative; 55 | } 56 | 57 | public Integer get_duration_ms() { 58 | return this.duration_ms; 59 | } 60 | 61 | public Long get_modified_at() { 62 | return this.modified_at; 63 | } 64 | 65 | public String get_name() { 66 | return this.name; 67 | } 68 | 69 | public Integer get_num_albums() { 70 | return this.num_albums; 71 | } 72 | 73 | public Integer get_num_artists() { 74 | return this.num_artists; 75 | } 76 | 77 | public Integer get_num_edits() { 78 | return this.num_edits; 79 | } 80 | 81 | public Integer get_num_followers() { 82 | return this.num_followers; 83 | } 84 | 85 | public Integer get_num_tracks() { 86 | return this.num_tracks; 87 | } 88 | 89 | public String get_pid() { 90 | return this.pid; 91 | } 92 | 93 | public Track[] getTracks() { 94 | return this.tracks; 95 | } 96 | 97 | public void setTracks(final Track[] tracksP) { 98 | this.tracks = tracksP; 99 | } 100 | 101 | } 102 | -------------------------------------------------------------------------------- /src/main/java/common/MLTimer.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.util.concurrent.TimeUnit; 4 | 5 | import com.google.common.base.Stopwatch; 6 | 7 | public class MLTimer { 8 | 9 | private String name; 10 | private long loopSize; 11 | private Stopwatch timer; 12 | 13 | public MLTimer(final String nameP) { 14 | this.name = nameP; 15 | this.loopSize = 0; 16 | this.timer = Stopwatch.createUnstarted(); 17 | } 18 | 19 | public MLTimer(final String nameP, final int loopSizeP) { 20 | this.name = nameP; 21 | this.loopSize = loopSizeP; 22 | this.timer = Stopwatch.createUnstarted(); 23 | } 24 | 25 | public synchronized void tic() { 26 | this.timer.reset().start(); 27 | } 28 | 29 | public synchronized void toc() { 30 | double elapsedTime = this.timer.elapsed(TimeUnit.MILLISECONDS) / 1000.0; 31 | System.out.printf("%s: elapsed [%s]\n", this.name, 32 | formatSeconds((float) elapsedTime)); 33 | } 34 | 35 | public synchronized void toc(final String message) { 36 | double elapsedTime = this.timer.elapsed(TimeUnit.MILLISECONDS) / 1000.0; 37 | System.out.printf("%s: %s elapsed [%s]\n", this.name, message, 38 | formatSeconds((float) elapsedTime)); 39 | } 40 | 41 | public synchronized void tocLoop(final int curLoop) { 42 | tocLoop("", curLoop); 43 | } 44 | 45 | public synchronized void tocLoop(String message, final int curLoop) { 46 | 47 | double elapsedTime = this.timer.elapsed(TimeUnit.MILLISECONDS) / 1000.0; 48 | double speed = curLoop / elapsedTime; 49 | 50 | if (this.loopSize > 0) { 51 | double remainTime = (this.loopSize - curLoop) / speed; 52 | 53 | System.out.printf( 54 | "%s: %s[%2.2f%%] elapsed [%s] cur_spd [%.0f/s] remain [%s]\n", 55 | this.name, message, (curLoop * 100f) / this.loopSize, 56 | formatSeconds(elapsedTime), speed, 57 | formatSeconds(remainTime)); 58 | } else { 59 | System.out.printf("%s: %s [%d] elapsed [%s] cur_spd [%.0f/s]\n", 60 | this.name, message, curLoop, formatSeconds(elapsedTime), 61 | speed); 62 | } 63 | } 64 | 65 | private static String formatSeconds(double secondsF) { 66 | if (secondsF < 0) { 67 | return Double.toString(secondsF); 68 | } 69 | TimeUnit base = TimeUnit.SECONDS; 70 | int s = (int) Math.floor(secondsF); 71 | // float remainder = (float) (secondsF - s); 72 | 73 | long days = base.toDays(s); 74 | s -= TimeUnit.DAYS.toSeconds(days); 75 | long hours = base.toHours(s); 76 | s -= TimeUnit.HOURS.toSeconds(hours); 77 | long minutes = base.toMinutes(s); 78 | s -= TimeUnit.MINUTES.toSeconds(minutes); 79 | long secondsL = base.toSeconds(s); 80 | 81 | StringBuilder sb = new StringBuilder(); 82 | if (days > 0) { 83 | sb.append(days); 84 | sb.append(" days "); 85 | } 86 | if (hours > 0 || days > 0) { 87 | sb.append(hours); 88 | sb.append(" hr "); 89 | } 90 | if (hours > 0 || days > 0 || minutes > 0) { 91 | sb.append(minutes); 92 | sb.append(" min "); 93 | } 94 | sb.append(secondsL + " sec"); 95 | 96 | return sb.toString(); 97 | } 98 | 99 | } 100 | -------------------------------------------------------------------------------- /src/main/java/common/MLConcurrentUtils.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.util.HashMap; 4 | import java.util.Iterator; 5 | import java.util.concurrent.*; 6 | import java.util.concurrent.atomic.AtomicInteger; 7 | import java.util.function.*; 8 | 9 | public class MLConcurrentUtils { 10 | public static class Async extends ThreadLocal { 11 | private HashMap refs; 12 | private AtomicInteger threadCounter; 13 | 14 | private final Supplier constructor; 15 | private final Consumer cleaner; 16 | 17 | public Async(Supplier constructorP, Consumer cleanerP) { 18 | this.refs = new HashMap(); 19 | this.threadCounter = new AtomicInteger(0); 20 | this.constructor = constructorP; 21 | this.cleaner = cleanerP; 22 | } 23 | 24 | public void cleanAll() { 25 | if (this.cleaner != null) { 26 | synchronized (this.refs) { 27 | this.refs.values().forEach(this.cleaner::accept); 28 | } 29 | } 30 | } 31 | 32 | @Override protected T initialValue() { 33 | int localCount = this.threadCounter.getAndIncrement(); 34 | T t = this.constructor.get(); 35 | this.refs.put(localCount, t); 36 | return t; 37 | } 38 | } 39 | 40 | public static class PreloadingQueue implements AutoCloseable { 41 | private static class QUpdater implements Runnable { 42 | 43 | private Iterator src; 44 | private ArrayBlockingQueue raw; 45 | 46 | private QUpdater(Iterator srcP, ArrayBlockingQueue rawP) { 47 | src = srcP; 48 | raw = rawP; 49 | } 50 | 51 | public void addOneNow() { 52 | if (raw.remainingCapacity() > 0 && src.hasNext()) { 53 | raw.add(src.next()); 54 | } 55 | } 56 | 57 | public boolean hasMore() { 58 | synchronized (src) { 59 | return src.hasNext(); 60 | } 61 | } 62 | 63 | @Override public void run() { 64 | // check and fill queue 65 | while (raw.remainingCapacity() > 0 && src.hasNext()) { 66 | raw.add(src.next()); 67 | } 68 | } 69 | } 70 | private ArrayBlockingQueue raw; 71 | private QUpdater updater; 72 | 73 | private ExecutorService updatePool; 74 | 75 | public PreloadingQueue(Iterator src, int maxQueueSize) { 76 | this(src, maxQueueSize, Executors.newFixedThreadPool(1)); 77 | } 78 | 79 | public PreloadingQueue(Iterator src, int maxQueueSize, 80 | ExecutorService pool) { 81 | raw = new ArrayBlockingQueue<>(maxQueueSize); 82 | updater = new QUpdater<>(src, raw); 83 | updatePool = pool; 84 | } 85 | 86 | @Override public void close() throws Exception { 87 | if (this.updatePool != null) { 88 | this.updatePool.shutdownNow(); 89 | } 90 | } 91 | 92 | public boolean hasMore() { 93 | return raw.isEmpty() == false || updater.hasMore(); 94 | } 95 | 96 | public T pop() throws InterruptedException { 97 | T data = raw.take(); 98 | requestUpdate(); 99 | return data; 100 | } 101 | 102 | private void requestUpdate() { 103 | updatePool.submit(updater); 104 | } 105 | 106 | public void warmupOneBlocking() { 107 | updater.addOneNow(); 108 | } 109 | 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/main/java/common/MLDenseVector.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.io.Serializable; 4 | import java.util.stream.IntStream; 5 | 6 | public class MLDenseVector implements Serializable { 7 | 8 | private static final long serialVersionUID = 5061781213113137196L; 9 | private float[] values; 10 | 11 | public MLDenseVector(final float[] valuesP) { 12 | this.values = valuesP; 13 | } 14 | 15 | public MLDenseVector add(final MLDenseVector vector) { 16 | float[] sum = new float[this.getLength()]; 17 | float[] vectorValues = vector.getValues(); 18 | 19 | IntStream.range(0, this.getLength()).parallel() 20 | .forEach(i -> sum[i] = this.values[i] + vectorValues[i]); 21 | 22 | return new MLDenseVector(sum); 23 | } 24 | 25 | public MLDenseVector deepCopy() { 26 | return new MLDenseVector(this.values.clone()); 27 | } 28 | 29 | public int getLength() { 30 | return this.values.length; 31 | } 32 | 33 | public float getValue(final int index) { 34 | return this.values[index]; 35 | } 36 | 37 | public float[] getValues() { 38 | return values; 39 | } 40 | 41 | public float mult(final MLDenseVector vector) { 42 | // multiply two dense vectors 43 | if (this.getLength() != vector.getLength()) { 44 | throw new IllegalArgumentException("vectors must be same length"); 45 | } 46 | 47 | float[] vectorValues = vector.getValues(); 48 | float product = 0f; 49 | for (int i = 0; i < this.values.length; i++) { 50 | product += this.values[i] * vectorValues[i]; 51 | } 52 | 53 | return product; 54 | } 55 | 56 | public float mult(final MLSparseVector vector) { 57 | // multiply sparse and dense vectors 58 | if (this.getLength() != vector.getLength()) { 59 | throw new IllegalArgumentException("vectors must be same length"); 60 | } 61 | 62 | int[] indexesSparse = vector.getIndexes(); 63 | float[] valuesSparse = vector.getValues(); 64 | 65 | float product = 0f; 66 | for (int i = 0; i < indexesSparse.length; i++) { 67 | if (this.values[indexesSparse[i]] != 0) { 68 | product += valuesSparse[i] * this.values[indexesSparse[i]]; 69 | } 70 | } 71 | 72 | return product; 73 | } 74 | 75 | public void scalarDivide(final float f) { 76 | // divide by a scalar 77 | for (int i = 0; i < this.values.length; i++) { 78 | this.values[i] = this.values[i] / f; 79 | } 80 | } 81 | 82 | public void scalarMult(final float f) { 83 | // multiply by a scalar 84 | for (int i = 0; i < this.values.length; i++) { 85 | this.values[i] = this.values[i] * f; 86 | } 87 | } 88 | 89 | public void setValues(final float[] values) { 90 | this.values = values; 91 | } 92 | 93 | public float sum() { 94 | float sum = 0; 95 | for (float value : values) { 96 | sum += value; 97 | } 98 | 99 | return sum; 100 | } 101 | 102 | public MLSparseVector toSparse() { 103 | int nnz = 0; 104 | for (float value : this.values) { 105 | if (value != 0) { 106 | nnz++; 107 | } 108 | } 109 | if (nnz == 0) { 110 | return new MLSparseVector(null, null, null, this.getLength()); 111 | } 112 | 113 | int[] indexes = new int[nnz]; 114 | float[] values = new float[nnz]; 115 | int cur = 0; 116 | for (int i = 0; i < this.values.length; i++) { 117 | if (this.values[i] != 0) { 118 | indexes[cur] = i; 119 | values[cur] = this.values[i]; 120 | cur++; 121 | } 122 | } 123 | 124 | return new MLSparseVector(indexes, values, null, this.getLength()); 125 | } 126 | 127 | } 128 | -------------------------------------------------------------------------------- /src/main/java/main/Latents.java: -------------------------------------------------------------------------------- 1 | package main; 2 | 3 | import common.MLDenseMatrix; 4 | import main.ParsedData.PlaylistFeature; 5 | import main.ParsedData.SongFeature; 6 | 7 | public class Latents { 8 | 9 | public MLDenseMatrix U; 10 | public MLDenseMatrix V; 11 | 12 | public MLDenseMatrix Ucnn; 13 | public MLDenseMatrix Vcnn; 14 | 15 | public MLDenseMatrix Uname; 16 | public MLDenseMatrix Vname; 17 | public MLDenseMatrix name; 18 | 19 | public MLDenseMatrix Uartist; 20 | public MLDenseMatrix Vartist; 21 | public MLDenseMatrix artist; 22 | 23 | public MLDenseMatrix Ualbum; 24 | public MLDenseMatrix Valbum; 25 | public MLDenseMatrix album; 26 | 27 | public Latents() { 28 | 29 | } 30 | 31 | public Latents(final ParsedData data) throws Exception { 32 | String dataPath = "/media/mvolkovs/external4TB/Data/recsys2018"; 33 | 34 | int rankWarm = 200; 35 | this.U = MLDenseMatrix 36 | .fromFile( 37 | dataPath + "/models/latent_song/matching_name_U_" 38 | + rankWarm + ".bin", 39 | data.interactions.getNRows(), rankWarm); 40 | 41 | this.V = MLDenseMatrix 42 | .fromFile( 43 | dataPath + "/models/latent_song/matching_name_V_" 44 | + rankWarm + ".bin", 45 | data.interactions.getNCols(), rankWarm); 46 | 47 | int rankWarmCNN = 200; 48 | this.Ucnn = MLDenseMatrix 49 | .fromFile( 50 | dataPath + "/models/latent_song/matching_CNN_v2_U_" 51 | + rankWarm + ".bin", 52 | data.interactions.getNRows(), rankWarmCNN); 53 | 54 | this.Vcnn = MLDenseMatrix 55 | .fromFile( 56 | dataPath + "/models/latent_song/matching_CNN_v2_V_" 57 | + rankWarm + ".bin", 58 | data.interactions.getNCols(), rankWarmCNN); 59 | 60 | int rankName = 200; 61 | this.Uname = MLDenseMatrix.fromFile( 62 | dataPath + "/models/latent_name/name_U_" + rankName + ".bin", 63 | data.interactions.getNRows(), rankName); 64 | this.Vname = MLDenseMatrix.fromFile( 65 | dataPath + "/models/latent_name/name_V_" + rankName + ".bin", 66 | data.interactions.getNCols(), rankName); 67 | this.name = MLDenseMatrix.fromFile( 68 | dataPath + "/models/latent_name/name_" + rankName + ".bin", 69 | data.playlistFeatures.get(PlaylistFeature.NAME_REGEXED) 70 | .getCatToIndex().size(), 71 | rankName); 72 | 73 | int rankArtist = 200; 74 | this.Uartist = MLDenseMatrix 75 | .fromFile( 76 | dataPath + "/models/latent_artist/artist_U_" 77 | + rankArtist + ".bin", 78 | data.interactions.getNRows(), rankArtist); 79 | this.Vartist = MLDenseMatrix 80 | .fromFile( 81 | dataPath + "/models/latent_artist/artist_V_" 82 | + rankArtist + ".bin", 83 | data.interactions.getNCols(), rankArtist); 84 | this.artist = MLDenseMatrix.fromFile( 85 | dataPath + "/models/latent_artist/artist_" + rankArtist 86 | + ".bin", 87 | data.songFeatures.get(SongFeature.ARTIST_ID).getCatToIndex() 88 | .size(), 89 | rankArtist); 90 | 91 | int rankAlbum = 200; 92 | this.Ualbum = MLDenseMatrix.fromFile( 93 | dataPath + "/models/latent_album/album_U_" + rankAlbum + ".bin", 94 | data.interactions.getNRows(), rankAlbum); 95 | this.Valbum = MLDenseMatrix.fromFile( 96 | dataPath + "/models/latent_album/album_V_" + rankAlbum + ".bin", 97 | data.interactions.getNCols(), rankAlbum); 98 | this.album = MLDenseMatrix.fromFile( 99 | dataPath + "/models/latent_album/album_" + rankAlbum + ".bin", 100 | data.songFeatures.get(SongFeature.ALBUM_ID).getCatToIndex() 101 | .size(), 102 | rankAlbum); 103 | 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /src/main/java/common/MLIOUtils.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.io.BufferedInputStream; 4 | import java.io.BufferedOutputStream; 5 | import java.io.File; 6 | import java.io.FileInputStream; 7 | import java.io.FileOutputStream; 8 | import java.io.IOException; 9 | import java.io.ObjectInputStream; 10 | import java.io.ObjectOutputStream; 11 | import java.io.Serializable; 12 | import java.util.zip.GZIPInputStream; 13 | import java.util.zip.GZIPOutputStream; 14 | 15 | public class MLIOUtils { 16 | 17 | public static T readObjectFromFile( 18 | final String file, Class classType) throws Exception { 19 | if ((new File(file)).exists() == false) { 20 | throw new Exception("file doesn't exists " + file); 21 | } 22 | 23 | ObjectInputStream objectInputStream = null; 24 | try { 25 | BufferedInputStream fileInputStream = new BufferedInputStream( 26 | new FileInputStream(file)); 27 | 28 | objectInputStream = new ObjectInputStream(fileInputStream); 29 | Object o = objectInputStream.readObject(); 30 | 31 | if (o.getClass().equals(classType) == true) { 32 | return classType.cast(o); 33 | } else { 34 | throw new Exception("failed to deserialize " + file 35 | + " inti class " + classType.getSimpleName()); 36 | } 37 | 38 | } finally { 39 | if (objectInputStream != null) { 40 | objectInputStream.close(); 41 | } 42 | } 43 | } 44 | 45 | public static T readObjectFromFileGZ( 46 | final String file, Class classType) throws Exception { 47 | if ((new File(file)).exists() == false) { 48 | throw new Exception("file doesn't exists " + file); 49 | } 50 | 51 | ObjectInputStream objectInputStream = null; 52 | try { 53 | BufferedInputStream fileInputStream = new BufferedInputStream( 54 | new FileInputStream(file)); 55 | 56 | GZIPInputStream gzInputStream = new GZIPInputStream( 57 | fileInputStream); 58 | 59 | objectInputStream = new ObjectInputStream(gzInputStream); 60 | Object o = objectInputStream.readObject(); 61 | 62 | if (o.getClass().equals(classType) == true) { 63 | return classType.cast(o); 64 | } else { 65 | throw new Exception("failed to serialize " + file 66 | + " inti class " + classType.getSimpleName()); 67 | } 68 | 69 | } finally { 70 | if (objectInputStream != null) { 71 | objectInputStream.close(); 72 | } 73 | } 74 | } 75 | 76 | public static void writeObjectToFile(final Object object, final String file) 77 | throws IOException { 78 | ObjectOutputStream objectOutputStream = null; 79 | try { 80 | BufferedOutputStream fileOutputStream = new BufferedOutputStream( 81 | new FileOutputStream(file)); 82 | 83 | objectOutputStream = new ObjectOutputStream(fileOutputStream); 84 | objectOutputStream.writeObject(object); 85 | 86 | } finally { 87 | if (objectOutputStream != null) { 88 | objectOutputStream.close(); 89 | } 90 | } 91 | } 92 | 93 | public static void writeObjectToFileGZ(final Object object, 94 | final String file) throws IOException { 95 | ObjectOutputStream objectOutputStream = null; 96 | try { 97 | BufferedOutputStream fileOutputStream = new BufferedOutputStream( 98 | new FileOutputStream(file)); 99 | 100 | GZIPOutputStream gzOutputStream = new GZIPOutputStream( 101 | fileOutputStream); 102 | 103 | objectOutputStream = new ObjectOutputStream(gzOutputStream); 104 | objectOutputStream.writeObject(object); 105 | 106 | } finally { 107 | if (objectOutputStream != null) { 108 | objectOutputStream.close(); 109 | } 110 | } 111 | } 112 | 113 | } 114 | -------------------------------------------------------------------------------- /src/main/java/main/DataLoader.java: -------------------------------------------------------------------------------- 1 | package main; 2 | 3 | import net.minidev.json.JSONArray; 4 | import net.minidev.json.JSONObject; 5 | import net.minidev.json.JSONValue; 6 | 7 | import java.io.BufferedReader; 8 | import java.io.File; 9 | import java.io.FileReader; 10 | import java.io.IOException; 11 | import java.util.ArrayList; 12 | import java.util.Arrays; 13 | import java.util.HashMap; 14 | import java.util.List; 15 | import java.util.Map; 16 | import java.util.concurrent.atomic.AtomicInteger; 17 | 18 | import common.MLTimer; 19 | 20 | public class DataLoader { 21 | 22 | public static Data load(final String trainPath, final String testFileName) 23 | throws IOException { 24 | 25 | File folder = new File(trainPath); 26 | File[] listOfFiles = folder.listFiles(); 27 | Arrays.sort(listOfFiles); 28 | File testFile = new File(testFileName); 29 | 30 | Map songIdToIndex = new HashMap(); 31 | 32 | AtomicInteger uniquePlaylistCounter = new AtomicInteger(0); 33 | AtomicInteger uniqueSongCounter = new AtomicInteger(0); 34 | AtomicInteger parsedSongCounter = new AtomicInteger(0); 35 | 36 | List playlists = new ArrayList(); 37 | List songs = new ArrayList(); 38 | List testIndexes = new ArrayList(); 39 | MLTimer timer = new MLTimer("load"); 40 | timer.tic(); 41 | 42 | for (int f = 0; f <= listOfFiles.length; f++) { 43 | File file; 44 | if (f >= listOfFiles.length) { 45 | timer.toc("test file " + testFileName); 46 | file = testFile; 47 | } else { 48 | file = listOfFiles[f]; 49 | } 50 | 51 | try (BufferedReader reader = new BufferedReader( 52 | new FileReader(file))) { 53 | JSONObject obj = (JSONObject) JSONValue.parse(reader); 54 | JSONArray list = (JSONArray) obj.get("playlists"); 55 | for (int l = 0; l < list.size(); l++) { 56 | uniquePlaylistCounter.incrementAndGet(); 57 | if (f >= listOfFiles.length) { 58 | testIndexes.add(uniquePlaylistCounter.get() - 1); 59 | } 60 | 61 | Object data = list.get(l); 62 | Playlist playlist = new Playlist((JSONObject) data); 63 | 64 | Object tracksObj = ((JSONObject) data).get("tracks"); 65 | if (tracksObj != null && tracksObj instanceof JSONArray) { 66 | JSONArray tracksArray = (JSONArray) tracksObj; 67 | Track[] tracks = new Track[tracksArray.size()]; 68 | for (int i = 0; i < tracksArray.size(); i++) { 69 | JSONObject songObj = (JSONObject) tracksArray 70 | .get(i); 71 | Song song = new Song(songObj); 72 | Integer songIndex = songIdToIndex 73 | .get(song.get_track_uri()); 74 | if (songIndex == null) { 75 | songIndex = uniqueSongCounter.getAndIncrement(); 76 | songIdToIndex.put(song.get_track_uri(), 77 | songIndex); 78 | songs.add(song); 79 | } 80 | tracks[i] = new Track(songIndex, 81 | songObj.getAsNumber("pos").intValue()); 82 | parsedSongCounter.incrementAndGet(); 83 | } 84 | playlist.setTracks(tracks); 85 | } 86 | playlists.add(playlist); 87 | } 88 | 89 | if ((f + 1) % 10 == 0) { 90 | timer.tocLoop(String.format( 91 | "playlists[%d] unique songs[%d] total songs[%d]", 92 | playlists.size(), songs.size(), 93 | parsedSongCounter.get()), parsedSongCounter.get()); 94 | } 95 | } 96 | } 97 | 98 | System.out.printf( 99 | "FINISHED PARSING: playlists[%d] unique songs[%d] total songs[%d]", 100 | playlists.size(), songs.size(), parsedSongCounter.get()); 101 | Data data = new Data(); 102 | 103 | data.playlists = new Playlist[playlists.size()]; 104 | playlists.toArray(data.playlists); 105 | 106 | data.testIndexes = testIndexes.stream().mapToInt(i -> i).toArray(); 107 | 108 | data.songs = new Song[songs.size()]; 109 | songs.toArray(data.songs); 110 | 111 | return data; 112 | 113 | } 114 | 115 | } 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 | 5 |

6 | 7 | ## 2018 ACM RecSys Challenge 1'st Place Solution From Team vl6 8 | 9 | **Team members**: Maksims Volkovs (Layer 6), Himanshu Rai (Layer 6), Zhaoyue Cheng (Layer 6), Yichao Lu (University of Toronto), Ga Wu (University of Toronto, Vector Institute), Scott Sanner (University of Toronto, Vector Institute) 10 | [[paper](http://www.cs.toronto.edu/~mvolkovs/recsys2018_challenge.pdf)][[challenge](http://www.recsyschallenge.com/2018)] 11 | 12 | Contact: maks@layer6.ai 13 | 14 | 15 | 16 | ## Introduction 17 | This repository contains the Java implementation of our entries for both main and creative tracks. Our approach 18 | consists of a two-stage model where in the first stage a blend of collaborative filtering methods is used to 19 | quickly retrieve a set of candidate songs for each playlist with high recall. Then in the second stage a pairwise 20 | playlist-song gradient boosting model is used to re-rank the retrieved candidates and maximize precision at the 21 | top of the recommended list. 22 | 23 | 24 | 25 | ## Environment 26 | The model is implemented in Java and tested on the following environment: 27 | * Intel(R) Xeon(R) CPU E5-2620 v4 @ 2.10GHz 28 | * 256GB RAM 29 | * Nvidia Titan V 30 | * Java Oracle 1.8.0_171 31 | * Python, Numpy 1.14.3, Sklearn 0.19.1, Scipy 1.1.0 32 | * Apache Maven 3.3.9 33 | * CUDA 8.0 and CUDNN 8.0 34 | * Intel MKL 2018.1.038 35 | * XGBoost and XGBoost4j 0.7 36 | 37 | 38 | 39 | ## Executing 40 | 41 | All models are executed from `src/main/java/main/Executor.java`, the main function has examples on 42 | how to do main and creative track model training, evaluation and submission. To run the model: 43 | 44 | * Set all paths: 45 | ``` 46 | //OAuth token for spotify creative api, if doing creative track submission 47 | String authToken = ""; 48 | 49 | // path to song audio feature file, if doing creative track submission 50 | String creativeTrackFile = "/home/recsys2018/data/song_audio_features.txt"; 51 | 52 | // path to MPD directory with the JSON files 53 | String trainPath = "/home/recsys2018/data/train/"; 54 | 55 | // path to challenge set JSON file 56 | String testFile = "/home/recsys2018/data/test/challenge_set.json"; 57 | 58 | // path to python SVD script included in the repo, default location: script/svd_py.py 59 | String pythonScriptPath = "/home/recsys2018/script/svd_py.py"; 60 | 61 | //path to cache folder for temp storage, at least 20GB should be available in this folder 62 | String cachePath = "/home/recsys2018/cache/"; 63 | ``` 64 | 65 | * Compile and execute with maven: 66 | ``` 67 | export MAVEN_OPTS="-Xms150g -Xmx150g" 68 | mvn clean compile 69 | mvn exec:java -Dexec.mainClass="main.Executor" 70 | ``` 71 | Note that by default the code is executing model for the main track, to run the creative track model set `xgbParams.doCreative = true`. For the creative track we extracted extra song features from the 72 | [Spotify Audio API](https://developer.spotify.com/documentation/web-api/reference/tracks/get-several-audio-features/). We were able to match most songs from the challenge Million Playlist Dataset, and used the following fields for further feature extraction: `[acousticness, danceability, energy, instrumentalness, key, liveness, loudness, mode, speechiness, tempo, time_signature, valence]`. In order to download the data for this track, you need to get the OAuth Token from 73 | [Spotify API page](https://developer.spotify.com/console/get-audio-features-several-tracks/?ids=4JpKVNYnVcJ8tuMKjAj50A,2NRANZE9UCmPAS5XVbXL40,24JygzOLM0EmRQeGtFcIcG) and 74 | assign it to the `authToken` variable in the `Executor.main` function. 75 | 76 | We prioritized speed over memory for this project so you'll need at least 100GB of RAM to run model training and inference. The full end-to-end runtime takes approximately 1.5 days. 77 | 78 | -------------------------------------------------------------------------------- /src/main/java/main/SVD.java: -------------------------------------------------------------------------------- 1 | package main; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.BufferedWriter; 5 | import java.io.FileReader; 6 | import java.io.IOException; 7 | import java.io.PrintWriter; 8 | import java.lang.reflect.Field; 9 | 10 | import common.MLDenseMatrix; 11 | import common.MLDenseVector; 12 | import common.MLSparseMatrix; 13 | import common.MLSparseVector; 14 | import common.MLTimer; 15 | 16 | public class SVD { 17 | 18 | public static class SVDParams { 19 | public int svdIter = 4; 20 | public int rank = 200; 21 | public String scriptPath; 22 | public String cachePath; 23 | public String cacheName = "matrix.csv"; 24 | public String cachePName = "U.nd"; 25 | public String cacheQName = "V.nd"; 26 | public String cacheSName = "S.nd"; 27 | public int shapeRows; 28 | public int shapeCols; 29 | 30 | @Override 31 | public String toString() { 32 | StringBuilder result = new StringBuilder(); 33 | String newLine = System.getProperty("line.separator"); 34 | 35 | result.append(this.getClass().getName()); 36 | result.append(" {"); 37 | result.append(newLine); 38 | 39 | // determine fields declared in this class only (no fields of 40 | // superclass) 41 | Field[] fields = this.getClass().getDeclaredFields(); 42 | 43 | // print field names paired with their values 44 | for (Field field : fields) { 45 | result.append(" "); 46 | try { 47 | result.append(field.getName()); 48 | result.append(": "); 49 | // requires access to private field: 50 | result.append(field.get(this)); 51 | } catch (IllegalAccessException ex) { 52 | System.out.println(ex); 53 | } 54 | result.append(newLine); 55 | } 56 | result.append("}"); 57 | 58 | return result.toString(); 59 | } 60 | 61 | } 62 | 63 | public SVDParams params; 64 | public MLDenseMatrix P; 65 | public MLDenseMatrix Q; 66 | public MLDenseVector s; 67 | 68 | public SVD(final SVDParams paramsP) { 69 | this.params = paramsP; 70 | } 71 | 72 | public void runPythonSVD(final MLSparseMatrix matrix) { 73 | 74 | MLTimer timer = new MLTimer("runPythonSVD"); 75 | timer.tic(); 76 | Process process = null; 77 | try { 78 | String command = String.format( 79 | "python %s -r %s -i %s -d %s -f %s --shape %s %s", 80 | this.params.scriptPath, this.params.rank, 81 | this.params.svdIter, this.params.cachePath, 82 | this.params.cacheName, this.params.shapeRows, 83 | this.params.shapeCols); 84 | System.out.println(command); 85 | 86 | toCSV(matrix, this.params.cachePath + this.params.cacheName); 87 | process = Runtime.getRuntime() 88 | .exec(new String[] { "bash", "-c", command }); 89 | process.waitFor(); 90 | 91 | this.P = fromCSV(this.params.cachePath + this.params.cachePName, 92 | this.params.shapeRows); 93 | this.Q = fromCSV(this.params.cachePath + this.params.cacheQName, 94 | this.params.shapeCols); 95 | } catch (Exception e) { 96 | e.printStackTrace(); 97 | 98 | } finally { 99 | if (process != null) { 100 | process.destroy(); 101 | } 102 | } 103 | } 104 | 105 | public static MLDenseMatrix fromCSV(final String inFile, final int nRows) { 106 | 107 | MLTimer timer = new MLTimer("fromCSV"); 108 | timer.tic(); 109 | 110 | MLDenseVector[] rows = new MLDenseVector[nRows]; 111 | int curRow = 0; 112 | try (BufferedReader reader = new BufferedReader( 113 | new FileReader(inFile))) { 114 | String line; 115 | while ((line = reader.readLine()) != null) { 116 | if (curRow % 500_000 == 0) { 117 | timer.tocLoop(curRow); 118 | } 119 | 120 | String[] split = line.split(","); 121 | float[] values = new float[split.length]; 122 | for (int i = 0; i < split.length; i++) { 123 | values[i] = Float.parseFloat(split[i]); 124 | } 125 | rows[curRow] = new MLDenseVector(values); 126 | curRow++; 127 | } 128 | 129 | if (curRow != nRows) { 130 | throw new Exception("urRow != nRows"); 131 | } 132 | 133 | } catch (Exception e) { 134 | e.printStackTrace(); 135 | } 136 | 137 | return new MLDenseMatrix(rows); 138 | } 139 | 140 | public static void toCSV(final MLSparseMatrix matrix, final String outFile) 141 | throws IOException { 142 | MLTimer timer = new MLTimer("toCSV"); 143 | timer.tic(); 144 | 145 | try (BufferedWriter writer = new BufferedWriter( 146 | new PrintWriter(outFile, "UTF-8"))) { 147 | int numRows = matrix.getNRows(); 148 | for (int i = 0; i < numRows; i++) { 149 | if (i % 500_000 == 0) { 150 | timer.tocLoop(i); 151 | } 152 | 153 | MLSparseVector row = matrix.getRow(i); 154 | if (row != null) { 155 | int[] indexes = row.getIndexes(); 156 | float[] values = row.getValues(); 157 | for (int j = 0; j < indexes.length; j++) { 158 | writer.write( 159 | i + "," + indexes[j] + "," + values[j] + "\n"); 160 | } 161 | } 162 | } 163 | } 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /src/main/java/common/MLSparseMatrix.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.io.Serializable; 4 | import java.util.Arrays; 5 | import java.util.Map; 6 | import java.util.stream.IntStream; 7 | 8 | public interface MLSparseMatrix extends Serializable { 9 | 10 | public abstract void applyColNorm(final MLDenseVector colNorm); 11 | 12 | public abstract void applyColSelector( 13 | final Map selectedColMap, 14 | final int nColsSelected); 15 | 16 | public abstract void applyRowNorm(final MLDenseVector rowNorm); 17 | 18 | public abstract void binarizeValues(); 19 | 20 | public abstract MLSparseMatrix deepCopy(); 21 | 22 | public abstract MLDenseVector getColNNZ(); 23 | 24 | public abstract MLDenseVector getColNorm(final int p); 25 | 26 | public abstract MLDenseVector getColSum(); 27 | 28 | public abstract int getNCols(); 29 | 30 | public abstract long getNNZ(); 31 | 32 | public abstract int getNRows(); 33 | 34 | public abstract MLSparseVector getRow(final int rowIndex); 35 | 36 | public abstract MLSparseVector getRow(final int rowIndex, 37 | final boolean returnEmpty); 38 | 39 | public abstract MLDenseVector getRowNNZ(); 40 | 41 | public abstract MLDenseVector getRowNorm(final int p); 42 | 43 | public abstract MLDenseVector getRowSum(); 44 | 45 | public abstract boolean hasDates(); 46 | 47 | public abstract void inferAndSetNCols(); 48 | 49 | public abstract MLSparseMatrix mult(final MLSparseMatrix another); 50 | 51 | public abstract MLDenseVector multCol(final MLDenseVector vector); 52 | 53 | public abstract MLDenseVector multCol(final MLSparseVector vector); 54 | 55 | public abstract MLDenseVector multRow(final MLDenseVector vector); 56 | 57 | public abstract MLDenseVector multRow(final MLSparseVector vector); 58 | 59 | public abstract Map selectCols(final int nnzCutOff); 60 | 61 | public abstract void setNCols(int nCols); 62 | 63 | public abstract void setRow(final MLSparseVector row, final int rowIndex); 64 | 65 | public abstract void toBinFile(final String outFile) throws Exception; 66 | 67 | public abstract MLSparseMatrix transpose(); 68 | 69 | public static MLSparseMatrix concatHorizontal( 70 | final MLSparseMatrix... matrices) { 71 | int nRows = matrices[0].getNRows(); 72 | int nColsNew = 0; 73 | for (MLSparseMatrix matrix : matrices) { 74 | if (nRows != matrix.getNRows()) { 75 | throw new IllegalArgumentException( 76 | "input must have same number of rows"); 77 | } 78 | 79 | nColsNew += matrix.getNCols(); 80 | } 81 | 82 | MLSparseVector[] concat = new MLSparseVector[nRows]; 83 | IntStream.range(0, nRows).parallel().forEach(rowIndex -> { 84 | 85 | MLSparseVector[] rows = new MLSparseVector[matrices.length]; 86 | boolean allNull = true; 87 | for (int i = 0; i < matrices.length; i++) { 88 | MLSparseVector row = matrices[i].getRow(rowIndex); 89 | if (row != null) { 90 | allNull = false; 91 | } else { 92 | // nulls are not allowed in vector concat 93 | row = new MLSparseVector(null, null, null, 94 | matrices[i].getNCols()); 95 | } 96 | rows[i] = row; 97 | } 98 | if (allNull == true) { 99 | concat[rowIndex] = null; 100 | } else { 101 | concat[rowIndex] = MLSparseVector.concat(rows); 102 | } 103 | }); 104 | 105 | return new MLSparseMatrixAOO(concat, nColsNew); 106 | } 107 | 108 | public static MLSparseMatrix concatVertical( 109 | final MLSparseMatrix... matrices) { 110 | 111 | int nCols = matrices[0].getNCols(); 112 | int nRowsNew = 0; 113 | int[] offsets = new int[matrices.length]; 114 | boolean[] hasDates = new boolean[] { true }; 115 | for (int i = 0; i < offsets.length; i++) { 116 | if (nCols != matrices[i].getNCols()) { 117 | throw new IllegalArgumentException( 118 | "input must have same number of columns"); 119 | } 120 | nRowsNew += matrices[i].getNRows(); 121 | offsets[i] = nRowsNew; 122 | 123 | if (matrices[i].hasDates() == false) { 124 | hasDates[0] = false; 125 | } 126 | } 127 | 128 | MLSparseVector[] concatRows = new MLSparseVector[nRowsNew]; 129 | IntStream.range(0, nRowsNew).parallel().forEach(rowIndex -> { 130 | 131 | int offsetMatIndex = 0; 132 | int offsetRowIndex = 0; 133 | for (int i = 0; i < offsets.length; i++) { 134 | if (rowIndex < offsets[i]) { 135 | offsetMatIndex = i; 136 | if (i == 0) { 137 | offsetRowIndex = rowIndex; 138 | } else { 139 | offsetRowIndex = rowIndex - offsets[i - 1]; 140 | } 141 | break; 142 | } 143 | } 144 | 145 | MLSparseVector row = matrices[offsetMatIndex] 146 | .getRow(offsetRowIndex); 147 | if (row != null) { 148 | concatRows[rowIndex] = row.deepCopy(); 149 | if (hasDates[0] == false) { 150 | // NOTE: if at least one matrix doesn't have dates 151 | // then all dates must be removed 152 | concatRows[rowIndex].setDates(null); 153 | } 154 | } 155 | }); 156 | 157 | return new MLSparseMatrixAOO(concatRows, nCols); 158 | } 159 | } 160 | -------------------------------------------------------------------------------- /src/main/java/common/EvaluatorCF.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.util.Arrays; 4 | import java.util.HashMap; 5 | import java.util.Map; 6 | import java.util.concurrent.atomic.AtomicInteger; 7 | import java.util.stream.IntStream; 8 | 9 | public abstract class EvaluatorCF { 10 | 11 | protected int[] threshs; 12 | 13 | public EvaluatorCF(final int[] threshsP) { 14 | this.threshs = threshsP; 15 | } 16 | 17 | public abstract ResultCF evaluate(final SplitterCF split, 18 | final String interactionType, final FloatElement[][] preds); 19 | 20 | public int[] getEvalThreshs() { 21 | return this.threshs; 22 | } 23 | 24 | public int getMaxEvalThresh() { 25 | return this.threshs[this.threshs.length - 1]; 26 | } 27 | 28 | public static FloatElement[][] getRankings(SplitterCF split, 29 | MLDenseMatrix U, MLDenseMatrix V, final int maxThresh, 30 | final String interactionType) { 31 | if (Math.floorDiv(LowLevelRoutines.MAX_ARRAY_SIZE, V.getNRows()) >= V 32 | .getNCols()) { 33 | return getRankingsNative(split.getRstrain().get(interactionType), 34 | split.getValidRowIndexes(), split.getValidColIndexes(), U, 35 | V, maxThresh, 100); 36 | } else { 37 | System.out.printf( 38 | "[WARNING] using non-native ranking can be very slow"); 39 | return getRankingsNonNative(split, U, V, maxThresh, 40 | interactionType); 41 | } 42 | } 43 | 44 | public static FloatElement[][] getRankingsNative( 45 | final MLSparseMatrix Rtrain, final int[] rowIndexes, 46 | int[] colIndexes, final MLDenseMatrix U, final MLDenseMatrix V, 47 | final int rankingSize, final int rowBatchSize) { 48 | 49 | // convenience function 50 | float[] Vflat = V.slice(colIndexes).toFlatArray(); 51 | return getRankingsNative(Rtrain, rowIndexes, colIndexes, U, Vflat, 52 | rankingSize, rowBatchSize); 53 | } 54 | 55 | public static FloatElement[][] getRankingsNative( 56 | final MLSparseMatrix Rtrain, final int[] rowIndexes, 57 | final int[] colIndexes, final MLDenseMatrix U, final float[] V, 58 | final int rankingSize, final int rowBatchSize) { 59 | 60 | FloatElement[][] rankings = new FloatElement[U.getNRows()][]; 61 | final int nRowsV = colIndexes.length; 62 | final int nCols = U.getNCols(); 63 | 64 | final Map colMap = new HashMap(); 65 | for (int i = 0; i < colIndexes.length; i++) { 66 | colMap.put(colIndexes[i], i); 67 | } 68 | 69 | final int uBatchSize = Math.min(rowBatchSize, 70 | Math.floorDiv(LowLevelRoutines.MAX_ARRAY_SIZE, nRowsV)); 71 | int nBatch = -Math.floorDiv(-rowIndexes.length, uBatchSize); 72 | 73 | for (int batch = 0; batch < nBatch; batch++) { 74 | final int start = batch * uBatchSize; 75 | final int end = Math.min(start + uBatchSize, rowIndexes.length); 76 | 77 | final float[] result = new float[(end - start) * nRowsV]; 78 | MLDenseMatrix uBatchRows = U.slice(rowIndexes, start, end); 79 | LowLevelRoutines.sgemm(uBatchRows.toFlatArray(), V, result, 80 | (end - start), nRowsV, nCols, true, false, 1, 0); 81 | 82 | IntStream.range(0, end - start).parallel().forEach(i -> { 83 | int rowIndex = rowIndexes[start + i]; 84 | MLSparseVector trainRow = null; 85 | if (Rtrain != null) { 86 | trainRow = Rtrain.getRow(rowIndex); 87 | } 88 | // map training index to relative index to match sliced V 89 | FloatElement[] preds; 90 | int[] excludes = null; 91 | if (trainRow != null) { 92 | excludes = Arrays.stream(trainRow.getIndexes()) 93 | .filter(colMap::containsKey).map(colMap::get) 94 | .toArray(); 95 | if (excludes.length == 0) { 96 | excludes = null; 97 | } 98 | } 99 | preds = FloatElement.topNSortOffset(result, rankingSize, 100 | excludes, i * nRowsV, nRowsV); 101 | 102 | if (preds != null) { 103 | // map back to full index 104 | for (int j = 0; j < preds.length; j++) { 105 | preds[j].setIndex(colIndexes[preds[j].getIndex()]); 106 | } 107 | } 108 | rankings[rowIndex] = preds; 109 | }); 110 | } 111 | return rankings; 112 | } 113 | 114 | private static FloatElement[][] getRankingsNonNative(final SplitterCF split, 115 | final MLDenseMatrix U, final MLDenseMatrix V, final int maxThresh, 116 | final String interactionType) { 117 | 118 | MLSparseMatrix R_train = split.getRstrain().get(interactionType); 119 | FloatElement[][] rankings = new FloatElement[R_train.getNRows()][]; 120 | int[] validRowIndexes = split.getValidRowIndexes(); 121 | int[] validColIndexes = split.getValidColIndexes(); 122 | AtomicInteger count = new AtomicInteger(0); 123 | MLTimer evalTimer = new MLTimer("ALS Eval", validRowIndexes.length); 124 | 125 | IntStream.range(0, validRowIndexes.length).parallel().forEach(index -> { 126 | final int countLocal = count.incrementAndGet(); 127 | if (countLocal % 1000 == 0) { 128 | evalTimer.tocLoop(countLocal); 129 | } 130 | int rowIndex = validRowIndexes[index]; 131 | 132 | MLDenseVector uRow = U.getRow(rowIndex); 133 | FloatElement[] rowScores = new FloatElement[validColIndexes.length]; 134 | int cur = 0; 135 | for (int colIndex : validColIndexes) { 136 | rowScores[cur] = new FloatElement(colIndex, 137 | uRow.mult(V.getRow(colIndex))); 138 | cur++; 139 | } 140 | 141 | MLSparseVector trainRow = R_train.getRow(rowIndex); 142 | if (trainRow != null) { 143 | rankings[rowIndex] = FloatElement.topNSortArr(rowScores, 144 | maxThresh, R_train.getRow(rowIndex).getIndexes()); 145 | } else { 146 | rankings[rowIndex] = FloatElement.topNSort(rowScores, maxThresh, 147 | null); 148 | } 149 | }); 150 | return rankings; 151 | } 152 | 153 | } 154 | -------------------------------------------------------------------------------- /src/main/java/common/MLXGBoost.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import ml.dmlc.xgboost4j.LabeledPoint; 4 | import ml.dmlc.xgboost4j.java.Booster; 5 | import ml.dmlc.xgboost4j.java.DMatrix; 6 | import ml.dmlc.xgboost4j.java.XGBoost; 7 | import ml.dmlc.xgboost4j.java.XGBoostError; 8 | 9 | import java.io.BufferedReader; 10 | import java.io.FileReader; 11 | import java.util.Arrays; 12 | import java.util.Comparator; 13 | import java.util.LinkedList; 14 | import java.util.List; 15 | import java.util.Map; 16 | 17 | import common.MLConcurrentUtils.Async; 18 | 19 | public class MLXGBoost { 20 | 21 | public static class MLXGBoostFeature { 22 | 23 | public static class ScoreComparator 24 | implements Comparator { 25 | 26 | private boolean decreasing; 27 | 28 | public ScoreComparator(final boolean decreasingP) { 29 | this.decreasing = decreasingP; 30 | } 31 | 32 | @Override 33 | public int compare(final MLXGBoostFeature e1, 34 | final MLXGBoostFeature e2) { 35 | if (this.decreasing == true) { 36 | return Double.compare(e2.score, e1.score); 37 | } else { 38 | return Double.compare(e1.score, e2.score); 39 | } 40 | } 41 | } 42 | 43 | private String name; 44 | private double score; 45 | 46 | public MLXGBoostFeature(final String nameP, final double scoreP) { 47 | this.name = nameP; 48 | this.score = scoreP; 49 | } 50 | 51 | public String getName() { 52 | return this.name; 53 | } 54 | 55 | public double getScore() { 56 | return this.score; 57 | } 58 | } 59 | 60 | public static MLXGBoostFeature[] analyzeFeatures(final String modelFile, 61 | final String featureFile) throws Exception { 62 | 63 | Booster model = XGBoost.loadModel(modelFile); 64 | 65 | List temp = new LinkedList(); 66 | try (BufferedReader reader = new BufferedReader( 67 | new FileReader(featureFile))) { 68 | String line; 69 | while ((line = reader.readLine()) != null) { 70 | temp.add(line); 71 | } 72 | } 73 | 74 | // get feature importance scores 75 | String[] featureNames = new String[temp.size()]; 76 | temp.toArray(featureNames); 77 | int[] importances = MLXGBoost.getFeatureImportance(model, featureNames); 78 | 79 | // sort features by their importance 80 | MLXGBoostFeature[] sortedFeatures = new MLXGBoostFeature[featureNames.length]; 81 | for (int i = 0; i < featureNames.length; i++) { 82 | sortedFeatures[i] = new MLXGBoostFeature(featureNames[i], 83 | importances[i]); 84 | } 85 | Arrays.sort(sortedFeatures, new MLXGBoostFeature.ScoreComparator(true)); 86 | 87 | return sortedFeatures; 88 | } 89 | 90 | public static Async asyncModel(final String modelFile) { 91 | return asyncModel(modelFile, 0); 92 | } 93 | 94 | public static Async asyncModel(final String modelFile, 95 | final int nthread) { 96 | // load xgboost model 97 | final Async modelAsync = new Async(() -> { 98 | try { 99 | Booster bst = XGBoost.loadModel(modelFile); 100 | if (nthread > 0) { 101 | bst.setParam("nthread", nthread); 102 | } 103 | return bst; 104 | } catch (XGBoostError e) { 105 | e.printStackTrace(); 106 | return null; 107 | } 108 | }, Booster::dispose); 109 | return modelAsync; 110 | } 111 | 112 | public static int[] getFeatureImportance(final Booster model, 113 | final String[] featNames) throws XGBoostError { 114 | 115 | int[] importances = new int[featNames.length]; 116 | // NOTE: not used feature are dropped here 117 | Map importanceMap = model.getFeatureScore(null); 118 | 119 | for (Map.Entry entry : importanceMap.entrySet()) { 120 | // get index from f0, f1 feature name output from xgboost 121 | int index = Integer.parseInt(entry.getKey().substring(1)); 122 | importances[index] = entry.getValue(); 123 | } 124 | 125 | return importances; 126 | } 127 | 128 | public static DMatrix toDMatrix(final MLSparseMatrix matrix) 129 | throws XGBoostError { 130 | 131 | final int nnz = (int) matrix.getNNZ(); 132 | final int nRows = matrix.getNRows(); 133 | final int nCols = matrix.getNCols(); 134 | 135 | long[] rowIndex = new long[nRows + 1]; 136 | int[] indexesFlat = new int[nnz]; 137 | float[] valuesFlat = new float[nnz]; 138 | 139 | int cur = 0; 140 | for (int i = 0; i < nRows; i++) { 141 | MLSparseVector row = matrix.getRow(i); 142 | if (row == null) { 143 | rowIndex[i] = cur; 144 | continue; 145 | } 146 | int[] indexes = row.getIndexes(); 147 | int rowNNZ = indexes.length; 148 | if (rowNNZ == 0) { 149 | rowIndex[i] = cur; 150 | continue; 151 | } 152 | float[] values = row.getValues(); 153 | rowIndex[i] = cur; 154 | 155 | for (int j = 0; j < rowNNZ; j++, cur++) { 156 | indexesFlat[cur] = indexes[j]; 157 | valuesFlat[cur] = values[j]; 158 | } 159 | } 160 | rowIndex[nRows] = cur; 161 | return new DMatrix(rowIndex, indexesFlat, valuesFlat, 162 | DMatrix.SparseType.CSR, nCols); 163 | } 164 | 165 | public static String toLIBSVMString(final LabeledPoint vec) { 166 | float target = vec.label(); 167 | StringBuilder builder = new StringBuilder(); 168 | if (target == (int) target) { 169 | builder.append((int) target); 170 | } else { 171 | builder.append(String.format("%.5f", target)); 172 | } 173 | for (int i = 0; i < vec.indices().length; i++) { 174 | float val = vec.values()[i]; 175 | if (val == Math.round(val)) { 176 | builder.append(" " + (vec.indices()[i]) + ":" + ((int) val)); 177 | } else { 178 | builder.append(" " + (vec.indices()[i]) + ":" 179 | + String.format("%.5f", val)); 180 | } 181 | } 182 | return builder.toString(); 183 | } 184 | 185 | } 186 | -------------------------------------------------------------------------------- /src/main/java/common/MLTextTransform.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.io.IOException; 4 | import java.io.Serializable; 5 | import java.io.StringReader; 6 | import java.util.HashMap; 7 | import java.util.LinkedList; 8 | import java.util.List; 9 | import java.util.Map; 10 | 11 | import org.apache.lucene.analysis.Analyzer; 12 | import org.apache.lucene.analysis.LowerCaseFilter; 13 | import org.apache.lucene.analysis.TokenStream; 14 | import org.apache.lucene.analysis.Tokenizer; 15 | import org.apache.lucene.analysis.core.WhitespaceTokenizer; 16 | import org.apache.lucene.analysis.miscellaneous.ASCIIFoldingFilter; 17 | import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; 18 | import org.apache.lucene.analysis.miscellaneous.WordDelimiterGraphFilterFactory; 19 | import org.apache.lucene.analysis.ngram.NGramTokenFilter; 20 | 21 | public abstract class MLTextTransform implements Serializable { 22 | 23 | public static class DefaultAnalyzer extends Analyzer 24 | implements Serializable { 25 | private static final long serialVersionUID = 8662472589510750438L; 26 | 27 | @Override 28 | protected TokenStreamComponents createComponents( 29 | final String fieldName) { 30 | try { 31 | Tokenizer tokenizer = new WhitespaceTokenizer(); 32 | 33 | Map args = new HashMap(); 34 | args.put("catenateAll", "0"); 35 | args.put("generateNumberParts", "1"); 36 | args.put("generateWordParts", "1"); 37 | args.put("splitOnCaseChange", "1"); 38 | args.put("splitOnNumerics", "1"); 39 | 40 | WordDelimiterGraphFilterFactory factory = new WordDelimiterGraphFilterFactory( 41 | args); 42 | TokenStream filter = factory.create(tokenizer); 43 | 44 | filter = new LowerCaseFilter(filter); 45 | 46 | filter = new ASCIIFoldingFilter(filter); 47 | 48 | return new TokenStreamComponents(tokenizer, filter); 49 | 50 | } catch (Exception e) { 51 | e.printStackTrace(); 52 | } 53 | return null; 54 | } 55 | } 56 | 57 | public static class DefaultNGRAMAnalyzer extends Analyzer 58 | implements Serializable { 59 | 60 | private static final long serialVersionUID = 2224016723762685329L; 61 | // https://lucene.apache.org/core/7_0_0/analyzers-common/org/apache/lucene/analysis/ngram/NGramTokenFilter.html 62 | private int minGram; 63 | private int maxGram; 64 | 65 | public DefaultNGRAMAnalyzer(final int minGraP, final int maxGramP) { 66 | this.minGram = minGraP; 67 | this.maxGram = maxGramP; 68 | } 69 | 70 | @Override 71 | protected TokenStreamComponents createComponents( 72 | final String fieldName) { 73 | try { 74 | Tokenizer tokenizer = new WhitespaceTokenizer(); 75 | 76 | Map args = new HashMap(); 77 | args.put("catenateAll", "0"); 78 | args.put("generateNumberParts", "1"); 79 | args.put("generateWordParts", "1"); 80 | args.put("splitOnCaseChange", "1"); 81 | args.put("splitOnNumerics", "1"); 82 | 83 | WordDelimiterGraphFilterFactory factory = new WordDelimiterGraphFilterFactory( 84 | args); 85 | TokenStream filter = factory.create(tokenizer); 86 | 87 | filter = new LowerCaseFilter(filter); 88 | 89 | filter = new ASCIIFoldingFilter(filter); 90 | 91 | filter = new NGramTokenFilter(filter, this.minGram, 92 | this.maxGram); 93 | 94 | return new TokenStreamComponents(tokenizer, filter); 95 | 96 | } catch (Exception e) { 97 | e.printStackTrace(); 98 | } 99 | return null; 100 | } 101 | } 102 | 103 | public static class LuceneAnalyzerTextTransform extends MLTextTransform { 104 | 105 | private static final long serialVersionUID = 1843607513745972795L; 106 | private Analyzer analyzer; 107 | 108 | public LuceneAnalyzerTextTransform(final Analyzer analyzerP) { 109 | this.analyzer = analyzerP; 110 | } 111 | 112 | @Override 113 | public void apply(final MLTextInput input) { 114 | try { 115 | List tokens = passThroughAnalyzer(input.text, 116 | this.analyzer); 117 | String[] tokenized = new String[tokens.size()]; 118 | int cur = 0; 119 | for (String token : tokens) { 120 | tokenized[cur] = token; 121 | cur++; 122 | } 123 | input.setTokenized(tokenized); 124 | 125 | } catch (Exception e) { 126 | throw new RuntimeException(e.getMessage()); 127 | } 128 | } 129 | 130 | public static List passThroughAnalyzer(final String input, 131 | final Analyzer analyzer) throws IOException { 132 | TokenStream tokenStream = null; 133 | try { 134 | tokenStream = analyzer.tokenStream(null, 135 | new StringReader(input)); 136 | CharTermAttribute termAtt = tokenStream 137 | .addAttribute(CharTermAttribute.class); 138 | tokenStream.reset(); 139 | List tokens = new LinkedList(); 140 | while (tokenStream.incrementToken()) { 141 | String term = termAtt.toString().trim(); 142 | if (term.length() > 0) { 143 | tokens.add(term); 144 | } 145 | } 146 | tokenStream.end(); 147 | 148 | return tokens; 149 | } finally { 150 | if (tokenStream != null) { 151 | tokenStream.close(); 152 | } 153 | } 154 | } 155 | } 156 | 157 | public static class MLTextInput { 158 | 159 | private String text; 160 | private String[] tokenized; 161 | 162 | public MLTextInput(final String textP) { 163 | this.text = textP; 164 | } 165 | 166 | public String getText() { 167 | return this.text; 168 | } 169 | 170 | public String[] getTokenized() { 171 | return this.tokenized; 172 | } 173 | 174 | public void setTokenized(final String[] tokenizedP) { 175 | this.tokenized = tokenizedP; 176 | } 177 | } 178 | 179 | private static final long serialVersionUID = 3800020927323228525L; 180 | 181 | public abstract void apply(final MLTextInput input); 182 | } 183 | -------------------------------------------------------------------------------- /src/main/java/common/MLMatrixElement.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.io.Serializable; 4 | import java.util.Arrays; 5 | import java.util.Comparator; 6 | import java.util.HashSet; 7 | import java.util.PriorityQueue; 8 | import java.util.Set; 9 | 10 | public class MLMatrixElement implements Serializable { 11 | 12 | public static class ColIndexComparator 13 | implements Comparator { 14 | 15 | private boolean decreasing; 16 | 17 | public ColIndexComparator(final boolean decreasingP) { 18 | this.decreasing = decreasingP; 19 | } 20 | 21 | @Override 22 | public int compare(final MLMatrixElement e1, final MLMatrixElement e2) { 23 | if (this.decreasing == true) { 24 | return Integer.compare(e2.colIndex, e1.colIndex); 25 | } else { 26 | return Integer.compare(e1.colIndex, e2.colIndex); 27 | } 28 | } 29 | } 30 | 31 | public static class DateComparator implements Comparator { 32 | 33 | private boolean decreasing; 34 | 35 | public DateComparator(final boolean decreasingP) { 36 | this.decreasing = decreasingP; 37 | } 38 | 39 | @Override 40 | public int compare(final MLMatrixElement e1, final MLMatrixElement e2) { 41 | if (this.decreasing == true) { 42 | return Float.compare(e2.date, e1.date); 43 | } else { 44 | return Float.compare(e1.date, e2.date); 45 | } 46 | } 47 | } 48 | 49 | public static class RowIndexComparator 50 | implements Comparator { 51 | 52 | private boolean decreasing; 53 | 54 | public RowIndexComparator(final boolean decreasingP) { 55 | this.decreasing = decreasingP; 56 | } 57 | 58 | @Override 59 | public int compare(final MLMatrixElement e1, final MLMatrixElement e2) { 60 | if (this.decreasing == true) { 61 | return Integer.compare(e2.rowIndex, e1.rowIndex); 62 | } else { 63 | return Integer.compare(e1.rowIndex, e2.rowIndex); 64 | } 65 | } 66 | } 67 | 68 | public static class ValueComparator implements Comparator { 69 | 70 | private boolean decreasing; 71 | 72 | public ValueComparator(final boolean decreasingP) { 73 | this.decreasing = decreasingP; 74 | } 75 | 76 | @Override 77 | public int compare(final MLMatrixElement e1, final MLMatrixElement e2) { 78 | if (this.decreasing == true) { 79 | return Float.compare(e2.value, e1.value); 80 | } else { 81 | return Float.compare(e1.value, e2.value); 82 | } 83 | } 84 | } 85 | 86 | private static final long serialVersionUID = 1078736772506670L; 87 | private int rowIndex; 88 | private int colIndex; 89 | private float value; 90 | private long date; 91 | 92 | public MLMatrixElement(int rowIndexP, int colIndexP, float valueP, 93 | long dateP) { 94 | this.rowIndex = rowIndexP; 95 | this.colIndex = colIndexP; 96 | this.value = valueP; 97 | this.date = dateP; 98 | } 99 | 100 | public int getColIndex() { 101 | return this.colIndex; 102 | } 103 | 104 | public long getDate() { 105 | return this.date; 106 | } 107 | 108 | public int getRowIndex() { 109 | return this.rowIndex; 110 | } 111 | 112 | public float getValue() { 113 | return this.value; 114 | } 115 | 116 | public void setColIndex(final int colIndexP) { 117 | this.colIndex = colIndexP; 118 | } 119 | 120 | public void setDate(final long dateP) { 121 | this.date = dateP; 122 | } 123 | 124 | public void setRowIndex(final int rowIndexP) { 125 | this.rowIndex = rowIndexP; 126 | } 127 | 128 | public void setValue(final float valueP) { 129 | this.value = valueP; 130 | } 131 | 132 | public static MLMatrixElement[] topNSort(final int rowIndex, 133 | final float[] vec, final int topN, final Set exclusions) { 134 | 135 | final Comparator valAscending = new MLMatrixElement.ValueComparator( 136 | false); 137 | final Comparator valDescending = new MLMatrixElement.ValueComparator( 138 | true); 139 | 140 | PriorityQueue heap = new PriorityQueue( 141 | topN, valAscending); 142 | 143 | for (int i = 0; i < vec.length; i++) { 144 | if (exclusions != null && exclusions.contains(i) == true) { 145 | continue; 146 | } 147 | float val = vec[i]; 148 | if (heap.size() < topN) { 149 | heap.add(new MLMatrixElement(rowIndex, i, val, 0)); 150 | 151 | } else { 152 | if (heap.peek().value < val) { 153 | heap.poll(); 154 | heap.add(new MLMatrixElement(rowIndex, i, val, 0)); 155 | } 156 | } 157 | } 158 | 159 | MLMatrixElement[] heapArray = new MLMatrixElement[heap.size()]; 160 | heap.toArray(heapArray); 161 | Arrays.sort(heapArray, valDescending); 162 | 163 | return heapArray; 164 | } 165 | 166 | public static MLMatrixElement[] topNSort(final MLMatrixElement[] elements, 167 | final int topN, final Set exclusions) { 168 | 169 | final Comparator valAscending = new MLMatrixElement.ValueComparator( 170 | false); 171 | final Comparator valDescending = new MLMatrixElement.ValueComparator( 172 | true); 173 | 174 | PriorityQueue heap = new PriorityQueue( 175 | topN, valAscending); 176 | 177 | for (int i = 0; i < elements.length; i++) { 178 | if (exclusions != null && exclusions.contains(i) == true) { 179 | continue; 180 | } 181 | MLMatrixElement element = elements[i]; 182 | if (heap.size() < topN) { 183 | heap.add(element); 184 | 185 | } else { 186 | if (heap.peek().value < element.value) { 187 | heap.poll(); 188 | heap.add(element); 189 | } 190 | } 191 | } 192 | 193 | MLMatrixElement[] heapArray = new MLMatrixElement[heap.size()]; 194 | heap.toArray(heapArray); 195 | Arrays.sort(heapArray, valDescending); 196 | 197 | return heapArray; 198 | } 199 | 200 | public static MLMatrixElement[] topNSortArr(final int rowIndex, 201 | final float[] vec, final int topN, final int[] exclusions) { 202 | Set exclusionSet = new HashSet(exclusions.length); 203 | for (int exclusion : exclusions) { 204 | exclusionSet.add(exclusion); 205 | } 206 | return topNSort(rowIndex, vec, topN, exclusionSet); 207 | } 208 | 209 | } 210 | -------------------------------------------------------------------------------- /src/main/java/main/Executor.java: -------------------------------------------------------------------------------- 1 | package main; 2 | 3 | import java.io.BufferedWriter; 4 | import java.io.File; 5 | import java.io.FileWriter; 6 | import java.io.IOException; 7 | 8 | import common.ALS; 9 | import common.ALS.ALSParams; 10 | import common.MLTimer; 11 | import common.SplitterCF; 12 | import main.XGBModel.XGBModelParams; 13 | import okhttp3.OkHttpClient; 14 | import okhttp3.Request; 15 | import okhttp3.Response; 16 | 17 | public class Executor { 18 | 19 | private static void downloadCreativeData(Data dataLoaded, String outFile, 20 | String authToken) throws IOException { 21 | // Please provide your own key here 22 | final String AUTH_TOKEN = "Bearer " + authToken; 23 | 24 | try (BufferedWriter bw = new BufferedWriter(new FileWriter(outFile))) { 25 | 26 | int nSongs = dataLoaded.songs.length; 27 | int batchSize = Math.floorDiv(nSongs, 100); 28 | OkHttpClient client = new OkHttpClient(); 29 | 30 | for (int batch = 0; batch < batchSize; batch++) { 31 | 32 | // uncomment and provide batch number from where to begin in 33 | // case the operation was terminated due to auth expiration 34 | /* 35 | * if(batch <33207) continue; 36 | */ 37 | 38 | System.out.println("Doing batch " + batch); 39 | 40 | int batchStart = batch * 100; 41 | int batchEnd = Math.min(batchStart + 100, nSongs); 42 | // Now form a batch of 100 43 | String url = "https://api.spotify.com/v1/audio-features?ids="; 44 | int firstTime = 1; 45 | for (int i = batchStart; i < batchEnd; i++) { 46 | if (firstTime == 1) { 47 | url = url + dataLoaded.songs[i].get_track_uri() 48 | .split(":")[2]; 49 | firstTime = 0; 50 | } else { 51 | url = url + "%2C" + dataLoaded.songs[i].get_track_uri() 52 | .split(":")[2]; 53 | } 54 | 55 | } 56 | 57 | Request request = new Request.Builder().url(url) 58 | .addHeader("Authorization", AUTH_TOKEN).build(); 59 | Response responses = null; 60 | String append = "["; 61 | String last = "]"; 62 | 63 | try { 64 | responses = client.newCall(request).execute(); 65 | } catch (IOException e) { 66 | e.printStackTrace(); 67 | } 68 | String jsonData = responses.body().string(); 69 | jsonData = append + jsonData + last; 70 | org.json.JSONArray jsonarray = new org.json.JSONArray(jsonData); 71 | 72 | if (jsonarray.getJSONObject(0).has("error")) { 73 | System.out.println("timed out pausing for a while."); 74 | try { 75 | Thread.sleep(4000 + 1000); 76 | 77 | try { 78 | responses = client.newCall(request).execute(); 79 | } catch (IOException e) { 80 | e.printStackTrace(); 81 | } 82 | jsonData = responses.body().string(); 83 | jsonData = append + jsonData + last; 84 | jsonarray = new org.json.JSONArray(jsonData); 85 | } catch (InterruptedException e) { 86 | e.printStackTrace(); 87 | } 88 | 89 | } 90 | if (jsonarray.getJSONObject(0).has("error")) { 91 | System.out.println(jsonarray.getJSONObject(0)); 92 | // Now our key has timed out . SO lets just exit 93 | bw.close(); 94 | System.out.println( 95 | "Please refresh your key as you timed out on batch: " 96 | + batch); 97 | System.exit(1); 98 | } 99 | 100 | org.json.JSONArray jsonobject = (org.json.JSONArray) jsonarray 101 | .getJSONObject(0).get("audio_features"); 102 | String writeString = jsonobject.toString(); 103 | if (batch == 0) { 104 | writeString = writeString.replace("]", ","); 105 | } else if (batch == batchSize - 1) { 106 | writeString = writeString.replace("[", ""); 107 | } else { 108 | writeString = writeString.replace("[", ""); 109 | writeString = writeString.replace("]", ","); 110 | } 111 | bw.write(writeString); 112 | 113 | } 114 | bw.close(); 115 | 116 | } catch (IOException e) { 117 | 118 | e.printStackTrace(); 119 | 120 | } 121 | System.out.println("Extraction complete."); 122 | } 123 | 124 | public static void main(final String[] args) { 125 | try { 126 | String authToken = ""; 127 | String creativeTrackFile = "/media/mvolkovs/external4TB/Data/recsys2018/data/song_audio_features.txt"; 128 | 129 | String trainPath = "/media/mvolkovs/external4TB/Data/recsys2018/data/train"; 130 | String testFile = "/media/mvolkovs/external4TB/Data/recsys2018/data/test/challenge_set.json"; 131 | String pythonScriptPath = "/home/mvolkovs/projects/vl6_recsys2018/script/svd_py.py"; 132 | String cachePath = "/media/mvolkovs/external4TB/Data/recsys2018/models/svd/"; 133 | 134 | MLTimer timer = new MLTimer("main"); 135 | timer.tic(); 136 | 137 | XGBModelParams xgbParams = new XGBModelParams(); 138 | xgbParams.doCreative = false; 139 | xgbParams.xgbModel = cachePath + "xgb.model"; 140 | 141 | // load data 142 | Data data = DataLoader.load(trainPath, testFile); 143 | timer.toc("data loaded"); 144 | 145 | // download creative track features if not there 146 | if (xgbParams.doCreative == true 147 | && new File(creativeTrackFile).exists() == false) { 148 | downloadCreativeData(data, creativeTrackFile, authToken); 149 | } 150 | 151 | ParsedDataLoader loader = new ParsedDataLoader(data); 152 | loader.loadPlaylists(); 153 | loader.loadSongs(); 154 | if (xgbParams.doCreative == true) { 155 | loader.loadSongExtraInfo(creativeTrackFile); 156 | } 157 | ParsedData dataParsed = loader.dataParsed; 158 | timer.toc("data parsed"); 159 | 160 | // generate split 161 | SplitterCF split = RecSysSplitter.getSplitMatching(dataParsed); 162 | RecSysSplitter.removeName(dataParsed, split); 163 | timer.toc("split done"); 164 | 165 | // get all latents 166 | Latents latents = new Latents(); 167 | 168 | // WMF 169 | ALSParams alsParams = new ALSParams(); 170 | alsParams.alpha = 100; 171 | alsParams.rank = 200; 172 | alsParams.lambda = 0.001f; 173 | alsParams.maxIter = 10; 174 | ALS als = new ALS(alsParams); 175 | als.optimize(split.getRstrain().get(ParsedData.INTERACTION_KEY), 176 | null); 177 | latents.U = als.getU(); 178 | latents.Ucnn = als.getU(); 179 | latents.V = als.getV(); 180 | latents.Vcnn = als.getV(); 181 | 182 | // SVD on album, artist and name 183 | SVDModel svdModel = new SVDModel(dataParsed, split, latents); 184 | svdModel.factorizeAlbums(pythonScriptPath, cachePath); 185 | svdModel.factorizeArtists(pythonScriptPath, cachePath); 186 | svdModel.factorizeNames(pythonScriptPath, cachePath); 187 | timer.toc("latents computed"); 188 | 189 | // train second stage model 190 | // Latents latents = new Latents(dataParsed); 191 | XGBModel model = new XGBModel(dataParsed, xgbParams, latents, 192 | split); 193 | model.extractFeatures2Stage(cachePath); 194 | model.trainModel(cachePath); 195 | model.submission2Stage(cachePath + "submission.out"); 196 | 197 | } catch (Exception e) { 198 | e.printStackTrace(); 199 | } 200 | 201 | } 202 | } 203 | -------------------------------------------------------------------------------- /src/main/java/common/MLFeatureTransform.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.io.Serializable; 4 | import java.util.Map; 5 | import java.util.stream.IntStream; 6 | 7 | public abstract class MLFeatureTransform implements Serializable { 8 | 9 | public static class ColNormTransform extends MLFeatureTransform { 10 | 11 | private static final long serialVersionUID = -1777920290446866227L; 12 | private int normType; 13 | private MLDenseVector colNorm; 14 | 15 | public ColNormTransform(final int normTypeP) { 16 | this.normType = normTypeP; 17 | } 18 | 19 | @Override 20 | public void apply(final MLSparseFeature feature) { 21 | MLSparseMatrix matrix = feature.getFeatMatrixTransformed(); 22 | this.colNorm = matrix.getColNorm(this.normType); 23 | this.applyInference(feature); 24 | } 25 | 26 | @Override 27 | public String[] applyFeatureName(final String[] featureNames) { 28 | // do nothing 29 | return featureNames; 30 | } 31 | 32 | @Override 33 | public void applyInference(final MLSparseFeature feature) { 34 | MLSparseMatrix matrix = feature.getFeatMatrixTransformed(); 35 | matrix.applyColNorm(this.colNorm); 36 | } 37 | 38 | @Override 39 | public void applyInference(final MLSparseVector vector) { 40 | if (vector == null || vector.getIndexes() == null) { 41 | return; 42 | } 43 | vector.applyNorm(this.colNorm); 44 | } 45 | } 46 | 47 | public static class ColSelectorTransform extends MLFeatureTransform { 48 | 49 | private static final long serialVersionUID = -4544414355983339509L; 50 | private Map selectedColMap; 51 | private int nColsSelected; 52 | private int nnzCutOff; 53 | 54 | public ColSelectorTransform(final int nnzCutOffP) { 55 | this.nnzCutOff = nnzCutOffP; 56 | } 57 | 58 | @Override 59 | public void apply(final MLSparseFeature feature) { 60 | 61 | // select columns that pass nnz cutoff 62 | MLSparseMatrix matrix = feature.getFeatMatrixTransformed(); 63 | this.selectedColMap = matrix.selectCols(this.nnzCutOff); 64 | 65 | // calculate new ncols 66 | this.nColsSelected = 0; 67 | for (Integer index : selectedColMap.values()) { 68 | if (this.nColsSelected < (index + 1)) { 69 | this.nColsSelected = index + 1; 70 | } 71 | } 72 | 73 | // apply column selector to feature matrix 74 | this.applyInference(feature); 75 | } 76 | 77 | @Override 78 | public String[] applyFeatureName(final String[] featureNames) { 79 | String[] selectedFeatNames = new String[this.selectedColMap.size()]; 80 | for (int i = 0; i < featureNames.length; i++) { 81 | Integer newIndex = this.selectedColMap.get(i); 82 | if (newIndex != null) { 83 | selectedFeatNames[newIndex] = featureNames[i]; 84 | } 85 | } 86 | return selectedFeatNames; 87 | } 88 | 89 | @Override 90 | public void applyInference(final MLSparseFeature feature) { 91 | MLSparseMatrix matrix = feature.getFeatMatrixTransformed(); 92 | matrix.applyColSelector(this.selectedColMap, this.nColsSelected); 93 | } 94 | 95 | @Override 96 | public void applyInference(final MLSparseVector vector) { 97 | vector.applyIndexSelector(this.selectedColMap, this.nColsSelected); 98 | } 99 | 100 | } 101 | 102 | public static class RowNormTransform extends MLFeatureTransform { 103 | 104 | private static final long serialVersionUID = -1777920290446866227L; 105 | private int normType; 106 | 107 | public RowNormTransform(final int normTypeP) { 108 | this.normType = normTypeP; 109 | } 110 | 111 | @Override 112 | public void apply(final MLSparseFeature feature) { 113 | this.applyInference(feature); 114 | } 115 | 116 | @Override 117 | public String[] applyFeatureName(final String[] featureNames) { 118 | // do nothing 119 | return featureNames; 120 | } 121 | 122 | @Override 123 | public void applyInference(final MLSparseFeature feature) { 124 | MLSparseMatrix matrix = feature.getFeatMatrixTransformed(); 125 | MLDenseVector norm = matrix.getRowNorm(this.normType); 126 | matrix.applyRowNorm(norm); 127 | } 128 | 129 | @Override 130 | public void applyInference(final MLSparseVector vector) { 131 | if (vector == null || vector.getIndexes() == null) { 132 | return; 133 | } 134 | vector.applyNorm(this.normType); 135 | } 136 | } 137 | 138 | public static class StandardizeTransform extends MLFeatureTransform { 139 | 140 | private static final long serialVersionUID = -2289862537575019481L; 141 | private float[] mean; 142 | private float[] std; 143 | private float cutOff; 144 | 145 | public StandardizeTransform(final float cutOffP) { 146 | this.cutOff = cutOffP; 147 | } 148 | 149 | @Override 150 | public void apply(final MLSparseFeature feature) { 151 | // compute mean 152 | MLSparseMatrix matrix = feature.getFeatMatrixTransformed(); 153 | this.mean = matrix.getColSum().getValues(); 154 | float[] colNNZ = matrix.getColNNZ().getValues(); 155 | for (int i = 0; i < this.mean.length; i++) { 156 | if (colNNZ[i] > 0) { 157 | this.mean[i] /= colNNZ[i]; 158 | } 159 | } 160 | 161 | // compute std 162 | this.std = new float[this.mean.length]; 163 | IntStream.range(0, matrix.getNRows()).parallel() 164 | .forEach(rowIndex -> { 165 | MLSparseVector row = matrix.getRow(rowIndex); 166 | if (row == null) { 167 | return; 168 | } 169 | int[] indexes = row.getIndexes(); 170 | float[] values = row.getValues(); 171 | for (int i = 0; i < indexes.length; i++) { 172 | float diff = values[i] - this.mean[indexes[i]]; 173 | synchronized (this.std) { 174 | this.std[indexes[i]] += diff * diff; 175 | } 176 | } 177 | 178 | }); 179 | for (int i = 0; i < this.std.length; i++) { 180 | if (colNNZ[i] > 0) { 181 | this.std[i] = (float) Math.sqrt(this.std[i] / colNNZ[i]); 182 | } 183 | } 184 | 185 | // apply transform to this feature 186 | this.applyInference(feature); 187 | } 188 | 189 | @Override 190 | public String[] applyFeatureName(final String[] featureNames) { 191 | // do nothing 192 | return featureNames; 193 | } 194 | 195 | @Override 196 | public void applyInference(final MLSparseFeature feature) { 197 | MLSparseMatrix matrix = feature.getFeatMatrixTransformed(); 198 | IntStream.range(0, matrix.getNRows()).parallel() 199 | .forEach(rowIndex -> { 200 | MLSparseVector row = matrix.getRow(rowIndex); 201 | if (row == null) { 202 | return; 203 | } 204 | this.applyInference(row); 205 | matrix.setRow(row, rowIndex); 206 | }); 207 | } 208 | 209 | @Override 210 | public void applyInference(final MLSparseVector vector) { 211 | if (vector == null || vector.getIndexes() == null) { 212 | return; 213 | } 214 | 215 | int[] indexes = vector.getIndexes(); 216 | float[] values = vector.getValues(); 217 | int nnz = 0; 218 | for (int i = 0; i < indexes.length; i++) { 219 | int index = indexes[i]; 220 | if (Math.abs(values[i] - this.mean[index]) < 1e-5) { 221 | values[i] = 0; 222 | continue; 223 | } 224 | nnz++; 225 | 226 | values[i] = (values[i] - this.mean[index]) / this.std[index]; 227 | 228 | // clip standardized values 229 | if (values[i] > this.cutOff) { 230 | values[i] = this.cutOff; 231 | 232 | } else if (values[i] < -this.cutOff) { 233 | values[i] = -this.cutOff; 234 | } 235 | } 236 | if (nnz != values.length) { 237 | // remove zeros 238 | int[] newIndexes = new int[nnz]; 239 | float[] newValues = new float[nnz]; 240 | int cur = 0; 241 | for (int i = 0; i < indexes.length; i++) { 242 | if (values[i] != 0) { 243 | newIndexes[cur] = indexes[i]; 244 | newValues[cur] = values[i]; 245 | cur++; 246 | } 247 | } 248 | vector.setIndexes(newIndexes); 249 | vector.setValues(newValues); 250 | } 251 | } 252 | 253 | } 254 | 255 | private static final long serialVersionUID = 3186575390529411219L; 256 | 257 | public abstract void apply(final MLSparseFeature feature); 258 | 259 | public abstract String[] applyFeatureName(final String[] featureNames); 260 | 261 | public abstract void applyInference(final MLSparseFeature feature); 262 | 263 | public abstract void applyInference(final MLSparseVector vector); 264 | } 265 | -------------------------------------------------------------------------------- /src/main/java/common/FloatElement.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.io.Serializable; 4 | import java.util.Arrays; 5 | import java.util.Comparator; 6 | import java.util.HashSet; 7 | import java.util.PriorityQueue; 8 | import java.util.Set; 9 | import java.util.stream.IntStream; 10 | 11 | public class FloatElement implements Serializable { 12 | 13 | public static class IndexComparator implements Comparator { 14 | 15 | private boolean decreasing; 16 | 17 | public IndexComparator(final boolean decreasingP) { 18 | this.decreasing = decreasingP; 19 | } 20 | 21 | @Override 22 | public int compare(final FloatElement e1, final FloatElement e2) { 23 | if (this.decreasing == true) { 24 | return Integer.compare(e2.index, e1.index); 25 | } else { 26 | return Integer.compare(e1.index, e2.index); 27 | } 28 | } 29 | } 30 | 31 | public static class ValueComparator implements Comparator { 32 | 33 | private boolean decreasing; 34 | 35 | public ValueComparator(final boolean decreasingP) { 36 | this.decreasing = decreasingP; 37 | } 38 | 39 | @Override 40 | public int compare(final FloatElement e1, final FloatElement e2) { 41 | if (this.decreasing == true) { 42 | return Float.compare(e2.value, e1.value); 43 | } else { 44 | return Float.compare(e1.value, e2.value); 45 | } 46 | } 47 | } 48 | 49 | private static final long serialVersionUID = -4838379190571020403L; 50 | private int index; 51 | private float value; 52 | private Object other; 53 | 54 | public FloatElement(final int indexP, final float valueP) { 55 | this.index = indexP; 56 | this.value = valueP; 57 | } 58 | 59 | public FloatElement(final int indexP, final float valueP, 60 | final Object otherP) { 61 | this.index = indexP; 62 | this.value = valueP; 63 | this.other = otherP; 64 | } 65 | 66 | public int getIndex() { 67 | return this.index; 68 | } 69 | 70 | public Object getOther() { 71 | return this.other; 72 | } 73 | 74 | public float getValue() { 75 | return this.value; 76 | } 77 | 78 | public void setIndex(final int indexP) { 79 | this.index = indexP; 80 | } 81 | 82 | public void setOther(final Object otherP) { 83 | this.other = otherP; 84 | } 85 | 86 | public void setValue(final float valueP) { 87 | this.value = valueP; 88 | } 89 | 90 | public static void standardize(final FloatElement[][] scores) { 91 | // in place row-based score standarization 92 | IntStream.range(0, scores.length).parallel().forEach(i -> { 93 | FloatElement[] row = scores[i]; 94 | if (row == null) { 95 | return; 96 | } 97 | 98 | standardize(row); 99 | }); 100 | } 101 | 102 | public static void standardize(final FloatElement[] scores) { 103 | float mean = 0f; 104 | for (FloatElement element : scores) { 105 | mean += element.value; 106 | } 107 | mean = mean / scores.length; 108 | 109 | float std = 0f; 110 | for (FloatElement element : scores) { 111 | std += (element.value - mean) * (element.value - mean); 112 | } 113 | std = (float) Math.sqrt(std / scores.length); 114 | 115 | if (std > 1e-5) { 116 | for (FloatElement element : scores) { 117 | element.value = (element.value - mean) / std; 118 | } 119 | } 120 | } 121 | 122 | public static FloatElement[] topNSort(final float[] vec, final int topN, 123 | final int[] exclusions) { 124 | Set exclusionSet = new HashSet(exclusions.length); 125 | for (int exclusion : exclusions) { 126 | exclusionSet.add(exclusion); 127 | } 128 | return topNSort(vec, topN, exclusionSet); 129 | } 130 | 131 | public static FloatElement[] topNSort(final float[] vec, final int topN, 132 | final Set exclusions) { 133 | 134 | final Comparator valAscending = new FloatElement.ValueComparator( 135 | false); 136 | final Comparator valDescending = new FloatElement.ValueComparator( 137 | true); 138 | 139 | PriorityQueue heap = new PriorityQueue(topN, 140 | valAscending); 141 | 142 | for (int i = 0; i < vec.length; i++) { 143 | if (exclusions != null && exclusions.contains(i) == true) { 144 | continue; 145 | } 146 | float val = vec[i]; 147 | if (heap.size() < topN) { 148 | heap.add(new FloatElement(i, val)); 149 | 150 | } else { 151 | if (heap.peek().value < val) { 152 | heap.poll(); 153 | heap.add(new FloatElement(i, val)); 154 | } 155 | } 156 | } 157 | 158 | FloatElement[] heapArray = new FloatElement[heap.size()]; 159 | heap.toArray(heapArray); 160 | Arrays.sort(heapArray, valDescending); 161 | 162 | return heapArray; 163 | } 164 | 165 | public static FloatElement[] topNSort(final FloatElement[] vec, 166 | final int topN, final Set exclusions) { 167 | 168 | final Comparator valAscending = new FloatElement.ValueComparator( 169 | false); 170 | final Comparator valDescending = new FloatElement.ValueComparator( 171 | true); 172 | 173 | PriorityQueue heap = new PriorityQueue(topN, 174 | valAscending); 175 | 176 | for (int i = 0; i < vec.length; i++) { 177 | if (exclusions != null && exclusions.contains(i) == true) { 178 | continue; 179 | } 180 | FloatElement element = vec[i]; 181 | if (heap.size() < topN) { 182 | heap.add(element); 183 | 184 | } else { 185 | if (heap.peek().value < element.getValue()) { 186 | heap.poll(); 187 | heap.add(element); 188 | } 189 | } 190 | } 191 | 192 | FloatElement[] heapArray = new FloatElement[heap.size()]; 193 | heap.toArray(heapArray); 194 | Arrays.sort(heapArray, valDescending); 195 | 196 | return heapArray; 197 | } 198 | 199 | public static FloatElement[] topNSortArr(final FloatElement[] vec, 200 | final int topN, final int[] exclusions) { 201 | Set exclusionSet = new HashSet(exclusions.length); 202 | for (int exclusion : exclusions) { 203 | exclusionSet.add(exclusion); 204 | } 205 | return topNSort(vec, topN, exclusionSet); 206 | } 207 | 208 | public static FloatElement[] topNSortOffset(final float[] vec, int topN, 209 | final int offset, final int length, Set exclusions) { 210 | 211 | final Comparator valAscending = new FloatElement.ValueComparator( 212 | false); 213 | final Comparator valDescending = new FloatElement.ValueComparator( 214 | true); 215 | PriorityQueue heap = new PriorityQueue<>(topN, 216 | valAscending); 217 | 218 | for (int i = 0; i < length; i++) { 219 | if (exclusions != null && exclusions.contains(i) == true) { 220 | continue; 221 | } 222 | float val = vec[i + offset]; 223 | if (heap.size() < topN) { 224 | heap.add(new FloatElement(i, val)); 225 | 226 | } else { 227 | if (heap.peek().getValue() < val) { 228 | heap.poll(); 229 | heap.add(new FloatElement(i, val)); 230 | } 231 | } 232 | } 233 | 234 | FloatElement[] heapArray = new FloatElement[heap.size()]; 235 | heap.toArray(heapArray); 236 | Arrays.sort(heapArray, valDescending); 237 | 238 | return heapArray; 239 | } 240 | 241 | public static FloatElement[] topNSortOffset(final float[] vec, int topN, 242 | int[] exclusionSorted, final int offset, final int length) { 243 | 244 | final Comparator valAscending = new FloatElement.ValueComparator( 245 | false); 246 | final Comparator valDescending = new FloatElement.ValueComparator( 247 | true); 248 | PriorityQueue heap = new PriorityQueue<>(topN, 249 | valAscending); 250 | 251 | int skipping = exclusionSorted == null ? -1 : exclusionSorted[0]; 252 | int skippingCur = 0; 253 | final int exclusionEnd = exclusionSorted == null ? 0 254 | : exclusionSorted.length; 255 | for (int i = 0; i < length; i++) { 256 | if (i == skipping) { 257 | skippingCur++; 258 | if (skippingCur < exclusionEnd) { 259 | skipping = exclusionSorted[skippingCur]; 260 | } else { 261 | skipping = -1; 262 | } 263 | continue; 264 | } 265 | float val = vec[i + offset]; 266 | if (heap.size() < topN) { 267 | heap.add(new FloatElement(i, val)); 268 | 269 | } else { 270 | if (heap.peek().getValue() < val) { 271 | heap.poll(); 272 | heap.add(new FloatElement(i, val)); 273 | } 274 | } 275 | } 276 | 277 | FloatElement[] heapArray = new FloatElement[heap.size()]; 278 | heap.toArray(heapArray); 279 | Arrays.sort(heapArray, valDescending); 280 | 281 | return heapArray; 282 | } 283 | 284 | } 285 | -------------------------------------------------------------------------------- /src/main/java/main/SVDModel.java: -------------------------------------------------------------------------------- 1 | package main; 2 | 3 | import java.util.concurrent.atomic.AtomicInteger; 4 | import java.util.stream.IntStream; 5 | 6 | import common.MLDenseVector; 7 | import common.MLSparseMatrix; 8 | import common.MLSparseMatrixAOO; 9 | import common.MLSparseVector; 10 | import common.MLTimer; 11 | import common.SplitterCF; 12 | import main.ParsedData.PlaylistFeature; 13 | import main.ParsedData.SongFeature; 14 | import main.SVD.SVDParams; 15 | 16 | public class SVDModel { 17 | 18 | private static MLTimer timer = new MLTimer("SVDModel"); 19 | 20 | private ParsedData data; 21 | private SplitterCF split; 22 | private Latents latents; 23 | 24 | public SVDModel(final ParsedData dataP, final SplitterCF splitP, 25 | final Latents latentsP) { 26 | this.data = dataP; 27 | this.split = splitP; 28 | this.latents = latentsP; 29 | } 30 | 31 | public void factorizeNames(final String scriptPath, final String cachePath) 32 | throws Exception { 33 | MLSparseMatrix playlistNames = this.data.playlistFeatures 34 | .get(PlaylistFeature.NAME_REGEXED).getFeatMatrix(); 35 | timer.toc("nNames " + playlistNames.getNCols()); 36 | 37 | // create name matrix 38 | MLSparseMatrix Rtrain = this.split.getRstrain() 39 | .get(ParsedData.INTERACTION_KEY); 40 | MLSparseMatrix RtrainT = Rtrain.transpose(); 41 | 42 | MLSparseVector[] rowsNames = new MLSparseVector[Rtrain.getNCols() 43 | + Rtrain.getNRows()]; 44 | AtomicInteger counter = new AtomicInteger(0); 45 | IntStream.range(0, Rtrain.getNCols()).parallel().forEach(songIndex -> { 46 | int count = counter.incrementAndGet(); 47 | if (count % 200_000 == 0) { 48 | timer.tocLoop("songs done", count); 49 | } 50 | 51 | MLSparseVector song = RtrainT.getRow(songIndex); 52 | if (song == null) { 53 | return; 54 | } 55 | 56 | MLDenseVector rowAvg = getRowAvg(playlistNames, song.getIndexes(), 57 | false); 58 | MLSparseVector rowAvgSparse = rowAvg.toSparse(); 59 | if (rowAvgSparse.getIndexes() != null) { 60 | rowsNames[songIndex] = rowAvgSparse; 61 | } 62 | }); 63 | 64 | counter.set(0); 65 | IntStream.range(0, Rtrain.getNRows()).parallel() 66 | .forEach(playlistIndex -> { 67 | int count = counter.incrementAndGet(); 68 | if (count % 200_000 == 0) { 69 | timer.tocLoop("playlists done", count); 70 | } 71 | 72 | MLSparseVector names = playlistNames.getRow(playlistIndex); 73 | if (names != null) { 74 | rowsNames[playlistIndex + Rtrain.getNCols()] = names 75 | .deepCopy(); 76 | } 77 | }); 78 | 79 | MLSparseMatrix nameMatrix = new MLSparseMatrixAOO(rowsNames, 80 | playlistNames.getNCols()); 81 | timer.toc("name matrix done " + nameMatrix.getNRows() + " " 82 | + nameMatrix.getNCols() + " " + nameMatrix.getNNZ()); 83 | 84 | SVDParams svdParams = new SVDParams(); 85 | svdParams.svdIter = 4; 86 | svdParams.rank = 200; 87 | svdParams.scriptPath = scriptPath; 88 | svdParams.cachePath = cachePath; 89 | svdParams.shapeRows = nameMatrix.getNRows(); 90 | svdParams.shapeCols = nameMatrix.getNCols(); 91 | SVD svd = new SVD(svdParams); 92 | svd.runPythonSVD(nameMatrix); 93 | 94 | this.latents.Vname = svd.P.slice(0, Rtrain.getNCols()); 95 | this.latents.Uname = svd.P.slice(Rtrain.getNCols(), 96 | Rtrain.getNCols() + Rtrain.getNRows()); 97 | this.latents.name = svd.Q; 98 | } 99 | 100 | public void factorizeAlbums(final String scriptPath, final String cachePath) 101 | throws Exception { 102 | MLSparseMatrix songAlbums = this.data.songFeatures 103 | .get(SongFeature.ALBUM_ID).getFeatMatrix(); 104 | timer.toc("nAlbums " + songAlbums.getNCols()); 105 | 106 | // create album matrix 107 | MLSparseMatrix Rtrain = this.split.getRstrain() 108 | .get(ParsedData.INTERACTION_KEY); 109 | 110 | MLSparseVector[] rowsAlbums = new MLSparseVector[Rtrain.getNCols() 111 | + Rtrain.getNRows()]; 112 | AtomicInteger counter = new AtomicInteger(0); 113 | IntStream.range(0, Rtrain.getNCols()).parallel().forEach(songIndex -> { 114 | int count = counter.incrementAndGet(); 115 | if (count % 500_000 == 0) { 116 | timer.tocLoop("songs done", count); 117 | } 118 | 119 | MLSparseVector albums = songAlbums.getRow(songIndex); 120 | if (albums != null) { 121 | rowsAlbums[songIndex] = albums.deepCopy(); 122 | } 123 | }); 124 | 125 | counter.set(0); 126 | IntStream.range(0, Rtrain.getNRows()).parallel() 127 | .forEach(playlistIndex -> { 128 | int count = counter.incrementAndGet(); 129 | if (count % 500_000 == 0) { 130 | timer.tocLoop("playlists done", count); 131 | } 132 | 133 | MLSparseVector playlist = Rtrain.getRow(playlistIndex); 134 | if (playlist == null) { 135 | return; 136 | } 137 | 138 | MLDenseVector rowAvg = getRowAvg(songAlbums, 139 | playlist.getIndexes(), false); 140 | MLSparseVector rowAvgSparse = rowAvg.toSparse(); 141 | if (rowAvgSparse.getIndexes() != null) { 142 | rowsAlbums[playlistIndex 143 | + Rtrain.getNCols()] = rowAvgSparse; 144 | } 145 | }); 146 | 147 | MLSparseMatrix albumMatrix = new MLSparseMatrixAOO(rowsAlbums, 148 | songAlbums.getNCols()); 149 | timer.toc("album matrix done " + albumMatrix.getNRows() + " " 150 | + albumMatrix.getNCols() + " " + albumMatrix.getNNZ()); 151 | 152 | SVDParams svdParams = new SVDParams(); 153 | svdParams.svdIter = 4; 154 | svdParams.rank = 200; 155 | svdParams.scriptPath = scriptPath; 156 | svdParams.cachePath = cachePath; 157 | svdParams.shapeRows = albumMatrix.getNRows(); 158 | svdParams.shapeCols = albumMatrix.getNCols(); 159 | SVD svd = new SVD(svdParams); 160 | svd.runPythonSVD(albumMatrix); 161 | 162 | this.latents.Valbum = svd.P.slice(0, Rtrain.getNCols()); 163 | this.latents.Ualbum = svd.P.slice(Rtrain.getNCols(), 164 | Rtrain.getNCols() + Rtrain.getNRows()); 165 | this.latents.album = svd.Q; 166 | } 167 | 168 | public void factorizeArtists(final String scriptPath, 169 | final String cachePath) throws Exception { 170 | MLSparseMatrix songArtists = this.data.songFeatures 171 | .get(SongFeature.ARTIST_ID).getFeatMatrix(); 172 | timer.toc("nArtists " + songArtists.getNCols()); 173 | 174 | // create artist matrix 175 | MLSparseMatrix Rtrain = this.split.getRstrain() 176 | .get(ParsedData.INTERACTION_KEY); 177 | 178 | MLSparseVector[] rowsArtist = new MLSparseVector[Rtrain.getNCols() 179 | + Rtrain.getNRows()]; 180 | AtomicInteger counter = new AtomicInteger(0); 181 | IntStream.range(0, Rtrain.getNCols()).parallel().forEach(songIndex -> { 182 | int count = counter.incrementAndGet(); 183 | if (count % 500_000 == 0) { 184 | timer.tocLoop("songs done", count); 185 | } 186 | 187 | MLSparseVector artists = songArtists.getRow(songIndex); 188 | if (artists != null) { 189 | rowsArtist[songIndex] = artists.deepCopy(); 190 | } 191 | }); 192 | 193 | counter.set(0); 194 | IntStream.range(0, Rtrain.getNRows()).parallel() 195 | .forEach(playlistIndex -> { 196 | int count = counter.incrementAndGet(); 197 | if (count % 500_000 == 0) { 198 | timer.tocLoop("playlists done", count); 199 | } 200 | 201 | MLSparseVector playlist = Rtrain.getRow(playlistIndex); 202 | if (playlist == null) { 203 | return; 204 | } 205 | 206 | MLDenseVector rowAvg = getRowAvg(songArtists, 207 | playlist.getIndexes(), false); 208 | MLSparseVector rowAvgSparse = rowAvg.toSparse(); 209 | if (rowAvgSparse.getIndexes() != null) { 210 | rowsArtist[playlistIndex 211 | + Rtrain.getNCols()] = rowAvgSparse; 212 | } 213 | }); 214 | 215 | MLSparseMatrix artistMatrix = new MLSparseMatrixAOO(rowsArtist, 216 | songArtists.getNCols()); 217 | timer.toc("artist matrix done " + artistMatrix.getNRows() + " " 218 | + artistMatrix.getNCols() + " " + artistMatrix.getNNZ()); 219 | 220 | SVDParams svdParams = new SVDParams(); 221 | svdParams.svdIter = 4; 222 | svdParams.rank = 200; 223 | svdParams.scriptPath = scriptPath; 224 | svdParams.cachePath = cachePath; 225 | svdParams.shapeRows = artistMatrix.getNRows(); 226 | svdParams.shapeCols = artistMatrix.getNCols(); 227 | SVD svd = new SVD(svdParams); 228 | svd.runPythonSVD(artistMatrix); 229 | 230 | this.latents.Vartist = svd.P.slice(0, Rtrain.getNCols()); 231 | this.latents.Uartist = svd.P.slice(Rtrain.getNCols(), 232 | Rtrain.getNCols() + Rtrain.getNRows()); 233 | this.latents.artist = svd.Q; 234 | } 235 | 236 | public static MLDenseVector getRowAvg(final MLSparseMatrix R, 237 | final int[] rowIndices, final boolean normalize) { 238 | float[] rowAvg = new float[R.getNCols()]; 239 | int count = 0; 240 | for (int colIndex : rowIndices) { 241 | MLSparseVector row = R.getRow(colIndex); 242 | // count++; 243 | if (row == null) { 244 | continue; 245 | } 246 | count++; 247 | 248 | int[] indexes = row.getIndexes(); 249 | float[] values = row.getValues(); 250 | 251 | for (int i = 0; i < indexes.length; i++) { 252 | rowAvg[indexes[i]] += values[i]; 253 | } 254 | } 255 | 256 | if (normalize == true && count > 1) { 257 | for (int i = 0; i < rowAvg.length; i++) { 258 | rowAvg[i] = rowAvg[i] / count; 259 | } 260 | } 261 | 262 | return new MLDenseVector(rowAvg); 263 | } 264 | 265 | } 266 | -------------------------------------------------------------------------------- /src/main/java/common/ALS.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.lang.reflect.Field; 4 | import java.util.Arrays; 5 | import java.util.Random; 6 | import java.util.concurrent.atomic.AtomicInteger; 7 | import java.util.stream.IntStream; 8 | 9 | public class ALS { 10 | 11 | public static class ALSParams { 12 | public int maxIter = 10; 13 | public int rank = 200; 14 | public float alpha = 10f; 15 | public float lambda = 0.01f; 16 | public float init = 0.01f; 17 | public int seed = 1; 18 | public boolean evaluate = true; 19 | public boolean debug = true; 20 | public int printFrequency = 500_000; 21 | 22 | @Override 23 | public String toString() { 24 | StringBuilder result = new StringBuilder(); 25 | String newLine = System.getProperty("line.separator"); 26 | 27 | result.append(this.getClass().getName()); 28 | result.append(" {"); 29 | result.append(newLine); 30 | 31 | // determine fields declared in this class only (no fields of 32 | // superclass) 33 | Field[] fields = this.getClass().getDeclaredFields(); 34 | 35 | // print field names paired with their values 36 | for (Field field : fields) { 37 | result.append(" "); 38 | try { 39 | result.append(field.getName()); 40 | result.append(": "); 41 | // requires access to private field: 42 | result.append(field.get(this)); 43 | } catch (IllegalAccessException ex) { 44 | System.out.println(ex); 45 | } 46 | result.append(newLine); 47 | } 48 | result.append("}"); 49 | 50 | return result.toString(); 51 | } 52 | 53 | } 54 | 55 | private ALSParams params; 56 | private MLDenseMatrix U; 57 | private MLDenseMatrix V; 58 | private MLTimer timer; 59 | 60 | public ALS(final ALSParams paramsP) { 61 | this.params = paramsP; 62 | this.timer = new MLTimer("als", this.params.maxIter); 63 | } 64 | 65 | public MLDenseMatrix getU() { 66 | return this.U; 67 | } 68 | 69 | public MLDenseMatrix getV() { 70 | return this.V; 71 | } 72 | 73 | public void optimize(final MLSparseMatrix R_train, final String outPath) 74 | throws Exception { 75 | 76 | this.timer.tic(); 77 | 78 | MLSparseMatrix R_train_t = R_train.transpose(); 79 | this.timer.toc("obtained R Rt"); 80 | 81 | // randomly initialize U and V 82 | if (this.U == null) { 83 | this.U = MLDenseMatrix.initRandom(R_train.getNRows(), 84 | this.params.rank, this.params.init, this.params.seed); 85 | this.timer.toc("initialized U"); 86 | } 87 | 88 | if (this.V == null) { 89 | this.V = MLDenseMatrix.initRandom(R_train.getNCols(), 90 | this.params.rank, this.params.init, this.params.seed); 91 | this.timer.toc("initialized V"); 92 | } 93 | 94 | for (int i = 0; i < R_train.getNRows(); i++) { 95 | if (R_train.getRow(i) == null) { 96 | // zero out cold start users 97 | this.U.setRow(new MLDenseVector(new float[this.params.rank]), 98 | i); 99 | } 100 | } 101 | for (int i = 0; i < R_train_t.getNRows(); i++) { 102 | if (R_train_t.getRow(i) == null) { 103 | // zero out cold start items 104 | this.V.setRow(new MLDenseVector(new float[this.params.rank]), 105 | i); 106 | } 107 | } 108 | 109 | for (int iter = 0; iter < this.params.maxIter; iter++) { 110 | this.solve(R_train, this.U, this.V); 111 | this.solve(R_train_t, this.V, this.U); 112 | this.timer.toc("solver done"); 113 | this.timer.toc(String.format("[iter %d] done", iter)); 114 | } 115 | 116 | if (outPath != null) { 117 | String uOutFile = outPath + "U_" + this.params.rank + ".bin"; 118 | String vOutFile = outPath + "V_" + this.params.rank + ".bin"; 119 | 120 | this.U.toFile(uOutFile); 121 | this.timer.toc("written U to " + uOutFile); 122 | 123 | this.V.toFile(vOutFile); 124 | this.timer.toc("written V to " + vOutFile); 125 | } 126 | } 127 | 128 | public void setU(final MLDenseMatrix Up) { 129 | this.U = Up; 130 | } 131 | 132 | public void setV(final MLDenseMatrix Vp) { 133 | this.V = Vp; 134 | } 135 | 136 | private MLDenseVector solve(final int targetIndex, 137 | final MLSparseMatrix data, final float[] H, final float[] HH, 138 | final float[] cache) { 139 | int[] rowIndexes = data.getRow(targetIndex).getIndexes(); 140 | float[] values = data.getRow(targetIndex).getValues(); 141 | 142 | float[] HC_minus_IH = new float[this.params.rank * this.params.rank]; 143 | for (int i = 0; i < this.params.rank; i++) { 144 | for (int j = i; j < this.params.rank; j++) { 145 | float total = 0; 146 | for (int k = 0; k < rowIndexes.length; k++) { 147 | int offset = rowIndexes[k] * this.params.rank; 148 | total += H[offset + i] * H[offset + j] * values[k]; 149 | } 150 | HC_minus_IH[i * this.params.rank + j] = total 151 | * this.params.alpha; 152 | HC_minus_IH[j * this.params.rank + i] = total 153 | * this.params.alpha; 154 | } 155 | } 156 | // create HCp in O(f|S_u|) 157 | float[] HCp = new float[this.params.rank]; 158 | for (int i = 0; i < this.params.rank; i++) { 159 | float total = 0; 160 | for (int k = 0; k < rowIndexes.length; k++) { 161 | total += H[rowIndexes[k] * this.params.rank + i] 162 | * (1 + this.params.alpha * values[k]); 163 | } 164 | HCp[i] = total; 165 | } 166 | // create temp = HH + HC_minus_IH + lambda*I 167 | // temp is symmetric 168 | // the inverse temp is symmetric 169 | float[] temp = new float[this.params.rank * this.params.rank]; 170 | for (int i = 0; i < this.params.rank; i++) { 171 | final int offset = i * this.params.rank; 172 | for (int j = i; j < this.params.rank; j++) { 173 | float total = HH[offset + j] + HC_minus_IH[offset + j]; 174 | if (i == j) { 175 | total += this.params.lambda; 176 | } 177 | temp[offset + j] = total; 178 | } 179 | } 180 | 181 | LowLevelRoutines.symmetricSolve(temp, this.params.rank, HCp, cache); 182 | 183 | // return optimal solution 184 | return new MLDenseVector(HCp); 185 | } 186 | 187 | private MLDenseVector solve(final int targetIndex, 188 | final MLSparseMatrix data, final MLDenseMatrix H, final float[] HH, 189 | final float[] cache) { 190 | int[] rowIndexes = data.getRow(targetIndex).getIndexes(); 191 | float[] values = data.getRow(targetIndex).getValues(); 192 | 193 | float[] HC_minus_IH = new float[this.params.rank * this.params.rank]; 194 | for (int i = 0; i < this.params.rank; i++) { 195 | for (int j = i; j < this.params.rank; j++) { 196 | float total = 0; 197 | for (int k = 0; k < rowIndexes.length; k++) { 198 | total += H.getValue(rowIndexes[k], i) 199 | * H.getValue(rowIndexes[k], j) * values[k]; 200 | } 201 | HC_minus_IH[i * this.params.rank + j] = total 202 | * this.params.alpha; 203 | HC_minus_IH[j * this.params.rank + i] = total 204 | * this.params.alpha; 205 | } 206 | } 207 | // create HCp in O(f|S_u|) 208 | float[] HCp = new float[this.params.rank]; 209 | for (int i = 0; i < this.params.rank; i++) { 210 | float total = 0; 211 | for (int k = 0; k < rowIndexes.length; k++) { 212 | total += H.getValue(rowIndexes[k], i) 213 | * (1 + this.params.alpha * values[k]); 214 | } 215 | HCp[i] = total; 216 | } 217 | // create temp = HH + HC_minus_IH + lambda*I 218 | // temp is symmetric 219 | // the inverse temp is symmetric 220 | float[] temp = new float[this.params.rank * this.params.rank]; 221 | for (int i = 0; i < this.params.rank; i++) { 222 | final int offset = i * this.params.rank; 223 | for (int j = i; j < this.params.rank; j++) { 224 | float total = HH[offset + j] + HC_minus_IH[offset + j]; 225 | if (i == j) { 226 | total += this.params.lambda; 227 | } 228 | temp[offset + j] = total; 229 | } 230 | } 231 | 232 | LowLevelRoutines.symmetricSolve(temp, this.params.rank, HCp, cache); 233 | 234 | // return optimal solution 235 | return new MLDenseVector(Arrays.copyOf(HCp, this.params.rank)); 236 | } 237 | 238 | private void solve(final MLSparseMatrix data, final MLDenseMatrix W, 239 | final MLDenseMatrix H) { 240 | 241 | int cacheSize = LowLevelRoutines.symmInverseCacheSize( 242 | new float[this.params.rank * this.params.rank], 243 | this.params.rank); 244 | // float[] cache = new float[cacheSize]; 245 | MLConcurrentUtils.Async cache = new MLConcurrentUtils.Async<>( 246 | () -> new float[cacheSize], null); 247 | MLTimer timer = new MLTimer("als", data.getNRows()); 248 | timer.tic(); 249 | 250 | // compute H_t * H 251 | MLDenseMatrix HH = H.transposeMultNative(); 252 | float[] HHflat = HH.toFlatArray(); 253 | if (this.params.debug) { 254 | timer.toc("HH done"); 255 | } 256 | 257 | boolean[] useFlat = new boolean[] { false }; 258 | float[][] Hflat = new float[1][]; 259 | if (H.getNRows() < LowLevelRoutines.MAX_ARRAY_SIZE / H.getNCols()) { 260 | // no overflow so use flat version 261 | useFlat[0] = true; 262 | Hflat[0] = H.toFlatArray(); 263 | if (this.params.debug) { 264 | timer.toc("H to flat done"); 265 | } 266 | } else { 267 | System.out.println("WARNING: not using flat H"); 268 | } 269 | 270 | int[] rowIndices = new int[data.getNRows()]; 271 | for (int i = 0; i < data.getNRows(); i++) { 272 | rowIndices[i] = i; 273 | } 274 | MLRandomUtils.shuffle(rowIndices, new Random(1)); 275 | AtomicInteger counter = new AtomicInteger(0); 276 | IntStream.range(0, rowIndices.length).parallel().forEach(i -> { 277 | int count = counter.incrementAndGet(); 278 | if (this.params.debug && count % this.params.printFrequency == 0) { 279 | timer.tocLoop(count); 280 | } 281 | int rowIndex = rowIndices[i]; 282 | if (data.getRow(rowIndex) == null) { 283 | return; 284 | } 285 | 286 | MLDenseVector solution; 287 | if (useFlat[0] == true) { 288 | solution = solve(rowIndex, data, Hflat[0], HHflat, cache.get()); 289 | } else { 290 | solution = solve(rowIndex, data, H, HHflat, cache.get()); 291 | } 292 | 293 | W.setRow(solution, rowIndex); 294 | }); 295 | if (this.params.debug) { 296 | timer.tocLoop(counter.get()); 297 | } 298 | } 299 | 300 | public static void main(String[] args) { 301 | MLDenseMatrix V = new MLDenseMatrix( 302 | new MLDenseVector[] { new MLDenseVector(new float[] { 1, 2 }), 303 | new MLDenseVector(new float[] { 3, 4 }), 304 | new MLDenseVector(new float[] { 5, 6 }) }); 305 | MLDenseMatrix U = new MLDenseMatrix( 306 | new MLDenseVector[] { new MLDenseVector(new float[] { 1, -2 }), 307 | new MLDenseVector(new float[] { 3, -4 }), 308 | new MLDenseVector(new float[] { 5, -6 }) }); 309 | 310 | MLSparseVector[] test = new MLSparseVector[3]; 311 | test[0] = new MLSparseVector(new int[] { 0, 1 }, new float[] { 1, 1 }, 312 | null, 3); 313 | test[1] = new MLSparseVector(new int[] { 0, 1, 2 }, 314 | new float[] { 1, 1, 1 }, null, 3); 315 | test[2] = new MLSparseVector(new int[] { 1, 2 }, new float[] { 1, 1 }, 316 | null, 3); 317 | MLSparseMatrix R = new MLSparseMatrixAOO(test, 3); 318 | MLSparseMatrix RT = new MLSparseMatrixAOO(test, 3); 319 | 320 | ALSParams params = new ALSParams(); 321 | params.maxIter = 1; 322 | params.rank = 2; 323 | params.lambda = 0f; 324 | 325 | ALS als = new ALS(params); 326 | als.solve(R, U, V); 327 | als.solve(RT, V, U); 328 | 329 | System.out.println("U"); 330 | System.out.println(Arrays.toString(U.getRow(0).getValues())); 331 | System.out.println(Arrays.toString(U.getRow(1).getValues())); 332 | System.out.println(Arrays.toString(U.getRow(2).getValues())); 333 | System.out.println("\nV"); 334 | System.out.println(Arrays.toString(V.getRow(0).getValues())); 335 | System.out.println(Arrays.toString(V.getRow(1).getValues())); 336 | System.out.println(Arrays.toString(V.getRow(2).getValues())); 337 | } 338 | } 339 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/main/java/main/RecSysSplitter.java: -------------------------------------------------------------------------------- 1 | package main; 2 | 3 | import java.util.Arrays; 4 | import java.util.Collections; 5 | import java.util.HashMap; 6 | import java.util.HashSet; 7 | import java.util.LinkedList; 8 | import java.util.List; 9 | import java.util.Map; 10 | import java.util.Random; 11 | import java.util.Set; 12 | import java.util.concurrent.atomic.AtomicInteger; 13 | import java.util.stream.IntStream; 14 | 15 | import common.MLMatrixElement; 16 | import common.MLRandomUtils; 17 | import common.MLSparseFeature; 18 | import common.MLSparseMatrix; 19 | import common.MLSparseMatrixAOO; 20 | import common.MLSparseVector; 21 | import common.MLTimer; 22 | import common.SplitterCF; 23 | import main.ParsedData.PlaylistFeature; 24 | 25 | public class RecSysSplitter { 26 | 27 | // Predict tracks for a playlist given its title only 28 | // Predict tracks for a playlist given its title and the first track 29 | // Predict tracks for a playlist given its title and the first 5 tracks 30 | // Predict tracks for a playlist given its first 5 tracks (no title) 31 | // Predict tracks for a playlist given its title and the first 10 tracks 32 | // Predict tracks for a playlist given its first 10 tracks (no title) 33 | // Predict tracks for a playlist given its title and the first 25 tracks 34 | // Predict tracks for a playlist given its title and 25 random tracks 35 | // Predict tracks for a playlist given its title and the first 100 tracks 36 | // Predict tracks for a playlist given its title and 100 random tracks 37 | 38 | private static MLTimer timer = new MLTimer("RecSysSplitter"); 39 | static { 40 | timer.tic(); 41 | } 42 | 43 | public static SplitterCF getSplitMatching(final ParsedData data) { 44 | 45 | // init with full data 46 | MLSparseVector[] trainRows = new MLSparseVector[data.interactions 47 | .getNRows()]; 48 | MLSparseVector[] validRows = new MLSparseVector[data.interactions 49 | .getNRows()]; 50 | for (int i = 0; i < data.interactions.getNRows(); i++) { 51 | MLSparseVector row = data.interactions.getRow(i); 52 | if (row != null) { 53 | trainRows[i] = row.deepCopy(); 54 | } 55 | } 56 | 57 | List validIndexList = new LinkedList(); 58 | Set addedIndexes = new HashSet(); 59 | AtomicInteger nExact = new AtomicInteger(0); 60 | AtomicInteger nAtLeast = new AtomicInteger(0); 61 | AtomicInteger counter = new AtomicInteger(0); 62 | IntStream.range(0, data.testIndexes.length).parallel() 63 | .forEach(index -> { 64 | int count = counter.incrementAndGet(); 65 | if (count % 1000 == 0) { 66 | timer.tocLoop(count); 67 | } 68 | 69 | int testIndex = data.testIndexes[index]; 70 | if (data.interactions.getRow(testIndex) == null) { 71 | // skip cold start 72 | return; 73 | } 74 | 75 | int nTracksTotal = (int) data.playlistFeatures 76 | .get(PlaylistFeature.N_TRACKS) 77 | .getRow(testIndex, false).getValues()[0]; 78 | float[] valuesTest = data.interactions.getRow(testIndex) 79 | .getValues(); 80 | // int nTracksTrain = data.interactions.getRow(testIndex) 81 | // .getIndexes().length; 82 | int nTracksTrain = 0; 83 | for (float value : valuesTest) { 84 | nTracksTrain += value; 85 | } 86 | 87 | // find training playlists with nTracksTotal 88 | List exact = new LinkedList(); 89 | List atLeast = new LinkedList(); 90 | for (int i = 0; i < data.interactions.getNRows(); i++) { 91 | if (Arrays.binarySearch(data.testIndexes, i) >= 0) { 92 | // don't split test playlists 93 | continue; 94 | } 95 | 96 | MLSparseVector row = data.interactions.getRow(i); 97 | if (row.getIndexes().length == nTracksTotal) { 98 | exact.add(i); 99 | 100 | } else if (row.getIndexes().length > nTracksTotal) { 101 | atLeast.add(i); 102 | } 103 | } 104 | Collections.shuffle(exact, new Random(index)); 105 | Collections.shuffle(atLeast, new Random(index)); 106 | 107 | int repeat = 0; 108 | while (repeat < 10) { 109 | Integer validIndex = null; 110 | synchronized (addedIndexes) { 111 | if (validIndex == null && exact.size() > 0) { 112 | while (exact.size() > 0) { 113 | int candIndex = exact.remove(0); 114 | if (addedIndexes 115 | .contains(candIndex) == false) { 116 | validIndex = candIndex; 117 | nExact.incrementAndGet(); 118 | break; 119 | } 120 | } 121 | } 122 | 123 | if (validIndex == null && atLeast.size() > 0) { 124 | while (atLeast.size() > 0) { 125 | int candIndex = atLeast.remove(0); 126 | if (addedIndexes 127 | .contains(candIndex) == false) { 128 | validIndex = candIndex; 129 | nAtLeast.incrementAndGet(); 130 | break; 131 | } 132 | } 133 | } 134 | 135 | if (validIndex == null) { 136 | // no rows with nTracksTotal 137 | break; 138 | } 139 | addedIndexes.add(validIndex); 140 | if (repeat == 0) { 141 | validIndexList.add(validIndex); 142 | } 143 | repeat++; 144 | } 145 | 146 | // split this row 147 | MLSparseVector row = data.interactions 148 | .getRow(validIndex); 149 | int[] indexes = row.getIndexes(); 150 | float[] values = row.getValues(); 151 | long[] dates = row.getDates(); 152 | 153 | MLMatrixElement[] elements = new MLMatrixElement[indexes.length]; 154 | for (int i = 0; i < indexes.length; i++) { 155 | elements[i] = new MLMatrixElement(validIndex, 156 | indexes[i], values[i], dates[i]); 157 | } 158 | 159 | if (isInRandomOrder(testIndex) == false) { 160 | // split by position 161 | Arrays.sort(elements, 162 | new MLMatrixElement.DateComparator(true)); 163 | } else { 164 | // random order split 165 | MLRandomUtils.shuffle(elements, new Random(index)); 166 | } 167 | 168 | Set trainIndexes = new HashSet(); 169 | for (int i = 0; i < nTracksTrain; i++) { 170 | trainIndexes.add(elements[i].getColIndex()); 171 | } 172 | 173 | int jtrain = 0; 174 | int[] indexesTrain = new int[nTracksTrain]; 175 | float[] valuesTrain = new float[nTracksTrain]; 176 | long[] datesTrain = new long[nTracksTrain]; 177 | 178 | int jvalid = 0; 179 | int[] indexesValid = new int[indexes.length 180 | - nTracksTrain]; 181 | float[] valuesValid = new float[indexes.length 182 | - nTracksTrain]; 183 | long[] datesValid = new long[indexes.length 184 | - nTracksTrain]; 185 | 186 | for (int i = 0; i < indexes.length; i++) { 187 | if (trainIndexes.contains(indexes[i]) == true) { 188 | indexesTrain[jtrain] = indexes[i]; 189 | valuesTrain[jtrain] = values[i]; 190 | datesTrain[jtrain] = dates[i]; 191 | jtrain++; 192 | 193 | } else { 194 | indexesValid[jvalid] = indexes[i]; 195 | valuesValid[jvalid] = values[i]; 196 | datesValid[jvalid] = dates[i]; 197 | jvalid++; 198 | } 199 | } 200 | 201 | // update split rows 202 | trainRows[validIndex] = new MLSparseVector(indexesTrain, 203 | valuesTrain, datesTrain, 204 | data.interactions.getNCols()); 205 | validRows[validIndex] = new MLSparseVector(indexesValid, 206 | valuesValid, datesValid, 207 | data.interactions.getNCols()); 208 | } 209 | }); 210 | 211 | MLSparseMatrix Rtrain = new MLSparseMatrixAOO(trainRows, 212 | data.interactions.getNCols()); 213 | Map trainMap = new HashMap(); 214 | trainMap.put(ParsedData.INTERACTION_KEY, Rtrain); 215 | 216 | MLSparseMatrix Rvalid = new MLSparseMatrixAOO(validRows, 217 | data.interactions.getNCols()); 218 | Map validMap = new HashMap(); 219 | validMap.put(ParsedData.INTERACTION_KEY, Rvalid); 220 | 221 | SplitterCF newSplit = new SplitterCF(); 222 | newSplit.setRstrain(trainMap); 223 | newSplit.setRsvalid(validMap); 224 | 225 | int[] validRowIndexes = new int[validIndexList.size()]; 226 | int cur = 0; 227 | for (int index : validIndexList) { 228 | validRowIndexes[cur] = index; 229 | cur++; 230 | } 231 | Arrays.sort(validRowIndexes); 232 | 233 | int[] validColIndexes = new int[data.interactions.getNCols()]; 234 | for (int i = 0; i < data.interactions.getNCols(); i++) { 235 | validColIndexes[i] = i; 236 | } 237 | 238 | newSplit.setValidRowIndexes(validRowIndexes); 239 | newSplit.setValidColIndexes(validColIndexes); 240 | 241 | System.out.println("nExact: " + nExact + " nAtLeast: " + nAtLeast); 242 | System.out.println("validRowIndexes: " + validRowIndexes.length); 243 | System.out.println( 244 | "nnz full:" + data.interactions.getNNZ() + " nnz train: " 245 | + Rtrain.getNNZ() + " nnz valid:" + Rvalid.getNNZ()); 246 | 247 | return newSplit; 248 | 249 | } 250 | 251 | public static void removeName(final ParsedData data, 252 | final SplitterCF split) { 253 | // Predict tracks for a playlist given its title and the first 5 tracks 254 | // Predict tracks for a playlist given its first 5 tracks (no title) 255 | // Predict tracks for a playlist given its title and the first 10 tracks 256 | // Predict tracks for a playlist given its first 10 tracks (no title) 257 | 258 | List fiveTracksValid = new LinkedList(); 259 | List fiveTracks = new LinkedList(); 260 | 261 | List tenTracksValid = new LinkedList(); 262 | List tenTracks = new LinkedList(); 263 | 264 | int[] validRowsIndexes = split.getValidRowIndexes(); 265 | MLSparseMatrix Rtrain = split.getRstrain() 266 | .get(ParsedData.INTERACTION_KEY); 267 | MLSparseMatrix Rvalid = split.getRsvalid() 268 | .get(ParsedData.INTERACTION_KEY); 269 | for (int i = 0; i < Rtrain.getNRows(); i++) { 270 | if (Rvalid.getRow(i) == null) { 271 | continue; 272 | } 273 | 274 | MLSparseVector row = Rtrain.getRow(i); 275 | int[] indexes = row.getIndexes(); 276 | 277 | if (indexes.length == 5) { 278 | if (Arrays.binarySearch(validRowsIndexes, i) >= 0) { 279 | fiveTracksValid.add(i); 280 | } else { 281 | fiveTracks.add(i); 282 | } 283 | 284 | } else if (indexes.length == 10) { 285 | if (Arrays.binarySearch(validRowsIndexes, i) >= 0) { 286 | tenTracksValid.add(i); 287 | } else { 288 | tenTracks.add(i); 289 | } 290 | } 291 | } 292 | 293 | Collections.shuffle(fiveTracksValid, new Random(0)); 294 | Collections.shuffle(fiveTracks, new Random(1)); 295 | Collections.shuffle(tenTracksValid, new Random(2)); 296 | Collections.shuffle(tenTracks, new Random(3)); 297 | 298 | System.out.println("5valid:" + fiveTracksValid.size() + " 5other:" 299 | + fiveTracks.size() + " 10valid:" + tenTracksValid.size() 300 | + " 10other:" + tenTracks.size()); 301 | 302 | // remove names for half of target playlists 303 | MLSparseFeature nameFeature = data.playlistFeatures 304 | .get(PlaylistFeature.NAME_REGEXED); 305 | 306 | for (int i = 0; i < fiveTracksValid.size() / 2; i++) { 307 | nameFeature.getFeatMatrix().setRow(null, fiveTracksValid.get(i)); 308 | nameFeature.getFeatMatrixTransformed().setRow(null, 309 | fiveTracksValid.get(i)); 310 | } 311 | 312 | for (int i = 0; i < fiveTracks.size() / 2; i++) { 313 | nameFeature.getFeatMatrix().setRow(null, fiveTracks.get(i)); 314 | nameFeature.getFeatMatrixTransformed().setRow(null, 315 | fiveTracks.get(i)); 316 | } 317 | 318 | for (int i = 0; i < tenTracksValid.size() / 2; i++) { 319 | nameFeature.getFeatMatrix().setRow(null, tenTracksValid.get(i)); 320 | nameFeature.getFeatMatrixTransformed().setRow(null, 321 | tenTracksValid.get(i)); 322 | } 323 | 324 | for (int i = 0; i < tenTracks.size() / 2; i++) { 325 | nameFeature.getFeatMatrix().setRow(null, tenTracks.get(i)); 326 | nameFeature.getFeatMatrixTransformed().setRow(null, 327 | tenTracks.get(i)); 328 | } 329 | 330 | System.out.println("5valid:" + (fiveTracksValid.size() / 2) 331 | + " 5other:" + (fiveTracks.size() / 2) + " 10valid:" 332 | + (tenTracksValid.size() / 2) + " 10other:" 333 | + (tenTracks.size() / 2)); 334 | } 335 | 336 | public static boolean isInRandomOrder(final int playlistIndex) { 337 | if ((playlistIndex >= 1006000 && playlistIndex <= 1006999) 338 | || (playlistIndex >= 1008000 && playlistIndex <= 1008999)) { 339 | return true; 340 | } else { 341 | return false; 342 | } 343 | } 344 | 345 | } 346 | -------------------------------------------------------------------------------- /src/main/java/main/ParsedDataLoader.java: -------------------------------------------------------------------------------- 1 | package main; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.FileReader; 5 | import java.util.Arrays; 6 | import java.util.HashMap; 7 | import java.util.Map; 8 | import java.util.concurrent.TimeUnit; 9 | import java.util.concurrent.atomic.AtomicInteger; 10 | 11 | import common.MLFeatureTransform; 12 | import common.MLMatrixElement; 13 | import common.MLSparseFeature; 14 | import common.MLSparseMatrixAOO; 15 | import common.MLSparseVector; 16 | import common.MLTextTransform; 17 | import common.MLTimer; 18 | import main.ParsedData.PlaylistFeature; 19 | import main.ParsedData.SongExtraInfoFeature; 20 | import main.ParsedData.SongFeature; 21 | import net.minidev.json.JSONArray; 22 | import net.minidev.json.JSONObject; 23 | import net.minidev.json.parser.JSONParser; 24 | 25 | public class ParsedDataLoader { 26 | 27 | private Data dataLoaded; 28 | public ParsedData dataParsed; 29 | 30 | public ParsedDataLoader(final Data dataLoadedP) { 31 | this.dataLoaded = dataLoadedP; 32 | this.dataParsed = new ParsedData(); 33 | } 34 | 35 | public ParsedDataLoader(final ParsedData dataParsedP) { 36 | this.dataParsed = dataParsedP; 37 | } 38 | 39 | public void loadPlaylists() { 40 | MLTimer timer = new MLTimer("loadPlaylists"); 41 | timer.tic(); 42 | 43 | int nPlaylists = this.dataLoaded.playlists.length; 44 | int nSongs = this.dataLoaded.songs.length; 45 | 46 | MLSparseVector[] rows = new MLSparseVector[nPlaylists]; 47 | this.dataParsed.interactions = new MLSparseMatrixAOO(rows, nSongs); 48 | 49 | // init playlist feature matrices 50 | this.dataParsed.playlistFeatures = new HashMap(); 51 | for (PlaylistFeature featureName : PlaylistFeature.values()) { 52 | MLFeatureTransform[] featTransforms = new MLFeatureTransform[] { 53 | new MLFeatureTransform.ColSelectorTransform(1_000) }; 54 | 55 | MLTextTransform[] textTransforms; 56 | switch (featureName) { 57 | case NAME_TOKENIZED: { 58 | // tokenize playlist name 59 | textTransforms = new MLTextTransform[] { 60 | new MLTextTransform.LuceneAnalyzerTextTransform( 61 | new MLTextTransform.DefaultAnalyzer()) }; 62 | break; 63 | } 64 | default: { 65 | textTransforms = null; 66 | break; 67 | } 68 | } 69 | 70 | MLSparseFeature feature = new MLSparseFeature(nPlaylists, 71 | textTransforms, featTransforms, MLSparseMatrixAOO.class); 72 | this.dataParsed.playlistFeatures.put(featureName, feature); 73 | } 74 | timer.toc("init done"); 75 | 76 | // load playlists 77 | 78 | AtomicInteger count = new AtomicInteger(0); 79 | this.dataParsed.testIndexes = this.dataLoaded.testIndexes; 80 | this.dataParsed.playlistIds = new String[nPlaylists]; 81 | // IntStream.range(0, nPlaylists).parallel()(i -> { 82 | for (int i = 0; i < nPlaylists; i++) { 83 | Playlist playlist = this.dataLoaded.playlists[i]; 84 | this.dataParsed.playlistIds[i] = playlist.get_pid(); 85 | 86 | Track[] tracks = playlist.getTracks(); 87 | 88 | // convert playlist to sparse matrix 89 | if (tracks != null && tracks.length > 0) { 90 | Map elementMap = new HashMap(); 91 | for (int j = 0; j < tracks.length; j++) { 92 | MLMatrixElement element = elementMap 93 | .get(tracks[j].getSongIndex()); 94 | if (element == null) { 95 | // set date to position in the playlist 96 | element = new MLMatrixElement(i, 97 | tracks[j].getSongIndex(), 1.0f, 98 | tracks[j].getSongPos()); 99 | elementMap.put(tracks[j].getSongIndex(), element); 100 | } else { 101 | // some playlists have duplicate songs 102 | element.setValue(element.getValue() + 1.0f); 103 | } 104 | } 105 | MLMatrixElement[] elements = new MLMatrixElement[elementMap 106 | .size()]; 107 | int curIndex = 0; 108 | for (MLMatrixElement element : elementMap.values()) { 109 | elements[curIndex] = element; 110 | curIndex++; 111 | } 112 | Arrays.sort(elements, 113 | new MLMatrixElement.ColIndexComparator(false)); 114 | 115 | int[] indexes = new int[elements.length]; 116 | float[] values = new float[elements.length]; 117 | long[] dates = new long[elements.length]; 118 | for (int j = 0; j < elements.length; j++) { 119 | indexes[j] = elements[j].getColIndex(); 120 | values[j] = elements[j].getValue(); 121 | dates[j] = elements[j].getDate(); 122 | } 123 | rows[i] = new MLSparseVector(indexes, values, dates, nSongs); 124 | } 125 | 126 | // add playlist features 127 | for (PlaylistFeature featureName : PlaylistFeature.values()) { 128 | switch (featureName) { 129 | case NAME_ORIGINAL: { 130 | if (playlist.get_name() != null) { 131 | this.dataParsed.playlistFeatures.get(featureName) 132 | .addRow(i, playlist.get_name()); 133 | } 134 | break; 135 | } 136 | 137 | case NAME_REGEXED: { 138 | if (playlist.get_name() != null) { 139 | String name = playlist.get_name(); 140 | name = name.toLowerCase(); 141 | name = name.replaceAll("\\p{Punct}", " "); 142 | name = name.replaceAll("\\s+", " ").trim(); 143 | this.dataParsed.playlistFeatures.get(featureName) 144 | .addRow(i, name); 145 | } 146 | break; 147 | } 148 | 149 | case NAME_TOKENIZED: { 150 | if (playlist.get_name() != null) { 151 | // convert emojis to string 152 | String name = playlist.get_name(); 153 | this.dataParsed.playlistFeatures.get(featureName) 154 | .addRow(i, name); 155 | } 156 | break; 157 | } 158 | 159 | case N_TRACKS: { 160 | if (playlist.get_num_tracks() != null) { 161 | this.dataParsed.playlistFeatures.get(featureName) 162 | .addRow(i, new MLSparseVector( 163 | new int[] { 0 }, 164 | new float[] { 165 | playlist.get_num_tracks() }, 166 | null, 1)); 167 | } 168 | break; 169 | } 170 | 171 | // case IS_COLLABORATIVE: { 172 | // int collab = 0; 173 | // if (playlist.get_collaborative() == true) { 174 | // collab = 1; 175 | // } 176 | // this.dataParsed.playlistFeatures.get(featureName) 177 | // .addRow(i, new MLSparseVector(new int[] { 0 }, 178 | // new float[] { collab }, null, 1)); 179 | // break; 180 | // } 181 | // 182 | // case MODIFIED_AT: { 183 | // this.dataParsed.playlistFeatures.get(featureName) 184 | // .addRow(i, new MLSparseVector(new int[] { 0 }, 185 | // new float[] { TimeUnit.MILLISECONDS 186 | // .toHours(playlist 187 | // .get_modified_at()) }, 188 | // null, 1)); 189 | // break; 190 | // } 191 | // 192 | // case N_FOLLOWERS: { 193 | // this.dataParsed.playlistFeatures.get(featureName) 194 | // .addRow(i, new MLSparseVector(new int[] { 0 }, 195 | // new float[] { 196 | // playlist.get_num_followers() }, 197 | // null, 1)); 198 | // break; 199 | // } 200 | // 201 | // case N_EDITS: { 202 | // this.dataParsed.playlistFeatures.get(featureName) 203 | // .addRow(i, new MLSparseVector(new int[] { 0 }, 204 | // new float[] { 205 | // playlist.get_num_edits() }, 206 | // null, 1)); 207 | // break; 208 | // } 209 | } 210 | } 211 | 212 | int curCount = count.incrementAndGet(); 213 | if (curCount % 100_000 == 0) { 214 | timer.tocLoop(curCount); 215 | } 216 | // }); 217 | } 218 | timer.tocLoop(count.get()); 219 | 220 | for (PlaylistFeature featureName : PlaylistFeature.values()) { 221 | // finalize feature, apply transforms but preserve original data 222 | this.dataParsed.playlistFeatures.get(featureName) 223 | .finalizeFeature(true); 224 | } 225 | } 226 | 227 | public void loadSongs() { 228 | MLTimer timer = new MLTimer("loadSongs"); 229 | timer.tic(); 230 | int nSongs = this.dataLoaded.songs.length; 231 | 232 | // init song feature matrices 233 | this.dataParsed.songFeatures = new HashMap(); 234 | for (SongFeature featureName : SongFeature.values()) { 235 | MLFeatureTransform[] featTransforms = new MLFeatureTransform[] { 236 | new MLFeatureTransform.ColSelectorTransform(1_000) }; 237 | 238 | MLTextTransform[] textTransforms; 239 | switch (featureName) { 240 | case TRACK_NAME: { 241 | // tokenize song name 242 | textTransforms = new MLTextTransform[] { 243 | new MLTextTransform.LuceneAnalyzerTextTransform( 244 | new MLTextTransform.DefaultAnalyzer()) }; 245 | break; 246 | } 247 | default: { 248 | textTransforms = null; 249 | break; 250 | } 251 | } 252 | 253 | MLSparseFeature feature = new MLSparseFeature(nSongs, 254 | textTransforms, featTransforms, MLSparseMatrixAOO.class); 255 | this.dataParsed.songFeatures.put(featureName, feature); 256 | } 257 | 258 | AtomicInteger count = new AtomicInteger(0); 259 | this.dataParsed.songIds = new String[nSongs]; 260 | // IntStream.range(0, nSongs).parallel()(i -> { 261 | for (int i = 0; i < nSongs; i++) { 262 | Song song = this.dataLoaded.songs[i]; 263 | this.dataParsed.songIds[i] = song.get_track_uri(); 264 | 265 | // add song features 266 | for (SongFeature featureName : SongFeature.values()) { 267 | switch (featureName) { 268 | case ARTIST_ID: { 269 | this.dataParsed.songFeatures.get(featureName).addRow(i, 270 | new String[] { song.get_artist_uri() }); 271 | break; 272 | } 273 | 274 | case ALBUM_ID: { 275 | this.dataParsed.songFeatures.get(featureName).addRow(i, 276 | new String[] { song.get_album_uri() }); 277 | break; 278 | } 279 | 280 | case TRACK_NAME: { 281 | this.dataParsed.songFeatures.get(featureName).addRow(i, 282 | song.get_track_name()); 283 | break; 284 | } 285 | 286 | case DURATION: { 287 | this.dataParsed.songFeatures.get(featureName).addRow(i, 288 | new MLSparseVector(new int[] { 0 }, 289 | new float[] { TimeUnit.MILLISECONDS 290 | .toSeconds(song 291 | .get_duration_ms()) }, 292 | null, 1)); 293 | break; 294 | } 295 | } 296 | } 297 | 298 | int cur = count.incrementAndGet(); 299 | if (cur % 100_000 == 0) { 300 | timer.tocLoop(cur); 301 | } 302 | } 303 | // }); 304 | timer.tocLoop(count.get()); 305 | 306 | for (SongFeature featureName : SongFeature.values()) { 307 | // finalize feature, apply transforms but preserve original data 308 | this.dataParsed.songFeatures.get(featureName).finalizeFeature(true); 309 | } 310 | 311 | } 312 | 313 | public void loadSongExtraInfo(final String inFile) throws Exception { 314 | MLTimer timer = new MLTimer("loadSongExtraInfo"); 315 | timer.tic(); 316 | 317 | Map songToIndexMap = new HashMap(); 318 | for (int i = 0; i < this.dataParsed.songIds.length; i++) { 319 | songToIndexMap.put(this.dataParsed.songIds[i], i); 320 | } 321 | 322 | this.dataParsed.songExtraInfoFeatures = new HashMap(); 323 | for (SongExtraInfoFeature featureName : SongExtraInfoFeature.values()) { 324 | MLFeatureTransform[] featTransforms = new MLFeatureTransform[] { 325 | new MLFeatureTransform.ColSelectorTransform(1_000) }; 326 | 327 | MLSparseFeature feature = new MLSparseFeature( 328 | this.dataParsed.songIds.length, null, featTransforms, 329 | MLSparseMatrixAOO.class); 330 | this.dataParsed.songExtraInfoFeatures.put(featureName, feature); 331 | } 332 | 333 | JSONParser parser = new JSONParser(JSONParser.USE_INTEGER_STORAGE); 334 | AtomicInteger count = new AtomicInteger(0); 335 | try (BufferedReader reader = new BufferedReader( 336 | new FileReader(inFile))) { 337 | JSONArray parsed = (JSONArray) parser.parse(reader); 338 | for (Object element : parsed) { 339 | if (element == null 340 | || ((JSONObject) element).containsKey("uri") == false) { 341 | continue; 342 | } 343 | 344 | String songId = ((JSONObject) element).getAsString("uri"); 345 | int songIndex = songToIndexMap.get(songId); 346 | 347 | int cur = count.incrementAndGet(); 348 | if (cur % 100_000 == 0) { 349 | timer.tocLoop(cur); 350 | } 351 | 352 | for (SongExtraInfoFeature feature : SongExtraInfoFeature 353 | .values()) { 354 | if (((JSONObject) element) 355 | .containsKey(feature.name()) == false 356 | || ((JSONObject) element) 357 | .get(feature.name()) == null) { 358 | continue; 359 | } 360 | 361 | if (feature.equals(SongExtraInfoFeature.key) == true) { 362 | this.dataParsed.songExtraInfoFeatures.get(feature) 363 | .addRow(songIndex, 364 | ((JSONObject) element) 365 | .getAsNumber(feature.name()) 366 | .intValue() + ""); 367 | 368 | } else { 369 | float value = ((JSONObject) element) 370 | .getAsNumber(feature.name()).floatValue(); 371 | this.dataParsed.songExtraInfoFeatures.get(feature) 372 | .addRow(songIndex, 373 | new MLSparseVector(new int[] { 0 }, 374 | new float[] { value }, null, 375 | 1)); 376 | } 377 | } 378 | } 379 | } 380 | timer.tocLoop(count.get()); 381 | 382 | for (SongExtraInfoFeature featureName : SongExtraInfoFeature.values()) { 383 | // finalize feature, apply transforms but preserve original data 384 | this.dataParsed.songExtraInfoFeatures.get(featureName) 385 | .finalizeFeature(true); 386 | } 387 | } 388 | 389 | } 390 | -------------------------------------------------------------------------------- /src/main/java/common/SplitterCF.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.io.Serializable; 4 | import java.util.ArrayList; 5 | import java.util.Arrays; 6 | import java.util.Collections; 7 | import java.util.HashMap; 8 | import java.util.HashSet; 9 | import java.util.List; 10 | import java.util.Map; 11 | import java.util.Random; 12 | import java.util.Set; 13 | import java.util.concurrent.atomic.AtomicInteger; 14 | import java.util.stream.IntStream; 15 | 16 | public class SplitterCF implements Serializable { 17 | 18 | private static final long serialVersionUID = -3182298371988867241L; 19 | private Map Rstrain; 20 | private Map Rsvalid; 21 | private int[] validRowIndexes; 22 | private int[] validColIndexes; 23 | 24 | public SplitterCF() { 25 | 26 | } 27 | 28 | public Map getRstrain() { 29 | return Rstrain; 30 | } 31 | 32 | public Map getRsvalid() { 33 | return Rsvalid; 34 | } 35 | 36 | public int[] getValidColIndexes() { 37 | return validColIndexes; 38 | } 39 | 40 | public int[] getValidRowIndexes() { 41 | return validRowIndexes; 42 | } 43 | 44 | public void setRstrain(Map rstrain) { 45 | this.Rstrain = rstrain; 46 | } 47 | 48 | public void setRsvalid(Map rsvalid) { 49 | this.Rsvalid = rsvalid; 50 | } 51 | 52 | public void setValidColIndexes(int[] validColIndexes) { 53 | this.validColIndexes = validColIndexes; 54 | } 55 | 56 | public void setValidRowIndexes(int[] validRowIndexes) { 57 | this.validRowIndexes = validRowIndexes; 58 | } 59 | 60 | private void split(final Map Rs, 61 | final long dateCutOff) { 62 | this.Rstrain = new HashMap(); 63 | this.Rsvalid = new HashMap(); 64 | 65 | for (Map.Entry entry : Rs.entrySet()) { 66 | MLSparseMatrix R = entry.getValue(); 67 | 68 | MLSparseVector[] trainRows = new MLSparseVector[R.getNRows()]; 69 | MLSparseVector[] validRows = new MLSparseVector[R.getNRows()]; 70 | 71 | AtomicInteger nnzTrain = new AtomicInteger(0); 72 | AtomicInteger nnzValid = new AtomicInteger(0); 73 | IntStream.range(0, R.getNRows()).parallel().forEach(rowIndex -> { 74 | MLSparseVector row = R.getRow(rowIndex); 75 | if (row == null) { 76 | return; 77 | } 78 | long[] dates = row.getDates(); 79 | 80 | int nGreater = 0; 81 | for (int i = 0; i < dates.length; i++) { 82 | if (dates[i] > dateCutOff) { 83 | nGreater++; 84 | } 85 | } 86 | if (nGreater == dates.length) { 87 | // no training data 88 | return; 89 | } 90 | 91 | // split forward in time 92 | int jtrain = 0; 93 | int[] indexesTrain = new int[dates.length - nGreater]; 94 | float[] valuesTrain = new float[dates.length - nGreater]; 95 | long[] datesTrain = new long[dates.length - nGreater]; 96 | 97 | int jvalid = 0; 98 | int[] indexesValid = new int[nGreater]; 99 | float[] valuesValid = new float[nGreater]; 100 | long[] datesValid = new long[nGreater]; 101 | 102 | int[] indexes = row.getIndexes(); 103 | float[] values = row.getValues(); 104 | for (int j = 0; j < dates.length; j++) { 105 | if (dates[j] > dateCutOff) { 106 | // interactions after dateCutOff 107 | indexesValid[jvalid] = indexes[j]; 108 | valuesValid[jvalid] = values[j]; 109 | datesValid[jvalid] = dates[j]; 110 | jvalid++; 111 | 112 | } else { 113 | // interactions before dateCutOff 114 | indexesTrain[jtrain] = indexes[j]; 115 | valuesTrain[jtrain] = values[j]; 116 | datesTrain[jtrain] = dates[j]; 117 | jtrain++; 118 | } 119 | } 120 | 121 | trainRows[rowIndex] = new MLSparseVector(indexesTrain, 122 | valuesTrain, datesTrain, R.getNCols()); 123 | nnzTrain.addAndGet(indexesTrain.length); 124 | 125 | if (indexesValid.length > 0) { 126 | // avoid empty rows 127 | validRows[rowIndex] = new MLSparseVector(indexesValid, 128 | valuesValid, datesValid, R.getNCols()); 129 | nnzValid.addAndGet(indexesValid.length); 130 | } 131 | }); 132 | 133 | this.Rstrain.put(entry.getKey(), 134 | new MLSparseMatrixAOO(trainRows, R.getNCols())); 135 | this.Rsvalid.put(entry.getKey(), 136 | new MLSparseMatrixAOO(validRows, R.getNCols())); 137 | System.out.println("split() valid interaction " + entry.getKey() 138 | + " nnz train:" + nnzTrain.get() + " nnz valid:" 139 | + nnzValid.get()); 140 | } 141 | } 142 | 143 | public void splitByDate(final Map Rs, 144 | final long dateCutOff) { 145 | 146 | // use all rows and all cols for validation 147 | int nRows = Rs.entrySet().iterator().next().getValue().getNRows(); 148 | int nCols = Rs.entrySet().iterator().next().getValue().getNCols(); 149 | 150 | splitByDate(Rs, dateCutOff, null, nRows, nCols, false); 151 | } 152 | 153 | public void splitByDate(final Map Rs, 154 | final long dateCutOff, final Set interToSkip, 155 | final int nValidRows, final int nValidCols, 156 | final boolean coldStart) { 157 | 158 | // generate forward in time split 159 | split(Rs, dateCutOff); 160 | 161 | // get target row and column indices 162 | this.validRowIndexes = getRowIndexes(interToSkip, nValidRows, 163 | this.Rsvalid); 164 | this.validColIndexes = getColIndexes(interToSkip, nValidCols, 165 | this.validRowIndexes, this.Rsvalid); 166 | 167 | if (coldStart == true) { 168 | // remove selected training rows to simulate cold start 169 | for (Map.Entry entry : this.Rstrain 170 | .entrySet()) { 171 | MLSparseMatrix R = entry.getValue(); 172 | for (int index : this.validRowIndexes) { 173 | R.setRow(null, index); 174 | } 175 | } 176 | } 177 | } 178 | 179 | public void splitFrac(final Map Rs, 180 | final float frac, final int minToSplit, 181 | final Set interToSkip, final boolean useDate, 182 | final int nValidRows, final int nValidCols) { 183 | this.Rstrain = new HashMap(); 184 | this.Rsvalid = new HashMap(); 185 | 186 | for (Map.Entry entry : Rs.entrySet()) { 187 | MLSparseMatrix R = entry.getValue(); 188 | 189 | MLSparseVector[] trainRows = new MLSparseVector[R.getNRows()]; 190 | MLSparseVector[] validRows = new MLSparseVector[R.getNRows()]; 191 | 192 | AtomicInteger nnzTrain = new AtomicInteger(0); 193 | AtomicInteger nnzValid = new AtomicInteger(0); 194 | IntStream.range(0, R.getNRows()).parallel().forEach(rowIndex -> { 195 | MLSparseVector row = R.getRow(rowIndex); 196 | if (row == null) { 197 | return; 198 | } 199 | int[] indexes = row.getIndexes(); 200 | float[] values = row.getValues(); 201 | long[] dates = row.getDates(); 202 | 203 | int nTotal = indexes.length; 204 | int nInValid = 0; 205 | if (nTotal < minToSplit) { 206 | // not enough to split 207 | trainRows[rowIndex] = row.deepCopy(); 208 | return; 209 | } 210 | 211 | nInValid = (int) Math.ceil(frac * nTotal); 212 | Set validIndexes = new HashSet(); 213 | if (useDate == false) { 214 | // randomly generate valid indexes 215 | // TODO: make this deterministic 216 | Random random = new Random(rowIndex); 217 | while (validIndexes.size() < nInValid) { 218 | int i = random.nextInt(nTotal); 219 | if (validIndexes.contains(indexes[i]) == false) { 220 | validIndexes.add(indexes[i]); 221 | } 222 | } 223 | } else { 224 | // sort by date and take *last* frac indexes for validation 225 | MLMatrixElement[] elements = new MLMatrixElement[indexes.length]; 226 | for (int i = 0; i < indexes.length; i++) { 227 | elements[i] = new MLMatrixElement(rowIndex, indexes[i], 228 | values[i], dates[i]); 229 | } 230 | Arrays.sort(elements, 231 | new MLMatrixElement.DateComparator(true)); 232 | for (int i = 0; i < nInValid; i++) { 233 | validIndexes.add(elements[i].getColIndex()); 234 | } 235 | } 236 | 237 | // split using validIndexes 238 | int jtrain = 0; 239 | int[] indexesTrain = new int[nTotal - nInValid]; 240 | float[] valuesTrain = new float[nTotal - nInValid]; 241 | long[] datesTrain = null; 242 | if (dates != null) { 243 | datesTrain = new long[nTotal - nInValid]; 244 | } 245 | 246 | int jvalid = 0; 247 | int[] indexesValid = new int[nInValid]; 248 | float[] valuesValid = new float[nInValid]; 249 | long[] datesValid = null; 250 | if (dates != null) { 251 | datesValid = new long[nInValid]; 252 | } 253 | 254 | for (int i = 0; i < dates.length; i++) { 255 | if (validIndexes.contains(indexes[i]) == true) { 256 | indexesValid[jvalid] = indexes[i]; 257 | valuesValid[jvalid] = values[i]; 258 | if (dates != null) { 259 | datesValid[jvalid] = dates[i]; 260 | } 261 | jvalid++; 262 | 263 | } else { 264 | indexesTrain[jtrain] = indexes[i]; 265 | valuesTrain[jtrain] = values[i]; 266 | if (dates != null) { 267 | datesTrain[jtrain] = dates[i]; 268 | } 269 | jtrain++; 270 | } 271 | } 272 | 273 | trainRows[rowIndex] = new MLSparseVector(indexesTrain, 274 | valuesTrain, datesTrain, R.getNCols()); 275 | nnzTrain.addAndGet(indexesTrain.length); 276 | 277 | if (indexesValid.length > 0) { 278 | // avoid empty rows 279 | validRows[rowIndex] = new MLSparseVector(indexesValid, 280 | valuesValid, datesValid, R.getNCols()); 281 | nnzValid.addAndGet(indexesValid.length); 282 | } 283 | }); 284 | 285 | this.Rstrain.put(entry.getKey(), 286 | new MLSparseMatrixAOO(trainRows, R.getNCols())); 287 | this.Rsvalid.put(entry.getKey(), 288 | new MLSparseMatrixAOO(validRows, R.getNCols())); 289 | 290 | // get target row and column indices 291 | this.validRowIndexes = getRowIndexes(interToSkip, nValidRows, 292 | this.Rsvalid); 293 | this.validColIndexes = getColIndexes(interToSkip, nValidCols, 294 | this.validRowIndexes, this.Rsvalid); 295 | 296 | System.out.println("split() valid interaction " + entry.getKey() 297 | + " nnz train:" + nnzTrain.get() + " nnz valid:" 298 | + nnzValid.get()); 299 | } 300 | } 301 | 302 | private static int[] getColIndexes(final Set interToSkip, 303 | final int nValidCols, final int[] validRowIndexes, 304 | final Map Rs) { 305 | 306 | int nCols = Rs.entrySet().iterator().next().getValue().getNCols(); 307 | if (nValidCols > nCols) { 308 | throw new IllegalArgumentException( 309 | "nValidCols=" + nValidCols + " nCols=" + nCols); 310 | } 311 | 312 | if (nValidCols == nCols) { 313 | // use all columns 314 | int[] validColIndexes = new int[nCols]; 315 | for (int i = 0; i < nCols; i++) { 316 | validColIndexes[i] = i; 317 | } 318 | return validColIndexes; 319 | } 320 | 321 | // find all candidate column ids that appear in the valid set 322 | Set validCols = null; 323 | for (Map.Entry entry : Rs.entrySet()) { 324 | if (interToSkip != null 325 | && interToSkip.contains(entry.getKey()) == true) { 326 | // skip these interaction types 327 | continue; 328 | } 329 | MLSparseMatrix R = entry.getValue(); 330 | if (validCols == null) { 331 | validCols = new HashSet(R.getNCols()); 332 | } 333 | 334 | for (int rowIndex : validRowIndexes) { 335 | MLSparseVector row = R.getRow(rowIndex); 336 | if (row == null) { 337 | continue; 338 | } 339 | 340 | for (int colIndex : row.getIndexes()) { 341 | validCols.add(colIndex); 342 | } 343 | } 344 | } 345 | 346 | if (validCols.size() > nValidCols) { 347 | // randomly select nValidCols 348 | List validIndexesPerm = new ArrayList(validCols); 349 | Collections.shuffle(validIndexesPerm, new Random(1)); 350 | 351 | validCols = new HashSet(); 352 | validCols.addAll(validIndexesPerm.subList(0, nValidCols)); 353 | 354 | } else { 355 | // backfill with random sampling 356 | int[] colIndexesRemain = new int[nCols - validCols.size()]; 357 | int cur = 0; 358 | for (int i = 0; i < nCols; i++) { 359 | if (validCols.contains(i) == false) { 360 | colIndexesRemain[cur] = i; 361 | cur++; 362 | } 363 | } 364 | MLRandomUtils.shuffle(colIndexesRemain, new Random(1)); 365 | for (int i = 0; i < nValidCols - validCols.size(); i++) { 366 | validCols.add(colIndexesRemain[i]); 367 | } 368 | } 369 | 370 | int[] validColIndexes = new int[validCols.size()]; 371 | int cur = 0; 372 | for (int index : validCols) { 373 | validColIndexes[cur] = index; 374 | cur++; 375 | } 376 | Arrays.sort(validColIndexes); 377 | return validColIndexes; 378 | } 379 | 380 | private static int[] getRowIndexes(final Set interToSkip, 381 | final int nValidRows, final Map Rs) { 382 | 383 | int nRows = Rs.entrySet().iterator().next().getValue().getNRows(); 384 | if (nValidRows > nRows) { 385 | throw new IllegalArgumentException( 386 | "nValidRows=" + nValidRows + " nRows=" + nRows); 387 | } 388 | 389 | if (nValidRows == nRows) { 390 | // use all rows 391 | int[] validRowIndexes = new int[nRows]; 392 | for (int i = 0; i < nRows; i++) { 393 | validRowIndexes[i] = i; 394 | } 395 | return validRowIndexes; 396 | } 397 | 398 | // get indexes of all validation rows 399 | Set validRows = null; 400 | for (Map.Entry entry : Rs.entrySet()) { 401 | if (interToSkip != null 402 | && interToSkip.contains(entry.getKey()) == true) { 403 | // skip these interaction types 404 | continue; 405 | } 406 | 407 | MLSparseMatrix R = entry.getValue(); 408 | if (validRows == null) { 409 | validRows = new HashSet(R.getNRows()); 410 | } 411 | 412 | for (int i = 0; i < R.getNRows(); i++) { 413 | if (R.getRow(i) != null) { 414 | validRows.add(i); 415 | } 416 | } 417 | } 418 | 419 | // shuffle all validation row indexes and select nValidRows 420 | if (validRows.size() > nValidRows) { 421 | List validIndexesPerm = new ArrayList(validRows); 422 | Collections.shuffle(validIndexesPerm, new Random(1)); 423 | 424 | validRows = new HashSet(); 425 | validRows.addAll(validIndexesPerm.subList(0, nValidRows)); 426 | } 427 | int[] validRowIndexes = new int[validRows.size()]; 428 | int cur = 0; 429 | for (int index : validRows) { 430 | validRowIndexes[cur] = index; 431 | cur++; 432 | } 433 | Arrays.sort(validRowIndexes); 434 | return validRowIndexes; 435 | } 436 | } 437 | -------------------------------------------------------------------------------- /src/main/java/common/MLDenseMatrix.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.io.BufferedInputStream; 4 | import java.io.BufferedOutputStream; 5 | import java.io.DataInputStream; 6 | import java.io.DataOutputStream; 7 | import java.io.FileInputStream; 8 | import java.io.FileOutputStream; 9 | import java.io.IOException; 10 | import java.io.Serializable; 11 | import java.util.Arrays; 12 | import java.util.Random; 13 | import java.util.StringJoiner; 14 | import java.util.function.Function; 15 | import java.util.stream.IntStream; 16 | 17 | import com.google.common.util.concurrent.AtomicDoubleArray; 18 | 19 | public class MLDenseMatrix implements Serializable { 20 | 21 | private static final long serialVersionUID = -8815753536628968271L; 22 | private MLDenseVector[] rows; 23 | 24 | public MLDenseMatrix(final MLDenseVector[] rowsP) { 25 | this.rows = rowsP; 26 | } 27 | 28 | public MLDenseMatrix deepCopy() { 29 | MLDenseVector[] rowsCopy = new MLDenseVector[this.getNRows()]; 30 | IntStream.range(0, this.getNRows()).parallel().forEach(rowIndex -> { 31 | MLDenseVector row = this.rows[rowIndex]; 32 | if (row == null) { 33 | return; 34 | } 35 | rowsCopy[rowIndex] = row.deepCopy(); 36 | }); 37 | 38 | return new MLDenseMatrix(rowsCopy); 39 | } 40 | 41 | public MLDenseVector getColMean() { 42 | 43 | AtomicDoubleArray results = new AtomicDoubleArray(this.getNCols()); 44 | float[] colMean = new float[this.getNCols()]; 45 | 46 | IntStream.range(0, this.getNRows()).parallel().forEach(rowIndex -> { 47 | 48 | MLDenseVector row = this.getRow(rowIndex); 49 | 50 | if (row == null) { 51 | return; 52 | } 53 | 54 | float[] values = row.getValues(); 55 | 56 | for (int i = 0; i < this.getNCols(); i++) { 57 | results.addAndGet(i, (double) values[i]); 58 | } 59 | }); 60 | 61 | IntStream.range(0, this.getNCols()).parallel().forEach(k -> { 62 | colMean[k] = (float) results.get(k) / this.getNRows(); 63 | }); 64 | 65 | return new MLDenseVector(colMean); 66 | } 67 | 68 | public MLDenseVector getColStd(final MLDenseVector mean) { 69 | float[] colStd = new float[this.getNCols()]; 70 | float[] colMean = mean.getValues(); 71 | AtomicDoubleArray results = new AtomicDoubleArray(this.getNCols()); 72 | 73 | IntStream.range(0, this.getNRows()).parallel().forEach(rowIndex -> { 74 | MLDenseVector row = this.getRow(rowIndex); 75 | 76 | if (row == null) { 77 | return; 78 | } 79 | 80 | float[] values = row.getValues(); 81 | 82 | for (int i = 0; i < this.getNCols(); i++) { 83 | double diff = (double) values[i] - colMean[i]; 84 | results.addAndGet(i, diff * diff); 85 | } 86 | }); 87 | 88 | IntStream.range(0, this.getNCols()).parallel().forEach(k -> { 89 | colStd[k] = (float) Math.sqrt(results.get(k) / this.getNRows()); 90 | }); 91 | 92 | return new MLDenseVector(colStd); 93 | } 94 | 95 | public int getNCols() { 96 | return this.rows[0].getLength(); 97 | } 98 | 99 | public int getNRows() { 100 | return this.rows.length; 101 | } 102 | 103 | public MLDenseVector getRow(final int rowIndex) { 104 | return this.rows[rowIndex]; 105 | } 106 | 107 | public MLDenseVector[] getRows() { 108 | return this.rows; 109 | } 110 | 111 | public float getValue(final int rowIndex, final int colIndex) { 112 | return this.rows[rowIndex].getValue(colIndex); 113 | } 114 | 115 | public MLDenseVector multRow(final MLDenseVector vector, 116 | final boolean parallel) { 117 | 118 | // multiply this matrix with nCols x 1 dense vector 119 | float[] result = new float[this.getNRows()]; 120 | if (parallel) { 121 | IntStream.range(0, this.getNRows()).parallel() 122 | .forEach(i -> result[i] = vector.mult(this.rows[i])); 123 | } else { 124 | for (int i = 0; i < this.getNRows(); i++) { 125 | result[i] = vector.mult(this.rows[i]); 126 | } 127 | } 128 | 129 | return new MLDenseVector(result); 130 | } 131 | 132 | public void setRow(final MLDenseVector row, final int rowIndex) { 133 | this.rows[rowIndex] = row; 134 | } 135 | 136 | public MLDenseMatrix slice(int fromInclusive, int toExclusive) { 137 | return new MLDenseMatrix( 138 | Arrays.copyOfRange(this.rows, fromInclusive, toExclusive)); 139 | } 140 | 141 | public MLDenseMatrix slice(int[] inds) { 142 | return slice(inds, 0, inds.length); 143 | } 144 | 145 | public MLDenseMatrix slice(int[] inds, int from, int to) { 146 | MLDenseVector[] slice = new MLDenseVector[to - from]; 147 | IntStream.range(from, to) 148 | .forEach(i -> slice[i - from] = this.rows[inds[i]]); 149 | return new MLDenseMatrix(slice); 150 | } 151 | 152 | public void Standardize() { 153 | // cutOff value for Standardize 154 | int cutOff = 5; 155 | 156 | MLDenseVector mean = this.getColMean(); 157 | MLDenseVector std = this.getColStd(mean); 158 | 159 | float[] colMean = mean.getValues(); 160 | float[] colStd = std.getValues(); 161 | 162 | IntStream.range(0, this.getNRows()).parallel().forEach(rowIndex -> { 163 | MLDenseVector row = this.getRow(rowIndex); 164 | 165 | if (row == null) { 166 | return; 167 | } 168 | 169 | float[] values = row.getValues(); 170 | 171 | for (int i = 0; i < this.getNCols(); i++) { 172 | if (colStd[i] == 0) { 173 | values[i] = values[i] - colMean[i]; 174 | continue; 175 | } 176 | 177 | values[i] = (values[i] - colMean[i]) / colStd[i]; 178 | 179 | if (values[i] > cutOff) { 180 | values[i] = cutOff; 181 | } else if (values[i] < -cutOff) { 182 | values[i] = -cutOff; 183 | } 184 | 185 | } 186 | 187 | this.setRow(new MLDenseVector(values), rowIndex); 188 | }); 189 | 190 | } 191 | 192 | public void toFile(final String outFile) throws IOException { 193 | // in MATLAB 194 | // fid = fopen('U.bin'); 195 | // U_target = fread(fid, nRows * nCols, 'float32', 'ieee-be'); 196 | // fclose(fid); 197 | // U_target = reshape(U_target, nCols, nRows)'; 198 | // U_target = U_target(2:end, :); 199 | 200 | try (DataOutputStream writer = new DataOutputStream( 201 | new BufferedOutputStream(new FileOutputStream(outFile)))) { 202 | for (MLDenseVector row : this.rows) { 203 | float[] values = row.getValues(); 204 | for (float value : values) { 205 | writer.writeFloat(value); 206 | } 207 | } 208 | } 209 | 210 | } 211 | 212 | public float[] toFlatArray() { 213 | int nCols = this.getNCols(); 214 | int nRows = this.getNRows(); 215 | if (nCols * nRows > LowLevelRoutines.MAX_ARRAY_SIZE) { 216 | throw new IllegalArgumentException("nCols*nRows > MAX_ARRAY_SIZE"); 217 | } 218 | 219 | float[] flat = new float[nCols * nRows]; 220 | IntStream.range(0, nRows).parallel().forEach(rowIndex -> { 221 | float[] values = this.rows[rowIndex].getValues(); 222 | int offset = rowIndex * nCols; 223 | System.arraycopy(values, 0, flat, offset, values.length); 224 | }); 225 | return flat; 226 | } 227 | 228 | public float[][] toFlatArrayColSlices(final int maxFlatCols) { 229 | final int nRows = this.getNRows(); 230 | final int nCols = this.getNCols(); 231 | if (nCols > maxFlatCols) { 232 | // big 233 | int numBig = (int) Math.ceil((float) nCols / maxFlatCols); 234 | float[][] big = new float[numBig][]; 235 | for (int i = 0; i < numBig; i++) { 236 | int start = i * maxFlatCols; 237 | int end = Math.min(start + maxFlatCols, nCols); 238 | final int nFlatCols = end - start; 239 | float[] flat = new float[nFlatCols * nRows]; 240 | IntStream.range(0, nRows).parallel() 241 | .forEach(row -> System.arraycopy( 242 | this.rows[row].getValues(), start, flat, 243 | row * nFlatCols, nFlatCols)); 244 | big[i] = flat; 245 | } 246 | return big; 247 | } else { 248 | return new float[][] { toFlatArray() }; 249 | } 250 | } 251 | 252 | public MLSparseMatrix toSparse() { 253 | MLSparseVector[] rowsSparse = new MLSparseVector[this.getNRows()]; 254 | IntStream.range(0, this.getNRows()).parallel().forEach(rowIndex -> { 255 | MLSparseVector rowSparse = this.rows[rowIndex].toSparse(); 256 | if (rowSparse.getIndexes() != null) { 257 | rowsSparse[rowIndex] = rowSparse; 258 | } 259 | }); 260 | 261 | return new MLSparseMatrixAOO(rowsSparse, this.getNCols()); 262 | } 263 | 264 | @Override 265 | public String toString() { 266 | final int nrows = this.getNRows(); 267 | final int ncols = this.getNCols(); 268 | StringBuilder sb = new StringBuilder(); 269 | String fmt = String.format("%%.%df", 4); 270 | String fmtLong = String.format("\t%s\t%s\t...\t%s\t%s\n", fmt, fmt, fmt, 271 | fmt); 272 | Function fmtRow = row -> { 273 | final int n = row.length; 274 | if (n > 4) 275 | return String.format(fmtLong, row[0], row[1], row[n - 2], 276 | row[n - 1]); 277 | else { 278 | StringJoiner ret = new StringJoiner("\t", "\t", "\n"); 279 | for (float val : row) 280 | ret.add(String.format(fmt, val)); 281 | return ret.toString(); 282 | } 283 | }; 284 | sb.append(String.format("[%d x %d]\n", nrows, ncols)); 285 | if (nrows > 3) { 286 | for (int i = 0; i < 2; i++) 287 | sb.append(fmtRow.apply(this.getRow(i).getValues())); 288 | sb.append("\t...\n\t...\n"); 289 | for (int i = nrows - 2; i < nrows; i++) 290 | sb.append(fmtRow.apply(this.getRow(i).getValues())); 291 | } else { 292 | for (int i = 0; i < nrows; i++) 293 | sb.append(fmtRow.apply(this.getRow(i).getValues())); 294 | } 295 | return sb.toString(); 296 | 297 | } 298 | 299 | public MLDenseMatrix transposeMult() { 300 | final int nrows = this.getNRows(); 301 | final int ncols = this.getNCols(); 302 | 303 | MLDenseVector[] result = new MLDenseVector[ncols]; 304 | IntStream.range(0, ncols).parallel().forEach(i -> { 305 | float[] resultRow = new float[ncols]; 306 | for (int j = 0; j < ncols; j++) { 307 | for (int k = 0; k < nrows; k++) { 308 | resultRow[j] += this.rows[k].getValue(i) 309 | * this.rows[k].getValue(j); 310 | } 311 | } 312 | result[i] = new MLDenseVector(resultRow); 313 | }); 314 | 315 | return new MLDenseMatrix(result); 316 | } 317 | 318 | public MLDenseMatrix transposeMultNative() { 319 | return transposeMultNative(LowLevelRoutines.MAX_ARRAY_SIZE); 320 | } 321 | 322 | public MLDenseMatrix transposeMultNative(final int maxArraySize) { 323 | final int nrows = this.getNRows(); 324 | final int ncols = this.getNCols(); 325 | // ability to fit R (m x N/m) 326 | int maxFlatCols = Math.floorDiv(maxArraySize, this.getNRows()); 327 | if (maxFlatCols > ncols) { 328 | maxFlatCols = ncols; 329 | } 330 | // ability to fit RtR (N/m x N/m) 331 | int maxResultCols = Math.floorDiv(maxArraySize, maxFlatCols); 332 | // inability to fit N/m x N/m and overflowing R leads to extreme 333 | // inefficient case 334 | // when overflowing, m is usually pretty large that this won't happen 335 | if (ncols > maxFlatCols && maxResultCols < maxFlatCols) { 336 | throw new UnsupportedOperationException( 337 | "we cannot handle extremely flat and wide matrix"); 338 | } 339 | MLDenseVector[] result = new MLDenseVector[ncols]; 340 | if (ncols > maxFlatCols) { 341 | // big 342 | float[][] bigFlat = toFlatArrayColSlices(maxFlatCols); 343 | IntStream.range(0, ncols).parallel().forEach( 344 | row -> result[row] = new MLDenseVector(new float[ncols])); 345 | // permute multiply 346 | for (int f1 = 0; f1 < bigFlat.length; f1++) { 347 | final int colStartA = f1 * maxFlatCols; 348 | final int colEndA = Math.min(colStartA + maxFlatCols, ncols); 349 | final int ncolsA = colEndA - colStartA; 350 | 351 | float[] flatA = bigFlat[f1]; 352 | // block multiply upper triangle only 353 | for (int f2 = f1; f2 < bigFlat.length; f2++) { 354 | final int colStartB = f2 * maxFlatCols; 355 | final int colEndB = Math.min(colStartB + maxFlatCols, 356 | ncols); 357 | final int ncolsB = colEndB - colStartB; 358 | float[] flatB = bigFlat[f2]; 359 | 360 | float[] raw = new float[ncolsA * ncolsB]; 361 | LowLevelRoutines.sgemm(flatA, flatB, raw, ncolsA, ncolsB, 362 | nrows, false, true, 1, 0); 363 | 364 | // copy into block 365 | IntStream.range(0, ncolsA).parallel().forEach(i -> { 366 | final int offset = i * ncolsB; 367 | System.arraycopy(raw, offset, 368 | result[i + colStartA].getValues(), colStartB, 369 | ncolsB); 370 | }); 371 | // copy into mirrored block 372 | IntStream.range(0, ncolsB).parallel().forEach(i -> { 373 | final float[] mirror = result[i + colStartB] 374 | .getValues(); 375 | for (int j = colStartA, k = i; j < colStartA 376 | + ncolsA; j++, k += ncolsB) { 377 | mirror[j] = raw[k]; 378 | } 379 | }); 380 | } 381 | } 382 | } else { 383 | // fit ncol x nrow 384 | float[] flat = toFlatArray(); 385 | maxResultCols = Math.floorDiv(maxArraySize, ncols); 386 | if (ncols > maxResultCols) { 387 | // fit R but not RtR 388 | float[][] bigFlat = this.toFlatArrayColSlices(maxResultCols); 389 | float[] raw = new float[maxResultCols * ncols]; 390 | for (int f1 = 0; f1 < bigFlat.length; f1++) { 391 | final int colStartA = f1 * maxResultCols; 392 | final int colEndA = Math.min(colStartA + maxResultCols, 393 | ncols); 394 | final int ncolsA = colEndA - colStartA; 395 | LowLevelRoutines.sgemm(bigFlat[f1], flat, raw, ncolsA, 396 | ncols, nrows, false, true, 1, 0); 397 | IntStream.range(0, ncolsA).parallel().forEach(i -> { 398 | final int offset = i * ncols; 399 | float[] resultRow = Arrays.copyOfRange(raw, offset, 400 | offset + ncols); 401 | result[colStartA + i] = new MLDenseVector(resultRow); 402 | }); 403 | } 404 | } else { 405 | // fit R AND RtR 406 | float[] raw = new float[ncols * ncols]; 407 | LowLevelRoutines.sgemm(flat, flat, raw, ncols, ncols, nrows, 408 | false, true, 1, 0); 409 | IntStream.range(0, ncols).parallel().forEach(i -> { 410 | final int offset = i * ncols; 411 | float[] resultRow = Arrays.copyOfRange(raw, offset, 412 | offset + ncols); 413 | result[i] = new MLDenseVector(resultRow); 414 | }); 415 | } 416 | } 417 | 418 | return new MLDenseMatrix(result); 419 | } 420 | 421 | public static MLDenseMatrix fromFile(final String inFile, final int nRows, 422 | final int nCols) throws IOException { 423 | 424 | MLDenseVector[] rows = new MLDenseVector[nRows]; 425 | try (DataInputStream reader = new DataInputStream( 426 | new BufferedInputStream(new FileInputStream(inFile)))) { 427 | 428 | for (int i = 0; i < nRows; i++) { 429 | float[] values = new float[nCols]; 430 | for (int j = 0; j < nCols; j++) { 431 | values[j] = reader.readFloat(); 432 | } 433 | rows[i] = new MLDenseVector(values); 434 | } 435 | 436 | if (reader.available() != 0) { 437 | throw new IllegalArgumentException( 438 | "data left after reading nRows x nCols elements"); 439 | } 440 | 441 | } 442 | 443 | return new MLDenseMatrix(rows); 444 | } 445 | 446 | public static MLDenseMatrix initRandom(final int nRows, final int nCols, 447 | final float initStd, final long seed) { 448 | 449 | MLDenseVector[] rows = new MLDenseVector[nRows]; 450 | IntStream.range(0, nRows).parallel().forEach(i -> { 451 | 452 | float[] values = new float[nCols]; 453 | // ensures that random init is repeatable 454 | Random random = new Random(i + seed); 455 | for (int j = 0; j < nCols; j++) { 456 | values[j] = initStd * ((float) random.nextGaussian()); 457 | } 458 | rows[i] = new MLDenseVector(values); 459 | }); 460 | 461 | return new MLDenseMatrix(rows); 462 | } 463 | } 464 | -------------------------------------------------------------------------------- /src/main/java/common/MLSparseVector.java: -------------------------------------------------------------------------------- 1 | package common; 2 | 3 | import java.io.Serializable; 4 | import java.util.ArrayList; 5 | import java.util.Collections; 6 | import java.util.List; 7 | import java.util.Map; 8 | import java.util.TreeMap; 9 | 10 | public class MLSparseVector implements Serializable { 11 | 12 | private static final long serialVersionUID = -8319046980055965552L; 13 | private int[] indexes; 14 | private float[] values; 15 | private long[] dates; 16 | private int length; 17 | 18 | public MLSparseVector(final int[] indexesP, final float[] valuesP, 19 | final long[] datesP, final int lengthP) { 20 | this.indexes = indexesP; 21 | this.values = valuesP; 22 | this.dates = datesP; 23 | this.length = lengthP; 24 | } 25 | 26 | public void applyDateThresh(final long dateThresh, final boolean greater) { 27 | // count how many dates are over the threshold 28 | int nPass = 0; 29 | for (int i = 0; i < this.dates.length; i++) { 30 | if ((greater == true && this.dates[i] > dateThresh) 31 | || (greater == false && this.dates[i] <= dateThresh)) { 32 | nPass++; 33 | } 34 | } 35 | 36 | if (nPass == 0) { 37 | this.indexes = null; 38 | this.values = null; 39 | this.dates = null; 40 | return; 41 | } 42 | 43 | // apply date threshold 44 | int[] indexesThresh = new int[nPass]; 45 | float[] valuesThresh = new float[nPass]; 46 | long[] datesThresh = new long[nPass]; 47 | 48 | int curIndex = 0; 49 | for (int j = 0; j < this.dates.length; j++) { 50 | if ((greater == true && this.dates[j] > dateThresh) 51 | || (greater == false && this.dates[j] <= dateThresh)) { 52 | indexesThresh[curIndex] = this.indexes[j]; 53 | valuesThresh[curIndex] = this.values[j]; 54 | datesThresh[curIndex] = this.dates[j]; 55 | curIndex++; 56 | } 57 | } 58 | 59 | this.indexes = indexesThresh; 60 | this.values = valuesThresh; 61 | this.dates = datesThresh; 62 | } 63 | 64 | public void applyIndexSelector(final Map selectedIndexMap, 65 | final int nColsSelected) { 66 | if (nColsSelected == 0 || this.indexes == null) { 67 | this.indexes = null; 68 | this.values = null; 69 | this.dates = null; 70 | this.length = nColsSelected; 71 | return; 72 | } 73 | 74 | // apply column selector in place to this vector 75 | List reindexElms = new ArrayList( 76 | this.indexes.length); 77 | for (int i = 0; i < this.indexes.length; i++) { 78 | Integer newIndex = selectedIndexMap.get(this.indexes[i]); 79 | if (newIndex != null) { 80 | if (this.dates != null) { 81 | reindexElms.add(new MLMatrixElement(-1, newIndex, 82 | this.values[i], this.dates[i])); 83 | } else { 84 | reindexElms.add(new MLMatrixElement(-1, newIndex, 85 | this.values[i], -1)); 86 | } 87 | } 88 | } 89 | 90 | if (reindexElms.size() == 0) { 91 | // nothing selected 92 | this.indexes = null; 93 | this.values = null; 94 | this.dates = null; 95 | this.length = nColsSelected; 96 | return; 97 | } 98 | 99 | Collections.sort(reindexElms, 100 | new MLMatrixElement.ColIndexComparator(false)); 101 | int[] prunedIndexes = new int[reindexElms.size()]; 102 | float[] prunedValues = new float[reindexElms.size()]; 103 | long[] prunedDates = null; 104 | if (this.dates != null) { 105 | prunedDates = new long[reindexElms.size()]; 106 | } 107 | 108 | int cur = 0; 109 | for (MLMatrixElement element : reindexElms) { 110 | 111 | prunedIndexes[cur] = element.getColIndex(); 112 | prunedValues[cur] = element.getValue(); 113 | if (this.dates != null) { 114 | prunedDates[cur] = element.getDate(); 115 | } 116 | cur++; 117 | } 118 | 119 | this.indexes = prunedIndexes; 120 | this.values = prunedValues; 121 | this.dates = prunedDates; 122 | this.length = nColsSelected; 123 | } 124 | 125 | public void applyNorm(final int p) { 126 | float rowNorm = this.getNorm(p); 127 | if (rowNorm < 1e-5f) { 128 | return; 129 | } 130 | this.divide(rowNorm); 131 | } 132 | 133 | public void applyNorm(final MLDenseVector norm) { 134 | if (this.length != norm.getLength()) { 135 | throw new IllegalArgumentException("length != length"); 136 | } 137 | 138 | float[] normValues = norm.getValues(); 139 | for (int i = 0; i < this.indexes.length; i++) { 140 | if (normValues[this.indexes[i]] > 1e-10f) { 141 | this.values[i] /= normValues[this.indexes[i]]; 142 | } 143 | } 144 | } 145 | 146 | public MLSparseVector deepCopy() { 147 | long[] datesClone = null; 148 | if (this.dates != null) { 149 | datesClone = this.dates.clone(); 150 | } 151 | 152 | return new MLSparseVector(this.indexes.clone(), this.values.clone(), 153 | datesClone, this.length); 154 | } 155 | 156 | public void divide(final float constant) { 157 | for (int i = 0; i < this.values.length; i++) { 158 | this.values[i] /= constant; 159 | } 160 | } 161 | 162 | public long[] getDates() { 163 | return this.dates; 164 | } 165 | 166 | public int[] getIndexes() { 167 | return this.indexes; 168 | } 169 | 170 | public int getLength() { 171 | return this.length; 172 | } 173 | 174 | public float getNorm(final int p) { 175 | float rowNorm = 0f; 176 | for (int i = 0; i < this.values.length; i++) { 177 | if (p == 1) { 178 | rowNorm += Math.abs(this.values[i]); 179 | } else { 180 | rowNorm += Math.pow(this.values[i], p); 181 | } 182 | } 183 | if (p != 1) { 184 | rowNorm = (float) Math.pow(rowNorm, 1.0 / p); 185 | } 186 | 187 | return rowNorm; 188 | } 189 | 190 | public float[] getValues() { 191 | return this.values; 192 | } 193 | 194 | public int intersect(final MLSparseVector other) { 195 | if (this.length != other.length) { 196 | throw new IllegalArgumentException("length != length"); 197 | } 198 | 199 | int maxIndex = this.indexes[this.indexes.length - 1]; 200 | if (other.getIndexes()[0] > maxIndex) { 201 | // no overlap in indexes 202 | return 0; 203 | } 204 | 205 | int intersect = 0; 206 | int[] otherIndexes = other.getIndexes(); 207 | 208 | int cur = 0; 209 | int curOther = 0; 210 | while (true) { 211 | if (cur >= this.length || curOther >= otherIndexes.length) { 212 | break; 213 | } 214 | 215 | if (otherIndexes[curOther] > maxIndex) { 216 | // indexes are sorted so can exit here 217 | break; 218 | } 219 | 220 | if (this.indexes[cur] == otherIndexes[curOther]) { 221 | intersect++; 222 | cur++; 223 | curOther++; 224 | 225 | } else if (this.indexes[cur] > otherIndexes[curOther]) { 226 | curOther++; 227 | 228 | } else { 229 | cur++; 230 | } 231 | } 232 | 233 | return intersect; 234 | } 235 | 236 | public float max() { 237 | float max = Float.NEGATIVE_INFINITY; 238 | for (int i = 0; i < this.values.length; i++) { 239 | if (max < this.values[i]) { 240 | max = this.values[i]; 241 | } 242 | } 243 | return max; 244 | } 245 | 246 | public void merge(final MLSparseVector vecToMerge) { 247 | if (this.getLength() != vecToMerge.getLength()) { 248 | throw new IllegalArgumentException( 249 | "vector lengths must be the same to merge"); 250 | } 251 | 252 | boolean hasDates = this.dates != null; 253 | 254 | Map rowMap = new TreeMap(); 255 | for (int i = 0; i < vecToMerge.getIndexes().length; i++) { 256 | if (hasDates == true) { 257 | rowMap.put(vecToMerge.getIndexes()[i], 258 | new MLMatrixElement(1, vecToMerge.getIndexes()[i], 259 | vecToMerge.getValues()[i], 260 | vecToMerge.getDates()[i])); 261 | } else { 262 | rowMap.put(vecToMerge.getIndexes()[i], 263 | new MLMatrixElement(1, vecToMerge.getIndexes()[i], 264 | vecToMerge.getValues()[i], 0L)); 265 | } 266 | } 267 | 268 | for (int i = 0; i < this.indexes.length; i++) { 269 | MLMatrixElement element = rowMap.get(this.indexes[i]); 270 | if (element == null) { 271 | if (hasDates == true) { 272 | rowMap.put(this.indexes[i], new MLMatrixElement(1, 273 | this.indexes[i], this.values[i], this.dates[i])); 274 | } else { 275 | rowMap.put(this.getIndexes()[i], new MLMatrixElement(1, 276 | this.indexes[i], this.values[i], 0L)); 277 | } 278 | } else { 279 | if (hasDates == true) { 280 | if (this.dates[i] > element.getDate()) { 281 | // store most recent date 282 | element.setDate(this.dates[i]); 283 | } 284 | } 285 | // sum up values 286 | element.setValue(element.getValue() + this.getValues()[i]); 287 | } 288 | } 289 | 290 | int[] indexesMerged = new int[rowMap.size()]; 291 | float[] valuesMerged = new float[rowMap.size()]; 292 | long[] datesMerged = null; 293 | if (hasDates == true) { 294 | datesMerged = new long[rowMap.size()]; 295 | } 296 | 297 | int index = 0; 298 | for (Map.Entry entry : rowMap.entrySet()) { 299 | MLMatrixElement element = entry.getValue(); 300 | indexesMerged[index] = element.getColIndex(); 301 | valuesMerged[index] = element.getValue(); 302 | if (hasDates == true) { 303 | datesMerged[index] = element.getDate(); 304 | } 305 | index++; 306 | } 307 | 308 | this.indexes = indexesMerged; 309 | this.values = valuesMerged; 310 | this.dates = datesMerged; 311 | } 312 | 313 | public float min() { 314 | float min = Float.POSITIVE_INFINITY; 315 | for (int i = 0; i < this.values.length; i++) { 316 | if (min > this.values[i]) { 317 | min = this.values[i]; 318 | } 319 | } 320 | return min; 321 | } 322 | 323 | public float multiply(final MLSparseVector other) { 324 | if (this.length != other.length) { 325 | throw new IllegalArgumentException("length != length"); 326 | } 327 | 328 | int maxIndex = this.indexes[this.indexes.length - 1]; 329 | if (other.getIndexes()[0] > maxIndex) { 330 | // no overlap in indexes 331 | return 0f; 332 | } 333 | 334 | float product = 0f; 335 | int[] otherIndexes = other.getIndexes(); 336 | float[] otherValues = other.getValues(); 337 | 338 | int cur = 0; 339 | int curOther = 0; 340 | while (true) { 341 | if (cur >= this.length || curOther >= otherIndexes.length) { 342 | break; 343 | } 344 | 345 | if (otherIndexes[curOther] > maxIndex) { 346 | // indexes are sorted so can exit here 347 | break; 348 | } 349 | 350 | if (this.indexes[cur] == otherIndexes[curOther]) { 351 | product += this.values[cur] * otherValues[curOther]; 352 | cur++; 353 | curOther++; 354 | 355 | } else if (this.indexes[cur] > otherIndexes[curOther]) { 356 | curOther++; 357 | 358 | } else { 359 | cur++; 360 | } 361 | } 362 | 363 | return product; 364 | } 365 | 366 | public void setDates(long[] dates) { 367 | this.dates = dates; 368 | } 369 | 370 | public void setIndexes(int[] indexes) { 371 | this.indexes = indexes; 372 | } 373 | 374 | public void setLength(int dim) { 375 | this.length = dim; 376 | } 377 | 378 | public void setValues(float[] values) { 379 | this.values = values; 380 | } 381 | 382 | public MLSparseVector subtract(final MLSparseVector other) { 383 | if (this.getLength() != other.getLength()) { 384 | throw new IllegalArgumentException( 385 | "vectors must have equall lengths"); 386 | } 387 | 388 | float[] result = new float[this.getLength()]; 389 | 390 | for (int i = 0; i < this.indexes.length; i++) { 391 | result[this.indexes[i]] += this.values[i]; 392 | } 393 | 394 | int[] otherIndexes = other.getIndexes(); 395 | float[] othervalues = other.getValues(); 396 | for (int i = 0; i < otherIndexes.length; i++) { 397 | result[otherIndexes[i]] -= othervalues[i]; 398 | } 399 | return new MLDenseVector(result).toSparse(); 400 | } 401 | 402 | public MLDenseVector toDense() { 403 | float[] dense = new float[this.length]; 404 | for (int i = 0; i < this.indexes.length; i++) { 405 | dense[this.indexes[i]] = this.values[i]; 406 | } 407 | return new MLDenseVector(dense); 408 | } 409 | 410 | public String toLIBSVMString(int offset) { 411 | 412 | StringBuilder builder = new StringBuilder(); 413 | for (int i = 0; i < this.indexes.length; i++) { 414 | float val = this.values[i]; 415 | if (val == Math.round(val)) { 416 | builder.append( 417 | " " + (offset + this.indexes[i]) + ":" + ((int) val)); 418 | } else { 419 | builder.append(" " + (offset + this.indexes[i]) + ":" 420 | + String.format("%.5f", val)); 421 | } 422 | } 423 | return builder.toString(); 424 | } 425 | 426 | public static MLSparseVector concat(final MLSparseVector... vectors) { 427 | int length = 0; 428 | int nnz = 0; 429 | boolean copyDates = true; 430 | for (int i = 0; i < vectors.length; i++) { 431 | MLSparseVector vector = vectors[i]; 432 | if (vector.getIndexes() != null) { 433 | nnz += vector.getIndexes().length; 434 | } 435 | length += vectors[i].getLength(); 436 | if (vector.getIndexes() != null && vector.getDates() == null) { 437 | // all vectors must have dates to concat 438 | copyDates = false; 439 | } 440 | } 441 | int[] indexes = new int[nnz]; 442 | float[] values = new float[nnz]; 443 | long[] dates = null; 444 | if (copyDates == true) { 445 | dates = new long[nnz]; 446 | } 447 | int cur = 0; 448 | int offset = 0; 449 | for (int i = 0; i < vectors.length; i++) { 450 | MLSparseVector vector = vectors[i]; 451 | int[] vecInds = vector.getIndexes(); 452 | if (vecInds != null) { 453 | float[] vecVals = vector.getValues(); 454 | long[] vecDates = vector.getDates(); 455 | for (int j = 0; j < vecInds.length; j++) { 456 | indexes[cur] = offset + vecInds[j]; 457 | values[cur] = vecVals[j]; 458 | if (copyDates == true) { 459 | dates[cur] = vecDates[j]; 460 | } 461 | cur++; 462 | } 463 | } 464 | offset += vector.getLength(); 465 | } 466 | return new MLSparseVector(indexes, values, dates, length); 467 | } 468 | 469 | public static MLSparseVector fromDense(final MLDenseVector dense) { 470 | float[] denseVals = dense.getValues(); 471 | 472 | int nnz = 0; 473 | for (int i = 0; i < denseVals.length; i++) { 474 | if (denseVals[i] != 0) { 475 | nnz++; 476 | } 477 | } 478 | if (nnz == 0) { 479 | return new MLSparseVector(null, null, null, denseVals.length); 480 | } 481 | 482 | int[] indexes = new int[nnz]; 483 | float[] values = new float[nnz]; 484 | int cur = 0; 485 | for (int i = 0; i < denseVals.length; i++) { 486 | if (denseVals[i] != 0) { 487 | indexes[cur] = i; 488 | values[cur] = denseVals[i]; 489 | cur++; 490 | } 491 | } 492 | return new MLSparseVector(indexes, values, null, denseVals.length); 493 | } 494 | 495 | public static MLSparseVector mean(MLSparseVector... input) { 496 | int n = input.length; 497 | 498 | if (n == 0) { 499 | throw new IllegalArgumentException("Can't average over no vectors."); 500 | } 501 | 502 | int d = input[0].length; 503 | for (int i = 0; i < n; i++) { 504 | if (input[i].length != d) { 505 | throw new IllegalArgumentException("Vector at position " + i + " has length " + input[i].length 506 | + " but first vector had length " + d + "."); 507 | } 508 | } 509 | 510 | float[] result = new float[d]; 511 | 512 | for (int i = 0; i < n; i++) { 513 | for (int j = 0; j < input[i].indexes.length; j++) { 514 | int index = input[i].indexes[j]; 515 | float value = input[i].values[j]; 516 | result[index] += value; 517 | } 518 | } 519 | 520 | for (int j = 0; j < d; j++) { 521 | result[j] /= n; 522 | } 523 | 524 | return MLSparseVector.fromDense(new MLDenseVector(result)); 525 | } 526 | 527 | } 528 | --------------------------------------------------------------------------------