├── .github
├── FUNDING.yml
└── workflows
│ ├── python-publish.yml
│ └── unit-tests.yml
├── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── assets
└── logo.jpg
├── docs
├── api.md
├── examples
│ ├── diffusion_example.ipynb
│ ├── gat_example.ipynb
│ ├── gpt4_example.ipynb
│ ├── llama_example.ipynb
│ ├── mistral_copy_example.ipynb
│ ├── sklearn_gpu_example.ipynb
│ ├── utils_examples.ipynb
│ └── whisper_example.ipynb
├── index.md
├── requirements.in
├── requirements.txt
└── usage.md
├── mkdocs.yml
├── nanodl
├── __init__.py
└── __src
│ ├── __init__.py
│ ├── classical
│ ├── __init__.py
│ ├── bayes.py
│ ├── clustering.py
│ ├── dimensionality_reduction.py
│ ├── dsp.py
│ └── regression.py
│ ├── experimental
│ ├── __init__.py
│ ├── bitlinear.py
│ ├── gat.py
│ ├── mamba.py
│ ├── rlhf.py
│ └── tokenizer.py
│ ├── models
│ ├── __init__.py
│ ├── attention.py
│ ├── clip.py
│ ├── diffusion.py
│ ├── gemma.py
│ ├── gpt.py
│ ├── ijepa.py
│ ├── kan.py
│ ├── lamda.py
│ ├── llama.py
│ ├── mistral.py
│ ├── mixer.py
│ ├── reward.py
│ ├── t5.py
│ ├── transformer.py
│ ├── vit.py
│ └── whisper.py
│ └── utils
│ ├── __init__.py
│ ├── data.py
│ ├── ml.py
│ ├── nlp.py
│ ├── random.py
│ └── vision.py
├── requirements.txt
├── setup.py
└── tests
├── files
└── sample.txt
├── test_classic.py
├── test_kan.py
├── test_models.py
├── test_random.py
└── test_utils.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github: [HMUNACHI]
2 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | name: Upload Python Package
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v2
12 | - name: Set up Python
13 | uses: actions/setup-python@v2
14 | with:
15 | python-version: '3.x'
16 | - name: Install dependencies
17 | run: |
18 | python -m pip install --upgrade pip
19 | pip install setuptools wheel twine
20 | - name: Build and publish
21 | env:
22 | TWINE_USERNAME: __token__
23 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
24 | run: |
25 | echo "Setting version to: ${GITHUB_REF#refs/tags/}"
26 | sed -i "s/version='.*',/version='${GITHUB_REF#refs/tags/}',/" setup.py
27 | sed -i "s/__version__ = '.*'/__version__ = '${GITHUB_REF#refs/tags/}'/" nanodl/__init__.py
28 | python setup.py sdist bdist_wheel
29 | twine upload dist/*
--------------------------------------------------------------------------------
/.github/workflows/unit-tests.yml:
--------------------------------------------------------------------------------
1 | name: Python Unit Tests
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | test:
7 |
8 | runs-on: ubuntu-latest
9 |
10 | steps:
11 | - uses: actions/checkout@v2
12 | - name: Set up Python
13 | uses: actions/setup-python@v2
14 | with:
15 | python-version: '3.10'
16 | - name: Install dependencies
17 | run: |
18 | pip install -r requirements.txt
19 | - name: Run tests
20 | run: |
21 | python -m unittest discover -s tests
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore compiled object files, libraries, and executables
2 | *.o
3 | *.a
4 | *.exe
5 |
6 | # Ignore system-specific files
7 | .DS_Store
8 | Thumbs.db
9 |
10 | # Ignore logs and temporary files
11 | *.log
12 | *.tmp
13 |
14 | # Ignore build directories
15 | /build/
16 | /dist/
17 | /node_modules/
18 | /.venv/
19 | /.vscode/
20 | /nanodl.egg-info/
21 | __pycache__/
22 |
23 | # Ignore configuration files with sensitive information
24 | config.ini
25 | secrets.yaml
26 | params.pkl
27 | base_params.pkl
28 | reward_params.pkl
29 |
30 | # Ignore user-specific files
31 | /userdata/
32 | /user-settings.json
33 |
34 | # ignore specific files
35 | /archive/
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: ubuntu-22.04
5 | tools:
6 | python: "3.10"
7 |
8 | mkdocs:
9 | configuration: mkdocs.yml
10 |
11 | python:
12 | install:
13 | - requirements: docs/requirements.txt
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Henry Ndubuaku
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # A Jax-based library for designing and training transformer models from scratch.
6 |
7 |  [](https://nanodl.readthedocs.io) [](https://discord.gg/3u9vumJEmz) [](https://www.linkedin.com//company/80434055) [](https://twitter.com/hmunachii)
8 |
9 | Author: [Henry Ndubuaku](https://www.linkedin.com/in/henry-ndubuaku-7b6350b8/) (Discord & Docs badges are clickable)
10 |
11 | N/B: Codes are implemented pedagogically at the expense of repetition.
12 | Each model is purposefully contained in a file without inter-file dependencies.
13 |
14 | ## Overview
15 | Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks and abstracts distributed training, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features:
16 |
17 | - A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch.
18 | - An extensive selection of models like Gemma, LlaMa3, Mistral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, CLIP etc.
19 | - Data-parallel distributed trainers models on multiple GPUs or TPUs, without the need for manual training loops.
20 | - Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective.
21 | - Layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development.
22 | - GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc.
23 | - True random number generators in Jax which do not need the verbose code.
24 | - A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU, Tokenizer etc.
25 | - Each model is contained in a single file with no external dependencies, so the source code can also be easily used.
26 | - True random number generators in Jax which do not need the verbose code (examples shown in next sections).
27 |
28 | There are experimental and/or unfinished features (like MAMBA, KAN, BitNet, GAT and RLHF)
29 | in the repo which are not yet available via the package, but can be copied from this repo.
30 | Feedback on any of our discussion, issue and pull request threads are welcomed!
31 | Please report any feature requests, issues, questions or concerns in the [Discord](https://discord.gg/3u9vumJEmz),
32 | or just let us know what you're working on!
33 |
34 | ## Quick install
35 |
36 | You will need Python 3.9 or later, and working [JAX](https://github.com/google/jax/blob/main/README.md)
37 | installation, [FLAX](https://github.com/google/flax/blob/main/README.md)
38 | installation, [OPTAX](https://github.com/google-deepmind/optax/blob/main/README.md)
39 | installation (with GPU support for running training, without can only support creations).
40 | Models can be designed and tested on CPUs but trainers are all Distributed Data-Parallel
41 | which would require a GPU with 1 to N GPUS/TPUS. For CPU-only version of JAX:
42 |
43 | ```
44 | pip install --upgrade pip # To support manylinux2010 wheels.
45 | pip install jax flax optax
46 | ```
47 |
48 | Then, install nanodl from PyPi:
49 |
50 | ```
51 | pip install nanodl
52 | ```
53 |
54 | ## What does nanodl look like?
55 |
56 | We provide various example usages of the nanodl API.
57 |
58 | ```py
59 | import jax
60 | import nanodl
61 | import jax.numpy as jnp
62 | from nanodl import ArrayDataset, DataLoader
63 | from nanodl import GPT4, GPTDataParallelTrainer
64 |
65 | # Preparing your dataset
66 | batch_size = 8
67 | max_length = 50
68 | vocab_size = 1000
69 |
70 | # Create random data
71 | data = nanodl.uniform(
72 | shape=(batch_size, max_length),
73 | minval=0, maxval=vocab_size-1
74 | ).astype(jnp.int32)
75 |
76 | # Shift to create next-token prediction dataset
77 | dummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]
78 |
79 | # Create dataset and dataloader
80 | dataset = ArrayDataset(dummy_inputs, dummy_targets)
81 | dataloader = DataLoader(
82 | dataset, batch_size=batch_size, shuffle=True, drop_last=False
83 | )
84 |
85 | # model parameters
86 | hyperparams = {
87 | 'num_layers': 1,
88 | 'hidden_dim': 256,
89 | 'num_heads': 2,
90 | 'feedforward_dim': 256,
91 | 'dropout': 0.1,
92 | 'vocab_size': vocab_size,
93 | 'embed_dim': 256,
94 | 'max_length': max_length,
95 | 'start_token': 0,
96 | 'end_token': 50,
97 | }
98 |
99 | # Inferred GPT4 model
100 | model = GPT4(**hyperparams)
101 |
102 | trainer = GPTDataParallelTrainer(
103 | model, dummy_inputs.shape, 'params.pkl'
104 | )
105 |
106 | trainer.train(
107 | train_loader=dataloader, num_epochs=100, val_loader=dataloader
108 | ) # use actual val data
109 |
110 | # Generating from a start token
111 | start_tokens = jnp.array([[123, 456]])
112 |
113 | # Remember to load the trained parameters
114 | params = trainer.load_params('params.pkl')
115 |
116 | outputs = model.apply(
117 | {'params': params},
118 | start_tokens,
119 | rngs={'dropout': nanodl.time_rng_key()},
120 | method=model.generate
121 | )
122 | ```
123 |
124 | Vision example
125 |
126 | ```py
127 | import nanodl
128 | import jax.numpy as jnp
129 | from nanodl import ArrayDataset, DataLoader
130 | from nanodl import DiffusionModel, DiffusionDataParallelTrainer
131 |
132 | image_size = 32
133 | block_depth = 2
134 | batch_size = 8
135 | widths = [32, 64, 128]
136 | input_shape = (101, image_size, image_size, 3)
137 | images = nanodl.normal(shape=input_shape)
138 |
139 | # Use your own images
140 | dataset = ArrayDataset(images)
141 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
142 |
143 | # Create diffusion model
144 | diffusion_model = DiffusionModel(image_size, widths, block_depth)
145 |
146 | # Training on your data
147 | trainer = DiffusionDataParallelTrainer(diffusion_model,
148 | input_shape=images.shape,
149 | weights_filename='params.pkl',
150 | learning_rate=1e-4)
151 |
152 | trainer.train(dataloader, 10)
153 |
154 | # Generate some samples: Each model is a Flax.linen module
155 | # Use as you normally would
156 | params = trainer.load_params('params.pkl')
157 | generated_images = diffusion_model.apply({'params': params},
158 | num_images=5,
159 | diffusion_steps=5,
160 | method=diffusion_model.generate)
161 | ```
162 |
163 | Audio example
164 |
165 | ```py
166 | import jax
167 | import jax.numpy as jnp
168 | from nanodl import ArrayDataset, DataLoader
169 | from nanodl import Whisper, WhisperDataParallelTrainer
170 |
171 | # Dummy data parameters
172 | batch_size = 8
173 | max_length = 50
174 | embed_dim = 256
175 | vocab_size = 1000
176 |
177 | # Generate data: replace with actual tokenised/quantised data
178 | dummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)
179 | dummy_inputs = jnp.ones((101, max_length, embed_dim))
180 |
181 | dataset = ArrayDataset(dummy_inputs, dummy_targets)
182 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
183 |
184 | # model parameters
185 | hyperparams = {
186 | 'num_layers': 1,
187 | 'hidden_dim': 256,
188 | 'num_heads': 2,
189 | 'feedforward_dim': 256,
190 | 'dropout': 0.1,
191 | 'vocab_size': 1000,
192 | 'embed_dim': embed_dim,
193 | 'max_length': max_length,
194 | 'start_token': 0,
195 | 'end_token': 50,
196 | }
197 |
198 | # Initialize model
199 | model = Whisper(**hyperparams)
200 |
201 | # Training on your data
202 | trainer = WhisperDataParallelTrainer(model,
203 | dummy_inputs.shape,
204 | dummy_targets.shape,
205 | 'params.pkl')
206 |
207 | trainer.train(dataloader, 2, dataloader)
208 |
209 | # Sample inference
210 | params = trainer.load_params('params.pkl')
211 |
212 | # for more than one sample, often use model.generate_batch
213 | transcripts = model.apply({'params': params},
214 | dummy_inputs[:1],
215 | method=model.generate)
216 | ```
217 |
218 | Reward Model example for RLHF
219 |
220 | ```py
221 | import nanodl
222 | import jax.numpy as jnp
223 | from nanodl import ArrayDataset, DataLoader
224 | from nanodl import Mistral, RewardModel, RewardDataParallelTrainer
225 |
226 | # Generate dummy data
227 | batch_size = 8
228 | max_length = 10
229 |
230 | # Replace with actual tokenised data
231 | dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32)
232 | dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)
233 |
234 | # Create dataset and dataloader
235 | dataset = ArrayDataset(dummy_chosen, dummy_rejected)
236 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
237 |
238 | # model parameters
239 | hyperparams = {
240 | 'num_layers': 1,
241 | 'hidden_dim': 256,
242 | 'num_heads': 2,
243 | 'feedforward_dim': 256,
244 | 'dropout': 0.1,
245 | 'vocab_size': 1000,
246 | 'embed_dim': 256,
247 | 'max_length': max_length,
248 | 'start_token': 0,
249 | 'end_token': 50,
250 | 'num_groups': 2,
251 | 'window_size': 5,
252 | 'shift_size': 2
253 | }
254 |
255 | # Initialize reward model from Mistral
256 | model = Mistral(**hyperparams)
257 | reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1)
258 |
259 | # Train the reward model
260 | trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')
261 | trainer.train(dataloader, 5, dataloader)
262 | params = trainer.load_params('reward_model_weights.pkl')
263 |
264 | # Call as you would a regular Flax model
265 | rewards = reward_model.apply({'params': params},
266 | dummy_chosen,
267 | rngs={'dropout': nanodl.time_rng_key()})
268 | ```
269 |
270 | PCA example
271 |
272 | ```py
273 | import nanodl
274 | from nanodl import PCA
275 |
276 | # Use actual data
277 | data = nanodl.normal(shape=(1000, 10))
278 |
279 | # Initialise and train PCA model
280 | pca = PCA(n_components=2)
281 | pca.fit(data)
282 |
283 | # Get PCA transforms
284 | transformed_data = pca.transform(data)
285 |
286 | # Get reverse transforms
287 | original_data = pca.inverse_transform(transformed_data)
288 |
289 | # Sample from the distribution
290 | X_sampled = pca.sample(n_samples=1000, key=None)
291 | ```
292 |
293 | This is still in dev, works great but roughness is expected, and contributions are therefore highly encouraged!
294 |
295 | - Make your changes without changing the design patterns.
296 | - Write tests for your changes if necessary.
297 | - Install locally with `pip3 install -e .`.
298 | - Run tests with `python3 -m unittest discover -s tests`.
299 | - Then submit a pull request.
300 |
301 | Contributions can be made in various forms:
302 |
303 | - Writing documentation.
304 | - Fixing bugs.
305 | - Implementing papers.
306 | - Writing high-coverage tests.
307 | - Optimizing existing codes.
308 | - Experimenting and submitting real-world examples to the examples section.
309 | - Reporting bugs.
310 | - Responding to reported issues.
311 |
312 | Join the [Discord Server](https://discord.gg/3u9vumJEmz) for more.
313 |
314 | ## Sponsorships
315 |
316 | The name "NanoDL" stands for Nano Deep Learning. Models are exploding in size, therefore gate-keeping
317 | experts and companies with limited resources from building flexible models without prohibitive costs.
318 | Following the success of Phi models, the long-term goal is to build and train nano versions of all available models,
319 | while ensuring they compete with the original models in performance, with total
320 | number of parameters not exceeding 1B. Trained weights will be made available via this library.
321 | Any form of sponsorship, funding will help with training resources.
322 | You can either sponsor via GitHub [here](https://github.com/sponsors/HMUNACHI) or reach out via ndubuakuhenry@gmail.com.
323 |
324 | ## Citing nanodl
325 |
326 | To cite this repository:
327 |
328 | ```
329 | @software{nanodl2024github,
330 | author = {Henry Ndubuaku},
331 | title = {NanoDL: A Jax-based library for designing and training transformer models from scratch.},
332 | url = {http://github.com/hmunachi/nanodl},
333 | year = {2024},
334 | }
335 | ```
336 |
--------------------------------------------------------------------------------
/assets/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HMUNACHI/nanodl/e52861a5b2c9bf76e4e79e0bf88a07420497579d/assets/logo.jpg
--------------------------------------------------------------------------------
/docs/api.md:
--------------------------------------------------------------------------------
1 | # Welcome to NanoDL Documentation
2 |
3 | ## API Reference
4 |
5 | ::: nanodl.GAT
6 | ::: nanodl.GraphAttentionLayer
7 | ::: nanodl.T5
8 | ::: nanodl.T5DataParallelTrainer
9 | ::: nanodl.T5Encoder
10 | ::: nanodl.T5Decoder
11 | ::: nanodl.T5EncoderBlock
12 | ::: nanodl.T5DecoderBlock
13 | ::: nanodl.ViT
14 | ::: nanodl.ViTDataParallelTrainer
15 | ::: nanodl.ViTBlock
16 | ::: nanodl.ViTEncoder
17 | ::: nanodl.PatchEmbedding
18 | ::: nanodl.Transformer
19 | ::: nanodl.TransformerDataParallelTrainer
20 | ::: nanodl.TransformerEncoder
21 | ::: nanodl.TransformerDecoderBlock
22 | ::: nanodl.PositionalEncoding
23 | ::: nanodl.TokenAndPositionEmbedding
24 | ::: nanodl.MultiHeadAttention
25 | ::: nanodl.AddNorm
26 | ::: nanodl.CLIP
27 | ::: nanodl.CLIPDataParallelTrainer
28 | ::: nanodl.ImageEncoder
29 | ::: nanodl.TextEncoder
30 | ::: nanodl.SelfMultiHeadAttention
31 | ::: nanodl.LaMDA
32 | ::: nanodl.LaMDADataParallelTrainer
33 | ::: nanodl.LaMDABlock
34 | ::: nanodl.LaMDADecoder
35 | ::: nanodl.RelativeMultiHeadAttention
36 | ::: nanodl.DiffusionModel
37 | ::: nanodl.DiffusionDataParallelTrainer
38 | ::: nanodl.UNet
39 | ::: nanodl.UNetDownBlock
40 | ::: nanodl.UNetUpBlock
41 | ::: nanodl.UNetResidualBlock
42 | ::: nanodl.GPT3
43 | ::: nanodl.GPT4
44 | ::: nanodl.GPTDataParallelTrainer
45 | ::: nanodl.GPT3Block
46 | ::: nanodl.GPT4Block
47 | ::: nanodl.GPT3Decoder
48 | ::: nanodl.GPT4Decoder
49 | ::: nanodl.PositionWiseFFN
50 | ::: nanodl.LlaMA2
51 | ::: nanodl.LlaMADataParallelTrainer
52 | ::: nanodl.RotaryPositionalEncoding
53 | ::: nanodl.LlaMA2Decoder
54 | ::: nanodl.LlaMA2DecoderBlock
55 | ::: nanodl.GroupedRotaryMultiHeadAttention
56 | ::: nanodl.Mistral
57 | ::: nanodl.MistralDataParallelTrainer
58 | ::: nanodl.MistralDecoder
59 | ::: nanodl.MistralDecoderBlock
60 | ::: nanodl.GroupedRotaryShiftedWindowMultiHeadAttention
61 | ::: nanodl.Mixtral
62 | ::: nanodl.MixtralDecoder
63 | ::: nanodl.MixtralDecoderBlock
64 | ::: nanodl.Whisper
65 | ::: nanodl.WhisperDataParallelTrainer
66 | ::: nanodl.WhisperSpeechEncoder
67 | ::: nanodl.WhisperSpeechEncoderBlock
68 | ::: nanodl.GAT
69 | ::: nanodl.GraphAttentionLayer
70 | ::: nanodl.NaiveBayesClassifier
71 | ::: nanodl.LinearRegression
72 | ::: nanodl.LogisticRegression
73 | ::: nanodl.GaussianProcess
74 | ::: nanodl.KMeans
75 | ::: nanodl.GaussianMixtureModel
76 | ::: nanodl.PCA
77 | ::: nanodl.Dataset
78 | ::: nanodl.ArrayDataset
79 | ::: nanodl.DataLoader
80 | ::: nanodl.batch_cosine_similarities
81 | ::: nanodl.batch_pearsonr
82 | ::: nanodl.classification_scores
83 | ::: nanodl.count_parameters
84 | ::: nanodl.entropy
85 | ::: nanodl.gini_impurity
86 | ::: nanodl.hamming
87 | ::: nanodl.jaccard
88 | ::: nanodl.kl_divergence
89 | ::: nanodl.mean_reciprocal_rank
90 | ::: nanodl.zero_pad_sequences
91 | ::: nanodl.bleu
92 | ::: nanodl.cider_score
93 | ::: nanodl.meteor
94 | ::: nanodl.perplexity
95 | ::: nanodl.rouge
96 | ::: nanodl.word_error_rate
97 | ::: nanodl.adjust_brightness
98 | ::: nanodl.adjust_contrast
99 | ::: nanodl.flip_image
100 | ::: nanodl.gaussian_blur
101 | ::: nanodl.normalize_images
102 | ::: nanodl.random_crop
103 | ::: nanodl.random_flip_image
104 | ::: nanodl.sobel_edge_detection
105 |
--------------------------------------------------------------------------------
/docs/examples/diffusion_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import jax\n",
10 | "import jax.numpy as jnp\n",
11 | "from nanodl import ArrayDataset, DataLoader\n",
12 | "from nanodl import DiffusionModel, DiffusionDataParallelTrainer\n",
13 | "\n",
14 | "image_size = 32\n",
15 | "block_depth = 2\n",
16 | "batch_size = 8\n",
17 | "widths = [32, 64, 128]\n",
18 | "key = jax.random.PRNGKey(0)"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 2,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "# Use actual images\n",
28 | "images = jnp.ones((101, image_size, image_size, 3))\n",
29 | "dataset = ArrayDataset(images) \n",
30 | "dataloader = DataLoader(dataset, \n",
31 | " batch_size=batch_size, \n",
32 | " shuffle=True, \n",
33 | " drop_last=False) "
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 3,
39 | "metadata": {},
40 | "outputs": [
41 | {
42 | "name": "stdout",
43 | "output_type": "stream",
44 | "text": [
45 | "(101, 32, 32, 3) (101, 32, 32, 3)\n"
46 | ]
47 | }
48 | ],
49 | "source": [
50 | "# Create diffusion model\n",
51 | "diffusion_model = DiffusionModel(image_size, widths, block_depth)\n",
52 | "params = diffusion_model.init(key, images)\n",
53 | "pred_noises, pred_images = diffusion_model.apply(params, images)\n",
54 | "print(pred_noises.shape, pred_images.shape)"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 4,
60 | "metadata": {},
61 | "outputs": [
62 | {
63 | "name": "stdout",
64 | "output_type": "stream",
65 | "text": [
66 | "Number of parameters: 1007395\n",
67 | "Number of accelerators: 1\n",
68 | "Epoch 1, Train Loss: 7.979574203491211\n",
69 | "Epoch 1, Val Loss: 24.317951202392578\n",
70 | "New best validation score achieved, saving model...\n",
71 | "Epoch 2, Train Loss: 7.75728702545166\n",
72 | "Epoch 2, Val Loss: 23.518024444580078\n",
73 | "New best validation score achieved, saving model...\n",
74 | "Epoch 3, Train Loss: 7.392527103424072\n",
75 | "Epoch 3, Val Loss: 22.308382034301758\n",
76 | "New best validation score achieved, saving model...\n",
77 | "Epoch 4, Train Loss: 6.846263408660889\n",
78 | "Epoch 4, Val Loss: 20.62131690979004\n",
79 | "New best validation score achieved, saving model...\n",
80 | "Epoch 5, Train Loss: 6.1358747482299805\n",
81 | "Epoch 5, Val Loss: 18.36245346069336\n",
82 | "New best validation score achieved, saving model...\n",
83 | "Epoch 6, Train Loss: 5.278435230255127\n",
84 | "Epoch 6, Val Loss: 15.812017440795898\n",
85 | "New best validation score achieved, saving model...\n",
86 | "Epoch 7, Train Loss: 4.328006267547607\n",
87 | "Epoch 7, Val Loss: 13.123092651367188\n",
88 | "New best validation score achieved, saving model...\n",
89 | "Epoch 8, Train Loss: 3.3344056606292725\n",
90 | "Epoch 8, Val Loss: 10.264131546020508\n",
91 | "New best validation score achieved, saving model...\n",
92 | "Epoch 9, Train Loss: 2.401970386505127\n",
93 | "Epoch 9, Val Loss: 7.67496919631958\n",
94 | "New best validation score achieved, saving model...\n",
95 | "Epoch 10, Train Loss: 1.6279072761535645\n",
96 | "Epoch 10, Val Loss: 5.5517578125\n",
97 | "New best validation score achieved, saving model...\n",
98 | "5.551758\n"
99 | ]
100 | }
101 | ],
102 | "source": [
103 | "# Training on your data\n",
104 | "# Note: saved params are often different from training weights, use the saved params for generation\n",
105 | "trainer = DiffusionDataParallelTrainer(diffusion_model, \n",
106 | " input_shape=images.shape, \n",
107 | " weights_filename='params.pkl', \n",
108 | " learning_rate=1e-4)\n",
109 | "trainer.train(dataloader, 10, dataloader)\n",
110 | "print(trainer.evaluate(dataloader))"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": 6,
116 | "metadata": {},
117 | "outputs": [
118 | {
119 | "name": "stdout",
120 | "output_type": "stream",
121 | "text": [
122 | "(5, 32, 32, 3)\n"
123 | ]
124 | }
125 | ],
126 | "source": [
127 | "# Generate some samples\n",
128 | "params = trainer.load_params('params.pkl')\n",
129 | "generated_images = diffusion_model.apply({'params': params}, \n",
130 | " num_images=5, \n",
131 | " diffusion_steps=5, \n",
132 | " method=diffusion_model.generate)\n",
133 | "print(generated_images.shape)"
134 | ]
135 | }
136 | ],
137 | "metadata": {
138 | "kernelspec": {
139 | "display_name": "Python 3",
140 | "language": "python",
141 | "name": "python3"
142 | },
143 | "language_info": {
144 | "codemirror_mode": {
145 | "name": "ipython",
146 | "version": 3
147 | },
148 | "file_extension": ".py",
149 | "mimetype": "text/x-python",
150 | "name": "python",
151 | "nbconvert_exporter": "python",
152 | "pygments_lexer": "ipython3",
153 | "version": "3.10.0"
154 | }
155 | },
156 | "nbformat": 4,
157 | "nbformat_minor": 2
158 | }
159 |
--------------------------------------------------------------------------------
/docs/examples/gat_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "Output shape: (10, 3)\n",
13 | "Output sample: [[-0.77854604 0.63543594 0.19899184]\n",
14 | " [-0.3826082 0.5630409 0.01140817]]\n"
15 | ]
16 | }
17 | ],
18 | "source": [
19 | "import jax\n",
20 | "from nanodl import GAT\n",
21 | "\n",
22 | "# Generate a random key for Jax\n",
23 | "key = jax.random.PRNGKey(0)\n",
24 | "\n",
25 | "# Create dummy input data\n",
26 | "num_nodes = 10\n",
27 | "num_features = 5\n",
28 | "x = jax.random.normal(key, (num_nodes, num_features)) # Features for each node\n",
29 | "adj = jax.random.bernoulli(key, 0.3, (num_nodes, num_nodes)) # Random adjacency matrix\n",
30 | "\n",
31 | "# Initialize the GAT model\n",
32 | "model = GAT(nfeat=num_features,\n",
33 | " nhid=8, \n",
34 | " nclass=3, \n",
35 | " dropout_rate=0.5, \n",
36 | " alpha=0.2, \n",
37 | " nheads=3)\n",
38 | "\n",
39 | "# Initialize the model parameters\n",
40 | "params = model.init(key, x, adj)\n",
41 | "output = model.apply(params, x, adj)\n",
42 | "\n",
43 | "# Print the output shape and a sample of the output\n",
44 | "print(\"Output shape:\", output.shape)\n",
45 | "print(\"Output sample:\", output[:2])"
46 | ]
47 | }
48 | ],
49 | "metadata": {
50 | "kernelspec": {
51 | "display_name": "Python 3",
52 | "language": "python",
53 | "name": "python3"
54 | },
55 | "language_info": {
56 | "codemirror_mode": {
57 | "name": "ipython",
58 | "version": 3
59 | },
60 | "file_extension": ".py",
61 | "mimetype": "text/x-python",
62 | "name": "python",
63 | "nbconvert_exporter": "python",
64 | "pygments_lexer": "ipython3",
65 | "version": "3.10.0"
66 | }
67 | },
68 | "nbformat": 4,
69 | "nbformat_minor": 2
70 | }
71 |
--------------------------------------------------------------------------------
/docs/examples/gpt4_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 5,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import jax\n",
10 | "import jax.numpy as jnp\n",
11 | "from nanodl import ArrayDataset, DataLoader\n",
12 | "from nanodl import GPT4, GPTDataParallelTrainer\n",
13 | "\n",
14 | "# Generate dummy data\n",
15 | "batch_size = 8\n",
16 | "max_length = 9"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 2,
22 | "metadata": {},
23 | "outputs": [
24 | {
25 | "name": "stdout",
26 | "output_type": "stream",
27 | "text": [
28 | "(8, 9) (8, 9)\n"
29 | ]
30 | }
31 | ],
32 | "source": [
33 | "# Replace with actual tokenised data\n",
34 | "data = jnp.ones((101, max_length+1), dtype=jnp.int32)\n",
35 | "\n",
36 | "# Shift to create next-token prediction dataset\n",
37 | "dummy_inputs = data[:, :-1]\n",
38 | "dummy_targets = data[:, 1:]\n",
39 | "\n",
40 | "# Create dataset and dataloader\n",
41 | "dataset = ArrayDataset(dummy_inputs, dummy_targets)\n",
42 | "dataloader = DataLoader(dataset, \n",
43 | " batch_size=batch_size, \n",
44 | " shuffle=True, \n",
45 | " drop_last=False)\n",
46 | "\n",
47 | "# How to loop through dataloader\n",
48 | "for batch in dataloader:\n",
49 | " x, y = batch\n",
50 | " print(x.shape, y.shape)\n",
51 | " break"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 3,
57 | "metadata": {},
58 | "outputs": [
59 | {
60 | "name": "stdout",
61 | "output_type": "stream",
62 | "text": [
63 | "(101, 9, 1000)\n"
64 | ]
65 | }
66 | ],
67 | "source": [
68 | "# model parameters\n",
69 | "hyperparams = {\n",
70 | " 'num_layers': 1,\n",
71 | " 'hidden_dim': 256,\n",
72 | " 'num_heads': 2,\n",
73 | " 'feedforward_dim': 256,\n",
74 | " 'dropout': 0.1,\n",
75 | " 'vocab_size': 1000,\n",
76 | " 'embed_dim': 256,\n",
77 | " 'max_length': max_length,\n",
78 | " 'start_token': 0,\n",
79 | " 'end_token': 50,\n",
80 | "}\n",
81 | "\n",
82 | "# Initialize model\n",
83 | "model = GPT4(**hyperparams)\n",
84 | "rngs = jax.random.PRNGKey(0)\n",
85 | "rngs, dropout_rng = jax.random.split(rngs)\n",
86 | "params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params']\n",
87 | "\n",
88 | "# Call as you would a Jax/Flax model\n",
89 | "outputs = model.apply({'params': params}, \n",
90 | " dummy_inputs, \n",
91 | " rngs={'dropout': dropout_rng})\n",
92 | "print(outputs.shape)"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 4,
98 | "metadata": {},
99 | "outputs": [
100 | {
101 | "name": "stdout",
102 | "output_type": "stream",
103 | "text": [
104 | "Number of parameters: 3740914\n",
105 | "Number of accelerators: 1\n",
106 | "Epoch 1, Train Loss: 6.438497066497803\n",
107 | "Epoch 1, Val Loss: 5.959759712219238\n",
108 | "New best validation score achieved, saving model...\n",
109 | "5.9597597\n"
110 | ]
111 | }
112 | ],
113 | "source": [
114 | "# Training on data\n",
115 | "trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')\n",
116 | "trainer.train(train_loader=dataloader, \n",
117 | " num_epochs=1, \n",
118 | " val_loader=dataloader) # Use actual validation data\n",
119 | "\n",
120 | "print(trainer.evaluate(dataloader))"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 6,
126 | "metadata": {},
127 | "outputs": [
128 | {
129 | "name": "stdout",
130 | "output_type": "stream",
131 | "text": [
132 | "[639 742 45 840 381 555 251 814 478 261]\n"
133 | ]
134 | }
135 | ],
136 | "source": [
137 | "# Generating from a start token\n",
138 | "start_tokens = jnp.array([[123, 456]])\n",
139 | "\n",
140 | "# Remember to load the trained parameters \n",
141 | "params = trainer.load_params('params.pkl')\n",
142 | "outputs = model.apply({'params': params},\n",
143 | " start_tokens,\n",
144 | " rngs={'dropout': jax.random.PRNGKey(2)}, \n",
145 | " method=model.generate)\n",
146 | "print(outputs) "
147 | ]
148 | }
149 | ],
150 | "metadata": {
151 | "kernelspec": {
152 | "display_name": "Python 3",
153 | "language": "python",
154 | "name": "python3"
155 | },
156 | "language_info": {
157 | "codemirror_mode": {
158 | "name": "ipython",
159 | "version": 3
160 | },
161 | "file_extension": ".py",
162 | "mimetype": "text/x-python",
163 | "name": "python",
164 | "nbconvert_exporter": "python",
165 | "pygments_lexer": "ipython3",
166 | "version": "3.10.0"
167 | }
168 | },
169 | "nbformat": 4,
170 | "nbformat_minor": 2
171 | }
172 |
--------------------------------------------------------------------------------
/docs/examples/llama_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import jax\n",
10 | "import jax.numpy as jnp\n",
11 | "from nanodl import ArrayDataset, DataLoader\n",
12 | "from nanodl import LlaMA2, LlaMADataParallelTrainer\n",
13 | "\n",
14 | "# Generate dummy data\n",
15 | "batch_size = 8\n",
16 | "max_length = 10"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 3,
22 | "metadata": {},
23 | "outputs": [
24 | {
25 | "name": "stdout",
26 | "output_type": "stream",
27 | "text": [
28 | "(8, 10) (8, 10)\n"
29 | ]
30 | }
31 | ],
32 | "source": [
33 | "# Replace with actual tokenised data\n",
34 | "data = jnp.ones((101, max_length+1), dtype=jnp.int32)\n",
35 | "\n",
36 | "# Shift to create next-token prediction dataset\n",
37 | "dummy_inputs = data[:, :-1]\n",
38 | "dummy_targets = data[:, 1:]\n",
39 | "\n",
40 | "# Create dataset and dataloader\n",
41 | "dataset = ArrayDataset(dummy_inputs, dummy_targets)\n",
42 | "dataloader = DataLoader(dataset, \n",
43 | " batch_size=batch_size, \n",
44 | " shuffle=True, \n",
45 | " drop_last=False)\n",
46 | "\n",
47 | "# How to loop through dataloader\n",
48 | "for batch in dataloader:\n",
49 | " x, y = batch\n",
50 | " print(x.shape, y.shape)\n",
51 | " break"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 5,
57 | "metadata": {},
58 | "outputs": [
59 | {
60 | "name": "stdout",
61 | "output_type": "stream",
62 | "text": [
63 | "(101, 10, 1000)\n"
64 | ]
65 | }
66 | ],
67 | "source": [
68 | "# model parameters\n",
69 | "hyperparams = {\n",
70 | " 'num_layers': 1,\n",
71 | " 'hidden_dim': 256,\n",
72 | " 'num_heads': 2,\n",
73 | " 'feedforward_dim': 256,\n",
74 | " 'dropout': 0.1,\n",
75 | " 'vocab_size': 1000,\n",
76 | " 'embed_dim': 256,\n",
77 | " 'max_length': max_length,\n",
78 | " 'start_token': 0,\n",
79 | " 'end_token': 50,\n",
80 | " 'num_groups': 2,\n",
81 | "}\n",
82 | "\n",
83 | "# Initialize model\n",
84 | "model = LlaMA2(**hyperparams)\n",
85 | "rngs = jax.random.PRNGKey(0)\n",
86 | "rngs, dropout_rng = jax.random.split(rngs)\n",
87 | "params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params']\n",
88 | "\n",
89 | "# Call as you would a Jax/Flax model\n",
90 | "outputs = model.apply({'params': params}, \n",
91 | " dummy_inputs, \n",
92 | " rngs={'dropout': dropout_rng})\n",
93 | "print(outputs.shape)"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 6,
99 | "metadata": {},
100 | "outputs": [
101 | {
102 | "name": "stdout",
103 | "output_type": "stream",
104 | "text": [
105 | "Number of parameters: 974312\n",
106 | "Number of accelerators: 1\n",
107 | "Epoch 1, Train Loss: 7.56395959854126\n",
108 | "Epoch 1, Val Loss: 6.7949442863464355\n",
109 | "New best validation score achieved, saving model...\n",
110 | "Epoch 2, Train Loss: 6.494354248046875\n",
111 | "Epoch 2, Val Loss: 5.658466815948486\n",
112 | "New best validation score achieved, saving model...\n",
113 | "5.658467\n"
114 | ]
115 | }
116 | ],
117 | "source": [
118 | "# Training on data\n",
119 | "trainer = LlaMADataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')\n",
120 | "trainer.train(train_loader=dataloader, \n",
121 | " num_epochs=2, \n",
122 | " val_loader=dataloader) # Use actual validation data\n",
123 | "\n",
124 | "print(trainer.evaluate(dataloader))"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 7,
130 | "metadata": {},
131 | "outputs": [
132 | {
133 | "name": "stdout",
134 | "output_type": "stream",
135 | "text": [
136 | "[157 157 932 932 949 406 748 523 171 364]\n"
137 | ]
138 | }
139 | ],
140 | "source": [
141 | "# Generating from a start token\n",
142 | "start_tokens = jnp.array([[123, 456]])\n",
143 | "\n",
144 | "# Remember to load the trained parameters \n",
145 | "params = trainer.load_params('params.pkl')\n",
146 | "outputs = model.apply({'params': params},\n",
147 | " start_tokens,\n",
148 | " rngs={'dropout': jax.random.PRNGKey(2)}, \n",
149 | " method=model.generate)\n",
150 | "print(outputs)"
151 | ]
152 | }
153 | ],
154 | "metadata": {
155 | "kernelspec": {
156 | "display_name": "Python 3",
157 | "language": "python",
158 | "name": "python3"
159 | },
160 | "language_info": {
161 | "codemirror_mode": {
162 | "name": "ipython",
163 | "version": 3
164 | },
165 | "file_extension": ".py",
166 | "mimetype": "text/x-python",
167 | "name": "python",
168 | "nbconvert_exporter": "python",
169 | "pygments_lexer": "ipython3",
170 | "version": "3.10.0"
171 | }
172 | },
173 | "nbformat": 4,
174 | "nbformat_minor": 2
175 | }
176 |
--------------------------------------------------------------------------------
/docs/examples/whisper_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import jax\n",
10 | "import jax.numpy as jnp\n",
11 | "from nanodl import ArrayDataset, DataLoader\n",
12 | "from nanodl import Whisper, WhisperDataParallelTrainer\n",
13 | "\n",
14 | "# Dummy data parameters\n",
15 | "batch_size = 8\n",
16 | "max_length = 50\n",
17 | "embed_dim = 256 \n",
18 | "vocab_size = 1000 "
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 2,
24 | "metadata": {},
25 | "outputs": [
26 | {
27 | "name": "stdout",
28 | "output_type": "stream",
29 | "text": [
30 | "(8, 50, 256) (8, 50)\n"
31 | ]
32 | }
33 | ],
34 | "source": [
35 | "# Generate data: replace with actual tokenised/quantised data\n",
36 | "dummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)\n",
37 | "dummy_inputs = jnp.ones((101, max_length, embed_dim))\n",
38 | "\n",
39 | "dataset = ArrayDataset(dummy_inputs, \n",
40 | " dummy_targets)\n",
41 | "\n",
42 | "dataloader = DataLoader(dataset, \n",
43 | " batch_size=batch_size, \n",
44 | " shuffle=True, \n",
45 | " drop_last=False)\n",
46 | "\n",
47 | "# How to loop through dataloader\n",
48 | "for batch in dataloader:\n",
49 | " x, y = batch\n",
50 | " print(x.shape, y.shape)\n",
51 | " break"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 3,
57 | "metadata": {},
58 | "outputs": [
59 | {
60 | "name": "stdout",
61 | "output_type": "stream",
62 | "text": [
63 | "(101, 50, 1000)\n"
64 | ]
65 | }
66 | ],
67 | "source": [
68 | "# model parameters\n",
69 | "hyperparams = {\n",
70 | " 'num_layers': 1,\n",
71 | " 'hidden_dim': 256,\n",
72 | " 'num_heads': 2,\n",
73 | " 'feedforward_dim': 256,\n",
74 | " 'dropout': 0.1,\n",
75 | " 'vocab_size': 1000,\n",
76 | " 'embed_dim': embed_dim,\n",
77 | " 'max_length': max_length,\n",
78 | " 'start_token': 0,\n",
79 | " 'end_token': 50,\n",
80 | "}\n",
81 | "\n",
82 | "# Initialize model\n",
83 | "model = Whisper(**hyperparams)\n",
84 | "rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}\n",
85 | "params = model.init(rngs, dummy_inputs, dummy_targets)['params']\n",
86 | "outputs = model.apply({'params': params}, dummy_inputs, dummy_targets, rngs=rngs)\n",
87 | "print(outputs.shape)"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": 4,
93 | "metadata": {},
94 | "outputs": [
95 | {
96 | "name": "stdout",
97 | "output_type": "stream",
98 | "text": [
99 | "Number of parameters: 1974760\n",
100 | "Number of accelerators: 1\n",
101 | "Epoch 1, Train Loss: 8.127946853637695\n",
102 | "Epoch 1, Val Loss: 7.081634521484375\n",
103 | "New best validation score achieved, saving model...\n",
104 | "Epoch 2, Train Loss: 6.22011137008667\n",
105 | "Epoch 2, Val Loss: 5.051723957061768\n",
106 | "New best validation score achieved, saving model...\n"
107 | ]
108 | }
109 | ],
110 | "source": [
111 | "# Training on your data\n",
112 | "trainer = WhisperDataParallelTrainer(model, \n",
113 | " dummy_inputs.shape, \n",
114 | " dummy_targets.shape, \n",
115 | " 'params.pkl')\n",
116 | "trainer.train(dataloader, 2, dataloader)"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": 5,
122 | "metadata": {},
123 | "outputs": [
124 | {
125 | "name": "stdout",
126 | "output_type": "stream",
127 | "text": [
128 | "[930 471 500 450 936 936 143 851 716 695 275 389 246 79 7 494 562 314\n",
129 | " 276 583 788 525 302 150 694 694 161 741 902 77 946 294 210 945 272 266\n",
130 | " 493 553 533 619 703 330 330 154 438 797 334 322 31 649]\n"
131 | ]
132 | }
133 | ],
134 | "source": [
135 | "# Sample inference\n",
136 | "params = trainer.load_params('params.pkl')\n",
137 | "\n",
138 | "# for more than one sample, use model.generate_batch\n",
139 | "transcripts = model.apply({'params': params}, \n",
140 | " dummy_inputs[:1], \n",
141 | " rngs=rngs, \n",
142 | " method=model.generate)\n",
143 | "\n",
144 | "print(transcripts)"
145 | ]
146 | }
147 | ],
148 | "metadata": {
149 | "kernelspec": {
150 | "display_name": "Python 3",
151 | "language": "python",
152 | "name": "python3"
153 | },
154 | "language_info": {
155 | "codemirror_mode": {
156 | "name": "ipython",
157 | "version": 3
158 | },
159 | "file_extension": ".py",
160 | "mimetype": "text/x-python",
161 | "name": "python",
162 | "nbconvert_exporter": "python",
163 | "pygments_lexer": "ipython3",
164 | "version": "3.10.0"
165 | }
166 | },
167 | "nbformat": 4,
168 | "nbformat_minor": 2
169 | }
170 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ## Overview
2 |
3 | Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features:
4 |
5 | - A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch.
6 | - An extensive selection of models like LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications.
7 | - Data-parallel distributed trainers so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops.
8 | - Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective.
9 | - Custom layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development.
10 | - GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU.
11 | - Modular design so users can blend elements from various models, such as GPT, Mixtral, and LlaMa2, to craft unique hybrid transformer models.
12 | - A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU etc.
13 | - Each model is contained in a single file with no external dependencies, so the source code can also be easily used.
14 |
15 | Feedback on any of our discussion, issue and pull request threads are welcomed! Please report any feature requests, issues, questions or concerns in the [discussion forum](https://github.com/hmunachi/nanodl/discussions), or just let us know what you're working on! In case you want to reach out directly, we're at ndubuakuhenry@gmail.com.
16 |
17 | # Contribution
18 |
19 | This is the first iteration of this project, roughness is expected, contributions are therefore highly encouraged! Follow the recommended steps:
20 |
21 | - Raise the issue/discussion to get second opinions
22 | - Fork the repository
23 | - Create a branch
24 | - Make your changes without ruining the design patterns
25 | - Write tests for your changes if necessary
26 | - Install locally with `pip install -e .`
27 | - Run tests with `python -m unittest discover -s tests`
28 | - Then submit a pull request from branch.
29 |
30 | Contributions can be made in various forms:
31 |
32 | - Writing documentation.
33 | - Fixing bugs.
34 | - Implementing papers.
35 | - Writing high-coverage tests.
36 | - OPtimizing existing codes.
37 | - Experimenting and submitting real-world examples to the examples section.
38 | - Reporting bugs.
39 | - Responding to reported issues.
40 |
41 | Coming features include:
42 | - Reinforcement Learning With Human Feedback (RLHF).
43 | - Tokenizers.
44 | - Code optimisations.
45 |
46 | To follow up or share thoughts, follow [here](https://forms.gle/vwveb9SKdPYywHx9A)
47 |
48 | ## Sponsorships
49 |
50 | The name "NanoDL" stands for Nano Deep Learning. Models are exploding in size, therefore gate-keeping
51 | experts and companies with limited resources from building flexible models without prohibitive costs.
52 | Following the success of Phi models, the long-term goal is to build and train nano versions of all available models,
53 | while ensuring they compete with the original models in performance, with total
54 | number of parameters not exceeding 1B. Trained weights will be made available via this library.
55 | Any form of sponsorship, funding, grants or contribution will help with training resources.
56 | You can sponsor via the provided button, or reach out via ndubuakuhenry@gmail.com.
--------------------------------------------------------------------------------
/docs/requirements.in:
--------------------------------------------------------------------------------
1 | mkdocs
2 | mkdocstrings[python]
3 | markdown-include
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | #
2 | # This file is autogenerated by pip-compile with python 3.10
3 | # To update, run:
4 | #
5 | # pip-compile docs/requirements.in
6 | #
7 | click==8.1.3
8 | # via mkdocs
9 | ghp-import==2.1.0
10 | # via mkdocs
11 | griffe==0.22.0
12 | # via mkdocstrings-python
13 | importlib-metadata==4.12.0
14 | # via mkdocs
15 | jinja2==3.1.2
16 | # via
17 | # mkdocs
18 | # mkdocstrings
19 | markdown==3.3.7
20 | # via
21 | # markdown-include
22 | # mkdocs
23 | # mkdocs-autorefs
24 | # mkdocstrings
25 | # pymdown-extensions
26 | markdown-include==0.6.0
27 | # via -r docs/requirements.in
28 | markupsafe==2.1.1
29 | # via
30 | # jinja2
31 | # mkdocstrings
32 | mergedeep==1.3.4
33 | # via mkdocs
34 | mkdocs==1.3.0
35 | # via
36 | # -r docs/requirements.in
37 | # mkdocs-autorefs
38 | # mkdocstrings
39 | mkdocs-autorefs==0.4.1
40 | # via mkdocstrings
41 | mkdocstrings[python]==0.19.0
42 | # via
43 | # -r docs/requirements.in
44 | # mkdocstrings-python
45 | mkdocstrings-python==0.7.1
46 | # via mkdocstrings
47 | packaging==21.3
48 | # via mkdocs
49 | pymdown-extensions==9.5
50 | # via mkdocstrings
51 | pyparsing==3.0.9
52 | # via packaging
53 | python-dateutil==2.8.2
54 | # via ghp-import
55 | pyyaml==6.0
56 | # via
57 | # mkdocs
58 | # pyyaml-env-tag
59 | pyyaml-env-tag==0.1
60 | # via mkdocs
61 | six==1.16.0
62 | # via python-dateutil
63 | watchdog==2.1.9
64 | # via mkdocs
65 | zipp==3.8.0
66 | # via importlib-metadata
--------------------------------------------------------------------------------
/docs/usage.md:
--------------------------------------------------------------------------------
1 | Usage
2 | =====
3 |
4 | Installation
5 | ------------
6 |
7 | You will need Python 3.9 or later, and working [JAX](https://github.com/google/jax/blob/main/README.md)
8 | installation, [FLAX](https://github.com/google/flax/blob/main/README.md)
9 | installation, [OPTAX](https://github.com/google-deepmind/optax/blob/main/README.md)
10 | installation (with GPU support for running training, without can only support creations).
11 | Models can be designed and tested on CPUs but trainers are all Distributed Data-Parallel which would require a GPU with 1 to N GPUS/TPUS. For CPU-only version of JAX:
12 |
13 | ```
14 | pip install --upgrade pip # To support manylinux2010 wheels.
15 | pip install jax, flax, optax
16 | ```
17 |
18 | Then, install nanodl from PyPi:
19 |
20 | ```
21 | pip install nanodl
22 | ```
23 |
24 | Creating a GPT Model
25 | ----------------
26 |
27 | ```py
28 | import jax
29 | import jax.numpy as jnp
30 | from nanodl import ArrayDataset, DataLoader
31 | from nanodl import GPT4, GPTDataParallelTrainer
32 |
33 | # Generate dummy data
34 | batch_size = 8
35 | max_length = 10
36 |
37 | # Replace with actual tokenised data
38 | data = jnp.ones((101, max_length), dtype=jnp.int32)
39 |
40 | # Shift to create next-token prediction dataset
41 | dummy_inputs = data[:, :-1]
42 | dummy_targets = data[:, 1:]
43 |
44 | # Create dataset and dataloader
45 | dataset = ArrayDataset(dummy_inputs, dummy_targets)
46 | dataloader = DataLoader(dataset,
47 | batch_size=batch_size,
48 | shuffle=True,
49 | drop_last=False)
50 |
51 | # How to loop through dataloader
52 | for batch in dataloader:
53 | x, y = batch
54 | print(x.shape, y.shape)
55 | break
56 |
57 | # model parameters
58 | hyperparams = {
59 | 'num_layers': 1,
60 | 'hidden_dim': 256,
61 | 'num_heads': 2,
62 | 'feedforward_dim': 256,
63 | 'dropout': 0.1,
64 | 'vocab_size': 1000,
65 | 'embed_dim': 256,
66 | 'max_length': max_length,
67 | 'start_token': 0,
68 | 'end_token': 50,
69 | }
70 |
71 | # Initialize model
72 | model = GPT4(**hyperparams)
73 | rngs = jax.random.PRNGKey(0)
74 | rngs, dropout_rng = jax.random.split(rngs)
75 | params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params']
76 |
77 | # Call as you would a Jax/Flax model
78 | outputs = model.apply({'params': params},
79 | dummy_inputs,
80 | rngs={'dropout': dropout_rng})
81 | print(outputs.shape)
82 |
83 | # Training on data
84 | trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')
85 | trainer.train(train_loader=dataloader,
86 | num_epochs=2,
87 | val_loader=dataloader)
88 |
89 | print(trainer.evaluate(dataloader))
90 |
91 | # Generating from a start token
92 | start_tokens = jnp.array([[123, 456]])
93 |
94 | # Remember to load the trained parameters
95 | params = trainer.load_params('params.pkl')
96 | outputs = model.apply({'params': params},
97 | start_tokens,
98 | rngs={'dropout': jax.random.PRNGKey(2)},
99 | method=model.generate)
100 | print(outputs)
101 | ```
102 |
103 | Creating a Diffusion model
104 | ----------------
105 |
106 | ```py
107 | import jax
108 | import jax.numpy as jnp
109 | from nanodl import ArrayDataset, DataLoader
110 | from nanodl import DiffusionModel, DiffusionDataParallelTrainer
111 |
112 | image_size = 32
113 | block_depth = 2
114 | batch_size = 8
115 | widths = [32, 64, 128]
116 | key = jax.random.PRNGKey(0)
117 | input_shape = (101, image_size, image_size, 3)
118 | images = jax.random.normal(key, input_shape)
119 |
120 | # Use your own images
121 | dataset = ArrayDataset(images)
122 | dataloader = DataLoader(dataset,
123 | batch_size=batch_size,
124 | shuffle=True,
125 | drop_last=False)
126 |
127 | # Create diffusion model
128 | diffusion_model = DiffusionModel(image_size, widths, block_depth)
129 | params = diffusion_model.init(key, images)
130 | pred_noises, pred_images = diffusion_model.apply(params, images)
131 | print(pred_noises.shape, pred_images.shape)
132 |
133 | # Training on your data
134 | # Note: saved params are often different from training weights, use the saved params for generation
135 | trainer = DiffusionDataParallelTrainer(diffusion_model,
136 | input_shape=images.shape,
137 | weights_filename='params.pkl',
138 | learning_rate=1e-4)
139 | trainer.train(dataloader, 10, dataloader)
140 | print(trainer.evaluate(dataloader))
141 |
142 | # Generate some samples
143 | params = trainer.load_params('params.pkl')
144 | generated_images = diffusion_model.apply({'params': params},
145 | num_images=5,
146 | diffusion_steps=5,
147 | method=diffusion_model.generate)
148 | print(generated_images.shape)
149 | ```
150 |
151 | Creating a Whisper TTS model
152 | ----------------
153 |
154 | ```py
155 | import jax
156 | import jax.numpy as jnp
157 | from nanodl import ArrayDataset, DataLoader
158 | from nanodl import Whisper, WhisperDataParallelTrainer
159 |
160 | # Dummy data parameters
161 | batch_size = 8
162 | max_length = 50
163 | embed_dim = 256
164 | vocab_size = 1000
165 |
166 | # Generate data: replace with actual tokenised/quantised data
167 | dummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)
168 | dummy_inputs = jnp.ones((101, max_length, embed_dim))
169 |
170 | dataset = ArrayDataset(dummy_inputs,
171 | dummy_targets)
172 |
173 | dataloader = DataLoader(dataset,
174 | batch_size=batch_size,
175 | shuffle=True,
176 | drop_last=False)
177 |
178 | # How to loop through dataloader
179 | for batch in dataloader:
180 | x, y = batch
181 | print(x.shape, y.shape)
182 | break
183 |
184 | # model parameters
185 | hyperparams = {
186 | 'num_layers': 1,
187 | 'hidden_dim': 256,
188 | 'num_heads': 2,
189 | 'feedforward_dim': 256,
190 | 'dropout': 0.1,
191 | 'vocab_size': 1000,
192 | 'embed_dim': embed_dim,
193 | 'max_length': max_length,
194 | 'start_token': 0,
195 | 'end_token': 50,
196 | }
197 |
198 | # Initialize model
199 | model = Whisper(**hyperparams)
200 | rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}
201 | params = model.init(rngs, dummy_inputs, dummy_targets)['params']
202 | outputs = model.apply({'params': params}, dummy_inputs, dummy_targets, rngs=rngs)
203 | print(outputs.shape)
204 |
205 | # Training on your data
206 | trainer = WhisperDataParallelTrainer(model,
207 | dummy_inputs.shape,
208 | dummy_targets.shape,
209 | 'params.pkl')
210 | trainer.train(dataloader, 2, dataloader)
211 |
212 | # Sample inference
213 | params = trainer.load_params('params.pkl')
214 |
215 | # for more than one sample, use model.generate_batch
216 | transcripts = model.apply({'params': params},
217 | dummy_inputs[:1],
218 | rngs=rngs,
219 | method=model.generate)
220 |
221 | print(transcripts)
222 | ```
223 |
224 | Creating an Accelerated PCA
225 | ----------------
226 |
227 | ```py
228 | import jax
229 | from nanodl import PCA
230 |
231 | data = jax.random.normal(jax.random.key(0), (1000, 10))
232 | pca = PCA(n_components=2)
233 | pca.fit(data)
234 | transformed_data = pca.transform(data)
235 | original_data = pca.inverse_transform(transformed_data)
236 | X_sampled = pca.sample(n_samples=1000, key=None)
237 | print(X_sampled.shape, original_data.shape, transformed_data.shape)
238 | ```
239 | GPU/TPU-accelerated versions of many models on SKLearn like NaiveBayesClassifier, Linear Regression, KMeans are available on NanoDL.
240 |
241 | Using an individual module
242 | ----------------
243 |
244 | Each contituent layer can be used in your own model.
245 |
246 | ```py
247 | from nanodl import GraphAttetnionLayer
248 |
249 | class GAT(nn.Module):
250 | nfeat: int
251 | nhid: int
252 | nclass: int
253 | dropout_rate: float
254 | alpha: float
255 | nheads: int
256 |
257 | @nn.compact
258 | def __call__(self,
259 | x: jnp.ndarray,
260 | adj: jnp.ndarray,
261 | training: bool = False) -> jnp.ndarray:
262 | heads = [GraphAttentionLayer(self.nfeat,
263 | self.nhid,
264 | dropout_rate=self.dropout_rate,
265 | alpha=self.alpha, concat=True) for _ in range(self.nheads)]
266 |
267 | x = jnp.concatenate([head(x, adj, training) for head in heads], axis=1)
268 | x = nn.Dropout(rate=self.dropout_rate,
269 | deterministic=not training)(x)
270 |
271 | out_att = GraphAttentionLayer(self.nhid * self.nheads,
272 | self.nclass,
273 | dropout_rate=self.dropout_rate,
274 | alpha=self.alpha, concat=False)
275 |
276 | return out_att(x, adj, training)
277 | ```
278 |
279 | With this, you could for example create a transformer model with T5 Encoder and LlaMa2 Decoder!
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: NanoDL Documentation
2 | theme:
3 | name: readthedocs
4 | highlightjs: true
5 | plugins:
6 | - search
7 | - mkdocstrings:
8 | handlers:
9 | python:
10 | options:
11 | docstring_section_style: "list"
12 | members_order: "source"
13 | show_root_heading: true
14 | show_source: false
15 | show_signature_annotations: false
16 | selection:
17 | inheritance: true
18 | filters:
19 | - "!^_[^_]" # Exclude members starting with a single underscore
20 | rendering:
21 | show_category_heading: true
22 | heading_level: 3
23 | docstring_styles: ["google", "numpy", "restructuredtext", "plain", "markdown"]
24 | cross_references:
25 | use_short_names: true
26 | fail_on_missing_reference: false
27 | introspection:
28 | modules: ["nanodl"]
29 | markdown_extensions:
30 | - markdown_include.include:
31 | base_path: .
32 | - admonition
--------------------------------------------------------------------------------
/nanodl/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.2.5.dev1"
2 |
3 | from nanodl.__src.classical.bayes import NaiveBayesClassifier
4 | from nanodl.__src.classical.clustering import GaussianMixtureModel, KMeans
5 | from nanodl.__src.classical.dimensionality_reduction import PCA
6 | from nanodl.__src.classical.regression import (
7 | GaussianProcess,
8 | LinearRegression,
9 | LogisticRegression,
10 | )
11 | from nanodl.__src.experimental.gat import GAT, GraphAttentionLayer
12 | from nanodl.__src.models.attention import (
13 | GatedMultiHeadAttention,
14 | HierarchicalMultiHeadAttention,
15 | LocalMultiHeadAttention,
16 | MultiQueryAttention,
17 | RotaryMultiHeadAttention,
18 | )
19 | from nanodl.__src.models.clip import (
20 | CLIP,
21 | CLIPDataParallelTrainer,
22 | ImageEncoder,
23 | SelfMultiHeadAttention,
24 | TextEncoder,
25 | )
26 | from nanodl.__src.models.diffusion import (
27 | DiffusionDataParallelTrainer,
28 | DiffusionModel,
29 | UNet,
30 | UNetDownBlock,
31 | UNetResidualBlock,
32 | UNetUpBlock,
33 | )
34 | from nanodl.__src.models.gemma import (
35 | Gemma,
36 | GemmaDataParallelTrainer,
37 | GemmaDecoder,
38 | GemmaDecoderBlock,
39 | )
40 | from nanodl.__src.models.gpt import (
41 | GPT3,
42 | GPT4,
43 | GPT3Block,
44 | GPT3Decoder,
45 | GPT4Block,
46 | GPT4Decoder,
47 | GPTDataParallelTrainer,
48 | PositionWiseFFN,
49 | )
50 | from nanodl.__src.models.ijepa import IJEPA, IJEPADataParallelTrainer, IJEPADataSampler
51 | from nanodl.__src.models.lamda import (
52 | LaMDA,
53 | LaMDABlock,
54 | LaMDADataParallelTrainer,
55 | LaMDADecoder,
56 | RelativeMultiHeadAttention,
57 | )
58 | from nanodl.__src.models.llama import (
59 | GroupedRotaryMultiHeadAttention,
60 | Llama3,
61 | Llama3Decoder,
62 | Llama3DecoderBlock,
63 | LlamaDataParallelTrainer,
64 | RotaryPositionalEncoding,
65 | )
66 | from nanodl.__src.models.kan import(
67 | KANLinear,
68 | ChebyKANLinear,
69 | LegendreKANLinear,
70 | MonomialKANLinear,
71 | FourierKANLinear,
72 | HermiteKANLinear,
73 | )
74 | from nanodl.__src.models.mistral import (
75 | GroupedRotaryShiftedWindowMultiHeadAttention,
76 | Mistral,
77 | MistralDataParallelTrainer,
78 | MistralDecoder,
79 | MistralDecoderBlock,
80 | Mixtral,
81 | MixtralDecoder,
82 | MixtralDecoderBlock,
83 | )
84 | from nanodl.__src.models.mixer import (
85 | Mixer,
86 | MixerBlock,
87 | MixerDataParallelTrainer,
88 | MixerEncoder,
89 | )
90 | from nanodl.__src.models.reward import RewardDataParallelTrainer, RewardModel
91 | from nanodl.__src.models.t5 import (
92 | T5,
93 | T5DataParallelTrainer,
94 | T5Decoder,
95 | T5DecoderBlock,
96 | T5Encoder,
97 | T5EncoderBlock,
98 | )
99 | from nanodl.__src.models.transformer import (
100 | AddNorm,
101 | MultiHeadAttention,
102 | PositionalEncoding,
103 | PositionWiseFFN,
104 | TokenAndPositionEmbedding,
105 | Transformer,
106 | TransformerDataParallelTrainer,
107 | TransformerDecoderBlock,
108 | TransformerEncoder,
109 | )
110 | from nanodl.__src.models.vit import (
111 | PatchEmbedding,
112 | ViT,
113 | ViTBlock,
114 | ViTDataParallelTrainer,
115 | ViTEncoder,
116 | )
117 | from nanodl.__src.models.whisper import (
118 | Whisper,
119 | WhisperDataParallelTrainer,
120 | WhisperSpeechEncoder,
121 | WhisperSpeechEncoderBlock,
122 | )
123 | from nanodl.__src.utils.data import ArrayDataset, DataLoader, Dataset
124 | from nanodl.__src.utils.ml import (
125 | batch_cosine_similarities,
126 | batch_pearsonr,
127 | classification_scores,
128 | count_parameters,
129 | entropy,
130 | gini_impurity,
131 | hamming,
132 | jaccard,
133 | kl_divergence,
134 | mean_reciprocal_rank,
135 | zero_pad_sequences,
136 | )
137 | from nanodl.__src.utils.nlp import (
138 | bleu,
139 | cider_score,
140 | meteor,
141 | perplexity,
142 | rouge,
143 | word_error_rate,
144 | )
145 | from nanodl.__src.utils.random import *
146 | from nanodl.__src.utils.vision import (
147 | adjust_brightness,
148 | adjust_contrast,
149 | flip_image,
150 | gaussian_blur,
151 | normalize_images,
152 | random_crop,
153 | random_flip_image,
154 | sobel_edge_detection,
155 | )
156 |
157 | __all__ = [
158 | # Classical
159 | "NaiveBayesClassifier",
160 | "PCA",
161 | "KMeans",
162 | "GaussianMixtureModel",
163 | "LinearRegression",
164 | "LogisticRegression",
165 | "GaussianProcess",
166 | # Models
167 | "IJEPA",
168 | "IJEPADataParallelTrainer",
169 | "IJEPADataSampler",
170 | "Gemma",
171 | "GemmaDataParallelTrainer",
172 | "GemmaDecoder",
173 | "GemmaDecoderBlock",
174 | "GAT",
175 | "GraphAttentionLayer",
176 | "T5",
177 | "T5DataParallelTrainer",
178 | "T5Encoder",
179 | "T5Decoder",
180 | "T5EncoderBlock",
181 | "T5DecoderBlock",
182 | "ViT",
183 | "ViTDataParallelTrainer",
184 | "ViTBlock",
185 | "ViTEncoder",
186 | "PatchEmbedding",
187 | "CLIP",
188 | "CLIPDataParallelTrainer",
189 | "ImageEncoder",
190 | "TextEncoder",
191 | "SelfMultiHeadAttention",
192 | "LaMDA",
193 | "LaMDADataParallelTrainer",
194 | "LaMDABlock",
195 | "LaMDADecoder",
196 | "RelativeMultiHeadAttention",
197 | "Mixer",
198 | "MixerDataParallelTrainer",
199 | "MixerBlock",
200 | "MixerEncoder",
201 | "Llama3",
202 | "LlamaDataParallelTrainer",
203 | "RotaryPositionalEncoding",
204 | "Llama3Decoder",
205 | "Llama3DecoderBlock",
206 | "GroupedRotaryMultiHeadAttention",
207 | "GPT3",
208 | "GPT4",
209 | "GPTDataParallelTrainer",
210 | "GPT3Block",
211 | "GPT4Block",
212 | "GPT3Decoder",
213 | "GPT4Decoder",
214 | "PositionWiseFFN",
215 | "Mistral",
216 | "MistralDataParallelTrainer",
217 | "MistralDecoder",
218 | "MistralDecoderBlock",
219 | "GroupedRotaryShiftedWindowMultiHeadAttention",
220 | "Mixtral",
221 | "MixtralDecoder",
222 | "MixtralDecoderBlock",
223 | "Whisper",
224 | "WhisperDataParallelTrainer",
225 | "WhisperSpeechEncoder",
226 | "WhisperSpeechEncoderBlock",
227 | "RewardModel",
228 | "RewardDataParallelTrainer",
229 | "DiffusionModel",
230 | "DiffusionDataParallelTrainer",
231 | "UNet",
232 | "UNetDownBlock",
233 | "UNetUpBlock",
234 | "UNetResidualBlock",
235 | "Transformer",
236 | "TransformerDataParallelTrainer",
237 | "TransformerEncoder",
238 | "TransformerDecoderBlock",
239 | "PositionalEncoding",
240 | "PositionWiseFFN",
241 | "TokenAndPositionEmbedding",
242 | "MultiHeadAttention",
243 | "AddNorm",
244 | # Utilities
245 | "Dataset",
246 | "ArrayDataset",
247 | "DataLoader",
248 | "batch_cosine_similarities",
249 | "batch_pearsonr",
250 | "classification_scores",
251 | "count_parameters",
252 | "entropy",
253 | "gini_impurity",
254 | "hamming",
255 | "jaccard",
256 | "kl_divergence",
257 | "mean_reciprocal_rank",
258 | "zero_pad_sequences",
259 | "bleu",
260 | "cider_score",
261 | "meteor",
262 | "perplexity",
263 | "rouge",
264 | "word_error_rate",
265 | "adjust_brightness",
266 | "adjust_contrast",
267 | "flip_image",
268 | "gaussian_blur",
269 | "normalize_images",
270 | "random_crop",
271 | "random_flip_image",
272 | "sobel_edge_detection",
273 | "MultiQueryAttention",
274 | "LocalMultiHeadAttention",
275 | "HierarchicalMultiHeadAttention",
276 | "GatedMultiHeadAttention",
277 | "RotaryMultiHeadAttention",
278 | # Random
279 | "time_rng_key",
280 | "uniform",
281 | "normal",
282 | "bernoulli",
283 | "categorical",
284 | "randint",
285 | "permutation",
286 | "gumbel",
287 | "choice",
288 | "bits",
289 | "exponential",
290 | "triangular",
291 | "truncated_normal",
292 | "poisson",
293 | "geometric",
294 | "gamma",
295 | "chisquare",
296 | "KANLinear",
297 | "ChebyKANLinear",
298 | "LegendreKANLinear",
299 | "MonomialKANLinear",
300 | "FourierKANLinear",
301 | "HermiteKANLinear",
302 | ]
303 |
304 | import importlib
305 | import sys
306 |
307 |
308 | def check_library_installed(lib_name):
309 | try:
310 | return importlib.import_module(lib_name)
311 | except ImportError:
312 | raise ImportError(f"{lib_name} is not installed or improperly installed.")
313 |
314 |
315 | def test_flax(flax):
316 | model = flax.linen.Dense(features=10)
317 |
318 |
319 | def test_jax(jax):
320 | arr = jax.numpy.array([1, 2, 3])
321 | result = jax.numpy.sum(arr)
322 |
323 |
324 | def test_optax(optax):
325 | optimizer = optax.sgd(learning_rate=0.1)
326 |
327 |
328 | def test_einops(einops):
329 | arr = einops.rearrange([1, 2, 3], "a b c -> b a c")
330 |
331 |
332 | def main():
333 | try:
334 | flax = check_library_installed("flax")
335 | jax = check_library_installed("jax")
336 | optax = check_library_installed("optax")
337 | einops = check_library_installed("einops")
338 |
339 | test_flax(flax)
340 | test_jax(jax)
341 | test_optax(optax)
342 |
343 | except ImportError as e:
344 | print(e)
345 | sys.exit(1)
346 | except Exception as e:
347 | print(f"An error occurred while verifying Jax/Flax/Optax installation: {e}")
348 | sys.exit(1)
349 |
350 |
351 | if __name__ == "__main__":
352 | main()
353 |
--------------------------------------------------------------------------------
/nanodl/__src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HMUNACHI/nanodl/e52861a5b2c9bf76e4e79e0bf88a07420497579d/nanodl/__src/__init__.py
--------------------------------------------------------------------------------
/nanodl/__src/classical/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HMUNACHI/nanodl/e52861a5b2c9bf76e4e79e0bf88a07420497579d/nanodl/__src/classical/__init__.py
--------------------------------------------------------------------------------
/nanodl/__src/classical/bayes.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import jax
4 | import jax.numpy as jnp
5 |
6 |
7 | def fit_naive_bayes(
8 | X: jnp.ndarray, y: jnp.ndarray, num_classes: int
9 | ) -> Tuple[jnp.ndarray, jnp.ndarray]:
10 | class_priors = jnp.zeros(num_classes)
11 | feature_probs = jnp.zeros((num_classes, X.shape[1]))
12 |
13 | for i in range(num_classes):
14 | class_mask = y == i
15 | class_count = jnp.sum(class_mask)
16 | class_priors = class_priors.at[i].set(class_count / X.shape[0])
17 | feature_count = jnp.sum(X[class_mask], axis=0)
18 | feature_probs = feature_probs.at[i].set(feature_count / class_count)
19 |
20 | return class_priors, feature_probs
21 |
22 |
23 | @jax.jit
24 | def predict_naive_bayes(
25 | X: jnp.ndarray, class_priors: jnp.ndarray, feature_probs: jnp.ndarray
26 | ) -> jnp.ndarray:
27 | # Calculate log probabilities for features
28 | log_feature_probs = jnp.log(feature_probs)
29 | log_feature_probs_neg = jnp.log(1 - feature_probs)
30 |
31 | # Expand dimensions for broadcasting
32 | expanded_log_feature_probs = log_feature_probs[:, None, :]
33 | expanded_log_feature_probs_neg = log_feature_probs_neg[:, None, :]
34 |
35 | # Calculate log probabilities for each sample and class
36 | log_probs = jnp.sum(
37 | expanded_log_feature_probs * X + expanded_log_feature_probs_neg * (1 - X),
38 | axis=2,
39 | )
40 | log_probs += jnp.log(class_priors)[:, None]
41 | return jnp.argmax(log_probs, axis=0)
42 |
43 |
44 | def accuracy(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> float:
45 | return jnp.mean(y_true == y_pred)
46 |
47 |
48 | class NaiveBayesClassifier:
49 | """
50 | Naive Bayes classifier using JAX.
51 |
52 | Example usage:
53 | ```
54 | classifier = NaiveBayesClassifier(num_classes=2)
55 | X = jnp.array([[0, 1], [1, 0], [1, 1]])
56 | y = jnp.array([0, 1, 0])
57 | classifier.fit(X, y)
58 | predictions = classifier.predict(X)
59 | acc = accuracy(y, predictions)
60 | print(f"Accuracy: {acc}")
61 | ```
62 |
63 | Attributes:
64 | num_classes (int): Number of classes.
65 | class_priors (jnp.ndarray): Class priors, shape (num_classes,).
66 | feature_probs (jnp.ndarray): Feature probabilities, shape (num_classes, num_features).
67 | """
68 |
69 | def __init__(self, num_classes: int):
70 | self.num_classes = num_classes
71 | self.class_priors = None
72 | self.feature_probs = None
73 |
74 | def fit(self, X: jnp.ndarray, y: jnp.ndarray) -> None:
75 | self.class_priors, self.feature_probs = fit_naive_bayes(X, y, self.num_classes)
76 |
77 | def predict(self, X: jnp.ndarray) -> jnp.ndarray:
78 | return predict_naive_bayes(X, self.class_priors, self.feature_probs)
79 |
--------------------------------------------------------------------------------
/nanodl/__src/classical/clustering.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 |
4 |
5 | class KMeans:
6 | """
7 | KMeans clustering using JAX for GPU/TPU acceleration.
8 |
9 | Attributes:
10 | k (int): Number of clusters.
11 | num_iters (int): Maximum number of iterations.
12 | random_seed (int): Random seed for initialization.
13 | centroids (Optional[jnp.ndarray]): Centroids of the clusters.
14 | clusters (Optional[jnp.ndarray]): Cluster assignments of the data points.
15 |
16 | Example usage:
17 | ```
18 | kmeans = KMeans(k=4)
19 | X, _ = make_blobs(n_samples=300, centers=4, n_features=2, random_state=0)
20 | kmeans.fit(X)
21 | clusters = kmeans.predict(X)
22 | print("Centroids:", kmeans.centroids)
23 | print("Cluster assignments:", clusters)
24 | ```
25 | """
26 |
27 | def __init__(self, k: int, num_iters: int = 100, random_seed: int = 0) -> None:
28 | self.k = k
29 | self.num_iters = num_iters
30 | self.random_seed = random_seed
31 | self.centroids = None
32 | self.clusters = None
33 |
34 | def initialize_centroids(self, X: jnp.ndarray) -> jnp.ndarray:
35 |
36 | indices = jnp.arange(X.shape[0])
37 | selected = jax.random.choice(
38 | jax.random.PRNGKey(self.random_seed),
39 | indices,
40 | shape=(self.k,),
41 | replace=False,
42 | )
43 | return X[selected]
44 |
45 | def assign_clusters(self, X: jnp.ndarray, centroids: jnp.ndarray) -> jnp.ndarray:
46 |
47 | distances = jnp.sqrt(
48 | ((X[:, jnp.newaxis, :] - centroids[jnp.newaxis, :, :]) ** 2).sum(axis=2)
49 | )
50 | return jnp.argmin(distances, axis=1)
51 |
52 | def update_centroids(self, X: jnp.ndarray, clusters: jnp.ndarray) -> jnp.ndarray:
53 |
54 | return jnp.array([X[clusters == i].mean(axis=0) for i in range(self.k)])
55 |
56 | def fit(self, X: jnp.ndarray) -> None:
57 |
58 | self.centroids = self.initialize_centroids(X)
59 | for _ in range(self.num_iters):
60 | self.clusters = self.assign_clusters(X, self.centroids)
61 | new_centroids = self.update_centroids(X, self.clusters)
62 | if jnp.allclose(self.centroids, new_centroids):
63 | break
64 | self.centroids = new_centroids
65 |
66 | def predict(self, X: jnp.ndarray) -> jnp.ndarray:
67 | if self.centroids is None:
68 | raise ValueError("Model not yet trained. Call 'fit' with training data.")
69 | return self.assign_clusters(X, self.centroids)
70 |
71 |
72 | class GaussianMixtureModel:
73 | """
74 | Gaussian Mixture Model implemented in JAX.
75 |
76 | This class represents a Gaussian Mixture Model (GMM) for clustering and density estimation.
77 | It uses the Expectation-Maximization (EM) algorithm for fitting the model to data.
78 |
79 | Attributes:
80 | n_components (int): Number of mixture components.
81 | tol (float): Tolerance for convergence.
82 | max_iter (int): Maximum number of iterations for the EM algorithm.
83 | means (jnp.ndarray): Means of the Gaussian components.
84 | covariances (jnp.ndarray): Covariances of the Gaussian components.
85 | weights (jnp.ndarray): Weights of the Gaussian components.
86 | seed (int): Random seed for initialization.
87 |
88 | Example:
89 | ```
90 | >>> import jax.numpy as jnp
91 | >>> from gaussian_mixture_model_jax import GaussianMixtureModelJAX
92 | >>> X = jnp.array([[1, 2], [1, 4], [1, 0],
93 | ... [10, 2], [10, 4], [10, 0]])
94 | >>> gmm = GaussianMixtureModelJAX(n_components=2, seed=42)
95 | >>> gmm.fit(X)
96 | >>> print(gmm.means)
97 | >>> labels = gmm.predict(X)
98 | >>> print(labels)
99 | ```
100 | """
101 |
102 | def __init__(
103 | self, n_components: int, tol: float = 1e-3, max_iter: int = 100, seed: int = 0
104 | ) -> None:
105 | self.n_components = n_components
106 | self.tol = tol
107 | self.max_iter = max_iter
108 | self.means = None
109 | self.covariances = None
110 | self.weights = None
111 | self.seed = seed
112 |
113 | def fit(self, X: jnp.ndarray) -> None:
114 | _, n_features = X.shape
115 | rng = jax.random.PRNGKey(self.seed)
116 |
117 | self.means = jax.random.normal(rng, (self.n_components, n_features))
118 | self.covariances = jnp.array([jnp.eye(n_features)] * self.n_components)
119 | self.weights = jnp.ones(self.n_components) / self.n_components
120 |
121 | log_likelihood = 0
122 | for _ in range(self.max_iter):
123 | responsibilities = self._e_step(X)
124 | self._m_step(X, responsibilities)
125 |
126 | new_log_likelihood = self._compute_log_likelihood(X)
127 | if jnp.abs(new_log_likelihood - log_likelihood) < self.tol:
128 | break
129 | log_likelihood = new_log_likelihood
130 |
131 | def _e_step(self, X: jnp.ndarray) -> jnp.ndarray:
132 | responsibilities = jnp.zeros((X.shape[0], self.n_components))
133 | for k in range(self.n_components):
134 | responsibilities = responsibilities.at[:, k].set(
135 | self.weights[k]
136 | * self._multivariate_gaussian(X, self.means[k], self.covariances[k])
137 | )
138 | responsibilities /= responsibilities.sum(axis=1, keepdims=True)
139 | return responsibilities
140 |
141 | def _m_step(self, X: jnp.ndarray, responsibilities: jnp.ndarray) -> None:
142 | n_samples = X.shape[0]
143 | for k in range(self.n_components):
144 | Nk = responsibilities[:, k].sum()
145 | self.means = self.means.at[k].set(
146 | (1 / Nk) * jnp.dot(responsibilities[:, k], X)
147 | )
148 | diff = X - self.means[k]
149 | self.covariances = self.covariances.at[k].set(
150 | (1 / Nk) * jnp.dot(responsibilities[:, k] * diff.T, diff)
151 | )
152 | self.weights = self.weights.at[k].set(Nk / n_samples)
153 |
154 | def _multivariate_gaussian(
155 | self, X: jnp.ndarray, mean: jnp.ndarray, cov: jnp.ndarray
156 | ) -> jnp.ndarray:
157 | n = X.shape[1]
158 | diff = X - mean
159 | return jnp.exp(
160 | -0.5 * jnp.sum(jnp.dot(diff, jnp.linalg.inv(cov)) * diff, axis=1)
161 | ) / (jnp.sqrt((2 * jnp.pi) ** n * jnp.linalg.det(cov)))
162 |
163 | def _compute_log_likelihood(self, X: jnp.ndarray) -> float:
164 | log_likelihood = 0
165 | for k in range(self.n_components):
166 | log_likelihood += jnp.sum(
167 | jnp.log(
168 | self.weights[k]
169 | * self._multivariate_gaussian(X, self.means[k], self.covariances[k])
170 | )
171 | )
172 | return log_likelihood
173 |
174 | def predict(self, X: jnp.ndarray) -> jnp.ndarray:
175 | responsibilities = self._e_step(X)
176 | return jnp.argmax(responsibilities, axis=1)
177 |
--------------------------------------------------------------------------------
/nanodl/__src/classical/dimensionality_reduction.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import jax
4 | import jax.numpy as jnp
5 |
6 |
7 | class PCA:
8 | """
9 | A class for performing Principal Component Analysis (PCA) on data.
10 |
11 | Attributes:
12 | n_components (int): Number of principal components to retain.
13 | components (ndarray): Principal components learned from the data.
14 | mean (ndarray): Mean of the data used for centering.
15 |
16 | Methods:
17 | fit(X):
18 | Fit the PCA model on the input data.
19 |
20 | transform(X):
21 | Transform the input data into the PCA space.
22 |
23 | inverse_transform(X_transformed):
24 | Inverse transform PCA-transformed data back to the original space.
25 |
26 | sample(n_samples=1, key=None):
27 | Generate synthetic data samples from the learned PCA distribution.
28 |
29 | Example Usage:
30 | # Create an instance of the PCA class
31 | data = jax.random.normal(jax.random.key(0), (1000, 10))
32 | pca = PCA(n_components=2)
33 | pca.fit(data)
34 | transformed_data = pca.transform(data)
35 | original_data = pca.inverse_transform(transformed_data)
36 | X_sampled = pca.sample(n_samples=1000, key=None)
37 | print(X_sampled.shape, original_data.shape, transformed_data.shape)
38 | """
39 |
40 | def __init__(self, n_components: int):
41 |
42 | self.n_components = n_components
43 | self.components = None
44 | self.mean = None
45 |
46 | def fit(self, X: jnp.ndarray) -> None:
47 |
48 | self.mean = jnp.mean(X, axis=0)
49 | X_centered = X - self.mean
50 | cov_matrix = jnp.cov(X_centered, rowvar=False)
51 | eigvals, eigvecs = jnp.linalg.eigh(cov_matrix)
52 | sorted_indices = jnp.argsort(eigvals)[::-1]
53 | sorted_eigvecs = eigvecs[:, sorted_indices]
54 | self.components = sorted_eigvecs[:, : self.n_components]
55 |
56 | def transform(self, X: jnp.ndarray) -> jnp.ndarray:
57 | X_centered = X - self.mean
58 | return jnp.dot(X_centered, self.components)
59 |
60 | def inverse_transform(self, X_transformed: jnp.ndarray) -> jnp.ndarray:
61 | return jnp.dot(X_transformed, self.components.T) + self.mean
62 |
63 | def sample(
64 | self, n_samples: int = 1, key: Optional[jnp.ndarray] = None
65 | ) -> jnp.ndarray:
66 |
67 | if key is None:
68 | key = jax.random.PRNGKey(0)
69 |
70 | z = jax.random.normal(key, (n_samples, self.n_components))
71 | X_sampled = self.inverse_transform(z)
72 | return X_sampled
73 |
--------------------------------------------------------------------------------
/nanodl/__src/classical/dsp.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import jax.numpy as jnp
4 | from jax import random
5 |
6 |
7 | def fastica(
8 | X: jnp.ndarray, n_components: jnp.ndarray, max_iter: int = 1000, tol: float = 1e-4
9 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
10 | """
11 | Perform Independent Component Analysis (ICA) on the input data using the FastICA algorithm.
12 |
13 | Parameters:
14 | X : jax.numpy.ndarray
15 | The input data matrix, where each row represents a data point, and each column represents a different signal.
16 | The input data should be a 2D jax.numpy array with shape (n_samples, n_features).
17 | n_components : int
18 | The number of independent components to extract. This should be less than or equal to the number of features in the input data.
19 | max_iter : int, optional
20 | The maximum number of iterations for the optimization process. The default value is 1000 iterations.
21 | tol : float, optional
22 | The tolerance for convergence. The optimization process stops when the maximum absolute change in the diagonal elements of the
23 | unmixing matrix from one iteration to the next is less than this tolerance. The default value is 1e-4.
24 |
25 | Returns:
26 | S : jax.numpy.ndarray
27 | The separated independent components. This is a 2D jax.numpy array with shape (n_components, n_samples), where each row represents
28 | a different independent component, and each column represents a data point.
29 | W : jax.numpy.ndarray
30 | The unmixing matrix. This is a 2D jax.numpy array with shape (n_components, n_features), representing the estimated inverse of the
31 | mixing matrix. It is used to transform the input data back into the independent components.
32 | whitening_matrix : jax.numpy.ndarray
33 | The whitening matrix used to whiten the input data. This is a 2D jax.numpy array with shape (n_features, n_features), used to decorrelate
34 | the input data and make its covariance matrix the identity matrix.
35 |
36 | Description:
37 | The FastICA algorithm aims to separate the mixed input signals into statistically independent components. The function first whitens the input
38 | data to decorrelate it and normalize its variance. Then, it initializes a random unmixing matrix and uses an optimization process to find
39 | the optimal unmixing matrix that maximizes the independence of the source signals.
40 |
41 | The optimization process involves iteratively updating the unmixing matrix based on the non-linear function (`tanh` in this case) applied
42 | to the transformed data (`WX`). The process stops when the unmixing matrix converges according to the specified tolerance (`tol`) or when the
43 | maximum number of iterations (`max_iter`) is reached.
44 |
45 | Once the optimal unmixing matrix is found, the function applies it to the whitened data to obtain the separated independent components.
46 |
47 | Example usage:
48 | # Set random seed
49 | jax.random.PRNGKey(42)
50 |
51 | # Generate synthetic source signals
52 | n_samples = 2000
53 | time = jnp.linspace(0, 8, n_samples)
54 | s1 = jnp.sin(2 * time)
55 | s2 = jnp.sign(jnp.sin(3 * time))
56 |
57 | # Combine the sources with a mixing matrix
58 | A = jnp.array([[1, 1], [0.5, 2]])
59 | X = jnp.dot(A, jnp.array([s1, s2]))
60 |
61 | # Perform ICA
62 | n_components = 2
63 | S, W, whitening_matrix = fastica(X.T, n_components)
64 |
65 | # Plot the results
66 | plt.figure(figsize=(12, 8))
67 |
68 | plt.subplot(3, 1, 1)
69 | plt.title('Original Source Signals')
70 | plt.plot(time, s1, label='Source 1 (Sine Wave)')
71 | plt.plot(time, s2, label='Source 2 (Square Wave)')
72 | plt.legend()
73 |
74 | plt.subplot(3, 1, 2)
75 | plt.title('Mixed Signals')
76 | plt.plot(time, X[0], label='Mixed Signal 1')
77 | plt.plot(time, X[1], label='Mixed Signal 2')
78 | plt.legend()
79 |
80 | plt.subplot(3, 1, 3)
81 | plt.title('Separated Signals (Using ICA)')
82 | plt.plot(time, S[0], label='Separated Signal 1')
83 | plt.plot(time, S[1], label='Separated Signal 2')
84 | plt.legend()
85 | s
86 | plt.tight_layout()
87 | plt.show()
88 | """
89 | # Calculate the covariance matrix and perform eigenvalue decomposition
90 | cov_matrix = jnp.cov(X, rowvar=False)
91 | eigenvalues, eigenvectors = jnp.linalg.eigh(cov_matrix)
92 |
93 | # Sort the eigenvalues and eigenvectors
94 | idx = jnp.argsort(eigenvalues)[::-1]
95 | eigenvalues = eigenvalues[idx]
96 | eigenvectors = eigenvectors[:, idx]
97 |
98 | # Create the whitening matrix
99 | D = jnp.diag(1.0 / jnp.sqrt(eigenvalues))
100 | whitening_matrix = jnp.dot(eigenvectors, D)
101 | X_whitened = jnp.dot(X, whitening_matrix)
102 |
103 | # Initialize unmixing matrix with random values
104 | rng = random.PRNGKey(0) # Set a seed for reproducibility
105 | W = random.normal(rng, (n_components, n_components))
106 |
107 | # Perform FastICA algorithm
108 | for _ in range(max_iter):
109 | WX = jnp.dot(X_whitened, W.T)
110 | g = jnp.tanh(WX)
111 | g_prime = 1 - g**2
112 | W_new = (jnp.dot(X_whitened.T, g) / X.shape[0]) - jnp.diag(
113 | g_prime.mean(axis=0)
114 | ).dot(W)
115 |
116 | # Orthogonalize the unmixing matrix
117 | W_new, _ = jnp.linalg.qr(W_new)
118 |
119 | # Check for convergence
120 | if jnp.max(jnp.abs(jnp.abs(jnp.diag(jnp.dot(W_new, W.T))) - 1)) < tol:
121 | W = W_new
122 | break
123 |
124 | W = W_new
125 |
126 | # Calculate the separated independent components
127 | S = jnp.dot(W, X_whitened.T)
128 |
129 | return S, W, whitening_matrix
130 |
--------------------------------------------------------------------------------
/nanodl/__src/classical/regression.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Tuple
2 |
3 | import jax
4 | import jax.numpy as jnp
5 |
6 |
7 | class LinearRegression:
8 | """
9 | Linear Regression model implemented using JAX.
10 |
11 | Parameters:
12 | - input_dim (int): Dimension of the input feature.
13 | - output_dim (int): Dimension of the output target.
14 |
15 | Methods:
16 | - linear_regression(params, x): Linear regression prediction function.
17 | - loss(params, x, y): Mean squared error loss function.
18 | - fit(x_data, y_data, learning_rate=0.1, num_epochs=100): Training function.
19 | - get_params(): Get the learned weights and bias.
20 |
21 | Example usage:
22 | ```
23 | num_samples = 100
24 | input_dim = 1
25 | output_dim = 1
26 |
27 | x_data = jax.random.normal(random.PRNGKey(0), (num_samples, input_dim))
28 | y_data = jnp.dot(x_data, jnp.array([[2.0]])) - jnp.array([[-1.0]])
29 |
30 | lr_model = LinearRegression(input_dim, output_dim)
31 | lr_model.fit(x_data, y_data)
32 |
33 | learned_weights, learned_bias = lr_model.get_params()
34 | print("Learned Weights:", learned_weights)
35 | print("Learned Bias:", learned_bias)
36 | ```
37 | """
38 |
39 | def __init__(self, input_dim, output_dim):
40 | self.input_dim = input_dim
41 | self.output_dim = output_dim
42 | self.key = jax.random.PRNGKey(0)
43 | self.params = (jnp.zeros((input_dim, output_dim)), jnp.zeros((output_dim,)))
44 |
45 | def linear_regression(self, params, x):
46 | weights, bias = params
47 | return jnp.dot(x, weights) + bias
48 |
49 | def loss(self, params, x, y):
50 | predictions = self.linear_regression(params, x)
51 | return jnp.mean((predictions - y) ** 2)
52 |
53 | def fit(self, x_data, y_data, learning_rate=0.1, num_epochs=100):
54 | grad_loss = jax.grad(self.loss)
55 | for epoch in range(num_epochs):
56 | grads = grad_loss(self.params, x_data, y_data)
57 | weights_grad, bias_grad = grads
58 | weights, bias = self.params
59 | weights -= learning_rate * weights_grad
60 | bias -= learning_rate * bias_grad
61 | self.params = (weights, bias)
62 | epoch_loss = self.loss(self.params, x_data, y_data)
63 | print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}")
64 |
65 | print("Training completed.")
66 |
67 | def get_params(self):
68 | return self.params
69 |
70 |
71 | class LogisticRegression:
72 | """
73 | Logistic Regression model implemented using JAX.
74 |
75 | Parameters:
76 | - input_dim (int): Dimension of the input feature.
77 |
78 | Methods:
79 | - sigmoid(x): Sigmoid activation function.
80 | - logistic_regression(params, x): Logistic regression prediction function.
81 | - loss(params, x, y): Binary cross-entropy loss function.
82 | - fit(x_data, y_data, learning_rate=0.1, num_epochs=100): Training function.
83 | - predict(x_data): Predict probabilities using the trained model.
84 |
85 | Example usage:
86 | ```
87 | num_samples = 100
88 | input_dim = 2
89 |
90 | x_data = jax.random.normal(random.PRNGKey(0), (num_samples, input_dim))
91 | logits = jnp.dot(x_data, jnp.array([0.5, -0.5])) - 0.1
92 | y_data = (logits > 0).astype(jnp.float32)
93 |
94 | lr_model = LogisticRegression(input_dim)
95 | lr_model.fit(x_data, y_data)
96 |
97 | test_data = jax.random.normal(random.PRNGKey(0), (num_samples, input_dim))
98 | predictions = lr_model.predict(test_data)
99 | print("Predictions:", predictions)
100 | ```
101 | """
102 |
103 | def __init__(self, input_dim):
104 | self.input_dim = input_dim
105 | self.key = jax.random.PRNGKey(0)
106 | self.params = (jnp.zeros((input_dim,)), jnp.zeros(()))
107 |
108 | def sigmoid(self, x):
109 | return 1.0 / (1.0 + jnp.exp(-x))
110 |
111 | def logistic_regression(self, params, x):
112 | weights, bias = params
113 | return self.sigmoid(jnp.dot(x, weights) + bias)
114 |
115 | def loss(self, params, x, y):
116 | predictions = self.logistic_regression(params, x)
117 | return -jnp.mean(y * jnp.log(predictions) + (1 - y) * jnp.log(1 - predictions))
118 |
119 | def fit(self, x_data, y_data, learning_rate=0.1, num_epochs=100):
120 | grad_loss = jax.grad(self.loss)
121 | for epoch in range(num_epochs):
122 | grads = grad_loss(self.params, x_data, y_data)
123 | weights_grad, bias_grad = grads
124 | weights, bias = self.params
125 | weights -= learning_rate * weights_grad
126 | bias -= learning_rate * bias_grad
127 | self.params = (weights, bias)
128 | epoch_loss = self.loss(self.params, x_data, y_data)
129 | print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}")
130 |
131 | print("Training completed.")
132 |
133 | def predict(self, x_data):
134 | return self.logistic_regression(self.params, x_data)
135 |
136 |
137 | class GaussianProcess:
138 | """
139 | A basic implementation of Gaussian Process regression using JAX.
140 |
141 | Attributes:
142 | kernel (Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]): The kernel function to measure similarity between data points.
143 | noise (float): Measurement noise added to the diagonal of the kernel matrix.
144 | X (jnp.ndarray): Training inputs.
145 | y (jnp.ndarray): Training outputs.
146 | K (jnp.ndarray): Kernel matrix incorporating training inputs and noise.
147 |
148 | Methods:
149 | fit(X, y):
150 | Fits the Gaussian Process model to the training data.
151 |
152 | predict(X_new):
153 | Makes predictions for new input points using the trained model.
154 |
155 | Example Usage:
156 | # Define a kernel function, e.g., Radial Basis Function (RBF) kernel
157 | def rbf_kernel(x1, x2, length_scale=1.0):
158 | diff = x1[:, None] - x2
159 | return jnp.exp(-0.5 * jnp.sum(diff**2, axis=-1) / length_scale**2)
160 |
161 | # Create an instance of the GaussianProcess class
162 | gp = GaussianProcess(kernel=rbf_kernel, noise=1e-3)
163 |
164 | # Fit the model on the training data
165 | gp.fit(X_train, y_train)
166 |
167 | # Make predictions on new data
168 | mean, covariance = gp.predict(X_new)
169 | """
170 |
171 | def __init__(
172 | self,
173 | kernel: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
174 | noise: float = 1e-3,
175 | ):
176 |
177 | self.kernel = kernel
178 | self.noise = noise
179 | self.X = None
180 | self.y = None
181 | self.K = None
182 |
183 | def fit(self, X: jnp.ndarray, y: jnp.ndarray) -> None:
184 |
185 | self.X = X
186 | self.y = y
187 | self.K = self.kernel(self.X, self.X) + jnp.eye(len(X)) * self.noise
188 |
189 | def predict(self, X_new: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
190 | K_inv = jnp.linalg.inv(self.K)
191 | K_s = self.kernel(self.X, X_new)
192 | K_ss = self.kernel(X_new, X_new)
193 |
194 | mu_s = jnp.dot(K_s.T, jnp.dot(K_inv, self.y))
195 | cov_s = K_ss - jnp.dot(K_s.T, jnp.dot(K_inv, K_s))
196 | return mu_s, cov_s
197 |
--------------------------------------------------------------------------------
/nanodl/__src/experimental/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HMUNACHI/nanodl/e52861a5b2c9bf76e4e79e0bf88a07420497579d/nanodl/__src/experimental/__init__.py
--------------------------------------------------------------------------------
/nanodl/__src/experimental/bitlinear.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from flax import linen as nn
4 |
5 |
6 | class BitLinear(nn.Module):
7 | """
8 | Implements a linear transformation layer with quantization for both activations and weights,
9 | optimized for low-bit inference. The layer is designed to operate in two modes: training and inference.
10 | During training, the activations and weights are quantized using separate quantization functions,
11 | aiming to simulate low-bit operations and reduce the quantization error. For inference, a more
12 | aggressive quantization scheme is applied to both activations and weights, potentially different
13 | from the training quantization, to maximize performance and efficiency on low-bit hardware.
14 |
15 | Attributes:
16 | output_features (int): The number of output features.
17 | kernel_init (callable): A function to initialize the weights. Default is LeCun normal initializer.
18 | """
19 |
20 | output_features: int
21 | kernel_init: callable = nn.initializers.lecun_normal()
22 |
23 | @nn.compact
24 | def __call__(self, x, training=False):
25 | w = self.param("kernel", self.kernel_init, (x.shape[-1], self.output_features))
26 |
27 | if not training:
28 | x_quant, x_scale = self.fused_activation_norm_quant(x)
29 |
30 | # HELP: How run externally on params at once for efficiency
31 | # Quantising weigts all over again each call is repeated work
32 | # This can be done on params dict using jax tree utils.
33 | # Albeit the weight scale for quantisation needs to be utilised at inference
34 | # Its easy to bypassed on its own by passing the weight scale during a call
35 | # This will be a module in various transformer models in my project (NanoDL)
36 | # Is there a way to achieve this without complication my existing codebase?
37 | w, w_scale = self.inference_weight_quant(w)
38 |
39 | return self.inference_lowbit_matmul(x_quant, w) / w_scale / x_scale
40 |
41 | x_norm = self.rmsnorm(x)
42 | x_quant = x_norm + jax.lax.stop_gradient(self.activation_quant(x_norm) - x_norm)
43 | w_quant = w + jax.lax.stop_gradient(self.weight_quant(w) - w)
44 | return jnp.dot(x_quant, w_quant)
45 |
46 | def rmsnorm(self, x):
47 | return x / jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + 1e-5)
48 |
49 | def activation_quant(self, x):
50 | scale = 127.0 / jnp.max(jnp.abs(x), axis=-1, keepdims=True).clip(min=1e-5)
51 | y = jnp.round(x * scale).clip(-128, 127) / scale
52 | return y
53 |
54 | def weight_quant(self, w):
55 | scale = 1.0 / jnp.mean(jnp.abs(w)).clip(min=1e-5)
56 | u = jnp.round(w * scale).clip(-1, 1) / scale
57 | return u
58 |
59 | def fused_activation_norm_quant(self, x):
60 | x_norm = self.rmsnorm(x)
61 | scale = 127.0 / jnp.max(jnp.abs(x_norm), axis=-1, keepdims=True).clip(min=1e-5)
62 | x_quant = jnp.round(x_norm * scale).clip(-128, 127) / scale
63 | return x_quant, scale
64 |
65 | def inference_weight_quant(self, w):
66 | scale = jnp.abs(w).mean().clip(min=1e-5)
67 | u = jnp.sign(w - w.mean()) * scale
68 | return u, scale
69 |
70 | # Help: how to implement lowbit matmul kernel for efficiency that can be integrated into Flax model
71 | def inference_lowbit_matmul(self, x, w):
72 | return jnp.dot(x, w)
73 |
--------------------------------------------------------------------------------
/nanodl/__src/experimental/gat.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from flax import linen as nn
4 |
5 |
6 | class GraphAttentionLayer(nn.Module):
7 | """
8 | A single graph attention layer as part of a Graph Attention Network (GAT).
9 |
10 | This layer applies a self-attention mechanism on the nodes of a graph. Each node's features are transformed through a learned linear transformation, and attention coefficients are computed to determine the importance of every other node's features. This allows the model to dynamically adjust which nodes contribute most to the next layer's input for each node.
11 |
12 | Attributes:
13 | in_features (int): Number of input features per node.
14 | out_features (int): Number of output features per node.
15 | dropout_rate (float): Dropout rate applied to features and attention coefficients for regularization.
16 | alpha (float): Negative slope coefficient for the LeakyReLU activation function used in computing attention scores.
17 | concat (bool, optional): Whether to concatenate the output of attention heads in a multi-head attention mechanism. Default is True.
18 |
19 | Methods:
20 | __call__(self, x: jnp.ndarray, adj: jnp.ndarray, training: bool) -> jnp.ndarray:
21 | Forward pass of the graph attention layer.
22 |
23 | Args:
24 | x (jnp.ndarray): The input node features, shape (N, in_features), where N is the number of nodes.
25 | adj (jnp.ndarray): The adjacency matrix of the graph, shape (N, N), indicating node connections.
26 | training (bool): Whether the layer is being used in training mode. Affects dropout behavior.
27 |
28 | Returns:
29 | jnp.ndarray: The output node features after the attention mechanism. If `concat` is True, applies a non-linearity (LeakyReLU); otherwise, returns the linear combination of features directly. Shape is (N, out_features).
30 | """
31 |
32 | in_features: int
33 | out_features: int
34 | dropout_rate: float
35 | alpha: float
36 | concat: bool = True
37 |
38 | @nn.compact
39 | def __call__(self, x: jnp.ndarray, adj: jnp.ndarray, training: bool) -> jnp.ndarray:
40 |
41 | W = self.param(
42 | "W",
43 | jax.nn.initializers.glorot_uniform(),
44 | (self.in_features, self.out_features),
45 | )
46 |
47 | a = self.param(
48 | "a", jax.nn.initializers.glorot_uniform(), (2 * self.out_features, 1)
49 | )
50 |
51 | h = jnp.dot(x, W)
52 | h = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(h)
53 |
54 | N = h.shape[0]
55 | a_input = jnp.concatenate(
56 | [h[:, None, :].repeat(N, axis=1), h[None, :, :].repeat(N, axis=0)], axis=2
57 | )
58 |
59 | e = nn.leaky_relu(jnp.dot(a_input, a).squeeze(-1), negative_slope=self.alpha)
60 |
61 | zero_vec = -9e15 * jnp.ones_like(e)
62 | attention = jnp.where(adj > 0, e, zero_vec)
63 | attention = nn.softmax(attention, axis=1)
64 |
65 | attention = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(
66 | attention
67 | )
68 |
69 | h_prime = jnp.matmul(attention, h)
70 |
71 | if self.concat:
72 | return nn.leaky_relu(h_prime)
73 | else:
74 | return h_prime
75 |
76 |
77 | class GAT(nn.Module):
78 | """
79 | Graph Attention Networks (GATs) are a type of neural network designed for graph-structured data.
80 | The key feature of GATs is the use of attention mechanisms to weigh the importance of nodes' neighbors.
81 | This allows GATs to focus on the most relevant parts of the graph structure when learning node representations.
82 | In GATs, each node aggregates information from its neighbors, but not all neighbors contribute equally.
83 | The attention mechanism computes weights that determine the importance of each neighbor's features to the target node.
84 | These weights are learned during training and are based on the features of the nodes involved.
85 | GATs can handle graphs with varying sizes and connectivity patterns, making them suitable for a wide range of applications,
86 | including social network analysis, recommendation systems, and molecular structure analysis.
87 |
88 | Example usage:
89 | ```
90 | import jax
91 | import jax.numpy as jnp
92 | from nanodl import ArrayDataset, DataLoader
93 | from nanodl import GAT
94 |
95 | # Generate dummy data
96 | batch_size = 8
97 | max_length = 10
98 | nclass = 3
99 |
100 | # Replace with actual tokenised data
101 | # Generate a random key for Jax
102 | key = jax.random.PRNGKey(0)
103 | num_nodes = 10
104 | num_features = 5
105 | x = jax.random.normal(key, (num_nodes, num_features)) # Features for each node
106 | adj = jax.random.bernoulli(key, 0.3, (num_nodes, num_nodes)) # Random adjacency matrix
107 |
108 | # Initialize the GAT model
109 | model = GAT(nfeat=num_features,
110 | nhid=8,
111 | nclass=nclass,
112 | dropout_rate=0.5,
113 | alpha=0.2,
114 | nheads=3)
115 |
116 | # Initialize the model parameters
117 | params = model.init(key, x, adj)
118 | output = model.apply(params, x, adj)
119 | print("Output shape:", output.shape)
120 | ```
121 |
122 | Attributes:
123 | nfeat (int): Number of features for each node in the input graph.
124 | nhid (int): Number of hidden units in each graph attention layer.
125 | nclass (int): Number of classes for the node classification task.
126 | dropout_rate (float): Dropout rate for regularization. Applied to the inputs of each graph attention layer and the final output.
127 | alpha (float): LeakyReLU angle of negative slope used in the attention mechanism.
128 | nheads (int): Number of attention heads. Each head computes a separate attention mechanism over the input, and their results are concatenated.
129 |
130 | Methods:
131 | __call__(self, x: jnp.ndarray, adj: jnp.ndarray, training: bool = False) -> jnp.ndarray:
132 | Forward pass of the GAT model.
133 |
134 | Args:
135 | x (jnp.ndarray): Node features matrix with shape (N, nfeat), where N is the number of nodes in the graph.
136 | adj (jnp.ndarray): Adjacency matrix of the graph with shape (N, N). It should represent the graph structure.
137 | training (bool, optional): Flag to indicate whether the model is being used for training. Affects dropout behavior. Defaults to False.
138 |
139 | Returns:
140 | jnp.ndarray: The output node features after passing through the GAT model. Shape is (N, nclass), representing the class scores for each node.
141 | """
142 |
143 | nfeat: int
144 | nhid: int
145 | nclass: int
146 | dropout_rate: float
147 | alpha: float
148 | nheads: int
149 |
150 | @nn.compact
151 | def __call__(
152 | self, x: jnp.ndarray, adj: jnp.ndarray, training: bool = False
153 | ) -> jnp.ndarray:
154 |
155 | heads = [
156 | GraphAttentionLayer(
157 | self.nfeat,
158 | self.nhid,
159 | dropout_rate=self.dropout_rate,
160 | alpha=self.alpha,
161 | concat=True,
162 | )
163 | for _ in range(self.nheads)
164 | ]
165 |
166 | x = jnp.concatenate([head(x, adj, training) for head in heads], axis=1)
167 |
168 | x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x)
169 |
170 | out_att = GraphAttentionLayer(
171 | self.nhid * self.nheads,
172 | self.nclass,
173 | dropout_rate=self.dropout_rate,
174 | alpha=self.alpha,
175 | concat=False,
176 | )
177 |
178 | return out_att(x, adj, training)
179 |
--------------------------------------------------------------------------------
/nanodl/__src/experimental/tokenizer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Optional
3 |
4 | from sentencepiece import SentencePieceProcessor, SentencePieceTrainer
5 |
6 |
7 | class Tokenizer:
8 | """
9 | A tokenizer class that utilizes SentencePiece to encode and decode text.
10 |
11 | This class can be initialized with either an existing SentencePiece model
12 | or a dataset to train a new model. It provides methods to encode a string
13 | to a list of token ids and decode a list of token ids back to a string.
14 |
15 | Attributes:
16 | sp_model (SentencePieceProcessor): The SentencePiece processor.
17 | n_words (int): Number of words in the vocabulary.
18 | bos_id (int): Token id for the beginning of a sentence.
19 | eos_id (int): Token id for the end of a sentence.
20 | pad_id (int): Token id for padding.
21 |
22 | Example usage:
23 |
24 | Training a new model and encoding/decoding a string:
25 |
26 | ```python
27 | # Initialize tokenizer with training data and train a new model.
28 | text_paths = ['/Users/mac1/Desktop/nanodl/nanodl/__src/utils/sample.txt']
29 |
30 | tokenizer = Tokenizer(training_data=text_paths,
31 | vocab_size=100,
32 | model_type='bpe',
33 | max_sentence_length=50)
34 |
35 | # Encode a sentence.
36 | encoded_sentence = tokenizer.encode('Hello, world!')
37 | print(f'Encoded: {encoded_sentence}')
38 |
39 | # Decode the encoded sentence.
40 | decoded_sentence = tokenizer.decode(encoded_sentence)
41 | print(f'Decoded: {decoded_sentence}')
42 | ```
43 |
44 | Loading an existing model and encoding/decoding a string:
45 |
46 | ```python
47 | # Initialize tokenizer with a pre-trained model.
48 | tokenizer = Tokenizer(model_path='path/to/model.model')
49 |
50 | # Encode a sentence.
51 | encoded_sentence = tokenizer.encode('Hello, world!')
52 | print(f'Encoded: {encoded_sentence}')
53 |
54 | # Decode the encoded sentence.
55 | decoded_sentence = tokenizer.decode(encoded_sentence)
56 | print(f'Decoded: {decoded_sentence}')
57 | ```
58 | """
59 |
60 | def __init__(
61 | self,
62 | training_data: List[str] = None,
63 | vocab_size: int = None,
64 | model_type: str = "bpe",
65 | max_sentence_length: int = 512,
66 | model_path: Optional[str] = None,
67 | ):
68 |
69 | if model_path and os.path.isfile(model_path):
70 | # Load an existing model
71 | self.sp_model = SentencePieceProcessor(model_file=model_path)
72 | elif training_data and all(os.path.isfile(f) for f in training_data):
73 | # Train a new model using a list of data files
74 | input_files = ",".join(training_data)
75 | model_prefix = "trained_model"
76 | SentencePieceTrainer.train(
77 | input=input_files,
78 | model_prefix=model_prefix,
79 | vocab_size=vocab_size,
80 | model_type=model_type,
81 | max_sentence_length=max_sentence_length,
82 | )
83 |
84 | self.sp_model = SentencePieceProcessor(model_file=f"{model_prefix}.model")
85 | else:
86 | raise ValueError(
87 | "Must provide either a model_path or a non-empty training_data list"
88 | )
89 |
90 | # Initialize token IDs
91 | self.n_words: int = self.sp_model.vocab_size()
92 | self.bos_id: int = self.sp_model.bos_id()
93 | self.eos_id: int = self.sp_model.eos_id()
94 | self.pad_id: int = self.sp_model.pad_id()
95 |
96 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
97 |
98 | def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]:
99 | """Converts a string into a list of tokens."""
100 | assert isinstance(s, str)
101 | t = self.sp_model.encode(s)
102 | if bos:
103 | t = [self.bos_id] + t
104 | if eos:
105 | t = t + [self.eos_id]
106 | return t
107 |
108 | def decode(self, t: List[int]) -> str:
109 | """Converts a list of tokens back into a string."""
110 | return self.sp_model.decode(t)
111 |
--------------------------------------------------------------------------------
/nanodl/__src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HMUNACHI/nanodl/e52861a5b2c9bf76e4e79e0bf88a07420497579d/nanodl/__src/models/__init__.py
--------------------------------------------------------------------------------
/nanodl/__src/models/kan.py:
--------------------------------------------------------------------------------
1 | from flax import linen as nn
2 | import jax.numpy as jnp
3 | import jax
4 | from jax import random
5 | from typing import Any
6 |
7 |
8 | class KANLinear(nn.Module):
9 | """
10 | A Flax module implementing a B-spline Neural Network layer, where the basis functions are B-splines.
11 |
12 | Attributes:
13 | in_features (int): Number of input features.
14 | out_features (int): Number of output features.
15 | degree (int): Degree of the B-splines.
16 | """
17 | in_features: int
18 | out_features: int
19 | degree: int
20 |
21 | def setup(self) -> None:
22 | assert self.degree > 0, "Degree of the B-splines must be greater than 0"
23 | mean, std = 0.0, 1 / (self.in_features * (self.degree + 1))
24 | self.coefficients = self.param(
25 | "coefficients",
26 | lambda key, shape: mean + std * random.normal(key, shape),
27 | (self.in_features, self.out_features, self.degree + 1)
28 | )
29 |
30 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
31 | x = jnp.tanh(x)
32 |
33 | knots = jnp.linspace(-1, 1, self.degree + self.in_features + 1)
34 |
35 | b_spline_values = jnp.array([
36 | self.bspline_basis(x[:, i], self.degree, knots) for i in range(self.in_features)
37 | ])
38 |
39 | b_spline_values = b_spline_values.transpose((1, 0, 2))
40 |
41 | output = jnp.einsum('bid,ijd->bj', b_spline_values, self.coefficients)
42 | return output
43 |
44 | def bspline_basis(self, x: jnp.ndarray, degree: int, knots: jnp.ndarray) -> jnp.ndarray:
45 |
46 | def cox_de_boor(x, k, i, t):
47 | if k == 0:
48 | return jnp.where((t[i] <= x) & (x < t[i + 1]), 1.0, 0.0)
49 | else:
50 | denom1 = t[i + k] - t[i]
51 | denom2 = t[i + k + 1] - t[i + 1]
52 |
53 | term1 = jnp.where(denom1 != 0, (x - t[i]) / denom1, 0.0) * cox_de_boor(x, k - 1, i, t)
54 | term2 = jnp.where(denom2 != 0, (t[i + k + 1] - x) / denom2, 0.0) * cox_de_boor(x, k - 1, i + 1, t)
55 |
56 | return term1 + term2
57 |
58 | n_basis = len(knots) - degree - 1
59 | basis = jnp.array([cox_de_boor(x, degree, i, knots) for i in range(n_basis)])
60 | return jnp.transpose(basis)
61 |
62 |
63 | class ChebyKANLinear(nn.Module):
64 | """
65 | A Flax module implementing a Chebyshev Neural Network layer, where the basis functions are Chebyshev polynomials.
66 |
67 | Inspired by https://github.com/CG80499/KAN-GPT-2/blob/master/chebykan_layer.py
68 |
69 | Attributes:
70 | in_features (int): Number of input features.
71 | out_features (int): Number of output features.
72 | degree (int): Degree of the Chebyshev polynomials.
73 | """
74 | in_features: int
75 | out_features: int
76 | degree: int
77 |
78 | def setup(self) -> None:
79 | mean, std = 0.0, 1 / (self.in_features * (self.degree + 1))
80 | self.coefficients = self.param(
81 | "coefficients",
82 | lambda key, shape: mean + std * random.normal(key, shape),
83 | (self.in_features, self.out_features, self.degree + 1)
84 | )
85 |
86 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
87 |
88 | x = jnp.tanh(x)
89 |
90 | cheby_values = jnp.ones((x.shape[0], self.in_features, self.degree + 1))
91 | cheby_values = cheby_values.at[:, :, 1].set(x)
92 |
93 | for i in range(2, self.degree + 1):
94 | next_value = 2 * x * cheby_values[:, :, i - 1] - cheby_values[:, :, i - 2]
95 | cheby_values = cheby_values.at[:, :, i].set(next_value)
96 |
97 | output = jnp.einsum('bid,ijd->bj', cheby_values, self.coefficients)
98 | return output
99 |
100 |
101 | class LegendreKANLinear(nn.Module):
102 | """
103 | A Flax module implementing a Legendre Neural Network layer, where the basis functions are Legendre polynomials.
104 |
105 | Attributes:
106 | in_features (int): Number of input features.
107 | out_features (int): Number of output features.
108 | degree (int): Degree of the Legendre polynomials.
109 | """
110 | in_features: int
111 | out_features: int
112 | degree: int
113 |
114 | def setup(self) -> None:
115 | assert self.degree > 0, "Degree of the Legendre polynomials must be greater than 0"
116 | mean, std = 0.0, 1 / (self.in_features * (self.degree + 1))
117 | self.coefficients = self.param(
118 | "coefficients",
119 | lambda key, shape: mean + std * random.normal(key, shape),
120 | (self.in_features, self.out_features, self.degree + 1)
121 | )
122 |
123 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
124 | x = jnp.tanh(x)
125 |
126 | legendre_values = jnp.ones((x.shape[0], self.in_features, self.degree + 1))
127 | legendre_values = legendre_values.at[:, :, 1].set(x)
128 |
129 | for i in range(2, self.degree + 1):
130 | next_value = ((2 * i - 1) * x * legendre_values[:, :, i - 1] - (i - 1) * legendre_values[:, :, i - 2]) / i
131 | legendre_values = legendre_values.at[:, :, i].set(next_value)
132 |
133 | output = jnp.einsum('bid,ijd->bj', legendre_values, self.coefficients)
134 | return output
135 |
136 |
137 | class MonomialKANLinear(nn.Module):
138 | """
139 | A Flax module implementing a Monomial Neural Network layer, where the basis functions are monomials.
140 |
141 | Attributes:
142 | in_features (int): Number of input features.
143 | out_features (int): Number of output features.
144 | degree (int): Degree of the monomial basis functions.
145 | """
146 | in_features: int
147 | out_features: int
148 | degree: int
149 |
150 | def setup(self) -> None:
151 | assert self.degree > 0, "Degree of the monomial basis functions must be greater than 0"
152 | mean, std = 0.0, 1 / (self.in_features * (self.degree + 1))
153 | self.coefficients = self.param(
154 | "coefficients",
155 | lambda key, shape: mean + std * random.normal(key, shape),
156 | (self.in_features, self.out_features, self.degree + 1)
157 | )
158 |
159 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
160 | x = jnp.tanh(x)
161 | monomial_values = jnp.ones((x.shape[0], self.in_features, self.degree + 1))
162 |
163 | for i in range(1, self.degree + 1):
164 | monomial_values = monomial_values.at[:, :, i].set(x ** i)
165 |
166 | output = jnp.einsum('bid,ijd->bj', monomial_values, self.coefficients)
167 | return output
168 |
169 |
170 | class FourierKANLinear(nn.Module):
171 | """
172 | A Flax module implementing a Fourier Neural Network layer, where the basis functions are sine and cosine functions.
173 |
174 | Attributes:
175 | in_features (int): Number of input features.
176 | out_features (int): Number of output features.
177 | degree (int): Degree of the Fourier series (i.e., number of harmonics).
178 | """
179 | in_features: int
180 | out_features: int
181 | degree: int
182 |
183 | def setup(self) -> None:
184 | assert self.degree > 0, "Degree of the Fourier series must be greater than 0"
185 | mean, std = 0.0, 1 / (self.in_features * (2 * self.degree + 1))
186 | self.sine_coefficients = self.param(
187 | "sine_coefficients",
188 | lambda key, shape: mean + std * random.normal(key, shape),
189 | (self.in_features, self.out_features, self.degree)
190 | )
191 | self.cosine_coefficients = self.param(
192 | "cosine_coefficients",
193 | lambda key, shape: mean + std * random.normal(key, shape),
194 | (self.in_features, self.out_features, self.degree)
195 | )
196 |
197 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
198 | x = jnp.tanh(x)
199 | sine_values = jnp.sin(jnp.pi * jnp.arange(1, self.degree + 1) * x[..., None])
200 | cosine_values = jnp.cos(jnp.pi * jnp.arange(1, self.degree + 1) * x[..., None])
201 |
202 | output = (jnp.einsum('bid,ijd->bj', sine_values, self.sine_coefficients) +
203 | jnp.einsum('bid,ijd->bj', cosine_values, self.cosine_coefficients))
204 | return output
205 |
206 |
207 | class HermiteKANLinear(nn.Module):
208 | """
209 | A Flax module implementing a Hermite Neural Network layer, where the basis functions are Hermite polynomials.
210 |
211 | Attributes:
212 | in_features (int): Number of input features.
213 | out_features (int): Number of output features.
214 | degree (int): Degree of the Hermite polynomials.
215 | """
216 | in_features: int
217 | out_features: int
218 | degree: int
219 |
220 | def setup(self) -> None:
221 | assert self.degree > 0, "Degree of the Hermite polynomials must be greater than 0"
222 | mean, std = 0.0, 1 / (self.in_features * (self.degree + 1))
223 | self.coefficients = self.param(
224 | "coefficients",
225 | lambda key, shape: mean + std * random.normal(key, shape),
226 | (self.in_features, self.out_features, self.degree + 1)
227 | )
228 |
229 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
230 | x = jnp.tanh(x)
231 |
232 | hermite_values = jnp.ones((x.shape[0], self.in_features, self.degree + 1))
233 |
234 | if self.degree >= 1:
235 | hermite_values = hermite_values.at[:, :, 1].set(2 * x)
236 | for i in range(2, self.degree + 1):
237 | hermite_values = hermite_values.at[:, :, i].set(
238 | 2 * x * hermite_values[:, :, i - 1] - 2 * (i - 1) * hermite_values[:, :, i - 2]
239 | )
240 |
241 | output = jnp.einsum('bid,ijd->bj', hermite_values, self.coefficients)
242 | return output
243 |
--------------------------------------------------------------------------------
/nanodl/__src/models/reward.py:
--------------------------------------------------------------------------------
1 | import time
2 | from typing import Any, Iterable, Optional, Tuple
3 |
4 | import flax
5 | import flax.linen as nn
6 | import jax
7 | import jax.numpy as jnp
8 | import optax
9 | from flax.training import train_state
10 |
11 |
12 | class RewardModel(nn.Module):
13 | """
14 | The RewardModel estimates the reward or value of a given input sequence,
15 | typically used in reinforcement learning frameworks for natural language processing tasks.
16 | It uses the last hidden state of a transformer-based model to generate a scalar reward prediction,
17 | guiding the agent's behavior by evaluating the desirability or utility of its generated outputs.
18 |
19 | Args:
20 | model (nn.Module): The neural network model to be used.
21 | dim (int): The dimension of the input data.
22 | dropout (float): The dropout rate for the model, a value between 0 and 1.
23 |
24 | Example:
25 | ```python
26 | from nanodl import ArrayDataset, DataLoader
27 | from nanodl import Gemma, RewardModel, RewardDataParallelTrainer
28 |
29 | # Generate dummy data
30 | batch_size = 8
31 | max_length = 10
32 |
33 | # Replace with actual tokenised data
34 | dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32)
35 | dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)
36 |
37 | # Create dataset and dataloader
38 | dataset = ArrayDataset(dummy_chosen, dummy_rejected)
39 | dataloader = DataLoader(dataset,
40 | batch_size=batch_size,
41 | shuffle=True,
42 | drop_last=False)
43 |
44 | # model parameters
45 | hyperparams = {
46 | 'num_layers': 1,
47 | 'hidden_dim': 256,
48 | 'num_heads': 2,
49 | 'feedforward_dim': 256,
50 | 'dropout': 0.1,
51 | 'vocab_size': 1000,
52 | 'embed_dim': 256,
53 | 'max_length': max_length,
54 | 'start_token': 0,
55 | 'end_token': 50,
56 | 'num_groups': 2,
57 | }
58 |
59 | # Initialize reward model from Gemma
60 | model = Gemma(**hyperparams)
61 | reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1)
62 |
63 | # Train the reward model
64 | trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')
65 | trainer.train(dataloader, 5, dataloader)
66 | params = trainer.load_params('reward_model_weights.pkl')
67 |
68 | # Call as you would a regular Flax model
69 | rngs = jax.random.PRNGKey(0)
70 | rngs, dropout_rng = jax.random.split(rngs)
71 | rewards = reward_model.apply({'params': params},
72 | dummy_chosen,
73 | rngs={'dropout': dropout_rng})
74 |
75 | print(rewards.shape)
76 | ```
77 | """
78 |
79 | model: nn.Module
80 | dim: int
81 | dropout: float
82 |
83 | @nn.compact
84 | def __call__(self, x: jnp.ndarray, training: bool = False):
85 |
86 | x = self.model(x, training=training, drop_last_layer=True)
87 | x = nn.Dropout(rate=self.dropout)(x, deterministic=not training)
88 | x = nn.Dense(1)(x)
89 | return nn.sigmoid(x)[:, -1, 0]
90 |
91 |
92 | class RewardDataParallelTrainer:
93 | """
94 | Trainer class using data parallelism with JAX.
95 | This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs).
96 | It handles the model training loop, including gradient computation, parameter updates, and evaluation.
97 |
98 | Attributes:
99 | model (Any): The model to be trained.
100 | input_shape (Tuple[int, ...]): The shape of the input tensor.
101 | weights_filename (str): Filename where the trained model weights will be saved.
102 | learning_rate (float): Learning rate for the optimizer.
103 | params_path (Optional[str]): Path to pre-trained reward model parameters for initializing the REWARD model, if available.
104 | model_params_path (Optional[str]): Path to pre-trained backbone model parameters for initializing the BACKBONE model, if available.
105 |
106 | Methods:
107 | create_train_state(learning_rate, text_input_shape, image_input_shape): Initializes the training state, including parameters and optimizer.
108 | train_step(state, texts, images): Performs a single training step, including forward pass, loss computation, and gradients update.
109 | train(train_loader, num_epochs, val_loader): Runs the training loop over the specified number of epochs, using the provided data loaders for training and validation.
110 | evaluation_step(state, texts, images): Performs an evaluation step, computing forward pass and loss without updating model parameters.
111 | evaluate(test_loader): Evaluates the model performance on a test dataset.
112 | save_params(): Saves the model parameters to a file.
113 | load_params(filename): Loads model parameters from a file.
114 | """
115 |
116 | def __init__(
117 | self,
118 | model: Any,
119 | input_shape: Tuple[int, ...],
120 | weights_filename: str,
121 | learning_rate: float = 1e-5,
122 | params_path: Optional[str] = None,
123 | model_params_path: Optional[str] = None,
124 | ) -> None:
125 |
126 | self.model = model
127 | self.params = None
128 | self.params_path = params_path
129 | self.model_params_path = model_params_path
130 | self.num_parameters = None
131 | self.best_val_loss = float("inf")
132 | self.weights_filename = weights_filename
133 | self.num_devices = jax.local_device_count()
134 | self.train_step = jax.pmap(
135 | RewardDataParallelTrainer.train_step, axis_name="devices"
136 | )
137 | self.evaluation_step = jax.pmap(
138 | RewardDataParallelTrainer.evaluation_step, axis_name="devices"
139 | )
140 | self.state = self.create_train_state(learning_rate, input_shape)
141 | print(f"Number of accelerators: {self.num_devices}")
142 |
143 | def create_train_state(
144 | self, learning_rate: float, input_shape: Tuple[int, ...]
145 | ) -> Any:
146 |
147 | rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)}
148 | params = self.model.init(rngs, jnp.ones(input_shape, dtype=jnp.int32))["params"]
149 |
150 | if self.params_path is not None:
151 | params = self.load_params(self.params_path)
152 |
153 | if self.model_params_path is not None:
154 | model_params = self.load_params(self.model_params_path)
155 | params = self.merge_params(model_params, params)
156 |
157 | self.num_parameters = sum(
158 | param.size for param in jax.tree_util.tree_leaves(params)
159 | )
160 | print(f"Number of parameters: {self.num_parameters}")
161 | state = train_state.TrainState.create(
162 | apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate)
163 | )
164 | return jax.device_put_replicated(state, jax.local_devices())
165 |
166 | @staticmethod
167 | def train_step(
168 | state: Any, chosen: jnp.ndarray, rejected: jnp.ndarray
169 | ) -> Tuple[Any, jnp.ndarray]:
170 |
171 | def loss_fn(params):
172 | chosen_rewards = state.apply_fn(
173 | {"params": params},
174 | chosen,
175 | training=True,
176 | rngs={"dropout": jax.random.PRNGKey(int(time.time()))},
177 | )
178 |
179 | rejected_rewards = state.apply_fn(
180 | {"params": params},
181 | rejected,
182 | training=True,
183 | rngs={"dropout": jax.random.PRNGKey(int(time.time()))},
184 | )
185 |
186 | return -jnp.log(jax.nn.sigmoid(chosen_rewards - rejected_rewards)).mean()
187 |
188 | loss, grads = jax.value_and_grad(loss_fn)(state.params)
189 | state = state.apply_gradients(grads=grads)
190 | return state, loss
191 |
192 | def train(
193 | self,
194 | train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]],
195 | num_epochs: int,
196 | val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None,
197 | ) -> None:
198 |
199 | for epoch in range(num_epochs):
200 | total_loss = 0.0
201 | count = 0
202 | for chosen, rejected in train_loader:
203 | batch_size = chosen.shape[0]
204 | batch_size_per_device = batch_size // self.num_devices
205 | chosen = chosen.reshape((self.num_devices, batch_size_per_device, -1))
206 | rejected = rejected.reshape(
207 | (self.num_devices, batch_size_per_device, -1)
208 | )
209 | self.state, loss = self.train_step(
210 | state=self.state, chosen=chosen, rejected=rejected
211 | )
212 | total_loss += jnp.mean(loss)
213 | count += 1
214 |
215 | mean_loss = total_loss / count
216 | print(f"Epoch {epoch+1}, Train Loss: {mean_loss}")
217 |
218 | if val_loader is not None:
219 | val_loss = self.evaluate(val_loader)
220 | print(f"Epoch {epoch+1}, Val Loss: {val_loss}")
221 | if val_loss < self.best_val_loss:
222 | self.best_val_loss = val_loss
223 | print("New best validation score achieved, saving model...")
224 | self.save_params()
225 | return
226 |
227 | @staticmethod
228 | def evaluation_step(
229 | state: Any, chosen: jnp.ndarray, rejected: jnp.ndarray
230 | ) -> Tuple[Any, jnp.ndarray]:
231 | chosen_rewards = state.apply_fn(
232 | {"params": state.params}, chosen, rngs={"dropout": jax.random.PRNGKey(2)}
233 | )
234 | rejected_rewards = state.apply_fn(
235 | {"params": state.params}, rejected, rngs={"dropout": jax.random.PRNGKey(2)}
236 | )
237 | return -jnp.log(jax.nn.sigmoid(chosen_rewards - rejected_rewards)).mean()
238 |
239 | def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None:
240 |
241 | total_loss = 0.0
242 | count = 0
243 | for chosen, rejected in test_loader:
244 | batch_size = chosen.shape[0]
245 | batch_size_per_device = batch_size // self.num_devices
246 | chosen = chosen.reshape((self.num_devices, batch_size_per_device, -1))
247 | rejected = rejected.reshape((self.num_devices, batch_size_per_device, -1))
248 | loss = self.evaluation_step(self.state, chosen, rejected)
249 | total_loss += jnp.mean(loss)
250 | count += 1
251 |
252 | mean_loss = total_loss / count
253 | return mean_loss
254 |
255 | def merge_params(untrained_params, trained_params):
256 | updated_untrained_params = jax.tree_map(
257 | lambda untrained, trained: (
258 | trained if untrained.shape == trained.shape else untrained
259 | ),
260 | untrained_params,
261 | trained_params,
262 | )
263 | return updated_untrained_params
264 |
265 | def save_params(self) -> None:
266 | self.params = flax.jax_utils.unreplicate(self.state.params)
267 | with open(self.weights_filename, "wb") as f:
268 | f.write(flax.serialization.to_bytes(self.params))
269 |
270 | def load_params(self, filename: str):
271 | with open(filename, "rb") as f:
272 | self.params = flax.serialization.from_bytes(self.params, f.read())
273 | return self.params
274 |
--------------------------------------------------------------------------------
/nanodl/__src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HMUNACHI/nanodl/e52861a5b2c9bf76e4e79e0bf88a07420497579d/nanodl/__src/utils/__init__.py
--------------------------------------------------------------------------------
/nanodl/__src/utils/data.py:
--------------------------------------------------------------------------------
1 | import collections
2 | from dataclasses import dataclass
3 | from typing import Iterator
4 |
5 | import jax
6 | import jax.numpy as jnp
7 |
8 | # This script modifies the JAX DataLoader from the following repository:
9 | # JAX DataLoader by Birkhoff G. (https://birkhoffg.github.io/jax-dataloader/)
10 | # Accessed on [Date you accessed the repository, e.g., February 4, 2024]
11 | # This DataLoader implementation is used for efficient data loading in JAX-based machine learning projects.
12 |
13 |
14 | class Dataset:
15 | """
16 | A PyTorch-like Dataset class for JAX.
17 |
18 | This is a base class for creating datasets in JAX. Subclasses should implement
19 | the `__len__` method to return the size of the dataset and the `__getitem__`
20 | method to return a data item at a given index.
21 |
22 | Example usage:
23 | ```
24 | >>> class MyDataset(Dataset):
25 | ... def __init__(self, data):
26 | ... self.data = data
27 | ... def __len__(self):
28 | ... return len(self.data)
29 | ... def __getitem__(self, index):
30 | ... return self.data[index]
31 | >>> dataset = MyDataset(jnp.arange(10))
32 | >>> print(len(dataset))
33 | >>> print(dataset[5])
34 | ```
35 | """
36 |
37 | def __len__(self):
38 | raise NotImplementedError
39 |
40 | def __getitem__(self, index):
41 | raise NotImplementedError
42 |
43 |
44 | class ArrayDataset(Dataset):
45 | """
46 | Dataset wrapping JAX numpy arrays.
47 |
48 | This class wraps multiple JAX numpy arrays into a dataset. Each array represents
49 | a different modality of the data (e.g., features and labels). All arrays must
50 | have the same first dimension (number of samples).
51 |
52 | Args:
53 | *arrays (jnp.array): Variable number of JAX numpy arrays to include in the dataset.
54 |
55 | Example usage:
56 | ```
57 | >>> dataset = ArrayDataset(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
58 | >>> print(len(dataset))
59 | >>> print(dataset[1])
60 | ```
61 | """
62 |
63 | def __init__(self, *arrays: jnp.array):
64 | assert all(
65 | arrays[0].shape[0] == arr.shape[0] for arr in arrays
66 | ), "All arrays must have the same first dimension."
67 | self.arrays = arrays
68 |
69 | def __len__(self):
70 | return self.arrays[0].shape[0]
71 |
72 | def __getitem__(self, index):
73 | return tuple(arr[index] for arr in self.arrays)
74 |
75 |
76 | class DataLoader:
77 | """
78 | DataLoader in Vanilla Jax.
79 |
80 | This class provides a way to iterate over batches of data from a given dataset.
81 | It supports batch processing, shuffling, and dropping the last batch if it's
82 | smaller than the specified batch size.
83 |
84 | Args:
85 | dataset (Dataset): The dataset from which to load the data.
86 | batch_size (int, optional): Number of samples per batch. Default is 1.
87 | shuffle (bool, optional): Whether to shuffle the data. Default is False.
88 | drop_last (bool, optional): Whether to drop the last incomplete batch.
89 | Default is False.
90 |
91 | Example usage:
92 | ```
93 | >>> dataset = ArrayDataset(jnp.ones((1001, 256, 256)), jnp.ones((1001, 256, 256)))
94 | >>> dataloader = DataLoader(dataset, batch_size=10, shuffle=True, drop_last=False)
95 | >>> for batch in dataloader:
96 | ... print(batch.shape)
97 | ```
98 | """
99 |
100 | def __init__(
101 | self,
102 | dataset: Dataset,
103 | batch_size: int = 1,
104 | shuffle: bool = False,
105 | drop_last: bool = False,
106 | **kwargs
107 | ):
108 | self.dataset = dataset
109 | self.batch_size = batch_size
110 | self.shuffle = shuffle
111 | self.drop_last = drop_last
112 |
113 | self.keys = _PRNGSequence(seed=Config.default().global_seed)
114 | self.data_len = len(dataset) # Length of the dataset
115 | self.indices = jnp.arange(self.data_len) # available indices in the dataset
116 | self.pose = 0 # record the current position in the dataset
117 | self._shuffle()
118 |
119 | def _shuffle(self):
120 | if self.shuffle:
121 | self.indices = jax.random.permutation(next(self.keys), self.indices)
122 |
123 | def _stop_iteration(self):
124 | self.pose = 0
125 | self._shuffle()
126 | raise StopIteration
127 |
128 | def __len__(self):
129 | if self.drop_last:
130 | batches = len(self.dataset) // self.batch_size # get the floor of division
131 | else:
132 | batches = -(
133 | len(self.dataset) // -self.batch_size
134 | ) # get the ceil of division
135 | return batches
136 |
137 | def __next__(self):
138 | if self.pose + self.batch_size <= self.data_len:
139 | batch_indices = self.indices[self.pose : self.pose + self.batch_size]
140 | batch_data = self.dataset[batch_indices]
141 | self.pose += self.batch_size
142 | return batch_data
143 | elif self.pose < self.data_len and not self.drop_last:
144 | batch_indices = self.indices[self.pose :]
145 | batch_data = self.dataset[batch_indices]
146 | self.pose += self.batch_size
147 | return batch_data
148 | else:
149 | self._stop_iteration()
150 |
151 | def __iter__(self):
152 | return self
153 |
154 |
155 | @dataclass
156 | class Config:
157 | rng_reserve_size: int
158 | global_seed: int
159 |
160 | @classmethod
161 | def default(cls):
162 | return cls(rng_reserve_size=1, global_seed=42)
163 |
164 |
165 | class _PRNGSequence(Iterator[jax.random.PRNGKey]):
166 | """
167 | An Iterator of Jax PRNGKey (minimal version of `haiku.PRNGSequence`).
168 |
169 | This class provides an iterator over PRNG keys generated from a seed. It is useful
170 | for generating random numbers in a reproducible way.
171 |
172 | Args:
173 | seed (int): Seed for generating the initial PRNG key.
174 |
175 | Example usage:
176 | ```
177 | >>> prng_seq = PRNGSequence(42)
178 | >>> key = next(prng_seq)
179 | ```
180 | """
181 |
182 | def __init__(self, seed: int):
183 | self._key = jax.random.PRNGKey(seed)
184 | self._subkeys = collections.deque()
185 |
186 | def reserve(self, num):
187 | if num > 0:
188 | new_keys = tuple(jax.random.split(self._key, num + 1))
189 | self._key = new_keys[0]
190 | self._subkeys.extend(new_keys[1:])
191 |
192 | def __next__(self):
193 | if not self._subkeys:
194 | self.reserve(Config.default().rng_reserve_size)
195 | return self._subkeys.popleft()
196 |
--------------------------------------------------------------------------------
/nanodl/__src/utils/ml.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List
2 |
3 | import jax
4 | import jax.numpy as jnp
5 |
6 |
7 | @jax.jit
8 | def batch_cosine_similarities(
9 | source: jnp.ndarray, candidates: jnp.ndarray
10 | ) -> jnp.ndarray:
11 | """
12 | Calculate cosine similarities between a source vector and a batch of candidate vectors.
13 |
14 | Args:
15 | source (jnp.ndarray): Source vector of shape (D,).
16 | candidates (jnp.ndarray): Batch of candidate vectors of shape (N, D), where N is the number of candidates.
17 |
18 | Returns:
19 | jnp.ndarray: Array of cosine similarity scores of shape (N,).
20 |
21 | Example usage:
22 | ```
23 | >>> source = jnp.array([1, 0, 0])
24 | >>> candidates = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
25 | >>> similarities = batch_cosine_similarities(source, candidates)
26 | >>> print(similarities)
27 | ```
28 | """
29 | dot_products = jnp.einsum("ij,j->i", candidates, source)
30 | norm_source = jnp.sqrt(jnp.einsum("i,i->", source, source))
31 | norm_candidates = jnp.sqrt(jnp.einsum("ij,ij->i", candidates, candidates))
32 | return dot_products / (norm_source * norm_candidates)
33 |
34 |
35 | @jax.jit
36 | def batch_pearsonr(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
37 | """
38 | Calculate batch Pearson correlation coefficient between two sets of vectors.
39 |
40 | Args:
41 | x (jnp.ndarray): First set of vectors of shape (N, D), where N is the number of vectors.
42 | y (jnp.ndarray): Second set of vectors of shape (N, D).
43 |
44 | Returns:
45 | jnp.ndarray: Array of Pearson correlation coefficients of shape (N,).
46 |
47 | Example usage:
48 | ```
49 | >>> x = jnp.array([[1, 2, 3], [4, 5, 6]])
50 | >>> y = jnp.array([[1, 5, 7], [2, 6, 8]])
51 | >>> correlations = batch_pearsonr(x, y)
52 | >>> print(correlations)
53 | ```
54 | """
55 | x = jnp.asarray(x).T
56 | y = jnp.asarray(y).T
57 | x = x - jnp.expand_dims(x.mean(axis=1), axis=-1)
58 | y = y - jnp.expand_dims(y.mean(axis=1), axis=-1)
59 | numerator = jnp.sum(x * y, axis=-1)
60 | sum_of_squares_x = jnp.einsum("ij,ij -> i", x, x)
61 | sum_of_squares_y = jnp.einsum("ij,ij -> i", y, y)
62 | denominator = jnp.sqrt(sum_of_squares_x * sum_of_squares_y)
63 | return numerator / denominator
64 |
65 |
66 | @jax.jit
67 | def classification_scores(labels: jnp.ndarray, preds: jnp.ndarray) -> jnp.ndarray:
68 | """
69 | Calculate classification evaluation scores using JAX.
70 |
71 | Args:
72 | labels (jnp.ndarray): Array of true labels.
73 | preds (jnp.ndarray): Array of predicted labels.
74 |
75 | Returns:
76 | jnp.ndarray: Array containing accuracy, precision, recall, and F1-score.
77 |
78 | Example usage:
79 | ```
80 | >>> labels = jnp.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0])
81 | >>> preds = jnp.array([1, 1, 1, 0, 1, 0, 1, 0, 0, 0])
82 | >>> print(classification_scores(labels, preds))
83 | ```
84 | """
85 | true_positives = jnp.sum(jnp.logical_and(preds == 1, labels == 1))
86 | true_negatives = jnp.sum(jnp.logical_and(preds == 0, labels == 0))
87 | false_positives = jnp.sum(jnp.logical_and(preds == 1, labels == 0))
88 | false_negatives = jnp.sum(jnp.logical_and(preds == 0, labels == 1))
89 |
90 | accuracy = (true_positives + true_negatives) / len(preds)
91 | precision = true_positives / (true_positives + false_positives)
92 | recall = true_positives / (true_positives + false_negatives)
93 | f1 = 2 * (precision * recall) / (precision + recall)
94 | return jnp.array([accuracy, precision, recall, f1])
95 |
96 |
97 | @jax.jit
98 | def mean_reciprocal_rank(predictions: jnp.ndarray) -> float:
99 | """
100 | Calculate the Mean Reciprocal Rank (MRR) for a list of ranked predictions using JAX.
101 |
102 | Example usage:
103 | ```
104 | predictions = jnp.array([
105 | [0, 1, 2], # "correct" prediction at index 0
106 | [1, 0, 2], # "correct" prediction at index 1
107 | [2, 1, 0] # "correct" prediction at index 2
108 | ])
109 | mrr_score = mean_reciprocal_rank(predictions)
110 | ```
111 |
112 | Args:
113 | predictions (jnp.ndarray): 2D array where each row contains ranked predictions
114 | and the "correct" prediction is indicated by a specific index.
115 |
116 | Returns:
117 | float: Mean Reciprocal Rank (MRR) score.
118 | """
119 | correct_indices = jnp.argmin(predictions, axis=1)
120 | ranks = correct_indices + 1
121 | reciprocal_ranks = 1.0 / ranks
122 | mean_mrr = jnp.mean(reciprocal_ranks)
123 | return mean_mrr
124 |
125 |
126 | def jaccard(sequence1: List, sequence2: List) -> float:
127 | """
128 | Calculate Jaccard similarity between two sequences.
129 |
130 | Args:
131 | sequence1 (List): First input sequence.
132 | sequence2 (List): Second input sequence.
133 |
134 | Returns:
135 | float: Jaccard similarity score.
136 |
137 | Example usage:
138 | ```py
139 | >>> sequence1 = [1, 2, 3]
140 | >>> sequence2 = [2, 3, 4]
141 | >>> similarity = jaccard(sequence1, sequence2)
142 | >>> print(similarity)
143 | ```
144 | """
145 | numerator = len(set(sequence1).intersection(sequence2))
146 | denominator = len(set(sequence1).union(sequence2))
147 | return numerator / denominator
148 |
149 |
150 | @jax.jit
151 | def hamming(sequence1: jnp.ndarray, sequence2: jnp.ndarray) -> int:
152 | """
153 | Calculate Hamming similarity between two sequences using JAX.
154 |
155 | Args:
156 | sequence1 (jnp.ndarray): First input sequence.
157 | sequence2 (jnp.ndarray): Second input sequence.
158 |
159 | Returns:
160 | int: Hamming similarity score.
161 |
162 | Example usage:
163 | ```py
164 | >>> sequence1 = jnp.array([1, 2, 3, 4])
165 | >>> sequence2 = jnp.array([1, 2, 4, 4])
166 | >>> similarity = hamming_jax(sequence1, sequence2)
167 | >>> print(similarity)
168 | ```
169 | """
170 | return jnp.sum(sequence1 == sequence2)
171 |
172 |
173 | def zero_pad_sequences(arr: jnp.array, max_length: int) -> jnp.array:
174 | """
175 | Zero-pad the given array to the specified maximum length along axis=1.
176 |
177 | This function pads the input array with zeros along the second dimension (axis=1)
178 | until it reaches the specified maximum length. If the array is already longer
179 | than the maximum length, it is returned as is.
180 |
181 | Args:
182 | arr (jax.numpy.ndarray): The array to be padded. Must be 2-dimensional.
183 | max_length (int): The maximum length to pad the array to along axis=1.
184 |
185 | Returns:
186 | jax.numpy.ndarray: The zero-padded array.
187 |
188 | Example usage:
189 | ```py
190 | >>> arr = jnp.array([[1, 2, 3], [4, 5, 6]])
191 | >>> max_length = 5
192 | >>> padded_arr = zero_pad_sequences(arr, max_length)
193 | >>> print(padded_arr)
194 | [[1 2 3 0 0]
195 | [4 5 6 0 0]]
196 | ```
197 | """
198 | current_length = arr.shape[1]
199 | num_zeros = max_length - current_length
200 |
201 | if num_zeros > 0:
202 | zeros = jnp.zeros((arr.shape[0], num_zeros), dtype=arr.dtype)
203 | padded_array = jnp.concatenate([arr, zeros], axis=1)
204 | else:
205 | padded_array = arr
206 |
207 | return padded_array
208 |
209 |
210 | @jax.jit
211 | def entropy(probabilities: jnp.ndarray) -> float:
212 | """
213 | Calculate the entropy of a probability distribution using JAX.
214 |
215 | Example usage:
216 | ```
217 | probabilities = jnp.array([0.25, 0.75])
218 | entropy_value = entropy(probabilities)
219 | ```
220 |
221 | Args:
222 | probabilities (jnp.ndarray): Array of probability values.
223 |
224 | Returns:
225 | float: Entropy value.
226 | """
227 | log_probs = jnp.log2(probabilities)
228 | entropy_value = -jnp.sum(probabilities * log_probs)
229 | return entropy_value
230 |
231 |
232 | @jax.jit
233 | def gini_impurity(probabilities: jnp.ndarray) -> float:
234 | """
235 | Calculate the Gini impurity of a probability distribution using JAX.
236 |
237 | Example usage:
238 | ```
239 | probabilities = jnp.array([0.25, 0.75])
240 | gini_value = gini_impurity(probabilities)
241 | ```
242 |
243 | Args:
244 | probabilities (jnp.ndarray): Array of probability values.
245 |
246 | Returns:
247 | float: Gini impurity value.
248 | """
249 | gini_value = 1 - jnp.sum(probabilities**2)
250 | return gini_value
251 |
252 |
253 | @jax.jit
254 | def kl_divergence(p: jnp.ndarray, q: jnp.ndarray) -> float:
255 | """
256 | Calculate the Kullback-Leibler (KL) divergence between two probability distributions using JAX.
257 |
258 | Example usage:
259 | ```
260 | p = jnp.array([0.25, 0.75])
261 | q = jnp.array([0.5, 0.5])
262 | kl_value = kl_divergence(p, q)
263 | ```
264 |
265 | Args:
266 | p (jnp.ndarray): Array of probability values for distribution p.
267 | q (jnp.ndarray): Array of probability values for distribution q.
268 |
269 | Returns:
270 | float: KL divergence value.
271 | """
272 | kl_value = jnp.sum(p * jnp.log2(p / q))
273 | return kl_value
274 |
275 |
276 | @jax.jit
277 | def count_parameters(params: Any) -> int:
278 | """
279 | Count the total number of parameters in a model's parameter dictionary using JAX.
280 |
281 | Example usage:
282 | ```
283 | model = MyModel()
284 | params = model.init(jax.random.PRNGKey(0), jnp.ones(input_shape))
285 | total_params = count_parameters(params)
286 | ```
287 |
288 | Args:
289 | params (Any): Model's parameter dictionary.
290 |
291 | Returns:
292 | int: Total number of parameters.
293 | """
294 | return sum(x.size for x in jax.tree_leaves(params))
295 |
--------------------------------------------------------------------------------
/nanodl/__src/utils/nlp.py:
--------------------------------------------------------------------------------
1 | import re
2 | from collections import Counter
3 | from typing import List
4 |
5 | import numpy as np
6 |
7 |
8 | def rouge(
9 | hypotheses: List[str], references: List[str], ngram_sizes: List[int] = [1, 2]
10 | ) -> dict:
11 | """
12 | Calculate the ROUGE (Recall-Oriented Understudy for Gisting Evaluation) metric.
13 | ROUGE-F1 = (Precision + Recall) / (2⋅Precision⋅Recall)
14 |
15 | Args:
16 | hypotheses (List[str]): List of hypothesis sentences.
17 | references (List[str]): List of reference sentences.
18 | ngram_sizes (List[int], optional): List of n-gram sizes. Default is [1, 2].
19 |
20 | Returns:
21 | dict: Dictionary containing precision, recall, and F1-score for each n-gram size.
22 |
23 | Example usage:
24 | ```
25 | >>> hypotheses = ["the cat is on the mat", "there is a cat on the mat"]
26 | >>> references = ["the cat is on the mat", "the cat sits on the mat"]
27 | >>> rouge_scores = rouge(hypotheses, references, [1, 2])
28 | >>> print(rouge_scores)
29 | ```
30 | """
31 |
32 | def ngrams(sequence: List[str], n: int) -> List[str]:
33 | return [tuple(sequence[i : i + n]) for i in range(len(sequence) - n + 1)]
34 |
35 | def precision_recall_f1(hypothesis_tokens, reference_tokens, n):
36 | hypothesis_ngrams = set(ngrams(hypothesis_tokens, n))
37 | reference_ngrams = set(ngrams(reference_tokens, n))
38 |
39 | common_ngrams = hypothesis_ngrams.intersection(reference_ngrams)
40 |
41 | precision = (
42 | len(common_ngrams) / len(hypothesis_ngrams)
43 | if len(hypothesis_ngrams) > 0
44 | else 0.0
45 | )
46 | recall = (
47 | len(common_ngrams) / len(reference_ngrams)
48 | if len(reference_ngrams) > 0
49 | else 0.0
50 | )
51 |
52 | f1 = 2 * (precision * recall) / (precision + recall + 1e-12)
53 | return precision, recall, f1
54 |
55 | rouge_scores = {}
56 | for n in ngram_sizes:
57 | total_precision = 0.0
58 | total_recall = 0.0
59 | total_f1 = 0.0
60 | for hypothesis, reference in zip(hypotheses, references):
61 | hypothesis_tokens = hypothesis.split()
62 | reference_tokens = reference.split()
63 |
64 | precision, recall, f1 = precision_recall_f1(
65 | hypothesis_tokens, reference_tokens, n
66 | )
67 | total_precision += precision
68 | total_recall += recall
69 | total_f1 += f1
70 |
71 | average_precision = total_precision / len(hypotheses)
72 | average_recall = total_recall / len(hypotheses)
73 | average_f1 = total_f1 / len(hypotheses)
74 |
75 | rouge_scores[f"ROUGE-{n}"] = {
76 | "precision": average_precision,
77 | "recall": average_recall,
78 | "f1": average_f1,
79 | }
80 |
81 | return rouge_scores
82 |
83 |
84 | def bleu(hypotheses: List[str], references: List[str], max_ngram: int = 4) -> float:
85 | """
86 | Calculate the BLEU (Bilingual Evaluation Understudy) metric.
87 | BLEU = (BP) * (exp(sum(wn * log(pn))))
88 | where BP = brevity penalty, wn = weight for n-gram precision, and pn = n-gram precision
89 |
90 | Args:
91 | hypotheses (List[str]): List of hypothesis sentences.
92 | references (List[str]): List of reference sentences.
93 | max_ngram (int, optional): Maximum n-gram size to consider. Default is 4.
94 |
95 | Returns:
96 | float: BLEU score.
97 |
98 | Example usage:
99 | ```
100 | >>> hypotheses = ["the cat is on the mat", "there is a cat on the mat"]
101 | >>> references = ["the cat is on the mat", "the cat sits on the mat"]
102 | >>> bleu_score = bleu(hypotheses, references)
103 | >>> print(bleu_score)
104 | ```
105 | """
106 |
107 | def ngrams(sequence: List[str], n: int) -> List[str]:
108 | return [tuple(sequence[i : i + n]) for i in range(len(sequence) - n + 1)]
109 |
110 | def modified_precision(hypothesis_tokens, reference_tokens, n):
111 | hypothesis_ngrams = ngrams(hypothesis_tokens, n)
112 | reference_ngrams = ngrams(reference_tokens, n)
113 |
114 | hypothesis_ngram_counts = Counter(hypothesis_ngrams)
115 | reference_ngram_counts = Counter(reference_ngrams)
116 |
117 | common_ngrams = hypothesis_ngram_counts & reference_ngram_counts
118 | common_count = sum(common_ngrams.values())
119 |
120 | if len(hypothesis_ngrams) == 0:
121 | return 0.0
122 | else:
123 | precision = common_count / len(hypothesis_ngrams)
124 | return precision
125 |
126 | brevity_penalty = np.exp(min(0, 1 - len(hypotheses) / len(references)))
127 | bleu_scores = []
128 |
129 | for n in range(1, max_ngram + 1):
130 | ngram_precisions = []
131 | for hypothesis, reference in zip(hypotheses, references):
132 | hypothesis_tokens = hypothesis.split()
133 | reference_tokens = reference.split()
134 |
135 | precision = modified_precision(hypothesis_tokens, reference_tokens, n)
136 | ngram_precisions.append(precision)
137 |
138 | geometric_mean = np.exp(np.mean(np.log(np.clip(ngram_precisions, 1e-10, None))))
139 | bleu_scores.append(geometric_mean)
140 |
141 | final_bleu = brevity_penalty * np.exp(
142 | np.mean(np.log(np.clip(bleu_scores, 1e-10, None)))
143 | )
144 | return final_bleu
145 |
146 |
147 | def meteor(hypothesis: str, reference: str) -> float:
148 | """
149 | Calculates the METEOR score between a reference and hypothesis sentence.
150 |
151 | Args:
152 | reference (str): The reference sentence.
153 | hypothesis (str): The hypothesis sentence.
154 |
155 | Returns:
156 | float: METEOR score.
157 |
158 | Example usage:
159 | ```
160 | >>> hypothesis = "the cat is on the mat"
161 | >>> reference = "the cat sits on the mat"
162 | >>> meteor_score = meteor(hypothesis, reference)
163 | >>> print(meteor_score)
164 | ```
165 | """
166 |
167 | def tokenize(sentence):
168 | return re.findall(r"\w+", sentence.lower())
169 |
170 | def stemming(token):
171 | return token.lower()
172 |
173 | def exact_matching(reference_tokens, hypothesis_tokens):
174 | return sum(1 for token in hypothesis_tokens if token in reference_tokens)
175 |
176 | def stemmed_matching(reference_tokens, hypothesis_tokens):
177 | stemmed_reference = [stemming(token) for token in reference_tokens]
178 | stemmed_hypothesis = [stemming(token) for token in hypothesis_tokens]
179 | return sum(1 for token in stemmed_hypothesis if token in stemmed_reference)
180 |
181 | def precision_recall_f1(match_count, hypothesis_length, reference_length):
182 | precision = match_count / hypothesis_length if hypothesis_length > 0 else 0
183 | recall = match_count / reference_length if reference_length > 0 else 0
184 | f1 = (
185 | 2 * precision * recall / (precision + recall)
186 | if precision + recall > 0
187 | else 0
188 | )
189 | return precision, recall, f1
190 |
191 | reference_tokens = tokenize(reference)
192 | hypothesis_tokens = tokenize(hypothesis)
193 |
194 | exact_matches = exact_matching(reference_tokens, hypothesis_tokens)
195 | stemmed_matches = stemmed_matching(reference_tokens, hypothesis_tokens)
196 |
197 | _, _, f1_exact = precision_recall_f1(
198 | exact_matches, len(hypothesis_tokens), len(reference_tokens)
199 | )
200 |
201 | precision_stemmed, recall_stemmed, f1_stemmed = precision_recall_f1(
202 | stemmed_matches, len(hypothesis_tokens), len(reference_tokens)
203 | )
204 |
205 | alpha = 0.5
206 | meteor_score = (1 - alpha) * f1_exact + alpha * precision_stemmed * recall_stemmed
207 | return meteor_score
208 |
209 |
210 | def cider_score(hypothesis: str, reference: str) -> float:
211 | """
212 | Calculates the CIDEr score between a reference and hypothesis sentence.
213 |
214 | Args:
215 | reference (str): The reference sentence.
216 | hypothesis (str): The hypothesis sentence.
217 |
218 | Returns:
219 | float: CIDEr score.
220 |
221 | Example usage:
222 | ```
223 | >>> hypothesis = "the cat is on the mat"
224 | >>> reference = "the cat sits on the mat"
225 | >>> score = cider_score(hypothesis, reference)
226 | >>> print(score)
227 | ```
228 | """
229 |
230 | def tokenize(sentence):
231 | return re.findall(r"\w+", sentence.lower())
232 |
233 | def ngrams(tokens, n):
234 | return [tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)]
235 |
236 | reference_tokens = tokenize(reference)
237 | hypothesis_tokens = tokenize(hypothesis)
238 |
239 | max_n = 4 # Maximum n-gram size
240 | weights = [1.0] * max_n # Weights for different n-gram sizes
241 |
242 | cider_scores = []
243 | for n in range(1, max_n + 1):
244 | ref_ngrams = ngrams(reference_tokens, n)
245 | hyp_ngrams = ngrams(hypothesis_tokens, n)
246 |
247 | ref_ngram_freq = Counter(ref_ngrams)
248 | hyp_ngram_freq = Counter(hyp_ngrams)
249 |
250 | common_ngrams = set(ref_ngrams) & set(hyp_ngrams)
251 |
252 | if len(common_ngrams) == 0:
253 | cider_scores.append(0)
254 | continue
255 |
256 | precision = sum(
257 | min(ref_ngram_freq[ngram], hyp_ngram_freq[ngram]) for ngram in common_ngrams
258 | ) / len(hyp_ngrams)
259 | ref_ngram_freq_sum = sum(ref_ngram_freq[ngram] for ngram in common_ngrams)
260 | hyp_ngram_freq_sum = sum(hyp_ngram_freq[ngram] for ngram in common_ngrams)
261 | recall = ref_ngram_freq_sum / len(ref_ngrams)
262 |
263 | cider_scores.append((precision * recall) / (precision + recall) * 2)
264 |
265 | avg_cider_score = np.average(cider_scores, weights=weights)
266 |
267 | return avg_cider_score
268 |
269 |
270 | def perplexity(log_probs: List[float]) -> float:
271 | """
272 | Calculate the perplexity of a sequence using a list of log probabilities.
273 | Perplexity = 2^(-average log likelihood)
274 | where average log likelihood = total log likelihood / total word count
275 |
276 | Args:
277 | log_probs (List[float]): List of log probabilities for each predicted word.
278 |
279 | Returns:
280 | float: Perplexity score.
281 |
282 | Example usage:
283 | ```
284 | >>> log_probs = [-2.3, -1.7, -0.4] # Example log probabilities
285 | >>> perplexity_score = perplexity(log_probs)
286 | >>> print(perplexity_score)
287 | ```
288 | """
289 | log_likelihood = 0.0
290 | word_count = 0
291 |
292 | for i in range(len(log_probs) - 1):
293 | predicted_log_prob = log_probs[
294 | i
295 | ] # Replace this with your language model's log probability
296 | log_likelihood += predicted_log_prob
297 | word_count += 1
298 |
299 | average_log_likelihood = log_likelihood / word_count
300 | perplexity_score = 2 ** (-average_log_likelihood)
301 | return perplexity_score
302 |
303 |
304 | def word_error_rate(hypotheses: List[int], references: List[int]) -> float:
305 | """
306 | Calculate the Word Error Rate (WER) metric.
307 |
308 | Args:
309 | hypotheses (List[str]): List of hypothesis words.
310 | references (List[str]): List of reference words.
311 |
312 | Returns:
313 | float: Word Error Rate score.
314 |
315 | Example usage:
316 | ```
317 | >>> hypotheses = ["the cat is on the mat", "there is a cat on the mat"]
318 | >>> references = ["the cat is on the mat", "the cat sits on the mat"]
319 | >>> wer_score = word_error_rate(hypotheses, references)
320 | >>> print(wer_score)
321 | ```
322 | """
323 |
324 | def edit_distance(str1, str2):
325 | len_str1 = len(str1)
326 | len_str2 = len(str2)
327 |
328 | dp = [[0] * (len_str2 + 1) for _ in range(len_str1 + 1)]
329 |
330 | for i in range(len_str1 + 1):
331 | dp[i][0] = i
332 |
333 | for j in range(len_str2 + 1):
334 | dp[0][j] = j
335 |
336 | for i in range(1, len_str1 + 1):
337 | for j in range(1, len_str2 + 1):
338 | cost = 0 if str1[i - 1] == str2[j - 1] else 1
339 | dp[i][j] = min(
340 | dp[i - 1][j] + 1, # Deletion
341 | dp[i][j - 1] + 1, # Insertion
342 | dp[i - 1][j - 1] + cost, # Substitution or no operation
343 | )
344 | return dp[len_str1][len_str2]
345 |
346 | total_edit_distance = 0
347 | total_reference_length = 0
348 |
349 | for hyp, ref in zip(hypotheses, references):
350 | edit_dist = edit_distance(hyp.split(), ref.split())
351 | total_edit_distance += edit_dist
352 | total_reference_length += len(ref.split())
353 |
354 | wer_score = total_edit_distance / total_reference_length
355 | return wer_score
356 |
--------------------------------------------------------------------------------
/nanodl/__src/utils/random.py:
--------------------------------------------------------------------------------
1 | import time
2 | from typing import Any, Tuple, Union
3 |
4 | import jax
5 | import jax.numpy as jnp
6 | from jax import random
7 |
8 |
9 | def time_rng_key(seed=None) -> jnp.ndarray:
10 | """Generate a JAX random key based on the current UNIX timestamp.
11 |
12 | Returns:
13 | jnp.ndarray: A JAX random key.
14 | """
15 | key = int(time.time()) if seed is None else seed
16 | return random.PRNGKey(key)
17 |
18 |
19 | def uniform(
20 | shape: Tuple[int, ...],
21 | minval: Any = 0.0,
22 | maxval: Any = 1.0,
23 | seed=None,
24 | dtype: Any = jnp.float32,
25 | ) -> jnp.ndarray:
26 | """Generate a tensor of uniform random values.
27 |
28 | Args:
29 | shape (Tuple[int, ...]): The shape of the output tensor.
30 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32.
31 | minval (Any, optional): The lower bound of the uniform distribution. Defaults to 0.0.
32 | maxval (Any, optional): The upper bound of the uniform distribution. Defaults to 1.0.
33 |
34 | Returns:
35 | jnp.ndarray: A tensor of uniform random values.
36 | """
37 | return random.uniform(
38 | time_rng_key(seed), shape, dtype=dtype, minval=minval, maxval=maxval
39 | )
40 |
41 |
42 | def normal(shape: Tuple[int, ...], dtype: Any = jnp.float32, seed=None) -> jnp.ndarray:
43 | """Generate a tensor of normal random values.
44 |
45 | Args:
46 | shape (Tuple[int, ...]): The shape of the output tensor.
47 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32.
48 |
49 | Returns:
50 | jnp.ndarray: A tensor of normal random values.
51 | """
52 | return random.normal(time_rng_key(seed), shape, dtype=dtype)
53 |
54 |
55 | def bernoulli(p: float, shape: Tuple[int, ...] = (), seed=None) -> jnp.ndarray:
56 | """Generate random boolean values with a given probability.
57 |
58 | Args:
59 | p (float): Probability of sampling a True value.
60 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to ().
61 |
62 | Returns:
63 | jnp.ndarray: A tensor of boolean values.
64 | """
65 | return random.bernoulli(time_rng_key(seed), p, shape)
66 |
67 |
68 | def categorical(
69 | logits: jnp.ndarray, axis: int = -1, shape: Tuple[int, ...] = (), seed=None
70 | ) -> jnp.ndarray:
71 | """Draw samples from a categorical distribution.
72 |
73 | Args:
74 | logits (jnp.ndarray): The unnormalized log probabilities of the categories.
75 | axis (int, optional): The axis along which the categorical distribution is applied. Defaults to -1.
76 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to ().
77 |
78 | Returns:
79 | jnp.ndarray: The sampled indices with the specified shape.
80 | """
81 | return random.categorical(time_rng_key(seed), logits, axis=axis, shape=shape)
82 |
83 |
84 | def randint(
85 | shape: Tuple[int, ...], minval: int, maxval: int, dtype: str = "int32", seed=None
86 | ) -> jnp.ndarray:
87 | """Generate random integers between minval (inclusive) and maxval (exclusive).
88 |
89 | Args:
90 | shape (Tuple[int, ...]): The shape of the output tensor.
91 | minval (int): The lower bound of the random integers, inclusive.
92 | maxval (int): The upper bound of the random integers, exclusive.
93 | dtype (str, optional): The data type of the output tensor. Defaults to 'int32'.
94 |
95 | Returns:
96 | jnp.ndarray: A tensor of random integers.
97 | """
98 | return random.randint(time_rng_key(seed), shape, minval, maxval, dtype=dtype)
99 |
100 |
101 | def permutation(x: Union[int, jnp.ndarray], axis: int = 0, seed=None) -> jnp.ndarray:
102 | """Randomly permute a sequence, or return a permuted range.
103 |
104 | Args:
105 | x (Union[int, jnp.ndarray]): If x is an integer, permute range(x). If x is an array, permute its elements.
106 | axis (int, optional): The axis along which to permute if x is an array. Defaults to 0.
107 |
108 | Returns:
109 | jnp.ndarray: The permuted sequence or array.
110 | """
111 | if isinstance(x, int):
112 | arr = jax.numpy.arange(x)
113 | return random.permutation(time_rng_key(seed), arr, axis=axis)
114 | else:
115 | return random.permutation(time_rng_key(seed), x, axis=axis)
116 |
117 |
118 | def gumbel(shape: Tuple[int, ...], dtype: Any = jnp.float32, seed=None) -> jnp.ndarray:
119 | """Draw samples from a Gumbel distribution.
120 |
121 | Args:
122 | shape (Tuple[int, ...]): The shape of the output tensor.
123 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32.
124 |
125 | Returns:
126 | jnp.ndarray: A tensor of samples from a Gumbel distribution.
127 | """
128 | return random.gumbel(time_rng_key(seed), shape, dtype=dtype)
129 |
130 |
131 | def choice(
132 | a: Union[int, jnp.ndarray],
133 | shape: Tuple[int, ...] = (),
134 | replace: bool = True,
135 | p: Union[None, jnp.ndarray] = None,
136 | axis: int = 0,
137 | seed=None,
138 | ) -> jnp.ndarray:
139 | """Randomly choose elements from a given 1-D array.
140 |
141 | Args:
142 | a (Union[int, jnp.ndarray]): If an int, the random sample is generated as if a were jnp.arange(a).
143 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to ().
144 | replace (bool, optional): Whether the sample is with or without replacement. Defaults to True.
145 | p (Union[None, jnp.ndarray], optional): The probabilities associated with each entry in a. Defaults to None.
146 | axis (int, optional): The axis along which to choose if a is an array. Defaults to 0.
147 |
148 | Returns:
149 | jnp.ndarray: The randomly chosen elements.
150 | """
151 | if isinstance(a, int):
152 | a = jnp.arange(a)
153 | return random.choice(
154 | time_rng_key(seed), a, shape=shape, replace=replace, p=p, axis=axis
155 | )
156 |
157 |
158 | def bits(shape: Tuple[int, ...], dtype: Any = jnp.uint32, seed=None) -> jnp.ndarray:
159 | """Generate random bits.
160 |
161 | Args:
162 | shape (Tuple[int, ...]): The shape of the output tensor.
163 | dtype (Any, optional): The data type of the output tensor, typically an unsigned integer type. Defaults to jnp.uint32.
164 |
165 | Returns:
166 | jnp.ndarray: A tensor of random bits.
167 | """
168 | return random.bits(time_rng_key(seed), shape, dtype=dtype)
169 |
170 |
171 | def exponential(
172 | shape: Tuple[int, ...], dtype: Any = jnp.float32, seed=None
173 | ) -> jnp.ndarray:
174 | """Draw samples from an exponential distribution.
175 |
176 | Args:
177 | shape (Tuple[int, ...]): The shape of the output tensor.
178 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32.
179 |
180 | Returns:
181 | jnp.ndarray: A tensor of samples from an exponential distribution.
182 | """
183 | return random.exponential(time_rng_key(seed), shape, dtype=dtype)
184 |
185 |
186 | def triangular(
187 | left: float, right: float, mode: float, shape: Tuple[int, ...] = (), seed=None
188 | ) -> jnp.ndarray:
189 | """Draw samples from a triangular distribution.
190 |
191 | Args:
192 | left (float): The lower limit of the distribution.
193 | right (float): The upper limit of the distribution.
194 | mode (float): The mode (peak) of the distribution.
195 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to ().
196 |
197 | Returns:
198 | jnp.ndarray: A tensor of samples from a triangular distribution.
199 | """
200 | return random.triangular(time_rng_key(seed), left, right, mode, shape)
201 |
202 |
203 | def truncated_normal(
204 | lower: float,
205 | upper: float,
206 | shape: Tuple[int, ...] = (),
207 | dtype: Any = jnp.float32,
208 | seed=None,
209 | ) -> jnp.ndarray:
210 | """Draw samples from a truncated normal distribution.
211 |
212 | Args:
213 | lower (float): The lower bound of the distribution.
214 | upper (float): The upper bound of the distribution.
215 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to ().
216 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32.
217 |
218 | Returns:
219 | jnp.ndarray: A tensor of samples from a truncated normal distribution.
220 | """
221 | return random.truncated_normal(time_rng_key(seed), lower, upper, shape, dtype)
222 |
223 |
224 | def poisson(
225 | lam: float, shape: Tuple[int, ...] = (), dtype: Any = jnp.int32, seed=None
226 | ) -> jnp.ndarray:
227 | """Draw samples from a Poisson distribution.
228 |
229 | Args:
230 | lam (float): The expectation of interval (lambda parameter).
231 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to ().
232 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.int32.
233 |
234 | Returns:
235 | jnp.ndarray: A tensor of samples from a Poisson distribution.
236 | """
237 | return random.poisson(time_rng_key(seed), lam, shape=shape, dtype=dtype)
238 |
239 |
240 | def geometric(
241 | p: float, shape: Tuple[int, ...] = (), dtype: Any = jnp.int32, seed=None
242 | ) -> jnp.ndarray:
243 | """Draw samples from a geometric distribution.
244 |
245 | Args:
246 | p (float): The probability of success of an individual trial.
247 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to ().
248 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.int32.
249 |
250 | Returns:
251 | jnp.ndarray: A tensor of samples from a geometric distribution.
252 | """
253 | return random.geometric(time_rng_key(seed), p, shape=shape, dtype=dtype)
254 |
255 |
256 | def gamma(
257 | a: float, shape: Tuple[int, ...] = (), dtype: Any = jnp.float32, seed=None
258 | ) -> jnp.ndarray:
259 | """Draw samples from a gamma distribution.
260 |
261 | Args:
262 | a (float): The shape parameter of the gamma distribution.
263 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to ().
264 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32.
265 |
266 | Returns:
267 | jnp.ndarray: A tensor of samples from a gamma distribution.
268 | """
269 | return random.gamma(time_rng_key(seed), a, shape=shape, dtype=dtype)
270 |
271 |
272 | def chisquare(
273 | df: float, shape: Tuple[int, ...] = (), dtype: Any = jnp.float32, seed=None
274 | ) -> jnp.ndarray:
275 | """Draw samples from a chi-square distribution.
276 |
277 | Args:
278 | df (float): The degrees of freedom.
279 | shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to ().
280 | dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32.
281 |
282 | Returns:
283 | jnp.ndarray: A tensor of samples from a chi-square distribution.
284 | """
285 | return random.chisquare(time_rng_key(seed), df, shape=shape, dtype=dtype)
286 |
--------------------------------------------------------------------------------
/nanodl/__src/utils/vision.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import jax
4 | import jax.numpy as jnp
5 |
6 |
7 | @jax.jit
8 | def normalize_images(images: jnp.ndarray) -> jnp.ndarray:
9 | """
10 | Normalize images to have zero mean and unit variance.
11 |
12 | Args:
13 | images (jnp.ndarray): Input images of shape (N, H, W, C), where N is the number of images,
14 | H is height, W is width, and C is the number of channels.
15 |
16 | Returns:
17 | jnp.ndarray: Normalized images of the same shape as the input.
18 |
19 | Example usage:
20 | ```
21 | >>> images = jnp.array([[[[0.0, 0.5], [1.0, 0.25]]]]) # One image of shape (1, 2, 2, 1)
22 | >>> normalized_images = normalize_images(images)
23 | >>> print(normalized_images)
24 | ```
25 | """
26 | mean = images.mean(axis=(1, 2, 3), keepdims=True)
27 | std = images.std(axis=(1, 2, 3), keepdims=True)
28 | return (images - mean) / (std + 1e-5)
29 |
30 |
31 | def random_crop(images: jnp.ndarray, crop_size: int) -> jnp.ndarray:
32 | """
33 | Randomly crop a batch of images to a specified size using JAX.
34 |
35 | This function takes a batch of images and randomly crops each image to the specified size.
36 | It uses JAX for random number generation to determine the starting coordinates of the crop.
37 |
38 | Args:
39 | images (jax.numpy.ndarray): A 4D array of shape (batch_size, height, width, channels),
40 | representing a batch of images.
41 | crop_size (int): The size to which each image will be cropped. Both the height and width
42 | of the crop will be equal to `crop_size`.
43 |
44 | Returns:
45 | jax.numpy.ndarray: The cropped images, with shape (batch_size, crop_size, crop_size, channels).
46 |
47 | Example usage:
48 | ```
49 | >>> images = jnp.ones((10, 100, 100, 3)) # Batch of 10 images of size 100x100 with 3 channels
50 | >>> crop_size = 64
51 | >>> cropped_images = random_crop(images, crop_size)
52 | >>> print(cropped_images.shape)
53 | ```
54 | """
55 | key = jax.random.PRNGKey(int(time.time()))
56 | _, height, width, _ = images.shape
57 | height_start = jax.random.randint(key, (), 0, height - crop_size + 1)
58 | width_start = jax.random.randint(key, (), 0, width - crop_size + 1)
59 | height_end = height_start + crop_size
60 | width_end = width_start + crop_size
61 | crops = images[:, height_start:height_end, width_start:width_end, :]
62 | return crops
63 |
64 |
65 | def gaussian_blur(image: jnp.ndarray, kernel_size: int, sigma: float) -> jnp.ndarray:
66 | """
67 | Apply Gaussian blur to a multi-channel image.
68 |
69 | Args:
70 | image (jnp.ndarray): Input image of shape (H, W, C).
71 | kernel_size (int): Size of the Gaussian kernel (must be odd).
72 | sigma (float): Standard deviation of the Gaussian kernel.
73 |
74 | Returns:
75 | jnp.ndarray: Blurred image of the same shape as the input.
76 |
77 | Example usage:
78 | ```
79 | >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels
80 | >>> blurred_image = gaussian_blur(image, kernel_size=3, sigma=1.0)
81 | >>> print(blurred_image.shape)
82 | ```
83 | """
84 | assert kernel_size % 2 == 1, "Kernel size must be odd."
85 | ax = jnp.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0)
86 | xx, yy = jnp.meshgrid(ax, ax)
87 | kernel = jnp.exp(-(xx**2 + yy**2) / (2.0 * sigma**2))
88 | kernel = kernel / jnp.sum(kernel)
89 |
90 | # Apply convolution to each channel
91 | blurred_image = jnp.stack(
92 | [
93 | jax.scipy.signal.convolve2d(image[:, :, i], kernel, mode="same")
94 | for i in range(image.shape[2])
95 | ],
96 | axis=-1,
97 | )
98 | return blurred_image
99 |
100 |
101 | @jax.jit
102 | def sobel_edge_detection(image: jnp.ndarray) -> jnp.ndarray:
103 | """
104 | Apply Sobel edge detection to a multi-channel image.
105 |
106 | Args:
107 | image (jnp.ndarray): Input image of shape (H, W, C).
108 |
109 | Returns:
110 | jnp.ndarray: Image representing the edges, of the same shape as the input.
111 |
112 | Example usage:
113 | ```
114 | >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels
115 | >>> edges = sobel_edge_detection(image)
116 | >>> print(edges.shape)
117 | ```
118 | """
119 | sobel_x = jnp.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=jnp.float32)
120 | sobel_y = jnp.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=jnp.float32)
121 |
122 | def apply_sobel(channel):
123 | gx = jax.scipy.signal.convolve2d(channel, sobel_x, mode="same")
124 | gy = jax.scipy.signal.convolve2d(channel, sobel_y, mode="same")
125 | return jnp.sqrt(gx**2 + gy**2)
126 |
127 | # Apply Sobel filter to each channel and sum the results
128 | edges = jnp.sum(
129 | jnp.stack(
130 | [apply_sobel(image[:, :, i]) for i in range(image.shape[2])], axis=-1
131 | ),
132 | axis=-1,
133 | )
134 | return edges
135 |
136 |
137 | @jax.jit
138 | def adjust_brightness(image: jnp.ndarray, factor: float) -> jnp.ndarray:
139 | """
140 | Adjust the brightness of an image.
141 |
142 | Args:
143 | image (jnp.ndarray): Input image of shape (H, W, C).
144 | factor (float): Factor to adjust brightness. Values > 1 increase brightness,
145 | values < 1 decrease brightness.
146 |
147 | Returns:
148 | jnp.ndarray: Brightness-adjusted image of the same shape as the input.
149 |
150 | Example usage:
151 | ```
152 | >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels
153 | >>> adjusted_image = adjust_brightness(image, factor=1.5)
154 | >>> print(adjusted_image.shape)
155 | ```
156 | """
157 | return jnp.clip(image * factor, 0, 1)
158 |
159 |
160 | @jax.jit
161 | def adjust_contrast(image: jnp.ndarray, factor: float) -> jnp.ndarray:
162 | """
163 | Adjust the contrast of an image.
164 |
165 | Args:
166 | image (jnp.ndarray): Input image of shape (H, W, C).
167 | factor (float): Factor to adjust contrast. Values > 1 increase contrast,
168 | values < 1 decrease contrast.
169 |
170 | Returns:
171 | jnp.ndarray: Contrast-adjusted image of the same shape as the input.
172 |
173 | Example usage:
174 | ```
175 | >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels
176 | >>> adjusted_image = adjust_contrast(image, factor=1.5)
177 | >>> print(adjusted_image.shape)
178 | ```
179 | """
180 | mean = jnp.mean(image, axis=(0, 1), keepdims=True)
181 | return jnp.clip((image - mean) * factor + mean, 0, 1)
182 |
183 |
184 | @jax.jit
185 | def flip_image(image: jnp.ndarray, horizontal: jnp.ndarray) -> jnp.ndarray:
186 | """
187 | Flip an image horizontally or vertically.
188 |
189 | Args:
190 | image (jnp.ndarray): Input image of shape (H, W, C).
191 | horizontal (jnp.ndarray): If True (jax.numpy.array with a single True value), flip horizontally;
192 | otherwise, flip vertically.
193 |
194 | Returns:
195 | jnp.ndarray: Flipped image of the same shape as the input.
196 |
197 | Example usage:
198 | ```
199 | >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels
200 | >>> flipped_image_horizontally = flip_image(image, jnp.array([True]))
201 | >>> flipped_image_vertically = flip_image(image, jnp.array([False]))
202 | >>> print(flipped_image_horizontally.shape, flipped_image_vertically.shape)
203 | ```
204 | """
205 | return jnp.where(horizontal, image[:, ::-1, :], image[::-1, :, :])
206 |
207 |
208 | @jax.jit
209 | def random_flip_image(
210 | image: jnp.ndarray, key: jax.random.PRNGKey, horizontal: jnp.ndarray
211 | ) -> jnp.ndarray:
212 | """
213 | Randomly flip an image horizontally or vertically using JAX.
214 |
215 | Args:
216 | image (jnp.ndarray): Input image of shape (H, W, C).
217 | key (jax.random.PRNGKey): A PRNG key used for random number generation.
218 | horizontal (jnp.ndarray): JAX array with a single boolean value indicating the flip direction.
219 | If True (jax.numpy.array with a single True value), flip horizontally;
220 | otherwise, flip vertically.
221 |
222 | Returns:
223 | jnp.ndarray: Randomly flipped image of the same shape as the input.
224 |
225 | Example usage:
226 | ```
227 | >>> key = jax.random.PRNGKey(0)
228 | >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels
229 | >>> flipped_image = random_flip_image(image, key, jnp.array([True]))
230 | >>> print(flipped_image.shape)
231 | ```
232 | """
233 | flip = jax.random.uniform(key) > 0.5
234 | flip_horizontal = jnp.where(horizontal, image[:, ::-1, :], image)
235 | flip_vertical = jnp.where(horizontal, image, image[::-1, :, :])
236 | return jnp.where(flip, flip_horizontal, flip_vertical)
237 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | jax
2 | jaxlib
3 | flax
4 | optax
5 | einops
6 | sentencepiece
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name="nanodl",
5 | version="0.0.0",
6 | author="Henry Ndubuaku",
7 | author_email="ndubuakuhenry@gmail.com",
8 | description="A Jax-based library for designing and training transformer models from scratch.",
9 | long_description=open("README.md").read(),
10 | long_description_content_type="text/markdown",
11 | url="https://github.com/hmunachi/nanodl",
12 | packages=find_packages(),
13 | install_requires=[
14 | "flax",
15 | "jax",
16 | "jaxlib",
17 | "optax",
18 | "einops",
19 | "sentencepiece",
20 | ],
21 | classifiers=[
22 | "Development Status :: 3 - Alpha",
23 | "Intended Audience :: Developers",
24 | "Intended Audience :: Science/Research",
25 | "Intended Audience :: Education",
26 | "Topic :: Software Development :: Build Tools",
27 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
28 | "License :: OSI Approved :: MIT License",
29 | "Programming Language :: Python :: 3",
30 | "Programming Language :: Python :: 3.7",
31 | "Programming Language :: Python :: 3.8",
32 | "Programming Language :: Python :: 3.9",
33 | ],
34 | keywords="transformers jax machine learning deep learning pytorch tensorflow",
35 | python_requires=">=3.7",
36 | )
37 |
--------------------------------------------------------------------------------
/tests/files/sample.txt:
--------------------------------------------------------------------------------
1 | Hello, world! This is a test of the Tokenizer.
2 | Let's see how it tokenizes this file.
3 | Another sentence to check the tokenization process.
--------------------------------------------------------------------------------
/tests/test_classic.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import jax
4 | import jax.numpy as jnp
5 |
6 | from nanodl import *
7 |
8 |
9 | class TestNaiveBayesFunctions(unittest.TestCase):
10 | def setUp(self):
11 | self.num_samples = 3
12 | self.num_features = 2
13 | self.num_classes = 2
14 | self.X = jnp.array([[0, 1], [1, 0], [1, 1]])
15 | self.y = jnp.array([0, 1, 0])
16 |
17 | def test_naive_bayes_classifier(self):
18 | classifier = NaiveBayesClassifier(num_classes=self.num_classes)
19 | classifier.fit(self.X, self.y)
20 | predictions = classifier.predict(self.X)
21 | self.assertEqual(predictions.shape, (self.num_samples,))
22 | self.assertTrue(
23 | jnp.all(predictions >= 0) and jnp.all(predictions < self.num_classes)
24 | )
25 |
26 |
27 | class TestKClustering(unittest.TestCase):
28 | def setUp(self):
29 | self.k = 3
30 | self.num_samples = 300
31 | self.num_features = 2
32 | self.X = jax.random.normal(
33 | jax.random.PRNGKey(0), (self.num_samples, self.num_features)
34 | )
35 |
36 | def test_kmeans_fit_predict(self):
37 | kmeans = KMeans(k=self.k)
38 | kmeans.fit(self.X)
39 | clusters = kmeans.predict(self.X)
40 | self.assertEqual(len(set(clusters.tolist())), self.k)
41 |
42 | def test_gmm_fit_predict(self):
43 | gmm = GaussianMixtureModel(n_components=self.k)
44 | gmm.fit(self.X)
45 | labels = gmm.predict(self.X)
46 | self.assertEqual(len(set(labels.tolist())), self.k)
47 |
48 |
49 | class TestPCA(unittest.TestCase):
50 | def test_pca_fit_transform(self):
51 | data = jax.random.normal(jax.random.PRNGKey(0), (1000, 10))
52 | pca = PCA(n_components=2)
53 | pca.fit(data)
54 | transformed_data = pca.transform(data)
55 | self.assertEqual(transformed_data.shape, (1000, 2))
56 |
57 | def test_pca_inverse_transform(self):
58 | data = jax.random.normal(jax.random.PRNGKey(0), (1000, 10))
59 | pca = PCA(n_components=2)
60 | pca.fit(data)
61 | transformed_data = pca.transform(data)
62 | inverse_data = pca.inverse_transform(transformed_data)
63 | self.assertEqual(inverse_data.shape, data.shape)
64 |
65 | def test_pca_sample(self):
66 | data = jax.random.normal(jax.random.PRNGKey(0), (1000, 10))
67 | pca = PCA(n_components=2)
68 | pca.fit(data)
69 | synthetic_samples = pca.sample(n_samples=100)
70 | self.assertEqual(synthetic_samples.shape, (100, 10))
71 |
72 |
73 | class TestRegression(unittest.TestCase):
74 | def test_linear_regression(self):
75 | num_samples = 100
76 | input_dim = 1
77 | output_dim = 1
78 | x_data = jax.random.normal(jax.random.PRNGKey(0), (num_samples, input_dim))
79 | y_data = jnp.dot(x_data, jnp.array([[2.0]])) - jnp.array([[-1.0]])
80 | lr_model = LinearRegression(input_dim, output_dim)
81 | lr_model.fit(x_data, y_data)
82 | learned_weights, learned_bias = lr_model.get_params()
83 | self.assertTrue(jnp.allclose(learned_weights, jnp.array([[2.0]]), atol=1e-1))
84 | self.assertTrue(jnp.allclose(learned_bias, jnp.array([[1.0]]), atol=1e-1))
85 |
86 | def test_logistic_regression(self):
87 | num_samples = 100
88 | input_dim = 2
89 | x_data = jax.random.normal(jax.random.PRNGKey(0), (num_samples, input_dim))
90 | logits = jnp.dot(x_data, jnp.array([0.5, -0.5])) - 0.1
91 | y_data = (logits > 0).astype(jnp.float32)
92 | lr_model = LogisticRegression(input_dim)
93 | lr_model.fit(x_data, y_data)
94 | test_data = jax.random.normal(jax.random.PRNGKey(0), (num_samples, input_dim))
95 | predictions = lr_model.predict(test_data)
96 | self.assertTrue(jnp.all(predictions >= 0) and jnp.all(predictions <= 1))
97 |
98 | def test_gaussian_process(self):
99 | def rbf_kernel(x1, x2, length_scale=1.0):
100 | diff = x1[:, None] - x2
101 | return jnp.exp(-0.5 * jnp.sum(diff**2, axis=-1) / length_scale**2)
102 |
103 | num_samples = 100
104 | input_dim = 1
105 | X_train = jax.random.normal(jax.random.PRNGKey(0), (num_samples, input_dim))
106 | y_train = (
107 | jnp.sin(X_train)
108 | + jax.random.normal(jax.random.PRNGKey(0), (num_samples, 1)) * 0.1
109 | )
110 | gp = GaussianProcess(kernel=rbf_kernel, noise=1e-3)
111 | gp.fit(X_train, y_train)
112 | X_new = jax.random.normal(jax.random.PRNGKey(0), (num_samples, input_dim))
113 | mean, covariance = gp.predict(X_new)
114 | self.assertEqual(mean.shape, (num_samples, 1))
115 | self.assertEqual(covariance.shape, (num_samples, num_samples))
116 |
117 |
118 | if __name__ == "__main__":
119 | unittest.main()
120 |
--------------------------------------------------------------------------------
/tests/test_kan.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import jax.numpy as jnp
3 | from jax import random
4 | from nanodl import *
5 |
6 | class TestKANLinearVariants(unittest.TestCase):
7 | def setUp(self):
8 | self.in_features = 4
9 | self.out_features = 3
10 | self.degree = 2
11 |
12 | self.key = random.PRNGKey(0)
13 | self.x = random.normal(self.key, (10, self.in_features))
14 |
15 | self.models = {
16 | # "BSplineKANLinear": KANLinear(self.in_features, self.out_features, self.degree),
17 | "ChebyKANLinear": ChebyKANLinear(self.in_features, self.out_features, self.degree),
18 | "LegendreKANLinear": LegendreKANLinear(self.in_features, self.out_features, self.degree),
19 | "MonomialKANLinear": MonomialKANLinear(self.in_features, self.out_features, self.degree),
20 | "FourierKANLinear": FourierKANLinear(self.in_features, self.out_features, self.degree),
21 | "HermiteKANLinear": HermiteKANLinear(self.in_features, self.out_features, self.degree),
22 | }
23 |
24 | def test_model_outputs(self):
25 | for model_name, model in self.models.items():
26 | with self.subTest(model=model_name):
27 | variables = model.init(self.key, self.x)
28 | output = model.apply(variables, self.x)
29 | self.assertEqual(output.shape, (10, self.out_features))
30 | self.assertTrue(jnp.all(jnp.isfinite(output)))
31 |
32 |
33 | if __name__ == "__main__":
34 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_models.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import jax
4 | import jax.numpy as jnp
5 |
6 | from nanodl import *
7 |
8 |
9 | class TestTextBasedModels(unittest.TestCase):
10 | def setUp(self):
11 | self.batch_size = 8
12 | self.max_length = 51
13 | self.vocab_size = 1000
14 | self.embed_dim = 256
15 |
16 | self.data = jnp.arange(
17 | self.batch_size * self.max_length, dtype=jnp.int32
18 | ).reshape((self.batch_size, self.max_length))
19 |
20 | self.dummy_inputs = self.data[:, :-1]
21 | self.dummy_targets = self.data[:, 1:]
22 |
23 | self.hyperparams = {
24 | "num_layers": 1,
25 | "hidden_dim": self.embed_dim,
26 | "num_heads": 2,
27 | "feedforward_dim": self.embed_dim,
28 | "dropout": 0.1,
29 | "vocab_size": self.vocab_size,
30 | "embed_dim": self.embed_dim,
31 | "max_length": self.max_length,
32 | "start_token": 0,
33 | "end_token": 50,
34 | }
35 |
36 | def test_t5_model(self):
37 | model = T5(**self.hyperparams)
38 | self._test_encoder_decoder_model(model)
39 |
40 | def test_transformer_model(self):
41 | model = Transformer(**self.hyperparams)
42 | self._test_encoder_decoder_model(model)
43 |
44 | def test_lamda_model(self):
45 | model = LaMDA(**self.hyperparams)
46 | self._test_decoder_only_model(model)
47 |
48 | def test_gpt3_model(self):
49 | model = GPT4(**self.hyperparams)
50 | self._test_decoder_only_model(model)
51 |
52 | def test_gpt3_model(self):
53 | model = GPT4(**self.hyperparams)
54 | self._test_decoder_only_model(model)
55 |
56 | def test_mistral_model(self):
57 | model = Mistral(**self.hyperparams, num_groups=2, window_size=5, shift_size=2)
58 | self._test_decoder_only_model(model)
59 |
60 | def test_mixtral_model(self):
61 | model = Mixtral(**self.hyperparams, num_groups=2, window_size=5, shift_size=2)
62 | self._test_decoder_only_model(model)
63 |
64 | def test_llama_model(self):
65 | model = Llama3(**self.hyperparams, num_groups=2)
66 | self._test_decoder_only_model(model)
67 |
68 | def test_gemma_model(self):
69 | model = Gemma(**self.hyperparams, num_groups=2)
70 | self._test_decoder_only_model(model)
71 |
72 | def _test_encoder_decoder_model(self, model):
73 | rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)}
74 | params = model.init(rngs, self.dummy_inputs, self.dummy_targets)["params"]
75 |
76 | outputs = model.apply(
77 | {"params": params}, self.dummy_inputs, self.dummy_targets, rngs=rngs
78 | )
79 |
80 | self.assertEqual(
81 | outputs.shape, (self.batch_size, self.max_length - 1, self.vocab_size)
82 | )
83 |
84 | def _test_decoder_only_model(self, model):
85 | rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)}
86 | params = model.init(rngs, self.dummy_inputs)["params"]
87 |
88 | outputs = model.apply({"params": params}, self.dummy_inputs, rngs=rngs)
89 |
90 | self.assertEqual(
91 | outputs.shape, (self.batch_size, self.max_length - 1, self.vocab_size)
92 | )
93 |
94 | def test_reward_model(self):
95 | model = RewardModel(
96 | Mixtral(**self.hyperparams, num_groups=2, window_size=5, shift_size=2),
97 | dim=self.hyperparams["hidden_dim"],
98 | dropout=0.1,
99 | )
100 | rngs = jax.random.PRNGKey(0)
101 | rngs, dropout_rng = jax.random.split(rngs)
102 | params = model.init(
103 | {"params": rngs, "dropout": dropout_rng}, self.dummy_inputs
104 | )["params"]
105 | rewards = model.apply(
106 | {"params": params}, self.dummy_inputs, rngs={"dropout": dropout_rng}
107 | )
108 | assert rewards.shape == (self.batch_size,)
109 |
110 |
111 | class TestVisionBasedModels(unittest.TestCase):
112 | def setUp(self):
113 | self.batch_size = 8
114 | self.n_outputs = 5
115 | self.embed_dim = 256
116 | self.patch_size = (16, 16)
117 | self.dummy_inputs = jnp.ones((self.batch_size, 224, 224, 3))
118 | key = jax.random.PRNGKey(10)
119 |
120 | self.dummy_labels = jax.random.randint(
121 | key, shape=(self.batch_size,), minval=0, maxval=self.n_outputs - 1
122 | )
123 |
124 | self.hyperparams = {
125 | "dropout": 0.1,
126 | "num_heads": 2,
127 | "feedforward_dim": self.embed_dim,
128 | "patch_size": self.patch_size,
129 | "hidden_dim": self.embed_dim,
130 | "num_layers": 4,
131 | "n_outputs": self.n_outputs,
132 | }
133 |
134 | def test_vit_model(self):
135 | model = ViT(**self.hyperparams)
136 | self._test_model(model)
137 |
138 | def test_mixer_model(self):
139 | model = Mixer(**self.hyperparams)
140 | self._test_model(model)
141 |
142 | def _test_model(self, model):
143 | rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)}
144 |
145 | params = model.init(rngs, self.dummy_inputs)["params"]
146 |
147 | outputs = model.apply({"params": params}, self.dummy_inputs, rngs=rngs)[0]
148 |
149 | self.assertEqual(outputs.shape, (self.batch_size, self.n_outputs))
150 |
151 |
152 | class TestCLIPModel(unittest.TestCase):
153 | def setUp(self):
154 | self.batch_size = 8
155 | self.max_length = 50
156 | self.vocab_size = 1000
157 | self.embed_dim = 256
158 | self.dummy_texts = jnp.ones((self.batch_size, self.max_length), dtype=jnp.int32)
159 | self.dummy_images = jnp.ones((self.batch_size, 224, 224, 3))
160 |
161 | self.clip_params = {
162 | "dropout": 0.1,
163 | "num_heads": 8,
164 | "feedforward_dim": self.embed_dim,
165 | "num_layers_text": 4,
166 | "hidden_dim_text": self.embed_dim,
167 | "image_patch_size": (16, 16),
168 | "hidden_dim_image": self.embed_dim,
169 | "num_layers_images": 4,
170 | "max_len": self.max_length,
171 | "vocab_size": self.vocab_size,
172 | "embed_dim": self.embed_dim,
173 | }
174 |
175 | self.model = CLIP(**self.clip_params)
176 |
177 | def test_clip_model_initialization_and_processing(self):
178 | rng = jax.random.PRNGKey(0)
179 | params = self.model.init(rng, self.dummy_texts, self.dummy_images)["params"]
180 |
181 | loss = self.model.apply({"params": params}, self.dummy_texts, self.dummy_images)
182 |
183 | self.assertIsNotNone(loss)
184 |
185 |
186 | class TestWhisperModel(unittest.TestCase):
187 | def setUp(self):
188 | self.batch_size = 8
189 | self.max_length = 50
190 | self.embed_dim = 256
191 | self.vocab_size = 1000
192 |
193 | self.dummy_targets = jnp.arange(
194 | self.batch_size * self.max_length, dtype=jnp.int32
195 | ).reshape((self.batch_size, self.max_length))
196 |
197 | self.dummy_inputs = jnp.ones((self.batch_size, self.max_length, self.embed_dim))
198 |
199 | self.hyperparams = {
200 | "num_layers": 1,
201 | "hidden_dim": self.embed_dim,
202 | "num_heads": 2,
203 | "feedforward_dim": self.embed_dim,
204 | "dropout": 0.1,
205 | "vocab_size": self.vocab_size,
206 | "embed_dim": self.embed_dim,
207 | "max_length": self.max_length,
208 | "start_token": 0,
209 | "end_token": 50,
210 | }
211 |
212 | self.model = Whisper(**self.hyperparams)
213 |
214 | def test_whisper_model_initialization_and_processing(self):
215 | rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)}
216 |
217 | params = self.model.init(rngs, self.dummy_inputs, self.dummy_targets)["params"]
218 |
219 | outputs = self.model.apply(
220 | {"params": params}, self.dummy_inputs, self.dummy_targets, rngs=rngs
221 | )
222 |
223 | self.assertEqual(
224 | outputs.shape, (self.batch_size, self.max_length, self.vocab_size)
225 | )
226 |
227 |
228 | class TestDiffusionModel(unittest.TestCase):
229 | def setUp(self):
230 | self.image_size = 32
231 | self.widths = [32, 64, 128]
232 | self.block_depth = 2
233 | self.input_shape = (3, self.image_size, self.image_size, 3)
234 | self.images = jax.random.normal(jax.random.PRNGKey(0), self.input_shape)
235 |
236 | self.model = DiffusionModel(self.image_size, self.widths, self.block_depth)
237 |
238 | def test_diffusion_model_initialization_and_processing(self):
239 | params = self.model.init(jax.random.PRNGKey(0), self.images)
240 | pred_noises, pred_images = self.model.apply(params, self.images)
241 | self.assertEqual(pred_noises.shape, self.input_shape)
242 | self.assertEqual(pred_images.shape, self.input_shape)
243 |
244 |
245 | class TestGATModel(unittest.TestCase):
246 | def setUp(self):
247 | self.num_nodes = 10
248 | self.num_features = 5
249 | self.nclass = 3
250 |
251 | self.x = jax.random.normal(
252 | jax.random.PRNGKey(0), (self.num_nodes, self.num_features)
253 | )
254 |
255 | self.adj = jax.random.bernoulli(
256 | jax.random.PRNGKey(0), 0.3, (self.num_nodes, self.num_nodes)
257 | )
258 |
259 | self.model = GAT(
260 | nfeat=self.num_features,
261 | nhid=8,
262 | nclass=self.nclass,
263 | dropout_rate=0.5,
264 | alpha=0.2,
265 | nheads=3,
266 | )
267 |
268 | def test_gat_model_initialization_and_processing(self):
269 | params = self.model.init(jax.random.key(0), self.x, self.adj, training=False)
270 |
271 | output = self.model.apply(params, self.x, self.adj, training=False)
272 |
273 | self.assertEqual(output.shape, (self.num_nodes, self.nclass))
274 |
275 |
276 | class TestIJEPAModel(unittest.TestCase):
277 | def setUp(self):
278 | self.image_size = 128
279 | self.num_channels = 3
280 | self.patch_size = 16
281 | self.embed_dim = 32
282 | self.predictor_bottleneck = 16
283 | self.num_heads = 4
284 | self.predictor_num_heads = 4
285 | self.num_layers = 2
286 | self.predictor_num_layers = 1
287 | self.dropout_p = 0
288 | self.num_patches = (self.image_size**2) / (self.patch_size**2)
289 |
290 | self.x = jax.random.normal(
291 | jax.random.PRNGKey(0),
292 | (1, self.image_size, self.image_size, self.num_channels),
293 | )
294 |
295 | self.model = IJEPA(
296 | image_size=self.image_size,
297 | num_channels=self.num_channels,
298 | patch_size=self.patch_size,
299 | embed_dim=self.embed_dim,
300 | predictor_bottleneck=self.predictor_bottleneck,
301 | num_heads=self.num_heads,
302 | predictor_num_heads=self.predictor_num_heads,
303 | num_layers=self.num_layers,
304 | predictor_num_layers=self.predictor_num_layers,
305 | dropout_p=self.dropout_p,
306 | )
307 |
308 | self.data_sampler = IJEPADataSampler(
309 | image_size=self.image_size, M=4, patch_size=self.patch_size
310 | )
311 |
312 | def test_ijepa_data_sampling(self):
313 | context_mask, target_mask = self.data_sampler()
314 | self.assertEqual(context_mask.shape, (4, self.num_patches))
315 | self.assertEqual(target_mask.shape, (4, self.num_patches))
316 |
317 | def test_ijepa_model_initialization_and_processing(self):
318 | context_mask, target_mask = self.data_sampler()
319 |
320 | params = self.model.init(
321 | jax.random.key(0),
322 | self.x,
323 | context_mask[jnp.newaxis],
324 | target_mask[jnp.newaxis],
325 | training=False,
326 | )
327 |
328 | outputs, _ = self.model.apply(
329 | params,
330 | self.x,
331 | context_mask[jnp.newaxis],
332 | target_mask[jnp.newaxis],
333 | training=False,
334 | )
335 |
336 | self.assertEqual(len(outputs), 4)
337 | self.assertEqual(outputs[0][0].shape, (1, self.num_patches, self.embed_dim))
338 | self.assertEqual(outputs[0][0].shape, outputs[0][1].shape)
339 |
340 |
341 | if __name__ == "__main__":
342 | unittest.main()
343 |
--------------------------------------------------------------------------------
/tests/test_random.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import jax.numpy as jnp
4 |
5 | from nanodl import (
6 | bernoulli,
7 | bits,
8 | categorical,
9 | chisquare,
10 | choice,
11 | exponential,
12 | gamma,
13 | geometric,
14 | gumbel,
15 | normal,
16 | permutation,
17 | poisson,
18 | randint,
19 | time_rng_key,
20 | triangular,
21 | truncated_normal,
22 | uniform,
23 | )
24 |
25 |
26 | class TestRandomFunctions(unittest.TestCase):
27 |
28 | def test_time_rng_key(self):
29 | key1 = time_rng_key(seed=42)
30 | key2 = time_rng_key(seed=42)
31 | self.assertTrue(
32 | jnp.array_equal(key1, key2), "Keys should be equal for the same seed"
33 | )
34 |
35 | def test_uniform(self):
36 | result = uniform((2, 3))
37 | self.assertEqual(result.shape, (2, 3))
38 | self.assertEqual(result.dtype, jnp.float32)
39 |
40 | def test_normal(self):
41 | result = normal((4, 5), seed=42)
42 | self.assertEqual(result.shape, (4, 5))
43 | self.assertEqual(result.dtype, jnp.float32)
44 |
45 | def test_bernoulli(self):
46 | result = bernoulli(0.5, (10,), seed=42)
47 | self.assertEqual(result.shape, (10,))
48 | self.assertEqual(result.dtype, jnp.bool_)
49 |
50 | def test_categorical(self):
51 | logits = jnp.array([0.1, 0.2, 0.7])
52 | result = categorical(logits, shape=(5,), seed=42)
53 | self.assertEqual(result.shape, (5,))
54 |
55 | def test_randint(self):
56 | result = randint((3, 3), 0, 10, seed=42)
57 | self.assertEqual(result.shape, (3, 3))
58 | self.assertEqual(result.dtype, jnp.int32)
59 |
60 | def test_permutation(self):
61 | arr = jnp.arange(10)
62 | result = permutation(arr, seed=42)
63 | self.assertEqual(result.shape, arr.shape)
64 | self.assertNotEqual(jnp.all(result == arr), True)
65 |
66 | def test_gumbel(self):
67 | result = gumbel((2, 2), seed=42)
68 | self.assertEqual(result.shape, (2, 2))
69 | self.assertEqual(result.dtype, jnp.float32)
70 |
71 | def test_choice(self):
72 | result = choice(5, shape=(3,), seed=42)
73 | self.assertEqual(result.shape, (3,))
74 |
75 | def test_bits(self):
76 | result = bits((2, 2), seed=42)
77 | self.assertEqual(result.shape, (2, 2))
78 | self.assertEqual(result.dtype, jnp.uint32)
79 |
80 | def test_exponential(self):
81 | result = exponential((2, 2), seed=42)
82 | self.assertEqual(result.shape, (2, 2))
83 | self.assertEqual(result.dtype, jnp.float32)
84 |
85 | def test_triangular(self):
86 | result = triangular(0, 1, 0.5, (2, 2), seed=42)
87 | self.assertEqual(result.shape, (2, 2))
88 | self.assertEqual(result.dtype, jnp.float32)
89 |
90 | def test_truncated_normal(self):
91 | result = truncated_normal(0, 1, (2, 2), seed=42)
92 | self.assertEqual(result.shape, (2, 2))
93 | self.assertEqual(result.dtype, jnp.float32)
94 |
95 | def test_poisson(self):
96 | result = poisson(3, (2, 2), seed=42)
97 | self.assertEqual(result.shape, (2, 2))
98 | self.assertEqual(result.dtype, jnp.int32)
99 |
100 | def test_geometric(self):
101 | result = geometric(0.5, (2, 2), seed=42)
102 | self.assertEqual(result.shape, (2, 2))
103 | self.assertEqual(result.dtype, jnp.int32)
104 |
105 | def test_gamma(self):
106 | result = gamma(2, (2, 2), seed=42)
107 | self.assertEqual(result.shape, (2, 2))
108 | self.assertEqual(result.dtype, jnp.float32)
109 |
110 | def test_chisquare(self):
111 | result = chisquare(2, (2, 2), seed=42)
112 | self.assertEqual(result.shape, (2, 2))
113 | self.assertEqual(result.dtype, jnp.float32)
114 |
115 |
116 | if __name__ == "__main__":
117 | unittest.main()
118 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import jax
4 | import jax.numpy as jnp
5 |
6 | from nanodl import *
7 |
8 |
9 | class TestDataset(unittest.TestCase):
10 | def test_dataset_length(self):
11 | class DummyDataset(Dataset):
12 | def __init__(self, data):
13 | self.data = data
14 |
15 | def __len__(self):
16 | return len(self.data)
17 |
18 | def __getitem__(self, index):
19 | return self.data[index]
20 |
21 | dataset = DummyDataset(jnp.arange(10))
22 | self.assertEqual(len(dataset), 10)
23 |
24 | def test_dataset_getitem(self):
25 | class DummyDataset(Dataset):
26 | def __init__(self, data):
27 | self.data = data
28 |
29 | def __len__(self):
30 | return len(self.data)
31 |
32 | def __getitem__(self, index):
33 | return self.data[index]
34 |
35 | dataset = DummyDataset(jnp.arange(10))
36 | item = dataset[5]
37 | self.assertEqual(item, 5)
38 |
39 |
40 | class TestArrayDataset(unittest.TestCase):
41 | def test_array_dataset_length(self):
42 | dataset = ArrayDataset(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
43 | self.assertEqual(len(dataset), 3)
44 |
45 | def test_array_dataset_getitem(self):
46 | dataset = ArrayDataset(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
47 | item = dataset[1]
48 | self.assertEqual(item, (2, 5))
49 |
50 |
51 | class TestDataLoader(unittest.TestCase):
52 | def test_data_loader_length(self):
53 | dataset = ArrayDataset(jnp.ones((1001, 256, 256)), jnp.ones((1001, 256, 256)))
54 | dataloader = DataLoader(dataset, batch_size=10, shuffle=True, drop_last=False)
55 | self.assertEqual(len(dataloader), 101)
56 |
57 | def test_data_loader_iteration(self):
58 | dataset = ArrayDataset(jnp.ones((1001, 256, 256)), jnp.ones((1001, 256, 256)))
59 | dataloader = DataLoader(dataset, batch_size=10, shuffle=True, drop_last=True)
60 | for a, b in dataloader:
61 | self.assertEqual(a.shape, (10, 256, 256))
62 | self.assertEqual(b.shape, (10, 256, 256))
63 |
64 |
65 | class TestMLFunctions(unittest.TestCase):
66 | def test_batch_cosine_similarities(self):
67 | source = jnp.array([1, 0, 0])
68 | candidates = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
69 | similarities = batch_cosine_similarities(source, candidates)
70 | expected_results = jnp.array([1.0, 0.0, 0.0])
71 | self.assertTrue(jnp.allclose(similarities, expected_results))
72 |
73 | def test_batch_pearsonr(self):
74 | x = jnp.array([[1, 2, 3], [4, 5, 6]])
75 | y = jnp.array([[6, 5, 4], [2, 6, 8]])
76 | correlations = batch_pearsonr(x, y)
77 | expected_results = jnp.array([-1.0, 1.0, 1.0])
78 | self.assertTrue(jnp.allclose(correlations, expected_results))
79 |
80 | def test_classification_scores(self):
81 | labels = jnp.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0])
82 | preds = jnp.array([1, 1, 1, 0, 1, 0, 1, 0, 0, 0])
83 | scores = classification_scores(labels, preds)
84 | expected_results = jnp.array([0.8, 0.8, 0.8, 0.8000001])
85 | self.assertTrue(jnp.allclose(scores, expected_results))
86 |
87 | def test_mean_reciprocal_rank(self):
88 | predictions = jnp.array(
89 | [
90 | [0, 1, 2], # "correct" prediction at index 0
91 | [1, 0, 2], # "correct" prediction at index 1
92 | [2, 1, 0], # "correct" prediction at index 2
93 | ]
94 | )
95 | mrr_score = mean_reciprocal_rank(predictions)
96 | self.assertAlmostEqual(mrr_score, 0.61111116)
97 |
98 | def test_jaccard(self):
99 | sequence1 = [1, 2, 3]
100 | sequence2 = [2, 3, 4]
101 | similarity = jaccard(sequence1, sequence2)
102 | self.assertAlmostEqual(similarity, 0.5)
103 |
104 | def test_hamming(self):
105 | sequence1 = jnp.array([1, 2, 3, 4])
106 | sequence2 = jnp.array([1, 2, 4, 4])
107 | similarity = hamming(sequence1, sequence2)
108 | self.assertEqual(similarity, 3)
109 |
110 | def test_zero_pad_sequences(self):
111 | arr = jnp.array([[1, 2, 3], [4, 5, 6]])
112 | max_length = 5
113 | padded_arr = zero_pad_sequences(arr, max_length)
114 | expected_padded_arr = jnp.array([[1, 2, 3, 0, 0], [4, 5, 6, 0, 0]])
115 | self.assertTrue(jnp.array_equal(padded_arr, expected_padded_arr))
116 |
117 | def test_entropy(self):
118 | probabilities = jnp.array([0.25, 0.75])
119 | entropy_value = entropy(probabilities)
120 | self.assertAlmostEqual(entropy_value, 0.8112781)
121 |
122 | def test_gini_impurity(self):
123 | probabilities = jnp.array([0.25, 0.75])
124 | gini_value = gini_impurity(probabilities)
125 | self.assertAlmostEqual(gini_value, 0.375)
126 |
127 | def test_kl_divergence(self):
128 | p = jnp.array([0.25, 0.75])
129 | q = jnp.array([0.5, 0.5])
130 | kl_value = kl_divergence(p, q)
131 | self.assertAlmostEqual(kl_value, 0.18872187)
132 |
133 | def test_count_parameters(self):
134 | class MyModel:
135 | def __init__(self):
136 | self.layer1 = jnp.ones((10, 20))
137 | self.layer2 = jnp.ones((5, 5))
138 |
139 | model = MyModel()
140 | params = model.__dict__
141 | total_params = count_parameters(params)
142 | self.assertEqual(total_params, 225)
143 |
144 |
145 | class TestNLPFunctions(unittest.TestCase):
146 | def setUp(self):
147 | self.hypotheses = [
148 | "the cat is on the mat",
149 | "there is a cat on the mat",
150 | ]
151 | self.references = [
152 | "the cat is on the mat",
153 | "the cat sits on the mat",
154 | ]
155 |
156 | def test_rouge(self):
157 | rouge_scores = rouge(self.hypotheses, self.references, [1, 2])
158 | expected_scores = {
159 | "ROUGE-1": {
160 | "precision": 0.7857142857142857,
161 | "recall": 0.9,
162 | "f1": 0.8333333333328402,
163 | },
164 | "ROUGE-2": {
165 | "precision": 0.6666666666666666,
166 | "recall": 0.7,
167 | "f1": 0.6818181818176838,
168 | },
169 | }
170 |
171 | def assert_nested_dicts_equal(dict1, dict2):
172 | for key in dict1.keys():
173 | if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
174 | assert_nested_dicts_equal(dict1[key], dict2[key])
175 | elif dict1[key] != dict2[key]:
176 | raise AssertionError(
177 | f"Values for key '{key}' are not equal: {dict1[key]} != {dict2[key]}"
178 | )
179 |
180 | assert_nested_dicts_equal(rouge_scores, expected_scores)
181 |
182 | def test_bleu(self):
183 | bleu_score = bleu(self.hypotheses, self.references)
184 | self.assertAlmostEqual(bleu_score, 0.03737737833658239, places=2)
185 |
186 | def test_meteor(self):
187 | meteor_score = meteor(self.hypotheses[0], self.references[0])
188 | self.assertAlmostEqual(meteor_score, 1.0, places=3) # Perfect match
189 | meteor_score = meteor(self.hypotheses[1], self.references[1])
190 | self.assertAlmostEqual(meteor_score, 0.4981684981684981, places=3)
191 |
192 | def test_cider_score(self):
193 | score = cider_score(self.hypotheses[0], self.references[0])
194 | self.assertAlmostEqual(score, 1.0, places=2) # Perfect match
195 | score = cider_score(self.hypotheses[1], self.references[1])
196 | self.assertAlmostEqual(score, 0.31595617188837527, places=2)
197 |
198 | def test_perplexity(self):
199 | log_probs = [-2.3, -1.7, -0.4]
200 | perplexity_score = perplexity(log_probs)
201 | self.assertAlmostEqual(perplexity_score, 4.0, places=3)
202 |
203 | def test_word_error_rate(self):
204 | wer_score = word_error_rate(self.hypotheses, self.references)
205 | self.assertAlmostEqual(wer_score, 0.3333333333333333, places=2)
206 |
207 |
208 | class TestVisionFunctions(unittest.TestCase):
209 | def test_normalize_images(self):
210 | images = jnp.array([[[[0.0, 0.5], [1.0, 0.25]]]])
211 | normalized_images = normalize_images(images)
212 | self.assertAlmostEqual(normalized_images.mean(), 0.0, places=3)
213 | self.assertAlmostEqual(normalized_images.std(), 1.0, places=3)
214 |
215 | def test_random_crop(self):
216 | images = jnp.ones((10, 100, 100, 3))
217 | crop_size = 64
218 | cropped_images = random_crop(images, crop_size)
219 | self.assertEqual(cropped_images.shape, (10, crop_size, crop_size, 3))
220 |
221 | def test_gaussian_blur(self):
222 | image = jnp.ones((5, 5, 3))
223 | blurred_image = gaussian_blur(image, kernel_size=3, sigma=1.0)
224 | self.assertEqual(blurred_image.shape, (5, 5, 3))
225 |
226 | def test_sobel_edge_detection(self):
227 | image = jnp.ones((5, 5, 3))
228 | edges = sobel_edge_detection(image)
229 | self.assertEqual(edges.shape, (5, 5))
230 |
231 | def test_adjust_brightness(self):
232 | image = jnp.ones((5, 5, 3))
233 | adjusted_image = adjust_brightness(image, factor=1.5)
234 | self.assertEqual(adjusted_image.shape, (5, 5, 3))
235 |
236 | def test_adjust_contrast(self):
237 | image = jnp.ones((5, 5, 3))
238 | adjusted_image = adjust_contrast(image, factor=1.5)
239 | self.assertEqual(adjusted_image.shape, (5, 5, 3))
240 |
241 | def test_flip_image(self):
242 | image = jnp.ones((5, 5, 3))
243 | flipped_image_horizontally = flip_image(image, jnp.array([True]))
244 | flipped_image_vertically = flip_image(image, jnp.array([False]))
245 | self.assertEqual(flipped_image_horizontally.shape, (5, 5, 3))
246 | self.assertEqual(flipped_image_vertically.shape, (5, 5, 3))
247 |
248 | def test_random_flip_image(self):
249 | key = jax.random.PRNGKey(0)
250 | image = jnp.ones((5, 5, 3))
251 | flipped_image = random_flip_image(image, key, jnp.array([True]))
252 | self.assertEqual(flipped_image.shape, (5, 5, 3))
253 |
254 |
255 | if __name__ == "__main__":
256 | unittest.main()
257 |
--------------------------------------------------------------------------------