├── README.md
├── enhanced_graph_embedding_side_information
├── .idea
│ ├── compiler.xml
│ ├── hydra.xml
│ ├── libraries
│ │ └── scala_sdk_2_11_8.xml
│ ├── misc.xml
│ ├── scala_compiler.xml
│ ├── uiDesigner.xml
│ └── workspace.xml
├── pom.xml
└── src
│ └── main
│ └── scala
│ ├── eges
│ ├── embedding
│ │ ├── WeightedSkipGram.scala
│ │ └── WeightedSkipGramBatch.scala
│ └── random
│ │ ├── sample
│ │ └── Alias.scala
│ │ └── walk
│ │ ├── Attributes.scala
│ │ └── RandomWalk.scala
│ ├── example
│ ├── Example1.scala
│ ├── Example10.scala
│ ├── Example11.scala
│ ├── Example2.scala
│ ├── Example3.scala
│ ├── Example4.scala
│ ├── Example5.scala
│ ├── Example6.scala
│ ├── Example7.scala
│ ├── Example8.scala
│ └── Example9.scala
│ ├── main
│ ├── Main1.scala
│ ├── Main2.scala
│ └── Main3.scala
│ └── sparkapplication
│ ├── BaseSparkLocal.scala
│ └── BaseSparkOnline.scala
└── paper
├── [Alibaba Embedding] Billion-scale Commodity Embedding for E-commerce Recommendation in Alibaba (Alibaba 2018).pdf
├── [Graph Embedding] DeepWalk- Online Learning of Social Representations (SBU 2014).pdf
└── [Node2vec] Node2vec - Scalable Feature Learning for Networks (Stanford 2016).pdf
/README.md:
--------------------------------------------------------------------------------
1 | # deepwalk_node2vector_eges
2 | 将deepwalk、node2vector和阿里的文章:Billion-scale Commodity Embedding for E-commerce Recommendation in Alibaba
3 | 用代码实现,将随机游走与embedding分开,其中代码中:
4 | 1. 主函数Main1通过设置p和q来随机游走
5 | 2. 主函数Main2(全量同步训练)是word2vector与阿里文章的结合,当边信息长度为1时就是word2vector,边信息长度大于1时就是阿里文章代码
6 | 3. 主函数Main3(分多批次同步训练)是与主函数Main2一样,只不过是多批训练
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/.idea/compiler.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/.idea/hydra.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/.idea/libraries/scala_sdk_2_11_8.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/.idea/scala_compiler.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/.idea/uiDesigner.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | -
6 |
7 |
8 | -
9 |
10 |
11 | -
12 |
13 |
14 | -
15 |
16 |
17 | -
18 |
19 |
20 |
21 |
22 |
23 | -
24 |
25 |
26 |
27 |
28 |
29 | -
30 |
31 |
32 |
33 |
34 |
35 | -
36 |
37 |
38 |
39 |
40 |
41 | -
42 |
43 |
44 |
45 |
46 | -
47 |
48 |
49 |
50 |
51 | -
52 |
53 |
54 |
55 |
56 | -
57 |
58 |
59 |
60 |
61 | -
62 |
63 |
64 |
65 |
66 | -
67 |
68 |
69 |
70 |
71 | -
72 |
73 |
74 | -
75 |
76 |
77 |
78 |
79 | -
80 |
81 |
82 |
83 |
84 | -
85 |
86 |
87 |
88 |
89 | -
90 |
91 |
92 |
93 |
94 | -
95 |
96 |
97 |
98 |
99 | -
100 |
101 |
102 | -
103 |
104 |
105 | -
106 |
107 |
108 | -
109 |
110 |
111 | -
112 |
113 |
114 |
115 |
116 | -
117 |
118 |
119 | -
120 |
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 | @
27 | nextFloat
28 | nextGaussian
29 | persist(StorageLevel.MEMORY_AND_DISK)
30 | first
31 | cache
32 | nextFloat()
33 | nodeOrSideInfo
34 | nodeOrSideInfoArray
35 | 16
36 | 20
37 | 35
38 | F
39 | Float
40 |
41 |
42 | D:\apache-maven-3.5.2\repository\org\apache\spark\spark-mllib_2.11\2.1.0\spark-mllib_2.11-2.1.0.jar!\org\apache\spark\ml
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 | 1565866081242
333 |
334 |
335 | 1565866081242
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
659 |
660 |
661 |
662 |
663 |
664 |
665 |
666 |
667 |
668 |
669 |
670 |
671 |
672 |
673 |
674 |
675 |
676 |
677 |
678 |
679 |
680 |
681 |
682 |
683 |
684 |
685 |
686 |
687 |
688 |
689 |
690 |
691 | scala-sdk-2.11.8
692 |
693 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
702 |
703 | 1.8
704 |
705 |
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 |
714 |
715 | enhanced_graph_embedding_side_information
716 |
717 |
718 |
719 |
720 |
721 |
722 |
723 |
724 |
725 |
726 |
727 |
728 | scala-sdk-2.11.8
729 |
730 |
731 |
732 |
733 |
734 |
735 |
736 |
737 |
738 |
739 |
740 |
741 |
742 |
743 |
744 |
745 |
746 |
747 |
748 |
749 |
750 |
751 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | embedding.eges
8 | enhanced_graph_embedding_side_information
9 | 1.0-SNAPSHOT
10 |
11 |
12 | 2.11
13 | ${scala.binary.version}.8
14 | 2.1.0
15 | compile
16 |
17 |
18 |
19 | org.scala-lang
20 | scala-library
21 | ${scala.version}
22 |
23 |
24 | org.scala-lang
25 | scala-compiler
26 | ${scala.version}
27 |
28 |
29 | org.apache.spark
30 | spark-core_${scala.binary.version}
31 | ${spark.version}
32 |
33 |
34 |
35 | org.apache.spark
36 | spark-sql_${scala.binary.version}
37 | ${spark.version}
38 |
39 |
40 |
41 | commons-configuration
42 | commons-configuration
43 | 1.9
44 |
45 |
46 | commons-lang
47 | commons-lang
48 | 2.5
49 |
50 |
51 |
52 | org.apache.spark
53 | spark-mllib_2.11
54 | ${spark.version}
55 |
56 |
57 |
58 |
59 | org.scalanlp
60 | breeze_2.11
61 | 0.13.2
62 |
63 |
64 |
65 |
66 |
67 | sit
68 |
69 | true
70 |
71 |
72 |
73 | ../${project.artifactId}/vars/vars.sit.properties
74 |
75 |
76 |
77 | src/main/resources
78 | true
79 |
80 |
81 |
82 |
83 |
84 | prod
85 |
86 |
87 | ../${project.artifactId}/vars/vars.prod.properties
88 |
89 |
90 |
91 | src/main/resources
92 | true
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 | org.codehaus.mojo
103 | build-helper-maven-plugin
104 | 1.8
105 |
106 |
107 | add-source
108 | generate-sources
109 |
110 | add-source
111 |
112 |
113 |
114 | src/main/scala
115 | src/test/scala
116 |
117 |
118 |
119 |
120 | add-test-source
121 | generate-sources
122 |
123 | add-test-source
124 |
125 |
126 |
127 | src/test/scala
128 |
129 |
130 |
131 |
132 |
133 |
134 | net.alchim31.maven
135 | scala-maven-plugin
136 | 3.1.5
137 |
138 |
139 | compile
140 | testCompile
141 |
142 |
143 |
144 | ${scala.version}
145 |
146 |
147 | org.apache.maven.plugins
148 | maven-compiler-plugin
149 |
150 | 1.7
151 | 1.7
152 | utf-8
153 |
154 |
155 |
156 | compile
157 |
158 | compile
159 |
160 |
161 |
162 |
163 |
164 | maven-assembly-plugin
165 |
166 |
167 | jar-with-dependencies
168 |
169 |
170 |
171 | example.Example5
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/eges/embedding/WeightedSkipGram.scala:
--------------------------------------------------------------------------------
1 | package eges.embedding
2 |
3 | import java.util.Random
4 | import scala.collection.mutable
5 | import org.apache.spark.rdd.RDD
6 | import breeze.linalg.DenseVector
7 | import scala.collection.mutable.ArrayBuffer
8 | import org.apache.spark.storage.StorageLevel
9 |
10 | class WeightedSkipGram extends Serializable {
11 |
12 | var vectorSize = 4
13 | var windowSize = 2
14 | var negativeSampleNum = 1
15 | var perSampleMaxNum = 10
16 | var subSample: Double = 0.0
17 | var learningRate: Double = 0.02
18 | var iterationNum = 5
19 | var isNoShowLoss = false
20 | var numPartitions = 200
21 | var nodeSideCount: Int = 0
22 | var nodeSideTable: Array[(String, Int)] = _
23 | var nodeSideHash = mutable.HashMap.empty[String, Int]
24 | var sampleTableNum: Int = 1e8.toInt
25 | lazy val sampleTable: Array[Int] = new Array[Int](sampleTableNum)
26 |
27 | /**
28 | * 每个节点embedding长度
29 | *
30 | * */
31 | def setVectorSize(value:Int):this.type = {
32 | require(value > 0, s"vectorSize must be more than 0, but it is $value")
33 | vectorSize = value
34 | this
35 | }
36 |
37 | /**
38 | * 窗口大小
39 | *
40 | * */
41 | def setWindowSize(value:Int):this.type = {
42 | require(value > 0, s"windowSize must be more than 0, but it is $value")
43 | windowSize = value
44 | this
45 | }
46 |
47 | /**
48 | * 负采样的样本个数
49 | *
50 | * */
51 | def setNegativeSampleNum(value:Int):this.type = {
52 | require(value > 0, s"negativeSampleNum must be more than 0, but it is $value")
53 | negativeSampleNum = value
54 | this
55 | }
56 |
57 | /**
58 | * 负采样时, 每个负样本最多采样的次数
59 | *
60 | * */
61 | def setPerSampleMaxNum(value:Int):this.type = {
62 | require(value > 0, s"perSampleMaxNum must be more than 0, but it is $value")
63 | perSampleMaxNum = value
64 | this
65 | }
66 |
67 | /**
68 | * 高频节点的下采样率
69 | *
70 | * */
71 | def setSubSample(value:Double):this.type = {
72 | require(value >= 0.0 && value <= 1.0, s"subSample must be not less than 0.0 and not more than 1.0, but it is $value")
73 | subSample = value
74 | this
75 | }
76 |
77 | /**
78 | * 初始化学习率
79 | *
80 | * */
81 | def setLearningRate(value:Double):this.type = {
82 | require(value > 0.0, s"learningRate must be more than 0.0, but it is $value")
83 | learningRate = value
84 | this
85 | }
86 |
87 | /**
88 | * 迭代次数
89 | *
90 | * */
91 | def setIterationNum(value:Int):this.type = {
92 | require(value > 0, s"iterationNum must be more than 0, but it is $value")
93 | iterationNum = value
94 | this
95 | }
96 |
97 | /**
98 | * 是否显示损失, 不建议显示损失值, 这样增加计算量
99 | *
100 | * */
101 | def setIsNoShowLoss(value:Boolean):this.type = {
102 | isNoShowLoss = value
103 | this
104 | }
105 |
106 | /**
107 | * 分区数量
108 | *
109 | * */
110 | def setNumPartitions(value:Int):this.type = {
111 | require(value > 0, s"numPartitions must be more than 0, but it is $value")
112 | numPartitions = value
113 | this
114 | }
115 |
116 | /**
117 | * 采样数组大小
118 | *
119 | * */
120 | def setSampleTableNum(value:Int):this.type = {
121 | require(value > 0, s"sampleTableNum must be more than 0, but it is $value")
122 | sampleTableNum = value
123 | this
124 | }
125 |
126 | /**
127 | * 训练
128 | *
129 | * */
130 | def fit(dataSet: RDD[Array[Array[String]]]): RDD[(String, String)] = {
131 | // 重新分区
132 | val dataSetRepartition = dataSet.map(k => ((new Random).nextInt(numPartitions), k))
133 | .repartition(numPartitions).map(k => k._2)
134 | dataSetRepartition.persist(StorageLevel.MEMORY_AND_DISK)
135 | val sc = dataSetRepartition.context
136 | // 初始化 WeightedSkipGram
137 | initWeightedSkipGram(dataSetRepartition)
138 | // 边和embedding的长度
139 | val sideInfoNum = dataSetRepartition.first().head.length
140 | val sideInfoNumBroadcast = sc.broadcast(sideInfoNum)
141 | val vectorSizeBroadcast = sc.broadcast(vectorSize)
142 | // 节点权重初始化
143 | val nodeWeight = dataSetRepartition.flatMap(x => x)
144 | .map(nodeInfo => nodeInfo.head + "#Weight").distinct()
145 | .map(nodeWeight => (nodeWeight, Array.fill[Double](sideInfoNumBroadcast.value)((new Random).nextGaussian())))
146 | // 节点Zu初始化
147 | val nodeZu = dataSetRepartition.flatMap(x => x)
148 | .map(nodeInfo => nodeInfo.head + "#Zu").distinct()
149 | .map(nodeZu => (nodeZu, Array.fill[Double](vectorSizeBroadcast.value)( ((new Random).nextDouble()-0.5)/vectorSizeBroadcast.value ) ))
150 | // 节点和边信息初始化
151 | val nodeAndSideVector = dataSetRepartition.flatMap(x => x).flatMap(x => x).distinct()
152 | .map(nodeAndSide => (nodeAndSide, Array.fill[Double](vectorSizeBroadcast.value)( ((new Random).nextDouble()-0.5)/vectorSizeBroadcast.value ) ))
153 | // 节点权重、节点和边信息初始化
154 | val embeddingHashInit = mutable.HashMap(nodeWeight.union(nodeZu).union(nodeAndSideVector).collectAsMap().toList:_*)
155 | // println("embedding初始化:")
156 | // embeddingHashInit.map(k => (k._1, k._2.mkString("@"))).foreach(println)
157 |
158 | // 广播部分变量
159 | val nodeSideTableBroadcast = sc.broadcast(nodeSideTable)
160 | val nodeSideHashBroadcast = sc.broadcast(nodeSideHash)
161 | val sampleTableBroadcast = sc.broadcast(sampleTable)
162 | val subSampleBroadcast = sc.broadcast(subSample)
163 | val windowSizeBroadcast = sc.broadcast(windowSize)
164 | val negativeSampleNumBroadcast = sc.broadcast(negativeSampleNum)
165 | val perSampleMaxNumBroadcast = sc.broadcast(perSampleMaxNum)
166 | val sampleTableNumBroadcast = sc.broadcast(sampleTableNum)
167 | val nodeSideCountBroadcast = sc.broadcast(nodeSideCount)
168 | val isNoShowLossBroadcast = sc.broadcast(isNoShowLoss)
169 | // val lambdaBroadcast = sc.broadcast(learningRate)
170 |
171 | // 训练
172 | var embeddingHashBroadcast = sc.broadcast(embeddingHashInit)
173 | for(iter <- 0 until iterationNum){
174 | val lambdaBroadcast = sc.broadcast(if(learningRate/(1+iter) >= learningRate/5) learningRate/(1+iter) else learningRate/5)
175 | // val lambdaBroadcast = sc.broadcast(if(learningRate/(1+iter) >= 0.000001) learningRate/(1+iter) else 0.000001)
176 | // val lambdaBroadcast = sc.broadcast(Array(0.5, 0.4, 0.2, 0.1, 0.05, 0.05)(iter))
177 | // 训练结果
178 | val trainResult = dataSetRepartition.mapPartitions(dataIterator =>
179 | trainWeightedSkipGram(dataIterator, embeddingHashBroadcast.value, nodeSideTableBroadcast.value,
180 | nodeSideHashBroadcast.value, sampleTableBroadcast.value, subSampleBroadcast.value, vectorSizeBroadcast.value,
181 | sideInfoNumBroadcast.value, windowSizeBroadcast.value, negativeSampleNumBroadcast.value, perSampleMaxNumBroadcast.value,
182 | sampleTableNumBroadcast.value, nodeSideCountBroadcast.value, isNoShowLossBroadcast.value))
183 | trainResult.persist(StorageLevel.MEMORY_AND_DISK)
184 | // 节点权重的梯度聚类
185 | val trainWeightResult = trainResult.filter(k => k._1.endsWith("#Weight") && !k._1.equals("LossValue"))
186 | .aggregateByKey((new DenseVector[Double](Array.fill[Double](sideInfoNumBroadcast.value)(0.0)), 0L))(
187 | (vector, array) => (vector._1 + new DenseVector[Double](array._1), vector._2 + array._2),
188 | (vector1, vector2) => (vector1._1 + vector2._1, vector1._2 + vector2._2) )
189 | trainWeightResult.persist(StorageLevel.MEMORY_AND_DISK)
190 | // 节点和边embedding的梯度聚类
191 | val trainVectorResult = trainResult.filter(k => !k._1.endsWith("#Weight") && !k._1.equals("LossValue"))
192 | .aggregateByKey((new DenseVector[Double](Array.fill[Double](vectorSizeBroadcast.value)(0.0)), 0L))(
193 | (vector, array) => (vector._1 + new DenseVector[Double](array._1), vector._2 + array._2),
194 | (vector1, vector2) => (vector1._1 + vector2._1, vector1._2 + vector2._2) )
195 | trainVectorResult.persist(StorageLevel.MEMORY_AND_DISK)
196 | // 计算平均损失
197 | if(isNoShowLoss){
198 | val trainLossResult = trainResult.filter(k => k._1.equals("LossValue"))
199 | .aggregateByKey((new DenseVector[Double](Array.fill[Double](1)(0.0)), 0L))(
200 | (vector, array) => (vector._1 + new DenseVector[Double](array._1), vector._2 + array._2),
201 | (vector1, vector2) => (vector1._1 + vector2._1, vector1._2 + vector2._2)
202 | ).map(k => (k._1, k._2._1/k._2._2.toDouble))
203 | .map(k => (k._1, k._2.toArray.head)).first()._2
204 | println(s"====第${iter+1}轮====平均损失:$trainLossResult====")
205 | }
206 | val trainNodeSideResultUnion = trainWeightResult.union(trainVectorResult)
207 | .map{ case (key, (gradientVectorSum, num)) =>
208 | val embeddingRaw = new DenseVector[Double](embeddingHashBroadcast.value(key))
209 | val embeddingUpdate = embeddingRaw - 100 * lambdaBroadcast.value * gradientVectorSum / num.toDouble
210 | (key, embeddingUpdate.toArray)
211 | }
212 | var trainSideVectorResultMap = mutable.HashMap(trainNodeSideResultUnion.collectAsMap().toList:_*)
213 | trainSideVectorResultMap = embeddingHashBroadcast.value.++(trainSideVectorResultMap)
214 | embeddingHashBroadcast = sc.broadcast(trainSideVectorResultMap)
215 | trainResult.unpersist()
216 | trainWeightResult.unpersist()
217 | trainVectorResult.unpersist()
218 | }
219 | dataSetRepartition.unpersist()
220 |
221 | // 每个节点加权向量
222 | var embeddingHashResult = embeddingHashBroadcast.value
223 | val nodeHvArray = nodeSideTable.map(k => k._1.split("@"))
224 | .map(nodeInfo => {
225 | val node = nodeInfo.head
226 | var eavSum = 0.0
227 | var eavWvAggSum = new DenseVector[Double](Array.fill[Double](vectorSize)(0.0))
228 | val keyEA = node + "#Weight"
229 | val valueEA = embeddingHashResult.getOrElse(keyEA, Array.fill[Double](sideInfoNum)(0.0))
230 | for(n <- 0 until sideInfoNum) {
231 | val keyNodeOrSide = nodeInfo(n)
232 | val valueNodeOrSide = embeddingHashResult.getOrElse(keyNodeOrSide, Array.fill[Double](vectorSize)(0.0))
233 | eavWvAggSum = eavWvAggSum + math.exp(valueEA(n)) * new DenseVector[Double](valueNodeOrSide)
234 | eavSum = eavSum + math.exp(valueEA(n))
235 | }
236 | val Hv = (eavWvAggSum/eavSum).toArray
237 | (nodeInfo.mkString("#"), Hv)
238 | })
239 | val nodeHvMap = mutable.Map(nodeHvArray:_*)
240 | embeddingHashResult = embeddingHashResult.++(nodeHvMap)
241 | val embeddingResult = sc.parallelize(embeddingHashResult.map(k => (k._1, k._2.mkString("@"))).toList, numPartitions)
242 | embeddingResult
243 | }
244 |
245 | /**
246 | * 初始化
247 | *
248 | * */
249 | def initWeightedSkipGram(dataSet: RDD[Array[Array[String]]]): Unit = {
250 | // 节点和边信息出现的频次数组
251 | nodeSideTable = dataSet.flatMap(x => x)
252 | .map(nodeInfo => (nodeInfo.mkString("@"), 1))
253 | .reduceByKey(_ + _)
254 | .collect().sortBy(-_._2)
255 | // 节点和边信息总数量
256 | nodeSideCount = nodeSideTable.map(_._2).par.sum
257 | //节点和边信息出现的频次哈希map
258 | nodeSideHash = mutable.HashMap(nodeSideTable:_*)
259 | // 节点和边信息尺寸
260 | val nodeSideSize = nodeSideTable.length
261 | // Negative Sampling 负采样初始化
262 | var a = 0
263 | val power = 0.75
264 | var nodesPow = 0.0
265 | while (a < nodeSideSize) {
266 | nodesPow += Math.pow(nodeSideTable(a)._2, power)
267 | a = a + 1
268 | }
269 | var b = 0
270 | var freq = Math.pow(nodeSideTable(b)._2, power) / nodesPow
271 | var c = 0
272 | while (c < sampleTableNum) {
273 | sampleTable(c) = b
274 | if ((c.toDouble + 1.0) / sampleTableNum >= freq) {
275 | b = b + 1
276 | if (b >= nodeSideSize) {
277 | b = nodeSideSize - 1
278 | }
279 | freq += Math.pow(nodeSideTable(b)._2, power) / nodesPow
280 | }
281 | c = c + 1
282 | }
283 | }
284 |
285 | /**
286 | * 每个分区训练
287 | *
288 | * */
289 | def trainWeightedSkipGram(dataSet: Iterator[Array[Array[String]]],
290 | embeddingHash: mutable.HashMap[String, Array[Double]],
291 | nodeSideTable: Array[(String, Int)],
292 | nodeSideHash: mutable.HashMap[String, Int],
293 | sampleTable: Array[Int],
294 | subSample: Double,
295 | vectorSize: Int,
296 | sideInfoNum: Int,
297 | windowSize: Int,
298 | negativeSampleNum: Int,
299 | perSampleMaxNum: Int,
300 | sampleTableNum: Int,
301 | nodeSideCount: Int,
302 | isNoShowLoss: Boolean): Iterator[(String, (Array[Double], Long))] = {
303 | val gradientUpdateHash = new mutable.HashMap[String, (Array[Double], Long)]()
304 | var lossSum = 0.0
305 | var lossNum = 0L
306 | for(data <- dataSet) {
307 | // 高频节点的下采样
308 | val dataSubSample = ArrayBuffer[Array[String]]()
309 | for( nodeOrSideInfoArray <- data){
310 | if(subSample > 0.0) {
311 | val nodeOrSideInfoFrequency = nodeSideHash(nodeOrSideInfoArray.mkString("@"))
312 | val keepProbability = (Math.sqrt(nodeOrSideInfoFrequency/ (subSample * nodeSideCount)) + 1.0) * (subSample * nodeSideCount) / nodeOrSideInfoFrequency
313 | if (keepProbability >= (new Random).nextDouble()) {
314 | dataSubSample.append(nodeOrSideInfoArray)
315 | }
316 | } else {
317 | dataSubSample.append(nodeOrSideInfoArray)
318 | }
319 | }
320 | val dssl = dataSubSample.length
321 | for(i <- 0 until dssl){
322 | val mainNodeInfo = dataSubSample(i)
323 | val mainNode = mainNodeInfo.head
324 | var eavSum = 0.0
325 | var eavWvAggSum = new DenseVector[Double](Array.fill[Double](vectorSize)(0.0))
326 | val keyEA = mainNode + "#Weight"
327 | val valueEA = embeddingHash(keyEA)
328 | for(k <- 0 until sideInfoNum) {
329 | val keyNodeOrSide = mainNodeInfo(k)
330 | val valueNodeOrSide = embeddingHash(keyNodeOrSide)
331 | eavWvAggSum = eavWvAggSum + math.exp(valueEA(k)) * new DenseVector[Double](valueNodeOrSide)
332 | eavSum = eavSum + math.exp(valueEA(k))
333 | }
334 | val Hv = eavWvAggSum/eavSum
335 | // 主节点对应窗口内的正样本集合
336 | val mainSlaveNodeSet = dataSubSample.slice(math.max(0, i-windowSize), math.min(i+windowSize, dssl-1) + 1)
337 | .map(array => array.head).toSet
338 | for(j <- math.max(0, i-windowSize) to math.min(i+windowSize, dssl-1)){
339 | if(j != i) {
340 | // 正样本训练
341 | val slaveNodeInfo = dataSubSample(j)
342 | val slaveNode = slaveNodeInfo.head
343 | embeddingUpdate(mainNodeInfo, slaveNodeInfo, embeddingHash, gradientUpdateHash, valueEA, eavWvAggSum, eavSum, Hv, 1.0, vectorSize, sideInfoNum)
344 | // 正样本计算损失值
345 | if(isNoShowLoss){
346 | val keySlaveNode = slaveNode + "#Zu"
347 | val valueSlaveNode = embeddingHash(keySlaveNode)
348 | val HvZu = Hv.t * new DenseVector[Double](valueSlaveNode)
349 | val logarithmRaw = math.log(1.0 / (1.0 + math.exp(math.max(math.min(-HvZu, 20.0), -20.0))))
350 | val logarithm = math.max(math.min(logarithmRaw, 0.0), -20.0)
351 | lossSum = lossSum - logarithm
352 | lossNum = lossNum + 1L
353 | }
354 | // 负样本采样和训练
355 | for(_ <- 0 until negativeSampleNum){
356 | // 负样本采样
357 | var sampleNodeInfo = slaveNodeInfo
358 | var sampleNode = slaveNode
359 | var sampleNum = 0
360 | while(mainSlaveNodeSet.contains(sampleNode) && sampleNum < perSampleMaxNum) {
361 | val index = sampleTable((new Random).nextInt(sampleTableNum))
362 | sampleNodeInfo = nodeSideTable(index)._1.split("@")
363 | sampleNode = sampleNodeInfo.head
364 | sampleNum = sampleNum + 1
365 | }
366 | // 负样本训练
367 | embeddingUpdate(mainNodeInfo, sampleNodeInfo, embeddingHash, gradientUpdateHash, valueEA, eavWvAggSum, eavSum, Hv, 0.0, vectorSize, sideInfoNum)
368 | // 负样本计算损失值
369 | if(isNoShowLoss){
370 | val keySampleNode = sampleNodeInfo.head + "#Zu"
371 | val valueSampleNode = embeddingHash(keySampleNode)
372 | val HvZu = Hv.t * new DenseVector[Double](valueSampleNode)
373 | val logarithmRaw = math.log(1.0 - 1.0 / (1.0 + math.exp(math.max(math.min(-HvZu, 20.0), -20.0))))
374 | val logarithm = math.max(math.min(logarithmRaw, 0.0), -20.0)
375 | lossSum = lossSum - logarithm
376 | lossNum = lossNum + 1L
377 | }
378 | }
379 | }
380 | }
381 | }
382 | }
383 | if(isNoShowLoss){
384 | gradientUpdateHash.put("LossValue", (Array(lossSum), lossNum))
385 | }
386 | gradientUpdateHash.toIterator
387 | }
388 |
389 | /**
390 | * 梯度更新
391 | *
392 | * */
393 | def embeddingUpdate(mainNodeInfo: Array[String],
394 | slaveNodeInfo: Array[String],
395 | embeddingHash: mutable.HashMap[String, Array[Double]],
396 | gradientUpdateHash: mutable.HashMap[String, (Array[Double], Long)],
397 | mainNodeWeightRaw: Array[Double],
398 | eavWvAggSum: DenseVector[Double],
399 | eavSum: Double,
400 | Hv: DenseVector[Double],
401 | label: Double,
402 | vectorSize: Int,
403 | sideInfoNum: Int): Unit = {
404 | val keySlaveNode = slaveNodeInfo.head + "#Zu"
405 | val valueSlaveNode = embeddingHash(keySlaveNode)
406 | val HvZu = Hv.t * new DenseVector[Double](valueSlaveNode)
407 | // 更新从节点Zu的梯度
408 | val gradientZu = (1.0 / (1.0 + math.exp(math.max(math.min(-HvZu, 20.0), -20.0))) - label) * Hv
409 | var gradientZuSum = gradientUpdateHash.getOrElseUpdate(keySlaveNode, (Array.fill[Double](vectorSize)(0.0), 0L))._1
410 | gradientZuSum = (new DenseVector[Double](gradientZuSum) + gradientZu).toArray
411 | var gradientZuNum = gradientUpdateHash.getOrElseUpdate(keySlaveNode, (Array.fill[Double](vectorSize)(0.0), 0L))._2
412 | gradientZuNum = gradientZuNum + 1L
413 | gradientUpdateHash.put(keySlaveNode, (gradientZuSum, gradientZuNum))
414 | // 更新主节点边权重的梯度和边信息的梯度
415 | val gradientHv = (1.0 / (1.0 + math.exp(math.max(math.min(-HvZu, 20.0), -20.0))) - label) * new DenseVector[Double](valueSlaveNode)
416 | val mainNodeWeightGradientSum = gradientUpdateHash.getOrElseUpdate(mainNodeInfo.head + "#Weight", (Array.fill[Double](sideInfoNum)(0.0), 0L))._1
417 | var mainNodeWeightGradientNum = gradientUpdateHash.getOrElseUpdate(mainNodeInfo.head + "#Weight", (Array.fill[Double](sideInfoNum)(0.0), 0L))._2
418 | for(m <- 0 until sideInfoNum) {
419 | val wsvVectorRaw = new DenseVector[Double](embeddingHash(mainNodeInfo(m)))
420 | // 更新主节点边权重的梯度
421 | val gradientAsv = gradientHv.t * (eavSum * math.exp(mainNodeWeightRaw(m)) * wsvVectorRaw - math.exp(mainNodeWeightRaw(m)) * eavWvAggSum)/(eavSum * eavSum)
422 | mainNodeWeightGradientSum(m) = mainNodeWeightGradientSum(m) + gradientAsv
423 | // 更新主节点边信息的梯度
424 | var gradientWsvSum = new DenseVector(gradientUpdateHash.getOrElseUpdate(mainNodeInfo(m), (Array.fill[Double](vectorSize)(0.0), 0L))._1)
425 | var gradientWsvNum = gradientUpdateHash.getOrElseUpdate(mainNodeInfo(m), (Array.fill[Double](vectorSize)(0.0), 0L))._2
426 | gradientWsvSum = gradientWsvSum + math.exp(mainNodeWeightRaw(m)) * gradientHv / eavSum
427 | gradientWsvNum = gradientWsvNum + 1L
428 | gradientUpdateHash.put(mainNodeInfo(m), (gradientWsvSum.toArray, gradientWsvNum))
429 | }
430 | mainNodeWeightGradientNum = mainNodeWeightGradientNum + 1L
431 | gradientUpdateHash.put(mainNodeInfo.head + "#Weight", (mainNodeWeightGradientSum, mainNodeWeightGradientNum))
432 | }
433 | }
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/eges/embedding/WeightedSkipGramBatch.scala:
--------------------------------------------------------------------------------
1 | package eges.embedding
2 |
3 | import java.util.Random
4 | import scala.collection.mutable
5 | import org.apache.spark.rdd.RDD
6 | import breeze.linalg.DenseVector
7 | import scala.collection.mutable.ArrayBuffer
8 | import org.apache.spark.storage.StorageLevel
9 |
10 | class WeightedSkipGramBatch extends Serializable {
11 |
12 | var vectorSize = 4
13 | var windowSize = 2
14 | var negativeSampleNum = 1
15 | var perSampleMaxNum = 10
16 | var subSample: Double = 0.0
17 | var learningRate: Double = 0.02
18 | var batchNum = 10
19 | var iterationNum = 5
20 | var isNoShowLoss = false
21 | var numPartitions = 200
22 | var nodeSideCount: Int = 0
23 | var nodeSideTable: Array[(String, Int)] = _
24 | var nodeSideHash = mutable.HashMap.empty[String, Int]
25 | var sampleTableNum: Int = 1e8.toInt
26 | lazy val sampleTable: Array[Int] = new Array[Int](sampleTableNum)
27 |
28 | /**
29 | * 每个节点embedding长度
30 | *
31 | * */
32 | def setVectorSize(value:Int):this.type = {
33 | require(value > 0, s"vectorSize must be more than 0, but it is $value")
34 | vectorSize = value
35 | this
36 | }
37 |
38 | /**
39 | * 窗口大小
40 | *
41 | * */
42 | def setWindowSize(value:Int):this.type = {
43 | require(value > 0, s"windowSize must be more than 0, but it is $value")
44 | windowSize = value
45 | this
46 | }
47 |
48 | /**
49 | * 负采样的样本个数
50 | *
51 | * */
52 | def setNegativeSampleNum(value:Int):this.type = {
53 | require(value > 0, s"negativeSampleNum must be more than 0, but it is $value")
54 | negativeSampleNum = value
55 | this
56 | }
57 |
58 | /**
59 | * 负采样时, 每个负样本最多采样的次数
60 | *
61 | * */
62 | def setPerSampleMaxNum(value:Int):this.type = {
63 | require(value > 0, s"perSampleMaxNum must be more than 0, but it is $value")
64 | perSampleMaxNum = value
65 | this
66 | }
67 |
68 | /**
69 | * 高频节点的下采样率
70 | *
71 | * */
72 | def setSubSample(value:Double):this.type = {
73 | require(value >= 0.0 && value <= 1.0, s"subSample must be not less than 0.0 and not more than 1.0, but it is $value")
74 | subSample = value
75 | this
76 | }
77 |
78 | /**
79 | * 初始化学习率
80 | *
81 | * */
82 | def setLearningRate(value:Double):this.type = {
83 | require(value > 0.0, s"learningRate must be more than 0.0, but it is $value")
84 | learningRate = value
85 | this
86 | }
87 |
88 | /**
89 | * 数据集批数
90 | *
91 | * */
92 | def setBatchNum(value:Int):this.type = {
93 | require(value > 0, s"batchNum must be more than 0, but it is $value")
94 | batchNum = value
95 | this
96 | }
97 |
98 | /**
99 | * 迭代次数
100 | *
101 | * */
102 | def setIterationNum(value:Int):this.type = {
103 | require(value > 0, s"iterationNum must be more than 0, but it is $value")
104 | iterationNum = value
105 | this
106 | }
107 |
108 | /**
109 | * 是否显示损失, 不建议显示损失值, 这样增加计算量
110 | *
111 | * */
112 | def setIsNoShowLoss(value:Boolean):this.type = {
113 | isNoShowLoss = value
114 | this
115 | }
116 |
117 | /**
118 | * 分区数量
119 | *
120 | * */
121 | def setNumPartitions(value:Int):this.type = {
122 | require(value > 0, s"numPartitions must be more than 0, but it is $value")
123 | numPartitions = value
124 | this
125 | }
126 |
127 | /**
128 | * 采样数组大小
129 | *
130 | * */
131 | def setSampleTableNum(value:Int):this.type = {
132 | require(value > 0, s"sampleTableNum must be more than 0, but it is $value")
133 | sampleTableNum = value
134 | this
135 | }
136 |
137 | /**
138 | * 训练
139 | *
140 | * */
141 | def fit(dataSet: RDD[Array[Array[String]]]): RDD[(String, String)] = {
142 | // 重新分区
143 | val dataSetRepartition = dataSet.map(k => ((new Random).nextInt(numPartitions), k))
144 | .repartition(numPartitions).map(k => k._2)
145 | dataSetRepartition.persist(StorageLevel.MEMORY_AND_DISK)
146 | val sc = dataSetRepartition.context
147 | // 初始化 WeightedSkipGram
148 | initWeightedSkipGram(dataSetRepartition)
149 | // 边和embedding的长度
150 | val sideInfoNum = dataSetRepartition.first().head.length
151 | val sideInfoNumBroadcast = sc.broadcast(sideInfoNum)
152 | val vectorSizeBroadcast = sc.broadcast(vectorSize)
153 | // 节点权重初始化
154 | val nodeWeight = dataSetRepartition.flatMap(x => x)
155 | .map(nodeInfo => nodeInfo.head + "#Weight").distinct()
156 | .map(nodeWeight => (nodeWeight, Array.fill[Double](sideInfoNumBroadcast.value)((new Random).nextGaussian())))
157 | // 节点Zu初始化
158 | val nodeZu = dataSetRepartition.flatMap(x => x)
159 | .map(nodeInfo => nodeInfo.head + "#Zu").distinct()
160 | .map(nodeZu => (nodeZu, Array.fill[Double](vectorSizeBroadcast.value)( ((new Random).nextDouble()-0.5)/vectorSizeBroadcast.value ) ))
161 | // 节点和边信息初始化
162 | val nodeAndSideVector = dataSetRepartition.flatMap(x => x).flatMap(x => x).distinct()
163 | .map(nodeAndSide => (nodeAndSide, Array.fill[Double](vectorSizeBroadcast.value)( ((new Random).nextDouble()-0.5)/vectorSizeBroadcast.value ) ))
164 | // 节点权重、节点和边信息初始化
165 | val embeddingHashInit = mutable.HashMap(nodeWeight.union(nodeZu).union(nodeAndSideVector).collectAsMap().toList:_*)
166 | // println("embedding初始化:")
167 | // embeddingHashInit.map(k => (k._1, k._2.mkString("@"))).foreach(println)
168 |
169 | // 广播部分变量
170 | val nodeSideTableBroadcast = sc.broadcast(nodeSideTable)
171 | val nodeSideHashBroadcast = sc.broadcast(nodeSideHash)
172 | val sampleTableBroadcast = sc.broadcast(sampleTable)
173 | val subSampleBroadcast = sc.broadcast(subSample)
174 | val windowSizeBroadcast = sc.broadcast(windowSize)
175 | val negativeSampleNumBroadcast = sc.broadcast(negativeSampleNum)
176 | val perSampleMaxNumBroadcast = sc.broadcast(perSampleMaxNum)
177 | val sampleTableNumBroadcast = sc.broadcast(sampleTableNum)
178 | val nodeSideCountBroadcast = sc.broadcast(nodeSideCount)
179 | val isNoShowLossBroadcast = sc.broadcast(isNoShowLoss)
180 | // val lambdaBroadcast = sc.broadcast(learningRate)
181 |
182 | // 数据分组
183 | val batchWeightArray = Array.fill(batchNum)(1.0/batchNum)
184 | val dataSetBatch = dataSetRepartition.randomSplit(batchWeightArray)
185 | val perBatchCountArray = Array.fill[Int](batchNum)(0)
186 | for(batchA <- 0 until batchNum){
187 | val batchCount = dataSetBatch(batchA).count().toInt
188 | perBatchCountArray(batchA) = batchCount
189 | if(batchCount > 0) dataSetBatch(batchA).persist(StorageLevel.MEMORY_AND_DISK)
190 | }
191 | dataSetRepartition.unpersist()
192 | // 训练
193 | var embeddingHashBroadcast = sc.broadcast(embeddingHashInit)
194 | for(iter <- 0 until iterationNum){
195 | val lambdaBroadcast = sc.broadcast(if(learningRate/(1+iter) >= learningRate/5) learningRate/(1+iter) else learningRate/5)
196 | // val lambdaBroadcast = sc.broadcast(if(learningRate/(1+iter) >= 0.000001) learningRate/(1+iter) else 0.000001)
197 | // val lambdaBroadcast = sc.broadcast(Array(0.5, 0.4, 0.2, 0.1, 0.05, 0.05)(iter))
198 | val perBatchLossArray = ArrayBuffer[Double]()
199 | for(batch <- 0 until batchNum){
200 | if(perBatchCountArray(batch) > 0) {
201 | // 训练结果
202 | val trainResult = dataSetBatch(batch).map(k => ((new Random).nextInt(numPartitions), k))
203 | .repartition(numPartitions).map(k => k._2).mapPartitions(dataIterator =>
204 | trainWeightedSkipGram(dataIterator, embeddingHashBroadcast.value, nodeSideTableBroadcast.value,
205 | nodeSideHashBroadcast.value, sampleTableBroadcast.value, subSampleBroadcast.value, vectorSizeBroadcast.value,
206 | sideInfoNumBroadcast.value, windowSizeBroadcast.value, negativeSampleNumBroadcast.value, perSampleMaxNumBroadcast.value,
207 | sampleTableNumBroadcast.value, nodeSideCountBroadcast.value, isNoShowLossBroadcast.value))
208 | trainResult.persist(StorageLevel.MEMORY_AND_DISK)
209 | // 节点权重的梯度聚类
210 | val trainWeightResult = trainResult.filter(k => k._1.endsWith("#Weight") && !k._1.equals("LossValue"))
211 | .aggregateByKey((new DenseVector[Double](Array.fill[Double](sideInfoNumBroadcast.value)(0.0)), 0L))(
212 | (vector, array) => (vector._1 + new DenseVector[Double](array._1), vector._2 + array._2),
213 | (vector1, vector2) => (vector1._1 + vector2._1, vector1._2 + vector2._2) )
214 | trainWeightResult.persist(StorageLevel.MEMORY_AND_DISK)
215 | // 节点和边embedding的梯度聚类
216 | val trainVectorResult = trainResult.filter(k => !k._1.endsWith("#Weight") && !k._1.equals("LossValue"))
217 | .aggregateByKey((new DenseVector[Double](Array.fill[Double](vectorSizeBroadcast.value)(0.0)), 0L))(
218 | (vector, array) => (vector._1 + new DenseVector[Double](array._1), vector._2 + array._2),
219 | (vector1, vector2) => (vector1._1 + vector2._1, vector1._2 + vector2._2) )
220 | trainVectorResult.persist(StorageLevel.MEMORY_AND_DISK)
221 | // 计算本轮本批平均损失
222 | if(isNoShowLoss){
223 | val trainLossResult = trainResult.filter(k => k._1.equals("LossValue"))
224 | .aggregateByKey((new DenseVector[Double](Array.fill[Double](1)(0.0)), 0L))(
225 | (vector, array) => (vector._1 + new DenseVector[Double](array._1), vector._2 + array._2),
226 | (vector1, vector2) => (vector1._1 + vector2._1, vector1._2 + vector2._2)
227 | ).map(k => (k._1, k._2._1/k._2._2.toDouble))
228 | .map(k => (k._1, k._2.toArray.head)).first()._2
229 | perBatchLossArray.append(trainLossResult)
230 | println(s"====第${iter+1}轮====第${batch+1}批====平均损失:$trainLossResult====")
231 | }
232 | val trainNodeSideResultUnion = trainWeightResult.union(trainVectorResult)
233 | .map{ case (key, (gradientVectorSum, num)) =>
234 | val embeddingRaw = new DenseVector[Double](embeddingHashBroadcast.value(key))
235 | val embeddingUpdate = embeddingRaw - 10 * lambdaBroadcast.value * gradientVectorSum / num.toDouble
236 | (key, embeddingUpdate.toArray)
237 | }
238 | var trainSideVectorResultMap = mutable.HashMap(trainNodeSideResultUnion.collectAsMap().toList:_*)
239 | trainSideVectorResultMap = embeddingHashBroadcast.value.++(trainSideVectorResultMap)
240 | embeddingHashBroadcast = sc.broadcast(trainSideVectorResultMap)
241 | trainResult.unpersist()
242 | trainWeightResult.unpersist()
243 | trainVectorResult.unpersist()
244 | }
245 | }
246 | // 计算每轮平均损失
247 | val batchLossLength = perBatchLossArray.length
248 | if(isNoShowLoss && batchLossLength > 0){
249 | println(s"====第${iter+1}轮====平均损失:${perBatchLossArray.sum/batchLossLength}====")
250 | }
251 | }
252 | for(batchB <- 0 until batchNum){
253 | if(perBatchCountArray(batchB) > 0) dataSetBatch(batchB).unpersist()
254 | }
255 |
256 | // 每个节点加权向量
257 | var embeddingHashResult = embeddingHashBroadcast.value
258 | val nodeHvArray = nodeSideTable.map(k => k._1.split("@"))
259 | .map(nodeInfo => {
260 | val node = nodeInfo.head
261 | var eavSum = 0.0
262 | var eavWvAggSum = new DenseVector[Double](Array.fill[Double](vectorSize)(0.0))
263 | val keyEA = node + "#Weight"
264 | val valueEA = embeddingHashResult.getOrElse(keyEA, Array.fill[Double](sideInfoNum)(0.0))
265 | for(n <- 0 until sideInfoNum) {
266 | val keyNodeOrSide = nodeInfo(n)
267 | val valueNodeOrSide = embeddingHashResult.getOrElse(keyNodeOrSide, Array.fill[Double](vectorSize)(0.0))
268 | eavWvAggSum = eavWvAggSum + math.exp(valueEA(n)) * new DenseVector[Double](valueNodeOrSide)
269 | eavSum = eavSum + math.exp(valueEA(n))
270 | }
271 | val Hv = (eavWvAggSum/eavSum).toArray
272 | (nodeInfo.mkString("#"), Hv)
273 | })
274 | val nodeHvMap = mutable.Map(nodeHvArray:_*)
275 | embeddingHashResult = embeddingHashResult.++(nodeHvMap)
276 | val embeddingResult = sc.parallelize(embeddingHashResult.map(k => (k._1, k._2.mkString("@"))).toList, numPartitions)
277 | embeddingResult
278 | }
279 |
280 | /**
281 | * 初始化
282 | *
283 | * */
284 | def initWeightedSkipGram(dataSet: RDD[Array[Array[String]]]): Unit = {
285 | // 节点和边信息出现的频次数组
286 | nodeSideTable = dataSet.flatMap(x => x)
287 | .map(nodeInfo => (nodeInfo.mkString("@"), 1))
288 | .reduceByKey(_ + _)
289 | .collect().sortBy(-_._2)
290 | // 节点和边信息总数量
291 | nodeSideCount = nodeSideTable.map(_._2).par.sum
292 | //节点和边信息出现的频次哈希map
293 | nodeSideHash = mutable.HashMap(nodeSideTable:_*)
294 | // 节点和边信息尺寸
295 | val nodeSideSize = nodeSideTable.length
296 | // Negative Sampling 负采样初始化
297 | var a = 0
298 | val power = 0.75
299 | var nodesPow = 0.0
300 | while (a < nodeSideSize) {
301 | nodesPow += Math.pow(nodeSideTable(a)._2, power)
302 | a = a + 1
303 | }
304 | var b = 0
305 | var freq = Math.pow(nodeSideTable(b)._2, power) / nodesPow
306 | var c = 0
307 | while (c < sampleTableNum) {
308 | sampleTable(c) = b
309 | if ((c.toDouble + 1.0) / sampleTableNum >= freq) {
310 | b = b + 1
311 | if (b >= nodeSideSize) {
312 | b = nodeSideSize - 1
313 | }
314 | freq += Math.pow(nodeSideTable(b)._2, power) / nodesPow
315 | }
316 | c = c + 1
317 | }
318 | }
319 |
320 | /**
321 | * 每个分区训练
322 | *
323 | * */
324 | def trainWeightedSkipGram(dataSet: Iterator[Array[Array[String]]],
325 | embeddingHash: mutable.HashMap[String, Array[Double]],
326 | nodeSideTable: Array[(String, Int)],
327 | nodeSideHash: mutable.HashMap[String, Int],
328 | sampleTable: Array[Int],
329 | subSample: Double,
330 | vectorSize: Int,
331 | sideInfoNum: Int,
332 | windowSize: Int,
333 | negativeSampleNum: Int,
334 | perSampleMaxNum: Int,
335 | sampleTableNum: Int,
336 | nodeSideCount: Int,
337 | isNoShowLoss: Boolean): Iterator[(String, (Array[Double], Long))] = {
338 | val gradientUpdateHash = new mutable.HashMap[String, (Array[Double], Long)]()
339 | var lossSum = 0.0
340 | var lossNum = 0L
341 | for(data <- dataSet) {
342 | // 高频节点的下采样
343 | val dataSubSample = ArrayBuffer[Array[String]]()
344 | for( nodeOrSideInfoArray <- data){
345 | if(subSample > 0.0) {
346 | val nodeOrSideInfoFrequency = nodeSideHash(nodeOrSideInfoArray.mkString("@"))
347 | val keepProbability = (Math.sqrt(nodeOrSideInfoFrequency/ (subSample * nodeSideCount)) + 1.0) * (subSample * nodeSideCount) / nodeOrSideInfoFrequency
348 | if (keepProbability >= (new Random).nextDouble()) {
349 | dataSubSample.append(nodeOrSideInfoArray)
350 | }
351 | } else {
352 | dataSubSample.append(nodeOrSideInfoArray)
353 | }
354 | }
355 | val dssl = dataSubSample.length
356 | for(i <- 0 until dssl){
357 | val mainNodeInfo = dataSubSample(i)
358 | val mainNode = mainNodeInfo.head
359 | var eavSum = 0.0
360 | var eavWvAggSum = new DenseVector[Double](Array.fill[Double](vectorSize)(0.0))
361 | val keyEA = mainNode + "#Weight"
362 | val valueEA = embeddingHash(keyEA)
363 | for(k <- 0 until sideInfoNum) {
364 | val keyNodeOrSide = mainNodeInfo(k)
365 | val valueNodeOrSide = embeddingHash(keyNodeOrSide)
366 | eavWvAggSum = eavWvAggSum + math.exp(valueEA(k)) * new DenseVector[Double](valueNodeOrSide)
367 | eavSum = eavSum + math.exp(valueEA(k))
368 | }
369 | val Hv = eavWvAggSum/eavSum
370 | // 主节点对应窗口内的正样本集合
371 | val mainSlaveNodeSet = dataSubSample.slice(math.max(0, i-windowSize), math.min(i+windowSize, dssl-1) + 1)
372 | .map(array => array.head).toSet
373 | for(j <- math.max(0, i-windowSize) to math.min(i+windowSize, dssl-1)){
374 | if(j != i) {
375 | // 正样本训练
376 | val slaveNodeInfo = dataSubSample(j)
377 | val slaveNode = slaveNodeInfo.head
378 | embeddingUpdate(mainNodeInfo, slaveNodeInfo, embeddingHash, gradientUpdateHash, valueEA, eavWvAggSum, eavSum, Hv, 1.0, vectorSize, sideInfoNum)
379 | // 正样本计算损失值
380 | if(isNoShowLoss){
381 | val keySlaveNode = slaveNode + "#Zu"
382 | val valueSlaveNode = embeddingHash(keySlaveNode)
383 | val HvZu = Hv.t * new DenseVector[Double](valueSlaveNode)
384 | val logarithmRaw = math.log(1.0 / (1.0 + math.exp(math.max(math.min(-HvZu, 20.0), -20.0))))
385 | val logarithm = math.max(math.min(logarithmRaw, 0.0), -20.0)
386 | lossSum = lossSum - logarithm
387 | lossNum = lossNum + 1L
388 | }
389 | // 负样本采样和训练
390 | for(_ <- 0 until negativeSampleNum){
391 | // 负样本采样
392 | var sampleNodeInfo = slaveNodeInfo
393 | var sampleNode = slaveNode
394 | var sampleNum = 0
395 | while(mainSlaveNodeSet.contains(sampleNode) && sampleNum < perSampleMaxNum) {
396 | val index = sampleTable((new Random).nextInt(sampleTableNum))
397 | sampleNodeInfo = nodeSideTable(index)._1.split("@")
398 | sampleNode = sampleNodeInfo.head
399 | sampleNum = sampleNum + 1
400 | }
401 | // 负样本训练
402 | embeddingUpdate(mainNodeInfo, sampleNodeInfo, embeddingHash, gradientUpdateHash, valueEA, eavWvAggSum, eavSum, Hv, 0.0, vectorSize, sideInfoNum)
403 | // 负样本计算损失值
404 | if(isNoShowLoss){
405 | val keySampleNode = sampleNodeInfo.head + "#Zu"
406 | val valueSampleNode = embeddingHash(keySampleNode)
407 | val HvZu = Hv.t * new DenseVector[Double](valueSampleNode)
408 | val logarithmRaw = math.log(1.0 - 1.0 / (1.0 + math.exp(math.max(math.min(-HvZu, 20.0), -20.0))))
409 | val logarithm = math.max(math.min(logarithmRaw, 0.0), -20.0)
410 | lossSum = lossSum - logarithm
411 | lossNum = lossNum + 1L
412 | }
413 | }
414 | }
415 | }
416 | }
417 | }
418 | if(isNoShowLoss){
419 | gradientUpdateHash.put("LossValue", (Array(lossSum), lossNum))
420 | }
421 | gradientUpdateHash.toIterator
422 | }
423 |
424 | /**
425 | * 梯度更新
426 | *
427 | * */
428 | def embeddingUpdate(mainNodeInfo: Array[String],
429 | slaveNodeInfo: Array[String],
430 | embeddingHash: mutable.HashMap[String, Array[Double]],
431 | gradientUpdateHash: mutable.HashMap[String, (Array[Double], Long)],
432 | mainNodeWeightRaw: Array[Double],
433 | eavWvAggSum: DenseVector[Double],
434 | eavSum: Double,
435 | Hv: DenseVector[Double],
436 | label: Double,
437 | vectorSize: Int,
438 | sideInfoNum: Int): Unit = {
439 | val keySlaveNode = slaveNodeInfo.head + "#Zu"
440 | val valueSlaveNode = embeddingHash(keySlaveNode)
441 | val HvZu = Hv.t * new DenseVector[Double](valueSlaveNode)
442 | // 更新从节点Zu的梯度
443 | val gradientZu = (1.0 / (1.0 + math.exp(math.max(math.min(-HvZu, 20.0), -20.0))) - label) * Hv
444 | var gradientZuSum = gradientUpdateHash.getOrElseUpdate(keySlaveNode, (Array.fill[Double](vectorSize)(0.0), 0L))._1
445 | gradientZuSum = (new DenseVector[Double](gradientZuSum) + gradientZu).toArray
446 | var gradientZuNum = gradientUpdateHash.getOrElseUpdate(keySlaveNode, (Array.fill[Double](vectorSize)(0.0), 0L))._2
447 | gradientZuNum = gradientZuNum + 1L
448 | gradientUpdateHash.put(keySlaveNode, (gradientZuSum, gradientZuNum))
449 | // 更新主节点边权重的梯度和边信息的梯度
450 | val gradientHv = (1.0 / (1.0 + math.exp(math.max(math.min(-HvZu, 20.0), -20.0))) - label) * new DenseVector[Double](valueSlaveNode)
451 | val mainNodeWeightGradientSum = gradientUpdateHash.getOrElseUpdate(mainNodeInfo.head + "#Weight", (Array.fill[Double](sideInfoNum)(0.0), 0L))._1
452 | var mainNodeWeightGradientNum = gradientUpdateHash.getOrElseUpdate(mainNodeInfo.head + "#Weight", (Array.fill[Double](sideInfoNum)(0.0), 0L))._2
453 | for(m <- 0 until sideInfoNum) {
454 | val wsvVectorRaw = new DenseVector[Double](embeddingHash(mainNodeInfo(m)))
455 | // 更新主节点边权重的梯度
456 | val gradientAsv = gradientHv.t * (eavSum * math.exp(mainNodeWeightRaw(m)) * wsvVectorRaw - math.exp(mainNodeWeightRaw(m)) * eavWvAggSum)/(eavSum * eavSum)
457 | mainNodeWeightGradientSum(m) = mainNodeWeightGradientSum(m) + gradientAsv
458 | // 更新主节点边信息的梯度
459 | var gradientWsvSum = new DenseVector(gradientUpdateHash.getOrElseUpdate(mainNodeInfo(m), (Array.fill[Double](vectorSize)(0.0), 0L))._1)
460 | var gradientWsvNum = gradientUpdateHash.getOrElseUpdate(mainNodeInfo(m), (Array.fill[Double](vectorSize)(0.0), 0L))._2
461 | gradientWsvSum = gradientWsvSum + math.exp(mainNodeWeightRaw(m)) * gradientHv / eavSum
462 | gradientWsvNum = gradientWsvNum + 1L
463 | gradientUpdateHash.put(mainNodeInfo(m), (gradientWsvSum.toArray, gradientWsvNum))
464 | }
465 | mainNodeWeightGradientNum = mainNodeWeightGradientNum + 1L
466 | gradientUpdateHash.put(mainNodeInfo.head + "#Weight", (mainNodeWeightGradientSum, mainNodeWeightGradientNum))
467 | }
468 | }
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/eges/random/sample/Alias.scala:
--------------------------------------------------------------------------------
1 | package eges.random.sample
2 |
3 | import scala.collection.mutable.ArrayBuffer
4 |
5 | object Alias {
6 |
7 | /**
8 | * 根据节点和节点的权重创建Alias抽样方法
9 | *
10 | * */
11 | def setupAlias(nodeWeights: Array[(Long, Double)]): (Array[Int], Array[Double]) = {
12 | val K = nodeWeights.length
13 | val J = Array.fill(K)(0)
14 | val q = Array.fill(K)(0.0)
15 |
16 | val smaller = new ArrayBuffer[Int]()
17 | val larger = new ArrayBuffer[Int]()
18 |
19 | val sum = nodeWeights.map(_._2).sum
20 | nodeWeights.zipWithIndex.foreach { case ((nodeId, weight), i) =>
21 | q(i) = K * weight / sum
22 | if (q(i) < 1.0) {
23 | smaller.append(i)
24 | } else {
25 | larger.append(i)
26 | }
27 | }
28 |
29 | while (smaller.nonEmpty && larger.nonEmpty) {
30 | val small = smaller.remove(smaller.length - 1)
31 | val large = larger.remove(larger.length - 1)
32 |
33 | J(small) = large
34 | q(large) = q(large) + q(small) - 1.0
35 | if (q(large) < 1.0) smaller.append(large)
36 | else larger.append(large)
37 | }
38 |
39 | (J, q)
40 | }
41 |
42 | /**
43 | *
44 | * 根据源节点、源节点的邻居和目的节点的邻居来创建边的Alias抽样方法
45 | *
46 | * */
47 | def setupEdgeAlias(p: Double = 1.0, q: Double = 1.0)(srcId: Long, srcNeighbors: Array[(Long, Double)], dstNeighbors: Array[(Long, Double)]): (Array[Int], Array[Double]) = {
48 | val neighbors_ = dstNeighbors.map { case (dstNeighborId, weight) =>
49 | var unnormProb = weight / q
50 | if (srcId == dstNeighborId) unnormProb = weight / p
51 | else if (srcNeighbors.exists(_._1 == dstNeighborId)) unnormProb = weight
52 |
53 | (dstNeighborId, unnormProb)
54 | }
55 |
56 | setupAlias(neighbors_)
57 | }
58 |
59 | /**
60 | * 用Alias进行抽样
61 | *
62 | * */
63 | def drawAlias(J: Array[Int], q: Array[Double]): Int = {
64 | val K = J.length
65 | val kk = math.floor(math.random * K).toInt
66 |
67 | if (math.random < q(kk)) kk
68 | else J(kk)
69 | }
70 |
71 | }
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/eges/random/walk/Attributes.scala:
--------------------------------------------------------------------------------
1 | package eges.random.walk
2 |
3 | import java.io.Serializable
4 |
5 | object Attributes {
6 |
7 | /**
8 | * 节点属性
9 | * @param neighbors 节点对应的邻居节点和权重
10 | * @param path 以该节点开头的路径的前两个节点
11 | *
12 | * */
13 | case class NodeAttr(var neighbors: Array[(Long, Double)] = Array.empty[(Long, Double)],
14 | var path: Array[Array[Long]] = Array.empty[Array[Long]]) extends Serializable
15 |
16 | /**
17 | * 边属性
18 | * @param dstNeighbors 目的节点对应的邻居节点
19 | * @param J Alias抽样方法返回值
20 | * @param q Alias抽样方法返回值
21 | *
22 | * */
23 | case class EdgeAttr(var dstNeighbors: Array[Long] = Array.empty[Long],
24 | var J: Array[Int] = Array.empty[Int],
25 | var q: Array[Double] = Array.empty[Double]) extends Serializable
26 |
27 | }
28 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/eges/random/walk/RandomWalk.scala:
--------------------------------------------------------------------------------
1 | package eges.random.walk
2 |
3 | import scala.util.Try
4 | import java.util.Random
5 | import scala.collection.mutable
6 | import org.apache.spark.rdd.RDD
7 | import eges.random.sample.Alias
8 | import org.apache.spark.graphx._
9 | import eges.random.walk.Attributes._
10 | import org.apache.spark.SparkContext
11 | import scala.collection.mutable.ArrayBuffer
12 | import org.apache.spark.broadcast.Broadcast
13 |
14 | class RandomWalk extends Serializable {
15 |
16 | var nodeIterationNum = 10
17 | var nodeWalkLength = 78
18 | var returnParameter = 1.0
19 | var inOutParameter = 1.0
20 | var nodeMaxDegree = 200
21 | var numPartitions = 500
22 | var nodeIndex: Broadcast[mutable.HashMap[String, Long]] = _
23 | var indexNode: Broadcast[mutable.HashMap[Long, String]] = _
24 | var indexedNodes: RDD[(VertexId, NodeAttr)] = _
25 | var indexedEdges: RDD[Edge[EdgeAttr]] = _
26 | var graph: Graph[NodeAttr, EdgeAttr] = _
27 | var randomWalkPaths: RDD[(Long, ArrayBuffer[Long])] = _
28 |
29 | /**
30 | * 每个节点产生路径的数量
31 | *
32 | * */
33 | def setNodeIterationNum(value:Int):this.type = {
34 | require(value > 0, s"nodeIterationNum must be more than 0, but it is $value")
35 | nodeIterationNum = value
36 | this
37 | }
38 |
39 | /**
40 | * 每个路径的最大长度
41 | *
42 | * */
43 | def setNodeWalkLength(value:Int):this.type = {
44 | require(value >= 2, s"nodeWalkLength must be not less than 2, but it is $value")
45 | nodeWalkLength = value - 2
46 | this
47 | }
48 |
49 | /**
50 | * 往回走的参数
51 | *
52 | * */
53 | def setReturnParameter(value:Double):this.type = {
54 | require(value > 0, s"returnParameter must be more than 0, but it is $value")
55 | returnParameter = value
56 | this
57 | }
58 |
59 | /**
60 | * 往外走的参数
61 | *
62 | * */
63 | def setInOutParameter(value:Double):this.type = {
64 | require(value > 0, s"inOutParameter must be more than 0, but it is $value")
65 | inOutParameter = value
66 | this
67 | }
68 |
69 | /**
70 | * 每个节点最多的邻居数量
71 | *
72 | * */
73 | def setNodeMaxDegree(value:Int):this.type = {
74 | require(value > 0, s"nodeMaxDegree must be more than 0, but it is $value")
75 | nodeMaxDegree = value
76 | this
77 | }
78 |
79 | /**
80 | * 分区数量, 并行度
81 | *
82 | * */
83 | def setNumPartitions(value:Int):this.type = {
84 | require(value > 0, s"numPartitions must be more than 0, but it is $value")
85 | numPartitions = value
86 | this
87 | }
88 |
89 | /**
90 | * 每个节点游走的路径
91 | *
92 | * */
93 | def fit(node2Weight: RDD[(String, String, Double)]): RDD[(String, String)] = {
94 |
95 | // 广播部分变量
96 | val sc = node2Weight.context
97 | val nodeIterationNumBroadcast = sc.broadcast(nodeIterationNum)
98 | val returnParameterBroadcast = sc.broadcast(returnParameter)
99 | val inOutParameterBroadcast = sc.broadcast(inOutParameter)
100 | val nodeMaxDegreeBroadcast = sc.broadcast(nodeMaxDegree)
101 |
102 | // 将字符串节点与节点的权重转化为长整型与长整型的权重
103 | val inputTriplets = processData(node2Weight, sc)
104 | inputTriplets.cache()
105 | inputTriplets.first() // 行动操作
106 |
107 | // 节点对应的节点属性
108 | indexedNodes = inputTriplets.map{ case (node1, node2, weight) => (node1, (node2, weight)) }
109 | .combineByKey(dstIdWeight =>{
110 | implicit object ord extends Ordering[(Long, Double)]{
111 | override def compare(p1:(Long, Double), p2:(Long, Double)):Int = {
112 | p2._2.compareTo(p1._2)
113 | }
114 | }
115 | val priorityQueue = new mutable.PriorityQueue[(Long, Double)]()
116 | priorityQueue.enqueue(dstIdWeight)
117 | priorityQueue
118 | },(priorityQueue:mutable.PriorityQueue[(Long,Double)], dstIdWeight)=>{
119 | if(priorityQueue.size < nodeMaxDegreeBroadcast.value){
120 | priorityQueue.enqueue(dstIdWeight)
121 | }else{
122 | if(priorityQueue.head._2 < dstIdWeight._2){
123 | priorityQueue.dequeue()
124 | priorityQueue.enqueue(dstIdWeight)
125 | }
126 | }
127 | priorityQueue
128 | },(priorityQueuePre:mutable.PriorityQueue[(Long,Double)],
129 | priorityQueueLast:mutable.PriorityQueue[(Long,Double)])=>{
130 | while(priorityQueueLast.nonEmpty){
131 | val dstIdWeight = priorityQueueLast.dequeue()
132 | if(priorityQueuePre.size < nodeMaxDegreeBroadcast.value){
133 | priorityQueuePre.enqueue(dstIdWeight)
134 | }else{
135 | if(priorityQueuePre.head._2 < dstIdWeight._2){
136 | priorityQueuePre.dequeue()
137 | priorityQueuePre.enqueue(dstIdWeight)
138 | }
139 | }
140 | }
141 | priorityQueuePre
142 | }).map{ case (srcId, dstIdWeightPriorityQueue) =>
143 | (srcId, NodeAttr(neighbors = dstIdWeightPriorityQueue.toArray))
144 | }.map(k => ((new Random).nextInt(numPartitions), k)).repartition(numPartitions).map(k => k._2)
145 | indexedNodes.cache()
146 | indexedNodes.first() // 行动操作
147 | inputTriplets.unpersist()
148 |
149 | //边的属性
150 | indexedEdges = indexedNodes.flatMap { case (srcId, srcIdAttr) =>
151 | srcIdAttr.neighbors.map { case (dstId, weight) =>
152 | Edge(srcId, dstId, EdgeAttr())
153 | }
154 | }.map(k => ((new Random).nextInt(numPartitions), k)).repartition(numPartitions).map(k => k._2)
155 | indexedEdges.cache()
156 | indexedEdges.first() // 行动操作
157 |
158 | // 通过节点属性和边的属性构建图
159 | graph = Graph(indexedNodes, indexedEdges)
160 | .mapVertices[NodeAttr]{ case (vertexId, vertexAttr) =>
161 | try{
162 | val (j, q) = Alias.setupAlias(vertexAttr.neighbors)
163 | val pathArray = ArrayBuffer[Array[Long]]()
164 | for(_ <- 0 until nodeIterationNumBroadcast.value){
165 | val nextNodeIndex = Alias.drawAlias(j, q)
166 | val path = Array(vertexId, vertexAttr.neighbors(nextNodeIndex)._1)
167 | pathArray.append(path)
168 | }
169 | vertexAttr.path = pathArray.toArray
170 | vertexAttr
171 | } catch {
172 | case _:Exception => Attributes.NodeAttr()
173 | }
174 | }.mapTriplets{ edgeTriplet: EdgeTriplet[NodeAttr, EdgeAttr] =>
175 | val dstAttrNeighbors = Try(edgeTriplet.dstAttr.neighbors).getOrElse(Array.empty[(Long, Double)])
176 | val (j, q) = Alias.setupEdgeAlias(returnParameterBroadcast.value, inOutParameterBroadcast.value)(edgeTriplet.srcId, edgeTriplet.srcAttr.neighbors, dstAttrNeighbors)
177 | edgeTriplet.attr.J = j
178 | edgeTriplet.attr.q = q
179 | edgeTriplet.attr.dstNeighbors = dstAttrNeighbors.map(_._1)
180 | edgeTriplet.attr
181 | }
182 | graph.cache()
183 |
184 | // 所有边的源节点、目的节点和边的属性
185 | val edgeAttr = graph.triplets.map{ edgeTriplet =>
186 | (s"${edgeTriplet.srcId}${edgeTriplet.dstId}", edgeTriplet.attr)
187 | }.map(k => ((new Random).nextInt(numPartitions), k)).repartition(numPartitions).map(k => k._2)
188 | edgeAttr.cache()
189 | edgeAttr.first() // 行动操作
190 | indexedNodes.unpersist()
191 | indexedEdges.unpersist()
192 |
193 | // 随机游走产生序列数据
194 | for (iter <- 0 until nodeIterationNum) {
195 | var randomWalk = graph.vertices.filter{ case (vertexId, vertexAttr) =>
196 | val vertexNeighborsLength = Try(vertexAttr.neighbors.length).getOrElse(0)
197 | vertexNeighborsLength > 0
198 | }.map { case (vertexId, vertexAttr) =>
199 | val pathBuffer = new ArrayBuffer[Long]()
200 | pathBuffer.append(vertexAttr.path(iter):_*)
201 | (vertexId, pathBuffer)
202 | }
203 | randomWalk.cache()
204 | randomWalk.first() // 行动操作
205 |
206 | for (_ <- 0 until nodeWalkLength) {
207 | randomWalk = randomWalk.map { case (srcNodeId, pathBuffer) =>
208 | val prevNodeId = pathBuffer(pathBuffer.length - 2)
209 | val currentNodeId = pathBuffer.last
210 |
211 | (s"$prevNodeId$currentNodeId", (srcNodeId, pathBuffer))
212 | }.join(edgeAttr).map { case (edge, ((srcNodeId, pathBuffer), attr)) =>
213 | try {
214 | val nextNodeIndex = Alias.drawAlias(attr.J, attr.q)
215 | val nextNodeId = attr.dstNeighbors(nextNodeIndex)
216 | pathBuffer.append(nextNodeId)
217 |
218 | (srcNodeId, pathBuffer)
219 | } catch {
220 | case _: Exception => (srcNodeId, pathBuffer)
221 | }
222 | }
223 | randomWalk.cache()
224 | randomWalk.first() // 行动操作
225 | }
226 |
227 | if (randomWalkPaths != null) {
228 | randomWalkPaths = randomWalkPaths.union(randomWalk)
229 | } else {
230 | randomWalkPaths = randomWalk
231 | }
232 | randomWalkPaths.cache()
233 | randomWalkPaths.first() // 行动操作
234 | randomWalk.unpersist()
235 | }
236 | graph.unpersist()
237 | edgeAttr.unpersist()
238 |
239 | // 将长整型节点与路径转化为字符串节点与路径, 路径用","分开
240 | val paths = longToString(randomWalkPaths)
241 | paths.cache()
242 | paths.first() // 行动操作
243 | randomWalkPaths.unpersist()
244 | paths
245 |
246 | }
247 |
248 | /**
249 | * 将字符串节点与节点的权重转化为长整型与长整型的权重
250 | *
251 | * */
252 | def processData(node2Weight: RDD[(String, String, Double)], sc:SparkContext): RDD[(Long, Long, Double)] = {
253 | val nodeIndexArray = node2Weight.flatMap(node => Array(node._1, node._2))
254 | .distinct().zipWithIndex().collect()
255 | val indexNodeArray = nodeIndexArray.map{case (node, index) => (index, node)}
256 | indexNode = sc.broadcast(mutable.HashMap(indexNodeArray:_*))
257 | nodeIndex = sc.broadcast(mutable.HashMap(nodeIndexArray:_*))
258 | val inputTriplets = node2Weight.map{ case (node1, node2, weight) =>
259 | val node1Index = nodeIndex.value(node1)
260 | val node2Index = nodeIndex.value(node2)
261 | (node1Index, node2Index, weight)
262 | }
263 | inputTriplets
264 | }
265 |
266 | /**
267 | * 将长整型节点与路径转化为字符串节点与路径, 路径用","分开
268 | *
269 | * */
270 | def longToString(srcNodeIdPaths: RDD[(Long, ArrayBuffer[Long])]): RDD[(String, String)] = {
271 | val pathsResult = srcNodeIdPaths.map{case (srcNodeId, paths) =>
272 | val srcNodeIdString = indexNode.value(srcNodeId)
273 | val pathsString = paths.map(index => indexNode.value(index)).mkString(",")
274 | (srcNodeIdString, pathsString)
275 | }
276 | pathsResult
277 | }
278 |
279 | }
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/example/Example1.scala:
--------------------------------------------------------------------------------
1 | package example
2 |
3 | import scala.collection.mutable
4 | import eges.random.walk.RandomWalk
5 | import sparkapplication.BaseSparkLocal
6 |
7 | object Example1 extends BaseSparkLocal {
8 | def main(args:Array[String]):Unit = {
9 | val spark = this.basicSpark
10 | import spark.implicits._
11 |
12 | val list1 = List((("gds1", "group1"), ("gds2", 0.1)), (("gds1", "group1"), ("gds3", 0.2)), (("gds1", "group1"), ("gds4", 0.3)),
13 | (("gds2", "group2"), ("gds3", 0.1)), (("gds2", "group2"), ("gds4", 0.3)), (("gds2", "group2"), ("gds5", 0.6)),
14 | (("gds3", "group4"), ("gds2", 0.1)), (("gds3", "group4"), ("gds5", 0.3)), (("gds3", "group4"), ("gds6", 0.5)))
15 | val list2 = spark.sparkContext.parallelize(list1)
16 | val list3 = list2.combineByKey(gds2WithSimilarity=>{
17 | implicit object ord extends Ordering[(String,Double)] {
18 | override def compare(p1: (String,Double), p2: (String,Double)): Int = {
19 | p2._2.compareTo(p1._2)
20 | }
21 | }
22 |
23 | val priorityQueue = new mutable.PriorityQueue[(String,Double)]()
24 | priorityQueue.enqueue(gds2WithSimilarity)
25 | priorityQueue
26 | },(priorityQueue:mutable.PriorityQueue[(String,Double)],gds2WithSimilarity)=>{
27 |
28 | if(priorityQueue.size < 2){
29 | priorityQueue.enqueue(gds2WithSimilarity)
30 | }else{
31 | if(priorityQueue.head._2 < gds2WithSimilarity._2){
32 | priorityQueue.dequeue()
33 | priorityQueue.enqueue(gds2WithSimilarity)
34 | }
35 | }
36 |
37 | priorityQueue
38 | },(priorityQueuePre:mutable.PriorityQueue[(String,Double)],
39 | priorityQueueLast:mutable.PriorityQueue[(String,Double)])=>{
40 | while(!priorityQueueLast.isEmpty){
41 |
42 | val gds2WithSimilarity = priorityQueueLast.dequeue()
43 | if(priorityQueuePre.size < 2){
44 | priorityQueuePre.enqueue(gds2WithSimilarity)
45 | }else{
46 | if(priorityQueuePre.head._2 < gds2WithSimilarity._2){
47 | priorityQueuePre.dequeue()
48 | priorityQueuePre.enqueue(gds2WithSimilarity)
49 | }
50 | }
51 | }
52 |
53 | priorityQueuePre}).flatMap({case ((gdsCd1, l4GroupCd), gdsCd2AndSimilarityPriorityQueue) => {
54 |
55 | gdsCd2AndSimilarityPriorityQueue.toList.map({case(gdsCd2, similarity) => (gdsCd1, gdsCd2, l4GroupCd, similarity)})
56 |
57 | }})
58 |
59 | list3.foreach(println)
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 | }
74 |
75 | }
76 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/example/Example10.scala:
--------------------------------------------------------------------------------
1 | package example
2 |
3 | import breeze.linalg.DenseVector
4 | import sparkapplication.BaseSparkLocal
5 |
6 | object Example10 extends BaseSparkLocal {
7 | def main(args:Array[String]):Unit = {
8 | val spark = this.basicSpark
9 | import spark.implicits._
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 | }
56 |
57 | }
58 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/example/Example11.scala:
--------------------------------------------------------------------------------
1 | package example
2 |
3 | import java.util.Random
4 | import breeze.linalg.DenseVector
5 | import sparkapplication.BaseSparkLocal
6 |
7 | object Example11 extends BaseSparkLocal {
8 | def main(args:Array[String]):Unit = {
9 | // val spark = this.basicSpark
10 | // import spark.implicits._
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 | }
46 |
47 | }
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/example/Example2.scala:
--------------------------------------------------------------------------------
1 | package example
2 |
3 | import breeze.linalg.operators
4 | import breeze.linalg.{DenseVector, SparseVector, Vector}
5 | import breeze.numerics._
6 | import java.util.{Date, Random}
7 | import sparkapplication.BaseSparkLocal
8 | import org.apache.spark.sql.functions._
9 | import scala.collection.mutable
10 |
11 | object Example2 extends BaseSparkLocal {
12 | def main(args:Array[String]):Unit = {
13 | val spark = this.basicSpark
14 | import spark.implicits._
15 |
16 | val nodeMaxDegree = 2
17 | val list1 = List((1:Long, (2:Long, 3.2:Double)), (1:Long, (3:Long, 2.2:Double)), (1:Long, (4:Long, 1.2:Double)),
18 | (1:Long, (5:Long, 0.2:Double)), (6:Long, (7:Long, 6.2:Double)), (6:Long, (8:Long, 8.2:Double)))
19 | val list2 = spark.sparkContext.parallelize(list1)
20 | val list3 = list2.combineByKey(dstIdWeight =>{
21 | implicit object ord extends Ordering[(Long, Double)]{
22 | override def compare(p1:(Long, Double), p2:(Long, Double)):Int = {
23 | p2._2.compareTo(p1._2)
24 | }
25 | }
26 | val priorityQueue = new mutable.PriorityQueue[(Long, Double)]()
27 | priorityQueue.enqueue(dstIdWeight)
28 | priorityQueue
29 | },(priorityQueue:mutable.PriorityQueue[(Long,Double)], dstIdWeight)=>{
30 | if(priorityQueue.size < nodeMaxDegree){
31 | priorityQueue.enqueue(dstIdWeight)
32 | }else{
33 | if(priorityQueue.head._2 < dstIdWeight._2){
34 | priorityQueue.dequeue()
35 | priorityQueue.enqueue(dstIdWeight)
36 | }
37 | }
38 | priorityQueue
39 | },(priorityQueuePre:mutable.PriorityQueue[(Long,Double)],
40 | priorityQueueLast:mutable.PriorityQueue[(Long,Double)])=>{
41 | while(priorityQueueLast.nonEmpty){
42 | val dstIdWeight = priorityQueueLast.dequeue()
43 | if(priorityQueuePre.size < nodeMaxDegree){
44 | priorityQueuePre.enqueue(dstIdWeight)
45 | }else{
46 | if(priorityQueuePre.head._2 < dstIdWeight._2){
47 | priorityQueuePre.dequeue()
48 | priorityQueuePre.enqueue(dstIdWeight)
49 | }
50 | }
51 | }
52 | priorityQueuePre
53 | }).map{case (srcId, dstIdWeightPriorityQueue) => (srcId, dstIdWeightPriorityQueue.toArray.sortBy(-_._2).mkString("@"))}
54 |
55 |
56 |
57 |
58 |
59 | list3.foreach(println)
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 | }
119 | }
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/example/Example3.scala:
--------------------------------------------------------------------------------
1 | package example
2 |
3 | import eges.random.walk.RandomWalk
4 | import sparkapplication.BaseSparkLocal
5 |
6 | object Example3 extends BaseSparkLocal {
7 | def main(args:Array[String]):Unit = {
8 | val spark = this.basicSpark
9 | import spark.implicits._
10 |
11 | val data = List(("y", "a", 0.8), ("y", "b", 0.6), ("y", "c", 0.4), ("q", "b", 0.8), ("q", "c", 0.6), ("q", "d", 0.4),
12 | ("s", "a", 0.8), ("s", "c", 0.6), ("s", "d", 0.4), ("a", "e", 0.5), ("b", "f", 0.5), ("c", "e", 0.5), ("c", "f", 0.5))
13 | val dataRDD = spark.sparkContext.parallelize(data)
14 | val randomWalk = new RandomWalk()
15 | .setNodeIterationNum(5)
16 | .setNodeWalkLength(6)
17 | .setReturnParameter(1.0)
18 | .setInOutParameter(1.0)
19 | .setNodeMaxDegree(200)
20 | .setNumPartitions(2)
21 |
22 | val paths = randomWalk.fit(dataRDD)
23 | paths.cache()
24 | println("总路径数: ", paths.count())
25 | println(paths.map(k => k._2).collect().mkString("#"))
26 | paths.unpersist()
27 |
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/example/Example4.scala:
--------------------------------------------------------------------------------
1 | package example
2 |
3 | import java.util.Random
4 | import eges.embedding.WeightedSkipGram
5 | import org.apache.spark.sql.functions._
6 | import sparkapplication.BaseSparkLocal
7 | import breeze.linalg.{norm, DenseVector}
8 | import scala.collection.mutable
9 | import scala.collection.mutable.ArrayBuffer
10 |
11 | object Example4 extends BaseSparkLocal {
12 | def main(args:Array[String]):Unit = {
13 | val spark = this.basicSpark
14 | import spark.implicits._
15 |
16 | val dataList = List("a@s1,b@s2,c@s3,d@s4,e@s5,f@s6,g@s7", "e@s5,f@s6,e@s5,a@s1,c@s3,h@s2,i@s6,j@s7,a@s2")
17 | val dataRDD = spark.sparkContext.parallelize(dataList)
18 | .map(k => k.split(",").map(v => v.split("@")))
19 | val weightedSkipGram = new WeightedSkipGram()
20 | .setVectorSize(4)
21 | .setWindowSize(1)
22 | .setNegativeSampleNum(1)
23 | .setPerSampleMaxNum(100)
24 | // .setSubSample(0.1)
25 | .setSubSample(0.0)
26 | .setLearningRate(0.25)
27 | .setIterationNum(6)
28 | .setIsNoShowLoss(true)
29 | .setNumPartitions(2)
30 | .setSampleTableNum(200)
31 | val nodeEmbedding = weightedSkipGram.fit(dataRDD).map(k => (k._1, new DenseVector[Double](k._2.split("@").map(_.toDouble)))).collectAsMap()
32 | println("nodeEmbedding结果如下:")
33 | nodeEmbedding.map(k => (k._1, k._2.toArray.mkString("@"))).foreach(println)
34 |
35 | val node = Array("a", "b", "c", "d", "e", "f", "g", "h", "i", "j")
36 | // val nodeWeight = Array("a#Weight", "b#Weight", "c#Weight", "d#Weight", "e#Weight", "f#Weight", "g#Weight", "h#Weight", "i#Weight", "j#Weight")
37 | val nodeAgg = Array("a#s1", "b#s2", "c#s3", "d#s4", "e#s5", "f#s6", "g#s7", "h#s2", "i#s6", "j#s7", "a#s2")
38 | val i2iSimilar = ArrayBuffer[(String, String, Double)]()
39 | for(mainNodeInfo <- nodeAgg){
40 | val mainNodeInfoVector = nodeEmbedding(mainNodeInfo)
41 | for(slaveNodeInfo <- nodeAgg){
42 | val slaveNodeInfoVector = nodeEmbedding(slaveNodeInfo)
43 | if(!mainNodeInfo.equals(slaveNodeInfo)) {
44 | val cosineSimilar = mainNodeInfoVector.t * slaveNodeInfoVector / ( norm(mainNodeInfoVector, 2) * norm(slaveNodeInfoVector, 2) )
45 | i2iSimilar.append((mainNodeInfo, slaveNodeInfo, cosineSimilar))
46 | }
47 | }
48 | }
49 | println("i2iSimilarHv结果如下:")
50 | i2iSimilar.foreach(println)
51 |
52 | val i2iSimilarSequence = ArrayBuffer[(String, String, Double)]()
53 | for(mainNodeInfo <- nodeAgg){
54 | val mainNodeInfoVector = nodeEmbedding(mainNodeInfo)
55 | for(slaveNode <- node){
56 | val slaveNodeInfoVector = nodeEmbedding(slaveNode+"#Zu")
57 | if(!mainNodeInfo.split("#").head.equals(slaveNode)) {
58 | val similar = mainNodeInfoVector.t * slaveNodeInfoVector
59 | val eHx = math.exp(math.max(math.min(-similar, 20.0), -20.0))
60 | val probability = 1.0 / (1.0 + eHx)
61 | i2iSimilarSequence.append((mainNodeInfo, slaveNode, probability))
62 | }
63 | }
64 | }
65 | println("i2iSimilarSequence结果如下:")
66 | i2iSimilarSequence.foreach(println)
67 |
68 | }
69 | }
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/example/Example5.scala:
--------------------------------------------------------------------------------
1 | package example
2 |
3 | import java.util.Random
4 | import breeze.linalg.DenseVector
5 | import eges.embedding.WeightedSkipGramBatch
6 | import org.apache.spark.sql.functions._
7 | import sparkapplication.BaseSparkLocal
8 | import scala.collection.mutable
9 | import breeze.linalg.{norm, DenseVector}
10 | import scala.collection.mutable.ArrayBuffer
11 |
12 | object Example5 extends BaseSparkLocal {
13 | def main(args:Array[String]):Unit = {
14 | val spark = this.basicSpark
15 | import spark.implicits._
16 |
17 | val dataList = List("a@s1,b@s2,c@s3,d@s4,e@s5,f@s6,g@s7", "e@s5,f@s6,e@s5,a@s1,c@s3,h@s2,i@s6,j@s7,a@s2")
18 | val dataRDD = spark.sparkContext.parallelize(dataList)
19 | .map(k => k.split(",").map(v => v.split("@")))
20 | val weightedSkipGram = new WeightedSkipGramBatch()
21 | .setVectorSize(4)
22 | .setWindowSize(1)
23 | .setNegativeSampleNum(1)
24 | .setPerSampleMaxNum(100)
25 | // .setSubSample(0.1)
26 | .setSubSample(0.0)
27 | .setLearningRate(0.5)
28 | .setBatchNum(2)
29 | .setIterationNum(6)
30 | .setIsNoShowLoss(true)
31 | .setNumPartitions(2)
32 | .setSampleTableNum(200)
33 | val nodeEmbedding = weightedSkipGram.fit(dataRDD).map(k => (k._1, new DenseVector[Double](k._2.split("@").map(_.toDouble)))).collectAsMap()
34 | println("nodeEmbedding结果如下:")
35 | nodeEmbedding.map(k => (k._1, k._2.toArray.mkString("@"))).foreach(println)
36 |
37 | val node = Array("a", "b", "c", "d", "e", "f", "g", "h", "i", "j")
38 | // val nodeWeight = Array("a#Weight", "b#Weight", "c#Weight", "d#Weight", "e#Weight", "f#Weight", "g#Weight", "h#Weight", "i#Weight", "j#Weight")
39 | val nodeAgg = Array("a#s1", "b#s2", "c#s3", "d#s4", "e#s5", "f#s6", "g#s7", "h#s2", "i#s6", "j#s7", "a#s2")
40 | val i2iSimilar = ArrayBuffer[(String, String, Double)]()
41 | for(mainNodeInfo <- nodeAgg){
42 | val mainNodeInfoVector = nodeEmbedding(mainNodeInfo)
43 | for(slaveNodeInfo <- nodeAgg){
44 | val slaveNodeInfoVector = nodeEmbedding(slaveNodeInfo)
45 | if(!mainNodeInfo.equals(slaveNodeInfo)) {
46 | val cosineSimilar = mainNodeInfoVector.t * slaveNodeInfoVector / ( norm(mainNodeInfoVector, 2) * norm(slaveNodeInfoVector, 2) )
47 | i2iSimilar.append((mainNodeInfo, slaveNodeInfo, cosineSimilar))
48 | }
49 | }
50 | }
51 | println("i2iSimilarHv结果如下:")
52 | i2iSimilar.foreach(println)
53 |
54 | val i2iSimilarSequence = ArrayBuffer[(String, String, Double)]()
55 | for(mainNodeInfo <- nodeAgg){
56 | val mainNodeInfoVector = nodeEmbedding(mainNodeInfo)
57 | for(slaveNode <- node){
58 | val slaveNodeInfoVector = nodeEmbedding(slaveNode+"#Zu")
59 | if(!mainNodeInfo.split("#").head.equals(slaveNode)) {
60 | val similar = mainNodeInfoVector.t * slaveNodeInfoVector
61 | val eHx = math.exp(math.max(math.min(-similar, 20.0), -20.0))
62 | val probability = 1.0 / (1.0 + eHx)
63 | i2iSimilarSequence.append((mainNodeInfo, slaveNode, probability))
64 | }
65 | }
66 | }
67 | println("i2iSimilarSequence结果如下:")
68 | i2iSimilarSequence.foreach(println)
69 |
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/example/Example6.scala:
--------------------------------------------------------------------------------
1 | package example
2 |
3 | import java.util.Date
4 | import java.util.Random
5 | import breeze.linalg.DenseVector
6 | import utils.miscellaneous.DateUtils
7 | import org.apache.spark.sql.functions._
8 | import org.apache.spark.util.random.XORShiftRandom
9 | import sparkapplication.BaseSparkLocal
10 | import breeze.linalg.DenseVector
11 | import scala.collection.mutable
12 | import scala.collection.mutable.ArrayBuffer
13 |
14 | object Example6 extends BaseSparkLocal {
15 | def main(args:Array[String]):Unit = {
16 | val spark = this.basicSpark
17 | import spark.implicits._
18 |
19 |
20 | val a = "-0.30614856@0.16460776@0.19909532@-0.48152933@-0.7925998"
21 | val aVector = new DenseVector[Double](a.split("@").map(_.toDouble))
22 | val b = "-0.7356615@0.6408239@-0.4380877@-1.920213@-2.628376"
23 | val bVector = new DenseVector[Double](b.split("@").map(_.toDouble))
24 | val similar = aVector.t * bVector
25 | val eHx = math.exp(math.max(math.min(-similar, 35.0), -35.0)).toFloat
26 | val probability = 1.toFloat / (1.toFloat + eHx)
27 | println(probability)
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/example/Example7.scala:
--------------------------------------------------------------------------------
1 | package example
2 |
3 | import breeze.linalg.{DenseVector, norm}
4 | import org.apache.spark.rdd.RDD
5 | import org.apache.spark.sql.expressions.Window
6 | import org.apache.spark.sql.functions._
7 | import sparkapplication.BaseSparkLocal
8 |
9 | import scala.collection.mutable
10 | import scala.collection.mutable.ArrayBuffer
11 |
12 | object Example7 extends BaseSparkLocal {
13 | def main(args:Array[String]):Unit = {
14 | val spark = this.basicSpark
15 | import spark.implicits._
16 |
17 | val list = List(("y", "2019"), ("q", "2018"))
18 | val data = spark.createDataFrame(list).toDF("member_id", "visit_time")
19 | .withColumn("index", row_number().over(Window.partitionBy("member_id").orderBy(desc("visit_time"))))
20 | .select("index")
21 | .rdd.map(k => k.getAs[Int]("index"))
22 |
23 |
24 | data.foreach(println)
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/example/Example8.scala:
--------------------------------------------------------------------------------
1 | package example
2 |
3 | import scala.collection.mutable
4 | import sparkapplication.BaseSparkLocal
5 |
6 | object Example8 extends BaseSparkLocal {
7 | def main(args:Array[String]):Unit = {
8 | val spark = this.basicSpark
9 | import spark.implicits._
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/example/Example9.scala:
--------------------------------------------------------------------------------
1 | package example
2 |
3 | import java.util.Random
4 |
5 | import breeze.linalg.DenseVector
6 | import sparkapplication.BaseSparkLocal
7 | import org.apache.spark.util.Utils
8 | import org.apache.spark.internal.Logging
9 | import org.apache.spark.util.random.XORShiftRandom
10 |
11 | object Example9 extends BaseSparkLocal {
12 | def main(args:Array[String]):Unit = {
13 | val spark = this.basicSpark
14 | import spark.implicits._
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | }
42 |
43 |
44 | }
45 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/main/Main1.scala:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import eges.random.walk.RandomWalk
4 | import sparkapplication.BaseSparkLocal
5 |
6 | object Main1 extends BaseSparkLocal {
7 | def main(args:Array[String]):Unit = {
8 | val spark = this.basicSpark
9 |
10 | val data = List(("y", "a", 0.8), ("y", "b", 0.6), ("y", "c", 0.4), ("q", "b", 0.8), ("q", "c", 0.6), ("q", "d", 0.4),
11 | ("s", "a", 0.8), ("s", "c", 0.6), ("s", "d", 0.4), ("a", "e", 0.5), ("b", "f", 0.5), ("c", "e", 0.5), ("c", "f", 0.5))
12 | val dataRDD = spark.sparkContext.parallelize(data)
13 | val randomWalk = new RandomWalk()
14 | .setNodeIterationNum(5)
15 | .setNodeWalkLength(6)
16 | .setReturnParameter(1.0)
17 | .setInOutParameter(1.0)
18 | .setNodeMaxDegree(200)
19 | .setNumPartitions(2)
20 |
21 | val paths = randomWalk.fit(dataRDD)
22 | paths.cache()
23 | println("总路径数: ", paths.count())
24 | println(paths.map(k => k._2).collect().mkString("#"))
25 | paths.unpersist()
26 |
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/main/Main2.scala:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import eges.embedding.WeightedSkipGram
4 | import sparkapplication.BaseSparkLocal
5 | import breeze.linalg.{norm, DenseVector}
6 | import scala.collection.mutable.ArrayBuffer
7 |
8 | object Main2 extends BaseSparkLocal {
9 | def main(args:Array[String]):Unit = {
10 | val spark = this.basicSpark
11 |
12 | val dataList = List("a@s1,b@s2,c@s3,d@s4,e@s5,f@s6,g@s7", "e@s5,f@s6,e@s5,a@s1,c@s3,h@s2,i@s6,j@s7,a@s2")
13 | val dataRDD = spark.sparkContext.parallelize(dataList)
14 | .map(k => k.split(",").map(v => v.split("@")))
15 | val weightedSkipGram = new WeightedSkipGram()
16 | .setVectorSize(4)
17 | .setWindowSize(1)
18 | .setNegativeSampleNum(1)
19 | .setPerSampleMaxNum(100)
20 | // .setSubSample(0.1)
21 | .setSubSample(0.0)
22 | .setLearningRate(0.25)
23 | .setIterationNum(6)
24 | .setIsNoShowLoss(true)
25 | .setNumPartitions(2)
26 | .setSampleTableNum(200)
27 | val nodeEmbedding = weightedSkipGram.fit(dataRDD).map(k => (k._1, new DenseVector[Double](k._2.split("@").map(_.toDouble)))).collectAsMap()
28 | println("nodeEmbedding结果如下:")
29 | nodeEmbedding.map(k => (k._1, k._2.toArray.mkString("@"))).foreach(println)
30 |
31 | val node = Array("a", "b", "c", "d", "e", "f", "g", "h", "i", "j")
32 | // val nodeWeight = Array("a#Weight", "b#Weight", "c#Weight", "d#Weight", "e#Weight", "f#Weight", "g#Weight", "h#Weight", "i#Weight", "j#Weight")
33 | val nodeAgg = Array("a#s1", "b#s2", "c#s3", "d#s4", "e#s5", "f#s6", "g#s7", "h#s2", "i#s6", "j#s7", "a#s2")
34 | val i2iSimilar = ArrayBuffer[(String, String, Double)]()
35 | for(mainNodeInfo <- nodeAgg){
36 | val mainNodeInfoVector = nodeEmbedding(mainNodeInfo)
37 | for(slaveNodeInfo <- nodeAgg){
38 | val slaveNodeInfoVector = nodeEmbedding(slaveNodeInfo)
39 | if(!mainNodeInfo.equals(slaveNodeInfo)) {
40 | val cosineSimilar = mainNodeInfoVector.t * slaveNodeInfoVector / ( norm(mainNodeInfoVector, 2) * norm(slaveNodeInfoVector, 2) )
41 | i2iSimilar.append((mainNodeInfo, slaveNodeInfo, cosineSimilar))
42 | }
43 | }
44 | }
45 | println("i2iSimilarHv结果如下:")
46 | i2iSimilar.foreach(println)
47 |
48 | val i2iSimilarSequence = ArrayBuffer[(String, String, Double)]()
49 | for(mainNodeInfo <- nodeAgg){
50 | val mainNodeInfoVector = nodeEmbedding(mainNodeInfo)
51 | for(slaveNode <- node){
52 | val slaveNodeInfoVector = nodeEmbedding(slaveNode+"#Zu")
53 | if(!mainNodeInfo.split("#").head.equals(slaveNode)) {
54 | val similar = mainNodeInfoVector.t * slaveNodeInfoVector
55 | val eHx = math.exp(math.max(math.min(-similar, 20.0), -20.0))
56 | val probability = 1.0 / (1.0 + eHx)
57 | i2iSimilarSequence.append((mainNodeInfo, slaveNode, probability))
58 | }
59 | }
60 | }
61 | println("i2iSimilarSequence结果如下:")
62 | i2iSimilarSequence.foreach(println)
63 |
64 | }
65 | }
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/main/Main3.scala:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import eges.embedding.WeightedSkipGramBatch
4 | import sparkapplication.BaseSparkLocal
5 | import breeze.linalg.{norm, DenseVector}
6 | import scala.collection.mutable.ArrayBuffer
7 |
8 | object Main3 extends BaseSparkLocal {
9 | def main(args:Array[String]):Unit = {
10 | val spark = this.basicSpark
11 |
12 | val dataList = List("a@s1,b@s2,c@s3,d@s4,e@s5,f@s6,g@s7", "e@s5,f@s6,e@s5,a@s1,c@s3,h@s2,i@s6,j@s7,a@s2")
13 | val dataRDD = spark.sparkContext.parallelize(dataList)
14 | .map(k => k.split(",").map(v => v.split("@")))
15 | val weightedSkipGram = new WeightedSkipGramBatch()
16 | .setVectorSize(4)
17 | .setWindowSize(1)
18 | .setNegativeSampleNum(1)
19 | .setPerSampleMaxNum(100)
20 | // .setSubSample(0.1)
21 | .setSubSample(0.0)
22 | .setLearningRate(0.5)
23 | .setBatchNum(2)
24 | .setIterationNum(6)
25 | .setIsNoShowLoss(true)
26 | .setNumPartitions(2)
27 | .setSampleTableNum(200)
28 | val nodeEmbedding = weightedSkipGram.fit(dataRDD).map(k => (k._1, new DenseVector[Double](k._2.split("@").map(_.toDouble)))).collectAsMap()
29 | println("nodeEmbedding结果如下:")
30 | nodeEmbedding.map(k => (k._1, k._2.toArray.mkString("@"))).foreach(println)
31 |
32 | val node = Array("a", "b", "c", "d", "e", "f", "g", "h", "i", "j")
33 | // val nodeWeight = Array("a#Weight", "b#Weight", "c#Weight", "d#Weight", "e#Weight", "f#Weight", "g#Weight", "h#Weight", "i#Weight", "j#Weight")
34 | val nodeAgg = Array("a#s1", "b#s2", "c#s3", "d#s4", "e#s5", "f#s6", "g#s7", "h#s2", "i#s6", "j#s7", "a#s2")
35 | val i2iSimilar = ArrayBuffer[(String, String, Double)]()
36 | for(mainNodeInfo <- nodeAgg){
37 | val mainNodeInfoVector = nodeEmbedding(mainNodeInfo)
38 | for(slaveNodeInfo <- nodeAgg){
39 | val slaveNodeInfoVector = nodeEmbedding(slaveNodeInfo)
40 | if(!mainNodeInfo.equals(slaveNodeInfo)) {
41 | val cosineSimilar = mainNodeInfoVector.t * slaveNodeInfoVector / ( norm(mainNodeInfoVector, 2) * norm(slaveNodeInfoVector, 2) )
42 | i2iSimilar.append((mainNodeInfo, slaveNodeInfo, cosineSimilar))
43 | }
44 | }
45 | }
46 | println("i2iSimilarHv结果如下:")
47 | i2iSimilar.foreach(println)
48 |
49 | val i2iSimilarSequence = ArrayBuffer[(String, String, Double)]()
50 | for(mainNodeInfo <- nodeAgg){
51 | val mainNodeInfoVector = nodeEmbedding(mainNodeInfo)
52 | for(slaveNode <- node){
53 | val slaveNodeInfoVector = nodeEmbedding(slaveNode+"#Zu")
54 | if(!mainNodeInfo.split("#").head.equals(slaveNode)) {
55 | val similar = mainNodeInfoVector.t * slaveNodeInfoVector
56 | val eHx = math.exp(math.max(math.min(-similar, 20.0), -20.0))
57 | val probability = 1.0 / (1.0 + eHx)
58 | i2iSimilarSequence.append((mainNodeInfo, slaveNode, probability))
59 | }
60 | }
61 | }
62 | println("i2iSimilarSequence结果如下:")
63 | i2iSimilarSequence.foreach(println)
64 |
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/sparkapplication/BaseSparkLocal.scala:
--------------------------------------------------------------------------------
1 | package sparkapplication
2 |
3 | import org.apache.spark.SparkConf
4 | import org.apache.spark.sql.SparkSession
5 |
6 | trait BaseSparkLocal {
7 | //本地
8 | def basicSpark: SparkSession =
9 | SparkSession
10 | .builder
11 | .config(getSparkConf)
12 | .master("local[1]")
13 | .getOrCreate()
14 |
15 | def getSparkConf: SparkConf = {
16 | val conf = new SparkConf()
17 | conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
18 | .set("spark.network.timeout", "600")
19 | .set("spark.streaming.kafka.maxRatePerPartition", "200000")
20 | .set("spark.streaming.kafka.consumer.poll.ms", "5120")
21 | .set("spark.streaming.concurrentJobs", "5")
22 | .set("spark.sql.crossJoin.enabled", "true")
23 | .set("spark.driver.maxResultSize", "1g")
24 | .set("spark.rpc.message.maxSize", "1000") // 1024 max
25 | conf
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/enhanced_graph_embedding_side_information/src/main/scala/sparkapplication/BaseSparkOnline.scala:
--------------------------------------------------------------------------------
1 | package sparkapplication
2 |
3 | import org.apache.spark.SparkConf
4 | import org.apache.spark.sql.SparkSession
5 |
6 | trait BaseSparkOnline {
7 | def basicSpark: SparkSession =
8 | SparkSession
9 | .builder
10 | .config(getSparkConf)
11 | .enableHiveSupport()
12 | .getOrCreate()
13 |
14 | def getSparkConf: SparkConf = {
15 | val conf = new SparkConf()
16 | conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
17 | .set("spark.network.timeout", "6000")
18 | .set("spark.streaming.kafka.maxRatePerPartition", "200000")
19 | .set("spark.streaming.kafka.consumer.poll.ms", "5120")
20 | .set("spark.streaming.concurrentJobs", "5")
21 | .set("spark.sql.crossJoin.enabled", "true")
22 | .set("spark.driver.maxResultSize", "20g")
23 | .set("spark.rpc.message.maxSize", "1000") // 1024 max
24 | }
25 |
26 | }
27 |
--------------------------------------------------------------------------------
/paper/[Alibaba Embedding] Billion-scale Commodity Embedding for E-commerce Recommendation in Alibaba (Alibaba 2018).pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JerryCatLeung/deepwalk_node2vector_eges/b817c583459e27f07fabae59fcdb83a51b0fa81e/paper/[Alibaba Embedding] Billion-scale Commodity Embedding for E-commerce Recommendation in Alibaba (Alibaba 2018).pdf
--------------------------------------------------------------------------------
/paper/[Graph Embedding] DeepWalk- Online Learning of Social Representations (SBU 2014).pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JerryCatLeung/deepwalk_node2vector_eges/b817c583459e27f07fabae59fcdb83a51b0fa81e/paper/[Graph Embedding] DeepWalk- Online Learning of Social Representations (SBU 2014).pdf
--------------------------------------------------------------------------------
/paper/[Node2vec] Node2vec - Scalable Feature Learning for Networks (Stanford 2016).pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JerryCatLeung/deepwalk_node2vector_eges/b817c583459e27f07fabae59fcdb83a51b0fa81e/paper/[Node2vec] Node2vec - Scalable Feature Learning for Networks (Stanford 2016).pdf
--------------------------------------------------------------------------------