├── example ├── ner │ └── main.go ├── qa │ └── main.go ├── modeling │ └── main.go └── bert │ └── main.go ├── pipeline.go ├── trainer.go ├── util ├── optimization.go ├── dropout.go ├── tensor.go ├── init.go ├── linear.go ├── activation.go └── file-util.go ├── pretrained ├── tokenizer.go ├── config.go ├── model.go ├── bert.go └── roberta.go ├── .gitignore ├── pipeline ├── token-classification.go ├── ner.go └── common.go ├── go.mod ├── config-example_test.go ├── bert ├── opt.go ├── tokenizer_test.go ├── config_test.go ├── tokenizer.go ├── embedding.go ├── config.go ├── example_test.go ├── encoder.go ├── attention.go ├── model_test.go └── model.go ├── CHANGELOG.md ├── .travis.yml ├── tokenizer.go ├── modeling.go ├── config.go ├── config_test.go ├── roberta ├── tokenizer.go ├── embedding.go ├── model_test.go └── model.go ├── modeling_test.go ├── go.sum ├── coverage.out ├── README.md └── LICENSE /example/ner/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | -------------------------------------------------------------------------------- /example/qa/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | -------------------------------------------------------------------------------- /pipeline.go: -------------------------------------------------------------------------------- 1 | package transformer 2 | -------------------------------------------------------------------------------- /trainer.go: -------------------------------------------------------------------------------- 1 | package transformer 2 | -------------------------------------------------------------------------------- /util/optimization.go: -------------------------------------------------------------------------------- 1 | package util 2 | -------------------------------------------------------------------------------- /pretrained/tokenizer.go: -------------------------------------------------------------------------------- 1 | package pretrained 2 | 3 | type Tokenizer interface { 4 | Load(modelNamOrPath string, params map[string]interface{}) error 5 | } 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .directory 2 | *.swp 3 | *.swo 4 | 5 | input/ 6 | example/gorgonia/testdata/ 7 | example/testdata/ 8 | *.dat 9 | *.log 10 | 11 | *.txt 12 | *.json 13 | *.bak 14 | 15 | data/ 16 | -------------------------------------------------------------------------------- /example/modeling/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | // "github.com/sugarme/tokenizer" 5 | ) 6 | 7 | type Trainer struct { 8 | Dataset interface{} 9 | Tokenizer interface{} 10 | } 11 | -------------------------------------------------------------------------------- /pretrained/config.go: -------------------------------------------------------------------------------- 1 | package pretrained 2 | 3 | // Config is an interface for pretrained model configuration. 4 | // It has only one method `Load(string) error` to load configuration 5 | // from local or remote file. 6 | type Config interface { 7 | Load(modelNamOrPath string, params map[string]interface{}) error 8 | } 9 | -------------------------------------------------------------------------------- /pipeline/token-classification.go: -------------------------------------------------------------------------------- 1 | package pipeline 2 | 3 | // Token classification pipeline (Named Entity Recognition, Part-of-Speech tagging). 4 | // More generic token classification pipeline, works with multiple models (Bert, Roberta). 5 | 6 | import ( 7 | 8 | // "github.com/sugarme/gotch/nn" 9 | ) 10 | 11 | type TokenClassificationModel struct{} 12 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/sugarme/transformer 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/sugarme/gotch v0.7.0 7 | github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c // indirect 8 | github.com/sugarme/tokenizer v0.1.17 9 | golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 // indirect 10 | golang.org/x/text v0.3.3 // indirect 11 | ) 12 | -------------------------------------------------------------------------------- /pretrained/model.go: -------------------------------------------------------------------------------- 1 | package pretrained 2 | 3 | import ( 4 | "github.com/sugarme/gotch" 5 | ) 6 | 7 | // Model is an interface for pretrained model. 8 | // It has only one method `Load(string) error` to load model 9 | // from local or remote file. 10 | type Model interface { 11 | Load(modelNamOrPath string, config interface{ Config }, params map[string]interface{}, device gotch.Device) error 12 | } 13 | -------------------------------------------------------------------------------- /util/dropout.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "github.com/sugarme/gotch/ts" 5 | ) 6 | 7 | type Dropout struct { 8 | dropoutProb float64 9 | } 10 | 11 | func NewDropout(p float64) *Dropout { 12 | return &Dropout{ 13 | dropoutProb: p, 14 | } 15 | } 16 | 17 | func (d *Dropout) ForwardT(input *ts.Tensor, train bool) (retVal *ts.Tensor) { 18 | return ts.MustDropout(input, d.dropoutProb, train) 19 | } 20 | -------------------------------------------------------------------------------- /config-example_test.go: -------------------------------------------------------------------------------- 1 | package transformer_test 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | 7 | "github.com/sugarme/transformer" 8 | "github.com/sugarme/transformer/bert" 9 | ) 10 | 11 | func ExampleLoadConfig() { 12 | modelNameOrPath := "bert-base-uncased" 13 | var config bert.BertConfig 14 | err := transformer.LoadConfig(&config, modelNameOrPath, nil) 15 | if err != nil { 16 | log.Fatal(err) 17 | } 18 | 19 | fmt.Println(config.VocabSize) 20 | 21 | // Output: 22 | // 30522 23 | } 24 | -------------------------------------------------------------------------------- /util/tensor.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "reflect" 5 | 6 | "github.com/sugarme/gotch/ts" 7 | ) 8 | 9 | // Equal compares 2 tensors in terms of shape, and every element values. 10 | func Equal(tensorA, tensorB *ts.Tensor) bool { 11 | var equal int64 = 0 12 | // 1. Compare shape 13 | if reflect.DeepEqual(tensorA.MustSize(), tensorB.MustSize()) { 14 | // 2. Compare values 15 | equal = tensorA.MustEqTensor(tensorB, false).MustAll(false).Int64Values()[0] 16 | } 17 | if equal == 0 { 18 | return false 19 | } 20 | return true 21 | } 22 | -------------------------------------------------------------------------------- /bert/opt.go: -------------------------------------------------------------------------------- 1 | package bert 2 | 3 | import ( 4 | "github.com/sugarme/gotch/ts" 5 | ) 6 | 7 | // TensorOpt is a function type to create pointer to tensor. 8 | type TensorOpt func() *ts.Tensor 9 | 10 | func MaskTensorOpt(t *ts.Tensor) TensorOpt { 11 | return func() *ts.Tensor { 12 | return t 13 | } 14 | } 15 | 16 | func EncoderMaskTensorOpt(t *ts.Tensor) TensorOpt { 17 | return func() *ts.Tensor { 18 | return t 19 | } 20 | } 21 | 22 | func EncoderHiddenStateTensorOpt(t *ts.Tensor) TensorOpt { 23 | return func() *ts.Tensor { 24 | return t 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /bert/tokenizer_test.go: -------------------------------------------------------------------------------- 1 | package bert_test 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/sugarme/transformer/bert" 8 | ) 9 | 10 | func TestBertTokenizer(t *testing.T) { 11 | var tk *bert.Tokenizer = bert.NewTokenizer() 12 | err := tk.Load("bert-base-uncased", nil) 13 | if err != nil { 14 | t.Error(err) 15 | } 16 | 17 | gotVocabSize := tk.GetVocabSize(false) 18 | wantVocabSize := 30522 19 | 20 | if !reflect.DeepEqual(wantVocabSize, gotVocabSize) { 21 | t.Errorf("Want %v\n", wantVocabSize) 22 | t.Errorf("Got %v\n", gotVocabSize) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 5 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 6 | 7 | ## [Unreleased] 8 | 9 | ### Fixed 10 | - [#...]: Fix a bug with... 11 | 12 | ### Changed 13 | - [#...]: 14 | 15 | ### Added 16 | - [#...]: 17 | 18 | 19 | ## [0.1.2] 20 | 21 | ### Added 22 | - [#2]: Added Roberta model 23 | 24 | ### Changed 25 | - Updated comment document for Bert model. 26 | 27 | ### Fixed 28 | - None 29 | 30 | -------------------------------------------------------------------------------- /util/init.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os" 7 | ) 8 | 9 | var ( 10 | CachedDir string = "NOT_SETTING" 11 | transformerEnvKey string = "GO_TRANSFORMER" 12 | ) 13 | 14 | func init() { 15 | // default path: {$HOME}/.cache/transfomer 16 | homeDir := os.Getenv("HOME") 17 | CachedDir = fmt.Sprintf("%s/.cache/transformer", homeDir) 18 | 19 | initEnv() 20 | 21 | log.Printf("INFO: CachedDir=%q\n", CachedDir) 22 | } 23 | 24 | func initEnv() { 25 | val := os.Getenv(transformerEnvKey) 26 | if val != "" { 27 | CachedDir = val 28 | } 29 | 30 | if _, err := os.Stat(CachedDir); os.IsNotExist(err) { 31 | if err := os.MkdirAll(CachedDir, 0755); err != nil { 32 | log.Fatal(err) 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.14.x 5 | 6 | env: 7 | - GO111MODULE=on 8 | 9 | branches: 10 | only: 11 | - master 12 | 13 | dist: bionic 14 | 15 | before_install: 16 | - sudo apt-get install clang-tools-9 17 | - wget -O /tmp/libtorch-cxx11-abi-shared-with-deps-1.5.1+cpu.zip https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.5.1%2Bcpu.zip 18 | - unzip /tmp/libtorch-cxx11-abi-shared-with-deps-1.5.1+cpu.zip -d /opt 19 | - export LIBTORCH=/opt/libtorch 20 | - export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH 21 | - printenv 22 | - ls 23 | - rm libtch/dummy_cuda_dependency.cpp 24 | - mv libtch/fake_cuda_dependency.cpp.cpu libtch/fake_cuda_dependency.cpp 25 | - rm libtch/lib.go 26 | - mv libtch/lib.go.cpu libtch/lib.go 27 | script: 28 | - go get -u ./... 29 | - go test -v github.com/sugarme/transformer 30 | - go test -v github.com/sugarme/transformer/bert 31 | -------------------------------------------------------------------------------- /tokenizer.go: -------------------------------------------------------------------------------- 1 | package transformer 2 | 3 | import ( 4 | "github.com/sugarme/transformer/pretrained" 5 | ) 6 | 7 | // LoadTokenizer loads pretrained tokenizer from local or remote file. 8 | // 9 | // Parameters: 10 | // - `tk` pretrained.Tokenizer (any tokenizer model that implements pretrained `Tokenizer` interface) 11 | // - `modelNameOrPath` is a string of either 12 | // + Model name or 13 | // + File name or path or 14 | // + URL to remote file 15 | // If `modelNameOrPath` is resolved, function will cache data using `TransformerCache` 16 | // environment if existing, otherwise it will be cached in `$HOME/.cache/transformers/` directory. 17 | // If `modleNameOrPath` is valid URL, file will be downloaded and cached. 18 | // Finally, vocab data will be loaded to `tk`. 19 | func LoadTokenizer(tk pretrained.Tokenizer, modelNameOrPath string, customParams map[string]interface{}) error { 20 | return tk.Load(modelNameOrPath, customParams) 21 | } 22 | -------------------------------------------------------------------------------- /util/linear.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "github.com/sugarme/gotch/nn" 5 | "github.com/sugarme/gotch/ts" 6 | ) 7 | 8 | type LinearNoBiasConfig struct { 9 | WsInit nn.Init // interface 10 | } 11 | 12 | func DefaultLinearNoBiasConfig() *LinearNoBiasConfig { 13 | 14 | init := nn.NewKaimingUniformInit() 15 | 16 | return &LinearNoBiasConfig{WsInit: init} 17 | } 18 | 19 | type LinearNoBias struct { 20 | Ws *ts.Tensor 21 | } 22 | 23 | func NewLinearNoBias(vs *nn.Path, inDim, outDim int64, config *LinearNoBiasConfig) (*LinearNoBias, error) { 24 | 25 | ws, err := vs.NewVar("weight", []int64{outDim, inDim}, config.WsInit) 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | return &LinearNoBias{ 31 | Ws: ws, 32 | }, nil 33 | } 34 | 35 | // Forward implements Module interface for LinearNoBias 36 | func (lnb *LinearNoBias) Forward(xs *ts.Tensor) (retVal *ts.Tensor) { 37 | wsT := lnb.Ws.MustT(false) 38 | retVal = xs.MustMatmul(wsT, false) 39 | wsT.MustDrop() 40 | 41 | return retVal 42 | } 43 | -------------------------------------------------------------------------------- /modeling.go: -------------------------------------------------------------------------------- 1 | package transformer 2 | 3 | import ( 4 | "github.com/sugarme/gotch" 5 | "github.com/sugarme/transformer/pretrained" 6 | ) 7 | 8 | // LoadConfig loads pretrained model data from local or remote file. 9 | // 10 | // Parameters: 11 | // - `model` pretrained Model (any model type that implements pretrained `Model` interface) 12 | // - `modelNameOrPath` is a string of either 13 | // + Model name or 14 | // + File name or path or 15 | // + URL to remote file 16 | // If `modelNameOrPath` is resolved, function will cache data using `TransformerCache` 17 | // environment if existing, otherwise it will be cached in `$HOME/.cache/transformers/` directory. 18 | // If `modleNameOrPath` is valid URL, file will be downloaded and cached. 19 | // Finally, model weights will be loaded to `varstore`. 20 | func LoadModel(model pretrained.Model, modelNameOrPath string, config pretrained.Config, customParams map[string]interface{}, device gotch.Device) error { 21 | return model.Load(modelNameOrPath, config, customParams, device) 22 | } 23 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package transformer 2 | 3 | import ( 4 | "github.com/sugarme/transformer/pretrained" 5 | "github.com/sugarme/transformer/util" 6 | ) 7 | 8 | // LoadConfig loads pretrained configuration data from local or remote file. 9 | // 10 | // Parameters: 11 | // - `config` pretrained.Config (any model config that implements pretrained `Config` interface) 12 | // - `modelNameOrPath` is a string of either 13 | // - Model name or 14 | // - File name or path or 15 | // - URL to remote file 16 | // 17 | // If `modelNameOrPath` is resolved, function will cache data using `TransformerCache` 18 | // environment if existing, otherwise it will be cached in `$HOME/.cache/transformers/` directory. 19 | // If `modleNameOrPath` is valid URL, file will be downloaded and cached. 20 | // Finally, configuration data will be loaded to `config` parameter. 21 | func LoadConfig(config pretrained.Config, modelNameOrPath string, customParams map[string]interface{}) error { 22 | configFile, err := util.CachedPath(modelNameOrPath, "config.json") 23 | if err != nil { 24 | return err 25 | } 26 | return config.Load(configFile, customParams) 27 | } 28 | -------------------------------------------------------------------------------- /bert/config_test.go: -------------------------------------------------------------------------------- 1 | package bert_test 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/sugarme/transformer/bert" 8 | ) 9 | 10 | // No custom params 11 | func TestNewBertConfig_Default(t *testing.T) { 12 | 13 | config := bert.NewConfig(nil) 14 | 15 | wantHiddenAct := "gelu" 16 | gotHiddenAct := config.HiddenAct 17 | if !reflect.DeepEqual(wantHiddenAct, gotHiddenAct) { 18 | t.Errorf("Want: '%v'\n", wantHiddenAct) 19 | t.Errorf("Got: '%v'\n", gotHiddenAct) 20 | } 21 | 22 | wantVocabSize := int64(30522) 23 | gotVocabSize := config.VocabSize 24 | 25 | if !reflect.DeepEqual(wantVocabSize, gotVocabSize) { 26 | t.Errorf("Want: '%v'\n", wantVocabSize) 27 | t.Errorf("Got: '%v'\n", gotVocabSize) 28 | } 29 | } 30 | 31 | // With custom params 32 | func TestNewBertConfig_Custom(t *testing.T) { 33 | 34 | config := bert.NewConfig(map[string]interface{}{"VocabSize": int64(2000), "HiddenAct": "relu"}) 35 | 36 | wantHiddenAct := "relu" 37 | gotHiddenAct := config.HiddenAct 38 | if !reflect.DeepEqual(wantHiddenAct, gotHiddenAct) { 39 | t.Errorf("Want: '%v'\n", wantHiddenAct) 40 | t.Errorf("Got: '%v'\n", gotHiddenAct) 41 | } 42 | 43 | wantVocabSize := int64(2000) 44 | gotVocabSize := config.VocabSize 45 | 46 | if !reflect.DeepEqual(wantVocabSize, gotVocabSize) { 47 | t.Errorf("Want: '%v'\n", wantVocabSize) 48 | t.Errorf("Got: '%v'\n", gotVocabSize) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /pipeline/ner.go: -------------------------------------------------------------------------------- 1 | package pipeline 2 | 3 | // Named Entity Recognition pipeline 4 | // Extracts entities (Person, Location, Organization, Miscellaneous) from text. 5 | // Pretrained models are available for the following languages: 6 | // - English 7 | // - German 8 | // - Spanish 9 | // - Dutch 10 | // 11 | // The default NER mode is an English BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz) 12 | 13 | // Entity holds entity data generated by NERModel 14 | type Entity struct { 15 | // String representation of the Entity 16 | Word string 17 | // Confidence score 18 | Score float64 19 | // Entity label (e.g. ORG, LOC...) 20 | Label string 21 | } 22 | 23 | // NERModel is a model to extract entities 24 | type NERModel struct { 25 | tokenClassificationModel TokenClassificationModel 26 | } 27 | 28 | // NewNERModel creates a NERModel from input config 29 | func NewNERModel(config TokenClassificationModel) *NERModel { 30 | return &NERModel{ 31 | tokenClassificationModel: config, 32 | } 33 | } 34 | 35 | // Predict extracts entities from input text and returns slice of entities with score 36 | func (nm *NERModel) Predict(input []string) []Entity { 37 | tokens := nm.tokenClassificationModel.Predict(input, true, false) 38 | 39 | var entities []Entity 40 | for _, tok := range tokens { 41 | if tok.Label != "0" { 42 | entities = append(entities, Entity{ 43 | Word: tok.Text, 44 | Score: tok.Score, 45 | Label: tok.Label, 46 | }) 47 | } 48 | } 49 | 50 | return entities 51 | } 52 | -------------------------------------------------------------------------------- /config_test.go: -------------------------------------------------------------------------------- 1 | package transformer_test 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/sugarme/transformer" 8 | "github.com/sugarme/transformer/bert" 9 | ) 10 | 11 | // With model name 12 | func TestConfigFromPretrained_ModelName(t *testing.T) { 13 | modelName := "bert-base-uncased" 14 | var config *bert.BertConfig = new(bert.BertConfig) 15 | err := transformer.LoadConfig(config, modelName, nil) 16 | if err != nil { 17 | t.Error(err) 18 | } 19 | 20 | wantVocabSize := int64(30522) 21 | gotVocabSize := config.VocabSize 22 | 23 | if !reflect.DeepEqual(wantVocabSize, gotVocabSize) { 24 | t.Errorf("Want: %v\n", wantVocabSize) 25 | t.Errorf("Got: %v\n", gotVocabSize) 26 | } 27 | } 28 | 29 | // No custom params 30 | func TestConfigFromPretrained(t *testing.T) { 31 | var config *bert.BertConfig = new(bert.BertConfig) 32 | err := transformer.LoadConfig(config, "bert-base-uncased", nil) 33 | if err != nil { 34 | t.Error(err) 35 | } 36 | 37 | wantVocabSize := int64(30522) 38 | gotVocabSize := config.VocabSize 39 | 40 | if !reflect.DeepEqual(wantVocabSize, gotVocabSize) { 41 | t.Errorf("Want: %v\n", wantVocabSize) 42 | t.Errorf("Got: %v\n", gotVocabSize) 43 | } 44 | 45 | } 46 | 47 | // With custom params 48 | func TestConfigFromPretrained_CustomParams(t *testing.T) { 49 | params := map[string]interface{}{ 50 | "VocabSize": int64(2000), 51 | "NumLabels": int64(4), 52 | } 53 | 54 | var config *bert.BertConfig = new(bert.BertConfig) 55 | err := transformer.LoadConfig(config, "bert-base-uncased", params) 56 | if err != nil { 57 | t.Error(err) 58 | } 59 | 60 | wantVocabSize := int64(2000) 61 | gotVocabSize := config.VocabSize 62 | 63 | if !reflect.DeepEqual(wantVocabSize, gotVocabSize) { 64 | t.Errorf("Want: %v\n", wantVocabSize) 65 | t.Errorf("Got: %v\n", gotVocabSize) 66 | } 67 | 68 | wantNumLabels := int64(4) 69 | gotNumLabels := config.NumLabels 70 | 71 | if !reflect.DeepEqual(wantNumLabels, gotNumLabels) { 72 | t.Errorf("Want: %v\n", wantNumLabels) 73 | t.Errorf("Got: %v\n", gotNumLabels) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /roberta/tokenizer.go: -------------------------------------------------------------------------------- 1 | package roberta 2 | 3 | import ( 4 | // "fmt" 5 | 6 | "github.com/sugarme/tokenizer" 7 | "github.com/sugarme/tokenizer/model/bpe" 8 | "github.com/sugarme/tokenizer/normalizer" 9 | "github.com/sugarme/tokenizer/pretokenizer" 10 | "github.com/sugarme/tokenizer/processor" 11 | 12 | "github.com/sugarme/transformer/util" 13 | ) 14 | 15 | // Tokenizer holds data for Roberta tokenizer. 16 | type Tokenizer struct { 17 | *tokenizer.Tokenizer 18 | } 19 | 20 | // NewTokenizer creates a new Roberta tokenizer. 21 | func NewTokenizer() *Tokenizer { 22 | tk := tokenizer.NewTokenizer(nil) 23 | return &Tokenizer{tk} 24 | } 25 | 26 | // Load loads Roberta tokenizer from pretrain vocab and merges files. 27 | func (t *Tokenizer) Load(modelNameOrPath string, params map[string]interface{}) error { 28 | vocabFile, err := util.CachedPath("roberta-base", "vocab.json") 29 | if err != nil { 30 | return err 31 | } 32 | mergesFile, err := util.CachedPath("roberta-base", "merges.txt") 33 | if err != nil { 34 | return err 35 | } 36 | 37 | model, err := bpe.NewBpeFromFiles(vocabFile, mergesFile) 38 | if err != nil { 39 | return err 40 | } 41 | 42 | t.WithModel(model) 43 | 44 | bertNormalizer := normalizer.NewBertNormalizer(true, true, true, true) 45 | t.WithNormalizer(bertNormalizer) 46 | 47 | blPreTokenizer := pretokenizer.NewByteLevel() 48 | // blPreTokenizer.SetAddPrefixSpace(false) 49 | t.WithPreTokenizer(blPreTokenizer) 50 | 51 | var specialTokens []tokenizer.AddedToken 52 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true)) 53 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true)) 54 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true)) 55 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true)) 56 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true)) 57 | t.AddSpecialTokens(specialTokens) 58 | 59 | postProcess := processor.DefaultRobertaProcessing() 60 | t.WithPostProcessor(postProcess) 61 | 62 | return nil 63 | } 64 | -------------------------------------------------------------------------------- /pretrained/bert.go: -------------------------------------------------------------------------------- 1 | package pretrained 2 | 3 | // BertConfigs is a map of pretrained Bert configuration names to corresponding URLs. 4 | var BertConfigs map[string]string = map[string]string{ 5 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 6 | "bert-ner": "https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/config.json", 7 | "bert-qa": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 8 | 9 | // Roberta 10 | "roberta-base": "https://cdn.huggingface.co/roberta-base-config.json", 11 | "roberta-qa": "https://s3.amazonaws.com/models.huggingface.co/bert/deepset/roberta-base-squad2/config.json", 12 | "xlm-roberta-ner-en": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json", 13 | "xlm-roberta-ner-de": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json", 14 | "xlm-roberta-ner-nl": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json", 15 | "xlm-roberta-ner-es": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json", 16 | } 17 | 18 | // BertModels is a map of pretrained Bert model names to corresponding URLs. 19 | var BertModels map[string]string = map[string]string{ 20 | "bert-base-uncased": "https://cdn.huggingface.co/bert-base-uncased-rust_model.ot", 21 | "bert-ner": "https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/rust_model.ot", 22 | "bert-qa": "https://cdn.huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad-rust_model.ot", 23 | } 24 | 25 | // BertVocabs is a map of BERT model vocab name to corresponding URLs. 26 | var BertVocabs map[string]string = map[string]string{ 27 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 28 | "bert-ner": "https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/vocab.txt", 29 | "bert-qa": "https://cdn.huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 30 | } 31 | -------------------------------------------------------------------------------- /bert/tokenizer.go: -------------------------------------------------------------------------------- 1 | package bert 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/sugarme/tokenizer" 7 | "github.com/sugarme/tokenizer/model/wordpiece" 8 | "github.com/sugarme/tokenizer/normalizer" 9 | "github.com/sugarme/tokenizer/pretokenizer" 10 | "github.com/sugarme/tokenizer/processor" 11 | 12 | // "github.com/sugarme/transformer/pretrained" 13 | "github.com/sugarme/transformer/util" 14 | ) 15 | 16 | type BertTokenizerFast = tokenizer.Tokenizer 17 | 18 | // BertJapaneseTokenizerFromPretrained initiate BERT tokenizer for Japanese language from pretrained file. 19 | func BertJapaneseTokenizerFromPretrained(pretrainedModelNameOrPath string, customParams map[string]interface{}) *tokenizer.Tokenizer { 20 | 21 | // TODO: implement it 22 | 23 | panic("Not implemented yet.") 24 | } 25 | 26 | type Tokenizer struct { 27 | *tokenizer.Tokenizer 28 | } 29 | 30 | func NewTokenizer() *Tokenizer { 31 | tk := tokenizer.NewTokenizer(nil) 32 | return &Tokenizer{tk} 33 | } 34 | 35 | func (bt *Tokenizer) Load(modelNameOrPath string, params map[string]interface{}) error { 36 | cachedFile, err := util.CachedPath(modelNameOrPath, "vocab.txt") 37 | if err != nil { 38 | return err 39 | } 40 | 41 | model, err := wordpiece.NewWordPieceFromFile(cachedFile, "[UNK]") 42 | if err != nil { 43 | return err 44 | } 45 | 46 | bt.WithModel(model) 47 | 48 | bertNormalizer := normalizer.NewBertNormalizer(true, true, true, true) 49 | bt.WithNormalizer(bertNormalizer) 50 | 51 | bertPreTokenizer := pretokenizer.NewBertPreTokenizer() 52 | bt.WithPreTokenizer(bertPreTokenizer) 53 | 54 | var specialTokens []tokenizer.AddedToken 55 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("[MASK]", true)) 56 | 57 | bt.AddSpecialTokens(specialTokens) 58 | 59 | sepId, ok := bt.TokenToId("[SEP]") 60 | if !ok { 61 | return fmt.Errorf("Cannot find ID for [SEP] token.\n") 62 | } 63 | sep := processor.PostToken{Id: sepId, Value: "[SEP]"} 64 | 65 | clsId, ok := bt.TokenToId("[CLS]") 66 | if !ok { 67 | return fmt.Errorf("Cannot find ID for [CLS] token.\n") 68 | } 69 | cls := processor.PostToken{Id: clsId, Value: "[CLS]"} 70 | 71 | postProcess := processor.NewBertProcessing(sep, cls) 72 | bt.WithPostProcessor(postProcess) 73 | 74 | // TODO: update params 75 | 76 | return nil 77 | } 78 | -------------------------------------------------------------------------------- /modeling_test.go: -------------------------------------------------------------------------------- 1 | package transformer_test 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/sugarme/transformer" 8 | "github.com/sugarme/transformer/bert" 9 | ) 10 | 11 | // With model name 12 | func TestModelFromPretrained_ModelName(t *testing.T) { 13 | modelName := "bert-base-uncased" 14 | var config *bert.BertConfig = new(bert.BertConfig) 15 | err := transformer.LoadConfig(config, modelName, nil) 16 | if err != nil { 17 | t.Error(err) 18 | } 19 | 20 | wantVocabSize := int64(30522) 21 | gotVocabSize := config.VocabSize 22 | 23 | if !reflect.DeepEqual(wantVocabSize, gotVocabSize) { 24 | t.Errorf("Want: %v\n", wantVocabSize) 25 | t.Errorf("Got: %v\n", gotVocabSize) 26 | } 27 | } 28 | 29 | // With local file 30 | 31 | /* 32 | * // No custom params 33 | * func TestModelFromPretrained(t *testing.T) { 34 | * // bertURL := transformer.AllPretrainedConfigs["bert-base-uncased"] 35 | * url := "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json" 36 | * 37 | * var config *bert.BertConfig = new(bert.BertConfig) 38 | * err := transformer.LoadConfig(config, url, nil) 39 | * if err != nil { 40 | * t.Error(err) 41 | * } 42 | * 43 | * wantVocabSize := int64(30522) 44 | * gotVocabSize := config.VocabSize 45 | * 46 | * if !reflect.DeepEqual(wantVocabSize, gotVocabSize) { 47 | * t.Errorf("Want: %v\n", wantVocabSize) 48 | * t.Errorf("Got: %v\n", gotVocabSize) 49 | * } 50 | * 51 | * } 52 | * 53 | * // With custom params 54 | * func TestModelFromPretrained_CustomParams(t *testing.T) { 55 | * url := "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json" 56 | * 57 | * params := map[string]interface{}{ 58 | * "VocabSize": int64(2000), 59 | * "NumLabels": int64(4), 60 | * } 61 | * 62 | * var config *bert.BertConfig = new(bert.BertConfig) 63 | * err := transformer.LoadConfig(config, url, params) 64 | * if err != nil { 65 | * t.Error(err) 66 | * } 67 | * 68 | * wantVocabSize := int64(2000) 69 | * gotVocabSize := config.VocabSize 70 | * 71 | * if !reflect.DeepEqual(wantVocabSize, gotVocabSize) { 72 | * t.Errorf("Want: %v\n", wantVocabSize) 73 | * t.Errorf("Got: %v\n", gotVocabSize) 74 | * } 75 | * 76 | * wantNumLabels := int64(4) 77 | * gotNumLabels := config.NumLabels 78 | * 79 | * if !reflect.DeepEqual(wantNumLabels, gotNumLabels) { 80 | * t.Errorf("Want: %v\n", wantNumLabels) 81 | * t.Errorf("Got: %v\n", gotNumLabels) 82 | * } 83 | * } */ 84 | -------------------------------------------------------------------------------- /pretrained/roberta.go: -------------------------------------------------------------------------------- 1 | package pretrained 2 | 3 | // RobertaConfigs is a map of pretrained Roberta configuration names to corresponding URLs. 4 | var RobertaConfigs map[string]string = map[string]string{ 5 | "roberta-base": "https://cdn.huggingface.co/roberta-base-config.json", 6 | "roberta-qa": "https://s3.amazonaws.com/models.huggingface.co/bert/deepset/roberta-base-squad2/config.json", 7 | "xlm-roberta-ner-en": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json", 8 | "xlm-roberta-ner-de": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json", 9 | "xlm-roberta-ner-nl": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json", 10 | "xlm-roberta-ner-es": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json", 11 | } 12 | 13 | // RobertaModels is a map of pretrained Roberta model names to corresponding URLs. 14 | var RobertaModels map[string]string = map[string]string{ 15 | "roberta-base": "https://cdn.huggingface.co/roberta-base-rust_model.ot", 16 | "roberta-qa": "https://cdn.huggingface.co/deepset/roberta-base-squad2/rust_model.ot", 17 | "xlm-roberta-ner-en": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-english-rust_model.ot", 18 | "xlm-roberta-ner-de": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-german-rust_model.ot", 19 | "xlm-roberta-ner-nl": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-dutch-rust_model.ot", 20 | "xlm-roberta-ner-es": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-spanish-rust_model.ot", 21 | } 22 | 23 | // RobertaVocabs is a map of pretrained Roberta vocab name to corresponding URLs. 24 | var RobertaVocabs map[string]string = map[string]string{ 25 | "roberta-base": "https://cdn.huggingface.co/roberta-base-vocab.json", 26 | "roberta-qa": "https://cdn.huggingface.co/deepset/roberta-base-squad2/vocab.json", 27 | "xlm-roberta-ner-en": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-english-sentencepiece.bpe.model", 28 | "xlm-roberta-ner-de": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-german-sentencepiece.bpe.model", 29 | "xlm-roberta-ner-nl": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-dutch-sentencepiece.bpe.model", 30 | "xlm-roberta-ner-es": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-spanish-sentencepiece.bpe.model", 31 | } 32 | 33 | // RobertaMerges is a map of pretrained Roberta vocab merges name to corresponding URLs. 34 | var RobertaMerges map[string]string = map[string]string{ 35 | "roberta-base": "https://cdn.huggingface.co/roberta-base-merges.txt", 36 | "roberta-qa": "https://cdn.huggingface.co/deepset/roberta-base-squad2/merges.txt", 37 | } 38 | -------------------------------------------------------------------------------- /util/activation.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "github.com/sugarme/gotch/ts" 5 | ) 6 | 7 | // ActivationFn is an activation function. 8 | type ActivationFn interface { 9 | // Fwd is a forward pass through x. 10 | Fwd(x *ts.Tensor) *ts.Tensor 11 | Name() string 12 | } 13 | 14 | // ReLU activation: 15 | // =============== 16 | 17 | type ReluActivation struct { 18 | name string 19 | } 20 | 21 | var Relu = ReluActivation{} 22 | 23 | func NewRelu() ReluActivation { 24 | return ReluActivation{"relu"} 25 | } 26 | 27 | func (r ReluActivation) Fwd(x *ts.Tensor) (retVal *ts.Tensor) { 28 | return x.MustRelu(false) 29 | } 30 | 31 | func (r ReluActivation) Name() (retVal string) { 32 | return r.name 33 | } 34 | 35 | // GeLU activation: 36 | // =============== 37 | 38 | type GeluActivation struct { 39 | name string 40 | } 41 | 42 | var Gelu = GeluActivation{} 43 | 44 | func NewGelu() GeluActivation { 45 | return GeluActivation{"gelu"} 46 | } 47 | 48 | func (g GeluActivation) Fwd(x *ts.Tensor) (retVal *ts.Tensor) { 49 | return x.MustGelu(false) 50 | } 51 | 52 | func (g GeluActivation) Name() (retVal string) { 53 | return g.name 54 | } 55 | 56 | // Tanh activation: 57 | // =============== 58 | 59 | type TanhActivation struct { 60 | name string 61 | } 62 | 63 | var Tanh = TanhActivation{} 64 | 65 | func NewTanh() TanhActivation { 66 | return TanhActivation{"tanh"} 67 | } 68 | 69 | func (t TanhActivation) Fwd(x *ts.Tensor) (retVal *ts.Tensor) { 70 | return x.MustTanh(false) 71 | } 72 | 73 | func (t TanhActivation) Name() string { 74 | return t.name 75 | } 76 | 77 | // Swish activation: 78 | // =============== 79 | 80 | type SwishActivation struct { 81 | name string 82 | } 83 | 84 | var Swish = SwishActivation{} 85 | 86 | func NewSwish() SwishActivation { 87 | return SwishActivation{"swish"} 88 | } 89 | 90 | func (s SwishActivation) Fwd(x *ts.Tensor) (retVal *ts.Tensor) { 91 | return x.Swish() 92 | } 93 | 94 | func (s SwishActivation) Name() (retVal string) { 95 | return s.name 96 | } 97 | 98 | // Mish activation: 99 | // ================= 100 | 101 | type MishActivation struct { 102 | name string 103 | } 104 | 105 | var Mish = MishActivation{} 106 | 107 | func NewMish() MishActivation { 108 | return MishActivation{"mish"} 109 | } 110 | 111 | func (m MishActivation) Fwd(x *ts.Tensor) (retVal *ts.Tensor) { 112 | softplus := x.MustSoftplus(false) 113 | tanh := softplus.MustTanh(true) 114 | retVal = x.MustMm(tanh, false) 115 | tanh.MustDrop() 116 | return retVal 117 | } 118 | 119 | func (m MishActivation) Name() (retVal string) { 120 | return m.name 121 | } 122 | 123 | func geluNew(xs *ts.Tensor) (retVal *ts.Tensor) { 124 | // TODO: implement 125 | // x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1) 126 | return retVal 127 | } 128 | 129 | var ActivationFnMap map[string]ActivationFn = map[string]ActivationFn{ 130 | "gelu": NewGelu(), 131 | "relu": NewRelu(), 132 | "tanh": NewTanh(), 133 | "swish": NewSwish(), 134 | "mish": NewMish(), 135 | } 136 | -------------------------------------------------------------------------------- /bert/embedding.go: -------------------------------------------------------------------------------- 1 | package bert 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/sugarme/gotch" 7 | "github.com/sugarme/gotch/nn" 8 | "github.com/sugarme/gotch/ts" 9 | 10 | "github.com/sugarme/transformer/util" 11 | ) 12 | 13 | // BertEmbedding defines interface for BertModel or RoBertaModel. 14 | type BertEmbedding interface { 15 | ForwardT(inputIds, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (*ts.Tensor, error) 16 | } 17 | 18 | type BertEmbeddings struct { 19 | WordEmbeddings *nn.Embedding 20 | PositionEmbeddings *nn.Embedding 21 | TokenTypeEmbeddings *nn.Embedding 22 | LayerNorm *nn.LayerNorm 23 | Dropout *util.Dropout 24 | } 25 | 26 | // NewBertEmbeddings builds a new BertEmbeddings 27 | func NewBertEmbeddings(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertEmbeddings { 28 | changeName := true 29 | if len(changeNameOpt) > 0 { 30 | changeName = changeNameOpt[0] 31 | } 32 | embeddingConfig := nn.DefaultEmbeddingConfig() 33 | embeddingConfig.PaddingIdx = 0 34 | 35 | wEmbedPath := p.Sub("word_embeddings") 36 | wordEmbeddings := nn.NewEmbedding(wEmbedPath, config.VocabSize, config.HiddenSize, embeddingConfig) 37 | 38 | posEmbedPath := p.Sub("position_embeddings") 39 | positionEmbeddings := nn.NewEmbedding(posEmbedPath, config.MaxPositionEmbeddings, config.HiddenSize, embeddingConfig) 40 | 41 | ttEmbedPath := p.Sub("token_type_embeddings") 42 | tokenTypeEmbeddings := nn.NewEmbedding(ttEmbedPath, config.TypeVocabSize, config.HiddenSize, embeddingConfig) 43 | 44 | layerNormConfig := nn.DefaultLayerNormConfig() 45 | if changeName { 46 | layerNormConfig.WsName = "gamma" 47 | layerNormConfig.BsName = "beta" 48 | } 49 | layerNormConfig.Eps = 1e-12 50 | 51 | lnPath := p.Sub("LayerNorm") 52 | layerNorm := nn.NewLayerNorm(lnPath, []int64{config.HiddenSize}, layerNormConfig) 53 | 54 | dropout := util.NewDropout(config.HiddenDropoutProb) 55 | 56 | return &BertEmbeddings{wordEmbeddings, positionEmbeddings, tokenTypeEmbeddings, layerNorm, dropout} 57 | } 58 | 59 | // ForwardT implements BertEmbedding interface, passes throught the embedding layer 60 | func (be *BertEmbeddings) ForwardT(inputIds, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (retVal *ts.Tensor, err error) { 61 | 62 | var ( 63 | inputEmbeddings *ts.Tensor 64 | inputShape []int64 65 | ) 66 | 67 | if inputIds.MustDefined() { 68 | if inputEmbeds.MustDefined() { 69 | err = fmt.Errorf("Only one of input Ids or input embeddings may be set.") 70 | return retVal, err 71 | } else { 72 | inputEmbeddings = inputIds.ApplyT(be.WordEmbeddings, train) 73 | inputShape = inputIds.MustSize() 74 | } 75 | } else { 76 | if inputEmbeds.MustDefined() { 77 | inputEmbeddings = inputEmbeds 78 | size := inputEmbeds.MustSize() 79 | inputShape = []int64{size[0], size[1]} 80 | } else { 81 | err = fmt.Errorf("Only one of input Ids or input embeddings may be set.") 82 | return retVal, err 83 | } 84 | } 85 | 86 | seqLength := inputEmbeddings.MustSize()[1] 87 | 88 | var posIds *ts.Tensor 89 | if positionIds.MustDefined() { 90 | posIds = positionIds 91 | } else { 92 | tmp1 := ts.MustArange(ts.IntScalar(seqLength), gotch.Int64, inputEmbeddings.MustDevice()) 93 | tmp2 := tmp1.MustUnsqueeze(0, true) 94 | posIds = tmp2.MustExpand(inputShape, true, true) 95 | } 96 | 97 | var tokTypeIds *ts.Tensor 98 | if tokenTypeIds.MustDefined() { 99 | tokTypeIds = tokenTypeIds 100 | } else { 101 | tokTypeIds = ts.MustZeros(inputShape, gotch.Int64, inputEmbeddings.MustDevice()) 102 | } 103 | 104 | posEmbeddings := posIds.Apply(be.PositionEmbeddings) 105 | posIds.MustDrop() 106 | tokEmbeddings := tokTypeIds.Apply(be.TokenTypeEmbeddings) 107 | tokTypeIds.MustDrop() 108 | 109 | input := inputEmbeddings.MustAdd(posEmbeddings, true) 110 | posEmbeddings.MustDrop() 111 | input.MustAdd_(tokEmbeddings) 112 | tokEmbeddings.MustDrop() 113 | 114 | retTmp1 := input.Apply(be.LayerNorm) 115 | input.MustDrop() 116 | retVal = retTmp1.ApplyT(be.Dropout, train) 117 | retTmp1.MustDrop() 118 | 119 | return retVal, nil 120 | } 121 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/emirpasic/gods v1.12.0 h1:QAUIPSaCu4G+POclxeqb3F+WPpdKqFGlw36+yOzGlrg= 5 | github.com/emirpasic/gods v1.12.0/go.mod h1:YfzfFFoVP/catgzJb4IKIqXjX78Ha8FMSDh3ymbK86o= 6 | github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= 7 | github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= 8 | github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= 9 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 10 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 11 | github.com/rivo/uniseg v0.1.0 h1:+2KBaVoUmb9XzDsrx/Ct0W/EYOSFf/nWTauy++DprtY= 12 | github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= 13 | github.com/schollz/progressbar/v2 v2.15.0 h1:dVzHQ8fHRmtPjD3K10jT3Qgn/+H+92jhPrhmxIJfDz8= 14 | github.com/schollz/progressbar/v2 v2.15.0/go.mod h1:UdPq3prGkfQ7MOzZKlDRpYKcFqEMczbD7YmbPgpzKMI= 15 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 16 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 17 | github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= 18 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 19 | github.com/sugarme/gotch v0.0.0-20200924012111-ff31d3c62dbe h1:8i3jLRuqZvwR5mvuioHO7Hq6qITQkSPzFc/M9LxkkkY= 20 | github.com/sugarme/gotch v0.0.0-20200924012111-ff31d3c62dbe/go.mod h1:w3zHhlZfnNS//C7YQd/89fOmgofd0RK8jf/GJ+cqkO4= 21 | github.com/sugarme/gotch v0.7.0 h1:vDQqLmuo5uhqNTfTyR7xbye9pPK9a4l57YWMKH41gGU= 22 | github.com/sugarme/gotch v0.7.0/go.mod h1:ydo7fmsmT+2L5p8Am1YhOLSWGN9WV9nyrfk/RSmhTfo= 23 | github.com/sugarme/regexpset v0.0.0-20200813070853-0a3212d91786/go.mod h1:2gwkXLWbDGUQWeL3RtpCmcY4mzCtU13kb9UsAg9xMaw= 24 | github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c h1:pwb4kNSHb4K89ymCaN+5lPH/MwnfSVg4rzGDh4d+iy4= 25 | github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c/go.mod h1:2gwkXLWbDGUQWeL3RtpCmcY4mzCtU13kb9UsAg9xMaw= 26 | github.com/sugarme/tokenizer v0.0.0-20200930080132-dbbba9ea4756 h1:zPZePVVB2+NiDDqtII00+BQ+MqTJpCt6HBqmtbGBUmA= 27 | github.com/sugarme/tokenizer v0.0.0-20200930080132-dbbba9ea4756/go.mod h1:mkypAc0rOi4EDcKuuJTdbPJOk8QMMZABaJRNs1M1OeU= 28 | github.com/sugarme/tokenizer v0.1.16 h1:M1YpaEaZjQw7HiXrgujCcsp2ccbOLVtWNIK37o3kZlk= 29 | github.com/sugarme/tokenizer v0.1.16/go.mod h1:a1EffeqKJAQtJz/IEAgaN1SIG/TCgawEgWOV9rquM5M= 30 | github.com/sugarme/tokenizer v0.1.17 h1:1UDhHtz/nG7FGQfEJPXF3VbFVIYuKiBymZRAxdrLIBM= 31 | github.com/sugarme/tokenizer v0.1.17/go.mod h1:a1EffeqKJAQtJz/IEAgaN1SIG/TCgawEgWOV9rquM5M= 32 | golang.org/x/image v0.0.0-20200927104501-e162460cd6b5/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= 33 | golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 34 | golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 h1:qwRHBd0NqMbJxfbotnDhm2ByMI1Shq4Y6oRJo21SGJA= 35 | golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 36 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 37 | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= 38 | golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= 39 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 40 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 41 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 42 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 43 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 44 | gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= 45 | gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 46 | -------------------------------------------------------------------------------- /coverage.out: -------------------------------------------------------------------------------- 1 | mode: set 2 | github.com/sugarme/sermo/tokenize/tokenizer.go:21.60,23.2 1 1 3 | github.com/sugarme/sermo/tokenize/tokenizer.go:30.41,35.2 1 1 4 | github.com/sugarme/sermo/tokenize/tokenizer.go:45.89,46.28 1 1 5 | github.com/sugarme/sermo/tokenize/tokenizer.go:46.28,52.3 1 1 6 | github.com/sugarme/sermo/tokenize/tokenizer.go:55.57,58.25 2 1 7 | github.com/sugarme/sermo/tokenize/tokenizer.go:62.2,62.12 1 1 8 | github.com/sugarme/sermo/tokenize/tokenizer.go:58.25,60.3 1 1 9 | github.com/sugarme/sermo/tokenize/unicode.go:37.29,38.11 1 1 10 | github.com/sugarme/sermo/tokenize/unicode.go:46.2,46.46 1 1 11 | github.com/sugarme/sermo/tokenize/unicode.go:39.12,40.15 1 1 12 | github.com/sugarme/sermo/tokenize/unicode.go:41.12,42.15 1 1 13 | github.com/sugarme/sermo/tokenize/unicode.go:43.12,44.15 1 0 14 | github.com/sugarme/sermo/tokenize/unicode.go:50.27,51.11 1 1 15 | github.com/sugarme/sermo/tokenize/unicode.go:61.2,61.34 1 1 16 | github.com/sugarme/sermo/tokenize/unicode.go:52.11,53.14 1 1 17 | github.com/sugarme/sermo/tokenize/unicode.go:54.12,55.14 1 1 18 | github.com/sugarme/sermo/tokenize/unicode.go:56.12,57.14 1 1 19 | github.com/sugarme/sermo/tokenize/unicode.go:58.12,59.14 1 0 20 | github.com/sugarme/sermo/tokenize/unicode.go:65.27,67.2 1 1 21 | github.com/sugarme/sermo/tokenize/unicode.go:70.29,72.2 1 1 22 | github.com/sugarme/sermo/tokenize/word.go:16.63,21.2 1 1 23 | github.com/sugarme/sermo/tokenize/word.go:29.54,37.2 4 1 24 | github.com/sugarme/sermo/tokenize/word.go:39.31,41.24 2 1 25 | github.com/sugarme/sermo/tokenize/word.go:50.2,50.19 1 1 26 | github.com/sugarme/sermo/tokenize/word.go:41.24,42.44 1 1 27 | github.com/sugarme/sermo/tokenize/word.go:42.44,43.12 1 0 28 | github.com/sugarme/sermo/tokenize/word.go:44.9,44.24 1 1 29 | github.com/sugarme/sermo/tokenize/word.go:44.24,46.4 1 1 30 | github.com/sugarme/sermo/tokenize/word.go:46.9,48.4 1 1 31 | github.com/sugarme/sermo/tokenize/word.go:53.38,55.41 2 1 32 | github.com/sugarme/sermo/tokenize/word.go:60.2,60.19 1 1 33 | github.com/sugarme/sermo/tokenize/word.go:55.41,56.33 1 1 34 | github.com/sugarme/sermo/tokenize/word.go:56.33,58.4 1 1 35 | github.com/sugarme/sermo/tokenize/word.go:63.38,67.24 3 1 36 | github.com/sugarme/sermo/tokenize/word.go:77.2,77.17 1 1 37 | github.com/sugarme/sermo/tokenize/word.go:80.2,80.13 1 1 38 | github.com/sugarme/sermo/tokenize/word.go:67.24,68.17 1 1 39 | github.com/sugarme/sermo/tokenize/word.go:68.17,72.4 3 1 40 | github.com/sugarme/sermo/tokenize/word.go:72.9,74.4 1 1 41 | github.com/sugarme/sermo/tokenize/word.go:77.17,79.3 1 1 42 | github.com/sugarme/sermo/tokenize/word.go:83.53,88.11 3 1 43 | github.com/sugarme/sermo/tokenize/word.go:92.2,93.27 2 1 44 | github.com/sugarme/sermo/tokenize/word.go:98.2,99.16 2 1 45 | github.com/sugarme/sermo/tokenize/word.go:88.11,90.3 1 1 46 | github.com/sugarme/sermo/tokenize/word.go:93.27,96.3 2 1 47 | github.com/sugarme/sermo/tokenize/word.go:102.37,104.24 2 1 48 | github.com/sugarme/sermo/tokenize/word.go:110.2,110.13 1 1 49 | github.com/sugarme/sermo/tokenize/word.go:104.24,108.3 3 1 50 | github.com/sugarme/sermo/tokenize/word.go:115.43,121.27 3 1 51 | github.com/sugarme/sermo/tokenize/word.go:126.2,126.16 1 1 52 | github.com/sugarme/sermo/tokenize/word.go:121.27,122.16 1 1 53 | github.com/sugarme/sermo/tokenize/word.go:122.16,124.4 1 1 54 | github.com/sugarme/sermo/tokenize/word.go:131.36,133.24 2 1 55 | github.com/sugarme/sermo/tokenize/word.go:142.2,142.19 1 1 56 | github.com/sugarme/sermo/tokenize/word.go:133.24,134.19 1 1 57 | github.com/sugarme/sermo/tokenize/word.go:134.19,138.4 3 1 58 | github.com/sugarme/sermo/tokenize/word.go:138.9,140.4 1 1 59 | github.com/sugarme/sermo/tokenize/wordpiece.go:25.49,30.2 1 0 60 | github.com/sugarme/sermo/tokenize/wordpiece.go:35.60,41.25 3 1 61 | github.com/sugarme/sermo/tokenize/wordpiece.go:44.2,44.16 1 1 62 | github.com/sugarme/sermo/tokenize/wordpiece.go:41.25,43.3 1 1 63 | github.com/sugarme/sermo/tokenize/wordpiece.go:47.70,49.43 2 1 64 | github.com/sugarme/sermo/tokenize/wordpiece.go:64.2,64.13 1 1 65 | github.com/sugarme/sermo/tokenize/wordpiece.go:49.43,50.33 1 1 66 | github.com/sugarme/sermo/tokenize/wordpiece.go:54.3,54.35 1 1 67 | github.com/sugarme/sermo/tokenize/wordpiece.go:50.33,52.12 2 0 68 | github.com/sugarme/sermo/tokenize/wordpiece.go:54.35,56.17 2 1 69 | github.com/sugarme/sermo/tokenize/wordpiece.go:60.4,61.45 2 1 70 | github.com/sugarme/sermo/tokenize/wordpiece.go:56.17,58.10 2 0 71 | github.com/sugarme/sermo/tokenize/wordpiece.go:69.53,71.2 1 0 72 | github.com/sugarme/sermo/tokenize/wordpiece.go:74.58,76.2 1 0 73 | -------------------------------------------------------------------------------- /bert/config.go: -------------------------------------------------------------------------------- 1 | package bert 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "os" 9 | "path/filepath" 10 | "reflect" 11 | ) 12 | 13 | // BertConfig defines the BERT model architecture (i.e., number of layers, 14 | // hidden layer size, label mapping...) 15 | type BertConfig struct { 16 | HiddenAct string `json:"hidden_act"` 17 | AttentionProbsDropoutProb float64 `json:"attention_probs_dropout_prob"` 18 | HiddenDropoutProb float64 `json:"hidden_dropout_prob"` 19 | HiddenSize int64 `json:"hidden_size"` 20 | InitializerRange float32 `json:"initializer_range"` 21 | IntermediateSize int64 `json:"intermediate_size"` 22 | MaxPositionEmbeddings int64 `json:"max_position_embeddings"` 23 | NumAttentionHeads int64 `json:"num_attention_heads"` 24 | NumHiddenLayers int64 `json:"num_hidden_layers"` 25 | TypeVocabSize int64 `json:"type_vocab_size"` 26 | VocabSize int64 `json:"vocab_size"` 27 | OutputAttentions bool `json:"output_attentions"` 28 | OutputHiddenStates bool `json:"output_hidden_states"` 29 | IsDecoder bool `json:"is_decoder"` 30 | Id2Label map[int64]string `json:"id_2_label"` 31 | Label2Id map[string]int64 `json:"label_2_id"` 32 | NumLabels int64 `json:"num_labels"` 33 | } 34 | 35 | // NewBertConfig initiates BertConfig with given input parameters or default values. 36 | func NewConfig(customParams map[string]interface{}) *BertConfig { 37 | defaultValues := map[string]interface{}{ 38 | "VocabSize": int64(30522), 39 | "HiddenSize": int64(768), 40 | "NumHiddenLayers": int64(12), 41 | "NumAttentionHeads": int64(12), 42 | "IntermediateSize": int64(3072), 43 | "HiddenAct": "gelu", 44 | "HiddenDropoutProb": float64(0.1), 45 | "AttentionProbDropoutProb": float64(0.1), 46 | "MaxPositionEmbeddings": int64(512), 47 | "TypeVocabSize": int64(2), 48 | "InitializerRange": float32(0.02), 49 | "LayerNormEps": 1e-12, // not applied yet 50 | "PadTokenId": 0, // not applied yet 51 | "GradientCheckpointing": false, // not applied yet 52 | } 53 | 54 | params := defaultValues 55 | for k, v := range customParams { 56 | if _, ok := params[k]; ok { 57 | params[k] = v 58 | } 59 | } 60 | 61 | config := new(BertConfig) 62 | config.updateParams(params) 63 | 64 | return config 65 | } 66 | 67 | func ConfigFromFile(filename string) (*BertConfig, error) { 68 | filePath, err := filepath.Abs(filename) 69 | if err != nil { 70 | return nil, err 71 | } 72 | 73 | f, err := os.Open(filePath) 74 | if err != nil { 75 | return nil, err 76 | } 77 | defer f.Close() 78 | 79 | buff, err := ioutil.ReadAll(f) 80 | if err != nil { 81 | return nil, err 82 | } 83 | 84 | var config BertConfig 85 | err = json.Unmarshal(buff, &config) 86 | if err != nil { 87 | fmt.Println(err) 88 | log.Fatalf("Could not parse configuration to BertConfiguration.\n") 89 | } 90 | return &config, nil 91 | } 92 | 93 | // Load loads model configuration from file or model name. It also updates 94 | // default configuration parameters if provided. 95 | // This method implements `pretrained.Config` interface. 96 | func (c *BertConfig) Load(modelNameOrPath string, params map[string]interface{}) error { 97 | err := c.fromFile(modelNameOrPath) 98 | if err != nil { 99 | return err 100 | } 101 | 102 | // Update custom parameters 103 | c.updateParams(params) 104 | 105 | return nil 106 | } 107 | 108 | func (c *BertConfig) fromFile(filename string) error { 109 | filePath, err := filepath.Abs(filename) 110 | if err != nil { 111 | return err 112 | } 113 | 114 | f, err := os.Open(filePath) 115 | if err != nil { 116 | return err 117 | } 118 | defer f.Close() 119 | 120 | buff, err := ioutil.ReadAll(f) 121 | if err != nil { 122 | return err 123 | } 124 | 125 | err = json.Unmarshal(buff, c) 126 | if err != nil { 127 | fmt.Println(err) 128 | log.Fatalf("Could not parse configuration to BertConfiguration.\n") 129 | } 130 | 131 | return nil 132 | } 133 | 134 | func (c *BertConfig) GetVocabSize() int64 { 135 | return c.VocabSize 136 | } 137 | 138 | func (c *BertConfig) updateParams(params map[string]interface{}) { 139 | for k, v := range params { 140 | c.updateField(k, v) 141 | } 142 | } 143 | 144 | func (c *BertConfig) updateField(field string, value interface{}) { 145 | // Check whether field name exists 146 | if reflect.ValueOf(c).Elem().FieldByName(field).IsValid() { 147 | // Check whether same type 148 | if reflect.ValueOf(c).Elem().FieldByName(field).Kind() == reflect.TypeOf(value).Kind() { 149 | reflect.ValueOf(c).Elem().FieldByName(field).Set(reflect.ValueOf(value)) 150 | } 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /bert/example_test.go: -------------------------------------------------------------------------------- 1 | package bert_test 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | 7 | "github.com/sugarme/gotch" 8 | "github.com/sugarme/gotch/nn" 9 | "github.com/sugarme/gotch/pickle" 10 | "github.com/sugarme/gotch/ts" 11 | "github.com/sugarme/tokenizer" 12 | 13 | "github.com/sugarme/transformer/bert" 14 | "github.com/sugarme/transformer/util" 15 | ) 16 | 17 | func ExampleBertForMaskedLM() { 18 | // Config 19 | configFile, err := util.CachedPath("bert-base-uncased", "config.json") 20 | if err != nil { 21 | log.Fatal(err) 22 | } 23 | config := new(bert.BertConfig) 24 | err = config.Load(configFile, nil) 25 | if err != nil { 26 | log.Fatal(err) 27 | } 28 | 29 | // Model 30 | device := gotch.CPU 31 | vs := nn.NewVarStore(device) 32 | 33 | model := new(bert.BertForMaskedLM) 34 | 35 | modelFile, err := util.CachedPath("bert-base-uncased", "pytorch_model.bin") 36 | if err != nil { 37 | log.Fatal(err) 38 | } 39 | err = pickle.LoadAll(vs, modelFile) 40 | if err != nil { 41 | log.Fatal(err) 42 | } 43 | 44 | vocabFile, err := util.CachedPath("bert-base-uncased", "vocab.txt") 45 | if err != nil { 46 | log.Fatal(err) 47 | } 48 | tk := getBertTokenizer(vocabFile) 49 | sentence1 := "Looks like one [MASK] is missing" 50 | sentence2 := "It was a very nice and [MASK] day" 51 | 52 | var input []tokenizer.EncodeInput 53 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1))) 54 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2))) 55 | 56 | encodings, err := tk.EncodeBatch(input, true) 57 | if err != nil { 58 | log.Fatal(err) 59 | } 60 | 61 | var maxLen int = 0 62 | for _, en := range encodings { 63 | if len(en.Ids) > maxLen { 64 | maxLen = len(en.Ids) 65 | } 66 | } 67 | 68 | var tensors []ts.Tensor 69 | for _, en := range encodings { 70 | var tokInput []int64 = make([]int64, maxLen) 71 | for i := 0; i < len(en.Ids); i++ { 72 | tokInput[i] = int64(en.Ids[i]) 73 | } 74 | 75 | tensors = append(tensors, *ts.TensorFrom(tokInput)) 76 | } 77 | 78 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true) 79 | 80 | var output *ts.Tensor 81 | ts.NoGrad(func() { 82 | output, _, _ = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, ts.None, ts.None, false) 83 | }) 84 | 85 | index1 := output.MustGet(0).MustGet(4).MustArgmax([]int64{0}, false, false).Int64Values()[0] 86 | index2 := output.MustGet(1).MustGet(7).MustArgmax([]int64{0}, false, false).Int64Values()[0] 87 | 88 | got1, ok := tk.IdToToken(int(index1)) 89 | if !ok { 90 | fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index1) 91 | } 92 | got2, ok := tk.IdToToken(int(index2)) 93 | if !ok { 94 | fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index2) 95 | } 96 | 97 | fmt.Println(got1) 98 | fmt.Println(got2) 99 | /* 100 | * // Output: 101 | * // person 102 | * // pleasant 103 | * */ 104 | } 105 | 106 | /* 107 | * func ExampleBertForSequenceClassification() { 108 | * 109 | * device := gotch.CPU 110 | * vs := nn.NewVarStore(device) 111 | * 112 | * config := bert.ConfigFromFile("../data/bert/config.json") 113 | * 114 | * var dummyLabelMap map[int64]string = make(map[int64]string) 115 | * dummyLabelMap[0] = "positive" 116 | * dummyLabelMap[1] = "negative" 117 | * dummyLabelMap[3] = "neutral" 118 | * 119 | * config.Id2Label = dummyLabelMap 120 | * config.OutputAttentions = true 121 | * config.OutputHiddenStates = true 122 | * model := bert.NewBertForSequenceClassification(vs.Root(), config) 123 | * tk := getBertTokenizer() 124 | * 125 | * // Define input 126 | * sentence1 := "Looks like one thing is missing" 127 | * sentence2 := `It's like comparing oranges to apples` 128 | * var input []tokenizer.EncodeInput 129 | * input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1))) 130 | * input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2))) 131 | * encodings, err := tk.EncodeBatch(input, true) 132 | * if err != nil { 133 | * log.Fatal(err) 134 | * } 135 | * 136 | * // Find max length of token Ids from slice of encodings 137 | * var maxLen int = 0 138 | * for _, en := range encodings { 139 | * if len(en.Ids) > maxLen { 140 | * maxLen = len(en.Ids) 141 | * } 142 | * } 143 | * 144 | * var tensors []ts.Tensor 145 | * for _, en := range encodings { 146 | * var tokInput []int64 = make([]int64, maxLen) 147 | * for i := 0; i < len(en.Ids); i++ { 148 | * tokInput[i] = int64(en.Ids[i]) 149 | * } 150 | * 151 | * tensors = append(tensors, ts.TensorFrom(tokInput)) 152 | * } 153 | * 154 | * inputTensor := ts.MustStack(tensors, 0).MustTo(device, true) 155 | * 156 | * var ( 157 | * output ts.Tensor 158 | * allHiddenStates, allAttentions []ts.Tensor 159 | * ) 160 | * 161 | * ts.NoGrad(func() { 162 | * output, allHiddenStates, allAttentions = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, false) 163 | * }) 164 | * 165 | * fmt.Println(output.MustSize()) 166 | * fmt.Println(len(allHiddenStates)) 167 | * fmt.Println(len(allAttentions)) 168 | * 169 | * // Output: 170 | * // [2 3] 171 | * // 12 172 | * // 12 173 | * } */ 174 | -------------------------------------------------------------------------------- /bert/encoder.go: -------------------------------------------------------------------------------- 1 | package bert 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/sugarme/gotch/nn" 7 | "github.com/sugarme/gotch/ts" 8 | ) 9 | 10 | // `BertLayer`: 11 | //============= 12 | 13 | // BertLayer defines a layer in BERT encoder 14 | type BertLayer struct { 15 | Attention *BertAttention 16 | IsDecoder bool 17 | CrossAttention *BertAttention 18 | Intermediate *BertIntermediate 19 | Output *BertOutput 20 | } 21 | 22 | // NewBertLayer creates a new BertLayer. 23 | func NewBertLayer(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertLayer { 24 | changeName := true 25 | if len(changeNameOpt) > 0 { 26 | changeName = changeNameOpt[0] 27 | } 28 | path := p.Sub("attention") 29 | attention := NewBertAttention(path, config, changeName) 30 | var ( 31 | isDecoder bool = false 32 | crossAttention *BertAttention 33 | ) 34 | 35 | if config.IsDecoder { 36 | isDecoder = true 37 | attPath := p.Sub("cross_attention") 38 | crossAttention = NewBertAttention(attPath, config) 39 | } 40 | 41 | intermediatePath := p.Sub("intermediate") 42 | intermediate := NewBertIntermediate(intermediatePath, config) 43 | outputPath := p.Sub("output") 44 | output := NewBertOutput(outputPath, config, changeName) 45 | 46 | return &BertLayer{attention, isDecoder, crossAttention, intermediate, output} 47 | } 48 | 49 | // ForwardT forwards pass through the model. 50 | func (bl *BertLayer) ForwardT(hiddenStates, mask, encoderHiddenStates, encoderMask *ts.Tensor, train bool) (retVal, retValOpt1, retValOpt2 *ts.Tensor) { 51 | var ( 52 | attentionOutput *ts.Tensor 53 | attentionWeights *ts.Tensor 54 | crossAttentionWeights *ts.Tensor 55 | ) 56 | 57 | if bl.IsDecoder && encoderHiddenStates.MustDefined() { 58 | var attentionOutputTmp *ts.Tensor 59 | attentionOutputTmp, attentionWeights = bl.Attention.ForwardT(hiddenStates, mask, ts.None, ts.None, train) 60 | attentionOutput, crossAttentionWeights = bl.CrossAttention.ForwardT(attentionOutputTmp, mask, encoderHiddenStates, encoderMask, train) 61 | attentionOutputTmp.MustDrop() 62 | } else { 63 | attentionOutput, attentionWeights = bl.Attention.ForwardT(hiddenStates, mask, ts.None, ts.None, train) 64 | crossAttentionWeights = ts.None 65 | } 66 | 67 | outputTmp := bl.Intermediate.Forward(attentionOutput) 68 | output := bl.Output.ForwardT(outputTmp, attentionOutput, train) 69 | attentionOutput.MustDrop() 70 | outputTmp.MustDrop() 71 | 72 | return output, attentionWeights, crossAttentionWeights 73 | } 74 | 75 | // `BertEncoder`: 76 | //=============== 77 | 78 | // BertEncoder defines an encoder for BERT model 79 | type BertEncoder struct { 80 | OutputAttentions bool 81 | OutputHiddenStates bool 82 | Layers []BertLayer 83 | } 84 | 85 | // NewBertEncoder creates a new BertEncoder. 86 | func NewBertEncoder(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertEncoder { 87 | changeName := true 88 | if len(changeNameOpt) > 0 { 89 | changeName = changeNameOpt[0] 90 | } 91 | path := p.Sub("layer") 92 | outputAttentions := false 93 | if config.OutputAttentions { 94 | outputAttentions = true 95 | } 96 | 97 | outputHiddenStates := false 98 | if config.OutputHiddenStates { 99 | outputHiddenStates = true 100 | } 101 | 102 | var layers []BertLayer 103 | for lIdx := 0; lIdx < int(config.NumHiddenLayers); lIdx++ { 104 | layers = append(layers, *NewBertLayer(path.Sub(fmt.Sprintf("%v", lIdx)), config, changeName)) 105 | } 106 | 107 | return &BertEncoder{outputAttentions, outputHiddenStates, layers} 108 | 109 | } 110 | 111 | // ForwardT forwards pass through the model. 112 | func (be *BertEncoder) ForwardT(hiddenStates, mask, encoderHiddenStates, encoderMask *ts.Tensor, train bool) (retVal *ts.Tensor, retValOpt1, retValOpt2 []ts.Tensor) { 113 | var ( 114 | allHiddenStates, allAttentions []ts.Tensor = nil, nil 115 | ) 116 | 117 | hiddenState := hiddenStates 118 | 119 | if be.OutputHiddenStates { 120 | allHiddenStates = make([]ts.Tensor, 0) // initialize it 121 | } 122 | if be.OutputAttentions { 123 | allAttentions = make([]ts.Tensor, 0) 124 | } 125 | 126 | for _, layer := range be.Layers { 127 | if allHiddenStates != nil { 128 | allHiddenStates = append(allHiddenStates, *hiddenState) 129 | } 130 | 131 | stateTmp, attnWeightsTmp, _ := layer.ForwardT(hiddenState, mask, encoderHiddenStates, encoderMask, train) 132 | hiddenState.MustDrop() 133 | hiddenState = stateTmp 134 | 135 | if allAttentions != nil { 136 | allAttentions = append(allAttentions, *attnWeightsTmp) 137 | } 138 | 139 | // TODO: should we need to delete `stateTmp` and `attnWeightsTmp` after all? 140 | 141 | } 142 | 143 | return hiddenState, allHiddenStates, allAttentions 144 | } 145 | 146 | // `BertPooler`: 147 | //============== 148 | 149 | // BertPooler defines a linear layer which can be applied to the 150 | // first element of the sequence(`[MASK]` token) 151 | type BertPooler struct { 152 | Lin *nn.Linear 153 | } 154 | 155 | // NewBertPooler creates a new BertPooler. 156 | func NewBertPooler(p *nn.Path, config *BertConfig) *BertPooler { 157 | path := p.Sub("dense") 158 | lconfig := nn.DefaultLinearConfig() 159 | lin := nn.NewLinear(path, config.HiddenSize, config.HiddenSize, lconfig) 160 | 161 | return &BertPooler{lin} 162 | } 163 | 164 | // Forward forwards pass through the model. 165 | func (bp *BertPooler) Forward(hiddenStates *ts.Tensor) (retVal *ts.Tensor) { 166 | 167 | selectTs := hiddenStates.MustSelect(1, 0, false) 168 | tmp := selectTs.Apply(bp.Lin) 169 | retVal = tmp.MustTanh(true) 170 | return retVal 171 | } 172 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transformer [![License](https://img.shields.io/:license-apache-blue.svg)](https://opensource.org/licenses/Apache-2.0)[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white&style=flat-square)](https://pkg.go.dev/github.com/sugarme/transformer?tab=doc)[![Travis CI](https://api.travis-ci.org/sugarme/transformer.svg?branch=master)](https://travis-ci.org/sugarme/transformer)[![Go Report Card](https://goreportcard.com/badge/github.com/sugarme/transformer)](https://goreportcard.com/report/github.com/sugarme/transformer) 2 | 3 | ## Overview 4 | 5 | `transformer` is pure Go package to facilitate applying Natural Language Processing (NLP) models train/test and inference in Go. 6 | 7 | This package is in active mode of building and there are many changes ahead. Hence you can use it with your complete own risk. The package will be considered as stable when version 1.0 is released. 8 | 9 | `transformer` is heavily inspired by and based on the popular [Python HuggingFace Transformers](https://github.com/huggingface/transformers). It's also influenced by [Rust version - rust-bert](https://github.com/guillaume-be/rust-bert). In fact, all pre-trained models for Rust are compatible to import to this Go `transformer` package as both `rust-bert`'s dependency Pytorch Rust binding - [**`tch-rs`**](https://github.com/LaurentMazare/tch-rs) and Go binding [**`gotch`**](https://github.com/sugarme/gotch) are built with similar principles. 10 | 11 | `transformer` is part of an ambitious goal (together with [**tokenizer**](https://github.com/sugarme/tokenizer) and [**gotch**](https://github.com/sugarme/gotch)) to bring more AI/deep-learning tools to Gophers so that they can stick to the language they love and good at and build faster software in production. 12 | 13 | ## Dependencies 14 | 15 | 2 main dependencies are: 16 | 17 | - `tokenizer` 18 | - `gotch` 19 | 20 | ## Prerequisites and installation 21 | 22 | - As this package depends on `gotch` which is a Pytorch C++ API binding for Go, a pre-compiled Libtorch copy (CPU or GPU) should be installed in your machine. Please see [gotch](https://github.com/sugarme/gotch) installation instruction for detail. 23 | - Install package: `go get -u github.com/sugarme/transformer` 24 | 25 | ## Basic example 26 | 27 | ```go 28 | import ( 29 | "fmt" 30 | "log" 31 | 32 | "github.com/sugarme/gotch" 33 | ts "github.com/sugarme/gotch/tensor" 34 | "github.com/sugarme/tokenizer" 35 | 36 | "github.com/sugarme/transformer/bert" 37 | ) 38 | 39 | func main() { 40 | var config *bert.BertConfig = new(bert.BertConfig) 41 | if err := transformer.LoadConfig(config, "bert-base-uncased", nil); err != nil { 42 | log.Fatal(err) 43 | } 44 | 45 | var model *bert.BertForMaskedLM = new(bert.BertForMaskedLM) 46 | if err := transformer.LoadModel(model, "bert-base-uncased", config, nil, gotch.CPU); err != nil { 47 | log.Fatal(err) 48 | } 49 | 50 | var tk *bert.Tokenizer = bert.NewTokenizer() 51 | if err := tk.Load("bert-base-uncased", nil); err != nil{ 52 | log.Fatal(err) 53 | } 54 | 55 | sentence1 := "Looks like one [MASK] is missing" 56 | sentence2 := "It was a very nice and [MASK] day" 57 | 58 | var input []tokenizer.EncodeInput 59 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1))) 60 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2))) 61 | 62 | encodings, err := tk.EncodeBatch(input, true) 63 | if err != nil { 64 | log.Fatal(err) 65 | } 66 | 67 | var maxLen int = 0 68 | for _, en := range encodings { 69 | if len(en.Ids) > maxLen { 70 | maxLen = len(en.Ids) 71 | } 72 | } 73 | 74 | var tensors []ts.Tensor 75 | for _, en := range encodings { 76 | var tokInput []int64 = make([]int64, maxLen) 77 | for i := 0; i < len(en.Ids); i++ { 78 | tokInput[i] = int64(en.Ids[i]) 79 | } 80 | 81 | tensors = append(tensors, ts.TensorFrom(tokInput)) 82 | } 83 | 84 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true) 85 | var output ts.Tensor 86 | ts.NoGrad(func() { 87 | output, _, _ = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, ts.None, ts.None, false) 88 | }) 89 | index1 := output.MustGet(0).MustGet(4).MustArgmax(0, false, false).Int64Values()[0] 90 | index2 := output.MustGet(1).MustGet(7).MustArgmax(0, false, false).Int64Values()[0] 91 | 92 | got1, ok := tk.IdToToken(int(index1)) 93 | if !ok { 94 | fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index1) 95 | } 96 | got2, ok := tk.IdToToken(int(index2)) 97 | if !ok { 98 | fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index2) 99 | } 100 | 101 | fmt.Println(got1) 102 | fmt.Println(got2) 103 | 104 | // Output: 105 | // person 106 | // pleasant 107 | } 108 | ``` 109 | 110 | ## Getting Started 111 | 112 | - See [pkg.go.dev](https://pkg.go.dev/github.com/sugarme/transformer?tab=doc) for detail APIs 113 | 114 | 115 | ## License 116 | 117 | `transformer` is Apache 2.0 licensed. 118 | 119 | 120 | ## Acknowledgement 121 | 122 | - This project has been inspired and used many concepts from [Python HuggingFace Transformers](https://github.com/huggingface/transformers) and [Rust version - rust-bert](https://github.com/guillaume-be/rust-bert). 123 | 124 | - Pre-trained models and configurations are downloaded remotely from HuggingFace. 125 | 126 | 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /roberta/embedding.go: -------------------------------------------------------------------------------- 1 | package roberta 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/sugarme/gotch" 7 | "github.com/sugarme/gotch/nn" 8 | "github.com/sugarme/gotch/ts" 9 | 10 | "github.com/sugarme/transformer/bert" 11 | "github.com/sugarme/transformer/util" 12 | ) 13 | 14 | // RobertaEmbeddings holds embedding struct for Roberta model. 15 | // It also implements `BertEmbedding` interface for Roberta models. 16 | type RobertaEmbeddings struct { 17 | wordEmbeddings *nn.Embedding 18 | positionEmbeddings *nn.Embedding 19 | tokenTypeEmbeddings *nn.Embedding 20 | layerNorm *nn.LayerNorm 21 | dropout *util.Dropout 22 | paddingIndex int64 23 | } 24 | 25 | func (re *RobertaEmbeddings) createPositionIdsFromInputIds(x *ts.Tensor) *ts.Tensor { 26 | mask := x.MustNe(ts.IntScalar(re.paddingIndex), false).MustTotype(gotch.Int64, true) 27 | cumSum := mask.MustCumsum(1, gotch.Int64, false) 28 | mul := cumSum.MustMul(mask, true) 29 | retVal := mul.MustAddScalar(ts.IntScalar(re.paddingIndex), false) 30 | mul.MustDrop() 31 | 32 | return retVal 33 | } 34 | 35 | func (re *RobertaEmbeddings) createPositionIdsFromEmbeddings(x *ts.Tensor) *ts.Tensor { 36 | shape := x.MustSize() 37 | var inputShape []int64 = []int64{shape[0], shape[1]} 38 | 39 | positionIds := ts.MustArangeStart(ts.IntScalar(re.paddingIndex+1), ts.IntScalar(inputShape[0]), gotch.Int64, x.MustDevice()) 40 | retVal := positionIds.MustUnsqueeze(0, false).MustExpand(inputShape, true, true) 41 | 42 | return retVal 43 | } 44 | 45 | // NewRobertaEmbeddings creates a new RobertaEmbeddings. 46 | // 47 | // Params: 48 | // - `p` - Variable store path for the root of the BertEmbeddings model 49 | // - `config` - `BertConfig` object defining the model architecture and vocab/hidden size. 50 | func NewRobertaEmbeddings(p nn.Path, config *bert.BertConfig) *RobertaEmbeddings { 51 | 52 | embeddingConfig := nn.DefaultEmbeddingConfig() 53 | embeddingConfig.PaddingIdx = 1 54 | 55 | wordEmbeddings := nn.NewEmbedding(p.Sub("word_embeddings"), config.VocabSize, config.HiddenSize, embeddingConfig) 56 | positionEmbeddings := nn.NewEmbedding(p.Sub("position_embeddings"), config.MaxPositionEmbeddings, config.HiddenSize, nn.DefaultEmbeddingConfig()) 57 | tokenTypeEmbeddings := nn.NewEmbedding(p.Sub("token_type_embeddings"), config.TypeVocabSize, config.HiddenSize, nn.DefaultEmbeddingConfig()) 58 | 59 | layerNormConfig := nn.DefaultLayerNormConfig() 60 | layerNormConfig.Eps = 1e-12 61 | layerNorm := nn.NewLayerNorm(p.Sub("LayerNorm"), []int64{config.HiddenSize}, layerNormConfig) 62 | dropout := util.NewDropout(config.HiddenDropoutProb) 63 | 64 | return &RobertaEmbeddings{ 65 | wordEmbeddings: wordEmbeddings, 66 | positionEmbeddings: positionEmbeddings, 67 | tokenTypeEmbeddings: tokenTypeEmbeddings, 68 | layerNorm: layerNorm, 69 | dropout: dropout, 70 | } 71 | } 72 | 73 | // ForwardT forwards pass through the embedding layer. 74 | // This differs from the original BERT embeddings in how the position ids are calculated when not provided. 75 | // 76 | // Params: 77 | // - `inputIds`: Optional input tensor of shape (batch size, sequence length). 78 | // If None, pre-computed embeddings must be provided (see `inputEmbeds`) 79 | // - `tokenTypeIds`: Optional segment id of shape (batch size, sequence length). 80 | // Convention is value of 0 for the first sentence (incl. [SEP]) and 1 for the second sentence. If None set to 0. 81 | // - `positionIds`: Optional position ids of shape (batch size, sequence length). 82 | // If None, will be incremented from 0. 83 | // - `inputEmbeds`: Optional pre-computed input embeddings of shape (batch size, sequence length, hidden size). 84 | // If None, input ids must be provided (see `inputIds`) 85 | // - `train`: boolean flag to turn on/off the dropout layers in the model. 86 | // Should be set to false for inference. 87 | // 88 | // Return: 89 | // - `embeddedOutput`: tensor of shape (batch size, sequence length, hidden size) 90 | func (re *RobertaEmbeddings) ForwardT(inputIds, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (*ts.Tensor, error) { 91 | 92 | var ( 93 | inputEmbeddings *ts.Tensor 94 | inputShape []int64 95 | ) 96 | 97 | if !inputIds.MustDefined() { 98 | if inputEmbeds.MustDefined() { 99 | inputEmbeddings = inputEmbeds 100 | inputEmbedsShape := inputEmbeds.MustSize() 101 | inputShape = []int64{inputEmbedsShape[0], inputEmbedsShape[1]} 102 | } else { 103 | err := fmt.Errorf("Only one of input Ids or input embeddings may be set.") 104 | return ts.None, err 105 | } 106 | } else { 107 | // if inputIds == inputEmbeds 108 | if util.Equal(inputIds, inputEmbeds) { 109 | err := fmt.Errorf("Only one of input Ids or input embeddings may be set.") 110 | return ts.None, err 111 | } else { 112 | inputEmbeddings = inputIds.ApplyT(re.wordEmbeddings, train) 113 | inputShape = inputIds.MustSize() 114 | } 115 | } 116 | 117 | var posIds *ts.Tensor 118 | if positionIds.MustDefined() { 119 | posIds = positionIds 120 | } else { 121 | if inputIds.MustDefined() { 122 | posIds = re.createPositionIdsFromInputIds(inputIds) 123 | } else { 124 | posIds = re.createPositionIdsFromEmbeddings(inputEmbeds) 125 | } 126 | } 127 | 128 | var tokTypeIds *ts.Tensor 129 | if tokenTypeIds.MustDefined() { 130 | tokTypeIds = tokenTypeIds 131 | } else { 132 | tokTypeIds = ts.MustZeros(inputShape, gotch.Int64, inputEmbeddings.MustDevice()) 133 | } 134 | 135 | positionEmbeddings := posIds.Apply(re.positionEmbeddings) 136 | tokenTypeEmbeddings := tokTypeIds.Apply(re.tokenTypeEmbeddings) 137 | 138 | add1 := inputEmbeddings.MustAdd(positionEmbeddings, true) 139 | newInputEmbeddings := add1.MustAdd(tokenTypeEmbeddings, true) 140 | inputEmbeddings.MustDrop() 141 | 142 | appliedLN := newInputEmbeddings.Apply(re.layerNorm) 143 | appliedDropout := appliedLN.ApplyT(re.dropout, train) 144 | appliedLN.MustDrop() 145 | 146 | return appliedDropout, nil 147 | } 148 | -------------------------------------------------------------------------------- /pipeline/common.go: -------------------------------------------------------------------------------- 1 | package pipeline 2 | 3 | import ( 4 | "log" 5 | "reflect" 6 | 7 | "github.com/sugarme/tokenizer" 8 | "github.com/sugarme/tokenizer/model/wordpiece" 9 | "github.com/sugarme/tokenizer/normalizer" 10 | "github.com/sugarme/tokenizer/pretokenizer" 11 | "github.com/sugarme/tokenizer/processor" 12 | "github.com/sugarme/transformer/bert" 13 | ) 14 | 15 | // Common blocks for generic pipelines (e.g. token classification or sequence classification) 16 | // Provides Enums holding configuration or tokenization resources that can be used to create 17 | // generic pipelines. The model component is defined in the generic pipeline itself as the 18 | // pre-processing, forward pass and postprocessing differs between pipelines while basic config and 19 | // tokenization objects don't. 20 | 21 | // ModelType is a enum-like, identifying the type of model 22 | type ModelType int 23 | 24 | const ( 25 | Bert ModelType = iota 26 | DistilBert 27 | Roberta 28 | XLMRoberta 29 | Electra 30 | Marian 31 | T5 32 | Albert 33 | ) 34 | 35 | type ModelOption struct { 36 | model ModelType 37 | } 38 | 39 | type Config interface{} 40 | 41 | // ConfigOption holds a model configuration 42 | type ConfigOption struct { 43 | model ModelType 44 | config Config 45 | } 46 | 47 | func NewBertConfigOption(config bert.BertConfig) *ConfigOption { 48 | return &ConfigOption{ 49 | model: Bert, 50 | config: config, 51 | } 52 | } 53 | 54 | type TokenizerType int 55 | 56 | const ( 57 | BertTokenizer TokenizerType = iota 58 | RobertaTokenizer 59 | XLMRobertaTokenizer 60 | MarianTokenizer 61 | T5Tokenizer 62 | AlbertTokenizer 63 | ) 64 | 65 | // TokenizerOption specifies a tokenizer 66 | type TokenizerOption struct { 67 | model ModelType 68 | tokenizer *tokenizer.Tokenizer 69 | } 70 | 71 | // ConfigOption methods: 72 | // ===================== 73 | 74 | // ConfigOptionFromFile loads configuration for corresponding model type from file. 75 | func ConfigOptionFromFile(modelType ModelType, path string) *ConfigOption { 76 | 77 | var configOpt *ConfigOption 78 | 79 | switch reflect.TypeOf(modelType).Kind().String() { 80 | case "Bert": 81 | config := bert.ConfigFromFile(path) 82 | configOpt = &ConfigOption{ 83 | model: Bert, 84 | config: config, 85 | } 86 | 87 | // TODO: implement others 88 | // case "DistilBert": 89 | default: 90 | log.Fatalf("Invalid modelType: '%v'\n", reflect.TypeOf(modelType).Kind().String()) 91 | } 92 | 93 | return configOpt 94 | } 95 | 96 | // GetLabelMap returns label mapping for corresponding model type. 97 | func (co *ConfigOption) GetLabelMapping() map[int64]string { 98 | 99 | var labelMap map[int64]string = make(map[int64]string) 100 | 101 | modelTypeStr := reflect.TypeOf(co.model).Kind().String() 102 | switch modelTypeStr { 103 | case "Bert": 104 | labelMap = co.config.(bert.BertConfig).Id2Label 105 | 106 | // TODO: implement others 107 | default: 108 | log.Fatalf("ConfigOption GetLabelMapping error: invalid model type ('%v')\n", modelTypeStr) 109 | } 110 | 111 | return labelMap 112 | } 113 | 114 | // TOkenizerOptionFromFile loads TokenizerOption from file corresponding to model type. 115 | func TokenizerOptionFromFile(modelType ModelType, path string) *TokenizerOption { 116 | modelTypeStr := reflect.TypeOf(modelType).Kind().String() 117 | 118 | var tk *TokenizerOption 119 | switch modelTypeStr { 120 | case "Bert": 121 | tk = &TokenizerOption{ 122 | model: modelType, 123 | tokenizer: getBert(path), 124 | } 125 | 126 | // TODO: implement others 127 | 128 | default: 129 | log.Fatalf("Unsupported model type: '%v'", modelTypeStr) 130 | } 131 | 132 | return tk 133 | } 134 | 135 | func getBert(path string) (retVal *tokenizer.Tokenizer) { 136 | model, err := wordpiece.NewWordPieceFromFile(path, "[UNK]") 137 | if err != nil { 138 | log.Fatal(err) 139 | } 140 | 141 | tk := tokenizer.NewTokenizer(model) 142 | 143 | bertNormalizer := normalizer.NewBertNormalizer(true, true, true, true) 144 | tk.WithNormalizer(bertNormalizer) 145 | 146 | bertPreTokenizer := pretokenizer.NewBertPreTokenizer() 147 | tk.WithPreTokenizer(bertPreTokenizer) 148 | 149 | var specialTokens []tokenizer.AddedToken 150 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("[MASK]", true)) 151 | 152 | tk.AddSpecialTokens(specialTokens) 153 | 154 | sepId, ok := tk.TokenToId("[SEP]") 155 | if !ok { 156 | log.Fatalf("Cannot find ID for [SEP] token.\n") 157 | } 158 | sep := processor.PostToken{Id: sepId, Value: "[SEP]"} 159 | 160 | clsId, ok := tk.TokenToId("[CLS]") 161 | if !ok { 162 | log.Fatalf("Cannot find ID for [CLS] token.\n") 163 | } 164 | cls := processor.PostToken{Id: clsId, Value: "[CLS]"} 165 | 166 | postProcess := processor.NewBertProcessing(sep, cls) 167 | tk.WithPostProcessor(postProcess) 168 | 169 | return tk 170 | } 171 | 172 | // ModelType returns chosen model type 173 | func (tk *TokenizerOption) ModelType() ModelType { 174 | return tk.model 175 | } 176 | 177 | // EncodeList encodes a slice of input string 178 | func (tk *TokenizerOption) EncodeList(sentences []string) ([]tokenizer.Encoding, error) { 179 | var input []tokenizer.EncodeInput 180 | for _, sentence := range sentences { 181 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence))) 182 | } 183 | 184 | return tk.tokenizer.EncodeBatch(input, true) 185 | } 186 | 187 | // Tokenize tokenizes input string 188 | func (tk *TokenizerOption) Tokenize(sentence string) ([]string, error) { 189 | 190 | input := tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence)) 191 | 192 | encoding, err := tk.tokenizer.Encode(input, true) 193 | if err != nil { 194 | return nil, err 195 | } 196 | 197 | return encoding.Tokens, nil 198 | } 199 | 200 | // AddSpecialTokens adds special tokens to tokenizer 201 | func (tk *TokenizerOption) AddSpecialTokens(tokens []string) { 202 | 203 | var addedToks []tokenizer.AddedToken 204 | for _, tok := range tokens { 205 | addedToks = append(addedToks, tokenizer.NewAddedToken(tok, true)) 206 | } 207 | 208 | tk.tokenizer.AddSpecialTokens(addedToks) 209 | } 210 | 211 | // TokensToIds converts a slice of tokens to corresponding Ids. 212 | func (tk *TokenizerOption) TokensToIds(tokens []string) (ids []int64, ok bool) { 213 | for _, tok := range tokens { 214 | id, ok := tk.tokenizer.TokenToId(tok) 215 | if !ok { 216 | return nil, false 217 | } 218 | ids = append(ids, int64(id)) 219 | } 220 | 221 | return ids, true 222 | } 223 | 224 | // PadId returns a PAD id if any. 225 | func (tk *TokenizerOption) PadId() (id int64, ok bool) { 226 | paddingParam := tk.tokenizer.GetPadding() 227 | if paddingParam == nil { 228 | return -1, false 229 | } 230 | 231 | return int64(paddingParam.PadId), true 232 | } 233 | 234 | // SepId returns a SEP id if any. 235 | // If optional sepOpt is not specify, default value is "[SEP]" 236 | func (tk *TokenizerOption) SepId(sepOpt ...string) (id int64, ok bool) { 237 | 238 | sep := "[SEP]" // default sep string 239 | if len(sepOpt) > 0 { 240 | sep = sepOpt[0] 241 | } 242 | 243 | i, ok := tk.tokenizer.TokenToId(sep) 244 | return int64(i), ok 245 | } 246 | -------------------------------------------------------------------------------- /example/bert/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | 7 | "github.com/sugarme/gotch" 8 | "github.com/sugarme/gotch/nn" 9 | "github.com/sugarme/gotch/pickle" 10 | "github.com/sugarme/gotch/ts" 11 | "github.com/sugarme/tokenizer" 12 | "github.com/sugarme/tokenizer/model/wordpiece" 13 | "github.com/sugarme/tokenizer/normalizer" 14 | "github.com/sugarme/tokenizer/pretokenizer" 15 | "github.com/sugarme/tokenizer/processor" 16 | "github.com/sugarme/transformer/bert" 17 | "github.com/sugarme/transformer/util" 18 | ) 19 | 20 | func main() { 21 | bertForMaskedLM() 22 | // bertForSequenceClassification() 23 | } 24 | 25 | func getBert() (retVal *tokenizer.Tokenizer) { 26 | vocabFile, err := util.CachedPath("bert-base-uncased", "vocab.txt") 27 | if err != nil { 28 | panic(err) 29 | } 30 | 31 | model, err := wordpiece.NewWordPieceFromFile(vocabFile, "[UNK]") 32 | if err != nil { 33 | log.Fatal(err) 34 | } 35 | 36 | tk := tokenizer.NewTokenizer(model) 37 | fmt.Printf("Vocab size: %v\n", tk.GetVocabSize(false)) 38 | 39 | bertNormalizer := normalizer.NewBertNormalizer(true, true, true, true) 40 | tk.WithNormalizer(bertNormalizer) 41 | 42 | bertPreTokenizer := pretokenizer.NewBertPreTokenizer() 43 | tk.WithPreTokenizer(bertPreTokenizer) 44 | 45 | var specialTokens []tokenizer.AddedToken 46 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("[MASK]", true)) 47 | 48 | tk.AddSpecialTokens(specialTokens) 49 | 50 | sepId, ok := tk.TokenToId("[SEP]") 51 | if !ok { 52 | log.Fatalf("Cannot find ID for [SEP] token.\n") 53 | } 54 | sep := processor.PostToken{Id: sepId, Value: "[SEP]"} 55 | 56 | clsId, ok := tk.TokenToId("[CLS]") 57 | if !ok { 58 | log.Fatalf("Cannot find ID for [CLS] token.\n") 59 | } 60 | cls := processor.PostToken{Id: clsId, Value: "[CLS]"} 61 | 62 | postProcess := processor.NewBertProcessing(sep, cls) 63 | tk.WithPostProcessor(postProcess) 64 | 65 | return tk 66 | } 67 | 68 | func bertForMaskedLM() { 69 | device := gotch.CPU 70 | vs := nn.NewVarStore(device) 71 | 72 | configFile, err := util.CachedPath("bert-base-uncased", "config.json") 73 | if err != nil { 74 | panic(err) 75 | } 76 | config, err := bert.ConfigFromFile(configFile) 77 | if err != nil { 78 | panic(err) 79 | } 80 | // fmt.Printf("Bert Configuration:\n%+v\n", config) 81 | 82 | model, err := bert.NewBertForMaskedLM(vs.Root(), config) 83 | if err != nil { 84 | panic(err) 85 | } 86 | 87 | modelFile, err := util.CachedPath("bert-base-uncased", "pytorch_model.bin") 88 | if err != nil { 89 | panic(err) 90 | } 91 | err = pickle.LoadAll(vs, modelFile) 92 | if err != nil { 93 | log.Fatalf("Load model weight error: \n%v", err) 94 | } 95 | 96 | // fmt.Printf("Varstore weights have been loaded\n") 97 | // fmt.Printf("Num of variables: %v\n", len(vs.Variables())) 98 | // fmt.Printf("%v\n", vs.Variables()) 99 | // fmt.Printf("Bert is Decoder: %v\n", model.bert.IsDecoder) 100 | 101 | tk := getBert() 102 | sentence1 := "Looks like one [MASK] is missing" 103 | sentence2 := "It was a very nice and [MASK] day" 104 | 105 | var input []tokenizer.EncodeInput 106 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1))) 107 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2))) 108 | 109 | encodings, err := tk.EncodeBatch(input, true) 110 | if err != nil { 111 | log.Fatal(err) 112 | } 113 | 114 | // Find max length of token Ids from slice of encodings 115 | var maxLen int = 0 116 | for _, en := range encodings { 117 | if len(en.Ids) > maxLen { 118 | maxLen = len(en.Ids) 119 | } 120 | } 121 | 122 | var tensors []ts.Tensor 123 | for _, en := range encodings { 124 | var tokInput []int64 = make([]int64, maxLen) 125 | for i := 0; i < len(en.Ids); i++ { 126 | tokInput[i] = int64(en.Ids[i]) 127 | } 128 | 129 | tensors = append(tensors, *ts.TensorFrom(tokInput)) 130 | } 131 | 132 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true) 133 | // inputTensor.Print() 134 | 135 | var output *ts.Tensor 136 | ts.NoGrad(func() { 137 | output, _, _ = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, ts.None, ts.None, false) 138 | }) 139 | 140 | index1 := output.MustGet(0).MustGet(4).MustArgmax([]int64{0}, false, false).Int64Values()[0] 141 | index2 := output.MustGet(1).MustGet(7).MustArgmax([]int64{0}, false, false).Int64Values()[0] 142 | 143 | word1, ok := tk.IdToToken(int(index1)) 144 | if !ok { 145 | fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index1) 146 | } 147 | fmt.Printf("Input: '%v' \t- Output: '%v'\n", sentence1, word1) 148 | 149 | word2, ok := tk.IdToToken(int(index2)) 150 | if !ok { 151 | fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index2) 152 | } 153 | fmt.Printf("Input: '%v' \t- Output: '%v'\n", sentence2, word2) 154 | } 155 | 156 | func bertForSequenceClassification() { 157 | device := gotch.CPU 158 | vs := nn.NewVarStore(device) 159 | 160 | configFile, err := util.CachedPath("bert-base-uncased", "config.json") 161 | if err != nil { 162 | panic(err) 163 | } 164 | config, err := bert.ConfigFromFile(configFile) 165 | if err != nil { 166 | panic(err) 167 | } 168 | 169 | var dummyLabelMap map[int64]string = make(map[int64]string) 170 | dummyLabelMap[0] = "positive" 171 | dummyLabelMap[1] = "negative" 172 | dummyLabelMap[3] = "neutral" 173 | 174 | config.Id2Label = dummyLabelMap 175 | config.OutputAttentions = true 176 | config.OutputHiddenStates = true 177 | model := bert.NewBertForSequenceClassification(vs.Root(), config) 178 | tk := getBert() 179 | 180 | // Define input 181 | sentence1 := "Looks like one thing is missing" 182 | sentence2 := `It's like comparing oranges to apples` 183 | var input []tokenizer.EncodeInput 184 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1))) 185 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2))) 186 | encodings, err := tk.EncodeBatch(input, true) 187 | if err != nil { 188 | log.Fatal(err) 189 | } 190 | 191 | // Find max length of token Ids from slice of encodings 192 | var maxLen int = 0 193 | for _, en := range encodings { 194 | if len(en.Ids) > maxLen { 195 | maxLen = len(en.Ids) 196 | } 197 | } 198 | 199 | fmt.Printf("encodings: %v\n", encodings) 200 | var tensors []ts.Tensor 201 | for _, en := range encodings { 202 | var tokInput []int64 = make([]int64, maxLen) 203 | for i := 0; i < len(en.Ids); i++ { 204 | tokInput[i] = int64(en.Ids[i]) 205 | } 206 | 207 | tensors = append(tensors, *ts.TensorFrom(tokInput)) 208 | } 209 | 210 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true) 211 | // inputTensor.Print() 212 | 213 | var ( 214 | output *ts.Tensor 215 | allHiddenStates, allAttentions []ts.Tensor 216 | ) 217 | 218 | ts.NoGrad(func() { 219 | output, allHiddenStates, allAttentions = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, false) 220 | }) 221 | 222 | fmt.Printf("output size: %v\n", output.MustSize()) 223 | 224 | fmt.Printf("NumHiddenLayers: %v\n", config.NumHiddenLayers) 225 | fmt.Printf("allHiddenStates length: %v\n", len(allHiddenStates)) 226 | fmt.Printf("allAttentions length: %v\n", len(allAttentions)) 227 | 228 | } 229 | -------------------------------------------------------------------------------- /util/file-util.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "log" 7 | "net/http" 8 | "os" 9 | "path" 10 | "strconv" 11 | "strings" 12 | ) 13 | 14 | // This file provides functions to work with local dataset cache, ... 15 | 16 | const ( 17 | WeightName = "pytorch_model.gt" 18 | ConfigName = "config.json" 19 | 20 | // NOTE. URL form := `$HFpath/ModelName/resolve/main/WeightName` 21 | HFpath = "https://huggingface.co" 22 | ) 23 | 24 | var ( 25 | DUMMY_INPUT [][]int64 = [][]int64{ 26 | {7, 6, 0, 0, 1}, 27 | {1, 2, 3, 0, 0}, 28 | {0, 0, 0, 4, 5}, 29 | } 30 | ) 31 | 32 | // CachedPath resolves and caches data based on input string, then returns fullpath to the cached data. 33 | // 34 | // Parameters: 35 | // - `modelNameOrPath`: model name e.g., "bert-base-uncased" or path to directory contains model/config files. 36 | // - `fileName`: model or config file name. E.g., "pytorch_model.py", "config.json" 37 | // 38 | // CachedPath does several things consequently: 39 | // 1. Resolves input string to a fullpath cached filename candidate. 40 | // 2. Check it at `CachedPath`, if exists, then return the candidate. If not 41 | // 3. Retrieves and Caches data to `CachedPath` and returns path to cached data 42 | // 43 | // NOTE. default `CachedDir` is at "{$HOME}/.cache/transformer" 44 | // Custom `CachedDir` can be changed by setting with environment `GO_TRANSFORMER` 45 | func CachedPath(modelNameOrPath, fileName string) (resolvedPath string, err error) { 46 | 47 | // Resolves to "candidate" filename at `CacheDir` 48 | cachedFileCandidate := fmt.Sprintf("%s/%s/%s", CachedDir, modelNameOrPath, fileName) 49 | 50 | // 1. Cached candidate file exists 51 | if _, err := os.Stat(cachedFileCandidate); err == nil { 52 | return cachedFileCandidate, nil 53 | } 54 | 55 | // 2. If valid fullpath to local file, caches it and return cached filename 56 | filepath := fmt.Sprintf("%s/%s", modelNameOrPath, fileName) 57 | if _, err := os.Stat(filepath); err == nil { 58 | err := copyFile(filepath, cachedFileCandidate) 59 | if err != nil { 60 | err := fmt.Errorf("CachedPath() failed at copying file: %w", err) 61 | return "", err 62 | } 63 | return cachedFileCandidate, nil 64 | } 65 | 66 | // 3. Cached candidate file NOT exist. Try to download it and save to `CachedDir` 67 | url := fmt.Sprintf("%s/%s/resolve/main/%s", HFpath, modelNameOrPath, fileName) 68 | // url := fmt.Sprintf("%s/%s/raw/main/%s", HFpath, modelNameOrPath, fileName) 69 | if isValidURL(url) { 70 | if _, err := http.Get(url); err == nil { 71 | err := downloadFile(url, cachedFileCandidate) 72 | if err != nil { 73 | err = fmt.Errorf("CachedPath() failed at trying to download file: %w", err) 74 | return "", err 75 | } 76 | 77 | return cachedFileCandidate, nil 78 | } else { 79 | err = fmt.Errorf("CachedPath() failed: Unable to parse '%v' as a URL or as a local path.\n", url) 80 | return "", err 81 | } 82 | } 83 | 84 | // Not resolves 85 | err = fmt.Errorf("CachedPath() failed: Unable to parse '%v' as a URL or as a local path.\n", url) 86 | return "", err 87 | } 88 | 89 | func isValidURL(url string) bool { 90 | 91 | // TODO: implement 92 | return true 93 | } 94 | 95 | // downloadFile downloads file from URL and stores it in local filepath. 96 | // It writes to the destination file as it downloads it, without loading 97 | // the entire file into memory. An `io.TeeReader` is passed into Copy() 98 | // to report progress on the download. 99 | func downloadFile(url string, filepath string) error { 100 | // Create path if not existing 101 | dir := path.Dir(filepath) 102 | filename := path.Base(filepath) 103 | if _, err := os.Stat(dir); os.IsNotExist(err) { 104 | if err := os.MkdirAll(dir, 0755); err != nil { 105 | log.Fatal(err) 106 | } 107 | } 108 | 109 | // Create the file with .tmp extension, so that we won't overwrite a 110 | // file until it's downloaded fully 111 | out, err := os.Create(filepath + ".tmp") 112 | if err != nil { 113 | return err 114 | } 115 | defer out.Close() 116 | 117 | // Get the data 118 | resp, err := http.Get(url) 119 | if err != nil { 120 | return err 121 | } 122 | defer resp.Body.Close() 123 | 124 | // Check server response 125 | if resp.StatusCode != http.StatusOK { 126 | err := fmt.Errorf("bad status: %s(%v)", resp.Status, resp.StatusCode) 127 | 128 | if resp.StatusCode == 404 { 129 | // if filename == "rust_model.ot" { 130 | // msg := fmt.Sprintf("model weight file not found. That means a compatible pretrained model weight file for Go is not available.\n") 131 | // msg = msg + fmt.Sprintf("You might need to manually convert a 'pytorch_model.bin' for Go. ") 132 | // msg = msg + fmt.Sprintf("See tutorial at: 'example/convert'") 133 | // err = fmt.Errorf(msg) 134 | // } else { 135 | // err = fmt.Errorf("download file not found: %q for downloading", url) 136 | // } 137 | err = fmt.Errorf("download file not found: %q for downloading", url) 138 | } else { 139 | err = fmt.Errorf("download file failed: %q", url) 140 | } 141 | return err 142 | } 143 | 144 | // the total file size to download 145 | size, _ := strconv.Atoi(resp.Header.Get("Content-Length")) 146 | downloadSize := uint64(size) 147 | 148 | // Create our bytes counter and pass it to be used alongside our writer 149 | counter := &writeCounter{FileSize: downloadSize} 150 | _, err = io.Copy(out, io.TeeReader(resp.Body, counter)) 151 | if err != nil { 152 | return err 153 | } 154 | 155 | fmt.Printf("\r%s... %s/%s completed", filename, byteCountIEC(counter.Total), byteCountIEC(counter.FileSize)) 156 | // The progress use the same line so print a new line once it's finished downloading 157 | fmt.Println() 158 | 159 | // Rename the tmp file back to the original file 160 | err = os.Rename(filepath+".tmp", filepath) 161 | if err != nil { 162 | return err 163 | } 164 | 165 | return nil 166 | } 167 | 168 | // writeCounter counts the number of bytes written to it. By implementing the Write method, 169 | // it is of the io.Writer interface and we can pass this into io.TeeReader() 170 | // Every write to this writer, will print the progress of the file write. 171 | type writeCounter struct { 172 | Total uint64 173 | FileSize uint64 174 | } 175 | 176 | func (wc *writeCounter) Write(p []byte) (int, error) { 177 | n := len(p) 178 | wc.Total += uint64(n) 179 | wc.printProgress() 180 | return n, nil 181 | } 182 | 183 | // PrintProgress prints the progress of a file write 184 | func (wc writeCounter) printProgress() { 185 | // Clear the line by using a character return to go back to the start and remove 186 | // the remaining characters by filling it with spaces 187 | fmt.Printf("\r%s", strings.Repeat(" ", 50)) 188 | 189 | // Return again and print current status of download 190 | fmt.Printf("\rDownloading... %s/%s", byteCountIEC(wc.Total), byteCountIEC(wc.FileSize)) 191 | } 192 | 193 | // byteCountIEC converts bytes to human-readable string in binary (IEC) format. 194 | func byteCountIEC(b uint64) string { 195 | const unit = 1024 196 | if b < unit { 197 | return fmt.Sprintf("%d B", b) 198 | } 199 | div, exp := uint64(unit), 0 200 | for n := b / unit; n >= unit; n /= unit { 201 | div *= unit 202 | exp++ 203 | } 204 | return fmt.Sprintf("%.1f %ciB", 205 | float64(b)/float64(div), "KMGTPE"[exp]) 206 | } 207 | 208 | func copyFile(src, dst string) error { 209 | sourceFileStat, err := os.Stat(src) 210 | if err != nil { 211 | return err 212 | } 213 | 214 | if !sourceFileStat.Mode().IsRegular() { 215 | return fmt.Errorf("%s is not a regular file", src) 216 | } 217 | 218 | source, err := os.Open(src) 219 | if err != nil { 220 | return err 221 | } 222 | defer source.Close() 223 | 224 | destination, err := os.Create(dst) 225 | if err != nil { 226 | return err 227 | } 228 | defer destination.Close() 229 | _, err = io.Copy(destination, source) 230 | return err 231 | } 232 | 233 | // CleanCache removes all files cached in transformer cache directory `CachedDir`. 234 | // 235 | // NOTE. custom `CachedDir` can be changed by setting environment `GO_TRANSFORMER` 236 | func CleanCache() error { 237 | err := os.RemoveAll(CachedDir) 238 | if err != nil { 239 | err = fmt.Errorf("CleanCache() failed: %w", err) 240 | return err 241 | } 242 | 243 | return nil 244 | } 245 | -------------------------------------------------------------------------------- /bert/attention.go: -------------------------------------------------------------------------------- 1 | package bert 2 | 3 | import ( 4 | // "fmt" 5 | "log" 6 | "math" 7 | 8 | "github.com/sugarme/gotch" 9 | "github.com/sugarme/gotch/nn" 10 | "github.com/sugarme/gotch/ts" 11 | 12 | "github.com/sugarme/transformer/util" 13 | ) 14 | 15 | // BertSelfAttention: 16 | //=================== 17 | 18 | type BertSelfAttention struct { 19 | NumAttentionHeads int64 20 | AttentionHeadSize int64 21 | Dropout *util.Dropout 22 | OutputAttentions bool 23 | Query *nn.Linear 24 | Key *nn.Linear 25 | Value *nn.Linear 26 | } 27 | 28 | // NewBertSelfAttention creates a new `BertSelfAttention` 29 | func NewBertSelfAttention(p *nn.Path, config *BertConfig) *BertSelfAttention { 30 | if config.HiddenSize%config.NumAttentionHeads != 0 { 31 | log.Fatal("Hidden size is not a multiple of the number of attention heads.") 32 | } 33 | 34 | lconfig := nn.DefaultLinearConfig() 35 | query := nn.NewLinear(p.Sub("query"), config.HiddenSize, config.HiddenSize, lconfig) 36 | key := nn.NewLinear(p.Sub("key"), config.HiddenSize, config.HiddenSize, lconfig) 37 | value := nn.NewLinear(p.Sub("value"), config.HiddenSize, config.HiddenSize, lconfig) 38 | 39 | dropout := util.NewDropout(config.AttentionProbsDropoutProb) 40 | attentionHeadSize := int64(config.HiddenSize) / config.NumAttentionHeads 41 | outputAttentions := config.OutputAttentions 42 | 43 | return &BertSelfAttention{ 44 | NumAttentionHeads: config.NumAttentionHeads, 45 | AttentionHeadSize: attentionHeadSize, 46 | Dropout: dropout, 47 | OutputAttentions: outputAttentions, 48 | Query: query, 49 | Key: key, 50 | Value: value, 51 | } 52 | 53 | } 54 | 55 | func (bsa *BertSelfAttention) splitHeads(x *ts.Tensor, bs, dimPerHead int64) (retVal *ts.Tensor) { 56 | 57 | xview := x.MustView([]int64{bs, -1, bsa.NumAttentionHeads, dimPerHead}, false) 58 | 59 | return xview.MustTranspose(1, 2, true) 60 | } 61 | 62 | func (bsa *BertSelfAttention) flatten(x *ts.Tensor, bs, dimPerHead int64) (retVal *ts.Tensor) { 63 | 64 | xT := x.MustTranspose(1, 2, false) 65 | xCon := xT.MustContiguous(true) 66 | retVal = xCon.MustView([]int64{bs, -1, bsa.NumAttentionHeads * dimPerHead}, true) 67 | 68 | return retVal 69 | } 70 | 71 | // ForwardT implements ModuleT interface for BertSelfAttention 72 | // 73 | // NOTE. mask, encoderHiddenStates, encoderMask are optional tensors 74 | // for `None` value, `ts.None` can be used. 75 | func (bsa *BertSelfAttention) ForwardT(hiddenStates, mask, encoderHiddenStates, encoderMask *ts.Tensor, train bool) (retVal, retValOpt *ts.Tensor) { 76 | 77 | key := bsa.Key.Forward(hiddenStates) 78 | value := bsa.Value.Forward(hiddenStates) 79 | 80 | if encoderHiddenStates.MustDefined() { 81 | key = bsa.Key.Forward(encoderHiddenStates) 82 | value = bsa.Value.Forward(encoderHiddenStates) 83 | } 84 | 85 | bs := hiddenStates.MustSize()[0] 86 | 87 | hiddenStatesQ := hiddenStates.Apply(bsa.Query) 88 | query := bsa.splitHeads(hiddenStatesQ, bs, bsa.AttentionHeadSize) 89 | 90 | hiddenStatesQ.MustDrop() 91 | keyLayer := bsa.splitHeads(key, bs, bsa.AttentionHeadSize) 92 | key.MustDrop() 93 | valueLayer := bsa.splitHeads(value, bs, bsa.AttentionHeadSize) 94 | value.MustDrop() 95 | 96 | size := math.Sqrt(float64(bsa.AttentionHeadSize)) 97 | queryLayer := query.MustDivScalar(ts.FloatScalar(size), true) 98 | 99 | // Calculate score 100 | var scores *ts.Tensor 101 | if mask.MustDefined() { 102 | keyLayerT := keyLayer.MustTranspose(-1, -2, true) 103 | keyLayerT.MustAdd_(mask) 104 | scores = queryLayer.MustMatmul(keyLayerT, true) 105 | } else { 106 | keyLayerT := keyLayer.MustTranspose(-1, -2, true) 107 | scores = queryLayer.MustMatmul(keyLayerT, true) 108 | } 109 | 110 | weights := scores.MustSoftmax(-1, gotch.Float, true).ApplyT(bsa.Dropout, train) 111 | 112 | weightsMul := weights.MustMatmul(valueLayer, false) 113 | 114 | context := bsa.flatten(weightsMul, bs, bsa.AttentionHeadSize) 115 | weightsMul.MustDrop() 116 | 117 | if !bsa.OutputAttentions { 118 | weights.MustDrop() 119 | return context, ts.None 120 | } else { 121 | return context, weights 122 | } 123 | 124 | } 125 | 126 | // BertSelfOutput: 127 | //================ 128 | 129 | type BertSelfOutput struct { 130 | Linear *nn.Linear 131 | LayerNorm *nn.LayerNorm 132 | Dropout *util.Dropout 133 | } 134 | 135 | func NewBertSelfOutput(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertSelfOutput { 136 | changeName := true 137 | if len(changeNameOpt) > 0 { 138 | changeName = changeNameOpt[0] 139 | } 140 | 141 | path := p.Sub("dense") 142 | lconfig := nn.DefaultLinearConfig() 143 | linear := nn.NewLinear(path, config.HiddenSize, config.HiddenSize, lconfig) 144 | 145 | layerNormConfig := nn.DefaultLayerNormConfig() 146 | if changeName { 147 | layerNormConfig.WsName = "gamma" 148 | layerNormConfig.BsName = "beta" 149 | } 150 | layerNormConfig.Eps = 1e-12 151 | 152 | layerNorm := nn.NewLayerNorm(p.Sub("LayerNorm"), []int64{config.HiddenSize}, layerNormConfig) 153 | dropout := util.NewDropout(config.HiddenDropoutProb) 154 | 155 | return &BertSelfOutput{linear, layerNorm, dropout} 156 | } 157 | 158 | func (bso *BertSelfOutput) ForwardT(hiddenStates *ts.Tensor, inputTensor *ts.Tensor, train bool) (retVal *ts.Tensor) { 159 | 160 | state1 := hiddenStates.Apply(bso.Linear) 161 | state2 := state1.ApplyT(bso.Dropout, train) 162 | state3 := inputTensor.MustAdd(state2, false) 163 | 164 | retVal = state3.Apply(bso.LayerNorm) 165 | state1.MustDrop() 166 | state2.MustDrop() 167 | state3.MustDrop() 168 | 169 | return retVal 170 | } 171 | 172 | // BertAttention: 173 | //=============== 174 | 175 | type BertAttention struct { 176 | Bsa *BertSelfAttention 177 | Output *BertSelfOutput 178 | } 179 | 180 | func NewBertAttention(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertAttention { 181 | changeName := true 182 | if len(changeNameOpt) > 0 { 183 | changeName = changeNameOpt[0] 184 | } 185 | self := NewBertSelfAttention(p.Sub("self"), config) 186 | output := NewBertSelfOutput(p.Sub("output"), config, changeName) 187 | 188 | return &BertAttention{self, output} 189 | } 190 | 191 | func (ba *BertAttention) ForwardT(hiddenStates, mask, encoderHiddenStates, encoderMask *ts.Tensor, train bool) (retVal, RetValOpt *ts.Tensor) { 192 | 193 | selfOutput, attentionWeights := ba.Bsa.ForwardT(hiddenStates, mask, encoderHiddenStates, encoderMask, train) 194 | selfOutput = ba.Output.ForwardT(selfOutput, hiddenStates, train) 195 | 196 | return selfOutput, attentionWeights 197 | } 198 | 199 | // BertIntermedate: 200 | //================= 201 | 202 | type BertIntermediate struct { 203 | Lin *nn.Linear 204 | Activation util.ActivationFn // interface 205 | } 206 | 207 | func NewBertIntermediate(p *nn.Path, config *BertConfig) *BertIntermediate { 208 | lconfig := nn.DefaultLinearConfig() 209 | lin := nn.NewLinear(p.Sub("dense"), config.HiddenSize, config.IntermediateSize, lconfig) 210 | 211 | actFn, ok := util.ActivationFnMap[config.HiddenAct] 212 | if !ok { 213 | log.Fatalf("Unsupported activation function - %v\n", config.HiddenAct) 214 | } 215 | 216 | return &BertIntermediate{lin, actFn} 217 | } 218 | 219 | func (bi *BertIntermediate) Forward(hiddenStates *ts.Tensor) (retVal *ts.Tensor) { 220 | 221 | states := hiddenStates.Apply(bi.Lin) 222 | 223 | retVal = bi.Activation.Fwd(states) 224 | states.MustDrop() 225 | 226 | return retVal 227 | } 228 | 229 | // BertOutput: 230 | //============ 231 | 232 | type BertOutput struct { 233 | Lin *nn.Linear 234 | LayerNorm *nn.LayerNorm 235 | Dropout *util.Dropout 236 | } 237 | 238 | func NewBertOutput(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertOutput { 239 | changeName := true 240 | if len(changeNameOpt) > 0 { 241 | changeName = changeNameOpt[0] 242 | } 243 | 244 | lconfig := nn.DefaultLinearConfig() 245 | lin := nn.NewLinear(p.Sub("dense"), config.IntermediateSize, config.HiddenSize, lconfig) 246 | 247 | layerNormConfig := nn.DefaultLayerNormConfig() 248 | if changeName { 249 | layerNormConfig.WsName = "gamma" 250 | layerNormConfig.BsName = "beta" 251 | } 252 | layerNormConfig.Eps = 1e-12 253 | layerNorm := nn.NewLayerNorm(p.Sub("LayerNorm"), []int64{config.HiddenSize}, layerNormConfig) 254 | 255 | dropout := util.NewDropout(config.HiddenDropoutProb) 256 | 257 | return &BertOutput{lin, layerNorm, dropout} 258 | } 259 | 260 | func (bo *BertOutput) ForwardT(hiddenStates, inputTensor *ts.Tensor, train bool) (retVal *ts.Tensor) { 261 | 262 | state1 := hiddenStates.Apply(bo.Lin) 263 | state2 := state1.ApplyT(bo.Dropout, train) 264 | state3 := inputTensor.MustAdd(state2, false) 265 | 266 | retVal = state3.Apply(bo.LayerNorm) 267 | 268 | state1.MustDrop() 269 | state2.MustDrop() 270 | state3.MustDrop() 271 | 272 | return retVal 273 | } 274 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | Copyright 2020 Thang Tran. 179 | 180 | Licensed under the Apache License, Version 2.0 (the "License"); 181 | you may not use this file except in compliance with the License. 182 | You may obtain a copy of the License at 183 | 184 | http://www.apache.org/licenses/LICENSE-2.0 185 | 186 | Unless required by applicable law or agreed to in writing, software 187 | distributed under the License is distributed on an "AS IS" BASIS, 188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 189 | See the License for the specific language governing permissions and 190 | limitations under the License. 191 | -------------------------------------------------------------------------------- /bert/model_test.go: -------------------------------------------------------------------------------- 1 | package bert_test 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "reflect" 7 | "testing" 8 | 9 | "github.com/sugarme/gotch" 10 | "github.com/sugarme/gotch/nn" 11 | "github.com/sugarme/gotch/pickle" 12 | "github.com/sugarme/gotch/ts" 13 | "github.com/sugarme/tokenizer" 14 | "github.com/sugarme/tokenizer/model/wordpiece" 15 | "github.com/sugarme/tokenizer/normalizer" 16 | "github.com/sugarme/tokenizer/pretokenizer" 17 | "github.com/sugarme/tokenizer/processor" 18 | 19 | "github.com/sugarme/transformer/bert" 20 | "github.com/sugarme/transformer/util" 21 | ) 22 | 23 | func getBertTokenizer(vocabFile string) (retVal *tokenizer.Tokenizer) { 24 | model, err := wordpiece.NewWordPieceFromFile(vocabFile, "[UNK]") 25 | if err != nil { 26 | log.Fatal(err) 27 | } 28 | 29 | tk := tokenizer.NewTokenizer(model) 30 | 31 | bertNormalizer := normalizer.NewBertNormalizer(true, true, true, true) 32 | tk.WithNormalizer(bertNormalizer) 33 | 34 | bertPreTokenizer := pretokenizer.NewBertPreTokenizer() 35 | tk.WithPreTokenizer(bertPreTokenizer) 36 | 37 | var specialTokens []tokenizer.AddedToken 38 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("[MASK]", true)) 39 | 40 | tk.AddSpecialTokens(specialTokens) 41 | 42 | sepId, ok := tk.TokenToId("[SEP]") 43 | if !ok { 44 | log.Fatalf("Cannot find ID for [SEP] token.\n") 45 | } 46 | sep := processor.PostToken{Id: sepId, Value: "[SEP]"} 47 | 48 | clsId, ok := tk.TokenToId("[CLS]") 49 | if !ok { 50 | log.Fatalf("Cannot find ID for [CLS] token.\n") 51 | } 52 | cls := processor.PostToken{Id: clsId, Value: "[CLS]"} 53 | 54 | postProcess := processor.NewBertProcessing(sep, cls) 55 | tk.WithPostProcessor(postProcess) 56 | 57 | return tk 58 | } 59 | 60 | func TestBertForMaskedLM(t *testing.T) { 61 | // Config 62 | config := new(bert.BertConfig) 63 | 64 | configFile, err := util.CachedPath("bert-base-uncased", "config.json") 65 | if err != nil { 66 | log.Fatal(err) 67 | } 68 | 69 | err = config.Load(configFile, nil) 70 | if err != nil { 71 | log.Fatal(err) 72 | } 73 | 74 | // Model 75 | device := gotch.CPU 76 | vs := nn.NewVarStore(device) 77 | 78 | model, err := bert.NewBertForMaskedLM(vs.Root(), config) 79 | if err != nil { 80 | log.Fatal(err) 81 | } 82 | modelFile, err := util.CachedPath("bert-base-uncased", "pytorch_model.bin") 83 | if err != nil { 84 | log.Fatal(err) 85 | } 86 | 87 | err = pickle.LoadAll(vs, modelFile) 88 | if err != nil { 89 | t.Error(err) 90 | } 91 | 92 | vocabFile, err := util.CachedPath("bert-base-uncased", "vocab.txt") 93 | tk := getBertTokenizer(vocabFile) 94 | sentence1 := "Looks like one [MASK] is missing" 95 | sentence2 := "It was a very nice and [MASK] day" 96 | 97 | var input []tokenizer.EncodeInput 98 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1))) 99 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2))) 100 | 101 | encodings, err := tk.EncodeBatch(input, true) 102 | if err != nil { 103 | log.Fatal(err) 104 | } 105 | 106 | var maxLen int = 0 107 | for _, en := range encodings { 108 | if len(en.Ids) > maxLen { 109 | maxLen = len(en.Ids) 110 | } 111 | } 112 | 113 | var tensors []ts.Tensor 114 | for _, en := range encodings { 115 | var tokInput []int64 = make([]int64, maxLen) 116 | for i := 0; i < len(en.Ids); i++ { 117 | tokInput[i] = int64(en.Ids[i]) 118 | } 119 | 120 | tensors = append(tensors, *ts.TensorFrom(tokInput)) 121 | } 122 | 123 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true) 124 | 125 | var output *ts.Tensor 126 | ts.NoGrad(func() { 127 | output, _, _ = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, ts.None, ts.None, false) 128 | }) 129 | 130 | index1 := output.MustGet(0).MustGet(4).MustArgmax([]int64{0}, false, false).Int64Values()[0] 131 | index2 := output.MustGet(1).MustGet(7).MustArgmax([]int64{0}, false, false).Int64Values()[0] 132 | 133 | got1, ok := tk.IdToToken(int(index1)) 134 | if !ok { 135 | fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index1) 136 | } 137 | want1 := "person" 138 | 139 | if !reflect.DeepEqual(want1, got1) { 140 | t.Errorf("Want: '%v'\n", want1) 141 | t.Errorf("Got '%v'\n", got1) 142 | } 143 | 144 | got2, ok := tk.IdToToken(int(index2)) 145 | if !ok { 146 | fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index2) 147 | } 148 | want2 := "pleasant" 149 | 150 | if !reflect.DeepEqual(want2, got2) { 151 | t.Errorf("Want: '%v'\n", want2) 152 | t.Errorf("Got '%v'\n", got2) 153 | } 154 | } 155 | 156 | func TestBertForSequenceClassification(t *testing.T) { 157 | device := gotch.CPU 158 | vs := nn.NewVarStore(device) 159 | 160 | configFile, err := util.CachedPath("bert-base-uncased", "config.json") 161 | if err != nil { 162 | t.Error(err) 163 | } 164 | config, err := bert.ConfigFromFile(configFile) 165 | if err != nil { 166 | t.Error(err) 167 | } 168 | 169 | var dummyLabelMap map[int64]string = make(map[int64]string) 170 | dummyLabelMap[0] = "positive" 171 | dummyLabelMap[1] = "negative" 172 | dummyLabelMap[3] = "neutral" 173 | 174 | config.Id2Label = dummyLabelMap 175 | config.OutputAttentions = true 176 | config.OutputHiddenStates = true 177 | model := bert.NewBertForSequenceClassification(vs.Root(), config) 178 | 179 | vocabFile, err := util.CachedPath("bert-base-uncased", "vocab.txt") 180 | tk := getBertTokenizer(vocabFile) 181 | 182 | // Define input 183 | sentence1 := "Looks like one thing is missing" 184 | sentence2 := `It's like comparing oranges to apples` 185 | var input []tokenizer.EncodeInput 186 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1))) 187 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2))) 188 | encodings, err := tk.EncodeBatch(input, true) 189 | if err != nil { 190 | log.Fatal(err) 191 | } 192 | 193 | // Find max length of token Ids from slice of encodings 194 | var maxLen int = 0 195 | for _, en := range encodings { 196 | if len(en.Ids) > maxLen { 197 | maxLen = len(en.Ids) 198 | } 199 | } 200 | 201 | var tensors []ts.Tensor 202 | for _, en := range encodings { 203 | var tokInput []int64 = make([]int64, maxLen) 204 | for i := 0; i < len(en.Ids); i++ { 205 | tokInput[i] = int64(en.Ids[i]) 206 | } 207 | 208 | tensors = append(tensors, *ts.TensorFrom(tokInput)) 209 | } 210 | 211 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true) 212 | 213 | var ( 214 | output *ts.Tensor 215 | allHiddenStates, allAttentions []ts.Tensor 216 | ) 217 | 218 | ts.NoGrad(func() { 219 | output, allHiddenStates, allAttentions = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, false) 220 | }) 221 | 222 | fmt.Printf("output size: %v\n", output.MustSize()) 223 | gotOuputSize := output.MustSize() 224 | wantOuputSize := []int64{2, 3} 225 | if !reflect.DeepEqual(wantOuputSize, gotOuputSize) { 226 | t.Errorf("Want: %v\n", wantOuputSize) 227 | t.Errorf("Got: %v\n", gotOuputSize) 228 | } 229 | 230 | numHiddenLayers := int(config.NumHiddenLayers) 231 | 232 | if !reflect.DeepEqual(numHiddenLayers, len(allHiddenStates)) { 233 | t.Errorf("Want num of allHiddenStates: %v\n", numHiddenLayers) 234 | t.Errorf("Got num of allHiddenStates: %v\n", len(allHiddenStates)) 235 | } 236 | 237 | if !reflect.DeepEqual(numHiddenLayers, len(allAttentions)) { 238 | t.Errorf("Want num of allAttentions: %v\n", numHiddenLayers) 239 | t.Errorf("Got num of allAttentions: %v\n", len(allAttentions)) 240 | } 241 | } 242 | 243 | func TestBertForMultipleChoice(t *testing.T) { 244 | device := gotch.CPU 245 | vs := nn.NewVarStore(device) 246 | 247 | configFile, err := util.CachedPath("bert-base-uncased", "config.json") 248 | if err != nil { 249 | t.Error(err) 250 | } 251 | config, err := bert.ConfigFromFile(configFile) 252 | if err != nil { 253 | t.Error(err) 254 | } 255 | 256 | config.OutputAttentions = true 257 | config.OutputHiddenStates = true 258 | model := bert.NewBertForMultipleChoice(vs.Root(), config) 259 | 260 | vocabFile, err := util.CachedPath("bert-base-uncased", "vocab.txt") 261 | tk := getBertTokenizer(vocabFile) 262 | 263 | // Define input 264 | sentence1 := "Looks like one thing is missing" 265 | sentence2 := `It's like comparing oranges to apples` 266 | var input []tokenizer.EncodeInput 267 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1))) 268 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2))) 269 | encodings, err := tk.EncodeBatch(input, true) 270 | if err != nil { 271 | log.Fatal(err) 272 | } 273 | 274 | // Find max length of token Ids from slice of encodings 275 | var maxLen int = 0 276 | for _, en := range encodings { 277 | if len(en.Ids) > maxLen { 278 | maxLen = len(en.Ids) 279 | } 280 | } 281 | 282 | var tensors []ts.Tensor 283 | for _, en := range encodings { 284 | var tokInput []int64 = make([]int64, maxLen) 285 | for i := 0; i < len(en.Ids); i++ { 286 | tokInput[i] = int64(en.Ids[i]) 287 | } 288 | 289 | tensors = append(tensors, *ts.TensorFrom(tokInput)) 290 | } 291 | 292 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true).MustUnsqueeze(0, true) 293 | 294 | var ( 295 | output *ts.Tensor 296 | allHiddenStates, allAttentions []ts.Tensor 297 | ) 298 | 299 | ts.NoGrad(func() { 300 | output, allHiddenStates, allAttentions = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, false) 301 | }) 302 | 303 | fmt.Printf("output size: %v\n", output.MustSize()) 304 | gotOuputSize := output.MustSize() 305 | wantOuputSize := []int64{1, 2} 306 | if !reflect.DeepEqual(wantOuputSize, gotOuputSize) { 307 | t.Errorf("Want: %v\n", wantOuputSize) 308 | t.Errorf("Got: %v\n", gotOuputSize) 309 | } 310 | 311 | numHiddenLayers := int(config.NumHiddenLayers) 312 | 313 | if !reflect.DeepEqual(numHiddenLayers, len(allHiddenStates)) { 314 | t.Errorf("Want num of allHiddenStates: %v\n", numHiddenLayers) 315 | t.Errorf("Got num of allHiddenStates: %v\n", len(allHiddenStates)) 316 | } 317 | 318 | if !reflect.DeepEqual(numHiddenLayers, len(allAttentions)) { 319 | t.Errorf("Want num of allAttentions: %v\n", numHiddenLayers) 320 | t.Errorf("Got num of allAttentions: %v\n", len(allAttentions)) 321 | } 322 | } 323 | 324 | func TestBertForTokenClassification(t *testing.T) { 325 | device := gotch.CPU 326 | vs := nn.NewVarStore(device) 327 | 328 | configFile, err := util.CachedPath("bert-base-uncased", "config.json") 329 | if err != nil { 330 | t.Error(err) 331 | } 332 | config, err := bert.ConfigFromFile(configFile) 333 | if err != nil { 334 | t.Error(err) 335 | } 336 | 337 | var dummyLabelMap map[int64]string = make(map[int64]string) 338 | dummyLabelMap[0] = "O" 339 | dummyLabelMap[1] = "LOC" 340 | dummyLabelMap[2] = "PER" 341 | dummyLabelMap[3] = "ORG" 342 | 343 | config.Id2Label = dummyLabelMap 344 | config.OutputAttentions = true 345 | config.OutputHiddenStates = true 346 | model := bert.NewBertForTokenClassification(vs.Root(), config) 347 | 348 | vocabFile, err := util.CachedPath("bert-base-uncased", "vocab.txt") 349 | tk := getBertTokenizer(vocabFile) 350 | 351 | // Define input 352 | sentence1 := "Looks like one thing is missing" 353 | sentence2 := `It's like comparing oranges to apples` 354 | var input []tokenizer.EncodeInput 355 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1))) 356 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2))) 357 | encodings, err := tk.EncodeBatch(input, true) 358 | if err != nil { 359 | log.Fatal(err) 360 | } 361 | 362 | // Find max length of token Ids from slice of encodings 363 | var maxLen int = 0 364 | for _, en := range encodings { 365 | if len(en.Ids) > maxLen { 366 | maxLen = len(en.Ids) 367 | } 368 | } 369 | 370 | var tensors []ts.Tensor 371 | for _, en := range encodings { 372 | var tokInput []int64 = make([]int64, maxLen) 373 | for i := 0; i < len(en.Ids); i++ { 374 | tokInput[i] = int64(en.Ids[i]) 375 | } 376 | 377 | tensors = append(tensors, *ts.TensorFrom(tokInput)) 378 | } 379 | 380 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true) 381 | 382 | var ( 383 | output *ts.Tensor 384 | allHiddenStates, allAttentions []ts.Tensor 385 | ) 386 | 387 | ts.NoGrad(func() { 388 | output, allHiddenStates, allAttentions = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, false) 389 | }) 390 | 391 | fmt.Printf("output size: %v\n", output.MustSize()) 392 | gotOuputSize := output.MustSize() 393 | wantOuputSize := []int64{2, 11, 4} 394 | if !reflect.DeepEqual(wantOuputSize, gotOuputSize) { 395 | t.Errorf("Want: %v\n", wantOuputSize) 396 | t.Errorf("Got: %v\n", gotOuputSize) 397 | } 398 | 399 | numHiddenLayers := int(config.NumHiddenLayers) 400 | 401 | if !reflect.DeepEqual(numHiddenLayers, len(allHiddenStates)) { 402 | t.Errorf("Want num of allHiddenStates: %v\n", numHiddenLayers) 403 | t.Errorf("Got num of allHiddenStates: %v\n", len(allHiddenStates)) 404 | } 405 | 406 | if !reflect.DeepEqual(numHiddenLayers, len(allAttentions)) { 407 | t.Errorf("Want num of allAttentions: %v\n", numHiddenLayers) 408 | t.Errorf("Got num of allAttentions: %v\n", len(allAttentions)) 409 | } 410 | } 411 | 412 | func TestBertForQuestionAnswering(t *testing.T) { 413 | device := gotch.CPU 414 | vs := nn.NewVarStore(device) 415 | 416 | configFile, err := util.CachedPath("bert-base-uncased", "config.json") 417 | if err != nil { 418 | t.Error(err) 419 | } 420 | config, err := bert.ConfigFromFile(configFile) 421 | if err != nil { 422 | t.Error(err) 423 | } 424 | 425 | config.OutputAttentions = true 426 | config.OutputHiddenStates = true 427 | model := bert.NewForBertQuestionAnswering(vs.Root(), config) 428 | 429 | vocabFile, err := util.CachedPath("bert-base-uncased", "vocab.txt") 430 | tk := getBertTokenizer(vocabFile) 431 | 432 | // Define input 433 | sentence1 := "Looks like one thing is missing" 434 | sentence2 := `It's like comparing oranges to apples` 435 | var input []tokenizer.EncodeInput 436 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1))) 437 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2))) 438 | encodings, err := tk.EncodeBatch(input, true) 439 | if err != nil { 440 | log.Fatal(err) 441 | } 442 | 443 | // Find max length of token Ids from slice of encodings 444 | var maxLen int = 0 445 | for _, en := range encodings { 446 | if len(en.Ids) > maxLen { 447 | maxLen = len(en.Ids) 448 | } 449 | } 450 | 451 | var tensors []ts.Tensor 452 | for _, en := range encodings { 453 | var tokInput []int64 = make([]int64, maxLen) 454 | for i := 0; i < len(en.Ids); i++ { 455 | tokInput[i] = int64(en.Ids[i]) 456 | } 457 | 458 | tensors = append(tensors, *ts.TensorFrom(tokInput)) 459 | } 460 | 461 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true) 462 | 463 | var ( 464 | startScores, endScores *ts.Tensor 465 | allHiddenStates, allAttentions []ts.Tensor 466 | ) 467 | 468 | ts.NoGrad(func() { 469 | startScores, endScores, allHiddenStates, allAttentions = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, false) 470 | }) 471 | 472 | gotStartScoresSize := startScores.MustSize() 473 | wantStartScoresSize := []int64{2, 11} 474 | if !reflect.DeepEqual(wantStartScoresSize, gotStartScoresSize) { 475 | t.Errorf("Want: %v\n", wantStartScoresSize) 476 | t.Errorf("Got: %v\n", gotStartScoresSize) 477 | } 478 | 479 | gotEndScoresSize := endScores.MustSize() 480 | wantEndScoresSize := []int64{2, 11} 481 | if !reflect.DeepEqual(wantEndScoresSize, gotEndScoresSize) { 482 | t.Errorf("Want: %v\n", wantEndScoresSize) 483 | t.Errorf("Got: %v\n", gotEndScoresSize) 484 | } 485 | 486 | numHiddenLayers := int(config.NumHiddenLayers) 487 | 488 | if !reflect.DeepEqual(numHiddenLayers, len(allHiddenStates)) { 489 | t.Errorf("Want num of allHiddenStates: %v\n", numHiddenLayers) 490 | t.Errorf("Got num of allHiddenStates: %v\n", len(allHiddenStates)) 491 | } 492 | 493 | if !reflect.DeepEqual(numHiddenLayers, len(allAttentions)) { 494 | t.Errorf("Want num of allAttentions: %v\n", numHiddenLayers) 495 | t.Errorf("Got num of allAttentions: %v\n", len(allAttentions)) 496 | } 497 | } 498 | -------------------------------------------------------------------------------- /roberta/model_test.go: -------------------------------------------------------------------------------- 1 | package roberta_test 2 | 3 | import ( 4 | // "fmt" 5 | "log" 6 | "reflect" 7 | "testing" 8 | 9 | "github.com/sugarme/gotch" 10 | "github.com/sugarme/gotch/nn" 11 | "github.com/sugarme/gotch/pickle" 12 | "github.com/sugarme/gotch/ts" 13 | "github.com/sugarme/tokenizer" 14 | "github.com/sugarme/tokenizer/model/bpe" 15 | "github.com/sugarme/tokenizer/normalizer" 16 | "github.com/sugarme/tokenizer/pretokenizer" 17 | "github.com/sugarme/tokenizer/processor" 18 | "github.com/sugarme/transformer/bert" 19 | "github.com/sugarme/transformer/roberta" 20 | "github.com/sugarme/transformer/util" 21 | ) 22 | 23 | func getRobertaTokenizer(vocabFile string, mergeFile string) (retVal *tokenizer.Tokenizer) { 24 | 25 | model, err := bpe.NewBpeFromFiles(vocabFile, mergeFile) 26 | if err != nil { 27 | log.Fatal(err) 28 | } 29 | 30 | tk := tokenizer.NewTokenizer(model) 31 | 32 | bertNormalizer := normalizer.NewBertNormalizer(true, true, true, true) 33 | tk.WithNormalizer(bertNormalizer) 34 | 35 | blPreTokenizer := pretokenizer.NewByteLevel() 36 | // blPreTokenizer.SetAddPrefixSpace(false) 37 | tk.WithPreTokenizer(blPreTokenizer) 38 | 39 | var specialTokens []tokenizer.AddedToken 40 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true)) 41 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true)) 42 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true)) 43 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true)) 44 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true)) 45 | tk.AddSpecialTokens(specialTokens) 46 | 47 | postProcess := processor.DefaultRobertaProcessing() 48 | tk.WithPostProcessor(postProcess) 49 | 50 | return tk 51 | } 52 | 53 | func TestRobertaForMaskedLM(t *testing.T) { 54 | 55 | // Config 56 | config := new(bert.BertConfig) 57 | configFile, err := util.CachedPath("roberta-base", "config.json") 58 | if err != nil { 59 | log.Fatal(err) 60 | } 61 | err = config.Load(configFile, nil) 62 | if err != nil { 63 | log.Fatal(err) 64 | } 65 | 66 | // Model 67 | device := gotch.CPU 68 | vs := nn.NewVarStore(device) 69 | 70 | model, err := roberta.NewRobertaForMaskedLM(vs.Root(), config) 71 | if err != nil { 72 | log.Fatal(err) 73 | } 74 | 75 | modelFile, err := util.CachedPath("roberta-base", "pytorch_model.bin") 76 | if err != nil { 77 | log.Fatal(err) 78 | } 79 | // err = vs.Load("../data/roberta/roberta-base-model.gt") 80 | err = pickle.LoadAll(vs, modelFile) 81 | if err != nil { 82 | log.Fatal(err) 83 | } 84 | 85 | // Roberta tokenizer 86 | vocabFile, err := util.CachedPath("roberta-base", "vocab.json") 87 | if err != nil { 88 | log.Fatal(err) 89 | } 90 | mergesFile, err := util.CachedPath("roberta-base", "merges.txt") 91 | if err != nil { 92 | log.Fatal(err) 93 | } 94 | 95 | tk := getRobertaTokenizer(vocabFile, mergesFile) 96 | 97 | sentence1 := "Looks like one is missing" 98 | sentence2 := "It's like comparing to apples" 99 | 100 | input := []tokenizer.EncodeInput{ 101 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)), 102 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)), 103 | } 104 | 105 | /* 106 | * // NOTE: EncodeBatch does encode concurrently, so it does not keep original order! 107 | * encodings, err := tk.EncodeBatch(input, true) 108 | * if err != nil { 109 | * log.Fatal(err) 110 | * } 111 | * */ 112 | var encodings []tokenizer.Encoding 113 | for _, i := range input { 114 | en, err := tk.Encode(i, true) 115 | if err != nil { 116 | log.Fatal(err) 117 | } 118 | encodings = append(encodings, *en) 119 | } 120 | 121 | // fmt.Printf("encodings:\n%+v\n", encodings) 122 | 123 | var maxLen int = 0 124 | for _, en := range encodings { 125 | if len(en.Ids) > maxLen { 126 | maxLen = len(en.Ids) 127 | } 128 | } 129 | 130 | var tensors []ts.Tensor 131 | for _, en := range encodings { 132 | var tokInput []int64 = make([]int64, maxLen) 133 | for i := 0; i < len(en.Ids); i++ { 134 | tokInput[i] = int64(en.Ids[i]) 135 | } 136 | 137 | tensors = append(tensors, *ts.TensorFrom(tokInput)) 138 | } 139 | 140 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true) 141 | 142 | var output *ts.Tensor 143 | ts.NoGrad(func() { 144 | output, _, _, err = model.Forward(inputTensor, ts.None, ts.None, ts.None, ts.None, ts.None, ts.None, false) 145 | if err != nil { 146 | log.Fatal(err) 147 | } 148 | }) 149 | 150 | index1 := output.MustGet(0).MustGet(4).MustArgmax([]int64{0}, false, false).Int64Values()[0] 151 | index2 := output.MustGet(1).MustGet(5).MustArgmax([]int64{0}, false, false).Int64Values()[0] 152 | gotMask1 := tk.Decode([]int{int(index1)}, false) 153 | gotMask2 := tk.Decode([]int{int(index2)}, false) 154 | 155 | // fmt.Printf("index1: '%v' - mask1: '%v'\n", index1, gotMask1) 156 | // fmt.Printf("index2: '%v' - mask2: '%v'\n", index2, gotMask2) 157 | wantMask1 := "Ġperson" 158 | wantMask2 := "Ġapples" 159 | 160 | if !reflect.DeepEqual(wantMask1, gotMask1) { 161 | t.Errorf("Want: %v got %v\n", wantMask1, gotMask1) 162 | } 163 | 164 | if !reflect.DeepEqual(wantMask2, gotMask2) { 165 | t.Errorf("Want: %v got %v\n", wantMask2, gotMask2) 166 | } 167 | 168 | } 169 | 170 | func TestRobertaForSequenceClassification(t *testing.T) { 171 | 172 | // Config 173 | config := new(bert.BertConfig) 174 | configFile, err := util.CachedPath("roberta-base", "config.json") 175 | if err != nil { 176 | log.Fatal(err) 177 | } 178 | err = config.Load(configFile, nil) 179 | if err != nil { 180 | log.Fatal(err) 181 | } 182 | 183 | // Model 184 | device := gotch.CPU 185 | vs := nn.NewVarStore(device) 186 | 187 | var dummyLabelMap map[int64]string = make(map[int64]string) 188 | dummyLabelMap[0] = "Positive" 189 | dummyLabelMap[1] = "Negative" 190 | dummyLabelMap[3] = "Neutral" 191 | 192 | config.Id2Label = dummyLabelMap 193 | config.OutputAttentions = true 194 | config.OutputHiddenStates = true 195 | model := roberta.NewRobertaForSequenceClassification(vs.Root(), config) 196 | 197 | // Roberta tokenizer 198 | vocabFile, err := util.CachedPath("roberta-base", "vocab.json") 199 | if err != nil { 200 | log.Fatal(err) 201 | } 202 | mergesFile, err := util.CachedPath("roberta-base", "merges.txt") 203 | if err != nil { 204 | log.Fatal(err) 205 | } 206 | 207 | tk := getRobertaTokenizer(vocabFile, mergesFile) 208 | 209 | sentence1 := "Looks like one thing is missing" 210 | sentence2 := "It's like comparing oranges to apples" 211 | 212 | input := []tokenizer.EncodeInput{ 213 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)), 214 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)), 215 | } 216 | 217 | var encodings []tokenizer.Encoding 218 | for _, i := range input { 219 | en, err := tk.Encode(i, true) 220 | if err != nil { 221 | log.Fatal(err) 222 | } 223 | encodings = append(encodings, *en) 224 | } 225 | 226 | var maxLen int = 0 227 | for _, en := range encodings { 228 | if len(en.Ids) > maxLen { 229 | maxLen = len(en.Ids) 230 | } 231 | } 232 | 233 | var tensors []ts.Tensor 234 | for _, en := range encodings { 235 | var tokInput []int64 = make([]int64, maxLen) 236 | for i := 0; i < len(en.Ids); i++ { 237 | tokInput[i] = int64(en.Ids[i]) 238 | } 239 | 240 | tensors = append(tensors, *ts.TensorFrom(tokInput)) 241 | } 242 | 243 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true) 244 | 245 | var ( 246 | output *ts.Tensor 247 | hiddenStates, attentions []ts.Tensor 248 | ) 249 | 250 | ts.NoGrad(func() { 251 | output, hiddenStates, attentions, err = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, false) 252 | if err != nil { 253 | log.Fatal(err) 254 | } 255 | }) 256 | 257 | wantOutput := []int64{2, 3} 258 | gotOutput := output.MustSize() 259 | 260 | wantNumHiddenLayers := config.NumHiddenLayers 261 | gotNumHiddenLayers := int64(len(hiddenStates)) 262 | 263 | wantAttentions := config.NumHiddenLayers 264 | gotAttentions := int64(len(attentions)) 265 | 266 | if !reflect.DeepEqual(wantOutput, gotOutput) { 267 | t.Errorf("want %v - got %v\n", wantOutput, gotOutput) 268 | } 269 | 270 | if !reflect.DeepEqual(wantNumHiddenLayers, gotNumHiddenLayers) { 271 | t.Errorf("want %v - got %v\n", wantNumHiddenLayers, gotNumHiddenLayers) 272 | } 273 | 274 | if !reflect.DeepEqual(wantAttentions, gotAttentions) { 275 | t.Errorf("want %v - got %v\n", wantAttentions, gotAttentions) 276 | } 277 | } 278 | 279 | func TestRobertaForMultipleChoice(t *testing.T) { 280 | // Config 281 | config := new(bert.BertConfig) 282 | config.OutputAttentions = true 283 | config.OutputHiddenStates = true 284 | configFile, err := util.CachedPath("roberta-base", "config.json") 285 | if err != nil { 286 | log.Fatal(err) 287 | } 288 | err = config.Load(configFile, nil) 289 | if err != nil { 290 | log.Fatal(err) 291 | } 292 | 293 | // Model 294 | device := gotch.CPU 295 | vs := nn.NewVarStore(device) 296 | 297 | model := roberta.NewRobertaForMultipleChoice(vs.Root(), config) 298 | 299 | // Roberta tokenizer 300 | vocabFile, err := util.CachedPath("roberta-base", "vocab.json") 301 | if err != nil { 302 | log.Fatal(err) 303 | } 304 | mergesFile, err := util.CachedPath("roberta-base", "merges.txt") 305 | if err != nil { 306 | log.Fatal(err) 307 | } 308 | 309 | tk := getRobertaTokenizer(vocabFile, mergesFile) 310 | 311 | sentence1 := "Looks like one is missing" 312 | sentence2 := "It's like comparing to apples" 313 | 314 | input := []tokenizer.EncodeInput{ 315 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)), 316 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)), 317 | } 318 | 319 | var encodings []tokenizer.Encoding 320 | for _, i := range input { 321 | en, err := tk.Encode(i, true) 322 | if err != nil { 323 | log.Fatal(err) 324 | } 325 | encodings = append(encodings, *en) 326 | } 327 | 328 | var maxLen int = 0 329 | for _, en := range encodings { 330 | if len(en.Ids) > maxLen { 331 | maxLen = len(en.Ids) 332 | } 333 | } 334 | 335 | var tensors []ts.Tensor 336 | for _, en := range encodings { 337 | var tokInput []int64 = make([]int64, maxLen) 338 | for i := 0; i < len(en.Ids); i++ { 339 | tokInput[i] = int64(en.Ids[i]) 340 | } 341 | 342 | tensors = append(tensors, *ts.TensorFrom(tokInput)) 343 | } 344 | 345 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true).MustUnsqueeze(0, true) 346 | 347 | var ( 348 | output *ts.Tensor 349 | hiddenStates, attentions []ts.Tensor 350 | ) 351 | ts.NoGrad(func() { 352 | output, hiddenStates, attentions, err = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, false) 353 | if err != nil { 354 | log.Fatal(err) 355 | } 356 | }) 357 | 358 | wantOutput := []int64{1, 2} 359 | gotOutput := output.MustSize() 360 | 361 | wantHiddenStates := config.NumHiddenLayers 362 | gotHiddenStates := int64(len(hiddenStates)) 363 | 364 | wantAttentions := config.NumHiddenLayers 365 | gotAttentions := int64(len(attentions)) 366 | 367 | if !reflect.DeepEqual(wantOutput, gotOutput) { 368 | t.Errorf("want %v - got %v\n", wantOutput, gotOutput) 369 | } 370 | 371 | if !reflect.DeepEqual(wantHiddenStates, gotHiddenStates) { 372 | t.Errorf("want %v - got %v\n", wantHiddenStates, gotHiddenStates) 373 | } 374 | 375 | if !reflect.DeepEqual(wantAttentions, gotAttentions) { 376 | t.Errorf("want %v - got %v\n", wantAttentions, gotAttentions) 377 | } 378 | } 379 | 380 | func TestRobertaForTokenClassification(t *testing.T) { 381 | 382 | // Config 383 | config := new(bert.BertConfig) 384 | configFile, err := util.CachedPath("roberta-base", "config.json") 385 | if err != nil { 386 | log.Fatal(err) 387 | } 388 | err = config.Load(configFile, nil) 389 | if err != nil { 390 | log.Fatal(err) 391 | } 392 | 393 | // Model 394 | device := gotch.CPU 395 | vs := nn.NewVarStore(device) 396 | 397 | var dummyLabelMap map[int64]string = make(map[int64]string) 398 | dummyLabelMap[0] = "O" 399 | dummyLabelMap[1] = "LOC" 400 | dummyLabelMap[2] = "PER" 401 | dummyLabelMap[3] = "ORG" 402 | 403 | config.Id2Label = dummyLabelMap 404 | config.OutputAttentions = true 405 | config.OutputHiddenStates = true 406 | model := roberta.NewRobertaForTokenClassification(vs.Root(), config) 407 | 408 | // Roberta tokenizer 409 | vocabFile, err := util.CachedPath("roberta-base", "vocab.json") 410 | if err != nil { 411 | log.Fatal(err) 412 | } 413 | mergesFile, err := util.CachedPath("roberta-base", "merges.txt") 414 | if err != nil { 415 | log.Fatal(err) 416 | } 417 | 418 | tk := getRobertaTokenizer(vocabFile, mergesFile) 419 | 420 | sentence1 := "Looks like one thing is missing" 421 | sentence2 := "It's like comparing oranges to apples" 422 | 423 | input := []tokenizer.EncodeInput{ 424 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)), 425 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)), 426 | } 427 | 428 | var encodings []tokenizer.Encoding 429 | for _, i := range input { 430 | en, err := tk.Encode(i, true) 431 | if err != nil { 432 | log.Fatal(err) 433 | } 434 | encodings = append(encodings, *en) 435 | } 436 | 437 | var maxLen int = 0 438 | for _, en := range encodings { 439 | if len(en.Ids) > maxLen { 440 | maxLen = len(en.Ids) 441 | } 442 | } 443 | 444 | var tensors []ts.Tensor 445 | for _, en := range encodings { 446 | var tokInput []int64 = make([]int64, maxLen) 447 | for i := 0; i < len(en.Ids); i++ { 448 | tokInput[i] = int64(en.Ids[i]) 449 | } 450 | 451 | tensors = append(tensors, *ts.TensorFrom(tokInput)) 452 | } 453 | 454 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true) 455 | 456 | var ( 457 | output *ts.Tensor 458 | hiddenStates, attentions []ts.Tensor 459 | ) 460 | 461 | ts.NoGrad(func() { 462 | output, hiddenStates, attentions, err = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, false) 463 | if err != nil { 464 | log.Fatal(err) 465 | } 466 | }) 467 | 468 | wantOutput := []int64{2, 9, 4} 469 | gotOutput := output.MustSize() 470 | 471 | wantNumHiddenLayers := config.NumHiddenLayers 472 | gotNumHiddenLayers := int64(len(hiddenStates)) 473 | 474 | wantAttentions := config.NumHiddenLayers 475 | gotAttentions := int64(len(attentions)) 476 | 477 | if !reflect.DeepEqual(wantOutput, gotOutput) { 478 | t.Errorf("want %v - got %v\n", wantOutput, gotOutput) 479 | } 480 | 481 | if !reflect.DeepEqual(wantNumHiddenLayers, gotNumHiddenLayers) { 482 | t.Errorf("want %v - got %v\n", wantNumHiddenLayers, gotNumHiddenLayers) 483 | } 484 | 485 | if !reflect.DeepEqual(wantAttentions, gotAttentions) { 486 | t.Errorf("want %v - got %v\n", wantAttentions, gotAttentions) 487 | } 488 | } 489 | 490 | func TestRobertaForQuestionAnswering(t *testing.T) { 491 | 492 | // Config 493 | config := new(bert.BertConfig) 494 | configFile, err := util.CachedPath("roberta-base", "config.json") 495 | if err != nil { 496 | log.Fatal(err) 497 | } 498 | 499 | err = config.Load(configFile, nil) 500 | if err != nil { 501 | log.Fatal(err) 502 | } 503 | 504 | // Model 505 | device := gotch.CPU 506 | vs := nn.NewVarStore(device) 507 | 508 | var dummyLabelMap map[int64]string = make(map[int64]string) 509 | dummyLabelMap[0] = "Positive" 510 | dummyLabelMap[1] = "Negative" 511 | dummyLabelMap[3] = "Neutral" 512 | 513 | config.Id2Label = dummyLabelMap 514 | config.OutputAttentions = true 515 | config.OutputHiddenStates = true 516 | model := roberta.NewRobertaForQuestionAnswering(vs.Root(), config) 517 | 518 | // Roberta tokenizer 519 | vocabFile, err := util.CachedPath("roberta-base", "vocab.json") 520 | if err != nil { 521 | log.Fatal(err) 522 | } 523 | mergesFile, err := util.CachedPath("roberta-base", "merges.txt") 524 | if err != nil { 525 | log.Fatal(err) 526 | } 527 | 528 | tk := getRobertaTokenizer(vocabFile, mergesFile) 529 | 530 | sentence1 := "Looks like one thing is missing" 531 | sentence2 := "It's like comparing oranges to apples" 532 | 533 | input := []tokenizer.EncodeInput{ 534 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)), 535 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)), 536 | } 537 | 538 | var encodings []tokenizer.Encoding 539 | for _, i := range input { 540 | en, err := tk.Encode(i, true) 541 | if err != nil { 542 | log.Fatal(err) 543 | } 544 | encodings = append(encodings, *en) 545 | } 546 | 547 | var maxLen int = 0 548 | for _, en := range encodings { 549 | if len(en.Ids) > maxLen { 550 | maxLen = len(en.Ids) 551 | } 552 | } 553 | 554 | var tensors []ts.Tensor 555 | for _, en := range encodings { 556 | var tokInput []int64 = make([]int64, maxLen) 557 | for i := 0; i < len(en.Ids); i++ { 558 | tokInput[i] = int64(en.Ids[i]) 559 | } 560 | 561 | tensors = append(tensors, *ts.TensorFrom(tokInput)) 562 | } 563 | 564 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true) 565 | 566 | var ( 567 | startScores, endScores *ts.Tensor 568 | hiddenStates, attentions []ts.Tensor 569 | ) 570 | 571 | ts.NoGrad(func() { 572 | startScores, endScores, hiddenStates, attentions, err = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, false) 573 | if err != nil { 574 | log.Fatal(err) 575 | } 576 | }) 577 | 578 | wantStartScores := []int64{2, 9} 579 | gotStartScores := startScores.MustSize() 580 | 581 | wantEndScores := []int64{2, 9} 582 | gotEndScores := endScores.MustSize() 583 | 584 | wantNumHiddenLayers := config.NumHiddenLayers 585 | gotNumHiddenLayers := int64(len(hiddenStates)) 586 | 587 | wantAttentions := config.NumHiddenLayers 588 | gotAttentions := int64(len(attentions)) 589 | 590 | if !reflect.DeepEqual(wantStartScores, gotStartScores) { 591 | t.Errorf("want %v - got %v\n", wantStartScores, gotStartScores) 592 | } 593 | 594 | if !reflect.DeepEqual(wantEndScores, gotEndScores) { 595 | t.Errorf("want %v - got %v\n", wantEndScores, gotEndScores) 596 | } 597 | 598 | if !reflect.DeepEqual(wantNumHiddenLayers, gotNumHiddenLayers) { 599 | t.Errorf("want %v - got %v\n", wantNumHiddenLayers, gotNumHiddenLayers) 600 | } 601 | 602 | if !reflect.DeepEqual(wantAttentions, gotAttentions) { 603 | t.Errorf("want %v - got %v\n", wantAttentions, gotAttentions) 604 | } 605 | } 606 | 607 | func TestRobertaQA(t *testing.T) { 608 | // TODO: implement via pipelines 609 | } 610 | 611 | func TestRobertaNER(t *testing.T) { 612 | // TODO: implement via pipelines 613 | } 614 | -------------------------------------------------------------------------------- /roberta/model.go: -------------------------------------------------------------------------------- 1 | package roberta 2 | 3 | // roberta package implements Roberta transformer model. 4 | 5 | import ( 6 | // "fmt" 7 | 8 | "github.com/sugarme/gotch" 9 | "github.com/sugarme/gotch/nn" 10 | "github.com/sugarme/gotch/pickle" 11 | "github.com/sugarme/gotch/ts" 12 | 13 | "github.com/sugarme/transformer/bert" 14 | "github.com/sugarme/transformer/pretrained" 15 | "github.com/sugarme/transformer/util" 16 | ) 17 | 18 | // RobertaLMHead holds data of Roberta LM head. 19 | type RobertaLMHead struct { 20 | dense *nn.Linear 21 | decoder *util.LinearNoBias 22 | layerNorm *nn.LayerNorm 23 | bias *ts.Tensor 24 | } 25 | 26 | // NewRobertaLMHead creates new RobertaLMHead. 27 | func NewRobertaLMHead(p *nn.Path, config *bert.BertConfig) (*RobertaLMHead, error) { 28 | dense := nn.NewLinear(p.Sub("dense"), config.HiddenSize, config.HiddenSize, nn.DefaultLinearConfig()) 29 | 30 | layerNormConfig := nn.DefaultLayerNormConfig() 31 | layerNormConfig.Eps = 1e-12 32 | layerNorm := nn.NewLayerNorm(p.Sub("layer_norm"), []int64{config.HiddenSize}, layerNormConfig) 33 | 34 | decoder, err := util.NewLinearNoBias(p.Sub("decoder"), config.HiddenSize, config.VocabSize, util.DefaultLinearNoBiasConfig()) 35 | if err != nil { 36 | return nil, err 37 | } 38 | 39 | bias, err := p.NewVar("bias", []int64{config.VocabSize}, nn.NewKaimingUniformInit()) 40 | if err != nil { 41 | return nil, err 42 | } 43 | 44 | return &RobertaLMHead{ 45 | dense: dense, 46 | decoder: decoder, 47 | layerNorm: layerNorm, 48 | bias: bias, 49 | }, nil 50 | } 51 | 52 | // Foward forwards pass through RobertaLMHead model. 53 | func (rh *RobertaLMHead) Forward(hiddenStates *ts.Tensor) *ts.Tensor { 54 | gelu := util.NewGelu() 55 | appliedDense := hiddenStates.Apply(rh.dense) 56 | geluFwd := gelu.Fwd(appliedDense) 57 | appliedLN := geluFwd.Apply(rh.layerNorm) 58 | appliedDecoder := appliedLN.Apply(rh.decoder) 59 | appliedBias := appliedDecoder.MustAdd(rh.bias, true) 60 | 61 | geluFwd.MustDrop() 62 | appliedDense.MustDrop() 63 | appliedLN.MustDrop() 64 | 65 | return appliedBias 66 | } 67 | 68 | // RobertaForMaskedLM holds data for Roberta masked language model. 69 | // 70 | // Base RoBERTa model with a RoBERTa masked language model head to predict 71 | // missing tokens. 72 | type RobertaForMaskedLM struct { 73 | roberta *bert.BertModel 74 | lmHead *RobertaLMHead 75 | } 76 | 77 | // NewRobertaForMaskedLM builds a new RobertaForMaskedLM. 78 | func NewRobertaForMaskedLM(p *nn.Path, config *bert.BertConfig) (*RobertaForMaskedLM, error) { 79 | roberta := bert.NewBertModel(p.Sub("roberta"), config, false) 80 | lmHead, err := NewRobertaLMHead(p.Sub("lm_head"), config) 81 | if err != nil { 82 | return nil, err 83 | } 84 | 85 | return &RobertaForMaskedLM{ 86 | roberta: roberta, 87 | lmHead: lmHead, 88 | }, nil 89 | } 90 | 91 | // Load loads model from file or model name. It also updates 92 | // default configuration parameters if provided. 93 | // This method implements `PretrainedModel` interface. 94 | func (mlm *RobertaForMaskedLM) Load(modelNameOrPath string, config interface{ pretrained.Config }, params map[string]interface{}, device gotch.Device) error { 95 | // var urlOrFilename string 96 | // // If modelName, infer to default configuration filename: 97 | // if modelFile, ok := pretrained.RobertaModels[modelNameOrPath]; ok { 98 | // urlOrFilename = modelFile 99 | // } else { 100 | // // Otherwise, just take the input 101 | // urlOrFilename = modelNameOrPath 102 | // } 103 | 104 | cachedFile, err := util.CachedPath(modelNameOrPath, "pytorch_model.bin") 105 | if err != nil { 106 | return err 107 | } 108 | 109 | vs := nn.NewVarStore(device) 110 | p := vs.Root() 111 | 112 | mlm.roberta = bert.NewBertModel(p.Sub("roberta"), config.(*bert.BertConfig), false) 113 | mlm.lmHead, err = NewRobertaLMHead(p.Sub("lm_head"), config.(*bert.BertConfig)) 114 | 115 | // err = vs.Load(cachedFile) 116 | err = pickle.LoadAll(vs, cachedFile) 117 | if err != nil { 118 | return err 119 | } 120 | 121 | return nil 122 | } 123 | 124 | // Forwad forwads pass through the model. 125 | // 126 | // Params: 127 | // - `inputIds`: Optional input tensor of shape (batch size, sequence length). 128 | // If None, pre-computed embeddings must be provided (see inputEmbeds). 129 | // - `mask`: Optional mask of shape (batch size, sequence length). 130 | // Masked position have value 0, non-masked value 1. If None set to 1. 131 | // - `tokenTypeIds`: Optional segment id of shape (batch size, sequence length). 132 | // Convention is value of 0 for the first sentence (incl. ) and 1 for the 133 | // second sentence. If None set to 0. 134 | // - `positionIds`: Optional position ids of shape (batch size, sequence length). 135 | // If None, will be incremented from 0. 136 | // - `inputEmbeds`: Optional pre-computed input embeddings of shape (batch size, 137 | // sequence length, hidden size). If None, input ids must be provided (see inputIds). 138 | // - `encoderHiddenStates`: Optional encoder hidden state of shape (batch size, 139 | // encoder sequence length, hidden size). If the model is defined as a decoder and 140 | // the encoder hidden states is not None, used in the cross-attention layer as 141 | // keys and values (query from the decoder). 142 | // - `encoderMask`: Optional encoder attention mask of shape (batch size, encoder sequence length). 143 | // If the model is defined as a decoder and the *encoder_hidden_states* is not None, 144 | // used to mask encoder values. Positions with value 0 will be masked. 145 | // - `train`: boolean flag to turn on/off the dropout layers in the model. 146 | // Should be set to false for inference. 147 | // 148 | // Returns: 149 | // - `output`: tensor of shape (batch size, numLabels, vocab size) 150 | // - `hiddenStates`: optional slice of tensors of length numHiddenLayers with shape 151 | // (batch size, sequence length, hidden size). 152 | // - `attentions`: optional slice of tensors of length num hidden layers with shape 153 | // (batch size, sequence length, hidden size). 154 | // - `err`: error 155 | func (mlm *RobertaForMaskedLM) Forward(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, encoderHiddenStates, encoderMask *ts.Tensor, train bool) (output *ts.Tensor, hiddenStates, attentions []ts.Tensor, err error) { 156 | 157 | hiddenState, _, allHiddenStates, allAttentions, err := mlm.roberta.ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, encoderHiddenStates, encoderMask, train) 158 | 159 | if err != nil { 160 | return ts.None, nil, nil, err 161 | } 162 | 163 | predictionScores := mlm.lmHead.Forward(hiddenState) 164 | 165 | return predictionScores, allHiddenStates, allAttentions, nil 166 | } 167 | 168 | // RoberatClassificationHead holds data for Roberta classification head. 169 | type RobertaClassificationHead struct { 170 | dense *nn.Linear 171 | dropout *util.Dropout 172 | outProj *nn.Linear 173 | } 174 | 175 | // NewRobertaClassificationHead create a new RobertaClassificationHead. 176 | func NewRobertaClassificationHead(p *nn.Path, config *bert.BertConfig) *RobertaClassificationHead { 177 | dense := nn.NewLinear(p.Sub("dense"), config.HiddenSize, config.HiddenSize, nn.DefaultLinearConfig()) 178 | numLabels := int64(len(config.Id2Label)) 179 | outProj := nn.NewLinear(p.Sub("out_proj"), config.HiddenSize, numLabels, nn.DefaultLinearConfig()) 180 | dropout := util.NewDropout(config.HiddenDropoutProb) 181 | 182 | return &RobertaClassificationHead{ 183 | dense: dense, 184 | dropout: dropout, 185 | outProj: outProj, 186 | } 187 | } 188 | 189 | // ForwardT forwards pass through model. 190 | func (ch *RobertaClassificationHead) ForwardT(hiddenStates *ts.Tensor, train bool) *ts.Tensor { 191 | appliedDO1 := hiddenStates.MustSelect(1, 0, false).ApplyT(ch.dropout, train) 192 | appliedDense := appliedDO1.Apply(ch.dense) 193 | tanhTs := appliedDense.MustTanh(false) 194 | appliedDO2 := tanhTs.ApplyT(ch.dropout, train) 195 | retVal := appliedDO2.Apply(ch.outProj) 196 | 197 | appliedDO1.MustDrop() 198 | appliedDense.MustDrop() 199 | tanhTs.MustDrop() 200 | appliedDO2.MustDrop() 201 | 202 | return retVal 203 | } 204 | 205 | // RobertaForSequenceClassification holds data for Roberta sequence classification model. 206 | // It's used for performing sentence or document-level classification. 207 | type RobertaForSequenceClassification struct { 208 | roberta *bert.BertModel 209 | classifier *RobertaClassificationHead 210 | } 211 | 212 | // NewRobertaForSequenceClassification creates a new RobertaForSequenceClassification model. 213 | func NewRobertaForSequenceClassification(p *nn.Path, config *bert.BertConfig) *RobertaForSequenceClassification { 214 | roberta := bert.NewBertModel(p.Sub("roberta"), config, false) 215 | classifier := NewRobertaClassificationHead(p.Sub("classifier"), config) 216 | 217 | return &RobertaForSequenceClassification{ 218 | roberta: roberta, 219 | classifier: classifier, 220 | } 221 | } 222 | 223 | // Load loads model from file or model name. It also updates default configuration parameters if provided. 224 | // 225 | // This method implements `PretrainedModel` interface. 226 | func (sc *RobertaForSequenceClassification) Load(modelNameOrPath string, config interface{ pretrained.Config }, params map[string]interface{}, device gotch.Device) error { 227 | // var urlOrFilename string 228 | // // If modelName, infer to default configuration filename: 229 | // if modelFile, ok := pretrained.RobertaModels[modelNameOrPath]; ok { 230 | // urlOrFilename = modelFile 231 | // } else { 232 | // // Otherwise, just take the input 233 | // urlOrFilename = modelNameOrPath 234 | // } 235 | 236 | cachedFile, err := util.CachedPath(modelNameOrPath, "pytorch_model.bin") 237 | if err != nil { 238 | return err 239 | } 240 | 241 | vs := nn.NewVarStore(device) 242 | p := vs.Root() 243 | 244 | sc.roberta = bert.NewBertModel(p.Sub("roberta"), config.(*bert.BertConfig), false) 245 | sc.classifier = NewRobertaClassificationHead(p.Sub("classifier"), config.(*bert.BertConfig)) 246 | 247 | // err = vs.Load(cachedFile) 248 | err = pickle.LoadAll(vs, cachedFile) 249 | if err != nil { 250 | return err 251 | } 252 | 253 | return nil 254 | } 255 | 256 | // Forward forwards pass through the model. 257 | func (sc *RobertaForSequenceClassification) ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (labels *ts.Tensor, hiddenStates, attentions []ts.Tensor, err error) { 258 | 259 | hiddenState, _, hiddenStates, attentions, err := sc.roberta.ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, ts.None, ts.None, train) 260 | if err != nil { 261 | return ts.None, nil, nil, err 262 | } 263 | 264 | labels = sc.classifier.ForwardT(hiddenState, train) 265 | hiddenState.MustDrop() 266 | 267 | return labels, hiddenStates, attentions, nil 268 | } 269 | 270 | // RobertaForMultipleChoice holds data for Roberta multiple choice model. 271 | // 272 | // Input should be in form of ` Context Possible choice `. 273 | // The choice is made along the batch axis, assuming all elements of the batch are 274 | // alternatives to be chosen from for a given context. 275 | type RobertaForMultipleChoice struct { 276 | roberta *bert.BertModel 277 | dropout *util.Dropout 278 | classifier *nn.Linear 279 | } 280 | 281 | // NewRobertaForMultipleChoice creates a new RobertaForMultipleChoice model. 282 | func NewRobertaForMultipleChoice(p *nn.Path, config *bert.BertConfig) *RobertaForMultipleChoice { 283 | roberta := bert.NewBertModel(p.Sub("roberta"), config, false) 284 | dropout := util.NewDropout(config.HiddenDropoutProb) 285 | classifier := nn.NewLinear(p.Sub("classifier"), config.HiddenSize, 1, nn.DefaultLinearConfig()) 286 | 287 | return &RobertaForMultipleChoice{ 288 | roberta: roberta, 289 | dropout: dropout, 290 | classifier: classifier, 291 | } 292 | } 293 | 294 | // Load loads model from file or model name. It also updates default configuration parameters if provided. 295 | // 296 | // This method implements `PretrainedModel` interface. 297 | func (mc *RobertaForMultipleChoice) Load(modelNameOrPath string, config interface{ pretrained.Config }, params map[string]interface{}, device gotch.Device) error { 298 | // var urlOrFilename string 299 | // // If modelName, infer to default configuration filename: 300 | // if modelFile, ok := pretrained.RobertaModels[modelNameOrPath]; ok { 301 | // urlOrFilename = modelFile 302 | // } else { 303 | // // Otherwise, just take the input 304 | // urlOrFilename = modelNameOrPath 305 | // } 306 | 307 | cachedFile, err := util.CachedPath(modelNameOrPath, "pytorch_model.bin") 308 | if err != nil { 309 | return err 310 | } 311 | 312 | vs := nn.NewVarStore(device) 313 | p := vs.Root() 314 | 315 | mc.roberta = bert.NewBertModel(p.Sub("roberta"), config.(*bert.BertConfig), false) 316 | mc.dropout = util.NewDropout(config.(*bert.BertConfig).HiddenDropoutProb) 317 | classifier := nn.NewLinear(p.Sub("classifier"), config.(*bert.BertConfig).HiddenSize, 1, nn.DefaultLinearConfig()) 318 | mc.classifier = classifier 319 | 320 | err = vs.Load(cachedFile) 321 | if err != nil { 322 | return err 323 | } 324 | 325 | return nil 326 | } 327 | 328 | // ForwardT forwards pass through the model. 329 | func (mc *RobertaForMultipleChoice) ForwardT(inputIds, mask, tokenTypeIds, positionIds *ts.Tensor, train bool) (output *ts.Tensor, hiddenStates, attentions []ts.Tensor, err error) { 330 | 331 | numChoices := inputIds.MustSize()[1] 332 | 333 | inputIdsSize := inputIds.MustSize() 334 | flatInputIds := inputIds.MustView([]int64{-1, inputIdsSize[len(inputIdsSize)-1]}, false) 335 | 336 | flatPositionIds := ts.None 337 | if positionIds.MustDefined() { 338 | positionIdsSize := positionIds.MustSize() 339 | flatPositionIds = positionIds.MustView([]int64{-1, positionIdsSize[len(positionIdsSize)-1]}, false) 340 | } 341 | 342 | flatTokenTypeIds := ts.None 343 | if tokenTypeIds.MustDefined() { 344 | tokenTypeIdsSize := tokenTypeIds.MustSize() 345 | flatTokenTypeIds = tokenTypeIds.MustView([]int64{-1, tokenTypeIdsSize[len(tokenTypeIdsSize)-1]}, false) 346 | } 347 | 348 | flatMask := ts.None 349 | if mask.MustDefined() { 350 | flatMaskSize := flatMask.MustSize() 351 | flatMask = mask.MustView([]int64{-1, flatMaskSize[len(flatMaskSize)-1]}, false) 352 | } 353 | 354 | var pooledOutput *ts.Tensor 355 | _, pooledOutput, hiddenStates, attentions, err = mc.roberta.ForwardT(flatInputIds, flatMask, flatTokenTypeIds, flatPositionIds, ts.None, ts.None, ts.None, train) 356 | if err != nil { 357 | return ts.None, nil, nil, err 358 | } 359 | 360 | appliedDO := pooledOutput.ApplyT(mc.dropout, train) 361 | appliedCls := appliedDO.Apply(mc.classifier) 362 | output = appliedCls.MustView([]int64{-1, numChoices}, true) 363 | 364 | appliedDO.MustDrop() 365 | 366 | return output, hiddenStates, attentions, nil 367 | } 368 | 369 | // RobertaForTokenClassification holds data for Roberta token classification model. 370 | type RobertaForTokenClassification struct { 371 | roberta *bert.BertModel 372 | dropout *util.Dropout 373 | classifier *nn.Linear 374 | } 375 | 376 | // NewRobertaForTokenClassification creates a new RobertaForTokenClassification model. 377 | func NewRobertaForTokenClassification(p *nn.Path, config *bert.BertConfig) *RobertaForTokenClassification { 378 | roberta := bert.NewBertModel(p.Sub("roberta"), config, false) 379 | dropout := util.NewDropout(config.HiddenDropoutProb) 380 | numLabels := int64(len(config.Id2Label)) 381 | classifier := nn.NewLinear(p.Sub("classifier"), config.HiddenSize, numLabels, nn.DefaultLinearConfig()) 382 | 383 | return &RobertaForTokenClassification{ 384 | roberta: roberta, 385 | dropout: dropout, 386 | classifier: classifier, 387 | } 388 | } 389 | 390 | // Load loads model from file or model name. It also updates default configuration parameters if provided. 391 | // 392 | // This method implements `PretrainedModel` interface. 393 | func (tc *RobertaForTokenClassification) Load(modelNameOrPath string, config interface{ pretrained.Config }, params map[string]interface{}, device gotch.Device) error { 394 | // var urlOrFilename string 395 | // // If modelName, infer to default configuration filename: 396 | // if modelFile, ok := pretrained.RobertaModels[modelNameOrPath]; ok { 397 | // urlOrFilename = modelFile 398 | // } else { 399 | // // Otherwise, just take the input 400 | // urlOrFilename = modelNameOrPath 401 | // } 402 | 403 | cachedFile, err := util.CachedPath(modelNameOrPath, "pytorch_model.bin") 404 | if err != nil { 405 | return err 406 | } 407 | 408 | vs := nn.NewVarStore(device) 409 | p := vs.Root() 410 | 411 | roberta := bert.NewBertModel(p.Sub("roberta"), config.(*bert.BertConfig), false) 412 | dropout := util.NewDropout(config.(*bert.BertConfig).HiddenDropoutProb) 413 | numLabels := int64(len(config.(*bert.BertConfig).Id2Label)) 414 | classifier := nn.NewLinear(p.Sub("classifier"), config.(*bert.BertConfig).HiddenSize, numLabels, nn.DefaultLinearConfig()) 415 | 416 | tc.roberta = roberta 417 | tc.dropout = dropout 418 | tc.classifier = classifier 419 | 420 | err = vs.Load(cachedFile) 421 | if err != nil { 422 | return err 423 | } 424 | 425 | return nil 426 | } 427 | 428 | // ForwardT forwards pass through the model. 429 | func (tc *RobertaForTokenClassification) ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (output *ts.Tensor, hiddenStates, attentions []ts.Tensor, err error) { 430 | hiddenState, _, hiddenStates, attentions, err := tc.roberta.ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, ts.None, ts.None, train) 431 | if err != nil { 432 | return ts.None, nil, nil, err 433 | } 434 | 435 | appliedDO := hiddenState.ApplyT(tc.dropout, train) 436 | output = appliedDO.Apply(tc.classifier) 437 | 438 | appliedDO.MustDrop() 439 | 440 | return output, hiddenStates, attentions, nil 441 | } 442 | 443 | // RobertaForQuestionAnswering constructs layers for Roberta question answering model. 444 | type RobertaForQuestionAnswering struct { 445 | roberta *bert.BertModel 446 | qaOutputs *nn.Linear 447 | } 448 | 449 | // NewRobertaQuestionAnswering creates a new RobertaForQuestionAnswering model. 450 | func NewRobertaForQuestionAnswering(p *nn.Path, config *bert.BertConfig) *RobertaForQuestionAnswering { 451 | roberta := bert.NewBertModel(p.Sub("roberta"), config, false) 452 | numLabels := int64(2) 453 | qaOutputs := nn.NewLinear(p.Sub("qa_outputs"), config.HiddenSize, numLabels, nn.DefaultLinearConfig()) 454 | 455 | return &RobertaForQuestionAnswering{ 456 | roberta: roberta, 457 | qaOutputs: qaOutputs, 458 | } 459 | } 460 | 461 | // Load loads model from file or model name. It also updates default configuration parameters if provided. 462 | // 463 | // This method implements `PretrainedModel` interface. 464 | func (qa *RobertaForQuestionAnswering) Load(modelNameOrPath string, config interface{ pretrained.Config }, params map[string]interface{}, device gotch.Device) error { 465 | // var urlOrFilename string 466 | // // If modelName, infer to default configuration filename: 467 | // if modelFile, ok := pretrained.RobertaModels[modelNameOrPath]; ok { 468 | // urlOrFilename = modelFile 469 | // } else { 470 | // // Otherwise, just take the input 471 | // urlOrFilename = modelNameOrPath 472 | // } 473 | 474 | cachedFile, err := util.CachedPath(modelNameOrPath, "pytorch_model.bin") 475 | if err != nil { 476 | return err 477 | } 478 | 479 | vs := nn.NewVarStore(device) 480 | p := vs.Root() 481 | 482 | roberta := bert.NewBertModel(p.Sub("roberta"), config.(*bert.BertConfig), false) 483 | numLabels := int64(2) 484 | qaOutputs := nn.NewLinear(p.Sub("qa_outputs"), config.(*bert.BertConfig).HiddenSize, numLabels, nn.DefaultLinearConfig()) 485 | 486 | qa.roberta = roberta 487 | qa.qaOutputs = qaOutputs 488 | 489 | err = vs.Load(cachedFile) 490 | if err != nil { 491 | return err 492 | } 493 | 494 | return nil 495 | } 496 | 497 | // ForwadT forwards pass through the model. 498 | func (qa *RobertaForQuestionAnswering) ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (startScores, endScores *ts.Tensor, hiddenStates, attentions []ts.Tensor, err error) { 499 | hiddenState, _, hiddenStates, attentions, err := qa.roberta.ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, ts.None, ts.None, train) 500 | if err != nil { 501 | return ts.None, ts.None, nil, nil, err 502 | } 503 | 504 | sequenceOutput := hiddenState.Apply(qa.qaOutputs) 505 | logits := sequenceOutput.MustSplit(1, -1, true) 506 | startScores = logits[0].MustSqueezeDim(-1, false) 507 | endScores = logits[1].MustSqueezeDim(-1, false) 508 | 509 | for _, x := range logits { 510 | x.MustDrop() 511 | } 512 | 513 | return startScores, endScores, hiddenStates, attentions, nil 514 | } 515 | -------------------------------------------------------------------------------- /bert/model.go: -------------------------------------------------------------------------------- 1 | package bert 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | 7 | "github.com/sugarme/gotch" 8 | "github.com/sugarme/gotch/nn" 9 | "github.com/sugarme/gotch/pickle" 10 | "github.com/sugarme/gotch/ts" 11 | 12 | "github.com/sugarme/transformer/pretrained" 13 | "github.com/sugarme/transformer/util" 14 | ) 15 | 16 | // BertModel defines base architecture for BERT models. 17 | // Task-specific models can be built from this base model. 18 | // 19 | // Fields: 20 | // - Embeddings: for `token`, `position` and `segment` embeddings 21 | // - Encoder: is a vector of layers. Each layer compose of a `self-attention`, 22 | // 23 | // an `intermedate` (linear) and an output ( linear + layer norm) sub-layers. 24 | // - Pooler: linear layer applied to the first element of the sequence (`[MASK]` token) 25 | // - IsDecoder: whether model is used as a decoder. If set to `true` 26 | // 27 | // a casual mask will be applied to hide future positions that should be attended to. 28 | type BertModel struct { 29 | Embeddings *BertEmbeddings 30 | Encoder *BertEncoder 31 | Pooler *BertPooler 32 | IsDecoder bool 33 | } 34 | 35 | // NewBertModel builds a new `BertModel`. 36 | // 37 | // Params: 38 | // - `p`: Variable store path for the root of the BERT Model 39 | // - `config`: BertConfig onfiguration for model architecture and decoder status 40 | func NewBertModel(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertModel { 41 | changeName := true 42 | if len(changeNameOpt) > 0 { 43 | changeName = changeNameOpt[0] 44 | } 45 | isDecoder := false 46 | if config.IsDecoder { 47 | isDecoder = true 48 | } 49 | 50 | embeddings := NewBertEmbeddings(p.Sub("embeddings"), config, changeName) 51 | encoder := NewBertEncoder(p.Sub("encoder"), config, changeName) 52 | pooler := NewBertPooler(p.Sub("pooler"), config) 53 | 54 | return &BertModel{embeddings, encoder, pooler, isDecoder} 55 | } 56 | 57 | // ForwardT forwards pass through the model. 58 | // 59 | // Params: 60 | // - `inputIds`: optional input tensor of shape (batch size, sequence length). 61 | // If None, pre-computed embeddings must be provided (see `inputEmbeds`) 62 | // - `mask`: optional mask of shape (batch size, sequence length). 63 | // Masked position have value 0, non-masked value 1. If None set to 1. 64 | // - `tokenTypeIds`: optional segment id of shape (batch size, sequence length). 65 | // Convention is value of 0 for the first sentence (incl. [SEP]) and 1 for the second sentence. If None set to 0. 66 | // - `positionIds`: optional position ids of shape (batch size, sequence length). 67 | // If None, will be incremented from 0. 68 | // - `inputEmbeds`: optional pre-computed input embeddings of shape (batch size, sequence length, hidden size). 69 | // If None, input ids must be provided (see `inputIds`). 70 | // - `encoderHiddenStates`: optional encoder hidden state of shape (batch size, encoder sequence length, hidden size). 71 | // If the model is defined as a decoder and the `encoderHiddenStates` is not None, 72 | // used in the cross-attention layer as keys and values (query from the decoder). 73 | // - `encoderMask`: optional encoder attention mask of shape (batch size, encoder sequence length). 74 | // If the model is defined as a decoder and the `encoderHiddenStates` is not None, used to mask encoder values. 75 | // Positions with value 0 will be masked. 76 | // - `train`: boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. 77 | // 78 | // Returns: 79 | // - `output`: tensor of shape (batch size, sequence length, hidden size) 80 | // - `pooledOutput`: tensor of shape (batch size, hidden size) 81 | // - `hiddenStates`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize) 82 | // - `attentions`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize) 83 | func (b *BertModel) ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, encoderHiddenStates, encoderMask *ts.Tensor, train bool) (retVal1, retVal2 *ts.Tensor, retValOpt1, retValOpt2 []ts.Tensor, err error) { 84 | 85 | var ( 86 | inputShape []int64 87 | device gotch.Device 88 | ) 89 | 90 | if inputIds.MustDefined() { 91 | if inputEmbeds.MustDefined() { 92 | err = fmt.Errorf("Only one of input ids or input embeddings may be set\n") 93 | return 94 | } 95 | inputShape = inputIds.MustSize() 96 | device = inputIds.MustDevice() 97 | } else { 98 | if inputEmbeds.MustDefined() { 99 | size := inputEmbeds.MustSize() 100 | inputShape = []int64{size[0], size[1]} 101 | device = inputEmbeds.MustDevice() 102 | } else { 103 | err = fmt.Errorf("At least one of input ids or input embeddings must be set\n") 104 | return 105 | } 106 | } 107 | 108 | var maskTs *ts.Tensor 109 | if mask.MustDefined() { 110 | maskTs = mask 111 | } else { 112 | maskTs = ts.MustOnes(inputShape, gotch.Int64, device) 113 | } 114 | 115 | var extendedAttentionMask *ts.Tensor 116 | switch maskTs.Dim() { 117 | case 3: 118 | extendedAttentionMask = maskTs.MustUnsqueeze(1, false) // TODO: check and delete maskTs if not using later 119 | case 2: 120 | if b.IsDecoder { 121 | seqIds := ts.MustArange(ts.IntScalar(inputShape[1]), gotch.Float, device) 122 | causalMaskTmp := seqIds.MustUnsqueeze(0, false).MustUnsqueeze(0, true).MustRepeat([]int64{inputShape[0], inputShape[1], 1}, true) 123 | causalMask := causalMaskTmp.MustLeTensor(seqIds.MustUnsqueeze(0, true).MustUnsqueeze(1, true), true) 124 | extendedAttentionMask = causalMask.MustMatmul(mask.MustUnsqueeze(1, false).MustUnsqueeze(1, true), true) 125 | } else { 126 | extendedAttentionMask = maskTs.MustUnsqueeze(1, false).MustUnsqueeze(1, true) 127 | } 128 | 129 | default: 130 | err = fmt.Errorf("Invalid attention mask dimension, must be 2 or 3, got %v\n", maskTs.Dim()) 131 | } 132 | 133 | extendedAttnMask := extendedAttentionMask.MustOnesLike(false).MustSub(extendedAttentionMask, true).MustMulScalar(ts.FloatScalar(-10000.0), true) 134 | 135 | // NOTE. encoderExtendedAttentionMask is an optional tensor 136 | var encoderExtendedAttentionMask *ts.Tensor 137 | if b.IsDecoder && encoderHiddenStates.MustDefined() { 138 | size := encoderHiddenStates.MustSize() 139 | var encoderMaskTs *ts.Tensor 140 | if encoderMask.MustDefined() { 141 | encoderMaskTs = encoderMask 142 | } else { 143 | encoderMaskTs = ts.MustOnes([]int64{size[0], size[1]}, gotch.Int64, device) 144 | } 145 | 146 | switch encoderMaskTs.Dim() { 147 | case 2: 148 | encoderExtendedAttentionMask = encoderMaskTs.MustUnsqueeze(1, true).MustUnsqueeze(1, true) 149 | case 3: 150 | encoderExtendedAttentionMask = encoderMaskTs.MustUnsqueeze(1, true) 151 | default: 152 | err = fmt.Errorf("Invalid encoder attention mask dimension, must be 2, or 3 got %v\n", encoderMaskTs.Dim()) 153 | return 154 | } 155 | } else { 156 | encoderExtendedAttentionMask = ts.None 157 | } 158 | 159 | embeddingOutput, err := b.Embeddings.ForwardT(inputIds, tokenTypeIds, positionIds, inputEmbeds, train) 160 | if err != nil { 161 | return 162 | } 163 | 164 | hiddenState, allHiddenStates, allAttentions := b.Encoder.ForwardT(embeddingOutput, extendedAttnMask, encoderHiddenStates, encoderExtendedAttentionMask, train) 165 | 166 | pooledOutput := b.Pooler.Forward(hiddenState) 167 | 168 | return hiddenState, pooledOutput, allHiddenStates, allAttentions, nil 169 | } 170 | 171 | // BertPredictionHeadTransform: 172 | // ============================ 173 | 174 | // BertPredictionHeadTransform holds layers of BERT prediction head transform. 175 | type BertPredictionHeadTransform struct { 176 | Dense *nn.Linear 177 | Activation util.ActivationFn 178 | LayerNorm *nn.LayerNorm 179 | } 180 | 181 | // NewBertPredictionHead creates BertPredictionHeadTransform. 182 | func NewBertPredictionHeadTransform(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertPredictionHeadTransform { 183 | changeName := true 184 | if len(changeNameOpt) > 0 { 185 | changeName = changeNameOpt[0] 186 | } 187 | dense := nn.NewLinear(p.Sub("dense"), config.HiddenSize, config.HiddenSize, nn.DefaultLinearConfig()) 188 | activation, ok := util.ActivationFnMap[config.HiddenAct] 189 | if !ok { 190 | log.Fatalf("Unsupported activation function - %v\n", config.HiddenAct) 191 | } 192 | 193 | lnConfig := nn.DefaultLayerNormConfig() 194 | if changeName { 195 | lnConfig.WsName = "gamma" 196 | lnConfig.BsName = "beta" 197 | } 198 | layerNorm := nn.NewLayerNorm(p.Sub("LayerNorm"), []int64{config.HiddenSize}, lnConfig) 199 | 200 | return &BertPredictionHeadTransform{dense, activation, layerNorm} 201 | } 202 | 203 | // Forward forwards through the model. 204 | func (bpht *BertPredictionHeadTransform) Forward(hiddenStates *ts.Tensor) (retVal *ts.Tensor) { 205 | tmp1 := hiddenStates.Apply(bpht.Dense) 206 | tmp2 := bpht.Activation.Fwd(tmp1) 207 | retVal = tmp2.Apply(bpht.LayerNorm) 208 | tmp1.MustDrop() 209 | tmp2.MustDrop() 210 | 211 | return retVal 212 | } 213 | 214 | // BertLMPredictionHead: 215 | // ===================== 216 | 217 | // BertLMPredictionHead constructs layers for BERT prediction head. 218 | type BertLMPredictionHead struct { 219 | Transform *BertPredictionHeadTransform 220 | Decoder *util.LinearNoBias 221 | Bias *ts.Tensor 222 | } 223 | 224 | // NewBertLMPredictionHead creates BertLMPredictionHead. 225 | func NewBertLMPredictionHead(p *nn.Path, config *BertConfig) (*BertLMPredictionHead, error) { 226 | path := p.Sub("predictions") 227 | transform := NewBertPredictionHeadTransform(path.Sub("transform"), config) 228 | decoder, err := util.NewLinearNoBias(path.Sub("decoder"), config.HiddenSize, config.VocabSize, util.DefaultLinearNoBiasConfig()) 229 | if err != nil { 230 | return nil, err 231 | } 232 | bias, err := path.NewVar("bias", []int64{config.VocabSize}, nn.NewKaimingUniformInit()) 233 | if err != nil { 234 | return nil, err 235 | } 236 | 237 | return &BertLMPredictionHead{transform, decoder, bias}, nil 238 | } 239 | 240 | // Forward fowards through the model. 241 | func (ph *BertLMPredictionHead) Forward(hiddenState *ts.Tensor) *ts.Tensor { 242 | fwTensor := ph.Transform.Forward(hiddenState).Apply(ph.Decoder) 243 | 244 | retVal := fwTensor.MustAdd(ph.Bias, false) 245 | fwTensor.MustDrop() 246 | 247 | return retVal 248 | } 249 | 250 | // BertForMaskedLM: 251 | // ================ 252 | 253 | // BertForMaskedLM is BERT for masked language model 254 | type BertForMaskedLM struct { 255 | bert *BertModel 256 | cls *BertLMPredictionHead 257 | } 258 | 259 | // NewBertForMaskedLM creates BertForMaskedLM. 260 | func NewBertForMaskedLM(p *nn.Path, config *BertConfig, changeNameOpt ...bool) (*BertForMaskedLM, error) { 261 | changeName := true 262 | if len(changeNameOpt) > 0 { 263 | changeName = changeNameOpt[0] 264 | } 265 | bert := NewBertModel(p.Sub("bert"), config, changeName) 266 | cls, err := NewBertLMPredictionHead(p.Sub("cls"), config) 267 | if err != nil { 268 | return nil, err 269 | } 270 | 271 | return &BertForMaskedLM{bert, cls}, nil 272 | } 273 | 274 | // Load loads model from file or model name. It also updates 275 | // default configuration parameters if provided. 276 | // This method implements `PretrainedModel` interface. 277 | func (mlm *BertForMaskedLM) Load(modelNameOrPath string, config interface{ pretrained.Config }, params map[string]interface{}, device gotch.Device) error { 278 | vs := nn.NewVarStore(device) 279 | p := vs.Root() 280 | mlm.bert = NewBertModel(p.Sub("bert"), config.(*BertConfig)) 281 | var err error 282 | mlm.cls, err = NewBertLMPredictionHead(p.Sub("cls"), config.(*BertConfig)) 283 | if err != nil { 284 | return err 285 | } 286 | 287 | // err = vs.Load(cachedFile) 288 | err = pickle.LoadAll(vs, modelNameOrPath) 289 | if err != nil { 290 | log.Fatalf("Load model weight error: \n%v", err) 291 | } 292 | 293 | return nil 294 | } 295 | 296 | // ForwardT forwards pass through the model. 297 | // 298 | // Params: 299 | // - `inputIds`: optional input tensor of shape (batch size, sequence length). 300 | // If None, pre-computed embeddings must be provided (see `inputEmbeds`) 301 | // - `mask`: optional mask of shape (batch size, sequence length). 302 | // Masked position have value 0, non-masked value 1. If None set to 1. 303 | // - `tokenTypeIds`: optional segment id of shape (batch size, sequence length). 304 | // Convention is value of 0 for the first sentence (incl. [SEP]) and 1 for the second sentence. If None set to 0. 305 | // - `positionIds`: optional position ids of shape (batch size, sequence length). 306 | // If None, will be incremented from 0. 307 | // - `inputEmbeds`: optional pre-computed input embeddings of shape (batch size, sequence length, hidden size). 308 | // If None, input ids must be provided (see `inputIds`). 309 | // - `encoderHiddenStates`: optional encoder hidden state of shape (batch size, encoder sequence length, hidden size). 310 | // If the model is defined as a decoder and the `encoderHiddenStates` is not None, 311 | // used in the cross-attention layer as keys and values (query from the decoder). 312 | // - `encoderMask`: optional encoder attention mask of shape (batch size, encoder sequence length). 313 | // If the model is defined as a decoder and the `encoderHiddenStates` is not None, used to mask encoder values. 314 | // Positions with value 0 will be masked. 315 | // - `train`: boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. 316 | // 317 | // Returns: 318 | // - `output`: tensor of shape (batch size, sequence length, hidden size) 319 | // - `hiddenStates`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize) 320 | // - `attentions`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize) 321 | func (mlm *BertForMaskedLM) ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, encoderHiddenStates, encoderMask *ts.Tensor, train bool) (retVal1 *ts.Tensor, optRetVal1, optRetVal2 []ts.Tensor) { 322 | 323 | hiddenState, _, allHiddenStates, allAttentions, err := mlm.bert.ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, encoderHiddenStates, encoderMask, train) 324 | if err != nil { 325 | log.Fatal(err) 326 | } 327 | 328 | predictionScores := mlm.cls.Forward(hiddenState) 329 | 330 | return predictionScores, allHiddenStates, allAttentions 331 | } 332 | 333 | // BERT for sequence classification: 334 | // ================================= 335 | 336 | // BertForSequenceClassification is Base BERT model with a classifier head to perform 337 | // sentence or document-level classification. 338 | // 339 | // It is made of the following blocks: 340 | // - `bert`: Base BertModel 341 | // - `classifier`: BERT linear layer for classification 342 | type BertForSequenceClassification struct { 343 | bert *BertModel 344 | dropout *util.Dropout 345 | classifier *nn.Linear 346 | } 347 | 348 | // NewBertForSequenceClassification creates a new `BertForSequenceClassification`. 349 | // 350 | // Params: 351 | // - `p`: ariable store path for the root of the BertForSequenceClassification model 352 | // - `config`: `BertConfig` object defining the model architecture and number of classes 353 | // 354 | // Example: 355 | // 356 | // device := gotch.CPU 357 | // vs := nn.NewVarStore(device) 358 | // config := bert.ConfigFromFile("path/to/config.json") 359 | // p := vs.Root() 360 | // bert := NewBertForSequenceClassification(p.Sub("bert"), config) 361 | func NewBertForSequenceClassification(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertForSequenceClassification { 362 | changeName := true 363 | if len(changeNameOpt) > 0 { 364 | changeName = changeNameOpt[0] 365 | } 366 | bert := NewBertModel(p.Sub("bert"), config, changeName) 367 | dropout := util.NewDropout(config.HiddenDropoutProb) 368 | numLabels := len(config.Id2Label) 369 | 370 | classifier := nn.NewLinear(p.Sub("classifier"), config.HiddenSize, int64(numLabels), nn.DefaultLinearConfig()) 371 | 372 | return &BertForSequenceClassification{ 373 | bert: bert, 374 | dropout: dropout, 375 | classifier: classifier, 376 | } 377 | } 378 | 379 | // ForwardT forwards pass through the model. 380 | // 381 | // Params: 382 | // - `inputIds`: optional input tensor of shape (batch size, sequence length). 383 | // If None, pre-computed embeddings must be provided (see `inputEmbeds`) 384 | // - `mask`: optional mask of shape (batch size, sequence length). 385 | // Masked position have value 0, non-masked value 1. If None set to 1. 386 | // - `tokenTypeIds`: optional segment id of shape (batch size, sequence length). 387 | // Convention is value of 0 for the first sentence (incl. [SEP]) and 1 for the second sentence. If None set to 0. 388 | // - `positionIds`: optional position ids of shape (batch size, sequence length). 389 | // If None, will be incremented from 0. 390 | // - `inputEmbeds`: optional pre-computed input embeddings of shape (batch size, sequence length, hidden size). 391 | // If None, input ids must be provided (see `inputIds`). 392 | // - `encoderHiddenStates`: optional encoder hidden state of shape (batch size, encoder sequence length, hidden size). 393 | // If the model is defined as a decoder and the `encoderHiddenStates` is not None, 394 | // used in the cross-attention layer as keys and values (query from the decoder). 395 | // - `encoderMask`: optional encoder attention mask of shape (batch size, encoder sequence length). 396 | // If the model is defined as a decoder and the `encoderHiddenStates` is not None, used to mask encoder values. 397 | // Positions with value 0 will be masked. 398 | // - `train`: boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. 399 | // 400 | // Returns: 401 | // - `output`: tensor of shape (batch size, sequence length, hidden size) 402 | // - `pooledOutput`: tensor of shape (batch size, hidden size) 403 | // - `hiddenStates`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize) 404 | // - `attentions`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize) 405 | func (bsc *BertForSequenceClassification) ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (retVal *ts.Tensor, retValOpt1, retValOpt2 []ts.Tensor) { 406 | _, pooledOutput, allHiddenStates, allAttentions, err := bsc.bert.ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, ts.None, ts.None, train) 407 | 408 | if err != nil { 409 | log.Fatalf("call bert.ForwardT error: %v", err) 410 | } 411 | 412 | dropoutOutput := pooledOutput.ApplyT(bsc.dropout, train) 413 | 414 | output := dropoutOutput.Apply(bsc.classifier) 415 | dropoutOutput.MustDrop() 416 | 417 | return output, allHiddenStates, allAttentions 418 | } 419 | 420 | // BERT for multiple choices : 421 | // =========================== 422 | 423 | // BertForMultipleChoice constructs multiple choices model using a BERT base model and a linear classifier. 424 | // Input should be in the form `[CLS] Context [SEP] Possible choice [SEP]`. The choice is made along the batch axis, 425 | // assuming all elements of the batch are alternatives to be chosen from for a given context. 426 | // 427 | // It is made of the following blocks: 428 | // - `bert`: Base BertModel 429 | // - `classifier`: Linear layer for multiple choices 430 | type BertForMultipleChoice struct { 431 | bert *BertModel 432 | dropout *util.Dropout 433 | classifier *nn.Linear 434 | } 435 | 436 | // NewBertForMultipleChoice creates a new `BertForMultipleChoice`. 437 | // 438 | // Params: 439 | // - `p`: Variable store path for the root of the BertForMultipleChoice model 440 | // - `config`: `BertConfig` object defining the model architecture 441 | func NewBertForMultipleChoice(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertForMultipleChoice { 442 | changeName := true 443 | if len(changeNameOpt) > 0 { 444 | changeName = changeNameOpt[0] 445 | } 446 | bert := NewBertModel(p.Sub("bert"), config, changeName) 447 | dropout := util.NewDropout(config.HiddenDropoutProb) 448 | classifier := nn.NewLinear(p.Sub("classifier"), config.HiddenSize, 1, nn.DefaultLinearConfig()) 449 | 450 | return &BertForMultipleChoice{ 451 | bert: bert, 452 | dropout: dropout, 453 | classifier: classifier, 454 | } 455 | } 456 | 457 | // ForwardT forwards pass through the model. 458 | // 459 | // Params: 460 | // - `inputIds`: optional input tensor of shape (batch size, sequence length). 461 | // If None, pre-computed embeddings must be provided (see `inputEmbeds`) 462 | // - `mask`: optional mask of shape (batch size, sequence length). 463 | // Masked position have value 0, non-masked value 1. If None set to 1. 464 | // - `tokenTypeIds`: optional segment id of shape (batch size, sequence length). 465 | // Convention is value of 0 for the first sentence (incl. [SEP]) and 1 for the second sentence. If None set to 0. 466 | // - `positionIds`: optional position ids of shape (batch size, sequence length). 467 | // If None, will be incremented from 0. 468 | // - `train`: boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. 469 | // 470 | // Returns: 471 | // - `output`: tensor of shape (batch size, sequence length, hidden size) 472 | // - `hiddenStates`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize) 473 | // - `attentions`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize) 474 | func (mc *BertForMultipleChoice) ForwardT(inputIds, mask, tokenTypeIds, positionIds *ts.Tensor, train bool) (retVal *ts.Tensor, retValOpt1, retValOpt2 []ts.Tensor) { 475 | inputIdsSize := inputIds.MustSize() 476 | fmt.Printf("inputIdsSize: %v\n", inputIdsSize) 477 | numChoices := inputIdsSize[1] 478 | inputIdsView := inputIds.MustView([]int64{-1, inputIdsSize[len(inputIdsSize)-1]}, false) 479 | 480 | maskView := ts.None 481 | if mask.MustDefined() { 482 | maskSize := mask.MustSize() 483 | maskView = mask.MustView([]int64{-1, maskSize[len(maskSize)-1]}, false) 484 | } 485 | 486 | tokenTypeIdsView := ts.None 487 | if tokenTypeIds.MustDefined() { 488 | tokenTypeIdsSize := tokenTypeIds.MustSize() 489 | tokenTypeIdsView = tokenTypeIds.MustView([]int64{-1, tokenTypeIdsSize[len(tokenTypeIdsSize)-1]}, false) 490 | } 491 | 492 | positionIdsView := ts.None 493 | if positionIds.MustDefined() { 494 | positionIdsSize := positionIds.MustSize() 495 | positionIdsView = positionIds.MustView([]int64{-1, positionIdsSize[len(positionIdsSize)-1]}, false) 496 | } 497 | 498 | _, pooledOutput, allHiddenStates, allAttentions, err := mc.bert.ForwardT(inputIdsView, maskView, tokenTypeIdsView, positionIdsView, ts.None, ts.None, ts.None, train) 499 | if err != nil { 500 | log.Fatalf("Call 'BertForMultipleChoice ForwordT' method error: %v\n", err) 501 | } 502 | 503 | outputDropout := pooledOutput.ApplyT(mc.dropout, train) 504 | outputClassifier := outputDropout.Apply(mc.classifier) 505 | 506 | output := outputClassifier.MustView([]int64{-1, numChoices}, false) 507 | 508 | outputDropout.MustDrop() 509 | outputClassifier.MustDrop() 510 | 511 | return output, allHiddenStates, allAttentions 512 | } 513 | 514 | // BERT for token classification (e.g., NER, POS): 515 | // =============================================== 516 | 517 | // BertForTokenClassification constructs token-level classifier predicting a label for each token provided. 518 | // Note that because of wordpiece tokenization, the labels predicted are not necessarily aligned with words in the sentence. 519 | // 520 | // It is made of the following blocks: 521 | // - `bert`: Base BertModel 522 | // - `classifier`: Linear layer for token classification 523 | type BertForTokenClassification struct { 524 | bert *BertModel 525 | dropout *util.Dropout 526 | classifier *nn.Linear 527 | } 528 | 529 | // NewBertForTokenClassification creates a new `BertForTokenClassification` 530 | // 531 | // Params: 532 | // - `p`: Variable store path for the root of the BertForTokenClassification model 533 | // - `config`: `BertConfig` object defining the model architecture, number of output labels and label mapping 534 | func NewBertForTokenClassification(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertForTokenClassification { 535 | changeName := true 536 | if len(changeNameOpt) > 0 { 537 | changeName = changeNameOpt[0] 538 | } 539 | bert := NewBertModel(p.Sub("bert"), config, changeName) 540 | dropout := util.NewDropout(config.HiddenDropoutProb) 541 | 542 | numLabels := len(config.Id2Label) 543 | classifier := nn.NewLinear(p.Sub("classifier"), config.HiddenSize, int64(numLabels), nn.DefaultLinearConfig()) 544 | 545 | return &BertForTokenClassification{ 546 | bert: bert, 547 | dropout: dropout, 548 | classifier: classifier, 549 | } 550 | } 551 | 552 | // ForwordT forwards pass through the model. 553 | // 554 | // Params: 555 | // - `inputIds`: optional input tensor of shape (batch size, sequence length). 556 | // If None, pre-computed embeddings must be provided (see `inputEmbeds`) 557 | // - `mask`: optional mask of shape (batch size, sequence length). 558 | // Masked position have value 0, non-masked value 1. If None set to 1. 559 | // - `tokenTypeIds`: optional segment id of shape (batch size, sequence length). 560 | // Convention is value of 0 for the first sentence (incl. [SEP]) and 1 for the second sentence. If None set to 0. 561 | // - `positionIds`: optional position ids of shape (batch size, sequence length). 562 | // If None, will be incremented from 0. 563 | // - `inputEmbeds`: optional pre-computed input embeddings of shape (batch size, sequence length, hidden size). 564 | // If None, input ids must be provided (see `inputIds`). 565 | // - `train`: boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. 566 | // 567 | // Returns: 568 | // - `output`: tensor of shape (batch size, sequence length, hidden size) 569 | // - `hiddenStates`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize) 570 | // - `attentions`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize) 571 | func (tc *BertForTokenClassification) ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (retVal *ts.Tensor, retValOpt1, retValOpt2 []ts.Tensor) { 572 | 573 | hiddenState, _, allHiddenStates, allAttentions, err := tc.bert.ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, ts.None, ts.None, train) 574 | if err != nil { 575 | log.Fatalf("Call 'BertForTokenClassification ForwardT' method error: %v\n", err) 576 | } 577 | 578 | outputDropout := hiddenState.ApplyT(tc.dropout, train) 579 | output := outputDropout.Apply(tc.classifier) 580 | 581 | outputDropout.MustDrop() 582 | 583 | return output, allHiddenStates, allAttentions 584 | } 585 | 586 | // BERT for question answering: 587 | // ============================ 588 | 589 | // BertForQuestionAnswering constructs extractive question-answering model based on a BERT language model. Identifies the segment of a context that answers a provided question. 590 | // 591 | // Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering. 592 | // See the question answering pipeline (also provided in this crate) for more details. 593 | // 594 | // It is made of the following blocks: 595 | // - `bert`: Base BertModel 596 | // - `qa_outputs`: Linear layer for question answering 597 | type BertForQuestionAnswering struct { 598 | bert *BertModel 599 | qaOutputs *nn.Linear 600 | } 601 | 602 | // NewBertForQuestionAnswering creates a new `BertForQuestionAnswering`. 603 | // 604 | // Params: 605 | // - `p`: Variable store path for the root of the BertForQuestionAnswering model 606 | // - `config`: `BertConfig` object defining the model architecture 607 | func NewForBertQuestionAnswering(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertForQuestionAnswering { 608 | changeName := true 609 | if len(changeNameOpt) > 0 { 610 | changeName = changeNameOpt[0] 611 | } 612 | bert := NewBertModel(p.Sub("bert"), config, changeName) 613 | 614 | numLabels := 2 615 | qaOutputs := nn.NewLinear(p.Sub("qa_outputs"), config.HiddenSize, int64(numLabels), nn.DefaultLinearConfig()) 616 | 617 | return &BertForQuestionAnswering{ 618 | bert: bert, 619 | qaOutputs: qaOutputs, 620 | } 621 | } 622 | 623 | // ForwardT forwards pass through the model. 624 | // 625 | // Params: 626 | // - `inputIds`: optional input tensor of shape (batch size, sequence length). 627 | // If None, pre-computed embeddings must be provided (see `inputEmbeds`) 628 | // - `mask`: optional mask of shape (batch size, sequence length). 629 | // Masked position have value 0, non-masked value 1. If None set to 1. 630 | // - `tokenTypeIds`: optional segment id of shape (batch size, sequence length). 631 | // Convention is value of 0 for the first sentence (incl. [SEP]) and 1 for the second sentence. If None set to 0. 632 | // - `positionIds`: optional position ids of shape (batch size, sequence length). 633 | // If None, will be incremented from 0. 634 | // - `inputEmbeds`: optional pre-computed input embeddings of shape (batch size, sequence length, hidden size). 635 | // If None, input ids must be provided (see `inputIds`). 636 | // - `train`: boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. 637 | // 638 | // Returns: 639 | // - `output`: tensor of shape (batch size, sequence length, hidden size) 640 | // - `hiddenStates`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize) 641 | // - `attentions`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize) 642 | func (qa *BertForQuestionAnswering) ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (retVal1, retVal2 *ts.Tensor, retValOpt1, retValOpt2 []ts.Tensor) { 643 | 644 | hiddenState, _, allHiddenStates, allAttentions, err := qa.bert.ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, ts.None, ts.None, train) 645 | if err != nil { 646 | log.Fatalf("Call 'BertForTokenClassification ForwardT' method error: %v\n", err) 647 | } 648 | 649 | sequenceOutput := hiddenState.Apply(qa.qaOutputs) 650 | logits := sequenceOutput.MustSplit(1, -1, false) // -1 : split along last size 651 | startLogits := logits[0].MustSqueezeDim(int64(-1), false) 652 | endLogits := logits[1].MustSqueezeDim(int64(-1), false) 653 | 654 | return startLogits, endLogits, allHiddenStates, allAttentions 655 | } 656 | --------------------------------------------------------------------------------