├── README.md
├── doc
├── BinaryClassify.md
├── CityHash.md
├── Example.md
├── FFMProcessor.md
├── FeatureAnalysis.md
├── FeatureProcess.md
├── FieldawareFactorizationMachine.md
├── JobConfigure.md
├── LibSVMProcessor.md
├── LogisticRegression.md
├── LogisticRegressionWithDCASGD.md
├── LogisticRegressionWithFTRL.md
├── LogisticRegressionWithMomentum.md
├── Model.md
├── PS.md
├── PSClient.md
├── PSConfiguration.md
├── PSMapStore.md
├── Updater.md
├── faq_cn.md
└── img
│ ├── logo.jpg
│ ├── qq.jpg
│ └── xdml.png
├── pom.xml
└── src
├── main
└── scala
│ ├── net
│ └── qihoo
│ │ └── xitong
│ │ └── xdml
│ │ ├── conf
│ │ ├── JobConfiguration.scala
│ │ ├── JobType.scala
│ │ ├── PSConfiguration.scala
│ │ └── PSDataType.scala
│ │ ├── dataProcess
│ │ ├── FFMProcessor.scala
│ │ └── LibSVMProcessor.scala
│ │ ├── example
│ │ ├── analysis
│ │ │ ├── feature
│ │ │ │ ├── analysis
│ │ │ │ │ ├── runUniversalAnalyzerDense.scala
│ │ │ │ │ ├── runUniversalAnalyzerDenseGrouped.scala
│ │ │ │ │ ├── runUniversalAnalyzerDenseKS.scala
│ │ │ │ │ └── runUniversalAnalyzerSparse.scala
│ │ │ │ └── process
│ │ │ │ │ ├── runCategoryEncoder.scala
│ │ │ │ │ ├── runFeatureProcess.scala
│ │ │ │ │ ├── runMultiCategoryEncoder.scala
│ │ │ │ │ ├── runNumericBucketer.scala
│ │ │ │ │ └── runNumericStandardizer.scala
│ │ │ └── model
│ │ │ │ ├── runFromDenseDataToXDMLH2ODRF.scala
│ │ │ │ ├── runFromDenseDataToXDMLH2OGBM.scala
│ │ │ │ ├── runFromDenseDataToXDMLH2OGLM.scala
│ │ │ │ ├── runFromDenseDataToXDMLH2OMLP.scala
│ │ │ │ ├── runFromLibSVMDataToXDMLLinearScopeModel.scala
│ │ │ │ ├── runFromLibSVMDataToXDMLOVR.scala
│ │ │ │ └── runFromLibSVMDataToXDMLSR.scala
│ │ └── ml
│ │ │ ├── DCASGDTest.scala
│ │ │ ├── FFMTest.scala
│ │ │ ├── FTRLTest.scala
│ │ │ ├── LRTest.scala
│ │ │ └── MomentumTest.scala
│ │ ├── feature
│ │ ├── analysis
│ │ │ ├── KolmogorovSmirnovTest.java
│ │ │ └── UniversalAnalyzer.scala
│ │ └── process
│ │ │ ├── CategoryEncoder.scala
│ │ │ ├── ColumnEliminator.scala
│ │ │ ├── ColumnRenamer.scala
│ │ │ ├── FeatureProcessor.scala
│ │ │ ├── MultiCategoryEncoder.scala
│ │ │ ├── NumericBucketer.scala
│ │ │ ├── NumericStandardizer.scala
│ │ │ └── PipelinePatch.scala
│ │ ├── linalg
│ │ └── BLAS.scala
│ │ ├── mapstore
│ │ ├── KuduTableOp.scala
│ │ └── PSMapStore.scala
│ │ ├── ml
│ │ ├── FieldawareFactorizationMachine.scala
│ │ ├── LogisticRegression.scala
│ │ ├── LogisticRegressionWithDCASGD.scala
│ │ ├── LogisticRegressionWithFTRL.scala
│ │ └── LogisticRegressionWithMomentum.scala
│ │ ├── model
│ │ ├── data
│ │ │ ├── DataHandler.scala
│ │ │ ├── DataProcessor.scala
│ │ │ ├── LogHandler.scala
│ │ │ └── SchemaHandler.scala
│ │ ├── linalg
│ │ │ └── BLAS.scala
│ │ ├── loss
│ │ │ ├── HingeLossFunc.scala
│ │ │ ├── L2LossFunc.scala
│ │ │ ├── LogitLossFunc.scala
│ │ │ ├── LossFunc.scala
│ │ │ ├── MultiHingeLossFunc.scala
│ │ │ ├── MultiLogitLossFunc.scala
│ │ │ ├── MultiLossFunc.scala
│ │ │ ├── MultiSmoothHingeLossFunc.scala
│ │ │ ├── PoissonLossFunc.scala
│ │ │ ├── SmoothHingeLossFunc.scala
│ │ │ ├── UPULogitLossFunc.scala
│ │ │ ├── WeightedHingeLossFunc.scala
│ │ │ ├── WeightedLogitLossFunc.scala
│ │ │ └── WeightedSmoothHingeLossFunc.scala
│ │ ├── supervised
│ │ │ ├── GLM
│ │ │ │ ├── LinearScope.scala
│ │ │ │ ├── MultiLinearScope.scala
│ │ │ │ └── OVRLinearScope.scala
│ │ │ └── H2O
│ │ │ │ ├── H2ODRF.scala
│ │ │ │ ├── H2OGBM.scala
│ │ │ │ ├── H2OGLM.scala
│ │ │ │ ├── H2OMLP.scala
│ │ │ │ └── H2OParams.scala
│ │ └── util
│ │ │ └── MLUtils.scala
│ │ ├── optimization
│ │ ├── BinaryClassify.scala
│ │ ├── FFM.scala
│ │ ├── FM.scala
│ │ └── Optimizer.scala
│ │ ├── ps
│ │ ├── PS.scala
│ │ └── PSClient.scala
│ │ ├── task
│ │ ├── PullTask.scala
│ │ ├── PushTask.scala
│ │ └── Task.scala
│ │ ├── updater
│ │ ├── DCASGDUpdater.scala
│ │ ├── FFMUpdater.scala
│ │ ├── LRFTRLUpdater.scala
│ │ ├── LRUpdater.scala
│ │ ├── MomLRUpdater.scala
│ │ └── Updater.scala
│ │ └── utils
│ │ ├── CityHash.java
│ │ ├── ExitCodeResolver.java
│ │ ├── JHex.java
│ │ └── XDMLException.java
│ └── org
│ └── apache
│ └── kudu
│ └── client
│ └── SessionConfiguration.java
└── test
└── scala
├── DataColFilter.scala
├── DataFilter.scala
├── DataRowFilter.scala
└── TestFilter.scala
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
7 |
8 | [](./LICENSE)
9 | []()
10 | []()
11 |
12 |
13 | **XDML**是一款基于参数服务器(Parameter Server),采用专门缓存机制的分布式机器学习平台。
14 | XDML内化了学界最新研究成果,在效果保持稳定的同时,能大幅加速收敛进程,显著提升模型与算法的性能。同时,XDML还对接了一些优秀的开源成果和360公司自研成果,站在巨人的肩膀上,博采众长。 XDML还兼容hadoop生态,提供更好的大数据框架使用体验,将开发者从繁杂的工作中解脱出来。XDML已经在360内部海量规模数据上进行了大量测试和调优,在大规模数据量和超高维特征的机器学习任务上,具有良好的稳定性,扩展性和兼容性。
15 |
16 | 欢迎对机器学习或分布式有兴趣的同仁一起贡献代码,提交Issues或者Pull Requests。
17 |
18 | ## 架构设计
19 | 
20 |
21 | 针对超大规模机器学习的场景,奇虎360开源了内部的超大规模机器学习计算框架XDML。XDML是一款基于参数服务器(Parameter Server),采用专门缓存机制的分布式机器学习平台。它在360内部海量规模数据上进行了测试和调优,在大规模数据量和超高维特征的机器学习任务上,具有良好的稳定性,扩展性和兼容性。
22 |
23 | ## 功能特性
24 | #### 1.提供特征预处理/分析,离线训练,模型管理等功能模块
25 | #### 2.实现常用的大规模数据量场景下的机器学习算法
26 | #### 3.充分利用现有的成熟技术,保证整个框架的高效稳定
27 | #### 4.完全兼容hadoop生态,和现有的大数据工具实现无缝对接,提升处理海量数据的能力
28 | #### 5.在系统架构和算法层面实现深度的工程优化,在不损失精度的前提下,大幅提高性能
29 |
30 |
31 | ## 代码结构
32 |
33 | ### 1.ps
34 | XDML的核心参数服务器架构,包括以下组件:
35 |
36 | - [PS](./doc/PS.md)
37 | - [PSClient](./doc/PSClient.md)
38 |
39 | ### 2.conf
40 | XDML的配置包,包括对参数服务器的配置和对作业及模型相关的配置。包括以下组件:
41 |
42 | - [JobConfiguration](./doc/JobConfigure.md)
43 | - [PSConfiguration](./doc/PSConfiguration.md)
44 | - ...
45 |
46 | ### 3.task
47 | XDML向PS提交的作业,包括拉取和推送。包括以下任务:
48 |
49 | - Task
50 | - PullTask
51 | - PushTask
52 |
53 | ### 4.optimization
54 | XDML模型的优化算法包。包括以下优化算法:
55 |
56 | - [BinaryClassify](./doc/BinaryClassify.md)
57 | - [FFM](./doc/FFMProcessor.md)
58 | - ...
59 |
60 | ### 5.ml
61 | XDML中已经实现的部分机器学习模型。包括以下模型:
62 |
63 | - [LogisticRegression](./doc/LogisticRegression.md)
64 | - [LogisticRegressionWithDCASGD](./doc/LogisticRegressionWithDCASGD.md)
65 | - [LogisticRegressionWithFTRL](./doc/LogisticRegressionWithFTRL.md)
66 | - [LogisticRegressionWithMomentum](./doc/LogisticRegressionWithMomentum.md)
67 | - [FieldwareFactorizationMachine](./doc/FieldawareFactorizationMachine.md)
68 | - ...
69 |
70 | ### 6.feature
71 | XDML中特征分析和特征处理模块。
72 |
73 | - [特征分析](./doc/FeatureAnalysis.md)
74 |
75 | 特征分析覆盖常见的分析指标,如数值型特征的偏度、峰度、分位数,与label相关的auc、ndcg、互信息、相关系数等指标。
76 |
77 | - [特征处理](./doc/FeatureProcess.md)
78 |
79 | 特征处理覆盖常见的数值型、类别型特征预处理方法。包括以下算子:
80 | - CategoryEncoder
81 | - MultiCategoryEncoder
82 | - NumericBuckter
83 | - NumericStandardizer
84 |
85 | ### 7.model
86 | XDML中包含用南京大学李武军老师提出的[Scope](https://arxiv.org/pdf/1602.00133.pdf)优化算法进行训练的线性模型,以及部分[H2O](https://www.h2o.ai/)模型的spark pipeline封装。具体包括以下模型:
87 |
88 | [Model:](./doc/Model.md)
89 |
90 | - LinearScope
91 | - MultiLinearScope
92 | - OVRLinearScope
93 | - H2ODRF
94 | - H2OGBM
95 | - H2OGLM
96 | - H2OMLP
97 |
98 | ### 8.example
99 | XDML中作业提交实例,可以参考[Example](./doc/Example.md).
100 |
101 | ## 编译&部署指南
102 |
103 | XDML是基于Kudu、HazelCast以及Hadoop生态圈的一款基于参数服务器的,采用专门缓存机制的分布式机器学习平台。
104 |
105 | ### 环境依赖
106 | - centos >= 7
107 | - Jdk >= 1.8
108 | - Maven >= 3.5.4
109 | - scala >= 2.11
110 | - hadoop >= 2.7.3
111 | - spark >= 2.3.0
112 | - sparkling-water-core >= 2.3.0
113 | - kudu >= 1.9
114 | - HazelCast >= 3.9.3
115 |
116 | ### Kudu安装部署
117 | XDML基于Kudu,请首先部署Kudu。Kudu的安装部署请参考[Kudu](https://github.com/apache/kudu/tree/1.7.0)。
118 |
119 | ### 源码下载
120 | ```git clone https://github.com/Qihoo360/XLearning-XDML```
121 |
122 | ### 编译
123 | ```mvn clean package -Dmaven.test.skip=true```
124 | 编译完成后,在源码根目录的`target`目录下会生成:`xdml-1.0.jar`、`xdml-1.0-jar-with-dependencies.jar`等多个文件,`xdml-1.0.jar`为未加spark、kudu等第三方依赖,`xdml-1.0-jar-with-dependencies.jar`添加了spark、kudu等依赖包。
125 |
126 | ## 运行示例
127 |
128 | ### 提交参数
129 | * **算法参数**
130 | * spark.xdml.learningRate:学习率
131 | * **训练参数**
132 | * spark.xdml.job.type:作业类型
133 | * spark.xdml.train.data.path:训练数据路径
134 | * spark.xdml.train.data.partitionNum:训练数据分区
135 | * spark.xdml.model.path:模型存储路径
136 | * spark.xdml.train.iter:训练迭代次数
137 | * spark.xdml.train.batchsize:训练数据batch大小
138 | * **PS相关参数**
139 | * spark.xdml.hz.clusterNum:hazelcast集群机器数目
140 | * spark.xdml.table.name:kudu表名称
141 |
142 | ### 提交命令
143 | 可以通过以下命令提交示例训练作业:
144 |
145 | ```
146 | $SPARK_HOME/bin/spark-submit \
147 | --master yarn-cluster \
148 | --class net.qihoo.xitong.xdml.example.LRTest \
149 | --num-executors 50 \
150 | --executor-memory 40g \
151 | --executor-cores 2 \
152 | --driver-memory 4g \
153 | --conf "spark.xdml.table.name=lrtest" \
154 | --conf "spark.xdml.job.type=train" \
155 | --conf "spark.xdml.train.data.path=$trainpath" \
156 | --conf "spark.xdml.train.data.partitionNum=50" \
157 | --conf "spark.xdml.hz.clusterNum=50" \
158 | --conf "spark.xdml.model.path=$modelpath" \
159 | --conf "spark.xdml.train.iter=5" \
160 | --conf "spark.xdml.train.batchsize=10000" \
161 | --conf "spark.xdml.learningRate=0.1" \
162 | --jars xdml-1.0-jar-with-dependencies.jar \
163 | xdml-1.0-jar-with-dependencies.jar
164 |
165 | ```
166 |
167 | 注:提交命令中的设置有`$SPARK_HOME`、`$trainpath`、`$modelpath` 分别代表spark客户端路径、训练数据HDFS路径、模型存储HDFS路径
168 |
169 | ## FAQ
170 | [**XDML常见问题**](./doc/faq_cn.md)
171 |
172 | ## 参考文献
173 | XDML参考了学界及工业界诸多优秀成果,对此表示感谢!
174 |
175 | - Shen-Yi Zhao, Ru Xiang, Ying-Hao Shi, Peng Gao, Wu-Jun Li, [SCOPE: Scalable Composite Optimization for Learning on Spark](https://arxiv.org/pdf/1602.00133.pdf). AAAI 2017: 2928-2934.
176 | - Shen-Yi Zhao, Gong-Duo Zhang, Ming-Wei Li, Wu-Jun Li.[Proximal SCOPE for Distributed Sparse Learning](https://arxiv.org/pdf/1803.05621.pdf).Proceedings of the Annual Conference on Neural Information Processing Systems (NIPS), 2018.
177 | - Shuxin Zheng, Qi Meng, Taifeng Wang, Wei Chen, Zhi-Ming Ma and Tie-Yan Liu, [Asynchronous Stochastic Gradient Descent with Delay Compensation](https://arxiv.org/pdf/1609.08326.pdf), ICML 2017.
178 |
179 | ## 联系我们
180 |
181 | Mail:
182 | QQ群:874050710
183 | 
--------------------------------------------------------------------------------
/doc/BinaryClassify.md:
--------------------------------------------------------------------------------
1 | # BinaryClassify
2 |
3 | ---
4 |
5 |
6 | > 二分类模型中计算梯度和损失的工具类
7 |
8 | ## 功能
9 |
10 | * 计算二分类模型中训练数据的梯度和损失值,并且可以用predict函数进行预测。
11 |
12 | ## 核心接口
13 |
14 | 1. **train**
15 | - 定义:
16 | ```
17 | train(label: Double,
18 | feature: (Array[Long], Array[Float]),
19 | localW: Map[Long, Float],
20 | subsampling_rate: Double = 1.0,
21 | subsampling_label: Double = 0.0
22 | ): (Map[Long, Float], Double)
23 | ```
24 | - 功能描述:处理原始的文本类型的RDD数据,返回新的数值化的RDD
25 | - 参数:
26 | - label:样本的标签label
27 | - feature:样本的特征和对应的value值
28 | - localW:本地的特征权重值Map
29 | - subsampling_rate:负采样率(未启用)
30 | - subsampling_label:负采样标签(未启用)
31 | - 返回值:
32 | - (Map[Long, Float], Double) :计算得到的样本的梯度和损失值
33 | * 第一列为样本特征对应的梯度,
34 | * 第二列为样本的损失
35 |
36 | 1. **predict**
37 | - 定义:```predict(feature: (Array[Long], Array[Float]),weight: Map[Long, Float]): Double ```
38 | - 功能描述:对二分类模型进行预测
39 | - 参数:
40 | - feature:样本的特征id和value值
41 | - weight:特征weight权重Map
42 | - 返回值:Double:该条样本预测的标签
--------------------------------------------------------------------------------
/doc/CityHash.md:
--------------------------------------------------------------------------------
1 | # CityHash
2 |
3 | ---
4 |
5 |
6 | > 计算hash值的工具类,可用于将字符串特征进行hash化得到数值型id,是google开源的cityhash的java版本,具有效率高超低碰撞的特点
7 |
8 | ## 功能
9 |
10 | * google的cityhash算法的java版本,提供32位,64位和128位哈希值运算。
11 |
12 |
13 | ## 核心接口
14 |
15 | 1. **stringCityHash32**
16 | - 定义:```int stringCityHash64(String str) ```
17 | - 功能描述:计算指定字符串的32位的hash值
18 | - 参数:
19 | - str:需要计算的字符串
20 | - 返回值:int:32位hash值
21 |
22 | 1. **stringCityHash64**
23 | - 定义:```long stringCityHash64(String str) ```
24 | - 功能描述:计算指定字符串的64位的hash值
25 | - 参数:
26 | - str:需要计算的字符串
27 | - 返回值:long:64位hash值
28 |
29 | 1. **stringCityHash128**
30 | - 定义:```long[] stringCityHash128(String str) ```
31 | - 功能描述:计算指定字符串的128位的hash值
32 | - 参数:
33 | - str:需要计算的字符串
34 | - 返回值:long[]:128位hash值
35 |
--------------------------------------------------------------------------------
/doc/FFMProcessor.md:
--------------------------------------------------------------------------------
1 | # FFMPreocessor
2 |
3 | ---
4 |
5 |
6 | > 处理标准libFFM格式数据的工具类
7 |
8 | ## 功能
9 |
10 | * 处理标准的文本libFFM数据,得到样本的RDD
11 |
12 | ## 核心接口
13 |
14 | 1. **processData**
15 | - 定义:```processData(data:RDD[String],separator:String = " "):RDD[(Double,Array[Int],Array[Long],Array[Float])] ```
16 | - 功能描述:处理原始的文本类型的RDD数据,返回新的数值化的RDD
17 | - 参数:
18 | - data:原始的文本数据RDD
19 | - separator:特征数据之间的分隔符,默认为单个空格
20 | - 返回值:
21 | - RDD[(Double,Array[Int],Array[Long],Array[Float])] :返回数值化的RDD,
22 | * 第一列为样本标签,
23 | * 第二列为field数组
24 | * 第三列为featureId数组
25 | * 第四列为feature对应的value数组
--------------------------------------------------------------------------------
/doc/FieldawareFactorizationMachine.md:
--------------------------------------------------------------------------------
1 | # FieldawareFactorizationMachine
2 |
3 | ---
4 |
5 |
6 | > FFM模型
7 |
8 | ## 功能
9 |
10 | * 实现了FFM,使用SGD的优化方法。提供训练的fit方法和预测的predict方法。
11 |
12 |
13 | ## 核心接口
14 |
15 | 1. **fit**
16 | - 定义:```fit(data: RDD[(Double, Array[Int], Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) ```
17 | - 功能描述:模型的训练方法
18 | - 参数:
19 | - data:训练的输入数据RDD,1为label标签,2为特征的filed域,3为特征id数组,4为相应特征值数组
20 | - 返回值:
21 | - (Map[Int, Double], (Long, Long)):
22 | - 第一列为训练信息Map,key为迭代轮次,value为每轮的loss损失
23 | - 第二列为训练数据的正负样本数
24 |
25 | 2. **predict**
26 | - 定义:```def predict(data: RDD[(Double, Array[Long], Array[Float])]): RDD[(Double, Double)] ```
27 | - 功能描述:模型的预测方法
28 | - 参数:
29 | - data:预测的输入数据RDD,1为label标签,2为特征的filed域,3为特征id数组,4为相应特征值数组
30 | - 返回值:
31 | - RDD[(Double, Double)]:样本label和预测的label
32 |
33 | 3. **setIterNum**
34 | - 定义:```setIterNum(iterNum: Int): this.type ```
35 | - 功能描述:设置数据迭代次数
36 | - 参数:
37 | - iterNum: 迭代次数
38 | - 返回值:this
39 |
40 | 4. **setBatchSize**
41 | - 定义:```setBatchSize(batchSize: Int): this.type ```
42 | - 功能描述:设置训练batch的大小
43 | - 参数:
44 | - batchSize:一组batch的大小
45 | - 返回值:this
46 |
47 | 5. **setLearningRate**
48 | - 定义:```setLearningRate(lr: Float): this.type ```
49 | - 功能描述:设置模型的学习率
50 | - 参数:
51 | - lr:学习率
52 | - 返回值:this
53 |
54 | 6. **setRank**
55 | - 定义:```setRank(rank:Int):this.type ```
56 | - 功能描述:设置FFM的单个域的向量长度
57 | - 参数:
58 | - rank:域向量长度
59 | - 返回值:this
60 |
61 | 6. **setField**
62 | - 定义:```setField(field:Int):this.type ```
63 | - 功能描述:设置FFM的单个域的向量长度
64 | - 参数:
65 | - field:设置field域的个数
66 | - 返回值:this
--------------------------------------------------------------------------------
/doc/JobConfigure.md:
--------------------------------------------------------------------------------
1 | ## 作业相关配置
2 |
3 | 配置项名称 | 默认值 | 配置项含义
4 | ---------------- | --------------- | ---------------
5 | spark.xdml.job.type | "train" | 设置作业类型:train,predict,increment_train
6 | spark.xdml.table.name | "table" | 设置kudu存储的表名称
7 | spark.xdml.train.data.path | 无 | 训练数据路径
8 | spark.xdml.train.data.partitionNum | | 训练数据分区数目
9 | spark.xdml.data.split | " " | 数据分隔符
10 | spark.xdml.model.path | 无 | 模型存储路径
11 | spark.xdml.predict.result.path | 无 | 预测结果输出路径
12 | spark.xdml.kudu.master | 无 | kudu集群master地址
13 | spark.xdml.hz.clusterNum | 50 | hazelcast集群节点数目
14 | spark.xdml.hz.partitionNum | 127 | hazelcast分区数目
15 | spark.xdml.hz.maxcachesize | 50000000 | hazelcast分区最大缓存大小
16 | spark.xdml.train.iter |1 | 训练迭代次数
17 | spark.xdml.train.batchsize | 50 | 训练数据batch大小
18 | spark.xdml.learningRate | 0.01f| 设置学习率
19 | spark.xdml.momentumCoff | 0.1f | 动量阻力参数
20 | spark.xdml.train.alpha | 1f | 设置ftrl等算法训练所需参数alpha
21 | spark.xdml.train.beta | 1f | 设置ftrl等算法训练所需参数beta
22 | spark.xdml.train.lambda1 | 1f | 设置ftrl算法训练所需参数lambda1
23 | spark.xdml.train.lambda2 | 1f | 设置ftrl算法训练所需参数lambda1
24 | spark.xdml.train.forcesparse | false | 设置ftrl算法是否保持强稀疏性,即w为0时,z、n对应值也为0
25 | spark.xdml.model.ffm.rank | 1 | 设置ffm算法中的k值
26 | spark.xdml.model.ffm.field | 1 | 设置ffm算法中的
27 |
28 |
--------------------------------------------------------------------------------
/doc/LibSVMProcessor.md:
--------------------------------------------------------------------------------
1 | # LibSVMPreocessor
2 |
3 | ---
4 |
5 |
6 | > 处理标准libSVM格式数据的工具类
7 | ## 功能
8 |
9 | * 处理标准的文本libFFM数据,得到样本的RDD
10 |
11 | ## 核心接口
12 |
13 | 1. **processData**
14 | - 定义:```processData(data:RDD[String],separator:String = " "):RDD[(Double,Array[Long],Array[Float])] ```
15 | - 功能描述:处理原始的文本类型的RDD数据,返回新的数值化的RDD
16 | - 参数:
17 | - data:原始的文本数据RDD
18 | - separator:特征数据之间的分隔符,默认为单个空格
19 | - 返回值:
20 | - RDD[(Double,Array[Long],Array[Float])] :返回数值化的RDD,
21 | * 第一列为样本标签,
22 | * 第二列为featureId数组
23 | * 第三列为feature对应的value值数组
--------------------------------------------------------------------------------
/doc/LogisticRegression.md:
--------------------------------------------------------------------------------
1 | # LogisticRegression
2 |
3 | ---
4 |
5 |
6 | > 逻辑回归模型
7 |
8 | ## 功能
9 |
10 | * 实现了逻辑回归模型,使用SGD的优化方法。提供训练的fit方法和预测的predict方法。
11 |
12 |
13 | ## 核心接口
14 |
15 | 1. **fit**
16 | - 定义:```fit(data: RDD[(Double, Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) ```
17 | - 功能描述:模型的训练方法
18 | - 参数:
19 | - data:训练的输入数据RDD,1为label标签,2为特征id数组,3为相应特征值数组
20 | - 返回值:
21 | - (Map[Int, Double], (Long, Long)):
22 | - 第一列为训练信息Map,key为迭代轮次,value为每轮的loss损失
23 | - 第二列为训练数据的正负样本数
24 |
25 | 2. **predict**
26 | - 定义:```def predict(data: RDD[(Double, Array[Long], Array[Float])]): RDD[(Double, Double)] ```
27 | - 功能描述:模型的预测方法
28 | - 参数:
29 | - data:预测的输入数据RDD,1为label标签,2为特征id数组,3为相应特征值数组
30 | - 返回值:
31 | - RDD[(Double, Double)]:样本label和预测的label
32 |
33 | 3. **setIterNum**
34 | - 定义:```setIterNum(iterNum: Int): this.type ```
35 | - 功能描述:设置数据迭代次数
36 | - 参数:
37 | - iterNum: 迭代次数
38 | - 返回值:this
39 |
40 | 4. **setBatchSize**
41 | - 定义:```setBatchSize(batchSize: Int): this.type ```
42 | - 功能描述:设置训练batch的大小
43 | - 参数:
44 | - batchSize:一组batch的大小
45 | - 返回值:this
46 |
47 | 5. **setLearningRate**
48 | - 定义:```setLearningRate(lr: Float): this.type ```
49 | - 功能描述:设置模型的学习率
50 | - 参数:
51 | - lr:学习率
52 | - 返回值:this
53 |
54 |
--------------------------------------------------------------------------------
/doc/LogisticRegressionWithDCASGD.md:
--------------------------------------------------------------------------------
1 | # LogisticRegressionWithDCASGD
2 |
3 | ---
4 |
5 |
6 | > 带有延迟梯度更新的逻辑回归模型
7 |
8 | ## 功能
9 |
10 | * 实现了基于延迟梯度更新(DC-ASGD)的逻辑回归模型,使用SGD的优化方法。提供训练的fit方法和预测的predict方法。
11 |
12 |
13 | ## 核心接口
14 |
15 | 1. **fit**
16 | - 定义:```fit(data: RDD[(Double, Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) ```
17 | - 功能描述:模型的训练方法
18 | - 参数:
19 | - data:训练的输入数据RDD,1为label标签,2为特征id数组,3为相应特征值数组
20 | - 返回值:
21 | - (Map[Int, Double], (Long, Long)):
22 | - 第一列为训练信息Map,key为迭代轮次,value为每轮的loss损失
23 | - 第二列为训练数据的正负样本数
24 |
25 | 2. **predict**
26 | - 定义:```def predict(data: RDD[(Double, Array[Long], Array[Float])]): RDD[(Double, Double)] ```
27 | - 功能描述:模型的预测方法
28 | - 参数:
29 | - data:预测的输入数据RDD,1为label标签,2为特征id数组,3为相应特征值数组
30 | - 返回值:
31 | - RDD[(Double, Double)]:样本label和预测的label
32 |
33 | 3. **setIterNum**
34 | - 定义:```setIterNum(iterNum: Int): this.type ```
35 | - 功能描述:设置数据迭代次数
36 | - 参数:
37 | - iterNum: 迭代次数
38 | - 返回值:this
39 |
40 | 4. **setBatchSize**
41 | - 定义:```setBatchSize(batchSize: Int): this.type ```
42 | - 功能描述:设置训练batch的大小
43 | - 参数:
44 | - batchSize:一组batch的大小
45 | - 返回值:this
46 |
47 | 5. **setLearningRate**
48 | - 定义:```setLearningRate(lr: Float): this.type ```
49 | - 功能描述:设置模型的学习率
50 | - 参数:
51 | - lr:学习率
52 | - 返回值:this
53 |
54 | 6. **setDcAsgdCoff**
55 | - 定义:```setDcAsgdCoff(coff:Float):this.type ```
56 | - 功能描述:设置延迟梯度系数
57 | - 参数:
58 | - coff:延迟梯度系数
59 | - 返回值:this
60 |
--------------------------------------------------------------------------------
/doc/LogisticRegressionWithFTRL.md:
--------------------------------------------------------------------------------
1 | # LogisticRegressionWithFTRL
2 |
3 | ---
4 |
5 |
6 | > 使用FTRL优化方法的逻辑回归模型
7 |
8 | ## 功能
9 |
10 | * 实现了带有FTRL优化的逻辑回归模型。提供训练的fit方法和预测的predict方法。
11 |
12 |
13 | ## 核心接口
14 |
15 | 1. **fit**
16 | - 定义:```fit(data: RDD[(Double, Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) ```
17 | - 功能描述:模型的训练方法
18 | - 参数:
19 | - data:训练的输入数据RDD,1为label标签,2为特征id数组,3为相应特征值数组
20 | - 返回值:
21 | - (Map[Int, Double], (Long, Long)):
22 | - 第一列为训练信息Map,key为迭代轮次,value为每轮的loss损失
23 | - 第二列为训练数据的正负样本数
24 |
25 | 2. **predict**
26 | - 定义:```def predict(data: RDD[(Double, Array[Long], Array[Float])]): RDD[(Double, Double)] ```
27 | - 功能描述:模型的预测方法
28 | - 参数:
29 | - data:预测的输入数据RDD,1为label标签,2为特征id数组,3为相应特征值数组
30 | - 返回值:
31 | - RDD[(Double, Double)]:样本label和预测的label
32 |
33 | 3. **setIterNum**
34 | - 定义:```setIterNum(iterNum: Int): this.type ```
35 | - 功能描述:设置数据迭代次数
36 | - 参数:
37 | - iterNum: 迭代次数
38 | - 返回值:this
39 |
40 | 4. **setBatchSize**
41 | - 定义:```setBatchSize(batchSize: Int): this.type ```
42 | - 功能描述:设置训练batch的大小
43 | - 参数:
44 | - batchSize:一组batch的大小
45 | - 返回值:this
46 |
47 | 5. **setAlpha**
48 | - 定义:```setAlpha(alpha: Float): this.type ```
49 | - 功能描述:设置FTRL中的参数alpha
50 | - 参数:
51 | - alpha:FTRL中的参数alpha
52 | - 返回值:this
53 |
54 | 6. **setBeta**
55 | - 定义:```setBeta(beta: Float): this.type ```
56 | - 功能描述:设置FTRL中的参数beta
57 | - 参数:
58 | - beta:FTRL中的参数beta
59 | - 返回值:this
60 |
61 | 7. **setLambda1**
62 | - 定义:```setLambda1(lambda1: Float): this.type ```
63 | - 功能描述:设置FTRL中的参数lambda1
64 | - 参数:
65 | - lambda1:FTRL中的参数lambda1
66 | - 返回值:this
67 |
68 | 8. **setLambda2**
69 | - 定义:```setLambda2(lambda2: Float): this.type ```
70 | - 功能描述:设置FTRL中的参数lambda1
71 | - 参数:
72 | - lambda2:FTRL中的参数lambda1
73 | - 返回值:this
74 |
--------------------------------------------------------------------------------
/doc/LogisticRegressionWithMomentum.md:
--------------------------------------------------------------------------------
1 | # LogisticRegressionWithDCASGD
2 |
3 | ---
4 |
5 |
6 | > 带有冲量的逻辑回归模型
7 |
8 | ## 功能
9 |
10 | * 实现了带有冲量的逻辑回归模型,使用SGD的优化方法。提供训练的fit方法和预测的predict方法。
11 |
12 |
13 | ## 核心接口
14 |
15 | 1. **fit**
16 | - 定义:```fit(data: RDD[(Double, Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) ```
17 | - 功能描述:模型的训练方法
18 | - 参数:
19 | - data:训练的输入数据RDD,1为label标签,2为特征id数组,3为相应特征值数组
20 | - 返回值:
21 | - (Map[Int, Double], (Long, Long)):
22 | - 第一列为训练信息Map,key为迭代轮次,value为每轮的loss损失
23 | - 第二列为训练数据的正负样本数
24 |
25 | 2. **predict**
26 | - 定义:```def predict(data: RDD[(Double, Array[Long], Array[Float])]): RDD[(Double, Double)] ```
27 | - 功能描述:模型的预测方法
28 | - 参数:
29 | - data:预测的输入数据RDD,1为label标签,2为特征id数组,3为相应特征值数组
30 | - 返回值:
31 | - RDD[(Double, Double)]:样本label和预测的label
32 |
33 | 3. **setIterNum**
34 | - 定义:```setIterNum(iterNum: Int): this.type ```
35 | - 功能描述:设置数据迭代次数
36 | - 参数:
37 | - iterNum: 迭代次数
38 | - 返回值:this
39 |
40 | 4. **setBatchSize**
41 | - 定义:```setBatchSize(batchSize: Int): this.type ```
42 | - 功能描述:设置训练batch的大小
43 | - 参数:
44 | - batchSize:一组batch的大小
45 | - 返回值:this
46 |
47 | 5. **setLearningRate**
48 | - 定义:```setLearningRate(lr: Float): this.type ```
49 | - 功能描述:设置模型的学习率
50 | - 参数:
51 | - lr:学习率
52 | - 返回值:this
53 |
54 | 6. **setMomemtumCoff**
55 | - 定义:```setMomemtumCoff(coff:Float):this.type ```
56 | - 功能描述:设置冲量系数
57 | - 参数:
58 | - coff:冲量系数
59 | - 返回值:this
60 |
--------------------------------------------------------------------------------
/doc/PS.md:
--------------------------------------------------------------------------------
1 | # PS
2 |
3 | ---
4 |
5 |
6 | > PS是XDML中抽象出来的参数服务器类。
7 |
8 | ## 功能
9 |
10 | * 包括获取参数服务器单例,保存和载入模型,load参数,启动和销毁参数服务器等功能
11 |
12 |
13 | ## 核心接口
14 |
15 | 1. **getInstance**
16 | - 定义:```PS.getInstance(sc: SparkContext, psConf: PSConfiguration): PS ```
17 | - 功能描述:设置作业参数服务器的作业类型,目前有train,predict和increment_train三种作业类型。
18 | - 参数:
19 | - sc:Spark作业的SparkContext实例
20 | - psConf:PSConfiguration的实例,包含了启动参数服务器的所有参数。
21 | - 返回值:
22 | - PS:参数服务器的单例
23 |
24 | 2. **getIPList**
25 | - 定义:```getIPList: Array[String] ```
26 | - 功能描述:获取hazelcast server集群启动后的节点IP列表
27 | - 参数:无
28 | - 返回值:
29 | - Array[String]:hazelcast集群节点的ip组成的数组
30 |
31 | 3. **start**
32 | - 定义:```start(sc: SparkContext): Unit ```
33 | - 功能描述:根据已配置的参数和作业类型启动参数服务器
34 | - 参数:
35 | - sc:Spark作业的SparkContext实例
36 | - 返回值:无
37 |
38 | 4. **createKuduTable**
39 | - 定义:```createKuduTable(): Unit ```
40 | - 功能描述:在kudu集群中创建一张新的表
41 | - 参数:无
42 | - 返回值:无
43 |
44 | 5. **startHzCluster**
45 | - 定义:```startHzCluster(sc: SparkContext) ```
46 | - 功能描述:根据传入的PSConfiguration的实例参数启动hazelcast集群
47 | - 参数:
48 | - sc:Spark作业的SparkContext实例
49 | - 返回值:无
50 |
51 | 6. **createHzMapConfig**
52 | - 定义:```createHzMapConfig(tableName: String): MapConfig ```
53 | - 功能描述:创建启动hz节点需要的一个MapConfig的实例
54 | - 参数:
55 | - tableName:需要创建的表名
56 | - 返回值:MapConfig:启动hz所需的MapConfig的实例
57 |
58 | 7. **getAllParametersReady**
59 | - 定义:
60 | ```getAllParametersReady(sc: SparkContext, tableName: String = psConf.getPsTableName(), kuduContext: KuduContext): Long ```
61 | - 功能描述:将kudu制定表中所有的参数载入到hz中
62 | - 参数:
63 | - sc:Spark作业的SparkContext实例
64 | - tableName: kudu中指定的表名,默认是PsConf中配置的第一个表名
65 | - kuduContext:KuduContext实例,类属性
66 | - 返回值:Long:载入的参数的总个数
67 |
68 | 8. **getRangeModelWeightFromBytes**
69 | - 定义:```getRangeModelWeightFromBytes(bytes: Array[Byte], startRange: Int, endRange: Int): Any ```
70 | - 功能描述:从Byte数组中获取指定范围的泛型数据类型的数据
71 | - 参数:
72 | - bytes: 源byte数组数据
73 | - startRange: 起始索引(包含)
74 | - endRange: 结束索引(不包含)
75 | - 返回值:Any:泛型指定的数据类型数据
76 |
77 | 9. **getRangeBytesFromWeight**
78 | - 定义:```getRangeBytesFromWeight(values: Array[String], startRange: Int, endRange: Int): Array[Byte] ```
79 | - 功能描述:从指定的字符串数组中按照ps设置的数据类型将指定范围的数据转换成Byte数组
80 | - 参数:
81 | - values:需要处理的字符串数组
82 | - startRange: 起始索引(包含)
83 | - endRange: 结束索引(不包含)
84 | - 返回值:Array[Byte]:转化后的Byte类型的数组
85 |
86 | 10. **saveModel**
87 | - 定义:
88 |
89 | saveModel(sc: SparkContext,
90 | outputPath: String,
91 | tableName: String = psConf.getPsTableName(),
92 | split: String = "\t",
93 | weightStartPos: Int = psConf.getWeightIndexRangeStart,
94 | weightEndPos: Int = psConf.getWeightIndexRangeEnd,
95 | kuduContext: KuduContext = PS.kuduContext): Unit
96 |
97 | - 功能描述:保存训练的模型到HDFS
98 | - 参数:
99 | - sc:Spark作业的SparkContext实例
100 | - outputPath: 模型输入的HDFS路径
101 | - tableName: 指定模型参数所在kudu指定的表名,
102 | - split: 模型参数文本的分隔符,默认为\t
103 | - weightStartPos: 指定需要保存模型参数的起始索引(包含),默认为0,
104 | - weightEndPos:指定需要保存模型参数的起始索引(不包含),默认为ps中设置的数据长度
105 | - kuduContext:KuduContext实例,类属性
106 | - 返回值:无
107 |
108 | 11. **loadModel**
109 | - 定义:
110 |
111 | loadModel(sc: SparkContext,
112 | inputPath: String,
113 | kuduContext: KuduContext,
114 | tableName: String = psConf.getPsTableName(),
115 | split: String = "\t",
116 | weightStartPos: Int = psConf.getWeightIndexRangeStart,
117 | weightEndPos: Int = psConf.getWeightIndexRangeEnd
118 | ): Unit
119 |
120 | - 功能描述:从HDFS中load模型到kudu集群的指定表中
121 | - 参数:
122 | - sc:Spark作业的SparkContext实例
123 | - inputPath: 需要载入的模型的HDFS路径
124 | - kuduContext:KuduContext实例,类属性
125 | - tableName:模型load到kudu指定的表名中
126 | - split: 模型参数文本的分隔符,默认为\t
127 | - weightStartPos: 指定需要载入模型参数的起始索引(包含),默认为0,
128 | - weightEndPos:指定需要载入模型参数的起始索引(不包含),默认为ps中设置的数据长度
129 | - 返回值:无
130 |
--------------------------------------------------------------------------------
/doc/PSClient.md:
--------------------------------------------------------------------------------
1 | # PSClient
2 |
3 | ---
4 |
5 |
6 | > PSClient封装了用户模型与参数服务器之间参数交互关键的pull和push方法。
7 |
8 | ## 功能
9 |
10 | * 在参数服务器启动后,PSClient相当于是个代理,用户模型与参数服务器打交道在每一个RDD的partition中都需要一个PSClient的实例与参数服务器进行参数的拉取和推送。其中pull和push的过程中数据类型对用户是无感的,会根据ps启动时设置的数据类型自动匹配,用户只需要向PSClient进行pull和push操作即可,无需关心数据类型的转化。
11 |
12 |
13 | ## 核心接口
14 |
15 | 1. **pull**
16 | - 定义:```pull(featureId: util.HashSet[Long], tableName: String = ps.psConf.getPsTableName()): util.Map[Long, V] ```
17 | - 功能描述:用户从参数服务器中拉去所需要的参数。
18 | - 参数:
19 | - featureId: 用户需要拉取的feature的id
20 | - tableName: 指定需要从哪个表中拉去参数,默认是ps启动时配置的第一张表名
21 | - 返回值:
22 | - - util.Map[Long, V]: 用户拉取的参数Map
23 |
24 | 2. **push**
25 | - 定义:```push(weightMap: util.Map[Long, V], tableName: String = ps.psConf.getPsTableName()): Unit ```
26 | - 功能描述:用户向参数服务器推送保存参数。
27 | - 参数:
28 | - weightMap: 需要推送的参数id和值
29 | - tableName: 指定需要向哪张表推送参数,默认是ps启动时配置的第一张表名
30 | - 返回值:无
31 |
32 | 3. **init**
33 | - 定义:```init(): Unit ```
34 | - 功能描述:初始化本实例的相关参数
35 | - 参数:无
36 | - 返回值:无
37 |
38 | 4. **validDataType**
39 | - 定义:```validDataType(): Unit ```
40 | - 功能描述:验证PS中配置的PSDataType和PSClient实例中的泛型是否一致,若不一致则抛出异常
41 | - 参数:无
42 | - 返回值:无
43 | - 抛出异常:XDMLException
44 |
45 | 5. **setUpdater**
46 | - 定义:```setUpdater(updater: Updater[Long, V]): this.type ```
47 | - 功能描述:设置本实例需要使用的Updater,在进行push之前会先调用updater实例中的update方法处理参数后再push到参数服务器。
48 | - 参数:
49 | - updater:该PSClient实例设定的updater
50 | - 返回值:this
51 |
52 | 6. **setPullRange**
53 | - 定义:```setPullRange(start: Int, end: Int): this.type ```
54 | - 功能描述:设置拉去参数的索引范围
55 | - 参数:
56 | - start:索引起始值(包含)
57 | - end:索引结束值(不包含)
58 | - 返回值:this
59 |
60 | 7. **setPullRange**
61 | - 定义:```setPullRange(index: Int): this.type ```
62 | - 功能描述:设置只拉取某个索引的参数
63 | - 参数:
64 | - index:参数索引值
65 | - 返回值:this
66 |
67 | 8. **getRangeVFromBytes**
68 | - 定义:```getRangeVFromBytes(bytes: Array[Byte]): V ```
69 | - 功能描述:从byte数组中获取指定范围的泛型V指定类型的数据
70 | - 参数:
71 | - bytes:数据源Byte数组
72 | - 返回值:V:泛型数据
73 |
74 | 9. **shutDown**
75 | - 定义:```shutDown(): Unit ```
76 | - 功能描述:关闭本PSClient实例
77 | - 参数:无
78 | - 返回值:无
79 |
80 |
--------------------------------------------------------------------------------
/doc/PSConfiguration.md:
--------------------------------------------------------------------------------
1 | # PSConfiguration
2 |
3 | ---
4 |
5 |
6 | > PSConfiguration封装了所有XDML中参数服务器中使用的两个组件Kudu和Hazelcast所有的配置参数和PS相关的参数。
7 |
8 | ## 功能
9 |
10 | * 使用get和set方法获取和配置XDML中参数服务器的相关参数。
11 |
12 |
13 | ## 核心接口
14 |
15 | 1. **setJobType**
16 | - 定义:```setJobType(jobType:PSJobType) :this.type ```
17 | - 功能描述:设置作业参数服务器的作业类型,目前有train,predict和increment_train三种作业类型。
18 | - 参数:
19 | - jobType: 枚举类JobType的实例
20 | - 返回值:this
21 |
22 | 2. **setKuduMaster**
23 | - 定义:```setKuduMaster(master: String): this.type ```
24 | - 功能描述:设置kudu集群的master地址
25 | - 参数:
26 | - master: Kudu集群的mater节点地址
27 | - 返回值:this
28 |
29 | 3. **setKuduReplicaNum**
30 | - 定义:```setKuduReplicaNum(kuduReplicaNum: Int) ```
31 | - 功能描述:设置kudu数据备份数,默认为3
32 | - 参数:
33 | - kuduReplicaNum:kudu数据备份数
34 | - 返回值:this
35 |
36 | 4. **setKuduPartitionNum**
37 | - 定义:```setKuduPartitionNum(kuduPartitionNum: Int) ```
38 | - 功能描述:设置kudu的分区数,默认为50
39 | - 参数:
40 | - kuduPartitionNum:kudu存储的分区数
41 | - 返回值:this
42 |
43 | 5. **setKuduTableForceOverride**
44 | - 定义:```setKuduTableForceOverride(kuduTableForceOverride: Boolean): this.type ```
45 | - 功能描述:设置当kudu中新建表时有重名的表出现是否强制覆盖旧表,默认为true
46 | - 参数:
47 | - kuduTableForceOverride:是否强制覆盖,true为覆盖,false为使用旧表
48 | - 返回值:this
49 |
50 | 6. **setKuduKeyColName**
51 | - 定义:```setKuduKeyColName(kuduKeyColName: String): this.type ```
52 | - 功能描述:设置kudu表的键所在列的名字
53 | - 参数:
54 | - kuduKeyColName:键列的名字
55 | - 返回值:this
56 |
57 | 7. **setKuduValueColName**
58 | - 定义:```setKuduValueColName(kuduValueColName: String): this.type ```
59 | - 功能描述:设置kudu表的值所在列的名字
60 | - 参数:
61 | - kuduValueColName:值列的名字
62 | - 返回值:this
63 |
64 | 8. **setKuduValueNullable**
65 | - 定义:```setKuduValueNullable(kuduValueNullable: Boolean) ```
66 | - 功能描述:设置kudu表的值列是否可以为null,默认为false
67 | - 参数:
68 | - kuduValueNullable:值是否可以为null
69 | - 返回值:this
70 |
71 | 9. **setKuduRandomInit**
72 | - 定义:```setKuduRandomInit(kuduRandomInit: Boolean) ```
73 | - 功能描述:设置是否需要参数的随机初始化,默认为false
74 | - 参数:
75 | - kuduRandomInit:布尔类型是否需要随机初始化值,默认范围是0-1之间随机值,可以通过setKuduRandomInit方法设置随机范围。
76 | - 返回值:this
77 |
78 | 10. **setHzClusterNum**
79 | - 定义:```setHzClusterNum(hzClusterNum:Int):this.type ```
80 | - 功能描述:配置在参数服务器启动的时候需要启动的hz的节点数,默认为50
81 | - 参数:
82 | - hzClusterNum:参数服务器中需要启动的hz节点数
83 | - 返回值:this
84 |
85 | 11. **setHzPartitionNum**
86 | - 定义:```setHzPartitionNum(hzPartitionNum: Int): this.type ```
87 | - 功能描述:设置hz的分区数,默认为271
88 | - 参数:
89 | - hzPartitionNum:Hazelcast的分区数
90 | - 返回值:this
91 |
92 | 12. **setHzClientThreadCount**
93 | - 定义:```setHzClientThreadCount(hzClientThreadCount: Int): this.type ```
94 | - 功能描述:设置hazelcast的client最大处理线程数,默认为20
95 | - 参数:
96 | - hzClientThreadCount:hzclient的线程数
97 | - 返回值:this
98 |
99 | 13. **setHzOperationThreadCount**
100 | - 定义:```setHzOperationThreadCount(hzOperationThreadCount:Int) :this.type ```
101 | - 功能描述:设置hz的server最大操作线程数,默认是20
102 | - 参数:
103 | - hzOperationThreadCount:hz的server最大操作线程数
104 | - 返回值:this
105 |
106 | 14. **setHzMaxCacheSizePerPartition**
107 | - 定义:```setHzMaxCacheSizePerPartition(hzMaxCacheSizePerPartition: Int): this.type ```
108 | - 功能描述:设置Hz的单个分区最大缓存数据量,默认是50000000
109 | - 参数:
110 | - hzMaxCacheSizePerPartition:Hz的单个分区最大缓存数据量
111 | - 返回值:this
112 |
113 | 15. **setHzWriteDelaySeconds**
114 | - 定义:```setHzWriteDelaySeconds(hzWriteDelaySeconds: Int): this.type ```
115 | - 功能描述:设置hz异步向kudu中持久化数据的时间间隔,单位为s,默认为20
116 | - 参数:
117 | - hzWriteDelaySeconds:久化数据的时间间隔
118 | - 返回值:this
119 |
120 | 16. **setHzWriteBatchSize**
121 | - 定义:```setHzWriteBatchSize(hzWriteBatchSize: Int): this.type ```
122 | - 功能描述:设置hz写入的batchsize
123 | - 参数:
124 | - hzWriteBatchSize:hz写入的batchsize
125 | - 返回值:this
126 |
127 | 17. **setHzHeartbeatTime**
128 | - 定义:```setHzHeartbeatTime(hzHeartbeatTime: Int): this.type ```
129 | - 功能描述:设置hz server的心跳时间,单位为ms,默认为3000
130 | - 参数:
131 | - hzHeartbeatTime:hz心跳时间
132 | - 返回值:this
133 |
134 | 18. **setHzClientHeartbeatTime**
135 | - 定义:```setHzClientHeartbeatTime(hzClientHeartbeatTime: Int): this.type ```
136 | - 功能描述:设置hz client的心跳时间,
137 | - 参数:
138 | - hzClientHeartbeatTime:hz client的心跳时间
139 | - 返回值:this
140 |
141 | 19. **setHzIOThreadCount**
142 | - 定义:```setHzIOThreadCount(hzIOThreadCount:Int):this.type ```
143 | - 功能描述:设置hazelcast的IO操作的线程数
144 | - 参数:
145 | - hzIOThreadCount:IO操作线程数
146 | - 返回值:this
147 |
148 | 20. **setPsTableName**
149 | - 定义:```setPsTableName(psTableName: String*): this.type ```
150 | - 功能描述:设置本次作业使用的PS中的表名
151 | - 参数:
152 | - psTableName:为可变参数,参数服务器中使用的表名,如果使用多张表,可以设置多个。
153 | - 返回值:this
154 |
155 | 21. **setPsDataType**
156 | - 定义:```setPsDataType(psDataType: PSDataType): this.type ```
157 | - 功能描述:设置参数服务器中使用的value的数据类型
158 | - 参数:
159 | - psDataType:PSDataType的枚举类型,目前包括Float,Double,Float_array和Double_array
160 | - 返回值:this
161 |
162 | 22. **setPsDataLength**
163 | - 定义:```setPsDataLength(psDataLength: Int): this.type ```
164 | - 功能描述:设置参数服务器中使用的value的数据长度
165 | - 参数:
166 | - psDataLength:参数服务器中value的数据长度
167 | - 返回值:this
168 |
169 | 23. **setPredictModelPath**
170 | - 定义:```setPredictModelPath(predictModelPath:String):this.type ```
171 | - 功能描述:设置预测作业中需要使用的模型路径
172 | - 参数:
173 | - predictModelPath:需要载入模型的HDFS路径
174 | - 返回值:this
175 |
--------------------------------------------------------------------------------
/doc/PSMapStore.md:
--------------------------------------------------------------------------------
1 | # PSMapStore
2 |
3 | ---
4 |
5 |
6 | > PSMapStore实现了hazelcast中的MapStore接口。用于实现hazelcast缓存中数据向Kudu中的持久化,参数服务器会优先从hazelcast中拉取参数,如果没有命中则调用该接口从kudu中拉取或者初始化。
7 |
8 | ## 功能
9 |
10 | * 实现了hazelcast和kudu交互的接口规范,是实现两级式参数服务器的关键核心接口。数据在kudu中的存储统一为Array[Byte]字节数组的格式
11 |
12 |
13 | ## 核心接口
14 |
15 | 1. **init**
16 | - 定义:```init(): Unit ```
17 | - 功能描述:用于本实例属性的初始化函数,在构造实例是自动调用
18 | - 参数:无
19 | - 返回值:无
20 |
21 | 2. **store**
22 | - 定义:```store(key: Long, value: Array[Byte]): Unit ```
23 | - 功能描述:将hazelcast中的k-v数据存储到kudu中
24 | - 参数:
25 | - key: 存储数据的键
26 | - value: 存储数据的值
27 | - 返回值:无
28 |
29 | 3. **storeAll**
30 | - 定义:```def storeAll(map: util.Map[Long, Array[Byte]]): Unit ```
31 | - 功能描述:将一个键值对集合Map存入kudu中
32 | - 参数:
33 | - map:键值对集合Map
34 | - 返回值:无
35 |
36 | 4. **load**
37 | - 定义:```load(key: Long): Array[Byte] ```
38 | - 功能描述:从kudu中载入一个键值对
39 | - 参数:
40 | - key:需要载入的键
41 | - 返回值:Array[Byte]:返回该键对应的Byte数组的值
42 |
43 | 5. **loadAll**
44 | - 定义:```loadAll(keys: util.Collection[Long]): util.Map[Long, Array[Byte]] ```
45 | - 功能描述:从kudu中拉取一个集合的键值对
46 | - 参数:
47 | - keys:需要从kudu中拉取的键值集合
48 | - 返回值:util.Map[Long, Array[Byte]]:返回拉取得到的键值对集合Map
49 |
--------------------------------------------------------------------------------
/doc/Updater.md:
--------------------------------------------------------------------------------
1 | # Updater
2 |
3 | ---
4 |
5 |
6 | > Updater是PSClient中用到的一个接口,用户可以自定义实现其中的Updater方法,PSClient会在push数据到参数服务器之前先用updater方法处理传入的参数和梯度数据。
7 |
8 | ## 功能
9 |
10 | * 定义参数服务器中更新参数的方法,在执行push操作时自动调用。
11 |
12 |
13 | ## 核心接口
14 |
15 | 1. **setJobType**
16 | - 定义:```update(originalWeightMap: java.util.Map[K, V], gradientMap: java.util.Map[K, V]):java.util.Map[K, V] ```
17 | - 功能描述:定义用户参数服务器更新参数方式的接口规范
18 | - 参数:
19 | - originalWeightMap:参数服务器上的原始参数
20 | - gradientMap:一般是梯度列表参数
21 | - 返回值:java.util.Map[K, V] :用户自定义更新后的Map数据,也就是真正更新到参数服务器的Map数据
--------------------------------------------------------------------------------
/doc/faq_cn.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qihoo360/XLearning-XDML/1256c0ce3a6757ff03fe54503a9f6d9416de24b9/doc/faq_cn.md
--------------------------------------------------------------------------------
/doc/img/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qihoo360/XLearning-XDML/1256c0ce3a6757ff03fe54503a9f6d9416de24b9/doc/img/logo.jpg
--------------------------------------------------------------------------------
/doc/img/qq.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qihoo360/XLearning-XDML/1256c0ce3a6757ff03fe54503a9f6d9416de24b9/doc/img/qq.jpg
--------------------------------------------------------------------------------
/doc/img/xdml.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qihoo360/XLearning-XDML/1256c0ce3a6757ff03fe54503a9f6d9416de24b9/doc/img/xdml.png
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/conf/JobConfiguration.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.conf
2 |
3 | import org.apache.spark.SparkContext
4 |
5 | import scala.collection.mutable
6 |
7 |
8 | object JobConfiguration {
9 |
10 | val DATA_SPLIT = "spark.xdml.data.split"
11 | val KUDU_MASTER = "spark.xdml.kudu.master"
12 | val LEARNING_RATE = "spark.xdml.learningRate"
13 | val TABLE_NAME = "spark.xdml.table.name"
14 | val MODEL_PATH = "spark.xdml.model.path"
15 | val DATA_PATH = "spark.xdml.data.path"
16 | val TRAIN_DATA_PARTITION = "spark.xdml.train.data.partitionNum"
17 | val TRAIN_ITER_NUM = "spark.xdml.train.iter"
18 | val BATCH_SIZE = "spark.xdml.train.batchsize"
19 | val SUBSAMPLE_RATE = "spark.xdml.subsample.rate"
20 | val SUBSAMPLE_LABEL = "spark.xdml.train.subsample.label"
21 | val HZ_CLUSTER_NUM = "spark.xdml.hz.clusterNum"
22 | val HZ_PARTITION_NUM = "spark.xdml.hz.partitionNum"
23 | val HZ_MAXCACHESIZE_PER_PARTITION = "spark.xdml.hz.maxcachesize"
24 | val SAVE_ALL_WEIGHTS = "spark.xdml.model.save.allweights"
25 | val SAVE_FEATURE = "spark.xdml.model.save.feature"
26 | val FEATURE_FILE = "spark.xdml.model.feature.path"
27 | val PREDICT_RESULT_PATH = "spark.xdml.predict.result.path"
28 | val JOB_TYPE = "spark.xdml.job.type"
29 |
30 | val DC_ASGD_COFF = "spark.xdml.dcasgdCoff"
31 |
32 | val MOMENTUM_COFF = "spark.xdml.momentumCoff"
33 |
34 | val FTRL_ALPHA = "spark.xdml.train.alpha"
35 | val FTRL_BETA = "spark.xdml.train.beta"
36 | val FTRL_LAMBDA1 = "spark.xdml.train.lambda1"
37 | val FTRL_LAMBDA2 = "spark.xdml.train.lambda2"
38 | val FTRL_FROCESPARSE = "spark.xdml.train.forcesparse"
39 |
40 | val FM_RANK = "spark.xdml.model.fm.rank"
41 |
42 | val FFM_RANK = "spark.xdml.model.ffm.rank"
43 | val FFM_FIELD = "spark.xdml.model.ffm.field"
44 |
45 | private val jobConfSet = Set(
46 | /**
47 | * Common configuration
48 | */
49 | DATA_SPLIT,
50 | KUDU_MASTER,
51 | LEARNING_RATE,
52 | TABLE_NAME,
53 | MODEL_PATH,
54 | DATA_PATH,
55 | TRAIN_DATA_PARTITION,
56 | TRAIN_ITER_NUM,
57 | BATCH_SIZE,
58 | SUBSAMPLE_RATE,
59 | SUBSAMPLE_LABEL,
60 | HZ_CLUSTER_NUM,
61 | HZ_PARTITION_NUM,
62 | HZ_MAXCACHESIZE_PER_PARTITION,
63 | SAVE_ALL_WEIGHTS,
64 | SAVE_FEATURE,
65 | FEATURE_FILE,
66 | PREDICT_RESULT_PATH,
67 | JOB_TYPE,
68 |
69 | /**
70 | * Momentum
71 | */
72 | MOMENTUM_COFF,
73 |
74 | /**
75 | * DC-ASGD
76 | */
77 | DC_ASGD_COFF,
78 | /**
79 | * FTRL configuration
80 | */
81 | FTRL_ALPHA,
82 | FTRL_BETA,
83 | FTRL_LAMBDA1,
84 | FTRL_LAMBDA2,
85 | FTRL_FROCESPARSE,
86 |
87 | /**
88 | * FM configuration
89 | */
90 | FM_RANK,
91 |
92 | /**
93 | * FFM configuration
94 | */
95 | FFM_RANK,
96 | FFM_FIELD
97 | )
98 |
99 | }
100 |
101 | class JobConfiguration() {
102 | //read all job configuration
103 | def readJobConfig(sc: SparkContext): Map[String, String] = {
104 | val confMap = new mutable.HashMap[String, String]()
105 | for (confName <- JobConfiguration.jobConfSet) {
106 | val value = sc.getConf.get(confName, null)
107 | if (value != null)
108 | confMap.put(confName, value)
109 | }
110 | confMap.toMap
111 | }
112 | }
113 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/conf/JobType.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.conf
2 |
3 | object JobType extends Enumeration {
4 | type PSJobType = Value
5 |
6 | val TRAIN = Value(0)
7 | val PREDICT = Value(1)
8 | val INCREMENT_TRAIN = Value(2)
9 |
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/conf/PSDataType.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.conf
2 |
3 | object PSDataType extends Enumeration {
4 | type PSDataType = Value
5 |
6 | val BYTE = Value(0)
7 | val SHORT = Value(1)
8 | val INT = Value(2)
9 | val LONG = Value(3)
10 | val FLOAT = Value(4)
11 | val DOUBLE = Value(5)
12 | val FLOAT_ARRAY = Value(6)
13 | val DOUBLE_ARRAY = Value(7)
14 |
15 | def sizeOf(dataType: PSDataType): Int ={
16 | dataType match {
17 | case BYTE => 1
18 | case SHORT => 2
19 | case INT => 4
20 | case LONG => 8
21 | case FLOAT => 4
22 | case DOUBLE => 8
23 | case FLOAT_ARRAY => 4
24 | case DOUBLE_ARRAY => 8
25 | }
26 | }
27 |
28 | def isArrayType(tpe:PSDataType):Boolean = {
29 | tpe == FLOAT_ARRAY || tpe == DOUBLE_ARRAY
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/dataProcess/FFMProcessor.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.dataProcess
2 | import org.apache.spark.rdd.RDD
3 |
4 | object FFMProcessor {
5 | //process data format in libffm
6 | def processData(data:RDD[String],separator:String = " "):RDD[(Double,Array[Int],Array[Long],Array[Float])]={
7 | data.map{ line =>
8 | val splits = line.split(separator)
9 | val featureList = splits.slice(1, splits.length)
10 | (if(splits(0).toDouble > 0) 1.0D else 0.0D, featureList.map(x=>x.split(":")(0).toInt),featureList.map(x=>x.split(":")(1).toLong), featureList.map(x=>x.split(":")(2).toFloat))
11 | }
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/dataProcess/LibSVMProcessor.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.dataProcess
2 | import org.apache.spark.rdd.RDD
3 |
4 | object LibSVMProcessor {
5 | //process data format in libsvm
6 | def processData(data:RDD[String],separator:String = " "):RDD[(Double, Array[Long], Array[Float])] = {
7 | data.map{ line =>
8 | val splits = line.split(separator)
9 | val featureList = splits.slice(1, splits.length)
10 | (if(splits(0).toDouble > 0) 1.0D else 0.0D , featureList.map(x=>x.split(":")(0).toLong), featureList.map(x=>x.split(":")(1).toFloat))
11 | }
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/analysis/runUniversalAnalyzerDense.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.analysis.feature.analysis
2 |
3 | import net.qihoo.xitong.xdml.feature.analysis._
4 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler}
5 | import org.apache.spark.sql.SparkSession
6 |
7 | object runUniversalAnalyzerDense {
8 |
9 | def main(args: Array[String]) {
10 |
11 | LogHandler.avoidLog()
12 |
13 | val spark = SparkSession.builder()
14 | .appName("runUniversalAnalyzerDense").getOrCreate()
15 |
16 | val dataPath = args(0).toString
17 | val dataDelimiter = args(1).toString
18 | val schemaPath = args(2).toString
19 | val schemaDelimiter = args(3).toString
20 | val hasLabel = args(4).toBoolean
21 | val numPartition = args(5).toInt
22 | val nullValue = args(6).toString
23 | val fineness = args(7).toInt
24 |
25 | // data input
26 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter)
27 | val orgDataDF = DataHandler.readData(spark, dataPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition)
28 | orgDataDF.show()
29 |
30 | val labelColName =
31 | if(hasLabel){
32 | schemaHandler.labelColName(0)
33 | }else{
34 | ""
35 | }
36 |
37 | val (catFeatSummaryArray, numFeatSummaryArray) = UniversalAnalyzer.fitDense(orgDataDF, hasLabel, labelColName,
38 | schemaHandler.catFeatColNames, schemaHandler.numFeatColNames,
39 | Range(0, fineness+1).toArray.map{ d => 1.0*d/fineness })
40 |
41 | catFeatSummaryArray.foreach{ summary =>
42 | println("\n===========================================")
43 | println(summary.name)
44 | println("countAll: " + summary.countAll)
45 | println("countNotNull: " + summary.countNotNull)
46 | println(summary.concernedCategories.map(tup => tup._1).mkString(","))
47 | println("mi: " + summary.mi)
48 | println("auc: " + summary.auc)
49 | }
50 |
51 | numFeatSummaryArray.foreach{ summary =>
52 | println("\n===========================================")
53 | println(summary.name)
54 | println("countAll: " + summary.countAll)
55 | println("countNotNull: " + summary.countNotNull)
56 | println("mean: " + summary.mean)
57 | println("std: " + summary.std)
58 | println("skewness: " + summary.skewness)
59 | println("kurtosis: " + summary.kurtosis)
60 | println("min: " + summary.min)
61 | println("max: " + summary.max)
62 | println("quantiles: " + summary.quantileArray.mkString(","))
63 | println("corr: " + summary.corr)
64 | println("auc: " + summary.auc)
65 | }
66 |
67 | spark.stop()
68 |
69 | }
70 |
71 | }
72 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/analysis/runUniversalAnalyzerDenseGrouped.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.analysis.feature.analysis
2 |
3 | import net.qihoo.xitong.xdml.feature.analysis.UniversalAnalyzer
4 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler}
5 | import org.apache.spark.sql.SparkSession
6 |
7 | object runUniversalAnalyzerDenseGrouped {
8 |
9 | def main(args: Array[String]): Unit = {
10 |
11 | LogHandler.avoidLog()
12 |
13 | val spark = SparkSession.builder()
14 | .appName("runUniversalAnalyzerDenseGrouped").getOrCreate()
15 |
16 | val dataPath = args(0).toString
17 | val dataDelimiter = args(1).toString
18 | val schemaPath = args(2).toString
19 | val schemaDelimiter = args(3).toString
20 | val numPartition = args(4).toInt
21 | val nullValue = args(5).toString
22 |
23 | val groupColName = args(6).toString
24 | val topKStr = args(7).toString
25 | val topKDelimiter = args(8).toString
26 |
27 | val topKs = topKStr.split(topKDelimiter).map(_.toInt)
28 |
29 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter)
30 | val orgDataDF = DataHandler.readData(spark, dataPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition)
31 | orgDataDF.show()
32 |
33 | val groupFeatSummaryArray = UniversalAnalyzer.fitDenseGrouped(orgDataDF,
34 | schemaHandler.labelColName(0),
35 | groupColName,
36 | schemaHandler.numFeatColNames,
37 | topKs,
38 | gainFunc,
39 | discountFunc)
40 | println("-------------------metric result----------------------")
41 | groupFeatSummaryArray.foreach{ summary =>
42 | println("\n=====================")
43 | println(summary.name)
44 | println("reversePairRate: " + summary.reversePairRate)
45 | println("ndcgs: " + summary.ndcgMap.toString())
46 | }
47 | spark.stop()
48 | }
49 |
50 | def gainFunc(rating: Int): Double = {
51 | val gains = Array(0, 1, 3, 7, 15)
52 | val gainSize = gains.length
53 | if (rating < 0) {
54 | 0
55 | } else if (rating < gainSize) {
56 | gains(rating)
57 | } else {
58 | 0
59 | // gains.max + (rating - gainSize)
60 | }
61 | }
62 |
63 | def discountFunc(index: Int): Double = {
64 | val log2 = math.log(2)
65 | math.log(index + 2) / log2
66 | }
67 |
68 | }
69 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/analysis/runUniversalAnalyzerDenseKS.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.analysis.feature.analysis
2 |
3 | import net.qihoo.xitong.xdml.feature.analysis.UniversalAnalyzer
4 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler}
5 | import org.apache.spark.sql.SparkSession
6 |
7 | object runUniversalAnalyzerDenseKS {
8 |
9 | def main(args: Array[String]): Unit={
10 |
11 | LogHandler.avoidLog()
12 |
13 | val spark = SparkSession.builder()
14 | .appName("runUniversalAnalyzerDenseKS").getOrCreate()
15 |
16 | val orgDataPath = args(0).toString
17 | val orgDataDelimiter = args(1).toString
18 | val orgSchemaPath = args(2).toString
19 | val orgSchemaDelimiter = args(3).toString
20 | val cmpDataPath = args(4).toString
21 | val cmpDataDelimiter = args(5).toString
22 | val cmpSchemaPath = args(6).toString
23 | val cmpSchemaDelimiter = args(7).toString
24 | val numPartition = args(8).toInt
25 | val nullValue = args(9).toString
26 | val withHeader = args(10).toBoolean
27 | val savePath = args(11).toString
28 |
29 | val orgSchemaHandler = SchemaHandler.readSchema(spark.sparkContext, orgSchemaPath, orgSchemaDelimiter)
30 | val cmpSchemaHandler = SchemaHandler.readSchema(spark.sparkContext, cmpSchemaPath, cmpSchemaDelimiter)
31 | val orgDataDF = DataHandler.readData(spark, orgDataPath, orgDataDelimiter, orgSchemaHandler.schema, nullValue, numPartition, withHeader)
32 | val cmpDataDF = DataHandler.readData(spark, cmpDataPath, cmpDataDelimiter, cmpSchemaHandler.schema, nullValue, numPartition, withHeader)
33 |
34 | val time1 = System.currentTimeMillis()
35 | val ksValues = UniversalAnalyzer.fitDenseKSForNum(orgDataDF, orgSchemaHandler.numFeatColNames, cmpDataDF, cmpSchemaHandler.numFeatColNames)
36 | val time2 = System.currentTimeMillis()
37 | println("spent time", (time2 - time1))
38 | spark.sparkContext.parallelize(Seq(ksValues.mkString("\n")))
39 | .repartition(1).saveAsTextFile(savePath)
40 |
41 | spark.stop()
42 | }
43 |
44 | }
45 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/analysis/runUniversalAnalyzerSparse.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.analysis.feature.analysis
2 |
3 | import net.qihoo.xitong.xdml.feature.analysis._
4 | import net.qihoo.xitong.xdml.model.data.{LogHandler, SchemaHandler}
5 | import org.apache.spark.sql.SparkSession
6 | import scala.collection.mutable.HashMap
7 |
8 | object runUniversalAnalyzerSparse {
9 |
10 | trait ParseBase extends Serializable {
11 | def parse(str: String): HashMap[String, String]
12 | }
13 |
14 | class StandardParse(stringDelimiter: String = ",", pairDelimiter: String = ":") extends ParseBase {
15 | def parse(str: String): HashMap[String, String] = {
16 | val arrOfTuple2 = str.split(stringDelimiter).map{ elem =>
17 | val elemSplit = elem.split(pairDelimiter)
18 | (elemSplit(0), elemSplit(1))
19 | }
20 | HashMap[String, String]() ++= arrOfTuple2
21 | }
22 | }
23 |
24 | def main(args: Array[String]) {
25 |
26 | LogHandler.avoidLog()
27 |
28 | val spark = SparkSession.builder()
29 | .appName("runUniversalAnalyzerSparse").getOrCreate()
30 |
31 | val dataPath = args(0).toString
32 | val parseType = args(1).toString
33 | val schemaPath = args(2).toString
34 | val schemaDelimiter = args(3).toString
35 | val hasLabel = args(4).toBoolean
36 | val numPartitions = args(5).toInt
37 | val sparseType = args(6).toString
38 | val fineness = args(7).toInt
39 |
40 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter)
41 | val tmpDataRDD = spark.sparkContext.textFile(dataPath, numPartitions)
42 | val orgDataRDD = parseType match {
43 | case "standard" =>
44 | val standardParse = new StandardParse()
45 | tmpDataRDD.map{ str => standardParse.parse(str) }
46 | case _ => throw new IllegalArgumentException("parse type unknown")
47 | }
48 |
49 | val labelColName =
50 | if(hasLabel){
51 | schemaHandler.labelColName(0)
52 | }else{
53 | ""
54 | }
55 |
56 | val (catFeatSummaryArray, numFeatSummaryArray) = UniversalAnalyzer.fitSparse(orgDataRDD, hasLabel,
57 | labelColName, schemaHandler.catFeatColNames, schemaHandler.numFeatColNames,
58 | sparseType, Range(0, fineness+1).toArray.map{ d => 1.0*d/fineness })
59 |
60 | catFeatSummaryArray.foreach{ summary =>
61 | println("\n===========================================")
62 | println(summary.name)
63 | println(summary.concernedCategories.map(tup => tup._1).mkString(","))
64 | println("mi: " + summary.mi)
65 | println("auc: " + summary.auc)
66 | }
67 |
68 | numFeatSummaryArray.foreach{ summary =>
69 | println("\n===========================================")
70 | println(summary.name)
71 | println("countNotNull: " + summary.countNotNull)
72 | println("mean: " + summary.mean)
73 | println("std: " + summary.std)
74 | println("skewness: " + summary.skewness)
75 | println("kurtosis: " + summary.kurtosis)
76 | println("min: " + summary.min)
77 | println("max: " + summary.max)
78 | println("quantiles: " + summary.quantileArray.mkString(","))
79 | println("corr: " + summary.corr)
80 | println("auc: " + summary.auc)
81 | }
82 |
83 | spark.stop()
84 |
85 | }
86 |
87 | }
88 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/process/runCategoryEncoder.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.analysis.feature.process
2 |
3 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler}
4 | import org.apache.spark.ml.feature.CategoryEncoder
5 | import org.apache.spark.sql.SparkSession
6 |
7 |
8 | object runCategoryEncoder {
9 |
10 | def main(args: Array[String]) {
11 |
12 | LogHandler.avoidLog()
13 |
14 | val spark = SparkSession.builder()
15 | .appName("testCategoryEncoder").getOrCreate()
16 |
17 | // data
18 | val dataPath = args(0).toString
19 | val dataDelimiter = args(1).toString
20 | val schemaPath = args(2).toString
21 | val schemaDelimiter = args(3).toString
22 | val numPartition = args(4).toInt
23 | val nullValue = args(5).toString
24 |
25 | // data input
26 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter)
27 | val orgDataDF = DataHandler.readData(spark, dataPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition)
28 | println("input table")
29 | orgDataDF.show(false)
30 |
31 | val indexer = new CategoryEncoder()
32 | .setInputCols(schemaHandler.catFeatColNames)
33 | .setOutputCols(schemaHandler.catFeatColNames.map(name => name + "xdmlIndexed"))
34 | .setIndexOnly(true)
35 | // .setDropInputCols(true)
36 | // .setCategoriesReserved(5)
37 |
38 | val startTime = System.currentTimeMillis()
39 | val indexerModel = indexer.fit(orgDataDF)
40 | val fitTime = System.currentTimeMillis()
41 | // indexerModel.labels.foreach{ seq =>
42 | // println(seq.toArray.mkString(", "))
43 | // }
44 | val processedDataDF = indexerModel.transform(orgDataDF)
45 | println("output table")
46 | processedDataDF.show(false)
47 | println("dataCount: " + processedDataDF.count())
48 |
49 | val transTime = System.currentTimeMillis()
50 | println("time for fitting: "+(fitTime-startTime)/1000.0)
51 | println("time for transforming: "+(transTime-fitTime)/1000.0)
52 |
53 | spark.stop()
54 |
55 | }
56 |
57 | }
58 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/process/runFeatureProcess.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.analysis.feature.process
2 |
3 | import net.qihoo.xitong.xdml.feature.process.FeatureProcessor
4 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler}
5 | import org.apache.spark.ml._
6 | import org.apache.spark.sql.SparkSession
7 |
8 |
9 | object runFeatureProcess {
10 |
11 | def main(args: Array[String]) {
12 |
13 | LogHandler.avoidLog()
14 |
15 | val spark = SparkSession.builder()
16 | .appName("runFeatureProcess").getOrCreate()
17 |
18 | // data
19 | val dataPath = args(0).toString
20 | val dataDelimiter = args(1).toString
21 | val schemaPath = args(2).toString
22 | val schemaDelimiter = args(3).toString
23 | val numPartition = args(4).toInt
24 | val nullValue = args(5).toString
25 |
26 | // job type
27 | val jobType = args(6).toString
28 |
29 | val methodForNum = args(7).toString
30 | val ifOnehotForNum = args(8).toBoolean
31 | val ifOnehotForCat = args(9).toBoolean
32 | val ifOnehotForMultiCat = args(10).toBoolean
33 |
34 | val numBuckets = args(11).toInt
35 | val categoriesReservedForCat = args(12).toInt
36 | val categoriesReservedForMultiCat = args(13).toInt
37 | val multiCatDelimiter = args(14).toString
38 |
39 | // pipeline model path
40 | val pipelineModelPath = args(15)
41 | // processed data path
42 | val processedDataPath = args(16)
43 |
44 | // data input
45 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter)
46 | val orgDataDF = DataHandler.readData(spark, dataPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition)
47 | println("input table")
48 | orgDataDF.show(false)
49 |
50 | val resDF = jobType match {
51 | case "fit_transform" => {
52 | val (resDFTmp, pipelineModelTmp) = FeatureProcessor.pipelineFitTransform(orgDataDF, schemaHandler,
53 | methodForNum, ifOnehotForNum, ifOnehotForCat, ifOnehotForMultiCat, true,
54 | numBuckets, categoriesReservedForCat, categoriesReservedForMultiCat, multiCatDelimiter)
55 | pipelineModelTmp.write.overwrite().save(pipelineModelPath)
56 | resDFTmp
57 | }
58 | case "transform" => {
59 | val pipelineModelTmp = PipelineModel.load(pipelineModelPath)
60 | val resDFTmp = pipelineModelTmp.transform(orgDataDF)
61 | resDFTmp
62 | }
63 | case _ => throw new IllegalArgumentException(s"Does not support job type: $jobType")
64 | }
65 |
66 | println("output table")
67 | resDF.show(false)
68 | DataHandler.writeLibSVMData(resDF, FeatureProcessor.labelProcessedColName,
69 | FeatureProcessor.featsProcessedColName, processedDataPath)
70 |
71 | spark.stop()
72 |
73 | }
74 |
75 | }
76 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/process/runMultiCategoryEncoder.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.analysis.feature.process
2 |
3 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler}
4 | import org.apache.spark.ml.feature.MultiCategoryEncoder
5 | import org.apache.spark.sql.SparkSession
6 |
7 |
8 | object runMultiCategoryEncoder {
9 |
10 | def main(args: Array[String]) {
11 |
12 | LogHandler.avoidLog()
13 |
14 | val spark = SparkSession.builder()
15 | .appName("testMultiCategoryEncoder").getOrCreate()
16 |
17 | // data
18 | val dataPath = args(0).toString
19 | val dataDelimiter = args(1).toString
20 | val schemaPath = args(2).toString
21 | val schemaDelimiter = args(3).toString
22 | val numPartition = args(4).toInt
23 | val nullValue = args(5).toString
24 | val multiCatDelimiter = args(6).toString
25 |
26 | // data input
27 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter)
28 | val orgDataDF = DataHandler.readData(spark, dataPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition)
29 | println("input table")
30 | orgDataDF.show(false)
31 |
32 | val encoder = new MultiCategoryEncoder()
33 | .setInputCols(schemaHandler.multiCatFeatColNames)
34 | .setOutputCols(schemaHandler.multiCatFeatColNames.map(name => name + "xdmlEncoded"))
35 | .setDelimiter(multiCatDelimiter)
36 | .setIndexOnly(true)
37 | // .setDropInputCols(true)
38 |
39 | val startTime = System.currentTimeMillis()
40 | val encoderModel = encoder.fit(orgDataDF)
41 | val fitTime = System.currentTimeMillis()
42 | val processedDataDF = encoderModel.transform(orgDataDF)
43 | println("output table")
44 | processedDataDF.show(false)
45 | println("dataCount: " + processedDataDF.count())
46 |
47 | val transTime = System.currentTimeMillis()
48 | println("time for fitting: "+(fitTime-startTime)/1000.0)
49 | println("time for transforming: "+(transTime-fitTime)/1000.0)
50 |
51 | spark.stop()
52 |
53 | }
54 |
55 | }
56 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/process/runNumericBucketer.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.analysis.feature.process
2 |
3 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler}
4 | import org.apache.spark.ml.feature.NumericBucketer
5 | import org.apache.spark.sql.SparkSession
6 |
7 | object runNumericBucketer {
8 |
9 | def main(args:Array[String]) {
10 |
11 | LogHandler.avoidLog()
12 |
13 | val spark = SparkSession.builder()
14 | .appName("runNumericBucketer").getOrCreate()
15 |
16 | // data
17 | val dataPath = args(0).toString
18 | val dataDelimiter = args(1).toString
19 | val schemaPath = args(2).toString
20 | val schemaDelimiter = args(3).toString
21 | val numPartition = args(4).toInt
22 | val nullValue = args(5).toString
23 |
24 | val numBucketsArray = args(6).toString.split(",").map(_.toInt)
25 |
26 | // data input
27 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter)
28 | val orgDataDF = DataHandler.readData(spark, dataPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition)
29 | println("input table")
30 | orgDataDF.show(false)
31 |
32 | require(numBucketsArray.length == schemaHandler.numFeatColNames.length, "invalid numBucketsArray size")
33 |
34 | val numericBucketer = new NumericBucketer()
35 | .setInputCols(schemaHandler.numFeatColNames)
36 | .setOutputCols(schemaHandler.numFeatColNames.map(name=>name+"xdmlBucketed"))
37 | .setNumBucketsArray(numBucketsArray)
38 | .setIndexOnly(false)
39 | .setOutputSparse(true)
40 | .setDropInputCols(true)
41 |
42 | val startTime = System.currentTimeMillis()
43 | val numericBucketerModel = numericBucketer.fit(orgDataDF)
44 | val fitTime = System.currentTimeMillis()
45 |
46 | val processedDataDF = numericBucketerModel.transform(orgDataDF)
47 | println("output table")
48 | processedDataDF.show(false)
49 | println("dataCount: " + processedDataDF.count())
50 |
51 | val transTime = System.currentTimeMillis()
52 | println("time for fitting: "+(fitTime-startTime)/1000.0)
53 | println("time for transforming: "+(transTime-fitTime)/1000.0)
54 |
55 | spark.stop()
56 | }
57 |
58 | }
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/process/runNumericStandardizer.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.analysis.feature.process
2 |
3 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler}
4 | import org.apache.spark.ml.feature.NumericStandardizer
5 | import org.apache.spark.sql.SparkSession
6 |
7 |
8 | object runNumericStandardizer {
9 |
10 | def main(args: Array[String]) {
11 |
12 | LogHandler.avoidLog()
13 |
14 | val spark = SparkSession.builder()
15 | .appName("runNumericStandardizer").getOrCreate()
16 |
17 | // data
18 | val dataPath = args(0).toString
19 | val dataDelimiter = args(1).toString
20 | val schemaPath = args(2).toString
21 | val schemaDelimiter = args(3).toString
22 | val numPartition = args(4).toInt
23 | val nullValue = args(5).toString
24 |
25 | // data input
26 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter)
27 | val orgDataDF = DataHandler.readData(spark, dataPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition)
28 | println("input table")
29 | orgDataDF.show(false)
30 |
31 | val numericStandardizer = new NumericStandardizer()
32 | .setInputCols(schemaHandler.numFeatColNames)
33 | .setOutputCols(schemaHandler.numFeatColNames.map(name => name + "XDMLStandardized"))
34 |
35 | val startTime = System.currentTimeMillis()
36 | val numericStandardizerModel = numericStandardizer.fit(orgDataDF)
37 | val fitTime = System.currentTimeMillis()
38 |
39 | val processedDataDF = numericStandardizerModel.transform(orgDataDF)
40 | println("output table")
41 | processedDataDF.show(false)
42 | println("dataCount: " + processedDataDF.count())
43 |
44 | val transTime = System.currentTimeMillis()
45 | println("time for fitting: "+(fitTime-startTime)/1000.0)
46 | println("time for transforming: "+(transTime-fitTime)/1000.0)
47 |
48 | spark.stop()
49 |
50 | }
51 |
52 | }
53 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/analysis/model/runFromDenseDataToXDMLH2ODRF.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.analysis.model
2 |
3 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler}
4 | import org.apache.spark.ml.model.supervised.{H2ODRF, H2ODRFModel}
5 | import org.apache.spark.sql.SparkSession
6 |
7 | object runFromDenseDataToXDMLH2ODRF {
8 |
9 | def main(args: Array[String]) {
10 | LogHandler.avoidLog()
11 |
12 | val spark = SparkSession
13 | .builder()
14 | .appName("runFromDenseDataToXDMLH2ODRF")
15 | .getOrCreate()
16 |
17 | // data
18 | val trainPath = args(0).toString
19 | val validPath = args(1).toString
20 | val dataDelimiter = args(2).toString
21 | val schemaPath = args(3).toString
22 | val schemaDelimiter = args(4).toString
23 | val numPartition = args(5).toInt
24 | val nullValue = args(6).toString
25 | val modelPath = args(7).toString
26 |
27 | // hyperparameter
28 | val maxDepth = args(8).toInt
29 | val numTrees = args(9).toInt
30 | val maxBinsForCat = args(10).toInt
31 | val maxBinsForNum = args(11).toInt
32 | val minInstancesPerNode = args(12).toInt
33 | val categoricalEncodingScheme = args(13).toString
34 | val histogramType = args(14).toString
35 | val distribution = args(15).toString
36 | val scoreTreeInterval = args(16).toInt
37 |
38 | // data input
39 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter)
40 | val orgTrainDataDF = DataHandler.readData(spark, trainPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition)
41 | orgTrainDataDF.show()
42 | val orgValidDataDF =
43 | if (validPath.trim.length > 0) {
44 | DataHandler.readData(spark, validPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition)
45 | } else {
46 | null
47 | }
48 |
49 | val h2oDRF = new H2ODRF()
50 | .setLabelCol(schemaHandler.labelColName(0))
51 | .setCatFeatColNames(schemaHandler.catFeatColNames)
52 | .setIgnoreFeatColNames(schemaHandler.otherColNames)
53 | .setMaxDepth(maxDepth)
54 | .setNumTrees(numTrees)
55 | .setMaxBinsForCat(maxBinsForCat)
56 | .setMaxBinsForNum(maxBinsForNum)
57 | .setMinInstancesPerNode(minInstancesPerNode)
58 | .setCategoricalEncodingScheme(categoricalEncodingScheme.trim())
59 | .setHistogramType(histogramType.trim())
60 | .setDistribution(distribution.trim())
61 | .setScoreTreeInterval(scoreTreeInterval)
62 |
63 | val drfModel = h2oDRF.fit(orgTrainDataDF)
64 | drfModel.write.overwrite().save(modelPath)
65 | val drfModel2 = H2ODRFModel.load(modelPath)
66 | drfModel2.transformSchema(orgValidDataDF.schema).printTreeString()
67 |
68 | val predDataDF = drfModel2.transform(orgValidDataDF)
69 | predDataDF.show(false)
70 |
71 | spark.stop()
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/analysis/model/runFromDenseDataToXDMLH2OGBM.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.analysis.model
2 |
3 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler}
4 | import org.apache.spark.ml.model.supervised.{H2OGBM, H2OGBMModel}
5 | import org.apache.spark.sql.SparkSession
6 |
7 | object runFromDenseDataToXDMLH2OGBM {
8 |
9 | def main(args: Array[String]) {
10 | LogHandler.avoidLog()
11 |
12 | val spark = SparkSession
13 | .builder()
14 | .appName("runFromDenseDataToXDMLH2OGBM")
15 | .getOrCreate()
16 |
17 | // data
18 | val trainPath = args(0).toString
19 | val validPath = args(1).toString
20 | val dataDelimiter = args(2).toString
21 | val schemaPath = args(3).toString
22 | val schemaDelimiter = args(4).toString
23 | val numPartition = args(5).toInt
24 | val nullValue = args(6).toString
25 |
26 | // model path
27 | val modelPath = args(7).toString
28 |
29 | // hyperparameter
30 | val maxDepth = args(8).toInt
31 | val numTrees = args(9).toInt
32 | val maxBinsForCat = args(10).toInt
33 | val maxBinsForNum = args(11).toInt
34 | val minInstancesPerNode = args(12).toInt
35 | val categoricalEncodingScheme = args(13).toString
36 | val histogramType = args(14).toString
37 | val learnRate = args(15).toDouble
38 | val learnRateAnnealing = args(16).toDouble
39 | val distribution = args(17).toString
40 | val scoreTreeInterval = args(18).toInt
41 |
42 | // data input
43 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter)
44 | val orgTrainDataDF = DataHandler.readData(spark, trainPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition)
45 | orgTrainDataDF.show()
46 | val orgValidDataDF =
47 | if (validPath.trim.length > 0) {
48 | DataHandler.readData(spark, validPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition)
49 | } else {
50 | null
51 | }
52 |
53 | val h2oGBM = new H2OGBM()
54 | .setLabelCol(schemaHandler.labelColName(0))
55 | .setCatFeatColNames(schemaHandler.catFeatColNames)
56 | .setIgnoreFeatColNames(schemaHandler.otherColNames)
57 | .setMaxDepth(maxDepth)
58 | .setNumTrees(numTrees)
59 | .setMaxBinsForCat(maxBinsForCat)
60 | .setMaxBinsForNum(maxBinsForNum)
61 | .setMinInstancesPerNode(minInstancesPerNode)
62 | .setCategoricalEncodingScheme(categoricalEncodingScheme)
63 | .setHistogramType(histogramType)
64 | .setLearnRate(learnRate)
65 | .setLearnRateAnnealing(learnRateAnnealing)
66 | .setDistribution(distribution)
67 | .setScoreTreeInterval(scoreTreeInterval)
68 |
69 | val gbmModel = h2oGBM.fit(orgTrainDataDF)
70 | gbmModel.write.overwrite().save(modelPath)
71 | val gbmModel2 = H2OGBMModel.load(modelPath)
72 | gbmModel2.transformSchema(orgValidDataDF.schema).printTreeString()
73 |
74 | val predDataDF = gbmModel2.transform(orgValidDataDF)
75 | predDataDF.show(false)
76 |
77 | spark.stop()
78 | }
79 |
80 | }
81 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/analysis/model/runFromDenseDataToXDMLH2OGLM.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.analysis.model
2 |
3 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler}
4 | import org.apache.spark.ml.model.supervised.{H2OGLM, H2OGLMModel}
5 | import org.apache.spark.sql.SparkSession
6 |
7 |
8 | object runFromDenseDataToXDMLH2OGLM {
9 |
10 | def main(args: Array[String]) {
11 | LogHandler.avoidLog()
12 |
13 | val spark = SparkSession
14 | .builder()
15 | .appName("runFromDenseDataToXDMLH2OGLM")
16 | .getOrCreate()
17 |
18 | // data
19 | val trainPath = args(0).toString
20 | val validPath = args(1).toString
21 | val dataDelimiter = args(2).toString
22 | val schemaPath = args(3).toString
23 | val schemaDelimiter = args(4).toString
24 | val numPartition = args(5).toInt
25 | val nullValue = args(6).toString
26 | val modelPath = args(7).toString
27 |
28 | // hyperparameter
29 | val family = args(8).toString
30 | val maxIter = args(9).toInt
31 | val alpha = args(10).toDouble
32 | val lambda = args(11).toDouble
33 | val missingValuesHandling = args(12).toString
34 | val solver = args(13).toString
35 | val standardization = args(14).toBoolean
36 | val fitIntercept = args(15).toBoolean
37 |
38 |
39 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter)
40 | val orgTrainDataDF = DataHandler.readData(spark, trainPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition)
41 | orgTrainDataDF.show()
42 | val orgValidDataDF =
43 | if (validPath.trim.length > 0) {
44 | DataHandler.readData(spark, validPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition)
45 | } else {
46 | null
47 | }
48 |
49 | val h2oGLM = new H2OGLM()
50 | .setLabelCol(schemaHandler.labelColName(0))
51 | .setCatFeatColNames(schemaHandler.catFeatColNames)
52 | .setIgnoreFeatColNames(schemaHandler.otherColNames)
53 | .setFamily(family)
54 | .setMaxIter(maxIter)
55 | .setAlpha(alpha)
56 | .setLambda(lambda)
57 | .setFitIntercept(fitIntercept)
58 | .setStandardization(standardization)
59 | .setMissingValueHandling(missingValuesHandling)
60 |
61 | val glmModel = h2oGLM.fit(orgTrainDataDF)
62 | glmModel.write.overwrite().save(modelPath)
63 | val glmModel2 = H2OGLMModel.load(modelPath)
64 |
65 | val predDF = glmModel2.transform(orgValidDataDF)
66 | predDF.show(false)
67 | spark.stop()
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/analysis/model/runFromDenseDataToXDMLH2OMLP.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.analysis.model
2 |
3 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler}
4 | import org.apache.spark.ml.model.supervised.{H2OMLP, H2OMLPModel}
5 | import org.apache.spark.sql.SparkSession
6 |
7 | object runFromDenseDataToXDMLH2OMLP {
8 |
9 | def main(args: Array[String]) {
10 | LogHandler.avoidLog()
11 |
12 | val spark = SparkSession
13 | .builder()
14 | .appName("runFromDenseDataToXDMLH2OMLP")
15 | .getOrCreate()
16 |
17 | // data
18 | val trainPath = args(0).toString
19 | val validPath = args(1).toString
20 | val dataDelimiter = args(2).toString
21 | val schemaPath = args(3).toString
22 | val schemaDelimiter = args(4).toString
23 | val numPartition = args(5).toInt
24 | val nullValue = args(6).toString
25 | val modelPath = args(7).toString
26 |
27 | // hyperparameter
28 | val missingValuesHandling = args(8).toString
29 | val categoricalEncodingScheme = args(9).toString
30 | val distribution = args(10).toString
31 | val hidden = args(11).toString.split(",").map(_.toInt)
32 | val activation = args(12).toString
33 | val epochs = args(13).toDouble
34 | val learningRate = args(14).toDouble
35 | val momentumStart = args(15).toDouble
36 | val momentumStable = args(16).toDouble
37 | val l1 = args(17).toDouble
38 | val l2 = args(18).toDouble
39 | val hiddenDropoutRatiosText = args(19).toString
40 | val elasticAveraging = args(20).toBoolean
41 | val standardization = args(21).toBoolean
42 |
43 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter)
44 | val orgTrainDataDF = DataHandler.readData(spark, trainPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition)
45 | orgTrainDataDF.show()
46 | val orgValidDataDF =
47 | if (validPath.trim.length > 0) {
48 | DataHandler.readData(spark, validPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition)
49 | } else {
50 | null
51 | }
52 |
53 | val hiddenDropoutRatios = if(hiddenDropoutRatiosText.trim.length > 0) hiddenDropoutRatiosText.split(",").map{ _.toDouble} else Array[Double]()
54 |
55 | val h2oMLP = new H2OMLP()
56 | .setLabelCol(schemaHandler.labelColName(0))
57 | .setCatFeatColNames(schemaHandler.catFeatColNames)
58 | .setIgnoreFeatColNames(schemaHandler.otherColNames)
59 | .setMissingValueHandling(missingValuesHandling)
60 | .setCategoricalEncodingScheme(categoricalEncodingScheme)
61 | .setDistribution(distribution)
62 | .setHidden(hidden)
63 | .setActivation(activation)
64 | .setEpochs(epochs)
65 | .setLearnRate(learningRate)
66 | .setMomentumStart(momentumStart)
67 | .setMomentumStable(momentumStable)
68 | .setL1(l1)
69 | .setL2(l2)
70 | .setHiddenDropoutRatios(hiddenDropoutRatios)
71 | .setElasticAveraging(elasticAveraging)
72 | .setStandardization(standardization)
73 |
74 | val mlpModel = h2oMLP.fit(orgTrainDataDF,null)
75 | mlpModel.write.overwrite().save(modelPath)
76 | val mlpModel2 = H2OMLPModel.load(modelPath)
77 | mlpModel2.transformSchema(orgValidDataDF.schema).printTreeString()
78 |
79 | val predDataDF = mlpModel2.transform(orgValidDataDF)
80 | predDataDF.show(false)
81 |
82 | spark.stop()
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/analysis/model/runFromLibSVMDataToXDMLLinearScopeModel.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.analysis.model
2 |
3 | import net.qihoo.xitong.xdml.model.data.LogHandler
4 | import org.apache.spark.ml.model.supervised.LinearScope
5 | import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
6 | import org.apache.spark.sql.data.DataProcessor
7 | import org.apache.spark.sql.functions._
8 | import org.apache.spark.sql.types.DoubleType
9 | import org.apache.spark.sql.{Row, SparkSession}
10 |
11 | object runFromLibSVMDataToXDMLLinearScopeModel {
12 | def main(args: Array[String]) {
13 | LogHandler.avoidLog()
14 | val spark = SparkSession
15 | .builder()
16 | .appName("runFromLibSVMDataToXDMLLinearScopeModel")
17 | .getOrCreate()
18 |
19 | val trainPath = args(0).toString
20 | val validPath = args(1).toString
21 | val numFeatures = args(2).toInt
22 | val numPartitions = args(3).toInt
23 | val oneBased = args(4).toBoolean
24 | val needSort = args(5).toBoolean
25 | val rescaleBinaryLabel = args(6).toBoolean
26 | val stepSize = args(7).toDouble
27 | val numIterations = args(8).toInt
28 | val regParam = args(9).toDouble
29 | val elasticNetParam = args(10).toDouble
30 | val factor = args(11).toDouble
31 | val fitIntercept = args(12).toBoolean
32 | val lossType = args(13).toString
33 | val posWeight = args(14).toDouble
34 |
35 | val trainData = DataProcessor.readLibSVMDataAsDF(spark, trainPath, numFeatures, numPartitions,
36 | oneBased, needSort, rescaleBinaryLabel)
37 | trainData.show(2)
38 |
39 | val validData =
40 | if (validPath.trim.length > 0) {
41 | DataProcessor.readLibSVMDataAsDF(spark, validPath, numFeatures, numPartitions,
42 | oneBased, needSort, rescaleBinaryLabel)
43 | } else {
44 | null
45 | }
46 | if(validData!=null) {
47 | validData.show(2)
48 | }
49 |
50 | val convergenceTol = 0.2
51 | val linearScope = new LinearScope()
52 | .setFeaturesCol("features")
53 | .setLabelCol("label")
54 | .setMaxIter(numIterations)
55 | .setStepSize(stepSize)
56 | .setLossFunc(lossType)
57 | .setRegParam(regParam)
58 | .setElasticNetParam(elasticNetParam)
59 | .setFactor(factor)
60 | .setFitIntercept(fitIntercept)
61 | .setPosWeight(posWeight)
62 | .setConvergenceTol(convergenceTol)
63 | .setNumPartitions(numPartitions)
64 |
65 | val model = linearScope.fit(trainData)
66 | /***
67 | * also can set initial weights of w
68 | val initialWeights = Vectors.zeros(numFeatures)
69 | val model = linearScope.fit(trainData,initialWeights)*/
70 | model.setRawPredictionCol("prediction")
71 | .setProbabilityCol("probability")
72 | if (validPath.trim.length > 0){
73 | val df = model.transform(validData)
74 | df.show(5)
75 | val predRDD=df.select(col("prediction"),col("label").cast(DoubleType)).rdd.map{case Row(pred:Double,label:Double)=>{(pred,label)}}
76 | val metrics = new BinaryClassificationMetrics(predRDD)
77 | println("Model Validating AUROC: " + metrics.areaUnderROC())
78 | }else{
79 | val df = model.transform(trainData)
80 | df.show(5)
81 | val predRDD=df.select(col("prediction"),col("label").cast(DoubleType)).rdd.map{case Row(pred:Double,label:Double)=>{(pred,label)}}
82 | val metrics = new BinaryClassificationMetrics(predRDD)
83 | println("Model Training AUROC: " + metrics.areaUnderROC())
84 | }
85 | spark.stop()
86 | }
87 |
88 | }
89 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/analysis/model/runFromLibSVMDataToXDMLOVR.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.analysis.model
2 |
3 | import net.qihoo.xitong.xdml.model.data.LogHandler
4 | import org.apache.spark.sql.SparkSession
5 | import org.apache.spark.sql.data.DataProcessor
6 | import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, MulticlassClassificationEvaluator}
7 | import org.apache.spark.ml.model.supervised.OVRLinearScope
8 |
9 | object runFromLibSVMDataToXDMLOVR {
10 |
11 | def main(args: Array[String]) {
12 |
13 | LogHandler.avoidLog()
14 |
15 | val spark = SparkSession
16 | .builder()
17 | .appName("runFromLibSVMDataToXDMLOVR")
18 | .getOrCreate()
19 |
20 | val trainPath = args(0).toString
21 | val validPath = args(1).toString
22 | val numFeatures = args(2).toInt
23 | val numPartitions = args(3).toInt
24 | val oneBased = args(4).toBoolean
25 | val needSort = args(5).toBoolean
26 | val rescaleBinaryLabel = args(6).toBoolean
27 | val stepSize = args(7).toDouble
28 | val numIterations = args(8).toInt
29 | val numClasses = args(9).toInt
30 | val lossType = args(10).toString
31 |
32 | val trainDF = DataProcessor.readLibSVMDataAsDF(spark, trainPath, numFeatures, numPartitions,
33 | oneBased, needSort, rescaleBinaryLabel)
34 | val validDF =
35 | if (validPath.trim.length > 0) {
36 | DataProcessor.readLibSVMDataAsDF(spark, validPath, numFeatures, numPartitions,
37 | oneBased, needSort, rescaleBinaryLabel)
38 | } else {
39 | null
40 | }
41 |
42 | val ovrls = new OVRLinearScope()
43 | .setFeaturesCol("features")
44 | .setLabelCol("label")
45 | .setStepSize(stepSize)
46 | .setMaxIter(numIterations)
47 | .setNumClasses(numClasses)
48 | .setNumPartitions(numPartitions)
49 | .setLossFunc(lossType)
50 |
51 | val ovrlsModel = ovrls.fit(trainDF)
52 |
53 | if (validPath.trim.length > 0) {
54 |
55 | val df = ovrlsModel.transform(validDF)
56 | df.show(false)
57 |
58 | if (numClasses > 2) {
59 | /*********** MulticlassClassificationEvaluator *************/
60 | val evaluator = new MulticlassClassificationEvaluator()
61 | .setLabelCol("label")
62 | .setPredictionCol("prediction")
63 | .setMetricName("accuracy")
64 | val acc = evaluator.evaluate(df)
65 | println("Model Validating accuracy: " + acc)
66 | } else {
67 | /*********** BinaryClassificationEvaluator *************/
68 | val evaluator = new BinaryClassificationEvaluator()
69 | .setLabelCol("label")
70 | .setRawPredictionCol("rawPrediction")
71 | .setMetricName("areaUnderROC")
72 | val auroc = evaluator.evaluate(df)
73 | println("Model Validating AUROC: " + auroc)
74 | }
75 |
76 | }
77 |
78 | spark.stop()
79 |
80 | }
81 |
82 | }
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/analysis/model/runFromLibSVMDataToXDMLSR.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.analysis.model
2 |
3 | import net.qihoo.xitong.xdml.model.data.LogHandler
4 | import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, MulticlassClassificationEvaluator}
5 | import org.apache.spark.ml.model.supervised.MultiLinearScope
6 | import org.apache.spark.sql.SparkSession
7 | import org.apache.spark.sql.data.DataProcessor
8 |
9 | object runFromLibSVMDataToXDMLSR {
10 |
11 | def main(args: Array[String]) {
12 |
13 | LogHandler.avoidLog()
14 |
15 | val spark = SparkSession
16 | .builder()
17 | .appName("runFromLibSVMDataToXDMLSR")
18 | .getOrCreate()
19 |
20 | val trainPath = args(0).toString
21 | val validPath = args(1).toString
22 | val numFeatures = args(2).toInt
23 | val numPartitions = args(3).toInt
24 | val oneBased = args(4).toBoolean
25 | val needSort = args(5).toBoolean
26 | val rescaleBinaryLabel = args(6).toBoolean
27 | val stepSize = args(7).toDouble
28 | val numIterations = args(8).toInt
29 | val numClasses = args(9).toInt
30 |
31 | val trainDF = DataProcessor.readLibSVMDataAsDF(spark, trainPath, numFeatures, numPartitions,
32 | oneBased, needSort, rescaleBinaryLabel)
33 | val validDF =
34 | if (validPath.trim.length > 0) {
35 | DataProcessor.readLibSVMDataAsDF(spark, validPath, numFeatures, numPartitions,
36 | oneBased, needSort, rescaleBinaryLabel)
37 | } else {
38 | null
39 | }
40 |
41 | val mls = new MultiLinearScope()
42 | .setFeaturesCol("features")
43 | .setLabelCol("label")
44 | .setStepSize(stepSize)
45 | .setMaxIter(numIterations)
46 | .setNumClasses(numClasses)
47 | .setNumPartitions(numPartitions)
48 |
49 | val mlsModel = mls.fit(trainDF)
50 |
51 | if (validPath.trim.length > 0) {
52 |
53 | val df = mlsModel.transform(validDF)
54 | df.show(false)
55 |
56 | if (numClasses > 2) {
57 | /*********** MulticlassClassificationEvaluator *************/
58 | val evaluator = new MulticlassClassificationEvaluator()
59 | .setLabelCol("label")
60 | .setPredictionCol("prediction")
61 | .setMetricName("accuracy")
62 | val acc = evaluator.evaluate(df)
63 | println("Model Validating accuracy: " + acc)
64 | } else {
65 | /*********** BinaryClassificationEvaluator *************/
66 | val evaluator = new BinaryClassificationEvaluator()
67 | .setLabelCol("label")
68 | .setRawPredictionCol("rawPrediction")
69 | .setMetricName("areaUnderROC")
70 | val auroc = evaluator.evaluate(df)
71 | println("Model Validating AUROC: " + auroc)
72 | }
73 |
74 | }
75 |
76 | spark.stop()
77 |
78 | }
79 |
80 | }
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/ml/DCASGDTest.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.ml
2 |
3 | import net.qihoo.xitong.xdml.conf.{JobConfiguration, JobType, PSConfiguration, PSDataType}
4 | import net.qihoo.xitong.xdml.dataProcess.LibSVMProcessor
5 | import net.qihoo.xitong.xdml.ml.LogisticRegressionWithDCASGD
6 | import net.qihoo.xitong.xdml.ps.PS
7 | import net.qihoo.xitong.xdml.utils.XDMLException
8 | import org.apache.hadoop.fs.Path
9 | import org.apache.spark.{SparkConf, SparkContext}
10 |
11 | object DCASGDTest {
12 | def main(args: Array[String]): Unit = {
13 | val conf = new SparkConf()
14 | .setAppName("Momentum-Test")
15 | val sc = new SparkContext(conf)
16 | //read spark config
17 | val jobConf = new JobConfiguration().readJobConfig(sc)
18 | val dataPath = jobConf.getOrElse(JobConfiguration.DATA_PATH, "")
19 | val dataPartitionNum = jobConf.getOrElse(JobConfiguration.TRAIN_DATA_PARTITION, "50").toInt
20 | val iterNum = jobConf.getOrElse(JobConfiguration.TRAIN_ITER_NUM, "1").toInt
21 | val batchSize = jobConf.getOrElse(JobConfiguration.BATCH_SIZE, "5000").toInt
22 | val jobType = jobConf.getOrElse(JobConfiguration.JOB_TYPE, "train")
23 | val modelPath = jobConf.getOrElse(JobConfiguration.MODEL_PATH, "")
24 | val learningRate = jobConf.getOrElse(JobConfiguration.LEARNING_RATE, "0.01").toFloat
25 | val coff = jobConf.getOrElse(JobConfiguration.DC_ASGD_COFF, "0.1").toFloat
26 | val hzClusterNum = jobConf.getOrElse(JobConfiguration.HZ_CLUSTER_NUM, "50").toInt
27 | val hzPartitionNum = jobConf.getOrElse(JobConfiguration.HZ_PARTITION_NUM, "271").toInt
28 | val tableName = System.getProperty("user.name") + "_" + jobConf.getOrElse(JobConfiguration.TABLE_NAME, "DC-ASGD")
29 | val kuduMaster = jobConf.getOrElse(JobConfiguration.KUDU_MASTER, "")
30 | val resultPath = jobConf.getOrElse(JobConfiguration.PREDICT_RESULT_PATH, "")
31 | val split = jobConf.getOrElse(JobConfiguration.DATA_SPLIT, " ")
32 | if (kuduMaster.equals("")) {
33 | throw new XDMLException("kudu master must be set!")
34 | }
35 | //read data
36 | val rawData = sc.textFile(dataPath)
37 | val data = LibSVMProcessor.processData(rawData, split).coalesce(dataPartitionNum)
38 | val psConf = new PSConfiguration()
39 | .setPsTableName(tableName)
40 | .setHzClusterNum(hzClusterNum)
41 | .setHzPartitionNum(hzPartitionNum)
42 | .setPsDataType(PSDataType.FLOAT_ARRAY)
43 | .setPsDataLength(2)
44 | .setKuduMaster(kuduMaster)
45 | if (jobType.toUpperCase.equals("TRAIN")) {
46 | psConf.setJobType(JobType.TRAIN)
47 | } else if (jobType.toUpperCase.equals("PREDICT")) {
48 | if (modelPath.equals(""))
49 | throw new XDMLException("Predict job must have a model path")
50 | else
51 | psConf.setPredictModelPath(modelPath)
52 | psConf.setJobType(JobType.PREDICT)
53 | } else if (jobType.toUpperCase.equals("INCREMENT_TRAIN")) {
54 | if (modelPath.equals(""))
55 | throw new XDMLException("INCREMENT_TRAIN job must have a model path")
56 | else
57 | psConf.setPredictModelPath(modelPath)
58 | psConf.setJobType(JobType.INCREMENT_TRAIN)
59 | } else {
60 | throw new XDMLException("Wrong job type")
61 | }
62 | val ps = PS.getInstance(sc, psConf)
63 | val model = new LogisticRegressionWithDCASGD(ps)
64 | .setIterNum(iterNum)
65 | .setBatchSize(batchSize)
66 | .setLearningRate(learningRate)
67 | .setDcAsgdCoff(coff)
68 | psConf.getJobType match {
69 | case JobType.TRAIN => {
70 | //start train
71 | println("Start Train...")
72 | val info = model.fit(data)
73 | println(s"Save the model to path :" + modelPath)
74 | ps.saveModel(sc, modelPath)
75 | }
76 | case JobType.PREDICT => {
77 | //start predict
78 | println("Start Predict...")
79 | val result = model.predict(data)
80 | val path = new Path(resultPath)
81 | val fs = path.getFileSystem(sc.hadoopConfiguration)
82 | if (fs.exists(path)) {
83 | println(s"Result Path ${resultPath} existed, delete.")
84 | fs.delete(path, true)
85 | }
86 | result.saveAsTextFile(resultPath)
87 | val rightRate = result.filter(x => (if (x._1 > 0.5) 1 else 0) == x._2).count().toDouble / result.count().toDouble
88 | println("right rate is : " + rightRate)
89 | }
90 | case JobType.INCREMENT_TRAIN => {
91 | //start train
92 | println("Start Increment Train...")
93 | val info = model.fit(data)
94 | println(s"Save the new model to path :" + modelPath)
95 | ps.saveModel(sc, modelPath)
96 | }
97 | }
98 | sc.stop()
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/ml/FFMTest.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.ml
2 |
3 | import net.qihoo.xitong.xdml.conf.{JobConfiguration, JobType, PSConfiguration, PSDataType}
4 | import net.qihoo.xitong.xdml.dataProcess.FFMProcessor
5 | import net.qihoo.xitong.xdml.ml.FieldawareFactorizationMachine
6 | import net.qihoo.xitong.xdml.ps.PS
7 | import net.qihoo.xitong.xdml.utils.XDMLException
8 | import org.apache.spark.{SparkConf, SparkContext}
9 |
10 | object FFMTest {
11 | def main(args: Array[String]): Unit = {
12 | val conf = new SparkConf()
13 | .setAppName("FFM-Test")
14 | val sc = new SparkContext(conf)
15 | //read spark config
16 | val jobConf = new JobConfiguration().readJobConfig(sc)
17 | val dataPath = jobConf.getOrElse(JobConfiguration.DATA_PATH, "")
18 | val dataPartitionNum = jobConf.getOrElse(JobConfiguration.TRAIN_DATA_PARTITION, "50").toInt
19 | val iterNum = jobConf.getOrElse(JobConfiguration.TRAIN_ITER_NUM, "1").toInt
20 | val batchSize = jobConf.getOrElse(JobConfiguration.BATCH_SIZE, "5000").toInt
21 | val jobType = jobConf.getOrElse(JobConfiguration.JOB_TYPE, "train")
22 | val modelPath = jobConf.getOrElse(JobConfiguration.MODEL_PATH, "")
23 | val learningRate = jobConf.getOrElse(JobConfiguration.LEARNING_RATE, "0.01").toFloat
24 | val hzClusterNum = jobConf.getOrElse(JobConfiguration.HZ_CLUSTER_NUM, "50").toInt
25 | val hzPartitionNum = jobConf.getOrElse(JobConfiguration.HZ_PARTITION_NUM, "271").toInt
26 | val rank = jobConf.getOrElse(JobConfiguration.FFM_RANK, "1").toInt
27 | val field = jobConf.getOrElse(JobConfiguration.FFM_FIELD, "1").toInt
28 | val tableName = System.getProperty("user.name") + "_" + jobConf.getOrElse(JobConfiguration.TABLE_NAME, "FFM")
29 | val kuduMaster = jobConf.getOrElse(JobConfiguration.KUDU_MASTER, "")
30 | val resultPath = jobConf.getOrElse(JobConfiguration.PREDICT_RESULT_PATH, "")
31 | val split = jobConf.getOrElse(JobConfiguration.DATA_SPLIT, " ")
32 | if (kuduMaster.equals("")) {
33 | throw new XDMLException("kudu master must be set!")
34 | }
35 | //read data
36 | val rawData = sc.textFile(dataPath)
37 | val data = FFMProcessor.processData(rawData,split).coalesce(dataPartitionNum)
38 | val psConf = new PSConfiguration()
39 | .setPsTableName(tableName)
40 | .setHzClusterNum(hzClusterNum)
41 | .setHzPartitionNum(hzPartitionNum)
42 | .setPsDataType(PSDataType.FLOAT_ARRAY)
43 | .setPsDataLength(rank * field)
44 | .setKuduMaster(kuduMaster)
45 | if (jobType.toUpperCase.equals("TRAIN")) {
46 | psConf.setJobType(JobType.TRAIN)
47 | } else if (jobType.toUpperCase.equals("PREDICT")) {
48 | if (modelPath.equals(""))
49 | throw new XDMLException("Predict job must have a model path")
50 | else
51 | psConf.setPredictModelPath(modelPath)
52 | psConf.setJobType(JobType.PREDICT)
53 | } else if (jobType.toUpperCase.equals("INCREMENT_TRAIN")) {
54 | if (modelPath.equals(""))
55 | throw new XDMLException("INCREMENT_TRAIN job must have a model path")
56 | else
57 | psConf.setPredictModelPath(modelPath)
58 | psConf.setJobType(JobType.INCREMENT_TRAIN)
59 | } else {
60 | throw new XDMLException("Wrong job type")
61 | }
62 | val ps = PS.getInstance(sc, psConf)
63 | val model = new FieldawareFactorizationMachine(ps)
64 | .setIterNum(iterNum)
65 | .setBatchSize(batchSize)
66 | .setLearningRate(learningRate)
67 | .setField(field)
68 | .setRank(rank)
69 | psConf.getJobType match {
70 | case JobType.TRAIN => {
71 | //start train
72 | println("Start Train...")
73 | val info = model.fit(data)
74 | println(s"Save the model to path :" + modelPath)
75 | ps.saveModel(sc, modelPath)
76 | }
77 | case JobType.PREDICT => {
78 | //start predict
79 | println("Start Predict...")
80 | }
81 | case JobType.INCREMENT_TRAIN => {
82 | //start train
83 | println("Start Increment Train...")
84 | val info = model.fit(data)
85 | println(s"Save the new model to path :" + modelPath)
86 | ps.saveModel(sc, modelPath)
87 | }
88 | }
89 | sc.stop()
90 | }
91 | }
92 |
93 |
94 |
95 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/ml/FTRLTest.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.ml
2 |
3 | import net.qihoo.xitong.xdml.conf.{JobConfiguration, JobType, PSConfiguration, PSDataType}
4 | import net.qihoo.xitong.xdml.dataProcess.LibSVMProcessor
5 | import net.qihoo.xitong.xdml.ml.LogisticRegressionWithFTRL
6 | import net.qihoo.xitong.xdml.ps.PS
7 | import net.qihoo.xitong.xdml.utils.XDMLException
8 | import org.apache.hadoop.fs.Path
9 | import org.apache.spark.{SparkConf, SparkContext}
10 |
11 | object FTRLTest {
12 | def main(args: Array[String]): Unit = {
13 |
14 | val conf = new SparkConf()
15 | .setAppName("LR-Test")
16 | val sc = new SparkContext(conf)
17 | //read spark config
18 | val jobConf = new JobConfiguration().readJobConfig(sc)
19 | val dataPath = jobConf.getOrElse(JobConfiguration.DATA_PATH, "")
20 | val dataPartitionNum = jobConf.getOrElse(JobConfiguration.TRAIN_DATA_PARTITION, "50").toInt
21 | val iterNum = jobConf.getOrElse(JobConfiguration.TRAIN_ITER_NUM, "1").toInt
22 | val batchSize = jobConf.getOrElse(JobConfiguration.BATCH_SIZE, "5000").toInt
23 | val jobType = jobConf.getOrElse(JobConfiguration.JOB_TYPE, "train")
24 | val modelPath = jobConf.getOrElse(JobConfiguration.MODEL_PATH, "")
25 | val hzClusterNum = jobConf.getOrElse(JobConfiguration.HZ_CLUSTER_NUM, "50").toInt
26 | val hzPartitionNum = jobConf.getOrElse(JobConfiguration.HZ_PARTITION_NUM, "271").toInt
27 | val alpha = jobConf.getOrElse(JobConfiguration.FTRL_ALPHA, "1.0").toFloat
28 | val beta = jobConf.getOrElse(JobConfiguration.FTRL_BETA, "1.0").toFloat
29 | val lambda1 = jobConf.getOrElse(JobConfiguration.FTRL_LAMBDA1, "1.0").toFloat
30 | val lambda2 = jobConf.getOrElse(JobConfiguration.FTRL_LAMBDA2, "1.0").toFloat
31 | val forceSparse = jobConf.getOrElse(JobConfiguration.FTRL_FROCESPARSE, "false").toBoolean
32 | val tableName = System.getProperty("user.name") + "_" + jobConf.getOrElse(JobConfiguration.TABLE_NAME, "FTRL")
33 | val kuduMaster = jobConf.getOrElse(JobConfiguration.KUDU_MASTER, "")
34 | val resultPath = jobConf.getOrElse(JobConfiguration.PREDICT_RESULT_PATH, "")
35 | val split = jobConf.getOrElse(JobConfiguration.DATA_SPLIT, " ")
36 | if (kuduMaster.equals("")) {
37 | throw new XDMLException("kudu master must be set!")
38 | }
39 | //read data
40 | val rawData = sc.textFile(dataPath)
41 | val data = LibSVMProcessor.processData(rawData, split).coalesce(dataPartitionNum)
42 | val psConf = new PSConfiguration()
43 | .setPsDataType(PSDataType.FLOAT_ARRAY)
44 | .setPsDataLength(3)
45 | .setPsTableName(tableName)
46 | .setHzClusterNum(hzClusterNum)
47 | .setHzPartitionNum(hzPartitionNum)
48 | .setForceSparse(forceSparse)
49 | .setKuduMaster(kuduMaster)
50 | if (jobType.toUpperCase.equals("TRAIN")) {
51 | psConf.setJobType(JobType.TRAIN)
52 | } else if (jobType.toUpperCase.equals("PREDICT")) {
53 | if (modelPath.equals(""))
54 | throw new XDMLException("Predict job must have a model path")
55 | else
56 | psConf.setPredictModelPath(modelPath)
57 | psConf.setJobType(JobType.PREDICT)
58 | } else if (jobType.toUpperCase.equals("INCREMENT_TRAIN")) {
59 | if (modelPath.equals(""))
60 | throw new XDMLException("INCREMENT_TRAIN job must have a model path")
61 | else
62 | psConf.setPredictModelPath(modelPath)
63 | psConf.setJobType(JobType.INCREMENT_TRAIN)
64 | } else {
65 | throw new XDMLException("Wrong job type")
66 | }
67 | val ps = PS.getInstance(sc, psConf)
68 | val model = new LogisticRegressionWithFTRL(ps)
69 | .setIterNum(iterNum)
70 | .setBatchSize(batchSize)
71 | .setAlpha(alpha)
72 | .setBeta(beta)
73 | .setLambda1(lambda1)
74 | .setLambda2(lambda2)
75 | psConf.getJobType match {
76 | case JobType.TRAIN => {
77 | //start train
78 | println("Start Train...")
79 | val info = model.fit(data)
80 | println(s"Save the model to path :" + modelPath)
81 | ps.saveModel(sc, modelPath)
82 | }
83 | case JobType.PREDICT => {
84 | //start predict
85 | println("Start Predict...")
86 | val result = model.predict(data)
87 | val path = new Path(resultPath)
88 | val fs = path.getFileSystem(sc.hadoopConfiguration)
89 | if (fs.exists(path)) {
90 | println(s"Result Path ${resultPath} existed, delete.")
91 | fs.delete(path, true)
92 | }
93 | result.saveAsTextFile(resultPath)
94 | val rightRate = result.filter(x => (if (x._1 > 0.5) 1 else 0) == x._2).count().toDouble / result.count().toDouble
95 | println("right rate is : " + rightRate)
96 | }
97 | case JobType.INCREMENT_TRAIN => {
98 | //start train
99 | println("Start Increment Train...")
100 | val info = model.fit(data)
101 | println(s"Save the new model to path :" + modelPath)
102 | ps.saveModel(sc, modelPath)
103 | }
104 | }
105 | sc.stop()
106 | }
107 | }
108 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/ml/LRTest.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.ml
2 |
3 | import net.qihoo.xitong.xdml.conf.{JobConfiguration, JobType, PSConfiguration, PSDataType}
4 | import net.qihoo.xitong.xdml.dataProcess.LibSVMProcessor
5 | import net.qihoo.xitong.xdml.ml.LogisticRegression
6 | import net.qihoo.xitong.xdml.ps.PS
7 | import net.qihoo.xitong.xdml.utils.XDMLException
8 | import org.apache.hadoop.fs.Path
9 | import org.apache.spark.{SparkConf, SparkContext}
10 |
11 | object LRTest {
12 | def main(args: Array[String]): Unit = {
13 | val conf = new SparkConf()
14 | .setAppName("LR-Test")
15 | val sc = new SparkContext(conf)
16 | //read spark config
17 | val jobConf = new JobConfiguration().readJobConfig(sc)
18 | val dataPath = jobConf.getOrElse(JobConfiguration.DATA_PATH, "")
19 | val dataPartitionNum = jobConf.getOrElse(JobConfiguration.TRAIN_DATA_PARTITION, "50").toInt
20 | val iterNum = jobConf.getOrElse(JobConfiguration.TRAIN_ITER_NUM, "1").toInt
21 | val batchSize = jobConf.getOrElse(JobConfiguration.BATCH_SIZE, "5000").toInt
22 | val jobType = jobConf.getOrElse(JobConfiguration.JOB_TYPE, "train")
23 | val modelPath = jobConf.getOrElse(JobConfiguration.MODEL_PATH, "")
24 | val learningRate = jobConf.getOrElse(JobConfiguration.LEARNING_RATE, "0.01").toFloat
25 | val hzClusterNum = jobConf.getOrElse(JobConfiguration.HZ_CLUSTER_NUM, "50").toInt
26 | val hzPartitionNum = jobConf.getOrElse(JobConfiguration.HZ_PARTITION_NUM, "271").toInt
27 | val tableName = System.getProperty("user.name") + "_" + jobConf.getOrElse(JobConfiguration.TABLE_NAME, "LR")
28 | val kuduMaster = jobConf.getOrElse(JobConfiguration.KUDU_MASTER, "")
29 | val resultPath = jobConf.getOrElse(JobConfiguration.PREDICT_RESULT_PATH, "")
30 | val split = jobConf.getOrElse(JobConfiguration.DATA_SPLIT, " ")
31 | if (kuduMaster.equals("")) {
32 | throw new XDMLException("kudu master must be set!")
33 | }
34 | //read data
35 | val rawData = sc.textFile(dataPath)
36 | val data = LibSVMProcessor.processData(rawData, split).coalesce(dataPartitionNum)
37 | val psConf = new PSConfiguration()
38 | .setPsDataType(PSDataType.FLOAT)
39 | .setPsDataLength(1)
40 | .setPsTableName(tableName)
41 | .setHzClusterNum(hzClusterNum)
42 | .setHzPartitionNum(hzPartitionNum)
43 | .setKuduMaster(kuduMaster)
44 | if (jobType.toUpperCase.equals("TRAIN")) {
45 | psConf.setJobType(JobType.TRAIN)
46 | } else if (jobType.toUpperCase.equals("PREDICT")) {
47 | if (modelPath.equals("") || resultPath.equals(""))
48 | throw new XDMLException("Predict job must have a model path and result path!")
49 | else
50 | psConf.setPredictModelPath(modelPath)
51 | psConf.setJobType(JobType.PREDICT)
52 | } else if (jobType.toUpperCase.equals("INCREMENT_TRAIN")) {
53 | if (modelPath.equals(""))
54 | throw new XDMLException("INCREMENT_TRAIN job must have a model path")
55 | else
56 | psConf.setPredictModelPath(modelPath)
57 | psConf.setJobType(JobType.INCREMENT_TRAIN)
58 | } else {
59 | throw new XDMLException("Wrong job type")
60 | }
61 | val ps = PS.getInstance(sc, psConf)
62 | val model = new LogisticRegression(ps)
63 | .setIterNum(iterNum)
64 | .setBatchSize(batchSize)
65 | .setLearningRate(learningRate)
66 | psConf.getJobType match {
67 | case JobType.TRAIN => {
68 | //start train
69 | println("Start Train...")
70 | val info = model.fit(data)
71 | println("Save the model to path :" + modelPath)
72 | ps.saveModel(sc, modelPath)
73 | }
74 | case JobType.PREDICT => {
75 | //start predict
76 | println("Start Predict...")
77 | val result = model.predict(data)
78 | val path = new Path(resultPath)
79 | val fs = path.getFileSystem(sc.hadoopConfiguration)
80 | if (fs.exists(path)) {
81 | println(s"Result Path ${resultPath} existed, delete.")
82 | fs.delete(path, true)
83 | }
84 | result.saveAsTextFile(resultPath)
85 | val rightRate = result.filter(x => (if (x._1 > 0.5) 1 else 0) == x._2).count().toDouble / result.count().toDouble
86 | println("right rate is : " + rightRate)
87 | }
88 | case JobType.INCREMENT_TRAIN => {
89 | //start train
90 | println("Start Increment Train...")
91 | val info = model.fit(data)
92 | println(s"Save the new model to path :" + modelPath)
93 | ps.saveModel(sc, modelPath)
94 | }
95 | }
96 | sc.stop()
97 | }
98 | }
99 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/example/ml/MomentumTest.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.example.ml
2 |
3 | import net.qihoo.xitong.xdml.conf.{JobConfiguration, JobType, PSConfiguration, PSDataType}
4 | import net.qihoo.xitong.xdml.dataProcess.LibSVMProcessor
5 | import net.qihoo.xitong.xdml.ml.LogisticRegressionWithMomentum
6 | import net.qihoo.xitong.xdml.ps.PS
7 | import net.qihoo.xitong.xdml.utils.XDMLException
8 | import org.apache.hadoop.fs.Path
9 | import org.apache.spark.{SparkConf, SparkContext}
10 |
11 | object MomentumTest {
12 | def main(args: Array[String]): Unit = {
13 | val conf = new SparkConf()
14 | .setAppName("Momentum-Test")
15 | val sc = new SparkContext(conf)
16 | //read spark config
17 | val jobConf = new JobConfiguration().readJobConfig(sc)
18 | val dataPath = jobConf.getOrElse(JobConfiguration.DATA_PATH, "")
19 | val dataPartitionNum = jobConf.getOrElse(JobConfiguration.TRAIN_DATA_PARTITION, "50").toInt
20 | val iterNum = jobConf.getOrElse(JobConfiguration.TRAIN_ITER_NUM, "1").toInt
21 | val batchSize = jobConf.getOrElse(JobConfiguration.BATCH_SIZE, "5000").toInt
22 | val jobType = jobConf.getOrElse(JobConfiguration.JOB_TYPE, "train")
23 | val modelPath = jobConf.getOrElse(JobConfiguration.MODEL_PATH, "")
24 | val learningRate = jobConf.getOrElse(JobConfiguration.LEARNING_RATE, "0.01").toFloat
25 | val momCoff = jobConf.getOrElse(JobConfiguration.MOMENTUM_COFF, "0.1").toFloat
26 | val hzClusterNum = jobConf.getOrElse(JobConfiguration.HZ_CLUSTER_NUM, "50").toInt
27 | val hzPartitionNum = jobConf.getOrElse(JobConfiguration.HZ_PARTITION_NUM, "271").toInt
28 | val tableName = System.getProperty("user.name") + "_" + jobConf.getOrElse(JobConfiguration.TABLE_NAME, "Momentum")
29 | val kuduMaster = jobConf.getOrElse(JobConfiguration.KUDU_MASTER, "")
30 | val resultPath = jobConf.getOrElse(JobConfiguration.PREDICT_RESULT_PATH, "")
31 | val split = jobConf.getOrElse(JobConfiguration.DATA_SPLIT, " ")
32 | if (kuduMaster.equals("")) {
33 | throw new XDMLException("kudu master must be set!")
34 | }
35 | //read data
36 | val rawData = sc.textFile(dataPath)
37 | val data = LibSVMProcessor.processData(rawData, split).coalesce(dataPartitionNum)
38 | val psConf = new PSConfiguration()
39 | .setPsTableName(tableName)
40 | .setHzClusterNum(hzClusterNum)
41 | .setHzPartitionNum(hzPartitionNum)
42 | .setPsDataType(PSDataType.FLOAT_ARRAY)
43 | .setPsDataLength(2)
44 | .setKuduMaster(kuduMaster)
45 | if (jobType.toUpperCase.equals("TRAIN")) {
46 | psConf.setJobType(JobType.TRAIN)
47 | } else if (jobType.toUpperCase.equals("PREDICT")) {
48 | if (modelPath.equals(""))
49 | throw new XDMLException("Predict job must have a model path")
50 | else
51 | psConf.setPredictModelPath(modelPath)
52 | psConf.setJobType(JobType.PREDICT)
53 | } else if (jobType.toUpperCase.equals("INCREMENT_TRAIN")) {
54 | if (modelPath.equals(""))
55 | throw new XDMLException("INCREMENT_TRAIN job must have a model path")
56 | else
57 | psConf.setPredictModelPath(modelPath)
58 | psConf.setJobType(JobType.INCREMENT_TRAIN)
59 | } else {
60 | throw new XDMLException("Wrong job type")
61 | }
62 | val ps = PS.getInstance(sc, psConf)
63 | val model = new LogisticRegressionWithMomentum(ps)
64 | .setIterNum(iterNum)
65 | .setBatchSize(batchSize)
66 | .setLearningRate(learningRate)
67 | .setMomemtumCoff(momCoff)
68 | psConf.getJobType match {
69 | case JobType.TRAIN => {
70 | //start train
71 | println("Start Train...")
72 | val info = model.fit(data)
73 | println(s"Save the model to path :" + modelPath)
74 | ps.saveModel(sc, modelPath)
75 | }
76 | case JobType.PREDICT => {
77 | //start predict
78 | println("Start Predict...")
79 | val result = model.predict(data)
80 | val path = new Path(resultPath)
81 | val fs = path.getFileSystem(sc.hadoopConfiguration)
82 | if (fs.exists(path)) {
83 | println(s"Result Path ${resultPath} existed, delete.")
84 | fs.delete(path, true)
85 | }
86 | result.saveAsTextFile(resultPath)
87 | val rightRate = result.filter(x => (if (x._1 > 0.5) 1 else 0) == x._2).count().toDouble / result.count().toDouble
88 | println("right rate is : " + rightRate)
89 | }
90 | case JobType.INCREMENT_TRAIN => {
91 | //start train
92 | println("Start Increment Train...")
93 | val info = model.fit(data)
94 | println(s"Save the new model to path :" + modelPath)
95 | ps.saveModel(sc, modelPath)
96 | }
97 | }
98 | sc.stop()
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/feature/process/ColumnEliminator.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.feature.process
2 |
3 | import org.apache.spark.ml.Transformer
4 | import org.apache.spark.ml.param.ParamMap
5 | import org.apache.spark.ml.param.shared.HasInputCols
6 | import org.apache.spark.ml.util._
7 | import org.apache.spark.sql.types.{StructField, StructType}
8 | import org.apache.spark.sql.{DataFrame, Dataset}
9 |
10 | import scala.collection.mutable.ArrayBuffer
11 |
12 | class ColumnEliminator(override val uid: String)
13 | extends Transformer with HasInputCols with DefaultParamsWritable {
14 |
15 | def this() = this(Identifiable.randomUID("ColumnEliminator"))
16 |
17 | def setInputCols(value: Array[String]): this.type = set(inputCols, value)
18 |
19 | override def transform(dataset: Dataset[_]): DataFrame = {
20 | transformSchema(dataset.schema, logging = true)
21 | dataset.drop($(inputCols):_*).toDF()
22 | }
23 |
24 | override def transformSchema(schema: StructType): StructType = {
25 | val inputColNames = $(inputCols)
26 | for (colName <- inputColNames) {
27 | if (!schema.fieldNames.contains(colName)) {
28 | logWarning("missing column " + colName)
29 | }
30 | }
31 | val fieldsEliminated = new ArrayBuffer[StructField]()
32 | for (structField <- schema.fields) {
33 | if (!inputColNames.contains(structField.name)) {
34 | fieldsEliminated += structField
35 | }
36 | }
37 | StructType(fieldsEliminated.toArray)
38 | }
39 |
40 | override def copy(extra: ParamMap): ColumnEliminator = defaultCopy(extra)
41 |
42 | }
43 |
44 |
45 | object ColumnEliminator extends DefaultParamsReadable[ColumnEliminator] {
46 |
47 | override def load(path: String): ColumnEliminator = super.load(path)
48 | }
49 |
50 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/feature/process/ColumnRenamer.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.feature.process
2 |
3 | import org.apache.spark.ml.Transformer
4 | import org.apache.spark.ml.param.ParamMap
5 | import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCols}
6 | import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
7 | import org.apache.spark.sql._
8 | import org.apache.spark.sql.types.{StructField, StructType}
9 |
10 | class ColumnRenamer(override val uid: String)
11 | extends Transformer with HasInputCols with HasOutputCols with DefaultParamsWritable {
12 |
13 | def this() = this(Identifiable.randomUID("ColumnRenamer"))
14 |
15 | def setInputCols(value: Array[String]): this.type = set(inputCols, value)
16 |
17 | def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
18 |
19 | override def transform(dataset: Dataset[_]): DataFrame = {
20 | transformSchema(dataset.schema, logging = true)
21 | var df = dataset
22 | val existingZipNewColNames = $(inputCols).zip($(outputCols))
23 | for(en <- existingZipNewColNames) {
24 | df = df.withColumnRenamed(en._1, en._2)
25 | }
26 | df.toDF()
27 | }
28 |
29 | override def transformSchema(schema: StructType): StructType = {
30 | val inputColNames = $(inputCols)
31 | val inputZipOutputColNames = $(inputCols).zip($(outputCols)).toMap
32 | StructType(schema.fields.map{ structField =>
33 | if (inputColNames.contains(structField.name)) {
34 | StructField(inputZipOutputColNames(structField.name), structField.dataType, structField.nullable)
35 | } else {
36 | structField
37 | }
38 | })
39 | }
40 |
41 | override def copy(extra: ParamMap): ColumnRenamer = defaultCopy(extra)
42 |
43 | }
44 |
45 | object ColumnRenamer extends DefaultParamsReadable[ColumnRenamer] {
46 |
47 | override def load(path: String): ColumnRenamer = super.load(path)
48 | }
49 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/feature/process/PipelinePatch.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.ml
2 |
3 | import org.apache.spark.sql.DataFrame
4 |
5 | import scala.collection.mutable.ArrayBuffer
6 |
7 | object PipelinePatch {
8 |
9 | def pipelineFitTransform(df: DataFrame, pipeline: Pipeline): (DataFrame, PipelineModel) = {
10 | // check validity of schema
11 | pipeline.transformSchema(df.schema)
12 | val stagesClone = pipeline.getStages
13 | var dfTarget = df
14 | val transformers = new ArrayBuffer[Transformer]()
15 | stagesClone.foreach{ stage => {
16 | stage match {
17 | case estimator: Estimator[_] => {
18 | val e2t = estimator.fit(dfTarget)
19 | dfTarget = e2t.transform(dfTarget)
20 | transformers += e2t
21 | }
22 | case transformer: Transformer => {
23 | dfTarget = transformer.transform(dfTarget)
24 | transformers += transformer
25 | }
26 | case _ =>
27 | throw new IllegalArgumentException(
28 | s"Does not support stage $stage of type ${stage.getClass}")
29 | }
30 | }}
31 | (dfTarget, new PipelineModel(pipeline.uid, transformers.toArray).setParent(pipeline))
32 | }
33 |
34 | }
35 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/linalg/BLAS.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.linalg
2 |
3 | import scala.collection.mutable
4 |
5 | object BLAS extends Serializable {
6 |
7 | /**
8 | * dot(x, y)
9 | */
10 |
11 | def dot(x: Map[Long, Float], y: Map[Long, Float]): Float = {
12 | var result = 0f
13 | x.keys.foreach(k => result += x(k) * y.getOrElse(k, 0f))
14 | result
15 | }
16 |
17 | def dot(x: (Array[Long], Array[Float]), y: Map[Long, Float]): Float = {
18 | var result = 0f
19 | (0 until x._1.size).foreach(id => result += x._2(id) * y.getOrElse(x._1(id), 0f))
20 | result
21 | }
22 |
23 | def axpy(a: Double, x: (Array[Long], Array[Float]), y: mutable.Map[Long, Float]): Unit = {
24 | val xIndices = x._1
25 | val xValues = x._2
26 | val nnz = xIndices.length
27 | if (Math.abs(a - 1) < 1e-6) {
28 | var k = 0
29 | while (k < nnz) {
30 | if (y.contains(xIndices(k))) {
31 | y(xIndices(k)) += xValues(k)
32 | } else {
33 | y += (xIndices(k) -> xValues(k))
34 | }
35 | k += 1
36 | }
37 | } else {
38 | var k = 0
39 | while (k < nnz) {
40 | if (y.contains(xIndices(k))) {
41 | y(xIndices(k)) += (a * xValues(k)).toFloat
42 | } else {
43 | y += (xIndices(k) -> (a * xValues(k)).toFloat)
44 | }
45 | k += 1
46 | }
47 | }
48 | }
49 |
50 |
51 | def sigmoid(value: Float): Double = {
52 | if (value > 30)
53 | 1f
54 | else if (value < (-30))
55 | 1e-6
56 | else {
57 | val ex = Math.pow(2.718281828, value)
58 | (ex / (1.0 + ex))
59 | }
60 | }
61 |
62 | }
63 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/mapstore/PSMapStore.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.mapstore
2 |
3 | import java.nio.ByteBuffer
4 | import java.{lang, util}
5 |
6 | import com.hazelcast.core.MapStore
7 | import net.qihoo.xitong.xdml.conf.PSDataType.PSDataType
8 | import net.qihoo.xitong.xdml.conf.{PSConfiguration, PSDataType}
9 | import org.apache.kudu.Schema
10 | import org.apache.kudu.client.KuduScanner.KuduScannerBuilder
11 | import org.apache.kudu.client.SessionConfiguration.FlushMode
12 | import org.apache.kudu.client._
13 |
14 | import scala.collection.JavaConversions._
15 | import scala.util.Random
16 |
17 | class PSMapStore(psConf: PSConfiguration, tableName: String) extends MapStore[Long, Array[Byte]] {
18 | var client: KuduClient = _
19 | var table: KuduTable = _
20 | var session: KuduSession = _
21 | var schema: Schema = _
22 | //pull num count
23 | var count: Long = 0L
24 | var maxHzCacheSizePerPartition: Long = _
25 | //random number generator
26 | var random: Random = _
27 | var kuduCols: util.ArrayList[String] = _
28 | var scanBuilder: KuduScannerBuilder = _
29 | //V's bytes array size
30 | var byteArraySize: Int = _
31 | //initial value when pull count < cache size or kudu table doesn't have the key
32 | var initValue: ByteBuffer = _
33 | //ps conf data type
34 | var valueDataType: PSDataType = _
35 | var valueDataLength: Int = _
36 | // whether need random initial values
37 | var needRandInit: Boolean = _
38 | var randomMin:Int = _
39 | var randomMax:Int = _
40 | var kuduForceSparse:Boolean = _
41 |
42 | init()
43 | //initialize all variables and init value
44 | def init(): Unit = {
45 | client = new KuduClient.KuduClientBuilder(psConf.getKuduMaster).build()
46 | session = client.newSession()
47 | table = client.openTable(tableName)
48 | scanBuilder = client.newScannerBuilder(table)
49 | schema = table.getSchema
50 | session.setFlushMode(FlushMode.AUTO_FLUSH_BACKGROUND)
51 | session.setIgnoreAllDuplicateRows(false)
52 | session.setMutationBufferSpace(5000)
53 | kuduCols = new util.ArrayList[String]
54 | kuduCols.add(psConf.getKuduKeyColName)
55 | kuduCols.add(psConf.getKuduValueColName)
56 | random = new Random()
57 | valueDataType = psConf.getPsDataType
58 | valueDataLength = psConf.getPsDataLength
59 | byteArraySize = PSDataType.sizeOf(valueDataType) * valueDataLength
60 | println("kudu bytes size: " + byteArraySize)
61 | needRandInit = psConf.getKuduRandomInit
62 | randomMin = psConf.getKuduRandomMin
63 | randomMax = psConf.getKuduRandomMax
64 | initValue = ByteBuffer.allocate(byteArraySize)
65 | maxHzCacheSizePerPartition = psConf.getHzMaxCacheSizePerPartition
66 | kuduForceSparse = psConf.getForceSparse
67 | //need random initial value
68 | if (needRandInit) {
69 | valueDataType match {
70 | case PSDataType.FLOAT => initValue.putFloat(random.nextFloat() * (randomMax - randomMin) + randomMin)
71 | case PSDataType.DOUBLE => initValue.putDouble(random.nextDouble() * (randomMax - randomMin) + randomMin)
72 | case PSDataType.FLOAT_ARRAY => {
73 | for (index <- 0 until valueDataLength) {
74 | initValue.putFloat(index << 2,random.nextFloat()*(randomMax - randomMin) + randomMin)
75 | }
76 | }
77 | case PSDataType.DOUBLE_ARRAY => {
78 | for (index <- 0 until valueDataLength) {
79 | initValue.putDouble(index << 3,random.nextDouble()*(randomMax - randomMin) + randomMin)
80 | }
81 | }
82 | }
83 | }
84 | }
85 |
86 | override def store(key: Long, value: Array[Byte]): Unit = {
87 | val op = table.newUpsert()
88 | op.getRow.addLong(0, key)
89 | op.getRow.addBinary(1, value)
90 | session.apply(op)
91 | }
92 |
93 | override def storeAll(map: util.Map[Long, Array[Byte]]): Unit = {
94 | for ((k: Long, v: Array[Byte]) <- map if v.length == byteArraySize) {
95 | this.store(k, v)
96 | }
97 | }
98 |
99 | override def load(key: Long): Array[Byte] = {
100 | if (count < (maxHzCacheSizePerPartition - 2)) {
101 | count += 1
102 | } else {
103 | val idsPredicate = KuduPredicate.newComparisonPredicate(schema.getColumn(psConf.getKuduKeyColName), KuduPredicate.ComparisonOp.EQUAL, key)
104 | val scanner = scanBuilder.setProjectedColumnNames(kuduCols).addPredicate(idsPredicate).build()
105 | while (scanner.hasMoreRows) {
106 | val results = scanner.nextRows()
107 | while (results.hasNext) {
108 | val value = results.next().getBinaryCopy(1)
109 | if (value.length == byteArraySize) {
110 | return value
111 | } else {
112 | return initValue.array()
113 | }
114 | }
115 | }
116 | scanner.close()
117 | }
118 | initValue.array()
119 | }
120 |
121 | override def loadAll(keys: util.Collection[Long]): util.Map[Long, Array[Byte]] = {
122 | val initCapacity: Int = (keys.size() / 0.75f + 1).toInt
123 | val loadResultMap = new util.HashMap[Long, Array[Byte]](initCapacity)
124 | if (count < (maxHzCacheSizePerPartition - 2)) {
125 | for (key <- keys)
126 | loadResultMap.put(key, initValue.array())
127 | count += keys.size()
128 | } else {
129 | val idsPredicate = KuduPredicate.newInListPredicate(schema.getColumn(psConf.getKuduKeyColName), keys.toList)
130 | val scanner = client.newScannerBuilder(table).setProjectedColumnNames(kuduCols).addPredicate(idsPredicate).build()
131 | while (scanner.hasMoreRows) {
132 | val results = scanner.nextRows()
133 | while (results.hasNext) {
134 | val result = results.next()
135 | val id = result.getLong(0)
136 | val weight = result.getBinaryCopy(1)
137 | if (weight.length == byteArraySize) {
138 | loadResultMap.put(id, weight)
139 | } else {
140 | loadResultMap.put(id, initValue.array())
141 | }
142 | }
143 | }
144 | scanner.close()
145 | }
146 | loadResultMap
147 | }
148 |
149 | override def loadAllKeys(): lang.Iterable[Long] = {null}
150 |
151 | override def delete(key: Long): Unit = {}
152 |
153 | override def deleteAll(keys: util.Collection[Long]): Unit = {}
154 | }
155 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/ml/FieldawareFactorizationMachine.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.ml
2 |
3 | import java.io.Serializable
4 | import java.util
5 |
6 | import bloomfilter.mutable.BloomFilter
7 | import net.qihoo.xitong.xdml.optimization.FFM
8 | import net.qihoo.xitong.xdml.ps.{PS, PSClient}
9 | import net.qihoo.xitong.xdml.updater.FFMUpdater
10 | import org.apache.spark.rdd.RDD
11 |
12 | import scala.collection.JavaConversions._
13 |
14 | class FieldawareFactorizationMachine(ps:PS) extends Serializable {
15 | private var iterNum = 1
16 | private var batchSize = 50
17 | private var learningRate:Float = 0.01F
18 | private var rank = 1
19 | private var field = 1
20 | def fit(data: RDD[(Double, Array[Int], Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) = {
21 | FieldawareFactorizationMachine.runMiniBatchSGD(ps,data, iterNum, batchSize, rank, field, learningRate)
22 | }
23 | def setIterNum(iterNum:Int):this.type = {
24 | this.iterNum = iterNum
25 | this
26 | }
27 |
28 | def setBatchSize(batchSize:Int):this.type = {
29 | this.batchSize = batchSize
30 | this
31 | }
32 |
33 | def setLearningRate(lr:Float):this.type ={
34 | this.learningRate = lr
35 | this
36 | }
37 |
38 | def setRank(rank:Int):this.type ={
39 | this.rank = rank
40 | this
41 | }
42 |
43 | def setField(field:Int):this.type ={
44 | this.field = field
45 | this
46 | }
47 | }
48 |
49 |
50 | object FieldawareFactorizationMachine{
51 | def runMiniBatchSGD(ps:PS,
52 | data: RDD[(Double, Array[Int], Array[Long], Array[Float])],
53 | iter: Int,
54 | batchSize: Int,
55 | rank: Int,
56 | fieldNum: Int,
57 | learningRate: Float
58 | ): (Map[Int, Double], (Long, Long)) = {
59 | println(s"Start the hz cluster ...")
60 | var trainInfo: Map[Int, Double] = Map()
61 | var posNum = 0L
62 | var negNum = 0L
63 | for (iterNum <- 0 until iter) {
64 | val result = data.mapPartitions { iter =>
65 | val updater = new FFMUpdater()
66 | .setLearningRate(learningRate)
67 | val client = new PSClient[Long, Array[Float]](ps)
68 | .setUpdater(updater)
69 | var count = 0
70 | val weightIndex = new util.HashSet[Long]()
71 | val localData = new util.ArrayList[(Double, Array[Int], Array[Long], Array[Float])]()
72 | var batchId = 0
73 | var flag = false
74 | var totalLoss = 0.0
75 | var totalCount = 0L
76 | var posNum = 0L
77 | var negNum = 0L
78 | var batchLoss = 0.0
79 | val expectedElements = 100000000
80 | val falsePositiveRate = 0.1
81 | val bf = BloomFilter[Long](expectedElements, falsePositiveRate)
82 | var partBatchSize = batchSize
83 |
84 | while (iter.hasNext) {
85 | val dataLine = iter.next()
86 | localData.add(dataLine)
87 | dataLine._3.foreach(x =>
88 | if (bf.mightContain(x)) {
89 | weightIndex.add(x)
90 | } else {
91 | bf.add(x)
92 | }
93 | )
94 | if (iterNum == 0) {
95 | if (dataLine._1 > 0.0)
96 | posNum += 1
97 | else
98 | negNum += 1
99 | }
100 | count += 1
101 | if (count == partBatchSize || !iter.hasNext) {
102 | println(s"weightIndex.size: ${weightIndex.size()}")
103 |
104 | val localV = client.pull(weightIndex)
105 | println("localV: " + localV.map{
106 | case(k,v) => k + "-->" + v.mkString(",")
107 | })
108 |
109 | var totalGrad: Map[Long, Array[Float]] = Map()
110 | localData.foreach { dataIter =>
111 | // train the data, get the grad and loss
112 | val (grad, loss) = FFM.train(
113 | fieldNum,
114 | rank,
115 | dataIter._1,
116 | (dataIter._2, dataIter._3, dataIter._4),
117 | localV
118 | )
119 | println("loss: " + loss)
120 | // update the inc
121 | grad.foreach { case (k, v) =>
122 | val gradArray = totalGrad.getOrElse(k, Array.fill(rank * fieldNum + fieldNum)(0.0f))
123 | for(index <- 0 until rank*fieldNum) {
124 | //一组batch的梯度累加
125 | gradArray(index) = gradArray(index) + v(index)
126 | }
127 | totalGrad += (k -> gradArray)
128 | }
129 | batchLoss += loss
130 | }
131 | println(s"TotalLoss: ${batchLoss} . SampleNum: ${localData.size()}. Average Loss: ${batchLoss / count}")
132 | val us = System.nanoTime()
133 | client.push(totalGrad)
134 |
135 | val ue = System.nanoTime()
136 | println(s"update weight time consume is: ${(ue - us) / 1e9}")
137 | batchId += 1
138 | flag = true
139 | totalLoss += batchLoss
140 | totalCount += count
141 | count = 0
142 | batchLoss = 0.0
143 | localData.clear()
144 | weightIndex.clear()
145 | }
146 |
147 | if (batchId % 50 == 0 && flag) {
148 | val s = System.nanoTime()
149 | println(s"epoch ${iterNum} sync begin.")
150 | Thread.sleep(2000)
151 | val e = System.nanoTime()
152 | println(s"epoch ${iterNum} sync end, consume is: ${(e - s) / 1e9}")
153 | flag = false
154 | }
155 | }
156 | Iterator((totalLoss, totalCount, posNum, negNum))
157 | }
158 | if (iterNum == 0) {
159 | val totalResult = result.reduce((x, y) => (x._1 + y._1, x._2 + y._2, x._3 + y._3, x._4 + y._4))
160 | posNum = totalResult._3
161 | negNum = totalResult._4
162 | trainInfo += ((iterNum + 1) -> (totalResult._1 / totalResult._2))
163 | } else {
164 | val totalResult = result.map(x => (x._1, x._2)).reduce((x, y) => (x._1 + y._1, x._2 + y._2))
165 | trainInfo += ((iterNum + 1) -> (totalResult._1 / totalResult._2))
166 | }
167 | }
168 | Thread.sleep(60000)
169 | (trainInfo, (posNum, negNum))
170 | }
171 | }
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/ml/LogisticRegression.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.ml
2 |
3 | import java.io.Serializable
4 | import java.util
5 |
6 | import net.qihoo.xitong.xdml.optimization.BinaryClassify
7 | import net.qihoo.xitong.xdml.ps.{PS, PSClient}
8 | import net.qihoo.xitong.xdml.updater.LRUpdater
9 | import org.apache.spark.rdd.RDD
10 |
11 | import scala.collection.JavaConversions._
12 |
13 | class LogisticRegression(ps: PS) extends Serializable {
14 | private var iterNum = 1
15 | private var batchSize = 50
16 | private var learningRate: Float = 0.01F
17 |
18 | def fit(data: RDD[(Double, Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) = {
19 | LogisticRegression.runMiniBatchSGD(ps, data, iterNum, batchSize, learningRate)
20 | }
21 |
22 | def predict(data: RDD[(Double, Array[Long], Array[Float])]): RDD[(Double, Double)] = {
23 | LogisticRegression.predict(ps, data, batchSize)
24 | }
25 |
26 | def setIterNum(iterNum: Int): this.type = {
27 | this.iterNum = iterNum
28 | this
29 | }
30 |
31 | def setBatchSize(batchSize: Int): this.type = {
32 | this.batchSize = batchSize
33 | this
34 | }
35 |
36 | def setLearningRate(lr: Float): this.type = {
37 | this.learningRate = lr
38 | this
39 | }
40 | }
41 |
42 | object LogisticRegression {
43 |
44 | def runMiniBatchSGD(ps: PS,
45 | data: RDD[(Double, Array[Long], Array[Float])],
46 | numIterations: Int,
47 | batchSize: Int,
48 | learningRate: Float
49 | ): (Map[Int, Double], (Long, Long)) = {
50 | //train
51 | var trainInfo: Map[Int, Double] = Map()
52 | var posNum: Long = 0
53 | var negNum: Long = 0
54 | for (iterNum <- 0 until numIterations) {
55 | val result = data.mapPartitions { iter =>
56 | val client = new PSClient[Long, Float](ps)
57 | val updater = new LRUpdater()
58 | .setLearningRate(learningRate)
59 | client.setUpdater(updater)
60 | var (count, totalCount, posNum, negNum, batchId, totalLoss) = (0L, 0L, 0L, 0L, 0, 0.0D)
61 | val weightIndex = new util.HashSet[Long]()
62 | val localData = new util.ArrayList[(Double, Array[Long], Array[Float])]()
63 | while (iter.hasNext) {
64 | val dataLine = iter.next()
65 | localData.add(dataLine)
66 | dataLine._2.foreach(weightIndex.add)
67 | if (iterNum == 0) {
68 | if (dataLine._1 > 0.0)
69 | posNum += 1
70 | else
71 | negNum += 1
72 | }
73 | count += 1
74 | if (count == batchSize || !iter.hasNext) {
75 | var s = System.nanoTime()
76 | val localMap = client.pull(weightIndex).toMap
77 | val totalGrad = new util.HashMap[Long, Float]((weightIndex.size / 0.75f + 1).toInt)
78 | localData.foreach { dataIter =>
79 | val (grad, loss) = BinaryClassify.train(
80 | dataIter._1,
81 | (dataIter._2, dataIter._3),
82 | localMap)
83 | // update the inc
84 | grad.foreach { case (id, v) =>
85 | totalGrad.put(id, totalGrad.getOrElse(id, 0F) + v.toFloat)
86 | }
87 | totalLoss += loss
88 | }
89 | val pushMap = totalGrad.map { case (k, v) => (k, v / count) }.toMap
90 | client.push(pushMap)
91 | s = System.nanoTime()
92 | batchId += 1
93 | totalCount += count
94 | count = 0
95 | localData.clear()
96 | weightIndex.clear()
97 | }
98 | }
99 | println("loss:" + totalLoss)
100 | client.shutDown()
101 | Iterator((totalLoss, totalCount, posNum, negNum))
102 | }
103 | val totalResult = result.reduce((x, y) => (x._1 + y._1, x._2 + y._2, x._3 + y._3, x._4 + y._4))
104 | posNum = totalResult._3
105 | negNum = totalResult._4
106 | trainInfo += ((iterNum + 1) -> (totalResult._1 / totalResult._2))
107 | }
108 | (trainInfo, (posNum, negNum))
109 | }
110 |
111 | def predict(ps: PS,
112 | data: RDD[(Double, Array[Long], Array[Float])],
113 | batchSize: Int
114 | ): RDD[(Double, Double)] = {
115 |
116 | data.mapPartitions { iter => {
117 | val client = new PSClient[Long, Float](ps)
118 | var preList = List[(Double, Double)]()
119 | var count = 0
120 | val weightIndex = new util.HashSet[Long]()
121 | val localData = new util.ArrayList[(Double, Array[Long], Array[Float])]()
122 | var batchId = 0
123 | while (iter.hasNext) {
124 | val dataLine = iter.next()
125 | localData.add(dataLine)
126 | dataLine._2.foreach(weightIndex.add)
127 | count += 1
128 | if (count == batchSize || !iter.hasNext) {
129 | val localWeight = client.pull(weightIndex)
130 | println("localWeight: " + localWeight.mkString(","))
131 | //预测label
132 | localData.foreach{x => preList +:= (BinaryClassify.predict((x._2, x._3), localWeight.toMap), x._1)}
133 | count = 0
134 | batchId += 1
135 | localData.clear()
136 | weightIndex.clear()
137 | }
138 | }
139 | client.shutDown()
140 | preList.iterator
141 | }
142 | }
143 | }
144 | }
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/ml/LogisticRegressionWithDCASGD.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.ml
2 |
3 | import java.io.Serializable
4 | import java.util
5 |
6 | import net.qihoo.xitong.xdml.optimization.BinaryClassify
7 | import net.qihoo.xitong.xdml.ps.{PS, PSClient}
8 | import net.qihoo.xitong.xdml.updater.DCASGDUpdater
9 | import org.apache.spark.rdd.RDD
10 |
11 | import scala.collection.JavaConversions._
12 |
13 | class LogisticRegressionWithDCASGD(ps: PS) extends Serializable {
14 | private var iterNum = 1
15 | private var batchSize = 50
16 | private var learningRate:Float = 0.01F
17 | private var dcAsgdCoff:Float = 0.1F
18 |
19 | def fit(data: RDD[(Double, Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) = {
20 | LogisticRegressionWithDCASGD.train(ps, data, iterNum, batchSize, learningRate,dcAsgdCoff)
21 | }
22 |
23 | def predict(data:RDD[(Double, Array[Long], Array[Float])]):RDD[(Double,Double)]={
24 | LogisticRegression.predict(ps,data,batchSize)
25 | }
26 |
27 | def setIterNum(iterNum:Int):this.type = {
28 | this.iterNum = iterNum
29 | this
30 | }
31 |
32 | def setBatchSize(batchSize:Int):this.type = {
33 | this.batchSize = batchSize
34 | this
35 | }
36 |
37 | def setLearningRate(lr:Float):this.type ={
38 | this.learningRate = lr
39 | this
40 | }
41 |
42 | def setDcAsgdCoff(coff:Float):this.type = {
43 | this.dcAsgdCoff = coff
44 | this
45 | }
46 | }
47 |
48 | object LogisticRegressionWithDCASGD{
49 | def train(ps: PS,
50 | data: RDD[(Double, Array[Long], Array[Float])],
51 | numIterations: Int,
52 | batchSize: Int,
53 | learningRate: Float,
54 | coff:Float
55 | ): (Map[Int, Double], (Long, Long)) = {
56 | //train
57 | var trainInfo: Map[Int, Double] = Map()
58 | var posNum: Long = 0
59 | var negNum: Long = 0
60 | for (iterNum <- 0 until numIterations) {
61 | val result = data.mapPartitions { iter =>
62 | val client = new PSClient[Long, Array[Float]](ps)
63 | val updater = new DCASGDUpdater()
64 | .setLearningRate(learningRate)
65 | .setCoff(coff)
66 | client.setUpdater(updater)
67 | var (count, totalCount, posNum, negNum, batchId, totalLoss) = (0L, 0L, 0L, 0L, 0, 0.0)
68 | val weightIndex = new util.HashSet[Long]()
69 | val localData = new util.ArrayList[(Double, Array[Long], Array[Float])]()
70 | while (iter.hasNext) {
71 | val dataLine = iter.next()
72 | localData.add(dataLine)
73 | dataLine._2.foreach(weightIndex.add)
74 | if (iterNum == 0) {
75 | if (dataLine._1 > 0.0)
76 | posNum += 1
77 | else
78 | negNum += 1
79 | }
80 | count += 1
81 | if (count == batchSize || !iter.hasNext) {
82 | val localMap = client.pull(weightIndex).toMap
83 | val totalGrad = new util.HashMap[Long, Float]((weightIndex.size / 0.75f + 1).toInt)
84 | localData.foreach { dataIter =>
85 | val (grad, loss) = BinaryClassify.train(
86 | dataIter._1,
87 | (dataIter._2, dataIter._3),
88 | localMap.map{case (k,v) => (k,v(0))})
89 | // update the inc
90 | grad.foreach { case (id, v) =>
91 | totalGrad.put(id, totalGrad.getOrElse(id, 0F) + v.toFloat)
92 | }
93 | totalLoss += loss
94 | }
95 | val pushMap = totalGrad.map { case (k, v) => (k, Array(v / count)) }.toMap
96 | client.push(pushMap)
97 | batchId += 1
98 | totalCount += count
99 | count = 0
100 | localData.clear()
101 | weightIndex.clear()
102 | }
103 | }
104 | println("loss:" + totalLoss)
105 | client.shutDown()
106 | Iterator((totalLoss, totalCount, posNum, negNum))
107 | }
108 | val totalResult = result.reduce((x, y) => (x._1 + y._1, x._2 + y._2, x._3 + y._3, x._4 + y._4))
109 | posNum = totalResult._3
110 | negNum = totalResult._4
111 | trainInfo += ((iterNum + 1) -> (totalResult._1 / totalResult._2))
112 | }
113 | (trainInfo, (posNum, negNum))
114 | }
115 | }
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/ml/LogisticRegressionWithMomentum.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.ml
2 |
3 | import java.io.Serializable
4 | import java.util
5 |
6 | import net.qihoo.xitong.xdml.ps.{PS, PSClient}
7 | import net.qihoo.xitong.xdml.updater.MomLRUpdater
8 | import net.qihoo.xitong.xdml.optimization.BinaryClassify
9 | import org.apache.spark.rdd.RDD
10 |
11 | import scala.collection.JavaConversions._
12 |
13 | class LogisticRegressionWithMomentum(ps:PS) extends Serializable {
14 | private var iterNum = 1
15 | private var batchSize = 50
16 | private var learningRate:Float = 0.01F
17 | private var momentumCoff:Float = 0.1F
18 | def fit(data: RDD[(Double, Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) = {
19 | LogisticRegressionWithMomentum.runMomentum(ps,data,iterNum,batchSize,learningRate,momentumCoff)
20 | }
21 | def predict(data:RDD[(Double, Array[Long], Array[Float])]):RDD[(Double,Double)]={
22 | LogisticRegression.predict(ps,data,batchSize)
23 | }
24 |
25 | def setIterNum(iterNum:Int):this.type = {
26 | this.iterNum = iterNum
27 | this
28 | }
29 |
30 | def setBatchSize(batchSize:Int):this.type = {
31 | this.batchSize = batchSize
32 | this
33 | }
34 |
35 | def setLearningRate(lr:Float):this.type = {
36 | this.learningRate = lr
37 | this
38 | }
39 |
40 | def setMomemtumCoff(coff:Float):this.type = {
41 | this.momentumCoff = coff
42 | this
43 | }
44 | }
45 |
46 | object LogisticRegressionWithMomentum{
47 | def runMomentum(ps:PS,
48 | data: RDD[(Double, Array[Long], Array[Float])],
49 | numIterations: Int,
50 | batchSize: Int,
51 | learningRate:Float,
52 | coff:Float
53 | ): (Map[Int, Double], (Long, Long)) = {
54 | //train
55 | var trainInfo: Map[Int, Double] = Map()
56 | var posNum:Long = 0
57 | var negNum:Long = 0
58 | for (iterNum <- 0 until numIterations) {
59 | val result = data.mapPartitions { iter =>
60 | val client = new PSClient[Long,Array[Float]](ps)
61 | val update = new MomLRUpdater()
62 | .setLearningRate(learningRate)
63 | .setCoff(coff)
64 | client.setUpdater(update)
65 | var (count,totalCount,posNum,negNum,batchId,totalLoss) = (0L, 0L, 0L, 0L, 0, 0.0)
66 | val weightIndex = new util.HashSet[Long]()
67 | val localData = new util.ArrayList[(Double, Array[Long], Array[Float])]()
68 | while (iter.hasNext) {
69 | val dataLine = iter.next()
70 | localData.add(dataLine)
71 | dataLine._2.foreach(weightIndex.add)
72 | if (iterNum == 0) {
73 | if (dataLine._1 > 0.0)
74 | posNum += 1
75 | else
76 | negNum += 1
77 | }
78 | count += 1
79 | if (count == batchSize || !iter.hasNext) {
80 | val localMap = client.pull(weightIndex).toMap
81 | val totalGrad = new util.HashMap[Long, Float]((weightIndex.size / 0.75f + 1).toInt)
82 | localData.foreach { dataIter =>
83 | val (grad, loss) = BinaryClassify.train(
84 | dataIter._1,
85 | (dataIter._2, dataIter._3),
86 | localMap.map{case (k,v) => (k,v(0))})
87 | // update the inc
88 | grad.foreach { case (id, v) =>
89 | totalGrad.put(id, totalGrad.getOrElse(id, 0F) + v.toFloat)
90 | }
91 | totalLoss += loss
92 | }
93 | val pushMap = totalGrad.map { case (k, v) => (k, Array(v / count)) }.toMap
94 | client.push(pushMap)
95 | batchId += 1
96 | totalCount += count
97 | count = 0
98 | localData.clear()
99 | weightIndex.clear()
100 | }
101 | }
102 | println("loss:" + totalLoss)
103 | client.shutDown()
104 | Iterator((totalLoss, totalCount,posNum,negNum))
105 | }
106 | val totalResult = result.reduce((x, y) => (x._1 + y._1, x._2 + y._2,x._3 + y._3,x._4 + y._4))
107 | posNum = totalResult._3
108 | negNum = totalResult._4
109 | trainInfo += ((iterNum + 1) -> (totalResult._1 / totalResult._2))
110 | }
111 |
112 | (trainInfo, (posNum, negNum))
113 | }
114 | }
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/data/DataHandler.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.data
2 |
3 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
4 | import org.apache.spark.sql.functions._
5 | import org.apache.spark.sql.types.{DoubleType, StructType}
6 | import org.apache.spark.sql.{DataFrame, Row, SparkSession}
7 |
8 |
9 | object DataHandler extends Serializable {
10 |
11 | // WARNING: there could be a problem if the number of partitions is very low
12 | // TODO: test big file with only one part
13 | // TODO: test nullValue
14 | def readData(spark: SparkSession,
15 | dataPath: String,
16 | dataDelimiter: String,
17 | structType: StructType,
18 | nullValue: String,
19 | numPartitions: Int = -1,
20 | header: Boolean = false): DataFrame = {
21 | val df = spark.read
22 | .option("sep", dataDelimiter)
23 | .option("ignoreLeadingWhiteSpace", true)
24 | .option("ignoreTrailingWhiteSpace", true)
25 | .option("nullValue", nullValue)
26 | .option("header", header)
27 | .schema(structType)
28 | .csv(dataPath)
29 | if (numPartitions > 1) {
30 | df.repartition(numPartitions)
31 | } else {
32 | df
33 | }
34 | }
35 |
36 | def writeData(df: DataFrame,
37 | path: String,
38 | delimiter: String,
39 | numPartitions: Int = -1,
40 | header: Boolean = false): Unit = {
41 | val dfTarget = if (numPartitions > 1) df.repartition(numPartitions) else df
42 | dfTarget.write.option("sep", delimiter).option("header", header).csv(path)
43 | }
44 |
45 | // WARNING: there could be a problem if the number of partitions is very low
46 | def readLibSVMData(spark: SparkSession,
47 | dataPath: String,
48 | numFeatures: Int = -1,
49 | numPartitions: Int = -1,
50 | labelColName: String = "labelLibSVM",
51 | featuresColName: String = "featuresLibSVM"): DataFrame = {
52 | val df = spark.read.format("libsvm").option("numFeatures", numFeatures).load(dataPath)
53 | .withColumnRenamed("label", labelColName)
54 | .withColumnRenamed("features", featuresColName)
55 | if (numPartitions > 1) {
56 | df.repartition(numPartitions)
57 | } else {
58 | df
59 | }
60 | }
61 |
62 | def writeLibSVMData(df: DataFrame,
63 | labelColName: String,
64 | featuresColName: String,
65 | path: String,
66 | numPartitions: Int = -1): Unit = {
67 | val dfTarget = if (numPartitions > 1) df.repartition(numPartitions) else df
68 | val resRDD = dfTarget.select(col(labelColName).cast(DoubleType), col(featuresColName)).rdd.map{
69 | case Row(label: Double, features: Vector) => {
70 | features match {
71 | case sv: SparseVector => {
72 | val kvs = sv.indices.zip(sv.values)
73 | label.toString + " " + kvs.map{ case (k, v) => k.toString + ":" + v.toString }.mkString(" ")
74 | }
75 | case dv: DenseVector => {
76 | label.toString + " " + dv.toArray.zipWithIndex.map{ case (v, k) => k.toString + ":" + v.toString }.mkString(" ")
77 | }
78 | }
79 | }
80 | }
81 | resRDD.saveAsTextFile(path)
82 | }
83 |
84 | // TODO: to be checked
85 | // WARNING: there could be a problem if the number of partitions is very low
86 | def readHiveData(spark: SparkSession,
87 | tableName: String,
88 | numPartitions: Int = -1): DataFrame = {
89 | val df = spark.read.table(tableName)
90 | if (numPartitions > 1) {
91 | df.repartition(numPartitions)
92 | } else {
93 | df
94 | }
95 | }
96 |
97 | }
98 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/data/LogHandler.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.data
2 |
3 | import org.apache.log4j.{Level, Logger}
4 |
5 | object LogHandler extends Serializable {
6 |
7 | @transient lazy val logger = Logger.getLogger("XDML")
8 |
9 | def avoidLog(): Unit = {
10 | Logger.getLogger("org").setLevel(Level.OFF)
11 | Logger.getLogger("akka").setLevel(Level.OFF)
12 | }
13 |
14 | }
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/data/SchemaHandler.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.data
2 |
3 | import org.apache.spark.SparkContext
4 | import org.apache.spark.sql.types._
5 |
6 | import scala.io.Source
7 |
8 |
9 | class SchemaHandler(val schema: StructType, val namesAndTypes: Array[Array[String]], checkLabel: Boolean = false) extends Serializable {
10 |
11 | val keyColName: Array[String] = namesAndTypes.filter(x => x(1).equals(SchemaHandler.Key)).map(x => x(0))
12 | val labelColName: Array[String] = namesAndTypes.filter(x => x(1).equals(SchemaHandler.Label)).map(x => x(0))
13 | val featColNames: Array[String] = namesAndTypes.filter(x => x(1).equals(SchemaHandler.Cat)
14 | || x(1).equals(SchemaHandler.Num)
15 | || x(1).equals(SchemaHandler.Text)
16 | || x(1).equals(SchemaHandler.MultiCat)).map(x => x(0))
17 | val catFeatColNames: Array[String] = namesAndTypes.filter(x => x(1).equals(SchemaHandler.Cat)).map(x => x(0))
18 | val multiCatFeatColNames: Array[String] = namesAndTypes.filter(x => x(1).equals(SchemaHandler.MultiCat)).map(x => x(0))
19 | val numFeatColNames: Array[String] = namesAndTypes.filter(x => x(1).equals(SchemaHandler.Num)).map(x => x(0))
20 | val textFeatColNames: Array[String] = namesAndTypes.filter(x => x(1).equals(SchemaHandler.Text)).map(x => x(0))
21 | val otherColNames: Array[String] = namesAndTypes.filter(x => x(1).equals(SchemaHandler.Other)).map(x => x(0))
22 | if (checkLabel) assert(labelColName.length == 1, "one and only one label needs to be provided")
23 |
24 | }
25 |
26 |
27 | object SchemaHandler extends Serializable {
28 |
29 | val Text = "Text"
30 | val Num = "Num"
31 | val Cat = "Cat"
32 | val MultiCat = "MultiCat"
33 | val Key = "Key"
34 | val Label = "Label"
35 | val Other = "Other"
36 |
37 | def readSchema(sc: SparkContext, schemaPath: String, delimiter: String): SchemaHandler = {
38 | val namesAndTypes = if(schemaPath.startsWith("jar://")) {
39 | val jarPath = schemaPath.substring("jar://".length())
40 | val stream = getClass().getClassLoader().getResourceAsStream(jarPath)
41 | val lines = Source.fromInputStream(stream).getLines
42 | lines.toArray.map(line => line.split(delimiter))
43 | } else {
44 | sc.textFile(schemaPath, 1).collect().map(line => line.split(delimiter))
45 | }
46 | val sfArr = namesAndTypes.map(arr => {
47 | arr(1) match {
48 | case Num => StructField(arr(0), DoubleType, true)
49 | case Cat => StructField(arr(0), StringType, true)
50 | case MultiCat => StructField(arr(0), StringType, true)
51 | case Text => StructField(arr(0), StringType, true)
52 | case Key => StructField(arr(0), StringType, true)
53 | case Label => StructField(arr(0), StringType, true)
54 | case Other => StructField(arr(0), StringType, true)
55 | case _ => throw new IllegalArgumentException("data type unknown")
56 | }
57 | })
58 | val schema = StructType(sfArr)
59 | new SchemaHandler(schema, namesAndTypes)
60 | }
61 |
62 | }
63 |
64 |
65 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/loss/HingeLossFunc.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.loss
2 |
3 | import org.apache.spark.mllib.linalg.Vector
4 | import net.qihoo.xitong.xdml.model.linalg.BLAS
5 |
6 |
7 | class HingeLossFunc extends LossFunc {
8 |
9 | ////////////////////////// without intercept //////////////////////////////////
10 |
11 | override def gradientWithLoss(data: Vector,
12 | label: Double,
13 | weights: Vector): (Double, Double) = {
14 | val dotProduct = BLAS.dot(weights, data)
15 | val labelScaled = 2 * label - 1.0
16 | if (1.0 > labelScaled * dotProduct) {
17 | (-labelScaled, 1.0 - labelScaled * dotProduct)
18 | } else {
19 | (0.0, 0.0)
20 | }
21 | }
22 |
23 | override def gradient(data: Vector,
24 | label: Double,
25 | weights: Vector): Double = {
26 | val dotProduct = BLAS.dot(weights, data)
27 | val labelScaled = 2 * label - 1.0
28 | if (1.0 > labelScaled * dotProduct) {
29 | -labelScaled
30 | } else {
31 | 0.0
32 | }
33 | }
34 |
35 | override def loss(data: Vector,
36 | label: Double,
37 | weights: Vector): Double = {
38 | val dotProduct = BLAS.dot(weights, data)
39 | val labelScaled = 2 * label - 1.0
40 | if (1.0 > labelScaled * dotProduct) {
41 | 1.0 - labelScaled * dotProduct
42 | } else {
43 | 0.0
44 | }
45 | }
46 |
47 | override def gradientFromDot(dot: Double,
48 | label: Double): Double = {
49 | val labelScaled = 2 * label - 1.0
50 | if (1.0 > labelScaled * dot) {
51 | -labelScaled
52 | } else {
53 | 0.0
54 | }
55 | }
56 |
57 | ////////////////////////// with intercept //////////////////////////////////
58 |
59 | override def gradientWithLoss(data: Vector,
60 | label: Double,
61 | weights: Vector,
62 | intercept: Double): (Double, Double) = {
63 | // val dotProduct = BLAS.dot(weights, data)
64 | val dotProduct = BLAS.dot(weights, data) + intercept
65 | val labelScaled = 2 * label - 1.0
66 | if (1.0 > labelScaled * dotProduct) {
67 | (-labelScaled, 1.0 - labelScaled * dotProduct)
68 | } else {
69 | (0.0, 0.0)
70 | }
71 | }
72 |
73 | override def gradient(data: Vector,
74 | label: Double,
75 | weights: Vector,
76 | intercept: Double): Double = {
77 | // val dotProduct = BLAS.dot(weights, data)
78 | val dotProduct = BLAS.dot(weights, data) + intercept
79 | val labelScaled = 2 * label - 1.0
80 | if (1.0 > labelScaled * dotProduct) {
81 | -labelScaled
82 | } else {
83 | 0.0
84 | }
85 | }
86 |
87 | override def loss(data: Vector,
88 | label: Double,
89 | weights: Vector,
90 | intercept: Double): Double = {
91 | // val dotProduct = BLAS.dot(weights, data)
92 | val dotProduct = BLAS.dot(weights, data) + intercept
93 | val labelScaled = 2 * label - 1.0
94 | if (1.0 > labelScaled * dotProduct) {
95 | 1.0 - labelScaled * dotProduct
96 | } else {
97 | 0.0
98 | }
99 | }
100 |
101 | }
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/loss/L2LossFunc.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.loss
2 |
3 | import net.qihoo.xitong.xdml.model.linalg.BLAS
4 | import org.apache.spark.mllib.linalg.Vector
5 |
6 |
7 | class L2LossFunc extends LossFunc {
8 |
9 | ////////////////////////// without intercept //////////////////////////////////
10 |
11 | override def loss(data: Vector,
12 | label: Double,
13 | weights: Vector): Double = {
14 | val margin = BLAS.dot(weights, data)
15 | math.pow(margin - label, 2) / 2
16 | }
17 |
18 | override def gradient(data: Vector,
19 | label: Double,
20 | weights: Vector): Double = {
21 | val margin = BLAS.dot(weights, data)
22 | margin - label
23 | }
24 |
25 | override def gradientWithLoss(data: Vector,
26 | label: Double,
27 | weights: Vector): (Double, Double) = {
28 | val margin = BLAS.dot(weights, data)
29 | (margin - label, math.pow(margin - label, 2) / 2)
30 | }
31 |
32 | override def gradientFromDot(dot: Double,
33 | label: Double): Double = {
34 | dot - label
35 | }
36 |
37 | ////////////////////////// with intercept //////////////////////////////////
38 |
39 | override def loss(data: Vector,
40 | label: Double,
41 | weights: Vector,
42 | intercept: Double): Double = {
43 | val margin = BLAS.dot(weights, data) + intercept
44 | math.pow(margin - label, 2) / 2
45 | }
46 |
47 | override def gradient(data: Vector,
48 | label: Double,
49 | weights: Vector,
50 | intercept: Double): Double = {
51 | val margin = BLAS.dot(weights, data) + intercept
52 | margin - label
53 | }
54 |
55 | override def gradientWithLoss(data: Vector,
56 | label: Double,
57 | weights: Vector,
58 | intercept: Double): (Double, Double) = {
59 | val margin = BLAS.dot(weights, data) + intercept
60 | (margin - label, math.pow(margin - label, 2) / 2)
61 | }
62 |
63 | }
64 |
65 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/loss/LogitLossFunc.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.loss
2 |
3 | import org.apache.spark.mllib.linalg.Vector
4 | import net.qihoo.xitong.xdml.model.util.MLUtils
5 | import net.qihoo.xitong.xdml.model.linalg.BLAS
6 |
7 |
8 | class LogitLossFunc extends LossFunc {
9 |
10 | def sigmoid(z: Double): Double = {
11 | 1.0 / (1.0 + math.exp(-z))
12 | }
13 |
14 | ////////////////////////// without intercept //////////////////////////////////
15 |
16 | override def loss(data: Vector,
17 | label: Double,
18 | weights: Vector): Double = {
19 | val margin = - BLAS.dot(weights, data)
20 | if (label > 0.5) {
21 | // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
22 | MLUtils.log1pExp(margin)
23 | } else {
24 | MLUtils.log1pExp(margin) - margin
25 | }
26 | }
27 |
28 | override def gradient(data: Vector,
29 | label: Double,
30 | weights: Vector): Double = {
31 | val margin = - BLAS.dot(weights, data)
32 | val factor = (1.0 / (1.0 + math.exp(margin))) - label
33 | factor
34 | }
35 |
36 | override def gradientWithLoss(data: Vector,
37 | label: Double,
38 | weights: Vector): (Double, Double) = {
39 | val margin = - BLAS.dot(weights, data)
40 | val factor = (1.0 / (1.0 + math.exp(margin))) - label
41 | val loss = if (label > 0.5) {
42 | // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
43 | MLUtils.log1pExp(margin)
44 | } else {
45 | MLUtils.log1pExp(margin) - margin
46 | }
47 | (factor, loss)
48 | }
49 |
50 | override def gradientFromDot(dot: Double,
51 | label: Double): Double = {
52 | sigmoid(dot) - label
53 | }
54 |
55 | ////////////////////////// with intercept //////////////////////////////////
56 |
57 | override def loss(data: Vector,
58 | label: Double,
59 | weights: Vector,
60 | intercept: Double): Double = {
61 | // val margin = - BLAS.dot(weights, data)
62 | val margin = - BLAS.dot(weights, data) - intercept
63 | if (label > 0.5) {
64 | // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
65 | MLUtils.log1pExp(margin)
66 | } else {
67 | MLUtils.log1pExp(margin) - margin
68 | }
69 | }
70 |
71 | override def gradient(data: Vector,
72 | label: Double,
73 | weights: Vector,
74 | intercept: Double): Double = {
75 | // val margin = - BLAS.dot(weights, data)
76 | val margin = - BLAS.dot(weights, data) - intercept
77 | val factor = (1.0 / (1.0 + math.exp(margin))) - label
78 | factor
79 | }
80 |
81 | override def gradientWithLoss(data: Vector,
82 | label: Double,
83 | weights: Vector,
84 | intercept: Double): (Double, Double) = {
85 | // val margin = - BLAS.dot(weights, data)
86 | val margin = - BLAS.dot(weights, data) - intercept
87 | val factor = (1.0 / (1.0 + math.exp(margin))) - label
88 | val loss = if (label > 0.5) {
89 | // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
90 | MLUtils.log1pExp(margin)
91 | } else {
92 | MLUtils.log1pExp(margin) - margin
93 | }
94 | (factor, loss)
95 | }
96 |
97 | }
98 |
99 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/loss/LossFunc.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.loss
2 |
3 | import org.apache.spark.mllib.linalg.Vector
4 |
5 |
6 | abstract class LossFunc extends Serializable {
7 |
8 | ////////////////////////// without intercept //////////////////////////////////
9 |
10 | def gradient(data: Vector, label: Double, weights: Vector): Double
11 |
12 | def loss(data: Vector, label: Double, weights: Vector): Double
13 |
14 | def gradientWithLoss(data: Vector, label: Double, weights: Vector): (Double, Double)
15 |
16 | def gradientFromDot(dot: Double, label: Double): Double
17 |
18 | ////////////////////////// with intercept //////////////////////////////////
19 |
20 | def gradient(data: Vector, label: Double, weights: Vector, intercept: Double): Double
21 |
22 | def loss(data: Vector, label: Double, weights: Vector, intercept: Double): Double
23 |
24 | def gradientWithLoss(data: Vector, label: Double, weights: Vector, intercept: Double): (Double, Double)
25 |
26 | }
27 |
28 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/loss/MultiHingeLossFunc.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.loss
2 |
3 | import net.qihoo.xitong.xdml.model.linalg.BLAS
4 | import org.apache.spark.mllib.linalg.Vector
5 |
6 |
7 | class MultiHingeLossFunc(numClasses: Int) extends MultiLossFunc {
8 |
9 | def gradientFromMargins(margins: Array[Double], label: Double): Array[Double] = {
10 | val labelMap = Array.fill(margins.length)(-1.0)
11 | labelMap(label.toInt) = 1.0
12 | val gradientFactors = Array.fill(margins.length)(0.0)
13 | for(ind <- margins.indices) {
14 | gradientFactors(ind) =
15 | if (1.0 > labelMap(ind) * margins(ind)) {
16 | -labelMap(ind)
17 | } else {
18 | 0.0
19 | }
20 | }
21 | gradientFactors
22 | }
23 |
24 | ////////////////////////// with intercept //////////////////////////////////
25 |
26 | def gradientWithLoss(data: Vector, label: Double,
27 | weightsArr: Array[Vector],
28 | interceptArr: Array[Double]): (Array[Double], Double) = {
29 | val margins = Array.fill(weightsArr.length)(0.0)
30 | for(ind <- margins.indices) {
31 | margins(ind) = BLAS.dot(weightsArr(ind), data) + interceptArr(ind)
32 | }
33 | val labelMap = Array.fill(margins.length)(-1.0)
34 | labelMap(label.toInt) = 1.0
35 | val gradientFactors = Array.fill(margins.length)(0.0)
36 | val losses = Array.fill(margins.length)(0.0)
37 | for(ind <- margins.indices) {
38 | if (1.0 > labelMap(ind) * margins(ind)) {
39 | gradientFactors(ind) = -labelMap(ind)
40 | losses(ind) = 1.0 - labelMap(ind) * margins(ind)
41 | } else {
42 | gradientFactors(ind) = 0.0
43 | losses(ind) = 0.0
44 | }
45 | }
46 | (gradientFactors, losses.sum)
47 | }
48 |
49 | ////////////////////////// without intercept //////////////////////////////////
50 |
51 | def gradientWithLoss(data: Vector, label: Double,
52 | weightsArr: Array[Vector]): (Array[Double], Double) = {
53 | val margins = Array.fill(weightsArr.length)(0.0)
54 | for(ind <- margins.indices) {
55 | margins(ind) = BLAS.dot(weightsArr(ind), data)
56 | }
57 | val labelMap = Array.fill(margins.length)(-1.0)
58 | labelMap(label.toInt) = 1.0
59 | val gradientFactors = Array.fill(margins.length)(0.0)
60 | val losses = Array.fill(margins.length)(0.0)
61 | for(ind <- margins.indices) {
62 | if (1.0 > labelMap(ind) * margins(ind)) {
63 | gradientFactors(ind) = -labelMap(ind)
64 | losses(ind) = 1.0 - labelMap(ind) * margins(ind)
65 | } else {
66 | gradientFactors(ind) = 0.0
67 | losses(ind) = 0.0
68 | }
69 | }
70 | (gradientFactors, losses.sum)
71 | }
72 |
73 | }
74 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/loss/MultiLogitLossFunc.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.loss
2 |
3 | import org.apache.spark.mllib.linalg.Vector
4 |
5 |
6 | class MultiLogitLossFunc(numClasses: Int) extends MultiLossFunc {
7 |
8 | def gradientFromMargins(margins: Array[Double], label: Double): Array[Double] = {
9 | // marginY is margins(label - 1) in the formula.
10 | var marginY = 0.0
11 | var maxMargin = Double.NegativeInfinity
12 | var maxMarginIndex = 0
13 | for(i <- 0 until (numClasses-1)) {
14 | if (i == label.toInt - 1) marginY = margins(i)
15 | if (margins(i) > maxMargin) {
16 | maxMargin = margins(i)
17 | maxMarginIndex = i
18 | }
19 | }
20 |
21 | val sum = {
22 | var temp = 0.0
23 | if (maxMargin > 0) {
24 | for (i <- 0 until numClasses - 1) {
25 | margins(i) -= maxMargin
26 | if (i == maxMarginIndex) {
27 | temp += math.exp(-maxMargin)
28 | } else {
29 | temp += math.exp(margins(i))
30 | }
31 | }
32 | } else {
33 | for (i <- 0 until numClasses - 1) {
34 | temp += math.exp(margins(i))
35 | }
36 | }
37 | temp
38 | }
39 |
40 | for (i <- 0 until numClasses - 1) {
41 | margins(i) = math.exp(margins(i)) / (sum + 1.0) - { if (label != 0.0 && label == i + 1) 1.0 else 0.0 }
42 | }
43 |
44 | margins
45 | }
46 |
47 | ////////////////////////// with intercept //////////////////////////////////
48 |
49 | def gradientWithLoss(data: Vector, label: Double,
50 | weightsArr: Array[Vector],
51 | interceptArr: Array[Double]): (Array[Double], Double) = {
52 | // marginY is margins(label - 1) in the formula.
53 | var marginY = 0.0
54 | var maxMargin = Double.NegativeInfinity
55 | var maxMarginIndex = 0
56 | val margins = interceptArr.clone()
57 | for(i <- 0 until (numClasses-1)) {
58 | data.foreachActive { (index, value) => if (value != 0.0) margins(i) += value * weightsArr(i)(index) }
59 | if (i == label.toInt - 1) marginY = margins(i)
60 | if (margins(i) > maxMargin) {
61 | maxMargin = margins(i)
62 | maxMarginIndex = i
63 | }
64 | }
65 |
66 | val sum = {
67 | var temp = 0.0
68 | if (maxMargin > 0) {
69 | for (i <- 0 until numClasses - 1) {
70 | margins(i) -= maxMargin
71 | if (i == maxMarginIndex) {
72 | temp += math.exp(-maxMargin)
73 | } else {
74 | temp += math.exp(margins(i))
75 | }
76 | }
77 | } else {
78 | for (i <- 0 until numClasses - 1) {
79 | temp += math.exp(margins(i))
80 | }
81 | }
82 | temp
83 | }
84 |
85 | for (i <- 0 until numClasses - 1) {
86 | margins(i) = math.exp(margins(i)) / (sum + 1.0) - { if (label != 0.0 && label == i + 1) 1.0 else 0.0 }
87 | }
88 |
89 | var loss = if (label > 0.0) math.log1p(sum) - marginY else math.log1p(sum)
90 | if (maxMargin > 0) {
91 | loss += maxMargin
92 | }
93 |
94 | (margins, loss)
95 | }
96 |
97 |
98 | ////////////////////////// without intercept //////////////////////////////////
99 |
100 | def gradientWithLoss(data: Vector, label: Double,
101 | weightsArr: Array[Vector]): (Array[Double], Double) = {
102 | // marginY is margins(label - 1) in the formula.
103 | var marginY = 0.0
104 | var maxMargin = Double.NegativeInfinity
105 | var maxMarginIndex = 0
106 | val margins = Array.fill(numClasses-1)(0.0)
107 | for(i <- 0 until (numClasses-1)) {
108 | data.foreachActive { (index, value) => if (value != 0.0) margins(i) += value * weightsArr(i)(index) }
109 | if (i == label.toInt - 1) marginY = margins(i)
110 | if (margins(i) > maxMargin) {
111 | maxMargin = margins(i)
112 | maxMarginIndex = i
113 | }
114 | }
115 |
116 | val sum = {
117 | var temp = 0.0
118 | if (maxMargin > 0) {
119 | for (i <- 0 until numClasses - 1) {
120 | margins(i) -= maxMargin
121 | if (i == maxMarginIndex) {
122 | temp += math.exp(-maxMargin)
123 | } else {
124 | temp += math.exp(margins(i))
125 | }
126 | }
127 | } else {
128 | for (i <- 0 until numClasses - 1) {
129 | temp += math.exp(margins(i))
130 | }
131 | }
132 | temp
133 | }
134 |
135 | for (i <- 0 until numClasses - 1) {
136 | margins(i) = math.exp(margins(i)) / (sum + 1.0) - { if (label != 0.0 && label == i + 1) 1.0 else 0.0 }
137 | }
138 |
139 | var loss = if (label > 0.0) math.log1p(sum) - marginY else math.log1p(sum)
140 | if (maxMargin > 0) {
141 | loss += maxMargin
142 | }
143 |
144 | (margins, loss)
145 | }
146 |
147 |
148 |
149 | }
150 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/loss/MultiLossFunc.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.loss
2 |
3 | import org.apache.spark.mllib.linalg.Vector
4 |
5 |
6 | abstract class MultiLossFunc extends Serializable {
7 |
8 | def gradientFromMargins(margins: Array[Double], label: Double): Array[Double]
9 |
10 | ////////////////////////// with intercept //////////////////////////////////
11 |
12 | def gradientWithLoss(data: Vector, label: Double,
13 | weightsArr: Array[Vector],
14 | interceptArr: Array[Double]): (Array[Double], Double)
15 |
16 | ////////////////////////// without intercept //////////////////////////////////
17 |
18 | def gradientWithLoss(data: Vector, label: Double,
19 | weightsArr: Array[Vector]): (Array[Double], Double)
20 |
21 | }
22 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/loss/MultiSmoothHingeLossFunc.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.loss
2 |
3 | import net.qihoo.xitong.xdml.model.linalg.BLAS
4 | import org.apache.spark.mllib.linalg.Vector
5 |
6 |
7 | class MultiSmoothHingeLossFunc(numClasses: Int) extends MultiLossFunc {
8 |
9 | def gradientFromMargins(margins: Array[Double], label: Double): Array[Double] = {
10 | val labelMap = Array.fill(margins.length)(-1.0)
11 | labelMap(label.toInt) = 1.0
12 | val gradientFactors = Array.fill(margins.length)(0.0)
13 | for(ind <- margins.indices) {
14 | val max = math.max(0.0, 1.0 - labelMap(ind) * margins(ind))
15 | gradientFactors(ind) =
16 | if (0.0 > labelMap(ind) * margins(ind)) {
17 | - labelMap(ind)
18 | } else {
19 | - labelMap(ind) * max
20 | }
21 | }
22 | gradientFactors
23 | }
24 |
25 | ////////////////////////// with intercept //////////////////////////////////
26 |
27 | def gradientWithLoss(data: Vector, label: Double,
28 | weightsArr: Array[Vector],
29 | interceptArr: Array[Double]): (Array[Double], Double) = {
30 | val margins = Array.fill(weightsArr.length)(0.0)
31 | for(ind <- margins.indices) {
32 | margins(ind) = BLAS.dot(weightsArr(ind), data) + interceptArr(ind)
33 | }
34 | val labelMap = Array.fill(margins.length)(-1.0)
35 | labelMap(label.toInt) = 1.0
36 | val gradientFactors = Array.fill(margins.length)(0.0)
37 | val losses = Array.fill(margins.length)(0.0)
38 | for(ind <- margins.indices) {
39 | val max = math.max(0.0, 1.0 - labelMap(ind) * margins(ind))
40 | if (0.0 > labelMap(ind) * margins(ind)) {
41 | gradientFactors(ind) = - labelMap(ind)
42 | losses(ind) = max - 0.5
43 | } else {
44 | gradientFactors(ind) = - labelMap(ind) * max
45 | losses(ind) = max * max / 2
46 | }
47 | }
48 | (gradientFactors, losses.sum)
49 | }
50 |
51 | ////////////////////////// without intercept //////////////////////////////////
52 |
53 | def gradientWithLoss(data: Vector, label: Double,
54 | weightsArr: Array[Vector]): (Array[Double], Double) = {
55 | val margins = Array.fill(weightsArr.length)(0.0)
56 | for(ind <- margins.indices) {
57 | margins(ind) = BLAS.dot(weightsArr(ind), data)
58 | }
59 | val labelMap = Array.fill(margins.length)(-1.0)
60 | labelMap(label.toInt) = 1.0
61 | val gradientFactors = Array.fill(margins.length)(0.0)
62 | val losses = Array.fill(margins.length)(0.0)
63 | for(ind <- margins.indices) {
64 | val max = math.max(0.0, 1.0 - labelMap(ind) * margins(ind))
65 | if (0.0 > labelMap(ind) * margins(ind)) {
66 | gradientFactors(ind) = - labelMap(ind)
67 | losses(ind) = max - 0.5
68 | } else {
69 | gradientFactors(ind) = - labelMap(ind) * max
70 | losses(ind) = max * max / 2
71 | }
72 | }
73 | (gradientFactors, losses.sum)
74 | }
75 |
76 | }
77 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/loss/PoissonLossFunc.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.loss
2 |
3 | import net.qihoo.xitong.xdml.model.linalg.BLAS
4 | import org.apache.spark.mllib.linalg.Vector
5 |
6 |
7 | class PoissonLossFunc extends LossFunc {
8 |
9 | ////////////////////////// without intercept //////////////////////////////////
10 |
11 | override def loss(data: Vector,
12 | label: Double,
13 | weights: Vector): Double = {
14 | val margin = BLAS.dot(weights, data)
15 | math.exp(margin) - label * margin
16 | }
17 |
18 | override def gradient(data: Vector,
19 | label: Double,
20 | weights: Vector): Double = {
21 | val margin = BLAS.dot(weights, data)
22 | math.exp(margin) - label
23 | }
24 |
25 | override def gradientWithLoss(data: Vector,
26 | label: Double,
27 | weights: Vector): (Double, Double) = {
28 | val margin = BLAS.dot(weights, data)
29 | val expMargin = math.exp(margin)
30 | (expMargin - label, expMargin - label * margin)
31 | }
32 |
33 | override def gradientFromDot(dot: Double,
34 | label: Double): Double = {
35 | math.exp(dot) - label
36 | }
37 |
38 | ////////////////////////// with intercept //////////////////////////////////
39 |
40 | override def loss(data: Vector,
41 | label: Double,
42 | weights: Vector,
43 | intercept: Double): Double = {
44 | val margin = BLAS.dot(weights, data) + intercept
45 | math.exp(margin) - label * margin
46 | }
47 |
48 | override def gradient(data: Vector,
49 | label: Double,
50 | weights: Vector,
51 | intercept: Double): Double = {
52 | val margin = BLAS.dot(weights, data) + intercept
53 | math.exp(margin) - label
54 | }
55 |
56 | override def gradientWithLoss(data: Vector,
57 | label: Double,
58 | weights: Vector,
59 | intercept: Double): (Double, Double) = {
60 | val margin = BLAS.dot(weights, data) + intercept
61 | val expMargin = math.exp(margin)
62 | (expMargin - label, expMargin - label * margin)
63 | }
64 |
65 | }
66 |
67 |
68 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/loss/SmoothHingeLossFunc.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.loss
2 |
3 | import org.apache.spark.mllib.linalg.Vector
4 | import net.qihoo.xitong.xdml.model.linalg.BLAS
5 |
6 |
7 | class SmoothHingeLossFunc extends LossFunc {
8 |
9 | ////////////////////////// without intercept //////////////////////////////////
10 |
11 | override def gradientWithLoss(data: Vector,
12 | label: Double,
13 | weights: Vector): (Double, Double) = {
14 | val dotProduct = BLAS.dot(weights, data)
15 | val labelScaled = 2 * label - 1.0
16 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct)
17 | if (labelScaled * dotProduct < 0) {
18 | (- labelScaled, max - 0.5)
19 | } else {
20 | (- labelScaled * max, max * max / 2)
21 | }
22 | }
23 |
24 | override def gradient(data: Vector,
25 | label: Double,
26 | weights: Vector): Double = {
27 | val dotProduct = BLAS.dot(weights, data)
28 | val labelScaled = 2 * label - 1.0
29 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct)
30 | if (labelScaled * dotProduct < 0) {
31 | - labelScaled
32 | } else {
33 | - labelScaled * max
34 | }
35 | }
36 |
37 | override def loss(data: Vector,
38 | label: Double,
39 | weights: Vector): Double = {
40 | val dotProduct = BLAS.dot(weights, data)
41 | val labelScaled = 2 * label - 1.0
42 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct)
43 | if (labelScaled * dotProduct < 0) {
44 | max - 0.5
45 | } else {
46 | max * max / 2
47 | }
48 | }
49 |
50 | override def gradientFromDot(dot: Double,
51 | label: Double): Double = {
52 | val labelScaled = 2 * label - 1.0
53 | val max = math.max(0.0, 1.0 - labelScaled * dot)
54 | if (labelScaled * dot < 0) {
55 | - labelScaled
56 | } else {
57 | - labelScaled * max
58 | }
59 | }
60 |
61 | ////////////////////////// with intercept //////////////////////////////////
62 |
63 | override def gradientWithLoss(data: Vector,
64 | label: Double,
65 | weights: Vector,
66 | intercept: Double): (Double, Double) = {
67 | // val dotProduct = BLAS.dot(weights, data)
68 | val dotProduct = BLAS.dot(weights, data) + intercept
69 | val labelScaled = 2 * label - 1.0
70 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct)
71 | if (labelScaled * dotProduct < 0) {
72 | (- labelScaled, max - 0.5)
73 | } else {
74 | (- labelScaled * max, max * max / 2)
75 | }
76 | }
77 |
78 | override def gradient(data: Vector,
79 | label: Double,
80 | weights: Vector,
81 | intercept: Double): Double = {
82 | // val dotProduct = BLAS.dot(weights, data)
83 | val dotProduct = BLAS.dot(weights, data) + intercept
84 | val labelScaled = 2 * label - 1.0
85 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct)
86 | if (labelScaled * dotProduct < 0) {
87 | - labelScaled
88 | } else {
89 | - labelScaled * max
90 | }
91 | }
92 |
93 | override def loss(data: Vector,
94 | label: Double,
95 | weights: Vector,
96 | intercept: Double): Double = {
97 | // val dotProduct = BLAS.dot(weights, data)
98 | val dotProduct = BLAS.dot(weights, data) + intercept
99 | val labelScaled = 2 * label - 1.0
100 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct)
101 | if (labelScaled * dotProduct < 0) {
102 | max - 0.5
103 | } else {
104 | max * max / 2
105 | }
106 | }
107 |
108 | }
109 |
110 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/loss/UPULogitLossFunc.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.loss
2 |
3 | import net.qihoo.xitong.xdml.model.linalg.BLAS
4 | import net.qihoo.xitong.xdml.model.util.MLUtils
5 | import org.apache.spark.mllib.linalg.Vector
6 |
7 |
8 | class UPULogitLossFunc(posRatioPrior: Double, posRatio: Double, unlabeledRatio: Double) extends LossFunc {
9 |
10 | val posRatioPriorFrac = - posRatioPrior / posRatio
11 | println("unlabeledRatio: "+unlabeledRatio)
12 | println("posRatio: "+posRatio)
13 | println("posRatioPrior: "+posRatioPrior)
14 |
15 | ////////////////////////// without intercept //////////////////////////////////
16 |
17 | override def loss(data: Vector,
18 | label: Double,
19 | weights: Vector): Double = {
20 | val margin = BLAS.dot(weights, data)
21 | if (label > 0.5) {
22 | posRatioPriorFrac * margin
23 | } else {
24 | MLUtils.log1pExp(margin) / unlabeledRatio
25 | }
26 | }
27 |
28 | override def gradient(data: Vector,
29 | label: Double,
30 | weights: Vector): Double = {
31 | if (label > 0.5) {
32 | posRatioPriorFrac
33 | } else {
34 | val margin = BLAS.dot(weights, data)
35 | 1.0 / (1.0 + math.exp(-margin)) / unlabeledRatio
36 | }
37 | }
38 |
39 | override def gradientWithLoss(data: Vector,
40 | label: Double,
41 | weights: Vector): (Double, Double) = {
42 | val margin = BLAS.dot(weights, data)
43 | if (label > 0.5) {
44 | (posRatioPriorFrac, posRatioPriorFrac * margin)
45 | } else {
46 | (1.0 / (1.0 + math.exp(-margin)) / unlabeledRatio, MLUtils.log1pExp(margin) / unlabeledRatio)
47 | }
48 | }
49 |
50 | override def gradientFromDot(dot: Double,
51 | label: Double): Double = {
52 | if (label > 0.5) {
53 | posRatioPriorFrac
54 | } else {
55 | 1.0 / (1.0 + math.exp(-dot)) / unlabeledRatio
56 | }
57 | }
58 |
59 | ////////////////////////// with intercept //////////////////////////////////
60 |
61 | override def loss(data: Vector,
62 | label: Double,
63 | weights: Vector,
64 | intercept: Double): Double = {
65 | val margin = BLAS.dot(weights, data) + intercept
66 | if (label > 0.5) {
67 | posRatioPriorFrac * margin
68 | } else {
69 | MLUtils.log1pExp(margin) / unlabeledRatio
70 | }
71 | }
72 |
73 | override def gradient(data: Vector,
74 | label: Double,
75 | weights: Vector,
76 | intercept: Double): Double = {
77 | if (label > 0.5) {
78 | posRatioPriorFrac
79 | } else {
80 | val margin = BLAS.dot(weights, data) + intercept
81 | 1.0 / (1.0 + math.exp(-margin)) / unlabeledRatio
82 | }
83 | }
84 |
85 | override def gradientWithLoss(data: Vector,
86 | label: Double,
87 | weights: Vector,
88 | intercept: Double): (Double, Double) = {
89 | val margin = BLAS.dot(weights, data) + intercept
90 | if (label > 0.5) {
91 | (posRatioPriorFrac, posRatioPriorFrac * margin)
92 | } else {
93 | (1.0 / (1.0 + math.exp(-margin)) / unlabeledRatio, MLUtils.log1pExp(margin) / unlabeledRatio)
94 | }
95 | }
96 |
97 | }
98 |
99 |
100 |
101 |
102 |
103 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/loss/WeightedHingeLossFunc.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.loss
2 |
3 | import org.apache.spark.mllib.linalg.Vector
4 | import net.qihoo.xitong.xdml.model.linalg.BLAS
5 |
6 |
7 | class WeightedHingeLossFunc(posWeight: Double) extends LossFunc {
8 |
9 | ////////////////////////// without intercept //////////////////////////////////
10 |
11 | override def gradientWithLoss(data: Vector,
12 | label: Double,
13 | weights: Vector): (Double, Double) = {
14 | val dotProduct = BLAS.dot(weights, data)
15 | val labelScaled = 2 * label - 1.0
16 | if (1.0 > labelScaled * dotProduct) {
17 | (- labelScaled * (label * (posWeight - 1) + 1), (1.0 - labelScaled * dotProduct) * (label * (posWeight - 1) + 1))
18 | } else {
19 | (0.0, 0.0)
20 | }
21 | }
22 |
23 | override def gradient(data: Vector,
24 | label: Double,
25 | weights: Vector): Double = {
26 | val dotProduct = BLAS.dot(weights, data)
27 | val labelScaled = 2 * label - 1.0
28 | if (1.0 > labelScaled * dotProduct) {
29 | - labelScaled * (label * (posWeight - 1) + 1)
30 | } else {
31 | 0.0
32 | }
33 | }
34 |
35 | override def loss(data: Vector,
36 | label: Double,
37 | weights: Vector): Double = {
38 | val dotProduct = BLAS.dot(weights, data)
39 | val labelScaled = 2 * label - 1.0
40 | if (1.0 > labelScaled * dotProduct) {
41 | (1.0 - labelScaled * dotProduct) * (label * (posWeight - 1) + 1)
42 | } else {
43 | 0.0
44 | }
45 | }
46 |
47 | override def gradientFromDot(dot: Double,
48 | label: Double): Double = {
49 | val labelScaled = 2 * label - 1.0
50 | if (1.0 > labelScaled * dot) {
51 | - labelScaled * (label * (posWeight - 1) + 1)
52 | } else {
53 | 0.0
54 | }
55 | }
56 |
57 | ////////////////////////// with intercept //////////////////////////////////
58 |
59 | override def gradientWithLoss(data: Vector,
60 | label: Double,
61 | weights: Vector,
62 | intercept: Double): (Double, Double) = {
63 | // val dotProduct = BLAS.dot(weights, data)
64 | val dotProduct = BLAS.dot(weights, data) + intercept
65 | val labelScaled = 2 * label - 1.0
66 | if (1.0 > labelScaled * dotProduct) {
67 | (- labelScaled * (label * (posWeight - 1) + 1), (1.0 - labelScaled * dotProduct) * (label * (posWeight - 1) + 1))
68 | } else {
69 | (0.0, 0.0)
70 | }
71 | }
72 |
73 | override def gradient(data: Vector,
74 | label: Double,
75 | weights: Vector,
76 | intercept: Double): Double = {
77 | // val dotProduct = BLAS.dot(weights, data)
78 | val dotProduct = BLAS.dot(weights, data) + intercept
79 | val labelScaled = 2 * label - 1.0
80 | if (1.0 > labelScaled * dotProduct) {
81 | - labelScaled * (label * (posWeight - 1) + 1)
82 | } else {
83 | 0.0
84 | }
85 | }
86 |
87 | override def loss(data: Vector,
88 | label: Double,
89 | weights: Vector,
90 | intercept: Double): Double = {
91 | // val dotProduct = BLAS.dot(weights, data)
92 | val dotProduct = BLAS.dot(weights, data) + intercept
93 | val labelScaled = 2 * label - 1.0
94 | if (1.0 > labelScaled * dotProduct) {
95 | (1.0 - labelScaled * dotProduct) * (label * (posWeight - 1) + 1)
96 | } else {
97 | 0.0
98 | }
99 | }
100 |
101 | }
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/loss/WeightedLogitLossFunc.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.loss
2 |
3 | import org.apache.spark.mllib.linalg.Vector
4 | import net.qihoo.xitong.xdml.model.util.MLUtils
5 | import net.qihoo.xitong.xdml.model.linalg.BLAS
6 |
7 |
8 | class WeightedLogitLossFunc(posWeight: Double) extends LossFunc {
9 |
10 | def sigmoid(z: Double): Double = {
11 | 1.0 / (1.0 + math.exp(-z))
12 | }
13 |
14 | ////////////////////////// without intercept //////////////////////////////////
15 |
16 | override def loss(data: Vector,
17 | label: Double,
18 | weights: Vector): Double = {
19 | val margin = - BLAS.dot(weights, data)
20 | if (label > 0.5) {
21 | // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
22 | MLUtils.log1pExp(margin) * posWeight
23 | } else {
24 | MLUtils.log1pExp(margin) - margin
25 | }
26 | }
27 |
28 | override def gradient(data: Vector,
29 | label: Double,
30 | weights: Vector): Double = {
31 | val margin = - BLAS.dot(weights, data)
32 | val factor = (1.0 / (1.0 + math.exp(margin))) - label
33 | factor * (label * (posWeight - 1) + 1)
34 | }
35 |
36 | override def gradientWithLoss(data: Vector,
37 | label: Double,
38 | weights: Vector): (Double, Double) = {
39 | val margin = - BLAS.dot(weights, data)
40 | val factor = (1.0 / (1.0 + math.exp(margin))) - label
41 | val loss = if (label > 0.5) {
42 | // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
43 | MLUtils.log1pExp(margin) * posWeight
44 | } else {
45 | MLUtils.log1pExp(margin) - margin
46 | }
47 | (factor * (label * (posWeight - 1) + 1), loss)
48 | }
49 |
50 | override def gradientFromDot(dot: Double,
51 | label: Double): Double = {
52 | (sigmoid(dot) - label) * (label * (posWeight - 1) + 1)
53 | }
54 |
55 | ////////////////////////// with intercept //////////////////////////////////
56 |
57 | override def loss(data: Vector,
58 | label: Double,
59 | weights: Vector,
60 | intercept: Double): Double = {
61 | // val margin = - BLAS.dot(weights, data)
62 | val margin = - BLAS.dot(weights, data) - intercept
63 | if (label > 0.5) {
64 | // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
65 | MLUtils.log1pExp(margin) * posWeight
66 | } else {
67 | MLUtils.log1pExp(margin) - margin
68 | }
69 | }
70 |
71 | override def gradient(data: Vector,
72 | label: Double,
73 | weights: Vector,
74 | intercept: Double): Double = {
75 | // val margin = - BLAS.dot(weights, data)
76 | val margin = - BLAS.dot(weights, data) - intercept
77 | val factor = (1.0 / (1.0 + math.exp(margin))) - label
78 | factor * (label * (posWeight - 1) + 1)
79 | }
80 |
81 | override def gradientWithLoss(data: Vector,
82 | label: Double,
83 | weights: Vector,
84 | intercept: Double): (Double, Double) = {
85 | // val margin = - BLAS.dot(weights, data)
86 | val margin = - BLAS.dot(weights, data) - intercept
87 | val factor = (1.0 / (1.0 + math.exp(margin))) - label
88 | val loss = if (label > 0.5) {
89 | // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
90 | MLUtils.log1pExp(margin) * posWeight
91 | } else {
92 | MLUtils.log1pExp(margin) - margin
93 | }
94 | (factor * (label * (posWeight - 1) + 1), loss)
95 | }
96 |
97 | }
98 |
99 |
100 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/loss/WeightedSmoothHingeLossFunc.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.loss
2 |
3 | import net.qihoo.xitong.xdml.model.linalg.BLAS
4 | import org.apache.spark.mllib.linalg.Vector
5 |
6 |
7 | class WeightedSmoothHingeLossFunc(posWeight: Double) extends LossFunc {
8 |
9 | ////////////////////////// without intercept //////////////////////////////////
10 |
11 | override def gradientWithLoss(data: Vector,
12 | label: Double,
13 | weights: Vector): (Double, Double) = {
14 | val dotProduct = BLAS.dot(weights, data)
15 | val labelScaled = 2 * label - 1.0
16 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct)
17 | if (labelScaled * dotProduct < 0) {
18 | (- labelScaled * (label * (posWeight - 1) + 1), (max - 0.5) * (label * (posWeight - 1) + 1))
19 | } else {
20 | (- labelScaled * max * (label * (posWeight - 1) + 1), max * max / 2 * (label * (posWeight - 1) + 1))
21 | }
22 | }
23 |
24 | override def gradient(data: Vector,
25 | label: Double,
26 | weights: Vector): Double = {
27 | val dotProduct = BLAS.dot(weights, data)
28 | val labelScaled = 2 * label - 1.0
29 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct)
30 | if (labelScaled * dotProduct < 0) {
31 | - labelScaled * (label * (posWeight - 1) + 1)
32 | } else {
33 | - labelScaled * max * (label * (posWeight - 1) + 1)
34 | }
35 | }
36 |
37 | override def loss(data: Vector,
38 | label: Double,
39 | weights: Vector): Double = {
40 | val dotProduct = BLAS.dot(weights, data)
41 | val labelScaled = 2 * label - 1.0
42 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct)
43 | if (labelScaled * dotProduct < 0) {
44 | (max - 0.5) * (label * (posWeight - 1) + 1)
45 | } else {
46 | max * max / 2 * (label * (posWeight - 1) + 1)
47 | }
48 | }
49 |
50 | override def gradientFromDot(dot: Double,
51 | label: Double): Double = {
52 | val labelScaled = 2 * label - 1.0
53 | val max = math.max(0.0, 1.0 - labelScaled * dot)
54 | if (labelScaled * dot < 0) {
55 | - labelScaled * (label * (posWeight - 1) + 1)
56 | } else {
57 | - labelScaled * max * (label * (posWeight - 1) + 1)
58 | }
59 | }
60 |
61 | ////////////////////////// with intercept //////////////////////////////////
62 |
63 | override def gradientWithLoss(data: Vector,
64 | label: Double,
65 | weights: Vector,
66 | intercept: Double): (Double, Double) = {
67 | // val dotProduct = BLAS.dot(weights, data)
68 | val dotProduct = BLAS.dot(weights, data) + intercept
69 | val labelScaled = 2 * label - 1.0
70 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct)
71 | if (labelScaled * dotProduct < 0) {
72 | (- labelScaled * (label * (posWeight - 1) + 1), (max - 0.5) * (label * (posWeight - 1) + 1))
73 | } else {
74 | (- labelScaled * max * (label * (posWeight - 1) + 1), max * max / 2 * (label * (posWeight - 1) + 1))
75 | }
76 | }
77 |
78 | override def gradient(data: Vector,
79 | label: Double,
80 | weights: Vector,
81 | intercept: Double): Double = {
82 | // val dotProduct = BLAS.dot(weights, data)
83 | val dotProduct = BLAS.dot(weights, data) + intercept
84 | val labelScaled = 2 * label - 1.0
85 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct)
86 | if (labelScaled * dotProduct < 0) {
87 | - labelScaled * (label * (posWeight - 1) + 1)
88 | } else {
89 | - labelScaled * max * (label * (posWeight - 1) + 1)
90 | }
91 | }
92 |
93 | override def loss(data: Vector,
94 | label: Double,
95 | weights: Vector,
96 | intercept: Double): Double = {
97 | // val dotProduct = BLAS.dot(weights, data)
98 | val dotProduct = BLAS.dot(weights, data) + intercept
99 | val labelScaled = 2 * label - 1.0
100 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct)
101 | if (labelScaled * dotProduct < 0) {
102 | (max - 0.5) * (label * (posWeight - 1) + 1)
103 | } else {
104 | max * max / 2 * (label * (posWeight - 1) + 1)
105 | }
106 | }
107 |
108 | }
109 |
110 |
111 |
112 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/supervised/H2O/H2OParams.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.ml.model.supervised
2 |
3 | import org.apache.spark.ml.param._
4 | import org.apache.spark.sql.types.{StringType, StructType}
5 |
6 |
7 | trait H2OParams extends Params {
8 |
9 | final val labelCol: Param[String] = new Param[String](this, "labelCol", "Label column name")
10 |
11 | final val catFeatColNames: StringArrayParam = new StringArrayParam(this, "catFeatColNames","Categorical feature column names")
12 |
13 | final val ignoredFeatColNames: StringArrayParam = new StringArrayParam(this, "ignoredFeatColNames", "Ignored feature column names")
14 |
15 | protected def validateSchema(schema: StructType): StructType={
16 | ${catFeatColNames}.foreach { catFeatName => require(schema.fieldNames.contains(catFeatName), s"Column $catFeatName cannot find in dataFrame.") }
17 | ${catFeatColNames}.foreach { catFeatName => require(schema(catFeatName).dataType == StringType, s"The categorical feature column $catFeatName must be string type.")}
18 | schema
19 | }
20 | }
21 |
22 | trait H2OTreeParams extends H2OParams{
23 |
24 | final val maxDepth: IntParam =
25 | new IntParam(this, "maxDepth", "Maximum depth of the tree. (>= 0)" +
26 | " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.",
27 | ParamValidators.gtEq(0))
28 |
29 | final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
30 | ParamValidators.gtEq(1))
31 |
32 | final val maxBinsForCat: IntParam = new IntParam(this, "maxBinsForCat", "Max number of bins for" +
33 | " category features. Must be >=2 and >= number of categories for any" +
34 | " categorical feature.", ParamValidators.gtEq(2))
35 |
36 | final val maxBinsForNum: IntParam = new IntParam(this, "maxBinsForNum", "Max number of bins for" +
37 | " numeric features. Must be >=2 and >= number of category for any" +
38 | " numeric feature.", ParamValidators.gtEq(2))
39 |
40 | final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" +
41 | " number of instances each child must have after split. If a split causes the left or right" +
42 | " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
43 | " Should be >= 1.", ParamValidators.gtEq(1))
44 |
45 | final val categoricalEncodingScheme: Param[String] = new Param[String](this, "categoricalEncodingScheme", "Encoding scheme " +
46 | "for categorical features. Supported options: " + H2OTreeParams.supportedCategoricalEncoding.mkString(", "),
47 | ParamValidators.inArray(H2OTreeParams.supportedCategoricalEncoding))
48 |
49 | final val histogramType: Param[String] = new Param[String](this, "histogramType"," Type of histogram. Supported options: " +
50 | H2OTreeParams.supportedHistogramType.mkString(", "), ParamValidators.inArray(H2OTreeParams.supportedHistogramType))
51 |
52 | final val distribution: Param[String] = new Param[String](this, "distribution", "Distribution for dataSet. Supported options: " +
53 | H2OTreeParams.supportedDistribution.mkString(", "), ParamValidators.inArray(H2OTreeParams.supportedDistribution))
54 |
55 | final val scoreTreeInterval: IntParam = new IntParam (this, "scoreTreeInterval", "Score the model after every so many trees." +
56 | " Disabled if set to 0.", ParamValidators.gtEq(0))
57 | }
58 |
59 | object H2OTreeParams{
60 |
61 | final val supportedCategoricalEncoding: Array[String] = Array("AUTO", "Enum", "LabelEncoder", "OneHotExplicit", "SortByResponse")
62 |
63 | final val supportedHistogramType: Array[String] = Array("AUTO", "UniformAdaptive", "QuantilesGlobal")
64 |
65 | final val supportedDistribution: Array[String] = Array("bernoulli", "multinomial", "gaussian")
66 | }
67 |
68 | trait H2OTreeEstimatorParams extends H2OTreeParams{
69 |
70 | def setLabelCol(value: String): this.type = set(labelCol, value)
71 |
72 | def setCatFeatColNames(values: Array[String]): this.type = set(catFeatColNames, values)
73 |
74 | def setIgnoreFeatColNames(values: Array[String]): this.type = set(ignoredFeatColNames, values)
75 |
76 | def setMaxDepth(value: Int): this.type = set(maxDepth, value)
77 |
78 | def setNumTrees(value: Int): this.type = set(numTrees, value)
79 |
80 | def setMaxBinsForCat(value: Int): this.type = set(maxBinsForCat, value)
81 |
82 | def setMaxBinsForNum(value: Int): this.type = set(maxBinsForNum, value)
83 |
84 | def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
85 |
86 | def setCategoricalEncodingScheme(value: String): this.type = set(categoricalEncodingScheme, value)
87 |
88 | def setHistogramType(value: String): this.type = set(histogramType, value)
89 |
90 | def setDistribution(value: String): this.type = set(distribution, value)
91 |
92 | def setScoreTreeInterval(value: Int): this.type = set(scoreTreeInterval, value)
93 |
94 | setDefault(catFeatColNames -> Array(), ignoredFeatColNames -> Array())
95 | }
96 |
97 |
98 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/model/util/MLUtils.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.model.util
2 |
3 | import org.apache.spark.internal.Logging
4 | import org.apache.spark.mllib.linalg._
5 | import net.qihoo.xitong.xdml.model.linalg.BLAS
6 |
7 |
8 | object MLUtils extends Logging {
9 |
10 | lazy val EPSILON = {
11 | var eps = 1.0
12 | while ((1.0 + (eps / 2.0)) != 1.0) {
13 | eps /= 2.0
14 | }
15 | eps
16 | }
17 |
18 | /**
19 | * Returns the squared Euclidean distance between two vectors. The following formula will be used
20 | * if it does not introduce too much numerical error:
21 | *
22 | * \|a - b\|_2^2 = \|a\|_2^2 + \|b\|_2^2 - 2 a^T b.
23 | *
24 | * When both vector norms are given, this is faster than computing the squared distance directly,
25 | * especially when one of the vectors is a sparse vector.
26 | * @param v1 the first vector
27 | * @param norm1 the norm of the first vector, non-negative
28 | * @param v2 the second vector
29 | * @param norm2 the norm of the second vector, non-negative
30 | * @param precision desired relative precision for the squared distance
31 | * @return squared distance between v1 and v2 within the specified precision
32 | */
33 | def fastSquaredDistance(v1: Vector,
34 | norm1: Double,
35 | v2: Vector,
36 | norm2: Double,
37 | precision: Double = 1e-6): Double = {
38 | val n = v1.size
39 | require(v2.size == n)
40 | require(norm1 >= 0.0 && norm2 >= 0.0)
41 | val sumSquaredNorm = norm1 * norm1 + norm2 * norm2
42 | val normDiff = norm1 - norm2
43 | var sqDist = 0.0
44 | /*
45 | * The relative error is
46 | *
47 | * EPSILON * ( \|a\|_2^2 + \|b\\_2^2 + 2 |a^T b|) / ( \|a - b\|_2^2 ),
48 | *
49 | * which is bounded by
50 | *
51 | * 2.0 * EPSILON * ( \|a\|_2^2 + \|b\|_2^2 ) / ( (\|a\|_2 - \|b\|_2)^2 ).
52 | *
53 | * The bound doesn't need the inner product, so we can use it as a sufficient condition to
54 | * check quickly whether the inner product approach is accurate.
55 | */
56 | val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)
57 | if (precisionBound1 < precision) {
58 | sqDist = sumSquaredNorm - 2.0 * BLAS.dot(v1, v2)
59 | } else if (v1.isInstanceOf[SparseVector] || v2.isInstanceOf[SparseVector]) {
60 | val dotValue = BLAS.dot(v1, v2)
61 | sqDist = math.max(sumSquaredNorm - 2.0 * dotValue, 0.0)
62 | val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dotValue)) /
63 | (sqDist + EPSILON)
64 | if (precisionBound2 > precision) {
65 | sqDist = Vectors.sqdist(v1, v2)
66 | }
67 | } else {
68 | sqDist = Vectors.sqdist(v1, v2)
69 | }
70 | sqDist
71 | }
72 |
73 | /**
74 | * When `x` is positive and large, computing `math.log(1 + math.exp(x))` will lead to arithmetic
75 | * overflow. This will happen when `x > 709.78` which is not a very large number.
76 | * It can be addressed by rewriting the formula into `x + math.log1p(math.exp(-x))` when `x > 0`.
77 | * @param x a floating-point value as input.
78 | * @return the result of `math.log(1 + math.exp(x))`.
79 | */
80 | def log1pExp(x: Double): Double = {
81 | if (x > 0) {
82 | x + math.log1p(math.exp(-x))
83 | } else {
84 | math.log1p(math.exp(x))
85 | }
86 | }
87 |
88 | }
89 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/optimization/BinaryClassify.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.optimization
2 |
3 | import breeze.numerics.sigmoid
4 | import net.qihoo.xitong.xdml.linalg.BLAS._
5 |
6 | object BinaryClassify extends Optimizer {
7 |
8 | def train(label: Double,
9 | feature: (Array[Long], Array[Float]),
10 | localW: Map[Long, Float],
11 | subsampling_rate: Double = 1.0,
12 | subsampling_label: Double = 0.0
13 | ): (Map[Long, Float], Double) = {
14 |
15 | val (pre, loss) = computePreAndLoss(label, feature, localW)
16 | val newGradient = getGradLoss(feature, label, pre, subsampling_rate, subsampling_label)
17 | (newGradient, loss)
18 | }
19 |
20 | def predict(feature: (Array[Long], Array[Float]),
21 | weight: Map[Long, Float]
22 | ): Double = {
23 | sigmoid(dot(feature, weight))
24 | }
25 |
26 | def getGradLoss(feature: (Array[Long], Array[Float]),
27 | label: Double,
28 | pre: Float,
29 | subsampling_rate: Double = 1.0,
30 | subsampling_label: Double = 0.0): Map[Long, Float] = {
31 | val p = sigmoid(pre)
32 | var w = 1.0
33 | if (Math.abs(label - subsampling_label) < 0.001) {
34 | w = 1 / subsampling_rate
35 | }
36 | var gradLossMap: Map[Long, Float] = Map()
37 | (0 until feature._1.size).foreach(id => gradLossMap += (feature._1(id) -> ((p - label) * feature._2(id) * w).toFloat))
38 | gradLossMap
39 | }
40 |
41 | def computePreAndLoss(label: Double,
42 | feature: (Array[Long], Array[Float]),
43 | weight: Map[Long, Float]
44 | ): (Float, Double) = {
45 | val dotVal = dot(feature, weight)
46 | val loss = if (label > 0) {
47 | if (dotVal > 13)
48 | 2e-6
49 | else if (dotVal < -13)
50 | (-dotVal)
51 | else
52 | math.log1p(math.exp(-dotVal))
53 | } else {
54 | if (dotVal > 13)
55 | (-dotVal)
56 | else if (dotVal < -13)
57 | 2e-6
58 | else
59 | math.log1p(math.exp(-dotVal)) + dotVal
60 | }
61 | (dotVal, loss)
62 | }
63 |
64 | }
65 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/optimization/FFM.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.optimization
2 |
3 | import scala.collection.mutable
4 |
5 | object FFM extends Optimizer {
6 | val lamda = 0.1F
7 | val lr = 0.1F
8 |
9 | def train(fieldNum: Int,
10 | rank: Int,
11 | label: Double,
12 | feature: (Array[Int], Array[Long], Array[Float]),
13 | VMap: mutable.Map[Long, Array[Float]]): (mutable.Map[Long, Array[Float]], Double) = {
14 | calcGradAndLoss(fieldNum, rank, label, feature, VMap)
15 | }
16 |
17 | def predict(): (Double,Double) ={
18 | (0.0,0.0)
19 | }
20 |
21 | //calculate dot
22 | def vectorDotWithField(rank: Int,
23 | vector1: Array[Float],
24 | field1: Int,
25 | vector2: Array[Float],
26 | field2: Int): Float = {
27 | var res = 0.0F
28 | val idx1 = rank * field1
29 | val idx2 = rank * field2
30 | for (i <- 0 until rank) {
31 | res = res + vector1(idx1 + i) * vector2(idx2 + i)
32 | }
33 | println("vec1: " + vector1.mkString(","))
34 | println("vec2: " + vector2.mkString(","))
35 | println("dot res: " + res)
36 | res
37 | }
38 |
39 | //calculate phi
40 | def calcPhi(fieldNum: Int,
41 | rank: Int,
42 | oneData: (Array[Int], Array[Long], Array[Float]),
43 | VMap: mutable.Map[Long, Array[Float]]): Float = {
44 | val size = oneData._1.length
45 | val field = oneData._1
46 | val id = oneData._2
47 | val value = oneData._3
48 | var resPhi = 0.0F
49 | val zeros = Array.fill(rank * fieldNum)(0.01f)
50 | for (i <- 0 until size) {
51 | for (j <- i + 1 until size) {
52 | val w1 = VMap.getOrElse(id(i), zeros)
53 | val w2 = VMap.getOrElse(id(j), zeros)
54 | println("w1: " + w1.mkString(",") + " w2: " + w2.mkString(","))
55 | resPhi = resPhi + vectorDotWithField(rank, w1, field(i), w2, field(j)) * value(i) * value(j)
56 | }
57 | }
58 | println("phi: " + resPhi)
59 | resPhi
60 | }
61 |
62 | //calculate grad and loss
63 | def calcGradAndLoss(fieldNum: Int,
64 | rank: Int,
65 | label: Double,
66 | oneData: (Array[Int], Array[Long], Array[Float]),
67 | weightMap: mutable.Map[Long, Array[Float]]): (mutable.Map[Long, Array[Float]], Double) = {
68 | val pred = calcPhi(fieldNum, rank, oneData, weightMap)
69 | val loss = math.log1p(math.exp(-pred * label))
70 | println("pred: " + pred + " label: " + label)
71 | val size = oneData._1.length
72 | val field = oneData._1
73 | val id = oneData._2
74 | val value = oneData._3
75 | val gradMap = new mutable.HashMap[Long, Array[Float]]
76 | val zeros = Array.fill(rank * fieldNum + fieldNum)(0.01f)
77 | val k = -label / math.log1p(label * pred)
78 | for (i <- 0 until size) {
79 | for (j <- i + 1 until size) {
80 | val w1 = weightMap.getOrElse(id(i), zeros)
81 | val w2 = weightMap.getOrElse(id(j), zeros)
82 | val (id1, id2, f1, f2, v1, v2) = (id(i), id(j), field(i), field(j), value(i), value(j))
83 | val grad = calcGrad(fieldNum, rank, w1, f1, v1, w2, f2, v2, k)
84 | val start1 = f2 * rank
85 | val start2 = f1 * rank
86 | for (index <- 0 until rank) {
87 | w1(start1 + index) = grad._1(index)
88 | w2(start2 + index) = grad._2(index)
89 | }
90 | gradMap.put(id1, w1)
91 | gradMap.put(id2, w2)
92 | }
93 | }
94 | (gradMap, loss)
95 | }
96 |
97 | //calculate grad
98 | def calcGrad(fieldNum: Int,
99 | rank: Int,
100 | w1: Array[Float],
101 | w1Field: Int,
102 | x1: Float,
103 | w2: Array[Float],
104 | w2Field: Int,
105 | x2: Float,
106 | k: Double): (Array[Float], Array[Float]) = {
107 | val w1f2Grad = Array.fill(rank)(0.01f)
108 | val w2f1Grad = Array.fill(rank)(0.01f)
109 | val v1f2Start = rank * w2Field
110 | val v2f1Start = rank * w1Field
111 | for (i <- 0 until rank) {
112 | w1f2Grad(i) = lamda * w1(v1f2Start + i) + k.toFloat * w2(v2f1Start + i) * x1 * x2
113 | w2f1Grad(i) = lamda * w2(v2f1Start + i) + k.toFloat * w1(v1f2Start + i) * x1 * x2
114 | }
115 | println("grad1: " + w1f2Grad.mkString(","))
116 | println("grad2: " + w2f1Grad.mkString(","))
117 | (w1f2Grad, w2f1Grad)
118 | }
119 | //update map
120 | def updateGradMap(fieldNum: Int,
121 | rank: Int,
122 | targetGradMap: mutable.Map[Long, Array[Float]],
123 | grad: (Array[Float], Array[Float]),
124 | field1: Int,
125 | field2: Int,
126 | id1: Long,
127 | id2: Long): Unit = {
128 | val zeros = Array.fill(rank)(0.0f)
129 | val gradId1 = targetGradMap.getOrElse(id1, zeros)
130 | val gradId2 = targetGradMap.getOrElse(id2, zeros)
131 | val (w1f2Grad, w2f1Grad) = (grad._1, grad._2)
132 | val (start1, start2) = (field2 * rank, field1 * rank)
133 | val (lastGradAcc1, lastGradAcc2) = (w1f2Grad(rank * fieldNum - 1 + field1), w2f1Grad(rank * fieldNum - 1 + field2))
134 | val lr1 = -1 / math.sqrt(1 + lastGradAcc1).toFloat
135 | val lr2 = -1 / math.sqrt(1 + lastGradAcc2).toFloat
136 | var grad2norm1 = 0.0F
137 | var grad2norm2 = 0.0F
138 | for (i <- 0 until rank) {
139 | gradId1(start1 + i) = gradId1(start1 + i) + lr1 * w1f2Grad(i)
140 | gradId2(start2 + i) = gradId2(start2 + i) + lr2 * w2f1Grad(i)
141 | grad2norm1 += math.pow(w1f2Grad(i), 2).toFloat
142 | grad2norm2 += math.pow(w2f1Grad(i), 2).toFloat
143 |
144 | }
145 | gradId1(rank * fieldNum - 1 + field1) = grad2norm1 + lastGradAcc1
146 | gradId2(rank * fieldNum - 1 + field2) = grad2norm2 + lastGradAcc2
147 | targetGradMap.put(id1, gradId1)
148 | targetGradMap.put(id2, gradId2)
149 | }
150 |
151 | //calc vec2
152 | def calc2norm(fieldNum: Int, rank: Int, vector: Array[Float], field: Int): Float = {
153 | var res = 0.0F
154 | val start = rank * field
155 | for (i <- 0 until rank) {
156 | res += math.pow(vector(start + i), 2).toFloat
157 | }
158 | res
159 | }
160 | }
161 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/optimization/FM.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.optimization
2 |
3 | import net.qihoo.xitong.xdml.linalg.BLAS._
4 |
5 |
6 | object FM extends Optimizer {
7 | def train(label: Double,
8 | feature: (Array[Long], Array[Float]),
9 | W0: Float,
10 | W: Map[Long, Float],
11 | V: Map[Long, Array[Float]],
12 | rank: Int,
13 | subsampling_rate: Double = 1.0,
14 | subsampling_label: Double = 0.0): (Map[Long, Array[Float]], Double) = {
15 | //pred是预测值y,先加上常数项和一次项
16 | var pred = W0 + dot(feature, W)
17 | val array = Array.fill(rank)(0.0f)
18 | val sumArray = Array.fill(rank)(0.0f)
19 | // res1 = v_i,f * x_i
20 | // res2 = (v_i,f * x_i)^2
21 | (0 until rank).foreach { f =>
22 | var (res1, res2) = (0f, 0f)
23 | feature._1.indices.foreach { id =>
24 | val fId = feature._1(id)
25 | val fValue = feature._2(id)
26 | val tmp = fValue * V.getOrElse(fId, array)(f)
27 | res1 += tmp
28 | res2 += tmp * tmp
29 | }
30 | sumArray(f) = res1
31 | pred += 0.5f * (res1 * res1 - res2)
32 | }
33 |
34 | /*
35 | val loss = if (label > 0) {
36 | if (pred > 13)
37 | 2e-6
38 | else if (pred < -13)
39 | -pred
40 | else
41 | math.log1p(math.exp(-pred))
42 | } else {
43 | if (pred > 13)
44 | -pred
45 | else if (pred < -13)
46 | 2e-6
47 | else
48 | math.log1p(math.exp(-pred)) + pred
49 | }
50 | */
51 | val z = label * pred
52 | //计算loss
53 | var loss = 0.0
54 | if (z > 18)
55 | loss = math.exp(-z)
56 | else if (z < -18)
57 | loss = -z
58 | else
59 | loss = math.log1p(math.exp(-z))
60 |
61 | val p = deltaGrad(label, pred).toFloat
62 | var w = 1.0f
63 | if (math.abs(label - subsampling_label) < 0.001)
64 | w = (1 / subsampling_rate).toFloat
65 |
66 | var gradLossMap: Map[Long, Array[Float]] = Map()
67 | val wArray = Array.fill(rank + 1)(0.0f)
68 | wArray(0) = p * w
69 | gradLossMap += (0L -> wArray)
70 | feature._1.indices.foreach { id =>
71 | val fId = feature._1(id)
72 | val fValue = feature._2(id)
73 | val wArray = Array.fill(rank + 1)(0.0f)
74 | if (fId != 0L) {
75 | wArray(0) = fValue * p * w
76 | (0 until rank).foreach { f =>
77 | val vGrad = fValue * sumArray(f) - V.getOrElse(fId, array)(f) * fValue * fValue
78 | wArray(f + 1) = vGrad * p * w
79 | }
80 | gradLossMap += (fId -> wArray)
81 | }
82 | }
83 | (gradLossMap, loss)
84 | }
85 |
86 | def predict(feature: (Array[Long], Array[Float]),
87 | W0: Float,
88 | W: Map[Long, Float],
89 | V: Map[Long, Array[Float]],
90 | rank: Int): Double = {
91 | var pred = W0 + dot(feature, W)
92 | val array = Array.fill(rank)(0.0F)
93 | for (f <- 0 until rank) {
94 | var (res1, res2) = (0f, 0f)
95 | feature._1.indices.foreach { id =>
96 | val tmp = feature._2(id) * V.getOrElse(feature._1(id), array)(f)
97 | res1 += tmp
98 | res2 += tmp * tmp
99 | }
100 | pred += 0.5f * (res1 * res1 - res2)
101 | }
102 | sigmoid(pred)
103 | }
104 |
105 | // - exp(-label * pred) / (1 + exp(-label * pred)) * label
106 | def deltaGrad(label: Double,
107 | pred: Float): Double = {
108 | val sigma = label * pred
109 | if (sigma > 18)
110 | -label * math.exp(-sigma)
111 | else if (sigma < -18)
112 | -label
113 | else {
114 | val ex = Math.pow(2.718281828, sigma)
115 | -label / (1 + ex)
116 | }
117 | }
118 |
119 | }
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/optimization/Optimizer.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.optimization
2 |
3 | trait Optimizer {
4 |
5 | }
6 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/task/PullTask.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.task
2 |
3 | import java.util.concurrent.Callable
4 |
5 | class PullTask[K, V] extends Task with Callable[java.util.Map[Long, Array[Byte]]]{
6 | var pullSet: java.util.Set[Long] = _
7 | //pull set must be set before pull parameters from ps
8 | def setPullSet(pullSet: java.util.Set[Long]): Unit = {
9 | this.pullSet = pullSet
10 | }
11 | //pull task
12 | override def call(): java.util.Map[Long, Array[Byte]] = {
13 | val psMap = hazelcastIns.getMap[Long, Array[Byte]](tableName)
14 | psMap.getAll(pullSet)
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/task/PushTask.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.task
2 |
3 | import java.nio.ByteBuffer
4 | import java.util.concurrent.Callable
5 |
6 | import net.qihoo.xitong.xdml.conf.PSDataType
7 | import net.qihoo.xitong.xdml.conf.PSDataType.PSDataType
8 | import net.qihoo.xitong.xdml.updater.Updater
9 | import net.qihoo.xitong.xdml.utils.XDMLException
10 |
11 | import scala.collection.JavaConversions._
12 |
13 | class PushTask[K, V] extends Task with Callable[Unit] {
14 | //PSDataType of V
15 | private var vClazz: PSDataType = _
16 | private var pushMap: java.util.Map[Long, V] = _
17 | private var psDatalength:Int = _
18 | private var updater:Updater[Long, V] = _
19 | private var hasUpdater = false
20 | //V's length
21 | def setDataLength(length: Int): Unit ={
22 | this.psDatalength = length
23 | }
24 |
25 | def setVClazz(dataType:PSDataType) : Unit = {
26 | this.vClazz = dataType
27 | }
28 |
29 | def setUpdater(updater:Updater[Long, V]): Unit ={
30 | this.updater = updater
31 | this.hasUpdater = true
32 | }
33 |
34 | def setPushMap(pushMap: java.util.Map[Long, V]): Unit = {
35 | this.pushMap = pushMap
36 | }
37 | //get V from bytes
38 | def getValueFromBytes(bytes:Array[Byte]): V = {
39 | val valueBuff = ByteBuffer.wrap(bytes)
40 | vClazz match{
41 | case PSDataType.INT => valueBuff.getInt().asInstanceOf[V]
42 | case PSDataType.LONG => valueBuff.getLong().asInstanceOf[V]
43 | case PSDataType.FLOAT => valueBuff.getFloat().asInstanceOf[V]
44 | case PSDataType.DOUBLE => valueBuff.getDouble().asInstanceOf[V]
45 | case PSDataType.FLOAT_ARRAY => {
46 | val arr = new Array[Float](psDatalength)
47 | for(index <- arr.indices){
48 | arr(index) = valueBuff.getFloat(index << 2)
49 | }
50 | arr.asInstanceOf[V]
51 | }
52 | case PSDataType.DOUBLE_ARRAY => {
53 | val arr = new Array[Double](psDatalength)
54 | for(index <- arr.indices){
55 | arr(index) = valueBuff.getDouble(index << 3)
56 | }
57 | arr.asInstanceOf[V]
58 | }
59 | case _ => throw new XDMLException("data type error!")
60 | }
61 | }
62 | //get bytes from V
63 | def getBytesFromValue(value: V): Array[Byte] = {
64 | val byteSize = PSDataType.sizeOf(vClazz) * psDatalength
65 | val byteBuff = ByteBuffer.allocate(byteSize)
66 | vClazz match {
67 | case PSDataType.INT => {
68 | byteBuff.putInt(value.asInstanceOf[Int])
69 | }
70 | case PSDataType.LONG => {
71 | byteBuff.putLong(value.asInstanceOf[Long])
72 | }
73 | case PSDataType.FLOAT => {
74 | byteBuff.putFloat(value.asInstanceOf[Float])
75 | }
76 | case PSDataType.DOUBLE => {
77 | byteBuff.putDouble(value.asInstanceOf[Double])
78 | }
79 | case PSDataType.FLOAT_ARRAY => {
80 | val arr = value.asInstanceOf[Array[Float]]
81 | arr.indices.map { index =>
82 | byteBuff.putFloat(index << 2, arr(index))
83 | }
84 | }
85 | case PSDataType.DOUBLE_ARRAY => {
86 | val arr = value.asInstanceOf[Array[Double]]
87 | arr.indices.map { index =>
88 | byteBuff.putDouble(index << 3, arr(index))
89 | }
90 | }
91 | case _ => throw new IllegalArgumentException("data type error!")
92 | }
93 | byteBuff.array()
94 | }
95 | //push task
96 | override def call(): Unit = {
97 | val weightMap = hazelcastIns.getMap[Long, Array[Byte]](tableName)
98 | val localMap = weightMap.getAll(pushMap.keySet).map {
99 | case (k, bytes) => (k, getValueFromBytes(bytes))
100 | }
101 | if(hasUpdater)
102 | pushMap = updater.update(localMap,pushMap)
103 | val bytesMap = pushMap.map {
104 | case (k, v) => (k, getBytesFromValue(v))
105 | }
106 | weightMap.putAll(bytesMap)
107 | }
108 |
109 | }
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/task/Task.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.task
2 |
3 | import java.io.Serializable
4 |
5 | import com.hazelcast.core.{HazelcastInstance, HazelcastInstanceAware}
6 |
7 | abstract class Task extends Serializable with HazelcastInstanceAware {
8 | @transient
9 | var hazelcastIns: HazelcastInstance = _
10 |
11 | //pull and push task table name
12 | protected var tableName:String = _
13 |
14 | def setTableName(name: String): Unit = {
15 | this.tableName = name
16 | }
17 | //set hz server instance
18 | override def setHazelcastInstance(hz: HazelcastInstance): Unit = {
19 | this.hazelcastIns = hz
20 | }
21 |
22 | }
23 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/updater/DCASGDUpdater.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.updater
2 | import java.util
3 | import scala.collection.JavaConversions._
4 |
5 | class DCASGDUpdater extends Updater[Long,Array[Float]]{
6 | private var coff = 0.1F
7 | private var learningRate = 0.01F
8 |
9 | override def update(originalWeightMap: util.Map[Long, Array[Float]], gradientMap: util.Map[Long, Array[Float]]): util.Map[Long, Array[Float]] = {
10 | val updateMap = originalWeightMap.map { case (k, v) => {
11 | val grad = gradientMap(k)(0)
12 | val delta = -learningRate * grad - coff * grad * grad * (v(0) - v(1))
13 | v(1) = v(0)
14 | v(0) = v(1) + delta
15 | (k, v)
16 | }
17 | }
18 | updateMap
19 | }
20 |
21 | def setLearningRate(lr:Float):this.type ={
22 | this.learningRate = lr
23 | this
24 | }
25 |
26 | def setCoff(coff:Float):this.type ={
27 | this.coff = coff
28 | this
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/updater/FFMUpdater.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.updater
2 | import java.util
3 | import scala.collection.JavaConversions._
4 |
5 | class FFMUpdater extends Updater[Long,Array[Float]]{
6 |
7 | private var learningRate = 0.01F
8 |
9 | override def update(originalWeightMap: util.Map[Long, Array[Float]], gradientMap: util.Map[Long, Array[Float]]): util.Map[Long, Array[Float]] = {
10 | val size = gradientMap.head._2.length
11 | val updateMap = originalWeightMap.map{
12 | case(k, v) => for(index <- 0 until size){
13 | v(index) = v(index) - learningRate * gradientMap(k)(index)
14 | }
15 | (k,v)
16 | }
17 | updateMap
18 | }
19 |
20 | //set learning rate
21 | def setLearningRate(lr:Float):this.type ={
22 | this.learningRate = lr
23 | this
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/updater/LRFTRLUpdater.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.updater
2 |
3 | import java.util
4 | import scala.collection.JavaConversions._
5 |
6 | class LRFTRLUpdater extends Updater[Long, Array[Float]] {
7 | private var alpha = 1F
8 | private var beta = 1F
9 | private var lambda1 = 1F
10 | private var lambda2 = 1F
11 |
12 | override def update(originalWeightMap: util.Map[Long, Array[Float]], gradientMap: util.Map[Long, Array[Float]]): util.Map[Long, Array[Float]] = {
13 | val updateMap = originalWeightMap.map {
14 | case (k, v) => {
15 | val grad = gradientMap(k)(0)
16 | val pValue = 1f / alpha * (Math.sqrt(v(1) + grad * grad) - Math.sqrt(v(1)))
17 | v(0) = (v(0) + grad - pValue * v(2)).toFloat
18 | v(1) = v(1) + grad * grad
19 | if (Math.abs(v(0)) > lambda1)
20 | v(2) = ((-1) * (1.0 / (lambda2 + (beta + Math.sqrt(v(1))) / alpha)) * (v(0) - Math.signum(v(0)).toInt * lambda1)).toFloat
21 | else
22 | v(2) = 0F
23 |
24 | (k, v)
25 | }
26 | }
27 | updateMap
28 | }
29 |
30 | def setAlpha(alpha: Float): this.type = {
31 | this.alpha = alpha
32 | this
33 | }
34 |
35 | def setBeta(beta: Float): this.type = {
36 | this.beta = beta
37 | this
38 | }
39 |
40 | def setLambda1(lambda1: Float): this.type = {
41 | this.lambda1 = lambda1
42 | this
43 | }
44 |
45 | def setLambda2(lambda2: Float): this.type = {
46 | this.lambda2 = lambda2
47 | this
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/updater/LRUpdater.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.updater
2 |
3 | import scala.collection.JavaConversions._
4 |
5 | class LRUpdater extends Updater[Long, Float]{
6 |
7 | private var learningRate:Float = 0.01F
8 | //normal lr push function to deal with last weights and gradients
9 | override def update(lastWeightMap:java.util.Map[Long,Float], gradientMap:java.util.Map[Long,Float]): java.util.Map[Long,Float] = {
10 | val updateMap = lastWeightMap.map { case (k, v) =>
11 | val delta = -learningRate * gradientMap(k)
12 | (k, v + delta)
13 | }
14 | updateMap
15 | }
16 | //set learning rate
17 | def setLearningRate(lr:Float):this.type ={
18 | this.learningRate = lr
19 | this
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/updater/MomLRUpdater.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.updater
2 |
3 | import java.util
4 | import scala.collection.JavaConversions._
5 |
6 | class MomLRUpdater extends Updater[Long, Array[Float]] {
7 | private var coff = 0.1F
8 | private var learningRate = 0.01F
9 |
10 | override def update(originalWeightMap: util.Map[Long, Array[Float]], gradientMap: util.Map[Long, Array[Float]]): util.Map[Long, Array[Float]] = {
11 | val updateMap = originalWeightMap.map { case (k, v) =>
12 | val delta = -learningRate * gradientMap(k)(0) + coff * v(1)
13 | v(0) = v(0) + delta
14 | v(1) = delta
15 | // println("update map: V0: " + v(0) + "----> V1: " + v(1))
16 | (k, v)
17 | }
18 | updateMap
19 | }
20 |
21 | //set learning rate
22 | def setLearningRate(lr:Float):this.type ={
23 | this.learningRate = lr
24 | this
25 | }
26 |
27 | def setCoff(coff:Float):this.type ={
28 | this.coff = coff
29 | this
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/updater/Updater.scala:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.updater
2 |
3 | trait Updater[K, V] extends Serializable{
4 | //update function
5 | def update(originalWeightMap: java.util.Map[K, V], gradientMap: java.util.Map[K, V]):java.util.Map[K, V]
6 | }
7 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/utils/ExitCodeResolver.java:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.utils;
2 |
3 | public class ExitCodeResolver {
4 |
5 | static void analysis(int exitCode){
6 | switch(exitCode){
7 | case -1: {
8 |
9 | }break;
10 | case -2: {
11 |
12 | }break;
13 | case -3: {
14 |
15 | }break;
16 | case -4: {
17 |
18 | }break;
19 | case -5: {
20 |
21 | }break;
22 | case -6: {
23 |
24 | }break;
25 | case -7: {
26 |
27 | }break;
28 | default:
29 | throw new IllegalArgumentException("Unknown exit code.");
30 | }
31 | }
32 |
33 | }
34 |
--------------------------------------------------------------------------------
/src/main/scala/net/qihoo/xitong/xdml/utils/XDMLException.java:
--------------------------------------------------------------------------------
1 | package net.qihoo.xitong.xdml.utils;
2 |
3 | public class XDMLException extends RuntimeException{
4 | private static final long serialVersionUID = 1L;
5 |
6 | public XDMLException() {
7 | }
8 |
9 | public XDMLException(String message) {
10 | super(message);
11 | }
12 |
13 | public XDMLException(String message, Throwable cause) {
14 | super(message, cause);
15 | }
16 |
17 | public XDMLException(Throwable cause) {
18 | super(cause);
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/src/test/scala/DataColFilter.scala:
--------------------------------------------------------------------------------
1 | import scala.collection.mutable
2 | import scala.io.Source
3 | import scala.util.control.Breaks._
4 |
5 | /**
6 | * @author wangxingda
7 | * 2018/08/16
8 | */
9 |
10 | class DataColFilter extends DataFilter {
11 | val featureFilterSet = new mutable.HashSet[String]
12 | def readFilterConfig(filePath: String): Unit = {
13 | val file = Source.fromFile(filePath)
14 | var isStarted:Boolean = false
15 | for (line <- file.getLines() if !line.trim.equals("")) {
16 | if(line.contains("[col]")){
17 | isStarted = true
18 | } else if(isStarted){
19 | breakable{
20 | if(line.trim.contains("[row]")){
21 | isStarted = false
22 | break
23 | }
24 | val kv = getKeyAndValue(line.trim)
25 | if(kv._2 == 1)
26 | featureFilterSet.add(kv._1)
27 | }
28 | }
29 | }
30 | println("colFilter:" + featureFilterSet.mkString(","))
31 | }
32 |
33 | def colFilter(feature: Array[String]): Array[String] = {
34 | feature.filter(x => !featureFilterSet.contains(x.split(splitC)(0)))
35 | }
36 |
37 | //get key value from k = v
38 | def getKeyAndValue(str: String): (String, Long) = {
39 | if (str.equals(""))
40 | throw new IllegalArgumentException
41 | val splits = str.split("=")
42 | if(splits.length != 2){
43 | throw new IllegalArgumentException("Col filter getKeyAndValue Error: cause by splits's length is not equals 2: " + str)
44 | }
45 | var value = -1
46 | try{
47 | value = splits(1).trim.toInt
48 | }catch{
49 | case ex:Exception => println("Col filter getKeyAndValue Error: cause by value is not a number value: " + str)
50 | ex.printStackTrace()
51 | }
52 | (splits(0).trim, value)
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/src/test/scala/DataFilter.scala:
--------------------------------------------------------------------------------
1 | import java.io.Serializable
2 |
3 | /**
4 | * @author wangxingda
5 | * 2018/08/16
6 | */
7 |
8 | trait DataFilter extends Serializable{
9 | val splitC: String = "\03"
10 | val splitB: String = "\02"
11 | val splitA: String = "\01"
12 | val splitT: String = "\t"
13 | //read config file, include row and col filter
14 | def readFilterConfig(filePath: String)
15 | }
16 |
--------------------------------------------------------------------------------
/src/test/scala/TestFilter.scala:
--------------------------------------------------------------------------------
1 | import net.qihoo.xitong.xdml.utils.CityHash
2 | import org.apache.spark.{SparkConf, SparkContext}
3 |
4 | import scala.collection.mutable.ArrayBuffer
5 |
6 | object TestFilter {
7 | val rowFilterConfigFile = "ftrl_filter_demo.txt"
8 | val colFilterConfigFile = "ftrl_filter_demo.txt"
9 | val splitA = "\01"
10 | val splitB = "\02"
11 | val splitC = "\03"
12 |
13 |
14 | def main(args: Array[String]): Unit = {
15 | //配置spark
16 | val rowF = new DataRowFilter()
17 | rowF.readFilterConfig(rowFilterConfigFile)
18 | val colF = new DataColFilter()
19 | colF.readFilterConfig(colFilterConfigFile)
20 | val sc = new SparkContext(new SparkConf().setAppName("TestFilter"))
21 | val trainInputPath = sc.getConf.get("spark.ndml.data.path", "").trim
22 | if (trainInputPath == "") {
23 | println(s"\'spark.ndml.data.path\' must be set of input data.")
24 | System.exit(-1)
25 | }
26 | val rawData = sc.textFile(trainInputPath).repartition(1000)
27 |
28 | //使用ArrayBuffer只过一遍数据的新思路
29 | val s = System.nanoTime()
30 | val number = rawData.mapPartitions(iter =>{
31 | val res = new ArrayBuffer[(Double,Array[Long],Array[Float])]
32 | while(iter.hasNext){
33 | val line = iter.next()
34 | if(rowF.rowFilter(line)) {
35 | val splits = line.split(splitA)
36 | val list = splits.slice(1, splits.length)
37 | val featureList = colF.colFilter(list)
38 | res.append((splits(0).toDouble, featureList.map(x=>CityHash.stringCityHash64(x.split(splitB)(0))), featureList.map(x=>x.split(splitB)(1).toFloat)))
39 | }
40 | }
41 | res.iterator
42 | })
43 | val e = System.nanoTime()
44 |
45 | //传统思路
46 | // val s = System.nanoTime()
47 | // val number = rawData.filter(line => rowF.rowFilter(line)).map(line => {
48 | // val splits = line.split(splitA)
49 | // val list = splits.slice(1, splits.length)
50 | // val featureList = colF.colFilter(list)
51 | // (splits(0).toDouble, featureList.map(x=>CityHash.stringCityHash64(x.split(splitB)(0))), featureList.map(x=>x.split(splitB)(1).toFloat))
52 | // }).count()
53 | // val e = System.nanoTime()
54 |
55 | println("time:" + (e - s)/1e9)
56 | println("number:" + number)
57 | sc.stop()
58 | }
59 | }
60 |
--------------------------------------------------------------------------------