├── .DS_Store ├── .gitattributes ├── DBLPPreProcess ├── .DS_Store ├── .classpath ├── .project ├── .settings │ └── org.eclipse.jdt.core.prefs ├── bin │ └── .DS_Store └── src │ ├── PreProcessStep1.java │ ├── PreProcessStep2.java │ └── PreProcessStep3.java ├── README.md └── RandomWalk2Vec ├── RandomWalk2VecRun.sh ├── compile.sh └── randomWalk2vec.c /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daokunzhang/MetaGraph2Vec/7fe0fab3d0948aca69e81b5d4e230927d75676d8/.DS_Store -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /DBLPPreProcess/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daokunzhang/MetaGraph2Vec/7fe0fab3d0948aca69e81b5d4e230927d75676d8/DBLPPreProcess/.DS_Store -------------------------------------------------------------------------------- /DBLPPreProcess/.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /DBLPPreProcess/.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | DBLPPreProcess 4 | 5 | 6 | 7 | 8 | 9 | org.eclipse.jdt.core.javabuilder 10 | 11 | 12 | 13 | 14 | 15 | org.eclipse.jdt.core.javanature 16 | 17 | 18 | -------------------------------------------------------------------------------- /DBLPPreProcess/.settings/org.eclipse.jdt.core.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled 3 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 4 | org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve 5 | org.eclipse.jdt.core.compiler.compliance=1.8 6 | org.eclipse.jdt.core.compiler.debug.lineNumber=generate 7 | org.eclipse.jdt.core.compiler.debug.localVariable=generate 8 | org.eclipse.jdt.core.compiler.debug.sourceFile=generate 9 | org.eclipse.jdt.core.compiler.problem.assertIdentifier=error 10 | org.eclipse.jdt.core.compiler.problem.enumIdentifier=error 11 | org.eclipse.jdt.core.compiler.source=1.8 12 | -------------------------------------------------------------------------------- /DBLPPreProcess/bin/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daokunzhang/MetaGraph2Vec/7fe0fab3d0948aca69e81b5d4e230927d75676d8/DBLPPreProcess/bin/.DS_Store -------------------------------------------------------------------------------- /DBLPPreProcess/src/PreProcessStep1.java: -------------------------------------------------------------------------------- 1 | import java.io.BufferedReader; 2 | import java.io.BufferedWriter; 3 | import java.io.File; 4 | import java.io.FileInputStream; 5 | import java.io.FileOutputStream; 6 | import java.io.IOException; 7 | import java.io.InputStreamReader; 8 | import java.io.OutputStreamWriter; 9 | import java.util.ArrayList; 10 | 11 | public class PreProcessStep1 { // extract subgraph form the DBLP-Citation-Network-V3.txt network 12 | 13 | public static int nonAlphabet(char ch) 14 | { 15 | if(ch>='A'&&ch<='Z') 16 | return 0; 17 | if(ch>='a'&&ch<='z') 18 | return 0; 19 | return 1; 20 | } 21 | 22 | public static int check(String venue) 23 | { 24 | String[] selectedVenues = {"SIGMOD", "ICDE", "VLDB", "EDBT", 25 | "PODS", "ICDT", "DASFAA", "SSDBM", "CIKM", "KDD", 26 | "ICDM", "SDM", "PKDD", "PAKDD", "IJCAI", "AAAI", 27 | "NIPS", "ICML", "ECML", "ACML", "IJCNN", "UAI", 28 | "ECAI", "COLT", "ACL", "KR", "CVPR", "ICCV", 29 | "ECCV", "ACCV", "MM", "ICPR", "ICIP", "ICME"}; 30 | int res = 0; 31 | for(int i=0;i=venue.length()) 39 | res = 1; 40 | else 41 | { 42 | char endChar = venue.charAt(index+selectedVenues[i].length()); 43 | if(nonAlphabet(endChar)==1) 44 | res = 1; 45 | } 46 | } 47 | else 48 | { 49 | char startChar = venue.charAt(index-1); 50 | if(nonAlphabet(startChar)==1) 51 | { 52 | if(index+selectedVenues[i].length()>=venue.length()) 53 | res = 1; 54 | else 55 | { 56 | char endChar = venue.charAt(index+selectedVenues[i].length()); 57 | if(nonAlphabet(endChar)==1) 58 | res = 1; 59 | } 60 | } 61 | } 62 | } 63 | if(res==1) 64 | break; 65 | } 66 | if(res==1) 67 | { 68 | if(venue.indexOf("SIGMOD")!=-1&&venue.indexOf("Data Mining")!=-1) 69 | res = 0; 70 | } 71 | return res; 72 | } 73 | 74 | public static void main(String[] args) throws IOException { 75 | // TODO Auto-generated method stub 76 | BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-Citation-Network-V3.txt"), "UTF-8")); 77 | BufferedWriter bwTitle = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("dataset"+File.separator+"DBLP-V3.paper.title.txt"), "UTF-8")); 78 | BufferedWriter bwAuthor = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("dataset"+File.separator+"DBLP-V3.paper.author.txt"), "UTF-8")); 79 | BufferedWriter bwVenue = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("dataset"+File.separator+"DBLP-V3.paper.venue.txt"), "UTF-8")); 80 | BufferedWriter bwRef = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("dataset"+File.separator+"DBLP-V3.paper.ref.txt"), "UTF-8")); 81 | int numOfNode = 0; 82 | String line = null; 83 | br.readLine(); 84 | String curIndex = ""; 85 | String curTitle = ""; 86 | String curAuthors = ""; 87 | String[] curAuthorList = null; 88 | String curVenue = ""; 89 | ArrayList curRef = new ArrayList(); 90 | int writeStatus = 0; 91 | int maxRefSize = 0; 92 | String maxRefIndex = ""; 93 | while((line=br.readLine())!=null) 94 | { 95 | if(line.indexOf("#*")==0) 96 | curTitle = line.substring(2).trim(); 97 | if(line.indexOf("#c")==0) 98 | { 99 | curVenue = line.substring(2).trim(); 100 | curVenue = curVenue.replaceAll(" ", ""); 101 | if(check(curVenue)==1) 102 | writeStatus = 1; 103 | else 104 | writeStatus = 0; 105 | } 106 | if(line.indexOf("#%")==0) 107 | curRef.add(line.substring(2).trim()); 108 | if(line.indexOf("#@")==0) 109 | { 110 | curAuthors = line.substring(2).trim(); 111 | curAuthorList = curAuthors.split(","); 112 | } 113 | if(line.indexOf("#index")==0) 114 | curIndex = line.substring(6).trim(); 115 | if(line.trim().length()==0) 116 | { 117 | if((!curTitle.equals(""))&&(!curVenue.equals(""))&&(!curAuthors.equals(""))&&(!curIndex.equals(""))&&writeStatus==1) 118 | { 119 | curIndex = curIndex.replaceAll(" ", ""); 120 | bwTitle.write("p_"+curIndex+"\t"+curTitle+"\n"); 121 | for(int i=0;imaxRefSize) 137 | { 138 | maxRefSize = curRef.size(); 139 | maxRefIndex = curIndex; 140 | } 141 | numOfNode++; 142 | } 143 | curTitle = ""; 144 | curAuthors = ""; 145 | curVenue = ""; 146 | curRef.clear(); 147 | curIndex = ""; 148 | } 149 | } 150 | br.close(); 151 | bwTitle.close(); 152 | bwAuthor.close(); 153 | bwVenue.close(); 154 | bwRef.close(); 155 | System.out.println("numOfNode = "+numOfNode); 156 | System.out.println("maxRefSize = "+maxRefSize); 157 | System.out.println("maxRefIndex = "+maxRefIndex); 158 | } 159 | 160 | } 161 | -------------------------------------------------------------------------------- /DBLPPreProcess/src/PreProcessStep2.java: -------------------------------------------------------------------------------- 1 | import java.io.BufferedReader; 2 | import java.io.BufferedWriter; 3 | import java.io.File; 4 | import java.io.FileInputStream; 5 | import java.io.FileOutputStream; 6 | import java.io.IOException; 7 | import java.io.InputStreamReader; 8 | import java.io.OutputStreamWriter; 9 | import java.util.ArrayList; 10 | import java.util.Collections; 11 | import java.util.Comparator; 12 | import java.util.HashMap; 13 | import java.util.Map; 14 | import java.util.Random; 15 | 16 | public class PreProcessStep2 { // obtain the class labels for authors; 17 | 18 | public static int getClass(String venue) 19 | { 20 | String[] venues1 = {"SIGMOD", "ICDE", "VLDB", "EDBT", 21 | "PODS", "ICDT", "DASFAA", "SSDBM", "CIKM"}; 22 | String[] venues2 = {"KDD", "ICDM", "SDM", "PKDD", "PAKDD"}; 23 | String[] venues3 = {"IJCAI", "AAAI", "NIPS", "ICML", "ECML", 24 | "ACML", "IJCNN", "UAI", "ECAI", "COLT", "ACL", "KR"}; 25 | String[] venues4 = {"CVPR", "ICCV", "ECCV", "ACCV", "MM", 26 | "ICPR", "ICIP", "ICME"}; 27 | ArrayList> venueLists = new ArrayList>(); 28 | ArrayList venueList1 = new ArrayList(); 29 | for(int i=0;i venueList2 = new ArrayList(); 33 | for(int i=0;i venueList3 = new ArrayList(); 37 | for(int i=0;i venueList4 = new ArrayList(); 41 | for(int i=0;i0;i--) 50 | { 51 | int randNumClass = randomClass.nextInt(i); 52 | int temp = arrClass[i]; 53 | arrClass[i] = arrClass[randNumClass]; 54 | arrClass[randNumClass] = temp; 55 | } 56 | for(int i=0;i<4;i++) 57 | { 58 | for(int j=0;j author2Id = new HashMap(); 74 | Map venue2Id = new HashMap(); 75 | Map paper2Id = new HashMap(); 76 | Map Id2Author = new HashMap(); 77 | Map Id2Venue = new HashMap(); 78 | Map Id2Paper = new HashMap(); 79 | 80 | String line = null; 81 | int paperIndex = 0, authorIndex = 0, venueIndex = 0; 82 | while((line=brAuthor.readLine())!=null) 83 | { 84 | int splitPos = line.indexOf("\t"); 85 | String paper = line.substring(0, splitPos); 86 | String author = line.substring(splitPos+1); 87 | if(!paper2Id.containsKey(paper)) 88 | { 89 | paper2Id.put(paper, paperIndex); 90 | Id2Paper.put(paperIndex, paper); 91 | paperIndex++; 92 | } 93 | if(!author2Id.containsKey(author)) 94 | { 95 | author2Id.put(author, authorIndex); 96 | Id2Author.put(authorIndex, author); 97 | authorIndex++; 98 | } 99 | } 100 | brAuthor.close(); 101 | while((line=brVenue.readLine())!=null) 102 | { 103 | int splitPos = line.indexOf("\t"); 104 | String venue = line.substring(splitPos+1); 105 | if(!venue2Id.containsKey(venue)) 106 | { 107 | venue2Id.put(venue, venueIndex); 108 | Id2Venue.put(venueIndex, venue); 109 | venueIndex++; 110 | } 111 | } 112 | brVenue.close(); 113 | int numOfPaper = paperIndex; 114 | int numOfAuthor = authorIndex; 115 | int numOfVenue = venueIndex; 116 | 117 | BufferedWriter bwPaperPaper = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("dataset"+File.separator+"DBLP-V3.paper.paper.txt"), "UTF-8")); 118 | while((line=brRef.readLine())!=null) 119 | { 120 | int splitPos = line.indexOf("\t"); 121 | String paper1 = line.substring(0, splitPos); 122 | String paper2 = line.substring(splitPos+1); 123 | if(paper2Id.containsKey(paper2)) 124 | bwPaperPaper.write(paper1+"\t"+paper2+"\n"); 125 | } 126 | brRef.close(); 127 | bwPaperPaper.close(); 128 | 129 | ArrayList> authorNeighbors = new ArrayList>(); 130 | ArrayList> paperNeighborsVenue = new ArrayList>(); 131 | for(int i=0;i neighbors = new ArrayList(); 134 | authorNeighbors.add(neighbors); 135 | } 136 | for(int i=0;i neighborsVenue = new ArrayList(); 139 | paperNeighborsVenue.add(neighborsVenue); 140 | } 141 | brAuthor = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.author.txt"), "UTF-8")); 142 | brVenue = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.venue.txt"), "UTF-8")); 143 | while((line=brAuthor.readLine())!=null) 144 | { 145 | int splitPos = line.indexOf("\t"); 146 | String paper = line.substring(0, splitPos); 147 | String author = line.substring(splitPos+1); 148 | int paperId = paper2Id.get(paper); 149 | int authorId = author2Id.get(author); 150 | if(!authorNeighbors.get(authorId).contains(paperId)) 151 | authorNeighbors.get(authorId).add(paperId); 152 | } 153 | brAuthor.close(); 154 | while((line=brVenue.readLine())!=null) 155 | { 156 | int splitPos = line.indexOf("\t"); 157 | String paper = line.substring(0, splitPos); 158 | String venue = line.substring(splitPos+1); 159 | int paperId = paper2Id.get(paper); 160 | int venueId = venue2Id.get(venue); 161 | if(!paperNeighborsVenue.get(paperId).contains(venueId)) 162 | paperNeighborsVenue.get(paperId).add(venueId); 163 | } 164 | brVenue.close(); 165 | for(int i=0;i() { 168 | public int compare(Integer a, Integer b) { 169 | return a.compareTo(b); 170 | } 171 | }); 172 | } 173 | for(int i=0;i() { 176 | public int compare(Integer a, Integer b) { 177 | return a.compareTo(b); 178 | } 179 | }); 180 | } 181 | 182 | int[][] venueClass = new int[numOfVenue][4]; 183 | int[][] authorClass = new int[numOfAuthor][4]; 184 | for(int i=0;i0;j--) 215 | { 216 | int randNumClass = randomClass.nextInt(j); 217 | int temp = arrClass[j]; 218 | arrClass[j] = arrClass[randNumClass]; 219 | arrClass[randNumClass] = temp; 220 | } 221 | int label = -1; 222 | int maxNum = -1; 223 | for(int j=0;j<4;j++) 224 | { 225 | if(authorClass[i][arrClass[j]]>maxNum) 226 | { 227 | maxNum = authorClass[i][arrClass[j]]; 228 | label = arrClass[j]; 229 | } 230 | } 231 | if(label==-1) 232 | System.out.println("no class label for author "+i); 233 | else 234 | { 235 | for(int j=0;j<4;j++) 236 | authorClass[i][j] = 0; 237 | authorClass[i][label] = 1; 238 | } 239 | } 240 | 241 | BufferedWriter bwAuthorClass = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("group"+File.separator+"DBLP-V3.author.class.txt"), "UTF-8")); 242 | for(int i=0;i author2Id; 19 | public static Map venue2Id; 20 | public static Map paper2Id; 21 | public static Map Id2Author; 22 | public static Map Id2Venue; 23 | public static Map Id2Paper; 24 | 25 | public static ArrayList> authorNeighbors; 26 | public static ArrayList> venueNeighbors; 27 | public static ArrayList> paperNeighborsAuthor; 28 | public static ArrayList> paperNeighborsPaper; 29 | public static ArrayList> paperNeighborsVenue; 30 | 31 | public static String metaGraphRandomWalk(int startNodeId, String type, int pathLen) 32 | { 33 | int curNodeId = startNodeId; 34 | String curType = type; 35 | int curPos = 1; 36 | Random random = new Random(); 37 | StringBuffer path = new StringBuffer(""); 38 | for(int i=0;i0) 44 | path.append(" "); 45 | path.append(paper); 46 | if(curPos==2) 47 | { 48 | if(paperNeighborsAuthor.get(curNodeId).size()>1&&paperNeighborsVenue.get(curNodeId).size()>0) 49 | { 50 | double option = Math.random(); 51 | if(option<1.0/2.0) 52 | { 53 | int nextIndex = random.nextInt(paperNeighborsAuthor.get(curNodeId).size()); 54 | curNodeId = paperNeighborsAuthor.get(curNodeId).get(nextIndex); 55 | curPos++; 56 | curType = "A"; 57 | } 58 | else 59 | { 60 | int nextIndex = random.nextInt(paperNeighborsVenue.get(curNodeId).size()); 61 | curNodeId = paperNeighborsVenue.get(curNodeId).get(nextIndex); 62 | curPos++; 63 | curType = "V"; 64 | } 65 | } 66 | else if(paperNeighborsAuthor.get(curNodeId).size()>1) 67 | { 68 | int nextIndex = random.nextInt(paperNeighborsAuthor.get(curNodeId).size()); 69 | curNodeId = paperNeighborsAuthor.get(curNodeId).get(nextIndex); 70 | curPos++; 71 | curType = "A"; 72 | } 73 | else if(paperNeighborsVenue.get(curNodeId).size()>0) 74 | { 75 | int nextIndex = random.nextInt(paperNeighborsVenue.get(curNodeId).size()); 76 | curNodeId = paperNeighborsVenue.get(curNodeId).get(nextIndex); 77 | curPos++; 78 | curType = "V"; 79 | } 80 | else 81 | return path.toString(); 82 | } 83 | else 84 | { 85 | if(paperNeighborsAuthor.get(curNodeId).size()==0) 86 | return path.toString(); 87 | int nextIndex = random.nextInt(paperNeighborsAuthor.get(curNodeId).size()); 88 | curNodeId = paperNeighborsAuthor.get(curNodeId).get(nextIndex); 89 | curPos++; 90 | if(curPos>=5) 91 | curPos = 1; 92 | curType = "A"; 93 | } 94 | } 95 | else 96 | { 97 | if(curType.equals("A")) 98 | { 99 | String author = Id2Author.get(curNodeId); 100 | if(i>0) 101 | path.append(" "); 102 | path.append(author); 103 | if(authorNeighbors.get(curNodeId).size()==0) 104 | return path.toString(); 105 | int nextIndex = random.nextInt(authorNeighbors.get(curNodeId).size()); 106 | curNodeId = authorNeighbors.get(curNodeId).get(nextIndex); 107 | curPos++; 108 | curType = "P"; 109 | } 110 | else if(curType.equals("V")) 111 | { 112 | String venue = Id2Venue.get(curNodeId); 113 | if(i>0) 114 | path.append(" "); 115 | path.append(venue); 116 | if(venueNeighbors.get(curNodeId).size()==0) 117 | return path.toString(); 118 | int nextIndex = random.nextInt(venueNeighbors.get(curNodeId).size()); 119 | curNodeId = venueNeighbors.get(curNodeId).get(nextIndex); 120 | curPos++; 121 | curType = "P"; 122 | } 123 | } 124 | } 125 | return path.toString(); 126 | } 127 | 128 | public static String metaPathAPVPARandomWalk(int startNodeId, String type, int pathLen) 129 | { 130 | int curNodeId = startNodeId; 131 | String curType = type; 132 | String lastType = ""; 133 | Random random = new Random(); 134 | StringBuffer path = new StringBuffer(""); 135 | for(int i=0;i0) 141 | path.append(" "); 142 | path.append(paper); 143 | if(lastType.equals("A")) 144 | { 145 | if(paperNeighborsVenue.get(curNodeId).size()==0) 146 | return path.toString(); 147 | int nextIndex = random.nextInt(paperNeighborsVenue.get(curNodeId).size()); 148 | curNodeId = paperNeighborsVenue.get(curNodeId).get(nextIndex); 149 | lastType = curType; 150 | curType = "V"; 151 | } 152 | else 153 | { 154 | if(paperNeighborsAuthor.get(curNodeId).size()==0) 155 | return path.toString(); 156 | int nextIndex = random.nextInt(paperNeighborsAuthor.get(curNodeId).size()); 157 | curNodeId = paperNeighborsAuthor.get(curNodeId).get(nextIndex); 158 | lastType = curType; 159 | curType = "A"; 160 | } 161 | } 162 | else 163 | { 164 | if(curType.equals("A")) 165 | { 166 | String author = Id2Author.get(curNodeId); 167 | if(i>0) 168 | path.append(" "); 169 | path.append(author); 170 | if(authorNeighbors.get(curNodeId).size()==0) 171 | return path.toString(); 172 | int nextIndex = random.nextInt(authorNeighbors.get(curNodeId).size()); 173 | curNodeId = authorNeighbors.get(curNodeId).get(nextIndex); 174 | lastType = curType; 175 | curType = "P"; 176 | } 177 | else 178 | { 179 | String venue = Id2Venue.get(curNodeId); 180 | if(i>0) 181 | path.append(" "); 182 | path.append(venue); 183 | if(venueNeighbors.get(curNodeId).size()==0) 184 | return path.toString(); 185 | int nextIndex = random.nextInt(venueNeighbors.get(curNodeId).size()); 186 | curNodeId = venueNeighbors.get(curNodeId).get(nextIndex); 187 | lastType = curType; 188 | curType = "P"; 189 | } 190 | } 191 | } 192 | return path.toString(); 193 | } 194 | 195 | public static String metaPathAPAPARandomWalk(int startNodeId, String type, int pathLen) 196 | { 197 | int curNodeId = startNodeId; 198 | String curType = type; 199 | Random random = new Random(); 200 | StringBuffer path = new StringBuffer(""); 201 | for(int i=0;i0) 207 | path.append(" "); 208 | path.append(paper); 209 | if(paperNeighborsAuthor.get(curNodeId).size()<=1) 210 | return path.toString(); 211 | int nextIndex = random.nextInt(paperNeighborsAuthor.get(curNodeId).size()); 212 | curNodeId = paperNeighborsAuthor.get(curNodeId).get(nextIndex); 213 | curType = "A"; 214 | } 215 | else 216 | { 217 | String author = Id2Author.get(curNodeId); 218 | if(i>0) 219 | path.append(" "); 220 | path.append(author); 221 | if(authorNeighbors.get(curNodeId).size()==0) 222 | return path.toString(); 223 | int nextIndex = random.nextInt(authorNeighbors.get(curNodeId).size()); 224 | curNodeId = authorNeighbors.get(curNodeId).get(nextIndex); 225 | curType = "P"; 226 | } 227 | } 228 | return path.toString(); 229 | } 230 | 231 | public static String uniformRandomWalk(int startNodeId, String type, int pathLen) 232 | { 233 | int curNodeId = startNodeId; 234 | String curType = type; 235 | Random random = new Random(); 236 | StringBuffer path = new StringBuffer(""); 237 | for(int i=0;i0) 243 | path.append(" "); 244 | path.append(paper); 245 | int authorSize = paperNeighborsAuthor.get(curNodeId).size(); 246 | int paperSize = paperNeighborsPaper.get(curNodeId).size(); 247 | int venueSize = paperNeighborsVenue.get(curNodeId).size(); 248 | if(authorSize+paperSize+venueSize==0) 249 | return path.toString(); 250 | int nextIndex = random.nextInt(authorSize+paperSize+venueSize); 251 | if(nextIndex0) 273 | path.append(" "); 274 | path.append(author); 275 | if(authorNeighbors.get(curNodeId).size()==0) 276 | return path.toString(); 277 | int nextIndex = random.nextInt(authorNeighbors.get(curNodeId).size()); 278 | curNodeId = authorNeighbors.get(curNodeId).get(nextIndex); 279 | curType = "P"; 280 | } 281 | else if(curType.equals("V")) 282 | { 283 | String venue = Id2Venue.get(curNodeId); 284 | if(i>0) 285 | path.append(" "); 286 | path.append(venue); 287 | if(venueNeighbors.get(curNodeId).size()==0) 288 | return path.toString(); 289 | int nextIndex = random.nextInt(venueNeighbors.get(curNodeId).size()); 290 | curNodeId = venueNeighbors.get(curNodeId).get(nextIndex); 291 | curType = "P"; 292 | } 293 | } 294 | } 295 | return path.toString(); 296 | } 297 | 298 | public static void main(String[] args) throws IOException { 299 | // TODO Auto-generated method stub 300 | BufferedReader brPaperAuthor = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.author.txt"), "UTF-8")); 301 | BufferedReader brPaperVenue = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.venue.txt"), "UTF-8")); 302 | BufferedReader brPaperPaper = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.paper.txt"), "UTF-8")); 303 | 304 | author2Id = new HashMap(); 305 | venue2Id = new HashMap(); 306 | paper2Id = new HashMap(); 307 | Id2Author = new HashMap(); 308 | Id2Venue = new HashMap(); 309 | Id2Paper = new HashMap(); 310 | 311 | String line = null; 312 | int paperIndex = 0, authorIndex = 0, venueIndex = 0; 313 | while((line=brPaperAuthor.readLine())!=null) 314 | { 315 | int splitPos = line.indexOf("\t"); 316 | String paper = line.substring(0, splitPos); 317 | String author = line.substring(splitPos+1); 318 | if(!paper2Id.containsKey(paper)) 319 | { 320 | paper2Id.put(paper, paperIndex); 321 | Id2Paper.put(paperIndex, paper); 322 | paperIndex++; 323 | } 324 | if(!author2Id.containsKey(author)) 325 | { 326 | author2Id.put(author, authorIndex); 327 | Id2Author.put(authorIndex, author); 328 | authorIndex++; 329 | } 330 | } 331 | brPaperAuthor.close(); 332 | while((line=brPaperVenue.readLine())!=null) 333 | { 334 | int splitPos = line.indexOf("\t"); 335 | String venue = line.substring(splitPos+1); 336 | if(!venue2Id.containsKey(venue)) 337 | { 338 | venue2Id.put(venue, venueIndex); 339 | Id2Venue.put(venueIndex, venue); 340 | venueIndex++; 341 | } 342 | } 343 | brPaperVenue.close(); 344 | int numOfPaper = paperIndex; 345 | int numOfAuthor = authorIndex; 346 | int numOfVenue = venueIndex; 347 | 348 | authorNeighbors = new ArrayList>(); 349 | venueNeighbors = new ArrayList>(); 350 | paperNeighborsAuthor = new ArrayList>(); 351 | paperNeighborsPaper = new ArrayList>(); 352 | paperNeighborsVenue = new ArrayList>(); 353 | 354 | for(int i=0;i neighbors = new ArrayList(); 357 | authorNeighbors.add(neighbors); 358 | } 359 | for(int i=0;i neighbors = new ArrayList(); 362 | venueNeighbors.add(neighbors); 363 | } 364 | for(int i=0;i neighborsAuthor = new ArrayList(); 367 | paperNeighborsAuthor.add(neighborsAuthor); 368 | ArrayList neighborsPaper = new ArrayList(); 369 | paperNeighborsPaper.add(neighborsPaper); 370 | ArrayList neighborsVenue = new ArrayList(); 371 | paperNeighborsVenue.add(neighborsVenue); 372 | } 373 | 374 | brPaperAuthor = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.author.txt"), "UTF-8")); 375 | brPaperVenue = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.venue.txt"), "UTF-8")); 376 | while((line=brPaperAuthor.readLine())!=null) 377 | { 378 | int splitPos = line.indexOf("\t"); 379 | String paper = line.substring(0, splitPos); 380 | String author = line.substring(splitPos+1); 381 | int paperId = paper2Id.get(paper); 382 | int authorId = author2Id.get(author); 383 | if(!authorNeighbors.get(authorId).contains(paperId)) 384 | authorNeighbors.get(authorId).add(paperId); 385 | if(!paperNeighborsAuthor.get(paperId).contains(authorId)) 386 | paperNeighborsAuthor.get(paperId).add(authorId); 387 | } 388 | brPaperAuthor.close(); 389 | while((line=brPaperPaper.readLine())!=null) 390 | { 391 | int splitPos = line.indexOf("\t"); 392 | String paper1 = line.substring(0, splitPos); 393 | String paper2 = line.substring(splitPos+1); 394 | int paperId1 = paper2Id.get(paper1); 395 | int paperId2 = paper2Id.get(paper2); 396 | if(!paperNeighborsPaper.get(paperId1).contains(paperId2)) 397 | paperNeighborsPaper.get(paperId1).add(paperId2); 398 | if(!paperNeighborsPaper.get(paperId2).contains(paperId1)) 399 | paperNeighborsPaper.get(paperId2).add(paperId1); 400 | } 401 | brPaperPaper.close(); 402 | while((line=brPaperVenue.readLine())!=null) 403 | { 404 | int splitPos = line.indexOf("\t"); 405 | String paper = line.substring(0, splitPos); 406 | String venue = line.substring(splitPos+1); 407 | int paperId = paper2Id.get(paper); 408 | int venueId = venue2Id.get(venue); 409 | if(!venueNeighbors.get(venueId).contains(paperId)) 410 | venueNeighbors.get(venueId).add(paperId); 411 | if(!paperNeighborsVenue.get(paperId).contains(venueId)) 412 | paperNeighborsVenue.get(paperId).add(venueId); 413 | } 414 | brPaperVenue.close(); 415 | 416 | int[] arrPaper1 = new int[numOfPaper]; 417 | for(int i=0;i0;i--) { 421 | int randNumPaper1 = randomPaper1.nextInt(i); 422 | int tempPaper1 = arrPaper1[i]; 423 | arrPaper1[i] = arrPaper1[randNumPaper1]; 424 | arrPaper1[randNumPaper1] = tempPaper1; 425 | } 426 | 427 | for(int i=0;i() { 449 | public int compare(Integer a, Integer b) { 450 | return a.compareTo(b); 451 | } 452 | }); 453 | } 454 | for(int i=0;i() { 457 | public int compare(Integer a, Integer b) { 458 | return a.compareTo(b); 459 | } 460 | }); 461 | } 462 | 463 | for(int i=0;i() { 466 | public int compare(Integer a, Integer b) { 467 | return a.compareTo(b); 468 | } 469 | }); 470 | Collections.sort(paperNeighborsPaper.get(i), new Comparator() { 471 | public int compare(Integer a, Integer b) { 472 | return a.compareTo(b); 473 | } 474 | }); 475 | Collections.sort(paperNeighborsVenue.get(i), new Comparator() { 476 | public int compare(Integer a, Integer b) { 477 | return a.compareTo(b); 478 | } 479 | }); 480 | } 481 | 482 | BufferedWriter bwMetaGraphRandWalk = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("randomwalk"+File.separator+"DBLP-V3.meta.graph.random.walk.txt"), "UTF-8")); 483 | BufferedWriter bwMetaPathRandWalkAPVPA = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("randomwalk"+File.separator+"DBLP-V3.meta.path.random.walk.apvpa.txt"), "UTF-8")); 484 | BufferedWriter bwMetaPathRandWalkAPAPA = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("randomwalk"+File.separator+"DBLP-V3.meta.path.random.walk.apapa.txt"), "UTF-8")); 485 | BufferedWriter bwMetaPathRandWalkMix = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("randomwalk"+File.separator+"DBLP-V3.meta.path.random.walk.mix.txt"), "UTF-8")); 486 | BufferedWriter bwUniformRandWalk = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("randomwalk"+File.separator+"DBLP-V3.uniform.random.walk.txt"), "UTF-8")); 487 | int pathLen = 100; 488 | int numOfWalk = 80; 489 | for(int i=0;i0;i--) { 522 | int randNum1 = random1.nextInt(i); 523 | int temp1 = arr1[i]; 524 | arr1[i] = arr1[randNum1]; 525 | arr1[randNum1] = temp1; 526 | } 527 | for(int i=0;i 28 | Use random walk sequences from to train the model 29 | -output 30 | Use to save the learned node vector-format representations 31 | -size 32 | Set the dimension of learned node representations; default is 128 33 | -window 34 | Set the window size for collecting node context pairs; default is 5 35 | -pp 36 | Use Heterogeneous Skip-Gram model (the ++ version) or not; default is 1 (++ version); otherwise, use 0 (Homogeneous Skip-Gram) 37 | -prefixes 38 | Prefixes of node Ids for specifying node types, e.g., ap with a for author and p for paper 39 | -objtype 40 | The index of the objective node type in the prefixes list, for which representations to be learned 41 | -alpha 42 | Set the starting learning rate; default is 0.025 43 | -negative 44 | Number of negative examples; default is 5, common values are 3 - 10 45 | -samples 46 | Set the number of iterations for stochastic gradient descent as Million; default is 100 47 | 48 | If you find this project is useful, please cite this paper: 49 | 50 | @inproceedings{zhang2018metagraph2vec, 51 | title={MetaGraph2Vec: Complex Semantic Path Augmented Heterogeneous Network Embedding}, 52 | author={Zhang, Daokun and Yin, Jie and Zhu, Xingquan and Zhang, Chengqi}, 53 | booktitle={Pacific-Asia Conference on Knowledge Discovery and Data Mining}, 54 | pages={196--208}, 55 | year={2018}, 56 | organization={Springer} 57 | } 58 | # Note 59 | The randomWalk2vec.c is compiled on a Linux system with RAND_MAX taking value 2147483647. If you compile randomWalk2vec.c on your system, please carefully check the value of RAND_MAX to make sure it is large enough for the correctness of alias 60 | table sampling. 61 | -------------------------------------------------------------------------------- /RandomWalk2Vec/RandomWalk2VecRun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.graph.random.walk.txt -output DBLP-V3.meta.graph.pp1.emb.txt -pp 1 -prefixes apv -objtype 0 3 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.graph.random.walk.txt -output DBLP-V3.meta.graph.pp0.emb.txt -pp 0 -prefixes apv -objtype 0 4 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.apapa.txt -output DBLP-V3.meta.path.apapa.pp1.emb.txt -pp 1 -prefixes apv -objtype 0 5 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.apapa.txt -output DBLP-V3.meta.path.apapa.pp0.emb.txt -pp 0 -prefixes apv -objtype 0 6 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.apvpa.txt -output DBLP-V3.meta.path.apvpa.pp1.emb.txt -pp 1 -prefixes apv -objtype 0 7 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.apvpa.txt -output DBLP-V3.meta.path.apvpa.pp0.emb.txt -pp 0 -prefixes apv -objtype 0 8 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.mix.txt -output DBLP-V3.meta.path.mix.pp1.emb.txt -pp 1 -prefixes apv -objtype 0 9 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.mix.txt -output DBLP-V3.meta.path.mix.pp0.emb.txt -pp 0 -prefixes apv -objtype 0 10 | ./randomWalk2vec -size 128 -train DBLP-V3.uniform.random.walk.txt -output DBLP-V3.uniform.pp1.emb.txt -pp 1 -prefixes apv -objtype 0 11 | ./randomWalk2vec -size 128 -train DBLP-V3.uniform.random.walk.txt -output DBLP-V3.uniform.pp0.emb.txt -pp 0 -prefixes apv -objtype 0 12 | -------------------------------------------------------------------------------- /RandomWalk2Vec/compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | gcc -Wall -g -c "randomWalk2vec.c" -o randomWalk2vec.o 3 | g++ -o randomWalk2vec randomWalk2vec.o 4 | 5 | -------------------------------------------------------------------------------- /RandomWalk2Vec/randomWalk2vec.c: -------------------------------------------------------------------------------- 1 | // The randomWalk2vec project was built upon the The metapath2vec.cpp code from https://ericdongyx.github.io/metapath2vec/m2v.html 2 | 3 | // Modifications Copyright (C) 2018 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | // The metapath2vec.cpp code was built upon the word2vec.c from https://code.google.com/archive/p/word2vec/ 18 | 19 | // Modifications Copyright (C) 2016 20 | // 21 | // Licensed under the Apache License, Version 2.0 (the "License"); 22 | // you may not use this file except in compliance with the License. 23 | // You may obtain a copy of the License at 24 | // 25 | // http://www.apache.org/licenses/LICENSE-2.0 26 | // 27 | // Unless required by applicable law or agreed to in writing, software 28 | // distributed under the License is distributed on an "AS IS" BASIS, 29 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 30 | // See the License for the specific language governing permissions and 31 | // limitations under the License. 32 | 33 | // Copyright 2013 Google Inc. All Rights Reserved. 34 | // 35 | // Licensed under the Apache License, Version 2.0 (the "License"); 36 | // you may not use this file except in compliance with the License. 37 | // You may obtain a copy of the License at 38 | // 39 | // http://www.apache.org/licenses/LICENSE-2.0 40 | // 41 | // Unless required by applicable law or agreed to in writing, software 42 | // distributed under the License is distributed on an "AS IS" BASIS, 43 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 44 | // See the License for the specific language governing permissions and 45 | // limitations under the License. 46 | 47 | #include 48 | #include 49 | #include 50 | #include 51 | #include 52 | 53 | #define MAX_STRING 100 54 | #define EXP_TABLE_SIZE 1000 55 | #define MAX_EXP 6 56 | #define MAX_WALK_LENGTH 1000 57 | 58 | const long long node_hash_size = 10000000; // Maximum 10M nodes 59 | const long long node_context_hash_size = 1000000000; // Maximum 1G node context pairs 60 | 61 | struct Node 62 | { 63 | long long cn; 64 | char *node_str; 65 | }; 66 | 67 | struct Node_Context 68 | { 69 | long long cn; 70 | long long source, target; 71 | }; 72 | 73 | char train_file[MAX_STRING], output_file[MAX_STRING]; 74 | struct Node *node_list; 75 | long long window = 5; 76 | long long pp = 1; // pp = 1 with Heterogeneous Skip-Gram, pp = 0 with Homogeneous Skip-Gram 77 | long long *node_hash; 78 | long long node_list_max_size = 100000, node_list_size = 0, layer1_size = 128; 79 | long long node_occur_num = 0; 80 | double alpha = 0.025, starting_alpha; 81 | double *syn0, *syn1, *syn1neg, *expTable; 82 | double **type_syn1; 83 | clock_t start, finish; 84 | 85 | long long type_num; 86 | long long **type_tables, *type_counts, *type_indices; 87 | char prefixes[MAX_STRING]; 88 | long long obj_type; 89 | 90 | long long **type_node_hash; 91 | long long *type_node_list_size, *type_node_index; 92 | struct Node **type_node_list; 93 | 94 | long long negative = 5; 95 | const long table_size = 1e8; 96 | long long *table; 97 | 98 | long long *node_context_hash; 99 | long long node_context_list_size, node_context_list_max_size; 100 | struct Node_Context *node_context_list; 101 | 102 | // Parameters for node context pair sampling 103 | long long *alias; 104 | double *prob; 105 | 106 | long long total_samples = 100; 107 | 108 | long long GetTypeId(char prefix) 109 | { 110 | long long i; 111 | for (i = 0; i < type_num; i++) 112 | { 113 | if (prefix == prefixes[i]) 114 | break; 115 | } 116 | if (i >= type_num) 117 | return -1; 118 | else 119 | return i; 120 | } 121 | 122 | void NodeTypeNeg() 123 | { 124 | long long i; 125 | long long target, type_Id; 126 | type_tables = (long long **)malloc(type_num * sizeof(long long *)); 127 | type_counts = (long long *)malloc(type_num * sizeof(long long)); 128 | type_indices = (long long *)malloc(type_num * sizeof(long long)); 129 | if (type_tables == NULL || type_counts == NULL || type_indices == NULL) 130 | { 131 | printf("Memory allocation failed\n"); 132 | exit(1); 133 | } 134 | for (i = 0; i < type_num; i++) 135 | { 136 | type_counts[i] = 0; 137 | type_indices[i] = 0; 138 | } 139 | for (i = 0; i < table_size; i++) 140 | { 141 | target = table[i]; 142 | type_Id = GetTypeId(node_list[target].node_str[0]); 143 | if (type_Id == -1) 144 | { 145 | printf("Unrecognised type\n"); 146 | exit(1); 147 | } 148 | type_counts[type_Id]++; 149 | } 150 | for (i = 0; i< type_num; i++) 151 | { 152 | type_tables[i] = (long long *)malloc(type_counts[i] * sizeof(long long)); 153 | if (type_tables[i] == NULL) 154 | { 155 | printf("Memory allocation failed\n"); 156 | exit(1); 157 | } 158 | } 159 | for (i = 0; i < table_size; i++) 160 | { 161 | target = table[i]; 162 | type_Id = GetTypeId(node_list[target].node_str[0]); 163 | type_tables[type_Id][type_indices[type_Id]] = target; 164 | type_indices[type_Id]++; 165 | } 166 | for (i = 0; i < type_num; i++) 167 | printf("type %c table size: %lld\n", prefixes[i], type_counts[i]); 168 | } 169 | 170 | void InitUnigramTable() 171 | { 172 | long long a, i; 173 | double train_nodes_pow = 0.0, d1, power = 0.75; 174 | table = (long long *)malloc(table_size * sizeof(long long)); 175 | if (table == NULL) 176 | { 177 | printf("Memory allocation failed\n"); 178 | exit(1); 179 | } 180 | for (a = 0; a < node_list_size; a++) train_nodes_pow += pow(node_list[a].cn, power); 181 | i = 0; 182 | d1 = pow(node_list[i].cn, power) / train_nodes_pow; 183 | for (a = 0; a < table_size; a++) 184 | { 185 | table[a] = i; 186 | if (a / (double)table_size > d1) 187 | { 188 | i++; 189 | d1 += pow(node_list[i].cn, power) / (double)train_nodes_pow; 190 | } 191 | if (i >= node_list_size) i = node_list_size - 1; 192 | } 193 | if (pp == 1) 194 | NodeTypeNeg(); 195 | } 196 | 197 | // Reads a single node from a file, assuming space + tab + EOL to be node name boundaries 198 | long long ReadNode(char *node_str, FILE *fin) // return 1, if the boundary is '\n', otherwise 0 199 | { 200 | long long a = 0, ch, flag = 0; 201 | while (!feof(fin)) 202 | { 203 | ch = fgetc(fin); 204 | if (ch == 13) continue; 205 | if ((ch == ' ') || (ch == '\t') || (ch == '\n')) 206 | { 207 | if (a > 0) 208 | { 209 | if (ch == '\n') flag = 1; 210 | break; 211 | } 212 | continue; 213 | } 214 | node_str[a] = ch; 215 | a++; 216 | if (a >= MAX_STRING - 1) a--; 217 | } 218 | node_str[a] = 0; 219 | return flag; 220 | } 221 | 222 | // Return hash value of a node 223 | long long GetNodeHash(char *node_str) 224 | { 225 | long long a, hash = 0; 226 | for (a = 0; a < strlen(node_str); a++) 227 | { 228 | hash = hash * 257 + node_str[a]; 229 | hash = hash % node_hash_size; 230 | } 231 | return hash; 232 | } 233 | 234 | // Return hash value for a node context pair 235 | long long GetNodeContextHash(char *node_str, char *context_node_str) 236 | { 237 | long long a, hash = 0; 238 | for (a = 0; a < strlen(node_str); a++) 239 | { 240 | hash = hash * 257 + node_str[a]; 241 | hash = hash % node_context_hash_size; 242 | } 243 | for (a = 0; a < strlen(context_node_str); a++) 244 | { 245 | hash = hash * 257 + context_node_str[a]; 246 | hash = hash % node_context_hash_size; 247 | } 248 | return hash; 249 | } 250 | 251 | // Return position of a node in the node list; if the node is not found, returns -1 252 | long long SearchNode(char *node_str) 253 | { 254 | long long hash = GetNodeHash(node_str); 255 | while (1) 256 | { 257 | if (node_hash[hash] == -1) return -1; 258 | if (!strcmp(node_str, node_list[node_hash[hash]].node_str)) return node_hash[hash]; 259 | hash = (hash + 1) % node_hash_size; 260 | } 261 | return -1; 262 | } 263 | 264 | // Return position of a node in the node list of the specific type; if the node is not found, returns -1 265 | long long TypeSearchNode(char *node_str, long long type_Id) 266 | { 267 | long long hash = GetNodeHash(node_str); 268 | while (1) 269 | { 270 | if (type_node_hash[type_Id][hash] == -1) return -1; 271 | if (!strcmp(node_str, type_node_list[type_Id][type_node_hash[type_Id][hash]].node_str)) return type_node_hash[type_Id][hash]; 272 | hash = (hash + 1) % node_hash_size; 273 | } 274 | return -1; 275 | } 276 | 277 | // Return position of a node context pair in the node context list; if the node is not found, returns -1 278 | long long SearchNodeContextPair(long long node, long long context) 279 | { 280 | long long hash = GetNodeContextHash(node_list[node].node_str, node_list[context].node_str); 281 | long long hash_iter = 0; 282 | long long cur_node, cur_context; 283 | while(1) 284 | { 285 | if (node_context_hash[hash] == -1) return -1; 286 | cur_node = node_context_list[node_context_hash[hash]].source; 287 | cur_context = node_context_list[node_context_hash[hash]].target; 288 | if (cur_node == node && cur_context == context) return node_context_hash[hash]; 289 | hash = (hash + 1) % node_context_hash_size; 290 | hash_iter++; 291 | if (hash_iter >= node_context_hash_size) 292 | { 293 | printf("The node context hash table is full!\n"); 294 | exit(1); 295 | } 296 | } 297 | return -1; 298 | } 299 | 300 | // Adds a node to the node list 301 | long long AddNodeToList(char *node_str) 302 | { 303 | long long hash; 304 | long long length = strlen(node_str) + 1; 305 | if (length > MAX_STRING) length = MAX_STRING; 306 | node_list[node_list_size].node_str = (char *)calloc(length, sizeof(char)); 307 | if (node_list[node_list_size].node_str == NULL) 308 | { 309 | printf("Memory allocation failed\n"); 310 | exit(1); 311 | } 312 | strcpy(node_list[node_list_size].node_str, node_str); 313 | node_list[node_list_size].cn = 1; 314 | node_list_size++; 315 | // Reallocate memory if needed 316 | if (node_list_size >= node_list_max_size) 317 | { 318 | node_list_max_size += 10000; 319 | node_list = (struct Node *)realloc(node_list, node_list_max_size * sizeof(struct Node)); 320 | if (node_list == NULL) 321 | { 322 | printf("Memory allocation failed\n"); 323 | exit(1); 324 | } 325 | } 326 | hash = GetNodeHash(node_str); 327 | while (node_hash[hash] != -1) hash = (hash + 1) % node_hash_size; 328 | node_hash[hash] = node_list_size - 1; 329 | return node_list_size - 1; 330 | } 331 | 332 | //Add node context pair to the node context list 333 | long long AddNodeContextToList(long long node, long long context) 334 | { 335 | long long hash; 336 | long long hash_iter = 0; 337 | node_context_list[node_context_list_size].source = node; 338 | node_context_list[node_context_list_size].target = context; 339 | node_context_list[node_context_list_size].cn = 1; 340 | node_context_list_size++; 341 | if (node_context_list_size >= node_context_list_max_size) 342 | { 343 | node_context_list_max_size += 100 * node_list_size; 344 | node_context_list = (struct Node_Context *)realloc(node_context_list, node_context_list_max_size * sizeof(struct Node_Context)); 345 | if (node_context_list == NULL) 346 | { 347 | printf("Memory allocation failed\n"); 348 | exit(1); 349 | } 350 | } 351 | hash = GetNodeContextHash(node_list[node].node_str, node_list[context].node_str); 352 | while (node_context_hash[hash] != -1) 353 | { 354 | hash = (hash + 1) % node_context_hash_size; 355 | hash_iter++; 356 | if (hash_iter >= node_context_hash_size) 357 | { 358 | printf("The node context hash table is full!\n"); 359 | exit(1); 360 | } 361 | } 362 | node_context_hash[hash] = node_context_list_size - 1; 363 | return node_context_list_size - 1; 364 | } 365 | 366 | void LearnNodeListFromTrainFile() 367 | { 368 | char node_str[MAX_STRING]; 369 | FILE *fin; 370 | long long a, i; 371 | for (a = 0; a < node_hash_size; a++) node_hash[a] = -1; 372 | fin = fopen(train_file, "r"); 373 | if (fin == NULL) 374 | { 375 | printf("ERROR: training data file not found!\n"); 376 | exit(1); 377 | } 378 | node_list_size = 0; 379 | printf("Reading nodes...\n"); 380 | while (1) 381 | { 382 | ReadNode(node_str, fin); 383 | if (feof(fin)) break; 384 | node_occur_num++; 385 | if (node_occur_num % 100000 == 0) 386 | { 387 | printf("Read nodes: %lldK%c", node_occur_num / 1000, 13); 388 | fflush(stdout); 389 | } 390 | i = SearchNode(node_str); 391 | if (i == -1) AddNodeToList(node_str); 392 | else node_list[i].cn++; 393 | } 394 | printf("Node list size: %lld \n", node_list_size); 395 | printf("Node occurrence times in train file: %lld\n", node_occur_num); 396 | fclose(fin); 397 | } 398 | 399 | void GetNodeContextTable() 400 | { 401 | char node_str[MAX_STRING]; 402 | long long a, b, c, node, context_node, walk_length = 0, walk_position = 0; 403 | long long walk[MAX_WALK_LENGTH + 1]; 404 | long long line_num = 0; 405 | long long end_flag; 406 | FILE *fi = fopen(train_file, "r"); 407 | node_context_hash = (long long *)malloc(node_context_hash_size * sizeof(long long)); 408 | node_context_list_max_size = 1000 * node_list_size; 409 | node_context_list = (struct Node_Context *)malloc(node_context_list_max_size * sizeof(struct Node_Context)); 410 | for (a = 0; a < node_context_hash_size; a++) node_context_hash[a] = -1; 411 | node_context_list_size = 0; 412 | printf("Collecting node context pair...\n"); 413 | start = clock(); 414 | while (1) 415 | { 416 | if (walk_length == 0) 417 | { 418 | while (1) 419 | { 420 | end_flag = ReadNode(node_str, fi); 421 | if (feof(fi)) break; 422 | node = SearchNode(node_str); 423 | //if (node == -1) continue; 424 | walk[walk_length] = node; 425 | walk_length++; 426 | if (end_flag || walk_length >= MAX_WALK_LENGTH) break; 427 | } 428 | walk_position = 0; 429 | line_num++; 430 | if (line_num % 1000 == 0) 431 | { 432 | printf("Processed lines: %lldK%c", line_num / 1000, 13); 433 | fflush(stdout); 434 | } 435 | } 436 | if (feof(fi) && walk_length == 0) break; 437 | //if (walk_length == 0) continue; 438 | node = walk[walk_position]; //current node 439 | for (a = 0; a < window * 2 + 1; a++) 440 | { 441 | if (a != window) 442 | { 443 | c = walk_position - window + a; 444 | if (c < 0) continue; 445 | if (c >= walk_length) continue; 446 | context_node = walk[c]; 447 | b = SearchNodeContextPair(node, context_node); 448 | if (b == -1) AddNodeContextToList(node, context_node); 449 | else node_context_list[b].cn++; 450 | } 451 | } 452 | walk_position++; 453 | if (walk_position >= walk_length) walk_length = 0; 454 | } 455 | finish = clock(); 456 | printf("Total time: %lf secs for collecting node context pairs\n", (double)(finish - start) / CLOCKS_PER_SEC); 457 | printf("Node context pair number: %lld\n", node_context_list_size); 458 | printf("----------------------------------------------------\n"); 459 | } 460 | 461 | void InitNet() 462 | { 463 | long long a, b; 464 | unsigned long long next_random = 1; 465 | syn0 = (double *)malloc(node_list_size * layer1_size * sizeof(double)); 466 | if (syn0 == NULL) 467 | { 468 | printf("Memory allocation failed\n"); 469 | exit(1); 470 | } 471 | for (a = 0; a < node_list_size; a++) 472 | for (b = 0; b < layer1_size; b++) 473 | { 474 | next_random = next_random * (unsigned long long)25214903917 + 11; 475 | syn0[a * layer1_size + b] = (((next_random & 0xFFFF) / (double)65536) - 0.5) / layer1_size; 476 | } 477 | syn1neg = (double *)malloc(node_list_size * layer1_size * sizeof(double)); 478 | if (syn1neg == NULL) 479 | { 480 | printf("Memory allocation failed\n"); 481 | exit(1); 482 | } 483 | for (a = 0; a < node_list_size; a++) 484 | for (b = 0; b < layer1_size; b++) 485 | syn1neg[a * layer1_size + b] = 0; 486 | InitUnigramTable(); 487 | } 488 | 489 | /* The alias sampling algorithm, which is used to sample an node context pair in O(1) time. */ 490 | void InitAliasTable() 491 | { 492 | long long k; 493 | double sum = 0; 494 | long long cur_small_block, cur_large_block; 495 | long long num_small_block = 0, num_large_block = 0; 496 | double *norm_prob; 497 | long long *large_block; 498 | long long *small_block; 499 | alias = (long long *)malloc(node_context_list_size * sizeof(long long)); 500 | prob = (double *)malloc(node_context_list_size * sizeof(double)); 501 | if (alias == NULL || prob == NULL) 502 | { 503 | printf("Memory allocation failed\n"); 504 | exit(1); 505 | } 506 | norm_prob = (double*)malloc(node_context_list_size * sizeof(double)); 507 | large_block = (long long*)malloc(node_context_list_size * sizeof(long long)); 508 | small_block = (long long*)malloc(node_context_list_size * sizeof(long long)); 509 | if (norm_prob == NULL || large_block == NULL || small_block == NULL) 510 | { 511 | printf("Memory allocation failed\n"); 512 | exit(1); 513 | } 514 | for (k = 0; k != node_context_list_size; k++) sum += node_context_list[k].cn; 515 | for (k = 0; k != node_context_list_size; k++) norm_prob[k] = (double)node_context_list[k].cn * node_context_list_size / sum; 516 | for (k = node_context_list_size - 1; k >= 0; k--) 517 | { 518 | if (norm_prob[k]<1) 519 | small_block[num_small_block++] = k; 520 | else 521 | large_block[num_large_block++] = k; 522 | } 523 | while (num_small_block && num_large_block) 524 | { 525 | cur_small_block = small_block[--num_small_block]; 526 | cur_large_block = large_block[--num_large_block]; 527 | prob[cur_small_block] = norm_prob[cur_small_block]; 528 | alias[cur_small_block] = cur_large_block; 529 | norm_prob[cur_large_block] = norm_prob[cur_large_block] + norm_prob[cur_small_block] - 1; 530 | if (norm_prob[cur_large_block] < 1) 531 | small_block[num_small_block++] = cur_large_block; 532 | else 533 | large_block[num_large_block++] = cur_large_block; 534 | } 535 | while (num_large_block) prob[large_block[--num_large_block]] = 1; 536 | while (num_small_block) prob[small_block[--num_small_block]] = 1; 537 | free(norm_prob); 538 | free(small_block); 539 | free(large_block); 540 | } 541 | 542 | long long SampleAPair(double rand_value1, double rand_value2) 543 | { 544 | long long k = (long long)node_context_list_size * rand_value1; 545 | return rand_value2 < prob[k] ? k : alias[k]; 546 | } 547 | 548 | void TrainModel() 549 | { 550 | long a, b, c, d; 551 | long long node, context_node, cur_pair; 552 | long long count = 0, last_count = 0; 553 | long long l1, l2, target, label; 554 | long long type_Id; 555 | unsigned long long next_random = 1; 556 | double rand_num1, rand_num2; 557 | double f, g; 558 | double *neu1e = (double *)calloc(layer1_size, sizeof(double)); 559 | FILE *fp; 560 | InitNet(); 561 | starting_alpha = alpha; 562 | printf("Skip-Gram model with Negative Sampling:"); 563 | if (pp == 1) printf(" heterogeneous version\n"); 564 | else printf(" homogeneous version\n"); 565 | printf("Training file: %s\n", train_file); 566 | printf("Samples: %lldM\n", total_samples / 1000000); 567 | printf("Dimension: %lld\n", layer1_size); 568 | printf("Initial Alpha: %f\n", alpha); 569 | start = clock(); 570 | srand((unsigned)time(NULL)); 571 | while (1) 572 | { 573 | if (count >= total_samples) break; 574 | if (count - last_count > 10000) 575 | { 576 | last_count = count; 577 | printf("Alpha: %f Progress %.3lf%%%c", alpha, (double)count / (double)(total_samples + 1) * 100, 13); 578 | fflush(stdout); 579 | alpha = starting_alpha * (1 - (double)count / (double)(total_samples + 1)); 580 | if (alpha < starting_alpha * 0.0001) alpha = starting_alpha * 0.0001; 581 | } 582 | rand_num1 = rand() / (RAND_MAX * 1.0 + 1); 583 | rand_num2 = rand() / (RAND_MAX * 1.0 + 1); 584 | cur_pair = SampleAPair(rand_num1, rand_num2); 585 | node = node_context_list[cur_pair].source; 586 | context_node = node_context_list[cur_pair].target; 587 | l1 = node * layer1_size; 588 | for (c = 0; c < layer1_size; c++) neu1e[c] = 0; 589 | 590 | // NEGATIVE SAMPLING 591 | for (d = 0; d < negative + 1; d++) 592 | { 593 | if (d == 0) 594 | { 595 | target = context_node; 596 | label = 1; 597 | } 598 | else 599 | { 600 | next_random = next_random * (unsigned long long)25214903917 + 11; 601 | if (pp == 1) 602 | { 603 | type_Id = GetTypeId(node_list[context_node].node_str[0]); 604 | target = type_tables[type_Id][(next_random >> 16) % type_counts[type_Id]]; 605 | } 606 | else 607 | target = table[(next_random >> 16) % table_size]; 608 | if (target == context_node) continue; 609 | label = 0; 610 | } 611 | l2 = target * layer1_size; 612 | f = 0; 613 | for (c = 0; c < layer1_size; c++) f += syn0[c + l1] * syn1neg[c + l2]; 614 | if (f > MAX_EXP) f = 1; 615 | else if (f < -MAX_EXP) f = 0; 616 | else f = expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]; 617 | g = (label - f) * alpha; 618 | // Propagate errors output -> hidden 619 | for (c = 0; c < layer1_size; c++) neu1e[c] += g * syn1neg[c + l2]; 620 | // Learn weights hidden -> output 621 | for (c = 0; c < layer1_size; c++) syn1neg[c + l2] += g * syn0[c + l1]; 622 | } 623 | // Learn weights input -> hidden 624 | for (c = 0; c < layer1_size; c++) syn0[c + l1] += neu1e[c]; 625 | count++; 626 | } 627 | finish = clock(); 628 | printf("Total time: %lf secs for learning node embeddings\n", (double)(finish - start) / CLOCKS_PER_SEC); 629 | printf("----------------------------------------------------\n"); 630 | fp = fopen(output_file, "w"); 631 | // Save the learned node representations 632 | for (a = 0; a < node_list_size; a++) 633 | { 634 | if (node_list[a].node_str[0] != prefixes[obj_type]) continue; 635 | fprintf(fp, "%s", node_list[a].node_str); 636 | for (b = 0; b < layer1_size; b++) fprintf(fp, " %lf", syn0[a * layer1_size + b]); 637 | fprintf(fp, "\n"); 638 | } 639 | free(neu1e); 640 | fclose(fp); 641 | } 642 | 643 | int ArgPos(char *str, int argc, char **argv) 644 | { 645 | int a; 646 | for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) 647 | { 648 | if (a == argc - 1) 649 | { 650 | printf("Argument missing for %s\n", str); 651 | exit(1); 652 | } 653 | return a; 654 | } 655 | return -1; 656 | } 657 | 658 | int main(int argc, char **argv) 659 | { 660 | int i; 661 | if (argc == 1) 662 | { 663 | printf("---randomWalk2vec---\n"); 664 | printf("---The code and following instructions are built upon word2vec.c by Mikolov et al.---\n\n"); 665 | printf("Options:\n"); 666 | printf("Parameters for training:\n"); 667 | printf("\t-train \n"); 668 | printf("\t\tUse random walk sequences from to train the model\n"); 669 | printf("\t-output \n"); 670 | printf("\t\tUse to save the learned node vector-format representations\n"); 671 | printf("\t-size \n"); 672 | printf("\t\tSet the dimension of learned node representations; default is 128\n"); 673 | printf("\t-window \n"); 674 | printf("\t\tSet the window size for collecting node context pairs; default is 5\n"); 675 | printf("\t-pp \n"); 676 | printf("\t\tUse Heterogeneous Skip-Gram model (the ++ version) or not; default is 1 (++ version); otherwise, use 0 (Homogeneous Skip-Gram)\n"); 677 | printf("\t-prefixes \n"); 678 | printf("\t\tPrefixes of node Ids for specifying node types, e.g., ap with a for author and p for paper\n"); 679 | printf("\t-objtype\n"); 680 | printf("\t\tThe index of the objective node type in the prefixes list, for which representations to be learned\n"); 681 | printf("\t-alpha \n"); 682 | printf("\t\tSet the starting learning rate; default is 0.025\n"); 683 | printf("\t-negative \n"); 684 | printf("\t\tNumber of negative examples; default is 5, common values are 3 - 10\n"); 685 | printf("\t-samples \n"); 686 | printf("\t\tSet the number of iterations for stochastic gradient descent as Million; default is 100\n"); 687 | return 0; 688 | } 689 | if ((i = ArgPos((char *)"-size", argc, argv)) > 0) layer1_size = atoi(argv[i + 1]); 690 | if ((i = ArgPos((char *)"-train", argc, argv)) > 0) strcpy(train_file, argv[i + 1]); 691 | if ((i = ArgPos((char *)"-alpha", argc, argv)) > 0) alpha = atof(argv[i + 1]); 692 | if ((i = ArgPos((char *)"-output", argc, argv)) > 0) strcpy(output_file, argv[i + 1]); 693 | if ((i = ArgPos((char *)"-window", argc, argv)) > 0) window = atoi(argv[i + 1]); 694 | if ((i = ArgPos((char *)"-pp", argc, argv)) > 0) pp = atoi(argv[i + 1]); 695 | if ((i = ArgPos((char *)"-prefixes", argc, argv)) > 0) strcpy(prefixes, argv[i + 1]); 696 | if ((i = ArgPos((char *)"-objtype", argc, argv)) > 0) obj_type = atoi(argv[i + 1]); 697 | if ((i = ArgPos((char *)"-negative", argc, argv)) > 0) negative = atoi(argv[i + 1]); 698 | if ((i = ArgPos((char *)"-samples", argc, argv)) >0) total_samples = atoi(argv[i + 1]); 699 | total_samples = total_samples * 1000000; 700 | type_num = strlen(prefixes); 701 | printf("Number of node types: %lld\n", type_num); 702 | node_list = (struct Node *)calloc(node_list_max_size, sizeof(struct Node)); 703 | node_hash = (long long *)calloc(node_hash_size, sizeof(long long)); 704 | expTable = (double *)malloc((EXP_TABLE_SIZE + 1) * sizeof(double)); 705 | if (node_list == NULL || node_hash == NULL || expTable == NULL) 706 | { 707 | printf("Memory allocation failed\n"); 708 | exit(1); 709 | } 710 | for (i = 0; i < EXP_TABLE_SIZE; i++) 711 | { 712 | expTable[i] = exp((i / (double)EXP_TABLE_SIZE * 2 - 1) * MAX_EXP); // Precompute the exp() table 713 | expTable[i] = expTable[i] / (expTable[i] + 1); // Precompute f(x) = x / (x + 1) 714 | } 715 | LearnNodeListFromTrainFile(); 716 | GetNodeContextTable(); 717 | InitAliasTable(); 718 | TrainModel(); 719 | return 0; 720 | } 721 | --------------------------------------------------------------------------------