├── .DS_Store
├── .gitattributes
├── DBLPPreProcess
├── .DS_Store
├── .classpath
├── .project
├── .settings
│ └── org.eclipse.jdt.core.prefs
├── bin
│ └── .DS_Store
└── src
│ ├── PreProcessStep1.java
│ ├── PreProcessStep2.java
│ └── PreProcessStep3.java
├── README.md
└── RandomWalk2Vec
├── RandomWalk2VecRun.sh
├── compile.sh
└── randomWalk2vec.c
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daokunzhang/MetaGraph2Vec/7fe0fab3d0948aca69e81b5d4e230927d75676d8/.DS_Store
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/DBLPPreProcess/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daokunzhang/MetaGraph2Vec/7fe0fab3d0948aca69e81b5d4e230927d75676d8/DBLPPreProcess/.DS_Store
--------------------------------------------------------------------------------
/DBLPPreProcess/.classpath:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/DBLPPreProcess/.project:
--------------------------------------------------------------------------------
1 |
2 |
3 | DBLPPreProcess
4 |
5 |
6 |
7 |
8 |
9 | org.eclipse.jdt.core.javabuilder
10 |
11 |
12 |
13 |
14 |
15 | org.eclipse.jdt.core.javanature
16 |
17 |
18 |
--------------------------------------------------------------------------------
/DBLPPreProcess/.settings/org.eclipse.jdt.core.prefs:
--------------------------------------------------------------------------------
1 | eclipse.preferences.version=1
2 | org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled
3 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8
4 | org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve
5 | org.eclipse.jdt.core.compiler.compliance=1.8
6 | org.eclipse.jdt.core.compiler.debug.lineNumber=generate
7 | org.eclipse.jdt.core.compiler.debug.localVariable=generate
8 | org.eclipse.jdt.core.compiler.debug.sourceFile=generate
9 | org.eclipse.jdt.core.compiler.problem.assertIdentifier=error
10 | org.eclipse.jdt.core.compiler.problem.enumIdentifier=error
11 | org.eclipse.jdt.core.compiler.source=1.8
12 |
--------------------------------------------------------------------------------
/DBLPPreProcess/bin/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daokunzhang/MetaGraph2Vec/7fe0fab3d0948aca69e81b5d4e230927d75676d8/DBLPPreProcess/bin/.DS_Store
--------------------------------------------------------------------------------
/DBLPPreProcess/src/PreProcessStep1.java:
--------------------------------------------------------------------------------
1 | import java.io.BufferedReader;
2 | import java.io.BufferedWriter;
3 | import java.io.File;
4 | import java.io.FileInputStream;
5 | import java.io.FileOutputStream;
6 | import java.io.IOException;
7 | import java.io.InputStreamReader;
8 | import java.io.OutputStreamWriter;
9 | import java.util.ArrayList;
10 |
11 | public class PreProcessStep1 { // extract subgraph form the DBLP-Citation-Network-V3.txt network
12 |
13 | public static int nonAlphabet(char ch)
14 | {
15 | if(ch>='A'&&ch<='Z')
16 | return 0;
17 | if(ch>='a'&&ch<='z')
18 | return 0;
19 | return 1;
20 | }
21 |
22 | public static int check(String venue)
23 | {
24 | String[] selectedVenues = {"SIGMOD", "ICDE", "VLDB", "EDBT",
25 | "PODS", "ICDT", "DASFAA", "SSDBM", "CIKM", "KDD",
26 | "ICDM", "SDM", "PKDD", "PAKDD", "IJCAI", "AAAI",
27 | "NIPS", "ICML", "ECML", "ACML", "IJCNN", "UAI",
28 | "ECAI", "COLT", "ACL", "KR", "CVPR", "ICCV",
29 | "ECCV", "ACCV", "MM", "ICPR", "ICIP", "ICME"};
30 | int res = 0;
31 | for(int i=0;i=venue.length())
39 | res = 1;
40 | else
41 | {
42 | char endChar = venue.charAt(index+selectedVenues[i].length());
43 | if(nonAlphabet(endChar)==1)
44 | res = 1;
45 | }
46 | }
47 | else
48 | {
49 | char startChar = venue.charAt(index-1);
50 | if(nonAlphabet(startChar)==1)
51 | {
52 | if(index+selectedVenues[i].length()>=venue.length())
53 | res = 1;
54 | else
55 | {
56 | char endChar = venue.charAt(index+selectedVenues[i].length());
57 | if(nonAlphabet(endChar)==1)
58 | res = 1;
59 | }
60 | }
61 | }
62 | }
63 | if(res==1)
64 | break;
65 | }
66 | if(res==1)
67 | {
68 | if(venue.indexOf("SIGMOD")!=-1&&venue.indexOf("Data Mining")!=-1)
69 | res = 0;
70 | }
71 | return res;
72 | }
73 |
74 | public static void main(String[] args) throws IOException {
75 | // TODO Auto-generated method stub
76 | BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-Citation-Network-V3.txt"), "UTF-8"));
77 | BufferedWriter bwTitle = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("dataset"+File.separator+"DBLP-V3.paper.title.txt"), "UTF-8"));
78 | BufferedWriter bwAuthor = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("dataset"+File.separator+"DBLP-V3.paper.author.txt"), "UTF-8"));
79 | BufferedWriter bwVenue = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("dataset"+File.separator+"DBLP-V3.paper.venue.txt"), "UTF-8"));
80 | BufferedWriter bwRef = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("dataset"+File.separator+"DBLP-V3.paper.ref.txt"), "UTF-8"));
81 | int numOfNode = 0;
82 | String line = null;
83 | br.readLine();
84 | String curIndex = "";
85 | String curTitle = "";
86 | String curAuthors = "";
87 | String[] curAuthorList = null;
88 | String curVenue = "";
89 | ArrayList curRef = new ArrayList();
90 | int writeStatus = 0;
91 | int maxRefSize = 0;
92 | String maxRefIndex = "";
93 | while((line=br.readLine())!=null)
94 | {
95 | if(line.indexOf("#*")==0)
96 | curTitle = line.substring(2).trim();
97 | if(line.indexOf("#c")==0)
98 | {
99 | curVenue = line.substring(2).trim();
100 | curVenue = curVenue.replaceAll(" ", "");
101 | if(check(curVenue)==1)
102 | writeStatus = 1;
103 | else
104 | writeStatus = 0;
105 | }
106 | if(line.indexOf("#%")==0)
107 | curRef.add(line.substring(2).trim());
108 | if(line.indexOf("#@")==0)
109 | {
110 | curAuthors = line.substring(2).trim();
111 | curAuthorList = curAuthors.split(",");
112 | }
113 | if(line.indexOf("#index")==0)
114 | curIndex = line.substring(6).trim();
115 | if(line.trim().length()==0)
116 | {
117 | if((!curTitle.equals(""))&&(!curVenue.equals(""))&&(!curAuthors.equals(""))&&(!curIndex.equals(""))&&writeStatus==1)
118 | {
119 | curIndex = curIndex.replaceAll(" ", "");
120 | bwTitle.write("p_"+curIndex+"\t"+curTitle+"\n");
121 | for(int i=0;imaxRefSize)
137 | {
138 | maxRefSize = curRef.size();
139 | maxRefIndex = curIndex;
140 | }
141 | numOfNode++;
142 | }
143 | curTitle = "";
144 | curAuthors = "";
145 | curVenue = "";
146 | curRef.clear();
147 | curIndex = "";
148 | }
149 | }
150 | br.close();
151 | bwTitle.close();
152 | bwAuthor.close();
153 | bwVenue.close();
154 | bwRef.close();
155 | System.out.println("numOfNode = "+numOfNode);
156 | System.out.println("maxRefSize = "+maxRefSize);
157 | System.out.println("maxRefIndex = "+maxRefIndex);
158 | }
159 |
160 | }
161 |
--------------------------------------------------------------------------------
/DBLPPreProcess/src/PreProcessStep2.java:
--------------------------------------------------------------------------------
1 | import java.io.BufferedReader;
2 | import java.io.BufferedWriter;
3 | import java.io.File;
4 | import java.io.FileInputStream;
5 | import java.io.FileOutputStream;
6 | import java.io.IOException;
7 | import java.io.InputStreamReader;
8 | import java.io.OutputStreamWriter;
9 | import java.util.ArrayList;
10 | import java.util.Collections;
11 | import java.util.Comparator;
12 | import java.util.HashMap;
13 | import java.util.Map;
14 | import java.util.Random;
15 |
16 | public class PreProcessStep2 { // obtain the class labels for authors;
17 |
18 | public static int getClass(String venue)
19 | {
20 | String[] venues1 = {"SIGMOD", "ICDE", "VLDB", "EDBT",
21 | "PODS", "ICDT", "DASFAA", "SSDBM", "CIKM"};
22 | String[] venues2 = {"KDD", "ICDM", "SDM", "PKDD", "PAKDD"};
23 | String[] venues3 = {"IJCAI", "AAAI", "NIPS", "ICML", "ECML",
24 | "ACML", "IJCNN", "UAI", "ECAI", "COLT", "ACL", "KR"};
25 | String[] venues4 = {"CVPR", "ICCV", "ECCV", "ACCV", "MM",
26 | "ICPR", "ICIP", "ICME"};
27 | ArrayList> venueLists = new ArrayList>();
28 | ArrayList venueList1 = new ArrayList();
29 | for(int i=0;i venueList2 = new ArrayList();
33 | for(int i=0;i venueList3 = new ArrayList();
37 | for(int i=0;i venueList4 = new ArrayList();
41 | for(int i=0;i0;i--)
50 | {
51 | int randNumClass = randomClass.nextInt(i);
52 | int temp = arrClass[i];
53 | arrClass[i] = arrClass[randNumClass];
54 | arrClass[randNumClass] = temp;
55 | }
56 | for(int i=0;i<4;i++)
57 | {
58 | for(int j=0;j author2Id = new HashMap();
74 | Map venue2Id = new HashMap();
75 | Map paper2Id = new HashMap();
76 | Map Id2Author = new HashMap();
77 | Map Id2Venue = new HashMap();
78 | Map Id2Paper = new HashMap();
79 |
80 | String line = null;
81 | int paperIndex = 0, authorIndex = 0, venueIndex = 0;
82 | while((line=brAuthor.readLine())!=null)
83 | {
84 | int splitPos = line.indexOf("\t");
85 | String paper = line.substring(0, splitPos);
86 | String author = line.substring(splitPos+1);
87 | if(!paper2Id.containsKey(paper))
88 | {
89 | paper2Id.put(paper, paperIndex);
90 | Id2Paper.put(paperIndex, paper);
91 | paperIndex++;
92 | }
93 | if(!author2Id.containsKey(author))
94 | {
95 | author2Id.put(author, authorIndex);
96 | Id2Author.put(authorIndex, author);
97 | authorIndex++;
98 | }
99 | }
100 | brAuthor.close();
101 | while((line=brVenue.readLine())!=null)
102 | {
103 | int splitPos = line.indexOf("\t");
104 | String venue = line.substring(splitPos+1);
105 | if(!venue2Id.containsKey(venue))
106 | {
107 | venue2Id.put(venue, venueIndex);
108 | Id2Venue.put(venueIndex, venue);
109 | venueIndex++;
110 | }
111 | }
112 | brVenue.close();
113 | int numOfPaper = paperIndex;
114 | int numOfAuthor = authorIndex;
115 | int numOfVenue = venueIndex;
116 |
117 | BufferedWriter bwPaperPaper = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("dataset"+File.separator+"DBLP-V3.paper.paper.txt"), "UTF-8"));
118 | while((line=brRef.readLine())!=null)
119 | {
120 | int splitPos = line.indexOf("\t");
121 | String paper1 = line.substring(0, splitPos);
122 | String paper2 = line.substring(splitPos+1);
123 | if(paper2Id.containsKey(paper2))
124 | bwPaperPaper.write(paper1+"\t"+paper2+"\n");
125 | }
126 | brRef.close();
127 | bwPaperPaper.close();
128 |
129 | ArrayList> authorNeighbors = new ArrayList>();
130 | ArrayList> paperNeighborsVenue = new ArrayList>();
131 | for(int i=0;i neighbors = new ArrayList();
134 | authorNeighbors.add(neighbors);
135 | }
136 | for(int i=0;i neighborsVenue = new ArrayList();
139 | paperNeighborsVenue.add(neighborsVenue);
140 | }
141 | brAuthor = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.author.txt"), "UTF-8"));
142 | brVenue = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.venue.txt"), "UTF-8"));
143 | while((line=brAuthor.readLine())!=null)
144 | {
145 | int splitPos = line.indexOf("\t");
146 | String paper = line.substring(0, splitPos);
147 | String author = line.substring(splitPos+1);
148 | int paperId = paper2Id.get(paper);
149 | int authorId = author2Id.get(author);
150 | if(!authorNeighbors.get(authorId).contains(paperId))
151 | authorNeighbors.get(authorId).add(paperId);
152 | }
153 | brAuthor.close();
154 | while((line=brVenue.readLine())!=null)
155 | {
156 | int splitPos = line.indexOf("\t");
157 | String paper = line.substring(0, splitPos);
158 | String venue = line.substring(splitPos+1);
159 | int paperId = paper2Id.get(paper);
160 | int venueId = venue2Id.get(venue);
161 | if(!paperNeighborsVenue.get(paperId).contains(venueId))
162 | paperNeighborsVenue.get(paperId).add(venueId);
163 | }
164 | brVenue.close();
165 | for(int i=0;i() {
168 | public int compare(Integer a, Integer b) {
169 | return a.compareTo(b);
170 | }
171 | });
172 | }
173 | for(int i=0;i() {
176 | public int compare(Integer a, Integer b) {
177 | return a.compareTo(b);
178 | }
179 | });
180 | }
181 |
182 | int[][] venueClass = new int[numOfVenue][4];
183 | int[][] authorClass = new int[numOfAuthor][4];
184 | for(int i=0;i0;j--)
215 | {
216 | int randNumClass = randomClass.nextInt(j);
217 | int temp = arrClass[j];
218 | arrClass[j] = arrClass[randNumClass];
219 | arrClass[randNumClass] = temp;
220 | }
221 | int label = -1;
222 | int maxNum = -1;
223 | for(int j=0;j<4;j++)
224 | {
225 | if(authorClass[i][arrClass[j]]>maxNum)
226 | {
227 | maxNum = authorClass[i][arrClass[j]];
228 | label = arrClass[j];
229 | }
230 | }
231 | if(label==-1)
232 | System.out.println("no class label for author "+i);
233 | else
234 | {
235 | for(int j=0;j<4;j++)
236 | authorClass[i][j] = 0;
237 | authorClass[i][label] = 1;
238 | }
239 | }
240 |
241 | BufferedWriter bwAuthorClass = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("group"+File.separator+"DBLP-V3.author.class.txt"), "UTF-8"));
242 | for(int i=0;i author2Id;
19 | public static Map venue2Id;
20 | public static Map paper2Id;
21 | public static Map Id2Author;
22 | public static Map Id2Venue;
23 | public static Map Id2Paper;
24 |
25 | public static ArrayList> authorNeighbors;
26 | public static ArrayList> venueNeighbors;
27 | public static ArrayList> paperNeighborsAuthor;
28 | public static ArrayList> paperNeighborsPaper;
29 | public static ArrayList> paperNeighborsVenue;
30 |
31 | public static String metaGraphRandomWalk(int startNodeId, String type, int pathLen)
32 | {
33 | int curNodeId = startNodeId;
34 | String curType = type;
35 | int curPos = 1;
36 | Random random = new Random();
37 | StringBuffer path = new StringBuffer("");
38 | for(int i=0;i0)
44 | path.append(" ");
45 | path.append(paper);
46 | if(curPos==2)
47 | {
48 | if(paperNeighborsAuthor.get(curNodeId).size()>1&&paperNeighborsVenue.get(curNodeId).size()>0)
49 | {
50 | double option = Math.random();
51 | if(option<1.0/2.0)
52 | {
53 | int nextIndex = random.nextInt(paperNeighborsAuthor.get(curNodeId).size());
54 | curNodeId = paperNeighborsAuthor.get(curNodeId).get(nextIndex);
55 | curPos++;
56 | curType = "A";
57 | }
58 | else
59 | {
60 | int nextIndex = random.nextInt(paperNeighborsVenue.get(curNodeId).size());
61 | curNodeId = paperNeighborsVenue.get(curNodeId).get(nextIndex);
62 | curPos++;
63 | curType = "V";
64 | }
65 | }
66 | else if(paperNeighborsAuthor.get(curNodeId).size()>1)
67 | {
68 | int nextIndex = random.nextInt(paperNeighborsAuthor.get(curNodeId).size());
69 | curNodeId = paperNeighborsAuthor.get(curNodeId).get(nextIndex);
70 | curPos++;
71 | curType = "A";
72 | }
73 | else if(paperNeighborsVenue.get(curNodeId).size()>0)
74 | {
75 | int nextIndex = random.nextInt(paperNeighborsVenue.get(curNodeId).size());
76 | curNodeId = paperNeighborsVenue.get(curNodeId).get(nextIndex);
77 | curPos++;
78 | curType = "V";
79 | }
80 | else
81 | return path.toString();
82 | }
83 | else
84 | {
85 | if(paperNeighborsAuthor.get(curNodeId).size()==0)
86 | return path.toString();
87 | int nextIndex = random.nextInt(paperNeighborsAuthor.get(curNodeId).size());
88 | curNodeId = paperNeighborsAuthor.get(curNodeId).get(nextIndex);
89 | curPos++;
90 | if(curPos>=5)
91 | curPos = 1;
92 | curType = "A";
93 | }
94 | }
95 | else
96 | {
97 | if(curType.equals("A"))
98 | {
99 | String author = Id2Author.get(curNodeId);
100 | if(i>0)
101 | path.append(" ");
102 | path.append(author);
103 | if(authorNeighbors.get(curNodeId).size()==0)
104 | return path.toString();
105 | int nextIndex = random.nextInt(authorNeighbors.get(curNodeId).size());
106 | curNodeId = authorNeighbors.get(curNodeId).get(nextIndex);
107 | curPos++;
108 | curType = "P";
109 | }
110 | else if(curType.equals("V"))
111 | {
112 | String venue = Id2Venue.get(curNodeId);
113 | if(i>0)
114 | path.append(" ");
115 | path.append(venue);
116 | if(venueNeighbors.get(curNodeId).size()==0)
117 | return path.toString();
118 | int nextIndex = random.nextInt(venueNeighbors.get(curNodeId).size());
119 | curNodeId = venueNeighbors.get(curNodeId).get(nextIndex);
120 | curPos++;
121 | curType = "P";
122 | }
123 | }
124 | }
125 | return path.toString();
126 | }
127 |
128 | public static String metaPathAPVPARandomWalk(int startNodeId, String type, int pathLen)
129 | {
130 | int curNodeId = startNodeId;
131 | String curType = type;
132 | String lastType = "";
133 | Random random = new Random();
134 | StringBuffer path = new StringBuffer("");
135 | for(int i=0;i0)
141 | path.append(" ");
142 | path.append(paper);
143 | if(lastType.equals("A"))
144 | {
145 | if(paperNeighborsVenue.get(curNodeId).size()==0)
146 | return path.toString();
147 | int nextIndex = random.nextInt(paperNeighborsVenue.get(curNodeId).size());
148 | curNodeId = paperNeighborsVenue.get(curNodeId).get(nextIndex);
149 | lastType = curType;
150 | curType = "V";
151 | }
152 | else
153 | {
154 | if(paperNeighborsAuthor.get(curNodeId).size()==0)
155 | return path.toString();
156 | int nextIndex = random.nextInt(paperNeighborsAuthor.get(curNodeId).size());
157 | curNodeId = paperNeighborsAuthor.get(curNodeId).get(nextIndex);
158 | lastType = curType;
159 | curType = "A";
160 | }
161 | }
162 | else
163 | {
164 | if(curType.equals("A"))
165 | {
166 | String author = Id2Author.get(curNodeId);
167 | if(i>0)
168 | path.append(" ");
169 | path.append(author);
170 | if(authorNeighbors.get(curNodeId).size()==0)
171 | return path.toString();
172 | int nextIndex = random.nextInt(authorNeighbors.get(curNodeId).size());
173 | curNodeId = authorNeighbors.get(curNodeId).get(nextIndex);
174 | lastType = curType;
175 | curType = "P";
176 | }
177 | else
178 | {
179 | String venue = Id2Venue.get(curNodeId);
180 | if(i>0)
181 | path.append(" ");
182 | path.append(venue);
183 | if(venueNeighbors.get(curNodeId).size()==0)
184 | return path.toString();
185 | int nextIndex = random.nextInt(venueNeighbors.get(curNodeId).size());
186 | curNodeId = venueNeighbors.get(curNodeId).get(nextIndex);
187 | lastType = curType;
188 | curType = "P";
189 | }
190 | }
191 | }
192 | return path.toString();
193 | }
194 |
195 | public static String metaPathAPAPARandomWalk(int startNodeId, String type, int pathLen)
196 | {
197 | int curNodeId = startNodeId;
198 | String curType = type;
199 | Random random = new Random();
200 | StringBuffer path = new StringBuffer("");
201 | for(int i=0;i0)
207 | path.append(" ");
208 | path.append(paper);
209 | if(paperNeighborsAuthor.get(curNodeId).size()<=1)
210 | return path.toString();
211 | int nextIndex = random.nextInt(paperNeighborsAuthor.get(curNodeId).size());
212 | curNodeId = paperNeighborsAuthor.get(curNodeId).get(nextIndex);
213 | curType = "A";
214 | }
215 | else
216 | {
217 | String author = Id2Author.get(curNodeId);
218 | if(i>0)
219 | path.append(" ");
220 | path.append(author);
221 | if(authorNeighbors.get(curNodeId).size()==0)
222 | return path.toString();
223 | int nextIndex = random.nextInt(authorNeighbors.get(curNodeId).size());
224 | curNodeId = authorNeighbors.get(curNodeId).get(nextIndex);
225 | curType = "P";
226 | }
227 | }
228 | return path.toString();
229 | }
230 |
231 | public static String uniformRandomWalk(int startNodeId, String type, int pathLen)
232 | {
233 | int curNodeId = startNodeId;
234 | String curType = type;
235 | Random random = new Random();
236 | StringBuffer path = new StringBuffer("");
237 | for(int i=0;i0)
243 | path.append(" ");
244 | path.append(paper);
245 | int authorSize = paperNeighborsAuthor.get(curNodeId).size();
246 | int paperSize = paperNeighborsPaper.get(curNodeId).size();
247 | int venueSize = paperNeighborsVenue.get(curNodeId).size();
248 | if(authorSize+paperSize+venueSize==0)
249 | return path.toString();
250 | int nextIndex = random.nextInt(authorSize+paperSize+venueSize);
251 | if(nextIndex0)
273 | path.append(" ");
274 | path.append(author);
275 | if(authorNeighbors.get(curNodeId).size()==0)
276 | return path.toString();
277 | int nextIndex = random.nextInt(authorNeighbors.get(curNodeId).size());
278 | curNodeId = authorNeighbors.get(curNodeId).get(nextIndex);
279 | curType = "P";
280 | }
281 | else if(curType.equals("V"))
282 | {
283 | String venue = Id2Venue.get(curNodeId);
284 | if(i>0)
285 | path.append(" ");
286 | path.append(venue);
287 | if(venueNeighbors.get(curNodeId).size()==0)
288 | return path.toString();
289 | int nextIndex = random.nextInt(venueNeighbors.get(curNodeId).size());
290 | curNodeId = venueNeighbors.get(curNodeId).get(nextIndex);
291 | curType = "P";
292 | }
293 | }
294 | }
295 | return path.toString();
296 | }
297 |
298 | public static void main(String[] args) throws IOException {
299 | // TODO Auto-generated method stub
300 | BufferedReader brPaperAuthor = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.author.txt"), "UTF-8"));
301 | BufferedReader brPaperVenue = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.venue.txt"), "UTF-8"));
302 | BufferedReader brPaperPaper = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.paper.txt"), "UTF-8"));
303 |
304 | author2Id = new HashMap();
305 | venue2Id = new HashMap();
306 | paper2Id = new HashMap();
307 | Id2Author = new HashMap();
308 | Id2Venue = new HashMap();
309 | Id2Paper = new HashMap();
310 |
311 | String line = null;
312 | int paperIndex = 0, authorIndex = 0, venueIndex = 0;
313 | while((line=brPaperAuthor.readLine())!=null)
314 | {
315 | int splitPos = line.indexOf("\t");
316 | String paper = line.substring(0, splitPos);
317 | String author = line.substring(splitPos+1);
318 | if(!paper2Id.containsKey(paper))
319 | {
320 | paper2Id.put(paper, paperIndex);
321 | Id2Paper.put(paperIndex, paper);
322 | paperIndex++;
323 | }
324 | if(!author2Id.containsKey(author))
325 | {
326 | author2Id.put(author, authorIndex);
327 | Id2Author.put(authorIndex, author);
328 | authorIndex++;
329 | }
330 | }
331 | brPaperAuthor.close();
332 | while((line=brPaperVenue.readLine())!=null)
333 | {
334 | int splitPos = line.indexOf("\t");
335 | String venue = line.substring(splitPos+1);
336 | if(!venue2Id.containsKey(venue))
337 | {
338 | venue2Id.put(venue, venueIndex);
339 | Id2Venue.put(venueIndex, venue);
340 | venueIndex++;
341 | }
342 | }
343 | brPaperVenue.close();
344 | int numOfPaper = paperIndex;
345 | int numOfAuthor = authorIndex;
346 | int numOfVenue = venueIndex;
347 |
348 | authorNeighbors = new ArrayList>();
349 | venueNeighbors = new ArrayList>();
350 | paperNeighborsAuthor = new ArrayList>();
351 | paperNeighborsPaper = new ArrayList>();
352 | paperNeighborsVenue = new ArrayList>();
353 |
354 | for(int i=0;i neighbors = new ArrayList();
357 | authorNeighbors.add(neighbors);
358 | }
359 | for(int i=0;i neighbors = new ArrayList();
362 | venueNeighbors.add(neighbors);
363 | }
364 | for(int i=0;i neighborsAuthor = new ArrayList();
367 | paperNeighborsAuthor.add(neighborsAuthor);
368 | ArrayList neighborsPaper = new ArrayList();
369 | paperNeighborsPaper.add(neighborsPaper);
370 | ArrayList neighborsVenue = new ArrayList();
371 | paperNeighborsVenue.add(neighborsVenue);
372 | }
373 |
374 | brPaperAuthor = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.author.txt"), "UTF-8"));
375 | brPaperVenue = new BufferedReader(new InputStreamReader(new FileInputStream("dataset"+File.separator+"DBLP-V3.paper.venue.txt"), "UTF-8"));
376 | while((line=brPaperAuthor.readLine())!=null)
377 | {
378 | int splitPos = line.indexOf("\t");
379 | String paper = line.substring(0, splitPos);
380 | String author = line.substring(splitPos+1);
381 | int paperId = paper2Id.get(paper);
382 | int authorId = author2Id.get(author);
383 | if(!authorNeighbors.get(authorId).contains(paperId))
384 | authorNeighbors.get(authorId).add(paperId);
385 | if(!paperNeighborsAuthor.get(paperId).contains(authorId))
386 | paperNeighborsAuthor.get(paperId).add(authorId);
387 | }
388 | brPaperAuthor.close();
389 | while((line=brPaperPaper.readLine())!=null)
390 | {
391 | int splitPos = line.indexOf("\t");
392 | String paper1 = line.substring(0, splitPos);
393 | String paper2 = line.substring(splitPos+1);
394 | int paperId1 = paper2Id.get(paper1);
395 | int paperId2 = paper2Id.get(paper2);
396 | if(!paperNeighborsPaper.get(paperId1).contains(paperId2))
397 | paperNeighborsPaper.get(paperId1).add(paperId2);
398 | if(!paperNeighborsPaper.get(paperId2).contains(paperId1))
399 | paperNeighborsPaper.get(paperId2).add(paperId1);
400 | }
401 | brPaperPaper.close();
402 | while((line=brPaperVenue.readLine())!=null)
403 | {
404 | int splitPos = line.indexOf("\t");
405 | String paper = line.substring(0, splitPos);
406 | String venue = line.substring(splitPos+1);
407 | int paperId = paper2Id.get(paper);
408 | int venueId = venue2Id.get(venue);
409 | if(!venueNeighbors.get(venueId).contains(paperId))
410 | venueNeighbors.get(venueId).add(paperId);
411 | if(!paperNeighborsVenue.get(paperId).contains(venueId))
412 | paperNeighborsVenue.get(paperId).add(venueId);
413 | }
414 | brPaperVenue.close();
415 |
416 | int[] arrPaper1 = new int[numOfPaper];
417 | for(int i=0;i0;i--) {
421 | int randNumPaper1 = randomPaper1.nextInt(i);
422 | int tempPaper1 = arrPaper1[i];
423 | arrPaper1[i] = arrPaper1[randNumPaper1];
424 | arrPaper1[randNumPaper1] = tempPaper1;
425 | }
426 |
427 | for(int i=0;i() {
449 | public int compare(Integer a, Integer b) {
450 | return a.compareTo(b);
451 | }
452 | });
453 | }
454 | for(int i=0;i() {
457 | public int compare(Integer a, Integer b) {
458 | return a.compareTo(b);
459 | }
460 | });
461 | }
462 |
463 | for(int i=0;i() {
466 | public int compare(Integer a, Integer b) {
467 | return a.compareTo(b);
468 | }
469 | });
470 | Collections.sort(paperNeighborsPaper.get(i), new Comparator() {
471 | public int compare(Integer a, Integer b) {
472 | return a.compareTo(b);
473 | }
474 | });
475 | Collections.sort(paperNeighborsVenue.get(i), new Comparator() {
476 | public int compare(Integer a, Integer b) {
477 | return a.compareTo(b);
478 | }
479 | });
480 | }
481 |
482 | BufferedWriter bwMetaGraphRandWalk = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("randomwalk"+File.separator+"DBLP-V3.meta.graph.random.walk.txt"), "UTF-8"));
483 | BufferedWriter bwMetaPathRandWalkAPVPA = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("randomwalk"+File.separator+"DBLP-V3.meta.path.random.walk.apvpa.txt"), "UTF-8"));
484 | BufferedWriter bwMetaPathRandWalkAPAPA = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("randomwalk"+File.separator+"DBLP-V3.meta.path.random.walk.apapa.txt"), "UTF-8"));
485 | BufferedWriter bwMetaPathRandWalkMix = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("randomwalk"+File.separator+"DBLP-V3.meta.path.random.walk.mix.txt"), "UTF-8"));
486 | BufferedWriter bwUniformRandWalk = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("randomwalk"+File.separator+"DBLP-V3.uniform.random.walk.txt"), "UTF-8"));
487 | int pathLen = 100;
488 | int numOfWalk = 80;
489 | for(int i=0;i0;i--) {
522 | int randNum1 = random1.nextInt(i);
523 | int temp1 = arr1[i];
524 | arr1[i] = arr1[randNum1];
525 | arr1[randNum1] = temp1;
526 | }
527 | for(int i=0;i
28 | Use random walk sequences from to train the model
29 | -output
30 | Use to save the learned node vector-format representations
31 | -size
32 | Set the dimension of learned node representations; default is 128
33 | -window
34 | Set the window size for collecting node context pairs; default is 5
35 | -pp
36 | Use Heterogeneous Skip-Gram model (the ++ version) or not; default is 1 (++ version); otherwise, use 0 (Homogeneous Skip-Gram)
37 | -prefixes
38 | Prefixes of node Ids for specifying node types, e.g., ap with a for author and p for paper
39 | -objtype
40 | The index of the objective node type in the prefixes list, for which representations to be learned
41 | -alpha
42 | Set the starting learning rate; default is 0.025
43 | -negative
44 | Number of negative examples; default is 5, common values are 3 - 10
45 | -samples
46 | Set the number of iterations for stochastic gradient descent as Million; default is 100
47 |
48 | If you find this project is useful, please cite this paper:
49 |
50 | @inproceedings{zhang2018metagraph2vec,
51 | title={MetaGraph2Vec: Complex Semantic Path Augmented Heterogeneous Network Embedding},
52 | author={Zhang, Daokun and Yin, Jie and Zhu, Xingquan and Zhang, Chengqi},
53 | booktitle={Pacific-Asia Conference on Knowledge Discovery and Data Mining},
54 | pages={196--208},
55 | year={2018},
56 | organization={Springer}
57 | }
58 | # Note
59 | The randomWalk2vec.c is compiled on a Linux system with RAND_MAX taking value 2147483647. If you compile randomWalk2vec.c on your system, please carefully check the value of RAND_MAX to make sure it is large enough for the correctness of alias
60 | table sampling.
61 |
--------------------------------------------------------------------------------
/RandomWalk2Vec/RandomWalk2VecRun.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.graph.random.walk.txt -output DBLP-V3.meta.graph.pp1.emb.txt -pp 1 -prefixes apv -objtype 0
3 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.graph.random.walk.txt -output DBLP-V3.meta.graph.pp0.emb.txt -pp 0 -prefixes apv -objtype 0
4 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.apapa.txt -output DBLP-V3.meta.path.apapa.pp1.emb.txt -pp 1 -prefixes apv -objtype 0
5 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.apapa.txt -output DBLP-V3.meta.path.apapa.pp0.emb.txt -pp 0 -prefixes apv -objtype 0
6 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.apvpa.txt -output DBLP-V3.meta.path.apvpa.pp1.emb.txt -pp 1 -prefixes apv -objtype 0
7 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.apvpa.txt -output DBLP-V3.meta.path.apvpa.pp0.emb.txt -pp 0 -prefixes apv -objtype 0
8 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.mix.txt -output DBLP-V3.meta.path.mix.pp1.emb.txt -pp 1 -prefixes apv -objtype 0
9 | ./randomWalk2vec -size 128 -train DBLP-V3.meta.path.random.walk.mix.txt -output DBLP-V3.meta.path.mix.pp0.emb.txt -pp 0 -prefixes apv -objtype 0
10 | ./randomWalk2vec -size 128 -train DBLP-V3.uniform.random.walk.txt -output DBLP-V3.uniform.pp1.emb.txt -pp 1 -prefixes apv -objtype 0
11 | ./randomWalk2vec -size 128 -train DBLP-V3.uniform.random.walk.txt -output DBLP-V3.uniform.pp0.emb.txt -pp 0 -prefixes apv -objtype 0
12 |
--------------------------------------------------------------------------------
/RandomWalk2Vec/compile.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | gcc -Wall -g -c "randomWalk2vec.c" -o randomWalk2vec.o
3 | g++ -o randomWalk2vec randomWalk2vec.o
4 |
5 |
--------------------------------------------------------------------------------
/RandomWalk2Vec/randomWalk2vec.c:
--------------------------------------------------------------------------------
1 | // The randomWalk2vec project was built upon the The metapath2vec.cpp code from https://ericdongyx.github.io/metapath2vec/m2v.html
2 |
3 | // Modifications Copyright (C) 2018
4 | //
5 | // Licensed under the Apache License, Version 2.0 (the "License");
6 | // you may not use this file except in compliance with the License.
7 | // You may obtain a copy of the License at
8 | //
9 | // http://www.apache.org/licenses/LICENSE-2.0
10 | //
11 | // Unless required by applicable law or agreed to in writing, software
12 | // distributed under the License is distributed on an "AS IS" BASIS,
13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | // See the License for the specific language governing permissions and
15 | // limitations under the License.
16 |
17 | // The metapath2vec.cpp code was built upon the word2vec.c from https://code.google.com/archive/p/word2vec/
18 |
19 | // Modifications Copyright (C) 2016
20 | //
21 | // Licensed under the Apache License, Version 2.0 (the "License");
22 | // you may not use this file except in compliance with the License.
23 | // You may obtain a copy of the License at
24 | //
25 | // http://www.apache.org/licenses/LICENSE-2.0
26 | //
27 | // Unless required by applicable law or agreed to in writing, software
28 | // distributed under the License is distributed on an "AS IS" BASIS,
29 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30 | // See the License for the specific language governing permissions and
31 | // limitations under the License.
32 |
33 | // Copyright 2013 Google Inc. All Rights Reserved.
34 | //
35 | // Licensed under the Apache License, Version 2.0 (the "License");
36 | // you may not use this file except in compliance with the License.
37 | // You may obtain a copy of the License at
38 | //
39 | // http://www.apache.org/licenses/LICENSE-2.0
40 | //
41 | // Unless required by applicable law or agreed to in writing, software
42 | // distributed under the License is distributed on an "AS IS" BASIS,
43 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
44 | // See the License for the specific language governing permissions and
45 | // limitations under the License.
46 |
47 | #include
48 | #include
49 | #include
50 | #include
51 | #include
52 |
53 | #define MAX_STRING 100
54 | #define EXP_TABLE_SIZE 1000
55 | #define MAX_EXP 6
56 | #define MAX_WALK_LENGTH 1000
57 |
58 | const long long node_hash_size = 10000000; // Maximum 10M nodes
59 | const long long node_context_hash_size = 1000000000; // Maximum 1G node context pairs
60 |
61 | struct Node
62 | {
63 | long long cn;
64 | char *node_str;
65 | };
66 |
67 | struct Node_Context
68 | {
69 | long long cn;
70 | long long source, target;
71 | };
72 |
73 | char train_file[MAX_STRING], output_file[MAX_STRING];
74 | struct Node *node_list;
75 | long long window = 5;
76 | long long pp = 1; // pp = 1 with Heterogeneous Skip-Gram, pp = 0 with Homogeneous Skip-Gram
77 | long long *node_hash;
78 | long long node_list_max_size = 100000, node_list_size = 0, layer1_size = 128;
79 | long long node_occur_num = 0;
80 | double alpha = 0.025, starting_alpha;
81 | double *syn0, *syn1, *syn1neg, *expTable;
82 | double **type_syn1;
83 | clock_t start, finish;
84 |
85 | long long type_num;
86 | long long **type_tables, *type_counts, *type_indices;
87 | char prefixes[MAX_STRING];
88 | long long obj_type;
89 |
90 | long long **type_node_hash;
91 | long long *type_node_list_size, *type_node_index;
92 | struct Node **type_node_list;
93 |
94 | long long negative = 5;
95 | const long table_size = 1e8;
96 | long long *table;
97 |
98 | long long *node_context_hash;
99 | long long node_context_list_size, node_context_list_max_size;
100 | struct Node_Context *node_context_list;
101 |
102 | // Parameters for node context pair sampling
103 | long long *alias;
104 | double *prob;
105 |
106 | long long total_samples = 100;
107 |
108 | long long GetTypeId(char prefix)
109 | {
110 | long long i;
111 | for (i = 0; i < type_num; i++)
112 | {
113 | if (prefix == prefixes[i])
114 | break;
115 | }
116 | if (i >= type_num)
117 | return -1;
118 | else
119 | return i;
120 | }
121 |
122 | void NodeTypeNeg()
123 | {
124 | long long i;
125 | long long target, type_Id;
126 | type_tables = (long long **)malloc(type_num * sizeof(long long *));
127 | type_counts = (long long *)malloc(type_num * sizeof(long long));
128 | type_indices = (long long *)malloc(type_num * sizeof(long long));
129 | if (type_tables == NULL || type_counts == NULL || type_indices == NULL)
130 | {
131 | printf("Memory allocation failed\n");
132 | exit(1);
133 | }
134 | for (i = 0; i < type_num; i++)
135 | {
136 | type_counts[i] = 0;
137 | type_indices[i] = 0;
138 | }
139 | for (i = 0; i < table_size; i++)
140 | {
141 | target = table[i];
142 | type_Id = GetTypeId(node_list[target].node_str[0]);
143 | if (type_Id == -1)
144 | {
145 | printf("Unrecognised type\n");
146 | exit(1);
147 | }
148 | type_counts[type_Id]++;
149 | }
150 | for (i = 0; i< type_num; i++)
151 | {
152 | type_tables[i] = (long long *)malloc(type_counts[i] * sizeof(long long));
153 | if (type_tables[i] == NULL)
154 | {
155 | printf("Memory allocation failed\n");
156 | exit(1);
157 | }
158 | }
159 | for (i = 0; i < table_size; i++)
160 | {
161 | target = table[i];
162 | type_Id = GetTypeId(node_list[target].node_str[0]);
163 | type_tables[type_Id][type_indices[type_Id]] = target;
164 | type_indices[type_Id]++;
165 | }
166 | for (i = 0; i < type_num; i++)
167 | printf("type %c table size: %lld\n", prefixes[i], type_counts[i]);
168 | }
169 |
170 | void InitUnigramTable()
171 | {
172 | long long a, i;
173 | double train_nodes_pow = 0.0, d1, power = 0.75;
174 | table = (long long *)malloc(table_size * sizeof(long long));
175 | if (table == NULL)
176 | {
177 | printf("Memory allocation failed\n");
178 | exit(1);
179 | }
180 | for (a = 0; a < node_list_size; a++) train_nodes_pow += pow(node_list[a].cn, power);
181 | i = 0;
182 | d1 = pow(node_list[i].cn, power) / train_nodes_pow;
183 | for (a = 0; a < table_size; a++)
184 | {
185 | table[a] = i;
186 | if (a / (double)table_size > d1)
187 | {
188 | i++;
189 | d1 += pow(node_list[i].cn, power) / (double)train_nodes_pow;
190 | }
191 | if (i >= node_list_size) i = node_list_size - 1;
192 | }
193 | if (pp == 1)
194 | NodeTypeNeg();
195 | }
196 |
197 | // Reads a single node from a file, assuming space + tab + EOL to be node name boundaries
198 | long long ReadNode(char *node_str, FILE *fin) // return 1, if the boundary is '\n', otherwise 0
199 | {
200 | long long a = 0, ch, flag = 0;
201 | while (!feof(fin))
202 | {
203 | ch = fgetc(fin);
204 | if (ch == 13) continue;
205 | if ((ch == ' ') || (ch == '\t') || (ch == '\n'))
206 | {
207 | if (a > 0)
208 | {
209 | if (ch == '\n') flag = 1;
210 | break;
211 | }
212 | continue;
213 | }
214 | node_str[a] = ch;
215 | a++;
216 | if (a >= MAX_STRING - 1) a--;
217 | }
218 | node_str[a] = 0;
219 | return flag;
220 | }
221 |
222 | // Return hash value of a node
223 | long long GetNodeHash(char *node_str)
224 | {
225 | long long a, hash = 0;
226 | for (a = 0; a < strlen(node_str); a++)
227 | {
228 | hash = hash * 257 + node_str[a];
229 | hash = hash % node_hash_size;
230 | }
231 | return hash;
232 | }
233 |
234 | // Return hash value for a node context pair
235 | long long GetNodeContextHash(char *node_str, char *context_node_str)
236 | {
237 | long long a, hash = 0;
238 | for (a = 0; a < strlen(node_str); a++)
239 | {
240 | hash = hash * 257 + node_str[a];
241 | hash = hash % node_context_hash_size;
242 | }
243 | for (a = 0; a < strlen(context_node_str); a++)
244 | {
245 | hash = hash * 257 + context_node_str[a];
246 | hash = hash % node_context_hash_size;
247 | }
248 | return hash;
249 | }
250 |
251 | // Return position of a node in the node list; if the node is not found, returns -1
252 | long long SearchNode(char *node_str)
253 | {
254 | long long hash = GetNodeHash(node_str);
255 | while (1)
256 | {
257 | if (node_hash[hash] == -1) return -1;
258 | if (!strcmp(node_str, node_list[node_hash[hash]].node_str)) return node_hash[hash];
259 | hash = (hash + 1) % node_hash_size;
260 | }
261 | return -1;
262 | }
263 |
264 | // Return position of a node in the node list of the specific type; if the node is not found, returns -1
265 | long long TypeSearchNode(char *node_str, long long type_Id)
266 | {
267 | long long hash = GetNodeHash(node_str);
268 | while (1)
269 | {
270 | if (type_node_hash[type_Id][hash] == -1) return -1;
271 | if (!strcmp(node_str, type_node_list[type_Id][type_node_hash[type_Id][hash]].node_str)) return type_node_hash[type_Id][hash];
272 | hash = (hash + 1) % node_hash_size;
273 | }
274 | return -1;
275 | }
276 |
277 | // Return position of a node context pair in the node context list; if the node is not found, returns -1
278 | long long SearchNodeContextPair(long long node, long long context)
279 | {
280 | long long hash = GetNodeContextHash(node_list[node].node_str, node_list[context].node_str);
281 | long long hash_iter = 0;
282 | long long cur_node, cur_context;
283 | while(1)
284 | {
285 | if (node_context_hash[hash] == -1) return -1;
286 | cur_node = node_context_list[node_context_hash[hash]].source;
287 | cur_context = node_context_list[node_context_hash[hash]].target;
288 | if (cur_node == node && cur_context == context) return node_context_hash[hash];
289 | hash = (hash + 1) % node_context_hash_size;
290 | hash_iter++;
291 | if (hash_iter >= node_context_hash_size)
292 | {
293 | printf("The node context hash table is full!\n");
294 | exit(1);
295 | }
296 | }
297 | return -1;
298 | }
299 |
300 | // Adds a node to the node list
301 | long long AddNodeToList(char *node_str)
302 | {
303 | long long hash;
304 | long long length = strlen(node_str) + 1;
305 | if (length > MAX_STRING) length = MAX_STRING;
306 | node_list[node_list_size].node_str = (char *)calloc(length, sizeof(char));
307 | if (node_list[node_list_size].node_str == NULL)
308 | {
309 | printf("Memory allocation failed\n");
310 | exit(1);
311 | }
312 | strcpy(node_list[node_list_size].node_str, node_str);
313 | node_list[node_list_size].cn = 1;
314 | node_list_size++;
315 | // Reallocate memory if needed
316 | if (node_list_size >= node_list_max_size)
317 | {
318 | node_list_max_size += 10000;
319 | node_list = (struct Node *)realloc(node_list, node_list_max_size * sizeof(struct Node));
320 | if (node_list == NULL)
321 | {
322 | printf("Memory allocation failed\n");
323 | exit(1);
324 | }
325 | }
326 | hash = GetNodeHash(node_str);
327 | while (node_hash[hash] != -1) hash = (hash + 1) % node_hash_size;
328 | node_hash[hash] = node_list_size - 1;
329 | return node_list_size - 1;
330 | }
331 |
332 | //Add node context pair to the node context list
333 | long long AddNodeContextToList(long long node, long long context)
334 | {
335 | long long hash;
336 | long long hash_iter = 0;
337 | node_context_list[node_context_list_size].source = node;
338 | node_context_list[node_context_list_size].target = context;
339 | node_context_list[node_context_list_size].cn = 1;
340 | node_context_list_size++;
341 | if (node_context_list_size >= node_context_list_max_size)
342 | {
343 | node_context_list_max_size += 100 * node_list_size;
344 | node_context_list = (struct Node_Context *)realloc(node_context_list, node_context_list_max_size * sizeof(struct Node_Context));
345 | if (node_context_list == NULL)
346 | {
347 | printf("Memory allocation failed\n");
348 | exit(1);
349 | }
350 | }
351 | hash = GetNodeContextHash(node_list[node].node_str, node_list[context].node_str);
352 | while (node_context_hash[hash] != -1)
353 | {
354 | hash = (hash + 1) % node_context_hash_size;
355 | hash_iter++;
356 | if (hash_iter >= node_context_hash_size)
357 | {
358 | printf("The node context hash table is full!\n");
359 | exit(1);
360 | }
361 | }
362 | node_context_hash[hash] = node_context_list_size - 1;
363 | return node_context_list_size - 1;
364 | }
365 |
366 | void LearnNodeListFromTrainFile()
367 | {
368 | char node_str[MAX_STRING];
369 | FILE *fin;
370 | long long a, i;
371 | for (a = 0; a < node_hash_size; a++) node_hash[a] = -1;
372 | fin = fopen(train_file, "r");
373 | if (fin == NULL)
374 | {
375 | printf("ERROR: training data file not found!\n");
376 | exit(1);
377 | }
378 | node_list_size = 0;
379 | printf("Reading nodes...\n");
380 | while (1)
381 | {
382 | ReadNode(node_str, fin);
383 | if (feof(fin)) break;
384 | node_occur_num++;
385 | if (node_occur_num % 100000 == 0)
386 | {
387 | printf("Read nodes: %lldK%c", node_occur_num / 1000, 13);
388 | fflush(stdout);
389 | }
390 | i = SearchNode(node_str);
391 | if (i == -1) AddNodeToList(node_str);
392 | else node_list[i].cn++;
393 | }
394 | printf("Node list size: %lld \n", node_list_size);
395 | printf("Node occurrence times in train file: %lld\n", node_occur_num);
396 | fclose(fin);
397 | }
398 |
399 | void GetNodeContextTable()
400 | {
401 | char node_str[MAX_STRING];
402 | long long a, b, c, node, context_node, walk_length = 0, walk_position = 0;
403 | long long walk[MAX_WALK_LENGTH + 1];
404 | long long line_num = 0;
405 | long long end_flag;
406 | FILE *fi = fopen(train_file, "r");
407 | node_context_hash = (long long *)malloc(node_context_hash_size * sizeof(long long));
408 | node_context_list_max_size = 1000 * node_list_size;
409 | node_context_list = (struct Node_Context *)malloc(node_context_list_max_size * sizeof(struct Node_Context));
410 | for (a = 0; a < node_context_hash_size; a++) node_context_hash[a] = -1;
411 | node_context_list_size = 0;
412 | printf("Collecting node context pair...\n");
413 | start = clock();
414 | while (1)
415 | {
416 | if (walk_length == 0)
417 | {
418 | while (1)
419 | {
420 | end_flag = ReadNode(node_str, fi);
421 | if (feof(fi)) break;
422 | node = SearchNode(node_str);
423 | //if (node == -1) continue;
424 | walk[walk_length] = node;
425 | walk_length++;
426 | if (end_flag || walk_length >= MAX_WALK_LENGTH) break;
427 | }
428 | walk_position = 0;
429 | line_num++;
430 | if (line_num % 1000 == 0)
431 | {
432 | printf("Processed lines: %lldK%c", line_num / 1000, 13);
433 | fflush(stdout);
434 | }
435 | }
436 | if (feof(fi) && walk_length == 0) break;
437 | //if (walk_length == 0) continue;
438 | node = walk[walk_position]; //current node
439 | for (a = 0; a < window * 2 + 1; a++)
440 | {
441 | if (a != window)
442 | {
443 | c = walk_position - window + a;
444 | if (c < 0) continue;
445 | if (c >= walk_length) continue;
446 | context_node = walk[c];
447 | b = SearchNodeContextPair(node, context_node);
448 | if (b == -1) AddNodeContextToList(node, context_node);
449 | else node_context_list[b].cn++;
450 | }
451 | }
452 | walk_position++;
453 | if (walk_position >= walk_length) walk_length = 0;
454 | }
455 | finish = clock();
456 | printf("Total time: %lf secs for collecting node context pairs\n", (double)(finish - start) / CLOCKS_PER_SEC);
457 | printf("Node context pair number: %lld\n", node_context_list_size);
458 | printf("----------------------------------------------------\n");
459 | }
460 |
461 | void InitNet()
462 | {
463 | long long a, b;
464 | unsigned long long next_random = 1;
465 | syn0 = (double *)malloc(node_list_size * layer1_size * sizeof(double));
466 | if (syn0 == NULL)
467 | {
468 | printf("Memory allocation failed\n");
469 | exit(1);
470 | }
471 | for (a = 0; a < node_list_size; a++)
472 | for (b = 0; b < layer1_size; b++)
473 | {
474 | next_random = next_random * (unsigned long long)25214903917 + 11;
475 | syn0[a * layer1_size + b] = (((next_random & 0xFFFF) / (double)65536) - 0.5) / layer1_size;
476 | }
477 | syn1neg = (double *)malloc(node_list_size * layer1_size * sizeof(double));
478 | if (syn1neg == NULL)
479 | {
480 | printf("Memory allocation failed\n");
481 | exit(1);
482 | }
483 | for (a = 0; a < node_list_size; a++)
484 | for (b = 0; b < layer1_size; b++)
485 | syn1neg[a * layer1_size + b] = 0;
486 | InitUnigramTable();
487 | }
488 |
489 | /* The alias sampling algorithm, which is used to sample an node context pair in O(1) time. */
490 | void InitAliasTable()
491 | {
492 | long long k;
493 | double sum = 0;
494 | long long cur_small_block, cur_large_block;
495 | long long num_small_block = 0, num_large_block = 0;
496 | double *norm_prob;
497 | long long *large_block;
498 | long long *small_block;
499 | alias = (long long *)malloc(node_context_list_size * sizeof(long long));
500 | prob = (double *)malloc(node_context_list_size * sizeof(double));
501 | if (alias == NULL || prob == NULL)
502 | {
503 | printf("Memory allocation failed\n");
504 | exit(1);
505 | }
506 | norm_prob = (double*)malloc(node_context_list_size * sizeof(double));
507 | large_block = (long long*)malloc(node_context_list_size * sizeof(long long));
508 | small_block = (long long*)malloc(node_context_list_size * sizeof(long long));
509 | if (norm_prob == NULL || large_block == NULL || small_block == NULL)
510 | {
511 | printf("Memory allocation failed\n");
512 | exit(1);
513 | }
514 | for (k = 0; k != node_context_list_size; k++) sum += node_context_list[k].cn;
515 | for (k = 0; k != node_context_list_size; k++) norm_prob[k] = (double)node_context_list[k].cn * node_context_list_size / sum;
516 | for (k = node_context_list_size - 1; k >= 0; k--)
517 | {
518 | if (norm_prob[k]<1)
519 | small_block[num_small_block++] = k;
520 | else
521 | large_block[num_large_block++] = k;
522 | }
523 | while (num_small_block && num_large_block)
524 | {
525 | cur_small_block = small_block[--num_small_block];
526 | cur_large_block = large_block[--num_large_block];
527 | prob[cur_small_block] = norm_prob[cur_small_block];
528 | alias[cur_small_block] = cur_large_block;
529 | norm_prob[cur_large_block] = norm_prob[cur_large_block] + norm_prob[cur_small_block] - 1;
530 | if (norm_prob[cur_large_block] < 1)
531 | small_block[num_small_block++] = cur_large_block;
532 | else
533 | large_block[num_large_block++] = cur_large_block;
534 | }
535 | while (num_large_block) prob[large_block[--num_large_block]] = 1;
536 | while (num_small_block) prob[small_block[--num_small_block]] = 1;
537 | free(norm_prob);
538 | free(small_block);
539 | free(large_block);
540 | }
541 |
542 | long long SampleAPair(double rand_value1, double rand_value2)
543 | {
544 | long long k = (long long)node_context_list_size * rand_value1;
545 | return rand_value2 < prob[k] ? k : alias[k];
546 | }
547 |
548 | void TrainModel()
549 | {
550 | long a, b, c, d;
551 | long long node, context_node, cur_pair;
552 | long long count = 0, last_count = 0;
553 | long long l1, l2, target, label;
554 | long long type_Id;
555 | unsigned long long next_random = 1;
556 | double rand_num1, rand_num2;
557 | double f, g;
558 | double *neu1e = (double *)calloc(layer1_size, sizeof(double));
559 | FILE *fp;
560 | InitNet();
561 | starting_alpha = alpha;
562 | printf("Skip-Gram model with Negative Sampling:");
563 | if (pp == 1) printf(" heterogeneous version\n");
564 | else printf(" homogeneous version\n");
565 | printf("Training file: %s\n", train_file);
566 | printf("Samples: %lldM\n", total_samples / 1000000);
567 | printf("Dimension: %lld\n", layer1_size);
568 | printf("Initial Alpha: %f\n", alpha);
569 | start = clock();
570 | srand((unsigned)time(NULL));
571 | while (1)
572 | {
573 | if (count >= total_samples) break;
574 | if (count - last_count > 10000)
575 | {
576 | last_count = count;
577 | printf("Alpha: %f Progress %.3lf%%%c", alpha, (double)count / (double)(total_samples + 1) * 100, 13);
578 | fflush(stdout);
579 | alpha = starting_alpha * (1 - (double)count / (double)(total_samples + 1));
580 | if (alpha < starting_alpha * 0.0001) alpha = starting_alpha * 0.0001;
581 | }
582 | rand_num1 = rand() / (RAND_MAX * 1.0 + 1);
583 | rand_num2 = rand() / (RAND_MAX * 1.0 + 1);
584 | cur_pair = SampleAPair(rand_num1, rand_num2);
585 | node = node_context_list[cur_pair].source;
586 | context_node = node_context_list[cur_pair].target;
587 | l1 = node * layer1_size;
588 | for (c = 0; c < layer1_size; c++) neu1e[c] = 0;
589 |
590 | // NEGATIVE SAMPLING
591 | for (d = 0; d < negative + 1; d++)
592 | {
593 | if (d == 0)
594 | {
595 | target = context_node;
596 | label = 1;
597 | }
598 | else
599 | {
600 | next_random = next_random * (unsigned long long)25214903917 + 11;
601 | if (pp == 1)
602 | {
603 | type_Id = GetTypeId(node_list[context_node].node_str[0]);
604 | target = type_tables[type_Id][(next_random >> 16) % type_counts[type_Id]];
605 | }
606 | else
607 | target = table[(next_random >> 16) % table_size];
608 | if (target == context_node) continue;
609 | label = 0;
610 | }
611 | l2 = target * layer1_size;
612 | f = 0;
613 | for (c = 0; c < layer1_size; c++) f += syn0[c + l1] * syn1neg[c + l2];
614 | if (f > MAX_EXP) f = 1;
615 | else if (f < -MAX_EXP) f = 0;
616 | else f = expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];
617 | g = (label - f) * alpha;
618 | // Propagate errors output -> hidden
619 | for (c = 0; c < layer1_size; c++) neu1e[c] += g * syn1neg[c + l2];
620 | // Learn weights hidden -> output
621 | for (c = 0; c < layer1_size; c++) syn1neg[c + l2] += g * syn0[c + l1];
622 | }
623 | // Learn weights input -> hidden
624 | for (c = 0; c < layer1_size; c++) syn0[c + l1] += neu1e[c];
625 | count++;
626 | }
627 | finish = clock();
628 | printf("Total time: %lf secs for learning node embeddings\n", (double)(finish - start) / CLOCKS_PER_SEC);
629 | printf("----------------------------------------------------\n");
630 | fp = fopen(output_file, "w");
631 | // Save the learned node representations
632 | for (a = 0; a < node_list_size; a++)
633 | {
634 | if (node_list[a].node_str[0] != prefixes[obj_type]) continue;
635 | fprintf(fp, "%s", node_list[a].node_str);
636 | for (b = 0; b < layer1_size; b++) fprintf(fp, " %lf", syn0[a * layer1_size + b]);
637 | fprintf(fp, "\n");
638 | }
639 | free(neu1e);
640 | fclose(fp);
641 | }
642 |
643 | int ArgPos(char *str, int argc, char **argv)
644 | {
645 | int a;
646 | for (a = 1; a < argc; a++) if (!strcmp(str, argv[a]))
647 | {
648 | if (a == argc - 1)
649 | {
650 | printf("Argument missing for %s\n", str);
651 | exit(1);
652 | }
653 | return a;
654 | }
655 | return -1;
656 | }
657 |
658 | int main(int argc, char **argv)
659 | {
660 | int i;
661 | if (argc == 1)
662 | {
663 | printf("---randomWalk2vec---\n");
664 | printf("---The code and following instructions are built upon word2vec.c by Mikolov et al.---\n\n");
665 | printf("Options:\n");
666 | printf("Parameters for training:\n");
667 | printf("\t-train \n");
668 | printf("\t\tUse random walk sequences from to train the model\n");
669 | printf("\t-output \n");
670 | printf("\t\tUse to save the learned node vector-format representations\n");
671 | printf("\t-size \n");
672 | printf("\t\tSet the dimension of learned node representations; default is 128\n");
673 | printf("\t-window \n");
674 | printf("\t\tSet the window size for collecting node context pairs; default is 5\n");
675 | printf("\t-pp \n");
676 | printf("\t\tUse Heterogeneous Skip-Gram model (the ++ version) or not; default is 1 (++ version); otherwise, use 0 (Homogeneous Skip-Gram)\n");
677 | printf("\t-prefixes \n");
678 | printf("\t\tPrefixes of node Ids for specifying node types, e.g., ap with a for author and p for paper\n");
679 | printf("\t-objtype\n");
680 | printf("\t\tThe index of the objective node type in the prefixes list, for which representations to be learned\n");
681 | printf("\t-alpha \n");
682 | printf("\t\tSet the starting learning rate; default is 0.025\n");
683 | printf("\t-negative \n");
684 | printf("\t\tNumber of negative examples; default is 5, common values are 3 - 10\n");
685 | printf("\t-samples \n");
686 | printf("\t\tSet the number of iterations for stochastic gradient descent as Million; default is 100\n");
687 | return 0;
688 | }
689 | if ((i = ArgPos((char *)"-size", argc, argv)) > 0) layer1_size = atoi(argv[i + 1]);
690 | if ((i = ArgPos((char *)"-train", argc, argv)) > 0) strcpy(train_file, argv[i + 1]);
691 | if ((i = ArgPos((char *)"-alpha", argc, argv)) > 0) alpha = atof(argv[i + 1]);
692 | if ((i = ArgPos((char *)"-output", argc, argv)) > 0) strcpy(output_file, argv[i + 1]);
693 | if ((i = ArgPos((char *)"-window", argc, argv)) > 0) window = atoi(argv[i + 1]);
694 | if ((i = ArgPos((char *)"-pp", argc, argv)) > 0) pp = atoi(argv[i + 1]);
695 | if ((i = ArgPos((char *)"-prefixes", argc, argv)) > 0) strcpy(prefixes, argv[i + 1]);
696 | if ((i = ArgPos((char *)"-objtype", argc, argv)) > 0) obj_type = atoi(argv[i + 1]);
697 | if ((i = ArgPos((char *)"-negative", argc, argv)) > 0) negative = atoi(argv[i + 1]);
698 | if ((i = ArgPos((char *)"-samples", argc, argv)) >0) total_samples = atoi(argv[i + 1]);
699 | total_samples = total_samples * 1000000;
700 | type_num = strlen(prefixes);
701 | printf("Number of node types: %lld\n", type_num);
702 | node_list = (struct Node *)calloc(node_list_max_size, sizeof(struct Node));
703 | node_hash = (long long *)calloc(node_hash_size, sizeof(long long));
704 | expTable = (double *)malloc((EXP_TABLE_SIZE + 1) * sizeof(double));
705 | if (node_list == NULL || node_hash == NULL || expTable == NULL)
706 | {
707 | printf("Memory allocation failed\n");
708 | exit(1);
709 | }
710 | for (i = 0; i < EXP_TABLE_SIZE; i++)
711 | {
712 | expTable[i] = exp((i / (double)EXP_TABLE_SIZE * 2 - 1) * MAX_EXP); // Precompute the exp() table
713 | expTable[i] = expTable[i] / (expTable[i] + 1); // Precompute f(x) = x / (x + 1)
714 | }
715 | LearnNodeListFromTrainFile();
716 | GetNodeContextTable();
717 | InitAliasTable();
718 | TrainModel();
719 | return 0;
720 | }
721 |
--------------------------------------------------------------------------------