├── src ├── scala │ ├── config │ │ ├── vars.prod.properties │ │ └── vars.sit.properties │ ├── examples │ │ ├── Example1.scala │ │ └── Example2.scala │ ├── models │ │ ├── FeatureBuilder.scala │ │ └── ItemEmbedding.scala │ ├── core │ │ ├── BaseSparkOnline.scala │ │ └── BaseSparkLocal.scala │ ├── prediction │ │ ├── ItemEmbeddingPredictor.scala │ │ ├── PredictUserVector.scala │ │ └── PredictUserVectorTwo.scala │ └── data │ │ ├── MakeDataOne.scala │ │ └── MakeDataTwo.scala └── python │ ├── examples │ ├── example2.py │ └── example1.py │ ├── data │ ├── tfrecords_methods │ │ ├── tfrecords │ │ │ └── data1.tfrecords │ │ ├── data2tfrecord1.py │ │ ├── read_sparse_tfrecords_2.py │ │ └── read_sparse_tfrecords_1.py │ ├── read_tfrecords.py │ └── data2tfrecords.py │ ├── utils │ └── tensor_board.py │ ├── reference │ ├── feature_column.py │ └── self_defined_network_layer.py │ └── models │ ├── load_dnn_model.py │ └── dnn.py ├── requirements.txt ├── data ├── tfrecords │ └── tfrecords │ │ ├── train │ │ └── train.tfrecords │ │ └── evaluation │ │ └── evaluation.tfrecords └── checkpoints │ └── ckpt │ ├── events.out.tfevents.1575536459.CNHQ-18076444T │ ├── eval │ └── events.out.tfevents.1575536462.CNHQ-18076444T │ └── checkpoint ├── docs ├── Deep Neural Networks for YouTube Recommendations.pdf └── architecture.drawio ├── .gitignore ├── README.md └── pom.xml /src/scala/config/vars.prod.properties: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/scala/config/vars.sit.properties: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/python/examples/example2.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow>=2.0.0 2 | numpy>=1.19.0 3 | pandas>=1.3.0 4 | matplotlib>=3.3.0 5 | seaborn>=0.11.0 6 | scikit-learn>=1.0.0 7 | -------------------------------------------------------------------------------- /data/tfrecords/tfrecords/train/train.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxingqiang/YouTube-DNN-RecSys/HEAD/data/tfrecords/tfrecords/train/train.tfrecords -------------------------------------------------------------------------------- /data/tfrecords/tfrecords/evaluation/evaluation.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxingqiang/YouTube-DNN-RecSys/HEAD/data/tfrecords/tfrecords/evaluation/evaluation.tfrecords -------------------------------------------------------------------------------- /docs/Deep Neural Networks for YouTube Recommendations.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxingqiang/YouTube-DNN-RecSys/HEAD/docs/Deep Neural Networks for YouTube Recommendations.pdf -------------------------------------------------------------------------------- /src/python/data/tfrecords_methods/tfrecords/data1.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxingqiang/YouTube-DNN-RecSys/HEAD/src/python/data/tfrecords_methods/tfrecords/data1.tfrecords -------------------------------------------------------------------------------- /data/checkpoints/ckpt/events.out.tfevents.1575536459.CNHQ-18076444T: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxingqiang/YouTube-DNN-RecSys/HEAD/data/checkpoints/ckpt/events.out.tfevents.1575536459.CNHQ-18076444T -------------------------------------------------------------------------------- /data/checkpoints/ckpt/eval/events.out.tfevents.1575536462.CNHQ-18076444T: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxingqiang/YouTube-DNN-RecSys/HEAD/data/checkpoints/ckpt/eval/events.out.tfevents.1575536462.CNHQ-18076444T -------------------------------------------------------------------------------- /data/checkpoints/ckpt/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt-1600" 2 | all_model_checkpoint_paths: "model.ckpt-1200" 3 | all_model_checkpoint_paths: "model.ckpt-1300" 4 | all_model_checkpoint_paths: "model.ckpt-1400" 5 | all_model_checkpoint_paths: "model.ckpt-1500" 6 | all_model_checkpoint_paths: "model.ckpt-1600" 7 | -------------------------------------------------------------------------------- /src/scala/examples/Example1.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import sparkapplication.BaseSparkLocal 4 | import scala.collection.mutable 5 | 6 | object Example1 extends BaseSparkLocal { 7 | def main(args:Array[String]):Unit = { 8 | val spark = this.basicSpark 9 | import spark.implicits._ 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/python/examples/example1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import json 5 | import tensorflow as tf 6 | 7 | a = tf.constant([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) #3*2 8 | b = tf.constant([[1, 0], [2, 1], [0, 1]]) #3*2 9 | c = tf.nn.embedding_lookup(a, b) 10 | d = tf.reduce_mean(c, axis=1) 11 | e = tf.concat([d, a], 1) 12 | 13 | with tf.Session() as sess: 14 | print(c) 15 | print(sess.run(c)) 16 | print(d) 17 | print(sess.run(d)) 18 | print(e) 19 | print(sess.run(e)) 20 | 21 | -------------------------------------------------------------------------------- /src/scala/examples/Example2.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import sparkapplication.BaseSparkLocal 4 | 5 | object Example2 extends BaseSparkLocal { 6 | def main(args:Array[String]):Unit = { 7 | // val spark = this.basicSpark 8 | // import spark.implicits._ 9 | 10 | val gdsVector = "[8534.033203125,-6634.611328125,-20669.0703125,-9483.734375,8790.3935546875,15647.646484375,-15543.39453125,34464.3203125,-1275.48974609375,28998.267578125,2446.0126953125,32628.033203125,1429.67431640625,37169.6640625,1902.3770751953125,-31038.359375]" 11 | val gds123 = gdsVector.replace("[", "").replace("]", "").split(",", -1).map(_.toDouble) 12 | gds123.foreach(println(_)) 13 | 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/scala/models/FeatureBuilder.scala: -------------------------------------------------------------------------------- 1 | package explore 2 | 3 | import org.tensorflow.example._ 4 | import org.tensorflow.spark.shaded.com.google.protobuf.ByteString 5 | 6 | object FeatureBuilder { 7 | def s(strings: String*): Feature = { 8 | val b = BytesList.newBuilder 9 | for (s <- strings) { 10 | b.addValue(ByteString.copyFromUtf8(s)) 11 | } 12 | Feature.newBuilder.setBytesList(b).build 13 | } 14 | 15 | def f(values: Float*): Feature = { 16 | val b = FloatList.newBuilder 17 | for (v <- values) { 18 | b.addValue(v) 19 | } 20 | Feature.newBuilder.setFloatList(b).build 21 | } 22 | 23 | def i(values: Int*): Feature = { 24 | val b = Int64List.newBuilder 25 | for (v <- values) { 26 | b.addValue(v) 27 | } 28 | Feature.newBuilder.setInt64List(b).build 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/scala/core/BaseSparkOnline.scala: -------------------------------------------------------------------------------- 1 | package sparkapplication 2 | 3 | import org.apache.spark.SparkConf 4 | import org.apache.spark.sql.SparkSession 5 | 6 | trait BaseSparkOnline { 7 | def basicSpark: SparkSession = 8 | SparkSession 9 | .builder 10 | .config(getSparkConf) 11 | .enableHiveSupport() 12 | .getOrCreate() 13 | 14 | def getSparkConf: SparkConf = { 15 | val conf = new SparkConf() 16 | conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 17 | .set("spark.network.timeout", "6000") 18 | .set("spark.streaming.kafka.maxRatePerPartition", "200000") 19 | .set("spark.streaming.kafka.consumer.poll.ms", "5120") 20 | .set("spark.streaming.concurrentJobs", "5") 21 | .set("spark.sql.crossJoin.enabled", "true") 22 | .set("spark.driver.maxResultSize", "20g") 23 | .set("spark.rpc.message.maxSize", "1000") // 1024 max 24 | } 25 | 26 | } 27 | -------------------------------------------------------------------------------- /src/scala/core/BaseSparkLocal.scala: -------------------------------------------------------------------------------- 1 | package sparkapplication 2 | 3 | import org.apache.spark.SparkConf 4 | import org.apache.spark.sql.SparkSession 5 | 6 | trait BaseSparkLocal { 7 | //本地 8 | def basicSpark: SparkSession = 9 | SparkSession 10 | .builder 11 | .config(getSparkConf) 12 | .master("local[1]") 13 | .getOrCreate() 14 | 15 | def getSparkConf: SparkConf = { 16 | val conf = new SparkConf() 17 | conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 18 | .set("spark.network.timeout", "600") 19 | .set("spark.streaming.kafka.maxRatePerPartition", "200000") 20 | .set("spark.streaming.kafka.consumer.poll.ms", "5120") 21 | .set("spark.streaming.concurrentJobs", "5") 22 | .set("spark.sql.crossJoin.enabled", "true") 23 | .set("spark.driver.maxResultSize", "1g") 24 | .set("spark.rpc.message.maxSize", "1000") // 1024 max 25 | conf 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/python/utils/tensor_board.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import tensorflow as tf 4 | import sys 5 | from tensorflow.python.platform import gfile 6 | from tensorflow.core.protobuf import saved_model_pb2 7 | from tensorflow.python.util import compat 8 | 9 | # 运行完后, tensorboard --logdir ./logdir, 然后在浏览器中输入地址: http://localhost:6006/ 10 | with tf.Session() as sess: 11 | model_filename ='../../data/checkpoints/modelpath/1575536466/saved_model.pb' 12 | with gfile.FastGFile(model_filename, 'rb') as f: 13 | data = compat.as_bytes(f.read()) 14 | sm = saved_model_pb2.SavedModel() 15 | sm.ParseFromString(data) 16 | 17 | if 1 != len(sm.meta_graphs): 18 | print('More than one graph found. Not sure which to write') 19 | sys.exit(1) 20 | 21 | g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def) 22 | LOGDIR='../../data/checkpoints/logdir' 23 | train_writer = tf.summary.FileWriter(LOGDIR) 24 | train_writer.add_graph(sess.graph) 25 | train_writer.flush() 26 | train_writer.close() -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | MANIFEST 23 | 24 | # Jupyter Notebook 25 | .ipynb_checkpoints 26 | 27 | # pyenv 28 | .python-version 29 | 30 | # Environments 31 | .env 32 | .venv 33 | env/ 34 | venv/ 35 | ENV/ 36 | env.bak/ 37 | venv.bak/ 38 | 39 | # IDE 40 | .vscode/ 41 | .idea/ 42 | *.swp 43 | *.swo 44 | *~ 45 | 46 | # OS 47 | .DS_Store 48 | .DS_Store? 49 | ._* 50 | .Spotlight-V100 51 | .Trashes 52 | ehthumbs.db 53 | Thumbs.db 54 | 55 | # TensorFlow 56 | *.ckpt 57 | *.meta 58 | *.data-00000-of-00001 59 | *.index 60 | *.pb 61 | *.pbtxt 62 | 63 | # Logs 64 | *.log 65 | logdir/ 66 | tensorboard_logs/ 67 | 68 | # Data 69 | 70 | 71 | 72 | # Scala/Java 73 | target/ 74 | *.class 75 | *.jar 76 | *.war 77 | *.ear 78 | *.zip 79 | *.tar.gz 80 | *.rar 81 | hs_err_pid* 82 | 83 | # Maven 84 | .mvn/ 85 | mvnw 86 | mvnw.cmd 87 | -------------------------------------------------------------------------------- /src/python/data/tfrecords_methods/data2tfrecord1.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | if __name__ == "__main__": 5 | per_item_sample_num = 20 6 | item_num = 15 7 | embedding_size = 8 8 | filename = "../../data/tfrecords_methods/tfrecords/data1.tfrecords" 9 | writer = tf.python_io.TFRecordWriter(filename) 10 | for i in range(per_item_sample_num): 11 | for j in range(item_num): 12 | embedding_average = np.random.uniform(low=j, high=j + 1.0, size=[embedding_size]) 13 | index = j 14 | example = tf.train.Example(features=tf.train.Features(feature={ 15 | "embedding_average": tf.train.Feature(float_list=tf.train.FloatList(value=embedding_average)), 16 | "index": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 17 | "value": tf.train.Feature(float_list=tf.train.FloatList(value=[1.0])), 18 | "size": tf.train.Feature(int64_list=tf.train.Int64List(value=[item_num])) 19 | })) 20 | writer.write(example.SerializeToString()) 21 | writer.close() -------------------------------------------------------------------------------- /src/python/data/read_tfrecords.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def parse_fn(example): 4 | example_fmt = { 5 | "visit_items_index": tf.FixedLenFeature([5], tf.int64), 6 | "continuous_features_value": tf.FixedLenFeature([16], tf.float32), 7 | "next_visit_item_index": tf.FixedLenFeature([], tf.int64) 8 | } 9 | parsed = tf.parse_single_example(example, example_fmt) 10 | next_visit_item_index = parsed.pop("next_visit_item_index") 11 | return parsed, next_visit_item_index 12 | 13 | if __name__ == "__main__": 14 | files = tf.data.Dataset.list_files('../../data/tfrecords/train/train.tfrecords', shuffle=True) 15 | data_set = files.apply( 16 | tf.contrib.data.parallel_interleave( 17 | lambda filename: tf.data.TFRecordDataset(filename), 18 | cycle_length=16)) 19 | data_set = data_set.repeat(1) 20 | data_set = data_set.map(map_func=parse_fn, num_parallel_calls=16) 21 | data_set = data_set.prefetch(buffer_size=64) 22 | data_set = data_set.batch(batch_size=16) 23 | iterator = data_set.make_one_shot_iterator() 24 | res1, res2 = iterator.get_next() 25 | 26 | with tf.Session() as sess: 27 | for i in range(5): 28 | result1, result2 = sess.run([res1, res2]) 29 | print("第{}批:".format(i), end=" ") 30 | print("result1是:", result1) 31 | print("result2是:", result2) -------------------------------------------------------------------------------- /src/python/data/tfrecords_methods/read_sparse_tfrecords_2.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def parse_fn(example): 4 | example_fmt = { 5 | "embedding_average": tf.FixedLenFeature([8], tf.float32), 6 | "one_hot": tf.SparseFeature(index_key=["index"], 7 | value_key="value", 8 | dtype=tf.float32, 9 | size=[15]) # size必须写死, 不能传超参 10 | } 11 | parsed = tf.parse_single_example(example, example_fmt) 12 | return parsed["embedding_average"], tf.sparse_tensor_to_dense(parsed["one_hot"]) 13 | 14 | if __name__ == "__main__": 15 | files = tf.data.Dataset.list_files('../../data/tfrecords_methods/tfrecords/data1.tfrecords', shuffle=True) 16 | data_set = files.apply( 17 | tf.contrib.data.parallel_interleave( 18 | lambda filename: tf.data.TFRecordDataset(filename), 19 | cycle_length=15)) 20 | data_set = data_set.repeat(1) 21 | data_set = data_set.map(map_func=parse_fn, num_parallel_calls=15) 22 | data_set = data_set.prefetch(buffer_size=30) 23 | data_set = data_set.batch(batch_size=15) 24 | iterator = data_set.make_one_shot_iterator() 25 | embedding, one_hot = iterator.get_next() 26 | 27 | with tf.Session() as sess: 28 | for i in range(5): 29 | embedding_result, one_hot_result = sess.run([embedding, one_hot]) 30 | print("第{}批:".format(i), end=" ") 31 | print("embedding是:", embedding_result, end=" ") 32 | print("one_hot是:", one_hot_result) -------------------------------------------------------------------------------- /src/python/data/tfrecords_methods/read_sparse_tfrecords_1.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def parse_fn(example): 4 | example_fmt = { 5 | "embedding_average": tf.FixedLenFeature([8], tf.float32), 6 | "index": tf.FixedLenFeature([], tf.int64), 7 | "value": tf.FixedLenFeature([], tf.float32), 8 | "size": tf.FixedLenFeature([], tf.int64) 9 | } 10 | parsed = tf.parse_single_example(example, example_fmt) 11 | sparse_tensor = tf.SparseTensor([[parsed["index"]]], [parsed["value"]], [parsed["size"]]) # 这种方法读取稀疏向量在有的平台可能不行 12 | return parsed["embedding_average"], tf.sparse_tensor_to_dense(sparse_tensor) 13 | 14 | if __name__ == "__main__": 15 | files = tf.data.Dataset.list_files('../../data/tfrecords_methods/tfrecords/data1.tfrecords', shuffle=True) 16 | data_set = files.apply( 17 | tf.contrib.data.parallel_interleave( 18 | lambda filename: tf.data.TFRecordDataset(filename), 19 | cycle_length=15)) 20 | data_set = data_set.repeat(1) 21 | data_set = data_set.map(map_func=parse_fn, num_parallel_calls=15) 22 | data_set = data_set.prefetch(buffer_size=30) 23 | data_set = data_set.batch(batch_size=15) 24 | iterator = data_set.make_one_shot_iterator() 25 | embedding, one_hot = iterator.get_next() 26 | 27 | with tf.Session() as sess: 28 | for i in range(5): 29 | embedding_result, one_hot_result = sess.run([embedding, one_hot]) 30 | print("第{}批:".format(i), end=" ") 31 | print("embedding是:", embedding_result, end=" ") 32 | print("one_hot是:", one_hot_result) -------------------------------------------------------------------------------- /src/scala/models/ItemEmbedding.scala: -------------------------------------------------------------------------------- 1 | package explore 2 | 3 | import org.tensorflow._ 4 | import sparkapplication.BaseSparkOnline 5 | import scala.collection.JavaConverters._ 6 | 7 | object ItemEmbeddingMakeDataOne extends BaseSparkOnline { 8 | def main(args:Array[String]):Unit = { 9 | val spark = this.basicSpark 10 | import spark.implicits._ 11 | 12 | val modelHdfsPath = "hdfs路径" 13 | val modelTag = "serve" 14 | 15 | val embeddingAverageArray = Array(Array.fill[Float](8)(0.1F)) 16 | val model = SavedModelBundle.load(modelHdfsPath, modelTag) 17 | val sess = model.session() 18 | val embeddingAverageArrayTensor = Tensor.create(embeddingAverageArray, classOf[java.lang.Float]) 19 | val itemEmbeddingResult = getItemEmbedding(sess, embeddingAverageArrayTensor) 20 | 21 | val result = spark.sparkContext.parallelize(itemEmbeddingResult.map(k => k.mkString("@"))).toDF("item_embedding") 22 | result.show(10, false) 23 | } 24 | 25 | private def getItemEmbedding(sess: Session, embeddingAverageArrayTensor: Tensor[_], embeddingAverageArrayName: String = "Placeholder:0", itemEmbeddingName: String = "item_embedding:0") = { 26 | val resultBuffer = sess.runner 27 | .feed(embeddingAverageArrayName, embeddingAverageArrayTensor) 28 | .fetch(itemEmbeddingName) 29 | .run.asScala 30 | 31 | val itemEmbedding = resultBuffer.head 32 | val itemEmbeddingShape: Array[Int] = itemEmbedding.shape.map(_.toInt) 33 | val itemEmbeddingResult = Array.ofDim[Float](itemEmbeddingShape.head, itemEmbeddingShape(1)) 34 | itemEmbedding.copyTo(itemEmbeddingResult) 35 | 36 | itemEmbeddingResult 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/scala/prediction/ItemEmbeddingPredictor.scala: -------------------------------------------------------------------------------- 1 | package explore 2 | 3 | import org.tensorflow._ 4 | import org.tensorflow.example._ 5 | import explore.FeatureBuilder._ 6 | import sparkapplication.BaseSparkOnline 7 | import scala.collection.JavaConverters._ 8 | 9 | object ItemEmbeddingMakeDataTwo extends BaseSparkOnline { 10 | def main(args:Array[String]):Unit = { 11 | val spark = this.basicSpark 12 | import spark.implicits._ 13 | 14 | val modelHdfsPath = "hdfs路径" 15 | val modelTag = "serve" 16 | 17 | val embeddingAverage = Array.fill[Float](8)(0.1F) 18 | val gender = "male" 19 | val cityCd = "city_cd_100" 20 | val featuresBuilder = Features.newBuilder 21 | .putFeature("embedding_average", f(embeddingAverage:_*)) 22 | .putFeature("gender", s(gender)) 23 | .putFeature("city_cd", s(cityCd)) 24 | featuresBuilder.build() 25 | val features = Example.newBuilder.setFeatures(featuresBuilder).build.toByteArray 26 | 27 | val model = SavedModelBundle.load(modelHdfsPath, modelTag) 28 | val sess = model.session() 29 | val embeddingAverageArrayTensor = Tensor.create(Array(features)) 30 | val itemEmbeddingResult = getItemEmbedding(sess, embeddingAverageArrayTensor) 31 | 32 | val result = spark.sparkContext.parallelize(itemEmbeddingResult.map(k => k.mkString("@"))).toDF("item_embedding") 33 | result.show(50, false) 34 | } 35 | 36 | private def getItemEmbedding(sess: Session, featuresArrayTensor: Tensor[_], featuresArrayName: String = "input_example_tensor:0", itemEmbeddingName: String = "item_embedding:0") = { 37 | val resultBuffer = sess.runner 38 | .feed(featuresArrayName, featuresArrayTensor) 39 | .fetch(itemEmbeddingName) 40 | .run.asScala 41 | 42 | val itemEmbedding = resultBuffer.head 43 | val itemEmbeddingShape: Array[Int] = itemEmbedding.shape.map(_.toInt) 44 | val itemEmbeddingResult = Array.ofDim[Float](itemEmbeddingShape.head, itemEmbeddingShape(1)) 45 | itemEmbedding.copyTo(itemEmbeddingResult) 46 | 47 | itemEmbeddingResult 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/scala/prediction/PredictUserVector.scala: -------------------------------------------------------------------------------- 1 | package explore 2 | 3 | import org.tensorflow._ 4 | import sparkapplication.BaseSparkOnline 5 | import scala.collection.JavaConverters._ 6 | 7 | object PredictUserVectorMakeDataOne extends BaseSparkOnline { 8 | def main(args:Array[String]):Unit = { 9 | val spark = this.basicSpark 10 | import spark.implicits._ 11 | 12 | val modelHdfsPath = "hdfs路径" 13 | val modelTag = "serve" 14 | 15 | val dataValidation = spark.read.format("tfrecords") 16 | .option("recordType", "Example") 17 | .load("hdfs路径") 18 | .rdd.map{row => 19 | val embeddingAverage = row.getAs[scala.collection.mutable.WrappedArray[Float]]("embedding_average") 20 | embeddingAverage.toArray 21 | } 22 | println(s"验证集数据dataValidation总数为:${dataValidation.count},数据格式如下:") 23 | dataValidation.toDF("embedding_average").show(5, false) 24 | 25 | val userVectorAll = dataValidation.mapPartitions(lineIterator => { 26 | val embeddingAverageArray = lineIterator.toArray 27 | val model = SavedModelBundle.load(modelHdfsPath, modelTag) 28 | val sess = model.session() 29 | val embeddingAverageArrayTensor = Tensor.create(embeddingAverageArray, classOf[java.lang.Float]) 30 | val userVectorResult = predictUserVector(sess, embeddingAverageArrayTensor) 31 | userVectorResult.toIterator 32 | }) 33 | 34 | val result = userVectorAll.map(k => k.mkString("@")).toDF("user_vector") 35 | result.show(10, false) 36 | } 37 | 38 | private def predictUserVector(sess: Session, embeddingAverageArrayTensor: Tensor[_], embeddingAverageArrayName: String = "Placeholder:0", userVectorName: String = "user_vector/Relu:0") = { 39 | val resultBuffer = sess.runner 40 | .feed(embeddingAverageArrayName, embeddingAverageArrayTensor) 41 | .fetch(userVectorName) 42 | .run.asScala 43 | 44 | val userVector = resultBuffer.head 45 | val userVectorShape: Array[Int] = userVector.shape.map(_.toInt) 46 | val userVectorResult = Array.ofDim[Float](userVectorShape.head, userVectorShape(1)) 47 | userVector.copyTo(userVectorResult) 48 | 49 | userVectorResult 50 | } 51 | 52 | } 53 | -------------------------------------------------------------------------------- /src/python/reference/feature_column.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | member_id = "member_id_{}".format(1) 5 | gds_cd = "gds_cd_{}".format(1) 6 | age = np.random.randint(18, 60) 7 | height = np.random.uniform(170.0, 190.0) 8 | example = tf.train.Example(features=tf.train.Features(feature={ 9 | "member_id": tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.compat.as_bytes(member_id)])), 10 | "gds_cd": tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.compat.as_bytes(gds_cd)])), 11 | "age": tf.train.Feature(int64_list=tf.train.Int64List(value=[age])), 12 | "height": tf.train.Feature(float_list=tf.train.FloatList(value=[height])) 13 | })) 14 | serialized_example = example.SerializeToString() 15 | 16 | example_fmt = { 17 | "member_id": tf.FixedLenFeature([1], tf.string), 18 | "gds_cd": tf.FixedLenFeature([1], tf.string), 19 | "age": tf.FixedLenFeature([1], tf.int64), 20 | "height": tf.FixedLenFeature([1], tf.float32) 21 | } 22 | parsed = tf.parse_single_example(serialized_example, example_fmt) 23 | 24 | member_id = tf.feature_column.categorical_column_with_hash_bucket("member_id", hash_bucket_size=3) 25 | gds_cd = tf.feature_column.categorical_column_with_hash_bucket("gds_cd", hash_bucket_size=3) 26 | age = tf.feature_column.categorical_column_with_vocabulary_list("age", [i for i in range(3)], dtype=tf.int64, 27 | default_value=0) 28 | height = tf.feature_column.numeric_column("height") 29 | member_id_indicator = tf.feature_column.indicator_column(member_id) 30 | gds_cd_indicator = tf.feature_column.indicator_column(gds_cd) 31 | age_indicator = tf.feature_column.indicator_column(age) 32 | feature_columns = [member_id_indicator, gds_cd_indicator, age_indicator, height] 33 | _result = tf.feature_column.input_layer(parsed, feature_columns) 34 | 35 | with tf.Session() as sess: 36 | sess.run(tf.global_variables_initializer()) 37 | sess.run(tf.tables_initializer()) 38 | parsed_result = sess.run([parsed]) 39 | print("parsed_result是:", parsed_result) 40 | result = sess.run([_result]) 41 | print("result是:", result) 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /src/python/data/data2tfrecords.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | if __name__ == "__main__": 5 | sample_num = 5000 6 | item_num = 500 7 | sample_set = [] 8 | for i in range(sample_num): 9 | visit_items_index = np.random.randint(low=0, high=item_num, size=[5]) 10 | continuous_features_value = np.random.uniform(low=-5.0, high=5.0, size=[16]) 11 | next_visit_item_index = np.random.randint(low=0, high=item_num) 12 | sample = [visit_items_index, continuous_features_value, next_visit_item_index] 13 | sample_set.append(sample) 14 | 15 | # 训练数据 16 | filename = "../../data/tfrecords/train/train.tfrecords" 17 | writer = tf.python_io.TFRecordWriter(filename) 18 | for sample in sample_set: 19 | visit_items_index = sample[0] 20 | continuous_features_value = sample[1] 21 | next_visit_item_index = sample[2] 22 | example = tf.train.Example(features=tf.train.Features(feature={ 23 | "visit_items_index": tf.train.Feature(int64_list=tf.train.Int64List(value=visit_items_index)), 24 | "continuous_features_value": tf.train.Feature( 25 | float_list=tf.train.FloatList(value=continuous_features_value)), 26 | "next_visit_item_index": tf.train.Feature(int64_list=tf.train.Int64List(value=[next_visit_item_index])) 27 | })) 28 | writer.write(example.SerializeToString()) 29 | writer.close() 30 | 31 | # 评估数据, 由于数据是随机生成, 所以评估数据从训练数据中取 32 | filename = "../../data/tfrecords/evaluation/evaluation.tfrecords" 33 | writer = tf.python_io.TFRecordWriter(filename) 34 | i = 0 35 | for sample in sample_set: 36 | if i % 10 == 0: 37 | visit_items_index = sample[0] 38 | continuous_features_value = sample[1] 39 | next_visit_item_index = sample[2] 40 | example = tf.train.Example(features=tf.train.Features(feature={ 41 | "visit_items_index": tf.train.Feature(int64_list=tf.train.Int64List(value=visit_items_index)), 42 | "continuous_features_value": tf.train.Feature( 43 | float_list=tf.train.FloatList(value=continuous_features_value)), 44 | "next_visit_item_index": tf.train.Feature(int64_list=tf.train.Int64List(value=[next_visit_item_index])) 45 | })) 46 | writer.write(example.SerializeToString()) 47 | i = i + 1 48 | writer.close() -------------------------------------------------------------------------------- /src/scala/data/MakeDataOne.scala: -------------------------------------------------------------------------------- 1 | package explore 2 | 3 | import java.util.Random 4 | import org.apache.spark.sql._ 5 | import sparkapplication.BaseSparkOnline 6 | import scala.collection.mutable.ArrayBuffer 7 | 8 | object MakeDataOne extends BaseSparkOnline { 9 | def main(args:Array[String]):Unit = { 10 | val spark = this.basicSpark 11 | import spark.implicits._ 12 | 13 | // 训练数据 14 | var perItemSampleNum = 20 15 | var itemNum = 15 16 | var embeddingSize = 8 17 | val trainData = ArrayBuffer[(Array[Double], Long, Double, Long)]() 18 | for(i <- 0 until perItemSampleNum) { 19 | for(j <- 0 until itemNum){ 20 | val embeddingAverage = Array.fill[Double](embeddingSize)(1.0*j + (new Random).nextDouble()) 21 | trainData.append((embeddingAverage, j.toLong, 1.0, itemNum.toLong)) 22 | } 23 | } 24 | val trainDataFrame = spark.sparkContext.parallelize(trainData, 10).toDF("embedding_average", "index", "value", "size") 25 | 26 | // Save DataFrame as TFRecords 27 | trainDataFrame.write.mode(SaveMode.Overwrite).format("tfrecords").option("recordType", "Example").save("hdfs路径") 28 | 29 | // Read TFRecords into DataFrame. 30 | val trainDataTfrecords: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load("hdfs路径") 31 | println("trainDataFrame重新加载tfrecords格式的数据,数据格式如下:") 32 | trainDataTfrecords.show(10, false) 33 | 34 | // 评估数据 35 | perItemSampleNum = 10 36 | itemNum = 15 37 | embeddingSize = 8 38 | val evaluationData = ArrayBuffer[(Array[Double], Long, Double, Long)]() 39 | for(i <- 0 until perItemSampleNum) { 40 | for(j <- 0 until itemNum){ 41 | val embeddingAverage = Array.fill[Double](embeddingSize)(1.0*j + (new Random).nextDouble()) 42 | evaluationData.append((embeddingAverage, j.toLong, 1.0, itemNum.toLong)) 43 | } 44 | } 45 | val evaluationDataFrame = spark.sparkContext.parallelize(evaluationData, 10).toDF("embedding_average", "index", "value", "size") 46 | 47 | // Save DataFrame as TFRecords 48 | evaluationDataFrame.write.mode(SaveMode.Overwrite).format("tfrecords").option("recordType", "Example").save("hdfs路径") 49 | 50 | // Read TFRecords into DataFrame. 51 | val evaluationDataTfrecords: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load("hdfs路径") 52 | println("evaluationData重新加载tfrecords格式的数据,数据格式如下:") 53 | evaluationDataTfrecords.show(10, false) 54 | 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/scala/prediction/PredictUserVectorTwo.scala: -------------------------------------------------------------------------------- 1 | package explore 2 | 3 | import org.tensorflow._ 4 | import org.tensorflow.example._ 5 | import explore.FeatureBuilder._ 6 | import sparkapplication.BaseSparkOnline 7 | import scala.collection.JavaConverters._ 8 | 9 | object PredictUserVectorMakeDataTwo extends BaseSparkOnline { 10 | def main(args:Array[String]):Unit = { 11 | val spark = this.basicSpark 12 | import spark.implicits._ 13 | 14 | val modelHdfsPath = "hdfs路径" 15 | val modelTag = "serve" 16 | 17 | val dataValidation = spark.read.format("tfrecords") 18 | .option("recordType", "Example") 19 | .load("hdfs路径") 20 | .rdd.map{row => 21 | val embeddingAverage = row.getAs[scala.collection.mutable.WrappedArray[Float]]("embedding_average").toArray 22 | val gender = row.getAs[String]("gender") 23 | val cityCd = row.getAs[String]("city_cd") 24 | val featuresBuilder = Features.newBuilder 25 | .putFeature("embedding_average", f(embeddingAverage:_*)) 26 | .putFeature("gender", s(gender)) 27 | .putFeature("city_cd", s(cityCd)) 28 | featuresBuilder.build() 29 | val features = Example.newBuilder.setFeatures(featuresBuilder).build.toByteArray 30 | features 31 | } 32 | println(s"验证集数据dataValidation总数为:${dataValidation.count},数据格式如下:") 33 | dataValidation.toDF("features").show(5, false) 34 | 35 | val userVectorAll = dataValidation.mapPartitions(lineIterator => { 36 | val featuresArray = lineIterator.toArray 37 | val model = SavedModelBundle.load(modelHdfsPath, modelTag) 38 | val sess = model.session() 39 | val featuresArrayTensor = Tensor.create(featuresArray) 40 | val userVectorResult = predictUserVector(sess, featuresArrayTensor) 41 | userVectorResult.toIterator 42 | }) 43 | 44 | val result = userVectorAll.map(k => k.mkString("@")).toDF("user_vector") 45 | result.show(10, false) 46 | } 47 | 48 | private def predictUserVector(sess: Session, featuresArrayTensor: Tensor[_], featuresArrayName: String = "input_example_tensor:0", userVectorName: String = "user_vector/Relu:0") = { 49 | val resultBuffer = sess.runner 50 | .feed(featuresArrayName, featuresArrayTensor) 51 | .fetch(userVectorName) 52 | .run.asScala 53 | 54 | val userVector = resultBuffer.head 55 | val userVectorShape: Array[Int] = userVector.shape.map(_.toInt) 56 | val userVectorResult = Array.ofDim[Float](userVectorShape.head, userVectorShape(1)) 57 | userVector.copyTo(userVectorResult) 58 | 59 | userVectorResult 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /src/scala/data/MakeDataTwo.scala: -------------------------------------------------------------------------------- 1 | package explore 2 | 3 | import java.util.Random 4 | import org.apache.spark.sql._ 5 | import sparkapplication.BaseSparkOnline 6 | import scala.collection.mutable.ArrayBuffer 7 | 8 | object MakeDataTwo extends BaseSparkOnline { 9 | def main(args:Array[String]):Unit = { 10 | val spark = this.basicSpark 11 | import spark.implicits._ 12 | 13 | // 训练数据 14 | var perItemSampleNum = 20 15 | var itemNum = 15 16 | var embeddingSize = 8 17 | val trainData = ArrayBuffer[(Array[Double], String, String, Long)]() 18 | for(i <- 0 until perItemSampleNum) { 19 | for(j <- 0 until itemNum){ 20 | val embeddingAverage = Array.fill[Double](embeddingSize)(1.0*j + (new Random).nextDouble()) 21 | val gender = if(j % 2 == 0) "male" else "female" 22 | val cityCd = "city_cd_" + (new Random).nextInt(200).toString 23 | trainData.append((embeddingAverage, gender, cityCd, j.toLong)) 24 | } 25 | } 26 | val trainDataFrame = spark.sparkContext.parallelize(trainData, 10).toDF("embedding_average", "gender", "city_cd", "index") 27 | 28 | // Save DataFrame as TFRecords 29 | trainDataFrame.write.mode(SaveMode.Overwrite).format("tfrecords").option("recordType", "Example").save("hdfs路径") 30 | 31 | // Read TFRecords into DataFrame. 32 | val trainDataTfrecords: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load("hdfs路径") 33 | println("trainDataFrame重新加载tfrecords格式的数据,数据格式如下:") 34 | trainDataTfrecords.show(10, false) 35 | 36 | // 评估数据 37 | perItemSampleNum = 10 38 | itemNum = 15 39 | embeddingSize = 8 40 | val evaluationData = ArrayBuffer[(Array[Double], String, String, Long)]() 41 | for(i <- 0 until perItemSampleNum) { 42 | for(j <- 0 until itemNum){ 43 | val embeddingAverage = Array.fill[Double](embeddingSize)(1.0*j + (new Random).nextDouble()) 44 | val gender = if(j % 2 == 0) "male" else "female" 45 | val cityCd = "city_cd_" + (new Random).nextInt(200).toString 46 | evaluationData.append((embeddingAverage, gender, cityCd, j.toLong)) 47 | } 48 | } 49 | val evaluationDataFrame = spark.sparkContext.parallelize(evaluationData, 10).toDF("embedding_average", "gender", "city_cd", "index") 50 | 51 | // Save DataFrame as TFRecords 52 | evaluationDataFrame.write.mode(SaveMode.Overwrite).format("tfrecords").option("recordType", "Example").save("hdfs路径") 53 | 54 | // Read TFRecords into DataFrame. 55 | val evaluationDataTfrecords: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load("hdfs路径") 56 | println("evaluationData重新加载tfrecords格式的数据,数据格式如下:") 57 | evaluationDataTfrecords.show(10, false) 58 | 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/python/models/load_dnn_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import tensorflow as tf 4 | 5 | class dataProcess(object): 6 | 7 | def parse_fn(self, example): 8 | example_fmt = { 9 | "visit_items_index": tf.FixedLenFeature([5], tf.int64), 10 | "continuous_features_value": tf.FixedLenFeature([16], tf.float32), 11 | "next_visit_item_index": tf.FixedLenFeature([], tf.int64) 12 | } 13 | parsed = tf.parse_single_example(example, example_fmt) 14 | parsed.pop("next_visit_item_index") 15 | return parsed 16 | 17 | def next_batch(self, batch_size): 18 | files = tf.data.Dataset.list_files( 19 | '../../data/tfrecords/train/train.tfrecords', shuffle=False 20 | ) 21 | data_set = files.apply( 22 | tf.contrib.data.parallel_interleave( 23 | lambda filename: tf.data.TFRecordDataset(filename), 24 | cycle_length=16)) 25 | data_set = data_set.map(map_func=self.parse_fn, num_parallel_calls=16) 26 | data_set = data_set.prefetch(buffer_size=256) 27 | data_set = data_set.batch(batch_size=batch_size) 28 | iterator = data_set.make_one_shot_iterator() 29 | features = iterator.get_next() 30 | return features 31 | 32 | if __name__ == "__main__": 33 | # 数据预处理# 34 | dataProcess = dataProcess() 35 | features = dataProcess.next_batch(batch_size=16) 36 | 37 | signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 38 | with tf.Session() as sess: 39 | meta_graph_def = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], 40 | "../../data/checkpoints/modelpath/1575536466") 41 | print(meta_graph_def) 42 | signature = meta_graph_def.signature_def 43 | visit_items_index_tensor_name = signature[signature_key].inputs["visit_items_index"].name 44 | visit_items_index_tensor = sess.graph.get_tensor_by_name(visit_items_index_tensor_name) 45 | continuous_features_value_tensor_name = signature[signature_key].inputs["continuous_features_value"].name 46 | continuous_features_value_tensor = sess.graph.get_tensor_by_name(continuous_features_value_tensor_name) 47 | user_vector_tensor_name = signature[signature_key].outputs["user_vector"].name 48 | user_vector_tensor = sess.graph.get_tensor_by_name(user_vector_tensor_name) 49 | index_tensor_name = signature[signature_key].outputs["index"].name 50 | index_tensor = sess.graph.get_tensor_by_name(index_tensor_name) 51 | 52 | features_result = sess.run(features) 53 | feed_dict = {visit_items_index_tensor: features_result["visit_items_index"], continuous_features_value_tensor: features_result["continuous_features_value"]} 54 | predict_outputs = sess.run([user_vector_tensor, index_tensor], feed_dict=feed_dict) 55 | print(predict_outputs[0]) 56 | print("==========") 57 | print(predict_outputs[1]) -------------------------------------------------------------------------------- /src/python/reference/self_defined_network_layer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.python.keras import initializers 6 | 7 | class SampleLayer(tf.keras.layers.Layer): 8 | def __init__(self, is_training, top_k, item_num, 9 | kernel_initializer=tf.initializers.random_uniform(minval=-0.1, maxval=0.1), **kwargs): 10 | self.is_training = is_training 11 | self.top_k = top_k 12 | self.item_num = item_num 13 | self.kernel_initializer = kernel_initializer 14 | super(SampleLayer, self).__init__(**kwargs) 15 | 16 | def build(self, input_shape): 17 | assert isinstance(input_shape, list) 18 | input_shape0 = input_shape[0] 19 | # 为该层创建一个可训练的权重 20 | partitioner = tf.compat.v1.fixed_size_partitioner(num_shards=int(input_shape0[1])) 21 | self.kernel = self.add_weight(name="item_embedding", 22 | shape=(self.item_num, int(input_shape0[1])), 23 | initializer=self.kernel_initializer, 24 | trainable=True, 25 | partitioner=partitioner) 26 | # 一定要在最后调用它 27 | super(SampleLayer, self).build(input_shape) 28 | 29 | def train_output(self, inputs0, inputs1): 30 | output_embedding = tf.nn.embedding_lookup(self.kernel, inputs1) # num * embedding_size 31 | logits = tf.matmul(inputs0, output_embedding, transpose_a=False, transpose_b=True) # num * num 32 | yhat = tf.nn.softmax(logits) # num * num 33 | cross_entropy = tf.reduce_mean(-tf.log(tf.matrix_diag_part(yhat) + 1e-16)) 34 | return cross_entropy 35 | 36 | def predict_output(self, inputs0): 37 | logits_predict = tf.matmul(inputs0, self.kernel, transpose_a=False, transpose_b=True) # num * item_num 38 | yhat_predict = tf.nn.softmax(logits_predict) # num * item_num 39 | _, indices = tf.nn.top_k(yhat_predict, k=self.top_k, sorted=True) # indices是: num * top_k 40 | indices = tf.cast(indices, tf.float32) # tf.keras.backend.switch输出类型必须一样, 所以将int转为float 41 | return indices 42 | 43 | def func1(self, inputs): 44 | assert len(inputs) == 2 45 | inputs1 = tf.cast(inputs[1], tf.int32) 46 | return inputs1 47 | 48 | def call(self, inputs, **kwargs): 49 | assert isinstance(inputs, list) 50 | inputs0 = inputs[0] # 上一层的输出 51 | inputs1_default = tf.zeros([inputs0.shape[0]], dtype=tf.int32) # 另外一个输入, 这是默认值 52 | inputs1 = tf.cond(self.is_training, lambda: self.func1(inputs), lambda: inputs1_default) 53 | # 如果训练的话, 输出是损失值; 如果预测的话, 输出是相似的top_k索引 54 | train_predict_output = tf.cond(self.is_training, lambda: self.train_output(inputs0, inputs1), 55 | lambda: self.predict_output(inputs0)) 56 | return train_predict_output 57 | 58 | def func2(self, input_shape): 59 | input_shape0 = input_shape[0] 60 | return (input_shape0[0], self.top_k) 61 | 62 | def compute_output_shape(self, input_shape): 63 | output_shape = tf.cond(self.is_training, lambda: (), lambda: self.func2(input_shape)) 64 | return output_shape 65 | 66 | def get_config(self): 67 | config = { 68 | 'is_training': self.is_training, 69 | 'top_k': self.top_k, 70 | 'item_num': self.item_num, 71 | 'kernel_initializer': initializers.serialize(self.kernel_initializer) 72 | } 73 | base_config = super(SampleLayer, self).get_config() 74 | return dict(list(base_config.items()) + list(config.items())) 75 | 76 | if __name__ == "__main__": 77 | inputs0 = tf.constant([[0.1, 0.2, 0.6, 0.3, 0.5], [0.8, 0.6, 0.9, 0.3, 0.5]]) 78 | inputs1 = tf.constant([0, 3]) 79 | sample_layer = SampleLayer(tf.constant(True), 3, 10, name="abc") 80 | result = sample_layer([inputs0, inputs1]) 81 | print(result) 82 | print(sample_layer.trainable_weights) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YouTube-DNN-RecSys: Deep Neural Networks for YouTube Recommendations 2 | 3 | ## Deep Neural Networks for YouTube Recommendations 4 | [Paper](https://dl.acm.org/doi/pdf/10.1145/2959100.2959190) 5 | 6 | A clean, well-organized implementation of Deep Neural Networks for YouTube Recommendations, featuring both Python (TensorFlow) and Scala (Spark) implementations. 7 | 8 | [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/chenxingqiang/YouTube-DNN-RecSys) 9 | 10 | ## Achitecture 11 | ![](architecture.drawio.svg) 12 | 13 | ## 🏗️ Project Structure 14 | 15 | ``` 16 | DNN-YouTube-RecSys/ 17 | ├── src/ 18 | │ ├── python/ # Python implementation using TensorFlow 19 | │ │ ├── models/ # Core DNN model and loading utilities 20 | │ │ ├── data/ # Data processing and TFRecords handling 21 | │ │ ├── utils/ # TensorBoard and utility functions 22 | │ │ ├── examples/ # Usage examples and tutorials 23 | │ │ └── reference/ # Custom layers and feature engineering 24 | │ └── scala/ # Scala implementation using Spark 25 | │ ├── models/ # Feature building and embedding models 26 | │ ├── data/ # Data generation scripts 27 | │ ├── prediction/ # User vector and item embedding prediction 28 | │ ├── core/ # Base Spark application classes 29 | │ ├── examples/ # Spark usage examples 30 | │ └── config/ # Environment configuration files 31 | ├── data/ # Data storage and model artifacts 32 | │ ├── tfrecords/ # Training and evaluation data 33 | │ └── checkpoints/ # Model checkpoints and saved models 34 | ├── tests/ # Test suites for both Python and Scala 35 | ├── docs/ # Research paper and documentation 36 | ├── requirements.txt # Python dependencies 37 | ├── pom.xml # Maven configuration for Scala 38 | └── .gitignore # Git ignore patterns 39 | ``` 40 | 41 | ## 🚀 Quick Start 42 | 43 | ### Python (TensorFlow) Implementation 44 | 45 | 1. **Install dependencies:** 46 | ```bash 47 | pip install -r requirements.txt 48 | ``` 49 | 50 | 2. **Train the model:** 51 | ```bash 52 | cd src/python 53 | python models/dnn.py 54 | ``` 55 | 56 | 3. **Run examples:** 57 | ```bash 58 | python examples/example1.py 59 | python examples/example2.py 60 | ``` 61 | 62 | ### Scala (Spark) Implementation 63 | 64 | 1. **Build the project:** 65 | ```bash 66 | mvn clean compile 67 | ``` 68 | 69 | 2. **Run examples:** 70 | ```bash 71 | mvn exec:java -Dexec.mainClass="example.Example1" 72 | ``` 73 | 74 | ## 📚 Key Components 75 | 76 | ### Python Implementation 77 | - **`models/dnn.py`**: Core deep neural network model 78 | - **`models/load_dnn_model.py`**: Model loading and inference utilities 79 | - **`data/data2tfrecords.py`**: Data conversion to TFRecords format 80 | - **`utils/tensor_board.py`**: TensorBoard integration for training visualization 81 | 82 | ### Scala Implementation 83 | - **`models/FeatureBuilder.scala`**: Feature engineering utilities 84 | - **`prediction/PredictUserVector.scala`**: User vector prediction 85 | - **`prediction/ItemEmbeddingPredictor.scala`**: Item embedding generation 86 | - **`core/BaseSparkLocal.scala`**: Local Spark application base class 87 | 88 | ## 🔧 Configuration 89 | 90 | - **Python**: Configure via `requirements.txt` and environment variables 91 | - **Scala**: Configure via `src/scala/config/` properties files 92 | - **Data**: Store training data in `data/tfrecords/` directory 93 | - **Models**: Save checkpoints in `data/checkpoints/` directory 94 | 95 | ## 📖 Documentation 96 | 97 | - **Research Paper**: `docs/Deep Neural Networks for YouTube Recommendations.pdf` 98 | - **Code Examples**: See `src/python/examples/` and `src/scala/examples/` 99 | - **Reference Implementations**: Check `src/python/reference/` for custom components 100 | 101 | ## 🤝 Contributing 102 | 103 | 1. Follow the established directory structure 104 | 2. Add tests in the appropriate `tests/` subdirectory 105 | 3. Update documentation for any new features 106 | 4. Ensure both Python and Scala implementations remain consistent 107 | 108 | ## 📄 License 109 | 110 | This project implements the research described in "Deep Neural Networks for YouTube Recommendations" paper. Please refer to the original paper for academic citations and research context. 111 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | algorithm.cedarmo 8 | deep_neural_networks_for_youtube_recommendations 9 | 1.0-SNAPSHOT 10 | 11 | 12 | 2.11 13 | ${scala.binary.version}.8 14 | 2.1.0 15 | compile 16 | 17 | 18 | 19 | org.scala-lang 20 | scala-library 21 | ${scala.version} 22 | 23 | 24 | org.scala-lang 25 | scala-compiler 26 | ${scala.version} 27 | 28 | 29 | commons-lang 30 | commons-lang 31 | 2.5 32 | 33 | 34 | 35 | org.apache.hadoop 36 | hadoop-auth 37 | 2.4.0 38 | 39 | 40 | commons-configuration 41 | commons-configuration 42 | 1.9 43 | 44 | 45 | org.apache.spark 46 | spark-core_${scala.binary.version} 47 | ${spark.version} 48 | 49 | 50 | 51 | org.apache.spark 52 | spark-sql_${scala.binary.version} 53 | ${spark.version} 54 | 55 | 56 | 57 | org.apache.spark 58 | spark-mllib_2.11 59 | ${spark.version} 60 | 61 | 62 | 63 | org.scalanlp 64 | breeze_2.11 65 | 0.13.2 66 | 67 | 68 | org.tensorflow 69 | spark-tensorflow-connector_2.11 70 | 1.13.1 71 | 72 | 73 | org.tensorflow 74 | tensorflow 75 | 1.13.1 76 | 77 | 78 | 79 | 80 | 81 | sit 82 | 83 | true 84 | 85 | 86 | 87 | ../${project.artifactId}/vars/vars.sit.properties 88 | 89 | 90 | 91 | src/main/resources 92 | true 93 | 94 | 95 | 96 | 97 | 98 | prod 99 | 100 | 101 | ../${project.artifactId}/vars/vars.prod.properties 102 | 103 | 104 | 105 | src/main/resources 106 | true 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | org.codehaus.mojo 117 | build-helper-maven-plugin 118 | 1.8 119 | 120 | 121 | add-source 122 | generate-sources 123 | 124 | add-source 125 | 126 | 127 | 128 | src/main/scala 129 | src/test/scala 130 | 131 | 132 | 133 | 134 | add-test-source 135 | generate-sources 136 | 137 | add-test-source 138 | 139 | 140 | 141 | src/test/scala 142 | 143 | 144 | 145 | 146 | 147 | 148 | net.alchim31.maven 149 | scala-maven-plugin 150 | 3.1.5 151 | 152 | 153 | compile 154 | testCompile 155 | 156 | 157 | 158 | ${scala.version} 159 | 160 | 161 | 162 | org.apache.maven.plugins 163 | maven-compiler-plugin 164 | 165 | 1.7 166 | 1.7 167 | utf-8 168 | 169 | 170 | 171 | compile 172 | 173 | compile 174 | 175 | 176 | 177 | 178 | 179 | maven-assembly-plugin 180 | 181 | 182 | jar-with-dependencies 183 | 184 | 185 | 186 | example.Example1 187 | 188 | 189 | 190 | 191 | 192 | make-assembly 193 | package 194 | 195 | assembly 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | -------------------------------------------------------------------------------- /src/python/models/dnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import json 5 | import tensorflow as tf 6 | 7 | def parse_fn(example): 8 | example_fmt = { 9 | "visit_items_index": tf.FixedLenFeature([5], tf.int64), 10 | "continuous_features_value": tf.FixedLenFeature([16], tf.float32), 11 | "next_visit_item_index": tf.FixedLenFeature([], tf.int64) 12 | } 13 | parsed = tf.parse_single_example(example, example_fmt) 14 | next_visit_item_index = parsed.pop("next_visit_item_index") 15 | return parsed, next_visit_item_index 16 | 17 | def input_fn(path, parallel_num, epoch_num, batch_size): 18 | files = tf.data.Dataset.list_files(path, shuffle=True) 19 | data_set = files.apply( 20 | tf.contrib.data.parallel_interleave( 21 | map_func=lambda filename: tf.data.TFRecordDataset(filename), 22 | cycle_length=parallel_num)) 23 | data_set = data_set.repeat(epoch_num) 24 | data_set = data_set.map(map_func=parse_fn, num_parallel_calls=parallel_num) 25 | data_set = data_set.prefetch(buffer_size=256) 26 | data_set = data_set.batch(batch_size=batch_size) 27 | return data_set 28 | 29 | def model_fn(features, labels, mode, params, config): 30 | 31 | visit_items_index = features["visit_items_index"] # num * 5 32 | continuous_features_value = features["continuous_features_value"] # num * 16 33 | next_visit_item_index = labels # num 34 | keep_prob = params["keep_prob"] 35 | embedding_size = params["embedding_size"] 36 | item_num = params["item_num"] 37 | learning_rate = params["learning_rate"] 38 | top_k = params["top_k"] 39 | 40 | # items embedding 初始化 41 | initializer = tf.initializers.random_uniform(minval=-0.5 / embedding_size, maxval=0.5 / embedding_size) 42 | partitioner = tf.fixed_size_partitioner(num_shards=embedding_size) 43 | item_embedding = tf.get_variable("item_embedding", [item_num, embedding_size], 44 | tf.float32, initializer=initializer, partitioner=partitioner) 45 | 46 | visit_items_embedding = tf.nn.embedding_lookup(item_embedding, visit_items_index) # num * 5 * embedding_size 47 | visit_items_average_embedding = tf.reduce_mean(visit_items_embedding, axis=1) # num * embedding_size 48 | input_embedding = tf.concat([visit_items_average_embedding, continuous_features_value], 1) # num * (embedding_size + 16) 49 | kernel_initializer_1 = tf.initializers.random_normal(mean=0.0, stddev=0.1) 50 | bias_initializer_1 = tf.initializers.random_normal(mean=0.0, stddev=0.1) 51 | layer_1 = tf.layers.dense(input_embedding, 64, activation=tf.nn.relu, 52 | kernel_initializer=kernel_initializer_1, 53 | bias_initializer=bias_initializer_1, name="layer_1") 54 | layer_dropout_1 = tf.nn.dropout(layer_1, keep_prob=keep_prob, name="layer_dropout_1") 55 | kernel_initializer_2 = tf.initializers.random_normal(mean=0.0, stddev=0.1) 56 | bias_initializer_2 = tf.initializers.random_normal(mean=0.0, stddev=0.1) 57 | layer_2 = tf.layers.dense(layer_dropout_1, 32, activation=tf.nn.relu, 58 | kernel_initializer=kernel_initializer_2, 59 | bias_initializer=bias_initializer_2, name="layer_2") 60 | layer_dropout_2 = tf.nn.dropout(layer_2, keep_prob=keep_prob, name="layer_dropout_2") 61 | # user vector, num * embedding_size 62 | kernel_initializer_3 = tf.initializers.random_normal(mean=0.0, stddev=0.1) 63 | bias_initializer_3 = tf.initializers.random_normal(mean=0.0, stddev=0.1) 64 | user_vector = tf.layers.dense(layer_dropout_2, embedding_size, activation=tf.nn.relu, 65 | kernel_initializer=kernel_initializer_3, 66 | bias_initializer=bias_initializer_3, name="user_vector") 67 | 68 | if mode == tf.estimator.ModeKeys.TRAIN: 69 | # 训练 70 | output_embedding = tf.nn.embedding_lookup(item_embedding, next_visit_item_index) # num * embedding_size 71 | logits = tf.matmul(user_vector, output_embedding, transpose_a=False, transpose_b=True) # num * num 72 | yhat = tf.nn.softmax(logits) # num * num 73 | cross_entropy = tf.reduce_mean(-tf.log(tf.matrix_diag_part(yhat) + 1e-16)) 74 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 75 | train = optimizer.minimize(cross_entropy, global_step=tf.train.get_global_step()) 76 | return tf.estimator.EstimatorSpec(mode, loss=cross_entropy, train_op=train) 77 | 78 | if mode == tf.estimator.ModeKeys.EVAL: 79 | # 评估 80 | output_embedding = tf.nn.embedding_lookup(item_embedding, next_visit_item_index) # num * embedding_size 81 | logits = tf.matmul(user_vector, output_embedding, transpose_a=False, transpose_b=True) # num * num 82 | yhat = tf.nn.softmax(logits) # num * num 83 | cross_entropy = tf.reduce_mean(-tf.log(tf.matrix_diag_part(yhat) + 1e-16)) 84 | return tf.estimator.EstimatorSpec(mode, loss=cross_entropy) 85 | 86 | if mode == tf.estimator.ModeKeys.PREDICT: 87 | logits_predict = tf.matmul(user_vector, item_embedding, transpose_a=False, transpose_b=True) # num * item_num 88 | yhat_predict = tf.nn.softmax(logits_predict) # num * item_num 89 | _, indices = tf.nn.top_k(yhat_predict, k=top_k, sorted=True) 90 | index = tf.identity(indices, name="index") # num * top_k 91 | # 预测 92 | predictions = { 93 | "user_vector": user_vector, 94 | "index": index 95 | } 96 | export_outputs = { 97 | "prediction": tf.estimator.export.PredictOutput(predictions) 98 | } 99 | return tf.estimator.EstimatorSpec(mode, predictions=predictions, export_outputs=export_outputs) 100 | 101 | def build_estimator(): 102 | params = {"keep_prob": 0.5, "embedding_size": 16, "item_num": 500, "learning_rate": 0.05, "top_k": 2} 103 | session_config = tf.ConfigProto(device_count={"CPU": 1}, allow_soft_placement=True, log_device_placement=False) 104 | session_config.gpu_options.allow_growth = True 105 | config = tf.estimator.RunConfig( 106 | model_dir="../../data/checkpoints/ckpt", 107 | tf_random_seed=2019, 108 | save_checkpoints_steps=100, 109 | session_config=session_config, 110 | keep_checkpoint_max=5, 111 | log_step_count_steps=100 112 | ) 113 | estimator = tf.estimator.Estimator(model_fn=model_fn, config=config, params=params) 114 | return estimator 115 | 116 | def set_dist_env(): 117 | if FLAGS.is_distributed: 118 | ps_hosts = FLAGS.strps_hosts.split(",") 119 | worker_hosts = FLAGS.strwork_hosts.split(",") 120 | job_name = FLAGS.job_name 121 | task_index = FLAGS.task_index 122 | chief_hosts = worker_hosts[0:1] # get first worker as chief 123 | worker_hosts = worker_hosts[2:] # the rest as worker 124 | 125 | # use #worker=0 as chief 126 | if job_name == "worker" and task_index == 0: 127 | job_name = "chief" 128 | # use #worker=1 as evaluator 129 | if job_name == "worker" and task_index == 1: 130 | job_name = 'evaluator' 131 | task_index = 0 132 | # the others as worker 133 | if job_name == "worker" and task_index > 1: 134 | task_index -= 2 135 | 136 | tf_config = {'cluster': {'chief': chief_hosts, 'worker': worker_hosts, 'ps': ps_hosts}, 137 | 'task': {'type': job_name, 'index': task_index}} 138 | os.environ['TF_CONFIG'] = json.dumps(tf_config) 139 | 140 | def train_eval_save(): 141 | 142 | set_dist_env() 143 | 144 | estimator = build_estimator() 145 | 146 | # 训练 147 | train_spec = tf.estimator.TrainSpec( 148 | input_fn=lambda: input_fn( 149 | path='../../data/tfrecords/train/train.tfrecords', 150 | parallel_num=32, 151 | epoch_num=11, 152 | batch_size=32), 153 | max_steps=1600 154 | ) 155 | # 评估 156 | eval_spec = tf.estimator.EvalSpec( 157 | input_fn=lambda: input_fn( 158 | path='../../data/tfrecords/evaluation/evaluation.tfrecords', 159 | parallel_num=32, 160 | epoch_num=1, 161 | batch_size=32), 162 | steps=15, # 验证集评估多少批数据 163 | start_delay_secs=1, # 在多少秒后 start_delay_secs=1, # 在多少秒后 164 | throttle_secs=20 # evaluate every 20seconds 165 | ) 166 | # 训练和评估 167 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) 168 | 169 | # 模型保存 170 | features_spec = { 171 | "visit_items_index": tf.placeholder(tf.int64, shape=[None, 5], name="visit_items_index"), 172 | "continuous_features_value": tf.placeholder(tf.float32, shape=[None, 16], name="continuous_features_value") 173 | } 174 | serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(features_spec) 175 | estimator.export_savedmodel( 176 | "../../data/checkpoints/modelpath", 177 | serving_input_receiver_fn) 178 | 179 | def main(_): 180 | train_eval_save() 181 | 182 | if __name__ == "__main__": 183 | tf.logging.set_verbosity(tf.logging.INFO) 184 | FLAGS = tf.app.flags.FLAGS 185 | tf.app.flags.DEFINE_boolean("is_distributed", False, "是否分布式训练") 186 | tf.app.flags.DEFINE_string("strps_hosts", "localhost:2000", "参数服务器") 187 | tf.app.flags.DEFINE_string("strwork_hosts", "localhost:2100,localhost:2200,localhost:2300,localhost:2400", "工作服务器") 188 | tf.app.flags.DEFINE_string("job_name", "ps", "参数服务器或者工作服务器") 189 | tf.app.flags.DEFINE_integer("task_index", 0, "job的task索引") 190 | tf.app.run(main=main) -------------------------------------------------------------------------------- /docs/architecture.drawio: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | --------------------------------------------------------------------------------