├── .DS_Store ├── .gitattributes ├── DBLPPreProcess ├── .DS_Store ├── bin │ └── .DS_Store ├── .classpath ├── .project ├── .settings │ └── org.eclipse.jdt.core.prefs └── src │ ├── PreProcessStep1.java │ ├── PreProcessStep2.java │ └── PreProcessStep3.java ├── RandomWalk2Vec ├── compile.sh ├── RandomWalk2VecRun.sh └── randomWalk2vec.c └── README.md /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daokunzhang/MetaGraph2Vec/HEAD/.DS_Store -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /DBLPPreProcess/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daokunzhang/MetaGraph2Vec/HEAD/DBLPPreProcess/.DS_Store -------------------------------------------------------------------------------- /DBLPPreProcess/bin/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daokunzhang/MetaGraph2Vec/HEAD/DBLPPreProcess/bin/.DS_Store -------------------------------------------------------------------------------- /RandomWalk2Vec/compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | gcc -Wall -g -c "randomWalk2vec.c" -o randomWalk2vec.o 3 | g++ -o randomWalk2vec randomWalk2vec.o 4 | 5 | -------------------------------------------------------------------------------- /DBLPPreProcess/.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /DBLPPreProcess/.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | DBLPPreProcess 4 | 5 | 6 | 7 | 8 | 9 | org.eclipse.jdt.core.javabuilder 10 | 11 | 12 | 13 | 14 | 15 | org.eclipse.jdt.core.javanature 16 | 17 | 18 | -------------------------------------------------------------------------------- /DBLPPreProcess/.settings/org.eclipse.jdt.core.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled 3 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 4 | org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve 5 | org.eclipse.jdt.core.compiler.compliance=1.8 6 | org.eclipse.jdt.core.compiler.debug.lineNumber=generate 7 | org.eclipse.jdt.core.compiler.debug.localVariable=generate 8 | org.eclipse.jdt.core.compiler.debug.sourceFile=generate 9 | org.eclipse.jdt.core.compiler.problem.assertIdentifier=error 10 | org.eclipse.jdt.core.compiler.problem.enumIdentifier=error 11 | org.eclipse.jdt.core.compiler.source=1.8 12 | -------------------------------------------------------------------------------- /RandomWalk2Vec/RandomWalk2VecRun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.graph.random.walk.txt -output DBLP-V3.meta.graph.pp1.emb.txt -pp 1 -prefixes apv -objtype 0 3 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.graph.random.walk.txt -output DBLP-V3.meta.graph.pp0.emb.txt -pp 0 -prefixes apv -objtype 0 4 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.apapa.txt -output DBLP-V3.meta.path.apapa.pp1.emb.txt -pp 1 -prefixes apv -objtype 0 5 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.apapa.txt -output DBLP-V3.meta.path.apapa.pp0.emb.txt -pp 0 -prefixes apv -objtype 0 6 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.apvpa.txt -output DBLP-V3.meta.path.apvpa.pp1.emb.txt -pp 1 -prefixes apv -objtype 0 7 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.apvpa.txt -output DBLP-V3.meta.path.apvpa.pp0.emb.txt -pp 0 -prefixes apv -objtype 0 8 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.mix.txt -output DBLP-V3.meta.path.mix.pp1.emb.txt -pp 1 -prefixes apv -objtype 0 9 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.mix.txt -output DBLP-V3.meta.path.mix.pp0.emb.txt -pp 0 -prefixes apv -objtype 0 10 | ./randomWalk2vec -size 128 -train DBLP-V3.uniform.random.walk.txt -output DBLP-V3.uniform.pp1.emb.txt -pp 1 -prefixes apv -objtype 0 11 | ./randomWalk2vec -size 128 -train DBLP-V3.uniform.random.walk.txt -output DBLP-V3.uniform.pp0.emb.txt -pp 0 -prefixes apv -objtype 0 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MetaGraph2Vec 2 | 3 | Code for PAKDD-2018 paper "MetaGraph2Vec: Complex Semantic Path Augmented Heterogeneous Network Embedding" 4 | 5 | Authors: Daokun Zhang, Jie Yin, Xingquan Zhu and Chengqi Zhang 6 | 7 | Contact: Daokun Zhang (daokunzhang2015@gmail.com) 8 | 9 | This project contains two folders: 10 | 11 | 1) The folder "DBLPPreProcess" is the Java project for data preprocessing. It contains three files 12 | 13 | "PreProcessStep1", used to extract subgraph form the "DBLP-Citation-Network-V3.txt" network; 14 | 15 | "PreProcessStep2", used to obtain the class labels for authors; 16 | 17 | "PreProcessStep3", used to generate random walks from the extracted subgraph, including the metagraph guided random walks, metapath guided random walks and uniform random walks that ignore the heterogeneity of nodes. 18 | 19 | To run this project, firstly download the "DBLP-Citation-Network-V3.txt" file from https://aminer.org/citation, and move it to the "DBLPPreProcess/dataset" folder, then run "PreProcessStep1", "PreProcessStep2" and "PreProcessStep3" sequentially. The class label file will be output to the folder "DBLPPreProcess/group", and the random walk files will be output to the folder "DBLPPreProcess/randomwalk". 20 | 21 | 2) The folder "RandomWalk2VecRun" contains the program for learning node embeddings from the generated random walk sequences. Please run "compile.sh" to compile the "randomwWalk2vec.c" file to its corresponding executable file. To run this program, move the generated random walk files from "DBLPPreProcess/randomwalk" to this folder and run the "RandomWalk2VecRun.sh" file. 22 | 23 | In this folder, the "randomWalk2vec" program is used to learn node embeddings from random walk sequences with Heterogeneous Skip-Gram or Homogeneous Skip-Gram. 24 | 25 | The options of randomWalk2Vec are as follows: 26 | 27 | -train 28 | Use random walk sequences from to train the model 29 | -output 30 | Use to save the learned node vector-format representations 31 | -size 32 | Set the dimension of learned node representations; default is 128 33 | -window 34 | Set the window size for collecting node context pairs; default is 5 35 | -pp 36 | Use Heterogeneous Skip-Gram model (the ++ version) or not; default is 1 (++ version); otherwise, use 0 (Homogeneous Skip-Gram) 37 | -prefixes 38 | Prefixes of node Ids for specifying node types, e.g., ap with a for author and p for paper 39 | -objtype 40 | The index of the objective node type in the prefixes list, for which representations to be learned 41 | -alpha 42 | Set the starting learning rate; default is 0.025 43 | -negative 44 | Number of negative examples; default is 5, common values are 3 - 10 45 | -samples 46 | Set the number of iterations for stochastic gradient descent as Million; default is 100 47 | 48 | If you find this project is useful, please cite this paper: 49 | 50 | @inproceedings{zhang2018metagraph2vec, 51 | title={MetaGraph2Vec: Complex Semantic Path Augmented Heterogeneous Network Embedding}, 52 | author={Zhang, Daokun and Yin, Jie and Zhu, Xingquan and Zhang, Chengqi}, 53 | booktitle={Pacific-Asia Conference on Knowledge Discovery and Data Mining}, 54 | pages={196--208}, 55 | year={2018}, 56 | organization={Springer} 57 | } 58 | # Note 59 | The randomWalk2vec.c is compiled on a Linux system with RAND_MAX taking value 2147483647. If you compile randomWalk2vec.c on your system, please carefully check the value of RAND_MAX to make sure it is large enough for the correctness of alias 60 | table sampling. 61 | -------------------------------------------------------------------------------- /DBLPPreProcess/src/PreProcessStep1.java: -------------------------------------------------------------------------------- 1 | import java.io.BufferedReader; 2 | import java.io.BufferedWriter; 3 | import java.io.File; 4 | import java.io.FileInputStream; 5 | import java.io.FileOutputStream; 6 | import java.io.IOException; 7 | import java.io.InputStreamReader; 8 | import java.io.OutputStreamWriter; 9 | import java.util.ArrayList; 10 | 11 | public class PreProcessStep1 { // extract subgraph form the DBLP-Citation-Network-V3.txt network 12 | 13 | public static int nonAlphabet(char ch) 14 | { 15 | if(ch>='A'&&ch<='Z') 16 | return 0; 17 | if(ch>='a'&&ch<='z') 18 | return 0; 19 | return 1; 20 | } 21 | 22 | public static int check(String venue) 23 | { 24 | String[] selectedVenues = {"SIGMOD", "ICDE", "VLDB", "EDBT", 25 | "PODS", "ICDT", "DASFAA", "SSDBM", "CIKM", "KDD", 26 | "ICDM", "SDM", "PKDD", "PAKDD", "IJCAI", "AAAI", 27 | "NIPS", "ICML", "ECML", "ACML", "IJCNN", "UAI", 28 | "ECAI", "COLT", "ACL", "KR", "CVPR", "ICCV", 29 | "ECCV", "ACCV", "MM", "ICPR", "ICIP", "ICME"}; 30 | int res = 0; 31 | for(int i=0;i=venue.length()) 39 | res = 1; 40 | else 41 | { 42 | char endChar = venue.charAt(index+selectedVenues[i].length()); 43 | if(nonAlphabet(endChar)==1) 44 | res = 1; 45 | } 46 | } 47 | else 48 | { 49 | char startChar = venue.charAt(index-1); 50 | if(nonAlphabet(startChar)==1) 51 | { 52 | if(index+selectedVenues[i].length()>=venue.length()) 53 | res = 1; 54 | else 55 | { 56 | char endChar = venue.charAt(index+selectedVenues[i].length()); 57 | if(nonAlphabet(endChar)==1) 58 | res = 1; 59 | } 60 | } 61 | } 62 | } 63 | if(res==1) 64 | break; 65 | } 66 | if(res==1) 67 | { 68 | if(venue.indexOf("SIGMOD")!=-1&&venue.indexOf("Data Mining")!=-1) 69 | res = 0; 70 | } 71 | return res; 72 | } 73 | 74 | public static void main(String[] args) throws IOException { 75 | // TODO Auto-generated method stub 76 | BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-Citation-Network-V3.txt"), "UTF-8")); 77 | BufferedWriter bwTitle = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("dataset"+File.separator+"DBLP-V3.paper.title.txt"), "UTF-8")); 78 | BufferedWriter bwAuthor = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("dataset"+File.separator+"DBLP-V3.paper.author.txt"), "UTF-8")); 79 | BufferedWriter bwVenue = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("dataset"+File.separator+"DBLP-V3.paper.venue.txt"), "UTF-8")); 80 | BufferedWriter bwRef = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("dataset"+File.separator+"DBLP-V3.paper.ref.txt"), "UTF-8")); 81 | int numOfNode = 0; 82 | String line = null; 83 | br.readLine(); 84 | String curIndex = ""; 85 | String curTitle = ""; 86 | String curAuthors = ""; 87 | String[] curAuthorList = null; 88 | String curVenue = ""; 89 | ArrayList curRef = new ArrayList(); 90 | int writeStatus = 0; 91 | int maxRefSize = 0; 92 | String maxRefIndex = ""; 93 | while((line=br.readLine())!=null) 94 | { 95 | if(line.indexOf("#*")==0) 96 | curTitle = line.substring(2).trim(); 97 | if(line.indexOf("#c")==0) 98 | { 99 | curVenue = line.substring(2).trim(); 100 | curVenue = curVenue.replaceAll(" ", ""); 101 | if(check(curVenue)==1) 102 | writeStatus = 1; 103 | else 104 | writeStatus = 0; 105 | } 106 | if(line.indexOf("#%")==0) 107 | curRef.add(line.substring(2).trim()); 108 | if(line.indexOf("#@")==0) 109 | { 110 | curAuthors = line.substring(2).trim(); 111 | curAuthorList = curAuthors.split(","); 112 | } 113 | if(line.indexOf("#index")==0) 114 | curIndex = line.substring(6).trim(); 115 | if(line.trim().length()==0) 116 | { 117 | if((!curTitle.equals(""))&&(!curVenue.equals(""))&&(!curAuthors.equals(""))&&(!curIndex.equals(""))&&writeStatus==1) 118 | { 119 | curIndex = curIndex.replaceAll(" ", ""); 120 | bwTitle.write("p_"+curIndex+"\t"+curTitle+"\n"); 121 | for(int i=0;imaxRefSize) 137 | { 138 | maxRefSize = curRef.size(); 139 | maxRefIndex = curIndex; 140 | } 141 | numOfNode++; 142 | } 143 | curTitle = ""; 144 | curAuthors = ""; 145 | curVenue = ""; 146 | curRef.clear(); 147 | curIndex = ""; 148 | } 149 | } 150 | br.close(); 151 | bwTitle.close(); 152 | bwAuthor.close(); 153 | bwVenue.close(); 154 | bwRef.close(); 155 | System.out.println("numOfNode = "+numOfNode); 156 | System.out.println("maxRefSize = "+maxRefSize); 157 | System.out.println("maxRefIndex = "+maxRefIndex); 158 | } 159 | 160 | } 161 | -------------------------------------------------------------------------------- /DBLPPreProcess/src/PreProcessStep2.java: -------------------------------------------------------------------------------- 1 | import java.io.BufferedReader; 2 | import java.io.BufferedWriter; 3 | import java.io.File; 4 | import java.io.FileInputStream; 5 | import java.io.FileOutputStream; 6 | import java.io.IOException; 7 | import java.io.InputStreamReader; 8 | import java.io.OutputStreamWriter; 9 | import java.util.ArrayList; 10 | import java.util.Collections; 11 | import java.util.Comparator; 12 | import java.util.HashMap; 13 | import java.util.Map; 14 | import java.util.Random; 15 | 16 | public class PreProcessStep2 { // obtain the class labels for authors; 17 | 18 | public static int getClass(String venue) 19 | { 20 | String[] venues1 = {"SIGMOD", "ICDE", "VLDB", "EDBT", 21 | "PODS", "ICDT", "DASFAA", "SSDBM", "CIKM"}; 22 | String[] venues2 = {"KDD", "ICDM", "SDM", "PKDD", "PAKDD"}; 23 | String[] venues3 = {"IJCAI", "AAAI", "NIPS", "ICML", "ECML", 24 | "ACML", "IJCNN", "UAI", "ECAI", "COLT", "ACL", "KR"}; 25 | String[] venues4 = {"CVPR", "ICCV", "ECCV", "ACCV", "MM", 26 | "ICPR", "ICIP", "ICME"}; 27 | ArrayList> venueLists = new ArrayList>(); 28 | ArrayList venueList1 = new ArrayList(); 29 | for(int i=0;i venueList2 = new ArrayList(); 33 | for(int i=0;i venueList3 = new ArrayList(); 37 | for(int i=0;i venueList4 = new ArrayList(); 41 | for(int i=0;i0;i--) 50 | { 51 | int randNumClass = randomClass.nextInt(i); 52 | int temp = arrClass[i]; 53 | arrClass[i] = arrClass[randNumClass]; 54 | arrClass[randNumClass] = temp; 55 | } 56 | for(int i=0;i<4;i++) 57 | { 58 | for(int j=0;j author2Id = new HashMap(); 74 | Map venue2Id = new HashMap(); 75 | Map paper2Id = new HashMap(); 76 | Map Id2Author = new HashMap(); 77 | Map Id2Venue = new HashMap(); 78 | Map Id2Paper = new HashMap(); 79 | 80 | String line = null; 81 | int paperIndex = 0, authorIndex = 0, venueIndex = 0; 82 | while((line=brAuthor.readLine())!=null) 83 | { 84 | int splitPos = line.indexOf("\t"); 85 | String paper = line.substring(0, splitPos); 86 | String author = line.substring(splitPos+1); 87 | if(!paper2Id.containsKey(paper)) 88 | { 89 | paper2Id.put(paper, paperIndex); 90 | Id2Paper.put(paperIndex, paper); 91 | paperIndex++; 92 | } 93 | if(!author2Id.containsKey(author)) 94 | { 95 | author2Id.put(author, authorIndex); 96 | Id2Author.put(authorIndex, author); 97 | authorIndex++; 98 | } 99 | } 100 | brAuthor.close(); 101 | while((line=brVenue.readLine())!=null) 102 | { 103 | int splitPos = line.indexOf("\t"); 104 | String venue = line.substring(splitPos+1); 105 | if(!venue2Id.containsKey(venue)) 106 | { 107 | venue2Id.put(venue, venueIndex); 108 | Id2Venue.put(venueIndex, venue); 109 | venueIndex++; 110 | } 111 | } 112 | brVenue.close(); 113 | int numOfPaper = paperIndex; 114 | int numOfAuthor = authorIndex; 115 | int numOfVenue = venueIndex; 116 | 117 | BufferedWriter bwPaperPaper = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("dataset"+File.separator+"DBLP-V3.paper.paper.txt"), "UTF-8")); 118 | while((line=brRef.readLine())!=null) 119 | { 120 | int splitPos = line.indexOf("\t"); 121 | String paper1 = line.substring(0, splitPos); 122 | String paper2 = line.substring(splitPos+1); 123 | if(paper2Id.containsKey(paper2)) 124 | bwPaperPaper.write(paper1+"\t"+paper2+"\n"); 125 | } 126 | brRef.close(); 127 | bwPaperPaper.close(); 128 | 129 | ArrayList> authorNeighbors = new ArrayList>(); 130 | ArrayList> paperNeighborsVenue = new ArrayList>(); 131 | for(int i=0;i neighbors = new ArrayList(); 134 | authorNeighbors.add(neighbors); 135 | } 136 | for(int i=0;i neighborsVenue = new ArrayList(); 139 | paperNeighborsVenue.add(neighborsVenue); 140 | } 141 | brAuthor = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.author.txt"), "UTF-8")); 142 | brVenue = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.venue.txt"), "UTF-8")); 143 | while((line=brAuthor.readLine())!=null) 144 | { 145 | int splitPos = line.indexOf("\t"); 146 | String paper = line.substring(0, splitPos); 147 | String author = line.substring(splitPos+1); 148 | int paperId = paper2Id.get(paper); 149 | int authorId = author2Id.get(author); 150 | if(!authorNeighbors.get(authorId).contains(paperId)) 151 | authorNeighbors.get(authorId).add(paperId); 152 | } 153 | brAuthor.close(); 154 | while((line=brVenue.readLine())!=null) 155 | { 156 | int splitPos = line.indexOf("\t"); 157 | String paper = line.substring(0, splitPos); 158 | String venue = line.substring(splitPos+1); 159 | int paperId = paper2Id.get(paper); 160 | int venueId = venue2Id.get(venue); 161 | if(!paperNeighborsVenue.get(paperId).contains(venueId)) 162 | paperNeighborsVenue.get(paperId).add(venueId); 163 | } 164 | brVenue.close(); 165 | for(int i=0;i() { 168 | public int compare(Integer a, Integer b) { 169 | return a.compareTo(b); 170 | } 171 | }); 172 | } 173 | for(int i=0;i() { 176 | public int compare(Integer a, Integer b) { 177 | return a.compareTo(b); 178 | } 179 | }); 180 | } 181 | 182 | int[][] venueClass = new int[numOfVenue][4]; 183 | int[][] authorClass = new int[numOfAuthor][4]; 184 | for(int i=0;i0;j--) 215 | { 216 | int randNumClass = randomClass.nextInt(j); 217 | int temp = arrClass[j]; 218 | arrClass[j] = arrClass[randNumClass]; 219 | arrClass[randNumClass] = temp; 220 | } 221 | int label = -1; 222 | int maxNum = -1; 223 | for(int j=0;j<4;j++) 224 | { 225 | if(authorClass[i][arrClass[j]]>maxNum) 226 | { 227 | maxNum = authorClass[i][arrClass[j]]; 228 | label = arrClass[j]; 229 | } 230 | } 231 | if(label==-1) 232 | System.out.println("no class label for author "+i); 233 | else 234 | { 235 | for(int j=0;j<4;j++) 236 | authorClass[i][j] = 0; 237 | authorClass[i][label] = 1; 238 | } 239 | } 240 | 241 | BufferedWriter bwAuthorClass = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("group"+File.separator+"DBLP-V3.author.class.txt"), "UTF-8")); 242 | for(int i=0;i author2Id; 19 | public static Map venue2Id; 20 | public static Map paper2Id; 21 | public static Map Id2Author; 22 | public static Map Id2Venue; 23 | public static Map Id2Paper; 24 | 25 | public static ArrayList> authorNeighbors; 26 | public static ArrayList> venueNeighbors; 27 | public static ArrayList> paperNeighborsAuthor; 28 | public static ArrayList> paperNeighborsPaper; 29 | public static ArrayList> paperNeighborsVenue; 30 | 31 | public static String metaGraphRandomWalk(int startNodeId, String type, int pathLen) 32 | { 33 | int curNodeId = startNodeId; 34 | String curType = type; 35 | int curPos = 1; 36 | Random random = new Random(); 37 | StringBuffer path = new StringBuffer(""); 38 | for(int i=0;i0) 44 | path.append(" "); 45 | path.append(paper); 46 | if(curPos==2) 47 | { 48 | if(paperNeighborsAuthor.get(curNodeId).size()>1&&paperNeighborsVenue.get(curNodeId).size()>0) 49 | { 50 | double option = Math.random(); 51 | if(option<1.0/2.0) 52 | { 53 | int nextIndex = random.nextInt(paperNeighborsAuthor.get(curNodeId).size()); 54 | curNodeId = paperNeighborsAuthor.get(curNodeId).get(nextIndex); 55 | curPos++; 56 | curType = "A"; 57 | } 58 | else 59 | { 60 | int nextIndex = random.nextInt(paperNeighborsVenue.get(curNodeId).size()); 61 | curNodeId = paperNeighborsVenue.get(curNodeId).get(nextIndex); 62 | curPos++; 63 | curType = "V"; 64 | } 65 | } 66 | else if(paperNeighborsAuthor.get(curNodeId).size()>1) 67 | { 68 | int nextIndex = random.nextInt(paperNeighborsAuthor.get(curNodeId).size()); 69 | curNodeId = paperNeighborsAuthor.get(curNodeId).get(nextIndex); 70 | curPos++; 71 | curType = "A"; 72 | } 73 | else if(paperNeighborsVenue.get(curNodeId).size()>0) 74 | { 75 | int nextIndex = random.nextInt(paperNeighborsVenue.get(curNodeId).size()); 76 | curNodeId = paperNeighborsVenue.get(curNodeId).get(nextIndex); 77 | curPos++; 78 | curType = "V"; 79 | } 80 | else 81 | return path.toString(); 82 | } 83 | else 84 | { 85 | if(paperNeighborsAuthor.get(curNodeId).size()==0) 86 | return path.toString(); 87 | int nextIndex = random.nextInt(paperNeighborsAuthor.get(curNodeId).size()); 88 | curNodeId = paperNeighborsAuthor.get(curNodeId).get(nextIndex); 89 | curPos++; 90 | if(curPos>=5) 91 | curPos = 1; 92 | curType = "A"; 93 | } 94 | } 95 | else 96 | { 97 | if(curType.equals("A")) 98 | { 99 | String author = Id2Author.get(curNodeId); 100 | if(i>0) 101 | path.append(" "); 102 | path.append(author); 103 | if(authorNeighbors.get(curNodeId).size()==0) 104 | return path.toString(); 105 | int nextIndex = random.nextInt(authorNeighbors.get(curNodeId).size()); 106 | curNodeId = authorNeighbors.get(curNodeId).get(nextIndex); 107 | curPos++; 108 | curType = "P"; 109 | } 110 | else if(curType.equals("V")) 111 | { 112 | String venue = Id2Venue.get(curNodeId); 113 | if(i>0) 114 | path.append(" "); 115 | path.append(venue); 116 | if(venueNeighbors.get(curNodeId).size()==0) 117 | return path.toString(); 118 | int nextIndex = random.nextInt(venueNeighbors.get(curNodeId).size()); 119 | curNodeId = venueNeighbors.get(curNodeId).get(nextIndex); 120 | curPos++; 121 | curType = "P"; 122 | } 123 | } 124 | } 125 | return path.toString(); 126 | } 127 | 128 | public static String metaPathAPVPARandomWalk(int startNodeId, String type, int pathLen) 129 | { 130 | int curNodeId = startNodeId; 131 | String curType = type; 132 | String lastType = ""; 133 | Random random = new Random(); 134 | StringBuffer path = new StringBuffer(""); 135 | for(int i=0;i0) 141 | path.append(" "); 142 | path.append(paper); 143 | if(lastType.equals("A")) 144 | { 145 | if(paperNeighborsVenue.get(curNodeId).size()==0) 146 | return path.toString(); 147 | int nextIndex = random.nextInt(paperNeighborsVenue.get(curNodeId).size()); 148 | curNodeId = paperNeighborsVenue.get(curNodeId).get(nextIndex); 149 | lastType = curType; 150 | curType = "V"; 151 | } 152 | else 153 | { 154 | if(paperNeighborsAuthor.get(curNodeId).size()==0) 155 | return path.toString(); 156 | int nextIndex = random.nextInt(paperNeighborsAuthor.get(curNodeId).size()); 157 | curNodeId = paperNeighborsAuthor.get(curNodeId).get(nextIndex); 158 | lastType = curType; 159 | curType = "A"; 160 | } 161 | } 162 | else 163 | { 164 | if(curType.equals("A")) 165 | { 166 | String author = Id2Author.get(curNodeId); 167 | if(i>0) 168 | path.append(" "); 169 | path.append(author); 170 | if(authorNeighbors.get(curNodeId).size()==0) 171 | return path.toString(); 172 | int nextIndex = random.nextInt(authorNeighbors.get(curNodeId).size()); 173 | curNodeId = authorNeighbors.get(curNodeId).get(nextIndex); 174 | lastType = curType; 175 | curType = "P"; 176 | } 177 | else 178 | { 179 | String venue = Id2Venue.get(curNodeId); 180 | if(i>0) 181 | path.append(" "); 182 | path.append(venue); 183 | if(venueNeighbors.get(curNodeId).size()==0) 184 | return path.toString(); 185 | int nextIndex = random.nextInt(venueNeighbors.get(curNodeId).size()); 186 | curNodeId = venueNeighbors.get(curNodeId).get(nextIndex); 187 | lastType = curType; 188 | curType = "P"; 189 | } 190 | } 191 | } 192 | return path.toString(); 193 | } 194 | 195 | public static String metaPathAPAPARandomWalk(int startNodeId, String type, int pathLen) 196 | { 197 | int curNodeId = startNodeId; 198 | String curType = type; 199 | Random random = new Random(); 200 | StringBuffer path = new StringBuffer(""); 201 | for(int i=0;i0) 207 | path.append(" "); 208 | path.append(paper); 209 | if(paperNeighborsAuthor.get(curNodeId).size()<=1) 210 | return path.toString(); 211 | int nextIndex = random.nextInt(paperNeighborsAuthor.get(curNodeId).size()); 212 | curNodeId = paperNeighborsAuthor.get(curNodeId).get(nextIndex); 213 | curType = "A"; 214 | } 215 | else 216 | { 217 | String author = Id2Author.get(curNodeId); 218 | if(i>0) 219 | path.append(" "); 220 | path.append(author); 221 | if(authorNeighbors.get(curNodeId).size()==0) 222 | return path.toString(); 223 | int nextIndex = random.nextInt(authorNeighbors.get(curNodeId).size()); 224 | curNodeId = authorNeighbors.get(curNodeId).get(nextIndex); 225 | curType = "P"; 226 | } 227 | } 228 | return path.toString(); 229 | } 230 | 231 | public static String uniformRandomWalk(int startNodeId, String type, int pathLen) 232 | { 233 | int curNodeId = startNodeId; 234 | String curType = type; 235 | Random random = new Random(); 236 | StringBuffer path = new StringBuffer(""); 237 | for(int i=0;i0) 243 | path.append(" "); 244 | path.append(paper); 245 | int authorSize = paperNeighborsAuthor.get(curNodeId).size(); 246 | int paperSize = paperNeighborsPaper.get(curNodeId).size(); 247 | int venueSize = paperNeighborsVenue.get(curNodeId).size(); 248 | if(authorSize+paperSize+venueSize==0) 249 | return path.toString(); 250 | int nextIndex = random.nextInt(authorSize+paperSize+venueSize); 251 | if(nextIndex0) 273 | path.append(" "); 274 | path.append(author); 275 | if(authorNeighbors.get(curNodeId).size()==0) 276 | return path.toString(); 277 | int nextIndex = random.nextInt(authorNeighbors.get(curNodeId).size()); 278 | curNodeId = authorNeighbors.get(curNodeId).get(nextIndex); 279 | curType = "P"; 280 | } 281 | else if(curType.equals("V")) 282 | { 283 | String venue = Id2Venue.get(curNodeId); 284 | if(i>0) 285 | path.append(" "); 286 | path.append(venue); 287 | if(venueNeighbors.get(curNodeId).size()==0) 288 | return path.toString(); 289 | int nextIndex = random.nextInt(venueNeighbors.get(curNodeId).size()); 290 | curNodeId = venueNeighbors.get(curNodeId).get(nextIndex); 291 | curType = "P"; 292 | } 293 | } 294 | } 295 | return path.toString(); 296 | } 297 | 298 | public static void main(String[] args) throws IOException { 299 | // TODO Auto-generated method stub 300 | BufferedReader brPaperAuthor = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.author.txt"), "UTF-8")); 301 | BufferedReader brPaperVenue = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.venue.txt"), "UTF-8")); 302 | BufferedReader brPaperPaper = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.paper.txt"), "UTF-8")); 303 | 304 | author2Id = new HashMap(); 305 | venue2Id = new HashMap(); 306 | paper2Id = new HashMap(); 307 | Id2Author = new HashMap(); 308 | Id2Venue = new HashMap(); 309 | Id2Paper = new HashMap(); 310 | 311 | String line = null; 312 | int paperIndex = 0, authorIndex = 0, venueIndex = 0; 313 | while((line=brPaperAuthor.readLine())!=null) 314 | { 315 | int splitPos = line.indexOf("\t"); 316 | String paper = line.substring(0, splitPos); 317 | String author = line.substring(splitPos+1); 318 | if(!paper2Id.containsKey(paper)) 319 | { 320 | paper2Id.put(paper, paperIndex); 321 | Id2Paper.put(paperIndex, paper); 322 | paperIndex++; 323 | } 324 | if(!author2Id.containsKey(author)) 325 | { 326 | author2Id.put(author, authorIndex); 327 | Id2Author.put(authorIndex, author); 328 | authorIndex++; 329 | } 330 | } 331 | brPaperAuthor.close(); 332 | while((line=brPaperVenue.readLine())!=null) 333 | { 334 | int splitPos = line.indexOf("\t"); 335 | String venue = line.substring(splitPos+1); 336 | if(!venue2Id.containsKey(venue)) 337 | { 338 | venue2Id.put(venue, venueIndex); 339 | Id2Venue.put(venueIndex, venue); 340 | venueIndex++; 341 | } 342 | } 343 | brPaperVenue.close(); 344 | int numOfPaper = paperIndex; 345 | int numOfAuthor = authorIndex; 346 | int numOfVenue = venueIndex; 347 | 348 | authorNeighbors = new ArrayList>(); 349 | venueNeighbors = new ArrayList>(); 350 | paperNeighborsAuthor = new ArrayList>(); 351 | paperNeighborsPaper = new ArrayList>(); 352 | paperNeighborsVenue = new ArrayList>(); 353 | 354 | for(int i=0;i neighbors = new ArrayList(); 357 | authorNeighbors.add(neighbors); 358 | } 359 | for(int i=0;i neighbors = new ArrayList(); 362 | venueNeighbors.add(neighbors); 363 | } 364 | for(int i=0;i neighborsAuthor = new ArrayList(); 367 | paperNeighborsAuthor.add(neighborsAuthor); 368 | ArrayList neighborsPaper = new ArrayList(); 369 | paperNeighborsPaper.add(neighborsPaper); 370 | ArrayList neighborsVenue = new ArrayList(); 371 | paperNeighborsVenue.add(neighborsVenue); 372 | } 373 | 374 | brPaperAuthor = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.author.txt"), "UTF-8")); 375 | brPaperVenue = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.venue.txt"), "UTF-8")); 376 | while((line=brPaperAuthor.readLine())!=null) 377 | { 378 | int splitPos = line.indexOf("\t"); 379 | String paper = line.substring(0, splitPos); 380 | String author = line.substring(splitPos+1); 381 | int paperId = paper2Id.get(paper); 382 | int authorId = author2Id.get(author); 383 | if(!authorNeighbors.get(authorId).contains(paperId)) 384 | authorNeighbors.get(authorId).add(paperId); 385 | if(!paperNeighborsAuthor.get(paperId).contains(authorId)) 386 | paperNeighborsAuthor.get(paperId).add(authorId); 387 | } 388 | brPaperAuthor.close(); 389 | while((line=brPaperPaper.readLine())!=null) 390 | { 391 | int splitPos = line.indexOf("\t"); 392 | String paper1 = line.substring(0, splitPos); 393 | String paper2 = line.substring(splitPos+1); 394 | int paperId1 = paper2Id.get(paper1); 395 | int paperId2 = paper2Id.get(paper2); 396 | if(!paperNeighborsPaper.get(paperId1).contains(paperId2)) 397 | paperNeighborsPaper.get(paperId1).add(paperId2); 398 | if(!paperNeighborsPaper.get(paperId2).contains(paperId1)) 399 | paperNeighborsPaper.get(paperId2).add(paperId1); 400 | } 401 | brPaperPaper.close(); 402 | while((line=brPaperVenue.readLine())!=null) 403 | { 404 | int splitPos = line.indexOf("\t"); 405 | String paper = line.substring(0, splitPos); 406 | String venue = line.substring(splitPos+1); 407 | int paperId = paper2Id.get(paper); 408 | int venueId = venue2Id.get(venue); 409 | if(!venueNeighbors.get(venueId).contains(paperId)) 410 | venueNeighbors.get(venueId).add(paperId); 411 | if(!paperNeighborsVenue.get(paperId).contains(venueId)) 412 | paperNeighborsVenue.get(paperId).add(venueId); 413 | } 414 | brPaperVenue.close(); 415 | 416 | int[] arrPaper1 = new int[numOfPaper]; 417 | for(int i=0;i0;i--) { 421 | int randNumPaper1 = randomPaper1.nextInt(i); 422 | int tempPaper1 = arrPaper1[i]; 423 | arrPaper1[i] = arrPaper1[randNumPaper1]; 424 | arrPaper1[randNumPaper1] = tempPaper1; 425 | } 426 | 427 | for(int i=0;i() { 449 | public int compare(Integer a, Integer b) { 450 | return a.compareTo(b); 451 | } 452 | }); 453 | } 454 | for(int i=0;i() { 457 | public int compare(Integer a, Integer b) { 458 | return a.compareTo(b); 459 | } 460 | }); 461 | } 462 | 463 | for(int i=0;i() { 466 | public int compare(Integer a, Integer b) { 467 | return a.compareTo(b); 468 | } 469 | }); 470 | Collections.sort(paperNeighborsPaper.get(i), new Comparator() { 471 | public int compare(Integer a, Integer b) { 472 | return a.compareTo(b); 473 | } 474 | }); 475 | Collections.sort(paperNeighborsVenue.get(i), new Comparator() { 476 | public int compare(Integer a, Integer b) { 477 | return a.compareTo(b); 478 | } 479 | }); 480 | } 481 | 482 | BufferedWriter bwMetaGraphRandWalk = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("randomwalk"+File.separator+"DBLP-V3.meta.graph.random.walk.txt"), "UTF-8")); 483 | BufferedWriter bwMetaPathRandWalkAPVPA = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("randomwalk"+File.separator+"DBLP-V3.meta.path.random.walk.apvpa.txt"), "UTF-8")); 484 | BufferedWriter bwMetaPathRandWalkAPAPA = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("randomwalk"+File.separator+"DBLP-V3.meta.path.random.walk.apapa.txt"), "UTF-8")); 485 | BufferedWriter bwMetaPathRandWalkMix = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("randomwalk"+File.separator+"DBLP-V3.meta.path.random.walk.mix.txt"), "UTF-8")); 486 | BufferedWriter bwUniformRandWalk = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("randomwalk"+File.separator+"DBLP-V3.uniform.random.walk.txt"), "UTF-8")); 487 | int pathLen = 100; 488 | int numOfWalk = 80; 489 | for(int i=0;i0;i--) { 522 | int randNum1 = random1.nextInt(i); 523 | int temp1 = arr1[i]; 524 | arr1[i] = arr1[randNum1]; 525 | arr1[randNum1] = temp1; 526 | } 527 | for(int i=0;i 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | // The metapath2vec.cpp code was built upon the word2vec.c from https://code.google.com/archive/p/word2vec/ 18 | 19 | // Modifications Copyright (C) 2016 20 | // 21 | // Licensed under the Apache License, Version 2.0 (the "License"); 22 | // you may not use this file except in compliance with the License. 23 | // You may obtain a copy of the License at 24 | // 25 | // http://www.apache.org/licenses/LICENSE-2.0 26 | // 27 | // Unless required by applicable law or agreed to in writing, software 28 | // distributed under the License is distributed on an "AS IS" BASIS, 29 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 30 | // See the License for the specific language governing permissions and 31 | // limitations under the License. 32 | 33 | // Copyright 2013 Google Inc. All Rights Reserved. 34 | // 35 | // Licensed under the Apache License, Version 2.0 (the "License"); 36 | // you may not use this file except in compliance with the License. 37 | // You may obtain a copy of the License at 38 | // 39 | // http://www.apache.org/licenses/LICENSE-2.0 40 | // 41 | // Unless required by applicable law or agreed to in writing, software 42 | // distributed under the License is distributed on an "AS IS" BASIS, 43 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 44 | // See the License for the specific language governing permissions and 45 | // limitations under the License. 46 | 47 | #include 48 | #include 49 | #include 50 | #include 51 | #include 52 | 53 | #define MAX_STRING 100 54 | #define EXP_TABLE_SIZE 1000 55 | #define MAX_EXP 6 56 | #define MAX_WALK_LENGTH 1000 57 | 58 | const long long node_hash_size = 10000000; // Maximum 10M nodes 59 | const long long node_context_hash_size = 1000000000; // Maximum 1G node context pairs 60 | 61 | struct Node 62 | { 63 | long long cn; 64 | char *node_str; 65 | }; 66 | 67 | struct Node_Context 68 | { 69 | long long cn; 70 | long long source, target; 71 | }; 72 | 73 | char train_file[MAX_STRING], output_file[MAX_STRING]; 74 | struct Node *node_list; 75 | long long window = 5; 76 | long long pp = 1; // pp = 1 with Heterogeneous Skip-Gram, pp = 0 with Homogeneous Skip-Gram 77 | long long *node_hash; 78 | long long node_list_max_size = 100000, node_list_size = 0, layer1_size = 128; 79 | long long node_occur_num = 0; 80 | double alpha = 0.025, starting_alpha; 81 | double *syn0, *syn1, *syn1neg, *expTable; 82 | double **type_syn1; 83 | clock_t start, finish; 84 | 85 | long long type_num; 86 | long long **type_tables, *type_counts, *type_indices; 87 | char prefixes[MAX_STRING]; 88 | long long obj_type; 89 | 90 | long long **type_node_hash; 91 | long long *type_node_list_size, *type_node_index; 92 | struct Node **type_node_list; 93 | 94 | long long negative = 5; 95 | const long table_size = 1e8; 96 | long long *table; 97 | 98 | long long *node_context_hash; 99 | long long node_context_list_size, node_context_list_max_size; 100 | struct Node_Context *node_context_list; 101 | 102 | // Parameters for node context pair sampling 103 | long long *alias; 104 | double *prob; 105 | 106 | long long total_samples = 100; 107 | 108 | long long GetTypeId(char prefix) 109 | { 110 | long long i; 111 | for (i = 0; i < type_num; i++) 112 | { 113 | if (prefix == prefixes[i]) 114 | break; 115 | } 116 | if (i >= type_num) 117 | return -1; 118 | else 119 | return i; 120 | } 121 | 122 | void NodeTypeNeg() 123 | { 124 | long long i; 125 | long long target, type_Id; 126 | type_tables = (long long **)malloc(type_num * sizeof(long long *)); 127 | type_counts = (long long *)malloc(type_num * sizeof(long long)); 128 | type_indices = (long long *)malloc(type_num * sizeof(long long)); 129 | if (type_tables == NULL || type_counts == NULL || type_indices == NULL) 130 | { 131 | printf("Memory allocation failed\n"); 132 | exit(1); 133 | } 134 | for (i = 0; i < type_num; i++) 135 | { 136 | type_counts[i] = 0; 137 | type_indices[i] = 0; 138 | } 139 | for (i = 0; i < table_size; i++) 140 | { 141 | target = table[i]; 142 | type_Id = GetTypeId(node_list[target].node_str[0]); 143 | if (type_Id == -1) 144 | { 145 | printf("Unrecognised type\n"); 146 | exit(1); 147 | } 148 | type_counts[type_Id]++; 149 | } 150 | for (i = 0; i< type_num; i++) 151 | { 152 | type_tables[i] = (long long *)malloc(type_counts[i] * sizeof(long long)); 153 | if (type_tables[i] == NULL) 154 | { 155 | printf("Memory allocation failed\n"); 156 | exit(1); 157 | } 158 | } 159 | for (i = 0; i < table_size; i++) 160 | { 161 | target = table[i]; 162 | type_Id = GetTypeId(node_list[target].node_str[0]); 163 | type_tables[type_Id][type_indices[type_Id]] = target; 164 | type_indices[type_Id]++; 165 | } 166 | for (i = 0; i < type_num; i++) 167 | printf("type %c table size: %lld\n", prefixes[i], type_counts[i]); 168 | } 169 | 170 | void InitUnigramTable() 171 | { 172 | long long a, i; 173 | double train_nodes_pow = 0.0, d1, power = 0.75; 174 | table = (long long *)malloc(table_size * sizeof(long long)); 175 | if (table == NULL) 176 | { 177 | printf("Memory allocation failed\n"); 178 | exit(1); 179 | } 180 | for (a = 0; a < node_list_size; a++) train_nodes_pow += pow(node_list[a].cn, power); 181 | i = 0; 182 | d1 = pow(node_list[i].cn, power) / train_nodes_pow; 183 | for (a = 0; a < table_size; a++) 184 | { 185 | table[a] = i; 186 | if (a / (double)table_size > d1) 187 | { 188 | i++; 189 | d1 += pow(node_list[i].cn, power) / (double)train_nodes_pow; 190 | } 191 | if (i >= node_list_size) i = node_list_size - 1; 192 | } 193 | if (pp == 1) 194 | NodeTypeNeg(); 195 | } 196 | 197 | // Reads a single node from a file, assuming space + tab + EOL to be node name boundaries 198 | long long ReadNode(char *node_str, FILE *fin) // return 1, if the boundary is '\n', otherwise 0 199 | { 200 | long long a = 0, ch, flag = 0; 201 | while (!feof(fin)) 202 | { 203 | ch = fgetc(fin); 204 | if (ch == 13) continue; 205 | if ((ch == ' ') || (ch == '\t') || (ch == '\n')) 206 | { 207 | if (a > 0) 208 | { 209 | if (ch == '\n') flag = 1; 210 | break; 211 | } 212 | continue; 213 | } 214 | node_str[a] = ch; 215 | a++; 216 | if (a >= MAX_STRING - 1) a--; 217 | } 218 | node_str[a] = 0; 219 | return flag; 220 | } 221 | 222 | // Return hash value of a node 223 | long long GetNodeHash(char *node_str) 224 | { 225 | long long a, hash = 0; 226 | for (a = 0; a < strlen(node_str); a++) 227 | { 228 | hash = hash * 257 + node_str[a]; 229 | hash = hash % node_hash_size; 230 | } 231 | return hash; 232 | } 233 | 234 | // Return hash value for a node context pair 235 | long long GetNodeContextHash(char *node_str, char *context_node_str) 236 | { 237 | long long a, hash = 0; 238 | for (a = 0; a < strlen(node_str); a++) 239 | { 240 | hash = hash * 257 + node_str[a]; 241 | hash = hash % node_context_hash_size; 242 | } 243 | for (a = 0; a < strlen(context_node_str); a++) 244 | { 245 | hash = hash * 257 + context_node_str[a]; 246 | hash = hash % node_context_hash_size; 247 | } 248 | return hash; 249 | } 250 | 251 | // Return position of a node in the node list; if the node is not found, returns -1 252 | long long SearchNode(char *node_str) 253 | { 254 | long long hash = GetNodeHash(node_str); 255 | while (1) 256 | { 257 | if (node_hash[hash] == -1) return -1; 258 | if (!strcmp(node_str, node_list[node_hash[hash]].node_str)) return node_hash[hash]; 259 | hash = (hash + 1) % node_hash_size; 260 | } 261 | return -1; 262 | } 263 | 264 | // Return position of a node in the node list of the specific type; if the node is not found, returns -1 265 | long long TypeSearchNode(char *node_str, long long type_Id) 266 | { 267 | long long hash = GetNodeHash(node_str); 268 | while (1) 269 | { 270 | if (type_node_hash[type_Id][hash] == -1) return -1; 271 | if (!strcmp(node_str, type_node_list[type_Id][type_node_hash[type_Id][hash]].node_str)) return type_node_hash[type_Id][hash]; 272 | hash = (hash + 1) % node_hash_size; 273 | } 274 | return -1; 275 | } 276 | 277 | // Return position of a node context pair in the node context list; if the node is not found, returns -1 278 | long long SearchNodeContextPair(long long node, long long context) 279 | { 280 | long long hash = GetNodeContextHash(node_list[node].node_str, node_list[context].node_str); 281 | long long hash_iter = 0; 282 | long long cur_node, cur_context; 283 | while(1) 284 | { 285 | if (node_context_hash[hash] == -1) return -1; 286 | cur_node = node_context_list[node_context_hash[hash]].source; 287 | cur_context = node_context_list[node_context_hash[hash]].target; 288 | if (cur_node == node && cur_context == context) return node_context_hash[hash]; 289 | hash = (hash + 1) % node_context_hash_size; 290 | hash_iter++; 291 | if (hash_iter >= node_context_hash_size) 292 | { 293 | printf("The node context hash table is full!\n"); 294 | exit(1); 295 | } 296 | } 297 | return -1; 298 | } 299 | 300 | // Adds a node to the node list 301 | long long AddNodeToList(char *node_str) 302 | { 303 | long long hash; 304 | long long length = strlen(node_str) + 1; 305 | if (length > MAX_STRING) length = MAX_STRING; 306 | node_list[node_list_size].node_str = (char *)calloc(length, sizeof(char)); 307 | if (node_list[node_list_size].node_str == NULL) 308 | { 309 | printf("Memory allocation failed\n"); 310 | exit(1); 311 | } 312 | strcpy(node_list[node_list_size].node_str, node_str); 313 | node_list[node_list_size].cn = 1; 314 | node_list_size++; 315 | // Reallocate memory if needed 316 | if (node_list_size >= node_list_max_size) 317 | { 318 | node_list_max_size += 10000; 319 | node_list = (struct Node *)realloc(node_list, node_list_max_size * sizeof(struct Node)); 320 | if (node_list == NULL) 321 | { 322 | printf("Memory allocation failed\n"); 323 | exit(1); 324 | } 325 | } 326 | hash = GetNodeHash(node_str); 327 | while (node_hash[hash] != -1) hash = (hash + 1) % node_hash_size; 328 | node_hash[hash] = node_list_size - 1; 329 | return node_list_size - 1; 330 | } 331 | 332 | //Add node context pair to the node context list 333 | long long AddNodeContextToList(long long node, long long context) 334 | { 335 | long long hash; 336 | long long hash_iter = 0; 337 | node_context_list[node_context_list_size].source = node; 338 | node_context_list[node_context_list_size].target = context; 339 | node_context_list[node_context_list_size].cn = 1; 340 | node_context_list_size++; 341 | if (node_context_list_size >= node_context_list_max_size) 342 | { 343 | node_context_list_max_size += 100 * node_list_size; 344 | node_context_list = (struct Node_Context *)realloc(node_context_list, node_context_list_max_size * sizeof(struct Node_Context)); 345 | if (node_context_list == NULL) 346 | { 347 | printf("Memory allocation failed\n"); 348 | exit(1); 349 | } 350 | } 351 | hash = GetNodeContextHash(node_list[node].node_str, node_list[context].node_str); 352 | while (node_context_hash[hash] != -1) 353 | { 354 | hash = (hash + 1) % node_context_hash_size; 355 | hash_iter++; 356 | if (hash_iter >= node_context_hash_size) 357 | { 358 | printf("The node context hash table is full!\n"); 359 | exit(1); 360 | } 361 | } 362 | node_context_hash[hash] = node_context_list_size - 1; 363 | return node_context_list_size - 1; 364 | } 365 | 366 | void LearnNodeListFromTrainFile() 367 | { 368 | char node_str[MAX_STRING]; 369 | FILE *fin; 370 | long long a, i; 371 | for (a = 0; a < node_hash_size; a++) node_hash[a] = -1; 372 | fin = fopen(train_file, "r"); 373 | if (fin == NULL) 374 | { 375 | printf("ERROR: training data file not found!\n"); 376 | exit(1); 377 | } 378 | node_list_size = 0; 379 | printf("Reading nodes...\n"); 380 | while (1) 381 | { 382 | ReadNode(node_str, fin); 383 | if (feof(fin)) break; 384 | node_occur_num++; 385 | if (node_occur_num % 100000 == 0) 386 | { 387 | printf("Read nodes: %lldK%c", node_occur_num / 1000, 13); 388 | fflush(stdout); 389 | } 390 | i = SearchNode(node_str); 391 | if (i == -1) AddNodeToList(node_str); 392 | else node_list[i].cn++; 393 | } 394 | printf("Node list size: %lld \n", node_list_size); 395 | printf("Node occurrence times in train file: %lld\n", node_occur_num); 396 | fclose(fin); 397 | } 398 | 399 | void GetNodeContextTable() 400 | { 401 | char node_str[MAX_STRING]; 402 | long long a, b, c, node, context_node, walk_length = 0, walk_position = 0; 403 | long long walk[MAX_WALK_LENGTH + 1]; 404 | long long line_num = 0; 405 | long long end_flag; 406 | FILE *fi = fopen(train_file, "r"); 407 | node_context_hash = (long long *)malloc(node_context_hash_size * sizeof(long long)); 408 | node_context_list_max_size = 1000 * node_list_size; 409 | node_context_list = (struct Node_Context *)malloc(node_context_list_max_size * sizeof(struct Node_Context)); 410 | for (a = 0; a < node_context_hash_size; a++) node_context_hash[a] = -1; 411 | node_context_list_size = 0; 412 | printf("Collecting node context pair...\n"); 413 | start = clock(); 414 | while (1) 415 | { 416 | if (walk_length == 0) 417 | { 418 | while (1) 419 | { 420 | end_flag = ReadNode(node_str, fi); 421 | if (feof(fi)) break; 422 | node = SearchNode(node_str); 423 | //if (node == -1) continue; 424 | walk[walk_length] = node; 425 | walk_length++; 426 | if (end_flag || walk_length >= MAX_WALK_LENGTH) break; 427 | } 428 | walk_position = 0; 429 | line_num++; 430 | if (line_num % 1000 == 0) 431 | { 432 | printf("Processed lines: %lldK%c", line_num / 1000, 13); 433 | fflush(stdout); 434 | } 435 | } 436 | if (feof(fi) && walk_length == 0) break; 437 | //if (walk_length == 0) continue; 438 | node = walk[walk_position]; //current node 439 | for (a = 0; a < window * 2 + 1; a++) 440 | { 441 | if (a != window) 442 | { 443 | c = walk_position - window + a; 444 | if (c < 0) continue; 445 | if (c >= walk_length) continue; 446 | context_node = walk[c]; 447 | b = SearchNodeContextPair(node, context_node); 448 | if (b == -1) AddNodeContextToList(node, context_node); 449 | else node_context_list[b].cn++; 450 | } 451 | } 452 | walk_position++; 453 | if (walk_position >= walk_length) walk_length = 0; 454 | } 455 | finish = clock(); 456 | printf("Total time: %lf secs for collecting node context pairs\n", (double)(finish - start) / CLOCKS_PER_SEC); 457 | printf("Node context pair number: %lld\n", node_context_list_size); 458 | printf("----------------------------------------------------\n"); 459 | } 460 | 461 | void InitNet() 462 | { 463 | long long a, b; 464 | unsigned long long next_random = 1; 465 | syn0 = (double *)malloc(node_list_size * layer1_size * sizeof(double)); 466 | if (syn0 == NULL) 467 | { 468 | printf("Memory allocation failed\n"); 469 | exit(1); 470 | } 471 | for (a = 0; a < node_list_size; a++) 472 | for (b = 0; b < layer1_size; b++) 473 | { 474 | next_random = next_random * (unsigned long long)25214903917 + 11; 475 | syn0[a * layer1_size + b] = (((next_random & 0xFFFF) / (double)65536) - 0.5) / layer1_size; 476 | } 477 | syn1neg = (double *)malloc(node_list_size * layer1_size * sizeof(double)); 478 | if (syn1neg == NULL) 479 | { 480 | printf("Memory allocation failed\n"); 481 | exit(1); 482 | } 483 | for (a = 0; a < node_list_size; a++) 484 | for (b = 0; b < layer1_size; b++) 485 | syn1neg[a * layer1_size + b] = 0; 486 | InitUnigramTable(); 487 | } 488 | 489 | /* The alias sampling algorithm, which is used to sample an node context pair in O(1) time. */ 490 | void InitAliasTable() 491 | { 492 | long long k; 493 | double sum = 0; 494 | long long cur_small_block, cur_large_block; 495 | long long num_small_block = 0, num_large_block = 0; 496 | double *norm_prob; 497 | long long *large_block; 498 | long long *small_block; 499 | alias = (long long *)malloc(node_context_list_size * sizeof(long long)); 500 | prob = (double *)malloc(node_context_list_size * sizeof(double)); 501 | if (alias == NULL || prob == NULL) 502 | { 503 | printf("Memory allocation failed\n"); 504 | exit(1); 505 | } 506 | norm_prob = (double*)malloc(node_context_list_size * sizeof(double)); 507 | large_block = (long long*)malloc(node_context_list_size * sizeof(long long)); 508 | small_block = (long long*)malloc(node_context_list_size * sizeof(long long)); 509 | if (norm_prob == NULL || large_block == NULL || small_block == NULL) 510 | { 511 | printf("Memory allocation failed\n"); 512 | exit(1); 513 | } 514 | for (k = 0; k != node_context_list_size; k++) sum += node_context_list[k].cn; 515 | for (k = 0; k != node_context_list_size; k++) norm_prob[k] = (double)node_context_list[k].cn * node_context_list_size / sum; 516 | for (k = node_context_list_size - 1; k >= 0; k--) 517 | { 518 | if (norm_prob[k]<1) 519 | small_block[num_small_block++] = k; 520 | else 521 | large_block[num_large_block++] = k; 522 | } 523 | while (num_small_block && num_large_block) 524 | { 525 | cur_small_block = small_block[--num_small_block]; 526 | cur_large_block = large_block[--num_large_block]; 527 | prob[cur_small_block] = norm_prob[cur_small_block]; 528 | alias[cur_small_block] = cur_large_block; 529 | norm_prob[cur_large_block] = norm_prob[cur_large_block] + norm_prob[cur_small_block] - 1; 530 | if (norm_prob[cur_large_block] < 1) 531 | small_block[num_small_block++] = cur_large_block; 532 | else 533 | large_block[num_large_block++] = cur_large_block; 534 | } 535 | while (num_large_block) prob[large_block[--num_large_block]] = 1; 536 | while (num_small_block) prob[small_block[--num_small_block]] = 1; 537 | free(norm_prob); 538 | free(small_block); 539 | free(large_block); 540 | } 541 | 542 | long long SampleAPair(double rand_value1, double rand_value2) 543 | { 544 | long long k = (long long)node_context_list_size * rand_value1; 545 | return rand_value2 < prob[k] ? k : alias[k]; 546 | } 547 | 548 | void TrainModel() 549 | { 550 | long a, b, c, d; 551 | long long node, context_node, cur_pair; 552 | long long count = 0, last_count = 0; 553 | long long l1, l2, target, label; 554 | long long type_Id; 555 | unsigned long long next_random = 1; 556 | double rand_num1, rand_num2; 557 | double f, g; 558 | double *neu1e = (double *)calloc(layer1_size, sizeof(double)); 559 | FILE *fp; 560 | InitNet(); 561 | starting_alpha = alpha; 562 | printf("Skip-Gram model with Negative Sampling:"); 563 | if (pp == 1) printf(" heterogeneous version\n"); 564 | else printf(" homogeneous version\n"); 565 | printf("Training file: %s\n", train_file); 566 | printf("Samples: %lldM\n", total_samples / 1000000); 567 | printf("Dimension: %lld\n", layer1_size); 568 | printf("Initial Alpha: %f\n", alpha); 569 | start = clock(); 570 | srand((unsigned)time(NULL)); 571 | while (1) 572 | { 573 | if (count >= total_samples) break; 574 | if (count - last_count > 10000) 575 | { 576 | last_count = count; 577 | printf("Alpha: %f Progress %.3lf%%%c", alpha, (double)count / (double)(total_samples + 1) * 100, 13); 578 | fflush(stdout); 579 | alpha = starting_alpha * (1 - (double)count / (double)(total_samples + 1)); 580 | if (alpha < starting_alpha * 0.0001) alpha = starting_alpha * 0.0001; 581 | } 582 | rand_num1 = rand() / (RAND_MAX * 1.0 + 1); 583 | rand_num2 = rand() / (RAND_MAX * 1.0 + 1); 584 | cur_pair = SampleAPair(rand_num1, rand_num2); 585 | node = node_context_list[cur_pair].source; 586 | context_node = node_context_list[cur_pair].target; 587 | l1 = node * layer1_size; 588 | for (c = 0; c < layer1_size; c++) neu1e[c] = 0; 589 | 590 | // NEGATIVE SAMPLING 591 | for (d = 0; d < negative + 1; d++) 592 | { 593 | if (d == 0) 594 | { 595 | target = context_node; 596 | label = 1; 597 | } 598 | else 599 | { 600 | next_random = next_random * (unsigned long long)25214903917 + 11; 601 | if (pp == 1) 602 | { 603 | type_Id = GetTypeId(node_list[context_node].node_str[0]); 604 | target = type_tables[type_Id][(next_random >> 16) % type_counts[type_Id]]; 605 | } 606 | else 607 | target = table[(next_random >> 16) % table_size]; 608 | if (target == context_node) continue; 609 | label = 0; 610 | } 611 | l2 = target * layer1_size; 612 | f = 0; 613 | for (c = 0; c < layer1_size; c++) f += syn0[c + l1] * syn1neg[c + l2]; 614 | if (f > MAX_EXP) f = 1; 615 | else if (f < -MAX_EXP) f = 0; 616 | else f = expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]; 617 | g = (label - f) * alpha; 618 | // Propagate errors output -> hidden 619 | for (c = 0; c < layer1_size; c++) neu1e[c] += g * syn1neg[c + l2]; 620 | // Learn weights hidden -> output 621 | for (c = 0; c < layer1_size; c++) syn1neg[c + l2] += g * syn0[c + l1]; 622 | } 623 | // Learn weights input -> hidden 624 | for (c = 0; c < layer1_size; c++) syn0[c + l1] += neu1e[c]; 625 | count++; 626 | } 627 | finish = clock(); 628 | printf("Total time: %lf secs for learning node embeddings\n", (double)(finish - start) / CLOCKS_PER_SEC); 629 | printf("----------------------------------------------------\n"); 630 | fp = fopen(output_file, "w"); 631 | // Save the learned node representations 632 | for (a = 0; a < node_list_size; a++) 633 | { 634 | if (node_list[a].node_str[0] != prefixes[obj_type]) continue; 635 | fprintf(fp, "%s", node_list[a].node_str); 636 | for (b = 0; b < layer1_size; b++) fprintf(fp, " %lf", syn0[a * layer1_size + b]); 637 | fprintf(fp, "\n"); 638 | } 639 | free(neu1e); 640 | fclose(fp); 641 | } 642 | 643 | int ArgPos(char *str, int argc, char **argv) 644 | { 645 | int a; 646 | for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) 647 | { 648 | if (a == argc - 1) 649 | { 650 | printf("Argument missing for %s\n", str); 651 | exit(1); 652 | } 653 | return a; 654 | } 655 | return -1; 656 | } 657 | 658 | int main(int argc, char **argv) 659 | { 660 | int i; 661 | if (argc == 1) 662 | { 663 | printf("---randomWalk2vec---\n"); 664 | printf("---The code and following instructions are built upon word2vec.c by Mikolov et al.---\n\n"); 665 | printf("Options:\n"); 666 | printf("Parameters for training:\n"); 667 | printf("\t-train \n"); 668 | printf("\t\tUse random walk sequences from to train the model\n"); 669 | printf("\t-output \n"); 670 | printf("\t\tUse to save the learned node vector-format representations\n"); 671 | printf("\t-size \n"); 672 | printf("\t\tSet the dimension of learned node representations; default is 128\n"); 673 | printf("\t-window \n"); 674 | printf("\t\tSet the window size for collecting node context pairs; default is 5\n"); 675 | printf("\t-pp \n"); 676 | printf("\t\tUse Heterogeneous Skip-Gram model (the ++ version) or not; default is 1 (++ version); otherwise, use 0 (Homogeneous Skip-Gram)\n"); 677 | printf("\t-prefixes \n"); 678 | printf("\t\tPrefixes of node Ids for specifying node types, e.g., ap with a for author and p for paper\n"); 679 | printf("\t-objtype\n"); 680 | printf("\t\tThe index of the objective node type in the prefixes list, for which representations to be learned\n"); 681 | printf("\t-alpha \n"); 682 | printf("\t\tSet the starting learning rate; default is 0.025\n"); 683 | printf("\t-negative \n"); 684 | printf("\t\tNumber of negative examples; default is 5, common values are 3 - 10\n"); 685 | printf("\t-samples \n"); 686 | printf("\t\tSet the number of iterations for stochastic gradient descent as Million; default is 100\n"); 687 | return 0; 688 | } 689 | if ((i = ArgPos((char *)"-size", argc, argv)) > 0) layer1_size = atoi(argv[i + 1]); 690 | if ((i = ArgPos((char *)"-train", argc, argv)) > 0) strcpy(train_file, argv[i + 1]); 691 | if ((i = ArgPos((char *)"-alpha", argc, argv)) > 0) alpha = atof(argv[i + 1]); 692 | if ((i = ArgPos((char *)"-output", argc, argv)) > 0) strcpy(output_file, argv[i + 1]); 693 | if ((i = ArgPos((char *)"-window", argc, argv)) > 0) window = atoi(argv[i + 1]); 694 | if ((i = ArgPos((char *)"-pp", argc, argv)) > 0) pp = atoi(argv[i + 1]); 695 | if ((i = ArgPos((char *)"-prefixes", argc, argv)) > 0) strcpy(prefixes, argv[i + 1]); 696 | if ((i = ArgPos((char *)"-objtype", argc, argv)) > 0) obj_type = atoi(argv[i + 1]); 697 | if ((i = ArgPos((char *)"-negative", argc, argv)) > 0) negative = atoi(argv[i + 1]); 698 | if ((i = ArgPos((char *)"-samples", argc, argv)) >0) total_samples = atoi(argv[i + 1]); 699 | total_samples = total_samples * 1000000; 700 | type_num = strlen(prefixes); 701 | printf("Number of node types: %lld\n", type_num); 702 | node_list = (struct Node *)calloc(node_list_max_size, sizeof(struct Node)); 703 | node_hash = (long long *)calloc(node_hash_size, sizeof(long long)); 704 | expTable = (double *)malloc((EXP_TABLE_SIZE + 1) * sizeof(double)); 705 | if (node_list == NULL || node_hash == NULL || expTable == NULL) 706 | { 707 | printf("Memory allocation failed\n"); 708 | exit(1); 709 | } 710 | for (i = 0; i < EXP_TABLE_SIZE; i++) 711 | { 712 | expTable[i] = exp((i / (double)EXP_TABLE_SIZE * 2 - 1) * MAX_EXP); // Precompute the exp() table 713 | expTable[i] = expTable[i] / (expTable[i] + 1); // Precompute f(x) = x / (x + 1) 714 | } 715 | LearnNodeListFromTrainFile(); 716 | GetNodeContextTable(); 717 | InitAliasTable(); 718 | TrainModel(); 719 | return 0; 720 | } 721 | --------------------------------------------------------------------------------