├── src ├── rating │ ├── RatingReader.java │ └── RatingRecord.java ├── personalrank │ ├── Item.java │ ├── GraphThread.java │ ├── ListenReader.java │ ├── IDManager.java │ └── Graph.java └── svd │ ├── sgd │ ├── Model.java │ └── Trainer.java │ ├── SvdModel.java │ ├── MovieLensDataReader.java │ └── bgd │ ├── Trainer.java │ └── Model.java └── README.md /src/rating/RatingReader.java: -------------------------------------------------------------------------------- 1 | package rating; 2 | 3 | 4 | public abstract class RatingReader { 5 | public abstract int getUserNum(); 6 | public abstract int getItemNum(); 7 | public abstract double getMinPref(); 8 | public abstract double getMaxPref(); 9 | 10 | public abstract boolean hasNext(); 11 | public abstract RatingRecord getNext(); 12 | public abstract void reset(); 13 | } 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # svd for recommendation 2 | reference: 3 | 4 | 《A Guide to Singular Value Decomposition for Collaborative Filtering》 5 | 6 | 《Collaborative Filtering for Netflix》 7 | 8 | 9 | #personal rank 10 | reference: 11 | 12 | http://blog.csdn.net/harryhuang1990/article/details/10048383 13 | 用PersonalRank实现基于图的推荐算法 14 | 15 | https://github.com/wangrn/RecommendationSystemStudy/blob/master/PersonalRank.py 16 | -------------------------------------------------------------------------------- /src/rating/RatingRecord.java: -------------------------------------------------------------------------------- 1 | package rating; 2 | 3 | public class RatingRecord { 4 | 5 | int user; 6 | int item; 7 | double pref; 8 | 9 | 10 | public RatingRecord(int u,int i,double p) 11 | { 12 | user=u; 13 | item=i; 14 | pref=p; 15 | } 16 | 17 | public int getUser() 18 | { 19 | return user; 20 | } 21 | public int getItem() 22 | { 23 | return item; 24 | } 25 | public double getPref() 26 | { 27 | return pref; 28 | } 29 | 30 | 31 | } 32 | -------------------------------------------------------------------------------- /src/personalrank/Item.java: -------------------------------------------------------------------------------- 1 | package personalrank; 2 | 3 | 4 | 5 | class Item implements Comparable 6 | { 7 | int id; 8 | float weight; 9 | public Item(int id,float weight) 10 | { 11 | this.id=id; 12 | this.weight=weight; 13 | } 14 | 15 | 16 | 17 | @Override 18 | public int compareTo(Item o) { 19 | if(weighto.weight) 24 | { 25 | return -1; 26 | } 27 | else return 0; 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /src/svd/sgd/Model.java: -------------------------------------------------------------------------------- 1 | package svd.sgd; 2 | 3 | import svd.SvdModel; 4 | 5 | public class Model extends SvdModel{ 6 | 7 | public Model(int un, int in, int r, double minPref, double maxPref) { 8 | super(un, in, r, minPref, maxPref); 9 | // TODO Auto-generated constructor stub 10 | } 11 | 12 | public double update(int user,int item,double pref) 13 | { 14 | double predicted = predict(user, item); 15 | double error= pref - predicted; 16 | 17 | double[] userVecotr = new double[rank]; 18 | for(int r=0;r(maxPref-minPref)) 63 | { 64 | score = maxPref; 65 | } 66 | else if(score<0) 67 | { 68 | score = minPref; 69 | } 70 | else 71 | { 72 | score +=minPref; 73 | } 74 | 75 | return score; 76 | } 77 | 78 | 79 | 80 | 81 | 82 | } 83 | -------------------------------------------------------------------------------- /src/personalrank/GraphThread.java: -------------------------------------------------------------------------------- 1 | package personalrank; 2 | 3 | import java.io.IOException; 4 | 5 | public class GraphThread extends Thread{ 6 | 7 | Graph graph; 8 | int start; 9 | int end; 10 | public GraphThread(Graph graph,int start,int end) 11 | { 12 | this.graph = graph; 13 | this.start =start; 14 | this.end = end; 15 | } 16 | 17 | public void run() 18 | { 19 | graph.recommend(start, end); 20 | } 21 | 22 | public static void main(String[] args) throws IOException, InterruptedException { 23 | // TODO Auto-generated method stub 24 | 25 | ListenReader reader = new ListenReader("dataset/user_music.dat"); 26 | Graph graph = new Graph(reader); 27 | 28 | int totalUser = graph.idManager.getUserMaxIternalID() +1; 29 | 30 | int threadNum = 10; 31 | GraphThread[] threads = new GraphThread[threadNum]; 32 | int start=-1; 33 | int end=-1; 34 | long start_time = System.currentTimeMillis(); 35 | for(int i=0;itotalUser) break; 38 | end=start+totalUser/threadNum; if(end>totalUser) end = totalUser; 39 | System.out.println("thread:"+i); 40 | System.out.println("start:"+start); 41 | System.out.println("end:"+end); 42 | 43 | threads[i] = new GraphThread(graph,start,end); 44 | threads[i].start(); 45 | 46 | } 47 | 48 | for(int i=0;i records; 17 | Iterator iter; 18 | BufferedReader br; 19 | public ListenReader(String filePath) throws IOException 20 | { 21 | records = new LinkedList(); 22 | br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(filePath)))); 23 | String line; 24 | while(true) 25 | { 26 | line=br.readLine(); 27 | if(line==null) break; 28 | String[] fields=line.split(","); 29 | int u= Integer.parseInt(fields[0]); 30 | int i= Integer.parseInt(fields[1]); 31 | records.push(new RatingRecord(u,i,1)); 32 | } 33 | br.close(); 34 | 35 | } 36 | 37 | 38 | 39 | 40 | @Override 41 | public double getMinPref() { 42 | // TODO Auto-generated method stub 43 | return 1; 44 | } 45 | 46 | @Override 47 | public double getMaxPref() { 48 | // TODO Auto-generated method stub 49 | return 1; 50 | } 51 | 52 | @Override 53 | public boolean hasNext() { 54 | // TODO Auto-generated method stub 55 | return iter.hasNext(); 56 | } 57 | 58 | @Override 59 | public RatingRecord getNext() { 60 | // TODO Auto-generated method stub 61 | return iter.next(); 62 | } 63 | 64 | @Override 65 | public void reset() { 66 | // TODO Auto-generated method stub 67 | iter = records.iterator(); 68 | } 69 | 70 | //以下两个方法是无效方法 71 | @Override 72 | public int getUserNum() { 73 | // TODO Auto-generated method stub 74 | return 0; 75 | } 76 | 77 | @Override 78 | public int getItemNum() { 79 | // TODO Auto-generated method stub 80 | return 0; 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /src/svd/MovieLensDataReader.java: -------------------------------------------------------------------------------- 1 | package svd; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.File; 5 | import java.io.FileInputStream; 6 | import java.io.IOException; 7 | import java.io.InputStreamReader; 8 | import java.util.Iterator; 9 | import java.util.LinkedList; 10 | 11 | import rating.RatingReader; 12 | import rating.RatingRecord; 13 | 14 | public class MovieLensDataReader extends RatingReader{ 15 | 16 | 17 | LinkedList records; 18 | Iterator iter; 19 | BufferedReader br; 20 | int maxUser; 21 | int maxItem; 22 | double minPref; 23 | double maxPref; 24 | 25 | public MovieLensDataReader(String filePath) throws IOException 26 | { 27 | minPref = Double.MAX_VALUE; 28 | maxPref = Double.MIN_VALUE; 29 | 30 | maxUser=-1; 31 | maxItem=-1; 32 | records = new LinkedList(); 33 | 34 | br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(filePath)))); 35 | String line; 36 | while(true) 37 | { 38 | line=br.readLine(); 39 | if(line==null) break; 40 | String[] fields=line.split("\t"); 41 | int u= Integer.parseInt(fields[0]) -1;//MovieLens数据集的id从1开始 42 | int i= Integer.parseInt(fields[1]) -1; 43 | double p= Double.parseDouble(fields[2]); 44 | records.push(new RatingRecord(u,i,p)); 45 | 46 | maxUser = maxUserlastRmse) break; 59 | if(count>maxIter) break; 60 | 61 | lastRmse = rmse; 62 | trainData.reset(); 63 | } 64 | 65 | } 66 | 67 | public void test(String testFile) throws IOException 68 | { 69 | RatingReader testData = new MovieLensDataReader(testFile); 70 | double sum=0.0; 71 | int n=0; 72 | while(testData.hasNext()) 73 | { 74 | RatingRecord rating = testData.getNext(); 75 | int u=rating.getUser(); 76 | int i=rating.getItem(); 77 | double p=rating.getPref(); 78 | double predicted = model.predict(u, i); 79 | double error=(p-predicted)*(p-predicted); 80 | sum+=error; 81 | n++; 82 | } 83 | double rmse=Math.sqrt(sum/n); 84 | System.out.println("test rmse:"+rmse); 85 | 86 | } 87 | 88 | public static void main(String[] args) throws IOException { 89 | // TODO Auto-generated method stub 90 | 91 | Trainer trainer = new Trainer("dataset/ua.base"); 92 | trainer.train(1000); 93 | trainer.test("dataset/ua.test"); 94 | 95 | } 96 | 97 | } 98 | -------------------------------------------------------------------------------- /src/svd/bgd/Trainer.java: -------------------------------------------------------------------------------- 1 | package svd.bgd; 2 | 3 | import java.io.IOException; 4 | 5 | import rating.RatingReader; 6 | import rating.RatingRecord; 7 | import svd.MovieLensDataReader; 8 | 9 | 10 | 11 | public class Trainer { 12 | 13 | Model model; 14 | RatingReader trainData; 15 | String trainFile; 16 | 17 | 18 | int userNum; 19 | int itemNum; 20 | 21 | public Trainer(String trainFilePath) throws IOException 22 | { 23 | trainFile = trainFilePath; 24 | trainData = new MovieLensDataReader(trainFile); 25 | userNum = trainData.getUserNum(); 26 | itemNum = trainData.getItemNum(); 27 | double minPref=trainData.getMinPref(); 28 | double maxPref=trainData.getMaxPref(); 29 | 30 | int rank =150; 31 | 32 | double[][] ratingMatrix = new double[userNum][itemNum]; 33 | 34 | for(int u=0;u user2Internal; 14 | private int userMaxIternalID; 15 | private HashMap item2Internal; 16 | private int itemMaxIternalID; 17 | private int[] internalMapping; 18 | 19 | public int[] getInternalMapping() 20 | { 21 | return internalMapping; 22 | } 23 | 24 | public HashMap getAllUser() 25 | { 26 | return user2Internal; 27 | } 28 | 29 | public int totalUserItem() 30 | { 31 | return itemMaxIternalID+1; 32 | } 33 | 34 | public int getUserMaxIternalID() 35 | { 36 | return userMaxIternalID; 37 | } 38 | 39 | public boolean isUser(int internalID) 40 | { 41 | return internalID<=userMaxIternalID; 42 | } 43 | 44 | public boolean isItem(int internalID) 45 | { 46 | return internalID>userMaxIternalID; 47 | } 48 | 49 | public int getOriginID(int internalID) 50 | { 51 | if(internalID (); 73 | item2Internal = new HashMap (); 74 | 75 | reader.reset(); 76 | int count=0; 77 | while(reader.hasNext()) 78 | { 79 | RatingRecord record = reader.getNext(); 80 | int u = record.getUser(); 81 | if(!user2Internal.containsKey(u)) 82 | { 83 | user2Internal.put(u, count++); 84 | } 85 | } 86 | userMaxIternalID=count-1; 87 | 88 | reader.reset(); 89 | while(reader.hasNext()) 90 | { 91 | RatingRecord record = reader.getNext(); 92 | int i = record.getItem(); 93 | if(!item2Internal.containsKey(i)) 94 | { 95 | item2Internal.put(i, count++); 96 | } 97 | } 98 | 99 | assert(user2Internal.size()+item2Internal.size()==count); 100 | itemMaxIternalID = count-1; 101 | 102 | internalMapping = new int[count]; 103 | 104 | Iterator> iter = user2Internal.entrySet().iterator(); 105 | Entry entry; 106 | while(iter.hasNext()) 107 | { 108 | entry = iter.next(); 109 | internalMapping[entry.getValue()]=entry.getKey(); 110 | } 111 | 112 | iter = item2Internal.entrySet().iterator(); 113 | while(iter.hasNext()) 114 | { 115 | entry = iter.next(); 116 | internalMapping[entry.getValue()]=entry.getKey(); 117 | } 118 | 119 | 120 | } 121 | 122 | 123 | 124 | } 125 | -------------------------------------------------------------------------------- /src/personalrank/Graph.java: -------------------------------------------------------------------------------- 1 | package personalrank; 2 | 3 | import java.io.IOException; 4 | import java.util.ArrayList; 5 | import java.util.Collections; 6 | import java.util.HashMap; 7 | import java.util.Iterator; 8 | import java.util.Map.Entry; 9 | 10 | import rating.RatingRecord; 11 | 12 | public class Graph { 13 | 14 | 15 | 16 | ArrayList > graph; 17 | 18 | IDManager idManager; 19 | public Graph(ListenReader reader) 20 | { 21 | idManager = new IDManager(reader); 22 | 23 | int userItemNum = idManager.totalUserItem(); 24 | 25 | graph = new ArrayList >(); 26 | for(int i=0;i());//申请空间 29 | } 30 | 31 | reader.reset(); 32 | RatingRecord record; 33 | int u; 34 | int i; 35 | float p; 36 | while(reader.hasNext()) 37 | { 38 | record = reader.getNext(); 39 | u = idManager.getUserInternalID(record.getUser()); 40 | i = idManager.getItemInternalID(record.getItem()); 41 | p = (float) record.getPref(); 42 | graph.get(u).put(i, p); 43 | graph.get(i).put(u, p); 44 | } 45 | 46 | } 47 | 48 | 49 | public void recommend(int start,int end)//为了多线程调用 50 | { 51 | float alpha = (float) 0.8; 52 | int maxIters = 20; 53 | int k = 100; 54 | 55 | int[] internalMapping = idManager.getInternalMapping(); 56 | int userMaxIternalID = idManager.getUserMaxIternalID(); 57 | 58 | if(start>userMaxIternalID || end>userMaxIternalID) return; 59 | int originID; 60 | int internalID; 61 | for(int i=start;i sorted = personalRank(internalID, alpha, maxIters); 66 | 67 | for(int j=0;j user2Internal = idManager.getAllUser(); 82 | 83 | System.out.println(user2Internal.size());//10449 84 | //已推荐10400位 85 | // 总耗时为:2544秒 86 | 87 | float alpha = (float) 0.8; 88 | int maxIters = 20; 89 | int k = 100; 90 | 91 | Iterator> iter = user2Internal.entrySet().iterator(); 92 | int originID; 93 | int internalID; 94 | Entry entry; 95 | long startMili=System.currentTimeMillis();// 当前时间对应的毫秒数 96 | int count=0; 97 | while(iter.hasNext()) 98 | { 99 | entry = iter.next(); 100 | originID = entry.getKey(); 101 | internalID = entry.getValue(); 102 | ArrayList sorted = personalRank(internalID, alpha, maxIters); 103 | 104 | for(int i=0;i sorted = personalRank(userInternalID, alpha, 20); 134 | 135 | int k = 100; 136 | for(int i=0;i personalRank(int userInternalID,float alpha,int maxIters) 147 | { 148 | int userItemNum = idManager.totalUserItem(); 149 | float[] rank = new float[userItemNum]; 150 | for(int i=0;i edges; 158 | Iterator> edgeIter; 159 | Entry edge; 160 | int j; 161 | for(int iter=0;iter candidates = new ArrayList(); 192 | //过滤+排序 193 | for(int i=idManager.getUserMaxIternalID()+1;i