├── .circleci └── config.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── WORKSPACE ├── configure.sh ├── examples └── simple.py ├── requirements-dev.txt ├── setup.py ├── tf_pjc ├── BUILD ├── __init__.py ├── cc │ ├── kernels.cc │ └── ops.cc └── python │ ├── ops.py │ ├── ops_test.py │ ├── protocol.py │ └── protocol_test.py └── third_party ├── glog └── BUILD └── tf ├── BUILD ├── BUILD.tpl └── tf_configure.bzl /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | commands: 4 | 5 | bootstrap-macos: 6 | steps: 7 | - run: 8 | name: Bootstrap macOS 9 | command: | 10 | HOMEBREW_NO_AUTO_UPDATE=1 brew tap bazelbuild/tap >> build.log 11 | HOMEBREW_NO_AUTO_UPDATE=1 brew install \ 12 | bazelbuild/tap/bazel mmv tree >> build.log 13 | 14 | curl -O https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh --silent 15 | bash Miniconda3-latest-MacOSX-x86_64.sh -b -f >> build.log 16 | ~/miniconda3/bin/conda create -n py3.5 python=3.5 -y 17 | ln -s ~/miniconda3/envs/py3.5/bin/python ~/python3.5 18 | ~/miniconda3/bin/conda create -n py3.6 python=3.6 -y 19 | ln -s ~/miniconda3/envs/py3.6/bin/python ~/python3.6 20 | 21 | create-pyenv: 22 | # Create new Python virtual environment 23 | parameters: 24 | python-version: 25 | type: string 26 | python-environment: 27 | type: string 28 | steps: 29 | - run: 30 | name: Create Python << parameters.python-version >> environment '<< parameters.python-environment >>' 31 | command: | 32 | ~/python<< parameters.python-version >> -m venv << parameters.python-environment >> 33 | 34 | build: 35 | parameters: 36 | python-version: 37 | type: string 38 | tensorflow-version: 39 | type: string 40 | # next parameter should be derived 41 | python-environment: 42 | type: string 43 | steps: 44 | - create-pyenv: 45 | python-environment: << parameters.python-environment >> 46 | python-version: << parameters.python-version >> 47 | - run: 48 | name: Install requirements-dev.txt in '<< parameters.python-environment >>' 49 | command: | 50 | . << parameters.python-environment >>/bin/activate 51 | pip install -q -U -r requirements-dev.txt 52 | pip freeze 53 | - run: 54 | name: Install TensorFlow << parameters.tensorflow-version >> in '<< parameters.python-environment >>' 55 | command: | 56 | . << parameters.python-environment >>/bin/activate 57 | make clean 58 | pip install -q -U tensorflow==<< parameters.tensorflow-version >> 59 | make .bazelrc 60 | # reduce Bazel output to logs 61 | echo 'test --noshow_progress --noshow_loading_progress' >> .bazelrc 62 | echo 'build --noshow_progress --noshow_loading_progress' >> .bazelrc 63 | - run: 64 | name: Test in '<< parameters.python-environment >>' 65 | command: | 66 | . << parameters.python-environment >>/bin/activate 67 | python --version 68 | pip freeze 69 | make test 70 | - run: 71 | name: Build in '<< parameters.python-environment >>' 72 | command: | 73 | . << parameters.python-environment >>/bin/activate 74 | python --version 75 | pip freeze 76 | DIR_TAGGED=./out/builds/py<< parameters.python-version >>-tf<< parameters.tensorflow-version >> make build 77 | - persist_to_workspace: 78 | root: ./out 79 | paths: 80 | - builds/py<< parameters.python-version >>-tf<< parameters.tensorflow-version >> 81 | 82 | bundle: 83 | parameters: 84 | python-version: 85 | type: string 86 | # next parameter should be derived 87 | python-environment: 88 | type: string 89 | steps: 90 | - create-pyenv: 91 | python-environment: << parameters.python-environment >> 92 | python-version: << parameters.python-version >> 93 | - run: 94 | name: Install requirements-dev.txt in '<< parameters.python-environment >>' 95 | command: | 96 | . << parameters.python-environment >>/bin/activate 97 | pip install -q -U -r requirements-dev.txt 98 | pip freeze 99 | - attach_workspace: 100 | at: ./out 101 | - run: 102 | name: Merge builds 103 | command: | 104 | tree ./out 105 | rsync -avm ./out/builds/*/ ./out/merged 106 | tree ./out 107 | - run: 108 | name: Bundle package in '<< parameters.python-environment >>' 109 | command: | 110 | . << parameters.python-environment >>/bin/activate 111 | python --version 112 | pip freeze 113 | DIR_TAGGED=./out/merged DIR_WHEEL=./out/wheelhouse make bundle 114 | tree ./out 115 | - persist_to_workspace: 116 | root: ./out 117 | paths: 118 | - 'wheelhouse' 119 | 120 | whltest: 121 | parameters: 122 | python-version: 123 | type: string 124 | tensorflow-version: 125 | type: string 126 | python-environment: 127 | type: string 128 | steps: 129 | - create-pyenv: 130 | python-version: << parameters.python-version >> 131 | python-environment: << parameters.python-environment >> 132 | - attach_workspace: 133 | at: ./out 134 | - run: 135 | name: Configure '<< parameters.python-environment >>' to use TensorFlow << parameters.tensorflow-version >> 136 | command: | 137 | set -e 138 | set -x 139 | tree ./out/wheelhouse 140 | . << parameters.python-environment >>/bin/activate 141 | # we want to make sure that tests are run against whatever is 142 | # in the wheelhouse; for this we'd like to use --no-index but 143 | # that will also block dependencies from being installed. 144 | # as a result we first install dependencies by installing the 145 | # package and then immediately remove it again 146 | pip install -q -U tf-pjc --find-links ./out/wheelhouse 147 | pip uninstall tf-pjc -y 148 | # install the package, but forced to only use the wheelhouse 149 | pip install -U tf-pjc --no-deps --no-cache-dir --no-index --find-links ./out/wheelhouse 150 | # make sure we are testing against the right version of TensorFlow 151 | pip install -q -U tensorflow==<< parameters.tensorflow-version >> 152 | - run: 153 | name: Test wheel in '<< parameters.python-environment >>' 154 | command: | 155 | . << parameters.python-environment >>/bin/activate 156 | python --version 157 | pip freeze 158 | make pytest 159 | 160 | jobs: 161 | 162 | build-linux: 163 | parameters: 164 | python-version: 165 | type: string 166 | tensorflow-version: 167 | type: string 168 | docker: 169 | - image: tfencrypted/tf-big:build 170 | working_directory: ~/repo 171 | steps: 172 | - checkout 173 | - build: 174 | python-version: << parameters.python-version >> 175 | tensorflow-version: << parameters.tensorflow-version >> 176 | python-environment: build-linux-py<< parameters.python-version >>-tf<< parameters.tensorflow-version >> 177 | 178 | build-macos: 179 | parameters: 180 | python-version: 181 | type: string 182 | tensorflow-version: 183 | type: string 184 | macos: 185 | xcode: "10.0.0" 186 | working_directory: ~/repo 187 | steps: 188 | - checkout 189 | - bootstrap-macos 190 | - build: 191 | python-version: << parameters.python-version >> 192 | tensorflow-version: << parameters.tensorflow-version >> 193 | python-environment: build-macos-py<< parameters.python-version >>-tf<< parameters.tensorflow-version >> 194 | 195 | bundle-linux: 196 | parameters: 197 | python-version: 198 | type: string 199 | docker: 200 | - image: tfencrypted/tf-big:build 201 | working_directory: ~/repo 202 | steps: 203 | - checkout 204 | - bundle: 205 | python-version: << parameters.python-version >> 206 | python-environment: bundle-linux-py<< parameters.python-version >> 207 | 208 | bundle-macos: 209 | parameters: 210 | python-version: 211 | type: string 212 | macos: 213 | xcode: "10.0.0" 214 | working_directory: ~/repo 215 | steps: 216 | - checkout 217 | - bootstrap-macos 218 | - bundle: 219 | python-version: << parameters.python-version >> 220 | python-environment: bundle-macos-py<< parameters.python-version >> 221 | 222 | whltest-linux: 223 | parameters: 224 | python-version: 225 | type: string 226 | tensorflow-version: 227 | type: string 228 | docker: 229 | - image: tfencrypted/tf-big:whltest 230 | working_directory: ~/repo 231 | steps: 232 | - checkout 233 | - whltest: 234 | python-version: << parameters.python-version >> 235 | python-environment: test-linux-py<< parameters.python-version >> 236 | tensorflow-version: << parameters.tensorflow-version >> 237 | 238 | whltest-macos: 239 | parameters: 240 | python-version: 241 | type: string 242 | tensorflow-version: 243 | type: string 244 | macos: 245 | xcode: "10.0.0" 246 | working_directory: ~/repo 247 | steps: 248 | - checkout 249 | - bootstrap-macos 250 | - whltest: 251 | python-version: << parameters.python-version >> 252 | python-environment: test-macos-py<< parameters.python-version >> 253 | tensorflow-version: << parameters.tensorflow-version >> 254 | 255 | store: 256 | docker: 257 | - image: tfencrypted/tf-big:deploy 258 | working_directory: ~/repo 259 | steps: 260 | - checkout 261 | - attach_workspace: 262 | at: ./out 263 | - run: 264 | name: List content to be stored 265 | command: | 266 | tree ./out/wheelhouse 267 | - store_artifacts: 268 | path: ./out/wheelhouse 269 | destination: wheelhouse 270 | 271 | deploy: 272 | docker: 273 | - image: tfencrypted/tf-big:deploy 274 | working_directory: ~/repo 275 | steps: 276 | - checkout 277 | - create-pyenv: 278 | python-version: "3.6" 279 | python-environment: "deploy-py3.6" 280 | - attach_workspace: 281 | at: ./out 282 | - run: 283 | name: Configure 'deploy-3.6' 284 | command: | 285 | . deploy-py3.6/bin/activate 286 | pip install -q -U -r requirements-dev.txt 287 | - run: 288 | name: Upload to PyPI 289 | command: | 290 | tree ./out/wheelhouse 291 | . deploy-py3.6/bin/activate 292 | DIR_WHEEL=./out/wheelhouse make push-wheels 293 | 294 | workflows: 295 | version: 2 296 | 297 | # these workflows implement the following logic: 298 | # - non-master branch: run quick tests 299 | # - master branch: build, test, and store wheels 300 | # - non-semver tag: build, test, and store wheels 301 | # - semver tag: build, test, store, and deploy wheels 302 | 303 | quicktest: 304 | jobs: 305 | - build-linux: 306 | name: build-linux-py3.6-tf1.14.0 307 | python-version: "3.6" 308 | tensorflow-version: "1.14.0" 309 | filters: 310 | branches: 311 | ignore: master 312 | tags: 313 | ignore: /.*/ 314 | 315 | - bundle-linux: 316 | name: bundle-linux-py3.6 317 | python-version: "3.6" 318 | requires: 319 | - build-linux-py3.6-tf1.14.0 320 | filters: 321 | branches: 322 | ignore: master 323 | tags: 324 | ignore: /.*/ 325 | 326 | - whltest-linux: 327 | name: whltest-linux-py3.6-tf1.14.0 328 | python-version: "3.6" 329 | tensorflow-version: "1.14.0" 330 | requires: 331 | - bundle-linux-py3.6 332 | filters: 333 | branches: 334 | ignore: master 335 | tags: 336 | ignore: /.*/ 337 | 338 | linux-py3.5: 339 | jobs: 340 | - build-linux: 341 | name: build-linux-py3.5-tf1.13.1 342 | python-version: "3.5" 343 | tensorflow-version: "1.13.1" 344 | filters: 345 | branches: 346 | only: master 347 | tags: 348 | only: /.*/ 349 | - build-linux: 350 | name: build-linux-py3.5-tf1.13.2 351 | python-version: "3.5" 352 | tensorflow-version: "1.13.2" 353 | filters: 354 | branches: 355 | only: master 356 | tags: 357 | only: /.*/ 358 | - build-linux: 359 | name: build-linux-py3.5-tf1.14.0 360 | python-version: "3.5" 361 | tensorflow-version: "1.14.0" 362 | filters: 363 | branches: 364 | only: master 365 | tags: 366 | only: /.*/ 367 | 368 | - bundle-linux: 369 | name: bundle-linux-py3.5 370 | python-version: "3.5" 371 | requires: 372 | - build-linux-py3.5-tf1.13.1 373 | - build-linux-py3.5-tf1.13.2 374 | - build-linux-py3.5-tf1.14.0 375 | filters: 376 | branches: 377 | only: master 378 | tags: 379 | only: /.*/ 380 | 381 | - whltest-linux: 382 | name: whltest-linux-py3.5-tf1.13.1 383 | python-version: "3.5" 384 | tensorflow-version: "1.13.1" 385 | requires: 386 | - bundle-linux-py3.5 387 | filters: 388 | branches: 389 | only: master 390 | tags: 391 | only: /.*/ 392 | - whltest-linux: 393 | name: whltest-linux-py3.5-tf1.13.2 394 | python-version: "3.5" 395 | tensorflow-version: "1.13.2" 396 | requires: 397 | - bundle-linux-py3.5 398 | filters: 399 | branches: 400 | only: master 401 | tags: 402 | only: /.*/ 403 | - whltest-linux: 404 | name: whltest-linux-py3.5-tf1.14.0 405 | python-version: "3.5" 406 | tensorflow-version: "1.14.0" 407 | requires: 408 | - bundle-linux-py3.5 409 | filters: 410 | branches: 411 | only: master 412 | tags: 413 | only: /.*/ 414 | 415 | - store: 416 | name: store-linux-py3.5 417 | requires: 418 | - whltest-linux-py3.5-tf1.13.1 419 | - whltest-linux-py3.5-tf1.13.2 420 | - whltest-linux-py3.5-tf1.14.0 421 | filters: 422 | branches: 423 | only: master 424 | tags: 425 | only: /.*/ 426 | 427 | - hold: 428 | type: approval 429 | name: hold-linux-py3.5 430 | requires: 431 | - store-linux-py3.5 432 | filters: 433 | branches: 434 | ignore: /.*/ 435 | tags: 436 | only: /^(?:[0-9]+)\.(?:[0-9]+)\.(?:[0-9]+)(?:(\-rc[0-9]+)?)$/ 437 | 438 | - deploy: 439 | name: deploy-linux-py3.5 440 | requires: 441 | - hold-linux-py3.5 442 | filters: 443 | branches: 444 | ignore: /.*/ 445 | tags: 446 | only: /^(?:[0-9]+)\.(?:[0-9]+)\.(?:[0-9]+)(?:(\-rc[0-9]+)?)$/ 447 | 448 | linux-py3.6: 449 | jobs: 450 | - build-linux: 451 | name: build-linux-py3.6-tf1.13.1 452 | python-version: "3.6" 453 | tensorflow-version: "1.13.1" 454 | filters: 455 | branches: 456 | only: master 457 | tags: 458 | only: /.*/ 459 | - build-linux: 460 | name: build-linux-py3.6-tf1.13.2 461 | python-version: "3.6" 462 | tensorflow-version: "1.13.2" 463 | filters: 464 | branches: 465 | only: master 466 | tags: 467 | only: /.*/ 468 | - build-linux: 469 | name: build-linux-py3.6-tf1.14.0 470 | python-version: "3.6" 471 | tensorflow-version: "1.14.0" 472 | filters: 473 | branches: 474 | only: master 475 | tags: 476 | only: /.*/ 477 | 478 | - bundle-linux: 479 | name: bundle-linux-py3.6 480 | python-version: "3.6" 481 | requires: 482 | - build-linux-py3.6-tf1.13.1 483 | - build-linux-py3.6-tf1.13.2 484 | - build-linux-py3.6-tf1.14.0 485 | filters: 486 | branches: 487 | only: master 488 | tags: 489 | only: /.*/ 490 | 491 | - whltest-linux: 492 | name: whltest-linux-py3.6-tf1.13.1 493 | python-version: "3.6" 494 | tensorflow-version: "1.13.1" 495 | requires: 496 | - bundle-linux-py3.6 497 | filters: 498 | branches: 499 | only: master 500 | tags: 501 | only: /.*/ 502 | - whltest-linux: 503 | name: whltest-linux-py3.6-tf1.13.2 504 | python-version: "3.6" 505 | tensorflow-version: "1.13.2" 506 | requires: 507 | - bundle-linux-py3.6 508 | filters: 509 | branches: 510 | only: master 511 | tags: 512 | only: /.*/ 513 | - whltest-linux: 514 | name: whltest-linux-py3.6-tf1.14.0 515 | python-version: "3.6" 516 | tensorflow-version: "1.14.0" 517 | requires: 518 | - bundle-linux-py3.6 519 | filters: 520 | branches: 521 | only: master 522 | tags: 523 | only: /.*/ 524 | 525 | - store: 526 | name: store-linux-py3.6 527 | requires: 528 | - whltest-linux-py3.6-tf1.13.1 529 | - whltest-linux-py3.6-tf1.13.2 530 | - whltest-linux-py3.6-tf1.14.0 531 | filters: 532 | branches: 533 | only: master 534 | tags: 535 | only: /.*/ 536 | 537 | - hold: 538 | type: approval 539 | name: hold-linux-py3.6 540 | requires: 541 | - store-linux-py3.6 542 | filters: 543 | branches: 544 | ignore: /.*/ 545 | tags: 546 | only: /^(?:[0-9]+)\.(?:[0-9]+)\.(?:[0-9]+)(?:(\-rc[0-9]+)?)$/ 547 | 548 | - deploy: 549 | name: deploy-linux-py3.6 550 | requires: 551 | - hold-linux-py3.6 552 | filters: 553 | branches: 554 | ignore: /.*/ 555 | tags: 556 | only: /^(?:[0-9]+)\.(?:[0-9]+)\.(?:[0-9]+)(?:(\-rc[0-9]+)?)$/ 557 | 558 | macos-py3.6: 559 | jobs: 560 | - build-macos: 561 | name: build-macos-py3.6-tf1.13.1 562 | python-version: "3.6" 563 | tensorflow-version: "1.13.1" 564 | filters: 565 | branches: 566 | only: master 567 | tags: 568 | only: /.*/ 569 | - build-macos: 570 | name: build-macos-py3.6-tf1.13.2 571 | python-version: "3.6" 572 | tensorflow-version: "1.13.2" 573 | filters: 574 | branches: 575 | only: master 576 | tags: 577 | only: /.*/ 578 | - build-macos: 579 | name: build-macos-py3.6-tf1.14.0 580 | python-version: "3.6" 581 | tensorflow-version: "1.14.0" 582 | filters: 583 | branches: 584 | only: master 585 | tags: 586 | only: /.*/ 587 | 588 | - bundle-macos: 589 | name: bundle-macos-py3.6 590 | python-version: "3.6" 591 | requires: 592 | - build-macos-py3.6-tf1.13.1 593 | - build-macos-py3.6-tf1.13.2 594 | - build-macos-py3.6-tf1.14.0 595 | filters: 596 | branches: 597 | only: master 598 | tags: 599 | only: /.*/ 600 | 601 | - whltest-macos: 602 | name: whltest-macos-py3.6-tf1.13.1 603 | python-version: "3.6" 604 | tensorflow-version: "1.13.1" 605 | requires: 606 | - bundle-macos-py3.6 607 | filters: 608 | branches: 609 | only: master 610 | tags: 611 | only: /.*/ 612 | - whltest-macos: 613 | name: whltest-macos-py3.6-tf1.13.2 614 | python-version: "3.6" 615 | tensorflow-version: "1.13.2" 616 | requires: 617 | - bundle-macos-py3.6 618 | filters: 619 | branches: 620 | only: master 621 | tags: 622 | only: /.*/ 623 | - whltest-macos: 624 | name: whltest-macos-py3.6-tf1.14.0 625 | python-version: "3.6" 626 | tensorflow-version: "1.14.0" 627 | requires: 628 | - bundle-macos-py3.6 629 | filters: 630 | branches: 631 | only: master 632 | tags: 633 | only: /.*/ 634 | 635 | - store: 636 | name: store-macos-py3.6 637 | requires: 638 | - whltest-macos-py3.6-tf1.13.1 639 | - whltest-macos-py3.6-tf1.13.2 640 | - whltest-macos-py3.6-tf1.14.0 641 | filters: 642 | branches: 643 | only: master 644 | tags: 645 | only: /.*/ 646 | 647 | - hold: 648 | type: approval 649 | name: hold-macos-py3.6 650 | requires: 651 | - store-macos-py3.6 652 | filters: 653 | branches: 654 | ignore: /.*/ 655 | tags: 656 | only: /^(?:[0-9]+)\.(?:[0-9]+)\.(?:[0-9]+)(?:(\-rc[0-9]+)?)$/ 657 | 658 | - deploy: 659 | name: deploy-macos-py3.6 660 | requires: 661 | - hold-macos-py3.6 662 | filters: 663 | branches: 664 | ignore: /.*/ 665 | tags: 666 | only: /^(?:[0-9]+)\.(?:[0-9]+)\.(?:[0-9]+)(?:(\-rc[0-9]+)?)$/ 667 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **.DS_Store 2 | 3 | .vscode 4 | 5 | .bazelrc 6 | bazel-* 7 | 8 | .mypy_cache/** -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .bazelrc: 2 | TF_NEED_CUDA=0 ./configure.sh 3 | 4 | clean: 5 | bazel clean 6 | rm -f .bazelrc 7 | 8 | test: .bazelrc 9 | bazel test ... --test_output=all 10 | 11 | build: .bazelrc 12 | bazel build "//tf_pjc:pjc_ops_py" 13 | 14 | .PHONY: clean test build 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TF PJC 2 | 3 | TF PJC provides a bridge between TensorFlow and Google's [Private Join and Compute](https://github.com/google/private-join-and-compute) library. This allows two parties to privately compute the intersection of two sets and the sum of associated values as described in [IKNP+'19](https://eprint.iacr.org/2019/723). 4 | 5 | 8 | 9 | ## Usage 10 | 11 | The library may be used as shown in the following example: 12 | 13 | ```python 14 | import tensorflow as tf 15 | import tf_pjc 16 | 17 | # device strings of the two players involved 18 | client_device = "/job:localhost/task:0/device:CPU:0" 19 | server_device = "/job:localhost/task:1/device:CPU:0" 20 | 21 | # construct private input of server 22 | with tf.device(server_device): 23 | server_elements = tf.constant(["a", "b", "c"]) 24 | 25 | # construct private inputs of client 26 | with tf.device(client_device): 27 | client_elements = tf.constant(["a", "b", "c", "d"]) 28 | client_values = tf.constant([100, 200, 400, 800]) 29 | 30 | # use protocol to securely compute intersection size and sum 31 | protocol_instance = tf_pjc.PrivateIntersectionSum(client_device, server_device) 32 | client_result_op, server_wait_op = protocol_instance(client_elements, client_values, server_elements) 33 | 34 | # print private result (which is local to the client) 35 | with tf.device(client_device): 36 | intersection_size, intersection_sum = client_result_op 37 | print_size_op = tf.print("Intersection size: ", intersection_size) 38 | print_sum_op = tf.print("Intersection sum: ", intersection_sum) 39 | print_op = tf.group(print_size_op, print_sum_op) 40 | 41 | # run in TensorFlow session 42 | with tf.Session() as sess: 43 | sess.run([print_op, server_wait_op]) 44 | ``` 45 | 46 | Future releases will also include the possibility of using TF PJC in conjunction with [TF Encrypted](https://github.com/tf-encrypted/tf-encrypted) as a kernel for `tfe.sets.intersection_sum`. 47 | 48 | ## Installation 49 | 50 | Python 3 packages are available from [PyPI](https://pypi.org/project/tf-pjc/): 51 | 52 | ``` 53 | pip install tf-pjc 54 | ``` 55 | 56 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 2 | load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") 3 | 4 | http_archive( 5 | name = "com_google_googletest", 6 | url = "https://github.com/google/googletest/archive/release-1.8.1.zip", 7 | strip_prefix = "googletest-release-1.8.1", 8 | sha256 = "927827c183d01734cc5cfef85e0ff3f5a92ffe6188e0d18e909c5efebf28a0c7", 9 | ) 10 | 11 | # 12 | # Google Private Join and Compute (and transitive dependencies) 13 | # 14 | 15 | # git_repository( 16 | # name = "com_github_google_private_join_and_compute", 17 | # remote = "https://github.com/google/private-join-and-compute.git", 18 | # branch = "master", 19 | # ) 20 | 21 | git_repository( 22 | name = "com_github_google_private_join_and_compute", 23 | remote = "https://github.com/mortendahl/private-join-and-compute.git", 24 | branch = "wrapper", 25 | ) 26 | 27 | http_archive( 28 | name = "com_github_glog_glog", 29 | build_file = "@//third_party/glog:BUILD", 30 | urls = ["https://github.com/google/glog/archive/v0.3.5.tar.gz"], 31 | strip_prefix = "glog-0.3.5", 32 | ) 33 | 34 | # com_google_absl used by TensorFlow 35 | http_archive( 36 | name = "com_google_absl", 37 | # build_file = "//external:com_google_absl.BUILD", 38 | sha256 = "56cd3fbbbd94468a5fff58f5df2b6f9de7a0272870c61f6ca05b869934f4802a", 39 | strip_prefix = "abseil-cpp-daf381e8535a1f1f1b8a75966a74e7cca63dee89", 40 | urls = [ 41 | "http://mirror.tensorflow.org/github.com/abseil/abseil-cpp/archive/daf381e8535a1f1f1b8a75966a74e7cca63dee89.tar.gz", 42 | "https://github.com/abseil/abseil-cpp/archive/daf381e8535a1f1f1b8a75966a74e7cca63dee89.tar.gz", 43 | ], 44 | ) 45 | 46 | git_repository( 47 | name = "com_github_grpc_grpc", 48 | remote = "https://github.com/grpc/grpc.git", 49 | tag = "v1.19.0", 50 | ) 51 | 52 | load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") 53 | grpc_deps() 54 | 55 | # 56 | # TensorFlow 57 | # 58 | 59 | load("//third_party/tf:tf_configure.bzl", "tf_configure") 60 | tf_configure(name = "local_config_tf") 61 | -------------------------------------------------------------------------------- /configure.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | function write_to_bazelrc() { 17 | echo "$1" >> .bazelrc 18 | } 19 | 20 | function write_action_env_to_bazelrc() { 21 | write_to_bazelrc "build --action_env $1=\"$2\"" 22 | } 23 | 24 | # Remove .bazelrc if it already exist 25 | [ -e .bazelrc ] && rm .bazelrc 26 | 27 | # Check if we are building GPU or CPU ops, default CPU 28 | while [[ "$TF_NEED_CUDA" == "" ]]; do 29 | read -p "Do you want to build ops again TensorFlow CPU pip package?"\ 30 | " Y or enter for CPU (tensorflow), N for GPU (tensorflow-gpu). [Y/n] " INPUT 31 | case $INPUT in 32 | [Yy]* ) echo "Build with CPU pip package."; TF_NEED_CUDA=0;; 33 | [Nn]* ) echo "Build with GPU pip package."; TF_NEED_CUDA=1;; 34 | "" ) echo "Build with CPU pip package."; TF_NEED_CUDA=0;; 35 | * ) echo "Invalid selection: " $INPUT;; 36 | esac 37 | done 38 | 39 | 40 | 41 | # CPU 42 | if [[ "$TF_NEED_CUDA" == "0" ]]; then 43 | 44 | # Check if it's installed 45 | if [[ $(pip show tensorflow) == *tensorflow* ]] || [[ $(pip show tf-nightly) == *tf-nightly* ]] ; then 46 | echo 'Using installed tensorflow' 47 | else 48 | # Uninstall GPU version if it is installed. 49 | if [[ $(pip show tensorflow-gpu) == *tensorflow-gpu* ]]; then 50 | echo 'Already have gpu version of tensorflow installed. Uninstalling......\n' 51 | pip uninstall tensorflow-gpu 52 | elif [[ $(pip show tf-nightly-gpu) == *tf-nightly-gpu* ]]; then 53 | echo 'Already have gpu version of tensorflow installed. Uninstalling......\n' 54 | pip uninstall tf-nightly-gpu 55 | fi 56 | # Install CPU version 57 | echo 'Installing tensorflow......\n' 58 | pip install tensorflow==1.13.1 59 | fi 60 | 61 | else 62 | 63 | # Check if it's installed 64 | if [[ $(pip show tensorflow-gpu) == *tensorflow-gpu* ]] || [[ $(pip show tf-nightly-gpu) == *tf-nightly-gpu* ]]; then 65 | echo 'Using installed tensorflow-gpu' 66 | else 67 | # Uninstall CPU version if it is installed. 68 | if [[ $(pip show tensorflow) == *tensorflow* ]]; then 69 | echo 'Already have tensorflow non-gpu installed. Uninstalling......\n' 70 | pip uninstall tensorflow 71 | elif [[ $(pip show tf-nightly) == *tf-nightly* ]]; then 72 | echo 'Already have tensorflow non-gpu installed. Uninstalling......\n' 73 | pip uninstall tf-nightly 74 | fi 75 | # Install CPU version 76 | echo 'Installing tensorflow-gpu .....\n' 77 | pip install tensorflow-gpu 78 | fi 79 | fi 80 | 81 | 82 | TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') ) 83 | TF_LFLAGS="$(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')" 84 | 85 | write_to_bazelrc "build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true" 86 | write_to_bazelrc "build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain" 87 | write_to_bazelrc "build --spawn_strategy=standalone" 88 | write_to_bazelrc "build --strategy=Genrule=standalone" 89 | write_to_bazelrc "build -c opt" 90 | 91 | 92 | write_action_env_to_bazelrc "TF_HEADER_DIR" ${TF_CFLAGS:2} 93 | SHARED_LIBRARY_DIR=${TF_LFLAGS:2} 94 | SHARED_LIBRARY_NAME=$(echo $TF_LFLAGS | rev | cut -d":" -f1 | rev) 95 | if ! [[ $TF_LFLAGS =~ .*:.* ]]; then 96 | if [[ "$(uname)" == "Darwin" ]]; then 97 | SHARED_LIBRARY_NAME="libtensorflow_framework.so" 98 | else 99 | SHARED_LIBRARY_NAME="libtensorflow_framework.so" 100 | fi 101 | fi 102 | write_action_env_to_bazelrc "TF_SHARED_LIBRARY_DIR" ${SHARED_LIBRARY_DIR} 103 | write_action_env_to_bazelrc "TF_SHARED_LIBRARY_NAME" ${SHARED_LIBRARY_NAME} 104 | write_action_env_to_bazelrc "TF_NEED_CUDA" ${TF_NEED_CUDA} 105 | 106 | # TODO(yifeif): do not hardcode path 107 | if [[ "$TF_NEED_CUDA" == "1" ]]; then 108 | write_action_env_to_bazelrc "CUDNN_INSTALL_PATH" "/usr/lib/x86_64-linux-gnu" 109 | write_action_env_to_bazelrc "TF_CUDA_VERSION" "10.0" 110 | write_action_env_to_bazelrc "TF_CUDNN_VERSION" "7" 111 | write_action_env_to_bazelrc "CUDA_TOOLKIT_PATH" "/usr/local/cuda" 112 | write_to_bazelrc "build --config=cuda" 113 | write_to_bazelrc "test --config=cuda" 114 | fi 115 | -------------------------------------------------------------------------------- /examples/simple.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tf_pjc 3 | 4 | # device strings of the two players involved 5 | client_device = "/job:localhost/replica:0/task:0/device:CPU:0" 6 | server_device = "/job:localhost/replica:0/task:0/device:CPU:0" 7 | 8 | # construct private input of server 9 | with tf.device(server_device): 10 | server_elements = tf.constant(["a", "b", "c"]) 11 | 12 | # construct private inputs of client 13 | with tf.device(client_device): 14 | client_elements = tf.constant(["a", "b", "c", "d"]) 15 | client_values = tf.constant([10, 20, 40, 80]) 16 | 17 | # use protocol to securely compute intersection size and sum 18 | protocol_instance = tf_pjc.PrivateIntersectionSum(client_device, server_device) 19 | client_result_op, server_wait_op = protocol_instance(client_elements, client_values, server_elements) 20 | 21 | # print private result (which is local to the client) 22 | with tf.device(client_device): 23 | intersection_size, intersection_sum = client_result_op 24 | print_size_op = tf.print("Intersection size: ", intersection_size) 25 | print_sum_op = tf.print("Intersection sum: ", intersection_sum) 26 | print_op = tf.group(print_size_op, print_sum_op) 27 | 28 | # run in TensorFlow session 29 | with tf.Session() as sess: 30 | sess.run([print_op, server_wait_op]) 31 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | cpplint==1.4.4 2 | pip==19.2.3 3 | setuptools==41.2.0 4 | twine==1.13.0 5 | wheel==0.33.6 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Installing with setuptools.""" 2 | import setuptools 3 | 4 | from setuptools.dist import Distribution 5 | 6 | class BinaryDistribution(Distribution): 7 | """This class is needed in order to create OS specific wheels.""" 8 | 9 | def has_ext_modules(self): 10 | return True 11 | 12 | with open("README.md", "r") as fh: 13 | long_description = fh.read() 14 | 15 | setuptools.setup( 16 | name="tf-pjc", 17 | version="0.1.0", 18 | packages=setuptools.find_packages(), 19 | package_data={'tf_pjc': []}, 20 | python_requires="==3.7.*", 21 | install_requires=[], 22 | extras_require={}, 23 | license="Apache License 2.0", 24 | url="https://github.com/tf-encrypted/tf-pjc", 25 | description="Bridge between TensorFlow and Google's Private Join and Compute library", 26 | long_description=long_description, 27 | long_description_content_type="text/markdown", 28 | author="The TF Encrypted Authors", 29 | author_email="contact@tf-encrypted.io", 30 | include_package_data=True, 31 | zip_safe=False, 32 | distclass=BinaryDistribution, 33 | classifiers=[ 34 | "Programming Language :: Python :: 3", 35 | "License :: OSI Approved :: Apache Software License", 36 | "Development Status :: 2 - Pre-Alpha", 37 | "Operating System :: OS Independent", 38 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 39 | "Topic :: Security :: Cryptography", 40 | ] 41 | ) 42 | -------------------------------------------------------------------------------- /tf_pjc/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cc_binary( 6 | name = 'python/_pjc_ops.so', 7 | srcs = [ 8 | "cc/ops.cc", 9 | "cc/kernels.cc", 10 | ], 11 | linkshared = True, 12 | deps = [ 13 | # TODO(Morten) the following seems like a mistake 14 | "@com_google_absl//absl/container:inlined_vector", 15 | 16 | "@local_config_tf//:libtensorflow_framework", 17 | "@local_config_tf//:tf_header_lib", 18 | 19 | "@com_github_google_private_join_and_compute//:lib", 20 | ], 21 | copts = ["-pthread", "-std=c++11", "-D_GLIBCXX_USE_CXX11_ABI=0", "-fPIC"], 22 | ) 23 | 24 | py_library( 25 | name = "pjc_ops_py", 26 | srcs = ([ 27 | "python/ops.py", 28 | "python/protocol.py", 29 | ]), 30 | data = [ 31 | ":python/_pjc_ops.so" 32 | ], 33 | srcs_version = "PY3", 34 | ) 35 | 36 | py_test( 37 | name = "pjc_ops_py_test", 38 | srcs = [ 39 | "python/ops_test.py", 40 | ], 41 | main = "python/ops_test.py", 42 | deps = [ 43 | ":pjc_ops_py", 44 | ], 45 | srcs_version = "PY3", 46 | ) 47 | 48 | py_test( 49 | name = "pjc_protocol_py_test", 50 | srcs = [ 51 | "python/protocol_test.py", 52 | ], 53 | main = "python/protocol_test.py", 54 | deps = [ 55 | ":pjc_ops_py", 56 | ], 57 | srcs_version = "PY3", 58 | ) 59 | -------------------------------------------------------------------------------- /tf_pjc/__init__.py: -------------------------------------------------------------------------------- 1 | from tf_pjc.python.protocol import PrivateIntersectionSum 2 | 3 | __all__ = [ 4 | 'PrivateIntersectionSum', 5 | ] 6 | -------------------------------------------------------------------------------- /tf_pjc/cc/kernels.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "tensorflow/core/framework/op.h" 6 | #include "tensorflow/core/framework/op_kernel.h" 7 | #include "tensorflow/core/framework/resource_mgr.h" 8 | #include "tensorflow/core/framework/shape_inference.h" 9 | #include "tensorflow/core/framework/variant.h" 10 | #include "tensorflow/core/framework/variant_encode_decode.h" 11 | #include "tensorflow/core/framework/variant_op_registry.h" 12 | #include "tensorflow/core/framework/variant_tensor_data.h" 13 | #include "tensorflow/core/lib/core/errors.h" 14 | 15 | #include "private_join_and_compute.h" 16 | 17 | using namespace tensorflow; // NOLINT 18 | 19 | template 20 | std::string convert_to_string(E x) { 21 | return std::to_string(x); 22 | } 23 | 24 | template <> 25 | std::string convert_to_string(std::string x) { 26 | return x; 27 | } 28 | 29 | template 30 | class PjcRunClientOp : public OpKernel { 31 | public: 32 | explicit PjcRunClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 33 | 34 | void Compute(OpKernelContext* ctx) override { 35 | // TODO(Morten) avoid these extra copies; pass iterator directly into PJC? 36 | 37 | // read elements input 38 | const Tensor& elements_tensor = ctx->input(0); 39 | auto elements_flat = elements_tensor.flat(); 40 | std::vector elements; 41 | elements.reserve(elements_flat.size()); 42 | for (int i = 0; i < elements_flat.size(); i++) { 43 | std::string element = convert_to_string(elements_flat.data()[i]); 44 | elements.push_back(element); 45 | } 46 | 47 | // read values input 48 | const Tensor& values_tensor = ctx->input(1); 49 | auto values_flat = values_tensor.flat(); 50 | std::vector values; 51 | values.reserve(values_flat.size()); 52 | for (int i = 0; i < values_flat.size(); i++) { 53 | int64_t value = static_cast(values_flat.data()[i]); // TODO(Morten) explicit cast 54 | values.push_back(value); 55 | } 56 | 57 | // run 58 | ::private_join_and_compute::ClientResult result; 59 | ::private_join_and_compute::ClientSession session; 60 | int res = session.Run(1536, 61 | "0.0.0.0:10501", 62 | std::move(elements), 63 | std::move(values), 64 | &result); 65 | OP_REQUIRES(ctx, res == 0, errors::Unknown("Session failed")); 66 | 67 | // write size output 68 | Tensor* size_tensor = nullptr; 69 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &size_tensor)); 70 | size_tensor->scalar()() = result.intersection_size; 71 | 72 | // write sum output 73 | Tensor* sum_tensor = nullptr; 74 | OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape{}, &sum_tensor)); 75 | sum_tensor->scalar()() = static_cast(result.intersection_sum); // TODO(Morten) check bounds? 76 | } 77 | }; 78 | 79 | struct ServerResource : ResourceBase { 80 | ::private_join_and_compute::ServerSession session; 81 | 82 | std::string DebugString() const override { 83 | return ""; 84 | } 85 | }; 86 | 87 | template 88 | class PjcRunServerOp : public OpKernel { 89 | public: 90 | explicit PjcRunServerOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 91 | 92 | void Compute(OpKernelContext* ctx) override { 93 | auto resource_mgr = ctx->resource_manager(); 94 | 95 | // read input 96 | const Tensor& elements_tensor = ctx->input(0); 97 | auto elements_flat = elements_tensor.flat(); 98 | std::vector elements; 99 | elements.reserve(elements_flat.size()); 100 | for (int i = 0; i < elements_flat.size(); i++) { 101 | std::string element = convert_to_string(elements_flat.data()[i]); 102 | elements.push_back(element); 103 | } 104 | 105 | // run 106 | ServerResource* server_resource = new ServerResource; 107 | OP_REQUIRES_OK(ctx, resource_mgr->Create("pjc", "server", server_resource)); 108 | server_resource->session.Run("0.0.0.0:10501", std::move(elements)); 109 | 110 | // cleanup 111 | OP_REQUIRES_OK(ctx, resource_mgr->Delete("pjc", "server")); 112 | server_resource = nullptr; 113 | } 114 | }; 115 | 116 | template 117 | class PjcRunAsyncServerOp : public OpKernel { 118 | public: 119 | explicit PjcRunAsyncServerOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 120 | 121 | void Compute(OpKernelContext* ctx) override { 122 | auto resource_mgr = ctx->resource_manager(); 123 | 124 | // read input 125 | const Tensor& elements_tensor = ctx->input(0); 126 | auto elements_flat = elements_tensor.flat(); 127 | std::vector elements; 128 | elements.reserve(elements_flat.size()); 129 | for (int i = 0; i < elements_flat.size(); i++) { 130 | std::string element = convert_to_string(elements_flat.data()[i]); 131 | elements.push_back(element); 132 | } 133 | 134 | // run 135 | ServerResource* server_resource = new ServerResource; 136 | OP_REQUIRES_OK(ctx, resource_mgr->Create("pjc", "server", server_resource)); 137 | server_resource->session.RunAsync("0.0.0.0:10501", std::move(elements)); 138 | } 139 | }; 140 | 141 | class PjcWaitServerOp : public OpKernel { 142 | public: 143 | explicit PjcWaitServerOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 144 | 145 | void Compute(OpKernelContext* ctx) override { 146 | auto resource_mgr = ctx->resource_manager(); 147 | 148 | // wait until server is shut down 149 | ServerResource* server_resource = nullptr; 150 | OP_REQUIRES_OK(ctx, resource_mgr->Lookup("pjc", "server", &server_resource)); 151 | server_resource->session.Wait(); 152 | 153 | // cleanup 154 | OP_REQUIRES_OK(ctx, resource_mgr->Delete("pjc", "server")); 155 | server_resource = nullptr; 156 | } 157 | }; 158 | 159 | REGISTER_KERNEL_BUILDER(Name("PjcRunClient").Device(DEVICE_CPU).TypeConstraint("elements_dtype").TypeConstraint("values_dtype"), PjcRunClientOp); 160 | REGISTER_KERNEL_BUILDER(Name("PjcRunClient").Device(DEVICE_CPU).TypeConstraint("elements_dtype").TypeConstraint("values_dtype"), PjcRunClientOp); 161 | REGISTER_KERNEL_BUILDER(Name("PjcRunClient").Device(DEVICE_CPU).TypeConstraint("elements_dtype").TypeConstraint("values_dtype"), PjcRunClientOp); 162 | REGISTER_KERNEL_BUILDER(Name("PjcRunClient").Device(DEVICE_CPU).TypeConstraint("elements_dtype").TypeConstraint("values_dtype"), PjcRunClientOp); 163 | REGISTER_KERNEL_BUILDER(Name("PjcRunClient").Device(DEVICE_CPU).TypeConstraint("elements_dtype").TypeConstraint("values_dtype"), PjcRunClientOp); 164 | REGISTER_KERNEL_BUILDER(Name("PjcRunClient").Device(DEVICE_CPU).TypeConstraint("elements_dtype").TypeConstraint("values_dtype"), PjcRunClientOp); 165 | 166 | REGISTER_KERNEL_BUILDER(Name("PjcRunServer").Device(DEVICE_CPU).TypeConstraint("elements_dtype"), PjcRunServerOp); 167 | REGISTER_KERNEL_BUILDER(Name("PjcRunServer").Device(DEVICE_CPU).TypeConstraint("elements_dtype"), PjcRunServerOp); 168 | REGISTER_KERNEL_BUILDER(Name("PjcRunServer").Device(DEVICE_CPU).TypeConstraint("elements_dtype"), PjcRunServerOp); 169 | 170 | REGISTER_KERNEL_BUILDER(Name("PjcRunAsyncServer").Device(DEVICE_CPU).TypeConstraint("elements_dtype"), PjcRunAsyncServerOp); 171 | REGISTER_KERNEL_BUILDER(Name("PjcRunAsyncServer").Device(DEVICE_CPU).TypeConstraint("elements_dtype"), PjcRunAsyncServerOp); 172 | REGISTER_KERNEL_BUILDER(Name("PjcRunAsyncServer").Device(DEVICE_CPU).TypeConstraint("elements_dtype"), PjcRunAsyncServerOp); 173 | 174 | REGISTER_KERNEL_BUILDER(Name("PjcWaitServer").Device(DEVICE_CPU), PjcWaitServerOp); 175 | -------------------------------------------------------------------------------- /tf_pjc/cc/ops.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow/core/framework/op.h" 2 | #include "tensorflow/core/framework/shape_inference.h" 3 | 4 | REGISTER_OP("PjcRunClient") 5 | .Attr("elements_dtype: {string, int64, int32}") 6 | .Attr("values_dtype: {int64, int32}") 7 | .Input("elements: elements_dtype") 8 | .Input("values: values_dtype") 9 | .Output("size: int64") 10 | .Output("sum: int64") 11 | .SetIsStateful() 12 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 13 | ::tensorflow::shape_inference::ShapeHandle elements = c->input(0); 14 | ::tensorflow::shape_inference::ShapeHandle values = c->input(1); 15 | ::tensorflow::shape_inference::ShapeHandle unused; 16 | TF_RETURN_IF_ERROR(c->Merge(elements, values, &unused)); 17 | 18 | ::tensorflow::shape_inference::ShapeHandle size = c->Scalar(); 19 | ::tensorflow::shape_inference::ShapeHandle sum = c->Scalar(); 20 | c->set_output(0, size); 21 | c->set_output(1, sum); 22 | 23 | return ::tensorflow::Status::OK(); 24 | }); 25 | 26 | REGISTER_OP("PjcRunServer") 27 | .Attr("elements_dtype: {string, int64, int32}") 28 | .Input("elements: elements_dtype") 29 | .SetIsStateful(); 30 | 31 | REGISTER_OP("PjcRunAsyncServer") 32 | .Attr("elements_dtype: {string, int64, int32}") 33 | .Input("elements: elements_dtype") 34 | .SetIsStateful(); 35 | 36 | REGISTER_OP("PjcWaitServer") 37 | .SetIsStateful(); 38 | -------------------------------------------------------------------------------- /tf_pjc/python/ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow 2 | from tensorflow.python.framework import load_library 3 | from tensorflow.python.framework.errors import NotFoundError 4 | from tensorflow.python.platform import resource_loader 5 | 6 | 7 | op_lib_file = resource_loader.get_path_to_datafile('_pjc_ops.so') 8 | pjc_ops = load_library.load_op_library(op_lib_file) 9 | 10 | pjc_run_client = pjc_ops.pjc_run_client 11 | pjc_run_server = pjc_ops.pjc_run_server 12 | pjc_run_async_server = pjc_ops.pjc_run_async_server 13 | pjc_wait_server = pjc_ops.pjc_wait_server 14 | -------------------------------------------------------------------------------- /tf_pjc/python/ops_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.platform import test 3 | 4 | from tf_pjc.python.ops import pjc_run_client 5 | from tf_pjc.python.ops import pjc_run_server 6 | from tf_pjc.python.ops import pjc_run_async_server 7 | from tf_pjc.python.ops import pjc_wait_server 8 | 9 | class PjcTest(test.TestCase): 10 | 11 | def test_run_on_strings(self): 12 | # inputs 13 | server_elements = tf.constant(["a", "b", "d"]) 14 | client_elements = tf.constant(["a", "b", "c"]) 15 | client_values = tf.constant([10, 20, 40], dtype=tf.int64) 16 | 17 | # protocol 18 | server_run_async_op = pjc_run_async_server(server_elements) 19 | with tf.control_dependencies([server_run_async_op]): 20 | client_run_op = pjc_run_client( 21 | elements=client_elements, 22 | values=client_values) 23 | server_wait_op = pjc_wait_server() 24 | 25 | # run 26 | with tf.Session() as sess: 27 | (intersection_size, intersection_sum), _ = sess.run([client_run_op, server_wait_op]) 28 | 29 | assert intersection_size == 2, (intersection_size, intersection_sum) 30 | assert intersection_sum == 30, (intersection_size, intersection_sum) 31 | 32 | def test_run_on_int(self): 33 | # inputs 34 | server_elements = tf.constant([1, 2, 4]) 35 | client_elements = tf.constant([1, 2, 3]) 36 | client_values = tf.constant([10, 20, 40], dtype=tf.int64) 37 | 38 | # protocol 39 | server_run_async_op = pjc_run_async_server(server_elements) 40 | with tf.control_dependencies([server_run_async_op]): 41 | client_run_op = pjc_run_client( 42 | elements=client_elements, 43 | values=client_values) 44 | server_wait_op = pjc_wait_server() 45 | 46 | # run 47 | with tf.Session() as sess: 48 | (intersection_size, intersection_sum), _ = sess.run([client_run_op, server_wait_op]) 49 | 50 | assert intersection_size == 2, (intersection_size, intersection_sum) 51 | assert intersection_sum == 30, (intersection_size, intersection_sum) 52 | 53 | 54 | if __name__ == '__main__': 55 | test.main() 56 | -------------------------------------------------------------------------------- /tf_pjc/python/protocol.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tf_pjc.python.ops import pjc_run_client 4 | from tf_pjc.python.ops import pjc_run_server 5 | from tf_pjc.python.ops import pjc_run_async_server 6 | from tf_pjc.python.ops import pjc_wait_server 7 | 8 | 9 | class PrivateIntersectionSum: 10 | 11 | def __init__(self, client_device, server_device): 12 | self.client_device = client_device 13 | self.server_device = server_device 14 | 15 | def __call__(self, client_elements, client_values, server_elements, wait_for_server=True): 16 | assert isinstance(server_elements, tf.Tensor) 17 | assert isinstance(client_elements, tf.Tensor) 18 | assert isinstance(client_values, tf.Tensor) 19 | assert client_elements.shape == client_values.shape 20 | 21 | # launch server and wait 22 | with tf.device(self.server_device): 23 | server_run_async_op = pjc_run_async_server(elements=server_elements) 24 | 25 | with tf.control_dependencies([server_run_async_op]): 26 | server_wait_op = pjc_wait_server() 27 | 28 | # once server is launched we can run the client, 29 | # using TensorFlow for the synchronization 30 | with tf.device(self.client_device): 31 | with tf.control_dependencies([server_run_async_op]): 32 | client_run_op = pjc_run_client(elements=client_elements, 33 | values=client_values) 34 | 35 | if not wait_for_server: 36 | return client_run_op, server_run_async_op 37 | 38 | return client_run_op, server_wait_op 39 | -------------------------------------------------------------------------------- /tf_pjc/python/protocol_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.platform import test 3 | 4 | from tf_pjc.python.protocol import PrivateIntersectionSum 5 | 6 | class PjcTest(test.TestCase): 7 | 8 | def test_simple(self): 9 | 10 | # device strings of the two players involved 11 | client_device = "/job:localhost/replica:0/task:0/device:CPU:0" 12 | server_device = "/job:localhost/replica:0/task:0/device:CPU:0" 13 | 14 | # construct private input of server 15 | with tf.device(server_device): 16 | server_elements = tf.constant(["a", "b", "c"]) 17 | 18 | # construct private inputs of client 19 | with tf.device(client_device): 20 | client_elements = tf.constant(["a", "b", "c", "d"]) 21 | client_values = tf.constant([10, 20, 40, 80]) 22 | 23 | # use protocol to securely compute intersection size and sum 24 | protocol_instance = PrivateIntersectionSum(client_device, server_device) 25 | client_result_op, server_wait_op = protocol_instance(client_elements, client_values, server_elements) 26 | 27 | # print private result (which is local to the client) 28 | with tf.device(client_device): 29 | intersection_size, intersection_sum = client_result_op 30 | print_size_op = tf.print("Intersection size: ", intersection_size) 31 | print_sum_op = tf.print("Intersection sum: ", intersection_sum) 32 | print_op = tf.group(print_size_op, print_sum_op) 33 | 34 | # run in TensorFlow session 35 | with tf.Session() as sess: 36 | sess.run([print_op, server_wait_op]) 37 | 38 | # def test_stress(self): 39 | 40 | # # players 41 | # client_device = "/job:localhost/replica:0/task:0/device:CPU:0" 42 | # server_device = "/job:localhost/replica:0/task:0/device:CPU:0" 43 | 44 | # N = 1000 45 | 46 | # # inputs 47 | # server_elements = tf.range(N) 48 | # client_elements = tf.range(N) 49 | # client_values = tf.range(N) 50 | 51 | # # protocol 52 | # protocol = PrivateIntersectionSum(client_device, server_device) 53 | # client_op, server_op = protocol(client_elements, client_values, server_elements) 54 | 55 | # # run 56 | # with tf.Session() as sess: 57 | # (intersection_size, intersection_sum), _ = sess.run([client_op, server_op]) 58 | 59 | # assert intersection_size == N, (intersection_size, intersection_sum) 60 | # assert intersection_sum == (N * (N - 1)) // 2, intersection_sum 61 | 62 | if __name__ == '__main__': 63 | test.main() 64 | -------------------------------------------------------------------------------- /third_party/glog/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | package(default_visibility = ["//visibility:public"]) 16 | 17 | cc_library( 18 | name = "glog", 19 | srcs = [ 20 | "config.h", 21 | "src/base/commandlineflags.h", 22 | "src/base/googleinit.h", 23 | "src/base/mutex.h", 24 | "src/demangle.cc", 25 | "src/demangle.h", 26 | "src/logging.cc", 27 | "src/raw_logging.cc", 28 | "src/signalhandler.cc", 29 | "src/symbolize.cc", 30 | "src/symbolize.h", 31 | "src/utilities.cc", 32 | "src/utilities.h", 33 | "src/vlog_is_on.cc", 34 | ] + glob(["src/stacktrace*.h"]), 35 | hdrs = [ 36 | "src/glog/log_severity.h", 37 | "src/glog/logging.h", 38 | "src/glog/raw_logging.h", 39 | "src/glog/stl_logging.h", 40 | "src/glog/vlog_is_on.h", 41 | ], 42 | copts = [ 43 | "-Wno-sign-compare", 44 | "-U_XOPEN_SOURCE", 45 | ], 46 | includes = ["./src"], 47 | linkopts = ["-lpthread"] + select({ 48 | ":libunwind": ["-lunwind"], 49 | "//conditions:default": [], 50 | }), 51 | visibility = ["//visibility:public"], 52 | deps = [ 53 | "@com_github_gflags_gflags//:gflags", 54 | ], 55 | ) 56 | 57 | config_setting( 58 | name = "libunwind", 59 | values = { 60 | "define": "libunwind=true", 61 | }, 62 | ) 63 | 64 | genrule( 65 | name = "run_configure", 66 | srcs = [ 67 | "README", 68 | "Makefile.in", 69 | "config.guess", 70 | "config.sub", 71 | "install-sh", 72 | "ltmain.sh", 73 | "missing", 74 | "libglog.pc.in", 75 | "src/config.h.in", 76 | "src/glog/logging.h.in", 77 | "src/glog/raw_logging.h.in", 78 | "src/glog/stl_logging.h.in", 79 | "src/glog/vlog_is_on.h.in", 80 | ], 81 | outs = [ 82 | "config.h", 83 | "src/glog/logging.h", 84 | "src/glog/raw_logging.h", 85 | "src/glog/stl_logging.h", 86 | "src/glog/vlog_is_on.h", 87 | ], 88 | cmd = "$(location :configure)" + 89 | "&& cp -v src/config.h $(location config.h) " + 90 | "&& cp -v src/glog/logging.h $(location src/glog/logging.h) " + 91 | "&& cp -v src/glog/raw_logging.h $(location src/glog/raw_logging.h) " + 92 | "&& cp -v src/glog/stl_logging.h $(location src/glog/stl_logging.h) " + 93 | "&& cp -v src/glog/vlog_is_on.h $(location src/glog/vlog_is_on.h) ", 94 | tools = [ 95 | "configure", 96 | ], 97 | ) 98 | -------------------------------------------------------------------------------- /third_party/tf/BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tf-encrypted/tf-pjc/72cf1767c3426b76a27f5bf72363ed295734ce9c/third_party/tf/BUILD -------------------------------------------------------------------------------- /third_party/tf/BUILD.tpl: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | cc_library( 4 | name = "tf_header_lib", 5 | hdrs = [":tf_header_include"], 6 | includes = ["include"], 7 | visibility = ["//visibility:public"], 8 | ) 9 | 10 | cc_library( 11 | name = "libtensorflow_framework", 12 | srcs = [":libtensorflow_framework.so"], 13 | #data = ["lib/libtensorflow_framework.so"], 14 | visibility = ["//visibility:public"], 15 | ) 16 | 17 | %{TF_HEADER_GENRULE} 18 | %{TF_SHARED_LIBRARY_GENRULE} -------------------------------------------------------------------------------- /third_party/tf/tf_configure.bzl: -------------------------------------------------------------------------------- 1 | """Setup TensorFlow as external dependency""" 2 | 3 | _TF_HEADER_DIR = "TF_HEADER_DIR" 4 | _TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR" 5 | _TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME" 6 | 7 | def _tpl(repository_ctx, tpl, substitutions = {}, out = None): 8 | if not out: 9 | out = tpl 10 | repository_ctx.template( 11 | out, 12 | Label("//third_party/tf:%s.tpl" % tpl), 13 | substitutions, 14 | ) 15 | 16 | def _fail(msg): 17 | """Output failure message when auto configuration fails.""" 18 | red = "\033[0;31m" 19 | no_color = "\033[0m" 20 | fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg)) 21 | 22 | def _is_windows(repository_ctx): 23 | """Returns true if the host operating system is windows.""" 24 | os_name = repository_ctx.os.name.lower() 25 | if os_name.find("windows") != -1: 26 | return True 27 | return False 28 | 29 | def _execute( 30 | repository_ctx, 31 | cmdline, 32 | error_msg = None, 33 | error_details = None, 34 | empty_stdout_fine = False): 35 | """Executes an arbitrary shell command. 36 | 37 | Helper for executes an arbitrary shell command. 38 | 39 | Args: 40 | repository_ctx: the repository_ctx object. 41 | cmdline: list of strings, the command to execute. 42 | error_msg: string, a summary of the error if the command fails. 43 | error_details: string, details about the error or steps to fix it. 44 | empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise 45 | it's an error. 46 | 47 | Returns: 48 | The result of repository_ctx.execute(cmdline). 49 | """ 50 | result = repository_ctx.execute(cmdline) 51 | if result.stderr or not (empty_stdout_fine or result.stdout): 52 | _fail("\n".join([ 53 | error_msg.strip() if error_msg else "Repository command failed", 54 | result.stderr.strip(), 55 | error_details if error_details else "", 56 | ])) 57 | return result 58 | 59 | def _read_dir(repository_ctx, src_dir): 60 | """Returns a string with all files in a directory. 61 | 62 | Finds all files inside a directory, traversing subfolders and following 63 | symlinks. The returned string contains the full path of all files 64 | separated by line breaks. 65 | 66 | Args: 67 | repository_ctx: the repository_ctx object. 68 | src_dir: directory to find files from. 69 | 70 | Returns: 71 | A string of all files inside the given dir. 72 | """ 73 | if _is_windows(repository_ctx): 74 | src_dir = src_dir.replace("/", "\\") 75 | find_result = _execute( 76 | repository_ctx, 77 | ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"], 78 | empty_stdout_fine = True, 79 | ) 80 | 81 | # src_files will be used in genrule.outs where the paths must 82 | # use forward slashes. 83 | result = find_result.stdout.replace("\\", "/") 84 | else: 85 | find_result = _execute( 86 | repository_ctx, 87 | ["find", src_dir, "-follow", "-type", "f"], 88 | empty_stdout_fine = True, 89 | ) 90 | result = find_result.stdout 91 | return result 92 | 93 | def _genrule(genrule_name, command, outs): 94 | """Returns a string with a genrule. 95 | 96 | Genrule executes the given command and produces the given outputs. 97 | 98 | Args: 99 | genrule_name: A unique name for genrule target. 100 | command: The command to run. 101 | outs: A list of files generated by this rule. 102 | 103 | Returns: 104 | A genrule target. 105 | """ 106 | return ( 107 | "genrule(\n" + 108 | ' name = "' + 109 | genrule_name + '",\n' + 110 | " outs = [\n" + 111 | outs + 112 | "\n ],\n" + 113 | ' cmd = """\n' + 114 | command + 115 | '\n """,\n' + 116 | ")\n" 117 | ) 118 | 119 | def _norm_path(path): 120 | """Returns a path with '/' and remove the trailing slash.""" 121 | path = path.replace("\\", "/") 122 | if path[-1] == "/": 123 | path = path[:-1] 124 | return path 125 | 126 | def _symlink_genrule_for_dir( 127 | repository_ctx, 128 | src_dir, 129 | dest_dir, 130 | genrule_name, 131 | src_files = [], 132 | dest_files = []): 133 | """Returns a genrule to symlink(or copy if on Windows) a set of files. 134 | 135 | If src_dir is passed, files will be read from the given directory; otherwise 136 | we assume files are in src_files and dest_files. 137 | 138 | Args: 139 | repository_ctx: the repository_ctx object. 140 | src_dir: source directory. 141 | dest_dir: directory to create symlink in. 142 | genrule_name: genrule name. 143 | src_files: list of source files instead of src_dir. 144 | dest_files: list of corresonding destination files. 145 | 146 | Returns: 147 | genrule target that creates the symlinks. 148 | """ 149 | if src_dir != None: 150 | src_dir = _norm_path(src_dir) 151 | dest_dir = _norm_path(dest_dir) 152 | files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines())) 153 | 154 | # Create a list with the src_dir stripped to use for outputs. 155 | dest_files = files.replace(src_dir, "").splitlines() 156 | src_files = files.splitlines() 157 | command = [] 158 | outs = [] 159 | for i in range(len(dest_files)): 160 | if dest_files[i] != "": 161 | # If we have only one file to link we do not want to use the dest_dir, as 162 | # $(@D) will include the full path to the file. 163 | dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i] 164 | 165 | # Copy the headers to create a sandboxable setup. 166 | cmd = "cp -f" 167 | command.append(cmd + ' "%s" "%s"' % (src_files[i], dest)) 168 | outs.append(' "' + dest_dir + dest_files[i] + '",') 169 | genrule = _genrule( 170 | genrule_name, 171 | " && ".join(command), 172 | "\n".join(outs), 173 | ) 174 | return genrule 175 | 176 | def _tf_pip_impl(repository_ctx): 177 | tf_header_dir = repository_ctx.os.environ[_TF_HEADER_DIR] 178 | tf_header_rule = _symlink_genrule_for_dir( 179 | repository_ctx, 180 | tf_header_dir, 181 | "include", 182 | "tf_header_include", 183 | ) 184 | 185 | tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR] 186 | tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME] 187 | tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name) 188 | 189 | tf_shared_library_rule = _symlink_genrule_for_dir( 190 | repository_ctx, 191 | None, 192 | "", 193 | "libtensorflow_framework.so", 194 | [tf_shared_library_path], 195 | [tf_shared_library_name], 196 | ) 197 | 198 | _tpl(repository_ctx, "BUILD", { 199 | "%{TF_HEADER_GENRULE}": tf_header_rule, 200 | "%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule, 201 | }) 202 | 203 | tf_configure = repository_rule( 204 | implementation = _tf_pip_impl, 205 | environ = [ 206 | _TF_HEADER_DIR, 207 | _TF_SHARED_LIBRARY_DIR, 208 | _TF_SHARED_LIBRARY_NAME, 209 | ], 210 | ) 211 | --------------------------------------------------------------------------------