├── syntaxdot-encoders ├── src │ ├── lang │ │ ├── de │ │ │ ├── mod.rs │ │ │ └── tdz │ │ │ │ ├── mod.rs │ │ │ │ └── lemma │ │ │ │ ├── error.rs │ │ │ │ ├── constants.rs │ │ │ │ ├── automaton.rs │ │ │ │ └── transform │ │ │ │ ├── mod.rs │ │ │ │ ├── delemmatization.rs │ │ │ │ ├── svp.rs │ │ │ │ ├── test_helpers.rs │ │ │ │ └── named_entity.rs │ │ └── mod.rs │ ├── dependency │ │ └── mod.rs │ ├── categorical │ │ ├── mod.rs │ │ └── number.rs │ ├── layer │ │ └── error.rs │ ├── lemma │ │ └── mod.rs │ ├── lib.rs │ └── depseq │ │ ├── mod.rs │ │ └── error.rs ├── testdata │ └── lang │ │ └── de │ │ └── tdz │ │ └── lemma │ │ ├── remove-sep-verb-prefix.test │ │ ├── remove-trunc-marker.test │ │ ├── form-as-lemma.test │ │ ├── simplify-article-lemma.test │ │ ├── add-separated-verb-prefix.test │ │ ├── simplify-personal-pronoun.test │ │ ├── restore-case.test │ │ ├── mark-verb-prefix.test │ │ ├── simplify-possesive-pronoun-lemma.test │ │ ├── simplify-piat-lemma.test │ │ ├── simplify-pidat-lemma.test │ │ └── simplify-pis-lemma.test ├── Cargo.toml └── data │ └── lang │ └── de │ └── tdz │ └── tdz11-separable-prefixes.txt ├── syntaxdot-summary ├── src │ ├── crc32_table.rs │ ├── lib.rs │ ├── record_writer.rs │ ├── event_writer.rs │ └── summary_writer.rs ├── Cargo.toml └── build.rs ├── .gitignore ├── syntaxdot ├── src │ ├── model │ │ └── mod.rs │ ├── encoders │ │ ├── mod.rs │ │ └── config.rs │ ├── lib.rs │ ├── optimizers │ │ ├── grad.rs │ │ ├── mod.rs │ │ └── grad_scale.rs │ ├── error.rs │ ├── dataset │ │ ├── mod.rs │ │ ├── conll.rs │ │ ├── plaintext.rs │ │ └── sentence_itertools.rs │ └── util.rs ├── testdata │ └── sticker.conf └── Cargo.toml ├── release.toml ├── Cargo.toml ├── syntaxdot-cli ├── src │ ├── summary │ │ ├── noop.rs │ │ ├── tensorboard.rs │ │ └── mod.rs │ ├── subcommands │ │ ├── mod.rs │ │ ├── filter_len.rs │ │ └── prepare.rs │ ├── util.rs │ ├── traits.rs │ ├── main.rs │ ├── progress.rs │ └── save.rs └── Cargo.toml ├── syntaxdot-transformers ├── src │ ├── models │ │ ├── traits.rs │ │ ├── albert │ │ │ ├── mod.rs │ │ │ ├── config.rs │ │ │ └── embeddings.rs │ │ ├── bert │ │ │ ├── mod.rs │ │ │ └── config.rs │ │ ├── mod.rs │ │ ├── squeeze_bert │ │ │ ├── mod.rs │ │ │ ├── config.rs │ │ │ └── embeddings.rs │ │ ├── encoder.rs │ │ └── layer_output.rs │ ├── cow.rs │ ├── lib.rs │ ├── error.rs │ ├── module.rs │ └── activations.rs └── Cargo.toml ├── scripts ├── tensor_module.py ├── update-syntaxdot-0.2-model.py ├── pytorch-squeezebert-to-syntaxdot.py ├── pytorch-roberta-to-syntaxdot.py ├── test-all.sh ├── pytorch-bert-to-syntaxdot.py └── plot-layer-weights ├── COPYRIGHT.md ├── syntaxdot-tokenizers ├── Cargo.toml └── src │ ├── error.rs │ ├── lib.rs │ ├── bert.rs │ ├── albert.rs │ └── xlm_roberta.rs ├── syntaxdot-tch-ext ├── Cargo.toml └── src │ ├── tensor.rs │ └── lib.rs ├── LICENSE-MIT ├── .github └── workflows │ ├── rust.yml │ └── release.yml ├── doc ├── models.md └── install.md └── README.md /syntaxdot-encoders/src/lang/de/mod.rs: -------------------------------------------------------------------------------- 1 | //! German-specific encoders/decoders. 2 | 3 | pub mod tdz; 4 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/lang/mod.rs: -------------------------------------------------------------------------------- 1 | //! Language-specific encoders/decoders. 2 | 3 | pub mod de; 4 | -------------------------------------------------------------------------------- /syntaxdot-summary/src/crc32_table.rs: -------------------------------------------------------------------------------- 1 | // Generated by build.rs. 2 | include!(concat!(env!("OUT_DIR"), "/crc32_table.rs")); 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Editor files 2 | .* 3 | *~ 4 | 5 | # Rust files 6 | /target 7 | **/*.rs.bk 8 | 9 | # Nix output 10 | result* 11 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/lang/de/tdz/mod.rs: -------------------------------------------------------------------------------- 1 | //! TüBa-D/Z-specific encoders/decoders. 2 | 3 | mod lemma; 4 | 5 | pub use lemma::TdzLemmaEncoder; 6 | -------------------------------------------------------------------------------- /syntaxdot/src/model/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod biaffine_dependency_layer; 2 | 3 | pub mod bert; 4 | 5 | pub(crate) mod pooling; 6 | 7 | pub mod seq_classifiers; 8 | -------------------------------------------------------------------------------- /release.toml: -------------------------------------------------------------------------------- 1 | consolidate-commits = true 2 | shared-version = true 3 | pre-release-commit-message = "Bump version to {{version}}" 4 | tag-message = "" 5 | tag-prefix = "" 6 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/dependency/mod.rs: -------------------------------------------------------------------------------- 1 | //! Dependency encoding/decoding for biaffine parsing. 2 | 3 | mod encoder; 4 | pub use encoder::{ 5 | DependencyEncoding, EncodeError, ImmutableDependencyEncoder, MutableDependencyEncoder, 6 | }; 7 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "syntaxdot-encoders", 4 | "syntaxdot-tokenizers", 5 | "syntaxdot-summary", 6 | "syntaxdot-transformers", 7 | "syntaxdot", 8 | "syntaxdot-cli", 9 | "syntaxdot-tch-ext" 10 | ] 11 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/categorical/mod.rs: -------------------------------------------------------------------------------- 1 | //! Categorical variable encoder 2 | 3 | mod encoder; 4 | pub use encoder::{CategoricalEncoder, ImmutableCategoricalEncoder, MutableCategoricalEncoder}; 5 | 6 | mod number; 7 | pub use number::{ImmutableNumberer, MutableNumberer, Number}; 8 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/layer/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Clone, Debug, Eq, Error, PartialEq)] 4 | pub enum EncodeError { 5 | /// The token does not have a label. 6 | #[error("token without a label: '{form:?}'")] 7 | MissingLabel { form: String }, 8 | } 9 | -------------------------------------------------------------------------------- /syntaxdot-cli/src/summary/noop.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | 3 | use crate::summary::ScalarWriter; 4 | 5 | pub struct NoopWriter; 6 | 7 | impl ScalarWriter for NoopWriter { 8 | fn write_scalar(&self, _tag: &str, _step: i64, _value: f32) -> Result<()> { 9 | Ok(()) 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /syntaxdot-transformers/src/models/traits.rs: -------------------------------------------------------------------------------- 1 | pub trait WordEmbeddingsConfig { 2 | fn dims(&self) -> i64; 3 | 4 | fn dropout(&self) -> f64; 5 | 6 | fn initializer_range(&self) -> f64; 7 | 8 | fn layer_norm_eps(&self) -> f64; 9 | 10 | fn vocab_size(&self) -> i64; 11 | } 12 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/lang/de/tdz/lemma/error.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | 3 | use thiserror::Error; 4 | 5 | #[derive(Debug, Error)] 6 | pub enum LemmatizationError { 7 | #[error(transparent)] 8 | IO(#[from] io::Error), 9 | 10 | #[error(transparent)] 11 | Fst(#[from] fst::Error), 12 | } 13 | -------------------------------------------------------------------------------- /scripts/tensor_module.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class TensorModule(nn.Module): 4 | def __init__(self, tensors): 5 | super(TensorModule, self).__init__() 6 | 7 | for tensor_name, tensor in tensors.items(): 8 | setattr(self, tensor_name, nn.Parameter(tensor)) 9 | 10 | -------------------------------------------------------------------------------- /syntaxdot-transformers/src/models/albert/mod.rs: -------------------------------------------------------------------------------- 1 | //! ALBERT (Lan et al., 2020) 2 | 3 | mod config; 4 | pub use config::AlbertConfig; 5 | 6 | mod embeddings; 7 | pub(crate) use embeddings::AlbertEmbeddingProjection; 8 | pub use embeddings::AlbertEmbeddings; 9 | 10 | mod encoder; 11 | pub use encoder::AlbertEncoder; 12 | -------------------------------------------------------------------------------- /syntaxdot-cli/src/subcommands/mod.rs: -------------------------------------------------------------------------------- 1 | mod annotate; 2 | pub use annotate::AnnotateApp; 3 | 4 | mod distill; 5 | pub use distill::DistillApp; 6 | 7 | mod filter_len; 8 | pub use filter_len::FilterLenApp; 9 | 10 | mod finetune; 11 | pub use finetune::FinetuneApp; 12 | 13 | mod prepare; 14 | pub use prepare::PrepareApp; 15 | -------------------------------------------------------------------------------- /syntaxdot/src/encoders/mod.rs: -------------------------------------------------------------------------------- 1 | //! Encoder configuration and construction. 2 | 3 | mod config; 4 | pub use config::{DependencyEncoder, EncoderType, EncodersConfig, NamedEncoderConfig}; 5 | 6 | #[allow(clippy::module_inception)] 7 | mod encoders; 8 | pub use encoders::{DecoderError, Encoder, EncoderError, Encoders, NamedEncoder}; 9 | -------------------------------------------------------------------------------- /syntaxdot-transformers/src/models/bert/mod.rs: -------------------------------------------------------------------------------- 1 | //! BERT (Devlin et al., 2018) 2 | 3 | mod config; 4 | pub use config::BertConfig; 5 | 6 | mod embeddings; 7 | pub use embeddings::BertEmbeddings; 8 | 9 | mod encoder; 10 | pub use encoder::BertEncoder; 11 | 12 | mod layer; 13 | pub(crate) use layer::bert_linear; 14 | pub use layer::BertLayer; 15 | -------------------------------------------------------------------------------- /syntaxdot-encoders/testdata/lang/de/tdz/lemma/remove-sep-verb-prefix.test: -------------------------------------------------------------------------------- 1 | # Format: form lemma upos xpos transformed 2 | 3 | # Should be stripped 4 | _ hinein#ziehen _ VVPP ziehen 5 | _ zusammen#hängen _ VVFIN hängen 6 | 7 | # Multiple prefixes 8 | _ wieder#auf#bauen _ VVINF bauen 9 | 10 | # No changes for non-verbs 11 | _ CD#1 _ NN CD#1 12 | -------------------------------------------------------------------------------- /syntaxdot/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod config; 2 | 3 | pub mod dataset; 4 | 5 | pub mod error; 6 | 7 | pub mod encoders; 8 | 9 | pub mod lr; 10 | 11 | pub mod model; 12 | 13 | pub mod optimizers; 14 | 15 | pub mod tagger; 16 | 17 | pub mod tensor; 18 | 19 | pub mod util; 20 | 21 | /// The syntaxdot version. 22 | pub const VERSION: &str = env!("CARGO_PKG_VERSION"); 23 | -------------------------------------------------------------------------------- /syntaxdot-transformers/src/models/mod.rs: -------------------------------------------------------------------------------- 1 | //! Transformer models. 2 | 3 | pub mod albert; 4 | 5 | pub mod bert; 6 | 7 | mod encoder; 8 | pub use encoder::Encoder; 9 | 10 | mod layer_output; 11 | pub use layer_output::{HiddenLayer, LayerOutput}; 12 | 13 | pub mod roberta; 14 | 15 | pub mod sinusoidal; 16 | 17 | pub mod squeeze_albert; 18 | 19 | pub mod squeeze_bert; 20 | 21 | mod traits; 22 | -------------------------------------------------------------------------------- /syntaxdot/src/optimizers/grad.rs: -------------------------------------------------------------------------------- 1 | use tch::Tensor; 2 | 3 | pub trait ZeroGrad { 4 | /// Zero out gradients. 5 | fn zero_grad(&mut self); 6 | } 7 | 8 | impl ZeroGrad for Vec { 9 | fn zero_grad(&mut self) { 10 | for tensor in self { 11 | if tensor.requires_grad() { 12 | tensor.zero_grad() 13 | } 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /COPYRIGHT.md: -------------------------------------------------------------------------------- 1 | ## SyntaxDot 2 | 3 | Copyright 2020-2021 TensorDot 4 | Copyright 2018-2020 The sticker contributors 5 | 6 | Licensed under the [Apache License, Version 7 | 2.0](http://www.apache.org/licenses/LICENSE-2.0) or the [MIT 8 | license](http://opensource.org/licenses/MIT), at your option. 9 | 10 | Contributors: 11 | 12 | * Daniël de Kok 13 | * Tobias Pütz 14 | -------------------------------------------------------------------------------- /syntaxdot-encoders/testdata/lang/de/tdz/lemma/remove-trunc-marker.test: -------------------------------------------------------------------------------- 1 | # Format: form lemma upos xpos transformed [rel head_form head_lemma head_upos head_xpos] 2 | # [rel dep_form dep_lemma dep_upos dep_xpos]* 3 | 4 | Bau- Bauplanung NOUN TRUNC Bau- 5 | jahre- jahrelang ADJ TRUNC jahre- 6 | Jahre- jahrelang ADJ TRUNC jahre- 7 | hin- hin#schieben VERB TRUNC hin- 8 | 9 | # Should not fire for other tags 10 | Bau- foo NOUN XYZ foo 11 | -------------------------------------------------------------------------------- /syntaxdot-transformers/src/cow.rs: -------------------------------------------------------------------------------- 1 | use std::ops::Deref; 2 | 3 | use tch::Tensor; 4 | 5 | pub enum CowTensor<'a> { 6 | Owned(Tensor), 7 | Borrowed(&'a Tensor), 8 | } 9 | 10 | impl<'a> Deref for CowTensor<'a> { 11 | type Target = Tensor; 12 | 13 | fn deref(&self) -> &Self::Target { 14 | match self { 15 | CowTensor::Owned(ref tensor) => tensor, 16 | CowTensor::Borrowed(tensor) => tensor, 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /syntaxdot-summary/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! TensorBoard summary writer 2 | //! 3 | //! This crate implements just enough functionality to write scalars 4 | //! in the TensorFlow/TensorBoard summary format. This replaces the 5 | //! far more extensive `tfrecord` crate, which has **many** 6 | //! dependencies. 7 | 8 | mod crc32; 9 | 10 | mod crc32_table; 11 | 12 | pub(crate) mod event_writer; 13 | 14 | pub(crate) mod record_writer; 15 | 16 | mod summary_writer; 17 | pub use summary_writer::SummaryWriter; 18 | -------------------------------------------------------------------------------- /syntaxdot-encoders/testdata/lang/de/tdz/lemma/form-as-lemma.test: -------------------------------------------------------------------------------- 1 | # Format: form lemma upos xpos transformed [rel head_form head_lemma head_upos head_xpos] 2 | # [rel dep_form dep_lemma dep_upos dep_xpos]* 3 | 4 | # For applicable tags, the lemma is the lowercase form. 5 | und _ _ KON und 6 | Und _ _ KON und 7 | 8 | # Foreign words should use the form, retaining case. 9 | Brasil _ _ FM Brasil 10 | improve _ _ FM improve 11 | 12 | # For other tags, nothing changes. 13 | Auto Foobar _ NN Foobar 14 | NBA Quux _ NE Quux -------------------------------------------------------------------------------- /syntaxdot-transformers/src/models/squeeze_bert/mod.rs: -------------------------------------------------------------------------------- 1 | //! SqueezeBERT (Iandola et al., 2020) 2 | //! 3 | //! SqueezeBERT follows the same architecture as BERT, but replaces most 4 | //! matrix multiplications by grouped convolutions. This reduces the 5 | //! number of parameters and speeds up inference. 6 | 7 | mod config; 8 | pub use config::SqueezeBertConfig; 9 | 10 | mod embeddings; 11 | 12 | mod encoder; 13 | pub use encoder::SqueezeBertEncoder; 14 | 15 | mod layer; 16 | pub(crate) use layer::SqueezeBertLayer; 17 | -------------------------------------------------------------------------------- /syntaxdot-summary/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "syntaxdot-summary" 3 | version = "0.5.0" 4 | authors = ["Daniël de Kok "] 5 | edition = "2018" 6 | description = "TensorBoard summary writer" 7 | homepage = "https://github.com/tensordot/syntaxdot" 8 | repository = "https://github.com/tensordot/syntaxdot.git" 9 | documentation = "https://docs.rs/syntaxdot-transformers/" 10 | license = "MIT OR Apache-2.0" 11 | rust-version = "1.70.0" 12 | 13 | 14 | [dependencies] 15 | hostname = "0.3" 16 | prost = { version = "0.12", features = ["prost-derive"] } 17 | -------------------------------------------------------------------------------- /syntaxdot-tokenizers/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "syntaxdot-tokenizers" 3 | version = "0.5.0" 4 | authors = ["Daniël de Kok "] 5 | edition = "2018" 6 | description = "Subword tokenizers" 7 | homepage = "https://github.com/tensordot/syntaxdot" 8 | repository = "https://github.com/tensordot/syntaxdot.git" 9 | documentation = "https://docs.rs/syntaxdot-tokenizers/" 10 | license = "MIT OR Apache-2.0" 11 | rust-version = "1.70.0" 12 | 13 | [dependencies] 14 | ndarray = "0.15" 15 | sentencepiece = "0.11" 16 | thiserror = "1" 17 | udgraph = "0.8" 18 | wordpieces = "0.6" 19 | 20 | [features] 21 | model-tests = [] 22 | -------------------------------------------------------------------------------- /syntaxdot-transformers/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Transformer models (Vaswani et al., 2017) 2 | //! 3 | //! This crate implements various transformer models, provided through 4 | //! the [`models`] module. The implementations are more restricted than 5 | //! e.g. their Huggingface counterparts, focusing only on the parts 6 | //! necessary for sequence labeling. 7 | 8 | pub mod activations; 9 | 10 | pub(crate) mod cow; 11 | 12 | mod error; 13 | pub use error::TransformerError; 14 | 15 | pub mod layers; 16 | 17 | pub mod loss; 18 | 19 | pub mod models; 20 | 21 | pub mod module; 22 | 23 | pub mod scalar_weighting; 24 | 25 | pub(crate) mod util; 26 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/lemma/mod.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | mod encoder; 4 | pub use self::encoder::{BackoffStrategy, EditTreeEncoder}; 5 | 6 | pub(crate) mod edit_tree; 7 | pub use edit_tree::EditTree; 8 | 9 | /// Lemma encoding error. 10 | #[derive(Clone, Debug, Eq, Error, PartialEq)] 11 | pub enum EncodeError { 12 | /// The token does not have a lemma. 13 | #[error("token without a lemma: '{form:?}'")] 14 | MissingLemma { form: String }, 15 | 16 | /// No edit tree can be constructed. 17 | #[error("cannot find an edit tree that rewrites '{form:?}' into '{lemma:?}'")] 18 | NoEditTree { form: String, lemma: String }, 19 | } 20 | -------------------------------------------------------------------------------- /syntaxdot-tch-ext/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "syntaxdot-tch-ext" 3 | version = "0.5.0" 4 | authors = ["Daniël de Kok "] 5 | edition = "2018" 6 | description = "tch path extension for partitioning parameters in groups" 7 | homepage = "https://github.com/tensordot/syntaxdot" 8 | repository = "https://github.com/tensordot/syntaxdot.git" 9 | documentation = "https://docs.rs/syntaxdot-tch-ext/" 10 | license = "MIT OR Apache-2.0" 11 | rust-version = "1.70.0" 12 | 13 | [dependencies] 14 | itertools = "0.11" 15 | tch = { version = "0.14", default-features = false } 16 | 17 | [features] 18 | doc-only = ["tch/doc-only"] 19 | 20 | [package.metadata.docs.rs] 21 | features = [ "doc-only" ] 22 | -------------------------------------------------------------------------------- /syntaxdot/testdata/sticker.conf: -------------------------------------------------------------------------------- 1 | [input] 2 | tokenizer = { bert = { vocab = "bert-base-german-cased-vocab.txt" } } 3 | 4 | [biaffine] 5 | labels = "sticker.biaffine_labels" 6 | head = { dims = 50, head_bias = true, dependent_bias = false } 7 | relation = { dims = 25, head_bias = true, dependent_bias = true } 8 | 9 | [labeler] 10 | labels = "sticker.labels" 11 | encoders = [ 12 | { name = "dep", encoder = { dependency = { encoder = { relativepos = "xpos" }, root_relation = "root" } } }, 13 | { name = "lemma", encoder = { lemma = "form" } }, 14 | { name = "pos", encoder = { sequence = "xpos" } }, 15 | ] 16 | 17 | [model] 18 | parameters = "epoch-99" 19 | pooler = "discard" 20 | position_embeddings = "model" 21 | pretrain_config = "bert_config.json" 22 | pretrain_type = "bert" -------------------------------------------------------------------------------- /syntaxdot-cli/src/summary/tensorboard.rs: -------------------------------------------------------------------------------- 1 | use std::cell::RefCell; 2 | use std::fs::File; 3 | use std::io::BufWriter; 4 | 5 | use anyhow::Result; 6 | use syntaxdot_summary::SummaryWriter; 7 | 8 | use super::ScalarWriter; 9 | 10 | pub struct TensorBoardWriter { 11 | writer: RefCell>>, 12 | } 13 | 14 | impl TensorBoardWriter { 15 | pub fn new(prefix: impl AsRef) -> Result { 16 | let writer = SummaryWriter::from_prefix(prefix.as_ref())?; 17 | Ok(TensorBoardWriter { 18 | writer: RefCell::new(writer), 19 | }) 20 | } 21 | } 22 | 23 | impl ScalarWriter for TensorBoardWriter { 24 | fn write_scalar(&self, tag: &str, step: i64, value: f32) -> Result<()> { 25 | Ok(self.writer.borrow_mut().write_scalar(tag, step, value)?) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /syntaxdot-tokenizers/src/error.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | 3 | use sentencepiece::SentencePieceError; 4 | use thiserror::Error; 5 | use wordpieces::WordPiecesError; 6 | 7 | #[derive(Debug, Error)] 8 | pub enum TokenizerError { 9 | #[error("Cannot open tokenizer model `{model_path:?}`: {inner:?}")] 10 | OpenError { 11 | model_path: String, 12 | inner: io::Error, 13 | }, 14 | 15 | #[error(transparent)] 16 | SentencePiece(#[from] SentencePieceError), 17 | 18 | #[error("Cannot process word pieces: {0}")] 19 | WordPieces(#[from] WordPiecesError), 20 | } 21 | 22 | impl TokenizerError { 23 | pub fn open_error(model_path: impl Into, inner: io::Error) -> Self { 24 | TokenizerError::OpenError { 25 | model_path: model_path.into(), 26 | inner, 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /syntaxdot-transformers/src/models/encoder.rs: -------------------------------------------------------------------------------- 1 | use tch::Tensor; 2 | 3 | use crate::models::layer_output::LayerOutput; 4 | use crate::TransformerError; 5 | 6 | /// Encoder networks. 7 | pub trait Encoder { 8 | /// Apply the encoder. 9 | /// 10 | /// Returns the output and attention per layer. The (optional) 11 | /// attention mask of shape `[batch_size, time_steps]` indicates 12 | /// which tokens should be included (`true`) and excluded (`false`) from 13 | /// attention. This can be used to mask inactive timesteps. 14 | fn encode( 15 | &self, 16 | input: &Tensor, 17 | attention_mask: Option<&Tensor>, 18 | train: bool, 19 | ) -> Result, TransformerError>; 20 | 21 | /// Get the number of layers that is returned by the encoder. 22 | fn n_layers(&self) -> i64; 23 | } 24 | -------------------------------------------------------------------------------- /syntaxdot-encoders/testdata/lang/de/tdz/lemma/simplify-article-lemma.test: -------------------------------------------------------------------------------- 1 | # Format: form lemma upos xpos transformed [rel head_form head_lemma head_upos head_xpos] 2 | # [rel dep_form dep_lemma dep_upos dep_xpos]* 3 | 4 | # Definite articles are lemmatized to 'd' 5 | der _ _ ART d 6 | des _ _ ART d 7 | dem _ _ ART d 8 | den _ _ ART d 9 | die _ _ ART d 10 | das _ _ ART d 11 | 12 | # Indefinite articles are lemmatized to 'e' 13 | ein _ _ ART e 14 | eines _ _ ART e 15 | einem _ _ ART e 16 | einen _ _ ART e 17 | eine _ _ ART e 18 | einer _ _ ART e 19 | 20 | # Definite articles are lemmatized to 'd' 21 | der _ _ ART d 22 | 23 | # Substituting relative pronoun 24 | der _ _ PRELS d 25 | 26 | # Attributive relative pronoun 27 | dessen _ _ PRELAT d 28 | 29 | # Do not attempt to lemmatize words with an irrelevant tag. 30 | der foo _ XY foo 31 | ein bar _ XY bar -------------------------------------------------------------------------------- /syntaxdot-encoders/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "syntaxdot-encoders" 3 | version = "0.5.0" 4 | authors = ["Daniël de Kok "] 5 | edition = "2018" 6 | description = "Encoders for linguistic features" 7 | homepage = "https://github.com/tensordot/syntaxdot" 8 | repository = "https://github.com/tensordot/syntaxdot.git" 9 | documentation = "https://docs.rs/syntaxdot-encoders/" 10 | license = "MIT OR Apache-2.0" 11 | rust-version = "1.70.0" 12 | 13 | [dependencies] 14 | caseless = "0.2" 15 | conllu = "0.8" 16 | fst = "0.4" 17 | itertools = "0.11" 18 | numberer = "0.2" 19 | lazy_static = "1" 20 | maplit = "1" 21 | ndarray = "0.15" 22 | ordered-float = "4" 23 | petgraph = "0.6" 24 | seqalign = "0.2" 25 | serde = { version = "1", features = ["derive"] } 26 | serde_derive = "1" 27 | thiserror = "1" 28 | udgraph = "0.8" 29 | unicode-normalization = "0.1" 30 | -------------------------------------------------------------------------------- /syntaxdot-transformers/src/error.rs: -------------------------------------------------------------------------------- 1 | use tch::TchError; 2 | use thiserror::Error; 3 | 4 | /// Transformer errors. 5 | #[derive(Debug, Error)] 6 | #[non_exhaustive] 7 | pub enum TransformerError { 8 | /// The hidden size is not a multiple of the number of attention heads. 9 | #[error("hidden size ({hidden_size:?}) is not a multiple of attention heads ({num_attention_heads:?})")] 10 | IncorrectHiddenSize { 11 | /// The hidden size. 12 | hidden_size: i64, 13 | 14 | /// The number of attention heads. 15 | num_attention_heads: i64, 16 | }, 17 | 18 | /// Torch error. 19 | #[error(transparent)] 20 | Tch(#[from] TchError), 21 | 22 | /// The activation function is unknown. 23 | #[error("unknown activation function: {activation:?}")] 24 | UnknownActivationFunction { activation: String }, 25 | } 26 | -------------------------------------------------------------------------------- /syntaxdot-encoders/testdata/lang/de/tdz/lemma/add-separated-verb-prefix.test: -------------------------------------------------------------------------------- 1 | # Format: form lemma upos xpos transformed [rel head_form head_lemma head_upos head_xpos] 2 | # [rel dep_form dep_lemma dep_upos dep_xpos]* 3 | 4 | # Single separated prefix 5 | _ zeichnen _ VVFIN ab#zeichnen _ _ _ _ _ AVZ ab _ _ PTKVZ 6 | _ zeichnen _ VVFIN ab#zeichnen _ _ _ _ _ AVZ Ab _ _ PTKVZ 7 | _ stellen _ VVFIN vor#stellen _ _ _ _ _ AVZ vor _ _ PTKVZ 8 | _ müssen _ VVFIN rein#müssen _ _ _ _ _ AVZ rein _ _ PTKVZ 9 | _ werden _ VVFIN los#werden _ _ _ _ _ AVZ los _ _ PTKVZ 10 | 11 | # Multiple separated prefixes 12 | _ nehmen _ VVFIN zu#nehmen|ab#nehmen _ _ _ _ _ AVZ ab _ _ PTKVZ KON zu _ _ PTKVZ 13 | 14 | # No splitting necessary 15 | kommt kommen _ VVFIN kommen 16 | 17 | # Not all tags can have separated prefixes. 18 | _ zeichnen _ VVINF zeichnen _ _ _ _ _ AVZ ab _ _ PTKVZ 19 | -------------------------------------------------------------------------------- /syntaxdot-transformers/src/module.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | 3 | use tch::Tensor; 4 | 5 | /// Module for which a computation can fail. 6 | pub trait FallibleModule: Debug + Send { 7 | /// The error type. 8 | type Error; 9 | 10 | /// Apply the module. 11 | fn forward(&self, input: &Tensor) -> Result; 12 | } 13 | 14 | /// Module for which a computation can fail. 15 | pub trait FallibleModuleT: Debug + Send { 16 | /// The error type. 17 | type Error; 18 | 19 | /// Apply the module. 20 | fn forward_t(&self, input: &Tensor, train: bool) -> Result; 21 | } 22 | 23 | impl FallibleModuleT for M 24 | where 25 | M: FallibleModule, 26 | { 27 | type Error = M::Error; 28 | 29 | fn forward_t(&self, input: &Tensor, _train: bool) -> Result { 30 | self.forward(input) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /syntaxdot-encoders/testdata/lang/de/tdz/lemma/simplify-personal-pronoun.test: -------------------------------------------------------------------------------- 1 | # Format: form lemma upos xpos transformed [rel head_form head_lemma head_upos head_xpos] 2 | # [rel dep_form dep_lemma dep_upos dep_xpos]* 3 | 4 | # Personal pronouns 5 | ich _ _ PPER ich 6 | mir _ _ PPER ich 7 | mich _ _ PPER ich 8 | meiner _ _ PPER ich 9 | du _ _ PPER du 10 | dir _ _ PPER du 11 | dich _ _ PPER du 12 | deiner _ _ PPER du 13 | er _ _ PPER er 14 | ihn _ _ PPER er 15 | ihm _ _ PPER er 16 | seiner _ _ PPER er 17 | sie _ _ PPER sie 18 | ihr _ _ PPER sie 19 | ihnen _ _ PPER sie 20 | ihrer _ _ PPER sie 21 | es _ _ PPER es 22 | 's _ _ PPER es 23 | wir _ _ PPER wir 24 | uns _ _ PPER wir 25 | unser _ _ PPER wir 26 | euch _ _ PPER ihr 27 | 28 | # Use the lemma when the form is unknown 29 | ic ich _ PPER ich 30 | 31 | # Exclude other tags 32 | ich foo _ XY foo 33 | mir bar _ XY bar 34 | mich baz _ XY baz -------------------------------------------------------------------------------- /syntaxdot-tch-ext/src/tensor.rs: -------------------------------------------------------------------------------- 1 | //! Convenience functions for `Tensor`. 2 | //! 3 | //! The `Tensor` API can be a bit unwieldy since it is partly 4 | //! autogenerated. This module prodides some additional methods 5 | //! that are more convenient to use. 6 | 7 | use tch::{Kind, TchError, Tensor}; 8 | 9 | pub trait SumDim { 10 | /// Sum over a dimension (fallible). 11 | fn f_sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Result; 12 | 13 | /// Sum over a dimension. 14 | fn sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Tensor; 15 | } 16 | 17 | impl SumDim for Tensor { 18 | fn f_sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Result { 19 | self.f_sum_dim_intlist(Some([dim].as_slice()), keep_dim, kind) 20 | } 21 | 22 | fn sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Tensor { 23 | self.f_sum_dim(dim, keep_dim, kind).unwrap() 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /syntaxdot-encoders/testdata/lang/de/tdz/lemma/restore-case.test: -------------------------------------------------------------------------------- 1 | # Format: form lemma upos xpos transformed [rel head_form head_lemma head_upos head_xpos] 2 | # [rel dep_form dep_lemma dep_upos dep_xpos]* 3 | 4 | # Nouns should have uppercased initial letters. 5 | _ bett _ NN Bett 6 | _ Bett _ NN Bett 7 | 8 | # Named entities 9 | Apple apple _ NE Apple 10 | Apple's apple _ NE Apple 11 | Italiens italien _ NE Italien 12 | Liga-Chef liga-chef _ NE Liga-Chef 13 | Liga-Chef's liga-chef _ NE Liga-Chef 14 | LigaChef's liga-chef _ NE Liga-Chef 15 | D'Alema d'alema _ NE D'Alema 16 | CDU cdu _ NE CDU 17 | CDU's cdu _ NE CDU 18 | foobar cdu _ NE cdu 19 | 20 | # Check that strings are normalized. The form containes the composed 21 | # character u0065 u0308, rather than u00eb. 22 | Ëee ëee _ NE Ëee 23 | 24 | # For other tags, nothing changes. 25 | _ laufen _ VVFIN laufen 26 | Laufen laufen _ VVFIN laufen -------------------------------------------------------------------------------- /syntaxdot-transformers/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "syntaxdot-transformers" 3 | version = "0.5.0" 4 | authors = ["Daniël de Kok "] 5 | edition = "2018" 6 | description = "Transformer architectures, such as BERT" 7 | homepage = "https://github.com/tensordot/syntaxdot" 8 | repository = "https://github.com/tensordot/syntaxdot.git" 9 | documentation = "https://docs.rs/syntaxdot-transformers/" 10 | license = "MIT OR Apache-2.0" 11 | rust-version = "1.70.0" 12 | 13 | [dependencies] 14 | serde = { version = "1", features = ["derive"] } 15 | syntaxdot-tch-ext = { path = "../syntaxdot-tch-ext", version = "0.5.0" } 16 | tch = { version = "0.14", default-features = false } 17 | thiserror = "1" 18 | 19 | [dev-dependencies] 20 | approx = "0.5" 21 | maplit = "1" 22 | ndarray = { version = "0.15", features = ["approx-0_5"] } 23 | 24 | [features] 25 | model-tests = [] 26 | doc-only = ["syntaxdot-tch-ext/doc-only", "tch/doc-only"] 27 | 28 | [package.metadata.docs.rs] 29 | features = [ "doc-only" ] 30 | -------------------------------------------------------------------------------- /syntaxdot-encoders/testdata/lang/de/tdz/lemma/mark-verb-prefix.test: -------------------------------------------------------------------------------- 1 | # Format: form lemma upos xpos transformed [rel head_form head_lemma head_upos head_xpos] 2 | # [rel dep_form dep_lemma dep_upos dep_xpos]* 3 | 4 | # Lookup table 5 | _ abbestellen _ VVFIN ab#bestellen 6 | 7 | # Derive from form 8 | dazugefügt fügen _ VVFIN dazu#fügen 9 | wiederaufgebaut bauen _ VVFIN wieder#auf#bauen 10 | abzuarbeiten arbeiten _ VVIZU ab#arbeiten 11 | abgefangen fangen _ VVPP ab#fangen 12 | abhing hängen _ VVFIN ab#hängen 13 | 14 | # Multiple prefixes 15 | mitabgedruckt drucken _ VVFIN mit#ab#drucken 16 | 17 | # 'zustande' is a particle, so here we do not want to match the 18 | # longest common prefix. 19 | zustanden stehen _ VVFIN zu#stehen 20 | 21 | # 'hinzu' is a particle, check that we don't analysze as 'hinzu#bewegen' 22 | hinzubewegen bewegen _ VVIZU hin#bewegen 23 | zuzuspitzen spitzen _ VVIZU zu#spitzen 24 | hinzuwirken wirken _ VVIZU hin#wirken 25 | hinzuzufügen fügen _ VVIZU hinzu#fügen 26 | mitaufzunehmen nehmen _ VVIZU mit#auf#nehmen -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any 2 | person obtaining a copy of this software and associated 3 | documentation files (the "Software"), to deal in the 4 | Software without restriction, including without 5 | limitation the rights to use, copy, modify, merge, 6 | publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software 8 | is furnished to do so, subject to the following 9 | conditions: 10 | 11 | The above copyright notice and this permission notice 12 | shall be included in all copies or substantial portions 13 | of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 23 | DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /syntaxdot-cli/src/summary/mod.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use clap::{Arg, ArgMatches, Command}; 3 | 4 | use crate::traits::SyntaxDotOption; 5 | 6 | mod noop; 7 | use noop::NoopWriter; 8 | 9 | mod tensorboard; 10 | 11 | const LOG_PREFIX: &str = "LOG_PREFIX"; 12 | 13 | pub trait ScalarWriter { 14 | fn write_scalar(&self, tag: &str, step: i64, value: f32) -> Result<()>; 15 | } 16 | 17 | pub struct SummaryOption; 18 | 19 | impl SyntaxDotOption for SummaryOption { 20 | type Value = Box; 21 | 22 | fn add_to_app(app: Command) -> Command { 23 | app.arg( 24 | Arg::new(LOG_PREFIX) 25 | .long("log-prefix") 26 | .value_name("PREFIX") 27 | .num_args(1) 28 | .help("Prefix for Tensorboard logs"), 29 | ) 30 | } 31 | 32 | fn parse(matches: &ArgMatches) -> Result { 33 | Ok(match matches.get_one::(LOG_PREFIX) { 34 | Some(prefix) => { 35 | Box::new(tensorboard::TensorBoardWriter::new(prefix)?) as Box 36 | } 37 | None => Box::new(NoopWriter), 38 | }) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /syntaxdot-cli/src/util.rs: -------------------------------------------------------------------------------- 1 | use std::io::BufRead; 2 | use std::os::raw::c_int; 3 | 4 | use anyhow::Result; 5 | 6 | pub fn count_sentences(mut buf_read: impl BufRead) -> Result { 7 | let mut n_sents = 0; 8 | 9 | loop { 10 | let buf = buf_read.fill_buf()?; 11 | 12 | if buf.is_empty() { 13 | break; 14 | } 15 | 16 | n_sents += bytecount::count(buf, b'\n'); 17 | 18 | // Satisfy borrows checker. 19 | let buf_len = buf.len(); 20 | buf_read.consume(buf_len); 21 | } 22 | 23 | Ok(n_sents) 24 | } 25 | 26 | #[allow(dead_code)] 27 | #[no_mangle] 28 | extern "C" fn mkl_serv_intel_cpu_true() -> c_int { 29 | 1 30 | } 31 | 32 | /// Runs a closure with autocast. 33 | /// 34 | /// This function runs a closure with `autocast` if enabled 35 | /// is set to `true`. Otherwise, the closure is run without 36 | /// autocast *iff* the calling function is **not** autocast. 37 | /// 38 | /// This function can be used to avoid autocasting overhead. 39 | pub fn autocast_or_preserve(enabled: bool, f: F) -> T 40 | where 41 | F: FnOnce() -> T, 42 | { 43 | if enabled { 44 | tch::autocast(true, f) 45 | } else { 46 | f() 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /syntaxdot-tokenizers/src/lib.rs: -------------------------------------------------------------------------------- 1 | use ndarray::Array1; 2 | use udgraph::graph::Sentence; 3 | 4 | mod albert; 5 | pub use albert::AlbertTokenizer; 6 | 7 | mod bert; 8 | pub use bert::BertTokenizer; 9 | 10 | mod error; 11 | pub use error::TokenizerError; 12 | 13 | mod xlm_roberta; 14 | pub use xlm_roberta::XlmRobertaTokenizer; 15 | 16 | /// Trait for wordpiece tokenizers. 17 | pub trait Tokenize: Send + Sync { 18 | /// Tokenize the tokens in a sentence into word pieces. 19 | /// 20 | /// Implementations **must** prefix the first piece corresponding to a 21 | /// token by one or more special pieces marking the beginning of the 22 | /// sentence. The representation of this piece can be used for special 23 | /// purposes, such as classification or acting is the pseudo-root in 24 | /// dependency parsing. 25 | fn tokenize(&self, sentence: Sentence) -> SentenceWithPieces; 26 | } 27 | 28 | /// A sentence and its word pieces. 29 | #[derive(Debug, Eq, PartialEq)] 30 | pub struct SentenceWithPieces { 31 | /// Word pieces in a sentence. 32 | pub pieces: Array1, 33 | 34 | /// Sentence graph. 35 | pub sentence: Sentence, 36 | 37 | /// The the offsets of tokens in `pieces`. 38 | pub token_offsets: Vec, 39 | } 40 | -------------------------------------------------------------------------------- /syntaxdot-summary/src/record_writer.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, Write}; 2 | 3 | use crate::crc32::CheckSummer; 4 | 5 | /// Write data in TFRecord format. 6 | pub struct TfRecordWriter { 7 | checksummer: CheckSummer, 8 | write: W, 9 | } 10 | 11 | impl From for TfRecordWriter 12 | where 13 | W: Write, 14 | { 15 | fn from(write: W) -> Self { 16 | TfRecordWriter { 17 | checksummer: CheckSummer::new(), 18 | write, 19 | } 20 | } 21 | } 22 | 23 | impl TfRecordWriter 24 | where 25 | W: Write, 26 | { 27 | pub fn flush(&mut self) -> io::Result<()> { 28 | self.write.flush() 29 | } 30 | 31 | pub fn write(&mut self, data: &[u8]) -> io::Result<()> { 32 | let len = (data.len() as u64).to_le_bytes(); 33 | self.write.write_all(&len)?; 34 | self.write 35 | .write_all(&self.checksummer.crc32c_masked(&len).to_le_bytes())?; 36 | self.write.write_all(data)?; 37 | self.write 38 | .write_all(&self.checksummer.crc32c_masked(data).to_le_bytes())?; 39 | 40 | Ok(()) 41 | } 42 | } 43 | 44 | // TFRecord format: 45 | // 46 | // uint64 length 47 | // uint32 masked_crc32_of_length 48 | // byte data[length] 49 | // uint32 masked_crc32_of_data 50 | -------------------------------------------------------------------------------- /scripts/update-syntaxdot-0.2-model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import os 5 | import re 6 | import sys 7 | 8 | import torch 9 | 10 | from tensor_module import TensorModule 11 | 12 | parser = argparse.ArgumentParser( 13 | description='Update a SyntaxDot 0.2 BERT/RoBERTa-based model.') 14 | parser.add_argument( 15 | 'model', 16 | metavar='MODEL', 17 | help='The model path') 18 | parser.add_argument('converted_model', metavar='CONVERTED_MODEL', help='The converted model') 19 | 20 | if __name__ == "__main__": 21 | args = parser.parse_args() 22 | 23 | model = torch.jit.load(args.model, map_location=torch.device('cpu')) 24 | 25 | tensors = {} 26 | 27 | embeddings_re = re.compile("^encoder\|([^_]+_embeddings)") 28 | embeddings_layernorm_re = re.compile("^encoder\|layer_norm") 29 | for var, tensor in model.named_parameters(): 30 | var = embeddings_re.sub(r"embeddings|\1", var) 31 | var = embeddings_layernorm_re.sub(r"embeddings|layer_norm", var) 32 | var = var.replace("encoder|token_type_embeddings", "embeddings|token_type_embeddings") 33 | 34 | tensors[var] = tensor 35 | 36 | 37 | wrapper = TensorModule(tensors) 38 | script = torch.jit.script(wrapper) 39 | script.save(args.converted_model) 40 | -------------------------------------------------------------------------------- /syntaxdot-encoders/testdata/lang/de/tdz/lemma/simplify-possesive-pronoun-lemma.test: -------------------------------------------------------------------------------- 1 | # Format: form lemma upos xpos transformed [rel head_form head_lemma head_upos head_xpos] 2 | # [rel dep_form dep_lemma dep_upos dep_xpos]* 3 | 4 | # Substituting possesive pronouns 5 | Ihre _ _ PPOSS ihr 6 | ihre _ _ PPOSS ihr 7 | meine _ _ PPOSS mein 8 | meinen _ _ PPOSS mein 9 | deines _ _ PPOSS dein 10 | seine _ _ PPOSS sein 11 | seinen _ _ PPOSS sein 12 | unsere _ _ PPOSS unser 13 | unsern _ _ PPOSS unser 14 | unsrige _ _ PPOSS unsrig 15 | unsrigen _ _ PPOSS unsrig 16 | 17 | # Attributive possesive pronouns 18 | dein _ _ PPOSAT dein 19 | deine _ _ PPOSAT dein 20 | euer _ _ PPOSAT euer 21 | eure _ _ PPOSAT euer 22 | euren _ _ PPOSAT euer 23 | ihr _ _ PPOSAT ihr 24 | ihrem _ _ PPOSAT ihr 25 | Ihre _ _ PPOSAT ihr 26 | Ihrem _ _ PPOSAT ihr 27 | mein _ _ PPOSAT mein 28 | meine _ _ PPOSAT mein 29 | meinem _ _ PPOSAT mein 30 | sein _ _ PPOSAT sein 31 | seine _ _ PPOSAT sein 32 | seinem _ _ PPOSAT sein 33 | seinen _ _ PPOSAT sein 34 | unser _ _ PPOSAT unser 35 | unsere _ _ PPOSAT unser 36 | unserem _ _ PPOSAT unser 37 | 38 | # Use the lemma when the prefix is unknown 39 | unsre unsere _ PPOSAT unsere 40 | 41 | # Exclude other tags 42 | Ihre foo _ XY foo 43 | ihre bar _ XY bar 44 | unserem baz _ XY baz -------------------------------------------------------------------------------- /syntaxdot-cli/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "syntaxdot-cli" 3 | version = "0.5.0" 4 | edition = "2018" 5 | authors = ["Daniël de Kok "] 6 | description = "Neural sequence labeler" 7 | homepage = "https://github.com/tensordot/syntaxdot" 8 | repository = "https://github.com/tensordot/syntaxdot.git" 9 | documentation = "https://github.com/tensordot/syntaxdot" 10 | license = "MIT OR Apache-2.0" 11 | rust-version = "1.70.0" 12 | 13 | [[bin]] 14 | name = "syntaxdot" 15 | path = "src/main.rs" 16 | 17 | [dependencies] 18 | anyhow = "1" 19 | bytecount = "0.6" 20 | clap = { version = "4", features = ["cargo"] } 21 | clap_complete = "4" 22 | conllu = "0.8" 23 | env_logger = "0.10" 24 | indicatif = "0.17" 25 | itertools = "0.11" 26 | log = "0.4" 27 | ndarray = "0.15" 28 | ordered-float = { version = "4", features = ["serde"] } 29 | rayon = "1" 30 | serde_yaml = "0.8" 31 | stdinout = "0.4" 32 | syntaxdot = { path = "../syntaxdot", version = "0.5.0", default-features = false } 33 | syntaxdot-encoders = { path = "../syntaxdot-encoders", version = "0.5.0" } 34 | syntaxdot-summary = { path = "../syntaxdot-summary", version = "0.5.0" } 35 | syntaxdot-tch-ext = { path = "../syntaxdot-tch-ext", version = "0.5.0" } 36 | syntaxdot-tokenizers = { path = "../syntaxdot-tokenizers", version = "0.5.0" } 37 | syntaxdot-transformers = { path = "../syntaxdot-transformers", version = "0.5.0", default-features = false } 38 | tch = { version = "0.14", default-features = false } 39 | udgraph = "0.8" 40 | -------------------------------------------------------------------------------- /syntaxdot/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "syntaxdot" 3 | version = "0.5.0" 4 | edition = "2018" 5 | authors = ["Daniël de Kok "] 6 | description = "Neural sequence labeler" 7 | homepage = "https://github.com/tensordot/syntaxdot" 8 | repository = "https://github.com/tensordot/syntaxdot.git" 9 | documentation = "https://docs.rs/syntaxdot/" 10 | license = "MIT OR Apache-2.0" 11 | rust-version = "1.70.0" 12 | 13 | [dependencies] 14 | chu-liu-edmonds = "0.1" 15 | conllu = "0.8" 16 | ndarray = "0.15" 17 | numberer = "0.2" 18 | log = "0.4" 19 | ordered-float = "4" 20 | rand = "0.8" 21 | rand_xorshift = "0.3" 22 | serde = { version = "1", features = [ "derive" ] } 23 | serde_json = "1" 24 | syntaxdot-encoders = { path = "../syntaxdot-encoders", version = "0.5.0" } 25 | syntaxdot-tch-ext = { path = "../syntaxdot-tch-ext", version = "0.5.0" } 26 | syntaxdot-tokenizers = { path = "../syntaxdot-tokenizers", default-features = false, version = "0.5.0" } 27 | syntaxdot-transformers = { path = "../syntaxdot-transformers", default-features = false, version = "0.5.0" } 28 | tch = { version = "0.14", default-features = false } 29 | thiserror = "1" 30 | toml = "0.8" 31 | udgraph = "0.8" 32 | 33 | [dev-dependencies] 34 | approx = "0.5" 35 | lazy_static = "1" 36 | maplit = "1" 37 | wordpieces = "0.6" 38 | 39 | [features] 40 | model-tests = [] 41 | doc-only = ["syntaxdot-tch-ext/doc-only", "syntaxdot-transformers/doc-only", "tch/doc-only"] 42 | 43 | [package.metadata.docs.rs] 44 | features = [ "doc-only" ] 45 | -------------------------------------------------------------------------------- /scripts/pytorch-squeezebert-to-syntaxdot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import re 5 | import sys 6 | 7 | import torch 8 | 9 | from tensor_module import TensorModule 10 | 11 | parser = argparse.ArgumentParser( 12 | description='Convert a PyTorch SqueezeBERT checkpoint to SyntaxDot tensors.') 13 | parser.add_argument( 14 | 'model', 15 | metavar='MODEL', 16 | help='The model path') 17 | parser.add_argument('tensors', metavar='TENSORS', help='SyntaxDot tensors') 18 | 19 | if __name__ == "__main__": 20 | args = parser.parse_args() 21 | 22 | model = torch.load(args.model) 23 | 24 | tensors = {} 25 | 26 | ignore = re.compile("cls|pooler") 27 | for var, tensor in model.items(): 28 | # Skip unneeded layers 29 | if ignore.search(var): 30 | continue 31 | 32 | var = var.replace("transformer.", "") 33 | var = var.replace("embeddings.weight", "embeddings.embeddings") 34 | var = var.replace("encoder.layers.", "encoder.layer_") 35 | var = var.replace("LayerNorm", "layer_norm") 36 | var = var.replace("layernorm", "layer_norm") 37 | 38 | # Finally, Rust VarStore replaces periods by vertical bars 39 | # during saving. 40 | var = var.replace(".", "|") 41 | 42 | print("Adding %s..." % var, file=sys.stderr) 43 | 44 | tensors[var] = tensor 45 | 46 | wrapper = TensorModule(tensors) 47 | script = torch.jit.script(wrapper) 48 | script.save(args.tensors) 49 | -------------------------------------------------------------------------------- /syntaxdot-encoders/data/lang/de/tdz/tdz11-separable-prefixes.txt: -------------------------------------------------------------------------------- 1 | ab 2 | an 3 | aneinander 4 | auf 5 | aufeinander 6 | aus 7 | auseinander 8 | bei 9 | beisammen 10 | beiseite 11 | bekannt 12 | bereit 13 | bevor 14 | breit 15 | da 16 | dagegen 17 | daher 18 | dahin 19 | dahinter 20 | daneben 21 | dar 22 | daraufhin 23 | davon 24 | dazu 25 | dazwischen 26 | dicht 27 | dran 28 | drauf 29 | drein 30 | drin 31 | durch 32 | durcheinander 33 | ein 34 | einher 35 | entgegen 36 | entlang 37 | fehl 38 | feil 39 | fern 40 | fertig 41 | fest 42 | fort 43 | frei 44 | gegenüber 45 | gleich 46 | halt 47 | heim 48 | her 49 | herab 50 | heran 51 | herauf 52 | heraus 53 | herbei 54 | herein 55 | herum 56 | herunter 57 | hervor 58 | herüber 59 | hin 60 | hinab 61 | hinauf 62 | hinaus 63 | hinein 64 | hinterher 65 | hinunter 66 | hinweg 67 | hinzu 68 | hoch 69 | inne 70 | kaputt 71 | kehrt 72 | kennen 73 | klar 74 | kund 75 | lahm 76 | leer 77 | leicht 78 | leid 79 | locker 80 | los 81 | mit 82 | nach 83 | nahe 84 | nieder 85 | näher 86 | offen 87 | preis 88 | quer 89 | ran 90 | raus 91 | rein 92 | rum 93 | runter 94 | rüber 95 | schief 96 | schwer 97 | sicher 98 | stand 99 | statt 100 | still 101 | teil 102 | tot 103 | um 104 | umher 105 | umhin 106 | unter 107 | voll 108 | vor 109 | voran 110 | voraus 111 | vorbei 112 | vorweg 113 | vorüber 114 | wahr 115 | weg 116 | weiter 117 | wett 118 | wider 119 | wieder 120 | zu 121 | zueinander 122 | zufrieden 123 | zugrunde 124 | zurecht 125 | zurück 126 | zusammen 127 | zustande 128 | zuvor 129 | übel 130 | über 131 | überein 132 | übereinander 133 | übrig 134 | -------------------------------------------------------------------------------- /syntaxdot/src/error.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | 3 | use ndarray::ShapeError; 4 | use syntaxdot_encoders::dependency; 5 | use syntaxdot_tokenizers::TokenizerError; 6 | use syntaxdot_transformers::TransformerError; 7 | use tch::TchError; 8 | use thiserror::Error; 9 | 10 | use crate::encoders::{DecoderError, EncoderError}; 11 | 12 | #[non_exhaustive] 13 | #[derive(Debug, Error)] 14 | pub enum SyntaxDotError { 15 | #[error(transparent)] 16 | BertError(#[from] TransformerError), 17 | 18 | #[error(transparent)] 19 | ConlluError(#[from] conllu::Error), 20 | 21 | #[error(transparent)] 22 | DecoderError(#[from] DecoderError), 23 | 24 | #[error(transparent)] 25 | DependencyEncodeError(#[from] dependency::EncodeError), 26 | 27 | #[error(transparent)] 28 | EncoderError(#[from] EncoderError), 29 | 30 | #[error("Illegal configuration: {0}")] 31 | IllegalConfigurationError(String), 32 | 33 | #[error(transparent)] 34 | IoError(#[from] io::Error), 35 | 36 | #[error("The optimizer does not have any associated trainable variables")] 37 | NoTrainableVariables, 38 | 39 | #[error("{0}: {1}")] 40 | JSonSerialization(String, serde_json::Error), 41 | 42 | #[error("Cannot relativize path: {0}")] 43 | RelativizePathError(String), 44 | 45 | #[error(transparent)] 46 | ShapeError(#[from] ShapeError), 47 | 48 | #[error(transparent)] 49 | Tch(#[from] TchError), 50 | 51 | #[error(transparent)] 52 | TomlDeserializationError(#[from] toml::de::Error), 53 | 54 | #[error(transparent)] 55 | TokenizerError(#[from] TokenizerError), 56 | 57 | #[error(transparent)] 58 | UdgraphError(#[from] udgraph::Error), 59 | } 60 | -------------------------------------------------------------------------------- /scripts/pytorch-roberta-to-syntaxdot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import os 5 | import re 6 | import sys 7 | 8 | import torch 9 | 10 | from tensor_module import TensorModule 11 | 12 | parser = argparse.ArgumentParser( 13 | description='Convert a PyTorch RoBERTa checkpoint to SyntaxDot tensors.') 14 | parser.add_argument( 15 | 'model', 16 | metavar='MODEL', 17 | help='The model path') 18 | parser.add_argument('tensors', metavar='TENSORS', help='SyntaxDot tensors') 19 | 20 | if __name__ == "__main__": 21 | args = parser.parse_args() 22 | 23 | model = torch.load(args.model) 24 | 25 | tensors = {} 26 | 27 | ignore = re.compile("adam_v|adam_m|global_step|lm_head|pooler") 28 | kernel = re.compile("(key|query|value|dense)/weight") 29 | for var, tensor in model.items(): 30 | # Skip unneeded layers 31 | if ignore.search(var): 32 | continue 33 | 34 | var = var.replace("roberta.", "") 35 | var = var.replace("embeddings.weight", "embeddings.embeddings") 36 | var = var.replace("encoder.layer.", "encoder.layer_") 37 | var = var.replace("LayerNorm", "layer_norm"); 38 | 39 | # Attention weight matrices are transposed, compared to BERT. 40 | if kernel.search(var): 41 | tensor = tensor.t() 42 | 43 | # Finally, Rust VarStore replaces periods by vertical bars 44 | # during saving. 45 | var = var.replace(".", "|") 46 | 47 | print("Adding %s..." % var, file=sys.stderr) 48 | 49 | tensors[var] = tensor 50 | 51 | 52 | wrapper = TensorModule(tensors) 53 | script = torch.jit.script(wrapper) 54 | script.save(args.tensors) 55 | -------------------------------------------------------------------------------- /syntaxdot/src/optimizers/mod.rs: -------------------------------------------------------------------------------- 1 | use tch::nn::{self}; 2 | use tch::Tensor; 3 | 4 | mod grad; 5 | pub use grad::ZeroGrad; 6 | 7 | mod grad_scale; 8 | use crate::error::SyntaxDotError; 9 | pub use grad_scale::GradScaler; 10 | 11 | pub trait Optimizer { 12 | /// Perform a backward step on the given loss. 13 | fn backward_step(&mut self, loss: &Tensor) -> Result<(), SyntaxDotError>; 14 | 15 | /// Set the learning rate for a parameter group. 16 | fn set_lr_group(&mut self, group: usize, learning_rate: f64); 17 | 18 | /// Set the weight decay for a parameter group. 19 | fn set_weight_decay_group(&mut self, group: usize, weight_decay: f64); 20 | 21 | /// Perform an update step. 22 | /// 23 | /// It is generally recommended to use `backward_step`, since it 24 | /// computes the gradients and performs any loss scaling (if 25 | /// necessary). 26 | fn step(&mut self); 27 | 28 | /// Get the trainable variables associated with the optimizer. 29 | fn trainable_variables(&self) -> Vec; 30 | } 31 | 32 | impl Optimizer for nn::Optimizer { 33 | fn backward_step(&mut self, loss: &Tensor) -> Result<(), SyntaxDotError> { 34 | nn::Optimizer::backward_step(self, loss); 35 | Ok(()) 36 | } 37 | 38 | fn set_lr_group(&mut self, group: usize, learning_rate: f64) { 39 | nn::Optimizer::set_lr_group(self, group, learning_rate) 40 | } 41 | 42 | fn set_weight_decay_group(&mut self, group: usize, weight_decay: f64) { 43 | nn::Optimizer::set_weight_decay_group(self, group, weight_decay) 44 | } 45 | 46 | fn step(&mut self) { 47 | nn::Optimizer::step(self) 48 | } 49 | 50 | fn trainable_variables(&self) -> Vec { 51 | nn::Optimizer::trainable_variables(self) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /syntaxdot-transformers/src/models/bert/config.rs: -------------------------------------------------------------------------------- 1 | use serde::Deserialize; 2 | 3 | use crate::activations::Activation; 4 | use crate::models::traits::WordEmbeddingsConfig; 5 | 6 | /// Bert model configuration. 7 | #[derive(Clone, Debug, Deserialize)] 8 | #[serde(default)] 9 | pub struct BertConfig { 10 | pub attention_probs_dropout_prob: f64, 11 | pub hidden_act: Activation, 12 | pub hidden_dropout_prob: f64, 13 | pub hidden_size: i64, 14 | pub initializer_range: f64, 15 | pub intermediate_size: i64, 16 | pub layer_norm_eps: f64, 17 | pub max_position_embeddings: i64, 18 | pub num_attention_heads: i64, 19 | pub num_hidden_layers: i64, 20 | pub type_vocab_size: i64, 21 | pub vocab_size: i64, 22 | } 23 | 24 | impl Default for BertConfig { 25 | fn default() -> Self { 26 | BertConfig { 27 | attention_probs_dropout_prob: 0.1, 28 | hidden_act: Activation::Gelu, 29 | hidden_dropout_prob: 0.1, 30 | hidden_size: 768, 31 | initializer_range: 0.02, 32 | intermediate_size: 3072, 33 | layer_norm_eps: 1e-12, 34 | max_position_embeddings: 512, 35 | num_attention_heads: 12, 36 | num_hidden_layers: 12, 37 | type_vocab_size: 2, 38 | vocab_size: 30000, 39 | } 40 | } 41 | } 42 | 43 | impl WordEmbeddingsConfig for BertConfig { 44 | fn dims(&self) -> i64 { 45 | self.hidden_size 46 | } 47 | 48 | fn dropout(&self) -> f64 { 49 | self.hidden_dropout_prob 50 | } 51 | 52 | fn initializer_range(&self) -> f64 { 53 | self.initializer_range 54 | } 55 | 56 | fn layer_norm_eps(&self) -> f64 { 57 | self.layer_norm_eps 58 | } 59 | 60 | fn vocab_size(&self) -> i64 { 61 | self.vocab_size 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | fmt: 7 | name: Rustfmt 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v2 11 | - uses: actions-rs/toolchain@v1 12 | with: 13 | profile: minimal 14 | toolchain: stable 15 | override: true 16 | components: rustfmt 17 | - uses: actions-rs/cargo@v1 18 | with: 19 | command: fmt 20 | args: --all -- --check 21 | 22 | clippy: 23 | name: Clippy 24 | runs-on: ubuntu-latest 25 | steps: 26 | - uses: actions/checkout@v1 27 | - uses: actions/cache@v2 28 | with: 29 | path: | 30 | ~/.cargo/registry 31 | ~/.cargo/git 32 | target 33 | key: ${{ runner.os }}-cargo-clippy-${{ hashFiles('**/Cargo.lock') }} 34 | - uses: actions-rs/toolchain@v1 35 | with: 36 | profile: minimal 37 | toolchain: stable 38 | override: true 39 | components: clippy 40 | - uses: tensordot/libtorch-action@v2.1.0 41 | - uses: actions-rs/cargo@v1 42 | with: 43 | command: clippy 44 | args: -- -D warnings 45 | 46 | test: 47 | name: Test 48 | runs-on: ubuntu-latest 49 | steps: 50 | - uses: actions/checkout@v1 51 | - uses: actions/cache@v2 52 | with: 53 | path: | 54 | ~/.cargo/registry 55 | ~/.cargo/git 56 | target 57 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 58 | - uses: actions-rs/toolchain@v1 59 | with: 60 | profile: minimal 61 | toolchain: stable 62 | override: true 63 | - uses: tensordot/libtorch-action@v2.1.0 64 | - uses: actions-rs/cargo@v1 65 | with: 66 | command: test 67 | -------------------------------------------------------------------------------- /syntaxdot/src/encoders/config.rs: -------------------------------------------------------------------------------- 1 | use std::ops::Deref; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | use syntaxdot_encoders::depseq::PosLayer; 5 | use syntaxdot_encoders::layer::Layer; 6 | use syntaxdot_encoders::lemma::BackoffStrategy; 7 | 8 | /// Configuration of a set of encoders. 9 | /// 10 | /// The configuration is a mapping from encoder name to 11 | /// encoder configuration. 12 | #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] 13 | #[serde(deny_unknown_fields)] 14 | pub struct EncodersConfig(pub Vec); 15 | 16 | impl Deref for EncodersConfig { 17 | type Target = [NamedEncoderConfig]; 18 | 19 | fn deref(&self) -> &Self::Target { 20 | &self.0 21 | } 22 | } 23 | 24 | /// The type of encoder. 25 | #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] 26 | #[serde(rename_all = "lowercase")] 27 | pub enum EncoderType { 28 | /// Encoder for syntactical dependencies. 29 | Dependency { 30 | encoder: DependencyEncoder, 31 | root_relation: String, 32 | }, 33 | 34 | /// Lemma encoder using edit trees. 35 | Lemma(BackoffStrategy), 36 | 37 | /// Encoder for plain sequence labels. 38 | Sequence(Layer), 39 | 40 | /// Lemma encoder using edit trees, with TüBa-D/Z-specific 41 | /// transformations. 42 | TdzLemma(BackoffStrategy), 43 | } 44 | 45 | /// The type of dependency encoder. 46 | #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] 47 | #[serde(rename_all = "lowercase")] 48 | pub enum DependencyEncoder { 49 | /// Encode a token's head by relative position. 50 | RelativePosition, 51 | 52 | /// Encode a token's head by relative position of the POS tag. 53 | RelativePos(PosLayer), 54 | } 55 | 56 | /// Configuration of an encoder with a name. 57 | #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] 58 | #[serde(deny_unknown_fields)] 59 | pub struct NamedEncoderConfig { 60 | pub encoder: EncoderType, 61 | pub name: String, 62 | } 63 | -------------------------------------------------------------------------------- /scripts/test-all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script runs all tests, including model tests. In order to do 4 | # so, models are downloaded and stored in $XDG_CACHE_HOME/syntaxdot. 5 | 6 | set -euo pipefail 7 | IFS=$'\n\t' 8 | 9 | if ! [ -x "$(command -v curl)" ] ; then 10 | >&2 echo "'curl' is required for downloading test data" 11 | exit 1 12 | fi 13 | 14 | cache_dir="${XDG_CACHE_HOME:-$HOME/.cache}/syntaxdot" 15 | 16 | declare -A models=( 17 | ["ALBERT_BASE_V2"]="https://f001.backblazeb2.com/file/danieldk-blob/syntaxdot/albert-base-v2.pt" 18 | ["BERT_BASE_GERMAN_CASED"]="https://f001.backblazeb2.com/file/danieldk-blob/syntaxdot/bert-base-german-cased.pt" 19 | ["SQUEEZEBERT_UNCASED"]="https://f001.backblazeb2.com/file/danieldk-blob/syntaxdot/squeezebert-base-uncased.pt" 20 | ["XLM_ROBERTA_BASE"]="https://f001.backblazeb2.com/file/danieldk-blob/syntaxdot/xlm-roberta-base.pt" 21 | 22 | ["ALBERT_BASE_V2_SENTENCEPIECE"]="https://huggingface.co/albert-base-v2/resolve/main/spiece.model" 23 | ["BERT_BASE_GERMAN_CASED_VOCAB"]="https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt" 24 | ["XLM_ROBERTA_BASE_SENTENCEPIECE"]="https://huggingface.co/xlm-roberta-base/resolve/main/sentencepiece.bpe.model" 25 | ) 26 | 27 | if [ ! -d "$cache_dir" ]; then 28 | mkdir -p "$cache_dir" 29 | fi 30 | 31 | for var in "${!models[@]}"; do 32 | url="${models[$var]}" 33 | data="${cache_dir}/$(basename "${url}")" 34 | 35 | # Since these checkpoints were generated, an assumption was added that 36 | # .pt files are created from Python code. Rename to .ot to avoid loading 37 | # issues. 38 | data=${data/%.pt/.ot} 39 | 40 | if [ ! -e "${data}" ]; then 41 | curl -fo "${data}" "${url}" 42 | fi 43 | 44 | declare -x "${var}"="${data}" 45 | done 46 | 47 | # Regular tests for all crates 48 | cargo test 49 | 50 | # Regular tests + model tests for transformers and tokenizers 51 | ( cd syntaxdot-tokenizers ; cargo test --features model-tests ) 52 | ( cd syntaxdot-transformers ; cargo test --features model-tests ) 53 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Label encoders. 2 | 3 | use std::error::Error; 4 | 5 | use udgraph::graph::Sentence; 6 | 7 | pub mod categorical; 8 | 9 | pub mod dependency; 10 | 11 | pub mod depseq; 12 | 13 | pub mod layer; 14 | 15 | pub mod lang; 16 | 17 | pub mod lemma; 18 | 19 | /// An encoding with its probability. 20 | #[derive(Debug)] 21 | pub struct EncodingProb { 22 | encoding: E, 23 | prob: f32, 24 | } 25 | 26 | impl EncodingProb 27 | where 28 | E: ToOwned, 29 | { 30 | /// Create an encoding with its probability. 31 | /// 32 | /// This constructor takes an owned encoding. 33 | pub fn new(encoding: E, prob: f32) -> Self { 34 | EncodingProb { encoding, prob } 35 | } 36 | 37 | /// Get the encoding. 38 | pub fn encoding(&self) -> &E { 39 | &self.encoding 40 | } 41 | 42 | /// Get the probability of the encoding. 43 | pub fn prob(&self) -> f32 { 44 | self.prob 45 | } 46 | } 47 | 48 | impl From> for (String, f32) 49 | where 50 | E: Clone + ToString, 51 | { 52 | fn from(prob: EncodingProb) -> Self { 53 | (prob.encoding().to_string(), prob.prob()) 54 | } 55 | } 56 | 57 | /// Trait for sentence decoders. 58 | /// 59 | /// A sentence decoder adds a representation to each token in a 60 | /// sentence, such as a part-of-speech tag or a topological field. 61 | pub trait SentenceDecoder { 62 | type Encoding: ToOwned; 63 | 64 | /// The decoding error type. 65 | type Error: Error; 66 | 67 | fn decode(&self, labels: &[S], sentence: &mut Sentence) -> Result<(), Self::Error> 68 | where 69 | S: AsRef<[EncodingProb]>; 70 | } 71 | 72 | /// Trait for sentence encoders. 73 | /// 74 | /// A sentence encoder extracts a representation of each token in a 75 | /// sentence, such as a part-of-speech tag or a topological field. 76 | pub trait SentenceEncoder { 77 | type Encoding; 78 | 79 | /// The encoding error type. 80 | type Error: Error; 81 | 82 | /// Encode the given sentence. 83 | fn encode(&self, sentence: &Sentence) -> Result, Self::Error>; 84 | } 85 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/lang/de/tdz/lemma/constants.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashSet; 2 | 3 | use lazy_static::lazy_static; 4 | use maplit::hashset; 5 | 6 | pub(crate) static REFLEXIVE_PERSONAL_PRONOUN_LEMMA: &str = "#refl"; 7 | 8 | pub(crate) static SEPARABLE_PARTICLE_POS: &str = "PTKVZ"; 9 | 10 | pub(crate) static PUNCTUATION_PREFIX: &str = "$"; 11 | 12 | pub(crate) static ARTICLE_TAG: &str = "ART"; 13 | pub(crate) static ATTRIBUTIVE_POSSESIVE_PRONOUN_TAG: &str = "PPOSAT"; 14 | pub(crate) static SUBST_POSSESIVE_PRONOUN_TAG: &str = "PPOSS"; 15 | pub(crate) static FOREIGN_WORD_TAG: &str = "FM"; 16 | pub(crate) static NAMED_ENTITY_TAG: &str = "NE"; 17 | pub(crate) static NON_WORD_TAG: &str = "XY"; 18 | pub(crate) static NOUN_TAG: &str = "NN"; 19 | pub(crate) static PERSONAL_PRONOUN_TAG: &str = "PPER"; 20 | pub(crate) static REFLEXIVE_PERSONAL_PRONOUN_TAG: &str = "PRF"; 21 | pub(crate) static SUBST_REL_PRONOUN: &str = "PRELS"; 22 | pub(crate) static ATTR_REL_PRONOUN: &str = "PRELAT"; 23 | pub(crate) static TRUNCATED_TAG: &str = "TRUNC"; 24 | pub(crate) static ZU_INFINITIVE_VERB: &str = "VVIZU"; 25 | 26 | pub(crate) static SUBSTITUTING_INDEF_PRONOUN: &str = "PIS"; 27 | pub(crate) static ATTRIBUTING_INDEF_PRONOUN_WITHOUT_DET: &str = "PIAT"; 28 | pub(crate) static ATTRIBUTING_INDEF_PRONOUN_WITH_DET: &str = "PIDAT"; 29 | 30 | lazy_static! { 31 | pub(crate) static ref LEMMA_IS_FORM_TAGS: HashSet<&'static str> = hashset! { 32 | "$,", 33 | "$.", 34 | "$(", 35 | "ADV", 36 | "APPR", 37 | "APPO", 38 | "APZR", 39 | "ITJ", 40 | "KOUI", 41 | "KOUS", 42 | "KON", 43 | "KOKOM", 44 | "ADJD", 45 | "CARD", 46 | "PTKZU", 47 | "PTKA", 48 | "PTKNEG", 49 | }; 50 | pub(crate) static ref LEMMA_IS_FORM_PRESERVE_CASE_TAGS: HashSet<&'static str> = hashset! { 51 | FOREIGN_WORD_TAG, 52 | }; 53 | } 54 | 55 | pub(crate) fn is_verb(tag: S) -> bool 56 | where 57 | S: AsRef, 58 | { 59 | tag.as_ref().starts_with('V') 60 | } 61 | 62 | pub(crate) fn is_separable_verb(tag: S) -> bool 63 | where 64 | S: AsRef, 65 | { 66 | let tag = tag.as_ref(); 67 | tag == "VVFIN" || tag == "VVPP" || tag == "VVIMP" || tag == "VMFIN" || tag == "VAFIN" 68 | } 69 | -------------------------------------------------------------------------------- /syntaxdot-cli/src/traits.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use clap::{ArgMatches, Command}; 3 | use syntaxdot::optimizers::{GradScaler, Optimizer}; 4 | use tch::nn::{adamw, Optimizer as TchOptimizer, OptimizerConfig, VarStore}; 5 | 6 | pub enum ParameterGroup { 7 | Encoder = 0, 8 | Classifier = 1, 9 | EncoderNoWeightDecay = 2, 10 | ClassifierNoWeightDecay = 3, 11 | } 12 | 13 | pub trait SyntaxDotApp 14 | where 15 | Self: Sized, 16 | { 17 | fn app() -> Command; 18 | 19 | fn parse(matches: &ArgMatches) -> Result; 20 | 21 | fn run(&self) -> Result<()>; 22 | } 23 | 24 | pub trait SyntaxDotTrainApp: SyntaxDotApp { 25 | fn build_parameter_group_fun() -> fn(&str) -> usize { 26 | |name: &str| { 27 | if name.starts_with("classifiers") || name.starts_with("biaffine") { 28 | if name.contains("layer_norm") || name.contains("bias") { 29 | ParameterGroup::ClassifierNoWeightDecay as usize 30 | } else { 31 | ParameterGroup::Classifier as usize 32 | } 33 | } else if name.starts_with("encoder") || name.starts_with("embeddings") { 34 | if name.contains("layer_norm") || name.contains("bias") { 35 | ParameterGroup::EncoderNoWeightDecay as usize 36 | } else { 37 | ParameterGroup::Encoder as usize 38 | } 39 | } else { 40 | unreachable!(); 41 | } 42 | } 43 | } 44 | 45 | fn build_optimizer(&self, var_store: &VarStore) -> Result> { 46 | let opt = adamw(0.9, 0.999, self.weight_decay()).build(var_store, 1e-3)?; 47 | let mut grad_scaler = GradScaler::new_with_defaults(self.mixed_precision(), opt)?; 48 | grad_scaler.set_weight_decay_group(ParameterGroup::EncoderNoWeightDecay as usize, 0.); 49 | grad_scaler.set_weight_decay_group(ParameterGroup::ClassifierNoWeightDecay as usize, 0.); 50 | Ok(grad_scaler) 51 | } 52 | 53 | fn mixed_precision(&self) -> bool; 54 | 55 | fn weight_decay(&self) -> f64; 56 | } 57 | 58 | pub trait SyntaxDotOption { 59 | type Value; 60 | 61 | fn add_to_app(app: Command) -> Command; 62 | 63 | fn parse(matches: &ArgMatches) -> Result; 64 | } 65 | -------------------------------------------------------------------------------- /syntaxdot/src/dataset/mod.rs: -------------------------------------------------------------------------------- 1 | //! Iterators over data sets. 2 | 3 | use syntaxdot_tokenizers::{SentenceWithPieces, Tokenize}; 4 | 5 | use crate::error::SyntaxDotError; 6 | 7 | mod conll; 8 | pub use conll::ConlluDataSet; 9 | 10 | mod plaintext; 11 | pub use plaintext::PlainTextDataSet; 12 | 13 | pub(crate) mod tensor_iter; 14 | pub use tensor_iter::BatchedTensors; 15 | 16 | mod sentence_itertools; 17 | pub use sentence_itertools::{SentenceIterTools, SequenceLength}; 18 | 19 | /// A data set consisting of annotated or unannotated sentences. 20 | /// 21 | /// A `DataSet` provides an iterator over the sentences (and their 22 | /// pieces) in a data set. 23 | pub trait DataSet<'a> { 24 | type Iter: Iterator>; 25 | 26 | /// Get an iterator over the sentences and pieces in a dataset. 27 | /// 28 | /// The tokens are split in pieces with the given `tokenizer`. 29 | fn sentences(self, tokenizer: &'a dyn Tokenize) -> Result; 30 | } 31 | 32 | #[cfg(test)] 33 | pub(crate) mod tests { 34 | use std::io::{BufReader, Cursor}; 35 | 36 | use lazy_static::lazy_static; 37 | use ndarray::{array, Array1}; 38 | use syntaxdot_tokenizers::{BertTokenizer, SentenceWithPieces, Tokenize}; 39 | use wordpieces::WordPieces; 40 | 41 | use crate::dataset::DataSet; 42 | use crate::error::SyntaxDotError; 43 | 44 | const PIECES: &str = r#"[CLS] 45 | [UNK] 46 | Dit 47 | is 48 | de 49 | eerste 50 | zin 51 | . 52 | tweede 53 | laatste 54 | nu"#; 55 | 56 | lazy_static! { 57 | pub static ref CORRECT_PIECE_IDS: Vec> = vec![ 58 | array![0, 2, 3, 4, 5, 6, 7], 59 | array![0, 2, 4, 8, 6, 7], 60 | array![0, 1, 10, 4, 9, 6, 7] 61 | ]; 62 | } 63 | 64 | pub fn dataset_to_pieces<'a, D, I>( 65 | dataset: D, 66 | tokenizer: &'a dyn Tokenize, 67 | ) -> Result>, SyntaxDotError> 68 | where 69 | D: DataSet<'a, Iter = I>, 70 | I: Iterator>, 71 | { 72 | dataset 73 | .sentences(tokenizer)? 74 | .map(|s| s.map(|s| s.pieces)) 75 | .collect::, _>>() 76 | } 77 | 78 | pub fn wordpiece_tokenizer() -> BertTokenizer { 79 | let pieces = WordPieces::from_buf_read(BufReader::new(Cursor::new(PIECES))).unwrap(); 80 | BertTokenizer::new(pieces, "[UNK]") 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/categorical/number.rs: -------------------------------------------------------------------------------- 1 | use std::cell::RefCell; 2 | use std::hash::Hash; 3 | 4 | use numberer::Numberer; 5 | use serde_derive::{Deserialize, Serialize}; 6 | 7 | /// Number a categorical variable. 8 | #[allow(clippy::len_without_is_empty)] 9 | pub trait Number 10 | where 11 | V: Clone + Eq + Hash, 12 | { 13 | /// Construct a numberer for categorical variables. 14 | fn new(numberer: Numberer) -> Self; 15 | 16 | /// Get the number of possible values in the categorical variable. 17 | /// 18 | /// This includes reserved numerical representations that do 19 | /// not correspond to values in the categorial variable. 20 | fn len(&self) -> usize; 21 | 22 | /// Get the number of a value from a categorical variable. 23 | /// 24 | /// Mutable implementations of this trait must add the value if it 25 | /// is unknown and always return [`Option::Some`]. 26 | fn number(&self, value: V) -> Option; 27 | 28 | /// Get the value corresponding of a number. 29 | /// 30 | /// Returns [`Option::None`] if the number is unknown *or* a 31 | /// reserved number. 32 | fn value(&self, number: usize) -> Option; 33 | } 34 | 35 | /// An immutable categorical variable numberer. 36 | #[derive(Deserialize, Serialize)] 37 | pub struct ImmutableNumberer(Numberer) 38 | where 39 | V: Clone + Eq + Hash; 40 | 41 | impl Number for ImmutableNumberer 42 | where 43 | V: Clone + Eq + Hash, 44 | { 45 | fn new(numberer: Numberer) -> Self { 46 | ImmutableNumberer(numberer) 47 | } 48 | 49 | fn len(&self) -> usize { 50 | self.0.len() 51 | } 52 | 53 | fn number(&self, value: V) -> Option { 54 | self.0.number(&value) 55 | } 56 | 57 | fn value(&self, number: usize) -> Option { 58 | self.0.value(number).cloned() 59 | } 60 | } 61 | 62 | /// A mutable categorical variable numberer using interior mutability. 63 | #[derive(Deserialize, Serialize)] 64 | pub struct MutableNumberer(RefCell>) 65 | where 66 | V: Clone + Eq + Hash; 67 | 68 | impl Number for MutableNumberer 69 | where 70 | V: Clone + Eq + Hash, 71 | { 72 | fn new(numberer: Numberer) -> Self { 73 | MutableNumberer(RefCell::new(numberer)) 74 | } 75 | 76 | fn len(&self) -> usize { 77 | self.0.borrow().len() 78 | } 79 | 80 | fn number(&self, value: V) -> Option { 81 | Some(self.0.borrow_mut().add(value)) 82 | } 83 | 84 | fn value(&self, number: usize) -> Option { 85 | self.0.borrow().value(number).cloned() 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /syntaxdot-cli/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::io::stdout; 2 | 3 | use anyhow::Result; 4 | use clap::{builder::EnumValueParser, crate_version, Arg, Command}; 5 | use clap_complete::{generate, Shell}; 6 | 7 | pub mod io; 8 | 9 | pub mod progress; 10 | 11 | pub mod save; 12 | 13 | pub mod sent_proc; 14 | 15 | pub mod summary; 16 | 17 | mod subcommands; 18 | 19 | pub mod traits; 20 | use traits::SyntaxDotApp; 21 | 22 | pub mod util; 23 | 24 | fn main() -> Result<()> { 25 | // Known subapplications. 26 | let apps = vec![ 27 | subcommands::AnnotateApp::app(), 28 | subcommands::DistillApp::app(), 29 | subcommands::FilterLenApp::app(), 30 | subcommands::FinetuneApp::app(), 31 | subcommands::PrepareApp::app(), 32 | ]; 33 | 34 | env_logger::init(); 35 | 36 | let cli = Command::new("syntaxdot") 37 | .arg_required_else_help(true) 38 | .about("A neural sequence labeler") 39 | .version(crate_version!()) 40 | .subcommands(apps) 41 | .subcommand( 42 | Command::new("completions") 43 | .about("Generate completion scripts for your shell") 44 | .arg_required_else_help(true) 45 | .arg(Arg::new("shell").value_parser(EnumValueParser::::new())), 46 | ); 47 | let matches = cli.clone().get_matches(); 48 | 49 | match matches.subcommand_name().unwrap() { 50 | "annotate" => { 51 | subcommands::AnnotateApp::parse(matches.subcommand_matches("annotate").unwrap())?.run() 52 | } 53 | "completions" => { 54 | let shell = matches 55 | .subcommand_matches("completions") 56 | .unwrap() 57 | .get_one::("shell") 58 | .unwrap(); 59 | write_completion_script(cli, *shell); 60 | Ok(()) 61 | } 62 | "distill" => { 63 | subcommands::DistillApp::parse(matches.subcommand_matches("distill").unwrap())?.run() 64 | } 65 | "finetune" => { 66 | subcommands::FinetuneApp::parse(matches.subcommand_matches("finetune").unwrap())?.run() 67 | } 68 | "filter-len" => { 69 | subcommands::FilterLenApp::parse(matches.subcommand_matches("filter-len").unwrap())? 70 | .run() 71 | } 72 | "prepare" => { 73 | subcommands::PrepareApp::parse(matches.subcommand_matches("prepare").unwrap())?.run() 74 | } 75 | _unknown => unreachable!(), 76 | } 77 | } 78 | 79 | fn write_completion_script(mut cli: Command, shell: Shell) { 80 | generate(shell, &mut cli, "syntaxdot", &mut stdout()); 81 | } 82 | -------------------------------------------------------------------------------- /scripts/pytorch-bert-to-syntaxdot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import os 5 | import re 6 | import sys 7 | 8 | import torch 9 | 10 | from tensor_module import TensorModule 11 | 12 | parser = argparse.ArgumentParser( 13 | description='Convert PyTorch BERT checkpoint to SyntaxDot tensors.') 14 | parser.add_argument( 15 | 'model', 16 | metavar='MODEL', 17 | help='The model path') 18 | parser.add_argument('tensors', metavar='TENSORS', help='SyntaxDot tensors') 19 | parser.add_argument('--albert', action='store_true', default=False, help="Convert an ALBERT model") 20 | 21 | if __name__ == "__main__": 22 | args = parser.parse_args() 23 | 24 | model = torch.load(args.model) 25 | 26 | tensors = {} 27 | 28 | ignore = re.compile("adam_v|adam_m|global_step|cls|pooler|position_ids") 29 | self_attention = re.compile("attention\.(key|query|value)") 30 | for var, tensor in model.items(): 31 | # Skip unneeded layers 32 | if ignore.search(var): 33 | continue 34 | 35 | # Remove prefix 36 | if args.albert: 37 | var = var.replace("albert.", "") 38 | var = var.replace("albert_", "") 39 | else: 40 | var = var.replace("bert.", "") 41 | 42 | # Rewrite some variable names 43 | var = var.replace("embeddings.weight", "embeddings.embeddings") 44 | var = var.replace("kernel", "weight") 45 | var = var.replace("gamma", "weight") 46 | var = var.replace("beta", "bias") 47 | var = var.replace("LayerNorm", "layer_norm") 48 | var = var.replace("layer.", "layer_") 49 | 50 | if args.albert: 51 | var = var.replace("layer_groups.", "group_") 52 | var = var.replace("layers.", "inner_group_") 53 | var = var.replace("embedding_hidden_mapping_in", "embedding_projection") 54 | var = self_attention.sub(r"attention.self.\1", var) 55 | var = var.replace("attention.dense", "attention.output.dense") 56 | var = var.replace("attention.layer_norm", "attention.output.layer_norm") 57 | var = var.replace("ffn.", "intermediate.dense.") 58 | var = var.replace("ffn_output", "output.dense") 59 | var = var.replace("full_layer_layer_norm", "output.layer_norm") 60 | 61 | # Finally, Rust VarStore replaces periods by vertical bars 62 | # during saving. 63 | var = var.replace(".", "|") 64 | 65 | print("Adding %s..." % var, file=sys.stderr) 66 | 67 | tensors[var] = tensor 68 | 69 | wrapper = TensorModule(tensors) 70 | script = torch.jit.script(wrapper) 71 | script.save(args.tensors) 72 | -------------------------------------------------------------------------------- /doc/models.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | The following ready-to-use models are available. For each language, 4 | treebanks were shuffled and split 7/1/2 in train/development/held-out 5 | partitions. Reported accuracies are from evaluation on held-out data. 6 | 7 | Performance in sentences per second is measured using: 8 | 9 | * CPU: Ryzen 3700X CPU, using 4 threads. 10 | * GPU: NVIDIA RTX 2060 11 | 12 | Models can be used by unpacking the model archive and running 13 | 14 | ```bash 15 | $ syntaxdot annotate model-name/syntaxdot.conf 16 | ``` 17 | 18 | ## Dutch 19 | 20 | | Model | UD POS | Lemma | UD morph | LAS | Size (MiB) | CPU sent/s | GPU sents/s | 21 | |:----------------------------------------------------------------------------------------------------------|-------:|------:|---------:|------:|-----------:|-----------:|------------:| 22 | | [Finetuned XLM-R base](https://s3.tensordot.com/syntaxdot/models/nl-ud-huge-20210301.tar.gz) | 98.90 | 99.03 | 98.87 | 94.37 | 1087 | 61 | 755 | 23 | | [Distilled, 12 layers, 384 hidden](https://s3.tensordot.com/syntaxdot/models/nl-ud-large-20210324.tar.gz) | 98.83 | 99.03 | 98.80 | 93.91 | 200 | 135 | 1450 | 24 | | [Distilled, 6 layers, 384 hidden](https://s3.tensordot.com/syntaxdot/models/nl-ud-medium-20210312.tar.gz) | 98.80 | 99.05 | 98.79 | 93.42 | 133 | 240 | 2359 | 25 | 26 | ## German 27 | 28 | | Model | UD POS | STTS POS | Lemma | UD morph | TDZ morph | LAS | Topo Field | Size (MiB) | CPU sent/s | GPU sent/s | 29 | |:----------------------------------------------------------------------------------------------------------|-------:|---------:|------:|---------:|----------:|------:|-----------:|-----------:|-----------:|-----------:| 30 | | [Finetuned XLM-R base](https://github.com/tensordot/syntaxdot-models/releases/download/de-ud-2021/de-ud-huge-20210307.tar.gz) | 99.54 | 99.48 | 99.34 | 98.38 | 98.43 | 96.59 | 98.17 | 1087 | 45 | 614 | 31 | | [Distilled, 12 layers, 384 hidden](https://github.com/tensordot/syntaxdot-models/releases/download/de-ud-2021/de-ud-large-20210326.tar.gz) | 99.50 | 99.44 | 99.31 | 98.31 | 98.36 | 96.17 | 98.12 | 208 | 105 | 1131 | 32 | | [Distilled, 6 layers, 384 hidden](https://github.com/tensordot/syntaxdot-models/releases/download/de-ud-2021/de-ud-medium-20210326.tar.gz) | 99.46 | 99.40 | 99.29 | 98.20 | 98.27 | 95.48 | 97.97 | 140 | 180 | 1748 | 33 | 34 | -------------------------------------------------------------------------------- /scripts/plot-layer-weights: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import csv 5 | 6 | import torch 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | matplotlib.use("svg") 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser(description="Plot layer weights.") 15 | parser.add_argument("model", metavar="MODEL", help="model") 16 | parser.add_argument("output", metavar="OUTPUT", help="output image") 17 | parser.add_argument( 18 | "-c", 19 | "--cumulative", 20 | action="store_true", 21 | help="plot cumulative weights", 22 | ) 23 | parser.add_argument( 24 | "-o", "--csv", default="none", help="write layer weights as CSV" 25 | ) 26 | parser.add_argument( 27 | "-n", 28 | "--normalized", 29 | action="store_true", 30 | help="plot normalized layer weights", 31 | ) 32 | parser.add_argument( 33 | "-f", "--format", default="svg", help="format as matplotlib backend" 34 | ) 35 | parser.add_argument("-t", "--title", default=None, help="plot title") 36 | 37 | args = parser.parse_args() 38 | 39 | model = torch.jit.load(args.model, map_location="cpu") 40 | 41 | parameter_names = [] 42 | for name, _ in model.named_parameters(): 43 | if "layer_weights" in name: 44 | parameter_names.append(name) 45 | 46 | parameter_names.sort() 47 | 48 | names = [] 49 | weights = [] 50 | for parameter_name in parameter_names: 51 | if "_classifier" in parameter_name: 52 | # Sequence classifier 53 | layer = parameter_name.split("|")[1][: -len("_classifier")] 54 | elif parameter_name.startswith("biaffine|"): 55 | layer = "biaffine" 56 | else: 57 | raise NotImplementedError( 58 | f"Unknown parameter type: {parameter_name}" 59 | ) 60 | 61 | names.append(layer) 62 | 63 | tensor = getattr(model, parameter_name) 64 | tensor = tensor.softmax(-1) if args.normalized else tensor 65 | tensor = tensor.cumsum(-1) if args.cumulative else tensor 66 | weights.append(tensor.detach().numpy()) 67 | 68 | plt.plot(np.array(weights).transpose()) 69 | plt.legend(names) 70 | if args.title: 71 | plt.title(args.title) 72 | plt.xlabel("Layer") 73 | plt.ylabel("Layer weight") 74 | plt.savefig(args.output, format=args.format) 75 | 76 | if args.csv: 77 | with open(args.csv, "w", newline="") as csvfile: 78 | layer_writer = csv.writer(csvfile) 79 | for (name, weights) in zip(names, weights): 80 | layer_writer.writerow([name] + list(weights)) 81 | -------------------------------------------------------------------------------- /syntaxdot/src/dataset/conll.rs: -------------------------------------------------------------------------------- 1 | use std::io::{BufRead, Seek, SeekFrom}; 2 | 3 | use conllu::io::{ReadSentence, Reader, Sentences}; 4 | use syntaxdot_tokenizers::{SentenceWithPieces, Tokenize}; 5 | 6 | use crate::dataset::DataSet; 7 | use crate::error::SyntaxDotError; 8 | 9 | /// A CoNLL-X data set. 10 | pub struct ConlluDataSet(R); 11 | 12 | impl ConlluDataSet { 13 | /// Construct a CoNLL-X dataset. 14 | pub fn new(read: R) -> Self { 15 | ConlluDataSet(read) 16 | } 17 | } 18 | 19 | impl<'a, R> DataSet<'a> for &'a mut ConlluDataSet 20 | where 21 | R: BufRead + Seek, 22 | { 23 | type Iter = ConllIter<'a, Reader<&'a mut R>>; 24 | 25 | fn sentences(self, tokenizer: &'a dyn Tokenize) -> Result { 26 | // Rewind to the beginning of the dataset (if necessary). 27 | self.0.seek(SeekFrom::Start(0))?; 28 | 29 | let reader = Reader::new(&mut self.0); 30 | 31 | Ok(ConllIter { 32 | sentences: reader.sentences(), 33 | tokenizer, 34 | }) 35 | } 36 | } 37 | 38 | pub struct ConllIter<'a, R> 39 | where 40 | R: ReadSentence, 41 | { 42 | sentences: Sentences, 43 | tokenizer: &'a dyn Tokenize, 44 | } 45 | 46 | impl<'a, R> Iterator for ConllIter<'a, R> 47 | where 48 | R: ReadSentence, 49 | { 50 | type Item = Result; 51 | 52 | fn next(&mut self) -> Option { 53 | self.sentences.next().map(|s| { 54 | s.map(|s| self.tokenizer.tokenize(s)) 55 | .map_err(SyntaxDotError::ConlluError) 56 | }) 57 | } 58 | } 59 | 60 | #[cfg(test)] 61 | mod tests { 62 | use std::io::{BufReader, Cursor}; 63 | 64 | use crate::dataset::tests::{dataset_to_pieces, wordpiece_tokenizer, CORRECT_PIECE_IDS}; 65 | use crate::dataset::ConlluDataSet; 66 | 67 | const SENTENCES: &str = r#" 68 | 1 Dit 69 | 2 is 70 | 3 de 71 | 4 eerste 72 | 5 zin 73 | 6 . 74 | 75 | 1 Dit 76 | 2 de 77 | 3 tweede 78 | 4 zin 79 | 5 . 80 | 81 | 1 En 82 | 2 nu 83 | 3 de 84 | 4 laatste 85 | 5 zin 86 | 6 ."#; 87 | 88 | #[test] 89 | fn plain_text_dataset_works() { 90 | let tokenizer = wordpiece_tokenizer(); 91 | let mut cursor = Cursor::new(SENTENCES); 92 | let mut dataset = ConlluDataSet::new(BufReader::new(&mut cursor)); 93 | 94 | let pieces = dataset_to_pieces(&mut dataset, &tokenizer).unwrap(); 95 | assert_eq!(pieces, *CORRECT_PIECE_IDS); 96 | 97 | // Verify that the data set is correctly read again. 98 | let more_pieces = dataset_to_pieces(&mut dataset, &tokenizer).unwrap(); 99 | assert_eq!(more_pieces, *CORRECT_PIECE_IDS); 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | 8 | jobs: 9 | create-release: 10 | runs-on: ubuntu-latest 11 | outputs: 12 | upload_url: ${{ steps.create_release.outputs.upload_url }} 13 | steps: 14 | - uses: actions/create-release@v1.0.0 15 | id: create_release 16 | env: 17 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 18 | with: 19 | tag_name: ${{ github.ref }} 20 | release_name: Release ${{ github.ref }} 21 | draft: true 22 | prerelease: falseu 23 | 24 | build_release: 25 | strategy: 26 | matrix: 27 | device: [cpu, cuda] 28 | needs: ['create-release'] 29 | runs-on: ubuntu-20.04 30 | steps: 31 | - uses: actions/checkout@v2 32 | - name: Get release version 33 | run: | 34 | echo "TAG_NAME=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV 35 | echo "tag: ${{ env.TAG_NAME }}" 36 | - name: Install dependencies 37 | run: sudo apt-get install build-essential bzip2 cmake pkg-config libssl-dev 38 | # Patchelf in Ubuntu 16.04 cannot patch the syntaxdot binary. 39 | - name: Build patchelf 40 | run: | 41 | wget "https://github.com/NixOS/patchelf/releases/download/0.12/patchelf-0.12.tar.bz2" 42 | echo "699a31cf52211cf5ad6e35a8801eb637bc7f3c43117140426400d67b7babd792 patchelf-0.12.tar.bz2" | sha256sum -c - 43 | tar jxf patchelf-0.12.tar.bz2 44 | ( cd patchelf-0.12.20200827.8d3a16e && ./configure && make -j4 ) 45 | - uses: tensordot/libtorch-action@v2.1.0 46 | with: 47 | device: ${{matrix.device}} 48 | - uses: actions-rs/toolchain@v1 49 | with: 50 | profile: minimal 51 | toolchain: stable 52 | override: true 53 | - uses: actions-rs/cargo@v1 54 | with: 55 | command: build 56 | args: --release 57 | - name: Create release archive 58 | id: create_archive 59 | run: | 60 | DIST=syntaxdot-${{ env.TAG_NAME }}-${{ matrix.device }}-x86_64-linux-gnu-gcc 61 | ARCHIVE=${DIST}.tar.zst 62 | install -Dm755 -t ${DIST} target/release/syntaxdot 63 | install -Dm755 -t ${DIST} ${LIBTORCH}/lib/*.so* 64 | patchelf-0.12.20200827.8d3a16e/src/patchelf --set-rpath '$ORIGIN' ${DIST}/*.so* ${DIST}/syntaxdot 65 | tar --zstd -cvf ${ARCHIVE} ${DIST} 66 | echo ::set-output name=ASSET::$ARCHIVE 67 | - uses: actions/upload-release-asset@v1.0.1 68 | env: 69 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 70 | with: 71 | upload_url: ${{ needs.create-release.outputs.upload_url }} 72 | asset_path: ${{ steps.create_archive.outputs.ASSET }} 73 | asset_name: ${{ steps.create_archive.outputs.ASSET }} 74 | asset_content_type: application/zstd 75 | -------------------------------------------------------------------------------- /syntaxdot-transformers/src/models/squeeze_bert/config.rs: -------------------------------------------------------------------------------- 1 | use serde::Deserialize; 2 | 3 | use crate::activations::Activation; 4 | use crate::models::bert::BertConfig; 5 | 6 | /// SqueezeBert model configuration. 7 | #[derive(Debug, Deserialize)] 8 | #[serde(default)] 9 | pub struct SqueezeBertConfig { 10 | pub attention_probs_dropout_prob: f64, 11 | pub embedding_size: i64, 12 | pub hidden_act: Activation, 13 | pub hidden_dropout_prob: f64, 14 | pub hidden_size: i64, 15 | pub initializer_range: f64, 16 | pub intermediate_size: i64, 17 | pub layer_norm_eps: f64, 18 | pub max_position_embeddings: i64, 19 | pub num_attention_heads: i64, 20 | pub num_hidden_layers: i64, 21 | pub type_vocab_size: i64, 22 | pub vocab_size: i64, 23 | pub q_groups: i64, 24 | pub k_groups: i64, 25 | pub v_groups: i64, 26 | pub post_attention_groups: i64, 27 | pub intermediate_groups: i64, 28 | pub output_groups: i64, 29 | } 30 | 31 | impl Default for SqueezeBertConfig { 32 | fn default() -> Self { 33 | SqueezeBertConfig { 34 | attention_probs_dropout_prob: 0.1, 35 | embedding_size: 768, 36 | hidden_act: Activation::Gelu, 37 | hidden_dropout_prob: 0.1, 38 | hidden_size: 768, 39 | initializer_range: 0.02, 40 | intermediate_size: 3072, 41 | layer_norm_eps: 1e-12, 42 | max_position_embeddings: 512, 43 | num_attention_heads: 12, 44 | num_hidden_layers: 12, 45 | type_vocab_size: 2, 46 | vocab_size: 30528, 47 | q_groups: 4, 48 | k_groups: 4, 49 | v_groups: 4, 50 | post_attention_groups: 1, 51 | intermediate_groups: 4, 52 | output_groups: 4, 53 | } 54 | } 55 | } 56 | 57 | impl From<&SqueezeBertConfig> for BertConfig { 58 | fn from(squeeze_bert_config: &SqueezeBertConfig) -> Self { 59 | BertConfig { 60 | attention_probs_dropout_prob: squeeze_bert_config.attention_probs_dropout_prob, 61 | hidden_act: squeeze_bert_config.hidden_act, 62 | hidden_dropout_prob: squeeze_bert_config.hidden_dropout_prob, 63 | hidden_size: squeeze_bert_config.hidden_size, 64 | initializer_range: squeeze_bert_config.initializer_range, 65 | intermediate_size: squeeze_bert_config.intermediate_size, 66 | layer_norm_eps: squeeze_bert_config.layer_norm_eps, 67 | max_position_embeddings: squeeze_bert_config.max_position_embeddings, 68 | num_attention_heads: squeeze_bert_config.num_attention_heads, 69 | num_hidden_layers: squeeze_bert_config.num_hidden_layers, 70 | type_vocab_size: squeeze_bert_config.type_vocab_size, 71 | vocab_size: squeeze_bert_config.vocab_size, 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /doc/install.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Dependencies 4 | 5 | SyntaxDot has the following base requirements: 6 | 7 | * A [Rust Toolchain](https://rustup.rs/) 8 | * A C++ compiler 9 | * libtorch 10 | * cmake 11 | * OpenSSL 12 | * pkg-config 13 | 14 | Additionally, compiling a SyntaxDot with training functionality requires: 15 | 16 | * CUDA 17 | 18 | ## Linux/macOS 19 | 20 | ### Dependencies 21 | 22 | #### macOS 23 | 24 | Install `cmake`, for instance through Homebrew: 25 | 26 | ```shell 27 | brew install cmake 28 | ``` 29 | 30 | #### Fedora 31 | 32 | Most of the dependencies can be installed in Fedora using the following 33 | command: 34 | 35 | ```shell 36 | $ sudo dnf install -y cmake gcc-c++ openssl-devel pkg-config 37 | ``` 38 | 39 | Follow the RPM Fusion [instructions for installing 40 | CUDA](https://rpmfusion.org/Howto/NVIDIA#CUDA). Besides these dependencies, 41 | you also need a Rust toolchain and libtorch (see below). 42 | 43 | #### Debian/Ubuntu 44 | 45 | Install the following dependencies through APT: 46 | 47 | ```shell 48 | $ apt-get install -y build-essential cmake libssl-dev pkg-config 49 | ``` 50 | 51 | For installing CUDA, please refer to your distribution's instructions. 52 | Besides these dependencies, you also need a Rust toolchain and libtorch 53 | (see below). 54 | 55 | ### Rust toolchain 56 | 57 | A Rust stable toolchain can be installed through [rustup](https://rustup.rs/): 58 | 59 | ```shell 60 | $ rustup default stable 61 | ``` 62 | ### libtorch 63 | 64 | [Download libtorch](https://pytorch.org/get-started/locally/) with the CXX11 65 | ABI. A CUDA build is necessary for training, otherwise the CPU build can be 66 | used. After unpacking the libtorch archive, you should set the following 67 | environment variables: 68 | 69 | ```shell 70 | $ export LIBTORCH=/path/to/libtorch 71 | # Linux: 72 | $ export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:+$LD_LIBRARY_PATH:}${LIBTORCH}/lib 73 | # macOS 74 | $ export DYLD_LIBRARY_PATH=${DYLD_LIBRARY_PATH:+$DYLD_LIBRARY_PATH:}${LIBTORCH}/lib 75 | ``` 76 | 77 | There are currently no libtorch releases available for macOS ARM64. However, 78 | you can use libtorch from the `torch` Python package: 79 | 80 | ```shell 81 | $ pip install torch 82 | $ export LIBTORCH=$(python -c 'import torch; from pathlib import Path; print(Path(torch.__file__).parent)') 83 | $ export DYLD_LIBRARY_PATH=${DYLD_LIBRARY_PATH:+$DYLD_LIBRARY_PATH:}${LIBTORCH}/lib 84 | ``` 85 | 86 | ## Building SyntaxDot 87 | 88 | You can build SyntaxDot with support for training enabled using: 89 | 90 | ```shell 91 | $ cargo install syntaxdot-cli 92 | ``` 93 | 94 | To build SyntaxDot without training features, use: 95 | 96 | ```shell 97 | $ cargo install --no-default-features syntaxdot-cli 98 | ``` 99 | 100 | The SyntaxDot binary will then be available in: ```~/.cargo/bin/syntaxdot``` 101 | -------------------------------------------------------------------------------- /syntaxdot-transformers/src/models/squeeze_bert/embeddings.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "model-tests")] 2 | #[cfg(test)] 3 | mod tests { 4 | use std::convert::TryInto; 5 | 6 | use approx::assert_abs_diff_eq; 7 | use ndarray::{array, ArrayD}; 8 | use syntaxdot_tch_ext::tensor::SumDim; 9 | use syntaxdot_tch_ext::RootExt; 10 | use tch::nn::VarStore; 11 | use tch::{Device, Kind, Tensor}; 12 | 13 | use crate::activations::Activation; 14 | use crate::models::bert::{BertConfig, BertEmbeddings}; 15 | use crate::models::squeeze_bert::SqueezeBertConfig; 16 | use crate::module::FallibleModuleT; 17 | 18 | const SQUEEZEBERT_UNCASED: &str = env!("SQUEEZEBERT_UNCASED"); 19 | 20 | fn squeezebert_uncased_config() -> SqueezeBertConfig { 21 | SqueezeBertConfig { 22 | attention_probs_dropout_prob: 0.1, 23 | embedding_size: 768, 24 | hidden_act: Activation::Gelu, 25 | hidden_dropout_prob: 0.1, 26 | hidden_size: 768, 27 | initializer_range: 0.02, 28 | intermediate_size: 3072, 29 | layer_norm_eps: 1e-12, 30 | max_position_embeddings: 512, 31 | num_attention_heads: 12, 32 | num_hidden_layers: 12, 33 | type_vocab_size: 2, 34 | vocab_size: 30528, 35 | q_groups: 4, 36 | k_groups: 4, 37 | v_groups: 4, 38 | post_attention_groups: 1, 39 | intermediate_groups: 4, 40 | output_groups: 4, 41 | } 42 | } 43 | 44 | #[test] 45 | fn squeeze_bert_embeddings() { 46 | let config = squeezebert_uncased_config(); 47 | let bert_config: BertConfig = (&config).into(); 48 | 49 | let mut vs = VarStore::new(Device::Cpu); 50 | let root = vs.root_ext(|_| 0); 51 | 52 | let embeddings = BertEmbeddings::new(root.sub("embeddings"), &bert_config).unwrap(); 53 | 54 | vs.load(SQUEEZEBERT_UNCASED).unwrap(); 55 | vs.float(); 56 | 57 | // Word pieces of: Did the AWO embezzle donations ? 58 | let pieces = 59 | Tensor::from_slice(&[2106i64, 1996, 22091, 2080, 7861, 4783, 17644, 11440, 1029]) 60 | .reshape(&[1, 9]); 61 | 62 | let summed_embeddings = 63 | embeddings 64 | .forward_t(&pieces, false) 65 | .unwrap() 66 | .sum_dim(-1, false, Kind::Float); 67 | 68 | let sums: ArrayD = (&summed_embeddings).try_into().unwrap(); 69 | 70 | // Verify output against Hugging Face transformers Python 71 | // implementation. 72 | assert_abs_diff_eq!( 73 | sums, 74 | (array![[ 75 | 39.4658, 35.4720, -2.2577, 11.3962, -1.6288, -9.8682, -18.4578, -12.0717, 11.7386 76 | ]]) 77 | .into_dyn(), 78 | epsilon = 1e-4 79 | ); 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /syntaxdot-transformers/src/models/albert/config.rs: -------------------------------------------------------------------------------- 1 | use serde::Deserialize; 2 | 3 | use crate::activations::Activation; 4 | use crate::models::bert::BertConfig; 5 | use crate::models::traits::WordEmbeddingsConfig; 6 | 7 | /// ALBERT model configuration. 8 | #[derive(Debug, Deserialize)] 9 | #[serde(default)] 10 | pub struct AlbertConfig { 11 | pub attention_probs_dropout_prob: f64, 12 | pub embedding_size: i64, 13 | pub hidden_act: Activation, 14 | pub hidden_dropout_prob: f64, 15 | pub hidden_size: i64, 16 | pub initializer_range: f64, 17 | pub inner_group_num: i64, 18 | pub intermediate_size: i64, 19 | pub max_position_embeddings: i64, 20 | pub num_attention_heads: i64, 21 | pub num_hidden_groups: i64, 22 | pub num_hidden_layers: i64, 23 | pub type_vocab_size: i64, 24 | pub vocab_size: i64, 25 | } 26 | 27 | impl Default for AlbertConfig { 28 | fn default() -> Self { 29 | AlbertConfig { 30 | attention_probs_dropout_prob: 0., 31 | embedding_size: 128, 32 | hidden_act: Activation::GeluNew, 33 | hidden_dropout_prob: 0., 34 | hidden_size: 768, 35 | initializer_range: 0.02, 36 | inner_group_num: 1, 37 | intermediate_size: 3072, 38 | max_position_embeddings: 512, 39 | num_attention_heads: 12, 40 | num_hidden_groups: 1, 41 | num_hidden_layers: 12, 42 | type_vocab_size: 2, 43 | vocab_size: 30000, 44 | } 45 | } 46 | } 47 | 48 | impl From<&AlbertConfig> for BertConfig { 49 | fn from(albert_config: &AlbertConfig) -> Self { 50 | BertConfig { 51 | attention_probs_dropout_prob: albert_config.attention_probs_dropout_prob, 52 | hidden_act: albert_config.hidden_act, 53 | hidden_dropout_prob: albert_config.hidden_dropout_prob, 54 | hidden_size: albert_config.hidden_size, 55 | initializer_range: albert_config.initializer_range, 56 | intermediate_size: albert_config.intermediate_size, 57 | layer_norm_eps: 1e-12, 58 | max_position_embeddings: albert_config.max_position_embeddings, 59 | num_attention_heads: albert_config.num_attention_heads, 60 | num_hidden_layers: albert_config.num_hidden_layers, 61 | type_vocab_size: albert_config.type_vocab_size, 62 | vocab_size: albert_config.vocab_size, 63 | } 64 | } 65 | } 66 | 67 | impl WordEmbeddingsConfig for AlbertConfig { 68 | fn dims(&self) -> i64 { 69 | self.embedding_size 70 | } 71 | 72 | fn dropout(&self) -> f64 { 73 | self.hidden_dropout_prob 74 | } 75 | 76 | fn initializer_range(&self) -> f64 { 77 | self.initializer_range 78 | } 79 | 80 | fn layer_norm_eps(&self) -> f64 { 81 | 1e-12 82 | } 83 | 84 | fn vocab_size(&self) -> i64 { 85 | self.vocab_size 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /syntaxdot-transformers/src/models/layer_output.rs: -------------------------------------------------------------------------------- 1 | use crate::TransformerError; 2 | use tch::Tensor; 3 | 4 | /// Hidden layer output and attention. 5 | #[derive(Debug)] 6 | pub struct HiddenLayer { 7 | /// The output of the layer. 8 | pub output: Tensor, 9 | 10 | /// The layer attention scores (unnormalized). 11 | pub attention: Tensor, 12 | } 13 | 14 | /// Output of a BERT layer. 15 | #[derive(Debug)] 16 | pub enum LayerOutput { 17 | /// Embedding layer output. 18 | Embedding(Tensor), 19 | 20 | /// Encoder layer output. 21 | EncoderWithAttention(HiddenLayer), 22 | } 23 | 24 | impl LayerOutput { 25 | /// Get the layer attention. 26 | /// 27 | /// Return a `Some` value if the layer output is from an encoder layer, 28 | /// or `None` otherwise. 29 | pub fn attention(&self) -> Option<&Tensor> { 30 | match self { 31 | LayerOutput::Embedding(_) => None, 32 | LayerOutput::EncoderWithAttention(hidden) => Some(&hidden.attention), 33 | } 34 | } 35 | 36 | /// Get the embedding. 37 | /// 38 | /// Returns `Some` if the layer output is an embedding or `None` 39 | /// otherwise. 40 | pub fn embedding(&self) -> Option<&Tensor> { 41 | match self { 42 | LayerOutput::Embedding(embedding) => Some(embedding), 43 | LayerOutput::EncoderWithAttention(_) => None, 44 | } 45 | } 46 | 47 | /// Map the output representation of this layer. 48 | pub fn map_output(&self, f: F) -> Result 49 | where 50 | F: Fn(&Tensor) -> Result, 51 | { 52 | let layer = match self { 53 | LayerOutput::Embedding(embedding) => LayerOutput::Embedding(f(embedding)?), 54 | LayerOutput::EncoderWithAttention(HiddenLayer { output, attention }) => { 55 | LayerOutput::EncoderWithAttention(HiddenLayer { 56 | output: f(output)?, 57 | attention: attention.shallow_clone(), 58 | }) 59 | } 60 | }; 61 | 62 | Ok(layer) 63 | } 64 | 65 | /// Get the layer output. 66 | pub fn output(&self) -> &Tensor { 67 | match self { 68 | LayerOutput::Embedding(embedding) => embedding, 69 | LayerOutput::EncoderWithAttention(hidden) => &hidden.output, 70 | } 71 | } 72 | 73 | /// Get the layer output mutably. 74 | pub fn output_mut(&mut self) -> &mut Tensor { 75 | match self { 76 | LayerOutput::Embedding(embedding) => embedding, 77 | LayerOutput::EncoderWithAttention(hidden) => &mut hidden.output, 78 | } 79 | } 80 | 81 | /// Get the output of an encoder layer. 82 | /// 83 | /// Return a `Some` value if the layer output is from an encoder layer, 84 | /// or `None` otherwise. 85 | pub fn encoder_with_attention(&self) -> Option<&HiddenLayer> { 86 | match self { 87 | LayerOutput::Embedding(_) => None, 88 | LayerOutput::EncoderWithAttention(hidden) => Some(hidden), 89 | } 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /syntaxdot-transformers/src/activations.rs: -------------------------------------------------------------------------------- 1 | //! Activation functions 2 | 3 | use std::convert::TryFrom; 4 | use std::f64; 5 | 6 | use serde::Deserialize; 7 | use tch::Tensor; 8 | 9 | use crate::module::FallibleModule; 10 | use crate::TransformerError; 11 | 12 | #[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq)] 13 | #[serde(try_from = "String")] 14 | pub enum Activation { 15 | /// GELU activation function. 16 | /// 17 | /// GELU(x)=x Φ(x) 18 | /// 19 | /// where Φ(x) is the CDF for the Gaussian distribution. 20 | Gelu, 21 | 22 | /// GELU activation function (Google/OpenAI flavor). 23 | /// 24 | /// GELU(x)=x Φ(x) 25 | /// 26 | /// where Φ(x) is the CDF for the Gaussian distribution. 27 | GeluNew, 28 | 29 | /// ReLU activation function 30 | /// 31 | /// ReLU(x)=max(0,x) 32 | Relu, 33 | } 34 | 35 | impl TryFrom<&str> for Activation { 36 | type Error = TransformerError; 37 | 38 | fn try_from(value: &str) -> Result { 39 | match value { 40 | "gelu" => Ok(Activation::Gelu), 41 | "gelu_new" => Ok(Activation::GeluNew), 42 | "relu" => Ok(Activation::Relu), 43 | unknown => Err(TransformerError::UnknownActivationFunction { 44 | activation: unknown.to_string(), 45 | }), 46 | } 47 | } 48 | } 49 | 50 | impl TryFrom for Activation { 51 | type Error = TransformerError; 52 | 53 | fn try_from(value: String) -> Result { 54 | Self::try_from(value.as_str()) 55 | } 56 | } 57 | 58 | impl FallibleModule for Activation { 59 | type Error = TransformerError; 60 | 61 | fn forward(&self, input: &Tensor) -> Result { 62 | match self { 63 | Self::Gelu => Ok(input.f_gelu("none")?), 64 | Self::GeluNew => Ok(0.5 65 | * input 66 | * (1.0 67 | + Tensor::f_tanh( 68 | &((2. / f64::consts::PI).sqrt() 69 | * (input + 0.044715 * input.pow_tensor_scalar(3.0))), 70 | )?)), 71 | Self::Relu => Ok(input.f_relu()?), 72 | } 73 | } 74 | } 75 | 76 | #[cfg(test)] 77 | mod tests { 78 | use std::convert::TryInto; 79 | 80 | use approx::assert_abs_diff_eq; 81 | use ndarray::{array, ArrayD}; 82 | use tch::Tensor; 83 | 84 | use crate::activations::Activation; 85 | use crate::module::FallibleModule; 86 | 87 | #[test] 88 | fn gelu_new_returns_correct_values() { 89 | let gelu_new = Activation::GeluNew; 90 | let activations: ArrayD = (&gelu_new 91 | .forward(&Tensor::from_slice(&[-1., -0.5, 0., 0.5, 1.])) 92 | .unwrap()) 93 | .try_into() 94 | .unwrap(); 95 | assert_abs_diff_eq!( 96 | activations, 97 | array![-0.1588, -0.1543, 0.0000, 0.3457, 0.8412].into_dyn(), 98 | epsilon = 1e-4 99 | ); 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /syntaxdot-encoders/testdata/lang/de/tdz/lemma/simplify-piat-lemma.test: -------------------------------------------------------------------------------- 1 | # Format: form lemma upos xpos transformed [rel head_form head_lemma head_upos head_xpos] 2 | # [rel dep_form dep_lemma dep_upos dep_xpos]* 3 | 4 | keinen kein _ PIAT kein 5 | keine kein _ PIAT kein 6 | mehrere mehrere _ PIAT mehrere 7 | mehr mehr _ PIAT mehr 8 | kein kein _ PIAT kein 9 | einige einiges _ PIAT einig 10 | soviel soviel _ PIAT soviel 11 | mehreren mehrere _ PIAT mehrere 12 | etliche etliche _ PIAT etlich 13 | weniger weniger _ PIAT wenig 14 | keiner keine _ PIAT kein 15 | manche mancher _ PIAT manch 16 | viel viel _ PIAT viel 17 | wenig wenig _ PIAT wenig 18 | keinerlei keinerlei _ PIAT keinerlei 19 | irgendwelchen irgendwelches _ PIAT irgendwelch 20 | einigen einiger _ PIAT einig 21 | derlei derlei _ PIAT derlei 22 | genausoviel genausoviel _ PIAT genausoviel 23 | etwas etwas _ PIAT etwas 24 | keinem kein _ PIAT kein 25 | #köön kein_ _ PIAT kein 26 | einigem einiger _ PIAT einig 27 | manchen mancher _ PIAT manch 28 | manches manches _ PIAT manch 29 | sowas sowas _ PIAT sowas 30 | einiger einige _ PIAT einig 31 | ebensoviel ebensoviel _ PIAT ebensoviel 32 | allerhand allerhand _ PIAT allerhand 33 | genug genug _ PIAT genug 34 | irgendein irgendein _ PIAT irgendein 35 | reichlich reichlich _ PIAT reichlich 36 | genügend genügend _ PIAT genügend 37 | irgendeiner irgendeine _ PIAT irgendein 38 | lauter lauter _ PIAT lauter 39 | was etwas _ PIAT etwas 40 | mancher mancher _ PIAT manch 41 | zuviel zuviel _ PIAT zuviel 42 | allerlei allerlei _ PIAT allerlei 43 | mehrerer mehrere _ PIAT mehrere 44 | #kain kein_ _ PIAT kein 45 | zuwenig zuwenig _ PIAT zuwenig 46 | allzuviel allzuviel _ PIAT allzuviel 47 | ausreichend ausreichend _ PIAT ausreichend 48 | irgendeine irgendeine _ PIAT irgendein 49 | keines kein _ PIAT kein 50 | dergleichen dergleichen _ PIAT dergleichen 51 | #ken kein _ PIAT kein 52 | etliches etliches _ PIAT etlich 53 | einiges einiges _ PIAT einig 54 | irgendwelche irgendwelche _ PIAT irgendwelch 55 | gleichviel gleichviel _ PIAT gleichviel 56 | etlicher etliche _ PIAT etlich 57 | irgendeinem irgendein _ PIAT irgendein 58 | zweierlei zweierlei _ PIAT zweierlei 59 | etlichen etliche _ PIAT etlich 60 | vielerlei vielerlei _ PIAT vielerlei 61 | beiderlei beiderlei _ PIAT beiderlei 62 | manchem manches _ PIAT manch 63 | intertransmultibiviele intertransmultibiviel _ PIAT intertransmultibiviel 64 | #mache manche _ PIAT manch 65 | mehrer mehrere _ PIAT mehrere 66 | irgendwelcher irgendwelcher _ PIAT irgendwelch 67 | sowenig sowenig _ PIAT sowenig 68 | irgendeinen irgendein _ PIAT irgendein 69 | allelei allerlei _ PIAT allerlei 70 | genausowenig genausowenig _ PIAT genausowenig 71 | solcherlei solcherlei _ PIAT solcherlei 72 | irgendeines irgendein _ PIAT irgendein 73 | tausenderlei tausenderlei _ PIAT tausenderlei 74 | jedwedem jedweder _ PIAT jedwed 75 | etwelchen etwelcher _ PIAT etwelcher 76 | #keene keine_ _ PIAT kein 77 | genügnd genügend_ _ PIAT genügend_ 78 | jedwedes jedwedes _ PIAT jedwed 79 | jedweder jedwede _ PIAT jedwed 80 | jedweden jedwedes _ PIAT jedwed 81 | diese diese _ PIAT diese 82 | siebenerlei siebenerlei _ PIAT siebenerlei 83 | kein(e) kein _ PIAT kein 84 | -------------------------------------------------------------------------------- /syntaxdot/src/dataset/plaintext.rs: -------------------------------------------------------------------------------- 1 | use std::io::{BufRead, Lines, Seek, SeekFrom}; 2 | 3 | use syntaxdot_tokenizers::{SentenceWithPieces, Tokenize}; 4 | use udgraph::graph::Sentence; 5 | use udgraph::token::Token; 6 | 7 | use crate::dataset::DataSet; 8 | use crate::error::SyntaxDotError; 9 | 10 | /// A CoNLL-X data set. 11 | pub struct PlainTextDataSet(R); 12 | 13 | impl PlainTextDataSet { 14 | /// Construct a plain-text dataset. 15 | pub fn new(read: R) -> Self { 16 | Self(read) 17 | } 18 | } 19 | 20 | impl<'a, R> DataSet<'a> for &'a mut PlainTextDataSet 21 | where 22 | R: BufRead + Seek, 23 | { 24 | type Iter = PlainTextIter<'a, &'a mut R>; 25 | 26 | fn sentences(self, tokenizer: &'a dyn Tokenize) -> Result { 27 | // Rewind to the beginning of the dataset (if necessary). 28 | self.0.seek(SeekFrom::Start(0))?; 29 | 30 | Ok(PlainTextIter { 31 | lines: (&mut self.0).lines(), 32 | tokenizer, 33 | }) 34 | } 35 | } 36 | 37 | pub struct PlainTextIter<'a, R> { 38 | lines: Lines, 39 | tokenizer: &'a dyn Tokenize, 40 | } 41 | 42 | impl<'a, R> Iterator for PlainTextIter<'a, R> 43 | where 44 | R: BufRead, 45 | { 46 | type Item = Result; 47 | 48 | fn next(&mut self) -> Option { 49 | for line in &mut self.lines { 50 | // Bubble up read errors. 51 | let line = match line { 52 | Ok(line) => line, 53 | Err(err) => return Some(Err(SyntaxDotError::IoError(err))), 54 | }; 55 | 56 | let line_trimmed = line.trim(); 57 | 58 | // Skip empty lines 59 | if line_trimmed.is_empty() { 60 | continue; 61 | } 62 | 63 | let sentence = line_trimmed 64 | .split_terminator(' ') 65 | .map(ToString::to_string) 66 | .map(Token::new) 67 | .collect::(); 68 | 69 | return Some(Ok(self.tokenizer.tokenize(sentence))); 70 | } 71 | 72 | None 73 | } 74 | } 75 | 76 | #[cfg(test)] 77 | mod tests { 78 | use std::io::{BufReader, Cursor}; 79 | 80 | use crate::dataset::tests::{dataset_to_pieces, wordpiece_tokenizer, CORRECT_PIECE_IDS}; 81 | use crate::dataset::PlainTextDataSet; 82 | 83 | const SENTENCES: &str = r#" 84 | Dit is de eerste zin . 85 | Dit de tweede zin . 86 | 87 | En nu de laatste zin ."#; 88 | 89 | #[test] 90 | fn plain_text_dataset_works() { 91 | let tokenizer = wordpiece_tokenizer(); 92 | let mut cursor = Cursor::new(SENTENCES); 93 | let mut dataset = PlainTextDataSet::new(BufReader::new(&mut cursor)); 94 | 95 | let pieces = dataset_to_pieces(&mut dataset, &tokenizer).unwrap(); 96 | assert_eq!(pieces, *CORRECT_PIECE_IDS); 97 | 98 | // Verify that the data set is correctly read again. 99 | let more_pieces = dataset_to_pieces(&mut dataset, &tokenizer).unwrap(); 100 | assert_eq!(more_pieces, *CORRECT_PIECE_IDS); 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/lang/de/tdz/lemma/automaton.rs: -------------------------------------------------------------------------------- 1 | use fst::raw::{Fst, Node}; 2 | use fst::Set; 3 | 4 | /// Search prefixes of a string in a set. 5 | pub trait Prefixes { 6 | /// Get an iterator over the prefixes of a string that are in a set. 7 | fn prefixes<'a, 'b>(&'a self, word: &'b str) -> PrefixIter<'a, 'b, D>; 8 | } 9 | 10 | impl Prefixes for Set 11 | where 12 | D: AsRef<[u8]>, 13 | { 14 | fn prefixes<'a, 'b>(&'a self, word: &'b str) -> PrefixIter<'a, 'b, D> { 15 | PrefixIter { 16 | fst: self.as_fst(), 17 | node: self.as_fst().root(), 18 | prefix_len: 0, 19 | word, 20 | } 21 | } 22 | } 23 | 24 | /// Prefix iterator. 25 | pub struct PrefixIter<'a, 'b, D> { 26 | fst: &'a Fst, 27 | node: Node<'a>, 28 | prefix_len: usize, 29 | word: &'b str, 30 | } 31 | 32 | impl<'a, 'b, D> Iterator for PrefixIter<'a, 'b, D> 33 | where 34 | D: AsRef<[u8]>, 35 | { 36 | type Item = &'b str; 37 | 38 | fn next(&mut self) -> Option { 39 | while self.prefix_len < self.word.len() { 40 | match self.node.find_input(self.word.as_bytes()[self.prefix_len]) { 41 | Some(trans_idx) => { 42 | let trans = self.node.transition(trans_idx); 43 | self.node = self.fst.node(trans.addr); 44 | self.prefix_len += 1; 45 | } 46 | None => return None, 47 | }; 48 | 49 | if self.node.is_final() { 50 | return Some(&self.word[..self.prefix_len]); 51 | } 52 | } 53 | 54 | None 55 | } 56 | } 57 | 58 | /// Search the longest prefix of a string in a set. 59 | pub trait LongestPrefix { 60 | /// Search the longest prefix of a string in a set. 61 | fn longest_prefix<'a>(&self, word: &'a str) -> Option<&'a str>; 62 | } 63 | 64 | impl LongestPrefix for fst::Set 65 | where 66 | D: AsRef<[u8]>, 67 | { 68 | fn longest_prefix<'a>(&self, word: &'a str) -> Option<&'a str> { 69 | self.prefixes(word).last() 70 | } 71 | } 72 | 73 | #[cfg(test)] 74 | mod tests { 75 | use fst::{Set, SetBuilder}; 76 | 77 | use super::Prefixes; 78 | 79 | fn test_set() -> Set> { 80 | let mut builder = SetBuilder::memory(); 81 | builder.extend_iter(["p", "pre", "pref", "prefix"]).unwrap(); 82 | let bytes = builder.into_inner().unwrap(); 83 | Set::new(bytes).unwrap() 84 | } 85 | 86 | #[test] 87 | fn finds_prefixes() { 88 | let set = test_set(); 89 | 90 | let mut iter = set.prefixes("prefixes"); 91 | assert_eq!(iter.next(), Some("p")); 92 | assert_eq!(iter.next(), Some("pre")); 93 | assert_eq!(iter.next(), Some("pref")); 94 | assert_eq!(iter.next(), Some("prefix")); 95 | assert!(iter.next().is_none()); 96 | 97 | let mut iter = set.prefixes("pre"); 98 | assert_eq!(iter.next(), Some("p")); 99 | assert_eq!(iter.next(), Some("pre")); 100 | assert!(iter.next().is_none()); 101 | 102 | assert!(set.prefixes("fix").next().is_none()); 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /syntaxdot-cli/src/subcommands/filter_len.rs: -------------------------------------------------------------------------------- 1 | use std::io::BufWriter; 2 | 3 | use anyhow::{Context, Result}; 4 | use clap::{Arg, ArgMatches, Command}; 5 | use conllu::io::{ReadSentence, Reader, WriteSentence, Writer}; 6 | use stdinout::{Input, Output}; 7 | 8 | use crate::io::{load_config, load_tokenizer}; 9 | use crate::traits::SyntaxDotApp; 10 | 11 | const CONFIG: &str = "CONFIG"; 12 | const MAX_LEN: &str = "MAX_LEN"; 13 | const INPUT: &str = "INPUT"; 14 | const OUTPUT: &str = "OUTPUT"; 15 | 16 | pub struct FilterLenApp { 17 | config: String, 18 | input: Option, 19 | max_len: usize, 20 | output: Option, 21 | } 22 | 23 | impl SyntaxDotApp for FilterLenApp { 24 | fn app() -> Command { 25 | Command::new("filter-len") 26 | .arg_required_else_help(true) 27 | .about("Filter corpus by the sentence length in pieces") 28 | .arg( 29 | Arg::new(CONFIG) 30 | .help("SyntaxDot configuration file") 31 | .index(1) 32 | .required(true), 33 | ) 34 | .arg( 35 | Arg::new(MAX_LEN) 36 | .help("Maximum sentence length") 37 | .index(2) 38 | .required(true), 39 | ) 40 | .arg(Arg::new(INPUT).help("Input corpus").index(3)) 41 | .arg(Arg::new(OUTPUT).help("Output corpus").index(4)) 42 | } 43 | 44 | fn parse(matches: &ArgMatches) -> Result { 45 | let config = matches.get_one::(CONFIG).unwrap().into(); 46 | let max_len = matches 47 | .get_one::(MAX_LEN) 48 | .unwrap() 49 | .parse() 50 | .context("Cannot parse maximum sentence length")?; 51 | let input = matches.get_one::(INPUT).map(ToOwned::to_owned); 52 | let output = matches.get_one::(OUTPUT).map(ToOwned::to_owned); 53 | 54 | Ok(FilterLenApp { 55 | config, 56 | input, 57 | max_len, 58 | output, 59 | }) 60 | } 61 | 62 | fn run(&self) -> Result<()> { 63 | let config = load_config(&self.config)?; 64 | 65 | let tokenizer = load_tokenizer(&config)?; 66 | 67 | let input = Input::from(self.input.as_ref()); 68 | let output = Output::from(self.output.as_ref()); 69 | 70 | let treebank_reader = Reader::new( 71 | input 72 | .buf_read() 73 | .context("Cannot open treebank for reading")?, 74 | ); 75 | 76 | let mut treebank_writer = Writer::new(BufWriter::new( 77 | output.write().context("Cannot open treebank for writing")?, 78 | )); 79 | 80 | for sentence in treebank_reader.sentences() { 81 | let sentence = sentence.context("Cannot read sentence from treebank")?; 82 | 83 | let sentence_with_pieces = tokenizer.tokenize(sentence); 84 | 85 | if sentence_with_pieces.pieces.len() <= self.max_len { 86 | treebank_writer 87 | .write_sentence(&sentence_with_pieces.sentence) 88 | .context("Cannot write sentence")?; 89 | } 90 | } 91 | 92 | Ok(()) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/lang/de/tdz/lemma/transform/mod.rs: -------------------------------------------------------------------------------- 1 | use udgraph::graph::Sentence; 2 | 3 | #[allow(clippy::len_without_is_empty)] 4 | pub trait DependencyGraph { 5 | fn dependents<'a>(&'a self, idx: usize) -> Box + 'a>; 6 | 7 | fn token(&self, idx: usize) -> &dyn Token; 8 | 9 | fn token_mut(&mut self, idx: usize) -> &mut dyn TokenMut; 10 | 11 | fn len(&self) -> usize; 12 | } 13 | 14 | impl DependencyGraph for Sentence { 15 | fn dependents<'a>(&'a self, idx: usize) -> Box + 'a> { 16 | Box::new(self.dep_graph().dependents(idx).map(|triple| { 17 | ( 18 | triple.dependent(), 19 | triple 20 | .relation() 21 | .expect("Edge without a dependency relation") 22 | .to_owned(), 23 | ) 24 | })) 25 | } 26 | 27 | fn token(&self, idx: usize) -> &dyn Token { 28 | self[idx] 29 | .token() 30 | .expect("The root node was used as a token") 31 | } 32 | 33 | fn token_mut(&mut self, idx: usize) -> &mut dyn TokenMut { 34 | self[idx] 35 | .token_mut() 36 | .expect("The root node was used as a token") 37 | } 38 | 39 | fn len(&self) -> usize { 40 | self.len() 41 | } 42 | } 43 | 44 | pub trait TokenMut: Token { 45 | fn set_lemma(&mut self, lemma: Option); 46 | } 47 | 48 | pub trait Token { 49 | fn form(&self) -> &str; 50 | fn lemma(&self) -> &str; 51 | fn upos(&self) -> &str; 52 | fn xpos(&self) -> &str; 53 | } 54 | 55 | impl Token for udgraph::token::Token { 56 | fn form(&self) -> &str { 57 | self.form() 58 | } 59 | 60 | fn lemma(&self) -> &str { 61 | self.lemma().unwrap_or("_") 62 | } 63 | 64 | fn upos(&self) -> &str { 65 | self.upos().unwrap() 66 | } 67 | 68 | fn xpos(&self) -> &str { 69 | self.xpos().unwrap() 70 | } 71 | } 72 | 73 | impl TokenMut for udgraph::token::Token { 74 | fn set_lemma(&mut self, lemma: Option) { 75 | self.set_lemma(lemma); 76 | } 77 | } 78 | 79 | pub trait Transform: Sync { 80 | fn transform(&self, graph: &dyn DependencyGraph, node: usize) -> String; 81 | } 82 | 83 | /// A list of `Transform`s. 84 | pub struct Transforms(pub Vec>); 85 | 86 | impl Transforms { 87 | /// Transform a graph using the transformation list. 88 | /// 89 | /// This method applies the transformations to the given graph. Each 90 | /// transform is fully applied to the graph before the next transform, 91 | /// to ensure that dependencies between transforms are correctly handled. 92 | pub fn transform(&self, graph: &mut dyn DependencyGraph) { 93 | for t in &self.0 { 94 | for idx in 1..graph.len() { 95 | let lemma = t.as_ref().transform(graph, idx); 96 | graph.token_mut(idx).set_lemma(Some(lemma)); 97 | } 98 | } 99 | } 100 | } 101 | 102 | pub mod delemmatization; 103 | 104 | pub mod lemmatization; 105 | 106 | pub mod misc; 107 | 108 | mod named_entity; 109 | 110 | mod svp; 111 | 112 | #[cfg(test)] 113 | pub(crate) mod test_helpers; 114 | -------------------------------------------------------------------------------- /syntaxdot/src/util.rs: -------------------------------------------------------------------------------- 1 | use rand::Rng; 2 | 3 | pub struct RandomRemoveVec { 4 | inner: Vec, 5 | rng: R, 6 | } 7 | 8 | impl RandomRemoveVec 9 | where 10 | R: Rng, 11 | { 12 | /// Create a shuffler with the given capacity. 13 | pub fn with_capacity(capacity: usize, rng: R) -> Self { 14 | RandomRemoveVec { 15 | inner: Vec::with_capacity(capacity + 1), 16 | rng, 17 | } 18 | } 19 | 20 | /// Check whether the shuffler is empty. 21 | pub fn is_empty(&self) -> bool { 22 | self.inner.is_empty() 23 | } 24 | 25 | /// Push an element into the shuffler. 26 | pub fn push(&mut self, value: T) { 27 | self.inner.push(value); 28 | } 29 | 30 | /// Get the number of elements in the shuffler. 31 | pub fn len(&self) -> usize { 32 | self.inner.len() 33 | } 34 | } 35 | 36 | impl RandomRemoveVec 37 | where 38 | R: Rng, 39 | { 40 | /// Randomly remove an element from the shuffler. 41 | pub fn remove_random(&mut self) -> Option { 42 | if self.inner.is_empty() { 43 | None 44 | } else { 45 | Some( 46 | self.inner 47 | .swap_remove(self.rng.gen_range(0..self.inner.len())), 48 | ) 49 | } 50 | } 51 | 52 | /// Add `replacement` to the inner and randomly remove an element. 53 | /// 54 | /// `replacement` could also be drawn randomly. 55 | pub fn push_and_remove_random(&mut self, replacement: T) -> T { 56 | self.inner.push(replacement); 57 | self.inner 58 | .swap_remove(self.rng.gen_range(0..self.inner.len())) 59 | } 60 | } 61 | 62 | #[cfg(test)] 63 | mod tests { 64 | use rand::{Rng, SeedableRng}; 65 | use rand_xorshift::XorShiftRng; 66 | 67 | use super::RandomRemoveVec; 68 | 69 | #[test] 70 | fn random_remove_vec() { 71 | let mut rng = XorShiftRng::seed_from_u64(42); 72 | let mut elems = RandomRemoveVec::with_capacity(3, XorShiftRng::seed_from_u64(42)); 73 | elems.push(1); 74 | elems.push(2); 75 | elems.push(3); 76 | 77 | // Before: [1 2 3] 78 | assert_eq!(rng.gen_range(0..4_usize), 1); 79 | assert_eq!(elems.push_and_remove_random(4), 2); 80 | 81 | // Before: [1 4 3] 82 | assert_eq!(rng.gen_range(0..4_usize), 2); 83 | assert_eq!(elems.push_and_remove_random(5), 3); 84 | 85 | // Before: [1 4 5] 86 | assert_eq!(rng.gen_range(0..4_usize), 1); 87 | assert_eq!(elems.push_and_remove_random(6), 4); 88 | 89 | // Before: [1 6 5] 90 | assert_eq!(rng.gen_range(0..3_usize), 1); 91 | assert_eq!(elems.remove_random().unwrap(), 6); 92 | 93 | // Before: [1 5] 94 | assert_eq!(rng.gen_range(0..2_usize), 0); 95 | assert_eq!(elems.remove_random().unwrap(), 1); 96 | 97 | // Before: [5] 98 | assert_eq!(rng.gen_range(0..1_usize), 0); 99 | assert_eq!(elems.remove_random().unwrap(), 5); 100 | 101 | // Exhausted 102 | assert_eq!(elems.remove_random(), None); 103 | 104 | // The buffer is empty, so always return the next number 105 | assert_eq!(elems.push_and_remove_random(7), 7); 106 | assert_eq!(elems.push_and_remove_random(8), 8); 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /syntaxdot-summary/src/event_writer.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, Write}; 2 | 3 | use prost::Message; 4 | 5 | use crate::event_writer::event::What; 6 | use crate::record_writer::TfRecordWriter; 7 | use std::time::SystemTime; 8 | 9 | pub struct EventWriter { 10 | writer: TfRecordWriter, 11 | } 12 | 13 | impl EventWriter { 14 | fn wall_time() -> f64 { 15 | SystemTime::now() 16 | .duration_since(SystemTime::UNIX_EPOCH) 17 | .unwrap() 18 | .as_nanos() as f64 19 | / 1e9 20 | } 21 | } 22 | 23 | impl EventWriter 24 | where 25 | W: Write, 26 | { 27 | pub fn new(write: W) -> io::Result { 28 | Self::new_with_wall_time(write, Self::wall_time()) 29 | } 30 | 31 | pub fn new_with_wall_time(write: W, wall_time: f64) -> io::Result { 32 | let mut writer = EventWriter { 33 | writer: TfRecordWriter::from(write), 34 | }; 35 | 36 | writer.write_event_with_wall_time( 37 | wall_time, 38 | 0, 39 | What::FileVersion("brain.Event:2".to_string()), 40 | )?; 41 | 42 | Ok(writer) 43 | } 44 | 45 | pub fn write_event(&mut self, step: i64, what: What) -> io::Result<()> { 46 | self.write_event_with_wall_time(Self::wall_time(), step, what) 47 | } 48 | 49 | pub fn write_event_with_wall_time( 50 | &mut self, 51 | wall_time: f64, 52 | step: i64, 53 | what: What, 54 | ) -> io::Result<()> { 55 | let event = Event { 56 | wall_time, 57 | step, 58 | what: Some(what), 59 | }; 60 | 61 | let mut event_bytes = vec![]; 62 | event.encode(&mut event_bytes)?; 63 | 64 | self.writer.write(&event_bytes)?; 65 | 66 | self.writer.flush() 67 | } 68 | } 69 | 70 | #[derive(Clone, PartialEq, Message)] 71 | pub struct Event { 72 | #[prost(double, tag = "1")] 73 | wall_time: f64, 74 | 75 | #[prost(int64, tag = "2")] 76 | step: i64, 77 | 78 | #[prost(oneof = "event::What", tags = "3, 4, 5, 6, 7, 8, 9")] 79 | what: Option, 80 | } 81 | 82 | pub mod event { 83 | use prost::Oneof; 84 | 85 | #[derive(Clone, PartialEq, Oneof)] 86 | pub enum What { 87 | #[prost(string, tag = "3")] 88 | FileVersion(std::string::String), 89 | 90 | #[prost(message, tag = "5")] 91 | Summary(super::Summary), 92 | } 93 | } 94 | 95 | #[derive(Clone, PartialEq, Message)] 96 | pub struct Summary { 97 | #[prost(message, repeated, tag = "1")] 98 | pub value: ::std::vec::Vec, 99 | } 100 | 101 | pub mod summary { 102 | use prost::Message; 103 | 104 | #[derive(Clone, PartialEq, Message)] 105 | pub struct Value { 106 | #[prost(string, tag = "7")] 107 | pub node_name: std::string::String, 108 | 109 | #[prost(string, tag = "1")] 110 | pub tag: std::string::String, 111 | 112 | #[prost(oneof = "value::Value", tags = "2, 3, 4, 5, 6, 8")] 113 | pub value: ::std::option::Option, 114 | } 115 | 116 | pub mod value { 117 | use prost::Oneof; 118 | 119 | #[derive(Clone, PartialEq, Oneof)] 120 | pub enum Value { 121 | #[prost(float, tag = "2")] 122 | SimpleValue(f32), 123 | } 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/depseq/mod.rs: -------------------------------------------------------------------------------- 1 | //! Dependency parsing as sequence labeling (Spoustová & Spousta, 2010). 2 | 3 | use serde_derive::{Deserialize, Serialize}; 4 | 5 | mod error; 6 | pub use self::error::*; 7 | 8 | mod post_processing; 9 | pub(crate) use self::post_processing::*; 10 | 11 | mod relative_position; 12 | pub use self::relative_position::*; 13 | 14 | mod relative_pos; 15 | pub use self::relative_pos::*; 16 | 17 | /// Encoding of a dependency relation as a token label. 18 | #[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] 19 | pub struct DependencyEncoding { 20 | head: H, 21 | label: String, 22 | } 23 | 24 | impl DependencyEncoding { 25 | pub fn new(head: H, label: impl Into) -> Self { 26 | DependencyEncoding { 27 | head, 28 | label: label.into(), 29 | } 30 | } 31 | 32 | /// Get the head representation. 33 | pub fn head(&self) -> &H { 34 | &self.head 35 | } 36 | 37 | /// Get the dependency label. 38 | pub fn label(&self) -> &str { 39 | &self.label 40 | } 41 | } 42 | 43 | #[cfg(test)] 44 | mod tests { 45 | use std::fs::File; 46 | use std::io::BufReader; 47 | use std::path::Path; 48 | 49 | use conllu::io::Reader; 50 | use udgraph::graph::{Node, Sentence}; 51 | 52 | use super::{PosLayer, RelativePosEncoder, RelativePositionEncoder}; 53 | use crate::{EncodingProb, SentenceDecoder, SentenceEncoder}; 54 | 55 | const NON_PROJECTIVE_DATA: &str = "testdata/lassy-small-dev.conllu"; 56 | 57 | const ROOT_RELATION: &str = "root"; 58 | 59 | fn copy_sentence_without_deprels(sentence: &Sentence) -> Sentence { 60 | let mut copy = Sentence::new(); 61 | 62 | copy.set_comments(sentence.comments().to_owned()); 63 | 64 | for token in sentence.iter().filter_map(Node::token) { 65 | copy.push(token.clone()); 66 | } 67 | 68 | copy 69 | } 70 | 71 | fn test_encoding(path: P, encoder_decoder: E) 72 | where 73 | P: AsRef, 74 | E: SentenceEncoder + SentenceDecoder, 75 | C: 'static + Clone, 76 | { 77 | let f = File::open(path).unwrap(); 78 | let reader = Reader::new(BufReader::new(f)); 79 | 80 | for sentence in reader { 81 | let sentence = sentence.unwrap(); 82 | 83 | // Encode 84 | let encodings = encoder_decoder 85 | .encode(&sentence) 86 | .unwrap() 87 | .into_iter() 88 | .map(|e| [EncodingProb::new(e, 1.)]) 89 | .collect::>(); 90 | 91 | // Decode 92 | let mut test_sentence = copy_sentence_without_deprels(&sentence); 93 | encoder_decoder 94 | .decode(&encodings, &mut test_sentence) 95 | .unwrap(); 96 | 97 | assert_eq!(sentence, test_sentence); 98 | } 99 | } 100 | 101 | #[test] 102 | fn relative_pos_position() { 103 | let encoder = RelativePosEncoder::new(PosLayer::XPos, ROOT_RELATION); 104 | test_encoding(NON_PROJECTIVE_DATA, encoder); 105 | } 106 | 107 | #[test] 108 | fn relative_position() { 109 | let encoder = RelativePositionEncoder::new(ROOT_RELATION); 110 | test_encoding(NON_PROJECTIVE_DATA, encoder); 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /syntaxdot-tch-ext/src/lib.rs: -------------------------------------------------------------------------------- 1 | use std::ops::Div; 2 | use std::rc::Rc; 3 | 4 | use itertools::Itertools; 5 | use tch::nn::{Init, Path, VarStore}; 6 | use tch::{TchError, Tensor}; 7 | 8 | pub mod tensor; 9 | 10 | /// Trait that provides the root of a variable store. 11 | pub trait RootExt { 12 | /// Get the root of a variable store. 13 | /// 14 | /// In contrast to the regular `root` method, `root_ext` allows 15 | /// you to provide a function that maps a variable name to a 16 | /// parameter group. This is particularly useful for use cases 17 | /// where one wants to put parameters in separate groups, to 18 | /// give each group its own hyper-parameters. 19 | fn root_ext(&self, parameter_group_fun: F) -> PathExt 20 | where 21 | F: 'static + Fn(&str) -> usize; 22 | } 23 | 24 | impl RootExt for VarStore { 25 | fn root_ext(&self, parameter_group_fun: F) -> PathExt 26 | where 27 | F: 'static + Fn(&str) -> usize, 28 | { 29 | PathExt { 30 | inner: self.root(), 31 | parameter_group_fun: Rc::new(parameter_group_fun), 32 | } 33 | } 34 | } 35 | 36 | pub struct PathExt<'a> { 37 | inner: Path<'a>, 38 | parameter_group_fun: Rc usize>, 39 | } 40 | 41 | impl<'a> PathExt<'a> { 42 | /// Create a tensor variable initialized with ones. 43 | pub fn ones(&self, name: &str, dims: &[i64]) -> Tensor { 44 | let group = self.name_group(name); 45 | let path = self.inner.set_group(group); 46 | path.ones(name, dims) 47 | } 48 | 49 | /// Get a sub-path of the current path. 50 | pub fn sub(&'a self, s: T) -> PathExt<'a> { 51 | PathExt { 52 | inner: self.inner.sub(s), 53 | parameter_group_fun: self.parameter_group_fun.clone(), 54 | } 55 | } 56 | 57 | /// Create a tensor variable initialized with the given initializer. 58 | pub fn var(&self, name: &str, dims: &[i64], init: Init) -> Result { 59 | let group = self.name_group(name); 60 | let path = self.inner.set_group(group); 61 | path.f_var(name, dims, init) 62 | } 63 | 64 | /// Create a tensor variable initialized with the values from another tensor. 65 | pub fn var_copy(&self, name: &str, t: &Tensor) -> Tensor { 66 | let group = self.name_group(name); 67 | let path = self.inner.set_group(group); 68 | path.var_copy(name, t) 69 | } 70 | 71 | /// Get the full name of `name` and return its group. 72 | fn name_group(&self, name: &str) -> usize { 73 | let fullname = format!("{}.{}", self.inner.components().join("."), name); 74 | (self.parameter_group_fun)(&fullname) 75 | } 76 | 77 | /// Create a tensor variable initialized with zeros. 78 | pub fn zeros(&self, name: &str, dims: &[i64]) -> Tensor { 79 | let group = self.name_group(name); 80 | let path = self.inner.set_group(group); 81 | path.zeros(name, dims) 82 | } 83 | } 84 | 85 | impl<'a, T> Div for &'a mut PathExt<'a> 86 | where 87 | T: std::string::ToString, 88 | { 89 | type Output = PathExt<'a>; 90 | 91 | fn div(self, rhs: T) -> Self::Output { 92 | self.sub(rhs.to_string()) 93 | } 94 | } 95 | 96 | impl<'a, T> Div for &'a PathExt<'a> 97 | where 98 | T: std::string::ToString, 99 | { 100 | type Output = PathExt<'a>; 101 | 102 | fn div(self, rhs: T) -> Self::Output { 103 | self.sub(rhs.to_string()) 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /syntaxdot-encoders/testdata/lang/de/tdz/lemma/simplify-pidat-lemma.test: -------------------------------------------------------------------------------- 1 | # Format: form lemma upos xpos transformed [rel head_form head_lemma head_upos head_xpos] 2 | # [rel dep_form dep_lemma dep_upos dep_xpos]* 3 | 4 | # all* is lemmatized to 'all' 5 | alle _ _ PIDAT all 6 | aller _ _ PIDAT all 7 | all _ _ PIDAT all 8 | alle aller _ PIDAT all 9 | 10 | # ebensolch* is lemmatized to 'ebensolch' 11 | ebensolcher _ _ PIDAT ebensolch 12 | ebensolches _ _ PIDAT ebensolch 13 | ebensolchen _ _ PIDAT ebensolch 14 | ebensolch _ _ PIDAT ebensolch 15 | 16 | # ebensoviel* is lemmatized to 'ebensoviel' 17 | ebensoviele _ _ PIDAT ebensoviel 18 | ebensovielen _ _ PIDAT ebensoviel 19 | ebensoviel _ _ PIDAT ebensoviel 20 | 21 | # don't touch words with different tag 22 | alle foo _ PIRAT foo 23 | 24 | # don't touch words that are lemmatized to the same anyways 25 | sämtlich sämtlich _ PIDAT sämtlich 26 | 27 | 28 | # all unique forms for _ PIDAT 29 | solche solche _ PIDAT solch 30 | viele viele _ PIDAT viel 31 | paar paar _ PIDAT paar 32 | jeden jeder _ PIDAT jed 33 | bißchen bißchen _ PIDAT bißchen 34 | vielen viele _ PIDAT viel 35 | alle alle _ PIDAT all 36 | solch solch _ PIDAT solch 37 | all all _ PIDAT all 38 | aller alles _ PIDAT all 39 | wenigen wenige _ PIDAT wenig 40 | jede jede _ PIDAT jed 41 | beiden beide _ PIDAT beide 42 | jedes jedes _ PIDAT jed 43 | jeder jeder _ PIDAT jed 44 | beide beide _ PIDAT beide 45 | jegliche jegliche _ PIDAT jeglich 46 | solcher solcher _ PIDAT solch 47 | allen alle _ PIDAT all 48 | beider beide _ PIDAT beide 49 | vieler viele _ PIDAT viel 50 | solchen solcher _ PIDAT solch 51 | jedem jedes _ PIDAT jed 52 | wenige wenige _ PIDAT wenig 53 | alles alles _ PIDAT all 54 | meisten meisten _ PIDAT meist 55 | solches solches _ PIDAT solch 56 | ebensoviele ebensovieles _ PIDAT ebensoviel 57 | wenig wenig _ PIDAT wenig 58 | sämtliche sämtlich _ PIDAT sämtlich 59 | weniger wenige _ PIDAT wenig 60 | allem alles _ PIDAT all 61 | jede/r jeder _ PIDAT jed 62 | jeglichem jeglicher _ PIDAT jeglich 63 | jedwede jedwede _ PIDAT jedwed 64 | manch manch _ PIDAT manch 65 | soviele sovieles _ PIDAT soviel 66 | jeglicher jegliche _ PIDAT jeglich 67 | -zig zig _ PIDAT zig 68 | sämtlicher sämtlich _ PIDAT sämtlich 69 | jedweder jedwede _ PIDAT jedwed 70 | sämtliches sämtlich _ PIDAT sämtlich 71 | wenigsten wenigste _ PIDAT wenigst 72 | zuviele zuvieler _ PIDAT zuviel 73 | jegliches jegliches _ PIDAT jeglich 74 | jeglichen jeglicher _ PIDAT jeglich 75 | solchem solches _ PIDAT solch 76 | meiste meistes _ PIDAT meist 77 | viel vieles _ PIDAT viel 78 | zig zig _ PIDAT zig 79 | sämtlichen sämtlich _ PIDAT sämtlich 80 | allermeisten allermeisten _ PIDAT allermeisten 81 | bei-den beide _ PIDAT beide 82 | soviel sovieler _ PIDAT soviel 83 | wenigste wenigste _ PIDAT wenigst 84 | ebensolches ebensolche _ PIDAT ebensolch 85 | solchenen solche_ _ PIDAT solch 86 | #arrel alles _ PIDAT all 87 | ebensolchem ebensolcher _ PIDAT ebensolch 88 | jedwedem jedweder _ PIDAT jedwed 89 | vielem vieles _ PIDAT viel 90 | ebensolche ebensolche _ PIDAT ebensolch 91 | ebensolchen ebensolche _ PIDAT ebensolch 92 | jedem/r jeder _ PIDAT jed 93 | solch` solch _ PIDAT solch 94 | zuviel zuviele _ PIDAT zuviel 95 | ebensovielen ebensoviele _ PIDAT ebensoviel 96 | paare paar _ PIDAT paar 97 | ebensolcher ebensolch _ PIDAT ebensolch 98 | jeglich'm jeglicher_ _ PIDAT jeglich 99 | gleichviel gleichvieler _ PIDAT gleichvieler 100 | jeglich jegliche _ PIDAT jeglich 101 | bischen bißchen _ PIDAT bißchen 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SyntaxDot 2 | 3 | ## Introduction 4 | 5 | **SyntaxDot** is a sequence labeler and dependency parser using 6 | [Transformer](https://arxiv.org/abs/1706.03762) networks. SyntaxDot models can 7 | be trained from scratch or using pretrained models, such as 8 | [BERT](https://arxiv.org/abs/1810.04805v2) or 9 | [XLM-RoBERTa](https://arxiv.org/abs/1911.02116). 10 | 11 | In principle, SyntaxDot can be used to perform any sequence labeling 12 | task, but so far the focus has been on: 13 | 14 | * Part-of-speech tagging 15 | * Morphological tagging 16 | * Topological field tagging 17 | * Lemmatization 18 | * Named entity recognition 19 | 20 | The easiest way to get started with SyntaxDot is to [use a pretrained 21 | sticker2 22 | model](https://github.com/stickeritis/sticker2/blob/master/doc/pretrained.md) 23 | (SyntaxDot is currently compatbile with sticker2 models). 24 | 25 | ## Features 26 | 27 | * Input representations: 28 | - Word pieces 29 | - Sentence pieces 30 | * Flexible sequence encoder/decoder architecture, which supports: 31 | * Simple sequence labels (e.g. POS, morphology, named entities) 32 | * Lemmatization, based on edit trees 33 | * Simple API to extend to other tasks 34 | * Dependency parsing as sequence labeling 35 | * Dependency parsing using deep biaffine attention and MST decoding. 36 | * Multi-task training and classification using scalar weighting. 37 | * Encoder models: 38 | * Transformers 39 | * Finetuning of BERT, XLM-RoBERTa, ALBERT, and SqueezeBERT models 40 | * Model distillation 41 | * Deployment: 42 | * Standalone binary that links against PyTorch's `libtorch` 43 | * Very liberal [license](LICENSE.md) 44 | 45 | ## Documentation 46 | 47 | * [Installation](doc/install.md) 48 | * [Finetuning](doc/finetune.md) (training) 49 | * [Ready-to-use models](doc/models.md) 50 | 51 | ## References 52 | 53 | SyntaxDot uses techniques from or was inspired by the following papers: 54 | 55 | * The biaffine dependency parsing layer is based on [Deep biaffine attention for 56 | neural dependency parsing](https://arxiv.org/pdf/1611.01734.pdf). 57 | Timothy Dozat and Christopher Manning, ICLR 2017. 58 | * The model architecture and training regime was largely based on [75 59 | Languages, 1 Model: Parsing Universal Dependencies 60 | Universally](https://www.aclweb.org/anthology/D19-1279.pdf). Dan 61 | Kondratyuk and Milan Straka, 2019, Proceedings of the EMNLP 2019 and 62 | the 9th IJCNLP. 63 | * The tagging as sequence labeling scheme was proposed by [Dependency 64 | Parsing as a Sequence Labeling 65 | Task](https://www.degruyter.com/downloadpdf/j/pralin.2010.94.issue--1/v10108-010-0017-3/v10108-010-0017-3.pdf). Drahomíra 66 | Spoustová, Miroslav Spousta, 2010, The Prague Bulletin of 67 | Mathematical Linguistics, Volume 94. 68 | * The idea to combine this scheme with neural networks comes from 69 | [Viable Dependency Parsing as Sequence 70 | Labeling](https://www.aclweb.org/anthology/papers/N/N19/N19-1077/). Michalina 71 | Strzyz, David Vilares, Carlos Gómez-Rodríguez, 2019, Proceedings of 72 | the 2019 Conference of the North American Chapter of the Association 73 | for Computational Linguistics: Human Language Technologies 74 | * The encoding of lemmatization as edit trees was proposed in [Towards 75 | a Machine-Learning Architecture for Lexical Functional Grammar 76 | Parsing](http://grzegorz.chrupala.me/papers/phd-single.pdf). 77 | Grzegorz Chrupała, 2008, PhD dissertation, Dublin City University. 78 | 79 | ## Issues 80 | 81 | You can report bugs and feature requests in the [SyntaxDot issue 82 | tracker](https://github.com/tensordot/syntaxdot/issues). 83 | 84 | ## License 85 | 86 | For licensing information, see [COPYRIGHT.md](COPYRIGHT.md). 87 | -------------------------------------------------------------------------------- /syntaxdot-cli/src/progress.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, Read, Seek, SeekFrom}; 2 | use std::time::Instant; 3 | 4 | use indicatif::{ProgressBar, ProgressStyle}; 5 | use syntaxdot_tokenizers::SentenceWithPieces; 6 | 7 | pub struct ReadProgress { 8 | inner: R, 9 | progress_bar: ProgressBar, 10 | } 11 | 12 | /// A progress bar that implements the `Read` and `Seek` traits. 13 | /// 14 | /// This wrapper of `indicatif`'s `ProgressBar` updates progress based on the 15 | /// current offset within the file. 16 | impl ReadProgress 17 | where 18 | R: Seek, 19 | { 20 | pub fn new(mut read: R) -> io::Result { 21 | let len = read.seek(SeekFrom::End(0))? + 1; 22 | read.seek(SeekFrom::Start(0))?; 23 | let progress_bar = ProgressBar::new(len); 24 | progress_bar.set_style( 25 | ProgressStyle::default_bar() 26 | .template("{bar} {bytes}/{total_bytes}") 27 | .expect("Invalid progress style"), 28 | ); 29 | 30 | Ok(ReadProgress { 31 | inner: read, 32 | progress_bar, 33 | }) 34 | } 35 | 36 | pub fn progress_bar(&self) -> &ProgressBar { 37 | &self.progress_bar 38 | } 39 | } 40 | 41 | impl Read for ReadProgress 42 | where 43 | R: Read + Seek, 44 | { 45 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 46 | let n_read = self.inner.read(buf)?; 47 | let pos = self.inner.stream_position()?; 48 | self.progress_bar.set_position(pos); 49 | Ok(n_read) 50 | } 51 | } 52 | 53 | impl Seek for ReadProgress 54 | where 55 | R: Seek, 56 | { 57 | fn seek(&mut self, pos: SeekFrom) -> io::Result { 58 | let pos = self.inner.seek(pos)?; 59 | self.progress_bar.set_position(pos); 60 | Ok(pos) 61 | } 62 | } 63 | 64 | impl Drop for ReadProgress { 65 | fn drop(&mut self) { 66 | self.progress_bar.finish(); 67 | } 68 | } 69 | 70 | /// Measure the number of sentences processed per second. 71 | /// 72 | /// When an instance of `TaggerSpeed` is constructed, it takes the 73 | /// current time. `count_sentence` should be called for each sentence 74 | /// that was processed. A `TaggerSpeed` instance will print the 75 | /// processing speed to *stderr* when the instance is dropped. 76 | pub struct TaggerSpeed { 77 | start: Instant, 78 | n_pieces: usize, 79 | n_sentences: usize, 80 | } 81 | 82 | impl TaggerSpeed { 83 | /// Construct a new instance. 84 | pub fn new() -> Self { 85 | TaggerSpeed { 86 | start: Instant::now(), 87 | n_pieces: 0, 88 | n_sentences: 0, 89 | } 90 | } 91 | 92 | /// Count a processed sentences. 93 | pub fn count_sentence(&mut self, sentence: &SentenceWithPieces) { 94 | self.n_pieces += sentence.pieces.len(); 95 | self.n_sentences += 1; 96 | } 97 | } 98 | 99 | impl Default for TaggerSpeed { 100 | fn default() -> Self { 101 | TaggerSpeed::new() 102 | } 103 | } 104 | 105 | impl Drop for TaggerSpeed { 106 | fn drop(&mut self) { 107 | let elapsed = self.start.elapsed(); 108 | // From nightly-only as_secs_f32. 109 | let elapsed_secs = elapsed.as_secs_f32(); 110 | log::info!("Annotation took {:.1}s", elapsed_secs); 111 | log::info!( 112 | "Processed {} sentences, {:.0} sents/s, {} pieces, {:.0} pieces/s", 113 | self.n_sentences, 114 | self.n_sentences as f32 / elapsed_secs, 115 | self.n_pieces, 116 | self.n_pieces as f32 / elapsed_secs, 117 | ); 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /syntaxdot-cli/src/save.rs: -------------------------------------------------------------------------------- 1 | use std::collections::VecDeque; 2 | use std::fs; 3 | 4 | use anyhow::{Context, Result}; 5 | use tch::nn::VarStore; 6 | 7 | #[derive(Copy, Clone, Eq, PartialEq)] 8 | pub enum CompletedUnit

{ 9 | /// A batch is completed with the given performance. 10 | Batch(P), 11 | 12 | /// An epoch is completed with the given performance. 13 | /// 14 | /// The performance is of an epoch is typically evaluated against 15 | /// a validation set. 16 | Epoch(P), 17 | } 18 | 19 | /// Trait for model savers. 20 | pub trait Save

{ 21 | /// Save a model 22 | /// 23 | /// Calling this method amounts to a request to save a 24 | /// model. Whether an actual model is saved depends on the 25 | /// implementor. E.g. `EpochSaver` only saves a model for 26 | /// each epoch, so requests to save at a completed batch 27 | /// are ignored. 28 | /// 29 | /// The performance should be that a better performance compares 30 | /// as larger. If smaller is better in a performance measure, the 31 | /// actual measure can be wrapped in `std::cmp::Reverse` to 32 | /// reverse the ordering. 33 | fn save(&mut self, vs: &VarStore, completed: CompletedUnit

) -> Result<()>; 34 | } 35 | 36 | /// Save best epochs with the best performance so far. 37 | #[derive(Clone)] 38 | pub struct BestEpochSaver

{ 39 | best_epoch_performance: Option

, 40 | best_epoch_paths: Option>, 41 | epoch: usize, 42 | keep_best_epochs: Option, 43 | prefix: String, 44 | } 45 | 46 | impl

BestEpochSaver

{ 47 | pub fn new(prefix: impl Into, keep_best_epochs: Option) -> Self { 48 | BestEpochSaver { 49 | best_epoch_performance: None, 50 | best_epoch_paths: keep_best_epochs.map(VecDeque::with_capacity), 51 | epoch: 0, 52 | keep_best_epochs, 53 | prefix: prefix.into(), 54 | } 55 | } 56 | 57 | fn cleanup_old_best_steps(&mut self, step_path: String) { 58 | if let Some(best_epoch_paths) = &mut self.best_epoch_paths { 59 | if best_epoch_paths.len() == self.keep_best_epochs.unwrap() { 60 | let cleanup_step = best_epoch_paths.pop_front().expect("No steps?"); 61 | if let Err(err) = fs::remove_file(&cleanup_step) { 62 | log::error!("Cannot remove step parameters {}: {}", cleanup_step, err); 63 | } 64 | } 65 | 66 | best_epoch_paths.push_back(step_path); 67 | } 68 | } 69 | } 70 | 71 | impl

Save

for BestEpochSaver

72 | where 73 | P: PartialOrd, 74 | { 75 | fn save(&mut self, vs: &VarStore, completed: CompletedUnit

) -> Result<()> { 76 | if let CompletedUnit::Epoch(perf) = completed { 77 | let improvement = match self.best_epoch_performance { 78 | Some(ref mut best) => { 79 | if perf > *best { 80 | *best = perf; 81 | true 82 | } else { 83 | false 84 | } 85 | } 86 | None => { 87 | self.best_epoch_performance = Some(perf); 88 | true 89 | } 90 | }; 91 | 92 | if improvement { 93 | let path = format!("{}epoch-{}", self.prefix, self.epoch); 94 | vs.save(&path).context(format!( 95 | "Cannot save variable store for epoch {}", 96 | self.epoch 97 | ))?; 98 | 99 | self.cleanup_old_best_steps(path) 100 | } 101 | 102 | self.epoch += 1; 103 | } 104 | 105 | Ok(()) 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/depseq/error.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | 3 | use thiserror::Error; 4 | use udgraph::graph::{Node, Sentence}; 5 | use udgraph::token::Token; 6 | 7 | /// Encoder errors. 8 | #[derive(Clone, Debug, Eq, Error, PartialEq)] 9 | pub enum EncodeError { 10 | /// The token does not have a head. 11 | MissingHead { token: usize, sent: Vec }, 12 | 13 | /// The token's head does not have a part-of-speech. 14 | MissingPos { sent: Vec, token: usize }, 15 | 16 | /// The token does not have a dependency relation. 17 | MissingRelation { token: usize, sent: Vec }, 18 | } 19 | 20 | impl EncodeError { 21 | /// Construct `EncodeError::MissingHead` from a CoNLL-X graph. 22 | /// 23 | /// Construct an error. `token` is the node index for which the 24 | /// error applies in `sentence`. 25 | pub fn missing_head(token: usize, sentence: &Sentence) -> EncodeError { 26 | EncodeError::MissingHead { 27 | sent: Self::sentence_to_forms(sentence), 28 | token: token - 1, 29 | } 30 | } 31 | 32 | /// Construct `EncodeError::MissingPOS` from a CoNLL-X graph. 33 | /// 34 | /// Construct an error. `token` is the node index for which the 35 | /// error applies in `sentence`. 36 | pub fn missing_pos(token: usize, sentence: &Sentence) -> EncodeError { 37 | EncodeError::MissingPos { 38 | sent: Self::sentence_to_forms(sentence), 39 | token: token - 1, 40 | } 41 | } 42 | 43 | /// Construct `EncodeError::MissingRelation` from a CoNLL-X graph. 44 | /// 45 | /// Construct an error. `token` is the node index for which the 46 | /// error applies in `sentence`. 47 | pub fn missing_relation(token: usize, sentence: &Sentence) -> EncodeError { 48 | EncodeError::MissingRelation { 49 | sent: Self::sentence_to_forms(sentence), 50 | token: token - 1, 51 | } 52 | } 53 | 54 | fn format_bracketed(bracket_idx: usize, tokens: &[String]) -> String { 55 | let mut tokens = tokens.to_owned(); 56 | tokens.insert(bracket_idx + 1, "]".to_string()); 57 | tokens.insert(bracket_idx, "[".to_string()); 58 | 59 | tokens.join(" ") 60 | } 61 | 62 | fn sentence_to_forms(sentence: &Sentence) -> Vec { 63 | sentence 64 | .iter() 65 | .filter_map(Node::token) 66 | .map(Token::form) 67 | .map(ToOwned::to_owned) 68 | .collect() 69 | } 70 | } 71 | 72 | impl fmt::Display for EncodeError { 73 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 74 | use EncodeError::*; 75 | 76 | match self { 77 | MissingHead { token, sent } => write!( 78 | f, 79 | "Token does not have a head:\n\n{}\n", 80 | Self::format_bracketed(*token, sent), 81 | ), 82 | MissingPos { token, sent } => write!( 83 | f, 84 | "Head of token '{}' does not have a part-of-speech:\n\n{}\n", 85 | sent[*token], 86 | Self::format_bracketed(*token, sent), 87 | ), 88 | MissingRelation { token, sent } => write!( 89 | f, 90 | "Token does not have a dependency relation:\n\n{}\n", 91 | Self::format_bracketed(*token, sent), 92 | ), 93 | } 94 | } 95 | } 96 | 97 | /// Decoder errors. 98 | #[derive(Clone, Copy, Debug, Eq, Error, PartialEq)] 99 | pub(crate) enum DecodeError { 100 | /// The head position is out of bounds. 101 | #[error("position out of bounds")] 102 | PositionOutOfBounds, 103 | 104 | /// The head part-of-speech tag does not occur in the sentence. 105 | #[error("unknown part-of-speech tag")] 106 | InvalidPos, 107 | } 108 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/lang/de/tdz/lemma/transform/delemmatization.rs: -------------------------------------------------------------------------------- 1 | //! Delemmatization transformations. 2 | //! 3 | //! This module provides transformations that converts TüBa-D/Z-style lemmas 4 | //! to `regular' lemmas. 5 | 6 | use super::{DependencyGraph, Transform}; 7 | use crate::lang::de::tdz::lemma::constants::*; 8 | 9 | /// Remove alternative lemma analyses. 10 | /// 11 | /// TüBa-D/Z sometimes provides multiple lemma analyses for a form. This 12 | /// transformation removes all but the first analysis. 13 | pub struct RemoveAlternatives; 14 | 15 | impl Transform for RemoveAlternatives { 16 | fn transform(&self, graph: &dyn DependencyGraph, node: usize) -> String { 17 | let token = graph.token(node); 18 | let mut lemma = token.lemma(); 19 | 20 | if token.xpos().starts_with(PUNCTUATION_PREFIX) 21 | || token.xpos() == NON_WORD_TAG 22 | || token.xpos() == FOREIGN_WORD_TAG 23 | { 24 | return lemma.to_owned(); 25 | } 26 | 27 | if let Some(idx) = lemma.find('|') { 28 | lemma = &lemma[..idx]; 29 | } 30 | 31 | lemma.to_owned() 32 | } 33 | } 34 | 35 | /// Replace reflexive tag. 36 | /// 37 | /// Reflexives use the special *#refl* lemma in TüBa-D/Z. This transformation 38 | /// replaces this pseudo-lemma by the lowercased form. 39 | pub struct RemoveReflexiveTag; 40 | 41 | impl Transform for RemoveReflexiveTag { 42 | fn transform(&self, graph: &dyn DependencyGraph, node: usize) -> String { 43 | let token = graph.token(node); 44 | let lemma = token.lemma(); 45 | 46 | if token.xpos() == REFLEXIVE_PERSONAL_PRONOUN_TAG { 47 | return token.form().to_lowercase(); 48 | } 49 | 50 | lemma.to_owned() 51 | } 52 | } 53 | 54 | /// Remove separable prefixes from verbs. 55 | /// 56 | /// TüBa-D/Z marks separable verb prefixes in the verb lemma. E.g. *ab#zeichnen*, 57 | /// where *ab* is the separable prefix. This transformation handles removes 58 | /// separable prefixes from verbs. For example *ab#zeichnen* is transformed to 59 | /// *zeichnen*. 60 | pub struct RemoveSepVerbPrefix; 61 | 62 | impl Transform for RemoveSepVerbPrefix { 63 | fn transform(&self, graph: &dyn DependencyGraph, node: usize) -> String { 64 | let token = graph.token(node); 65 | let mut lemma = token.lemma(); 66 | 67 | if is_verb(token.xpos()) { 68 | if let Some(idx) = lemma.rfind('#') { 69 | lemma = &lemma[idx + 1..]; 70 | } 71 | } 72 | 73 | lemma.to_owned() 74 | } 75 | } 76 | 77 | /// Remove truncation markers. 78 | /// 79 | /// TüBa-D/Z uses special marking for truncations. For example, *Bau-* in 80 | /// 81 | /// *Bau- und Verkehrsplanungen* 82 | /// 83 | /// is lemmatized as *Bauplanung%n*, recovering the full lemma and adding 84 | /// a simplified part of speech tag of the word (since the form is tagged 85 | /// as *TRUNC*). 86 | /// 87 | /// This transformation replaces the TüBa-D/Z lemma by the word form, such 88 | /// as *Bau-* in this example. If the simplified part of speech tag is not 89 | /// *n*, the lemma is also lowercased. 90 | pub struct RemoveTruncMarker; 91 | 92 | impl Transform for RemoveTruncMarker { 93 | fn transform(&self, graph: &dyn DependencyGraph, node: usize) -> String { 94 | let token = graph.token(node); 95 | let lemma = token.lemma(); 96 | 97 | if token.xpos() != TRUNCATED_TAG { 98 | return lemma.to_owned(); 99 | } 100 | 101 | if token.upos() == "NOUN" { 102 | token.form().to_owned() 103 | } else { 104 | token.form().to_lowercase() 105 | } 106 | } 107 | } 108 | 109 | #[cfg(test)] 110 | mod tests { 111 | use crate::lang::de::tdz::lemma::transform::test_helpers::run_test_cases; 112 | 113 | use super::{RemoveSepVerbPrefix, RemoveTruncMarker}; 114 | 115 | #[test] 116 | pub fn remove_sep_verb_prefix() { 117 | run_test_cases( 118 | "testdata/lang/de/tdz/lemma/remove-sep-verb-prefix.test", 119 | RemoveSepVerbPrefix, 120 | ); 121 | } 122 | 123 | #[test] 124 | pub fn remove_trunc_marker() { 125 | run_test_cases( 126 | "testdata/lang/de/tdz/lemma/remove-trunc-marker.test", 127 | RemoveTruncMarker, 128 | ); 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /syntaxdot/src/dataset/sentence_itertools.rs: -------------------------------------------------------------------------------- 1 | use rand::SeedableRng; 2 | use rand_xorshift::XorShiftRng; 3 | use syntaxdot_tokenizers::SentenceWithPieces; 4 | 5 | use crate::error::SyntaxDotError; 6 | use crate::util::RandomRemoveVec; 7 | 8 | /// The length of a sequence. 9 | /// 10 | /// This enum can be used to express the (maximum) length of a 11 | /// sentence in tokens or in pieces. 12 | #[derive(Debug, Clone, Copy, Eq, PartialEq)] 13 | pub enum SequenceLength { 14 | Tokens(usize), 15 | Pieces(usize), 16 | Unbounded, 17 | } 18 | 19 | /// Trait providing adapters for `SentenceWithPieces` iterators. 20 | pub trait SentenceIterTools<'a>: Sized { 21 | /// Filter sentences by their length. 22 | /// 23 | /// If `max_len` is `None`, then the sentences will not be 24 | /// filtered by length. 25 | fn filter_by_len(self, max_len: SequenceLength) -> LengthFilter; 26 | 27 | /// Shuffle sentences. 28 | /// 29 | /// `buffer_size` is the size of the shuffle buffer that should be 30 | /// used. If `buffer_size` is `None`, then the sentences will not 31 | /// be shuffled. 32 | fn shuffle(self, buffer_size: usize) -> Shuffled; 33 | } 34 | 35 | impl<'a, I> SentenceIterTools<'a> for I 36 | where 37 | I: 'a + Iterator>, 38 | { 39 | fn filter_by_len(self, max_len: SequenceLength) -> LengthFilter { 40 | LengthFilter { 41 | inner: self, 42 | max_len, 43 | } 44 | } 45 | 46 | fn shuffle(self, buffer_size: usize) -> Shuffled { 47 | Shuffled { 48 | inner: self, 49 | buffer: RandomRemoveVec::with_capacity(buffer_size, XorShiftRng::from_entropy()), 50 | buffer_size, 51 | } 52 | } 53 | } 54 | 55 | /// An Iterator adapter filtering sentences by maximum length. 56 | pub struct LengthFilter { 57 | inner: I, 58 | max_len: SequenceLength, 59 | } 60 | 61 | impl Iterator for LengthFilter 62 | where 63 | I: Iterator>, 64 | { 65 | type Item = Result; 66 | 67 | fn next(&mut self) -> Option { 68 | for sent in &mut self.inner { 69 | // Treat Err as length 0 to keep our type as Result. The iterator 70 | // will properly return the Error at a later point. 71 | let too_long = match self.max_len { 72 | SequenceLength::Pieces(max_len) => { 73 | sent.as_ref().map(|s| s.pieces.len()).unwrap_or(0) > max_len 74 | } 75 | SequenceLength::Tokens(max_len) => { 76 | sent.as_ref().map(|s| s.token_offsets.len()).unwrap_or(0) > max_len 77 | } 78 | SequenceLength::Unbounded => false, 79 | }; 80 | 81 | if too_long { 82 | continue; 83 | } 84 | 85 | return Some(sent); 86 | } 87 | None 88 | } 89 | } 90 | 91 | /// An Iterator adapter performing local shuffling. 92 | /// 93 | /// Fills a buffer with size `buffer_size` on the first 94 | /// call. Subsequent calls add the next incoming item to the buffer 95 | /// and pick a random element from the buffer. 96 | pub struct Shuffled { 97 | inner: I, 98 | buffer: RandomRemoveVec, 99 | buffer_size: usize, 100 | } 101 | 102 | impl Iterator for Shuffled 103 | where 104 | I: Iterator>, 105 | { 106 | type Item = Result; 107 | 108 | fn next(&mut self) -> Option { 109 | if self.buffer.is_empty() { 110 | for sent in &mut self.inner { 111 | match sent { 112 | Ok(sent) => self.buffer.push(sent), 113 | Err(err) => return Some(Err(err)), 114 | } 115 | 116 | if self.buffer.len() == self.buffer_size { 117 | break; 118 | } 119 | } 120 | } 121 | 122 | match self.inner.next() { 123 | Some(sent) => match sent { 124 | Ok(sent) => Some(Ok(self.buffer.push_and_remove_random(sent))), 125 | Err(err) => Some(Err(err)), 126 | }, 127 | None => self.buffer.remove_random().map(Result::Ok), 128 | } 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/lang/de/tdz/lemma/transform/svp.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::Ordering; 2 | use std::collections::VecDeque; 3 | 4 | use fst::Set; 5 | 6 | use crate::lang::de::tdz::lemma::automaton::Prefixes; 7 | use crate::lang::de::tdz::lemma::constants::*; 8 | 9 | /// Candidate list of prefixes and the corresponding stripped form. 10 | struct PrefixesCandidate<'a> { 11 | stripped_form: &'a str, 12 | prefixes: Vec, 13 | } 14 | 15 | /// Look for all matches of (prefix)* in the given form. Ideally, 16 | /// we'd construct a Kleene star automaton of the prefix automaton. 17 | /// Unfortunately, this functionality is not (yet) provided by the 18 | /// fst crate. Instead, we repeatedly search prefixes in the set. 19 | fn prefix_star<'a, D>(prefix_set: &Set, s: &'a str) -> Vec> 20 | where 21 | D: AsRef<[u8]>, 22 | { 23 | let mut result = Vec::new(); 24 | 25 | let mut q = VecDeque::new(); 26 | q.push_back(PrefixesCandidate { 27 | stripped_form: s, 28 | prefixes: Vec::new(), 29 | }); 30 | 31 | while let Some(PrefixesCandidate { 32 | stripped_form, 33 | prefixes, 34 | }) = q.pop_front() 35 | { 36 | result.push(PrefixesCandidate { 37 | stripped_form, 38 | prefixes: prefixes.clone(), 39 | }); 40 | 41 | for prefix in prefix_set.prefixes(stripped_form) { 42 | let mut prefixes = prefixes.clone(); 43 | let prefix_len = prefix.len(); 44 | prefixes.push(prefix.to_owned()); 45 | q.push_back(PrefixesCandidate { 46 | stripped_form: &stripped_form[prefix_len..], 47 | prefixes, 48 | }); 49 | } 50 | } 51 | 52 | result 53 | } 54 | 55 | pub fn longest_prefixes(prefix_set: &Set, form: F, lemma: L, tag: T) -> Vec 56 | where 57 | D: AsRef<[u8]>, 58 | F: AsRef, 59 | L: AsRef, 60 | T: AsRef, 61 | { 62 | let lemma = lemma.as_ref(); 63 | let form = form.as_ref(); 64 | let tag = tag.as_ref(); 65 | 66 | let all_prefixes = prefix_star(prefix_set, form); 67 | 68 | FilterPrefixes { 69 | inner: all_prefixes.into_iter(), 70 | lemma, 71 | tag, 72 | } 73 | .max_by(|l, r| { 74 | match l.stripped_form.len().cmp(&r.stripped_form.len()) { 75 | Ordering::Less => return Ordering::Greater, 76 | Ordering::Greater => return Ordering::Less, 77 | Ordering::Equal => (), 78 | } 79 | 80 | l.prefixes.len().cmp(&r.prefixes.len()).reverse() 81 | }) 82 | .map(|t| t.prefixes) 83 | .unwrap_or_else(Vec::new) 84 | } 85 | 86 | fn is_verb(verb: S) -> bool 87 | where 88 | S: AsRef, 89 | { 90 | // A separable verb with a length shorter than 3 is unlikely. 91 | verb.as_ref().len() > 2 92 | } 93 | 94 | struct FilterPrefixes<'a, I> 95 | where 96 | I: Iterator>, 97 | { 98 | lemma: &'a str, 99 | tag: &'a str, 100 | inner: I, 101 | } 102 | 103 | impl<'a, I> Iterator for FilterPrefixes<'a, I> 104 | where 105 | I: Iterator>, 106 | { 107 | type Item = PrefixesCandidate<'a>; 108 | 109 | fn next(&mut self) -> Option { 110 | while let Some(candidate) = self.inner.next() { 111 | if candidate.prefixes.is_empty() { 112 | return Some(candidate); 113 | } 114 | 115 | // I don't like the to_owned() here, but as of 1.14-nightly, the 116 | // borrows checker is not happy about moving candidate otherwise. 117 | let last_prefix = candidate.prefixes.last().unwrap().to_owned(); 118 | 119 | // Avoid e.g. 'dazu' as a valid prefix for a zu-infinitive. 120 | if self.tag == ZU_INFINITIVE_VERB 121 | && last_prefix.ends_with("zu") 122 | && !candidate.stripped_form.starts_with("zu") 123 | { 124 | continue; 125 | } 126 | 127 | // 1. Do not start stripping parts of the lemma 128 | // 2. Prefix should not end with lemma. E.g.: 129 | // abgefangen fangen -> ab#fangen, not: ab#gefangen#fangen 130 | if candidate.prefixes.iter().any(|p| self.lemma.starts_with(p)) 131 | || last_prefix.ends_with(&self.lemma) 132 | || !is_verb(candidate.stripped_form) 133 | { 134 | continue; 135 | } 136 | 137 | return Some(candidate); 138 | } 139 | 140 | None 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /syntaxdot-cli/src/subcommands/prepare.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | use std::io::{BufReader, Write}; 3 | 4 | use anyhow::{Context, Result}; 5 | use clap::{Arg, ArgMatches, Command}; 6 | use conllu::io::{ReadSentence, Reader}; 7 | use indicatif::ProgressStyle; 8 | use syntaxdot::config::{BiaffineParserConfig, Config}; 9 | use syntaxdot::encoders::Encoders; 10 | use syntaxdot_encoders::SentenceEncoder; 11 | 12 | use crate::io::load_config; 13 | use crate::progress::ReadProgress; 14 | use crate::traits::SyntaxDotApp; 15 | use syntaxdot_encoders::dependency::MutableDependencyEncoder; 16 | 17 | const CONFIG: &str = "CONFIG"; 18 | static TRAIN_DATA: &str = "TRAIN_DATA"; 19 | 20 | pub struct PrepareApp { 21 | config: String, 22 | train_data: String, 23 | } 24 | 25 | impl PrepareApp { 26 | fn write_dependency_labels( 27 | config: &BiaffineParserConfig, 28 | encoder: &MutableDependencyEncoder, 29 | ) -> Result<()> { 30 | let mut f = File::create(&config.labels).context(format!( 31 | "Cannot create dependency label file: {}", 32 | config.labels 33 | ))?; 34 | let serialized_labels = 35 | serde_yaml::to_string(&encoder).context("Cannot serialize labels")?; 36 | f.write_all(serialized_labels.as_bytes()) 37 | .context("Cannot write labels") 38 | } 39 | 40 | fn write_labels(config: &Config, encoders: &Encoders) -> Result<()> { 41 | let mut f = File::create(&config.labeler.labels).context(format!( 42 | "Cannot create label file: {}", 43 | config.labeler.labels 44 | ))?; 45 | let serialized_labels = 46 | serde_yaml::to_string(&encoders).context("Cannot serialize labels")?; 47 | f.write_all(serialized_labels.as_bytes()) 48 | .context("Cannot write labels") 49 | } 50 | } 51 | 52 | impl SyntaxDotApp for PrepareApp { 53 | fn app() -> Command { 54 | Command::new("prepare") 55 | .arg_required_else_help(true) 56 | .about("Prepare shape and label files for training") 57 | .arg( 58 | Arg::new(CONFIG) 59 | .help("SyntaxDot configuration file") 60 | .index(1) 61 | .required(true), 62 | ) 63 | .arg( 64 | Arg::new(TRAIN_DATA) 65 | .help("Training data") 66 | .index(2) 67 | .required(true), 68 | ) 69 | } 70 | 71 | fn parse(matches: &ArgMatches) -> Result { 72 | let config = matches.get_one::(CONFIG).unwrap().into(); 73 | let train_data = matches.get_one::(TRAIN_DATA).unwrap().into(); 74 | 75 | Ok(PrepareApp { config, train_data }) 76 | } 77 | 78 | fn run(&self) -> Result<()> { 79 | let config = load_config(&self.config)?; 80 | 81 | let mut biaffine_decoder = config.biaffine.as_ref().map(MutableDependencyEncoder::from); 82 | 83 | let encoders: Encoders = (&config.labeler.encoders).into(); 84 | 85 | let train_file = File::open(&self.train_data) 86 | .context(format!("Cannot open train data file: {}", self.train_data))?; 87 | let read_progress = ReadProgress::new(train_file).context("Cannot create progress bar")?; 88 | let progress_bar = read_progress.progress_bar().clone(); 89 | progress_bar 90 | .set_style(ProgressStyle::default_bar().template( 91 | "[Time: {elapsed_precise}, ETA: {eta_precise}] {bar} {percent}% {msg}", 92 | )?); 93 | 94 | let treebank_reader = Reader::new(BufReader::new(read_progress)); 95 | 96 | for sentence in treebank_reader.sentences() { 97 | let sentence = sentence.context("Cannot read sentence from treebank")?; 98 | 99 | for encoder in &*encoders { 100 | encoder.encoder().encode(&sentence).context(format!( 101 | "Cannot encode sentence with encoder {}", 102 | encoder.name() 103 | ))?; 104 | } 105 | 106 | if let Some(biaffine_decoder) = biaffine_decoder.as_mut() { 107 | biaffine_decoder.encode(&sentence)?; 108 | } 109 | } 110 | 111 | Self::write_labels(&config, &encoders)?; 112 | 113 | if let Some(biaffine_decoder) = biaffine_decoder.as_ref() { 114 | Self::write_dependency_labels( 115 | config 116 | .biaffine 117 | .as_ref() 118 | .expect("Biaffine parser without configuration?"), 119 | biaffine_decoder, 120 | )?; 121 | } 122 | 123 | Ok(()) 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /syntaxdot-tokenizers/src/bert.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | use std::io::{BufRead, BufReader}; 3 | 4 | use udgraph::graph::{Node, Sentence}; 5 | use wordpieces::WordPieces; 6 | 7 | use super::{SentenceWithPieces, Tokenize}; 8 | use crate::TokenizerError; 9 | use std::path::Path; 10 | 11 | /// BERT word piece tokenizer. 12 | /// 13 | /// This tokenizer splits CoNLL-X tokens into word pieces. For 14 | /// example, a sentence such as: 15 | /// 16 | /// > Veruntreute die AWO Spendengeld ? 17 | /// 18 | /// Could be split (depending on the vocabulary) into the following 19 | /// word pieces: 20 | /// 21 | /// > Ver ##unt ##reute die A ##W ##O Spenden ##geld [UNK] 22 | /// 23 | /// Then vocabulary index of each such piece is returned. 24 | /// 25 | /// The unknown token (here `[UNK]`) can be specified while 26 | /// constructing a tokenizer. 27 | pub struct BertTokenizer { 28 | word_pieces: WordPieces, 29 | unknown_piece: String, 30 | } 31 | 32 | impl BertTokenizer { 33 | /// Construct a tokenizer from wordpieces and the unknown piece. 34 | pub fn new(word_pieces: WordPieces, unknown_piece: impl Into) -> Self { 35 | BertTokenizer { 36 | word_pieces, 37 | unknown_piece: unknown_piece.into(), 38 | } 39 | } 40 | 41 | pub fn open

(model_path: P, unknown_piece: impl Into) -> Result 42 | where 43 | P: AsRef, 44 | { 45 | let model_path = model_path.as_ref(); 46 | let f = File::open(model_path) 47 | .map_err(|err| TokenizerError::open_error(model_path.to_string_lossy(), err))?; 48 | Self::read(BufReader::new(f), unknown_piece) 49 | } 50 | 51 | pub fn read( 52 | buf_read: R, 53 | unknown_piece: impl Into, 54 | ) -> Result 55 | where 56 | R: BufRead, 57 | { 58 | let word_pieces = WordPieces::from_buf_read(buf_read)?; 59 | Ok(Self::new(word_pieces, unknown_piece)) 60 | } 61 | } 62 | 63 | impl Tokenize for BertTokenizer { 64 | fn tokenize(&self, sentence: Sentence) -> SentenceWithPieces { 65 | // An average of three pieces per token ought to enough for 66 | // everyone ;). 67 | let mut pieces = Vec::with_capacity((sentence.len() - 1) * 3); 68 | let mut token_offsets = Vec::with_capacity(sentence.len()); 69 | 70 | pieces.push( 71 | self.word_pieces 72 | .get_initial("[CLS]") 73 | .expect("BERT model does not have a [CLS] token") as i64, 74 | ); 75 | 76 | for token in sentence.iter().filter_map(Node::token) { 77 | token_offsets.push(pieces.len()); 78 | 79 | match self 80 | .word_pieces 81 | .split(token.form()) 82 | .map(|piece| piece.idx().map(|piece| piece as i64)) 83 | .collect::>>() 84 | { 85 | Some(word_pieces) => pieces.extend(word_pieces), 86 | None => pieces.push( 87 | self.word_pieces 88 | .get_initial(&self.unknown_piece) 89 | .expect("Cannot get unknown piece") as i64, 90 | ), 91 | } 92 | } 93 | 94 | SentenceWithPieces { 95 | pieces: pieces.into(), 96 | sentence, 97 | token_offsets, 98 | } 99 | } 100 | } 101 | 102 | #[cfg(feature = "model-tests")] 103 | #[cfg(test)] 104 | mod tests { 105 | use std::fs::File; 106 | use std::io::BufReader; 107 | use std::iter::FromIterator; 108 | 109 | use ndarray::array; 110 | use udgraph::graph::Sentence; 111 | use udgraph::token::Token; 112 | use wordpieces::WordPieces; 113 | 114 | use super::BertTokenizer; 115 | use crate::Tokenize; 116 | 117 | fn read_pieces() -> WordPieces { 118 | let f = File::open(env!("BERT_BASE_GERMAN_CASED_VOCAB")).unwrap(); 119 | WordPieces::from_buf_read(BufReader::new(f)).unwrap() 120 | } 121 | 122 | fn sentence_from_forms(forms: &[&str]) -> Sentence { 123 | Sentence::from_iter(forms.iter().map(|&f| Token::new(f))) 124 | } 125 | 126 | #[test] 127 | fn test_pieces() { 128 | let tokenizer = BertTokenizer::new(read_pieces(), "[UNK]"); 129 | 130 | let sentence = sentence_from_forms(&["Veruntreute", "die", "AWO", "Spendengeld", "?"]); 131 | 132 | let sentence_pieces = tokenizer.tokenize(sentence); 133 | assert_eq!( 134 | sentence_pieces.pieces, 135 | array![3i64, 133, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 26972] 136 | ); 137 | assert_eq!(sentence_pieces.token_offsets, &[1, 4, 5, 8, 10]); 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /syntaxdot-tokenizers/src/albert.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | 3 | use sentencepiece::SentencePieceProcessor; 4 | use udgraph::graph::{Node, Sentence}; 5 | 6 | use super::{SentenceWithPieces, Tokenize}; 7 | use crate::TokenizerError; 8 | 9 | /// Tokenizer for ALBERT models. 10 | /// 11 | /// ALBERT uses the sentencepiece tokenizer. However, we cannot use 12 | /// it in the intended way: we would have to detokenize sentences and 13 | /// it is not guaranteed that each token has a unique piece, which is 14 | /// required in sequence labeling. So instead, we use the tokenizer as 15 | /// a subword tokenizer. 16 | pub struct AlbertTokenizer { 17 | spp: SentencePieceProcessor, 18 | } 19 | 20 | impl AlbertTokenizer { 21 | pub fn new(spp: SentencePieceProcessor) -> Self { 22 | AlbertTokenizer { spp } 23 | } 24 | 25 | pub fn open

(model: P) -> Result 26 | where 27 | P: AsRef, 28 | { 29 | let spp = SentencePieceProcessor::open(model)?; 30 | Ok(Self::new(spp)) 31 | } 32 | } 33 | 34 | impl From for AlbertTokenizer { 35 | fn from(spp: SentencePieceProcessor) -> Self { 36 | AlbertTokenizer::new(spp) 37 | } 38 | } 39 | 40 | impl Tokenize for AlbertTokenizer { 41 | fn tokenize(&self, sentence: Sentence) -> SentenceWithPieces { 42 | // An average of three pieces per token ought to be enough for 43 | // everyone ;). 44 | let mut pieces = Vec::with_capacity((sentence.len() + 1) * 3); 45 | let mut token_offsets = Vec::with_capacity(sentence.len()); 46 | 47 | pieces.push( 48 | self.spp 49 | .piece_to_id("[CLS]") 50 | .expect("ALBERT model does not have a [CLS] token") 51 | .expect("ALBERT model does not have a [CLS] token") as i64, 52 | ); 53 | 54 | for token in sentence.iter().filter_map(Node::token) { 55 | token_offsets.push(pieces.len()); 56 | 57 | let token_pieces = self 58 | .spp 59 | .encode(token.form()) 60 | .expect("The sentencepiece tokenizer failed"); 61 | 62 | if !token_pieces.is_empty() { 63 | pieces.extend(token_pieces.into_iter().map(|piece| piece.id as i64)); 64 | } else { 65 | // Use the unknown token id if sentencepiece does not 66 | // give an output for the token. This should not 67 | // happen under normal circumstances, since 68 | // sentencepiece does return this id for unknown 69 | // tokens. However, the input may be corrupt and use 70 | // some form of non-tab whitespace as a form, for which 71 | // sentencepiece does not return any identifier. 72 | pieces.push(self.spp.unk_id() as i64); 73 | } 74 | } 75 | 76 | pieces.push( 77 | self.spp 78 | .piece_to_id("[SEP]") 79 | .expect("ALBERT model does not have a [SEP] token") 80 | .expect("ALBERT model does not have a [SEP] token") as i64, 81 | ); 82 | 83 | SentenceWithPieces { 84 | pieces: pieces.into(), 85 | sentence, 86 | token_offsets, 87 | } 88 | } 89 | } 90 | 91 | #[cfg(feature = "model-tests")] 92 | #[cfg(test)] 93 | mod tests { 94 | use std::iter::FromIterator; 95 | 96 | use ndarray::array; 97 | use sentencepiece::SentencePieceProcessor; 98 | use udgraph::graph::Sentence; 99 | use udgraph::token::Token; 100 | 101 | use super::AlbertTokenizer; 102 | use crate::Tokenize; 103 | 104 | fn sentence_from_forms(forms: &[&str]) -> Sentence { 105 | Sentence::from_iter(forms.iter().map(|&f| Token::new(f))) 106 | } 107 | 108 | fn albert_tokenizer() -> AlbertTokenizer { 109 | let spp = SentencePieceProcessor::open(env!("ALBERT_BASE_V2_SENTENCEPIECE")).unwrap(); 110 | AlbertTokenizer::new(spp) 111 | } 112 | 113 | #[test] 114 | fn tokenizer_gives_expected_output() { 115 | let tokenizer = albert_tokenizer(); 116 | let sent = sentence_from_forms(&["pierre", "vinken", "will", "join", "the", "board", "."]); 117 | let pieces = tokenizer.tokenize(sent); 118 | assert_eq!( 119 | pieces.pieces, 120 | array![2, 5399, 9730, 2853, 129, 1865, 14, 686, 13, 9, 3] 121 | ); 122 | } 123 | 124 | #[test] 125 | fn handles_missing_sentence_pieces() { 126 | let tokenizer = albert_tokenizer(); 127 | let sent = sentence_from_forms(&["pierre", " ", "vinken"]); 128 | let pieces = tokenizer.tokenize(sent); 129 | assert_eq!(pieces.pieces, array![2, 5399, 1, 9730, 2853, 3]); 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /syntaxdot-tokenizers/src/xlm_roberta.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | 3 | use sentencepiece::SentencePieceProcessor; 4 | use udgraph::graph::{Node, Sentence}; 5 | 6 | use super::{SentenceWithPieces, Tokenize}; 7 | use crate::TokenizerError; 8 | 9 | const FAIRSEQ_BOS_ID: i64 = 0; 10 | const FAIRSEQ_EOS_ID: i64 = 2; 11 | const FAIRSEQ_OFFSET: i64 = 1; 12 | const FAIRSEQ_UNK: i64 = 3; 13 | 14 | /// Tokenizer for Roberta models. 15 | /// 16 | /// Roberta uses the sentencepiece tokenizer. However, we cannot use 17 | /// it in the intended way: we would have to detokenize sentences and 18 | /// it is not guaranteed that each token has a unique piece, which is 19 | /// required in sequence labeling. So instead, we use the tokenizer as 20 | /// a subword tokenizer. 21 | pub struct XlmRobertaTokenizer { 22 | spp: SentencePieceProcessor, 23 | } 24 | 25 | impl XlmRobertaTokenizer { 26 | pub fn new(spp: SentencePieceProcessor) -> Self { 27 | XlmRobertaTokenizer { spp } 28 | } 29 | 30 | pub fn open

(model: P) -> Result 31 | where 32 | P: AsRef, 33 | { 34 | let spp = SentencePieceProcessor::open(model)?; 35 | Ok(Self::new(spp)) 36 | } 37 | } 38 | 39 | impl From for XlmRobertaTokenizer { 40 | fn from(spp: SentencePieceProcessor) -> Self { 41 | XlmRobertaTokenizer::new(spp) 42 | } 43 | } 44 | 45 | impl Tokenize for XlmRobertaTokenizer { 46 | fn tokenize(&self, sentence: Sentence) -> SentenceWithPieces { 47 | // An average of three pieces per token ought to be enough for 48 | // everyone ;). 49 | let mut pieces = Vec::with_capacity((sentence.len() - 1) * 3); 50 | let mut token_offsets = Vec::with_capacity(sentence.len()); 51 | 52 | pieces.push(FAIRSEQ_BOS_ID); 53 | 54 | for token in sentence.iter().filter_map(Node::token) { 55 | token_offsets.push(pieces.len()); 56 | 57 | let token_pieces = self 58 | .spp 59 | .encode(token.form()) 60 | .expect("The sentencepiece tokenizer failed"); 61 | 62 | if !token_pieces.is_empty() { 63 | pieces.extend(token_pieces.into_iter().map(|piece| { 64 | let piece_id = piece.id as i64; 65 | if piece_id == self.spp.unk_id() as i64 { 66 | FAIRSEQ_UNK 67 | } else { 68 | piece_id + FAIRSEQ_OFFSET 69 | } 70 | })); 71 | } else { 72 | // Use the unknown token id if sentencepiece does not 73 | // give an output for the token. This should not 74 | // happen under normal circumstances, since 75 | // sentencepiece does return this id for unknown 76 | // tokens. However, the input may be corrupt and use 77 | // some form of non-tab whitespace as a form, for which 78 | // sentencepiece does not return any identifier. 79 | pieces.push(FAIRSEQ_UNK); 80 | } 81 | } 82 | 83 | pieces.push(FAIRSEQ_EOS_ID); 84 | 85 | SentenceWithPieces { 86 | pieces: pieces.into(), 87 | sentence, 88 | token_offsets, 89 | } 90 | } 91 | } 92 | 93 | #[cfg(feature = "model-tests")] 94 | #[cfg(test)] 95 | mod tests { 96 | use std::iter::FromIterator; 97 | 98 | use ndarray::array; 99 | use sentencepiece::SentencePieceProcessor; 100 | use udgraph::graph::Sentence; 101 | use udgraph::token::Token; 102 | 103 | use super::XlmRobertaTokenizer; 104 | use crate::Tokenize; 105 | 106 | fn sentence_from_forms(forms: &[&str]) -> Sentence { 107 | Sentence::from_iter(forms.iter().map(|&f| Token::new(f))) 108 | } 109 | 110 | fn xlm_roberta_tokenizer() -> XlmRobertaTokenizer { 111 | let spp = SentencePieceProcessor::open(env!("XLM_ROBERTA_BASE_SENTENCEPIECE")).unwrap(); 112 | XlmRobertaTokenizer::from(spp) 113 | } 114 | 115 | #[test] 116 | fn tokenizer_gives_expected_output() { 117 | let tokenizer = xlm_roberta_tokenizer(); 118 | let sent = sentence_from_forms(&["Veruntreute", "die", "AWO", "Spendengeld", "?"]); 119 | let pieces = tokenizer.tokenize(sent); 120 | assert_eq!( 121 | pieces.pieces, 122 | array![0, 310, 23451, 107, 6743, 68, 62, 43789, 207126, 49004, 705, 2] 123 | ); 124 | } 125 | 126 | #[test] 127 | fn handles_missing_sentence_pieces() { 128 | let tokenizer = xlm_roberta_tokenizer(); 129 | let sent = sentence_from_forms(&["die", " ", "AWO"]); 130 | let pieces = tokenizer.tokenize(sent); 131 | assert_eq!(pieces.pieces, array![0, 68, 1, 62, 43789, 2]); 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/lang/de/tdz/lemma/transform/test_helpers.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | use std::io::{BufRead, BufReader}; 3 | use std::path::Path; 4 | 5 | use petgraph::graph::{DiGraph, NodeIndex}; 6 | use petgraph::visit::EdgeRef; 7 | use petgraph::Direction; 8 | 9 | use super::{DependencyGraph, Token, TokenMut, Transform}; 10 | 11 | pub struct TestCase { 12 | graph: TestCaseGraph, 13 | index: usize, 14 | correct: String, 15 | } 16 | 17 | struct TestCaseGraph(pub DiGraph); 18 | 19 | impl DependencyGraph for TestCaseGraph { 20 | fn dependents<'a>(&'a self, idx: usize) -> Box + 'a> { 21 | Box::new( 22 | self.0 23 | .edges_directed(NodeIndex::new(idx), Direction::Outgoing) 24 | .map(|e| (e.target().index(), e.weight().to_owned())), 25 | ) 26 | } 27 | 28 | fn token(&self, idx: usize) -> &dyn Token { 29 | &self.0[NodeIndex::new(idx)] 30 | } 31 | 32 | fn token_mut(&mut self, idx: usize) -> &mut dyn TokenMut { 33 | &mut self.0[NodeIndex::new(idx)] 34 | } 35 | 36 | fn len(&self) -> usize { 37 | self.0.node_count() 38 | } 39 | } 40 | 41 | pub struct TestToken { 42 | form: String, 43 | lemma: String, 44 | upos: String, 45 | xpos: String, 46 | } 47 | 48 | impl Token for TestToken { 49 | fn form(&self) -> &str { 50 | &self.form 51 | } 52 | 53 | fn lemma(&self) -> &str { 54 | &self.lemma 55 | } 56 | 57 | fn upos(&self) -> &str { 58 | &self.upos 59 | } 60 | 61 | fn xpos(&self) -> &str { 62 | &self.xpos 63 | } 64 | } 65 | 66 | impl TokenMut for TestToken { 67 | fn set_lemma(&mut self, lemma: Option) { 68 | self.lemma = lemma.expect("Missing lemma for test token"); 69 | } 70 | } 71 | 72 | fn read_dependency(iter: &mut dyn Iterator) -> Option<(String, TestToken)> { 73 | // If there is a relation, read it, otherwise bail out. 74 | let rel = iter.next()?.to_owned(); 75 | 76 | // However, if there is a relation and no token, panic. 77 | Some(( 78 | rel, 79 | read_token(iter).expect("Incomplete dependency relation"), 80 | )) 81 | } 82 | 83 | fn read_token(iter: &mut dyn Iterator) -> Option { 84 | Some(TestToken { 85 | form: iter.next()?.to_owned(), 86 | lemma: iter.next()?.to_owned(), 87 | upos: iter.next()?.to_owned(), 88 | xpos: iter.next()?.to_owned(), 89 | }) 90 | } 91 | 92 | fn read_test_cases(buf_read: R) -> Vec 93 | where 94 | R: BufRead, 95 | { 96 | let mut test_cases = Vec::new(); 97 | 98 | for line in buf_read.lines() { 99 | let line = line.unwrap(); 100 | let line_str = line.trim(); 101 | 102 | // Skip empty lines 103 | if line_str.is_empty() { 104 | continue; 105 | } 106 | 107 | // Skip comments 108 | if line_str.starts_with('#') { 109 | continue; 110 | } 111 | 112 | let mut iter = line.split_whitespace(); 113 | 114 | let mut graph = DiGraph::new(); 115 | 116 | graph.add_node(TestToken { 117 | form: "ROOT".to_string(), 118 | lemma: "ROOT".to_string(), 119 | upos: "root".to_string(), 120 | xpos: "root".to_string(), 121 | }); 122 | 123 | let test_token = read_token(&mut iter).unwrap(); 124 | let index = graph.add_node(test_token); 125 | let correct = iter 126 | .next() 127 | .unwrap_or_else(|| panic!("Gold standard lemma missing: {}", line_str)) 128 | .to_owned(); 129 | 130 | // Optional: read head 131 | if let Some((rel, head)) = read_dependency(&mut iter) { 132 | let head_index = graph.add_node(head); 133 | graph.add_edge(head_index, index, rel); 134 | } 135 | 136 | // Optional: read dependents 137 | while let Some((rel, dep)) = read_dependency(&mut iter) { 138 | let dep_index = graph.add_node(dep); 139 | graph.add_edge(index, dep_index, rel); 140 | } 141 | 142 | let test_case = TestCase { 143 | graph: TestCaseGraph(graph), 144 | index: index.index(), 145 | correct, 146 | }; 147 | 148 | test_cases.push(test_case); 149 | } 150 | 151 | test_cases 152 | } 153 | 154 | pub fn run_test_cases(filename: P, transform: T) 155 | where 156 | P: AsRef, 157 | T: Transform, 158 | { 159 | let f = File::open(filename).unwrap(); 160 | let test_cases = read_test_cases(BufReader::new(f)); 161 | 162 | for test_case in test_cases { 163 | assert_eq!( 164 | test_case.correct, 165 | transform.transform(&test_case.graph, test_case.index) 166 | ) 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /syntaxdot-transformers/src/models/albert/embeddings.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Borrow; 2 | 3 | use syntaxdot_tch_ext::PathExt; 4 | use tch::nn::{Linear, Module}; 5 | use tch::Tensor; 6 | 7 | use crate::models::albert::AlbertConfig; 8 | use crate::models::bert::{bert_linear, BertConfig, BertEmbeddings}; 9 | use crate::module::FallibleModuleT; 10 | use crate::TransformerError; 11 | 12 | /// ALBERT embeddings. 13 | /// 14 | /// These embeddings are the same as BERT embeddings. However, we do 15 | /// some wrapping to ensure that the right embedding dimensionality is 16 | /// used. 17 | #[derive(Debug)] 18 | pub struct AlbertEmbeddings { 19 | embeddings: BertEmbeddings, 20 | } 21 | 22 | impl AlbertEmbeddings { 23 | /// Construct new ALBERT embeddings with the given variable store 24 | /// and ALBERT configuration. 25 | pub fn new<'a>( 26 | vs: impl Borrow>, 27 | config: &AlbertConfig, 28 | ) -> Result { 29 | let vs = vs.borrow(); 30 | 31 | // BERT uses the hidden size as the vocab size. 32 | let mut bert_config: BertConfig = config.into(); 33 | bert_config.hidden_size = config.embedding_size; 34 | 35 | let embeddings = BertEmbeddings::new(vs, &bert_config)?; 36 | 37 | Ok(AlbertEmbeddings { embeddings }) 38 | } 39 | 40 | pub fn forward( 41 | &self, 42 | input_ids: &Tensor, 43 | token_type_ids: Option<&Tensor>, 44 | position_ids: Option<&Tensor>, 45 | train: bool, 46 | ) -> Result { 47 | self.embeddings 48 | .forward(input_ids, token_type_ids, position_ids, train) 49 | } 50 | } 51 | 52 | impl FallibleModuleT for AlbertEmbeddings { 53 | type Error = TransformerError; 54 | 55 | fn forward_t(&self, input: &Tensor, train: bool) -> Result { 56 | self.forward(input, None, None, train) 57 | } 58 | } 59 | 60 | /// Projection of ALBERT embeddings into the encoder hidden size. 61 | #[derive(Debug)] 62 | pub struct AlbertEmbeddingProjection { 63 | projection: Linear, 64 | } 65 | 66 | impl AlbertEmbeddingProjection { 67 | pub fn new<'a>( 68 | vs: impl Borrow>, 69 | config: &AlbertConfig, 70 | ) -> Result { 71 | let vs = vs.borrow(); 72 | 73 | let projection = bert_linear( 74 | vs / "embedding_projection", 75 | &config.into(), 76 | config.embedding_size, 77 | config.hidden_size, 78 | "weight", 79 | "bias", 80 | )?; 81 | 82 | Ok(AlbertEmbeddingProjection { projection }) 83 | } 84 | } 85 | 86 | impl Module for AlbertEmbeddingProjection { 87 | fn forward(&self, input: &Tensor) -> Tensor { 88 | self.projection.forward(input) 89 | } 90 | } 91 | 92 | #[cfg(feature = "model-tests")] 93 | #[cfg(test)] 94 | mod tests { 95 | use std::collections::BTreeSet; 96 | 97 | use maplit::btreeset; 98 | use syntaxdot_tch_ext::RootExt; 99 | use tch::nn::VarStore; 100 | use tch::Device; 101 | 102 | use crate::activations::Activation; 103 | use crate::models::albert::{AlbertConfig, AlbertEmbeddings}; 104 | 105 | fn albert_config() -> AlbertConfig { 106 | AlbertConfig { 107 | attention_probs_dropout_prob: 0., 108 | embedding_size: 128, 109 | hidden_act: Activation::GeluNew, 110 | hidden_dropout_prob: 0., 111 | hidden_size: 768, 112 | initializer_range: 0.02, 113 | inner_group_num: 1, 114 | intermediate_size: 3072, 115 | max_position_embeddings: 512, 116 | num_attention_heads: 12, 117 | num_hidden_groups: 1, 118 | num_hidden_layers: 12, 119 | type_vocab_size: 2, 120 | vocab_size: 30000, 121 | } 122 | } 123 | 124 | fn varstore_variables(vs: &VarStore) -> BTreeSet { 125 | vs.variables() 126 | .into_iter() 127 | .map(|(k, _)| k) 128 | .collect::>() 129 | } 130 | 131 | #[test] 132 | fn albert_embeddings_names() { 133 | let config = albert_config(); 134 | 135 | let vs = VarStore::new(Device::Cpu); 136 | let root = vs.root_ext(|_| 0); 137 | 138 | let _embeddings = AlbertEmbeddings::new(root, &config); 139 | 140 | let variables = varstore_variables(&vs); 141 | 142 | assert_eq!( 143 | variables, 144 | btreeset![ 145 | "layer_norm.bias".to_string(), 146 | "layer_norm.weight".to_string(), 147 | "position_embeddings.embeddings".to_string(), 148 | "token_type_embeddings.embeddings".to_string(), 149 | "word_embeddings.embeddings".to_string() 150 | ] 151 | ); 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /syntaxdot-summary/src/summary_writer.rs: -------------------------------------------------------------------------------- 1 | use std::fs::{create_dir_all, File}; 2 | use std::io::{self, BufWriter, ErrorKind, Write}; 3 | use std::path::PathBuf; 4 | 5 | use crate::event_writer::event::What; 6 | use crate::event_writer::summary::value::Value::SimpleValue; 7 | use crate::event_writer::summary::Value; 8 | use crate::event_writer::{EventWriter, Summary}; 9 | use std::time::{SystemTime, UNIX_EPOCH}; 10 | 11 | /// TensorBoard summary writer. 12 | pub struct SummaryWriter { 13 | writer: EventWriter, 14 | } 15 | 16 | impl SummaryWriter> { 17 | /// Construct a writer from a path prefix. 18 | /// 19 | /// For instance, a path such as `tensorboard/bert/opt` will create 20 | /// the directory `tensorboard/bert` if it does not exist. Within that 21 | /// directory, it will write to the file 22 | /// `opt.out.tfevents..`. 23 | pub fn from_prefix(path: impl Into) -> io::Result { 24 | let path = path.into(); 25 | 26 | if path.components().count() == 0 { 27 | return Err(io::Error::new( 28 | ErrorKind::NotFound, 29 | "summary prefix must not be empty".to_string(), 30 | )); 31 | } 32 | 33 | if let Some(dir) = path.parent() { 34 | create_dir_all(dir)?; 35 | } 36 | 37 | let timestamp = SystemTime::now() 38 | .duration_since(UNIX_EPOCH) 39 | .unwrap() 40 | .as_micros(); 41 | let hostname = hostname::get()?; 42 | 43 | let mut path_string = path.into_os_string(); 44 | path_string.push(format!(".out.tfevents.{}.", timestamp)); 45 | path_string.push(hostname); 46 | 47 | SummaryWriter::new(BufWriter::new(File::create(path_string)?)) 48 | } 49 | } 50 | 51 | impl SummaryWriter 52 | where 53 | W: Write, 54 | { 55 | /// Construct a writer from a `Write` type. 56 | pub fn new(write: W) -> io::Result { 57 | let writer = EventWriter::new(write)?; 58 | Ok(SummaryWriter { writer }) 59 | } 60 | 61 | /// Create a writer that uses the given wall time in the version record. 62 | /// 63 | /// This constructor is provided for unit tests. 64 | #[allow(dead_code)] 65 | fn new_with_wall_time(write: W, wall_time: f64) -> io::Result { 66 | let writer = EventWriter::new_with_wall_time(write, wall_time)?; 67 | Ok(SummaryWriter { writer }) 68 | } 69 | 70 | /// Write a scalar. 71 | pub fn write_scalar( 72 | &mut self, 73 | tag: impl Into, 74 | step: i64, 75 | scalar: f32, 76 | ) -> std::io::Result<()> { 77 | self.writer.write_event( 78 | step, 79 | What::Summary(Summary { 80 | value: vec![Value { 81 | node_name: "".to_string(), 82 | tag: tag.into(), 83 | value: Some(SimpleValue(scalar)), 84 | }], 85 | }), 86 | ) 87 | } 88 | 89 | /// Write a scalar with the given wall time. 90 | /// 91 | /// This method is provided for unit tests. 92 | #[allow(dead_code)] 93 | fn write_scalar_with_wall_time( 94 | &mut self, 95 | wall_time: f64, 96 | tag: impl Into, 97 | step: i64, 98 | scalar: f32, 99 | ) -> std::io::Result<()> { 100 | self.writer.write_event_with_wall_time( 101 | wall_time, 102 | step, 103 | What::Summary(Summary { 104 | value: vec![Value { 105 | node_name: "".to_string(), 106 | tag: tag.into(), 107 | value: Some(SimpleValue(scalar)), 108 | }], 109 | }), 110 | ) 111 | } 112 | } 113 | 114 | #[cfg(test)] 115 | mod tests { 116 | use crate::SummaryWriter; 117 | 118 | static CHECK_OUTPUT: [u8; 126] = [ 119 | 24, 0, 0, 0, 0, 0, 0, 0, 163, 127, 75, 34, 9, 0, 0, 128, 54, 111, 246, 215, 65, 26, 13, 98, 120 | 114, 97, 105, 110, 46, 69, 118, 101, 110, 116, 58, 50, 136, 162, 101, 134, 27, 0, 0, 0, 0, 121 | 0, 0, 0, 26, 13, 158, 19, 9, 188, 119, 164, 54, 111, 246, 215, 65, 16, 10, 42, 14, 10, 12, 122 | 10, 5, 104, 101, 108, 108, 111, 21, 0, 0, 40, 66, 93, 240, 111, 128, 27, 0, 0, 0, 0, 0, 0, 123 | 0, 26, 13, 158, 19, 9, 48, 127, 164, 54, 111, 246, 215, 65, 16, 20, 42, 14, 10, 12, 10, 5, 124 | 119, 111, 114, 108, 100, 21, 0, 0, 128, 63, 5, 210, 83, 151, 125 | ]; 126 | 127 | #[test] 128 | fn writes_the_same_output_as_tensorflow() { 129 | let mut data = vec![]; 130 | let mut writer = SummaryWriter::new_with_wall_time(&mut data, 1608105178.).unwrap(); 131 | writer 132 | .write_scalar_with_wall_time(1608105178.569808, "hello", 10, 42.) 133 | .unwrap(); 134 | writer 135 | .write_scalar_with_wall_time(1608105178.570263, "world", 20, 1.) 136 | .unwrap(); 137 | 138 | assert_eq!(data, CHECK_OUTPUT); 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /syntaxdot-summary/build.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2011, The Snappy-Rust Authors. All rights reserved. 2 | // 3 | // Redistribution and use in source and binary forms, with or without 4 | // modification, are permitted provided that the following conditions are 5 | // met: 6 | // 7 | // * Redistributions of source code must retain the above copyright 8 | // notice, this list of conditions and the following disclaimer. 9 | // * Redistributions in binary form must reproduce the above 10 | // copyright notice, this list of conditions and the following disclaimer 11 | // in the documentation and/or other materials provided with the 12 | // distribution. 13 | // * Neither the name of the copyright holder nor the names of its 14 | // contributors may be used to endorse or promote products derived from 15 | // this software without specific prior written permission. 16 | // 17 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | use std::env; 30 | use std::fs::File; 31 | use std::io::{self, Write}; 32 | use std::path::{Path, PathBuf}; 33 | 34 | const CASTAGNOLI_POLY: u32 = 0x82f63b78; 35 | 36 | type Result = std::result::Result>; 37 | 38 | fn main() { 39 | if let Err(err) = try_main() { 40 | panic!("{}", err); 41 | } 42 | } 43 | 44 | fn try_main() -> Result<()> { 45 | let out_dir = match env::var_os("OUT_DIR") { 46 | None => return Err(From::from("OUT_DIR environment variable not defined")), 47 | Some(out_dir) => PathBuf::from(out_dir), 48 | }; 49 | write_tag_lookup_table(&out_dir)?; 50 | write_crc_tables(&out_dir)?; 51 | Ok(()) 52 | } 53 | 54 | fn write_tag_lookup_table(out_dir: &Path) -> Result<()> { 55 | let out_path = out_dir.join("tag.rs"); 56 | let mut out = io::BufWriter::new(File::create(out_path)?); 57 | 58 | writeln!(out, "pub const TAG_LOOKUP_TABLE: [u16; 256] = [")?; 59 | for b in 0u8..=255 { 60 | writeln!(out, " {},", tag_entry(b))?; 61 | } 62 | writeln!(out, "];")?; 63 | Ok(()) 64 | } 65 | 66 | fn tag_entry(b: u8) -> u16 { 67 | let b = b as u16; 68 | match b & 0b00000011 { 69 | 0b00 => { 70 | let lit_len = (b >> 2) + 1; 71 | if lit_len <= 60 { 72 | lit_len 73 | } else { 74 | assert!(lit_len <= 64); 75 | (lit_len - 60) << 11 76 | } 77 | } 78 | 0b01 => { 79 | let len = 4 + ((b >> 2) & 0b111); 80 | let offset = (b >> 5) & 0b111; 81 | (1 << 11) | (offset << 8) | len 82 | } 83 | 0b10 => { 84 | let len = 1 + (b >> 2); 85 | (2 << 11) | len 86 | } 87 | 0b11 => { 88 | let len = 1 + (b >> 2); 89 | (4 << 11) | len 90 | } 91 | _ => unreachable!(), 92 | } 93 | } 94 | 95 | fn write_crc_tables(out_dir: &Path) -> Result<()> { 96 | let out_path = out_dir.join("crc32_table.rs"); 97 | let mut out = io::BufWriter::new(File::create(out_path)?); 98 | 99 | let table = make_table(CASTAGNOLI_POLY); 100 | let table16 = make_table16(CASTAGNOLI_POLY); 101 | 102 | writeln!(out, "pub const TABLE: [u32; 256] = [")?; 103 | for &x in table.iter() { 104 | writeln!(out, " {},", x)?; 105 | } 106 | writeln!(out, "];\n")?; 107 | 108 | writeln!(out, "pub const TABLE16: [[u32; 256]; 16] = [")?; 109 | for table in table16.iter() { 110 | writeln!(out, " [")?; 111 | for &x in table.iter() { 112 | writeln!(out, " {},", x)?; 113 | } 114 | writeln!(out, " ],")?; 115 | } 116 | writeln!(out, "];")?; 117 | 118 | out.flush()?; 119 | 120 | Ok(()) 121 | } 122 | 123 | fn make_table16(poly: u32) -> [[u32; 256]; 16] { 124 | let mut tab = [[0; 256]; 16]; 125 | tab[0] = make_table(poly); 126 | for i in 0..256 { 127 | let mut crc = tab[0][i]; 128 | for j in 1..16 { 129 | crc = (crc >> 8) ^ tab[0][crc as u8 as usize]; 130 | tab[j][i] = crc; 131 | } 132 | } 133 | tab 134 | } 135 | 136 | fn make_table(poly: u32) -> [u32; 256] { 137 | let mut tab = [0; 256]; 138 | for i in 0u32..256u32 { 139 | let mut crc = i; 140 | for _ in 0..8 { 141 | if crc & 1 == 1 { 142 | crc = (crc >> 1) ^ poly; 143 | } else { 144 | crc >>= 1; 145 | } 146 | } 147 | tab[i as usize] = crc; 148 | } 149 | tab 150 | } 151 | -------------------------------------------------------------------------------- /syntaxdot-encoders/testdata/lang/de/tdz/lemma/simplify-pis-lemma.test: -------------------------------------------------------------------------------- 1 | alle alle _ PIS alle 2 | anderen anderer _ PIS ander 3 | was etwas _ PIS etwas 4 | nichts nichts _ PIS nichts 5 | etwas etwas _ PIS etwas 6 | allem alles _ PIS alle 7 | alles alles _ PIS alle 8 | mehr mehr _ PIS mehr 9 | wenig wenig _ PIS wenig 10 | man man _ PIS man 11 | bißchen bißchen _ PIS bißchen 12 | einer einer _ PIS ein 13 | nix nichts _ PIS nichts 14 | meisten meisten _ PIS meist 15 | viele viele _ PIS viel 16 | jeder jeder _ PIS jed 17 | beide beide _ PIS beid 18 | beides beides _ PIS beid 19 | eine eine _ PIS ein 20 | ebensoviel ebensoviel _ PIS ebensoviel 21 | niemanden niemand _ PIS niemand 22 | anderes anderes _ PIS ander 23 | einen einer _ PIS ein 24 | beiden beide _ PIS beid 25 | jemand jemand _ PIS jemand 26 | letztere letzterer _ PIS letzter 27 | solchen solcher _ PIS solch 28 | andere anderer _ PIS ander 29 | keine keine _ PIS kein 30 | einem einer _ PIS ein 31 | eines eines _ PIS ein 32 | viel viel _ PIS viel 33 | einiges einiges _ PIS einig 34 | jeden jeder _ PIS jed 35 | niemand niemand _ PIS niemand 36 | anderem anderes _ PIS ander 37 | vielen viele _ PIS viel 38 | genug genug _ PIS genug 39 | irgendwer irgendwer _ PIS irgendwer 40 | keiner keiner _ PIS kein 41 | letzterem letzterer _ PIS letzter 42 | wenige wenige _ PIS wenig 43 | einige einige _ PIS einig 44 | ihresgleichen ihresgleichen _ PIS ihresgleichen 45 | allen alle _ PIS alle 46 | vieles vieles _ PIS viel 47 | sowas sowas _ PIS sowas 48 | manches manches _ PIS manch 49 | seinesgleichen seinesgleichen _ PIS seinesgleichen 50 | mancher mancher _ PIS manch 51 | meiste meistes _ PIS meist 52 | manche mancher _ PIS manch 53 | eins eines _ PIS ein 54 | a. a. _ PIS a. 55 | einiger einige _ PIS einig 56 | soviel soviel _ PIS soviel 57 | niemandem niemand _ PIS niemand 58 | keinem keiner _ PIS kein 59 | irgendwen irgendwer _ PIS irgendwer 60 | irgendwas irgendetwas _ PIS irgendetwas 61 | allzuviel allzuviel _ PIS allzuviel 62 | jedes jedes _ PIS jed 63 | solche solcher _ PIS solch 64 | weniger weniger _ PIS wenig 65 | jedem jeder _ PIS jed 66 | keinen keiner _ PIS kein 67 | keines keines _ PIS kein 68 | jemanden jemand _ PIS jemand 69 | frau frau _ PIS frau 70 | solches solches _ PIS solch 71 | derlei derlei _ PIS derlei 72 | mehrere mehrere _ PIS mehrere 73 | aller alle _ PIS alle 74 | paar paar _ PIS paar 75 | jemandem jemand _ PIS jemand 76 | wenigsten wenigst _ PIS wenigst 77 | alledem alledem _ PIS alledem 78 | genausoviel genausoviel _ PIS genausoviel 79 | jede(r) jeder _ PIS jed 80 | anderer anderer _ PIS ander 81 | manchem mancher _ PIS manch 82 | wenigen wenige _ PIS wenig 83 | ersteres ersteres _ PIS erster 84 | einzige einziger _ PIS einzig 85 | irgendwem irgendwer _ PIS irgendwer 86 | allerhand allerhand _ PIS allerhand 87 | solcher solcher _ PIS solch 88 | letzteres letzteres _ PIS letzter 89 | andern anderes _ PIS ander 90 | zuwenig zuwenig _ PIS zuwenig 91 | ersterer ersterer _ PIS erster 92 | letzteren letzterer _ PIS letzter 93 | zuviel zuviel _ PIS zuviel 94 | irgendjemand irgendjemand _ PIS irgendjemand 95 | vieler viele _ PIS viel 96 | anders anderes _ PIS ander 97 | unsereins unsereins _ PIS unsereins 98 | zweierlei zweierlei _ PIS zweierlei 99 | einigen einige _ PIS einig 100 | manchen mancher _ PIS manch 101 | jede jede _ PIS jed 102 | einziger einziger _ PIS einzig 103 | jedermann jedermann _ PIS jederman 104 | irgendetwas irgendetwas _ PIS irgendetwas 105 | mensch mensch _ PIS mensch 106 | etliche etliche _ PIS etlich 107 | allermeisten allermeisten _ PIS allermeisten 108 | irgendjemanden irgendjemand _ PIS irgendjemand 109 | allerlei allerlei _ PIS allerlei 110 | -frau jedefrau _ PIS jedefrau 111 | mann mann _ PIS mann 112 | nicht nichts _ PIS nichts 113 | letzterer letzterer _ PIS letzter 114 | unsereinem unsereiner _ PIS unsereiner 115 | etlichen etliche _ PIS etlich 116 | andre anderer _ PIS ander 117 | genausowenig genausowenig _ PIS genausowenig 118 | beidem beides _ PIS beid 119 | niemand's niemand _ PIS niemand 120 | an man _ PIS man 121 | vielem vieles _ PIS viel 122 | jederman jedermann _ PIS jederman 123 | bisserl bißchen_ _ PIS bißchen_ 124 | ma man_ _ PIS man_ 125 | irgendeine irgendeine _ PIS irgendein 126 | üllus alles_ _ PIS alles_ 127 | keins keines _ PIS kein 128 | ein eines _ PIS ein 129 | man(n) man _ PIS man 130 | erstere ersterer _ PIS erster 131 | einzigen einziger _ PIS einzig 132 | anderm anderes _ PIS ander 133 | irgendeinen irgendeiner _ PIS irgendein 134 | irgendeiner irgendeiner _ PIS irgendein 135 | ebensolchen ebensolcher _ PIS ebensolcher 136 | irgendeins irgendeines _ PIS irgendein 137 | dergleichen dergleichen _ PIS dergleichen 138 | unseresgleichen unseresgleichen _ PIS unseresgleichen 139 | wenigstens wenigst _ PIS wenigst 140 | spottwenig spottwenig _ PIS spottwenig 141 | wat etwas_ _ PIS etwas_ 142 | ersteren ersterer _ PIS erster 143 | mehreren mehrere _ PIS mehrere 144 | etliches etliches _ PIS etlich 145 | sonstwas sonstwas _ PIS sonstwas 146 | einzelnen einzelner _ PIS einzeln 147 | nüscht nichts_ _ PIS nichts_ 148 | sonstwen sonstwer _ PIS sonstwer 149 | jede/r jeder _ PIS jed 150 | bissl bißchen_ _ PIS bißchen_ 151 | koaner keiner_ _ PIS keiner_ 152 | ois alles_ _ PIS alles_ 153 | einzelne einzelner _ PIS einzeln 154 | ebensowenig ebensowenig _ PIS ebensowenig 155 | letzeres letzeres _ PIS letzeres 156 | zuviele zuviele _ PIS zuviel 157 | kein keiner _ PIS kein 158 | zig zig _ PIS zig 159 | soviele soviele _ PIS soviel 160 | allis alles _ PIS alles 161 | beider beide _ PIS beid 162 | wos etwas_ _ PIS etwas_ 163 | dreierlei dreierlei _ PIS dreierlei 164 | -------------------------------------------------------------------------------- /syntaxdot-encoders/src/lang/de/tdz/lemma/transform/named_entity.rs: -------------------------------------------------------------------------------- 1 | use std::iter; 2 | use std::iter::FromIterator; 3 | 4 | use caseless::Caseless; 5 | use seqalign::op::{archetype, Operation}; 6 | use seqalign::{Align, Measure, SeqPair}; 7 | use unicode_normalization::UnicodeNormalization; 8 | 9 | /// Levenshtein distance with case a case-insensitive match operation. 10 | #[derive(Clone, Debug)] 11 | struct CaseInsensitiveLevenshtein { 12 | ops: [CaseInsensitiveLevenshteinOp; 4], 13 | } 14 | 15 | impl CaseInsensitiveLevenshtein { 16 | /// Construct a Levenshtein measure with the associated insertion, deletion, 17 | /// and substitution cost. 18 | pub fn new(insert_cost: usize, delete_cost: usize, substitute_cost: usize) -> Self { 19 | use self::CaseInsensitiveLevenshteinOp::*; 20 | 21 | CaseInsensitiveLevenshtein { 22 | ops: [ 23 | Insert(insert_cost), 24 | Delete(delete_cost), 25 | Match, 26 | Substitute(substitute_cost), 27 | ], 28 | } 29 | } 30 | } 31 | 32 | impl Measure for CaseInsensitiveLevenshtein { 33 | type Operation = CaseInsensitiveLevenshteinOp; 34 | 35 | fn operations(&self) -> &[Self::Operation] { 36 | &self.ops 37 | } 38 | } 39 | 40 | /// Case-insensitive Levenshtein operation with associated cost. 41 | #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] 42 | enum CaseInsensitiveLevenshteinOp { 43 | Insert(usize), 44 | Delete(usize), 45 | Match, 46 | Substitute(usize), 47 | } 48 | 49 | impl Operation for CaseInsensitiveLevenshteinOp { 50 | fn backtrack( 51 | &self, 52 | seq_pair: &SeqPair, 53 | source_idx: usize, 54 | target_idx: usize, 55 | ) -> Option<(usize, usize)> { 56 | use self::CaseInsensitiveLevenshteinOp::*; 57 | 58 | match *self { 59 | Delete(cost) => archetype::Delete(cost).backtrack(seq_pair, source_idx, target_idx), 60 | Insert(cost) => archetype::Insert(cost).backtrack(seq_pair, source_idx, target_idx), 61 | Match => archetype::Match.backtrack(seq_pair, source_idx, target_idx), 62 | Substitute(cost) => { 63 | archetype::Substitute(cost).backtrack(seq_pair, source_idx, target_idx) 64 | } 65 | } 66 | } 67 | 68 | fn cost( 69 | &self, 70 | seq_pair: &SeqPair, 71 | cost_matrix: &[Vec], 72 | source_idx: usize, 73 | target_idx: usize, 74 | ) -> Option { 75 | use self::CaseInsensitiveLevenshteinOp::*; 76 | 77 | let (from_source_idx, from_target_idx) = 78 | self.backtrack(seq_pair, source_idx, target_idx)?; 79 | let orig_cost = cost_matrix[from_source_idx][from_target_idx]; 80 | 81 | match *self { 82 | Delete(cost) => { 83 | archetype::Delete(cost).cost(seq_pair, cost_matrix, source_idx, target_idx) 84 | } 85 | Insert(cost) => { 86 | archetype::Insert(cost).cost(seq_pair, cost_matrix, source_idx, target_idx) 87 | } 88 | Match => { 89 | if iter::once(seq_pair.source[from_source_idx]) 90 | .default_caseless_match(iter::once(seq_pair.target[from_target_idx])) 91 | { 92 | Some(orig_cost) 93 | } else { 94 | None 95 | } 96 | } 97 | Substitute(cost) => { 98 | archetype::Substitute(cost).cost(seq_pair, cost_matrix, source_idx, target_idx) 99 | } 100 | } 101 | } 102 | } 103 | 104 | /// This function restores uppercase characters in lowercased lemmas from 105 | /// the corresponding forms. This task is actually more complex than it 106 | /// may seem initially due to the properties of Unicode. In particular: 107 | /// 108 | /// * Many characters have code points in Unicode, but can also be formed 109 | /// using composed codepoints (e.g. characters with diacritics such as 110 | /// ë). This function applies Normalization Form C to ensure the equivalent 111 | /// representation of characters in the two strings. 112 | /// * Uppercasing or lowercasing a character that is a single code point may 113 | /// result in multiple codepoints. In particular, 'ẞ' (upercased sz) can be 114 | /// lowercased to 'ß' (simple case folding) or 'ss' (full case fulding). 115 | /// This is partially handled --- individual codepoints are compared using 116 | /// Unicode caseless matching. However, if a character is 1 codepoint in 117 | /// the form and >1 codepoint in the lemma (e.g. ẞ vs. ss) or vice versa, 118 | /// the characters will not be matched. 119 | pub(crate) fn restore_named_entity_case(form: S1, lemma: S2) -> String 120 | where 121 | S1: AsRef, 122 | S2: AsRef, 123 | { 124 | // Get code points after NFC normalization. 125 | let form_chars: Vec = form.as_ref().nfc().collect(); 126 | let mut lemma_chars: Vec = lemma.as_ref().nfc().collect(); 127 | 128 | // Align the strings using case-insensitive Levenshtein distance. 129 | let levenshtein = CaseInsensitiveLevenshtein::new(1, 1, 1); 130 | let script = levenshtein.align(&form_chars, &lemma_chars).edit_script(); 131 | 132 | // Copy over aligned characters from the form to the lemma. 133 | for op in script { 134 | if let CaseInsensitiveLevenshteinOp::Match = op.operation() { 135 | lemma_chars[op.target_idx()] = form_chars[op.source_idx()]; 136 | } 137 | } 138 | 139 | String::from_iter(lemma_chars) 140 | } 141 | -------------------------------------------------------------------------------- /syntaxdot/src/optimizers/grad_scale.rs: -------------------------------------------------------------------------------- 1 | use std::convert::TryFrom; 2 | 3 | use tch::{Kind, Tensor}; 4 | 5 | use super::{Optimizer, ZeroGrad}; 6 | use crate::error::SyntaxDotError; 7 | 8 | /// Gradient scaler 9 | /// 10 | /// This data type implements gradient scaling. 11 | /// 12 | /// In mixed-precision training, gradients underflow more quickly in FP16 13 | /// as they become smaller, stopping backpropagation. Gradient scaling 14 | /// counters this by scaling up the loss, to increase the magnitude of 15 | /// gradients. The gradients are then unscaled in FP32. Since loss 16 | /// scaling can also lead to overflow of gradients, the gradients are 17 | /// checked for infinites before performing an optimizer step. If one or 18 | /// more infinite gradients are found, the optimizer step is skipped and 19 | /// the scale is reduced for the next step. 20 | /// 21 | /// `GradientScaler` wraps an optimizer and implements the `Optimizer` 22 | /// trait, so that it can be used in the same contexts as an optimizer 23 | /// can be used. 24 | pub struct GradScaler { 25 | enabled: bool, 26 | growth_factor: f64, 27 | backoff_factor: f64, 28 | growth_interval: i64, 29 | 30 | optimizer: O, 31 | 32 | found_inf: Tensor, 33 | growth_tracker: Tensor, 34 | scale: Tensor, 35 | } 36 | 37 | impl GradScaler 38 | where 39 | O: Optimizer, 40 | { 41 | fn new( 42 | enabled: bool, 43 | optimizer: O, 44 | init_scale: f64, 45 | growth_factor: f64, 46 | backoff_factor: f64, 47 | growth_interval: i64, 48 | ) -> Result { 49 | let device = match optimizer.trainable_variables().first() { 50 | Some(tensor) => tensor.device(), 51 | None => return Err(SyntaxDotError::NoTrainableVariables), 52 | }; 53 | 54 | Ok(GradScaler { 55 | enabled, 56 | growth_factor, 57 | backoff_factor, 58 | growth_interval, 59 | 60 | optimizer, 61 | 62 | found_inf: Tensor::full([1], 0.0, (Kind::Float, device)), 63 | growth_tracker: Tensor::full([1], 0, (Kind::Int, device)), 64 | scale: Tensor::full([1], init_scale, (Kind::Float, device)), 65 | }) 66 | } 67 | 68 | /// Construct a new gradient scaler. 69 | /// 70 | /// The gradient scaler wraps the given optimizer. 71 | pub fn new_with_defaults(enabled: bool, optimizer: O) -> Result { 72 | GradScaler::new(enabled, optimizer, 2f64.powi(16), 2., 0.5, 2000) 73 | } 74 | 75 | /// Get the current scale. 76 | pub fn current_scale(&self) -> f32 { 77 | Vec::::try_from(&self.scale).expect("Tensor cannot be conversted to Vec")[0] 78 | } 79 | 80 | /// Get a reference to the wrapped optimizer. 81 | pub fn optimizer(&self) -> &O { 82 | &self.optimizer 83 | } 84 | 85 | /// Get a mutable reference to the wrapped optimizer. 86 | pub fn optimizer_mut(&mut self) -> &mut O { 87 | &mut self.optimizer 88 | } 89 | 90 | /// Scale the given tensor. 91 | fn scale(&mut self, t: &Tensor) -> Result { 92 | Ok(if !self.enabled { 93 | t.shallow_clone() 94 | } else { 95 | t.f_mul(&self.scale)? 96 | }) 97 | } 98 | 99 | /// Update the scale for the next step. 100 | fn update(&mut self) { 101 | if !self.enabled { 102 | return; 103 | }; 104 | 105 | self.scale = self.scale.internal_amp_update_scale_( 106 | &self.growth_tracker, 107 | &self.found_inf, 108 | self.growth_factor, 109 | self.backoff_factor, 110 | self.growth_interval, 111 | ); 112 | 113 | // Clear infinity found status. 114 | self.found_inf = self.found_inf.zeros_like(); 115 | } 116 | } 117 | 118 | impl Optimizer for GradScaler 119 | where 120 | O: Optimizer, 121 | { 122 | fn backward_step(&mut self, loss: &Tensor) -> Result<(), SyntaxDotError> { 123 | self.optimizer.trainable_variables().zero_grad(); 124 | self.scale(loss)?.f_backward()?; 125 | tch::no_grad(|| self.step()); 126 | self.update(); 127 | Ok(()) 128 | } 129 | 130 | fn set_lr_group(&mut self, group: usize, learning_rate: f64) { 131 | self.optimizer.set_lr_group(group, learning_rate) 132 | } 133 | 134 | fn set_weight_decay_group(&mut self, group: usize, weight_decay: f64) { 135 | self.optimizer.set_weight_decay_group(group, weight_decay) 136 | } 137 | 138 | fn step(&mut self) { 139 | if !self.enabled { 140 | return self.optimizer.step(); 141 | } 142 | 143 | let inv_scale = self.scale.reciprocal().to_kind(Kind::Float); 144 | 145 | for tensor in &mut self.optimizer.trainable_variables() { 146 | if !tensor.grad().defined() { 147 | continue; 148 | } 149 | 150 | tensor 151 | .grad() 152 | .internal_amp_non_finite_check_and_unscale(&mut self.found_inf, &inv_scale); 153 | } 154 | 155 | let found_inf = (f32::try_from(&self.found_inf) 156 | .expect("Cannot convert boolean for infinity detection to f32") 157 | - 1.0) 158 | .abs() 159 | < f32::EPSILON; 160 | 161 | // Only step when there are no infinite gradients. 162 | if !found_inf { 163 | self.optimizer.step() 164 | } 165 | } 166 | 167 | fn trainable_variables(&self) -> Vec { 168 | self.optimizer.trainable_variables() 169 | } 170 | } 171 | --------------------------------------------------------------------------------