├── jest-globals.js
├── .prettierignore
├── .eslintignore
├── .prettierrc
├── .gitignore
├── art
├── fake-logo.png
└── mnist-demo-ani.gif
├── documentation.yml
├── src
├── index.js
├── events.js
├── types
│ ├── pointer-tensor.js
│ ├── protocol.js
│ ├── state.js
│ ├── plan.js
│ ├── message.js
│ ├── placeholder.js
│ ├── torch.js
│ ├── role.js
│ └── computation-action.js
├── logger.js
├── object-registry.js
├── data-channel-message-queue.js
├── _errors.js
├── _constants.js
├── protobuf
│ ├── mapping.js
│ └── index.js
├── syft-model.js
├── sockets.js
├── syft.js
├── syft-webrtc.js
├── data-channel-message.js
├── speed-test.js
├── job.js
└── grid-api-client.js
├── .npmignore
├── copy-examples.sh
├── .babelrc
├── .eslintrc.json
├── test
├── types
│ ├── pointer-tensor.test.js
│ ├── placeholder.test.js
│ ├── protocol.test.js
│ ├── message.test.js
│ ├── torch.test.js
│ └── plan.test.js
├── data
│ ├── dummy.tpl.js
│ └── generate-data.py
├── events.test.js
├── mocks
│ ├── webrtc.js
│ └── grid.js
├── data_channel_message_queue.test.js
├── logger.test.js
├── data_channel_message.test.js
├── sockets.test.js
└── protobuf.test.js
├── .github
└── workflows
│ └── run-tests.yml
├── examples
├── mnist
│ ├── package.json
│ ├── README.md
│ ├── webpack.config.js
│ ├── index.html
│ ├── mnist.js
│ └── index.js
└── multi-armed-bandit
│ ├── index.html
│ ├── package.json
│ ├── webpack.config.js
│ ├── README.md
│ ├── index.js
│ ├── app.js
│ ├── logo-white.svg
│ └── logo-color.svg
├── rollup.config.js
├── .all-contributorsrc
├── package.json
├── CHANGELOG.md
├── API-REFERENCE.md
└── LICENSE
/jest-globals.js:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.prettierignore:
--------------------------------------------------------------------------------
1 | API-REFERENCE.md
--------------------------------------------------------------------------------
/.eslintignore:
--------------------------------------------------------------------------------
1 | node_modules
2 | dist
3 | examples
--------------------------------------------------------------------------------
/.prettierrc:
--------------------------------------------------------------------------------
1 | {
2 | "singleQuote": true
3 | }
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | node_modules
2 | coverage
3 | dist
4 | tmp
5 | yarn.lock
6 | .DS_Store
7 | .idea
--------------------------------------------------------------------------------
/art/fake-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shashigharti/syft.js/master/art/fake-logo.png
--------------------------------------------------------------------------------
/art/mnist-demo-ani.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shashigharti/syft.js/master/art/mnist-demo-ani.gif
--------------------------------------------------------------------------------
/documentation.yml:
--------------------------------------------------------------------------------
1 | toc:
2 | - Syft
3 | - Job
4 | - Job#accepted
5 | - Job#rejected
6 | - Job#error
7 | - SyftModel
8 | - Plan
9 |
--------------------------------------------------------------------------------
/src/index.js:
--------------------------------------------------------------------------------
1 | // Export all constants
2 | export * from './_constants';
3 |
4 | // Export the main class as default AND as named
5 | export { default as Syft } from './syft';
6 | export { default } from './syft';
7 |
--------------------------------------------------------------------------------
/.npmignore:
--------------------------------------------------------------------------------
1 | node_modules
2 | coverage
3 | examples
4 | test
5 | tmp
6 | yarn.lock
7 | .DS_Store
8 | .idea
9 | .babelrc
10 | .prettierrc
11 | .travis.yml
12 | copy-examples.sh
13 | jest-globals.js
14 | rollup.config.js
--------------------------------------------------------------------------------
/copy-examples.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | rm -rf ./tmp
3 | mkdir tmp
4 |
5 | for directory in $(find ./examples -type d -mindepth 1 -maxdepth 1);
6 | do
7 | directory="$directory/dist/"
8 | name="$(cut -d'/' -f3 <<<$directory)"
9 |
10 | rsync -avzh "$directory" "./tmp/$name"
11 | done
--------------------------------------------------------------------------------
/.babelrc:
--------------------------------------------------------------------------------
1 | {
2 | "presets": ["@babel/preset-env"],
3 | "plugins": [
4 | "@babel/plugin-proposal-class-properties"
5 | ],
6 | "env": {
7 | "test": {
8 | "plugins": [
9 | "@babel/plugin-transform-runtime",
10 | "@babel/plugin-proposal-class-properties"
11 | ]
12 | }
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/.eslintrc.json:
--------------------------------------------------------------------------------
1 | {
2 | "parser": "babel-eslint",
3 | "env": {
4 | "browser": true,
5 | "es6": true,
6 | "node": true,
7 | "jest": true
8 | },
9 | "extends": ["eslint:recommended"],
10 | "parserOptions": {
11 | "ecmaVersion": 2018,
12 | "sourceType": "module"
13 | },
14 | "rules": {
15 | "strict": ["error", "never"]
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/events.js:
--------------------------------------------------------------------------------
1 | export default class EventObserver {
2 | constructor() {
3 | this.observers = [];
4 | }
5 |
6 | subscribe(type, func) {
7 | this.observers.push({ type, func });
8 | }
9 |
10 | unsubscribe(eventType) {
11 | this.observers = this.observers.filter(({ type }) => eventType !== type);
12 | }
13 |
14 | broadcast(eventType, data) {
15 | this.observers.forEach(observer => {
16 | if (eventType === observer.type) {
17 | return observer.func(data);
18 | }
19 | });
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/src/types/pointer-tensor.js:
--------------------------------------------------------------------------------
1 | // Create a class to represent pointer tensors
2 | // Add all the attributes that are serialized, just as for range and slice
3 |
4 | export default class PointerTensor {
5 | constructor(
6 | id,
7 | idAtLocation,
8 | locationId,
9 | pointToAttr,
10 | shape,
11 | garbageCollectData
12 | ) {
13 | this.id = id;
14 | this.idAtLocation = idAtLocation;
15 | this.locationId = locationId;
16 | this.pointToAttr = pointToAttr;
17 | this.shape = shape;
18 | this.garbageCollectData = garbageCollectData;
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/test/types/pointer-tensor.test.js:
--------------------------------------------------------------------------------
1 | import PointerTensor from '../../src/types/pointer-tensor';
2 |
3 | describe('PointerTensor', () => {
4 | test('can be properly constructed', () => {
5 | const obj = new PointerTensor(123, 444, 'worker1', null, [2, 3], true);
6 | expect(obj.id).toStrictEqual(123);
7 | expect(obj.idAtLocation).toStrictEqual(444);
8 | expect(obj.locationId).toStrictEqual('worker1');
9 | expect(obj.pointToAttr).toStrictEqual(null);
10 | expect(obj.shape).toStrictEqual([2, 3]);
11 | expect(obj.garbageCollectData).toStrictEqual(true);
12 | });
13 | });
14 |
--------------------------------------------------------------------------------
/src/logger.js:
--------------------------------------------------------------------------------
1 | // A simple logging function
2 | export default class Logger {
3 | constructor(system, verbose) {
4 | if (!Logger.instance) {
5 | this.system = system;
6 | this.verbose = verbose;
7 | Logger.instance = this;
8 | }
9 | return Logger.instance;
10 | }
11 |
12 | log(message, data) {
13 | // Only log if verbose is turned on
14 | if (this.verbose) {
15 | const output = `${Date.now()}: ${this.system} - ${message}`;
16 |
17 | // Have the passed additional data?
18 | if (data) {
19 | console.log(output, data);
20 | } else {
21 | console.log(output);
22 | }
23 | }
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/test/types/placeholder.test.js:
--------------------------------------------------------------------------------
1 | import { Placeholder } from '../../src/types/placeholder';
2 |
3 | describe('Placeholder', () => {
4 | test('can be properly constructed', () => {
5 | const obj = new Placeholder(123, ['tag1', 'tag2'], 'desc');
6 | expect(obj.id).toStrictEqual(123);
7 | expect(obj.tags).toStrictEqual(['tag1', 'tag2']);
8 | expect(obj.description).toStrictEqual('desc');
9 | });
10 |
11 | test('can get placeholder order', () => {
12 | const obj = new Placeholder(123, ['#input-1'], 'desc');
13 | expect(obj.getOrderFromTags('#input')).toStrictEqual(1);
14 | expect(() => obj.getOrderFromTags('#output')).toThrow();
15 | });
16 | });
17 |
--------------------------------------------------------------------------------
/test/types/protocol.test.js:
--------------------------------------------------------------------------------
1 | import Protocol from '../../src/types/protocol';
2 |
3 | describe('Protocol', () => {
4 | test('can be properly constructed', () => {
5 | const planAssignments = [
6 | ['worker1', '1234'],
7 | ['worker2', '3456']
8 | ];
9 | const obj = new Protocol(
10 | 123,
11 | ['tag1', 'tag2'],
12 | 'desc',
13 | planAssignments,
14 | true
15 | );
16 | expect(obj.id).toStrictEqual(123);
17 | expect(obj.tags).toStrictEqual(['tag1', 'tag2']);
18 | expect(obj.description).toStrictEqual('desc');
19 | expect(obj.plans).toStrictEqual(planAssignments);
20 | expect(obj.workersResolved).toStrictEqual(true);
21 | });
22 | });
23 |
--------------------------------------------------------------------------------
/.github/workflows/run-tests.yml:
--------------------------------------------------------------------------------
1 | name: Run tests and coverage
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 | types: [opened, synchronize, reopened]
9 |
10 | jobs:
11 | build:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Use Node.js ${{ matrix.node-version }}
18 | uses: actions/setup-node@v1
19 | with:
20 | node-version: ${{ matrix.node-version }}
21 | - name: Install dependencies
22 | run: |
23 | npm install
24 | sudo npm install -g codecov
25 | - name: Test with npm
26 | run: |
27 | npm run test
28 | - name: Test code coverage
29 | run: |
30 | codecov
31 |
--------------------------------------------------------------------------------
/examples/mnist/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "syftjs-mnist",
3 | "version": "1.0.0",
4 | "private": true,
5 | "description": "",
6 | "main": "index.js",
7 | "scripts": {
8 | "start": "webpack-dev-server --mode development",
9 | "build": "rm -rf dist && webpack --mode production",
10 | "test": "echo \"Error: no test specified\" && exit 1"
11 | },
12 | "author": "",
13 | "license": "",
14 | "dependencies": {
15 | "@openmined/syft.js": "github:openmined/syft.js"
16 | },
17 | "devDependencies": {
18 | "babel-loader": "^8.0.6",
19 | "html-webpack-plugin": "^3.2.0",
20 | "path": "^0.12.7",
21 | "regenerator-runtime": "^0.13.3",
22 | "webpack": "^4.39.1",
23 | "webpack-cli": "^3.3.7",
24 | "webpack-dev-server": "^3.7.2"
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/src/types/protocol.js:
--------------------------------------------------------------------------------
1 | import { getPbId } from '../protobuf';
2 |
3 | export default class Protocol {
4 | constructor(id, tags, description, planAssigments, workersResolved) {
5 | this.id = id;
6 | this.tags = tags;
7 | this.description = description;
8 | this.plans = planAssigments;
9 | this.workersResolved = workersResolved;
10 | }
11 |
12 | static unbufferize(worker, pb) {
13 | const planAssignments = [];
14 | if (pb.plan_assignments) {
15 | pb.plan_assignments.forEach(item => {
16 | planAssignments.push([getPbId(item.worker_id), getPbId(item.plan_id)]);
17 | });
18 | }
19 | return new Protocol(
20 | getPbId(pb.id),
21 | pb.tags,
22 | pb.description,
23 | planAssignments,
24 | pb.workers_resolved
25 | );
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/examples/multi-armed-bandit/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
10 |
14 |
15 |
21 |
22 | Syft.js Multi-armed Bandit Example
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/examples/mnist/README.md:
--------------------------------------------------------------------------------
1 | # syft.js MNIST Example
2 |
3 | This is a demonstration of how to use [syft.js](https://github.com/openmined/syft.js)
4 | with [PyGrid](https://github.com/OpenMined/pygrid) to train a plan on local data in the browser.
5 |
6 | ## Quick Start
7 |
8 | 1. Install and start [PyGrid](https://github.com/OpenMined/pygrid)
9 | 2. Install [PySyft](https://github.com/OpenMined/PySyft) and [execute the "Part 01 - Create Plan" notebook](https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/static-fl/Part%2001%20-%20Create%20Plan.ipynb) from `examples/tutorials/static-fl` folder to seed the MNIST plan and model into PyGrid.
10 | 3. Now back in this folder, execute `npm install`
11 | 4. And then execute `npm start`
12 |
13 | This will launch a web browser running the file `index.js`. Every time you make changes to this file, the server will automatically re-compile your code and refresh the page. No need to start and stop. :)
14 |
--------------------------------------------------------------------------------
/examples/multi-armed-bandit/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "syftjs-multi-armed-bandit",
3 | "version": "1.0.0",
4 | "private": true,
5 | "description": "",
6 | "main": "index.js",
7 | "scripts": {
8 | "start": "webpack-dev-server --mode development",
9 | "build": "rm -rf dist && webpack --mode production",
10 | "test": "echo \"Error: no test specified\" && exit 1"
11 | },
12 | "author": "",
13 | "license": "",
14 | "dependencies": {
15 | "@emotion/core": "^10.0.28",
16 | "@fortawesome/fontawesome-free": "^5.13.1",
17 | "@openmined/syft.js": "github:openmined/syft.js",
18 | "react": "^16.13.1",
19 | "react-dom": "^16.13.1"
20 | },
21 | "devDependencies": {
22 | "babel-loader": "^8.0.6",
23 | "html-webpack-plugin": "^3.2.0",
24 | "path": "^0.12.7",
25 | "regenerator-runtime": "^0.13.3",
26 | "svg-url-loader": "^6.0.0",
27 | "webpack": "^4.39.1",
28 | "webpack-cli": "^3.3.7",
29 | "webpack-dev-server": "^3.7.2"
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/test/data/dummy.tpl.js:
--------------------------------------------------------------------------------
1 | export const MNIST_BATCH_SIZE = parseInt('%MNIST_BATCH_SIZE%');
2 | export const MNIST_LR = parseFloat('%MNIST_LR%');
3 | export const MNIST_PLAN = '%MNIST_PLAN%';
4 | export const MNIST_BATCH_DATA = '%MNIST_BATCH_DATA%';
5 | export const MNIST_MODEL_PARAMS = '%MNIST_MODEL_PARAMS%';
6 | export const MNIST_UPD_MODEL_PARAMS = '%MNIST_UPD_MODEL_PARAMS%';
7 | export const MNIST_LOSS = parseFloat('%MNIST_LOSS%');
8 | export const MNIST_ACCURACY = parseFloat('%MNIST_ACCURACY%');
9 |
10 | export const PLAN_WITH_STATE = '%PLAN_WITH_STATE%';
11 |
12 | export const BANDIT_SIMPLE_PLAN = '%BANDIT_SIMPLE_PLAN%';
13 | export const BANDIT_SIMPLE_MODEL_PARAMS = '%BANDIT_SIMPLE_MODEL_PARAMS%';
14 | export const BANDIT_THOMPSON_PLAN = '%BANDIT_THOMPSON_PLAN%';
15 | export const BANDIT_THOMPSON_MODEL_PARAMS = '%BANDIT_THOMPSON_MODEL_PARAMS%';
16 |
17 | export const PROTOCOL =
18 | 'CgYIjcivoCUqEwoGCIHIr6AlEgkSB3dvcmtlcjEqEwoGCIXIr6AlEgkSB3dvcmtlcjIqEwoGCInIr6AlEgkSB3dvcmtlcjM=';
19 |
--------------------------------------------------------------------------------
/src/types/state.js:
--------------------------------------------------------------------------------
1 | import { protobuf, unbufferize } from '../protobuf';
2 |
3 | export class State {
4 | constructor(placeholders = null, tensors = null) {
5 | this.placeholders = placeholders;
6 | this.tensors = tensors;
7 | }
8 |
9 | getTfTensors() {
10 | return this.tensors.map(t => t.toTfTensor());
11 | }
12 |
13 | static unbufferize(worker, pb) {
14 | const tensors = pb.tensors.map(stateTensor => {
15 | // unwrap StateTensor
16 | return unbufferize(worker, stateTensor[stateTensor.tensor]);
17 | });
18 |
19 | return new State(unbufferize(worker, pb.placeholders), tensors);
20 | }
21 |
22 | bufferize(worker) {
23 | const tensorsPb = this.tensors.map(tensor =>
24 | protobuf.syft_proto.execution.v1.StateTensor.create({
25 | torch_tensor: tensor.bufferize(worker)
26 | })
27 | );
28 | const placeholdersPb = this.placeholders.map(ph => ph.bufferize());
29 | return protobuf.syft_proto.execution.v1.State.create({
30 | placeholders: placeholdersPb,
31 | tensors: tensorsPb
32 | });
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/src/object-registry.js:
--------------------------------------------------------------------------------
1 | import * as tf from '@tensorflow/tfjs-core';
2 |
3 | export default class ObjectRegistry {
4 | constructor() {
5 | this.objects = {};
6 | this.gc = {};
7 | }
8 |
9 | set(id, obj, gc = false) {
10 | if (this.objects[id] instanceof tf.Tensor) {
11 | this.objects[id].dispose();
12 | delete this.objects[id];
13 | }
14 | this.objects[id] = obj;
15 | this.gc[id] = gc;
16 | }
17 |
18 | setGc(id, gc) {
19 | this.gc[id] = gc;
20 | }
21 |
22 | get(id) {
23 | return this.objects[id];
24 | }
25 |
26 | has(id) {
27 | return Object.hasOwnProperty.call(this.objects, id);
28 | }
29 |
30 | clear() {
31 | for (let key of Object.keys(this.objects)) {
32 | if (this.gc[key] && this.objects[key] instanceof tf.Tensor) {
33 | this.objects[key].dispose();
34 | }
35 | }
36 | this.objects = {};
37 | this.gc = {};
38 | }
39 |
40 | load(objectRegistry) {
41 | for (let key of Object.keys(objectRegistry.objects)) {
42 | this.set(key, objectRegistry.get(key));
43 | }
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/src/types/plan.js:
--------------------------------------------------------------------------------
1 | import { getPbId, unbufferize } from '../protobuf';
2 |
3 | /**
4 | * PySyft Plan.
5 | */
6 | export class Plan {
7 | /**
8 | * @hideconstructor
9 | */
10 | constructor(id, name, role = [], tags = [], description = null) {
11 | this.id = id;
12 | this.name = name;
13 | this.role = role;
14 | this.tags = tags;
15 | this.description = description;
16 | }
17 |
18 | /**
19 | * @private
20 | * @returns {Plan}
21 | */
22 | static unbufferize(worker, pb) {
23 | const id = getPbId(pb.id);
24 |
25 | return new Plan(
26 | id,
27 | pb.name,
28 | unbufferize(worker, pb.role),
29 | pb.tags,
30 | pb.description
31 | );
32 | }
33 |
34 | /**
35 | * Executes the Plan and returns its output.
36 | *
37 | * The order, type and number of arguments must match to arguments defined in the PySyft Plan.
38 | *
39 | * @param {Syft} worker
40 | * @param {...(tf.Tensor|number)} data
41 | * @returns {Promise>}
42 | */
43 | async execute(worker, ...data) {
44 | return this.role.execute(worker, ...data);
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/examples/mnist/webpack.config.js:
--------------------------------------------------------------------------------
1 | const path = require('path');
2 | const HtmlWebpackPlugin = require('html-webpack-plugin');
3 |
4 | module.exports = (env, argv) => ({
5 | mode: argv.mode,
6 | entry: ['regenerator-runtime/runtime', './index.js'],
7 | output: {
8 | path: path.join(__dirname, '/dist'),
9 | filename: 'index.bundle.js'
10 | },
11 | devtool: argv.mode === 'development' ? 'eval-source-map' : 'source-map',
12 | devServer: {
13 | port: 8080,
14 | hot: true,
15 | open: true,
16 | stats: {
17 | children: false, // Hide children information
18 | maxModules: 0 // Set the maximum number of modules to be shown
19 | }
20 | },
21 | module: {
22 | rules: [
23 | {
24 | test: /\.js$/,
25 | exclude: /node_modules/,
26 | use: {
27 | loader: 'babel-loader',
28 | options: {
29 | presets: ['@babel/preset-env'],
30 | plugins: ['@babel/plugin-proposal-class-properties']
31 | }
32 | }
33 | }
34 | ]
35 | },
36 | plugins: [new HtmlWebpackPlugin({ template: './index.html' })],
37 | externals: {
38 | '@tensorflow/tfjs-core': 'tf'
39 | }
40 | });
41 |
--------------------------------------------------------------------------------
/src/data-channel-message-queue.js:
--------------------------------------------------------------------------------
1 | import EventObserver from './events';
2 |
3 | export default class DataChannelMessageQueue {
4 | constructor() {
5 | this.messages = new Map();
6 | this.observer = new EventObserver();
7 | }
8 |
9 | /**
10 | * Register new message
11 | * @param {DataChannelMessage} message
12 | */
13 | register(message) {
14 | if (this.isRegistered(message.id)) {
15 | return false;
16 | }
17 | this.messages.set(message.id, message);
18 | message.once('ready', this.onMessageReady.bind(this));
19 | }
20 |
21 | unregister(message) {
22 | this.messages.delete(message.id);
23 | }
24 |
25 | /**
26 | * Check if registered by message id
27 | * @param {number} id
28 | */
29 | isRegistered(id) {
30 | return this.messages.has(id);
31 | }
32 |
33 | /**
34 | * Check if registered by message id
35 | * @param {number} id
36 | */
37 | getById(id) {
38 | return this.messages.get(id);
39 | }
40 |
41 | /**
42 | *
43 | * @param {DataChannelMessage} message
44 | */
45 | onMessageReady(message) {
46 | this.unregister(message);
47 | this.observer.broadcast('message', message);
48 | }
49 |
50 | on(event, func) {
51 | this.observer.subscribe(event, func);
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/rollup.config.js:
--------------------------------------------------------------------------------
1 | import babel from '@rollup/plugin-babel';
2 | import builtins from '@joseph184/rollup-plugin-node-builtins';
3 | import json from '@rollup/plugin-json';
4 | import peerDepsExternal from 'rollup-plugin-peer-deps-external';
5 | import resolve from '@rollup/plugin-node-resolve';
6 | import commonjs from '@rollup/plugin-commonjs';
7 | import filesize from 'rollup-plugin-filesize';
8 |
9 | import pkg from './package.json';
10 |
11 | const sharedOutput = {
12 | name: 'syft',
13 | sourcemap: true,
14 | exports: 'named',
15 | globals: {
16 | '@tensorflow/tfjs-core': 'tf'
17 | }
18 | };
19 |
20 | export default {
21 | input: 'src/index.js',
22 | output: [
23 | {
24 | file: pkg.browser,
25 | format: 'umd',
26 | ...sharedOutput
27 | },
28 | {
29 | file: pkg.main,
30 | format: 'cjs',
31 | ...sharedOutput
32 | },
33 | {
34 | file: pkg.module,
35 | format: 'es',
36 | ...sharedOutput
37 | }
38 | ],
39 | plugins: [
40 | builtins(),
41 | json(),
42 | peerDepsExternal(),
43 | babel({
44 | babelHelpers: 'bundled',
45 | exclude: 'node_modules/**'
46 | }),
47 | resolve({
48 | preferBuiltins: true
49 | }),
50 | commonjs(),
51 | filesize()
52 | ]
53 | };
54 |
--------------------------------------------------------------------------------
/src/_errors.js:
--------------------------------------------------------------------------------
1 | export const NO_DETAILER = d =>
2 | `Serialized object contains type that may exist in PySyft, but is not currently supported in syft.js. Please file a feature request (https://github.com/OpenMined/syft.js/issues) for type ${d}.`;
3 |
4 | export const NOT_ENOUGH_ARGS = (passed, expected) =>
5 | `You have passed ${passed} argument(s) when the plan requires ${expected} argument(s).`;
6 |
7 | export const MISSING_VARIABLE = () =>
8 | `Command requires variable that is missing.`;
9 |
10 | export const NO_PLAN = `The operation you're attempting to run requires a plan before being called.`;
11 |
12 | export const PLAN_ALREADY_COMPLETED = (name, id) =>
13 | `You have already executed the plan named "${name}" with id "${id}".`;
14 |
15 | export const CANNOT_FIND_COMMAND = command =>
16 | `Command ${command} not found in in TensorFlow.js.`;
17 |
18 | export const GRID_UNKNOWN_CYCLE_STATUS = status =>
19 | `Unknown cycle status: ${status}`;
20 |
21 | export const GRID_ERROR = status => `Grid error: ${status}`;
22 |
23 | export const MODEL_LOAD_FAILED = status => `Failed to load Model: ${status}`;
24 |
25 | export const PLAN_LOAD_FAILED = (planName, status) =>
26 | `Failed to load '${planName}' Plan: ${status}`;
27 |
28 | export const PROTOBUF_UNSERIALIZE_FAILED = (pbType, status) =>
29 | `Failed to unserialize binary protobuf data into ${pbType}: ${status}`;
30 |
--------------------------------------------------------------------------------
/test/events.test.js:
--------------------------------------------------------------------------------
1 | import EventObserver from '../src/events';
2 |
3 | describe('Event Observer', () => {
4 | jest.spyOn(console, 'log');
5 |
6 | afterEach(() => {
7 | jest.clearAllMocks();
8 | });
9 |
10 | test('can subscribe to and broadcast an event', () => {
11 | const observer = new EventObserver();
12 | const name = 'my-thing';
13 | const func = data => {
14 | console.log('hello', data);
15 | };
16 | const myData = { awesome: true };
17 |
18 | expect(observer.observers.length).toBe(0);
19 |
20 | observer.subscribe(name, func);
21 | observer.subscribe(`${name}-other`, func);
22 |
23 | expect(observer.observers.length).toBe(2);
24 | expect(console.log.mock.calls.length).toBe(0);
25 |
26 | observer.broadcast(name, myData);
27 |
28 | expect(console.log.mock.calls.length).toBe(1);
29 | expect(console.log.mock.calls[0][0]).toBe('hello');
30 | expect(console.log.mock.calls[0][1]).toBe(myData);
31 | });
32 |
33 | test('can unsubscribe to an event', () => {
34 | const observer = new EventObserver(),
35 | myType = 'hello';
36 |
37 | expect(observer.observers.length).toBe(0);
38 |
39 | observer.subscribe(myType, () => console.log('awesome'));
40 |
41 | expect(observer.observers.length).toBe(1);
42 |
43 | observer.unsubscribe(myType);
44 |
45 | expect(observer.observers.length).toBe(0);
46 | });
47 | });
48 |
--------------------------------------------------------------------------------
/src/_constants.js:
--------------------------------------------------------------------------------
1 | // Sockets
2 | export const SOCKET_STATUS = 'socket-status';
3 | export const SOCKET_PING = 'socket-ping';
4 |
5 | // Grid
6 | export const GET_PROTOCOL = 'get-protocol';
7 | export const CYCLE_STATUS_ACCEPTED = 'accepted';
8 | export const CYCLE_STATUS_REJECTED = 'rejected';
9 |
10 | // WebRTC
11 | export const WEBRTC_JOIN_ROOM = 'webrtc: join-room';
12 | export const WEBRTC_INTERNAL_MESSAGE = 'webrtc: internal-message';
13 | export const WEBRTC_PEER_LEFT = 'webrtc: peer-left';
14 |
15 | // WebRTC: Data Channel
16 | export const WEBRTC_DATACHANNEL_CHUNK_SIZE = 64 * 1024;
17 | export const WEBRTC_DATACHANNEL_MAX_BUFFER = 4 * 1024 * 1024;
18 | export const WEBRTC_DATACHANNEL_BUFFER_TIMEOUT = 2000;
19 | export const WEBRTC_DATACHANNEL_MAX_BUFFER_TIMEOUTS = 5;
20 |
21 | export const WEBRTC_PEER_CONFIG = {
22 | iceServers: [
23 | {
24 | urls: [
25 | 'stun:stun.l.google.com:19302',
26 | 'stun:stun1.l.google.com:19302',
27 | 'stun:stun2.l.google.com:19302' // FF says too many stuns are bad, don't send more than this
28 | ]
29 | }
30 | ]
31 | };
32 |
33 | export const WEBRTC_PEER_OPTIONS = {
34 | optional: [
35 | { DtlsSrtpKeyAgreement: true } // Required for connection between Chrome and Firefox
36 | // FF works w/o this option, but Chrome fails with it
37 | // { RtpDataChannels: true } // Required in Firefox to use the DataChannels API
38 | ]
39 | };
40 |
--------------------------------------------------------------------------------
/examples/multi-armed-bandit/webpack.config.js:
--------------------------------------------------------------------------------
1 | const path = require('path');
2 | const HtmlWebpackPlugin = require('html-webpack-plugin');
3 |
4 | module.exports = (env, argv) => ({
5 | mode: argv.mode,
6 | entry: ['regenerator-runtime/runtime', './index.js'],
7 | output: {
8 | path: path.join(__dirname, '/dist'),
9 | filename: 'index.bundle.js'
10 | },
11 | devtool: argv.mode === 'development' ? 'eval-source-map' : 'source-map',
12 | devServer: {
13 | port: 8080,
14 | hot: true,
15 | open: true,
16 | stats: {
17 | children: false, // Hide children information
18 | maxModules: 0 // Set the maximum number of modules to be shown
19 | }
20 | },
21 | module: {
22 | rules: [
23 | {
24 | test: /\.js$/,
25 | exclude: /node_modules/,
26 | use: {
27 | loader: 'babel-loader',
28 | options: {
29 | presets: ['@babel/preset-env', '@babel/preset-react'],
30 | plugins: ['@babel/plugin-proposal-class-properties']
31 | }
32 | }
33 | },
34 | {
35 | test: /\.svg$/,
36 | use: [
37 | {
38 | loader: 'svg-url-loader',
39 | options: {
40 | limit: 10000
41 | }
42 | }
43 | ]
44 | }
45 | ]
46 | },
47 | plugins: [new HtmlWebpackPlugin({ template: './index.html' })],
48 | externals: {
49 | '@tensorflow/tfjs-core': 'tf'
50 | }
51 | });
52 |
--------------------------------------------------------------------------------
/src/types/message.js:
--------------------------------------------------------------------------------
1 | import { unbufferize } from '../protobuf';
2 | import Logger from '../logger';
3 |
4 | export class Message {
5 | constructor(contents) {
6 | if (contents) {
7 | this.contents = contents;
8 | }
9 | this.logger = new Logger();
10 | }
11 | }
12 |
13 | export class ObjectMessage extends Message {
14 | constructor(contents) {
15 | super(contents);
16 | }
17 |
18 | static unbufferize(worker, pb) {
19 | const tensor = unbufferize(worker, pb.tensor);
20 | return new ObjectMessage(tensor);
21 | }
22 | }
23 |
24 | // TODO when types will be availbale in protobuf
25 |
26 | /*
27 | export class ObjectRequestMessage extends Message {
28 | constructor(contents) {
29 | super(contents);
30 | }
31 | }
32 |
33 | export class IsNoneMessage extends Message {
34 | constructor(contents) {
35 | super(contents);
36 | }
37 | }
38 |
39 | export class GetShapeMessage extends Message {
40 | constructor(contents) {
41 | super(contents);
42 | }
43 | }
44 |
45 | export class ForceObjectDeleteMessage extends Message {
46 | constructor(contents) {
47 | super(contents);
48 | }
49 | }
50 |
51 | export class SearchMessage extends Message {
52 | constructor(contents) {
53 | super(contents);
54 | }
55 | }
56 |
57 | export class PlanCommandMessage extends Message {
58 | constructor(commandName, message) {
59 | super();
60 |
61 | this.commandName = commandName;
62 | this.message = message;
63 | }
64 | }
65 | */
66 |
--------------------------------------------------------------------------------
/src/protobuf/mapping.js:
--------------------------------------------------------------------------------
1 | import { protobuf } from 'syft-proto';
2 | import Protocol from '../types/protocol';
3 | import { Plan } from '../types/plan';
4 | import { Role } from '../types/role';
5 | import { State } from '../types/state';
6 | import { ObjectMessage } from '../types/message';
7 | import { TorchParameter, TorchTensor } from '../types/torch';
8 | import { Placeholder, PlaceholderId } from '../types/placeholder';
9 | import { ComputationAction } from '../types/computation-action';
10 |
11 | let PB_CLASS_MAP, PB_TO_UNBUFFERIZER;
12 |
13 | // because of cyclic dependencies between Protocol/etc modules and protobuf module
14 | // Protocol/etc classes are undefined at the moment when this module is imported
15 | export const initMappings = () => {
16 | PB_CLASS_MAP = [
17 | [Protocol, protobuf.syft_proto.execution.v1.Protocol],
18 | [Plan, protobuf.syft_proto.execution.v1.Plan],
19 | [Role, protobuf.syft_proto.execution.v1.Role],
20 | [State, protobuf.syft_proto.execution.v1.State],
21 | [ComputationAction, protobuf.syft_proto.execution.v1.ComputationAction],
22 | [Placeholder, protobuf.syft_proto.execution.v1.Placeholder],
23 | [PlaceholderId, protobuf.syft_proto.execution.v1.PlaceholderId],
24 | [ObjectMessage, protobuf.syft_proto.messaging.v1.ObjectMessage],
25 | [TorchTensor, protobuf.syft_proto.types.torch.v1.TorchTensor],
26 | [TorchParameter, protobuf.syft_proto.types.torch.v1.Parameter]
27 | ];
28 |
29 | PB_TO_UNBUFFERIZER = PB_CLASS_MAP.reduce((map, item) => {
30 | map[item[1]] = item[0].unbufferize;
31 | return map;
32 | }, {});
33 | };
34 |
35 | export { PB_CLASS_MAP, PB_TO_UNBUFFERIZER };
36 |
--------------------------------------------------------------------------------
/test/types/message.test.js:
--------------------------------------------------------------------------------
1 | import { Message, ObjectMessage } from '../../src/types/message';
2 | import { TorchTensor } from '../../src/types/torch';
3 |
4 | describe('Message', () => {
5 | test('can be properly constructed', () => {
6 | const content = 'abc';
7 | const obj = new Message(content);
8 | expect(obj.contents).toBe(content);
9 | });
10 | });
11 |
12 | describe('ObjectMessage', () => {
13 | test('can be properly constructed', () => {
14 | const tensor = new TorchTensor(
15 | 123,
16 | new Float32Array([1, 2, 3, 4]),
17 | [2, 2],
18 | 'float32'
19 | );
20 | const obj = new ObjectMessage(tensor);
21 | expect(obj.contents).toBe(tensor);
22 | });
23 | });
24 |
25 | describe('ObjectRequestMessage', () => {
26 | test('can be properly constructed', () => {
27 | // TODO when type is available in protobuf
28 | });
29 | });
30 |
31 | describe('IsNoneMessage', () => {
32 | test('can be properly constructed', () => {
33 | // TODO when type is available in protobuf
34 | });
35 | });
36 |
37 | describe('GetShapeMessage', () => {
38 | test('can be properly constructed', () => {
39 | // TODO when type is available in protobuf
40 | });
41 | });
42 |
43 | describe('ForceObjectDeleteMessage', () => {
44 | test('can be properly constructed', () => {
45 | // TODO when type is available in protobuf
46 | });
47 | });
48 |
49 | describe('SearchMessage', () => {
50 | test('can be properly constructed', () => {
51 | // TODO when type is available in protobuf
52 | });
53 | });
54 |
55 | describe('PlanCommandMessage', () => {
56 | test('can be properly constructed', () => {
57 | // TODO when type is available in protobuf
58 | });
59 | });
60 |
--------------------------------------------------------------------------------
/test/mocks/webrtc.js:
--------------------------------------------------------------------------------
1 | // RTC classes mocks
2 | export class RTCPeerConnection {
3 | constructor(options, optional) {
4 | this.options = options;
5 | this.optional = optional;
6 | this.localDescription = null;
7 | this.remoteDescription = null;
8 | this.iceCandidates = [];
9 | this.sentMessages = [];
10 | this.ondatachannelListener = null;
11 | this.onicecandidateListener = null;
12 | }
13 |
14 | createDataChannel() {
15 | return {};
16 | }
17 |
18 | async createOffer() {
19 | return Promise.resolve({ type: 'offer', sdp: 'testOfferSdp' });
20 | }
21 |
22 | async createAnswer() {
23 | return Promise.resolve({ type: 'answer', sdp: 'testAnswerSdp' });
24 | }
25 |
26 | async setLocalDescription(sessionDescription) {
27 | this.localDescription = sessionDescription;
28 | return Promise.resolve();
29 | }
30 |
31 | async setRemoteDescription(sessionDescription) {
32 | this.remoteDescription = sessionDescription;
33 | return Promise.resolve();
34 | }
35 |
36 | async addIceCandidate(iceCandidate) {
37 | this.iceCandidates.push(iceCandidate);
38 | return Promise.resolve();
39 | }
40 |
41 | get onicecandidate() {
42 | return this.onicecandidateListener;
43 | }
44 |
45 | set onicecandidate(cb) {
46 | this.onicecandidateListener = cb;
47 | }
48 |
49 | get ondatachannel() {
50 | return this.ondatachannelListener;
51 | }
52 |
53 | set ondatachannel(cb) {
54 | this.ondatachannelListener = cb;
55 | }
56 | }
57 |
58 | export class RTCSessionDescription {
59 | constructor(sessionDescription) {
60 | Object.assign(this, sessionDescription);
61 | }
62 | }
63 |
64 | export class RTCIceCandidate {
65 | constructor(iceCandidate) {
66 | Object.assign(this, iceCandidate);
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/src/types/placeholder.js:
--------------------------------------------------------------------------------
1 | import { protobuf, getPbId, pbId } from '../protobuf';
2 |
3 | export class PlaceholderId {
4 | constructor(id) {
5 | this.id = id;
6 | }
7 |
8 | static unbufferize(worker, pb) {
9 | return new PlaceholderId(getPbId(pb.id));
10 | }
11 |
12 | bufferize(/* worker */) {
13 | return protobuf.syft_proto.execution.v1.PlaceholderId.create({
14 | id: pbId(this.id)
15 | });
16 | }
17 | }
18 |
19 | export class Placeholder {
20 | constructor(id, tags = [], description = null, expected_shape = null) {
21 | this.id = id;
22 | this.tags = tags;
23 | this.description = description;
24 | this.expected_shape = expected_shape;
25 | }
26 |
27 | static unbufferize(worker, pb) {
28 | let expected_shape = null;
29 | if (
30 | pb.expected_shape &&
31 | Array.isArray(pb.expected_shape.dims) &&
32 | pb.expected_shape.dims.length > 0
33 | ) {
34 | // Unwrap Shape
35 | expected_shape = pb.expected_shape.dims;
36 | }
37 |
38 | return new Placeholder(
39 | getPbId(pb.id),
40 | pb.tags || [],
41 | pb.description,
42 | expected_shape
43 | );
44 | }
45 |
46 | bufferize(/* worker */) {
47 | return protobuf.syft_proto.execution.v1.Placeholder.create({
48 | id: pbId(this.id),
49 | tags: this.tags,
50 | description: this.description,
51 | expected_shape: protobuf.syft_proto.types.syft.v1.Shape.create(
52 | this.expected_shape
53 | )
54 | });
55 | }
56 |
57 | getOrderFromTags(prefix) {
58 | const regExp = new RegExp(`^${prefix}-(\\d+)$`, 'i');
59 | for (let tag of this.tags) {
60 | let tagMatch = regExp[Symbol.match](tag);
61 | if (tagMatch) {
62 | return Number(tagMatch[1]);
63 | }
64 | }
65 | throw new Error(`Placeholder ${this.id} doesn't have order tag ${prefix}`);
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/examples/multi-armed-bandit/README.md:
--------------------------------------------------------------------------------
1 | # syft.js Multi-armed Bandit Example
2 |
3 | This is a demonstration of how to use [syft.js](https://github.com/openmined/syft.js)
4 | with [PyGrid](https://github.com/OpenMined/pygrid) to train [a multi-armed bandit](https://vwo.com/blog/multi-armed-bandit-algorithm/) in the browser. A multi-armed bandit can be used to perform automated A/B testing on a website or application, while gradually converging your users to the ideal user-experience, given a goal of your choosing.
5 |
6 | In this demo, we're automatically generating various website layouts that we want our users to view. There are subtle changes made to the website every time you load the website again, including things like changes in button size, color, or position on the page. In the background, syft.js will track which layouts the user does what we want (click a button) and report a positive model diff for that particular layout. For all other layouts where the user doesn't click on the button, we do not report anything. Over time, our model will slowly start to converge on a "preferred user experience" for the best layout, as chosen by user actions.
7 |
8 | While this demo is inherently simple, it's easy to see how one could extend it to a real-world application whereby website layouts are generated and tested by real users, slowly converging to the preferred UX. We're particuarly excited to see derivations of this demo in real-world web and mobile development!
9 |
10 | ## Quick Start
11 |
12 | 1. Install and start [PyGrid](https://github.com/OpenMined/pygrid)
13 | 2. Install [PySyft](https://github.com/OpenMined/PySyft) and ... TODO: @maddie - fill this in here
14 | 3. Now back in this folder, execute `npm install`
15 | 4. And then execute `npm start`
16 |
17 | This will launch a web browser running the file `index.js`. Every time you make changes to this file, the server will automatically re-compile your code and refresh the page. No need to start and stop. :)
18 |
--------------------------------------------------------------------------------
/src/syft-model.js:
--------------------------------------------------------------------------------
1 | import { unserialize, protobuf, serialize } from './protobuf';
2 | import { State } from './types/state';
3 | import { TorchTensor } from './types/torch';
4 | import { Placeholder } from './types/placeholder';
5 | import { MODEL_LOAD_FAILED } from './_errors';
6 |
7 | /**
8 | * Model parameters as stored in the PyGrid.
9 | *
10 | * @property {Array.} params Array of Model parameters.
11 | */
12 | export default class SyftModel {
13 | /**
14 | * @hideconstructor
15 | * @param {Object} options
16 | * @param {Syft} options.worker Instance of Syft client.
17 | * @param {ArrayBuffer} options.modelData Serialized Model parameters as returned by PyGrid.
18 | */
19 | constructor({ worker, modelData }) {
20 | try {
21 | const state = unserialize(
22 | worker,
23 | modelData,
24 | protobuf.syft_proto.execution.v1.State
25 | );
26 | this.worker = worker;
27 | this.params = state.getTfTensors();
28 | } catch (e) {
29 | throw new Error(MODEL_LOAD_FAILED(e.message));
30 | }
31 | }
32 |
33 | /**
34 | * Calculates difference between 2 versions of the Model parameters
35 | * and returns serialized `diff` that can be submitted to PyGrid.
36 | *
37 | * @param {Array.} updatedModelParams Array of model parameters (tensors).
38 | * @returns {Promise} Protobuf-serialized `diff`.
39 | */
40 | async createSerializedDiff(updatedModelParams) {
41 | const placeholders = [],
42 | tensors = [];
43 |
44 | for (let i = 0; i < updatedModelParams.length; i++) {
45 | let paramDiff = this.params[i].sub(updatedModelParams[i]);
46 | placeholders.push(new Placeholder(i, [`#${i}`, `#state-${i}`]));
47 | tensors.push(await TorchTensor.fromTfTensor(paramDiff));
48 | }
49 | const state = new State(placeholders, tensors);
50 | const bin = serialize(this.worker, state);
51 |
52 | // Free memory.
53 | tensors.forEach(t => t._tfTensor.dispose());
54 |
55 | return bin;
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/src/sockets.js:
--------------------------------------------------------------------------------
1 | import { SOCKET_PING } from './_constants';
2 | import Logger from './logger';
3 |
4 | export default class Sockets {
5 | constructor({
6 | url,
7 | workerId,
8 | onOpen,
9 | onClose,
10 | onMessage,
11 | keepAliveTimeout = 20000
12 | }) {
13 | this.logger = new Logger();
14 | const socket = new WebSocket(url);
15 |
16 | const keepAlive = () => {
17 | this.send(SOCKET_PING);
18 | this.timerId = setTimeout(keepAlive, keepAliveTimeout);
19 | };
20 |
21 | const cancelKeepAlive = () => {
22 | clearTimeout(this.timerId);
23 | this.timerId = null;
24 | };
25 |
26 | socket.onopen = event => {
27 | this.logger.log(
28 | `Opening socket connection at ${event.currentTarget.url}`,
29 | event
30 | );
31 |
32 | keepAlive();
33 |
34 | if (onOpen) onOpen(event);
35 | };
36 |
37 | socket.onclose = event => {
38 | this.logger.log(
39 | `Closing socket connection at ${event.currentTarget.url}`,
40 | event
41 | );
42 |
43 | cancelKeepAlive();
44 |
45 | if (onClose) onClose(event);
46 | };
47 |
48 | this.url = url;
49 | this.workerId = workerId;
50 | this.socket = socket;
51 | this.onMessage = onMessage;
52 | this.timerId = null;
53 | }
54 |
55 | send(type, data = {}) {
56 | return new Promise((resolve, reject) => {
57 | data.workerId = this.workerId;
58 |
59 | const message = { type, data };
60 |
61 | this.logger.log('Sending message', message);
62 |
63 | this.socket.send(JSON.stringify(message));
64 |
65 | this.socket.onmessage = event => {
66 | const data = JSON.parse(event.data);
67 |
68 | this.logger.log('Receiving message', data);
69 |
70 | resolve(this.onMessage(data));
71 | };
72 |
73 | this.socket.onerror = event => {
74 | this.logger.log('We have a socket error!', event);
75 |
76 | reject(event);
77 | };
78 | });
79 | }
80 |
81 | stop() {
82 | this.socket.close();
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/test/types/torch.test.js:
--------------------------------------------------------------------------------
1 | import { TorchParameter, TorchTensor, TorchSize } from '../../src/types/torch';
2 | import * as tf from '@tensorflow/tfjs-core';
3 |
4 | describe('TorchTensor', () => {
5 | test('can be properly constructed', () => {
6 | const chain = new TorchTensor(
7 | 555,
8 | new Float32Array([6, 6, 6, 6]),
9 | [2, 2],
10 | 'float32'
11 | );
12 | const grad = new TorchTensor(
13 | 666,
14 | new Float32Array([6, 6, 6, 6]),
15 | [2, 2],
16 | 'float32'
17 | );
18 | const obj = new TorchTensor(
19 | 123,
20 | new Float32Array([1, 2, 3.3, 4]),
21 | [2, 2],
22 | 'float32',
23 | chain,
24 | grad,
25 | ['tag1', 'tag2'],
26 | 'desc'
27 | );
28 | const tfTensor = tf.tensor([1, 2, 3.3, 4], [2, 2], 'float32');
29 |
30 | expect(obj.id).toStrictEqual(123);
31 | expect(obj.shape).toStrictEqual([2, 2]);
32 | expect(obj.dtype).toStrictEqual('float32');
33 | expect(obj.contents).toStrictEqual(new Float32Array([1, 2, 3.3, 4]));
34 |
35 | // resulting TF tensors are equal
36 | expect(
37 | tf
38 | .equal(obj.toTfTensor(), tfTensor)
39 | .all()
40 | .dataSync()[0]
41 | ).toBe(1);
42 |
43 | expect(obj.chain).toBe(chain);
44 | expect(obj.gradChain).toBe(grad);
45 | expect(obj.tags).toStrictEqual(['tag1', 'tag2']);
46 | expect(obj.description).toStrictEqual('desc');
47 | });
48 | });
49 |
50 | describe('TorchSize', () => {
51 | test('can be properly constructed', () => {
52 | const obj = new TorchSize([2, 3]);
53 | expect(obj.size).toStrictEqual([2, 3]);
54 | });
55 | });
56 |
57 | describe('TorchParameter', () => {
58 | test('can be properly constructed', () => {
59 | const grad = new TorchTensor(
60 | 666,
61 | new Float32Array([6, 6, 6, 6]),
62 | [2, 2],
63 | 'float32'
64 | );
65 | const tensor = new TorchTensor(
66 | 123,
67 | new Float32Array([1, 2, 3.3, 4]),
68 | [2, 2],
69 | 'float32'
70 | );
71 | const obj = new TorchParameter(123, tensor, true, grad);
72 |
73 | expect(obj.id).toStrictEqual(123);
74 | expect(obj.tensor).toBe(tensor);
75 | expect(obj.requiresGrad).toStrictEqual(true);
76 | expect(obj.grad).toBe(grad);
77 | });
78 | });
79 |
--------------------------------------------------------------------------------
/test/data_channel_message_queue.test.js:
--------------------------------------------------------------------------------
1 | import DataChannelMessage from '../src/data-channel-message';
2 | import DataChannelMessageQueue from '../src/data-channel-message-queue';
3 | import { WEBRTC_DATACHANNEL_CHUNK_SIZE } from '../src';
4 | import { randomFillSync } from 'crypto';
5 |
6 | describe('Data Channel Message Queue', () => {
7 | afterEach(() => {
8 | jest.clearAllMocks();
9 | });
10 |
11 | test('can register a message', () => {
12 | const queue = new DataChannelMessageQueue();
13 | const message = new DataChannelMessage({ id: 123 });
14 | queue.register(message);
15 | expect(queue.isRegistered(message.id)).toBe(true);
16 | expect(queue.register(message)).toBe(false);
17 | });
18 |
19 | test('can unregister a message', () => {
20 | const queue = new DataChannelMessageQueue();
21 | const message = new DataChannelMessage({ id: 123 });
22 | queue.register(message);
23 | expect(queue.isRegistered(message.id)).toBe(true);
24 | queue.unregister(message);
25 | expect(queue.isRegistered(message.id)).toBe(false);
26 | });
27 |
28 | test('can get a registered message', () => {
29 | const queue = new DataChannelMessageQueue();
30 | const message = new DataChannelMessage({ id: 123 });
31 | queue.register(message);
32 | expect(queue.getById(message.id)).toBe(message);
33 | queue.unregister(message);
34 | expect(queue.getById(message.id)).toBe(undefined);
35 | });
36 |
37 | test('emits "message" event when the message is ready', done => {
38 | const buf = new ArrayBuffer(WEBRTC_DATACHANNEL_CHUNK_SIZE + 100);
39 | randomFillSync(new Uint8Array(buf), 0, buf.byteLength);
40 | const messageOrig = new DataChannelMessage({ data: buf });
41 | const chunk1 = messageOrig.getChunk(0);
42 | const chunk2 = messageOrig.getChunk(1);
43 | const info1 = DataChannelMessage.messageInfoFromBuf(chunk1);
44 |
45 | const messageAssembled = new DataChannelMessage({ id: info1.id });
46 | const queue = new DataChannelMessageQueue();
47 | queue.register(messageAssembled);
48 |
49 | queue.on('message', message => {
50 | expect(message.chunks).toBe(messageOrig.chunks);
51 | expect(message.size).toBe(messageOrig.size);
52 | const orig = new Uint8Array(messageOrig.data);
53 | const assembled = new Uint8Array(message.data);
54 | expect(orig.every((v, i) => assembled[i] === v)).toBe(true);
55 | done();
56 | });
57 |
58 | messageAssembled.addChunk(chunk1);
59 | messageAssembled.addChunk(chunk2);
60 | });
61 | });
62 |
--------------------------------------------------------------------------------
/.all-contributorsrc:
--------------------------------------------------------------------------------
1 | {
2 | "files": [
3 | "README.md"
4 | ],
5 | "imageSize": 100,
6 | "commit": false,
7 | "contributors": [
8 | {
9 | "login": "cereallarceny",
10 | "name": "Patrick Cason",
11 | "avatar_url": "https://avatars1.githubusercontent.com/u/1297930?v=4",
12 | "profile": "https://www.patrickcason.com",
13 | "contributions": [
14 | "ideas",
15 | "code",
16 | "design",
17 | "doc",
18 | "business"
19 | ]
20 | },
21 | {
22 | "login": "vvmnnnkv",
23 | "name": "Vova Manannikov",
24 | "avatar_url": "https://avatars2.githubusercontent.com/u/12518480?v=4",
25 | "profile": "https://www.linkedin.com/in/vova-manannikov",
26 | "contributions": [
27 | "code",
28 | "doc",
29 | "test"
30 | ]
31 | },
32 | {
33 | "login": "Nolski",
34 | "name": "Mike Nolan",
35 | "avatar_url": "https://avatars3.githubusercontent.com/u/2600677?v=4",
36 | "profile": "http://nolski.rocks",
37 | "contributions": [
38 | "code"
39 | ]
40 | },
41 | {
42 | "login": "IamRavikantSingh",
43 | "name": "Ravikant Singh",
44 | "avatar_url": "https://avatars3.githubusercontent.com/u/40258150?v=4",
45 | "profile": "http://ravikantsingh.com",
46 | "contributions": [
47 | "code",
48 | "test"
49 | ]
50 | },
51 | {
52 | "login": "vkkhare",
53 | "name": "varun khare",
54 | "avatar_url": "https://avatars1.githubusercontent.com/u/18126069?v=4",
55 | "profile": "http://vkkhare.github.io",
56 | "contributions": [
57 | "code"
58 | ]
59 | },
60 | {
61 | "login": "pedroespindula",
62 | "name": "Pedro Espíndula",
63 | "avatar_url": "https://avatars1.githubusercontent.com/u/38431219?v=4",
64 | "profile": "https://github.com/pedroespindula",
65 | "contributions": [
66 | "doc"
67 | ]
68 | },
69 | {
70 | "login": "Benardi",
71 | "name": "José Benardi de Souza Nunes",
72 | "avatar_url": "https://avatars0.githubusercontent.com/u/9937551?v=4",
73 | "profile": "https://benardi.github.io/myblog/",
74 | "contributions": [
75 | "test"
76 | ]
77 | },
78 | {
79 | "login": "tsingh2k15",
80 | "name": "Tajinder Singh",
81 | "avatar_url": "https://avatars1.githubusercontent.com/u/25232829?v=4",
82 | "profile": "http://www.linkedin.com/in/singh-taj",
83 | "contributions": [
84 | "code"
85 | ]
86 | }
87 | ],
88 | "contributorsPerLine": 7,
89 | "projectName": "syft.js",
90 | "projectOwner": "OpenMined",
91 | "repoType": "github",
92 | "repoHost": "https://github.com",
93 | "skipCi": true
94 | }
95 |
--------------------------------------------------------------------------------
/test/logger.test.js:
--------------------------------------------------------------------------------
1 | import Logger from '../src/logger';
2 |
3 | describe('Logger', () => {
4 | jest.spyOn(console, 'log');
5 |
6 | afterEach(() => {
7 | jest.clearAllMocks();
8 | Logger.instance = null;
9 | });
10 |
11 | test('can skip when verbose is false', () => {
12 | const testLogger = new Logger('syft.js', false),
13 | message = 'hello';
14 |
15 | expect(testLogger.verbose).toBe(false);
16 | expect(console.log.mock.calls.length).toBe(0);
17 |
18 | testLogger.log(message);
19 |
20 | expect(console.log.mock.calls.length).toBe(0);
21 | });
22 |
23 | test('singleton instance should be used with verbose false', () => {
24 | const testLogger = new Logger('syft.js', false);
25 | const testLogger_1 = new Logger('syft.js', true);
26 |
27 | expect(testLogger).toEqual(testLogger_1);
28 |
29 | expect(testLogger.verbose).toBe(false);
30 | expect(testLogger_1.verbose).toBe(false);
31 | expect(console.log.mock.calls.length).toBe(0);
32 |
33 | testLogger.log('hello singleton!!');
34 | testLogger_1.log('hello singleton!!');
35 |
36 | expect(console.log.mock.calls.length).toBe(0);
37 | });
38 |
39 | test('singleton instance should be used with verbose true', () => {
40 | const testLogger = new Logger('syft.js', true);
41 | const testLogger_1 = new Logger('syft.js', false);
42 |
43 | expect(testLogger).toEqual(testLogger_1);
44 |
45 | expect(testLogger.verbose).toBe(true);
46 | expect(testLogger_1.verbose).toBe(true);
47 | expect(console.log.mock.calls.length).toBe(0);
48 |
49 | testLogger.log('hello singleton!!');
50 | testLogger_1.log('hello singleton!!');
51 |
52 | expect(console.log.mock.calls.length).toBe(2);
53 | });
54 |
55 | test('can log under verbose mode', () => {
56 | const testLogger = new Logger('syft.js', true),
57 | message = 'hello';
58 |
59 | expect(testLogger.verbose).toBe(true);
60 | expect(console.log.mock.calls.length).toBe(0);
61 |
62 | const currentTime = Date.now();
63 | testLogger.log(message);
64 |
65 | expect(console.log.mock.calls.length).toBe(1);
66 | expect(console.log.mock.calls[0][0]).toBe(
67 | `${currentTime}: syft.js - ${message}`
68 | );
69 | });
70 |
71 | test('can log with data', () => {
72 | const testLogger = new Logger('syft.js', true),
73 | message = 'hello',
74 | myObj = { awesome: true };
75 |
76 | expect(testLogger.verbose).toBe(true);
77 | expect(console.log.mock.calls.length).toBe(0);
78 |
79 | testLogger.log(message, myObj);
80 |
81 | expect(console.log.mock.calls.length).toBe(1);
82 | expect(console.log.mock.calls[0][0]).toContain(`: syft.js - ${message}`);
83 | expect(console.log.mock.calls[0][1]).toStrictEqual(myObj);
84 | });
85 | });
86 |
--------------------------------------------------------------------------------
/src/syft.js:
--------------------------------------------------------------------------------
1 | import EventObserver from './events';
2 | import Logger from './logger';
3 | import GridAPIClient from './grid-api-client';
4 | import Job from './job';
5 | import ObjectRegistry from './object-registry';
6 |
7 | /**
8 | * Syft client for static federated learning.
9 | *
10 | * @param {Object} options
11 | * @param {string} options.url Full URL to PyGrid app (`ws` and `http` schemas supported).
12 | * @param {boolean} options.verbose Whether to enable logging and allow unsecured PyGrid connection.
13 | * @param {string} options.authToken PyGrid authentication token.
14 | * @param {Object} options.peerConfig [not implemented] WebRTC peer config used with RTCPeerConnection.
15 | *
16 | * @example
17 | *
18 | * const client = new Syft({url: "ws://localhost:5000", verbose: true})
19 | * const job = client.newJob({modelName: "mnist", modelVersion: "1.0.0"})
20 | * job.on('accepted', async ({model, clientConfig}) => {
21 | * // execute training
22 | * const [...newParams] = await this.plans['...'].execute(...)
23 | * const diff = await model.createSerializedDiff(newParams)
24 | * await this.report(diff)
25 | * })
26 | * job.on('rejected', ({timeout}) => {
27 | * // re-try later or stop
28 | * })
29 | * job.on('error', (err) => {
30 | * // handle errors
31 | * })
32 | * job.start()
33 | */
34 | export default class Syft {
35 | constructor({ url, verbose, authToken, peerConfig }) {
36 | // For creating verbose logging should the worker desire
37 | this.logger = new Logger('syft.js', verbose);
38 |
39 | // Forcing connection to be secure if verbose value is false.
40 | this.verbose = verbose;
41 |
42 | this.gridClient = new GridAPIClient({ url, allowInsecureUrl: verbose });
43 |
44 | // objects registry
45 | this.objects = new ObjectRegistry();
46 |
47 | // For creating event listeners
48 | this.observer = new EventObserver();
49 |
50 | this.worker_id = null;
51 | this.peerConfig = peerConfig;
52 | this.authToken = authToken;
53 | }
54 |
55 | /**
56 | * Authenticates the client against PyGrid and instantiates new Job with given options.
57 | *
58 | * @throws Error
59 | * @param {Object} options
60 | * @param {string} options.modelName FL Model name.
61 | * @param {string} options.modelVersion FL Model version.
62 | * @returns {Promise}
63 | */
64 | async newJob({ modelName, modelVersion }) {
65 | if (!this.worker_id) {
66 | // authenticate
67 | const authResponse = await this.gridClient.authenticate(
68 | modelName,
69 | modelVersion,
70 | this.authToken
71 | );
72 | this.worker_id = authResponse.worker_id;
73 | }
74 |
75 | return new Job({
76 | worker: this,
77 | modelName,
78 | modelVersion,
79 | gridClient: this.gridClient
80 | });
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/test/mocks/grid.js:
--------------------------------------------------------------------------------
1 | import { WebSocket, Server } from 'mock-socket';
2 | import fetchMock from 'fetch-mock';
3 | global.WebSocket = WebSocket;
4 |
5 | export class GridMock {
6 | constructor(hostname = 'localhost', port = 8080) {
7 | this.hostname = hostname;
8 | this.port = port;
9 | this.ws = new Server(`ws://${hostname}:${port}`);
10 |
11 | this.wsConnections = [];
12 | this.ws.on('connection', socket => {
13 | this.wsConnections.push(socket);
14 | socket.on('message', message => this.messageHandler(socket, message));
15 | socket.on('error', err => console.log('WS ERROR', err));
16 | socket.on('close', err => console.log('WS CLOSE', err));
17 | });
18 |
19 | this.wsMessagesHistory = [];
20 | }
21 |
22 | setAuthenticationResponse(data) {
23 | this.authResponse = data;
24 | }
25 |
26 | setCycleResponse(data) {
27 | this.cycleResponse = data;
28 | }
29 |
30 | setReportResponse(data) {
31 | this.reportResponse = data;
32 | }
33 |
34 | _setHttpResponse(method, url, query, data, status) {
35 | let contentType = 'application/json';
36 | let json = true;
37 | if (data instanceof ArrayBuffer || data instanceof Buffer) {
38 | contentType = 'application/octet-stream';
39 | json = false;
40 | }
41 |
42 | fetchMock[method](
43 | { url, query },
44 | {
45 | body: data,
46 | status,
47 | headers: {
48 | 'Content-Type': contentType
49 | }
50 | },
51 | { sendAsJson: json }
52 | );
53 | }
54 |
55 | setModel(model_id, data, status = 200) {
56 | this._setHttpResponse(
57 | 'get',
58 | `http://${this.hostname}:${this.port}/federated/get-model`,
59 | { model_id },
60 | data,
61 | status
62 | );
63 | }
64 |
65 | setPlan(plan_id, data, status = 200) {
66 | this._setHttpResponse(
67 | 'get',
68 | `http://${this.hostname}:${this.port}/federated/get-plan`,
69 | { plan_id },
70 | data,
71 | status
72 | );
73 | }
74 |
75 | messageHandler(socket, message) {
76 | const data = JSON.parse(message);
77 | this.wsMessagesHistory.push(data);
78 | switch (data.type) {
79 | case 'federated/authenticate':
80 | socket.send(
81 | JSON.stringify({
82 | type: data.type,
83 | data: this.authResponse
84 | })
85 | );
86 | break;
87 |
88 | case 'federated/cycle-request':
89 | socket.send(
90 | JSON.stringify({
91 | type: data.type,
92 | data: this.cycleResponse
93 | })
94 | );
95 | break;
96 |
97 | case 'federated/report':
98 | socket.send(
99 | JSON.stringify({
100 | type: data.type,
101 | data: this.reportResponse
102 | })
103 | );
104 | break;
105 | }
106 | }
107 |
108 | stop() {
109 | this.ws.close();
110 | fetchMock.reset();
111 | }
112 | }
113 |
--------------------------------------------------------------------------------
/src/types/torch.js:
--------------------------------------------------------------------------------
1 | import { getPbId, unbufferize, protobuf, pbId } from '../protobuf';
2 | import * as tf from '@tensorflow/tfjs-core';
3 |
4 | export class TorchTensor {
5 | constructor(
6 | id,
7 | contents,
8 | shape,
9 | dtype,
10 | chain = null,
11 | gradChain = null,
12 | tags = [],
13 | description = null
14 | ) {
15 | this.id = id;
16 | this.shape = shape;
17 | this.dtype = dtype;
18 | this.contents = contents;
19 | this.chain = chain;
20 | this.gradChain = gradChain;
21 | this.tags = tags;
22 | this.description = description;
23 | this._tfTensor = null;
24 | }
25 |
26 | toTfTensor() {
27 | if (!this._tfTensor) {
28 | this._tfTensor = tf.tensor(this.contents, this.shape, this.dtype);
29 | }
30 | return this._tfTensor;
31 | }
32 |
33 | static unbufferize(worker, pb) {
34 | if (
35 | pb.serializer !==
36 | protobuf.syft_proto.types.torch.v1.TorchTensor.Serializer.SERIALIZER_ALL
37 | ) {
38 | throw new Error(
39 | `Tensor serializer ${pb.serializer} is not supported in syft.js`
40 | );
41 | }
42 |
43 | // unwrap TensorData
44 | const tensorData = pb.contents_data;
45 | const dtype = tensorData.dtype;
46 | const shape = tensorData.shape.dims;
47 | const contents = tensorData[`contents_${dtype}`];
48 |
49 | return new TorchTensor(
50 | getPbId(pb.id),
51 | contents,
52 | shape,
53 | dtype,
54 | unbufferize(worker, pb.chain),
55 | unbufferize(worker, pb.grad_chain),
56 | pb.tags,
57 | pb.description
58 | );
59 | }
60 |
61 | bufferize(/* worker */) {
62 | const tensorData = {
63 | shape: protobuf.syft_proto.types.torch.v1.Size.create({
64 | dims: this.shape
65 | }),
66 | dtype: this.dtype
67 | };
68 | tensorData[`contents_${this.dtype}`] = this.contents;
69 | const pbTensorData = protobuf.syft_proto.types.torch.v1.TensorData.create(
70 | tensorData
71 | );
72 | return protobuf.syft_proto.types.torch.v1.TorchTensor.create({
73 | id: pbId(this.id),
74 | serializer:
75 | protobuf.syft_proto.types.torch.v1.TorchTensor.Serializer
76 | .SERIALIZER_ALL,
77 | contents_data: pbTensorData,
78 | tags: this.tags,
79 | description: this.description
80 | });
81 | }
82 |
83 | static async fromTfTensor(tensor) {
84 | const flat = tensor.flatten();
85 | const array = await flat.array();
86 | flat.dispose();
87 | const t = new TorchTensor(tensor.id, array, tensor.shape, tensor.dtype);
88 | t._tfTensor = tensor;
89 | return t;
90 | }
91 | }
92 |
93 | export class TorchSize {
94 | constructor(size) {
95 | this.size = size;
96 | }
97 | }
98 |
99 | export class TorchParameter {
100 | constructor(id, tensor, requiresGrad, grad) {
101 | this.id = id;
102 | this.tensor = tensor;
103 | this.requiresGrad = requiresGrad;
104 | this.grad = grad;
105 | }
106 |
107 | static unbufferize(worker, pb) {
108 | return new TorchParameter(
109 | getPbId(pb.id),
110 | unbufferize(worker, pb.tensor),
111 | pb.requires_grad,
112 | unbufferize(worker, pb.grad)
113 | );
114 | }
115 |
116 | toTfTensor() {
117 | return this.tensor.toTfTensor();
118 | }
119 | }
120 |
--------------------------------------------------------------------------------
/src/protobuf/index.js:
--------------------------------------------------------------------------------
1 | import { NO_DETAILER, PROTOBUF_UNSERIALIZE_FAILED } from '../_errors';
2 | import { initMappings, PB_TO_UNBUFFERIZER } from './mapping';
3 | import { protobuf } from 'syft-proto';
4 | import Long from 'long';
5 | export { protobuf };
6 |
7 | export const unbufferize = (worker, pbObj) => {
8 | if (!PB_TO_UNBUFFERIZER) {
9 | initMappings();
10 | }
11 |
12 | if (
13 | pbObj === undefined ||
14 | pbObj === null ||
15 | ['number', 'string', 'boolean'].includes(typeof pbObj)
16 | ) {
17 | return pbObj;
18 | }
19 |
20 | const pbType = pbObj.constructor;
21 |
22 | // automatically unbufferize repeated fields
23 | if (Array.isArray(pbObj)) {
24 | return pbObj.map(item => unbufferize(worker, item));
25 | }
26 |
27 | // automatically unbufferize map fields
28 | if (pbType.name === 'Object') {
29 | let res = {};
30 | for (let key of Object.keys(pbObj)) {
31 | res[key] = unbufferize(worker, pbObj[key]);
32 | }
33 | return res;
34 | }
35 |
36 | // automatically unbufferize Id
37 | if (pbType === protobuf.syft_proto.types.syft.v1.Id) {
38 | return getPbId(pbObj);
39 | }
40 |
41 | // automatically unwrap Arg
42 | if (pbType === protobuf.syft_proto.types.syft.v1.Arg) {
43 | if (pbObj.arg === 'arg_int' && pbObj[pbObj.arg] instanceof Long) {
44 | // protobuf int64 is represented as Long
45 | return pbObj[pbObj.arg].toNumber();
46 | } else {
47 | return unbufferize(worker, pbObj[pbObj.arg]);
48 | }
49 | }
50 |
51 | // automatically unwrap ArgList
52 | if (pbType === protobuf.syft_proto.types.syft.v1.ArgList) {
53 | return unbufferize(worker, pbObj.args);
54 | }
55 |
56 | const unbufferizer = PB_TO_UNBUFFERIZER[pbType];
57 | if (typeof unbufferizer === 'undefined') {
58 | throw new Error(NO_DETAILER(pbType.name));
59 | }
60 | return unbufferizer(worker, pbObj);
61 | };
62 |
63 | /**
64 | * Converts binary in the form of ArrayBuffer or base64 string to syft class
65 | * @param worker
66 | * @param bin
67 | * @param pbType
68 | * @returns {Object}
69 | */
70 | export const unserialize = (worker, bin, pbType) => {
71 | const buff =
72 | typeof bin === 'string'
73 | ? Buffer.from(bin, 'base64')
74 | : bin instanceof ArrayBuffer
75 | ? new Uint8Array(bin)
76 | : bin;
77 | let pbObj;
78 | try {
79 | pbObj = pbType.decode(buff);
80 | } catch (e) {
81 | throw new Error(PROTOBUF_UNSERIALIZE_FAILED(pbType.name, e.message));
82 | }
83 | return unbufferize(worker, pbObj);
84 | };
85 |
86 | /**
87 | * Converts syft class to protobuf-serialized binary
88 | * @param worker
89 | * @param obj
90 | * @returns {ArrayBuffer}
91 | */
92 | export const serialize = (worker, obj) => {
93 | const pbObj = obj.bufferize(worker);
94 | const pbType = pbObj.constructor;
95 | const err = pbType.verify(pbObj);
96 | if (err) {
97 | throw new Error(err);
98 | }
99 | const bin = pbType.encode(pbObj).finish();
100 | return new Uint8Array(bin).buffer;
101 | };
102 |
103 | export const getPbId = field => {
104 | // convert int64 to string
105 | return field[field.id].toString();
106 | };
107 |
108 | export const pbId = value => {
109 | if (typeof value === 'number') {
110 | return protobuf.syft_proto.types.syft.v1.Id.create({ id_int: value });
111 | } else if (typeof value === 'string') {
112 | return protobuf.syft_proto.types.syft.v1.Id.create({ id_str: value });
113 | }
114 | };
115 |
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "@openmined/syft.js",
3 | "version": "0.0.1-1",
4 | "description": "A Javascript Syft worker in the browser",
5 | "main": "dist/index.cjs.js",
6 | "module": "dist/index.esm.js",
7 | "browser": "dist/index.js",
8 | "files": [
9 | "dist/*.js",
10 | "*.md"
11 | ],
12 | "repository": {
13 | "type": "git",
14 | "url": "git+https://github.com/OpenMined/syft.js.git"
15 | },
16 | "keywords": [
17 | "syft",
18 | "pysyft",
19 | "openmined",
20 | "open",
21 | "mined",
22 | "deep",
23 | "learning",
24 | "private",
25 | "javascript",
26 | "machine",
27 | "learning"
28 | ],
29 | "author": "OpenMined",
30 | "license": "Apache-2.0",
31 | "bugs": {
32 | "url": "https://github.com/OpenMined/syft.js/issues"
33 | },
34 | "homepage": "https://github.com/OpenMined/syft.js#readme",
35 | "scripts": {
36 | "start": "npm run lint && rollup -cw",
37 | "build": "npm run lint && rollup -c",
38 | "prepare": "npm run build",
39 | "test": "npm run lint && jest --coverage",
40 | "test:watch": "npm run lint && jest --watch",
41 | "version": "auto-changelog -p && git add CHANGELOG.md",
42 | "release": "np",
43 | "deploy": "./copy-examples.sh && gh-pages -d tmp && rm -rf tmp",
44 | "lint": "eslint .",
45 | "doc": "documentation build --config documentation.yml src/syft.js src/syft-model.js src/job.js src/types/plan.js --shallow -f md -o API-REFERENCE.md"
46 | },
47 | "browserslist": "> 0.25%, not dead",
48 | "husky": {
49 | "hooks": {
50 | "pre-commit": "npm run doc && pretty-quick --staged"
51 | }
52 | },
53 | "jest": {
54 | "testEnvironment": "node",
55 | "collectCoverageFrom": [
56 | "**/src/**/*.js"
57 | ],
58 | "setupFiles": [
59 | "/jest-globals.js"
60 | ],
61 | "globals": {
62 | "window": {}
63 | }
64 | },
65 | "dependencies": {},
66 | "peerDependencies": {
67 | "@tensorflow/tfjs-core": "^1.2.5"
68 | },
69 | "devDependencies": {
70 | "@babel/core": "^7.10.2",
71 | "@babel/plugin-proposal-class-properties": "^7.10.1",
72 | "@babel/plugin-transform-runtime": "^7.10.1",
73 | "@babel/preset-env": "^7.10.2",
74 | "@babel/runtime": "^7.10.2",
75 | "@joseph184/rollup-plugin-node-builtins": "^2.1.4",
76 | "@rollup/plugin-babel": "^5.0.3",
77 | "@rollup/plugin-commonjs": "^13.0.0",
78 | "@rollup/plugin-json": "^4.1.0",
79 | "@rollup/plugin-node-resolve": "^8.0.1",
80 | "@tensorflow/tfjs-core": "^1.7.4",
81 | "auto-changelog": "^1.16.4",
82 | "babel-eslint": "^10.1.0",
83 | "babel-jest": "^24.9.0",
84 | "documentation": "^13.0.1",
85 | "eslint": "^6.8.0",
86 | "eslint-config-standard": "^14.1.1",
87 | "eslint-plugin-import": "^2.21.2",
88 | "eslint-plugin-node": "^11.1.0",
89 | "eslint-plugin-promise": "^4.2.1",
90 | "eslint-plugin-standard": "^4.0.1",
91 | "fetch-mock": "^9.10.1",
92 | "gh-pages": "^2.2.0",
93 | "husky": "^3.1.0",
94 | "jest": "^24.9.0",
95 | "long": "^4.0.0",
96 | "mock-socket": "^9.0.3",
97 | "np": "^5.2.1",
98 | "prettier": "^1.19.1",
99 | "pretty-quick": "^2.0.1",
100 | "randomfill": "^1.0.4",
101 | "regenerator-runtime": "^0.13.5",
102 | "rollup": "^2.16.1",
103 | "rollup-plugin-filesize": "^9.0.0",
104 | "rollup-plugin-peer-deps-external": "^2.2.2",
105 | "syft-proto": "github:openmined/syft-proto#v0.4.9"
106 | }
107 | }
108 |
--------------------------------------------------------------------------------
/test/data_channel_message.test.js:
--------------------------------------------------------------------------------
1 | import DataChannelMessage from '../src/data-channel-message';
2 | import { WEBRTC_DATACHANNEL_CHUNK_SIZE } from '../src';
3 | import { randomFillSync } from 'crypto';
4 |
5 | describe('Data Channel Message', () => {
6 | afterEach(() => {
7 | jest.clearAllMocks();
8 | });
9 |
10 | test('can construct from ArrayBuffer', () => {
11 | const buf = new ArrayBuffer(100000);
12 | const message = new DataChannelMessage({ data: buf });
13 | expect(message.chunks).toBe(
14 | Math.ceil(100000 / WEBRTC_DATACHANNEL_CHUNK_SIZE)
15 | );
16 | expect(() => new DataChannelMessage({ data: {} })).toThrow();
17 | });
18 |
19 | test('can slice data to chunks', () => {
20 | const buf = new ArrayBuffer(100000);
21 | const message = new DataChannelMessage({ data: buf });
22 | const chunk = message.getChunk(0);
23 | expect(chunk.byteLength).toBe(WEBRTC_DATACHANNEL_CHUNK_SIZE);
24 | const info = DataChannelMessage.messageInfoFromBuf(chunk);
25 | expect(info.id).toBe(message.id);
26 | expect(info.chunks).toBe(message.chunks);
27 | expect(info.chunk).toBe(0);
28 | });
29 |
30 | test('can get info from chunk', () => {
31 | const buf = new ArrayBuffer(WEBRTC_DATACHANNEL_CHUNK_SIZE + 100);
32 | randomFillSync(new Uint8Array(buf), 0, buf.byteLength);
33 | const messageOrig = new DataChannelMessage({ data: buf });
34 | const chunk1 = messageOrig.getChunk(0);
35 | const info1 = DataChannelMessage.messageInfoFromBuf(chunk1);
36 | expect(info1.id).toBe(messageOrig.id);
37 | expect(info1.chunks).toBe(messageOrig.chunks);
38 | expect(info1.chunk).toBe(0);
39 |
40 | new Uint8Array(chunk1)[0] = 123;
41 | const infoErr = DataChannelMessage.messageInfoFromBuf(chunk1);
42 | expect(infoErr).toBe(false);
43 | });
44 |
45 | test('can assemble full message from chunks', done => {
46 | const buf = new ArrayBuffer(WEBRTC_DATACHANNEL_CHUNK_SIZE + 100);
47 | randomFillSync(new Uint8Array(buf), 0, buf.byteLength);
48 | const messageOrig = new DataChannelMessage({ data: buf });
49 | const chunk1 = messageOrig.getChunk(0);
50 | const chunk2 = messageOrig.getChunk(1);
51 | const info1 = DataChannelMessage.messageInfoFromBuf(chunk1);
52 |
53 | const messageAssembled = new DataChannelMessage({ id: info1.id });
54 | messageAssembled.once('ready', message => {
55 | expect(message.chunks).toBe(messageOrig.chunks);
56 | expect(message.size).toBe(messageOrig.size);
57 | const orig = new Uint8Array(messageOrig.data);
58 | const assembled = new Uint8Array(message.data);
59 | expect(orig.every((v, i) => assembled[i] === v)).toBe(true);
60 | done();
61 | });
62 | messageAssembled.addChunk(chunk1);
63 | messageAssembled.addChunk(chunk2);
64 | });
65 |
66 | test('should error on invalid chunks', () => {
67 | const buf = new ArrayBuffer(WEBRTC_DATACHANNEL_CHUNK_SIZE + 100);
68 | randomFillSync(new Uint8Array(buf), 0, buf.byteLength);
69 | const messageOrig = new DataChannelMessage({ data: buf });
70 | const chunk1 = messageOrig.getChunk(0);
71 |
72 | // id doesn't match
73 | const messageAssembled = new DataChannelMessage({ id: 123 });
74 | expect(() => messageAssembled.addChunk(chunk1)).toThrow();
75 |
76 | // simply invalid chunk
77 | expect(() => messageAssembled.addChunk(new Uint8Array(3))).toThrow();
78 |
79 | // double chunk add
80 | const messageAssembled2 = new DataChannelMessage({ id: messageOrig.id });
81 | messageAssembled2.addChunk(chunk1);
82 | expect(() => messageAssembled2.addChunk(chunk1)).toThrow();
83 | });
84 | });
85 |
--------------------------------------------------------------------------------
/src/types/role.js:
--------------------------------------------------------------------------------
1 | import { getPbId, unbufferize } from '../protobuf';
2 | import ObjectRegistry from '../object-registry';
3 | import { NOT_ENOUGH_ARGS } from '../_errors';
4 |
5 | export class Role {
6 | constructor(
7 | id,
8 | actions = [],
9 | state = null,
10 | placeholders = {},
11 | input_placeholder_ids = [],
12 | output_placeholder_ids = [],
13 | tags = [],
14 | description = null
15 | ) {
16 | this.id = id;
17 | this.actions = actions;
18 | this.state = state;
19 | this.placeholders = placeholders;
20 | this.input_placeholder_ids = input_placeholder_ids;
21 | this.output_placeholder_ids = output_placeholder_ids;
22 | this.tags = tags;
23 | this.description = description;
24 | }
25 |
26 | static unbufferize(worker, pb) {
27 | let placeholdersArray = unbufferize(worker, pb.placeholders);
28 | let placeholders = {};
29 | for (let ph of placeholdersArray) {
30 | placeholders[ph.id] = ph;
31 | }
32 |
33 | return new Role(
34 | getPbId(pb.id),
35 | unbufferize(worker, pb.actions),
36 | unbufferize(worker, pb.state),
37 | placeholders,
38 | pb.input_placeholder_ids.map(getPbId),
39 | pb.output_placeholder_ids.map(getPbId),
40 | pb.tags,
41 | pb.description
42 | );
43 | }
44 |
45 | findPlaceholders(tagRegex) {
46 | return this.placeholders.filter(
47 | placeholder =>
48 | placeholder.tags && placeholder.tags.some(tag => tagRegex.test(tag))
49 | );
50 | }
51 |
52 | getInputPlaceholders() {
53 | return this.input_placeholder_ids.map(id => this.placeholders[id]);
54 | }
55 |
56 | getOutputPlaceholders() {
57 | return this.output_placeholder_ids.map(id => this.placeholders[id]);
58 | }
59 |
60 | /**
61 | * Execute the Role with given worker
62 | * @param {Syft} worker
63 | * @param data
64 | * @returns {Promise}
65 | */
66 | async execute(worker, ...data) {
67 | // Create local scope.
68 | const planScope = new ObjectRegistry();
69 | planScope.load(worker.objects);
70 |
71 | const inputPlaceholders = this.getInputPlaceholders(),
72 | outputPlaceholders = this.getOutputPlaceholders(),
73 | argsLength = inputPlaceholders.length;
74 |
75 | // If the number of arguments supplied does not match the number of arguments required...
76 | if (data.length !== argsLength)
77 | throw new Error(NOT_ENOUGH_ARGS(data.length, argsLength));
78 |
79 | // For each argument supplied, add them in scope
80 | data.forEach((datum, i) => {
81 | planScope.set(inputPlaceholders[i].id, datum);
82 | });
83 |
84 | // load state tensors to worker
85 | if (this.state && this.state.tensors) {
86 | this.state.placeholders.forEach((ph, idx) => {
87 | planScope.set(ph.id, this.state.tensors[idx]);
88 | });
89 | }
90 |
91 | // Execute the plan
92 | for (const action of this.actions) {
93 | // The result of the current operation
94 | const result = await action.execute(planScope);
95 |
96 | // Place the result of the current operation into this.objects at the 0th item in returnIds
97 | // All intermediate tensors will be garbage collected by default
98 | if (result) {
99 | if (action.returnIds.length > 0) {
100 | planScope.set(action.returnIds[0], result, true);
101 | } else if (action.returnPlaceholderIds.length > 0) {
102 | planScope.set(action.returnPlaceholderIds[0].id, result, true);
103 | }
104 | }
105 | }
106 |
107 | // Resolve all of the requested resultId's as specific by the plan
108 | const resolvedResultingTensors = [];
109 | outputPlaceholders.forEach(placeholder => {
110 | resolvedResultingTensors.push(planScope.get(placeholder.id));
111 | // Do not gc output tensors
112 | planScope.setGc(placeholder.id, false);
113 | });
114 |
115 | // Cleanup intermediate plan variables.
116 | planScope.clear();
117 |
118 | // Return them to the worker
119 | return resolvedResultingTensors;
120 | }
121 | }
122 |
--------------------------------------------------------------------------------
/examples/mnist/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
10 |
11 |
52 |
53 | Syft.js MNIST Example
54 |
55 |
56 |
57 |
58 |
59 |
60 |
65 | Syft.js MNIST Example
66 |
67 | This is a demo using
69 | syft.js from OpenMined
72 | to train MNIST model hosted on the PyGrid using local data and Federated
73 | Learning approach. The data never leaves the browser. Please open your
74 | Javascript console to see what's going on (instructions for
75 | Chrome
80 | and
81 | Firefox ) .
87 |
88 |
89 |
123 |
124 |
137 |
138 |
139 |
--------------------------------------------------------------------------------
/src/types/computation-action.js:
--------------------------------------------------------------------------------
1 | import { unbufferize } from '../protobuf';
2 | import PointerTensor from './pointer-tensor';
3 | import { Placeholder, PlaceholderId } from './placeholder';
4 | import * as tf from '@tensorflow/tfjs-core';
5 | import { TorchParameter, TorchTensor } from './torch';
6 | import { CANNOT_FIND_COMMAND, MISSING_VARIABLE } from '../_errors';
7 |
8 | export class ComputationAction {
9 | constructor(command, target, args, kwargs, returnIds, returnPlaceholderIds) {
10 | this.command = command;
11 | this.target = target;
12 | this.args = args;
13 | this.kwargs = kwargs;
14 | this.returnIds = returnIds;
15 | this.returnPlaceholderIds = returnPlaceholderIds;
16 | }
17 |
18 | static unbufferize(worker, pb) {
19 | return new ComputationAction(
20 | pb.command,
21 | unbufferize(worker, pb[pb.target]),
22 | unbufferize(worker, pb.args),
23 | unbufferize(worker, pb.kwargs),
24 | unbufferize(worker, pb.return_ids),
25 | unbufferize(worker, pb.return_placeholder_ids)
26 | );
27 | }
28 |
29 | async execute(scope) {
30 | // A helper function for helping us determine if all PointerTensors/Placeholders inside of "this.args" also exist as tensors inside of "objects"
31 | const haveValuesForAllArgs = args => {
32 | let enoughInfo = true;
33 |
34 | args.forEach(arg => {
35 | if (
36 | (arg instanceof PointerTensor && !scope.has(arg.idAtLocation)) ||
37 | (arg instanceof Placeholder && !scope.has(arg.id)) ||
38 | (arg instanceof PlaceholderId && !scope.has(arg.id))
39 | ) {
40 | enoughInfo = false;
41 | }
42 | });
43 |
44 | return enoughInfo;
45 | };
46 |
47 | const toTFTensor = tensor => {
48 | if (tensor instanceof tf.Tensor) {
49 | return tensor;
50 | } else if (tensor instanceof TorchTensor) {
51 | return tensor.toTfTensor();
52 | } else if (tensor instanceof TorchParameter) {
53 | return tensor.tensor.toTfTensor();
54 | } else if (typeof tensor === 'number') {
55 | return tensor;
56 | }
57 | return null;
58 | };
59 |
60 | const getTensorByRef = reference => {
61 | let tensor = null;
62 | if (reference instanceof PlaceholderId) {
63 | tensor = scope.get(reference.id);
64 | } else if (reference instanceof Placeholder) {
65 | tensor = scope.get(reference.id);
66 | } else if (reference instanceof PointerTensor) {
67 | tensor = scope.get(reference.idAtLocation);
68 | }
69 | tensor = toTFTensor(tensor);
70 | return tensor;
71 | };
72 |
73 | // A helper function for helping us get all operable tensors from PointerTensors inside of "this._args"
74 | const pullTensorsFromArgs = args => {
75 | const resolvedArgs = [];
76 |
77 | args.forEach(arg => {
78 | const tensorByRef = getTensorByRef(arg);
79 | if (tensorByRef) {
80 | resolvedArgs.push(toTFTensor(tensorByRef));
81 | } else {
82 | // Try to convert to tensor.
83 | const tensor = toTFTensor(arg);
84 | if (tensor !== null) {
85 | resolvedArgs.push(toTFTensor(arg));
86 | } else {
87 | // Keep as is.
88 | resolvedArgs.push(arg);
89 | }
90 | }
91 | });
92 |
93 | return resolvedArgs;
94 | };
95 |
96 | //worker.logger.log(`Given command: ${this.command}, converted command: ${command} + ${JSON.stringify(preArgs)} + ${JSON.stringify(postArgs)}`);
97 |
98 | const args = this.args;
99 | let self = null;
100 |
101 | if (this.target) {
102 | // resolve "self" if it's present
103 | self = getTensorByRef(this.target);
104 | if (!self) {
105 | throw new Error(MISSING_VARIABLE());
106 | }
107 | }
108 |
109 | if (!haveValuesForAllArgs(args)) {
110 | throw new Error(MISSING_VARIABLE());
111 | }
112 |
113 | const resolvedArgs = pullTensorsFromArgs(args);
114 | const functionName = this.command.split('.').pop();
115 |
116 | if (self) {
117 | if (!(functionName in self)) {
118 | throw new Error(CANNOT_FIND_COMMAND(`tensor.${functionName}`));
119 | } else {
120 | return self[functionName](...resolvedArgs);
121 | }
122 | }
123 |
124 | if (!(functionName in tf)) {
125 | throw new Error(CANNOT_FIND_COMMAND(functionName));
126 | } else {
127 | return tf[functionName](...resolvedArgs, ...Object.values(this.kwargs));
128 | }
129 | }
130 | }
131 |
--------------------------------------------------------------------------------
/src/syft-webrtc.js:
--------------------------------------------------------------------------------
1 | /**
2 | * NOTE: This is temporary file to hold WebRTC code that needs to be refactored
3 | * during implementation of Protocols functionality in syft.js.
4 | * Do not use this file as it will be removed in the future.
5 | */
6 | /* istanbul ignore file */
7 | import {
8 | GET_PROTOCOL,
9 | SOCKET_STATUS,
10 | WEBRTC_INTERNAL_MESSAGE,
11 | WEBRTC_JOIN_ROOM,
12 | WEBRTC_PEER_CONFIG,
13 | WEBRTC_PEER_LEFT,
14 | WEBRTC_PEER_OPTIONS
15 | } from './_constants';
16 | import { protobuf, unserialize } from './protobuf';
17 | import Socket from './sockets';
18 | import WebRTCClient from './webrtc';
19 |
20 | export class SyftWebrtc {
21 | /* ----- SOCKET COMMUNICATION ----- */
22 | // TODO refactor into grid client class
23 |
24 | // To create a socket connection internally and externally
25 | createSocketConnection(url) {
26 | if (!url) return;
27 | if (!this.verbose) {
28 | url = url.replace('ws://', 'wss://');
29 | }
30 | // When a socket connection is opened...
31 | const onOpen = event => {
32 | this.observer.broadcast(SOCKET_STATUS, {
33 | connected: true,
34 | event
35 | });
36 | };
37 |
38 | // When a socket connection is closed...
39 | const onClose = event => {
40 | this.observer.broadcast(SOCKET_STATUS, {
41 | connected: false,
42 | event
43 | });
44 | };
45 |
46 | // When a socket message is received...
47 | const onMessage = event => {
48 | const { type, data } = event;
49 |
50 | if (type === GET_PROTOCOL) {
51 | if (data.error) {
52 | this.logger.log(
53 | 'There was an error getting the protocol you requested',
54 | data.error
55 | );
56 |
57 | return data;
58 | }
59 |
60 | // Save our workerId if we don't already have it (also for the socket connection)
61 | this.workerId = data.worker.workerId;
62 | this.socket.workerId = this.workerId;
63 |
64 | // Save our scopeId if we don't already have it
65 | this.scopeId = data.worker.scopeId;
66 |
67 | // Save our role
68 | this.role = data.worker.role;
69 |
70 | // Save the other participant workerId's
71 | this.participants = data.participants;
72 |
73 | // Save the protocol and plan assignment after having Serde detail them
74 | let detailedProtocol;
75 | let detailedPlan;
76 | detailedProtocol = unserialize(
77 | null,
78 | data.protocol,
79 | protobuf.syft_proto.execution.v1.Protocol
80 | );
81 | detailedPlan = unserialize(
82 | null,
83 | data.plan,
84 | protobuf.syft_proto.execution.v1.Plan
85 | );
86 |
87 | this.protocol = detailedProtocol;
88 | this.plan = detailedPlan;
89 |
90 | return this.plan;
91 | } else if (type === WEBRTC_INTERNAL_MESSAGE) {
92 | this.rtc.receiveInternalMessage(data);
93 | } else if (type === WEBRTC_JOIN_ROOM) {
94 | this.rtc.receiveNewPeer(data);
95 | } else if (type === WEBRTC_PEER_LEFT) {
96 | this.rtc.removePeer(data.workerId);
97 | }
98 | };
99 |
100 | this.socket = new Socket({
101 | url,
102 | workerId: this.workerId,
103 | onOpen,
104 | onClose,
105 | onMessage
106 | });
107 | }
108 |
109 | // To close the socket connection with the grid
110 | disconnectFromGrid() {
111 | this.socket.stop();
112 | }
113 |
114 | /* ----- WEBRTC ----- */
115 |
116 | // To create a socket connection internally and externally
117 | createWebRTCClient(peerConfig, peerOptions) {
118 | // If we don't have a socket sever, we can't create the WebRTCClient
119 | if (!this.socket) return;
120 |
121 | // The default STUN/TURN servers to use for NAT traversal
122 | if (!peerConfig) peerConfig = WEBRTC_PEER_CONFIG;
123 |
124 | // Some standard options for establishing peer connections
125 | if (!peerOptions) peerOptions = WEBRTC_PEER_OPTIONS;
126 |
127 | this.rtc = new WebRTCClient({
128 | peerConfig,
129 | peerOptions,
130 | socket: this.socket
131 | });
132 |
133 | const onDataMessage = data => {
134 | this.logger.log(`Data message is received from ${data.worker_id}`, data);
135 | };
136 | this.rtc.on('message', onDataMessage);
137 | }
138 |
139 | connectToParticipants() {
140 | this.rtc.start(this.workerId, this.scopeId);
141 | }
142 |
143 | disconnectFromParticipants() {
144 | this.rtc.stop();
145 | }
146 |
147 | sendToParticipants(data, to) {
148 | this.rtc.sendMessage(data, to);
149 | }
150 | }
151 |
--------------------------------------------------------------------------------
/src/data-channel-message.js:
--------------------------------------------------------------------------------
1 | import EventObserver from './events';
2 | import { WEBRTC_DATACHANNEL_CHUNK_SIZE } from './_constants';
3 |
4 | export default class DataChannelMessage {
5 | static chunkHeaderSign = 0xff00;
6 | static chunkHeaderLength = 10;
7 |
8 | constructor({ data, id, worker_id }) {
9 | if (data === undefined) {
10 | this.data = new ArrayBuffer(0);
11 | } else if (typeof data === 'string') {
12 | this.data = new TextEncoder().encode(data).buffer;
13 | } else if (data instanceof ArrayBuffer) {
14 | this.data = data;
15 | } else {
16 | throw new Error('Message type is not supported');
17 | }
18 |
19 | this.worker_id = worker_id;
20 | this.size = this.data.byteLength;
21 | this.dataChunks = [];
22 | this.id = id || Math.floor(Math.random() * 0xffffffff);
23 | this.observer = new EventObserver();
24 | this.chunks = Math.ceil(
25 | this.size /
26 | (WEBRTC_DATACHANNEL_CHUNK_SIZE - DataChannelMessage.chunkHeaderLength)
27 | );
28 | this.makeChunkHeader(0);
29 | }
30 |
31 | /**
32 | * Creates chunk header for given chunk index
33 | * @param {number} chunk Chunk index
34 | * @returns {ArrayBuffer}
35 | */
36 | makeChunkHeader(chunk) {
37 | if (this.chunkHeader === undefined) {
38 | this.chunkHeader = new ArrayBuffer(DataChannelMessage.chunkHeaderLength);
39 |
40 | const view = new DataView(this.chunkHeader);
41 |
42 | view.setUint16(0, DataChannelMessage.chunkHeaderSign);
43 | view.setUint32(2, this.id);
44 | view.setUint16(6, this.chunks);
45 | view.setUint16(8, chunk);
46 | } else {
47 | const view = new DataView(this.chunkHeader);
48 |
49 | view.setUint16(8, chunk);
50 | }
51 |
52 | return this.chunkHeader;
53 | }
54 |
55 | once(event, func) {
56 | this.observer.subscribe(event, data => {
57 | this.observer.unsubscribe(event);
58 |
59 | func(data);
60 | });
61 | }
62 |
63 | /**
64 | * Gets chunk info from header
65 | * @param {ArrayBuffer} buf
66 | */
67 | static messageInfoFromBuf(buf) {
68 | let view;
69 |
70 | try {
71 | view = new DataView(buf);
72 | if (view.getUint16(0) !== DataChannelMessage.chunkHeaderSign) {
73 | return false;
74 | }
75 | } catch (e) {
76 | return false;
77 | }
78 |
79 | return {
80 | id: view.getUint32(2),
81 | chunks: view.getUint16(6),
82 | chunk: view.getUint16(8)
83 | };
84 | }
85 |
86 | /**
87 | * Adds chunk for further assembly
88 | * @param {ArrayBuffer} buf
89 | */
90 | addChunk(buf) {
91 | const info = DataChannelMessage.messageInfoFromBuf(buf);
92 |
93 | if (info === false) {
94 | throw new Error(`Is not a valid chunk`);
95 | }
96 |
97 | if (this.id !== info.id) {
98 | throw new Error(
99 | `Trying to add chunk from different message: ${this.id} != ${info.id}`
100 | );
101 | }
102 |
103 | if (this.dataChunks[info.chunk] !== undefined) {
104 | throw new Error(`Duplicated chunk ${info.chunks} in message ${this.id}`);
105 | }
106 |
107 | this.dataChunks[info.chunk] = buf.slice(
108 | DataChannelMessage.chunkHeaderLength
109 | );
110 |
111 | if (this.dataChunks.length === info.chunks) {
112 | this.assemble();
113 | }
114 | }
115 |
116 | /**
117 | * Concatenate data pieces
118 | */
119 | assemble() {
120 | let size = 0;
121 |
122 | for (let chunk of this.dataChunks) {
123 | size += chunk.byteLength;
124 | }
125 |
126 | const data = new Uint8Array(size);
127 | let offset = 0;
128 |
129 | for (let chunk of this.dataChunks) {
130 | data.set(new Uint8Array(chunk), offset);
131 | offset += chunk.byteLength;
132 | }
133 |
134 | this.chunks = this.dataChunks.length;
135 | this.size = size;
136 | this.data = data.buffer;
137 |
138 | // Clean up
139 | this.dataChunks = [];
140 |
141 | // Emit event when done
142 | this.observer.broadcast('ready', this);
143 | }
144 |
145 | /**
146 | * Slice a piece of message and add a header
147 | * @param {number} num
148 | * @returns {ArrayBuffer}
149 | */
150 | getChunk(num) {
151 | const start =
152 | num *
153 | (WEBRTC_DATACHANNEL_CHUNK_SIZE - DataChannelMessage.chunkHeaderLength);
154 | const end = Math.min(
155 | start +
156 | WEBRTC_DATACHANNEL_CHUNK_SIZE -
157 | DataChannelMessage.chunkHeaderLength,
158 | this.size
159 | );
160 | const chunk = new Uint8Array(
161 | DataChannelMessage.chunkHeaderLength + end - start
162 | );
163 | const header = this.makeChunkHeader(num);
164 |
165 | chunk.set(new Uint8Array(header), 0);
166 | chunk.set(
167 | new Uint8Array(this.data.slice(start, end)),
168 | DataChannelMessage.chunkHeaderLength
169 | );
170 |
171 | return chunk.buffer;
172 | }
173 | }
174 |
--------------------------------------------------------------------------------
/test/sockets.test.js:
--------------------------------------------------------------------------------
1 | import { SOCKET_PING } from '../src/_constants';
2 | import { WebSocket, Server } from 'mock-socket';
3 |
4 | import Socket from '../src/sockets';
5 |
6 | global.WebSocket = WebSocket;
7 |
8 | const url = 'ws://localhost:8080/';
9 |
10 | // Create a promise that is resolved when the event is triggered.
11 | const makeEventPromise = (emitter, event) => {
12 | let resolver;
13 |
14 | const promise = new Promise(resolve => (resolver = resolve));
15 | emitter.on(event, data => resolver(data));
16 |
17 | return promise;
18 | };
19 |
20 | describe('Sockets', () => {
21 | let mockServer;
22 |
23 | beforeEach(() => {
24 | mockServer = new Server(url);
25 | mockServer.connected = makeEventPromise(mockServer, 'connection');
26 | });
27 |
28 | afterEach(() => {
29 | mockServer.close();
30 | });
31 |
32 | test('sends keep-alive messages automatically', async () => {
33 | const keepAliveTimeout = 300,
34 | expectedMessagesCount = 3,
35 | messages = [],
36 | expectedTypes = [];
37 |
38 | // Creating a socket will open connection and start keep-alive pings.
39 | new Socket({ url, keepAliveTimeout });
40 |
41 | const serverSocket = await mockServer.connected;
42 |
43 | serverSocket.on('message', message => messages.push(JSON.parse(message)));
44 |
45 | await new Promise(done =>
46 | setTimeout(
47 | done,
48 | keepAliveTimeout * expectedMessagesCount + keepAliveTimeout / 2
49 | )
50 | );
51 |
52 | // One keep-alive message is sent right after connection, hence +1.
53 | expect(messages).toHaveLength(expectedMessagesCount + 1);
54 |
55 | for (let i = 0; i < expectedMessagesCount + 1; i++) {
56 | expectedTypes.push(SOCKET_PING);
57 | }
58 |
59 | expect(messages.map(message => message['type'])).toEqual(expectedTypes);
60 | });
61 |
62 | test('triggers onOpen event', async () => {
63 | const onOpen = jest.fn();
64 |
65 | new Socket({ url, onOpen });
66 |
67 | await mockServer.connected;
68 |
69 | expect(onOpen).toHaveBeenCalledTimes(1);
70 | });
71 |
72 | test('triggers onClose event', async () => {
73 | const closed = makeEventPromise(mockServer, 'close'),
74 | onClose = jest.fn(),
75 | mySocket = new Socket({
76 | url,
77 | onClose
78 | });
79 |
80 | await mockServer.connected;
81 |
82 | mySocket.stop();
83 |
84 | await closed;
85 |
86 | expect(onClose).toHaveBeenCalledTimes(1);
87 | expect(mySocket.timerId).toBeNull();
88 | });
89 |
90 | test('sends data correctly', async () => {
91 | const testReqType = 'test',
92 | testReqData = { blob: 1 },
93 | testResponse = { response: 'test' },
94 | testworkerId = 'test-worker',
95 | mySocket = new Socket({
96 | workerId: testworkerId,
97 | url,
98 | onMessage: data => data
99 | });
100 |
101 | const serverSocket = await mockServer.connected;
102 |
103 | // Skip first keep-alive message.
104 | await makeEventPromise(serverSocket, 'message');
105 |
106 | const responsePromise = mySocket.send(testReqType, testReqData);
107 | const message = await makeEventPromise(serverSocket, 'message');
108 |
109 | serverSocket.send(JSON.stringify(testResponse));
110 |
111 | const response = await responsePromise;
112 |
113 | expect(JSON.parse(message)).toEqual({
114 | type: testReqType,
115 | data: testReqData
116 | });
117 | expect(response).toEqual(testResponse);
118 | });
119 |
120 | test('returns error when .send() fails', async () => {
121 | const mySocket = new Socket({
122 | url,
123 | onMessage: data => data
124 | });
125 |
126 | const serverSocket = await mockServer.connected;
127 |
128 | // Skip first keep-alive message.
129 | await makeEventPromise(serverSocket, 'message');
130 |
131 | const responsePromise = mySocket.send('test', {});
132 |
133 | mockServer.simulate('error');
134 |
135 | expect.assertions(1);
136 |
137 | try {
138 | await responsePromise;
139 | } catch (e) {
140 | expect(e).toBeDefined();
141 | }
142 | });
143 |
144 | test('disconnects from server after .stop()', async () => {
145 | const mySocket = new Socket({
146 | url
147 | });
148 |
149 | await mockServer.connected;
150 |
151 | expect(mockServer.clients()).toHaveLength(1);
152 |
153 | mySocket.stop();
154 |
155 | await new Promise(done => setTimeout(done, 100));
156 |
157 | expect(mockServer.clients()).toHaveLength(0);
158 | });
159 |
160 | test('triggers onMessage event', async () => {
161 | const testResponse = { response: 'test' },
162 | testworkerId = 'test-worker',
163 | onMessage = jest.fn(message => message),
164 | mySocket = new Socket({
165 | workerId: testworkerId,
166 | url,
167 | onMessage: onMessage
168 | });
169 |
170 | const serverSocket = await mockServer.connected;
171 |
172 | // Skip first keep-alive message.
173 | await makeEventPromise(serverSocket, 'message');
174 |
175 | serverSocket.on('message', () => {
176 | serverSocket.send(JSON.stringify(testResponse));
177 | });
178 |
179 | await mySocket.send('test1', {});
180 | await mySocket.send('test2', {});
181 |
182 | expect(onMessage).toHaveBeenCalledTimes(2);
183 | expect(onMessage).toHaveBeenLastCalledWith(testResponse);
184 | });
185 | });
186 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | ### Changelog
2 |
3 | All notable changes to this project will be documented in this file. Dates are displayed in UTC.
4 |
5 | Generated by [`auto-changelog`](https://github.com/CookPete/auto-changelog).
6 |
7 | #### [v1.5.0](https://github.com/OpenMined/syft.js/compare/v1.4.4...v1.5.0)
8 |
9 | > 25 July 2019
10 |
11 | - Major refactor to build and testing systems [`#42`](https://github.com/OpenMined/syft.js/pull/42)
12 | - Upgrading the project to a modern development workflow [`b584677`](https://github.com/OpenMined/syft.js/commit/b584677fe1530e8114463f577dc4e5e0e51432f5)
13 | - Upgrading tensorflow [`977ef41`](https://github.com/OpenMined/syft.js/commit/977ef41c929d1bbb9dab5f956c497d128b7d0ff7)
14 | - Generating autochangelog [`ffed4a2`](https://github.com/OpenMined/syft.js/commit/ffed4a2dcc1647d37ef2f3acc9734592db835b09)
15 |
16 | #### [v1.4.4](https://github.com/OpenMined/syft.js/compare/v1.4.3...v1.4.4)
17 |
18 | > 17 October 2018
19 |
20 | - Adding a result_id to operations [`a38af02`](https://github.com/OpenMined/syft.js/commit/a38af02a9b3897ef2172296a1755bf1e2a2626e0)
21 |
22 | #### [v1.4.3](https://github.com/OpenMined/syft.js/compare/v1.4.2...v1.4.3)
23 |
24 | > 14 October 2018
25 |
26 | - GitBook: [1.4.2] 5 pages modified [`#41`](https://github.com/OpenMined/syft.js/pull/41)
27 | - GitBook: [master] 4 pages modified [`522e928`](https://github.com/OpenMined/syft.js/commit/522e9282d8df5f65aaee7dc4dafe3d28c15f10fd)
28 | - Delete api-documentation.md [`625a623`](https://github.com/OpenMined/syft.js/commit/625a6232edcc81cd5acbc1e4cce3f569fcff3152)
29 | - Delete guide.md [`5982dbd`](https://github.com/OpenMined/syft.js/commit/5982dbde60aebe2afac6b263cfdfea1973bde4e5)
30 |
31 | #### v1.4.2
32 |
33 | > 14 October 2018
34 |
35 | - Tensor removal issue [`#40`](https://github.com/OpenMined/syft.js/pull/40)
36 | - TensorFlow bundling [`#38`](https://github.com/OpenMined/syft.js/pull/38)
37 | - Improve build system [`#35`](https://github.com/OpenMined/syft.js/pull/35)
38 | - Add tensorflow [`#31`](https://github.com/OpenMined/syft.js/pull/31)
39 | - Add tensorflow [`#29`](https://github.com/OpenMined/syft.js/pull/29)
40 | - added preliminary websocket support [`#22`](https://github.com/OpenMined/syft.js/pull/22)
41 | - Travis Support for Unit Tests [`#24`](https://github.com/OpenMined/syft.js/pull/24)
42 | - FloatTensor add test [`#23`](https://github.com/OpenMined/syft.js/pull/23)
43 | - created a node server [`#19`](https://github.com/OpenMined/syft.js/pull/19)
44 | - adding demo folder's primitives. [`#20`](https://github.com/OpenMined/syft.js/pull/20)
45 | - Unit Testing Suite [`#21`](https://github.com/OpenMined/syft.js/pull/21)
46 | - Rollup [`#18`](https://github.com/OpenMined/syft.js/pull/18)
47 | - A cleaned up version of examples to keep things organized [`#17`](https://github.com/OpenMined/syft.js/pull/17)
48 | - Add gpu.js as a dependency [`#16`](https://github.com/OpenMined/syft.js/pull/16)
49 | - added js to parse json and create a js object [`#15`](https://github.com/OpenMined/syft.js/pull/15)
50 | - Initial commit, got babel and prettier running [`#5`](https://github.com/OpenMined/syft.js/pull/5)
51 | - Create README.md [`#4`](https://github.com/OpenMined/syft.js/pull/4)
52 | - Master [`#3`](https://github.com/OpenMined/syft.js/pull/3)
53 | - Corrected 'Resources' links [`#1`](https://github.com/OpenMined/syft.js/pull/1)
54 | - Closes #39 [`#39`](https://github.com/OpenMined/syft.js/issues/39)
55 | - Adding a Webpack example [`#36`](https://github.com/OpenMined/syft.js/issues/36)
56 | - Closes #32 by adding sourcemaps and minification [`#32`](https://github.com/OpenMined/syft.js/issues/32)
57 | - clean up plus docs 1 [`b28f7a1`](https://github.com/OpenMined/syft.js/commit/b28f7a15a051ae6c15ccf523e6df8bb0fcbbd0ae)
58 | - re-init [`4936d3f`](https://github.com/OpenMined/syft.js/commit/4936d3ffd2a01da2432f758723df72a9853a7f33)
59 | - Docs [`5be80be`](https://github.com/OpenMined/syft.js/commit/5be80be0c45b79a45990f53f1b2bbd12ad7e2f49)
60 |
61 | #### [v0.0.1-1](https://github.com/OpenMined/syft.js/compare/v1.5.0...v0.0.1-1)
62 |
63 | > 29 December 2019
64 |
65 | - Update serde, fix with-node example [`#71`](https://github.com/OpenMined/syft.js/pull/71)
66 | - Revert merge conflict issues [`#69`](https://github.com/OpenMined/syft.js/pull/69)
67 | - Changing terminology from user to worker [`#68`](https://github.com/OpenMined/syft.js/pull/68)
68 | - TS type defs, closes #57 [`#67`](https://github.com/OpenMined/syft.js/pull/67)
69 | - Fix specific case in string serde [`#58`](https://github.com/OpenMined/syft.js/pull/58)
70 | - Closes #55 [`#56`](https://github.com/OpenMined/syft.js/pull/56)
71 | - Grid syft [`#54`](https://github.com/OpenMined/syft.js/pull/54)
72 | - Add unit tests for webrtc.js [`#52`](https://github.com/OpenMined/syft.js/pull/52)
73 | - Remove unused variables in socket.test.js [`#51`](https://github.com/OpenMined/syft.js/pull/51)
74 | - Add sockets.js tests and minor fixes [`#50`](https://github.com/OpenMined/syft.js/pull/50)
75 | - Adding Serde to syft.js [`#43`](https://github.com/OpenMined/syft.js/pull/43)
76 | - Merge pull request #67 from tisd/ts-type-defs [`#57`](https://github.com/OpenMined/syft.js/issues/57)
77 | - Merge pull request #56 from OpenMined/issue-55 [`#55`](https://github.com/OpenMined/syft.js/issues/55)
78 | - Closes #55 [`#55`](https://github.com/OpenMined/syft.js/issues/55)
79 | - Closes #53 [`#53`](https://github.com/OpenMined/syft.js/issues/53)
80 | - Working on with-grid watching issues [`c63ce1f`](https://github.com/OpenMined/syft.js/commit/c63ce1f8ff4c5eafe2c2862e359f6f3ad1164237)
81 | - Reverting back to NPM [`b1e8d39`](https://github.com/OpenMined/syft.js/commit/b1e8d390a851a21e4ec4d59ae80b3914f59915b9)
82 | - Switching back to NPM [`34f583e`](https://github.com/OpenMined/syft.js/commit/34f583e4ddc3cd46e61e34849bacc34bc5176a03)
83 |
--------------------------------------------------------------------------------
/test/protobuf.test.js:
--------------------------------------------------------------------------------
1 | import { protobuf, unserialize, getPbId, serialize } from '../src/protobuf';
2 | import { ObjectMessage } from '../src/types/message';
3 | import Protocol from '../src/types/protocol';
4 | import { Plan } from '../src/types/plan';
5 | import { State } from '../src/types/state';
6 | import { MNIST_PLAN, MNIST_MODEL_PARAMS, PROTOCOL } from './data/dummy';
7 | import { TorchTensor } from '../src/types/torch';
8 | import * as tf from '@tensorflow/tfjs-core';
9 | import { Placeholder } from '../src/types/placeholder';
10 |
11 | describe('Protobuf', () => {
12 | test('can unserialize an ObjectMessage', () => {
13 | const obj = unserialize(
14 | null,
15 | 'CjcKBwi91JDfnQISKgoECgICAxIHZmxvYXQzMrIBGAAAgD8AAABAAABAQAAAgEAAAKBAMzPDQEAE',
16 | protobuf.syft_proto.messaging.v1.ObjectMessage
17 | );
18 | expect(obj).toBeInstanceOf(ObjectMessage);
19 | });
20 |
21 | test('can unserialize a Protocol', () => {
22 | const protocol = unserialize(
23 | null,
24 | PROTOCOL,
25 | protobuf.syft_proto.execution.v1.Protocol
26 | );
27 | expect(protocol).toBeInstanceOf(Protocol);
28 | });
29 |
30 | test('can unserialize a Plan', () => {
31 | const plan = unserialize(
32 | null,
33 | MNIST_PLAN,
34 | protobuf.syft_proto.execution.v1.Plan
35 | );
36 | expect(plan).toBeInstanceOf(Plan);
37 | });
38 |
39 | test('can unserialize a State', () => {
40 | const state = unserialize(
41 | null,
42 | MNIST_MODEL_PARAMS,
43 | protobuf.syft_proto.execution.v1.State
44 | );
45 | expect(state).toBeInstanceOf(State);
46 | });
47 |
48 | test('can serialize a State', async () => {
49 | const placeholders = [
50 | new Placeholder('123', ['tag1', 'tag2'], 'placeholder')
51 | ];
52 | const tensors = [
53 | await TorchTensor.fromTfTensor(
54 | tf.tensor([
55 | [1.1, 2.2],
56 | [3.3, 4.4]
57 | ])
58 | )
59 | ];
60 | const state = new State(placeholders, tensors);
61 | const serialized = serialize(null, state);
62 |
63 | // unserialize back to check
64 | const unserState = unserialize(
65 | null,
66 | serialized,
67 | protobuf.syft_proto.execution.v1.State
68 | );
69 | expect(unserState).toBeInstanceOf(State);
70 | expect(unserState.id).toStrictEqual(state.id);
71 | expect(unserState.placeholders).toStrictEqual(placeholders);
72 | expect(
73 | tf
74 | .equal(unserState.tensors[0].toTfTensor(), tensors[0].toTfTensor())
75 | .all()
76 | .dataSync()[0]
77 | ).toBe(1);
78 | });
79 |
80 | test('can serialize TorchTensor', async () => {
81 | const tensor = tf.tensor([
82 | [1.1, 2.2],
83 | [3.3, 4.4]
84 | ]);
85 | const torchTensor = await TorchTensor.fromTfTensor(tensor);
86 | torchTensor.tags = ['tag1', 'tag2'];
87 | torchTensor.description = 'description of tensor';
88 |
89 | // serialize
90 | const bin = serialize(null, torchTensor);
91 | expect(bin).toBeInstanceOf(ArrayBuffer);
92 |
93 | // check unserialized matches to original
94 | const unserTorchTensor = unserialize(
95 | null,
96 | bin,
97 | protobuf.syft_proto.types.torch.v1.TorchTensor
98 | );
99 | expect(unserTorchTensor.shape).toStrictEqual(torchTensor.shape);
100 | expect(unserTorchTensor.dtype).toStrictEqual(torchTensor.dtype);
101 | expect(unserTorchTensor.contents).toStrictEqual(torchTensor.contents);
102 | // resulting TF tensors are equal
103 | expect(
104 | tf
105 | .equal(unserTorchTensor.toTfTensor(), tensor)
106 | .all()
107 | .dataSync()[0]
108 | ).toBe(1);
109 | expect(unserTorchTensor.tags).toStrictEqual(torchTensor.tags);
110 | expect(unserTorchTensor.description).toStrictEqual(torchTensor.description);
111 | });
112 |
113 | test('can serialize TorchTensor', async () => {
114 | const tensor = tf.tensor([
115 | [1.1, 2.2],
116 | [3.3, 4.4]
117 | ]);
118 | const torchTensor = await TorchTensor.fromTfTensor(tensor);
119 | torchTensor.tags = ['tag1', 'tag2'];
120 | torchTensor.description = 'description of tensor';
121 |
122 | // serialize
123 | const bin = serialize(null, torchTensor);
124 | expect(bin).toBeInstanceOf(ArrayBuffer);
125 |
126 | // check unserialized matches to original
127 | const unserTorchTensor = unserialize(
128 | null,
129 | bin,
130 | protobuf.syft_proto.types.torch.v1.TorchTensor
131 | );
132 | expect(unserTorchTensor.shape).toStrictEqual(torchTensor.shape);
133 | expect(unserTorchTensor.dtype).toStrictEqual(torchTensor.dtype);
134 | expect(unserTorchTensor.contents).toStrictEqual(torchTensor.contents);
135 | // resulting TF tensors are equal
136 | expect(
137 | tf
138 | .equal(unserTorchTensor.toTfTensor(), tensor)
139 | .all()
140 | .dataSync()[0]
141 | ).toBe(1);
142 | expect(unserTorchTensor.tags).toStrictEqual(torchTensor.tags);
143 | expect(unserTorchTensor.description).toStrictEqual(torchTensor.description);
144 | });
145 |
146 | test('gets id from types.syft.Id', () => {
147 | const protocolWithIntId = protobuf.syft_proto.execution.v1.Protocol.fromObject(
148 | {
149 | id: {
150 | id_int: 123
151 | }
152 | }
153 | );
154 | const protocolWithStrId = protobuf.syft_proto.execution.v1.Protocol.fromObject(
155 | {
156 | id: {
157 | id_str: '321'
158 | }
159 | }
160 | );
161 | expect(getPbId(protocolWithIntId.id)).toBe('123');
162 | expect(getPbId(protocolWithStrId.id)).toBe('321');
163 | });
164 | });
165 |
--------------------------------------------------------------------------------
/examples/mnist/mnist.js:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import * as tf from '@tensorflow/tfjs-core';
19 |
20 | export const IMAGE_H = 28;
21 | export const IMAGE_W = 28;
22 | const IMAGE_SIZE = IMAGE_H * IMAGE_W;
23 | const NUM_CLASSES = 10;
24 | const NUM_DATASET_ELEMENTS = 65000;
25 |
26 | const NUM_TRAIN_ELEMENTS = 55000;
27 | const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;
28 |
29 | const MNIST_IMAGES_SPRITE_PATH =
30 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
31 | const MNIST_LABELS_PATH =
32 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';
33 |
34 | /**
35 | * A class that fetches the sprited MNIST dataset and provide data as
36 | * tf.Tensors.
37 | */
38 | export class MnistData {
39 | constructor() {}
40 |
41 | async load() {
42 | // Make a request for the MNIST sprited image.
43 | const img = new Image();
44 | const canvas = document.createElement('canvas');
45 | const ctx = canvas.getContext('2d');
46 | const imgRequest = new Promise((resolve, reject) => {
47 | img.crossOrigin = '';
48 | img.onload = () => {
49 | img.width = img.naturalWidth;
50 | img.height = img.naturalHeight;
51 |
52 | const datasetBytesBuffer = new ArrayBuffer(
53 | NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4
54 | );
55 |
56 | const chunkSize = 5000;
57 | canvas.width = img.width;
58 | canvas.height = chunkSize;
59 |
60 | for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
61 | const datasetBytesView = new Float32Array(
62 | datasetBytesBuffer,
63 | i * IMAGE_SIZE * chunkSize * 4,
64 | IMAGE_SIZE * chunkSize
65 | );
66 | ctx.drawImage(
67 | img,
68 | 0,
69 | i * chunkSize,
70 | img.width,
71 | chunkSize,
72 | 0,
73 | 0,
74 | img.width,
75 | chunkSize
76 | );
77 |
78 | const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
79 |
80 | for (let j = 0; j < imageData.data.length / 4; j++) {
81 | // All channels hold an equal value since the image is grayscale, so
82 | // just read the red channel.
83 | datasetBytesView[j] = imageData.data[j * 4] / 255;
84 | }
85 | }
86 | this.datasetImages = new Float32Array(datasetBytesBuffer);
87 |
88 | resolve();
89 | };
90 | img.src = MNIST_IMAGES_SPRITE_PATH;
91 | });
92 |
93 | const labelsRequest = fetch(MNIST_LABELS_PATH);
94 | const [imgResponse, labelsResponse] = await Promise.all([
95 | imgRequest,
96 | labelsRequest
97 | ]);
98 |
99 | this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
100 |
101 | // Slice the the images and labels into train and test sets.
102 | this.trainImages = this.datasetImages.slice(
103 | 0,
104 | IMAGE_SIZE * NUM_TRAIN_ELEMENTS
105 | );
106 | this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
107 | this.trainLabels = this.datasetLabels.slice(
108 | 0,
109 | NUM_CLASSES * NUM_TRAIN_ELEMENTS
110 | );
111 | this.testLabels = this.datasetLabels.slice(
112 | NUM_CLASSES * NUM_TRAIN_ELEMENTS
113 | );
114 | }
115 |
116 | /**
117 | * Get all training data as a data tensor and a labels tensor.
118 | *
119 | * @returns
120 | * xs: The data tensor, of shape `[numTrainExamples, 784]`.
121 | * labels: The one-hot encoded labels tensor, of shape
122 | * `[numTrainExamples, 10]`.
123 | */
124 | getTrainData() {
125 | const xs = tf.tensor2d(this.trainImages, [
126 | this.trainImages.length / IMAGE_SIZE,
127 | IMAGE_H * IMAGE_W
128 | ]);
129 | const labels = tf.tensor2d(this.trainLabels, [
130 | this.trainLabels.length / NUM_CLASSES,
131 | NUM_CLASSES
132 | ]);
133 | return { xs, labels };
134 | }
135 |
136 | /**
137 | * Get all test data as a data tensor a a labels tensor.
138 | *
139 | * @param {number} numExamples Optional number of examples to get. If not
140 | * provided,
141 | * all test examples will be returned.
142 | * @returns
143 | * xs: The data tensor, of shape `[numTestExamples, 784]`.
144 | * labels: The one-hot encoded labels tensor, of shape
145 | * `[numTestExamples, 10]`.
146 | */
147 | getTestData(numExamples = NUM_TEST_ELEMENTS) {
148 | let xs = tf.tensor2d(this.testImages, [
149 | this.testImages.length / IMAGE_SIZE,
150 | IMAGE_H * IMAGE_W
151 | ]);
152 | let labels = tf.tensor2d(this.testLabels, [
153 | this.testLabels.length / NUM_CLASSES,
154 | NUM_CLASSES
155 | ]);
156 |
157 | if (numExamples != null) {
158 | xs = xs.slice([0, 0, 0, 0], [numExamples, IMAGE_H * IMAGE_W]);
159 | labels = labels.slice([0, 0], [numExamples, NUM_CLASSES]);
160 | }
161 | return { xs, labels };
162 | }
163 | }
164 |
--------------------------------------------------------------------------------
/examples/multi-armed-bandit/index.js:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import { render } from 'react-dom';
3 | import * as tf from '@tensorflow/tfjs-core';
4 | import { Syft } from '@openmined/syft.js';
5 |
6 | import App from './app.js';
7 |
8 | // Define grid connection parameters
9 | const url = 'ws://localhost:5000';
10 | const modelName = 'bandit';
11 | const modelVersion = '1.0.0';
12 | const shouldRepeat = false;
13 |
14 | // Pick random values f.or the layout
15 | const pickValue = p => p[Math.floor(Math.random() * p.length)];
16 |
17 | // 24 possible configurations
18 | const appConfigPossibilities = {
19 | heroBackground: ['black', 'gradient'],
20 | buttonPosition: ['hero', 'vision'],
21 | buttonIcon: ['arrow', 'user', 'code'],
22 | buttonColor: ['blue', 'white']
23 | };
24 |
25 | // Final configuration for the app
26 | const appConfig = {
27 | heroBackground: pickValue(appConfigPossibilities.heroBackground),
28 | buttonPosition: pickValue(appConfigPossibilities.buttonPosition),
29 | buttonIcon: pickValue(appConfigPossibilities.buttonIcon),
30 | buttonColor: pickValue(appConfigPossibilities.buttonColor)
31 | };
32 |
33 | // Set up an event listener for the button when it's clicked
34 | // TODO: @maddie - Submit the diff for a positive button click here...
35 | const onButtonClick = () => {
36 | console.log(
37 | 'Clicked the button! Send a positive result for config',
38 | appConfig
39 | );
40 | };
41 |
42 | // Start React
43 | render(
44 | startFL(url, modelName, modelVersion, shouldRepeat)}
48 | />,
49 | document.getElementById('root')
50 | );
51 |
52 | // Main start method
53 | const startFL = async (
54 | url,
55 | modelName,
56 | modelVersion,
57 | authToken = null,
58 | shouldRepeat
59 | ) => {
60 | const worker = new Syft({ url, authToken, verbose: true });
61 | const job = await worker.newJob({ modelName, modelVersion });
62 |
63 | job.start();
64 |
65 | job.on('accepted', async ({ model, clientConfig }) => {
66 | updateStatus('Accepted into cycle!');
67 |
68 | // TODO: @maddie - Replace all of this with the bandit code, but try to still use the same
69 | // updateAfterBatch and updateStatus calls... those are helpful for the user to see!
70 | // // Load MNIST data
71 | // await loadMnistDataset();
72 | // const trainDataset = mnist.getTrainData();
73 | // const data = trainDataset.xs;
74 | // const targets = trainDataset.labels;
75 |
76 | // // Prepare randomized indices for data batching
77 | // const indices = Array.from({ length: data.shape[0] }, (v, i) => i);
78 | // tf.util.shuffle(indices);
79 |
80 | // // Prepare train parameters
81 | // const batchSize = clientConfig.batch_size;
82 | // const lr = clientConfig.lr;
83 | // const numBatches = Math.ceil(data.shape[0] / batchSize);
84 |
85 | // // Calculate total number of model updates
86 | // // in case none of these options specified, we fallback to one loop
87 | // // though all batches.
88 | // const maxEpochs = clientConfig.max_epochs || 1;
89 | // const maxUpdates = clientConfig.max_updates || maxEpochs * numBatches;
90 | // const numUpdates = Math.min(maxUpdates, maxEpochs * numBatches);
91 |
92 | // // Copy model to train it
93 | // let modelParams = [];
94 | // for (let param of model.params) {
95 | // modelParams.push(param.clone());
96 | // }
97 |
98 | // // Main training loop
99 | // for (let update = 0, batch = 0, epoch = 0; update < numUpdates; update++) {
100 | // // Slice a batch
101 | // const chunkSize = Math.min(batchSize, data.shape[0] - batch * batchSize);
102 | // const indicesBatch = indices.slice(
103 | // batch * batchSize,
104 | // batch * batchSize + chunkSize
105 | // );
106 | // const dataBatch = data.gather(indicesBatch);
107 | // const targetBatch = targets.gather(indicesBatch);
108 |
109 | // // Execute the plan and get updated model params back
110 | // let [loss, acc, ...updatedModelParams] = await job.plans[
111 | // 'training_plan'
112 | // ].execute(
113 | // job.worker,
114 | // dataBatch,
115 | // targetBatch,
116 | // chunkSize,
117 | // lr,
118 | // ...modelParams
119 | // );
120 |
121 | // // Use updated model params in the next cycle
122 | // for (let i = 0; i < modelParams.length; i++) {
123 | // modelParams[i].dispose();
124 | // modelParams[i] = updatedModelParams[i];
125 | // }
126 |
127 | // await updateAfterBatch({
128 | // epoch,
129 | // batch,
130 | // accuracy: await acc.array(),
131 | // loss: await loss.array()
132 | // });
133 |
134 | // batch++;
135 |
136 | // // Check if we're out of batches (end of epoch)
137 | // if (batch === numBatches) {
138 | // batch = 0;
139 | // epoch++;
140 | // }
141 |
142 | // // Free GPU memory
143 | // acc.dispose();
144 | // loss.dispose();
145 | // dataBatch.dispose();
146 | // targetBatch.dispose();
147 | // }
148 |
149 | // // Free GPU memory
150 | // data.dispose();
151 | // targets.dispose();
152 |
153 | // // TODO protocol execution
154 | // // job.protocols['secure_aggregation'].execute();
155 |
156 | // // Calc model diff
157 | // const modelDiff = await model.createSerializedDiff(modelParams);
158 |
159 | // // Report diff
160 | // await job.report(modelDiff);
161 | // updateStatus('Cycle is done!');
162 |
163 | // // Try again...
164 | // if (shouldRepeat) {
165 | // setTimeout(startFL, 1000, url, modelName, modelVersion, authToken);
166 | // }
167 | });
168 |
169 | job.on('rejected', ({ timeout }) => {
170 | // Handle the job rejection
171 | if (timeout) {
172 | const msUntilRetry = timeout * 1000;
173 |
174 | // Try to join the job again in "msUntilRetry" milliseconds
175 | updateStatus(`Rejected from cycle, retry in ${timeout}`);
176 | setTimeout(job.start.bind(job), msUntilRetry);
177 | } else {
178 | updateStatus(
179 | `Rejected from cycle with no timeout, assuming Model training is complete.`
180 | );
181 | }
182 | });
183 |
184 | job.on('error', err => {
185 | updateStatus(`Error: ${err.message}`);
186 | });
187 | };
188 |
189 | // Status update message
190 | const updateStatus = message => {
191 | console.log('STATUS', message);
192 | };
193 |
194 | // Log statistics after each batch
195 | const updateAfterBatch = async ({ epoch, batch, accuracy, loss }) => {
196 | console.log(
197 | `Epoch: ${epoch}`,
198 | `Batch: ${batch}`,
199 | `Accuracy: ${accuracy}`,
200 | `Loss: ${loss}`
201 | );
202 |
203 | await tf.nextFrame();
204 | };
205 |
--------------------------------------------------------------------------------
/src/speed-test.js:
--------------------------------------------------------------------------------
1 | import { randomFillSync } from 'randomfill';
2 |
3 | export class SpeedTest {
4 | constructor({
5 | downloadUrl,
6 | uploadUrl,
7 | pingUrl,
8 | maxUploadSizeMb = 64,
9 | maxTestTimeSec = 10
10 | }) {
11 | this.downloadUrl = downloadUrl;
12 | this.uploadUrl = uploadUrl;
13 | this.pingUrl = pingUrl;
14 | this.maxUploadSizeMb = maxUploadSizeMb;
15 | this.maxTestTimeSec = maxTestTimeSec;
16 |
17 | // Various settings to tune.
18 | this.bwAvgWindow = 5;
19 | this.bwLowJitterThreshold = 0.05;
20 | this.bwMaxLowJitterConsecutiveMeasures = 5;
21 | }
22 |
23 | async meterXhr(xhr, isUpload = false) {
24 | return new Promise((resolve, reject) => {
25 | let timeoutHandler = null,
26 | prevTime = 0,
27 | prevSize = 0,
28 | avgCollector = new AvgCollector({
29 | avgWindow: this.bwAvgWindow,
30 | lowJitterThreshold: this.bwLowJitterThreshold,
31 | maxLowJitterConsecutiveMeasures: this
32 | .bwMaxLowJitterConsecutiveMeasures
33 | });
34 |
35 | const req = isUpload ? xhr.upload : xhr;
36 |
37 | const finish = (error = null) => {
38 | if (timeoutHandler) {
39 | clearTimeout(timeoutHandler);
40 | }
41 |
42 | // clean up
43 | req.onprogress = null;
44 | req.onload = null;
45 | req.onerror = null;
46 | xhr.abort();
47 |
48 | if (!error) {
49 | resolve(avgCollector.getAvg());
50 | } else {
51 | reject(new Error(error));
52 | }
53 | };
54 |
55 | req.onreadystatechange = () => {
56 | if (xhr.readyState === 1) {
57 | // set speed test timeout
58 | timeoutHandler = setTimeout(finish, this.maxTestTimeSec * 1000);
59 | }
60 | };
61 |
62 | req.onprogress = e => {
63 | const // mbit
64 | size = (8 * e.loaded) / 1048576,
65 | // seconds
66 | time = Date.now() / 1000;
67 |
68 | if (!prevTime) {
69 | prevTime = time;
70 | prevSize = size;
71 | return;
72 | }
73 |
74 | let deltaSize = size - prevSize,
75 | deltaTime = time - prevTime,
76 | speed = deltaSize / deltaTime;
77 |
78 | if (deltaTime === 0 || !Number.isFinite(speed)) {
79 | prevTime = time;
80 | prevSize = size;
81 | return;
82 | }
83 |
84 | const canStop = avgCollector.collect(speed);
85 | if (canStop) {
86 | finish();
87 | }
88 |
89 | prevSize = size;
90 | prevTime = time;
91 | };
92 |
93 | req.onload = () => {
94 | finish();
95 | };
96 | req.onerror = e => {
97 | finish(e);
98 | };
99 | });
100 | }
101 |
102 | async getDownloadSpeed() {
103 | let xhr = new XMLHttpRequest();
104 | const result = this.meterXhr(xhr);
105 |
106 | xhr.open('GET', this.downloadUrl + '?' + Math.random(), true);
107 | xhr.send();
108 |
109 | return result;
110 | }
111 |
112 | async getUploadSpeed() {
113 | const xhr = new XMLHttpRequest();
114 | const result = this.meterXhr(xhr, true);
115 |
116 | // Create random bytes buffer.
117 | const buff = new Uint8Array(this.maxUploadSizeMb * 1024 * 1024);
118 | const maxRandomChunkSize = 65536;
119 | const chunkNum = Math.ceil(buff.byteLength / maxRandomChunkSize);
120 | for (
121 | let chunk = 0, offset = 0;
122 | chunk < chunkNum;
123 | chunk++, offset += maxRandomChunkSize
124 | ) {
125 | randomFillSync(
126 | buff,
127 | offset,
128 | Math.min(maxRandomChunkSize, buff.byteLength - offset)
129 | );
130 | }
131 |
132 | xhr.open('POST', this.uploadUrl, true);
133 | xhr.send(buff);
134 |
135 | return result;
136 | }
137 |
138 | async getPing() {
139 | return new Promise((resolve, reject) => {
140 | const avgCollector = new AvgCollector({});
141 | let currXhr;
142 | let timeoutHandler;
143 |
144 | const finish = (xhr, error = null) => {
145 | if (timeoutHandler) {
146 | clearTimeout(timeoutHandler);
147 | }
148 |
149 | // clean up
150 | xhr.onprogress = null;
151 | xhr.onload = null;
152 | xhr.onerror = null;
153 | xhr.abort();
154 |
155 | if (!error) {
156 | resolve(avgCollector.getAvg());
157 | } else {
158 | reject(new Error(error));
159 | }
160 | };
161 |
162 | const runPing = () => {
163 | const xhr = new XMLHttpRequest();
164 | currXhr = xhr;
165 | let startTime = Date.now();
166 |
167 | xhr.onload = () => {
168 | const ping = Date.now() - startTime;
169 | const canStop = avgCollector.collect(ping);
170 | if (canStop) {
171 | finish(xhr);
172 | } else {
173 | setTimeout(runPing, 0);
174 | }
175 | };
176 |
177 | xhr.onerror = e => {
178 | finish(xhr, e);
179 | };
180 |
181 | xhr.open('GET', this.pingUrl + '?' + Math.random(), true);
182 | xhr.send();
183 | };
184 |
185 | timeoutHandler = setTimeout(() => {
186 | finish(currXhr);
187 | }, this.maxTestTimeSec * 1000);
188 | runPing();
189 | });
190 | }
191 | }
192 |
193 | /**
194 | * Helper to average series of values
195 | */
196 | class AvgCollector {
197 | constructor({
198 | avgWindow = 5,
199 | lowJitterThreshold = 0.05,
200 | maxLowJitterConsecutiveMeasures = 5
201 | }) {
202 | this.measuresCount = 0;
203 | this.prevAvg = 0;
204 | this.avg = 0;
205 | this.lowJitterConsecutiveMeasures = 0;
206 |
207 | this.avgWindow = avgWindow;
208 | this.lowJitterThreshold = lowJitterThreshold;
209 | this.maxLowJitterConsecutiveMeasures = maxLowJitterConsecutiveMeasures;
210 | this.name = name;
211 | }
212 |
213 | collect(value) {
214 | this.prevAvg = this.avg;
215 | const avgWindow = Math.min(this.measuresCount, this.avgWindow);
216 | this.avg = (this.avg * avgWindow + value) / (avgWindow + 1);
217 | this.measuresCount++;
218 |
219 | // Return true if measurements are stable.
220 | if (
221 | this.prevAvg > 0 &&
222 | this.avg < this.prevAvg * (1 + this.lowJitterThreshold) &&
223 | this.avg > this.prevAvg * (1 - this.lowJitterThreshold)
224 | ) {
225 | this.lowJitterConsecutiveMeasures++;
226 | } else {
227 | this.lowJitterConsecutiveMeasures = 0;
228 | }
229 |
230 | if (
231 | this.lowJitterConsecutiveMeasures >= this.maxLowJitterConsecutiveMeasures
232 | ) {
233 | return true;
234 | }
235 |
236 | return false;
237 | }
238 |
239 | getAvg() {
240 | return this.avg;
241 | }
242 | }
243 |
--------------------------------------------------------------------------------
/examples/mnist/index.js:
--------------------------------------------------------------------------------
1 | import * as tf from '@tensorflow/tfjs-core';
2 | import { Syft } from '@openmined/syft.js';
3 | import { MnistData } from './mnist';
4 |
5 | const gridServer = document.getElementById('grid-server');
6 | const startButton = document.getElementById('start');
7 | let mnist = null;
8 |
9 | startButton.onclick = () => {
10 | setFLUI();
11 | const modelName = document.getElementById('model-id').value;
12 | const modelVersion = document.getElementById('model-version').value;
13 | const authToken = document.getElementById('auth-token').value;
14 | startFL(gridServer.value, modelName, modelVersion, authToken).catch(err => {
15 | updateStatus(`Error: ${err}`);
16 | });
17 | };
18 |
19 | /**
20 | * The main federated learning training routine
21 | * @param url PyGrid Url
22 | * @param modelName Federated learning model name hosted in PyGrid
23 | * @param modelVersion Federated learning model version
24 | * @returns {Promise}
25 | */
26 | const startFL = async (url, modelName, modelVersion, authToken = null) => {
27 | const worker = new Syft({ url, authToken, verbose: true });
28 | const job = await worker.newJob({ modelName, modelVersion });
29 |
30 | job.start();
31 |
32 | job.on('accepted', async ({ model, clientConfig }) => {
33 | updateStatus('Accepted into cycle!');
34 |
35 | // Load MNIST data
36 | await loadMnistDataset();
37 | const trainDataset = mnist.getTrainData();
38 | const data = trainDataset.xs;
39 | const targets = trainDataset.labels;
40 |
41 | // Prepare randomized indices for data batching.
42 | const indices = Array.from({ length: data.shape[0] }, (v, i) => i);
43 | tf.util.shuffle(indices);
44 |
45 | // Prepare train parameters.
46 | const batchSize = clientConfig.batch_size;
47 | const lr = clientConfig.lr;
48 | const numBatches = Math.ceil(data.shape[0] / batchSize);
49 |
50 | // Calculate total number of model updates
51 | // in case none of these options specified, we fallback to one loop
52 | // though all batches.
53 | const maxEpochs = clientConfig.max_epochs || 1;
54 | const maxUpdates = clientConfig.max_updates || maxEpochs * numBatches;
55 | const numUpdates = Math.min(maxUpdates, maxEpochs * numBatches);
56 |
57 | // Copy model to train it.
58 | let modelParams = [];
59 | for (let param of model.params) {
60 | modelParams.push(param.clone());
61 | }
62 |
63 | // Main training loop.
64 | for (let update = 0, batch = 0, epoch = 0; update < numUpdates; update++) {
65 | // Slice a batch.
66 | const chunkSize = Math.min(batchSize, data.shape[0] - batch * batchSize);
67 | const indicesBatch = indices.slice(
68 | batch * batchSize,
69 | batch * batchSize + chunkSize
70 | );
71 | const dataBatch = data.gather(indicesBatch);
72 | const targetBatch = targets.gather(indicesBatch);
73 |
74 | // Execute the plan and get updated model params back.
75 | let [loss, acc, ...updatedModelParams] = await job.plans[
76 | 'training_plan'
77 | ].execute(
78 | job.worker,
79 | dataBatch,
80 | targetBatch,
81 | chunkSize,
82 | lr,
83 | ...modelParams
84 | );
85 |
86 | // Use updated model params in the next cycle.
87 | for (let i = 0; i < modelParams.length; i++) {
88 | modelParams[i].dispose();
89 | modelParams[i] = updatedModelParams[i];
90 | }
91 |
92 | await updateUIAfterBatch({
93 | epoch,
94 | batch,
95 | accuracy: await acc.array(),
96 | loss: await loss.array()
97 | });
98 |
99 | batch++;
100 |
101 | // Check if we're out of batches (end of epoch).
102 | if (batch === numBatches) {
103 | batch = 0;
104 | epoch++;
105 | }
106 |
107 | // Free GPU memory.
108 | acc.dispose();
109 | loss.dispose();
110 | dataBatch.dispose();
111 | targetBatch.dispose();
112 | }
113 |
114 | // Free GPU memory.
115 | data.dispose();
116 | targets.dispose();
117 |
118 | // TODO protocol execution
119 | // job.protocols['secure_aggregation'].execute();
120 |
121 | // Calc model diff.
122 | const modelDiff = await model.createSerializedDiff(modelParams);
123 |
124 | // Report diff.
125 | await job.report(modelDiff);
126 | updateStatus('Cycle is done!');
127 |
128 | // Try again.
129 | if (doRepeat()) {
130 | setTimeout(startFL, 1000, url, modelName, modelVersion, authToken);
131 | }
132 | });
133 |
134 | job.on('rejected', ({ timeout }) => {
135 | // Handle the job rejection.
136 | if (timeout) {
137 | const msUntilRetry = timeout * 1000;
138 | // Try to join the job again in "msUntilRetry" milliseconds
139 | updateStatus(`Rejected from cycle, retry in ${timeout}`);
140 | setTimeout(job.start.bind(job), msUntilRetry);
141 | } else {
142 | updateStatus(
143 | `Rejected from cycle with no timeout, assuming Model training is complete.`
144 | );
145 | }
146 | });
147 |
148 | job.on('error', err => {
149 | updateStatus(`Error: ${err.message}`);
150 | });
151 | };
152 |
153 | /**
154 | * Loads MNIST dataset into global variable `mnist`.
155 | */
156 | const loadMnistDataset = async () => {
157 | if (!mnist) {
158 | updateStatus('Loading MNIST data...');
159 | mnist = new MnistData();
160 | await mnist.load();
161 | updateStatus('MNIST data loaded.');
162 | }
163 | };
164 |
165 | /**
166 | * Log message on the page.
167 | * @param message
168 | */
169 | const updateStatus = message => {
170 | const cont = document.getElementById('status');
171 | cont.innerHTML = message + ' ' + cont.innerHTML;
172 | };
173 |
174 | /**
175 | * Initializes loss & accuracy plots.
176 | */
177 | const setFLUI = () => {
178 | Plotly.newPlot(
179 | 'loss_graph',
180 | [{ y: [], mode: 'lines', line: { color: '#80CAF6' } }],
181 | { title: 'Train Loss', showlegend: false },
182 | { staticPlot: true }
183 | );
184 |
185 | Plotly.newPlot(
186 | 'acc_graph',
187 | [{ y: [], mode: 'lines', line: { color: '#80CAF6' } }],
188 | { title: 'Train Accuracy', showlegend: false },
189 | { staticPlot: true }
190 | );
191 |
192 | document.getElementById('fl-training').style.display = 'table';
193 | };
194 |
195 | /**
196 | * Updates graphs after each batch.
197 | * @param epoch
198 | * @param batch
199 | * @param accuracy
200 | * @param loss
201 | * @returns {Promise}
202 | */
203 | const updateUIAfterBatch = async ({ epoch, batch, accuracy, loss }) => {
204 | console.log(
205 | `Epoch: ${epoch}, Batch: ${batch}, Accuracy: ${accuracy}, Loss: ${loss}`
206 | );
207 | Plotly.extendTraces('loss_graph', { y: [[loss]] }, [0]);
208 | Plotly.extendTraces('acc_graph', { y: [[accuracy]] }, [0]);
209 | await tf.nextFrame();
210 | };
211 |
212 | const doRepeat = () => document.getElementById('worker-repeat').checked;
213 |
--------------------------------------------------------------------------------
/API-REFERENCE.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ### Table of Contents
4 |
5 | - [Syft][1]
6 | - [Parameters][2]
7 | - [Examples][3]
8 | - [newJob][4]
9 | - [Parameters][5]
10 | - [Job][6]
11 | - [Properties][7]
12 | - [on][8]
13 | - [Parameters][9]
14 | - [start][10]
15 | - [Parameters][11]
16 | - [report][12]
17 | - [Parameters][13]
18 | - [Job#accepted][14]
19 | - [Properties][15]
20 | - [Job#rejected][16]
21 | - [Properties][17]
22 | - [Job#error][18]
23 | - [SyftModel][19]
24 | - [Properties][20]
25 | - [createSerializedDiff][21]
26 | - [Parameters][22]
27 | - [Plan][23]
28 | - [execute][24]
29 | - [Parameters][25]
30 |
31 | ## Syft
32 |
33 | Syft client for static federated learning.
34 |
35 | ### Parameters
36 |
37 | - `options` **[Object][26]**
38 | - `options.url` **[string][27]** Full URL to PyGrid app (`ws` and `http` schemas supported).
39 | - `options.verbose` **[boolean][28]** Whether to enable logging and allow unsecured PyGrid connection.
40 | - `options.authToken` **[string][27]** PyGrid authentication token.
41 | - `options.peerConfig` **[Object][26]** [not implemented] WebRTC peer config used with RTCPeerConnection.
42 |
43 | ### Examples
44 |
45 | ```javascript
46 | const client = new Syft({url: "ws://localhost:5000", verbose: true})
47 | const job = client.newJob({modelName: "mnist", modelVersion: "1.0.0"})
48 | job.on('accepted', async ({model, clientConfig}) => {
49 | // execute training
50 | const [...newParams] = await this.plans['...'].execute(...)
51 | const diff = await model.createSerializedDiff(newParams)
52 | await this.report(diff)
53 | })
54 | job.on('rejected', ({timeout}) => {
55 | // re-try later or stop
56 | })
57 | job.on('error', (err) => {
58 | // handle errors
59 | })
60 | job.start()
61 | ```
62 |
63 | ### newJob
64 |
65 | Authenticates the client against PyGrid and instantiates new Job with given options.
66 |
67 | #### Parameters
68 |
69 | - `options` **[Object][26]**
70 | - `options.modelName` **[string][27]** FL Model name.
71 | - `options.modelVersion` **[string][27]** FL Model version.
72 |
73 |
74 | - Throws **any** Error
75 |
76 | Returns **[Promise][29]<[Job][30]>**
77 |
78 | ## Job
79 |
80 | Job represents a single training cycle done by the client.
81 |
82 | ### Properties
83 |
84 | - `plans` **[Object][26]<[string][27], [Plan][31]>** Plans dictionary.
85 | - `protocols` **[Object][26]<[string][27], Protocol>** [not implemented] Protocols dictionary.
86 | - `model` **[SyftModel][32]** Model.
87 |
88 | ### on
89 |
90 | Registers an event listener.
91 |
92 | Available events: `accepted`, `rejected`, `error`.
93 |
94 | #### Parameters
95 |
96 | - `event` **[string][27]** Event name.
97 | - `handler` **[function][33]** Event handler.
98 |
99 | ### start
100 |
101 | Starts the Job executing following actions:
102 |
103 | - Meters connection speed to PyGrid.
104 | - Registers into training cycle on PyGrid.
105 | - Retrieves cycle and client parameters.
106 | - Downloads Plans, Model, Protocols.
107 | - Fires `accepted` event on success.
108 |
109 | #### Parameters
110 |
111 | - `options` **[Object][26]** (optional, default `{}`)
112 | - `options.skipGridSpeedTest` **[boolean][28]** When true, skips the speed test before requesting a cycle. (optional, default `false`)
113 |
114 | Returns **[Promise][29]<void>**
115 |
116 | ### report
117 |
118 | Submits the model diff to PyGrid.
119 |
120 | #### Parameters
121 |
122 | - `diff` **[ArrayBuffer][34]** Serialized difference between original and trained model parameters.
123 |
124 | Returns **[Promise][29]<void>**
125 |
126 | ## Job#accepted
127 |
128 | `accepted` event.
129 | Triggered when PyGrid accepts the client into training cycle.
130 |
131 | Type: [Object][26]
132 |
133 | ### Properties
134 |
135 | - `model` **[SyftModel][32]** Instance of SyftModel.
136 | - `clientConfig` **[Object][26]** Client configuration returned by PyGrid.
137 |
138 | ## Job#rejected
139 |
140 | `rejected` event.
141 | Triggered when PyGrid rejects the client.
142 |
143 | Type: [Object][26]
144 |
145 | ### Properties
146 |
147 | - `timeout` **([number][35] | null)** Time in seconds to re-try. Empty when the FL model is not trainable anymore.
148 |
149 | ## Job#error
150 |
151 | `error` event.
152 | Triggered for plethora of error conditions.
153 |
154 | ## SyftModel
155 |
156 | Model parameters as stored in the PyGrid.
157 |
158 | ### Properties
159 |
160 | - `params` **[Array][36]<tf.Tensor>** Array of Model parameters.
161 |
162 | ### createSerializedDiff
163 |
164 | Calculates difference between 2 versions of the Model parameters
165 | and returns serialized `diff` that can be submitted to PyGrid.
166 |
167 | #### Parameters
168 |
169 | - `updatedModelParams` **[Array][36]<tf.Tensor>** Array of model parameters (tensors).
170 |
171 | Returns **[Promise][29]<[ArrayBuffer][34]>** Protobuf-serialized `diff`.
172 |
173 | ## Plan
174 |
175 | PySyft Plan.
176 |
177 | ### execute
178 |
179 | Executes the Plan and returns its output.
180 |
181 | The order, type and number of arguments must match to arguments defined in the PySyft Plan.
182 |
183 | #### Parameters
184 |
185 | - `worker` **[Syft][37]**
186 | - `data` **...(tf.Tensor | [number][35])**
187 |
188 | Returns **[Promise][29]<[Array][36]<tf.Tensor>>**
189 |
190 | [1]: #syft
191 |
192 | [2]: #parameters
193 |
194 | [3]: #examples
195 |
196 | [4]: #newjob
197 |
198 | [5]: #parameters-1
199 |
200 | [6]: #job
201 |
202 | [7]: #properties
203 |
204 | [8]: #on
205 |
206 | [9]: #parameters-2
207 |
208 | [10]: #start
209 |
210 | [11]: #parameters-3
211 |
212 | [12]: #report
213 |
214 | [13]: #parameters-4
215 |
216 | [14]: #jobaccepted
217 |
218 | [15]: #properties-1
219 |
220 | [16]: #jobrejected
221 |
222 | [17]: #properties-2
223 |
224 | [18]: #joberror
225 |
226 | [19]: #syftmodel
227 |
228 | [20]: #properties-3
229 |
230 | [21]: #createserializeddiff
231 |
232 | [22]: #parameters-5
233 |
234 | [23]: #plan
235 |
236 | [24]: #execute
237 |
238 | [25]: #parameters-6
239 |
240 | [26]: https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Object
241 |
242 | [27]: https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String
243 |
244 | [28]: https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean
245 |
246 | [29]: https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise
247 |
248 | [30]: #job
249 |
250 | [31]: #plan
251 |
252 | [32]: #syftmodel
253 |
254 | [33]: https://developer.mozilla.org/docs/Web/JavaScript/Reference/Statements/function
255 |
256 | [34]: https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/ArrayBuffer
257 |
258 | [35]: https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number
259 |
260 | [36]: https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array
261 |
262 | [37]: #syft
263 |
--------------------------------------------------------------------------------
/test/types/plan.test.js:
--------------------------------------------------------------------------------
1 | import { Plan } from '../../src/types/plan';
2 | import { State } from '../../src/types/state';
3 | import { TorchTensor } from '../../src/types/torch';
4 | import { Placeholder, PlaceholderId } from '../../src/types/placeholder';
5 | import * as tf from '@tensorflow/tfjs-core';
6 | import { protobuf, unserialize } from '../../src/protobuf';
7 | import Syft from '../../src/syft';
8 | import {
9 | MNIST_BATCH_SIZE,
10 | MNIST_LR,
11 | MNIST_BATCH_DATA,
12 | MNIST_PLAN,
13 | MNIST_MODEL_PARAMS,
14 | MNIST_UPD_MODEL_PARAMS,
15 | MNIST_LOSS,
16 | MNIST_ACCURACY,
17 | PLAN_WITH_STATE,
18 | BANDIT_SIMPLE_MODEL_PARAMS,
19 | BANDIT_SIMPLE_PLAN,
20 | BANDIT_THOMPSON_PLAN,
21 | BANDIT_THOMPSON_MODEL_PARAMS
22 | } from '../data/dummy';
23 | import { ComputationAction } from '../../src/types/computation-action';
24 | import { Role } from '../../src/types/role';
25 |
26 | describe('State', () => {
27 | test('can be properly constructed', () => {
28 | const ph = new Placeholder(123);
29 | const ts = new TorchTensor(
30 | 234,
31 | new Float32Array([1, 2, 3, 4]),
32 | [2, 2],
33 | 'float32'
34 | );
35 | const obj = new State([ph], [ts]);
36 | expect(obj.placeholders).toStrictEqual([ph]);
37 | expect(obj.tensors).toStrictEqual([ts]);
38 | });
39 | });
40 |
41 | describe('Plan', () => {
42 | test('can be properly constructed', () => {
43 | const phId1 = new PlaceholderId(555);
44 | const phId2 = new PlaceholderId(666);
45 | const ph1 = new Placeholder(555);
46 | const ph2 = new Placeholder(666);
47 | const action = new ComputationAction(
48 | 'torch.abs',
49 | null,
50 | [phId1],
51 | null,
52 | null,
53 | [phId2]
54 | );
55 | const state = new State([], []);
56 | const role = new Role(777, [action], state, [ph1, ph2], [phId1], [phId2]);
57 | const obj = new Plan(123, 'plan', role, ['tag1', 'tag2'], 'desc');
58 | expect(obj.id).toBe(123);
59 | expect(obj.name).toBe('plan');
60 | expect(obj.role.actions).toStrictEqual([action]);
61 | expect(obj.role.state).toBe(state);
62 | expect(obj.role.placeholders).toStrictEqual([ph1, ph2]);
63 | expect(obj.role.input_placeholder_ids).toStrictEqual([phId1]);
64 | expect(obj.role.output_placeholder_ids).toStrictEqual([phId2]);
65 | expect(obj.tags).toStrictEqual(['tag1', 'tag2']);
66 | expect(obj.description).toStrictEqual('desc');
67 | });
68 |
69 | test('can be executed (with state)', async () => {
70 | const input = tf.tensor([
71 | [1, 2],
72 | [-30, -40]
73 | ]);
74 | // this is what plan contains
75 | const state = tf.tensor([4.2, 7.3]);
76 | const expected = tf.abs(tf.add(input, state));
77 | const plan = unserialize(
78 | null,
79 | PLAN_WITH_STATE,
80 | protobuf.syft_proto.execution.v1.Plan
81 | );
82 | const worker = new Syft({ url: 'dummy' });
83 | const result = await plan.execute(worker, input);
84 | expect(result[0]).toBeInstanceOf(tf.Tensor);
85 | expect(
86 | tf
87 | .equal(result[0], expected)
88 | .all()
89 | .dataSync()[0]
90 | ).toBe(1);
91 | });
92 |
93 | test('invalid args shape throws corresponding error', async () => {
94 | // PLAN_WITH_STATE plan contains input + A(2,2)
95 | // this should error with tf error about incompatible shapes
96 | const input = tf.ones([3, 3]);
97 | const plan = unserialize(
98 | null,
99 | PLAN_WITH_STATE,
100 | protobuf.syft_proto.execution.v1.Plan
101 | );
102 | const worker = new Syft({ url: 'dummy' });
103 | expect(plan.execute(worker, input)).rejects.toThrow(
104 | 'Operands could not be broadcast together with shapes 3,3 and 2.'
105 | );
106 | });
107 |
108 | test('can be executed (MNIST example)', async () => {
109 | const plan = unserialize(
110 | null,
111 | MNIST_PLAN,
112 | protobuf.syft_proto.execution.v1.Plan
113 | );
114 |
115 | const worker = new Syft({ url: 'dummy' });
116 | const dataState = unserialize(
117 | null,
118 | MNIST_BATCH_DATA,
119 | protobuf.syft_proto.execution.v1.State
120 | );
121 | const [data, labels] = dataState.tensors;
122 | const lr = tf.tensor(MNIST_LR);
123 | const batchSize = tf.tensor(MNIST_BATCH_SIZE);
124 | const modelState = unserialize(
125 | null,
126 | MNIST_MODEL_PARAMS,
127 | protobuf.syft_proto.execution.v1.State
128 | );
129 | const modelParams = modelState.tensors;
130 | const [loss, acc, ...updModelParams] = await plan.execute(
131 | worker,
132 | data,
133 | labels,
134 | batchSize,
135 | lr,
136 | ...modelParams
137 | );
138 |
139 | const refUpdModelParamsState = unserialize(
140 | null,
141 | MNIST_UPD_MODEL_PARAMS,
142 | protobuf.syft_proto.execution.v1.State
143 | );
144 | const refUpdModelParams = refUpdModelParamsState.tensors.map(i =>
145 | i.toTfTensor()
146 | );
147 |
148 | expect(loss).toBeInstanceOf(tf.Tensor);
149 | expect(acc).toBeInstanceOf(tf.Tensor);
150 | expect(updModelParams).toHaveLength(refUpdModelParams.length);
151 | expect(loss.arraySync()).toStrictEqual(MNIST_LOSS);
152 | expect(acc.arraySync()).toStrictEqual(MNIST_ACCURACY);
153 |
154 | for (let i = 0; i < refUpdModelParams.length; i++) {
155 | // Check that resulting model params are close to pysyft reference
156 | let diff = refUpdModelParams[i].sub(updModelParams[i]);
157 | expect(
158 | diff
159 | .abs()
160 | .sum()
161 | .arraySync()
162 | ).toBeLessThan(1e-7);
163 | }
164 | });
165 |
166 | test('bandit (simple) example can be executed', async () => {
167 | const plan = unserialize(
168 | null,
169 | BANDIT_SIMPLE_PLAN,
170 | protobuf.syft_proto.execution.v1.Plan
171 | );
172 |
173 | const worker = new Syft({ url: 'dummy' });
174 | const modelState = unserialize(
175 | null,
176 | BANDIT_SIMPLE_MODEL_PARAMS,
177 | protobuf.syft_proto.execution.v1.State
178 | );
179 | const [means] = modelState.tensors;
180 |
181 | const reward = tf.tensor([0, 1, 0]);
182 | const n_so_far = tf.tensor([1, 1, 1]);
183 | const [newMeans] = await plan.execute(worker, reward, n_so_far, means);
184 |
185 | newMeans.print();
186 | });
187 |
188 | test('bandit (thompson) example can be executed', async () => {
189 | const plan = unserialize(
190 | null,
191 | BANDIT_THOMPSON_PLAN,
192 | protobuf.syft_proto.execution.v1.Plan
193 | );
194 |
195 | const worker = new Syft({ url: 'dummy' });
196 | const modelState = unserialize(
197 | null,
198 | BANDIT_THOMPSON_MODEL_PARAMS,
199 | protobuf.syft_proto.execution.v1.State
200 | );
201 | const [alphas, betas] = modelState.tensors;
202 | const reward = tf.tensor([0, 0, 0]);
203 | const samples = tf.tensor([0, 0, 1]);
204 | const [newAlphas, newBetas] = await plan.execute(
205 | worker,
206 | reward,
207 | samples,
208 | alphas,
209 | betas
210 | );
211 |
212 | newAlphas.print();
213 | newBetas.print();
214 | });
215 | });
216 |
--------------------------------------------------------------------------------
/test/data/generate-data.py:
--------------------------------------------------------------------------------
1 | # This python script generates values in dummy.js, for MNIST plan unit tests.
2 | # Script should be executed with appropriate PySyft version installed.
3 |
4 | import os
5 | import base64
6 |
7 | import torch as th
8 | from torch import jit
9 | from torch import nn
10 | from torchvision import datasets, transforms
11 |
12 | import syft as sy
13 | from syft.serde import protobuf
14 | from syft.execution.state import State
15 | from syft.execution.placeholder import PlaceHolder
16 | from syft.execution.translation import TranslationTarget
17 |
18 | sy.make_hook(globals())
19 | # force protobuf serialization for tensors
20 | hook.local_worker.framework = None
21 | th.random.manual_seed(1)
22 |
23 | def serialize_to_b64_pb(worker, obj):
24 | pb = protobuf.serde._bufferize(worker, obj)
25 | bin = pb.SerializeToString()
26 | return base64.b64encode(bin).decode('ascii')
27 |
28 |
29 | def tensors_to_state(tensors):
30 | return State(
31 | state_placeholders=[
32 | PlaceHolder().instantiate(t)
33 | for t in tensors
34 | ]
35 | )
36 |
37 | def set_model_params(module, params_list, start_param_idx=0):
38 | """ Set params list into model recursively
39 | """
40 | param_idx = start_param_idx
41 |
42 | for name, param in module._parameters.items():
43 | module._parameters[name] = params_list[param_idx]
44 | param_idx += 1
45 |
46 | for name, child in module._modules.items():
47 | if child is not None:
48 | param_idx = set_model_params(child, params_list, param_idx)
49 |
50 | return param_idx
51 |
52 | # = MNIST =
53 |
54 | class Net(nn.Module):
55 | def __init__(self):
56 | super(Net, self).__init__()
57 | self.fc1 = nn.Linear(784, 392)
58 | self.fc2 = nn.Linear(392, 10)
59 |
60 | def forward(self, x):
61 | x = self.fc1(x)
62 | x = nn.functional.relu(x)
63 | x = self.fc2(x)
64 | return x
65 |
66 | def softmax_cross_entropy_with_logits(logits, targets, batch_size):
67 | """ Calculates softmax entropy
68 | Args:
69 | * logits: (NxC) outputs of dense layer
70 | * targets: (NxC) one-hot encoded labels
71 | * batch_size: value of N, temporarily required because Plan cannot trace .shape
72 | """
73 | # numstable logsoftmax
74 | norm_logits = logits - logits.max()
75 | log_probs = norm_logits - norm_logits.exp().sum(dim=1, keepdim=True).log()
76 | # NLL, reduction = mean
77 | return -(targets * log_probs).sum() / batch_size
78 |
79 | def naive_sgd(param, **kwargs):
80 | return param - kwargs['lr'] * param.grad
81 |
82 | model = Net()
83 |
84 | @sy.func2plan()
85 | def training_plan(X, y, batch_size, lr, model_params):
86 | # inject params into model
87 | set_model_params(model, model_params)
88 |
89 | # forward pass
90 | logits = model.forward(X)
91 |
92 | # loss
93 | loss = softmax_cross_entropy_with_logits(logits, y, batch_size)
94 |
95 | # backprop
96 | loss.backward()
97 |
98 | # step
99 | updated_params = [
100 | naive_sgd(param, lr=lr)
101 | for param in model_params
102 | ]
103 |
104 | # accuracy
105 | pred = th.argmax(logits, dim=1)
106 | target = th.argmax(y, dim=1)
107 | acc = pred.eq(target).sum().float() / batch_size
108 |
109 | return (
110 | loss,
111 | acc,
112 | *updated_params
113 | )
114 |
115 | # Build the training plan.
116 | bs = 64
117 | lr = 0.005
118 | model_params = tuple(p.data for p in model.parameters())
119 | X = th.randn(bs, 28 * 28)
120 | y = nn.functional.one_hot(th.randint(0, 10, (bs,)), 10)
121 | training_plan.build(X, y, th.tensor([bs]), th.tensor([lr]), model_params, trace_autograd=True)
122 |
123 | # Produce training plan result after the first batch.
124 | mnist_dataset = th.utils.data.DataLoader(
125 | datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor()),
126 | batch_size=bs,
127 | shuffle=False
128 | )
129 | X, y = next(iter(mnist_dataset))
130 | X = X.view(bs, -1)
131 | y_oh = th.nn.functional.one_hot(y, 10).float()
132 | training_plan.forward = None
133 | loss, acc, *upd_params = training_plan(X, y_oh, th.tensor([bs], dtype=th.float32), th.tensor([lr]), model_params)
134 |
135 | print("MNIST plan (torch): ")
136 | print(training_plan.code)
137 |
138 | training_plan.base_framework = TranslationTarget.TENSORFLOW_JS.value
139 | print("MNIST plan: ")
140 | print(training_plan.code)
141 |
142 | # = Plan with state =
143 | @sy.func2plan(args_shape=[(2,2)], state=(th.tensor([4.2, 7.3]),))
144 | def plan_with_state(x, state):
145 | (y,) = state.read()
146 | x = x + y
147 | x = th.abs(x)
148 | return x
149 |
150 | plan_with_state.base_framework = TranslationTarget.TENSORFLOW_JS.value
151 | print("Plan w/ state: ")
152 | print(plan_with_state.code)
153 |
154 | # = Bandit plans =
155 | # Simple
156 | reward = th.tensor([0.0, 0.0, 0.0])
157 | n_so_far = th.tensor([1.0, 1.0, 1.0])
158 | means = th.tensor([1.0, 2.0, 3.0])
159 | bandit_args = [reward, n_so_far, means]
160 | bandit_arg_shape = [arg.shape for arg in bandit_args]
161 | @sy.func2plan(args_shape=bandit_arg_shape)
162 | def bandit(reward, n_so_far, means):
163 | prev = means
164 | new = th.div(prev*(n_so_far-1),n_so_far) + th.div(reward,n_so_far)
165 | means=new
166 | return means
167 |
168 | bandit.base_framework = TranslationTarget.TENSORFLOW_JS.value
169 | print("Bandit simple plan: ")
170 | print(bandit.code)
171 |
172 | # Thompson
173 | alphas = th.tensor([1.0, 1.0, 1.0], requires_grad=False)
174 | betas = th.tensor([1.0, 1.0, 1.0], requires_grad=False)
175 | rwd = th.tensor([0.0, 0.0, 0.0])
176 | samples = th.tensor([0.0, 0.0, 0.0])
177 | bandit_args_th = [rwd, samples, alphas, betas]
178 | bandit_th_args_shape = [rwd.shape, samples.shape, alphas.shape, betas.shape]
179 | @sy.func2plan(args_shape=bandit_th_args_shape)
180 | def bandit_thompson(reward, sample_vector, alphas, betas):
181 | prev_alpha = alphas
182 | prev_beta = betas
183 | alphas = prev_alpha.add(reward)
184 | betas = prev_beta.add(sample_vector.sub(reward))
185 | return (alphas, betas)
186 |
187 | bandit_thompson.base_framework = TranslationTarget.TENSORFLOW_JS.value
188 | print("Bandit thompson plan: ")
189 | print(bandit_thompson.code)
190 |
191 | replacements = {
192 | 'MNIST_BATCH_SIZE': bs,
193 | 'MNIST_LR': lr,
194 | 'MNIST_PLAN': serialize_to_b64_pb(hook.local_worker, training_plan),
195 | 'MNIST_BATCH_DATA': serialize_to_b64_pb(hook.local_worker, tensors_to_state([X, y_oh])),
196 | 'MNIST_MODEL_PARAMS': serialize_to_b64_pb(hook.local_worker, tensors_to_state(model_params)),
197 | 'MNIST_UPD_MODEL_PARAMS': serialize_to_b64_pb(hook.local_worker, tensors_to_state(upd_params)),
198 | 'MNIST_LOSS': loss.item(),
199 | 'MNIST_ACCURACY': acc.item(),
200 | 'PLAN_WITH_STATE': serialize_to_b64_pb(hook.local_worker, plan_with_state),
201 | 'BANDIT_SIMPLE_PLAN': serialize_to_b64_pb(hook.local_worker, bandit),
202 | 'BANDIT_SIMPLE_MODEL_PARAMS': serialize_to_b64_pb(hook.local_worker, tensors_to_state([means])),
203 | 'BANDIT_THOMPSON_PLAN': serialize_to_b64_pb(hook.local_worker, bandit_thompson),
204 | 'BANDIT_THOMPSON_MODEL_PARAMS': serialize_to_b64_pb(hook.local_worker, tensors_to_state([alphas, betas])),
205 | }
206 |
207 | with open("dummy.tpl.js", "r") as tpl, open("dummy.js", "w") as output:
208 | js_tpl = tpl.read()
209 | for k, v in replacements.items():
210 | js_tpl = js_tpl.replace(f"%{k}%", str(v))
211 | output.write(js_tpl)
212 |
--------------------------------------------------------------------------------
/examples/multi-armed-bandit/app.js:
--------------------------------------------------------------------------------
1 | /** @jsx jsx */
2 | import { jsx } from '@emotion/core';
3 | import React, { useEffect } from 'react';
4 |
5 | import backgroundGradient from './background-gradient.svg';
6 | import logoWhite from './logo-white.svg';
7 | import logoColor from './logo-color.svg';
8 |
9 | import '@fortawesome/fontawesome-free/js/all';
10 |
11 | const fontFamily =
12 | '"Rubik", -apple-system, BlinkMacSystemFont, "Segoe UI", Helvetica, Arial, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol"';
13 |
14 | const globalStyles = {
15 | paragraph: {
16 | fontFamily,
17 | lineHeight: 1.75,
18 | fontSize: 18
19 | },
20 | heading: {
21 | fontFamily,
22 | lineHeight: 1.5,
23 | textTransform: 'uppercase',
24 | fontWeight: 700,
25 | letterSpacing: 2,
26 | marginTop: 0,
27 | marginBottom: 0,
28 | fontSize: 24
29 | }
30 | };
31 |
32 | const Hero = ({ background, button }) => {
33 | const title = 'Answer Questions Using Data You Cannot See';
34 | const description = `OpenMined is an open-source community whose goal is to make the world more privacy-preserving by lowering the barrier-to-entry to private AI technologies.`;
35 |
36 | const styles = {
37 | container: {
38 | background:
39 | background === 'gradient' ? `url(${backgroundGradient})` : '#333',
40 | backgroundRepeat: 'no-repeat',
41 | backgroundSize: 'cover',
42 | padding: 40,
43 | display: 'flex',
44 | flexDirection: 'column',
45 | alignItems: 'center'
46 | },
47 | logo: {
48 | width: 200,
49 | height: 'auto'
50 | },
51 | title: {
52 | color: 'white',
53 | marginTop: 40,
54 | width: 480,
55 | textAlign: 'center'
56 | },
57 | description: {
58 | color: 'white',
59 | marginTop: 40,
60 | marginBottom: 0,
61 | width: 480,
62 | textAlign: 'center'
63 | }
64 | };
65 |
66 | return (
67 |
68 |
69 |
{title}
70 |
71 | {description}
72 |
73 | {button && button}
74 |
75 | );
76 | };
77 |
78 | const Vision = ({ button }) => {
79 | const title = 'Vision & Mission';
80 | const description = `
81 | Industry standard tools for artificial intelligence have been designed with several assumptions: data is centralized into a single compute cluster, the cluster exists in a secure cloud, and the resulting models will be owned by a central authority. We envision a world in which we are not restricted to this scenario - a world in which AI tools treat privacy, security, and multi-owner governance as first class citizens.
82 | With OpenMined, an AI model can be governed by multiple owners and trained securely on an unseen, distributed dataset.
83 | The mission of the OpenMined community is to create an accessible ecosystem of tools for private, secure, multi-owner governed AI. We do this by extending popular libraries like TensorFlow and PyTorch with advanced techniques in cryptography and private machine learning.
84 | `;
85 |
86 | const styles = {
87 | container: {
88 | padding: 40,
89 | maxWidth: 960,
90 | width: '90%',
91 | margin: '0 auto'
92 | }
93 | };
94 |
95 | return (
96 |
97 |
{title}
98 |
102 | {button && button}
103 |
104 | );
105 | };
106 |
107 | const Button = ({ background, icon, onClick }) => (
108 |
136 | Sign Up
137 | {icon === 'arrow' && }
138 | {icon === 'user' && }
139 | {icon === 'code' && }
140 |
141 | );
142 |
143 | const Footer = () => {
144 | const styles = {
145 | wrapper: {
146 | background: '#333',
147 | padding: 40
148 | },
149 | container: {
150 | display: 'flex',
151 | justifyContent: 'space-between',
152 | width: 940,
153 | margin: '0 auto'
154 | },
155 | logo: {
156 | width: 200,
157 | height: 'auto'
158 | },
159 | social: {
160 | display: 'flex',
161 | alignItems: 'center'
162 | },
163 | socialIcon: {
164 | color: 'rgba(255, 255, 255, 0.5)',
165 | transition: 'color 0.2s ease-in-out',
166 | fontSize: 24,
167 | marginLeft: 20,
168 | '&:hover': {
169 | color: '#FFF'
170 | }
171 | }
172 | };
173 | return (
174 |
175 |
176 |
177 |
207 |
208 |
209 | );
210 | };
211 |
212 | export default ({ config, onButtonClick, start }) => {
213 | useEffect(() => {
214 | start();
215 | }, []);
216 |
217 | const button = (
218 |
223 | );
224 |
225 | return (
226 |
233 |
237 |
238 |
239 |
240 | );
241 | };
242 |
--------------------------------------------------------------------------------
/src/job.js:
--------------------------------------------------------------------------------
1 | import EventObserver from './events';
2 | import { protobuf, unserialize } from './protobuf';
3 | import { CYCLE_STATUS_ACCEPTED, CYCLE_STATUS_REJECTED } from './_constants';
4 | import { GRID_UNKNOWN_CYCLE_STATUS, PLAN_LOAD_FAILED } from './_errors';
5 | import SyftModel from './syft-model';
6 | import Logger from './logger';
7 |
8 | /**
9 | * Job represents a single training cycle done by the client.
10 | *
11 | * @property {Object.} plans Plans dictionary.
12 | * @property {Object.} protocols [not implemented] Protocols dictionary.
13 | * @property {SyftModel} model Model.
14 | */
15 | export default class Job {
16 | /**
17 | * @hideconstructor
18 | * @param {object} options
19 | * @param {Syft} options.worker Instance of Syft client.
20 | * @param {string} options.modelName Model name.
21 | * @param {string} options.modelVersion Model version.
22 | * @param {GridAPIClient} options.gridClient Instance of GridAPIClient.
23 | */
24 | constructor({ worker, modelName, modelVersion, gridClient }) {
25 | this.worker = worker;
26 | this.modelName = modelName;
27 | this.modelVersion = modelVersion;
28 | this.grid = gridClient;
29 | this.logger = new Logger();
30 | this.observer = new EventObserver();
31 |
32 | // parameters loaded from grid
33 | this.model = null;
34 | this.plans = {};
35 | this.protocols = {};
36 | // holds request_key
37 | this.cycleParams = {};
38 | this.clientConfig = {};
39 | }
40 |
41 | /**
42 | * Registers an event listener.
43 | *
44 | * Available events: `accepted`, `rejected`, `error`.
45 | *
46 | * @param {string} event Event name.
47 | * @param {function} handler Event handler.
48 | */
49 | on(event, handler) {
50 | if (['accepted', 'rejected', 'error'].includes(event)) {
51 | this.observer.subscribe(event, handler.bind(this));
52 | }
53 | }
54 |
55 | /**
56 | * Initializes the Job with provided training cycle params.
57 | *
58 | * @private
59 | * @param {Object} cycleParams
60 | * @returns {Promise}
61 | */
62 | async initCycle(cycleParams) {
63 | this.logger.log(
64 | `Cycle initialization with params: ${JSON.stringify(cycleParams)}`
65 | );
66 | this.cycleParams = cycleParams;
67 | this.clientConfig = cycleParams.client_config;
68 |
69 | // load the model
70 | const modelData = await this.grid.getModel(
71 | this.worker.worker_id,
72 | cycleParams.request_key,
73 | cycleParams.model_id
74 | );
75 | this.model = new SyftModel({
76 | worker: this.worker,
77 | modelData
78 | });
79 |
80 | // load all plans
81 | for (let planName of Object.keys(cycleParams.plans)) {
82 | const planId = cycleParams.plans[planName];
83 | const planBinary = await this.grid.getPlan(
84 | this.worker.worker_id,
85 | cycleParams.request_key,
86 | planId
87 | );
88 | try {
89 | this.plans[planName] = unserialize(
90 | this.worker,
91 | planBinary,
92 | protobuf.syft_proto.execution.v1.Plan
93 | );
94 | } catch (e) {
95 | throw new Error(PLAN_LOAD_FAILED(planName, e.message));
96 | }
97 | }
98 |
99 | // load all protocols
100 | for (let protocolName of Object.keys(cycleParams.protocols)) {
101 | const protocolId = cycleParams.protocols[protocolName];
102 | const protocolBinary = await this.grid.getProtocol(
103 | this.worker.worker_id,
104 | cycleParams.request_key,
105 | protocolId
106 | );
107 | this.protocols[protocolName] = unserialize(
108 | this.worker,
109 | protocolBinary,
110 | protobuf.syft_proto.execution.v1.Protocol
111 | );
112 | }
113 | }
114 |
115 | /**
116 | * Starts the Job executing following actions:
117 | * * Meters connection speed to PyGrid.
118 | * * Registers into training cycle on PyGrid.
119 | * * Retrieves cycle and client parameters.
120 | * * Downloads Plans, Model, Protocols.
121 | * * Fires `accepted` event on success.
122 | *
123 | * @fires Job#accepted
124 | * @fires Job#rejected
125 | * @fires Job#error
126 | * @param {Object} options
127 | * @param {boolean} options.skipGridSpeedTest When true, skips the speed test before requesting a cycle.
128 | * @returns {Promise}
129 | */
130 | async start({ skipGridSpeedTest = false } = {}) {
131 | let cycleParams;
132 | try {
133 | let [ping, download, upload] = [0, 0, 0];
134 | if (!skipGridSpeedTest) {
135 | // speed test
136 | ({ ping, download, upload } = await this.grid.getConnectionSpeed(
137 | this.worker.worker_id
138 | ));
139 | }
140 |
141 | // request cycle
142 | cycleParams = await this.grid.requestCycle(
143 | this.worker.worker_id,
144 | this.modelName,
145 | this.modelVersion,
146 | ping,
147 | download,
148 | upload
149 | );
150 |
151 | if (cycleParams.status === CYCLE_STATUS_ACCEPTED) {
152 | // load model, plans, protocols, etc.
153 | this.logger.log(
154 | `Accepted into cycle with params: ${JSON.stringify(
155 | cycleParams,
156 | null,
157 | 2
158 | )}`
159 | );
160 | await this.initCycle(cycleParams);
161 | }
162 |
163 | if (
164 | ![CYCLE_STATUS_ACCEPTED, CYCLE_STATUS_REJECTED].includes(
165 | cycleParams.status
166 | )
167 | ) {
168 | throw new Error(GRID_UNKNOWN_CYCLE_STATUS(cycleParams.status));
169 | }
170 | } catch (error) {
171 | /**
172 | * `error` event.
173 | * Triggered for plethora of error conditions.
174 | *
175 | * @event Job#error
176 | */
177 | this.observer.broadcast('error', error);
178 | return;
179 | }
180 |
181 | // Trigger events outside of try/catch.
182 | switch (cycleParams.status) {
183 | case CYCLE_STATUS_ACCEPTED:
184 | /**
185 | * `accepted` event.
186 | * Triggered when PyGrid accepts the client into training cycle.
187 | *
188 | * @event Job#accepted
189 | * @type {Object}
190 | * @property {SyftModel} model Instance of SyftModel.
191 | * @property {Object} clientConfig Client configuration returned by PyGrid.
192 | */
193 | this.observer.broadcast('accepted', {
194 | model: this.model,
195 | clientConfig: this.clientConfig
196 | });
197 | break;
198 |
199 | case CYCLE_STATUS_REJECTED:
200 | this.logger.log(
201 | `Rejected from cycle with timeout: ${cycleParams.timeout}`
202 | );
203 |
204 | /**
205 | * `rejected` event.
206 | * Triggered when PyGrid rejects the client.
207 | *
208 | * @event Job#rejected
209 | * @type {Object}
210 | * @property {number|null} timeout Time in seconds to re-try. Empty when the FL model is not trainable anymore.
211 | */
212 | this.observer.broadcast('rejected', {
213 | timeout: cycleParams.timeout
214 | });
215 | break;
216 | }
217 | }
218 |
219 | /**
220 | * Submits the model diff to PyGrid.
221 | *
222 | * @param {ArrayBuffer} diff Serialized difference between original and trained model parameters.
223 | * @returns {Promise}
224 | */
225 | async report(diff) {
226 | await this.grid.submitReport(
227 | this.worker.worker_id,
228 | this.cycleParams.request_key,
229 | Buffer.from(diff).toString('base64')
230 | );
231 | }
232 | }
233 |
--------------------------------------------------------------------------------
/examples/multi-armed-bandit/logo-white.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 |
16 |
17 |
18 |
23 |
29 |
35 |
39 |
43 |
49 |
50 |
56 |
57 |
59 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
--------------------------------------------------------------------------------
/src/grid-api-client.js:
--------------------------------------------------------------------------------
1 | import Logger from './logger';
2 | import { SpeedTest } from './speed-test';
3 | import { GRID_ERROR } from './_errors';
4 | import EventObserver from './events';
5 |
6 | const HTTP_PATH_VERB = {
7 | 'federated/get-plan': 'GET',
8 | 'federated/get-model': 'GET',
9 | 'federated/get-protocol': 'GET',
10 | 'federated/cycle-request': 'POST',
11 | 'federated/report': 'POST',
12 | 'federated/authenticate': 'POST'
13 | };
14 |
15 | export default class GridAPIClient {
16 | constructor({ url, allowInsecureUrl = false }) {
17 | this.transport = url.match(/^ws/i) ? 'ws' : 'http';
18 | if (this.transport === 'ws') {
19 | this.wsUrl = url;
20 | this.httpUrl = url.replace(/^ws(s)?/i, 'http$1');
21 | } else {
22 | this.httpUrl = url;
23 | this.wsUrl = url.replace(/^http(s)?/i, 'ws$1');
24 | }
25 | if (!allowInsecureUrl) {
26 | this.wsUrl = this.wsUrl.replace('ws', 'wss');
27 | this.httpUrl = this.httpUrl.replace('http', 'https');
28 | }
29 | this.ws = null;
30 | this.observer = new EventObserver();
31 | this.wsMessageQueue = [];
32 | this.logger = new Logger('grid', true);
33 | this.responseTimeout = 10000;
34 |
35 | this._handleWsMessage = this._handleWsMessage.bind(this);
36 | this._handleWsError = this._handleWsError.bind(this);
37 | this._handleWsClose = this._handleWsClose.bind(this);
38 | }
39 |
40 | async authenticate(modelName, modelVersion, authToken) {
41 | this.logger.log(
42 | `Authenticating against ${modelName} ${modelVersion} with ${authToken}...`
43 | );
44 |
45 | const response = await this._send('federated/authenticate', {
46 | model_name: modelName,
47 | model_version: modelVersion,
48 | auth_token: authToken
49 | });
50 |
51 | return response;
52 | }
53 |
54 | requestCycle(workerId, modelName, modelVersion, ping, download, upload) {
55 | this.logger.log(
56 | `[WID: ${workerId}] Requesting cycle for model ${modelName} v.${modelVersion} [${ping}, ${download}, ${upload}]...`
57 | );
58 |
59 | const response = this._send('federated/cycle-request', {
60 | worker_id: workerId,
61 | model: modelName,
62 | version: modelVersion,
63 | ping: ping,
64 | download: download,
65 | upload: upload
66 | });
67 |
68 | return response;
69 | }
70 |
71 | async getModel(workerId, requestKey, modelId) {
72 | this.logger.log(
73 | `[WID: ${workerId}, KEY: ${requestKey}] Requesting model ${modelId}...`
74 | );
75 |
76 | const response = await this._sendHttp(
77 | 'federated/get-model',
78 | {
79 | worker_id: workerId,
80 | request_key: requestKey,
81 | model_id: modelId
82 | },
83 | 'arrayBuffer'
84 | );
85 | return response;
86 | }
87 |
88 | async getPlan(workerId, requestKey, planId) {
89 | this.logger.log(
90 | `[WID: ${workerId}, KEY: ${requestKey}] Requesting plan ${planId}...`
91 | );
92 |
93 | const response = await this._sendHttp(
94 | 'federated/get-plan',
95 | {
96 | worker_id: workerId,
97 | request_key: requestKey,
98 | plan_id: planId,
99 | receive_operations_as: 'tfjs'
100 | },
101 | 'arrayBuffer'
102 | );
103 |
104 | return response;
105 | }
106 |
107 | getProtocol(workerId, requestKey, protocolId) {
108 | this.logger.log(
109 | `[WID: ${workerId}, KEY: ${requestKey}] Requesting protocol ${protocolId}...`
110 | );
111 | return Promise.resolve(
112 | 'CgYIjcivoCUqEwoGCIHIr6AlEgkSB3dvcmtlcjEqEwoGCIXIr6AlEgkSB3dvcmtlcjIqEwoGCInIr6AlEgkSB3dvcmtlcjM='
113 | );
114 | }
115 |
116 | async submitReport(workerId, requestKey, diff) {
117 | this.logger.log(
118 | `[WID: ${workerId}, KEY: ${requestKey}] Submitting report...`
119 | );
120 |
121 | const response = await this._send('federated/report', {
122 | worker_id: workerId,
123 | request_key: requestKey,
124 | diff
125 | });
126 |
127 | return response;
128 | }
129 |
130 | async getConnectionSpeed(workerId) {
131 | const speedTest = new SpeedTest({
132 | downloadUrl:
133 | this.httpUrl +
134 | '/federated/speed-test?worker_id=' +
135 | encodeURIComponent(workerId) +
136 | '&random=' +
137 | Math.random(),
138 | uploadUrl:
139 | this.httpUrl +
140 | '/federated/speed-test?worker_id=' +
141 | encodeURIComponent(workerId) +
142 | '&random=' +
143 | Math.random(),
144 | pingUrl:
145 | this.httpUrl +
146 | '/federated/speed-test?is_ping=1&worker_id=' +
147 | encodeURIComponent(workerId) +
148 | '&random=' +
149 | Math.random()
150 | });
151 |
152 | const ping = await speedTest.getPing();
153 | // start tests altogether
154 | const [download, upload] = await Promise.all([
155 | speedTest.getDownloadSpeed(),
156 | speedTest.getUploadSpeed()
157 | ]);
158 |
159 | return {
160 | ping,
161 | download,
162 | upload
163 | };
164 | }
165 |
166 | async _send(path, data) {
167 | const response =
168 | this.transport === 'ws'
169 | ? await this._sendWs(path, data)
170 | : await this._sendHttp(path, data);
171 |
172 | if (response.error) {
173 | throw new Error(response.error);
174 | }
175 |
176 | return response;
177 | }
178 |
179 | async _sendHttp(path, data, type = 'json') {
180 | const method = HTTP_PATH_VERB[path] || 'GET';
181 | let response;
182 |
183 | if (method === 'GET') {
184 | const query = Object.keys(data)
185 | .map(k => encodeURIComponent(k) + '=' + encodeURIComponent(data[k]))
186 | .join('&');
187 | response = await fetch(this.httpUrl + '/' + path + '?' + query, {
188 | method: 'GET',
189 | mode: 'cors'
190 | });
191 | } else {
192 | response = await fetch(this.httpUrl + '/' + path, {
193 | method: 'POST',
194 | mode: 'cors',
195 | headers: {
196 | 'Content-Type': 'application/json'
197 | },
198 | body: JSON.stringify(data)
199 | });
200 | }
201 |
202 | if (!response.ok) {
203 | let error = `${response.status} ${response.statusText}`;
204 | try {
205 | let res = await response.json();
206 | if (res.error) {
207 | error = res.error;
208 | }
209 | } catch (e) {
210 | // not JSON
211 | }
212 | throw new Error(GRID_ERROR(error));
213 | }
214 |
215 | return response[type]();
216 | }
217 |
218 | async _sendWs(type, data) {
219 | if (!this.ws) {
220 | await this._initWs();
221 | }
222 |
223 | const message = { type, data };
224 | this.logger.log('Sending WS message', type);
225 |
226 | return new Promise((resolve, reject) => {
227 | this.ws.send(JSON.stringify(message));
228 |
229 | const cleanUp = () => {
230 | // Remove all handlers related to message.
231 | this.wsMessageQueue = this.wsMessageQueue.filter(
232 | item => item !== onMessage
233 | );
234 | this.observer.unsubscribe('ws-error', onError);
235 | this.observer.unsubscribe('ws-close', onClose);
236 | clearTimeout(timeoutHandler);
237 | };
238 |
239 | const timeoutHandler = setTimeout(() => {
240 | cleanUp();
241 | reject(new Error('Response timeout'));
242 | }, this.responseTimeout);
243 |
244 | const onMessage = data => {
245 | if (data.type !== message.type) {
246 | this.logger.log('Received invalid response type, ignoring');
247 | return false;
248 | }
249 | cleanUp();
250 | resolve(data.data);
251 | };
252 |
253 | const onError = event => {
254 | cleanUp();
255 | reject(new Error(event));
256 | };
257 |
258 | const onClose = () => {
259 | cleanUp();
260 | reject(new Error('WS connection closed'));
261 | };
262 |
263 | // We expect responses coming in same order as requests.
264 | this.wsMessageQueue.push(onMessage);
265 |
266 | // Other events while waiting for response.
267 | this.observer.subscribe('ws-error', onError);
268 | this.observer.subscribe('ws-close', onClose);
269 | });
270 | }
271 |
272 | async _initWs() {
273 | const ws = new WebSocket(this.wsUrl);
274 | return new Promise((resolve, reject) => {
275 | ws.onopen = () => {
276 | // setup handlers
277 | ws.onerror = this._handleWsError;
278 | ws.onclose = this._handleWsClose;
279 | ws.onmessage = this._handleWsMessage;
280 | this.ws = ws;
281 | resolve();
282 | };
283 | ws.onerror = event => {
284 | // couldn't connect
285 | this._handleWsError(event);
286 | reject(new Error(event));
287 | };
288 | ws.onclose = event => {
289 | // couldn't connect
290 | this._handleWsClose(event);
291 | reject(new Error('WS connection closed during connect'));
292 | };
293 | });
294 | }
295 |
296 | _handleWsMessage(event) {
297 | this.logger.log('Received message', event.data);
298 | let data;
299 | try {
300 | data = JSON.parse(event.data);
301 | } catch (e) {
302 | this.logger.log('Message is not valid JSON!');
303 | }
304 |
305 | // Call response handlers (in order of requests),
306 | // stopping at the first successful handler.
307 | for (let handler of this.wsMessageQueue) {
308 | if (handler(data) !== false) {
309 | break;
310 | }
311 | }
312 | }
313 |
314 | _handleWsError(event) {
315 | this.logger.log('WS connection error', event);
316 | this.observer.broadcast('ws-error', event);
317 | this.ws = null;
318 | }
319 |
320 | _handleWsClose(event) {
321 | this.logger.log('WS connection closed', event);
322 | this.observer.broadcast('ws-close', event);
323 | this.ws = null;
324 | }
325 | }
326 |
--------------------------------------------------------------------------------
/examples/multi-armed-bandit/logo-color.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 |
17 |
18 |
23 |
29 |
35 |
39 |
43 |
49 |
50 |
56 |
57 |
59 |
61 |
62 |
63 |
64 |
65 |
66 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2019 OpenMined Contributors
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------