├── src ├── lib.rs ├── models │ ├── mod.rs │ ├── modelling_outputs.rs │ ├── model_utils.rs │ ├── roberta.rs │ └── xlm_roberta.rs ├── main.rs └── utils.rs ├── assets └── layer_norm.png ├── .gitignore ├── Cargo.toml ├── tests ├── test_xlm_roberta.rs └── test_roberta.rs ├── examples ├── tensor.rs └── roberta_tutorial.rs ├── python └── roberta_main.py └── README.md /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod models; 2 | pub mod utils; -------------------------------------------------------------------------------- /assets/layer_norm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ToluClassics/candle-tutorial/HEAD/assets/layer_norm.png -------------------------------------------------------------------------------- /src/models/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod roberta; 2 | pub mod xlm_roberta; 3 | 4 | pub mod model_utils; 5 | pub mod modelling_outputs; -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | 16 | .vscode/ 17 | .DS_Store 18 | .idea/ -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "candle-tutorial" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.0" } 10 | candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.0" } 11 | serde = "1.0" 12 | anyhow = "1.0" 13 | hf-hub = "0.3.2" 14 | tokenizers = "0.14.1" 15 | serde_json = "1.0.107" 16 | records = "0.2.0" 17 | 18 | 19 | [[example]] 20 | name = "tensor" 21 | 22 | -------------------------------------------------------------------------------- /src/models/modelling_outputs.rs: -------------------------------------------------------------------------------- 1 | use candle_core::Tensor; 2 | 3 | 4 | #[derive(Debug)] 5 | #[records::record] 6 | pub struct SequenceClassifierOutput { 7 | loss: Option, 8 | logits: Tensor, 9 | hidden_states: Option, 10 | attentions: Option 11 | } 12 | 13 | #[derive(Debug)] 14 | #[records::record] 15 | pub struct TokenClassifierOutput { 16 | loss: Option, 17 | logits: Tensor, 18 | hidden_states: Option, 19 | attentions: Option 20 | } 21 | 22 | #[derive(Debug)] 23 | #[records::record] 24 | pub struct QuestionAnsweringModelOutput { 25 | loss: Option, 26 | start_logits: Tensor, 27 | end_logits: Tensor, 28 | hidden_states: Option, 29 | attentions: Option 30 | } -------------------------------------------------------------------------------- /tests/test_xlm_roberta.rs: -------------------------------------------------------------------------------- 1 | mod tests { 2 | use candle_tutorial::models::xlm_roberta::XLMRobertaModel; 3 | use candle_tutorial::models::xlm_roberta::XLMRobertaForTokenClassification ; 4 | use candle_tutorial::utils::{build_roberta_model_and_tokenizer, ModelType}; 5 | 6 | use anyhow::Result; 7 | use candle_core::Tensor; 8 | 9 | 10 | // https://github.com/huggingface/transformers/blob/46092f763d26eb938a937c2a9cc69ce1cb6c44c2/tests/models/xlm_roberta/test_modeling_xlm_roberta.py#L32 11 | #[test] 12 | fn test_modeling_xlm_roberta_base () -> Result<()> { 13 | let model_type = "XLMRobertaModel"; 14 | let (model, _tokenizer) = build_roberta_model_and_tokenizer("xlm-roberta-base", false, model_type).unwrap(); 15 | 16 | let model: XLMRobertaModel = match model { 17 | ModelType::XLMRobertaModel {model} => model, 18 | _ => panic!("Invalid model_type") 19 | }; 20 | 21 | let input_ids = &[[0u32, 581, 10269, 83, 99942, 136, 60742, 23, 70, 80583, 18276, 2]]; 22 | let input_ids = Tensor::new(input_ids, &model.device).unwrap(); 23 | 24 | let token_ids = input_ids.zeros_like().unwrap(); 25 | let output = model.forward(&input_ids, &token_ids)?; 26 | 27 | let expected_shape = [1, 12, 768]; 28 | 29 | assert_eq!(output.shape().dims(), &expected_shape); 30 | 31 | Ok(()) 32 | 33 | } 34 | 35 | 36 | #[test] 37 | fn test_inference_token_classification_head() -> Result<()> { 38 | 39 | let model_type = "XLMRobertaForTokenClassification"; 40 | let (model, _tokenizer) = build_roberta_model_and_tokenizer("Davlan/xlm-roberta-base-wikiann-ner", false, model_type).unwrap(); 41 | 42 | let model: XLMRobertaForTokenClassification = match model { 43 | ModelType::XLMRobertaForTokenClassification {model} => model, 44 | _ => panic!("Invalid model_type") 45 | }; 46 | 47 | let input_ids = &[[0u32, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]; 48 | let input_ids = Tensor::new(input_ids, &model.device).unwrap(); 49 | 50 | let token_ids = input_ids.zeros_like().unwrap(); 51 | let output = model.forward(&input_ids, &token_ids, None)?; 52 | 53 | println!("Output: {:?}",candle_nn::ops::softmax(&output.logits, candle_core::D::Minus1)?.to_vec3::()?); 54 | println!("Output: {:?}", output.logits.to_vec3::()?); 55 | 56 | Ok(()) 57 | 58 | } 59 | 60 | 61 | } -------------------------------------------------------------------------------- /examples/tensor.rs: -------------------------------------------------------------------------------- 1 | // use anyhow::Result; 2 | use candle_core::{Device, Tensor, Result}; 3 | 4 | fn tensor_from_data() -> Result<()> { 5 | let data: [u32; 3] = [1u32, 2, 3]; 6 | let tensor = Tensor::new(&data, &Device::Cpu)?; 7 | println!("tensor: {:?}", tensor.to_vec1::()?); 8 | 9 | let nested_data: [[u32; 3]; 3] = [[1u32, 2, 3], [4, 5, 6], [7, 8, 9]]; 10 | let nested_tensor = Tensor::new(&nested_data, &Device::Cpu)?; 11 | println!("nested_tensor: {:?}", nested_tensor.to_vec2::()?); 12 | 13 | Ok(()) 14 | } 15 | 16 | fn tensor_from_another_tensor() -> Result<()> { 17 | let data: [u32; 3] = [1u32, 2, 3]; 18 | let tensor = Tensor::new(&data, &Device::Cpu)?; 19 | let zero_tensor = tensor.zeros_like()?; 20 | 21 | println!("zero_tensor: {:?}", zero_tensor.to_vec1::()?); 22 | 23 | let ones_tensor = tensor.ones_like()?; 24 | println!("ones_tensor: {:?}", ones_tensor.to_vec1::()?); 25 | 26 | let random_tensor = tensor.rand_like(0.0, 1.0)?; 27 | println!( 28 | "random_tensor: {:?}", 29 | random_tensor.to_vec1::().unwrap() 30 | ); 31 | 32 | Ok(()) 33 | } 34 | 35 | pub fn sigmoid(xs: &Tensor) -> Result { 36 | // TODO: Should we have a specialized op for this? 37 | (xs.neg()?.exp()? + 1.0)?.recip() 38 | } 39 | 40 | pub fn binary_cross_entropy(inp: &Tensor, target: &Tensor) -> Result { 41 | let inp = sigmoid(inp)?; 42 | 43 | let one_tensor = Tensor::new(1.0, &inp.device())?; 44 | 45 | let left_side = target * inp.log()?; 46 | let right_side = (one_tensor.broadcast_sub(&target)?) * (one_tensor.broadcast_sub(&inp)?.log()?); 47 | 48 | let loss = left_side? + right_side?; 49 | let loss = loss?.neg()?.mean_all()?; 50 | 51 | Ok(loss) 52 | } 53 | 54 | fn main() { 55 | let _ = tensor_from_data(); 56 | 57 | let _ = tensor_from_another_tensor(); 58 | 59 | let inp = [[ 2.3611f64, -0.8813, -0.5006, -0.2178], 60 | [ 0.0419, 0.0763, -1.0457, -1.6692], 61 | [-1.0494, 0.8111, 1.5723, 1.2315], 62 | [ 1.3081, 0.6641, 1.1802, -0.2547], 63 | [ 0.5292, 0.7636, 0.3692, -0.8318], 64 | [ 0.5100, 0.9849, -1.2905, 0.2821], 65 | [ 1.4662, 0.4550, 0.9875, 0.3143], 66 | [-1.2121, 0.1262, 0.0598, -1.6363], 67 | [ 0.3214, -0.8689, 0.0689, -2.5094], 68 | [ 1.1320, -0.6824, 0.1657, -0.0687]]; 69 | 70 | let target = [[0.0f64, 1., 0., 0.], 71 | [0., 1., 0., 0.], 72 | [0., 0., 0., 1.], 73 | [1., 0., 0., 0.], 74 | [0., 0., 1., 0.], 75 | [1., 0., 0., 0.], 76 | [0., 0., 1., 0.], 77 | [0., 0., 1., 0.], 78 | [0., 1., 0., 0.], 79 | [0., 0., 1., 0.]]; 80 | 81 | let device = Device::Cpu; 82 | 83 | let inp = Tensor::new(&inp, &device).unwrap(); 84 | let target = Tensor::new(&target, &device).unwrap(); 85 | 86 | let loss = binary_cross_entropy(&inp, &target).unwrap(); 87 | 88 | println!("{:?}", loss) 89 | } 90 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Error as E, Result}; 2 | use candle_core::{Device, IndexOp, Tensor}; 3 | use candle_nn::VarBuilder; 4 | use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; 5 | use tokenizers::Tokenizer; 6 | 7 | use candle_tutorial::models::roberta::{RobertaConfig, RobertaModel, FLOATING_DTYPE}; 8 | 9 | fn build_model_and_tokenizer() -> Result<(RobertaModel, Tokenizer)> { 10 | let device = Device::Cpu; 11 | let default_model = "roberta-base".to_string(); 12 | let default_revision = "main".to_string(); 13 | let (model_id, revision) = (default_model, default_revision); 14 | let repo = Repo::with_revision(model_id, RepoType::Model, revision); 15 | let offline = false; 16 | 17 | let (config_filename, tokenizer_filename, weights_filename) = if offline { 18 | let cache = Cache::default().repo(repo); 19 | ( 20 | cache 21 | .get("config.json") 22 | .ok_or(anyhow!("Missing config file in cache"))?, 23 | cache 24 | .get("tokenizer.json") 25 | .ok_or(anyhow!("Missing tokenizer file in cache"))?, 26 | cache 27 | .get("model.safetensors") 28 | .ok_or(anyhow!("Missing weights file in cache"))?, 29 | ) 30 | } else { 31 | let api = Api::new()?; 32 | let api = api.repo(repo); 33 | ( 34 | api.get("config.json")?, 35 | api.get("tokenizer.json")?, 36 | api.get("model.safetensors")?, 37 | ) 38 | }; 39 | 40 | println!("config_filename: {}", config_filename.display()); 41 | 42 | let config = std::fs::read_to_string(config_filename)?; 43 | let config: RobertaConfig = serde_json::from_str(&config)?; 44 | let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; 45 | 46 | let vb = unsafe { 47 | VarBuilder::from_mmaped_safetensors(&[weights_filename], FLOATING_DTYPE, &device)? 48 | }; 49 | let model = RobertaModel::load(vb, &config)?; 50 | Ok((model, tokenizer)) 51 | } 52 | 53 | fn main() -> Result<()> { 54 | let (model, _tokenizer) = build_model_and_tokenizer()?; 55 | let device = &model.device; 56 | 57 | let input_ids = &[ 58 | [0u32, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2], 59 | [0u32, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2], 60 | ]; 61 | let input_ids = Tensor::new(input_ids, &device)?; 62 | 63 | let token_ids = input_ids.zeros_like()?; 64 | 65 | println!("token_ids: {:?}", token_ids.to_vec2::()?); 66 | println!("input_ids: {:?}", input_ids.to_vec2::()?); 67 | 68 | let output = model.forward(&input_ids, &token_ids)?; 69 | // let output = output.squeeze(0)?; 70 | 71 | println!("output: {:?}", output.i((.., 0))?.dims2()); 72 | 73 | let logits = &[[0.1_f32, 0.2], [0.5, 0.6]]; 74 | let logits = Tensor::new(logits, &device)?; 75 | 76 | println!("logits: {:?}", logits.i((.., 0))?.to_vec1::()?); 77 | println!("logits: {:?}", logits.i((.., 1))?.to_vec1::()?); 78 | 79 | 80 | 81 | Ok(()) 82 | } 83 | -------------------------------------------------------------------------------- /python/roberta_main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaEmbeddings, RobertaConfig 5 | from transformers import RobertaForSequenceClassification, RobertaForTokenClassification 6 | 7 | model = RobertaModel.from_pretrained('Davlan/xlm-roberta-base-wikiann-ner') 8 | 9 | # input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2], [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) 10 | # with torch.no_grad(): 11 | # output = model(input_ids) 12 | 13 | 14 | print(model) 15 | 16 | 17 | # # output = output.squeeze(0)[0][:10] 18 | 19 | # # print(model.embeddings(input_ids).squeeze(0)) 20 | # # print(model.embeddings(input_ids).shape) 21 | 22 | # a = output[0][:, 0] 23 | 24 | # print(a) 25 | # print(output[0].shape) 26 | 27 | # b = output[0][:, 0, :] 28 | 29 | # print("========================") 30 | 31 | # print(b) 32 | 33 | # # assert(== ) 34 | 35 | # print(output[0][:, 0].shape) 36 | 37 | # # print(model.embeddings.position_embeddings.weight[0]) 38 | 39 | # # print(model.embeddings.position_embeddings.weight[0]) 40 | 41 | # def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): 42 | # """ 43 | # Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols 44 | # are ignored. This is modified from fairseq's `utils.make_positions`. 45 | 46 | # Args: 47 | # x: torch.Tensor x: 48 | 49 | # Returns: torch.Tensor 50 | # """ 51 | # # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. 52 | # mask = input_ids.ne(padding_idx).int() 53 | # incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask 54 | # return incremental_indices.long() + padding_idx 55 | 56 | # # print(create_position_ids_from_input_ids(input_ids, 1)) 57 | 58 | 59 | # x = torch.Tensor([[ 2.3611, -0.8813, -0.5006, -0.2178], 60 | # [ 0.0419, 0.0763, -1.0457, -1.6692], 61 | # [-1.0494, 0.8111, 1.5723, 1.2315], 62 | # [ 1.3081, 0.6641, 1.1802, -0.2547], 63 | # [ 0.5292, 0.7636, 0.3692, -0.8318]]) 64 | 65 | # # print(x) 66 | 67 | # y = torch.Tensor([[0., 1., 0., 0.], 68 | # [0., 1., 0., 0.], 69 | # [0., 0., 0., 1.], 70 | # [1., 0., 0., 0.], 71 | # [0., 0., 1., 0.]]) 72 | 73 | # def sigmoid(x): 74 | # return (1 + (-x).exp()).reciprocal() 75 | 76 | 77 | # def binary_cross_entropy(input, y): 78 | 79 | # print(pred.log()*y) 80 | # print((1-y)*(1-pred).log()) 81 | # return -(pred.log()*y + (1-y)*(1-pred).log()).mean() 82 | 83 | # pred = sigmoid(x) 84 | # print(pred) 85 | 86 | # print(binary_cross_entropy(pred, y)) 87 | 88 | 89 | # import torch 90 | # import torch.nn.functional as F 91 | 92 | # inp = torch.Tensor([[ 2.3611, -0.8813, -0.5006, -0.2178], 93 | # [ 0.0419, 0.0763, -1.0457, -1.6692], 94 | # [-1.0494, 0.8111, 1.5723, 1.2315], 95 | # [ 1.3081, 0.6641, 1.1802, -0.2547], 96 | # [ 0.5292, 0.7636, 0.3692, -0.8318]]) 97 | 98 | # target = torch.Tensor([[0., 1., 0., 0.], 99 | # [0., 1., 0., 0.], 100 | # [0., 0., 0., 1.], 101 | # [1., 0., 0., 0.], 102 | # [0., 0., 1., 0.]]) 103 | 104 | # print(F.binary_cross_entropy_with_logits(inp, target)) -------------------------------------------------------------------------------- /src/models/model_utils.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{DType, Result, Tensor}; 2 | use serde::Deserialize; 3 | 4 | 5 | 6 | pub fn sigmoid(xs: &Tensor) -> Result { 7 | // TODO: Should we have a specialized op for this? 8 | (xs.neg()?.exp()? + 1.0)?.recip() 9 | } 10 | 11 | pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result { 12 | let inp = sigmoid(inp)?; 13 | 14 | let left_side = target * inp.log()?; 15 | let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?; 16 | 17 | let loss = left_side? + right_side?; 18 | let loss = loss?.neg()?.mean_all()?; 19 | 20 | Ok(loss) 21 | } 22 | 23 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] 24 | #[serde(rename_all = "lowercase")] 25 | pub enum HiddenAct { 26 | Gelu, 27 | Relu, 28 | Tanh 29 | } 30 | 31 | pub struct HiddenActLayer { 32 | act: HiddenAct, 33 | } 34 | 35 | impl HiddenActLayer { 36 | pub fn new(act: HiddenAct) -> Self { 37 | Self { act } 38 | } 39 | 40 | pub fn forward(&self, xs: &Tensor) -> candle_core::Result { 41 | match self.act { 42 | // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 43 | HiddenAct::Gelu => xs.gelu_erf(), 44 | HiddenAct::Relu => xs.relu(), 45 | HiddenAct::Tanh => xs.tanh() 46 | } 47 | } 48 | } 49 | 50 | #[derive(Debug)] 51 | pub struct Linear { 52 | weight: Tensor, 53 | bias: Option, 54 | } 55 | 56 | impl Linear { 57 | pub fn new(weight: Tensor, bias: Option) -> Self { 58 | Self { weight, bias } 59 | } 60 | 61 | pub fn forward(&self, x: &Tensor) -> candle_core::Result { 62 | let w = match x.dims() { 63 | &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, 64 | _ => self.weight.t()?, 65 | }; 66 | let x = x.matmul(&w)?; 67 | match &self.bias { 68 | None => Ok(x), 69 | Some(bias) => x.broadcast_add(bias), 70 | } 71 | } 72 | } 73 | 74 | pub struct Dropout { 75 | #[allow(dead_code)] 76 | pr: f64, 77 | } 78 | 79 | impl Dropout { 80 | pub fn new(pr: f64) -> Self { 81 | Self { pr } 82 | } 83 | 84 | pub fn forward(&self, x: &Tensor) -> Result { 85 | Ok(x.clone()) 86 | } 87 | } 88 | 89 | #[derive(Debug)] 90 | pub struct LayerNorm { 91 | weight: Tensor, 92 | bias: Tensor, 93 | eps: f64, 94 | } 95 | 96 | impl LayerNorm { 97 | pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { 98 | Self { weight, bias, eps } 99 | } 100 | 101 | pub fn forward(&self, x: &Tensor) -> Result { 102 | let x_dtype = x.dtype(); 103 | let internal_dtype = match x_dtype { 104 | DType::F16 | DType::BF16 => DType::F32, 105 | d => d, 106 | }; 107 | let (_bsize, _seq_len, hidden_size) = x.dims3()?; 108 | let x = x.to_dtype(internal_dtype)?; 109 | let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; 110 | let x = x.broadcast_sub(&mean_x)?; 111 | let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; 112 | let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; 113 | let x = x_normed 114 | .to_dtype(x_dtype)? 115 | .broadcast_mul(&self.weight)? 116 | .broadcast_add(&self.bias)?; 117 | Ok(x) 118 | } 119 | } 120 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] 121 | #[serde(rename_all = "lowercase")] 122 | pub enum PositionEmbeddingType { 123 | #[default] 124 | Absolute, 125 | } -------------------------------------------------------------------------------- /src/utils.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Error as E, Result}; 2 | use candle_core::Device; 3 | use tokenizers::Tokenizer; 4 | use candle_nn::VarBuilder; 5 | 6 | use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; 7 | 8 | 9 | use crate::models::roberta::{RobertaModel, RobertaConfig,FLOATING_DTYPE}; 10 | use crate::models::roberta::{RobertaForSequenceClassification, RobertaForTokenClassification, RobertaForQuestionAnswering}; 11 | use crate::models::xlm_roberta::XLMRobertaConfig; 12 | use crate::models::xlm_roberta::{XLMRobertaModel, XLMRobertaForSequenceClassification, XLMRobertaForTokenClassification, XLMRobertaForQuestionAnswering}; 13 | 14 | pub enum ModelType { 15 | RobertaModel {model: RobertaModel}, 16 | RobertaForSequenceClassification {model: RobertaForSequenceClassification}, 17 | RobertaForTokenClassification {model: RobertaForTokenClassification}, 18 | RobertaForQuestionAnswering {model: RobertaForQuestionAnswering}, 19 | 20 | XLMRobertaModel {model: XLMRobertaModel}, 21 | XLMRobertaForSequenceClassification {model: XLMRobertaForSequenceClassification}, 22 | XLMRobertaForTokenClassification {model: XLMRobertaForTokenClassification}, 23 | XLMRobertaForQuestionAnswering {model: XLMRobertaForQuestionAnswering}, 24 | } 25 | 26 | pub fn round_to_decimal_places(n: f32, places: u32) -> f32 { 27 | let multiplier: f32 = 10f32.powi(places as i32); 28 | (n * multiplier).round() / multiplier 29 | } 30 | 31 | pub fn build_roberta_model_and_tokenizer(model_name_or_path: impl Into, offline: bool, model_type: &str) -> Result<(ModelType, Tokenizer)> { 32 | let device = Device::Cpu; 33 | let (model_id, revision) = (model_name_or_path.into(), "main".to_string()); 34 | let repo = Repo::with_revision(model_id, RepoType::Model, revision); 35 | 36 | let (config_filename, tokenizer_filename, weights_filename) = if offline { 37 | let cache = Cache::default().repo(repo); 38 | ( 39 | cache 40 | .get("config.json") 41 | .ok_or(anyhow!("Missing config file in cache"))?, 42 | cache 43 | .get("tokenizer.json") 44 | .ok_or(anyhow!("Missing tokenizer file in cache"))?, 45 | cache 46 | .get("model.safetensors") 47 | .ok_or(anyhow!("Missing weights file in cache"))?, 48 | ) 49 | } else { 50 | let api = Api::new()?; 51 | let api = api.repo(repo); 52 | ( 53 | api.get("config.json")?, 54 | api.get("tokenizer.json")?, 55 | api.get("model.safetensors")?, 56 | ) 57 | }; 58 | 59 | println!("config_filename: {}", config_filename.display()); 60 | println!("tokenizer_filename: {}", tokenizer_filename.display()); 61 | println!("weights_filename: {}", weights_filename.display()); 62 | 63 | 64 | let config = std::fs::read_to_string(config_filename)?; 65 | let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; 66 | 67 | let vb = 68 | unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], FLOATING_DTYPE, &device)? }; 69 | 70 | let model = match model_type { 71 | "RobertaModel" => { 72 | let config: RobertaConfig = serde_json::from_str(&config)?; 73 | let model = RobertaModel::load(vb, &config)?; 74 | ModelType::RobertaModel {model} 75 | } 76 | "RobertaForSequenceClassification" => { 77 | let config: RobertaConfig = serde_json::from_str(&config)?; 78 | let model = RobertaForSequenceClassification::load(vb, &config)?; 79 | ModelType::RobertaForSequenceClassification {model} 80 | } 81 | "RobertaForTokenClassification" => { 82 | let config: RobertaConfig = serde_json::from_str(&config)?; 83 | let model = RobertaForTokenClassification::load(vb, &config)?; 84 | ModelType::RobertaForTokenClassification {model} 85 | } 86 | "RobertaForQuestionAnswering" => { 87 | let config: RobertaConfig = serde_json::from_str(&config)?; 88 | let model = RobertaForQuestionAnswering::load(vb, &config)?; 89 | ModelType::RobertaForQuestionAnswering {model} 90 | } 91 | "XLMRobertaModel" => { 92 | let config: XLMRobertaConfig = serde_json::from_str(&config)?; 93 | println!("config: {:?}", config); 94 | let model = XLMRobertaModel::load(vb, &config)?; 95 | ModelType::XLMRobertaModel {model} 96 | } 97 | "XLMRobertaForSequenceClassification" => { 98 | let config: XLMRobertaConfig = serde_json::from_str(&config)?; 99 | let model = XLMRobertaForSequenceClassification::load(vb, &config)?; 100 | ModelType::XLMRobertaForSequenceClassification {model} 101 | } 102 | "XLMRobertaForTokenClassification" => { 103 | let config: XLMRobertaConfig = serde_json::from_str(&config)?; 104 | let model = XLMRobertaForTokenClassification::load(vb, &config)?; 105 | ModelType::XLMRobertaForTokenClassification {model} 106 | } 107 | "XLMRobertaForQuestionAnswering" => { 108 | let config: XLMRobertaConfig = serde_json::from_str(&config)?; 109 | let model = XLMRobertaForQuestionAnswering::load(vb, &config)?; 110 | ModelType::XLMRobertaForQuestionAnswering {model} 111 | } 112 | _ => panic!("Invalid model_type") 113 | }; 114 | 115 | Ok((model, tokenizer)) 116 | } -------------------------------------------------------------------------------- /tests/test_roberta.rs: -------------------------------------------------------------------------------- 1 | mod tests { 2 | use candle_tutorial::models::roberta::{RobertaEmbeddings,RobertaModel, RobertaConfig, create_position_ids_from_input_ids}; 3 | use candle_tutorial::models::roberta::{RobertaForSequenceClassification, RobertaForTokenClassification }; 4 | use candle_tutorial::utils::{build_roberta_model_and_tokenizer, ModelType, round_to_decimal_places}; 5 | 6 | use anyhow::Result; 7 | use candle_nn::VarBuilder; 8 | use candle_core::{DType, Device, Tensor}; 9 | 10 | // Regression_test = https://github.com/huggingface/transformers/blob/21dc5859421cf0d7d82d374b10f533611745a8c5/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py#L496 11 | #[test] 12 | fn test_create_position_ids_from_input_embeds() -> Result<()> { 13 | 14 | let config = RobertaConfig::default(); 15 | let vb = VarBuilder::zeros(DType::F32, &Device::Cpu); 16 | let embeddings_module = RobertaEmbeddings::load(vb, &config).unwrap(); 17 | 18 | let input_embeds = Tensor::randn(0f32, 1f32, (2, 4, 30), &Device::Cpu).unwrap(); 19 | let position_ids = embeddings_module.create_position_ids_from_input_embeds(&input_embeds); 20 | 21 | let expected_tensor: &[[u32; 4]; 2] = &[ 22 | [0 + embeddings_module.padding_idx + 1, 1 + embeddings_module.padding_idx + 1, 2 + embeddings_module.padding_idx + 1, 3 + embeddings_module.padding_idx + 1,], 23 | [0 + embeddings_module.padding_idx + 1, 1 + embeddings_module.padding_idx + 1, 2 + embeddings_module.padding_idx + 1, 3 + embeddings_module.padding_idx + 1,] 24 | ]; 25 | 26 | assert_eq!(position_ids.unwrap().to_vec2::()?, expected_tensor); 27 | 28 | Ok(()) 29 | 30 | } 31 | 32 | #[test] 33 | fn test_create_position_ids_from_input_ids() -> Result<()> { 34 | 35 | let config = RobertaConfig::default(); 36 | 37 | let vb = VarBuilder::zeros(DType::F32, &Device::Cpu); 38 | let embeddings_module = RobertaEmbeddings::load(vb, &config).unwrap(); 39 | 40 | let input_ids = &[[0u32, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]; 41 | let input_ids = Tensor::new(input_ids, &Device::Cpu)?; 42 | 43 | let position_ids = create_position_ids_from_input_ids(&input_ids, embeddings_module.padding_idx, 1)?; 44 | 45 | let expected_tensor = &[[2u8, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]; 46 | 47 | assert_eq!(position_ids.to_vec2::()?, expected_tensor); 48 | 49 | 50 | 51 | Ok(()) 52 | 53 | 54 | } 55 | 56 | // https://github.com/huggingface/transformers/blob/e1cec43415e72c9853288d4e9325b734d36dd617/tests/models/roberta/test_modeling_roberta.py#L548 57 | #[test] 58 | fn test_modeling_roberta_base () -> Result<()> { 59 | let model_type = "RobertaModel"; 60 | let (model, _tokenizer) = build_roberta_model_and_tokenizer("roberta-base", false, model_type).unwrap(); 61 | 62 | let model: RobertaModel = match model { 63 | ModelType::RobertaModel {model} => model, 64 | _ => panic!("Invalid model_type") 65 | }; 66 | 67 | let input_ids = &[[0u32, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]; 68 | let input_ids = Tensor::new(input_ids, &model.device).unwrap(); 69 | 70 | let token_ids = input_ids.zeros_like().unwrap(); 71 | let output = model.forward(&input_ids, &token_ids)?; 72 | 73 | let expected_shape = [1, 11, 768]; 74 | 75 | assert_eq!(output.shape().dims(), &expected_shape); 76 | 77 | let output = output.squeeze(0)?; 78 | let output = output.to_vec2::()?; 79 | let output: Vec> = output.iter().take(3).map(|nested_vec| nested_vec.iter().take(3).map(|&x| round_to_decimal_places(x, 4)).collect()).collect(); 80 | 81 | let expected_output = [[-0.0231, 0.0782, 0.0074], [-0.1854, 0.0540, -0.0175], [0.0548, 0.0799, 0.1687]]; 82 | 83 | assert_eq!(output, expected_output); 84 | 85 | Ok(()) 86 | 87 | } 88 | 89 | 90 | // https://github.com/huggingface/transformers/blob/46092f763d26eb938a937c2a9cc69ce1cb6c44c2/tests/models/roberta/test_modeling_roberta.py#L567 91 | #[test] 92 | fn test_roberta_sequence_classification() -> Result<()> { 93 | 94 | let model_type = "RobertaForSequenceClassification"; 95 | let (model, _tokenizer) = build_roberta_model_and_tokenizer("roberta-large-mnli", false, model_type).unwrap(); 96 | 97 | let model: RobertaForSequenceClassification = match model { 98 | ModelType::RobertaForSequenceClassification {model} => model, 99 | _ => panic!("Invalid model_type") 100 | }; 101 | 102 | let input_ids = &[[0u32, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]; 103 | let input_ids = Tensor::new(input_ids, &model.device).unwrap(); 104 | 105 | let token_ids = input_ids.zeros_like().unwrap(); 106 | let output = model.forward(&input_ids, &token_ids, None)?; 107 | 108 | let expected_shape = [1, 3]; 109 | let expected_output = [[-0.9469, 0.3913, 0.5118]]; 110 | 111 | 112 | assert_eq!(output.logits.shape().dims(), &expected_shape); 113 | 114 | let output = output.logits.to_vec2::()?; 115 | let output: Vec> = output.iter().take(3).map(|nested_vec| nested_vec.iter().take(3).map(|&x| round_to_decimal_places(x, 4)).collect()).collect(); 116 | 117 | assert_eq!(output, expected_output); 118 | 119 | Ok(()) 120 | 121 | } 122 | 123 | #[test] 124 | fn test_roberta_token_classification() -> Result<()> { 125 | 126 | let model_type = "RobertaForTokenClassification"; 127 | let (model, _tokenizer) = build_roberta_model_and_tokenizer("Davlan/xlm-roberta-base-wikiann-ner", false, model_type).unwrap(); 128 | 129 | let model: RobertaForTokenClassification = match model { 130 | ModelType::RobertaForTokenClassification {model} => model, 131 | _ => panic!("Invalid model_type") 132 | }; 133 | 134 | let input_ids = &[[0u32, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]; 135 | let input_ids = Tensor::new(input_ids, &model.device).unwrap(); 136 | 137 | let token_ids = input_ids.zeros_like().unwrap(); 138 | let output = model.forward(&input_ids, &token_ids, None)?; 139 | 140 | println!("Output: {:?}",candle_nn::ops::softmax(&output.logits, candle_core::D::Minus1)?.to_vec3::()?); 141 | 142 | println!("Output: {:?}", output.logits.to_vec3::()?); 143 | 144 | Ok(()) 145 | 146 | } 147 | 148 | #[test] 149 | fn test_roberta_question_answering() -> Result<()> { 150 | 151 | let model_type = "RobertaForTokenClassification"; 152 | let (model, _tokenizer) = build_roberta_model_and_tokenizer("deepset/roberta-base-squad2", false, model_type).unwrap(); 153 | 154 | let model: RobertaForTokenClassification = match model { 155 | ModelType::RobertaForTokenClassification {model} => model, 156 | _ => panic!("Invalid model_type") 157 | }; 158 | 159 | let input_ids = &[[0u32, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]; 160 | let input_ids = Tensor::new(input_ids, &model.device).unwrap(); 161 | 162 | let token_ids = input_ids.zeros_like().unwrap(); 163 | let output = model.forward(&input_ids, &token_ids, None)?; 164 | 165 | println!("Output: {:?}",candle_nn::ops::softmax(&output.logits, candle_core::D::Minus1)?.to_vec3::()?); 166 | 167 | println!("Output: {:?}", output.logits.to_vec3::()?); 168 | 169 | Ok(()) 170 | 171 | } 172 | 173 | 174 | } -------------------------------------------------------------------------------- /examples/roberta_tutorial.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use candle_core::{DType, Device, IndexOp, Result, Tensor}; 4 | use candle_nn::{Embedding, Module, VarBuilder}; 5 | 6 | use serde::Deserialize; 7 | 8 | pub const FLOATING_DTYPE: DType = DType::F32; 9 | pub const LONG_DTYPE: DType = DType::I64; 10 | 11 | pub fn sigmoid(xs: &Tensor) -> Result { 12 | // TODO: Should we have a specialized op for this? 13 | (xs.neg()?.exp()? + 1.0)?.recip() 14 | } 15 | 16 | pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result { 17 | let inp = sigmoid(inp)?; 18 | 19 | let left_side = target * inp.log()?; 20 | let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?; 21 | 22 | let loss = left_side? + right_side?; 23 | let loss = loss?.neg()?.mean_all()?; 24 | 25 | Ok(loss) 26 | } 27 | 28 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] 29 | #[serde(rename_all = "lowercase")] 30 | enum HiddenAct { 31 | Gelu, 32 | Relu, 33 | Tanh 34 | } 35 | 36 | struct HiddenActLayer { 37 | act: HiddenAct, 38 | } 39 | 40 | impl HiddenActLayer { 41 | fn new(act: HiddenAct) -> Self { 42 | Self { act } 43 | } 44 | 45 | fn forward(&self, xs: &Tensor) -> candle_core::Result { 46 | match self.act { 47 | // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 48 | HiddenAct::Gelu => xs.gelu_erf(), 49 | HiddenAct::Relu => xs.relu(), 50 | HiddenAct::Tanh => xs.tanh() 51 | } 52 | } 53 | } 54 | 55 | #[derive(Debug)] 56 | pub struct Linear { 57 | weight: Tensor, 58 | bias: Option, 59 | } 60 | 61 | impl Linear { 62 | pub fn new(weight: Tensor, bias: Option) -> Self { 63 | Self { weight, bias } 64 | } 65 | 66 | pub fn forward(&self, x: &Tensor) -> candle_core::Result { 67 | let w = match x.dims() { 68 | &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, 69 | _ => self.weight.t()?, 70 | }; 71 | let x = x.matmul(&w)?; 72 | match &self.bias { 73 | None => Ok(x), 74 | Some(bias) => x.broadcast_add(bias), 75 | } 76 | } 77 | } 78 | 79 | struct Dropout { 80 | #[allow(dead_code)] 81 | pr: f64, 82 | } 83 | 84 | impl Dropout { 85 | fn new(pr: f64) -> Self { 86 | Self { pr } 87 | } 88 | 89 | fn forward(&self, x: &Tensor) -> Result { 90 | Ok(x.clone()) 91 | } 92 | } 93 | 94 | #[derive(Debug)] 95 | pub struct LayerNorm { 96 | weight: Tensor, 97 | bias: Tensor, 98 | eps: f64, 99 | } 100 | 101 | impl LayerNorm { 102 | pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { 103 | Self { weight, bias, eps } 104 | } 105 | 106 | pub fn forward(&self, x: &Tensor) -> Result { 107 | let x_dtype = x.dtype(); 108 | let internal_dtype = match x_dtype { 109 | DType::F16 | DType::BF16 => DType::F32, 110 | d => d, 111 | }; 112 | let (_bsize, _seq_len, hidden_size) = x.dims3()?; 113 | let x = x.to_dtype(internal_dtype)?; 114 | let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; 115 | let x = x.broadcast_sub(&mean_x)?; 116 | let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; 117 | let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; 118 | let x = x_normed 119 | .to_dtype(x_dtype)? 120 | .broadcast_mul(&self.weight)? 121 | .broadcast_add(&self.bias)?; 122 | Ok(x) 123 | } 124 | } 125 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] 126 | #[serde(rename_all = "lowercase")] 127 | enum PositionEmbeddingType { 128 | #[default] 129 | Absolute, 130 | } 131 | 132 | #[derive(Debug, Clone, PartialEq, Deserialize)] 133 | pub struct RobertaConfig { 134 | vocab_size: usize, 135 | hidden_size: usize, 136 | num_hidden_layers: usize, 137 | num_attention_heads: usize, 138 | intermediate_size: usize, 139 | hidden_act: HiddenAct, 140 | hidden_dropout_prob: f64, 141 | max_position_embeddings: usize, 142 | type_vocab_size: usize, 143 | initializer_range: f64, 144 | layer_norm_eps: f64, 145 | pad_token_id: usize, 146 | bos_token_id: usize, 147 | eos_token_id: usize, 148 | #[serde(default)] 149 | position_embedding_type: PositionEmbeddingType, 150 | #[serde(default)] 151 | use_cache: bool, 152 | classifier_dropout: Option, 153 | model_type: Option, 154 | problem_type: Option, 155 | _num_labels: Option, 156 | id2label: Option>, 157 | label2id: Option> 158 | } 159 | 160 | impl Default for RobertaConfig { 161 | fn default() -> Self { 162 | Self { 163 | vocab_size: 50265, 164 | hidden_size: 768, 165 | num_hidden_layers: 12, 166 | num_attention_heads: 12, 167 | intermediate_size: 3072, 168 | hidden_act: HiddenAct::Gelu, 169 | hidden_dropout_prob: 0.1, 170 | max_position_embeddings: 512, 171 | type_vocab_size: 2, 172 | initializer_range: 0.02, 173 | layer_norm_eps: 1e-12, 174 | pad_token_id: 1, 175 | bos_token_id: 0, 176 | eos_token_id: 2, 177 | position_embedding_type: PositionEmbeddingType::Absolute, 178 | use_cache: true, 179 | classifier_dropout: None, 180 | model_type: Some("roberta".to_string()), 181 | problem_type: None, 182 | _num_labels: Some(3), 183 | id2label: None, 184 | label2id: None 185 | } 186 | } 187 | } 188 | 189 | fn cumsum_2d(mask: &Tensor, dim: u8, device: &Device) -> Result { 190 | let mask = mask.to_vec2::()?; 191 | 192 | let rows = mask.len(); 193 | let cols = mask[0].len(); 194 | 195 | let mut result = mask.clone(); 196 | 197 | match dim { 198 | 0 => { 199 | // Cumulative sum along rows 200 | for i in 0..rows { 201 | for j in 1..cols { 202 | result[i][j] += result[i][j - 1]; 203 | } 204 | } 205 | } 206 | 1 => { 207 | // Cumulative sum along columns 208 | for j in 0..cols { 209 | for i in 1..rows { 210 | result[i][j] += result[i - 1][j]; 211 | } 212 | } 213 | } 214 | _ => panic!("Dimension not supported"), 215 | } 216 | 217 | let result = Tensor::new(result, &device)?; 218 | 219 | Ok(result) 220 | } 221 | 222 | pub fn create_position_ids_from_input_ids( 223 | input_ids: &Tensor, 224 | padding_idx: u32, 225 | past_key_values_length: u8, 226 | ) -> Result { 227 | let mask = input_ids.ne(padding_idx)?; 228 | let incremental_indices = cumsum_2d(&mask, 0, input_ids.device())?; 229 | 230 | let incremental_indices = incremental_indices 231 | .broadcast_add(&Tensor::new(&[past_key_values_length], input_ids.device())?)?; 232 | 233 | Ok(incremental_indices) 234 | } 235 | 236 | fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { 237 | let embeddings = vb.get((vocab_size, hidden_size), "weight")?; 238 | Ok(Embedding::new(embeddings, hidden_size)) 239 | } 240 | 241 | fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { 242 | let weight = vb.get((size2, size1), "weight")?; 243 | let bias = vb.get(size2, "bias")?; 244 | Ok(Linear::new(weight, Some(bias))) 245 | } 246 | 247 | fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { 248 | let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { 249 | (Ok(weight), Ok(bias)) => (weight, bias), 250 | (Err(err), _) | (_, Err(err)) => { 251 | if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { 252 | (weight, bias) 253 | } else { 254 | return Err(err); 255 | } 256 | } 257 | }; 258 | Ok(LayerNorm::new(weight, bias, eps)) 259 | } 260 | 261 | pub struct RobertaEmbeddings { 262 | word_embeddings: Embedding, 263 | position_embeddings: Option, 264 | token_type_embeddings: Embedding, 265 | layer_norm: LayerNorm, 266 | dropout: Dropout, 267 | pub padding_idx: u32, 268 | } 269 | 270 | impl RobertaEmbeddings { 271 | pub fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 272 | let word_embeddings = embedding( 273 | config.vocab_size, 274 | config.hidden_size, 275 | vb.pp("word_embeddings"), 276 | )?; 277 | let position_embeddings = embedding( 278 | config.max_position_embeddings, 279 | config.hidden_size, 280 | vb.pp("position_embeddings"), 281 | )?; 282 | let token_type_embeddings = embedding( 283 | config.type_vocab_size, 284 | config.hidden_size, 285 | vb.pp("token_type_embeddings"), 286 | )?; 287 | let layer_norm = layer_norm( 288 | config.hidden_size, 289 | config.layer_norm_eps, 290 | vb.pp("LayerNorm"), 291 | )?; 292 | let padding_idx = config.pad_token_id as u32; 293 | 294 | Ok(Self { 295 | word_embeddings, 296 | position_embeddings: Some(position_embeddings), 297 | token_type_embeddings, 298 | layer_norm, 299 | dropout: Dropout::new(config.hidden_dropout_prob), 300 | padding_idx, 301 | }) 302 | } 303 | 304 | pub fn forward( 305 | &self, 306 | input_ids: &Tensor, 307 | token_type_ids: &Tensor, 308 | position_ids: Option<&Tensor>, 309 | inputs_embeds: Option<&Tensor>, 310 | ) -> Result { 311 | let position_ids = match position_ids { 312 | Some(ids) => ids.to_owned(), 313 | None => { 314 | if Option::is_some(&inputs_embeds) { 315 | let position_ids = 316 | self.create_position_ids_from_input_embeds(inputs_embeds.unwrap())?; 317 | position_ids 318 | } else { 319 | let position_ids = 320 | create_position_ids_from_input_ids(input_ids, self.padding_idx, 1)?; 321 | position_ids 322 | } 323 | } 324 | }; 325 | 326 | let inputs_embeds: Tensor = match inputs_embeds { 327 | Some(embeds) => embeds.to_owned(), 328 | None => { 329 | let embeds = self.word_embeddings.forward(input_ids)?; 330 | embeds 331 | } 332 | }; 333 | 334 | let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; 335 | let mut embeddings = (inputs_embeds + token_type_embeddings)?; 336 | 337 | if let Some(position_embeddings) = &self.position_embeddings { 338 | embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)? 339 | } 340 | 341 | let embeddings = self.layer_norm.forward(&embeddings)?; 342 | let embeddings = self.dropout.forward(&embeddings)?; 343 | 344 | Ok(embeddings) 345 | } 346 | 347 | pub fn create_position_ids_from_input_embeds(&self, input_embeds: &Tensor) -> Result { 348 | let input_shape = input_embeds.dims3()?; 349 | let seq_length = input_shape.1; 350 | 351 | println!("seq_length: {:?}", seq_length); 352 | let mut position_ids = Tensor::arange( 353 | self.padding_idx + 1, 354 | seq_length as u32 + self.padding_idx + 1, 355 | &Device::Cpu, 356 | )?; 357 | 358 | println!("position_ids: {:?}", position_ids); 359 | 360 | position_ids = position_ids 361 | .unsqueeze(0)? 362 | .expand((input_shape.0, input_shape.1))?; 363 | Ok(position_ids) 364 | } 365 | } 366 | 367 | struct RobertaSelfAttention { 368 | query: Linear, 369 | key: Linear, 370 | value: Linear, 371 | dropout: Dropout, 372 | num_attention_heads: usize, 373 | attention_head_size: usize, 374 | } 375 | 376 | impl RobertaSelfAttention { 377 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 378 | let attention_head_size = config.hidden_size / config.num_attention_heads; 379 | let all_head_size = config.num_attention_heads * attention_head_size; 380 | let dropout = Dropout::new(config.hidden_dropout_prob); 381 | let hidden_size = config.hidden_size; 382 | let query = linear(hidden_size, all_head_size, vb.pp("query"))?; 383 | let value = linear(hidden_size, all_head_size, vb.pp("value"))?; 384 | let key = linear(hidden_size, all_head_size, vb.pp("key"))?; 385 | Ok(Self { 386 | query, 387 | key, 388 | value, 389 | dropout, 390 | num_attention_heads: config.num_attention_heads, 391 | attention_head_size, 392 | }) 393 | } 394 | 395 | fn transpose_for_scores(&self, xs: &Tensor) -> Result { 396 | let mut new_x_shape = xs.dims().to_vec(); 397 | new_x_shape.pop(); 398 | new_x_shape.push(self.num_attention_heads); 399 | new_x_shape.push(self.attention_head_size); 400 | let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; 401 | xs.contiguous() 402 | } 403 | 404 | fn forward(&self, hidden_states: &Tensor) -> Result { 405 | let query_layer = self.query.forward(hidden_states)?; 406 | let key_layer = self.key.forward(hidden_states)?; 407 | let value_layer = self.value.forward(hidden_states)?; 408 | 409 | let query_layer = self.transpose_for_scores(&query_layer)?; 410 | let key_layer = self.transpose_for_scores(&key_layer)?; 411 | let value_layer = self.transpose_for_scores(&value_layer)?; 412 | 413 | let attention_scores = query_layer.matmul(&key_layer.t()?)?; 414 | let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; 415 | let attention_probs = 416 | { candle_nn::ops::softmax(&attention_scores, candle_core::D::Minus1)? }; 417 | let attention_probs = self.dropout.forward(&attention_probs)?; 418 | 419 | let context_layer = attention_probs.matmul(&value_layer)?; 420 | let context_layer = context_layer.transpose(1, 2)?.contiguous()?; 421 | let context_layer = context_layer.flatten_from(candle_core::D::Minus2)?; 422 | Ok(context_layer) 423 | } 424 | } 425 | 426 | struct RobertaSelfOutput { 427 | dense: Linear, 428 | layer_norm: LayerNorm, 429 | dropout: Dropout, 430 | } 431 | 432 | impl RobertaSelfOutput { 433 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 434 | let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; 435 | let layer_norm = layer_norm( 436 | config.hidden_size, 437 | config.layer_norm_eps, 438 | vb.pp("LayerNorm"), 439 | )?; 440 | let dropout = Dropout::new(config.hidden_dropout_prob); 441 | Ok(Self { 442 | dense, 443 | layer_norm, 444 | dropout, 445 | }) 446 | } 447 | 448 | fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { 449 | let hidden_states = self.dense.forward(hidden_states)?; 450 | let hidden_states = self.dropout.forward(&hidden_states)?; 451 | self.layer_norm.forward(&(hidden_states + input_tensor)?) 452 | } 453 | } 454 | 455 | struct RobertaAttention { 456 | self_attention: RobertaSelfAttention, 457 | self_output: RobertaSelfOutput, 458 | } 459 | 460 | impl RobertaAttention { 461 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 462 | let self_attention = RobertaSelfAttention::load(vb.pp("self"), config)?; 463 | let self_output = RobertaSelfOutput::load(vb.pp("output"), config)?; 464 | Ok(Self { 465 | self_attention, 466 | self_output, 467 | }) 468 | } 469 | 470 | fn forward(&self, hidden_states: &Tensor) -> Result { 471 | let self_outputs = self.self_attention.forward(hidden_states)?; 472 | let attention_output = self.self_output.forward(&self_outputs, hidden_states)?; 473 | Ok(attention_output) 474 | } 475 | } 476 | 477 | struct RobertaIntermediate { 478 | dense: Linear, 479 | intermediate_act: HiddenActLayer, 480 | } 481 | 482 | impl RobertaIntermediate { 483 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 484 | let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?; 485 | Ok(Self { 486 | dense, 487 | intermediate_act: HiddenActLayer::new(config.hidden_act), 488 | }) 489 | } 490 | 491 | fn forward(&self, hidden_states: &Tensor) -> Result { 492 | let hidden_states = self.dense.forward(hidden_states)?; 493 | let ys = self.intermediate_act.forward(&hidden_states)?; 494 | Ok(ys) 495 | } 496 | } 497 | 498 | struct RobertaOutput { 499 | dense: Linear, 500 | layer_norm: LayerNorm, 501 | dropout: Dropout, 502 | } 503 | 504 | impl RobertaOutput { 505 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 506 | let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?; 507 | let layer_norm = layer_norm( 508 | config.hidden_size, 509 | config.layer_norm_eps, 510 | vb.pp("LayerNorm"), 511 | )?; 512 | let dropout = Dropout::new(config.hidden_dropout_prob); 513 | Ok(Self { 514 | dense, 515 | layer_norm, 516 | dropout, 517 | }) 518 | } 519 | 520 | fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { 521 | let hidden_states = self.dense.forward(hidden_states)?; 522 | let hidden_states = self.dropout.forward(&hidden_states)?; 523 | self.layer_norm.forward(&(hidden_states + input_tensor)?) 524 | } 525 | } 526 | 527 | struct RobertaLayer { 528 | attention: RobertaAttention, 529 | intermediate: RobertaIntermediate, 530 | output: RobertaOutput, 531 | } 532 | 533 | impl RobertaLayer { 534 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 535 | let attention = RobertaAttention::load(vb.pp("attention"), config)?; 536 | let intermediate = RobertaIntermediate::load(vb.pp("intermediate"), config)?; 537 | let output = RobertaOutput::load(vb.pp("output"), config)?; 538 | Ok(Self { 539 | attention, 540 | intermediate, 541 | output, 542 | }) 543 | } 544 | 545 | fn forward(&self, hidden_states: &Tensor) -> Result { 546 | let attention_output = self.attention.forward(hidden_states)?; 547 | 548 | let intermediate_output = self.intermediate.forward(&attention_output)?; 549 | let layer_output = self 550 | .output 551 | .forward(&intermediate_output, &attention_output)?; 552 | Ok(layer_output) 553 | } 554 | } 555 | 556 | struct RobertaEncoder { 557 | layers: Vec, 558 | } 559 | 560 | impl RobertaEncoder { 561 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 562 | let layers = (0..config.num_hidden_layers) 563 | .map(|index| RobertaLayer::load(vb.pp(&format!("layer.{index}")), config)) 564 | .collect::>>()?; 565 | Ok(RobertaEncoder { layers }) 566 | } 567 | 568 | fn forward(&self, hidden_states: &Tensor) -> Result { 569 | let mut hidden_states = hidden_states.clone(); 570 | for layer in self.layers.iter() { 571 | hidden_states = layer.forward(&hidden_states)? 572 | } 573 | Ok(hidden_states) 574 | } 575 | } 576 | 577 | pub struct RobertaPooler{ 578 | dense: Linear, 579 | activation: HiddenActLayer, 580 | } 581 | 582 | impl RobertaPooler{ 583 | pub fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 584 | let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; 585 | Ok( Self { 586 | dense, 587 | activation: HiddenActLayer::new(HiddenAct::Tanh), 588 | }) 589 | 590 | } 591 | 592 | pub fn forward(&self, hidden_states: &Tensor) -> Result { 593 | // We "pool" the model by simply taking the hidden state corresponding 594 | // to the first token. 595 | 596 | let first_token_sensor = hidden_states.i((.., 0))?; 597 | let pooled_output = self.dense.forward(&first_token_sensor)?; 598 | let pooled_output = self.activation.forward(&pooled_output)?; 599 | 600 | Ok(pooled_output) 601 | } 602 | } 603 | 604 | pub struct RobertaModel { 605 | embeddings: RobertaEmbeddings, 606 | encoder: RobertaEncoder, 607 | pub device: Device, 608 | } 609 | 610 | impl RobertaModel { 611 | pub fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 612 | let (embeddings, encoder) = match ( 613 | RobertaEmbeddings::load(vb.pp("embeddings"), config), 614 | RobertaEncoder::load(vb.pp("encoder"), config), 615 | ) { 616 | (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), 617 | (Err(err), _) | (_, Err(err)) => { 618 | if let Some(model_type) = &config.model_type { 619 | if let (Ok(embeddings), Ok(encoder)) = ( 620 | RobertaEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), 621 | RobertaEncoder::load(vb.pp(&format!("{model_type}.encoder")), config), 622 | ) { 623 | (embeddings, encoder) 624 | } else { 625 | return Err(err); 626 | } 627 | } else { 628 | return Err(err); 629 | } 630 | } 631 | }; 632 | Ok(Self { 633 | embeddings, 634 | encoder, 635 | device: vb.device().clone(), 636 | }) 637 | } 638 | 639 | pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { 640 | let embedding_output = self 641 | .embeddings 642 | .forward(input_ids, token_type_ids, None, None)?; 643 | let sequence_output = self.encoder.forward(&embedding_output)?; 644 | Ok(sequence_output) 645 | } 646 | } 647 | 648 | fn main(){ 649 | 650 | } 651 | -------------------------------------------------------------------------------- /src/models/roberta.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use candle_core::{DType, Device, IndexOp, Result, Tensor}; 4 | use candle_nn::{Embedding, Module, VarBuilder}; 5 | 6 | use crate::models::modelling_outputs::{SequenceClassifierOutput, TokenClassifierOutput, QuestionAnsweringModelOutput}; 7 | use crate::models::model_utils::{Dropout, HiddenAct, Linear, HiddenActLayer, LayerNorm, PositionEmbeddingType}; 8 | use crate::models::model_utils::binary_cross_entropy_with_logit; 9 | use serde::Deserialize; 10 | 11 | pub const FLOATING_DTYPE: DType = DType::F32; 12 | pub const LONG_DTYPE: DType = DType::I64; 13 | 14 | #[derive(Debug, Clone, PartialEq, Deserialize)] 15 | pub struct RobertaConfig { 16 | vocab_size: usize, 17 | hidden_size: usize, 18 | num_hidden_layers: usize, 19 | num_attention_heads: usize, 20 | intermediate_size: usize, 21 | hidden_act: HiddenAct, 22 | hidden_dropout_prob: f64, 23 | max_position_embeddings: usize, 24 | type_vocab_size: usize, 25 | initializer_range: f64, 26 | layer_norm_eps: f64, 27 | pad_token_id: usize, 28 | bos_token_id: usize, 29 | eos_token_id: usize, 30 | #[serde(default)] 31 | position_embedding_type: PositionEmbeddingType, 32 | #[serde(default)] 33 | use_cache: bool, 34 | classifier_dropout: Option, 35 | model_type: Option, 36 | problem_type: Option, 37 | _num_labels: Option, 38 | id2label: Option>, 39 | label2id: Option> 40 | } 41 | 42 | impl Default for RobertaConfig { 43 | fn default() -> Self { 44 | Self { 45 | vocab_size: 50265, 46 | hidden_size: 768, 47 | num_hidden_layers: 12, 48 | num_attention_heads: 12, 49 | intermediate_size: 3072, 50 | hidden_act: HiddenAct::Gelu, 51 | hidden_dropout_prob: 0.1, 52 | max_position_embeddings: 512, 53 | type_vocab_size: 2, 54 | initializer_range: 0.02, 55 | layer_norm_eps: 1e-12, 56 | pad_token_id: 1, 57 | bos_token_id: 0, 58 | eos_token_id: 2, 59 | position_embedding_type: PositionEmbeddingType::Absolute, 60 | use_cache: true, 61 | classifier_dropout: None, 62 | model_type: Some("roberta".to_string()), 63 | problem_type: None, 64 | _num_labels: Some(3), 65 | id2label: None, 66 | label2id: None 67 | } 68 | } 69 | } 70 | 71 | fn cumsum_2d(mask: &Tensor, dim: u8, device: &Device) -> Result { 72 | let mask = mask.to_vec2::()?; 73 | 74 | let rows = mask.len(); 75 | let cols = mask[0].len(); 76 | 77 | let mut result = mask.clone(); 78 | 79 | match dim { 80 | 0 => { 81 | // Cumulative sum along rows 82 | for i in 0..rows { 83 | for j in 1..cols { 84 | result[i][j] += result[i][j - 1]; 85 | } 86 | } 87 | } 88 | 1 => { 89 | // Cumulative sum along columns 90 | for j in 0..cols { 91 | for i in 1..rows { 92 | result[i][j] += result[i - 1][j]; 93 | } 94 | } 95 | } 96 | _ => panic!("Dimension not supported"), 97 | } 98 | 99 | let result = Tensor::new(result, &device)?; 100 | 101 | Ok(result) 102 | } 103 | 104 | pub fn create_position_ids_from_input_ids( 105 | input_ids: &Tensor, 106 | padding_idx: u32, 107 | past_key_values_length: u8, 108 | ) -> Result { 109 | let mask = input_ids.ne(padding_idx)?; 110 | let incremental_indices = cumsum_2d(&mask, 0, input_ids.device())?; 111 | 112 | let incremental_indices = incremental_indices 113 | .broadcast_add(&Tensor::new(&[past_key_values_length], input_ids.device())?)?; 114 | 115 | Ok(incremental_indices) 116 | } 117 | 118 | fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { 119 | let embeddings = vb.get((vocab_size, hidden_size), "weight")?; 120 | Ok(Embedding::new(embeddings, hidden_size)) 121 | } 122 | 123 | fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { 124 | let weight = vb.get((size2, size1), "weight")?; 125 | let bias = vb.get(size2, "bias")?; 126 | Ok(Linear::new(weight, Some(bias))) 127 | } 128 | 129 | fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { 130 | let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { 131 | (Ok(weight), Ok(bias)) => (weight, bias), 132 | (Err(err), _) | (_, Err(err)) => { 133 | if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { 134 | (weight, bias) 135 | } else { 136 | return Err(err); 137 | } 138 | } 139 | }; 140 | Ok(LayerNorm::new(weight, bias, eps)) 141 | } 142 | 143 | pub struct RobertaEmbeddings { 144 | word_embeddings: Embedding, 145 | position_embeddings: Option, 146 | token_type_embeddings: Embedding, 147 | layer_norm: LayerNorm, 148 | dropout: Dropout, 149 | pub padding_idx: u32, 150 | } 151 | 152 | impl RobertaEmbeddings { 153 | pub fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 154 | let word_embeddings = embedding( 155 | config.vocab_size, 156 | config.hidden_size, 157 | vb.pp("word_embeddings"), 158 | )?; 159 | let position_embeddings = embedding( 160 | config.max_position_embeddings, 161 | config.hidden_size, 162 | vb.pp("position_embeddings"), 163 | )?; 164 | let token_type_embeddings = embedding( 165 | config.type_vocab_size, 166 | config.hidden_size, 167 | vb.pp("token_type_embeddings"), 168 | )?; 169 | let layer_norm = layer_norm( 170 | config.hidden_size, 171 | config.layer_norm_eps, 172 | vb.pp("LayerNorm"), 173 | )?; 174 | let padding_idx = config.pad_token_id as u32; 175 | 176 | Ok(Self { 177 | word_embeddings, 178 | position_embeddings: Some(position_embeddings), 179 | token_type_embeddings, 180 | layer_norm, 181 | dropout: Dropout::new(config.hidden_dropout_prob), 182 | padding_idx, 183 | }) 184 | } 185 | 186 | pub fn forward( 187 | &self, 188 | input_ids: &Tensor, 189 | token_type_ids: &Tensor, 190 | position_ids: Option<&Tensor>, 191 | inputs_embeds: Option<&Tensor>, 192 | ) -> Result { 193 | let position_ids = match position_ids { 194 | Some(ids) => ids.to_owned(), 195 | None => { 196 | if Option::is_some(&inputs_embeds) { 197 | let position_ids = 198 | self.create_position_ids_from_input_embeds(inputs_embeds.unwrap())?; 199 | position_ids 200 | } else { 201 | let position_ids = 202 | create_position_ids_from_input_ids(input_ids, self.padding_idx, 1)?; 203 | position_ids 204 | } 205 | } 206 | }; 207 | 208 | let inputs_embeds: Tensor = match inputs_embeds { 209 | Some(embeds) => embeds.to_owned(), 210 | None => { 211 | let embeds = self.word_embeddings.forward(input_ids)?; 212 | embeds 213 | } 214 | }; 215 | 216 | let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; 217 | let mut embeddings = (inputs_embeds + token_type_embeddings)?; 218 | 219 | if let Some(position_embeddings) = &self.position_embeddings { 220 | embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)? 221 | } 222 | 223 | let embeddings = self.layer_norm.forward(&embeddings)?; 224 | let embeddings = self.dropout.forward(&embeddings)?; 225 | 226 | Ok(embeddings) 227 | } 228 | 229 | pub fn create_position_ids_from_input_embeds(&self, input_embeds: &Tensor) -> Result { 230 | let input_shape = input_embeds.dims3()?; 231 | let seq_length = input_shape.1; 232 | 233 | println!("seq_length: {:?}", seq_length); 234 | let mut position_ids = Tensor::arange( 235 | self.padding_idx + 1, 236 | seq_length as u32 + self.padding_idx + 1, 237 | &Device::Cpu, 238 | )?; 239 | 240 | println!("position_ids: {:?}", position_ids); 241 | 242 | position_ids = position_ids 243 | .unsqueeze(0)? 244 | .expand((input_shape.0, input_shape.1))?; 245 | Ok(position_ids) 246 | } 247 | } 248 | 249 | struct RobertaSelfAttention { 250 | query: Linear, 251 | key: Linear, 252 | value: Linear, 253 | dropout: Dropout, 254 | num_attention_heads: usize, 255 | attention_head_size: usize, 256 | } 257 | 258 | impl RobertaSelfAttention { 259 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 260 | let attention_head_size = config.hidden_size / config.num_attention_heads; 261 | let all_head_size = config.num_attention_heads * attention_head_size; 262 | let dropout = Dropout::new(config.hidden_dropout_prob); 263 | let hidden_size = config.hidden_size; 264 | let query = linear(hidden_size, all_head_size, vb.pp("query"))?; 265 | let value = linear(hidden_size, all_head_size, vb.pp("value"))?; 266 | let key = linear(hidden_size, all_head_size, vb.pp("key"))?; 267 | Ok(Self { 268 | query, 269 | key, 270 | value, 271 | dropout, 272 | num_attention_heads: config.num_attention_heads, 273 | attention_head_size, 274 | }) 275 | } 276 | 277 | fn transpose_for_scores(&self, xs: &Tensor) -> Result { 278 | let mut new_x_shape = xs.dims().to_vec(); 279 | new_x_shape.pop(); 280 | new_x_shape.push(self.num_attention_heads); 281 | new_x_shape.push(self.attention_head_size); 282 | let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; 283 | xs.contiguous() 284 | } 285 | 286 | fn forward(&self, hidden_states: &Tensor) -> Result { 287 | let query_layer = self.query.forward(hidden_states)?; 288 | let key_layer = self.key.forward(hidden_states)?; 289 | let value_layer = self.value.forward(hidden_states)?; 290 | 291 | let query_layer = self.transpose_for_scores(&query_layer)?; 292 | let key_layer = self.transpose_for_scores(&key_layer)?; 293 | let value_layer = self.transpose_for_scores(&value_layer)?; 294 | 295 | let attention_scores = query_layer.matmul(&key_layer.t()?)?; 296 | let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; 297 | let attention_probs = 298 | { candle_nn::ops::softmax(&attention_scores, candle_core::D::Minus1)? }; 299 | let attention_probs = self.dropout.forward(&attention_probs)?; 300 | 301 | let context_layer = attention_probs.matmul(&value_layer)?; 302 | let context_layer = context_layer.transpose(1, 2)?.contiguous()?; 303 | let context_layer = context_layer.flatten_from(candle_core::D::Minus2)?; 304 | Ok(context_layer) 305 | } 306 | } 307 | 308 | struct RobertaSelfOutput { 309 | dense: Linear, 310 | layer_norm: LayerNorm, 311 | dropout: Dropout, 312 | } 313 | 314 | impl RobertaSelfOutput { 315 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 316 | let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; 317 | let layer_norm = layer_norm( 318 | config.hidden_size, 319 | config.layer_norm_eps, 320 | vb.pp("LayerNorm"), 321 | )?; 322 | let dropout = Dropout::new(config.hidden_dropout_prob); 323 | Ok(Self { 324 | dense, 325 | layer_norm, 326 | dropout, 327 | }) 328 | } 329 | 330 | fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { 331 | let hidden_states = self.dense.forward(hidden_states)?; 332 | let hidden_states = self.dropout.forward(&hidden_states)?; 333 | self.layer_norm.forward(&(hidden_states + input_tensor)?) 334 | } 335 | } 336 | 337 | struct RobertaAttention { 338 | self_attention: RobertaSelfAttention, 339 | self_output: RobertaSelfOutput, 340 | } 341 | 342 | impl RobertaAttention { 343 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 344 | let self_attention = RobertaSelfAttention::load(vb.pp("self"), config)?; 345 | let self_output = RobertaSelfOutput::load(vb.pp("output"), config)?; 346 | Ok(Self { 347 | self_attention, 348 | self_output, 349 | }) 350 | } 351 | 352 | fn forward(&self, hidden_states: &Tensor) -> Result { 353 | let self_outputs = self.self_attention.forward(hidden_states)?; 354 | let attention_output = self.self_output.forward(&self_outputs, hidden_states)?; 355 | Ok(attention_output) 356 | } 357 | } 358 | 359 | struct RobertaIntermediate { 360 | dense: Linear, 361 | intermediate_act: HiddenActLayer, 362 | } 363 | 364 | impl RobertaIntermediate { 365 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 366 | let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?; 367 | Ok(Self { 368 | dense, 369 | intermediate_act: HiddenActLayer::new(config.hidden_act), 370 | }) 371 | } 372 | 373 | fn forward(&self, hidden_states: &Tensor) -> Result { 374 | let hidden_states = self.dense.forward(hidden_states)?; 375 | let ys = self.intermediate_act.forward(&hidden_states)?; 376 | Ok(ys) 377 | } 378 | } 379 | 380 | struct RobertaOutput { 381 | dense: Linear, 382 | layer_norm: LayerNorm, 383 | dropout: Dropout, 384 | } 385 | 386 | impl RobertaOutput { 387 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 388 | let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?; 389 | let layer_norm = layer_norm( 390 | config.hidden_size, 391 | config.layer_norm_eps, 392 | vb.pp("LayerNorm"), 393 | )?; 394 | let dropout = Dropout::new(config.hidden_dropout_prob); 395 | Ok(Self { 396 | dense, 397 | layer_norm, 398 | dropout, 399 | }) 400 | } 401 | 402 | fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { 403 | let hidden_states = self.dense.forward(hidden_states)?; 404 | let hidden_states = self.dropout.forward(&hidden_states)?; 405 | self.layer_norm.forward(&(hidden_states + input_tensor)?) 406 | } 407 | } 408 | 409 | struct RobertaLayer { 410 | attention: RobertaAttention, 411 | intermediate: RobertaIntermediate, 412 | output: RobertaOutput, 413 | } 414 | 415 | impl RobertaLayer { 416 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 417 | let attention = RobertaAttention::load(vb.pp("attention"), config)?; 418 | let intermediate = RobertaIntermediate::load(vb.pp("intermediate"), config)?; 419 | let output = RobertaOutput::load(vb.pp("output"), config)?; 420 | Ok(Self { 421 | attention, 422 | intermediate, 423 | output, 424 | }) 425 | } 426 | 427 | fn forward(&self, hidden_states: &Tensor) -> Result { 428 | let attention_output = self.attention.forward(hidden_states)?; 429 | 430 | let intermediate_output = self.intermediate.forward(&attention_output)?; 431 | let layer_output = self 432 | .output 433 | .forward(&intermediate_output, &attention_output)?; 434 | Ok(layer_output) 435 | } 436 | } 437 | 438 | struct RobertaEncoder { 439 | layers: Vec, 440 | } 441 | 442 | impl RobertaEncoder { 443 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 444 | let layers = (0..config.num_hidden_layers) 445 | .map(|index| RobertaLayer::load(vb.pp(&format!("layer.{index}")), config)) 446 | .collect::>>()?; 447 | Ok(RobertaEncoder { layers }) 448 | } 449 | 450 | fn forward(&self, hidden_states: &Tensor) -> Result { 451 | let mut hidden_states = hidden_states.clone(); 452 | for layer in self.layers.iter() { 453 | hidden_states = layer.forward(&hidden_states)? 454 | } 455 | Ok(hidden_states) 456 | } 457 | } 458 | 459 | pub struct RobertaPooler{ 460 | dense: Linear, 461 | activation: HiddenActLayer, 462 | } 463 | 464 | impl RobertaPooler{ 465 | pub fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 466 | let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; 467 | Ok( Self { 468 | dense, 469 | activation: HiddenActLayer::new(HiddenAct::Tanh), 470 | }) 471 | 472 | } 473 | 474 | pub fn forward(&self, hidden_states: &Tensor) -> Result { 475 | // We "pool" the model by simply taking the hidden state corresponding 476 | // to the first token. 477 | 478 | let first_token_sensor = hidden_states.i((.., 0))?; 479 | let pooled_output = self.dense.forward(&first_token_sensor)?; 480 | let pooled_output = self.activation.forward(&pooled_output)?; 481 | 482 | Ok(pooled_output) 483 | } 484 | } 485 | 486 | pub struct RobertaModel { 487 | embeddings: RobertaEmbeddings, 488 | encoder: RobertaEncoder, 489 | pub device: Device, 490 | } 491 | 492 | impl RobertaModel { 493 | pub fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 494 | let (embeddings, encoder) = match ( 495 | RobertaEmbeddings::load(vb.pp("embeddings"), config), 496 | RobertaEncoder::load(vb.pp("encoder"), config), 497 | ) { 498 | (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), 499 | (Err(err), _) | (_, Err(err)) => { 500 | if let Some(model_type) = &config.model_type { 501 | if let (Ok(embeddings), Ok(encoder)) = ( 502 | RobertaEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), 503 | RobertaEncoder::load(vb.pp(&format!("{model_type}.encoder")), config), 504 | ) { 505 | (embeddings, encoder) 506 | } else { 507 | return Err(err); 508 | } 509 | } else { 510 | return Err(err); 511 | } 512 | } 513 | }; 514 | Ok(Self { 515 | embeddings, 516 | encoder, 517 | device: vb.device().clone(), 518 | }) 519 | } 520 | 521 | pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { 522 | let embedding_output = self 523 | .embeddings 524 | .forward(input_ids, token_type_ids, None, None)?; 525 | let sequence_output = self.encoder.forward(&embedding_output)?; 526 | Ok(sequence_output) 527 | } 528 | } 529 | 530 | pub struct RobertaModelWithPooler { 531 | embeddings: RobertaEmbeddings, 532 | encoder: RobertaEncoder, 533 | pooler: RobertaPooler, 534 | pub device: Device, 535 | } 536 | 537 | impl RobertaModelWithPooler { 538 | pub fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 539 | let (embeddings, encoder, pooler) = match ( 540 | RobertaEmbeddings::load(vb.pp("embeddings"), config), 541 | RobertaEncoder::load(vb.pp("encoder"), config), 542 | RobertaPooler::load(vb.pp("pooler"), config) 543 | ) { 544 | (Ok(embeddings), Ok(encoder), Ok(pooler)) => (embeddings, encoder, pooler), 545 | (Err(err), _, _) | (_, Err(err), _) | (_, _, Err(err)) => { 546 | if let Some(model_type) = &config.model_type { 547 | if let (Ok(embeddings), Ok(encoder), Ok(pooler)) = ( 548 | RobertaEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), 549 | RobertaEncoder::load(vb.pp(&format!("{model_type}.encoder")), config), 550 | RobertaPooler::load(vb.pp(&format!("{model_type}.pooler")), config), 551 | ) { 552 | (embeddings, encoder, pooler) 553 | } else { 554 | return Err(err); 555 | } 556 | } else { 557 | return Err(err); 558 | } 559 | } 560 | }; 561 | Ok(Self { 562 | embeddings, 563 | encoder, 564 | pooler, 565 | device: vb.device().clone(), 566 | }) 567 | } 568 | 569 | pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { 570 | let embedding_output = self 571 | .embeddings 572 | .forward(input_ids, token_type_ids, None, None)?; 573 | let sequence_output = self.encoder.forward(&embedding_output)?; 574 | let pooled_output = self.pooler.forward(&sequence_output)?; 575 | Ok(pooled_output) 576 | } 577 | } 578 | 579 | struct RobertaClassificationHead{ 580 | dense: Linear, 581 | dropout: Dropout, 582 | out_proj: Linear 583 | } 584 | 585 | impl RobertaClassificationHead { 586 | 587 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 588 | let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; 589 | let classifier_dropout = config.classifier_dropout; 590 | 591 | let classifier_dropout: f64 = match classifier_dropout { 592 | Some(classifier_dropout) => classifier_dropout, 593 | None => config.hidden_dropout_prob, 594 | }; 595 | let out_proj = linear(config.hidden_size, config._num_labels.unwrap(), vb.pp("out_proj"))?; 596 | 597 | Ok( Self { 598 | dense, 599 | dropout: Dropout::new(classifier_dropout), 600 | out_proj 601 | }) 602 | 603 | } 604 | 605 | fn forward(&self, features: &Tensor) -> Result { 606 | 607 | let x = features.i((.., 0))?; 608 | let x = self.dropout.forward(&x)?; 609 | let x = self.dense.forward(&x)?; 610 | let x = x.tanh()?; 611 | let x = self.dropout.forward(&x)?; 612 | let x = self.out_proj.forward(&x)?; 613 | 614 | Ok(x) 615 | } 616 | } 617 | 618 | pub struct RobertaForSequenceClassification { 619 | roberta: RobertaModel, 620 | classifier: RobertaClassificationHead, 621 | pub device: Device, 622 | config: RobertaConfig 623 | } 624 | 625 | impl RobertaForSequenceClassification { 626 | pub fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 627 | let (roberta, classifier) = match ( 628 | RobertaModel::load(vb.pp("roberta"), config), 629 | RobertaClassificationHead::load(vb.pp("classifier"), config), 630 | ) { 631 | (Ok(roberta), Ok(classifier)) => (roberta, classifier), 632 | (Err(err), _) | (_, Err(err)) => { 633 | return Err(err); 634 | } 635 | }; 636 | Ok(Self { 637 | roberta, 638 | classifier, 639 | device: vb.device().clone(), 640 | config: config.clone() 641 | }) 642 | } 643 | 644 | pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor, labels: Option<&Tensor>) -> Result { 645 | let outputs = self 646 | .roberta 647 | .forward(input_ids, token_type_ids)?; 648 | let mut problem_type: String = String::from(""); 649 | 650 | let logits = self.classifier.forward(&outputs)?; 651 | let mut loss: Tensor = Tensor::new(vec![0.0], &self.device)?; 652 | 653 | match labels { 654 | Some(labels) => { 655 | let labels = labels.to_device(&input_ids.device())?; 656 | 657 | if self.config.problem_type == None { 658 | if self.config._num_labels == Some(1) { 659 | problem_type = String::from("regression"); 660 | } else if self.config._num_labels > Some(1) && (labels.dtype() == LONG_DTYPE || labels.dtype() == DType::U32) { 661 | problem_type = String::from("single_label_classification"); 662 | } else { 663 | problem_type = String::from("multi_label_classification"); 664 | } 665 | } 666 | 667 | if problem_type == String::from("single_label_classification") { 668 | loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &labels.flatten_to(1)?)?; 669 | } else if problem_type == String::from("multi_label_classification") { 670 | let labels_logits: Tensor = logits.zeros_like()?; 671 | let mut label_logits = labels_logits.to_vec2::()?; 672 | 673 | let label = vec![0, 1, 2, 3, 2]; 674 | 675 | for vec_i in 0..label_logits.len() { 676 | label_logits[vec_i][label[vec_i]] = 1.0; 677 | } 678 | 679 | let label_logits = Tensor::new(label_logits, &self.device)?; 680 | 681 | loss = binary_cross_entropy_with_logit(&logits, &label_logits)?; 682 | } 683 | 684 | } 685 | 686 | None => {} 687 | } 688 | 689 | Ok(SequenceClassifierOutput { 690 | loss :Some(loss), 691 | logits, 692 | hidden_states :None, 693 | attentions : None 694 | }) 695 | 696 | 697 | } 698 | 699 | } 700 | 701 | pub struct RobertaForTokenClassification { 702 | roberta: RobertaModel, 703 | dropout: Dropout, 704 | classifier: Linear, 705 | pub device: Device, 706 | } 707 | 708 | impl RobertaForTokenClassification { 709 | pub fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 710 | let classifier_dropout = config.classifier_dropout; 711 | 712 | println!("{:?}", config); 713 | 714 | let (roberta, classifier) = match ( 715 | RobertaModel::load(vb.pp("roberta"), config), 716 | 717 | if Option::is_some(&config._num_labels) { 718 | linear(config.hidden_size, config._num_labels.unwrap(), vb.pp("classifier")) 719 | } else if Option::is_some(&config.id2label) { 720 | let num_labels = &config.id2label.as_ref().unwrap().len(); 721 | linear(config.hidden_size, num_labels.clone(), vb.pp("classifier")) 722 | } else { 723 | candle_core::bail!("cannnot find the number of classes to map to") 724 | } 725 | 726 | ) { 727 | (Ok(roberta), Ok(classifier)) => (roberta, classifier), 728 | (Err(err), _) | (_, Err(err)) => { 729 | return Err(err); 730 | } 731 | }; 732 | Ok(Self { 733 | roberta, 734 | dropout: Dropout::new(classifier_dropout.unwrap_or_else(|| 0.2)), 735 | classifier, 736 | device: vb.device().clone(), 737 | }) 738 | } 739 | 740 | pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor, labels: Option<&Tensor>) -> Result { 741 | let outputs = self 742 | .roberta 743 | .forward(input_ids, token_type_ids)?; 744 | let outputs = self.dropout.forward(&outputs)?; 745 | 746 | let logits = self.classifier.forward(&outputs)?; 747 | 748 | println!("{:?}", logits); 749 | let mut loss: Tensor = Tensor::new(vec![0.0], &self.device)?; 750 | 751 | match labels { 752 | Some(labels) => { 753 | loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &labels.flatten_to(1)?)?; 754 | } 755 | None => {} 756 | } 757 | 758 | Ok(TokenClassifierOutput { 759 | loss :Some(loss), 760 | logits, 761 | hidden_states :None, 762 | attentions : None 763 | }) 764 | 765 | 766 | } 767 | 768 | } 769 | 770 | pub struct RobertaForQuestionAnswering { 771 | roberta: RobertaModel, 772 | dropout: Dropout, 773 | qa_outputs: Linear, 774 | pub device: Device, 775 | } 776 | 777 | 778 | impl RobertaForQuestionAnswering { 779 | pub fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 780 | let classifier_dropout = config.classifier_dropout; 781 | 782 | println!("{:?}", config); 783 | 784 | let (roberta, qa_outputs) = match ( 785 | RobertaModel::load(vb.pp("roberta"), config), 786 | linear(config.hidden_size, 2, vb.pp("classifier")) 787 | 788 | ) { 789 | (Ok(roberta), Ok(qa_outputs)) => (roberta, qa_outputs), 790 | (Err(err), _) | (_, Err(err)) => { 791 | return Err(err); 792 | } 793 | }; 794 | Ok(Self { 795 | roberta, 796 | dropout: Dropout::new(classifier_dropout.unwrap_or_else(|| 0.2)), 797 | qa_outputs, 798 | device: vb.device().clone(), 799 | }) 800 | } 801 | 802 | pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor, start_positions: Option<&Tensor>, end_positions: Option<&Tensor>) -> Result { 803 | let outputs = self 804 | .roberta 805 | .forward(input_ids, token_type_ids)?; 806 | let outputs = self.dropout.forward(&outputs)?; 807 | 808 | let logits = self.qa_outputs.forward(&outputs)?; 809 | 810 | let start_logits = logits.i((.., 0))?; 811 | let end_logits = logits.i((.., 1))?; 812 | 813 | println!("{:?}", logits); 814 | let mut loss: Tensor = Tensor::new(vec![0.0], &self.device)?; 815 | 816 | match (start_positions, end_positions) { 817 | (Some(start_positions), Some(end_positions)) => { 818 | let start_loss = candle_nn::loss::cross_entropy(&start_logits.flatten_to(1)?, &start_positions.flatten_to(1)?)?; 819 | let end_loss = candle_nn::loss::cross_entropy(&end_logits.flatten_to(1)?, &end_positions.flatten_to(1)?)?; 820 | 821 | loss = ((start_loss + end_loss)? / 2.0)?; 822 | } 823 | _ => {} 824 | } 825 | 826 | Ok(QuestionAnsweringModelOutput { 827 | loss :Some(loss), 828 | start_logits, 829 | end_logits, 830 | hidden_states :None, 831 | attentions : None 832 | }) 833 | 834 | } 835 | 836 | } 837 | 838 | -------------------------------------------------------------------------------- /src/models/xlm_roberta.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use candle_core::{DType, Device, IndexOp, Result, Tensor}; 4 | use candle_nn::{Embedding, Module, VarBuilder}; 5 | 6 | use crate::models::modelling_outputs::{SequenceClassifierOutput, TokenClassifierOutput, QuestionAnsweringModelOutput}; 7 | use crate::models::model_utils::{Dropout, HiddenAct, Linear, HiddenActLayer, LayerNorm, PositionEmbeddingType}; 8 | use crate::models::model_utils::binary_cross_entropy_with_logit; 9 | use serde::Deserialize; 10 | 11 | pub const FLOATING_DTYPE: DType = DType::F32; 12 | pub const LONG_DTYPE: DType = DType::I64; 13 | 14 | #[derive(Debug, Clone, PartialEq, Deserialize)] 15 | pub struct XLMRobertaConfig { 16 | vocab_size: usize, 17 | hidden_size: usize, 18 | num_hidden_layers: usize, 19 | num_attention_heads: usize, 20 | intermediate_size: usize, 21 | hidden_act: HiddenAct, 22 | hidden_dropout_prob: f64, 23 | max_position_embeddings: usize, 24 | type_vocab_size: usize, 25 | initializer_range: f64, 26 | layer_norm_eps: f64, 27 | pad_token_id: usize, 28 | bos_token_id: usize, 29 | eos_token_id: usize, 30 | #[serde(default)] 31 | position_embedding_type: PositionEmbeddingType, 32 | #[serde(default)] 33 | use_cache: bool, 34 | classifier_dropout: Option, 35 | model_type: Option, 36 | problem_type: Option, 37 | _num_labels: Option, 38 | id2label: Option>, 39 | label2id: Option> 40 | } 41 | 42 | impl Default for XLMRobertaConfig { 43 | fn default() -> Self { 44 | Self { 45 | vocab_size: 50265, 46 | hidden_size: 768, 47 | num_hidden_layers: 12, 48 | num_attention_heads: 12, 49 | intermediate_size: 3072, 50 | hidden_act: HiddenAct::Gelu, 51 | hidden_dropout_prob: 0.1, 52 | max_position_embeddings: 512, 53 | type_vocab_size: 2, 54 | initializer_range: 0.02, 55 | layer_norm_eps: 1e-12, 56 | pad_token_id: 1, 57 | bos_token_id: 0, 58 | eos_token_id: 2, 59 | position_embedding_type: PositionEmbeddingType::Absolute, 60 | use_cache: true, 61 | classifier_dropout: None, 62 | model_type: Some("xlm-roberta".to_string()), 63 | problem_type: None, 64 | _num_labels: Some(3), 65 | id2label: None, 66 | label2id: None 67 | } 68 | } 69 | } 70 | 71 | fn cumsum_2d(mask: &Tensor, dim: u8, device: &Device) -> Result { 72 | let mask = mask.to_vec2::()?; 73 | 74 | let rows = mask.len(); 75 | let cols = mask[0].len(); 76 | 77 | let mut result = mask.clone(); 78 | 79 | match dim { 80 | 0 => { 81 | // Cumulative sum along rows 82 | for i in 0..rows { 83 | for j in 1..cols { 84 | result[i][j] += result[i][j - 1]; 85 | } 86 | } 87 | } 88 | 1 => { 89 | // Cumulative sum along columns 90 | for j in 0..cols { 91 | for i in 1..rows { 92 | result[i][j] += result[i - 1][j]; 93 | } 94 | } 95 | } 96 | _ => panic!("Dimension not supported"), 97 | } 98 | 99 | let result = Tensor::new(result, &device)?; 100 | 101 | Ok(result) 102 | } 103 | 104 | pub fn create_position_ids_from_input_ids( 105 | input_ids: &Tensor, 106 | padding_idx: u32, 107 | past_key_values_length: u8, 108 | ) -> Result { 109 | let mask = input_ids.ne(padding_idx)?; 110 | let incremental_indices = cumsum_2d(&mask, 0, input_ids.device())?; 111 | 112 | let incremental_indices = incremental_indices 113 | .broadcast_add(&Tensor::new(&[past_key_values_length], input_ids.device())?)?; 114 | 115 | Ok(incremental_indices) 116 | } 117 | 118 | fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { 119 | let embeddings = vb.get((vocab_size, hidden_size), "weight")?; 120 | Ok(Embedding::new(embeddings, hidden_size)) 121 | } 122 | 123 | fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { 124 | let weight = vb.get((size2, size1), "weight")?; 125 | let bias = vb.get(size2, "bias")?; 126 | Ok(Linear::new(weight, Some(bias))) 127 | } 128 | 129 | fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { 130 | let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { 131 | (Ok(weight), Ok(bias)) => (weight, bias), 132 | (Err(err), _) | (_, Err(err)) => { 133 | if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { 134 | (weight, bias) 135 | } else { 136 | return Err(err); 137 | } 138 | } 139 | }; 140 | Ok(LayerNorm::new(weight, bias, eps)) 141 | } 142 | 143 | pub struct XLMRobertaEmbeddings { 144 | word_embeddings: Embedding, 145 | position_embeddings: Option, 146 | token_type_embeddings: Embedding, 147 | layer_norm: LayerNorm, 148 | dropout: Dropout, 149 | pub padding_idx: u32, 150 | } 151 | 152 | impl XLMRobertaEmbeddings { 153 | pub fn load(vb: VarBuilder, config: &XLMRobertaConfig) -> Result { 154 | let word_embeddings = embedding( 155 | config.vocab_size, 156 | config.hidden_size, 157 | vb.pp("word_embeddings"), 158 | )?; 159 | let position_embeddings = embedding( 160 | config.max_position_embeddings, 161 | config.hidden_size, 162 | vb.pp("position_embeddings"), 163 | )?; 164 | let token_type_embeddings = embedding( 165 | config.type_vocab_size, 166 | config.hidden_size, 167 | vb.pp("token_type_embeddings"), 168 | )?; 169 | let layer_norm = layer_norm( 170 | config.hidden_size, 171 | config.layer_norm_eps, 172 | vb.pp("LayerNorm"), 173 | )?; 174 | let padding_idx = config.pad_token_id as u32; 175 | 176 | Ok(Self { 177 | word_embeddings, 178 | position_embeddings: Some(position_embeddings), 179 | token_type_embeddings, 180 | layer_norm, 181 | dropout: Dropout::new(config.hidden_dropout_prob), 182 | padding_idx, 183 | }) 184 | } 185 | 186 | pub fn forward( 187 | &self, 188 | input_ids: &Tensor, 189 | token_type_ids: &Tensor, 190 | position_ids: Option<&Tensor>, 191 | inputs_embeds: Option<&Tensor>, 192 | ) -> Result { 193 | let position_ids = match position_ids { 194 | Some(ids) => ids.to_owned(), 195 | None => { 196 | if Option::is_some(&inputs_embeds) { 197 | let position_ids = 198 | self.create_position_ids_from_input_embeds(inputs_embeds.unwrap())?; 199 | position_ids 200 | } else { 201 | let position_ids = 202 | create_position_ids_from_input_ids(input_ids, self.padding_idx, 1)?; 203 | position_ids 204 | } 205 | } 206 | }; 207 | 208 | let inputs_embeds: Tensor = match inputs_embeds { 209 | Some(embeds) => embeds.to_owned(), 210 | None => { 211 | let embeds = self.word_embeddings.forward(input_ids)?; 212 | embeds 213 | } 214 | }; 215 | 216 | let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; 217 | let mut embeddings = (inputs_embeds + token_type_embeddings)?; 218 | 219 | if let Some(position_embeddings) = &self.position_embeddings { 220 | embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)? 221 | } 222 | 223 | let embeddings = self.layer_norm.forward(&embeddings)?; 224 | let embeddings = self.dropout.forward(&embeddings)?; 225 | 226 | Ok(embeddings) 227 | } 228 | 229 | pub fn create_position_ids_from_input_embeds(&self, input_embeds: &Tensor) -> Result { 230 | let input_shape = input_embeds.dims3()?; 231 | let seq_length = input_shape.1; 232 | 233 | println!("seq_length: {:?}", seq_length); 234 | let mut position_ids = Tensor::arange( 235 | self.padding_idx + 1, 236 | seq_length as u32 + self.padding_idx + 1, 237 | &Device::Cpu, 238 | )?; 239 | 240 | println!("position_ids: {:?}", position_ids); 241 | 242 | position_ids = position_ids 243 | .unsqueeze(0)? 244 | .expand((input_shape.0, input_shape.1))?; 245 | Ok(position_ids) 246 | } 247 | } 248 | 249 | struct XLMRobertaSelfAttention { 250 | query: Linear, 251 | key: Linear, 252 | value: Linear, 253 | dropout: Dropout, 254 | num_attention_heads: usize, 255 | attention_head_size: usize, 256 | } 257 | 258 | impl XLMRobertaSelfAttention { 259 | fn load(vb: VarBuilder, config: &XLMRobertaConfig) -> Result { 260 | let attention_head_size = config.hidden_size / config.num_attention_heads; 261 | let all_head_size = config.num_attention_heads * attention_head_size; 262 | let dropout = Dropout::new(config.hidden_dropout_prob); 263 | let hidden_size = config.hidden_size; 264 | let query = linear(hidden_size, all_head_size, vb.pp("query"))?; 265 | let value = linear(hidden_size, all_head_size, vb.pp("value"))?; 266 | let key = linear(hidden_size, all_head_size, vb.pp("key"))?; 267 | Ok(Self { 268 | query, 269 | key, 270 | value, 271 | dropout, 272 | num_attention_heads: config.num_attention_heads, 273 | attention_head_size, 274 | }) 275 | } 276 | 277 | fn transpose_for_scores(&self, xs: &Tensor) -> Result { 278 | let mut new_x_shape = xs.dims().to_vec(); 279 | new_x_shape.pop(); 280 | new_x_shape.push(self.num_attention_heads); 281 | new_x_shape.push(self.attention_head_size); 282 | let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; 283 | xs.contiguous() 284 | } 285 | 286 | fn forward(&self, hidden_states: &Tensor) -> Result { 287 | let query_layer = self.query.forward(hidden_states)?; 288 | let key_layer = self.key.forward(hidden_states)?; 289 | let value_layer = self.value.forward(hidden_states)?; 290 | 291 | let query_layer = self.transpose_for_scores(&query_layer)?; 292 | let key_layer = self.transpose_for_scores(&key_layer)?; 293 | let value_layer = self.transpose_for_scores(&value_layer)?; 294 | 295 | let attention_scores = query_layer.matmul(&key_layer.t()?)?; 296 | let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; 297 | let attention_probs = 298 | { candle_nn::ops::softmax(&attention_scores, candle_core::D::Minus1)? }; 299 | let attention_probs = self.dropout.forward(&attention_probs)?; 300 | 301 | let context_layer = attention_probs.matmul(&value_layer)?; 302 | let context_layer = context_layer.transpose(1, 2)?.contiguous()?; 303 | let context_layer = context_layer.flatten_from(candle_core::D::Minus2)?; 304 | Ok(context_layer) 305 | } 306 | } 307 | 308 | struct XLMRobertaSelfOutput { 309 | dense: Linear, 310 | layer_norm: LayerNorm, 311 | dropout: Dropout, 312 | } 313 | 314 | impl XLMRobertaSelfOutput { 315 | fn load(vb: VarBuilder, config: &XLMRobertaConfig) -> Result { 316 | let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; 317 | let layer_norm = layer_norm( 318 | config.hidden_size, 319 | config.layer_norm_eps, 320 | vb.pp("LayerNorm"), 321 | )?; 322 | let dropout = Dropout::new(config.hidden_dropout_prob); 323 | Ok(Self { 324 | dense, 325 | layer_norm, 326 | dropout, 327 | }) 328 | } 329 | 330 | fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { 331 | let hidden_states = self.dense.forward(hidden_states)?; 332 | let hidden_states = self.dropout.forward(&hidden_states)?; 333 | self.layer_norm.forward(&(hidden_states + input_tensor)?) 334 | } 335 | } 336 | 337 | struct XLMRobertaAttention { 338 | self_attention: XLMRobertaSelfAttention, 339 | self_output: XLMRobertaSelfOutput, 340 | } 341 | 342 | impl XLMRobertaAttention { 343 | fn load(vb: VarBuilder, config: &XLMRobertaConfig) -> Result { 344 | let self_attention = XLMRobertaSelfAttention::load(vb.pp("self"), config)?; 345 | let self_output = XLMRobertaSelfOutput::load(vb.pp("output"), config)?; 346 | Ok(Self { 347 | self_attention, 348 | self_output, 349 | }) 350 | } 351 | 352 | fn forward(&self, hidden_states: &Tensor) -> Result { 353 | let self_outputs = self.self_attention.forward(hidden_states)?; 354 | let attention_output = self.self_output.forward(&self_outputs, hidden_states)?; 355 | Ok(attention_output) 356 | } 357 | } 358 | 359 | struct XLMRobertaIntermediate { 360 | dense: Linear, 361 | intermediate_act: HiddenActLayer, 362 | } 363 | 364 | impl XLMRobertaIntermediate { 365 | fn load(vb: VarBuilder, config: &XLMRobertaConfig) -> Result { 366 | let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?; 367 | Ok(Self { 368 | dense, 369 | intermediate_act: HiddenActLayer::new(config.hidden_act), 370 | }) 371 | } 372 | 373 | fn forward(&self, hidden_states: &Tensor) -> Result { 374 | let hidden_states = self.dense.forward(hidden_states)?; 375 | let ys = self.intermediate_act.forward(&hidden_states)?; 376 | Ok(ys) 377 | } 378 | } 379 | 380 | struct XLMRobertaOutput { 381 | dense: Linear, 382 | layer_norm: LayerNorm, 383 | dropout: Dropout, 384 | } 385 | 386 | impl XLMRobertaOutput { 387 | fn load(vb: VarBuilder, config: &XLMRobertaConfig) -> Result { 388 | let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?; 389 | let layer_norm = layer_norm( 390 | config.hidden_size, 391 | config.layer_norm_eps, 392 | vb.pp("LayerNorm"), 393 | )?; 394 | let dropout = Dropout::new(config.hidden_dropout_prob); 395 | Ok(Self { 396 | dense, 397 | layer_norm, 398 | dropout, 399 | }) 400 | } 401 | 402 | fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { 403 | let hidden_states = self.dense.forward(hidden_states)?; 404 | let hidden_states = self.dropout.forward(&hidden_states)?; 405 | self.layer_norm.forward(&(hidden_states + input_tensor)?) 406 | } 407 | } 408 | 409 | struct XLMRobertaLayer { 410 | attention: XLMRobertaAttention, 411 | intermediate: XLMRobertaIntermediate, 412 | output: XLMRobertaOutput, 413 | } 414 | 415 | impl XLMRobertaLayer { 416 | fn load(vb: VarBuilder, config: &XLMRobertaConfig) -> Result { 417 | let attention = XLMRobertaAttention::load(vb.pp("attention"), config)?; 418 | let intermediate = XLMRobertaIntermediate::load(vb.pp("intermediate"), config)?; 419 | let output = XLMRobertaOutput::load(vb.pp("output"), config)?; 420 | Ok(Self { 421 | attention, 422 | intermediate, 423 | output, 424 | }) 425 | } 426 | 427 | fn forward(&self, hidden_states: &Tensor) -> Result { 428 | let attention_output = self.attention.forward(hidden_states)?; 429 | 430 | let intermediate_output = self.intermediate.forward(&attention_output)?; 431 | let layer_output = self 432 | .output 433 | .forward(&intermediate_output, &attention_output)?; 434 | Ok(layer_output) 435 | } 436 | } 437 | 438 | struct XLMRobertaEncoder { 439 | layers: Vec, 440 | } 441 | 442 | impl XLMRobertaEncoder { 443 | fn load(vb: VarBuilder, config: &XLMRobertaConfig) -> Result { 444 | let layers = (0..config.num_hidden_layers) 445 | .map(|index| XLMRobertaLayer::load(vb.pp(&format!("layer.{index}")), config)) 446 | .collect::>>()?; 447 | Ok(XLMRobertaEncoder { layers }) 448 | } 449 | 450 | fn forward(&self, hidden_states: &Tensor) -> Result { 451 | let mut hidden_states = hidden_states.clone(); 452 | for layer in self.layers.iter() { 453 | hidden_states = layer.forward(&hidden_states)? 454 | } 455 | Ok(hidden_states) 456 | } 457 | } 458 | 459 | pub struct XLMRobertaPooler{ 460 | dense: Linear, 461 | activation: HiddenActLayer, 462 | } 463 | 464 | impl XLMRobertaPooler{ 465 | pub fn load(vb: VarBuilder, config: &XLMRobertaConfig) -> Result { 466 | let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; 467 | Ok( Self { 468 | dense, 469 | activation: HiddenActLayer::new(HiddenAct::Tanh), 470 | }) 471 | 472 | } 473 | 474 | pub fn forward(&self, hidden_states: &Tensor) -> Result { 475 | // We "pool" the model by simply taking the hidden state corresponding 476 | // to the first token. 477 | 478 | let first_token_sensor = hidden_states.i((.., 0))?; 479 | let pooled_output = self.dense.forward(&first_token_sensor)?; 480 | let pooled_output = self.activation.forward(&pooled_output)?; 481 | 482 | Ok(pooled_output) 483 | } 484 | } 485 | 486 | pub struct XLMRobertaModel { 487 | embeddings: XLMRobertaEmbeddings, 488 | encoder: XLMRobertaEncoder, 489 | pub device: Device, 490 | } 491 | 492 | impl XLMRobertaModel { 493 | pub fn load(vb: VarBuilder, config: &XLMRobertaConfig) -> Result { 494 | 495 | let (embeddings, encoder) = match ( 496 | XLMRobertaEmbeddings::load(vb.pp("embeddings"), config), 497 | XLMRobertaEncoder::load(vb.pp("encoder"), config), 498 | ) { 499 | (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), 500 | (Err(err), _) | (_, Err(err)) => { 501 | if let Some(_model_type) = &config.model_type { 502 | if let (Ok(embeddings), Ok(encoder)) = ( 503 | XLMRobertaEmbeddings::load(vb.pp(&format!("roberta.embeddings")), config), 504 | XLMRobertaEncoder::load(vb.pp(&format!("roberta.encoder")), config), 505 | ) { 506 | (embeddings, encoder) 507 | } else { 508 | return Err(err); 509 | } 510 | } else { 511 | return Err(err); 512 | } 513 | } 514 | }; 515 | Ok(Self { 516 | embeddings, 517 | encoder, 518 | device: vb.device().clone(), 519 | }) 520 | } 521 | 522 | pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { 523 | let embedding_output = self 524 | .embeddings 525 | .forward(input_ids, token_type_ids, None, None)?; 526 | let sequence_output = self.encoder.forward(&embedding_output)?; 527 | Ok(sequence_output) 528 | } 529 | } 530 | 531 | pub struct XLMRobertaModelWithPooler { 532 | embeddings: XLMRobertaEmbeddings, 533 | encoder: XLMRobertaEncoder, 534 | pooler: XLMRobertaPooler, 535 | pub device: Device, 536 | } 537 | 538 | impl XLMRobertaModelWithPooler { 539 | pub fn load(vb: VarBuilder, config: &XLMRobertaConfig) -> Result { 540 | let (embeddings, encoder, pooler) = match ( 541 | XLMRobertaEmbeddings::load(vb.pp("embeddings"), config), 542 | XLMRobertaEncoder::load(vb.pp("encoder"), config), 543 | XLMRobertaPooler::load(vb.pp("pooler"), config) 544 | ) { 545 | (Ok(embeddings), Ok(encoder), Ok(pooler)) => (embeddings, encoder, pooler), 546 | (Err(err), _, _) | (_, Err(err), _) | (_, _, Err(err)) => { 547 | if let Some(model_type) = &config.model_type { 548 | if let (Ok(embeddings), Ok(encoder), Ok(pooler)) = ( 549 | XLMRobertaEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), 550 | XLMRobertaEncoder::load(vb.pp(&format!("{model_type}.encoder")), config), 551 | XLMRobertaPooler::load(vb.pp(&format!("{model_type}.pooler")), config), 552 | ) { 553 | (embeddings, encoder, pooler) 554 | } else { 555 | return Err(err); 556 | } 557 | } else { 558 | return Err(err); 559 | } 560 | } 561 | }; 562 | Ok(Self { 563 | embeddings, 564 | encoder, 565 | pooler, 566 | device: vb.device().clone(), 567 | }) 568 | } 569 | 570 | pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { 571 | let embedding_output = self 572 | .embeddings 573 | .forward(input_ids, token_type_ids, None, None)?; 574 | let sequence_output = self.encoder.forward(&embedding_output)?; 575 | let pooled_output = self.pooler.forward(&sequence_output)?; 576 | Ok(pooled_output) 577 | } 578 | } 579 | 580 | struct XLMRobertaClassificationHead{ 581 | dense: Linear, 582 | dropout: Dropout, 583 | out_proj: Linear 584 | } 585 | 586 | impl XLMRobertaClassificationHead { 587 | 588 | fn load(vb: VarBuilder, config: &XLMRobertaConfig) -> Result { 589 | let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; 590 | let classifier_dropout = config.classifier_dropout; 591 | 592 | let classifier_dropout: f64 = match classifier_dropout { 593 | Some(classifier_dropout) => classifier_dropout, 594 | None => config.hidden_dropout_prob, 595 | }; 596 | let out_proj = linear(config.hidden_size, config._num_labels.unwrap(), vb.pp("out_proj"))?; 597 | 598 | Ok( Self { 599 | dense, 600 | dropout: Dropout::new(classifier_dropout), 601 | out_proj 602 | }) 603 | 604 | } 605 | 606 | fn forward(&self, features: &Tensor) -> Result { 607 | 608 | let x = features.i((.., 0))?; 609 | let x = self.dropout.forward(&x)?; 610 | let x = self.dense.forward(&x)?; 611 | let x = x.tanh()?; 612 | let x = self.dropout.forward(&x)?; 613 | let x = self.out_proj.forward(&x)?; 614 | 615 | Ok(x) 616 | } 617 | } 618 | 619 | pub struct XLMRobertaForSequenceClassification { 620 | xlmroberta: XLMRobertaModel, 621 | classifier: XLMRobertaClassificationHead, 622 | pub device: Device, 623 | config: XLMRobertaConfig 624 | } 625 | 626 | impl XLMRobertaForSequenceClassification { 627 | pub fn load(vb: VarBuilder, config: &XLMRobertaConfig) -> Result { 628 | let (xlmroberta, classifier) = match ( 629 | XLMRobertaModel::load(vb.pp("roberta"), config), 630 | XLMRobertaClassificationHead::load(vb.pp("classifier"), config), 631 | ) { 632 | (Ok(xlmroberta), Ok(classifier)) => (xlmroberta, classifier), 633 | (Err(err), _) | (_, Err(err)) => { 634 | return Err(err); 635 | } 636 | }; 637 | Ok(Self { 638 | xlmroberta, 639 | classifier, 640 | device: vb.device().clone(), 641 | config: config.clone() 642 | }) 643 | } 644 | 645 | pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor, labels: Option<&Tensor>) -> Result { 646 | let outputs = self 647 | .xlmroberta 648 | .forward(input_ids, token_type_ids)?; 649 | let mut problem_type: String = String::from(""); 650 | 651 | let logits = self.classifier.forward(&outputs)?; 652 | let mut loss: Tensor = Tensor::new(vec![0.0], &self.device)?; 653 | 654 | match labels { 655 | Some(labels) => { 656 | let labels = labels.to_device(&input_ids.device())?; 657 | 658 | if self.config.problem_type == None { 659 | if self.config._num_labels == Some(1) { 660 | problem_type = String::from("regression"); 661 | } else if self.config._num_labels > Some(1) && (labels.dtype() == LONG_DTYPE || labels.dtype() == DType::U32) { 662 | problem_type = String::from("single_label_classification"); 663 | } else { 664 | problem_type = String::from("multi_label_classification"); 665 | } 666 | } 667 | 668 | if problem_type == String::from("single_label_classification") { 669 | loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &labels.flatten_to(1)?)?; 670 | } else if problem_type == String::from("multi_label_classification") { 671 | let labels_logits: Tensor = logits.zeros_like()?; 672 | let mut label_logits = labels_logits.to_vec2::()?; 673 | 674 | let label = vec![0, 1, 2, 3, 2]; 675 | 676 | for vec_i in 0..label_logits.len() { 677 | label_logits[vec_i][label[vec_i]] = 1.0; 678 | } 679 | 680 | let label_logits = Tensor::new(label_logits, &self.device)?; 681 | 682 | loss = binary_cross_entropy_with_logit(&logits, &label_logits)?; 683 | } 684 | 685 | } 686 | 687 | None => {} 688 | } 689 | 690 | Ok(SequenceClassifierOutput { 691 | loss :Some(loss), 692 | logits, 693 | hidden_states :None, 694 | attentions : None 695 | }) 696 | 697 | 698 | } 699 | 700 | } 701 | 702 | pub struct XLMRobertaForTokenClassification { 703 | xlmroberta: XLMRobertaModel, 704 | dropout: Dropout, 705 | classifier: Linear, 706 | pub device: Device, 707 | } 708 | 709 | impl XLMRobertaForTokenClassification { 710 | pub fn load(vb: VarBuilder, config: &XLMRobertaConfig) -> Result { 711 | let classifier_dropout = config.classifier_dropout; 712 | 713 | println!("{:?}", config); 714 | 715 | let (xlmroberta, classifier) = match ( 716 | XLMRobertaModel::load(vb.pp("roberta"), config), 717 | 718 | if Option::is_some(&config._num_labels) { 719 | linear(config.hidden_size, config._num_labels.unwrap(), vb.pp("classifier")) 720 | } else if Option::is_some(&config.id2label) { 721 | let num_labels = &config.id2label.as_ref().unwrap().len(); 722 | linear(config.hidden_size, num_labels.clone(), vb.pp("classifier")) 723 | } else { 724 | candle_core::bail!("cannnot find the number of classes to map to") 725 | } 726 | 727 | ) { 728 | (Ok(xlmroberta), Ok(classifier)) => (xlmroberta, classifier), 729 | (Err(err), _) | (_, Err(err)) => { 730 | return Err(err); 731 | } 732 | }; 733 | Ok(Self { 734 | xlmroberta, 735 | dropout: Dropout::new(classifier_dropout.unwrap_or_else(|| 0.2)), 736 | classifier, 737 | device: vb.device().clone(), 738 | }) 739 | } 740 | 741 | pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor, labels: Option<&Tensor>) -> Result { 742 | let outputs = self 743 | .xlmroberta 744 | .forward(input_ids, token_type_ids)?; 745 | let outputs = self.dropout.forward(&outputs)?; 746 | 747 | let logits = self.classifier.forward(&outputs)?; 748 | 749 | println!("{:?}", logits); 750 | let mut loss: Tensor = Tensor::new(vec![0.0], &self.device)?; 751 | 752 | match labels { 753 | Some(labels) => { 754 | loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &labels.flatten_to(1)?)?; 755 | } 756 | None => {} 757 | } 758 | 759 | Ok(TokenClassifierOutput { 760 | loss :Some(loss), 761 | logits, 762 | hidden_states :None, 763 | attentions : None 764 | }) 765 | 766 | 767 | } 768 | 769 | } 770 | 771 | 772 | pub struct XLMRobertaForQuestionAnswering { 773 | xlmroberta: XLMRobertaModel, 774 | dropout: Dropout, 775 | qa_outputs: Linear, 776 | pub device: Device, 777 | } 778 | 779 | 780 | impl XLMRobertaForQuestionAnswering { 781 | pub fn load(vb: VarBuilder, config: &XLMRobertaConfig) -> Result { 782 | let classifier_dropout = config.classifier_dropout; 783 | 784 | println!("{:?}", config); 785 | 786 | let (xlmroberta, qa_outputs) = match ( 787 | XLMRobertaModel::load(vb.pp("roberta"), config), 788 | linear(config.hidden_size, 2, vb.pp("classifier")) 789 | 790 | ) { 791 | (Ok(xlmroberta), Ok(qa_outputs)) => (xlmroberta, qa_outputs), 792 | (Err(err), _) | (_, Err(err)) => { 793 | return Err(err); 794 | } 795 | }; 796 | Ok(Self { 797 | xlmroberta, 798 | dropout: Dropout::new(classifier_dropout.unwrap_or_else(|| 0.2)), 799 | qa_outputs, 800 | device: vb.device().clone(), 801 | }) 802 | } 803 | 804 | pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor, start_positions: Option<&Tensor>, end_positions: Option<&Tensor>) -> Result { 805 | let outputs = self 806 | .xlmroberta 807 | .forward(input_ids, token_type_ids)?; 808 | let outputs = self.dropout.forward(&outputs)?; 809 | 810 | let logits = self.qa_outputs.forward(&outputs)?; 811 | 812 | let start_logits = logits.i((.., 0))?; 813 | let end_logits = logits.i((.., 1))?; 814 | 815 | println!("{:?}", logits); 816 | let mut loss: Tensor = Tensor::new(vec![0.0], &self.device)?; 817 | 818 | match (start_positions, end_positions) { 819 | (Some(start_positions), Some(end_positions)) => { 820 | let start_loss = candle_nn::loss::cross_entropy(&start_logits.flatten_to(1)?, &start_positions.flatten_to(1)?)?; 821 | let end_loss = candle_nn::loss::cross_entropy(&end_logits.flatten_to(1)?, &end_positions.flatten_to(1)?)?; 822 | 823 | loss = ((start_loss + end_loss)? / 2.0)?; 824 | } 825 | _ => {} 826 | } 827 | 828 | Ok(QuestionAnsweringModelOutput { 829 | loss :Some(loss), 830 | start_logits, 831 | end_logits, 832 | hidden_states :None, 833 | attentions : None 834 | }) 835 | 836 | } 837 | 838 | } 839 | 840 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Candle Tutorial - Convert Pytorch Models to Candle 2 | 3 | [Candle](https://github.com/huggingface/candle) is an ML framework written in rust that takes advantage of the speed and memory safety Rust provides for writing machine workloads. It can be used as a drop in replacement for ML frameworks like PyTorch, it also has [python bindings](https://github.com/huggingface/candle/tree/main/candle-pyo3) so you can use it from python... 4 | 5 | This repo provides some guide for converting pytorch models from the transformers library to Candle by directly translating the pytorch code to Candle ... 6 | 7 | ❗️❗️: To make the code easily understandable, I have annotated each line of the Rust/Candle code with the equivalent PyTorch code. 8 | Tutorial Structure: 9 | - [Getting Started](#getting-started) 10 | - [0. Important things to note](#0-important-things-to-note) 11 | - [1. Start a new rust project](#1-start-a-new-rust-project) 12 | - [2. Install Candle & Other Packages](#2-install-candle--other-packages) 13 | 14 | - [Parallels between Pytorch and Candle](#3-parallels-between-pytorch-and-candle) 15 | - [Tensors](#tensors) 16 | - [Tensor Operations](#tensor-operations) 17 | - [Translating a PyTorch Transformer Model into Candle](#3-translating-a-pytorch-transformer-model-into-candle) 18 | - [RoBERTa](#31-roberta) 19 | - [a. Writing Building Blocks](#a-writing-building-blocks) 20 | - [b. Roberta Config](#b-roberta-config) 21 | - [c. RobertaEmbeddings](#c-robertaembeddings) 22 | - [d. RobertaSelfAttention](#d-robertaselfattention) 23 | - [e. RobertaSelfOutput](#e-robertaselfoutput) 24 | - [f. RobertaIntermediate](#f-robertaintermediate) 25 | - [g. RobertaOutput](#g-robertaoutput) 26 | - [h. RobertaLayer](#h-robertalayer) 27 | - [i. RobertaEncoder](#i-robertaencoder) 28 | - [j. RobertaModel](#j-robertamodel) 29 | - [Debugging/Testing the model](#debugging-the-model) 30 | 31 | ## Getting Started: 32 | 33 | ### 0. Important things to note 34 | 35 | - When Porting an already trained checkpoint to Candle, there's a bunch of PyTorch code that are not relevant and they are mostly included for handling different scenarios in training. It's definitely beneficial to know which functions to bypass if the conversion effort is mostly geared towards loading an already trained model. 36 | 37 | - Python Built in Method: Unlike Python where we have built-in methods like `__call__` that allow us to use a class as a method and `__init__` for initializing a class, In rust we have to explicitly define methods like `Class::new()` to initialize a class and `Class::forward()` to perform a forward pass. This is going to be a recurrent theme in most of the code shown below. 38 | 39 | - It is important to write [unit tests](tests/test_roberta.rs) after writing most or every module to ensure that input and output shapes in Candle are consistent with the same module in pytorch. 40 | 41 | - In PyTorch, we can initialize module weights by creating a class method `_init_weights` but in candle it becomes a design decision, you can initialize a tensor using the shape of your weights/bias (e.g. ) and hold it in a `VarBuilder` which then used to initialize the tensors in each module. 42 | 43 | 44 | ### 1. Start a new rust project 45 | The command below will create a new rust project called `candle-roberta` in the current directory with a `Cargo.toml` file and a `src` directory with a `main.rs` file in it. 46 | 47 | ```bash 48 | $ cargo new candle-roberta 49 | ``` 50 | 51 | 52 | ### 2. Install Candle & Other Packages 53 | 54 | You can follow the instructions [here](https://huggingface.github.io/candle/guide/installation.html) to install candle or you can use the command below to install candle directly from github. 55 | 56 | For this tutorial, we would be using the `candle-core` and `candle-nn` crates. 57 | `candle-core` provides the core functionality of the candle framework. It provides an implementation the basic blocks for building neural networks and also integrations with different backends like Cuda, MKL, CPU etc, while `candle-nn` provides a high level API for building neural networks. 58 | 59 | ```bash 60 | - cargo add --git https://github.com/huggingface/candle.git candle-core # install candle-core 61 | - cargo add --git https://github.com/huggingface/candle.git candle-nn # install candle-nn 62 | ``` 63 | 64 | Other frameworks we would need for this tutorial are: 65 | - `anyhow` for error handling ==> `cargo add anyhow` 66 | - `serde` for serialization ==> `cargo add serde` 67 | - `serde_json` for json serialization ==> `cargo add serde_json` 68 | - `hf-hub` for integrating with the huggingface hub ==> `cargo add hf-hub` 69 | - `tokenizers` for tokenizing text ==> `cargo add tokenizers` 70 | 71 | ## 3. Parallels between Pytorch and Candle 72 | 73 | To convert a pytorch model to candle, it is important understand the parallels between the two frameworks. 74 | - Candle is a rust framework, so it is statically typed, while pytorch is a python framework, so it is dynamically typed. This means that you need to be explicit about the types of your variables in candle, while in pytorch, you don't need to be explicit about the types of your variables. 75 | 76 | ### Tensors 77 | 78 | The examples shows below can be found [here](); 79 | 80 | - Initializing a Tensor: Tensors can be directly created from an array in both frameworks 81 | 82 | - Pytorch: in pytorch the data type is automatically inffereed from the data; 83 | 84 | ```python 85 | import torch 86 | from typing import List 87 | 88 | data: List = [1, 2, 3] 89 | tensor = torch.tensor(data) 90 | print(tensor) 91 | 92 | nested_data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] 93 | nested_tensor = torch.tensor(nested_data) 94 | print(nested_tensor) 95 | ``` 96 | - Candle: in candle, the data type needs to be explicitly specified; 97 | 98 | ```rust 99 | use candle_core::{DType, Device, Tensor}; 100 | use anyhow::Result; 101 | 102 | let data: [u32; 3] = [1u32, 2, 3]; 103 | let tensor = Tensor::new(&data, &Device::Cpu)?; 104 | println!("tensor: {:?}", tensor.to_vec1::()?); 105 | 106 | let nested_data: [[u32; 3]; 3] = [[1u32, 2, 3], [4, 5, 6], [7, 8, 9]]; 107 | let nested_tensor = Tensor::new(&nested_data, &Device::Cpu)?; 108 | println!("nested_tensor: {:?}", nested_tensor.to_vec2::()?); 109 | ``` 110 | 111 | - Creating a tensor from another tensor 112 | 113 | - Pytorch: in pytorch, the data type is automatically inferred from the data; 114 | 115 | ```python 116 | zero_tensor = torch.zeros_like(tensor) 117 | ones_tensor = torch.ones_like(tensor) 118 | random_tensor = torch.rand_like(tensor) 119 | ``` 120 | 121 | - Candle: in candle, the data type needs to be explicitly specified; 122 | 123 | ```rust 124 | let data: [u32; 3] = [1u32, 2, 3]; 125 | let tensor = Tensor::new(&data, &Device::Cpu)?; 126 | 127 | let zero_tensor = tensor.zeros_like()?; 128 | println!("zero_tensor: {:?}", zero_tensor.to_vec1::()?); 129 | 130 | let ones_tensor = tensor.ones_like()?; 131 | println!("ones_tensor: {:?}", ones_tensor.to_vec1::()?); 132 | 133 | let random_tensor = tensor.rand_like(0.0, 1.0)?; 134 | println!("random_tensor: {:?}", random_tensor.to_vec1::()?); 135 | ``` 136 | 137 | - Checking tensor dimensions: 138 | - PyTorch 139 | ```python 140 | print(tensor.shape) 141 | print(tensor.size()) 142 | ``` 143 | - Candle 144 | ```rust 145 | // 1 dimensional tensor 146 | println!("tensor shape: {:?}", tensor.shape().dims()); 147 | // 2 dimensional tensor 148 | println!("tensor shape: {:?}", tensor.shape().dims2()); 149 | // 3 dimensional tensor 150 | println!("tensor shape: {:?}", tensor.shape().dims3()); 151 | ``` 152 | 153 | ### Tensor Operations: 154 | 155 | Performing tensor operations is pretty similar across both frameworks 156 | Some examples can be found here:: [Candle CheatSheet](https://github.com/huggingface/candle/blob/main/README.md#how-to-use) 157 | 158 | 159 | 160 | 161 | ## 3. Translating a PyTorch Transformer Model into Candle 162 | 163 | Here's the fun part! In this section we are going to take a look at translating models from the transformers library to candle. We would be using the [RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html) and [XLM-Roberta](https://huggingface.co/docs/transformers/model_doc/xlm-roberta) model for this tutorial. 164 | 165 | We would be translating the [Pytorch Source Code](https://github.com/huggingface/transformers/blob/main/src/transformers/models/roberta/modeling_roberta.py) into Candle Code and then load the pretrained checkpoint into Rust and compare the output from both frameworks. 166 | 167 | Note ❗️❗️: To make the code easily understandable, I have annotated each line of the Rust/Candle code with the equivalent PyTorch code. 168 | 169 | ### 3.1. RoBERTa 170 | 171 | RoBERTa is a variant of the BERT model. Although both models have different pretraining approaches, structurally both models are very similar and the major difference between both models is that in the RoBERTa layer, Position numbers begin at padding_idx+1, While in BERT, Position numbers begin at 0. 172 | 173 | Following the transformers PyTorch implementation, RoBERTa Model can be divided into the 2 main parts (embeddings and encoder): 174 | 175 | ``` 176 | RobertaModel( 177 | (embeddings): RobertaEmbeddings( 178 | (word_embeddings): Embedding(50265, 768, padding_idx=1) 179 | (position_embeddings): Embedding(514, 768, padding_idx=1) 180 | (token_type_embeddings): Embedding(1, 768) 181 | (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) 182 | (dropout): Dropout(p=0.1, inplace=False) 183 | ) 184 | (encoder): RobertaEncoder( 185 | (layer): ModuleList( 186 | (0-11): 12 x RobertaLayer( 187 | (attention): RobertaAttention( 188 | (self): RobertaSelfAttention( 189 | (query): Linear(in_features=768, out_features=768, bias=True) 190 | (key): Linear(in_features=768, out_features=768, bias=True) 191 | (value): Linear(in_features=768, out_features=768, bias=True) 192 | (dropout): Dropout(p=0.1, inplace=False) 193 | ) 194 | (output): RobertaSelfOutput( 195 | (dense): Linear(in_features=768, out_features=768, bias=True) 196 | (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) 197 | (dropout): Dropout(p=0.1, inplace=False) 198 | ) 199 | ) 200 | (intermediate): RobertaIntermediate( 201 | (dense): Linear(in_features=768, out_features=3072, bias=True) 202 | (intermediate_act_fn): GELUActivation() 203 | ) 204 | (output): RobertaOutput( 205 | (dense): Linear(in_features=3072, out_features=768, bias=True) 206 | (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) 207 | (dropout): Dropout(p=0.1, inplace=False) 208 | ) 209 | ) 210 | ) 211 | ) 212 | ) 213 | ``` 214 | 215 | - Roberta Config: For Holding Model Configuration 216 | - Roberta Model (RobertaModel) : This is the main model class that contains the embedding and the encoder module. 217 | - Embedding (RobertaEmbeddings) 218 | - Encoder (RobertaEncoder) 219 | 220 | - Embedding (RobertaEmbeddings): The Embedding module is a combination of the following: 221 | - Word Embedding --> PyTorch Linear Module 222 | - Position Embedding --> PyTorch Linear Module 223 | - Token Type Embedding --> PyTorch Linear Module 224 | - Layer Norm 225 | 226 | - Encoder (RobertaEncoder): The Encoder is just made up a number of Attention Layers stacked on one another: 227 | - x number of RobertaLayers: This is a a PyTorch ModuleList of RobertaLayer 228 | 229 | - Roberta Layer (RobertaLayer): The RobertaLayer is made up of the following modules: 230 | - Attention Block (RobertaAttention) -> PyTorch Module (made up of Self Attention Layer and Self Attention Output Layer) 231 | - Self Attention Layer (RobertaSelfAttention) 232 | - Self Attention Output Layer (RobertaSelfOutput) 233 | 234 | - Intermediate Layer (RobertaIntermediate) --> PyTorch Linear Module 235 | - Output Layer (RobertaOutput) --> PyTorch Linear Module 236 | 237 | Listed above are the main components of the model. Other building blocks for implementing the model include: 238 | 239 | - Layer Norm --> PyTorch LayerNorm Module 240 | - Dropout --> PyTorch Dropout Module 241 | - Activation --> PyTorch Activation Function 242 | 243 | 244 | ### Translating Pytorch Modules into Candle 245 | 246 | #### Import necessary Modules: 247 | 248 | Import the necessary modules from candle and other crates: 249 | 250 | - DType: This is an enum that represents the data type of a tensor. 251 | - Device: This is an enum that represents the device a tensor is stored on. 252 | - Result: This is a type alias for `std::result::Result` for error handling 253 | - Tensor: This is a struct that represents a tensor. 254 | 255 | - Embedding: This is a prebuilt struct that represents an embedding layer similar to `nn.Embedding`. 256 | - Module: This is a trait that represents a neural network module similar to `nn.Module` in PyTorch. 257 | - Varbuilder: Module builder for creating variables similar to `nn.Parameter` in PyTorch. 258 | 259 | 260 | ```rust 261 | use candle_core::{DType, Device, Result, Tensor}; 262 | use candle_nn::{Embedding, Module, VarBuilder}; 263 | use serde::Deserialize; 264 | ``` 265 | 266 | ### a. Writing Building Blocks: 267 | 268 | - Linear/Embedding: This is a helper function for loading the weights of a linear/embedding layer using `VarBuilder` from a checkpoint file. We create these 2 helper functions because we will use them multiple times.: 269 | 270 | ```rust 271 | fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { 272 | let embeddings = vb.get((vocab_size, hidden_size), "weight")?; 273 | Ok(Embedding::new(embeddings, hidden_size)) 274 | } 275 | 276 | fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { 277 | let weight = vb.get((size2, size1), "weight")?; 278 | let bias = vb.get(size2, "bias")?; 279 | Ok(Linear::new(weight, Some(bias))) 280 | } 281 | ``` 282 | 283 | Both of these functions already exist in `candle_nn` and can be imported as such `candle_nn::{embedding,linear}` 284 | 285 | - Layer Norm (https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html): Used to a normalize a tensor over a given axis. It is used in the embedding layer and the encoder layer. A good explanation of layer normalization can be [found here](https://www.pinecone.io/learn/batch-layer-normalization/#What-is-Layer-Normalization). This is required because we need to implement the low-level layer norm module in Candle. 286 | 287 | ![image info](./assets/layer_norm.png) 288 | *Layer Normalization from https://www.pinecone.io/learn/batch-layer-normalization/#What-is-Layer-Normalization* 289 | 290 | 291 | - PyTorch: In PyTorch, we can use LayerNorm by calling it as a module 292 | 293 | ```python 294 | from torch import nn 295 | 296 | LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 297 | ``` 298 | 299 | - Candle: In candle we can implement the layer normalization using the equation above or import it directly from `candle_nn` with `candle_nn::{LayerNorm,layer_norm}` Steps: 300 | - Since normalization is done over the last axis which is the hidden size, we can use the `sum_keepdim` method to sum over the last axis and divide by dimension size to get `mean_x`. 301 | - For each element in the tensor, we subtract the mean and square the result and divide by hidden dimension to get `norm_x`. 302 | - To get the normalized input, we subtract the mean from the input and divide by the square root of `norm_x + eps`. 303 | - To get the final output, we multiply the normalized input by the weight of the normalization layer and add the bias. 304 | 305 | ```rust 306 | pub struct LayerNorm { 307 | weight: Tensor, // Weight vector of the LayerNorm Layer 308 | bias: Tensor, // Bias vector of the LayerNorm Layer 309 | eps: f64, // Epsilon value for numerical stability 310 | } 311 | 312 | impl LayerNorm { 313 | // Constructor for LayerNorm 314 | pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { 315 | Self { weight, bias, eps } 316 | } 317 | 318 | pub fn forward(&self, x: &Tensor) -> Result { 319 | let x_dtype = x.dtype(); // Get the data type of the input tensor 320 | let internal_dtype = match x_dtype { 321 | DType::F16 | DType::BF16 => DType::F32, 322 | d => d, 323 | }; 324 | let (_bsize, _seq_len, hidden_size) = x.dims3()?; // Get the dimensions of the input tensor 325 | let x = x.to_dtype(internal_dtype)?; 326 | let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; // Get the mean of the input tensor and divide by the hidden size 327 | let x = x.broadcast_sub(&mean_x)?; // Subtract the mean from the input tensor 328 | let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; // Get the squared norm of the input tensor and divide by the hidden size 329 | let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; // Get the normalized input 330 | let x = x_normed 331 | .to_dtype(x_dtype)? 332 | .broadcast_mul(&self.weight)? 333 | .broadcast_add(&self.bias)?; 334 | Ok(x) 335 | } 336 | } 337 | ``` 338 | 339 | This struct can be used as follows: 340 | 341 | ```rust 342 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 343 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 344 | 345 | // initialize a layer norm layer 346 | let layer_norm = LayerNorm::new(w_gen, b_gen, 1f64); 347 | 348 | let data: [u32; 3] = [1u32, 2, 3]; 349 | let input_tensor = Tensor::new(&data, &Device::Cpu)?; 350 | let normalized_tensor = layer_norm.forward(&input_tensor)?; 351 | ``` 352 | 353 | - Dropout: Randomly zero out different parts of the input tensor using a probability value. This is only used during training, since we are translating a pretrained model, we can just write a struct that returns the passed input tensor 354 | 355 | - PyTorch: In PyTorch, we can use LayerNorm by calling it as a module 356 | 357 | ```python 358 | from torch import nn 359 | 360 | dropout = nn.Dropout(p=0.1) 361 | input = torch.randn(2) 362 | output = dropout(input) 363 | ``` 364 | - Candle: In candle we can implement the dropout layer by just returning the input tensor 365 | 366 | ```rust 367 | struct Dropout { 368 | #[allow(dead_code)] 369 | pr: f64, 370 | } 371 | 372 | impl Dropout { 373 | fn new(pr: f64) -> Self { 374 | Self { pr } 375 | } 376 | 377 | fn forward(&self, x: &Tensor) -> Result { 378 | Ok(x.clone()) 379 | } 380 | } 381 | ``` 382 | 383 | This struct can be used as follows: 384 | 385 | ```rust 386 | let dropout = Dropout::new(0.1); 387 | 388 | let data: [u32; 3] = [1u32, 2, 3]; 389 | let input_tensor = Tensor::new(&data, &Device::Cpu)?; 390 | let dropout_tensor = dropout.forward(&input_tensor)?; 391 | ``` 392 | 393 | - Activation: The RoBERTa uses a GELU activation function. We can implement the GELU using a similar approach as dropout above with no input params. Candle tensors have an inbuilt module to perform this operation 394 | 395 | - PyTorch: In PyTorch, we can use LayerNorm by calling it as a module 396 | 397 | ```python 398 | from torch import nn 399 | 400 | activation = nn.GELU() 401 | input = torch.randn(2) 402 | output = activation(input) 403 | ``` 404 | 405 | - Candle: In candle we can implement the dropout layer by just returning the input tensor 406 | 407 | ```rust 408 | struct Activation {} 409 | 410 | impl Activation { 411 | fn new() -> Self { 412 | Self {} 413 | } 414 | 415 | fn forward(&self, x: &Tensor) -> Result { 416 | Ok(x.gelu()?) 417 | } 418 | } 419 | ``` 420 | 421 | This struct can be used as follows: 422 | 423 | ```rust 424 | let activation = Activation::new(); 425 | 426 | let data: [u32; 3] = [1u32, 2, 3]; 427 | let input_tensor = Tensor::new(&data, &Device::Cpu)?; 428 | let activation_tensor = activation.forward(&input_tensor)?; 429 | ``` 430 | 431 | 432 | ### b. Roberta Config: 433 | 434 | Up next is the Roberta Config. This is a struct that holds the configuration of the model. It is similar to the [RobertaConfig](https://github.com/huggingface/transformers/blob/e1cec43415e72c9853288d4e9325b734d36dd617/src/transformers/models/roberta/configuration_roberta.py#L37) in the transformers library. For this Struct, We will initialize the default values for the config (We implement the `Default` trait for the `RobertaConfig` struct ) and then use the serde crate to deserialize the config from a json file. Alternatively we can create a `RobertaConfig::new()` method for creating a new instance of RobertaConfig 435 | 436 | ```rust 437 | pub struct RobertaConfig { 438 | vocab_size: usize, 439 | hidden_size: usize, 440 | num_hidden_layers: usize, 441 | num_attention_heads: usize, 442 | intermediate_size: usize, 443 | hidden_act: String, 444 | hidden_dropout_prob: f64, 445 | max_position_embeddings: usize, 446 | type_vocab_size: usize, 447 | initializer_range: f64, 448 | layer_norm_eps: f64, 449 | pad_token_id: usize, 450 | bos_token_id: usize, 451 | eos_token_id: usize, 452 | position_embedding_type: String, 453 | use_cache: bool, 454 | classifier_dropout: Option, 455 | model_type: Option, 456 | } 457 | 458 | impl Default for RobertaConfig { 459 | fn default() -> Self { 460 | Self { 461 | vocab_size: 50265, 462 | hidden_size: 768, 463 | num_hidden_layers: 12, 464 | num_attention_heads: 12, 465 | intermediate_size: 3072, 466 | hidden_act: "gelu".to_string(), 467 | hidden_dropout_prob: 0.1, 468 | max_position_embeddings: 512, 469 | type_vocab_size: 2, 470 | initializer_range: 0.02, 471 | layer_norm_eps: 1e-12, 472 | pad_token_id: 1, 473 | bos_token_id: 0, 474 | eos_token_id: 2, 475 | position_embedding_type: PositionEmbeddingType::Absolute, 476 | use_cache: true, 477 | classifier_dropout: None, 478 | model_type: Some("roberta".to_string()), 479 | } 480 | } 481 | } 482 | ``` 483 | 484 | ### c. RobertaEmbeddings: 485 | [HuggingFace PyTorch Implementation](https://github.com/huggingface/transformers/blob/e1cec43415e72c9853288d4e9325b734d36dd617/src/transformers/models/roberta/modeling_roberta.py#L65) 486 | 487 | In the `__init__` function of the embedding class, we have 3 linear layers for processing word_embeddings, position_embeddings and token_type_ids. Similar to the PyTorch implementation, there are two important class methods that we need to implement. 488 | 489 | - [create_position_ids_from_input_embeds](https://github.com/huggingface/transformers/blob/46092f763d26eb938a937c2a9cc69ce1cb6c44c2/src/transformers/models/roberta/modeling_roberta.py#L136): A function to generate position ids from embeddings. I have included the pytorch equivalent of each line as a comment. 490 | 491 | ```rust 492 | pub fn create_position_ids_from_input_embeds(&self, input_embeds: &Tensor) -> Result { 493 | // input_shape = inputs_embeds.size() 494 | // In candle, we use dims3() for getting the size of a 3 dimensional tensor 495 | let input_shape = input_embeds.dims3()?; 496 | // sequence_length = input_shape[1] 497 | let seq_length = input_shape.1; 498 | 499 | // position_ids = torch.arange( self.padding_idx + 1, sequence_length + self.padding_idx + 1, \ 500 | // dtype=torch.long, device=inputs_embeds.device) 501 | let mut position_ids = Tensor::arange( 502 | self.padding_idx + 1, 503 | seq_length as u32 + self.padding_idx + 1, 504 | &Device::Cpu, 505 | )?; 506 | 507 | // return position_ids.unsqueeze(0).expand(input_shape) 508 | position_ids = position_ids 509 | .unsqueeze(0)? 510 | .expand((input_shape.0, input_shape.1))?; 511 | Ok(position_ids) 512 | } 513 | ``` 514 | - [create_position_ids_from_input_ids](https://github.com/huggingface/transformers/blob/46092f763d26eb938a937c2a9cc69ce1cb6c44c2/src/transformers/models/roberta/modeling_roberta.py#L1558): A function to generate position_ids from input_ids. 515 | 516 | ```rust 517 | pub fn create_position_ids_from_input_ids(input_ids: &Tensor, padding_idx: u32, past_key_values_length: u8) -> Result { 518 | // mask = input_ids.ne(padding_idx).int() 519 | let mask = input_ids.ne(padding_idx)?; 520 | // incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask 521 | let incremental_indices = cumsum_2d(&mask, 0, input_ids.device())?; 522 | 523 | // incremental_indices.long() + padding_idx 524 | let incremental_indices = incremental_indices.broadcast_add(&Tensor::new(&[past_key_values_length], input_ids.device())?)?; 525 | 526 | Ok(incremental_indices) 527 | } 528 | ``` 529 | 530 | - [Embedding Layer] : The embedding layer is made up of 3 linear layers for processing word_embeddings, position_embeddings and token_type_ids. The output of the embedding layer is the sum of the word_embeddings, position_embeddings and token_type_embeddings. The output is then passed through a layer norm and dropout layer. A link to the pytorch implementation is shown above. 531 | 532 | ```rust 533 | pub struct RobertaEmbeddings { 534 | word_embeddings: Embedding, 535 | position_embeddings: Option, 536 | token_type_embeddings: Embedding, 537 | layer_norm: LayerNorm, 538 | dropout: Dropout, 539 | pub padding_idx: u32, 540 | } 541 | 542 | impl RobertaEmbeddings { 543 | pub fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 544 | 545 | // nn.Embedding(config.vocab_size, config.hidden_size) 546 | let word_embeddings = embedding( 547 | config.vocab_size, 548 | config.hidden_size, 549 | vb.pp("word_embeddings"), 550 | )?; 551 | 552 | // nn.Embedding(config.max_position_embeddings, config.hidden_size) 553 | let position_embeddings = embedding( 554 | config.max_position_embeddings, 555 | config.hidden_size, 556 | vb.pp("position_embeddings"), 557 | )?; 558 | 559 | // nn.Embedding(config.type_vocab_size, config.hidden_size) 560 | let token_type_embeddings = embedding( 561 | config.type_vocab_size, 562 | config.hidden_size, 563 | vb.pp("token_type_embeddings"), 564 | )?; 565 | 566 | // nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 567 | let layer_norm = layer_norm( 568 | config.hidden_size, 569 | config.layer_norm_eps, 570 | vb.pp("LayerNorm"), 571 | )?; 572 | 573 | // nn.Dropout(config.hidden_dropout_prob) 574 | let dropout = Dropout::new(config.hidden_dropout_prob); 575 | 576 | let padding_idx = config.pad_token_id as u32; 577 | 578 | Ok(Self { 579 | word_embeddings, 580 | position_embeddings: Some(position_embeddings), 581 | token_type_embeddings, 582 | layer_norm, 583 | dropout, 584 | padding_idx, 585 | }) 586 | } 587 | 588 | pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor, position_ids: Option<&Tensor>, inputs_embeds: Option<&Tensor>) -> Result { 589 | 590 | let position_ids = match position_ids { 591 | Some(ids) => ids.to_owned(), 592 | None => { 593 | if Option::is_some(&inputs_embeds){ 594 | // self.create_position_ids_from_inputs_embeds(inputs_embeds) 595 | let position_ids = self.create_position_ids_from_input_embeds(inputs_embeds.unwrap())?; // 596 | position_ids 597 | } else { 598 | // create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) 599 | let position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, 1)?; 600 | position_ids 601 | } 602 | } 603 | }; 604 | 605 | 606 | let inputs_embeds : Tensor = match inputs_embeds { 607 | Some(embeds) => embeds.to_owned(), 608 | None => { 609 | // self.word_embeddings(input_ids) 610 | let embeds = self.word_embeddings.forward(input_ids)?; 611 | embeds 612 | } 613 | }; 614 | 615 | // self.token_type_embeddings(token_type_ids) 616 | let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; 617 | // inputs_embeds + token_type_embeddings 618 | let mut embeddings = (inputs_embeds + token_type_embeddings)?; 619 | 620 | if let Some(position_embeddings) = &self.position_embeddings { 621 | // embeddings + self.position_embeddings(position_ids) 622 | embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)? 623 | } 624 | 625 | // self.LayerNorm(embeddings) 626 | let embeddings = self.layer_norm.forward(&embeddings)?; 627 | // self.dropout(embeddings) 628 | let embeddings = self.dropout.forward(&embeddings)?; 629 | 630 | Ok(embeddings) 631 | 632 | } 633 | } 634 | ``` 635 | 636 | ### d. RobertaSelfAttention: 637 | [HuggingFace PyTorch Implementation](https://github.com/huggingface/transformers/blob/e1cec43415e72c9853288d4e9325b734d36dd617/src/transformers/models/roberta/modeling_roberta.py#L155). The self attention layer is made up of 3 linear layers for processing the query, key and value. The output of the self attention layer is the dot product of the query and key. The output is then passed through a softmax layer and a dropout layer which is then multiplied by the value. 638 | 639 | ```rust 640 | 641 | ```rust 642 | struct RobertaSelfAttention { 643 | query: Linear, 644 | key: Linear, 645 | value: Linear, 646 | dropout: Dropout, 647 | num_attention_heads: usize, 648 | attention_head_size: usize, 649 | } 650 | 651 | impl RobertaSelfAttention { 652 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 653 | // config.hidden_size / config.num_attention_heads 654 | let attention_head_size = config.hidden_size / config.num_attention_heads; 655 | // self.num_attention_heads * self.attention_head_size 656 | let all_head_size = config.num_attention_heads * attention_head_size; 657 | // nn.Dropout(config.attention_probs_dropout_prob) 658 | let dropout = Dropout::new(config.hidden_dropout_prob); 659 | let hidden_size = config.hidden_size; 660 | 661 | // nn.Linear(config.hidden_size, self.all_head_size) 662 | let query = linear(hidden_size, all_head_size, vb.pp("query"))?; 663 | // nn.Linear(config.hidden_size, self.all_head_size) 664 | let value = linear(hidden_size, all_head_size, vb.pp("value"))?; 665 | // nn.Linear(config.hidden_size, self.all_head_size) 666 | let key = linear(hidden_size, all_head_size, vb.pp("key"))?; 667 | Ok(Self { 668 | query, 669 | key, 670 | value, 671 | dropout, 672 | num_attention_heads: config.num_attention_heads, 673 | attention_head_size, 674 | }) 675 | } 676 | 677 | fn transpose_for_scores(&self, xs: &Tensor) -> Result { 678 | 679 | // x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 680 | let mut new_x_shape = xs.dims().to_vec(); 681 | new_x_shape.pop(); 682 | new_x_shape.push(self.num_attention_heads); 683 | new_x_shape.push(self.attention_head_size); 684 | 685 | // x = x.view(new_x_shape) || x.permute(0, 2, 1, 3) 686 | let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; 687 | xs.contiguous() 688 | } 689 | 690 | fn forward(&self, hidden_states: &Tensor) -> Result { 691 | // self.query(hidden_states) 692 | let query_layer = self.query.forward(hidden_states)?; 693 | // self.key(hidden_states) 694 | let key_layer = self.key.forward(hidden_states)?; 695 | // self.value(hidden_states) 696 | let value_layer = self.value.forward(hidden_states)?; 697 | 698 | // self.transpose_for_scores(query_layer) 699 | let query_layer = self.transpose_for_scores(&query_layer)?; 700 | // self.transpose_for_scores(key_layer) 701 | let key_layer = self.transpose_for_scores(&key_layer)?; 702 | // self.transpose_for_scores(value_layer) 703 | let value_layer = self.transpose_for_scores(&value_layer)?; 704 | 705 | // attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 706 | let attention_scores = query_layer.matmul(&key_layer.t()?)?; 707 | // attention_scores / math.sqrt(self.attention_head_size) 708 | let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; 709 | // attention_probs = nn.functional.softmax(attention_scores, dim=-1) 710 | let attention_probs = {candle_nn::ops::softmax(&attention_scores, candle_core::D::Minus1)?}; 711 | // attention_probs = self.dropout(attention_probs) 712 | let attention_probs = self.dropout.forward(&attention_probs)?; 713 | 714 | // torch.matmul(attention_probs, value_layer) 715 | let context_layer = attention_probs.matmul(&value_layer)?; 716 | // context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 717 | let context_layer = context_layer.transpose(1, 2)?.contiguous()?; 718 | 719 | // new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 720 | // context_layer = context_layer.view(new_context_layer_shape) 721 | let context_layer = context_layer.flatten_from(candle_core::D::Minus2)?; // 722 | Ok(context_layer) 723 | } 724 | } 725 | ``` 726 | 727 | ### e. RobertaSelfOutput: 728 | [HuggingFace PyTorch Implementation](https://github.com/huggingface/transformers/blob/e1cec43415e72c9853288d4e9325b734d36dd617/src/transformers/models/roberta/modeling_roberta.py#L290). The output of the Self Attention Layer is passed through the Self Output layer which is made up of a linear layer, layer norm and dropout layer. 729 | 730 | ```rust 731 | struct RobertaSelfOutput { 732 | dense: Linear, 733 | layer_norm: LayerNorm, 734 | dropout: Dropout, 735 | } 736 | 737 | impl RobertaSelfOutput { 738 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 739 | // nn.Linear(config.hidden_size, config.hidden_size) 740 | let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; 741 | // nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 742 | let layer_norm = layer_norm( 743 | config.hidden_size, 744 | config.layer_norm_eps, 745 | vb.pp("LayerNorm"), 746 | )?; 747 | 748 | // nn.Dropout(config.hidden_dropout_prob) 749 | let dropout = Dropout::new(config.hidden_dropout_prob); 750 | Ok(Self { 751 | dense, 752 | layer_norm, 753 | dropout, 754 | }) 755 | } 756 | 757 | fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { 758 | // self.dense(hidden_states) 759 | let hidden_states = self.dense.forward(hidden_states)?; 760 | // self.dropout(hidden_states) 761 | let hidden_states = self.dropout.forward(&hidden_states)?; 762 | // self.LayerNorm(hidden_states + input_tensor) 763 | self.layer_norm.forward(&(hidden_states + input_tensor)?) 764 | } 765 | } 766 | ``` 767 | 768 | ### f. RobertaAttention: 769 | [HuggingFace PyTorch Implementation](https://github.com/huggingface/transformers/blob/e1cec43415e72c9853288d4e9325b734d36dd617/src/transformers/models/roberta/modeling_roberta.py#L305). The Roberta Attention Layer is made up of the Self Attention Layer and the Self Output Layer implemented earlier. The output of the Self Attention Layer is passed through the Self Output Layer. 770 | 771 | ```rust 772 | struct RobertaAttention { 773 | self_attention: RobertaSelfAttention, 774 | self_output: RobertaSelfOutput, 775 | } 776 | 777 | impl RobertaAttention { 778 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 779 | // RobertaSelfAttention(config, position_embedding_type=position_embedding_type) 780 | let self_attention = RobertaSelfAttention::load(vb.pp("self"), config)?; 781 | // RobertaSelfOutput(config) 782 | let self_output = RobertaSelfOutput::load(vb.pp("output"), config)?; 783 | 784 | Ok(Self { 785 | self_attention, 786 | self_output, 787 | }) 788 | } 789 | 790 | fn forward(&self, hidden_states: &Tensor) -> Result { 791 | //self_outputs = self.self(hidden_states) 792 | let self_outputs = self.self_attention.forward(hidden_states)?; 793 | // attention_output = self.output(self_outputs[0], hidden_states) 794 | let attention_output = self.self_output.forward(&self_outputs, hidden_states)?; 795 | 796 | Ok(attention_output) 797 | } 798 | } 799 | ``` 800 | 801 | ### g. RobertaIntermediate 802 | [HuggingFace PyTorch Implementation](https://github.com/huggingface/transformers/blob/e1cec43415e72c9853288d4e9325b734d36dd617/src/transformers/models/roberta/modeling_roberta.py#L355). The intermediate layer is made up of a linear layer and an activation function. Here we use the GELU activation function. This layer combined with the Attention Layer and an Output layer makes up the Encoder. 803 | 804 | ```rust 805 | struct RobertaIntermediate { 806 | dense: Linear, 807 | intermediate_act: HiddenActLayer, 808 | } 809 | 810 | impl RobertaIntermediate { 811 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 812 | // nn.Linear(config.hidden_size, config.intermediate_size) 813 | let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?; 814 | Ok(Self { 815 | dense, 816 | intermediate_act: Activation::new(), 817 | }) 818 | } 819 | 820 | fn forward(&self, hidden_states: &Tensor) -> Result { 821 | // self.dense(hidden_states) 822 | let hidden_states = self.dense.forward(hidden_states)?; 823 | // self.intermediate_act_fn(hidden_states) 824 | let ys = self.intermediate_act.forward(&hidden_states)?; 825 | Ok(ys) 826 | } 827 | } 828 | ``` 829 | 830 | ### h. RobertaOutput 831 | [HuggingFace PyTorch Implementation](https://github.com/huggingface/transformers/blob/e1cec43415e72c9853288d4e9325b734d36dd617/src/transformers/models/roberta/modeling_roberta.py#L371). The output layer is made up of a linear layer, layer norm and dropout layer. This layer combined with the Attention Layer and an Intermediate layer makes up the Encoder. 832 | 833 | ```rust 834 | struct RobertaOutput { 835 | dense: Linear, 836 | layer_norm: LayerNorm, 837 | dropout: Dropout, 838 | } 839 | 840 | impl RobertaOutput { 841 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 842 | // nn.Linear(config.intermediate_size, config.hidden_size) 843 | let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?; 844 | // nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 845 | let layer_norm = layer_norm( 846 | config.hidden_size, 847 | config.layer_norm_eps, 848 | vb.pp("LayerNorm"), 849 | )?; 850 | let dropout = Dropout::new(config.hidden_dropout_prob); 851 | Ok(Self { 852 | dense, 853 | layer_norm, 854 | dropout, 855 | }) 856 | } 857 | 858 | fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { 859 | // self.dense(hidden_states) 860 | let hidden_states = self.dense.forward(hidden_states)?; 861 | // self.dropout(hidden_states) 862 | let hidden_states = self.dropout.forward(&hidden_states)?; 863 | // self.LayerNorm(hidden_states + input_tensor) 864 | self.layer_norm.forward(&(hidden_states + input_tensor)?) 865 | } 866 | } 867 | ``` 868 | 869 | ### i. RobertaLayer 870 | [HuggingFace PyTorch Implementation](https://github.com/huggingface/transformers/blob/e1cec43415e72c9853288d4e9325b734d36dd617/src/transformers/models/roberta/modeling_roberta.py#L386): This does not include an implementation of cross-attention as in the Pytorch code. As mentioned in the previous layers, The Robertalayer is made up of an Attention Layer, an Intermediate Layer and an Output Layer. This layer combined with the Attention Layer and an Output layer makes up the Encoder. 871 | 872 | ```rust 873 | struct RobertaLayer { 874 | attention: RobertaAttention, 875 | intermediate: RobertaIntermediate, 876 | output: RobertaOutput, 877 | } 878 | 879 | impl RobertaLayer { 880 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 881 | // RobertaAttention(config) 882 | let attention = RobertaAttention::load(vb.pp("attention"), config)?; 883 | // RobertaIntermediate(config) 884 | let intermediate = RobertaIntermediate::load(vb.pp("intermediate"), config)?; 885 | // RobertaOutput(config) 886 | let output = RobertaOutput::load(vb.pp("output"), config)?; 887 | Ok(Self { 888 | attention, 889 | intermediate, 890 | output, 891 | }) 892 | } 893 | 894 | fn forward(&self, hidden_states: &Tensor) -> Result { 895 | // self.attention(hidden_states) 896 | let attention_output = self.attention.forward(hidden_states)?; 897 | 898 | // self.intermediate(attention_output) 899 | let intermediate_output = self.intermediate.forward(&attention_output)?; 900 | // self.output(intermediate_output, attention_output) 901 | let layer_output = self 902 | .output 903 | .forward(&intermediate_output, &attention_output)?; 904 | Ok(layer_output) 905 | } 906 | } 907 | ``` 908 | 909 | ### j. RobertaEncoder 910 | [HuggingFace PyTorch Implementation](https://github.com/huggingface/transformers/blob/e1cec43415e72c9853288d4e9325b734d36dd617/src/transformers/models/roberta/modeling_roberta.py#L473). The Encoder is made up of a stack of RobertaLayers. The output of the Encoder is the output of the last RobertaLayer. 911 | 912 | ```rust 913 | impl RobertaEncoder { 914 | fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 915 | // nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)]) 916 | let layers = (0..config.num_hidden_layers) 917 | .map(|index| RobertaLayer::load(vb.pp(&format!("layer.{index}")), config)) 918 | .collect::>>()?; 919 | Ok(RobertaEncoder { layers }) 920 | } 921 | 922 | fn forward(&self, hidden_states: &Tensor) -> Result { 923 | let mut hidden_states = hidden_states.clone(); 924 | 925 | //for i, layer_module in enumerate(self.layer): 926 | // layer_outputs = layer_module(hidden_states) 927 | 928 | for layer in self.layers.iter() { 929 | hidden_states = layer.forward(&hidden_states)? 930 | } 931 | Ok(hidden_states) 932 | } 933 | } 934 | ``` 935 | 936 | ### k. RobertaModel 937 | [HuggingFace PyTorch Implementation](https://github.com/huggingface/transformers/blob/e1cec43415e72c9853288d4e9325b734d36dd617/src/transformers/models/roberta/modeling_roberta.py#L691). VOila! We have implemented all the components of the Roberta Model. The Roberta Model is made up of an Embedding Layer and an Encoder. The output of the Roberta Model is the output of the Encoder. 938 | 939 | ```rust 940 | pub struct RobertaModel { 941 | embeddings: RobertaEmbeddings, 942 | encoder: RobertaEncoder, 943 | pub device: Device, 944 | } 945 | 946 | impl RobertaModel { 947 | pub fn load(vb: VarBuilder, config: &RobertaConfig) -> Result { 948 | let (embeddings, encoder) = match ( 949 | RobertaEmbeddings::load(vb.pp("embeddings"), config), // RobertaEmbeddings(config) 950 | RobertaEncoder::load(vb.pp("encoder"), config), // RobertaEncoder(config) 951 | ) { 952 | (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), 953 | (Err(err), _) | (_, Err(err)) => { 954 | if let Some(model_type) = &config.model_type { 955 | if let (Ok(embeddings), Ok(encoder)) = ( 956 | RobertaEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), 957 | RobertaEncoder::load(vb.pp(&format!("{model_type}.encoder")), config), 958 | ) { 959 | (embeddings, encoder) 960 | } else { 961 | return Err(err); 962 | } 963 | } else { 964 | return Err(err); 965 | } 966 | } 967 | }; 968 | Ok(Self { 969 | embeddings, 970 | encoder, 971 | device: vb.device().clone(), 972 | }) 973 | } 974 | 975 | pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { 976 | // self.embedding(input_ids=input_ids) 977 | let embedding_output = self.embeddings.forward(input_ids, token_type_ids, None, None)?; 978 | // self.encoder(embedding_output ) 979 | let sequence_output = self.encoder.forward(&embedding_output)?; 980 | Ok(sequence_output) 981 | } 982 | 983 | } 984 | ``` 985 | 986 | 987 | ### Debugging the Model 988 | 989 | #### Unit Tests for Different Components 990 | It is important to write unit tests for the different components of the model. This is to ensure that the model is working as expected. Unit tests sometime appear to be time-consuming but they can be very important in the long run. Here are some unit tests I wrote during the porting process: 991 | 992 | ```rust 993 | // Regression_test = https://github.com/huggingface/transformers/blob/21dc5859421cf0d7d82d374b10f533611745a8c5/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py#L496 994 | #[test] 995 | fn test_create_position_ids_from_input_embeds() -> Result<()> { 996 | 997 | let config = RobertaConfig::default(); 998 | let vb = VarBuilder::zeros(DType::F32, &Device::Cpu); 999 | let embeddings_module = RobertaEmbeddings::load(vb, &config).unwrap(); 1000 | 1001 | let input_embeds = Tensor::randn(0f32, 1f32, (2, 4, 30), &Device::Cpu).unwrap(); 1002 | let position_ids = embeddings_module.create_position_ids_from_input_embeds(&input_embeds); 1003 | 1004 | let expected_tensor: &[[u32; 4]; 2] = &[ 1005 | [0 + embeddings_module.padding_idx + 1, 1 + embeddings_module.padding_idx + 1, 2 + embeddings_module.padding_idx + 1, 3 + embeddings_module.padding_idx + 1,], 1006 | [0 + embeddings_module.padding_idx + 1, 1 + embeddings_module.padding_idx + 1, 2 + embeddings_module.padding_idx + 1, 3 + embeddings_module.padding_idx + 1,] 1007 | ]; 1008 | 1009 | assert_eq!(position_ids.unwrap().to_vec2::()?, expected_tensor); 1010 | 1011 | Ok(()) 1012 | 1013 | } 1014 | ``` 1015 | 1016 | - Testing the Model :: [Full Test Code](tests/test_roberta.rs) 1017 | ```rust 1018 | // https://github.com/huggingface/transformers/blob/e1cec43415e72c9853288d4e9325b734d36dd617/tests/models/roberta/test_modeling_roberta.py#L548 1019 | #[test] 1020 | fn test_modeling_roberta_base () -> Result<()> { 1021 | // model = RobertaModel.from_pretrained("roberta-base") 1022 | let (model, _) = build_roberta_model_and_tokenizer("roberta-base", false).unwrap(); 1023 | 1024 | // input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) 1025 | let input_ids = &[[0u32, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]; 1026 | let input_ids = Tensor::new(input_ids, &model.device).unwrap(); 1027 | 1028 | let token_ids = input_ids.zeros_like().unwrap(); 1029 | let output = model.forward(&input_ids, &token_ids)?; 1030 | 1031 | let expected_shape = [1, 11, 768]; 1032 | assert_eq!(output.shape().dims(), &expected_shape); 1033 | 1034 | // expected_slice = torch.tensor([[[-0.0231, 0.0782, 0.0074], [-0.1854, 0.0540, -0.0175], [0.0548, 0.0799, 0.1687]]]) 1035 | let expected_output = [[-0.0231, 0.0782, 0.0074], [-0.1854, 0.0540, -0.0175], [0.0548, 0.0799, 0.1687]]; 1036 | 1037 | // self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) 1038 | let output = output.squeeze(0)?; 1039 | let output = output.to_vec2::()?; 1040 | let output: Vec> = output.iter().take(3).map(|nested_vec| nested_vec.iter().take(3).map(|&x| round_to_decimal_places(x, 4)).collect()).collect(); 1041 | assert_eq!(output, expected_output); 1042 | 1043 | Ok(()) 1044 | 1045 | } 1046 | ``` 1047 | --------------------------------------------------------------------------------