├── README.md ├── doc ├── BinaryClassify.md ├── CityHash.md ├── Example.md ├── FFMProcessor.md ├── FeatureAnalysis.md ├── FeatureProcess.md ├── FieldawareFactorizationMachine.md ├── JobConfigure.md ├── LibSVMProcessor.md ├── LogisticRegression.md ├── LogisticRegressionWithDCASGD.md ├── LogisticRegressionWithFTRL.md ├── LogisticRegressionWithMomentum.md ├── Model.md ├── PS.md ├── PSClient.md ├── PSConfiguration.md ├── PSMapStore.md ├── Updater.md ├── faq_cn.md └── img │ ├── logo.jpg │ ├── qq.jpg │ └── xdml.png ├── pom.xml └── src ├── main └── scala │ ├── net │ └── qihoo │ │ └── xitong │ │ └── xdml │ │ ├── conf │ │ ├── JobConfiguration.scala │ │ ├── JobType.scala │ │ ├── PSConfiguration.scala │ │ └── PSDataType.scala │ │ ├── dataProcess │ │ ├── FFMProcessor.scala │ │ └── LibSVMProcessor.scala │ │ ├── example │ │ ├── analysis │ │ │ ├── feature │ │ │ │ ├── analysis │ │ │ │ │ ├── runUniversalAnalyzerDense.scala │ │ │ │ │ ├── runUniversalAnalyzerDenseGrouped.scala │ │ │ │ │ ├── runUniversalAnalyzerDenseKS.scala │ │ │ │ │ └── runUniversalAnalyzerSparse.scala │ │ │ │ └── process │ │ │ │ │ ├── runCategoryEncoder.scala │ │ │ │ │ ├── runFeatureProcess.scala │ │ │ │ │ ├── runMultiCategoryEncoder.scala │ │ │ │ │ ├── runNumericBucketer.scala │ │ │ │ │ └── runNumericStandardizer.scala │ │ │ └── model │ │ │ │ ├── runFromDenseDataToXDMLH2ODRF.scala │ │ │ │ ├── runFromDenseDataToXDMLH2OGBM.scala │ │ │ │ ├── runFromDenseDataToXDMLH2OGLM.scala │ │ │ │ ├── runFromDenseDataToXDMLH2OMLP.scala │ │ │ │ ├── runFromLibSVMDataToXDMLLinearScopeModel.scala │ │ │ │ ├── runFromLibSVMDataToXDMLOVR.scala │ │ │ │ └── runFromLibSVMDataToXDMLSR.scala │ │ └── ml │ │ │ ├── DCASGDTest.scala │ │ │ ├── FFMTest.scala │ │ │ ├── FTRLTest.scala │ │ │ ├── LRTest.scala │ │ │ └── MomentumTest.scala │ │ ├── feature │ │ ├── analysis │ │ │ ├── KolmogorovSmirnovTest.java │ │ │ └── UniversalAnalyzer.scala │ │ └── process │ │ │ ├── CategoryEncoder.scala │ │ │ ├── ColumnEliminator.scala │ │ │ ├── ColumnRenamer.scala │ │ │ ├── FeatureProcessor.scala │ │ │ ├── MultiCategoryEncoder.scala │ │ │ ├── NumericBucketer.scala │ │ │ ├── NumericStandardizer.scala │ │ │ └── PipelinePatch.scala │ │ ├── linalg │ │ └── BLAS.scala │ │ ├── mapstore │ │ ├── KuduTableOp.scala │ │ └── PSMapStore.scala │ │ ├── ml │ │ ├── FieldawareFactorizationMachine.scala │ │ ├── LogisticRegression.scala │ │ ├── LogisticRegressionWithDCASGD.scala │ │ ├── LogisticRegressionWithFTRL.scala │ │ └── LogisticRegressionWithMomentum.scala │ │ ├── model │ │ ├── data │ │ │ ├── DataHandler.scala │ │ │ ├── DataProcessor.scala │ │ │ ├── LogHandler.scala │ │ │ └── SchemaHandler.scala │ │ ├── linalg │ │ │ └── BLAS.scala │ │ ├── loss │ │ │ ├── HingeLossFunc.scala │ │ │ ├── L2LossFunc.scala │ │ │ ├── LogitLossFunc.scala │ │ │ ├── LossFunc.scala │ │ │ ├── MultiHingeLossFunc.scala │ │ │ ├── MultiLogitLossFunc.scala │ │ │ ├── MultiLossFunc.scala │ │ │ ├── MultiSmoothHingeLossFunc.scala │ │ │ ├── PoissonLossFunc.scala │ │ │ ├── SmoothHingeLossFunc.scala │ │ │ ├── UPULogitLossFunc.scala │ │ │ ├── WeightedHingeLossFunc.scala │ │ │ ├── WeightedLogitLossFunc.scala │ │ │ └── WeightedSmoothHingeLossFunc.scala │ │ ├── supervised │ │ │ ├── GLM │ │ │ │ ├── LinearScope.scala │ │ │ │ ├── MultiLinearScope.scala │ │ │ │ └── OVRLinearScope.scala │ │ │ └── H2O │ │ │ │ ├── H2ODRF.scala │ │ │ │ ├── H2OGBM.scala │ │ │ │ ├── H2OGLM.scala │ │ │ │ ├── H2OMLP.scala │ │ │ │ └── H2OParams.scala │ │ └── util │ │ │ └── MLUtils.scala │ │ ├── optimization │ │ ├── BinaryClassify.scala │ │ ├── FFM.scala │ │ ├── FM.scala │ │ └── Optimizer.scala │ │ ├── ps │ │ ├── PS.scala │ │ └── PSClient.scala │ │ ├── task │ │ ├── PullTask.scala │ │ ├── PushTask.scala │ │ └── Task.scala │ │ ├── updater │ │ ├── DCASGDUpdater.scala │ │ ├── FFMUpdater.scala │ │ ├── LRFTRLUpdater.scala │ │ ├── LRUpdater.scala │ │ ├── MomLRUpdater.scala │ │ └── Updater.scala │ │ └── utils │ │ ├── CityHash.java │ │ ├── ExitCodeResolver.java │ │ ├── JHex.java │ │ └── XDMLException.java │ └── org │ └── apache │ └── kudu │ └── client │ └── SessionConfiguration.java └── test └── scala ├── DataColFilter.scala ├── DataFilter.scala ├── DataRowFilter.scala └── TestFilter.scala /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 | 4 | 5 | 6 |
7 | 8 | [![license](https://img.shields.io/badge/license-Apache2.0-blue.svg?style=flat)](./LICENSE) 9 | [![Release Version](https://img.shields.io/badge/release-1.0-red.svg)]() 10 | [![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)]() 11 | 12 | 13 | **XDML**是一款基于参数服务器(Parameter Server),采用专门缓存机制的分布式机器学习平台。 14 | XDML内化了学界最新研究成果,在效果保持稳定的同时,能大幅加速收敛进程,显著提升模型与算法的性能。同时,XDML还对接了一些优秀的开源成果和360公司自研成果,站在巨人的肩膀上,博采众长。 XDML还兼容hadoop生态,提供更好的大数据框架使用体验,将开发者从繁杂的工作中解脱出来。XDML已经在360内部海量规模数据上进行了大量测试和调优,在大规模数据量和超高维特征的机器学习任务上,具有良好的稳定性,扩展性和兼容性。 15 | 16 | 欢迎对机器学习或分布式有兴趣的同仁一起贡献代码,提交Issues或者Pull Requests。 17 | 18 | ## 架构设计 19 | ![architecture](./doc/img/xdml.png) 20 | 21 | 针对超大规模机器学习的场景,奇虎360开源了内部的超大规模机器学习计算框架XDML。XDML是一款基于参数服务器(Parameter Server),采用专门缓存机制的分布式机器学习平台。它在360内部海量规模数据上进行了测试和调优,在大规模数据量和超高维特征的机器学习任务上,具有良好的稳定性,扩展性和兼容性。 22 | 23 | ## 功能特性 24 | #### 1.提供特征预处理/分析,离线训练,模型管理等功能模块 25 | #### 2.实现常用的大规模数据量场景下的机器学习算法 26 | #### 3.充分利用现有的成熟技术,保证整个框架的高效稳定 27 | #### 4.完全兼容hadoop生态,和现有的大数据工具实现无缝对接,提升处理海量数据的能力 28 | #### 5.在系统架构和算法层面实现深度的工程优化,在不损失精度的前提下,大幅提高性能 29 | 30 | 31 | ## 代码结构 32 | 33 | ### 1.ps 34 | XDML的核心参数服务器架构,包括以下组件: 35 | 36 | - [PS](./doc/PS.md) 37 | - [PSClient](./doc/PSClient.md) 38 | 39 | ### 2.conf 40 | XDML的配置包,包括对参数服务器的配置和对作业及模型相关的配置。包括以下组件: 41 | 42 | - [JobConfiguration](./doc/JobConfigure.md) 43 | - [PSConfiguration](./doc/PSConfiguration.md) 44 | - ... 45 | 46 | ### 3.task 47 | XDML向PS提交的作业,包括拉取和推送。包括以下任务: 48 | 49 | - Task 50 | - PullTask 51 | - PushTask 52 | 53 | ### 4.optimization 54 | XDML模型的优化算法包。包括以下优化算法: 55 | 56 | - [BinaryClassify](./doc/BinaryClassify.md) 57 | - [FFM](./doc/FFMProcessor.md) 58 | - ... 59 | 60 | ### 5.ml 61 | XDML中已经实现的部分机器学习模型。包括以下模型: 62 | 63 | - [LogisticRegression](./doc/LogisticRegression.md) 64 | - [LogisticRegressionWithDCASGD](./doc/LogisticRegressionWithDCASGD.md) 65 | - [LogisticRegressionWithFTRL](./doc/LogisticRegressionWithFTRL.md) 66 | - [LogisticRegressionWithMomentum](./doc/LogisticRegressionWithMomentum.md) 67 | - [FieldwareFactorizationMachine](./doc/FieldawareFactorizationMachine.md) 68 | - ... 69 | 70 | ### 6.feature 71 | XDML中特征分析和特征处理模块。 72 | 73 | - [特征分析](./doc/FeatureAnalysis.md) 74 | 75 | 特征分析覆盖常见的分析指标,如数值型特征的偏度、峰度、分位数,与label相关的auc、ndcg、互信息、相关系数等指标。 76 | 77 | - [特征处理](./doc/FeatureProcess.md) 78 | 79 | 特征处理覆盖常见的数值型、类别型特征预处理方法。包括以下算子: 80 | - CategoryEncoder 81 | - MultiCategoryEncoder 82 | - NumericBuckter 83 | - NumericStandardizer 84 | 85 | ### 7.model 86 | XDML中包含用南京大学李武军老师提出的[Scope](https://arxiv.org/pdf/1602.00133.pdf)优化算法进行训练的线性模型,以及部分[H2O](https://www.h2o.ai/)模型的spark pipeline封装。具体包括以下模型: 87 | 88 | [Model:](./doc/Model.md) 89 | 90 | - LinearScope 91 | - MultiLinearScope 92 | - OVRLinearScope 93 | - H2ODRF 94 | - H2OGBM 95 | - H2OGLM 96 | - H2OMLP 97 | 98 | ### 8.example 99 | XDML中作业提交实例,可以参考[Example](./doc/Example.md). 100 | 101 | ## 编译&部署指南 102 | 103 | XDML是基于Kudu、HazelCast以及Hadoop生态圈的一款基于参数服务器的,采用专门缓存机制的分布式机器学习平台。 104 | 105 | ### 环境依赖 106 | - centos >= 7 107 | - Jdk >= 1.8 108 | - Maven >= 3.5.4 109 | - scala >= 2.11 110 | - hadoop >= 2.7.3 111 | - spark >= 2.3.0 112 | - sparkling-water-core >= 2.3.0 113 | - kudu >= 1.9 114 | - HazelCast >= 3.9.3 115 | 116 | ### Kudu安装部署 117 | XDML基于Kudu,请首先部署Kudu。Kudu的安装部署请参考[Kudu](https://github.com/apache/kudu/tree/1.7.0)。 118 | 119 | ### 源码下载 120 | ```git clone https://github.com/Qihoo360/XLearning-XDML``` 121 | 122 | ### 编译 123 | ```mvn clean package -Dmaven.test.skip=true``` 124 | 编译完成后,在源码根目录的`target`目录下会生成:`xdml-1.0.jar`、`xdml-1.0-jar-with-dependencies.jar`等多个文件,`xdml-1.0.jar`为未加spark、kudu等第三方依赖,`xdml-1.0-jar-with-dependencies.jar`添加了spark、kudu等依赖包。 125 | 126 | ## 运行示例 127 | 128 | ### 提交参数 129 | * **算法参数** 130 | * spark.xdml.learningRate:学习率 131 | * **训练参数** 132 | * spark.xdml.job.type:作业类型 133 | * spark.xdml.train.data.path:训练数据路径 134 | * spark.xdml.train.data.partitionNum:训练数据分区 135 | * spark.xdml.model.path:模型存储路径 136 | * spark.xdml.train.iter:训练迭代次数 137 | * spark.xdml.train.batchsize:训练数据batch大小 138 | * **PS相关参数** 139 | * spark.xdml.hz.clusterNum:hazelcast集群机器数目 140 | * spark.xdml.table.name:kudu表名称 141 | 142 | ### 提交命令 143 | 可以通过以下命令提交示例训练作业: 144 | 145 | ``` 146 | $SPARK_HOME/bin/spark-submit \ 147 | --master yarn-cluster \ 148 | --class net.qihoo.xitong.xdml.example.LRTest \ 149 | --num-executors 50 \ 150 | --executor-memory 40g \ 151 | --executor-cores 2 \ 152 | --driver-memory 4g \ 153 | --conf "spark.xdml.table.name=lrtest" \ 154 | --conf "spark.xdml.job.type=train" \ 155 | --conf "spark.xdml.train.data.path=$trainpath" \ 156 | --conf "spark.xdml.train.data.partitionNum=50" \ 157 | --conf "spark.xdml.hz.clusterNum=50" \ 158 | --conf "spark.xdml.model.path=$modelpath" \ 159 | --conf "spark.xdml.train.iter=5" \ 160 | --conf "spark.xdml.train.batchsize=10000" \ 161 | --conf "spark.xdml.learningRate=0.1" \ 162 | --jars xdml-1.0-jar-with-dependencies.jar \ 163 | xdml-1.0-jar-with-dependencies.jar 164 | 165 | ``` 166 | 167 | 注:提交命令中的设置有`$SPARK_HOME`、`$trainpath`、`$modelpath` 分别代表spark客户端路径、训练数据HDFS路径、模型存储HDFS路径 168 | 169 | ## FAQ 170 | [**XDML常见问题**](./doc/faq_cn.md) 171 | 172 | ## 参考文献 173 | XDML参考了学界及工业界诸多优秀成果,对此表示感谢! 174 | 175 | - Shen-Yi Zhao, Ru Xiang, Ying-Hao Shi, Peng Gao, Wu-Jun Li, [SCOPE: Scalable Composite Optimization for Learning on Spark](https://arxiv.org/pdf/1602.00133.pdf). AAAI 2017: 2928-2934. 176 | - Shen-Yi Zhao, Gong-Duo Zhang, Ming-Wei Li, Wu-Jun Li.[Proximal SCOPE for Distributed Sparse Learning](https://arxiv.org/pdf/1803.05621.pdf).Proceedings of the Annual Conference on Neural Information Processing Systems (NIPS), 2018. 177 | - Shuxin Zheng, Qi Meng, Taifeng Wang, Wei Chen, Zhi-Ming Ma and Tie-Yan Liu, [Asynchronous Stochastic Gradient Descent with Delay Compensation](https://arxiv.org/pdf/1609.08326.pdf), ICML 2017. 178 | 179 | ## 联系我们 180 | 181 | Mail: 182 | QQ群:874050710 183 | ![qq](./doc/img/qq.jpg) -------------------------------------------------------------------------------- /doc/BinaryClassify.md: -------------------------------------------------------------------------------- 1 | # BinaryClassify 2 | 3 | --- 4 | 5 | 6 | > 二分类模型中计算梯度和损失的工具类 7 | 8 | ## 功能 9 | 10 | * 计算二分类模型中训练数据的梯度和损失值,并且可以用predict函数进行预测。 11 | 12 | ## 核心接口 13 | 14 | 1. **train** 15 | - 定义: 16 | ``` 17 | train(label: Double, 18 | feature: (Array[Long], Array[Float]), 19 | localW: Map[Long, Float], 20 | subsampling_rate: Double = 1.0, 21 | subsampling_label: Double = 0.0 22 | ): (Map[Long, Float], Double) 23 | ``` 24 | - 功能描述:处理原始的文本类型的RDD数据,返回新的数值化的RDD 25 | - 参数: 26 | - label:样本的标签label 27 | - feature:样本的特征和对应的value值 28 | - localW:本地的特征权重值Map 29 | - subsampling_rate:负采样率(未启用) 30 | - subsampling_label:负采样标签(未启用) 31 | - 返回值: 32 | - (Map[Long, Float], Double) :计算得到的样本的梯度和损失值 33 | * 第一列为样本特征对应的梯度, 34 | * 第二列为样本的损失 35 | 36 | 1. **predict** 37 | - 定义:```predict(feature: (Array[Long], Array[Float]),weight: Map[Long, Float]): Double ``` 38 | - 功能描述:对二分类模型进行预测 39 | - 参数: 40 | - feature:样本的特征id和value值 41 | - weight:特征weight权重Map 42 | - 返回值:Double:该条样本预测的标签 -------------------------------------------------------------------------------- /doc/CityHash.md: -------------------------------------------------------------------------------- 1 | # CityHash 2 | 3 | --- 4 | 5 | 6 | > 计算hash值的工具类,可用于将字符串特征进行hash化得到数值型id,是google开源的cityhash的java版本,具有效率高超低碰撞的特点 7 | 8 | ## 功能 9 | 10 | * google的cityhash算法的java版本,提供32位,64位和128位哈希值运算。 11 | 12 | 13 | ## 核心接口 14 | 15 | 1. **stringCityHash32** 16 | - 定义:```int stringCityHash64(String str) ``` 17 | - 功能描述:计算指定字符串的32位的hash值 18 | - 参数: 19 | - str:需要计算的字符串 20 | - 返回值:int:32位hash值 21 | 22 | 1. **stringCityHash64** 23 | - 定义:```long stringCityHash64(String str) ``` 24 | - 功能描述:计算指定字符串的64位的hash值 25 | - 参数: 26 | - str:需要计算的字符串 27 | - 返回值:long:64位hash值 28 | 29 | 1. **stringCityHash128** 30 | - 定义:```long[] stringCityHash128(String str) ``` 31 | - 功能描述:计算指定字符串的128位的hash值 32 | - 参数: 33 | - str:需要计算的字符串 34 | - 返回值:long[]:128位hash值 35 | -------------------------------------------------------------------------------- /doc/FFMProcessor.md: -------------------------------------------------------------------------------- 1 | # FFMPreocessor 2 | 3 | --- 4 | 5 | 6 | > 处理标准libFFM格式数据的工具类 7 | 8 | ## 功能 9 | 10 | * 处理标准的文本libFFM数据,得到样本的RDD 11 | 12 | ## 核心接口 13 | 14 | 1. **processData** 15 | - 定义:```processData(data:RDD[String],separator:String = " "):RDD[(Double,Array[Int],Array[Long],Array[Float])] ``` 16 | - 功能描述:处理原始的文本类型的RDD数据,返回新的数值化的RDD 17 | - 参数: 18 | - data:原始的文本数据RDD 19 | - separator:特征数据之间的分隔符,默认为单个空格 20 | - 返回值: 21 | - RDD[(Double,Array[Int],Array[Long],Array[Float])] :返回数值化的RDD, 22 | * 第一列为样本标签, 23 | * 第二列为field数组 24 | * 第三列为featureId数组 25 | * 第四列为feature对应的value数组 -------------------------------------------------------------------------------- /doc/FieldawareFactorizationMachine.md: -------------------------------------------------------------------------------- 1 | # FieldawareFactorizationMachine 2 | 3 | --- 4 | 5 | 6 | > FFM模型 7 | 8 | ## 功能 9 | 10 | * 实现了FFM,使用SGD的优化方法。提供训练的fit方法和预测的predict方法。 11 | 12 | 13 | ## 核心接口 14 | 15 | 1. **fit** 16 | - 定义:```fit(data: RDD[(Double, Array[Int], Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) ``` 17 | - 功能描述:模型的训练方法 18 | - 参数: 19 | - data:训练的输入数据RDD,1为label标签,2为特征的filed域,3为特征id数组,4为相应特征值数组 20 | - 返回值: 21 | - (Map[Int, Double], (Long, Long)): 22 | - 第一列为训练信息Map,key为迭代轮次,value为每轮的loss损失 23 | - 第二列为训练数据的正负样本数 24 | 25 | 2. **predict** 26 | - 定义:```def predict(data: RDD[(Double, Array[Long], Array[Float])]): RDD[(Double, Double)] ``` 27 | - 功能描述:模型的预测方法 28 | - 参数: 29 | - data:预测的输入数据RDD,1为label标签,2为特征的filed域,3为特征id数组,4为相应特征值数组 30 | - 返回值: 31 | - RDD[(Double, Double)]:样本label和预测的label 32 | 33 | 3. **setIterNum** 34 | - 定义:```setIterNum(iterNum: Int): this.type ``` 35 | - 功能描述:设置数据迭代次数 36 | - 参数: 37 | - iterNum: 迭代次数 38 | - 返回值:this 39 | 40 | 4. **setBatchSize** 41 | - 定义:```setBatchSize(batchSize: Int): this.type ``` 42 | - 功能描述:设置训练batch的大小 43 | - 参数: 44 | - batchSize:一组batch的大小 45 | - 返回值:this 46 | 47 | 5. **setLearningRate** 48 | - 定义:```setLearningRate(lr: Float): this.type ``` 49 | - 功能描述:设置模型的学习率 50 | - 参数: 51 | - lr:学习率 52 | - 返回值:this 53 | 54 | 6. **setRank** 55 | - 定义:```setRank(rank:Int):this.type ``` 56 | - 功能描述:设置FFM的单个域的向量长度 57 | - 参数: 58 | - rank:域向量长度 59 | - 返回值:this 60 | 61 | 6. **setField** 62 | - 定义:```setField(field:Int):this.type ``` 63 | - 功能描述:设置FFM的单个域的向量长度 64 | - 参数: 65 | - field:设置field域的个数 66 | - 返回值:this -------------------------------------------------------------------------------- /doc/JobConfigure.md: -------------------------------------------------------------------------------- 1 | ## 作业相关配置 2 | 3 | 配置项名称 | 默认值 | 配置项含义 4 | ---------------- | --------------- | --------------- 5 | spark.xdml.job.type | "train" | 设置作业类型:train,predict,increment_train 6 | spark.xdml.table.name | "table" | 设置kudu存储的表名称 7 | spark.xdml.train.data.path | 无 | 训练数据路径 8 | spark.xdml.train.data.partitionNum | | 训练数据分区数目 9 | spark.xdml.data.split | " " | 数据分隔符 10 | spark.xdml.model.path | 无 | 模型存储路径 11 | spark.xdml.predict.result.path | 无 | 预测结果输出路径 12 | spark.xdml.kudu.master | 无 | kudu集群master地址 13 | spark.xdml.hz.clusterNum | 50 | hazelcast集群节点数目 14 | spark.xdml.hz.partitionNum | 127 | hazelcast分区数目 15 | spark.xdml.hz.maxcachesize | 50000000 | hazelcast分区最大缓存大小 16 | spark.xdml.train.iter |1 | 训练迭代次数 17 | spark.xdml.train.batchsize | 50 | 训练数据batch大小 18 | spark.xdml.learningRate | 0.01f| 设置学习率 19 | spark.xdml.momentumCoff | 0.1f | 动量阻力参数 20 | spark.xdml.train.alpha | 1f | 设置ftrl等算法训练所需参数alpha 21 | spark.xdml.train.beta | 1f | 设置ftrl等算法训练所需参数beta 22 | spark.xdml.train.lambda1 | 1f | 设置ftrl算法训练所需参数lambda1 23 | spark.xdml.train.lambda2 | 1f | 设置ftrl算法训练所需参数lambda1 24 | spark.xdml.train.forcesparse | false | 设置ftrl算法是否保持强稀疏性,即w为0时,z、n对应值也为0 25 | spark.xdml.model.ffm.rank | 1 | 设置ffm算法中的k值 26 | spark.xdml.model.ffm.field | 1 | 设置ffm算法中的 27 | 28 | -------------------------------------------------------------------------------- /doc/LibSVMProcessor.md: -------------------------------------------------------------------------------- 1 | # LibSVMPreocessor 2 | 3 | --- 4 | 5 | 6 | > 处理标准libSVM格式数据的工具类 7 | ## 功能 8 | 9 | * 处理标准的文本libFFM数据,得到样本的RDD 10 | 11 | ## 核心接口 12 | 13 | 1. **processData** 14 | - 定义:```processData(data:RDD[String],separator:String = " "):RDD[(Double,Array[Long],Array[Float])] ``` 15 | - 功能描述:处理原始的文本类型的RDD数据,返回新的数值化的RDD 16 | - 参数: 17 | - data:原始的文本数据RDD 18 | - separator:特征数据之间的分隔符,默认为单个空格 19 | - 返回值: 20 | - RDD[(Double,Array[Long],Array[Float])] :返回数值化的RDD, 21 | * 第一列为样本标签, 22 | * 第二列为featureId数组 23 | * 第三列为feature对应的value值数组 -------------------------------------------------------------------------------- /doc/LogisticRegression.md: -------------------------------------------------------------------------------- 1 | # LogisticRegression 2 | 3 | --- 4 | 5 | 6 | > 逻辑回归模型 7 | 8 | ## 功能 9 | 10 | * 实现了逻辑回归模型,使用SGD的优化方法。提供训练的fit方法和预测的predict方法。 11 | 12 | 13 | ## 核心接口 14 | 15 | 1. **fit** 16 | - 定义:```fit(data: RDD[(Double, Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) ``` 17 | - 功能描述:模型的训练方法 18 | - 参数: 19 | - data:训练的输入数据RDD,1为label标签,2为特征id数组,3为相应特征值数组 20 | - 返回值: 21 | - (Map[Int, Double], (Long, Long)): 22 | - 第一列为训练信息Map,key为迭代轮次,value为每轮的loss损失 23 | - 第二列为训练数据的正负样本数 24 | 25 | 2. **predict** 26 | - 定义:```def predict(data: RDD[(Double, Array[Long], Array[Float])]): RDD[(Double, Double)] ``` 27 | - 功能描述:模型的预测方法 28 | - 参数: 29 | - data:预测的输入数据RDD,1为label标签,2为特征id数组,3为相应特征值数组 30 | - 返回值: 31 | - RDD[(Double, Double)]:样本label和预测的label 32 | 33 | 3. **setIterNum** 34 | - 定义:```setIterNum(iterNum: Int): this.type ``` 35 | - 功能描述:设置数据迭代次数 36 | - 参数: 37 | - iterNum: 迭代次数 38 | - 返回值:this 39 | 40 | 4. **setBatchSize** 41 | - 定义:```setBatchSize(batchSize: Int): this.type ``` 42 | - 功能描述:设置训练batch的大小 43 | - 参数: 44 | - batchSize:一组batch的大小 45 | - 返回值:this 46 | 47 | 5. **setLearningRate** 48 | - 定义:```setLearningRate(lr: Float): this.type ``` 49 | - 功能描述:设置模型的学习率 50 | - 参数: 51 | - lr:学习率 52 | - 返回值:this 53 | 54 | -------------------------------------------------------------------------------- /doc/LogisticRegressionWithDCASGD.md: -------------------------------------------------------------------------------- 1 | # LogisticRegressionWithDCASGD 2 | 3 | --- 4 | 5 | 6 | > 带有延迟梯度更新的逻辑回归模型 7 | 8 | ## 功能 9 | 10 | * 实现了基于延迟梯度更新(DC-ASGD)的逻辑回归模型,使用SGD的优化方法。提供训练的fit方法和预测的predict方法。 11 | 12 | 13 | ## 核心接口 14 | 15 | 1. **fit** 16 | - 定义:```fit(data: RDD[(Double, Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) ``` 17 | - 功能描述:模型的训练方法 18 | - 参数: 19 | - data:训练的输入数据RDD,1为label标签,2为特征id数组,3为相应特征值数组 20 | - 返回值: 21 | - (Map[Int, Double], (Long, Long)): 22 | - 第一列为训练信息Map,key为迭代轮次,value为每轮的loss损失 23 | - 第二列为训练数据的正负样本数 24 | 25 | 2. **predict** 26 | - 定义:```def predict(data: RDD[(Double, Array[Long], Array[Float])]): RDD[(Double, Double)] ``` 27 | - 功能描述:模型的预测方法 28 | - 参数: 29 | - data:预测的输入数据RDD,1为label标签,2为特征id数组,3为相应特征值数组 30 | - 返回值: 31 | - RDD[(Double, Double)]:样本label和预测的label 32 | 33 | 3. **setIterNum** 34 | - 定义:```setIterNum(iterNum: Int): this.type ``` 35 | - 功能描述:设置数据迭代次数 36 | - 参数: 37 | - iterNum: 迭代次数 38 | - 返回值:this 39 | 40 | 4. **setBatchSize** 41 | - 定义:```setBatchSize(batchSize: Int): this.type ``` 42 | - 功能描述:设置训练batch的大小 43 | - 参数: 44 | - batchSize:一组batch的大小 45 | - 返回值:this 46 | 47 | 5. **setLearningRate** 48 | - 定义:```setLearningRate(lr: Float): this.type ``` 49 | - 功能描述:设置模型的学习率 50 | - 参数: 51 | - lr:学习率 52 | - 返回值:this 53 | 54 | 6. **setDcAsgdCoff** 55 | - 定义:```setDcAsgdCoff(coff:Float):this.type ``` 56 | - 功能描述:设置延迟梯度系数 57 | - 参数: 58 | - coff:延迟梯度系数 59 | - 返回值:this 60 | -------------------------------------------------------------------------------- /doc/LogisticRegressionWithFTRL.md: -------------------------------------------------------------------------------- 1 | # LogisticRegressionWithFTRL 2 | 3 | --- 4 | 5 | 6 | > 使用FTRL优化方法的逻辑回归模型 7 | 8 | ## 功能 9 | 10 | * 实现了带有FTRL优化的逻辑回归模型。提供训练的fit方法和预测的predict方法。 11 | 12 | 13 | ## 核心接口 14 | 15 | 1. **fit** 16 | - 定义:```fit(data: RDD[(Double, Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) ``` 17 | - 功能描述:模型的训练方法 18 | - 参数: 19 | - data:训练的输入数据RDD,1为label标签,2为特征id数组,3为相应特征值数组 20 | - 返回值: 21 | - (Map[Int, Double], (Long, Long)): 22 | - 第一列为训练信息Map,key为迭代轮次,value为每轮的loss损失 23 | - 第二列为训练数据的正负样本数 24 | 25 | 2. **predict** 26 | - 定义:```def predict(data: RDD[(Double, Array[Long], Array[Float])]): RDD[(Double, Double)] ``` 27 | - 功能描述:模型的预测方法 28 | - 参数: 29 | - data:预测的输入数据RDD,1为label标签,2为特征id数组,3为相应特征值数组 30 | - 返回值: 31 | - RDD[(Double, Double)]:样本label和预测的label 32 | 33 | 3. **setIterNum** 34 | - 定义:```setIterNum(iterNum: Int): this.type ``` 35 | - 功能描述:设置数据迭代次数 36 | - 参数: 37 | - iterNum: 迭代次数 38 | - 返回值:this 39 | 40 | 4. **setBatchSize** 41 | - 定义:```setBatchSize(batchSize: Int): this.type ``` 42 | - 功能描述:设置训练batch的大小 43 | - 参数: 44 | - batchSize:一组batch的大小 45 | - 返回值:this 46 | 47 | 5. **setAlpha** 48 | - 定义:```setAlpha(alpha: Float): this.type ``` 49 | - 功能描述:设置FTRL中的参数alpha 50 | - 参数: 51 | - alpha:FTRL中的参数alpha 52 | - 返回值:this 53 | 54 | 6. **setBeta** 55 | - 定义:```setBeta(beta: Float): this.type ``` 56 | - 功能描述:设置FTRL中的参数beta 57 | - 参数: 58 | - beta:FTRL中的参数beta 59 | - 返回值:this 60 | 61 | 7. **setLambda1** 62 | - 定义:```setLambda1(lambda1: Float): this.type ``` 63 | - 功能描述:设置FTRL中的参数lambda1 64 | - 参数: 65 | - lambda1:FTRL中的参数lambda1 66 | - 返回值:this 67 | 68 | 8. **setLambda2** 69 | - 定义:```setLambda2(lambda2: Float): this.type ``` 70 | - 功能描述:设置FTRL中的参数lambda1 71 | - 参数: 72 | - lambda2:FTRL中的参数lambda1 73 | - 返回值:this 74 | -------------------------------------------------------------------------------- /doc/LogisticRegressionWithMomentum.md: -------------------------------------------------------------------------------- 1 | # LogisticRegressionWithDCASGD 2 | 3 | --- 4 | 5 | 6 | > 带有冲量的逻辑回归模型 7 | 8 | ## 功能 9 | 10 | * 实现了带有冲量的逻辑回归模型,使用SGD的优化方法。提供训练的fit方法和预测的predict方法。 11 | 12 | 13 | ## 核心接口 14 | 15 | 1. **fit** 16 | - 定义:```fit(data: RDD[(Double, Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) ``` 17 | - 功能描述:模型的训练方法 18 | - 参数: 19 | - data:训练的输入数据RDD,1为label标签,2为特征id数组,3为相应特征值数组 20 | - 返回值: 21 | - (Map[Int, Double], (Long, Long)): 22 | - 第一列为训练信息Map,key为迭代轮次,value为每轮的loss损失 23 | - 第二列为训练数据的正负样本数 24 | 25 | 2. **predict** 26 | - 定义:```def predict(data: RDD[(Double, Array[Long], Array[Float])]): RDD[(Double, Double)] ``` 27 | - 功能描述:模型的预测方法 28 | - 参数: 29 | - data:预测的输入数据RDD,1为label标签,2为特征id数组,3为相应特征值数组 30 | - 返回值: 31 | - RDD[(Double, Double)]:样本label和预测的label 32 | 33 | 3. **setIterNum** 34 | - 定义:```setIterNum(iterNum: Int): this.type ``` 35 | - 功能描述:设置数据迭代次数 36 | - 参数: 37 | - iterNum: 迭代次数 38 | - 返回值:this 39 | 40 | 4. **setBatchSize** 41 | - 定义:```setBatchSize(batchSize: Int): this.type ``` 42 | - 功能描述:设置训练batch的大小 43 | - 参数: 44 | - batchSize:一组batch的大小 45 | - 返回值:this 46 | 47 | 5. **setLearningRate** 48 | - 定义:```setLearningRate(lr: Float): this.type ``` 49 | - 功能描述:设置模型的学习率 50 | - 参数: 51 | - lr:学习率 52 | - 返回值:this 53 | 54 | 6. **setMomemtumCoff** 55 | - 定义:```setMomemtumCoff(coff:Float):this.type ``` 56 | - 功能描述:设置冲量系数 57 | - 参数: 58 | - coff:冲量系数 59 | - 返回值:this 60 | -------------------------------------------------------------------------------- /doc/PS.md: -------------------------------------------------------------------------------- 1 | # PS 2 | 3 | --- 4 | 5 | 6 | > PS是XDML中抽象出来的参数服务器类。 7 | 8 | ## 功能 9 | 10 | * 包括获取参数服务器单例,保存和载入模型,load参数,启动和销毁参数服务器等功能 11 | 12 | 13 | ## 核心接口 14 | 15 | 1. **getInstance** 16 | - 定义:```PS.getInstance(sc: SparkContext, psConf: PSConfiguration): PS ``` 17 | - 功能描述:设置作业参数服务器的作业类型,目前有train,predict和increment_train三种作业类型。 18 | - 参数: 19 | - sc:Spark作业的SparkContext实例 20 | - psConf:PSConfiguration的实例,包含了启动参数服务器的所有参数。 21 | - 返回值: 22 | - PS:参数服务器的单例 23 | 24 | 2. **getIPList** 25 | - 定义:```getIPList: Array[String] ``` 26 | - 功能描述:获取hazelcast server集群启动后的节点IP列表 27 | - 参数:无 28 | - 返回值: 29 | - Array[String]:hazelcast集群节点的ip组成的数组 30 | 31 | 3. **start** 32 | - 定义:```start(sc: SparkContext): Unit ``` 33 | - 功能描述:根据已配置的参数和作业类型启动参数服务器 34 | - 参数: 35 | - sc:Spark作业的SparkContext实例 36 | - 返回值:无 37 | 38 | 4. **createKuduTable** 39 | - 定义:```createKuduTable(): Unit ``` 40 | - 功能描述:在kudu集群中创建一张新的表 41 | - 参数:无 42 | - 返回值:无 43 | 44 | 5. **startHzCluster** 45 | - 定义:```startHzCluster(sc: SparkContext) ``` 46 | - 功能描述:根据传入的PSConfiguration的实例参数启动hazelcast集群 47 | - 参数: 48 | - sc:Spark作业的SparkContext实例 49 | - 返回值:无 50 | 51 | 6. **createHzMapConfig** 52 | - 定义:```createHzMapConfig(tableName: String): MapConfig ``` 53 | - 功能描述:创建启动hz节点需要的一个MapConfig的实例 54 | - 参数: 55 | - tableName:需要创建的表名 56 | - 返回值:MapConfig:启动hz所需的MapConfig的实例 57 | 58 | 7. **getAllParametersReady** 59 | - 定义: 60 | ```getAllParametersReady(sc: SparkContext, tableName: String = psConf.getPsTableName(), kuduContext: KuduContext): Long ``` 61 | - 功能描述:将kudu制定表中所有的参数载入到hz中 62 | - 参数: 63 | - sc:Spark作业的SparkContext实例 64 | - tableName: kudu中指定的表名,默认是PsConf中配置的第一个表名 65 | - kuduContext:KuduContext实例,类属性 66 | - 返回值:Long:载入的参数的总个数 67 | 68 | 8. **getRangeModelWeightFromBytes** 69 | - 定义:```getRangeModelWeightFromBytes(bytes: Array[Byte], startRange: Int, endRange: Int): Any ``` 70 | - 功能描述:从Byte数组中获取指定范围的泛型数据类型的数据 71 | - 参数: 72 | - bytes: 源byte数组数据 73 | - startRange: 起始索引(包含) 74 | - endRange: 结束索引(不包含) 75 | - 返回值:Any:泛型指定的数据类型数据 76 | 77 | 9. **getRangeBytesFromWeight** 78 | - 定义:```getRangeBytesFromWeight(values: Array[String], startRange: Int, endRange: Int): Array[Byte] ``` 79 | - 功能描述:从指定的字符串数组中按照ps设置的数据类型将指定范围的数据转换成Byte数组 80 | - 参数: 81 | - values:需要处理的字符串数组 82 | - startRange: 起始索引(包含) 83 | - endRange: 结束索引(不包含) 84 | - 返回值:Array[Byte]:转化后的Byte类型的数组 85 | 86 | 10. **saveModel** 87 | - 定义: 88 | 89 | saveModel(sc: SparkContext, 90 | outputPath: String, 91 | tableName: String = psConf.getPsTableName(), 92 | split: String = "\t", 93 | weightStartPos: Int = psConf.getWeightIndexRangeStart, 94 | weightEndPos: Int = psConf.getWeightIndexRangeEnd, 95 | kuduContext: KuduContext = PS.kuduContext): Unit 96 | 97 | - 功能描述:保存训练的模型到HDFS 98 | - 参数: 99 | - sc:Spark作业的SparkContext实例 100 | - outputPath: 模型输入的HDFS路径 101 | - tableName: 指定模型参数所在kudu指定的表名, 102 | - split: 模型参数文本的分隔符,默认为\t 103 | - weightStartPos: 指定需要保存模型参数的起始索引(包含),默认为0, 104 | - weightEndPos:指定需要保存模型参数的起始索引(不包含),默认为ps中设置的数据长度 105 | - kuduContext:KuduContext实例,类属性 106 | - 返回值:无 107 | 108 | 11. **loadModel** 109 | - 定义: 110 | 111 | loadModel(sc: SparkContext, 112 | inputPath: String, 113 | kuduContext: KuduContext, 114 | tableName: String = psConf.getPsTableName(), 115 | split: String = "\t", 116 | weightStartPos: Int = psConf.getWeightIndexRangeStart, 117 | weightEndPos: Int = psConf.getWeightIndexRangeEnd 118 | ): Unit 119 | 120 | - 功能描述:从HDFS中load模型到kudu集群的指定表中 121 | - 参数: 122 | - sc:Spark作业的SparkContext实例 123 | - inputPath: 需要载入的模型的HDFS路径 124 | - kuduContext:KuduContext实例,类属性 125 | - tableName:模型load到kudu指定的表名中 126 | - split: 模型参数文本的分隔符,默认为\t 127 | - weightStartPos: 指定需要载入模型参数的起始索引(包含),默认为0, 128 | - weightEndPos:指定需要载入模型参数的起始索引(不包含),默认为ps中设置的数据长度 129 | - 返回值:无 130 | -------------------------------------------------------------------------------- /doc/PSClient.md: -------------------------------------------------------------------------------- 1 | # PSClient 2 | 3 | --- 4 | 5 | 6 | > PSClient封装了用户模型与参数服务器之间参数交互关键的pull和push方法。 7 | 8 | ## 功能 9 | 10 | * 在参数服务器启动后,PSClient相当于是个代理,用户模型与参数服务器打交道在每一个RDD的partition中都需要一个PSClient的实例与参数服务器进行参数的拉取和推送。其中pull和push的过程中数据类型对用户是无感的,会根据ps启动时设置的数据类型自动匹配,用户只需要向PSClient进行pull和push操作即可,无需关心数据类型的转化。 11 | 12 | 13 | ## 核心接口 14 | 15 | 1. **pull** 16 | - 定义:```pull(featureId: util.HashSet[Long], tableName: String = ps.psConf.getPsTableName()): util.Map[Long, V] ``` 17 | - 功能描述:用户从参数服务器中拉去所需要的参数。 18 | - 参数: 19 | - featureId: 用户需要拉取的feature的id 20 | - tableName: 指定需要从哪个表中拉去参数,默认是ps启动时配置的第一张表名 21 | - 返回值: 22 | - - util.Map[Long, V]: 用户拉取的参数Map 23 | 24 | 2. **push** 25 | - 定义:```push(weightMap: util.Map[Long, V], tableName: String = ps.psConf.getPsTableName()): Unit ``` 26 | - 功能描述:用户向参数服务器推送保存参数。 27 | - 参数: 28 | - weightMap: 需要推送的参数id和值 29 | - tableName: 指定需要向哪张表推送参数,默认是ps启动时配置的第一张表名 30 | - 返回值:无 31 | 32 | 3. **init** 33 | - 定义:```init(): Unit ``` 34 | - 功能描述:初始化本实例的相关参数 35 | - 参数:无 36 | - 返回值:无 37 | 38 | 4. **validDataType** 39 | - 定义:```validDataType(): Unit ``` 40 | - 功能描述:验证PS中配置的PSDataType和PSClient实例中的泛型是否一致,若不一致则抛出异常 41 | - 参数:无 42 | - 返回值:无 43 | - 抛出异常:XDMLException 44 | 45 | 5. **setUpdater** 46 | - 定义:```setUpdater(updater: Updater[Long, V]): this.type ``` 47 | - 功能描述:设置本实例需要使用的Updater,在进行push之前会先调用updater实例中的update方法处理参数后再push到参数服务器。 48 | - 参数: 49 | - updater:该PSClient实例设定的updater 50 | - 返回值:this 51 | 52 | 6. **setPullRange** 53 | - 定义:```setPullRange(start: Int, end: Int): this.type ``` 54 | - 功能描述:设置拉去参数的索引范围 55 | - 参数: 56 | - start:索引起始值(包含) 57 | - end:索引结束值(不包含) 58 | - 返回值:this 59 | 60 | 7. **setPullRange** 61 | - 定义:```setPullRange(index: Int): this.type ``` 62 | - 功能描述:设置只拉取某个索引的参数 63 | - 参数: 64 | - index:参数索引值 65 | - 返回值:this 66 | 67 | 8. **getRangeVFromBytes** 68 | - 定义:```getRangeVFromBytes(bytes: Array[Byte]): V ``` 69 | - 功能描述:从byte数组中获取指定范围的泛型V指定类型的数据 70 | - 参数: 71 | - bytes:数据源Byte数组 72 | - 返回值:V:泛型数据 73 | 74 | 9. **shutDown** 75 | - 定义:```shutDown(): Unit ``` 76 | - 功能描述:关闭本PSClient实例 77 | - 参数:无 78 | - 返回值:无 79 | 80 | -------------------------------------------------------------------------------- /doc/PSConfiguration.md: -------------------------------------------------------------------------------- 1 | # PSConfiguration 2 | 3 | --- 4 | 5 | 6 | > PSConfiguration封装了所有XDML中参数服务器中使用的两个组件Kudu和Hazelcast所有的配置参数和PS相关的参数。 7 | 8 | ## 功能 9 | 10 | * 使用get和set方法获取和配置XDML中参数服务器的相关参数。 11 | 12 | 13 | ## 核心接口 14 | 15 | 1. **setJobType** 16 | - 定义:```setJobType(jobType:PSJobType) :this.type ``` 17 | - 功能描述:设置作业参数服务器的作业类型,目前有train,predict和increment_train三种作业类型。 18 | - 参数: 19 | - jobType: 枚举类JobType的实例 20 | - 返回值:this 21 | 22 | 2. **setKuduMaster** 23 | - 定义:```setKuduMaster(master: String): this.type ``` 24 | - 功能描述:设置kudu集群的master地址 25 | - 参数: 26 | - master: Kudu集群的mater节点地址 27 | - 返回值:this 28 | 29 | 3. **setKuduReplicaNum** 30 | - 定义:```setKuduReplicaNum(kuduReplicaNum: Int) ``` 31 | - 功能描述:设置kudu数据备份数,默认为3 32 | - 参数: 33 | - kuduReplicaNum:kudu数据备份数 34 | - 返回值:this 35 | 36 | 4. **setKuduPartitionNum** 37 | - 定义:```setKuduPartitionNum(kuduPartitionNum: Int) ``` 38 | - 功能描述:设置kudu的分区数,默认为50 39 | - 参数: 40 | - kuduPartitionNum:kudu存储的分区数 41 | - 返回值:this 42 | 43 | 5. **setKuduTableForceOverride** 44 | - 定义:```setKuduTableForceOverride(kuduTableForceOverride: Boolean): this.type ``` 45 | - 功能描述:设置当kudu中新建表时有重名的表出现是否强制覆盖旧表,默认为true 46 | - 参数: 47 | - kuduTableForceOverride:是否强制覆盖,true为覆盖,false为使用旧表 48 | - 返回值:this 49 | 50 | 6. **setKuduKeyColName** 51 | - 定义:```setKuduKeyColName(kuduKeyColName: String): this.type ``` 52 | - 功能描述:设置kudu表的键所在列的名字 53 | - 参数: 54 | - kuduKeyColName:键列的名字 55 | - 返回值:this 56 | 57 | 7. **setKuduValueColName** 58 | - 定义:```setKuduValueColName(kuduValueColName: String): this.type ``` 59 | - 功能描述:设置kudu表的值所在列的名字 60 | - 参数: 61 | - kuduValueColName:值列的名字 62 | - 返回值:this 63 | 64 | 8. **setKuduValueNullable** 65 | - 定义:```setKuduValueNullable(kuduValueNullable: Boolean) ``` 66 | - 功能描述:设置kudu表的值列是否可以为null,默认为false 67 | - 参数: 68 | - kuduValueNullable:值是否可以为null 69 | - 返回值:this 70 | 71 | 9. **setKuduRandomInit** 72 | - 定义:```setKuduRandomInit(kuduRandomInit: Boolean) ``` 73 | - 功能描述:设置是否需要参数的随机初始化,默认为false 74 | - 参数: 75 | - kuduRandomInit:布尔类型是否需要随机初始化值,默认范围是0-1之间随机值,可以通过setKuduRandomInit方法设置随机范围。 76 | - 返回值:this 77 | 78 | 10. **setHzClusterNum** 79 | - 定义:```setHzClusterNum(hzClusterNum:Int):this.type ``` 80 | - 功能描述:配置在参数服务器启动的时候需要启动的hz的节点数,默认为50 81 | - 参数: 82 | - hzClusterNum:参数服务器中需要启动的hz节点数 83 | - 返回值:this 84 | 85 | 11. **setHzPartitionNum** 86 | - 定义:```setHzPartitionNum(hzPartitionNum: Int): this.type ``` 87 | - 功能描述:设置hz的分区数,默认为271 88 | - 参数: 89 | - hzPartitionNum:Hazelcast的分区数 90 | - 返回值:this 91 | 92 | 12. **setHzClientThreadCount** 93 | - 定义:```setHzClientThreadCount(hzClientThreadCount: Int): this.type ``` 94 | - 功能描述:设置hazelcast的client最大处理线程数,默认为20 95 | - 参数: 96 | - hzClientThreadCount:hzclient的线程数 97 | - 返回值:this 98 | 99 | 13. **setHzOperationThreadCount** 100 | - 定义:```setHzOperationThreadCount(hzOperationThreadCount:Int) :this.type ``` 101 | - 功能描述:设置hz的server最大操作线程数,默认是20 102 | - 参数: 103 | - hzOperationThreadCount:hz的server最大操作线程数 104 | - 返回值:this 105 | 106 | 14. **setHzMaxCacheSizePerPartition** 107 | - 定义:```setHzMaxCacheSizePerPartition(hzMaxCacheSizePerPartition: Int): this.type ``` 108 | - 功能描述:设置Hz的单个分区最大缓存数据量,默认是50000000 109 | - 参数: 110 | - hzMaxCacheSizePerPartition:Hz的单个分区最大缓存数据量 111 | - 返回值:this 112 | 113 | 15. **setHzWriteDelaySeconds** 114 | - 定义:```setHzWriteDelaySeconds(hzWriteDelaySeconds: Int): this.type ``` 115 | - 功能描述:设置hz异步向kudu中持久化数据的时间间隔,单位为s,默认为20 116 | - 参数: 117 | - hzWriteDelaySeconds:久化数据的时间间隔 118 | - 返回值:this 119 | 120 | 16. **setHzWriteBatchSize** 121 | - 定义:```setHzWriteBatchSize(hzWriteBatchSize: Int): this.type ``` 122 | - 功能描述:设置hz写入的batchsize 123 | - 参数: 124 | - hzWriteBatchSize:hz写入的batchsize 125 | - 返回值:this 126 | 127 | 17. **setHzHeartbeatTime** 128 | - 定义:```setHzHeartbeatTime(hzHeartbeatTime: Int): this.type ``` 129 | - 功能描述:设置hz server的心跳时间,单位为ms,默认为3000 130 | - 参数: 131 | - hzHeartbeatTime:hz心跳时间 132 | - 返回值:this 133 | 134 | 18. **setHzClientHeartbeatTime** 135 | - 定义:```setHzClientHeartbeatTime(hzClientHeartbeatTime: Int): this.type ``` 136 | - 功能描述:设置hz client的心跳时间, 137 | - 参数: 138 | - hzClientHeartbeatTime:hz client的心跳时间 139 | - 返回值:this 140 | 141 | 19. **setHzIOThreadCount** 142 | - 定义:```setHzIOThreadCount(hzIOThreadCount:Int):this.type ``` 143 | - 功能描述:设置hazelcast的IO操作的线程数 144 | - 参数: 145 | - hzIOThreadCount:IO操作线程数 146 | - 返回值:this 147 | 148 | 20. **setPsTableName** 149 | - 定义:```setPsTableName(psTableName: String*): this.type ``` 150 | - 功能描述:设置本次作业使用的PS中的表名 151 | - 参数: 152 | - psTableName:为可变参数,参数服务器中使用的表名,如果使用多张表,可以设置多个。 153 | - 返回值:this 154 | 155 | 21. **setPsDataType** 156 | - 定义:```setPsDataType(psDataType: PSDataType): this.type ``` 157 | - 功能描述:设置参数服务器中使用的value的数据类型 158 | - 参数: 159 | - psDataType:PSDataType的枚举类型,目前包括Float,Double,Float_array和Double_array 160 | - 返回值:this 161 | 162 | 22. **setPsDataLength** 163 | - 定义:```setPsDataLength(psDataLength: Int): this.type ``` 164 | - 功能描述:设置参数服务器中使用的value的数据长度 165 | - 参数: 166 | - psDataLength:参数服务器中value的数据长度 167 | - 返回值:this 168 | 169 | 23. **setPredictModelPath** 170 | - 定义:```setPredictModelPath(predictModelPath:String):this.type ``` 171 | - 功能描述:设置预测作业中需要使用的模型路径 172 | - 参数: 173 | - predictModelPath:需要载入模型的HDFS路径 174 | - 返回值:this 175 | -------------------------------------------------------------------------------- /doc/PSMapStore.md: -------------------------------------------------------------------------------- 1 | # PSMapStore 2 | 3 | --- 4 | 5 | 6 | > PSMapStore实现了hazelcast中的MapStore接口。用于实现hazelcast缓存中数据向Kudu中的持久化,参数服务器会优先从hazelcast中拉取参数,如果没有命中则调用该接口从kudu中拉取或者初始化。 7 | 8 | ## 功能 9 | 10 | * 实现了hazelcast和kudu交互的接口规范,是实现两级式参数服务器的关键核心接口。数据在kudu中的存储统一为Array[Byte]字节数组的格式 11 | 12 | 13 | ## 核心接口 14 | 15 | 1. **init** 16 | - 定义:```init(): Unit ``` 17 | - 功能描述:用于本实例属性的初始化函数,在构造实例是自动调用 18 | - 参数:无 19 | - 返回值:无 20 | 21 | 2. **store** 22 | - 定义:```store(key: Long, value: Array[Byte]): Unit ``` 23 | - 功能描述:将hazelcast中的k-v数据存储到kudu中 24 | - 参数: 25 | - key: 存储数据的键 26 | - value: 存储数据的值 27 | - 返回值:无 28 | 29 | 3. **storeAll** 30 | - 定义:```def storeAll(map: util.Map[Long, Array[Byte]]): Unit ``` 31 | - 功能描述:将一个键值对集合Map存入kudu中 32 | - 参数: 33 | - map:键值对集合Map 34 | - 返回值:无 35 | 36 | 4. **load** 37 | - 定义:```load(key: Long): Array[Byte] ``` 38 | - 功能描述:从kudu中载入一个键值对 39 | - 参数: 40 | - key:需要载入的键 41 | - 返回值:Array[Byte]:返回该键对应的Byte数组的值 42 | 43 | 5. **loadAll** 44 | - 定义:```loadAll(keys: util.Collection[Long]): util.Map[Long, Array[Byte]] ``` 45 | - 功能描述:从kudu中拉取一个集合的键值对 46 | - 参数: 47 | - keys:需要从kudu中拉取的键值集合 48 | - 返回值:util.Map[Long, Array[Byte]]:返回拉取得到的键值对集合Map 49 | -------------------------------------------------------------------------------- /doc/Updater.md: -------------------------------------------------------------------------------- 1 | # Updater 2 | 3 | --- 4 | 5 | 6 | > Updater是PSClient中用到的一个接口,用户可以自定义实现其中的Updater方法,PSClient会在push数据到参数服务器之前先用updater方法处理传入的参数和梯度数据。 7 | 8 | ## 功能 9 | 10 | * 定义参数服务器中更新参数的方法,在执行push操作时自动调用。 11 | 12 | 13 | ## 核心接口 14 | 15 | 1. **setJobType** 16 | - 定义:```update(originalWeightMap: java.util.Map[K, V], gradientMap: java.util.Map[K, V]):java.util.Map[K, V] ``` 17 | - 功能描述:定义用户参数服务器更新参数方式的接口规范 18 | - 参数: 19 | - originalWeightMap:参数服务器上的原始参数 20 | - gradientMap:一般是梯度列表参数 21 | - 返回值:java.util.Map[K, V] :用户自定义更新后的Map数据,也就是真正更新到参数服务器的Map数据 -------------------------------------------------------------------------------- /doc/faq_cn.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qihoo360/XLearning-XDML/1256c0ce3a6757ff03fe54503a9f6d9416de24b9/doc/faq_cn.md -------------------------------------------------------------------------------- /doc/img/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qihoo360/XLearning-XDML/1256c0ce3a6757ff03fe54503a9f6d9416de24b9/doc/img/logo.jpg -------------------------------------------------------------------------------- /doc/img/qq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qihoo360/XLearning-XDML/1256c0ce3a6757ff03fe54503a9f6d9416de24b9/doc/img/qq.jpg -------------------------------------------------------------------------------- /doc/img/xdml.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qihoo360/XLearning-XDML/1256c0ce3a6757ff03fe54503a9f6d9416de24b9/doc/img/xdml.png -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/conf/JobConfiguration.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.conf 2 | 3 | import org.apache.spark.SparkContext 4 | 5 | import scala.collection.mutable 6 | 7 | 8 | object JobConfiguration { 9 | 10 | val DATA_SPLIT = "spark.xdml.data.split" 11 | val KUDU_MASTER = "spark.xdml.kudu.master" 12 | val LEARNING_RATE = "spark.xdml.learningRate" 13 | val TABLE_NAME = "spark.xdml.table.name" 14 | val MODEL_PATH = "spark.xdml.model.path" 15 | val DATA_PATH = "spark.xdml.data.path" 16 | val TRAIN_DATA_PARTITION = "spark.xdml.train.data.partitionNum" 17 | val TRAIN_ITER_NUM = "spark.xdml.train.iter" 18 | val BATCH_SIZE = "spark.xdml.train.batchsize" 19 | val SUBSAMPLE_RATE = "spark.xdml.subsample.rate" 20 | val SUBSAMPLE_LABEL = "spark.xdml.train.subsample.label" 21 | val HZ_CLUSTER_NUM = "spark.xdml.hz.clusterNum" 22 | val HZ_PARTITION_NUM = "spark.xdml.hz.partitionNum" 23 | val HZ_MAXCACHESIZE_PER_PARTITION = "spark.xdml.hz.maxcachesize" 24 | val SAVE_ALL_WEIGHTS = "spark.xdml.model.save.allweights" 25 | val SAVE_FEATURE = "spark.xdml.model.save.feature" 26 | val FEATURE_FILE = "spark.xdml.model.feature.path" 27 | val PREDICT_RESULT_PATH = "spark.xdml.predict.result.path" 28 | val JOB_TYPE = "spark.xdml.job.type" 29 | 30 | val DC_ASGD_COFF = "spark.xdml.dcasgdCoff" 31 | 32 | val MOMENTUM_COFF = "spark.xdml.momentumCoff" 33 | 34 | val FTRL_ALPHA = "spark.xdml.train.alpha" 35 | val FTRL_BETA = "spark.xdml.train.beta" 36 | val FTRL_LAMBDA1 = "spark.xdml.train.lambda1" 37 | val FTRL_LAMBDA2 = "spark.xdml.train.lambda2" 38 | val FTRL_FROCESPARSE = "spark.xdml.train.forcesparse" 39 | 40 | val FM_RANK = "spark.xdml.model.fm.rank" 41 | 42 | val FFM_RANK = "spark.xdml.model.ffm.rank" 43 | val FFM_FIELD = "spark.xdml.model.ffm.field" 44 | 45 | private val jobConfSet = Set( 46 | /** 47 | * Common configuration 48 | */ 49 | DATA_SPLIT, 50 | KUDU_MASTER, 51 | LEARNING_RATE, 52 | TABLE_NAME, 53 | MODEL_PATH, 54 | DATA_PATH, 55 | TRAIN_DATA_PARTITION, 56 | TRAIN_ITER_NUM, 57 | BATCH_SIZE, 58 | SUBSAMPLE_RATE, 59 | SUBSAMPLE_LABEL, 60 | HZ_CLUSTER_NUM, 61 | HZ_PARTITION_NUM, 62 | HZ_MAXCACHESIZE_PER_PARTITION, 63 | SAVE_ALL_WEIGHTS, 64 | SAVE_FEATURE, 65 | FEATURE_FILE, 66 | PREDICT_RESULT_PATH, 67 | JOB_TYPE, 68 | 69 | /** 70 | * Momentum 71 | */ 72 | MOMENTUM_COFF, 73 | 74 | /** 75 | * DC-ASGD 76 | */ 77 | DC_ASGD_COFF, 78 | /** 79 | * FTRL configuration 80 | */ 81 | FTRL_ALPHA, 82 | FTRL_BETA, 83 | FTRL_LAMBDA1, 84 | FTRL_LAMBDA2, 85 | FTRL_FROCESPARSE, 86 | 87 | /** 88 | * FM configuration 89 | */ 90 | FM_RANK, 91 | 92 | /** 93 | * FFM configuration 94 | */ 95 | FFM_RANK, 96 | FFM_FIELD 97 | ) 98 | 99 | } 100 | 101 | class JobConfiguration() { 102 | //read all job configuration 103 | def readJobConfig(sc: SparkContext): Map[String, String] = { 104 | val confMap = new mutable.HashMap[String, String]() 105 | for (confName <- JobConfiguration.jobConfSet) { 106 | val value = sc.getConf.get(confName, null) 107 | if (value != null) 108 | confMap.put(confName, value) 109 | } 110 | confMap.toMap 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/conf/JobType.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.conf 2 | 3 | object JobType extends Enumeration { 4 | type PSJobType = Value 5 | 6 | val TRAIN = Value(0) 7 | val PREDICT = Value(1) 8 | val INCREMENT_TRAIN = Value(2) 9 | 10 | } 11 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/conf/PSDataType.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.conf 2 | 3 | object PSDataType extends Enumeration { 4 | type PSDataType = Value 5 | 6 | val BYTE = Value(0) 7 | val SHORT = Value(1) 8 | val INT = Value(2) 9 | val LONG = Value(3) 10 | val FLOAT = Value(4) 11 | val DOUBLE = Value(5) 12 | val FLOAT_ARRAY = Value(6) 13 | val DOUBLE_ARRAY = Value(7) 14 | 15 | def sizeOf(dataType: PSDataType): Int ={ 16 | dataType match { 17 | case BYTE => 1 18 | case SHORT => 2 19 | case INT => 4 20 | case LONG => 8 21 | case FLOAT => 4 22 | case DOUBLE => 8 23 | case FLOAT_ARRAY => 4 24 | case DOUBLE_ARRAY => 8 25 | } 26 | } 27 | 28 | def isArrayType(tpe:PSDataType):Boolean = { 29 | tpe == FLOAT_ARRAY || tpe == DOUBLE_ARRAY 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/dataProcess/FFMProcessor.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.dataProcess 2 | import org.apache.spark.rdd.RDD 3 | 4 | object FFMProcessor { 5 | //process data format in libffm 6 | def processData(data:RDD[String],separator:String = " "):RDD[(Double,Array[Int],Array[Long],Array[Float])]={ 7 | data.map{ line => 8 | val splits = line.split(separator) 9 | val featureList = splits.slice(1, splits.length) 10 | (if(splits(0).toDouble > 0) 1.0D else 0.0D, featureList.map(x=>x.split(":")(0).toInt),featureList.map(x=>x.split(":")(1).toLong), featureList.map(x=>x.split(":")(2).toFloat)) 11 | } 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/dataProcess/LibSVMProcessor.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.dataProcess 2 | import org.apache.spark.rdd.RDD 3 | 4 | object LibSVMProcessor { 5 | //process data format in libsvm 6 | def processData(data:RDD[String],separator:String = " "):RDD[(Double, Array[Long], Array[Float])] = { 7 | data.map{ line => 8 | val splits = line.split(separator) 9 | val featureList = splits.slice(1, splits.length) 10 | (if(splits(0).toDouble > 0) 1.0D else 0.0D , featureList.map(x=>x.split(":")(0).toLong), featureList.map(x=>x.split(":")(1).toFloat)) 11 | } 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/analysis/runUniversalAnalyzerDense.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.analysis.feature.analysis 2 | 3 | import net.qihoo.xitong.xdml.feature.analysis._ 4 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler} 5 | import org.apache.spark.sql.SparkSession 6 | 7 | object runUniversalAnalyzerDense { 8 | 9 | def main(args: Array[String]) { 10 | 11 | LogHandler.avoidLog() 12 | 13 | val spark = SparkSession.builder() 14 | .appName("runUniversalAnalyzerDense").getOrCreate() 15 | 16 | val dataPath = args(0).toString 17 | val dataDelimiter = args(1).toString 18 | val schemaPath = args(2).toString 19 | val schemaDelimiter = args(3).toString 20 | val hasLabel = args(4).toBoolean 21 | val numPartition = args(5).toInt 22 | val nullValue = args(6).toString 23 | val fineness = args(7).toInt 24 | 25 | // data input 26 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter) 27 | val orgDataDF = DataHandler.readData(spark, dataPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition) 28 | orgDataDF.show() 29 | 30 | val labelColName = 31 | if(hasLabel){ 32 | schemaHandler.labelColName(0) 33 | }else{ 34 | "" 35 | } 36 | 37 | val (catFeatSummaryArray, numFeatSummaryArray) = UniversalAnalyzer.fitDense(orgDataDF, hasLabel, labelColName, 38 | schemaHandler.catFeatColNames, schemaHandler.numFeatColNames, 39 | Range(0, fineness+1).toArray.map{ d => 1.0*d/fineness }) 40 | 41 | catFeatSummaryArray.foreach{ summary => 42 | println("\n===========================================") 43 | println(summary.name) 44 | println("countAll: " + summary.countAll) 45 | println("countNotNull: " + summary.countNotNull) 46 | println(summary.concernedCategories.map(tup => tup._1).mkString(",")) 47 | println("mi: " + summary.mi) 48 | println("auc: " + summary.auc) 49 | } 50 | 51 | numFeatSummaryArray.foreach{ summary => 52 | println("\n===========================================") 53 | println(summary.name) 54 | println("countAll: " + summary.countAll) 55 | println("countNotNull: " + summary.countNotNull) 56 | println("mean: " + summary.mean) 57 | println("std: " + summary.std) 58 | println("skewness: " + summary.skewness) 59 | println("kurtosis: " + summary.kurtosis) 60 | println("min: " + summary.min) 61 | println("max: " + summary.max) 62 | println("quantiles: " + summary.quantileArray.mkString(",")) 63 | println("corr: " + summary.corr) 64 | println("auc: " + summary.auc) 65 | } 66 | 67 | spark.stop() 68 | 69 | } 70 | 71 | } 72 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/analysis/runUniversalAnalyzerDenseGrouped.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.analysis.feature.analysis 2 | 3 | import net.qihoo.xitong.xdml.feature.analysis.UniversalAnalyzer 4 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler} 5 | import org.apache.spark.sql.SparkSession 6 | 7 | object runUniversalAnalyzerDenseGrouped { 8 | 9 | def main(args: Array[String]): Unit = { 10 | 11 | LogHandler.avoidLog() 12 | 13 | val spark = SparkSession.builder() 14 | .appName("runUniversalAnalyzerDenseGrouped").getOrCreate() 15 | 16 | val dataPath = args(0).toString 17 | val dataDelimiter = args(1).toString 18 | val schemaPath = args(2).toString 19 | val schemaDelimiter = args(3).toString 20 | val numPartition = args(4).toInt 21 | val nullValue = args(5).toString 22 | 23 | val groupColName = args(6).toString 24 | val topKStr = args(7).toString 25 | val topKDelimiter = args(8).toString 26 | 27 | val topKs = topKStr.split(topKDelimiter).map(_.toInt) 28 | 29 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter) 30 | val orgDataDF = DataHandler.readData(spark, dataPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition) 31 | orgDataDF.show() 32 | 33 | val groupFeatSummaryArray = UniversalAnalyzer.fitDenseGrouped(orgDataDF, 34 | schemaHandler.labelColName(0), 35 | groupColName, 36 | schemaHandler.numFeatColNames, 37 | topKs, 38 | gainFunc, 39 | discountFunc) 40 | println("-------------------metric result----------------------") 41 | groupFeatSummaryArray.foreach{ summary => 42 | println("\n=====================") 43 | println(summary.name) 44 | println("reversePairRate: " + summary.reversePairRate) 45 | println("ndcgs: " + summary.ndcgMap.toString()) 46 | } 47 | spark.stop() 48 | } 49 | 50 | def gainFunc(rating: Int): Double = { 51 | val gains = Array(0, 1, 3, 7, 15) 52 | val gainSize = gains.length 53 | if (rating < 0) { 54 | 0 55 | } else if (rating < gainSize) { 56 | gains(rating) 57 | } else { 58 | 0 59 | // gains.max + (rating - gainSize) 60 | } 61 | } 62 | 63 | def discountFunc(index: Int): Double = { 64 | val log2 = math.log(2) 65 | math.log(index + 2) / log2 66 | } 67 | 68 | } 69 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/analysis/runUniversalAnalyzerDenseKS.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.analysis.feature.analysis 2 | 3 | import net.qihoo.xitong.xdml.feature.analysis.UniversalAnalyzer 4 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler} 5 | import org.apache.spark.sql.SparkSession 6 | 7 | object runUniversalAnalyzerDenseKS { 8 | 9 | def main(args: Array[String]): Unit={ 10 | 11 | LogHandler.avoidLog() 12 | 13 | val spark = SparkSession.builder() 14 | .appName("runUniversalAnalyzerDenseKS").getOrCreate() 15 | 16 | val orgDataPath = args(0).toString 17 | val orgDataDelimiter = args(1).toString 18 | val orgSchemaPath = args(2).toString 19 | val orgSchemaDelimiter = args(3).toString 20 | val cmpDataPath = args(4).toString 21 | val cmpDataDelimiter = args(5).toString 22 | val cmpSchemaPath = args(6).toString 23 | val cmpSchemaDelimiter = args(7).toString 24 | val numPartition = args(8).toInt 25 | val nullValue = args(9).toString 26 | val withHeader = args(10).toBoolean 27 | val savePath = args(11).toString 28 | 29 | val orgSchemaHandler = SchemaHandler.readSchema(spark.sparkContext, orgSchemaPath, orgSchemaDelimiter) 30 | val cmpSchemaHandler = SchemaHandler.readSchema(spark.sparkContext, cmpSchemaPath, cmpSchemaDelimiter) 31 | val orgDataDF = DataHandler.readData(spark, orgDataPath, orgDataDelimiter, orgSchemaHandler.schema, nullValue, numPartition, withHeader) 32 | val cmpDataDF = DataHandler.readData(spark, cmpDataPath, cmpDataDelimiter, cmpSchemaHandler.schema, nullValue, numPartition, withHeader) 33 | 34 | val time1 = System.currentTimeMillis() 35 | val ksValues = UniversalAnalyzer.fitDenseKSForNum(orgDataDF, orgSchemaHandler.numFeatColNames, cmpDataDF, cmpSchemaHandler.numFeatColNames) 36 | val time2 = System.currentTimeMillis() 37 | println("spent time", (time2 - time1)) 38 | spark.sparkContext.parallelize(Seq(ksValues.mkString("\n"))) 39 | .repartition(1).saveAsTextFile(savePath) 40 | 41 | spark.stop() 42 | } 43 | 44 | } 45 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/analysis/runUniversalAnalyzerSparse.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.analysis.feature.analysis 2 | 3 | import net.qihoo.xitong.xdml.feature.analysis._ 4 | import net.qihoo.xitong.xdml.model.data.{LogHandler, SchemaHandler} 5 | import org.apache.spark.sql.SparkSession 6 | import scala.collection.mutable.HashMap 7 | 8 | object runUniversalAnalyzerSparse { 9 | 10 | trait ParseBase extends Serializable { 11 | def parse(str: String): HashMap[String, String] 12 | } 13 | 14 | class StandardParse(stringDelimiter: String = ",", pairDelimiter: String = ":") extends ParseBase { 15 | def parse(str: String): HashMap[String, String] = { 16 | val arrOfTuple2 = str.split(stringDelimiter).map{ elem => 17 | val elemSplit = elem.split(pairDelimiter) 18 | (elemSplit(0), elemSplit(1)) 19 | } 20 | HashMap[String, String]() ++= arrOfTuple2 21 | } 22 | } 23 | 24 | def main(args: Array[String]) { 25 | 26 | LogHandler.avoidLog() 27 | 28 | val spark = SparkSession.builder() 29 | .appName("runUniversalAnalyzerSparse").getOrCreate() 30 | 31 | val dataPath = args(0).toString 32 | val parseType = args(1).toString 33 | val schemaPath = args(2).toString 34 | val schemaDelimiter = args(3).toString 35 | val hasLabel = args(4).toBoolean 36 | val numPartitions = args(5).toInt 37 | val sparseType = args(6).toString 38 | val fineness = args(7).toInt 39 | 40 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter) 41 | val tmpDataRDD = spark.sparkContext.textFile(dataPath, numPartitions) 42 | val orgDataRDD = parseType match { 43 | case "standard" => 44 | val standardParse = new StandardParse() 45 | tmpDataRDD.map{ str => standardParse.parse(str) } 46 | case _ => throw new IllegalArgumentException("parse type unknown") 47 | } 48 | 49 | val labelColName = 50 | if(hasLabel){ 51 | schemaHandler.labelColName(0) 52 | }else{ 53 | "" 54 | } 55 | 56 | val (catFeatSummaryArray, numFeatSummaryArray) = UniversalAnalyzer.fitSparse(orgDataRDD, hasLabel, 57 | labelColName, schemaHandler.catFeatColNames, schemaHandler.numFeatColNames, 58 | sparseType, Range(0, fineness+1).toArray.map{ d => 1.0*d/fineness }) 59 | 60 | catFeatSummaryArray.foreach{ summary => 61 | println("\n===========================================") 62 | println(summary.name) 63 | println(summary.concernedCategories.map(tup => tup._1).mkString(",")) 64 | println("mi: " + summary.mi) 65 | println("auc: " + summary.auc) 66 | } 67 | 68 | numFeatSummaryArray.foreach{ summary => 69 | println("\n===========================================") 70 | println(summary.name) 71 | println("countNotNull: " + summary.countNotNull) 72 | println("mean: " + summary.mean) 73 | println("std: " + summary.std) 74 | println("skewness: " + summary.skewness) 75 | println("kurtosis: " + summary.kurtosis) 76 | println("min: " + summary.min) 77 | println("max: " + summary.max) 78 | println("quantiles: " + summary.quantileArray.mkString(",")) 79 | println("corr: " + summary.corr) 80 | println("auc: " + summary.auc) 81 | } 82 | 83 | spark.stop() 84 | 85 | } 86 | 87 | } 88 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/process/runCategoryEncoder.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.analysis.feature.process 2 | 3 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler} 4 | import org.apache.spark.ml.feature.CategoryEncoder 5 | import org.apache.spark.sql.SparkSession 6 | 7 | 8 | object runCategoryEncoder { 9 | 10 | def main(args: Array[String]) { 11 | 12 | LogHandler.avoidLog() 13 | 14 | val spark = SparkSession.builder() 15 | .appName("testCategoryEncoder").getOrCreate() 16 | 17 | // data 18 | val dataPath = args(0).toString 19 | val dataDelimiter = args(1).toString 20 | val schemaPath = args(2).toString 21 | val schemaDelimiter = args(3).toString 22 | val numPartition = args(4).toInt 23 | val nullValue = args(5).toString 24 | 25 | // data input 26 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter) 27 | val orgDataDF = DataHandler.readData(spark, dataPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition) 28 | println("input table") 29 | orgDataDF.show(false) 30 | 31 | val indexer = new CategoryEncoder() 32 | .setInputCols(schemaHandler.catFeatColNames) 33 | .setOutputCols(schemaHandler.catFeatColNames.map(name => name + "xdmlIndexed")) 34 | .setIndexOnly(true) 35 | // .setDropInputCols(true) 36 | // .setCategoriesReserved(5) 37 | 38 | val startTime = System.currentTimeMillis() 39 | val indexerModel = indexer.fit(orgDataDF) 40 | val fitTime = System.currentTimeMillis() 41 | // indexerModel.labels.foreach{ seq => 42 | // println(seq.toArray.mkString(", ")) 43 | // } 44 | val processedDataDF = indexerModel.transform(orgDataDF) 45 | println("output table") 46 | processedDataDF.show(false) 47 | println("dataCount: " + processedDataDF.count()) 48 | 49 | val transTime = System.currentTimeMillis() 50 | println("time for fitting: "+(fitTime-startTime)/1000.0) 51 | println("time for transforming: "+(transTime-fitTime)/1000.0) 52 | 53 | spark.stop() 54 | 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/process/runFeatureProcess.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.analysis.feature.process 2 | 3 | import net.qihoo.xitong.xdml.feature.process.FeatureProcessor 4 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler} 5 | import org.apache.spark.ml._ 6 | import org.apache.spark.sql.SparkSession 7 | 8 | 9 | object runFeatureProcess { 10 | 11 | def main(args: Array[String]) { 12 | 13 | LogHandler.avoidLog() 14 | 15 | val spark = SparkSession.builder() 16 | .appName("runFeatureProcess").getOrCreate() 17 | 18 | // data 19 | val dataPath = args(0).toString 20 | val dataDelimiter = args(1).toString 21 | val schemaPath = args(2).toString 22 | val schemaDelimiter = args(3).toString 23 | val numPartition = args(4).toInt 24 | val nullValue = args(5).toString 25 | 26 | // job type 27 | val jobType = args(6).toString 28 | 29 | val methodForNum = args(7).toString 30 | val ifOnehotForNum = args(8).toBoolean 31 | val ifOnehotForCat = args(9).toBoolean 32 | val ifOnehotForMultiCat = args(10).toBoolean 33 | 34 | val numBuckets = args(11).toInt 35 | val categoriesReservedForCat = args(12).toInt 36 | val categoriesReservedForMultiCat = args(13).toInt 37 | val multiCatDelimiter = args(14).toString 38 | 39 | // pipeline model path 40 | val pipelineModelPath = args(15) 41 | // processed data path 42 | val processedDataPath = args(16) 43 | 44 | // data input 45 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter) 46 | val orgDataDF = DataHandler.readData(spark, dataPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition) 47 | println("input table") 48 | orgDataDF.show(false) 49 | 50 | val resDF = jobType match { 51 | case "fit_transform" => { 52 | val (resDFTmp, pipelineModelTmp) = FeatureProcessor.pipelineFitTransform(orgDataDF, schemaHandler, 53 | methodForNum, ifOnehotForNum, ifOnehotForCat, ifOnehotForMultiCat, true, 54 | numBuckets, categoriesReservedForCat, categoriesReservedForMultiCat, multiCatDelimiter) 55 | pipelineModelTmp.write.overwrite().save(pipelineModelPath) 56 | resDFTmp 57 | } 58 | case "transform" => { 59 | val pipelineModelTmp = PipelineModel.load(pipelineModelPath) 60 | val resDFTmp = pipelineModelTmp.transform(orgDataDF) 61 | resDFTmp 62 | } 63 | case _ => throw new IllegalArgumentException(s"Does not support job type: $jobType") 64 | } 65 | 66 | println("output table") 67 | resDF.show(false) 68 | DataHandler.writeLibSVMData(resDF, FeatureProcessor.labelProcessedColName, 69 | FeatureProcessor.featsProcessedColName, processedDataPath) 70 | 71 | spark.stop() 72 | 73 | } 74 | 75 | } 76 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/process/runMultiCategoryEncoder.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.analysis.feature.process 2 | 3 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler} 4 | import org.apache.spark.ml.feature.MultiCategoryEncoder 5 | import org.apache.spark.sql.SparkSession 6 | 7 | 8 | object runMultiCategoryEncoder { 9 | 10 | def main(args: Array[String]) { 11 | 12 | LogHandler.avoidLog() 13 | 14 | val spark = SparkSession.builder() 15 | .appName("testMultiCategoryEncoder").getOrCreate() 16 | 17 | // data 18 | val dataPath = args(0).toString 19 | val dataDelimiter = args(1).toString 20 | val schemaPath = args(2).toString 21 | val schemaDelimiter = args(3).toString 22 | val numPartition = args(4).toInt 23 | val nullValue = args(5).toString 24 | val multiCatDelimiter = args(6).toString 25 | 26 | // data input 27 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter) 28 | val orgDataDF = DataHandler.readData(spark, dataPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition) 29 | println("input table") 30 | orgDataDF.show(false) 31 | 32 | val encoder = new MultiCategoryEncoder() 33 | .setInputCols(schemaHandler.multiCatFeatColNames) 34 | .setOutputCols(schemaHandler.multiCatFeatColNames.map(name => name + "xdmlEncoded")) 35 | .setDelimiter(multiCatDelimiter) 36 | .setIndexOnly(true) 37 | // .setDropInputCols(true) 38 | 39 | val startTime = System.currentTimeMillis() 40 | val encoderModel = encoder.fit(orgDataDF) 41 | val fitTime = System.currentTimeMillis() 42 | val processedDataDF = encoderModel.transform(orgDataDF) 43 | println("output table") 44 | processedDataDF.show(false) 45 | println("dataCount: " + processedDataDF.count()) 46 | 47 | val transTime = System.currentTimeMillis() 48 | println("time for fitting: "+(fitTime-startTime)/1000.0) 49 | println("time for transforming: "+(transTime-fitTime)/1000.0) 50 | 51 | spark.stop() 52 | 53 | } 54 | 55 | } 56 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/process/runNumericBucketer.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.analysis.feature.process 2 | 3 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler} 4 | import org.apache.spark.ml.feature.NumericBucketer 5 | import org.apache.spark.sql.SparkSession 6 | 7 | object runNumericBucketer { 8 | 9 | def main(args:Array[String]) { 10 | 11 | LogHandler.avoidLog() 12 | 13 | val spark = SparkSession.builder() 14 | .appName("runNumericBucketer").getOrCreate() 15 | 16 | // data 17 | val dataPath = args(0).toString 18 | val dataDelimiter = args(1).toString 19 | val schemaPath = args(2).toString 20 | val schemaDelimiter = args(3).toString 21 | val numPartition = args(4).toInt 22 | val nullValue = args(5).toString 23 | 24 | val numBucketsArray = args(6).toString.split(",").map(_.toInt) 25 | 26 | // data input 27 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter) 28 | val orgDataDF = DataHandler.readData(spark, dataPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition) 29 | println("input table") 30 | orgDataDF.show(false) 31 | 32 | require(numBucketsArray.length == schemaHandler.numFeatColNames.length, "invalid numBucketsArray size") 33 | 34 | val numericBucketer = new NumericBucketer() 35 | .setInputCols(schemaHandler.numFeatColNames) 36 | .setOutputCols(schemaHandler.numFeatColNames.map(name=>name+"xdmlBucketed")) 37 | .setNumBucketsArray(numBucketsArray) 38 | .setIndexOnly(false) 39 | .setOutputSparse(true) 40 | .setDropInputCols(true) 41 | 42 | val startTime = System.currentTimeMillis() 43 | val numericBucketerModel = numericBucketer.fit(orgDataDF) 44 | val fitTime = System.currentTimeMillis() 45 | 46 | val processedDataDF = numericBucketerModel.transform(orgDataDF) 47 | println("output table") 48 | processedDataDF.show(false) 49 | println("dataCount: " + processedDataDF.count()) 50 | 51 | val transTime = System.currentTimeMillis() 52 | println("time for fitting: "+(fitTime-startTime)/1000.0) 53 | println("time for transforming: "+(transTime-fitTime)/1000.0) 54 | 55 | spark.stop() 56 | } 57 | 58 | } -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/analysis/feature/process/runNumericStandardizer.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.analysis.feature.process 2 | 3 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler} 4 | import org.apache.spark.ml.feature.NumericStandardizer 5 | import org.apache.spark.sql.SparkSession 6 | 7 | 8 | object runNumericStandardizer { 9 | 10 | def main(args: Array[String]) { 11 | 12 | LogHandler.avoidLog() 13 | 14 | val spark = SparkSession.builder() 15 | .appName("runNumericStandardizer").getOrCreate() 16 | 17 | // data 18 | val dataPath = args(0).toString 19 | val dataDelimiter = args(1).toString 20 | val schemaPath = args(2).toString 21 | val schemaDelimiter = args(3).toString 22 | val numPartition = args(4).toInt 23 | val nullValue = args(5).toString 24 | 25 | // data input 26 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter) 27 | val orgDataDF = DataHandler.readData(spark, dataPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition) 28 | println("input table") 29 | orgDataDF.show(false) 30 | 31 | val numericStandardizer = new NumericStandardizer() 32 | .setInputCols(schemaHandler.numFeatColNames) 33 | .setOutputCols(schemaHandler.numFeatColNames.map(name => name + "XDMLStandardized")) 34 | 35 | val startTime = System.currentTimeMillis() 36 | val numericStandardizerModel = numericStandardizer.fit(orgDataDF) 37 | val fitTime = System.currentTimeMillis() 38 | 39 | val processedDataDF = numericStandardizerModel.transform(orgDataDF) 40 | println("output table") 41 | processedDataDF.show(false) 42 | println("dataCount: " + processedDataDF.count()) 43 | 44 | val transTime = System.currentTimeMillis() 45 | println("time for fitting: "+(fitTime-startTime)/1000.0) 46 | println("time for transforming: "+(transTime-fitTime)/1000.0) 47 | 48 | spark.stop() 49 | 50 | } 51 | 52 | } 53 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/analysis/model/runFromDenseDataToXDMLH2ODRF.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.analysis.model 2 | 3 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler} 4 | import org.apache.spark.ml.model.supervised.{H2ODRF, H2ODRFModel} 5 | import org.apache.spark.sql.SparkSession 6 | 7 | object runFromDenseDataToXDMLH2ODRF { 8 | 9 | def main(args: Array[String]) { 10 | LogHandler.avoidLog() 11 | 12 | val spark = SparkSession 13 | .builder() 14 | .appName("runFromDenseDataToXDMLH2ODRF") 15 | .getOrCreate() 16 | 17 | // data 18 | val trainPath = args(0).toString 19 | val validPath = args(1).toString 20 | val dataDelimiter = args(2).toString 21 | val schemaPath = args(3).toString 22 | val schemaDelimiter = args(4).toString 23 | val numPartition = args(5).toInt 24 | val nullValue = args(6).toString 25 | val modelPath = args(7).toString 26 | 27 | // hyperparameter 28 | val maxDepth = args(8).toInt 29 | val numTrees = args(9).toInt 30 | val maxBinsForCat = args(10).toInt 31 | val maxBinsForNum = args(11).toInt 32 | val minInstancesPerNode = args(12).toInt 33 | val categoricalEncodingScheme = args(13).toString 34 | val histogramType = args(14).toString 35 | val distribution = args(15).toString 36 | val scoreTreeInterval = args(16).toInt 37 | 38 | // data input 39 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter) 40 | val orgTrainDataDF = DataHandler.readData(spark, trainPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition) 41 | orgTrainDataDF.show() 42 | val orgValidDataDF = 43 | if (validPath.trim.length > 0) { 44 | DataHandler.readData(spark, validPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition) 45 | } else { 46 | null 47 | } 48 | 49 | val h2oDRF = new H2ODRF() 50 | .setLabelCol(schemaHandler.labelColName(0)) 51 | .setCatFeatColNames(schemaHandler.catFeatColNames) 52 | .setIgnoreFeatColNames(schemaHandler.otherColNames) 53 | .setMaxDepth(maxDepth) 54 | .setNumTrees(numTrees) 55 | .setMaxBinsForCat(maxBinsForCat) 56 | .setMaxBinsForNum(maxBinsForNum) 57 | .setMinInstancesPerNode(minInstancesPerNode) 58 | .setCategoricalEncodingScheme(categoricalEncodingScheme.trim()) 59 | .setHistogramType(histogramType.trim()) 60 | .setDistribution(distribution.trim()) 61 | .setScoreTreeInterval(scoreTreeInterval) 62 | 63 | val drfModel = h2oDRF.fit(orgTrainDataDF) 64 | drfModel.write.overwrite().save(modelPath) 65 | val drfModel2 = H2ODRFModel.load(modelPath) 66 | drfModel2.transformSchema(orgValidDataDF.schema).printTreeString() 67 | 68 | val predDataDF = drfModel2.transform(orgValidDataDF) 69 | predDataDF.show(false) 70 | 71 | spark.stop() 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/analysis/model/runFromDenseDataToXDMLH2OGBM.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.analysis.model 2 | 3 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler} 4 | import org.apache.spark.ml.model.supervised.{H2OGBM, H2OGBMModel} 5 | import org.apache.spark.sql.SparkSession 6 | 7 | object runFromDenseDataToXDMLH2OGBM { 8 | 9 | def main(args: Array[String]) { 10 | LogHandler.avoidLog() 11 | 12 | val spark = SparkSession 13 | .builder() 14 | .appName("runFromDenseDataToXDMLH2OGBM") 15 | .getOrCreate() 16 | 17 | // data 18 | val trainPath = args(0).toString 19 | val validPath = args(1).toString 20 | val dataDelimiter = args(2).toString 21 | val schemaPath = args(3).toString 22 | val schemaDelimiter = args(4).toString 23 | val numPartition = args(5).toInt 24 | val nullValue = args(6).toString 25 | 26 | // model path 27 | val modelPath = args(7).toString 28 | 29 | // hyperparameter 30 | val maxDepth = args(8).toInt 31 | val numTrees = args(9).toInt 32 | val maxBinsForCat = args(10).toInt 33 | val maxBinsForNum = args(11).toInt 34 | val minInstancesPerNode = args(12).toInt 35 | val categoricalEncodingScheme = args(13).toString 36 | val histogramType = args(14).toString 37 | val learnRate = args(15).toDouble 38 | val learnRateAnnealing = args(16).toDouble 39 | val distribution = args(17).toString 40 | val scoreTreeInterval = args(18).toInt 41 | 42 | // data input 43 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter) 44 | val orgTrainDataDF = DataHandler.readData(spark, trainPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition) 45 | orgTrainDataDF.show() 46 | val orgValidDataDF = 47 | if (validPath.trim.length > 0) { 48 | DataHandler.readData(spark, validPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition) 49 | } else { 50 | null 51 | } 52 | 53 | val h2oGBM = new H2OGBM() 54 | .setLabelCol(schemaHandler.labelColName(0)) 55 | .setCatFeatColNames(schemaHandler.catFeatColNames) 56 | .setIgnoreFeatColNames(schemaHandler.otherColNames) 57 | .setMaxDepth(maxDepth) 58 | .setNumTrees(numTrees) 59 | .setMaxBinsForCat(maxBinsForCat) 60 | .setMaxBinsForNum(maxBinsForNum) 61 | .setMinInstancesPerNode(minInstancesPerNode) 62 | .setCategoricalEncodingScheme(categoricalEncodingScheme) 63 | .setHistogramType(histogramType) 64 | .setLearnRate(learnRate) 65 | .setLearnRateAnnealing(learnRateAnnealing) 66 | .setDistribution(distribution) 67 | .setScoreTreeInterval(scoreTreeInterval) 68 | 69 | val gbmModel = h2oGBM.fit(orgTrainDataDF) 70 | gbmModel.write.overwrite().save(modelPath) 71 | val gbmModel2 = H2OGBMModel.load(modelPath) 72 | gbmModel2.transformSchema(orgValidDataDF.schema).printTreeString() 73 | 74 | val predDataDF = gbmModel2.transform(orgValidDataDF) 75 | predDataDF.show(false) 76 | 77 | spark.stop() 78 | } 79 | 80 | } 81 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/analysis/model/runFromDenseDataToXDMLH2OGLM.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.analysis.model 2 | 3 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler} 4 | import org.apache.spark.ml.model.supervised.{H2OGLM, H2OGLMModel} 5 | import org.apache.spark.sql.SparkSession 6 | 7 | 8 | object runFromDenseDataToXDMLH2OGLM { 9 | 10 | def main(args: Array[String]) { 11 | LogHandler.avoidLog() 12 | 13 | val spark = SparkSession 14 | .builder() 15 | .appName("runFromDenseDataToXDMLH2OGLM") 16 | .getOrCreate() 17 | 18 | // data 19 | val trainPath = args(0).toString 20 | val validPath = args(1).toString 21 | val dataDelimiter = args(2).toString 22 | val schemaPath = args(3).toString 23 | val schemaDelimiter = args(4).toString 24 | val numPartition = args(5).toInt 25 | val nullValue = args(6).toString 26 | val modelPath = args(7).toString 27 | 28 | // hyperparameter 29 | val family = args(8).toString 30 | val maxIter = args(9).toInt 31 | val alpha = args(10).toDouble 32 | val lambda = args(11).toDouble 33 | val missingValuesHandling = args(12).toString 34 | val solver = args(13).toString 35 | val standardization = args(14).toBoolean 36 | val fitIntercept = args(15).toBoolean 37 | 38 | 39 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter) 40 | val orgTrainDataDF = DataHandler.readData(spark, trainPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition) 41 | orgTrainDataDF.show() 42 | val orgValidDataDF = 43 | if (validPath.trim.length > 0) { 44 | DataHandler.readData(spark, validPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition) 45 | } else { 46 | null 47 | } 48 | 49 | val h2oGLM = new H2OGLM() 50 | .setLabelCol(schemaHandler.labelColName(0)) 51 | .setCatFeatColNames(schemaHandler.catFeatColNames) 52 | .setIgnoreFeatColNames(schemaHandler.otherColNames) 53 | .setFamily(family) 54 | .setMaxIter(maxIter) 55 | .setAlpha(alpha) 56 | .setLambda(lambda) 57 | .setFitIntercept(fitIntercept) 58 | .setStandardization(standardization) 59 | .setMissingValueHandling(missingValuesHandling) 60 | 61 | val glmModel = h2oGLM.fit(orgTrainDataDF) 62 | glmModel.write.overwrite().save(modelPath) 63 | val glmModel2 = H2OGLMModel.load(modelPath) 64 | 65 | val predDF = glmModel2.transform(orgValidDataDF) 66 | predDF.show(false) 67 | spark.stop() 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/analysis/model/runFromDenseDataToXDMLH2OMLP.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.analysis.model 2 | 3 | import net.qihoo.xitong.xdml.model.data.{DataHandler, LogHandler, SchemaHandler} 4 | import org.apache.spark.ml.model.supervised.{H2OMLP, H2OMLPModel} 5 | import org.apache.spark.sql.SparkSession 6 | 7 | object runFromDenseDataToXDMLH2OMLP { 8 | 9 | def main(args: Array[String]) { 10 | LogHandler.avoidLog() 11 | 12 | val spark = SparkSession 13 | .builder() 14 | .appName("runFromDenseDataToXDMLH2OMLP") 15 | .getOrCreate() 16 | 17 | // data 18 | val trainPath = args(0).toString 19 | val validPath = args(1).toString 20 | val dataDelimiter = args(2).toString 21 | val schemaPath = args(3).toString 22 | val schemaDelimiter = args(4).toString 23 | val numPartition = args(5).toInt 24 | val nullValue = args(6).toString 25 | val modelPath = args(7).toString 26 | 27 | // hyperparameter 28 | val missingValuesHandling = args(8).toString 29 | val categoricalEncodingScheme = args(9).toString 30 | val distribution = args(10).toString 31 | val hidden = args(11).toString.split(",").map(_.toInt) 32 | val activation = args(12).toString 33 | val epochs = args(13).toDouble 34 | val learningRate = args(14).toDouble 35 | val momentumStart = args(15).toDouble 36 | val momentumStable = args(16).toDouble 37 | val l1 = args(17).toDouble 38 | val l2 = args(18).toDouble 39 | val hiddenDropoutRatiosText = args(19).toString 40 | val elasticAveraging = args(20).toBoolean 41 | val standardization = args(21).toBoolean 42 | 43 | val schemaHandler = SchemaHandler.readSchema(spark.sparkContext, schemaPath, schemaDelimiter) 44 | val orgTrainDataDF = DataHandler.readData(spark, trainPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition) 45 | orgTrainDataDF.show() 46 | val orgValidDataDF = 47 | if (validPath.trim.length > 0) { 48 | DataHandler.readData(spark, validPath, dataDelimiter, schemaHandler.schema, nullValue, numPartition) 49 | } else { 50 | null 51 | } 52 | 53 | val hiddenDropoutRatios = if(hiddenDropoutRatiosText.trim.length > 0) hiddenDropoutRatiosText.split(",").map{ _.toDouble} else Array[Double]() 54 | 55 | val h2oMLP = new H2OMLP() 56 | .setLabelCol(schemaHandler.labelColName(0)) 57 | .setCatFeatColNames(schemaHandler.catFeatColNames) 58 | .setIgnoreFeatColNames(schemaHandler.otherColNames) 59 | .setMissingValueHandling(missingValuesHandling) 60 | .setCategoricalEncodingScheme(categoricalEncodingScheme) 61 | .setDistribution(distribution) 62 | .setHidden(hidden) 63 | .setActivation(activation) 64 | .setEpochs(epochs) 65 | .setLearnRate(learningRate) 66 | .setMomentumStart(momentumStart) 67 | .setMomentumStable(momentumStable) 68 | .setL1(l1) 69 | .setL2(l2) 70 | .setHiddenDropoutRatios(hiddenDropoutRatios) 71 | .setElasticAveraging(elasticAveraging) 72 | .setStandardization(standardization) 73 | 74 | val mlpModel = h2oMLP.fit(orgTrainDataDF,null) 75 | mlpModel.write.overwrite().save(modelPath) 76 | val mlpModel2 = H2OMLPModel.load(modelPath) 77 | mlpModel2.transformSchema(orgValidDataDF.schema).printTreeString() 78 | 79 | val predDataDF = mlpModel2.transform(orgValidDataDF) 80 | predDataDF.show(false) 81 | 82 | spark.stop() 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/analysis/model/runFromLibSVMDataToXDMLLinearScopeModel.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.analysis.model 2 | 3 | import net.qihoo.xitong.xdml.model.data.LogHandler 4 | import org.apache.spark.ml.model.supervised.LinearScope 5 | import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics 6 | import org.apache.spark.sql.data.DataProcessor 7 | import org.apache.spark.sql.functions._ 8 | import org.apache.spark.sql.types.DoubleType 9 | import org.apache.spark.sql.{Row, SparkSession} 10 | 11 | object runFromLibSVMDataToXDMLLinearScopeModel { 12 | def main(args: Array[String]) { 13 | LogHandler.avoidLog() 14 | val spark = SparkSession 15 | .builder() 16 | .appName("runFromLibSVMDataToXDMLLinearScopeModel") 17 | .getOrCreate() 18 | 19 | val trainPath = args(0).toString 20 | val validPath = args(1).toString 21 | val numFeatures = args(2).toInt 22 | val numPartitions = args(3).toInt 23 | val oneBased = args(4).toBoolean 24 | val needSort = args(5).toBoolean 25 | val rescaleBinaryLabel = args(6).toBoolean 26 | val stepSize = args(7).toDouble 27 | val numIterations = args(8).toInt 28 | val regParam = args(9).toDouble 29 | val elasticNetParam = args(10).toDouble 30 | val factor = args(11).toDouble 31 | val fitIntercept = args(12).toBoolean 32 | val lossType = args(13).toString 33 | val posWeight = args(14).toDouble 34 | 35 | val trainData = DataProcessor.readLibSVMDataAsDF(spark, trainPath, numFeatures, numPartitions, 36 | oneBased, needSort, rescaleBinaryLabel) 37 | trainData.show(2) 38 | 39 | val validData = 40 | if (validPath.trim.length > 0) { 41 | DataProcessor.readLibSVMDataAsDF(spark, validPath, numFeatures, numPartitions, 42 | oneBased, needSort, rescaleBinaryLabel) 43 | } else { 44 | null 45 | } 46 | if(validData!=null) { 47 | validData.show(2) 48 | } 49 | 50 | val convergenceTol = 0.2 51 | val linearScope = new LinearScope() 52 | .setFeaturesCol("features") 53 | .setLabelCol("label") 54 | .setMaxIter(numIterations) 55 | .setStepSize(stepSize) 56 | .setLossFunc(lossType) 57 | .setRegParam(regParam) 58 | .setElasticNetParam(elasticNetParam) 59 | .setFactor(factor) 60 | .setFitIntercept(fitIntercept) 61 | .setPosWeight(posWeight) 62 | .setConvergenceTol(convergenceTol) 63 | .setNumPartitions(numPartitions) 64 | 65 | val model = linearScope.fit(trainData) 66 | /*** 67 | * also can set initial weights of w 68 | val initialWeights = Vectors.zeros(numFeatures) 69 | val model = linearScope.fit(trainData,initialWeights)*/ 70 | model.setRawPredictionCol("prediction") 71 | .setProbabilityCol("probability") 72 | if (validPath.trim.length > 0){ 73 | val df = model.transform(validData) 74 | df.show(5) 75 | val predRDD=df.select(col("prediction"),col("label").cast(DoubleType)).rdd.map{case Row(pred:Double,label:Double)=>{(pred,label)}} 76 | val metrics = new BinaryClassificationMetrics(predRDD) 77 | println("Model Validating AUROC: " + metrics.areaUnderROC()) 78 | }else{ 79 | val df = model.transform(trainData) 80 | df.show(5) 81 | val predRDD=df.select(col("prediction"),col("label").cast(DoubleType)).rdd.map{case Row(pred:Double,label:Double)=>{(pred,label)}} 82 | val metrics = new BinaryClassificationMetrics(predRDD) 83 | println("Model Training AUROC: " + metrics.areaUnderROC()) 84 | } 85 | spark.stop() 86 | } 87 | 88 | } 89 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/analysis/model/runFromLibSVMDataToXDMLOVR.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.analysis.model 2 | 3 | import net.qihoo.xitong.xdml.model.data.LogHandler 4 | import org.apache.spark.sql.SparkSession 5 | import org.apache.spark.sql.data.DataProcessor 6 | import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, MulticlassClassificationEvaluator} 7 | import org.apache.spark.ml.model.supervised.OVRLinearScope 8 | 9 | object runFromLibSVMDataToXDMLOVR { 10 | 11 | def main(args: Array[String]) { 12 | 13 | LogHandler.avoidLog() 14 | 15 | val spark = SparkSession 16 | .builder() 17 | .appName("runFromLibSVMDataToXDMLOVR") 18 | .getOrCreate() 19 | 20 | val trainPath = args(0).toString 21 | val validPath = args(1).toString 22 | val numFeatures = args(2).toInt 23 | val numPartitions = args(3).toInt 24 | val oneBased = args(4).toBoolean 25 | val needSort = args(5).toBoolean 26 | val rescaleBinaryLabel = args(6).toBoolean 27 | val stepSize = args(7).toDouble 28 | val numIterations = args(8).toInt 29 | val numClasses = args(9).toInt 30 | val lossType = args(10).toString 31 | 32 | val trainDF = DataProcessor.readLibSVMDataAsDF(spark, trainPath, numFeatures, numPartitions, 33 | oneBased, needSort, rescaleBinaryLabel) 34 | val validDF = 35 | if (validPath.trim.length > 0) { 36 | DataProcessor.readLibSVMDataAsDF(spark, validPath, numFeatures, numPartitions, 37 | oneBased, needSort, rescaleBinaryLabel) 38 | } else { 39 | null 40 | } 41 | 42 | val ovrls = new OVRLinearScope() 43 | .setFeaturesCol("features") 44 | .setLabelCol("label") 45 | .setStepSize(stepSize) 46 | .setMaxIter(numIterations) 47 | .setNumClasses(numClasses) 48 | .setNumPartitions(numPartitions) 49 | .setLossFunc(lossType) 50 | 51 | val ovrlsModel = ovrls.fit(trainDF) 52 | 53 | if (validPath.trim.length > 0) { 54 | 55 | val df = ovrlsModel.transform(validDF) 56 | df.show(false) 57 | 58 | if (numClasses > 2) { 59 | /*********** MulticlassClassificationEvaluator *************/ 60 | val evaluator = new MulticlassClassificationEvaluator() 61 | .setLabelCol("label") 62 | .setPredictionCol("prediction") 63 | .setMetricName("accuracy") 64 | val acc = evaluator.evaluate(df) 65 | println("Model Validating accuracy: " + acc) 66 | } else { 67 | /*********** BinaryClassificationEvaluator *************/ 68 | val evaluator = new BinaryClassificationEvaluator() 69 | .setLabelCol("label") 70 | .setRawPredictionCol("rawPrediction") 71 | .setMetricName("areaUnderROC") 72 | val auroc = evaluator.evaluate(df) 73 | println("Model Validating AUROC: " + auroc) 74 | } 75 | 76 | } 77 | 78 | spark.stop() 79 | 80 | } 81 | 82 | } -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/analysis/model/runFromLibSVMDataToXDMLSR.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.analysis.model 2 | 3 | import net.qihoo.xitong.xdml.model.data.LogHandler 4 | import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, MulticlassClassificationEvaluator} 5 | import org.apache.spark.ml.model.supervised.MultiLinearScope 6 | import org.apache.spark.sql.SparkSession 7 | import org.apache.spark.sql.data.DataProcessor 8 | 9 | object runFromLibSVMDataToXDMLSR { 10 | 11 | def main(args: Array[String]) { 12 | 13 | LogHandler.avoidLog() 14 | 15 | val spark = SparkSession 16 | .builder() 17 | .appName("runFromLibSVMDataToXDMLSR") 18 | .getOrCreate() 19 | 20 | val trainPath = args(0).toString 21 | val validPath = args(1).toString 22 | val numFeatures = args(2).toInt 23 | val numPartitions = args(3).toInt 24 | val oneBased = args(4).toBoolean 25 | val needSort = args(5).toBoolean 26 | val rescaleBinaryLabel = args(6).toBoolean 27 | val stepSize = args(7).toDouble 28 | val numIterations = args(8).toInt 29 | val numClasses = args(9).toInt 30 | 31 | val trainDF = DataProcessor.readLibSVMDataAsDF(spark, trainPath, numFeatures, numPartitions, 32 | oneBased, needSort, rescaleBinaryLabel) 33 | val validDF = 34 | if (validPath.trim.length > 0) { 35 | DataProcessor.readLibSVMDataAsDF(spark, validPath, numFeatures, numPartitions, 36 | oneBased, needSort, rescaleBinaryLabel) 37 | } else { 38 | null 39 | } 40 | 41 | val mls = new MultiLinearScope() 42 | .setFeaturesCol("features") 43 | .setLabelCol("label") 44 | .setStepSize(stepSize) 45 | .setMaxIter(numIterations) 46 | .setNumClasses(numClasses) 47 | .setNumPartitions(numPartitions) 48 | 49 | val mlsModel = mls.fit(trainDF) 50 | 51 | if (validPath.trim.length > 0) { 52 | 53 | val df = mlsModel.transform(validDF) 54 | df.show(false) 55 | 56 | if (numClasses > 2) { 57 | /*********** MulticlassClassificationEvaluator *************/ 58 | val evaluator = new MulticlassClassificationEvaluator() 59 | .setLabelCol("label") 60 | .setPredictionCol("prediction") 61 | .setMetricName("accuracy") 62 | val acc = evaluator.evaluate(df) 63 | println("Model Validating accuracy: " + acc) 64 | } else { 65 | /*********** BinaryClassificationEvaluator *************/ 66 | val evaluator = new BinaryClassificationEvaluator() 67 | .setLabelCol("label") 68 | .setRawPredictionCol("rawPrediction") 69 | .setMetricName("areaUnderROC") 70 | val auroc = evaluator.evaluate(df) 71 | println("Model Validating AUROC: " + auroc) 72 | } 73 | 74 | } 75 | 76 | spark.stop() 77 | 78 | } 79 | 80 | } -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/ml/DCASGDTest.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.ml 2 | 3 | import net.qihoo.xitong.xdml.conf.{JobConfiguration, JobType, PSConfiguration, PSDataType} 4 | import net.qihoo.xitong.xdml.dataProcess.LibSVMProcessor 5 | import net.qihoo.xitong.xdml.ml.LogisticRegressionWithDCASGD 6 | import net.qihoo.xitong.xdml.ps.PS 7 | import net.qihoo.xitong.xdml.utils.XDMLException 8 | import org.apache.hadoop.fs.Path 9 | import org.apache.spark.{SparkConf, SparkContext} 10 | 11 | object DCASGDTest { 12 | def main(args: Array[String]): Unit = { 13 | val conf = new SparkConf() 14 | .setAppName("Momentum-Test") 15 | val sc = new SparkContext(conf) 16 | //read spark config 17 | val jobConf = new JobConfiguration().readJobConfig(sc) 18 | val dataPath = jobConf.getOrElse(JobConfiguration.DATA_PATH, "") 19 | val dataPartitionNum = jobConf.getOrElse(JobConfiguration.TRAIN_DATA_PARTITION, "50").toInt 20 | val iterNum = jobConf.getOrElse(JobConfiguration.TRAIN_ITER_NUM, "1").toInt 21 | val batchSize = jobConf.getOrElse(JobConfiguration.BATCH_SIZE, "5000").toInt 22 | val jobType = jobConf.getOrElse(JobConfiguration.JOB_TYPE, "train") 23 | val modelPath = jobConf.getOrElse(JobConfiguration.MODEL_PATH, "") 24 | val learningRate = jobConf.getOrElse(JobConfiguration.LEARNING_RATE, "0.01").toFloat 25 | val coff = jobConf.getOrElse(JobConfiguration.DC_ASGD_COFF, "0.1").toFloat 26 | val hzClusterNum = jobConf.getOrElse(JobConfiguration.HZ_CLUSTER_NUM, "50").toInt 27 | val hzPartitionNum = jobConf.getOrElse(JobConfiguration.HZ_PARTITION_NUM, "271").toInt 28 | val tableName = System.getProperty("user.name") + "_" + jobConf.getOrElse(JobConfiguration.TABLE_NAME, "DC-ASGD") 29 | val kuduMaster = jobConf.getOrElse(JobConfiguration.KUDU_MASTER, "") 30 | val resultPath = jobConf.getOrElse(JobConfiguration.PREDICT_RESULT_PATH, "") 31 | val split = jobConf.getOrElse(JobConfiguration.DATA_SPLIT, " ") 32 | if (kuduMaster.equals("")) { 33 | throw new XDMLException("kudu master must be set!") 34 | } 35 | //read data 36 | val rawData = sc.textFile(dataPath) 37 | val data = LibSVMProcessor.processData(rawData, split).coalesce(dataPartitionNum) 38 | val psConf = new PSConfiguration() 39 | .setPsTableName(tableName) 40 | .setHzClusterNum(hzClusterNum) 41 | .setHzPartitionNum(hzPartitionNum) 42 | .setPsDataType(PSDataType.FLOAT_ARRAY) 43 | .setPsDataLength(2) 44 | .setKuduMaster(kuduMaster) 45 | if (jobType.toUpperCase.equals("TRAIN")) { 46 | psConf.setJobType(JobType.TRAIN) 47 | } else if (jobType.toUpperCase.equals("PREDICT")) { 48 | if (modelPath.equals("")) 49 | throw new XDMLException("Predict job must have a model path") 50 | else 51 | psConf.setPredictModelPath(modelPath) 52 | psConf.setJobType(JobType.PREDICT) 53 | } else if (jobType.toUpperCase.equals("INCREMENT_TRAIN")) { 54 | if (modelPath.equals("")) 55 | throw new XDMLException("INCREMENT_TRAIN job must have a model path") 56 | else 57 | psConf.setPredictModelPath(modelPath) 58 | psConf.setJobType(JobType.INCREMENT_TRAIN) 59 | } else { 60 | throw new XDMLException("Wrong job type") 61 | } 62 | val ps = PS.getInstance(sc, psConf) 63 | val model = new LogisticRegressionWithDCASGD(ps) 64 | .setIterNum(iterNum) 65 | .setBatchSize(batchSize) 66 | .setLearningRate(learningRate) 67 | .setDcAsgdCoff(coff) 68 | psConf.getJobType match { 69 | case JobType.TRAIN => { 70 | //start train 71 | println("Start Train...") 72 | val info = model.fit(data) 73 | println(s"Save the model to path :" + modelPath) 74 | ps.saveModel(sc, modelPath) 75 | } 76 | case JobType.PREDICT => { 77 | //start predict 78 | println("Start Predict...") 79 | val result = model.predict(data) 80 | val path = new Path(resultPath) 81 | val fs = path.getFileSystem(sc.hadoopConfiguration) 82 | if (fs.exists(path)) { 83 | println(s"Result Path ${resultPath} existed, delete.") 84 | fs.delete(path, true) 85 | } 86 | result.saveAsTextFile(resultPath) 87 | val rightRate = result.filter(x => (if (x._1 > 0.5) 1 else 0) == x._2).count().toDouble / result.count().toDouble 88 | println("right rate is : " + rightRate) 89 | } 90 | case JobType.INCREMENT_TRAIN => { 91 | //start train 92 | println("Start Increment Train...") 93 | val info = model.fit(data) 94 | println(s"Save the new model to path :" + modelPath) 95 | ps.saveModel(sc, modelPath) 96 | } 97 | } 98 | sc.stop() 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/ml/FFMTest.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.ml 2 | 3 | import net.qihoo.xitong.xdml.conf.{JobConfiguration, JobType, PSConfiguration, PSDataType} 4 | import net.qihoo.xitong.xdml.dataProcess.FFMProcessor 5 | import net.qihoo.xitong.xdml.ml.FieldawareFactorizationMachine 6 | import net.qihoo.xitong.xdml.ps.PS 7 | import net.qihoo.xitong.xdml.utils.XDMLException 8 | import org.apache.spark.{SparkConf, SparkContext} 9 | 10 | object FFMTest { 11 | def main(args: Array[String]): Unit = { 12 | val conf = new SparkConf() 13 | .setAppName("FFM-Test") 14 | val sc = new SparkContext(conf) 15 | //read spark config 16 | val jobConf = new JobConfiguration().readJobConfig(sc) 17 | val dataPath = jobConf.getOrElse(JobConfiguration.DATA_PATH, "") 18 | val dataPartitionNum = jobConf.getOrElse(JobConfiguration.TRAIN_DATA_PARTITION, "50").toInt 19 | val iterNum = jobConf.getOrElse(JobConfiguration.TRAIN_ITER_NUM, "1").toInt 20 | val batchSize = jobConf.getOrElse(JobConfiguration.BATCH_SIZE, "5000").toInt 21 | val jobType = jobConf.getOrElse(JobConfiguration.JOB_TYPE, "train") 22 | val modelPath = jobConf.getOrElse(JobConfiguration.MODEL_PATH, "") 23 | val learningRate = jobConf.getOrElse(JobConfiguration.LEARNING_RATE, "0.01").toFloat 24 | val hzClusterNum = jobConf.getOrElse(JobConfiguration.HZ_CLUSTER_NUM, "50").toInt 25 | val hzPartitionNum = jobConf.getOrElse(JobConfiguration.HZ_PARTITION_NUM, "271").toInt 26 | val rank = jobConf.getOrElse(JobConfiguration.FFM_RANK, "1").toInt 27 | val field = jobConf.getOrElse(JobConfiguration.FFM_FIELD, "1").toInt 28 | val tableName = System.getProperty("user.name") + "_" + jobConf.getOrElse(JobConfiguration.TABLE_NAME, "FFM") 29 | val kuduMaster = jobConf.getOrElse(JobConfiguration.KUDU_MASTER, "") 30 | val resultPath = jobConf.getOrElse(JobConfiguration.PREDICT_RESULT_PATH, "") 31 | val split = jobConf.getOrElse(JobConfiguration.DATA_SPLIT, " ") 32 | if (kuduMaster.equals("")) { 33 | throw new XDMLException("kudu master must be set!") 34 | } 35 | //read data 36 | val rawData = sc.textFile(dataPath) 37 | val data = FFMProcessor.processData(rawData,split).coalesce(dataPartitionNum) 38 | val psConf = new PSConfiguration() 39 | .setPsTableName(tableName) 40 | .setHzClusterNum(hzClusterNum) 41 | .setHzPartitionNum(hzPartitionNum) 42 | .setPsDataType(PSDataType.FLOAT_ARRAY) 43 | .setPsDataLength(rank * field) 44 | .setKuduMaster(kuduMaster) 45 | if (jobType.toUpperCase.equals("TRAIN")) { 46 | psConf.setJobType(JobType.TRAIN) 47 | } else if (jobType.toUpperCase.equals("PREDICT")) { 48 | if (modelPath.equals("")) 49 | throw new XDMLException("Predict job must have a model path") 50 | else 51 | psConf.setPredictModelPath(modelPath) 52 | psConf.setJobType(JobType.PREDICT) 53 | } else if (jobType.toUpperCase.equals("INCREMENT_TRAIN")) { 54 | if (modelPath.equals("")) 55 | throw new XDMLException("INCREMENT_TRAIN job must have a model path") 56 | else 57 | psConf.setPredictModelPath(modelPath) 58 | psConf.setJobType(JobType.INCREMENT_TRAIN) 59 | } else { 60 | throw new XDMLException("Wrong job type") 61 | } 62 | val ps = PS.getInstance(sc, psConf) 63 | val model = new FieldawareFactorizationMachine(ps) 64 | .setIterNum(iterNum) 65 | .setBatchSize(batchSize) 66 | .setLearningRate(learningRate) 67 | .setField(field) 68 | .setRank(rank) 69 | psConf.getJobType match { 70 | case JobType.TRAIN => { 71 | //start train 72 | println("Start Train...") 73 | val info = model.fit(data) 74 | println(s"Save the model to path :" + modelPath) 75 | ps.saveModel(sc, modelPath) 76 | } 77 | case JobType.PREDICT => { 78 | //start predict 79 | println("Start Predict...") 80 | } 81 | case JobType.INCREMENT_TRAIN => { 82 | //start train 83 | println("Start Increment Train...") 84 | val info = model.fit(data) 85 | println(s"Save the new model to path :" + modelPath) 86 | ps.saveModel(sc, modelPath) 87 | } 88 | } 89 | sc.stop() 90 | } 91 | } 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/ml/FTRLTest.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.ml 2 | 3 | import net.qihoo.xitong.xdml.conf.{JobConfiguration, JobType, PSConfiguration, PSDataType} 4 | import net.qihoo.xitong.xdml.dataProcess.LibSVMProcessor 5 | import net.qihoo.xitong.xdml.ml.LogisticRegressionWithFTRL 6 | import net.qihoo.xitong.xdml.ps.PS 7 | import net.qihoo.xitong.xdml.utils.XDMLException 8 | import org.apache.hadoop.fs.Path 9 | import org.apache.spark.{SparkConf, SparkContext} 10 | 11 | object FTRLTest { 12 | def main(args: Array[String]): Unit = { 13 | 14 | val conf = new SparkConf() 15 | .setAppName("LR-Test") 16 | val sc = new SparkContext(conf) 17 | //read spark config 18 | val jobConf = new JobConfiguration().readJobConfig(sc) 19 | val dataPath = jobConf.getOrElse(JobConfiguration.DATA_PATH, "") 20 | val dataPartitionNum = jobConf.getOrElse(JobConfiguration.TRAIN_DATA_PARTITION, "50").toInt 21 | val iterNum = jobConf.getOrElse(JobConfiguration.TRAIN_ITER_NUM, "1").toInt 22 | val batchSize = jobConf.getOrElse(JobConfiguration.BATCH_SIZE, "5000").toInt 23 | val jobType = jobConf.getOrElse(JobConfiguration.JOB_TYPE, "train") 24 | val modelPath = jobConf.getOrElse(JobConfiguration.MODEL_PATH, "") 25 | val hzClusterNum = jobConf.getOrElse(JobConfiguration.HZ_CLUSTER_NUM, "50").toInt 26 | val hzPartitionNum = jobConf.getOrElse(JobConfiguration.HZ_PARTITION_NUM, "271").toInt 27 | val alpha = jobConf.getOrElse(JobConfiguration.FTRL_ALPHA, "1.0").toFloat 28 | val beta = jobConf.getOrElse(JobConfiguration.FTRL_BETA, "1.0").toFloat 29 | val lambda1 = jobConf.getOrElse(JobConfiguration.FTRL_LAMBDA1, "1.0").toFloat 30 | val lambda2 = jobConf.getOrElse(JobConfiguration.FTRL_LAMBDA2, "1.0").toFloat 31 | val forceSparse = jobConf.getOrElse(JobConfiguration.FTRL_FROCESPARSE, "false").toBoolean 32 | val tableName = System.getProperty("user.name") + "_" + jobConf.getOrElse(JobConfiguration.TABLE_NAME, "FTRL") 33 | val kuduMaster = jobConf.getOrElse(JobConfiguration.KUDU_MASTER, "") 34 | val resultPath = jobConf.getOrElse(JobConfiguration.PREDICT_RESULT_PATH, "") 35 | val split = jobConf.getOrElse(JobConfiguration.DATA_SPLIT, " ") 36 | if (kuduMaster.equals("")) { 37 | throw new XDMLException("kudu master must be set!") 38 | } 39 | //read data 40 | val rawData = sc.textFile(dataPath) 41 | val data = LibSVMProcessor.processData(rawData, split).coalesce(dataPartitionNum) 42 | val psConf = new PSConfiguration() 43 | .setPsDataType(PSDataType.FLOAT_ARRAY) 44 | .setPsDataLength(3) 45 | .setPsTableName(tableName) 46 | .setHzClusterNum(hzClusterNum) 47 | .setHzPartitionNum(hzPartitionNum) 48 | .setForceSparse(forceSparse) 49 | .setKuduMaster(kuduMaster) 50 | if (jobType.toUpperCase.equals("TRAIN")) { 51 | psConf.setJobType(JobType.TRAIN) 52 | } else if (jobType.toUpperCase.equals("PREDICT")) { 53 | if (modelPath.equals("")) 54 | throw new XDMLException("Predict job must have a model path") 55 | else 56 | psConf.setPredictModelPath(modelPath) 57 | psConf.setJobType(JobType.PREDICT) 58 | } else if (jobType.toUpperCase.equals("INCREMENT_TRAIN")) { 59 | if (modelPath.equals("")) 60 | throw new XDMLException("INCREMENT_TRAIN job must have a model path") 61 | else 62 | psConf.setPredictModelPath(modelPath) 63 | psConf.setJobType(JobType.INCREMENT_TRAIN) 64 | } else { 65 | throw new XDMLException("Wrong job type") 66 | } 67 | val ps = PS.getInstance(sc, psConf) 68 | val model = new LogisticRegressionWithFTRL(ps) 69 | .setIterNum(iterNum) 70 | .setBatchSize(batchSize) 71 | .setAlpha(alpha) 72 | .setBeta(beta) 73 | .setLambda1(lambda1) 74 | .setLambda2(lambda2) 75 | psConf.getJobType match { 76 | case JobType.TRAIN => { 77 | //start train 78 | println("Start Train...") 79 | val info = model.fit(data) 80 | println(s"Save the model to path :" + modelPath) 81 | ps.saveModel(sc, modelPath) 82 | } 83 | case JobType.PREDICT => { 84 | //start predict 85 | println("Start Predict...") 86 | val result = model.predict(data) 87 | val path = new Path(resultPath) 88 | val fs = path.getFileSystem(sc.hadoopConfiguration) 89 | if (fs.exists(path)) { 90 | println(s"Result Path ${resultPath} existed, delete.") 91 | fs.delete(path, true) 92 | } 93 | result.saveAsTextFile(resultPath) 94 | val rightRate = result.filter(x => (if (x._1 > 0.5) 1 else 0) == x._2).count().toDouble / result.count().toDouble 95 | println("right rate is : " + rightRate) 96 | } 97 | case JobType.INCREMENT_TRAIN => { 98 | //start train 99 | println("Start Increment Train...") 100 | val info = model.fit(data) 101 | println(s"Save the new model to path :" + modelPath) 102 | ps.saveModel(sc, modelPath) 103 | } 104 | } 105 | sc.stop() 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/ml/LRTest.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.ml 2 | 3 | import net.qihoo.xitong.xdml.conf.{JobConfiguration, JobType, PSConfiguration, PSDataType} 4 | import net.qihoo.xitong.xdml.dataProcess.LibSVMProcessor 5 | import net.qihoo.xitong.xdml.ml.LogisticRegression 6 | import net.qihoo.xitong.xdml.ps.PS 7 | import net.qihoo.xitong.xdml.utils.XDMLException 8 | import org.apache.hadoop.fs.Path 9 | import org.apache.spark.{SparkConf, SparkContext} 10 | 11 | object LRTest { 12 | def main(args: Array[String]): Unit = { 13 | val conf = new SparkConf() 14 | .setAppName("LR-Test") 15 | val sc = new SparkContext(conf) 16 | //read spark config 17 | val jobConf = new JobConfiguration().readJobConfig(sc) 18 | val dataPath = jobConf.getOrElse(JobConfiguration.DATA_PATH, "") 19 | val dataPartitionNum = jobConf.getOrElse(JobConfiguration.TRAIN_DATA_PARTITION, "50").toInt 20 | val iterNum = jobConf.getOrElse(JobConfiguration.TRAIN_ITER_NUM, "1").toInt 21 | val batchSize = jobConf.getOrElse(JobConfiguration.BATCH_SIZE, "5000").toInt 22 | val jobType = jobConf.getOrElse(JobConfiguration.JOB_TYPE, "train") 23 | val modelPath = jobConf.getOrElse(JobConfiguration.MODEL_PATH, "") 24 | val learningRate = jobConf.getOrElse(JobConfiguration.LEARNING_RATE, "0.01").toFloat 25 | val hzClusterNum = jobConf.getOrElse(JobConfiguration.HZ_CLUSTER_NUM, "50").toInt 26 | val hzPartitionNum = jobConf.getOrElse(JobConfiguration.HZ_PARTITION_NUM, "271").toInt 27 | val tableName = System.getProperty("user.name") + "_" + jobConf.getOrElse(JobConfiguration.TABLE_NAME, "LR") 28 | val kuduMaster = jobConf.getOrElse(JobConfiguration.KUDU_MASTER, "") 29 | val resultPath = jobConf.getOrElse(JobConfiguration.PREDICT_RESULT_PATH, "") 30 | val split = jobConf.getOrElse(JobConfiguration.DATA_SPLIT, " ") 31 | if (kuduMaster.equals("")) { 32 | throw new XDMLException("kudu master must be set!") 33 | } 34 | //read data 35 | val rawData = sc.textFile(dataPath) 36 | val data = LibSVMProcessor.processData(rawData, split).coalesce(dataPartitionNum) 37 | val psConf = new PSConfiguration() 38 | .setPsDataType(PSDataType.FLOAT) 39 | .setPsDataLength(1) 40 | .setPsTableName(tableName) 41 | .setHzClusterNum(hzClusterNum) 42 | .setHzPartitionNum(hzPartitionNum) 43 | .setKuduMaster(kuduMaster) 44 | if (jobType.toUpperCase.equals("TRAIN")) { 45 | psConf.setJobType(JobType.TRAIN) 46 | } else if (jobType.toUpperCase.equals("PREDICT")) { 47 | if (modelPath.equals("") || resultPath.equals("")) 48 | throw new XDMLException("Predict job must have a model path and result path!") 49 | else 50 | psConf.setPredictModelPath(modelPath) 51 | psConf.setJobType(JobType.PREDICT) 52 | } else if (jobType.toUpperCase.equals("INCREMENT_TRAIN")) { 53 | if (modelPath.equals("")) 54 | throw new XDMLException("INCREMENT_TRAIN job must have a model path") 55 | else 56 | psConf.setPredictModelPath(modelPath) 57 | psConf.setJobType(JobType.INCREMENT_TRAIN) 58 | } else { 59 | throw new XDMLException("Wrong job type") 60 | } 61 | val ps = PS.getInstance(sc, psConf) 62 | val model = new LogisticRegression(ps) 63 | .setIterNum(iterNum) 64 | .setBatchSize(batchSize) 65 | .setLearningRate(learningRate) 66 | psConf.getJobType match { 67 | case JobType.TRAIN => { 68 | //start train 69 | println("Start Train...") 70 | val info = model.fit(data) 71 | println("Save the model to path :" + modelPath) 72 | ps.saveModel(sc, modelPath) 73 | } 74 | case JobType.PREDICT => { 75 | //start predict 76 | println("Start Predict...") 77 | val result = model.predict(data) 78 | val path = new Path(resultPath) 79 | val fs = path.getFileSystem(sc.hadoopConfiguration) 80 | if (fs.exists(path)) { 81 | println(s"Result Path ${resultPath} existed, delete.") 82 | fs.delete(path, true) 83 | } 84 | result.saveAsTextFile(resultPath) 85 | val rightRate = result.filter(x => (if (x._1 > 0.5) 1 else 0) == x._2).count().toDouble / result.count().toDouble 86 | println("right rate is : " + rightRate) 87 | } 88 | case JobType.INCREMENT_TRAIN => { 89 | //start train 90 | println("Start Increment Train...") 91 | val info = model.fit(data) 92 | println(s"Save the new model to path :" + modelPath) 93 | ps.saveModel(sc, modelPath) 94 | } 95 | } 96 | sc.stop() 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/example/ml/MomentumTest.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.example.ml 2 | 3 | import net.qihoo.xitong.xdml.conf.{JobConfiguration, JobType, PSConfiguration, PSDataType} 4 | import net.qihoo.xitong.xdml.dataProcess.LibSVMProcessor 5 | import net.qihoo.xitong.xdml.ml.LogisticRegressionWithMomentum 6 | import net.qihoo.xitong.xdml.ps.PS 7 | import net.qihoo.xitong.xdml.utils.XDMLException 8 | import org.apache.hadoop.fs.Path 9 | import org.apache.spark.{SparkConf, SparkContext} 10 | 11 | object MomentumTest { 12 | def main(args: Array[String]): Unit = { 13 | val conf = new SparkConf() 14 | .setAppName("Momentum-Test") 15 | val sc = new SparkContext(conf) 16 | //read spark config 17 | val jobConf = new JobConfiguration().readJobConfig(sc) 18 | val dataPath = jobConf.getOrElse(JobConfiguration.DATA_PATH, "") 19 | val dataPartitionNum = jobConf.getOrElse(JobConfiguration.TRAIN_DATA_PARTITION, "50").toInt 20 | val iterNum = jobConf.getOrElse(JobConfiguration.TRAIN_ITER_NUM, "1").toInt 21 | val batchSize = jobConf.getOrElse(JobConfiguration.BATCH_SIZE, "5000").toInt 22 | val jobType = jobConf.getOrElse(JobConfiguration.JOB_TYPE, "train") 23 | val modelPath = jobConf.getOrElse(JobConfiguration.MODEL_PATH, "") 24 | val learningRate = jobConf.getOrElse(JobConfiguration.LEARNING_RATE, "0.01").toFloat 25 | val momCoff = jobConf.getOrElse(JobConfiguration.MOMENTUM_COFF, "0.1").toFloat 26 | val hzClusterNum = jobConf.getOrElse(JobConfiguration.HZ_CLUSTER_NUM, "50").toInt 27 | val hzPartitionNum = jobConf.getOrElse(JobConfiguration.HZ_PARTITION_NUM, "271").toInt 28 | val tableName = System.getProperty("user.name") + "_" + jobConf.getOrElse(JobConfiguration.TABLE_NAME, "Momentum") 29 | val kuduMaster = jobConf.getOrElse(JobConfiguration.KUDU_MASTER, "") 30 | val resultPath = jobConf.getOrElse(JobConfiguration.PREDICT_RESULT_PATH, "") 31 | val split = jobConf.getOrElse(JobConfiguration.DATA_SPLIT, " ") 32 | if (kuduMaster.equals("")) { 33 | throw new XDMLException("kudu master must be set!") 34 | } 35 | //read data 36 | val rawData = sc.textFile(dataPath) 37 | val data = LibSVMProcessor.processData(rawData, split).coalesce(dataPartitionNum) 38 | val psConf = new PSConfiguration() 39 | .setPsTableName(tableName) 40 | .setHzClusterNum(hzClusterNum) 41 | .setHzPartitionNum(hzPartitionNum) 42 | .setPsDataType(PSDataType.FLOAT_ARRAY) 43 | .setPsDataLength(2) 44 | .setKuduMaster(kuduMaster) 45 | if (jobType.toUpperCase.equals("TRAIN")) { 46 | psConf.setJobType(JobType.TRAIN) 47 | } else if (jobType.toUpperCase.equals("PREDICT")) { 48 | if (modelPath.equals("")) 49 | throw new XDMLException("Predict job must have a model path") 50 | else 51 | psConf.setPredictModelPath(modelPath) 52 | psConf.setJobType(JobType.PREDICT) 53 | } else if (jobType.toUpperCase.equals("INCREMENT_TRAIN")) { 54 | if (modelPath.equals("")) 55 | throw new XDMLException("INCREMENT_TRAIN job must have a model path") 56 | else 57 | psConf.setPredictModelPath(modelPath) 58 | psConf.setJobType(JobType.INCREMENT_TRAIN) 59 | } else { 60 | throw new XDMLException("Wrong job type") 61 | } 62 | val ps = PS.getInstance(sc, psConf) 63 | val model = new LogisticRegressionWithMomentum(ps) 64 | .setIterNum(iterNum) 65 | .setBatchSize(batchSize) 66 | .setLearningRate(learningRate) 67 | .setMomemtumCoff(momCoff) 68 | psConf.getJobType match { 69 | case JobType.TRAIN => { 70 | //start train 71 | println("Start Train...") 72 | val info = model.fit(data) 73 | println(s"Save the model to path :" + modelPath) 74 | ps.saveModel(sc, modelPath) 75 | } 76 | case JobType.PREDICT => { 77 | //start predict 78 | println("Start Predict...") 79 | val result = model.predict(data) 80 | val path = new Path(resultPath) 81 | val fs = path.getFileSystem(sc.hadoopConfiguration) 82 | if (fs.exists(path)) { 83 | println(s"Result Path ${resultPath} existed, delete.") 84 | fs.delete(path, true) 85 | } 86 | result.saveAsTextFile(resultPath) 87 | val rightRate = result.filter(x => (if (x._1 > 0.5) 1 else 0) == x._2).count().toDouble / result.count().toDouble 88 | println("right rate is : " + rightRate) 89 | } 90 | case JobType.INCREMENT_TRAIN => { 91 | //start train 92 | println("Start Increment Train...") 93 | val info = model.fit(data) 94 | println(s"Save the new model to path :" + modelPath) 95 | ps.saveModel(sc, modelPath) 96 | } 97 | } 98 | sc.stop() 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/feature/process/ColumnEliminator.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.feature.process 2 | 3 | import org.apache.spark.ml.Transformer 4 | import org.apache.spark.ml.param.ParamMap 5 | import org.apache.spark.ml.param.shared.HasInputCols 6 | import org.apache.spark.ml.util._ 7 | import org.apache.spark.sql.types.{StructField, StructType} 8 | import org.apache.spark.sql.{DataFrame, Dataset} 9 | 10 | import scala.collection.mutable.ArrayBuffer 11 | 12 | class ColumnEliminator(override val uid: String) 13 | extends Transformer with HasInputCols with DefaultParamsWritable { 14 | 15 | def this() = this(Identifiable.randomUID("ColumnEliminator")) 16 | 17 | def setInputCols(value: Array[String]): this.type = set(inputCols, value) 18 | 19 | override def transform(dataset: Dataset[_]): DataFrame = { 20 | transformSchema(dataset.schema, logging = true) 21 | dataset.drop($(inputCols):_*).toDF() 22 | } 23 | 24 | override def transformSchema(schema: StructType): StructType = { 25 | val inputColNames = $(inputCols) 26 | for (colName <- inputColNames) { 27 | if (!schema.fieldNames.contains(colName)) { 28 | logWarning("missing column " + colName) 29 | } 30 | } 31 | val fieldsEliminated = new ArrayBuffer[StructField]() 32 | for (structField <- schema.fields) { 33 | if (!inputColNames.contains(structField.name)) { 34 | fieldsEliminated += structField 35 | } 36 | } 37 | StructType(fieldsEliminated.toArray) 38 | } 39 | 40 | override def copy(extra: ParamMap): ColumnEliminator = defaultCopy(extra) 41 | 42 | } 43 | 44 | 45 | object ColumnEliminator extends DefaultParamsReadable[ColumnEliminator] { 46 | 47 | override def load(path: String): ColumnEliminator = super.load(path) 48 | } 49 | 50 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/feature/process/ColumnRenamer.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.feature.process 2 | 3 | import org.apache.spark.ml.Transformer 4 | import org.apache.spark.ml.param.ParamMap 5 | import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCols} 6 | import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} 7 | import org.apache.spark.sql._ 8 | import org.apache.spark.sql.types.{StructField, StructType} 9 | 10 | class ColumnRenamer(override val uid: String) 11 | extends Transformer with HasInputCols with HasOutputCols with DefaultParamsWritable { 12 | 13 | def this() = this(Identifiable.randomUID("ColumnRenamer")) 14 | 15 | def setInputCols(value: Array[String]): this.type = set(inputCols, value) 16 | 17 | def setOutputCols(value: Array[String]): this.type = set(outputCols, value) 18 | 19 | override def transform(dataset: Dataset[_]): DataFrame = { 20 | transformSchema(dataset.schema, logging = true) 21 | var df = dataset 22 | val existingZipNewColNames = $(inputCols).zip($(outputCols)) 23 | for(en <- existingZipNewColNames) { 24 | df = df.withColumnRenamed(en._1, en._2) 25 | } 26 | df.toDF() 27 | } 28 | 29 | override def transformSchema(schema: StructType): StructType = { 30 | val inputColNames = $(inputCols) 31 | val inputZipOutputColNames = $(inputCols).zip($(outputCols)).toMap 32 | StructType(schema.fields.map{ structField => 33 | if (inputColNames.contains(structField.name)) { 34 | StructField(inputZipOutputColNames(structField.name), structField.dataType, structField.nullable) 35 | } else { 36 | structField 37 | } 38 | }) 39 | } 40 | 41 | override def copy(extra: ParamMap): ColumnRenamer = defaultCopy(extra) 42 | 43 | } 44 | 45 | object ColumnRenamer extends DefaultParamsReadable[ColumnRenamer] { 46 | 47 | override def load(path: String): ColumnRenamer = super.load(path) 48 | } 49 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/feature/process/PipelinePatch.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml 2 | 3 | import org.apache.spark.sql.DataFrame 4 | 5 | import scala.collection.mutable.ArrayBuffer 6 | 7 | object PipelinePatch { 8 | 9 | def pipelineFitTransform(df: DataFrame, pipeline: Pipeline): (DataFrame, PipelineModel) = { 10 | // check validity of schema 11 | pipeline.transformSchema(df.schema) 12 | val stagesClone = pipeline.getStages 13 | var dfTarget = df 14 | val transformers = new ArrayBuffer[Transformer]() 15 | stagesClone.foreach{ stage => { 16 | stage match { 17 | case estimator: Estimator[_] => { 18 | val e2t = estimator.fit(dfTarget) 19 | dfTarget = e2t.transform(dfTarget) 20 | transformers += e2t 21 | } 22 | case transformer: Transformer => { 23 | dfTarget = transformer.transform(dfTarget) 24 | transformers += transformer 25 | } 26 | case _ => 27 | throw new IllegalArgumentException( 28 | s"Does not support stage $stage of type ${stage.getClass}") 29 | } 30 | }} 31 | (dfTarget, new PipelineModel(pipeline.uid, transformers.toArray).setParent(pipeline)) 32 | } 33 | 34 | } 35 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/linalg/BLAS.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.linalg 2 | 3 | import scala.collection.mutable 4 | 5 | object BLAS extends Serializable { 6 | 7 | /** 8 | * dot(x, y) 9 | */ 10 | 11 | def dot(x: Map[Long, Float], y: Map[Long, Float]): Float = { 12 | var result = 0f 13 | x.keys.foreach(k => result += x(k) * y.getOrElse(k, 0f)) 14 | result 15 | } 16 | 17 | def dot(x: (Array[Long], Array[Float]), y: Map[Long, Float]): Float = { 18 | var result = 0f 19 | (0 until x._1.size).foreach(id => result += x._2(id) * y.getOrElse(x._1(id), 0f)) 20 | result 21 | } 22 | 23 | def axpy(a: Double, x: (Array[Long], Array[Float]), y: mutable.Map[Long, Float]): Unit = { 24 | val xIndices = x._1 25 | val xValues = x._2 26 | val nnz = xIndices.length 27 | if (Math.abs(a - 1) < 1e-6) { 28 | var k = 0 29 | while (k < nnz) { 30 | if (y.contains(xIndices(k))) { 31 | y(xIndices(k)) += xValues(k) 32 | } else { 33 | y += (xIndices(k) -> xValues(k)) 34 | } 35 | k += 1 36 | } 37 | } else { 38 | var k = 0 39 | while (k < nnz) { 40 | if (y.contains(xIndices(k))) { 41 | y(xIndices(k)) += (a * xValues(k)).toFloat 42 | } else { 43 | y += (xIndices(k) -> (a * xValues(k)).toFloat) 44 | } 45 | k += 1 46 | } 47 | } 48 | } 49 | 50 | 51 | def sigmoid(value: Float): Double = { 52 | if (value > 30) 53 | 1f 54 | else if (value < (-30)) 55 | 1e-6 56 | else { 57 | val ex = Math.pow(2.718281828, value) 58 | (ex / (1.0 + ex)) 59 | } 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/mapstore/PSMapStore.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.mapstore 2 | 3 | import java.nio.ByteBuffer 4 | import java.{lang, util} 5 | 6 | import com.hazelcast.core.MapStore 7 | import net.qihoo.xitong.xdml.conf.PSDataType.PSDataType 8 | import net.qihoo.xitong.xdml.conf.{PSConfiguration, PSDataType} 9 | import org.apache.kudu.Schema 10 | import org.apache.kudu.client.KuduScanner.KuduScannerBuilder 11 | import org.apache.kudu.client.SessionConfiguration.FlushMode 12 | import org.apache.kudu.client._ 13 | 14 | import scala.collection.JavaConversions._ 15 | import scala.util.Random 16 | 17 | class PSMapStore(psConf: PSConfiguration, tableName: String) extends MapStore[Long, Array[Byte]] { 18 | var client: KuduClient = _ 19 | var table: KuduTable = _ 20 | var session: KuduSession = _ 21 | var schema: Schema = _ 22 | //pull num count 23 | var count: Long = 0L 24 | var maxHzCacheSizePerPartition: Long = _ 25 | //random number generator 26 | var random: Random = _ 27 | var kuduCols: util.ArrayList[String] = _ 28 | var scanBuilder: KuduScannerBuilder = _ 29 | //V's bytes array size 30 | var byteArraySize: Int = _ 31 | //initial value when pull count < cache size or kudu table doesn't have the key 32 | var initValue: ByteBuffer = _ 33 | //ps conf data type 34 | var valueDataType: PSDataType = _ 35 | var valueDataLength: Int = _ 36 | // whether need random initial values 37 | var needRandInit: Boolean = _ 38 | var randomMin:Int = _ 39 | var randomMax:Int = _ 40 | var kuduForceSparse:Boolean = _ 41 | 42 | init() 43 | //initialize all variables and init value 44 | def init(): Unit = { 45 | client = new KuduClient.KuduClientBuilder(psConf.getKuduMaster).build() 46 | session = client.newSession() 47 | table = client.openTable(tableName) 48 | scanBuilder = client.newScannerBuilder(table) 49 | schema = table.getSchema 50 | session.setFlushMode(FlushMode.AUTO_FLUSH_BACKGROUND) 51 | session.setIgnoreAllDuplicateRows(false) 52 | session.setMutationBufferSpace(5000) 53 | kuduCols = new util.ArrayList[String] 54 | kuduCols.add(psConf.getKuduKeyColName) 55 | kuduCols.add(psConf.getKuduValueColName) 56 | random = new Random() 57 | valueDataType = psConf.getPsDataType 58 | valueDataLength = psConf.getPsDataLength 59 | byteArraySize = PSDataType.sizeOf(valueDataType) * valueDataLength 60 | println("kudu bytes size: " + byteArraySize) 61 | needRandInit = psConf.getKuduRandomInit 62 | randomMin = psConf.getKuduRandomMin 63 | randomMax = psConf.getKuduRandomMax 64 | initValue = ByteBuffer.allocate(byteArraySize) 65 | maxHzCacheSizePerPartition = psConf.getHzMaxCacheSizePerPartition 66 | kuduForceSparse = psConf.getForceSparse 67 | //need random initial value 68 | if (needRandInit) { 69 | valueDataType match { 70 | case PSDataType.FLOAT => initValue.putFloat(random.nextFloat() * (randomMax - randomMin) + randomMin) 71 | case PSDataType.DOUBLE => initValue.putDouble(random.nextDouble() * (randomMax - randomMin) + randomMin) 72 | case PSDataType.FLOAT_ARRAY => { 73 | for (index <- 0 until valueDataLength) { 74 | initValue.putFloat(index << 2,random.nextFloat()*(randomMax - randomMin) + randomMin) 75 | } 76 | } 77 | case PSDataType.DOUBLE_ARRAY => { 78 | for (index <- 0 until valueDataLength) { 79 | initValue.putDouble(index << 3,random.nextDouble()*(randomMax - randomMin) + randomMin) 80 | } 81 | } 82 | } 83 | } 84 | } 85 | 86 | override def store(key: Long, value: Array[Byte]): Unit = { 87 | val op = table.newUpsert() 88 | op.getRow.addLong(0, key) 89 | op.getRow.addBinary(1, value) 90 | session.apply(op) 91 | } 92 | 93 | override def storeAll(map: util.Map[Long, Array[Byte]]): Unit = { 94 | for ((k: Long, v: Array[Byte]) <- map if v.length == byteArraySize) { 95 | this.store(k, v) 96 | } 97 | } 98 | 99 | override def load(key: Long): Array[Byte] = { 100 | if (count < (maxHzCacheSizePerPartition - 2)) { 101 | count += 1 102 | } else { 103 | val idsPredicate = KuduPredicate.newComparisonPredicate(schema.getColumn(psConf.getKuduKeyColName), KuduPredicate.ComparisonOp.EQUAL, key) 104 | val scanner = scanBuilder.setProjectedColumnNames(kuduCols).addPredicate(idsPredicate).build() 105 | while (scanner.hasMoreRows) { 106 | val results = scanner.nextRows() 107 | while (results.hasNext) { 108 | val value = results.next().getBinaryCopy(1) 109 | if (value.length == byteArraySize) { 110 | return value 111 | } else { 112 | return initValue.array() 113 | } 114 | } 115 | } 116 | scanner.close() 117 | } 118 | initValue.array() 119 | } 120 | 121 | override def loadAll(keys: util.Collection[Long]): util.Map[Long, Array[Byte]] = { 122 | val initCapacity: Int = (keys.size() / 0.75f + 1).toInt 123 | val loadResultMap = new util.HashMap[Long, Array[Byte]](initCapacity) 124 | if (count < (maxHzCacheSizePerPartition - 2)) { 125 | for (key <- keys) 126 | loadResultMap.put(key, initValue.array()) 127 | count += keys.size() 128 | } else { 129 | val idsPredicate = KuduPredicate.newInListPredicate(schema.getColumn(psConf.getKuduKeyColName), keys.toList) 130 | val scanner = client.newScannerBuilder(table).setProjectedColumnNames(kuduCols).addPredicate(idsPredicate).build() 131 | while (scanner.hasMoreRows) { 132 | val results = scanner.nextRows() 133 | while (results.hasNext) { 134 | val result = results.next() 135 | val id = result.getLong(0) 136 | val weight = result.getBinaryCopy(1) 137 | if (weight.length == byteArraySize) { 138 | loadResultMap.put(id, weight) 139 | } else { 140 | loadResultMap.put(id, initValue.array()) 141 | } 142 | } 143 | } 144 | scanner.close() 145 | } 146 | loadResultMap 147 | } 148 | 149 | override def loadAllKeys(): lang.Iterable[Long] = {null} 150 | 151 | override def delete(key: Long): Unit = {} 152 | 153 | override def deleteAll(keys: util.Collection[Long]): Unit = {} 154 | } 155 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/ml/FieldawareFactorizationMachine.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.ml 2 | 3 | import java.io.Serializable 4 | import java.util 5 | 6 | import bloomfilter.mutable.BloomFilter 7 | import net.qihoo.xitong.xdml.optimization.FFM 8 | import net.qihoo.xitong.xdml.ps.{PS, PSClient} 9 | import net.qihoo.xitong.xdml.updater.FFMUpdater 10 | import org.apache.spark.rdd.RDD 11 | 12 | import scala.collection.JavaConversions._ 13 | 14 | class FieldawareFactorizationMachine(ps:PS) extends Serializable { 15 | private var iterNum = 1 16 | private var batchSize = 50 17 | private var learningRate:Float = 0.01F 18 | private var rank = 1 19 | private var field = 1 20 | def fit(data: RDD[(Double, Array[Int], Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) = { 21 | FieldawareFactorizationMachine.runMiniBatchSGD(ps,data, iterNum, batchSize, rank, field, learningRate) 22 | } 23 | def setIterNum(iterNum:Int):this.type = { 24 | this.iterNum = iterNum 25 | this 26 | } 27 | 28 | def setBatchSize(batchSize:Int):this.type = { 29 | this.batchSize = batchSize 30 | this 31 | } 32 | 33 | def setLearningRate(lr:Float):this.type ={ 34 | this.learningRate = lr 35 | this 36 | } 37 | 38 | def setRank(rank:Int):this.type ={ 39 | this.rank = rank 40 | this 41 | } 42 | 43 | def setField(field:Int):this.type ={ 44 | this.field = field 45 | this 46 | } 47 | } 48 | 49 | 50 | object FieldawareFactorizationMachine{ 51 | def runMiniBatchSGD(ps:PS, 52 | data: RDD[(Double, Array[Int], Array[Long], Array[Float])], 53 | iter: Int, 54 | batchSize: Int, 55 | rank: Int, 56 | fieldNum: Int, 57 | learningRate: Float 58 | ): (Map[Int, Double], (Long, Long)) = { 59 | println(s"Start the hz cluster ...") 60 | var trainInfo: Map[Int, Double] = Map() 61 | var posNum = 0L 62 | var negNum = 0L 63 | for (iterNum <- 0 until iter) { 64 | val result = data.mapPartitions { iter => 65 | val updater = new FFMUpdater() 66 | .setLearningRate(learningRate) 67 | val client = new PSClient[Long, Array[Float]](ps) 68 | .setUpdater(updater) 69 | var count = 0 70 | val weightIndex = new util.HashSet[Long]() 71 | val localData = new util.ArrayList[(Double, Array[Int], Array[Long], Array[Float])]() 72 | var batchId = 0 73 | var flag = false 74 | var totalLoss = 0.0 75 | var totalCount = 0L 76 | var posNum = 0L 77 | var negNum = 0L 78 | var batchLoss = 0.0 79 | val expectedElements = 100000000 80 | val falsePositiveRate = 0.1 81 | val bf = BloomFilter[Long](expectedElements, falsePositiveRate) 82 | var partBatchSize = batchSize 83 | 84 | while (iter.hasNext) { 85 | val dataLine = iter.next() 86 | localData.add(dataLine) 87 | dataLine._3.foreach(x => 88 | if (bf.mightContain(x)) { 89 | weightIndex.add(x) 90 | } else { 91 | bf.add(x) 92 | } 93 | ) 94 | if (iterNum == 0) { 95 | if (dataLine._1 > 0.0) 96 | posNum += 1 97 | else 98 | negNum += 1 99 | } 100 | count += 1 101 | if (count == partBatchSize || !iter.hasNext) { 102 | println(s"weightIndex.size: ${weightIndex.size()}") 103 | 104 | val localV = client.pull(weightIndex) 105 | println("localV: " + localV.map{ 106 | case(k,v) => k + "-->" + v.mkString(",") 107 | }) 108 | 109 | var totalGrad: Map[Long, Array[Float]] = Map() 110 | localData.foreach { dataIter => 111 | // train the data, get the grad and loss 112 | val (grad, loss) = FFM.train( 113 | fieldNum, 114 | rank, 115 | dataIter._1, 116 | (dataIter._2, dataIter._3, dataIter._4), 117 | localV 118 | ) 119 | println("loss: " + loss) 120 | // update the inc 121 | grad.foreach { case (k, v) => 122 | val gradArray = totalGrad.getOrElse(k, Array.fill(rank * fieldNum + fieldNum)(0.0f)) 123 | for(index <- 0 until rank*fieldNum) { 124 | //一组batch的梯度累加 125 | gradArray(index) = gradArray(index) + v(index) 126 | } 127 | totalGrad += (k -> gradArray) 128 | } 129 | batchLoss += loss 130 | } 131 | println(s"TotalLoss: ${batchLoss} . SampleNum: ${localData.size()}. Average Loss: ${batchLoss / count}") 132 | val us = System.nanoTime() 133 | client.push(totalGrad) 134 | 135 | val ue = System.nanoTime() 136 | println(s"update weight time consume is: ${(ue - us) / 1e9}") 137 | batchId += 1 138 | flag = true 139 | totalLoss += batchLoss 140 | totalCount += count 141 | count = 0 142 | batchLoss = 0.0 143 | localData.clear() 144 | weightIndex.clear() 145 | } 146 | 147 | if (batchId % 50 == 0 && flag) { 148 | val s = System.nanoTime() 149 | println(s"epoch ${iterNum} sync begin.") 150 | Thread.sleep(2000) 151 | val e = System.nanoTime() 152 | println(s"epoch ${iterNum} sync end, consume is: ${(e - s) / 1e9}") 153 | flag = false 154 | } 155 | } 156 | Iterator((totalLoss, totalCount, posNum, negNum)) 157 | } 158 | if (iterNum == 0) { 159 | val totalResult = result.reduce((x, y) => (x._1 + y._1, x._2 + y._2, x._3 + y._3, x._4 + y._4)) 160 | posNum = totalResult._3 161 | negNum = totalResult._4 162 | trainInfo += ((iterNum + 1) -> (totalResult._1 / totalResult._2)) 163 | } else { 164 | val totalResult = result.map(x => (x._1, x._2)).reduce((x, y) => (x._1 + y._1, x._2 + y._2)) 165 | trainInfo += ((iterNum + 1) -> (totalResult._1 / totalResult._2)) 166 | } 167 | } 168 | Thread.sleep(60000) 169 | (trainInfo, (posNum, negNum)) 170 | } 171 | } -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/ml/LogisticRegression.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.ml 2 | 3 | import java.io.Serializable 4 | import java.util 5 | 6 | import net.qihoo.xitong.xdml.optimization.BinaryClassify 7 | import net.qihoo.xitong.xdml.ps.{PS, PSClient} 8 | import net.qihoo.xitong.xdml.updater.LRUpdater 9 | import org.apache.spark.rdd.RDD 10 | 11 | import scala.collection.JavaConversions._ 12 | 13 | class LogisticRegression(ps: PS) extends Serializable { 14 | private var iterNum = 1 15 | private var batchSize = 50 16 | private var learningRate: Float = 0.01F 17 | 18 | def fit(data: RDD[(Double, Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) = { 19 | LogisticRegression.runMiniBatchSGD(ps, data, iterNum, batchSize, learningRate) 20 | } 21 | 22 | def predict(data: RDD[(Double, Array[Long], Array[Float])]): RDD[(Double, Double)] = { 23 | LogisticRegression.predict(ps, data, batchSize) 24 | } 25 | 26 | def setIterNum(iterNum: Int): this.type = { 27 | this.iterNum = iterNum 28 | this 29 | } 30 | 31 | def setBatchSize(batchSize: Int): this.type = { 32 | this.batchSize = batchSize 33 | this 34 | } 35 | 36 | def setLearningRate(lr: Float): this.type = { 37 | this.learningRate = lr 38 | this 39 | } 40 | } 41 | 42 | object LogisticRegression { 43 | 44 | def runMiniBatchSGD(ps: PS, 45 | data: RDD[(Double, Array[Long], Array[Float])], 46 | numIterations: Int, 47 | batchSize: Int, 48 | learningRate: Float 49 | ): (Map[Int, Double], (Long, Long)) = { 50 | //train 51 | var trainInfo: Map[Int, Double] = Map() 52 | var posNum: Long = 0 53 | var negNum: Long = 0 54 | for (iterNum <- 0 until numIterations) { 55 | val result = data.mapPartitions { iter => 56 | val client = new PSClient[Long, Float](ps) 57 | val updater = new LRUpdater() 58 | .setLearningRate(learningRate) 59 | client.setUpdater(updater) 60 | var (count, totalCount, posNum, negNum, batchId, totalLoss) = (0L, 0L, 0L, 0L, 0, 0.0D) 61 | val weightIndex = new util.HashSet[Long]() 62 | val localData = new util.ArrayList[(Double, Array[Long], Array[Float])]() 63 | while (iter.hasNext) { 64 | val dataLine = iter.next() 65 | localData.add(dataLine) 66 | dataLine._2.foreach(weightIndex.add) 67 | if (iterNum == 0) { 68 | if (dataLine._1 > 0.0) 69 | posNum += 1 70 | else 71 | negNum += 1 72 | } 73 | count += 1 74 | if (count == batchSize || !iter.hasNext) { 75 | var s = System.nanoTime() 76 | val localMap = client.pull(weightIndex).toMap 77 | val totalGrad = new util.HashMap[Long, Float]((weightIndex.size / 0.75f + 1).toInt) 78 | localData.foreach { dataIter => 79 | val (grad, loss) = BinaryClassify.train( 80 | dataIter._1, 81 | (dataIter._2, dataIter._3), 82 | localMap) 83 | // update the inc 84 | grad.foreach { case (id, v) => 85 | totalGrad.put(id, totalGrad.getOrElse(id, 0F) + v.toFloat) 86 | } 87 | totalLoss += loss 88 | } 89 | val pushMap = totalGrad.map { case (k, v) => (k, v / count) }.toMap 90 | client.push(pushMap) 91 | s = System.nanoTime() 92 | batchId += 1 93 | totalCount += count 94 | count = 0 95 | localData.clear() 96 | weightIndex.clear() 97 | } 98 | } 99 | println("loss:" + totalLoss) 100 | client.shutDown() 101 | Iterator((totalLoss, totalCount, posNum, negNum)) 102 | } 103 | val totalResult = result.reduce((x, y) => (x._1 + y._1, x._2 + y._2, x._3 + y._3, x._4 + y._4)) 104 | posNum = totalResult._3 105 | negNum = totalResult._4 106 | trainInfo += ((iterNum + 1) -> (totalResult._1 / totalResult._2)) 107 | } 108 | (trainInfo, (posNum, negNum)) 109 | } 110 | 111 | def predict(ps: PS, 112 | data: RDD[(Double, Array[Long], Array[Float])], 113 | batchSize: Int 114 | ): RDD[(Double, Double)] = { 115 | 116 | data.mapPartitions { iter => { 117 | val client = new PSClient[Long, Float](ps) 118 | var preList = List[(Double, Double)]() 119 | var count = 0 120 | val weightIndex = new util.HashSet[Long]() 121 | val localData = new util.ArrayList[(Double, Array[Long], Array[Float])]() 122 | var batchId = 0 123 | while (iter.hasNext) { 124 | val dataLine = iter.next() 125 | localData.add(dataLine) 126 | dataLine._2.foreach(weightIndex.add) 127 | count += 1 128 | if (count == batchSize || !iter.hasNext) { 129 | val localWeight = client.pull(weightIndex) 130 | println("localWeight: " + localWeight.mkString(",")) 131 | //预测label 132 | localData.foreach{x => preList +:= (BinaryClassify.predict((x._2, x._3), localWeight.toMap), x._1)} 133 | count = 0 134 | batchId += 1 135 | localData.clear() 136 | weightIndex.clear() 137 | } 138 | } 139 | client.shutDown() 140 | preList.iterator 141 | } 142 | } 143 | } 144 | } -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/ml/LogisticRegressionWithDCASGD.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.ml 2 | 3 | import java.io.Serializable 4 | import java.util 5 | 6 | import net.qihoo.xitong.xdml.optimization.BinaryClassify 7 | import net.qihoo.xitong.xdml.ps.{PS, PSClient} 8 | import net.qihoo.xitong.xdml.updater.DCASGDUpdater 9 | import org.apache.spark.rdd.RDD 10 | 11 | import scala.collection.JavaConversions._ 12 | 13 | class LogisticRegressionWithDCASGD(ps: PS) extends Serializable { 14 | private var iterNum = 1 15 | private var batchSize = 50 16 | private var learningRate:Float = 0.01F 17 | private var dcAsgdCoff:Float = 0.1F 18 | 19 | def fit(data: RDD[(Double, Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) = { 20 | LogisticRegressionWithDCASGD.train(ps, data, iterNum, batchSize, learningRate,dcAsgdCoff) 21 | } 22 | 23 | def predict(data:RDD[(Double, Array[Long], Array[Float])]):RDD[(Double,Double)]={ 24 | LogisticRegression.predict(ps,data,batchSize) 25 | } 26 | 27 | def setIterNum(iterNum:Int):this.type = { 28 | this.iterNum = iterNum 29 | this 30 | } 31 | 32 | def setBatchSize(batchSize:Int):this.type = { 33 | this.batchSize = batchSize 34 | this 35 | } 36 | 37 | def setLearningRate(lr:Float):this.type ={ 38 | this.learningRate = lr 39 | this 40 | } 41 | 42 | def setDcAsgdCoff(coff:Float):this.type = { 43 | this.dcAsgdCoff = coff 44 | this 45 | } 46 | } 47 | 48 | object LogisticRegressionWithDCASGD{ 49 | def train(ps: PS, 50 | data: RDD[(Double, Array[Long], Array[Float])], 51 | numIterations: Int, 52 | batchSize: Int, 53 | learningRate: Float, 54 | coff:Float 55 | ): (Map[Int, Double], (Long, Long)) = { 56 | //train 57 | var trainInfo: Map[Int, Double] = Map() 58 | var posNum: Long = 0 59 | var negNum: Long = 0 60 | for (iterNum <- 0 until numIterations) { 61 | val result = data.mapPartitions { iter => 62 | val client = new PSClient[Long, Array[Float]](ps) 63 | val updater = new DCASGDUpdater() 64 | .setLearningRate(learningRate) 65 | .setCoff(coff) 66 | client.setUpdater(updater) 67 | var (count, totalCount, posNum, negNum, batchId, totalLoss) = (0L, 0L, 0L, 0L, 0, 0.0) 68 | val weightIndex = new util.HashSet[Long]() 69 | val localData = new util.ArrayList[(Double, Array[Long], Array[Float])]() 70 | while (iter.hasNext) { 71 | val dataLine = iter.next() 72 | localData.add(dataLine) 73 | dataLine._2.foreach(weightIndex.add) 74 | if (iterNum == 0) { 75 | if (dataLine._1 > 0.0) 76 | posNum += 1 77 | else 78 | negNum += 1 79 | } 80 | count += 1 81 | if (count == batchSize || !iter.hasNext) { 82 | val localMap = client.pull(weightIndex).toMap 83 | val totalGrad = new util.HashMap[Long, Float]((weightIndex.size / 0.75f + 1).toInt) 84 | localData.foreach { dataIter => 85 | val (grad, loss) = BinaryClassify.train( 86 | dataIter._1, 87 | (dataIter._2, dataIter._3), 88 | localMap.map{case (k,v) => (k,v(0))}) 89 | // update the inc 90 | grad.foreach { case (id, v) => 91 | totalGrad.put(id, totalGrad.getOrElse(id, 0F) + v.toFloat) 92 | } 93 | totalLoss += loss 94 | } 95 | val pushMap = totalGrad.map { case (k, v) => (k, Array(v / count)) }.toMap 96 | client.push(pushMap) 97 | batchId += 1 98 | totalCount += count 99 | count = 0 100 | localData.clear() 101 | weightIndex.clear() 102 | } 103 | } 104 | println("loss:" + totalLoss) 105 | client.shutDown() 106 | Iterator((totalLoss, totalCount, posNum, negNum)) 107 | } 108 | val totalResult = result.reduce((x, y) => (x._1 + y._1, x._2 + y._2, x._3 + y._3, x._4 + y._4)) 109 | posNum = totalResult._3 110 | negNum = totalResult._4 111 | trainInfo += ((iterNum + 1) -> (totalResult._1 / totalResult._2)) 112 | } 113 | (trainInfo, (posNum, negNum)) 114 | } 115 | } -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/ml/LogisticRegressionWithMomentum.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.ml 2 | 3 | import java.io.Serializable 4 | import java.util 5 | 6 | import net.qihoo.xitong.xdml.ps.{PS, PSClient} 7 | import net.qihoo.xitong.xdml.updater.MomLRUpdater 8 | import net.qihoo.xitong.xdml.optimization.BinaryClassify 9 | import org.apache.spark.rdd.RDD 10 | 11 | import scala.collection.JavaConversions._ 12 | 13 | class LogisticRegressionWithMomentum(ps:PS) extends Serializable { 14 | private var iterNum = 1 15 | private var batchSize = 50 16 | private var learningRate:Float = 0.01F 17 | private var momentumCoff:Float = 0.1F 18 | def fit(data: RDD[(Double, Array[Long], Array[Float])]): (Map[Int, Double], (Long, Long)) = { 19 | LogisticRegressionWithMomentum.runMomentum(ps,data,iterNum,batchSize,learningRate,momentumCoff) 20 | } 21 | def predict(data:RDD[(Double, Array[Long], Array[Float])]):RDD[(Double,Double)]={ 22 | LogisticRegression.predict(ps,data,batchSize) 23 | } 24 | 25 | def setIterNum(iterNum:Int):this.type = { 26 | this.iterNum = iterNum 27 | this 28 | } 29 | 30 | def setBatchSize(batchSize:Int):this.type = { 31 | this.batchSize = batchSize 32 | this 33 | } 34 | 35 | def setLearningRate(lr:Float):this.type = { 36 | this.learningRate = lr 37 | this 38 | } 39 | 40 | def setMomemtumCoff(coff:Float):this.type = { 41 | this.momentumCoff = coff 42 | this 43 | } 44 | } 45 | 46 | object LogisticRegressionWithMomentum{ 47 | def runMomentum(ps:PS, 48 | data: RDD[(Double, Array[Long], Array[Float])], 49 | numIterations: Int, 50 | batchSize: Int, 51 | learningRate:Float, 52 | coff:Float 53 | ): (Map[Int, Double], (Long, Long)) = { 54 | //train 55 | var trainInfo: Map[Int, Double] = Map() 56 | var posNum:Long = 0 57 | var negNum:Long = 0 58 | for (iterNum <- 0 until numIterations) { 59 | val result = data.mapPartitions { iter => 60 | val client = new PSClient[Long,Array[Float]](ps) 61 | val update = new MomLRUpdater() 62 | .setLearningRate(learningRate) 63 | .setCoff(coff) 64 | client.setUpdater(update) 65 | var (count,totalCount,posNum,negNum,batchId,totalLoss) = (0L, 0L, 0L, 0L, 0, 0.0) 66 | val weightIndex = new util.HashSet[Long]() 67 | val localData = new util.ArrayList[(Double, Array[Long], Array[Float])]() 68 | while (iter.hasNext) { 69 | val dataLine = iter.next() 70 | localData.add(dataLine) 71 | dataLine._2.foreach(weightIndex.add) 72 | if (iterNum == 0) { 73 | if (dataLine._1 > 0.0) 74 | posNum += 1 75 | else 76 | negNum += 1 77 | } 78 | count += 1 79 | if (count == batchSize || !iter.hasNext) { 80 | val localMap = client.pull(weightIndex).toMap 81 | val totalGrad = new util.HashMap[Long, Float]((weightIndex.size / 0.75f + 1).toInt) 82 | localData.foreach { dataIter => 83 | val (grad, loss) = BinaryClassify.train( 84 | dataIter._1, 85 | (dataIter._2, dataIter._3), 86 | localMap.map{case (k,v) => (k,v(0))}) 87 | // update the inc 88 | grad.foreach { case (id, v) => 89 | totalGrad.put(id, totalGrad.getOrElse(id, 0F) + v.toFloat) 90 | } 91 | totalLoss += loss 92 | } 93 | val pushMap = totalGrad.map { case (k, v) => (k, Array(v / count)) }.toMap 94 | client.push(pushMap) 95 | batchId += 1 96 | totalCount += count 97 | count = 0 98 | localData.clear() 99 | weightIndex.clear() 100 | } 101 | } 102 | println("loss:" + totalLoss) 103 | client.shutDown() 104 | Iterator((totalLoss, totalCount,posNum,negNum)) 105 | } 106 | val totalResult = result.reduce((x, y) => (x._1 + y._1, x._2 + y._2,x._3 + y._3,x._4 + y._4)) 107 | posNum = totalResult._3 108 | negNum = totalResult._4 109 | trainInfo += ((iterNum + 1) -> (totalResult._1 / totalResult._2)) 110 | } 111 | 112 | (trainInfo, (posNum, negNum)) 113 | } 114 | } -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/data/DataHandler.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.data 2 | 3 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector} 4 | import org.apache.spark.sql.functions._ 5 | import org.apache.spark.sql.types.{DoubleType, StructType} 6 | import org.apache.spark.sql.{DataFrame, Row, SparkSession} 7 | 8 | 9 | object DataHandler extends Serializable { 10 | 11 | // WARNING: there could be a problem if the number of partitions is very low 12 | // TODO: test big file with only one part 13 | // TODO: test nullValue 14 | def readData(spark: SparkSession, 15 | dataPath: String, 16 | dataDelimiter: String, 17 | structType: StructType, 18 | nullValue: String, 19 | numPartitions: Int = -1, 20 | header: Boolean = false): DataFrame = { 21 | val df = spark.read 22 | .option("sep", dataDelimiter) 23 | .option("ignoreLeadingWhiteSpace", true) 24 | .option("ignoreTrailingWhiteSpace", true) 25 | .option("nullValue", nullValue) 26 | .option("header", header) 27 | .schema(structType) 28 | .csv(dataPath) 29 | if (numPartitions > 1) { 30 | df.repartition(numPartitions) 31 | } else { 32 | df 33 | } 34 | } 35 | 36 | def writeData(df: DataFrame, 37 | path: String, 38 | delimiter: String, 39 | numPartitions: Int = -1, 40 | header: Boolean = false): Unit = { 41 | val dfTarget = if (numPartitions > 1) df.repartition(numPartitions) else df 42 | dfTarget.write.option("sep", delimiter).option("header", header).csv(path) 43 | } 44 | 45 | // WARNING: there could be a problem if the number of partitions is very low 46 | def readLibSVMData(spark: SparkSession, 47 | dataPath: String, 48 | numFeatures: Int = -1, 49 | numPartitions: Int = -1, 50 | labelColName: String = "labelLibSVM", 51 | featuresColName: String = "featuresLibSVM"): DataFrame = { 52 | val df = spark.read.format("libsvm").option("numFeatures", numFeatures).load(dataPath) 53 | .withColumnRenamed("label", labelColName) 54 | .withColumnRenamed("features", featuresColName) 55 | if (numPartitions > 1) { 56 | df.repartition(numPartitions) 57 | } else { 58 | df 59 | } 60 | } 61 | 62 | def writeLibSVMData(df: DataFrame, 63 | labelColName: String, 64 | featuresColName: String, 65 | path: String, 66 | numPartitions: Int = -1): Unit = { 67 | val dfTarget = if (numPartitions > 1) df.repartition(numPartitions) else df 68 | val resRDD = dfTarget.select(col(labelColName).cast(DoubleType), col(featuresColName)).rdd.map{ 69 | case Row(label: Double, features: Vector) => { 70 | features match { 71 | case sv: SparseVector => { 72 | val kvs = sv.indices.zip(sv.values) 73 | label.toString + " " + kvs.map{ case (k, v) => k.toString + ":" + v.toString }.mkString(" ") 74 | } 75 | case dv: DenseVector => { 76 | label.toString + " " + dv.toArray.zipWithIndex.map{ case (v, k) => k.toString + ":" + v.toString }.mkString(" ") 77 | } 78 | } 79 | } 80 | } 81 | resRDD.saveAsTextFile(path) 82 | } 83 | 84 | // TODO: to be checked 85 | // WARNING: there could be a problem if the number of partitions is very low 86 | def readHiveData(spark: SparkSession, 87 | tableName: String, 88 | numPartitions: Int = -1): DataFrame = { 89 | val df = spark.read.table(tableName) 90 | if (numPartitions > 1) { 91 | df.repartition(numPartitions) 92 | } else { 93 | df 94 | } 95 | } 96 | 97 | } 98 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/data/LogHandler.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.data 2 | 3 | import org.apache.log4j.{Level, Logger} 4 | 5 | object LogHandler extends Serializable { 6 | 7 | @transient lazy val logger = Logger.getLogger("XDML") 8 | 9 | def avoidLog(): Unit = { 10 | Logger.getLogger("org").setLevel(Level.OFF) 11 | Logger.getLogger("akka").setLevel(Level.OFF) 12 | } 13 | 14 | } -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/data/SchemaHandler.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.data 2 | 3 | import org.apache.spark.SparkContext 4 | import org.apache.spark.sql.types._ 5 | 6 | import scala.io.Source 7 | 8 | 9 | class SchemaHandler(val schema: StructType, val namesAndTypes: Array[Array[String]], checkLabel: Boolean = false) extends Serializable { 10 | 11 | val keyColName: Array[String] = namesAndTypes.filter(x => x(1).equals(SchemaHandler.Key)).map(x => x(0)) 12 | val labelColName: Array[String] = namesAndTypes.filter(x => x(1).equals(SchemaHandler.Label)).map(x => x(0)) 13 | val featColNames: Array[String] = namesAndTypes.filter(x => x(1).equals(SchemaHandler.Cat) 14 | || x(1).equals(SchemaHandler.Num) 15 | || x(1).equals(SchemaHandler.Text) 16 | || x(1).equals(SchemaHandler.MultiCat)).map(x => x(0)) 17 | val catFeatColNames: Array[String] = namesAndTypes.filter(x => x(1).equals(SchemaHandler.Cat)).map(x => x(0)) 18 | val multiCatFeatColNames: Array[String] = namesAndTypes.filter(x => x(1).equals(SchemaHandler.MultiCat)).map(x => x(0)) 19 | val numFeatColNames: Array[String] = namesAndTypes.filter(x => x(1).equals(SchemaHandler.Num)).map(x => x(0)) 20 | val textFeatColNames: Array[String] = namesAndTypes.filter(x => x(1).equals(SchemaHandler.Text)).map(x => x(0)) 21 | val otherColNames: Array[String] = namesAndTypes.filter(x => x(1).equals(SchemaHandler.Other)).map(x => x(0)) 22 | if (checkLabel) assert(labelColName.length == 1, "one and only one label needs to be provided") 23 | 24 | } 25 | 26 | 27 | object SchemaHandler extends Serializable { 28 | 29 | val Text = "Text" 30 | val Num = "Num" 31 | val Cat = "Cat" 32 | val MultiCat = "MultiCat" 33 | val Key = "Key" 34 | val Label = "Label" 35 | val Other = "Other" 36 | 37 | def readSchema(sc: SparkContext, schemaPath: String, delimiter: String): SchemaHandler = { 38 | val namesAndTypes = if(schemaPath.startsWith("jar://")) { 39 | val jarPath = schemaPath.substring("jar://".length()) 40 | val stream = getClass().getClassLoader().getResourceAsStream(jarPath) 41 | val lines = Source.fromInputStream(stream).getLines 42 | lines.toArray.map(line => line.split(delimiter)) 43 | } else { 44 | sc.textFile(schemaPath, 1).collect().map(line => line.split(delimiter)) 45 | } 46 | val sfArr = namesAndTypes.map(arr => { 47 | arr(1) match { 48 | case Num => StructField(arr(0), DoubleType, true) 49 | case Cat => StructField(arr(0), StringType, true) 50 | case MultiCat => StructField(arr(0), StringType, true) 51 | case Text => StructField(arr(0), StringType, true) 52 | case Key => StructField(arr(0), StringType, true) 53 | case Label => StructField(arr(0), StringType, true) 54 | case Other => StructField(arr(0), StringType, true) 55 | case _ => throw new IllegalArgumentException("data type unknown") 56 | } 57 | }) 58 | val schema = StructType(sfArr) 59 | new SchemaHandler(schema, namesAndTypes) 60 | } 61 | 62 | } 63 | 64 | 65 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/loss/HingeLossFunc.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.loss 2 | 3 | import org.apache.spark.mllib.linalg.Vector 4 | import net.qihoo.xitong.xdml.model.linalg.BLAS 5 | 6 | 7 | class HingeLossFunc extends LossFunc { 8 | 9 | ////////////////////////// without intercept ////////////////////////////////// 10 | 11 | override def gradientWithLoss(data: Vector, 12 | label: Double, 13 | weights: Vector): (Double, Double) = { 14 | val dotProduct = BLAS.dot(weights, data) 15 | val labelScaled = 2 * label - 1.0 16 | if (1.0 > labelScaled * dotProduct) { 17 | (-labelScaled, 1.0 - labelScaled * dotProduct) 18 | } else { 19 | (0.0, 0.0) 20 | } 21 | } 22 | 23 | override def gradient(data: Vector, 24 | label: Double, 25 | weights: Vector): Double = { 26 | val dotProduct = BLAS.dot(weights, data) 27 | val labelScaled = 2 * label - 1.0 28 | if (1.0 > labelScaled * dotProduct) { 29 | -labelScaled 30 | } else { 31 | 0.0 32 | } 33 | } 34 | 35 | override def loss(data: Vector, 36 | label: Double, 37 | weights: Vector): Double = { 38 | val dotProduct = BLAS.dot(weights, data) 39 | val labelScaled = 2 * label - 1.0 40 | if (1.0 > labelScaled * dotProduct) { 41 | 1.0 - labelScaled * dotProduct 42 | } else { 43 | 0.0 44 | } 45 | } 46 | 47 | override def gradientFromDot(dot: Double, 48 | label: Double): Double = { 49 | val labelScaled = 2 * label - 1.0 50 | if (1.0 > labelScaled * dot) { 51 | -labelScaled 52 | } else { 53 | 0.0 54 | } 55 | } 56 | 57 | ////////////////////////// with intercept ////////////////////////////////// 58 | 59 | override def gradientWithLoss(data: Vector, 60 | label: Double, 61 | weights: Vector, 62 | intercept: Double): (Double, Double) = { 63 | // val dotProduct = BLAS.dot(weights, data) 64 | val dotProduct = BLAS.dot(weights, data) + intercept 65 | val labelScaled = 2 * label - 1.0 66 | if (1.0 > labelScaled * dotProduct) { 67 | (-labelScaled, 1.0 - labelScaled * dotProduct) 68 | } else { 69 | (0.0, 0.0) 70 | } 71 | } 72 | 73 | override def gradient(data: Vector, 74 | label: Double, 75 | weights: Vector, 76 | intercept: Double): Double = { 77 | // val dotProduct = BLAS.dot(weights, data) 78 | val dotProduct = BLAS.dot(weights, data) + intercept 79 | val labelScaled = 2 * label - 1.0 80 | if (1.0 > labelScaled * dotProduct) { 81 | -labelScaled 82 | } else { 83 | 0.0 84 | } 85 | } 86 | 87 | override def loss(data: Vector, 88 | label: Double, 89 | weights: Vector, 90 | intercept: Double): Double = { 91 | // val dotProduct = BLAS.dot(weights, data) 92 | val dotProduct = BLAS.dot(weights, data) + intercept 93 | val labelScaled = 2 * label - 1.0 94 | if (1.0 > labelScaled * dotProduct) { 95 | 1.0 - labelScaled * dotProduct 96 | } else { 97 | 0.0 98 | } 99 | } 100 | 101 | } -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/loss/L2LossFunc.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.loss 2 | 3 | import net.qihoo.xitong.xdml.model.linalg.BLAS 4 | import org.apache.spark.mllib.linalg.Vector 5 | 6 | 7 | class L2LossFunc extends LossFunc { 8 | 9 | ////////////////////////// without intercept ////////////////////////////////// 10 | 11 | override def loss(data: Vector, 12 | label: Double, 13 | weights: Vector): Double = { 14 | val margin = BLAS.dot(weights, data) 15 | math.pow(margin - label, 2) / 2 16 | } 17 | 18 | override def gradient(data: Vector, 19 | label: Double, 20 | weights: Vector): Double = { 21 | val margin = BLAS.dot(weights, data) 22 | margin - label 23 | } 24 | 25 | override def gradientWithLoss(data: Vector, 26 | label: Double, 27 | weights: Vector): (Double, Double) = { 28 | val margin = BLAS.dot(weights, data) 29 | (margin - label, math.pow(margin - label, 2) / 2) 30 | } 31 | 32 | override def gradientFromDot(dot: Double, 33 | label: Double): Double = { 34 | dot - label 35 | } 36 | 37 | ////////////////////////// with intercept ////////////////////////////////// 38 | 39 | override def loss(data: Vector, 40 | label: Double, 41 | weights: Vector, 42 | intercept: Double): Double = { 43 | val margin = BLAS.dot(weights, data) + intercept 44 | math.pow(margin - label, 2) / 2 45 | } 46 | 47 | override def gradient(data: Vector, 48 | label: Double, 49 | weights: Vector, 50 | intercept: Double): Double = { 51 | val margin = BLAS.dot(weights, data) + intercept 52 | margin - label 53 | } 54 | 55 | override def gradientWithLoss(data: Vector, 56 | label: Double, 57 | weights: Vector, 58 | intercept: Double): (Double, Double) = { 59 | val margin = BLAS.dot(weights, data) + intercept 60 | (margin - label, math.pow(margin - label, 2) / 2) 61 | } 62 | 63 | } 64 | 65 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/loss/LogitLossFunc.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.loss 2 | 3 | import org.apache.spark.mllib.linalg.Vector 4 | import net.qihoo.xitong.xdml.model.util.MLUtils 5 | import net.qihoo.xitong.xdml.model.linalg.BLAS 6 | 7 | 8 | class LogitLossFunc extends LossFunc { 9 | 10 | def sigmoid(z: Double): Double = { 11 | 1.0 / (1.0 + math.exp(-z)) 12 | } 13 | 14 | ////////////////////////// without intercept ////////////////////////////////// 15 | 16 | override def loss(data: Vector, 17 | label: Double, 18 | weights: Vector): Double = { 19 | val margin = - BLAS.dot(weights, data) 20 | if (label > 0.5) { 21 | // The following is equivalent to log(1 + exp(margin)) but more numerically stable. 22 | MLUtils.log1pExp(margin) 23 | } else { 24 | MLUtils.log1pExp(margin) - margin 25 | } 26 | } 27 | 28 | override def gradient(data: Vector, 29 | label: Double, 30 | weights: Vector): Double = { 31 | val margin = - BLAS.dot(weights, data) 32 | val factor = (1.0 / (1.0 + math.exp(margin))) - label 33 | factor 34 | } 35 | 36 | override def gradientWithLoss(data: Vector, 37 | label: Double, 38 | weights: Vector): (Double, Double) = { 39 | val margin = - BLAS.dot(weights, data) 40 | val factor = (1.0 / (1.0 + math.exp(margin))) - label 41 | val loss = if (label > 0.5) { 42 | // The following is equivalent to log(1 + exp(margin)) but more numerically stable. 43 | MLUtils.log1pExp(margin) 44 | } else { 45 | MLUtils.log1pExp(margin) - margin 46 | } 47 | (factor, loss) 48 | } 49 | 50 | override def gradientFromDot(dot: Double, 51 | label: Double): Double = { 52 | sigmoid(dot) - label 53 | } 54 | 55 | ////////////////////////// with intercept ////////////////////////////////// 56 | 57 | override def loss(data: Vector, 58 | label: Double, 59 | weights: Vector, 60 | intercept: Double): Double = { 61 | // val margin = - BLAS.dot(weights, data) 62 | val margin = - BLAS.dot(weights, data) - intercept 63 | if (label > 0.5) { 64 | // The following is equivalent to log(1 + exp(margin)) but more numerically stable. 65 | MLUtils.log1pExp(margin) 66 | } else { 67 | MLUtils.log1pExp(margin) - margin 68 | } 69 | } 70 | 71 | override def gradient(data: Vector, 72 | label: Double, 73 | weights: Vector, 74 | intercept: Double): Double = { 75 | // val margin = - BLAS.dot(weights, data) 76 | val margin = - BLAS.dot(weights, data) - intercept 77 | val factor = (1.0 / (1.0 + math.exp(margin))) - label 78 | factor 79 | } 80 | 81 | override def gradientWithLoss(data: Vector, 82 | label: Double, 83 | weights: Vector, 84 | intercept: Double): (Double, Double) = { 85 | // val margin = - BLAS.dot(weights, data) 86 | val margin = - BLAS.dot(weights, data) - intercept 87 | val factor = (1.0 / (1.0 + math.exp(margin))) - label 88 | val loss = if (label > 0.5) { 89 | // The following is equivalent to log(1 + exp(margin)) but more numerically stable. 90 | MLUtils.log1pExp(margin) 91 | } else { 92 | MLUtils.log1pExp(margin) - margin 93 | } 94 | (factor, loss) 95 | } 96 | 97 | } 98 | 99 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/loss/LossFunc.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.loss 2 | 3 | import org.apache.spark.mllib.linalg.Vector 4 | 5 | 6 | abstract class LossFunc extends Serializable { 7 | 8 | ////////////////////////// without intercept ////////////////////////////////// 9 | 10 | def gradient(data: Vector, label: Double, weights: Vector): Double 11 | 12 | def loss(data: Vector, label: Double, weights: Vector): Double 13 | 14 | def gradientWithLoss(data: Vector, label: Double, weights: Vector): (Double, Double) 15 | 16 | def gradientFromDot(dot: Double, label: Double): Double 17 | 18 | ////////////////////////// with intercept ////////////////////////////////// 19 | 20 | def gradient(data: Vector, label: Double, weights: Vector, intercept: Double): Double 21 | 22 | def loss(data: Vector, label: Double, weights: Vector, intercept: Double): Double 23 | 24 | def gradientWithLoss(data: Vector, label: Double, weights: Vector, intercept: Double): (Double, Double) 25 | 26 | } 27 | 28 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/loss/MultiHingeLossFunc.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.loss 2 | 3 | import net.qihoo.xitong.xdml.model.linalg.BLAS 4 | import org.apache.spark.mllib.linalg.Vector 5 | 6 | 7 | class MultiHingeLossFunc(numClasses: Int) extends MultiLossFunc { 8 | 9 | def gradientFromMargins(margins: Array[Double], label: Double): Array[Double] = { 10 | val labelMap = Array.fill(margins.length)(-1.0) 11 | labelMap(label.toInt) = 1.0 12 | val gradientFactors = Array.fill(margins.length)(0.0) 13 | for(ind <- margins.indices) { 14 | gradientFactors(ind) = 15 | if (1.0 > labelMap(ind) * margins(ind)) { 16 | -labelMap(ind) 17 | } else { 18 | 0.0 19 | } 20 | } 21 | gradientFactors 22 | } 23 | 24 | ////////////////////////// with intercept ////////////////////////////////// 25 | 26 | def gradientWithLoss(data: Vector, label: Double, 27 | weightsArr: Array[Vector], 28 | interceptArr: Array[Double]): (Array[Double], Double) = { 29 | val margins = Array.fill(weightsArr.length)(0.0) 30 | for(ind <- margins.indices) { 31 | margins(ind) = BLAS.dot(weightsArr(ind), data) + interceptArr(ind) 32 | } 33 | val labelMap = Array.fill(margins.length)(-1.0) 34 | labelMap(label.toInt) = 1.0 35 | val gradientFactors = Array.fill(margins.length)(0.0) 36 | val losses = Array.fill(margins.length)(0.0) 37 | for(ind <- margins.indices) { 38 | if (1.0 > labelMap(ind) * margins(ind)) { 39 | gradientFactors(ind) = -labelMap(ind) 40 | losses(ind) = 1.0 - labelMap(ind) * margins(ind) 41 | } else { 42 | gradientFactors(ind) = 0.0 43 | losses(ind) = 0.0 44 | } 45 | } 46 | (gradientFactors, losses.sum) 47 | } 48 | 49 | ////////////////////////// without intercept ////////////////////////////////// 50 | 51 | def gradientWithLoss(data: Vector, label: Double, 52 | weightsArr: Array[Vector]): (Array[Double], Double) = { 53 | val margins = Array.fill(weightsArr.length)(0.0) 54 | for(ind <- margins.indices) { 55 | margins(ind) = BLAS.dot(weightsArr(ind), data) 56 | } 57 | val labelMap = Array.fill(margins.length)(-1.0) 58 | labelMap(label.toInt) = 1.0 59 | val gradientFactors = Array.fill(margins.length)(0.0) 60 | val losses = Array.fill(margins.length)(0.0) 61 | for(ind <- margins.indices) { 62 | if (1.0 > labelMap(ind) * margins(ind)) { 63 | gradientFactors(ind) = -labelMap(ind) 64 | losses(ind) = 1.0 - labelMap(ind) * margins(ind) 65 | } else { 66 | gradientFactors(ind) = 0.0 67 | losses(ind) = 0.0 68 | } 69 | } 70 | (gradientFactors, losses.sum) 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/loss/MultiLogitLossFunc.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.loss 2 | 3 | import org.apache.spark.mllib.linalg.Vector 4 | 5 | 6 | class MultiLogitLossFunc(numClasses: Int) extends MultiLossFunc { 7 | 8 | def gradientFromMargins(margins: Array[Double], label: Double): Array[Double] = { 9 | // marginY is margins(label - 1) in the formula. 10 | var marginY = 0.0 11 | var maxMargin = Double.NegativeInfinity 12 | var maxMarginIndex = 0 13 | for(i <- 0 until (numClasses-1)) { 14 | if (i == label.toInt - 1) marginY = margins(i) 15 | if (margins(i) > maxMargin) { 16 | maxMargin = margins(i) 17 | maxMarginIndex = i 18 | } 19 | } 20 | 21 | val sum = { 22 | var temp = 0.0 23 | if (maxMargin > 0) { 24 | for (i <- 0 until numClasses - 1) { 25 | margins(i) -= maxMargin 26 | if (i == maxMarginIndex) { 27 | temp += math.exp(-maxMargin) 28 | } else { 29 | temp += math.exp(margins(i)) 30 | } 31 | } 32 | } else { 33 | for (i <- 0 until numClasses - 1) { 34 | temp += math.exp(margins(i)) 35 | } 36 | } 37 | temp 38 | } 39 | 40 | for (i <- 0 until numClasses - 1) { 41 | margins(i) = math.exp(margins(i)) / (sum + 1.0) - { if (label != 0.0 && label == i + 1) 1.0 else 0.0 } 42 | } 43 | 44 | margins 45 | } 46 | 47 | ////////////////////////// with intercept ////////////////////////////////// 48 | 49 | def gradientWithLoss(data: Vector, label: Double, 50 | weightsArr: Array[Vector], 51 | interceptArr: Array[Double]): (Array[Double], Double) = { 52 | // marginY is margins(label - 1) in the formula. 53 | var marginY = 0.0 54 | var maxMargin = Double.NegativeInfinity 55 | var maxMarginIndex = 0 56 | val margins = interceptArr.clone() 57 | for(i <- 0 until (numClasses-1)) { 58 | data.foreachActive { (index, value) => if (value != 0.0) margins(i) += value * weightsArr(i)(index) } 59 | if (i == label.toInt - 1) marginY = margins(i) 60 | if (margins(i) > maxMargin) { 61 | maxMargin = margins(i) 62 | maxMarginIndex = i 63 | } 64 | } 65 | 66 | val sum = { 67 | var temp = 0.0 68 | if (maxMargin > 0) { 69 | for (i <- 0 until numClasses - 1) { 70 | margins(i) -= maxMargin 71 | if (i == maxMarginIndex) { 72 | temp += math.exp(-maxMargin) 73 | } else { 74 | temp += math.exp(margins(i)) 75 | } 76 | } 77 | } else { 78 | for (i <- 0 until numClasses - 1) { 79 | temp += math.exp(margins(i)) 80 | } 81 | } 82 | temp 83 | } 84 | 85 | for (i <- 0 until numClasses - 1) { 86 | margins(i) = math.exp(margins(i)) / (sum + 1.0) - { if (label != 0.0 && label == i + 1) 1.0 else 0.0 } 87 | } 88 | 89 | var loss = if (label > 0.0) math.log1p(sum) - marginY else math.log1p(sum) 90 | if (maxMargin > 0) { 91 | loss += maxMargin 92 | } 93 | 94 | (margins, loss) 95 | } 96 | 97 | 98 | ////////////////////////// without intercept ////////////////////////////////// 99 | 100 | def gradientWithLoss(data: Vector, label: Double, 101 | weightsArr: Array[Vector]): (Array[Double], Double) = { 102 | // marginY is margins(label - 1) in the formula. 103 | var marginY = 0.0 104 | var maxMargin = Double.NegativeInfinity 105 | var maxMarginIndex = 0 106 | val margins = Array.fill(numClasses-1)(0.0) 107 | for(i <- 0 until (numClasses-1)) { 108 | data.foreachActive { (index, value) => if (value != 0.0) margins(i) += value * weightsArr(i)(index) } 109 | if (i == label.toInt - 1) marginY = margins(i) 110 | if (margins(i) > maxMargin) { 111 | maxMargin = margins(i) 112 | maxMarginIndex = i 113 | } 114 | } 115 | 116 | val sum = { 117 | var temp = 0.0 118 | if (maxMargin > 0) { 119 | for (i <- 0 until numClasses - 1) { 120 | margins(i) -= maxMargin 121 | if (i == maxMarginIndex) { 122 | temp += math.exp(-maxMargin) 123 | } else { 124 | temp += math.exp(margins(i)) 125 | } 126 | } 127 | } else { 128 | for (i <- 0 until numClasses - 1) { 129 | temp += math.exp(margins(i)) 130 | } 131 | } 132 | temp 133 | } 134 | 135 | for (i <- 0 until numClasses - 1) { 136 | margins(i) = math.exp(margins(i)) / (sum + 1.0) - { if (label != 0.0 && label == i + 1) 1.0 else 0.0 } 137 | } 138 | 139 | var loss = if (label > 0.0) math.log1p(sum) - marginY else math.log1p(sum) 140 | if (maxMargin > 0) { 141 | loss += maxMargin 142 | } 143 | 144 | (margins, loss) 145 | } 146 | 147 | 148 | 149 | } 150 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/loss/MultiLossFunc.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.loss 2 | 3 | import org.apache.spark.mllib.linalg.Vector 4 | 5 | 6 | abstract class MultiLossFunc extends Serializable { 7 | 8 | def gradientFromMargins(margins: Array[Double], label: Double): Array[Double] 9 | 10 | ////////////////////////// with intercept ////////////////////////////////// 11 | 12 | def gradientWithLoss(data: Vector, label: Double, 13 | weightsArr: Array[Vector], 14 | interceptArr: Array[Double]): (Array[Double], Double) 15 | 16 | ////////////////////////// without intercept ////////////////////////////////// 17 | 18 | def gradientWithLoss(data: Vector, label: Double, 19 | weightsArr: Array[Vector]): (Array[Double], Double) 20 | 21 | } 22 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/loss/MultiSmoothHingeLossFunc.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.loss 2 | 3 | import net.qihoo.xitong.xdml.model.linalg.BLAS 4 | import org.apache.spark.mllib.linalg.Vector 5 | 6 | 7 | class MultiSmoothHingeLossFunc(numClasses: Int) extends MultiLossFunc { 8 | 9 | def gradientFromMargins(margins: Array[Double], label: Double): Array[Double] = { 10 | val labelMap = Array.fill(margins.length)(-1.0) 11 | labelMap(label.toInt) = 1.0 12 | val gradientFactors = Array.fill(margins.length)(0.0) 13 | for(ind <- margins.indices) { 14 | val max = math.max(0.0, 1.0 - labelMap(ind) * margins(ind)) 15 | gradientFactors(ind) = 16 | if (0.0 > labelMap(ind) * margins(ind)) { 17 | - labelMap(ind) 18 | } else { 19 | - labelMap(ind) * max 20 | } 21 | } 22 | gradientFactors 23 | } 24 | 25 | ////////////////////////// with intercept ////////////////////////////////// 26 | 27 | def gradientWithLoss(data: Vector, label: Double, 28 | weightsArr: Array[Vector], 29 | interceptArr: Array[Double]): (Array[Double], Double) = { 30 | val margins = Array.fill(weightsArr.length)(0.0) 31 | for(ind <- margins.indices) { 32 | margins(ind) = BLAS.dot(weightsArr(ind), data) + interceptArr(ind) 33 | } 34 | val labelMap = Array.fill(margins.length)(-1.0) 35 | labelMap(label.toInt) = 1.0 36 | val gradientFactors = Array.fill(margins.length)(0.0) 37 | val losses = Array.fill(margins.length)(0.0) 38 | for(ind <- margins.indices) { 39 | val max = math.max(0.0, 1.0 - labelMap(ind) * margins(ind)) 40 | if (0.0 > labelMap(ind) * margins(ind)) { 41 | gradientFactors(ind) = - labelMap(ind) 42 | losses(ind) = max - 0.5 43 | } else { 44 | gradientFactors(ind) = - labelMap(ind) * max 45 | losses(ind) = max * max / 2 46 | } 47 | } 48 | (gradientFactors, losses.sum) 49 | } 50 | 51 | ////////////////////////// without intercept ////////////////////////////////// 52 | 53 | def gradientWithLoss(data: Vector, label: Double, 54 | weightsArr: Array[Vector]): (Array[Double], Double) = { 55 | val margins = Array.fill(weightsArr.length)(0.0) 56 | for(ind <- margins.indices) { 57 | margins(ind) = BLAS.dot(weightsArr(ind), data) 58 | } 59 | val labelMap = Array.fill(margins.length)(-1.0) 60 | labelMap(label.toInt) = 1.0 61 | val gradientFactors = Array.fill(margins.length)(0.0) 62 | val losses = Array.fill(margins.length)(0.0) 63 | for(ind <- margins.indices) { 64 | val max = math.max(0.0, 1.0 - labelMap(ind) * margins(ind)) 65 | if (0.0 > labelMap(ind) * margins(ind)) { 66 | gradientFactors(ind) = - labelMap(ind) 67 | losses(ind) = max - 0.5 68 | } else { 69 | gradientFactors(ind) = - labelMap(ind) * max 70 | losses(ind) = max * max / 2 71 | } 72 | } 73 | (gradientFactors, losses.sum) 74 | } 75 | 76 | } 77 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/loss/PoissonLossFunc.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.loss 2 | 3 | import net.qihoo.xitong.xdml.model.linalg.BLAS 4 | import org.apache.spark.mllib.linalg.Vector 5 | 6 | 7 | class PoissonLossFunc extends LossFunc { 8 | 9 | ////////////////////////// without intercept ////////////////////////////////// 10 | 11 | override def loss(data: Vector, 12 | label: Double, 13 | weights: Vector): Double = { 14 | val margin = BLAS.dot(weights, data) 15 | math.exp(margin) - label * margin 16 | } 17 | 18 | override def gradient(data: Vector, 19 | label: Double, 20 | weights: Vector): Double = { 21 | val margin = BLAS.dot(weights, data) 22 | math.exp(margin) - label 23 | } 24 | 25 | override def gradientWithLoss(data: Vector, 26 | label: Double, 27 | weights: Vector): (Double, Double) = { 28 | val margin = BLAS.dot(weights, data) 29 | val expMargin = math.exp(margin) 30 | (expMargin - label, expMargin - label * margin) 31 | } 32 | 33 | override def gradientFromDot(dot: Double, 34 | label: Double): Double = { 35 | math.exp(dot) - label 36 | } 37 | 38 | ////////////////////////// with intercept ////////////////////////////////// 39 | 40 | override def loss(data: Vector, 41 | label: Double, 42 | weights: Vector, 43 | intercept: Double): Double = { 44 | val margin = BLAS.dot(weights, data) + intercept 45 | math.exp(margin) - label * margin 46 | } 47 | 48 | override def gradient(data: Vector, 49 | label: Double, 50 | weights: Vector, 51 | intercept: Double): Double = { 52 | val margin = BLAS.dot(weights, data) + intercept 53 | math.exp(margin) - label 54 | } 55 | 56 | override def gradientWithLoss(data: Vector, 57 | label: Double, 58 | weights: Vector, 59 | intercept: Double): (Double, Double) = { 60 | val margin = BLAS.dot(weights, data) + intercept 61 | val expMargin = math.exp(margin) 62 | (expMargin - label, expMargin - label * margin) 63 | } 64 | 65 | } 66 | 67 | 68 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/loss/SmoothHingeLossFunc.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.loss 2 | 3 | import org.apache.spark.mllib.linalg.Vector 4 | import net.qihoo.xitong.xdml.model.linalg.BLAS 5 | 6 | 7 | class SmoothHingeLossFunc extends LossFunc { 8 | 9 | ////////////////////////// without intercept ////////////////////////////////// 10 | 11 | override def gradientWithLoss(data: Vector, 12 | label: Double, 13 | weights: Vector): (Double, Double) = { 14 | val dotProduct = BLAS.dot(weights, data) 15 | val labelScaled = 2 * label - 1.0 16 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct) 17 | if (labelScaled * dotProduct < 0) { 18 | (- labelScaled, max - 0.5) 19 | } else { 20 | (- labelScaled * max, max * max / 2) 21 | } 22 | } 23 | 24 | override def gradient(data: Vector, 25 | label: Double, 26 | weights: Vector): Double = { 27 | val dotProduct = BLAS.dot(weights, data) 28 | val labelScaled = 2 * label - 1.0 29 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct) 30 | if (labelScaled * dotProduct < 0) { 31 | - labelScaled 32 | } else { 33 | - labelScaled * max 34 | } 35 | } 36 | 37 | override def loss(data: Vector, 38 | label: Double, 39 | weights: Vector): Double = { 40 | val dotProduct = BLAS.dot(weights, data) 41 | val labelScaled = 2 * label - 1.0 42 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct) 43 | if (labelScaled * dotProduct < 0) { 44 | max - 0.5 45 | } else { 46 | max * max / 2 47 | } 48 | } 49 | 50 | override def gradientFromDot(dot: Double, 51 | label: Double): Double = { 52 | val labelScaled = 2 * label - 1.0 53 | val max = math.max(0.0, 1.0 - labelScaled * dot) 54 | if (labelScaled * dot < 0) { 55 | - labelScaled 56 | } else { 57 | - labelScaled * max 58 | } 59 | } 60 | 61 | ////////////////////////// with intercept ////////////////////////////////// 62 | 63 | override def gradientWithLoss(data: Vector, 64 | label: Double, 65 | weights: Vector, 66 | intercept: Double): (Double, Double) = { 67 | // val dotProduct = BLAS.dot(weights, data) 68 | val dotProduct = BLAS.dot(weights, data) + intercept 69 | val labelScaled = 2 * label - 1.0 70 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct) 71 | if (labelScaled * dotProduct < 0) { 72 | (- labelScaled, max - 0.5) 73 | } else { 74 | (- labelScaled * max, max * max / 2) 75 | } 76 | } 77 | 78 | override def gradient(data: Vector, 79 | label: Double, 80 | weights: Vector, 81 | intercept: Double): Double = { 82 | // val dotProduct = BLAS.dot(weights, data) 83 | val dotProduct = BLAS.dot(weights, data) + intercept 84 | val labelScaled = 2 * label - 1.0 85 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct) 86 | if (labelScaled * dotProduct < 0) { 87 | - labelScaled 88 | } else { 89 | - labelScaled * max 90 | } 91 | } 92 | 93 | override def loss(data: Vector, 94 | label: Double, 95 | weights: Vector, 96 | intercept: Double): Double = { 97 | // val dotProduct = BLAS.dot(weights, data) 98 | val dotProduct = BLAS.dot(weights, data) + intercept 99 | val labelScaled = 2 * label - 1.0 100 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct) 101 | if (labelScaled * dotProduct < 0) { 102 | max - 0.5 103 | } else { 104 | max * max / 2 105 | } 106 | } 107 | 108 | } 109 | 110 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/loss/UPULogitLossFunc.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.loss 2 | 3 | import net.qihoo.xitong.xdml.model.linalg.BLAS 4 | import net.qihoo.xitong.xdml.model.util.MLUtils 5 | import org.apache.spark.mllib.linalg.Vector 6 | 7 | 8 | class UPULogitLossFunc(posRatioPrior: Double, posRatio: Double, unlabeledRatio: Double) extends LossFunc { 9 | 10 | val posRatioPriorFrac = - posRatioPrior / posRatio 11 | println("unlabeledRatio: "+unlabeledRatio) 12 | println("posRatio: "+posRatio) 13 | println("posRatioPrior: "+posRatioPrior) 14 | 15 | ////////////////////////// without intercept ////////////////////////////////// 16 | 17 | override def loss(data: Vector, 18 | label: Double, 19 | weights: Vector): Double = { 20 | val margin = BLAS.dot(weights, data) 21 | if (label > 0.5) { 22 | posRatioPriorFrac * margin 23 | } else { 24 | MLUtils.log1pExp(margin) / unlabeledRatio 25 | } 26 | } 27 | 28 | override def gradient(data: Vector, 29 | label: Double, 30 | weights: Vector): Double = { 31 | if (label > 0.5) { 32 | posRatioPriorFrac 33 | } else { 34 | val margin = BLAS.dot(weights, data) 35 | 1.0 / (1.0 + math.exp(-margin)) / unlabeledRatio 36 | } 37 | } 38 | 39 | override def gradientWithLoss(data: Vector, 40 | label: Double, 41 | weights: Vector): (Double, Double) = { 42 | val margin = BLAS.dot(weights, data) 43 | if (label > 0.5) { 44 | (posRatioPriorFrac, posRatioPriorFrac * margin) 45 | } else { 46 | (1.0 / (1.0 + math.exp(-margin)) / unlabeledRatio, MLUtils.log1pExp(margin) / unlabeledRatio) 47 | } 48 | } 49 | 50 | override def gradientFromDot(dot: Double, 51 | label: Double): Double = { 52 | if (label > 0.5) { 53 | posRatioPriorFrac 54 | } else { 55 | 1.0 / (1.0 + math.exp(-dot)) / unlabeledRatio 56 | } 57 | } 58 | 59 | ////////////////////////// with intercept ////////////////////////////////// 60 | 61 | override def loss(data: Vector, 62 | label: Double, 63 | weights: Vector, 64 | intercept: Double): Double = { 65 | val margin = BLAS.dot(weights, data) + intercept 66 | if (label > 0.5) { 67 | posRatioPriorFrac * margin 68 | } else { 69 | MLUtils.log1pExp(margin) / unlabeledRatio 70 | } 71 | } 72 | 73 | override def gradient(data: Vector, 74 | label: Double, 75 | weights: Vector, 76 | intercept: Double): Double = { 77 | if (label > 0.5) { 78 | posRatioPriorFrac 79 | } else { 80 | val margin = BLAS.dot(weights, data) + intercept 81 | 1.0 / (1.0 + math.exp(-margin)) / unlabeledRatio 82 | } 83 | } 84 | 85 | override def gradientWithLoss(data: Vector, 86 | label: Double, 87 | weights: Vector, 88 | intercept: Double): (Double, Double) = { 89 | val margin = BLAS.dot(weights, data) + intercept 90 | if (label > 0.5) { 91 | (posRatioPriorFrac, posRatioPriorFrac * margin) 92 | } else { 93 | (1.0 / (1.0 + math.exp(-margin)) / unlabeledRatio, MLUtils.log1pExp(margin) / unlabeledRatio) 94 | } 95 | } 96 | 97 | } 98 | 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/loss/WeightedHingeLossFunc.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.loss 2 | 3 | import org.apache.spark.mllib.linalg.Vector 4 | import net.qihoo.xitong.xdml.model.linalg.BLAS 5 | 6 | 7 | class WeightedHingeLossFunc(posWeight: Double) extends LossFunc { 8 | 9 | ////////////////////////// without intercept ////////////////////////////////// 10 | 11 | override def gradientWithLoss(data: Vector, 12 | label: Double, 13 | weights: Vector): (Double, Double) = { 14 | val dotProduct = BLAS.dot(weights, data) 15 | val labelScaled = 2 * label - 1.0 16 | if (1.0 > labelScaled * dotProduct) { 17 | (- labelScaled * (label * (posWeight - 1) + 1), (1.0 - labelScaled * dotProduct) * (label * (posWeight - 1) + 1)) 18 | } else { 19 | (0.0, 0.0) 20 | } 21 | } 22 | 23 | override def gradient(data: Vector, 24 | label: Double, 25 | weights: Vector): Double = { 26 | val dotProduct = BLAS.dot(weights, data) 27 | val labelScaled = 2 * label - 1.0 28 | if (1.0 > labelScaled * dotProduct) { 29 | - labelScaled * (label * (posWeight - 1) + 1) 30 | } else { 31 | 0.0 32 | } 33 | } 34 | 35 | override def loss(data: Vector, 36 | label: Double, 37 | weights: Vector): Double = { 38 | val dotProduct = BLAS.dot(weights, data) 39 | val labelScaled = 2 * label - 1.0 40 | if (1.0 > labelScaled * dotProduct) { 41 | (1.0 - labelScaled * dotProduct) * (label * (posWeight - 1) + 1) 42 | } else { 43 | 0.0 44 | } 45 | } 46 | 47 | override def gradientFromDot(dot: Double, 48 | label: Double): Double = { 49 | val labelScaled = 2 * label - 1.0 50 | if (1.0 > labelScaled * dot) { 51 | - labelScaled * (label * (posWeight - 1) + 1) 52 | } else { 53 | 0.0 54 | } 55 | } 56 | 57 | ////////////////////////// with intercept ////////////////////////////////// 58 | 59 | override def gradientWithLoss(data: Vector, 60 | label: Double, 61 | weights: Vector, 62 | intercept: Double): (Double, Double) = { 63 | // val dotProduct = BLAS.dot(weights, data) 64 | val dotProduct = BLAS.dot(weights, data) + intercept 65 | val labelScaled = 2 * label - 1.0 66 | if (1.0 > labelScaled * dotProduct) { 67 | (- labelScaled * (label * (posWeight - 1) + 1), (1.0 - labelScaled * dotProduct) * (label * (posWeight - 1) + 1)) 68 | } else { 69 | (0.0, 0.0) 70 | } 71 | } 72 | 73 | override def gradient(data: Vector, 74 | label: Double, 75 | weights: Vector, 76 | intercept: Double): Double = { 77 | // val dotProduct = BLAS.dot(weights, data) 78 | val dotProduct = BLAS.dot(weights, data) + intercept 79 | val labelScaled = 2 * label - 1.0 80 | if (1.0 > labelScaled * dotProduct) { 81 | - labelScaled * (label * (posWeight - 1) + 1) 82 | } else { 83 | 0.0 84 | } 85 | } 86 | 87 | override def loss(data: Vector, 88 | label: Double, 89 | weights: Vector, 90 | intercept: Double): Double = { 91 | // val dotProduct = BLAS.dot(weights, data) 92 | val dotProduct = BLAS.dot(weights, data) + intercept 93 | val labelScaled = 2 * label - 1.0 94 | if (1.0 > labelScaled * dotProduct) { 95 | (1.0 - labelScaled * dotProduct) * (label * (posWeight - 1) + 1) 96 | } else { 97 | 0.0 98 | } 99 | } 100 | 101 | } -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/loss/WeightedLogitLossFunc.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.loss 2 | 3 | import org.apache.spark.mllib.linalg.Vector 4 | import net.qihoo.xitong.xdml.model.util.MLUtils 5 | import net.qihoo.xitong.xdml.model.linalg.BLAS 6 | 7 | 8 | class WeightedLogitLossFunc(posWeight: Double) extends LossFunc { 9 | 10 | def sigmoid(z: Double): Double = { 11 | 1.0 / (1.0 + math.exp(-z)) 12 | } 13 | 14 | ////////////////////////// without intercept ////////////////////////////////// 15 | 16 | override def loss(data: Vector, 17 | label: Double, 18 | weights: Vector): Double = { 19 | val margin = - BLAS.dot(weights, data) 20 | if (label > 0.5) { 21 | // The following is equivalent to log(1 + exp(margin)) but more numerically stable. 22 | MLUtils.log1pExp(margin) * posWeight 23 | } else { 24 | MLUtils.log1pExp(margin) - margin 25 | } 26 | } 27 | 28 | override def gradient(data: Vector, 29 | label: Double, 30 | weights: Vector): Double = { 31 | val margin = - BLAS.dot(weights, data) 32 | val factor = (1.0 / (1.0 + math.exp(margin))) - label 33 | factor * (label * (posWeight - 1) + 1) 34 | } 35 | 36 | override def gradientWithLoss(data: Vector, 37 | label: Double, 38 | weights: Vector): (Double, Double) = { 39 | val margin = - BLAS.dot(weights, data) 40 | val factor = (1.0 / (1.0 + math.exp(margin))) - label 41 | val loss = if (label > 0.5) { 42 | // The following is equivalent to log(1 + exp(margin)) but more numerically stable. 43 | MLUtils.log1pExp(margin) * posWeight 44 | } else { 45 | MLUtils.log1pExp(margin) - margin 46 | } 47 | (factor * (label * (posWeight - 1) + 1), loss) 48 | } 49 | 50 | override def gradientFromDot(dot: Double, 51 | label: Double): Double = { 52 | (sigmoid(dot) - label) * (label * (posWeight - 1) + 1) 53 | } 54 | 55 | ////////////////////////// with intercept ////////////////////////////////// 56 | 57 | override def loss(data: Vector, 58 | label: Double, 59 | weights: Vector, 60 | intercept: Double): Double = { 61 | // val margin = - BLAS.dot(weights, data) 62 | val margin = - BLAS.dot(weights, data) - intercept 63 | if (label > 0.5) { 64 | // The following is equivalent to log(1 + exp(margin)) but more numerically stable. 65 | MLUtils.log1pExp(margin) * posWeight 66 | } else { 67 | MLUtils.log1pExp(margin) - margin 68 | } 69 | } 70 | 71 | override def gradient(data: Vector, 72 | label: Double, 73 | weights: Vector, 74 | intercept: Double): Double = { 75 | // val margin = - BLAS.dot(weights, data) 76 | val margin = - BLAS.dot(weights, data) - intercept 77 | val factor = (1.0 / (1.0 + math.exp(margin))) - label 78 | factor * (label * (posWeight - 1) + 1) 79 | } 80 | 81 | override def gradientWithLoss(data: Vector, 82 | label: Double, 83 | weights: Vector, 84 | intercept: Double): (Double, Double) = { 85 | // val margin = - BLAS.dot(weights, data) 86 | val margin = - BLAS.dot(weights, data) - intercept 87 | val factor = (1.0 / (1.0 + math.exp(margin))) - label 88 | val loss = if (label > 0.5) { 89 | // The following is equivalent to log(1 + exp(margin)) but more numerically stable. 90 | MLUtils.log1pExp(margin) * posWeight 91 | } else { 92 | MLUtils.log1pExp(margin) - margin 93 | } 94 | (factor * (label * (posWeight - 1) + 1), loss) 95 | } 96 | 97 | } 98 | 99 | 100 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/loss/WeightedSmoothHingeLossFunc.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.loss 2 | 3 | import net.qihoo.xitong.xdml.model.linalg.BLAS 4 | import org.apache.spark.mllib.linalg.Vector 5 | 6 | 7 | class WeightedSmoothHingeLossFunc(posWeight: Double) extends LossFunc { 8 | 9 | ////////////////////////// without intercept ////////////////////////////////// 10 | 11 | override def gradientWithLoss(data: Vector, 12 | label: Double, 13 | weights: Vector): (Double, Double) = { 14 | val dotProduct = BLAS.dot(weights, data) 15 | val labelScaled = 2 * label - 1.0 16 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct) 17 | if (labelScaled * dotProduct < 0) { 18 | (- labelScaled * (label * (posWeight - 1) + 1), (max - 0.5) * (label * (posWeight - 1) + 1)) 19 | } else { 20 | (- labelScaled * max * (label * (posWeight - 1) + 1), max * max / 2 * (label * (posWeight - 1) + 1)) 21 | } 22 | } 23 | 24 | override def gradient(data: Vector, 25 | label: Double, 26 | weights: Vector): Double = { 27 | val dotProduct = BLAS.dot(weights, data) 28 | val labelScaled = 2 * label - 1.0 29 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct) 30 | if (labelScaled * dotProduct < 0) { 31 | - labelScaled * (label * (posWeight - 1) + 1) 32 | } else { 33 | - labelScaled * max * (label * (posWeight - 1) + 1) 34 | } 35 | } 36 | 37 | override def loss(data: Vector, 38 | label: Double, 39 | weights: Vector): Double = { 40 | val dotProduct = BLAS.dot(weights, data) 41 | val labelScaled = 2 * label - 1.0 42 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct) 43 | if (labelScaled * dotProduct < 0) { 44 | (max - 0.5) * (label * (posWeight - 1) + 1) 45 | } else { 46 | max * max / 2 * (label * (posWeight - 1) + 1) 47 | } 48 | } 49 | 50 | override def gradientFromDot(dot: Double, 51 | label: Double): Double = { 52 | val labelScaled = 2 * label - 1.0 53 | val max = math.max(0.0, 1.0 - labelScaled * dot) 54 | if (labelScaled * dot < 0) { 55 | - labelScaled * (label * (posWeight - 1) + 1) 56 | } else { 57 | - labelScaled * max * (label * (posWeight - 1) + 1) 58 | } 59 | } 60 | 61 | ////////////////////////// with intercept ////////////////////////////////// 62 | 63 | override def gradientWithLoss(data: Vector, 64 | label: Double, 65 | weights: Vector, 66 | intercept: Double): (Double, Double) = { 67 | // val dotProduct = BLAS.dot(weights, data) 68 | val dotProduct = BLAS.dot(weights, data) + intercept 69 | val labelScaled = 2 * label - 1.0 70 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct) 71 | if (labelScaled * dotProduct < 0) { 72 | (- labelScaled * (label * (posWeight - 1) + 1), (max - 0.5) * (label * (posWeight - 1) + 1)) 73 | } else { 74 | (- labelScaled * max * (label * (posWeight - 1) + 1), max * max / 2 * (label * (posWeight - 1) + 1)) 75 | } 76 | } 77 | 78 | override def gradient(data: Vector, 79 | label: Double, 80 | weights: Vector, 81 | intercept: Double): Double = { 82 | // val dotProduct = BLAS.dot(weights, data) 83 | val dotProduct = BLAS.dot(weights, data) + intercept 84 | val labelScaled = 2 * label - 1.0 85 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct) 86 | if (labelScaled * dotProduct < 0) { 87 | - labelScaled * (label * (posWeight - 1) + 1) 88 | } else { 89 | - labelScaled * max * (label * (posWeight - 1) + 1) 90 | } 91 | } 92 | 93 | override def loss(data: Vector, 94 | label: Double, 95 | weights: Vector, 96 | intercept: Double): Double = { 97 | // val dotProduct = BLAS.dot(weights, data) 98 | val dotProduct = BLAS.dot(weights, data) + intercept 99 | val labelScaled = 2 * label - 1.0 100 | val max = math.max(0.0, 1.0 - labelScaled * dotProduct) 101 | if (labelScaled * dotProduct < 0) { 102 | (max - 0.5) * (label * (posWeight - 1) + 1) 103 | } else { 104 | max * max / 2 * (label * (posWeight - 1) + 1) 105 | } 106 | } 107 | 108 | } 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/supervised/H2O/H2OParams.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.model.supervised 2 | 3 | import org.apache.spark.ml.param._ 4 | import org.apache.spark.sql.types.{StringType, StructType} 5 | 6 | 7 | trait H2OParams extends Params { 8 | 9 | final val labelCol: Param[String] = new Param[String](this, "labelCol", "Label column name") 10 | 11 | final val catFeatColNames: StringArrayParam = new StringArrayParam(this, "catFeatColNames","Categorical feature column names") 12 | 13 | final val ignoredFeatColNames: StringArrayParam = new StringArrayParam(this, "ignoredFeatColNames", "Ignored feature column names") 14 | 15 | protected def validateSchema(schema: StructType): StructType={ 16 | ${catFeatColNames}.foreach { catFeatName => require(schema.fieldNames.contains(catFeatName), s"Column $catFeatName cannot find in dataFrame.") } 17 | ${catFeatColNames}.foreach { catFeatName => require(schema(catFeatName).dataType == StringType, s"The categorical feature column $catFeatName must be string type.")} 18 | schema 19 | } 20 | } 21 | 22 | trait H2OTreeParams extends H2OParams{ 23 | 24 | final val maxDepth: IntParam = 25 | new IntParam(this, "maxDepth", "Maximum depth of the tree. (>= 0)" + 26 | " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.", 27 | ParamValidators.gtEq(0)) 28 | 29 | final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", 30 | ParamValidators.gtEq(1)) 31 | 32 | final val maxBinsForCat: IntParam = new IntParam(this, "maxBinsForCat", "Max number of bins for" + 33 | " category features. Must be >=2 and >= number of categories for any" + 34 | " categorical feature.", ParamValidators.gtEq(2)) 35 | 36 | final val maxBinsForNum: IntParam = new IntParam(this, "maxBinsForNum", "Max number of bins for" + 37 | " numeric features. Must be >=2 and >= number of category for any" + 38 | " numeric feature.", ParamValidators.gtEq(2)) 39 | 40 | final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" + 41 | " number of instances each child must have after split. If a split causes the left or right" + 42 | " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." + 43 | " Should be >= 1.", ParamValidators.gtEq(1)) 44 | 45 | final val categoricalEncodingScheme: Param[String] = new Param[String](this, "categoricalEncodingScheme", "Encoding scheme " + 46 | "for categorical features. Supported options: " + H2OTreeParams.supportedCategoricalEncoding.mkString(", "), 47 | ParamValidators.inArray(H2OTreeParams.supportedCategoricalEncoding)) 48 | 49 | final val histogramType: Param[String] = new Param[String](this, "histogramType"," Type of histogram. Supported options: " + 50 | H2OTreeParams.supportedHistogramType.mkString(", "), ParamValidators.inArray(H2OTreeParams.supportedHistogramType)) 51 | 52 | final val distribution: Param[String] = new Param[String](this, "distribution", "Distribution for dataSet. Supported options: " + 53 | H2OTreeParams.supportedDistribution.mkString(", "), ParamValidators.inArray(H2OTreeParams.supportedDistribution)) 54 | 55 | final val scoreTreeInterval: IntParam = new IntParam (this, "scoreTreeInterval", "Score the model after every so many trees." + 56 | " Disabled if set to 0.", ParamValidators.gtEq(0)) 57 | } 58 | 59 | object H2OTreeParams{ 60 | 61 | final val supportedCategoricalEncoding: Array[String] = Array("AUTO", "Enum", "LabelEncoder", "OneHotExplicit", "SortByResponse") 62 | 63 | final val supportedHistogramType: Array[String] = Array("AUTO", "UniformAdaptive", "QuantilesGlobal") 64 | 65 | final val supportedDistribution: Array[String] = Array("bernoulli", "multinomial", "gaussian") 66 | } 67 | 68 | trait H2OTreeEstimatorParams extends H2OTreeParams{ 69 | 70 | def setLabelCol(value: String): this.type = set(labelCol, value) 71 | 72 | def setCatFeatColNames(values: Array[String]): this.type = set(catFeatColNames, values) 73 | 74 | def setIgnoreFeatColNames(values: Array[String]): this.type = set(ignoredFeatColNames, values) 75 | 76 | def setMaxDepth(value: Int): this.type = set(maxDepth, value) 77 | 78 | def setNumTrees(value: Int): this.type = set(numTrees, value) 79 | 80 | def setMaxBinsForCat(value: Int): this.type = set(maxBinsForCat, value) 81 | 82 | def setMaxBinsForNum(value: Int): this.type = set(maxBinsForNum, value) 83 | 84 | def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) 85 | 86 | def setCategoricalEncodingScheme(value: String): this.type = set(categoricalEncodingScheme, value) 87 | 88 | def setHistogramType(value: String): this.type = set(histogramType, value) 89 | 90 | def setDistribution(value: String): this.type = set(distribution, value) 91 | 92 | def setScoreTreeInterval(value: Int): this.type = set(scoreTreeInterval, value) 93 | 94 | setDefault(catFeatColNames -> Array(), ignoredFeatColNames -> Array()) 95 | } 96 | 97 | 98 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/model/util/MLUtils.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.model.util 2 | 3 | import org.apache.spark.internal.Logging 4 | import org.apache.spark.mllib.linalg._ 5 | import net.qihoo.xitong.xdml.model.linalg.BLAS 6 | 7 | 8 | object MLUtils extends Logging { 9 | 10 | lazy val EPSILON = { 11 | var eps = 1.0 12 | while ((1.0 + (eps / 2.0)) != 1.0) { 13 | eps /= 2.0 14 | } 15 | eps 16 | } 17 | 18 | /** 19 | * Returns the squared Euclidean distance between two vectors. The following formula will be used 20 | * if it does not introduce too much numerical error: 21 | *
22 |     *   \|a - b\|_2^2 = \|a\|_2^2 + \|b\|_2^2 - 2 a^T b.
23 |     * 
24 | * When both vector norms are given, this is faster than computing the squared distance directly, 25 | * especially when one of the vectors is a sparse vector. 26 | * @param v1 the first vector 27 | * @param norm1 the norm of the first vector, non-negative 28 | * @param v2 the second vector 29 | * @param norm2 the norm of the second vector, non-negative 30 | * @param precision desired relative precision for the squared distance 31 | * @return squared distance between v1 and v2 within the specified precision 32 | */ 33 | def fastSquaredDistance(v1: Vector, 34 | norm1: Double, 35 | v2: Vector, 36 | norm2: Double, 37 | precision: Double = 1e-6): Double = { 38 | val n = v1.size 39 | require(v2.size == n) 40 | require(norm1 >= 0.0 && norm2 >= 0.0) 41 | val sumSquaredNorm = norm1 * norm1 + norm2 * norm2 42 | val normDiff = norm1 - norm2 43 | var sqDist = 0.0 44 | /* 45 | * The relative error is 46 | *
47 |      * EPSILON * ( \|a\|_2^2 + \|b\\_2^2 + 2 |a^T b|) / ( \|a - b\|_2^2 ),
48 |      * 
49 | * which is bounded by 50 | *
51 |      * 2.0 * EPSILON * ( \|a\|_2^2 + \|b\|_2^2 ) / ( (\|a\|_2 - \|b\|_2)^2 ).
52 |      * 
53 | * The bound doesn't need the inner product, so we can use it as a sufficient condition to 54 | * check quickly whether the inner product approach is accurate. 55 | */ 56 | val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON) 57 | if (precisionBound1 < precision) { 58 | sqDist = sumSquaredNorm - 2.0 * BLAS.dot(v1, v2) 59 | } else if (v1.isInstanceOf[SparseVector] || v2.isInstanceOf[SparseVector]) { 60 | val dotValue = BLAS.dot(v1, v2) 61 | sqDist = math.max(sumSquaredNorm - 2.0 * dotValue, 0.0) 62 | val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dotValue)) / 63 | (sqDist + EPSILON) 64 | if (precisionBound2 > precision) { 65 | sqDist = Vectors.sqdist(v1, v2) 66 | } 67 | } else { 68 | sqDist = Vectors.sqdist(v1, v2) 69 | } 70 | sqDist 71 | } 72 | 73 | /** 74 | * When `x` is positive and large, computing `math.log(1 + math.exp(x))` will lead to arithmetic 75 | * overflow. This will happen when `x > 709.78` which is not a very large number. 76 | * It can be addressed by rewriting the formula into `x + math.log1p(math.exp(-x))` when `x > 0`. 77 | * @param x a floating-point value as input. 78 | * @return the result of `math.log(1 + math.exp(x))`. 79 | */ 80 | def log1pExp(x: Double): Double = { 81 | if (x > 0) { 82 | x + math.log1p(math.exp(-x)) 83 | } else { 84 | math.log1p(math.exp(x)) 85 | } 86 | } 87 | 88 | } 89 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/optimization/BinaryClassify.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.optimization 2 | 3 | import breeze.numerics.sigmoid 4 | import net.qihoo.xitong.xdml.linalg.BLAS._ 5 | 6 | object BinaryClassify extends Optimizer { 7 | 8 | def train(label: Double, 9 | feature: (Array[Long], Array[Float]), 10 | localW: Map[Long, Float], 11 | subsampling_rate: Double = 1.0, 12 | subsampling_label: Double = 0.0 13 | ): (Map[Long, Float], Double) = { 14 | 15 | val (pre, loss) = computePreAndLoss(label, feature, localW) 16 | val newGradient = getGradLoss(feature, label, pre, subsampling_rate, subsampling_label) 17 | (newGradient, loss) 18 | } 19 | 20 | def predict(feature: (Array[Long], Array[Float]), 21 | weight: Map[Long, Float] 22 | ): Double = { 23 | sigmoid(dot(feature, weight)) 24 | } 25 | 26 | def getGradLoss(feature: (Array[Long], Array[Float]), 27 | label: Double, 28 | pre: Float, 29 | subsampling_rate: Double = 1.0, 30 | subsampling_label: Double = 0.0): Map[Long, Float] = { 31 | val p = sigmoid(pre) 32 | var w = 1.0 33 | if (Math.abs(label - subsampling_label) < 0.001) { 34 | w = 1 / subsampling_rate 35 | } 36 | var gradLossMap: Map[Long, Float] = Map() 37 | (0 until feature._1.size).foreach(id => gradLossMap += (feature._1(id) -> ((p - label) * feature._2(id) * w).toFloat)) 38 | gradLossMap 39 | } 40 | 41 | def computePreAndLoss(label: Double, 42 | feature: (Array[Long], Array[Float]), 43 | weight: Map[Long, Float] 44 | ): (Float, Double) = { 45 | val dotVal = dot(feature, weight) 46 | val loss = if (label > 0) { 47 | if (dotVal > 13) 48 | 2e-6 49 | else if (dotVal < -13) 50 | (-dotVal) 51 | else 52 | math.log1p(math.exp(-dotVal)) 53 | } else { 54 | if (dotVal > 13) 55 | (-dotVal) 56 | else if (dotVal < -13) 57 | 2e-6 58 | else 59 | math.log1p(math.exp(-dotVal)) + dotVal 60 | } 61 | (dotVal, loss) 62 | } 63 | 64 | } 65 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/optimization/FFM.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.optimization 2 | 3 | import scala.collection.mutable 4 | 5 | object FFM extends Optimizer { 6 | val lamda = 0.1F 7 | val lr = 0.1F 8 | 9 | def train(fieldNum: Int, 10 | rank: Int, 11 | label: Double, 12 | feature: (Array[Int], Array[Long], Array[Float]), 13 | VMap: mutable.Map[Long, Array[Float]]): (mutable.Map[Long, Array[Float]], Double) = { 14 | calcGradAndLoss(fieldNum, rank, label, feature, VMap) 15 | } 16 | 17 | def predict(): (Double,Double) ={ 18 | (0.0,0.0) 19 | } 20 | 21 | //calculate dot 22 | def vectorDotWithField(rank: Int, 23 | vector1: Array[Float], 24 | field1: Int, 25 | vector2: Array[Float], 26 | field2: Int): Float = { 27 | var res = 0.0F 28 | val idx1 = rank * field1 29 | val idx2 = rank * field2 30 | for (i <- 0 until rank) { 31 | res = res + vector1(idx1 + i) * vector2(idx2 + i) 32 | } 33 | println("vec1: " + vector1.mkString(",")) 34 | println("vec2: " + vector2.mkString(",")) 35 | println("dot res: " + res) 36 | res 37 | } 38 | 39 | //calculate phi 40 | def calcPhi(fieldNum: Int, 41 | rank: Int, 42 | oneData: (Array[Int], Array[Long], Array[Float]), 43 | VMap: mutable.Map[Long, Array[Float]]): Float = { 44 | val size = oneData._1.length 45 | val field = oneData._1 46 | val id = oneData._2 47 | val value = oneData._3 48 | var resPhi = 0.0F 49 | val zeros = Array.fill(rank * fieldNum)(0.01f) 50 | for (i <- 0 until size) { 51 | for (j <- i + 1 until size) { 52 | val w1 = VMap.getOrElse(id(i), zeros) 53 | val w2 = VMap.getOrElse(id(j), zeros) 54 | println("w1: " + w1.mkString(",") + " w2: " + w2.mkString(",")) 55 | resPhi = resPhi + vectorDotWithField(rank, w1, field(i), w2, field(j)) * value(i) * value(j) 56 | } 57 | } 58 | println("phi: " + resPhi) 59 | resPhi 60 | } 61 | 62 | //calculate grad and loss 63 | def calcGradAndLoss(fieldNum: Int, 64 | rank: Int, 65 | label: Double, 66 | oneData: (Array[Int], Array[Long], Array[Float]), 67 | weightMap: mutable.Map[Long, Array[Float]]): (mutable.Map[Long, Array[Float]], Double) = { 68 | val pred = calcPhi(fieldNum, rank, oneData, weightMap) 69 | val loss = math.log1p(math.exp(-pred * label)) 70 | println("pred: " + pred + " label: " + label) 71 | val size = oneData._1.length 72 | val field = oneData._1 73 | val id = oneData._2 74 | val value = oneData._3 75 | val gradMap = new mutable.HashMap[Long, Array[Float]] 76 | val zeros = Array.fill(rank * fieldNum + fieldNum)(0.01f) 77 | val k = -label / math.log1p(label * pred) 78 | for (i <- 0 until size) { 79 | for (j <- i + 1 until size) { 80 | val w1 = weightMap.getOrElse(id(i), zeros) 81 | val w2 = weightMap.getOrElse(id(j), zeros) 82 | val (id1, id2, f1, f2, v1, v2) = (id(i), id(j), field(i), field(j), value(i), value(j)) 83 | val grad = calcGrad(fieldNum, rank, w1, f1, v1, w2, f2, v2, k) 84 | val start1 = f2 * rank 85 | val start2 = f1 * rank 86 | for (index <- 0 until rank) { 87 | w1(start1 + index) = grad._1(index) 88 | w2(start2 + index) = grad._2(index) 89 | } 90 | gradMap.put(id1, w1) 91 | gradMap.put(id2, w2) 92 | } 93 | } 94 | (gradMap, loss) 95 | } 96 | 97 | //calculate grad 98 | def calcGrad(fieldNum: Int, 99 | rank: Int, 100 | w1: Array[Float], 101 | w1Field: Int, 102 | x1: Float, 103 | w2: Array[Float], 104 | w2Field: Int, 105 | x2: Float, 106 | k: Double): (Array[Float], Array[Float]) = { 107 | val w1f2Grad = Array.fill(rank)(0.01f) 108 | val w2f1Grad = Array.fill(rank)(0.01f) 109 | val v1f2Start = rank * w2Field 110 | val v2f1Start = rank * w1Field 111 | for (i <- 0 until rank) { 112 | w1f2Grad(i) = lamda * w1(v1f2Start + i) + k.toFloat * w2(v2f1Start + i) * x1 * x2 113 | w2f1Grad(i) = lamda * w2(v2f1Start + i) + k.toFloat * w1(v1f2Start + i) * x1 * x2 114 | } 115 | println("grad1: " + w1f2Grad.mkString(",")) 116 | println("grad2: " + w2f1Grad.mkString(",")) 117 | (w1f2Grad, w2f1Grad) 118 | } 119 | //update map 120 | def updateGradMap(fieldNum: Int, 121 | rank: Int, 122 | targetGradMap: mutable.Map[Long, Array[Float]], 123 | grad: (Array[Float], Array[Float]), 124 | field1: Int, 125 | field2: Int, 126 | id1: Long, 127 | id2: Long): Unit = { 128 | val zeros = Array.fill(rank)(0.0f) 129 | val gradId1 = targetGradMap.getOrElse(id1, zeros) 130 | val gradId2 = targetGradMap.getOrElse(id2, zeros) 131 | val (w1f2Grad, w2f1Grad) = (grad._1, grad._2) 132 | val (start1, start2) = (field2 * rank, field1 * rank) 133 | val (lastGradAcc1, lastGradAcc2) = (w1f2Grad(rank * fieldNum - 1 + field1), w2f1Grad(rank * fieldNum - 1 + field2)) 134 | val lr1 = -1 / math.sqrt(1 + lastGradAcc1).toFloat 135 | val lr2 = -1 / math.sqrt(1 + lastGradAcc2).toFloat 136 | var grad2norm1 = 0.0F 137 | var grad2norm2 = 0.0F 138 | for (i <- 0 until rank) { 139 | gradId1(start1 + i) = gradId1(start1 + i) + lr1 * w1f2Grad(i) 140 | gradId2(start2 + i) = gradId2(start2 + i) + lr2 * w2f1Grad(i) 141 | grad2norm1 += math.pow(w1f2Grad(i), 2).toFloat 142 | grad2norm2 += math.pow(w2f1Grad(i), 2).toFloat 143 | 144 | } 145 | gradId1(rank * fieldNum - 1 + field1) = grad2norm1 + lastGradAcc1 146 | gradId2(rank * fieldNum - 1 + field2) = grad2norm2 + lastGradAcc2 147 | targetGradMap.put(id1, gradId1) 148 | targetGradMap.put(id2, gradId2) 149 | } 150 | 151 | //calc vec2 152 | def calc2norm(fieldNum: Int, rank: Int, vector: Array[Float], field: Int): Float = { 153 | var res = 0.0F 154 | val start = rank * field 155 | for (i <- 0 until rank) { 156 | res += math.pow(vector(start + i), 2).toFloat 157 | } 158 | res 159 | } 160 | } 161 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/optimization/FM.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.optimization 2 | 3 | import net.qihoo.xitong.xdml.linalg.BLAS._ 4 | 5 | 6 | object FM extends Optimizer { 7 | def train(label: Double, 8 | feature: (Array[Long], Array[Float]), 9 | W0: Float, 10 | W: Map[Long, Float], 11 | V: Map[Long, Array[Float]], 12 | rank: Int, 13 | subsampling_rate: Double = 1.0, 14 | subsampling_label: Double = 0.0): (Map[Long, Array[Float]], Double) = { 15 | //pred是预测值y,先加上常数项和一次项 16 | var pred = W0 + dot(feature, W) 17 | val array = Array.fill(rank)(0.0f) 18 | val sumArray = Array.fill(rank)(0.0f) 19 | // res1 = v_i,f * x_i 20 | // res2 = (v_i,f * x_i)^2 21 | (0 until rank).foreach { f => 22 | var (res1, res2) = (0f, 0f) 23 | feature._1.indices.foreach { id => 24 | val fId = feature._1(id) 25 | val fValue = feature._2(id) 26 | val tmp = fValue * V.getOrElse(fId, array)(f) 27 | res1 += tmp 28 | res2 += tmp * tmp 29 | } 30 | sumArray(f) = res1 31 | pred += 0.5f * (res1 * res1 - res2) 32 | } 33 | 34 | /* 35 | val loss = if (label > 0) { 36 | if (pred > 13) 37 | 2e-6 38 | else if (pred < -13) 39 | -pred 40 | else 41 | math.log1p(math.exp(-pred)) 42 | } else { 43 | if (pred > 13) 44 | -pred 45 | else if (pred < -13) 46 | 2e-6 47 | else 48 | math.log1p(math.exp(-pred)) + pred 49 | } 50 | */ 51 | val z = label * pred 52 | //计算loss 53 | var loss = 0.0 54 | if (z > 18) 55 | loss = math.exp(-z) 56 | else if (z < -18) 57 | loss = -z 58 | else 59 | loss = math.log1p(math.exp(-z)) 60 | 61 | val p = deltaGrad(label, pred).toFloat 62 | var w = 1.0f 63 | if (math.abs(label - subsampling_label) < 0.001) 64 | w = (1 / subsampling_rate).toFloat 65 | 66 | var gradLossMap: Map[Long, Array[Float]] = Map() 67 | val wArray = Array.fill(rank + 1)(0.0f) 68 | wArray(0) = p * w 69 | gradLossMap += (0L -> wArray) 70 | feature._1.indices.foreach { id => 71 | val fId = feature._1(id) 72 | val fValue = feature._2(id) 73 | val wArray = Array.fill(rank + 1)(0.0f) 74 | if (fId != 0L) { 75 | wArray(0) = fValue * p * w 76 | (0 until rank).foreach { f => 77 | val vGrad = fValue * sumArray(f) - V.getOrElse(fId, array)(f) * fValue * fValue 78 | wArray(f + 1) = vGrad * p * w 79 | } 80 | gradLossMap += (fId -> wArray) 81 | } 82 | } 83 | (gradLossMap, loss) 84 | } 85 | 86 | def predict(feature: (Array[Long], Array[Float]), 87 | W0: Float, 88 | W: Map[Long, Float], 89 | V: Map[Long, Array[Float]], 90 | rank: Int): Double = { 91 | var pred = W0 + dot(feature, W) 92 | val array = Array.fill(rank)(0.0F) 93 | for (f <- 0 until rank) { 94 | var (res1, res2) = (0f, 0f) 95 | feature._1.indices.foreach { id => 96 | val tmp = feature._2(id) * V.getOrElse(feature._1(id), array)(f) 97 | res1 += tmp 98 | res2 += tmp * tmp 99 | } 100 | pred += 0.5f * (res1 * res1 - res2) 101 | } 102 | sigmoid(pred) 103 | } 104 | 105 | // - exp(-label * pred) / (1 + exp(-label * pred)) * label 106 | def deltaGrad(label: Double, 107 | pred: Float): Double = { 108 | val sigma = label * pred 109 | if (sigma > 18) 110 | -label * math.exp(-sigma) 111 | else if (sigma < -18) 112 | -label 113 | else { 114 | val ex = Math.pow(2.718281828, sigma) 115 | -label / (1 + ex) 116 | } 117 | } 118 | 119 | } -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/optimization/Optimizer.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.optimization 2 | 3 | trait Optimizer { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/task/PullTask.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.task 2 | 3 | import java.util.concurrent.Callable 4 | 5 | class PullTask[K, V] extends Task with Callable[java.util.Map[Long, Array[Byte]]]{ 6 | var pullSet: java.util.Set[Long] = _ 7 | //pull set must be set before pull parameters from ps 8 | def setPullSet(pullSet: java.util.Set[Long]): Unit = { 9 | this.pullSet = pullSet 10 | } 11 | //pull task 12 | override def call(): java.util.Map[Long, Array[Byte]] = { 13 | val psMap = hazelcastIns.getMap[Long, Array[Byte]](tableName) 14 | psMap.getAll(pullSet) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/task/PushTask.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.task 2 | 3 | import java.nio.ByteBuffer 4 | import java.util.concurrent.Callable 5 | 6 | import net.qihoo.xitong.xdml.conf.PSDataType 7 | import net.qihoo.xitong.xdml.conf.PSDataType.PSDataType 8 | import net.qihoo.xitong.xdml.updater.Updater 9 | import net.qihoo.xitong.xdml.utils.XDMLException 10 | 11 | import scala.collection.JavaConversions._ 12 | 13 | class PushTask[K, V] extends Task with Callable[Unit] { 14 | //PSDataType of V 15 | private var vClazz: PSDataType = _ 16 | private var pushMap: java.util.Map[Long, V] = _ 17 | private var psDatalength:Int = _ 18 | private var updater:Updater[Long, V] = _ 19 | private var hasUpdater = false 20 | //V's length 21 | def setDataLength(length: Int): Unit ={ 22 | this.psDatalength = length 23 | } 24 | 25 | def setVClazz(dataType:PSDataType) : Unit = { 26 | this.vClazz = dataType 27 | } 28 | 29 | def setUpdater(updater:Updater[Long, V]): Unit ={ 30 | this.updater = updater 31 | this.hasUpdater = true 32 | } 33 | 34 | def setPushMap(pushMap: java.util.Map[Long, V]): Unit = { 35 | this.pushMap = pushMap 36 | } 37 | //get V from bytes 38 | def getValueFromBytes(bytes:Array[Byte]): V = { 39 | val valueBuff = ByteBuffer.wrap(bytes) 40 | vClazz match{ 41 | case PSDataType.INT => valueBuff.getInt().asInstanceOf[V] 42 | case PSDataType.LONG => valueBuff.getLong().asInstanceOf[V] 43 | case PSDataType.FLOAT => valueBuff.getFloat().asInstanceOf[V] 44 | case PSDataType.DOUBLE => valueBuff.getDouble().asInstanceOf[V] 45 | case PSDataType.FLOAT_ARRAY => { 46 | val arr = new Array[Float](psDatalength) 47 | for(index <- arr.indices){ 48 | arr(index) = valueBuff.getFloat(index << 2) 49 | } 50 | arr.asInstanceOf[V] 51 | } 52 | case PSDataType.DOUBLE_ARRAY => { 53 | val arr = new Array[Double](psDatalength) 54 | for(index <- arr.indices){ 55 | arr(index) = valueBuff.getDouble(index << 3) 56 | } 57 | arr.asInstanceOf[V] 58 | } 59 | case _ => throw new XDMLException("data type error!") 60 | } 61 | } 62 | //get bytes from V 63 | def getBytesFromValue(value: V): Array[Byte] = { 64 | val byteSize = PSDataType.sizeOf(vClazz) * psDatalength 65 | val byteBuff = ByteBuffer.allocate(byteSize) 66 | vClazz match { 67 | case PSDataType.INT => { 68 | byteBuff.putInt(value.asInstanceOf[Int]) 69 | } 70 | case PSDataType.LONG => { 71 | byteBuff.putLong(value.asInstanceOf[Long]) 72 | } 73 | case PSDataType.FLOAT => { 74 | byteBuff.putFloat(value.asInstanceOf[Float]) 75 | } 76 | case PSDataType.DOUBLE => { 77 | byteBuff.putDouble(value.asInstanceOf[Double]) 78 | } 79 | case PSDataType.FLOAT_ARRAY => { 80 | val arr = value.asInstanceOf[Array[Float]] 81 | arr.indices.map { index => 82 | byteBuff.putFloat(index << 2, arr(index)) 83 | } 84 | } 85 | case PSDataType.DOUBLE_ARRAY => { 86 | val arr = value.asInstanceOf[Array[Double]] 87 | arr.indices.map { index => 88 | byteBuff.putDouble(index << 3, arr(index)) 89 | } 90 | } 91 | case _ => throw new IllegalArgumentException("data type error!") 92 | } 93 | byteBuff.array() 94 | } 95 | //push task 96 | override def call(): Unit = { 97 | val weightMap = hazelcastIns.getMap[Long, Array[Byte]](tableName) 98 | val localMap = weightMap.getAll(pushMap.keySet).map { 99 | case (k, bytes) => (k, getValueFromBytes(bytes)) 100 | } 101 | if(hasUpdater) 102 | pushMap = updater.update(localMap,pushMap) 103 | val bytesMap = pushMap.map { 104 | case (k, v) => (k, getBytesFromValue(v)) 105 | } 106 | weightMap.putAll(bytesMap) 107 | } 108 | 109 | } -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/task/Task.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.task 2 | 3 | import java.io.Serializable 4 | 5 | import com.hazelcast.core.{HazelcastInstance, HazelcastInstanceAware} 6 | 7 | abstract class Task extends Serializable with HazelcastInstanceAware { 8 | @transient 9 | var hazelcastIns: HazelcastInstance = _ 10 | 11 | //pull and push task table name 12 | protected var tableName:String = _ 13 | 14 | def setTableName(name: String): Unit = { 15 | this.tableName = name 16 | } 17 | //set hz server instance 18 | override def setHazelcastInstance(hz: HazelcastInstance): Unit = { 19 | this.hazelcastIns = hz 20 | } 21 | 22 | } 23 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/updater/DCASGDUpdater.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.updater 2 | import java.util 3 | import scala.collection.JavaConversions._ 4 | 5 | class DCASGDUpdater extends Updater[Long,Array[Float]]{ 6 | private var coff = 0.1F 7 | private var learningRate = 0.01F 8 | 9 | override def update(originalWeightMap: util.Map[Long, Array[Float]], gradientMap: util.Map[Long, Array[Float]]): util.Map[Long, Array[Float]] = { 10 | val updateMap = originalWeightMap.map { case (k, v) => { 11 | val grad = gradientMap(k)(0) 12 | val delta = -learningRate * grad - coff * grad * grad * (v(0) - v(1)) 13 | v(1) = v(0) 14 | v(0) = v(1) + delta 15 | (k, v) 16 | } 17 | } 18 | updateMap 19 | } 20 | 21 | def setLearningRate(lr:Float):this.type ={ 22 | this.learningRate = lr 23 | this 24 | } 25 | 26 | def setCoff(coff:Float):this.type ={ 27 | this.coff = coff 28 | this 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/updater/FFMUpdater.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.updater 2 | import java.util 3 | import scala.collection.JavaConversions._ 4 | 5 | class FFMUpdater extends Updater[Long,Array[Float]]{ 6 | 7 | private var learningRate = 0.01F 8 | 9 | override def update(originalWeightMap: util.Map[Long, Array[Float]], gradientMap: util.Map[Long, Array[Float]]): util.Map[Long, Array[Float]] = { 10 | val size = gradientMap.head._2.length 11 | val updateMap = originalWeightMap.map{ 12 | case(k, v) => for(index <- 0 until size){ 13 | v(index) = v(index) - learningRate * gradientMap(k)(index) 14 | } 15 | (k,v) 16 | } 17 | updateMap 18 | } 19 | 20 | //set learning rate 21 | def setLearningRate(lr:Float):this.type ={ 22 | this.learningRate = lr 23 | this 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/updater/LRFTRLUpdater.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.updater 2 | 3 | import java.util 4 | import scala.collection.JavaConversions._ 5 | 6 | class LRFTRLUpdater extends Updater[Long, Array[Float]] { 7 | private var alpha = 1F 8 | private var beta = 1F 9 | private var lambda1 = 1F 10 | private var lambda2 = 1F 11 | 12 | override def update(originalWeightMap: util.Map[Long, Array[Float]], gradientMap: util.Map[Long, Array[Float]]): util.Map[Long, Array[Float]] = { 13 | val updateMap = originalWeightMap.map { 14 | case (k, v) => { 15 | val grad = gradientMap(k)(0) 16 | val pValue = 1f / alpha * (Math.sqrt(v(1) + grad * grad) - Math.sqrt(v(1))) 17 | v(0) = (v(0) + grad - pValue * v(2)).toFloat 18 | v(1) = v(1) + grad * grad 19 | if (Math.abs(v(0)) > lambda1) 20 | v(2) = ((-1) * (1.0 / (lambda2 + (beta + Math.sqrt(v(1))) / alpha)) * (v(0) - Math.signum(v(0)).toInt * lambda1)).toFloat 21 | else 22 | v(2) = 0F 23 | 24 | (k, v) 25 | } 26 | } 27 | updateMap 28 | } 29 | 30 | def setAlpha(alpha: Float): this.type = { 31 | this.alpha = alpha 32 | this 33 | } 34 | 35 | def setBeta(beta: Float): this.type = { 36 | this.beta = beta 37 | this 38 | } 39 | 40 | def setLambda1(lambda1: Float): this.type = { 41 | this.lambda1 = lambda1 42 | this 43 | } 44 | 45 | def setLambda2(lambda2: Float): this.type = { 46 | this.lambda2 = lambda2 47 | this 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/updater/LRUpdater.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.updater 2 | 3 | import scala.collection.JavaConversions._ 4 | 5 | class LRUpdater extends Updater[Long, Float]{ 6 | 7 | private var learningRate:Float = 0.01F 8 | //normal lr push function to deal with last weights and gradients 9 | override def update(lastWeightMap:java.util.Map[Long,Float], gradientMap:java.util.Map[Long,Float]): java.util.Map[Long,Float] = { 10 | val updateMap = lastWeightMap.map { case (k, v) => 11 | val delta = -learningRate * gradientMap(k) 12 | (k, v + delta) 13 | } 14 | updateMap 15 | } 16 | //set learning rate 17 | def setLearningRate(lr:Float):this.type ={ 18 | this.learningRate = lr 19 | this 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/updater/MomLRUpdater.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.updater 2 | 3 | import java.util 4 | import scala.collection.JavaConversions._ 5 | 6 | class MomLRUpdater extends Updater[Long, Array[Float]] { 7 | private var coff = 0.1F 8 | private var learningRate = 0.01F 9 | 10 | override def update(originalWeightMap: util.Map[Long, Array[Float]], gradientMap: util.Map[Long, Array[Float]]): util.Map[Long, Array[Float]] = { 11 | val updateMap = originalWeightMap.map { case (k, v) => 12 | val delta = -learningRate * gradientMap(k)(0) + coff * v(1) 13 | v(0) = v(0) + delta 14 | v(1) = delta 15 | // println("update map: V0: " + v(0) + "----> V1: " + v(1)) 16 | (k, v) 17 | } 18 | updateMap 19 | } 20 | 21 | //set learning rate 22 | def setLearningRate(lr:Float):this.type ={ 23 | this.learningRate = lr 24 | this 25 | } 26 | 27 | def setCoff(coff:Float):this.type ={ 28 | this.coff = coff 29 | this 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/updater/Updater.scala: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.updater 2 | 3 | trait Updater[K, V] extends Serializable{ 4 | //update function 5 | def update(originalWeightMap: java.util.Map[K, V], gradientMap: java.util.Map[K, V]):java.util.Map[K, V] 6 | } 7 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/utils/ExitCodeResolver.java: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.utils; 2 | 3 | public class ExitCodeResolver { 4 | 5 | static void analysis(int exitCode){ 6 | switch(exitCode){ 7 | case -1: { 8 | 9 | }break; 10 | case -2: { 11 | 12 | }break; 13 | case -3: { 14 | 15 | }break; 16 | case -4: { 17 | 18 | }break; 19 | case -5: { 20 | 21 | }break; 22 | case -6: { 23 | 24 | }break; 25 | case -7: { 26 | 27 | }break; 28 | default: 29 | throw new IllegalArgumentException("Unknown exit code."); 30 | } 31 | } 32 | 33 | } 34 | -------------------------------------------------------------------------------- /src/main/scala/net/qihoo/xitong/xdml/utils/XDMLException.java: -------------------------------------------------------------------------------- 1 | package net.qihoo.xitong.xdml.utils; 2 | 3 | public class XDMLException extends RuntimeException{ 4 | private static final long serialVersionUID = 1L; 5 | 6 | public XDMLException() { 7 | } 8 | 9 | public XDMLException(String message) { 10 | super(message); 11 | } 12 | 13 | public XDMLException(String message, Throwable cause) { 14 | super(message, cause); 15 | } 16 | 17 | public XDMLException(Throwable cause) { 18 | super(cause); 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/test/scala/DataColFilter.scala: -------------------------------------------------------------------------------- 1 | import scala.collection.mutable 2 | import scala.io.Source 3 | import scala.util.control.Breaks._ 4 | 5 | /** 6 | * @author wangxingda 7 | * 2018/08/16 8 | */ 9 | 10 | class DataColFilter extends DataFilter { 11 | val featureFilterSet = new mutable.HashSet[String] 12 | def readFilterConfig(filePath: String): Unit = { 13 | val file = Source.fromFile(filePath) 14 | var isStarted:Boolean = false 15 | for (line <- file.getLines() if !line.trim.equals("")) { 16 | if(line.contains("[col]")){ 17 | isStarted = true 18 | } else if(isStarted){ 19 | breakable{ 20 | if(line.trim.contains("[row]")){ 21 | isStarted = false 22 | break 23 | } 24 | val kv = getKeyAndValue(line.trim) 25 | if(kv._2 == 1) 26 | featureFilterSet.add(kv._1) 27 | } 28 | } 29 | } 30 | println("colFilter:" + featureFilterSet.mkString(",")) 31 | } 32 | 33 | def colFilter(feature: Array[String]): Array[String] = { 34 | feature.filter(x => !featureFilterSet.contains(x.split(splitC)(0))) 35 | } 36 | 37 | //get key value from k = v 38 | def getKeyAndValue(str: String): (String, Long) = { 39 | if (str.equals("")) 40 | throw new IllegalArgumentException 41 | val splits = str.split("=") 42 | if(splits.length != 2){ 43 | throw new IllegalArgumentException("Col filter getKeyAndValue Error: cause by splits's length is not equals 2: " + str) 44 | } 45 | var value = -1 46 | try{ 47 | value = splits(1).trim.toInt 48 | }catch{ 49 | case ex:Exception => println("Col filter getKeyAndValue Error: cause by value is not a number value: " + str) 50 | ex.printStackTrace() 51 | } 52 | (splits(0).trim, value) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/test/scala/DataFilter.scala: -------------------------------------------------------------------------------- 1 | import java.io.Serializable 2 | 3 | /** 4 | * @author wangxingda 5 | * 2018/08/16 6 | */ 7 | 8 | trait DataFilter extends Serializable{ 9 | val splitC: String = "\03" 10 | val splitB: String = "\02" 11 | val splitA: String = "\01" 12 | val splitT: String = "\t" 13 | //read config file, include row and col filter 14 | def readFilterConfig(filePath: String) 15 | } 16 | -------------------------------------------------------------------------------- /src/test/scala/TestFilter.scala: -------------------------------------------------------------------------------- 1 | import net.qihoo.xitong.xdml.utils.CityHash 2 | import org.apache.spark.{SparkConf, SparkContext} 3 | 4 | import scala.collection.mutable.ArrayBuffer 5 | 6 | object TestFilter { 7 | val rowFilterConfigFile = "ftrl_filter_demo.txt" 8 | val colFilterConfigFile = "ftrl_filter_demo.txt" 9 | val splitA = "\01" 10 | val splitB = "\02" 11 | val splitC = "\03" 12 | 13 | 14 | def main(args: Array[String]): Unit = { 15 | //配置spark 16 | val rowF = new DataRowFilter() 17 | rowF.readFilterConfig(rowFilterConfigFile) 18 | val colF = new DataColFilter() 19 | colF.readFilterConfig(colFilterConfigFile) 20 | val sc = new SparkContext(new SparkConf().setAppName("TestFilter")) 21 | val trainInputPath = sc.getConf.get("spark.ndml.data.path", "").trim 22 | if (trainInputPath == "") { 23 | println(s"\'spark.ndml.data.path\' must be set of input data.") 24 | System.exit(-1) 25 | } 26 | val rawData = sc.textFile(trainInputPath).repartition(1000) 27 | 28 | //使用ArrayBuffer只过一遍数据的新思路 29 | val s = System.nanoTime() 30 | val number = rawData.mapPartitions(iter =>{ 31 | val res = new ArrayBuffer[(Double,Array[Long],Array[Float])] 32 | while(iter.hasNext){ 33 | val line = iter.next() 34 | if(rowF.rowFilter(line)) { 35 | val splits = line.split(splitA) 36 | val list = splits.slice(1, splits.length) 37 | val featureList = colF.colFilter(list) 38 | res.append((splits(0).toDouble, featureList.map(x=>CityHash.stringCityHash64(x.split(splitB)(0))), featureList.map(x=>x.split(splitB)(1).toFloat))) 39 | } 40 | } 41 | res.iterator 42 | }) 43 | val e = System.nanoTime() 44 | 45 | //传统思路 46 | // val s = System.nanoTime() 47 | // val number = rawData.filter(line => rowF.rowFilter(line)).map(line => { 48 | // val splits = line.split(splitA) 49 | // val list = splits.slice(1, splits.length) 50 | // val featureList = colF.colFilter(list) 51 | // (splits(0).toDouble, featureList.map(x=>CityHash.stringCityHash64(x.split(splitB)(0))), featureList.map(x=>x.split(splitB)(1).toFloat)) 52 | // }).count() 53 | // val e = System.nanoTime() 54 | 55 | println("time:" + (e - s)/1e9) 56 | println("number:" + number) 57 | sc.stop() 58 | } 59 | } 60 | --------------------------------------------------------------------------------