├── .github ├── FUNDING.yml └── workflows │ ├── python-publish.yml │ └── unit-tests.yml ├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── assets └── logo.jpg ├── docs ├── api.md ├── examples │ ├── diffusion_example.ipynb │ ├── gat_example.ipynb │ ├── gpt4_example.ipynb │ ├── llama_example.ipynb │ ├── mistral_copy_example.ipynb │ ├── sklearn_gpu_example.ipynb │ ├── utils_examples.ipynb │ └── whisper_example.ipynb ├── index.md ├── requirements.in ├── requirements.txt └── usage.md ├── mkdocs.yml ├── nanodl ├── __init__.py └── __src │ ├── __init__.py │ ├── classical │ ├── __init__.py │ ├── bayes.py │ ├── clustering.py │ ├── dimensionality_reduction.py │ ├── dsp.py │ └── regression.py │ ├── experimental │ ├── __init__.py │ ├── bitlinear.py │ ├── gat.py │ ├── mamba.py │ ├── rlhf.py │ └── tokenizer.py │ ├── models │ ├── __init__.py │ ├── attention.py │ ├── clip.py │ ├── diffusion.py │ ├── gemma.py │ ├── gpt.py │ ├── ijepa.py │ ├── kan.py │ ├── lamda.py │ ├── llama.py │ ├── mistral.py │ ├── mixer.py │ ├── reward.py │ ├── t5.py │ ├── transformer.py │ ├── vit.py │ └── whisper.py │ └── utils │ ├── __init__.py │ ├── data.py │ ├── ml.py │ ├── nlp.py │ ├── random.py │ └── vision.py ├── requirements.txt ├── setup.py └── tests ├── files └── sample.txt ├── test_classic.py ├── test_kan.py ├── test_models.py ├── test_random.py └── test_utils.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [HMUNACHI] 2 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine 20 | - name: Build and publish 21 | env: 22 | TWINE_USERNAME: __token__ 23 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 24 | run: | 25 | echo "Setting version to: ${GITHUB_REF#refs/tags/}" 26 | sed -i "s/version='.*',/version='${GITHUB_REF#refs/tags/}',/" setup.py 27 | sed -i "s/__version__ = '.*'/__version__ = '${GITHUB_REF#refs/tags/}'/" nanodl/__init__.py 28 | python setup.py sdist bdist_wheel 29 | twine upload dist/* -------------------------------------------------------------------------------- /.github/workflows/unit-tests.yml: -------------------------------------------------------------------------------- 1 | name: Python Unit Tests 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | test: 7 | 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: '3.10' 16 | - name: Install dependencies 17 | run: | 18 | pip install -r requirements.txt 19 | - name: Run tests 20 | run: | 21 | python -m unittest discover -s tests 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore compiled object files, libraries, and executables 2 | *.o 3 | *.a 4 | *.exe 5 | 6 | # Ignore system-specific files 7 | .DS_Store 8 | Thumbs.db 9 | 10 | # Ignore logs and temporary files 11 | *.log 12 | *.tmp 13 | 14 | # Ignore build directories 15 | /build/ 16 | /dist/ 17 | /node_modules/ 18 | /.venv/ 19 | /.vscode/ 20 | /nanodl.egg-info/ 21 | __pycache__/ 22 | 23 | # Ignore configuration files with sensitive information 24 | config.ini 25 | secrets.yaml 26 | params.pkl 27 | base_params.pkl 28 | reward_params.pkl 29 | 30 | # Ignore user-specific files 31 | /userdata/ 32 | /user-settings.json 33 | 34 | # ignore specific files 35 | /archive/ -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.10" 7 | 8 | mkdocs: 9 | configuration: mkdocs.yml 10 | 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Henry Ndubuaku 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Alt text 3 |

4 | 5 | # A Jax-based library for designing and training transformer models from scratch. 6 | 7 | ![License](https://img.shields.io/github/license/hmunachi/nanodl?style=flat-square) [![Read the Docs](https://img.shields.io/readthedocs/nanodl?labelColor=blue&color=white)](https://nanodl.readthedocs.io) [![Discord](https://img.shields.io/discord/1222217369816928286?style=social&logo=discord&label=Discord&color=white)](https://discord.gg/3u9vumJEmz) [![LinkedIn](https://img.shields.io/badge/-LinkedIn-blue?style=flat-square&logo=linkedin&logoColor=white)](https://www.linkedin.com//company/80434055) [![Twitter](https://img.shields.io/twitter/follow/hmunachii?style=social)](https://twitter.com/hmunachii) 8 | 9 | Author: [Henry Ndubuaku](https://www.linkedin.com/in/henry-ndubuaku-7b6350b8/) (Discord & Docs badges are clickable) 10 | 11 | N/B: Codes are implemented pedagogically at the expense of repetition. 12 | Each model is purposefully contained in a file without inter-file dependencies. 13 | 14 | ## Overview 15 | Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks and abstracts distributed training, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features: 16 | 17 | - A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch. 18 | - An extensive selection of models like Gemma, LlaMa3, Mistral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, CLIP etc. 19 | - Data-parallel distributed trainers models on multiple GPUs or TPUs, without the need for manual training loops. 20 | - Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective. 21 | - Layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development. 22 | - GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc. 23 | - True random number generators in Jax which do not need the verbose code. 24 | - A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU, Tokenizer etc. 25 | - Each model is contained in a single file with no external dependencies, so the source code can also be easily used. 26 | - True random number generators in Jax which do not need the verbose code (examples shown in next sections). 27 | 28 | There are experimental and/or unfinished features (like MAMBA, KAN, BitNet, GAT and RLHF) 29 | in the repo which are not yet available via the package, but can be copied from this repo. 30 | Feedback on any of our discussion, issue and pull request threads are welcomed! 31 | Please report any feature requests, issues, questions or concerns in the [Discord](https://discord.gg/3u9vumJEmz), 32 | or just let us know what you're working on! 33 | 34 | ## Quick install 35 | 36 | You will need Python 3.9 or later, and working [JAX](https://github.com/google/jax/blob/main/README.md) 37 | installation, [FLAX](https://github.com/google/flax/blob/main/README.md) 38 | installation, [OPTAX](https://github.com/google-deepmind/optax/blob/main/README.md) 39 | installation (with GPU support for running training, without can only support creations). 40 | Models can be designed and tested on CPUs but trainers are all Distributed Data-Parallel 41 | which would require a GPU with 1 to N GPUS/TPUS. For CPU-only version of JAX: 42 | 43 | ``` 44 | pip install --upgrade pip # To support manylinux2010 wheels. 45 | pip install jax flax optax 46 | ``` 47 | 48 | Then, install nanodl from PyPi: 49 | 50 | ``` 51 | pip install nanodl 52 | ``` 53 | 54 | ## What does nanodl look like? 55 | 56 | We provide various example usages of the nanodl API. 57 | 58 | ```py 59 | import jax 60 | import nanodl 61 | import jax.numpy as jnp 62 | from nanodl import ArrayDataset, DataLoader 63 | from nanodl import GPT4, GPTDataParallelTrainer 64 | 65 | # Preparing your dataset 66 | batch_size = 8 67 | max_length = 50 68 | vocab_size = 1000 69 | 70 | # Create random data 71 | data = nanodl.uniform( 72 | shape=(batch_size, max_length), 73 | minval=0, maxval=vocab_size-1 74 | ).astype(jnp.int32) 75 | 76 | # Shift to create next-token prediction dataset 77 | dummy_inputs, dummy_targets = data[:, :-1], data[:, 1:] 78 | 79 | # Create dataset and dataloader 80 | dataset = ArrayDataset(dummy_inputs, dummy_targets) 81 | dataloader = DataLoader( 82 | dataset, batch_size=batch_size, shuffle=True, drop_last=False 83 | ) 84 | 85 | # model parameters 86 | hyperparams = { 87 | 'num_layers': 1, 88 | 'hidden_dim': 256, 89 | 'num_heads': 2, 90 | 'feedforward_dim': 256, 91 | 'dropout': 0.1, 92 | 'vocab_size': vocab_size, 93 | 'embed_dim': 256, 94 | 'max_length': max_length, 95 | 'start_token': 0, 96 | 'end_token': 50, 97 | } 98 | 99 | # Inferred GPT4 model 100 | model = GPT4(**hyperparams) 101 | 102 | trainer = GPTDataParallelTrainer( 103 | model, dummy_inputs.shape, 'params.pkl' 104 | ) 105 | 106 | trainer.train( 107 | train_loader=dataloader, num_epochs=100, val_loader=dataloader 108 | ) # use actual val data 109 | 110 | # Generating from a start token 111 | start_tokens = jnp.array([[123, 456]]) 112 | 113 | # Remember to load the trained parameters 114 | params = trainer.load_params('params.pkl') 115 | 116 | outputs = model.apply( 117 | {'params': params}, 118 | start_tokens, 119 | rngs={'dropout': nanodl.time_rng_key()}, 120 | method=model.generate 121 | ) 122 | ``` 123 | 124 | Vision example 125 | 126 | ```py 127 | import nanodl 128 | import jax.numpy as jnp 129 | from nanodl import ArrayDataset, DataLoader 130 | from nanodl import DiffusionModel, DiffusionDataParallelTrainer 131 | 132 | image_size = 32 133 | block_depth = 2 134 | batch_size = 8 135 | widths = [32, 64, 128] 136 | input_shape = (101, image_size, image_size, 3) 137 | images = nanodl.normal(shape=input_shape) 138 | 139 | # Use your own images 140 | dataset = ArrayDataset(images) 141 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) 142 | 143 | # Create diffusion model 144 | diffusion_model = DiffusionModel(image_size, widths, block_depth) 145 | 146 | # Training on your data 147 | trainer = DiffusionDataParallelTrainer(diffusion_model, 148 | input_shape=images.shape, 149 | weights_filename='params.pkl', 150 | learning_rate=1e-4) 151 | 152 | trainer.train(dataloader, 10) 153 | 154 | # Generate some samples: Each model is a Flax.linen module 155 | # Use as you normally would 156 | params = trainer.load_params('params.pkl') 157 | generated_images = diffusion_model.apply({'params': params}, 158 | num_images=5, 159 | diffusion_steps=5, 160 | method=diffusion_model.generate) 161 | ``` 162 | 163 | Audio example 164 | 165 | ```py 166 | import jax 167 | import jax.numpy as jnp 168 | from nanodl import ArrayDataset, DataLoader 169 | from nanodl import Whisper, WhisperDataParallelTrainer 170 | 171 | # Dummy data parameters 172 | batch_size = 8 173 | max_length = 50 174 | embed_dim = 256 175 | vocab_size = 1000 176 | 177 | # Generate data: replace with actual tokenised/quantised data 178 | dummy_targets = jnp.ones((101, max_length), dtype=jnp.int32) 179 | dummy_inputs = jnp.ones((101, max_length, embed_dim)) 180 | 181 | dataset = ArrayDataset(dummy_inputs, dummy_targets) 182 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) 183 | 184 | # model parameters 185 | hyperparams = { 186 | 'num_layers': 1, 187 | 'hidden_dim': 256, 188 | 'num_heads': 2, 189 | 'feedforward_dim': 256, 190 | 'dropout': 0.1, 191 | 'vocab_size': 1000, 192 | 'embed_dim': embed_dim, 193 | 'max_length': max_length, 194 | 'start_token': 0, 195 | 'end_token': 50, 196 | } 197 | 198 | # Initialize model 199 | model = Whisper(**hyperparams) 200 | 201 | # Training on your data 202 | trainer = WhisperDataParallelTrainer(model, 203 | dummy_inputs.shape, 204 | dummy_targets.shape, 205 | 'params.pkl') 206 | 207 | trainer.train(dataloader, 2, dataloader) 208 | 209 | # Sample inference 210 | params = trainer.load_params('params.pkl') 211 | 212 | # for more than one sample, often use model.generate_batch 213 | transcripts = model.apply({'params': params}, 214 | dummy_inputs[:1], 215 | method=model.generate) 216 | ``` 217 | 218 | Reward Model example for RLHF 219 | 220 | ```py 221 | import nanodl 222 | import jax.numpy as jnp 223 | from nanodl import ArrayDataset, DataLoader 224 | from nanodl import Mistral, RewardModel, RewardDataParallelTrainer 225 | 226 | # Generate dummy data 227 | batch_size = 8 228 | max_length = 10 229 | 230 | # Replace with actual tokenised data 231 | dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32) 232 | dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32) 233 | 234 | # Create dataset and dataloader 235 | dataset = ArrayDataset(dummy_chosen, dummy_rejected) 236 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) 237 | 238 | # model parameters 239 | hyperparams = { 240 | 'num_layers': 1, 241 | 'hidden_dim': 256, 242 | 'num_heads': 2, 243 | 'feedforward_dim': 256, 244 | 'dropout': 0.1, 245 | 'vocab_size': 1000, 246 | 'embed_dim': 256, 247 | 'max_length': max_length, 248 | 'start_token': 0, 249 | 'end_token': 50, 250 | 'num_groups': 2, 251 | 'window_size': 5, 252 | 'shift_size': 2 253 | } 254 | 255 | # Initialize reward model from Mistral 256 | model = Mistral(**hyperparams) 257 | reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1) 258 | 259 | # Train the reward model 260 | trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl') 261 | trainer.train(dataloader, 5, dataloader) 262 | params = trainer.load_params('reward_model_weights.pkl') 263 | 264 | # Call as you would a regular Flax model 265 | rewards = reward_model.apply({'params': params}, 266 | dummy_chosen, 267 | rngs={'dropout': nanodl.time_rng_key()}) 268 | ``` 269 | 270 | PCA example 271 | 272 | ```py 273 | import nanodl 274 | from nanodl import PCA 275 | 276 | # Use actual data 277 | data = nanodl.normal(shape=(1000, 10)) 278 | 279 | # Initialise and train PCA model 280 | pca = PCA(n_components=2) 281 | pca.fit(data) 282 | 283 | # Get PCA transforms 284 | transformed_data = pca.transform(data) 285 | 286 | # Get reverse transforms 287 | original_data = pca.inverse_transform(transformed_data) 288 | 289 | # Sample from the distribution 290 | X_sampled = pca.sample(n_samples=1000, key=None) 291 | ``` 292 | 293 | This is still in dev, works great but roughness is expected, and contributions are therefore highly encouraged! 294 | 295 | - Make your changes without changing the design patterns. 296 | - Write tests for your changes if necessary. 297 | - Install locally with `pip3 install -e .`. 298 | - Run tests with `python3 -m unittest discover -s tests`. 299 | - Then submit a pull request. 300 | 301 | Contributions can be made in various forms: 302 | 303 | - Writing documentation. 304 | - Fixing bugs. 305 | - Implementing papers. 306 | - Writing high-coverage tests. 307 | - Optimizing existing codes. 308 | - Experimenting and submitting real-world examples to the examples section. 309 | - Reporting bugs. 310 | - Responding to reported issues. 311 | 312 | Join the [Discord Server](https://discord.gg/3u9vumJEmz) for more. 313 | 314 | ## Sponsorships 315 | 316 | The name "NanoDL" stands for Nano Deep Learning. Models are exploding in size, therefore gate-keeping 317 | experts and companies with limited resources from building flexible models without prohibitive costs. 318 | Following the success of Phi models, the long-term goal is to build and train nano versions of all available models, 319 | while ensuring they compete with the original models in performance, with total 320 | number of parameters not exceeding 1B. Trained weights will be made available via this library. 321 | Any form of sponsorship, funding will help with training resources. 322 | You can either sponsor via GitHub [here](https://github.com/sponsors/HMUNACHI) or reach out via ndubuakuhenry@gmail.com. 323 | 324 | ## Citing nanodl 325 | 326 | To cite this repository: 327 | 328 | ``` 329 | @software{nanodl2024github, 330 | author = {Henry Ndubuaku}, 331 | title = {NanoDL: A Jax-based library for designing and training transformer models from scratch.}, 332 | url = {http://github.com/hmunachi/nanodl}, 333 | year = {2024}, 334 | } 335 | ``` 336 | -------------------------------------------------------------------------------- /assets/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMUNACHI/nanodl/e52861a5b2c9bf76e4e79e0bf88a07420497579d/assets/logo.jpg -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # Welcome to NanoDL Documentation 2 | 3 | ## API Reference 4 | 5 | ::: nanodl.GAT 6 | ::: nanodl.GraphAttentionLayer 7 | ::: nanodl.T5 8 | ::: nanodl.T5DataParallelTrainer 9 | ::: nanodl.T5Encoder 10 | ::: nanodl.T5Decoder 11 | ::: nanodl.T5EncoderBlock 12 | ::: nanodl.T5DecoderBlock 13 | ::: nanodl.ViT 14 | ::: nanodl.ViTDataParallelTrainer 15 | ::: nanodl.ViTBlock 16 | ::: nanodl.ViTEncoder 17 | ::: nanodl.PatchEmbedding 18 | ::: nanodl.Transformer 19 | ::: nanodl.TransformerDataParallelTrainer 20 | ::: nanodl.TransformerEncoder 21 | ::: nanodl.TransformerDecoderBlock 22 | ::: nanodl.PositionalEncoding 23 | ::: nanodl.TokenAndPositionEmbedding 24 | ::: nanodl.MultiHeadAttention 25 | ::: nanodl.AddNorm 26 | ::: nanodl.CLIP 27 | ::: nanodl.CLIPDataParallelTrainer 28 | ::: nanodl.ImageEncoder 29 | ::: nanodl.TextEncoder 30 | ::: nanodl.SelfMultiHeadAttention 31 | ::: nanodl.LaMDA 32 | ::: nanodl.LaMDADataParallelTrainer 33 | ::: nanodl.LaMDABlock 34 | ::: nanodl.LaMDADecoder 35 | ::: nanodl.RelativeMultiHeadAttention 36 | ::: nanodl.DiffusionModel 37 | ::: nanodl.DiffusionDataParallelTrainer 38 | ::: nanodl.UNet 39 | ::: nanodl.UNetDownBlock 40 | ::: nanodl.UNetUpBlock 41 | ::: nanodl.UNetResidualBlock 42 | ::: nanodl.GPT3 43 | ::: nanodl.GPT4 44 | ::: nanodl.GPTDataParallelTrainer 45 | ::: nanodl.GPT3Block 46 | ::: nanodl.GPT4Block 47 | ::: nanodl.GPT3Decoder 48 | ::: nanodl.GPT4Decoder 49 | ::: nanodl.PositionWiseFFN 50 | ::: nanodl.LlaMA2 51 | ::: nanodl.LlaMADataParallelTrainer 52 | ::: nanodl.RotaryPositionalEncoding 53 | ::: nanodl.LlaMA2Decoder 54 | ::: nanodl.LlaMA2DecoderBlock 55 | ::: nanodl.GroupedRotaryMultiHeadAttention 56 | ::: nanodl.Mistral 57 | ::: nanodl.MistralDataParallelTrainer 58 | ::: nanodl.MistralDecoder 59 | ::: nanodl.MistralDecoderBlock 60 | ::: nanodl.GroupedRotaryShiftedWindowMultiHeadAttention 61 | ::: nanodl.Mixtral 62 | ::: nanodl.MixtralDecoder 63 | ::: nanodl.MixtralDecoderBlock 64 | ::: nanodl.Whisper 65 | ::: nanodl.WhisperDataParallelTrainer 66 | ::: nanodl.WhisperSpeechEncoder 67 | ::: nanodl.WhisperSpeechEncoderBlock 68 | ::: nanodl.GAT 69 | ::: nanodl.GraphAttentionLayer 70 | ::: nanodl.NaiveBayesClassifier 71 | ::: nanodl.LinearRegression 72 | ::: nanodl.LogisticRegression 73 | ::: nanodl.GaussianProcess 74 | ::: nanodl.KMeans 75 | ::: nanodl.GaussianMixtureModel 76 | ::: nanodl.PCA 77 | ::: nanodl.Dataset 78 | ::: nanodl.ArrayDataset 79 | ::: nanodl.DataLoader 80 | ::: nanodl.batch_cosine_similarities 81 | ::: nanodl.batch_pearsonr 82 | ::: nanodl.classification_scores 83 | ::: nanodl.count_parameters 84 | ::: nanodl.entropy 85 | ::: nanodl.gini_impurity 86 | ::: nanodl.hamming 87 | ::: nanodl.jaccard 88 | ::: nanodl.kl_divergence 89 | ::: nanodl.mean_reciprocal_rank 90 | ::: nanodl.zero_pad_sequences 91 | ::: nanodl.bleu 92 | ::: nanodl.cider_score 93 | ::: nanodl.meteor 94 | ::: nanodl.perplexity 95 | ::: nanodl.rouge 96 | ::: nanodl.word_error_rate 97 | ::: nanodl.adjust_brightness 98 | ::: nanodl.adjust_contrast 99 | ::: nanodl.flip_image 100 | ::: nanodl.gaussian_blur 101 | ::: nanodl.normalize_images 102 | ::: nanodl.random_crop 103 | ::: nanodl.random_flip_image 104 | ::: nanodl.sobel_edge_detection 105 | -------------------------------------------------------------------------------- /docs/examples/diffusion_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import jax\n", 10 | "import jax.numpy as jnp\n", 11 | "from nanodl import ArrayDataset, DataLoader\n", 12 | "from nanodl import DiffusionModel, DiffusionDataParallelTrainer\n", 13 | "\n", 14 | "image_size = 32\n", 15 | "block_depth = 2\n", 16 | "batch_size = 8\n", 17 | "widths = [32, 64, 128]\n", 18 | "key = jax.random.PRNGKey(0)" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# Use actual images\n", 28 | "images = jnp.ones((101, image_size, image_size, 3))\n", 29 | "dataset = ArrayDataset(images) \n", 30 | "dataloader = DataLoader(dataset, \n", 31 | " batch_size=batch_size, \n", 32 | " shuffle=True, \n", 33 | " drop_last=False) " 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "(101, 32, 32, 3) (101, 32, 32, 3)\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "# Create diffusion model\n", 51 | "diffusion_model = DiffusionModel(image_size, widths, block_depth)\n", 52 | "params = diffusion_model.init(key, images)\n", 53 | "pred_noises, pred_images = diffusion_model.apply(params, images)\n", 54 | "print(pred_noises.shape, pred_images.shape)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "Number of parameters: 1007395\n", 67 | "Number of accelerators: 1\n", 68 | "Epoch 1, Train Loss: 7.979574203491211\n", 69 | "Epoch 1, Val Loss: 24.317951202392578\n", 70 | "New best validation score achieved, saving model...\n", 71 | "Epoch 2, Train Loss: 7.75728702545166\n", 72 | "Epoch 2, Val Loss: 23.518024444580078\n", 73 | "New best validation score achieved, saving model...\n", 74 | "Epoch 3, Train Loss: 7.392527103424072\n", 75 | "Epoch 3, Val Loss: 22.308382034301758\n", 76 | "New best validation score achieved, saving model...\n", 77 | "Epoch 4, Train Loss: 6.846263408660889\n", 78 | "Epoch 4, Val Loss: 20.62131690979004\n", 79 | "New best validation score achieved, saving model...\n", 80 | "Epoch 5, Train Loss: 6.1358747482299805\n", 81 | "Epoch 5, Val Loss: 18.36245346069336\n", 82 | "New best validation score achieved, saving model...\n", 83 | "Epoch 6, Train Loss: 5.278435230255127\n", 84 | "Epoch 6, Val Loss: 15.812017440795898\n", 85 | "New best validation score achieved, saving model...\n", 86 | "Epoch 7, Train Loss: 4.328006267547607\n", 87 | "Epoch 7, Val Loss: 13.123092651367188\n", 88 | "New best validation score achieved, saving model...\n", 89 | "Epoch 8, Train Loss: 3.3344056606292725\n", 90 | "Epoch 8, Val Loss: 10.264131546020508\n", 91 | "New best validation score achieved, saving model...\n", 92 | "Epoch 9, Train Loss: 2.401970386505127\n", 93 | "Epoch 9, Val Loss: 7.67496919631958\n", 94 | "New best validation score achieved, saving model...\n", 95 | "Epoch 10, Train Loss: 1.6279072761535645\n", 96 | "Epoch 10, Val Loss: 5.5517578125\n", 97 | "New best validation score achieved, saving model...\n", 98 | "5.551758\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "# Training on your data\n", 104 | "# Note: saved params are often different from training weights, use the saved params for generation\n", 105 | "trainer = DiffusionDataParallelTrainer(diffusion_model, \n", 106 | " input_shape=images.shape, \n", 107 | " weights_filename='params.pkl', \n", 108 | " learning_rate=1e-4)\n", 109 | "trainer.train(dataloader, 10, dataloader)\n", 110 | "print(trainer.evaluate(dataloader))" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 6, 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "name": "stdout", 120 | "output_type": "stream", 121 | "text": [ 122 | "(5, 32, 32, 3)\n" 123 | ] 124 | } 125 | ], 126 | "source": [ 127 | "# Generate some samples\n", 128 | "params = trainer.load_params('params.pkl')\n", 129 | "generated_images = diffusion_model.apply({'params': params}, \n", 130 | " num_images=5, \n", 131 | " diffusion_steps=5, \n", 132 | " method=diffusion_model.generate)\n", 133 | "print(generated_images.shape)" 134 | ] 135 | } 136 | ], 137 | "metadata": { 138 | "kernelspec": { 139 | "display_name": "Python 3", 140 | "language": "python", 141 | "name": "python3" 142 | }, 143 | "language_info": { 144 | "codemirror_mode": { 145 | "name": "ipython", 146 | "version": 3 147 | }, 148 | "file_extension": ".py", 149 | "mimetype": "text/x-python", 150 | "name": "python", 151 | "nbconvert_exporter": "python", 152 | "pygments_lexer": "ipython3", 153 | "version": "3.10.0" 154 | } 155 | }, 156 | "nbformat": 4, 157 | "nbformat_minor": 2 158 | } 159 | -------------------------------------------------------------------------------- /docs/examples/gat_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "Output shape: (10, 3)\n", 13 | "Output sample: [[-0.77854604 0.63543594 0.19899184]\n", 14 | " [-0.3826082 0.5630409 0.01140817]]\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import jax\n", 20 | "from nanodl import GAT\n", 21 | "\n", 22 | "# Generate a random key for Jax\n", 23 | "key = jax.random.PRNGKey(0)\n", 24 | "\n", 25 | "# Create dummy input data\n", 26 | "num_nodes = 10\n", 27 | "num_features = 5\n", 28 | "x = jax.random.normal(key, (num_nodes, num_features)) # Features for each node\n", 29 | "adj = jax.random.bernoulli(key, 0.3, (num_nodes, num_nodes)) # Random adjacency matrix\n", 30 | "\n", 31 | "# Initialize the GAT model\n", 32 | "model = GAT(nfeat=num_features,\n", 33 | " nhid=8, \n", 34 | " nclass=3, \n", 35 | " dropout_rate=0.5, \n", 36 | " alpha=0.2, \n", 37 | " nheads=3)\n", 38 | "\n", 39 | "# Initialize the model parameters\n", 40 | "params = model.init(key, x, adj)\n", 41 | "output = model.apply(params, x, adj)\n", 42 | "\n", 43 | "# Print the output shape and a sample of the output\n", 44 | "print(\"Output shape:\", output.shape)\n", 45 | "print(\"Output sample:\", output[:2])" 46 | ] 47 | } 48 | ], 49 | "metadata": { 50 | "kernelspec": { 51 | "display_name": "Python 3", 52 | "language": "python", 53 | "name": "python3" 54 | }, 55 | "language_info": { 56 | "codemirror_mode": { 57 | "name": "ipython", 58 | "version": 3 59 | }, 60 | "file_extension": ".py", 61 | "mimetype": "text/x-python", 62 | "name": "python", 63 | "nbconvert_exporter": "python", 64 | "pygments_lexer": "ipython3", 65 | "version": "3.10.0" 66 | } 67 | }, 68 | "nbformat": 4, 69 | "nbformat_minor": 2 70 | } 71 | -------------------------------------------------------------------------------- /docs/examples/gpt4_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 5, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import jax\n", 10 | "import jax.numpy as jnp\n", 11 | "from nanodl import ArrayDataset, DataLoader\n", 12 | "from nanodl import GPT4, GPTDataParallelTrainer\n", 13 | "\n", 14 | "# Generate dummy data\n", 15 | "batch_size = 8\n", 16 | "max_length = 9" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "(8, 9) (8, 9)\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "# Replace with actual tokenised data\n", 34 | "data = jnp.ones((101, max_length+1), dtype=jnp.int32)\n", 35 | "\n", 36 | "# Shift to create next-token prediction dataset\n", 37 | "dummy_inputs = data[:, :-1]\n", 38 | "dummy_targets = data[:, 1:]\n", 39 | "\n", 40 | "# Create dataset and dataloader\n", 41 | "dataset = ArrayDataset(dummy_inputs, dummy_targets)\n", 42 | "dataloader = DataLoader(dataset, \n", 43 | " batch_size=batch_size, \n", 44 | " shuffle=True, \n", 45 | " drop_last=False)\n", 46 | "\n", 47 | "# How to loop through dataloader\n", 48 | "for batch in dataloader:\n", 49 | " x, y = batch\n", 50 | " print(x.shape, y.shape)\n", 51 | " break" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "(101, 9, 1000)\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "# model parameters\n", 69 | "hyperparams = {\n", 70 | " 'num_layers': 1,\n", 71 | " 'hidden_dim': 256,\n", 72 | " 'num_heads': 2,\n", 73 | " 'feedforward_dim': 256,\n", 74 | " 'dropout': 0.1,\n", 75 | " 'vocab_size': 1000,\n", 76 | " 'embed_dim': 256,\n", 77 | " 'max_length': max_length,\n", 78 | " 'start_token': 0,\n", 79 | " 'end_token': 50,\n", 80 | "}\n", 81 | "\n", 82 | "# Initialize model\n", 83 | "model = GPT4(**hyperparams)\n", 84 | "rngs = jax.random.PRNGKey(0)\n", 85 | "rngs, dropout_rng = jax.random.split(rngs)\n", 86 | "params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params']\n", 87 | "\n", 88 | "# Call as you would a Jax/Flax model\n", 89 | "outputs = model.apply({'params': params}, \n", 90 | " dummy_inputs, \n", 91 | " rngs={'dropout': dropout_rng})\n", 92 | "print(outputs.shape)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 4, 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "name": "stdout", 102 | "output_type": "stream", 103 | "text": [ 104 | "Number of parameters: 3740914\n", 105 | "Number of accelerators: 1\n", 106 | "Epoch 1, Train Loss: 6.438497066497803\n", 107 | "Epoch 1, Val Loss: 5.959759712219238\n", 108 | "New best validation score achieved, saving model...\n", 109 | "5.9597597\n" 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "# Training on data\n", 115 | "trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')\n", 116 | "trainer.train(train_loader=dataloader, \n", 117 | " num_epochs=1, \n", 118 | " val_loader=dataloader) # Use actual validation data\n", 119 | "\n", 120 | "print(trainer.evaluate(dataloader))" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 6, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "name": "stdout", 130 | "output_type": "stream", 131 | "text": [ 132 | "[639 742 45 840 381 555 251 814 478 261]\n" 133 | ] 134 | } 135 | ], 136 | "source": [ 137 | "# Generating from a start token\n", 138 | "start_tokens = jnp.array([[123, 456]])\n", 139 | "\n", 140 | "# Remember to load the trained parameters \n", 141 | "params = trainer.load_params('params.pkl')\n", 142 | "outputs = model.apply({'params': params},\n", 143 | " start_tokens,\n", 144 | " rngs={'dropout': jax.random.PRNGKey(2)}, \n", 145 | " method=model.generate)\n", 146 | "print(outputs) " 147 | ] 148 | } 149 | ], 150 | "metadata": { 151 | "kernelspec": { 152 | "display_name": "Python 3", 153 | "language": "python", 154 | "name": "python3" 155 | }, 156 | "language_info": { 157 | "codemirror_mode": { 158 | "name": "ipython", 159 | "version": 3 160 | }, 161 | "file_extension": ".py", 162 | "mimetype": "text/x-python", 163 | "name": "python", 164 | "nbconvert_exporter": "python", 165 | "pygments_lexer": "ipython3", 166 | "version": "3.10.0" 167 | } 168 | }, 169 | "nbformat": 4, 170 | "nbformat_minor": 2 171 | } 172 | -------------------------------------------------------------------------------- /docs/examples/llama_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import jax\n", 10 | "import jax.numpy as jnp\n", 11 | "from nanodl import ArrayDataset, DataLoader\n", 12 | "from nanodl import LlaMA2, LlaMADataParallelTrainer\n", 13 | "\n", 14 | "# Generate dummy data\n", 15 | "batch_size = 8\n", 16 | "max_length = 10" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 3, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "(8, 10) (8, 10)\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "# Replace with actual tokenised data\n", 34 | "data = jnp.ones((101, max_length+1), dtype=jnp.int32)\n", 35 | "\n", 36 | "# Shift to create next-token prediction dataset\n", 37 | "dummy_inputs = data[:, :-1]\n", 38 | "dummy_targets = data[:, 1:]\n", 39 | "\n", 40 | "# Create dataset and dataloader\n", 41 | "dataset = ArrayDataset(dummy_inputs, dummy_targets)\n", 42 | "dataloader = DataLoader(dataset, \n", 43 | " batch_size=batch_size, \n", 44 | " shuffle=True, \n", 45 | " drop_last=False)\n", 46 | "\n", 47 | "# How to loop through dataloader\n", 48 | "for batch in dataloader:\n", 49 | " x, y = batch\n", 50 | " print(x.shape, y.shape)\n", 51 | " break" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 5, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "(101, 10, 1000)\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "# model parameters\n", 69 | "hyperparams = {\n", 70 | " 'num_layers': 1,\n", 71 | " 'hidden_dim': 256,\n", 72 | " 'num_heads': 2,\n", 73 | " 'feedforward_dim': 256,\n", 74 | " 'dropout': 0.1,\n", 75 | " 'vocab_size': 1000,\n", 76 | " 'embed_dim': 256,\n", 77 | " 'max_length': max_length,\n", 78 | " 'start_token': 0,\n", 79 | " 'end_token': 50,\n", 80 | " 'num_groups': 2,\n", 81 | "}\n", 82 | "\n", 83 | "# Initialize model\n", 84 | "model = LlaMA2(**hyperparams)\n", 85 | "rngs = jax.random.PRNGKey(0)\n", 86 | "rngs, dropout_rng = jax.random.split(rngs)\n", 87 | "params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params']\n", 88 | "\n", 89 | "# Call as you would a Jax/Flax model\n", 90 | "outputs = model.apply({'params': params}, \n", 91 | " dummy_inputs, \n", 92 | " rngs={'dropout': dropout_rng})\n", 93 | "print(outputs.shape)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 6, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "Number of parameters: 974312\n", 106 | "Number of accelerators: 1\n", 107 | "Epoch 1, Train Loss: 7.56395959854126\n", 108 | "Epoch 1, Val Loss: 6.7949442863464355\n", 109 | "New best validation score achieved, saving model...\n", 110 | "Epoch 2, Train Loss: 6.494354248046875\n", 111 | "Epoch 2, Val Loss: 5.658466815948486\n", 112 | "New best validation score achieved, saving model...\n", 113 | "5.658467\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "# Training on data\n", 119 | "trainer = LlaMADataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')\n", 120 | "trainer.train(train_loader=dataloader, \n", 121 | " num_epochs=2, \n", 122 | " val_loader=dataloader) # Use actual validation data\n", 123 | "\n", 124 | "print(trainer.evaluate(dataloader))" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 7, 130 | "metadata": {}, 131 | "outputs": [ 132 | { 133 | "name": "stdout", 134 | "output_type": "stream", 135 | "text": [ 136 | "[157 157 932 932 949 406 748 523 171 364]\n" 137 | ] 138 | } 139 | ], 140 | "source": [ 141 | "# Generating from a start token\n", 142 | "start_tokens = jnp.array([[123, 456]])\n", 143 | "\n", 144 | "# Remember to load the trained parameters \n", 145 | "params = trainer.load_params('params.pkl')\n", 146 | "outputs = model.apply({'params': params},\n", 147 | " start_tokens,\n", 148 | " rngs={'dropout': jax.random.PRNGKey(2)}, \n", 149 | " method=model.generate)\n", 150 | "print(outputs)" 151 | ] 152 | } 153 | ], 154 | "metadata": { 155 | "kernelspec": { 156 | "display_name": "Python 3", 157 | "language": "python", 158 | "name": "python3" 159 | }, 160 | "language_info": { 161 | "codemirror_mode": { 162 | "name": "ipython", 163 | "version": 3 164 | }, 165 | "file_extension": ".py", 166 | "mimetype": "text/x-python", 167 | "name": "python", 168 | "nbconvert_exporter": "python", 169 | "pygments_lexer": "ipython3", 170 | "version": "3.10.0" 171 | } 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 2 175 | } 176 | -------------------------------------------------------------------------------- /docs/examples/whisper_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import jax\n", 10 | "import jax.numpy as jnp\n", 11 | "from nanodl import ArrayDataset, DataLoader\n", 12 | "from nanodl import Whisper, WhisperDataParallelTrainer\n", 13 | "\n", 14 | "# Dummy data parameters\n", 15 | "batch_size = 8\n", 16 | "max_length = 50\n", 17 | "embed_dim = 256 \n", 18 | "vocab_size = 1000 " 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "name": "stdout", 28 | "output_type": "stream", 29 | "text": [ 30 | "(8, 50, 256) (8, 50)\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "# Generate data: replace with actual tokenised/quantised data\n", 36 | "dummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)\n", 37 | "dummy_inputs = jnp.ones((101, max_length, embed_dim))\n", 38 | "\n", 39 | "dataset = ArrayDataset(dummy_inputs, \n", 40 | " dummy_targets)\n", 41 | "\n", 42 | "dataloader = DataLoader(dataset, \n", 43 | " batch_size=batch_size, \n", 44 | " shuffle=True, \n", 45 | " drop_last=False)\n", 46 | "\n", 47 | "# How to loop through dataloader\n", 48 | "for batch in dataloader:\n", 49 | " x, y = batch\n", 50 | " print(x.shape, y.shape)\n", 51 | " break" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "(101, 50, 1000)\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "# model parameters\n", 69 | "hyperparams = {\n", 70 | " 'num_layers': 1,\n", 71 | " 'hidden_dim': 256,\n", 72 | " 'num_heads': 2,\n", 73 | " 'feedforward_dim': 256,\n", 74 | " 'dropout': 0.1,\n", 75 | " 'vocab_size': 1000,\n", 76 | " 'embed_dim': embed_dim,\n", 77 | " 'max_length': max_length,\n", 78 | " 'start_token': 0,\n", 79 | " 'end_token': 50,\n", 80 | "}\n", 81 | "\n", 82 | "# Initialize model\n", 83 | "model = Whisper(**hyperparams)\n", 84 | "rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}\n", 85 | "params = model.init(rngs, dummy_inputs, dummy_targets)['params']\n", 86 | "outputs = model.apply({'params': params}, dummy_inputs, dummy_targets, rngs=rngs)\n", 87 | "print(outputs.shape)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 4, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "Number of parameters: 1974760\n", 100 | "Number of accelerators: 1\n", 101 | "Epoch 1, Train Loss: 8.127946853637695\n", 102 | "Epoch 1, Val Loss: 7.081634521484375\n", 103 | "New best validation score achieved, saving model...\n", 104 | "Epoch 2, Train Loss: 6.22011137008667\n", 105 | "Epoch 2, Val Loss: 5.051723957061768\n", 106 | "New best validation score achieved, saving model...\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "# Training on your data\n", 112 | "trainer = WhisperDataParallelTrainer(model, \n", 113 | " dummy_inputs.shape, \n", 114 | " dummy_targets.shape, \n", 115 | " 'params.pkl')\n", 116 | "trainer.train(dataloader, 2, dataloader)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 5, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "[930 471 500 450 936 936 143 851 716 695 275 389 246 79 7 494 562 314\n", 129 | " 276 583 788 525 302 150 694 694 161 741 902 77 946 294 210 945 272 266\n", 130 | " 493 553 533 619 703 330 330 154 438 797 334 322 31 649]\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "# Sample inference\n", 136 | "params = trainer.load_params('params.pkl')\n", 137 | "\n", 138 | "# for more than one sample, use model.generate_batch\n", 139 | "transcripts = model.apply({'params': params}, \n", 140 | " dummy_inputs[:1], \n", 141 | " rngs=rngs, \n", 142 | " method=model.generate)\n", 143 | "\n", 144 | "print(transcripts)" 145 | ] 146 | } 147 | ], 148 | "metadata": { 149 | "kernelspec": { 150 | "display_name": "Python 3", 151 | "language": "python", 152 | "name": "python3" 153 | }, 154 | "language_info": { 155 | "codemirror_mode": { 156 | "name": "ipython", 157 | "version": 3 158 | }, 159 | "file_extension": ".py", 160 | "mimetype": "text/x-python", 161 | "name": "python", 162 | "nbconvert_exporter": "python", 163 | "pygments_lexer": "ipython3", 164 | "version": "3.10.0" 165 | } 166 | }, 167 | "nbformat": 4, 168 | "nbformat_minor": 2 169 | } 170 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ## Overview 2 | 3 | Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features: 4 | 5 | - A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch. 6 | - An extensive selection of models like LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications. 7 | - Data-parallel distributed trainers so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops. 8 | - Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective. 9 | - Custom layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development. 10 | - GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU. 11 | - Modular design so users can blend elements from various models, such as GPT, Mixtral, and LlaMa2, to craft unique hybrid transformer models. 12 | - A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU etc. 13 | - Each model is contained in a single file with no external dependencies, so the source code can also be easily used. 14 | 15 | Feedback on any of our discussion, issue and pull request threads are welcomed! Please report any feature requests, issues, questions or concerns in the [discussion forum](https://github.com/hmunachi/nanodl/discussions), or just let us know what you're working on! In case you want to reach out directly, we're at ndubuakuhenry@gmail.com. 16 | 17 | # Contribution 18 | 19 | This is the first iteration of this project, roughness is expected, contributions are therefore highly encouraged! Follow the recommended steps: 20 | 21 | - Raise the issue/discussion to get second opinions 22 | - Fork the repository 23 | - Create a branch 24 | - Make your changes without ruining the design patterns 25 | - Write tests for your changes if necessary 26 | - Install locally with `pip install -e .` 27 | - Run tests with `python -m unittest discover -s tests` 28 | - Then submit a pull request from branch. 29 | 30 | Contributions can be made in various forms: 31 | 32 | - Writing documentation. 33 | - Fixing bugs. 34 | - Implementing papers. 35 | - Writing high-coverage tests. 36 | - OPtimizing existing codes. 37 | - Experimenting and submitting real-world examples to the examples section. 38 | - Reporting bugs. 39 | - Responding to reported issues. 40 | 41 | Coming features include: 42 | - Reinforcement Learning With Human Feedback (RLHF). 43 | - Tokenizers. 44 | - Code optimisations. 45 | 46 | To follow up or share thoughts, follow [here](https://forms.gle/vwveb9SKdPYywHx9A) 47 | 48 | ## Sponsorships 49 | 50 | The name "NanoDL" stands for Nano Deep Learning. Models are exploding in size, therefore gate-keeping 51 | experts and companies with limited resources from building flexible models without prohibitive costs. 52 | Following the success of Phi models, the long-term goal is to build and train nano versions of all available models, 53 | while ensuring they compete with the original models in performance, with total 54 | number of parameters not exceeding 1B. Trained weights will be made available via this library. 55 | Any form of sponsorship, funding, grants or contribution will help with training resources. 56 | You can sponsor via the provided button, or reach out via ndubuakuhenry@gmail.com. -------------------------------------------------------------------------------- /docs/requirements.in: -------------------------------------------------------------------------------- 1 | mkdocs 2 | mkdocstrings[python] 3 | markdown-include -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with python 3.10 3 | # To update, run: 4 | # 5 | # pip-compile docs/requirements.in 6 | # 7 | click==8.1.3 8 | # via mkdocs 9 | ghp-import==2.1.0 10 | # via mkdocs 11 | griffe==0.22.0 12 | # via mkdocstrings-python 13 | importlib-metadata==4.12.0 14 | # via mkdocs 15 | jinja2==3.1.2 16 | # via 17 | # mkdocs 18 | # mkdocstrings 19 | markdown==3.3.7 20 | # via 21 | # markdown-include 22 | # mkdocs 23 | # mkdocs-autorefs 24 | # mkdocstrings 25 | # pymdown-extensions 26 | markdown-include==0.6.0 27 | # via -r docs/requirements.in 28 | markupsafe==2.1.1 29 | # via 30 | # jinja2 31 | # mkdocstrings 32 | mergedeep==1.3.4 33 | # via mkdocs 34 | mkdocs==1.3.0 35 | # via 36 | # -r docs/requirements.in 37 | # mkdocs-autorefs 38 | # mkdocstrings 39 | mkdocs-autorefs==0.4.1 40 | # via mkdocstrings 41 | mkdocstrings[python]==0.19.0 42 | # via 43 | # -r docs/requirements.in 44 | # mkdocstrings-python 45 | mkdocstrings-python==0.7.1 46 | # via mkdocstrings 47 | packaging==21.3 48 | # via mkdocs 49 | pymdown-extensions==9.5 50 | # via mkdocstrings 51 | pyparsing==3.0.9 52 | # via packaging 53 | python-dateutil==2.8.2 54 | # via ghp-import 55 | pyyaml==6.0 56 | # via 57 | # mkdocs 58 | # pyyaml-env-tag 59 | pyyaml-env-tag==0.1 60 | # via mkdocs 61 | six==1.16.0 62 | # via python-dateutil 63 | watchdog==2.1.9 64 | # via mkdocs 65 | zipp==3.8.0 66 | # via importlib-metadata -------------------------------------------------------------------------------- /docs/usage.md: -------------------------------------------------------------------------------- 1 | Usage 2 | ===== 3 | 4 | Installation 5 | ------------ 6 | 7 | You will need Python 3.9 or later, and working [JAX](https://github.com/google/jax/blob/main/README.md) 8 | installation, [FLAX](https://github.com/google/flax/blob/main/README.md) 9 | installation, [OPTAX](https://github.com/google-deepmind/optax/blob/main/README.md) 10 | installation (with GPU support for running training, without can only support creations). 11 | Models can be designed and tested on CPUs but trainers are all Distributed Data-Parallel which would require a GPU with 1 to N GPUS/TPUS. For CPU-only version of JAX: 12 | 13 | ``` 14 | pip install --upgrade pip # To support manylinux2010 wheels. 15 | pip install jax, flax, optax 16 | ``` 17 | 18 | Then, install nanodl from PyPi: 19 | 20 | ``` 21 | pip install nanodl 22 | ``` 23 | 24 | Creating a GPT Model 25 | ---------------- 26 | 27 | ```py 28 | import jax 29 | import jax.numpy as jnp 30 | from nanodl import ArrayDataset, DataLoader 31 | from nanodl import GPT4, GPTDataParallelTrainer 32 | 33 | # Generate dummy data 34 | batch_size = 8 35 | max_length = 10 36 | 37 | # Replace with actual tokenised data 38 | data = jnp.ones((101, max_length), dtype=jnp.int32) 39 | 40 | # Shift to create next-token prediction dataset 41 | dummy_inputs = data[:, :-1] 42 | dummy_targets = data[:, 1:] 43 | 44 | # Create dataset and dataloader 45 | dataset = ArrayDataset(dummy_inputs, dummy_targets) 46 | dataloader = DataLoader(dataset, 47 | batch_size=batch_size, 48 | shuffle=True, 49 | drop_last=False) 50 | 51 | # How to loop through dataloader 52 | for batch in dataloader: 53 | x, y = batch 54 | print(x.shape, y.shape) 55 | break 56 | 57 | # model parameters 58 | hyperparams = { 59 | 'num_layers': 1, 60 | 'hidden_dim': 256, 61 | 'num_heads': 2, 62 | 'feedforward_dim': 256, 63 | 'dropout': 0.1, 64 | 'vocab_size': 1000, 65 | 'embed_dim': 256, 66 | 'max_length': max_length, 67 | 'start_token': 0, 68 | 'end_token': 50, 69 | } 70 | 71 | # Initialize model 72 | model = GPT4(**hyperparams) 73 | rngs = jax.random.PRNGKey(0) 74 | rngs, dropout_rng = jax.random.split(rngs) 75 | params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params'] 76 | 77 | # Call as you would a Jax/Flax model 78 | outputs = model.apply({'params': params}, 79 | dummy_inputs, 80 | rngs={'dropout': dropout_rng}) 81 | print(outputs.shape) 82 | 83 | # Training on data 84 | trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl') 85 | trainer.train(train_loader=dataloader, 86 | num_epochs=2, 87 | val_loader=dataloader) 88 | 89 | print(trainer.evaluate(dataloader)) 90 | 91 | # Generating from a start token 92 | start_tokens = jnp.array([[123, 456]]) 93 | 94 | # Remember to load the trained parameters 95 | params = trainer.load_params('params.pkl') 96 | outputs = model.apply({'params': params}, 97 | start_tokens, 98 | rngs={'dropout': jax.random.PRNGKey(2)}, 99 | method=model.generate) 100 | print(outputs) 101 | ``` 102 | 103 | Creating a Diffusion model 104 | ---------------- 105 | 106 | ```py 107 | import jax 108 | import jax.numpy as jnp 109 | from nanodl import ArrayDataset, DataLoader 110 | from nanodl import DiffusionModel, DiffusionDataParallelTrainer 111 | 112 | image_size = 32 113 | block_depth = 2 114 | batch_size = 8 115 | widths = [32, 64, 128] 116 | key = jax.random.PRNGKey(0) 117 | input_shape = (101, image_size, image_size, 3) 118 | images = jax.random.normal(key, input_shape) 119 | 120 | # Use your own images 121 | dataset = ArrayDataset(images) 122 | dataloader = DataLoader(dataset, 123 | batch_size=batch_size, 124 | shuffle=True, 125 | drop_last=False) 126 | 127 | # Create diffusion model 128 | diffusion_model = DiffusionModel(image_size, widths, block_depth) 129 | params = diffusion_model.init(key, images) 130 | pred_noises, pred_images = diffusion_model.apply(params, images) 131 | print(pred_noises.shape, pred_images.shape) 132 | 133 | # Training on your data 134 | # Note: saved params are often different from training weights, use the saved params for generation 135 | trainer = DiffusionDataParallelTrainer(diffusion_model, 136 | input_shape=images.shape, 137 | weights_filename='params.pkl', 138 | learning_rate=1e-4) 139 | trainer.train(dataloader, 10, dataloader) 140 | print(trainer.evaluate(dataloader)) 141 | 142 | # Generate some samples 143 | params = trainer.load_params('params.pkl') 144 | generated_images = diffusion_model.apply({'params': params}, 145 | num_images=5, 146 | diffusion_steps=5, 147 | method=diffusion_model.generate) 148 | print(generated_images.shape) 149 | ``` 150 | 151 | Creating a Whisper TTS model 152 | ---------------- 153 | 154 | ```py 155 | import jax 156 | import jax.numpy as jnp 157 | from nanodl import ArrayDataset, DataLoader 158 | from nanodl import Whisper, WhisperDataParallelTrainer 159 | 160 | # Dummy data parameters 161 | batch_size = 8 162 | max_length = 50 163 | embed_dim = 256 164 | vocab_size = 1000 165 | 166 | # Generate data: replace with actual tokenised/quantised data 167 | dummy_targets = jnp.ones((101, max_length), dtype=jnp.int32) 168 | dummy_inputs = jnp.ones((101, max_length, embed_dim)) 169 | 170 | dataset = ArrayDataset(dummy_inputs, 171 | dummy_targets) 172 | 173 | dataloader = DataLoader(dataset, 174 | batch_size=batch_size, 175 | shuffle=True, 176 | drop_last=False) 177 | 178 | # How to loop through dataloader 179 | for batch in dataloader: 180 | x, y = batch 181 | print(x.shape, y.shape) 182 | break 183 | 184 | # model parameters 185 | hyperparams = { 186 | 'num_layers': 1, 187 | 'hidden_dim': 256, 188 | 'num_heads': 2, 189 | 'feedforward_dim': 256, 190 | 'dropout': 0.1, 191 | 'vocab_size': 1000, 192 | 'embed_dim': embed_dim, 193 | 'max_length': max_length, 194 | 'start_token': 0, 195 | 'end_token': 50, 196 | } 197 | 198 | # Initialize model 199 | model = Whisper(**hyperparams) 200 | rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} 201 | params = model.init(rngs, dummy_inputs, dummy_targets)['params'] 202 | outputs = model.apply({'params': params}, dummy_inputs, dummy_targets, rngs=rngs) 203 | print(outputs.shape) 204 | 205 | # Training on your data 206 | trainer = WhisperDataParallelTrainer(model, 207 | dummy_inputs.shape, 208 | dummy_targets.shape, 209 | 'params.pkl') 210 | trainer.train(dataloader, 2, dataloader) 211 | 212 | # Sample inference 213 | params = trainer.load_params('params.pkl') 214 | 215 | # for more than one sample, use model.generate_batch 216 | transcripts = model.apply({'params': params}, 217 | dummy_inputs[:1], 218 | rngs=rngs, 219 | method=model.generate) 220 | 221 | print(transcripts) 222 | ``` 223 | 224 | Creating an Accelerated PCA 225 | ---------------- 226 | 227 | ```py 228 | import jax 229 | from nanodl import PCA 230 | 231 | data = jax.random.normal(jax.random.key(0), (1000, 10)) 232 | pca = PCA(n_components=2) 233 | pca.fit(data) 234 | transformed_data = pca.transform(data) 235 | original_data = pca.inverse_transform(transformed_data) 236 | X_sampled = pca.sample(n_samples=1000, key=None) 237 | print(X_sampled.shape, original_data.shape, transformed_data.shape) 238 | ``` 239 | GPU/TPU-accelerated versions of many models on SKLearn like NaiveBayesClassifier, Linear Regression, KMeans are available on NanoDL. 240 | 241 | Using an individual module 242 | ---------------- 243 | 244 | Each contituent layer can be used in your own model. 245 | 246 | ```py 247 | from nanodl import GraphAttetnionLayer 248 | 249 | class GAT(nn.Module): 250 | nfeat: int 251 | nhid: int 252 | nclass: int 253 | dropout_rate: float 254 | alpha: float 255 | nheads: int 256 | 257 | @nn.compact 258 | def __call__(self, 259 | x: jnp.ndarray, 260 | adj: jnp.ndarray, 261 | training: bool = False) -> jnp.ndarray: 262 | heads = [GraphAttentionLayer(self.nfeat, 263 | self.nhid, 264 | dropout_rate=self.dropout_rate, 265 | alpha=self.alpha, concat=True) for _ in range(self.nheads)] 266 | 267 | x = jnp.concatenate([head(x, adj, training) for head in heads], axis=1) 268 | x = nn.Dropout(rate=self.dropout_rate, 269 | deterministic=not training)(x) 270 | 271 | out_att = GraphAttentionLayer(self.nhid * self.nheads, 272 | self.nclass, 273 | dropout_rate=self.dropout_rate, 274 | alpha=self.alpha, concat=False) 275 | 276 | return out_att(x, adj, training) 277 | ``` 278 | 279 | With this, you could for example create a transformer model with T5 Encoder and LlaMa2 Decoder! -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: NanoDL Documentation 2 | theme: 3 | name: readthedocs 4 | highlightjs: true 5 | plugins: 6 | - search 7 | - mkdocstrings: 8 | handlers: 9 | python: 10 | options: 11 | docstring_section_style: "list" 12 | members_order: "source" 13 | show_root_heading: true 14 | show_source: false 15 | show_signature_annotations: false 16 | selection: 17 | inheritance: true 18 | filters: 19 | - "!^_[^_]" # Exclude members starting with a single underscore 20 | rendering: 21 | show_category_heading: true 22 | heading_level: 3 23 | docstring_styles: ["google", "numpy", "restructuredtext", "plain", "markdown"] 24 | cross_references: 25 | use_short_names: true 26 | fail_on_missing_reference: false 27 | introspection: 28 | modules: ["nanodl"] 29 | markdown_extensions: 30 | - markdown_include.include: 31 | base_path: . 32 | - admonition -------------------------------------------------------------------------------- /nanodl/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.2.5.dev1" 2 | 3 | from nanodl.__src.classical.bayes import NaiveBayesClassifier 4 | from nanodl.__src.classical.clustering import GaussianMixtureModel, KMeans 5 | from nanodl.__src.classical.dimensionality_reduction import PCA 6 | from nanodl.__src.classical.regression import ( 7 | GaussianProcess, 8 | LinearRegression, 9 | LogisticRegression, 10 | ) 11 | from nanodl.__src.experimental.gat import GAT, GraphAttentionLayer 12 | from nanodl.__src.models.attention import ( 13 | GatedMultiHeadAttention, 14 | HierarchicalMultiHeadAttention, 15 | LocalMultiHeadAttention, 16 | MultiQueryAttention, 17 | RotaryMultiHeadAttention, 18 | ) 19 | from nanodl.__src.models.clip import ( 20 | CLIP, 21 | CLIPDataParallelTrainer, 22 | ImageEncoder, 23 | SelfMultiHeadAttention, 24 | TextEncoder, 25 | ) 26 | from nanodl.__src.models.diffusion import ( 27 | DiffusionDataParallelTrainer, 28 | DiffusionModel, 29 | UNet, 30 | UNetDownBlock, 31 | UNetResidualBlock, 32 | UNetUpBlock, 33 | ) 34 | from nanodl.__src.models.gemma import ( 35 | Gemma, 36 | GemmaDataParallelTrainer, 37 | GemmaDecoder, 38 | GemmaDecoderBlock, 39 | ) 40 | from nanodl.__src.models.gpt import ( 41 | GPT3, 42 | GPT4, 43 | GPT3Block, 44 | GPT3Decoder, 45 | GPT4Block, 46 | GPT4Decoder, 47 | GPTDataParallelTrainer, 48 | PositionWiseFFN, 49 | ) 50 | from nanodl.__src.models.ijepa import IJEPA, IJEPADataParallelTrainer, IJEPADataSampler 51 | from nanodl.__src.models.lamda import ( 52 | LaMDA, 53 | LaMDABlock, 54 | LaMDADataParallelTrainer, 55 | LaMDADecoder, 56 | RelativeMultiHeadAttention, 57 | ) 58 | from nanodl.__src.models.llama import ( 59 | GroupedRotaryMultiHeadAttention, 60 | Llama3, 61 | Llama3Decoder, 62 | Llama3DecoderBlock, 63 | LlamaDataParallelTrainer, 64 | RotaryPositionalEncoding, 65 | ) 66 | from nanodl.__src.models.kan import( 67 | KANLinear, 68 | ChebyKANLinear, 69 | LegendreKANLinear, 70 | MonomialKANLinear, 71 | FourierKANLinear, 72 | HermiteKANLinear, 73 | ) 74 | from nanodl.__src.models.mistral import ( 75 | GroupedRotaryShiftedWindowMultiHeadAttention, 76 | Mistral, 77 | MistralDataParallelTrainer, 78 | MistralDecoder, 79 | MistralDecoderBlock, 80 | Mixtral, 81 | MixtralDecoder, 82 | MixtralDecoderBlock, 83 | ) 84 | from nanodl.__src.models.mixer import ( 85 | Mixer, 86 | MixerBlock, 87 | MixerDataParallelTrainer, 88 | MixerEncoder, 89 | ) 90 | from nanodl.__src.models.reward import RewardDataParallelTrainer, RewardModel 91 | from nanodl.__src.models.t5 import ( 92 | T5, 93 | T5DataParallelTrainer, 94 | T5Decoder, 95 | T5DecoderBlock, 96 | T5Encoder, 97 | T5EncoderBlock, 98 | ) 99 | from nanodl.__src.models.transformer import ( 100 | AddNorm, 101 | MultiHeadAttention, 102 | PositionalEncoding, 103 | PositionWiseFFN, 104 | TokenAndPositionEmbedding, 105 | Transformer, 106 | TransformerDataParallelTrainer, 107 | TransformerDecoderBlock, 108 | TransformerEncoder, 109 | ) 110 | from nanodl.__src.models.vit import ( 111 | PatchEmbedding, 112 | ViT, 113 | ViTBlock, 114 | ViTDataParallelTrainer, 115 | ViTEncoder, 116 | ) 117 | from nanodl.__src.models.whisper import ( 118 | Whisper, 119 | WhisperDataParallelTrainer, 120 | WhisperSpeechEncoder, 121 | WhisperSpeechEncoderBlock, 122 | ) 123 | from nanodl.__src.utils.data import ArrayDataset, DataLoader, Dataset 124 | from nanodl.__src.utils.ml import ( 125 | batch_cosine_similarities, 126 | batch_pearsonr, 127 | classification_scores, 128 | count_parameters, 129 | entropy, 130 | gini_impurity, 131 | hamming, 132 | jaccard, 133 | kl_divergence, 134 | mean_reciprocal_rank, 135 | zero_pad_sequences, 136 | ) 137 | from nanodl.__src.utils.nlp import ( 138 | bleu, 139 | cider_score, 140 | meteor, 141 | perplexity, 142 | rouge, 143 | word_error_rate, 144 | ) 145 | from nanodl.__src.utils.random import * 146 | from nanodl.__src.utils.vision import ( 147 | adjust_brightness, 148 | adjust_contrast, 149 | flip_image, 150 | gaussian_blur, 151 | normalize_images, 152 | random_crop, 153 | random_flip_image, 154 | sobel_edge_detection, 155 | ) 156 | 157 | __all__ = [ 158 | # Classical 159 | "NaiveBayesClassifier", 160 | "PCA", 161 | "KMeans", 162 | "GaussianMixtureModel", 163 | "LinearRegression", 164 | "LogisticRegression", 165 | "GaussianProcess", 166 | # Models 167 | "IJEPA", 168 | "IJEPADataParallelTrainer", 169 | "IJEPADataSampler", 170 | "Gemma", 171 | "GemmaDataParallelTrainer", 172 | "GemmaDecoder", 173 | "GemmaDecoderBlock", 174 | "GAT", 175 | "GraphAttentionLayer", 176 | "T5", 177 | "T5DataParallelTrainer", 178 | "T5Encoder", 179 | "T5Decoder", 180 | "T5EncoderBlock", 181 | "T5DecoderBlock", 182 | "ViT", 183 | "ViTDataParallelTrainer", 184 | "ViTBlock", 185 | "ViTEncoder", 186 | "PatchEmbedding", 187 | "CLIP", 188 | "CLIPDataParallelTrainer", 189 | "ImageEncoder", 190 | "TextEncoder", 191 | "SelfMultiHeadAttention", 192 | "LaMDA", 193 | "LaMDADataParallelTrainer", 194 | "LaMDABlock", 195 | "LaMDADecoder", 196 | "RelativeMultiHeadAttention", 197 | "Mixer", 198 | "MixerDataParallelTrainer", 199 | "MixerBlock", 200 | "MixerEncoder", 201 | "Llama3", 202 | "LlamaDataParallelTrainer", 203 | "RotaryPositionalEncoding", 204 | "Llama3Decoder", 205 | "Llama3DecoderBlock", 206 | "GroupedRotaryMultiHeadAttention", 207 | "GPT3", 208 | "GPT4", 209 | "GPTDataParallelTrainer", 210 | "GPT3Block", 211 | "GPT4Block", 212 | "GPT3Decoder", 213 | "GPT4Decoder", 214 | "PositionWiseFFN", 215 | "Mistral", 216 | "MistralDataParallelTrainer", 217 | "MistralDecoder", 218 | "MistralDecoderBlock", 219 | "GroupedRotaryShiftedWindowMultiHeadAttention", 220 | "Mixtral", 221 | "MixtralDecoder", 222 | "MixtralDecoderBlock", 223 | "Whisper", 224 | "WhisperDataParallelTrainer", 225 | "WhisperSpeechEncoder", 226 | "WhisperSpeechEncoderBlock", 227 | "RewardModel", 228 | "RewardDataParallelTrainer", 229 | "DiffusionModel", 230 | "DiffusionDataParallelTrainer", 231 | "UNet", 232 | "UNetDownBlock", 233 | "UNetUpBlock", 234 | "UNetResidualBlock", 235 | "Transformer", 236 | "TransformerDataParallelTrainer", 237 | "TransformerEncoder", 238 | "TransformerDecoderBlock", 239 | "PositionalEncoding", 240 | "PositionWiseFFN", 241 | "TokenAndPositionEmbedding", 242 | "MultiHeadAttention", 243 | "AddNorm", 244 | # Utilities 245 | "Dataset", 246 | "ArrayDataset", 247 | "DataLoader", 248 | "batch_cosine_similarities", 249 | "batch_pearsonr", 250 | "classification_scores", 251 | "count_parameters", 252 | "entropy", 253 | "gini_impurity", 254 | "hamming", 255 | "jaccard", 256 | "kl_divergence", 257 | "mean_reciprocal_rank", 258 | "zero_pad_sequences", 259 | "bleu", 260 | "cider_score", 261 | "meteor", 262 | "perplexity", 263 | "rouge", 264 | "word_error_rate", 265 | "adjust_brightness", 266 | "adjust_contrast", 267 | "flip_image", 268 | "gaussian_blur", 269 | "normalize_images", 270 | "random_crop", 271 | "random_flip_image", 272 | "sobel_edge_detection", 273 | "MultiQueryAttention", 274 | "LocalMultiHeadAttention", 275 | "HierarchicalMultiHeadAttention", 276 | "GatedMultiHeadAttention", 277 | "RotaryMultiHeadAttention", 278 | # Random 279 | "time_rng_key", 280 | "uniform", 281 | "normal", 282 | "bernoulli", 283 | "categorical", 284 | "randint", 285 | "permutation", 286 | "gumbel", 287 | "choice", 288 | "bits", 289 | "exponential", 290 | "triangular", 291 | "truncated_normal", 292 | "poisson", 293 | "geometric", 294 | "gamma", 295 | "chisquare", 296 | "KANLinear", 297 | "ChebyKANLinear", 298 | "LegendreKANLinear", 299 | "MonomialKANLinear", 300 | "FourierKANLinear", 301 | "HermiteKANLinear", 302 | ] 303 | 304 | import importlib 305 | import sys 306 | 307 | 308 | def check_library_installed(lib_name): 309 | try: 310 | return importlib.import_module(lib_name) 311 | except ImportError: 312 | raise ImportError(f"{lib_name} is not installed or improperly installed.") 313 | 314 | 315 | def test_flax(flax): 316 | model = flax.linen.Dense(features=10) 317 | 318 | 319 | def test_jax(jax): 320 | arr = jax.numpy.array([1, 2, 3]) 321 | result = jax.numpy.sum(arr) 322 | 323 | 324 | def test_optax(optax): 325 | optimizer = optax.sgd(learning_rate=0.1) 326 | 327 | 328 | def test_einops(einops): 329 | arr = einops.rearrange([1, 2, 3], "a b c -> b a c") 330 | 331 | 332 | def main(): 333 | try: 334 | flax = check_library_installed("flax") 335 | jax = check_library_installed("jax") 336 | optax = check_library_installed("optax") 337 | einops = check_library_installed("einops") 338 | 339 | test_flax(flax) 340 | test_jax(jax) 341 | test_optax(optax) 342 | 343 | except ImportError as e: 344 | print(e) 345 | sys.exit(1) 346 | except Exception as e: 347 | print(f"An error occurred while verifying Jax/Flax/Optax installation: {e}") 348 | sys.exit(1) 349 | 350 | 351 | if __name__ == "__main__": 352 | main() 353 | -------------------------------------------------------------------------------- /nanodl/__src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMUNACHI/nanodl/e52861a5b2c9bf76e4e79e0bf88a07420497579d/nanodl/__src/__init__.py -------------------------------------------------------------------------------- /nanodl/__src/classical/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMUNACHI/nanodl/e52861a5b2c9bf76e4e79e0bf88a07420497579d/nanodl/__src/classical/__init__.py -------------------------------------------------------------------------------- /nanodl/__src/classical/bayes.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | 7 | def fit_naive_bayes( 8 | X: jnp.ndarray, y: jnp.ndarray, num_classes: int 9 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 10 | class_priors = jnp.zeros(num_classes) 11 | feature_probs = jnp.zeros((num_classes, X.shape[1])) 12 | 13 | for i in range(num_classes): 14 | class_mask = y == i 15 | class_count = jnp.sum(class_mask) 16 | class_priors = class_priors.at[i].set(class_count / X.shape[0]) 17 | feature_count = jnp.sum(X[class_mask], axis=0) 18 | feature_probs = feature_probs.at[i].set(feature_count / class_count) 19 | 20 | return class_priors, feature_probs 21 | 22 | 23 | @jax.jit 24 | def predict_naive_bayes( 25 | X: jnp.ndarray, class_priors: jnp.ndarray, feature_probs: jnp.ndarray 26 | ) -> jnp.ndarray: 27 | # Calculate log probabilities for features 28 | log_feature_probs = jnp.log(feature_probs) 29 | log_feature_probs_neg = jnp.log(1 - feature_probs) 30 | 31 | # Expand dimensions for broadcasting 32 | expanded_log_feature_probs = log_feature_probs[:, None, :] 33 | expanded_log_feature_probs_neg = log_feature_probs_neg[:, None, :] 34 | 35 | # Calculate log probabilities for each sample and class 36 | log_probs = jnp.sum( 37 | expanded_log_feature_probs * X + expanded_log_feature_probs_neg * (1 - X), 38 | axis=2, 39 | ) 40 | log_probs += jnp.log(class_priors)[:, None] 41 | return jnp.argmax(log_probs, axis=0) 42 | 43 | 44 | def accuracy(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> float: 45 | return jnp.mean(y_true == y_pred) 46 | 47 | 48 | class NaiveBayesClassifier: 49 | """ 50 | Naive Bayes classifier using JAX. 51 | 52 | Example usage: 53 | ``` 54 | classifier = NaiveBayesClassifier(num_classes=2) 55 | X = jnp.array([[0, 1], [1, 0], [1, 1]]) 56 | y = jnp.array([0, 1, 0]) 57 | classifier.fit(X, y) 58 | predictions = classifier.predict(X) 59 | acc = accuracy(y, predictions) 60 | print(f"Accuracy: {acc}") 61 | ``` 62 | 63 | Attributes: 64 | num_classes (int): Number of classes. 65 | class_priors (jnp.ndarray): Class priors, shape (num_classes,). 66 | feature_probs (jnp.ndarray): Feature probabilities, shape (num_classes, num_features). 67 | """ 68 | 69 | def __init__(self, num_classes: int): 70 | self.num_classes = num_classes 71 | self.class_priors = None 72 | self.feature_probs = None 73 | 74 | def fit(self, X: jnp.ndarray, y: jnp.ndarray) -> None: 75 | self.class_priors, self.feature_probs = fit_naive_bayes(X, y, self.num_classes) 76 | 77 | def predict(self, X: jnp.ndarray) -> jnp.ndarray: 78 | return predict_naive_bayes(X, self.class_priors, self.feature_probs) 79 | -------------------------------------------------------------------------------- /nanodl/__src/classical/clustering.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | class KMeans: 6 | """ 7 | KMeans clustering using JAX for GPU/TPU acceleration. 8 | 9 | Attributes: 10 | k (int): Number of clusters. 11 | num_iters (int): Maximum number of iterations. 12 | random_seed (int): Random seed for initialization. 13 | centroids (Optional[jnp.ndarray]): Centroids of the clusters. 14 | clusters (Optional[jnp.ndarray]): Cluster assignments of the data points. 15 | 16 | Example usage: 17 | ``` 18 | kmeans = KMeans(k=4) 19 | X, _ = make_blobs(n_samples=300, centers=4, n_features=2, random_state=0) 20 | kmeans.fit(X) 21 | clusters = kmeans.predict(X) 22 | print("Centroids:", kmeans.centroids) 23 | print("Cluster assignments:", clusters) 24 | ``` 25 | """ 26 | 27 | def __init__(self, k: int, num_iters: int = 100, random_seed: int = 0) -> None: 28 | self.k = k 29 | self.num_iters = num_iters 30 | self.random_seed = random_seed 31 | self.centroids = None 32 | self.clusters = None 33 | 34 | def initialize_centroids(self, X: jnp.ndarray) -> jnp.ndarray: 35 | 36 | indices = jnp.arange(X.shape[0]) 37 | selected = jax.random.choice( 38 | jax.random.PRNGKey(self.random_seed), 39 | indices, 40 | shape=(self.k,), 41 | replace=False, 42 | ) 43 | return X[selected] 44 | 45 | def assign_clusters(self, X: jnp.ndarray, centroids: jnp.ndarray) -> jnp.ndarray: 46 | 47 | distances = jnp.sqrt( 48 | ((X[:, jnp.newaxis, :] - centroids[jnp.newaxis, :, :]) ** 2).sum(axis=2) 49 | ) 50 | return jnp.argmin(distances, axis=1) 51 | 52 | def update_centroids(self, X: jnp.ndarray, clusters: jnp.ndarray) -> jnp.ndarray: 53 | 54 | return jnp.array([X[clusters == i].mean(axis=0) for i in range(self.k)]) 55 | 56 | def fit(self, X: jnp.ndarray) -> None: 57 | 58 | self.centroids = self.initialize_centroids(X) 59 | for _ in range(self.num_iters): 60 | self.clusters = self.assign_clusters(X, self.centroids) 61 | new_centroids = self.update_centroids(X, self.clusters) 62 | if jnp.allclose(self.centroids, new_centroids): 63 | break 64 | self.centroids = new_centroids 65 | 66 | def predict(self, X: jnp.ndarray) -> jnp.ndarray: 67 | if self.centroids is None: 68 | raise ValueError("Model not yet trained. Call 'fit' with training data.") 69 | return self.assign_clusters(X, self.centroids) 70 | 71 | 72 | class GaussianMixtureModel: 73 | """ 74 | Gaussian Mixture Model implemented in JAX. 75 | 76 | This class represents a Gaussian Mixture Model (GMM) for clustering and density estimation. 77 | It uses the Expectation-Maximization (EM) algorithm for fitting the model to data. 78 | 79 | Attributes: 80 | n_components (int): Number of mixture components. 81 | tol (float): Tolerance for convergence. 82 | max_iter (int): Maximum number of iterations for the EM algorithm. 83 | means (jnp.ndarray): Means of the Gaussian components. 84 | covariances (jnp.ndarray): Covariances of the Gaussian components. 85 | weights (jnp.ndarray): Weights of the Gaussian components. 86 | seed (int): Random seed for initialization. 87 | 88 | Example: 89 | ``` 90 | >>> import jax.numpy as jnp 91 | >>> from gaussian_mixture_model_jax import GaussianMixtureModelJAX 92 | >>> X = jnp.array([[1, 2], [1, 4], [1, 0], 93 | ... [10, 2], [10, 4], [10, 0]]) 94 | >>> gmm = GaussianMixtureModelJAX(n_components=2, seed=42) 95 | >>> gmm.fit(X) 96 | >>> print(gmm.means) 97 | >>> labels = gmm.predict(X) 98 | >>> print(labels) 99 | ``` 100 | """ 101 | 102 | def __init__( 103 | self, n_components: int, tol: float = 1e-3, max_iter: int = 100, seed: int = 0 104 | ) -> None: 105 | self.n_components = n_components 106 | self.tol = tol 107 | self.max_iter = max_iter 108 | self.means = None 109 | self.covariances = None 110 | self.weights = None 111 | self.seed = seed 112 | 113 | def fit(self, X: jnp.ndarray) -> None: 114 | _, n_features = X.shape 115 | rng = jax.random.PRNGKey(self.seed) 116 | 117 | self.means = jax.random.normal(rng, (self.n_components, n_features)) 118 | self.covariances = jnp.array([jnp.eye(n_features)] * self.n_components) 119 | self.weights = jnp.ones(self.n_components) / self.n_components 120 | 121 | log_likelihood = 0 122 | for _ in range(self.max_iter): 123 | responsibilities = self._e_step(X) 124 | self._m_step(X, responsibilities) 125 | 126 | new_log_likelihood = self._compute_log_likelihood(X) 127 | if jnp.abs(new_log_likelihood - log_likelihood) < self.tol: 128 | break 129 | log_likelihood = new_log_likelihood 130 | 131 | def _e_step(self, X: jnp.ndarray) -> jnp.ndarray: 132 | responsibilities = jnp.zeros((X.shape[0], self.n_components)) 133 | for k in range(self.n_components): 134 | responsibilities = responsibilities.at[:, k].set( 135 | self.weights[k] 136 | * self._multivariate_gaussian(X, self.means[k], self.covariances[k]) 137 | ) 138 | responsibilities /= responsibilities.sum(axis=1, keepdims=True) 139 | return responsibilities 140 | 141 | def _m_step(self, X: jnp.ndarray, responsibilities: jnp.ndarray) -> None: 142 | n_samples = X.shape[0] 143 | for k in range(self.n_components): 144 | Nk = responsibilities[:, k].sum() 145 | self.means = self.means.at[k].set( 146 | (1 / Nk) * jnp.dot(responsibilities[:, k], X) 147 | ) 148 | diff = X - self.means[k] 149 | self.covariances = self.covariances.at[k].set( 150 | (1 / Nk) * jnp.dot(responsibilities[:, k] * diff.T, diff) 151 | ) 152 | self.weights = self.weights.at[k].set(Nk / n_samples) 153 | 154 | def _multivariate_gaussian( 155 | self, X: jnp.ndarray, mean: jnp.ndarray, cov: jnp.ndarray 156 | ) -> jnp.ndarray: 157 | n = X.shape[1] 158 | diff = X - mean 159 | return jnp.exp( 160 | -0.5 * jnp.sum(jnp.dot(diff, jnp.linalg.inv(cov)) * diff, axis=1) 161 | ) / (jnp.sqrt((2 * jnp.pi) ** n * jnp.linalg.det(cov))) 162 | 163 | def _compute_log_likelihood(self, X: jnp.ndarray) -> float: 164 | log_likelihood = 0 165 | for k in range(self.n_components): 166 | log_likelihood += jnp.sum( 167 | jnp.log( 168 | self.weights[k] 169 | * self._multivariate_gaussian(X, self.means[k], self.covariances[k]) 170 | ) 171 | ) 172 | return log_likelihood 173 | 174 | def predict(self, X: jnp.ndarray) -> jnp.ndarray: 175 | responsibilities = self._e_step(X) 176 | return jnp.argmax(responsibilities, axis=1) 177 | -------------------------------------------------------------------------------- /nanodl/__src/classical/dimensionality_reduction.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | 7 | class PCA: 8 | """ 9 | A class for performing Principal Component Analysis (PCA) on data. 10 | 11 | Attributes: 12 | n_components (int): Number of principal components to retain. 13 | components (ndarray): Principal components learned from the data. 14 | mean (ndarray): Mean of the data used for centering. 15 | 16 | Methods: 17 | fit(X): 18 | Fit the PCA model on the input data. 19 | 20 | transform(X): 21 | Transform the input data into the PCA space. 22 | 23 | inverse_transform(X_transformed): 24 | Inverse transform PCA-transformed data back to the original space. 25 | 26 | sample(n_samples=1, key=None): 27 | Generate synthetic data samples from the learned PCA distribution. 28 | 29 | Example Usage: 30 | # Create an instance of the PCA class 31 | data = jax.random.normal(jax.random.key(0), (1000, 10)) 32 | pca = PCA(n_components=2) 33 | pca.fit(data) 34 | transformed_data = pca.transform(data) 35 | original_data = pca.inverse_transform(transformed_data) 36 | X_sampled = pca.sample(n_samples=1000, key=None) 37 | print(X_sampled.shape, original_data.shape, transformed_data.shape) 38 | """ 39 | 40 | def __init__(self, n_components: int): 41 | 42 | self.n_components = n_components 43 | self.components = None 44 | self.mean = None 45 | 46 | def fit(self, X: jnp.ndarray) -> None: 47 | 48 | self.mean = jnp.mean(X, axis=0) 49 | X_centered = X - self.mean 50 | cov_matrix = jnp.cov(X_centered, rowvar=False) 51 | eigvals, eigvecs = jnp.linalg.eigh(cov_matrix) 52 | sorted_indices = jnp.argsort(eigvals)[::-1] 53 | sorted_eigvecs = eigvecs[:, sorted_indices] 54 | self.components = sorted_eigvecs[:, : self.n_components] 55 | 56 | def transform(self, X: jnp.ndarray) -> jnp.ndarray: 57 | X_centered = X - self.mean 58 | return jnp.dot(X_centered, self.components) 59 | 60 | def inverse_transform(self, X_transformed: jnp.ndarray) -> jnp.ndarray: 61 | return jnp.dot(X_transformed, self.components.T) + self.mean 62 | 63 | def sample( 64 | self, n_samples: int = 1, key: Optional[jnp.ndarray] = None 65 | ) -> jnp.ndarray: 66 | 67 | if key is None: 68 | key = jax.random.PRNGKey(0) 69 | 70 | z = jax.random.normal(key, (n_samples, self.n_components)) 71 | X_sampled = self.inverse_transform(z) 72 | return X_sampled 73 | -------------------------------------------------------------------------------- /nanodl/__src/classical/dsp.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax.numpy as jnp 4 | from jax import random 5 | 6 | 7 | def fastica( 8 | X: jnp.ndarray, n_components: jnp.ndarray, max_iter: int = 1000, tol: float = 1e-4 9 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: 10 | """ 11 | Perform Independent Component Analysis (ICA) on the input data using the FastICA algorithm. 12 | 13 | Parameters: 14 | X : jax.numpy.ndarray 15 | The input data matrix, where each row represents a data point, and each column represents a different signal. 16 | The input data should be a 2D jax.numpy array with shape (n_samples, n_features). 17 | n_components : int 18 | The number of independent components to extract. This should be less than or equal to the number of features in the input data. 19 | max_iter : int, optional 20 | The maximum number of iterations for the optimization process. The default value is 1000 iterations. 21 | tol : float, optional 22 | The tolerance for convergence. The optimization process stops when the maximum absolute change in the diagonal elements of the 23 | unmixing matrix from one iteration to the next is less than this tolerance. The default value is 1e-4. 24 | 25 | Returns: 26 | S : jax.numpy.ndarray 27 | The separated independent components. This is a 2D jax.numpy array with shape (n_components, n_samples), where each row represents 28 | a different independent component, and each column represents a data point. 29 | W : jax.numpy.ndarray 30 | The unmixing matrix. This is a 2D jax.numpy array with shape (n_components, n_features), representing the estimated inverse of the 31 | mixing matrix. It is used to transform the input data back into the independent components. 32 | whitening_matrix : jax.numpy.ndarray 33 | The whitening matrix used to whiten the input data. This is a 2D jax.numpy array with shape (n_features, n_features), used to decorrelate 34 | the input data and make its covariance matrix the identity matrix. 35 | 36 | Description: 37 | The FastICA algorithm aims to separate the mixed input signals into statistically independent components. The function first whitens the input 38 | data to decorrelate it and normalize its variance. Then, it initializes a random unmixing matrix and uses an optimization process to find 39 | the optimal unmixing matrix that maximizes the independence of the source signals. 40 | 41 | The optimization process involves iteratively updating the unmixing matrix based on the non-linear function (`tanh` in this case) applied 42 | to the transformed data (`WX`). The process stops when the unmixing matrix converges according to the specified tolerance (`tol`) or when the 43 | maximum number of iterations (`max_iter`) is reached. 44 | 45 | Once the optimal unmixing matrix is found, the function applies it to the whitened data to obtain the separated independent components. 46 | 47 | Example usage: 48 | # Set random seed 49 | jax.random.PRNGKey(42) 50 | 51 | # Generate synthetic source signals 52 | n_samples = 2000 53 | time = jnp.linspace(0, 8, n_samples) 54 | s1 = jnp.sin(2 * time) 55 | s2 = jnp.sign(jnp.sin(3 * time)) 56 | 57 | # Combine the sources with a mixing matrix 58 | A = jnp.array([[1, 1], [0.5, 2]]) 59 | X = jnp.dot(A, jnp.array([s1, s2])) 60 | 61 | # Perform ICA 62 | n_components = 2 63 | S, W, whitening_matrix = fastica(X.T, n_components) 64 | 65 | # Plot the results 66 | plt.figure(figsize=(12, 8)) 67 | 68 | plt.subplot(3, 1, 1) 69 | plt.title('Original Source Signals') 70 | plt.plot(time, s1, label='Source 1 (Sine Wave)') 71 | plt.plot(time, s2, label='Source 2 (Square Wave)') 72 | plt.legend() 73 | 74 | plt.subplot(3, 1, 2) 75 | plt.title('Mixed Signals') 76 | plt.plot(time, X[0], label='Mixed Signal 1') 77 | plt.plot(time, X[1], label='Mixed Signal 2') 78 | plt.legend() 79 | 80 | plt.subplot(3, 1, 3) 81 | plt.title('Separated Signals (Using ICA)') 82 | plt.plot(time, S[0], label='Separated Signal 1') 83 | plt.plot(time, S[1], label='Separated Signal 2') 84 | plt.legend() 85 | s 86 | plt.tight_layout() 87 | plt.show() 88 | """ 89 | # Calculate the covariance matrix and perform eigenvalue decomposition 90 | cov_matrix = jnp.cov(X, rowvar=False) 91 | eigenvalues, eigenvectors = jnp.linalg.eigh(cov_matrix) 92 | 93 | # Sort the eigenvalues and eigenvectors 94 | idx = jnp.argsort(eigenvalues)[::-1] 95 | eigenvalues = eigenvalues[idx] 96 | eigenvectors = eigenvectors[:, idx] 97 | 98 | # Create the whitening matrix 99 | D = jnp.diag(1.0 / jnp.sqrt(eigenvalues)) 100 | whitening_matrix = jnp.dot(eigenvectors, D) 101 | X_whitened = jnp.dot(X, whitening_matrix) 102 | 103 | # Initialize unmixing matrix with random values 104 | rng = random.PRNGKey(0) # Set a seed for reproducibility 105 | W = random.normal(rng, (n_components, n_components)) 106 | 107 | # Perform FastICA algorithm 108 | for _ in range(max_iter): 109 | WX = jnp.dot(X_whitened, W.T) 110 | g = jnp.tanh(WX) 111 | g_prime = 1 - g**2 112 | W_new = (jnp.dot(X_whitened.T, g) / X.shape[0]) - jnp.diag( 113 | g_prime.mean(axis=0) 114 | ).dot(W) 115 | 116 | # Orthogonalize the unmixing matrix 117 | W_new, _ = jnp.linalg.qr(W_new) 118 | 119 | # Check for convergence 120 | if jnp.max(jnp.abs(jnp.abs(jnp.diag(jnp.dot(W_new, W.T))) - 1)) < tol: 121 | W = W_new 122 | break 123 | 124 | W = W_new 125 | 126 | # Calculate the separated independent components 127 | S = jnp.dot(W, X_whitened.T) 128 | 129 | return S, W, whitening_matrix 130 | -------------------------------------------------------------------------------- /nanodl/__src/classical/regression.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | 7 | class LinearRegression: 8 | """ 9 | Linear Regression model implemented using JAX. 10 | 11 | Parameters: 12 | - input_dim (int): Dimension of the input feature. 13 | - output_dim (int): Dimension of the output target. 14 | 15 | Methods: 16 | - linear_regression(params, x): Linear regression prediction function. 17 | - loss(params, x, y): Mean squared error loss function. 18 | - fit(x_data, y_data, learning_rate=0.1, num_epochs=100): Training function. 19 | - get_params(): Get the learned weights and bias. 20 | 21 | Example usage: 22 | ``` 23 | num_samples = 100 24 | input_dim = 1 25 | output_dim = 1 26 | 27 | x_data = jax.random.normal(random.PRNGKey(0), (num_samples, input_dim)) 28 | y_data = jnp.dot(x_data, jnp.array([[2.0]])) - jnp.array([[-1.0]]) 29 | 30 | lr_model = LinearRegression(input_dim, output_dim) 31 | lr_model.fit(x_data, y_data) 32 | 33 | learned_weights, learned_bias = lr_model.get_params() 34 | print("Learned Weights:", learned_weights) 35 | print("Learned Bias:", learned_bias) 36 | ``` 37 | """ 38 | 39 | def __init__(self, input_dim, output_dim): 40 | self.input_dim = input_dim 41 | self.output_dim = output_dim 42 | self.key = jax.random.PRNGKey(0) 43 | self.params = (jnp.zeros((input_dim, output_dim)), jnp.zeros((output_dim,))) 44 | 45 | def linear_regression(self, params, x): 46 | weights, bias = params 47 | return jnp.dot(x, weights) + bias 48 | 49 | def loss(self, params, x, y): 50 | predictions = self.linear_regression(params, x) 51 | return jnp.mean((predictions - y) ** 2) 52 | 53 | def fit(self, x_data, y_data, learning_rate=0.1, num_epochs=100): 54 | grad_loss = jax.grad(self.loss) 55 | for epoch in range(num_epochs): 56 | grads = grad_loss(self.params, x_data, y_data) 57 | weights_grad, bias_grad = grads 58 | weights, bias = self.params 59 | weights -= learning_rate * weights_grad 60 | bias -= learning_rate * bias_grad 61 | self.params = (weights, bias) 62 | epoch_loss = self.loss(self.params, x_data, y_data) 63 | print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}") 64 | 65 | print("Training completed.") 66 | 67 | def get_params(self): 68 | return self.params 69 | 70 | 71 | class LogisticRegression: 72 | """ 73 | Logistic Regression model implemented using JAX. 74 | 75 | Parameters: 76 | - input_dim (int): Dimension of the input feature. 77 | 78 | Methods: 79 | - sigmoid(x): Sigmoid activation function. 80 | - logistic_regression(params, x): Logistic regression prediction function. 81 | - loss(params, x, y): Binary cross-entropy loss function. 82 | - fit(x_data, y_data, learning_rate=0.1, num_epochs=100): Training function. 83 | - predict(x_data): Predict probabilities using the trained model. 84 | 85 | Example usage: 86 | ``` 87 | num_samples = 100 88 | input_dim = 2 89 | 90 | x_data = jax.random.normal(random.PRNGKey(0), (num_samples, input_dim)) 91 | logits = jnp.dot(x_data, jnp.array([0.5, -0.5])) - 0.1 92 | y_data = (logits > 0).astype(jnp.float32) 93 | 94 | lr_model = LogisticRegression(input_dim) 95 | lr_model.fit(x_data, y_data) 96 | 97 | test_data = jax.random.normal(random.PRNGKey(0), (num_samples, input_dim)) 98 | predictions = lr_model.predict(test_data) 99 | print("Predictions:", predictions) 100 | ``` 101 | """ 102 | 103 | def __init__(self, input_dim): 104 | self.input_dim = input_dim 105 | self.key = jax.random.PRNGKey(0) 106 | self.params = (jnp.zeros((input_dim,)), jnp.zeros(())) 107 | 108 | def sigmoid(self, x): 109 | return 1.0 / (1.0 + jnp.exp(-x)) 110 | 111 | def logistic_regression(self, params, x): 112 | weights, bias = params 113 | return self.sigmoid(jnp.dot(x, weights) + bias) 114 | 115 | def loss(self, params, x, y): 116 | predictions = self.logistic_regression(params, x) 117 | return -jnp.mean(y * jnp.log(predictions) + (1 - y) * jnp.log(1 - predictions)) 118 | 119 | def fit(self, x_data, y_data, learning_rate=0.1, num_epochs=100): 120 | grad_loss = jax.grad(self.loss) 121 | for epoch in range(num_epochs): 122 | grads = grad_loss(self.params, x_data, y_data) 123 | weights_grad, bias_grad = grads 124 | weights, bias = self.params 125 | weights -= learning_rate * weights_grad 126 | bias -= learning_rate * bias_grad 127 | self.params = (weights, bias) 128 | epoch_loss = self.loss(self.params, x_data, y_data) 129 | print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}") 130 | 131 | print("Training completed.") 132 | 133 | def predict(self, x_data): 134 | return self.logistic_regression(self.params, x_data) 135 | 136 | 137 | class GaussianProcess: 138 | """ 139 | A basic implementation of Gaussian Process regression using JAX. 140 | 141 | Attributes: 142 | kernel (Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]): The kernel function to measure similarity between data points. 143 | noise (float): Measurement noise added to the diagonal of the kernel matrix. 144 | X (jnp.ndarray): Training inputs. 145 | y (jnp.ndarray): Training outputs. 146 | K (jnp.ndarray): Kernel matrix incorporating training inputs and noise. 147 | 148 | Methods: 149 | fit(X, y): 150 | Fits the Gaussian Process model to the training data. 151 | 152 | predict(X_new): 153 | Makes predictions for new input points using the trained model. 154 | 155 | Example Usage: 156 | # Define a kernel function, e.g., Radial Basis Function (RBF) kernel 157 | def rbf_kernel(x1, x2, length_scale=1.0): 158 | diff = x1[:, None] - x2 159 | return jnp.exp(-0.5 * jnp.sum(diff**2, axis=-1) / length_scale**2) 160 | 161 | # Create an instance of the GaussianProcess class 162 | gp = GaussianProcess(kernel=rbf_kernel, noise=1e-3) 163 | 164 | # Fit the model on the training data 165 | gp.fit(X_train, y_train) 166 | 167 | # Make predictions on new data 168 | mean, covariance = gp.predict(X_new) 169 | """ 170 | 171 | def __init__( 172 | self, 173 | kernel: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], 174 | noise: float = 1e-3, 175 | ): 176 | 177 | self.kernel = kernel 178 | self.noise = noise 179 | self.X = None 180 | self.y = None 181 | self.K = None 182 | 183 | def fit(self, X: jnp.ndarray, y: jnp.ndarray) -> None: 184 | 185 | self.X = X 186 | self.y = y 187 | self.K = self.kernel(self.X, self.X) + jnp.eye(len(X)) * self.noise 188 | 189 | def predict(self, X_new: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: 190 | K_inv = jnp.linalg.inv(self.K) 191 | K_s = self.kernel(self.X, X_new) 192 | K_ss = self.kernel(X_new, X_new) 193 | 194 | mu_s = jnp.dot(K_s.T, jnp.dot(K_inv, self.y)) 195 | cov_s = K_ss - jnp.dot(K_s.T, jnp.dot(K_inv, K_s)) 196 | return mu_s, cov_s 197 | -------------------------------------------------------------------------------- /nanodl/__src/experimental/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMUNACHI/nanodl/e52861a5b2c9bf76e4e79e0bf88a07420497579d/nanodl/__src/experimental/__init__.py -------------------------------------------------------------------------------- /nanodl/__src/experimental/bitlinear.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from flax import linen as nn 4 | 5 | 6 | class BitLinear(nn.Module): 7 | """ 8 | Implements a linear transformation layer with quantization for both activations and weights, 9 | optimized for low-bit inference. The layer is designed to operate in two modes: training and inference. 10 | During training, the activations and weights are quantized using separate quantization functions, 11 | aiming to simulate low-bit operations and reduce the quantization error. For inference, a more 12 | aggressive quantization scheme is applied to both activations and weights, potentially different 13 | from the training quantization, to maximize performance and efficiency on low-bit hardware. 14 | 15 | Attributes: 16 | output_features (int): The number of output features. 17 | kernel_init (callable): A function to initialize the weights. Default is LeCun normal initializer. 18 | """ 19 | 20 | output_features: int 21 | kernel_init: callable = nn.initializers.lecun_normal() 22 | 23 | @nn.compact 24 | def __call__(self, x, training=False): 25 | w = self.param("kernel", self.kernel_init, (x.shape[-1], self.output_features)) 26 | 27 | if not training: 28 | x_quant, x_scale = self.fused_activation_norm_quant(x) 29 | 30 | # HELP: How run externally on params at once for efficiency 31 | # Quantising weigts all over again each call is repeated work 32 | # This can be done on params dict using jax tree utils. 33 | # Albeit the weight scale for quantisation needs to be utilised at inference 34 | # Its easy to bypassed on its own by passing the weight scale during a call 35 | # This will be a module in various transformer models in my project (NanoDL) 36 | # Is there a way to achieve this without complication my existing codebase? 37 | w, w_scale = self.inference_weight_quant(w) 38 | 39 | return self.inference_lowbit_matmul(x_quant, w) / w_scale / x_scale 40 | 41 | x_norm = self.rmsnorm(x) 42 | x_quant = x_norm + jax.lax.stop_gradient(self.activation_quant(x_norm) - x_norm) 43 | w_quant = w + jax.lax.stop_gradient(self.weight_quant(w) - w) 44 | return jnp.dot(x_quant, w_quant) 45 | 46 | def rmsnorm(self, x): 47 | return x / jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + 1e-5) 48 | 49 | def activation_quant(self, x): 50 | scale = 127.0 / jnp.max(jnp.abs(x), axis=-1, keepdims=True).clip(min=1e-5) 51 | y = jnp.round(x * scale).clip(-128, 127) / scale 52 | return y 53 | 54 | def weight_quant(self, w): 55 | scale = 1.0 / jnp.mean(jnp.abs(w)).clip(min=1e-5) 56 | u = jnp.round(w * scale).clip(-1, 1) / scale 57 | return u 58 | 59 | def fused_activation_norm_quant(self, x): 60 | x_norm = self.rmsnorm(x) 61 | scale = 127.0 / jnp.max(jnp.abs(x_norm), axis=-1, keepdims=True).clip(min=1e-5) 62 | x_quant = jnp.round(x_norm * scale).clip(-128, 127) / scale 63 | return x_quant, scale 64 | 65 | def inference_weight_quant(self, w): 66 | scale = jnp.abs(w).mean().clip(min=1e-5) 67 | u = jnp.sign(w - w.mean()) * scale 68 | return u, scale 69 | 70 | # Help: how to implement lowbit matmul kernel for efficiency that can be integrated into Flax model 71 | def inference_lowbit_matmul(self, x, w): 72 | return jnp.dot(x, w) 73 | -------------------------------------------------------------------------------- /nanodl/__src/experimental/gat.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from flax import linen as nn 4 | 5 | 6 | class GraphAttentionLayer(nn.Module): 7 | """ 8 | A single graph attention layer as part of a Graph Attention Network (GAT). 9 | 10 | This layer applies a self-attention mechanism on the nodes of a graph. Each node's features are transformed through a learned linear transformation, and attention coefficients are computed to determine the importance of every other node's features. This allows the model to dynamically adjust which nodes contribute most to the next layer's input for each node. 11 | 12 | Attributes: 13 | in_features (int): Number of input features per node. 14 | out_features (int): Number of output features per node. 15 | dropout_rate (float): Dropout rate applied to features and attention coefficients for regularization. 16 | alpha (float): Negative slope coefficient for the LeakyReLU activation function used in computing attention scores. 17 | concat (bool, optional): Whether to concatenate the output of attention heads in a multi-head attention mechanism. Default is True. 18 | 19 | Methods: 20 | __call__(self, x: jnp.ndarray, adj: jnp.ndarray, training: bool) -> jnp.ndarray: 21 | Forward pass of the graph attention layer. 22 | 23 | Args: 24 | x (jnp.ndarray): The input node features, shape (N, in_features), where N is the number of nodes. 25 | adj (jnp.ndarray): The adjacency matrix of the graph, shape (N, N), indicating node connections. 26 | training (bool): Whether the layer is being used in training mode. Affects dropout behavior. 27 | 28 | Returns: 29 | jnp.ndarray: The output node features after the attention mechanism. If `concat` is True, applies a non-linearity (LeakyReLU); otherwise, returns the linear combination of features directly. Shape is (N, out_features). 30 | """ 31 | 32 | in_features: int 33 | out_features: int 34 | dropout_rate: float 35 | alpha: float 36 | concat: bool = True 37 | 38 | @nn.compact 39 | def __call__(self, x: jnp.ndarray, adj: jnp.ndarray, training: bool) -> jnp.ndarray: 40 | 41 | W = self.param( 42 | "W", 43 | jax.nn.initializers.glorot_uniform(), 44 | (self.in_features, self.out_features), 45 | ) 46 | 47 | a = self.param( 48 | "a", jax.nn.initializers.glorot_uniform(), (2 * self.out_features, 1) 49 | ) 50 | 51 | h = jnp.dot(x, W) 52 | h = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(h) 53 | 54 | N = h.shape[0] 55 | a_input = jnp.concatenate( 56 | [h[:, None, :].repeat(N, axis=1), h[None, :, :].repeat(N, axis=0)], axis=2 57 | ) 58 | 59 | e = nn.leaky_relu(jnp.dot(a_input, a).squeeze(-1), negative_slope=self.alpha) 60 | 61 | zero_vec = -9e15 * jnp.ones_like(e) 62 | attention = jnp.where(adj > 0, e, zero_vec) 63 | attention = nn.softmax(attention, axis=1) 64 | 65 | attention = nn.Dropout(rate=self.dropout_rate, deterministic=not training)( 66 | attention 67 | ) 68 | 69 | h_prime = jnp.matmul(attention, h) 70 | 71 | if self.concat: 72 | return nn.leaky_relu(h_prime) 73 | else: 74 | return h_prime 75 | 76 | 77 | class GAT(nn.Module): 78 | """ 79 | Graph Attention Networks (GATs) are a type of neural network designed for graph-structured data. 80 | The key feature of GATs is the use of attention mechanisms to weigh the importance of nodes' neighbors. 81 | This allows GATs to focus on the most relevant parts of the graph structure when learning node representations. 82 | In GATs, each node aggregates information from its neighbors, but not all neighbors contribute equally. 83 | The attention mechanism computes weights that determine the importance of each neighbor's features to the target node. 84 | These weights are learned during training and are based on the features of the nodes involved. 85 | GATs can handle graphs with varying sizes and connectivity patterns, making them suitable for a wide range of applications, 86 | including social network analysis, recommendation systems, and molecular structure analysis. 87 | 88 | Example usage: 89 | ``` 90 | import jax 91 | import jax.numpy as jnp 92 | from nanodl import ArrayDataset, DataLoader 93 | from nanodl import GAT 94 | 95 | # Generate dummy data 96 | batch_size = 8 97 | max_length = 10 98 | nclass = 3 99 | 100 | # Replace with actual tokenised data 101 | # Generate a random key for Jax 102 | key = jax.random.PRNGKey(0) 103 | num_nodes = 10 104 | num_features = 5 105 | x = jax.random.normal(key, (num_nodes, num_features)) # Features for each node 106 | adj = jax.random.bernoulli(key, 0.3, (num_nodes, num_nodes)) # Random adjacency matrix 107 | 108 | # Initialize the GAT model 109 | model = GAT(nfeat=num_features, 110 | nhid=8, 111 | nclass=nclass, 112 | dropout_rate=0.5, 113 | alpha=0.2, 114 | nheads=3) 115 | 116 | # Initialize the model parameters 117 | params = model.init(key, x, adj) 118 | output = model.apply(params, x, adj) 119 | print("Output shape:", output.shape) 120 | ``` 121 | 122 | Attributes: 123 | nfeat (int): Number of features for each node in the input graph. 124 | nhid (int): Number of hidden units in each graph attention layer. 125 | nclass (int): Number of classes for the node classification task. 126 | dropout_rate (float): Dropout rate for regularization. Applied to the inputs of each graph attention layer and the final output. 127 | alpha (float): LeakyReLU angle of negative slope used in the attention mechanism. 128 | nheads (int): Number of attention heads. Each head computes a separate attention mechanism over the input, and their results are concatenated. 129 | 130 | Methods: 131 | __call__(self, x: jnp.ndarray, adj: jnp.ndarray, training: bool = False) -> jnp.ndarray: 132 | Forward pass of the GAT model. 133 | 134 | Args: 135 | x (jnp.ndarray): Node features matrix with shape (N, nfeat), where N is the number of nodes in the graph. 136 | adj (jnp.ndarray): Adjacency matrix of the graph with shape (N, N). It should represent the graph structure. 137 | training (bool, optional): Flag to indicate whether the model is being used for training. Affects dropout behavior. Defaults to False. 138 | 139 | Returns: 140 | jnp.ndarray: The output node features after passing through the GAT model. Shape is (N, nclass), representing the class scores for each node. 141 | """ 142 | 143 | nfeat: int 144 | nhid: int 145 | nclass: int 146 | dropout_rate: float 147 | alpha: float 148 | nheads: int 149 | 150 | @nn.compact 151 | def __call__( 152 | self, x: jnp.ndarray, adj: jnp.ndarray, training: bool = False 153 | ) -> jnp.ndarray: 154 | 155 | heads = [ 156 | GraphAttentionLayer( 157 | self.nfeat, 158 | self.nhid, 159 | dropout_rate=self.dropout_rate, 160 | alpha=self.alpha, 161 | concat=True, 162 | ) 163 | for _ in range(self.nheads) 164 | ] 165 | 166 | x = jnp.concatenate([head(x, adj, training) for head in heads], axis=1) 167 | 168 | x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x) 169 | 170 | out_att = GraphAttentionLayer( 171 | self.nhid * self.nheads, 172 | self.nclass, 173 | dropout_rate=self.dropout_rate, 174 | alpha=self.alpha, 175 | concat=False, 176 | ) 177 | 178 | return out_att(x, adj, training) 179 | -------------------------------------------------------------------------------- /nanodl/__src/experimental/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | 4 | from sentencepiece import SentencePieceProcessor, SentencePieceTrainer 5 | 6 | 7 | class Tokenizer: 8 | """ 9 | A tokenizer class that utilizes SentencePiece to encode and decode text. 10 | 11 | This class can be initialized with either an existing SentencePiece model 12 | or a dataset to train a new model. It provides methods to encode a string 13 | to a list of token ids and decode a list of token ids back to a string. 14 | 15 | Attributes: 16 | sp_model (SentencePieceProcessor): The SentencePiece processor. 17 | n_words (int): Number of words in the vocabulary. 18 | bos_id (int): Token id for the beginning of a sentence. 19 | eos_id (int): Token id for the end of a sentence. 20 | pad_id (int): Token id for padding. 21 | 22 | Example usage: 23 | 24 | Training a new model and encoding/decoding a string: 25 | 26 | ```python 27 | # Initialize tokenizer with training data and train a new model. 28 | text_paths = ['/Users/mac1/Desktop/nanodl/nanodl/__src/utils/sample.txt'] 29 | 30 | tokenizer = Tokenizer(training_data=text_paths, 31 | vocab_size=100, 32 | model_type='bpe', 33 | max_sentence_length=50) 34 | 35 | # Encode a sentence. 36 | encoded_sentence = tokenizer.encode('Hello, world!') 37 | print(f'Encoded: {encoded_sentence}') 38 | 39 | # Decode the encoded sentence. 40 | decoded_sentence = tokenizer.decode(encoded_sentence) 41 | print(f'Decoded: {decoded_sentence}') 42 | ``` 43 | 44 | Loading an existing model and encoding/decoding a string: 45 | 46 | ```python 47 | # Initialize tokenizer with a pre-trained model. 48 | tokenizer = Tokenizer(model_path='path/to/model.model') 49 | 50 | # Encode a sentence. 51 | encoded_sentence = tokenizer.encode('Hello, world!') 52 | print(f'Encoded: {encoded_sentence}') 53 | 54 | # Decode the encoded sentence. 55 | decoded_sentence = tokenizer.decode(encoded_sentence) 56 | print(f'Decoded: {decoded_sentence}') 57 | ``` 58 | """ 59 | 60 | def __init__( 61 | self, 62 | training_data: List[str] = None, 63 | vocab_size: int = None, 64 | model_type: str = "bpe", 65 | max_sentence_length: int = 512, 66 | model_path: Optional[str] = None, 67 | ): 68 | 69 | if model_path and os.path.isfile(model_path): 70 | # Load an existing model 71 | self.sp_model = SentencePieceProcessor(model_file=model_path) 72 | elif training_data and all(os.path.isfile(f) for f in training_data): 73 | # Train a new model using a list of data files 74 | input_files = ",".join(training_data) 75 | model_prefix = "trained_model" 76 | SentencePieceTrainer.train( 77 | input=input_files, 78 | model_prefix=model_prefix, 79 | vocab_size=vocab_size, 80 | model_type=model_type, 81 | max_sentence_length=max_sentence_length, 82 | ) 83 | 84 | self.sp_model = SentencePieceProcessor(model_file=f"{model_prefix}.model") 85 | else: 86 | raise ValueError( 87 | "Must provide either a model_path or a non-empty training_data list" 88 | ) 89 | 90 | # Initialize token IDs 91 | self.n_words: int = self.sp_model.vocab_size() 92 | self.bos_id: int = self.sp_model.bos_id() 93 | self.eos_id: int = self.sp_model.eos_id() 94 | self.pad_id: int = self.sp_model.pad_id() 95 | 96 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() 97 | 98 | def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]: 99 | """Converts a string into a list of tokens.""" 100 | assert isinstance(s, str) 101 | t = self.sp_model.encode(s) 102 | if bos: 103 | t = [self.bos_id] + t 104 | if eos: 105 | t = t + [self.eos_id] 106 | return t 107 | 108 | def decode(self, t: List[int]) -> str: 109 | """Converts a list of tokens back into a string.""" 110 | return self.sp_model.decode(t) 111 | -------------------------------------------------------------------------------- /nanodl/__src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMUNACHI/nanodl/e52861a5b2c9bf76e4e79e0bf88a07420497579d/nanodl/__src/models/__init__.py -------------------------------------------------------------------------------- /nanodl/__src/models/kan.py: -------------------------------------------------------------------------------- 1 | from flax import linen as nn 2 | import jax.numpy as jnp 3 | import jax 4 | from jax import random 5 | from typing import Any 6 | 7 | 8 | class KANLinear(nn.Module): 9 | """ 10 | A Flax module implementing a B-spline Neural Network layer, where the basis functions are B-splines. 11 | 12 | Attributes: 13 | in_features (int): Number of input features. 14 | out_features (int): Number of output features. 15 | degree (int): Degree of the B-splines. 16 | """ 17 | in_features: int 18 | out_features: int 19 | degree: int 20 | 21 | def setup(self) -> None: 22 | assert self.degree > 0, "Degree of the B-splines must be greater than 0" 23 | mean, std = 0.0, 1 / (self.in_features * (self.degree + 1)) 24 | self.coefficients = self.param( 25 | "coefficients", 26 | lambda key, shape: mean + std * random.normal(key, shape), 27 | (self.in_features, self.out_features, self.degree + 1) 28 | ) 29 | 30 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 31 | x = jnp.tanh(x) 32 | 33 | knots = jnp.linspace(-1, 1, self.degree + self.in_features + 1) 34 | 35 | b_spline_values = jnp.array([ 36 | self.bspline_basis(x[:, i], self.degree, knots) for i in range(self.in_features) 37 | ]) 38 | 39 | b_spline_values = b_spline_values.transpose((1, 0, 2)) 40 | 41 | output = jnp.einsum('bid,ijd->bj', b_spline_values, self.coefficients) 42 | return output 43 | 44 | def bspline_basis(self, x: jnp.ndarray, degree: int, knots: jnp.ndarray) -> jnp.ndarray: 45 | 46 | def cox_de_boor(x, k, i, t): 47 | if k == 0: 48 | return jnp.where((t[i] <= x) & (x < t[i + 1]), 1.0, 0.0) 49 | else: 50 | denom1 = t[i + k] - t[i] 51 | denom2 = t[i + k + 1] - t[i + 1] 52 | 53 | term1 = jnp.where(denom1 != 0, (x - t[i]) / denom1, 0.0) * cox_de_boor(x, k - 1, i, t) 54 | term2 = jnp.where(denom2 != 0, (t[i + k + 1] - x) / denom2, 0.0) * cox_de_boor(x, k - 1, i + 1, t) 55 | 56 | return term1 + term2 57 | 58 | n_basis = len(knots) - degree - 1 59 | basis = jnp.array([cox_de_boor(x, degree, i, knots) for i in range(n_basis)]) 60 | return jnp.transpose(basis) 61 | 62 | 63 | class ChebyKANLinear(nn.Module): 64 | """ 65 | A Flax module implementing a Chebyshev Neural Network layer, where the basis functions are Chebyshev polynomials. 66 | 67 | Inspired by https://github.com/CG80499/KAN-GPT-2/blob/master/chebykan_layer.py 68 | 69 | Attributes: 70 | in_features (int): Number of input features. 71 | out_features (int): Number of output features. 72 | degree (int): Degree of the Chebyshev polynomials. 73 | """ 74 | in_features: int 75 | out_features: int 76 | degree: int 77 | 78 | def setup(self) -> None: 79 | mean, std = 0.0, 1 / (self.in_features * (self.degree + 1)) 80 | self.coefficients = self.param( 81 | "coefficients", 82 | lambda key, shape: mean + std * random.normal(key, shape), 83 | (self.in_features, self.out_features, self.degree + 1) 84 | ) 85 | 86 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 87 | 88 | x = jnp.tanh(x) 89 | 90 | cheby_values = jnp.ones((x.shape[0], self.in_features, self.degree + 1)) 91 | cheby_values = cheby_values.at[:, :, 1].set(x) 92 | 93 | for i in range(2, self.degree + 1): 94 | next_value = 2 * x * cheby_values[:, :, i - 1] - cheby_values[:, :, i - 2] 95 | cheby_values = cheby_values.at[:, :, i].set(next_value) 96 | 97 | output = jnp.einsum('bid,ijd->bj', cheby_values, self.coefficients) 98 | return output 99 | 100 | 101 | class LegendreKANLinear(nn.Module): 102 | """ 103 | A Flax module implementing a Legendre Neural Network layer, where the basis functions are Legendre polynomials. 104 | 105 | Attributes: 106 | in_features (int): Number of input features. 107 | out_features (int): Number of output features. 108 | degree (int): Degree of the Legendre polynomials. 109 | """ 110 | in_features: int 111 | out_features: int 112 | degree: int 113 | 114 | def setup(self) -> None: 115 | assert self.degree > 0, "Degree of the Legendre polynomials must be greater than 0" 116 | mean, std = 0.0, 1 / (self.in_features * (self.degree + 1)) 117 | self.coefficients = self.param( 118 | "coefficients", 119 | lambda key, shape: mean + std * random.normal(key, shape), 120 | (self.in_features, self.out_features, self.degree + 1) 121 | ) 122 | 123 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 124 | x = jnp.tanh(x) 125 | 126 | legendre_values = jnp.ones((x.shape[0], self.in_features, self.degree + 1)) 127 | legendre_values = legendre_values.at[:, :, 1].set(x) 128 | 129 | for i in range(2, self.degree + 1): 130 | next_value = ((2 * i - 1) * x * legendre_values[:, :, i - 1] - (i - 1) * legendre_values[:, :, i - 2]) / i 131 | legendre_values = legendre_values.at[:, :, i].set(next_value) 132 | 133 | output = jnp.einsum('bid,ijd->bj', legendre_values, self.coefficients) 134 | return output 135 | 136 | 137 | class MonomialKANLinear(nn.Module): 138 | """ 139 | A Flax module implementing a Monomial Neural Network layer, where the basis functions are monomials. 140 | 141 | Attributes: 142 | in_features (int): Number of input features. 143 | out_features (int): Number of output features. 144 | degree (int): Degree of the monomial basis functions. 145 | """ 146 | in_features: int 147 | out_features: int 148 | degree: int 149 | 150 | def setup(self) -> None: 151 | assert self.degree > 0, "Degree of the monomial basis functions must be greater than 0" 152 | mean, std = 0.0, 1 / (self.in_features * (self.degree + 1)) 153 | self.coefficients = self.param( 154 | "coefficients", 155 | lambda key, shape: mean + std * random.normal(key, shape), 156 | (self.in_features, self.out_features, self.degree + 1) 157 | ) 158 | 159 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 160 | x = jnp.tanh(x) 161 | monomial_values = jnp.ones((x.shape[0], self.in_features, self.degree + 1)) 162 | 163 | for i in range(1, self.degree + 1): 164 | monomial_values = monomial_values.at[:, :, i].set(x ** i) 165 | 166 | output = jnp.einsum('bid,ijd->bj', monomial_values, self.coefficients) 167 | return output 168 | 169 | 170 | class FourierKANLinear(nn.Module): 171 | """ 172 | A Flax module implementing a Fourier Neural Network layer, where the basis functions are sine and cosine functions. 173 | 174 | Attributes: 175 | in_features (int): Number of input features. 176 | out_features (int): Number of output features. 177 | degree (int): Degree of the Fourier series (i.e., number of harmonics). 178 | """ 179 | in_features: int 180 | out_features: int 181 | degree: int 182 | 183 | def setup(self) -> None: 184 | assert self.degree > 0, "Degree of the Fourier series must be greater than 0" 185 | mean, std = 0.0, 1 / (self.in_features * (2 * self.degree + 1)) 186 | self.sine_coefficients = self.param( 187 | "sine_coefficients", 188 | lambda key, shape: mean + std * random.normal(key, shape), 189 | (self.in_features, self.out_features, self.degree) 190 | ) 191 | self.cosine_coefficients = self.param( 192 | "cosine_coefficients", 193 | lambda key, shape: mean + std * random.normal(key, shape), 194 | (self.in_features, self.out_features, self.degree) 195 | ) 196 | 197 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 198 | x = jnp.tanh(x) 199 | sine_values = jnp.sin(jnp.pi * jnp.arange(1, self.degree + 1) * x[..., None]) 200 | cosine_values = jnp.cos(jnp.pi * jnp.arange(1, self.degree + 1) * x[..., None]) 201 | 202 | output = (jnp.einsum('bid,ijd->bj', sine_values, self.sine_coefficients) + 203 | jnp.einsum('bid,ijd->bj', cosine_values, self.cosine_coefficients)) 204 | return output 205 | 206 | 207 | class HermiteKANLinear(nn.Module): 208 | """ 209 | A Flax module implementing a Hermite Neural Network layer, where the basis functions are Hermite polynomials. 210 | 211 | Attributes: 212 | in_features (int): Number of input features. 213 | out_features (int): Number of output features. 214 | degree (int): Degree of the Hermite polynomials. 215 | """ 216 | in_features: int 217 | out_features: int 218 | degree: int 219 | 220 | def setup(self) -> None: 221 | assert self.degree > 0, "Degree of the Hermite polynomials must be greater than 0" 222 | mean, std = 0.0, 1 / (self.in_features * (self.degree + 1)) 223 | self.coefficients = self.param( 224 | "coefficients", 225 | lambda key, shape: mean + std * random.normal(key, shape), 226 | (self.in_features, self.out_features, self.degree + 1) 227 | ) 228 | 229 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 230 | x = jnp.tanh(x) 231 | 232 | hermite_values = jnp.ones((x.shape[0], self.in_features, self.degree + 1)) 233 | 234 | if self.degree >= 1: 235 | hermite_values = hermite_values.at[:, :, 1].set(2 * x) 236 | for i in range(2, self.degree + 1): 237 | hermite_values = hermite_values.at[:, :, i].set( 238 | 2 * x * hermite_values[:, :, i - 1] - 2 * (i - 1) * hermite_values[:, :, i - 2] 239 | ) 240 | 241 | output = jnp.einsum('bid,ijd->bj', hermite_values, self.coefficients) 242 | return output 243 | -------------------------------------------------------------------------------- /nanodl/__src/models/reward.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Any, Iterable, Optional, Tuple 3 | 4 | import flax 5 | import flax.linen as nn 6 | import jax 7 | import jax.numpy as jnp 8 | import optax 9 | from flax.training import train_state 10 | 11 | 12 | class RewardModel(nn.Module): 13 | """ 14 | The RewardModel estimates the reward or value of a given input sequence, 15 | typically used in reinforcement learning frameworks for natural language processing tasks. 16 | It uses the last hidden state of a transformer-based model to generate a scalar reward prediction, 17 | guiding the agent's behavior by evaluating the desirability or utility of its generated outputs. 18 | 19 | Args: 20 | model (nn.Module): The neural network model to be used. 21 | dim (int): The dimension of the input data. 22 | dropout (float): The dropout rate for the model, a value between 0 and 1. 23 | 24 | Example: 25 | ```python 26 | from nanodl import ArrayDataset, DataLoader 27 | from nanodl import Gemma, RewardModel, RewardDataParallelTrainer 28 | 29 | # Generate dummy data 30 | batch_size = 8 31 | max_length = 10 32 | 33 | # Replace with actual tokenised data 34 | dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32) 35 | dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32) 36 | 37 | # Create dataset and dataloader 38 | dataset = ArrayDataset(dummy_chosen, dummy_rejected) 39 | dataloader = DataLoader(dataset, 40 | batch_size=batch_size, 41 | shuffle=True, 42 | drop_last=False) 43 | 44 | # model parameters 45 | hyperparams = { 46 | 'num_layers': 1, 47 | 'hidden_dim': 256, 48 | 'num_heads': 2, 49 | 'feedforward_dim': 256, 50 | 'dropout': 0.1, 51 | 'vocab_size': 1000, 52 | 'embed_dim': 256, 53 | 'max_length': max_length, 54 | 'start_token': 0, 55 | 'end_token': 50, 56 | 'num_groups': 2, 57 | } 58 | 59 | # Initialize reward model from Gemma 60 | model = Gemma(**hyperparams) 61 | reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1) 62 | 63 | # Train the reward model 64 | trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl') 65 | trainer.train(dataloader, 5, dataloader) 66 | params = trainer.load_params('reward_model_weights.pkl') 67 | 68 | # Call as you would a regular Flax model 69 | rngs = jax.random.PRNGKey(0) 70 | rngs, dropout_rng = jax.random.split(rngs) 71 | rewards = reward_model.apply({'params': params}, 72 | dummy_chosen, 73 | rngs={'dropout': dropout_rng}) 74 | 75 | print(rewards.shape) 76 | ``` 77 | """ 78 | 79 | model: nn.Module 80 | dim: int 81 | dropout: float 82 | 83 | @nn.compact 84 | def __call__(self, x: jnp.ndarray, training: bool = False): 85 | 86 | x = self.model(x, training=training, drop_last_layer=True) 87 | x = nn.Dropout(rate=self.dropout)(x, deterministic=not training) 88 | x = nn.Dense(1)(x) 89 | return nn.sigmoid(x)[:, -1, 0] 90 | 91 | 92 | class RewardDataParallelTrainer: 93 | """ 94 | Trainer class using data parallelism with JAX. 95 | This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). 96 | It handles the model training loop, including gradient computation, parameter updates, and evaluation. 97 | 98 | Attributes: 99 | model (Any): The model to be trained. 100 | input_shape (Tuple[int, ...]): The shape of the input tensor. 101 | weights_filename (str): Filename where the trained model weights will be saved. 102 | learning_rate (float): Learning rate for the optimizer. 103 | params_path (Optional[str]): Path to pre-trained reward model parameters for initializing the REWARD model, if available. 104 | model_params_path (Optional[str]): Path to pre-trained backbone model parameters for initializing the BACKBONE model, if available. 105 | 106 | Methods: 107 | create_train_state(learning_rate, text_input_shape, image_input_shape): Initializes the training state, including parameters and optimizer. 108 | train_step(state, texts, images): Performs a single training step, including forward pass, loss computation, and gradients update. 109 | train(train_loader, num_epochs, val_loader): Runs the training loop over the specified number of epochs, using the provided data loaders for training and validation. 110 | evaluation_step(state, texts, images): Performs an evaluation step, computing forward pass and loss without updating model parameters. 111 | evaluate(test_loader): Evaluates the model performance on a test dataset. 112 | save_params(): Saves the model parameters to a file. 113 | load_params(filename): Loads model parameters from a file. 114 | """ 115 | 116 | def __init__( 117 | self, 118 | model: Any, 119 | input_shape: Tuple[int, ...], 120 | weights_filename: str, 121 | learning_rate: float = 1e-5, 122 | params_path: Optional[str] = None, 123 | model_params_path: Optional[str] = None, 124 | ) -> None: 125 | 126 | self.model = model 127 | self.params = None 128 | self.params_path = params_path 129 | self.model_params_path = model_params_path 130 | self.num_parameters = None 131 | self.best_val_loss = float("inf") 132 | self.weights_filename = weights_filename 133 | self.num_devices = jax.local_device_count() 134 | self.train_step = jax.pmap( 135 | RewardDataParallelTrainer.train_step, axis_name="devices" 136 | ) 137 | self.evaluation_step = jax.pmap( 138 | RewardDataParallelTrainer.evaluation_step, axis_name="devices" 139 | ) 140 | self.state = self.create_train_state(learning_rate, input_shape) 141 | print(f"Number of accelerators: {self.num_devices}") 142 | 143 | def create_train_state( 144 | self, learning_rate: float, input_shape: Tuple[int, ...] 145 | ) -> Any: 146 | 147 | rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} 148 | params = self.model.init(rngs, jnp.ones(input_shape, dtype=jnp.int32))["params"] 149 | 150 | if self.params_path is not None: 151 | params = self.load_params(self.params_path) 152 | 153 | if self.model_params_path is not None: 154 | model_params = self.load_params(self.model_params_path) 155 | params = self.merge_params(model_params, params) 156 | 157 | self.num_parameters = sum( 158 | param.size for param in jax.tree_util.tree_leaves(params) 159 | ) 160 | print(f"Number of parameters: {self.num_parameters}") 161 | state = train_state.TrainState.create( 162 | apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate) 163 | ) 164 | return jax.device_put_replicated(state, jax.local_devices()) 165 | 166 | @staticmethod 167 | def train_step( 168 | state: Any, chosen: jnp.ndarray, rejected: jnp.ndarray 169 | ) -> Tuple[Any, jnp.ndarray]: 170 | 171 | def loss_fn(params): 172 | chosen_rewards = state.apply_fn( 173 | {"params": params}, 174 | chosen, 175 | training=True, 176 | rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, 177 | ) 178 | 179 | rejected_rewards = state.apply_fn( 180 | {"params": params}, 181 | rejected, 182 | training=True, 183 | rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, 184 | ) 185 | 186 | return -jnp.log(jax.nn.sigmoid(chosen_rewards - rejected_rewards)).mean() 187 | 188 | loss, grads = jax.value_and_grad(loss_fn)(state.params) 189 | state = state.apply_gradients(grads=grads) 190 | return state, loss 191 | 192 | def train( 193 | self, 194 | train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], 195 | num_epochs: int, 196 | val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, 197 | ) -> None: 198 | 199 | for epoch in range(num_epochs): 200 | total_loss = 0.0 201 | count = 0 202 | for chosen, rejected in train_loader: 203 | batch_size = chosen.shape[0] 204 | batch_size_per_device = batch_size // self.num_devices 205 | chosen = chosen.reshape((self.num_devices, batch_size_per_device, -1)) 206 | rejected = rejected.reshape( 207 | (self.num_devices, batch_size_per_device, -1) 208 | ) 209 | self.state, loss = self.train_step( 210 | state=self.state, chosen=chosen, rejected=rejected 211 | ) 212 | total_loss += jnp.mean(loss) 213 | count += 1 214 | 215 | mean_loss = total_loss / count 216 | print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") 217 | 218 | if val_loader is not None: 219 | val_loss = self.evaluate(val_loader) 220 | print(f"Epoch {epoch+1}, Val Loss: {val_loss}") 221 | if val_loss < self.best_val_loss: 222 | self.best_val_loss = val_loss 223 | print("New best validation score achieved, saving model...") 224 | self.save_params() 225 | return 226 | 227 | @staticmethod 228 | def evaluation_step( 229 | state: Any, chosen: jnp.ndarray, rejected: jnp.ndarray 230 | ) -> Tuple[Any, jnp.ndarray]: 231 | chosen_rewards = state.apply_fn( 232 | {"params": state.params}, chosen, rngs={"dropout": jax.random.PRNGKey(2)} 233 | ) 234 | rejected_rewards = state.apply_fn( 235 | {"params": state.params}, rejected, rngs={"dropout": jax.random.PRNGKey(2)} 236 | ) 237 | return -jnp.log(jax.nn.sigmoid(chosen_rewards - rejected_rewards)).mean() 238 | 239 | def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: 240 | 241 | total_loss = 0.0 242 | count = 0 243 | for chosen, rejected in test_loader: 244 | batch_size = chosen.shape[0] 245 | batch_size_per_device = batch_size // self.num_devices 246 | chosen = chosen.reshape((self.num_devices, batch_size_per_device, -1)) 247 | rejected = rejected.reshape((self.num_devices, batch_size_per_device, -1)) 248 | loss = self.evaluation_step(self.state, chosen, rejected) 249 | total_loss += jnp.mean(loss) 250 | count += 1 251 | 252 | mean_loss = total_loss / count 253 | return mean_loss 254 | 255 | def merge_params(untrained_params, trained_params): 256 | updated_untrained_params = jax.tree_map( 257 | lambda untrained, trained: ( 258 | trained if untrained.shape == trained.shape else untrained 259 | ), 260 | untrained_params, 261 | trained_params, 262 | ) 263 | return updated_untrained_params 264 | 265 | def save_params(self) -> None: 266 | self.params = flax.jax_utils.unreplicate(self.state.params) 267 | with open(self.weights_filename, "wb") as f: 268 | f.write(flax.serialization.to_bytes(self.params)) 269 | 270 | def load_params(self, filename: str): 271 | with open(filename, "rb") as f: 272 | self.params = flax.serialization.from_bytes(self.params, f.read()) 273 | return self.params 274 | -------------------------------------------------------------------------------- /nanodl/__src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMUNACHI/nanodl/e52861a5b2c9bf76e4e79e0bf88a07420497579d/nanodl/__src/utils/__init__.py -------------------------------------------------------------------------------- /nanodl/__src/utils/data.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from dataclasses import dataclass 3 | from typing import Iterator 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | # This script modifies the JAX DataLoader from the following repository: 9 | # JAX DataLoader by Birkhoff G. (https://birkhoffg.github.io/jax-dataloader/) 10 | # Accessed on [Date you accessed the repository, e.g., February 4, 2024] 11 | # This DataLoader implementation is used for efficient data loading in JAX-based machine learning projects. 12 | 13 | 14 | class Dataset: 15 | """ 16 | A PyTorch-like Dataset class for JAX. 17 | 18 | This is a base class for creating datasets in JAX. Subclasses should implement 19 | the `__len__` method to return the size of the dataset and the `__getitem__` 20 | method to return a data item at a given index. 21 | 22 | Example usage: 23 | ``` 24 | >>> class MyDataset(Dataset): 25 | ... def __init__(self, data): 26 | ... self.data = data 27 | ... def __len__(self): 28 | ... return len(self.data) 29 | ... def __getitem__(self, index): 30 | ... return self.data[index] 31 | >>> dataset = MyDataset(jnp.arange(10)) 32 | >>> print(len(dataset)) 33 | >>> print(dataset[5]) 34 | ``` 35 | """ 36 | 37 | def __len__(self): 38 | raise NotImplementedError 39 | 40 | def __getitem__(self, index): 41 | raise NotImplementedError 42 | 43 | 44 | class ArrayDataset(Dataset): 45 | """ 46 | Dataset wrapping JAX numpy arrays. 47 | 48 | This class wraps multiple JAX numpy arrays into a dataset. Each array represents 49 | a different modality of the data (e.g., features and labels). All arrays must 50 | have the same first dimension (number of samples). 51 | 52 | Args: 53 | *arrays (jnp.array): Variable number of JAX numpy arrays to include in the dataset. 54 | 55 | Example usage: 56 | ``` 57 | >>> dataset = ArrayDataset(jnp.array([1, 2, 3]), jnp.array([4, 5, 6])) 58 | >>> print(len(dataset)) 59 | >>> print(dataset[1]) 60 | ``` 61 | """ 62 | 63 | def __init__(self, *arrays: jnp.array): 64 | assert all( 65 | arrays[0].shape[0] == arr.shape[0] for arr in arrays 66 | ), "All arrays must have the same first dimension." 67 | self.arrays = arrays 68 | 69 | def __len__(self): 70 | return self.arrays[0].shape[0] 71 | 72 | def __getitem__(self, index): 73 | return tuple(arr[index] for arr in self.arrays) 74 | 75 | 76 | class DataLoader: 77 | """ 78 | DataLoader in Vanilla Jax. 79 | 80 | This class provides a way to iterate over batches of data from a given dataset. 81 | It supports batch processing, shuffling, and dropping the last batch if it's 82 | smaller than the specified batch size. 83 | 84 | Args: 85 | dataset (Dataset): The dataset from which to load the data. 86 | batch_size (int, optional): Number of samples per batch. Default is 1. 87 | shuffle (bool, optional): Whether to shuffle the data. Default is False. 88 | drop_last (bool, optional): Whether to drop the last incomplete batch. 89 | Default is False. 90 | 91 | Example usage: 92 | ``` 93 | >>> dataset = ArrayDataset(jnp.ones((1001, 256, 256)), jnp.ones((1001, 256, 256))) 94 | >>> dataloader = DataLoader(dataset, batch_size=10, shuffle=True, drop_last=False) 95 | >>> for batch in dataloader: 96 | ... print(batch.shape) 97 | ``` 98 | """ 99 | 100 | def __init__( 101 | self, 102 | dataset: Dataset, 103 | batch_size: int = 1, 104 | shuffle: bool = False, 105 | drop_last: bool = False, 106 | **kwargs 107 | ): 108 | self.dataset = dataset 109 | self.batch_size = batch_size 110 | self.shuffle = shuffle 111 | self.drop_last = drop_last 112 | 113 | self.keys = _PRNGSequence(seed=Config.default().global_seed) 114 | self.data_len = len(dataset) # Length of the dataset 115 | self.indices = jnp.arange(self.data_len) # available indices in the dataset 116 | self.pose = 0 # record the current position in the dataset 117 | self._shuffle() 118 | 119 | def _shuffle(self): 120 | if self.shuffle: 121 | self.indices = jax.random.permutation(next(self.keys), self.indices) 122 | 123 | def _stop_iteration(self): 124 | self.pose = 0 125 | self._shuffle() 126 | raise StopIteration 127 | 128 | def __len__(self): 129 | if self.drop_last: 130 | batches = len(self.dataset) // self.batch_size # get the floor of division 131 | else: 132 | batches = -( 133 | len(self.dataset) // -self.batch_size 134 | ) # get the ceil of division 135 | return batches 136 | 137 | def __next__(self): 138 | if self.pose + self.batch_size <= self.data_len: 139 | batch_indices = self.indices[self.pose : self.pose + self.batch_size] 140 | batch_data = self.dataset[batch_indices] 141 | self.pose += self.batch_size 142 | return batch_data 143 | elif self.pose < self.data_len and not self.drop_last: 144 | batch_indices = self.indices[self.pose :] 145 | batch_data = self.dataset[batch_indices] 146 | self.pose += self.batch_size 147 | return batch_data 148 | else: 149 | self._stop_iteration() 150 | 151 | def __iter__(self): 152 | return self 153 | 154 | 155 | @dataclass 156 | class Config: 157 | rng_reserve_size: int 158 | global_seed: int 159 | 160 | @classmethod 161 | def default(cls): 162 | return cls(rng_reserve_size=1, global_seed=42) 163 | 164 | 165 | class _PRNGSequence(Iterator[jax.random.PRNGKey]): 166 | """ 167 | An Iterator of Jax PRNGKey (minimal version of `haiku.PRNGSequence`). 168 | 169 | This class provides an iterator over PRNG keys generated from a seed. It is useful 170 | for generating random numbers in a reproducible way. 171 | 172 | Args: 173 | seed (int): Seed for generating the initial PRNG key. 174 | 175 | Example usage: 176 | ``` 177 | >>> prng_seq = PRNGSequence(42) 178 | >>> key = next(prng_seq) 179 | ``` 180 | """ 181 | 182 | def __init__(self, seed: int): 183 | self._key = jax.random.PRNGKey(seed) 184 | self._subkeys = collections.deque() 185 | 186 | def reserve(self, num): 187 | if num > 0: 188 | new_keys = tuple(jax.random.split(self._key, num + 1)) 189 | self._key = new_keys[0] 190 | self._subkeys.extend(new_keys[1:]) 191 | 192 | def __next__(self): 193 | if not self._subkeys: 194 | self.reserve(Config.default().rng_reserve_size) 195 | return self._subkeys.popleft() 196 | -------------------------------------------------------------------------------- /nanodl/__src/utils/ml.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | 7 | @jax.jit 8 | def batch_cosine_similarities( 9 | source: jnp.ndarray, candidates: jnp.ndarray 10 | ) -> jnp.ndarray: 11 | """ 12 | Calculate cosine similarities between a source vector and a batch of candidate vectors. 13 | 14 | Args: 15 | source (jnp.ndarray): Source vector of shape (D,). 16 | candidates (jnp.ndarray): Batch of candidate vectors of shape (N, D), where N is the number of candidates. 17 | 18 | Returns: 19 | jnp.ndarray: Array of cosine similarity scores of shape (N,). 20 | 21 | Example usage: 22 | ``` 23 | >>> source = jnp.array([1, 0, 0]) 24 | >>> candidates = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) 25 | >>> similarities = batch_cosine_similarities(source, candidates) 26 | >>> print(similarities) 27 | ``` 28 | """ 29 | dot_products = jnp.einsum("ij,j->i", candidates, source) 30 | norm_source = jnp.sqrt(jnp.einsum("i,i->", source, source)) 31 | norm_candidates = jnp.sqrt(jnp.einsum("ij,ij->i", candidates, candidates)) 32 | return dot_products / (norm_source * norm_candidates) 33 | 34 | 35 | @jax.jit 36 | def batch_pearsonr(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: 37 | """ 38 | Calculate batch Pearson correlation coefficient between two sets of vectors. 39 | 40 | Args: 41 | x (jnp.ndarray): First set of vectors of shape (N, D), where N is the number of vectors. 42 | y (jnp.ndarray): Second set of vectors of shape (N, D). 43 | 44 | Returns: 45 | jnp.ndarray: Array of Pearson correlation coefficients of shape (N,). 46 | 47 | Example usage: 48 | ``` 49 | >>> x = jnp.array([[1, 2, 3], [4, 5, 6]]) 50 | >>> y = jnp.array([[1, 5, 7], [2, 6, 8]]) 51 | >>> correlations = batch_pearsonr(x, y) 52 | >>> print(correlations) 53 | ``` 54 | """ 55 | x = jnp.asarray(x).T 56 | y = jnp.asarray(y).T 57 | x = x - jnp.expand_dims(x.mean(axis=1), axis=-1) 58 | y = y - jnp.expand_dims(y.mean(axis=1), axis=-1) 59 | numerator = jnp.sum(x * y, axis=-1) 60 | sum_of_squares_x = jnp.einsum("ij,ij -> i", x, x) 61 | sum_of_squares_y = jnp.einsum("ij,ij -> i", y, y) 62 | denominator = jnp.sqrt(sum_of_squares_x * sum_of_squares_y) 63 | return numerator / denominator 64 | 65 | 66 | @jax.jit 67 | def classification_scores(labels: jnp.ndarray, preds: jnp.ndarray) -> jnp.ndarray: 68 | """ 69 | Calculate classification evaluation scores using JAX. 70 | 71 | Args: 72 | labels (jnp.ndarray): Array of true labels. 73 | preds (jnp.ndarray): Array of predicted labels. 74 | 75 | Returns: 76 | jnp.ndarray: Array containing accuracy, precision, recall, and F1-score. 77 | 78 | Example usage: 79 | ``` 80 | >>> labels = jnp.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]) 81 | >>> preds = jnp.array([1, 1, 1, 0, 1, 0, 1, 0, 0, 0]) 82 | >>> print(classification_scores(labels, preds)) 83 | ``` 84 | """ 85 | true_positives = jnp.sum(jnp.logical_and(preds == 1, labels == 1)) 86 | true_negatives = jnp.sum(jnp.logical_and(preds == 0, labels == 0)) 87 | false_positives = jnp.sum(jnp.logical_and(preds == 1, labels == 0)) 88 | false_negatives = jnp.sum(jnp.logical_and(preds == 0, labels == 1)) 89 | 90 | accuracy = (true_positives + true_negatives) / len(preds) 91 | precision = true_positives / (true_positives + false_positives) 92 | recall = true_positives / (true_positives + false_negatives) 93 | f1 = 2 * (precision * recall) / (precision + recall) 94 | return jnp.array([accuracy, precision, recall, f1]) 95 | 96 | 97 | @jax.jit 98 | def mean_reciprocal_rank(predictions: jnp.ndarray) -> float: 99 | """ 100 | Calculate the Mean Reciprocal Rank (MRR) for a list of ranked predictions using JAX. 101 | 102 | Example usage: 103 | ``` 104 | predictions = jnp.array([ 105 | [0, 1, 2], # "correct" prediction at index 0 106 | [1, 0, 2], # "correct" prediction at index 1 107 | [2, 1, 0] # "correct" prediction at index 2 108 | ]) 109 | mrr_score = mean_reciprocal_rank(predictions) 110 | ``` 111 | 112 | Args: 113 | predictions (jnp.ndarray): 2D array where each row contains ranked predictions 114 | and the "correct" prediction is indicated by a specific index. 115 | 116 | Returns: 117 | float: Mean Reciprocal Rank (MRR) score. 118 | """ 119 | correct_indices = jnp.argmin(predictions, axis=1) 120 | ranks = correct_indices + 1 121 | reciprocal_ranks = 1.0 / ranks 122 | mean_mrr = jnp.mean(reciprocal_ranks) 123 | return mean_mrr 124 | 125 | 126 | def jaccard(sequence1: List, sequence2: List) -> float: 127 | """ 128 | Calculate Jaccard similarity between two sequences. 129 | 130 | Args: 131 | sequence1 (List): First input sequence. 132 | sequence2 (List): Second input sequence. 133 | 134 | Returns: 135 | float: Jaccard similarity score. 136 | 137 | Example usage: 138 | ```py 139 | >>> sequence1 = [1, 2, 3] 140 | >>> sequence2 = [2, 3, 4] 141 | >>> similarity = jaccard(sequence1, sequence2) 142 | >>> print(similarity) 143 | ``` 144 | """ 145 | numerator = len(set(sequence1).intersection(sequence2)) 146 | denominator = len(set(sequence1).union(sequence2)) 147 | return numerator / denominator 148 | 149 | 150 | @jax.jit 151 | def hamming(sequence1: jnp.ndarray, sequence2: jnp.ndarray) -> int: 152 | """ 153 | Calculate Hamming similarity between two sequences using JAX. 154 | 155 | Args: 156 | sequence1 (jnp.ndarray): First input sequence. 157 | sequence2 (jnp.ndarray): Second input sequence. 158 | 159 | Returns: 160 | int: Hamming similarity score. 161 | 162 | Example usage: 163 | ```py 164 | >>> sequence1 = jnp.array([1, 2, 3, 4]) 165 | >>> sequence2 = jnp.array([1, 2, 4, 4]) 166 | >>> similarity = hamming_jax(sequence1, sequence2) 167 | >>> print(similarity) 168 | ``` 169 | """ 170 | return jnp.sum(sequence1 == sequence2) 171 | 172 | 173 | def zero_pad_sequences(arr: jnp.array, max_length: int) -> jnp.array: 174 | """ 175 | Zero-pad the given array to the specified maximum length along axis=1. 176 | 177 | This function pads the input array with zeros along the second dimension (axis=1) 178 | until it reaches the specified maximum length. If the array is already longer 179 | than the maximum length, it is returned as is. 180 | 181 | Args: 182 | arr (jax.numpy.ndarray): The array to be padded. Must be 2-dimensional. 183 | max_length (int): The maximum length to pad the array to along axis=1. 184 | 185 | Returns: 186 | jax.numpy.ndarray: The zero-padded array. 187 | 188 | Example usage: 189 | ```py 190 | >>> arr = jnp.array([[1, 2, 3], [4, 5, 6]]) 191 | >>> max_length = 5 192 | >>> padded_arr = zero_pad_sequences(arr, max_length) 193 | >>> print(padded_arr) 194 | [[1 2 3 0 0] 195 | [4 5 6 0 0]] 196 | ``` 197 | """ 198 | current_length = arr.shape[1] 199 | num_zeros = max_length - current_length 200 | 201 | if num_zeros > 0: 202 | zeros = jnp.zeros((arr.shape[0], num_zeros), dtype=arr.dtype) 203 | padded_array = jnp.concatenate([arr, zeros], axis=1) 204 | else: 205 | padded_array = arr 206 | 207 | return padded_array 208 | 209 | 210 | @jax.jit 211 | def entropy(probabilities: jnp.ndarray) -> float: 212 | """ 213 | Calculate the entropy of a probability distribution using JAX. 214 | 215 | Example usage: 216 | ``` 217 | probabilities = jnp.array([0.25, 0.75]) 218 | entropy_value = entropy(probabilities) 219 | ``` 220 | 221 | Args: 222 | probabilities (jnp.ndarray): Array of probability values. 223 | 224 | Returns: 225 | float: Entropy value. 226 | """ 227 | log_probs = jnp.log2(probabilities) 228 | entropy_value = -jnp.sum(probabilities * log_probs) 229 | return entropy_value 230 | 231 | 232 | @jax.jit 233 | def gini_impurity(probabilities: jnp.ndarray) -> float: 234 | """ 235 | Calculate the Gini impurity of a probability distribution using JAX. 236 | 237 | Example usage: 238 | ``` 239 | probabilities = jnp.array([0.25, 0.75]) 240 | gini_value = gini_impurity(probabilities) 241 | ``` 242 | 243 | Args: 244 | probabilities (jnp.ndarray): Array of probability values. 245 | 246 | Returns: 247 | float: Gini impurity value. 248 | """ 249 | gini_value = 1 - jnp.sum(probabilities**2) 250 | return gini_value 251 | 252 | 253 | @jax.jit 254 | def kl_divergence(p: jnp.ndarray, q: jnp.ndarray) -> float: 255 | """ 256 | Calculate the Kullback-Leibler (KL) divergence between two probability distributions using JAX. 257 | 258 | Example usage: 259 | ``` 260 | p = jnp.array([0.25, 0.75]) 261 | q = jnp.array([0.5, 0.5]) 262 | kl_value = kl_divergence(p, q) 263 | ``` 264 | 265 | Args: 266 | p (jnp.ndarray): Array of probability values for distribution p. 267 | q (jnp.ndarray): Array of probability values for distribution q. 268 | 269 | Returns: 270 | float: KL divergence value. 271 | """ 272 | kl_value = jnp.sum(p * jnp.log2(p / q)) 273 | return kl_value 274 | 275 | 276 | @jax.jit 277 | def count_parameters(params: Any) -> int: 278 | """ 279 | Count the total number of parameters in a model's parameter dictionary using JAX. 280 | 281 | Example usage: 282 | ``` 283 | model = MyModel() 284 | params = model.init(jax.random.PRNGKey(0), jnp.ones(input_shape)) 285 | total_params = count_parameters(params) 286 | ``` 287 | 288 | Args: 289 | params (Any): Model's parameter dictionary. 290 | 291 | Returns: 292 | int: Total number of parameters. 293 | """ 294 | return sum(x.size for x in jax.tree_leaves(params)) 295 | -------------------------------------------------------------------------------- /nanodl/__src/utils/nlp.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import Counter 3 | from typing import List 4 | 5 | import numpy as np 6 | 7 | 8 | def rouge( 9 | hypotheses: List[str], references: List[str], ngram_sizes: List[int] = [1, 2] 10 | ) -> dict: 11 | """ 12 | Calculate the ROUGE (Recall-Oriented Understudy for Gisting Evaluation) metric. 13 | ROUGE-F1 = (Precision + Recall) / (2⋅Precision⋅Recall) 14 | 15 | Args: 16 | hypotheses (List[str]): List of hypothesis sentences. 17 | references (List[str]): List of reference sentences. 18 | ngram_sizes (List[int], optional): List of n-gram sizes. Default is [1, 2]. 19 | 20 | Returns: 21 | dict: Dictionary containing precision, recall, and F1-score for each n-gram size. 22 | 23 | Example usage: 24 | ``` 25 | >>> hypotheses = ["the cat is on the mat", "there is a cat on the mat"] 26 | >>> references = ["the cat is on the mat", "the cat sits on the mat"] 27 | >>> rouge_scores = rouge(hypotheses, references, [1, 2]) 28 | >>> print(rouge_scores) 29 | ``` 30 | """ 31 | 32 | def ngrams(sequence: List[str], n: int) -> List[str]: 33 | return [tuple(sequence[i : i + n]) for i in range(len(sequence) - n + 1)] 34 | 35 | def precision_recall_f1(hypothesis_tokens, reference_tokens, n): 36 | hypothesis_ngrams = set(ngrams(hypothesis_tokens, n)) 37 | reference_ngrams = set(ngrams(reference_tokens, n)) 38 | 39 | common_ngrams = hypothesis_ngrams.intersection(reference_ngrams) 40 | 41 | precision = ( 42 | len(common_ngrams) / len(hypothesis_ngrams) 43 | if len(hypothesis_ngrams) > 0 44 | else 0.0 45 | ) 46 | recall = ( 47 | len(common_ngrams) / len(reference_ngrams) 48 | if len(reference_ngrams) > 0 49 | else 0.0 50 | ) 51 | 52 | f1 = 2 * (precision * recall) / (precision + recall + 1e-12) 53 | return precision, recall, f1 54 | 55 | rouge_scores = {} 56 | for n in ngram_sizes: 57 | total_precision = 0.0 58 | total_recall = 0.0 59 | total_f1 = 0.0 60 | for hypothesis, reference in zip(hypotheses, references): 61 | hypothesis_tokens = hypothesis.split() 62 | reference_tokens = reference.split() 63 | 64 | precision, recall, f1 = precision_recall_f1( 65 | hypothesis_tokens, reference_tokens, n 66 | ) 67 | total_precision += precision 68 | total_recall += recall 69 | total_f1 += f1 70 | 71 | average_precision = total_precision / len(hypotheses) 72 | average_recall = total_recall / len(hypotheses) 73 | average_f1 = total_f1 / len(hypotheses) 74 | 75 | rouge_scores[f"ROUGE-{n}"] = { 76 | "precision": average_precision, 77 | "recall": average_recall, 78 | "f1": average_f1, 79 | } 80 | 81 | return rouge_scores 82 | 83 | 84 | def bleu(hypotheses: List[str], references: List[str], max_ngram: int = 4) -> float: 85 | """ 86 | Calculate the BLEU (Bilingual Evaluation Understudy) metric. 87 | BLEU = (BP) * (exp(sum(wn * log(pn)))) 88 | where BP = brevity penalty, wn = weight for n-gram precision, and pn = n-gram precision 89 | 90 | Args: 91 | hypotheses (List[str]): List of hypothesis sentences. 92 | references (List[str]): List of reference sentences. 93 | max_ngram (int, optional): Maximum n-gram size to consider. Default is 4. 94 | 95 | Returns: 96 | float: BLEU score. 97 | 98 | Example usage: 99 | ``` 100 | >>> hypotheses = ["the cat is on the mat", "there is a cat on the mat"] 101 | >>> references = ["the cat is on the mat", "the cat sits on the mat"] 102 | >>> bleu_score = bleu(hypotheses, references) 103 | >>> print(bleu_score) 104 | ``` 105 | """ 106 | 107 | def ngrams(sequence: List[str], n: int) -> List[str]: 108 | return [tuple(sequence[i : i + n]) for i in range(len(sequence) - n + 1)] 109 | 110 | def modified_precision(hypothesis_tokens, reference_tokens, n): 111 | hypothesis_ngrams = ngrams(hypothesis_tokens, n) 112 | reference_ngrams = ngrams(reference_tokens, n) 113 | 114 | hypothesis_ngram_counts = Counter(hypothesis_ngrams) 115 | reference_ngram_counts = Counter(reference_ngrams) 116 | 117 | common_ngrams = hypothesis_ngram_counts & reference_ngram_counts 118 | common_count = sum(common_ngrams.values()) 119 | 120 | if len(hypothesis_ngrams) == 0: 121 | return 0.0 122 | else: 123 | precision = common_count / len(hypothesis_ngrams) 124 | return precision 125 | 126 | brevity_penalty = np.exp(min(0, 1 - len(hypotheses) / len(references))) 127 | bleu_scores = [] 128 | 129 | for n in range(1, max_ngram + 1): 130 | ngram_precisions = [] 131 | for hypothesis, reference in zip(hypotheses, references): 132 | hypothesis_tokens = hypothesis.split() 133 | reference_tokens = reference.split() 134 | 135 | precision = modified_precision(hypothesis_tokens, reference_tokens, n) 136 | ngram_precisions.append(precision) 137 | 138 | geometric_mean = np.exp(np.mean(np.log(np.clip(ngram_precisions, 1e-10, None)))) 139 | bleu_scores.append(geometric_mean) 140 | 141 | final_bleu = brevity_penalty * np.exp( 142 | np.mean(np.log(np.clip(bleu_scores, 1e-10, None))) 143 | ) 144 | return final_bleu 145 | 146 | 147 | def meteor(hypothesis: str, reference: str) -> float: 148 | """ 149 | Calculates the METEOR score between a reference and hypothesis sentence. 150 | 151 | Args: 152 | reference (str): The reference sentence. 153 | hypothesis (str): The hypothesis sentence. 154 | 155 | Returns: 156 | float: METEOR score. 157 | 158 | Example usage: 159 | ``` 160 | >>> hypothesis = "the cat is on the mat" 161 | >>> reference = "the cat sits on the mat" 162 | >>> meteor_score = meteor(hypothesis, reference) 163 | >>> print(meteor_score) 164 | ``` 165 | """ 166 | 167 | def tokenize(sentence): 168 | return re.findall(r"\w+", sentence.lower()) 169 | 170 | def stemming(token): 171 | return token.lower() 172 | 173 | def exact_matching(reference_tokens, hypothesis_tokens): 174 | return sum(1 for token in hypothesis_tokens if token in reference_tokens) 175 | 176 | def stemmed_matching(reference_tokens, hypothesis_tokens): 177 | stemmed_reference = [stemming(token) for token in reference_tokens] 178 | stemmed_hypothesis = [stemming(token) for token in hypothesis_tokens] 179 | return sum(1 for token in stemmed_hypothesis if token in stemmed_reference) 180 | 181 | def precision_recall_f1(match_count, hypothesis_length, reference_length): 182 | precision = match_count / hypothesis_length if hypothesis_length > 0 else 0 183 | recall = match_count / reference_length if reference_length > 0 else 0 184 | f1 = ( 185 | 2 * precision * recall / (precision + recall) 186 | if precision + recall > 0 187 | else 0 188 | ) 189 | return precision, recall, f1 190 | 191 | reference_tokens = tokenize(reference) 192 | hypothesis_tokens = tokenize(hypothesis) 193 | 194 | exact_matches = exact_matching(reference_tokens, hypothesis_tokens) 195 | stemmed_matches = stemmed_matching(reference_tokens, hypothesis_tokens) 196 | 197 | _, _, f1_exact = precision_recall_f1( 198 | exact_matches, len(hypothesis_tokens), len(reference_tokens) 199 | ) 200 | 201 | precision_stemmed, recall_stemmed, f1_stemmed = precision_recall_f1( 202 | stemmed_matches, len(hypothesis_tokens), len(reference_tokens) 203 | ) 204 | 205 | alpha = 0.5 206 | meteor_score = (1 - alpha) * f1_exact + alpha * precision_stemmed * recall_stemmed 207 | return meteor_score 208 | 209 | 210 | def cider_score(hypothesis: str, reference: str) -> float: 211 | """ 212 | Calculates the CIDEr score between a reference and hypothesis sentence. 213 | 214 | Args: 215 | reference (str): The reference sentence. 216 | hypothesis (str): The hypothesis sentence. 217 | 218 | Returns: 219 | float: CIDEr score. 220 | 221 | Example usage: 222 | ``` 223 | >>> hypothesis = "the cat is on the mat" 224 | >>> reference = "the cat sits on the mat" 225 | >>> score = cider_score(hypothesis, reference) 226 | >>> print(score) 227 | ``` 228 | """ 229 | 230 | def tokenize(sentence): 231 | return re.findall(r"\w+", sentence.lower()) 232 | 233 | def ngrams(tokens, n): 234 | return [tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)] 235 | 236 | reference_tokens = tokenize(reference) 237 | hypothesis_tokens = tokenize(hypothesis) 238 | 239 | max_n = 4 # Maximum n-gram size 240 | weights = [1.0] * max_n # Weights for different n-gram sizes 241 | 242 | cider_scores = [] 243 | for n in range(1, max_n + 1): 244 | ref_ngrams = ngrams(reference_tokens, n) 245 | hyp_ngrams = ngrams(hypothesis_tokens, n) 246 | 247 | ref_ngram_freq = Counter(ref_ngrams) 248 | hyp_ngram_freq = Counter(hyp_ngrams) 249 | 250 | common_ngrams = set(ref_ngrams) & set(hyp_ngrams) 251 | 252 | if len(common_ngrams) == 0: 253 | cider_scores.append(0) 254 | continue 255 | 256 | precision = sum( 257 | min(ref_ngram_freq[ngram], hyp_ngram_freq[ngram]) for ngram in common_ngrams 258 | ) / len(hyp_ngrams) 259 | ref_ngram_freq_sum = sum(ref_ngram_freq[ngram] for ngram in common_ngrams) 260 | hyp_ngram_freq_sum = sum(hyp_ngram_freq[ngram] for ngram in common_ngrams) 261 | recall = ref_ngram_freq_sum / len(ref_ngrams) 262 | 263 | cider_scores.append((precision * recall) / (precision + recall) * 2) 264 | 265 | avg_cider_score = np.average(cider_scores, weights=weights) 266 | 267 | return avg_cider_score 268 | 269 | 270 | def perplexity(log_probs: List[float]) -> float: 271 | """ 272 | Calculate the perplexity of a sequence using a list of log probabilities. 273 | Perplexity = 2^(-average log likelihood) 274 | where average log likelihood = total log likelihood / total word count 275 | 276 | Args: 277 | log_probs (List[float]): List of log probabilities for each predicted word. 278 | 279 | Returns: 280 | float: Perplexity score. 281 | 282 | Example usage: 283 | ``` 284 | >>> log_probs = [-2.3, -1.7, -0.4] # Example log probabilities 285 | >>> perplexity_score = perplexity(log_probs) 286 | >>> print(perplexity_score) 287 | ``` 288 | """ 289 | log_likelihood = 0.0 290 | word_count = 0 291 | 292 | for i in range(len(log_probs) - 1): 293 | predicted_log_prob = log_probs[ 294 | i 295 | ] # Replace this with your language model's log probability 296 | log_likelihood += predicted_log_prob 297 | word_count += 1 298 | 299 | average_log_likelihood = log_likelihood / word_count 300 | perplexity_score = 2 ** (-average_log_likelihood) 301 | return perplexity_score 302 | 303 | 304 | def word_error_rate(hypotheses: List[int], references: List[int]) -> float: 305 | """ 306 | Calculate the Word Error Rate (WER) metric. 307 | 308 | Args: 309 | hypotheses (List[str]): List of hypothesis words. 310 | references (List[str]): List of reference words. 311 | 312 | Returns: 313 | float: Word Error Rate score. 314 | 315 | Example usage: 316 | ``` 317 | >>> hypotheses = ["the cat is on the mat", "there is a cat on the mat"] 318 | >>> references = ["the cat is on the mat", "the cat sits on the mat"] 319 | >>> wer_score = word_error_rate(hypotheses, references) 320 | >>> print(wer_score) 321 | ``` 322 | """ 323 | 324 | def edit_distance(str1, str2): 325 | len_str1 = len(str1) 326 | len_str2 = len(str2) 327 | 328 | dp = [[0] * (len_str2 + 1) for _ in range(len_str1 + 1)] 329 | 330 | for i in range(len_str1 + 1): 331 | dp[i][0] = i 332 | 333 | for j in range(len_str2 + 1): 334 | dp[0][j] = j 335 | 336 | for i in range(1, len_str1 + 1): 337 | for j in range(1, len_str2 + 1): 338 | cost = 0 if str1[i - 1] == str2[j - 1] else 1 339 | dp[i][j] = min( 340 | dp[i - 1][j] + 1, # Deletion 341 | dp[i][j - 1] + 1, # Insertion 342 | dp[i - 1][j - 1] + cost, # Substitution or no operation 343 | ) 344 | return dp[len_str1][len_str2] 345 | 346 | total_edit_distance = 0 347 | total_reference_length = 0 348 | 349 | for hyp, ref in zip(hypotheses, references): 350 | edit_dist = edit_distance(hyp.split(), ref.split()) 351 | total_edit_distance += edit_dist 352 | total_reference_length += len(ref.split()) 353 | 354 | wer_score = total_edit_distance / total_reference_length 355 | return wer_score 356 | -------------------------------------------------------------------------------- /nanodl/__src/utils/random.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Any, Tuple, Union 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | from jax import random 7 | 8 | 9 | def time_rng_key(seed=None) -> jnp.ndarray: 10 | """Generate a JAX random key based on the current UNIX timestamp. 11 | 12 | Returns: 13 | jnp.ndarray: A JAX random key. 14 | """ 15 | key = int(time.time()) if seed is None else seed 16 | return random.PRNGKey(key) 17 | 18 | 19 | def uniform( 20 | shape: Tuple[int, ...], 21 | minval: Any = 0.0, 22 | maxval: Any = 1.0, 23 | seed=None, 24 | dtype: Any = jnp.float32, 25 | ) -> jnp.ndarray: 26 | """Generate a tensor of uniform random values. 27 | 28 | Args: 29 | shape (Tuple[int, ...]): The shape of the output tensor. 30 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32. 31 | minval (Any, optional): The lower bound of the uniform distribution. Defaults to 0.0. 32 | maxval (Any, optional): The upper bound of the uniform distribution. Defaults to 1.0. 33 | 34 | Returns: 35 | jnp.ndarray: A tensor of uniform random values. 36 | """ 37 | return random.uniform( 38 | time_rng_key(seed), shape, dtype=dtype, minval=minval, maxval=maxval 39 | ) 40 | 41 | 42 | def normal(shape: Tuple[int, ...], dtype: Any = jnp.float32, seed=None) -> jnp.ndarray: 43 | """Generate a tensor of normal random values. 44 | 45 | Args: 46 | shape (Tuple[int, ...]): The shape of the output tensor. 47 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32. 48 | 49 | Returns: 50 | jnp.ndarray: A tensor of normal random values. 51 | """ 52 | return random.normal(time_rng_key(seed), shape, dtype=dtype) 53 | 54 | 55 | def bernoulli(p: float, shape: Tuple[int, ...] = (), seed=None) -> jnp.ndarray: 56 | """Generate random boolean values with a given probability. 57 | 58 | Args: 59 | p (float): Probability of sampling a True value. 60 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). 61 | 62 | Returns: 63 | jnp.ndarray: A tensor of boolean values. 64 | """ 65 | return random.bernoulli(time_rng_key(seed), p, shape) 66 | 67 | 68 | def categorical( 69 | logits: jnp.ndarray, axis: int = -1, shape: Tuple[int, ...] = (), seed=None 70 | ) -> jnp.ndarray: 71 | """Draw samples from a categorical distribution. 72 | 73 | Args: 74 | logits (jnp.ndarray): The unnormalized log probabilities of the categories. 75 | axis (int, optional): The axis along which the categorical distribution is applied. Defaults to -1. 76 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). 77 | 78 | Returns: 79 | jnp.ndarray: The sampled indices with the specified shape. 80 | """ 81 | return random.categorical(time_rng_key(seed), logits, axis=axis, shape=shape) 82 | 83 | 84 | def randint( 85 | shape: Tuple[int, ...], minval: int, maxval: int, dtype: str = "int32", seed=None 86 | ) -> jnp.ndarray: 87 | """Generate random integers between minval (inclusive) and maxval (exclusive). 88 | 89 | Args: 90 | shape (Tuple[int, ...]): The shape of the output tensor. 91 | minval (int): The lower bound of the random integers, inclusive. 92 | maxval (int): The upper bound of the random integers, exclusive. 93 | dtype (str, optional): The data type of the output tensor. Defaults to 'int32'. 94 | 95 | Returns: 96 | jnp.ndarray: A tensor of random integers. 97 | """ 98 | return random.randint(time_rng_key(seed), shape, minval, maxval, dtype=dtype) 99 | 100 | 101 | def permutation(x: Union[int, jnp.ndarray], axis: int = 0, seed=None) -> jnp.ndarray: 102 | """Randomly permute a sequence, or return a permuted range. 103 | 104 | Args: 105 | x (Union[int, jnp.ndarray]): If x is an integer, permute range(x). If x is an array, permute its elements. 106 | axis (int, optional): The axis along which to permute if x is an array. Defaults to 0. 107 | 108 | Returns: 109 | jnp.ndarray: The permuted sequence or array. 110 | """ 111 | if isinstance(x, int): 112 | arr = jax.numpy.arange(x) 113 | return random.permutation(time_rng_key(seed), arr, axis=axis) 114 | else: 115 | return random.permutation(time_rng_key(seed), x, axis=axis) 116 | 117 | 118 | def gumbel(shape: Tuple[int, ...], dtype: Any = jnp.float32, seed=None) -> jnp.ndarray: 119 | """Draw samples from a Gumbel distribution. 120 | 121 | Args: 122 | shape (Tuple[int, ...]): The shape of the output tensor. 123 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32. 124 | 125 | Returns: 126 | jnp.ndarray: A tensor of samples from a Gumbel distribution. 127 | """ 128 | return random.gumbel(time_rng_key(seed), shape, dtype=dtype) 129 | 130 | 131 | def choice( 132 | a: Union[int, jnp.ndarray], 133 | shape: Tuple[int, ...] = (), 134 | replace: bool = True, 135 | p: Union[None, jnp.ndarray] = None, 136 | axis: int = 0, 137 | seed=None, 138 | ) -> jnp.ndarray: 139 | """Randomly choose elements from a given 1-D array. 140 | 141 | Args: 142 | a (Union[int, jnp.ndarray]): If an int, the random sample is generated as if a were jnp.arange(a). 143 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). 144 | replace (bool, optional): Whether the sample is with or without replacement. Defaults to True. 145 | p (Union[None, jnp.ndarray], optional): The probabilities associated with each entry in a. Defaults to None. 146 | axis (int, optional): The axis along which to choose if a is an array. Defaults to 0. 147 | 148 | Returns: 149 | jnp.ndarray: The randomly chosen elements. 150 | """ 151 | if isinstance(a, int): 152 | a = jnp.arange(a) 153 | return random.choice( 154 | time_rng_key(seed), a, shape=shape, replace=replace, p=p, axis=axis 155 | ) 156 | 157 | 158 | def bits(shape: Tuple[int, ...], dtype: Any = jnp.uint32, seed=None) -> jnp.ndarray: 159 | """Generate random bits. 160 | 161 | Args: 162 | shape (Tuple[int, ...]): The shape of the output tensor. 163 | dtype (Any, optional): The data type of the output tensor, typically an unsigned integer type. Defaults to jnp.uint32. 164 | 165 | Returns: 166 | jnp.ndarray: A tensor of random bits. 167 | """ 168 | return random.bits(time_rng_key(seed), shape, dtype=dtype) 169 | 170 | 171 | def exponential( 172 | shape: Tuple[int, ...], dtype: Any = jnp.float32, seed=None 173 | ) -> jnp.ndarray: 174 | """Draw samples from an exponential distribution. 175 | 176 | Args: 177 | shape (Tuple[int, ...]): The shape of the output tensor. 178 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32. 179 | 180 | Returns: 181 | jnp.ndarray: A tensor of samples from an exponential distribution. 182 | """ 183 | return random.exponential(time_rng_key(seed), shape, dtype=dtype) 184 | 185 | 186 | def triangular( 187 | left: float, right: float, mode: float, shape: Tuple[int, ...] = (), seed=None 188 | ) -> jnp.ndarray: 189 | """Draw samples from a triangular distribution. 190 | 191 | Args: 192 | left (float): The lower limit of the distribution. 193 | right (float): The upper limit of the distribution. 194 | mode (float): The mode (peak) of the distribution. 195 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). 196 | 197 | Returns: 198 | jnp.ndarray: A tensor of samples from a triangular distribution. 199 | """ 200 | return random.triangular(time_rng_key(seed), left, right, mode, shape) 201 | 202 | 203 | def truncated_normal( 204 | lower: float, 205 | upper: float, 206 | shape: Tuple[int, ...] = (), 207 | dtype: Any = jnp.float32, 208 | seed=None, 209 | ) -> jnp.ndarray: 210 | """Draw samples from a truncated normal distribution. 211 | 212 | Args: 213 | lower (float): The lower bound of the distribution. 214 | upper (float): The upper bound of the distribution. 215 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). 216 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32. 217 | 218 | Returns: 219 | jnp.ndarray: A tensor of samples from a truncated normal distribution. 220 | """ 221 | return random.truncated_normal(time_rng_key(seed), lower, upper, shape, dtype) 222 | 223 | 224 | def poisson( 225 | lam: float, shape: Tuple[int, ...] = (), dtype: Any = jnp.int32, seed=None 226 | ) -> jnp.ndarray: 227 | """Draw samples from a Poisson distribution. 228 | 229 | Args: 230 | lam (float): The expectation of interval (lambda parameter). 231 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). 232 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.int32. 233 | 234 | Returns: 235 | jnp.ndarray: A tensor of samples from a Poisson distribution. 236 | """ 237 | return random.poisson(time_rng_key(seed), lam, shape=shape, dtype=dtype) 238 | 239 | 240 | def geometric( 241 | p: float, shape: Tuple[int, ...] = (), dtype: Any = jnp.int32, seed=None 242 | ) -> jnp.ndarray: 243 | """Draw samples from a geometric distribution. 244 | 245 | Args: 246 | p (float): The probability of success of an individual trial. 247 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). 248 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.int32. 249 | 250 | Returns: 251 | jnp.ndarray: A tensor of samples from a geometric distribution. 252 | """ 253 | return random.geometric(time_rng_key(seed), p, shape=shape, dtype=dtype) 254 | 255 | 256 | def gamma( 257 | a: float, shape: Tuple[int, ...] = (), dtype: Any = jnp.float32, seed=None 258 | ) -> jnp.ndarray: 259 | """Draw samples from a gamma distribution. 260 | 261 | Args: 262 | a (float): The shape parameter of the gamma distribution. 263 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). 264 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32. 265 | 266 | Returns: 267 | jnp.ndarray: A tensor of samples from a gamma distribution. 268 | """ 269 | return random.gamma(time_rng_key(seed), a, shape=shape, dtype=dtype) 270 | 271 | 272 | def chisquare( 273 | df: float, shape: Tuple[int, ...] = (), dtype: Any = jnp.float32, seed=None 274 | ) -> jnp.ndarray: 275 | """Draw samples from a chi-square distribution. 276 | 277 | Args: 278 | df (float): The degrees of freedom. 279 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). 280 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32. 281 | 282 | Returns: 283 | jnp.ndarray: A tensor of samples from a chi-square distribution. 284 | """ 285 | return random.chisquare(time_rng_key(seed), df, shape=shape, dtype=dtype) 286 | -------------------------------------------------------------------------------- /nanodl/__src/utils/vision.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | 7 | @jax.jit 8 | def normalize_images(images: jnp.ndarray) -> jnp.ndarray: 9 | """ 10 | Normalize images to have zero mean and unit variance. 11 | 12 | Args: 13 | images (jnp.ndarray): Input images of shape (N, H, W, C), where N is the number of images, 14 | H is height, W is width, and C is the number of channels. 15 | 16 | Returns: 17 | jnp.ndarray: Normalized images of the same shape as the input. 18 | 19 | Example usage: 20 | ``` 21 | >>> images = jnp.array([[[[0.0, 0.5], [1.0, 0.25]]]]) # One image of shape (1, 2, 2, 1) 22 | >>> normalized_images = normalize_images(images) 23 | >>> print(normalized_images) 24 | ``` 25 | """ 26 | mean = images.mean(axis=(1, 2, 3), keepdims=True) 27 | std = images.std(axis=(1, 2, 3), keepdims=True) 28 | return (images - mean) / (std + 1e-5) 29 | 30 | 31 | def random_crop(images: jnp.ndarray, crop_size: int) -> jnp.ndarray: 32 | """ 33 | Randomly crop a batch of images to a specified size using JAX. 34 | 35 | This function takes a batch of images and randomly crops each image to the specified size. 36 | It uses JAX for random number generation to determine the starting coordinates of the crop. 37 | 38 | Args: 39 | images (jax.numpy.ndarray): A 4D array of shape (batch_size, height, width, channels), 40 | representing a batch of images. 41 | crop_size (int): The size to which each image will be cropped. Both the height and width 42 | of the crop will be equal to `crop_size`. 43 | 44 | Returns: 45 | jax.numpy.ndarray: The cropped images, with shape (batch_size, crop_size, crop_size, channels). 46 | 47 | Example usage: 48 | ``` 49 | >>> images = jnp.ones((10, 100, 100, 3)) # Batch of 10 images of size 100x100 with 3 channels 50 | >>> crop_size = 64 51 | >>> cropped_images = random_crop(images, crop_size) 52 | >>> print(cropped_images.shape) 53 | ``` 54 | """ 55 | key = jax.random.PRNGKey(int(time.time())) 56 | _, height, width, _ = images.shape 57 | height_start = jax.random.randint(key, (), 0, height - crop_size + 1) 58 | width_start = jax.random.randint(key, (), 0, width - crop_size + 1) 59 | height_end = height_start + crop_size 60 | width_end = width_start + crop_size 61 | crops = images[:, height_start:height_end, width_start:width_end, :] 62 | return crops 63 | 64 | 65 | def gaussian_blur(image: jnp.ndarray, kernel_size: int, sigma: float) -> jnp.ndarray: 66 | """ 67 | Apply Gaussian blur to a multi-channel image. 68 | 69 | Args: 70 | image (jnp.ndarray): Input image of shape (H, W, C). 71 | kernel_size (int): Size of the Gaussian kernel (must be odd). 72 | sigma (float): Standard deviation of the Gaussian kernel. 73 | 74 | Returns: 75 | jnp.ndarray: Blurred image of the same shape as the input. 76 | 77 | Example usage: 78 | ``` 79 | >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels 80 | >>> blurred_image = gaussian_blur(image, kernel_size=3, sigma=1.0) 81 | >>> print(blurred_image.shape) 82 | ``` 83 | """ 84 | assert kernel_size % 2 == 1, "Kernel size must be odd." 85 | ax = jnp.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0) 86 | xx, yy = jnp.meshgrid(ax, ax) 87 | kernel = jnp.exp(-(xx**2 + yy**2) / (2.0 * sigma**2)) 88 | kernel = kernel / jnp.sum(kernel) 89 | 90 | # Apply convolution to each channel 91 | blurred_image = jnp.stack( 92 | [ 93 | jax.scipy.signal.convolve2d(image[:, :, i], kernel, mode="same") 94 | for i in range(image.shape[2]) 95 | ], 96 | axis=-1, 97 | ) 98 | return blurred_image 99 | 100 | 101 | @jax.jit 102 | def sobel_edge_detection(image: jnp.ndarray) -> jnp.ndarray: 103 | """ 104 | Apply Sobel edge detection to a multi-channel image. 105 | 106 | Args: 107 | image (jnp.ndarray): Input image of shape (H, W, C). 108 | 109 | Returns: 110 | jnp.ndarray: Image representing the edges, of the same shape as the input. 111 | 112 | Example usage: 113 | ``` 114 | >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels 115 | >>> edges = sobel_edge_detection(image) 116 | >>> print(edges.shape) 117 | ``` 118 | """ 119 | sobel_x = jnp.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=jnp.float32) 120 | sobel_y = jnp.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=jnp.float32) 121 | 122 | def apply_sobel(channel): 123 | gx = jax.scipy.signal.convolve2d(channel, sobel_x, mode="same") 124 | gy = jax.scipy.signal.convolve2d(channel, sobel_y, mode="same") 125 | return jnp.sqrt(gx**2 + gy**2) 126 | 127 | # Apply Sobel filter to each channel and sum the results 128 | edges = jnp.sum( 129 | jnp.stack( 130 | [apply_sobel(image[:, :, i]) for i in range(image.shape[2])], axis=-1 131 | ), 132 | axis=-1, 133 | ) 134 | return edges 135 | 136 | 137 | @jax.jit 138 | def adjust_brightness(image: jnp.ndarray, factor: float) -> jnp.ndarray: 139 | """ 140 | Adjust the brightness of an image. 141 | 142 | Args: 143 | image (jnp.ndarray): Input image of shape (H, W, C). 144 | factor (float): Factor to adjust brightness. Values > 1 increase brightness, 145 | values < 1 decrease brightness. 146 | 147 | Returns: 148 | jnp.ndarray: Brightness-adjusted image of the same shape as the input. 149 | 150 | Example usage: 151 | ``` 152 | >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels 153 | >>> adjusted_image = adjust_brightness(image, factor=1.5) 154 | >>> print(adjusted_image.shape) 155 | ``` 156 | """ 157 | return jnp.clip(image * factor, 0, 1) 158 | 159 | 160 | @jax.jit 161 | def adjust_contrast(image: jnp.ndarray, factor: float) -> jnp.ndarray: 162 | """ 163 | Adjust the contrast of an image. 164 | 165 | Args: 166 | image (jnp.ndarray): Input image of shape (H, W, C). 167 | factor (float): Factor to adjust contrast. Values > 1 increase contrast, 168 | values < 1 decrease contrast. 169 | 170 | Returns: 171 | jnp.ndarray: Contrast-adjusted image of the same shape as the input. 172 | 173 | Example usage: 174 | ``` 175 | >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels 176 | >>> adjusted_image = adjust_contrast(image, factor=1.5) 177 | >>> print(adjusted_image.shape) 178 | ``` 179 | """ 180 | mean = jnp.mean(image, axis=(0, 1), keepdims=True) 181 | return jnp.clip((image - mean) * factor + mean, 0, 1) 182 | 183 | 184 | @jax.jit 185 | def flip_image(image: jnp.ndarray, horizontal: jnp.ndarray) -> jnp.ndarray: 186 | """ 187 | Flip an image horizontally or vertically. 188 | 189 | Args: 190 | image (jnp.ndarray): Input image of shape (H, W, C). 191 | horizontal (jnp.ndarray): If True (jax.numpy.array with a single True value), flip horizontally; 192 | otherwise, flip vertically. 193 | 194 | Returns: 195 | jnp.ndarray: Flipped image of the same shape as the input. 196 | 197 | Example usage: 198 | ``` 199 | >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels 200 | >>> flipped_image_horizontally = flip_image(image, jnp.array([True])) 201 | >>> flipped_image_vertically = flip_image(image, jnp.array([False])) 202 | >>> print(flipped_image_horizontally.shape, flipped_image_vertically.shape) 203 | ``` 204 | """ 205 | return jnp.where(horizontal, image[:, ::-1, :], image[::-1, :, :]) 206 | 207 | 208 | @jax.jit 209 | def random_flip_image( 210 | image: jnp.ndarray, key: jax.random.PRNGKey, horizontal: jnp.ndarray 211 | ) -> jnp.ndarray: 212 | """ 213 | Randomly flip an image horizontally or vertically using JAX. 214 | 215 | Args: 216 | image (jnp.ndarray): Input image of shape (H, W, C). 217 | key (jax.random.PRNGKey): A PRNG key used for random number generation. 218 | horizontal (jnp.ndarray): JAX array with a single boolean value indicating the flip direction. 219 | If True (jax.numpy.array with a single True value), flip horizontally; 220 | otherwise, flip vertically. 221 | 222 | Returns: 223 | jnp.ndarray: Randomly flipped image of the same shape as the input. 224 | 225 | Example usage: 226 | ``` 227 | >>> key = jax.random.PRNGKey(0) 228 | >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels 229 | >>> flipped_image = random_flip_image(image, key, jnp.array([True])) 230 | >>> print(flipped_image.shape) 231 | ``` 232 | """ 233 | flip = jax.random.uniform(key) > 0.5 234 | flip_horizontal = jnp.where(horizontal, image[:, ::-1, :], image) 235 | flip_vertical = jnp.where(horizontal, image, image[::-1, :, :]) 236 | return jnp.where(flip, flip_horizontal, flip_vertical) 237 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax 2 | jaxlib 3 | flax 4 | optax 5 | einops 6 | sentencepiece -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="nanodl", 5 | version="0.0.0", 6 | author="Henry Ndubuaku", 7 | author_email="ndubuakuhenry@gmail.com", 8 | description="A Jax-based library for designing and training transformer models from scratch.", 9 | long_description=open("README.md").read(), 10 | long_description_content_type="text/markdown", 11 | url="https://github.com/hmunachi/nanodl", 12 | packages=find_packages(), 13 | install_requires=[ 14 | "flax", 15 | "jax", 16 | "jaxlib", 17 | "optax", 18 | "einops", 19 | "sentencepiece", 20 | ], 21 | classifiers=[ 22 | "Development Status :: 3 - Alpha", 23 | "Intended Audience :: Developers", 24 | "Intended Audience :: Science/Research", 25 | "Intended Audience :: Education", 26 | "Topic :: Software Development :: Build Tools", 27 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 28 | "License :: OSI Approved :: MIT License", 29 | "Programming Language :: Python :: 3", 30 | "Programming Language :: Python :: 3.7", 31 | "Programming Language :: Python :: 3.8", 32 | "Programming Language :: Python :: 3.9", 33 | ], 34 | keywords="transformers jax machine learning deep learning pytorch tensorflow", 35 | python_requires=">=3.7", 36 | ) 37 | -------------------------------------------------------------------------------- /tests/files/sample.txt: -------------------------------------------------------------------------------- 1 | Hello, world! This is a test of the Tokenizer. 2 | Let's see how it tokenizes this file. 3 | Another sentence to check the tokenization process. -------------------------------------------------------------------------------- /tests/test_classic.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from nanodl import * 7 | 8 | 9 | class TestNaiveBayesFunctions(unittest.TestCase): 10 | def setUp(self): 11 | self.num_samples = 3 12 | self.num_features = 2 13 | self.num_classes = 2 14 | self.X = jnp.array([[0, 1], [1, 0], [1, 1]]) 15 | self.y = jnp.array([0, 1, 0]) 16 | 17 | def test_naive_bayes_classifier(self): 18 | classifier = NaiveBayesClassifier(num_classes=self.num_classes) 19 | classifier.fit(self.X, self.y) 20 | predictions = classifier.predict(self.X) 21 | self.assertEqual(predictions.shape, (self.num_samples,)) 22 | self.assertTrue( 23 | jnp.all(predictions >= 0) and jnp.all(predictions < self.num_classes) 24 | ) 25 | 26 | 27 | class TestKClustering(unittest.TestCase): 28 | def setUp(self): 29 | self.k = 3 30 | self.num_samples = 300 31 | self.num_features = 2 32 | self.X = jax.random.normal( 33 | jax.random.PRNGKey(0), (self.num_samples, self.num_features) 34 | ) 35 | 36 | def test_kmeans_fit_predict(self): 37 | kmeans = KMeans(k=self.k) 38 | kmeans.fit(self.X) 39 | clusters = kmeans.predict(self.X) 40 | self.assertEqual(len(set(clusters.tolist())), self.k) 41 | 42 | def test_gmm_fit_predict(self): 43 | gmm = GaussianMixtureModel(n_components=self.k) 44 | gmm.fit(self.X) 45 | labels = gmm.predict(self.X) 46 | self.assertEqual(len(set(labels.tolist())), self.k) 47 | 48 | 49 | class TestPCA(unittest.TestCase): 50 | def test_pca_fit_transform(self): 51 | data = jax.random.normal(jax.random.PRNGKey(0), (1000, 10)) 52 | pca = PCA(n_components=2) 53 | pca.fit(data) 54 | transformed_data = pca.transform(data) 55 | self.assertEqual(transformed_data.shape, (1000, 2)) 56 | 57 | def test_pca_inverse_transform(self): 58 | data = jax.random.normal(jax.random.PRNGKey(0), (1000, 10)) 59 | pca = PCA(n_components=2) 60 | pca.fit(data) 61 | transformed_data = pca.transform(data) 62 | inverse_data = pca.inverse_transform(transformed_data) 63 | self.assertEqual(inverse_data.shape, data.shape) 64 | 65 | def test_pca_sample(self): 66 | data = jax.random.normal(jax.random.PRNGKey(0), (1000, 10)) 67 | pca = PCA(n_components=2) 68 | pca.fit(data) 69 | synthetic_samples = pca.sample(n_samples=100) 70 | self.assertEqual(synthetic_samples.shape, (100, 10)) 71 | 72 | 73 | class TestRegression(unittest.TestCase): 74 | def test_linear_regression(self): 75 | num_samples = 100 76 | input_dim = 1 77 | output_dim = 1 78 | x_data = jax.random.normal(jax.random.PRNGKey(0), (num_samples, input_dim)) 79 | y_data = jnp.dot(x_data, jnp.array([[2.0]])) - jnp.array([[-1.0]]) 80 | lr_model = LinearRegression(input_dim, output_dim) 81 | lr_model.fit(x_data, y_data) 82 | learned_weights, learned_bias = lr_model.get_params() 83 | self.assertTrue(jnp.allclose(learned_weights, jnp.array([[2.0]]), atol=1e-1)) 84 | self.assertTrue(jnp.allclose(learned_bias, jnp.array([[1.0]]), atol=1e-1)) 85 | 86 | def test_logistic_regression(self): 87 | num_samples = 100 88 | input_dim = 2 89 | x_data = jax.random.normal(jax.random.PRNGKey(0), (num_samples, input_dim)) 90 | logits = jnp.dot(x_data, jnp.array([0.5, -0.5])) - 0.1 91 | y_data = (logits > 0).astype(jnp.float32) 92 | lr_model = LogisticRegression(input_dim) 93 | lr_model.fit(x_data, y_data) 94 | test_data = jax.random.normal(jax.random.PRNGKey(0), (num_samples, input_dim)) 95 | predictions = lr_model.predict(test_data) 96 | self.assertTrue(jnp.all(predictions >= 0) and jnp.all(predictions <= 1)) 97 | 98 | def test_gaussian_process(self): 99 | def rbf_kernel(x1, x2, length_scale=1.0): 100 | diff = x1[:, None] - x2 101 | return jnp.exp(-0.5 * jnp.sum(diff**2, axis=-1) / length_scale**2) 102 | 103 | num_samples = 100 104 | input_dim = 1 105 | X_train = jax.random.normal(jax.random.PRNGKey(0), (num_samples, input_dim)) 106 | y_train = ( 107 | jnp.sin(X_train) 108 | + jax.random.normal(jax.random.PRNGKey(0), (num_samples, 1)) * 0.1 109 | ) 110 | gp = GaussianProcess(kernel=rbf_kernel, noise=1e-3) 111 | gp.fit(X_train, y_train) 112 | X_new = jax.random.normal(jax.random.PRNGKey(0), (num_samples, input_dim)) 113 | mean, covariance = gp.predict(X_new) 114 | self.assertEqual(mean.shape, (num_samples, 1)) 115 | self.assertEqual(covariance.shape, (num_samples, num_samples)) 116 | 117 | 118 | if __name__ == "__main__": 119 | unittest.main() 120 | -------------------------------------------------------------------------------- /tests/test_kan.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import jax.numpy as jnp 3 | from jax import random 4 | from nanodl import * 5 | 6 | class TestKANLinearVariants(unittest.TestCase): 7 | def setUp(self): 8 | self.in_features = 4 9 | self.out_features = 3 10 | self.degree = 2 11 | 12 | self.key = random.PRNGKey(0) 13 | self.x = random.normal(self.key, (10, self.in_features)) 14 | 15 | self.models = { 16 | # "BSplineKANLinear": KANLinear(self.in_features, self.out_features, self.degree), 17 | "ChebyKANLinear": ChebyKANLinear(self.in_features, self.out_features, self.degree), 18 | "LegendreKANLinear": LegendreKANLinear(self.in_features, self.out_features, self.degree), 19 | "MonomialKANLinear": MonomialKANLinear(self.in_features, self.out_features, self.degree), 20 | "FourierKANLinear": FourierKANLinear(self.in_features, self.out_features, self.degree), 21 | "HermiteKANLinear": HermiteKANLinear(self.in_features, self.out_features, self.degree), 22 | } 23 | 24 | def test_model_outputs(self): 25 | for model_name, model in self.models.items(): 26 | with self.subTest(model=model_name): 27 | variables = model.init(self.key, self.x) 28 | output = model.apply(variables, self.x) 29 | self.assertEqual(output.shape, (10, self.out_features)) 30 | self.assertTrue(jnp.all(jnp.isfinite(output))) 31 | 32 | 33 | if __name__ == "__main__": 34 | unittest.main() -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from nanodl import * 7 | 8 | 9 | class TestTextBasedModels(unittest.TestCase): 10 | def setUp(self): 11 | self.batch_size = 8 12 | self.max_length = 51 13 | self.vocab_size = 1000 14 | self.embed_dim = 256 15 | 16 | self.data = jnp.arange( 17 | self.batch_size * self.max_length, dtype=jnp.int32 18 | ).reshape((self.batch_size, self.max_length)) 19 | 20 | self.dummy_inputs = self.data[:, :-1] 21 | self.dummy_targets = self.data[:, 1:] 22 | 23 | self.hyperparams = { 24 | "num_layers": 1, 25 | "hidden_dim": self.embed_dim, 26 | "num_heads": 2, 27 | "feedforward_dim": self.embed_dim, 28 | "dropout": 0.1, 29 | "vocab_size": self.vocab_size, 30 | "embed_dim": self.embed_dim, 31 | "max_length": self.max_length, 32 | "start_token": 0, 33 | "end_token": 50, 34 | } 35 | 36 | def test_t5_model(self): 37 | model = T5(**self.hyperparams) 38 | self._test_encoder_decoder_model(model) 39 | 40 | def test_transformer_model(self): 41 | model = Transformer(**self.hyperparams) 42 | self._test_encoder_decoder_model(model) 43 | 44 | def test_lamda_model(self): 45 | model = LaMDA(**self.hyperparams) 46 | self._test_decoder_only_model(model) 47 | 48 | def test_gpt3_model(self): 49 | model = GPT4(**self.hyperparams) 50 | self._test_decoder_only_model(model) 51 | 52 | def test_gpt3_model(self): 53 | model = GPT4(**self.hyperparams) 54 | self._test_decoder_only_model(model) 55 | 56 | def test_mistral_model(self): 57 | model = Mistral(**self.hyperparams, num_groups=2, window_size=5, shift_size=2) 58 | self._test_decoder_only_model(model) 59 | 60 | def test_mixtral_model(self): 61 | model = Mixtral(**self.hyperparams, num_groups=2, window_size=5, shift_size=2) 62 | self._test_decoder_only_model(model) 63 | 64 | def test_llama_model(self): 65 | model = Llama3(**self.hyperparams, num_groups=2) 66 | self._test_decoder_only_model(model) 67 | 68 | def test_gemma_model(self): 69 | model = Gemma(**self.hyperparams, num_groups=2) 70 | self._test_decoder_only_model(model) 71 | 72 | def _test_encoder_decoder_model(self, model): 73 | rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} 74 | params = model.init(rngs, self.dummy_inputs, self.dummy_targets)["params"] 75 | 76 | outputs = model.apply( 77 | {"params": params}, self.dummy_inputs, self.dummy_targets, rngs=rngs 78 | ) 79 | 80 | self.assertEqual( 81 | outputs.shape, (self.batch_size, self.max_length - 1, self.vocab_size) 82 | ) 83 | 84 | def _test_decoder_only_model(self, model): 85 | rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} 86 | params = model.init(rngs, self.dummy_inputs)["params"] 87 | 88 | outputs = model.apply({"params": params}, self.dummy_inputs, rngs=rngs) 89 | 90 | self.assertEqual( 91 | outputs.shape, (self.batch_size, self.max_length - 1, self.vocab_size) 92 | ) 93 | 94 | def test_reward_model(self): 95 | model = RewardModel( 96 | Mixtral(**self.hyperparams, num_groups=2, window_size=5, shift_size=2), 97 | dim=self.hyperparams["hidden_dim"], 98 | dropout=0.1, 99 | ) 100 | rngs = jax.random.PRNGKey(0) 101 | rngs, dropout_rng = jax.random.split(rngs) 102 | params = model.init( 103 | {"params": rngs, "dropout": dropout_rng}, self.dummy_inputs 104 | )["params"] 105 | rewards = model.apply( 106 | {"params": params}, self.dummy_inputs, rngs={"dropout": dropout_rng} 107 | ) 108 | assert rewards.shape == (self.batch_size,) 109 | 110 | 111 | class TestVisionBasedModels(unittest.TestCase): 112 | def setUp(self): 113 | self.batch_size = 8 114 | self.n_outputs = 5 115 | self.embed_dim = 256 116 | self.patch_size = (16, 16) 117 | self.dummy_inputs = jnp.ones((self.batch_size, 224, 224, 3)) 118 | key = jax.random.PRNGKey(10) 119 | 120 | self.dummy_labels = jax.random.randint( 121 | key, shape=(self.batch_size,), minval=0, maxval=self.n_outputs - 1 122 | ) 123 | 124 | self.hyperparams = { 125 | "dropout": 0.1, 126 | "num_heads": 2, 127 | "feedforward_dim": self.embed_dim, 128 | "patch_size": self.patch_size, 129 | "hidden_dim": self.embed_dim, 130 | "num_layers": 4, 131 | "n_outputs": self.n_outputs, 132 | } 133 | 134 | def test_vit_model(self): 135 | model = ViT(**self.hyperparams) 136 | self._test_model(model) 137 | 138 | def test_mixer_model(self): 139 | model = Mixer(**self.hyperparams) 140 | self._test_model(model) 141 | 142 | def _test_model(self, model): 143 | rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} 144 | 145 | params = model.init(rngs, self.dummy_inputs)["params"] 146 | 147 | outputs = model.apply({"params": params}, self.dummy_inputs, rngs=rngs)[0] 148 | 149 | self.assertEqual(outputs.shape, (self.batch_size, self.n_outputs)) 150 | 151 | 152 | class TestCLIPModel(unittest.TestCase): 153 | def setUp(self): 154 | self.batch_size = 8 155 | self.max_length = 50 156 | self.vocab_size = 1000 157 | self.embed_dim = 256 158 | self.dummy_texts = jnp.ones((self.batch_size, self.max_length), dtype=jnp.int32) 159 | self.dummy_images = jnp.ones((self.batch_size, 224, 224, 3)) 160 | 161 | self.clip_params = { 162 | "dropout": 0.1, 163 | "num_heads": 8, 164 | "feedforward_dim": self.embed_dim, 165 | "num_layers_text": 4, 166 | "hidden_dim_text": self.embed_dim, 167 | "image_patch_size": (16, 16), 168 | "hidden_dim_image": self.embed_dim, 169 | "num_layers_images": 4, 170 | "max_len": self.max_length, 171 | "vocab_size": self.vocab_size, 172 | "embed_dim": self.embed_dim, 173 | } 174 | 175 | self.model = CLIP(**self.clip_params) 176 | 177 | def test_clip_model_initialization_and_processing(self): 178 | rng = jax.random.PRNGKey(0) 179 | params = self.model.init(rng, self.dummy_texts, self.dummy_images)["params"] 180 | 181 | loss = self.model.apply({"params": params}, self.dummy_texts, self.dummy_images) 182 | 183 | self.assertIsNotNone(loss) 184 | 185 | 186 | class TestWhisperModel(unittest.TestCase): 187 | def setUp(self): 188 | self.batch_size = 8 189 | self.max_length = 50 190 | self.embed_dim = 256 191 | self.vocab_size = 1000 192 | 193 | self.dummy_targets = jnp.arange( 194 | self.batch_size * self.max_length, dtype=jnp.int32 195 | ).reshape((self.batch_size, self.max_length)) 196 | 197 | self.dummy_inputs = jnp.ones((self.batch_size, self.max_length, self.embed_dim)) 198 | 199 | self.hyperparams = { 200 | "num_layers": 1, 201 | "hidden_dim": self.embed_dim, 202 | "num_heads": 2, 203 | "feedforward_dim": self.embed_dim, 204 | "dropout": 0.1, 205 | "vocab_size": self.vocab_size, 206 | "embed_dim": self.embed_dim, 207 | "max_length": self.max_length, 208 | "start_token": 0, 209 | "end_token": 50, 210 | } 211 | 212 | self.model = Whisper(**self.hyperparams) 213 | 214 | def test_whisper_model_initialization_and_processing(self): 215 | rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} 216 | 217 | params = self.model.init(rngs, self.dummy_inputs, self.dummy_targets)["params"] 218 | 219 | outputs = self.model.apply( 220 | {"params": params}, self.dummy_inputs, self.dummy_targets, rngs=rngs 221 | ) 222 | 223 | self.assertEqual( 224 | outputs.shape, (self.batch_size, self.max_length, self.vocab_size) 225 | ) 226 | 227 | 228 | class TestDiffusionModel(unittest.TestCase): 229 | def setUp(self): 230 | self.image_size = 32 231 | self.widths = [32, 64, 128] 232 | self.block_depth = 2 233 | self.input_shape = (3, self.image_size, self.image_size, 3) 234 | self.images = jax.random.normal(jax.random.PRNGKey(0), self.input_shape) 235 | 236 | self.model = DiffusionModel(self.image_size, self.widths, self.block_depth) 237 | 238 | def test_diffusion_model_initialization_and_processing(self): 239 | params = self.model.init(jax.random.PRNGKey(0), self.images) 240 | pred_noises, pred_images = self.model.apply(params, self.images) 241 | self.assertEqual(pred_noises.shape, self.input_shape) 242 | self.assertEqual(pred_images.shape, self.input_shape) 243 | 244 | 245 | class TestGATModel(unittest.TestCase): 246 | def setUp(self): 247 | self.num_nodes = 10 248 | self.num_features = 5 249 | self.nclass = 3 250 | 251 | self.x = jax.random.normal( 252 | jax.random.PRNGKey(0), (self.num_nodes, self.num_features) 253 | ) 254 | 255 | self.adj = jax.random.bernoulli( 256 | jax.random.PRNGKey(0), 0.3, (self.num_nodes, self.num_nodes) 257 | ) 258 | 259 | self.model = GAT( 260 | nfeat=self.num_features, 261 | nhid=8, 262 | nclass=self.nclass, 263 | dropout_rate=0.5, 264 | alpha=0.2, 265 | nheads=3, 266 | ) 267 | 268 | def test_gat_model_initialization_and_processing(self): 269 | params = self.model.init(jax.random.key(0), self.x, self.adj, training=False) 270 | 271 | output = self.model.apply(params, self.x, self.adj, training=False) 272 | 273 | self.assertEqual(output.shape, (self.num_nodes, self.nclass)) 274 | 275 | 276 | class TestIJEPAModel(unittest.TestCase): 277 | def setUp(self): 278 | self.image_size = 128 279 | self.num_channels = 3 280 | self.patch_size = 16 281 | self.embed_dim = 32 282 | self.predictor_bottleneck = 16 283 | self.num_heads = 4 284 | self.predictor_num_heads = 4 285 | self.num_layers = 2 286 | self.predictor_num_layers = 1 287 | self.dropout_p = 0 288 | self.num_patches = (self.image_size**2) / (self.patch_size**2) 289 | 290 | self.x = jax.random.normal( 291 | jax.random.PRNGKey(0), 292 | (1, self.image_size, self.image_size, self.num_channels), 293 | ) 294 | 295 | self.model = IJEPA( 296 | image_size=self.image_size, 297 | num_channels=self.num_channels, 298 | patch_size=self.patch_size, 299 | embed_dim=self.embed_dim, 300 | predictor_bottleneck=self.predictor_bottleneck, 301 | num_heads=self.num_heads, 302 | predictor_num_heads=self.predictor_num_heads, 303 | num_layers=self.num_layers, 304 | predictor_num_layers=self.predictor_num_layers, 305 | dropout_p=self.dropout_p, 306 | ) 307 | 308 | self.data_sampler = IJEPADataSampler( 309 | image_size=self.image_size, M=4, patch_size=self.patch_size 310 | ) 311 | 312 | def test_ijepa_data_sampling(self): 313 | context_mask, target_mask = self.data_sampler() 314 | self.assertEqual(context_mask.shape, (4, self.num_patches)) 315 | self.assertEqual(target_mask.shape, (4, self.num_patches)) 316 | 317 | def test_ijepa_model_initialization_and_processing(self): 318 | context_mask, target_mask = self.data_sampler() 319 | 320 | params = self.model.init( 321 | jax.random.key(0), 322 | self.x, 323 | context_mask[jnp.newaxis], 324 | target_mask[jnp.newaxis], 325 | training=False, 326 | ) 327 | 328 | outputs, _ = self.model.apply( 329 | params, 330 | self.x, 331 | context_mask[jnp.newaxis], 332 | target_mask[jnp.newaxis], 333 | training=False, 334 | ) 335 | 336 | self.assertEqual(len(outputs), 4) 337 | self.assertEqual(outputs[0][0].shape, (1, self.num_patches, self.embed_dim)) 338 | self.assertEqual(outputs[0][0].shape, outputs[0][1].shape) 339 | 340 | 341 | if __name__ == "__main__": 342 | unittest.main() 343 | -------------------------------------------------------------------------------- /tests/test_random.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax.numpy as jnp 4 | 5 | from nanodl import ( 6 | bernoulli, 7 | bits, 8 | categorical, 9 | chisquare, 10 | choice, 11 | exponential, 12 | gamma, 13 | geometric, 14 | gumbel, 15 | normal, 16 | permutation, 17 | poisson, 18 | randint, 19 | time_rng_key, 20 | triangular, 21 | truncated_normal, 22 | uniform, 23 | ) 24 | 25 | 26 | class TestRandomFunctions(unittest.TestCase): 27 | 28 | def test_time_rng_key(self): 29 | key1 = time_rng_key(seed=42) 30 | key2 = time_rng_key(seed=42) 31 | self.assertTrue( 32 | jnp.array_equal(key1, key2), "Keys should be equal for the same seed" 33 | ) 34 | 35 | def test_uniform(self): 36 | result = uniform((2, 3)) 37 | self.assertEqual(result.shape, (2, 3)) 38 | self.assertEqual(result.dtype, jnp.float32) 39 | 40 | def test_normal(self): 41 | result = normal((4, 5), seed=42) 42 | self.assertEqual(result.shape, (4, 5)) 43 | self.assertEqual(result.dtype, jnp.float32) 44 | 45 | def test_bernoulli(self): 46 | result = bernoulli(0.5, (10,), seed=42) 47 | self.assertEqual(result.shape, (10,)) 48 | self.assertEqual(result.dtype, jnp.bool_) 49 | 50 | def test_categorical(self): 51 | logits = jnp.array([0.1, 0.2, 0.7]) 52 | result = categorical(logits, shape=(5,), seed=42) 53 | self.assertEqual(result.shape, (5,)) 54 | 55 | def test_randint(self): 56 | result = randint((3, 3), 0, 10, seed=42) 57 | self.assertEqual(result.shape, (3, 3)) 58 | self.assertEqual(result.dtype, jnp.int32) 59 | 60 | def test_permutation(self): 61 | arr = jnp.arange(10) 62 | result = permutation(arr, seed=42) 63 | self.assertEqual(result.shape, arr.shape) 64 | self.assertNotEqual(jnp.all(result == arr), True) 65 | 66 | def test_gumbel(self): 67 | result = gumbel((2, 2), seed=42) 68 | self.assertEqual(result.shape, (2, 2)) 69 | self.assertEqual(result.dtype, jnp.float32) 70 | 71 | def test_choice(self): 72 | result = choice(5, shape=(3,), seed=42) 73 | self.assertEqual(result.shape, (3,)) 74 | 75 | def test_bits(self): 76 | result = bits((2, 2), seed=42) 77 | self.assertEqual(result.shape, (2, 2)) 78 | self.assertEqual(result.dtype, jnp.uint32) 79 | 80 | def test_exponential(self): 81 | result = exponential((2, 2), seed=42) 82 | self.assertEqual(result.shape, (2, 2)) 83 | self.assertEqual(result.dtype, jnp.float32) 84 | 85 | def test_triangular(self): 86 | result = triangular(0, 1, 0.5, (2, 2), seed=42) 87 | self.assertEqual(result.shape, (2, 2)) 88 | self.assertEqual(result.dtype, jnp.float32) 89 | 90 | def test_truncated_normal(self): 91 | result = truncated_normal(0, 1, (2, 2), seed=42) 92 | self.assertEqual(result.shape, (2, 2)) 93 | self.assertEqual(result.dtype, jnp.float32) 94 | 95 | def test_poisson(self): 96 | result = poisson(3, (2, 2), seed=42) 97 | self.assertEqual(result.shape, (2, 2)) 98 | self.assertEqual(result.dtype, jnp.int32) 99 | 100 | def test_geometric(self): 101 | result = geometric(0.5, (2, 2), seed=42) 102 | self.assertEqual(result.shape, (2, 2)) 103 | self.assertEqual(result.dtype, jnp.int32) 104 | 105 | def test_gamma(self): 106 | result = gamma(2, (2, 2), seed=42) 107 | self.assertEqual(result.shape, (2, 2)) 108 | self.assertEqual(result.dtype, jnp.float32) 109 | 110 | def test_chisquare(self): 111 | result = chisquare(2, (2, 2), seed=42) 112 | self.assertEqual(result.shape, (2, 2)) 113 | self.assertEqual(result.dtype, jnp.float32) 114 | 115 | 116 | if __name__ == "__main__": 117 | unittest.main() 118 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from nanodl import * 7 | 8 | 9 | class TestDataset(unittest.TestCase): 10 | def test_dataset_length(self): 11 | class DummyDataset(Dataset): 12 | def __init__(self, data): 13 | self.data = data 14 | 15 | def __len__(self): 16 | return len(self.data) 17 | 18 | def __getitem__(self, index): 19 | return self.data[index] 20 | 21 | dataset = DummyDataset(jnp.arange(10)) 22 | self.assertEqual(len(dataset), 10) 23 | 24 | def test_dataset_getitem(self): 25 | class DummyDataset(Dataset): 26 | def __init__(self, data): 27 | self.data = data 28 | 29 | def __len__(self): 30 | return len(self.data) 31 | 32 | def __getitem__(self, index): 33 | return self.data[index] 34 | 35 | dataset = DummyDataset(jnp.arange(10)) 36 | item = dataset[5] 37 | self.assertEqual(item, 5) 38 | 39 | 40 | class TestArrayDataset(unittest.TestCase): 41 | def test_array_dataset_length(self): 42 | dataset = ArrayDataset(jnp.array([1, 2, 3]), jnp.array([4, 5, 6])) 43 | self.assertEqual(len(dataset), 3) 44 | 45 | def test_array_dataset_getitem(self): 46 | dataset = ArrayDataset(jnp.array([1, 2, 3]), jnp.array([4, 5, 6])) 47 | item = dataset[1] 48 | self.assertEqual(item, (2, 5)) 49 | 50 | 51 | class TestDataLoader(unittest.TestCase): 52 | def test_data_loader_length(self): 53 | dataset = ArrayDataset(jnp.ones((1001, 256, 256)), jnp.ones((1001, 256, 256))) 54 | dataloader = DataLoader(dataset, batch_size=10, shuffle=True, drop_last=False) 55 | self.assertEqual(len(dataloader), 101) 56 | 57 | def test_data_loader_iteration(self): 58 | dataset = ArrayDataset(jnp.ones((1001, 256, 256)), jnp.ones((1001, 256, 256))) 59 | dataloader = DataLoader(dataset, batch_size=10, shuffle=True, drop_last=True) 60 | for a, b in dataloader: 61 | self.assertEqual(a.shape, (10, 256, 256)) 62 | self.assertEqual(b.shape, (10, 256, 256)) 63 | 64 | 65 | class TestMLFunctions(unittest.TestCase): 66 | def test_batch_cosine_similarities(self): 67 | source = jnp.array([1, 0, 0]) 68 | candidates = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) 69 | similarities = batch_cosine_similarities(source, candidates) 70 | expected_results = jnp.array([1.0, 0.0, 0.0]) 71 | self.assertTrue(jnp.allclose(similarities, expected_results)) 72 | 73 | def test_batch_pearsonr(self): 74 | x = jnp.array([[1, 2, 3], [4, 5, 6]]) 75 | y = jnp.array([[6, 5, 4], [2, 6, 8]]) 76 | correlations = batch_pearsonr(x, y) 77 | expected_results = jnp.array([-1.0, 1.0, 1.0]) 78 | self.assertTrue(jnp.allclose(correlations, expected_results)) 79 | 80 | def test_classification_scores(self): 81 | labels = jnp.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]) 82 | preds = jnp.array([1, 1, 1, 0, 1, 0, 1, 0, 0, 0]) 83 | scores = classification_scores(labels, preds) 84 | expected_results = jnp.array([0.8, 0.8, 0.8, 0.8000001]) 85 | self.assertTrue(jnp.allclose(scores, expected_results)) 86 | 87 | def test_mean_reciprocal_rank(self): 88 | predictions = jnp.array( 89 | [ 90 | [0, 1, 2], # "correct" prediction at index 0 91 | [1, 0, 2], # "correct" prediction at index 1 92 | [2, 1, 0], # "correct" prediction at index 2 93 | ] 94 | ) 95 | mrr_score = mean_reciprocal_rank(predictions) 96 | self.assertAlmostEqual(mrr_score, 0.61111116) 97 | 98 | def test_jaccard(self): 99 | sequence1 = [1, 2, 3] 100 | sequence2 = [2, 3, 4] 101 | similarity = jaccard(sequence1, sequence2) 102 | self.assertAlmostEqual(similarity, 0.5) 103 | 104 | def test_hamming(self): 105 | sequence1 = jnp.array([1, 2, 3, 4]) 106 | sequence2 = jnp.array([1, 2, 4, 4]) 107 | similarity = hamming(sequence1, sequence2) 108 | self.assertEqual(similarity, 3) 109 | 110 | def test_zero_pad_sequences(self): 111 | arr = jnp.array([[1, 2, 3], [4, 5, 6]]) 112 | max_length = 5 113 | padded_arr = zero_pad_sequences(arr, max_length) 114 | expected_padded_arr = jnp.array([[1, 2, 3, 0, 0], [4, 5, 6, 0, 0]]) 115 | self.assertTrue(jnp.array_equal(padded_arr, expected_padded_arr)) 116 | 117 | def test_entropy(self): 118 | probabilities = jnp.array([0.25, 0.75]) 119 | entropy_value = entropy(probabilities) 120 | self.assertAlmostEqual(entropy_value, 0.8112781) 121 | 122 | def test_gini_impurity(self): 123 | probabilities = jnp.array([0.25, 0.75]) 124 | gini_value = gini_impurity(probabilities) 125 | self.assertAlmostEqual(gini_value, 0.375) 126 | 127 | def test_kl_divergence(self): 128 | p = jnp.array([0.25, 0.75]) 129 | q = jnp.array([0.5, 0.5]) 130 | kl_value = kl_divergence(p, q) 131 | self.assertAlmostEqual(kl_value, 0.18872187) 132 | 133 | def test_count_parameters(self): 134 | class MyModel: 135 | def __init__(self): 136 | self.layer1 = jnp.ones((10, 20)) 137 | self.layer2 = jnp.ones((5, 5)) 138 | 139 | model = MyModel() 140 | params = model.__dict__ 141 | total_params = count_parameters(params) 142 | self.assertEqual(total_params, 225) 143 | 144 | 145 | class TestNLPFunctions(unittest.TestCase): 146 | def setUp(self): 147 | self.hypotheses = [ 148 | "the cat is on the mat", 149 | "there is a cat on the mat", 150 | ] 151 | self.references = [ 152 | "the cat is on the mat", 153 | "the cat sits on the mat", 154 | ] 155 | 156 | def test_rouge(self): 157 | rouge_scores = rouge(self.hypotheses, self.references, [1, 2]) 158 | expected_scores = { 159 | "ROUGE-1": { 160 | "precision": 0.7857142857142857, 161 | "recall": 0.9, 162 | "f1": 0.8333333333328402, 163 | }, 164 | "ROUGE-2": { 165 | "precision": 0.6666666666666666, 166 | "recall": 0.7, 167 | "f1": 0.6818181818176838, 168 | }, 169 | } 170 | 171 | def assert_nested_dicts_equal(dict1, dict2): 172 | for key in dict1.keys(): 173 | if isinstance(dict1[key], dict) and isinstance(dict2[key], dict): 174 | assert_nested_dicts_equal(dict1[key], dict2[key]) 175 | elif dict1[key] != dict2[key]: 176 | raise AssertionError( 177 | f"Values for key '{key}' are not equal: {dict1[key]} != {dict2[key]}" 178 | ) 179 | 180 | assert_nested_dicts_equal(rouge_scores, expected_scores) 181 | 182 | def test_bleu(self): 183 | bleu_score = bleu(self.hypotheses, self.references) 184 | self.assertAlmostEqual(bleu_score, 0.03737737833658239, places=2) 185 | 186 | def test_meteor(self): 187 | meteor_score = meteor(self.hypotheses[0], self.references[0]) 188 | self.assertAlmostEqual(meteor_score, 1.0, places=3) # Perfect match 189 | meteor_score = meteor(self.hypotheses[1], self.references[1]) 190 | self.assertAlmostEqual(meteor_score, 0.4981684981684981, places=3) 191 | 192 | def test_cider_score(self): 193 | score = cider_score(self.hypotheses[0], self.references[0]) 194 | self.assertAlmostEqual(score, 1.0, places=2) # Perfect match 195 | score = cider_score(self.hypotheses[1], self.references[1]) 196 | self.assertAlmostEqual(score, 0.31595617188837527, places=2) 197 | 198 | def test_perplexity(self): 199 | log_probs = [-2.3, -1.7, -0.4] 200 | perplexity_score = perplexity(log_probs) 201 | self.assertAlmostEqual(perplexity_score, 4.0, places=3) 202 | 203 | def test_word_error_rate(self): 204 | wer_score = word_error_rate(self.hypotheses, self.references) 205 | self.assertAlmostEqual(wer_score, 0.3333333333333333, places=2) 206 | 207 | 208 | class TestVisionFunctions(unittest.TestCase): 209 | def test_normalize_images(self): 210 | images = jnp.array([[[[0.0, 0.5], [1.0, 0.25]]]]) 211 | normalized_images = normalize_images(images) 212 | self.assertAlmostEqual(normalized_images.mean(), 0.0, places=3) 213 | self.assertAlmostEqual(normalized_images.std(), 1.0, places=3) 214 | 215 | def test_random_crop(self): 216 | images = jnp.ones((10, 100, 100, 3)) 217 | crop_size = 64 218 | cropped_images = random_crop(images, crop_size) 219 | self.assertEqual(cropped_images.shape, (10, crop_size, crop_size, 3)) 220 | 221 | def test_gaussian_blur(self): 222 | image = jnp.ones((5, 5, 3)) 223 | blurred_image = gaussian_blur(image, kernel_size=3, sigma=1.0) 224 | self.assertEqual(blurred_image.shape, (5, 5, 3)) 225 | 226 | def test_sobel_edge_detection(self): 227 | image = jnp.ones((5, 5, 3)) 228 | edges = sobel_edge_detection(image) 229 | self.assertEqual(edges.shape, (5, 5)) 230 | 231 | def test_adjust_brightness(self): 232 | image = jnp.ones((5, 5, 3)) 233 | adjusted_image = adjust_brightness(image, factor=1.5) 234 | self.assertEqual(adjusted_image.shape, (5, 5, 3)) 235 | 236 | def test_adjust_contrast(self): 237 | image = jnp.ones((5, 5, 3)) 238 | adjusted_image = adjust_contrast(image, factor=1.5) 239 | self.assertEqual(adjusted_image.shape, (5, 5, 3)) 240 | 241 | def test_flip_image(self): 242 | image = jnp.ones((5, 5, 3)) 243 | flipped_image_horizontally = flip_image(image, jnp.array([True])) 244 | flipped_image_vertically = flip_image(image, jnp.array([False])) 245 | self.assertEqual(flipped_image_horizontally.shape, (5, 5, 3)) 246 | self.assertEqual(flipped_image_vertically.shape, (5, 5, 3)) 247 | 248 | def test_random_flip_image(self): 249 | key = jax.random.PRNGKey(0) 250 | image = jnp.ones((5, 5, 3)) 251 | flipped_image = random_flip_image(image, key, jnp.array([True])) 252 | self.assertEqual(flipped_image.shape, (5, 5, 3)) 253 | 254 | 255 | if __name__ == "__main__": 256 | unittest.main() 257 | --------------------------------------------------------------------------------