├── CMakeLists.txt ├── LICENSE ├── README.md ├── docker ├── deploy │ └── Dockerfile └── develop │ └── Dockerfile ├── examples └── deepctr │ ├── data.txt │ └── deepctr.py ├── tef ├── CMakeLists.txt ├── core │ ├── CMakeLists.txt │ ├── build.sh │ ├── kernels │ │ ├── example_op.cc │ │ ├── example_op.cu │ │ ├── example_op.h │ │ ├── ps_client │ │ │ ├── ps_client.h │ │ │ ├── ps_client_dummy.cc │ │ │ ├── ps_client_dummy.h │ │ │ ├── ps_client_factory.cc │ │ │ └── ps_client_factory.h │ │ ├── ps_hash_pull_op.cc │ │ ├── ps_hash_pull_op.h │ │ ├── ps_hash_push_op.cc │ │ ├── ps_hash_push_op.h │ │ ├── ps_pull_op.cc │ │ ├── ps_pull_op.h │ │ ├── ps_push_op.cc │ │ ├── ps_push_op.h │ │ ├── ps_sparse_pull_op.cc │ │ ├── ps_sparse_pull_op.h │ │ ├── ps_sparse_push_op.cc │ │ ├── ps_sparse_push_op.h │ │ ├── zero_out_op.cc │ │ └── zero_out_op.h │ └── ops │ │ ├── example_ops.cc │ │ └── ps_ops.cc └── python │ ├── CMakeLists.txt │ ├── setup.py.in │ └── tef │ ├── __init__.py │ ├── ops │ ├── __init__.py │ ├── embedding.py │ └── variable.py │ ├── pywrap │ ├── __init__.py │ ├── tef_core.py │ └── tef_core_test.py │ ├── training │ ├── __init__.py │ └── optimizer.py │ └── utils │ ├── __init__.py │ └── collections.py └── third_party └── CMakeLists.txt /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(tef) 2 | cmake_minimum_required(VERSION 3.0.0 FATAL_ERROR) 3 | 4 | set(PACKAGE_VERSION "1.0.0.0") 5 | 6 | add_subdirectory(tef) 7 | add_subdirectory(third_party) 8 | 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TEF(TensorflowExtendFramework) 2 | 3 | ## 概述 4 | Tensorflow是目前使用最广泛的深度学习解决方案,但是在高维稀疏数据场景(如广告、推荐、搜索等)下Tensorflow有很多不足之处: 5 | 6 | * 参数必须以固定维度的矩阵形式提前分配(训练开始前),不支持参数的实时(训练过程中)分配与淘汰 7 | * 不支持参数的增量形式的导出 8 | 9 | TensorflowExtendFramework(以下简称TEF)是笔者开源的针对高维稀疏数据场景(如广告,推荐,搜索等)的深度学习解决方案. 10 | 11 | TEF通过Operation扩展的机制,将Tensorflow的参数分配与更新任务交给自定义的参数服务器来承担,从而克服了以上的几点不足: 12 | 13 | * 通过TEF可以很方便的对接自定义参数服务器 14 | * 通过TEF可以很方便的实现参数的动态分配和淘汰,以及参数增量导出 15 | * TEF以单独的Python Package形式安装部署 16 | 17 | ## 编译与安装 18 | 19 | 1.构建开发docker镜像 20 | 21 | ``` 22 | cd docker/develop/ 23 | docker build -t tef_develop . 24 | ``` 25 | 26 | 2.启动docker开发环境 27 | 28 | ``` 29 | docker run -it --net=host tef_develop 30 | 31 | ``` 32 | 33 | 2.编译,生成Python Package安装包 34 | 35 | ``` 36 | git clone https://github.com/jony0917/tensorflow-extend-framework.git 37 | cd tensorflow-extend-framework 38 | mkdir build 39 | cd build 40 | cmake .. 41 | make tef 42 | ``` 43 | 44 | 3.pip安装tef 45 | 46 | ``` 47 | pip install build/tef/python/dist/tf-x.x.x.x-py2-none-any.whl 48 | ``` 49 | 50 | 4.运行example,确认正确安装 51 | 52 | ``` 53 | cd examples/deepctr 54 | python deepctr.py 55 | ``` 56 | 57 | 看到类似一下输入,表明安装正确: 58 | 59 | ``` 60 | ... 61 | batch=9, loss=0.23234 62 | ... 63 | ``` 64 | 65 | ## 使用指南 66 | 67 | 1. 通过TEF,你可以通过以下简单两步构建自己的支持高维稀疏数据的场景的深度学习解决方案: 68 | 69 | * 首先你需要有自己的参数服务器,或则使用第三方参数服务器,如pslite,ps_plus等 70 | * 然后为你的参数服务器实现接口 tef/core/kernels/ps\_client.h:PsClient, 参考样例:tef/core/kernels/ps\_client/ps\_client\_dummy.h 71 | 72 | 2. 主要API介绍: 73 | 74 | |方法或类名|功能| 75 | |---|---| 76 | |tef.ops.variable|分配稠密参数| 77 | |tef.ops.embedding|获取离散参数embedding| 78 | |tef.ops.embedding_sprase|获取离散参数embedding| 79 | |tef.training.GradientDescentOptimizer|训练优化器| 80 | 81 | 82 | 参考样例:examples/deepctr/deepctr.py 83 | 84 | ## 设计文档 85 | 86 | [TensorflowExtendFramework](https://blog.csdn.net/gaofeipaopaotang/article/details/104182284) 87 | -------------------------------------------------------------------------------- /docker/deploy/Dockerfile: -------------------------------------------------------------------------------- 1 | 2 | FROM tensorflow/tensorflow:2.1.0-gpu 3 | -------------------------------------------------------------------------------- /docker/develop/Dockerfile: -------------------------------------------------------------------------------- 1 | 2 | FROM tensorflow/tensorflow:2.1.0-gpu 3 | 4 | RUN apt-get update && apt-get -y upgrade; \ 5 | apt-get install -y build-essential cmake 6 | 7 | RUN mkdir -p /usr/local/lib/python2.7/dist-packages/tensorflow_core/include/third_party/gpus/cuda && \ 8 | ln -s /usr/local/cuda-10.1/targets/x86_64-linux/include/ /usr/local/lib/python2.7/dist-packages/tensorflow_core/include/third_party/gpus/cuda/include -------------------------------------------------------------------------------- /examples/deepctr/data.txt: -------------------------------------------------------------------------------- 1 | uid:9353999, age:97, interest:1859|9227|9665, aid:8758751, ad_kw:9484|3953|6297|8253|7989, label:1 2 | uid:6675311, age:80, interest:9523|4181|5077, aid:2084710, ad_kw:5321|1953|3896, label:1 3 | uid:7062075, age:43, interest:1286|8166|4152, aid:2880725, ad_kw:7605|364|8534, label:0 4 | uid:8159127, age:91, interest:5196|5379|3925, aid:1701079, ad_kw:1330|998|8374, label:1 5 | uid:2561610, age:76, interest:3678|3345|2616, aid:1609000, ad_kw:5940|632|1418, label:0 6 | uid:9574041, age:53, interest:8366|9934|8324|8761|7528, aid:1905099, ad_kw:2228|84|5615, label:1 7 | uid:5074168, age:69, interest:2837|9596|8514|2502|4047, aid:2043983, ad_kw:2504|95|2961|2667|9737, label:0 8 | uid:7834986, age:52, interest:3209|3056|8483|8207|5606, aid:5286996, ad_kw:7306|66|6489|8611|6998, label:1 9 | uid:4691151, age:31, interest:8924|5466|9302|5383|4085, aid:8553737, ad_kw:2972|37|3491|2131|3805, label:1 10 | uid:1337162, age:36, interest:7878|3348|9547|8674, aid:9301390, ad_kw:8113|89|3006|2381|3833, label:1 11 | uid:3640250, age:55, interest:7477|6216|9854|1357, aid:3763341, ad_kw:5410|322|3646|4924|4104, label:1 12 | uid:6611600, age:48, interest:2305|9569|3995|4423, aid:6132263, ad_kw:321|657|2632|7726|1704, label:1 13 | uid:7219447, age:31, interest:4786|1838|8556|3554, aid:3973624, ad_kw:656|6158|4350|7298|5171, label:0 14 | uid:6179065, age:62, interest:8512|7938|6812|5048, aid:1143872, ad_kw:646|8976|9003|7092, label:1 15 | uid:1347542, age:56, interest:9914|1419|5406|3172|7666, aid:9948332, ad_kw:568|8846|8060|8434, label:1 16 | uid:6615128, age:59, interest:9080|6113|9778|2309|2238, aid:2566511, ad_kw:116|2401|5527|6197, label:1 17 | uid:4640079, age:46, interest:8267|1192|2910|9380|4453, aid:8158332, ad_kw:8099|7842|7978|4161, label:0 18 | uid:6103692, age:32, interest:8045|3895|3021|1518|7672, aid:5155754, ad_kw:1155|8113|7798, label:1 19 | uid:7879810, age:65, interest:5565|2388|3868|9909|9696, aid:5930441, ad_kw:3442|2878|1127|9874, label:1 20 | uid:9818325, age:95, interest:561|7302|6824, aid:8629044, ad_kw:6821|7471|6531|7697, label:0 21 | uid:8502912, age:41, interest:925|6682|6233, aid:7785295, ad_kw:5854|6660|8882|8253, label:1 22 | uid:8716364, age:11, interest:197|9932|6465, aid:6724705, ad_kw:4790|6029|1450|3387, label:1 23 | uid:7789805, age:22, interest:520|4662|7606, aid:4075113, ad_kw:2534|9362|4080|5459, label:0 24 | uid:6133514, age:49, interest:457|4407|9848, aid:6336337, ad_kw:9376|62|6162|1922, label:1 25 | uid:5789729, age:87, interest:7414|9634|9663|5685|9561, aid:4154569, ad_kw:2406|70|1817|3940, label:0 26 | uid:7162897, age:81, interest:8099|8516|2179|4241|3014, aid:9356155, ad_kw:7415|89|6708|3093, label:1 27 | uid:4488119, age:80, interest:9106|9408|8810|7372|9791, aid:9914512, ad_kw:1100|31|8309|9798|5753, label:0 28 | uid:1522167, age:11, interest:1296|3436|5029|1192|5255, aid:1814794, ad_kw:8288|1510|7683|8544|4818, label:1 29 | uid:3730068, age:68, interest:7480|3010|4407|3938|1370, aid:1987204, ad_kw:1654|5179|1363, label:0 30 | uid:4396590, age:71, interest:2414|5432|5523|5454|1815, aid:1364600, ad_kw:3041|2539|4021, label:1 31 | uid:3108433, age:28, interest:3482|9942, aid:7980304, ad_kw:2736|3509|3640, label:0 32 | uid:4169021, age:12, interest:5715|4977, aid:1437304, ad_kw:9419|9361|9213, label:1 33 | uid:9475283, age:14, interest:5563|3147, aid:6482805, ad_kw:9694|3169|5843|5526|6291, label:0 34 | uid:7649589, age:20, interest:5871|2202, aid:5446346, ad_kw:8932|732|7926|1624|3460, label:1 35 | uid:8050478, age:36, interest:8159|6855, aid:3312014, ad_kw:3650|8407|4246|8194, label:0 36 | uid:2650721, age:57, interest:4041|3164|5486|9064|3671, aid:9405376, ad_kw:1918|7286|2501|5208, label:1 37 | uid:5682554, age:43, interest:9624|4277|8000|6566|9565, aid:6436832, ad_kw:6180|1562|7596|2582, label:0 38 | uid:3830106, age:40, interest:3583|6655|4219|7939|5910, aid:6431803, ad_kw:4143|28|5338|4699, label:1 39 | uid:5615369, age:85, interest:6360|2969|4692|8852|9317, aid:3624240, ad_kw:1230|12|9125|7866, label:0 40 | uid:6902261, age:31, interest:2341|9029|1474|5313|3415, aid:4916774, ad_kw:2031|28|1476|7524, label:1 41 | uid:1367074, age:50, interest:9945|1753|9082|7515|5584, aid:4252763, ad_kw:3313|32|6078|9329, label:0 42 | uid:4919107, age:50, interest:8738|1323|9612|5845|8266, aid:9535584, ad_kw:3935|32|9019|3408, label:1 43 | uid:9781314, age:58, interest:2653|2990|2158|3611|8051, aid:8358593, ad_kw:6434|771|5111|4351, label:0 44 | uid:2077723, age:96, interest:5039|6499|2543|5510|2611, aid:7455425, ad_kw:8629|88|1014|2152|1358, label:1 45 | uid:6020647, age:25, interest:5341|5014|6397|2434, aid:3303916, ad_kw:1528|64|9099|4170|3305, label:0 46 | uid:8125088, age:74, interest:7359|2041|9936|5326, aid:8431471, ad_kw:7049|25|6065|6687|5415, label:1 47 | uid:2001858, age:45, interest:4489|8324|8421|5029, aid:4029719, ad_kw:9270|7424|3311|2013|2245, label:1 48 | uid:5066618, age:27, interest:9739|5950|4695|8166, aid:2291458, ad_kw:3233|1059|8721|1678|6196, label:0 49 | uid:1843585, age:43, interest:1252|6205|3182, aid:8784191, ad_kw:3168|8170|7925, label:1 50 | uid:3747335, age:86, interest:4125|7467|2130|1737, aid:4375473, ad_kw:9917|4866|1309, label:0 51 | uid:9727766, age:96, interest:7754|5197|8373|6595, aid:3249062, ad_kw:44|6246|7571, label:1 52 | uid:7840482, age:19, interest:7917|4084|2022|9331, aid:6214805, ad_kw:11|7897|9457, label:0 53 | uid:8220846, age:46, interest:9362|5041|6188|4474, aid:7839329, ad_kw:19|1069|8223, label:1 54 | uid:8674390, age:61, interest:4847|9718|4334|3898|1665, aid:7181967, ad_kw:59|4423|5034, label:1 55 | uid:4461713, age:53, interest:8030|6751|7145|5872|5137, aid:6596598, ad_kw:80|9248|5255, label:1 56 | uid:4351758, age:94, interest:8329|8517|7109|9017|5570, aid:3125795, ad_kw:922|5001, label:0 57 | uid:1320158, age:14, interest:7619|6149|5483|2761|5218, aid:6232779, ad_kw:769|3536, label:1 58 | uid:9844501, age:89, interest:3226|8408|3651, aid:2305977, ad_kw:575|1479|3707|5061, label:1 59 | uid:9214127, age:18, interest:1724|7354|5608, aid:8240138, ad_kw:911|3864, label:0 60 | uid:7531276, age:80, interest:3173|2491|7182, aid:6867092, ad_kw:302|6600, label:1 61 | uid:8515137, age:80, interest:9621|1749|5353, aid:3104976, ad_kw:7583|9555|6858, label:0 62 | uid:9432604, age:35, interest:4514|8088, aid:3697956, ad_kw:4424|1173|2472, label:1 63 | uid:9992290, age:64, interest:7624|8906|6441|3531, aid:5977519, ad_kw:9628|9922|4787, label:0 64 | uid:7539277, age:68, interest:3214|8001|4987|4519, aid:6999235, ad_kw:6688|9230|8077|8115|4506, label:1 65 | uid:9514097, age:11, interest:125|8771|7563, aid:6056630, ad_kw:3546|4688|2435|9881|4948, label:1 66 | uid:2473421, age:93, interest:729|7612|4172, aid:5064223, ad_kw:8616|4099|1637, label:1 67 | uid:6179713, age:81, interest:688|7621|7546|5879, aid:9126220, ad_kw:9702|2079|4703, label:1 68 | uid:9254005, age:84, interest:546|826|2102|5567, aid:8527769, ad_kw:8257|8157, label:1 69 | uid:7425878, age:58, interest:1576|791|2918|8282, aid:8246697, ad_kw:4149|5801, label:1 70 | uid:8052043, age:95, interest:3223|607|1317|1108|6788, aid:9095950, ad_kw:747|9481, label:0 71 | uid:5237922, age:71, interest:4453|825|7330|3216|3958, aid:8096047, ad_kw:545|144|9379|1621, label:1 72 | uid:4584512, age:33, interest:6983|5529|5775|9883|2591, aid:2796569, ad_kw:384|852|5656|9143, label:1 73 | uid:3574079, age:28, interest:9736|1475|4678, aid:6914321, ad_kw:801|443|8099|2869|4737, label:1 74 | uid:4209289, age:88, interest:7151|7068|6710, aid:3789727, ad_kw:744|277|8095|8473|7936, label:0 75 | uid:3882588, age:52, interest:3859|1195|4912, aid:7793564, ad_kw:219|3236|7311|2552|2523, label:1 76 | uid:9277663, age:34, interest:6711|5904, aid:2423034, ad_kw:5543|49|3783|4432|3588, label:0 77 | uid:6682862, age:49, interest:3936|1410, aid:5298903, ad_kw:6759|75|3433|6362|3794, label:1 78 | uid:4654917, age:35, interest:110|1856|4977|8932, aid:8819906, ad_kw:8156|56|7670|8746|1075, label:0 79 | uid:9598000, age:98, interest:715|3793|6600|1478, aid:3728436, ad_kw:204|99|4248|4155, label:1 80 | uid:6935068, age:79, interest:405|6719|7377|7132, aid:4730969, ad_kw:339|26|3856|1065, label:0 81 | uid:2104685, age:97, interest:85|4144|1908|1155|6368, aid:8420238, ad_kw:767|1243|1557, label:1 82 | uid:2935951, age:66, interest:199|3945|1243|1849|2157, aid:7815100, ad_kw:591|2799|9094, label:0 83 | uid:9864003, age:47, interest:283|8816|5860|9300|1410, aid:6256864, ad_kw:522|1760|1384, label:1 84 | uid:2894806, age:48, interest:898|8316|1393|6292|3095, aid:2191152, ad_kw:9401|8051|5909|7925, label:0 85 | uid:7290214, age:24, interest:7886|4187|3213|3313, aid:9272494, ad_kw:6180|7632|1104|6800, label:0 86 | uid:9381697, age:13, interest:8700|218|2460|2310, aid:9820827, ad_kw:5457|8899|8935|8477, label:0 87 | uid:9140512, age:41, interest:3801|49|6238|2466, aid:4561997, ad_kw:4456|9476|8251|6698, label:0 88 | uid:1227504, age:23, interest:9378|35|8256|1627, aid:6918989, ad_kw:614|6308|3883|4388, label:0 89 | uid:6328737, age:40, interest:6376|2|9364|2756|4967, aid:7965251, ad_kw:531|8662|5262|5784, label:1 90 | uid:2247776, age:17, interest:8204|48|9870|5487|8539, aid:3535906, ad_kw:449|6409|3563, label:1 91 | uid:9234562, age:63, interest:2666|388|5736, aid:8261819, ad_kw:745|9430|6327, label:0 92 | uid:6185915, age:73, interest:7039|891|4669, aid:9902313, ad_kw:848|2876, label:1 93 | uid:4870611, age:26, interest:5199|1179|7906, aid:3777557, ad_kw:6666|6087, label:1 94 | uid:9454297, age:76, interest:6753|3552|8672, aid:2322622, ad_kw:7975|5499, label:0 95 | uid:8993227, age:83, interest:258|2273, aid:5440483, ad_kw:8535|5408, label:0 96 | uid:2246873, age:28, interest:265|352|529|3244, aid:8571464, ad_kw:3807|3162|4969|8836, label:1 97 | uid:8874975, age:87, interest:600|667|182|9644, aid:2820098, ad_kw:3590|5103|8486|9060, label:0 98 | uid:5520261, age:98, interest:350|823|691|3424, aid:4660616, ad_kw:5434|5458|7102|1056|5691, label:1 99 | uid:9592540, age:56, interest:612|465|327|1845, aid:7186802, ad_kw:1486|5981|7101|9382|3736, label:1 100 | uid:5786293, age:83, interest:9179|264|200|3624|4947, aid:8701707, ad_kw:1256|2278|6896|2871|2106, label:0 101 | -------------------------------------------------------------------------------- /examples/deepctr/deepctr.py: -------------------------------------------------------------------------------- 1 | 2 | import multiprocessing 3 | import tensorflow as tf 4 | import tef 5 | import tef.ops 6 | import tef.training 7 | 8 | 9 | batch_queue = multiprocessing.Queue(maxsize=5000) 10 | 11 | def load_data(): 12 | global batch_queue 13 | with open("data.txt") as fp: 14 | for line in fp.readlines(): 15 | columns = line.split(",") 16 | assert len(columns) == 6 17 | 18 | kv = {} 19 | for i in range(len(columns)): 20 | column = columns[i].strip() 21 | items = column.split(":") 22 | assert len(items) == 2 23 | key = items[0] 24 | values = items[1].split("|") 25 | assert len(values) > 0 26 | for k in range(len(values)): 27 | values[k] = int(values[k]) 28 | if key == "interest" or key == "ad_kw": 29 | while len(values) < 5: 30 | values.append(0) 31 | kv[key] = values 32 | 33 | print kv 34 | batch_queue.put((kv["uid"][0], kv["age"][0], kv["interest"], kv["aid"][0], kv["ad_kw"], kv["label"])) 35 | 36 | 37 | def data_generator(): 38 | global batch_queue 39 | while True: 40 | yield batch_queue.get() 41 | 42 | 43 | def data_from_feed(): 44 | data_set = tf.data.Dataset.from_generator(data_generator, (tf.int64, tf.int64, tf.int64, tf.int64, tf.int64, tf.float32)) 45 | #data_set = data_set.padded_batch(4, padded_shapes=[None]) 46 | data_set = data_set.batch(5) 47 | iterator = tf.compat.v1.data.make_one_shot_iterator(data_set) 48 | return iterator.get_next() 49 | 50 | 51 | def full_connect(name, input, input_dim, output_dim): 52 | w = tef.ops.variable("%s_w_%dx%d" % (name, input_dim, output_dim), [input_dim, output_dim], tf.float32) 53 | b = tef.ops.variable("%s_b_%d" % (name, output_dim), [output_dim], tf.float32) 54 | return tf.sigmoid(tf.matmul(input, w) + b) 55 | 56 | 57 | def dense_to_sparse(dense, missing_element): 58 | indices = tf.where(tf.not_equal(dense, missing_element)) 59 | values = tf.gather_nd(dense, indices) 60 | shape = tf.shape(dense, out_type=tf.int64) 61 | return tf.SparseTensor(indices, values, shape) 62 | 63 | def deep_ctr(): 64 | graph = tf.Graph() 65 | with graph.as_default(): 66 | uid, age, interest, aid, ad_kw, label = data_from_feed() 67 | 68 | embs = [] 69 | uid_emb = tef.ops.embedding(uid, "uid", [20], tf.float32, id_type="hash") 70 | embs.append(uid_emb) 71 | 72 | age_emb = tef.ops.embedding(age, "age", [120, 20], tf.float32, id_type="index") 73 | embs.append(age_emb) 74 | 75 | 76 | sp_interest = dense_to_sparse(interest, 0) 77 | interest_emb = tef.ops.embedding_sparse(sp_interest, 78 | "interest", 79 | [20], 80 | tf.float32, 81 | id_type="hash", 82 | combiner="mean") 83 | embs.append(interest_emb) 84 | 85 | aid_emb = tef.ops.embedding(aid, "aid", [20], tf.float32, id_type="hash") 86 | embs.append(aid_emb) 87 | 88 | 89 | 90 | sp_ad_kw = dense_to_sparse(ad_kw, 0) 91 | ad_kw_emb = tef.ops.embedding_sparse(sp_ad_kw, 92 | "ad_kw", 93 | [20], 94 | tf.float32, 95 | id_type="hash", 96 | combiner="mean") 97 | embs.append(ad_kw_emb) 98 | 99 | x = tf.concat(embs, axis=1) 100 | x = full_connect("fc_1", x, 5 * 20, 100) 101 | x = full_connect("fc_2", x, 100, 100) 102 | y = full_connect("fc_3", x, 100, 1) 103 | 104 | loss = tf.nn.sigmoid_cross_entropy_with_logits(y, label) 105 | loss_mean = tf.reduce_mean(loss) 106 | sgd_optimizer = tef.training.GradientDescentOptimizer(0.002) 107 | gs, stubs = sgd_optimizer.compute_gradients(loss) 108 | train_op = sgd_optimizer.apply_gradients(gs, stubs) 109 | 110 | sess = tf.compat.v1.Session(graph = graph) 111 | batch = 0 112 | while batch < 10: 113 | loss_value, _ = sess.run([loss_mean, train_op]) 114 | print "batch=%d, loss=%f" % (batch, loss_value) 115 | batch += 1 116 | 117 | if __name__ == '__main__': 118 | data_load_process = multiprocessing.Process(target=load_data) 119 | data_load_process.daemon = True 120 | data_load_process.start() 121 | 122 | deep_ctr() 123 | -------------------------------------------------------------------------------- /tef/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(tef) 2 | cmake_minimum_required(VERSION 3.0.0 FATAL_ERROR) 3 | 4 | add_subdirectory(core) 5 | add_subdirectory(python) 6 | 7 | -------------------------------------------------------------------------------- /tef/core/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(tef) 2 | cmake_minimum_required(VERSION 3.0.0 FATAL_ERROR) 3 | 4 | execute_process(COMMAND python -c "import tensorflow as tf; print(\" \".join(tf.sysconfig.get_compile_flags()))" OUTPUT_VARIABLE TENSORFLOW_CFLAGS) 5 | string(REGEX REPLACE "\n$" "" TENSORFLOW_CFLAGS ${TENSORFLOW_CFLAGS}) 6 | message(STATUS "TENSORFLOW_CFLAGS=${TENSORFLOW_CFLAGS}") 7 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -D GOOGLE_CUDA=1 ${TENSORFLOW_CFLAGS} -O3") 8 | 9 | 10 | execute_process(COMMAND python -c "import tensorflow as tf; print(\" \".join(tf.sysconfig.get_link_flags()))" OUTPUT_VARIABLE TENSORFLOW_LFLAGS) 11 | string(REGEX REPLACE "\n$" "" TENSORFLOW_LFLAGS ${TENSORFLOW_LFLAGS}) 12 | message(STATUS "TENSORFLOW_LFLAGS=${TENSORFLOW_LFLAGS}") 13 | 14 | find_package(CUDA) 15 | 16 | ############################################################################ 17 | # Gencode arguments 18 | set(SMS 30 35 37 50 52 60 70) 19 | foreach(sm ${SMS}) 20 | set(GENCODE_FLAGS ${GENCODE_FLAGS} "-gencode arch=compute_${sm},code=sm_${sm}") 21 | endforeach() 22 | 23 | set(HIGHEST_SM 70) 24 | set(GENCODE_FLAGS ${GENCODE_FLAGS} "-gencode arch=compute_${HIGHEST_SM},code=compute_${HIGHEST_SM}") 25 | message(STATUS "GENCODE_FLAGS=${GENCODE_FLAGS}") 26 | 27 | 28 | ############################################################################ 29 | # compile targets 30 | set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} ${GENCODE_FLAGS} "-D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr -DNDEBUG -std=c++11") 31 | cuda_compile(cuda_op_objects 32 | "kernels/example_op.cu") 33 | message(STATUS "cuda_op_objects=${cuda_op_objects}") 34 | 35 | 36 | set(SOURCE_FILES 37 | "ops/example_ops.cc" 38 | "ops/ps_ops.cc" 39 | "kernels/zero_out_op.h" 40 | "kernels/zero_out_op.cc" 41 | "kernels/example_op.h" 42 | "kernels/example_op.cc" 43 | "kernels/ps_hash_pull_op.h" 44 | "kernels/ps_hash_pull_op.cc" 45 | "kernels/ps_hash_push_op.h" 46 | "kernels/ps_hash_push_op.cc" 47 | "kernels/ps_pull_op.h" 48 | "kernels/ps_pull_op.cc" 49 | "kernels/ps_push_op.h" 50 | "kernels/ps_push_op.cc" 51 | "kernels/ps_sparse_pull_op.h" 52 | "kernels/ps_sparse_pull_op.cc" 53 | "kernels/ps_sparse_push_op.h" 54 | "kernels/ps_sparse_push_op.cc" 55 | "kernels/ps_client/ps_client.h" 56 | "kernels/ps_client/ps_client_dummy.h" 57 | "kernels/ps_client/ps_client_dummy.cc" 58 | "kernels/ps_client/ps_client_factory.h" 59 | "kernels/ps_client/ps_client_factory.cc") 60 | 61 | add_library(tef_core SHARED ${SOURCE_FILES} ${cuda_op_objects}) 62 | target_link_libraries(tef_core ${TENSORFLOW_LFLAGS}) 63 | add_custom_command(TARGET tef_core 64 | POST_BUILD 65 | COMMAND ${CMAKE_COMMAND} -E copy $ "${CMAKE_CURRENT_SOURCE_DIR}/../python/tef/pywrap/" 66 | COMMENT "coping output of tef_core task") -------------------------------------------------------------------------------- /tef/core/build.sh: -------------------------------------------------------------------------------- 1 | #/bin/dash 2 | 3 | TF_CFLAGS=$(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') 4 | TF_LFLAGS=$(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') 5 | 6 | echo ${TF_CFLAGS[@]} 7 | echo ${TF_LFLAGS[@]} 8 | 9 | /usr/local/cuda-10.1/bin/nvcc -std=c++11 -c -o kernels/example_op.cu.o kernels/example_op.cu.cc ${TF_CFLAGS[@]} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr -DNDEBUG 10 | g++ -std=c++11 -shared ops/example_ops.cc kernels/zero_out_op.cc kernels/example_op.cc kernels/example_op.cu.o -o libtef_core.so -D GOOGLE_CUDA=1 -fPIC ${TF_CFLAGS[@]} -lcudart ${TF_LFLAGS[@]} -O2 -L/usr/local/cuda-10.1/targets/x86_64-linux/lib/ 11 | -------------------------------------------------------------------------------- /tef/core/kernels/example_op.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include "tensorflow/core/framework/op_kernel.h" 4 | #include "example_op.h" 5 | 6 | using namespace tensorflow; 7 | 8 | using CPUDevice = Eigen::ThreadPoolDevice; 9 | using GPUDevice = Eigen::GpuDevice; 10 | 11 | template 12 | struct ExampleFunctor { 13 | void operator()(const CPUDevice& device, int size, const T* in, T* out) { 14 | for(int i = 0; i < size; i++){ 15 | out[i] = 2 * in[i]; 16 | } 17 | } 18 | }; 19 | 20 | // OpKernel definition. 21 | // template parameter is the datatype of the tensors 22 | template 23 | class ExampleOp : public OpKernel { 24 | public: 25 | explicit ExampleOp(OpKernelConstruction* context) : OpKernel(context) {} 26 | 27 | public: 28 | void Compute(OpKernelContext* context) override { 29 | // Grab the input tensor 30 | const Tensor& input_tensor = context->input(0); 31 | 32 | // Create an output tensor 33 | Tensor* output_tensor = NULL; 34 | OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), 35 | &output_tensor)); 36 | // Do the computation. 37 | OP_REQUIRES(context, input_tensor.NumElements() <= tensorflow::kint32max, 38 | errors::InvalidArgument("Too many elements in tensor")); 39 | 40 | ExampleFunctor()( 41 | context->eigen_device(), 42 | static_cast(input_tensor.NumElements()), 43 | input_tensor.flat().data(), 44 | output_tensor->flat().data()); 45 | } 46 | }; 47 | 48 | // Register the CPU kernels. 49 | #define REGISTER_CPU(T) \ 50 | REGISTER_KERNEL_BUILDER( \ 51 | Name("Example").Device(DEVICE_CPU).TypeConstraint("T"), \ 52 | ExampleOp); 53 | REGISTER_CPU(float); 54 | REGISTER_CPU(int32); 55 | 56 | 57 | // Register the GPU kernels. 58 | //#ifdef GOOGLE_CUDA 59 | #define REGISTER_GPU(T) \ 60 | REGISTER_KERNEL_BUILDER( \ 61 | Name("Example").Device(DEVICE_GPU).TypeConstraint("T"), \ 62 | ExampleOp); 63 | REGISTER_GPU(float); 64 | REGISTER_GPU(int32); 65 | //#endif //GOOGLE_CUDA 66 | 67 | 68 | -------------------------------------------------------------------------------- /tef/core/kernels/example_op.cu: -------------------------------------------------------------------------------- 1 | 2 | #ifdef GOOGLE_CUDA 3 | #define EIGEN_USE_GPU 4 | 5 | 6 | #include "tensorflow/core/util/gpu_kernel_helper.h" 7 | #include "example_op.h" 8 | 9 | using namespace tensorflow; 10 | using GPUDevice = Eigen::GpuDevice; 11 | 12 | // Define the CUDA kernel. 13 | template 14 | __global__ void ExampleCudaKernel(const int size, const T* in, T* out) { 15 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { 16 | out[i] = 2 * ldg(in + i); 17 | } 18 | 19 | } 20 | 21 | // Define the GPU implementation that launches the CUDA kernel. 22 | template 23 | void ExampleFunctor::operator()(const GPUDevice& d, int size, const T* in, T* out) { 24 | // Launch the cuda kernel. 25 | // 26 | // See core/util/gpu_kernel_helper.h for example of computing 27 | // block count and thread_per_block count. 28 | int block_count = 1024; 29 | int thread_per_block = 20; 30 | ExampleCudaKernel<<>>(size, in, out); 31 | } 32 | 33 | template struct ExampleFunctor; 34 | template struct ExampleFunctor; 35 | 36 | #endif //GOOGLE_CUDA 37 | 38 | 39 | -------------------------------------------------------------------------------- /tef/core/kernels/example_op.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | #ifndef KERNEL_EXAMPLE_H_ 4 | #define KERNEL_EXAMPLE_H_ 5 | 6 | 7 | template 8 | struct ExampleFunctor{ 9 | void operator()(const Device& d, int size, const T* in, T* out); 10 | }; 11 | 12 | 13 | #if GOOGLE_CUDA 14 | //Partially specialize functor for GpuDevice 15 | template 16 | struct ExampleFunctor{ 17 | void operator()(const Eigen::GpuDevice& d, int size, const T* in, T* out); 18 | }; 19 | 20 | #endif //GOOGLE_CUDA 21 | 22 | #endif //KERNEL_EXAMPLE_H_ -------------------------------------------------------------------------------- /tef/core/kernels/ps_client/ps_client.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef PS_CLIENT_H 3 | #define PS_CLIENT_H 4 | 5 | #include 6 | 7 | #include "tensorflow/core/framework/op_kernel.h" 8 | using namespace tensorflow; 9 | 10 | 11 | class PsClient { 12 | public: 13 | virtual ~PsClient(){} 14 | 15 | enum VariableType{ 16 | VT_DENSE = 0, 17 | VT_HASH 18 | }; 19 | 20 | struct VariableInfo{ 21 | TensorShape shape_; 22 | DataType dtype_; 23 | string var_name_; 24 | VariableType var_type_; 25 | }; 26 | 27 | 28 | public: 29 | virtual void RegisterVariable(const VariableInfo& info, int &id) = 0; 30 | 31 | virtual void DensePull(int id, Tensor* data) = 0; 32 | 33 | virtual void DensePush(int id, 34 | const Tensor &data, 35 | const std::string& updater, 36 | float learning_rate) = 0; 37 | 38 | virtual void SparsePull(int id, 39 | const Tensor &index, 40 | Tensor* data) = 0; 41 | 42 | virtual void SparsePush(int id, 43 | const Tensor& index, 44 | const Tensor& data, 45 | const std::string& updater, 46 | float learning_rate) = 0; 47 | 48 | virtual void HashPull(int id, 49 | const Tensor& hash, 50 | Tensor* data) = 0; 51 | 52 | 53 | virtual void HashPush(int id, 54 | const Tensor& hash, 55 | const Tensor& data, 56 | const std::string& updater, 57 | float learning_rate) = 0; 58 | 59 | }; 60 | 61 | #endif //PS_CLIENT_H -------------------------------------------------------------------------------- /tef/core/kernels/ps_client/ps_client_dummy.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include 4 | #include 5 | 6 | #include "ps_client_dummy.h" 7 | #include "tensorflow/core/framework/tensor_util.h" 8 | 9 | 10 | namespace{ 11 | 12 | template 13 | void ZeroInit(Tensor * target){ 14 | auto flat = target->flat(); 15 | for(int i = 0; i < target->NumElements(); ++i){ 16 | flat(i) = static_cast(0); 17 | } 18 | } 19 | 20 | template 21 | void RandomInit(Tensor * target){ 22 | unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); 23 | std::default_random_engine generator(seed); 24 | std::uniform_int_distribution distribution(-2, 2); 25 | auto flat = target->flat(); 26 | for(int i = 0; i < target->NumElements(); ++i){ 27 | int dice_roll = distribution(generator); 28 | flat(i) = static_cast(dice_roll); 29 | } 30 | } 31 | 32 | template 33 | void SGDUpdate(float alpha, const Tensor& gradient, Tensor * target){ 34 | CHECK(target); 35 | CHECK(gradient.NumElements() == target->NumElements()); 36 | 37 | auto target_vec = target->flat(); 38 | auto gradient_vec = gradient.flat(); 39 | for(int i = 0; i < target->NumElements(); i++){ 40 | target_vec(i) -= alpha * gradient_vec(i) ; 41 | } 42 | } 43 | 44 | 45 | template 46 | void SGDUpdateSparse(float alpha, const Tensor& index, const Tensor& gradient, Tensor * target){ 47 | auto index_vec = index.vec(); 48 | auto target_matrix = target->flat_inner_dims(); 49 | auto gradient_matrix = gradient.flat_inner_dims(); 50 | CHECK(target_matrix.dimension(1) == gradient_matrix.dimension(1)); 51 | 52 | for(int i = 0; i < index.NumElements(); i++){ 53 | CHECK(index_vec(i) < target->dim_size(0)); 54 | for(int j = 0; j < target_matrix.dimension(1); j++){ 55 | target_matrix(index_vec(i), j) -= alpha * gradient_matrix(i, j); 56 | } 57 | } 58 | } 59 | 60 | template 61 | void SGDUpdateHash(float alpha, const Tensor& hash, const Tensor& gradient, std::unordered_map * target){ 62 | auto hash_vec = hash.vec(); 63 | auto gradient_matrix = gradient.flat_inner_dims(); 64 | for(int i = 0; i < hash.NumElements(); i++){ 65 | int64 key = hash_vec(i); 66 | auto it = target->find(key); 67 | CHECK(it != target->end()); 68 | 69 | Tensor slice = it->second; 70 | auto slice_flat = slice.flat(); 71 | for(int j = 0; j < slice.NumElements(); j++){ 72 | slice_flat(j) -= alpha * gradient_matrix(i, j); 73 | } 74 | } 75 | 76 | } 77 | 78 | template 79 | void LookUp(const Tensor& index, const Tensor& param, Tensor * out){ 80 | 81 | auto index_vec = index.vec(); 82 | auto out_matrix = out->flat_inner_dims(); 83 | auto param_matrix = param.flat_inner_dims(); 84 | CHECK(out_matrix.dimension(1) == param_matrix.dimension(1)); 85 | 86 | for(int i = 0; i < index.NumElements(); i++){ 87 | CHECK(index_vec(i) < param.dim_size(0)); 88 | 89 | for(int j = 0; j < out_matrix.dimension(1); j++){ 90 | out_matrix(i, j) = param_matrix(index_vec(i), j); 91 | } 92 | } 93 | } 94 | 95 | template 96 | void HashLookUp(const Tensor& hash, DataType dtype, const TensorShape shape, std::unordered_map * param, Tensor * out){ 97 | auto hash_vec = hash.vec(); 98 | auto out_matrix = out->flat_inner_dims(); 99 | for(int i = 0; i < hash.NumElements(); i++){ 100 | int64 key = hash_vec(i); 101 | auto it = param->find(key); 102 | if(it == param->end()){ 103 | Tensor missing(dtype, shape); 104 | RandomInit(&missing); 105 | (*param)[key] = missing; 106 | auto missing_flat = missing.flat(); 107 | for(int j = 0; j < missing.NumElements(); j++){ 108 | out_matrix(i, j) = missing_flat(j); 109 | } 110 | }else{ 111 | Tensor slice = it->second; 112 | auto slice_flat = slice.flat(); 113 | for(int j = 0; j < it->second.NumElements(); j++){ 114 | out_matrix(i, j) = slice_flat(j); 115 | } 116 | } 117 | } 118 | } 119 | 120 | 121 | } 122 | 123 | 124 | std::mutex PsClientDummy::s_instance_mutex_; 125 | 126 | //static 127 | PsClientDummy * PsClientDummy::GetInstance(){ 128 | static PsClientDummy * instance = nullptr; 129 | if (!instance){ 130 | s_instance_mutex_.lock(); 131 | if(!instance){ 132 | instance = new PsClientDummy(); 133 | } 134 | s_instance_mutex_.unlock(); 135 | } 136 | return instance; 137 | } 138 | 139 | void PsClientDummy::RegisterVariable(const VariableInfo& info, int &id) { 140 | variable_mutex_.lock(); 141 | auto it = variable_ids_.find(info.var_name_); 142 | if(it != variable_ids_.end()){ 143 | id = it->second; 144 | CHECK(id < variable_infos_.size()); 145 | CHECK(variable_infos_[id].shape_ == info.shape_); 146 | CHECK(variable_infos_[id].dtype_ == info.dtype_); 147 | CHECK(variable_infos_[id].var_type_ == info.var_type_); 148 | 149 | }else{ 150 | id = variables_.size(); 151 | variable_ids_[info.var_name_] = id; 152 | Variable var; 153 | if(info.var_type_ == VT_DENSE){ 154 | Tensor New(info.dtype_, info.shape_); 155 | switch(info.dtype_){ 156 | case DT_FLOAT: 157 | RandomInit(&New); 158 | break; 159 | case DT_DOUBLE: 160 | RandomInit(&New); 161 | break; 162 | case DT_INT32: 163 | RandomInit(&New); 164 | break; 165 | case DT_INT64: 166 | RandomInit(&New); 167 | break; 168 | default: 169 | CHECK(false); 170 | break; 171 | } 172 | var.dense_value_ = New; 173 | } 174 | variables_.push_back(var); 175 | variable_infos_.push_back(info); 176 | } 177 | 178 | variable_mutex_.unlock(); 179 | } 180 | 181 | void PsClientDummy::DensePull(int id, Tensor* data) { 182 | std::cout<<"DensePull variable_infos_[id].dtype_="<(learning_rate, data, &variables_[id].dense_value_); 206 | break; 207 | case DT_DOUBLE: 208 | SGDUpdate(learning_rate, data, &variables_[id].dense_value_); 209 | break; 210 | case DT_INT32: 211 | SGDUpdate(learning_rate, data, &variables_[id].dense_value_); 212 | break; 213 | case DT_INT64: 214 | SGDUpdate(learning_rate, data, &variables_[id].dense_value_); 215 | break; 216 | default: 217 | CHECK(false); 218 | break; 219 | } 220 | variable_mutex_.unlock(); 221 | } 222 | 223 | void PsClientDummy::SparsePull(int id, 224 | const Tensor &index, 225 | Tensor* data) { 226 | std::cout<<"SparsePull variable_infos_[id].dtype_="<(index, variables_[id].dense_value_, data); 235 | break; 236 | case DT_DOUBLE: 237 | LookUp(index, variables_[id].dense_value_, data); 238 | break; 239 | case DT_INT32: 240 | LookUp(index, variables_[id].dense_value_, data); 241 | break; 242 | case DT_INT64: 243 | LookUp(index, variables_[id].dense_value_, data); 244 | break; 245 | default: 246 | CHECK(false); 247 | break; 248 | } 249 | variable_mutex_.unlock(); 250 | 251 | 252 | } 253 | 254 | void PsClientDummy::SparsePush(int id, 255 | const Tensor& index, 256 | const Tensor& data, 257 | const std::string& updater, float learning_rate) { 258 | std::cout<<"SparsePush variable_infos_[id].dtype_="<(learning_rate, index, data, &variables_[id].dense_value_); 267 | break; 268 | case DT_DOUBLE: 269 | SGDUpdateSparse(learning_rate, index, data, &variables_[id].dense_value_); 270 | break; 271 | case DT_INT32: 272 | SGDUpdateSparse(learning_rate, index, data, &variables_[id].dense_value_); 273 | break; 274 | case DT_INT64: 275 | SGDUpdateSparse(learning_rate, index, data, &variables_[id].dense_value_); 276 | break; 277 | default: 278 | CHECK(false); 279 | break; 280 | } 281 | variable_mutex_.unlock(); 282 | } 283 | 284 | 285 | void PsClientDummy::HashPull(int id, 286 | const Tensor& hash, 287 | Tensor* data) { 288 | std::cout<<"HashPull variable_infos_[id].dtype_="<(learning_rate, hash, data, &variables_[id].hash_value_); 330 | break; 331 | case DT_DOUBLE: 332 | SGDUpdateHash(learning_rate, hash, data, &variables_[id].hash_value_); 333 | break; 334 | case DT_INT32: 335 | SGDUpdateHash(learning_rate, hash, data, &variables_[id].hash_value_); 336 | break; 337 | case DT_INT64: 338 | SGDUpdateHash(learning_rate, hash, data, &variables_[id].hash_value_); 339 | break; 340 | default: 341 | CHECK(false); 342 | break; 343 | } 344 | variable_mutex_.unlock(); 345 | } 346 | 347 | -------------------------------------------------------------------------------- /tef/core/kernels/ps_client/ps_client_dummy.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | #ifndef PS_CLIENT_DUMMY_H 4 | #define PS_CLIENT_DUMMY_H 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "ps_client.h" 11 | 12 | 13 | class PsClientDummy : public PsClient { 14 | public: 15 | static PsClientDummy * GetInstance(); 16 | 17 | public: 18 | virtual void RegisterVariable(const VariableInfo& info, int &id) override; 19 | 20 | virtual void DensePull(int id, Tensor* data) override; 21 | 22 | virtual void DensePush(int id, 23 | const Tensor &data, 24 | const std::string& updater, 25 | float learning_rate) override; 26 | 27 | virtual void SparsePull(int id, 28 | const Tensor &index, 29 | Tensor* data) override; 30 | 31 | virtual void SparsePush(int id, 32 | const Tensor& index, 33 | const Tensor& data, 34 | const std::string& updater, 35 | float learning_rate) override; 36 | 37 | virtual void HashPull(int id, 38 | const Tensor& hash, 39 | Tensor* data) override; 40 | 41 | 42 | virtual void HashPush(int id, 43 | const Tensor& hash, 44 | const Tensor& data, 45 | const std::string& updater, 46 | float learning_rate) override; 47 | 48 | private: 49 | std::mutex variable_mutex_; 50 | static std::mutex s_instance_mutex_; 51 | 52 | struct Variable{ 53 | Tensor dense_value_; 54 | std::unordered_map hash_value_; 55 | }; 56 | 57 | std::vector variables_; 58 | std::vector variable_infos_; 59 | std::unordered_map variable_ids_; 60 | 61 | }; 62 | 63 | #endif //PS_CLIENT_DUMMY_H -------------------------------------------------------------------------------- /tef/core/kernels/ps_client/ps_client_factory.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include "ps_client_factory.h" 4 | #include "ps_client_dummy.h" 5 | 6 | 7 | PsClient * PsClientFactory::Build(){ 8 | return PsClientDummy::GetInstance(); 9 | } 10 | -------------------------------------------------------------------------------- /tef/core/kernels/ps_client/ps_client_factory.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef PS_CLIENT_FACTORY_H 3 | #define PS_CLIENT_FACTORY_H 4 | 5 | #include "ps_client.h" 6 | 7 | class PsClientFactory { 8 | public: 9 | static PsClient * Build(); 10 | }; 11 | 12 | #endif //PS_CLIENT_FACTORY_H -------------------------------------------------------------------------------- /tef/core/kernels/ps_hash_pull_op.cc: -------------------------------------------------------------------------------- 1 | #include "ps_hash_pull_op.h" 2 | #include "ps_client/ps_client_factory.h" 3 | 4 | class PsHashPullOp : public OpKernel { 5 | public: 6 | explicit PsHashPullOp(OpKernelConstruction* context) : OpKernel(context){ 7 | OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_)); 8 | OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); 9 | OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); 10 | 11 | ps_client_ = PsClientFactory::Build(); 12 | PsClient::VariableInfo var_info; 13 | var_info.var_name_ = var_name_; 14 | var_info.shape_ = shape_; 15 | var_info.dtype_ = dtype_; 16 | var_info.var_type_ = PsClient::VT_HASH; 17 | ps_client_->RegisterVariable(var_info, var_id_); 18 | } 19 | 20 | 21 | public: 22 | void Compute(OpKernelContext* context) override { 23 | const Tensor &index = context->input(0); 24 | CHECK(index.dims() == 1)<<"index.dims="<< index.dims(); 25 | 26 | Tensor* output_tensor = nullptr; 27 | TensorShape output_tensor_shape(index.shape()); 28 | output_tensor_shape.AppendShape(shape_); 29 | OP_REQUIRES_OK(context, context->allocate_output(0, output_tensor_shape, &output_tensor)); 30 | ps_client_->HashPull(var_id_, index, output_tensor); 31 | } 32 | 33 | private: 34 | TensorShape shape_; 35 | DataType dtype_; 36 | string var_name_; 37 | int var_id_; 38 | PsClient * ps_client_; 39 | }; 40 | 41 | 42 | #define REGISTER_CPU_KERNEL(T) \ 43 | REGISTER_KERNEL_BUILDER(Name("PsHashPull").Device(DEVICE_CPU).TypeConstraint("dtype"), PsHashPullOp); 44 | REGISTER_CPU_KERNEL(bool) 45 | REGISTER_CPU_KERNEL(int) 46 | REGISTER_CPU_KERNEL(int64) 47 | REGISTER_CPU_KERNEL(float) 48 | REGISTER_CPU_KERNEL(double) 49 | 50 | -------------------------------------------------------------------------------- /tef/core/kernels/ps_hash_pull_op.h: -------------------------------------------------------------------------------- 1 | #ifndef PS_HASH_PULL_OP_H 2 | #define PS_HASH_PULL_OP_H 3 | 4 | #include "tensorflow/core/framework/op_kernel.h" 5 | using namespace tensorflow; 6 | 7 | 8 | 9 | #endif //PS_HASH_PULL_OP_H -------------------------------------------------------------------------------- /tef/core/kernels/ps_hash_push_op.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include "ps_hash_push_op.h" 4 | #include "ps_client/ps_client_factory.h" 5 | 6 | class PsHashPushOp : public OpKernel { 7 | public: 8 | explicit PsHashPushOp(OpKernelConstruction* context) : OpKernel(context){ 9 | OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_)); 10 | OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); 11 | OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); 12 | OP_REQUIRES_OK(context, context->GetAttr("updater", &updater_)); 13 | OP_REQUIRES_OK(context, context->GetAttr("learning_rate", &learning_rate_)); 14 | 15 | 16 | ps_client_ = PsClientFactory::Build(); 17 | PsClient::VariableInfo var_info; 18 | var_info.var_name_ = var_name_; 19 | var_info.shape_ = shape_; 20 | var_info.dtype_ = dtype_; 21 | var_info.var_type_ = PsClient::VT_HASH; 22 | ps_client_->RegisterVariable(var_info, var_id_); 23 | } 24 | 25 | public: 26 | void Compute(OpKernelContext* context) override { 27 | const Tensor &index = context->input(0); 28 | const Tensor &data = context->input(1); 29 | ps_client_->HashPush(var_id_, index, data, updater_, learning_rate_); 30 | } 31 | 32 | 33 | private: 34 | TensorShape shape_; 35 | DataType dtype_; 36 | std::string var_name_; 37 | std::string updater_; 38 | float learning_rate_; 39 | 40 | int var_id_; 41 | PsClient * ps_client_; 42 | }; 43 | 44 | 45 | #define REGISTER_CPU_KERNEL(T) \ 46 | REGISTER_KERNEL_BUILDER(Name("PsHashPush").Device(DEVICE_CPU).TypeConstraint("dtype"), PsHashPushOp); 47 | REGISTER_CPU_KERNEL(bool) 48 | REGISTER_CPU_KERNEL(int) 49 | REGISTER_CPU_KERNEL(int64) 50 | REGISTER_CPU_KERNEL(float) 51 | REGISTER_CPU_KERNEL(double) 52 | 53 | -------------------------------------------------------------------------------- /tef/core/kernels/ps_hash_push_op.h: -------------------------------------------------------------------------------- 1 | #ifndef PS_HASH_PUSH_OP_H 2 | #define PS_HASH_PUSH_OP_H 3 | 4 | #include "tensorflow/core/framework/op_kernel.h" 5 | using namespace tensorflow; 6 | 7 | 8 | 9 | #endif //PS_HASH_PUSH_OP_H 10 | -------------------------------------------------------------------------------- /tef/core/kernels/ps_pull_op.cc: -------------------------------------------------------------------------------- 1 | 2 | #include "ps_pull_op.h" 3 | #include "ps_client/ps_client_factory.h" 4 | 5 | 6 | class PsPullOp : public OpKernel { 7 | public: 8 | explicit PsPullOp(OpKernelConstruction* context) : OpKernel(context){ 9 | OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_)); 10 | OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); 11 | OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); 12 | 13 | ps_client_ = PsClientFactory::Build(); 14 | PsClient::VariableInfo var_info; 15 | var_info.var_name_ = var_name_; 16 | var_info.shape_ = shape_; 17 | var_info.dtype_ = dtype_; 18 | var_info.var_type_ = PsClient::VT_DENSE; 19 | ps_client_->RegisterVariable(var_info, var_id_); 20 | } 21 | 22 | 23 | public: 24 | void Compute(OpKernelContext* context) override { 25 | Tensor* output_tensor = nullptr; 26 | OP_REQUIRES_OK(context, context->allocate_output(0, shape_, &output_tensor)); 27 | ps_client_->DensePull(var_id_, output_tensor); 28 | } 29 | 30 | private: 31 | TensorShape shape_; 32 | DataType dtype_; 33 | string var_name_; 34 | int var_id_; 35 | PsClient * ps_client_; 36 | }; 37 | 38 | 39 | 40 | 41 | #define REGISTER_CPU_KERNEL(T) \ 42 | REGISTER_KERNEL_BUILDER(Name("PsPull").Device(DEVICE_CPU).TypeConstraint("dtype"), PsPullOp); 43 | REGISTER_CPU_KERNEL(bool) 44 | REGISTER_CPU_KERNEL(int) 45 | REGISTER_CPU_KERNEL(int64) 46 | REGISTER_CPU_KERNEL(float) 47 | REGISTER_CPU_KERNEL(double) 48 | -------------------------------------------------------------------------------- /tef/core/kernels/ps_pull_op.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | #ifndef PS_PULL_OP_H 4 | #define PS_PULL_OP_H 5 | 6 | #include "tensorflow/core/framework/op_kernel.h" 7 | using namespace tensorflow; 8 | 9 | 10 | 11 | #endif //PS_PULL_OP_H 12 | 13 | -------------------------------------------------------------------------------- /tef/core/kernels/ps_push_op.cc: -------------------------------------------------------------------------------- 1 | 2 | #include "ps_push_op.h" 3 | #include "ps_client/ps_client_factory.h" 4 | 5 | class PsPushOp : public OpKernel { 6 | public: 7 | explicit PsPushOp(OpKernelConstruction* context) : OpKernel(context){ 8 | OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_)); 9 | OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); 10 | OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); 11 | OP_REQUIRES_OK(context, context->GetAttr("updater", &updater_)); 12 | OP_REQUIRES_OK(context, context->GetAttr("learning_rate", &learning_rate_)); 13 | 14 | ps_client_ = PsClientFactory::Build(); 15 | PsClient::VariableInfo var_info; 16 | var_info.var_name_ = var_name_; 17 | var_info.shape_ = shape_; 18 | var_info.dtype_ = dtype_; 19 | var_info.var_type_ = PsClient::VT_DENSE; 20 | ps_client_->RegisterVariable(var_info, var_id_); 21 | } 22 | 23 | 24 | public: 25 | void Compute(OpKernelContext* context) override { 26 | const Tensor &data = context->input(0); 27 | ps_client_->DensePush(var_id_, data, updater_, learning_rate_); 28 | } 29 | 30 | 31 | private: 32 | TensorShape shape_; 33 | DataType dtype_; 34 | string var_name_; 35 | string updater_; 36 | float learning_rate_; 37 | 38 | int var_id_; 39 | PsClient * ps_client_; 40 | }; 41 | 42 | 43 | 44 | 45 | #define REGISTER_CPU_KERNEL(T) \ 46 | REGISTER_KERNEL_BUILDER(Name("PsPush").Device(DEVICE_CPU).TypeConstraint("dtype"), PsPushOp); 47 | REGISTER_CPU_KERNEL(bool) 48 | REGISTER_CPU_KERNEL(int) 49 | REGISTER_CPU_KERNEL(int64) 50 | REGISTER_CPU_KERNEL(float) 51 | REGISTER_CPU_KERNEL(double) 52 | 53 | -------------------------------------------------------------------------------- /tef/core/kernels/ps_push_op.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef PS_PUSH_OP_H 3 | #define PS_PUSH_OP_H 4 | 5 | #include "tensorflow/core/framework/op_kernel.h" 6 | using namespace tensorflow; 7 | 8 | 9 | #endif //PS_PUSH_OP_H -------------------------------------------------------------------------------- /tef/core/kernels/ps_sparse_pull_op.cc: -------------------------------------------------------------------------------- 1 | 2 | #include "ps_sparse_pull_op.h" 3 | #include "ps_client/ps_client_factory.h" 4 | 5 | class PsSparsePullOp : public OpKernel { 6 | public: 7 | explicit PsSparsePullOp(OpKernelConstruction* context) : OpKernel(context){ 8 | OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_)); 9 | OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); 10 | OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); 11 | CHECK(shape_.dims() >= 2); 12 | 13 | ps_client_ = PsClientFactory::Build(); 14 | PsClient::VariableInfo var_info; 15 | var_info.var_name_ = var_name_; 16 | var_info.shape_ = shape_; 17 | var_info.dtype_ = dtype_; 18 | var_info.var_type_ = PsClient::VT_DENSE; 19 | ps_client_->RegisterVariable(var_info, var_id_); 20 | } 21 | 22 | 23 | public: 24 | void Compute(OpKernelContext* context) override { 25 | const Tensor &index = context->input(0); 26 | CHECK(index.dims() == 1)<<"index.dims="<< index.dims(); 27 | 28 | Tensor* output_tensor = nullptr; 29 | TensorShape output_tensor_shape(index.shape()); 30 | for(int i = 1; i < shape_.dims(); i++){ 31 | output_tensor_shape.AddDim(shape_.dim_size(i)); 32 | } 33 | OP_REQUIRES_OK(context, context->allocate_output(0, output_tensor_shape, &output_tensor)); 34 | ps_client_->SparsePull(var_id_, index, output_tensor); 35 | } 36 | 37 | private: 38 | TensorShape shape_; 39 | DataType dtype_; 40 | string var_name_; 41 | 42 | int var_id_; 43 | PsClient * ps_client_; 44 | }; 45 | 46 | 47 | #define REGISTER_CPU_KERNEL(T) \ 48 | REGISTER_KERNEL_BUILDER(Name("PsSparsePull").Device(DEVICE_CPU).TypeConstraint("dtype"), PsSparsePullOp); 49 | REGISTER_CPU_KERNEL(bool) 50 | REGISTER_CPU_KERNEL(int) 51 | REGISTER_CPU_KERNEL(int64) 52 | REGISTER_CPU_KERNEL(float) 53 | REGISTER_CPU_KERNEL(double) 54 | 55 | -------------------------------------------------------------------------------- /tef/core/kernels/ps_sparse_pull_op.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | #ifndef PS_SPARSE_PULL_OP_H 4 | #define PS_SPARSE_PULL_OP_H 5 | #include "tensorflow/core/framework/op_kernel.h" 6 | using namespace tensorflow; 7 | 8 | 9 | #endif //PS_SPARSE_PULL_OP_H 10 | -------------------------------------------------------------------------------- /tef/core/kernels/ps_sparse_push_op.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include "ps_sparse_push_op.h" 4 | #include "ps_client/ps_client_factory.h" 5 | 6 | class PsSparsePushOp : public OpKernel { 7 | public: 8 | explicit PsSparsePushOp(OpKernelConstruction* context) : OpKernel(context){ 9 | OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_)); 10 | OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); 11 | OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); 12 | OP_REQUIRES_OK(context, context->GetAttr("updater", &updater_)); 13 | OP_REQUIRES_OK(context, context->GetAttr("learning_rate", &learning_rate_)); 14 | 15 | 16 | ps_client_ = PsClientFactory::Build(); 17 | PsClient::VariableInfo var_info; 18 | var_info.var_name_ = var_name_; 19 | var_info.shape_ = shape_; 20 | var_info.dtype_ = dtype_; 21 | var_info.var_type_ = PsClient::VT_DENSE; 22 | ps_client_->RegisterVariable(var_info, var_id_); 23 | } 24 | 25 | 26 | public: 27 | void Compute(OpKernelContext* context) override { 28 | const Tensor &index = context->input(0); 29 | const Tensor &data = context->input(1); 30 | ps_client_->SparsePush(var_id_, index, data, updater_, learning_rate_); 31 | } 32 | 33 | private: 34 | TensorShape shape_; 35 | DataType dtype_; 36 | std::string var_name_; 37 | std::string updater_; 38 | float learning_rate_; 39 | 40 | int var_id_; 41 | PsClient * ps_client_; 42 | }; 43 | 44 | 45 | 46 | #define REGISTER_CPU_KERNEL(T) \ 47 | REGISTER_KERNEL_BUILDER(Name("PsSparsePush").Device(DEVICE_CPU).TypeConstraint("dtype"), PsSparsePushOp); 48 | REGISTER_CPU_KERNEL(bool) 49 | REGISTER_CPU_KERNEL(int) 50 | REGISTER_CPU_KERNEL(int64) 51 | REGISTER_CPU_KERNEL(float) 52 | REGISTER_CPU_KERNEL(double) 53 | 54 | -------------------------------------------------------------------------------- /tef/core/kernels/ps_sparse_push_op.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef PS_SPARSE_PUSH_OP_H 3 | #define PS_SPARSE_PUSH_OP_H 4 | 5 | #include "tensorflow/core/framework/op_kernel.h" 6 | using namespace tensorflow; 7 | 8 | #endif //PS_SPARSE_PUSH_OP_H -------------------------------------------------------------------------------- /tef/core/kernels/zero_out_op.cc: -------------------------------------------------------------------------------- 1 | 2 | #include "zero_out_op.h" 3 | 4 | ZeroOutOp::ZeroOutOp(OpKernelConstruction * context) : OpKernel(context) { 5 | } 6 | 7 | void ZeroOutOp::Compute(OpKernelContext* context) { 8 | // Grab the input tensor 9 | const Tensor& input_tensor = context->input(0); 10 | auto input = input_tensor.flat(); 11 | 12 | // Create an output tensor 13 | Tensor* output_tensor = NULL; 14 | OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), 15 | &output_tensor)); 16 | auto output_flat = output_tensor->flat(); 17 | 18 | // Set all but the first element of the output tensor to 0. 19 | const int N = input.size(); 20 | for (int i = 1; i < N; i++) { 21 | output_flat(i) = 0; 22 | } 23 | 24 | // Preserve the first input value if possible. 25 | if (N > 0) output_flat(0) = input(0); 26 | } 27 | 28 | 29 | REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp); 30 | -------------------------------------------------------------------------------- /tef/core/kernels/zero_out_op.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef ZERO_OUT_OP_H_ 3 | #define ZERO_OUT_OP_H_ 4 | 5 | #include "tensorflow/core/framework/op_kernel.h" 6 | 7 | using namespace tensorflow; 8 | 9 | class ZeroOutOp : public OpKernel { 10 | public: 11 | explicit ZeroOutOp(OpKernelConstruction * context); 12 | 13 | public: 14 | void Compute(OpKernelContext* context) override; 15 | }; 16 | 17 | 18 | 19 | #endif //ZERO_OUT_OP_H_ 20 | 21 | -------------------------------------------------------------------------------- /tef/core/ops/example_ops.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include "tensorflow/core/framework/op.h" 4 | #include "tensorflow/core/framework/shape_inference.h" 5 | 6 | using namespace tensorflow; 7 | 8 | 9 | REGISTER_OP("ZeroOut") 10 | .Input("to_zero: int32") 11 | .Output("zeroed: int32") 12 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 13 | c->set_output(0, c->input(0)); 14 | return Status::OK(); 15 | }); 16 | 17 | REGISTER_OP("Example") 18 | .Input("multiply_by_two: T") 19 | .Output("output: T") 20 | .Attr("T: {int32, float}") 21 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 22 | c->set_output(0, c->input(0)); 23 | return Status::OK(); 24 | }); 25 | 26 | -------------------------------------------------------------------------------- /tef/core/ops/ps_ops.cc: -------------------------------------------------------------------------------- 1 | 2 | #include "tensorflow/core/framework/op.h" 3 | #include "tensorflow/core/framework/shape_inference.h" 4 | 5 | using namespace tensorflow; 6 | 7 | 8 | REGISTER_OP("PsPull") 9 | .Attr("var_name: string") 10 | .Attr("shape: shape") 11 | .Attr("dtype: type") 12 | .Output("output: dtype"); 13 | 14 | 15 | REGISTER_OP("PsPush") 16 | .Input("g: float") 17 | .Attr("var_name: string") 18 | .Attr("shape: shape") 19 | .Attr("dtype: type") 20 | .Attr("updater: string") 21 | .Attr("learning_rate: float"); 22 | 23 | 24 | 25 | REGISTER_OP("PsSparsePull") 26 | .Input("index: int64") 27 | .Attr("var_name: string") 28 | .Attr("shape: shape") 29 | .Attr("dtype: type") 30 | .Output("output: dtype"); 31 | REGISTER_OP("PsSparsePush") 32 | .Input("index: int64") 33 | .Input("g: float") 34 | .Attr("var_name: string") 35 | .Attr("shape: shape") 36 | .Attr("dtype: type") 37 | .Attr("updater: string") 38 | .Attr("learning_rate: float"); 39 | 40 | 41 | 42 | 43 | REGISTER_OP("PsHashPull") 44 | .Input("hash: int64") 45 | .Attr("var_name: string") 46 | .Attr("shape: shape") 47 | .Attr("dtype: type") 48 | .Output("output: dtype"); 49 | REGISTER_OP("PsHashPush") 50 | .Input("hash: int64") 51 | .Input("g: float") 52 | .Attr("var_name: string") 53 | .Attr("shape: shape") 54 | .Attr("dtype: type") 55 | .Attr("updater: string") 56 | .Attr("learning_rate: float"); 57 | -------------------------------------------------------------------------------- /tef/python/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_program(PYTHON "python") 2 | 3 | set(SETUP_PY_IN "${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in") 4 | set(SETUP_PY "${CMAKE_CURRENT_SOURCE_DIR}/setup.py") 5 | set(BUILD_PACKAGE_TASK "${CMAKE_CURRENT_SOURCE_DIR}/timestamp") 6 | 7 | 8 | configure_file(${SETUP_PY_IN} ${SETUP_PY}) 9 | 10 | add_custom_command(OUTPUT ${BUILD_PACKAGE_TASK} 11 | COMMAND ${PYTHON} ${SETUP_PY} bdist_wheel 12 | COMMENT "building tef package") 13 | 14 | add_custom_target(tef ALL DEPENDS tef_core ${BUILD_PACKAGE_TASK}) 15 | 16 | 17 | -------------------------------------------------------------------------------- /tef/python/setup.py.in: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name = "tef", 5 | version = '${PACKAGE_VERSION}', 6 | package_dir = {'':'${CMAKE_CURRENT_SOURCE_DIR}'}, 7 | packages = setuptools.find_packages('${CMAKE_CURRENT_SOURCE_DIR}'), 8 | package_data = {'':['libtef_core.so']}) 9 | 10 | -------------------------------------------------------------------------------- /tef/python/tef/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jony0917/tensorflow-extend-framework/5db3a0ed373173fd799d292bfcc7e5544882e9d0/tef/python/tef/__init__.py -------------------------------------------------------------------------------- /tef/python/tef/ops/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from embedding import * 3 | from variable import * 4 | 5 | -------------------------------------------------------------------------------- /tef/python/tef/ops/embedding.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import tef.pywrap 4 | import tef.utils 5 | 6 | 7 | 8 | def embedding(ids, name, shape, dtype, id_type="index"): 9 | ids, idx = tf.unique(ids) 10 | if id_type == "index": 11 | emb = tef.pywrap.ps_sparse_pull(ids, name, shape, dtype) 12 | tef.utils.add_to_collection(tef.utils.TEF_TRAINABLE_COLLECTION, 13 | tef.utils.VariableSub(emb, name, shape, dtype, ids, "index")) 14 | elif id_type == "hash": 15 | emb = tef.pywrap.ps_hash_pull(ids, name, shape, dtype) 16 | tef.utils.add_to_collection(tef.utils.TEF_TRAINABLE_COLLECTION, 17 | tef.utils.VariableSub(emb, name, shape, dtype, ids, "hash")) 18 | else: 19 | assert False 20 | 21 | emb = tf.gather(emb, idx) 22 | return emb 23 | 24 | 25 | def embedding_sparse(sp_ids, name, shape, dtype, id_type="index", combiner="mean"): 26 | ids, idx = tf.unique(sp_ids.values) 27 | if id_type == "index": 28 | emb = tef.pywrap.ps_sparse_pull(ids, name, shape, dtype) 29 | tef.utils.add_to_collection(tef.utils.TEF_TRAINABLE_COLLECTION, 30 | tef.utils.VariableSub(emb, name, shape, dtype, ids, "index")) 31 | elif id_type == "hash": 32 | emb = tef.pywrap.ps_hash_pull(ids, name, shape, dtype) 33 | tef.utils.add_to_collection(tef.utils.TEF_TRAINABLE_COLLECTION, 34 | tef.utils.VariableSub(emb, name, shape, dtype, ids, "hash")) 35 | 36 | emb = tf.gather(emb, idx) 37 | segment_ids = sp_ids.indices[:, 0] 38 | if combiner == "sum": 39 | emb = tf.math.segment_sum(emb, segment_ids) 40 | elif combiner == "mean": 41 | emb_sum = tf.math.segment_sum(emb, segment_ids) 42 | weight_sum = tf.math.segment_sum(tf.ones(tf.shape(emb)), segment_ids) 43 | emb = tf.math.divide(emb_sum, weight_sum) 44 | else: 45 | assert False 46 | 47 | return emb 48 | -------------------------------------------------------------------------------- /tef/python/tef/ops/variable.py: -------------------------------------------------------------------------------- 1 | 2 | import tef 3 | import tef.pywrap 4 | import tef.utils 5 | 6 | def variable(name, shape, dtype): 7 | v = tef.pywrap.ps_pull(name, shape, dtype) 8 | tef.utils.add_to_collection(tef.utils.TEF_TRAINABLE_COLLECTION, 9 | tef.utils.VariableSub(v, name, shape, dtype, None, "dense")) 10 | return v 11 | -------------------------------------------------------------------------------- /tef/python/tef/pywrap/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from tef_core import * 3 | 4 | -------------------------------------------------------------------------------- /tef/python/tef/pywrap/tef_core.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | tef_core_lib = tf.load_op_library(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'libtef_core.so')) 5 | zero_out = tef_core_lib.zero_out 6 | example = tef_core_lib.example 7 | ps_pull = tef_core_lib.ps_pull 8 | ps_push = tef_core_lib.ps_push 9 | ps_sparse_pull = tef_core_lib.ps_sparse_pull 10 | ps_sparse_push = tef_core_lib.ps_sparse_push 11 | ps_hash_pull = tef_core_lib.ps_hash_pull 12 | ps_hash_push = tef_core_lib.ps_hash_push 13 | -------------------------------------------------------------------------------- /tef/python/tef/pywrap/tef_core_test.py: -------------------------------------------------------------------------------- 1 | # --coding:utf-8-- 2 | import unittest 3 | 4 | import tensorflow as tf 5 | import tef_core 6 | 7 | class FMModelWarmerTest(unittest.TestCase): 8 | def test_zero_out(self): 9 | @tf.function 10 | def zero_out(x): 11 | return tef_core.zero_out(x) 12 | res = zero_out([1,2,3,4,5,6]) 13 | print res 14 | 15 | graph = tf.Graph() 16 | with graph.as_default(): 17 | input = tf.compat.v1.placeholder(tf.int32) 18 | output = tef_core.zero_out(input) 19 | sess = tf.compat.v1.Session(graph = graph) 20 | res = sess.run(output, feed_dict={input : [2,2,2,2,2]}) 21 | print res 22 | 23 | 24 | def test_example_op(self): 25 | graph = tf.Graph() 26 | with graph.as_default(): 27 | input = tf.compat.v1.placeholder(tf.int32) 28 | output = tef_core.example(input) 29 | config = tf.compat.v1.ConfigProto(allow_soft_placement=False, 30 | log_device_placement=True) 31 | sess = tf.compat.v1.Session(graph = graph, config=config) 32 | res = sess.run(output, feed_dict={input : [3,3,3,3,3]}) 33 | print res 34 | 35 | 36 | def test_example_op_on_gpu(self): 37 | graph = tf.Graph() 38 | with graph.as_default(): 39 | with tf.device("/GPU"): 40 | input = tf.compat.v1.placeholder(tf.int32) 41 | output = tef_core.example(input) 42 | config = tf.compat.v1.ConfigProto(allow_soft_placement=False, 43 | log_device_placement=True) 44 | sess = tf.compat.v1.Session(graph = graph, config = config) 45 | res = sess.run(output, feed_dict={input : [4,4,4,4,4]}) 46 | print res 47 | 48 | if __name__ == '__main__': 49 | unittest.main() 50 | 51 | -------------------------------------------------------------------------------- /tef/python/tef/training/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from optimizer import * 3 | -------------------------------------------------------------------------------- /tef/python/tef/training/optimizer.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import tef 4 | import tef.pywrap 5 | import tef.utils 6 | 7 | class BaseOptimizer(object): 8 | def __init__(self): 9 | pass 10 | 11 | def compute_gradients(self, loss): 12 | tef_trainable = tef.utils.get_collection(tef.utils.TEF_TRAINABLE_COLLECTION) 13 | gs = [] 14 | stubs = [] 15 | for stub in tef_trainable: 16 | gradient = tf.gradients(loss, stub.var) 17 | assert len(gradient) == 1 18 | gs.append(gradient[0]) 19 | stubs.append(stub) 20 | return gs, stubs 21 | 22 | def apply_gradients(self, gs, stubs): 23 | """ 24 | To be implemented in subclass. 25 | 26 | :param gs: gradients, list of tf.Tensor or tf.IndexedSlices object. 27 | :param stubs: tef variable stubs 28 | :return: train operation 29 | """ 30 | 31 | assert False 32 | 33 | def minimize(self, loss): 34 | gs, stubs = self.compute_gradients(loss) 35 | return self.apply_gradients(gs, stubs) 36 | 37 | 38 | class GradientDescentOptimizer(BaseOptimizer): 39 | 40 | def __init__(self, learning_rate): 41 | super(GradientDescentOptimizer, self).__init__() 42 | self.learning_rate = learning_rate 43 | 44 | def apply_gradients(self, gs, stubs): 45 | assert len(gs) == len(stubs) 46 | 47 | push_ops = [] 48 | for i in range(len(gs)): 49 | gradient = gs[i] 50 | stub = stubs[i] 51 | if stub.category == "dense": 52 | assert isinstance(gradient, tf.Tensor) 53 | push_op = tef.pywrap.ps_push(gradient, 54 | stub.name, 55 | stub.shape, 56 | stub.dtype, 57 | "SGD", 58 | self.learning_rate) 59 | elif stub.category == "index": 60 | assert isinstance(gradient, tf.IndexedSlices) 61 | ids = tf.gather(stub.ids, gradient.indices) 62 | push_op = tef.pywrap.ps_sparse_push(ids, 63 | gradient.values, 64 | stub.name, 65 | stub.shape, 66 | stub.dtype, 67 | "SGD", 68 | self.learning_rate) 69 | elif stub.category == "hash": 70 | assert isinstance(gradient, tf.IndexedSlices) 71 | ids = tf.gather(stub.ids, gradient.indices) 72 | push_op = tef.pywrap.ps_hash_push(ids, 73 | gradient.values, 74 | stub.name, 75 | stub.shape, 76 | stub.dtype, 77 | "SGD", 78 | self.learning_rate) 79 | else: 80 | assert False 81 | 82 | push_ops.append(push_op) 83 | 84 | return tf.group(push_ops) 85 | 86 | -------------------------------------------------------------------------------- /tef/python/tef/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import * 3 | -------------------------------------------------------------------------------- /tef/python/tef/utils/collections.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | GLOBAL_COLLECTIONS = {} 4 | TEF_TRAINABLE_COLLECTION = "tef_trainable_collection" 5 | 6 | 7 | class VariableSub(object): 8 | def __init__(self, var, var_name, var_shape, dtype, ids=None, category="dense"): 9 | self.var = var 10 | self.name = var_name 11 | self.shape = var_shape 12 | self.dtype = dtype 13 | self.ids = ids 14 | self.category = category 15 | 16 | 17 | def add_to_collection(name, stub): 18 | global GLOBAL_COLLECTIONS 19 | if not GLOBAL_COLLECTIONS.has_key(name): 20 | GLOBAL_COLLECTIONS[name] = [] 21 | GLOBAL_COLLECTIONS[name].append(stub) 22 | 23 | 24 | 25 | def get_collection(name): 26 | global GLOBAL_COLLECTIONS 27 | if GLOBAL_COLLECTIONS.has_key(name): 28 | return GLOBAL_COLLECTIONS[name] 29 | else: 30 | return None 31 | -------------------------------------------------------------------------------- /third_party/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(tef) 2 | cmake_minimum_required(VERSION 3.0.0 FATAL_ERROR) 3 | 4 | --------------------------------------------------------------------------------