├── jest-globals.js ├── .prettierignore ├── .eslintignore ├── .prettierrc ├── .gitignore ├── art ├── fake-logo.png └── mnist-demo-ani.gif ├── documentation.yml ├── src ├── index.js ├── events.js ├── types │ ├── pointer-tensor.js │ ├── protocol.js │ ├── state.js │ ├── plan.js │ ├── message.js │ ├── placeholder.js │ ├── torch.js │ ├── role.js │ └── computation-action.js ├── logger.js ├── object-registry.js ├── data-channel-message-queue.js ├── _errors.js ├── _constants.js ├── protobuf │ ├── mapping.js │ └── index.js ├── syft-model.js ├── sockets.js ├── syft.js ├── syft-webrtc.js ├── data-channel-message.js ├── speed-test.js ├── job.js └── grid-api-client.js ├── .npmignore ├── copy-examples.sh ├── .babelrc ├── .eslintrc.json ├── test ├── types │ ├── pointer-tensor.test.js │ ├── placeholder.test.js │ ├── protocol.test.js │ ├── message.test.js │ ├── torch.test.js │ └── plan.test.js ├── data │ ├── dummy.tpl.js │ └── generate-data.py ├── events.test.js ├── mocks │ ├── webrtc.js │ └── grid.js ├── data_channel_message_queue.test.js ├── logger.test.js ├── data_channel_message.test.js ├── sockets.test.js └── protobuf.test.js ├── .github └── workflows │ └── run-tests.yml ├── examples ├── mnist │ ├── package.json │ ├── README.md │ ├── webpack.config.js │ ├── index.html │ ├── mnist.js │ └── index.js └── multi-armed-bandit │ ├── index.html │ ├── package.json │ ├── webpack.config.js │ ├── README.md │ ├── index.js │ ├── app.js │ ├── logo-white.svg │ └── logo-color.svg ├── rollup.config.js ├── .all-contributorsrc ├── package.json ├── CHANGELOG.md ├── API-REFERENCE.md └── LICENSE /jest-globals.js: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | API-REFERENCE.md -------------------------------------------------------------------------------- /.eslintignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | dist 3 | examples -------------------------------------------------------------------------------- /.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "singleQuote": true 3 | } 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | coverage 3 | dist 4 | tmp 5 | yarn.lock 6 | .DS_Store 7 | .idea -------------------------------------------------------------------------------- /art/fake-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashigharti/syft.js/master/art/fake-logo.png -------------------------------------------------------------------------------- /art/mnist-demo-ani.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashigharti/syft.js/master/art/mnist-demo-ani.gif -------------------------------------------------------------------------------- /documentation.yml: -------------------------------------------------------------------------------- 1 | toc: 2 | - Syft 3 | - Job 4 | - Job#accepted 5 | - Job#rejected 6 | - Job#error 7 | - SyftModel 8 | - Plan 9 | -------------------------------------------------------------------------------- /src/index.js: -------------------------------------------------------------------------------- 1 | // Export all constants 2 | export * from './_constants'; 3 | 4 | // Export the main class as default AND as named 5 | export { default as Syft } from './syft'; 6 | export { default } from './syft'; 7 | -------------------------------------------------------------------------------- /.npmignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | coverage 3 | examples 4 | test 5 | tmp 6 | yarn.lock 7 | .DS_Store 8 | .idea 9 | .babelrc 10 | .prettierrc 11 | .travis.yml 12 | copy-examples.sh 13 | jest-globals.js 14 | rollup.config.js -------------------------------------------------------------------------------- /copy-examples.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -rf ./tmp 3 | mkdir tmp 4 | 5 | for directory in $(find ./examples -type d -mindepth 1 -maxdepth 1); 6 | do 7 | directory="$directory/dist/" 8 | name="$(cut -d'/' -f3 <<<$directory)" 9 | 10 | rsync -avzh "$directory" "./tmp/$name" 11 | done -------------------------------------------------------------------------------- /.babelrc: -------------------------------------------------------------------------------- 1 | { 2 | "presets": ["@babel/preset-env"], 3 | "plugins": [ 4 | "@babel/plugin-proposal-class-properties" 5 | ], 6 | "env": { 7 | "test": { 8 | "plugins": [ 9 | "@babel/plugin-transform-runtime", 10 | "@babel/plugin-proposal-class-properties" 11 | ] 12 | } 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "parser": "babel-eslint", 3 | "env": { 4 | "browser": true, 5 | "es6": true, 6 | "node": true, 7 | "jest": true 8 | }, 9 | "extends": ["eslint:recommended"], 10 | "parserOptions": { 11 | "ecmaVersion": 2018, 12 | "sourceType": "module" 13 | }, 14 | "rules": { 15 | "strict": ["error", "never"] 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/events.js: -------------------------------------------------------------------------------- 1 | export default class EventObserver { 2 | constructor() { 3 | this.observers = []; 4 | } 5 | 6 | subscribe(type, func) { 7 | this.observers.push({ type, func }); 8 | } 9 | 10 | unsubscribe(eventType) { 11 | this.observers = this.observers.filter(({ type }) => eventType !== type); 12 | } 13 | 14 | broadcast(eventType, data) { 15 | this.observers.forEach(observer => { 16 | if (eventType === observer.type) { 17 | return observer.func(data); 18 | } 19 | }); 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/types/pointer-tensor.js: -------------------------------------------------------------------------------- 1 | // Create a class to represent pointer tensors 2 | // Add all the attributes that are serialized, just as for range and slice 3 | 4 | export default class PointerTensor { 5 | constructor( 6 | id, 7 | idAtLocation, 8 | locationId, 9 | pointToAttr, 10 | shape, 11 | garbageCollectData 12 | ) { 13 | this.id = id; 14 | this.idAtLocation = idAtLocation; 15 | this.locationId = locationId; 16 | this.pointToAttr = pointToAttr; 17 | this.shape = shape; 18 | this.garbageCollectData = garbageCollectData; 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /test/types/pointer-tensor.test.js: -------------------------------------------------------------------------------- 1 | import PointerTensor from '../../src/types/pointer-tensor'; 2 | 3 | describe('PointerTensor', () => { 4 | test('can be properly constructed', () => { 5 | const obj = new PointerTensor(123, 444, 'worker1', null, [2, 3], true); 6 | expect(obj.id).toStrictEqual(123); 7 | expect(obj.idAtLocation).toStrictEqual(444); 8 | expect(obj.locationId).toStrictEqual('worker1'); 9 | expect(obj.pointToAttr).toStrictEqual(null); 10 | expect(obj.shape).toStrictEqual([2, 3]); 11 | expect(obj.garbageCollectData).toStrictEqual(true); 12 | }); 13 | }); 14 | -------------------------------------------------------------------------------- /src/logger.js: -------------------------------------------------------------------------------- 1 | // A simple logging function 2 | export default class Logger { 3 | constructor(system, verbose) { 4 | if (!Logger.instance) { 5 | this.system = system; 6 | this.verbose = verbose; 7 | Logger.instance = this; 8 | } 9 | return Logger.instance; 10 | } 11 | 12 | log(message, data) { 13 | // Only log if verbose is turned on 14 | if (this.verbose) { 15 | const output = `${Date.now()}: ${this.system} - ${message}`; 16 | 17 | // Have the passed additional data? 18 | if (data) { 19 | console.log(output, data); 20 | } else { 21 | console.log(output); 22 | } 23 | } 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /test/types/placeholder.test.js: -------------------------------------------------------------------------------- 1 | import { Placeholder } from '../../src/types/placeholder'; 2 | 3 | describe('Placeholder', () => { 4 | test('can be properly constructed', () => { 5 | const obj = new Placeholder(123, ['tag1', 'tag2'], 'desc'); 6 | expect(obj.id).toStrictEqual(123); 7 | expect(obj.tags).toStrictEqual(['tag1', 'tag2']); 8 | expect(obj.description).toStrictEqual('desc'); 9 | }); 10 | 11 | test('can get placeholder order', () => { 12 | const obj = new Placeholder(123, ['#input-1'], 'desc'); 13 | expect(obj.getOrderFromTags('#input')).toStrictEqual(1); 14 | expect(() => obj.getOrderFromTags('#output')).toThrow(); 15 | }); 16 | }); 17 | -------------------------------------------------------------------------------- /test/types/protocol.test.js: -------------------------------------------------------------------------------- 1 | import Protocol from '../../src/types/protocol'; 2 | 3 | describe('Protocol', () => { 4 | test('can be properly constructed', () => { 5 | const planAssignments = [ 6 | ['worker1', '1234'], 7 | ['worker2', '3456'] 8 | ]; 9 | const obj = new Protocol( 10 | 123, 11 | ['tag1', 'tag2'], 12 | 'desc', 13 | planAssignments, 14 | true 15 | ); 16 | expect(obj.id).toStrictEqual(123); 17 | expect(obj.tags).toStrictEqual(['tag1', 'tag2']); 18 | expect(obj.description).toStrictEqual('desc'); 19 | expect(obj.plans).toStrictEqual(planAssignments); 20 | expect(obj.workersResolved).toStrictEqual(true); 21 | }); 22 | }); 23 | -------------------------------------------------------------------------------- /.github/workflows/run-tests.yml: -------------------------------------------------------------------------------- 1 | name: Run tests and coverage 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | types: [opened, synchronize, reopened] 9 | 10 | jobs: 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Use Node.js ${{ matrix.node-version }} 18 | uses: actions/setup-node@v1 19 | with: 20 | node-version: ${{ matrix.node-version }} 21 | - name: Install dependencies 22 | run: | 23 | npm install 24 | sudo npm install -g codecov 25 | - name: Test with npm 26 | run: | 27 | npm run test 28 | - name: Test code coverage 29 | run: | 30 | codecov 31 | -------------------------------------------------------------------------------- /examples/mnist/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "syftjs-mnist", 3 | "version": "1.0.0", 4 | "private": true, 5 | "description": "", 6 | "main": "index.js", 7 | "scripts": { 8 | "start": "webpack-dev-server --mode development", 9 | "build": "rm -rf dist && webpack --mode production", 10 | "test": "echo \"Error: no test specified\" && exit 1" 11 | }, 12 | "author": "", 13 | "license": "", 14 | "dependencies": { 15 | "@openmined/syft.js": "github:openmined/syft.js" 16 | }, 17 | "devDependencies": { 18 | "babel-loader": "^8.0.6", 19 | "html-webpack-plugin": "^3.2.0", 20 | "path": "^0.12.7", 21 | "regenerator-runtime": "^0.13.3", 22 | "webpack": "^4.39.1", 23 | "webpack-cli": "^3.3.7", 24 | "webpack-dev-server": "^3.7.2" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/types/protocol.js: -------------------------------------------------------------------------------- 1 | import { getPbId } from '../protobuf'; 2 | 3 | export default class Protocol { 4 | constructor(id, tags, description, planAssigments, workersResolved) { 5 | this.id = id; 6 | this.tags = tags; 7 | this.description = description; 8 | this.plans = planAssigments; 9 | this.workersResolved = workersResolved; 10 | } 11 | 12 | static unbufferize(worker, pb) { 13 | const planAssignments = []; 14 | if (pb.plan_assignments) { 15 | pb.plan_assignments.forEach(item => { 16 | planAssignments.push([getPbId(item.worker_id), getPbId(item.plan_id)]); 17 | }); 18 | } 19 | return new Protocol( 20 | getPbId(pb.id), 21 | pb.tags, 22 | pb.description, 23 | planAssignments, 24 | pb.workers_resolved 25 | ); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /examples/multi-armed-bandit/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 14 | 15 | 21 | 22 | Syft.js Multi-armed Bandit Example 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 | 31 | -------------------------------------------------------------------------------- /examples/mnist/README.md: -------------------------------------------------------------------------------- 1 | # syft.js MNIST Example 2 | 3 | This is a demonstration of how to use [syft.js](https://github.com/openmined/syft.js) 4 | with [PyGrid](https://github.com/OpenMined/pygrid) to train a plan on local data in the browser. 5 | 6 | ## Quick Start 7 | 8 | 1. Install and start [PyGrid](https://github.com/OpenMined/pygrid) 9 | 2. Install [PySyft](https://github.com/OpenMined/PySyft) and [execute the "Part 01 - Create Plan" notebook](https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/static-fl/Part%2001%20-%20Create%20Plan.ipynb) from `examples/tutorials/static-fl` folder to seed the MNIST plan and model into PyGrid. 10 | 3. Now back in this folder, execute `npm install` 11 | 4. And then execute `npm start` 12 | 13 | This will launch a web browser running the file `index.js`. Every time you make changes to this file, the server will automatically re-compile your code and refresh the page. No need to start and stop. :) 14 | -------------------------------------------------------------------------------- /examples/multi-armed-bandit/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "syftjs-multi-armed-bandit", 3 | "version": "1.0.0", 4 | "private": true, 5 | "description": "", 6 | "main": "index.js", 7 | "scripts": { 8 | "start": "webpack-dev-server --mode development", 9 | "build": "rm -rf dist && webpack --mode production", 10 | "test": "echo \"Error: no test specified\" && exit 1" 11 | }, 12 | "author": "", 13 | "license": "", 14 | "dependencies": { 15 | "@emotion/core": "^10.0.28", 16 | "@fortawesome/fontawesome-free": "^5.13.1", 17 | "@openmined/syft.js": "github:openmined/syft.js", 18 | "react": "^16.13.1", 19 | "react-dom": "^16.13.1" 20 | }, 21 | "devDependencies": { 22 | "babel-loader": "^8.0.6", 23 | "html-webpack-plugin": "^3.2.0", 24 | "path": "^0.12.7", 25 | "regenerator-runtime": "^0.13.3", 26 | "svg-url-loader": "^6.0.0", 27 | "webpack": "^4.39.1", 28 | "webpack-cli": "^3.3.7", 29 | "webpack-dev-server": "^3.7.2" 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /test/data/dummy.tpl.js: -------------------------------------------------------------------------------- 1 | export const MNIST_BATCH_SIZE = parseInt('%MNIST_BATCH_SIZE%'); 2 | export const MNIST_LR = parseFloat('%MNIST_LR%'); 3 | export const MNIST_PLAN = '%MNIST_PLAN%'; 4 | export const MNIST_BATCH_DATA = '%MNIST_BATCH_DATA%'; 5 | export const MNIST_MODEL_PARAMS = '%MNIST_MODEL_PARAMS%'; 6 | export const MNIST_UPD_MODEL_PARAMS = '%MNIST_UPD_MODEL_PARAMS%'; 7 | export const MNIST_LOSS = parseFloat('%MNIST_LOSS%'); 8 | export const MNIST_ACCURACY = parseFloat('%MNIST_ACCURACY%'); 9 | 10 | export const PLAN_WITH_STATE = '%PLAN_WITH_STATE%'; 11 | 12 | export const BANDIT_SIMPLE_PLAN = '%BANDIT_SIMPLE_PLAN%'; 13 | export const BANDIT_SIMPLE_MODEL_PARAMS = '%BANDIT_SIMPLE_MODEL_PARAMS%'; 14 | export const BANDIT_THOMPSON_PLAN = '%BANDIT_THOMPSON_PLAN%'; 15 | export const BANDIT_THOMPSON_MODEL_PARAMS = '%BANDIT_THOMPSON_MODEL_PARAMS%'; 16 | 17 | export const PROTOCOL = 18 | 'CgYIjcivoCUqEwoGCIHIr6AlEgkSB3dvcmtlcjEqEwoGCIXIr6AlEgkSB3dvcmtlcjIqEwoGCInIr6AlEgkSB3dvcmtlcjM='; 19 | -------------------------------------------------------------------------------- /src/types/state.js: -------------------------------------------------------------------------------- 1 | import { protobuf, unbufferize } from '../protobuf'; 2 | 3 | export class State { 4 | constructor(placeholders = null, tensors = null) { 5 | this.placeholders = placeholders; 6 | this.tensors = tensors; 7 | } 8 | 9 | getTfTensors() { 10 | return this.tensors.map(t => t.toTfTensor()); 11 | } 12 | 13 | static unbufferize(worker, pb) { 14 | const tensors = pb.tensors.map(stateTensor => { 15 | // unwrap StateTensor 16 | return unbufferize(worker, stateTensor[stateTensor.tensor]); 17 | }); 18 | 19 | return new State(unbufferize(worker, pb.placeholders), tensors); 20 | } 21 | 22 | bufferize(worker) { 23 | const tensorsPb = this.tensors.map(tensor => 24 | protobuf.syft_proto.execution.v1.StateTensor.create({ 25 | torch_tensor: tensor.bufferize(worker) 26 | }) 27 | ); 28 | const placeholdersPb = this.placeholders.map(ph => ph.bufferize()); 29 | return protobuf.syft_proto.execution.v1.State.create({ 30 | placeholders: placeholdersPb, 31 | tensors: tensorsPb 32 | }); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/object-registry.js: -------------------------------------------------------------------------------- 1 | import * as tf from '@tensorflow/tfjs-core'; 2 | 3 | export default class ObjectRegistry { 4 | constructor() { 5 | this.objects = {}; 6 | this.gc = {}; 7 | } 8 | 9 | set(id, obj, gc = false) { 10 | if (this.objects[id] instanceof tf.Tensor) { 11 | this.objects[id].dispose(); 12 | delete this.objects[id]; 13 | } 14 | this.objects[id] = obj; 15 | this.gc[id] = gc; 16 | } 17 | 18 | setGc(id, gc) { 19 | this.gc[id] = gc; 20 | } 21 | 22 | get(id) { 23 | return this.objects[id]; 24 | } 25 | 26 | has(id) { 27 | return Object.hasOwnProperty.call(this.objects, id); 28 | } 29 | 30 | clear() { 31 | for (let key of Object.keys(this.objects)) { 32 | if (this.gc[key] && this.objects[key] instanceof tf.Tensor) { 33 | this.objects[key].dispose(); 34 | } 35 | } 36 | this.objects = {}; 37 | this.gc = {}; 38 | } 39 | 40 | load(objectRegistry) { 41 | for (let key of Object.keys(objectRegistry.objects)) { 42 | this.set(key, objectRegistry.get(key)); 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/types/plan.js: -------------------------------------------------------------------------------- 1 | import { getPbId, unbufferize } from '../protobuf'; 2 | 3 | /** 4 | * PySyft Plan. 5 | */ 6 | export class Plan { 7 | /** 8 | * @hideconstructor 9 | */ 10 | constructor(id, name, role = [], tags = [], description = null) { 11 | this.id = id; 12 | this.name = name; 13 | this.role = role; 14 | this.tags = tags; 15 | this.description = description; 16 | } 17 | 18 | /** 19 | * @private 20 | * @returns {Plan} 21 | */ 22 | static unbufferize(worker, pb) { 23 | const id = getPbId(pb.id); 24 | 25 | return new Plan( 26 | id, 27 | pb.name, 28 | unbufferize(worker, pb.role), 29 | pb.tags, 30 | pb.description 31 | ); 32 | } 33 | 34 | /** 35 | * Executes the Plan and returns its output. 36 | * 37 | * The order, type and number of arguments must match to arguments defined in the PySyft Plan. 38 | * 39 | * @param {Syft} worker 40 | * @param {...(tf.Tensor|number)} data 41 | * @returns {Promise>} 42 | */ 43 | async execute(worker, ...data) { 44 | return this.role.execute(worker, ...data); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /examples/mnist/webpack.config.js: -------------------------------------------------------------------------------- 1 | const path = require('path'); 2 | const HtmlWebpackPlugin = require('html-webpack-plugin'); 3 | 4 | module.exports = (env, argv) => ({ 5 | mode: argv.mode, 6 | entry: ['regenerator-runtime/runtime', './index.js'], 7 | output: { 8 | path: path.join(__dirname, '/dist'), 9 | filename: 'index.bundle.js' 10 | }, 11 | devtool: argv.mode === 'development' ? 'eval-source-map' : 'source-map', 12 | devServer: { 13 | port: 8080, 14 | hot: true, 15 | open: true, 16 | stats: { 17 | children: false, // Hide children information 18 | maxModules: 0 // Set the maximum number of modules to be shown 19 | } 20 | }, 21 | module: { 22 | rules: [ 23 | { 24 | test: /\.js$/, 25 | exclude: /node_modules/, 26 | use: { 27 | loader: 'babel-loader', 28 | options: { 29 | presets: ['@babel/preset-env'], 30 | plugins: ['@babel/plugin-proposal-class-properties'] 31 | } 32 | } 33 | } 34 | ] 35 | }, 36 | plugins: [new HtmlWebpackPlugin({ template: './index.html' })], 37 | externals: { 38 | '@tensorflow/tfjs-core': 'tf' 39 | } 40 | }); 41 | -------------------------------------------------------------------------------- /src/data-channel-message-queue.js: -------------------------------------------------------------------------------- 1 | import EventObserver from './events'; 2 | 3 | export default class DataChannelMessageQueue { 4 | constructor() { 5 | this.messages = new Map(); 6 | this.observer = new EventObserver(); 7 | } 8 | 9 | /** 10 | * Register new message 11 | * @param {DataChannelMessage} message 12 | */ 13 | register(message) { 14 | if (this.isRegistered(message.id)) { 15 | return false; 16 | } 17 | this.messages.set(message.id, message); 18 | message.once('ready', this.onMessageReady.bind(this)); 19 | } 20 | 21 | unregister(message) { 22 | this.messages.delete(message.id); 23 | } 24 | 25 | /** 26 | * Check if registered by message id 27 | * @param {number} id 28 | */ 29 | isRegistered(id) { 30 | return this.messages.has(id); 31 | } 32 | 33 | /** 34 | * Check if registered by message id 35 | * @param {number} id 36 | */ 37 | getById(id) { 38 | return this.messages.get(id); 39 | } 40 | 41 | /** 42 | * 43 | * @param {DataChannelMessage} message 44 | */ 45 | onMessageReady(message) { 46 | this.unregister(message); 47 | this.observer.broadcast('message', message); 48 | } 49 | 50 | on(event, func) { 51 | this.observer.subscribe(event, func); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /rollup.config.js: -------------------------------------------------------------------------------- 1 | import babel from '@rollup/plugin-babel'; 2 | import builtins from '@joseph184/rollup-plugin-node-builtins'; 3 | import json from '@rollup/plugin-json'; 4 | import peerDepsExternal from 'rollup-plugin-peer-deps-external'; 5 | import resolve from '@rollup/plugin-node-resolve'; 6 | import commonjs from '@rollup/plugin-commonjs'; 7 | import filesize from 'rollup-plugin-filesize'; 8 | 9 | import pkg from './package.json'; 10 | 11 | const sharedOutput = { 12 | name: 'syft', 13 | sourcemap: true, 14 | exports: 'named', 15 | globals: { 16 | '@tensorflow/tfjs-core': 'tf' 17 | } 18 | }; 19 | 20 | export default { 21 | input: 'src/index.js', 22 | output: [ 23 | { 24 | file: pkg.browser, 25 | format: 'umd', 26 | ...sharedOutput 27 | }, 28 | { 29 | file: pkg.main, 30 | format: 'cjs', 31 | ...sharedOutput 32 | }, 33 | { 34 | file: pkg.module, 35 | format: 'es', 36 | ...sharedOutput 37 | } 38 | ], 39 | plugins: [ 40 | builtins(), 41 | json(), 42 | peerDepsExternal(), 43 | babel({ 44 | babelHelpers: 'bundled', 45 | exclude: 'node_modules/**' 46 | }), 47 | resolve({ 48 | preferBuiltins: true 49 | }), 50 | commonjs(), 51 | filesize() 52 | ] 53 | }; 54 | -------------------------------------------------------------------------------- /src/_errors.js: -------------------------------------------------------------------------------- 1 | export const NO_DETAILER = d => 2 | `Serialized object contains type that may exist in PySyft, but is not currently supported in syft.js. Please file a feature request (https://github.com/OpenMined/syft.js/issues) for type ${d}.`; 3 | 4 | export const NOT_ENOUGH_ARGS = (passed, expected) => 5 | `You have passed ${passed} argument(s) when the plan requires ${expected} argument(s).`; 6 | 7 | export const MISSING_VARIABLE = () => 8 | `Command requires variable that is missing.`; 9 | 10 | export const NO_PLAN = `The operation you're attempting to run requires a plan before being called.`; 11 | 12 | export const PLAN_ALREADY_COMPLETED = (name, id) => 13 | `You have already executed the plan named "${name}" with id "${id}".`; 14 | 15 | export const CANNOT_FIND_COMMAND = command => 16 | `Command ${command} not found in in TensorFlow.js.`; 17 | 18 | export const GRID_UNKNOWN_CYCLE_STATUS = status => 19 | `Unknown cycle status: ${status}`; 20 | 21 | export const GRID_ERROR = status => `Grid error: ${status}`; 22 | 23 | export const MODEL_LOAD_FAILED = status => `Failed to load Model: ${status}`; 24 | 25 | export const PLAN_LOAD_FAILED = (planName, status) => 26 | `Failed to load '${planName}' Plan: ${status}`; 27 | 28 | export const PROTOBUF_UNSERIALIZE_FAILED = (pbType, status) => 29 | `Failed to unserialize binary protobuf data into ${pbType}: ${status}`; 30 | -------------------------------------------------------------------------------- /test/events.test.js: -------------------------------------------------------------------------------- 1 | import EventObserver from '../src/events'; 2 | 3 | describe('Event Observer', () => { 4 | jest.spyOn(console, 'log'); 5 | 6 | afterEach(() => { 7 | jest.clearAllMocks(); 8 | }); 9 | 10 | test('can subscribe to and broadcast an event', () => { 11 | const observer = new EventObserver(); 12 | const name = 'my-thing'; 13 | const func = data => { 14 | console.log('hello', data); 15 | }; 16 | const myData = { awesome: true }; 17 | 18 | expect(observer.observers.length).toBe(0); 19 | 20 | observer.subscribe(name, func); 21 | observer.subscribe(`${name}-other`, func); 22 | 23 | expect(observer.observers.length).toBe(2); 24 | expect(console.log.mock.calls.length).toBe(0); 25 | 26 | observer.broadcast(name, myData); 27 | 28 | expect(console.log.mock.calls.length).toBe(1); 29 | expect(console.log.mock.calls[0][0]).toBe('hello'); 30 | expect(console.log.mock.calls[0][1]).toBe(myData); 31 | }); 32 | 33 | test('can unsubscribe to an event', () => { 34 | const observer = new EventObserver(), 35 | myType = 'hello'; 36 | 37 | expect(observer.observers.length).toBe(0); 38 | 39 | observer.subscribe(myType, () => console.log('awesome')); 40 | 41 | expect(observer.observers.length).toBe(1); 42 | 43 | observer.unsubscribe(myType); 44 | 45 | expect(observer.observers.length).toBe(0); 46 | }); 47 | }); 48 | -------------------------------------------------------------------------------- /src/_constants.js: -------------------------------------------------------------------------------- 1 | // Sockets 2 | export const SOCKET_STATUS = 'socket-status'; 3 | export const SOCKET_PING = 'socket-ping'; 4 | 5 | // Grid 6 | export const GET_PROTOCOL = 'get-protocol'; 7 | export const CYCLE_STATUS_ACCEPTED = 'accepted'; 8 | export const CYCLE_STATUS_REJECTED = 'rejected'; 9 | 10 | // WebRTC 11 | export const WEBRTC_JOIN_ROOM = 'webrtc: join-room'; 12 | export const WEBRTC_INTERNAL_MESSAGE = 'webrtc: internal-message'; 13 | export const WEBRTC_PEER_LEFT = 'webrtc: peer-left'; 14 | 15 | // WebRTC: Data Channel 16 | export const WEBRTC_DATACHANNEL_CHUNK_SIZE = 64 * 1024; 17 | export const WEBRTC_DATACHANNEL_MAX_BUFFER = 4 * 1024 * 1024; 18 | export const WEBRTC_DATACHANNEL_BUFFER_TIMEOUT = 2000; 19 | export const WEBRTC_DATACHANNEL_MAX_BUFFER_TIMEOUTS = 5; 20 | 21 | export const WEBRTC_PEER_CONFIG = { 22 | iceServers: [ 23 | { 24 | urls: [ 25 | 'stun:stun.l.google.com:19302', 26 | 'stun:stun1.l.google.com:19302', 27 | 'stun:stun2.l.google.com:19302' // FF says too many stuns are bad, don't send more than this 28 | ] 29 | } 30 | ] 31 | }; 32 | 33 | export const WEBRTC_PEER_OPTIONS = { 34 | optional: [ 35 | { DtlsSrtpKeyAgreement: true } // Required for connection between Chrome and Firefox 36 | // FF works w/o this option, but Chrome fails with it 37 | // { RtpDataChannels: true } // Required in Firefox to use the DataChannels API 38 | ] 39 | }; 40 | -------------------------------------------------------------------------------- /examples/multi-armed-bandit/webpack.config.js: -------------------------------------------------------------------------------- 1 | const path = require('path'); 2 | const HtmlWebpackPlugin = require('html-webpack-plugin'); 3 | 4 | module.exports = (env, argv) => ({ 5 | mode: argv.mode, 6 | entry: ['regenerator-runtime/runtime', './index.js'], 7 | output: { 8 | path: path.join(__dirname, '/dist'), 9 | filename: 'index.bundle.js' 10 | }, 11 | devtool: argv.mode === 'development' ? 'eval-source-map' : 'source-map', 12 | devServer: { 13 | port: 8080, 14 | hot: true, 15 | open: true, 16 | stats: { 17 | children: false, // Hide children information 18 | maxModules: 0 // Set the maximum number of modules to be shown 19 | } 20 | }, 21 | module: { 22 | rules: [ 23 | { 24 | test: /\.js$/, 25 | exclude: /node_modules/, 26 | use: { 27 | loader: 'babel-loader', 28 | options: { 29 | presets: ['@babel/preset-env', '@babel/preset-react'], 30 | plugins: ['@babel/plugin-proposal-class-properties'] 31 | } 32 | } 33 | }, 34 | { 35 | test: /\.svg$/, 36 | use: [ 37 | { 38 | loader: 'svg-url-loader', 39 | options: { 40 | limit: 10000 41 | } 42 | } 43 | ] 44 | } 45 | ] 46 | }, 47 | plugins: [new HtmlWebpackPlugin({ template: './index.html' })], 48 | externals: { 49 | '@tensorflow/tfjs-core': 'tf' 50 | } 51 | }); 52 | -------------------------------------------------------------------------------- /src/types/message.js: -------------------------------------------------------------------------------- 1 | import { unbufferize } from '../protobuf'; 2 | import Logger from '../logger'; 3 | 4 | export class Message { 5 | constructor(contents) { 6 | if (contents) { 7 | this.contents = contents; 8 | } 9 | this.logger = new Logger(); 10 | } 11 | } 12 | 13 | export class ObjectMessage extends Message { 14 | constructor(contents) { 15 | super(contents); 16 | } 17 | 18 | static unbufferize(worker, pb) { 19 | const tensor = unbufferize(worker, pb.tensor); 20 | return new ObjectMessage(tensor); 21 | } 22 | } 23 | 24 | // TODO when types will be availbale in protobuf 25 | 26 | /* 27 | export class ObjectRequestMessage extends Message { 28 | constructor(contents) { 29 | super(contents); 30 | } 31 | } 32 | 33 | export class IsNoneMessage extends Message { 34 | constructor(contents) { 35 | super(contents); 36 | } 37 | } 38 | 39 | export class GetShapeMessage extends Message { 40 | constructor(contents) { 41 | super(contents); 42 | } 43 | } 44 | 45 | export class ForceObjectDeleteMessage extends Message { 46 | constructor(contents) { 47 | super(contents); 48 | } 49 | } 50 | 51 | export class SearchMessage extends Message { 52 | constructor(contents) { 53 | super(contents); 54 | } 55 | } 56 | 57 | export class PlanCommandMessage extends Message { 58 | constructor(commandName, message) { 59 | super(); 60 | 61 | this.commandName = commandName; 62 | this.message = message; 63 | } 64 | } 65 | */ 66 | -------------------------------------------------------------------------------- /src/protobuf/mapping.js: -------------------------------------------------------------------------------- 1 | import { protobuf } from 'syft-proto'; 2 | import Protocol from '../types/protocol'; 3 | import { Plan } from '../types/plan'; 4 | import { Role } from '../types/role'; 5 | import { State } from '../types/state'; 6 | import { ObjectMessage } from '../types/message'; 7 | import { TorchParameter, TorchTensor } from '../types/torch'; 8 | import { Placeholder, PlaceholderId } from '../types/placeholder'; 9 | import { ComputationAction } from '../types/computation-action'; 10 | 11 | let PB_CLASS_MAP, PB_TO_UNBUFFERIZER; 12 | 13 | // because of cyclic dependencies between Protocol/etc modules and protobuf module 14 | // Protocol/etc classes are undefined at the moment when this module is imported 15 | export const initMappings = () => { 16 | PB_CLASS_MAP = [ 17 | [Protocol, protobuf.syft_proto.execution.v1.Protocol], 18 | [Plan, protobuf.syft_proto.execution.v1.Plan], 19 | [Role, protobuf.syft_proto.execution.v1.Role], 20 | [State, protobuf.syft_proto.execution.v1.State], 21 | [ComputationAction, protobuf.syft_proto.execution.v1.ComputationAction], 22 | [Placeholder, protobuf.syft_proto.execution.v1.Placeholder], 23 | [PlaceholderId, protobuf.syft_proto.execution.v1.PlaceholderId], 24 | [ObjectMessage, protobuf.syft_proto.messaging.v1.ObjectMessage], 25 | [TorchTensor, protobuf.syft_proto.types.torch.v1.TorchTensor], 26 | [TorchParameter, protobuf.syft_proto.types.torch.v1.Parameter] 27 | ]; 28 | 29 | PB_TO_UNBUFFERIZER = PB_CLASS_MAP.reduce((map, item) => { 30 | map[item[1]] = item[0].unbufferize; 31 | return map; 32 | }, {}); 33 | }; 34 | 35 | export { PB_CLASS_MAP, PB_TO_UNBUFFERIZER }; 36 | -------------------------------------------------------------------------------- /test/types/message.test.js: -------------------------------------------------------------------------------- 1 | import { Message, ObjectMessage } from '../../src/types/message'; 2 | import { TorchTensor } from '../../src/types/torch'; 3 | 4 | describe('Message', () => { 5 | test('can be properly constructed', () => { 6 | const content = 'abc'; 7 | const obj = new Message(content); 8 | expect(obj.contents).toBe(content); 9 | }); 10 | }); 11 | 12 | describe('ObjectMessage', () => { 13 | test('can be properly constructed', () => { 14 | const tensor = new TorchTensor( 15 | 123, 16 | new Float32Array([1, 2, 3, 4]), 17 | [2, 2], 18 | 'float32' 19 | ); 20 | const obj = new ObjectMessage(tensor); 21 | expect(obj.contents).toBe(tensor); 22 | }); 23 | }); 24 | 25 | describe('ObjectRequestMessage', () => { 26 | test('can be properly constructed', () => { 27 | // TODO when type is available in protobuf 28 | }); 29 | }); 30 | 31 | describe('IsNoneMessage', () => { 32 | test('can be properly constructed', () => { 33 | // TODO when type is available in protobuf 34 | }); 35 | }); 36 | 37 | describe('GetShapeMessage', () => { 38 | test('can be properly constructed', () => { 39 | // TODO when type is available in protobuf 40 | }); 41 | }); 42 | 43 | describe('ForceObjectDeleteMessage', () => { 44 | test('can be properly constructed', () => { 45 | // TODO when type is available in protobuf 46 | }); 47 | }); 48 | 49 | describe('SearchMessage', () => { 50 | test('can be properly constructed', () => { 51 | // TODO when type is available in protobuf 52 | }); 53 | }); 54 | 55 | describe('PlanCommandMessage', () => { 56 | test('can be properly constructed', () => { 57 | // TODO when type is available in protobuf 58 | }); 59 | }); 60 | -------------------------------------------------------------------------------- /test/mocks/webrtc.js: -------------------------------------------------------------------------------- 1 | // RTC classes mocks 2 | export class RTCPeerConnection { 3 | constructor(options, optional) { 4 | this.options = options; 5 | this.optional = optional; 6 | this.localDescription = null; 7 | this.remoteDescription = null; 8 | this.iceCandidates = []; 9 | this.sentMessages = []; 10 | this.ondatachannelListener = null; 11 | this.onicecandidateListener = null; 12 | } 13 | 14 | createDataChannel() { 15 | return {}; 16 | } 17 | 18 | async createOffer() { 19 | return Promise.resolve({ type: 'offer', sdp: 'testOfferSdp' }); 20 | } 21 | 22 | async createAnswer() { 23 | return Promise.resolve({ type: 'answer', sdp: 'testAnswerSdp' }); 24 | } 25 | 26 | async setLocalDescription(sessionDescription) { 27 | this.localDescription = sessionDescription; 28 | return Promise.resolve(); 29 | } 30 | 31 | async setRemoteDescription(sessionDescription) { 32 | this.remoteDescription = sessionDescription; 33 | return Promise.resolve(); 34 | } 35 | 36 | async addIceCandidate(iceCandidate) { 37 | this.iceCandidates.push(iceCandidate); 38 | return Promise.resolve(); 39 | } 40 | 41 | get onicecandidate() { 42 | return this.onicecandidateListener; 43 | } 44 | 45 | set onicecandidate(cb) { 46 | this.onicecandidateListener = cb; 47 | } 48 | 49 | get ondatachannel() { 50 | return this.ondatachannelListener; 51 | } 52 | 53 | set ondatachannel(cb) { 54 | this.ondatachannelListener = cb; 55 | } 56 | } 57 | 58 | export class RTCSessionDescription { 59 | constructor(sessionDescription) { 60 | Object.assign(this, sessionDescription); 61 | } 62 | } 63 | 64 | export class RTCIceCandidate { 65 | constructor(iceCandidate) { 66 | Object.assign(this, iceCandidate); 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /src/types/placeholder.js: -------------------------------------------------------------------------------- 1 | import { protobuf, getPbId, pbId } from '../protobuf'; 2 | 3 | export class PlaceholderId { 4 | constructor(id) { 5 | this.id = id; 6 | } 7 | 8 | static unbufferize(worker, pb) { 9 | return new PlaceholderId(getPbId(pb.id)); 10 | } 11 | 12 | bufferize(/* worker */) { 13 | return protobuf.syft_proto.execution.v1.PlaceholderId.create({ 14 | id: pbId(this.id) 15 | }); 16 | } 17 | } 18 | 19 | export class Placeholder { 20 | constructor(id, tags = [], description = null, expected_shape = null) { 21 | this.id = id; 22 | this.tags = tags; 23 | this.description = description; 24 | this.expected_shape = expected_shape; 25 | } 26 | 27 | static unbufferize(worker, pb) { 28 | let expected_shape = null; 29 | if ( 30 | pb.expected_shape && 31 | Array.isArray(pb.expected_shape.dims) && 32 | pb.expected_shape.dims.length > 0 33 | ) { 34 | // Unwrap Shape 35 | expected_shape = pb.expected_shape.dims; 36 | } 37 | 38 | return new Placeholder( 39 | getPbId(pb.id), 40 | pb.tags || [], 41 | pb.description, 42 | expected_shape 43 | ); 44 | } 45 | 46 | bufferize(/* worker */) { 47 | return protobuf.syft_proto.execution.v1.Placeholder.create({ 48 | id: pbId(this.id), 49 | tags: this.tags, 50 | description: this.description, 51 | expected_shape: protobuf.syft_proto.types.syft.v1.Shape.create( 52 | this.expected_shape 53 | ) 54 | }); 55 | } 56 | 57 | getOrderFromTags(prefix) { 58 | const regExp = new RegExp(`^${prefix}-(\\d+)$`, 'i'); 59 | for (let tag of this.tags) { 60 | let tagMatch = regExp[Symbol.match](tag); 61 | if (tagMatch) { 62 | return Number(tagMatch[1]); 63 | } 64 | } 65 | throw new Error(`Placeholder ${this.id} doesn't have order tag ${prefix}`); 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /examples/multi-armed-bandit/README.md: -------------------------------------------------------------------------------- 1 | # syft.js Multi-armed Bandit Example 2 | 3 | This is a demonstration of how to use [syft.js](https://github.com/openmined/syft.js) 4 | with [PyGrid](https://github.com/OpenMined/pygrid) to train [a multi-armed bandit](https://vwo.com/blog/multi-armed-bandit-algorithm/) in the browser. A multi-armed bandit can be used to perform automated A/B testing on a website or application, while gradually converging your users to the ideal user-experience, given a goal of your choosing. 5 | 6 | In this demo, we're automatically generating various website layouts that we want our users to view. There are subtle changes made to the website every time you load the website again, including things like changes in button size, color, or position on the page. In the background, syft.js will track which layouts the user does what we want (click a button) and report a positive model diff for that particular layout. For all other layouts where the user doesn't click on the button, we do not report anything. Over time, our model will slowly start to converge on a "preferred user experience" for the best layout, as chosen by user actions. 7 | 8 | While this demo is inherently simple, it's easy to see how one could extend it to a real-world application whereby website layouts are generated and tested by real users, slowly converging to the preferred UX. We're particuarly excited to see derivations of this demo in real-world web and mobile development! 9 | 10 | ## Quick Start 11 | 12 | 1. Install and start [PyGrid](https://github.com/OpenMined/pygrid) 13 | 2. Install [PySyft](https://github.com/OpenMined/PySyft) and ... TODO: @maddie - fill this in here 14 | 3. Now back in this folder, execute `npm install` 15 | 4. And then execute `npm start` 16 | 17 | This will launch a web browser running the file `index.js`. Every time you make changes to this file, the server will automatically re-compile your code and refresh the page. No need to start and stop. :) 18 | -------------------------------------------------------------------------------- /src/syft-model.js: -------------------------------------------------------------------------------- 1 | import { unserialize, protobuf, serialize } from './protobuf'; 2 | import { State } from './types/state'; 3 | import { TorchTensor } from './types/torch'; 4 | import { Placeholder } from './types/placeholder'; 5 | import { MODEL_LOAD_FAILED } from './_errors'; 6 | 7 | /** 8 | * Model parameters as stored in the PyGrid. 9 | * 10 | * @property {Array.} params Array of Model parameters. 11 | */ 12 | export default class SyftModel { 13 | /** 14 | * @hideconstructor 15 | * @param {Object} options 16 | * @param {Syft} options.worker Instance of Syft client. 17 | * @param {ArrayBuffer} options.modelData Serialized Model parameters as returned by PyGrid. 18 | */ 19 | constructor({ worker, modelData }) { 20 | try { 21 | const state = unserialize( 22 | worker, 23 | modelData, 24 | protobuf.syft_proto.execution.v1.State 25 | ); 26 | this.worker = worker; 27 | this.params = state.getTfTensors(); 28 | } catch (e) { 29 | throw new Error(MODEL_LOAD_FAILED(e.message)); 30 | } 31 | } 32 | 33 | /** 34 | * Calculates difference between 2 versions of the Model parameters 35 | * and returns serialized `diff` that can be submitted to PyGrid. 36 | * 37 | * @param {Array.} updatedModelParams Array of model parameters (tensors). 38 | * @returns {Promise} Protobuf-serialized `diff`. 39 | */ 40 | async createSerializedDiff(updatedModelParams) { 41 | const placeholders = [], 42 | tensors = []; 43 | 44 | for (let i = 0; i < updatedModelParams.length; i++) { 45 | let paramDiff = this.params[i].sub(updatedModelParams[i]); 46 | placeholders.push(new Placeholder(i, [`#${i}`, `#state-${i}`])); 47 | tensors.push(await TorchTensor.fromTfTensor(paramDiff)); 48 | } 49 | const state = new State(placeholders, tensors); 50 | const bin = serialize(this.worker, state); 51 | 52 | // Free memory. 53 | tensors.forEach(t => t._tfTensor.dispose()); 54 | 55 | return bin; 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/sockets.js: -------------------------------------------------------------------------------- 1 | import { SOCKET_PING } from './_constants'; 2 | import Logger from './logger'; 3 | 4 | export default class Sockets { 5 | constructor({ 6 | url, 7 | workerId, 8 | onOpen, 9 | onClose, 10 | onMessage, 11 | keepAliveTimeout = 20000 12 | }) { 13 | this.logger = new Logger(); 14 | const socket = new WebSocket(url); 15 | 16 | const keepAlive = () => { 17 | this.send(SOCKET_PING); 18 | this.timerId = setTimeout(keepAlive, keepAliveTimeout); 19 | }; 20 | 21 | const cancelKeepAlive = () => { 22 | clearTimeout(this.timerId); 23 | this.timerId = null; 24 | }; 25 | 26 | socket.onopen = event => { 27 | this.logger.log( 28 | `Opening socket connection at ${event.currentTarget.url}`, 29 | event 30 | ); 31 | 32 | keepAlive(); 33 | 34 | if (onOpen) onOpen(event); 35 | }; 36 | 37 | socket.onclose = event => { 38 | this.logger.log( 39 | `Closing socket connection at ${event.currentTarget.url}`, 40 | event 41 | ); 42 | 43 | cancelKeepAlive(); 44 | 45 | if (onClose) onClose(event); 46 | }; 47 | 48 | this.url = url; 49 | this.workerId = workerId; 50 | this.socket = socket; 51 | this.onMessage = onMessage; 52 | this.timerId = null; 53 | } 54 | 55 | send(type, data = {}) { 56 | return new Promise((resolve, reject) => { 57 | data.workerId = this.workerId; 58 | 59 | const message = { type, data }; 60 | 61 | this.logger.log('Sending message', message); 62 | 63 | this.socket.send(JSON.stringify(message)); 64 | 65 | this.socket.onmessage = event => { 66 | const data = JSON.parse(event.data); 67 | 68 | this.logger.log('Receiving message', data); 69 | 70 | resolve(this.onMessage(data)); 71 | }; 72 | 73 | this.socket.onerror = event => { 74 | this.logger.log('We have a socket error!', event); 75 | 76 | reject(event); 77 | }; 78 | }); 79 | } 80 | 81 | stop() { 82 | this.socket.close(); 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /test/types/torch.test.js: -------------------------------------------------------------------------------- 1 | import { TorchParameter, TorchTensor, TorchSize } from '../../src/types/torch'; 2 | import * as tf from '@tensorflow/tfjs-core'; 3 | 4 | describe('TorchTensor', () => { 5 | test('can be properly constructed', () => { 6 | const chain = new TorchTensor( 7 | 555, 8 | new Float32Array([6, 6, 6, 6]), 9 | [2, 2], 10 | 'float32' 11 | ); 12 | const grad = new TorchTensor( 13 | 666, 14 | new Float32Array([6, 6, 6, 6]), 15 | [2, 2], 16 | 'float32' 17 | ); 18 | const obj = new TorchTensor( 19 | 123, 20 | new Float32Array([1, 2, 3.3, 4]), 21 | [2, 2], 22 | 'float32', 23 | chain, 24 | grad, 25 | ['tag1', 'tag2'], 26 | 'desc' 27 | ); 28 | const tfTensor = tf.tensor([1, 2, 3.3, 4], [2, 2], 'float32'); 29 | 30 | expect(obj.id).toStrictEqual(123); 31 | expect(obj.shape).toStrictEqual([2, 2]); 32 | expect(obj.dtype).toStrictEqual('float32'); 33 | expect(obj.contents).toStrictEqual(new Float32Array([1, 2, 3.3, 4])); 34 | 35 | // resulting TF tensors are equal 36 | expect( 37 | tf 38 | .equal(obj.toTfTensor(), tfTensor) 39 | .all() 40 | .dataSync()[0] 41 | ).toBe(1); 42 | 43 | expect(obj.chain).toBe(chain); 44 | expect(obj.gradChain).toBe(grad); 45 | expect(obj.tags).toStrictEqual(['tag1', 'tag2']); 46 | expect(obj.description).toStrictEqual('desc'); 47 | }); 48 | }); 49 | 50 | describe('TorchSize', () => { 51 | test('can be properly constructed', () => { 52 | const obj = new TorchSize([2, 3]); 53 | expect(obj.size).toStrictEqual([2, 3]); 54 | }); 55 | }); 56 | 57 | describe('TorchParameter', () => { 58 | test('can be properly constructed', () => { 59 | const grad = new TorchTensor( 60 | 666, 61 | new Float32Array([6, 6, 6, 6]), 62 | [2, 2], 63 | 'float32' 64 | ); 65 | const tensor = new TorchTensor( 66 | 123, 67 | new Float32Array([1, 2, 3.3, 4]), 68 | [2, 2], 69 | 'float32' 70 | ); 71 | const obj = new TorchParameter(123, tensor, true, grad); 72 | 73 | expect(obj.id).toStrictEqual(123); 74 | expect(obj.tensor).toBe(tensor); 75 | expect(obj.requiresGrad).toStrictEqual(true); 76 | expect(obj.grad).toBe(grad); 77 | }); 78 | }); 79 | -------------------------------------------------------------------------------- /test/data_channel_message_queue.test.js: -------------------------------------------------------------------------------- 1 | import DataChannelMessage from '../src/data-channel-message'; 2 | import DataChannelMessageQueue from '../src/data-channel-message-queue'; 3 | import { WEBRTC_DATACHANNEL_CHUNK_SIZE } from '../src'; 4 | import { randomFillSync } from 'crypto'; 5 | 6 | describe('Data Channel Message Queue', () => { 7 | afterEach(() => { 8 | jest.clearAllMocks(); 9 | }); 10 | 11 | test('can register a message', () => { 12 | const queue = new DataChannelMessageQueue(); 13 | const message = new DataChannelMessage({ id: 123 }); 14 | queue.register(message); 15 | expect(queue.isRegistered(message.id)).toBe(true); 16 | expect(queue.register(message)).toBe(false); 17 | }); 18 | 19 | test('can unregister a message', () => { 20 | const queue = new DataChannelMessageQueue(); 21 | const message = new DataChannelMessage({ id: 123 }); 22 | queue.register(message); 23 | expect(queue.isRegistered(message.id)).toBe(true); 24 | queue.unregister(message); 25 | expect(queue.isRegistered(message.id)).toBe(false); 26 | }); 27 | 28 | test('can get a registered message', () => { 29 | const queue = new DataChannelMessageQueue(); 30 | const message = new DataChannelMessage({ id: 123 }); 31 | queue.register(message); 32 | expect(queue.getById(message.id)).toBe(message); 33 | queue.unregister(message); 34 | expect(queue.getById(message.id)).toBe(undefined); 35 | }); 36 | 37 | test('emits "message" event when the message is ready', done => { 38 | const buf = new ArrayBuffer(WEBRTC_DATACHANNEL_CHUNK_SIZE + 100); 39 | randomFillSync(new Uint8Array(buf), 0, buf.byteLength); 40 | const messageOrig = new DataChannelMessage({ data: buf }); 41 | const chunk1 = messageOrig.getChunk(0); 42 | const chunk2 = messageOrig.getChunk(1); 43 | const info1 = DataChannelMessage.messageInfoFromBuf(chunk1); 44 | 45 | const messageAssembled = new DataChannelMessage({ id: info1.id }); 46 | const queue = new DataChannelMessageQueue(); 47 | queue.register(messageAssembled); 48 | 49 | queue.on('message', message => { 50 | expect(message.chunks).toBe(messageOrig.chunks); 51 | expect(message.size).toBe(messageOrig.size); 52 | const orig = new Uint8Array(messageOrig.data); 53 | const assembled = new Uint8Array(message.data); 54 | expect(orig.every((v, i) => assembled[i] === v)).toBe(true); 55 | done(); 56 | }); 57 | 58 | messageAssembled.addChunk(chunk1); 59 | messageAssembled.addChunk(chunk2); 60 | }); 61 | }); 62 | -------------------------------------------------------------------------------- /.all-contributorsrc: -------------------------------------------------------------------------------- 1 | { 2 | "files": [ 3 | "README.md" 4 | ], 5 | "imageSize": 100, 6 | "commit": false, 7 | "contributors": [ 8 | { 9 | "login": "cereallarceny", 10 | "name": "Patrick Cason", 11 | "avatar_url": "https://avatars1.githubusercontent.com/u/1297930?v=4", 12 | "profile": "https://www.patrickcason.com", 13 | "contributions": [ 14 | "ideas", 15 | "code", 16 | "design", 17 | "doc", 18 | "business" 19 | ] 20 | }, 21 | { 22 | "login": "vvmnnnkv", 23 | "name": "Vova Manannikov", 24 | "avatar_url": "https://avatars2.githubusercontent.com/u/12518480?v=4", 25 | "profile": "https://www.linkedin.com/in/vova-manannikov", 26 | "contributions": [ 27 | "code", 28 | "doc", 29 | "test" 30 | ] 31 | }, 32 | { 33 | "login": "Nolski", 34 | "name": "Mike Nolan", 35 | "avatar_url": "https://avatars3.githubusercontent.com/u/2600677?v=4", 36 | "profile": "http://nolski.rocks", 37 | "contributions": [ 38 | "code" 39 | ] 40 | }, 41 | { 42 | "login": "IamRavikantSingh", 43 | "name": "Ravikant Singh", 44 | "avatar_url": "https://avatars3.githubusercontent.com/u/40258150?v=4", 45 | "profile": "http://ravikantsingh.com", 46 | "contributions": [ 47 | "code", 48 | "test" 49 | ] 50 | }, 51 | { 52 | "login": "vkkhare", 53 | "name": "varun khare", 54 | "avatar_url": "https://avatars1.githubusercontent.com/u/18126069?v=4", 55 | "profile": "http://vkkhare.github.io", 56 | "contributions": [ 57 | "code" 58 | ] 59 | }, 60 | { 61 | "login": "pedroespindula", 62 | "name": "Pedro Espíndula", 63 | "avatar_url": "https://avatars1.githubusercontent.com/u/38431219?v=4", 64 | "profile": "https://github.com/pedroespindula", 65 | "contributions": [ 66 | "doc" 67 | ] 68 | }, 69 | { 70 | "login": "Benardi", 71 | "name": "José Benardi de Souza Nunes", 72 | "avatar_url": "https://avatars0.githubusercontent.com/u/9937551?v=4", 73 | "profile": "https://benardi.github.io/myblog/", 74 | "contributions": [ 75 | "test" 76 | ] 77 | }, 78 | { 79 | "login": "tsingh2k15", 80 | "name": "Tajinder Singh", 81 | "avatar_url": "https://avatars1.githubusercontent.com/u/25232829?v=4", 82 | "profile": "http://www.linkedin.com/in/singh-taj", 83 | "contributions": [ 84 | "code" 85 | ] 86 | } 87 | ], 88 | "contributorsPerLine": 7, 89 | "projectName": "syft.js", 90 | "projectOwner": "OpenMined", 91 | "repoType": "github", 92 | "repoHost": "https://github.com", 93 | "skipCi": true 94 | } 95 | -------------------------------------------------------------------------------- /test/logger.test.js: -------------------------------------------------------------------------------- 1 | import Logger from '../src/logger'; 2 | 3 | describe('Logger', () => { 4 | jest.spyOn(console, 'log'); 5 | 6 | afterEach(() => { 7 | jest.clearAllMocks(); 8 | Logger.instance = null; 9 | }); 10 | 11 | test('can skip when verbose is false', () => { 12 | const testLogger = new Logger('syft.js', false), 13 | message = 'hello'; 14 | 15 | expect(testLogger.verbose).toBe(false); 16 | expect(console.log.mock.calls.length).toBe(0); 17 | 18 | testLogger.log(message); 19 | 20 | expect(console.log.mock.calls.length).toBe(0); 21 | }); 22 | 23 | test('singleton instance should be used with verbose false', () => { 24 | const testLogger = new Logger('syft.js', false); 25 | const testLogger_1 = new Logger('syft.js', true); 26 | 27 | expect(testLogger).toEqual(testLogger_1); 28 | 29 | expect(testLogger.verbose).toBe(false); 30 | expect(testLogger_1.verbose).toBe(false); 31 | expect(console.log.mock.calls.length).toBe(0); 32 | 33 | testLogger.log('hello singleton!!'); 34 | testLogger_1.log('hello singleton!!'); 35 | 36 | expect(console.log.mock.calls.length).toBe(0); 37 | }); 38 | 39 | test('singleton instance should be used with verbose true', () => { 40 | const testLogger = new Logger('syft.js', true); 41 | const testLogger_1 = new Logger('syft.js', false); 42 | 43 | expect(testLogger).toEqual(testLogger_1); 44 | 45 | expect(testLogger.verbose).toBe(true); 46 | expect(testLogger_1.verbose).toBe(true); 47 | expect(console.log.mock.calls.length).toBe(0); 48 | 49 | testLogger.log('hello singleton!!'); 50 | testLogger_1.log('hello singleton!!'); 51 | 52 | expect(console.log.mock.calls.length).toBe(2); 53 | }); 54 | 55 | test('can log under verbose mode', () => { 56 | const testLogger = new Logger('syft.js', true), 57 | message = 'hello'; 58 | 59 | expect(testLogger.verbose).toBe(true); 60 | expect(console.log.mock.calls.length).toBe(0); 61 | 62 | const currentTime = Date.now(); 63 | testLogger.log(message); 64 | 65 | expect(console.log.mock.calls.length).toBe(1); 66 | expect(console.log.mock.calls[0][0]).toBe( 67 | `${currentTime}: syft.js - ${message}` 68 | ); 69 | }); 70 | 71 | test('can log with data', () => { 72 | const testLogger = new Logger('syft.js', true), 73 | message = 'hello', 74 | myObj = { awesome: true }; 75 | 76 | expect(testLogger.verbose).toBe(true); 77 | expect(console.log.mock.calls.length).toBe(0); 78 | 79 | testLogger.log(message, myObj); 80 | 81 | expect(console.log.mock.calls.length).toBe(1); 82 | expect(console.log.mock.calls[0][0]).toContain(`: syft.js - ${message}`); 83 | expect(console.log.mock.calls[0][1]).toStrictEqual(myObj); 84 | }); 85 | }); 86 | -------------------------------------------------------------------------------- /src/syft.js: -------------------------------------------------------------------------------- 1 | import EventObserver from './events'; 2 | import Logger from './logger'; 3 | import GridAPIClient from './grid-api-client'; 4 | import Job from './job'; 5 | import ObjectRegistry from './object-registry'; 6 | 7 | /** 8 | * Syft client for static federated learning. 9 | * 10 | * @param {Object} options 11 | * @param {string} options.url Full URL to PyGrid app (`ws` and `http` schemas supported). 12 | * @param {boolean} options.verbose Whether to enable logging and allow unsecured PyGrid connection. 13 | * @param {string} options.authToken PyGrid authentication token. 14 | * @param {Object} options.peerConfig [not implemented] WebRTC peer config used with RTCPeerConnection. 15 | * 16 | * @example 17 | * 18 | * const client = new Syft({url: "ws://localhost:5000", verbose: true}) 19 | * const job = client.newJob({modelName: "mnist", modelVersion: "1.0.0"}) 20 | * job.on('accepted', async ({model, clientConfig}) => { 21 | * // execute training 22 | * const [...newParams] = await this.plans['...'].execute(...) 23 | * const diff = await model.createSerializedDiff(newParams) 24 | * await this.report(diff) 25 | * }) 26 | * job.on('rejected', ({timeout}) => { 27 | * // re-try later or stop 28 | * }) 29 | * job.on('error', (err) => { 30 | * // handle errors 31 | * }) 32 | * job.start() 33 | */ 34 | export default class Syft { 35 | constructor({ url, verbose, authToken, peerConfig }) { 36 | // For creating verbose logging should the worker desire 37 | this.logger = new Logger('syft.js', verbose); 38 | 39 | // Forcing connection to be secure if verbose value is false. 40 | this.verbose = verbose; 41 | 42 | this.gridClient = new GridAPIClient({ url, allowInsecureUrl: verbose }); 43 | 44 | // objects registry 45 | this.objects = new ObjectRegistry(); 46 | 47 | // For creating event listeners 48 | this.observer = new EventObserver(); 49 | 50 | this.worker_id = null; 51 | this.peerConfig = peerConfig; 52 | this.authToken = authToken; 53 | } 54 | 55 | /** 56 | * Authenticates the client against PyGrid and instantiates new Job with given options. 57 | * 58 | * @throws Error 59 | * @param {Object} options 60 | * @param {string} options.modelName FL Model name. 61 | * @param {string} options.modelVersion FL Model version. 62 | * @returns {Promise} 63 | */ 64 | async newJob({ modelName, modelVersion }) { 65 | if (!this.worker_id) { 66 | // authenticate 67 | const authResponse = await this.gridClient.authenticate( 68 | modelName, 69 | modelVersion, 70 | this.authToken 71 | ); 72 | this.worker_id = authResponse.worker_id; 73 | } 74 | 75 | return new Job({ 76 | worker: this, 77 | modelName, 78 | modelVersion, 79 | gridClient: this.gridClient 80 | }); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /test/mocks/grid.js: -------------------------------------------------------------------------------- 1 | import { WebSocket, Server } from 'mock-socket'; 2 | import fetchMock from 'fetch-mock'; 3 | global.WebSocket = WebSocket; 4 | 5 | export class GridMock { 6 | constructor(hostname = 'localhost', port = 8080) { 7 | this.hostname = hostname; 8 | this.port = port; 9 | this.ws = new Server(`ws://${hostname}:${port}`); 10 | 11 | this.wsConnections = []; 12 | this.ws.on('connection', socket => { 13 | this.wsConnections.push(socket); 14 | socket.on('message', message => this.messageHandler(socket, message)); 15 | socket.on('error', err => console.log('WS ERROR', err)); 16 | socket.on('close', err => console.log('WS CLOSE', err)); 17 | }); 18 | 19 | this.wsMessagesHistory = []; 20 | } 21 | 22 | setAuthenticationResponse(data) { 23 | this.authResponse = data; 24 | } 25 | 26 | setCycleResponse(data) { 27 | this.cycleResponse = data; 28 | } 29 | 30 | setReportResponse(data) { 31 | this.reportResponse = data; 32 | } 33 | 34 | _setHttpResponse(method, url, query, data, status) { 35 | let contentType = 'application/json'; 36 | let json = true; 37 | if (data instanceof ArrayBuffer || data instanceof Buffer) { 38 | contentType = 'application/octet-stream'; 39 | json = false; 40 | } 41 | 42 | fetchMock[method]( 43 | { url, query }, 44 | { 45 | body: data, 46 | status, 47 | headers: { 48 | 'Content-Type': contentType 49 | } 50 | }, 51 | { sendAsJson: json } 52 | ); 53 | } 54 | 55 | setModel(model_id, data, status = 200) { 56 | this._setHttpResponse( 57 | 'get', 58 | `http://${this.hostname}:${this.port}/federated/get-model`, 59 | { model_id }, 60 | data, 61 | status 62 | ); 63 | } 64 | 65 | setPlan(plan_id, data, status = 200) { 66 | this._setHttpResponse( 67 | 'get', 68 | `http://${this.hostname}:${this.port}/federated/get-plan`, 69 | { plan_id }, 70 | data, 71 | status 72 | ); 73 | } 74 | 75 | messageHandler(socket, message) { 76 | const data = JSON.parse(message); 77 | this.wsMessagesHistory.push(data); 78 | switch (data.type) { 79 | case 'federated/authenticate': 80 | socket.send( 81 | JSON.stringify({ 82 | type: data.type, 83 | data: this.authResponse 84 | }) 85 | ); 86 | break; 87 | 88 | case 'federated/cycle-request': 89 | socket.send( 90 | JSON.stringify({ 91 | type: data.type, 92 | data: this.cycleResponse 93 | }) 94 | ); 95 | break; 96 | 97 | case 'federated/report': 98 | socket.send( 99 | JSON.stringify({ 100 | type: data.type, 101 | data: this.reportResponse 102 | }) 103 | ); 104 | break; 105 | } 106 | } 107 | 108 | stop() { 109 | this.ws.close(); 110 | fetchMock.reset(); 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /src/types/torch.js: -------------------------------------------------------------------------------- 1 | import { getPbId, unbufferize, protobuf, pbId } from '../protobuf'; 2 | import * as tf from '@tensorflow/tfjs-core'; 3 | 4 | export class TorchTensor { 5 | constructor( 6 | id, 7 | contents, 8 | shape, 9 | dtype, 10 | chain = null, 11 | gradChain = null, 12 | tags = [], 13 | description = null 14 | ) { 15 | this.id = id; 16 | this.shape = shape; 17 | this.dtype = dtype; 18 | this.contents = contents; 19 | this.chain = chain; 20 | this.gradChain = gradChain; 21 | this.tags = tags; 22 | this.description = description; 23 | this._tfTensor = null; 24 | } 25 | 26 | toTfTensor() { 27 | if (!this._tfTensor) { 28 | this._tfTensor = tf.tensor(this.contents, this.shape, this.dtype); 29 | } 30 | return this._tfTensor; 31 | } 32 | 33 | static unbufferize(worker, pb) { 34 | if ( 35 | pb.serializer !== 36 | protobuf.syft_proto.types.torch.v1.TorchTensor.Serializer.SERIALIZER_ALL 37 | ) { 38 | throw new Error( 39 | `Tensor serializer ${pb.serializer} is not supported in syft.js` 40 | ); 41 | } 42 | 43 | // unwrap TensorData 44 | const tensorData = pb.contents_data; 45 | const dtype = tensorData.dtype; 46 | const shape = tensorData.shape.dims; 47 | const contents = tensorData[`contents_${dtype}`]; 48 | 49 | return new TorchTensor( 50 | getPbId(pb.id), 51 | contents, 52 | shape, 53 | dtype, 54 | unbufferize(worker, pb.chain), 55 | unbufferize(worker, pb.grad_chain), 56 | pb.tags, 57 | pb.description 58 | ); 59 | } 60 | 61 | bufferize(/* worker */) { 62 | const tensorData = { 63 | shape: protobuf.syft_proto.types.torch.v1.Size.create({ 64 | dims: this.shape 65 | }), 66 | dtype: this.dtype 67 | }; 68 | tensorData[`contents_${this.dtype}`] = this.contents; 69 | const pbTensorData = protobuf.syft_proto.types.torch.v1.TensorData.create( 70 | tensorData 71 | ); 72 | return protobuf.syft_proto.types.torch.v1.TorchTensor.create({ 73 | id: pbId(this.id), 74 | serializer: 75 | protobuf.syft_proto.types.torch.v1.TorchTensor.Serializer 76 | .SERIALIZER_ALL, 77 | contents_data: pbTensorData, 78 | tags: this.tags, 79 | description: this.description 80 | }); 81 | } 82 | 83 | static async fromTfTensor(tensor) { 84 | const flat = tensor.flatten(); 85 | const array = await flat.array(); 86 | flat.dispose(); 87 | const t = new TorchTensor(tensor.id, array, tensor.shape, tensor.dtype); 88 | t._tfTensor = tensor; 89 | return t; 90 | } 91 | } 92 | 93 | export class TorchSize { 94 | constructor(size) { 95 | this.size = size; 96 | } 97 | } 98 | 99 | export class TorchParameter { 100 | constructor(id, tensor, requiresGrad, grad) { 101 | this.id = id; 102 | this.tensor = tensor; 103 | this.requiresGrad = requiresGrad; 104 | this.grad = grad; 105 | } 106 | 107 | static unbufferize(worker, pb) { 108 | return new TorchParameter( 109 | getPbId(pb.id), 110 | unbufferize(worker, pb.tensor), 111 | pb.requires_grad, 112 | unbufferize(worker, pb.grad) 113 | ); 114 | } 115 | 116 | toTfTensor() { 117 | return this.tensor.toTfTensor(); 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /src/protobuf/index.js: -------------------------------------------------------------------------------- 1 | import { NO_DETAILER, PROTOBUF_UNSERIALIZE_FAILED } from '../_errors'; 2 | import { initMappings, PB_TO_UNBUFFERIZER } from './mapping'; 3 | import { protobuf } from 'syft-proto'; 4 | import Long from 'long'; 5 | export { protobuf }; 6 | 7 | export const unbufferize = (worker, pbObj) => { 8 | if (!PB_TO_UNBUFFERIZER) { 9 | initMappings(); 10 | } 11 | 12 | if ( 13 | pbObj === undefined || 14 | pbObj === null || 15 | ['number', 'string', 'boolean'].includes(typeof pbObj) 16 | ) { 17 | return pbObj; 18 | } 19 | 20 | const pbType = pbObj.constructor; 21 | 22 | // automatically unbufferize repeated fields 23 | if (Array.isArray(pbObj)) { 24 | return pbObj.map(item => unbufferize(worker, item)); 25 | } 26 | 27 | // automatically unbufferize map fields 28 | if (pbType.name === 'Object') { 29 | let res = {}; 30 | for (let key of Object.keys(pbObj)) { 31 | res[key] = unbufferize(worker, pbObj[key]); 32 | } 33 | return res; 34 | } 35 | 36 | // automatically unbufferize Id 37 | if (pbType === protobuf.syft_proto.types.syft.v1.Id) { 38 | return getPbId(pbObj); 39 | } 40 | 41 | // automatically unwrap Arg 42 | if (pbType === protobuf.syft_proto.types.syft.v1.Arg) { 43 | if (pbObj.arg === 'arg_int' && pbObj[pbObj.arg] instanceof Long) { 44 | // protobuf int64 is represented as Long 45 | return pbObj[pbObj.arg].toNumber(); 46 | } else { 47 | return unbufferize(worker, pbObj[pbObj.arg]); 48 | } 49 | } 50 | 51 | // automatically unwrap ArgList 52 | if (pbType === protobuf.syft_proto.types.syft.v1.ArgList) { 53 | return unbufferize(worker, pbObj.args); 54 | } 55 | 56 | const unbufferizer = PB_TO_UNBUFFERIZER[pbType]; 57 | if (typeof unbufferizer === 'undefined') { 58 | throw new Error(NO_DETAILER(pbType.name)); 59 | } 60 | return unbufferizer(worker, pbObj); 61 | }; 62 | 63 | /** 64 | * Converts binary in the form of ArrayBuffer or base64 string to syft class 65 | * @param worker 66 | * @param bin 67 | * @param pbType 68 | * @returns {Object} 69 | */ 70 | export const unserialize = (worker, bin, pbType) => { 71 | const buff = 72 | typeof bin === 'string' 73 | ? Buffer.from(bin, 'base64') 74 | : bin instanceof ArrayBuffer 75 | ? new Uint8Array(bin) 76 | : bin; 77 | let pbObj; 78 | try { 79 | pbObj = pbType.decode(buff); 80 | } catch (e) { 81 | throw new Error(PROTOBUF_UNSERIALIZE_FAILED(pbType.name, e.message)); 82 | } 83 | return unbufferize(worker, pbObj); 84 | }; 85 | 86 | /** 87 | * Converts syft class to protobuf-serialized binary 88 | * @param worker 89 | * @param obj 90 | * @returns {ArrayBuffer} 91 | */ 92 | export const serialize = (worker, obj) => { 93 | const pbObj = obj.bufferize(worker); 94 | const pbType = pbObj.constructor; 95 | const err = pbType.verify(pbObj); 96 | if (err) { 97 | throw new Error(err); 98 | } 99 | const bin = pbType.encode(pbObj).finish(); 100 | return new Uint8Array(bin).buffer; 101 | }; 102 | 103 | export const getPbId = field => { 104 | // convert int64 to string 105 | return field[field.id].toString(); 106 | }; 107 | 108 | export const pbId = value => { 109 | if (typeof value === 'number') { 110 | return protobuf.syft_proto.types.syft.v1.Id.create({ id_int: value }); 111 | } else if (typeof value === 'string') { 112 | return protobuf.syft_proto.types.syft.v1.Id.create({ id_str: value }); 113 | } 114 | }; 115 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "@openmined/syft.js", 3 | "version": "0.0.1-1", 4 | "description": "A Javascript Syft worker in the browser", 5 | "main": "dist/index.cjs.js", 6 | "module": "dist/index.esm.js", 7 | "browser": "dist/index.js", 8 | "files": [ 9 | "dist/*.js", 10 | "*.md" 11 | ], 12 | "repository": { 13 | "type": "git", 14 | "url": "git+https://github.com/OpenMined/syft.js.git" 15 | }, 16 | "keywords": [ 17 | "syft", 18 | "pysyft", 19 | "openmined", 20 | "open", 21 | "mined", 22 | "deep", 23 | "learning", 24 | "private", 25 | "javascript", 26 | "machine", 27 | "learning" 28 | ], 29 | "author": "OpenMined", 30 | "license": "Apache-2.0", 31 | "bugs": { 32 | "url": "https://github.com/OpenMined/syft.js/issues" 33 | }, 34 | "homepage": "https://github.com/OpenMined/syft.js#readme", 35 | "scripts": { 36 | "start": "npm run lint && rollup -cw", 37 | "build": "npm run lint && rollup -c", 38 | "prepare": "npm run build", 39 | "test": "npm run lint && jest --coverage", 40 | "test:watch": "npm run lint && jest --watch", 41 | "version": "auto-changelog -p && git add CHANGELOG.md", 42 | "release": "np", 43 | "deploy": "./copy-examples.sh && gh-pages -d tmp && rm -rf tmp", 44 | "lint": "eslint .", 45 | "doc": "documentation build --config documentation.yml src/syft.js src/syft-model.js src/job.js src/types/plan.js --shallow -f md -o API-REFERENCE.md" 46 | }, 47 | "browserslist": "> 0.25%, not dead", 48 | "husky": { 49 | "hooks": { 50 | "pre-commit": "npm run doc && pretty-quick --staged" 51 | } 52 | }, 53 | "jest": { 54 | "testEnvironment": "node", 55 | "collectCoverageFrom": [ 56 | "**/src/**/*.js" 57 | ], 58 | "setupFiles": [ 59 | "/jest-globals.js" 60 | ], 61 | "globals": { 62 | "window": {} 63 | } 64 | }, 65 | "dependencies": {}, 66 | "peerDependencies": { 67 | "@tensorflow/tfjs-core": "^1.2.5" 68 | }, 69 | "devDependencies": { 70 | "@babel/core": "^7.10.2", 71 | "@babel/plugin-proposal-class-properties": "^7.10.1", 72 | "@babel/plugin-transform-runtime": "^7.10.1", 73 | "@babel/preset-env": "^7.10.2", 74 | "@babel/runtime": "^7.10.2", 75 | "@joseph184/rollup-plugin-node-builtins": "^2.1.4", 76 | "@rollup/plugin-babel": "^5.0.3", 77 | "@rollup/plugin-commonjs": "^13.0.0", 78 | "@rollup/plugin-json": "^4.1.0", 79 | "@rollup/plugin-node-resolve": "^8.0.1", 80 | "@tensorflow/tfjs-core": "^1.7.4", 81 | "auto-changelog": "^1.16.4", 82 | "babel-eslint": "^10.1.0", 83 | "babel-jest": "^24.9.0", 84 | "documentation": "^13.0.1", 85 | "eslint": "^6.8.0", 86 | "eslint-config-standard": "^14.1.1", 87 | "eslint-plugin-import": "^2.21.2", 88 | "eslint-plugin-node": "^11.1.0", 89 | "eslint-plugin-promise": "^4.2.1", 90 | "eslint-plugin-standard": "^4.0.1", 91 | "fetch-mock": "^9.10.1", 92 | "gh-pages": "^2.2.0", 93 | "husky": "^3.1.0", 94 | "jest": "^24.9.0", 95 | "long": "^4.0.0", 96 | "mock-socket": "^9.0.3", 97 | "np": "^5.2.1", 98 | "prettier": "^1.19.1", 99 | "pretty-quick": "^2.0.1", 100 | "randomfill": "^1.0.4", 101 | "regenerator-runtime": "^0.13.5", 102 | "rollup": "^2.16.1", 103 | "rollup-plugin-filesize": "^9.0.0", 104 | "rollup-plugin-peer-deps-external": "^2.2.2", 105 | "syft-proto": "github:openmined/syft-proto#v0.4.9" 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /test/data_channel_message.test.js: -------------------------------------------------------------------------------- 1 | import DataChannelMessage from '../src/data-channel-message'; 2 | import { WEBRTC_DATACHANNEL_CHUNK_SIZE } from '../src'; 3 | import { randomFillSync } from 'crypto'; 4 | 5 | describe('Data Channel Message', () => { 6 | afterEach(() => { 7 | jest.clearAllMocks(); 8 | }); 9 | 10 | test('can construct from ArrayBuffer', () => { 11 | const buf = new ArrayBuffer(100000); 12 | const message = new DataChannelMessage({ data: buf }); 13 | expect(message.chunks).toBe( 14 | Math.ceil(100000 / WEBRTC_DATACHANNEL_CHUNK_SIZE) 15 | ); 16 | expect(() => new DataChannelMessage({ data: {} })).toThrow(); 17 | }); 18 | 19 | test('can slice data to chunks', () => { 20 | const buf = new ArrayBuffer(100000); 21 | const message = new DataChannelMessage({ data: buf }); 22 | const chunk = message.getChunk(0); 23 | expect(chunk.byteLength).toBe(WEBRTC_DATACHANNEL_CHUNK_SIZE); 24 | const info = DataChannelMessage.messageInfoFromBuf(chunk); 25 | expect(info.id).toBe(message.id); 26 | expect(info.chunks).toBe(message.chunks); 27 | expect(info.chunk).toBe(0); 28 | }); 29 | 30 | test('can get info from chunk', () => { 31 | const buf = new ArrayBuffer(WEBRTC_DATACHANNEL_CHUNK_SIZE + 100); 32 | randomFillSync(new Uint8Array(buf), 0, buf.byteLength); 33 | const messageOrig = new DataChannelMessage({ data: buf }); 34 | const chunk1 = messageOrig.getChunk(0); 35 | const info1 = DataChannelMessage.messageInfoFromBuf(chunk1); 36 | expect(info1.id).toBe(messageOrig.id); 37 | expect(info1.chunks).toBe(messageOrig.chunks); 38 | expect(info1.chunk).toBe(0); 39 | 40 | new Uint8Array(chunk1)[0] = 123; 41 | const infoErr = DataChannelMessage.messageInfoFromBuf(chunk1); 42 | expect(infoErr).toBe(false); 43 | }); 44 | 45 | test('can assemble full message from chunks', done => { 46 | const buf = new ArrayBuffer(WEBRTC_DATACHANNEL_CHUNK_SIZE + 100); 47 | randomFillSync(new Uint8Array(buf), 0, buf.byteLength); 48 | const messageOrig = new DataChannelMessage({ data: buf }); 49 | const chunk1 = messageOrig.getChunk(0); 50 | const chunk2 = messageOrig.getChunk(1); 51 | const info1 = DataChannelMessage.messageInfoFromBuf(chunk1); 52 | 53 | const messageAssembled = new DataChannelMessage({ id: info1.id }); 54 | messageAssembled.once('ready', message => { 55 | expect(message.chunks).toBe(messageOrig.chunks); 56 | expect(message.size).toBe(messageOrig.size); 57 | const orig = new Uint8Array(messageOrig.data); 58 | const assembled = new Uint8Array(message.data); 59 | expect(orig.every((v, i) => assembled[i] === v)).toBe(true); 60 | done(); 61 | }); 62 | messageAssembled.addChunk(chunk1); 63 | messageAssembled.addChunk(chunk2); 64 | }); 65 | 66 | test('should error on invalid chunks', () => { 67 | const buf = new ArrayBuffer(WEBRTC_DATACHANNEL_CHUNK_SIZE + 100); 68 | randomFillSync(new Uint8Array(buf), 0, buf.byteLength); 69 | const messageOrig = new DataChannelMessage({ data: buf }); 70 | const chunk1 = messageOrig.getChunk(0); 71 | 72 | // id doesn't match 73 | const messageAssembled = new DataChannelMessage({ id: 123 }); 74 | expect(() => messageAssembled.addChunk(chunk1)).toThrow(); 75 | 76 | // simply invalid chunk 77 | expect(() => messageAssembled.addChunk(new Uint8Array(3))).toThrow(); 78 | 79 | // double chunk add 80 | const messageAssembled2 = new DataChannelMessage({ id: messageOrig.id }); 81 | messageAssembled2.addChunk(chunk1); 82 | expect(() => messageAssembled2.addChunk(chunk1)).toThrow(); 83 | }); 84 | }); 85 | -------------------------------------------------------------------------------- /src/types/role.js: -------------------------------------------------------------------------------- 1 | import { getPbId, unbufferize } from '../protobuf'; 2 | import ObjectRegistry from '../object-registry'; 3 | import { NOT_ENOUGH_ARGS } from '../_errors'; 4 | 5 | export class Role { 6 | constructor( 7 | id, 8 | actions = [], 9 | state = null, 10 | placeholders = {}, 11 | input_placeholder_ids = [], 12 | output_placeholder_ids = [], 13 | tags = [], 14 | description = null 15 | ) { 16 | this.id = id; 17 | this.actions = actions; 18 | this.state = state; 19 | this.placeholders = placeholders; 20 | this.input_placeholder_ids = input_placeholder_ids; 21 | this.output_placeholder_ids = output_placeholder_ids; 22 | this.tags = tags; 23 | this.description = description; 24 | } 25 | 26 | static unbufferize(worker, pb) { 27 | let placeholdersArray = unbufferize(worker, pb.placeholders); 28 | let placeholders = {}; 29 | for (let ph of placeholdersArray) { 30 | placeholders[ph.id] = ph; 31 | } 32 | 33 | return new Role( 34 | getPbId(pb.id), 35 | unbufferize(worker, pb.actions), 36 | unbufferize(worker, pb.state), 37 | placeholders, 38 | pb.input_placeholder_ids.map(getPbId), 39 | pb.output_placeholder_ids.map(getPbId), 40 | pb.tags, 41 | pb.description 42 | ); 43 | } 44 | 45 | findPlaceholders(tagRegex) { 46 | return this.placeholders.filter( 47 | placeholder => 48 | placeholder.tags && placeholder.tags.some(tag => tagRegex.test(tag)) 49 | ); 50 | } 51 | 52 | getInputPlaceholders() { 53 | return this.input_placeholder_ids.map(id => this.placeholders[id]); 54 | } 55 | 56 | getOutputPlaceholders() { 57 | return this.output_placeholder_ids.map(id => this.placeholders[id]); 58 | } 59 | 60 | /** 61 | * Execute the Role with given worker 62 | * @param {Syft} worker 63 | * @param data 64 | * @returns {Promise} 65 | */ 66 | async execute(worker, ...data) { 67 | // Create local scope. 68 | const planScope = new ObjectRegistry(); 69 | planScope.load(worker.objects); 70 | 71 | const inputPlaceholders = this.getInputPlaceholders(), 72 | outputPlaceholders = this.getOutputPlaceholders(), 73 | argsLength = inputPlaceholders.length; 74 | 75 | // If the number of arguments supplied does not match the number of arguments required... 76 | if (data.length !== argsLength) 77 | throw new Error(NOT_ENOUGH_ARGS(data.length, argsLength)); 78 | 79 | // For each argument supplied, add them in scope 80 | data.forEach((datum, i) => { 81 | planScope.set(inputPlaceholders[i].id, datum); 82 | }); 83 | 84 | // load state tensors to worker 85 | if (this.state && this.state.tensors) { 86 | this.state.placeholders.forEach((ph, idx) => { 87 | planScope.set(ph.id, this.state.tensors[idx]); 88 | }); 89 | } 90 | 91 | // Execute the plan 92 | for (const action of this.actions) { 93 | // The result of the current operation 94 | const result = await action.execute(planScope); 95 | 96 | // Place the result of the current operation into this.objects at the 0th item in returnIds 97 | // All intermediate tensors will be garbage collected by default 98 | if (result) { 99 | if (action.returnIds.length > 0) { 100 | planScope.set(action.returnIds[0], result, true); 101 | } else if (action.returnPlaceholderIds.length > 0) { 102 | planScope.set(action.returnPlaceholderIds[0].id, result, true); 103 | } 104 | } 105 | } 106 | 107 | // Resolve all of the requested resultId's as specific by the plan 108 | const resolvedResultingTensors = []; 109 | outputPlaceholders.forEach(placeholder => { 110 | resolvedResultingTensors.push(planScope.get(placeholder.id)); 111 | // Do not gc output tensors 112 | planScope.setGc(placeholder.id, false); 113 | }); 114 | 115 | // Cleanup intermediate plan variables. 116 | planScope.clear(); 117 | 118 | // Return them to the worker 119 | return resolvedResultingTensors; 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /examples/mnist/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 11 | 52 | 53 | Syft.js MNIST Example 54 | 55 | 56 | 57 | 58 | 59 | 60 | 65 |

Syft.js MNIST Example

66 |

67 | This is a demo using 69 | syft.js from OpenMined 72 | to train MNIST model hosted on the PyGrid using local data and Federated 73 | Learning approach. The data never leaves the browser. Please open your 74 | Javascript console to see what's going on (instructions for 75 | Chrome 80 | and 81 | Firefox). 87 |

88 | 89 |
90 |

91 | 92 | 93 |

94 | 95 |

96 | 97 | 98 |

99 | 100 |

101 | 102 | 103 |

104 | 105 |

106 | 107 | 110 |

111 | 112 |

113 | 114 |

115 | 116 |

117 | 121 |

122 |
123 | 124 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /src/types/computation-action.js: -------------------------------------------------------------------------------- 1 | import { unbufferize } from '../protobuf'; 2 | import PointerTensor from './pointer-tensor'; 3 | import { Placeholder, PlaceholderId } from './placeholder'; 4 | import * as tf from '@tensorflow/tfjs-core'; 5 | import { TorchParameter, TorchTensor } from './torch'; 6 | import { CANNOT_FIND_COMMAND, MISSING_VARIABLE } from '../_errors'; 7 | 8 | export class ComputationAction { 9 | constructor(command, target, args, kwargs, returnIds, returnPlaceholderIds) { 10 | this.command = command; 11 | this.target = target; 12 | this.args = args; 13 | this.kwargs = kwargs; 14 | this.returnIds = returnIds; 15 | this.returnPlaceholderIds = returnPlaceholderIds; 16 | } 17 | 18 | static unbufferize(worker, pb) { 19 | return new ComputationAction( 20 | pb.command, 21 | unbufferize(worker, pb[pb.target]), 22 | unbufferize(worker, pb.args), 23 | unbufferize(worker, pb.kwargs), 24 | unbufferize(worker, pb.return_ids), 25 | unbufferize(worker, pb.return_placeholder_ids) 26 | ); 27 | } 28 | 29 | async execute(scope) { 30 | // A helper function for helping us determine if all PointerTensors/Placeholders inside of "this.args" also exist as tensors inside of "objects" 31 | const haveValuesForAllArgs = args => { 32 | let enoughInfo = true; 33 | 34 | args.forEach(arg => { 35 | if ( 36 | (arg instanceof PointerTensor && !scope.has(arg.idAtLocation)) || 37 | (arg instanceof Placeholder && !scope.has(arg.id)) || 38 | (arg instanceof PlaceholderId && !scope.has(arg.id)) 39 | ) { 40 | enoughInfo = false; 41 | } 42 | }); 43 | 44 | return enoughInfo; 45 | }; 46 | 47 | const toTFTensor = tensor => { 48 | if (tensor instanceof tf.Tensor) { 49 | return tensor; 50 | } else if (tensor instanceof TorchTensor) { 51 | return tensor.toTfTensor(); 52 | } else if (tensor instanceof TorchParameter) { 53 | return tensor.tensor.toTfTensor(); 54 | } else if (typeof tensor === 'number') { 55 | return tensor; 56 | } 57 | return null; 58 | }; 59 | 60 | const getTensorByRef = reference => { 61 | let tensor = null; 62 | if (reference instanceof PlaceholderId) { 63 | tensor = scope.get(reference.id); 64 | } else if (reference instanceof Placeholder) { 65 | tensor = scope.get(reference.id); 66 | } else if (reference instanceof PointerTensor) { 67 | tensor = scope.get(reference.idAtLocation); 68 | } 69 | tensor = toTFTensor(tensor); 70 | return tensor; 71 | }; 72 | 73 | // A helper function for helping us get all operable tensors from PointerTensors inside of "this._args" 74 | const pullTensorsFromArgs = args => { 75 | const resolvedArgs = []; 76 | 77 | args.forEach(arg => { 78 | const tensorByRef = getTensorByRef(arg); 79 | if (tensorByRef) { 80 | resolvedArgs.push(toTFTensor(tensorByRef)); 81 | } else { 82 | // Try to convert to tensor. 83 | const tensor = toTFTensor(arg); 84 | if (tensor !== null) { 85 | resolvedArgs.push(toTFTensor(arg)); 86 | } else { 87 | // Keep as is. 88 | resolvedArgs.push(arg); 89 | } 90 | } 91 | }); 92 | 93 | return resolvedArgs; 94 | }; 95 | 96 | //worker.logger.log(`Given command: ${this.command}, converted command: ${command} + ${JSON.stringify(preArgs)} + ${JSON.stringify(postArgs)}`); 97 | 98 | const args = this.args; 99 | let self = null; 100 | 101 | if (this.target) { 102 | // resolve "self" if it's present 103 | self = getTensorByRef(this.target); 104 | if (!self) { 105 | throw new Error(MISSING_VARIABLE()); 106 | } 107 | } 108 | 109 | if (!haveValuesForAllArgs(args)) { 110 | throw new Error(MISSING_VARIABLE()); 111 | } 112 | 113 | const resolvedArgs = pullTensorsFromArgs(args); 114 | const functionName = this.command.split('.').pop(); 115 | 116 | if (self) { 117 | if (!(functionName in self)) { 118 | throw new Error(CANNOT_FIND_COMMAND(`tensor.${functionName}`)); 119 | } else { 120 | return self[functionName](...resolvedArgs); 121 | } 122 | } 123 | 124 | if (!(functionName in tf)) { 125 | throw new Error(CANNOT_FIND_COMMAND(functionName)); 126 | } else { 127 | return tf[functionName](...resolvedArgs, ...Object.values(this.kwargs)); 128 | } 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /src/syft-webrtc.js: -------------------------------------------------------------------------------- 1 | /** 2 | * NOTE: This is temporary file to hold WebRTC code that needs to be refactored 3 | * during implementation of Protocols functionality in syft.js. 4 | * Do not use this file as it will be removed in the future. 5 | */ 6 | /* istanbul ignore file */ 7 | import { 8 | GET_PROTOCOL, 9 | SOCKET_STATUS, 10 | WEBRTC_INTERNAL_MESSAGE, 11 | WEBRTC_JOIN_ROOM, 12 | WEBRTC_PEER_CONFIG, 13 | WEBRTC_PEER_LEFT, 14 | WEBRTC_PEER_OPTIONS 15 | } from './_constants'; 16 | import { protobuf, unserialize } from './protobuf'; 17 | import Socket from './sockets'; 18 | import WebRTCClient from './webrtc'; 19 | 20 | export class SyftWebrtc { 21 | /* ----- SOCKET COMMUNICATION ----- */ 22 | // TODO refactor into grid client class 23 | 24 | // To create a socket connection internally and externally 25 | createSocketConnection(url) { 26 | if (!url) return; 27 | if (!this.verbose) { 28 | url = url.replace('ws://', 'wss://'); 29 | } 30 | // When a socket connection is opened... 31 | const onOpen = event => { 32 | this.observer.broadcast(SOCKET_STATUS, { 33 | connected: true, 34 | event 35 | }); 36 | }; 37 | 38 | // When a socket connection is closed... 39 | const onClose = event => { 40 | this.observer.broadcast(SOCKET_STATUS, { 41 | connected: false, 42 | event 43 | }); 44 | }; 45 | 46 | // When a socket message is received... 47 | const onMessage = event => { 48 | const { type, data } = event; 49 | 50 | if (type === GET_PROTOCOL) { 51 | if (data.error) { 52 | this.logger.log( 53 | 'There was an error getting the protocol you requested', 54 | data.error 55 | ); 56 | 57 | return data; 58 | } 59 | 60 | // Save our workerId if we don't already have it (also for the socket connection) 61 | this.workerId = data.worker.workerId; 62 | this.socket.workerId = this.workerId; 63 | 64 | // Save our scopeId if we don't already have it 65 | this.scopeId = data.worker.scopeId; 66 | 67 | // Save our role 68 | this.role = data.worker.role; 69 | 70 | // Save the other participant workerId's 71 | this.participants = data.participants; 72 | 73 | // Save the protocol and plan assignment after having Serde detail them 74 | let detailedProtocol; 75 | let detailedPlan; 76 | detailedProtocol = unserialize( 77 | null, 78 | data.protocol, 79 | protobuf.syft_proto.execution.v1.Protocol 80 | ); 81 | detailedPlan = unserialize( 82 | null, 83 | data.plan, 84 | protobuf.syft_proto.execution.v1.Plan 85 | ); 86 | 87 | this.protocol = detailedProtocol; 88 | this.plan = detailedPlan; 89 | 90 | return this.plan; 91 | } else if (type === WEBRTC_INTERNAL_MESSAGE) { 92 | this.rtc.receiveInternalMessage(data); 93 | } else if (type === WEBRTC_JOIN_ROOM) { 94 | this.rtc.receiveNewPeer(data); 95 | } else if (type === WEBRTC_PEER_LEFT) { 96 | this.rtc.removePeer(data.workerId); 97 | } 98 | }; 99 | 100 | this.socket = new Socket({ 101 | url, 102 | workerId: this.workerId, 103 | onOpen, 104 | onClose, 105 | onMessage 106 | }); 107 | } 108 | 109 | // To close the socket connection with the grid 110 | disconnectFromGrid() { 111 | this.socket.stop(); 112 | } 113 | 114 | /* ----- WEBRTC ----- */ 115 | 116 | // To create a socket connection internally and externally 117 | createWebRTCClient(peerConfig, peerOptions) { 118 | // If we don't have a socket sever, we can't create the WebRTCClient 119 | if (!this.socket) return; 120 | 121 | // The default STUN/TURN servers to use for NAT traversal 122 | if (!peerConfig) peerConfig = WEBRTC_PEER_CONFIG; 123 | 124 | // Some standard options for establishing peer connections 125 | if (!peerOptions) peerOptions = WEBRTC_PEER_OPTIONS; 126 | 127 | this.rtc = new WebRTCClient({ 128 | peerConfig, 129 | peerOptions, 130 | socket: this.socket 131 | }); 132 | 133 | const onDataMessage = data => { 134 | this.logger.log(`Data message is received from ${data.worker_id}`, data); 135 | }; 136 | this.rtc.on('message', onDataMessage); 137 | } 138 | 139 | connectToParticipants() { 140 | this.rtc.start(this.workerId, this.scopeId); 141 | } 142 | 143 | disconnectFromParticipants() { 144 | this.rtc.stop(); 145 | } 146 | 147 | sendToParticipants(data, to) { 148 | this.rtc.sendMessage(data, to); 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /src/data-channel-message.js: -------------------------------------------------------------------------------- 1 | import EventObserver from './events'; 2 | import { WEBRTC_DATACHANNEL_CHUNK_SIZE } from './_constants'; 3 | 4 | export default class DataChannelMessage { 5 | static chunkHeaderSign = 0xff00; 6 | static chunkHeaderLength = 10; 7 | 8 | constructor({ data, id, worker_id }) { 9 | if (data === undefined) { 10 | this.data = new ArrayBuffer(0); 11 | } else if (typeof data === 'string') { 12 | this.data = new TextEncoder().encode(data).buffer; 13 | } else if (data instanceof ArrayBuffer) { 14 | this.data = data; 15 | } else { 16 | throw new Error('Message type is not supported'); 17 | } 18 | 19 | this.worker_id = worker_id; 20 | this.size = this.data.byteLength; 21 | this.dataChunks = []; 22 | this.id = id || Math.floor(Math.random() * 0xffffffff); 23 | this.observer = new EventObserver(); 24 | this.chunks = Math.ceil( 25 | this.size / 26 | (WEBRTC_DATACHANNEL_CHUNK_SIZE - DataChannelMessage.chunkHeaderLength) 27 | ); 28 | this.makeChunkHeader(0); 29 | } 30 | 31 | /** 32 | * Creates chunk header for given chunk index 33 | * @param {number} chunk Chunk index 34 | * @returns {ArrayBuffer} 35 | */ 36 | makeChunkHeader(chunk) { 37 | if (this.chunkHeader === undefined) { 38 | this.chunkHeader = new ArrayBuffer(DataChannelMessage.chunkHeaderLength); 39 | 40 | const view = new DataView(this.chunkHeader); 41 | 42 | view.setUint16(0, DataChannelMessage.chunkHeaderSign); 43 | view.setUint32(2, this.id); 44 | view.setUint16(6, this.chunks); 45 | view.setUint16(8, chunk); 46 | } else { 47 | const view = new DataView(this.chunkHeader); 48 | 49 | view.setUint16(8, chunk); 50 | } 51 | 52 | return this.chunkHeader; 53 | } 54 | 55 | once(event, func) { 56 | this.observer.subscribe(event, data => { 57 | this.observer.unsubscribe(event); 58 | 59 | func(data); 60 | }); 61 | } 62 | 63 | /** 64 | * Gets chunk info from header 65 | * @param {ArrayBuffer} buf 66 | */ 67 | static messageInfoFromBuf(buf) { 68 | let view; 69 | 70 | try { 71 | view = new DataView(buf); 72 | if (view.getUint16(0) !== DataChannelMessage.chunkHeaderSign) { 73 | return false; 74 | } 75 | } catch (e) { 76 | return false; 77 | } 78 | 79 | return { 80 | id: view.getUint32(2), 81 | chunks: view.getUint16(6), 82 | chunk: view.getUint16(8) 83 | }; 84 | } 85 | 86 | /** 87 | * Adds chunk for further assembly 88 | * @param {ArrayBuffer} buf 89 | */ 90 | addChunk(buf) { 91 | const info = DataChannelMessage.messageInfoFromBuf(buf); 92 | 93 | if (info === false) { 94 | throw new Error(`Is not a valid chunk`); 95 | } 96 | 97 | if (this.id !== info.id) { 98 | throw new Error( 99 | `Trying to add chunk from different message: ${this.id} != ${info.id}` 100 | ); 101 | } 102 | 103 | if (this.dataChunks[info.chunk] !== undefined) { 104 | throw new Error(`Duplicated chunk ${info.chunks} in message ${this.id}`); 105 | } 106 | 107 | this.dataChunks[info.chunk] = buf.slice( 108 | DataChannelMessage.chunkHeaderLength 109 | ); 110 | 111 | if (this.dataChunks.length === info.chunks) { 112 | this.assemble(); 113 | } 114 | } 115 | 116 | /** 117 | * Concatenate data pieces 118 | */ 119 | assemble() { 120 | let size = 0; 121 | 122 | for (let chunk of this.dataChunks) { 123 | size += chunk.byteLength; 124 | } 125 | 126 | const data = new Uint8Array(size); 127 | let offset = 0; 128 | 129 | for (let chunk of this.dataChunks) { 130 | data.set(new Uint8Array(chunk), offset); 131 | offset += chunk.byteLength; 132 | } 133 | 134 | this.chunks = this.dataChunks.length; 135 | this.size = size; 136 | this.data = data.buffer; 137 | 138 | // Clean up 139 | this.dataChunks = []; 140 | 141 | // Emit event when done 142 | this.observer.broadcast('ready', this); 143 | } 144 | 145 | /** 146 | * Slice a piece of message and add a header 147 | * @param {number} num 148 | * @returns {ArrayBuffer} 149 | */ 150 | getChunk(num) { 151 | const start = 152 | num * 153 | (WEBRTC_DATACHANNEL_CHUNK_SIZE - DataChannelMessage.chunkHeaderLength); 154 | const end = Math.min( 155 | start + 156 | WEBRTC_DATACHANNEL_CHUNK_SIZE - 157 | DataChannelMessage.chunkHeaderLength, 158 | this.size 159 | ); 160 | const chunk = new Uint8Array( 161 | DataChannelMessage.chunkHeaderLength + end - start 162 | ); 163 | const header = this.makeChunkHeader(num); 164 | 165 | chunk.set(new Uint8Array(header), 0); 166 | chunk.set( 167 | new Uint8Array(this.data.slice(start, end)), 168 | DataChannelMessage.chunkHeaderLength 169 | ); 170 | 171 | return chunk.buffer; 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /test/sockets.test.js: -------------------------------------------------------------------------------- 1 | import { SOCKET_PING } from '../src/_constants'; 2 | import { WebSocket, Server } from 'mock-socket'; 3 | 4 | import Socket from '../src/sockets'; 5 | 6 | global.WebSocket = WebSocket; 7 | 8 | const url = 'ws://localhost:8080/'; 9 | 10 | // Create a promise that is resolved when the event is triggered. 11 | const makeEventPromise = (emitter, event) => { 12 | let resolver; 13 | 14 | const promise = new Promise(resolve => (resolver = resolve)); 15 | emitter.on(event, data => resolver(data)); 16 | 17 | return promise; 18 | }; 19 | 20 | describe('Sockets', () => { 21 | let mockServer; 22 | 23 | beforeEach(() => { 24 | mockServer = new Server(url); 25 | mockServer.connected = makeEventPromise(mockServer, 'connection'); 26 | }); 27 | 28 | afterEach(() => { 29 | mockServer.close(); 30 | }); 31 | 32 | test('sends keep-alive messages automatically', async () => { 33 | const keepAliveTimeout = 300, 34 | expectedMessagesCount = 3, 35 | messages = [], 36 | expectedTypes = []; 37 | 38 | // Creating a socket will open connection and start keep-alive pings. 39 | new Socket({ url, keepAliveTimeout }); 40 | 41 | const serverSocket = await mockServer.connected; 42 | 43 | serverSocket.on('message', message => messages.push(JSON.parse(message))); 44 | 45 | await new Promise(done => 46 | setTimeout( 47 | done, 48 | keepAliveTimeout * expectedMessagesCount + keepAliveTimeout / 2 49 | ) 50 | ); 51 | 52 | // One keep-alive message is sent right after connection, hence +1. 53 | expect(messages).toHaveLength(expectedMessagesCount + 1); 54 | 55 | for (let i = 0; i < expectedMessagesCount + 1; i++) { 56 | expectedTypes.push(SOCKET_PING); 57 | } 58 | 59 | expect(messages.map(message => message['type'])).toEqual(expectedTypes); 60 | }); 61 | 62 | test('triggers onOpen event', async () => { 63 | const onOpen = jest.fn(); 64 | 65 | new Socket({ url, onOpen }); 66 | 67 | await mockServer.connected; 68 | 69 | expect(onOpen).toHaveBeenCalledTimes(1); 70 | }); 71 | 72 | test('triggers onClose event', async () => { 73 | const closed = makeEventPromise(mockServer, 'close'), 74 | onClose = jest.fn(), 75 | mySocket = new Socket({ 76 | url, 77 | onClose 78 | }); 79 | 80 | await mockServer.connected; 81 | 82 | mySocket.stop(); 83 | 84 | await closed; 85 | 86 | expect(onClose).toHaveBeenCalledTimes(1); 87 | expect(mySocket.timerId).toBeNull(); 88 | }); 89 | 90 | test('sends data correctly', async () => { 91 | const testReqType = 'test', 92 | testReqData = { blob: 1 }, 93 | testResponse = { response: 'test' }, 94 | testworkerId = 'test-worker', 95 | mySocket = new Socket({ 96 | workerId: testworkerId, 97 | url, 98 | onMessage: data => data 99 | }); 100 | 101 | const serverSocket = await mockServer.connected; 102 | 103 | // Skip first keep-alive message. 104 | await makeEventPromise(serverSocket, 'message'); 105 | 106 | const responsePromise = mySocket.send(testReqType, testReqData); 107 | const message = await makeEventPromise(serverSocket, 'message'); 108 | 109 | serverSocket.send(JSON.stringify(testResponse)); 110 | 111 | const response = await responsePromise; 112 | 113 | expect(JSON.parse(message)).toEqual({ 114 | type: testReqType, 115 | data: testReqData 116 | }); 117 | expect(response).toEqual(testResponse); 118 | }); 119 | 120 | test('returns error when .send() fails', async () => { 121 | const mySocket = new Socket({ 122 | url, 123 | onMessage: data => data 124 | }); 125 | 126 | const serverSocket = await mockServer.connected; 127 | 128 | // Skip first keep-alive message. 129 | await makeEventPromise(serverSocket, 'message'); 130 | 131 | const responsePromise = mySocket.send('test', {}); 132 | 133 | mockServer.simulate('error'); 134 | 135 | expect.assertions(1); 136 | 137 | try { 138 | await responsePromise; 139 | } catch (e) { 140 | expect(e).toBeDefined(); 141 | } 142 | }); 143 | 144 | test('disconnects from server after .stop()', async () => { 145 | const mySocket = new Socket({ 146 | url 147 | }); 148 | 149 | await mockServer.connected; 150 | 151 | expect(mockServer.clients()).toHaveLength(1); 152 | 153 | mySocket.stop(); 154 | 155 | await new Promise(done => setTimeout(done, 100)); 156 | 157 | expect(mockServer.clients()).toHaveLength(0); 158 | }); 159 | 160 | test('triggers onMessage event', async () => { 161 | const testResponse = { response: 'test' }, 162 | testworkerId = 'test-worker', 163 | onMessage = jest.fn(message => message), 164 | mySocket = new Socket({ 165 | workerId: testworkerId, 166 | url, 167 | onMessage: onMessage 168 | }); 169 | 170 | const serverSocket = await mockServer.connected; 171 | 172 | // Skip first keep-alive message. 173 | await makeEventPromise(serverSocket, 'message'); 174 | 175 | serverSocket.on('message', () => { 176 | serverSocket.send(JSON.stringify(testResponse)); 177 | }); 178 | 179 | await mySocket.send('test1', {}); 180 | await mySocket.send('test2', {}); 181 | 182 | expect(onMessage).toHaveBeenCalledTimes(2); 183 | expect(onMessage).toHaveBeenLastCalledWith(testResponse); 184 | }); 185 | }); 186 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ### Changelog 2 | 3 | All notable changes to this project will be documented in this file. Dates are displayed in UTC. 4 | 5 | Generated by [`auto-changelog`](https://github.com/CookPete/auto-changelog). 6 | 7 | #### [v1.5.0](https://github.com/OpenMined/syft.js/compare/v1.4.4...v1.5.0) 8 | 9 | > 25 July 2019 10 | 11 | - Major refactor to build and testing systems [`#42`](https://github.com/OpenMined/syft.js/pull/42) 12 | - Upgrading the project to a modern development workflow [`b584677`](https://github.com/OpenMined/syft.js/commit/b584677fe1530e8114463f577dc4e5e0e51432f5) 13 | - Upgrading tensorflow [`977ef41`](https://github.com/OpenMined/syft.js/commit/977ef41c929d1bbb9dab5f956c497d128b7d0ff7) 14 | - Generating autochangelog [`ffed4a2`](https://github.com/OpenMined/syft.js/commit/ffed4a2dcc1647d37ef2f3acc9734592db835b09) 15 | 16 | #### [v1.4.4](https://github.com/OpenMined/syft.js/compare/v1.4.3...v1.4.4) 17 | 18 | > 17 October 2018 19 | 20 | - Adding a result_id to operations [`a38af02`](https://github.com/OpenMined/syft.js/commit/a38af02a9b3897ef2172296a1755bf1e2a2626e0) 21 | 22 | #### [v1.4.3](https://github.com/OpenMined/syft.js/compare/v1.4.2...v1.4.3) 23 | 24 | > 14 October 2018 25 | 26 | - GitBook: [1.4.2] 5 pages modified [`#41`](https://github.com/OpenMined/syft.js/pull/41) 27 | - GitBook: [master] 4 pages modified [`522e928`](https://github.com/OpenMined/syft.js/commit/522e9282d8df5f65aaee7dc4dafe3d28c15f10fd) 28 | - Delete api-documentation.md [`625a623`](https://github.com/OpenMined/syft.js/commit/625a6232edcc81cd5acbc1e4cce3f569fcff3152) 29 | - Delete guide.md [`5982dbd`](https://github.com/OpenMined/syft.js/commit/5982dbde60aebe2afac6b263cfdfea1973bde4e5) 30 | 31 | #### v1.4.2 32 | 33 | > 14 October 2018 34 | 35 | - Tensor removal issue [`#40`](https://github.com/OpenMined/syft.js/pull/40) 36 | - TensorFlow bundling [`#38`](https://github.com/OpenMined/syft.js/pull/38) 37 | - Improve build system [`#35`](https://github.com/OpenMined/syft.js/pull/35) 38 | - Add tensorflow [`#31`](https://github.com/OpenMined/syft.js/pull/31) 39 | - Add tensorflow [`#29`](https://github.com/OpenMined/syft.js/pull/29) 40 | - added preliminary websocket support [`#22`](https://github.com/OpenMined/syft.js/pull/22) 41 | - Travis Support for Unit Tests [`#24`](https://github.com/OpenMined/syft.js/pull/24) 42 | - FloatTensor add test [`#23`](https://github.com/OpenMined/syft.js/pull/23) 43 | - created a node server [`#19`](https://github.com/OpenMined/syft.js/pull/19) 44 | - adding demo folder's primitives. [`#20`](https://github.com/OpenMined/syft.js/pull/20) 45 | - Unit Testing Suite [`#21`](https://github.com/OpenMined/syft.js/pull/21) 46 | - Rollup [`#18`](https://github.com/OpenMined/syft.js/pull/18) 47 | - A cleaned up version of examples to keep things organized [`#17`](https://github.com/OpenMined/syft.js/pull/17) 48 | - Add gpu.js as a dependency [`#16`](https://github.com/OpenMined/syft.js/pull/16) 49 | - added js to parse json and create a js object [`#15`](https://github.com/OpenMined/syft.js/pull/15) 50 | - Initial commit, got babel and prettier running [`#5`](https://github.com/OpenMined/syft.js/pull/5) 51 | - Create README.md [`#4`](https://github.com/OpenMined/syft.js/pull/4) 52 | - Master [`#3`](https://github.com/OpenMined/syft.js/pull/3) 53 | - Corrected 'Resources' links [`#1`](https://github.com/OpenMined/syft.js/pull/1) 54 | - Closes #39 [`#39`](https://github.com/OpenMined/syft.js/issues/39) 55 | - Adding a Webpack example [`#36`](https://github.com/OpenMined/syft.js/issues/36) 56 | - Closes #32 by adding sourcemaps and minification [`#32`](https://github.com/OpenMined/syft.js/issues/32) 57 | - clean up plus docs 1 [`b28f7a1`](https://github.com/OpenMined/syft.js/commit/b28f7a15a051ae6c15ccf523e6df8bb0fcbbd0ae) 58 | - re-init [`4936d3f`](https://github.com/OpenMined/syft.js/commit/4936d3ffd2a01da2432f758723df72a9853a7f33) 59 | - Docs [`5be80be`](https://github.com/OpenMined/syft.js/commit/5be80be0c45b79a45990f53f1b2bbd12ad7e2f49) 60 | 61 | #### [v0.0.1-1](https://github.com/OpenMined/syft.js/compare/v1.5.0...v0.0.1-1) 62 | 63 | > 29 December 2019 64 | 65 | - Update serde, fix with-node example [`#71`](https://github.com/OpenMined/syft.js/pull/71) 66 | - Revert merge conflict issues [`#69`](https://github.com/OpenMined/syft.js/pull/69) 67 | - Changing terminology from user to worker [`#68`](https://github.com/OpenMined/syft.js/pull/68) 68 | - TS type defs, closes #57 [`#67`](https://github.com/OpenMined/syft.js/pull/67) 69 | - Fix specific case in string serde [`#58`](https://github.com/OpenMined/syft.js/pull/58) 70 | - Closes #55 [`#56`](https://github.com/OpenMined/syft.js/pull/56) 71 | - Grid syft [`#54`](https://github.com/OpenMined/syft.js/pull/54) 72 | - Add unit tests for webrtc.js [`#52`](https://github.com/OpenMined/syft.js/pull/52) 73 | - Remove unused variables in socket.test.js [`#51`](https://github.com/OpenMined/syft.js/pull/51) 74 | - Add sockets.js tests and minor fixes [`#50`](https://github.com/OpenMined/syft.js/pull/50) 75 | - Adding Serde to syft.js [`#43`](https://github.com/OpenMined/syft.js/pull/43) 76 | - Merge pull request #67 from tisd/ts-type-defs [`#57`](https://github.com/OpenMined/syft.js/issues/57) 77 | - Merge pull request #56 from OpenMined/issue-55 [`#55`](https://github.com/OpenMined/syft.js/issues/55) 78 | - Closes #55 [`#55`](https://github.com/OpenMined/syft.js/issues/55) 79 | - Closes #53 [`#53`](https://github.com/OpenMined/syft.js/issues/53) 80 | - Working on with-grid watching issues [`c63ce1f`](https://github.com/OpenMined/syft.js/commit/c63ce1f8ff4c5eafe2c2862e359f6f3ad1164237) 81 | - Reverting back to NPM [`b1e8d39`](https://github.com/OpenMined/syft.js/commit/b1e8d390a851a21e4ec4d59ae80b3914f59915b9) 82 | - Switching back to NPM [`34f583e`](https://github.com/OpenMined/syft.js/commit/34f583e4ddc3cd46e61e34849bacc34bc5176a03) 83 | -------------------------------------------------------------------------------- /test/protobuf.test.js: -------------------------------------------------------------------------------- 1 | import { protobuf, unserialize, getPbId, serialize } from '../src/protobuf'; 2 | import { ObjectMessage } from '../src/types/message'; 3 | import Protocol from '../src/types/protocol'; 4 | import { Plan } from '../src/types/plan'; 5 | import { State } from '../src/types/state'; 6 | import { MNIST_PLAN, MNIST_MODEL_PARAMS, PROTOCOL } from './data/dummy'; 7 | import { TorchTensor } from '../src/types/torch'; 8 | import * as tf from '@tensorflow/tfjs-core'; 9 | import { Placeholder } from '../src/types/placeholder'; 10 | 11 | describe('Protobuf', () => { 12 | test('can unserialize an ObjectMessage', () => { 13 | const obj = unserialize( 14 | null, 15 | 'CjcKBwi91JDfnQISKgoECgICAxIHZmxvYXQzMrIBGAAAgD8AAABAAABAQAAAgEAAAKBAMzPDQEAE', 16 | protobuf.syft_proto.messaging.v1.ObjectMessage 17 | ); 18 | expect(obj).toBeInstanceOf(ObjectMessage); 19 | }); 20 | 21 | test('can unserialize a Protocol', () => { 22 | const protocol = unserialize( 23 | null, 24 | PROTOCOL, 25 | protobuf.syft_proto.execution.v1.Protocol 26 | ); 27 | expect(protocol).toBeInstanceOf(Protocol); 28 | }); 29 | 30 | test('can unserialize a Plan', () => { 31 | const plan = unserialize( 32 | null, 33 | MNIST_PLAN, 34 | protobuf.syft_proto.execution.v1.Plan 35 | ); 36 | expect(plan).toBeInstanceOf(Plan); 37 | }); 38 | 39 | test('can unserialize a State', () => { 40 | const state = unserialize( 41 | null, 42 | MNIST_MODEL_PARAMS, 43 | protobuf.syft_proto.execution.v1.State 44 | ); 45 | expect(state).toBeInstanceOf(State); 46 | }); 47 | 48 | test('can serialize a State', async () => { 49 | const placeholders = [ 50 | new Placeholder('123', ['tag1', 'tag2'], 'placeholder') 51 | ]; 52 | const tensors = [ 53 | await TorchTensor.fromTfTensor( 54 | tf.tensor([ 55 | [1.1, 2.2], 56 | [3.3, 4.4] 57 | ]) 58 | ) 59 | ]; 60 | const state = new State(placeholders, tensors); 61 | const serialized = serialize(null, state); 62 | 63 | // unserialize back to check 64 | const unserState = unserialize( 65 | null, 66 | serialized, 67 | protobuf.syft_proto.execution.v1.State 68 | ); 69 | expect(unserState).toBeInstanceOf(State); 70 | expect(unserState.id).toStrictEqual(state.id); 71 | expect(unserState.placeholders).toStrictEqual(placeholders); 72 | expect( 73 | tf 74 | .equal(unserState.tensors[0].toTfTensor(), tensors[0].toTfTensor()) 75 | .all() 76 | .dataSync()[0] 77 | ).toBe(1); 78 | }); 79 | 80 | test('can serialize TorchTensor', async () => { 81 | const tensor = tf.tensor([ 82 | [1.1, 2.2], 83 | [3.3, 4.4] 84 | ]); 85 | const torchTensor = await TorchTensor.fromTfTensor(tensor); 86 | torchTensor.tags = ['tag1', 'tag2']; 87 | torchTensor.description = 'description of tensor'; 88 | 89 | // serialize 90 | const bin = serialize(null, torchTensor); 91 | expect(bin).toBeInstanceOf(ArrayBuffer); 92 | 93 | // check unserialized matches to original 94 | const unserTorchTensor = unserialize( 95 | null, 96 | bin, 97 | protobuf.syft_proto.types.torch.v1.TorchTensor 98 | ); 99 | expect(unserTorchTensor.shape).toStrictEqual(torchTensor.shape); 100 | expect(unserTorchTensor.dtype).toStrictEqual(torchTensor.dtype); 101 | expect(unserTorchTensor.contents).toStrictEqual(torchTensor.contents); 102 | // resulting TF tensors are equal 103 | expect( 104 | tf 105 | .equal(unserTorchTensor.toTfTensor(), tensor) 106 | .all() 107 | .dataSync()[0] 108 | ).toBe(1); 109 | expect(unserTorchTensor.tags).toStrictEqual(torchTensor.tags); 110 | expect(unserTorchTensor.description).toStrictEqual(torchTensor.description); 111 | }); 112 | 113 | test('can serialize TorchTensor', async () => { 114 | const tensor = tf.tensor([ 115 | [1.1, 2.2], 116 | [3.3, 4.4] 117 | ]); 118 | const torchTensor = await TorchTensor.fromTfTensor(tensor); 119 | torchTensor.tags = ['tag1', 'tag2']; 120 | torchTensor.description = 'description of tensor'; 121 | 122 | // serialize 123 | const bin = serialize(null, torchTensor); 124 | expect(bin).toBeInstanceOf(ArrayBuffer); 125 | 126 | // check unserialized matches to original 127 | const unserTorchTensor = unserialize( 128 | null, 129 | bin, 130 | protobuf.syft_proto.types.torch.v1.TorchTensor 131 | ); 132 | expect(unserTorchTensor.shape).toStrictEqual(torchTensor.shape); 133 | expect(unserTorchTensor.dtype).toStrictEqual(torchTensor.dtype); 134 | expect(unserTorchTensor.contents).toStrictEqual(torchTensor.contents); 135 | // resulting TF tensors are equal 136 | expect( 137 | tf 138 | .equal(unserTorchTensor.toTfTensor(), tensor) 139 | .all() 140 | .dataSync()[0] 141 | ).toBe(1); 142 | expect(unserTorchTensor.tags).toStrictEqual(torchTensor.tags); 143 | expect(unserTorchTensor.description).toStrictEqual(torchTensor.description); 144 | }); 145 | 146 | test('gets id from types.syft.Id', () => { 147 | const protocolWithIntId = protobuf.syft_proto.execution.v1.Protocol.fromObject( 148 | { 149 | id: { 150 | id_int: 123 151 | } 152 | } 153 | ); 154 | const protocolWithStrId = protobuf.syft_proto.execution.v1.Protocol.fromObject( 155 | { 156 | id: { 157 | id_str: '321' 158 | } 159 | } 160 | ); 161 | expect(getPbId(protocolWithIntId.id)).toBe('123'); 162 | expect(getPbId(protocolWithStrId.id)).toBe('321'); 163 | }); 164 | }); 165 | -------------------------------------------------------------------------------- /examples/mnist/mnist.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs-core'; 19 | 20 | export const IMAGE_H = 28; 21 | export const IMAGE_W = 28; 22 | const IMAGE_SIZE = IMAGE_H * IMAGE_W; 23 | const NUM_CLASSES = 10; 24 | const NUM_DATASET_ELEMENTS = 65000; 25 | 26 | const NUM_TRAIN_ELEMENTS = 55000; 27 | const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS; 28 | 29 | const MNIST_IMAGES_SPRITE_PATH = 30 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png'; 31 | const MNIST_LABELS_PATH = 32 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8'; 33 | 34 | /** 35 | * A class that fetches the sprited MNIST dataset and provide data as 36 | * tf.Tensors. 37 | */ 38 | export class MnistData { 39 | constructor() {} 40 | 41 | async load() { 42 | // Make a request for the MNIST sprited image. 43 | const img = new Image(); 44 | const canvas = document.createElement('canvas'); 45 | const ctx = canvas.getContext('2d'); 46 | const imgRequest = new Promise((resolve, reject) => { 47 | img.crossOrigin = ''; 48 | img.onload = () => { 49 | img.width = img.naturalWidth; 50 | img.height = img.naturalHeight; 51 | 52 | const datasetBytesBuffer = new ArrayBuffer( 53 | NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4 54 | ); 55 | 56 | const chunkSize = 5000; 57 | canvas.width = img.width; 58 | canvas.height = chunkSize; 59 | 60 | for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) { 61 | const datasetBytesView = new Float32Array( 62 | datasetBytesBuffer, 63 | i * IMAGE_SIZE * chunkSize * 4, 64 | IMAGE_SIZE * chunkSize 65 | ); 66 | ctx.drawImage( 67 | img, 68 | 0, 69 | i * chunkSize, 70 | img.width, 71 | chunkSize, 72 | 0, 73 | 0, 74 | img.width, 75 | chunkSize 76 | ); 77 | 78 | const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); 79 | 80 | for (let j = 0; j < imageData.data.length / 4; j++) { 81 | // All channels hold an equal value since the image is grayscale, so 82 | // just read the red channel. 83 | datasetBytesView[j] = imageData.data[j * 4] / 255; 84 | } 85 | } 86 | this.datasetImages = new Float32Array(datasetBytesBuffer); 87 | 88 | resolve(); 89 | }; 90 | img.src = MNIST_IMAGES_SPRITE_PATH; 91 | }); 92 | 93 | const labelsRequest = fetch(MNIST_LABELS_PATH); 94 | const [imgResponse, labelsResponse] = await Promise.all([ 95 | imgRequest, 96 | labelsRequest 97 | ]); 98 | 99 | this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer()); 100 | 101 | // Slice the the images and labels into train and test sets. 102 | this.trainImages = this.datasetImages.slice( 103 | 0, 104 | IMAGE_SIZE * NUM_TRAIN_ELEMENTS 105 | ); 106 | this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS); 107 | this.trainLabels = this.datasetLabels.slice( 108 | 0, 109 | NUM_CLASSES * NUM_TRAIN_ELEMENTS 110 | ); 111 | this.testLabels = this.datasetLabels.slice( 112 | NUM_CLASSES * NUM_TRAIN_ELEMENTS 113 | ); 114 | } 115 | 116 | /** 117 | * Get all training data as a data tensor and a labels tensor. 118 | * 119 | * @returns 120 | * xs: The data tensor, of shape `[numTrainExamples, 784]`. 121 | * labels: The one-hot encoded labels tensor, of shape 122 | * `[numTrainExamples, 10]`. 123 | */ 124 | getTrainData() { 125 | const xs = tf.tensor2d(this.trainImages, [ 126 | this.trainImages.length / IMAGE_SIZE, 127 | IMAGE_H * IMAGE_W 128 | ]); 129 | const labels = tf.tensor2d(this.trainLabels, [ 130 | this.trainLabels.length / NUM_CLASSES, 131 | NUM_CLASSES 132 | ]); 133 | return { xs, labels }; 134 | } 135 | 136 | /** 137 | * Get all test data as a data tensor a a labels tensor. 138 | * 139 | * @param {number} numExamples Optional number of examples to get. If not 140 | * provided, 141 | * all test examples will be returned. 142 | * @returns 143 | * xs: The data tensor, of shape `[numTestExamples, 784]`. 144 | * labels: The one-hot encoded labels tensor, of shape 145 | * `[numTestExamples, 10]`. 146 | */ 147 | getTestData(numExamples = NUM_TEST_ELEMENTS) { 148 | let xs = tf.tensor2d(this.testImages, [ 149 | this.testImages.length / IMAGE_SIZE, 150 | IMAGE_H * IMAGE_W 151 | ]); 152 | let labels = tf.tensor2d(this.testLabels, [ 153 | this.testLabels.length / NUM_CLASSES, 154 | NUM_CLASSES 155 | ]); 156 | 157 | if (numExamples != null) { 158 | xs = xs.slice([0, 0, 0, 0], [numExamples, IMAGE_H * IMAGE_W]); 159 | labels = labels.slice([0, 0], [numExamples, NUM_CLASSES]); 160 | } 161 | return { xs, labels }; 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /examples/multi-armed-bandit/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { render } from 'react-dom'; 3 | import * as tf from '@tensorflow/tfjs-core'; 4 | import { Syft } from '@openmined/syft.js'; 5 | 6 | import App from './app.js'; 7 | 8 | // Define grid connection parameters 9 | const url = 'ws://localhost:5000'; 10 | const modelName = 'bandit'; 11 | const modelVersion = '1.0.0'; 12 | const shouldRepeat = false; 13 | 14 | // Pick random values f.or the layout 15 | const pickValue = p => p[Math.floor(Math.random() * p.length)]; 16 | 17 | // 24 possible configurations 18 | const appConfigPossibilities = { 19 | heroBackground: ['black', 'gradient'], 20 | buttonPosition: ['hero', 'vision'], 21 | buttonIcon: ['arrow', 'user', 'code'], 22 | buttonColor: ['blue', 'white'] 23 | }; 24 | 25 | // Final configuration for the app 26 | const appConfig = { 27 | heroBackground: pickValue(appConfigPossibilities.heroBackground), 28 | buttonPosition: pickValue(appConfigPossibilities.buttonPosition), 29 | buttonIcon: pickValue(appConfigPossibilities.buttonIcon), 30 | buttonColor: pickValue(appConfigPossibilities.buttonColor) 31 | }; 32 | 33 | // Set up an event listener for the button when it's clicked 34 | // TODO: @maddie - Submit the diff for a positive button click here... 35 | const onButtonClick = () => { 36 | console.log( 37 | 'Clicked the button! Send a positive result for config', 38 | appConfig 39 | ); 40 | }; 41 | 42 | // Start React 43 | render( 44 | startFL(url, modelName, modelVersion, shouldRepeat)} 48 | />, 49 | document.getElementById('root') 50 | ); 51 | 52 | // Main start method 53 | const startFL = async ( 54 | url, 55 | modelName, 56 | modelVersion, 57 | authToken = null, 58 | shouldRepeat 59 | ) => { 60 | const worker = new Syft({ url, authToken, verbose: true }); 61 | const job = await worker.newJob({ modelName, modelVersion }); 62 | 63 | job.start(); 64 | 65 | job.on('accepted', async ({ model, clientConfig }) => { 66 | updateStatus('Accepted into cycle!'); 67 | 68 | // TODO: @maddie - Replace all of this with the bandit code, but try to still use the same 69 | // updateAfterBatch and updateStatus calls... those are helpful for the user to see! 70 | // // Load MNIST data 71 | // await loadMnistDataset(); 72 | // const trainDataset = mnist.getTrainData(); 73 | // const data = trainDataset.xs; 74 | // const targets = trainDataset.labels; 75 | 76 | // // Prepare randomized indices for data batching 77 | // const indices = Array.from({ length: data.shape[0] }, (v, i) => i); 78 | // tf.util.shuffle(indices); 79 | 80 | // // Prepare train parameters 81 | // const batchSize = clientConfig.batch_size; 82 | // const lr = clientConfig.lr; 83 | // const numBatches = Math.ceil(data.shape[0] / batchSize); 84 | 85 | // // Calculate total number of model updates 86 | // // in case none of these options specified, we fallback to one loop 87 | // // though all batches. 88 | // const maxEpochs = clientConfig.max_epochs || 1; 89 | // const maxUpdates = clientConfig.max_updates || maxEpochs * numBatches; 90 | // const numUpdates = Math.min(maxUpdates, maxEpochs * numBatches); 91 | 92 | // // Copy model to train it 93 | // let modelParams = []; 94 | // for (let param of model.params) { 95 | // modelParams.push(param.clone()); 96 | // } 97 | 98 | // // Main training loop 99 | // for (let update = 0, batch = 0, epoch = 0; update < numUpdates; update++) { 100 | // // Slice a batch 101 | // const chunkSize = Math.min(batchSize, data.shape[0] - batch * batchSize); 102 | // const indicesBatch = indices.slice( 103 | // batch * batchSize, 104 | // batch * batchSize + chunkSize 105 | // ); 106 | // const dataBatch = data.gather(indicesBatch); 107 | // const targetBatch = targets.gather(indicesBatch); 108 | 109 | // // Execute the plan and get updated model params back 110 | // let [loss, acc, ...updatedModelParams] = await job.plans[ 111 | // 'training_plan' 112 | // ].execute( 113 | // job.worker, 114 | // dataBatch, 115 | // targetBatch, 116 | // chunkSize, 117 | // lr, 118 | // ...modelParams 119 | // ); 120 | 121 | // // Use updated model params in the next cycle 122 | // for (let i = 0; i < modelParams.length; i++) { 123 | // modelParams[i].dispose(); 124 | // modelParams[i] = updatedModelParams[i]; 125 | // } 126 | 127 | // await updateAfterBatch({ 128 | // epoch, 129 | // batch, 130 | // accuracy: await acc.array(), 131 | // loss: await loss.array() 132 | // }); 133 | 134 | // batch++; 135 | 136 | // // Check if we're out of batches (end of epoch) 137 | // if (batch === numBatches) { 138 | // batch = 0; 139 | // epoch++; 140 | // } 141 | 142 | // // Free GPU memory 143 | // acc.dispose(); 144 | // loss.dispose(); 145 | // dataBatch.dispose(); 146 | // targetBatch.dispose(); 147 | // } 148 | 149 | // // Free GPU memory 150 | // data.dispose(); 151 | // targets.dispose(); 152 | 153 | // // TODO protocol execution 154 | // // job.protocols['secure_aggregation'].execute(); 155 | 156 | // // Calc model diff 157 | // const modelDiff = await model.createSerializedDiff(modelParams); 158 | 159 | // // Report diff 160 | // await job.report(modelDiff); 161 | // updateStatus('Cycle is done!'); 162 | 163 | // // Try again... 164 | // if (shouldRepeat) { 165 | // setTimeout(startFL, 1000, url, modelName, modelVersion, authToken); 166 | // } 167 | }); 168 | 169 | job.on('rejected', ({ timeout }) => { 170 | // Handle the job rejection 171 | if (timeout) { 172 | const msUntilRetry = timeout * 1000; 173 | 174 | // Try to join the job again in "msUntilRetry" milliseconds 175 | updateStatus(`Rejected from cycle, retry in ${timeout}`); 176 | setTimeout(job.start.bind(job), msUntilRetry); 177 | } else { 178 | updateStatus( 179 | `Rejected from cycle with no timeout, assuming Model training is complete.` 180 | ); 181 | } 182 | }); 183 | 184 | job.on('error', err => { 185 | updateStatus(`Error: ${err.message}`); 186 | }); 187 | }; 188 | 189 | // Status update message 190 | const updateStatus = message => { 191 | console.log('STATUS', message); 192 | }; 193 | 194 | // Log statistics after each batch 195 | const updateAfterBatch = async ({ epoch, batch, accuracy, loss }) => { 196 | console.log( 197 | `Epoch: ${epoch}`, 198 | `Batch: ${batch}`, 199 | `Accuracy: ${accuracy}`, 200 | `Loss: ${loss}` 201 | ); 202 | 203 | await tf.nextFrame(); 204 | }; 205 | -------------------------------------------------------------------------------- /src/speed-test.js: -------------------------------------------------------------------------------- 1 | import { randomFillSync } from 'randomfill'; 2 | 3 | export class SpeedTest { 4 | constructor({ 5 | downloadUrl, 6 | uploadUrl, 7 | pingUrl, 8 | maxUploadSizeMb = 64, 9 | maxTestTimeSec = 10 10 | }) { 11 | this.downloadUrl = downloadUrl; 12 | this.uploadUrl = uploadUrl; 13 | this.pingUrl = pingUrl; 14 | this.maxUploadSizeMb = maxUploadSizeMb; 15 | this.maxTestTimeSec = maxTestTimeSec; 16 | 17 | // Various settings to tune. 18 | this.bwAvgWindow = 5; 19 | this.bwLowJitterThreshold = 0.05; 20 | this.bwMaxLowJitterConsecutiveMeasures = 5; 21 | } 22 | 23 | async meterXhr(xhr, isUpload = false) { 24 | return new Promise((resolve, reject) => { 25 | let timeoutHandler = null, 26 | prevTime = 0, 27 | prevSize = 0, 28 | avgCollector = new AvgCollector({ 29 | avgWindow: this.bwAvgWindow, 30 | lowJitterThreshold: this.bwLowJitterThreshold, 31 | maxLowJitterConsecutiveMeasures: this 32 | .bwMaxLowJitterConsecutiveMeasures 33 | }); 34 | 35 | const req = isUpload ? xhr.upload : xhr; 36 | 37 | const finish = (error = null) => { 38 | if (timeoutHandler) { 39 | clearTimeout(timeoutHandler); 40 | } 41 | 42 | // clean up 43 | req.onprogress = null; 44 | req.onload = null; 45 | req.onerror = null; 46 | xhr.abort(); 47 | 48 | if (!error) { 49 | resolve(avgCollector.getAvg()); 50 | } else { 51 | reject(new Error(error)); 52 | } 53 | }; 54 | 55 | req.onreadystatechange = () => { 56 | if (xhr.readyState === 1) { 57 | // set speed test timeout 58 | timeoutHandler = setTimeout(finish, this.maxTestTimeSec * 1000); 59 | } 60 | }; 61 | 62 | req.onprogress = e => { 63 | const // mbit 64 | size = (8 * e.loaded) / 1048576, 65 | // seconds 66 | time = Date.now() / 1000; 67 | 68 | if (!prevTime) { 69 | prevTime = time; 70 | prevSize = size; 71 | return; 72 | } 73 | 74 | let deltaSize = size - prevSize, 75 | deltaTime = time - prevTime, 76 | speed = deltaSize / deltaTime; 77 | 78 | if (deltaTime === 0 || !Number.isFinite(speed)) { 79 | prevTime = time; 80 | prevSize = size; 81 | return; 82 | } 83 | 84 | const canStop = avgCollector.collect(speed); 85 | if (canStop) { 86 | finish(); 87 | } 88 | 89 | prevSize = size; 90 | prevTime = time; 91 | }; 92 | 93 | req.onload = () => { 94 | finish(); 95 | }; 96 | req.onerror = e => { 97 | finish(e); 98 | }; 99 | }); 100 | } 101 | 102 | async getDownloadSpeed() { 103 | let xhr = new XMLHttpRequest(); 104 | const result = this.meterXhr(xhr); 105 | 106 | xhr.open('GET', this.downloadUrl + '?' + Math.random(), true); 107 | xhr.send(); 108 | 109 | return result; 110 | } 111 | 112 | async getUploadSpeed() { 113 | const xhr = new XMLHttpRequest(); 114 | const result = this.meterXhr(xhr, true); 115 | 116 | // Create random bytes buffer. 117 | const buff = new Uint8Array(this.maxUploadSizeMb * 1024 * 1024); 118 | const maxRandomChunkSize = 65536; 119 | const chunkNum = Math.ceil(buff.byteLength / maxRandomChunkSize); 120 | for ( 121 | let chunk = 0, offset = 0; 122 | chunk < chunkNum; 123 | chunk++, offset += maxRandomChunkSize 124 | ) { 125 | randomFillSync( 126 | buff, 127 | offset, 128 | Math.min(maxRandomChunkSize, buff.byteLength - offset) 129 | ); 130 | } 131 | 132 | xhr.open('POST', this.uploadUrl, true); 133 | xhr.send(buff); 134 | 135 | return result; 136 | } 137 | 138 | async getPing() { 139 | return new Promise((resolve, reject) => { 140 | const avgCollector = new AvgCollector({}); 141 | let currXhr; 142 | let timeoutHandler; 143 | 144 | const finish = (xhr, error = null) => { 145 | if (timeoutHandler) { 146 | clearTimeout(timeoutHandler); 147 | } 148 | 149 | // clean up 150 | xhr.onprogress = null; 151 | xhr.onload = null; 152 | xhr.onerror = null; 153 | xhr.abort(); 154 | 155 | if (!error) { 156 | resolve(avgCollector.getAvg()); 157 | } else { 158 | reject(new Error(error)); 159 | } 160 | }; 161 | 162 | const runPing = () => { 163 | const xhr = new XMLHttpRequest(); 164 | currXhr = xhr; 165 | let startTime = Date.now(); 166 | 167 | xhr.onload = () => { 168 | const ping = Date.now() - startTime; 169 | const canStop = avgCollector.collect(ping); 170 | if (canStop) { 171 | finish(xhr); 172 | } else { 173 | setTimeout(runPing, 0); 174 | } 175 | }; 176 | 177 | xhr.onerror = e => { 178 | finish(xhr, e); 179 | }; 180 | 181 | xhr.open('GET', this.pingUrl + '?' + Math.random(), true); 182 | xhr.send(); 183 | }; 184 | 185 | timeoutHandler = setTimeout(() => { 186 | finish(currXhr); 187 | }, this.maxTestTimeSec * 1000); 188 | runPing(); 189 | }); 190 | } 191 | } 192 | 193 | /** 194 | * Helper to average series of values 195 | */ 196 | class AvgCollector { 197 | constructor({ 198 | avgWindow = 5, 199 | lowJitterThreshold = 0.05, 200 | maxLowJitterConsecutiveMeasures = 5 201 | }) { 202 | this.measuresCount = 0; 203 | this.prevAvg = 0; 204 | this.avg = 0; 205 | this.lowJitterConsecutiveMeasures = 0; 206 | 207 | this.avgWindow = avgWindow; 208 | this.lowJitterThreshold = lowJitterThreshold; 209 | this.maxLowJitterConsecutiveMeasures = maxLowJitterConsecutiveMeasures; 210 | this.name = name; 211 | } 212 | 213 | collect(value) { 214 | this.prevAvg = this.avg; 215 | const avgWindow = Math.min(this.measuresCount, this.avgWindow); 216 | this.avg = (this.avg * avgWindow + value) / (avgWindow + 1); 217 | this.measuresCount++; 218 | 219 | // Return true if measurements are stable. 220 | if ( 221 | this.prevAvg > 0 && 222 | this.avg < this.prevAvg * (1 + this.lowJitterThreshold) && 223 | this.avg > this.prevAvg * (1 - this.lowJitterThreshold) 224 | ) { 225 | this.lowJitterConsecutiveMeasures++; 226 | } else { 227 | this.lowJitterConsecutiveMeasures = 0; 228 | } 229 | 230 | if ( 231 | this.lowJitterConsecutiveMeasures >= this.maxLowJitterConsecutiveMeasures 232 | ) { 233 | return true; 234 | } 235 | 236 | return false; 237 | } 238 | 239 | getAvg() { 240 | return this.avg; 241 | } 242 | } 243 | -------------------------------------------------------------------------------- /examples/mnist/index.js: -------------------------------------------------------------------------------- 1 | import * as tf from '@tensorflow/tfjs-core'; 2 | import { Syft } from '@openmined/syft.js'; 3 | import { MnistData } from './mnist'; 4 | 5 | const gridServer = document.getElementById('grid-server'); 6 | const startButton = document.getElementById('start'); 7 | let mnist = null; 8 | 9 | startButton.onclick = () => { 10 | setFLUI(); 11 | const modelName = document.getElementById('model-id').value; 12 | const modelVersion = document.getElementById('model-version').value; 13 | const authToken = document.getElementById('auth-token').value; 14 | startFL(gridServer.value, modelName, modelVersion, authToken).catch(err => { 15 | updateStatus(`Error: ${err}`); 16 | }); 17 | }; 18 | 19 | /** 20 | * The main federated learning training routine 21 | * @param url PyGrid Url 22 | * @param modelName Federated learning model name hosted in PyGrid 23 | * @param modelVersion Federated learning model version 24 | * @returns {Promise} 25 | */ 26 | const startFL = async (url, modelName, modelVersion, authToken = null) => { 27 | const worker = new Syft({ url, authToken, verbose: true }); 28 | const job = await worker.newJob({ modelName, modelVersion }); 29 | 30 | job.start(); 31 | 32 | job.on('accepted', async ({ model, clientConfig }) => { 33 | updateStatus('Accepted into cycle!'); 34 | 35 | // Load MNIST data 36 | await loadMnistDataset(); 37 | const trainDataset = mnist.getTrainData(); 38 | const data = trainDataset.xs; 39 | const targets = trainDataset.labels; 40 | 41 | // Prepare randomized indices for data batching. 42 | const indices = Array.from({ length: data.shape[0] }, (v, i) => i); 43 | tf.util.shuffle(indices); 44 | 45 | // Prepare train parameters. 46 | const batchSize = clientConfig.batch_size; 47 | const lr = clientConfig.lr; 48 | const numBatches = Math.ceil(data.shape[0] / batchSize); 49 | 50 | // Calculate total number of model updates 51 | // in case none of these options specified, we fallback to one loop 52 | // though all batches. 53 | const maxEpochs = clientConfig.max_epochs || 1; 54 | const maxUpdates = clientConfig.max_updates || maxEpochs * numBatches; 55 | const numUpdates = Math.min(maxUpdates, maxEpochs * numBatches); 56 | 57 | // Copy model to train it. 58 | let modelParams = []; 59 | for (let param of model.params) { 60 | modelParams.push(param.clone()); 61 | } 62 | 63 | // Main training loop. 64 | for (let update = 0, batch = 0, epoch = 0; update < numUpdates; update++) { 65 | // Slice a batch. 66 | const chunkSize = Math.min(batchSize, data.shape[0] - batch * batchSize); 67 | const indicesBatch = indices.slice( 68 | batch * batchSize, 69 | batch * batchSize + chunkSize 70 | ); 71 | const dataBatch = data.gather(indicesBatch); 72 | const targetBatch = targets.gather(indicesBatch); 73 | 74 | // Execute the plan and get updated model params back. 75 | let [loss, acc, ...updatedModelParams] = await job.plans[ 76 | 'training_plan' 77 | ].execute( 78 | job.worker, 79 | dataBatch, 80 | targetBatch, 81 | chunkSize, 82 | lr, 83 | ...modelParams 84 | ); 85 | 86 | // Use updated model params in the next cycle. 87 | for (let i = 0; i < modelParams.length; i++) { 88 | modelParams[i].dispose(); 89 | modelParams[i] = updatedModelParams[i]; 90 | } 91 | 92 | await updateUIAfterBatch({ 93 | epoch, 94 | batch, 95 | accuracy: await acc.array(), 96 | loss: await loss.array() 97 | }); 98 | 99 | batch++; 100 | 101 | // Check if we're out of batches (end of epoch). 102 | if (batch === numBatches) { 103 | batch = 0; 104 | epoch++; 105 | } 106 | 107 | // Free GPU memory. 108 | acc.dispose(); 109 | loss.dispose(); 110 | dataBatch.dispose(); 111 | targetBatch.dispose(); 112 | } 113 | 114 | // Free GPU memory. 115 | data.dispose(); 116 | targets.dispose(); 117 | 118 | // TODO protocol execution 119 | // job.protocols['secure_aggregation'].execute(); 120 | 121 | // Calc model diff. 122 | const modelDiff = await model.createSerializedDiff(modelParams); 123 | 124 | // Report diff. 125 | await job.report(modelDiff); 126 | updateStatus('Cycle is done!'); 127 | 128 | // Try again. 129 | if (doRepeat()) { 130 | setTimeout(startFL, 1000, url, modelName, modelVersion, authToken); 131 | } 132 | }); 133 | 134 | job.on('rejected', ({ timeout }) => { 135 | // Handle the job rejection. 136 | if (timeout) { 137 | const msUntilRetry = timeout * 1000; 138 | // Try to join the job again in "msUntilRetry" milliseconds 139 | updateStatus(`Rejected from cycle, retry in ${timeout}`); 140 | setTimeout(job.start.bind(job), msUntilRetry); 141 | } else { 142 | updateStatus( 143 | `Rejected from cycle with no timeout, assuming Model training is complete.` 144 | ); 145 | } 146 | }); 147 | 148 | job.on('error', err => { 149 | updateStatus(`Error: ${err.message}`); 150 | }); 151 | }; 152 | 153 | /** 154 | * Loads MNIST dataset into global variable `mnist`. 155 | */ 156 | const loadMnistDataset = async () => { 157 | if (!mnist) { 158 | updateStatus('Loading MNIST data...'); 159 | mnist = new MnistData(); 160 | await mnist.load(); 161 | updateStatus('MNIST data loaded.'); 162 | } 163 | }; 164 | 165 | /** 166 | * Log message on the page. 167 | * @param message 168 | */ 169 | const updateStatus = message => { 170 | const cont = document.getElementById('status'); 171 | cont.innerHTML = message + '
' + cont.innerHTML; 172 | }; 173 | 174 | /** 175 | * Initializes loss & accuracy plots. 176 | */ 177 | const setFLUI = () => { 178 | Plotly.newPlot( 179 | 'loss_graph', 180 | [{ y: [], mode: 'lines', line: { color: '#80CAF6' } }], 181 | { title: 'Train Loss', showlegend: false }, 182 | { staticPlot: true } 183 | ); 184 | 185 | Plotly.newPlot( 186 | 'acc_graph', 187 | [{ y: [], mode: 'lines', line: { color: '#80CAF6' } }], 188 | { title: 'Train Accuracy', showlegend: false }, 189 | { staticPlot: true } 190 | ); 191 | 192 | document.getElementById('fl-training').style.display = 'table'; 193 | }; 194 | 195 | /** 196 | * Updates graphs after each batch. 197 | * @param epoch 198 | * @param batch 199 | * @param accuracy 200 | * @param loss 201 | * @returns {Promise} 202 | */ 203 | const updateUIAfterBatch = async ({ epoch, batch, accuracy, loss }) => { 204 | console.log( 205 | `Epoch: ${epoch}, Batch: ${batch}, Accuracy: ${accuracy}, Loss: ${loss}` 206 | ); 207 | Plotly.extendTraces('loss_graph', { y: [[loss]] }, [0]); 208 | Plotly.extendTraces('acc_graph', { y: [[accuracy]] }, [0]); 209 | await tf.nextFrame(); 210 | }; 211 | 212 | const doRepeat = () => document.getElementById('worker-repeat').checked; 213 | -------------------------------------------------------------------------------- /API-REFERENCE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ### Table of Contents 4 | 5 | - [Syft][1] 6 | - [Parameters][2] 7 | - [Examples][3] 8 | - [newJob][4] 9 | - [Parameters][5] 10 | - [Job][6] 11 | - [Properties][7] 12 | - [on][8] 13 | - [Parameters][9] 14 | - [start][10] 15 | - [Parameters][11] 16 | - [report][12] 17 | - [Parameters][13] 18 | - [Job#accepted][14] 19 | - [Properties][15] 20 | - [Job#rejected][16] 21 | - [Properties][17] 22 | - [Job#error][18] 23 | - [SyftModel][19] 24 | - [Properties][20] 25 | - [createSerializedDiff][21] 26 | - [Parameters][22] 27 | - [Plan][23] 28 | - [execute][24] 29 | - [Parameters][25] 30 | 31 | ## Syft 32 | 33 | Syft client for static federated learning. 34 | 35 | ### Parameters 36 | 37 | - `options` **[Object][26]** 38 | - `options.url` **[string][27]** Full URL to PyGrid app (`ws` and `http` schemas supported). 39 | - `options.verbose` **[boolean][28]** Whether to enable logging and allow unsecured PyGrid connection. 40 | - `options.authToken` **[string][27]** PyGrid authentication token. 41 | - `options.peerConfig` **[Object][26]** [not implemented] WebRTC peer config used with RTCPeerConnection. 42 | 43 | ### Examples 44 | 45 | ```javascript 46 | const client = new Syft({url: "ws://localhost:5000", verbose: true}) 47 | const job = client.newJob({modelName: "mnist", modelVersion: "1.0.0"}) 48 | job.on('accepted', async ({model, clientConfig}) => { 49 | // execute training 50 | const [...newParams] = await this.plans['...'].execute(...) 51 | const diff = await model.createSerializedDiff(newParams) 52 | await this.report(diff) 53 | }) 54 | job.on('rejected', ({timeout}) => { 55 | // re-try later or stop 56 | }) 57 | job.on('error', (err) => { 58 | // handle errors 59 | }) 60 | job.start() 61 | ``` 62 | 63 | ### newJob 64 | 65 | Authenticates the client against PyGrid and instantiates new Job with given options. 66 | 67 | #### Parameters 68 | 69 | - `options` **[Object][26]** 70 | - `options.modelName` **[string][27]** FL Model name. 71 | - `options.modelVersion` **[string][27]** FL Model version. 72 | 73 | 74 | - Throws **any** Error 75 | 76 | Returns **[Promise][29]<[Job][30]>** 77 | 78 | ## Job 79 | 80 | Job represents a single training cycle done by the client. 81 | 82 | ### Properties 83 | 84 | - `plans` **[Object][26]<[string][27], [Plan][31]>** Plans dictionary. 85 | - `protocols` **[Object][26]<[string][27], Protocol>** [not implemented] Protocols dictionary. 86 | - `model` **[SyftModel][32]** Model. 87 | 88 | ### on 89 | 90 | Registers an event listener. 91 | 92 | Available events: `accepted`, `rejected`, `error`. 93 | 94 | #### Parameters 95 | 96 | - `event` **[string][27]** Event name. 97 | - `handler` **[function][33]** Event handler. 98 | 99 | ### start 100 | 101 | Starts the Job executing following actions: 102 | 103 | - Meters connection speed to PyGrid. 104 | - Registers into training cycle on PyGrid. 105 | - Retrieves cycle and client parameters. 106 | - Downloads Plans, Model, Protocols. 107 | - Fires `accepted` event on success. 108 | 109 | #### Parameters 110 | 111 | - `options` **[Object][26]** (optional, default `{}`) 112 | - `options.skipGridSpeedTest` **[boolean][28]** When true, skips the speed test before requesting a cycle. (optional, default `false`) 113 | 114 | Returns **[Promise][29]<void>** 115 | 116 | ### report 117 | 118 | Submits the model diff to PyGrid. 119 | 120 | #### Parameters 121 | 122 | - `diff` **[ArrayBuffer][34]** Serialized difference between original and trained model parameters. 123 | 124 | Returns **[Promise][29]<void>** 125 | 126 | ## Job#accepted 127 | 128 | `accepted` event. 129 | Triggered when PyGrid accepts the client into training cycle. 130 | 131 | Type: [Object][26] 132 | 133 | ### Properties 134 | 135 | - `model` **[SyftModel][32]** Instance of SyftModel. 136 | - `clientConfig` **[Object][26]** Client configuration returned by PyGrid. 137 | 138 | ## Job#rejected 139 | 140 | `rejected` event. 141 | Triggered when PyGrid rejects the client. 142 | 143 | Type: [Object][26] 144 | 145 | ### Properties 146 | 147 | - `timeout` **([number][35] | null)** Time in seconds to re-try. Empty when the FL model is not trainable anymore. 148 | 149 | ## Job#error 150 | 151 | `error` event. 152 | Triggered for plethora of error conditions. 153 | 154 | ## SyftModel 155 | 156 | Model parameters as stored in the PyGrid. 157 | 158 | ### Properties 159 | 160 | - `params` **[Array][36]<tf.Tensor>** Array of Model parameters. 161 | 162 | ### createSerializedDiff 163 | 164 | Calculates difference between 2 versions of the Model parameters 165 | and returns serialized `diff` that can be submitted to PyGrid. 166 | 167 | #### Parameters 168 | 169 | - `updatedModelParams` **[Array][36]<tf.Tensor>** Array of model parameters (tensors). 170 | 171 | Returns **[Promise][29]<[ArrayBuffer][34]>** Protobuf-serialized `diff`. 172 | 173 | ## Plan 174 | 175 | PySyft Plan. 176 | 177 | ### execute 178 | 179 | Executes the Plan and returns its output. 180 | 181 | The order, type and number of arguments must match to arguments defined in the PySyft Plan. 182 | 183 | #### Parameters 184 | 185 | - `worker` **[Syft][37]** 186 | - `data` **...(tf.Tensor | [number][35])** 187 | 188 | Returns **[Promise][29]<[Array][36]<tf.Tensor>>** 189 | 190 | [1]: #syft 191 | 192 | [2]: #parameters 193 | 194 | [3]: #examples 195 | 196 | [4]: #newjob 197 | 198 | [5]: #parameters-1 199 | 200 | [6]: #job 201 | 202 | [7]: #properties 203 | 204 | [8]: #on 205 | 206 | [9]: #parameters-2 207 | 208 | [10]: #start 209 | 210 | [11]: #parameters-3 211 | 212 | [12]: #report 213 | 214 | [13]: #parameters-4 215 | 216 | [14]: #jobaccepted 217 | 218 | [15]: #properties-1 219 | 220 | [16]: #jobrejected 221 | 222 | [17]: #properties-2 223 | 224 | [18]: #joberror 225 | 226 | [19]: #syftmodel 227 | 228 | [20]: #properties-3 229 | 230 | [21]: #createserializeddiff 231 | 232 | [22]: #parameters-5 233 | 234 | [23]: #plan 235 | 236 | [24]: #execute 237 | 238 | [25]: #parameters-6 239 | 240 | [26]: https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Object 241 | 242 | [27]: https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String 243 | 244 | [28]: https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean 245 | 246 | [29]: https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise 247 | 248 | [30]: #job 249 | 250 | [31]: #plan 251 | 252 | [32]: #syftmodel 253 | 254 | [33]: https://developer.mozilla.org/docs/Web/JavaScript/Reference/Statements/function 255 | 256 | [34]: https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/ArrayBuffer 257 | 258 | [35]: https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number 259 | 260 | [36]: https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array 261 | 262 | [37]: #syft 263 | -------------------------------------------------------------------------------- /test/types/plan.test.js: -------------------------------------------------------------------------------- 1 | import { Plan } from '../../src/types/plan'; 2 | import { State } from '../../src/types/state'; 3 | import { TorchTensor } from '../../src/types/torch'; 4 | import { Placeholder, PlaceholderId } from '../../src/types/placeholder'; 5 | import * as tf from '@tensorflow/tfjs-core'; 6 | import { protobuf, unserialize } from '../../src/protobuf'; 7 | import Syft from '../../src/syft'; 8 | import { 9 | MNIST_BATCH_SIZE, 10 | MNIST_LR, 11 | MNIST_BATCH_DATA, 12 | MNIST_PLAN, 13 | MNIST_MODEL_PARAMS, 14 | MNIST_UPD_MODEL_PARAMS, 15 | MNIST_LOSS, 16 | MNIST_ACCURACY, 17 | PLAN_WITH_STATE, 18 | BANDIT_SIMPLE_MODEL_PARAMS, 19 | BANDIT_SIMPLE_PLAN, 20 | BANDIT_THOMPSON_PLAN, 21 | BANDIT_THOMPSON_MODEL_PARAMS 22 | } from '../data/dummy'; 23 | import { ComputationAction } from '../../src/types/computation-action'; 24 | import { Role } from '../../src/types/role'; 25 | 26 | describe('State', () => { 27 | test('can be properly constructed', () => { 28 | const ph = new Placeholder(123); 29 | const ts = new TorchTensor( 30 | 234, 31 | new Float32Array([1, 2, 3, 4]), 32 | [2, 2], 33 | 'float32' 34 | ); 35 | const obj = new State([ph], [ts]); 36 | expect(obj.placeholders).toStrictEqual([ph]); 37 | expect(obj.tensors).toStrictEqual([ts]); 38 | }); 39 | }); 40 | 41 | describe('Plan', () => { 42 | test('can be properly constructed', () => { 43 | const phId1 = new PlaceholderId(555); 44 | const phId2 = new PlaceholderId(666); 45 | const ph1 = new Placeholder(555); 46 | const ph2 = new Placeholder(666); 47 | const action = new ComputationAction( 48 | 'torch.abs', 49 | null, 50 | [phId1], 51 | null, 52 | null, 53 | [phId2] 54 | ); 55 | const state = new State([], []); 56 | const role = new Role(777, [action], state, [ph1, ph2], [phId1], [phId2]); 57 | const obj = new Plan(123, 'plan', role, ['tag1', 'tag2'], 'desc'); 58 | expect(obj.id).toBe(123); 59 | expect(obj.name).toBe('plan'); 60 | expect(obj.role.actions).toStrictEqual([action]); 61 | expect(obj.role.state).toBe(state); 62 | expect(obj.role.placeholders).toStrictEqual([ph1, ph2]); 63 | expect(obj.role.input_placeholder_ids).toStrictEqual([phId1]); 64 | expect(obj.role.output_placeholder_ids).toStrictEqual([phId2]); 65 | expect(obj.tags).toStrictEqual(['tag1', 'tag2']); 66 | expect(obj.description).toStrictEqual('desc'); 67 | }); 68 | 69 | test('can be executed (with state)', async () => { 70 | const input = tf.tensor([ 71 | [1, 2], 72 | [-30, -40] 73 | ]); 74 | // this is what plan contains 75 | const state = tf.tensor([4.2, 7.3]); 76 | const expected = tf.abs(tf.add(input, state)); 77 | const plan = unserialize( 78 | null, 79 | PLAN_WITH_STATE, 80 | protobuf.syft_proto.execution.v1.Plan 81 | ); 82 | const worker = new Syft({ url: 'dummy' }); 83 | const result = await plan.execute(worker, input); 84 | expect(result[0]).toBeInstanceOf(tf.Tensor); 85 | expect( 86 | tf 87 | .equal(result[0], expected) 88 | .all() 89 | .dataSync()[0] 90 | ).toBe(1); 91 | }); 92 | 93 | test('invalid args shape throws corresponding error', async () => { 94 | // PLAN_WITH_STATE plan contains input + A(2,2) 95 | // this should error with tf error about incompatible shapes 96 | const input = tf.ones([3, 3]); 97 | const plan = unserialize( 98 | null, 99 | PLAN_WITH_STATE, 100 | protobuf.syft_proto.execution.v1.Plan 101 | ); 102 | const worker = new Syft({ url: 'dummy' }); 103 | expect(plan.execute(worker, input)).rejects.toThrow( 104 | 'Operands could not be broadcast together with shapes 3,3 and 2.' 105 | ); 106 | }); 107 | 108 | test('can be executed (MNIST example)', async () => { 109 | const plan = unserialize( 110 | null, 111 | MNIST_PLAN, 112 | protobuf.syft_proto.execution.v1.Plan 113 | ); 114 | 115 | const worker = new Syft({ url: 'dummy' }); 116 | const dataState = unserialize( 117 | null, 118 | MNIST_BATCH_DATA, 119 | protobuf.syft_proto.execution.v1.State 120 | ); 121 | const [data, labels] = dataState.tensors; 122 | const lr = tf.tensor(MNIST_LR); 123 | const batchSize = tf.tensor(MNIST_BATCH_SIZE); 124 | const modelState = unserialize( 125 | null, 126 | MNIST_MODEL_PARAMS, 127 | protobuf.syft_proto.execution.v1.State 128 | ); 129 | const modelParams = modelState.tensors; 130 | const [loss, acc, ...updModelParams] = await plan.execute( 131 | worker, 132 | data, 133 | labels, 134 | batchSize, 135 | lr, 136 | ...modelParams 137 | ); 138 | 139 | const refUpdModelParamsState = unserialize( 140 | null, 141 | MNIST_UPD_MODEL_PARAMS, 142 | protobuf.syft_proto.execution.v1.State 143 | ); 144 | const refUpdModelParams = refUpdModelParamsState.tensors.map(i => 145 | i.toTfTensor() 146 | ); 147 | 148 | expect(loss).toBeInstanceOf(tf.Tensor); 149 | expect(acc).toBeInstanceOf(tf.Tensor); 150 | expect(updModelParams).toHaveLength(refUpdModelParams.length); 151 | expect(loss.arraySync()).toStrictEqual(MNIST_LOSS); 152 | expect(acc.arraySync()).toStrictEqual(MNIST_ACCURACY); 153 | 154 | for (let i = 0; i < refUpdModelParams.length; i++) { 155 | // Check that resulting model params are close to pysyft reference 156 | let diff = refUpdModelParams[i].sub(updModelParams[i]); 157 | expect( 158 | diff 159 | .abs() 160 | .sum() 161 | .arraySync() 162 | ).toBeLessThan(1e-7); 163 | } 164 | }); 165 | 166 | test('bandit (simple) example can be executed', async () => { 167 | const plan = unserialize( 168 | null, 169 | BANDIT_SIMPLE_PLAN, 170 | protobuf.syft_proto.execution.v1.Plan 171 | ); 172 | 173 | const worker = new Syft({ url: 'dummy' }); 174 | const modelState = unserialize( 175 | null, 176 | BANDIT_SIMPLE_MODEL_PARAMS, 177 | protobuf.syft_proto.execution.v1.State 178 | ); 179 | const [means] = modelState.tensors; 180 | 181 | const reward = tf.tensor([0, 1, 0]); 182 | const n_so_far = tf.tensor([1, 1, 1]); 183 | const [newMeans] = await plan.execute(worker, reward, n_so_far, means); 184 | 185 | newMeans.print(); 186 | }); 187 | 188 | test('bandit (thompson) example can be executed', async () => { 189 | const plan = unserialize( 190 | null, 191 | BANDIT_THOMPSON_PLAN, 192 | protobuf.syft_proto.execution.v1.Plan 193 | ); 194 | 195 | const worker = new Syft({ url: 'dummy' }); 196 | const modelState = unserialize( 197 | null, 198 | BANDIT_THOMPSON_MODEL_PARAMS, 199 | protobuf.syft_proto.execution.v1.State 200 | ); 201 | const [alphas, betas] = modelState.tensors; 202 | const reward = tf.tensor([0, 0, 0]); 203 | const samples = tf.tensor([0, 0, 1]); 204 | const [newAlphas, newBetas] = await plan.execute( 205 | worker, 206 | reward, 207 | samples, 208 | alphas, 209 | betas 210 | ); 211 | 212 | newAlphas.print(); 213 | newBetas.print(); 214 | }); 215 | }); 216 | -------------------------------------------------------------------------------- /test/data/generate-data.py: -------------------------------------------------------------------------------- 1 | # This python script generates values in dummy.js, for MNIST plan unit tests. 2 | # Script should be executed with appropriate PySyft version installed. 3 | 4 | import os 5 | import base64 6 | 7 | import torch as th 8 | from torch import jit 9 | from torch import nn 10 | from torchvision import datasets, transforms 11 | 12 | import syft as sy 13 | from syft.serde import protobuf 14 | from syft.execution.state import State 15 | from syft.execution.placeholder import PlaceHolder 16 | from syft.execution.translation import TranslationTarget 17 | 18 | sy.make_hook(globals()) 19 | # force protobuf serialization for tensors 20 | hook.local_worker.framework = None 21 | th.random.manual_seed(1) 22 | 23 | def serialize_to_b64_pb(worker, obj): 24 | pb = protobuf.serde._bufferize(worker, obj) 25 | bin = pb.SerializeToString() 26 | return base64.b64encode(bin).decode('ascii') 27 | 28 | 29 | def tensors_to_state(tensors): 30 | return State( 31 | state_placeholders=[ 32 | PlaceHolder().instantiate(t) 33 | for t in tensors 34 | ] 35 | ) 36 | 37 | def set_model_params(module, params_list, start_param_idx=0): 38 | """ Set params list into model recursively 39 | """ 40 | param_idx = start_param_idx 41 | 42 | for name, param in module._parameters.items(): 43 | module._parameters[name] = params_list[param_idx] 44 | param_idx += 1 45 | 46 | for name, child in module._modules.items(): 47 | if child is not None: 48 | param_idx = set_model_params(child, params_list, param_idx) 49 | 50 | return param_idx 51 | 52 | # = MNIST = 53 | 54 | class Net(nn.Module): 55 | def __init__(self): 56 | super(Net, self).__init__() 57 | self.fc1 = nn.Linear(784, 392) 58 | self.fc2 = nn.Linear(392, 10) 59 | 60 | def forward(self, x): 61 | x = self.fc1(x) 62 | x = nn.functional.relu(x) 63 | x = self.fc2(x) 64 | return x 65 | 66 | def softmax_cross_entropy_with_logits(logits, targets, batch_size): 67 | """ Calculates softmax entropy 68 | Args: 69 | * logits: (NxC) outputs of dense layer 70 | * targets: (NxC) one-hot encoded labels 71 | * batch_size: value of N, temporarily required because Plan cannot trace .shape 72 | """ 73 | # numstable logsoftmax 74 | norm_logits = logits - logits.max() 75 | log_probs = norm_logits - norm_logits.exp().sum(dim=1, keepdim=True).log() 76 | # NLL, reduction = mean 77 | return -(targets * log_probs).sum() / batch_size 78 | 79 | def naive_sgd(param, **kwargs): 80 | return param - kwargs['lr'] * param.grad 81 | 82 | model = Net() 83 | 84 | @sy.func2plan() 85 | def training_plan(X, y, batch_size, lr, model_params): 86 | # inject params into model 87 | set_model_params(model, model_params) 88 | 89 | # forward pass 90 | logits = model.forward(X) 91 | 92 | # loss 93 | loss = softmax_cross_entropy_with_logits(logits, y, batch_size) 94 | 95 | # backprop 96 | loss.backward() 97 | 98 | # step 99 | updated_params = [ 100 | naive_sgd(param, lr=lr) 101 | for param in model_params 102 | ] 103 | 104 | # accuracy 105 | pred = th.argmax(logits, dim=1) 106 | target = th.argmax(y, dim=1) 107 | acc = pred.eq(target).sum().float() / batch_size 108 | 109 | return ( 110 | loss, 111 | acc, 112 | *updated_params 113 | ) 114 | 115 | # Build the training plan. 116 | bs = 64 117 | lr = 0.005 118 | model_params = tuple(p.data for p in model.parameters()) 119 | X = th.randn(bs, 28 * 28) 120 | y = nn.functional.one_hot(th.randint(0, 10, (bs,)), 10) 121 | training_plan.build(X, y, th.tensor([bs]), th.tensor([lr]), model_params, trace_autograd=True) 122 | 123 | # Produce training plan result after the first batch. 124 | mnist_dataset = th.utils.data.DataLoader( 125 | datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor()), 126 | batch_size=bs, 127 | shuffle=False 128 | ) 129 | X, y = next(iter(mnist_dataset)) 130 | X = X.view(bs, -1) 131 | y_oh = th.nn.functional.one_hot(y, 10).float() 132 | training_plan.forward = None 133 | loss, acc, *upd_params = training_plan(X, y_oh, th.tensor([bs], dtype=th.float32), th.tensor([lr]), model_params) 134 | 135 | print("MNIST plan (torch): ") 136 | print(training_plan.code) 137 | 138 | training_plan.base_framework = TranslationTarget.TENSORFLOW_JS.value 139 | print("MNIST plan: ") 140 | print(training_plan.code) 141 | 142 | # = Plan with state = 143 | @sy.func2plan(args_shape=[(2,2)], state=(th.tensor([4.2, 7.3]),)) 144 | def plan_with_state(x, state): 145 | (y,) = state.read() 146 | x = x + y 147 | x = th.abs(x) 148 | return x 149 | 150 | plan_with_state.base_framework = TranslationTarget.TENSORFLOW_JS.value 151 | print("Plan w/ state: ") 152 | print(plan_with_state.code) 153 | 154 | # = Bandit plans = 155 | # Simple 156 | reward = th.tensor([0.0, 0.0, 0.0]) 157 | n_so_far = th.tensor([1.0, 1.0, 1.0]) 158 | means = th.tensor([1.0, 2.0, 3.0]) 159 | bandit_args = [reward, n_so_far, means] 160 | bandit_arg_shape = [arg.shape for arg in bandit_args] 161 | @sy.func2plan(args_shape=bandit_arg_shape) 162 | def bandit(reward, n_so_far, means): 163 | prev = means 164 | new = th.div(prev*(n_so_far-1),n_so_far) + th.div(reward,n_so_far) 165 | means=new 166 | return means 167 | 168 | bandit.base_framework = TranslationTarget.TENSORFLOW_JS.value 169 | print("Bandit simple plan: ") 170 | print(bandit.code) 171 | 172 | # Thompson 173 | alphas = th.tensor([1.0, 1.0, 1.0], requires_grad=False) 174 | betas = th.tensor([1.0, 1.0, 1.0], requires_grad=False) 175 | rwd = th.tensor([0.0, 0.0, 0.0]) 176 | samples = th.tensor([0.0, 0.0, 0.0]) 177 | bandit_args_th = [rwd, samples, alphas, betas] 178 | bandit_th_args_shape = [rwd.shape, samples.shape, alphas.shape, betas.shape] 179 | @sy.func2plan(args_shape=bandit_th_args_shape) 180 | def bandit_thompson(reward, sample_vector, alphas, betas): 181 | prev_alpha = alphas 182 | prev_beta = betas 183 | alphas = prev_alpha.add(reward) 184 | betas = prev_beta.add(sample_vector.sub(reward)) 185 | return (alphas, betas) 186 | 187 | bandit_thompson.base_framework = TranslationTarget.TENSORFLOW_JS.value 188 | print("Bandit thompson plan: ") 189 | print(bandit_thompson.code) 190 | 191 | replacements = { 192 | 'MNIST_BATCH_SIZE': bs, 193 | 'MNIST_LR': lr, 194 | 'MNIST_PLAN': serialize_to_b64_pb(hook.local_worker, training_plan), 195 | 'MNIST_BATCH_DATA': serialize_to_b64_pb(hook.local_worker, tensors_to_state([X, y_oh])), 196 | 'MNIST_MODEL_PARAMS': serialize_to_b64_pb(hook.local_worker, tensors_to_state(model_params)), 197 | 'MNIST_UPD_MODEL_PARAMS': serialize_to_b64_pb(hook.local_worker, tensors_to_state(upd_params)), 198 | 'MNIST_LOSS': loss.item(), 199 | 'MNIST_ACCURACY': acc.item(), 200 | 'PLAN_WITH_STATE': serialize_to_b64_pb(hook.local_worker, plan_with_state), 201 | 'BANDIT_SIMPLE_PLAN': serialize_to_b64_pb(hook.local_worker, bandit), 202 | 'BANDIT_SIMPLE_MODEL_PARAMS': serialize_to_b64_pb(hook.local_worker, tensors_to_state([means])), 203 | 'BANDIT_THOMPSON_PLAN': serialize_to_b64_pb(hook.local_worker, bandit_thompson), 204 | 'BANDIT_THOMPSON_MODEL_PARAMS': serialize_to_b64_pb(hook.local_worker, tensors_to_state([alphas, betas])), 205 | } 206 | 207 | with open("dummy.tpl.js", "r") as tpl, open("dummy.js", "w") as output: 208 | js_tpl = tpl.read() 209 | for k, v in replacements.items(): 210 | js_tpl = js_tpl.replace(f"%{k}%", str(v)) 211 | output.write(js_tpl) 212 | -------------------------------------------------------------------------------- /examples/multi-armed-bandit/app.js: -------------------------------------------------------------------------------- 1 | /** @jsx jsx */ 2 | import { jsx } from '@emotion/core'; 3 | import React, { useEffect } from 'react'; 4 | 5 | import backgroundGradient from './background-gradient.svg'; 6 | import logoWhite from './logo-white.svg'; 7 | import logoColor from './logo-color.svg'; 8 | 9 | import '@fortawesome/fontawesome-free/js/all'; 10 | 11 | const fontFamily = 12 | '"Rubik", -apple-system, BlinkMacSystemFont, "Segoe UI", Helvetica, Arial, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol"'; 13 | 14 | const globalStyles = { 15 | paragraph: { 16 | fontFamily, 17 | lineHeight: 1.75, 18 | fontSize: 18 19 | }, 20 | heading: { 21 | fontFamily, 22 | lineHeight: 1.5, 23 | textTransform: 'uppercase', 24 | fontWeight: 700, 25 | letterSpacing: 2, 26 | marginTop: 0, 27 | marginBottom: 0, 28 | fontSize: 24 29 | } 30 | }; 31 | 32 | const Hero = ({ background, button }) => { 33 | const title = 'Answer Questions Using Data You Cannot See'; 34 | const description = `OpenMined is an open-source community whose goal is to make the world more privacy-preserving by lowering the barrier-to-entry to private AI technologies.`; 35 | 36 | const styles = { 37 | container: { 38 | background: 39 | background === 'gradient' ? `url(${backgroundGradient})` : '#333', 40 | backgroundRepeat: 'no-repeat', 41 | backgroundSize: 'cover', 42 | padding: 40, 43 | display: 'flex', 44 | flexDirection: 'column', 45 | alignItems: 'center' 46 | }, 47 | logo: { 48 | width: 200, 49 | height: 'auto' 50 | }, 51 | title: { 52 | color: 'white', 53 | marginTop: 40, 54 | width: 480, 55 | textAlign: 'center' 56 | }, 57 | description: { 58 | color: 'white', 59 | marginTop: 40, 60 | marginBottom: 0, 61 | width: 480, 62 | textAlign: 'center' 63 | } 64 | }; 65 | 66 | return ( 67 |
68 | OpenMined 69 |

{title}

70 |

71 | {description} 72 |

73 | {button && button} 74 |
75 | ); 76 | }; 77 | 78 | const Vision = ({ button }) => { 79 | const title = 'Vision & Mission'; 80 | const description = ` 81 |

Industry standard tools for artificial intelligence have been designed with several assumptions: data is centralized into a single compute cluster, the cluster exists in a secure cloud, and the resulting models will be owned by a central authority. We envision a world in which we are not restricted to this scenario - a world in which AI tools treat privacy, security, and multi-owner governance as first class citizens.

82 |

With OpenMined, an AI model can be governed by multiple owners and trained securely on an unseen, distributed dataset.

83 |

The mission of the OpenMined community is to create an accessible ecosystem of tools for private, secure, multi-owner governed AI. We do this by extending popular libraries like TensorFlow and PyTorch with advanced techniques in cryptography and private machine learning.

84 | `; 85 | 86 | const styles = { 87 | container: { 88 | padding: 40, 89 | maxWidth: 960, 90 | width: '90%', 91 | margin: '0 auto' 92 | } 93 | }; 94 | 95 | return ( 96 |
97 |

{title}

98 |
102 | {button && button} 103 |
104 | ); 105 | }; 106 | 107 | const Button = ({ background, icon, onClick }) => ( 108 | 141 | ); 142 | 143 | const Footer = () => { 144 | const styles = { 145 | wrapper: { 146 | background: '#333', 147 | padding: 40 148 | }, 149 | container: { 150 | display: 'flex', 151 | justifyContent: 'space-between', 152 | width: 940, 153 | margin: '0 auto' 154 | }, 155 | logo: { 156 | width: 200, 157 | height: 'auto' 158 | }, 159 | social: { 160 | display: 'flex', 161 | alignItems: 'center' 162 | }, 163 | socialIcon: { 164 | color: 'rgba(255, 255, 255, 0.5)', 165 | transition: 'color 0.2s ease-in-out', 166 | fontSize: 24, 167 | marginLeft: 20, 168 | '&:hover': { 169 | color: '#FFF' 170 | } 171 | } 172 | }; 173 | return ( 174 |
175 |
176 | OpenMined 177 | 207 |
208 |
209 | ); 210 | }; 211 | 212 | export default ({ config, onButtonClick, start }) => { 213 | useEffect(() => { 214 | start(); 215 | }, []); 216 | 217 | const button = ( 218 |