├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── app ├── .prettierrc ├── .watchmanconfig ├── .yarnclean ├── app.json ├── babel.config.js ├── eas.json ├── google-services.json ├── index.js ├── metro.config.js ├── package.json ├── src │ ├── App.tsx │ ├── communications │ │ ├── Sockets.ts │ │ └── Supabase.ts │ ├── components │ │ ├── Home.tsx │ │ ├── Login.tsx │ │ ├── Outh.ts │ │ ├── Signup.tsx │ │ ├── auth │ │ │ └── PlatformAuth.native.tsx │ │ └── primary │ │ │ ├── Button.tsx │ │ │ ├── SectionBreak.tsx │ │ │ ├── SocialAuthButton.tsx │ │ │ ├── SocialAuthButtons.tsx │ │ │ ├── TextField.tsx │ │ │ ├── TrainToggle.tsx │ │ │ └── Typography.tsx │ ├── ml │ │ ├── Config.ts │ │ ├── Diagnostics.ts │ │ ├── Evaluation.ts │ │ ├── Losses.ts │ │ ├── ModelHandler.ts │ │ ├── Optimizers.ts │ │ ├── Prediction.ts │ │ ├── TensorflowHandler.ts │ │ ├── Training.ts │ │ └── index.ts │ ├── styles │ │ ├── Style.ts │ │ └── theme.ts │ ├── types │ │ └── declarations.d.ts │ └── utils │ │ └── CurrentDevice.ts └── tsconfig.json ├── assets ├── .DS_Store └── logo.png ├── example └── job.py └── framework ├── mfl ├── __init__.py ├── common.py ├── data.py ├── federated.py ├── keras_h5_conversion.py ├── quantization.py ├── read_weights.py ├── trainer.py ├── worker.py └── write_weights.py ├── requirements.txt └── setup.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMUNACHI/FederatedPhoneML/0ecab2cb0bef83e1c049c854d8a98c9cee661e29/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | yarn.lock 29 | package-lock.json 30 | node_modules/ 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | *.ipynb* 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 113 | .pdm.toml 114 | .pdm-python 115 | .pdm-build/ 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | .idea/ 166 | 167 | app/.expo 168 | 169 | # iOS Pods that are autogenerated in dozens 170 | app/ios/ 171 | app/android/ 172 | 173 | *.ipa 174 | *.apk 175 | 176 | *.DS_STORE* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 FederatedPhoneML 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mobile Federated Learning 2 | 3 |

4 | FederatedPhoneML Logo 5 |

6 | 7 | A mobile-first framework for distributed machine learning on phones, enabling on-device training via federated learning using a Python SDK. 8 | 9 | ## Table of Contents 10 | 11 | 1. [Technical Design, Benefits, and Limitations](#technical-design-benefits-and-limitations) 12 | 2. [Features](#features) 13 | 3. [Prerequisites](#prerequisites) 14 | 4. [Installation](#installation) 15 | - [Supabase Setup](#supabase-setup) 16 | - [Python SDK](#python-sdk) 17 | - [Mobile App (Expo)](#mobile-app-expo) 18 | 5. [Usage](#usage) 19 | - [Python SDK](#python-sdk-usage) 20 | - [React Native App](#react-native-app-usage) 21 | 6. [API Reference](#api-reference) 22 | - [ReceiveConfig](#receiveconfig) 23 | - [SendConfig](#sendconfig) 24 | 7. [Technical Considerations](#technical-considerations) 25 | 8. [Performance Notes](#performance-notes) 26 | 9. [Contributing](#contributing) 27 | 10. [License](#license) 28 | 29 | ## Technical Design, Benefits, and Limitations 30 | 31 | ### Core Design 32 | 33 | FederatedPhoneML is a distributed machine learning framework enabling on-device training via federated learning on mobile phones. The architecture consists of: 34 | 35 | 1. **Python SDK**: Server-side coordination layer using Keras for model definition 36 | 2. **Mobile Client**: React Native app with TensorFlow.js that executes training locally 37 | 3. **Coordination Backend**: Supabase for task distribution and weight aggregation 38 | 39 | Training occurs directly on users' devices rather than centralizing data collection. The system sends model architecture and initial weights to phones, which perform local training iterations on device-specific data, then return only the updated weights. 40 | 41 | ### Technical Benefits 42 | 43 | 1. **Data Privacy**: Raw training data never leaves user devices; only model weight updates are transmitted 44 | 2. **Bandwidth Efficiency**: Transmitting model weights requires significantly less bandwidth than raw data transfer 45 | 3. **Distributed Computation**: Leverages computational resources across many devices rather than centralized servers 46 | 4. **Heterogeneous Data Utilization**: Captures diverse training signals from varied user environments and behaviors 47 | 5. **Battery-Aware Processing**: Implements batch-wise processing to manage memory constraints on mobile devices 48 | 49 | ### Technical Limitations 50 | 51 | 1. **Resource Constraints**: Mobile devices have limited RAM and computational power, restricting model complexity 52 | 2. **Battery Consumption**: On-device training significantly increases energy usage, requiring careful optimization 53 | 3. **Training Latency**: Federation rounds are limited by slowest devices and network conditions 54 | 4. **Model Size Limitations**: TensorFlow.js imposes practical limits on model architecture complexity 55 | 5. **Heterogeneous Performance**: Training behavior varies across device hardware, potentially introducing bias 56 | 6. **WebGL Limitations**: Backend acceleration varies significantly across mobile GPUs 57 | 7. **Synchronization Challenges**: Coordination of model versions across devices with intermittent connectivity 58 | 8. **Security Vulnerabilities**: Weight updates can potentially leak information about training data 59 | 60 | ### Implementation Constraints 61 | 62 | 1. Limited to Keras/TensorFlow models with specific serialization requirements 63 | 2. TensorFlow.js performance varies significantly across browsers and devices 64 | 3. Tensor cleanup requires explicit memory management to prevent leaks 65 | 4. No differential privacy implementation yet to fully protect against inference attacks 66 | 5. Limited support for asynchronous federated learning when devices have variable availability 67 | 68 | ## Features 69 | 70 | - Distributed federated training with gradient accumulation 71 | - Memory-efficient batch processing 72 | - Real-time model evaluation and inference on mobile 73 | - Automatic resource management and tensor cleanup 74 | - Comprehensive error handling and fallback support 75 | 76 | ## Prerequisites 77 | 78 | - Node.js (>= 14.x) 79 | - Yarn (>= 1.x) 80 | - Python (>= 3.7) 81 | - expo-cli (for React Native) 82 | - CocoaPods (for iOS dependencies) 83 | - [Supabase](https://supabase.com) account (for backend services) 84 | 85 | ## Installation 86 | 87 | ### Supabase Setup 88 | 89 | 1. Create a new Supabase project at [supabase.com](https://supabase.com) 90 | 2. Get your project URL and anon key from Settings → API 91 | 3. Create a `.env` file in the app root with: 92 | ```bash 93 | EXPO_PUBLIC_SUPABASE_URL=your-supabase-url 94 | EXPO_PUBLIC_SUPABASE_ANON_KEY=your-supabase-anon-key 95 | ``` 96 | 4. Create these tables in your Supabase database: 97 | ```sql 98 | -- Devices table 99 | CREATE TABLE devices ( 100 | id SERIAL PRIMARY KEY, 101 | user_id UUID REFERENCES auth.users NOT NULL, 102 | status VARCHAR(20) NOT NULL, 103 | last_updated TIMESTAMPTZ 104 | ); 105 | 106 | -- Task Requests table 107 | CREATE TABLE task_requests ( 108 | id SERIAL PRIMARY KEY, 109 | device_id INTEGER REFERENCES devices(id) NOT NULL, 110 | request_type VARCHAR(20) NOT NULL, 111 | data JSONB NOT NULL, 112 | created_at TIMESTAMPTZ DEFAULT NOW() 113 | ); 114 | 115 | -- Task Responses table 116 | CREATE TABLE task_responses ( 117 | id SERIAL PRIMARY KEY, 118 | task_id INTEGER REFERENCES task_requests(id) NOT NULL, 119 | data JSONB NOT NULL, 120 | created_at TIMESTAMPTZ DEFAULT NOW() 121 | ); 122 | ``` 123 | 124 | ### Python SDK 125 | 126 | 1. Navigate to the SDK folder: 127 | ```bash 128 | cd python 129 | ```2. Install in editable mode: 130 | ```bash 131 | pip install -e . 132 | ``` 133 | 134 | ### Mobile App (Expo) 135 | 136 | 1. Install dependencies: 137 | ```bash 138 | yarn 139 | ``` 140 | 2. Start the Expo development server: 141 | ```bash 142 | yarn start 143 | ``` 144 | 145 | #### iOS Setup 146 | 147 | ```bash 148 | cd ios 149 | pod install 150 | cd .. 151 | yarn 152 | yarn start 153 | i 154 | ``` 155 | 156 | #### Android Setup 157 | 158 | ```bash 159 | yarn start 160 | a 161 | ``` 162 | 163 | ## Usage 164 | 165 | ### Python SDK Usage 166 | 167 | ```python 168 | from mfl import keras, Trainer 169 | 170 | # Define a simple model 171 | model = keras.Sequential([ 172 | keras.layers.Input(shape=(1,)), 173 | keras.layers.Dense(units=1) 174 | ]) 175 | model.compile(optimizer='sgd', loss='mean_squared_error') 176 | 177 | # Prepare training data 178 | inputs = [[...]] 179 | outputs = [[...]] 180 | 181 | # Initialize federated Trainer 182 | trainer = Trainer( 183 | model, 184 | inputs, 185 | outputs, 186 | batch_size=2 187 | ) 188 | 189 | # Train for 10 epochs 190 | trainer.fit(epochs=10) 191 | ``` 192 | 193 | ### React Native App Usage 194 | 195 | ```typescript 196 | import { isAvailable, train, evaluate, predict } from './ml'; 197 | 198 | // Example training call 199 | type ReceiveConfig = { 200 | modelJson: string; 201 | weights: string[]; 202 | batchSize: number; 203 | inputs: number[][]; 204 | inputShape: number[]; 205 | outputs?: number[][]; 206 | outputShape?: number[]; 207 | epochs?: number; 208 | }; 209 | 210 | async function runFederatedTraining(config: ReceiveConfig) { 211 | if (await isAvailable()) { 212 | const sendConfig = await train(config); 213 | console.log('Updated weights:', sendConfig.weights); 214 | console.log('Loss:', sendConfig.loss); 215 | } else { 216 | console.warn('Federated learning not available on this device.'); 217 | } 218 | } 219 | ``` 220 | 221 | ## API Reference 222 | 223 | ### ReceiveConfig 224 | 225 | ```typescript 226 | interface ReceiveConfig { 227 | modelJson: string; // URL to the model JSON manifest 228 | weights: string[]; // Array of weight shard file names 229 | batchSize: number; // Micro-batch size for local training 230 | inputs: number[][]; // Local input data array 231 | inputShape: number[]; // Shape of the input tensor 232 | outputs?: number[][]; // Local output data (for training/evaluation) 233 | outputShape?: number[]; // Shape of the output tensor 234 | epochs?: number; // Number of local training epochs 235 | datasetsPerDevice?: number; // Number of batches per device 236 | } 237 | ``` 238 | 239 | ### SendConfig 240 | 241 | ```typescript 242 | interface SendConfig { 243 | weights: Float32Array[]; // Updated model weights after operation 244 | outputs?: Float32Array[]; // Predictions (for evaluate/predict) 245 | loss: number; // Loss value after training 246 | } 247 | ``` 248 | 249 | ## Technical Considerations 250 | 251 | - **Automatic Tensor Disposal:** Prevents memory leaks by disposing of unused tensors. 252 | - **Batch-wise Processing:** Optimizes memory usage on resource-constrained devices. 253 | - **Error Handling:** Graceful cleanup and retry strategies on failures. 254 | - **Background Limits:** Adheres to mobile OS restrictions for background tasks. 255 | - **Network Optimization:** Minimizes bandwidth usage during federated communication. 256 | 257 | ## Performance Notes 258 | 259 | - WebGL performance and model size can impact load/inference times. 260 | - Battery consumption rises during on-device training; monitor usage. 261 | - Network latency affects federated rounds; consider compression. 262 | - iOS and Android hardware differences may yield variable performance. 263 | 264 | ## Contributing 265 | 266 | Contributions are welcome! To contribute: 267 | 268 | 1. Fork the repository. 269 | 2. Create a feature branch: `git checkout -b feature/my-feature`. 270 | 3. Implement your changes and commit. 271 | 4. Push to your fork: `git push origin feature/my-feature`. 272 | 5. Open a pull request describing your changes. 273 | 274 | ## License 275 | 276 | This project is licensed under the MIT License. See [LICENSE](LICENSE) for details. 277 | -------------------------------------------------------------------------------- /app/.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "semi": true, 3 | "singleQuote": true, 4 | "trailingComma": "es5", 5 | "tabWidth": 2, 6 | "printWidth": 80 7 | } 8 | -------------------------------------------------------------------------------- /app/.watchmanconfig: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /app/.yarnclean: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.md 3 | test 4 | tests 5 | example 6 | examples 7 | docs 8 | -------------------------------------------------------------------------------- /app/app.json: -------------------------------------------------------------------------------- 1 | { 2 | "expo": { 3 | "name": "mfl", 4 | "icon": "src/assets/images/appstore1024.png", 5 | "slug": "mfl", 6 | "owner": "rshemet", 7 | "version": "1.0.0", 8 | "sdkVersion": "51.0.0", 9 | "splash": { 10 | "image": "src/assets/images/appstore1024.png", 11 | "backgroundColor": "#0D0D0D" 12 | }, 13 | "assetBundlePatterns": [ 14 | "**/*" 15 | ], 16 | "extra": { 17 | "eas": { 18 | "projectId": "7e086f19-c37d-4a4e-ba43-7e59fbc312ee" 19 | } 20 | }, 21 | "ios": { 22 | "bundleIdentifier": "com.rshemet.mfl" 23 | }, 24 | "android": { 25 | "package": "com.rshemet.mfl", 26 | "googleServicesFile": "./google-services.json" 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /app/babel.config.js: -------------------------------------------------------------------------------- 1 | module.exports = function(api) { 2 | api.cache(true); 3 | return { 4 | presets: ['babel-preset-expo'], 5 | plugins: [ 6 | [ 7 | 'module:react-native-dotenv', 8 | { 9 | moduleName: '@env', 10 | path: '.env', 11 | blocklist: null, 12 | allowlist: null, 13 | safe: false, 14 | allowUndefined: true, 15 | verbose: false, 16 | } 17 | ], 18 | ], 19 | }; 20 | }; 21 | -------------------------------------------------------------------------------- /app/eas.json: -------------------------------------------------------------------------------- 1 | { 2 | "cli": { 3 | "version": ">= 13.4.2", 4 | "appVersionSource": "remote" 5 | }, 6 | "build": { 7 | "development": { 8 | "developmentClient": true, 9 | "distribution": "internal" 10 | }, 11 | "preview": { 12 | "distribution": "internal", 13 | "android": { 14 | "buildType": "apk" 15 | } 16 | }, 17 | "production": { 18 | "autoIncrement": true 19 | } 20 | }, 21 | "submit": { 22 | "production": {} 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /app/google-services.json: -------------------------------------------------------------------------------- 1 | "create yours" -------------------------------------------------------------------------------- /app/index.js: -------------------------------------------------------------------------------- 1 | import 'react-native-url-polyfill/auto'; 2 | import { registerRootComponent } from 'expo'; 3 | 4 | import App from './src/App'; 5 | 6 | // registerRootComponent calls AppRegistry.registerComponent('main', () => App); 7 | // It also ensures that whether you load the app in Expo Go or in a native build, 8 | // the environment is set up appropriately 9 | registerRootComponent(App); 10 | -------------------------------------------------------------------------------- /app/metro.config.js: -------------------------------------------------------------------------------- 1 | const { getDefaultConfig } = require('expo/metro-config'); 2 | 3 | const config = getDefaultConfig(__dirname); 4 | 5 | module.exports = config; -------------------------------------------------------------------------------- /app/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "mfl", 3 | "main": "index.js", 4 | "version": "0.0.1", 5 | "license": "Apache-2.0", 6 | "scripts": { 7 | "reset": "rm -rf node_modules && rm yarn.lock", 8 | "start": "expo start", 9 | "build-ios-local": "EXPO_NO_CAPABILITY_SYNC=1 eas build --platform ios --local --profile preview", 10 | "build-android-local": "eas build --platform android --profile preview --local" 11 | }, 12 | "dependencies": { 13 | "@react-native-async-storage/async-storage": "1.23.1", 14 | "@supabase/supabase-js": "^2.46.1", 15 | "@tensorflow/tfjs": "^4.22.0", 16 | "@tensorflow/tfjs-react-native": "^1.0.0", 17 | "expo": "^51.0.0", 18 | "expo-apple-authentication": "~6.4.2", 19 | "expo-build-properties": "~0.12.5", 20 | "expo-camera": "~15.0.16", 21 | "expo-file-system": "~17.0.1", 22 | "expo-gl": "~14.0.2", 23 | "expo-keep-awake": "~13.0.2", 24 | "react": "18.2.0", 25 | "react-native": "0.74.5", 26 | "react-native-dotenv": "^3.4.11", 27 | "react-native-fs": "^2.20.0", 28 | "react-native-paper": "^5.12.5", 29 | "react-native-safe-area-context": "4.10.5", 30 | "react-native-svg": "15.2.0", 31 | "react-native-url-polyfill": "^2.0.0", 32 | "styled-components": "^6.1.13" 33 | }, 34 | "devDependencies": { 35 | "@babel/core": "^7.25.2", 36 | "@types/react": "~18.2.79", 37 | "typescript": "~5.3.3" 38 | }, 39 | "jest": { 40 | "preset": "react-native" 41 | }, 42 | "private": true 43 | } 44 | -------------------------------------------------------------------------------- /app/src/App.tsx: -------------------------------------------------------------------------------- 1 | import React, { useState, useEffect } from 'react'; 2 | import { SafeAreaView } from 'react-native'; 3 | import { styledTheme } from './styles/theme'; 4 | import { ThemeProvider } from 'styled-components/native'; 5 | import LoginScreen from './components/Login'; 6 | import HomePage from './components/Home'; 7 | import { restoreSession, saveSession, clearSession } from './components/Outh'; 8 | import { checkSession, onAuthStateChange } from './components/Outh'; 9 | import { useKeepAwake } from 'expo-keep-awake'; 10 | import theme from './styles/theme'; 11 | 12 | const App = () => { 13 | const [isLoggedIn, setIsLoggedIn] = useState(false); 14 | const [loginInProgress, setLoginInProgress] = useState(false); 15 | 16 | useKeepAwake(); 17 | 18 | useEffect(() => { 19 | const initSession = async () => { 20 | const restored = await restoreSession(); 21 | if (!restored) { 22 | const isLoggedInA = await checkSession(); 23 | setIsLoggedIn(isLoggedInA); 24 | } else { 25 | setIsLoggedIn(true); 26 | } 27 | }; 28 | 29 | initSession(); 30 | 31 | const cleanup = onAuthStateChange(async (isLoggedIn, session) => { 32 | setIsLoggedIn(isLoggedIn); 33 | 34 | if (isLoggedIn && session) { 35 | await saveSession(session); 36 | setLoginInProgress(false); 37 | } else if (!isLoggedIn) { 38 | await clearSession(); 39 | } 40 | }); 41 | 42 | return cleanup; 43 | }, []); 44 | 45 | return ( 46 | 47 | 48 | {isLoggedIn && !loginInProgress ? ( 49 | 50 | ) : ( 51 | 52 | )} 53 | 54 | 55 | ); 56 | }; 57 | 58 | export default App; 59 | -------------------------------------------------------------------------------- /app/src/communications/Sockets.ts: -------------------------------------------------------------------------------- 1 | import { supabase, insertRow, setDeviceAvailability } from "./Supabase"; 2 | import {isAvailable, train, evaluate, predict} from '../ml' 3 | import { ReceiveConfig } from "../ml/Config"; 4 | 5 | import { getCurrentDeviceID } from "../utils/CurrentDevice"; 6 | 7 | const activeSubscriptions: Record = {}; // a basic dictionary to store active Supabase channels like {'tasks_device_1': channel} 8 | 9 | const createRealtimeChannelName = (table:string, deviceId:string) => { 10 | // simple function used to generate standardized generic names for subscriptions to realtime tables 11 | return `${table}_device_${deviceId}` 12 | } 13 | 14 | async function subscribeToRealtimeTable( 15 | table: string, 16 | eventType: 'INSERT' | 'UPDATE' | 'DELETE', 17 | callback: Function 18 | ): Promise { 19 | getCurrentDeviceID().then((deviceId => { 20 | const channelName = createRealtimeChannelName(table, deviceId) 21 | const channel = supabase 22 | .channel(channelName) 23 | .on( 24 | 'postgres_changes', 25 | { 26 | event: eventType, 27 | schema: 'public', 28 | table: table, 29 | filter: `device_id=eq.${deviceId}` 30 | }, 31 | (payload: Object) => { 32 | callback(payload) 33 | } 34 | ).subscribe((status) => { 35 | if (status === 'SUBSCRIBED') { 36 | console.log(`Successfully subscribed to ${table} for device_id: ${deviceId}`); 37 | activeSubscriptions[channelName] = channel 38 | } 39 | }); 40 | })) 41 | } 42 | 43 | export function joinNetwork() { 44 | subscribeToRealtimeTable( 45 | 'task_requests', 46 | 'INSERT', 47 | async (payload: object) => { 48 | const taskId = payload.new.id 49 | const task_type = payload.new.request_type 50 | const requestConfig = payload.new.data 51 | const responseData = await handleNewMLTask(task_type, requestConfig); 52 | await setDeviceAvailability('available'); 53 | insertRow( 54 | 'task_responses', 55 | {id: taskId, data: responseData} 56 | ) 57 | } 58 | ) 59 | } 60 | 61 | export async function leaveNetwork() { 62 | getCurrentDeviceID().then((deviceId) => { 63 | const tableName = 'task_requests'; 64 | const channelName = createRealtimeChannelName(tableName, deviceId) 65 | const channel = activeSubscriptions[channelName]; 66 | if (!channel) { 67 | console.log(`No active subscription found for ${tableName} and device_id: ${deviceId}`); 68 | return; 69 | } 70 | supabase.removeChannel(channel); 71 | delete activeSubscriptions[channelName]; // Clean up the subscription reference 72 | console.log(`Unsubscribed from ${tableName} for device_id: ${deviceId}`); 73 | }) 74 | } 75 | 76 | enum TaskType { 77 | train = 'train', 78 | evaluate = 'evaluate', 79 | predict = 'predict', 80 | } 81 | 82 | async function handleNewMLTask( 83 | taskType: TaskType, 84 | requestData: ReceiveConfig 85 | ): Promise { 86 | await setDeviceAvailability('busy') 87 | switch (taskType){ 88 | case TaskType.train: 89 | return await train(requestData) 90 | case TaskType.evaluate: 91 | return await evaluate(requestData) 92 | case TaskType.predict: 93 | return await predict(requestData) 94 | default: 95 | console.warn(`Unhandled task type: ${taskType}`); 96 | return 97 | } 98 | } -------------------------------------------------------------------------------- /app/src/communications/Supabase.ts: -------------------------------------------------------------------------------- 1 | import { useEffect } from 'react'; 2 | import { createClient } from '@supabase/supabase-js'; 3 | import { setCurrentDeviceID } from 'src/utils/CurrentDevice'; 4 | import AsyncStorage from '@react-native-async-storage/async-storage'; 5 | 6 | export const supabase = createClient(process.env.EXPO_PUBLIC_SUPABASE_URL, process.env.EXPO_PUBLIC_SUPABASE_ANON_KEY); 7 | 8 | export async function insertRow(table: string, data: Record): Promise { 9 | // Inserts any data object into a specified supabase table 10 | try { 11 | const { error } = await supabase.from(table).insert([data]); 12 | if (error) {throw error;} 13 | console.log(`Row inserted successfully into ${table}`); 14 | } catch (err) { 15 | console.error('Error inserting row:', err); 16 | } 17 | } 18 | 19 | export async function updateTableRows( 20 | table: string, 21 | criteria: Record, 22 | updates: Record 23 | ): Promise { 24 | // Updates any row(s) in a specified supabase table based on criteria 25 | // criteria is an object that specifies row conditions. Example: {id: 1, user_id: 2} 26 | // updates is an object that will update the row(s) matching the condition. Example: {status: "unavailable"} 27 | try { 28 | const { error } = await supabase.from(table).update(updates).match(criteria); 29 | if (error) {throw error;} 30 | console.log(`Row(s) updated successfully in ${table}`); 31 | } catch (err) { 32 | console.error('Error updating row(s):', err); 33 | } 34 | } 35 | 36 | export async function fetchDeviceAvailability(): Promise{ 37 | // This function is called after initial login, to know in what state the toggle was last left 38 | const { data: sessionData, error: sessionError } = await supabase.auth.getSession(); 39 | if (!sessionError){ 40 | const { data: latestAvailability, error: availabilityLoadError } = await supabase 41 | .from('devices') 42 | .select('status') 43 | .eq('user_id', sessionData.session?.user.id) 44 | .limit(1); 45 | 46 | if (!availabilityLoadError){ 47 | return latestAvailability[0].status 48 | } 49 | } 50 | } 51 | 52 | export async function setDeviceAvailability ( deviceAvailability:string ): Promise{ 53 | // returns success boolean 54 | const session = await AsyncStorage.getItem('user_session'); 55 | if (session){ 56 | const parsedSession = JSON.parse(session); 57 | const { data, error } = await supabase 58 | .from('devices') 59 | .update({ status: deviceAvailability, last_updated: new Date() }) 60 | .eq('user_id', parsedSession?.user.id); 61 | if (error){ 62 | error; 63 | } 64 | return true; 65 | } 66 | return false; 67 | } 68 | 69 | export const Heartbeat = ({ deviceActive }: { deviceActive: boolean }) => { 70 | useEffect(() => { 71 | let intervalId: NodeJS.Timeout | null = null; 72 | 73 | if (deviceActive) { 74 | setDeviceAvailability('available'); // set availability immediately 75 | intervalId = setInterval(() => { // ..and set a recurring heartbeat each 45sec 76 | setDeviceAvailability('available'); 77 | }, 45000); 78 | } 79 | 80 | return () => { 81 | if (intervalId) { 82 | clearInterval(intervalId); 83 | } 84 | }; 85 | }, [deviceActive]); 86 | 87 | return null; // we essentially treat this as a component but render nothing 88 | }; 89 | 90 | export const registerOrRetrieveDeviceFromSupabase = async () => { 91 | // Gets called on successful login. 92 | // The goal is to either register the device for the user (if they're logging in for the first time) or 93 | // retrieve device ID and set it to local storage 94 | 95 | // one BIG assumption we're making is that each user can only have one device ID 96 | 97 | try { 98 | // Step 1: Get the user session 99 | const { data: sessionData, error: sessionError } = await supabase.auth.getSession(); 100 | if (sessionError) { 101 | throw sessionError; 102 | } 103 | const user_id: string | undefined = sessionData.session?.user.id; 104 | if (!user_id) { 105 | throw new Error('User session is invalid or user ID is missing.'); 106 | } 107 | 108 | // Step 2: Check whether this user already has a registered device 109 | async function fetchDeviceId(): Promise { 110 | const { data: deviceData, error: tableLoadError } = await supabase 111 | .from('devices') 112 | .select('id') 113 | .eq('user_id', user_id) 114 | .limit(1); 115 | 116 | if (tableLoadError) { 117 | throw tableLoadError; 118 | } 119 | return deviceData[0]?.id 120 | } 121 | 122 | // Step 3. either set device ID or register this device if there's no device record for this user yet 123 | const deviceId = await fetchDeviceId() 124 | if (deviceId) { 125 | console.log(`Device id loaded: ${deviceId}, saving to async storage`) 126 | await setCurrentDeviceID(deviceId) 127 | } else { 128 | console.log(`Registering new device ID`) 129 | await insertRow('devices', { 130 | user_id: user_id, 131 | status: 'available', 132 | last_updated: new Date(), 133 | } 134 | ) 135 | const deviceId = await fetchDeviceId(); 136 | await setCurrentDeviceID(deviceId); 137 | } 138 | } catch (err) { 139 | console.error('Error registering device with Supabase:', err); 140 | } 141 | } -------------------------------------------------------------------------------- /app/src/components/Home.tsx: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState } from "react" 2 | import theme from "../styles/theme"; 3 | 4 | import { Image } from 'react-native'; 5 | import { View } from "react-native"; 6 | import { Text } from "react-native"; 7 | import { ActivityIndicator } from "react-native-paper"; 8 | 9 | import { MainContent } from "../styles/Style"; 10 | import TrainToggle from "./primary/TrainToggle"; 11 | import CustomButton from "./primary/Button"; 12 | import mflLogo from "../assets/images/logo_light_grey.png" 13 | 14 | import { joinNetwork, leaveNetwork } from "../communications/Sockets"; 15 | import { handleLogout } from "./Outh"; 16 | import { fetchDeviceAvailability, setDeviceAvailability, Heartbeat } from "../communications/Supabase"; 17 | 18 | 19 | const HomePage: React.FC = () => { 20 | const [loadingAvailability, setLoadingAvailability] = useState(true); 21 | const [deviceActive, setDeviceActive] = useState(false); 22 | 23 | const toggleDeviceStatus = () => { 24 | const newAvailability = deviceActive ? 'unavailable' : 'available' 25 | setDeviceAvailability(newAvailability).then((success) => { 26 | if (success){ 27 | setDeviceActive(newAvailability === 'available') 28 | switch (newAvailability) { 29 | case 'available': 30 | joinNetwork(); 31 | case 'unavailable': 32 | leaveNetwork(); 33 | } 34 | } 35 | else{console.log('unable to update device availability!')} 36 | }) 37 | }; 38 | 39 | useEffect(() => { 40 | fetchDeviceAvailability().then((availability) => { 41 | setLoadingAvailability(false) 42 | setDeviceActive(availability === 'available') 43 | if (availability === 'available'){joinNetwork()} 44 | }) 45 | }, []) 46 | 47 | return ( 48 | 49 | 50 | 51 | 52 | {loadingAvailability ? : } 53 | 54 | 55 | Sign out 56 | 57 | ) 58 | } 59 | export default HomePage -------------------------------------------------------------------------------- /app/src/components/Login.tsx: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState } from 'react'; 2 | import { handleLogin } from './Outh'; 3 | import SectionBreak from './primary/SectionBreak'; 4 | import Typography from './primary/Typography'; 5 | import { Image } from 'react-native'; 6 | import mflLogo from "../assets/images/logo_light_grey.png" 7 | import StyledTextInput from './primary/TextField'; 8 | import CustomButton from './primary/Button'; 9 | 10 | import theme from '../styles/theme'; 11 | import styled from 'styled-components/native'; 12 | 13 | const LoginScreen = ({setLoginInProgress}) => { 14 | const [email, setEmail] = useState(''); 15 | const [password, setPassword] = useState(''); 16 | 17 | const loadingAuth = false; // TODO: update for actual loading auth 18 | 19 | const onLoginPress = async () => { 20 | try { 21 | await handleLogin(email, password); 22 | setLoginInProgress(false) 23 | } catch (error: any) { 24 | console.log(error.message) 25 | } 26 | }; 27 | 28 | useEffect(() => { 29 | setLoginInProgress(true) 30 | }, []) 31 | 32 | return ( 33 | 34 | 35 | Sign in 36 | 43 | 49 | Sign in 50 | or 51 | 52 | ); 53 | }; 54 | 55 | const StyledAuthView = styled.View` 56 | flex: 1; 57 | justify-content: top; 58 | align-items: center; 59 | padding: 40px 30px; /* Top and Bottom 40px; Left and Right = 20px */ 60 | background-color: ${theme.colors.background}; 61 | `; 62 | 63 | export default LoginScreen; 64 | -------------------------------------------------------------------------------- /app/src/components/Outh.ts: -------------------------------------------------------------------------------- 1 | import { supabase } from '../communications/Supabase'; 2 | import { Session } from '@supabase/supabase-js'; 3 | import { registerOrRetrieveDeviceFromSupabase } from '../communications/Supabase'; 4 | import AsyncStorage from '@react-native-async-storage/async-storage'; 5 | 6 | export const checkSession = async (): Promise => { 7 | const { data } = await supabase.auth.getSession(); 8 | return !!data.session; 9 | }; 10 | 11 | export const saveSession = async (session: object) => { 12 | // this function saves the user session to an encrypted async storage for later retrieval. 13 | // this is implemented to avoid having the user need to log in each time the app is quit & re-opened 14 | try { 15 | await AsyncStorage.setItem('user_session', JSON.stringify(session)); 16 | } catch (error) { 17 | console.error('Error saving session:', error); 18 | } 19 | }; 20 | 21 | export const restoreSession = async () => { 22 | // this function retrieves the user session from an encrypted async storage (see saveSession method) 23 | // this is implemented to avoid having the user need to log in each time the app is quit & re-opened 24 | try { 25 | const session = await AsyncStorage.getItem('user_session'); 26 | if (session) { 27 | const parsedSession = JSON.parse(session); 28 | await supabase.auth.setSession(parsedSession); 29 | return true; 30 | } 31 | return false; 32 | } catch (error) { 33 | console.error('Error restoring session:', error); 34 | return false; 35 | } 36 | }; 37 | 38 | export const clearSession = async () => { 39 | // this function clears the user session from an encrypted async storage 40 | // this should be called each time the user is logged out 41 | try { 42 | await AsyncStorage.removeItem('user_session'); 43 | } catch (error) { 44 | console.error('Error clearing session:', error); 45 | } 46 | }; 47 | 48 | export const onAuthStateChange = ( 49 | callback: (isLoggedIn: boolean, session: Session | null) => void 50 | ) => { 51 | const { data: authListener } = supabase.auth.onAuthStateChange( 52 | (_event, session) => { 53 | callback(!!session, session); 54 | } 55 | ); 56 | 57 | return () => { 58 | authListener.subscription.unsubscribe(); 59 | }; 60 | }; 61 | 62 | export const handleLogout = async () => { 63 | try { 64 | await supabase.auth.signOut(); 65 | console.log('User logged out successfully.'); 66 | } catch (error) { 67 | console.error('Error logging out:', error); 68 | } 69 | }; 70 | 71 | export const handleLogin = async ( 72 | email: string, 73 | password: string 74 | ): Promise => { 75 | try { 76 | const { error } = await supabase.auth.signInWithPassword({ 77 | email, 78 | password, 79 | }); 80 | if (error) { 81 | throw error; 82 | } else { 83 | // we should await, otherwise realtime subscription is faster than the device ID fetch and we try to listen to an undefied device ID 84 | await registerOrRetrieveDeviceFromSupabase() 85 | } 86 | } catch (error: any) { 87 | throw new Error(error.message || 'Login failed'); 88 | } 89 | }; 90 | -------------------------------------------------------------------------------- /app/src/components/Signup.tsx: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState } from 'react'; 2 | import { handleSignup } from './Outh'; 3 | import Typography from './primary/Typography'; 4 | import { Image, View, TouchableOpacity, Platform } from 'react-native'; 5 | import mflLogo from "../assets/images/logo_light_grey.png" 6 | import StyledTextInput from './primary/TextField'; 7 | import CustomButton from './primary/Button'; 8 | 9 | import theme from '../styles/theme'; 10 | import styled from 'styled-components/native'; 11 | 12 | const SignupScreen = ({ setLoginInProgress, switchToLogin }) => { 13 | const [ email, setEmail ] = useState(''); 14 | const [ errorMessage, setErrorMessage ] = useState('') 15 | const [ password, setPassword ] = useState(''); 16 | const [ passwordConfirmation, setPasswordConfirmation ] = useState(''); 17 | const [ buttonLoading, setButtonLoading ] = useState(false); 18 | 19 | const loadingAuth = false; // TODO: update for actual loading auth 20 | 21 | const onSignupPress = async () => { 22 | setButtonLoading(true) 23 | if (password !== passwordConfirmation){setErrorMessage('Passwords must match')} 24 | try { 25 | const error = await handleSignup(email, password) 26 | if (error){ 27 | setErrorMessage(error) 28 | }else{ 29 | setLoginInProgress(false) 30 | } 31 | } catch (error: any) { 32 | console.log(error.message) 33 | } 34 | setButtonLoading(false) 35 | }; 36 | 37 | useEffect(() => { 38 | setLoginInProgress(true) 39 | }, []) 40 | 41 | return ( 42 | 43 | 44 | Sign up 45 | {setEmail(value); setErrorMessage('')}} 49 | keyboardType="email-address" 50 | autoCapitalize="none" 51 | /> 52 | {setPassword(value); setErrorMessage('')}} 57 | /> 58 | {setPasswordConfirmation(value); setErrorMessage('')}} 63 | /> 64 | {errorMessage ? 65 | {errorMessage} 66 | : null} 67 | Sign up 68 | 69 | 70 | Already have an account? 71 | 72 | 73 | Click here to log in 74 | 75 | 76 | 77 | 78 | ); 79 | }; 80 | 81 | const StyledAuthView = styled.KeyboardAvoidingView` 82 | flex: 1; 83 | justify-content: top; 84 | align-items: center; 85 | padding: 40px 30px; /* Top and Bottom 40px; Left and Right = 20px */ 86 | background-color: ${theme.colors.background}; 87 | `; 88 | 89 | export default SignupScreen; -------------------------------------------------------------------------------- /app/src/components/auth/PlatformAuth.native.tsx: -------------------------------------------------------------------------------- 1 | import React from 'react' 2 | import { Platform } from 'react-native' 3 | import SocialAuthButton from '../primary/SocialAuthButton' 4 | import * as AppleAuthentication from 'expo-apple-authentication' 5 | import { registerOrRetrieveDeviceFromSupabase, supabase } from 'src/communications/Supabase' 6 | 7 | export const PlatformAuth: React.FC = () => { 8 | if (Platform.OS === 'ios') { 9 | return ( 10 | { 15 | try { 16 | const credential = await AppleAuthentication.signInAsync({ 17 | requestedScopes: [ 18 | AppleAuthentication.AppleAuthenticationScope.FULL_NAME, 19 | AppleAuthentication.AppleAuthenticationScope.EMAIL, 20 | ], 21 | }) 22 | // Sign in via Supabase Auth. 23 | if (credential.identityToken) { 24 | const { 25 | error, 26 | data: { user }, 27 | } = await supabase.auth.signInWithIdToken({ 28 | provider: 'apple', 29 | token: credential.identityToken, 30 | }) 31 | // if (!error) {} 32 | if (error){console.log(JSON.stringify({ error, user }, null, 2))} 33 | } else { 34 | throw new Error('No identityToken.') 35 | } 36 | } catch (e) { 37 | if (e.code === 'ERR_REQUEST_CANCELED') { 38 | // handle that the user canceled the sign-in flow 39 | } else { 40 | // handle other errors 41 | } 42 | } 43 | }} 44 | /> 45 | ) 46 | }else{ 47 | return null //add android? maybe 48 | } 49 | } 50 | 51 | {/* { 57 | try { 58 | const credential = await AppleAuthentication.signInAsync({ 59 | requestedScopes: [ 60 | AppleAuthentication.AppleAuthenticationScope.FULL_NAME, 61 | AppleAuthentication.AppleAuthenticationScope.EMAIL, 62 | ], 63 | }) 64 | // Sign in via Supabase Auth. 65 | if (credential.identityToken) { 66 | const { 67 | error, 68 | data: { user }, 69 | } = await supabase.auth.signInWithIdToken({ 70 | provider: 'apple', 71 | token: credential.identityToken, 72 | }) 73 | console.log(JSON.stringify({ error, user }, null, 2)) 74 | if (!error) { 75 | // User is signed in. 76 | } 77 | } else { 78 | throw new Error('No identityToken.') 79 | } 80 | } catch (e) { 81 | if (e.code === 'ERR_REQUEST_CANCELED') { 82 | // handle that the user canceled the sign-in flow 83 | } else { 84 | // handle other errors 85 | } 86 | } 87 | }} 88 | /> */} -------------------------------------------------------------------------------- /app/src/components/primary/Button.tsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import Typography from './Typography'; 3 | import theme from '../../styles/theme'; 4 | import { ActivityIndicator } from 'react-native'; 5 | import { Pressable, PressableProps } from 'react-native'; // Pressable is a lower-level button component and supports more customization 6 | 7 | interface CustomButtonProps extends PressableProps { 8 | customVariant?: 'primary' | 'secondary'; // Custom variant for styling 9 | loading?: boolean; 10 | } 11 | 12 | const CustomButton: React.FC = ({ children, customVariant='primary', loading=false, ...props }) => { 13 | 14 | const buttonColor = customVariant === 'primary' ? theme.colors.primary : theme.colors.surface 15 | const fontColor = customVariant === 'primary' ? theme.colors.background : theme.colors.textDefault 16 | 17 | const buttonShadow = customVariant === 'primary' 18 | ? { shadowColor: theme.colors.primary, shadowOffset: { width: 0, height: 0 }, shadowOpacity: 1, shadowRadius: 5, } 19 | : {}; 20 | 21 | return ( 22 | [ 26 | { 27 | backgroundColor: buttonColor, 28 | paddingLeft: 40, 29 | paddingRight: 40, 30 | paddingTop: 20, 31 | paddingBottom: 20, 32 | borderRadius: 50, 33 | borderWidth: customVariant === 'primary' ? 0 : 1, 34 | borderColor: customVariant === 'primary' ? null : theme.colors.textSecondary, 35 | alignItems: 'center', 36 | }, 37 | pressed && buttonShadow, 38 | ]} 39 | > 40 | { loading ? : {children} } 41 | 42 | ); 43 | }; 44 | 45 | export default CustomButton; -------------------------------------------------------------------------------- /app/src/components/primary/SectionBreak.tsx: -------------------------------------------------------------------------------- 1 | // src/components/SectionBreak.tsx 2 | import React from 'react'; 3 | import theme from '../../styles/theme'; 4 | import Typography from './Typography'; 5 | import styled from 'styled-components/native'; 6 | 7 | interface SectionBreakProps { 8 | children: React.ReactNode; // Children (basically plain text) to be displayed as the label 9 | } 10 | 11 | const DividerContainer = styled.View` 12 | width: 100%; 13 | flex-direction: row; 14 | align-items: center; 15 | `; 16 | 17 | const HorizontalLine = styled.View` 18 | flex: 1; 19 | height: 1px; 20 | background-color: ${theme.colors.primary}; 21 | opacity: 0.3; 22 | `; 23 | 24 | const SectionBreakContainer = styled.View` 25 | flex-direction: row; 26 | align-items: center; 27 | justify-content: center; 28 | margin: 10px; /* Add margin around the text block */ 29 | `; 30 | 31 | const SectionBreak: React.FC = ({ children }) => { 32 | return ( 33 | 34 | 35 | 36 | 37 | {children} 38 | 39 | 40 | 41 | 42 | ); 43 | }; 44 | 45 | export default SectionBreak; -------------------------------------------------------------------------------- /app/src/components/primary/SocialAuthButton.tsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { AntDesign } from '@expo/vector-icons'; 3 | import theme from '../../styles/theme'; 4 | import { ActivityIndicator } from 'react-native'; 5 | import { Pressable, PressableProps } from 'react-native'; // Pressable is a lower-level button component and supports more customization 6 | 7 | interface SocialAuthButtonProps extends PressableProps { 8 | customVariant?: 'primary' | 'secondary'; // Custom variant for styling 9 | provider: 'apple' | 'google' | 'android' | 'facebook'; 10 | loading?: boolean; 11 | } 12 | const getIconName = (provider: string) => { 13 | switch (provider) { 14 | case 'apple': 15 | return 'apple1'; 16 | case 'android': 17 | return 'android1'; 18 | case 'google': 19 | return 'google'; 20 | case 'facebook': 21 | return 'facebook-square'; 22 | default: 23 | return 'questioncircle'; 24 | } 25 | }; 26 | 27 | const SocialAuthButton: React.FC = ({ provider, customVariant='primary', loading=false, ...props }) => { 28 | 29 | const buttonColor = customVariant === 'primary' ? theme.colors.primary : theme.colors.surface 30 | const iconColor = customVariant === 'primary' ? theme.colors.background : theme.colors.textDefault 31 | const iconName = getIconName(provider) 32 | 33 | const buttonShadow = customVariant === 'primary' 34 | ? { shadowColor: theme.colors.primary, shadowOffset: { width: 0, height: 0 }, shadowOpacity: 1, shadowRadius: 5, } 35 | : {}; 36 | 37 | return ( 38 | [ 42 | { 43 | backgroundColor: buttonColor, 44 | padding: 20, 45 | borderRadius: 50, 46 | borderWidth: customVariant === 'primary' ? 0 : 1, 47 | borderColor: customVariant === 'primary' ? null : theme.colors.textSecondary, 48 | alignItems: 'center', 49 | }, 50 | pressed && buttonShadow, 51 | ]} 52 | > 53 | { loading ? ( 54 | 55 | ) : ( 56 | 57 | )} 58 | 59 | ); 60 | }; 61 | 62 | export default SocialAuthButton; -------------------------------------------------------------------------------- /app/src/components/primary/SocialAuthButtons.tsx: -------------------------------------------------------------------------------- 1 | import React, { Dispatch, SetStateAction } from "react" 2 | import CustomButton from "../primary/Button" 3 | import GitHubIcon from '@mui/icons-material/GitHub'; // Import GitHub icon 4 | import { supabase } from "services/supabaseClient"; 5 | 6 | interface SocialButtonGroupProps { 7 | setLoading: Dispatch>; 8 | setError: Dispatch>; 9 | } 10 | 11 | const SocialButtonGroup: React.FC = ({ setLoading, setError }) => { 12 | const handleSocialLogin = async (provider: 'github') => { 13 | setLoading(true); 14 | setError(null); 15 | 16 | const { error } = await supabase.auth.signInWithOAuth({ 17 | provider, 18 | options: { 19 | redirectTo: `${window.location.origin}/dashboard`, // Adjust as needed 20 | }, 21 | }); 22 | 23 | if (error) { 24 | setError(error.message); 25 | setLoading(false); 26 | } 27 | }; 28 | 29 | return ( 30 | handleSocialLogin('github')} 33 | icon={} 34 | > 35 | Continue with GitHub 36 | 37 | ) 38 | } 39 | 40 | export default SocialButtonGroup -------------------------------------------------------------------------------- /app/src/components/primary/TextField.tsx: -------------------------------------------------------------------------------- 1 | import React, { useState } from 'react'; 2 | import styled from 'styled-components/native'; 3 | import theme from '../../styles/theme'; 4 | 5 | // Styled Container 6 | const StyledContainer = styled.View` 7 | width: 100%; 8 | height: auto; 9 | margin: ${theme.spacing * 2}px 0; 10 | `; 11 | 12 | // Styled Input 13 | const StyledInput = styled.TextInput<{ isFocused: boolean }>` 14 | width: 100%; 15 | border-color: ${({ isFocused }) => 16 | isFocused ? theme.colors.primary : theme.colors.textDefault}; 17 | border-width: ${({ isFocused }) => (isFocused ? '1px' : '0.5px')}; 18 | border-radius: ${theme.roundness}px; 19 | padding: 25px; 20 | color: ${theme.colors.textDefault}; 21 | background-color: ${theme.colors.surface}; 22 | text-transform: none; 23 | `; 24 | 25 | const StyledTextInput = ({ value, onChangeText, placeholder='placeholder', secureTextEntry=false }: any) => { 26 | const [isFocused, setIsFocused] = useState(false); 27 | 28 | return ( 29 | 30 | setIsFocused(true)} 38 | onBlur={() => setIsFocused(false)} 39 | autoCapitalize="none" 40 | /> 41 | 42 | ); 43 | }; 44 | 45 | export default StyledTextInput; -------------------------------------------------------------------------------- /app/src/components/primary/TrainToggle.tsx: -------------------------------------------------------------------------------- 1 | import React from "react" 2 | import { View, Text, Switch } from "react-native" 3 | import { useTheme } from "styled-components" 4 | 5 | interface TrainToggleProps { 6 | isTraining: boolean; 7 | handleTrainToggle: () => void; 8 | } 9 | 10 | const TrainToggle: React.FC = ({isTraining, handleTrainToggle }) => { 11 | const theme = useTheme(); 12 | return ( 13 | 14 | 30 | 38 | {isTraining ? 'Connected to the grid!' : 'Offline'} 39 | 40 | 41 | ) 42 | } 43 | 44 | export default TrainToggle -------------------------------------------------------------------------------- /app/src/components/primary/Typography.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import { Text } from "react-native"; 3 | 4 | import theme from '../../styles/theme'; 5 | 6 | type TypographyVariant = 'h1' | 'h2' | 'h3' | 'h4' | 'h5' | 'h6' | 'body1' | 'body2'; 7 | 8 | interface TypographyProps { 9 | variant: TypographyVariant; // h1, h2, body1, etc 10 | style?: object; // lets us pass in typography customizations if needed 11 | children: React.ReactNode; 12 | } 13 | 14 | const Typography: React.FC = ({ variant='body1', style, children }) => { 15 | const textStyles = [ 16 | {...theme.typography[variant], color: theme.colors.textDefault}, 17 | style 18 | ]; 19 | return {children}; 20 | }; 21 | 22 | export default Typography -------------------------------------------------------------------------------- /app/src/ml/Config.ts: -------------------------------------------------------------------------------- 1 | // Config.ts 2 | 3 | import * as tf from '@tensorflow/tfjs'; 4 | import '@tensorflow/tfjs-react-native'; 5 | 6 | export interface ReceiveConfig { 7 | modelJson: JSON; 8 | weights: number[][][]; 9 | batchSize: number; 10 | inputs: number[][][]; 11 | inputShape: number[][][]; 12 | outputs?: number[][][]; 13 | outputShape?: number[][][]; 14 | epochs?: number; 15 | datasetsPerDevice?: number; 16 | } 17 | 18 | export interface SendConfig { 19 | weights: number[][][]; 20 | outputs?: number[][][]; 21 | loss: number; 22 | } 23 | 24 | export async function processSendConfig( 25 | model: tf.LayersModel, 26 | loss: number, 27 | modelOutputs?: tf.Tensor[] 28 | ): Promise<{ 29 | weights: number[][][]; // Changed to support up to 3D arrays 30 | outputs?: number[][]; 31 | loss: number; 32 | }> { 33 | try { 34 | const weights = model.getWeights(); 35 | const weightData: number[][][] = []; // Changed to 3D array 36 | const outputData: number[][] = []; 37 | 38 | // Process weights 39 | for (const tensor of weights) { 40 | try { 41 | const data = await tensor.array(); // Use array() instead of data() 42 | weightData.push(data); 43 | } catch (tensorError) { 44 | console.error('Error processing weight tensor:', tensorError); 45 | throw tensorError; 46 | } finally { 47 | tensor.dispose(); 48 | } 49 | } 50 | 51 | // Initialize sendConfig with weights 52 | const sendConfig: { 53 | weights: number[][][]; // Changed to 3D array 54 | outputs?: number[][]; 55 | loss: number; 56 | } = { 57 | weights: weightData, 58 | loss: loss 59 | }; 60 | 61 | // Process outputs only if modelOutputs is provided 62 | if (modelOutputs && modelOutputs.length > 0) { 63 | for (const tensor of modelOutputs) { 64 | try { 65 | const data = await tensor.array(); // Use array() instead of data() 66 | outputData.push(data); 67 | } catch (tensorError) { 68 | console.error('Error processing output tensor:', tensorError); 69 | throw tensorError; 70 | } finally { 71 | tensor.dispose(); 72 | } 73 | } 74 | 75 | // Add outputs to sendConfig 76 | sendConfig.outputs = outputData; 77 | } 78 | 79 | return sendConfig; 80 | } catch (error) { 81 | console.error('Error in processSendConfig:', error); 82 | throw new Error(`Failed to process send config: ${error}`); 83 | } 84 | } -------------------------------------------------------------------------------- /app/src/ml/Diagnostics.ts: -------------------------------------------------------------------------------- 1 | // Diagnostics.ts 2 | 3 | import { initializeTf } from './TensorflowHandler'; 4 | 5 | export async function runDiagnositics() { 6 | try { 7 | await initializeTf(); 8 | return true; 9 | } catch (error) { 10 | return false; 11 | throw error; 12 | } 13 | } -------------------------------------------------------------------------------- /app/src/ml/Evaluation.ts: -------------------------------------------------------------------------------- 1 | // Evaluator.ts 2 | 3 | import * as tf from '@tensorflow/tfjs'; 4 | import { loadModel } from './ModelHandler'; 5 | import { ReceiveConfig, processSendConfig, SendConfig } from './Config'; 6 | import { initializeTf } from './TensorflowHandler'; 7 | 8 | export const runEvaluation = async ( 9 | receiveConfig: ReceiveConfig, 10 | ): Promise => { 11 | 12 | await initializeTf(); 13 | const model = await loadModel(receiveConfig); 14 | const inputTensor = tf.tensor2d(receiveConfig.inputs, receiveConfig.inputShape); 15 | const outputTensor = tf.tensor2d(receiveConfig.outputs, receiveConfig.outputShape); 16 | 17 | try { 18 | if (!model.loss) { 19 | throw new Error('Please ensure the model is loaded and compiled correctly.'); 20 | } 21 | 22 | const batchSize = receiveConfig.batchSize; 23 | const numSamples = inputTensor.shape[0]; 24 | const numBatches = Math.ceil(numSamples / batchSize); 25 | let totalLoss = 0; 26 | 27 | // Evaluate in batches to manage memory 28 | for (let batch = 0; batch < numBatches; batch++) { 29 | const start = batch * batchSize; 30 | const end = Math.min(start + batchSize, numSamples); 31 | 32 | const batchInputs = inputTensor.slice([start, 0], [end - start, -1]); 33 | const batchOutputs = outputTensor.slice([start, 0], [end - start, -1]); 34 | 35 | // Calculate loss for this batch 36 | const lossValue = tf.tidy(() => { 37 | const preds = model.predict(batchInputs) as tf.Tensor; 38 | const lossVal = model.loss(batchOutputs, preds); 39 | return lossVal; 40 | }); 41 | 42 | totalLoss += lossValue.dataSync()[0] * (end - start); 43 | lossValue.dispose(); 44 | 45 | batchInputs.dispose(); 46 | batchOutputs.dispose(); 47 | 48 | } 49 | 50 | // Calculate average loss 51 | const averageLoss = totalLoss / numSamples; 52 | console.log(`Evaluation Loss: ${averageLoss.toFixed(4)}`); 53 | const sendConfig = await processSendConfig(model, averageLoss); 54 | 55 | return sendConfig; 56 | } catch (error) { 57 | console.error('Error during evaluation:', error); 58 | throw error; 59 | } 60 | }; -------------------------------------------------------------------------------- /app/src/ml/Losses.ts: -------------------------------------------------------------------------------- 1 | // Losses.ts 2 | 3 | import * as tf from '@tensorflow/tfjs'; 4 | 5 | export function createLossFunction(modelJson) { 6 | const loss = modelJson.modelTopology?.training_config?.loss; 7 | 8 | if (!loss) { 9 | throw new Error("Loss function not found in the model JSON."); 10 | } 11 | 12 | console.log(`Loss function: ${loss}`); 13 | 14 | const lossMapping = { 15 | 'mean_squared_error': tf.losses.meanSquaredError, 16 | 'mse': tf.losses.meanSquaredError, 17 | 'mean_absolute_error': tf.losses.absoluteDifference, 18 | 'mae': tf.losses.absoluteDifference, 19 | 'categorical_crossentropy': tf.losses.softmaxCrossEntropy, 20 | 'binary_crossentropy': tf.losses.sigmoidCrossEntropy, 21 | 'sparse_categorical_crossentropy': tf.losses.softmaxCrossEntropy, 22 | 'hinge': tf.losses.hingeLoss, 23 | 'huber_loss': tf.losses.huberLoss, 24 | 'kl_divergence': tf.losses.kullbackLeiblerDivergence, 25 | 'cosine_similarity': tf.losses.cosineDistance, 26 | }; 27 | 28 | const tfjsLoss = lossMapping[loss.toLowerCase()]; 29 | 30 | if (!tfjsLoss) { 31 | throw new Error(`Unsupported loss function: ${loss}`); 32 | } 33 | 34 | switch (loss.toLowerCase()) { 35 | case 'categorical_crossentropy': 36 | case 'binary_crossentropy': 37 | return (yTrue, yPred) => tfjsLoss(yTrue, yPred, {from_logits: false}); 38 | case 'sparse_categorical_crossentropy': 39 | return (yTrue, yPred) => tfjsLoss(yTrue, yPred, {from_logits: false, axis: -1}); 40 | case 'huber_loss': 41 | return (yTrue, yPred) => tfjsLoss(yTrue, yPred, 1.0); 42 | case 'cosine_similarity': 43 | return (yTrue, yPred) => tfjsLoss(yTrue, yPred, -1); 44 | default: 45 | return tfjsLoss; 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /app/src/ml/ModelHandler.ts: -------------------------------------------------------------------------------- 1 | // ModelHandler.ts 2 | 3 | import * as tf from '@tensorflow/tfjs'; 4 | import { createLossFunction } from './Losses'; 5 | import { createOptimizer } from './Optimizers'; 6 | import { ReceiveConfig } from './Config'; 7 | 8 | export const loadModel = async (receiveConfig: ReceiveConfig): Promise => { 9 | try { 10 | const customIOHandler = { 11 | load: async () => { 12 | return { 13 | modelTopology: receiveConfig.modelJson.modelTopology, 14 | format: receiveConfig.modelJson.format || 'layers-model', 15 | generatedBy: receiveConfig.modelJson.generatedBy, 16 | convertedBy: receiveConfig.modelJson.convertedBy, 17 | // Intentionally omit weightsManifest to prevent automatic weight loading 18 | }; 19 | } 20 | }; 21 | 22 | // Load the model architecture using the custom IOHandler 23 | const loadedModel = await tf.loadLayersModel(customIOHandler); 24 | 25 | // Load the weights from the received config 26 | const weightTensors = receiveConfig.weights.map(data => tf.tensor(data)); 27 | loadedModel.setWeights(weightTensors); 28 | weightTensors.forEach(tensor => tensor.dispose()); 29 | 30 | const optimizer = createOptimizer(receiveConfig.modelJson); 31 | const lossFunction = createLossFunction(receiveConfig.modelJson); 32 | 33 | if (optimizer && lossFunction) { 34 | loadedModel.compile({ 35 | optimizer: optimizer, 36 | loss: lossFunction, 37 | }); 38 | console.log('Model compiled with extracted optimizer and loss function.'); 39 | } else { 40 | console.warn('Optimizer or loss function information not found in model'); 41 | } 42 | 43 | console.log('Model loaded from JSON and compiled.'); 44 | return loadedModel; 45 | 46 | } catch (error) { 47 | console.error('Error loading the model:', error); 48 | throw error; 49 | } 50 | }; 51 | -------------------------------------------------------------------------------- /app/src/ml/Optimizers.ts: -------------------------------------------------------------------------------- 1 | // Optimizers.ts 2 | 3 | import * as tf from '@tensorflow/tfjs'; 4 | 5 | export function createOptimizer(modelJson) { 6 | const optimizerConfig = modelJson.modelTopology?.training_config?.optimizer_config; 7 | 8 | if (!optimizerConfig) { 9 | throw new Error("Optimizer configuration not found in the model JSON."); 10 | } 11 | 12 | const className = optimizerConfig.class_name; 13 | const config = optimizerConfig.config; 14 | let optimizer; 15 | 16 | switch (className.toLowerCase()) { 17 | case 'sgd': 18 | const sgdLearningRate = config.learning_rate !== undefined ? config.learning_rate : 0.01; 19 | const sgdMomentum = config.momentum !== undefined ? config.momentum : 0.9; // Added momentum 20 | optimizer = tf.train.sgd(sgdLearningRate, sgdMomentum); 21 | break; 22 | 23 | case 'adam': 24 | const adamLearningRate = config.learning_rate !== undefined ? config.learning_rate : 0.001; // Reduced learning rate 25 | const adamBeta1 = config.beta_1 !== undefined ? config.beta_1 : 0.9; 26 | const adamBeta2 = config.beta_2 !== undefined ? config.beta_2 : 0.999; 27 | const adamEpsilon = config.epsilon !== undefined ? config.epsilon : 1e-8; // Adjusted epsilon 28 | optimizer = tf.train.adam(adamLearningRate, adamBeta1, adamBeta2, adamEpsilon); 29 | break; 30 | 31 | case 'rmsprop': 32 | const rmsLearningRate = config.learning_rate !== undefined ? config.learning_rate : 0.01; 33 | const rmsRho = config.rho !== undefined ? config.rho : 0.9; 34 | const rmsMomentum = config.momentum !== undefined ? config.momentum : 0.0; 35 | const rmsEpsilon = config.epsilon !== undefined ? config.epsilon : 1e-7; 36 | const rmsCentered = config.centered !== undefined ? config.centered : false; 37 | 38 | optimizer = tf.train.rmsprop(rmsLearningRate, rmsRho, rmsMomentum, rmsEpsilon, rmsCentered); 39 | break; 40 | 41 | default: 42 | throw new Error(`Unsupported optimizer class name: ${className}`); 43 | } 44 | 45 | return optimizer; 46 | } 47 | -------------------------------------------------------------------------------- /app/src/ml/Prediction.ts: -------------------------------------------------------------------------------- 1 | // Prediction.ts 2 | 3 | import * as tf from '@tensorflow/tfjs'; 4 | import { loadModel } from './ModelHandler'; 5 | import { ReceiveConfig, processSendConfig, SendConfig } from './Config'; 6 | import { initializeTf } from './TensorflowHandler'; 7 | 8 | export const runPrediction = async ( 9 | receiveConfig: ReceiveConfig, 10 | ): Promise => { 11 | 12 | await initializeTf(); 13 | const model = await loadModel(receiveConfig); 14 | 15 | const inputTensor = tf.tensor2d(receiveConfig.inputs, receiveConfig.inputShape); 16 | const allPredictions: tf.Tensor[] = []; 17 | 18 | try { 19 | if (typeof model.predict !== 'function') { 20 | throw new Error('The loaded model does not have a predict method.'); 21 | } 22 | 23 | const batchSize = receiveConfig.batchSize; 24 | const numSamples = inputTensor.shape[0]; 25 | const numBatches = Math.ceil(numSamples / batchSize); 26 | 27 | // Predict in batches to manage memory 28 | for (let batch = 0; batch < numBatches; batch++) { 29 | const start = batch * batchSize; 30 | const end = Math.min(start + batchSize, numSamples); 31 | 32 | const batchInputs = inputTensor.slice([start, 0], [end - start, -1]); 33 | 34 | // Make prediction on the current batch 35 | const batchPreds = model.predict(batchInputs) as tf.Tensor; 36 | allPredictions.push(batchPreds); // Keep the tensor for later processing 37 | batchInputs.dispose(); 38 | // Do NOT dispose of batchPreds here; it will be handled in processSendConfig 39 | } 40 | 41 | 42 | // Create SendConfig with model weights and predictions 43 | const sendConfig: SendConfig = await processSendConfig(model, 0, allPredictions); 44 | 45 | // Dispose of the input tensor and the model after processing 46 | inputTensor.dispose(); 47 | model.dispose(); 48 | 49 | console.log('Success', 'Model predictions processed and SendConfig created successfully.'); 50 | return sendConfig; 51 | } catch (error) { 52 | console.error('Error during prediction:', error); 53 | 54 | inputTensor.dispose(); 55 | model.dispose(); 56 | 57 | throw error; // Re-throw the error after logging 58 | } 59 | }; 60 | -------------------------------------------------------------------------------- /app/src/ml/TensorflowHandler.ts: -------------------------------------------------------------------------------- 1 | // tensorflowInitializer.ts 2 | 3 | import * as tf from '@tensorflow/tfjs'; 4 | 5 | export const initializeTf = async (): Promise => { 6 | try { 7 | await tf.ready(); 8 | console.log('TensorFlow.js is ready.'); 9 | } catch (error) { 10 | console.error('Error initializing TensorFlow.js:', error); 11 | throw error; 12 | } 13 | }; -------------------------------------------------------------------------------- /app/src/ml/Training.ts: -------------------------------------------------------------------------------- 1 | // Training.ts 2 | 3 | import * as tf from '@tensorflow/tfjs'; 4 | import { loadModel } from './ModelHandler'; 5 | import { ReceiveConfig, processSendConfig, SendConfig } from './Config'; 6 | import { initializeTf } from './TensorflowHandler'; 7 | 8 | export const runTraining = async ( 9 | receiveConfig: ReceiveConfig, 10 | ): Promise => { 11 | 12 | try { 13 | await initializeTf(); 14 | const model = await loadModel(receiveConfig); 15 | 16 | // Prepare input and output tensors 17 | const inputTensor = tf.tensor2d(receiveConfig.inputs, receiveConfig.inputShape); 18 | const outputTensor = tf.tensor2d(receiveConfig.outputs, receiveConfig.outputShape); 19 | 20 | if (!model.optimizer || !model.loss) { 21 | throw new Error('Model is not compiled. Please ensure the model is loaded and compiled correctly.'); 22 | } 23 | 24 | // Extract training configurations 25 | const effectiveBatchSize = receiveConfig.batchSize; 26 | const microBatchSize = 1; 27 | 28 | const accumulationSteps = effectiveBatchSize / microBatchSize; 29 | const numSamples = inputTensor.shape[0]; 30 | const numBatches = Math.ceil(numSamples / microBatchSize); 31 | 32 | let finalLoss = 0; 33 | 34 | // Initialize accumulated gradients as separate tensors 35 | const accumulatedGradients: tf.NamedTensorMap = {}; 36 | model.trainableWeights.forEach(weight => { 37 | // Clone the tensors to ensure they are separate from model variables 38 | accumulatedGradients[weight.name] = tf.zerosLike(weight.read()).clone(); 39 | }); 40 | 41 | // Calculate the number of iterations based on datasetsPerDevice 42 | const totalIterations = effectiveBatchSize; 43 | 44 | let iteration = 0; 45 | let epochLoss = 0; 46 | let step = 0; 47 | 48 | while (iteration < totalIterations) { 49 | for (let batch = 0; batch < numBatches && iteration < totalIterations; batch++, iteration++) { 50 | const start = batch * microBatchSize; 51 | const end = Math.min(start + microBatchSize, numSamples); 52 | 53 | // Slice the input and output tensors for the current batch 54 | const batchInputs = inputTensor.slice([start, 0], [end - start, -1]); 55 | const batchOutputs = outputTensor.slice([start, 0], [end - start, -1]); 56 | 57 | // Compute gradients and loss inside tf.tidy to manage temporary tensors 58 | const { lossValue, grads } = tf.tidy(() => { 59 | const lossFunction = () => { 60 | const preds = model.predict(batchInputs) as tf.Tensor; 61 | const loss = (model.loss as any)(batchOutputs, preds); 62 | preds.dispose(); // Dispose predictions to free memory 63 | return loss; 64 | }; 65 | // Compute gradients with respect to model variables 66 | const { value, grads } = tf.variableGrads(lossFunction); 67 | return { lossValue: value, grads }; 68 | }); 69 | 70 | // Accumulate loss 71 | const batchLoss = lossValue.dataSync()[0] * (end - start); 72 | epochLoss += batchLoss; 73 | lossValue.dispose(); // Dispose loss tensor 74 | 75 | // Accumulate gradients 76 | model.trainableWeights.forEach(weight => { 77 | const weightName = weight.name; 78 | if (grads[weightName]) { 79 | // Accumulate gradients by adding them to the existing accumulated gradients 80 | accumulatedGradients[weightName] = tf.add(accumulatedGradients[weightName], grads[weightName]); 81 | 82 | // Dispose the current batch gradient tensor to free memory 83 | grads[weightName].dispose(); 84 | } 85 | }); 86 | 87 | step++; 88 | 89 | // Apply gradients when accumulation steps are met 90 | if (step % accumulationSteps === 0) { 91 | // Compute averaged gradients 92 | const averagedGradients: tf.NamedTensorMap = {}; 93 | model.trainableWeights.forEach(weight => { 94 | const weightName = weight.name; 95 | averagedGradients[weightName] = tf.div(accumulatedGradients[weightName], accumulationSteps); 96 | }); 97 | 98 | // Apply the averaged gradients to update model weights 99 | (model.optimizer as tf.Optimizer).applyGradients(averagedGradients); 100 | 101 | // Dispose the averaged gradients tensors 102 | model.trainableWeights.forEach(weight => { 103 | const weightName = weight.name; 104 | averagedGradients[weightName].dispose(); 105 | }); 106 | 107 | // Reset accumulated gradients for the next accumulation cycle 108 | model.trainableWeights.forEach(weight => { 109 | const weightName = weight.name; 110 | accumulatedGradients[weightName].dispose(); // Dispose previous accumulation 111 | accumulatedGradients[weightName] = tf.zerosLike(weight.read()).clone(); 112 | }); 113 | } 114 | 115 | // Dispose batch tensors to free memory 116 | batchInputs.dispose(); 117 | batchOutputs.dispose(); 118 | } 119 | } 120 | 121 | // Apply remaining gradients if any (when total iterations are not perfectly divisible) 122 | if (step % accumulationSteps !== 0) { 123 | const remainingSteps = step % accumulationSteps; 124 | 125 | // Compute averaged gradients for the remaining steps 126 | const averagedGradients: tf.NamedTensorMap = {}; 127 | model.trainableWeights.forEach(weight => { 128 | const weightName = weight.name; 129 | averagedGradients[weightName] = tf.div(accumulatedGradients[weightName], remainingSteps); 130 | }); 131 | 132 | // Apply the averaged gradients 133 | (model.optimizer as tf.Optimizer).applyGradients(averagedGradients); 134 | 135 | // Dispose the averaged gradients tensors 136 | model.trainableWeights.forEach(weight => { 137 | const weightName = weight.name; 138 | averagedGradients[weightName].dispose(); 139 | }); 140 | 141 | // Reset accumulated gradients 142 | model.trainableWeights.forEach(weight => { 143 | const weightName = weight.name; 144 | accumulatedGradients[weightName].dispose(); // Dispose previous accumulation 145 | accumulatedGradients[weightName] = tf.zerosLike(weight.read()).clone(); 146 | }); 147 | } 148 | 149 | // Calculate average loss for the entire training 150 | const averageLoss = epochLoss / numSamples; 151 | console.log(`Training Iterations: ${totalIterations}, Loss = ${averageLoss.toFixed(4)}`); 152 | finalLoss = averageLoss; // Update final loss 153 | 154 | // After training, process SendConfig without model outputs 155 | const sendConfig: SendConfig = await processSendConfig(model, finalLoss); 156 | 157 | console.log('Success', 'Model trained and SendConfig created successfully.'); 158 | return sendConfig; 159 | 160 | } catch (error) { 161 | console.error('Error during training:', error); 162 | throw error; // Re-throw the error after logging 163 | } 164 | }; 165 | -------------------------------------------------------------------------------- /app/src/ml/index.ts: -------------------------------------------------------------------------------- 1 | // index.ts 2 | 3 | import { runTraining } from './Training'; 4 | import { runEvaluation } from './Evaluation'; 5 | import { runPrediction } from './Prediction'; 6 | import { runDiagnositics } from './Diagnostics'; 7 | 8 | export async function train(config: any) { 9 | return await runTraining(config); 10 | } 11 | 12 | export async function evaluate(config: any) { 13 | return await runEvaluation(config); 14 | } 15 | 16 | export async function predict(config: any) { 17 | return await runPrediction(config); 18 | } 19 | 20 | export async function isAvailable() { 21 | return await runDiagnositics(); 22 | } -------------------------------------------------------------------------------- /app/src/styles/Style.ts: -------------------------------------------------------------------------------- 1 | import { StyleSheet } from 'react-native'; 2 | import styled from 'styled-components/native'; 3 | import theme from './theme'; 4 | 5 | export const HomeStyles = StyleSheet.create({ 6 | container: { 7 | flex: 1, 8 | justifyContent: 'center', 9 | alignItems: 'center', 10 | backgroundColor: '#f5f5f5', 11 | }, 12 | text: { 13 | fontSize: 24, 14 | color: '#333', 15 | }, 16 | }); 17 | 18 | export const AuthStyles = StyleSheet.create({ 19 | container: { 20 | flex: 1, 21 | justifyContent: 'center', 22 | padding: 16, 23 | backgroundColor: '#f5f5f5', 24 | }, 25 | title: { 26 | fontSize: 24, 27 | fontWeight: 'bold', 28 | marginBottom: 16, 29 | textAlign: 'center', 30 | }, 31 | input: { 32 | height: 40, 33 | borderColor: '#ccc', 34 | borderWidth: 1, 35 | borderRadius: 8, 36 | marginBottom: 16, 37 | paddingHorizontal: 8, 38 | }, 39 | }); 40 | 41 | 42 | export const MainContent = styled.View` 43 | flex: 1; 44 | justify-content: space-between; 45 | align-items: center; 46 | padding: 10px; 47 | margin-top: 0px; 48 | background: ${theme.colors.background}; 49 | `; -------------------------------------------------------------------------------- /app/src/styles/theme.ts: -------------------------------------------------------------------------------- 1 | // src/theme.ts 2 | import { DefaultTheme } from 'react-native-paper'; 3 | import { DefaultTheme as StyledDefaultTheme } from 'styled-components/native'; 4 | 5 | const headingBaseStyle = { 6 | // fontFamily: 'Poppins, sans-serif', 7 | fontWeight: '300', 8 | }; 9 | 10 | const theme = { 11 | ...DefaultTheme, 12 | colors: { 13 | ...DefaultTheme.colors, 14 | primary: '#DDFF00', // mfl neon 15 | background: '#0D0D0D', // default black on the main page 16 | success: '#7BEA6D', // green color similar to the one on the chart 17 | surface: '#161616', // dark color for cards 18 | textStandout: '#FFFFFF', // white text 19 | textDefault: '#CCCCCC', // light grey 20 | textSecondary: '#7D7F78', // darker grey for secondary text 21 | }, 22 | roundness: 50, 23 | spacing: 8, 24 | typography: { 25 | fontFamily: 'Inter, sans-serif', 26 | h1: { 27 | ...headingBaseStyle, 28 | fontSize: 42, 29 | }, 30 | h2: { 31 | ...headingBaseStyle, 32 | fontSize: 34, 33 | }, 34 | h3: { 35 | ...headingBaseStyle, 36 | fontSize: 28, 37 | }, 38 | h4: { 39 | ...headingBaseStyle, 40 | fontSize: 24, 41 | }, 42 | h5: { 43 | ...headingBaseStyle, 44 | fontSize: 20, 45 | }, 46 | h6: { 47 | ...headingBaseStyle, 48 | fontSize: 16, 49 | }, 50 | body1: { 51 | fontFamily: 'Inter, sans-serif', 52 | fontWeight: '300', 53 | fontSize: 14, 54 | // lineHeight: 1.5, 55 | }, 56 | body2: { 57 | fontFamily: 'Inter, sans-serif', 58 | fontSize: 12, 59 | fontWeight: '300', 60 | // lineHeight: 1.5, 61 | }, 62 | button: { 63 | // fontFamily: 'Poppins, sans-serif', 64 | fontSize: 14, 65 | fontWeight: '400', 66 | textTransform: 'none', // Ensures button text is not capitalized 67 | }, 68 | }, 69 | buttonStyles: { 70 | primary: { 71 | backgroundColor: '#DDFF00', 72 | boxShadow: '0px 20px 35px rgba(221, 255, 0, 0.25)', 73 | color: '#000000', 74 | }, 75 | secondary: { 76 | backgroundColor: '#171717', 77 | color: '#FFFFFF', 78 | borderColor: '#262626', 79 | }, 80 | }, 81 | }; 82 | 83 | export default theme; 84 | export const styledTheme = { ...StyledDefaultTheme, ...theme }; -------------------------------------------------------------------------------- /app/src/types/declarations.d.ts: -------------------------------------------------------------------------------- 1 | declare module '@env' { 2 | export const SUPABASE_URL: string; 3 | export const SUPABASE_ANON_KEY: string; 4 | } 5 | -------------------------------------------------------------------------------- /app/src/utils/CurrentDevice.ts: -------------------------------------------------------------------------------- 1 | import AsyncStorage from '@react-native-async-storage/async-storage'; 2 | 3 | export async function setCurrentDeviceID(deviceId: number): Promise { 4 | AsyncStorage.setItem('deviceId', deviceId.toString()) 5 | } 6 | 7 | export async function getCurrentDeviceID(): Promise { 8 | const deviceId = await AsyncStorage.getItem('deviceId') 9 | return parseInt(deviceId, 10) 10 | } -------------------------------------------------------------------------------- /app/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "expo/tsconfig.base", 3 | "compilerOptions": { 4 | "strict": true, 5 | "jsx": "react", 6 | "resolveJsonModule": true, 7 | "esModuleInterop": true, 8 | "allowJs": true, 9 | "baseUrl": ".", 10 | "paths": { 11 | "@env": ["declarations.d.ts"] 12 | } 13 | }, 14 | "include": ["declarations.d.ts", "src/**/*"] 15 | } 16 | -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMUNACHI/FederatedPhoneML/0ecab2cb0bef83e1c049c854d8a98c9cee661e29/assets/.DS_Store -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMUNACHI/FederatedPhoneML/0ecab2cb0bef83e1c049c854d8a98c9cee661e29/assets/logo.png -------------------------------------------------------------------------------- /example/job.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from mfl import Trainer, keras 4 | 5 | num_samples = 40000 6 | num_features = 10 7 | hidden_dim = 20 8 | batch_size = 1 9 | epochs = 3 10 | 11 | # Define the model 12 | model = keras.Sequential([ 13 | keras.layers.Input(shape=(num_features,)), 14 | keras.layers.Dense(units=hidden_dim), 15 | keras.layers.Dense(units=num_features), 16 | ]) 17 | 18 | # Compile the model 19 | model.compile(optimizer="sgd", loss="mean_squared_error") 20 | 21 | # Generate synthetic training data 22 | inputs = np.random.randint(1, 6, size=(num_samples, num_features)) 23 | outputs = inputs * 2 + 1 24 | 25 | trainer = Trainer( 26 | model, 27 | inputs, 28 | outputs, 29 | batch_size=batch_size) 30 | 31 | start = time.time() 32 | trainer.fit(epochs=epochs) 33 | print(f"Training took {time.time() - start:.2f} seconds") 34 | -------------------------------------------------------------------------------- /framework/mfl/__init__.py: -------------------------------------------------------------------------------- 1 | import tf_keras as keras 2 | 3 | from .trainer import Trainer 4 | -------------------------------------------------------------------------------- /framework/mfl/common.py: -------------------------------------------------------------------------------- 1 | version = "1.7.0" 2 | 3 | 4 | # File name for the indexing JSON file in an artifact directory. 5 | ARTIFACT_MODEL_JSON_FILE_NAME = "model.json" 6 | ASSETS_DIRECTORY_NAME = "assets" 7 | 8 | # JSON string keys for fields of the indexing JSON. 9 | ARTIFACT_MODEL_TOPOLOGY_KEY = "modelTopology" 10 | ARTIFACT_MODEL_INITIALIZER = "modelInitializer" 11 | ARTIFACT_WEIGHTS_MANIFEST_KEY = "weightsManifest" 12 | 13 | FORMAT_KEY = "format" 14 | TFJS_GRAPH_MODEL_FORMAT = "graph-model" 15 | TFJS_LAYERS_MODEL_FORMAT = "layers-model" 16 | 17 | GENERATED_BY_KEY = "generatedBy" 18 | CONVERTED_BY_KEY = "convertedBy" 19 | 20 | SIGNATURE_KEY = "signature" 21 | INITIALIZER_SIGNATURE_KEY = "initializerSignature" 22 | USER_DEFINED_METADATA_KEY = "userDefinedMetadata" 23 | STRUCTURED_OUTPUTS_KEYS_KEY = "structuredOutputKeys" 24 | RESOURCE_ID_KEY = "resourceId" 25 | 26 | # Model formats. 27 | KERAS_SAVED_MODEL = "keras_saved_model" 28 | KERAS_MODEL = "keras" 29 | KERAS_KERAS_MODEL = "keras_keras" 30 | TF_SAVED_MODEL = "tf_saved_model" 31 | TF_HUB_MODEL = "tf_hub" 32 | TFJS_GRAPH_MODEL = "tfjs_graph_model" 33 | TFJS_LAYERS_MODEL = "tfjs_layers_model" 34 | TF_FROZEN_MODEL = "tf_frozen_model" 35 | 36 | # CLI argument strings. 37 | INPUT_PATH = "input_path" 38 | OUTPUT_PATH = "output_path" 39 | INPUT_FORMAT = "input_format" 40 | OUTPUT_FORMAT = "output_format" 41 | OUTPUT_NODE = "output_node_names" 42 | SIGNATURE_NAME = "signature_name" 43 | SAVED_MODEL_TAGS = "saved_model_tags" 44 | QUANTIZATION_BYTES = "quantization_bytes" 45 | QUANTIZATION_TYPE_FLOAT16 = "quantize_float16" 46 | QUANTIZATION_TYPE_UINT8 = "quantize_uint8" 47 | QUANTIZATION_TYPE_UINT16 = "quantize_uint16" 48 | SPLIT_WEIGHTS_BY_LAYER = "split_weights_by_layer" 49 | VERSION = "version" 50 | SKIP_OP_CHECK = "skip_op_check" 51 | STRIP_DEBUG_OPS = "strip_debug_ops" 52 | USE_STRUCTURED_OUTPUTS_NAMES = "use_structured_outputs_names" 53 | WEIGHT_SHARD_SIZE_BYTES = "weight_shard_size_bytes" 54 | CONTROL_FLOW_V2 = "control_flow_v2" 55 | EXPERIMENTS = "experiments" 56 | METADATA = "metadata" 57 | 58 | # Federated 59 | DATASET_MINIMUM_REPEAT = 2 60 | 61 | 62 | def get_converted_by(): 63 | """Get the convertedBy string for storage in model artifacts.""" 64 | return "TensorFlow.js Converter v%s" % version 65 | -------------------------------------------------------------------------------- /framework/mfl/data.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Tuple 3 | 4 | import numpy as np 5 | 6 | # from .common import DATASET_MINIMUM_REPEAT # Removed since repetition is no longer needed 7 | 8 | 9 | def validate_dataset(inputs: np.ndarray, outputs: np.ndarray) -> None: 10 | """Validate that inputs and outputs have the same number of samples.""" 11 | if outputs is not None and len(inputs) != len(outputs): 12 | raise ValueError("Input and output shapes do not match") 13 | 14 | 15 | def repeat_and_shuffle( 16 | inputs: np.ndarray, 17 | outputs: Optional[np.ndarray] = None, 18 | num_devices: int = 1, 19 | batch_size: int = 1, 20 | ) -> Tuple[np.ndarray, Optional[np.ndarray]]: 21 | """ 22 | Shuffle the dataset without repeating. 23 | 24 | Args: 25 | inputs (np.ndarray): Input data. 26 | outputs (Optional[np.ndarray]): Output data. Can be None. 27 | num_devices (int): Number of devices to split the data across. (Unused in shuffling) 28 | batch_size (int): Size of each batch. (Unused in shuffling) 29 | 30 | Returns: 31 | Tuple[np.ndarray, Optional[np.ndarray]]: Shuffled inputs and outputs. 32 | """ 33 | validate_dataset(inputs, outputs) 34 | 35 | if not isinstance(inputs, np.ndarray): 36 | inputs = np.array(inputs) 37 | if outputs is not None and not isinstance(outputs, np.ndarray): 38 | outputs = np.array(outputs) 39 | 40 | new_size = len(inputs) 41 | indices = np.random.permutation(new_size) 42 | 43 | if not isinstance(indices, np.ndarray) or indices.dtype.kind not in {'i', 'u'} or indices.ndim != 1: 44 | raise TypeError("Indices must be a one-dimensional array of integers.") 45 | 46 | try: 47 | shuffled_inputs = inputs[indices] 48 | except Exception as e: 49 | raise TypeError(f"Error shuffling inputs: {e}") 50 | 51 | if outputs is not None: 52 | try: 53 | shuffled_outputs = outputs[indices] 54 | except Exception as e: 55 | raise TypeError(f"Error shuffling outputs: {e}") 56 | else: 57 | shuffled_outputs = None 58 | 59 | return shuffled_inputs, shuffled_outputs 60 | 61 | 62 | def split_datasets( 63 | inputs: np.ndarray, 64 | devices: List[int], 65 | outputs: Optional[np.ndarray] = None, 66 | include_outputs: bool = False, 67 | ) -> List[Tuple[int, np.ndarray, Optional[np.ndarray]]]: 68 | """ 69 | Shuffle and split inputs and outputs into roughly equal parts for each device without truncating data points. 70 | 71 | Args: 72 | inputs (np.ndarray): Input data. 73 | devices (List[int]): List of device identifiers. 74 | batch_size (int): Size of each batch. (Unused in splitting) 75 | outputs (Optional[np.ndarray], optional): Output data. Defaults to None. 76 | include_outputs (bool, optional): Whether to include outputs in the split. Defaults to False. 77 | 78 | Returns: 79 | List[Tuple[int, np.ndarray, Optional[np.ndarray]]]: 80 | A list where each element is a tuple containing: 81 | - Device ID 82 | - Input subset for the device 83 | - Output subset for the device (or None if outputs are not provided) 84 | """ 85 | num_devices = len(devices) 86 | if num_devices == 0: 87 | raise ValueError("No devices available for training.") 88 | 89 | validate_dataset(inputs, outputs) 90 | 91 | total_samples = len(inputs) 92 | samples_per_device = total_samples // num_devices 93 | remainder = total_samples % num_devices 94 | 95 | datasets = [] 96 | start_idx = 0 97 | 98 | for i, device in enumerate(devices): 99 | 100 | current_samples = samples_per_device + (1 if i < remainder else 0) 101 | end_idx = start_idx + current_samples 102 | device_inputs = np.asarray(inputs[start_idx:end_idx]) 103 | device_outputs = np.asarray(outputs[start_idx:end_idx]) if include_outputs else None 104 | 105 | datasets.append( 106 | ( 107 | device, 108 | device_inputs, 109 | device_outputs, 110 | ) 111 | ) 112 | start_idx = end_idx 113 | 114 | assert start_idx == total_samples, "Not all samples were assigned to devices." 115 | 116 | return datasets 117 | -------------------------------------------------------------------------------- /framework/mfl/federated.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | 5 | 6 | def average_model_weights(all_weights: List[List[np.ndarray]]) -> List[np.ndarray]: 7 | """Compute average of model weights""" 8 | averaged_weights = [] 9 | for layer_weights in zip(*all_weights): 10 | avg_layer_weights = np.mean(layer_weights, axis=0) 11 | averaged_weights.append(avg_layer_weights) 12 | return averaged_weights 13 | 14 | 15 | def average_epoch_loss(losses: List[Tuple[float, int]]) -> float: 16 | """Compute weighted average of losses""" 17 | total_samples = sum(samples for _, samples in losses) 18 | average_loss = sum(loss * samples for loss, samples in losses) / total_samples 19 | return average_loss 20 | -------------------------------------------------------------------------------- /framework/mfl/keras_h5_conversion.py: -------------------------------------------------------------------------------- 1 | """Library for converting from hdf5 to json + binary weights. 2 | 3 | Used primarily to convert saved weights, or saved_models from their 4 | hdf5 format to a JSON + binary weights format that the TS codebase can use. 5 | """ 6 | 7 | import json 8 | import os 9 | import tempfile 10 | 11 | import h5py 12 | import numpy as np 13 | import six 14 | 15 | from .common import * 16 | from .write_weights import write_weights 17 | 18 | import warnings 19 | from urllib3.exceptions import NotOpenSSLWarning 20 | 21 | # Suppress the specific warning 22 | warnings.filterwarnings('ignore', category=UserWarning, message='You are saving your model as an HDF5 file.*') 23 | warnings.filterwarnings('ignore', category=NotOpenSSLWarning, message='urllib3 v2 only supports OpenSSL 1.1.1+.*') 24 | 25 | 26 | def normalize_weight_name(weight_name): 27 | """Remove suffix ":0" (if present) from weight name.""" 28 | name = as_text(weight_name) 29 | if name.endswith(":0"): 30 | # Python TensorFlow weight names ends with the output slot, which is 31 | # not applicable to TensorFlow.js. 32 | name = name[:-2] 33 | return name 34 | 35 | 36 | def as_text(bytes_or_text, encoding="utf-8"): 37 | if isinstance(bytes_or_text, six.text_type): 38 | return bytes_or_text 39 | elif isinstance(bytes_or_text, bytes): 40 | return bytes_or_text.decode(encoding) 41 | else: 42 | raise TypeError("Expected binary or unicode string, got %r" % bytes_or_text) 43 | 44 | 45 | def _convert_h5_group(group): 46 | """Construct a weights group entry. 47 | 48 | Args: 49 | group: The HDF5 group data, possibly nested. 50 | 51 | Returns: 52 | An array of weight groups (see `write_weights` in TensorFlow.js). 53 | """ 54 | group_out = [] 55 | if "weight_names" in group.attrs: 56 | # This is a leaf node in namespace (e.g., 'Dense' in 'foo/bar/Dense'). 57 | names = group.attrs["weight_names"].tolist() 58 | 59 | if not names: 60 | return group_out 61 | 62 | names = [as_text(name) for name in names] 63 | weight_values = [np.array(group[weight_name]) for weight_name in names] 64 | group_out += [ 65 | {"name": normalize_weight_name(weight_name), "data": weight_value} 66 | for (weight_name, weight_value) in zip(names, weight_values) 67 | ] 68 | else: 69 | # This is *not* a leaf level in the namespace (e.g., 'foo' in 70 | # 'foo/bar/Dense'). 71 | for key in group.keys(): 72 | # Call this method recursively. 73 | group_out += _convert_h5_group(group[key]) 74 | 75 | return group_out 76 | 77 | 78 | def _convert_v3_group(group, actual_layer_name): 79 | """Construct a weights group entry. 80 | 81 | Args: 82 | group: The HDF5 group data, possibly nested. 83 | 84 | Returns: 85 | An array of weight groups (see `write_weights` in TensorFlow.js). 86 | """ 87 | group_out = [] 88 | list_of_folder = [as_text(name) for name in group] 89 | if "vars" in list_of_folder: 90 | names = group["vars"] 91 | if not names: 92 | return group_out 93 | name_list = [as_text(name) for name in names] 94 | weight_values = [np.array(names[weight_name]) for weight_name in name_list] 95 | name_list = [os.path.join(actual_layer_name, item) for item in name_list] 96 | group_out += [ 97 | {"name": normalize_weight_name(weight_name), "data": weight_value} 98 | for (weight_name, weight_value) in zip(name_list, weight_values) 99 | ] 100 | else: 101 | for key in list_of_folder: 102 | group_out += _convert_v3_group(group[key], actual_layer_name) 103 | return group_out 104 | 105 | 106 | def _check_version(h5file): 107 | """Check version compatiility. 108 | 109 | Args: 110 | h5file: An h5file object. 111 | 112 | Raises: 113 | ValueError: if the KerasVersion of the HDF5 file is unsupported. 114 | """ 115 | keras_version = as_text(h5file.attrs["keras_version"]) 116 | if keras_version.split(".")[0] != "2": 117 | raise ValueError( 118 | "Expected Keras version 2; got Keras version %s" % keras_version 119 | ) 120 | 121 | 122 | def _initialize_output_dictionary(h5file): 123 | """Prepopulate required fields for all data foramts. 124 | 125 | Args: 126 | h5file: Valid h5file object. 127 | 128 | Returns: 129 | A dictionary with common fields sets, shared across formats. 130 | """ 131 | out = dict() 132 | out["keras_version"] = as_text(h5file.attrs["keras_version"]) 133 | out["backend"] = as_text(h5file.attrs["backend"]) 134 | return out 135 | 136 | 137 | def _ensure_h5file(h5file): 138 | if not isinstance(h5file, h5py.File): 139 | return h5py.File(h5file, "r") 140 | else: 141 | return h5file 142 | 143 | 144 | def _ensure_json_dict(item): 145 | return item if isinstance(item, dict) else json.loads(as_text(item)) 146 | 147 | 148 | def _discard_v3_keys(json_dict, keys_to_delete): 149 | if isinstance(json_dict, dict): 150 | keys = list(json_dict.keys()) 151 | for key in keys: 152 | if key in keys_to_delete: 153 | del json_dict[key] 154 | else: 155 | _discard_v3_keys(json_dict[key], keys_to_delete) 156 | elif isinstance(json_dict, list): 157 | for item in json_dict: 158 | _discard_v3_keys(item, keys_to_delete) 159 | 160 | 161 | # https://github.com/tensorflow/tfjs/issues/1255, b/124791387 162 | # In tensorflow version 1.13 and some alpha and nightly-preview versions, 163 | # the following layers have different class names in their serialization. 164 | # This issue should be fixed in later releases. But we include the logic 165 | # to translate them anyway, for users who use those versions of tensorflow. 166 | _CLASS_NAME_MAP = { 167 | "BatchNormalizationV1": "BatchNormalization", 168 | "UnifiedGRU": "GRU", 169 | "UnifiedLSTM": "LSTM", 170 | } 171 | 172 | 173 | def translate_class_names(input_object): 174 | """Perform class name replacement. 175 | 176 | Beware that this method modifies the input object in-place. 177 | """ 178 | if not isinstance(input_object, dict): 179 | return 180 | for key in input_object: 181 | value = input_object[key] 182 | if key == "class_name" and value in _CLASS_NAME_MAP: 183 | input_object[key] = _CLASS_NAME_MAP[value] 184 | elif isinstance(value, dict): 185 | translate_class_names(value) 186 | elif isinstance(value, (tuple, list)): 187 | for item in value: 188 | translate_class_names(item) 189 | 190 | 191 | def h5_merged_saved_model_to_tfjs_format(h5file, split_by_layer=False): 192 | """Load topology & weight values from HDF5 file and convert. 193 | 194 | The HDF5 file is one generated by Keras' save_model method or model.save() 195 | 196 | N.B.: 197 | 1) This function works only on HDF5 values from Keras version 2. 198 | 2) This function does not perform conversion for special weights including 199 | ConvLSTM2D and CuDNNLSTM. 200 | 201 | Args: 202 | h5file: An instance of h5py.File, or the path to an h5py file. 203 | split_by_layer: (Optional) whether the weights of different layers are 204 | to be stored in separate weight groups (Default: `False`). 205 | 206 | Returns: 207 | (model_json, groups) 208 | model_json: a JSON dictionary holding topology and system metadata. 209 | group: an array of group_weights as defined in tfjs write_weights. 210 | 211 | Raises: 212 | ValueError: If the Keras version of the HDF5 file is not supported. 213 | """ 214 | h5file = _ensure_h5file(h5file) 215 | try: 216 | _check_version(h5file) 217 | except ValueError: 218 | print( 219 | """failed to lookup keras version from the file, 220 | this is likely a weight only file""" 221 | ) 222 | model_json = _initialize_output_dictionary(h5file) 223 | 224 | model_json["model_config"] = _ensure_json_dict(h5file.attrs["model_config"]) 225 | translate_class_names(model_json["model_config"]) 226 | if "training_config" in h5file.attrs: 227 | model_json["training_config"] = _ensure_json_dict( 228 | h5file.attrs["training_config"] 229 | ) 230 | 231 | groups = [] if split_by_layer else [[]] 232 | 233 | model_weights = h5file["model_weights"] 234 | layer_names = [as_text(n) for n in model_weights] 235 | for layer_name in layer_names: 236 | layer = model_weights[layer_name] 237 | group = _convert_h5_group(layer) 238 | if group: 239 | if split_by_layer: 240 | groups.append(group) 241 | else: 242 | groups[0] += group 243 | return model_json, groups 244 | 245 | 246 | def h5_v3_merged_saved_model_to_tfjs_format( 247 | h5file, meta_file, config_file, split_by_layer=False 248 | ): 249 | """Load topology & weight values from HDF5 file and convert. 250 | 251 | The HDF5 weights file is one generated by Keras's save_model method or model.save() 252 | 253 | N.B.: 254 | 1) This function works only on HDF5 values from Keras version 3. 255 | 2) This function does not perform conversion for special weights including 256 | ConvLSTM2D and CuDNNLSTM. 257 | 258 | Args: 259 | h5file: An instance of h5py.File, or the path to an h5py file. 260 | split_by_layer: (Optional) whether the weights of different layers are 261 | to be stored in separate weight groups (Default: `False`). 262 | 263 | Returns: 264 | (model_json, groups) 265 | model_json: a JSON dictionary holding topology and system metadata. 266 | group: an array of group_weights as defined in tfjs write_weights. 267 | 268 | Raises: 269 | ValueError: If the Keras version of the HDF5 file is not supported. 270 | """ 271 | h5file = _ensure_h5file(h5file) 272 | model_json = dict() 273 | model_json["keras_version"] = meta_file["keras_version"] 274 | 275 | keys_to_remove = ["module", "registered_name", "date_saved"] 276 | config = _ensure_json_dict(config_file) 277 | _discard_v3_keys(config, keys_to_remove) 278 | model_json["model_config"] = config 279 | translate_class_names(model_json["model_config"]) 280 | if "training_config" in h5file.attrs: 281 | model_json["training_config"] = _ensure_json_dict( 282 | h5file.attrs["training_config"] 283 | ) 284 | 285 | groups = [] if split_by_layer else [[]] 286 | 287 | _convert_v3_group_structure_to_weights( 288 | groups=groups, group=h5file, split_by_layer=split_by_layer 289 | ) 290 | return model_json, groups 291 | 292 | 293 | def _convert_v3_group_structure_to_weights(groups, group, split_by_layer, indent=""): 294 | for key in group.keys(): 295 | if isinstance(group[key], h5py.Group): 296 | _convert_v3_group_structure_to_weights( 297 | groups, group[key], split_by_layer, indent + key + "/" 298 | ) 299 | elif isinstance(group[key], h5py.Dataset): 300 | group_of_weights = dict() 301 | for key in group.keys(): 302 | group_of_weights[str(indent + key)] = group[key] 303 | group_out = _convert_group(group_of_weights) 304 | if split_by_layer: 305 | groups.append(group_out) 306 | else: 307 | groups[0] += group_out 308 | break 309 | 310 | 311 | def _convert_group(group_dict): 312 | group_out = [] 313 | for key in group_dict.keys(): 314 | name = key 315 | weights_value = np.array(group_dict[key]) 316 | group_out += [{"name": name, "data": weights_value}] 317 | 318 | return group_out 319 | 320 | 321 | def h5_weights_to_tfjs_format(h5file, split_by_layer=False): 322 | """Load weight values from a Keras HDF5 file and to a binary format. 323 | 324 | The HDF5 file is one generated by Keras' Model.save_weights() method. 325 | 326 | N.B.: 327 | 1) This function works only on HDF5 values from Keras version 2. 328 | 2) This function does not perform conversion for special weights including 329 | ConvLSTM2D and CuDNNLSTM. 330 | 331 | Args: 332 | h5file: An instance of h5py.File, or the path to an h5py file. 333 | split_by_layer: (Optional) whether the weights of different layers are 334 | to be stored in separate weight groups (Default: `False`). 335 | 336 | Returns: 337 | An array of group_weights as defined in tfjs write_weights. 338 | 339 | Raises: 340 | ValueError: If the Keras version of the HDF5 file is not supported 341 | """ 342 | h5file = _ensure_h5file(h5file) 343 | try: 344 | _check_version(h5file) 345 | except ValueError: 346 | print( 347 | """failed to lookup keras version from the file, 348 | this is likely a weight only file""" 349 | ) 350 | 351 | groups = [] if split_by_layer else [[]] 352 | 353 | # pylint: disable=not-an-iterable 354 | layer_names = [as_text(n) for n in h5file.attrs["layer_names"]] 355 | # pylint: enable=not-an-iterable 356 | for layer_name in layer_names: 357 | layer = h5file[layer_name] 358 | group = _convert_h5_group(layer) 359 | if group: 360 | if split_by_layer: 361 | groups.append(group) 362 | else: 363 | groups[0] += group 364 | return groups 365 | 366 | 367 | def _get_generated_by(topology): 368 | if topology is None: 369 | return None 370 | elif "keras_version" in topology: 371 | return "keras v%s" % topology["keras_version"] 372 | else: 373 | return None 374 | 375 | 376 | def write_artifacts( 377 | topology, 378 | weights, 379 | output_dir, 380 | quantization_dtype_map=None, 381 | weight_shard_size_bytes=1024 * 1024 * 4, 382 | metadata=None, 383 | ): 384 | """Writes weights and topology to the output_dir. 385 | 386 | If `topology` is Falsy (e.g., `None`), only emit weights to output_dir. 387 | 388 | Args: 389 | topology: a JSON dictionary, representing the Keras config. 390 | weights: an array of weight groups (as defined in tfjs write_weights). 391 | output_dir: the directory to hold all the contents. 392 | quantization_dtype_map: (Optional) A mapping from dtype 393 | (`uint8`, `uint16`, `float16`) to weights names. The weight mapping 394 | supports wildcard substitution. 395 | weight_shard_size_bytes: Shard size (in bytes) of the weight files. 396 | The size of each weight file will be <= this value. 397 | metadata: User defined metadata map. 398 | """ 399 | # TODO(cais, nielsene): This method should allow optional arguments of 400 | # `write_weights.write_weights` (e.g., shard size) and forward them. 401 | # We write the topology after since write_weights makes no promises about 402 | # preserving directory contents. 403 | if not (isinstance(weight_shard_size_bytes, int) and weight_shard_size_bytes > 0): 404 | raise ValueError( 405 | "Expected weight_shard_size_bytes to be a positive integer, " 406 | "but got %s" % weight_shard_size_bytes 407 | ) 408 | 409 | if os.path.isfile(output_dir): 410 | raise ValueError( 411 | 'Path "%d" already exists as a file (not a directory).' % output_dir 412 | ) 413 | 414 | model_json = { 415 | FORMAT_KEY: TFJS_LAYERS_MODEL_FORMAT, 416 | GENERATED_BY_KEY: _get_generated_by(topology), 417 | CONVERTED_BY_KEY: get_converted_by(), 418 | } 419 | 420 | if metadata: 421 | model_json[USER_DEFINED_METADATA_KEY] = metadata 422 | 423 | model_json[ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None 424 | weights_manifest = write_weights( 425 | weights, 426 | output_dir, 427 | write_manifest=False, 428 | quantization_dtype_map=quantization_dtype_map, 429 | shard_size_bytes=weight_shard_size_bytes, 430 | ) 431 | assert isinstance(weights_manifest, list) 432 | model_json[ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest 433 | return model_json 434 | 435 | 436 | def get_keras_model_graph( 437 | model, 438 | artifacts_dir=tempfile.tempdir, 439 | quantization_dtype_map=None, 440 | weight_shard_size_bytes=1024 * 1024 * 4, 441 | metadata=None, 442 | ): 443 | r"""Save a Keras model and its weights in TensorFlow.js format. 444 | 445 | Args: 446 | model: An instance of `keras.Model`. 447 | artifacts_dir: The directory in which the artifacts will be saved. 448 | The artifacts to be saved include: 449 | - model.json: A JSON representing the model. It has the following 450 | fields: 451 | - 'modelTopology': A JSON object describing the topology of the model, 452 | along with additional information such as training. It is obtained 453 | through calling `model.save()`. 454 | - 'weightsManifest': A TensorFlow.js-format JSON manifest for the 455 | model's weights. 456 | - files containing weight values in groups, with the file name pattern 457 | group(\d+)-shard(\d+)of(\d+). 458 | If the directory does not exist, this function will attempt to create it. 459 | quantization_dtype_map: (Optional) A mapping from dtype 460 | (`uint8`, `uint16`, `float16`) to weights names. The weight mapping 461 | supports wildcard substitution. 462 | weight_shard_size_bytes: Shard size (in bytes) of the weight files. 463 | The size of each weight file will be <= this value. 464 | metadata: User defined metadata map. 465 | 466 | Raises: 467 | ValueError: If `artifacts_dir` already exists as a file (not a directory). 468 | """ 469 | temp_h5_path = tempfile.mktemp() + ".h5" 470 | model.save(temp_h5_path, save_format="h5") 471 | topology_json, weight_groups = h5_merged_saved_model_to_tfjs_format(temp_h5_path) 472 | if os.path.isfile(artifacts_dir): 473 | raise ValueError('Path "%s" already exists as a file.' % artifacts_dir) 474 | if not os.path.isdir(artifacts_dir): 475 | os.makedirs(artifacts_dir) 476 | return write_artifacts( 477 | topology_json, 478 | weight_groups, 479 | artifacts_dir, 480 | quantization_dtype_map=quantization_dtype_map, 481 | weight_shard_size_bytes=weight_shard_size_bytes, 482 | metadata=metadata, 483 | ) 484 | -------------------------------------------------------------------------------- /framework/mfl/quantization.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | 3 | import numpy as np 4 | 5 | QUANTIZATION_DTYPE_FLOAT16 = "float16" 6 | QUANTIZATION_DTYPE_UINT8 = "uint8" 7 | QUANTIZATION_DTYPE_UINT16 = "uint16" 8 | 9 | QUANTIZATION_BYTES_TO_DTYPES = { 10 | 1: QUANTIZATION_DTYPE_UINT8, 11 | 2: QUANTIZATION_DTYPE_UINT16, 12 | } 13 | QUANTIZATION_OPTION_TO_DTYPES = { 14 | QUANTIZATION_DTYPE_UINT8: np.uint8, 15 | QUANTIZATION_DTYPE_UINT16: np.uint16, 16 | QUANTIZATION_DTYPE_FLOAT16: np.float16, 17 | } 18 | 19 | 20 | def map_layers_to_quantization_dtype(names, quantization_dtype_map): 21 | """Maps node names to their quantization dtypes. 22 | 23 | Given a quantization_dtype_map which maps dtypes `uint8`, `uint16`, `float16` 24 | to node patterns, e.g., conv/*/weights we construct a new mapping for each 25 | individual node name to its dtype, e.g., conv/1/weight -> `uint8`. 26 | A dtype in the map can also be a boolean, signaling a fallthrough dtype. 27 | There can only be one fallthrough dtype in the map. A fallthrough dtype 28 | will convert all weights that don't match any pattern to the provided dtype. 29 | 30 | Args: 31 | names: Array of node names. 32 | quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`) 33 | to weights. The weight mapping supports wildcard substitution. 34 | 35 | Returns: 36 | quantization_dtype: A mapping from each node name which matches 37 | an entry in quantization_dtype_map to its corresponding dtype. 38 | 39 | Raises: 40 | ValueError: - If multiple dtypes match the same node name 41 | - If more than one fallthrough is provided 42 | """ 43 | if quantization_dtype_map is None: 44 | return {} 45 | 46 | fallthrough = None 47 | quantization_dtype = {} 48 | for dtype_name, patterns in quantization_dtype_map.items(): 49 | # Record fallthrough if there is one 50 | if isinstance(patterns, bool) and patterns: 51 | # Only one fallthrough is supported 52 | if fallthrough is not None: 53 | raise ValueError( 54 | "More than one quantization fallthrough provided, " 55 | "exactly one is supported" 56 | ) 57 | fallthrough = dtype_name 58 | continue 59 | if isinstance(patterns, str): 60 | patterns = list([patterns]) 61 | 62 | # Record matched weights for dtype 63 | for pattern in patterns: 64 | for match in fnmatch.filter(names, pattern): 65 | dtype = QUANTIZATION_OPTION_TO_DTYPES[dtype_name] 66 | if match in quantization_dtype and quantization_dtype[match] != dtype: 67 | raise ValueError( 68 | "Two quantization values %s, %s match the same node %s" 69 | % (dtype, quantization_dtype[match], match) 70 | ) 71 | quantization_dtype[match] = dtype 72 | 73 | # Catch all remaining names with fallthrough 74 | if fallthrough is not None: 75 | nameset = set(names) 76 | fallthrough_names = nameset - set(quantization_dtype.keys()) 77 | for name in fallthrough_names: 78 | quantization_dtype[name] = QUANTIZATION_OPTION_TO_DTYPES[fallthrough] 79 | 80 | return quantization_dtype 81 | 82 | 83 | def quantize_weights(data, quantization_dtype): 84 | """Quantizes the weights by linearly re-scaling across available bits. 85 | 86 | The weights are quantized by linearly re-scaling the values between the 87 | minimum and maximum value, and representing them with the number of bits 88 | provided by the `quantization_dtype`. 89 | 90 | In order to guarantee that 0 is perfectly represented by one of the quantized 91 | values, the range is "nudged" in the same manner as in TF-Lite. 92 | 93 | Weights can be de-quantized by multiplying by the returned `scale` and adding 94 | `min`. 95 | 96 | Args: 97 | data: A numpy array of dtype 'float32' or 'int32'. 98 | quantization_dtype: A numpy dtype to quantize weights to. Only np.float16, 99 | np.uint8, and np.uint16 are supported. 100 | 101 | Returns: 102 | quantized_data: The quantized weights as a numpy array with dtype 103 | `quantization_dtype`. 104 | metadata: A dictionary with the corresponding metadata for the quantization 105 | type. There is no metadata associated with float16. 106 | For affine quantization there are two associated metadata values: 107 | scale: The linearly scaling constant used for quantization. 108 | min_val: The minimum value of the linear range. 109 | Raises: 110 | ValueError: if `quantization_dtype` is not a valid type. 111 | """ 112 | if quantization_dtype in [np.uint8, np.uint16]: 113 | # Compute the min and max for the group. 114 | min_val = data.min().astype(np.float64) 115 | max_val = data.max().astype(np.float64) 116 | if min_val == max_val: 117 | # If there is only a single value, we can represent everything as zeros. 118 | quantized_data = np.zeros_like(data, dtype=quantization_dtype) 119 | scale = 1.0 120 | else: 121 | # Quantize data. 122 | scale, min_val, max_val = _get_affine_quantization_range( 123 | min_val, max_val, quantization_dtype 124 | ) 125 | quantized_data = np.round( 126 | (data.clip(min_val, max_val) - min_val) / scale 127 | ).astype(quantization_dtype) 128 | 129 | return quantized_data, {"min": min_val, "scale": scale} 130 | elif quantization_dtype == np.float16: 131 | if data.dtype != np.float32: 132 | raise ValueError( 133 | "Invalid data dtype %r\n" 134 | "float16 quantization only supports float32 dtype" % data.dtype 135 | ) 136 | quantized_data = data.astype(np.float16) 137 | return quantized_data, {} 138 | else: 139 | raise ValueError("Invalid `quantization_dtype`: %r" % quantization_dtype) 140 | 141 | 142 | def dequantize_weights(data, metadata, original_dtype=np.float32): 143 | dtype = data.dtype 144 | 145 | if dtype in [np.uint8, np.uint16]: 146 | if not ("scale" in metadata and "min" in metadata): 147 | raise ValueError("Missing metadata min or scale for dtype %s" % dtype.name) 148 | scale = metadata["scale"] 149 | min_val = metadata["min"] 150 | if original_dtype == np.int32: 151 | return np.round(data * scale + min_val).astype(original_dtype) 152 | else: 153 | return (data * scale + min_val).astype(original_dtype) 154 | elif dtype == np.float16: 155 | if original_dtype != np.float32: 156 | raise ValueError( 157 | "Invalid data dtype %r\n" 158 | "float16 quantization only supports float32 dtype" % data.dtype 159 | ) 160 | return data.astype(original_dtype) 161 | else: 162 | raise ValueError( 163 | "Invalid dtype %s for dequantization\n" 164 | "Supported dtypes are uint8, uint16, float16" % dtype.name 165 | ) 166 | 167 | 168 | def _get_affine_quantization_range(min_val, max_val, quantization_dtype): 169 | """Computes quantization range to ensure that zero is represented if covered. 170 | 171 | Gymnastics with nudged zero point is to ensure that real zero maps to an 172 | integer, which is required for e.g. zero-padding in convolutional layers. 173 | 174 | Based on `NudgeQuantizationRange` in 175 | tensorflow/contrib/lite/kernels/internal/quantization_util.h, except we do not 176 | nudge if 0 is not in the range. 177 | 178 | Args: 179 | min_val: The actual minimum value of the data. 180 | max_val: The actual maximum value of the data. 181 | quantization_dtype: A numpy dtype to quantize weights to. Only np.uint8 and 182 | np.uint16 are supported. 183 | 184 | Returns: 185 | scale: The linear scaling constant used for quantization. 186 | nudged_min: The adjusted minimum value to ensure zero is represented, if 187 | covered. 188 | nudged_max: The adjusted maximum value to ensure zero is represented, if 189 | covered. 190 | Raises: 191 | ValueError: if `quantization_dtype` is not a valid type. 192 | """ 193 | if quantization_dtype not in [np.uint8, np.uint16]: 194 | raise ValueError("Invalid `quantization_dtype`: %r" % quantization_dtype) 195 | 196 | quant_max = np.iinfo(quantization_dtype).max 197 | scale = (max_val - min_val) / quant_max 198 | 199 | if min_val <= 0 <= max_val: 200 | quantized_zero_point = (0 - min_val) / scale 201 | nudged_zero_point = np.round(quantized_zero_point) 202 | 203 | # Solve `0 = nudged_zero_point * scale + nudged_min` for `nudged_min`. 204 | nudged_min = -nudged_zero_point * scale 205 | nudged_max = quant_max * scale + nudged_min 206 | else: 207 | nudged_min, nudged_max = min_val, max_val 208 | 209 | return scale, nudged_min, nudged_max 210 | -------------------------------------------------------------------------------- /framework/mfl/read_weights.py: -------------------------------------------------------------------------------- 1 | """Read weights stored in TensorFlow.js-format binary files.""" 2 | 3 | import io 4 | import os 5 | 6 | import numpy as np 7 | 8 | from .quantization import dequantize_weights 9 | 10 | _INPUT_DTYPES = [ 11 | np.float16, 12 | np.float32, 13 | np.int32, 14 | np.complex64, 15 | np.uint8, 16 | np.uint16, 17 | object, 18 | bool, 19 | ] 20 | 21 | # Number of bytes used to encode the length of a string in a string tensor. 22 | STRING_LENGTH_NUM_BYTES = 4 23 | # The data type used to encode the length of a string in a string tensor. 24 | STRING_LENGTH_DTYPE = np.dtype("uint32").newbyteorder("<") 25 | 26 | 27 | def read_weights(weights_manifest, base_path, flatten=False): 28 | """Load weight values according to a TensorFlow.js weights manifest. 29 | 30 | Args: 31 | weights_manifest: A TensorFlow.js-format weights manifest (a JSON array). 32 | base_path: Base path prefix for the weights files. 33 | flatten: Whether all the weight groups in the return value are to be 34 | flattened as a single weights group. Default: `False`. 35 | 36 | Returns: 37 | If `flatten` is `False`, a `list` of weight groups. Each group is an array 38 | of weight entries. Each entry is a dict that maps a unique name to a numpy 39 | array, for example: 40 | entry = { 41 | 'name': 'weight1', 42 | 'data': np.array([1, 2, 3], 'float32') 43 | } 44 | 45 | Weights groups would then look like: 46 | weight_groups = [ 47 | [group_0_entry1, group_0_entry2], 48 | [group_1_entry1, group_1_entry2], 49 | ] 50 | If `flatten` is `True`, returns a single weight group. 51 | """ 52 | if not isinstance(weights_manifest, list): 53 | raise ValueError( 54 | "weights_manifest should be a `list`, but received %s" 55 | % type(weights_manifest) 56 | ) 57 | 58 | data_buffers = [] 59 | for group in weights_manifest: 60 | buff = io.BytesIO() 61 | buff_writer = io.BufferedWriter(buff) 62 | for path in group["paths"]: 63 | with open(os.path.join(base_path, path), "rb") as f: 64 | buff_writer.write(f.read()) 65 | buff_writer.flush() 66 | buff_writer.seek(0) 67 | data_buffers.append(buff.read()) 68 | return decode_weights(weights_manifest, data_buffers, flatten=flatten) 69 | 70 | 71 | def _deserialize_string_array(data_buffer, offset, shape): 72 | """Deserializes bytes into np.array of dtype `object` which holds strings. 73 | 74 | Each string value is preceded by 4 bytes which denote a 32-bit unsigned 75 | integer in little endian that specifies the byte length of the following 76 | string. This is followed by the actual string bytes. If the tensor has no 77 | strings there will be no bytes reserved. Empty strings will still take 4 bytes 78 | for the length. 79 | 80 | For example, a tensor that has 2 strings will be encoded as 81 | [byte length of s1][bytes of s1...][byte length of s2][bytes of s2...] 82 | 83 | where byte length always takes 4 bytes. 84 | 85 | Args: 86 | data_buffer: A buffer of bytes containing the serialized data. 87 | offset: The byte offset in that buffer that denotes the start of the tensor. 88 | shape: The logical shape of the tensor. 89 | 90 | Returns: 91 | A tuple of (np.array, offset) where the np.array contains the encoded 92 | strings, and the offset contains the new offset (the byte position in the 93 | buffer at the end of the string data). 94 | """ 95 | size = int(np.prod(shape)) 96 | if size == 0: 97 | return (np.array([], "object").reshape(shape), offset + STRING_LENGTH_NUM_BYTES) 98 | vals = [] 99 | for _ in range(size): 100 | byte_length = np.frombuffer( 101 | data_buffer[offset : offset + STRING_LENGTH_NUM_BYTES], STRING_LENGTH_DTYPE 102 | )[0] 103 | offset += STRING_LENGTH_NUM_BYTES 104 | string = data_buffer[offset : offset + byte_length] 105 | vals.append(string) 106 | offset += byte_length 107 | return np.array(vals, "object").reshape(shape), offset 108 | 109 | 110 | def _deserialize_numeric_array(data_buffer, offset, dtype, shape): 111 | weight_numel = 1 112 | for dim in shape: 113 | weight_numel *= dim 114 | return np.frombuffer( 115 | data_buffer, dtype=dtype, count=weight_numel, offset=offset 116 | ).reshape(shape) 117 | 118 | 119 | def decode_weights(weights_manifest, data_buffers, flatten=False): 120 | """Load weight values from buffer(s) according to a weights manifest. 121 | 122 | Args: 123 | weights_manifest: A TensorFlow.js-format weights manifest (a JSON array). 124 | data_buffers: A buffer or a `list` of buffers containing the weights values 125 | in binary format, concatenated in the order specified in 126 | `weights_manifest`. If a `list` of buffers, the length of the `list` 127 | must match the length of `weights_manifest`. A single buffer is 128 | interpreted as a `list` of one buffer and is valid only if the length of 129 | `weights_manifest` is `1`. 130 | flatten: Whether all the weight groups in the return value are to be 131 | flattened as a single weight groups. Default: `False`. 132 | 133 | Returns: 134 | If `flatten` is `False`, a `list` of weight groups. Each group is an array 135 | of weight entries. Each entry is a dict that maps a unique name to a numpy 136 | array, for example: 137 | entry = { 138 | 'name': 'weight1', 139 | 'data': np.array([1, 2, 3], 'float32') 140 | } 141 | 142 | Weights groups would then look like: 143 | weight_groups = [ 144 | [group_0_entry1, group_0_entry2], 145 | [group_1_entry1, group_1_entry2], 146 | ] 147 | If `flatten` is `True`, returns a single weight group. 148 | 149 | Raises: 150 | ValueError: if the lengths of `weights_manifest` and `data_buffers` do not 151 | match. 152 | """ 153 | if not isinstance(data_buffers, list): 154 | data_buffers = [data_buffers] 155 | if len(weights_manifest) != len(data_buffers): 156 | raise ValueError( 157 | "Mismatch in the length of weights_manifest (%d) and the length of " 158 | "data buffers (%d)" % (len(weights_manifest), len(data_buffers)) 159 | ) 160 | 161 | out = [] 162 | for group, data_buffer in zip(weights_manifest, data_buffers): 163 | offset = 0 164 | out_group = [] 165 | 166 | for weight in group["weights"]: 167 | quant_info = weight.get("quantization", None) 168 | name = weight["name"] 169 | if weight["dtype"] == "string": 170 | # String array. 171 | dtype = object 172 | elif quant_info: 173 | # Quantized array. 174 | dtype = np.dtype(quant_info["dtype"]) 175 | else: 176 | # Regular numeric array. 177 | dtype = np.dtype(weight["dtype"]) 178 | shape = weight["shape"] 179 | if dtype not in _INPUT_DTYPES: 180 | raise NotImplementedError("Unsupported data type: %s" % dtype) 181 | if weight["dtype"] == "string": 182 | value, offset = _deserialize_string_array(data_buffer, offset, shape) 183 | else: 184 | value = _deserialize_numeric_array(data_buffer, offset, dtype, shape) 185 | offset += dtype.itemsize * value.size 186 | if quant_info: 187 | value = dequantize_weights(value, quant_info, np.dtype(weight["dtype"])) 188 | out_group.append({"name": name, "data": value}) 189 | 190 | if flatten: 191 | out += out_group 192 | else: 193 | out.append(out_group) 194 | 195 | return out 196 | -------------------------------------------------------------------------------- /framework/mfl/trainer.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from collections import defaultdict 3 | from typing import List, Optional, Tuple 4 | 5 | import numpy as np 6 | import tf_keras as keras 7 | 8 | from .data import split_datasets 9 | from .federated import average_epoch_loss, average_model_weights 10 | from .keras_h5_conversion import get_keras_model_graph 11 | from .worker import RequestConfig, Worker 12 | 13 | 14 | class Trainer: 15 | """Distributed training by using federated training class""" 16 | def __init__( 17 | self, 18 | model: keras.Model, 19 | inputs: np.ndarray, 20 | outputs: np.ndarray, 21 | batch_size: int, 22 | validation_inputs: Optional[np.ndarray] = None, 23 | validation_outputs: Optional[np.ndarray] = None, 24 | ): 25 | 26 | self.model = model 27 | self.modelJson = get_keras_model_graph(self.model) 28 | self.device_urls = None 29 | self.batch_size = batch_size 30 | worker_id = np.random.randint(0, 100000) 31 | self.worker = Worker(_id=worker_id) 32 | self.inputs = np.asarray(inputs) 33 | self.outputs = np.asarray(outputs) 34 | self.validation_inputs = validation_inputs 35 | self.validation_outputs = validation_outputs 36 | self.history = defaultdict(list) 37 | self.device_epochs = 1 38 | 39 | def _create_base_request_config(self, epochs=None) -> RequestConfig: 40 | """Create base request configuration""" 41 | return RequestConfig( 42 | modelJson=self.modelJson, 43 | weights=self._get_weights(), 44 | batchSize=self.batch_size, 45 | epochs=self.device_epochs, 46 | ) 47 | 48 | def _reset(self): 49 | """Reset training job data""" 50 | self.history = defaultdict(list) 51 | 52 | def _get_weights(self) -> List: 53 | """Convert numpy arrays to nested lists for JSON serialization""" 54 | return [w.tolist() for w in self.model.get_weights()] 55 | 56 | def _deserialize_weights(self, weights_data: List) -> List[np.ndarray]: 57 | """Convert nested lists back to numpy arrays""" 58 | return [np.array(w, dtype=np.float32) for w in weights_data] 59 | 60 | def _to_validate(self): 61 | """Check if validation data is available""" 62 | return ( 63 | self.validation_inputs is not None and self.validation_outputs is not None 64 | ) 65 | 66 | async def _dispatch( 67 | self, 68 | request_config: RequestConfig, 69 | datasets: List[Tuple[int, np.ndarray, np.ndarray]], 70 | request_type: str, 71 | ) -> None: 72 | """Dispatch tasks to all available devices""" 73 | request_configs = [] 74 | 75 | for device, device_inputs, device_outputs in datasets: 76 | 77 | request_config.inputs = device_inputs.tolist() 78 | request_config.outputs = ( 79 | device_outputs.tolist() if device_outputs is not None else None 80 | ) 81 | request_config.inputShape = list(device_inputs.shape) 82 | 83 | if device_outputs is not None: 84 | request_config.outputShape = list(device_outputs.shape) 85 | 86 | request_config.datasetsPerDevice = len(device_inputs) 87 | 88 | request_configs.append(request_config) 89 | 90 | await self.worker.run( 91 | request_type=request_type, request_configs=request_configs 92 | ) 93 | 94 | def _gather( 95 | self, request_type: str 96 | ) -> Tuple[List[np.ndarray], List[Tuple[float, int]]]: 97 | """Gather results from all devices, update model weights, and compute loss""" 98 | all_weights = [] 99 | epoch_device_losses = [] 100 | outputs = [] 101 | 102 | results = self.worker.task_manager.completed_tasks.items() 103 | 104 | for task_id, task in list(results): 105 | if task.response_data.outputs is not None: 106 | outputs.append(task.response_data.outputs) 107 | if task.response_data.weights is not None: 108 | deserialized_weights = self._deserialize_weights( 109 | task.response_data.weights 110 | ) 111 | all_weights.append(deserialized_weights) 112 | if task.response_data.loss is not None: 113 | loss = task.response_data.loss 114 | num_samples = len(results) 115 | epoch_device_losses.append((loss, num_samples)) 116 | del self.worker.task_manager.tasks[task_id] 117 | 118 | if all_weights: 119 | averaged_weights = average_model_weights(all_weights) 120 | self.model.set_weights(averaged_weights) 121 | 122 | if epoch_device_losses: 123 | average_loss = average_epoch_loss(epoch_device_losses) 124 | self.history[f"{request_type}_loss"].append(average_loss) 125 | 126 | return outputs 127 | 128 | async def _dispatch_gather(self, request_config, datasets, request_type): 129 | await self._dispatch(request_config, datasets, request_type) 130 | return self._gather(request_type) 131 | 132 | def _print_progress(self, epoch, epochs): 133 | """Print progress of training""" 134 | if not "train_loss" in self.history: 135 | return 136 | 137 | log = f"Epoch {epoch + 1}/{epochs} - Loss: {self.history['train_loss'][-1]}" 138 | if self._to_validate(): 139 | log += f" - Validation Loss: {self.history['evaluate_loss'][-1]}" 140 | print(log) 141 | 142 | async def _fit(self, epochs): 143 | """Run federated training process""" 144 | available_devices = self.worker.load_available_devices() 145 | print(f"Training on {len(available_devices)} devices") 146 | 147 | async def fit_epoch(epoch): 148 | request_config = self._create_base_request_config(epochs) 149 | 150 | datasets = split_datasets( 151 | self.inputs, 152 | available_devices, 153 | self.outputs, 154 | include_outputs=True, 155 | ) 156 | 157 | await self._dispatch_gather(request_config, datasets, "train") 158 | 159 | if self._to_validate(): 160 | await self._evaluate() 161 | 162 | self._print_progress(epoch, epochs) 163 | 164 | for epoch in range(epochs): 165 | await fit_epoch(epoch) 166 | 167 | def fit(self, epochs: int) -> None: 168 | """Run federated training process""" 169 | asyncio.run(self._fit(epochs)) 170 | 171 | async def _evaluate(self) -> None: 172 | """Run distributed evaluation across all devices""" 173 | request_config = self._create_base_request_config() 174 | 175 | datasets = split_datasets( 176 | self.validation_inputs, 177 | self.worker.load_available_devices(), 178 | self.validation_outputs, 179 | include_outputs=True, 180 | ) 181 | 182 | await self._dispatch_gather(request_config, datasets, "evaluate") 183 | 184 | def evaluate(self) -> None: 185 | """Run distributed evaluation across all devices""" 186 | asyncio.run(self._evaluate()) 187 | 188 | async def _predict(self, inputs: np.ndarray) -> Tuple[np.ndarray, Optional[float]]: 189 | """Run distributed prediction across all devices""" 190 | request_config = self._create_base_request_config() 191 | datasets = split_datasets(inputs, self.worker.load_available_devices()) 192 | return await self._dispatch_gather(request_config, datasets, "predict") 193 | 194 | def predict(self, inputs: np.ndarray) -> Tuple[np.ndarray, Optional[float]]: 195 | """Run distributed prediction across all devices""" 196 | return asyncio.run(self._predict(inputs)) 197 | -------------------------------------------------------------------------------- /framework/mfl/worker.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from dataclasses import asdict, dataclass 4 | from datetime import datetime, timedelta, timezone 5 | from typing import Dict, List, Optional 6 | 7 | from dateutil.parser import parse 8 | from dotenv import load_dotenv 9 | from realtime._async.client import AsyncRealtimeClient 10 | from supabase import Client, create_client 11 | 12 | load_dotenv() 13 | SUPABASE_URL = os.getenv("SUPABASE_URL") 14 | SUPABASE_ANON_KEY = os.getenv("SUPABASE_ANON_KEY") 15 | TASK_TIMEOUT = 10 # seconds 16 | supabase: Client = create_client(SUPABASE_URL, SUPABASE_ANON_KEY) 17 | 18 | 19 | @dataclass 20 | class RequestConfig: 21 | modelJson: str 22 | weights: List[float] 23 | batchSize: int 24 | inputs: List[List[float]] = None 25 | inputShape: List[int] = None 26 | outputs: Optional[List[List[float]]] = None 27 | outputShape: Optional[List[int]] = None 28 | epochs: Optional[int] = None 29 | datasetsPerDevice: Optional[int] = None 30 | 31 | 32 | @dataclass 33 | class ResponseConfig: 34 | weights: List[List[float]] 35 | outputs: Optional[List[List[float]]] = None 36 | loss: Optional[float] = None 37 | 38 | 39 | @dataclass 40 | class Task: 41 | """ 42 | A class to manage tasks broadcasted to the mfl network. 43 | """ 44 | 45 | request_data: RequestConfig 46 | sent_at: datetime 47 | response_data: Optional[ResponseConfig] = None 48 | 49 | @property 50 | def is_completed(self) -> bool: 51 | return self.response_data is not None 52 | 53 | @property 54 | def is_expired(self) -> bool: 55 | return (not self.is_completed) and ( 56 | (datetime.now(timezone.utc) - self.sent_at).total_seconds() > TASK_TIMEOUT 57 | ) 58 | 59 | 60 | class TaskManager: 61 | """ 62 | A class to manage each consumer's collection of tasks 63 | """ 64 | 65 | def __init__(self): 66 | self.tasks: Dict[int, Task] = {} 67 | 68 | def __repr__(self) -> str: 69 | task_list = [] 70 | for task_id, task in self.tasks.items(): 71 | status = "complete" if task.is_completed else "incomplete" 72 | task_list.append(f"Task {task_id}: {status}") 73 | return "\n".join(task_list) 74 | 75 | def create_task( 76 | self, task_id: int, request_data: RequestConfig, sent_at: datetime 77 | ) -> None: 78 | if task_id in self.tasks: 79 | raise ValueError(f"Task {task_id} already exists.") 80 | self.tasks[task_id] = Task(request_data=request_data, sent_at=sent_at) 81 | 82 | def discard_task(self, task_id: int) -> None: 83 | del self.tasks[task_id] 84 | 85 | def log_completion(self, task_id: int, response_data: ResponseConfig) -> None: 86 | if task_id not in self.tasks: 87 | raise KeyError(f"Task {task_id} does not exist.") 88 | self.tasks[task_id].response_data = response_data 89 | 90 | @property 91 | def expired_tasks(self): 92 | return { 93 | task_id: task for task_id, task in self.tasks.items() if task.is_expired 94 | } 95 | 96 | @property 97 | def completed_tasks(self): 98 | return { 99 | task_id: task for task_id, task in self.tasks.items() if task.is_completed 100 | } 101 | 102 | @property 103 | def incomplete_tasks(self): 104 | return { 105 | task_id: task 106 | for task_id, task in self.tasks.items() 107 | if not task.is_completed 108 | } 109 | 110 | 111 | class Worker: 112 | """ 113 | A class to manage the worker's interactions with the mfl network 114 | """ 115 | 116 | def __init__(self, _id: int) -> None: 117 | self.id = _id 118 | self.task_manager = TaskManager() 119 | self.timeout = False 120 | 121 | def send_task( 122 | self, device_id: int, request_type: str, request_data: RequestConfig 123 | ) -> bool: 124 | """ 125 | Main method that sends a request to a given device 126 | """ 127 | try: 128 | response = ( 129 | supabase.table("task_requests") 130 | .insert( 131 | { 132 | "device_id": device_id, 133 | "request_type": request_type, 134 | "data": asdict(request_data), 135 | "consumer_id": self.id, 136 | } 137 | ) 138 | .execute() 139 | ) 140 | 141 | self.task_manager.create_task( 142 | task_id=response.data[0]["id"], 143 | request_data=request_data, 144 | sent_at=parse(response.data[0]["created_at"]), 145 | ) 146 | # print(f"Sent task {response.data[0]['id']}") 147 | return True 148 | except Exception as e: 149 | print(f"Error sending job request: {e}") 150 | return False 151 | 152 | async def _connect_to_realtime(self): 153 | """ 154 | We subscribe to realtime updates to ALL supabase tables and 155 | pass them to the centralized callback function 156 | """ 157 | client = AsyncRealtimeClient( 158 | f"{SUPABASE_URL}/realtime/v1", SUPABASE_ANON_KEY, auto_reconnect=False 159 | ) 160 | await client.connect() 161 | 162 | device_channel = client.channel(f"devices") 163 | task_completion_channel = client.channel(f"task_responses") 164 | 165 | self.available_devices = self.load_available_devices() 166 | 167 | await device_channel.on_postgres_changes( 168 | "UPDATE", 169 | schema="public", 170 | table='devices', 171 | callback=self._device_update_callback, 172 | ).subscribe() 173 | 174 | await task_completion_channel.on_postgres_changes( 175 | "INSERT", 176 | schema="public", 177 | table='task_responses', 178 | callback=self._task_update_callback, 179 | ).subscribe() 180 | 181 | self.listener = asyncio.create_task(client.listen()) 182 | 183 | async def run( 184 | self, request_configs: List[RequestConfig], request_type: str 185 | ) -> None: 186 | """Multi-device federated learning process""" 187 | assert request_type in ( 188 | "train", 189 | "evaluate", 190 | "predict", 191 | ), "Unsupported request type!" 192 | self.request_type = request_type 193 | self.request_configs = request_configs 194 | 195 | await self._connect_to_realtime() 196 | 197 | try: 198 | for device_id in self.available_devices: 199 | if self.request_configs: 200 | self.send_task( 201 | device_id=device_id, 202 | request_type=self.request_type, 203 | request_data=self.request_configs[0], 204 | ) 205 | self.request_configs.pop(0) 206 | 207 | while not self.timeout and self.task_manager.incomplete_tasks: 208 | await asyncio.sleep(0.01) 209 | await self._check_task_timeouts() 210 | 211 | except Exception as e: 212 | print(e) 213 | finally: 214 | # Cancel the listening task and disconnect cleanly 215 | self.listener.cancel() 216 | 217 | def load_available_devices(self) -> List[int]: 218 | "Retrieve devices available at any given point and store them in self.available_devices" 219 | try: 220 | response = ( 221 | supabase.table("devices") 222 | .select("id") 223 | .eq("status", "available") 224 | .gte( 225 | "last_updated", 226 | (datetime.now(timezone.utc) - timedelta(minutes=1)).isoformat(), 227 | ) 228 | .execute() 229 | ) 230 | return [device["id"] for device in response.data] 231 | except Exception as e: 232 | print("Unable to load devices (unexpected error): ", e) 233 | return [] 234 | 235 | async def _check_task_timeouts(self): 236 | """ 237 | Checks whether we have hit timeout 238 | """ 239 | if self.task_manager.expired_tasks: 240 | self.timeout = True 241 | 242 | def _task_update_callback(self, payload: Dict) -> None: 243 | """ 244 | Callback to handle responses from the tasks table. Here, we: 245 | 246 | - parse the incoming payload 247 | - log completion of the task 248 | - if there are request_configs remaining (happens when the number of requested configs 249 | was higher than the initial number of available devices), we send the device that 250 | completed a task its next request config. 251 | """ 252 | record = payload.get("data", {}).get("record") 253 | task_id = record.get("id") 254 | if task_id in self.task_manager.tasks: 255 | self.task_manager.log_completion( 256 | task_id=task_id, 257 | response_data=ResponseConfig(**record["data"]), 258 | ) 259 | else: 260 | print(f"Received task ID not found in my tasks: {task_id}") 261 | 262 | def _device_update_callback(self, payload: Dict) -> None: 263 | """ 264 | Callback to handle responses from the devices table. Here, we: 265 | 266 | - extract device ID and new status 267 | - add it to the list of available devices 268 | """ 269 | record = payload.get("data", {}).get("record") 270 | device_id = record.get("id") 271 | status = record.get("status") 272 | if status == "available" and device_id not in self.available_devices: 273 | self.available_devices.append(device_id) 274 | elif status == 'unavailable': 275 | if device_id in self.available_devices: 276 | self.available_devices.remove(device_id) 277 | else: 278 | print(f"received unavailability update for device id {device_id}??") -------------------------------------------------------------------------------- /framework/mfl/write_weights.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | import math 4 | import os 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from .quantization import map_layers_to_quantization_dtype, quantize_weights 10 | from .read_weights import STRING_LENGTH_DTYPE 11 | 12 | _OUTPUT_DTYPES = [ 13 | np.float16, 14 | np.float32, 15 | np.int32, 16 | np.complex64, 17 | np.uint8, 18 | np.uint16, 19 | bool, 20 | object, 21 | ] 22 | _AUTO_DTYPE_CONVERSION = { 23 | np.dtype(np.float16): np.float32, 24 | np.dtype(np.float64): np.float32, 25 | np.dtype(np.int64): np.int32, 26 | np.dtype(np.complex128): np.complex64, 27 | } 28 | 29 | 30 | def write_weights( 31 | weight_groups, 32 | write_dir, 33 | shard_size_bytes=1024 * 1024 * 4, 34 | write_manifest=True, 35 | quantization_dtype_map=None, 36 | ): 37 | """Writes weights to a binary format on disk for ingestion by JavaScript. 38 | 39 | Weights are organized into groups. When writing to disk, the bytes from all 40 | weights in each group are concatenated together and then split into shards 41 | (default is 4MB). This means that large weights (> shard_size) get sharded 42 | and small weights (< shard_size) will be packed. If the bytes can't be split 43 | evenly into shards, there will be a leftover shard that is smaller than the 44 | shard size. 45 | 46 | Weights are optionally quantized to either 8 or 16 bits for compression, 47 | which is enabled via the `quantization_dtype_map`. 48 | 49 | Args: 50 | weight_groups: An list of groups. Each group is an array of weight 51 | entries. Each entry is a dict that maps a unique name to a numpy array, 52 | for example: 53 | entry = { 54 | 'name': 'weight1', 55 | 'data': np.array([1, 2, 3], 'float32') 56 | } 57 | 58 | Weights groups would then look like: 59 | weight_groups = [ 60 | [group_0_entry1, group_0_entry2], 61 | [group_1_entry1, group_1_entry2], 62 | ] 63 | 64 | The 'name' must be unique across all groups and all entries. The 'data' 65 | field must be a numpy ndarray. 66 | write_dir: A directory to write the files to. 67 | shard_size_bytes: The size of shards in bytes. Defaults to 4MB, which is 68 | the max file size for caching for all major browsers. 69 | write_manifest: Whether to write the manifest JSON to disk. Defaults to 70 | True. 71 | quantization_dtype_map: (Optional) A mapping from dtype 72 | (`uint8`, `uint16`, `float16`) to weights names. The weight mapping 73 | supports wildcard substitution. 74 | Returns: 75 | The weights manifest JSON dict. 76 | 77 | An example manifest with 2 groups, 2 weights, and each weight sharded 78 | into 2: 79 | 80 | The manifest JSON looks like the following: 81 | [{ 82 | 'paths': ['group1-shard1of2', 'group1-shard2of2'], 83 | 'weights': [{ 84 | 'name': 'weight1', 85 | 'shape': [1000, 1000], 86 | 'dtype': 'float32' 87 | }] 88 | }, { 89 | 'paths': ['group2-shard1of2', 'group2-shard2of2'], 90 | 'weights': [{ 91 | 'name': 'weight2', 92 | 'shape': [2000, 2000], 93 | 'dtype': 'float32' 94 | }] 95 | }] 96 | or, if quantization is used: 97 | [{ 98 | 'paths': ['group1-shard1of2', 'group1-shard2of2'], 99 | 'weights': [{ 100 | 'name': 'weight1', 101 | 'shape': [1000, 1000], 102 | 'dtype': 'float32' 103 | 'quantization': {'min': -0.1, 'scale': 0.01, 'dtype': 'uint8'} 104 | }] 105 | }, { 106 | 'paths': ['group2-shard1of2', 'group2-shard2of2'], 107 | 'weights': [{ 108 | 'name': 'weight2', 109 | 'shape': [2000, 2000], 110 | 'dtype': 'float32', 111 | 'quantization': {'dtype': 'float16'} 112 | }] 113 | }] 114 | """ 115 | _assert_weight_groups_valid(weight_groups) 116 | _assert_shard_size_bytes_valid(shard_size_bytes) 117 | _assert_no_duplicate_weight_names(weight_groups) 118 | 119 | manifest = [] 120 | 121 | for group_index, group in enumerate(weight_groups): 122 | for e in group: 123 | _auto_convert_weight_entry(e) 124 | names = [entry["name"] for entry in group] 125 | quantization_dtype = map_layers_to_quantization_dtype( 126 | names, quantization_dtype_map 127 | ) 128 | 129 | group = [ 130 | ( 131 | _quantize_entry(e, quantization_dtype[e["name"]]) 132 | if e["name"] in quantization_dtype 133 | else e 134 | ) 135 | for e in group 136 | ] 137 | group_bytes, total_bytes, _ = _stack_group_bytes(group) 138 | 139 | shard_filenames = _shard_group_bytes_to_disk( 140 | write_dir, group_index, group_bytes, total_bytes, shard_size_bytes 141 | ) 142 | 143 | weights_entries = _get_weights_manifest_for_group(group) 144 | manifest_entry = {"paths": shard_filenames, "weights": weights_entries} 145 | manifest.append(manifest_entry) 146 | 147 | if write_manifest: 148 | manifest_path = os.path.join(write_dir, "weights_manifest.json") 149 | with tf.io.gfile.GFile(manifest_path, "wb") as f: 150 | f.write(json.dumps(manifest).encode()) 151 | 152 | return manifest 153 | 154 | 155 | def _quantize_entry(entry, quantization_dtype): 156 | """Quantizes the weights in the entry, returning a new entry. 157 | 158 | The weights are quantized by linearly re-scaling the values between the 159 | minimum and maximum value, and representing them with the number of bits 160 | provided by the `quantization_dtype`. 161 | 162 | In order to guarantee that 0 is perfectly represented by one of the quanzitzed 163 | values, the range is "nudged" in the same manner as in TF-Lite. 164 | 165 | Args: 166 | entry: A weight entries to quantize. 167 | quantization_dtype: An numpy dtype to quantize weights to. 168 | Only np.uint8, np.uint16, and np.float16 are supported. 169 | 170 | Returns: 171 | A new entry containing the quantized data and additional quantization info, 172 | for example: 173 | original_entry = { 174 | 'name': 'weight1', 175 | 'data': np.array([0, -0.1, 1.2], 'float32') 176 | } 177 | quantized_entry = { 178 | 'name': 'weight1', 179 | 'data': np.array([20, 0, 255], 'uint8') 180 | 'quantization': {'min': -0.10196078817, 'scale': 0.00509803940852, 181 | 'dtype': 'uint8', 'original_dtype': 'float32'} 182 | } 183 | """ 184 | data = entry["data"] 185 | # Only float32 tensors are quantized. 186 | if data.dtype != "float32": 187 | return entry 188 | quantized_data, metadata = quantize_weights(data, quantization_dtype) 189 | metadata.update({"original_dtype": data.dtype.name}) 190 | quantized_entry = entry.copy() 191 | quantized_entry["data"] = quantized_data 192 | quantized_entry["quantization"] = metadata 193 | return quantized_entry 194 | 195 | 196 | def _serialize_string_array(data): 197 | """Serializes a numpy array of dtype `string` into bytes. 198 | 199 | Each string value is preceded by 4 bytes which denote a 32-bit unsigned 200 | integer in little endian that specifies the byte length of the following 201 | string. This is followed by the actual string bytes. If the tensor has no 202 | strings there will be no bytes reserved. Empty strings will still take 4 bytes 203 | for the length. 204 | 205 | For example, a tensor that has 2 strings will be encoded as 206 | [byte length of s1][bytes of s1...][byte length of s2][bytes of s2...] 207 | 208 | where byte length always takes 4 bytes. 209 | 210 | Args: 211 | data: A numpy array of dtype `string`. 212 | 213 | Returns: 214 | bytes of the entire string tensor to be serialized on disk. 215 | """ 216 | strings = data.flatten().tolist() 217 | 218 | string_bytes = io.BytesIO() 219 | bytes_writer = io.BufferedWriter(string_bytes) 220 | 221 | for x in strings: 222 | encoded = x if isinstance(x, bytes) else x.encode("utf-8") 223 | length_as_bytes = np.array(len(encoded), STRING_LENGTH_DTYPE).tobytes() 224 | bytes_writer.write(length_as_bytes) 225 | bytes_writer.write(encoded) 226 | bytes_writer.flush() 227 | string_bytes.seek(0) 228 | return string_bytes.read() 229 | 230 | 231 | def _serialize_numeric_array(data): 232 | """Serializes a numeric numpy array into bytes. 233 | 234 | Args: 235 | data: A numeric numpy array. 236 | 237 | Returns: 238 | bytes of the array to be serialized on disk. 239 | """ 240 | return data.tobytes() 241 | 242 | 243 | def _stack_group_bytes(group): 244 | """Stacks the bytes for a weight group into a flat byte array. 245 | 246 | Args: 247 | group: A list of weight entries. 248 | Returns: 249 | A type: (group_bytes, total_bytes, weights_entries, group_bytes_writer) 250 | group_bytes: The stacked bytes for the group, as a BytesIO() stream. 251 | total_bytes: A number representing the total size of the byte buffer. 252 | groups_bytes_writer: The io.BufferedWriter object. Returned so that 253 | group_bytes does not get garbage collected and closed. 254 | 255 | """ 256 | group_bytes = io.BytesIO() 257 | group_bytes_writer = io.BufferedWriter(group_bytes) 258 | total_bytes = 0 259 | 260 | for entry in group: 261 | _assert_valid_weight_entry(entry) 262 | data = entry["data"] 263 | 264 | if data.dtype == object: 265 | data_bytes = _serialize_string_array(data) 266 | else: 267 | data_bytes = _serialize_numeric_array(data) 268 | group_bytes_writer.write(data_bytes) 269 | total_bytes += len(data_bytes) 270 | 271 | group_bytes_writer.flush() 272 | group_bytes.seek(0) 273 | 274 | # NOTE: We must return the bytes writer here, otherwise it goes out of scope 275 | # and python closes the IO operation. 276 | return (group_bytes, total_bytes, group_bytes_writer) 277 | 278 | 279 | def _shard_group_bytes_to_disk( 280 | write_dir, group_index, group_bytes, total_bytes, shard_size_bytes 281 | ): 282 | """Shards the concatenated bytes for a group to disk. 283 | 284 | Args: 285 | write_dir: The directory to write the files to. 286 | group_index: The index for the group. 287 | group_bytes: An io.BytesIO() object representing the byte array. 288 | total_bytes: The total number of bytes of the stream. 289 | shard_size_bytes: The size of shards in bytes. If None, the whole byte 290 | array will be written as one shard. 291 | Returns: 292 | A list of filenames that were written to disk. 293 | """ 294 | if shard_size_bytes is None: 295 | shard_size_bytes = total_bytes 296 | 297 | num_shards = int(math.ceil(float(total_bytes) / shard_size_bytes)) 298 | 299 | filenames = [] 300 | for i in range(num_shards): 301 | shard = group_bytes.read(shard_size_bytes) 302 | 303 | filename = "group%d-shard%dof%d.bin" % (group_index + 1, i + 1, num_shards) 304 | filenames.append(filename) 305 | filepath = os.path.join(write_dir, filename) 306 | 307 | # Write the shard to disk. 308 | with tf.io.gfile.GFile(filepath, "wb") as f: 309 | f.write(shard) 310 | 311 | return filenames 312 | 313 | 314 | def _get_weights_manifest_for_group(group): 315 | """Gets the weights entries manifest JSON for a group. 316 | 317 | Args: 318 | group: A list of weight entries. 319 | Returns: 320 | An list of manifest entries (dicts) to be written in the weights manifest. 321 | """ 322 | weights_entries = [] 323 | for entry in group: 324 | is_quantized = "quantization" in entry 325 | dtype = ( 326 | entry["quantization"]["original_dtype"] 327 | if is_quantized 328 | else entry["data"].dtype.name 329 | ) 330 | var_manifest = { 331 | "name": entry["name"], 332 | "shape": list(entry["data"].shape), 333 | "dtype": dtype, 334 | } 335 | # String arrays have dtype 'object' and need extra metadata to parse. 336 | if dtype == "object": 337 | var_manifest["dtype"] = "string" 338 | if is_quantized: 339 | manifest = {"dtype": entry["data"].dtype.name} 340 | manifest.update(entry["quantization"]) 341 | var_manifest["quantization"] = manifest 342 | weights_entries.append(var_manifest) 343 | return weights_entries 344 | 345 | 346 | def _assert_no_duplicate_weight_names(weight_groups): 347 | weight_names = set() 348 | for group in weight_groups: 349 | for entry in group: 350 | name = entry["name"] 351 | if name in weight_names: 352 | raise Exception("Error dumping weights, duplicate weight name " + name) 353 | weight_names.add(name) 354 | 355 | 356 | def _auto_convert_weight_entry(entry): 357 | data = entry["data"] 358 | if data.dtype in _AUTO_DTYPE_CONVERSION: 359 | entry["data"] = data.astype(_AUTO_DTYPE_CONVERSION[data.dtype]) 360 | print( 361 | "weight " 362 | + entry["name"] 363 | + " with shape " 364 | + str(data.shape) 365 | + " and dtype " 366 | + data.dtype.name 367 | + " was auto converted to the type " 368 | + np.dtype(_AUTO_DTYPE_CONVERSION[data.dtype]).name 369 | ) 370 | 371 | 372 | def _assert_valid_weight_entry(entry): 373 | if "name" not in entry: 374 | raise ValueError("Error dumping weight, no name field found.") 375 | if "data" not in entry: 376 | raise ValueError("Error dumping weight, no data field found.") 377 | 378 | name = entry["name"] 379 | data = entry["data"] 380 | 381 | # String tensors can be backed by different numpy dtypes, thus we consolidate 382 | # to a single 'object' dtype. 383 | if data.dtype.name.startswith("str") or data.dtype.name.startswith("bytes"): 384 | data = data.astype(object) 385 | entry["data"] = data 386 | 387 | if not (data.dtype in _OUTPUT_DTYPES or data.dtype in _AUTO_DTYPE_CONVERSION): 388 | raise ValueError( 389 | "Error dumping weight " 390 | + name 391 | + ", dtype " 392 | + data.dtype.name 393 | + " not supported." 394 | ) 395 | 396 | if not isinstance(data, np.ndarray): 397 | raise ValueError( 398 | "Error dumping weight " + name + ", data " + "must be a numpy ndarray." 399 | ) 400 | 401 | 402 | def _assert_weight_groups_valid(weight_groups): 403 | if not isinstance(weight_groups, list): 404 | raise Exception("weight_groups must be a list of groups") 405 | if not weight_groups: 406 | raise ValueError("weight_groups must have more than one list element") 407 | for i, weight_group in enumerate(weight_groups): 408 | if not isinstance(weight_group, list): 409 | raise ValueError( 410 | "weight_groups[" + i + "] must be a list of weight entries" 411 | ) 412 | for j, weights in enumerate(weight_group): 413 | if "name" not in weights: 414 | raise ValueError( 415 | "weight_groups[" + i + "][" + j + "] has no string field 'name'" 416 | ) 417 | if "data" not in weights: 418 | raise ValueError( 419 | "weight_groups[" 420 | + i 421 | + "][" 422 | + j 423 | + "] has no numpy " 424 | + "array field 'data'" 425 | ) 426 | if not isinstance(weights["data"], np.ndarray): 427 | raise ValueError( 428 | "weight_groups[" 429 | + i 430 | + "][" 431 | + j 432 | + "]['data'] is not a numpy " 433 | + "array" 434 | ) 435 | 436 | 437 | def _assert_shard_size_bytes_valid(shard_size_bytes): 438 | if shard_size_bytes <= 0: 439 | raise ValueError( 440 | "shard_size_bytes must be greater than 0, but got %s" % shard_size_bytes 441 | ) 442 | if not isinstance(shard_size_bytes, int): 443 | raise ValueError( 444 | "shard_size_bytes must be an integer, but got %s" % shard_size_bytes 445 | ) 446 | -------------------------------------------------------------------------------- /framework/requirements.txt: -------------------------------------------------------------------------------- 1 | importlib_resources>=5.9.0 2 | tensorflow>=2.13.0,<3 3 | tf-keras>=2.16.0 4 | tensorflow-decision-forests>=1.9.0 5 | six>=1.16.0,<2 6 | tensorflow-hub>=0.16.1 7 | packaging~=23.1 8 | supabase==2.10.0 9 | realtime==2.0.0 10 | python-dotenv -------------------------------------------------------------------------------- /framework/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="mfl", 5 | version="0.0.0", 6 | author="Henry Ndubuaku, Roman Shemet, and James Unsworth", 7 | description="A framework for distributing ML to mobile devices", 8 | long_description=open("./README.md").read(), 9 | long_description_content_type="text/markdown", 10 | url="https://github.com/HMUNACHI", 11 | packages=find_packages(), 12 | install_requires=[ 13 | "importlib_resources>=5.9.0", 14 | "tensorflow>=2.13.0,<3", 15 | "tf-keras>=2.16.0", 16 | "tensorflow-decision-forests>=1.9.0", 17 | "six>=1.16.0,<2", 18 | "tensorflow-hub>=0.16.1", 19 | "packaging~=23.1", 20 | "supabase==2.10.0", 21 | "realtime==2.0.0", 22 | ], 23 | classifiers=[ 24 | "Development Status :: 3 - Alpha", 25 | "Intended Audience :: Developers", 26 | "Intended Audience :: Science/Research", 27 | "Topic :: Software Development :: Build Tools", 28 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 29 | "Programming Language :: Python :: 3.9", 30 | ], 31 | python_requires=">=3.9", 32 | ) 33 | --------------------------------------------------------------------------------