├── src
├── scala
│ ├── config
│ │ ├── vars.prod.properties
│ │ └── vars.sit.properties
│ ├── examples
│ │ ├── Example1.scala
│ │ └── Example2.scala
│ ├── models
│ │ ├── FeatureBuilder.scala
│ │ └── ItemEmbedding.scala
│ ├── core
│ │ ├── BaseSparkOnline.scala
│ │ └── BaseSparkLocal.scala
│ ├── prediction
│ │ ├── ItemEmbeddingPredictor.scala
│ │ ├── PredictUserVector.scala
│ │ └── PredictUserVectorTwo.scala
│ └── data
│ │ ├── MakeDataOne.scala
│ │ └── MakeDataTwo.scala
└── python
│ ├── examples
│ ├── example2.py
│ └── example1.py
│ ├── data
│ ├── tfrecords_methods
│ │ ├── tfrecords
│ │ │ └── data1.tfrecords
│ │ ├── data2tfrecord1.py
│ │ ├── read_sparse_tfrecords_2.py
│ │ └── read_sparse_tfrecords_1.py
│ ├── read_tfrecords.py
│ └── data2tfrecords.py
│ ├── utils
│ └── tensor_board.py
│ ├── reference
│ ├── feature_column.py
│ └── self_defined_network_layer.py
│ └── models
│ ├── load_dnn_model.py
│ └── dnn.py
├── requirements.txt
├── data
├── tfrecords
│ └── tfrecords
│ │ ├── train
│ │ └── train.tfrecords
│ │ └── evaluation
│ │ └── evaluation.tfrecords
└── checkpoints
│ └── ckpt
│ ├── events.out.tfevents.1575536459.CNHQ-18076444T
│ ├── eval
│ └── events.out.tfevents.1575536462.CNHQ-18076444T
│ └── checkpoint
├── docs
├── Deep Neural Networks for YouTube Recommendations.pdf
└── architecture.drawio
├── .gitignore
├── README.md
└── pom.xml
/src/scala/config/vars.prod.properties:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/scala/config/vars.sit.properties:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/python/examples/example2.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow>=2.0.0
2 | numpy>=1.19.0
3 | pandas>=1.3.0
4 | matplotlib>=3.3.0
5 | seaborn>=0.11.0
6 | scikit-learn>=1.0.0
7 |
--------------------------------------------------------------------------------
/data/tfrecords/tfrecords/train/train.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxingqiang/YouTube-DNN-RecSys/HEAD/data/tfrecords/tfrecords/train/train.tfrecords
--------------------------------------------------------------------------------
/data/tfrecords/tfrecords/evaluation/evaluation.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxingqiang/YouTube-DNN-RecSys/HEAD/data/tfrecords/tfrecords/evaluation/evaluation.tfrecords
--------------------------------------------------------------------------------
/docs/Deep Neural Networks for YouTube Recommendations.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxingqiang/YouTube-DNN-RecSys/HEAD/docs/Deep Neural Networks for YouTube Recommendations.pdf
--------------------------------------------------------------------------------
/src/python/data/tfrecords_methods/tfrecords/data1.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxingqiang/YouTube-DNN-RecSys/HEAD/src/python/data/tfrecords_methods/tfrecords/data1.tfrecords
--------------------------------------------------------------------------------
/data/checkpoints/ckpt/events.out.tfevents.1575536459.CNHQ-18076444T:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxingqiang/YouTube-DNN-RecSys/HEAD/data/checkpoints/ckpt/events.out.tfevents.1575536459.CNHQ-18076444T
--------------------------------------------------------------------------------
/data/checkpoints/ckpt/eval/events.out.tfevents.1575536462.CNHQ-18076444T:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxingqiang/YouTube-DNN-RecSys/HEAD/data/checkpoints/ckpt/eval/events.out.tfevents.1575536462.CNHQ-18076444T
--------------------------------------------------------------------------------
/data/checkpoints/ckpt/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "model.ckpt-1600"
2 | all_model_checkpoint_paths: "model.ckpt-1200"
3 | all_model_checkpoint_paths: "model.ckpt-1300"
4 | all_model_checkpoint_paths: "model.ckpt-1400"
5 | all_model_checkpoint_paths: "model.ckpt-1500"
6 | all_model_checkpoint_paths: "model.ckpt-1600"
7 |
--------------------------------------------------------------------------------
/src/scala/examples/Example1.scala:
--------------------------------------------------------------------------------
1 | package example
2 |
3 | import sparkapplication.BaseSparkLocal
4 | import scala.collection.mutable
5 |
6 | object Example1 extends BaseSparkLocal {
7 | def main(args:Array[String]):Unit = {
8 | val spark = this.basicSpark
9 | import spark.implicits._
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/src/python/examples/example1.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import json
5 | import tensorflow as tf
6 |
7 | a = tf.constant([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) #3*2
8 | b = tf.constant([[1, 0], [2, 1], [0, 1]]) #3*2
9 | c = tf.nn.embedding_lookup(a, b)
10 | d = tf.reduce_mean(c, axis=1)
11 | e = tf.concat([d, a], 1)
12 |
13 | with tf.Session() as sess:
14 | print(c)
15 | print(sess.run(c))
16 | print(d)
17 | print(sess.run(d))
18 | print(e)
19 | print(sess.run(e))
20 |
21 |
--------------------------------------------------------------------------------
/src/scala/examples/Example2.scala:
--------------------------------------------------------------------------------
1 | package example
2 |
3 | import sparkapplication.BaseSparkLocal
4 |
5 | object Example2 extends BaseSparkLocal {
6 | def main(args:Array[String]):Unit = {
7 | // val spark = this.basicSpark
8 | // import spark.implicits._
9 |
10 | val gdsVector = "[8534.033203125,-6634.611328125,-20669.0703125,-9483.734375,8790.3935546875,15647.646484375,-15543.39453125,34464.3203125,-1275.48974609375,28998.267578125,2446.0126953125,32628.033203125,1429.67431640625,37169.6640625,1902.3770751953125,-31038.359375]"
11 | val gds123 = gdsVector.replace("[", "").replace("]", "").split(",", -1).map(_.toDouble)
12 | gds123.foreach(println(_))
13 |
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/src/scala/models/FeatureBuilder.scala:
--------------------------------------------------------------------------------
1 | package explore
2 |
3 | import org.tensorflow.example._
4 | import org.tensorflow.spark.shaded.com.google.protobuf.ByteString
5 |
6 | object FeatureBuilder {
7 | def s(strings: String*): Feature = {
8 | val b = BytesList.newBuilder
9 | for (s <- strings) {
10 | b.addValue(ByteString.copyFromUtf8(s))
11 | }
12 | Feature.newBuilder.setBytesList(b).build
13 | }
14 |
15 | def f(values: Float*): Feature = {
16 | val b = FloatList.newBuilder
17 | for (v <- values) {
18 | b.addValue(v)
19 | }
20 | Feature.newBuilder.setFloatList(b).build
21 | }
22 |
23 | def i(values: Int*): Feature = {
24 | val b = Int64List.newBuilder
25 | for (v <- values) {
26 | b.addValue(v)
27 | }
28 | Feature.newBuilder.setInt64List(b).build
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/src/scala/core/BaseSparkOnline.scala:
--------------------------------------------------------------------------------
1 | package sparkapplication
2 |
3 | import org.apache.spark.SparkConf
4 | import org.apache.spark.sql.SparkSession
5 |
6 | trait BaseSparkOnline {
7 | def basicSpark: SparkSession =
8 | SparkSession
9 | .builder
10 | .config(getSparkConf)
11 | .enableHiveSupport()
12 | .getOrCreate()
13 |
14 | def getSparkConf: SparkConf = {
15 | val conf = new SparkConf()
16 | conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
17 | .set("spark.network.timeout", "6000")
18 | .set("spark.streaming.kafka.maxRatePerPartition", "200000")
19 | .set("spark.streaming.kafka.consumer.poll.ms", "5120")
20 | .set("spark.streaming.concurrentJobs", "5")
21 | .set("spark.sql.crossJoin.enabled", "true")
22 | .set("spark.driver.maxResultSize", "20g")
23 | .set("spark.rpc.message.maxSize", "1000") // 1024 max
24 | }
25 |
26 | }
27 |
--------------------------------------------------------------------------------
/src/scala/core/BaseSparkLocal.scala:
--------------------------------------------------------------------------------
1 | package sparkapplication
2 |
3 | import org.apache.spark.SparkConf
4 | import org.apache.spark.sql.SparkSession
5 |
6 | trait BaseSparkLocal {
7 | //本地
8 | def basicSpark: SparkSession =
9 | SparkSession
10 | .builder
11 | .config(getSparkConf)
12 | .master("local[1]")
13 | .getOrCreate()
14 |
15 | def getSparkConf: SparkConf = {
16 | val conf = new SparkConf()
17 | conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
18 | .set("spark.network.timeout", "600")
19 | .set("spark.streaming.kafka.maxRatePerPartition", "200000")
20 | .set("spark.streaming.kafka.consumer.poll.ms", "5120")
21 | .set("spark.streaming.concurrentJobs", "5")
22 | .set("spark.sql.crossJoin.enabled", "true")
23 | .set("spark.driver.maxResultSize", "1g")
24 | .set("spark.rpc.message.maxSize", "1000") // 1024 max
25 | conf
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/src/python/utils/tensor_board.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import tensorflow as tf
4 | import sys
5 | from tensorflow.python.platform import gfile
6 | from tensorflow.core.protobuf import saved_model_pb2
7 | from tensorflow.python.util import compat
8 |
9 | # 运行完后, tensorboard --logdir ./logdir, 然后在浏览器中输入地址: http://localhost:6006/
10 | with tf.Session() as sess:
11 | model_filename ='../../data/checkpoints/modelpath/1575536466/saved_model.pb'
12 | with gfile.FastGFile(model_filename, 'rb') as f:
13 | data = compat.as_bytes(f.read())
14 | sm = saved_model_pb2.SavedModel()
15 | sm.ParseFromString(data)
16 |
17 | if 1 != len(sm.meta_graphs):
18 | print('More than one graph found. Not sure which to write')
19 | sys.exit(1)
20 |
21 | g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
22 | LOGDIR='../../data/checkpoints/logdir'
23 | train_writer = tf.summary.FileWriter(LOGDIR)
24 | train_writer.add_graph(sess.graph)
25 | train_writer.flush()
26 | train_writer.close()
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | downloads/
11 | eggs/
12 | .eggs/
13 | lib/
14 | lib64/
15 | parts/
16 | sdist/
17 | var/
18 | wheels/
19 | *.egg-info/
20 | .installed.cfg
21 | *.egg
22 | MANIFEST
23 |
24 | # Jupyter Notebook
25 | .ipynb_checkpoints
26 |
27 | # pyenv
28 | .python-version
29 |
30 | # Environments
31 | .env
32 | .venv
33 | env/
34 | venv/
35 | ENV/
36 | env.bak/
37 | venv.bak/
38 |
39 | # IDE
40 | .vscode/
41 | .idea/
42 | *.swp
43 | *.swo
44 | *~
45 |
46 | # OS
47 | .DS_Store
48 | .DS_Store?
49 | ._*
50 | .Spotlight-V100
51 | .Trashes
52 | ehthumbs.db
53 | Thumbs.db
54 |
55 | # TensorFlow
56 | *.ckpt
57 | *.meta
58 | *.data-00000-of-00001
59 | *.index
60 | *.pb
61 | *.pbtxt
62 |
63 | # Logs
64 | *.log
65 | logdir/
66 | tensorboard_logs/
67 |
68 | # Data
69 |
70 |
71 |
72 | # Scala/Java
73 | target/
74 | *.class
75 | *.jar
76 | *.war
77 | *.ear
78 | *.zip
79 | *.tar.gz
80 | *.rar
81 | hs_err_pid*
82 |
83 | # Maven
84 | .mvn/
85 | mvnw
86 | mvnw.cmd
87 |
--------------------------------------------------------------------------------
/src/python/data/tfrecords_methods/data2tfrecord1.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | if __name__ == "__main__":
5 | per_item_sample_num = 20
6 | item_num = 15
7 | embedding_size = 8
8 | filename = "../../data/tfrecords_methods/tfrecords/data1.tfrecords"
9 | writer = tf.python_io.TFRecordWriter(filename)
10 | for i in range(per_item_sample_num):
11 | for j in range(item_num):
12 | embedding_average = np.random.uniform(low=j, high=j + 1.0, size=[embedding_size])
13 | index = j
14 | example = tf.train.Example(features=tf.train.Features(feature={
15 | "embedding_average": tf.train.Feature(float_list=tf.train.FloatList(value=embedding_average)),
16 | "index": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
17 | "value": tf.train.Feature(float_list=tf.train.FloatList(value=[1.0])),
18 | "size": tf.train.Feature(int64_list=tf.train.Int64List(value=[item_num]))
19 | }))
20 | writer.write(example.SerializeToString())
21 | writer.close()
--------------------------------------------------------------------------------
/src/python/data/read_tfrecords.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | def parse_fn(example):
4 | example_fmt = {
5 | "visit_items_index": tf.FixedLenFeature([5], tf.int64),
6 | "continuous_features_value": tf.FixedLenFeature([16], tf.float32),
7 | "next_visit_item_index": tf.FixedLenFeature([], tf.int64)
8 | }
9 | parsed = tf.parse_single_example(example, example_fmt)
10 | next_visit_item_index = parsed.pop("next_visit_item_index")
11 | return parsed, next_visit_item_index
12 |
13 | if __name__ == "__main__":
14 | files = tf.data.Dataset.list_files('../../data/tfrecords/train/train.tfrecords', shuffle=True)
15 | data_set = files.apply(
16 | tf.contrib.data.parallel_interleave(
17 | lambda filename: tf.data.TFRecordDataset(filename),
18 | cycle_length=16))
19 | data_set = data_set.repeat(1)
20 | data_set = data_set.map(map_func=parse_fn, num_parallel_calls=16)
21 | data_set = data_set.prefetch(buffer_size=64)
22 | data_set = data_set.batch(batch_size=16)
23 | iterator = data_set.make_one_shot_iterator()
24 | res1, res2 = iterator.get_next()
25 |
26 | with tf.Session() as sess:
27 | for i in range(5):
28 | result1, result2 = sess.run([res1, res2])
29 | print("第{}批:".format(i), end=" ")
30 | print("result1是:", result1)
31 | print("result2是:", result2)
--------------------------------------------------------------------------------
/src/python/data/tfrecords_methods/read_sparse_tfrecords_2.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | def parse_fn(example):
4 | example_fmt = {
5 | "embedding_average": tf.FixedLenFeature([8], tf.float32),
6 | "one_hot": tf.SparseFeature(index_key=["index"],
7 | value_key="value",
8 | dtype=tf.float32,
9 | size=[15]) # size必须写死, 不能传超参
10 | }
11 | parsed = tf.parse_single_example(example, example_fmt)
12 | return parsed["embedding_average"], tf.sparse_tensor_to_dense(parsed["one_hot"])
13 |
14 | if __name__ == "__main__":
15 | files = tf.data.Dataset.list_files('../../data/tfrecords_methods/tfrecords/data1.tfrecords', shuffle=True)
16 | data_set = files.apply(
17 | tf.contrib.data.parallel_interleave(
18 | lambda filename: tf.data.TFRecordDataset(filename),
19 | cycle_length=15))
20 | data_set = data_set.repeat(1)
21 | data_set = data_set.map(map_func=parse_fn, num_parallel_calls=15)
22 | data_set = data_set.prefetch(buffer_size=30)
23 | data_set = data_set.batch(batch_size=15)
24 | iterator = data_set.make_one_shot_iterator()
25 | embedding, one_hot = iterator.get_next()
26 |
27 | with tf.Session() as sess:
28 | for i in range(5):
29 | embedding_result, one_hot_result = sess.run([embedding, one_hot])
30 | print("第{}批:".format(i), end=" ")
31 | print("embedding是:", embedding_result, end=" ")
32 | print("one_hot是:", one_hot_result)
--------------------------------------------------------------------------------
/src/python/data/tfrecords_methods/read_sparse_tfrecords_1.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | def parse_fn(example):
4 | example_fmt = {
5 | "embedding_average": tf.FixedLenFeature([8], tf.float32),
6 | "index": tf.FixedLenFeature([], tf.int64),
7 | "value": tf.FixedLenFeature([], tf.float32),
8 | "size": tf.FixedLenFeature([], tf.int64)
9 | }
10 | parsed = tf.parse_single_example(example, example_fmt)
11 | sparse_tensor = tf.SparseTensor([[parsed["index"]]], [parsed["value"]], [parsed["size"]]) # 这种方法读取稀疏向量在有的平台可能不行
12 | return parsed["embedding_average"], tf.sparse_tensor_to_dense(sparse_tensor)
13 |
14 | if __name__ == "__main__":
15 | files = tf.data.Dataset.list_files('../../data/tfrecords_methods/tfrecords/data1.tfrecords', shuffle=True)
16 | data_set = files.apply(
17 | tf.contrib.data.parallel_interleave(
18 | lambda filename: tf.data.TFRecordDataset(filename),
19 | cycle_length=15))
20 | data_set = data_set.repeat(1)
21 | data_set = data_set.map(map_func=parse_fn, num_parallel_calls=15)
22 | data_set = data_set.prefetch(buffer_size=30)
23 | data_set = data_set.batch(batch_size=15)
24 | iterator = data_set.make_one_shot_iterator()
25 | embedding, one_hot = iterator.get_next()
26 |
27 | with tf.Session() as sess:
28 | for i in range(5):
29 | embedding_result, one_hot_result = sess.run([embedding, one_hot])
30 | print("第{}批:".format(i), end=" ")
31 | print("embedding是:", embedding_result, end=" ")
32 | print("one_hot是:", one_hot_result)
--------------------------------------------------------------------------------
/src/scala/models/ItemEmbedding.scala:
--------------------------------------------------------------------------------
1 | package explore
2 |
3 | import org.tensorflow._
4 | import sparkapplication.BaseSparkOnline
5 | import scala.collection.JavaConverters._
6 |
7 | object ItemEmbeddingMakeDataOne extends BaseSparkOnline {
8 | def main(args:Array[String]):Unit = {
9 | val spark = this.basicSpark
10 | import spark.implicits._
11 |
12 | val modelHdfsPath = "hdfs路径"
13 | val modelTag = "serve"
14 |
15 | val embeddingAverageArray = Array(Array.fill[Float](8)(0.1F))
16 | val model = SavedModelBundle.load(modelHdfsPath, modelTag)
17 | val sess = model.session()
18 | val embeddingAverageArrayTensor = Tensor.create(embeddingAverageArray, classOf[java.lang.Float])
19 | val itemEmbeddingResult = getItemEmbedding(sess, embeddingAverageArrayTensor)
20 |
21 | val result = spark.sparkContext.parallelize(itemEmbeddingResult.map(k => k.mkString("@"))).toDF("item_embedding")
22 | result.show(10, false)
23 | }
24 |
25 | private def getItemEmbedding(sess: Session, embeddingAverageArrayTensor: Tensor[_], embeddingAverageArrayName: String = "Placeholder:0", itemEmbeddingName: String = "item_embedding:0") = {
26 | val resultBuffer = sess.runner
27 | .feed(embeddingAverageArrayName, embeddingAverageArrayTensor)
28 | .fetch(itemEmbeddingName)
29 | .run.asScala
30 |
31 | val itemEmbedding = resultBuffer.head
32 | val itemEmbeddingShape: Array[Int] = itemEmbedding.shape.map(_.toInt)
33 | val itemEmbeddingResult = Array.ofDim[Float](itemEmbeddingShape.head, itemEmbeddingShape(1))
34 | itemEmbedding.copyTo(itemEmbeddingResult)
35 |
36 | itemEmbeddingResult
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/src/scala/prediction/ItemEmbeddingPredictor.scala:
--------------------------------------------------------------------------------
1 | package explore
2 |
3 | import org.tensorflow._
4 | import org.tensorflow.example._
5 | import explore.FeatureBuilder._
6 | import sparkapplication.BaseSparkOnline
7 | import scala.collection.JavaConverters._
8 |
9 | object ItemEmbeddingMakeDataTwo extends BaseSparkOnline {
10 | def main(args:Array[String]):Unit = {
11 | val spark = this.basicSpark
12 | import spark.implicits._
13 |
14 | val modelHdfsPath = "hdfs路径"
15 | val modelTag = "serve"
16 |
17 | val embeddingAverage = Array.fill[Float](8)(0.1F)
18 | val gender = "male"
19 | val cityCd = "city_cd_100"
20 | val featuresBuilder = Features.newBuilder
21 | .putFeature("embedding_average", f(embeddingAverage:_*))
22 | .putFeature("gender", s(gender))
23 | .putFeature("city_cd", s(cityCd))
24 | featuresBuilder.build()
25 | val features = Example.newBuilder.setFeatures(featuresBuilder).build.toByteArray
26 |
27 | val model = SavedModelBundle.load(modelHdfsPath, modelTag)
28 | val sess = model.session()
29 | val embeddingAverageArrayTensor = Tensor.create(Array(features))
30 | val itemEmbeddingResult = getItemEmbedding(sess, embeddingAverageArrayTensor)
31 |
32 | val result = spark.sparkContext.parallelize(itemEmbeddingResult.map(k => k.mkString("@"))).toDF("item_embedding")
33 | result.show(50, false)
34 | }
35 |
36 | private def getItemEmbedding(sess: Session, featuresArrayTensor: Tensor[_], featuresArrayName: String = "input_example_tensor:0", itemEmbeddingName: String = "item_embedding:0") = {
37 | val resultBuffer = sess.runner
38 | .feed(featuresArrayName, featuresArrayTensor)
39 | .fetch(itemEmbeddingName)
40 | .run.asScala
41 |
42 | val itemEmbedding = resultBuffer.head
43 | val itemEmbeddingShape: Array[Int] = itemEmbedding.shape.map(_.toInt)
44 | val itemEmbeddingResult = Array.ofDim[Float](itemEmbeddingShape.head, itemEmbeddingShape(1))
45 | itemEmbedding.copyTo(itemEmbeddingResult)
46 |
47 | itemEmbeddingResult
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/src/scala/prediction/PredictUserVector.scala:
--------------------------------------------------------------------------------
1 | package explore
2 |
3 | import org.tensorflow._
4 | import sparkapplication.BaseSparkOnline
5 | import scala.collection.JavaConverters._
6 |
7 | object PredictUserVectorMakeDataOne extends BaseSparkOnline {
8 | def main(args:Array[String]):Unit = {
9 | val spark = this.basicSpark
10 | import spark.implicits._
11 |
12 | val modelHdfsPath = "hdfs路径"
13 | val modelTag = "serve"
14 |
15 | val dataValidation = spark.read.format("tfrecords")
16 | .option("recordType", "Example")
17 | .load("hdfs路径")
18 | .rdd.map{row =>
19 | val embeddingAverage = row.getAs[scala.collection.mutable.WrappedArray[Float]]("embedding_average")
20 | embeddingAverage.toArray
21 | }
22 | println(s"验证集数据dataValidation总数为:${dataValidation.count},数据格式如下:")
23 | dataValidation.toDF("embedding_average").show(5, false)
24 |
25 | val userVectorAll = dataValidation.mapPartitions(lineIterator => {
26 | val embeddingAverageArray = lineIterator.toArray
27 | val model = SavedModelBundle.load(modelHdfsPath, modelTag)
28 | val sess = model.session()
29 | val embeddingAverageArrayTensor = Tensor.create(embeddingAverageArray, classOf[java.lang.Float])
30 | val userVectorResult = predictUserVector(sess, embeddingAverageArrayTensor)
31 | userVectorResult.toIterator
32 | })
33 |
34 | val result = userVectorAll.map(k => k.mkString("@")).toDF("user_vector")
35 | result.show(10, false)
36 | }
37 |
38 | private def predictUserVector(sess: Session, embeddingAverageArrayTensor: Tensor[_], embeddingAverageArrayName: String = "Placeholder:0", userVectorName: String = "user_vector/Relu:0") = {
39 | val resultBuffer = sess.runner
40 | .feed(embeddingAverageArrayName, embeddingAverageArrayTensor)
41 | .fetch(userVectorName)
42 | .run.asScala
43 |
44 | val userVector = resultBuffer.head
45 | val userVectorShape: Array[Int] = userVector.shape.map(_.toInt)
46 | val userVectorResult = Array.ofDim[Float](userVectorShape.head, userVectorShape(1))
47 | userVector.copyTo(userVectorResult)
48 |
49 | userVectorResult
50 | }
51 |
52 | }
53 |
--------------------------------------------------------------------------------
/src/python/reference/feature_column.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | member_id = "member_id_{}".format(1)
5 | gds_cd = "gds_cd_{}".format(1)
6 | age = np.random.randint(18, 60)
7 | height = np.random.uniform(170.0, 190.0)
8 | example = tf.train.Example(features=tf.train.Features(feature={
9 | "member_id": tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.compat.as_bytes(member_id)])),
10 | "gds_cd": tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.compat.as_bytes(gds_cd)])),
11 | "age": tf.train.Feature(int64_list=tf.train.Int64List(value=[age])),
12 | "height": tf.train.Feature(float_list=tf.train.FloatList(value=[height]))
13 | }))
14 | serialized_example = example.SerializeToString()
15 |
16 | example_fmt = {
17 | "member_id": tf.FixedLenFeature([1], tf.string),
18 | "gds_cd": tf.FixedLenFeature([1], tf.string),
19 | "age": tf.FixedLenFeature([1], tf.int64),
20 | "height": tf.FixedLenFeature([1], tf.float32)
21 | }
22 | parsed = tf.parse_single_example(serialized_example, example_fmt)
23 |
24 | member_id = tf.feature_column.categorical_column_with_hash_bucket("member_id", hash_bucket_size=3)
25 | gds_cd = tf.feature_column.categorical_column_with_hash_bucket("gds_cd", hash_bucket_size=3)
26 | age = tf.feature_column.categorical_column_with_vocabulary_list("age", [i for i in range(3)], dtype=tf.int64,
27 | default_value=0)
28 | height = tf.feature_column.numeric_column("height")
29 | member_id_indicator = tf.feature_column.indicator_column(member_id)
30 | gds_cd_indicator = tf.feature_column.indicator_column(gds_cd)
31 | age_indicator = tf.feature_column.indicator_column(age)
32 | feature_columns = [member_id_indicator, gds_cd_indicator, age_indicator, height]
33 | _result = tf.feature_column.input_layer(parsed, feature_columns)
34 |
35 | with tf.Session() as sess:
36 | sess.run(tf.global_variables_initializer())
37 | sess.run(tf.tables_initializer())
38 | parsed_result = sess.run([parsed])
39 | print("parsed_result是:", parsed_result)
40 | result = sess.run([_result])
41 | print("result是:", result)
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/src/python/data/data2tfrecords.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | if __name__ == "__main__":
5 | sample_num = 5000
6 | item_num = 500
7 | sample_set = []
8 | for i in range(sample_num):
9 | visit_items_index = np.random.randint(low=0, high=item_num, size=[5])
10 | continuous_features_value = np.random.uniform(low=-5.0, high=5.0, size=[16])
11 | next_visit_item_index = np.random.randint(low=0, high=item_num)
12 | sample = [visit_items_index, continuous_features_value, next_visit_item_index]
13 | sample_set.append(sample)
14 |
15 | # 训练数据
16 | filename = "../../data/tfrecords/train/train.tfrecords"
17 | writer = tf.python_io.TFRecordWriter(filename)
18 | for sample in sample_set:
19 | visit_items_index = sample[0]
20 | continuous_features_value = sample[1]
21 | next_visit_item_index = sample[2]
22 | example = tf.train.Example(features=tf.train.Features(feature={
23 | "visit_items_index": tf.train.Feature(int64_list=tf.train.Int64List(value=visit_items_index)),
24 | "continuous_features_value": tf.train.Feature(
25 | float_list=tf.train.FloatList(value=continuous_features_value)),
26 | "next_visit_item_index": tf.train.Feature(int64_list=tf.train.Int64List(value=[next_visit_item_index]))
27 | }))
28 | writer.write(example.SerializeToString())
29 | writer.close()
30 |
31 | # 评估数据, 由于数据是随机生成, 所以评估数据从训练数据中取
32 | filename = "../../data/tfrecords/evaluation/evaluation.tfrecords"
33 | writer = tf.python_io.TFRecordWriter(filename)
34 | i = 0
35 | for sample in sample_set:
36 | if i % 10 == 0:
37 | visit_items_index = sample[0]
38 | continuous_features_value = sample[1]
39 | next_visit_item_index = sample[2]
40 | example = tf.train.Example(features=tf.train.Features(feature={
41 | "visit_items_index": tf.train.Feature(int64_list=tf.train.Int64List(value=visit_items_index)),
42 | "continuous_features_value": tf.train.Feature(
43 | float_list=tf.train.FloatList(value=continuous_features_value)),
44 | "next_visit_item_index": tf.train.Feature(int64_list=tf.train.Int64List(value=[next_visit_item_index]))
45 | }))
46 | writer.write(example.SerializeToString())
47 | i = i + 1
48 | writer.close()
--------------------------------------------------------------------------------
/src/scala/data/MakeDataOne.scala:
--------------------------------------------------------------------------------
1 | package explore
2 |
3 | import java.util.Random
4 | import org.apache.spark.sql._
5 | import sparkapplication.BaseSparkOnline
6 | import scala.collection.mutable.ArrayBuffer
7 |
8 | object MakeDataOne extends BaseSparkOnline {
9 | def main(args:Array[String]):Unit = {
10 | val spark = this.basicSpark
11 | import spark.implicits._
12 |
13 | // 训练数据
14 | var perItemSampleNum = 20
15 | var itemNum = 15
16 | var embeddingSize = 8
17 | val trainData = ArrayBuffer[(Array[Double], Long, Double, Long)]()
18 | for(i <- 0 until perItemSampleNum) {
19 | for(j <- 0 until itemNum){
20 | val embeddingAverage = Array.fill[Double](embeddingSize)(1.0*j + (new Random).nextDouble())
21 | trainData.append((embeddingAverage, j.toLong, 1.0, itemNum.toLong))
22 | }
23 | }
24 | val trainDataFrame = spark.sparkContext.parallelize(trainData, 10).toDF("embedding_average", "index", "value", "size")
25 |
26 | // Save DataFrame as TFRecords
27 | trainDataFrame.write.mode(SaveMode.Overwrite).format("tfrecords").option("recordType", "Example").save("hdfs路径")
28 |
29 | // Read TFRecords into DataFrame.
30 | val trainDataTfrecords: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load("hdfs路径")
31 | println("trainDataFrame重新加载tfrecords格式的数据,数据格式如下:")
32 | trainDataTfrecords.show(10, false)
33 |
34 | // 评估数据
35 | perItemSampleNum = 10
36 | itemNum = 15
37 | embeddingSize = 8
38 | val evaluationData = ArrayBuffer[(Array[Double], Long, Double, Long)]()
39 | for(i <- 0 until perItemSampleNum) {
40 | for(j <- 0 until itemNum){
41 | val embeddingAverage = Array.fill[Double](embeddingSize)(1.0*j + (new Random).nextDouble())
42 | evaluationData.append((embeddingAverage, j.toLong, 1.0, itemNum.toLong))
43 | }
44 | }
45 | val evaluationDataFrame = spark.sparkContext.parallelize(evaluationData, 10).toDF("embedding_average", "index", "value", "size")
46 |
47 | // Save DataFrame as TFRecords
48 | evaluationDataFrame.write.mode(SaveMode.Overwrite).format("tfrecords").option("recordType", "Example").save("hdfs路径")
49 |
50 | // Read TFRecords into DataFrame.
51 | val evaluationDataTfrecords: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load("hdfs路径")
52 | println("evaluationData重新加载tfrecords格式的数据,数据格式如下:")
53 | evaluationDataTfrecords.show(10, false)
54 |
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/src/scala/prediction/PredictUserVectorTwo.scala:
--------------------------------------------------------------------------------
1 | package explore
2 |
3 | import org.tensorflow._
4 | import org.tensorflow.example._
5 | import explore.FeatureBuilder._
6 | import sparkapplication.BaseSparkOnline
7 | import scala.collection.JavaConverters._
8 |
9 | object PredictUserVectorMakeDataTwo extends BaseSparkOnline {
10 | def main(args:Array[String]):Unit = {
11 | val spark = this.basicSpark
12 | import spark.implicits._
13 |
14 | val modelHdfsPath = "hdfs路径"
15 | val modelTag = "serve"
16 |
17 | val dataValidation = spark.read.format("tfrecords")
18 | .option("recordType", "Example")
19 | .load("hdfs路径")
20 | .rdd.map{row =>
21 | val embeddingAverage = row.getAs[scala.collection.mutable.WrappedArray[Float]]("embedding_average").toArray
22 | val gender = row.getAs[String]("gender")
23 | val cityCd = row.getAs[String]("city_cd")
24 | val featuresBuilder = Features.newBuilder
25 | .putFeature("embedding_average", f(embeddingAverage:_*))
26 | .putFeature("gender", s(gender))
27 | .putFeature("city_cd", s(cityCd))
28 | featuresBuilder.build()
29 | val features = Example.newBuilder.setFeatures(featuresBuilder).build.toByteArray
30 | features
31 | }
32 | println(s"验证集数据dataValidation总数为:${dataValidation.count},数据格式如下:")
33 | dataValidation.toDF("features").show(5, false)
34 |
35 | val userVectorAll = dataValidation.mapPartitions(lineIterator => {
36 | val featuresArray = lineIterator.toArray
37 | val model = SavedModelBundle.load(modelHdfsPath, modelTag)
38 | val sess = model.session()
39 | val featuresArrayTensor = Tensor.create(featuresArray)
40 | val userVectorResult = predictUserVector(sess, featuresArrayTensor)
41 | userVectorResult.toIterator
42 | })
43 |
44 | val result = userVectorAll.map(k => k.mkString("@")).toDF("user_vector")
45 | result.show(10, false)
46 | }
47 |
48 | private def predictUserVector(sess: Session, featuresArrayTensor: Tensor[_], featuresArrayName: String = "input_example_tensor:0", userVectorName: String = "user_vector/Relu:0") = {
49 | val resultBuffer = sess.runner
50 | .feed(featuresArrayName, featuresArrayTensor)
51 | .fetch(userVectorName)
52 | .run.asScala
53 |
54 | val userVector = resultBuffer.head
55 | val userVectorShape: Array[Int] = userVector.shape.map(_.toInt)
56 | val userVectorResult = Array.ofDim[Float](userVectorShape.head, userVectorShape(1))
57 | userVector.copyTo(userVectorResult)
58 |
59 | userVectorResult
60 | }
61 |
62 | }
63 |
--------------------------------------------------------------------------------
/src/scala/data/MakeDataTwo.scala:
--------------------------------------------------------------------------------
1 | package explore
2 |
3 | import java.util.Random
4 | import org.apache.spark.sql._
5 | import sparkapplication.BaseSparkOnline
6 | import scala.collection.mutable.ArrayBuffer
7 |
8 | object MakeDataTwo extends BaseSparkOnline {
9 | def main(args:Array[String]):Unit = {
10 | val spark = this.basicSpark
11 | import spark.implicits._
12 |
13 | // 训练数据
14 | var perItemSampleNum = 20
15 | var itemNum = 15
16 | var embeddingSize = 8
17 | val trainData = ArrayBuffer[(Array[Double], String, String, Long)]()
18 | for(i <- 0 until perItemSampleNum) {
19 | for(j <- 0 until itemNum){
20 | val embeddingAverage = Array.fill[Double](embeddingSize)(1.0*j + (new Random).nextDouble())
21 | val gender = if(j % 2 == 0) "male" else "female"
22 | val cityCd = "city_cd_" + (new Random).nextInt(200).toString
23 | trainData.append((embeddingAverage, gender, cityCd, j.toLong))
24 | }
25 | }
26 | val trainDataFrame = spark.sparkContext.parallelize(trainData, 10).toDF("embedding_average", "gender", "city_cd", "index")
27 |
28 | // Save DataFrame as TFRecords
29 | trainDataFrame.write.mode(SaveMode.Overwrite).format("tfrecords").option("recordType", "Example").save("hdfs路径")
30 |
31 | // Read TFRecords into DataFrame.
32 | val trainDataTfrecords: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load("hdfs路径")
33 | println("trainDataFrame重新加载tfrecords格式的数据,数据格式如下:")
34 | trainDataTfrecords.show(10, false)
35 |
36 | // 评估数据
37 | perItemSampleNum = 10
38 | itemNum = 15
39 | embeddingSize = 8
40 | val evaluationData = ArrayBuffer[(Array[Double], String, String, Long)]()
41 | for(i <- 0 until perItemSampleNum) {
42 | for(j <- 0 until itemNum){
43 | val embeddingAverage = Array.fill[Double](embeddingSize)(1.0*j + (new Random).nextDouble())
44 | val gender = if(j % 2 == 0) "male" else "female"
45 | val cityCd = "city_cd_" + (new Random).nextInt(200).toString
46 | evaluationData.append((embeddingAverage, gender, cityCd, j.toLong))
47 | }
48 | }
49 | val evaluationDataFrame = spark.sparkContext.parallelize(evaluationData, 10).toDF("embedding_average", "gender", "city_cd", "index")
50 |
51 | // Save DataFrame as TFRecords
52 | evaluationDataFrame.write.mode(SaveMode.Overwrite).format("tfrecords").option("recordType", "Example").save("hdfs路径")
53 |
54 | // Read TFRecords into DataFrame.
55 | val evaluationDataTfrecords: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load("hdfs路径")
56 | println("evaluationData重新加载tfrecords格式的数据,数据格式如下:")
57 | evaluationDataTfrecords.show(10, false)
58 |
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/src/python/models/load_dnn_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import tensorflow as tf
4 |
5 | class dataProcess(object):
6 |
7 | def parse_fn(self, example):
8 | example_fmt = {
9 | "visit_items_index": tf.FixedLenFeature([5], tf.int64),
10 | "continuous_features_value": tf.FixedLenFeature([16], tf.float32),
11 | "next_visit_item_index": tf.FixedLenFeature([], tf.int64)
12 | }
13 | parsed = tf.parse_single_example(example, example_fmt)
14 | parsed.pop("next_visit_item_index")
15 | return parsed
16 |
17 | def next_batch(self, batch_size):
18 | files = tf.data.Dataset.list_files(
19 | '../../data/tfrecords/train/train.tfrecords', shuffle=False
20 | )
21 | data_set = files.apply(
22 | tf.contrib.data.parallel_interleave(
23 | lambda filename: tf.data.TFRecordDataset(filename),
24 | cycle_length=16))
25 | data_set = data_set.map(map_func=self.parse_fn, num_parallel_calls=16)
26 | data_set = data_set.prefetch(buffer_size=256)
27 | data_set = data_set.batch(batch_size=batch_size)
28 | iterator = data_set.make_one_shot_iterator()
29 | features = iterator.get_next()
30 | return features
31 |
32 | if __name__ == "__main__":
33 | # 数据预处理#
34 | dataProcess = dataProcess()
35 | features = dataProcess.next_batch(batch_size=16)
36 |
37 | signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
38 | with tf.Session() as sess:
39 | meta_graph_def = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING],
40 | "../../data/checkpoints/modelpath/1575536466")
41 | print(meta_graph_def)
42 | signature = meta_graph_def.signature_def
43 | visit_items_index_tensor_name = signature[signature_key].inputs["visit_items_index"].name
44 | visit_items_index_tensor = sess.graph.get_tensor_by_name(visit_items_index_tensor_name)
45 | continuous_features_value_tensor_name = signature[signature_key].inputs["continuous_features_value"].name
46 | continuous_features_value_tensor = sess.graph.get_tensor_by_name(continuous_features_value_tensor_name)
47 | user_vector_tensor_name = signature[signature_key].outputs["user_vector"].name
48 | user_vector_tensor = sess.graph.get_tensor_by_name(user_vector_tensor_name)
49 | index_tensor_name = signature[signature_key].outputs["index"].name
50 | index_tensor = sess.graph.get_tensor_by_name(index_tensor_name)
51 |
52 | features_result = sess.run(features)
53 | feed_dict = {visit_items_index_tensor: features_result["visit_items_index"], continuous_features_value_tensor: features_result["continuous_features_value"]}
54 | predict_outputs = sess.run([user_vector_tensor, index_tensor], feed_dict=feed_dict)
55 | print(predict_outputs[0])
56 | print("==========")
57 | print(predict_outputs[1])
--------------------------------------------------------------------------------
/src/python/reference/self_defined_network_layer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import numpy as np
4 | import tensorflow as tf
5 | from tensorflow.python.keras import initializers
6 |
7 | class SampleLayer(tf.keras.layers.Layer):
8 | def __init__(self, is_training, top_k, item_num,
9 | kernel_initializer=tf.initializers.random_uniform(minval=-0.1, maxval=0.1), **kwargs):
10 | self.is_training = is_training
11 | self.top_k = top_k
12 | self.item_num = item_num
13 | self.kernel_initializer = kernel_initializer
14 | super(SampleLayer, self).__init__(**kwargs)
15 |
16 | def build(self, input_shape):
17 | assert isinstance(input_shape, list)
18 | input_shape0 = input_shape[0]
19 | # 为该层创建一个可训练的权重
20 | partitioner = tf.compat.v1.fixed_size_partitioner(num_shards=int(input_shape0[1]))
21 | self.kernel = self.add_weight(name="item_embedding",
22 | shape=(self.item_num, int(input_shape0[1])),
23 | initializer=self.kernel_initializer,
24 | trainable=True,
25 | partitioner=partitioner)
26 | # 一定要在最后调用它
27 | super(SampleLayer, self).build(input_shape)
28 |
29 | def train_output(self, inputs0, inputs1):
30 | output_embedding = tf.nn.embedding_lookup(self.kernel, inputs1) # num * embedding_size
31 | logits = tf.matmul(inputs0, output_embedding, transpose_a=False, transpose_b=True) # num * num
32 | yhat = tf.nn.softmax(logits) # num * num
33 | cross_entropy = tf.reduce_mean(-tf.log(tf.matrix_diag_part(yhat) + 1e-16))
34 | return cross_entropy
35 |
36 | def predict_output(self, inputs0):
37 | logits_predict = tf.matmul(inputs0, self.kernel, transpose_a=False, transpose_b=True) # num * item_num
38 | yhat_predict = tf.nn.softmax(logits_predict) # num * item_num
39 | _, indices = tf.nn.top_k(yhat_predict, k=self.top_k, sorted=True) # indices是: num * top_k
40 | indices = tf.cast(indices, tf.float32) # tf.keras.backend.switch输出类型必须一样, 所以将int转为float
41 | return indices
42 |
43 | def func1(self, inputs):
44 | assert len(inputs) == 2
45 | inputs1 = tf.cast(inputs[1], tf.int32)
46 | return inputs1
47 |
48 | def call(self, inputs, **kwargs):
49 | assert isinstance(inputs, list)
50 | inputs0 = inputs[0] # 上一层的输出
51 | inputs1_default = tf.zeros([inputs0.shape[0]], dtype=tf.int32) # 另外一个输入, 这是默认值
52 | inputs1 = tf.cond(self.is_training, lambda: self.func1(inputs), lambda: inputs1_default)
53 | # 如果训练的话, 输出是损失值; 如果预测的话, 输出是相似的top_k索引
54 | train_predict_output = tf.cond(self.is_training, lambda: self.train_output(inputs0, inputs1),
55 | lambda: self.predict_output(inputs0))
56 | return train_predict_output
57 |
58 | def func2(self, input_shape):
59 | input_shape0 = input_shape[0]
60 | return (input_shape0[0], self.top_k)
61 |
62 | def compute_output_shape(self, input_shape):
63 | output_shape = tf.cond(self.is_training, lambda: (), lambda: self.func2(input_shape))
64 | return output_shape
65 |
66 | def get_config(self):
67 | config = {
68 | 'is_training': self.is_training,
69 | 'top_k': self.top_k,
70 | 'item_num': self.item_num,
71 | 'kernel_initializer': initializers.serialize(self.kernel_initializer)
72 | }
73 | base_config = super(SampleLayer, self).get_config()
74 | return dict(list(base_config.items()) + list(config.items()))
75 |
76 | if __name__ == "__main__":
77 | inputs0 = tf.constant([[0.1, 0.2, 0.6, 0.3, 0.5], [0.8, 0.6, 0.9, 0.3, 0.5]])
78 | inputs1 = tf.constant([0, 3])
79 | sample_layer = SampleLayer(tf.constant(True), 3, 10, name="abc")
80 | result = sample_layer([inputs0, inputs1])
81 | print(result)
82 | print(sample_layer.trainable_weights)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # YouTube-DNN-RecSys: Deep Neural Networks for YouTube Recommendations
2 |
3 | ## Deep Neural Networks for YouTube Recommendations
4 | [Paper](https://dl.acm.org/doi/pdf/10.1145/2959100.2959190)
5 |
6 | A clean, well-organized implementation of Deep Neural Networks for YouTube Recommendations, featuring both Python (TensorFlow) and Scala (Spark) implementations.
7 |
8 | [](https://deepwiki.com/chenxingqiang/YouTube-DNN-RecSys)
9 |
10 | ## Achitecture
11 | 
12 |
13 | ## 🏗️ Project Structure
14 |
15 | ```
16 | DNN-YouTube-RecSys/
17 | ├── src/
18 | │ ├── python/ # Python implementation using TensorFlow
19 | │ │ ├── models/ # Core DNN model and loading utilities
20 | │ │ ├── data/ # Data processing and TFRecords handling
21 | │ │ ├── utils/ # TensorBoard and utility functions
22 | │ │ ├── examples/ # Usage examples and tutorials
23 | │ │ └── reference/ # Custom layers and feature engineering
24 | │ └── scala/ # Scala implementation using Spark
25 | │ ├── models/ # Feature building and embedding models
26 | │ ├── data/ # Data generation scripts
27 | │ ├── prediction/ # User vector and item embedding prediction
28 | │ ├── core/ # Base Spark application classes
29 | │ ├── examples/ # Spark usage examples
30 | │ └── config/ # Environment configuration files
31 | ├── data/ # Data storage and model artifacts
32 | │ ├── tfrecords/ # Training and evaluation data
33 | │ └── checkpoints/ # Model checkpoints and saved models
34 | ├── tests/ # Test suites for both Python and Scala
35 | ├── docs/ # Research paper and documentation
36 | ├── requirements.txt # Python dependencies
37 | ├── pom.xml # Maven configuration for Scala
38 | └── .gitignore # Git ignore patterns
39 | ```
40 |
41 | ## 🚀 Quick Start
42 |
43 | ### Python (TensorFlow) Implementation
44 |
45 | 1. **Install dependencies:**
46 | ```bash
47 | pip install -r requirements.txt
48 | ```
49 |
50 | 2. **Train the model:**
51 | ```bash
52 | cd src/python
53 | python models/dnn.py
54 | ```
55 |
56 | 3. **Run examples:**
57 | ```bash
58 | python examples/example1.py
59 | python examples/example2.py
60 | ```
61 |
62 | ### Scala (Spark) Implementation
63 |
64 | 1. **Build the project:**
65 | ```bash
66 | mvn clean compile
67 | ```
68 |
69 | 2. **Run examples:**
70 | ```bash
71 | mvn exec:java -Dexec.mainClass="example.Example1"
72 | ```
73 |
74 | ## 📚 Key Components
75 |
76 | ### Python Implementation
77 | - **`models/dnn.py`**: Core deep neural network model
78 | - **`models/load_dnn_model.py`**: Model loading and inference utilities
79 | - **`data/data2tfrecords.py`**: Data conversion to TFRecords format
80 | - **`utils/tensor_board.py`**: TensorBoard integration for training visualization
81 |
82 | ### Scala Implementation
83 | - **`models/FeatureBuilder.scala`**: Feature engineering utilities
84 | - **`prediction/PredictUserVector.scala`**: User vector prediction
85 | - **`prediction/ItemEmbeddingPredictor.scala`**: Item embedding generation
86 | - **`core/BaseSparkLocal.scala`**: Local Spark application base class
87 |
88 | ## 🔧 Configuration
89 |
90 | - **Python**: Configure via `requirements.txt` and environment variables
91 | - **Scala**: Configure via `src/scala/config/` properties files
92 | - **Data**: Store training data in `data/tfrecords/` directory
93 | - **Models**: Save checkpoints in `data/checkpoints/` directory
94 |
95 | ## 📖 Documentation
96 |
97 | - **Research Paper**: `docs/Deep Neural Networks for YouTube Recommendations.pdf`
98 | - **Code Examples**: See `src/python/examples/` and `src/scala/examples/`
99 | - **Reference Implementations**: Check `src/python/reference/` for custom components
100 |
101 | ## 🤝 Contributing
102 |
103 | 1. Follow the established directory structure
104 | 2. Add tests in the appropriate `tests/` subdirectory
105 | 3. Update documentation for any new features
106 | 4. Ensure both Python and Scala implementations remain consistent
107 |
108 | ## 📄 License
109 |
110 | This project implements the research described in "Deep Neural Networks for YouTube Recommendations" paper. Please refer to the original paper for academic citations and research context.
111 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | algorithm.cedarmo
8 | deep_neural_networks_for_youtube_recommendations
9 | 1.0-SNAPSHOT
10 |
11 |
12 | 2.11
13 | ${scala.binary.version}.8
14 | 2.1.0
15 | compile
16 |
17 |
18 |
19 | org.scala-lang
20 | scala-library
21 | ${scala.version}
22 |
23 |
24 | org.scala-lang
25 | scala-compiler
26 | ${scala.version}
27 |
28 |
29 | commons-lang
30 | commons-lang
31 | 2.5
32 |
33 |
34 |
35 | org.apache.hadoop
36 | hadoop-auth
37 | 2.4.0
38 |
39 |
40 | commons-configuration
41 | commons-configuration
42 | 1.9
43 |
44 |
45 | org.apache.spark
46 | spark-core_${scala.binary.version}
47 | ${spark.version}
48 |
49 |
50 |
51 | org.apache.spark
52 | spark-sql_${scala.binary.version}
53 | ${spark.version}
54 |
55 |
56 |
57 | org.apache.spark
58 | spark-mllib_2.11
59 | ${spark.version}
60 |
61 |
62 |
63 | org.scalanlp
64 | breeze_2.11
65 | 0.13.2
66 |
67 |
68 | org.tensorflow
69 | spark-tensorflow-connector_2.11
70 | 1.13.1
71 |
72 |
73 | org.tensorflow
74 | tensorflow
75 | 1.13.1
76 |
77 |
78 |
79 |
80 |
81 | sit
82 |
83 | true
84 |
85 |
86 |
87 | ../${project.artifactId}/vars/vars.sit.properties
88 |
89 |
90 |
91 | src/main/resources
92 | true
93 |
94 |
95 |
96 |
97 |
98 | prod
99 |
100 |
101 | ../${project.artifactId}/vars/vars.prod.properties
102 |
103 |
104 |
105 | src/main/resources
106 | true
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 | org.codehaus.mojo
117 | build-helper-maven-plugin
118 | 1.8
119 |
120 |
121 | add-source
122 | generate-sources
123 |
124 | add-source
125 |
126 |
127 |
128 | src/main/scala
129 | src/test/scala
130 |
131 |
132 |
133 |
134 | add-test-source
135 | generate-sources
136 |
137 | add-test-source
138 |
139 |
140 |
141 | src/test/scala
142 |
143 |
144 |
145 |
146 |
147 |
148 | net.alchim31.maven
149 | scala-maven-plugin
150 | 3.1.5
151 |
152 |
153 | compile
154 | testCompile
155 |
156 |
157 |
158 | ${scala.version}
159 |
160 |
161 |
162 | org.apache.maven.plugins
163 | maven-compiler-plugin
164 |
165 | 1.7
166 | 1.7
167 | utf-8
168 |
169 |
170 |
171 | compile
172 |
173 | compile
174 |
175 |
176 |
177 |
178 |
179 | maven-assembly-plugin
180 |
181 |
182 | jar-with-dependencies
183 |
184 |
185 |
186 | example.Example1
187 |
188 |
189 |
190 |
191 |
192 | make-assembly
193 | package
194 |
195 | assembly
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
--------------------------------------------------------------------------------
/src/python/models/dnn.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import json
5 | import tensorflow as tf
6 |
7 | def parse_fn(example):
8 | example_fmt = {
9 | "visit_items_index": tf.FixedLenFeature([5], tf.int64),
10 | "continuous_features_value": tf.FixedLenFeature([16], tf.float32),
11 | "next_visit_item_index": tf.FixedLenFeature([], tf.int64)
12 | }
13 | parsed = tf.parse_single_example(example, example_fmt)
14 | next_visit_item_index = parsed.pop("next_visit_item_index")
15 | return parsed, next_visit_item_index
16 |
17 | def input_fn(path, parallel_num, epoch_num, batch_size):
18 | files = tf.data.Dataset.list_files(path, shuffle=True)
19 | data_set = files.apply(
20 | tf.contrib.data.parallel_interleave(
21 | map_func=lambda filename: tf.data.TFRecordDataset(filename),
22 | cycle_length=parallel_num))
23 | data_set = data_set.repeat(epoch_num)
24 | data_set = data_set.map(map_func=parse_fn, num_parallel_calls=parallel_num)
25 | data_set = data_set.prefetch(buffer_size=256)
26 | data_set = data_set.batch(batch_size=batch_size)
27 | return data_set
28 |
29 | def model_fn(features, labels, mode, params, config):
30 |
31 | visit_items_index = features["visit_items_index"] # num * 5
32 | continuous_features_value = features["continuous_features_value"] # num * 16
33 | next_visit_item_index = labels # num
34 | keep_prob = params["keep_prob"]
35 | embedding_size = params["embedding_size"]
36 | item_num = params["item_num"]
37 | learning_rate = params["learning_rate"]
38 | top_k = params["top_k"]
39 |
40 | # items embedding 初始化
41 | initializer = tf.initializers.random_uniform(minval=-0.5 / embedding_size, maxval=0.5 / embedding_size)
42 | partitioner = tf.fixed_size_partitioner(num_shards=embedding_size)
43 | item_embedding = tf.get_variable("item_embedding", [item_num, embedding_size],
44 | tf.float32, initializer=initializer, partitioner=partitioner)
45 |
46 | visit_items_embedding = tf.nn.embedding_lookup(item_embedding, visit_items_index) # num * 5 * embedding_size
47 | visit_items_average_embedding = tf.reduce_mean(visit_items_embedding, axis=1) # num * embedding_size
48 | input_embedding = tf.concat([visit_items_average_embedding, continuous_features_value], 1) # num * (embedding_size + 16)
49 | kernel_initializer_1 = tf.initializers.random_normal(mean=0.0, stddev=0.1)
50 | bias_initializer_1 = tf.initializers.random_normal(mean=0.0, stddev=0.1)
51 | layer_1 = tf.layers.dense(input_embedding, 64, activation=tf.nn.relu,
52 | kernel_initializer=kernel_initializer_1,
53 | bias_initializer=bias_initializer_1, name="layer_1")
54 | layer_dropout_1 = tf.nn.dropout(layer_1, keep_prob=keep_prob, name="layer_dropout_1")
55 | kernel_initializer_2 = tf.initializers.random_normal(mean=0.0, stddev=0.1)
56 | bias_initializer_2 = tf.initializers.random_normal(mean=0.0, stddev=0.1)
57 | layer_2 = tf.layers.dense(layer_dropout_1, 32, activation=tf.nn.relu,
58 | kernel_initializer=kernel_initializer_2,
59 | bias_initializer=bias_initializer_2, name="layer_2")
60 | layer_dropout_2 = tf.nn.dropout(layer_2, keep_prob=keep_prob, name="layer_dropout_2")
61 | # user vector, num * embedding_size
62 | kernel_initializer_3 = tf.initializers.random_normal(mean=0.0, stddev=0.1)
63 | bias_initializer_3 = tf.initializers.random_normal(mean=0.0, stddev=0.1)
64 | user_vector = tf.layers.dense(layer_dropout_2, embedding_size, activation=tf.nn.relu,
65 | kernel_initializer=kernel_initializer_3,
66 | bias_initializer=bias_initializer_3, name="user_vector")
67 |
68 | if mode == tf.estimator.ModeKeys.TRAIN:
69 | # 训练
70 | output_embedding = tf.nn.embedding_lookup(item_embedding, next_visit_item_index) # num * embedding_size
71 | logits = tf.matmul(user_vector, output_embedding, transpose_a=False, transpose_b=True) # num * num
72 | yhat = tf.nn.softmax(logits) # num * num
73 | cross_entropy = tf.reduce_mean(-tf.log(tf.matrix_diag_part(yhat) + 1e-16))
74 | optimizer = tf.train.GradientDescentOptimizer(learning_rate)
75 | train = optimizer.minimize(cross_entropy, global_step=tf.train.get_global_step())
76 | return tf.estimator.EstimatorSpec(mode, loss=cross_entropy, train_op=train)
77 |
78 | if mode == tf.estimator.ModeKeys.EVAL:
79 | # 评估
80 | output_embedding = tf.nn.embedding_lookup(item_embedding, next_visit_item_index) # num * embedding_size
81 | logits = tf.matmul(user_vector, output_embedding, transpose_a=False, transpose_b=True) # num * num
82 | yhat = tf.nn.softmax(logits) # num * num
83 | cross_entropy = tf.reduce_mean(-tf.log(tf.matrix_diag_part(yhat) + 1e-16))
84 | return tf.estimator.EstimatorSpec(mode, loss=cross_entropy)
85 |
86 | if mode == tf.estimator.ModeKeys.PREDICT:
87 | logits_predict = tf.matmul(user_vector, item_embedding, transpose_a=False, transpose_b=True) # num * item_num
88 | yhat_predict = tf.nn.softmax(logits_predict) # num * item_num
89 | _, indices = tf.nn.top_k(yhat_predict, k=top_k, sorted=True)
90 | index = tf.identity(indices, name="index") # num * top_k
91 | # 预测
92 | predictions = {
93 | "user_vector": user_vector,
94 | "index": index
95 | }
96 | export_outputs = {
97 | "prediction": tf.estimator.export.PredictOutput(predictions)
98 | }
99 | return tf.estimator.EstimatorSpec(mode, predictions=predictions, export_outputs=export_outputs)
100 |
101 | def build_estimator():
102 | params = {"keep_prob": 0.5, "embedding_size": 16, "item_num": 500, "learning_rate": 0.05, "top_k": 2}
103 | session_config = tf.ConfigProto(device_count={"CPU": 1}, allow_soft_placement=True, log_device_placement=False)
104 | session_config.gpu_options.allow_growth = True
105 | config = tf.estimator.RunConfig(
106 | model_dir="../../data/checkpoints/ckpt",
107 | tf_random_seed=2019,
108 | save_checkpoints_steps=100,
109 | session_config=session_config,
110 | keep_checkpoint_max=5,
111 | log_step_count_steps=100
112 | )
113 | estimator = tf.estimator.Estimator(model_fn=model_fn, config=config, params=params)
114 | return estimator
115 |
116 | def set_dist_env():
117 | if FLAGS.is_distributed:
118 | ps_hosts = FLAGS.strps_hosts.split(",")
119 | worker_hosts = FLAGS.strwork_hosts.split(",")
120 | job_name = FLAGS.job_name
121 | task_index = FLAGS.task_index
122 | chief_hosts = worker_hosts[0:1] # get first worker as chief
123 | worker_hosts = worker_hosts[2:] # the rest as worker
124 |
125 | # use #worker=0 as chief
126 | if job_name == "worker" and task_index == 0:
127 | job_name = "chief"
128 | # use #worker=1 as evaluator
129 | if job_name == "worker" and task_index == 1:
130 | job_name = 'evaluator'
131 | task_index = 0
132 | # the others as worker
133 | if job_name == "worker" and task_index > 1:
134 | task_index -= 2
135 |
136 | tf_config = {'cluster': {'chief': chief_hosts, 'worker': worker_hosts, 'ps': ps_hosts},
137 | 'task': {'type': job_name, 'index': task_index}}
138 | os.environ['TF_CONFIG'] = json.dumps(tf_config)
139 |
140 | def train_eval_save():
141 |
142 | set_dist_env()
143 |
144 | estimator = build_estimator()
145 |
146 | # 训练
147 | train_spec = tf.estimator.TrainSpec(
148 | input_fn=lambda: input_fn(
149 | path='../../data/tfrecords/train/train.tfrecords',
150 | parallel_num=32,
151 | epoch_num=11,
152 | batch_size=32),
153 | max_steps=1600
154 | )
155 | # 评估
156 | eval_spec = tf.estimator.EvalSpec(
157 | input_fn=lambda: input_fn(
158 | path='../../data/tfrecords/evaluation/evaluation.tfrecords',
159 | parallel_num=32,
160 | epoch_num=1,
161 | batch_size=32),
162 | steps=15, # 验证集评估多少批数据
163 | start_delay_secs=1, # 在多少秒后 start_delay_secs=1, # 在多少秒后
164 | throttle_secs=20 # evaluate every 20seconds
165 | )
166 | # 训练和评估
167 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
168 |
169 | # 模型保存
170 | features_spec = {
171 | "visit_items_index": tf.placeholder(tf.int64, shape=[None, 5], name="visit_items_index"),
172 | "continuous_features_value": tf.placeholder(tf.float32, shape=[None, 16], name="continuous_features_value")
173 | }
174 | serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(features_spec)
175 | estimator.export_savedmodel(
176 | "../../data/checkpoints/modelpath",
177 | serving_input_receiver_fn)
178 |
179 | def main(_):
180 | train_eval_save()
181 |
182 | if __name__ == "__main__":
183 | tf.logging.set_verbosity(tf.logging.INFO)
184 | FLAGS = tf.app.flags.FLAGS
185 | tf.app.flags.DEFINE_boolean("is_distributed", False, "是否分布式训练")
186 | tf.app.flags.DEFINE_string("strps_hosts", "localhost:2000", "参数服务器")
187 | tf.app.flags.DEFINE_string("strwork_hosts", "localhost:2100,localhost:2200,localhost:2300,localhost:2400", "工作服务器")
188 | tf.app.flags.DEFINE_string("job_name", "ps", "参数服务器或者工作服务器")
189 | tf.app.flags.DEFINE_integer("task_index", 0, "job的task索引")
190 | tf.app.run(main=main)
--------------------------------------------------------------------------------
/docs/architecture.drawio:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
--------------------------------------------------------------------------------