├── README.md ├── enhanced_graph_embedding_side_information ├── .idea │ ├── compiler.xml │ ├── hydra.xml │ ├── libraries │ │ └── scala_sdk_2_11_8.xml │ ├── misc.xml │ ├── scala_compiler.xml │ ├── uiDesigner.xml │ └── workspace.xml ├── pom.xml └── src │ └── main │ └── scala │ ├── eges │ ├── embedding │ │ ├── WeightedSkipGram.scala │ │ └── WeightedSkipGramBatch.scala │ └── random │ │ ├── sample │ │ └── Alias.scala │ │ └── walk │ │ ├── Attributes.scala │ │ └── RandomWalk.scala │ ├── example │ ├── Example1.scala │ ├── Example10.scala │ ├── Example11.scala │ ├── Example2.scala │ ├── Example3.scala │ ├── Example4.scala │ ├── Example5.scala │ ├── Example6.scala │ ├── Example7.scala │ ├── Example8.scala │ └── Example9.scala │ ├── main │ ├── Main1.scala │ ├── Main2.scala │ └── Main3.scala │ └── sparkapplication │ ├── BaseSparkLocal.scala │ └── BaseSparkOnline.scala └── paper ├── [Alibaba Embedding] Billion-scale Commodity Embedding for E-commerce Recommendation in Alibaba (Alibaba 2018).pdf ├── [Graph Embedding] DeepWalk- Online Learning of Social Representations (SBU 2014).pdf └── [Node2vec] Node2vec - Scalable Feature Learning for Networks (Stanford 2016).pdf /README.md: -------------------------------------------------------------------------------- 1 | # deepwalk_node2vector_eges 2 | 将deepwalk、node2vector和阿里的文章:Billion-scale Commodity Embedding for E-commerce Recommendation in Alibaba 3 | 用代码实现,将随机游走与embedding分开,其中代码中: 4 | 1. 主函数Main1通过设置p和q来随机游走 5 | 2. 主函数Main2(全量同步训练)是word2vector与阿里文章的结合,当边信息长度为1时就是word2vector,边信息长度大于1时就是阿里文章代码 6 | 3. 主函数Main3(分多批次同步训练)是与主函数Main2一样,只不过是多批训练 -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/.idea/compiler.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/.idea/hydra.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 9 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/.idea/libraries/scala_sdk_2_11_8.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/.idea/scala_compiler.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/.idea/uiDesigner.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 14 | 15 | 16 | 17 | 18 | 23 | 24 | 25 | 26 | @ 27 | nextFloat 28 | nextGaussian 29 | persist(StorageLevel.MEMORY_AND_DISK) 30 | first 31 | cache 32 | nextFloat() 33 | nodeOrSideInfo 34 | nodeOrSideInfoArray 35 | 16 36 | 20 37 | 35 38 | F 39 | Float 40 | 41 | 42 | D:\apache-maven-3.5.2\repository\org\apache\spark\spark-mllib_2.11\2.1.0\spark-mllib_2.11-2.1.0.jar!\org\apache\spark\ml 43 | 44 | 45 | 46 | 101 | 102 | 103 | 110 | 115 | 116 | 117 | 118 | 119 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 201 | 202 | 203 | 204 | 214 | 215 | 225 | 226 | 236 | 237 | 247 | 248 | 258 | 259 | 261 | 262 | 275 | 276 | 277 | 278 | 287 | 288 | 304 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 1565866081242 333 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 401 | 402 | 403 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 549 | 550 | 551 | 552 | 553 | 554 | 555 | 556 | 557 | 558 | 559 | 560 | 561 | 562 | 563 | 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 574 | 575 | 576 | 577 | 578 | 579 | 580 | 581 | 582 | 583 | 584 | 585 | 586 | 587 | 588 | 589 | 590 | 591 | 592 | 593 | 594 | 595 | 596 | 597 | 598 | 599 | 600 | 601 | 602 | 603 | 604 | 605 | 606 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | 614 | 615 | 616 | 617 | 618 | 619 | 620 | 621 | 622 | 623 | 624 | 625 | 626 | 627 | 628 | 629 | 630 | 631 | 632 | 633 | 634 | 635 | 636 | 637 | 638 | 639 | 640 | 641 | 642 | 643 | 644 | 645 | 646 | 647 | 648 | 649 | 650 | 651 | 652 | 653 | 654 | 655 | 656 | 657 | 658 | 659 | 660 | 661 | 662 | 663 | 664 | 665 | 666 | 667 | 668 | 669 | 674 | 675 | 676 | 677 | 678 | 679 | 680 | 681 | 686 | 687 | 688 | 689 | 690 | 691 | scala-sdk-2.11.8 692 | 693 | 698 | 699 | 700 | 701 | 702 | 703 | 1.8 704 | 705 | 710 | 711 | 712 | 713 | 714 | 715 | enhanced_graph_embedding_side_information 716 | 717 | 723 | 724 | 725 | 726 | 727 | 728 | scala-sdk-2.11.8 729 | 730 | 735 | 736 | 737 | 738 | 739 | 740 | 741 | 746 | 747 | 748 | 749 | 750 | 751 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | embedding.eges 8 | enhanced_graph_embedding_side_information 9 | 1.0-SNAPSHOT 10 | 11 | 12 | 2.11 13 | ${scala.binary.version}.8 14 | 2.1.0 15 | compile 16 | 17 | 18 | 19 | org.scala-lang 20 | scala-library 21 | ${scala.version} 22 | 23 | 24 | org.scala-lang 25 | scala-compiler 26 | ${scala.version} 27 | 28 | 29 | org.apache.spark 30 | spark-core_${scala.binary.version} 31 | ${spark.version} 32 | 33 | 34 | 35 | org.apache.spark 36 | spark-sql_${scala.binary.version} 37 | ${spark.version} 38 | 39 | 40 | 41 | commons-configuration 42 | commons-configuration 43 | 1.9 44 | 45 | 46 | commons-lang 47 | commons-lang 48 | 2.5 49 | 50 | 51 | 52 | org.apache.spark 53 | spark-mllib_2.11 54 | ${spark.version} 55 | 56 | 57 | 58 | 59 | org.scalanlp 60 | breeze_2.11 61 | 0.13.2 62 | 63 | 64 | 65 | 66 | 67 | sit 68 | 69 | true 70 | 71 | 72 | 73 | ../${project.artifactId}/vars/vars.sit.properties 74 | 75 | 76 | 77 | src/main/resources 78 | true 79 | 80 | 81 | 82 | 83 | 84 | prod 85 | 86 | 87 | ../${project.artifactId}/vars/vars.prod.properties 88 | 89 | 90 | 91 | src/main/resources 92 | true 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | org.codehaus.mojo 103 | build-helper-maven-plugin 104 | 1.8 105 | 106 | 107 | add-source 108 | generate-sources 109 | 110 | add-source 111 | 112 | 113 | 114 | src/main/scala 115 | src/test/scala 116 | 117 | 118 | 119 | 120 | add-test-source 121 | generate-sources 122 | 123 | add-test-source 124 | 125 | 126 | 127 | src/test/scala 128 | 129 | 130 | 131 | 132 | 133 | 134 | net.alchim31.maven 135 | scala-maven-plugin 136 | 3.1.5 137 | 138 | 139 | compile 140 | testCompile 141 | 142 | 143 | 144 | ${scala.version} 145 | 146 | 147 | org.apache.maven.plugins 148 | maven-compiler-plugin 149 | 150 | 1.7 151 | 1.7 152 | utf-8 153 | 154 | 155 | 156 | compile 157 | 158 | compile 159 | 160 | 161 | 162 | 163 | 164 | maven-assembly-plugin 165 | 166 | 167 | jar-with-dependencies 168 | 169 | 170 | 171 | example.Example5 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/eges/embedding/WeightedSkipGram.scala: -------------------------------------------------------------------------------- 1 | package eges.embedding 2 | 3 | import java.util.Random 4 | import scala.collection.mutable 5 | import org.apache.spark.rdd.RDD 6 | import breeze.linalg.DenseVector 7 | import scala.collection.mutable.ArrayBuffer 8 | import org.apache.spark.storage.StorageLevel 9 | 10 | class WeightedSkipGram extends Serializable { 11 | 12 | var vectorSize = 4 13 | var windowSize = 2 14 | var negativeSampleNum = 1 15 | var perSampleMaxNum = 10 16 | var subSample: Double = 0.0 17 | var learningRate: Double = 0.02 18 | var iterationNum = 5 19 | var isNoShowLoss = false 20 | var numPartitions = 200 21 | var nodeSideCount: Int = 0 22 | var nodeSideTable: Array[(String, Int)] = _ 23 | var nodeSideHash = mutable.HashMap.empty[String, Int] 24 | var sampleTableNum: Int = 1e8.toInt 25 | lazy val sampleTable: Array[Int] = new Array[Int](sampleTableNum) 26 | 27 | /** 28 | * 每个节点embedding长度 29 | * 30 | * */ 31 | def setVectorSize(value:Int):this.type = { 32 | require(value > 0, s"vectorSize must be more than 0, but it is $value") 33 | vectorSize = value 34 | this 35 | } 36 | 37 | /** 38 | * 窗口大小 39 | * 40 | * */ 41 | def setWindowSize(value:Int):this.type = { 42 | require(value > 0, s"windowSize must be more than 0, but it is $value") 43 | windowSize = value 44 | this 45 | } 46 | 47 | /** 48 | * 负采样的样本个数 49 | * 50 | * */ 51 | def setNegativeSampleNum(value:Int):this.type = { 52 | require(value > 0, s"negativeSampleNum must be more than 0, but it is $value") 53 | negativeSampleNum = value 54 | this 55 | } 56 | 57 | /** 58 | * 负采样时, 每个负样本最多采样的次数 59 | * 60 | * */ 61 | def setPerSampleMaxNum(value:Int):this.type = { 62 | require(value > 0, s"perSampleMaxNum must be more than 0, but it is $value") 63 | perSampleMaxNum = value 64 | this 65 | } 66 | 67 | /** 68 | * 高频节点的下采样率 69 | * 70 | * */ 71 | def setSubSample(value:Double):this.type = { 72 | require(value >= 0.0 && value <= 1.0, s"subSample must be not less than 0.0 and not more than 1.0, but it is $value") 73 | subSample = value 74 | this 75 | } 76 | 77 | /** 78 | * 初始化学习率 79 | * 80 | * */ 81 | def setLearningRate(value:Double):this.type = { 82 | require(value > 0.0, s"learningRate must be more than 0.0, but it is $value") 83 | learningRate = value 84 | this 85 | } 86 | 87 | /** 88 | * 迭代次数 89 | * 90 | * */ 91 | def setIterationNum(value:Int):this.type = { 92 | require(value > 0, s"iterationNum must be more than 0, but it is $value") 93 | iterationNum = value 94 | this 95 | } 96 | 97 | /** 98 | * 是否显示损失, 不建议显示损失值, 这样增加计算量 99 | * 100 | * */ 101 | def setIsNoShowLoss(value:Boolean):this.type = { 102 | isNoShowLoss = value 103 | this 104 | } 105 | 106 | /** 107 | * 分区数量 108 | * 109 | * */ 110 | def setNumPartitions(value:Int):this.type = { 111 | require(value > 0, s"numPartitions must be more than 0, but it is $value") 112 | numPartitions = value 113 | this 114 | } 115 | 116 | /** 117 | * 采样数组大小 118 | * 119 | * */ 120 | def setSampleTableNum(value:Int):this.type = { 121 | require(value > 0, s"sampleTableNum must be more than 0, but it is $value") 122 | sampleTableNum = value 123 | this 124 | } 125 | 126 | /** 127 | * 训练 128 | * 129 | * */ 130 | def fit(dataSet: RDD[Array[Array[String]]]): RDD[(String, String)] = { 131 | // 重新分区 132 | val dataSetRepartition = dataSet.map(k => ((new Random).nextInt(numPartitions), k)) 133 | .repartition(numPartitions).map(k => k._2) 134 | dataSetRepartition.persist(StorageLevel.MEMORY_AND_DISK) 135 | val sc = dataSetRepartition.context 136 | // 初始化 WeightedSkipGram 137 | initWeightedSkipGram(dataSetRepartition) 138 | // 边和embedding的长度 139 | val sideInfoNum = dataSetRepartition.first().head.length 140 | val sideInfoNumBroadcast = sc.broadcast(sideInfoNum) 141 | val vectorSizeBroadcast = sc.broadcast(vectorSize) 142 | // 节点权重初始化 143 | val nodeWeight = dataSetRepartition.flatMap(x => x) 144 | .map(nodeInfo => nodeInfo.head + "#Weight").distinct() 145 | .map(nodeWeight => (nodeWeight, Array.fill[Double](sideInfoNumBroadcast.value)((new Random).nextGaussian()))) 146 | // 节点Zu初始化 147 | val nodeZu = dataSetRepartition.flatMap(x => x) 148 | .map(nodeInfo => nodeInfo.head + "#Zu").distinct() 149 | .map(nodeZu => (nodeZu, Array.fill[Double](vectorSizeBroadcast.value)( ((new Random).nextDouble()-0.5)/vectorSizeBroadcast.value ) )) 150 | // 节点和边信息初始化 151 | val nodeAndSideVector = dataSetRepartition.flatMap(x => x).flatMap(x => x).distinct() 152 | .map(nodeAndSide => (nodeAndSide, Array.fill[Double](vectorSizeBroadcast.value)( ((new Random).nextDouble()-0.5)/vectorSizeBroadcast.value ) )) 153 | // 节点权重、节点和边信息初始化 154 | val embeddingHashInit = mutable.HashMap(nodeWeight.union(nodeZu).union(nodeAndSideVector).collectAsMap().toList:_*) 155 | // println("embedding初始化:") 156 | // embeddingHashInit.map(k => (k._1, k._2.mkString("@"))).foreach(println) 157 | 158 | // 广播部分变量 159 | val nodeSideTableBroadcast = sc.broadcast(nodeSideTable) 160 | val nodeSideHashBroadcast = sc.broadcast(nodeSideHash) 161 | val sampleTableBroadcast = sc.broadcast(sampleTable) 162 | val subSampleBroadcast = sc.broadcast(subSample) 163 | val windowSizeBroadcast = sc.broadcast(windowSize) 164 | val negativeSampleNumBroadcast = sc.broadcast(negativeSampleNum) 165 | val perSampleMaxNumBroadcast = sc.broadcast(perSampleMaxNum) 166 | val sampleTableNumBroadcast = sc.broadcast(sampleTableNum) 167 | val nodeSideCountBroadcast = sc.broadcast(nodeSideCount) 168 | val isNoShowLossBroadcast = sc.broadcast(isNoShowLoss) 169 | // val lambdaBroadcast = sc.broadcast(learningRate) 170 | 171 | // 训练 172 | var embeddingHashBroadcast = sc.broadcast(embeddingHashInit) 173 | for(iter <- 0 until iterationNum){ 174 | val lambdaBroadcast = sc.broadcast(if(learningRate/(1+iter) >= learningRate/5) learningRate/(1+iter) else learningRate/5) 175 | // val lambdaBroadcast = sc.broadcast(if(learningRate/(1+iter) >= 0.000001) learningRate/(1+iter) else 0.000001) 176 | // val lambdaBroadcast = sc.broadcast(Array(0.5, 0.4, 0.2, 0.1, 0.05, 0.05)(iter)) 177 | // 训练结果 178 | val trainResult = dataSetRepartition.mapPartitions(dataIterator => 179 | trainWeightedSkipGram(dataIterator, embeddingHashBroadcast.value, nodeSideTableBroadcast.value, 180 | nodeSideHashBroadcast.value, sampleTableBroadcast.value, subSampleBroadcast.value, vectorSizeBroadcast.value, 181 | sideInfoNumBroadcast.value, windowSizeBroadcast.value, negativeSampleNumBroadcast.value, perSampleMaxNumBroadcast.value, 182 | sampleTableNumBroadcast.value, nodeSideCountBroadcast.value, isNoShowLossBroadcast.value)) 183 | trainResult.persist(StorageLevel.MEMORY_AND_DISK) 184 | // 节点权重的梯度聚类 185 | val trainWeightResult = trainResult.filter(k => k._1.endsWith("#Weight") && !k._1.equals("LossValue")) 186 | .aggregateByKey((new DenseVector[Double](Array.fill[Double](sideInfoNumBroadcast.value)(0.0)), 0L))( 187 | (vector, array) => (vector._1 + new DenseVector[Double](array._1), vector._2 + array._2), 188 | (vector1, vector2) => (vector1._1 + vector2._1, vector1._2 + vector2._2) ) 189 | trainWeightResult.persist(StorageLevel.MEMORY_AND_DISK) 190 | // 节点和边embedding的梯度聚类 191 | val trainVectorResult = trainResult.filter(k => !k._1.endsWith("#Weight") && !k._1.equals("LossValue")) 192 | .aggregateByKey((new DenseVector[Double](Array.fill[Double](vectorSizeBroadcast.value)(0.0)), 0L))( 193 | (vector, array) => (vector._1 + new DenseVector[Double](array._1), vector._2 + array._2), 194 | (vector1, vector2) => (vector1._1 + vector2._1, vector1._2 + vector2._2) ) 195 | trainVectorResult.persist(StorageLevel.MEMORY_AND_DISK) 196 | // 计算平均损失 197 | if(isNoShowLoss){ 198 | val trainLossResult = trainResult.filter(k => k._1.equals("LossValue")) 199 | .aggregateByKey((new DenseVector[Double](Array.fill[Double](1)(0.0)), 0L))( 200 | (vector, array) => (vector._1 + new DenseVector[Double](array._1), vector._2 + array._2), 201 | (vector1, vector2) => (vector1._1 + vector2._1, vector1._2 + vector2._2) 202 | ).map(k => (k._1, k._2._1/k._2._2.toDouble)) 203 | .map(k => (k._1, k._2.toArray.head)).first()._2 204 | println(s"====第${iter+1}轮====平均损失:$trainLossResult====") 205 | } 206 | val trainNodeSideResultUnion = trainWeightResult.union(trainVectorResult) 207 | .map{ case (key, (gradientVectorSum, num)) => 208 | val embeddingRaw = new DenseVector[Double](embeddingHashBroadcast.value(key)) 209 | val embeddingUpdate = embeddingRaw - 100 * lambdaBroadcast.value * gradientVectorSum / num.toDouble 210 | (key, embeddingUpdate.toArray) 211 | } 212 | var trainSideVectorResultMap = mutable.HashMap(trainNodeSideResultUnion.collectAsMap().toList:_*) 213 | trainSideVectorResultMap = embeddingHashBroadcast.value.++(trainSideVectorResultMap) 214 | embeddingHashBroadcast = sc.broadcast(trainSideVectorResultMap) 215 | trainResult.unpersist() 216 | trainWeightResult.unpersist() 217 | trainVectorResult.unpersist() 218 | } 219 | dataSetRepartition.unpersist() 220 | 221 | // 每个节点加权向量 222 | var embeddingHashResult = embeddingHashBroadcast.value 223 | val nodeHvArray = nodeSideTable.map(k => k._1.split("@")) 224 | .map(nodeInfo => { 225 | val node = nodeInfo.head 226 | var eavSum = 0.0 227 | var eavWvAggSum = new DenseVector[Double](Array.fill[Double](vectorSize)(0.0)) 228 | val keyEA = node + "#Weight" 229 | val valueEA = embeddingHashResult.getOrElse(keyEA, Array.fill[Double](sideInfoNum)(0.0)) 230 | for(n <- 0 until sideInfoNum) { 231 | val keyNodeOrSide = nodeInfo(n) 232 | val valueNodeOrSide = embeddingHashResult.getOrElse(keyNodeOrSide, Array.fill[Double](vectorSize)(0.0)) 233 | eavWvAggSum = eavWvAggSum + math.exp(valueEA(n)) * new DenseVector[Double](valueNodeOrSide) 234 | eavSum = eavSum + math.exp(valueEA(n)) 235 | } 236 | val Hv = (eavWvAggSum/eavSum).toArray 237 | (nodeInfo.mkString("#"), Hv) 238 | }) 239 | val nodeHvMap = mutable.Map(nodeHvArray:_*) 240 | embeddingHashResult = embeddingHashResult.++(nodeHvMap) 241 | val embeddingResult = sc.parallelize(embeddingHashResult.map(k => (k._1, k._2.mkString("@"))).toList, numPartitions) 242 | embeddingResult 243 | } 244 | 245 | /** 246 | * 初始化 247 | * 248 | * */ 249 | def initWeightedSkipGram(dataSet: RDD[Array[Array[String]]]): Unit = { 250 | // 节点和边信息出现的频次数组 251 | nodeSideTable = dataSet.flatMap(x => x) 252 | .map(nodeInfo => (nodeInfo.mkString("@"), 1)) 253 | .reduceByKey(_ + _) 254 | .collect().sortBy(-_._2) 255 | // 节点和边信息总数量 256 | nodeSideCount = nodeSideTable.map(_._2).par.sum 257 | //节点和边信息出现的频次哈希map 258 | nodeSideHash = mutable.HashMap(nodeSideTable:_*) 259 | // 节点和边信息尺寸 260 | val nodeSideSize = nodeSideTable.length 261 | // Negative Sampling 负采样初始化 262 | var a = 0 263 | val power = 0.75 264 | var nodesPow = 0.0 265 | while (a < nodeSideSize) { 266 | nodesPow += Math.pow(nodeSideTable(a)._2, power) 267 | a = a + 1 268 | } 269 | var b = 0 270 | var freq = Math.pow(nodeSideTable(b)._2, power) / nodesPow 271 | var c = 0 272 | while (c < sampleTableNum) { 273 | sampleTable(c) = b 274 | if ((c.toDouble + 1.0) / sampleTableNum >= freq) { 275 | b = b + 1 276 | if (b >= nodeSideSize) { 277 | b = nodeSideSize - 1 278 | } 279 | freq += Math.pow(nodeSideTable(b)._2, power) / nodesPow 280 | } 281 | c = c + 1 282 | } 283 | } 284 | 285 | /** 286 | * 每个分区训练 287 | * 288 | * */ 289 | def trainWeightedSkipGram(dataSet: Iterator[Array[Array[String]]], 290 | embeddingHash: mutable.HashMap[String, Array[Double]], 291 | nodeSideTable: Array[(String, Int)], 292 | nodeSideHash: mutable.HashMap[String, Int], 293 | sampleTable: Array[Int], 294 | subSample: Double, 295 | vectorSize: Int, 296 | sideInfoNum: Int, 297 | windowSize: Int, 298 | negativeSampleNum: Int, 299 | perSampleMaxNum: Int, 300 | sampleTableNum: Int, 301 | nodeSideCount: Int, 302 | isNoShowLoss: Boolean): Iterator[(String, (Array[Double], Long))] = { 303 | val gradientUpdateHash = new mutable.HashMap[String, (Array[Double], Long)]() 304 | var lossSum = 0.0 305 | var lossNum = 0L 306 | for(data <- dataSet) { 307 | // 高频节点的下采样 308 | val dataSubSample = ArrayBuffer[Array[String]]() 309 | for( nodeOrSideInfoArray <- data){ 310 | if(subSample > 0.0) { 311 | val nodeOrSideInfoFrequency = nodeSideHash(nodeOrSideInfoArray.mkString("@")) 312 | val keepProbability = (Math.sqrt(nodeOrSideInfoFrequency/ (subSample * nodeSideCount)) + 1.0) * (subSample * nodeSideCount) / nodeOrSideInfoFrequency 313 | if (keepProbability >= (new Random).nextDouble()) { 314 | dataSubSample.append(nodeOrSideInfoArray) 315 | } 316 | } else { 317 | dataSubSample.append(nodeOrSideInfoArray) 318 | } 319 | } 320 | val dssl = dataSubSample.length 321 | for(i <- 0 until dssl){ 322 | val mainNodeInfo = dataSubSample(i) 323 | val mainNode = mainNodeInfo.head 324 | var eavSum = 0.0 325 | var eavWvAggSum = new DenseVector[Double](Array.fill[Double](vectorSize)(0.0)) 326 | val keyEA = mainNode + "#Weight" 327 | val valueEA = embeddingHash(keyEA) 328 | for(k <- 0 until sideInfoNum) { 329 | val keyNodeOrSide = mainNodeInfo(k) 330 | val valueNodeOrSide = embeddingHash(keyNodeOrSide) 331 | eavWvAggSum = eavWvAggSum + math.exp(valueEA(k)) * new DenseVector[Double](valueNodeOrSide) 332 | eavSum = eavSum + math.exp(valueEA(k)) 333 | } 334 | val Hv = eavWvAggSum/eavSum 335 | // 主节点对应窗口内的正样本集合 336 | val mainSlaveNodeSet = dataSubSample.slice(math.max(0, i-windowSize), math.min(i+windowSize, dssl-1) + 1) 337 | .map(array => array.head).toSet 338 | for(j <- math.max(0, i-windowSize) to math.min(i+windowSize, dssl-1)){ 339 | if(j != i) { 340 | // 正样本训练 341 | val slaveNodeInfo = dataSubSample(j) 342 | val slaveNode = slaveNodeInfo.head 343 | embeddingUpdate(mainNodeInfo, slaveNodeInfo, embeddingHash, gradientUpdateHash, valueEA, eavWvAggSum, eavSum, Hv, 1.0, vectorSize, sideInfoNum) 344 | // 正样本计算损失值 345 | if(isNoShowLoss){ 346 | val keySlaveNode = slaveNode + "#Zu" 347 | val valueSlaveNode = embeddingHash(keySlaveNode) 348 | val HvZu = Hv.t * new DenseVector[Double](valueSlaveNode) 349 | val logarithmRaw = math.log(1.0 / (1.0 + math.exp(math.max(math.min(-HvZu, 20.0), -20.0)))) 350 | val logarithm = math.max(math.min(logarithmRaw, 0.0), -20.0) 351 | lossSum = lossSum - logarithm 352 | lossNum = lossNum + 1L 353 | } 354 | // 负样本采样和训练 355 | for(_ <- 0 until negativeSampleNum){ 356 | // 负样本采样 357 | var sampleNodeInfo = slaveNodeInfo 358 | var sampleNode = slaveNode 359 | var sampleNum = 0 360 | while(mainSlaveNodeSet.contains(sampleNode) && sampleNum < perSampleMaxNum) { 361 | val index = sampleTable((new Random).nextInt(sampleTableNum)) 362 | sampleNodeInfo = nodeSideTable(index)._1.split("@") 363 | sampleNode = sampleNodeInfo.head 364 | sampleNum = sampleNum + 1 365 | } 366 | // 负样本训练 367 | embeddingUpdate(mainNodeInfo, sampleNodeInfo, embeddingHash, gradientUpdateHash, valueEA, eavWvAggSum, eavSum, Hv, 0.0, vectorSize, sideInfoNum) 368 | // 负样本计算损失值 369 | if(isNoShowLoss){ 370 | val keySampleNode = sampleNodeInfo.head + "#Zu" 371 | val valueSampleNode = embeddingHash(keySampleNode) 372 | val HvZu = Hv.t * new DenseVector[Double](valueSampleNode) 373 | val logarithmRaw = math.log(1.0 - 1.0 / (1.0 + math.exp(math.max(math.min(-HvZu, 20.0), -20.0)))) 374 | val logarithm = math.max(math.min(logarithmRaw, 0.0), -20.0) 375 | lossSum = lossSum - logarithm 376 | lossNum = lossNum + 1L 377 | } 378 | } 379 | } 380 | } 381 | } 382 | } 383 | if(isNoShowLoss){ 384 | gradientUpdateHash.put("LossValue", (Array(lossSum), lossNum)) 385 | } 386 | gradientUpdateHash.toIterator 387 | } 388 | 389 | /** 390 | * 梯度更新 391 | * 392 | * */ 393 | def embeddingUpdate(mainNodeInfo: Array[String], 394 | slaveNodeInfo: Array[String], 395 | embeddingHash: mutable.HashMap[String, Array[Double]], 396 | gradientUpdateHash: mutable.HashMap[String, (Array[Double], Long)], 397 | mainNodeWeightRaw: Array[Double], 398 | eavWvAggSum: DenseVector[Double], 399 | eavSum: Double, 400 | Hv: DenseVector[Double], 401 | label: Double, 402 | vectorSize: Int, 403 | sideInfoNum: Int): Unit = { 404 | val keySlaveNode = slaveNodeInfo.head + "#Zu" 405 | val valueSlaveNode = embeddingHash(keySlaveNode) 406 | val HvZu = Hv.t * new DenseVector[Double](valueSlaveNode) 407 | // 更新从节点Zu的梯度 408 | val gradientZu = (1.0 / (1.0 + math.exp(math.max(math.min(-HvZu, 20.0), -20.0))) - label) * Hv 409 | var gradientZuSum = gradientUpdateHash.getOrElseUpdate(keySlaveNode, (Array.fill[Double](vectorSize)(0.0), 0L))._1 410 | gradientZuSum = (new DenseVector[Double](gradientZuSum) + gradientZu).toArray 411 | var gradientZuNum = gradientUpdateHash.getOrElseUpdate(keySlaveNode, (Array.fill[Double](vectorSize)(0.0), 0L))._2 412 | gradientZuNum = gradientZuNum + 1L 413 | gradientUpdateHash.put(keySlaveNode, (gradientZuSum, gradientZuNum)) 414 | // 更新主节点边权重的梯度和边信息的梯度 415 | val gradientHv = (1.0 / (1.0 + math.exp(math.max(math.min(-HvZu, 20.0), -20.0))) - label) * new DenseVector[Double](valueSlaveNode) 416 | val mainNodeWeightGradientSum = gradientUpdateHash.getOrElseUpdate(mainNodeInfo.head + "#Weight", (Array.fill[Double](sideInfoNum)(0.0), 0L))._1 417 | var mainNodeWeightGradientNum = gradientUpdateHash.getOrElseUpdate(mainNodeInfo.head + "#Weight", (Array.fill[Double](sideInfoNum)(0.0), 0L))._2 418 | for(m <- 0 until sideInfoNum) { 419 | val wsvVectorRaw = new DenseVector[Double](embeddingHash(mainNodeInfo(m))) 420 | // 更新主节点边权重的梯度 421 | val gradientAsv = gradientHv.t * (eavSum * math.exp(mainNodeWeightRaw(m)) * wsvVectorRaw - math.exp(mainNodeWeightRaw(m)) * eavWvAggSum)/(eavSum * eavSum) 422 | mainNodeWeightGradientSum(m) = mainNodeWeightGradientSum(m) + gradientAsv 423 | // 更新主节点边信息的梯度 424 | var gradientWsvSum = new DenseVector(gradientUpdateHash.getOrElseUpdate(mainNodeInfo(m), (Array.fill[Double](vectorSize)(0.0), 0L))._1) 425 | var gradientWsvNum = gradientUpdateHash.getOrElseUpdate(mainNodeInfo(m), (Array.fill[Double](vectorSize)(0.0), 0L))._2 426 | gradientWsvSum = gradientWsvSum + math.exp(mainNodeWeightRaw(m)) * gradientHv / eavSum 427 | gradientWsvNum = gradientWsvNum + 1L 428 | gradientUpdateHash.put(mainNodeInfo(m), (gradientWsvSum.toArray, gradientWsvNum)) 429 | } 430 | mainNodeWeightGradientNum = mainNodeWeightGradientNum + 1L 431 | gradientUpdateHash.put(mainNodeInfo.head + "#Weight", (mainNodeWeightGradientSum, mainNodeWeightGradientNum)) 432 | } 433 | } -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/eges/embedding/WeightedSkipGramBatch.scala: -------------------------------------------------------------------------------- 1 | package eges.embedding 2 | 3 | import java.util.Random 4 | import scala.collection.mutable 5 | import org.apache.spark.rdd.RDD 6 | import breeze.linalg.DenseVector 7 | import scala.collection.mutable.ArrayBuffer 8 | import org.apache.spark.storage.StorageLevel 9 | 10 | class WeightedSkipGramBatch extends Serializable { 11 | 12 | var vectorSize = 4 13 | var windowSize = 2 14 | var negativeSampleNum = 1 15 | var perSampleMaxNum = 10 16 | var subSample: Double = 0.0 17 | var learningRate: Double = 0.02 18 | var batchNum = 10 19 | var iterationNum = 5 20 | var isNoShowLoss = false 21 | var numPartitions = 200 22 | var nodeSideCount: Int = 0 23 | var nodeSideTable: Array[(String, Int)] = _ 24 | var nodeSideHash = mutable.HashMap.empty[String, Int] 25 | var sampleTableNum: Int = 1e8.toInt 26 | lazy val sampleTable: Array[Int] = new Array[Int](sampleTableNum) 27 | 28 | /** 29 | * 每个节点embedding长度 30 | * 31 | * */ 32 | def setVectorSize(value:Int):this.type = { 33 | require(value > 0, s"vectorSize must be more than 0, but it is $value") 34 | vectorSize = value 35 | this 36 | } 37 | 38 | /** 39 | * 窗口大小 40 | * 41 | * */ 42 | def setWindowSize(value:Int):this.type = { 43 | require(value > 0, s"windowSize must be more than 0, but it is $value") 44 | windowSize = value 45 | this 46 | } 47 | 48 | /** 49 | * 负采样的样本个数 50 | * 51 | * */ 52 | def setNegativeSampleNum(value:Int):this.type = { 53 | require(value > 0, s"negativeSampleNum must be more than 0, but it is $value") 54 | negativeSampleNum = value 55 | this 56 | } 57 | 58 | /** 59 | * 负采样时, 每个负样本最多采样的次数 60 | * 61 | * */ 62 | def setPerSampleMaxNum(value:Int):this.type = { 63 | require(value > 0, s"perSampleMaxNum must be more than 0, but it is $value") 64 | perSampleMaxNum = value 65 | this 66 | } 67 | 68 | /** 69 | * 高频节点的下采样率 70 | * 71 | * */ 72 | def setSubSample(value:Double):this.type = { 73 | require(value >= 0.0 && value <= 1.0, s"subSample must be not less than 0.0 and not more than 1.0, but it is $value") 74 | subSample = value 75 | this 76 | } 77 | 78 | /** 79 | * 初始化学习率 80 | * 81 | * */ 82 | def setLearningRate(value:Double):this.type = { 83 | require(value > 0.0, s"learningRate must be more than 0.0, but it is $value") 84 | learningRate = value 85 | this 86 | } 87 | 88 | /** 89 | * 数据集批数 90 | * 91 | * */ 92 | def setBatchNum(value:Int):this.type = { 93 | require(value > 0, s"batchNum must be more than 0, but it is $value") 94 | batchNum = value 95 | this 96 | } 97 | 98 | /** 99 | * 迭代次数 100 | * 101 | * */ 102 | def setIterationNum(value:Int):this.type = { 103 | require(value > 0, s"iterationNum must be more than 0, but it is $value") 104 | iterationNum = value 105 | this 106 | } 107 | 108 | /** 109 | * 是否显示损失, 不建议显示损失值, 这样增加计算量 110 | * 111 | * */ 112 | def setIsNoShowLoss(value:Boolean):this.type = { 113 | isNoShowLoss = value 114 | this 115 | } 116 | 117 | /** 118 | * 分区数量 119 | * 120 | * */ 121 | def setNumPartitions(value:Int):this.type = { 122 | require(value > 0, s"numPartitions must be more than 0, but it is $value") 123 | numPartitions = value 124 | this 125 | } 126 | 127 | /** 128 | * 采样数组大小 129 | * 130 | * */ 131 | def setSampleTableNum(value:Int):this.type = { 132 | require(value > 0, s"sampleTableNum must be more than 0, but it is $value") 133 | sampleTableNum = value 134 | this 135 | } 136 | 137 | /** 138 | * 训练 139 | * 140 | * */ 141 | def fit(dataSet: RDD[Array[Array[String]]]): RDD[(String, String)] = { 142 | // 重新分区 143 | val dataSetRepartition = dataSet.map(k => ((new Random).nextInt(numPartitions), k)) 144 | .repartition(numPartitions).map(k => k._2) 145 | dataSetRepartition.persist(StorageLevel.MEMORY_AND_DISK) 146 | val sc = dataSetRepartition.context 147 | // 初始化 WeightedSkipGram 148 | initWeightedSkipGram(dataSetRepartition) 149 | // 边和embedding的长度 150 | val sideInfoNum = dataSetRepartition.first().head.length 151 | val sideInfoNumBroadcast = sc.broadcast(sideInfoNum) 152 | val vectorSizeBroadcast = sc.broadcast(vectorSize) 153 | // 节点权重初始化 154 | val nodeWeight = dataSetRepartition.flatMap(x => x) 155 | .map(nodeInfo => nodeInfo.head + "#Weight").distinct() 156 | .map(nodeWeight => (nodeWeight, Array.fill[Double](sideInfoNumBroadcast.value)((new Random).nextGaussian()))) 157 | // 节点Zu初始化 158 | val nodeZu = dataSetRepartition.flatMap(x => x) 159 | .map(nodeInfo => nodeInfo.head + "#Zu").distinct() 160 | .map(nodeZu => (nodeZu, Array.fill[Double](vectorSizeBroadcast.value)( ((new Random).nextDouble()-0.5)/vectorSizeBroadcast.value ) )) 161 | // 节点和边信息初始化 162 | val nodeAndSideVector = dataSetRepartition.flatMap(x => x).flatMap(x => x).distinct() 163 | .map(nodeAndSide => (nodeAndSide, Array.fill[Double](vectorSizeBroadcast.value)( ((new Random).nextDouble()-0.5)/vectorSizeBroadcast.value ) )) 164 | // 节点权重、节点和边信息初始化 165 | val embeddingHashInit = mutable.HashMap(nodeWeight.union(nodeZu).union(nodeAndSideVector).collectAsMap().toList:_*) 166 | // println("embedding初始化:") 167 | // embeddingHashInit.map(k => (k._1, k._2.mkString("@"))).foreach(println) 168 | 169 | // 广播部分变量 170 | val nodeSideTableBroadcast = sc.broadcast(nodeSideTable) 171 | val nodeSideHashBroadcast = sc.broadcast(nodeSideHash) 172 | val sampleTableBroadcast = sc.broadcast(sampleTable) 173 | val subSampleBroadcast = sc.broadcast(subSample) 174 | val windowSizeBroadcast = sc.broadcast(windowSize) 175 | val negativeSampleNumBroadcast = sc.broadcast(negativeSampleNum) 176 | val perSampleMaxNumBroadcast = sc.broadcast(perSampleMaxNum) 177 | val sampleTableNumBroadcast = sc.broadcast(sampleTableNum) 178 | val nodeSideCountBroadcast = sc.broadcast(nodeSideCount) 179 | val isNoShowLossBroadcast = sc.broadcast(isNoShowLoss) 180 | // val lambdaBroadcast = sc.broadcast(learningRate) 181 | 182 | // 数据分组 183 | val batchWeightArray = Array.fill(batchNum)(1.0/batchNum) 184 | val dataSetBatch = dataSetRepartition.randomSplit(batchWeightArray) 185 | val perBatchCountArray = Array.fill[Int](batchNum)(0) 186 | for(batchA <- 0 until batchNum){ 187 | val batchCount = dataSetBatch(batchA).count().toInt 188 | perBatchCountArray(batchA) = batchCount 189 | if(batchCount > 0) dataSetBatch(batchA).persist(StorageLevel.MEMORY_AND_DISK) 190 | } 191 | dataSetRepartition.unpersist() 192 | // 训练 193 | var embeddingHashBroadcast = sc.broadcast(embeddingHashInit) 194 | for(iter <- 0 until iterationNum){ 195 | val lambdaBroadcast = sc.broadcast(if(learningRate/(1+iter) >= learningRate/5) learningRate/(1+iter) else learningRate/5) 196 | // val lambdaBroadcast = sc.broadcast(if(learningRate/(1+iter) >= 0.000001) learningRate/(1+iter) else 0.000001) 197 | // val lambdaBroadcast = sc.broadcast(Array(0.5, 0.4, 0.2, 0.1, 0.05, 0.05)(iter)) 198 | val perBatchLossArray = ArrayBuffer[Double]() 199 | for(batch <- 0 until batchNum){ 200 | if(perBatchCountArray(batch) > 0) { 201 | // 训练结果 202 | val trainResult = dataSetBatch(batch).map(k => ((new Random).nextInt(numPartitions), k)) 203 | .repartition(numPartitions).map(k => k._2).mapPartitions(dataIterator => 204 | trainWeightedSkipGram(dataIterator, embeddingHashBroadcast.value, nodeSideTableBroadcast.value, 205 | nodeSideHashBroadcast.value, sampleTableBroadcast.value, subSampleBroadcast.value, vectorSizeBroadcast.value, 206 | sideInfoNumBroadcast.value, windowSizeBroadcast.value, negativeSampleNumBroadcast.value, perSampleMaxNumBroadcast.value, 207 | sampleTableNumBroadcast.value, nodeSideCountBroadcast.value, isNoShowLossBroadcast.value)) 208 | trainResult.persist(StorageLevel.MEMORY_AND_DISK) 209 | // 节点权重的梯度聚类 210 | val trainWeightResult = trainResult.filter(k => k._1.endsWith("#Weight") && !k._1.equals("LossValue")) 211 | .aggregateByKey((new DenseVector[Double](Array.fill[Double](sideInfoNumBroadcast.value)(0.0)), 0L))( 212 | (vector, array) => (vector._1 + new DenseVector[Double](array._1), vector._2 + array._2), 213 | (vector1, vector2) => (vector1._1 + vector2._1, vector1._2 + vector2._2) ) 214 | trainWeightResult.persist(StorageLevel.MEMORY_AND_DISK) 215 | // 节点和边embedding的梯度聚类 216 | val trainVectorResult = trainResult.filter(k => !k._1.endsWith("#Weight") && !k._1.equals("LossValue")) 217 | .aggregateByKey((new DenseVector[Double](Array.fill[Double](vectorSizeBroadcast.value)(0.0)), 0L))( 218 | (vector, array) => (vector._1 + new DenseVector[Double](array._1), vector._2 + array._2), 219 | (vector1, vector2) => (vector1._1 + vector2._1, vector1._2 + vector2._2) ) 220 | trainVectorResult.persist(StorageLevel.MEMORY_AND_DISK) 221 | // 计算本轮本批平均损失 222 | if(isNoShowLoss){ 223 | val trainLossResult = trainResult.filter(k => k._1.equals("LossValue")) 224 | .aggregateByKey((new DenseVector[Double](Array.fill[Double](1)(0.0)), 0L))( 225 | (vector, array) => (vector._1 + new DenseVector[Double](array._1), vector._2 + array._2), 226 | (vector1, vector2) => (vector1._1 + vector2._1, vector1._2 + vector2._2) 227 | ).map(k => (k._1, k._2._1/k._2._2.toDouble)) 228 | .map(k => (k._1, k._2.toArray.head)).first()._2 229 | perBatchLossArray.append(trainLossResult) 230 | println(s"====第${iter+1}轮====第${batch+1}批====平均损失:$trainLossResult====") 231 | } 232 | val trainNodeSideResultUnion = trainWeightResult.union(trainVectorResult) 233 | .map{ case (key, (gradientVectorSum, num)) => 234 | val embeddingRaw = new DenseVector[Double](embeddingHashBroadcast.value(key)) 235 | val embeddingUpdate = embeddingRaw - 10 * lambdaBroadcast.value * gradientVectorSum / num.toDouble 236 | (key, embeddingUpdate.toArray) 237 | } 238 | var trainSideVectorResultMap = mutable.HashMap(trainNodeSideResultUnion.collectAsMap().toList:_*) 239 | trainSideVectorResultMap = embeddingHashBroadcast.value.++(trainSideVectorResultMap) 240 | embeddingHashBroadcast = sc.broadcast(trainSideVectorResultMap) 241 | trainResult.unpersist() 242 | trainWeightResult.unpersist() 243 | trainVectorResult.unpersist() 244 | } 245 | } 246 | // 计算每轮平均损失 247 | val batchLossLength = perBatchLossArray.length 248 | if(isNoShowLoss && batchLossLength > 0){ 249 | println(s"====第${iter+1}轮====平均损失:${perBatchLossArray.sum/batchLossLength}====") 250 | } 251 | } 252 | for(batchB <- 0 until batchNum){ 253 | if(perBatchCountArray(batchB) > 0) dataSetBatch(batchB).unpersist() 254 | } 255 | 256 | // 每个节点加权向量 257 | var embeddingHashResult = embeddingHashBroadcast.value 258 | val nodeHvArray = nodeSideTable.map(k => k._1.split("@")) 259 | .map(nodeInfo => { 260 | val node = nodeInfo.head 261 | var eavSum = 0.0 262 | var eavWvAggSum = new DenseVector[Double](Array.fill[Double](vectorSize)(0.0)) 263 | val keyEA = node + "#Weight" 264 | val valueEA = embeddingHashResult.getOrElse(keyEA, Array.fill[Double](sideInfoNum)(0.0)) 265 | for(n <- 0 until sideInfoNum) { 266 | val keyNodeOrSide = nodeInfo(n) 267 | val valueNodeOrSide = embeddingHashResult.getOrElse(keyNodeOrSide, Array.fill[Double](vectorSize)(0.0)) 268 | eavWvAggSum = eavWvAggSum + math.exp(valueEA(n)) * new DenseVector[Double](valueNodeOrSide) 269 | eavSum = eavSum + math.exp(valueEA(n)) 270 | } 271 | val Hv = (eavWvAggSum/eavSum).toArray 272 | (nodeInfo.mkString("#"), Hv) 273 | }) 274 | val nodeHvMap = mutable.Map(nodeHvArray:_*) 275 | embeddingHashResult = embeddingHashResult.++(nodeHvMap) 276 | val embeddingResult = sc.parallelize(embeddingHashResult.map(k => (k._1, k._2.mkString("@"))).toList, numPartitions) 277 | embeddingResult 278 | } 279 | 280 | /** 281 | * 初始化 282 | * 283 | * */ 284 | def initWeightedSkipGram(dataSet: RDD[Array[Array[String]]]): Unit = { 285 | // 节点和边信息出现的频次数组 286 | nodeSideTable = dataSet.flatMap(x => x) 287 | .map(nodeInfo => (nodeInfo.mkString("@"), 1)) 288 | .reduceByKey(_ + _) 289 | .collect().sortBy(-_._2) 290 | // 节点和边信息总数量 291 | nodeSideCount = nodeSideTable.map(_._2).par.sum 292 | //节点和边信息出现的频次哈希map 293 | nodeSideHash = mutable.HashMap(nodeSideTable:_*) 294 | // 节点和边信息尺寸 295 | val nodeSideSize = nodeSideTable.length 296 | // Negative Sampling 负采样初始化 297 | var a = 0 298 | val power = 0.75 299 | var nodesPow = 0.0 300 | while (a < nodeSideSize) { 301 | nodesPow += Math.pow(nodeSideTable(a)._2, power) 302 | a = a + 1 303 | } 304 | var b = 0 305 | var freq = Math.pow(nodeSideTable(b)._2, power) / nodesPow 306 | var c = 0 307 | while (c < sampleTableNum) { 308 | sampleTable(c) = b 309 | if ((c.toDouble + 1.0) / sampleTableNum >= freq) { 310 | b = b + 1 311 | if (b >= nodeSideSize) { 312 | b = nodeSideSize - 1 313 | } 314 | freq += Math.pow(nodeSideTable(b)._2, power) / nodesPow 315 | } 316 | c = c + 1 317 | } 318 | } 319 | 320 | /** 321 | * 每个分区训练 322 | * 323 | * */ 324 | def trainWeightedSkipGram(dataSet: Iterator[Array[Array[String]]], 325 | embeddingHash: mutable.HashMap[String, Array[Double]], 326 | nodeSideTable: Array[(String, Int)], 327 | nodeSideHash: mutable.HashMap[String, Int], 328 | sampleTable: Array[Int], 329 | subSample: Double, 330 | vectorSize: Int, 331 | sideInfoNum: Int, 332 | windowSize: Int, 333 | negativeSampleNum: Int, 334 | perSampleMaxNum: Int, 335 | sampleTableNum: Int, 336 | nodeSideCount: Int, 337 | isNoShowLoss: Boolean): Iterator[(String, (Array[Double], Long))] = { 338 | val gradientUpdateHash = new mutable.HashMap[String, (Array[Double], Long)]() 339 | var lossSum = 0.0 340 | var lossNum = 0L 341 | for(data <- dataSet) { 342 | // 高频节点的下采样 343 | val dataSubSample = ArrayBuffer[Array[String]]() 344 | for( nodeOrSideInfoArray <- data){ 345 | if(subSample > 0.0) { 346 | val nodeOrSideInfoFrequency = nodeSideHash(nodeOrSideInfoArray.mkString("@")) 347 | val keepProbability = (Math.sqrt(nodeOrSideInfoFrequency/ (subSample * nodeSideCount)) + 1.0) * (subSample * nodeSideCount) / nodeOrSideInfoFrequency 348 | if (keepProbability >= (new Random).nextDouble()) { 349 | dataSubSample.append(nodeOrSideInfoArray) 350 | } 351 | } else { 352 | dataSubSample.append(nodeOrSideInfoArray) 353 | } 354 | } 355 | val dssl = dataSubSample.length 356 | for(i <- 0 until dssl){ 357 | val mainNodeInfo = dataSubSample(i) 358 | val mainNode = mainNodeInfo.head 359 | var eavSum = 0.0 360 | var eavWvAggSum = new DenseVector[Double](Array.fill[Double](vectorSize)(0.0)) 361 | val keyEA = mainNode + "#Weight" 362 | val valueEA = embeddingHash(keyEA) 363 | for(k <- 0 until sideInfoNum) { 364 | val keyNodeOrSide = mainNodeInfo(k) 365 | val valueNodeOrSide = embeddingHash(keyNodeOrSide) 366 | eavWvAggSum = eavWvAggSum + math.exp(valueEA(k)) * new DenseVector[Double](valueNodeOrSide) 367 | eavSum = eavSum + math.exp(valueEA(k)) 368 | } 369 | val Hv = eavWvAggSum/eavSum 370 | // 主节点对应窗口内的正样本集合 371 | val mainSlaveNodeSet = dataSubSample.slice(math.max(0, i-windowSize), math.min(i+windowSize, dssl-1) + 1) 372 | .map(array => array.head).toSet 373 | for(j <- math.max(0, i-windowSize) to math.min(i+windowSize, dssl-1)){ 374 | if(j != i) { 375 | // 正样本训练 376 | val slaveNodeInfo = dataSubSample(j) 377 | val slaveNode = slaveNodeInfo.head 378 | embeddingUpdate(mainNodeInfo, slaveNodeInfo, embeddingHash, gradientUpdateHash, valueEA, eavWvAggSum, eavSum, Hv, 1.0, vectorSize, sideInfoNum) 379 | // 正样本计算损失值 380 | if(isNoShowLoss){ 381 | val keySlaveNode = slaveNode + "#Zu" 382 | val valueSlaveNode = embeddingHash(keySlaveNode) 383 | val HvZu = Hv.t * new DenseVector[Double](valueSlaveNode) 384 | val logarithmRaw = math.log(1.0 / (1.0 + math.exp(math.max(math.min(-HvZu, 20.0), -20.0)))) 385 | val logarithm = math.max(math.min(logarithmRaw, 0.0), -20.0) 386 | lossSum = lossSum - logarithm 387 | lossNum = lossNum + 1L 388 | } 389 | // 负样本采样和训练 390 | for(_ <- 0 until negativeSampleNum){ 391 | // 负样本采样 392 | var sampleNodeInfo = slaveNodeInfo 393 | var sampleNode = slaveNode 394 | var sampleNum = 0 395 | while(mainSlaveNodeSet.contains(sampleNode) && sampleNum < perSampleMaxNum) { 396 | val index = sampleTable((new Random).nextInt(sampleTableNum)) 397 | sampleNodeInfo = nodeSideTable(index)._1.split("@") 398 | sampleNode = sampleNodeInfo.head 399 | sampleNum = sampleNum + 1 400 | } 401 | // 负样本训练 402 | embeddingUpdate(mainNodeInfo, sampleNodeInfo, embeddingHash, gradientUpdateHash, valueEA, eavWvAggSum, eavSum, Hv, 0.0, vectorSize, sideInfoNum) 403 | // 负样本计算损失值 404 | if(isNoShowLoss){ 405 | val keySampleNode = sampleNodeInfo.head + "#Zu" 406 | val valueSampleNode = embeddingHash(keySampleNode) 407 | val HvZu = Hv.t * new DenseVector[Double](valueSampleNode) 408 | val logarithmRaw = math.log(1.0 - 1.0 / (1.0 + math.exp(math.max(math.min(-HvZu, 20.0), -20.0)))) 409 | val logarithm = math.max(math.min(logarithmRaw, 0.0), -20.0) 410 | lossSum = lossSum - logarithm 411 | lossNum = lossNum + 1L 412 | } 413 | } 414 | } 415 | } 416 | } 417 | } 418 | if(isNoShowLoss){ 419 | gradientUpdateHash.put("LossValue", (Array(lossSum), lossNum)) 420 | } 421 | gradientUpdateHash.toIterator 422 | } 423 | 424 | /** 425 | * 梯度更新 426 | * 427 | * */ 428 | def embeddingUpdate(mainNodeInfo: Array[String], 429 | slaveNodeInfo: Array[String], 430 | embeddingHash: mutable.HashMap[String, Array[Double]], 431 | gradientUpdateHash: mutable.HashMap[String, (Array[Double], Long)], 432 | mainNodeWeightRaw: Array[Double], 433 | eavWvAggSum: DenseVector[Double], 434 | eavSum: Double, 435 | Hv: DenseVector[Double], 436 | label: Double, 437 | vectorSize: Int, 438 | sideInfoNum: Int): Unit = { 439 | val keySlaveNode = slaveNodeInfo.head + "#Zu" 440 | val valueSlaveNode = embeddingHash(keySlaveNode) 441 | val HvZu = Hv.t * new DenseVector[Double](valueSlaveNode) 442 | // 更新从节点Zu的梯度 443 | val gradientZu = (1.0 / (1.0 + math.exp(math.max(math.min(-HvZu, 20.0), -20.0))) - label) * Hv 444 | var gradientZuSum = gradientUpdateHash.getOrElseUpdate(keySlaveNode, (Array.fill[Double](vectorSize)(0.0), 0L))._1 445 | gradientZuSum = (new DenseVector[Double](gradientZuSum) + gradientZu).toArray 446 | var gradientZuNum = gradientUpdateHash.getOrElseUpdate(keySlaveNode, (Array.fill[Double](vectorSize)(0.0), 0L))._2 447 | gradientZuNum = gradientZuNum + 1L 448 | gradientUpdateHash.put(keySlaveNode, (gradientZuSum, gradientZuNum)) 449 | // 更新主节点边权重的梯度和边信息的梯度 450 | val gradientHv = (1.0 / (1.0 + math.exp(math.max(math.min(-HvZu, 20.0), -20.0))) - label) * new DenseVector[Double](valueSlaveNode) 451 | val mainNodeWeightGradientSum = gradientUpdateHash.getOrElseUpdate(mainNodeInfo.head + "#Weight", (Array.fill[Double](sideInfoNum)(0.0), 0L))._1 452 | var mainNodeWeightGradientNum = gradientUpdateHash.getOrElseUpdate(mainNodeInfo.head + "#Weight", (Array.fill[Double](sideInfoNum)(0.0), 0L))._2 453 | for(m <- 0 until sideInfoNum) { 454 | val wsvVectorRaw = new DenseVector[Double](embeddingHash(mainNodeInfo(m))) 455 | // 更新主节点边权重的梯度 456 | val gradientAsv = gradientHv.t * (eavSum * math.exp(mainNodeWeightRaw(m)) * wsvVectorRaw - math.exp(mainNodeWeightRaw(m)) * eavWvAggSum)/(eavSum * eavSum) 457 | mainNodeWeightGradientSum(m) = mainNodeWeightGradientSum(m) + gradientAsv 458 | // 更新主节点边信息的梯度 459 | var gradientWsvSum = new DenseVector(gradientUpdateHash.getOrElseUpdate(mainNodeInfo(m), (Array.fill[Double](vectorSize)(0.0), 0L))._1) 460 | var gradientWsvNum = gradientUpdateHash.getOrElseUpdate(mainNodeInfo(m), (Array.fill[Double](vectorSize)(0.0), 0L))._2 461 | gradientWsvSum = gradientWsvSum + math.exp(mainNodeWeightRaw(m)) * gradientHv / eavSum 462 | gradientWsvNum = gradientWsvNum + 1L 463 | gradientUpdateHash.put(mainNodeInfo(m), (gradientWsvSum.toArray, gradientWsvNum)) 464 | } 465 | mainNodeWeightGradientNum = mainNodeWeightGradientNum + 1L 466 | gradientUpdateHash.put(mainNodeInfo.head + "#Weight", (mainNodeWeightGradientSum, mainNodeWeightGradientNum)) 467 | } 468 | } -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/eges/random/sample/Alias.scala: -------------------------------------------------------------------------------- 1 | package eges.random.sample 2 | 3 | import scala.collection.mutable.ArrayBuffer 4 | 5 | object Alias { 6 | 7 | /** 8 | * 根据节点和节点的权重创建Alias抽样方法 9 | * 10 | * */ 11 | def setupAlias(nodeWeights: Array[(Long, Double)]): (Array[Int], Array[Double]) = { 12 | val K = nodeWeights.length 13 | val J = Array.fill(K)(0) 14 | val q = Array.fill(K)(0.0) 15 | 16 | val smaller = new ArrayBuffer[Int]() 17 | val larger = new ArrayBuffer[Int]() 18 | 19 | val sum = nodeWeights.map(_._2).sum 20 | nodeWeights.zipWithIndex.foreach { case ((nodeId, weight), i) => 21 | q(i) = K * weight / sum 22 | if (q(i) < 1.0) { 23 | smaller.append(i) 24 | } else { 25 | larger.append(i) 26 | } 27 | } 28 | 29 | while (smaller.nonEmpty && larger.nonEmpty) { 30 | val small = smaller.remove(smaller.length - 1) 31 | val large = larger.remove(larger.length - 1) 32 | 33 | J(small) = large 34 | q(large) = q(large) + q(small) - 1.0 35 | if (q(large) < 1.0) smaller.append(large) 36 | else larger.append(large) 37 | } 38 | 39 | (J, q) 40 | } 41 | 42 | /** 43 | * 44 | * 根据源节点、源节点的邻居和目的节点的邻居来创建边的Alias抽样方法 45 | * 46 | * */ 47 | def setupEdgeAlias(p: Double = 1.0, q: Double = 1.0)(srcId: Long, srcNeighbors: Array[(Long, Double)], dstNeighbors: Array[(Long, Double)]): (Array[Int], Array[Double]) = { 48 | val neighbors_ = dstNeighbors.map { case (dstNeighborId, weight) => 49 | var unnormProb = weight / q 50 | if (srcId == dstNeighborId) unnormProb = weight / p 51 | else if (srcNeighbors.exists(_._1 == dstNeighborId)) unnormProb = weight 52 | 53 | (dstNeighborId, unnormProb) 54 | } 55 | 56 | setupAlias(neighbors_) 57 | } 58 | 59 | /** 60 | * 用Alias进行抽样 61 | * 62 | * */ 63 | def drawAlias(J: Array[Int], q: Array[Double]): Int = { 64 | val K = J.length 65 | val kk = math.floor(math.random * K).toInt 66 | 67 | if (math.random < q(kk)) kk 68 | else J(kk) 69 | } 70 | 71 | } -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/eges/random/walk/Attributes.scala: -------------------------------------------------------------------------------- 1 | package eges.random.walk 2 | 3 | import java.io.Serializable 4 | 5 | object Attributes { 6 | 7 | /** 8 | * 节点属性 9 | * @param neighbors 节点对应的邻居节点和权重 10 | * @param path 以该节点开头的路径的前两个节点 11 | * 12 | * */ 13 | case class NodeAttr(var neighbors: Array[(Long, Double)] = Array.empty[(Long, Double)], 14 | var path: Array[Array[Long]] = Array.empty[Array[Long]]) extends Serializable 15 | 16 | /** 17 | * 边属性 18 | * @param dstNeighbors 目的节点对应的邻居节点 19 | * @param J Alias抽样方法返回值 20 | * @param q Alias抽样方法返回值 21 | * 22 | * */ 23 | case class EdgeAttr(var dstNeighbors: Array[Long] = Array.empty[Long], 24 | var J: Array[Int] = Array.empty[Int], 25 | var q: Array[Double] = Array.empty[Double]) extends Serializable 26 | 27 | } 28 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/eges/random/walk/RandomWalk.scala: -------------------------------------------------------------------------------- 1 | package eges.random.walk 2 | 3 | import scala.util.Try 4 | import java.util.Random 5 | import scala.collection.mutable 6 | import org.apache.spark.rdd.RDD 7 | import eges.random.sample.Alias 8 | import org.apache.spark.graphx._ 9 | import eges.random.walk.Attributes._ 10 | import org.apache.spark.SparkContext 11 | import scala.collection.mutable.ArrayBuffer 12 | import org.apache.spark.broadcast.Broadcast 13 | 14 | class RandomWalk extends Serializable { 15 | 16 | var nodeIterationNum = 10 17 | var nodeWalkLength = 78 18 | var returnParameter = 1.0 19 | var inOutParameter = 1.0 20 | var nodeMaxDegree = 200 21 | var numPartitions = 500 22 | var nodeIndex: Broadcast[mutable.HashMap[String, Long]] = _ 23 | var indexNode: Broadcast[mutable.HashMap[Long, String]] = _ 24 | var indexedNodes: RDD[(VertexId, NodeAttr)] = _ 25 | var indexedEdges: RDD[Edge[EdgeAttr]] = _ 26 | var graph: Graph[NodeAttr, EdgeAttr] = _ 27 | var randomWalkPaths: RDD[(Long, ArrayBuffer[Long])] = _ 28 | 29 | /** 30 | * 每个节点产生路径的数量 31 | * 32 | * */ 33 | def setNodeIterationNum(value:Int):this.type = { 34 | require(value > 0, s"nodeIterationNum must be more than 0, but it is $value") 35 | nodeIterationNum = value 36 | this 37 | } 38 | 39 | /** 40 | * 每个路径的最大长度 41 | * 42 | * */ 43 | def setNodeWalkLength(value:Int):this.type = { 44 | require(value >= 2, s"nodeWalkLength must be not less than 2, but it is $value") 45 | nodeWalkLength = value - 2 46 | this 47 | } 48 | 49 | /** 50 | * 往回走的参数 51 | * 52 | * */ 53 | def setReturnParameter(value:Double):this.type = { 54 | require(value > 0, s"returnParameter must be more than 0, but it is $value") 55 | returnParameter = value 56 | this 57 | } 58 | 59 | /** 60 | * 往外走的参数 61 | * 62 | * */ 63 | def setInOutParameter(value:Double):this.type = { 64 | require(value > 0, s"inOutParameter must be more than 0, but it is $value") 65 | inOutParameter = value 66 | this 67 | } 68 | 69 | /** 70 | * 每个节点最多的邻居数量 71 | * 72 | * */ 73 | def setNodeMaxDegree(value:Int):this.type = { 74 | require(value > 0, s"nodeMaxDegree must be more than 0, but it is $value") 75 | nodeMaxDegree = value 76 | this 77 | } 78 | 79 | /** 80 | * 分区数量, 并行度 81 | * 82 | * */ 83 | def setNumPartitions(value:Int):this.type = { 84 | require(value > 0, s"numPartitions must be more than 0, but it is $value") 85 | numPartitions = value 86 | this 87 | } 88 | 89 | /** 90 | * 每个节点游走的路径 91 | * 92 | * */ 93 | def fit(node2Weight: RDD[(String, String, Double)]): RDD[(String, String)] = { 94 | 95 | // 广播部分变量 96 | val sc = node2Weight.context 97 | val nodeIterationNumBroadcast = sc.broadcast(nodeIterationNum) 98 | val returnParameterBroadcast = sc.broadcast(returnParameter) 99 | val inOutParameterBroadcast = sc.broadcast(inOutParameter) 100 | val nodeMaxDegreeBroadcast = sc.broadcast(nodeMaxDegree) 101 | 102 | // 将字符串节点与节点的权重转化为长整型与长整型的权重 103 | val inputTriplets = processData(node2Weight, sc) 104 | inputTriplets.cache() 105 | inputTriplets.first() // 行动操作 106 | 107 | // 节点对应的节点属性 108 | indexedNodes = inputTriplets.map{ case (node1, node2, weight) => (node1, (node2, weight)) } 109 | .combineByKey(dstIdWeight =>{ 110 | implicit object ord extends Ordering[(Long, Double)]{ 111 | override def compare(p1:(Long, Double), p2:(Long, Double)):Int = { 112 | p2._2.compareTo(p1._2) 113 | } 114 | } 115 | val priorityQueue = new mutable.PriorityQueue[(Long, Double)]() 116 | priorityQueue.enqueue(dstIdWeight) 117 | priorityQueue 118 | },(priorityQueue:mutable.PriorityQueue[(Long,Double)], dstIdWeight)=>{ 119 | if(priorityQueue.size < nodeMaxDegreeBroadcast.value){ 120 | priorityQueue.enqueue(dstIdWeight) 121 | }else{ 122 | if(priorityQueue.head._2 < dstIdWeight._2){ 123 | priorityQueue.dequeue() 124 | priorityQueue.enqueue(dstIdWeight) 125 | } 126 | } 127 | priorityQueue 128 | },(priorityQueuePre:mutable.PriorityQueue[(Long,Double)], 129 | priorityQueueLast:mutable.PriorityQueue[(Long,Double)])=>{ 130 | while(priorityQueueLast.nonEmpty){ 131 | val dstIdWeight = priorityQueueLast.dequeue() 132 | if(priorityQueuePre.size < nodeMaxDegreeBroadcast.value){ 133 | priorityQueuePre.enqueue(dstIdWeight) 134 | }else{ 135 | if(priorityQueuePre.head._2 < dstIdWeight._2){ 136 | priorityQueuePre.dequeue() 137 | priorityQueuePre.enqueue(dstIdWeight) 138 | } 139 | } 140 | } 141 | priorityQueuePre 142 | }).map{ case (srcId, dstIdWeightPriorityQueue) => 143 | (srcId, NodeAttr(neighbors = dstIdWeightPriorityQueue.toArray)) 144 | }.map(k => ((new Random).nextInt(numPartitions), k)).repartition(numPartitions).map(k => k._2) 145 | indexedNodes.cache() 146 | indexedNodes.first() // 行动操作 147 | inputTriplets.unpersist() 148 | 149 | //边的属性 150 | indexedEdges = indexedNodes.flatMap { case (srcId, srcIdAttr) => 151 | srcIdAttr.neighbors.map { case (dstId, weight) => 152 | Edge(srcId, dstId, EdgeAttr()) 153 | } 154 | }.map(k => ((new Random).nextInt(numPartitions), k)).repartition(numPartitions).map(k => k._2) 155 | indexedEdges.cache() 156 | indexedEdges.first() // 行动操作 157 | 158 | // 通过节点属性和边的属性构建图 159 | graph = Graph(indexedNodes, indexedEdges) 160 | .mapVertices[NodeAttr]{ case (vertexId, vertexAttr) => 161 | try{ 162 | val (j, q) = Alias.setupAlias(vertexAttr.neighbors) 163 | val pathArray = ArrayBuffer[Array[Long]]() 164 | for(_ <- 0 until nodeIterationNumBroadcast.value){ 165 | val nextNodeIndex = Alias.drawAlias(j, q) 166 | val path = Array(vertexId, vertexAttr.neighbors(nextNodeIndex)._1) 167 | pathArray.append(path) 168 | } 169 | vertexAttr.path = pathArray.toArray 170 | vertexAttr 171 | } catch { 172 | case _:Exception => Attributes.NodeAttr() 173 | } 174 | }.mapTriplets{ edgeTriplet: EdgeTriplet[NodeAttr, EdgeAttr] => 175 | val dstAttrNeighbors = Try(edgeTriplet.dstAttr.neighbors).getOrElse(Array.empty[(Long, Double)]) 176 | val (j, q) = Alias.setupEdgeAlias(returnParameterBroadcast.value, inOutParameterBroadcast.value)(edgeTriplet.srcId, edgeTriplet.srcAttr.neighbors, dstAttrNeighbors) 177 | edgeTriplet.attr.J = j 178 | edgeTriplet.attr.q = q 179 | edgeTriplet.attr.dstNeighbors = dstAttrNeighbors.map(_._1) 180 | edgeTriplet.attr 181 | } 182 | graph.cache() 183 | 184 | // 所有边的源节点、目的节点和边的属性 185 | val edgeAttr = graph.triplets.map{ edgeTriplet => 186 | (s"${edgeTriplet.srcId}${edgeTriplet.dstId}", edgeTriplet.attr) 187 | }.map(k => ((new Random).nextInt(numPartitions), k)).repartition(numPartitions).map(k => k._2) 188 | edgeAttr.cache() 189 | edgeAttr.first() // 行动操作 190 | indexedNodes.unpersist() 191 | indexedEdges.unpersist() 192 | 193 | // 随机游走产生序列数据 194 | for (iter <- 0 until nodeIterationNum) { 195 | var randomWalk = graph.vertices.filter{ case (vertexId, vertexAttr) => 196 | val vertexNeighborsLength = Try(vertexAttr.neighbors.length).getOrElse(0) 197 | vertexNeighborsLength > 0 198 | }.map { case (vertexId, vertexAttr) => 199 | val pathBuffer = new ArrayBuffer[Long]() 200 | pathBuffer.append(vertexAttr.path(iter):_*) 201 | (vertexId, pathBuffer) 202 | } 203 | randomWalk.cache() 204 | randomWalk.first() // 行动操作 205 | 206 | for (_ <- 0 until nodeWalkLength) { 207 | randomWalk = randomWalk.map { case (srcNodeId, pathBuffer) => 208 | val prevNodeId = pathBuffer(pathBuffer.length - 2) 209 | val currentNodeId = pathBuffer.last 210 | 211 | (s"$prevNodeId$currentNodeId", (srcNodeId, pathBuffer)) 212 | }.join(edgeAttr).map { case (edge, ((srcNodeId, pathBuffer), attr)) => 213 | try { 214 | val nextNodeIndex = Alias.drawAlias(attr.J, attr.q) 215 | val nextNodeId = attr.dstNeighbors(nextNodeIndex) 216 | pathBuffer.append(nextNodeId) 217 | 218 | (srcNodeId, pathBuffer) 219 | } catch { 220 | case _: Exception => (srcNodeId, pathBuffer) 221 | } 222 | } 223 | randomWalk.cache() 224 | randomWalk.first() // 行动操作 225 | } 226 | 227 | if (randomWalkPaths != null) { 228 | randomWalkPaths = randomWalkPaths.union(randomWalk) 229 | } else { 230 | randomWalkPaths = randomWalk 231 | } 232 | randomWalkPaths.cache() 233 | randomWalkPaths.first() // 行动操作 234 | randomWalk.unpersist() 235 | } 236 | graph.unpersist() 237 | edgeAttr.unpersist() 238 | 239 | // 将长整型节点与路径转化为字符串节点与路径, 路径用","分开 240 | val paths = longToString(randomWalkPaths) 241 | paths.cache() 242 | paths.first() // 行动操作 243 | randomWalkPaths.unpersist() 244 | paths 245 | 246 | } 247 | 248 | /** 249 | * 将字符串节点与节点的权重转化为长整型与长整型的权重 250 | * 251 | * */ 252 | def processData(node2Weight: RDD[(String, String, Double)], sc:SparkContext): RDD[(Long, Long, Double)] = { 253 | val nodeIndexArray = node2Weight.flatMap(node => Array(node._1, node._2)) 254 | .distinct().zipWithIndex().collect() 255 | val indexNodeArray = nodeIndexArray.map{case (node, index) => (index, node)} 256 | indexNode = sc.broadcast(mutable.HashMap(indexNodeArray:_*)) 257 | nodeIndex = sc.broadcast(mutable.HashMap(nodeIndexArray:_*)) 258 | val inputTriplets = node2Weight.map{ case (node1, node2, weight) => 259 | val node1Index = nodeIndex.value(node1) 260 | val node2Index = nodeIndex.value(node2) 261 | (node1Index, node2Index, weight) 262 | } 263 | inputTriplets 264 | } 265 | 266 | /** 267 | * 将长整型节点与路径转化为字符串节点与路径, 路径用","分开 268 | * 269 | * */ 270 | def longToString(srcNodeIdPaths: RDD[(Long, ArrayBuffer[Long])]): RDD[(String, String)] = { 271 | val pathsResult = srcNodeIdPaths.map{case (srcNodeId, paths) => 272 | val srcNodeIdString = indexNode.value(srcNodeId) 273 | val pathsString = paths.map(index => indexNode.value(index)).mkString(",") 274 | (srcNodeIdString, pathsString) 275 | } 276 | pathsResult 277 | } 278 | 279 | } -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/example/Example1.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import scala.collection.mutable 4 | import eges.random.walk.RandomWalk 5 | import sparkapplication.BaseSparkLocal 6 | 7 | object Example1 extends BaseSparkLocal { 8 | def main(args:Array[String]):Unit = { 9 | val spark = this.basicSpark 10 | import spark.implicits._ 11 | 12 | val list1 = List((("gds1", "group1"), ("gds2", 0.1)), (("gds1", "group1"), ("gds3", 0.2)), (("gds1", "group1"), ("gds4", 0.3)), 13 | (("gds2", "group2"), ("gds3", 0.1)), (("gds2", "group2"), ("gds4", 0.3)), (("gds2", "group2"), ("gds5", 0.6)), 14 | (("gds3", "group4"), ("gds2", 0.1)), (("gds3", "group4"), ("gds5", 0.3)), (("gds3", "group4"), ("gds6", 0.5))) 15 | val list2 = spark.sparkContext.parallelize(list1) 16 | val list3 = list2.combineByKey(gds2WithSimilarity=>{ 17 | implicit object ord extends Ordering[(String,Double)] { 18 | override def compare(p1: (String,Double), p2: (String,Double)): Int = { 19 | p2._2.compareTo(p1._2) 20 | } 21 | } 22 | 23 | val priorityQueue = new mutable.PriorityQueue[(String,Double)]() 24 | priorityQueue.enqueue(gds2WithSimilarity) 25 | priorityQueue 26 | },(priorityQueue:mutable.PriorityQueue[(String,Double)],gds2WithSimilarity)=>{ 27 | 28 | if(priorityQueue.size < 2){ 29 | priorityQueue.enqueue(gds2WithSimilarity) 30 | }else{ 31 | if(priorityQueue.head._2 < gds2WithSimilarity._2){ 32 | priorityQueue.dequeue() 33 | priorityQueue.enqueue(gds2WithSimilarity) 34 | } 35 | } 36 | 37 | priorityQueue 38 | },(priorityQueuePre:mutable.PriorityQueue[(String,Double)], 39 | priorityQueueLast:mutable.PriorityQueue[(String,Double)])=>{ 40 | while(!priorityQueueLast.isEmpty){ 41 | 42 | val gds2WithSimilarity = priorityQueueLast.dequeue() 43 | if(priorityQueuePre.size < 2){ 44 | priorityQueuePre.enqueue(gds2WithSimilarity) 45 | }else{ 46 | if(priorityQueuePre.head._2 < gds2WithSimilarity._2){ 47 | priorityQueuePre.dequeue() 48 | priorityQueuePre.enqueue(gds2WithSimilarity) 49 | } 50 | } 51 | } 52 | 53 | priorityQueuePre}).flatMap({case ((gdsCd1, l4GroupCd), gdsCd2AndSimilarityPriorityQueue) => { 54 | 55 | gdsCd2AndSimilarityPriorityQueue.toList.map({case(gdsCd2, similarity) => (gdsCd1, gdsCd2, l4GroupCd, similarity)}) 56 | 57 | }}) 58 | 59 | list3.foreach(println) 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | } 74 | 75 | } 76 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/example/Example10.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import breeze.linalg.DenseVector 4 | import sparkapplication.BaseSparkLocal 5 | 6 | object Example10 extends BaseSparkLocal { 7 | def main(args:Array[String]):Unit = { 8 | val spark = this.basicSpark 9 | import spark.implicits._ 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/example/Example11.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import java.util.Random 4 | import breeze.linalg.DenseVector 5 | import sparkapplication.BaseSparkLocal 6 | 7 | object Example11 extends BaseSparkLocal { 8 | def main(args:Array[String]):Unit = { 9 | // val spark = this.basicSpark 10 | // import spark.implicits._ 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | } 46 | 47 | } -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/example/Example2.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import breeze.linalg.operators 4 | import breeze.linalg.{DenseVector, SparseVector, Vector} 5 | import breeze.numerics._ 6 | import java.util.{Date, Random} 7 | import sparkapplication.BaseSparkLocal 8 | import org.apache.spark.sql.functions._ 9 | import scala.collection.mutable 10 | 11 | object Example2 extends BaseSparkLocal { 12 | def main(args:Array[String]):Unit = { 13 | val spark = this.basicSpark 14 | import spark.implicits._ 15 | 16 | val nodeMaxDegree = 2 17 | val list1 = List((1:Long, (2:Long, 3.2:Double)), (1:Long, (3:Long, 2.2:Double)), (1:Long, (4:Long, 1.2:Double)), 18 | (1:Long, (5:Long, 0.2:Double)), (6:Long, (7:Long, 6.2:Double)), (6:Long, (8:Long, 8.2:Double))) 19 | val list2 = spark.sparkContext.parallelize(list1) 20 | val list3 = list2.combineByKey(dstIdWeight =>{ 21 | implicit object ord extends Ordering[(Long, Double)]{ 22 | override def compare(p1:(Long, Double), p2:(Long, Double)):Int = { 23 | p2._2.compareTo(p1._2) 24 | } 25 | } 26 | val priorityQueue = new mutable.PriorityQueue[(Long, Double)]() 27 | priorityQueue.enqueue(dstIdWeight) 28 | priorityQueue 29 | },(priorityQueue:mutable.PriorityQueue[(Long,Double)], dstIdWeight)=>{ 30 | if(priorityQueue.size < nodeMaxDegree){ 31 | priorityQueue.enqueue(dstIdWeight) 32 | }else{ 33 | if(priorityQueue.head._2 < dstIdWeight._2){ 34 | priorityQueue.dequeue() 35 | priorityQueue.enqueue(dstIdWeight) 36 | } 37 | } 38 | priorityQueue 39 | },(priorityQueuePre:mutable.PriorityQueue[(Long,Double)], 40 | priorityQueueLast:mutable.PriorityQueue[(Long,Double)])=>{ 41 | while(priorityQueueLast.nonEmpty){ 42 | val dstIdWeight = priorityQueueLast.dequeue() 43 | if(priorityQueuePre.size < nodeMaxDegree){ 44 | priorityQueuePre.enqueue(dstIdWeight) 45 | }else{ 46 | if(priorityQueuePre.head._2 < dstIdWeight._2){ 47 | priorityQueuePre.dequeue() 48 | priorityQueuePre.enqueue(dstIdWeight) 49 | } 50 | } 51 | } 52 | priorityQueuePre 53 | }).map{case (srcId, dstIdWeightPriorityQueue) => (srcId, dstIdWeightPriorityQueue.toArray.sortBy(-_._2).mkString("@"))} 54 | 55 | 56 | 57 | 58 | 59 | list3.foreach(println) 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | } 119 | } -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/example/Example3.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import eges.random.walk.RandomWalk 4 | import sparkapplication.BaseSparkLocal 5 | 6 | object Example3 extends BaseSparkLocal { 7 | def main(args:Array[String]):Unit = { 8 | val spark = this.basicSpark 9 | import spark.implicits._ 10 | 11 | val data = List(("y", "a", 0.8), ("y", "b", 0.6), ("y", "c", 0.4), ("q", "b", 0.8), ("q", "c", 0.6), ("q", "d", 0.4), 12 | ("s", "a", 0.8), ("s", "c", 0.6), ("s", "d", 0.4), ("a", "e", 0.5), ("b", "f", 0.5), ("c", "e", 0.5), ("c", "f", 0.5)) 13 | val dataRDD = spark.sparkContext.parallelize(data) 14 | val randomWalk = new RandomWalk() 15 | .setNodeIterationNum(5) 16 | .setNodeWalkLength(6) 17 | .setReturnParameter(1.0) 18 | .setInOutParameter(1.0) 19 | .setNodeMaxDegree(200) 20 | .setNumPartitions(2) 21 | 22 | val paths = randomWalk.fit(dataRDD) 23 | paths.cache() 24 | println("总路径数: ", paths.count()) 25 | println(paths.map(k => k._2).collect().mkString("#")) 26 | paths.unpersist() 27 | 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/example/Example4.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import java.util.Random 4 | import eges.embedding.WeightedSkipGram 5 | import org.apache.spark.sql.functions._ 6 | import sparkapplication.BaseSparkLocal 7 | import breeze.linalg.{norm, DenseVector} 8 | import scala.collection.mutable 9 | import scala.collection.mutable.ArrayBuffer 10 | 11 | object Example4 extends BaseSparkLocal { 12 | def main(args:Array[String]):Unit = { 13 | val spark = this.basicSpark 14 | import spark.implicits._ 15 | 16 | val dataList = List("a@s1,b@s2,c@s3,d@s4,e@s5,f@s6,g@s7", "e@s5,f@s6,e@s5,a@s1,c@s3,h@s2,i@s6,j@s7,a@s2") 17 | val dataRDD = spark.sparkContext.parallelize(dataList) 18 | .map(k => k.split(",").map(v => v.split("@"))) 19 | val weightedSkipGram = new WeightedSkipGram() 20 | .setVectorSize(4) 21 | .setWindowSize(1) 22 | .setNegativeSampleNum(1) 23 | .setPerSampleMaxNum(100) 24 | // .setSubSample(0.1) 25 | .setSubSample(0.0) 26 | .setLearningRate(0.25) 27 | .setIterationNum(6) 28 | .setIsNoShowLoss(true) 29 | .setNumPartitions(2) 30 | .setSampleTableNum(200) 31 | val nodeEmbedding = weightedSkipGram.fit(dataRDD).map(k => (k._1, new DenseVector[Double](k._2.split("@").map(_.toDouble)))).collectAsMap() 32 | println("nodeEmbedding结果如下:") 33 | nodeEmbedding.map(k => (k._1, k._2.toArray.mkString("@"))).foreach(println) 34 | 35 | val node = Array("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") 36 | // val nodeWeight = Array("a#Weight", "b#Weight", "c#Weight", "d#Weight", "e#Weight", "f#Weight", "g#Weight", "h#Weight", "i#Weight", "j#Weight") 37 | val nodeAgg = Array("a#s1", "b#s2", "c#s3", "d#s4", "e#s5", "f#s6", "g#s7", "h#s2", "i#s6", "j#s7", "a#s2") 38 | val i2iSimilar = ArrayBuffer[(String, String, Double)]() 39 | for(mainNodeInfo <- nodeAgg){ 40 | val mainNodeInfoVector = nodeEmbedding(mainNodeInfo) 41 | for(slaveNodeInfo <- nodeAgg){ 42 | val slaveNodeInfoVector = nodeEmbedding(slaveNodeInfo) 43 | if(!mainNodeInfo.equals(slaveNodeInfo)) { 44 | val cosineSimilar = mainNodeInfoVector.t * slaveNodeInfoVector / ( norm(mainNodeInfoVector, 2) * norm(slaveNodeInfoVector, 2) ) 45 | i2iSimilar.append((mainNodeInfo, slaveNodeInfo, cosineSimilar)) 46 | } 47 | } 48 | } 49 | println("i2iSimilarHv结果如下:") 50 | i2iSimilar.foreach(println) 51 | 52 | val i2iSimilarSequence = ArrayBuffer[(String, String, Double)]() 53 | for(mainNodeInfo <- nodeAgg){ 54 | val mainNodeInfoVector = nodeEmbedding(mainNodeInfo) 55 | for(slaveNode <- node){ 56 | val slaveNodeInfoVector = nodeEmbedding(slaveNode+"#Zu") 57 | if(!mainNodeInfo.split("#").head.equals(slaveNode)) { 58 | val similar = mainNodeInfoVector.t * slaveNodeInfoVector 59 | val eHx = math.exp(math.max(math.min(-similar, 20.0), -20.0)) 60 | val probability = 1.0 / (1.0 + eHx) 61 | i2iSimilarSequence.append((mainNodeInfo, slaveNode, probability)) 62 | } 63 | } 64 | } 65 | println("i2iSimilarSequence结果如下:") 66 | i2iSimilarSequence.foreach(println) 67 | 68 | } 69 | } -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/example/Example5.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import java.util.Random 4 | import breeze.linalg.DenseVector 5 | import eges.embedding.WeightedSkipGramBatch 6 | import org.apache.spark.sql.functions._ 7 | import sparkapplication.BaseSparkLocal 8 | import scala.collection.mutable 9 | import breeze.linalg.{norm, DenseVector} 10 | import scala.collection.mutable.ArrayBuffer 11 | 12 | object Example5 extends BaseSparkLocal { 13 | def main(args:Array[String]):Unit = { 14 | val spark = this.basicSpark 15 | import spark.implicits._ 16 | 17 | val dataList = List("a@s1,b@s2,c@s3,d@s4,e@s5,f@s6,g@s7", "e@s5,f@s6,e@s5,a@s1,c@s3,h@s2,i@s6,j@s7,a@s2") 18 | val dataRDD = spark.sparkContext.parallelize(dataList) 19 | .map(k => k.split(",").map(v => v.split("@"))) 20 | val weightedSkipGram = new WeightedSkipGramBatch() 21 | .setVectorSize(4) 22 | .setWindowSize(1) 23 | .setNegativeSampleNum(1) 24 | .setPerSampleMaxNum(100) 25 | // .setSubSample(0.1) 26 | .setSubSample(0.0) 27 | .setLearningRate(0.5) 28 | .setBatchNum(2) 29 | .setIterationNum(6) 30 | .setIsNoShowLoss(true) 31 | .setNumPartitions(2) 32 | .setSampleTableNum(200) 33 | val nodeEmbedding = weightedSkipGram.fit(dataRDD).map(k => (k._1, new DenseVector[Double](k._2.split("@").map(_.toDouble)))).collectAsMap() 34 | println("nodeEmbedding结果如下:") 35 | nodeEmbedding.map(k => (k._1, k._2.toArray.mkString("@"))).foreach(println) 36 | 37 | val node = Array("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") 38 | // val nodeWeight = Array("a#Weight", "b#Weight", "c#Weight", "d#Weight", "e#Weight", "f#Weight", "g#Weight", "h#Weight", "i#Weight", "j#Weight") 39 | val nodeAgg = Array("a#s1", "b#s2", "c#s3", "d#s4", "e#s5", "f#s6", "g#s7", "h#s2", "i#s6", "j#s7", "a#s2") 40 | val i2iSimilar = ArrayBuffer[(String, String, Double)]() 41 | for(mainNodeInfo <- nodeAgg){ 42 | val mainNodeInfoVector = nodeEmbedding(mainNodeInfo) 43 | for(slaveNodeInfo <- nodeAgg){ 44 | val slaveNodeInfoVector = nodeEmbedding(slaveNodeInfo) 45 | if(!mainNodeInfo.equals(slaveNodeInfo)) { 46 | val cosineSimilar = mainNodeInfoVector.t * slaveNodeInfoVector / ( norm(mainNodeInfoVector, 2) * norm(slaveNodeInfoVector, 2) ) 47 | i2iSimilar.append((mainNodeInfo, slaveNodeInfo, cosineSimilar)) 48 | } 49 | } 50 | } 51 | println("i2iSimilarHv结果如下:") 52 | i2iSimilar.foreach(println) 53 | 54 | val i2iSimilarSequence = ArrayBuffer[(String, String, Double)]() 55 | for(mainNodeInfo <- nodeAgg){ 56 | val mainNodeInfoVector = nodeEmbedding(mainNodeInfo) 57 | for(slaveNode <- node){ 58 | val slaveNodeInfoVector = nodeEmbedding(slaveNode+"#Zu") 59 | if(!mainNodeInfo.split("#").head.equals(slaveNode)) { 60 | val similar = mainNodeInfoVector.t * slaveNodeInfoVector 61 | val eHx = math.exp(math.max(math.min(-similar, 20.0), -20.0)) 62 | val probability = 1.0 / (1.0 + eHx) 63 | i2iSimilarSequence.append((mainNodeInfo, slaveNode, probability)) 64 | } 65 | } 66 | } 67 | println("i2iSimilarSequence结果如下:") 68 | i2iSimilarSequence.foreach(println) 69 | 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/example/Example6.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import java.util.Date 4 | import java.util.Random 5 | import breeze.linalg.DenseVector 6 | import utils.miscellaneous.DateUtils 7 | import org.apache.spark.sql.functions._ 8 | import org.apache.spark.util.random.XORShiftRandom 9 | import sparkapplication.BaseSparkLocal 10 | import breeze.linalg.DenseVector 11 | import scala.collection.mutable 12 | import scala.collection.mutable.ArrayBuffer 13 | 14 | object Example6 extends BaseSparkLocal { 15 | def main(args:Array[String]):Unit = { 16 | val spark = this.basicSpark 17 | import spark.implicits._ 18 | 19 | 20 | val a = "-0.30614856@0.16460776@0.19909532@-0.48152933@-0.7925998" 21 | val aVector = new DenseVector[Double](a.split("@").map(_.toDouble)) 22 | val b = "-0.7356615@0.6408239@-0.4380877@-1.920213@-2.628376" 23 | val bVector = new DenseVector[Double](b.split("@").map(_.toDouble)) 24 | val similar = aVector.t * bVector 25 | val eHx = math.exp(math.max(math.min(-similar, 35.0), -35.0)).toFloat 26 | val probability = 1.toFloat / (1.toFloat + eHx) 27 | println(probability) 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/example/Example7.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import breeze.linalg.{DenseVector, norm} 4 | import org.apache.spark.rdd.RDD 5 | import org.apache.spark.sql.expressions.Window 6 | import org.apache.spark.sql.functions._ 7 | import sparkapplication.BaseSparkLocal 8 | 9 | import scala.collection.mutable 10 | import scala.collection.mutable.ArrayBuffer 11 | 12 | object Example7 extends BaseSparkLocal { 13 | def main(args:Array[String]):Unit = { 14 | val spark = this.basicSpark 15 | import spark.implicits._ 16 | 17 | val list = List(("y", "2019"), ("q", "2018")) 18 | val data = spark.createDataFrame(list).toDF("member_id", "visit_time") 19 | .withColumn("index", row_number().over(Window.partitionBy("member_id").orderBy(desc("visit_time")))) 20 | .select("index") 21 | .rdd.map(k => k.getAs[Int]("index")) 22 | 23 | 24 | data.foreach(println) 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/example/Example8.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import scala.collection.mutable 4 | import sparkapplication.BaseSparkLocal 5 | 6 | object Example8 extends BaseSparkLocal { 7 | def main(args:Array[String]):Unit = { 8 | val spark = this.basicSpark 9 | import spark.implicits._ 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/example/Example9.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import java.util.Random 4 | 5 | import breeze.linalg.DenseVector 6 | import sparkapplication.BaseSparkLocal 7 | import org.apache.spark.util.Utils 8 | import org.apache.spark.internal.Logging 9 | import org.apache.spark.util.random.XORShiftRandom 10 | 11 | object Example9 extends BaseSparkLocal { 12 | def main(args:Array[String]):Unit = { 13 | val spark = this.basicSpark 14 | import spark.implicits._ 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | } 42 | 43 | 44 | } 45 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/main/Main1.scala: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import eges.random.walk.RandomWalk 4 | import sparkapplication.BaseSparkLocal 5 | 6 | object Main1 extends BaseSparkLocal { 7 | def main(args:Array[String]):Unit = { 8 | val spark = this.basicSpark 9 | 10 | val data = List(("y", "a", 0.8), ("y", "b", 0.6), ("y", "c", 0.4), ("q", "b", 0.8), ("q", "c", 0.6), ("q", "d", 0.4), 11 | ("s", "a", 0.8), ("s", "c", 0.6), ("s", "d", 0.4), ("a", "e", 0.5), ("b", "f", 0.5), ("c", "e", 0.5), ("c", "f", 0.5)) 12 | val dataRDD = spark.sparkContext.parallelize(data) 13 | val randomWalk = new RandomWalk() 14 | .setNodeIterationNum(5) 15 | .setNodeWalkLength(6) 16 | .setReturnParameter(1.0) 17 | .setInOutParameter(1.0) 18 | .setNodeMaxDegree(200) 19 | .setNumPartitions(2) 20 | 21 | val paths = randomWalk.fit(dataRDD) 22 | paths.cache() 23 | println("总路径数: ", paths.count()) 24 | println(paths.map(k => k._2).collect().mkString("#")) 25 | paths.unpersist() 26 | 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/main/Main2.scala: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import eges.embedding.WeightedSkipGram 4 | import sparkapplication.BaseSparkLocal 5 | import breeze.linalg.{norm, DenseVector} 6 | import scala.collection.mutable.ArrayBuffer 7 | 8 | object Main2 extends BaseSparkLocal { 9 | def main(args:Array[String]):Unit = { 10 | val spark = this.basicSpark 11 | 12 | val dataList = List("a@s1,b@s2,c@s3,d@s4,e@s5,f@s6,g@s7", "e@s5,f@s6,e@s5,a@s1,c@s3,h@s2,i@s6,j@s7,a@s2") 13 | val dataRDD = spark.sparkContext.parallelize(dataList) 14 | .map(k => k.split(",").map(v => v.split("@"))) 15 | val weightedSkipGram = new WeightedSkipGram() 16 | .setVectorSize(4) 17 | .setWindowSize(1) 18 | .setNegativeSampleNum(1) 19 | .setPerSampleMaxNum(100) 20 | // .setSubSample(0.1) 21 | .setSubSample(0.0) 22 | .setLearningRate(0.25) 23 | .setIterationNum(6) 24 | .setIsNoShowLoss(true) 25 | .setNumPartitions(2) 26 | .setSampleTableNum(200) 27 | val nodeEmbedding = weightedSkipGram.fit(dataRDD).map(k => (k._1, new DenseVector[Double](k._2.split("@").map(_.toDouble)))).collectAsMap() 28 | println("nodeEmbedding结果如下:") 29 | nodeEmbedding.map(k => (k._1, k._2.toArray.mkString("@"))).foreach(println) 30 | 31 | val node = Array("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") 32 | // val nodeWeight = Array("a#Weight", "b#Weight", "c#Weight", "d#Weight", "e#Weight", "f#Weight", "g#Weight", "h#Weight", "i#Weight", "j#Weight") 33 | val nodeAgg = Array("a#s1", "b#s2", "c#s3", "d#s4", "e#s5", "f#s6", "g#s7", "h#s2", "i#s6", "j#s7", "a#s2") 34 | val i2iSimilar = ArrayBuffer[(String, String, Double)]() 35 | for(mainNodeInfo <- nodeAgg){ 36 | val mainNodeInfoVector = nodeEmbedding(mainNodeInfo) 37 | for(slaveNodeInfo <- nodeAgg){ 38 | val slaveNodeInfoVector = nodeEmbedding(slaveNodeInfo) 39 | if(!mainNodeInfo.equals(slaveNodeInfo)) { 40 | val cosineSimilar = mainNodeInfoVector.t * slaveNodeInfoVector / ( norm(mainNodeInfoVector, 2) * norm(slaveNodeInfoVector, 2) ) 41 | i2iSimilar.append((mainNodeInfo, slaveNodeInfo, cosineSimilar)) 42 | } 43 | } 44 | } 45 | println("i2iSimilarHv结果如下:") 46 | i2iSimilar.foreach(println) 47 | 48 | val i2iSimilarSequence = ArrayBuffer[(String, String, Double)]() 49 | for(mainNodeInfo <- nodeAgg){ 50 | val mainNodeInfoVector = nodeEmbedding(mainNodeInfo) 51 | for(slaveNode <- node){ 52 | val slaveNodeInfoVector = nodeEmbedding(slaveNode+"#Zu") 53 | if(!mainNodeInfo.split("#").head.equals(slaveNode)) { 54 | val similar = mainNodeInfoVector.t * slaveNodeInfoVector 55 | val eHx = math.exp(math.max(math.min(-similar, 20.0), -20.0)) 56 | val probability = 1.0 / (1.0 + eHx) 57 | i2iSimilarSequence.append((mainNodeInfo, slaveNode, probability)) 58 | } 59 | } 60 | } 61 | println("i2iSimilarSequence结果如下:") 62 | i2iSimilarSequence.foreach(println) 63 | 64 | } 65 | } -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/main/Main3.scala: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import eges.embedding.WeightedSkipGramBatch 4 | import sparkapplication.BaseSparkLocal 5 | import breeze.linalg.{norm, DenseVector} 6 | import scala.collection.mutable.ArrayBuffer 7 | 8 | object Main3 extends BaseSparkLocal { 9 | def main(args:Array[String]):Unit = { 10 | val spark = this.basicSpark 11 | 12 | val dataList = List("a@s1,b@s2,c@s3,d@s4,e@s5,f@s6,g@s7", "e@s5,f@s6,e@s5,a@s1,c@s3,h@s2,i@s6,j@s7,a@s2") 13 | val dataRDD = spark.sparkContext.parallelize(dataList) 14 | .map(k => k.split(",").map(v => v.split("@"))) 15 | val weightedSkipGram = new WeightedSkipGramBatch() 16 | .setVectorSize(4) 17 | .setWindowSize(1) 18 | .setNegativeSampleNum(1) 19 | .setPerSampleMaxNum(100) 20 | // .setSubSample(0.1) 21 | .setSubSample(0.0) 22 | .setLearningRate(0.5) 23 | .setBatchNum(2) 24 | .setIterationNum(6) 25 | .setIsNoShowLoss(true) 26 | .setNumPartitions(2) 27 | .setSampleTableNum(200) 28 | val nodeEmbedding = weightedSkipGram.fit(dataRDD).map(k => (k._1, new DenseVector[Double](k._2.split("@").map(_.toDouble)))).collectAsMap() 29 | println("nodeEmbedding结果如下:") 30 | nodeEmbedding.map(k => (k._1, k._2.toArray.mkString("@"))).foreach(println) 31 | 32 | val node = Array("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") 33 | // val nodeWeight = Array("a#Weight", "b#Weight", "c#Weight", "d#Weight", "e#Weight", "f#Weight", "g#Weight", "h#Weight", "i#Weight", "j#Weight") 34 | val nodeAgg = Array("a#s1", "b#s2", "c#s3", "d#s4", "e#s5", "f#s6", "g#s7", "h#s2", "i#s6", "j#s7", "a#s2") 35 | val i2iSimilar = ArrayBuffer[(String, String, Double)]() 36 | for(mainNodeInfo <- nodeAgg){ 37 | val mainNodeInfoVector = nodeEmbedding(mainNodeInfo) 38 | for(slaveNodeInfo <- nodeAgg){ 39 | val slaveNodeInfoVector = nodeEmbedding(slaveNodeInfo) 40 | if(!mainNodeInfo.equals(slaveNodeInfo)) { 41 | val cosineSimilar = mainNodeInfoVector.t * slaveNodeInfoVector / ( norm(mainNodeInfoVector, 2) * norm(slaveNodeInfoVector, 2) ) 42 | i2iSimilar.append((mainNodeInfo, slaveNodeInfo, cosineSimilar)) 43 | } 44 | } 45 | } 46 | println("i2iSimilarHv结果如下:") 47 | i2iSimilar.foreach(println) 48 | 49 | val i2iSimilarSequence = ArrayBuffer[(String, String, Double)]() 50 | for(mainNodeInfo <- nodeAgg){ 51 | val mainNodeInfoVector = nodeEmbedding(mainNodeInfo) 52 | for(slaveNode <- node){ 53 | val slaveNodeInfoVector = nodeEmbedding(slaveNode+"#Zu") 54 | if(!mainNodeInfo.split("#").head.equals(slaveNode)) { 55 | val similar = mainNodeInfoVector.t * slaveNodeInfoVector 56 | val eHx = math.exp(math.max(math.min(-similar, 20.0), -20.0)) 57 | val probability = 1.0 / (1.0 + eHx) 58 | i2iSimilarSequence.append((mainNodeInfo, slaveNode, probability)) 59 | } 60 | } 61 | } 62 | println("i2iSimilarSequence结果如下:") 63 | i2iSimilarSequence.foreach(println) 64 | 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/sparkapplication/BaseSparkLocal.scala: -------------------------------------------------------------------------------- 1 | package sparkapplication 2 | 3 | import org.apache.spark.SparkConf 4 | import org.apache.spark.sql.SparkSession 5 | 6 | trait BaseSparkLocal { 7 | //本地 8 | def basicSpark: SparkSession = 9 | SparkSession 10 | .builder 11 | .config(getSparkConf) 12 | .master("local[1]") 13 | .getOrCreate() 14 | 15 | def getSparkConf: SparkConf = { 16 | val conf = new SparkConf() 17 | conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 18 | .set("spark.network.timeout", "600") 19 | .set("spark.streaming.kafka.maxRatePerPartition", "200000") 20 | .set("spark.streaming.kafka.consumer.poll.ms", "5120") 21 | .set("spark.streaming.concurrentJobs", "5") 22 | .set("spark.sql.crossJoin.enabled", "true") 23 | .set("spark.driver.maxResultSize", "1g") 24 | .set("spark.rpc.message.maxSize", "1000") // 1024 max 25 | conf 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /enhanced_graph_embedding_side_information/src/main/scala/sparkapplication/BaseSparkOnline.scala: -------------------------------------------------------------------------------- 1 | package sparkapplication 2 | 3 | import org.apache.spark.SparkConf 4 | import org.apache.spark.sql.SparkSession 5 | 6 | trait BaseSparkOnline { 7 | def basicSpark: SparkSession = 8 | SparkSession 9 | .builder 10 | .config(getSparkConf) 11 | .enableHiveSupport() 12 | .getOrCreate() 13 | 14 | def getSparkConf: SparkConf = { 15 | val conf = new SparkConf() 16 | conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 17 | .set("spark.network.timeout", "6000") 18 | .set("spark.streaming.kafka.maxRatePerPartition", "200000") 19 | .set("spark.streaming.kafka.consumer.poll.ms", "5120") 20 | .set("spark.streaming.concurrentJobs", "5") 21 | .set("spark.sql.crossJoin.enabled", "true") 22 | .set("spark.driver.maxResultSize", "20g") 23 | .set("spark.rpc.message.maxSize", "1000") // 1024 max 24 | } 25 | 26 | } 27 | -------------------------------------------------------------------------------- /paper/[Alibaba Embedding] Billion-scale Commodity Embedding for E-commerce Recommendation in Alibaba (Alibaba 2018).pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryCatLeung/deepwalk_node2vector_eges/b817c583459e27f07fabae59fcdb83a51b0fa81e/paper/[Alibaba Embedding] Billion-scale Commodity Embedding for E-commerce Recommendation in Alibaba (Alibaba 2018).pdf -------------------------------------------------------------------------------- /paper/[Graph Embedding] DeepWalk- Online Learning of Social Representations (SBU 2014).pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryCatLeung/deepwalk_node2vector_eges/b817c583459e27f07fabae59fcdb83a51b0fa81e/paper/[Graph Embedding] DeepWalk- Online Learning of Social Representations (SBU 2014).pdf -------------------------------------------------------------------------------- /paper/[Node2vec] Node2vec - Scalable Feature Learning for Networks (Stanford 2016).pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryCatLeung/deepwalk_node2vector_eges/b817c583459e27f07fabae59fcdb83a51b0fa81e/paper/[Node2vec] Node2vec - Scalable Feature Learning for Networks (Stanford 2016).pdf --------------------------------------------------------------------------------