├── .github └── workflows │ └── rust.yml ├── .gitignore ├── CODE-OF-CONDUCT.md ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── _typos.toml ├── bert-burn ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── examples │ ├── infer-embedding.rs │ └── masked.rs └── src │ ├── data │ ├── batcher.rs │ ├── mod.rs │ └── tokenizer.rs │ ├── embedding.rs │ ├── fill_mask.rs │ ├── lib.rs │ ├── loader.rs │ ├── model.rs │ └── pooler.rs ├── llama-burn ├── Cargo.toml ├── NOTICES.md ├── README.md ├── assets │ └── llama-burn.jpeg ├── examples │ └── chat.rs └── src │ ├── cache.rs │ ├── lib.rs │ ├── llama.rs │ ├── pretrained.rs │ ├── sampling.rs │ ├── tokenizer │ ├── base.rs │ ├── mod.rs │ ├── sentence_piece.rs │ └── tiktoken.rs │ └── transformer.rs ├── mobilenetv2-burn ├── Cargo.toml ├── NOTICES.md ├── README.md ├── examples │ └── inference.rs ├── samples │ └── dog.jpg └── src │ ├── lib.rs │ └── model │ ├── conv_norm.rs │ ├── imagenet.rs │ ├── inverted_residual.rs │ ├── mobilenetv2.rs │ ├── mod.rs │ └── weights.rs ├── resnet-burn ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── NOTICES.md ├── README.md ├── examples │ ├── finetune │ │ ├── .gitignore │ │ ├── Cargo.toml │ │ ├── examples │ │ │ └── finetune.rs │ │ └── src │ │ │ ├── data.rs │ │ │ ├── dataset.rs │ │ │ ├── inference.rs │ │ │ ├── lib.rs │ │ │ └── training.rs │ └── inference │ │ ├── Cargo.toml │ │ ├── examples │ │ └── inference.rs │ │ └── src │ │ ├── imagenet.rs │ │ └── lib.rs ├── resnet │ ├── Cargo.toml │ └── src │ │ ├── block.rs │ │ ├── lib.rs │ │ ├── resnet.rs │ │ └── weights.rs └── samples │ ├── dataset.jpg │ └── dog.jpg ├── squeezenet-burn ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── NOTICES.md ├── README.md ├── build.rs ├── examples │ └── classify.rs ├── samples │ ├── bridge.jpg │ ├── cat.jpg │ ├── coyote.jpg │ ├── flamingo.jpg │ ├── pelican.jpg │ ├── table-lamp.jpg │ └── torch.jpg └── src │ ├── lib.rs │ └── model │ ├── label.rs │ ├── label.txt │ ├── mod.rs │ ├── normalizer.rs │ ├── squeezenet1.onnx │ └── squeezenet1.rs └── yolox-burn ├── .gitignore ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── NOTICES.md ├── README.md ├── examples └── inference.rs ├── samples └── dog_bike_man.jpg └── src ├── lib.rs └── model ├── blocks.rs ├── bottleneck.rs ├── boxes.rs ├── darknet.rs ├── head.rs ├── mod.rs ├── pafpn.rs ├── weights.rs └── yolox.rs /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | 16 | strategy: 17 | matrix: 18 | rust: [stable, 1.85.0] 19 | 20 | steps: 21 | - uses: actions/checkout@v4 22 | 23 | - name: install rust 24 | uses: dtolnay/rust-toolchain@master 25 | with: 26 | components: rustfmt, clippy 27 | toolchain: ${{ matrix.rust }} 28 | 29 | - name: Check Formatting 30 | run: | 31 | cd squeezenet-burn 32 | cargo fmt --all -- --check 33 | 34 | - name: Run Clippy 35 | run: | 36 | cd squeezenet-burn 37 | cargo clippy -- -D warnings 38 | 39 | - name: Build 40 | run: | 41 | cd squeezenet-burn 42 | cargo build --verbose 43 | 44 | - name: Run tests 45 | run: | 46 | cd squeezenet-burn 47 | cargo test --verbose 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | .DS_Store 16 | 17 | # direnv files 18 | .envrc 19 | 20 | # Editor files 21 | .vscode 22 | -------------------------------------------------------------------------------- /CODE-OF-CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | nathaniel.simard.42@gmail.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Burn-rs/Models Contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🔥 Models 🔥 2 | 3 | Welcome to the Models repository! Here, you'll find a diverse collection of deep learning models and 4 | examples constructed using the [Burn](https://github.com/burn-rs/burn) deep learning framework. 5 | 6 | ## Collection of Official Models 7 | 8 | | Model | Description | Repository Link | 9 | |-------------------------------------------------|----------------------------------------------------------|---------------------------------------| 10 | | [Llama](https://github.com/meta-llama/llama3) | Llama 3 and TinyLlama large language models. | [llama-burn](llama-burn/) | 11 | | [MobileNetV2](https://arxiv.org/abs/1801.04381) | A CNN model targeted at mobile devices. | [mobilenetv2-burn](mobilenetv2-burn/) | 12 | | [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/) | 13 | | [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/) | 14 | | [RoBERTa](https://arxiv.org/abs/1907.11692) | A robustly optimized BERT pretraining approach. | [bert-burn](bert-burn/) | 15 | | [YOLOX](https://arxiv.org/abs/2107.08430) | A single-stage object detector based on the YOLO series. | [yolox-burn](yolox-burn/) | 16 | 17 | ## Community Contributions 18 | 19 | Explore the curated list of models developed by the community ♥. 20 | 21 | | Model | Description | Repository Link | 22 | |--------------------------------------------------|-------------------------------------------------------------------|-----------------------------------------------------------------------------------| 23 | | [Llama 2](https://arxiv.org/abs/2307.09288) | LLMs by Meta AI, ranging from 7 billion to 70 billion parameters. | [Gadersd/llama2-burn](https://github.com/Gadersd/llama2-burn) | 24 | | [Whisper](https://arxiv.org/abs/2212.04356) | A general-purpose speech recognition model by OpenAI. | [Gadersd/whisper-burn](https://github.com/Gadersd/whisper-burn) | 25 | | Stable Diffusion v1.4 | An image generation model developed by Stability AI. | [Gadersd/stable-diffusion-burn](https://github.com/Gadersd/stable-diffusion-burn) | 26 | | kord (music note predictor) | A music theory model that can detect notes in short audio clips. | [twitchax/kord](https://github.com/twitchax/kord) | 27 | | Whisper-Live | A fork of [Gadersd/whisper-burn](https://github.com/Gadersd/whisper-burn) which has been updated for Burn 13 and provides live transcription | [sudomonikers/whisper-burn](https://github.com/sudomonikers/whisper-burn) | 28 | | [Inception V3](https://arxiv.org/abs/1512.00567) | A CNN used for calculating FID scores. | [varonroy/inception-v3-burn](https://github.com/varonroy/inception-v3-burn/) | 29 | | [CRAFT](https://arxiv.org/abs/1904.01941) | A CNN for character-region aware text detection | [wingertge/craft-burn](https://github.com/wingertge/craft-burn) | 30 | | [RWKV v7](https://arxiv.org/abs/2503.14456) | A large language model architecture that can be used like transformer models (parallel processing of tokens) and like RNNs (sequential generation). | [dymat/rwkv-burn](https://github.com/dymat/rwkv-burn) | 31 | 32 | ## License Information 33 | 34 | Models implemented in this repository are distributed under the terms of both the MIT license and 35 | the Apache License (Version 2.0). See [LICENSE-APACHE](./LICENSE-APACHE) and 36 | [LICENSE-MIT](./LICENSE-MIT) for complete details. 37 | 38 | Please note that opening a pull request signals your agreement with these licensing terms. If you 39 | copy or adapt material from other resources or codebases, ensure that you include the original 40 | license information in the NOTICES.md file under the corresponding model directory. 41 | 42 | Community models linked in this repository may fall under different licenses, so please consult the 43 | respective repositories for specific license information. 44 | -------------------------------------------------------------------------------- /_typos.toml: -------------------------------------------------------------------------------- 1 | [default] 2 | extend-ignore-identifiers-re = ["NdArray*", "ND"] 3 | 4 | [files] 5 | extend-exclude = ["squeezenet-burn/src/model/label.txt"] 6 | -------------------------------------------------------------------------------- /bert-burn/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Aasheesh Singh aasheeshdtu@gmail.com"] 3 | license = "MIT OR Apache-2.0" 4 | name = "bert-burn" 5 | version = "0.2.0" 6 | edition = "2021" 7 | 8 | [features] 9 | default = [] 10 | f16 = [] 11 | ndarray = ["burn/ndarray"] 12 | tch-cpu = ["burn/tch"] 13 | tch-gpu = ["burn/tch"] 14 | wgpu = ["burn/wgpu"] 15 | cuda = ["burn/cuda-jit"] 16 | fusion = ["burn/fusion"] 17 | # To be replaced by burn-safetensors once supported: https://github.com/tracel-ai/burn/issues/626 18 | safetensors = ["candle-core/default"] 19 | 20 | 21 | [dependencies] 22 | # Burn 23 | burn = { version = "0.16", default-features = false, features = ["dataset", "std"] } 24 | cubecl-runtime = { version = "0.3.0", features = ["channel-mpsc"] } # missing feature flag when burn default-features are off 25 | candle-core = { version = "0.3" } 26 | # Tokenizer 27 | tokenizers = { version = "0.15.0", default-features = false, features = [ 28 | "onig", 29 | "http", 30 | ] } 31 | derive-new = "0.6.0" 32 | hf-hub = { version = "0.3.2", features = ["tokio"] } 33 | 34 | # Utils 35 | serde = { version = "1.0.196", features = ["std", "derive"] } 36 | libm = "0.2.8" 37 | serde_json = "1.0.113" 38 | tokio = "1.35.1" 39 | -------------------------------------------------------------------------------- /bert-burn/LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | ../LICENSE-APACHE -------------------------------------------------------------------------------- /bert-burn/LICENSE-MIT: -------------------------------------------------------------------------------- 1 | ../LICENSE-MIT -------------------------------------------------------------------------------- /bert-burn/README.md: -------------------------------------------------------------------------------- 1 | # Bert-Burn Model 2 | 3 | This project provides an example implementation for inference on the BERT family of models. The following compatible 4 | bert-variants: `roberta-base`(**default**)/`roberta-large`, `bert-base-uncased`/`bert-large-uncased`/`bert-base-cased`/`bert-large-cased` 5 | can be loaded as following. The pre-trained weights and config files are automatically downloaded 6 | from: [HuggingFace Model hub](https://huggingface.co/FacebookAI/roberta-base/tree/main) 7 | 8 | ### To include the model in your project 9 | 10 | Add this to your `Cargo.toml`: 11 | 12 | ```toml 13 | [dependencies] 14 | bert-burn = { git = "https://github.com/tracel-ai/models", package = "bert-burn", default-features = false } 15 | ``` 16 | 17 | ## Example Usage 18 | 19 | Example usage for getting sentence embedding from given input text. The model supports multiple backends from burn 20 | (e.g. `ndarray`, `wgpu`, `tch-gpu`, `tch-cpu`, `cuda`) which can be selected using the `--features` flag. An example with `wgpu` 21 | backend is shown below. The `fusion` flag is used to enable kernel fusion for the `wgpu` backend. It is not required 22 | with other backends. The `safetensors` flag is used to support loading weights in `safetensors` format via `candle-core` 23 | crate. 24 | 25 | ### WGPU backend 26 | 27 | ```bash 28 | cd bert-burn/ 29 | # Get sentence embeddings from the RobBERTa encoder (default) 30 | cargo run --example infer-embedding --release --features wgpu,fusion,safetensors 31 | 32 | # Using bert-base-uncased model 33 | cargo run --example infer-embedding --release --features wgpu,fusion,safetensors bert-base-uncased 34 | 35 | # Using roberta-large model 36 | cargo run --example infer-embedding --release --features wgpu,fusion,safetensors roberta-large 37 | ``` 38 | 39 | 40 | -------------------------------------------------------------------------------- /bert-burn/examples/infer-embedding.rs: -------------------------------------------------------------------------------- 1 | use bert_burn::data::{BertInputBatcher, BertTokenizer}; 2 | use bert_burn::loader::{download_hf_model, load_model_config}; 3 | use bert_burn::model::{BertModel, BertModelRecord}; 4 | use burn::data::dataloader::batcher::Batcher; 5 | use burn::module::Module; 6 | use burn::tensor::backend::Backend; 7 | use burn::tensor::Tensor; 8 | use std::env; 9 | use std::sync::Arc; 10 | 11 | #[cfg(not(feature = "f16"))] 12 | #[allow(dead_code)] 13 | type ElemType = f32; 14 | #[cfg(feature = "f16")] 15 | type ElemType = burn::tensor::f16; 16 | 17 | pub fn launch(device: B::Device) { 18 | let args: Vec = env::args().collect(); 19 | let default_model = "roberta-base".to_string(); 20 | let model_variant = if args.len() > 1 { 21 | // Use the argument provided by the user 22 | // Possible values: "bert-base-uncased", "roberta-large" etc. 23 | &args[1] 24 | } else { 25 | // Use the default value if no argument is provided 26 | &default_model 27 | }; 28 | 29 | println!("Model variant: {}", model_variant); 30 | 31 | let text_samples = vec![ 32 | "Jays power up to take finale Contrary to popular belief, the power never really \ 33 | snapped back at SkyDome on Sunday. The lights came on after an hour delay, but it \ 34 | took some extra time for the batting orders to provide some extra wattage." 35 | .to_string(), 36 | "Yemen Sentences 15 Militants on Terror Charges A court in Yemen has sentenced one \ 37 | man to death and 14 others to prison terms for a series of attacks and terrorist \ 38 | plots in 2002, including the bombing of a French oil tanker." 39 | .to_string(), 40 | "IBM puts grids to work at U.S. Open IBM will put a collection of its On \ 41 | Demand-related products and technologies to this test next week at the U.S. Open \ 42 | tennis championships, implementing a grid-based infrastructure capable of running \ 43 | multiple workloads including two not associated with the tournament." 44 | .to_string(), 45 | ]; 46 | 47 | let (config_file, model_file) = download_hf_model(model_variant); 48 | let model_config = load_model_config(config_file); 49 | 50 | let model_record: BertModelRecord = 51 | BertModel::from_safetensors(model_file, &device, model_config.clone()); 52 | 53 | let model = model_config.init(&device).load_record(model_record); 54 | 55 | let tokenizer = Arc::new(BertTokenizer::new( 56 | model_variant.to_string(), 57 | model_config.pad_token_id, 58 | )); 59 | 60 | // Batch the input samples to max sequence length with padding 61 | let batcher = Arc::new(BertInputBatcher::::new( 62 | tokenizer.clone(), 63 | device.clone(), 64 | model_config.max_seq_len.unwrap(), 65 | )); 66 | 67 | // Batch input samples using the batcher Shape: [Batch size, Seq_len] 68 | let input = batcher.batch(text_samples.clone()); 69 | let [batch_size, _seq_len] = input.tokens.dims(); 70 | println!("Input: {}", input.tokens); 71 | 72 | let output = model.forward(input); 73 | 74 | // get sentence embedding from the first [CLS] token 75 | let cls_token_idx = 0; 76 | 77 | // Embedding size 78 | let d_model = model_config.hidden_size; 79 | let sentence_embedding = output.hidden_states.clone().slice([ 80 | 0..batch_size, 81 | cls_token_idx..cls_token_idx + 1, 82 | 0..d_model, 83 | ]); 84 | 85 | let sentence_embedding: Tensor = sentence_embedding.squeeze(1); 86 | println!( 87 | "Roberta Sentence embedding: {}", 88 | sentence_embedding 89 | ); 90 | } 91 | 92 | #[cfg(feature = "ndarray")] 93 | mod ndarray { 94 | use burn::backend::ndarray::{NdArray, NdArrayDevice}; 95 | 96 | use crate::{launch, ElemType}; 97 | 98 | pub fn run() { 99 | launch::>(NdArrayDevice::Cpu); 100 | } 101 | } 102 | 103 | #[cfg(feature = "tch-gpu")] 104 | mod tch_gpu { 105 | use crate::{launch, ElemType}; 106 | use burn::backend::libtorch::{LibTorch, LibTorchDevice}; 107 | 108 | pub fn run() { 109 | #[cfg(not(target_os = "macos"))] 110 | let device = LibTorchDevice::Cuda(0); 111 | #[cfg(target_os = "macos")] 112 | let device = LibTorchDevice::Mps; 113 | 114 | launch::>(device); 115 | } 116 | } 117 | 118 | #[cfg(feature = "tch-cpu")] 119 | mod tch_cpu { 120 | use crate::{launch, ElemType}; 121 | use burn::backend::libtorch::{LibTorch, LibTorchDevice}; 122 | 123 | pub fn run() { 124 | launch::>(LibTorchDevice::Cpu); 125 | } 126 | } 127 | 128 | #[cfg(feature = "wgpu")] 129 | mod wgpu { 130 | use crate::launch; 131 | use burn::backend::wgpu::{Wgpu, WgpuDevice}; 132 | 133 | pub fn run() { 134 | launch::(WgpuDevice::default()); 135 | } 136 | } 137 | 138 | fn main() { 139 | #[cfg(feature = "ndarray")] 140 | ndarray::run(); 141 | #[cfg(feature = "tch-gpu")] 142 | tch_gpu::run(); 143 | #[cfg(feature = "tch-cpu")] 144 | tch_cpu::run(); 145 | #[cfg(feature = "wgpu")] 146 | wgpu::run(); 147 | } 148 | -------------------------------------------------------------------------------- /bert-burn/examples/masked.rs: -------------------------------------------------------------------------------- 1 | use bert_burn::data::{BertInputBatcher, BertTokenizer}; 2 | use bert_burn::fill_mask::fill_mask; 3 | use bert_burn::loader::{download_hf_model, load_model_config}; 4 | use bert_burn::model::{BertMaskedLM, BertMaskedLMRecord}; 5 | use burn::data::dataloader::batcher::Batcher; 6 | use burn::module::Module; 7 | use burn::tensor::backend::Backend; 8 | use std::env; 9 | use std::sync::Arc; 10 | 11 | #[cfg(not(feature = "f16"))] 12 | #[allow(dead_code)] 13 | type ElemType = f32; 14 | #[cfg(feature = "f16")] 15 | type ElemType = burn::tensor::f16; 16 | 17 | pub fn launch(device: B::Device) { 18 | let args: Vec = env::args().collect(); 19 | let default_model = "roberta-base".to_string(); 20 | let model_variant = if args.len() > 1 { 21 | // Use the argument provided by the user 22 | // Possible values: "bert-base-uncased", "roberta-large" etc. 23 | &args[1] 24 | } else { 25 | // Use the default value if no argument is provided 26 | &default_model 27 | }; 28 | 29 | println!("Model variant: {}", model_variant); 30 | 31 | let text_samples = vec![ 32 | "Paris is the of France.".to_string(), 33 | "The goal of life is .".to_string(), 34 | ]; 35 | 36 | let (config_file, model_file) = download_hf_model(model_variant); 37 | let model_config = load_model_config(config_file); 38 | 39 | let model_record: BertMaskedLMRecord = 40 | BertMaskedLM::from_safetensors(model_file, &device, model_config.clone()); 41 | 42 | let model = model_config 43 | .init_with_lm_head(&device) 44 | .load_record(model_record); 45 | 46 | let tokenizer = Arc::new(BertTokenizer::new( 47 | model_variant.to_string(), 48 | model_config.pad_token_id, 49 | )); 50 | 51 | // Batch the input samples to max sequence length with padding 52 | let batcher = Arc::new(BertInputBatcher::::new( 53 | tokenizer.clone(), 54 | device.clone(), 55 | model_config.max_seq_len.unwrap(), 56 | )); 57 | 58 | // Batch input samples using the batcher Shape: [Batch size, Seq_len] 59 | let input = batcher.batch(text_samples.clone()); 60 | let [batch_size, _seq_len] = input.tokens.dims(); 61 | println!("Input: {:?} // (Batch Size, Seq_len)", input.tokens.shape()); 62 | 63 | let output = fill_mask(&model, &model_config, tokenizer.as_ref(), input); 64 | 65 | for i in 0..batch_size { 66 | let input = &text_samples[i]; 67 | let result = &output[i]; 68 | println!("Input: {}", input); 69 | for fill_mask_result in result.iter() { 70 | let mask_idx = fill_mask_result.mask_idx; 71 | let top_k = &fill_mask_result.top_k; 72 | for (k, (score, token)) in top_k.iter().enumerate() { 73 | println!( 74 | "Top {} Prediction for {}: {} (Score: {:.4})", 75 | k + 1, 76 | mask_idx, 77 | token, 78 | score 79 | ); 80 | } 81 | } 82 | } 83 | } 84 | 85 | #[cfg(feature = "ndarray")] 86 | mod ndarray { 87 | use burn::backend::ndarray::{NdArray, NdArrayDevice}; 88 | 89 | use crate::{launch, ElemType}; 90 | 91 | pub fn run() { 92 | launch::>(NdArrayDevice::Cpu); 93 | } 94 | } 95 | 96 | #[cfg(feature = "tch-gpu")] 97 | mod tch_gpu { 98 | use crate::{launch, ElemType}; 99 | use burn::backend::libtorch::{LibTorch, LibTorchDevice}; 100 | 101 | pub fn run() { 102 | #[cfg(not(target_os = "macos"))] 103 | let device = LibTorchDevice::Cuda(0); 104 | #[cfg(target_os = "macos")] 105 | let device = LibTorchDevice::Mps; 106 | 107 | launch::>(device); 108 | } 109 | } 110 | 111 | #[cfg(feature = "tch-cpu")] 112 | mod tch_cpu { 113 | use crate::{launch, ElemType}; 114 | use burn::backend::libtorch::{LibTorch, LibTorchDevice}; 115 | 116 | pub fn run() { 117 | launch::>(LibTorchDevice::Cpu); 118 | } 119 | } 120 | 121 | #[cfg(feature = "wgpu")] 122 | mod wgpu { 123 | use crate::launch; 124 | use burn::backend::wgpu::{Wgpu, WgpuDevice}; 125 | 126 | pub fn run() { 127 | launch::(WgpuDevice::default()); 128 | } 129 | } 130 | 131 | fn main() { 132 | #[cfg(feature = "ndarray")] 133 | ndarray::run(); 134 | #[cfg(feature = "tch-gpu")] 135 | tch_gpu::run(); 136 | #[cfg(feature = "tch-cpu")] 137 | tch_cpu::run(); 138 | #[cfg(feature = "wgpu")] 139 | wgpu::run(); 140 | } 141 | -------------------------------------------------------------------------------- /bert-burn/src/data/batcher.rs: -------------------------------------------------------------------------------- 1 | use super::tokenizer::Tokenizer; 2 | use burn::data::dataloader::batcher::Batcher; 3 | use burn::nn::attention::generate_padding_mask; 4 | use burn::tensor::backend::Backend; 5 | use burn::tensor::{Bool, Int, Tensor}; 6 | use std::sync::Arc; 7 | 8 | #[derive(new)] 9 | pub struct BertInputBatcher { 10 | /// Tokenizer for converting input text string to token IDs 11 | tokenizer: Arc, 12 | /// Device on which to perform computation (e.g., CPU or CUDA device) 13 | device: B::Device, 14 | /// Maximum sequence length for tokenized text 15 | max_seq_length: usize, 16 | } 17 | 18 | #[derive(Debug, Clone, new)] 19 | pub struct BertInferenceBatch { 20 | /// Tokenized text as 2D tensor: [batch_size, max_seq_length] 21 | pub tokens: Tensor, 22 | /// Padding mask for the tokenized text containing booleans for padding locations 23 | pub mask_pad: Tensor, 24 | } 25 | 26 | impl Batcher> for BertInputBatcher { 27 | /// Batches a vector of strings into an inference batch 28 | fn batch(&self, items: Vec) -> BertInferenceBatch { 29 | let mut tokens_list = Vec::with_capacity(items.len()); 30 | 31 | // Tokenize each string 32 | for item in items { 33 | tokens_list.push(self.tokenizer.encode(&item)); 34 | } 35 | 36 | // Generate padding mask for tokenized text 37 | let mask = generate_padding_mask( 38 | self.tokenizer.pad_token(), 39 | tokens_list, 40 | Some(self.max_seq_length), 41 | &self.device, 42 | ); 43 | 44 | // Create and return inference batch 45 | BertInferenceBatch { 46 | tokens: mask.tensor, 47 | mask_pad: mask.mask, 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /bert-burn/src/data/mod.rs: -------------------------------------------------------------------------------- 1 | mod batcher; 2 | mod tokenizer; 3 | 4 | pub use batcher::*; 5 | pub use tokenizer::*; 6 | -------------------------------------------------------------------------------- /bert-burn/src/data/tokenizer.rs: -------------------------------------------------------------------------------- 1 | pub trait Tokenizer: Send + Sync { 2 | /// Converts a text string into a sequence of tokens. 3 | fn encode(&self, value: &str) -> Vec; 4 | 5 | /// Converts a sequence of tokens back into a text string. 6 | fn decode(&self, tokens: &[usize]) -> String; 7 | 8 | /// Gets the size of the tokenizer's vocabulary. 9 | fn vocab_size(&self) -> usize; 10 | 11 | /// Gets the token used for padding sequences to a consistent length. 12 | fn pad_token(&self) -> usize; 13 | 14 | /// Gets the string representation of the padding token. 15 | /// The default implementation uses `decode` on the padding token. 16 | fn pad_token_value(&self) -> String { 17 | self.decode(&[self.pad_token()]) 18 | } 19 | } 20 | 21 | /// Struct represents a specific tokenizer using the Roberta BPE tokenization strategy. 22 | pub struct BertTokenizer { 23 | // The underlying tokenizer from the `tokenizers` library. 24 | tokenizer: tokenizers::Tokenizer, 25 | pad_token: usize, 26 | } 27 | 28 | // Default implementation for creating a new BertTokenizer. 29 | // Downloads tokenizer from given model_name (eg: "roberta-base"). 30 | // Pad_token_id is the id of the padding token used to convert sequences to a consistent length. 31 | // specified in the model's config.json. 32 | impl BertTokenizer { 33 | pub fn new(model_name: String, pad_token_id: usize) -> Self { 34 | Self { 35 | tokenizer: tokenizers::Tokenizer::from_pretrained(model_name, None).unwrap(), 36 | pad_token: pad_token_id, 37 | } 38 | } 39 | } 40 | 41 | // Implementation of the Tokenizer trait for BertTokenizer. 42 | impl Tokenizer for BertTokenizer { 43 | /// Convert a text string into a sequence of tokens using the BERT model's tokenization strategy. 44 | fn encode(&self, value: &str) -> Vec { 45 | let tokens = self.tokenizer.encode(value, true).unwrap(); 46 | tokens.get_ids().iter().map(|t| *t as usize).collect() 47 | } 48 | 49 | /// Converts a sequence of tokens back into a text string. 50 | fn decode(&self, tokens: &[usize]) -> String { 51 | let tokens = tokens.iter().map(|t| *t as u32).collect::>(); 52 | self.tokenizer.decode(&tokens, false).unwrap() 53 | } 54 | 55 | /// Gets the size of the BERT tokenizer's vocabulary. 56 | fn vocab_size(&self) -> usize { 57 | self.tokenizer.get_vocab_size(true) 58 | } 59 | 60 | /// Gets the token used for padding sequences to a consistent length. 61 | fn pad_token(&self) -> usize { 62 | self.pad_token 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /bert-burn/src/embedding.rs: -------------------------------------------------------------------------------- 1 | use crate::data::BertInferenceBatch; 2 | use burn::config::Config; 3 | use burn::module::Module; 4 | use burn::nn::{Dropout, DropoutConfig, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig}; 5 | use burn::tensor::backend::Backend; 6 | use burn::tensor::{Float, Int, Tensor}; 7 | 8 | #[derive(Config)] 9 | pub struct BertEmbeddingsConfig { 10 | pub vocab_size: usize, 11 | pub max_position_embeddings: usize, 12 | pub type_vocab_size: usize, 13 | pub hidden_size: usize, 14 | pub hidden_dropout_prob: f64, 15 | pub layer_norm_eps: f64, 16 | pub pad_token_idx: usize, 17 | } 18 | 19 | #[derive(Module, Debug)] 20 | pub struct BertEmbeddings { 21 | pub pad_token_idx: usize, 22 | word_embeddings: Embedding, 23 | position_embeddings: Embedding, 24 | token_type_embeddings: Embedding, 25 | layer_norm: LayerNorm, 26 | dropout: Dropout, 27 | max_position_embeddings: usize, 28 | } 29 | 30 | impl BertEmbeddingsConfig { 31 | /// Initializes BertEmbeddings with default weights 32 | pub fn init(&self, device: &B::Device) -> BertEmbeddings { 33 | let word_embeddings = EmbeddingConfig::new(self.vocab_size, self.hidden_size).init(device); 34 | let position_embeddings = 35 | EmbeddingConfig::new(self.max_position_embeddings, self.hidden_size).init(device); 36 | let token_type_embeddings = 37 | EmbeddingConfig::new(self.type_vocab_size, self.hidden_size).init(device); 38 | let layer_norm = LayerNormConfig::new(self.hidden_size) 39 | .with_epsilon(self.layer_norm_eps) 40 | .init(device); 41 | 42 | let dropout = DropoutConfig::new(self.hidden_dropout_prob).init(); 43 | 44 | BertEmbeddings { 45 | word_embeddings, 46 | position_embeddings, 47 | token_type_embeddings, 48 | layer_norm, 49 | dropout, 50 | max_position_embeddings: self.max_position_embeddings, 51 | pad_token_idx: self.pad_token_idx, 52 | } 53 | } 54 | } 55 | 56 | impl BertEmbeddings { 57 | pub fn forward(&self, item: BertInferenceBatch) -> Tensor { 58 | // Items batch contains the tokenized input and padding mask, each of dim: [batch_size, max_seq_length] 59 | let input_shape = item.tokens.shape(); 60 | let input_ids = item.tokens; 61 | 62 | // Embed tokens 63 | let inputs_embeds = self.word_embeddings.forward(input_ids); 64 | let mut embeddings = inputs_embeds; 65 | 66 | let device = &self.position_embeddings.devices()[0]; 67 | 68 | let token_type_ids = Tensor::::zeros(input_shape.clone(), device); 69 | let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids); 70 | 71 | embeddings = embeddings + token_type_embeddings; 72 | 73 | // Max position embeddings is 514 for roberta models as opposed to 512 for bert models 74 | // The position embeddings thus start from padding_idx + 1 to max_position_embeddings: [2 - 514) 75 | // https://github.com/facebookresearch/fairseq/issues/1187 76 | 77 | let seq_length = input_shape.dims[1]; 78 | let mut position_ids_tensor: Tensor = 79 | Tensor::arange(0..seq_length as i64, device) 80 | .reshape([1, seq_length]) 81 | .expand(input_shape.clone()); 82 | 83 | if self.max_position_embeddings != 512 { 84 | // RoBERTa use a different scheme than BERT to create position indexes where padding tokens are given 85 | // a fixed positional index. Check: create_position_ids_from_input_ids() in 86 | // https://github.com/huggingface/transformers/blob/main/src/transformers/models/roberta/modeling_roberta.py 87 | let position_ids = Tensor::arange( 88 | (self.pad_token_idx as i64) + 1 89 | ..(seq_length as i64) + (self.pad_token_idx as i64) + 1, 90 | device, 91 | ) 92 | .reshape([1, seq_length]) 93 | .expand(input_shape); 94 | position_ids_tensor = 95 | position_ids.mask_fill(item.mask_pad.clone(), self.pad_token_idx as i32); 96 | } 97 | 98 | let position_embeddings = self.position_embeddings.forward(position_ids_tensor); 99 | embeddings = embeddings + position_embeddings; 100 | 101 | // Layer normalization and dropout 102 | let embeddings = self.layer_norm.forward(embeddings); 103 | let embeddings = self.dropout.forward(embeddings); 104 | 105 | embeddings 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /bert-burn/src/fill_mask.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | data::Tokenizer, 3 | data::{BertInferenceBatch, BertTokenizer}, 4 | model::BertMaskedLM, 5 | model::BertModelConfig, 6 | }; 7 | use burn::tensor::{activation::softmax, backend::Backend, Element, Tensor}; 8 | 9 | type TokenType = usize; 10 | const MASK_TOKEN_ID: TokenType = 50264; 11 | 12 | #[derive(Debug, Clone)] 13 | pub struct FillMaskResult { 14 | pub mask_idx: usize, 15 | pub top_k: Vec<(f32, String)>, 16 | } 17 | 18 | pub fn fill_mask( 19 | model: &BertMaskedLM, 20 | model_config: &BertModelConfig, 21 | tokenizer: &BertTokenizer, 22 | input: BertInferenceBatch, 23 | ) -> Vec> { 24 | let [batch_size, seq_len] = input.tokens.dims(); 25 | let output = model.forward(input.clone()); 26 | 27 | let mut results = vec![]; 28 | 29 | // Embedding size 30 | let d_model = model_config.vocab_size.clone(); 31 | for i in 0..batch_size { 32 | let mut batch_results = vec![]; 33 | let input_tokens = input 34 | .tokens 35 | .clone() 36 | .slice([i..i + 1, 0..seq_len]) 37 | .squeeze::<1>(0) 38 | .into_data(); 39 | // Find the mask tokens in the input, as a list of indices 40 | let masks = find_masks(input_tokens.as_slice::().unwrap(), MASK_TOKEN_ID); 41 | for mask in masks { 42 | let logits = output 43 | .clone() 44 | .slice([i..i + 1, mask..(mask + 1), 0..d_model]) 45 | .squeeze::<2>(0) 46 | .squeeze(0); 47 | // Find the top k tokens with the highest probabilities 48 | let top_k = top_k(5, logits); 49 | batch_results.push(FillMaskResult { 50 | mask_idx: mask, 51 | top_k: top_k 52 | .iter() 53 | .map(|(k, prob)| (*prob, tokenizer.decode(&[*k]))) 54 | .collect(), 55 | }); 56 | } 57 | results.push(batch_results); 58 | } 59 | 60 | results 61 | } 62 | 63 | fn find_masks(tokens: &[T], mask_token_id: TokenType) -> Vec { 64 | let mut masks = Vec::new(); 65 | for (i, token) in tokens.iter().enumerate() { 66 | if token.to_usize() == mask_token_id { 67 | masks.push(i); 68 | } 69 | } 70 | masks 71 | } 72 | 73 | fn data_to_vec_f32(data: &[T]) -> Vec { 74 | data.iter().map(|x| x.to_f32()).collect() 75 | } 76 | 77 | fn data_to_vec_usize(data: &[T]) -> Vec { 78 | data.iter().map(|x| x.to_usize()).collect() 79 | } 80 | 81 | fn top_k(k: usize, logits: Tensor) -> Vec<(usize, f32)> { 82 | let (pre_soft_probs, indices) = logits.sort_with_indices(0); 83 | let (probabilities, indices) = ( 84 | data_to_vec_f32(&softmax(pre_soft_probs, 0).into_data().as_slice::().unwrap()), 85 | data_to_vec_usize(&indices.into_data().as_slice::().unwrap()), 86 | ); 87 | probabilities 88 | .iter() 89 | .enumerate() 90 | .rev() 91 | .take(k) 92 | .map(|(i, &p)| (indices[i], p)) 93 | .collect() 94 | } 95 | -------------------------------------------------------------------------------- /bert-burn/src/lib.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] 2 | extern crate derive_new; 3 | 4 | pub mod data; 5 | mod embedding; 6 | pub mod fill_mask; 7 | pub mod loader; 8 | pub mod model; 9 | pub mod pooler; 10 | -------------------------------------------------------------------------------- /bert-burn/src/model.rs: -------------------------------------------------------------------------------- 1 | use crate::data::BertInferenceBatch; 2 | use crate::embedding::{BertEmbeddings, BertEmbeddingsConfig}; 3 | use crate::loader::{ 4 | load_decoder_from_safetensors, load_embeddings_from_safetensors, load_encoder_from_safetensors, 5 | load_layer_norm_safetensor, load_linear_safetensor, load_pooler_from_safetensors, 6 | }; 7 | use crate::pooler::{Pooler, PoolerConfig}; 8 | use burn::config::Config; 9 | use burn::module::Module; 10 | use burn::nn::transformer::{ 11 | TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput, 12 | }; 13 | use burn::nn::Initializer::KaimingUniform; 14 | use burn::nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig}; 15 | use burn::tensor::activation::gelu; 16 | use burn::tensor::backend::Backend; 17 | use burn::tensor::Tensor; 18 | use candle_core::{safetensors, Device, Tensor as CandleTensor}; 19 | use std::collections::HashMap; 20 | use std::path::PathBuf; 21 | 22 | // Define the Bert model configuration 23 | #[derive(Config)] 24 | pub struct BertModelConfig { 25 | /// Number of attention heads in the multi-head attention 26 | pub num_attention_heads: usize, 27 | /// Number of transformer encoder layers/blocks 28 | pub num_hidden_layers: usize, 29 | /// Layer normalization epsilon 30 | pub layer_norm_eps: f64, 31 | /// Size of bert embedding (e.g., 768 for roberta-base) 32 | pub hidden_size: usize, 33 | /// Size of the intermediate position wise feedforward layer 34 | pub intermediate_size: usize, 35 | /// Size of the vocabulary 36 | pub vocab_size: usize, 37 | /// Max position embeddings, in RoBERTa equal to max_seq_len + 2 (514), for BERT equal to max_seq_len(512) 38 | pub max_position_embeddings: usize, 39 | /// Identifier for sentence type in input (e.g., 0 for single sentence, 1 for pair) 40 | pub type_vocab_size: usize, 41 | /// Dropout value across layers, typically 0.1 42 | pub hidden_dropout_prob: f64, 43 | /// BERT model name (roberta) 44 | pub model_type: String, 45 | /// Index of the padding token 46 | pub pad_token_id: usize, 47 | /// Maximum sequence length for the tokenizer 48 | pub max_seq_len: Option, 49 | /// Whether to add a pooling layer to the model 50 | pub with_pooling_layer: Option, 51 | } 52 | 53 | // Define the Bert model structure 54 | #[derive(Module, Debug)] 55 | pub struct BertModel { 56 | pub embeddings: BertEmbeddings, 57 | pub encoder: TransformerEncoder, 58 | pub pooler: Option>, 59 | } 60 | 61 | #[derive(Debug, Clone)] 62 | pub struct BertModelOutput { 63 | pub hidden_states: Tensor, 64 | pub pooled_output: Option>, 65 | } 66 | 67 | impl BertModelConfig { 68 | /// Initializes a Bert model with default weights 69 | pub fn init(&self, device: &B::Device) -> BertModel { 70 | let embeddings = self.get_embeddings_config().init(device); 71 | let encoder = self.get_encoder_config().init(device); 72 | 73 | let pooler = if self.with_pooling_layer.unwrap_or(false) { 74 | Some( 75 | PoolerConfig { 76 | hidden_size: self.hidden_size, 77 | } 78 | .init(device), 79 | ) 80 | } else { 81 | None 82 | }; 83 | 84 | BertModel { 85 | embeddings, 86 | encoder, 87 | pooler, 88 | } 89 | } 90 | 91 | pub fn init_with_lm_head(&self, device: &B::Device) -> BertMaskedLM { 92 | let bert = self.init(device); 93 | let lm_head = BertLMHead { 94 | dense: LinearConfig::new(self.hidden_size, self.hidden_size).init(device), 95 | layer_norm: LayerNormConfig::new(self.hidden_size) 96 | .with_epsilon(self.layer_norm_eps) 97 | .init(device), 98 | decoder: LinearConfig::new(self.hidden_size, self.vocab_size).init(device), 99 | }; 100 | 101 | BertMaskedLM { bert, lm_head } 102 | } 103 | 104 | fn get_embeddings_config(&self) -> BertEmbeddingsConfig { 105 | BertEmbeddingsConfig { 106 | vocab_size: self.vocab_size, 107 | max_position_embeddings: self.max_position_embeddings, 108 | type_vocab_size: self.type_vocab_size, 109 | hidden_size: self.hidden_size, 110 | hidden_dropout_prob: self.hidden_dropout_prob, 111 | layer_norm_eps: self.layer_norm_eps, 112 | pad_token_idx: self.pad_token_id, 113 | } 114 | } 115 | 116 | fn get_encoder_config(&self) -> TransformerEncoderConfig { 117 | TransformerEncoderConfig { 118 | n_heads: self.num_attention_heads, 119 | n_layers: self.num_hidden_layers, 120 | d_model: self.hidden_size, 121 | d_ff: self.intermediate_size, 122 | dropout: self.hidden_dropout_prob, 123 | norm_first: false, 124 | quiet_softmax: false, 125 | initializer: KaimingUniform { 126 | gain: 1.0 / libm::sqrt(3.0), 127 | fan_out_only: false, 128 | }, 129 | } 130 | } 131 | } 132 | 133 | impl BertModel { 134 | /// Defines forward pass 135 | pub fn forward(&self, input: BertInferenceBatch) -> BertModelOutput { 136 | let embedding = self.embeddings.forward(input.clone()); 137 | let device = &self.embeddings.devices()[0]; 138 | 139 | let mask_pad = input.mask_pad.to_device(device); 140 | 141 | let encoder_input = TransformerEncoderInput::new(embedding).mask_pad(mask_pad); 142 | let hidden_states = self.encoder.forward(encoder_input); 143 | 144 | let pooled_output = self 145 | .pooler 146 | .as_ref() 147 | .map(|pooler| pooler.forward(hidden_states.clone())); 148 | 149 | BertModelOutput { 150 | hidden_states, 151 | pooled_output, 152 | } 153 | } 154 | 155 | pub fn from_safetensors( 156 | file_path: PathBuf, 157 | device: &B::Device, 158 | config: BertModelConfig, 159 | ) -> BertModelRecord { 160 | let model_name = config.model_type.as_str(); 161 | let weight_result = safetensors::load::(file_path, &Device::Cpu); 162 | 163 | // Match on the result of loading the weights 164 | let weights = match weight_result { 165 | Ok(weights) => weights, 166 | Err(e) => panic!("Error loading weights: {:?}", e), 167 | }; 168 | 169 | // Weights are stored in a HashMap 170 | // For each layer, it will either be prefixed with "encoder.layer." or "embeddings." 171 | // We need to extract both. 172 | let mut encoder_layers: HashMap = HashMap::new(); 173 | let mut embeddings_layers: HashMap = HashMap::new(); 174 | let mut pooler_layers: HashMap = HashMap::new(); 175 | 176 | for (key, value) in weights.iter() { 177 | // If model name prefix present in keys, remove it to load keys consistently 178 | // across variants (bert-base, roberta-base etc.) 179 | 180 | let prefix = String::from(model_name) + "."; 181 | let key_without_prefix = key.replace(&prefix, ""); 182 | 183 | if key_without_prefix.starts_with("encoder.layer.") { 184 | encoder_layers.insert(key_without_prefix, value.clone()); 185 | } else if key_without_prefix.starts_with("embeddings.") { 186 | embeddings_layers.insert(key_without_prefix, value.clone()); 187 | } else if key_without_prefix.starts_with("pooler.") { 188 | pooler_layers.insert(key_without_prefix, value.clone()); 189 | } 190 | } 191 | 192 | let embeddings_record = load_embeddings_from_safetensors::(embeddings_layers, device); 193 | let encoder_record = load_encoder_from_safetensors::(encoder_layers, device); 194 | 195 | let pooler_record = if config.with_pooling_layer.unwrap_or(false) { 196 | Some(load_pooler_from_safetensors::(pooler_layers, device)) 197 | } else { 198 | None 199 | }; 200 | 201 | let model_record = BertModelRecord { 202 | embeddings: embeddings_record, 203 | encoder: encoder_record, 204 | pooler: pooler_record, 205 | }; 206 | model_record 207 | } 208 | } 209 | 210 | #[derive(Module, Debug)] 211 | pub struct BertMaskedLM { 212 | pub bert: BertModel, 213 | pub lm_head: BertLMHead, 214 | } 215 | 216 | #[derive(Module, Debug)] 217 | pub struct BertLMHead { 218 | pub dense: Linear, 219 | pub layer_norm: LayerNorm, 220 | pub decoder: Linear, 221 | } 222 | 223 | impl BertMaskedLM { 224 | pub fn forward(&self, input: BertInferenceBatch) -> Tensor { 225 | let output = self.bert.forward(BertInferenceBatch { 226 | tokens: input.tokens.clone(), 227 | mask_pad: input.mask_pad.clone(), 228 | }); 229 | let output = self.lm_head.forward(output.hidden_states); 230 | output 231 | } 232 | 233 | pub fn from_safetensors( 234 | file_path: PathBuf, 235 | device: &B::Device, 236 | config: BertModelConfig, 237 | ) -> BertMaskedLMRecord { 238 | let bert = BertModel::from_safetensors(file_path.clone(), device, config.clone()); 239 | let lm_head = BertLMHead::from_safetensors(file_path, device, config); 240 | 241 | BertMaskedLMRecord { bert, lm_head } 242 | } 243 | } 244 | 245 | impl BertLMHead { 246 | pub fn forward(&self, features: Tensor) -> Tensor { 247 | let output = self.dense.forward(features); 248 | let output = gelu(output); 249 | let output = self.layer_norm.forward(output); 250 | 251 | let output = self.decoder.forward(output); 252 | output 253 | } 254 | 255 | pub fn from_safetensors( 256 | file_path: PathBuf, 257 | device: &B::Device, 258 | _config: BertModelConfig, 259 | ) -> BertLMHeadRecord { 260 | let weight_result = safetensors::load::(file_path, &Device::Cpu); 261 | 262 | // Match on the result of loading the weights 263 | let weights = match weight_result { 264 | Ok(weights) => weights, 265 | Err(e) => panic!("Error loading weights: {:?}", e), 266 | }; 267 | 268 | let dense = load_linear_safetensor::( 269 | &weights["lm_head.dense.bias"], 270 | &weights["lm_head.dense.weight"], 271 | device, 272 | ); 273 | let layer_norm = load_layer_norm_safetensor::( 274 | &weights["lm_head.layer_norm.bias"], 275 | &weights["lm_head.layer_norm.weight"], 276 | device, 277 | ); 278 | let decoder = load_decoder_from_safetensors::( 279 | &weights["lm_head.bias"], 280 | &weights 281 | .iter() 282 | .find(|(k, _)| k.contains("word_embeddings.weight")) 283 | .unwrap() 284 | .1, 285 | device, 286 | ); 287 | 288 | BertLMHeadRecord { 289 | dense, 290 | layer_norm, 291 | decoder, 292 | } 293 | } 294 | } 295 | -------------------------------------------------------------------------------- /bert-burn/src/pooler.rs: -------------------------------------------------------------------------------- 1 | use burn::{ 2 | config::Config, 3 | module::Module, 4 | nn::{Linear, LinearConfig}, 5 | tensor::{backend::Backend, Tensor}, 6 | }; 7 | use derive_new::new; 8 | 9 | /// Pooler 10 | #[derive(Module, Debug, new)] 11 | pub struct Pooler { 12 | /// Linear output 13 | output: Linear, 14 | } 15 | 16 | impl Pooler { 17 | /// Forward pass 18 | pub fn forward(&self, encoder_output: Tensor) -> Tensor { 19 | let [batch_size, _, _] = encoder_output.dims(); 20 | 21 | self.output 22 | .forward(encoder_output.slice([0..batch_size, 0..1])) 23 | .tanh() 24 | } 25 | } 26 | 27 | /// Pooler Configuration 28 | #[derive(Config)] 29 | pub struct PoolerConfig { 30 | /// Hidden size 31 | pub hidden_size: usize, 32 | } 33 | 34 | impl PoolerConfig { 35 | /// Initialize a new Pooler module. 36 | pub fn init(&self, device: &B::Device) -> Pooler { 37 | let output = LinearConfig::new(self.hidden_size, self.hidden_size).init(device); 38 | 39 | Pooler::new(output) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /llama-burn/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["guillaumelagrange "] 3 | license = "MIT OR Apache-2.0" 4 | name = "llama-burn" 5 | version = "0.1.0" 6 | edition = "2021" 7 | description = "Llama 3 large language model with Burn" 8 | 9 | [features] 10 | default = ["pretrained"] 11 | pretrained = ["burn/network", "dep:dirs"] 12 | 13 | llama3 = ["dep:tiktoken-rs", "dep:rustc-hash", "dep:base64"] 14 | tiny = ["dep:tokenizers"] 15 | 16 | # Example feature flags (backend selection) 17 | tch-cpu = ["burn/tch"] 18 | tch-gpu = ["burn/tch"] 19 | cuda = ["burn/cuda-jit"] 20 | wgpu = ["burn/wgpu"] 21 | 22 | # To import pytorch weights 23 | import = ["burn-import"] 24 | 25 | [dependencies] 26 | burn = { version = "0.16.0", default-features = false, features = ["std", "fusion"] } 27 | burn-import = { version = "0.16.0", optional = true } 28 | 29 | itertools = { version = "0.12.1", default-features = false, features = [ 30 | "use_alloc", 31 | ] } 32 | dirs = { version = "5.0.1", optional = true } 33 | serde = { version = "1.0.192", default-features = false, features = [ 34 | "derive", 35 | "alloc", 36 | ] } # alloc is for no_std, derive is needed 37 | 38 | # Tiktoken tokenizer (llama 3) 39 | tiktoken-rs = { version = "0.5.8", optional = true } 40 | base64 = { version = "0.22.1", optional = true } 41 | rustc-hash = { version = "1.1.0", optional = true } 42 | 43 | # SentencePiece tokenizer (tiny llama / llama 2) 44 | tokenizers = { version = "0.19.1", default-features = false, features = [ 45 | "onig", 46 | ], optional = true } 47 | 48 | rand = { version = "0.8.5", default-features = false, features = [ 49 | "std_rng", 50 | ] } # std_rng is for no_std 51 | 52 | [dev-dependencies] 53 | burn = { version = "0.16.0", default-features = false } 54 | clap = { version = "4.5.4", features = ["derive"] } 55 | -------------------------------------------------------------------------------- /llama-burn/NOTICES.md: -------------------------------------------------------------------------------- 1 | # NOTICES AND INFORMATION 2 | 3 | This file contains notices and information required by libraries that this repository copied or 4 | derived from. The use of the following resources complies with the licenses provided. 5 | 6 | ## Implementation 7 | 8 | The model implementation was adapted from the original 9 | [Llama 3 implementation](https://github.com/meta-llama/llama3), which is distributed under the 10 | [Meta Llama 3 Community License Agreement](https://github.com/meta-llama/llama3/blob/main/LICENSE). 11 | The Llama 3.1 model is distributed under the 12 | [Llama 3.1 Community License Agreement](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/LICENSE). 13 | The Llama 3.2 model is distributed under the 14 | [Llama 3.2 Community License Agreement](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/LICENSE). 15 | 16 | The TinyLlama implementation is derived from the same code, but its weights and tokenizers were 17 | adapted from the [original implementation](https://github.com/jzhang38/TinyLlama) distributed under 18 | the [Apache 2.0](https://github.com/jzhang38/TinyLlama/blob/main/LICENSE) open source license. 19 | -------------------------------------------------------------------------------- /llama-burn/README.md: -------------------------------------------------------------------------------- 1 | # Llama Burn 2 | 3 | An image of a llama surrounded by fiery colors and a gust of fire 4 | 5 | The popular Llama LLM is here! 6 | 7 | This repository contains the 8 | [Llama 3.2, Llama 3.1, Llama 3](https://github.com/meta-llama/llama-models/), and 9 | [TinyLlama](https://github.com/jzhang38/TinyLlama) implementations with their corresponding 10 | tokenizers. You can find the [Burn](https://github.com/tracel-ai/burn) implementation for the Llama 11 | variants in [src/llama.rs](src/llama.rs). 12 | 13 | ## Usage 14 | 15 | ### `Cargo.toml` 16 | 17 | Add this to your `Cargo.toml`: 18 | 19 | ```toml 20 | [dependencies] 21 | llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-burn", default-features = false } 22 | ``` 23 | 24 | If you want to use Llama 3 or TinyLlama (including pre-trained weights if default features are 25 | active), enable the corresponding feature flag. 26 | 27 | > **Important:** these features require `std`. 28 | 29 | #### Llama 3 30 | 31 | ```toml 32 | [dependencies] 33 | llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-burn", features = ["llama3"] } 34 | ``` 35 | 36 | #### TinyLlama 37 | 38 | ```toml 39 | [dependencies] 40 | llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-burn", features = ["tiny"] } 41 | ``` 42 | 43 | ### Example Usage 44 | 45 | The [chat completion example](examples/chat.rs) initializes a Llama model from the provided weights 46 | file and generates a sequence of text based on the input prompt. The instruction-tuned model is 47 | loaded for dialogue applications, so the prompt is automatically formatted for chat completion. 48 | 49 | The example can be executed on the `tch` backend (CUDA or CPU), `cuda` or `wgpu`. 50 | 51 | | Argument | Description | 52 | | :-------------- | :------------------------------------------------------------------------------------------------------------- | 53 | | `-p` | The prompt or question to pass to the LLM (default: `"How many helicopters can a human eat in one sitting?"`). | 54 | | `-n` | The number of new tokens to generate (default: `50`). | 55 | | `--top-p` | Top-p probability threshold (default: `0.9`). | 56 | | `--temperature` | Temperature value for controlling randomness in sampling. (default: `0.6`). | 57 | | `--max-seq-len` | Maximum sequence length for input text. (default: `128`). | 58 | | `--seed` | The seed to use when generating random samples.. (default: `42`). | 59 | 60 | Any of the commands below can be used by appending any of the listed arguments by appending 61 | `[-- ]`. For example, you can provided your own prompt/question 62 | `-- -p "How many llamas does it take to change a lightbulb?"`. 63 | 64 | #### Llama 3 65 | 66 | Using the `tch` backend with CUDA: 67 | 68 | ```sh 69 | export TORCH_CUDA_VERSION=cu121 70 | cargo run --release --features llama3,tch-gpu --example chat 71 | ``` 72 | 73 | Using the `tch` backend with CPU: 74 | 75 | ```sh 76 | cargo run --release --features llama3,tch-cpu --example chat 77 | ``` 78 | 79 | Using the `wgpu` backend: 80 | 81 | ```sh 82 | cargo run --release --features llama3,wgpu --example chat 83 | ``` 84 | 85 | Using the `cuda` backend: 86 | 87 | ```sh 88 | cargo run --release --features llama3,cuda --example chat 89 | ``` 90 | 91 | **Built with Meta Llama 3.** This example uses the 92 | [Meta-Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) (default), 93 | [Meta-Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), 94 | [Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) and 95 | [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) 96 | instruction-tuned models. Note that the [base pre-trained Llama-3 model](./src/pretrained.rs#L77) is 97 | also available if you wish to use it in your application. 98 | 99 | #### TinyLlama 100 | 101 | Using the `tch` backend with CUDA: 102 | 103 | ```sh 104 | export TORCH_CUDA_VERSION=cu121 105 | cargo run --release --features tiny,tch-gpu --example chat 106 | ``` 107 | 108 | Using the `tch` backend with CPU: 109 | 110 | ```sh 111 | cargo run --release --features tiny,tch-cpu --example chat 112 | ``` 113 | 114 | Using the `wgpu` backend: 115 | 116 | ```sh 117 | cargo run --release --features tiny,wgpu --example chat 118 | ``` 119 | 120 | Using the `cuda` backend: 121 | 122 | ```sh 123 | cargo run --release --features tiny,cuda --example chat 124 | ``` 125 | 126 | This example uses the 127 | [TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0) 128 | instruction-tuned model based on the Llama2 architecture and tokenizer. 129 | 130 | ## Known Issues 131 | 132 | Based on your hardware and the model selected, the `wgpu` backend might not be able to successfully 133 | run the model due to the current memory management strategy. With `cuda` selected, the precision is 134 | set to `f32` due to compilation errors with `f16`. 135 | 136 | ### Windows 137 | 138 | The `cuda` backend is [unable to find nvrtc lib](https://github.com/coreylowman/cudarc/issues/246): 139 | 140 | ``` 141 | Unable to find nvrtc lib under the names ["nvrtc", "nvrtc64", "nvrtc64_12", "nvrtc64_123", "nvrtc64_123_0", "nvrtc64_120_3", "nvrtc64_10"]. Please open GitHub issue. 142 | ``` 143 | 144 | This has been fixed in the latest `cudarc` release (used by our `cuda-jit` backend), which is 145 | currently used [on main](https://github.com/tracel-ai/burn). To circumvent the issue, feel free to 146 | modify the code and use the latest Burn dependency in your project instead of `0.14.0`. This should 147 | also allow you to use `f16` precision (compilation errors have been fixed since). 148 | -------------------------------------------------------------------------------- /llama-burn/assets/llama-burn.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tracel-ai/models/146da485484c6cec70b69065117329dea4604b73/llama-burn/assets/llama-burn.jpeg -------------------------------------------------------------------------------- /llama-burn/examples/chat.rs: -------------------------------------------------------------------------------- 1 | use std::time::Instant; 2 | 3 | use burn::tensor::{backend::Backend, Device}; 4 | use clap::Parser; 5 | use llama_burn::{ 6 | llama::{Llama, LlamaConfig}, 7 | sampling::{Sampler, TopP}, 8 | tokenizer::Tokenizer, 9 | }; 10 | 11 | #[cfg(feature = "llama3")] 12 | use clap::ValueEnum; 13 | 14 | const DEFAULT_PROMPT: &str = "How many helicopters can a human eat in one sitting?"; 15 | 16 | #[derive(Parser, Debug)] 17 | #[command(version, about, long_about = None)] 18 | pub struct Config { 19 | /// Top-p probability threshold. 20 | #[arg(long, default_value_t = 0.9)] 21 | top_p: f64, 22 | 23 | /// Temperature value for controlling randomness in sampling. 24 | #[arg(long, default_value_t = 0.6)] 25 | temperature: f64, 26 | 27 | /// Maximum sequence length for input text. 28 | #[arg(long, default_value_t = 128)] 29 | max_seq_len: usize, 30 | 31 | /// The number of new tokens to generate (i.e., the number of generation steps to take). 32 | #[arg(long, short = 'n', default_value_t = 65)] 33 | sample_len: usize, 34 | 35 | /// The seed to use when generating random samples. 36 | #[arg(long, default_value_t = 42)] 37 | seed: u64, 38 | 39 | /// The input prompt. 40 | #[arg(short, long, default_value_t = String::from(DEFAULT_PROMPT))] 41 | prompt: String, 42 | 43 | /// The Llama 3 model version. 44 | #[cfg(feature = "llama3")] 45 | #[arg(long, default_value = "llama-3.2-1b-instruct")] 46 | model_version: Llama3, 47 | } 48 | 49 | #[cfg(feature = "llama3")] 50 | #[derive(Clone, Debug, ValueEnum)] 51 | /// Llama-3 model variants to load. 52 | enum Llama3 { 53 | /// Llama-3-8B-Instruct. 54 | #[value(name = "llama-3-8b-instruct")] 55 | V3Instruct, 56 | /// Llama-3.1-8B-Instruct. 57 | #[value(name = "llama-3.1-8b-instruct")] 58 | V31Instruct, 59 | /// Llama-3.2-1B-Instruct. 60 | #[value(name = "llama-3.2-1b-instruct")] 61 | V321bInstruct, 62 | /// Llama-3.2-3B-Instruct. 63 | #[value(name = "llama-3.2-3b-instruct")] 64 | V323bInstruct, 65 | } 66 | 67 | pub fn generate( 68 | llama: &mut Llama, 69 | prompt: &str, 70 | sample_len: usize, 71 | temperature: f64, 72 | sampler: &mut Sampler, 73 | ) { 74 | let now = Instant::now(); 75 | let generated = llama.generate(prompt, sample_len, temperature, sampler); 76 | let elapsed = now.elapsed().as_secs(); 77 | 78 | println!("> {}\n", generated.text); 79 | println!( 80 | "{} tokens generated ({:.4} tokens/s)\n", 81 | generated.tokens, 82 | generated.tokens as f64 / generated.time 83 | ); 84 | 85 | println!( 86 | "Generation completed in {}m{}s", 87 | (elapsed / 60), 88 | elapsed % 60 89 | ); 90 | } 91 | 92 | pub fn chat(args: Config, device: Device) { 93 | let mut prompt = args.prompt; 94 | 95 | // Sampling strategy 96 | let mut sampler = if args.temperature > 0.0 { 97 | Sampler::TopP(TopP::new(args.top_p, args.seed)) 98 | } else { 99 | Sampler::Argmax 100 | }; 101 | 102 | #[cfg(feature = "tiny")] 103 | { 104 | // TinyLlama-1.1B Chat v1.0 105 | let mut llama = LlamaConfig::tiny_llama_pretrained::(args.max_seq_len, &device).unwrap(); 106 | println!("Processing prompt: {}", prompt); 107 | 108 | // Prompt formatting for chat model 109 | prompt = format!( 110 | "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate\n<|user|>\n{prompt}\n<|assistant|>\n" 111 | ); 112 | 113 | generate( 114 | &mut llama, 115 | &prompt, 116 | args.sample_len, 117 | args.temperature, 118 | &mut sampler, 119 | ); 120 | } 121 | 122 | #[cfg(feature = "llama3")] 123 | { 124 | // Llama-3-8B-Instruct or Llama-3.1-8B-Instruct 125 | let mut llama = match args.model_version { 126 | Llama3::V3Instruct => { 127 | LlamaConfig::llama3_8b_pretrained::(args.max_seq_len, &device).unwrap() 128 | } 129 | Llama3::V31Instruct => { 130 | LlamaConfig::llama3_1_8b_pretrained::(args.max_seq_len, &device).unwrap() 131 | } 132 | Llama3::V321bInstruct => { 133 | LlamaConfig::llama3_2_1b_pretrained::(args.max_seq_len, &device).unwrap() 134 | } 135 | Llama3::V323bInstruct => { 136 | LlamaConfig::llama3_2_3b_pretrained::(args.max_seq_len, &device).unwrap() 137 | } 138 | }; 139 | println!("Processing prompt: {}", prompt); 140 | 141 | // Prompt formatting for chat model 142 | prompt = format!( 143 | "<|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" 144 | ); 145 | 146 | generate( 147 | &mut llama, 148 | &prompt, 149 | args.sample_len, 150 | args.temperature, 151 | &mut sampler, 152 | ); 153 | } 154 | } 155 | 156 | #[cfg(feature = "tch-gpu")] 157 | mod tch_gpu { 158 | use super::*; 159 | use burn::{ 160 | backend::{libtorch::LibTorchDevice, LibTorch}, 161 | tensor::f16, 162 | }; 163 | 164 | pub fn run(args: Config) { 165 | #[cfg(not(target_os = "macos"))] 166 | let device = LibTorchDevice::Cuda(0); 167 | #[cfg(target_os = "macos")] 168 | let device = LibTorchDevice::Mps; 169 | 170 | chat::>(args, device); 171 | } 172 | } 173 | 174 | #[cfg(feature = "tch-cpu")] 175 | mod tch_cpu { 176 | use super::*; 177 | use burn::backend::{libtorch::LibTorchDevice, LibTorch}; 178 | 179 | pub fn run(args: Config) { 180 | let device = LibTorchDevice::Cpu; 181 | 182 | chat::(args, device); 183 | } 184 | } 185 | 186 | #[cfg(feature = "wgpu")] 187 | mod wgpu { 188 | use super::*; 189 | use burn::backend::wgpu::{Wgpu, WgpuDevice}; 190 | 191 | pub fn run(args: Config) { 192 | let device = WgpuDevice::default(); 193 | 194 | chat::(args, device); 195 | } 196 | } 197 | 198 | #[cfg(feature = "cuda")] 199 | mod cuda { 200 | use super::*; 201 | use burn::{ 202 | backend::{cuda_jit::CudaDevice, CudaJit}, 203 | tensor::f16, 204 | }; 205 | 206 | pub fn run(args: Config) { 207 | let device = CudaDevice::default(); 208 | 209 | chat::>(args, device); 210 | } 211 | } 212 | 213 | pub fn main() { 214 | // Parse arguments 215 | let args = Config::parse(); 216 | 217 | #[cfg(feature = "tch-gpu")] 218 | tch_gpu::run(args); 219 | #[cfg(feature = "tch-cpu")] 220 | tch_cpu::run(args); 221 | #[cfg(feature = "wgpu")] 222 | wgpu::run(args); 223 | #[cfg(feature = "cuda")] 224 | cuda::run(args); 225 | } 226 | -------------------------------------------------------------------------------- /llama-burn/src/cache.rs: -------------------------------------------------------------------------------- 1 | use burn::tensor::{backend::Backend, Device, Tensor}; 2 | 3 | pub(crate) struct AutoregressiveCache { 4 | /// Tensor cache with shape `[batch_size, num_heads, seq_len, d_model]` 5 | cache: Tensor, 6 | pub(crate) max_seq_len: usize, 7 | cur_seq_len: usize, 8 | } 9 | 10 | impl AutoregressiveCache { 11 | /// Creates a new empty cache. 12 | pub fn new( 13 | max_batch_size: usize, 14 | num_heads: usize, 15 | max_seq_len: usize, 16 | d_model: usize, 17 | device: &Device, 18 | ) -> Self { 19 | Self { 20 | cache: Tensor::empty([max_batch_size, num_heads, max_seq_len, d_model], device), 21 | max_seq_len, 22 | cur_seq_len: 0, 23 | } 24 | } 25 | 26 | /// Reset the cache state. 27 | pub fn reset(&mut self) { 28 | self.cache = Tensor::empty(self.cache.shape(), &self.cache.device()); 29 | self.cur_seq_len = 0; 30 | } 31 | 32 | pub fn forward(&mut self, tensor: Tensor) -> Tensor { 33 | let [batch_size, num_heads, seq_len, d_model] = tensor.dims(); 34 | let mut new_seq_len = self.cur_seq_len + seq_len; 35 | 36 | if new_seq_len > self.max_seq_len { 37 | self.cur_seq_len = self.max_seq_len - seq_len; 38 | let prev_slice = self.cache.clone().slice([ 39 | 0..batch_size, 40 | 0..num_heads, 41 | seq_len..self.max_seq_len, 42 | 0..d_model, 43 | ]); 44 | self.cache = self.cache.clone().slice_assign( 45 | [0..batch_size, 0..num_heads, 0..self.cur_seq_len, 0..d_model], 46 | prev_slice, 47 | ); 48 | new_seq_len = self.max_seq_len; 49 | } 50 | 51 | self.cache = self.cache.clone().slice_assign( 52 | [ 53 | 0..batch_size, 54 | 0..num_heads, 55 | self.cur_seq_len..new_seq_len, 56 | 0..d_model, 57 | ], 58 | tensor, 59 | ); 60 | 61 | self.cur_seq_len += seq_len; 62 | 63 | self.cache 64 | .clone() 65 | .slice([0..batch_size, 0..num_heads, 0..self.cur_seq_len, 0..d_model]) 66 | } 67 | 68 | /// Returns the cached sequence length. 69 | pub fn len(&self) -> usize { 70 | self.cur_seq_len 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /llama-burn/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub(crate) mod cache; 2 | pub mod llama; 3 | pub mod pretrained; 4 | pub mod sampling; 5 | pub mod tokenizer; 6 | mod transformer; 7 | 8 | #[cfg(test)] 9 | mod tests { 10 | #[cfg(feature = "cuda")] 11 | use burn::{backend::CudaJit, tensor::f16}; 12 | #[cfg(feature = "cuda")] 13 | pub type TestBackend = CudaJit; 14 | 15 | // NOTE: no tests on tch cpu (f32) 16 | #[cfg(feature = "tch-gpu")] 17 | use burn::{backend::LibTorch, tensor::f16}; 18 | #[cfg(feature = "tch-gpu")] 19 | pub type TestBackend = LibTorch; 20 | 21 | #[cfg(any(feature = "cuda", feature = "tch-gpu"))] 22 | pub type TestTensor = burn::tensor::Tensor; 23 | } 24 | -------------------------------------------------------------------------------- /llama-burn/src/pretrained.rs: -------------------------------------------------------------------------------- 1 | /// Pre-trained model metadata. 2 | pub struct Pretrained { 3 | pub(super) name: &'static str, 4 | pub(super) model: &'static str, 5 | pub(super) tokenizer: &'static str, 6 | } 7 | 8 | #[cfg(feature = "pretrained")] 9 | mod downloader { 10 | use super::*; 11 | use burn::data::network::downloader; 12 | use std::fs::{create_dir_all, File}; 13 | use std::io::Write; 14 | use std::path::PathBuf; 15 | 16 | impl Pretrained { 17 | /// Download the file to the local cache directory. 18 | fn download(&self, url: &str) -> Result { 19 | // Model cache directory 20 | let model_dir = dirs::home_dir() 21 | .expect("Should be able to get home directory") 22 | .join(".cache") 23 | .join("llama-burn") 24 | .join(self.name); 25 | 26 | if !model_dir.exists() { 27 | create_dir_all(&model_dir)?; 28 | } 29 | 30 | let file_base_name = url 31 | .rsplit_once('/') 32 | .unwrap() 33 | .1 34 | .replace("?download=true", ""); 35 | let file_name = model_dir.join(&file_base_name); 36 | if !file_name.exists() { 37 | // Download file content 38 | let bytes = downloader::download_file_as_bytes(url, &file_base_name); 39 | 40 | // Write content to file 41 | let mut output_file = File::create(&file_name)?; 42 | output_file.write_all(&bytes)?; // write_all is not OS limited (files over 2GB) 43 | } 44 | 45 | Ok(file_name) 46 | } 47 | 48 | /// Download the pre-trained model weights to the local cache directory. 49 | pub fn download_weights(&self) -> Result { 50 | self.download(self.model) 51 | } 52 | 53 | /// Download the tokenizer to the local cache directory. 54 | pub fn download_tokenizer(&self) -> Result { 55 | self.download(self.tokenizer) 56 | } 57 | } 58 | } 59 | 60 | pub trait ModelMeta { 61 | fn pretrained(&self) -> Pretrained; 62 | } 63 | 64 | /// Llama pre-trained weights. 65 | pub enum Llama { 66 | /// Llama-3-8B. 67 | Llama3, 68 | /// Llama-3-8B-Instruct. 69 | Llama3Instruct, 70 | /// Llama-3.1-8B-Instruct. 71 | Llama31Instruct, 72 | /// Llama-3.2-3B-Instruct. 73 | Llama323bInstruct, 74 | /// Llama-3.2-1B-Instruct. 75 | Llama321bInstruct, 76 | /// TinyLlama-1.1B Chat v1.0. 77 | TinyLlama, 78 | } 79 | 80 | impl ModelMeta for Llama { 81 | fn pretrained(&self) -> Pretrained { 82 | match self { 83 | Self::Llama3 => Pretrained { 84 | name: "Llama-3-8B", 85 | model: "https://huggingface.co/tracel-ai/llama-3-8b-burn/resolve/main/model.mpk?download=true", 86 | tokenizer: "https://huggingface.co/tracel-ai/llama-3-8b-burn/resolve/main/tokenizer.model?download=true", 87 | }, 88 | Self::Llama3Instruct => Pretrained { 89 | name: "Llama-3-8B-Instruct", 90 | model: "https://huggingface.co/tracel-ai/llama-3-8b-instruct-burn/resolve/main/model.mpk?download=true", 91 | tokenizer: "https://huggingface.co/tracel-ai/llama-3-8b-instruct-burn/resolve/main/tokenizer.model?download=true", 92 | }, 93 | Self::Llama31Instruct => Pretrained { 94 | name: "Llama-3.1-8B-Instruct", 95 | model: "https://huggingface.co/tracel-ai/llama-3.1-8b-instruct-burn/resolve/main/model.mpk?download=true", 96 | tokenizer: "https://huggingface.co/tracel-ai/llama-3.1-8b-instruct-burn/resolve/main/tokenizer.model?download=true", 97 | }, 98 | Self::Llama323bInstruct => Pretrained { 99 | name: "Llama-3.2-3B-Instruct", 100 | model: "https://huggingface.co/tracel-ai/llama-3.2-3b-instruct-burn/resolve/main/model.mpk?download=true", 101 | tokenizer: "https://huggingface.co/tracel-ai/llama-3.2-3b-instruct-burn/resolve/main/tokenizer.model?download=true", 102 | }, 103 | Self::Llama321bInstruct => Pretrained { 104 | name: "Llama-3.2-1B-Instruct", 105 | model: "https://huggingface.co/tracel-ai/llama-3.2-1b-instruct-burn/resolve/main/model.mpk?download=true", 106 | tokenizer: "https://huggingface.co/tracel-ai/llama-3.2-1b-instruct-burn/resolve/main/tokenizer.model?download=true", 107 | }, 108 | Self::TinyLlama => Pretrained { 109 | name: "TinyLlama-1.1B", 110 | model: "https://huggingface.co/tracel-ai/tiny-llama-1.1b-burn/resolve/main/model.mpk?download=true", 111 | tokenizer: "https://huggingface.co/tracel-ai/tiny-llama-1.1b-burn/resolve/main/tokenizer.json?download=true", 112 | }, 113 | } 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /llama-burn/src/sampling.rs: -------------------------------------------------------------------------------- 1 | use burn::tensor::{backend::Backend, Int, Tensor}; 2 | use rand::{ 3 | distributions::{Distribution, WeightedIndex}, 4 | rngs::StdRng, 5 | SeedableRng, 6 | }; 7 | 8 | pub enum Sampler { 9 | TopP(TopP), 10 | Argmax, 11 | } 12 | 13 | impl Sampler { 14 | pub fn sample(&mut self, logits: Tensor) -> Tensor { 15 | match self { 16 | Self::TopP(s) => s.sample(logits), 17 | Self::Argmax => logits.argmax(1), 18 | } 19 | } 20 | } 21 | 22 | pub trait Sampling { 23 | fn sample(&mut self, logits: Tensor) -> Tensor; 24 | } 25 | 26 | /// Top-p sampling (nucleus sampling) selects the smallest set of tokens whose cumulative 27 | /// probability mass exceed the threshold p. 28 | pub struct TopP { 29 | /// Probability threshold for sampling. 30 | p: f64, 31 | /// RNG. 32 | rng: StdRng, 33 | } 34 | 35 | impl TopP { 36 | pub fn new(p: f64, seed: u64) -> Self { 37 | let rng = StdRng::seed_from_u64(seed); 38 | Self { p, rng } 39 | } 40 | } 41 | 42 | impl Sampling for TopP { 43 | fn sample(&mut self, probs: Tensor) -> Tensor { 44 | assert_eq!( 45 | probs.dims()[0], 46 | 1, 47 | "Naive top-p sampling only supports single-batch tensors" 48 | ); 49 | let (probs_sort, probs_idx) = probs.sort_descending_with_indices(1); 50 | 51 | // TODO: cumsum + Distribution::Multinomial support 52 | 53 | let mut probs_sort = probs_sort.to_data().iter::().collect::>(); 54 | 55 | let mut cumsum = 0.; 56 | probs_sort.iter_mut().for_each(|x| { 57 | if cumsum >= self.p { 58 | *x = 0.0; 59 | } else { 60 | cumsum += *x; 61 | } 62 | }); 63 | 64 | let next_token_idx = WeightedIndex::new(probs_sort) 65 | .unwrap() 66 | .sample(&mut self.rng); 67 | 68 | probs_idx.slice([0..1, next_token_idx..next_token_idx + 1]) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /llama-burn/src/tokenizer/base.rs: -------------------------------------------------------------------------------- 1 | pub trait Tokenizer { 2 | /// Load the tokenizer from the provided path. 3 | fn new(tokenizer_path: &str) -> Result 4 | where 5 | Self: Sized; 6 | 7 | /// Encode a string into a list of token identifiers. 8 | fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec; 9 | 10 | /// Decode a list of token identifiers into a string. 11 | fn decode(&self, tokens: Vec) -> String; 12 | 13 | /// Beginning of sentence token. 14 | fn bos(&self) -> String { 15 | self.decode(vec![self.bos_id()]) 16 | } 17 | 18 | /// Beginning of sentence token identifier. 19 | fn bos_id(&self) -> u32; 20 | 21 | /// End of sentence token. 22 | fn eos(&self) -> String { 23 | self.decode(vec![self.eos_id()]) 24 | } 25 | 26 | /// End of sentence token identifier. 27 | fn eos_id(&self) -> u32; 28 | 29 | /// Stop token identifiers. 30 | fn stop_ids(&self) -> Vec; 31 | } 32 | -------------------------------------------------------------------------------- /llama-burn/src/tokenizer/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod base; 2 | pub use base::*; 3 | 4 | #[cfg(feature = "llama3")] 5 | pub mod tiktoken; 6 | #[cfg(feature = "llama3")] 7 | pub use tiktoken::*; 8 | 9 | #[cfg(feature = "tiny")] 10 | pub mod sentence_piece; 11 | #[cfg(feature = "tiny")] 12 | pub use sentence_piece::*; 13 | -------------------------------------------------------------------------------- /llama-burn/src/tokenizer/sentence_piece.rs: -------------------------------------------------------------------------------- 1 | use tokenizers::Tokenizer as BaseTokenizer; 2 | 3 | use super::Tokenizer; 4 | 5 | const BOS_TOKEN_ID: u32 = 1; 6 | const EOS_TOKEN_ID: u32 = 2; 7 | 8 | pub struct SentiencePieceTokenizer { 9 | bpe: BaseTokenizer, 10 | bos_token_id: u32, 11 | eos_token_id: u32, 12 | } 13 | 14 | impl Tokenizer for SentiencePieceTokenizer { 15 | /// Load the [SentenciePiece](https://github.com/google/sentencepiece) tokenizer. 16 | fn new(tokenizer_path: &str) -> Result { 17 | let bpe = BaseTokenizer::from_file(tokenizer_path).map_err(|e| e.to_string())?; 18 | 19 | Ok(Self { 20 | bpe, 21 | bos_token_id: BOS_TOKEN_ID, 22 | eos_token_id: EOS_TOKEN_ID, 23 | }) 24 | } 25 | 26 | fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec { 27 | let bos_token = if bos { vec![self.bos_token_id] } else { vec![] }; 28 | let eos_token = if eos { vec![self.eos_token_id] } else { vec![] }; 29 | 30 | let tokens = self.bpe.encode(text, false).unwrap().get_ids().to_vec(); 31 | 32 | [bos_token, tokens, eos_token] 33 | .into_iter() 34 | .flat_map(|t| t.into_iter()) 35 | .collect() 36 | } 37 | 38 | fn decode(&self, tokens: Vec) -> String { 39 | self.bpe 40 | .decode( 41 | &tokens.into_iter().map(|t| t as u32).collect::>(), 42 | true, 43 | ) 44 | .unwrap() 45 | } 46 | 47 | fn bos_id(&self) -> u32 { 48 | self.bos_token_id 49 | } 50 | 51 | fn eos_id(&self) -> u32 { 52 | self.eos_token_id 53 | } 54 | 55 | fn stop_ids(&self) -> Vec { 56 | vec![self.eos_id()] 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /llama-burn/src/tokenizer/tiktoken.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | fs::File, 3 | io::{BufRead, BufReader}, 4 | }; 5 | 6 | use base64::{engine::general_purpose::STANDARD, Engine}; 7 | use rustc_hash::FxHashMap as HashMap; 8 | use tiktoken_rs::CoreBPE; 9 | 10 | use super::Tokenizer; 11 | 12 | const BOS_TOKEN: &str = "<|begin_of_text|>"; 13 | const EOS_TOKEN: &str = "<|end_of_text|>"; 14 | const EOT_TOKEN: &str = "<|eot_id|>"; 15 | const EOM_TOKEN: &str = "<|eom_id|>"; 16 | 17 | const NUM_RESERVED_SPECIAL_TOKENS: usize = 256; 18 | const SPECIAL_TOKENS: [&str; 11] = [ 19 | BOS_TOKEN, 20 | EOS_TOKEN, 21 | "<|reserved_special_token_0|>", 22 | "<|reserved_special_token_1|>", 23 | "<|finetune_right_pad_id|>", 24 | "<|step_id|>", 25 | "<|start_header_id|>", 26 | "<|end_header_id|>", 27 | EOM_TOKEN, // end of message 28 | EOT_TOKEN, // end of turn 29 | "<|python_tag|>", 30 | ]; 31 | const PATTERN: &str = r#"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"#; 32 | 33 | #[derive(Debug, Clone)] 34 | pub struct Tiktoken { 35 | bpe: CoreBPE, 36 | bos_token_id: usize, 37 | eos_token_id: usize, 38 | eot_token_id: usize, 39 | eom_token_id: usize, 40 | } 41 | 42 | impl Tokenizer for Tiktoken { 43 | /// Load the [Tiktoken](https://github.com/openai/tiktoken) tokenizer. 44 | fn new(tiktoken_bpe_file: &str) -> Result { 45 | let file = File::open(tiktoken_bpe_file).map_err(|e| e.to_string())?; 46 | let mut mergeable_ranks: HashMap, usize> = HashMap::default(); 47 | 48 | for line in BufReader::new(file).lines().flatten() { 49 | let mut parts = line.split(' '); 50 | let token = STANDARD 51 | .decode(parts.next().ok_or("Missing token")?) 52 | .map_err(|e| e.to_string())?; 53 | let rank = parts 54 | .next() 55 | .ok_or("Missing rank")? 56 | .parse::() 57 | .map_err(|e| e.to_string())?; 58 | 59 | mergeable_ranks.insert(token, rank); 60 | } 61 | let num_base_tokens = mergeable_ranks.len(); 62 | 63 | let special_tokens = [ 64 | SPECIAL_TOKENS 65 | .iter() 66 | .map(|t| t.to_string()) 67 | .collect::>(), 68 | (0..NUM_RESERVED_SPECIAL_TOKENS - SPECIAL_TOKENS.len()) 69 | .into_iter() 70 | .map(|i| format!("<|reserved_special_token_{}|>", i + 2)) 71 | .collect::>(), 72 | ] 73 | .concat(); 74 | let special_tokens = special_tokens 75 | .into_iter() 76 | .enumerate() 77 | .map(|(i, s)| (s, i + num_base_tokens)) 78 | .collect::>(); 79 | 80 | let bos_token_id = special_tokens[BOS_TOKEN]; 81 | let eos_token_id = special_tokens[EOS_TOKEN]; 82 | let eot_token_id = special_tokens[EOT_TOKEN]; 83 | let eom_token_id = special_tokens[EOM_TOKEN]; 84 | 85 | let bpe = 86 | CoreBPE::new(mergeable_ranks, special_tokens, PATTERN).map_err(|e| e.to_string())?; 87 | Ok(Self { 88 | bpe, 89 | bos_token_id, 90 | eos_token_id, 91 | eot_token_id, 92 | eom_token_id, 93 | }) 94 | } 95 | 96 | fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec { 97 | let bos_token = if bos { vec![self.bos_token_id] } else { vec![] }; 98 | let eos_token = if eos { vec![self.eos_token_id] } else { vec![] }; 99 | 100 | let tokens = self.bpe.encode_with_special_tokens(text); 101 | 102 | [bos_token, tokens, eos_token] 103 | .into_iter() 104 | .flat_map(|t| t.into_iter()) 105 | .map(|t| t as u32) 106 | .collect() 107 | } 108 | 109 | fn decode(&self, tokens: Vec) -> String { 110 | self.bpe 111 | .decode(tokens.into_iter().map(|t| t as usize).collect()) 112 | .expect("Should decode tokens") 113 | } 114 | 115 | fn bos_id(&self) -> u32 { 116 | self.bos_token_id as u32 117 | } 118 | 119 | fn eos_id(&self) -> u32 { 120 | self.eos_token_id as u32 121 | } 122 | 123 | fn stop_ids(&self) -> Vec { 124 | vec![ 125 | self.eos_id(), 126 | self.eom_token_id as u32, 127 | self.eot_token_id as u32, 128 | ] 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /mobilenetv2-burn/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Arjun31415", "guillaumelagrange "] 3 | license = "MIT OR Apache-2.0" 4 | name = "mobilenetv2-burn" 5 | version = "0.1.0" 6 | edition = "2021" 7 | 8 | [features] 9 | default = [] 10 | std = [] 11 | pretrained = ["burn/network", "std", "dep:dirs"] 12 | 13 | [dependencies] 14 | # Note: default-features = false is needed to disable std 15 | burn = { version = "0.16.0" } 16 | burn-import = { version = "0.16.0" } 17 | dirs = { version = "5.0.1", optional = true } 18 | serde = { version = "1.0.192", default-features = false, features = [ 19 | "derive", 20 | "alloc", 21 | ] } # alloc is for no_std, derive is needed 22 | 23 | [dev-dependencies] 24 | burn = { version = "0.16.0", features = ["ndarray"] } 25 | image = { version = "0.24.9", features = ["png", "jpeg"] } 26 | -------------------------------------------------------------------------------- /mobilenetv2-burn/NOTICES.md: -------------------------------------------------------------------------------- 1 | # NOTICES AND INFORMATION 2 | 3 | This file contains notices and information required by libraries that this repository copied or derived from. The use of the following resources complies with the licenses provided. 4 | 5 | ## Sample Image 6 | 7 | Image Title: Standing yellow Labrador Retriever dog. 8 | Author: Djmirko 9 | Source: https://commons.wikimedia.org/wiki/File:YellowLabradorLooking_new.jpg 10 | License: https://creativecommons.org/licenses/by-sa/3.0/ 11 | 12 | ## Pre-trained Model 13 | 14 | The ImageNet pre-trained model was ported from [`torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2`](https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html#torchvision.models.MobileNet_V2_Weights). 15 | 16 | As opposed to [other pre-trained models](https://pytorch.org/vision/stable/models/generated/torchvision.models.regnet_y_128gf.html#torchvision.models.RegNet_Y_128GF_Weights) in `torchvision`, no specific license was linked to the weights, which are assumed to be under the library's [BSD-3-Clause license](https://github.com/pytorch/vision/blob/main/LICENSE) ([ref](https://github.com/pytorch/vision/issues/160)). 17 | -------------------------------------------------------------------------------- /mobilenetv2-burn/README.md: -------------------------------------------------------------------------------- 1 | # MobileNetV2 Burn 2 | 3 | [MobileNetV2](https://arxiv.org/abs/1801.04381) is a convolutional neural network architecture for 4 | classification tasks which seeks to perform well on mobile devices. You can find the 5 | [Burn](https://github.com/tracel-ai/burn) implementation for the MobileNetV2 in 6 | [src/model/mobilenetv2.rs](src/model/mobilenetv2.rs). 7 | 8 | The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html). 9 | 10 | ## Usage 11 | 12 | ### `Cargo.toml` 13 | 14 | Add this to your `Cargo.toml`: 15 | 16 | ```toml 17 | [dependencies] 18 | mobilenetv2-burn = { git = "https://github.com/tracel-ai/models", package = "mobilenetv2-burn", default-features = false } 19 | ``` 20 | 21 | If you want to get the pre-trained ImageNet weights, enable the `pretrained` feature flag. 22 | 23 | ```toml 24 | [dependencies] 25 | mobilenetv2-burn = { git = "https://github.com/tracel-ai/models", package = "mobilenetv2-burn", features = ["pretrained"] } 26 | ``` 27 | 28 | **Important:** this feature requires `std`. 29 | 30 | ### Example Usage 31 | 32 | The [inference example](examples/inference.rs) initializes a MobileNetV2 from the ImageNet 33 | [pre-trained weights](https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html#torchvision.models.MobileNet_V2_Weights) 34 | with the `NdArray` backend and performs inference on the provided input image. 35 | 36 | You can run the example with the following command: 37 | 38 | ```sh 39 | cargo run --release --features pretrained --example inference samples/dog.jpg 40 | ``` 41 | -------------------------------------------------------------------------------- /mobilenetv2-burn/examples/inference.rs: -------------------------------------------------------------------------------- 1 | use mobilenetv2_burn::model::{imagenet, mobilenetv2::MobileNetV2, weights}; 2 | 3 | use burn::{ 4 | backend::NdArray, 5 | tensor::{backend::Backend, Device, Element, Tensor, TensorData}, 6 | }; 7 | 8 | const HEIGHT: usize = 224; 9 | const WIDTH: usize = 224; 10 | 11 | fn to_tensor( 12 | data: Vec, 13 | shape: [usize; 3], 14 | device: &Device, 15 | ) -> Tensor { 16 | Tensor::::from_data( 17 | TensorData::new(data, shape).convert::(), 18 | device, 19 | ) 20 | // [H, W, C] -> [C, H, W] 21 | .permute([2, 0, 1]) 22 | / 255 // normalize between [0, 1] 23 | } 24 | 25 | pub fn main() { 26 | // Parse arguments 27 | let img_path = std::env::args().nth(1).expect("No image path provided"); 28 | 29 | // Create MobileNetV2 30 | let device = Default::default(); 31 | let model: MobileNetV2 = 32 | MobileNetV2::pretrained(weights::MobileNetV2::ImageNet1kV2, &device) 33 | .map_err(|err| format!("Failed to load pre-trained weights.\nError: {err}")) 34 | .unwrap(); 35 | 36 | // Load image 37 | let img = image::open(&img_path) 38 | .map_err(|err| format!("Failed to load image {img_path}.\nError: {err}")) 39 | .unwrap(); 40 | 41 | // Resize to 224x224 42 | let resized_img = img.resize_exact( 43 | WIDTH as u32, 44 | HEIGHT as u32, 45 | image::imageops::FilterType::Triangle, // also known as bilinear in 2D 46 | ); 47 | 48 | // Create tensor from image data 49 | let img_tensor = to_tensor( 50 | resized_img.into_rgb8().into_raw(), 51 | [HEIGHT, WIDTH, 3], 52 | &device, 53 | ) 54 | .unsqueeze::<4>(); // [B, C, H, W] 55 | 56 | // Normalize the image 57 | let x = imagenet::Normalizer::new(&device).normalize(img_tensor); 58 | 59 | // Forward pass 60 | let out = model.forward(x); 61 | 62 | // Output class index w/ score (raw) 63 | let (score, idx) = out.max_dim_with_indices(1); 64 | let idx = idx.into_scalar() as usize; 65 | 66 | println!( 67 | "Predicted: {}\nCategory Id: {}\nScore: {:.4}", 68 | imagenet::CLASSES[idx], 69 | idx, 70 | score.into_scalar() 71 | ); 72 | } 73 | -------------------------------------------------------------------------------- /mobilenetv2-burn/samples/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tracel-ai/models/146da485484c6cec70b69065117329dea4604b73/mobilenetv2-burn/samples/dog.jpg -------------------------------------------------------------------------------- /mobilenetv2-burn/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![cfg_attr(not(feature = "std"), no_std)] 2 | pub mod model; 3 | extern crate alloc; 4 | -------------------------------------------------------------------------------- /mobilenetv2-burn/src/model/conv_norm.rs: -------------------------------------------------------------------------------- 1 | use burn::{ 2 | config::Config, 3 | module::Module, 4 | nn::{ 5 | conv::{Conv2d, Conv2dConfig}, 6 | BatchNorm, BatchNormConfig, PaddingConfig2d, 7 | }, 8 | tensor::{self, backend::Backend, Tensor}, 9 | }; 10 | 11 | /// A rectified linear unit where the activation is limited to a maximum of 6. 12 | #[derive(Module, Debug, Clone, Default)] 13 | pub struct ReLU6 {} 14 | impl ReLU6 { 15 | pub fn forward(&self, input: Tensor) -> Tensor { 16 | tensor::activation::relu(input).clamp_max(6) 17 | } 18 | } 19 | 20 | /// A Conv2d -> BatchNorm -> activation block. 21 | #[derive(Module, Debug)] 22 | pub struct Conv2dNormActivation { 23 | conv: Conv2d, 24 | norm: BatchNorm, 25 | activation: ReLU6, 26 | } 27 | 28 | /// [Conv2dNormActivation] configuration. 29 | #[derive(Config, Debug)] 30 | pub struct Conv2dNormActivationConfig { 31 | pub in_channels: usize, 32 | pub out_channels: usize, 33 | 34 | #[config(default = "3")] 35 | pub kernel_size: usize, 36 | 37 | #[config(default = "1")] 38 | pub stride: usize, 39 | 40 | #[config(default = "None")] 41 | pub padding: Option, 42 | 43 | #[config(default = "1")] 44 | pub groups: usize, 45 | 46 | #[config(default = "1")] 47 | pub dilation: usize, 48 | 49 | #[config(default = false)] 50 | pub bias: bool, 51 | } 52 | 53 | impl Conv2dNormActivationConfig { 54 | pub fn init(&self, device: &B::Device) -> Conv2dNormActivation { 55 | let padding = if let Some(padding) = self.padding { 56 | padding 57 | } else { 58 | (self.kernel_size - 1) / 2 * self.dilation 59 | }; 60 | 61 | Conv2dNormActivation { 62 | conv: Conv2dConfig::new( 63 | [self.in_channels, self.out_channels], 64 | [self.kernel_size, self.kernel_size], 65 | ) 66 | .with_padding(PaddingConfig2d::Explicit(padding, padding)) 67 | .with_stride([self.stride, self.stride]) 68 | .with_bias(self.bias) 69 | .with_dilation([self.dilation, self.dilation]) 70 | .with_groups(self.groups) 71 | .init(device), 72 | norm: BatchNormConfig::new(self.out_channels).init(device), 73 | activation: ReLU6 {}, 74 | } 75 | } 76 | } 77 | impl Conv2dNormActivation { 78 | pub fn forward(&self, input: Tensor) -> Tensor { 79 | let x = self.conv.forward(input); 80 | let x = self.norm.forward(x); 81 | self.activation.forward(x) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /mobilenetv2-burn/src/model/inverted_residual.rs: -------------------------------------------------------------------------------- 1 | use super::conv_norm::{Conv2dNormActivation, Conv2dNormActivationConfig}; 2 | use burn::config::Config; 3 | use burn::nn::conv::Conv2dConfig; 4 | use burn::nn::{BatchNorm, BatchNormConfig}; 5 | use burn::tensor::Tensor; 6 | use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend}; 7 | 8 | #[derive(Module, Debug)] 9 | pub struct PointWiseLinear { 10 | conv: Conv2d, 11 | norm: BatchNorm, 12 | } 13 | 14 | impl PointWiseLinear { 15 | pub fn forward(&self, x: Tensor) -> Tensor { 16 | self.norm.forward(self.conv.forward(x)) 17 | } 18 | } 19 | 20 | /// [Inverted Residual Block](https://paperswithcode.com/method/inverted-residual-block). 21 | #[derive(Module, Debug)] 22 | pub struct InvertedResidual { 23 | use_res_connect: bool, 24 | pw: Option>, // pointwise, only when expand ratio != 1 25 | dw: Conv2dNormActivation, 26 | pw_linear: PointWiseLinear, 27 | } 28 | 29 | /// [InvertedResidual](InvertedResidual) configuration. 30 | #[derive(Config, Debug)] 31 | pub struct InvertedResidualConfig { 32 | pub inp: usize, 33 | pub oup: usize, 34 | pub stride: usize, 35 | pub expand_ratio: usize, 36 | } 37 | 38 | impl InvertedResidualConfig { 39 | /// Initialize a new [InvertedResidual](InvertedResidual) module. 40 | pub fn init(&self, device: &B::Device) -> InvertedResidual { 41 | let hidden_dim = self.inp * self.expand_ratio; 42 | let pw = if self.expand_ratio != 1 { 43 | Some( 44 | Conv2dNormActivationConfig::new(self.inp, hidden_dim) 45 | .with_kernel_size(1) 46 | .init(device), 47 | ) 48 | } else { 49 | None 50 | }; 51 | let dw = Conv2dNormActivationConfig::new(hidden_dim, hidden_dim) 52 | .with_stride(self.stride) 53 | .with_groups(hidden_dim) 54 | .init(device); 55 | let pw_linear = PointWiseLinear { 56 | conv: Conv2dConfig::new([hidden_dim, self.oup], [1, 1]) 57 | .with_stride([1, 1]) 58 | .with_padding(burn::nn::PaddingConfig2d::Explicit(0, 0)) 59 | .with_bias(false) 60 | .init(device), 61 | norm: BatchNormConfig::new(self.oup).init(device), 62 | }; 63 | InvertedResidual { 64 | use_res_connect: self.stride == 1 && self.inp == self.oup, 65 | pw_linear, 66 | dw, 67 | pw, 68 | } 69 | } 70 | } 71 | 72 | impl InvertedResidual { 73 | pub fn forward(&self, x: &Tensor) -> Tensor { 74 | let mut out = x.clone(); 75 | if let Some(pw) = &self.pw { 76 | out = pw.forward(out); 77 | } 78 | out = self.dw.forward(out); 79 | out = self.pw_linear.forward(out); 80 | 81 | if self.use_res_connect { 82 | out = out + x.clone(); 83 | } 84 | out 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /mobilenetv2-burn/src/model/mobilenetv2.rs: -------------------------------------------------------------------------------- 1 | use alloc::vec; 2 | use alloc::vec::Vec; 3 | use core::cmp::max; 4 | 5 | use burn::{ 6 | config::Config, 7 | module::Module, 8 | nn::{ 9 | pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, 10 | Dropout, DropoutConfig, Linear, LinearConfig, 11 | }, 12 | tensor::{backend::Backend, Tensor}, 13 | }; 14 | 15 | use super::{ 16 | conv_norm::{Conv2dNormActivation, Conv2dNormActivationConfig}, 17 | inverted_residual::{InvertedResidual, InvertedResidualConfig}, 18 | }; 19 | 20 | #[cfg(feature = "pretrained")] 21 | use { 22 | super::weights::{self, WeightsMeta}, 23 | burn::record::{FullPrecisionSettings, Recorder, RecorderError}, 24 | burn::tensor::Device, 25 | burn_import::pytorch::{LoadArgs, PyTorchFileRecorder}, 26 | }; 27 | 28 | /// Network blocks structure 29 | const INVERTED_RESIDUAL_SETTINGS: [[usize; 4]; 7] = [ 30 | // (t = expansion factor; c = channels; n = num blocks; s = stride) 31 | // t, c, n, s 32 | [1, 16, 1, 1], 33 | [6, 24, 2, 2], 34 | [6, 32, 3, 2], 35 | [6, 64, 4, 2], 36 | [6, 96, 3, 1], 37 | [6, 160, 3, 2], 38 | [6, 320, 1, 1], 39 | ]; 40 | /// Round the number of channels in each layer to be a multiple of this number. 41 | const ROUND_NEAREST: usize = 8; 42 | 43 | #[derive(Debug, Module)] 44 | pub struct MobileNetV2 { 45 | features: Vec>, 46 | classifier: Classifier, 47 | avg_pool: AdaptiveAvgPool2d, 48 | } 49 | 50 | impl MobileNetV2 { 51 | pub fn forward(&self, input: Tensor) -> Tensor { 52 | let mut x = input; 53 | for layer in &self.features { 54 | match layer { 55 | ConvBlock::InvertedResidual(block) => { 56 | x = block.forward(&x); 57 | } 58 | ConvBlock::Conv(conv) => { 59 | x = conv.forward(x); 60 | } 61 | } 62 | } 63 | x = self.avg_pool.forward(x); 64 | // Reshape [B, C, 1, 1] -> [B, C] 65 | let x = x.flatten(1, 3); 66 | 67 | self.classifier.forward(x) 68 | } 69 | 70 | /// Load specified pre-trained PyTorch weights as a record. 71 | #[cfg(feature = "pretrained")] 72 | fn load_weights_record( 73 | weights: &weights::Weights, 74 | device: &Device, 75 | ) -> Result, RecorderError> { 76 | // Download torch weights 77 | let torch_weights = weights.download().map_err(|err| { 78 | RecorderError::Unknown(format!("Could not download weights.\nError: {err}")) 79 | })?; 80 | 81 | // Load weights from torch state_dict 82 | let load_args = LoadArgs::new(torch_weights) 83 | // Map features.{0,18}.0.* -> features.{0,18}.conv.* 84 | .with_key_remap("features\\.(0|18)\\.0.(.+)", "features.$1.conv.$2") 85 | // Map features.{0,18}.1.* -> features.{0,18}.norm.* 86 | .with_key_remap("features\\.(0|18)\\.1.(.+)", "features.$1.norm.$2") 87 | // Map features.1.conv.0.0.* -> features.1.dw.conv.* 88 | .with_key_remap("features\\.1\\.conv.0.0.(.+)", "features.1.dw.conv.$1") 89 | // Map features.1.conv.0.1.* -> features.1.dw.conv.* 90 | .with_key_remap("features\\.1\\.conv.0.1.(.+)", "features.1.dw.norm.$1") 91 | // Map features.1.conv.1.* -> features.1.pw_linear.conv.* 92 | .with_key_remap("features\\.1\\.conv.1.(.+)", "features.1.pw_linear.conv.$1") 93 | // Map features.1.conv.2.* -> features.1.pw_linear.norm.* 94 | .with_key_remap("features\\.1\\.conv.2.(.+)", "features.1.pw_linear.norm.$1") 95 | // Map features.[i].conv.0.0.* -> features.[i].pw.conv.* 96 | .with_key_remap( 97 | "features\\.([2-9]|1[0-7])\\.conv.0.0.(.+)", // for i in [2, 17] 98 | "features.$1.pw.conv.$2", 99 | ) 100 | // Map features.[i].conv.0.1.* -> features.[i].pw.conv.* 101 | .with_key_remap( 102 | "features\\.([2-9]|1[0-7])\\.conv.0.1.(.+)", // for i in [2, 17] 103 | "features.$1.pw.norm.$2", 104 | ) 105 | // Map features.[i].conv.1.0.* -> features.[i].dw.conv.* 106 | .with_key_remap( 107 | "features\\.([2-9]|1[0-7])\\.conv.1.0.(.+)", // for i in [2, 17] 108 | "features.$1.dw.conv.$2", 109 | ) 110 | // Map features.[i].conv.1.1.* -> features.[i].dw.norm.* 111 | .with_key_remap( 112 | "features\\.([2-9]|1[0-7])\\.conv.1.1.(.+)", // for i in [2, 17] 113 | "features.$1.dw.norm.$2", 114 | ) 115 | // Map features.[i].conv.2.* -> features.[i].pw_linear.conv.* 116 | .with_key_remap( 117 | "features\\.([2-9]|1[0-7])\\.conv.2.(.+)", // for i in [2, 17] 118 | "features.$1.pw_linear.conv.$2", 119 | ) 120 | // Map features.[i].conv.3.* -> features.[i].pw_linear.norm.* 121 | .with_key_remap( 122 | "features\\.([2-9]|1[0-7])\\.conv.3.(.+)", // for i in [2, 17] 123 | "features.$1.pw_linear.norm.$2", 124 | ) 125 | // Map classifier.1.* -> classifier.linear.* 126 | .with_key_remap("classifier.1.(.+)", "classifier.linear.$1"); 127 | let record = PyTorchFileRecorder::::new().load(load_args, device)?; 128 | 129 | Ok(record) 130 | } 131 | 132 | /// MobileNetV2 from [`MobileNetV2: Inverted Residuals and Linear Bottlenecks`](https://arxiv.org/abs/1801.04381) 133 | /// with pre-trained weights. 134 | /// 135 | /// # Arguments 136 | /// 137 | /// * `weights`: Pre-trained weights to load. 138 | /// * `device` - Device to create the module on. 139 | /// 140 | /// # Returns 141 | /// 142 | /// A MobileNetV2 module with pre-trained weights. 143 | #[cfg(feature = "pretrained")] 144 | pub fn pretrained( 145 | weights: weights::MobileNetV2, 146 | device: &Device, 147 | ) -> Result { 148 | let weights = weights.weights(); 149 | let record = Self::load_weights_record(&weights, device)?; 150 | let model = MobileNetV2Config::new() 151 | .with_num_classes(weights.num_classes) 152 | .init(device) 153 | .load_record(record); 154 | 155 | Ok(model) 156 | } 157 | } 158 | 159 | #[allow(clippy::large_enum_variant)] 160 | #[derive(Module, Debug)] 161 | enum ConvBlock { 162 | InvertedResidual(InvertedResidual), 163 | Conv(Conv2dNormActivation), 164 | } 165 | 166 | #[derive(Module, Debug)] 167 | struct Classifier { 168 | dropout: Dropout, 169 | linear: Linear, 170 | } 171 | impl Classifier { 172 | fn forward(&self, input: Tensor) -> Tensor { 173 | let x = self.dropout.forward(input); 174 | self.linear.forward(x) 175 | } 176 | } 177 | 178 | /// MobileNetV2 from [`MobileNetV2: Inverted Residuals and Linear Bottlenecks`](https://arxiv.org/abs/1801.04381). 179 | #[derive(Debug, Config)] 180 | pub struct MobileNetV2Config { 181 | #[config(default = "1000")] 182 | num_classes: usize, 183 | 184 | #[config(default = "1.0")] 185 | width_mult: f32, 186 | 187 | #[config(default = "0.2")] 188 | dropout: f64, 189 | } 190 | 191 | impl MobileNetV2Config { 192 | /// Initialize a MobileNetV2 from 193 | /// [`MobileNetV2: Inverted Residuals and Linear Bottlenecks`](https://arxiv.org/abs/1801.04381). 194 | /// 195 | /// # Arguments 196 | /// 197 | /// * `device` - Device to create the module on. 198 | /// 199 | /// # Returns 200 | /// 201 | /// A MobileNetV2 module. 202 | pub fn init(&self, device: &B::Device) -> MobileNetV2 { 203 | let input_channel = 32; 204 | let last_channel = 1280; 205 | 206 | let make_divisible = |v, divisor| { 207 | let new_v = (v + divisor as f32 / 2.0) as usize / divisor * divisor; 208 | let mut new_v = max(new_v, divisor); 209 | 210 | // Make sure that round down does not go down by more than 10% 211 | if (new_v as f32) < 0.9 * v { 212 | new_v += divisor; 213 | } 214 | 215 | new_v 216 | }; 217 | 218 | let mut input_channel = 219 | make_divisible(input_channel as f32 * self.width_mult, ROUND_NEAREST); 220 | let last_channel = make_divisible( 221 | last_channel as f32 * f32::max(1.0, self.width_mult), 222 | ROUND_NEAREST, 223 | ); 224 | 225 | // Feature extraction layers with inverted residual blocks 226 | let mut features = vec![ConvBlock::Conv( 227 | Conv2dNormActivationConfig::new(3, input_channel) 228 | .with_kernel_size(3) 229 | .with_stride(2) 230 | .init(device), 231 | )]; 232 | for [t, c, n, s] in INVERTED_RESIDUAL_SETTINGS.into_iter() { 233 | let output_channel = make_divisible(c as f32 * self.width_mult, ROUND_NEAREST); 234 | for i in 0..n { 235 | let stride = if i == 0 { s } else { 1 }; 236 | features.push(ConvBlock::InvertedResidual( 237 | InvertedResidualConfig::new(input_channel, output_channel, stride, t) 238 | .init(device), 239 | )); 240 | input_channel = output_channel; 241 | } 242 | } 243 | features.push(ConvBlock::Conv( 244 | Conv2dNormActivationConfig::new(input_channel, last_channel) 245 | .with_kernel_size(1) 246 | .init(device), 247 | )); 248 | 249 | let classifier = Classifier { 250 | dropout: DropoutConfig::new(self.dropout).init(), 251 | linear: LinearConfig::new(last_channel, self.num_classes).init(device), 252 | }; 253 | 254 | MobileNetV2 { 255 | features, 256 | classifier, 257 | avg_pool: AdaptiveAvgPool2dConfig::new([1, 1]).init(), 258 | } 259 | } 260 | } 261 | -------------------------------------------------------------------------------- /mobilenetv2-burn/src/model/mod.rs: -------------------------------------------------------------------------------- 1 | mod conv_norm; 2 | pub mod imagenet; 3 | mod inverted_residual; 4 | pub mod mobilenetv2; 5 | #[cfg(feature = "pretrained")] 6 | pub mod weights; 7 | -------------------------------------------------------------------------------- /mobilenetv2-burn/src/model/weights.rs: -------------------------------------------------------------------------------- 1 | /// Pre-trained weights metadata. 2 | pub struct Weights { 3 | pub(super) url: &'static str, 4 | pub(super) num_classes: usize, 5 | } 6 | 7 | #[cfg(feature = "pretrained")] 8 | mod downloader { 9 | use super::*; 10 | use burn::data::network::downloader; 11 | use std::fs::{create_dir_all, File}; 12 | use std::io::Write; 13 | use std::path::PathBuf; 14 | 15 | impl Weights { 16 | /// Download the pre-trained weights to the local cache directory. 17 | pub fn download(&self) -> Result { 18 | // Model cache directory 19 | let model_dir = dirs::home_dir() 20 | .expect("Should be able to get home directory") 21 | .join(".cache") 22 | .join("mobilenetv2-burn"); 23 | 24 | if !model_dir.exists() { 25 | create_dir_all(&model_dir)?; 26 | } 27 | 28 | let file_base_name = self.url.rsplit_once('/').unwrap().1; 29 | let file_name = model_dir.join(file_base_name); 30 | if !file_name.exists() { 31 | // Download file content 32 | let bytes = downloader::download_file_as_bytes(self.url, file_base_name); 33 | 34 | // Write content to file 35 | let mut output_file = File::create(&file_name)?; 36 | let bytes_written = output_file.write(&bytes)?; 37 | 38 | if bytes_written != bytes.len() { 39 | return Err(std::io::Error::new( 40 | std::io::ErrorKind::InvalidData, 41 | "Failed to write the whole model weights file.", 42 | )); 43 | } 44 | } 45 | 46 | Ok(file_name) 47 | } 48 | } 49 | } 50 | 51 | pub trait WeightsMeta { 52 | fn weights(&self) -> Weights; 53 | } 54 | 55 | /// MobileNetV2 pre-trained weights. 56 | pub enum MobileNetV2 { 57 | /// These weights improve upon the results of the original paper with a new training 58 | /// [recipe](https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives). 59 | /// Top-1 accuracy: 72.154%. 60 | /// Top-5 accuracy: 90.822%. 61 | ImageNet1kV2, 62 | } 63 | impl WeightsMeta for MobileNetV2 { 64 | fn weights(&self) -> Weights { 65 | let url = match *self { 66 | MobileNetV2::ImageNet1kV2 => { 67 | "https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth" 68 | } 69 | }; 70 | Weights { 71 | url, 72 | num_classes: 1000, 73 | } 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /resnet-burn/Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | # Try 3 | # require version 2 to avoid "feature" additiveness for dev-dependencies 4 | # https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2 5 | resolver = "2" 6 | 7 | members = [ 8 | "resnet", 9 | "examples/*", 10 | ] 11 | 12 | [workspace.package] 13 | edition = "2021" 14 | version = "0.2.0" 15 | readme = "README.md" 16 | license = "MIT OR Apache-2.0" 17 | 18 | [workspace.dependencies] 19 | # Note: default-features = false is needed to disable std 20 | burn = { version = "0.16.0", default-features = false } 21 | burn-import = "0.16.0" 22 | dirs = "5.0.1" 23 | serde = { version = "1.0.192", default-features = false, features = [ 24 | "derive", 25 | "alloc", 26 | ] } # alloc is for no_std, derive is needed 27 | -------------------------------------------------------------------------------- /resnet-burn/LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | ../LICENSE-APACHE -------------------------------------------------------------------------------- /resnet-burn/LICENSE-MIT: -------------------------------------------------------------------------------- 1 | ../LICENSE-MIT -------------------------------------------------------------------------------- /resnet-burn/NOTICES.md: -------------------------------------------------------------------------------- 1 | # NOTICES AND INFORMATION 2 | 3 | This file contains notices and information required by libraries that this repository copied or derived from. The use of the following resources complies with the licenses provided. 4 | 5 | ## Sample Image 6 | 7 | Image Title: Standing yellow Labrador Retriever dog. 8 | Author: Djmirko 9 | Source: https://commons.wikimedia.org/wiki/File:YellowLabradorLooking_new.jpg 10 | License: https://creativecommons.org/licenses/by-sa/3.0/ 11 | 12 | ## Pre-trained Model 13 | 14 | The ImageNet pre-trained model was ported from [`torchvision.models.ResNet18_Weights.IMAGENET1K_V1`](https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet18.html#torchvision.models.ResNet18_Weights). 15 | 16 | As opposed to [other pre-trained models](https://pytorch.org/vision/stable/models/generated/torchvision.models.regnet_y_128gf.html#torchvision.models.RegNet_Y_128GF_Weights) in `torchvision`, no specific license was linked to the weights, which are assumed to be under the library's [BSD-3-Clause license](https://github.com/pytorch/vision/blob/main/LICENSE) ([ref](https://github.com/pytorch/vision/issues/160)). 17 | -------------------------------------------------------------------------------- /resnet-burn/README.md: -------------------------------------------------------------------------------- 1 | # ResNet Burn 2 | 3 | To this day, [ResNet](https://arxiv.org/abs/1512.03385)s are still a strong baseline for your image 4 | classification tasks. You can find the [Burn](https://github.com/tracel-ai/burn) implementation for 5 | the ResNet variants in [resnet.rs](resnet/src/resnet.rs). 6 | 7 | The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html). 8 | 9 | ## Usage 10 | 11 | ### `Cargo.toml` 12 | 13 | Add this to your `Cargo.toml`: 14 | 15 | ```toml 16 | [dependencies] 17 | resnet-burn = { git = "https://github.com/tracel-ai/models", package = "resnet-burn", default-features = false } 18 | ``` 19 | 20 | If you want to get the pre-trained ImageNet weights, enable the `pretrained` feature flag. 21 | 22 | ```toml 23 | [dependencies] 24 | resnet-burn = { git = "https://github.com/tracel-ai/models", package = "resnet-burn", features = ["pretrained"] } 25 | ``` 26 | 27 | **Important:** this feature requires `std`. 28 | 29 | ### Example Usage 30 | 31 | #### Inference 32 | 33 | The [inference example](examples/inference/examples/inference.rs) initializes a ResNet-18 from the 34 | ImageNet 35 | [pre-trained weights](https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet18.html#torchvision.models.ResNet18_Weights) 36 | with the `NdArray` backend and performs inference on the provided input image. 37 | 38 | You can run the example with the following command: 39 | 40 | ```sh 41 | cargo run --release --example inference samples/dog.jpg 42 | ``` 43 | 44 | #### Fine-tuning 45 | 46 | For this [multi-label image classification fine-tuning example](examples/finetune), a sample of the 47 | planets dataset from the Kaggle competition 48 | [Planet: Understanding the Amazon from Space](https://www.kaggle.com/c/planet-understanding-the-amazon-from-space) 49 | is downloaded from a 50 | [fastai mirror](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L55). The 51 | sample dataset is a collection of satellite images with multiple labels describing the scene, as 52 | illustrated below. 53 | 54 | Planet dataset sample 55 | 56 | To achieve this task, a ResNet-18 pre-trained on the ImageNet dataset is fine-tuned on the target 57 | planets dataset. The training recipe used is fairly simple. The main objective is to demonstrate how to re-use a 58 | pre-trained model for a different downstream task. 59 | 60 | Without any bells and whistle, our model achieves over 90% multi-label accuracy (i.e., hamming 61 | score) on the validation set after just 5 epochs. 62 | 63 | Run the example with the Torch GPU backend: 64 | 65 | ```sh 66 | export TORCH_CUDA_VERSION=cu121 67 | cargo run --release --example finetune --features tch-gpu 68 | ``` 69 | 70 | Run it with our WGPU backend: 71 | 72 | ```sh 73 | cargo run --release --example finetune --features wgpu 74 | ``` 75 | -------------------------------------------------------------------------------- /resnet-burn/examples/finetune/.gitignore: -------------------------------------------------------------------------------- 1 | # Downloaded files 2 | planet_sample/ -------------------------------------------------------------------------------- /resnet-burn/examples/finetune/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["guillaumelagrange "] 3 | name = "finetune" 4 | license.workspace = true 5 | version.workspace = true 6 | edition.workspace = true 7 | 8 | [features] 9 | default = ["burn/default"] 10 | tch-gpu = ["burn/tch"] 11 | wgpu = ["burn/wgpu"] 12 | 13 | [dependencies] 14 | resnet-burn = { path = "../../resnet", features = ["pretrained"] } 15 | burn = { workspace = true, features = ["train", "vision", "network"] } 16 | 17 | # Dataset files 18 | csv = "1.3.0" 19 | flate2 = "1.0.28" 20 | rand = { version = "0.8.5", default-features = false, features = [ 21 | "std_rng", 22 | ] } 23 | serde = { version = "1.0.192", features = ["std", "derive"] } 24 | tar = "0.4.40" -------------------------------------------------------------------------------- /resnet-burn/examples/finetune/examples/finetune.rs: -------------------------------------------------------------------------------- 1 | use burn::{backend::Autodiff, tensor::backend::Backend}; 2 | use finetune::{inference::infer, training::train}; 3 | 4 | #[allow(dead_code)] 5 | const ARTIFACT_DIR: &str = "/tmp/resnet-finetune"; 6 | 7 | #[allow(dead_code)] 8 | fn run(device: B::Device) { 9 | train::>(ARTIFACT_DIR, device.clone()); 10 | infer::(ARTIFACT_DIR, device, 0.5); 11 | } 12 | 13 | #[cfg(feature = "tch-gpu")] 14 | mod tch_gpu { 15 | use burn::backend::libtorch::{LibTorch, LibTorchDevice}; 16 | 17 | pub fn run() { 18 | #[cfg(not(target_os = "macos"))] 19 | let device = LibTorchDevice::Cuda(0); 20 | #[cfg(target_os = "macos")] 21 | let device = LibTorchDevice::Mps; 22 | 23 | super::run::(device); 24 | } 25 | } 26 | 27 | #[cfg(feature = "wgpu")] 28 | mod wgpu { 29 | use burn::backend::wgpu::{Wgpu, WgpuDevice}; 30 | 31 | pub fn run() { 32 | super::run::(WgpuDevice::default()); 33 | } 34 | } 35 | 36 | fn main() { 37 | #[cfg(feature = "tch-gpu")] 38 | tch_gpu::run(); 39 | #[cfg(feature = "wgpu")] 40 | wgpu::run(); 41 | } 42 | -------------------------------------------------------------------------------- /resnet-burn/examples/finetune/src/data.rs: -------------------------------------------------------------------------------- 1 | use burn::{ 2 | data::{ 3 | dataloader::batcher::Batcher, 4 | dataset::vision::{Annotation, ImageDatasetItem, PixelDepth}, 5 | }, 6 | prelude::*, 7 | }; 8 | 9 | use super::dataset::CLASSES; 10 | 11 | // ImageNet mean and std values 12 | const MEAN: [f32; 3] = [0.485, 0.456, 0.406]; 13 | const STD: [f32; 3] = [0.229, 0.224, 0.225]; 14 | 15 | // Planets patch size 16 | const WIDTH: usize = 256; 17 | const HEIGHT: usize = 256; 18 | 19 | /// Create a multi-hot encoded tensor. 20 | /// 21 | /// # Example 22 | /// 23 | /// ```rust, ignore 24 | /// let multi_hot = multi_hot::(&[2, 5, 8], 10, &device); 25 | /// println!("{}", multi_hot.to_data()); 26 | /// // [0, 0, 1, 0, 0, 1, 0, 0, 1, 0] 27 | /// ``` 28 | pub fn multi_hot( 29 | indices: &[usize], 30 | num_classes: usize, 31 | device: &B::Device, 32 | ) -> Tensor { 33 | Tensor::zeros(Shape::new([num_classes]), device).scatter( 34 | 0, 35 | Tensor::from_ints( 36 | indices 37 | .iter() 38 | .map(|i| *i as i32) 39 | .collect::>() 40 | .as_slice(), 41 | device, 42 | ), 43 | Tensor::ones(Shape::new([indices.len()]), device), 44 | ) 45 | } 46 | 47 | /// Normalizer with ImageNet values as it helps accelerate training since we are fine-tuning from 48 | /// ImageNet pre-trained weights and the model expects the data to be in this normalized range. 49 | #[derive(Clone)] 50 | pub struct Normalizer { 51 | pub mean: Tensor, 52 | pub std: Tensor, 53 | } 54 | 55 | impl Normalizer { 56 | /// Creates a new normalizer. 57 | pub fn new(device: &Device) -> Self { 58 | let mean = Tensor::::from_floats(MEAN, device).reshape([1, 3, 1, 1]); 59 | let std = Tensor::::from_floats(STD, device).reshape([1, 3, 1, 1]); 60 | Self { mean, std } 61 | } 62 | 63 | /// Normalizes the input image according to the ImageNet dataset. 64 | /// 65 | /// The input image should be in the range [0, 1]. 66 | /// The output image will be in the range [-1, 1]. 67 | /// 68 | /// The normalization is done according to the following formula: 69 | /// `input = (input - mean) / std` 70 | pub fn normalize(&self, input: Tensor) -> Tensor { 71 | (input - self.mean.clone()) / self.std.clone() 72 | } 73 | } 74 | 75 | #[derive(Clone)] 76 | pub struct ClassificationBatcher { 77 | normalizer: Normalizer, 78 | device: B::Device, 79 | } 80 | 81 | #[derive(Clone, Debug)] 82 | pub struct ClassificationBatch { 83 | pub images: Tensor, 84 | pub targets: Tensor, 85 | } 86 | 87 | impl ClassificationBatcher { 88 | pub fn new(device: B::Device) -> Self { 89 | Self { 90 | normalizer: Normalizer::::new(&device), 91 | device, 92 | } 93 | } 94 | } 95 | 96 | impl Batcher> for ClassificationBatcher { 97 | fn batch(&self, items: Vec) -> ClassificationBatch { 98 | fn image_as_vec_u8(item: ImageDatasetItem) -> Vec { 99 | // Convert Vec to Vec (Planet images are u8) 100 | item.image 101 | .into_iter() 102 | .map(|p: PixelDepth| -> u8 { p.try_into().unwrap() }) 103 | .collect::>() 104 | } 105 | 106 | let targets = items 107 | .iter() 108 | .map(|item| { 109 | // Expect multi-hot encoded class labels as target (e.g., [0, 1, 0, 0, 1]) 110 | if let Annotation::MultiLabel(y) = &item.annotation { 111 | multi_hot(y, CLASSES.len(), &self.device) 112 | } else { 113 | panic!("Invalid target type") 114 | } 115 | }) 116 | .collect(); 117 | 118 | let images = items 119 | .into_iter() 120 | .map(|item| TensorData::new(image_as_vec_u8(item), Shape::new([HEIGHT, WIDTH, 3]))) 121 | .map(|data| Tensor::::from_data(data.convert::(), &self.device)) 122 | .map(|tensor| tensor.permute([2, 0, 1]) / 255) // normalize between [0, 1] 123 | .collect(); 124 | 125 | let images = Tensor::stack(images, 0); 126 | let targets = Tensor::stack(targets, 0); 127 | 128 | let images = self.normalizer.normalize(images); 129 | 130 | ClassificationBatch { images, targets } 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /resnet-burn/examples/finetune/src/dataset.rs: -------------------------------------------------------------------------------- 1 | use flate2::read::GzDecoder; 2 | use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng}; 3 | use serde::{Deserialize, Serialize}; 4 | use std::{ 5 | collections::HashSet, 6 | path::{Path, PathBuf}, 7 | }; 8 | use tar::Archive; 9 | 10 | use burn::data::{ 11 | dataset::vision::{ImageFolderDataset, ImageLoaderError}, 12 | network::downloader, 13 | }; 14 | 15 | /// Planets dataset sample mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L55). 16 | /// Licensed under the [Appache License](https://github.com/fastai/fastai/blob/master/LICENSE). 17 | const URL: &str = "https://s3.amazonaws.com/fast-ai-sample/planet_sample.tgz"; 18 | const LABELS: &str = "labels.csv"; 19 | pub const CLASSES: [&str; 17] = [ 20 | "agriculture", 21 | "artisinal_mine", 22 | "bare_ground", 23 | "blooming", 24 | "blow_down", 25 | "clear", 26 | "cloudy", 27 | "conventional_mine", 28 | "cultivation", 29 | "habitation", 30 | "haze", 31 | "partly_cloudy", 32 | "primary", 33 | "road", 34 | "selective_logging", 35 | "slash_burn", 36 | "water", 37 | ]; 38 | 39 | /// A sample of the planets dataset from the Kaggle competition 40 | /// [Planet: Understanding the Amazon from Space](https://www.kaggle.com/c/planet-understanding-the-amazon-from-space). 41 | /// 42 | /// This version of the multi-label classification dataset contains 1,000 256x256 image patches 43 | /// with possibly multiple labels per patch. The labels can broadly be broken into three groups: 44 | /// atmospheric conditions, common land cover/land use phenomena, and rare land cover/land use 45 | /// phenomena. Each patch will have one and potentially more than one atmospheric label and zero 46 | /// or more common and rare labels. 47 | /// 48 | /// The data is downloaded from the web from the [fastai mirror](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L55). 49 | pub trait PlanetLoader: Sized { 50 | fn planet_train_val_split( 51 | train_percentage: u8, 52 | seed: u64, 53 | ) -> Result<(Self, Self), ImageLoaderError>; 54 | } 55 | 56 | #[derive(Deserialize, Serialize, Debug, Clone)] 57 | struct PlanetSample { 58 | image_name: String, 59 | tags: String, 60 | } 61 | 62 | impl PlanetLoader for ImageFolderDataset { 63 | /// Creates new Planet dataset for train and validation splits. 64 | /// 65 | /// # Arguments 66 | /// 67 | /// * `train_percentage` - Percentage of the training split. The remainder will be used for the validation split. 68 | /// * `seed` - Controls the shuffling applied to the data before applying the split. 69 | /// 70 | fn planet_train_val_split( 71 | train_percentage: u8, 72 | seed: u64, 73 | ) -> Result<(Self, Self), ImageLoaderError> { 74 | assert!( 75 | train_percentage > 0 && train_percentage < 100, 76 | "Training split percentage must be between (0, 100)" 77 | ); 78 | let root = download(); 79 | 80 | // Load items from csv 81 | let mut rdr = csv::ReaderBuilder::new() 82 | .from_path(root.join(LABELS)) 83 | .map_err(|err| ImageLoaderError::Unknown(err.to_string()))?; 84 | 85 | // Collect items (image path, labels) 86 | let mut classes = HashSet::new(); 87 | let mut items = rdr 88 | .deserialize() 89 | .map(|result| { 90 | let item: PlanetSample = 91 | result.map_err(|err| ImageLoaderError::Unknown(err.to_string()))?; 92 | let tags = item 93 | .tags 94 | .split(' ') 95 | .map(|s| s.to_string()) 96 | .collect::>(); 97 | 98 | for tag in tags.iter() { 99 | classes.insert(tag.clone()); 100 | } 101 | 102 | Ok(( 103 | // Full path to image 104 | root.join("train") 105 | .join(item.image_name) 106 | .with_extension("jpg"), 107 | // Multiple labels per image (e.g., ["haze", "primary", "water"]) 108 | tags, 109 | )) 110 | }) 111 | .collect::, _>>()?; 112 | 113 | // Sort class names 114 | let mut classes = classes.iter().collect::>(); 115 | classes.sort(); 116 | assert_eq!(classes, CLASSES, "Invalid categories"); // just in case the labels unexpectedly change 117 | 118 | // Shuffle items 119 | items.shuffle(&mut StdRng::seed_from_u64(seed)); 120 | 121 | // Split train and validation 122 | let size = items.len(); 123 | let train_slice = (size as f32 * (train_percentage as f32 / 100.0)) as usize; 124 | 125 | let train = Self::new_multilabel_classification_with_items( 126 | items[..train_slice].to_vec(), 127 | &classes, 128 | )?; 129 | let valid = Self::new_multilabel_classification_with_items( 130 | items[train_slice..].to_vec(), 131 | &classes, 132 | )?; 133 | 134 | Ok((train, valid)) 135 | } 136 | } 137 | 138 | /// Download the Planet dataset from the web to the current example directory. 139 | fn download() -> PathBuf { 140 | // Point to current example directory 141 | let example_dir = Path::new(file!()).parent().unwrap().parent().unwrap(); 142 | let planet_dir = example_dir.join("planet_sample"); 143 | 144 | // Check for already downloaded content 145 | let labels_file = planet_dir.join(LABELS); 146 | if !labels_file.exists() { 147 | // Download gzip file 148 | let bytes = downloader::download_file_as_bytes(URL, "planet_sample.tgz"); 149 | 150 | // Decode gzip file content and unpack archive 151 | let gz_buffer = GzDecoder::new(&bytes[..]); 152 | let mut archive = Archive::new(gz_buffer); 153 | archive.unpack(example_dir).unwrap(); 154 | } 155 | 156 | planet_dir 157 | } 158 | -------------------------------------------------------------------------------- /resnet-burn/examples/finetune/src/inference.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | data::ClassificationBatcher, 3 | dataset::{PlanetLoader, CLASSES}, 4 | training::TrainingConfig, 5 | }; 6 | use burn::{ 7 | data::{ 8 | dataloader::batcher::Batcher, 9 | dataset::{ 10 | vision::{Annotation, ImageFolderDataset}, 11 | Dataset, 12 | }, 13 | }, 14 | prelude::*, 15 | record::{CompactRecorder, Recorder}, 16 | tensor::activation::sigmoid, 17 | }; 18 | use resnet_burn::ResNet; 19 | 20 | pub fn infer(artifact_dir: &str, device: B::Device, threshold: f32) { 21 | // Load trained ResNet-18 22 | let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) 23 | .expect("Config should exist for the model"); 24 | let record = CompactRecorder::new() 25 | .load(format!("{artifact_dir}/model").into(), &device) 26 | .expect("Trained model should exist"); 27 | 28 | let model: ResNet = ResNet::resnet18(config.num_classes, &device).load_record(record); 29 | 30 | // Get an item from validation split with multiple labels 31 | let (_train, valid) = 32 | ImageFolderDataset::planet_train_val_split(config.train_percentage, config.seed).unwrap(); 33 | let item = valid.get(20).unwrap(); 34 | 35 | let label = if let Annotation::MultiLabel(ref categories) = item.annotation { 36 | categories.iter().map(|&i| CLASSES[i]).collect::>() 37 | } else { 38 | panic!("Annotation should be multilabel") 39 | }; 40 | 41 | // Forward pass with sigmoid activation function 42 | let batcher = ClassificationBatcher::new(device); 43 | let batch = batcher.batch(vec![item]); 44 | let output = sigmoid(model.forward(batch.images)); 45 | 46 | // Get predicted class names over the specified threshold 47 | let predicted = output.greater_equal_elem(threshold).nonzero()[1] 48 | .to_data() 49 | .iter::() 50 | .map(|i| CLASSES[i.elem::() as usize]) 51 | .collect::>(); 52 | 53 | println!("Predicted: {:?}\nExpected: {:?}", predicted, label); 54 | } 55 | -------------------------------------------------------------------------------- /resnet-burn/examples/finetune/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod data; 2 | pub mod dataset; 3 | pub mod inference; 4 | pub mod training; 5 | -------------------------------------------------------------------------------- /resnet-burn/examples/finetune/src/training.rs: -------------------------------------------------------------------------------- 1 | use std::time::Instant; 2 | 3 | use crate::{ 4 | data::{ClassificationBatch, ClassificationBatcher}, 5 | dataset::{PlanetLoader, CLASSES}, 6 | }; 7 | use burn::{ 8 | data::{dataloader::DataLoaderBuilder, dataset::vision::ImageFolderDataset}, 9 | nn::loss::BinaryCrossEntropyLossConfig, 10 | optim::{decay::WeightDecayConfig, AdamConfig}, 11 | prelude::*, 12 | record::CompactRecorder, 13 | tensor::backend::AutodiffBackend, 14 | train::{ 15 | metric::{HammingScore, LossMetric}, 16 | LearnerBuilder, MultiLabelClassificationOutput, TrainOutput, TrainStep, ValidStep, 17 | }, 18 | }; 19 | use resnet_burn::{weights, ResNet}; 20 | 21 | const NUM_CLASSES: usize = CLASSES.len(); 22 | 23 | pub trait MultiLabelClassification { 24 | fn forward_classification( 25 | &self, 26 | images: Tensor, 27 | targets: Tensor, 28 | ) -> MultiLabelClassificationOutput; 29 | } 30 | 31 | impl MultiLabelClassification for ResNet { 32 | fn forward_classification( 33 | &self, 34 | images: Tensor, 35 | targets: Tensor, 36 | ) -> MultiLabelClassificationOutput { 37 | let output = self.forward(images); 38 | let loss = BinaryCrossEntropyLossConfig::new() 39 | .with_logits(true) 40 | .init(&output.device()) 41 | .forward(output.clone(), targets.clone()); 42 | 43 | MultiLabelClassificationOutput::new(loss, output, targets) 44 | } 45 | } 46 | 47 | impl TrainStep, MultiLabelClassificationOutput> 48 | for ResNet 49 | { 50 | fn step( 51 | &self, 52 | batch: ClassificationBatch, 53 | ) -> TrainOutput> { 54 | let item = self.forward_classification(batch.images, batch.targets); 55 | 56 | TrainOutput::new(self, item.loss.backward(), item) 57 | } 58 | } 59 | 60 | impl ValidStep, MultiLabelClassificationOutput> 61 | for ResNet 62 | { 63 | fn step(&self, batch: ClassificationBatch) -> MultiLabelClassificationOutput { 64 | self.forward_classification(batch.images, batch.targets) 65 | } 66 | } 67 | 68 | #[derive(Config)] 69 | pub struct TrainingConfig { 70 | #[config(default = 5)] 71 | pub num_epochs: usize, 72 | 73 | #[config(default = 128)] 74 | pub batch_size: usize, 75 | 76 | #[config(default = 4)] 77 | pub num_workers: usize, 78 | 79 | #[config(default = 42)] 80 | pub seed: u64, 81 | 82 | #[config(default = 1e-3)] 83 | pub learning_rate: f64, 84 | 85 | #[config(default = 5e-5)] 86 | pub weight_decay: f32, 87 | 88 | #[config(default = 70)] 89 | pub train_percentage: u8, 90 | 91 | pub num_classes: usize, 92 | } 93 | 94 | fn create_artifact_dir(artifact_dir: &str) { 95 | // Remove existing artifacts before to get an accurate learner summary 96 | std::fs::remove_dir_all(artifact_dir).ok(); 97 | std::fs::create_dir_all(artifact_dir).ok(); 98 | } 99 | 100 | pub fn train(artifact_dir: &str, device: B::Device) { 101 | create_artifact_dir(artifact_dir); 102 | 103 | // Config 104 | let config = TrainingConfig::new(NUM_CLASSES); 105 | let optimizer = AdamConfig::new() 106 | .with_weight_decay(Some(WeightDecayConfig::new(config.weight_decay))) 107 | .init(); 108 | 109 | config 110 | .save(format!("{artifact_dir}/config.json")) 111 | .expect("Config should be saved successfully"); 112 | 113 | B::seed(config.seed); 114 | 115 | // Dataloaders 116 | let batcher_train = ClassificationBatcher::::new(device.clone()); 117 | let batcher_valid = ClassificationBatcher::::new(device.clone()); 118 | 119 | let (train, valid) = 120 | ImageFolderDataset::planet_train_val_split(config.train_percentage, config.seed).unwrap(); 121 | 122 | let dataloader_train = DataLoaderBuilder::new(batcher_train) 123 | .batch_size(config.batch_size) 124 | .shuffle(config.seed) 125 | .num_workers(config.num_workers) 126 | .build(train); 127 | 128 | let dataloader_test = DataLoaderBuilder::new(batcher_valid) 129 | .batch_size(config.batch_size) 130 | .num_workers(config.num_workers) 131 | .build(valid); 132 | 133 | // Pre-trained ResNet-18 adapted for num_classes in this task 134 | let model = ResNet::resnet18_pretrained(weights::ResNet18::ImageNet1kV1, &device) 135 | .unwrap() 136 | .with_classes(NUM_CLASSES); 137 | 138 | // Learner config 139 | let learner = LearnerBuilder::new(artifact_dir) 140 | .metric_train_numeric(HammingScore::new()) 141 | .metric_valid_numeric(HammingScore::new()) 142 | .metric_train_numeric(LossMetric::new()) 143 | .metric_valid_numeric(LossMetric::new()) 144 | .with_file_checkpointer(CompactRecorder::new()) 145 | .devices(vec![device.clone()]) 146 | .num_epochs(config.num_epochs) 147 | .summary() 148 | .build(model, optimizer, config.learning_rate); 149 | 150 | // Training 151 | let now = Instant::now(); 152 | let model_trained = learner.fit(dataloader_train, dataloader_test); 153 | let elapsed = now.elapsed().as_secs(); 154 | println!("Training completed in {}m{}s", (elapsed / 60), elapsed % 60); 155 | 156 | model_trained 157 | .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) 158 | .expect("Trained model should be saved successfully"); 159 | } 160 | -------------------------------------------------------------------------------- /resnet-burn/examples/inference/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["guillaumelagrange "] 3 | name = "inference" 4 | license.workspace = true 5 | version.workspace = true 6 | edition.workspace = true 7 | publish = false 8 | 9 | [dependencies] 10 | resnet-burn = { path = "../../resnet", features = ["pretrained"] } 11 | burn = { workspace = true, features = ["ndarray"] } 12 | image = { version = "0.24.9", features = ["png", "jpeg"] } 13 | 14 | -------------------------------------------------------------------------------- /resnet-burn/examples/inference/examples/inference.rs: -------------------------------------------------------------------------------- 1 | use inference::imagenet; 2 | use resnet_burn::{weights, ResNet}; 3 | 4 | use burn::{ 5 | backend::NdArray, 6 | module::Module, 7 | record::{FullPrecisionSettings, NamedMpkFileRecorder}, 8 | tensor::{backend::Backend, Device, Element, Tensor, TensorData}, 9 | }; 10 | 11 | const MODEL_PATH: &str = "resnet18-ImageNet1k"; 12 | const HEIGHT: usize = 224; 13 | const WIDTH: usize = 224; 14 | 15 | fn to_tensor( 16 | data: Vec, 17 | shape: [usize; 3], 18 | device: &Device, 19 | ) -> Tensor { 20 | Tensor::::from_data(TensorData::new(data, shape).convert::(), device) 21 | .permute([2, 0, 1]) // [C, H, W] 22 | / 255 // normalize between [0, 1] 23 | } 24 | 25 | pub fn main() { 26 | // Parse arguments 27 | let img_path = std::env::args().nth(1).expect("No image path provided"); 28 | 29 | // Create ResNet-18 30 | let device = Default::default(); 31 | let model: ResNet = 32 | ResNet::resnet18_pretrained(weights::ResNet18::ImageNet1kV1, &device) 33 | .map_err(|err| format!("Failed to load pre-trained weights.\nError: {err}")) 34 | .unwrap(); 35 | 36 | // Save the model to a supported format and load it back 37 | let recorder = NamedMpkFileRecorder::::new(); 38 | model 39 | .clone() // `save_file` takes ownership but we want to load the file after 40 | .save_file(MODEL_PATH, &recorder) 41 | .map_err(|err| format!("Failed to save weights to file {MODEL_PATH}.\nError: {err}")) 42 | .unwrap(); 43 | let model = model 44 | .load_file(MODEL_PATH, &recorder, &device) 45 | .map_err(|err| format!("Failed to load weights from file {MODEL_PATH}.\nError: {err}")) 46 | .unwrap(); 47 | 48 | // Load image 49 | let img = image::open(&img_path) 50 | .map_err(|err| format!("Failed to load image {img_path}.\nError: {err}")) 51 | .unwrap(); 52 | 53 | // Resize to 224x224 54 | let resized_img = img.resize_exact( 55 | WIDTH as u32, 56 | HEIGHT as u32, 57 | image::imageops::FilterType::Triangle, // also known as bilinear in 2D 58 | ); 59 | 60 | // Create tensor from image data 61 | let img_tensor = to_tensor( 62 | resized_img.into_rgb8().into_raw(), 63 | [HEIGHT, WIDTH, 3], 64 | &device, 65 | ) 66 | .unsqueeze::<4>(); // [B, C, H, W] 67 | 68 | // Normalize the image 69 | let x = imagenet::Normalizer::new(&device).normalize(img_tensor); 70 | 71 | // Forward pass 72 | let out = model.forward(x); 73 | 74 | // Output class index w/ score (raw) 75 | let (score, idx) = out.max_dim_with_indices(1); 76 | let idx = idx.into_scalar() as usize; 77 | 78 | println!( 79 | "Predicted: {}\nCategory Id: {}\nScore: {:.4}", 80 | imagenet::CLASSES[idx], 81 | idx, 82 | score.into_scalar() 83 | ); 84 | } 85 | -------------------------------------------------------------------------------- /resnet-burn/examples/inference/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod imagenet; 2 | -------------------------------------------------------------------------------- /resnet-burn/resnet/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["guillaumelagrange "] 3 | name = "resnet-burn" 4 | version = "0.1.0" 5 | edition.workspace = true 6 | license.workspace = true 7 | 8 | [features] 9 | default = [] 10 | std = [] 11 | pretrained = ["burn/network", "std", "dep:dirs"] 12 | 13 | [dependencies] 14 | burn = { workspace = true } 15 | burn-import = { workspace = true } 16 | dirs = { workspace = true, optional = true } 17 | serde = { workspace = true } 18 | -------------------------------------------------------------------------------- /resnet-burn/resnet/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![cfg_attr(not(feature = "std"), no_std)] 2 | mod block; 3 | pub mod resnet; 4 | pub mod weights; 5 | 6 | pub use resnet::*; 7 | pub use weights::*; 8 | 9 | extern crate alloc; 10 | -------------------------------------------------------------------------------- /resnet-burn/resnet/src/weights.rs: -------------------------------------------------------------------------------- 1 | /// Pre-trained weights metadata. 2 | pub struct Weights { 3 | pub(super) url: &'static str, 4 | pub(super) num_classes: usize, 5 | } 6 | 7 | #[cfg(feature = "pretrained")] 8 | mod downloader { 9 | use super::*; 10 | use burn::data::network::downloader; 11 | use std::fs::{create_dir_all, File}; 12 | use std::io::Write; 13 | use std::path::PathBuf; 14 | 15 | impl Weights { 16 | /// Download the pre-trained weights to the local cache directory. 17 | pub fn download(&self) -> Result { 18 | // Model cache directory 19 | let model_dir = dirs::home_dir() 20 | .expect("Should be able to get home directory") 21 | .join(".cache") 22 | .join("resnet-burn"); 23 | 24 | if !model_dir.exists() { 25 | create_dir_all(&model_dir)?; 26 | } 27 | 28 | let file_base_name = self.url.rsplit_once('/').unwrap().1; 29 | let file_name = model_dir.join(file_base_name); 30 | if !file_name.exists() { 31 | // Download file content 32 | let bytes = downloader::download_file_as_bytes(self.url, file_base_name); 33 | 34 | // Write content to file 35 | let mut output_file = File::create(&file_name)?; 36 | let bytes_written = output_file.write(&bytes)?; 37 | 38 | if bytes_written != bytes.len() { 39 | return Err(std::io::Error::new( 40 | std::io::ErrorKind::InvalidData, 41 | "Failed to write the whole model weights file.", 42 | )); 43 | } 44 | } 45 | 46 | Ok(file_name) 47 | } 48 | } 49 | } 50 | 51 | pub trait WeightsMeta { 52 | fn weights(&self) -> Weights; 53 | } 54 | 55 | /// ResNet-18 pre-trained weights. 56 | pub enum ResNet18 { 57 | /// These weights reproduce closely the results of the original paper. 58 | /// Top-1 accuracy: 69.758%. 59 | /// Top-5 accuracy: 89.078%. 60 | ImageNet1kV1, 61 | } 62 | impl WeightsMeta for ResNet18 { 63 | fn weights(&self) -> Weights { 64 | Weights { 65 | url: "https://download.pytorch.org/models/resnet18-f37072fd.pth", 66 | num_classes: 1000, 67 | } 68 | } 69 | } 70 | 71 | /// ResNet-34 pre-trained weights. 72 | pub enum ResNet34 { 73 | /// These weights reproduce closely the results of the original paper. 74 | /// Top-1 accuracy: 73.314%. 75 | /// Top-5 accuracy: 91.420%. 76 | ImageNet1kV1, 77 | } 78 | impl WeightsMeta for ResNet34 { 79 | fn weights(&self) -> Weights { 80 | Weights { 81 | url: "https://download.pytorch.org/models/resnet34-b627a593.pth", 82 | num_classes: 1000, 83 | } 84 | } 85 | } 86 | 87 | /// ResNet-50 pre-trained weights. 88 | pub enum ResNet50 { 89 | /// These weights reproduce closely the results of the original paper. 90 | /// Top-1 accuracy: 76.130%. 91 | /// Top-5 accuracy: 92.862%. 92 | ImageNet1kV1, 93 | /// These weights improve upon the results of the original paper with a new training 94 | /// [recipe](https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives). 95 | /// Top-1 accuracy: 80.858%. 96 | /// Top-5 accuracy: 95.434%. 97 | ImageNet1kV2, 98 | } 99 | impl WeightsMeta for ResNet50 { 100 | fn weights(&self) -> Weights { 101 | let url = match *self { 102 | ResNet50::ImageNet1kV1 => "https://download.pytorch.org/models/resnet50-0676ba61.pth", 103 | ResNet50::ImageNet1kV2 => "https://download.pytorch.org/models/resnet50-11ad3fa6.pth", 104 | }; 105 | Weights { 106 | url, 107 | num_classes: 1000, 108 | } 109 | } 110 | } 111 | 112 | /// ResNet-101 pre-trained weights. 113 | pub enum ResNet101 { 114 | /// These weights reproduce closely the results of the original paper. 115 | /// Top-1 accuracy: 77.374%. 116 | /// Top-5 accuracy: 93.546%. 117 | ImageNet1kV1, 118 | /// These weights improve upon the results of the original paper with a new training 119 | /// [recipe](https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives). 120 | /// Top-1 accuracy: 81.886%. 121 | /// Top-5 accuracy: 95.780%. 122 | ImageNet1kV2, 123 | } 124 | impl WeightsMeta for ResNet101 { 125 | fn weights(&self) -> Weights { 126 | let url = match *self { 127 | ResNet101::ImageNet1kV1 => "https://download.pytorch.org/models/resnet101-63fe2227.pth", 128 | ResNet101::ImageNet1kV2 => "https://download.pytorch.org/models/resnet101-cd907fc2.pth", 129 | }; 130 | Weights { 131 | url, 132 | num_classes: 1000, 133 | } 134 | } 135 | } 136 | 137 | /// ResNet-152 pre-trained weights. 138 | pub enum ResNet152 { 139 | /// These weights reproduce closely the results of the original paper. 140 | /// Top-1 accuracy: 78.312%. 141 | /// Top-5 accuracy: 94.046%. 142 | ImageNet1kV1, 143 | /// These weights improve upon the results of the original paper with a new training 144 | /// [recipe](https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives). 145 | /// Top-1 accuracy: 82.284%. 146 | /// Top-5 accuracy: 96.002%. 147 | ImageNet1kV2, 148 | } 149 | impl WeightsMeta for ResNet152 { 150 | fn weights(&self) -> Weights { 151 | let url = match *self { 152 | ResNet152::ImageNet1kV1 => "https://download.pytorch.org/models/resnet152-394f9c45.pth", 153 | ResNet152::ImageNet1kV2 => "https://download.pytorch.org/models/resnet152-f82ba261.pth", 154 | }; 155 | Weights { 156 | url, 157 | num_classes: 1000, 158 | } 159 | } 160 | } 161 | -------------------------------------------------------------------------------- /resnet-burn/samples/dataset.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tracel-ai/models/146da485484c6cec70b69065117329dea4604b73/resnet-burn/samples/dataset.jpg -------------------------------------------------------------------------------- /resnet-burn/samples/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tracel-ai/models/146da485484c6cec70b69065117329dea4604b73/resnet-burn/samples/dog.jpg -------------------------------------------------------------------------------- /squeezenet-burn/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Dilshod Tadjibaev (@antimora)"] 3 | license = "MIT OR Apache-2.0" 4 | name = "squeezenet-burn" 5 | version = "0.1.0" 6 | edition = "2021" 7 | 8 | [features] 9 | default = ["weights_file", "weights_file_dump"] 10 | 11 | # Enables Half precision (f16) support 12 | weights_f16 = [] 13 | 14 | # Embed weights into the binary 15 | weights_embedded = [] 16 | 17 | # Use weights from a file 18 | weights_file = ["burn/default"] 19 | 20 | # Copy weights file to specif folder 21 | weights_file_dump = [] 22 | 23 | 24 | [dependencies] 25 | 26 | # Note: default-features = false is needed to disable std 27 | burn = { version = "0.16.0", default-features = false } 28 | 29 | # Used to load weights from a file 30 | serde = { version = "1.0.183", default-features = false, features = [ 31 | "derive", 32 | "alloc", 33 | ] } # alloc is for no_std, derive is needed 34 | 35 | [dev-dependencies] 36 | # Used by the classify example 37 | burn = { version = "0.16.0", features = ["ndarray"] } 38 | image = { version = "0.24.7", features = ["png", "jpeg"] } 39 | 40 | [build-dependencies] 41 | # Used to generate code from ONNX model 42 | burn-import = { version = "0.16.0", package = "burn-import" } 43 | -------------------------------------------------------------------------------- /squeezenet-burn/LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | ../LICENSE-APACHE -------------------------------------------------------------------------------- /squeezenet-burn/LICENSE-MIT: -------------------------------------------------------------------------------- 1 | ../LICENSE-MIT -------------------------------------------------------------------------------- /squeezenet-burn/NOTICES.md: -------------------------------------------------------------------------------- 1 | # NOTICES AND INFORMATION 2 | 3 | This file contains notices and information required by libraries that this repository copied or derived from. The use of the following resources complies with the licenses provided. 4 | 5 | ## Sample Images 6 | 7 | Image Title: Domestic cat, a ten month old female. 8 | Author: Von.grzanka 9 | Source: https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg 10 | License: https://creativecommons.org/licenses/by-sa/3.0/ 11 | 12 | Image Title: The George Washington Bridge over the Hudson River leading to New York City as seen from Fort Lee, New Jersey. 13 | Author: John O'Connell 14 | Source: https://commons.wikimedia.org/wiki/File:George_Washington_Bridge_from_New_Jersey-edit.jpg 15 | License: https://creativecommons.org/licenses/by/2.0/deed.en 16 | 17 | Image Title: Coyote from Yosemite National Park, California in snow 18 | Author: Yathin S Krishnappa 19 | Source https://commons.wikimedia.org/wiki/File:2009-Coyote-Yosemite.jpg 20 | License: https://creativecommons.org/licenses/by-sa/3.0/deed.en 21 | 22 | Image Title: Table lamp with a lampshade illuminated by sunlight. 23 | Author: LoMit 24 | Source: https://commons.wikimedia.org/wiki/File:Lamp_with_a_lampshade_illuminated_by_sunlight.jpg 25 | License: https://creativecommons.org/licenses/by-sa/4.0/deed.en 26 | 27 | Image Title: White Pelican Pelecanus onocrotalus at Walvis Bay, Namibia 28 | Author: Rui Ornelas 29 | Source: https://commons.wikimedia.org/wiki/File:Pelikan_Walvis_Bay.jpg 30 | License: https://creativecommons.org/licenses/by/2.0/deed.en 31 | 32 | Image Title: Photo of a traditional torch to be posted at gates 33 | Author: Faizul Latif Chowdhury 34 | Source: https://commons.wikimedia.org/wiki/File:Torch_traditional.jpg 35 | License: https://creativecommons.org/licenses/by-sa/3.0/deed.en 36 | 37 | Image Title: American Flamingo Phoenicopterus ruber at Gotomeer, Riscado, Bonaire 38 | Author: Paul Asman and Jill Lenoble 39 | Source: https://commons.wikimedia.org/wiki/File:Phoenicopterus_ruber_Bonaire_2.jpg 40 | License: https://creativecommons.org/licenses/by/2.0/deed.en 41 | 42 | ## ONNX Model 43 | 44 | SqueezeNet 1.1 model is licensed under Apache License 2.0. The model is downloaded from the [ONNX model zoo](https://github.com/onnx/models/tree/main). 45 | 46 | Source: https://github.com/onnx/models/blob/main/vision/classification/squeezenet/model/squeezenet1.1-7.onnx 47 | License: Apache License 2.0 48 | License URL: https://github.com/onnx/models/blob/main/LICENSE 49 | 50 | ## ONNX Labels 51 | 52 | The labels for the SqueezeNet 1.1 model are licensed under Apache License 2.0. The labels are downloaded from the [ONNX model zoo](https://github.com/onnx/models/blob/main/vision/classification/synset.txt) -------------------------------------------------------------------------------- /squeezenet-burn/README.md: -------------------------------------------------------------------------------- 1 | # SqueezeNet Burn - from ONNX to Rust 2 | 3 | SqueezeNet is a small CNN that can be used for image classification. It was trained on the ImageNet 4 | dataset and can classify images into 1000 different classes. The included ONNX model is copied from 5 | the [ONNX model zoo](https://github.com/onnx/models/tree/main/vision/classification/squeezenet), and 6 | the details of the model can be found in the [paper](https://arxiv.org/abs/1602.07360). 7 | 8 | The ONNX model is converted into a [Burn](https://github.com/burn-rs/burn/tree/main) model in Rust 9 | using the [burn-import](https://github.com/burn-rs/burn/tree/main/burn-import) crate during build 10 | time. The weights are saved in a binary file during build time in Burn compatible format, and the 11 | model is loaded at runtime. 12 | 13 | It is worth noting that the model can be fine-tuned to improve the accuracy, since the ONNX model is 14 | fully converted to a Burn model. The model is trained with the ImageNet dataset, which contains 1.2 15 | million images. The model can be fine-tuned with a smaller dataset to improve the accuracy for a 16 | specific use case. 17 | 18 | The labels for the classes are included in the crate and generated from the 19 | [`labels.txt`](src/model/label.txt) during build time. 20 | 21 | The data normalizer for the model is included in the crate. See 22 | [Normalizer](src/model/normalizer.rs). 23 | 24 | The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html). 25 | 26 | See the [classify example](examples/classify.rs) for how to use the model. 27 | 28 | ## Usage 29 | 30 | ### To include the model in your project 31 | 32 | Add this to your `Cargo.toml`: 33 | 34 | ```toml 35 | [dependencies] 36 | squeezenet-burn = { git = "https://github.com/tracel-ai/models", package = "squeezenet-burn", features = ["weights_embedded"], default-features = false } 37 | ``` 38 | 39 | ### To run the example 40 | 41 | 1. Use the `weights_embedded` feature to embed the weights in the binary. 42 | 43 | ```shell 44 | cargo r --release --features weights_embedded --no-default-features --example classify samples/flamingo.jpg 45 | ``` 46 | 47 | 2. Use the `weights_file` feature to load the weights from a file. 48 | 49 | ```shell 50 | cargo r --release --features weights_file --example classify samples/flamingo.jpg 51 | ``` 52 | 53 | 3. Use the `weights_f16` feature to use 16-bit floating point numbers for the weights. 54 | 55 | ```shell 56 | cargo r --release --features "weights_embedded, weights_f16" --no-default-features --example classify samples/flamingo.jpg 57 | ``` 58 | 59 | Or 60 | 61 | ```shell 62 | cargo r --release --features "weights_file, weights_f16" --example classify samples/flamingo.jpg 63 | ``` 64 | 65 | ## Feature Flags 66 | 67 | - `weights_file`: Load the weights from a file (enabled by default). 68 | - `weights_embedded`: Embed the weights in the binary. 69 | - `weights_f16`: Use 16-bit floating point numbers for the weights. (by default 32-bit is used) 70 | -------------------------------------------------------------------------------- /squeezenet-burn/build.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | use std::fs; 3 | use std::fs::File; 4 | use std::io::{BufRead, BufReader, Write}; 5 | use std::path::Path; 6 | 7 | use burn_import::burn::graph::RecordType; 8 | use burn_import::onnx::ModelGen; 9 | 10 | const LABEL_SOURCE_FILE: &str = "src/model/label.txt"; 11 | const LABEL_DEST_FILE: &str = "model/label.rs"; 12 | const GENERATED_MODEL_WEIGHTS_FILE: &str = "squeezenet1.mpk"; 13 | const INPUT_ONNX_FILE: &str = "src/model/squeezenet1.onnx"; 14 | const OUT_DIR: &str = "model/"; 15 | 16 | fn main() { 17 | // Re-run the build script if model files change. 18 | println!("cargo:rerun-if-changed=src/model"); 19 | 20 | // Make sure either weights_file or weights_embedded is enabled. 21 | if cfg!(feature = "weights_file") && cfg!(feature = "weights_embedded") { 22 | panic!("Only one of the features weights_file and weights_embedded can be enabled"); 23 | } 24 | 25 | // Make sure at least one of weights_file or weights_embedded is enabled. 26 | if !cfg!(feature = "weights_file") && !cfg!(feature = "weights_embedded") { 27 | panic!("One of the features weights_file and weights_embedded must be enabled"); 28 | } 29 | 30 | // Check if the weights are embedded. 31 | let (record_type, embed_states) = if cfg!(feature = "weights_embedded") { 32 | (RecordType::Bincode, true) 33 | } else { 34 | (RecordType::NamedMpk, false) 35 | }; 36 | 37 | // Check if half precision is enabled. 38 | let half_precision = cfg!(feature = "weights_f16"); 39 | 40 | // Generate the model code from the ONNX file. 41 | ModelGen::new() 42 | .input(INPUT_ONNX_FILE) 43 | .out_dir(OUT_DIR) 44 | .record_type(record_type) 45 | .embed_states(embed_states) 46 | .half_precision(half_precision) 47 | .run_from_script(); 48 | 49 | // Copy the weights next to the executable. 50 | if cfg!(feature = "weights_file") && cfg!(feature = "weights_file_dump") { 51 | copy_weights_next_to_executable(); 52 | } 53 | 54 | // Generate the labels from the synset.txt file. 55 | generate_labels_from_txt_file().unwrap(); 56 | } 57 | 58 | /// Read labels from synset.txt and store them in a vector of strings in a Rust file. 59 | fn generate_labels_from_txt_file() -> std::io::Result<()> { 60 | let out_dir = env::var("OUT_DIR").unwrap(); 61 | let dest_path = Path::new(&out_dir).join(LABEL_DEST_FILE); 62 | let mut f = File::create(dest_path)?; 63 | 64 | let file = File::open(LABEL_SOURCE_FILE)?; 65 | let reader = BufReader::new(file); 66 | 67 | writeln!(f, "pub static LABELS: &[&str] = &[")?; 68 | for line in reader.lines() { 69 | writeln!(f, " \"{}\",", line.unwrap())?; 70 | } 71 | writeln!(f, "];")?; 72 | 73 | Ok(()) 74 | } 75 | 76 | /// Copy the weights file next to the executable. 77 | fn copy_weights_next_to_executable() { 78 | // Obtain the OUT_DIR path from the environment variable. 79 | let out_dir = env::var("OUT_DIR").expect("OUT_DIR not defined"); 80 | 81 | // Weights file in OUT_DIR that you want to copy. 82 | let source_path = Path::new(&out_dir) 83 | .join("model") 84 | .join(GENERATED_MODEL_WEIGHTS_FILE); 85 | 86 | // Determine the profile (debug or release) to set the appropriate destination directory. 87 | let profile = env::var("PROFILE").expect("PROFILE not defined"); 88 | let target_dir = format!("target/{}", profile); 89 | 90 | // Specify the destination path. 91 | let destination_path = Path::new(&target_dir) 92 | .join("examples") 93 | .join(GENERATED_MODEL_WEIGHTS_FILE); 94 | 95 | // Copy the file. 96 | fs::copy(source_path, destination_path).expect("Failed to copy generated file"); 97 | } 98 | -------------------------------------------------------------------------------- /squeezenet-burn/examples/classify.rs: -------------------------------------------------------------------------------- 1 | use squeezenet_burn::model::{label::LABELS, normalizer::Normalizer, squeezenet1::Model}; 2 | 3 | #[cfg(feature = "weights_embedded")] 4 | use burn::backend::ndarray::NdArrayDevice; 5 | 6 | use burn::backend::NdArray; 7 | use burn::tensor::Tensor; 8 | 9 | use image::{self, GenericImageView, Pixel}; 10 | 11 | const HEIGHT: usize = 224; 12 | const WIDTH: usize = 224; 13 | 14 | #[cfg(feature = "weights_file")] 15 | const RECORD_FILE: &str = "squeezenet1"; 16 | 17 | type Backend = NdArray; 18 | 19 | fn main() { 20 | // Path to the image from the main args 21 | let img_path = std::env::args().nth(1).expect("No image path provided"); 22 | 23 | // Load the image 24 | let img = image::open(&img_path).unwrap_or_else(|_| panic!("Failed to load image: {img_path}")); 25 | 26 | // Resize it to 224x224 27 | let resized_img = img.resize_exact( 28 | WIDTH as u32, 29 | HEIGHT as u32, 30 | image::imageops::FilterType::Lanczos3, 31 | ); 32 | 33 | // 3d array of 224x224x3 floats 34 | let mut img_array = [[[0.0; WIDTH]; HEIGHT]; 3]; 35 | 36 | // Iterate over the pixels and populate the array 37 | for y in 0..224usize { 38 | for x in 0..224usize { 39 | let pixel = resized_img.get_pixel(x as u32, y as u32); 40 | let rgb = pixel.to_rgb(); 41 | 42 | img_array[0][y][x] = rgb[0] as f32 / 255.0; 43 | img_array[1][y][x] = rgb[1] as f32 / 255.0; 44 | img_array[2][y][x] = rgb[2] as f32 / 255.0; 45 | } 46 | } 47 | 48 | let device = Default::default(); 49 | 50 | // Create a tensor from the array 51 | let image_input = 52 | Tensor::::from_data(img_array, &device).reshape([1, 3, HEIGHT, WIDTH]); 53 | // Normalize the image 54 | let normalizer = Normalizer::new(&device); 55 | let normalized_image = normalizer.normalize(image_input); 56 | 57 | // Create the model 58 | // Load the weights from the file next to the executable 59 | #[cfg(feature = "weights_file")] 60 | let weights_file = std::env::current_exe() 61 | .unwrap() 62 | .parent() 63 | .unwrap() 64 | .join(RECORD_FILE); 65 | 66 | #[cfg(feature = "weights_file")] 67 | let model = Model::::from_file(weights_file.to_str().unwrap(), &device); 68 | 69 | #[cfg(feature = "weights_embedded")] 70 | // Load model from embedded weights 71 | let model = Model::::from_embedded(&NdArrayDevice::Cpu); 72 | 73 | // Run the model 74 | let output = model.forward(normalized_image); 75 | 76 | // Get the argmax of the output 77 | let arg_max = output.argmax(1).into_scalar() as usize; 78 | 79 | // Get the label from the argmax 80 | let label = LABELS[arg_max]; 81 | 82 | println!("Predicted label: {}", label); 83 | } 84 | -------------------------------------------------------------------------------- /squeezenet-burn/samples/bridge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tracel-ai/models/146da485484c6cec70b69065117329dea4604b73/squeezenet-burn/samples/bridge.jpg -------------------------------------------------------------------------------- /squeezenet-burn/samples/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tracel-ai/models/146da485484c6cec70b69065117329dea4604b73/squeezenet-burn/samples/cat.jpg -------------------------------------------------------------------------------- /squeezenet-burn/samples/coyote.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tracel-ai/models/146da485484c6cec70b69065117329dea4604b73/squeezenet-burn/samples/coyote.jpg -------------------------------------------------------------------------------- /squeezenet-burn/samples/flamingo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tracel-ai/models/146da485484c6cec70b69065117329dea4604b73/squeezenet-burn/samples/flamingo.jpg -------------------------------------------------------------------------------- /squeezenet-burn/samples/pelican.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tracel-ai/models/146da485484c6cec70b69065117329dea4604b73/squeezenet-burn/samples/pelican.jpg -------------------------------------------------------------------------------- /squeezenet-burn/samples/table-lamp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tracel-ai/models/146da485484c6cec70b69065117329dea4604b73/squeezenet-burn/samples/table-lamp.jpg -------------------------------------------------------------------------------- /squeezenet-burn/samples/torch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tracel-ai/models/146da485484c6cec70b69065117329dea4604b73/squeezenet-burn/samples/torch.jpg -------------------------------------------------------------------------------- /squeezenet-burn/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![no_std] 2 | pub mod model; 3 | -------------------------------------------------------------------------------- /squeezenet-burn/src/model/label.rs: -------------------------------------------------------------------------------- 1 | // Generated labels from labels.txt 2 | include!(concat!(env!("OUT_DIR"), "/model/label.rs")); 3 | -------------------------------------------------------------------------------- /squeezenet-burn/src/model/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod label; 2 | pub mod normalizer; 3 | pub mod squeezenet1; 4 | -------------------------------------------------------------------------------- /squeezenet-burn/src/model/normalizer.rs: -------------------------------------------------------------------------------- 1 | use burn::tensor::{backend::Backend, Tensor}; 2 | 3 | // Values are taken from the [ONNX SqueezeNet] 4 | // (https://github.com/onnx/models/tree/main/vision/classification/squeezenet#preprocessing) 5 | const MEAN: [f32; 3] = [0.485, 0.456, 0.406]; 6 | const STD: [f32; 3] = [0.229, 0.224, 0.225]; 7 | 8 | /// Normalizer for the imagenet dataset. 9 | pub struct Normalizer { 10 | pub mean: Tensor, 11 | pub std: Tensor, 12 | } 13 | 14 | impl Normalizer { 15 | /// Creates a new normalizer. 16 | pub fn new(device: &B::Device) -> Self { 17 | let mean = Tensor::::from_floats(MEAN, device).reshape([1, 3, 1, 1]); 18 | let std = Tensor::::from_floats(STD, device).reshape([1, 3, 1, 1]); 19 | Self { mean, std } 20 | } 21 | 22 | /// Normalizes the input image according to the imagenet dataset. 23 | /// 24 | /// The input image should be in the range [0, 1]. 25 | /// The output image will be in the range [-1, 1]. 26 | /// 27 | /// The normalization is done according to the following formula: 28 | /// `input = (input - mean) / std` 29 | pub fn normalize(&self, input: Tensor) -> Tensor { 30 | (input - self.mean.clone()) / self.std.clone() 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /squeezenet-burn/src/model/squeezenet1.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tracel-ai/models/146da485484c6cec70b69065117329dea4604b73/squeezenet-burn/src/model/squeezenet1.onnx -------------------------------------------------------------------------------- /squeezenet-burn/src/model/squeezenet1.rs: -------------------------------------------------------------------------------- 1 | // Generated model from squeezenet1.onnx 2 | mod model { 3 | include!(concat!(env!("OUT_DIR"), "/model/squeezenet1.rs")); 4 | } 5 | 6 | pub use model::*; 7 | -------------------------------------------------------------------------------- /yolox-burn/.gitignore: -------------------------------------------------------------------------------- 1 | # Output image 2 | *.output.png 3 | -------------------------------------------------------------------------------- /yolox-burn/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["guillaumelagrange "] 3 | license = "MIT OR Apache-2.0" 4 | name = "yolox-burn" 5 | version = "0.1.0" 6 | edition = "2021" 7 | 8 | [features] 9 | default = [] 10 | std = [] 11 | pretrained = ["burn/network", "std", "dep:dirs"] 12 | 13 | [dependencies] 14 | # Note: default-features = false is needed to disable std 15 | burn = { version = "0.16.0", default-features = false } 16 | burn-import = { version = "0.16.0" } 17 | itertools = { version = "0.12.1", default-features = false, features = [ 18 | "use_alloc", 19 | ] } 20 | dirs = { version = "5.0.1", optional = true } 21 | serde = { version = "1.0.192", default-features = false, features = [ 22 | "derive", 23 | "alloc", 24 | ] } # alloc is for no_std, derive is needed 25 | 26 | [dev-dependencies] 27 | burn = { version = "0.16.0", features = ["ndarray"] } 28 | image = { version = "0.24.9", features = ["png", "jpeg"] } 29 | -------------------------------------------------------------------------------- /yolox-burn/LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | ../LICENSE-APACHE -------------------------------------------------------------------------------- /yolox-burn/LICENSE-MIT: -------------------------------------------------------------------------------- 1 | ../LICENSE-MIT -------------------------------------------------------------------------------- /yolox-burn/NOTICES.md: -------------------------------------------------------------------------------- 1 | # NOTICES AND INFORMATION 2 | 3 | This file contains notices and information required by libraries that this repository copied or derived from. The use of the following resources complies with the licenses provided. 4 | 5 | ## Sample Image 6 | 7 | Image Title: Man with Bike and Pet Dog circa 1900 (archive ref DDX1319-2-3) 8 | Author: East Riding Archives 9 | Source: https://commons.wikimedia.org/wiki/File:Man_with_Bike_and_Pet_Dog_circa_1900_%28archive_ref_DDX1319-2-3%29_%2826507570321%29.jpg 10 | License: [Creative Commons](https://www.flickr.com/commons/usage/) 11 | 12 | ## Pre-trained Model 13 | 14 | The COCO pre-trained model was ported from the original [YOLOX implementation](https://github.com/Megvii-BaseDetection/YOLOX). 15 | 16 | As opposed to other YOLO variants (YOLOv8, YOLO-NAS, etc.), both the code and pre-trained weights are distributed under the [Apache 2.0](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/LICENSE) open source license. 17 | -------------------------------------------------------------------------------- /yolox-burn/README.md: -------------------------------------------------------------------------------- 1 | # YOLOX Burn 2 | 3 | There have been many different object detection models with the YOLO prefix released in the recent 4 | years, though most of them carry a GPL or AGPL license which restricts their usage. For this reason, 5 | we selected [YOLOX](https://arxiv.org/abs/2107.08430) as the first object detection architecture 6 | since both the original code and pre-trained weights are released under the 7 | [Apache 2.0](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/LICENSE) open source license. 8 | 9 | You can find the [Burn](https://github.com/tracel-ai/burn) implementation for the YOLOX variants in 10 | [src/model/yolox.rs](src/model/yolox.rs). 11 | 12 | The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html). 13 | 14 | ## Usage 15 | 16 | ### `Cargo.toml` 17 | 18 | Add this to your `Cargo.toml`: 19 | 20 | ```toml 21 | [dependencies] 22 | yolox-burn = { git = "https://github.com/tracel-ai/models", package = "yolox-burn", default-features = false } 23 | ``` 24 | 25 | If you want to get the COCO pre-trained weights, enable the `pretrained` feature flag. 26 | 27 | ```toml 28 | [dependencies] 29 | yolox-burn = { git = "https://github.com/tracel-ai/models", package = "yolox-burn", features = ["pretrained"] } 30 | ``` 31 | 32 | **Important:** this feature requires `std`. 33 | 34 | ### Example Usage 35 | 36 | The [inference example](examples/inference.rs) initializes a YOLOX-Tiny from the COCO 37 | [pre-trained weights](https://github.com/Megvii-BaseDetection/YOLOX?tab=readme-ov-file#standard-models) 38 | with the `NdArray` backend and performs inference on the provided input image. 39 | 40 | You can run the example with the following command: 41 | 42 | ```sh 43 | cargo run --release --features pretrained --example inference samples/dog_bike_man.jpg 44 | ``` 45 | -------------------------------------------------------------------------------- /yolox-burn/examples/inference.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | 3 | use image::{DynamicImage, ImageBuffer}; 4 | use yolox_burn::model::{boxes::nms, weights, yolox::Yolox, BoundingBox}; 5 | 6 | use burn::{ 7 | backend::NdArray, 8 | tensor::{backend::Backend, Device, Element, Tensor, TensorData}, 9 | }; 10 | 11 | const HEIGHT: usize = 640; 12 | const WIDTH: usize = 640; 13 | 14 | fn to_tensor( 15 | data: Vec, 16 | shape: [usize; 3], 17 | device: &Device, 18 | ) -> Tensor { 19 | Tensor::::from_data( 20 | TensorData::new(data, shape).convert::(), 21 | device, 22 | ) 23 | // [H, W, C] -> [C, H, W] 24 | .permute([2, 0, 1]) 25 | } 26 | 27 | /// Draws bounding boxes on the given image. 28 | /// 29 | /// # Arguments 30 | /// 31 | /// * `image`: Original input image. 32 | /// * `boxes` - Bounding boxes, grouped per class. 33 | /// * `color` - [R, G, B] color values to draw the boxes. 34 | /// * `ratio` - [x, y] aspect ratio to scale the predicted boxes. 35 | /// 36 | /// # Returns 37 | /// 38 | /// The image annotated with bounding boxes. 39 | fn draw_boxes( 40 | image: DynamicImage, 41 | boxes: &[Vec], 42 | color: &[u8; 3], 43 | ratio: &[f32; 2], // (x, y) ratio 44 | ) -> DynamicImage { 45 | // Assumes x1 <= x2 and y1 <= y2 46 | fn draw_rect( 47 | image: &mut ImageBuffer, Vec>, 48 | x1: u32, 49 | x2: u32, 50 | y1: u32, 51 | y2: u32, 52 | color: &[u8; 3], 53 | ) { 54 | for x in x1..=x2 { 55 | let pixel = image.get_pixel_mut(x, y1); 56 | *pixel = image::Rgb(*color); 57 | let pixel = image.get_pixel_mut(x, y2); 58 | *pixel = image::Rgb(*color); 59 | } 60 | for y in y1..=y2 { 61 | let pixel = image.get_pixel_mut(x1, y); 62 | *pixel = image::Rgb(*color); 63 | let pixel = image.get_pixel_mut(x2, y); 64 | *pixel = image::Rgb(*color); 65 | } 66 | } 67 | 68 | // Annotate the original image and print boxes information. 69 | let (image_h, image_w) = (image.height(), image.width()); 70 | let mut image = image.to_rgb8(); 71 | for (class_index, bboxes_for_class) in boxes.iter().enumerate() { 72 | for b in bboxes_for_class.iter() { 73 | let xmin = (b.xmin * ratio[0]).clamp(0., image_w as f32 - 1.); 74 | let ymin = (b.ymin * ratio[1]).clamp(0., image_h as f32 - 1.); 75 | let xmax = (b.xmax * ratio[0]).clamp(0., image_w as f32 - 1.); 76 | let ymax = (b.ymax * ratio[1]).clamp(0., image_h as f32 - 1.); 77 | 78 | println!( 79 | "Predicted {} ({:.2}) at [{:.2}, {:.2}, {:.2}, {:.2}]", 80 | class_index, b.confidence, xmin, ymin, xmax, ymax, 81 | ); 82 | 83 | draw_rect( 84 | &mut image, 85 | xmin as u32, 86 | xmax as u32, 87 | ymin as u32, 88 | ymax as u32, 89 | color, 90 | ); 91 | } 92 | } 93 | DynamicImage::ImageRgb8(image) 94 | } 95 | 96 | pub fn main() { 97 | // Parse arguments 98 | let img_path = std::env::args().nth(1).expect("No image path provided"); 99 | 100 | // Create YOLOX-Tiny 101 | let device = Default::default(); 102 | let model: Yolox = Yolox::yolox_tiny_pretrained(weights::YoloxTiny::Coco, &device) 103 | .map_err(|err| format!("Failed to load pre-trained weights.\nError: {err}")) 104 | .unwrap(); 105 | 106 | // Load image 107 | let img = image::open(&img_path) 108 | .map_err(|err| format!("Failed to load image {img_path}.\nError: {err}")) 109 | .unwrap(); 110 | 111 | // Resize to 640x640 112 | let resized_img = img.resize_exact( 113 | WIDTH as u32, 114 | HEIGHT as u32, 115 | image::imageops::FilterType::Triangle, // also known as bilinear in 2D 116 | ); 117 | 118 | // Create tensor from image data 119 | let x = to_tensor( 120 | resized_img.into_rgb8().into_raw(), 121 | [HEIGHT, WIDTH, 3], 122 | &device, 123 | ) 124 | .unsqueeze::<4>(); // [B, C, H, W] 125 | 126 | // Forward pass 127 | let out = model.forward(x); 128 | 129 | // Post-processing 130 | let [_, num_boxes, num_outputs] = out.dims(); 131 | let boxes = out.clone().slice([0..1, 0..num_boxes, 0..4]); 132 | let obj_scores = out.clone().slice([0..1, 0..num_boxes, 4..5]); 133 | let cls_scores = out.slice([0..1, 0..num_boxes, 5..num_outputs]); 134 | let scores = cls_scores * obj_scores; 135 | let boxes = nms(boxes, scores, 0.65, 0.5); 136 | 137 | // Draw outputs and save results 138 | let (h, w) = (img.height(), img.width()); 139 | let img_out = draw_boxes( 140 | img, 141 | &boxes[0], 142 | &[239u8, 62u8, 5u8], 143 | &[w as f32 / WIDTH as f32, h as f32 / HEIGHT as f32], 144 | ); 145 | 146 | let img_path = Path::new(&img_path); 147 | let _ = img_out.save(img_path.with_extension("output.png")); 148 | } 149 | -------------------------------------------------------------------------------- /yolox-burn/samples/dog_bike_man.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tracel-ai/models/146da485484c6cec70b69065117329dea4604b73/yolox-burn/samples/dog_bike_man.jpg -------------------------------------------------------------------------------- /yolox-burn/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![cfg_attr(not(feature = "std"), no_std)] 2 | pub mod model; 3 | extern crate alloc; 4 | -------------------------------------------------------------------------------- /yolox-burn/src/model/blocks.rs: -------------------------------------------------------------------------------- 1 | use alloc::vec; 2 | use burn::{ 3 | config::Config, 4 | module::Module, 5 | nn::{ 6 | conv::{Conv2d, Conv2dConfig}, 7 | BatchNorm, BatchNormConfig, PaddingConfig2d, 8 | }, 9 | tensor::{activation::silu, backend::Backend, Device, Tensor}, 10 | }; 11 | 12 | /// Compute the number of channels based on the provided factor. 13 | pub fn expand(num_channels: usize, factor: f64) -> usize { 14 | (num_channels as f64 * factor).floor() as usize 15 | } 16 | 17 | /// A base convolution block. 18 | /// Allows to switch between regular and depthwise separable convolution blocks based on the 19 | /// architecture. 20 | #[derive(Module, Debug)] 21 | pub enum Conv { 22 | /// Basic convolution block used for all variants. 23 | BaseConv(BaseConv), 24 | /// Depthwise separable convolution block, used for some blocks by YOLOX-Nano. 25 | DwsConv(DwsConv), 26 | } 27 | 28 | impl Conv { 29 | pub fn forward(&self, x: Tensor) -> Tensor { 30 | match self { 31 | Self::BaseConv(conv) => conv.forward(x), 32 | Self::DwsConv(conv) => conv.forward(x), 33 | } 34 | } 35 | } 36 | 37 | #[derive(Config)] 38 | pub struct ConvConfig { 39 | in_channels: usize, 40 | out_channels: usize, 41 | kernel_size: usize, 42 | stride: usize, 43 | depthwise: bool, 44 | } 45 | 46 | impl ConvConfig { 47 | /// Initialize a new [convolution block](Conv) module. 48 | pub fn init(&self, device: &Device) -> Conv { 49 | if self.depthwise { 50 | Conv::DwsConv( 51 | DwsConvConfig::new( 52 | self.in_channels, 53 | self.out_channels, 54 | self.kernel_size, 55 | self.stride, 56 | ) 57 | .init(device), 58 | ) 59 | } else { 60 | Conv::BaseConv( 61 | BaseConvConfig::new( 62 | self.in_channels, 63 | self.out_channels, 64 | self.kernel_size, 65 | self.stride, 66 | 1, 67 | ) 68 | .init(device), 69 | ) 70 | } 71 | } 72 | } 73 | 74 | /// A Conv2d -> BatchNorm -> activation block. 75 | #[derive(Module, Debug)] 76 | pub struct BaseConv { 77 | conv: Conv2d, 78 | bn: BatchNorm, 79 | } 80 | 81 | impl BaseConv { 82 | pub fn forward(&self, x: Tensor) -> Tensor { 83 | let x = self.conv.forward(x); 84 | let x = self.bn.forward(x); 85 | 86 | silu(x) 87 | } 88 | } 89 | 90 | /// [Base convolution block](BaseConv) configuration. 91 | pub struct BaseConvConfig { 92 | conv: Conv2dConfig, 93 | bn: BatchNormConfig, 94 | } 95 | 96 | impl BaseConvConfig { 97 | /// Create a new instance of the base convolution block [config](BaseConvConfig). 98 | pub fn new( 99 | in_channels: usize, 100 | out_channels: usize, 101 | kernel_size: usize, 102 | stride: usize, 103 | groups: usize, 104 | ) -> Self { 105 | // Same padding 106 | let pad = (kernel_size - 1) / 2; 107 | 108 | let conv = Conv2dConfig::new([in_channels, out_channels], [kernel_size, kernel_size]) 109 | .with_stride([stride, stride]) 110 | .with_padding(PaddingConfig2d::Explicit(pad, pad)) 111 | .with_groups(groups) 112 | .with_bias(false); 113 | let bn = BatchNormConfig::new(out_channels) 114 | .with_epsilon(1e-3) 115 | .with_momentum(0.03); 116 | 117 | Self { conv, bn } 118 | } 119 | 120 | /// Initialize a new [base convolution block](BaseConv) module. 121 | pub fn init(&self, device: &Device) -> BaseConv { 122 | BaseConv { 123 | conv: self.conv.init(device), 124 | bn: self.bn.init(device), 125 | } 126 | } 127 | } 128 | 129 | /// A [depthwise separable convolution](https://paperswithcode.com/method/depthwise-separable-convolution) 130 | /// block. Both depthwise and pointwise blocks consist of a Conv2d -> BatchNorm -> activation block. 131 | #[derive(Module, Debug)] 132 | pub struct DwsConv { 133 | dconv: BaseConv, 134 | pconv: BaseConv, 135 | } 136 | 137 | impl DwsConv { 138 | pub fn forward(&self, x: Tensor) -> Tensor { 139 | let x = self.dconv.forward(x); 140 | self.pconv.forward(x) 141 | } 142 | } 143 | 144 | /// [Depthwise separable convolution block](DwsConv) configuration. 145 | pub struct DwsConvConfig { 146 | dconv: BaseConvConfig, 147 | pconv: BaseConvConfig, 148 | } 149 | 150 | impl DwsConvConfig { 151 | /// Create a new instance of the depthwise separable convolution block [config](DwsConvConfig). 152 | pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize, stride: usize) -> Self { 153 | // Depthwise conv 154 | let dconv = BaseConvConfig::new(in_channels, in_channels, kernel_size, stride, in_channels); 155 | // Pointwise conv 156 | let pconv = BaseConvConfig::new(in_channels, out_channels, 1, 1, 1); 157 | 158 | Self { dconv, pconv } 159 | } 160 | 161 | /// Initialize a new [depthwise separable convolution block](DwsConv) module. 162 | pub fn init(&self, device: &Device) -> DwsConv { 163 | DwsConv { 164 | dconv: self.dconv.init(device), 165 | pconv: self.pconv.init(device), 166 | } 167 | } 168 | } 169 | 170 | /// Focus width and height information into channel space. 171 | #[derive(Module, Debug)] 172 | pub struct Focus { 173 | conv: BaseConv, 174 | } 175 | 176 | impl Focus { 177 | pub fn forward(&self, x: Tensor) -> Tensor { 178 | let device = x.device(); 179 | let [_, _, h, w] = x.dims(); 180 | 181 | // Indexing 182 | let top_idx = Tensor::arange_step(0..h as i64, 2, &device); 183 | let bottom_idx = Tensor::arange_step(1..h as i64, 2, &device); 184 | let left_idx = Tensor::arange_step(0..w as i64, 2, &device); 185 | let right_idx = Tensor::arange_step(1..w as i64, 2, &device); 186 | 187 | // patch_top_left = x[..., ::2, ::2] 188 | let patch_top_left = x 189 | .clone() 190 | .select(2, top_idx.clone()) 191 | .select(3, left_idx.clone()); 192 | // patch_top_right = x[..., ::2, 1::2] 193 | let patch_top_right = x.clone().select(2, top_idx).select(3, right_idx.clone()); 194 | // patch_bot_left = x[..., 1::2, ::2] 195 | let patch_bottom_left = x.clone().select(2, bottom_idx.clone()).select(3, left_idx); 196 | // patch_bot_right = x[..., 1::2, 1::2] 197 | let patch_bottom_right = x.select(2, bottom_idx).select(3, right_idx); 198 | 199 | // Shape (b,c,w,h) -> y(b,4c,w/2,h/2) 200 | let x = Tensor::cat( 201 | vec![ 202 | patch_top_left, 203 | patch_bottom_left, 204 | patch_top_right, 205 | patch_bottom_right, 206 | ], 207 | 1, 208 | ); 209 | 210 | self.conv.forward(x) 211 | } 212 | } 213 | 214 | /// [Focus block](Focus) configuration. 215 | pub struct FocusConfig { 216 | conv: BaseConvConfig, 217 | } 218 | 219 | impl FocusConfig { 220 | /// Create a new instance of the focus block [config](FocusConfig). 221 | pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize, stride: usize) -> Self { 222 | let conv = BaseConvConfig::new(in_channels * 4, out_channels, kernel_size, stride, 1); 223 | 224 | Self { conv } 225 | } 226 | 227 | /// Initialize a new [focus block](Focus) module. 228 | pub fn init(&self, device: &Device) -> Focus { 229 | Focus { 230 | conv: self.conv.init(device), 231 | } 232 | } 233 | } 234 | 235 | /// Dual convolution block used for feature extraction in the prediction head. 236 | #[derive(Module, Debug)] 237 | pub struct ConvBlock { 238 | conv0: Conv, 239 | conv1: Conv, 240 | } 241 | 242 | impl ConvBlock { 243 | pub fn forward(&self, x: Tensor) -> Tensor { 244 | let x = self.conv0.forward(x); 245 | self.conv1.forward(x) 246 | } 247 | } 248 | 249 | /// [Dual convolution block](ConvBlock) configuration. 250 | pub struct ConvBlockConfig { 251 | conv0: ConvConfig, 252 | conv1: ConvConfig, 253 | } 254 | 255 | impl ConvBlockConfig { 256 | /// Create a new instance of the dual convolution block [config](ConvBlockConfig). 257 | pub fn new(channels: usize, kernel_size: usize, stride: usize, depthwise: bool) -> Self { 258 | let conv0 = ConvConfig::new(channels, channels, kernel_size, stride, depthwise); 259 | let conv1 = ConvConfig::new(channels, channels, kernel_size, stride, depthwise); 260 | 261 | Self { conv0, conv1 } 262 | } 263 | 264 | /// Initialize a new [dual convolution block](ConvBlock) module. 265 | pub fn init(&self, device: &Device) -> ConvBlock { 266 | ConvBlock { 267 | conv0: self.conv0.init(device), 268 | conv1: self.conv1.init(device), 269 | } 270 | } 271 | } 272 | -------------------------------------------------------------------------------- /yolox-burn/src/model/bottleneck.rs: -------------------------------------------------------------------------------- 1 | use alloc::{vec, vec::Vec}; 2 | use burn::{ 3 | module::Module, 4 | nn::pool::{MaxPool2d, MaxPool2dConfig}, 5 | tensor::{backend::Backend, Device, Tensor}, 6 | }; 7 | 8 | use super::blocks::{expand, BaseConv, BaseConvConfig, Conv, ConvConfig}; 9 | 10 | pub(crate) const SPP_POOLING: [usize; 3] = [5, 9, 13]; 11 | 12 | /// Standard bottleneck block. 13 | #[derive(Module, Debug)] 14 | pub struct Bottleneck { 15 | conv1: BaseConv, 16 | conv2: Conv, 17 | shortcut: bool, 18 | } 19 | 20 | impl Bottleneck { 21 | pub fn forward(&self, x: Tensor) -> Tensor { 22 | let identity = x.clone(); 23 | 24 | let x = self.conv1.forward(x); 25 | let mut x = self.conv2.forward(x); 26 | 27 | if self.shortcut { 28 | x = x + identity; 29 | } 30 | 31 | x 32 | } 33 | } 34 | 35 | /// [Bottleneck block](Bottleneck) configuration. 36 | struct BottleneckConfig { 37 | conv1: BaseConvConfig, 38 | conv2: ConvConfig, 39 | shortcut: bool, 40 | } 41 | 42 | impl BottleneckConfig { 43 | /// Create a new instance of the bottleneck block [config](BottleneckConfig). 44 | pub fn new(in_channels: usize, out_channels: usize, shortcut: bool, depthwise: bool) -> Self { 45 | // In practice, expansion = 1.0 and no shortcut connection is used 46 | let hidden_channels = out_channels; 47 | 48 | let conv1 = BaseConvConfig::new(in_channels, hidden_channels, 1, 1, 1); 49 | let conv2 = ConvConfig::new(hidden_channels, out_channels, 3, 1, depthwise); 50 | 51 | Self { 52 | conv1, 53 | conv2, 54 | shortcut, 55 | } 56 | } 57 | 58 | /// Initialize a new [bottleneck block](Bottleneck) module. 59 | pub fn init(&self, device: &Device) -> Bottleneck { 60 | Bottleneck { 61 | conv1: self.conv1.init(device), 62 | conv2: self.conv2.init(device), 63 | shortcut: self.shortcut, 64 | } 65 | } 66 | } 67 | 68 | /// Spatial pyramid pooling layer used in YOLOv3-SPP. 69 | #[derive(Module, Debug)] 70 | pub struct SppBottleneck { 71 | conv1: BaseConv, 72 | conv2: BaseConv, 73 | m: Vec, 74 | } 75 | 76 | impl SppBottleneck { 77 | pub fn forward(&self, x: Tensor) -> Tensor { 78 | if self.m.is_empty() { 79 | panic!("No MaxPool2d modules found"); 80 | } 81 | 82 | let x = self.conv1.forward(x); 83 | 84 | let x: Vec<_> = vec![x.clone()] 85 | .into_iter() 86 | .chain(self.m.iter().map(|pool| pool.forward(x.clone()))) 87 | .collect(); 88 | let x = Tensor::cat(x, 1); 89 | 90 | self.conv2.forward(x) 91 | } 92 | } 93 | 94 | /// [SppBottleneck block](SppBottleneck) configuration. 95 | pub struct SppBottleneckConfig { 96 | conv1: BaseConvConfig, 97 | conv2: BaseConvConfig, 98 | m: Vec, 99 | } 100 | 101 | impl SppBottleneckConfig { 102 | /// Create a new instance of the bottleneck block [config](SppBottleneckConfig). 103 | pub fn new(in_channels: usize, out_channels: usize) -> Self { 104 | let hidden_channels = in_channels / 2; 105 | let conv2_channels = hidden_channels * 4; // conv1 output + maxpool (3x) 106 | 107 | let conv1 = BaseConvConfig::new(in_channels, hidden_channels, 1, 1, 1); 108 | let conv2 = BaseConvConfig::new(conv2_channels, out_channels, 1, 1, 1); 109 | let m: Vec<_> = SPP_POOLING 110 | .into_iter() 111 | .map(|k| { 112 | let pad = k / 2; 113 | MaxPool2dConfig::new([k, k]) 114 | .with_padding(burn::nn::PaddingConfig2d::Explicit(pad, pad)) 115 | }) 116 | .collect(); 117 | 118 | Self { conv1, conv2, m } 119 | } 120 | 121 | /// Initialize a new [bottleneck block](SppBottleneck) module. 122 | pub fn init(&self, device: &Device) -> SppBottleneck { 123 | SppBottleneck { 124 | conv1: self.conv1.init(device), 125 | conv2: self.conv2.init(device), 126 | m: self.m.iter().map(|m| m.init()).collect(), 127 | } 128 | } 129 | } 130 | 131 | /// Simplified Cross Stage Partial bottleneck with 3 convolutional layers. 132 | /// Equivalent to C3 in YOLOv5. 133 | #[derive(Module, Debug)] 134 | pub struct CspBottleneck { 135 | conv1: BaseConv, 136 | conv2: BaseConv, 137 | conv3: BaseConv, 138 | m: Vec>, 139 | } 140 | 141 | impl CspBottleneck { 142 | pub fn forward(&self, x: Tensor) -> Tensor { 143 | let x1 = self.conv1.forward(x.clone()); 144 | let x2 = self.conv2.forward(x); 145 | 146 | let x1 = self 147 | .m 148 | .iter() 149 | .fold(x1, |x_i, bottleneck| bottleneck.forward(x_i)); 150 | 151 | let x = Tensor::cat(vec![x1, x2], 1); 152 | 153 | self.conv3.forward(x) 154 | } 155 | } 156 | 157 | /// [CspBottleneck block](CspBottleneck) configuration. 158 | pub struct CspBottleneckConfig { 159 | conv1: BaseConvConfig, 160 | conv2: BaseConvConfig, 161 | conv3: BaseConvConfig, 162 | m: Vec, 163 | } 164 | 165 | impl CspBottleneckConfig { 166 | /// Create a new instance of the bottleneck block [config](CspBottleneckConfig). 167 | pub fn new( 168 | in_channels: usize, 169 | out_channels: usize, 170 | num_blocks: usize, 171 | expansion: f64, 172 | shortcut: bool, 173 | depthwise: bool, 174 | ) -> Self { 175 | assert!( 176 | expansion > 0.0 && expansion <= 1.0, 177 | "expansion should be in range (0, 1]" 178 | ); 179 | 180 | let hidden_channels = expand(out_channels, expansion); 181 | 182 | let conv1 = BaseConvConfig::new(in_channels, hidden_channels, 1, 1, 1); 183 | let conv2 = BaseConvConfig::new(in_channels, hidden_channels, 1, 1, 1); 184 | let conv3 = BaseConvConfig::new(2 * hidden_channels, out_channels, 1, 1, 1); 185 | let m = (0..num_blocks) 186 | .map(|_| BottleneckConfig::new(hidden_channels, hidden_channels, shortcut, depthwise)) 187 | .collect(); 188 | 189 | Self { 190 | conv1, 191 | conv2, 192 | conv3, 193 | m, 194 | } 195 | } 196 | 197 | /// Initialize a new [bottleneck block](CspBottleneck) module. 198 | pub fn init(&self, device: &Device) -> CspBottleneck { 199 | CspBottleneck { 200 | conv1: self.conv1.init(device), 201 | conv2: self.conv2.init(device), 202 | conv3: self.conv3.init(device), 203 | m: self.m.iter().map(|b| b.init(device)).collect(), 204 | } 205 | } 206 | } 207 | -------------------------------------------------------------------------------- /yolox-burn/src/model/boxes.rs: -------------------------------------------------------------------------------- 1 | use burn::tensor::{backend::Backend, ElementConversion, Tensor}; 2 | use itertools::Itertools; 3 | 4 | pub struct BoundingBox { 5 | pub xmin: f32, 6 | pub ymin: f32, 7 | pub xmax: f32, 8 | pub ymax: f32, 9 | pub confidence: f32, 10 | } 11 | 12 | /// Non-maximum suppression (NMS) filters overlapping bounding boxes that have an intersection-over- 13 | /// union (IoU) greater or equal than the specified `iou_threshold` with previously selected boxes. 14 | /// 15 | /// Boxes are filtered based on `score_threshold` and ranked based on their score. As such, lower 16 | /// scoring boxes are removed when overlapping with another (higher scoring) box. 17 | /// 18 | /// # Arguments 19 | /// 20 | /// * `boxes`: Bounding box coordinates. Shape: `[batch_size, num_boxes, 4]`. 21 | /// * `scores` - Classification scores for each box. Shape: `[batch_size, num_boxes, num_classes]`. 22 | /// * `iou_threshold` - Scalar threshold for IoU. 23 | /// * `score_threshold` - Scalar threshold for scores. 24 | /// 25 | /// # Returns 26 | /// 27 | /// Vector of bounding boxes grouped by class for each batch. The boxes are sorted in decreasing 28 | /// order of scores for each class. 29 | pub fn nms( 30 | boxes: Tensor, 31 | scores: Tensor, 32 | iou_threshold: f32, 33 | score_threshold: f32, 34 | ) -> Vec>> { 35 | let [batch_size, num_boxes, num_classes] = scores.dims(); 36 | 37 | // Bounding boxes grouped by batch and by (maximum) class index 38 | let mut bboxes = boxes 39 | .iter_dim(0) 40 | .zip(scores.iter_dim(0)) 41 | .enumerate() 42 | // Per-batch 43 | .map(|(_, (candidate_boxes, candidate_scores))| { 44 | // Keep max scoring boxes only ([num_boxes, 1], [num_boxes, 1]) 45 | let (cls_score, cls_idx) = candidate_scores.squeeze::<2>(0).max_dim_with_indices(1); 46 | let cls_score: Vec<_> = cls_score 47 | .into_data() 48 | .iter::() 49 | .map(|v| v.elem::()) 50 | .collect(); 51 | let cls_idx: Vec<_> = cls_idx 52 | .into_data() 53 | .iter::() 54 | .map(|v| v.elem::() as usize) 55 | .collect(); 56 | 57 | // [num_boxes, 4] 58 | let candidate_boxes: Vec<_> = candidate_boxes 59 | .into_data() 60 | .iter::() 61 | .map(|v| v.elem::()) 62 | .collect(); 63 | 64 | // Per-class filtering based on score 65 | (0..num_classes) 66 | .map(|cls_id| { 67 | // [num_boxes, 1] 68 | (0..num_boxes) 69 | .filter_map(|box_idx| { 70 | let box_cls_idx = cls_idx[box_idx]; 71 | if box_cls_idx != cls_id { 72 | return None; 73 | } 74 | let box_cls_score = cls_score[box_idx]; 75 | if box_cls_score >= score_threshold { 76 | let bbox = &candidate_boxes[box_idx * 4..box_idx * 4 + 4]; 77 | Some(BoundingBox { 78 | xmin: bbox[0] - bbox[2] / 2., 79 | ymin: bbox[1] - bbox[3] / 2., 80 | xmax: bbox[0] + bbox[2] / 2., 81 | ymax: bbox[1] + bbox[3] / 2., 82 | confidence: box_cls_score, 83 | }) 84 | } else { 85 | None 86 | } 87 | }) 88 | .sorted_unstable_by(|a, b| a.confidence.partial_cmp(&b.confidence).unwrap()) 89 | .collect::>() 90 | }) 91 | .collect::>() 92 | }) 93 | .collect::>(); 94 | 95 | for batch_bboxes in bboxes.iter_mut().take(batch_size) { 96 | non_maximum_suppression(batch_bboxes, iou_threshold); 97 | } 98 | 99 | bboxes 100 | } 101 | 102 | /// Intersection over union of two bounding boxes. 103 | pub fn iou(b1: &BoundingBox, b2: &BoundingBox) -> f32 { 104 | let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.); 105 | let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.); 106 | let i_xmin = b1.xmin.max(b2.xmin); 107 | let i_xmax = b1.xmax.min(b2.xmax); 108 | let i_ymin = b1.ymin.max(b2.ymin); 109 | let i_ymax = b1.ymax.min(b2.ymax); 110 | let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.); 111 | i_area / (b1_area + b2_area - i_area) 112 | } 113 | 114 | /// Perform non-maximum suppression over boxes of the same class. 115 | pub fn non_maximum_suppression(bboxes: &mut [Vec], threshold: f32) { 116 | for bboxes_for_class in bboxes.iter_mut() { 117 | bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap()); 118 | let mut current_index = 0; 119 | for index in 0..bboxes_for_class.len() { 120 | let mut drop = false; 121 | for prev_index in 0..current_index { 122 | let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]); 123 | if iou > threshold { 124 | drop = true; 125 | break; 126 | } 127 | } 128 | if !drop { 129 | bboxes_for_class.swap(current_index, index); 130 | current_index += 1; 131 | } 132 | } 133 | bboxes_for_class.truncate(current_index); 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /yolox-burn/src/model/darknet.rs: -------------------------------------------------------------------------------- 1 | use core::cmp::max; 2 | 3 | use crate::model::blocks::expand; 4 | 5 | use super::{ 6 | blocks::{Conv, ConvConfig, Focus, FocusConfig}, 7 | bottleneck::{CspBottleneck, CspBottleneckConfig, SppBottleneck, SppBottleneckConfig}, 8 | }; 9 | use burn::{ 10 | module::Module, 11 | tensor::{backend::Backend, Device, Tensor}, 12 | }; 13 | 14 | /// Darknet backbone feature maps. 15 | pub struct DarknetFeatures(pub Tensor, pub Tensor, pub Tensor); 16 | 17 | /// [CSPDarknet-53](https://paperswithcode.com/method/cspdarknet53) backbone. 18 | #[derive(Module, Debug)] 19 | pub struct CspDarknet { 20 | stem: Focus, 21 | dark2: CspBlock, 22 | dark3: CspBlock, 23 | dark4: CspBlock, 24 | dark5: CspBlock, 25 | } 26 | 27 | impl CspDarknet { 28 | pub fn forward(&self, x: Tensor) -> DarknetFeatures { 29 | let x = self.stem.forward(x); 30 | let x = self.dark2.forward(x); 31 | let f1 = self.dark3.forward(x); 32 | let f2 = self.dark4.forward(f1.clone()); 33 | let f3 = self.dark5.forward(f2.clone()); 34 | 35 | DarknetFeatures(f1, f2, f3) 36 | } 37 | } 38 | 39 | /// [CSPDarknet-53](CspDarknet) configuration. 40 | pub struct CspDarknetConfig { 41 | stem: FocusConfig, 42 | dark2: CspBlockConfig, 43 | dark3: CspBlockConfig, 44 | dark4: CspBlockConfig, 45 | dark5: CspBlockConfig, 46 | } 47 | 48 | impl CspDarknetConfig { 49 | /// Create a new instance of the CSPDarknet-53 [config](CspDarknetConfig). 50 | pub fn new(depth: f64, width: f64, depthwise: bool) -> Self { 51 | assert!( 52 | [0.33, 0.67, 1.0, 1.33].contains(&depth), 53 | "invalid depth value {depth}" 54 | ); 55 | 56 | assert!( 57 | [0.25, 0.375, 0.5, 0.75, 1.0, 1.25].contains(&width), 58 | "invalid width value {width}" 59 | ); 60 | 61 | let base_channels = expand(64, width); 62 | let base_depth = max((depth * 3_f64).round() as usize, 1); 63 | 64 | let stem = FocusConfig::new(3, base_channels, 3, 1); 65 | let dark2 = CspBlockConfig::new( 66 | base_channels, 67 | base_channels * 2, 68 | base_depth, 69 | false, 70 | depthwise, 71 | ); 72 | let dark3 = CspBlockConfig::new( 73 | base_channels * 2, 74 | base_channels * 4, 75 | base_depth * 3, 76 | false, 77 | depthwise, 78 | ); 79 | let dark4 = CspBlockConfig::new( 80 | base_channels * 4, 81 | base_channels * 8, 82 | base_depth * 3, 83 | false, 84 | depthwise, 85 | ); 86 | let dark5 = CspBlockConfig::new( 87 | base_channels * 8, 88 | base_channels * 16, 89 | base_depth, 90 | true, 91 | depthwise, 92 | ); 93 | 94 | Self { 95 | stem, 96 | dark2, 97 | dark3, 98 | dark4, 99 | dark5, 100 | } 101 | } 102 | 103 | /// Initialize a new [CspDarknet](CspDarknet) module. 104 | pub fn init(&self, device: &Device) -> CspDarknet { 105 | CspDarknet { 106 | stem: self.stem.init(device), 107 | dark2: self.dark2.init(device), 108 | dark3: self.dark3.init(device), 109 | dark4: self.dark4.init(device), 110 | dark5: self.dark5.init(device), 111 | } 112 | } 113 | } 114 | 115 | /// A BaseConv -> CspBottleneck block. 116 | /// The SppBottleneck layer is only used in the last block of [CSPDarknet-53](CspDarknet). 117 | #[derive(Module, Debug)] 118 | pub struct CspBlock { 119 | conv: Conv, 120 | c3: CspBottleneck, 121 | spp: Option>, 122 | } 123 | 124 | impl CspBlock { 125 | pub fn forward(&self, x: Tensor) -> Tensor { 126 | let mut x = self.conv.forward(x); 127 | 128 | if let Some(spp) = &self.spp { 129 | x = spp.forward(x); 130 | } 131 | 132 | self.c3.forward(x) 133 | } 134 | } 135 | 136 | /// [CSP block](CspBlock) configuration. 137 | pub struct CspBlockConfig { 138 | conv: ConvConfig, 139 | c3: CspBottleneckConfig, 140 | spp: Option, 141 | } 142 | 143 | impl CspBlockConfig { 144 | /// Create a new instance of the CSP block [config](CspBlockConfig). 145 | pub fn new( 146 | in_channels: usize, 147 | out_channels: usize, 148 | depth: usize, 149 | spp: bool, 150 | depthwise: bool, 151 | ) -> Self { 152 | let conv = ConvConfig::new(in_channels, out_channels, 3, 2, depthwise); 153 | let c3 = CspBottleneckConfig::new(out_channels, out_channels, depth, 0.5, !spp, depthwise); 154 | 155 | let spp = if spp { 156 | Some(SppBottleneckConfig::new(out_channels, out_channels)) 157 | } else { 158 | None 159 | }; 160 | 161 | Self { conv, c3, spp } 162 | } 163 | 164 | /// Initialize a new [CSP block](CspBlock) module. 165 | pub fn init(&self, device: &Device) -> CspBlock { 166 | CspBlock { 167 | conv: self.conv.init(device), 168 | c3: self.c3.init(device), 169 | spp: self.spp.as_ref().map(|spp| spp.init(device)), 170 | } 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /yolox-burn/src/model/head.rs: -------------------------------------------------------------------------------- 1 | use alloc::{vec, vec::Vec}; 2 | use burn::{ 3 | module::Module, 4 | nn::{ 5 | conv::{Conv2d, Conv2dConfig}, 6 | Initializer, PaddingConfig2d, 7 | }, 8 | tensor::{activation::sigmoid, backend::Backend, Device, Int, Shape, Tensor}, 9 | }; 10 | use itertools::{izip, multiunzip}; 11 | 12 | use super::{ 13 | blocks::{expand, BaseConv, BaseConvConfig, ConvBlock, ConvBlockConfig}, 14 | pafpn::FpnFeatures, 15 | }; 16 | 17 | const STRIDES: [usize; 3] = [8, 16, 32]; 18 | const IN_CHANNELS: [usize; 3] = [256, 512, 1024]; 19 | const PRIOR_PROB: f64 = 1e-2; 20 | 21 | /// Create a 2D coordinate grid for the specified dimensions. 22 | /// Similar to [`numpy.indices`](https://numpy.org/doc/stable/reference/generated/numpy.indices.html) 23 | /// but specific to two dimensions. 24 | fn create_2d_grid(x: usize, y: usize, device: &Device) -> Tensor { 25 | let y_idx = Tensor::arange(0..y as i64, device) 26 | .reshape::<2, _>(Shape::new([y, 1])) 27 | .repeat_dim(1, x) 28 | .reshape::<2, _>(Shape::new([y, x])); 29 | let x_idx = Tensor::arange(0..x as i64, device) 30 | .reshape::<2, _>(Shape::new([1, x])) // can only repeat with dim=1 31 | .repeat_dim(0, y) 32 | .reshape(Shape::new([y, x])); 33 | 34 | Tensor::stack(vec![x_idx, y_idx], 2) 35 | } 36 | 37 | /// YOLOX head. 38 | #[derive(Module, Debug)] 39 | pub struct Head { 40 | stems: Vec>, 41 | cls_convs: Vec>, 42 | reg_convs: Vec>, 43 | cls_preds: Vec>, 44 | reg_preds: Vec>, 45 | obj_preds: Vec>, 46 | } 47 | 48 | impl Head { 49 | pub fn forward(&self, x: FpnFeatures) -> Tensor { 50 | let features: [Tensor; 3] = [x.0, x.1, x.2]; 51 | 52 | // Outputs for each feature map 53 | let (outputs, shapes): (Vec>, Vec<(usize, usize)>) = izip!( 54 | features, 55 | &self.stems, 56 | &self.cls_convs, 57 | &self.cls_preds, 58 | &self.reg_convs, 59 | &self.reg_preds, 60 | &self.obj_preds, 61 | &STRIDES 62 | ) 63 | .map( 64 | |(feat, stem, cls_conv, cls_pred, reg_conv, reg_pred, obj_pred, _stride)| { 65 | let feat = stem.forward(feat); 66 | 67 | let cls_feat = cls_conv.forward(feat.clone()); 68 | let cls_out = cls_pred.forward(cls_feat); 69 | 70 | let reg_feat = reg_conv.forward(feat); 71 | let reg_out = reg_pred.forward(reg_feat.clone()); 72 | 73 | let obj_out = obj_pred.forward(reg_feat); 74 | 75 | // Output [B, 5 + num_classes, num_anchors] 76 | let out = Tensor::cat(vec![reg_out, sigmoid(obj_out), sigmoid(cls_out)], 1); 77 | let [_, _, h, w] = out.dims(); 78 | (out.flatten(2, 3), (h, w)) 79 | }, 80 | ) 81 | .unzip(); 82 | 83 | // 1. Concat all regression outputs 84 | // 2. Permute shape to [B, num_anchors_total, 5 + num_classes] 85 | // 3. Decode absolute bounding box values 86 | self.decode(Tensor::cat(outputs, 2).swap_dims(2, 1), shapes.as_ref()) 87 | } 88 | 89 | /// Decode bounding box absolute values from regression output offsets. 90 | fn decode(&self, outputs: Tensor, shapes: &[(usize, usize)]) -> Tensor { 91 | let device = outputs.device(); 92 | let [b, num_anchors, num_outputs] = outputs.dims(); 93 | 94 | let (grids, strides) = shapes 95 | .iter() 96 | .zip(STRIDES) 97 | .map(|((h, w), stride)| { 98 | // Grid (x, y) coordinates 99 | let num_anchors = w * h; 100 | let grid = 101 | create_2d_grid::(*w, *h, &device).reshape(Shape::new([1, num_anchors, 2])); 102 | let strides: Tensor = 103 | Tensor::full(Shape::new([1, num_anchors, 1]), stride as i64, &device); 104 | 105 | (grid, strides) 106 | }) 107 | .unzip(); 108 | 109 | let grids = Tensor::cat(grids, 1).float(); 110 | let strides = Tensor::cat(strides, 1).float(); 111 | 112 | Tensor::cat( 113 | vec![ 114 | // Add grid offset to center coordinates and scale to image dimensions 115 | (outputs.clone().slice([0..b, 0..num_anchors, 0..2]) + grids) * strides.clone(), 116 | // Decode `log` encoded boxes with `exp`and scale to image dimensions 117 | outputs.clone().slice([0..b, 0..num_anchors, 2..4]).exp() * strides, 118 | // Classification outputs 119 | outputs.slice([0..b, 0..num_anchors, 4..num_outputs]), 120 | ], 121 | 2, 122 | ) 123 | } 124 | } 125 | 126 | /// [YOLOX head](Head) configuration. 127 | pub struct HeadConfig { 128 | stems: Vec, 129 | cls_convs: Vec, 130 | reg_convs: Vec, 131 | cls_preds: Vec, 132 | reg_preds: Vec, 133 | obj_preds: Vec, 134 | } 135 | 136 | impl HeadConfig { 137 | /// Create a new instance of the YOLOX head [config](HeadConfig). 138 | pub fn new(num_classes: usize, width: f64, depthwise: bool) -> Self { 139 | let hidden_channels: usize = 256; 140 | // Initialize conv2d biases for classification and objectness heads 141 | let bias = -f64::ln((1.0 - PRIOR_PROB) / PRIOR_PROB); 142 | 143 | let (stems, cls_convs, reg_convs, cls_preds, reg_preds, obj_preds) = 144 | multiunzip(IN_CHANNELS.into_iter().map(|in_channels| { 145 | let stem = BaseConvConfig::new( 146 | expand(in_channels, width), 147 | expand(hidden_channels, width), 148 | 1, 149 | 1, 150 | 1, 151 | ); 152 | 153 | let cls_conv = 154 | ConvBlockConfig::new(expand(hidden_channels, width), 3, 1, depthwise); 155 | let reg_conv = 156 | ConvBlockConfig::new(expand(hidden_channels, width), 3, 1, depthwise); 157 | 158 | let cls_pred = 159 | Conv2dConfig::new([expand(hidden_channels, width), num_classes], [1, 1]) 160 | .with_padding(PaddingConfig2d::Explicit(0, 0)) 161 | .with_initializer(Initializer::Constant { value: bias }); 162 | let reg_pred = Conv2dConfig::new([expand(hidden_channels, width), 4], [1, 1]) 163 | .with_padding(PaddingConfig2d::Explicit(0, 0)); 164 | let obj_pred = Conv2dConfig::new([expand(hidden_channels, width), 1], [1, 1]) 165 | .with_padding(PaddingConfig2d::Explicit(0, 0)) 166 | .with_initializer(Initializer::Constant { value: bias }); 167 | 168 | (stem, cls_conv, reg_conv, cls_pred, reg_pred, obj_pred) 169 | })); 170 | 171 | Self { 172 | stems, 173 | cls_convs, 174 | reg_convs, 175 | cls_preds, 176 | reg_preds, 177 | obj_preds, 178 | } 179 | } 180 | 181 | /// Initialize a new [YOLOX head](Head) module. 182 | pub fn init(&self, device: &Device) -> Head { 183 | Head { 184 | stems: self.stems.iter().map(|m| m.init(device)).collect(), 185 | cls_convs: self.cls_convs.iter().map(|m| m.init(device)).collect(), 186 | reg_convs: self.reg_convs.iter().map(|m| m.init(device)).collect(), 187 | cls_preds: self.cls_preds.iter().map(|m| m.init(device)).collect(), 188 | reg_preds: self.reg_preds.iter().map(|m| m.init(device)).collect(), 189 | obj_preds: self.obj_preds.iter().map(|m| m.init(device)).collect(), 190 | } 191 | } 192 | } 193 | -------------------------------------------------------------------------------- /yolox-burn/src/model/mod.rs: -------------------------------------------------------------------------------- 1 | mod blocks; 2 | mod bottleneck; 3 | pub mod boxes; 4 | mod darknet; 5 | mod head; 6 | mod pafpn; 7 | pub mod weights; 8 | pub mod yolox; 9 | 10 | pub use boxes::BoundingBox; 11 | -------------------------------------------------------------------------------- /yolox-burn/src/model/pafpn.rs: -------------------------------------------------------------------------------- 1 | use alloc::vec; 2 | use burn::{ 3 | module::Module, 4 | tensor::{ 5 | backend::Backend, 6 | module::interpolate, 7 | ops::{InterpolateMode, InterpolateOptions}, 8 | Device, Tensor, 9 | }, 10 | }; 11 | 12 | use super::{ 13 | blocks::{expand, BaseConv, BaseConvConfig, Conv, ConvConfig}, 14 | bottleneck::{CspBottleneck, CspBottleneckConfig}, 15 | darknet::{CspDarknet, CspDarknetConfig}, 16 | }; 17 | 18 | pub struct FpnFeatures(pub Tensor, pub Tensor, pub Tensor); 19 | 20 | /// [PAFPN](https://paperswithcode.com/method/pafpn) is the feature pyramid module used in 21 | /// [Path Aggregation Network](https://arxiv.org/abs/1803.01534) that combines FPNs with 22 | /// bottom-up path augmentation. 23 | #[derive(Module, Debug)] 24 | pub struct Pafpn { 25 | backbone: CspDarknet, 26 | lateral_conv0: BaseConv, 27 | c3_n3: CspBottleneck, 28 | c3_n4: CspBottleneck, 29 | c3_p3: CspBottleneck, 30 | c3_p4: CspBottleneck, 31 | reduce_conv1: BaseConv, 32 | bu_conv1: Conv, // bottom-up conv 33 | bu_conv2: Conv, // bottom-up conv 34 | } 35 | 36 | impl Pafpn { 37 | pub fn forward(&self, x: Tensor) -> FpnFeatures { 38 | fn upsample(x_in: Tensor, scale: usize) -> Tensor { 39 | let [_, _, h, w] = x_in.dims(); 40 | interpolate( 41 | x_in, 42 | [h * scale, w * scale], 43 | InterpolateOptions::new(InterpolateMode::Nearest), 44 | ) 45 | } 46 | 47 | // Backbone features 48 | let features = self.backbone.forward(x); 49 | 50 | let fpn_out0 = self.lateral_conv0.forward(features.2); 51 | let f_out0 = upsample(fpn_out0.clone(), 2); 52 | let f_out0 = Tensor::cat(vec![f_out0, features.1], 1); 53 | let f_out0 = self.c3_p4.forward(f_out0); 54 | 55 | let fpn_out1 = self.reduce_conv1.forward(f_out0); 56 | let f_out1 = upsample(fpn_out1.clone(), 2); 57 | let f_out1 = Tensor::cat(vec![f_out1, features.0], 1); 58 | let pan_out2 = self.c3_p3.forward(f_out1); 59 | 60 | let p_out1 = self.bu_conv2.forward(pan_out2.clone()); 61 | let p_out1 = Tensor::cat(vec![p_out1, fpn_out1], 1); 62 | let pan_out1 = self.c3_n3.forward(p_out1); 63 | 64 | let p_out0 = self.bu_conv1.forward(pan_out1.clone()); 65 | let p_out0 = Tensor::cat(vec![p_out0, fpn_out0], 1); 66 | let pan_out0 = self.c3_n4.forward(p_out0); 67 | 68 | FpnFeatures(pan_out2, pan_out1, pan_out0) 69 | } 70 | } 71 | 72 | /// [PAFPN block](Pafpn) configuration. 73 | pub struct PafpnConfig { 74 | backbone: CspDarknetConfig, 75 | lateral_conv0: BaseConvConfig, 76 | c3_n3: CspBottleneckConfig, 77 | c3_n4: CspBottleneckConfig, 78 | c3_p3: CspBottleneckConfig, 79 | c3_p4: CspBottleneckConfig, 80 | reduce_conv1: BaseConvConfig, 81 | bu_conv1: ConvConfig, // bottom-up conv 82 | bu_conv2: ConvConfig, // bottom-up conv 83 | } 84 | 85 | impl PafpnConfig { 86 | /// Create a new instance of the PAFPN [config](PafpnConfig). 87 | pub fn new(depth: f64, width: f64, depthwise: bool) -> Self { 88 | assert!( 89 | [0.33, 0.67, 1.0, 1.33].contains(&depth), 90 | "invalid depth value {depth}" 91 | ); 92 | assert!( 93 | [0.25, 0.375, 0.5, 0.75, 1.0, 1.25].contains(&width), 94 | "invalid width value {width}" 95 | ); 96 | 97 | let in_channels: [usize; 3] = [256, 512, 1024]; 98 | let hidden_channels: [usize; 2] = [ 99 | expand(2 * in_channels[0], width), 100 | expand(2 * in_channels[1], width), 101 | ]; 102 | let in_channels: [usize; 3] = [ 103 | expand(in_channels[0], width), 104 | expand(in_channels[1], width), 105 | expand(in_channels[2], width), 106 | ]; 107 | let num_blocks = (3_f64 * depth).round() as usize; 108 | 109 | let backbone = CspDarknetConfig::new(depth, width, depthwise); 110 | let lateral_conv0 = BaseConvConfig::new(in_channels[2], in_channels[1], 1, 1, 1); 111 | let c3_p4 = CspBottleneckConfig::new( 112 | hidden_channels[1], 113 | in_channels[1], 114 | num_blocks, 115 | 0.5, 116 | false, 117 | depthwise, 118 | ); 119 | 120 | let reduce_conv1 = BaseConvConfig::new(in_channels[1], in_channels[0], 1, 1, 1); 121 | let c3_p3 = CspBottleneckConfig::new( 122 | hidden_channels[0], 123 | in_channels[0], 124 | num_blocks, 125 | 0.5, 126 | false, 127 | depthwise, 128 | ); 129 | 130 | let bu_conv2 = ConvConfig::new(in_channels[0], in_channels[0], 3, 2, depthwise); 131 | let c3_n3 = CspBottleneckConfig::new( 132 | hidden_channels[0], 133 | in_channels[1], 134 | num_blocks, 135 | 0.5, 136 | false, 137 | depthwise, 138 | ); 139 | 140 | let bu_conv1 = ConvConfig::new(in_channels[1], in_channels[1], 3, 2, depthwise); 141 | let c3_n4 = CspBottleneckConfig::new( 142 | hidden_channels[1], 143 | in_channels[2], 144 | num_blocks, 145 | 0.5, 146 | false, 147 | depthwise, 148 | ); 149 | 150 | Self { 151 | backbone, 152 | lateral_conv0, 153 | c3_n3, 154 | c3_n4, 155 | c3_p3, 156 | c3_p4, 157 | reduce_conv1, 158 | bu_conv1, 159 | bu_conv2, 160 | } 161 | } 162 | 163 | /// Initialize a new [PAFPN](Pafpn) module. 164 | pub fn init(&self, device: &Device) -> Pafpn { 165 | Pafpn { 166 | backbone: self.backbone.init(device), 167 | lateral_conv0: self.lateral_conv0.init(device), 168 | c3_n3: self.c3_n3.init(device), 169 | c3_n4: self.c3_n4.init(device), 170 | c3_p3: self.c3_p3.init(device), 171 | c3_p4: self.c3_p4.init(device), 172 | reduce_conv1: self.reduce_conv1.init(device), 173 | bu_conv1: self.bu_conv1.init(device), 174 | bu_conv2: self.bu_conv2.init(device), 175 | } 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /yolox-burn/src/model/weights.rs: -------------------------------------------------------------------------------- 1 | /// Pre-trained weights metadata. 2 | pub struct Weights { 3 | pub(super) url: &'static str, 4 | pub(super) num_classes: usize, 5 | } 6 | 7 | #[cfg(feature = "pretrained")] 8 | mod downloader { 9 | use super::*; 10 | use burn::data::network::downloader; 11 | use std::fs::{create_dir_all, File}; 12 | use std::io::Write; 13 | use std::path::PathBuf; 14 | 15 | impl Weights { 16 | /// Download the pre-trained weights to the local cache directory. 17 | pub fn download(&self) -> Result { 18 | // Model cache directory 19 | let model_dir = dirs::home_dir() 20 | .expect("Should be able to get home directory") 21 | .join(".cache") 22 | .join("yolox-burn"); 23 | 24 | if !model_dir.exists() { 25 | create_dir_all(&model_dir)?; 26 | } 27 | 28 | let file_base_name = self.url.rsplit_once('/').unwrap().1; 29 | let file_name = model_dir.join(file_base_name); 30 | if !file_name.exists() { 31 | // Download file content 32 | let bytes = downloader::download_file_as_bytes(self.url, file_base_name); 33 | 34 | // Write content to file 35 | let mut output_file = File::create(&file_name)?; 36 | let bytes_written = output_file.write(&bytes)?; 37 | 38 | if bytes_written != bytes.len() { 39 | return Err(std::io::Error::new( 40 | std::io::ErrorKind::InvalidData, 41 | "Failed to write the whole model weights file.", 42 | )); 43 | } 44 | } 45 | 46 | Ok(file_name) 47 | } 48 | } 49 | } 50 | 51 | pub trait WeightsMeta { 52 | fn weights(&self) -> Weights; 53 | } 54 | 55 | /// YOLOX-Nano pre-trained weights. 56 | pub enum YoloxNano { 57 | /// These weights were released after the original paper implementation with slightly better results. 58 | /// mAP (val2017): 25.8 59 | Coco, 60 | } 61 | impl WeightsMeta for YoloxNano { 62 | fn weights(&self) -> Weights { 63 | Weights { 64 | url: "https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_nano.pth", 65 | num_classes: 80, 66 | } 67 | } 68 | } 69 | 70 | /// YOLOX-Tiny pre-trained weights. 71 | pub enum YoloxTiny { 72 | /// These weights were released after the original paper implementation with slightly better results. 73 | /// mAP (val2017): 32.8 74 | Coco, 75 | } 76 | impl WeightsMeta for YoloxTiny { 77 | fn weights(&self) -> Weights { 78 | Weights { 79 | url: "https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_tiny.pth", 80 | num_classes: 80, 81 | } 82 | } 83 | } 84 | 85 | /// YOLOX-S pre-trained weights. 86 | pub enum YoloxS { 87 | /// These weights were released after the original paper implementation with slightly better results. 88 | /// mAP (test2017): 40.5 89 | Coco, 90 | } 91 | impl WeightsMeta for YoloxS { 92 | fn weights(&self) -> Weights { 93 | Weights { 94 | url: "https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_s.pth", 95 | num_classes: 80, 96 | } 97 | } 98 | } 99 | 100 | /// YOLOX-M pre-trained weights. 101 | pub enum YoloxM { 102 | /// These weights were released after the original paper implementation with slightly better results. 103 | /// mAP (test2017): 47.2 104 | Coco, 105 | } 106 | impl WeightsMeta for YoloxM { 107 | fn weights(&self) -> Weights { 108 | Weights { 109 | url: "https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_m.pth", 110 | num_classes: 80, 111 | } 112 | } 113 | } 114 | 115 | /// YOLOX-L pre-trained weights. 116 | pub enum YoloxL { 117 | /// These weights were released after the original paper implementation with slightly better results. 118 | /// mAP (test2017): 50.1 119 | Coco, 120 | } 121 | impl WeightsMeta for YoloxL { 122 | fn weights(&self) -> Weights { 123 | Weights { 124 | url: "https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_l.pth", 125 | num_classes: 80, 126 | } 127 | } 128 | } 129 | 130 | /// YOLOX-X pre-trained weights. 131 | pub enum YoloxX { 132 | /// These weights were released after the original paper implementation with slightly better results. 133 | /// mAP (test2017): 51.5 134 | Coco, 135 | } 136 | impl WeightsMeta for YoloxX { 137 | fn weights(&self) -> Weights { 138 | Weights { 139 | url: "https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_x.pth", 140 | num_classes: 80, 141 | } 142 | } 143 | } 144 | --------------------------------------------------------------------------------