├── example
├── ner
│ └── main.go
├── qa
│ └── main.go
├── modeling
│ └── main.go
└── bert
│ └── main.go
├── pipeline.go
├── trainer.go
├── util
├── optimization.go
├── dropout.go
├── tensor.go
├── init.go
├── linear.go
├── activation.go
└── file-util.go
├── pretrained
├── tokenizer.go
├── config.go
├── model.go
├── bert.go
└── roberta.go
├── .gitignore
├── pipeline
├── token-classification.go
├── ner.go
└── common.go
├── go.mod
├── config-example_test.go
├── bert
├── opt.go
├── tokenizer_test.go
├── config_test.go
├── tokenizer.go
├── embedding.go
├── config.go
├── example_test.go
├── encoder.go
├── attention.go
├── model_test.go
└── model.go
├── CHANGELOG.md
├── .travis.yml
├── tokenizer.go
├── modeling.go
├── config.go
├── config_test.go
├── roberta
├── tokenizer.go
├── embedding.go
├── model_test.go
└── model.go
├── modeling_test.go
├── go.sum
├── coverage.out
├── README.md
└── LICENSE
/example/ner/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
--------------------------------------------------------------------------------
/example/qa/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
--------------------------------------------------------------------------------
/pipeline.go:
--------------------------------------------------------------------------------
1 | package transformer
2 |
--------------------------------------------------------------------------------
/trainer.go:
--------------------------------------------------------------------------------
1 | package transformer
2 |
--------------------------------------------------------------------------------
/util/optimization.go:
--------------------------------------------------------------------------------
1 | package util
2 |
--------------------------------------------------------------------------------
/pretrained/tokenizer.go:
--------------------------------------------------------------------------------
1 | package pretrained
2 |
3 | type Tokenizer interface {
4 | Load(modelNamOrPath string, params map[string]interface{}) error
5 | }
6 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .directory
2 | *.swp
3 | *.swo
4 |
5 | input/
6 | example/gorgonia/testdata/
7 | example/testdata/
8 | *.dat
9 | *.log
10 |
11 | *.txt
12 | *.json
13 | *.bak
14 |
15 | data/
16 |
--------------------------------------------------------------------------------
/example/modeling/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | // "github.com/sugarme/tokenizer"
5 | )
6 |
7 | type Trainer struct {
8 | Dataset interface{}
9 | Tokenizer interface{}
10 | }
11 |
--------------------------------------------------------------------------------
/pretrained/config.go:
--------------------------------------------------------------------------------
1 | package pretrained
2 |
3 | // Config is an interface for pretrained model configuration.
4 | // It has only one method `Load(string) error` to load configuration
5 | // from local or remote file.
6 | type Config interface {
7 | Load(modelNamOrPath string, params map[string]interface{}) error
8 | }
9 |
--------------------------------------------------------------------------------
/pipeline/token-classification.go:
--------------------------------------------------------------------------------
1 | package pipeline
2 |
3 | // Token classification pipeline (Named Entity Recognition, Part-of-Speech tagging).
4 | // More generic token classification pipeline, works with multiple models (Bert, Roberta).
5 |
6 | import (
7 |
8 | // "github.com/sugarme/gotch/nn"
9 | )
10 |
11 | type TokenClassificationModel struct{}
12 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/sugarme/transformer
2 |
3 | go 1.13
4 |
5 | require (
6 | github.com/sugarme/gotch v0.7.0
7 | github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c // indirect
8 | github.com/sugarme/tokenizer v0.1.17
9 | golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 // indirect
10 | golang.org/x/text v0.3.3 // indirect
11 | )
12 |
--------------------------------------------------------------------------------
/pretrained/model.go:
--------------------------------------------------------------------------------
1 | package pretrained
2 |
3 | import (
4 | "github.com/sugarme/gotch"
5 | )
6 |
7 | // Model is an interface for pretrained model.
8 | // It has only one method `Load(string) error` to load model
9 | // from local or remote file.
10 | type Model interface {
11 | Load(modelNamOrPath string, config interface{ Config }, params map[string]interface{}, device gotch.Device) error
12 | }
13 |
--------------------------------------------------------------------------------
/util/dropout.go:
--------------------------------------------------------------------------------
1 | package util
2 |
3 | import (
4 | "github.com/sugarme/gotch/ts"
5 | )
6 |
7 | type Dropout struct {
8 | dropoutProb float64
9 | }
10 |
11 | func NewDropout(p float64) *Dropout {
12 | return &Dropout{
13 | dropoutProb: p,
14 | }
15 | }
16 |
17 | func (d *Dropout) ForwardT(input *ts.Tensor, train bool) (retVal *ts.Tensor) {
18 | return ts.MustDropout(input, d.dropoutProb, train)
19 | }
20 |
--------------------------------------------------------------------------------
/config-example_test.go:
--------------------------------------------------------------------------------
1 | package transformer_test
2 |
3 | import (
4 | "fmt"
5 | "log"
6 |
7 | "github.com/sugarme/transformer"
8 | "github.com/sugarme/transformer/bert"
9 | )
10 |
11 | func ExampleLoadConfig() {
12 | modelNameOrPath := "bert-base-uncased"
13 | var config bert.BertConfig
14 | err := transformer.LoadConfig(&config, modelNameOrPath, nil)
15 | if err != nil {
16 | log.Fatal(err)
17 | }
18 |
19 | fmt.Println(config.VocabSize)
20 |
21 | // Output:
22 | // 30522
23 | }
24 |
--------------------------------------------------------------------------------
/util/tensor.go:
--------------------------------------------------------------------------------
1 | package util
2 |
3 | import (
4 | "reflect"
5 |
6 | "github.com/sugarme/gotch/ts"
7 | )
8 |
9 | // Equal compares 2 tensors in terms of shape, and every element values.
10 | func Equal(tensorA, tensorB *ts.Tensor) bool {
11 | var equal int64 = 0
12 | // 1. Compare shape
13 | if reflect.DeepEqual(tensorA.MustSize(), tensorB.MustSize()) {
14 | // 2. Compare values
15 | equal = tensorA.MustEqTensor(tensorB, false).MustAll(false).Int64Values()[0]
16 | }
17 | if equal == 0 {
18 | return false
19 | }
20 | return true
21 | }
22 |
--------------------------------------------------------------------------------
/bert/opt.go:
--------------------------------------------------------------------------------
1 | package bert
2 |
3 | import (
4 | "github.com/sugarme/gotch/ts"
5 | )
6 |
7 | // TensorOpt is a function type to create pointer to tensor.
8 | type TensorOpt func() *ts.Tensor
9 |
10 | func MaskTensorOpt(t *ts.Tensor) TensorOpt {
11 | return func() *ts.Tensor {
12 | return t
13 | }
14 | }
15 |
16 | func EncoderMaskTensorOpt(t *ts.Tensor) TensorOpt {
17 | return func() *ts.Tensor {
18 | return t
19 | }
20 | }
21 |
22 | func EncoderHiddenStateTensorOpt(t *ts.Tensor) TensorOpt {
23 | return func() *ts.Tensor {
24 | return t
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/bert/tokenizer_test.go:
--------------------------------------------------------------------------------
1 | package bert_test
2 |
3 | import (
4 | "reflect"
5 | "testing"
6 |
7 | "github.com/sugarme/transformer/bert"
8 | )
9 |
10 | func TestBertTokenizer(t *testing.T) {
11 | var tk *bert.Tokenizer = bert.NewTokenizer()
12 | err := tk.Load("bert-base-uncased", nil)
13 | if err != nil {
14 | t.Error(err)
15 | }
16 |
17 | gotVocabSize := tk.GetVocabSize(false)
18 | wantVocabSize := 30522
19 |
20 | if !reflect.DeepEqual(wantVocabSize, gotVocabSize) {
21 | t.Errorf("Want %v\n", wantVocabSize)
22 | t.Errorf("Got %v\n", gotVocabSize)
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 | All notable changes to this project will be documented in this file.
3 |
4 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
5 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
6 |
7 | ## [Unreleased]
8 |
9 | ### Fixed
10 | - [#...]: Fix a bug with...
11 |
12 | ### Changed
13 | - [#...]:
14 |
15 | ### Added
16 | - [#...]:
17 |
18 |
19 | ## [0.1.2]
20 |
21 | ### Added
22 | - [#2]: Added Roberta model
23 |
24 | ### Changed
25 | - Updated comment document for Bert model.
26 |
27 | ### Fixed
28 | - None
29 |
30 |
--------------------------------------------------------------------------------
/util/init.go:
--------------------------------------------------------------------------------
1 | package util
2 |
3 | import (
4 | "fmt"
5 | "log"
6 | "os"
7 | )
8 |
9 | var (
10 | CachedDir string = "NOT_SETTING"
11 | transformerEnvKey string = "GO_TRANSFORMER"
12 | )
13 |
14 | func init() {
15 | // default path: {$HOME}/.cache/transfomer
16 | homeDir := os.Getenv("HOME")
17 | CachedDir = fmt.Sprintf("%s/.cache/transformer", homeDir)
18 |
19 | initEnv()
20 |
21 | log.Printf("INFO: CachedDir=%q\n", CachedDir)
22 | }
23 |
24 | func initEnv() {
25 | val := os.Getenv(transformerEnvKey)
26 | if val != "" {
27 | CachedDir = val
28 | }
29 |
30 | if _, err := os.Stat(CachedDir); os.IsNotExist(err) {
31 | if err := os.MkdirAll(CachedDir, 0755); err != nil {
32 | log.Fatal(err)
33 | }
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: go
2 |
3 | go:
4 | - 1.14.x
5 |
6 | env:
7 | - GO111MODULE=on
8 |
9 | branches:
10 | only:
11 | - master
12 |
13 | dist: bionic
14 |
15 | before_install:
16 | - sudo apt-get install clang-tools-9
17 | - wget -O /tmp/libtorch-cxx11-abi-shared-with-deps-1.5.1+cpu.zip https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.5.1%2Bcpu.zip
18 | - unzip /tmp/libtorch-cxx11-abi-shared-with-deps-1.5.1+cpu.zip -d /opt
19 | - export LIBTORCH=/opt/libtorch
20 | - export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH
21 | - printenv
22 | - ls
23 | - rm libtch/dummy_cuda_dependency.cpp
24 | - mv libtch/fake_cuda_dependency.cpp.cpu libtch/fake_cuda_dependency.cpp
25 | - rm libtch/lib.go
26 | - mv libtch/lib.go.cpu libtch/lib.go
27 | script:
28 | - go get -u ./...
29 | - go test -v github.com/sugarme/transformer
30 | - go test -v github.com/sugarme/transformer/bert
31 |
--------------------------------------------------------------------------------
/tokenizer.go:
--------------------------------------------------------------------------------
1 | package transformer
2 |
3 | import (
4 | "github.com/sugarme/transformer/pretrained"
5 | )
6 |
7 | // LoadTokenizer loads pretrained tokenizer from local or remote file.
8 | //
9 | // Parameters:
10 | // - `tk` pretrained.Tokenizer (any tokenizer model that implements pretrained `Tokenizer` interface)
11 | // - `modelNameOrPath` is a string of either
12 | // + Model name or
13 | // + File name or path or
14 | // + URL to remote file
15 | // If `modelNameOrPath` is resolved, function will cache data using `TransformerCache`
16 | // environment if existing, otherwise it will be cached in `$HOME/.cache/transformers/` directory.
17 | // If `modleNameOrPath` is valid URL, file will be downloaded and cached.
18 | // Finally, vocab data will be loaded to `tk`.
19 | func LoadTokenizer(tk pretrained.Tokenizer, modelNameOrPath string, customParams map[string]interface{}) error {
20 | return tk.Load(modelNameOrPath, customParams)
21 | }
22 |
--------------------------------------------------------------------------------
/util/linear.go:
--------------------------------------------------------------------------------
1 | package util
2 |
3 | import (
4 | "github.com/sugarme/gotch/nn"
5 | "github.com/sugarme/gotch/ts"
6 | )
7 |
8 | type LinearNoBiasConfig struct {
9 | WsInit nn.Init // interface
10 | }
11 |
12 | func DefaultLinearNoBiasConfig() *LinearNoBiasConfig {
13 |
14 | init := nn.NewKaimingUniformInit()
15 |
16 | return &LinearNoBiasConfig{WsInit: init}
17 | }
18 |
19 | type LinearNoBias struct {
20 | Ws *ts.Tensor
21 | }
22 |
23 | func NewLinearNoBias(vs *nn.Path, inDim, outDim int64, config *LinearNoBiasConfig) (*LinearNoBias, error) {
24 |
25 | ws, err := vs.NewVar("weight", []int64{outDim, inDim}, config.WsInit)
26 | if err != nil {
27 | return nil, err
28 | }
29 |
30 | return &LinearNoBias{
31 | Ws: ws,
32 | }, nil
33 | }
34 |
35 | // Forward implements Module interface for LinearNoBias
36 | func (lnb *LinearNoBias) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
37 | wsT := lnb.Ws.MustT(false)
38 | retVal = xs.MustMatmul(wsT, false)
39 | wsT.MustDrop()
40 |
41 | return retVal
42 | }
43 |
--------------------------------------------------------------------------------
/modeling.go:
--------------------------------------------------------------------------------
1 | package transformer
2 |
3 | import (
4 | "github.com/sugarme/gotch"
5 | "github.com/sugarme/transformer/pretrained"
6 | )
7 |
8 | // LoadConfig loads pretrained model data from local or remote file.
9 | //
10 | // Parameters:
11 | // - `model` pretrained Model (any model type that implements pretrained `Model` interface)
12 | // - `modelNameOrPath` is a string of either
13 | // + Model name or
14 | // + File name or path or
15 | // + URL to remote file
16 | // If `modelNameOrPath` is resolved, function will cache data using `TransformerCache`
17 | // environment if existing, otherwise it will be cached in `$HOME/.cache/transformers/` directory.
18 | // If `modleNameOrPath` is valid URL, file will be downloaded and cached.
19 | // Finally, model weights will be loaded to `varstore`.
20 | func LoadModel(model pretrained.Model, modelNameOrPath string, config pretrained.Config, customParams map[string]interface{}, device gotch.Device) error {
21 | return model.Load(modelNameOrPath, config, customParams, device)
22 | }
23 |
--------------------------------------------------------------------------------
/config.go:
--------------------------------------------------------------------------------
1 | package transformer
2 |
3 | import (
4 | "github.com/sugarme/transformer/pretrained"
5 | "github.com/sugarme/transformer/util"
6 | )
7 |
8 | // LoadConfig loads pretrained configuration data from local or remote file.
9 | //
10 | // Parameters:
11 | // - `config` pretrained.Config (any model config that implements pretrained `Config` interface)
12 | // - `modelNameOrPath` is a string of either
13 | // - Model name or
14 | // - File name or path or
15 | // - URL to remote file
16 | //
17 | // If `modelNameOrPath` is resolved, function will cache data using `TransformerCache`
18 | // environment if existing, otherwise it will be cached in `$HOME/.cache/transformers/` directory.
19 | // If `modleNameOrPath` is valid URL, file will be downloaded and cached.
20 | // Finally, configuration data will be loaded to `config` parameter.
21 | func LoadConfig(config pretrained.Config, modelNameOrPath string, customParams map[string]interface{}) error {
22 | configFile, err := util.CachedPath(modelNameOrPath, "config.json")
23 | if err != nil {
24 | return err
25 | }
26 | return config.Load(configFile, customParams)
27 | }
28 |
--------------------------------------------------------------------------------
/bert/config_test.go:
--------------------------------------------------------------------------------
1 | package bert_test
2 |
3 | import (
4 | "reflect"
5 | "testing"
6 |
7 | "github.com/sugarme/transformer/bert"
8 | )
9 |
10 | // No custom params
11 | func TestNewBertConfig_Default(t *testing.T) {
12 |
13 | config := bert.NewConfig(nil)
14 |
15 | wantHiddenAct := "gelu"
16 | gotHiddenAct := config.HiddenAct
17 | if !reflect.DeepEqual(wantHiddenAct, gotHiddenAct) {
18 | t.Errorf("Want: '%v'\n", wantHiddenAct)
19 | t.Errorf("Got: '%v'\n", gotHiddenAct)
20 | }
21 |
22 | wantVocabSize := int64(30522)
23 | gotVocabSize := config.VocabSize
24 |
25 | if !reflect.DeepEqual(wantVocabSize, gotVocabSize) {
26 | t.Errorf("Want: '%v'\n", wantVocabSize)
27 | t.Errorf("Got: '%v'\n", gotVocabSize)
28 | }
29 | }
30 |
31 | // With custom params
32 | func TestNewBertConfig_Custom(t *testing.T) {
33 |
34 | config := bert.NewConfig(map[string]interface{}{"VocabSize": int64(2000), "HiddenAct": "relu"})
35 |
36 | wantHiddenAct := "relu"
37 | gotHiddenAct := config.HiddenAct
38 | if !reflect.DeepEqual(wantHiddenAct, gotHiddenAct) {
39 | t.Errorf("Want: '%v'\n", wantHiddenAct)
40 | t.Errorf("Got: '%v'\n", gotHiddenAct)
41 | }
42 |
43 | wantVocabSize := int64(2000)
44 | gotVocabSize := config.VocabSize
45 |
46 | if !reflect.DeepEqual(wantVocabSize, gotVocabSize) {
47 | t.Errorf("Want: '%v'\n", wantVocabSize)
48 | t.Errorf("Got: '%v'\n", gotVocabSize)
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/pipeline/ner.go:
--------------------------------------------------------------------------------
1 | package pipeline
2 |
3 | // Named Entity Recognition pipeline
4 | // Extracts entities (Person, Location, Organization, Miscellaneous) from text.
5 | // Pretrained models are available for the following languages:
6 | // - English
7 | // - German
8 | // - Spanish
9 | // - Dutch
10 | //
11 | // The default NER mode is an English BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
12 |
13 | // Entity holds entity data generated by NERModel
14 | type Entity struct {
15 | // String representation of the Entity
16 | Word string
17 | // Confidence score
18 | Score float64
19 | // Entity label (e.g. ORG, LOC...)
20 | Label string
21 | }
22 |
23 | // NERModel is a model to extract entities
24 | type NERModel struct {
25 | tokenClassificationModel TokenClassificationModel
26 | }
27 |
28 | // NewNERModel creates a NERModel from input config
29 | func NewNERModel(config TokenClassificationModel) *NERModel {
30 | return &NERModel{
31 | tokenClassificationModel: config,
32 | }
33 | }
34 |
35 | // Predict extracts entities from input text and returns slice of entities with score
36 | func (nm *NERModel) Predict(input []string) []Entity {
37 | tokens := nm.tokenClassificationModel.Predict(input, true, false)
38 |
39 | var entities []Entity
40 | for _, tok := range tokens {
41 | if tok.Label != "0" {
42 | entities = append(entities, Entity{
43 | Word: tok.Text,
44 | Score: tok.Score,
45 | Label: tok.Label,
46 | })
47 | }
48 | }
49 |
50 | return entities
51 | }
52 |
--------------------------------------------------------------------------------
/config_test.go:
--------------------------------------------------------------------------------
1 | package transformer_test
2 |
3 | import (
4 | "reflect"
5 | "testing"
6 |
7 | "github.com/sugarme/transformer"
8 | "github.com/sugarme/transformer/bert"
9 | )
10 |
11 | // With model name
12 | func TestConfigFromPretrained_ModelName(t *testing.T) {
13 | modelName := "bert-base-uncased"
14 | var config *bert.BertConfig = new(bert.BertConfig)
15 | err := transformer.LoadConfig(config, modelName, nil)
16 | if err != nil {
17 | t.Error(err)
18 | }
19 |
20 | wantVocabSize := int64(30522)
21 | gotVocabSize := config.VocabSize
22 |
23 | if !reflect.DeepEqual(wantVocabSize, gotVocabSize) {
24 | t.Errorf("Want: %v\n", wantVocabSize)
25 | t.Errorf("Got: %v\n", gotVocabSize)
26 | }
27 | }
28 |
29 | // No custom params
30 | func TestConfigFromPretrained(t *testing.T) {
31 | var config *bert.BertConfig = new(bert.BertConfig)
32 | err := transformer.LoadConfig(config, "bert-base-uncased", nil)
33 | if err != nil {
34 | t.Error(err)
35 | }
36 |
37 | wantVocabSize := int64(30522)
38 | gotVocabSize := config.VocabSize
39 |
40 | if !reflect.DeepEqual(wantVocabSize, gotVocabSize) {
41 | t.Errorf("Want: %v\n", wantVocabSize)
42 | t.Errorf("Got: %v\n", gotVocabSize)
43 | }
44 |
45 | }
46 |
47 | // With custom params
48 | func TestConfigFromPretrained_CustomParams(t *testing.T) {
49 | params := map[string]interface{}{
50 | "VocabSize": int64(2000),
51 | "NumLabels": int64(4),
52 | }
53 |
54 | var config *bert.BertConfig = new(bert.BertConfig)
55 | err := transformer.LoadConfig(config, "bert-base-uncased", params)
56 | if err != nil {
57 | t.Error(err)
58 | }
59 |
60 | wantVocabSize := int64(2000)
61 | gotVocabSize := config.VocabSize
62 |
63 | if !reflect.DeepEqual(wantVocabSize, gotVocabSize) {
64 | t.Errorf("Want: %v\n", wantVocabSize)
65 | t.Errorf("Got: %v\n", gotVocabSize)
66 | }
67 |
68 | wantNumLabels := int64(4)
69 | gotNumLabels := config.NumLabels
70 |
71 | if !reflect.DeepEqual(wantNumLabels, gotNumLabels) {
72 | t.Errorf("Want: %v\n", wantNumLabels)
73 | t.Errorf("Got: %v\n", gotNumLabels)
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/roberta/tokenizer.go:
--------------------------------------------------------------------------------
1 | package roberta
2 |
3 | import (
4 | // "fmt"
5 |
6 | "github.com/sugarme/tokenizer"
7 | "github.com/sugarme/tokenizer/model/bpe"
8 | "github.com/sugarme/tokenizer/normalizer"
9 | "github.com/sugarme/tokenizer/pretokenizer"
10 | "github.com/sugarme/tokenizer/processor"
11 |
12 | "github.com/sugarme/transformer/util"
13 | )
14 |
15 | // Tokenizer holds data for Roberta tokenizer.
16 | type Tokenizer struct {
17 | *tokenizer.Tokenizer
18 | }
19 |
20 | // NewTokenizer creates a new Roberta tokenizer.
21 | func NewTokenizer() *Tokenizer {
22 | tk := tokenizer.NewTokenizer(nil)
23 | return &Tokenizer{tk}
24 | }
25 |
26 | // Load loads Roberta tokenizer from pretrain vocab and merges files.
27 | func (t *Tokenizer) Load(modelNameOrPath string, params map[string]interface{}) error {
28 | vocabFile, err := util.CachedPath("roberta-base", "vocab.json")
29 | if err != nil {
30 | return err
31 | }
32 | mergesFile, err := util.CachedPath("roberta-base", "merges.txt")
33 | if err != nil {
34 | return err
35 | }
36 |
37 | model, err := bpe.NewBpeFromFiles(vocabFile, mergesFile)
38 | if err != nil {
39 | return err
40 | }
41 |
42 | t.WithModel(model)
43 |
44 | bertNormalizer := normalizer.NewBertNormalizer(true, true, true, true)
45 | t.WithNormalizer(bertNormalizer)
46 |
47 | blPreTokenizer := pretokenizer.NewByteLevel()
48 | // blPreTokenizer.SetAddPrefixSpace(false)
49 | t.WithPreTokenizer(blPreTokenizer)
50 |
51 | var specialTokens []tokenizer.AddedToken
52 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true))
53 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true))
54 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true))
55 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true))
56 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true))
57 | t.AddSpecialTokens(specialTokens)
58 |
59 | postProcess := processor.DefaultRobertaProcessing()
60 | t.WithPostProcessor(postProcess)
61 |
62 | return nil
63 | }
64 |
--------------------------------------------------------------------------------
/pretrained/bert.go:
--------------------------------------------------------------------------------
1 | package pretrained
2 |
3 | // BertConfigs is a map of pretrained Bert configuration names to corresponding URLs.
4 | var BertConfigs map[string]string = map[string]string{
5 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
6 | "bert-ner": "https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/config.json",
7 | "bert-qa": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
8 |
9 | // Roberta
10 | "roberta-base": "https://cdn.huggingface.co/roberta-base-config.json",
11 | "roberta-qa": "https://s3.amazonaws.com/models.huggingface.co/bert/deepset/roberta-base-squad2/config.json",
12 | "xlm-roberta-ner-en": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json",
13 | "xlm-roberta-ner-de": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json",
14 | "xlm-roberta-ner-nl": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json",
15 | "xlm-roberta-ner-es": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json",
16 | }
17 |
18 | // BertModels is a map of pretrained Bert model names to corresponding URLs.
19 | var BertModels map[string]string = map[string]string{
20 | "bert-base-uncased": "https://cdn.huggingface.co/bert-base-uncased-rust_model.ot",
21 | "bert-ner": "https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/rust_model.ot",
22 | "bert-qa": "https://cdn.huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad-rust_model.ot",
23 | }
24 |
25 | // BertVocabs is a map of BERT model vocab name to corresponding URLs.
26 | var BertVocabs map[string]string = map[string]string{
27 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
28 | "bert-ner": "https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/vocab.txt",
29 | "bert-qa": "https://cdn.huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
30 | }
31 |
--------------------------------------------------------------------------------
/bert/tokenizer.go:
--------------------------------------------------------------------------------
1 | package bert
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/sugarme/tokenizer"
7 | "github.com/sugarme/tokenizer/model/wordpiece"
8 | "github.com/sugarme/tokenizer/normalizer"
9 | "github.com/sugarme/tokenizer/pretokenizer"
10 | "github.com/sugarme/tokenizer/processor"
11 |
12 | // "github.com/sugarme/transformer/pretrained"
13 | "github.com/sugarme/transformer/util"
14 | )
15 |
16 | type BertTokenizerFast = tokenizer.Tokenizer
17 |
18 | // BertJapaneseTokenizerFromPretrained initiate BERT tokenizer for Japanese language from pretrained file.
19 | func BertJapaneseTokenizerFromPretrained(pretrainedModelNameOrPath string, customParams map[string]interface{}) *tokenizer.Tokenizer {
20 |
21 | // TODO: implement it
22 |
23 | panic("Not implemented yet.")
24 | }
25 |
26 | type Tokenizer struct {
27 | *tokenizer.Tokenizer
28 | }
29 |
30 | func NewTokenizer() *Tokenizer {
31 | tk := tokenizer.NewTokenizer(nil)
32 | return &Tokenizer{tk}
33 | }
34 |
35 | func (bt *Tokenizer) Load(modelNameOrPath string, params map[string]interface{}) error {
36 | cachedFile, err := util.CachedPath(modelNameOrPath, "vocab.txt")
37 | if err != nil {
38 | return err
39 | }
40 |
41 | model, err := wordpiece.NewWordPieceFromFile(cachedFile, "[UNK]")
42 | if err != nil {
43 | return err
44 | }
45 |
46 | bt.WithModel(model)
47 |
48 | bertNormalizer := normalizer.NewBertNormalizer(true, true, true, true)
49 | bt.WithNormalizer(bertNormalizer)
50 |
51 | bertPreTokenizer := pretokenizer.NewBertPreTokenizer()
52 | bt.WithPreTokenizer(bertPreTokenizer)
53 |
54 | var specialTokens []tokenizer.AddedToken
55 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("[MASK]", true))
56 |
57 | bt.AddSpecialTokens(specialTokens)
58 |
59 | sepId, ok := bt.TokenToId("[SEP]")
60 | if !ok {
61 | return fmt.Errorf("Cannot find ID for [SEP] token.\n")
62 | }
63 | sep := processor.PostToken{Id: sepId, Value: "[SEP]"}
64 |
65 | clsId, ok := bt.TokenToId("[CLS]")
66 | if !ok {
67 | return fmt.Errorf("Cannot find ID for [CLS] token.\n")
68 | }
69 | cls := processor.PostToken{Id: clsId, Value: "[CLS]"}
70 |
71 | postProcess := processor.NewBertProcessing(sep, cls)
72 | bt.WithPostProcessor(postProcess)
73 |
74 | // TODO: update params
75 |
76 | return nil
77 | }
78 |
--------------------------------------------------------------------------------
/modeling_test.go:
--------------------------------------------------------------------------------
1 | package transformer_test
2 |
3 | import (
4 | "reflect"
5 | "testing"
6 |
7 | "github.com/sugarme/transformer"
8 | "github.com/sugarme/transformer/bert"
9 | )
10 |
11 | // With model name
12 | func TestModelFromPretrained_ModelName(t *testing.T) {
13 | modelName := "bert-base-uncased"
14 | var config *bert.BertConfig = new(bert.BertConfig)
15 | err := transformer.LoadConfig(config, modelName, nil)
16 | if err != nil {
17 | t.Error(err)
18 | }
19 |
20 | wantVocabSize := int64(30522)
21 | gotVocabSize := config.VocabSize
22 |
23 | if !reflect.DeepEqual(wantVocabSize, gotVocabSize) {
24 | t.Errorf("Want: %v\n", wantVocabSize)
25 | t.Errorf("Got: %v\n", gotVocabSize)
26 | }
27 | }
28 |
29 | // With local file
30 |
31 | /*
32 | * // No custom params
33 | * func TestModelFromPretrained(t *testing.T) {
34 | * // bertURL := transformer.AllPretrainedConfigs["bert-base-uncased"]
35 | * url := "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json"
36 | *
37 | * var config *bert.BertConfig = new(bert.BertConfig)
38 | * err := transformer.LoadConfig(config, url, nil)
39 | * if err != nil {
40 | * t.Error(err)
41 | * }
42 | *
43 | * wantVocabSize := int64(30522)
44 | * gotVocabSize := config.VocabSize
45 | *
46 | * if !reflect.DeepEqual(wantVocabSize, gotVocabSize) {
47 | * t.Errorf("Want: %v\n", wantVocabSize)
48 | * t.Errorf("Got: %v\n", gotVocabSize)
49 | * }
50 | *
51 | * }
52 | *
53 | * // With custom params
54 | * func TestModelFromPretrained_CustomParams(t *testing.T) {
55 | * url := "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json"
56 | *
57 | * params := map[string]interface{}{
58 | * "VocabSize": int64(2000),
59 | * "NumLabels": int64(4),
60 | * }
61 | *
62 | * var config *bert.BertConfig = new(bert.BertConfig)
63 | * err := transformer.LoadConfig(config, url, params)
64 | * if err != nil {
65 | * t.Error(err)
66 | * }
67 | *
68 | * wantVocabSize := int64(2000)
69 | * gotVocabSize := config.VocabSize
70 | *
71 | * if !reflect.DeepEqual(wantVocabSize, gotVocabSize) {
72 | * t.Errorf("Want: %v\n", wantVocabSize)
73 | * t.Errorf("Got: %v\n", gotVocabSize)
74 | * }
75 | *
76 | * wantNumLabels := int64(4)
77 | * gotNumLabels := config.NumLabels
78 | *
79 | * if !reflect.DeepEqual(wantNumLabels, gotNumLabels) {
80 | * t.Errorf("Want: %v\n", wantNumLabels)
81 | * t.Errorf("Got: %v\n", gotNumLabels)
82 | * }
83 | * } */
84 |
--------------------------------------------------------------------------------
/pretrained/roberta.go:
--------------------------------------------------------------------------------
1 | package pretrained
2 |
3 | // RobertaConfigs is a map of pretrained Roberta configuration names to corresponding URLs.
4 | var RobertaConfigs map[string]string = map[string]string{
5 | "roberta-base": "https://cdn.huggingface.co/roberta-base-config.json",
6 | "roberta-qa": "https://s3.amazonaws.com/models.huggingface.co/bert/deepset/roberta-base-squad2/config.json",
7 | "xlm-roberta-ner-en": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json",
8 | "xlm-roberta-ner-de": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json",
9 | "xlm-roberta-ner-nl": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json",
10 | "xlm-roberta-ner-es": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json",
11 | }
12 |
13 | // RobertaModels is a map of pretrained Roberta model names to corresponding URLs.
14 | var RobertaModels map[string]string = map[string]string{
15 | "roberta-base": "https://cdn.huggingface.co/roberta-base-rust_model.ot",
16 | "roberta-qa": "https://cdn.huggingface.co/deepset/roberta-base-squad2/rust_model.ot",
17 | "xlm-roberta-ner-en": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-english-rust_model.ot",
18 | "xlm-roberta-ner-de": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-german-rust_model.ot",
19 | "xlm-roberta-ner-nl": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-dutch-rust_model.ot",
20 | "xlm-roberta-ner-es": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-spanish-rust_model.ot",
21 | }
22 |
23 | // RobertaVocabs is a map of pretrained Roberta vocab name to corresponding URLs.
24 | var RobertaVocabs map[string]string = map[string]string{
25 | "roberta-base": "https://cdn.huggingface.co/roberta-base-vocab.json",
26 | "roberta-qa": "https://cdn.huggingface.co/deepset/roberta-base-squad2/vocab.json",
27 | "xlm-roberta-ner-en": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-english-sentencepiece.bpe.model",
28 | "xlm-roberta-ner-de": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-german-sentencepiece.bpe.model",
29 | "xlm-roberta-ner-nl": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-dutch-sentencepiece.bpe.model",
30 | "xlm-roberta-ner-es": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-spanish-sentencepiece.bpe.model",
31 | }
32 |
33 | // RobertaMerges is a map of pretrained Roberta vocab merges name to corresponding URLs.
34 | var RobertaMerges map[string]string = map[string]string{
35 | "roberta-base": "https://cdn.huggingface.co/roberta-base-merges.txt",
36 | "roberta-qa": "https://cdn.huggingface.co/deepset/roberta-base-squad2/merges.txt",
37 | }
38 |
--------------------------------------------------------------------------------
/util/activation.go:
--------------------------------------------------------------------------------
1 | package util
2 |
3 | import (
4 | "github.com/sugarme/gotch/ts"
5 | )
6 |
7 | // ActivationFn is an activation function.
8 | type ActivationFn interface {
9 | // Fwd is a forward pass through x.
10 | Fwd(x *ts.Tensor) *ts.Tensor
11 | Name() string
12 | }
13 |
14 | // ReLU activation:
15 | // ===============
16 |
17 | type ReluActivation struct {
18 | name string
19 | }
20 |
21 | var Relu = ReluActivation{}
22 |
23 | func NewRelu() ReluActivation {
24 | return ReluActivation{"relu"}
25 | }
26 |
27 | func (r ReluActivation) Fwd(x *ts.Tensor) (retVal *ts.Tensor) {
28 | return x.MustRelu(false)
29 | }
30 |
31 | func (r ReluActivation) Name() (retVal string) {
32 | return r.name
33 | }
34 |
35 | // GeLU activation:
36 | // ===============
37 |
38 | type GeluActivation struct {
39 | name string
40 | }
41 |
42 | var Gelu = GeluActivation{}
43 |
44 | func NewGelu() GeluActivation {
45 | return GeluActivation{"gelu"}
46 | }
47 |
48 | func (g GeluActivation) Fwd(x *ts.Tensor) (retVal *ts.Tensor) {
49 | return x.MustGelu(false)
50 | }
51 |
52 | func (g GeluActivation) Name() (retVal string) {
53 | return g.name
54 | }
55 |
56 | // Tanh activation:
57 | // ===============
58 |
59 | type TanhActivation struct {
60 | name string
61 | }
62 |
63 | var Tanh = TanhActivation{}
64 |
65 | func NewTanh() TanhActivation {
66 | return TanhActivation{"tanh"}
67 | }
68 |
69 | func (t TanhActivation) Fwd(x *ts.Tensor) (retVal *ts.Tensor) {
70 | return x.MustTanh(false)
71 | }
72 |
73 | func (t TanhActivation) Name() string {
74 | return t.name
75 | }
76 |
77 | // Swish activation:
78 | // ===============
79 |
80 | type SwishActivation struct {
81 | name string
82 | }
83 |
84 | var Swish = SwishActivation{}
85 |
86 | func NewSwish() SwishActivation {
87 | return SwishActivation{"swish"}
88 | }
89 |
90 | func (s SwishActivation) Fwd(x *ts.Tensor) (retVal *ts.Tensor) {
91 | return x.Swish()
92 | }
93 |
94 | func (s SwishActivation) Name() (retVal string) {
95 | return s.name
96 | }
97 |
98 | // Mish activation:
99 | // =================
100 |
101 | type MishActivation struct {
102 | name string
103 | }
104 |
105 | var Mish = MishActivation{}
106 |
107 | func NewMish() MishActivation {
108 | return MishActivation{"mish"}
109 | }
110 |
111 | func (m MishActivation) Fwd(x *ts.Tensor) (retVal *ts.Tensor) {
112 | softplus := x.MustSoftplus(false)
113 | tanh := softplus.MustTanh(true)
114 | retVal = x.MustMm(tanh, false)
115 | tanh.MustDrop()
116 | return retVal
117 | }
118 |
119 | func (m MishActivation) Name() (retVal string) {
120 | return m.name
121 | }
122 |
123 | func geluNew(xs *ts.Tensor) (retVal *ts.Tensor) {
124 | // TODO: implement
125 | // x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1)
126 | return retVal
127 | }
128 |
129 | var ActivationFnMap map[string]ActivationFn = map[string]ActivationFn{
130 | "gelu": NewGelu(),
131 | "relu": NewRelu(),
132 | "tanh": NewTanh(),
133 | "swish": NewSwish(),
134 | "mish": NewMish(),
135 | }
136 |
--------------------------------------------------------------------------------
/bert/embedding.go:
--------------------------------------------------------------------------------
1 | package bert
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/sugarme/gotch"
7 | "github.com/sugarme/gotch/nn"
8 | "github.com/sugarme/gotch/ts"
9 |
10 | "github.com/sugarme/transformer/util"
11 | )
12 |
13 | // BertEmbedding defines interface for BertModel or RoBertaModel.
14 | type BertEmbedding interface {
15 | ForwardT(inputIds, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (*ts.Tensor, error)
16 | }
17 |
18 | type BertEmbeddings struct {
19 | WordEmbeddings *nn.Embedding
20 | PositionEmbeddings *nn.Embedding
21 | TokenTypeEmbeddings *nn.Embedding
22 | LayerNorm *nn.LayerNorm
23 | Dropout *util.Dropout
24 | }
25 |
26 | // NewBertEmbeddings builds a new BertEmbeddings
27 | func NewBertEmbeddings(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertEmbeddings {
28 | changeName := true
29 | if len(changeNameOpt) > 0 {
30 | changeName = changeNameOpt[0]
31 | }
32 | embeddingConfig := nn.DefaultEmbeddingConfig()
33 | embeddingConfig.PaddingIdx = 0
34 |
35 | wEmbedPath := p.Sub("word_embeddings")
36 | wordEmbeddings := nn.NewEmbedding(wEmbedPath, config.VocabSize, config.HiddenSize, embeddingConfig)
37 |
38 | posEmbedPath := p.Sub("position_embeddings")
39 | positionEmbeddings := nn.NewEmbedding(posEmbedPath, config.MaxPositionEmbeddings, config.HiddenSize, embeddingConfig)
40 |
41 | ttEmbedPath := p.Sub("token_type_embeddings")
42 | tokenTypeEmbeddings := nn.NewEmbedding(ttEmbedPath, config.TypeVocabSize, config.HiddenSize, embeddingConfig)
43 |
44 | layerNormConfig := nn.DefaultLayerNormConfig()
45 | if changeName {
46 | layerNormConfig.WsName = "gamma"
47 | layerNormConfig.BsName = "beta"
48 | }
49 | layerNormConfig.Eps = 1e-12
50 |
51 | lnPath := p.Sub("LayerNorm")
52 | layerNorm := nn.NewLayerNorm(lnPath, []int64{config.HiddenSize}, layerNormConfig)
53 |
54 | dropout := util.NewDropout(config.HiddenDropoutProb)
55 |
56 | return &BertEmbeddings{wordEmbeddings, positionEmbeddings, tokenTypeEmbeddings, layerNorm, dropout}
57 | }
58 |
59 | // ForwardT implements BertEmbedding interface, passes throught the embedding layer
60 | func (be *BertEmbeddings) ForwardT(inputIds, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (retVal *ts.Tensor, err error) {
61 |
62 | var (
63 | inputEmbeddings *ts.Tensor
64 | inputShape []int64
65 | )
66 |
67 | if inputIds.MustDefined() {
68 | if inputEmbeds.MustDefined() {
69 | err = fmt.Errorf("Only one of input Ids or input embeddings may be set.")
70 | return retVal, err
71 | } else {
72 | inputEmbeddings = inputIds.ApplyT(be.WordEmbeddings, train)
73 | inputShape = inputIds.MustSize()
74 | }
75 | } else {
76 | if inputEmbeds.MustDefined() {
77 | inputEmbeddings = inputEmbeds
78 | size := inputEmbeds.MustSize()
79 | inputShape = []int64{size[0], size[1]}
80 | } else {
81 | err = fmt.Errorf("Only one of input Ids or input embeddings may be set.")
82 | return retVal, err
83 | }
84 | }
85 |
86 | seqLength := inputEmbeddings.MustSize()[1]
87 |
88 | var posIds *ts.Tensor
89 | if positionIds.MustDefined() {
90 | posIds = positionIds
91 | } else {
92 | tmp1 := ts.MustArange(ts.IntScalar(seqLength), gotch.Int64, inputEmbeddings.MustDevice())
93 | tmp2 := tmp1.MustUnsqueeze(0, true)
94 | posIds = tmp2.MustExpand(inputShape, true, true)
95 | }
96 |
97 | var tokTypeIds *ts.Tensor
98 | if tokenTypeIds.MustDefined() {
99 | tokTypeIds = tokenTypeIds
100 | } else {
101 | tokTypeIds = ts.MustZeros(inputShape, gotch.Int64, inputEmbeddings.MustDevice())
102 | }
103 |
104 | posEmbeddings := posIds.Apply(be.PositionEmbeddings)
105 | posIds.MustDrop()
106 | tokEmbeddings := tokTypeIds.Apply(be.TokenTypeEmbeddings)
107 | tokTypeIds.MustDrop()
108 |
109 | input := inputEmbeddings.MustAdd(posEmbeddings, true)
110 | posEmbeddings.MustDrop()
111 | input.MustAdd_(tokEmbeddings)
112 | tokEmbeddings.MustDrop()
113 |
114 | retTmp1 := input.Apply(be.LayerNorm)
115 | input.MustDrop()
116 | retVal = retTmp1.ApplyT(be.Dropout, train)
117 | retTmp1.MustDrop()
118 |
119 | return retVal, nil
120 | }
121 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
4 | github.com/emirpasic/gods v1.12.0 h1:QAUIPSaCu4G+POclxeqb3F+WPpdKqFGlw36+yOzGlrg=
5 | github.com/emirpasic/gods v1.12.0/go.mod h1:YfzfFFoVP/catgzJb4IKIqXjX78Ha8FMSDh3ymbK86o=
6 | github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
7 | github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
8 | github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
9 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
10 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
11 | github.com/rivo/uniseg v0.1.0 h1:+2KBaVoUmb9XzDsrx/Ct0W/EYOSFf/nWTauy++DprtY=
12 | github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
13 | github.com/schollz/progressbar/v2 v2.15.0 h1:dVzHQ8fHRmtPjD3K10jT3Qgn/+H+92jhPrhmxIJfDz8=
14 | github.com/schollz/progressbar/v2 v2.15.0/go.mod h1:UdPq3prGkfQ7MOzZKlDRpYKcFqEMczbD7YmbPgpzKMI=
15 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
16 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
17 | github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
18 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
19 | github.com/sugarme/gotch v0.0.0-20200924012111-ff31d3c62dbe h1:8i3jLRuqZvwR5mvuioHO7Hq6qITQkSPzFc/M9LxkkkY=
20 | github.com/sugarme/gotch v0.0.0-20200924012111-ff31d3c62dbe/go.mod h1:w3zHhlZfnNS//C7YQd/89fOmgofd0RK8jf/GJ+cqkO4=
21 | github.com/sugarme/gotch v0.7.0 h1:vDQqLmuo5uhqNTfTyR7xbye9pPK9a4l57YWMKH41gGU=
22 | github.com/sugarme/gotch v0.7.0/go.mod h1:ydo7fmsmT+2L5p8Am1YhOLSWGN9WV9nyrfk/RSmhTfo=
23 | github.com/sugarme/regexpset v0.0.0-20200813070853-0a3212d91786/go.mod h1:2gwkXLWbDGUQWeL3RtpCmcY4mzCtU13kb9UsAg9xMaw=
24 | github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c h1:pwb4kNSHb4K89ymCaN+5lPH/MwnfSVg4rzGDh4d+iy4=
25 | github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c/go.mod h1:2gwkXLWbDGUQWeL3RtpCmcY4mzCtU13kb9UsAg9xMaw=
26 | github.com/sugarme/tokenizer v0.0.0-20200930080132-dbbba9ea4756 h1:zPZePVVB2+NiDDqtII00+BQ+MqTJpCt6HBqmtbGBUmA=
27 | github.com/sugarme/tokenizer v0.0.0-20200930080132-dbbba9ea4756/go.mod h1:mkypAc0rOi4EDcKuuJTdbPJOk8QMMZABaJRNs1M1OeU=
28 | github.com/sugarme/tokenizer v0.1.16 h1:M1YpaEaZjQw7HiXrgujCcsp2ccbOLVtWNIK37o3kZlk=
29 | github.com/sugarme/tokenizer v0.1.16/go.mod h1:a1EffeqKJAQtJz/IEAgaN1SIG/TCgawEgWOV9rquM5M=
30 | github.com/sugarme/tokenizer v0.1.17 h1:1UDhHtz/nG7FGQfEJPXF3VbFVIYuKiBymZRAxdrLIBM=
31 | github.com/sugarme/tokenizer v0.1.17/go.mod h1:a1EffeqKJAQtJz/IEAgaN1SIG/TCgawEgWOV9rquM5M=
32 | golang.org/x/image v0.0.0-20200927104501-e162460cd6b5/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
33 | golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
34 | golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 h1:qwRHBd0NqMbJxfbotnDhm2ByMI1Shq4Y6oRJo21SGJA=
35 | golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
36 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
37 | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
38 | golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
39 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
40 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
41 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
42 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
43 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
44 | gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
45 | gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
46 |
--------------------------------------------------------------------------------
/coverage.out:
--------------------------------------------------------------------------------
1 | mode: set
2 | github.com/sugarme/sermo/tokenize/tokenizer.go:21.60,23.2 1 1
3 | github.com/sugarme/sermo/tokenize/tokenizer.go:30.41,35.2 1 1
4 | github.com/sugarme/sermo/tokenize/tokenizer.go:45.89,46.28 1 1
5 | github.com/sugarme/sermo/tokenize/tokenizer.go:46.28,52.3 1 1
6 | github.com/sugarme/sermo/tokenize/tokenizer.go:55.57,58.25 2 1
7 | github.com/sugarme/sermo/tokenize/tokenizer.go:62.2,62.12 1 1
8 | github.com/sugarme/sermo/tokenize/tokenizer.go:58.25,60.3 1 1
9 | github.com/sugarme/sermo/tokenize/unicode.go:37.29,38.11 1 1
10 | github.com/sugarme/sermo/tokenize/unicode.go:46.2,46.46 1 1
11 | github.com/sugarme/sermo/tokenize/unicode.go:39.12,40.15 1 1
12 | github.com/sugarme/sermo/tokenize/unicode.go:41.12,42.15 1 1
13 | github.com/sugarme/sermo/tokenize/unicode.go:43.12,44.15 1 0
14 | github.com/sugarme/sermo/tokenize/unicode.go:50.27,51.11 1 1
15 | github.com/sugarme/sermo/tokenize/unicode.go:61.2,61.34 1 1
16 | github.com/sugarme/sermo/tokenize/unicode.go:52.11,53.14 1 1
17 | github.com/sugarme/sermo/tokenize/unicode.go:54.12,55.14 1 1
18 | github.com/sugarme/sermo/tokenize/unicode.go:56.12,57.14 1 1
19 | github.com/sugarme/sermo/tokenize/unicode.go:58.12,59.14 1 0
20 | github.com/sugarme/sermo/tokenize/unicode.go:65.27,67.2 1 1
21 | github.com/sugarme/sermo/tokenize/unicode.go:70.29,72.2 1 1
22 | github.com/sugarme/sermo/tokenize/word.go:16.63,21.2 1 1
23 | github.com/sugarme/sermo/tokenize/word.go:29.54,37.2 4 1
24 | github.com/sugarme/sermo/tokenize/word.go:39.31,41.24 2 1
25 | github.com/sugarme/sermo/tokenize/word.go:50.2,50.19 1 1
26 | github.com/sugarme/sermo/tokenize/word.go:41.24,42.44 1 1
27 | github.com/sugarme/sermo/tokenize/word.go:42.44,43.12 1 0
28 | github.com/sugarme/sermo/tokenize/word.go:44.9,44.24 1 1
29 | github.com/sugarme/sermo/tokenize/word.go:44.24,46.4 1 1
30 | github.com/sugarme/sermo/tokenize/word.go:46.9,48.4 1 1
31 | github.com/sugarme/sermo/tokenize/word.go:53.38,55.41 2 1
32 | github.com/sugarme/sermo/tokenize/word.go:60.2,60.19 1 1
33 | github.com/sugarme/sermo/tokenize/word.go:55.41,56.33 1 1
34 | github.com/sugarme/sermo/tokenize/word.go:56.33,58.4 1 1
35 | github.com/sugarme/sermo/tokenize/word.go:63.38,67.24 3 1
36 | github.com/sugarme/sermo/tokenize/word.go:77.2,77.17 1 1
37 | github.com/sugarme/sermo/tokenize/word.go:80.2,80.13 1 1
38 | github.com/sugarme/sermo/tokenize/word.go:67.24,68.17 1 1
39 | github.com/sugarme/sermo/tokenize/word.go:68.17,72.4 3 1
40 | github.com/sugarme/sermo/tokenize/word.go:72.9,74.4 1 1
41 | github.com/sugarme/sermo/tokenize/word.go:77.17,79.3 1 1
42 | github.com/sugarme/sermo/tokenize/word.go:83.53,88.11 3 1
43 | github.com/sugarme/sermo/tokenize/word.go:92.2,93.27 2 1
44 | github.com/sugarme/sermo/tokenize/word.go:98.2,99.16 2 1
45 | github.com/sugarme/sermo/tokenize/word.go:88.11,90.3 1 1
46 | github.com/sugarme/sermo/tokenize/word.go:93.27,96.3 2 1
47 | github.com/sugarme/sermo/tokenize/word.go:102.37,104.24 2 1
48 | github.com/sugarme/sermo/tokenize/word.go:110.2,110.13 1 1
49 | github.com/sugarme/sermo/tokenize/word.go:104.24,108.3 3 1
50 | github.com/sugarme/sermo/tokenize/word.go:115.43,121.27 3 1
51 | github.com/sugarme/sermo/tokenize/word.go:126.2,126.16 1 1
52 | github.com/sugarme/sermo/tokenize/word.go:121.27,122.16 1 1
53 | github.com/sugarme/sermo/tokenize/word.go:122.16,124.4 1 1
54 | github.com/sugarme/sermo/tokenize/word.go:131.36,133.24 2 1
55 | github.com/sugarme/sermo/tokenize/word.go:142.2,142.19 1 1
56 | github.com/sugarme/sermo/tokenize/word.go:133.24,134.19 1 1
57 | github.com/sugarme/sermo/tokenize/word.go:134.19,138.4 3 1
58 | github.com/sugarme/sermo/tokenize/word.go:138.9,140.4 1 1
59 | github.com/sugarme/sermo/tokenize/wordpiece.go:25.49,30.2 1 0
60 | github.com/sugarme/sermo/tokenize/wordpiece.go:35.60,41.25 3 1
61 | github.com/sugarme/sermo/tokenize/wordpiece.go:44.2,44.16 1 1
62 | github.com/sugarme/sermo/tokenize/wordpiece.go:41.25,43.3 1 1
63 | github.com/sugarme/sermo/tokenize/wordpiece.go:47.70,49.43 2 1
64 | github.com/sugarme/sermo/tokenize/wordpiece.go:64.2,64.13 1 1
65 | github.com/sugarme/sermo/tokenize/wordpiece.go:49.43,50.33 1 1
66 | github.com/sugarme/sermo/tokenize/wordpiece.go:54.3,54.35 1 1
67 | github.com/sugarme/sermo/tokenize/wordpiece.go:50.33,52.12 2 0
68 | github.com/sugarme/sermo/tokenize/wordpiece.go:54.35,56.17 2 1
69 | github.com/sugarme/sermo/tokenize/wordpiece.go:60.4,61.45 2 1
70 | github.com/sugarme/sermo/tokenize/wordpiece.go:56.17,58.10 2 0
71 | github.com/sugarme/sermo/tokenize/wordpiece.go:69.53,71.2 1 0
72 | github.com/sugarme/sermo/tokenize/wordpiece.go:74.58,76.2 1 0
73 |
--------------------------------------------------------------------------------
/bert/config.go:
--------------------------------------------------------------------------------
1 | package bert
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "io/ioutil"
7 | "log"
8 | "os"
9 | "path/filepath"
10 | "reflect"
11 | )
12 |
13 | // BertConfig defines the BERT model architecture (i.e., number of layers,
14 | // hidden layer size, label mapping...)
15 | type BertConfig struct {
16 | HiddenAct string `json:"hidden_act"`
17 | AttentionProbsDropoutProb float64 `json:"attention_probs_dropout_prob"`
18 | HiddenDropoutProb float64 `json:"hidden_dropout_prob"`
19 | HiddenSize int64 `json:"hidden_size"`
20 | InitializerRange float32 `json:"initializer_range"`
21 | IntermediateSize int64 `json:"intermediate_size"`
22 | MaxPositionEmbeddings int64 `json:"max_position_embeddings"`
23 | NumAttentionHeads int64 `json:"num_attention_heads"`
24 | NumHiddenLayers int64 `json:"num_hidden_layers"`
25 | TypeVocabSize int64 `json:"type_vocab_size"`
26 | VocabSize int64 `json:"vocab_size"`
27 | OutputAttentions bool `json:"output_attentions"`
28 | OutputHiddenStates bool `json:"output_hidden_states"`
29 | IsDecoder bool `json:"is_decoder"`
30 | Id2Label map[int64]string `json:"id_2_label"`
31 | Label2Id map[string]int64 `json:"label_2_id"`
32 | NumLabels int64 `json:"num_labels"`
33 | }
34 |
35 | // NewBertConfig initiates BertConfig with given input parameters or default values.
36 | func NewConfig(customParams map[string]interface{}) *BertConfig {
37 | defaultValues := map[string]interface{}{
38 | "VocabSize": int64(30522),
39 | "HiddenSize": int64(768),
40 | "NumHiddenLayers": int64(12),
41 | "NumAttentionHeads": int64(12),
42 | "IntermediateSize": int64(3072),
43 | "HiddenAct": "gelu",
44 | "HiddenDropoutProb": float64(0.1),
45 | "AttentionProbDropoutProb": float64(0.1),
46 | "MaxPositionEmbeddings": int64(512),
47 | "TypeVocabSize": int64(2),
48 | "InitializerRange": float32(0.02),
49 | "LayerNormEps": 1e-12, // not applied yet
50 | "PadTokenId": 0, // not applied yet
51 | "GradientCheckpointing": false, // not applied yet
52 | }
53 |
54 | params := defaultValues
55 | for k, v := range customParams {
56 | if _, ok := params[k]; ok {
57 | params[k] = v
58 | }
59 | }
60 |
61 | config := new(BertConfig)
62 | config.updateParams(params)
63 |
64 | return config
65 | }
66 |
67 | func ConfigFromFile(filename string) (*BertConfig, error) {
68 | filePath, err := filepath.Abs(filename)
69 | if err != nil {
70 | return nil, err
71 | }
72 |
73 | f, err := os.Open(filePath)
74 | if err != nil {
75 | return nil, err
76 | }
77 | defer f.Close()
78 |
79 | buff, err := ioutil.ReadAll(f)
80 | if err != nil {
81 | return nil, err
82 | }
83 |
84 | var config BertConfig
85 | err = json.Unmarshal(buff, &config)
86 | if err != nil {
87 | fmt.Println(err)
88 | log.Fatalf("Could not parse configuration to BertConfiguration.\n")
89 | }
90 | return &config, nil
91 | }
92 |
93 | // Load loads model configuration from file or model name. It also updates
94 | // default configuration parameters if provided.
95 | // This method implements `pretrained.Config` interface.
96 | func (c *BertConfig) Load(modelNameOrPath string, params map[string]interface{}) error {
97 | err := c.fromFile(modelNameOrPath)
98 | if err != nil {
99 | return err
100 | }
101 |
102 | // Update custom parameters
103 | c.updateParams(params)
104 |
105 | return nil
106 | }
107 |
108 | func (c *BertConfig) fromFile(filename string) error {
109 | filePath, err := filepath.Abs(filename)
110 | if err != nil {
111 | return err
112 | }
113 |
114 | f, err := os.Open(filePath)
115 | if err != nil {
116 | return err
117 | }
118 | defer f.Close()
119 |
120 | buff, err := ioutil.ReadAll(f)
121 | if err != nil {
122 | return err
123 | }
124 |
125 | err = json.Unmarshal(buff, c)
126 | if err != nil {
127 | fmt.Println(err)
128 | log.Fatalf("Could not parse configuration to BertConfiguration.\n")
129 | }
130 |
131 | return nil
132 | }
133 |
134 | func (c *BertConfig) GetVocabSize() int64 {
135 | return c.VocabSize
136 | }
137 |
138 | func (c *BertConfig) updateParams(params map[string]interface{}) {
139 | for k, v := range params {
140 | c.updateField(k, v)
141 | }
142 | }
143 |
144 | func (c *BertConfig) updateField(field string, value interface{}) {
145 | // Check whether field name exists
146 | if reflect.ValueOf(c).Elem().FieldByName(field).IsValid() {
147 | // Check whether same type
148 | if reflect.ValueOf(c).Elem().FieldByName(field).Kind() == reflect.TypeOf(value).Kind() {
149 | reflect.ValueOf(c).Elem().FieldByName(field).Set(reflect.ValueOf(value))
150 | }
151 | }
152 | }
153 |
--------------------------------------------------------------------------------
/bert/example_test.go:
--------------------------------------------------------------------------------
1 | package bert_test
2 |
3 | import (
4 | "fmt"
5 | "log"
6 |
7 | "github.com/sugarme/gotch"
8 | "github.com/sugarme/gotch/nn"
9 | "github.com/sugarme/gotch/pickle"
10 | "github.com/sugarme/gotch/ts"
11 | "github.com/sugarme/tokenizer"
12 |
13 | "github.com/sugarme/transformer/bert"
14 | "github.com/sugarme/transformer/util"
15 | )
16 |
17 | func ExampleBertForMaskedLM() {
18 | // Config
19 | configFile, err := util.CachedPath("bert-base-uncased", "config.json")
20 | if err != nil {
21 | log.Fatal(err)
22 | }
23 | config := new(bert.BertConfig)
24 | err = config.Load(configFile, nil)
25 | if err != nil {
26 | log.Fatal(err)
27 | }
28 |
29 | // Model
30 | device := gotch.CPU
31 | vs := nn.NewVarStore(device)
32 |
33 | model := new(bert.BertForMaskedLM)
34 |
35 | modelFile, err := util.CachedPath("bert-base-uncased", "pytorch_model.bin")
36 | if err != nil {
37 | log.Fatal(err)
38 | }
39 | err = pickle.LoadAll(vs, modelFile)
40 | if err != nil {
41 | log.Fatal(err)
42 | }
43 |
44 | vocabFile, err := util.CachedPath("bert-base-uncased", "vocab.txt")
45 | if err != nil {
46 | log.Fatal(err)
47 | }
48 | tk := getBertTokenizer(vocabFile)
49 | sentence1 := "Looks like one [MASK] is missing"
50 | sentence2 := "It was a very nice and [MASK] day"
51 |
52 | var input []tokenizer.EncodeInput
53 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)))
54 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)))
55 |
56 | encodings, err := tk.EncodeBatch(input, true)
57 | if err != nil {
58 | log.Fatal(err)
59 | }
60 |
61 | var maxLen int = 0
62 | for _, en := range encodings {
63 | if len(en.Ids) > maxLen {
64 | maxLen = len(en.Ids)
65 | }
66 | }
67 |
68 | var tensors []ts.Tensor
69 | for _, en := range encodings {
70 | var tokInput []int64 = make([]int64, maxLen)
71 | for i := 0; i < len(en.Ids); i++ {
72 | tokInput[i] = int64(en.Ids[i])
73 | }
74 |
75 | tensors = append(tensors, *ts.TensorFrom(tokInput))
76 | }
77 |
78 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true)
79 |
80 | var output *ts.Tensor
81 | ts.NoGrad(func() {
82 | output, _, _ = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, ts.None, ts.None, false)
83 | })
84 |
85 | index1 := output.MustGet(0).MustGet(4).MustArgmax([]int64{0}, false, false).Int64Values()[0]
86 | index2 := output.MustGet(1).MustGet(7).MustArgmax([]int64{0}, false, false).Int64Values()[0]
87 |
88 | got1, ok := tk.IdToToken(int(index1))
89 | if !ok {
90 | fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index1)
91 | }
92 | got2, ok := tk.IdToToken(int(index2))
93 | if !ok {
94 | fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index2)
95 | }
96 |
97 | fmt.Println(got1)
98 | fmt.Println(got2)
99 | /*
100 | * // Output:
101 | * // person
102 | * // pleasant
103 | * */
104 | }
105 |
106 | /*
107 | * func ExampleBertForSequenceClassification() {
108 | *
109 | * device := gotch.CPU
110 | * vs := nn.NewVarStore(device)
111 | *
112 | * config := bert.ConfigFromFile("../data/bert/config.json")
113 | *
114 | * var dummyLabelMap map[int64]string = make(map[int64]string)
115 | * dummyLabelMap[0] = "positive"
116 | * dummyLabelMap[1] = "negative"
117 | * dummyLabelMap[3] = "neutral"
118 | *
119 | * config.Id2Label = dummyLabelMap
120 | * config.OutputAttentions = true
121 | * config.OutputHiddenStates = true
122 | * model := bert.NewBertForSequenceClassification(vs.Root(), config)
123 | * tk := getBertTokenizer()
124 | *
125 | * // Define input
126 | * sentence1 := "Looks like one thing is missing"
127 | * sentence2 := `It's like comparing oranges to apples`
128 | * var input []tokenizer.EncodeInput
129 | * input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)))
130 | * input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)))
131 | * encodings, err := tk.EncodeBatch(input, true)
132 | * if err != nil {
133 | * log.Fatal(err)
134 | * }
135 | *
136 | * // Find max length of token Ids from slice of encodings
137 | * var maxLen int = 0
138 | * for _, en := range encodings {
139 | * if len(en.Ids) > maxLen {
140 | * maxLen = len(en.Ids)
141 | * }
142 | * }
143 | *
144 | * var tensors []ts.Tensor
145 | * for _, en := range encodings {
146 | * var tokInput []int64 = make([]int64, maxLen)
147 | * for i := 0; i < len(en.Ids); i++ {
148 | * tokInput[i] = int64(en.Ids[i])
149 | * }
150 | *
151 | * tensors = append(tensors, ts.TensorFrom(tokInput))
152 | * }
153 | *
154 | * inputTensor := ts.MustStack(tensors, 0).MustTo(device, true)
155 | *
156 | * var (
157 | * output ts.Tensor
158 | * allHiddenStates, allAttentions []ts.Tensor
159 | * )
160 | *
161 | * ts.NoGrad(func() {
162 | * output, allHiddenStates, allAttentions = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, false)
163 | * })
164 | *
165 | * fmt.Println(output.MustSize())
166 | * fmt.Println(len(allHiddenStates))
167 | * fmt.Println(len(allAttentions))
168 | *
169 | * // Output:
170 | * // [2 3]
171 | * // 12
172 | * // 12
173 | * } */
174 |
--------------------------------------------------------------------------------
/bert/encoder.go:
--------------------------------------------------------------------------------
1 | package bert
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/sugarme/gotch/nn"
7 | "github.com/sugarme/gotch/ts"
8 | )
9 |
10 | // `BertLayer`:
11 | //=============
12 |
13 | // BertLayer defines a layer in BERT encoder
14 | type BertLayer struct {
15 | Attention *BertAttention
16 | IsDecoder bool
17 | CrossAttention *BertAttention
18 | Intermediate *BertIntermediate
19 | Output *BertOutput
20 | }
21 |
22 | // NewBertLayer creates a new BertLayer.
23 | func NewBertLayer(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertLayer {
24 | changeName := true
25 | if len(changeNameOpt) > 0 {
26 | changeName = changeNameOpt[0]
27 | }
28 | path := p.Sub("attention")
29 | attention := NewBertAttention(path, config, changeName)
30 | var (
31 | isDecoder bool = false
32 | crossAttention *BertAttention
33 | )
34 |
35 | if config.IsDecoder {
36 | isDecoder = true
37 | attPath := p.Sub("cross_attention")
38 | crossAttention = NewBertAttention(attPath, config)
39 | }
40 |
41 | intermediatePath := p.Sub("intermediate")
42 | intermediate := NewBertIntermediate(intermediatePath, config)
43 | outputPath := p.Sub("output")
44 | output := NewBertOutput(outputPath, config, changeName)
45 |
46 | return &BertLayer{attention, isDecoder, crossAttention, intermediate, output}
47 | }
48 |
49 | // ForwardT forwards pass through the model.
50 | func (bl *BertLayer) ForwardT(hiddenStates, mask, encoderHiddenStates, encoderMask *ts.Tensor, train bool) (retVal, retValOpt1, retValOpt2 *ts.Tensor) {
51 | var (
52 | attentionOutput *ts.Tensor
53 | attentionWeights *ts.Tensor
54 | crossAttentionWeights *ts.Tensor
55 | )
56 |
57 | if bl.IsDecoder && encoderHiddenStates.MustDefined() {
58 | var attentionOutputTmp *ts.Tensor
59 | attentionOutputTmp, attentionWeights = bl.Attention.ForwardT(hiddenStates, mask, ts.None, ts.None, train)
60 | attentionOutput, crossAttentionWeights = bl.CrossAttention.ForwardT(attentionOutputTmp, mask, encoderHiddenStates, encoderMask, train)
61 | attentionOutputTmp.MustDrop()
62 | } else {
63 | attentionOutput, attentionWeights = bl.Attention.ForwardT(hiddenStates, mask, ts.None, ts.None, train)
64 | crossAttentionWeights = ts.None
65 | }
66 |
67 | outputTmp := bl.Intermediate.Forward(attentionOutput)
68 | output := bl.Output.ForwardT(outputTmp, attentionOutput, train)
69 | attentionOutput.MustDrop()
70 | outputTmp.MustDrop()
71 |
72 | return output, attentionWeights, crossAttentionWeights
73 | }
74 |
75 | // `BertEncoder`:
76 | //===============
77 |
78 | // BertEncoder defines an encoder for BERT model
79 | type BertEncoder struct {
80 | OutputAttentions bool
81 | OutputHiddenStates bool
82 | Layers []BertLayer
83 | }
84 |
85 | // NewBertEncoder creates a new BertEncoder.
86 | func NewBertEncoder(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertEncoder {
87 | changeName := true
88 | if len(changeNameOpt) > 0 {
89 | changeName = changeNameOpt[0]
90 | }
91 | path := p.Sub("layer")
92 | outputAttentions := false
93 | if config.OutputAttentions {
94 | outputAttentions = true
95 | }
96 |
97 | outputHiddenStates := false
98 | if config.OutputHiddenStates {
99 | outputHiddenStates = true
100 | }
101 |
102 | var layers []BertLayer
103 | for lIdx := 0; lIdx < int(config.NumHiddenLayers); lIdx++ {
104 | layers = append(layers, *NewBertLayer(path.Sub(fmt.Sprintf("%v", lIdx)), config, changeName))
105 | }
106 |
107 | return &BertEncoder{outputAttentions, outputHiddenStates, layers}
108 |
109 | }
110 |
111 | // ForwardT forwards pass through the model.
112 | func (be *BertEncoder) ForwardT(hiddenStates, mask, encoderHiddenStates, encoderMask *ts.Tensor, train bool) (retVal *ts.Tensor, retValOpt1, retValOpt2 []ts.Tensor) {
113 | var (
114 | allHiddenStates, allAttentions []ts.Tensor = nil, nil
115 | )
116 |
117 | hiddenState := hiddenStates
118 |
119 | if be.OutputHiddenStates {
120 | allHiddenStates = make([]ts.Tensor, 0) // initialize it
121 | }
122 | if be.OutputAttentions {
123 | allAttentions = make([]ts.Tensor, 0)
124 | }
125 |
126 | for _, layer := range be.Layers {
127 | if allHiddenStates != nil {
128 | allHiddenStates = append(allHiddenStates, *hiddenState)
129 | }
130 |
131 | stateTmp, attnWeightsTmp, _ := layer.ForwardT(hiddenState, mask, encoderHiddenStates, encoderMask, train)
132 | hiddenState.MustDrop()
133 | hiddenState = stateTmp
134 |
135 | if allAttentions != nil {
136 | allAttentions = append(allAttentions, *attnWeightsTmp)
137 | }
138 |
139 | // TODO: should we need to delete `stateTmp` and `attnWeightsTmp` after all?
140 |
141 | }
142 |
143 | return hiddenState, allHiddenStates, allAttentions
144 | }
145 |
146 | // `BertPooler`:
147 | //==============
148 |
149 | // BertPooler defines a linear layer which can be applied to the
150 | // first element of the sequence(`[MASK]` token)
151 | type BertPooler struct {
152 | Lin *nn.Linear
153 | }
154 |
155 | // NewBertPooler creates a new BertPooler.
156 | func NewBertPooler(p *nn.Path, config *BertConfig) *BertPooler {
157 | path := p.Sub("dense")
158 | lconfig := nn.DefaultLinearConfig()
159 | lin := nn.NewLinear(path, config.HiddenSize, config.HiddenSize, lconfig)
160 |
161 | return &BertPooler{lin}
162 | }
163 |
164 | // Forward forwards pass through the model.
165 | func (bp *BertPooler) Forward(hiddenStates *ts.Tensor) (retVal *ts.Tensor) {
166 |
167 | selectTs := hiddenStates.MustSelect(1, 0, false)
168 | tmp := selectTs.Apply(bp.Lin)
169 | retVal = tmp.MustTanh(true)
170 | return retVal
171 | }
172 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Transformer [](https://opensource.org/licenses/Apache-2.0)[](https://pkg.go.dev/github.com/sugarme/transformer?tab=doc)[](https://travis-ci.org/sugarme/transformer)[](https://goreportcard.com/report/github.com/sugarme/transformer)
2 |
3 | ## Overview
4 |
5 | `transformer` is pure Go package to facilitate applying Natural Language Processing (NLP) models train/test and inference in Go.
6 |
7 | This package is in active mode of building and there are many changes ahead. Hence you can use it with your complete own risk. The package will be considered as stable when version 1.0 is released.
8 |
9 | `transformer` is heavily inspired by and based on the popular [Python HuggingFace Transformers](https://github.com/huggingface/transformers). It's also influenced by [Rust version - rust-bert](https://github.com/guillaume-be/rust-bert). In fact, all pre-trained models for Rust are compatible to import to this Go `transformer` package as both `rust-bert`'s dependency Pytorch Rust binding - [**`tch-rs`**](https://github.com/LaurentMazare/tch-rs) and Go binding [**`gotch`**](https://github.com/sugarme/gotch) are built with similar principles.
10 |
11 | `transformer` is part of an ambitious goal (together with [**tokenizer**](https://github.com/sugarme/tokenizer) and [**gotch**](https://github.com/sugarme/gotch)) to bring more AI/deep-learning tools to Gophers so that they can stick to the language they love and good at and build faster software in production.
12 |
13 | ## Dependencies
14 |
15 | 2 main dependencies are:
16 |
17 | - `tokenizer`
18 | - `gotch`
19 |
20 | ## Prerequisites and installation
21 |
22 | - As this package depends on `gotch` which is a Pytorch C++ API binding for Go, a pre-compiled Libtorch copy (CPU or GPU) should be installed in your machine. Please see [gotch](https://github.com/sugarme/gotch) installation instruction for detail.
23 | - Install package: `go get -u github.com/sugarme/transformer`
24 |
25 | ## Basic example
26 |
27 | ```go
28 | import (
29 | "fmt"
30 | "log"
31 |
32 | "github.com/sugarme/gotch"
33 | ts "github.com/sugarme/gotch/tensor"
34 | "github.com/sugarme/tokenizer"
35 |
36 | "github.com/sugarme/transformer/bert"
37 | )
38 |
39 | func main() {
40 | var config *bert.BertConfig = new(bert.BertConfig)
41 | if err := transformer.LoadConfig(config, "bert-base-uncased", nil); err != nil {
42 | log.Fatal(err)
43 | }
44 |
45 | var model *bert.BertForMaskedLM = new(bert.BertForMaskedLM)
46 | if err := transformer.LoadModel(model, "bert-base-uncased", config, nil, gotch.CPU); err != nil {
47 | log.Fatal(err)
48 | }
49 |
50 | var tk *bert.Tokenizer = bert.NewTokenizer()
51 | if err := tk.Load("bert-base-uncased", nil); err != nil{
52 | log.Fatal(err)
53 | }
54 |
55 | sentence1 := "Looks like one [MASK] is missing"
56 | sentence2 := "It was a very nice and [MASK] day"
57 |
58 | var input []tokenizer.EncodeInput
59 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)))
60 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)))
61 |
62 | encodings, err := tk.EncodeBatch(input, true)
63 | if err != nil {
64 | log.Fatal(err)
65 | }
66 |
67 | var maxLen int = 0
68 | for _, en := range encodings {
69 | if len(en.Ids) > maxLen {
70 | maxLen = len(en.Ids)
71 | }
72 | }
73 |
74 | var tensors []ts.Tensor
75 | for _, en := range encodings {
76 | var tokInput []int64 = make([]int64, maxLen)
77 | for i := 0; i < len(en.Ids); i++ {
78 | tokInput[i] = int64(en.Ids[i])
79 | }
80 |
81 | tensors = append(tensors, ts.TensorFrom(tokInput))
82 | }
83 |
84 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true)
85 | var output ts.Tensor
86 | ts.NoGrad(func() {
87 | output, _, _ = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, ts.None, ts.None, false)
88 | })
89 | index1 := output.MustGet(0).MustGet(4).MustArgmax(0, false, false).Int64Values()[0]
90 | index2 := output.MustGet(1).MustGet(7).MustArgmax(0, false, false).Int64Values()[0]
91 |
92 | got1, ok := tk.IdToToken(int(index1))
93 | if !ok {
94 | fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index1)
95 | }
96 | got2, ok := tk.IdToToken(int(index2))
97 | if !ok {
98 | fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index2)
99 | }
100 |
101 | fmt.Println(got1)
102 | fmt.Println(got2)
103 |
104 | // Output:
105 | // person
106 | // pleasant
107 | }
108 | ```
109 |
110 | ## Getting Started
111 |
112 | - See [pkg.go.dev](https://pkg.go.dev/github.com/sugarme/transformer?tab=doc) for detail APIs
113 |
114 |
115 | ## License
116 |
117 | `transformer` is Apache 2.0 licensed.
118 |
119 |
120 | ## Acknowledgement
121 |
122 | - This project has been inspired and used many concepts from [Python HuggingFace Transformers](https://github.com/huggingface/transformers) and [Rust version - rust-bert](https://github.com/guillaume-be/rust-bert).
123 |
124 | - Pre-trained models and configurations are downloaded remotely from HuggingFace.
125 |
126 |
127 |
128 |
129 |
130 |
--------------------------------------------------------------------------------
/roberta/embedding.go:
--------------------------------------------------------------------------------
1 | package roberta
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/sugarme/gotch"
7 | "github.com/sugarme/gotch/nn"
8 | "github.com/sugarme/gotch/ts"
9 |
10 | "github.com/sugarme/transformer/bert"
11 | "github.com/sugarme/transformer/util"
12 | )
13 |
14 | // RobertaEmbeddings holds embedding struct for Roberta model.
15 | // It also implements `BertEmbedding` interface for Roberta models.
16 | type RobertaEmbeddings struct {
17 | wordEmbeddings *nn.Embedding
18 | positionEmbeddings *nn.Embedding
19 | tokenTypeEmbeddings *nn.Embedding
20 | layerNorm *nn.LayerNorm
21 | dropout *util.Dropout
22 | paddingIndex int64
23 | }
24 |
25 | func (re *RobertaEmbeddings) createPositionIdsFromInputIds(x *ts.Tensor) *ts.Tensor {
26 | mask := x.MustNe(ts.IntScalar(re.paddingIndex), false).MustTotype(gotch.Int64, true)
27 | cumSum := mask.MustCumsum(1, gotch.Int64, false)
28 | mul := cumSum.MustMul(mask, true)
29 | retVal := mul.MustAddScalar(ts.IntScalar(re.paddingIndex), false)
30 | mul.MustDrop()
31 |
32 | return retVal
33 | }
34 |
35 | func (re *RobertaEmbeddings) createPositionIdsFromEmbeddings(x *ts.Tensor) *ts.Tensor {
36 | shape := x.MustSize()
37 | var inputShape []int64 = []int64{shape[0], shape[1]}
38 |
39 | positionIds := ts.MustArangeStart(ts.IntScalar(re.paddingIndex+1), ts.IntScalar(inputShape[0]), gotch.Int64, x.MustDevice())
40 | retVal := positionIds.MustUnsqueeze(0, false).MustExpand(inputShape, true, true)
41 |
42 | return retVal
43 | }
44 |
45 | // NewRobertaEmbeddings creates a new RobertaEmbeddings.
46 | //
47 | // Params:
48 | // - `p` - Variable store path for the root of the BertEmbeddings model
49 | // - `config` - `BertConfig` object defining the model architecture and vocab/hidden size.
50 | func NewRobertaEmbeddings(p nn.Path, config *bert.BertConfig) *RobertaEmbeddings {
51 |
52 | embeddingConfig := nn.DefaultEmbeddingConfig()
53 | embeddingConfig.PaddingIdx = 1
54 |
55 | wordEmbeddings := nn.NewEmbedding(p.Sub("word_embeddings"), config.VocabSize, config.HiddenSize, embeddingConfig)
56 | positionEmbeddings := nn.NewEmbedding(p.Sub("position_embeddings"), config.MaxPositionEmbeddings, config.HiddenSize, nn.DefaultEmbeddingConfig())
57 | tokenTypeEmbeddings := nn.NewEmbedding(p.Sub("token_type_embeddings"), config.TypeVocabSize, config.HiddenSize, nn.DefaultEmbeddingConfig())
58 |
59 | layerNormConfig := nn.DefaultLayerNormConfig()
60 | layerNormConfig.Eps = 1e-12
61 | layerNorm := nn.NewLayerNorm(p.Sub("LayerNorm"), []int64{config.HiddenSize}, layerNormConfig)
62 | dropout := util.NewDropout(config.HiddenDropoutProb)
63 |
64 | return &RobertaEmbeddings{
65 | wordEmbeddings: wordEmbeddings,
66 | positionEmbeddings: positionEmbeddings,
67 | tokenTypeEmbeddings: tokenTypeEmbeddings,
68 | layerNorm: layerNorm,
69 | dropout: dropout,
70 | }
71 | }
72 |
73 | // ForwardT forwards pass through the embedding layer.
74 | // This differs from the original BERT embeddings in how the position ids are calculated when not provided.
75 | //
76 | // Params:
77 | // - `inputIds`: Optional input tensor of shape (batch size, sequence length).
78 | // If None, pre-computed embeddings must be provided (see `inputEmbeds`)
79 | // - `tokenTypeIds`: Optional segment id of shape (batch size, sequence length).
80 | // Convention is value of 0 for the first sentence (incl. [SEP]) and 1 for the second sentence. If None set to 0.
81 | // - `positionIds`: Optional position ids of shape (batch size, sequence length).
82 | // If None, will be incremented from 0.
83 | // - `inputEmbeds`: Optional pre-computed input embeddings of shape (batch size, sequence length, hidden size).
84 | // If None, input ids must be provided (see `inputIds`)
85 | // - `train`: boolean flag to turn on/off the dropout layers in the model.
86 | // Should be set to false for inference.
87 | //
88 | // Return:
89 | // - `embeddedOutput`: tensor of shape (batch size, sequence length, hidden size)
90 | func (re *RobertaEmbeddings) ForwardT(inputIds, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (*ts.Tensor, error) {
91 |
92 | var (
93 | inputEmbeddings *ts.Tensor
94 | inputShape []int64
95 | )
96 |
97 | if !inputIds.MustDefined() {
98 | if inputEmbeds.MustDefined() {
99 | inputEmbeddings = inputEmbeds
100 | inputEmbedsShape := inputEmbeds.MustSize()
101 | inputShape = []int64{inputEmbedsShape[0], inputEmbedsShape[1]}
102 | } else {
103 | err := fmt.Errorf("Only one of input Ids or input embeddings may be set.")
104 | return ts.None, err
105 | }
106 | } else {
107 | // if inputIds == inputEmbeds
108 | if util.Equal(inputIds, inputEmbeds) {
109 | err := fmt.Errorf("Only one of input Ids or input embeddings may be set.")
110 | return ts.None, err
111 | } else {
112 | inputEmbeddings = inputIds.ApplyT(re.wordEmbeddings, train)
113 | inputShape = inputIds.MustSize()
114 | }
115 | }
116 |
117 | var posIds *ts.Tensor
118 | if positionIds.MustDefined() {
119 | posIds = positionIds
120 | } else {
121 | if inputIds.MustDefined() {
122 | posIds = re.createPositionIdsFromInputIds(inputIds)
123 | } else {
124 | posIds = re.createPositionIdsFromEmbeddings(inputEmbeds)
125 | }
126 | }
127 |
128 | var tokTypeIds *ts.Tensor
129 | if tokenTypeIds.MustDefined() {
130 | tokTypeIds = tokenTypeIds
131 | } else {
132 | tokTypeIds = ts.MustZeros(inputShape, gotch.Int64, inputEmbeddings.MustDevice())
133 | }
134 |
135 | positionEmbeddings := posIds.Apply(re.positionEmbeddings)
136 | tokenTypeEmbeddings := tokTypeIds.Apply(re.tokenTypeEmbeddings)
137 |
138 | add1 := inputEmbeddings.MustAdd(positionEmbeddings, true)
139 | newInputEmbeddings := add1.MustAdd(tokenTypeEmbeddings, true)
140 | inputEmbeddings.MustDrop()
141 |
142 | appliedLN := newInputEmbeddings.Apply(re.layerNorm)
143 | appliedDropout := appliedLN.ApplyT(re.dropout, train)
144 | appliedLN.MustDrop()
145 |
146 | return appliedDropout, nil
147 | }
148 |
--------------------------------------------------------------------------------
/pipeline/common.go:
--------------------------------------------------------------------------------
1 | package pipeline
2 |
3 | import (
4 | "log"
5 | "reflect"
6 |
7 | "github.com/sugarme/tokenizer"
8 | "github.com/sugarme/tokenizer/model/wordpiece"
9 | "github.com/sugarme/tokenizer/normalizer"
10 | "github.com/sugarme/tokenizer/pretokenizer"
11 | "github.com/sugarme/tokenizer/processor"
12 | "github.com/sugarme/transformer/bert"
13 | )
14 |
15 | // Common blocks for generic pipelines (e.g. token classification or sequence classification)
16 | // Provides Enums holding configuration or tokenization resources that can be used to create
17 | // generic pipelines. The model component is defined in the generic pipeline itself as the
18 | // pre-processing, forward pass and postprocessing differs between pipelines while basic config and
19 | // tokenization objects don't.
20 |
21 | // ModelType is a enum-like, identifying the type of model
22 | type ModelType int
23 |
24 | const (
25 | Bert ModelType = iota
26 | DistilBert
27 | Roberta
28 | XLMRoberta
29 | Electra
30 | Marian
31 | T5
32 | Albert
33 | )
34 |
35 | type ModelOption struct {
36 | model ModelType
37 | }
38 |
39 | type Config interface{}
40 |
41 | // ConfigOption holds a model configuration
42 | type ConfigOption struct {
43 | model ModelType
44 | config Config
45 | }
46 |
47 | func NewBertConfigOption(config bert.BertConfig) *ConfigOption {
48 | return &ConfigOption{
49 | model: Bert,
50 | config: config,
51 | }
52 | }
53 |
54 | type TokenizerType int
55 |
56 | const (
57 | BertTokenizer TokenizerType = iota
58 | RobertaTokenizer
59 | XLMRobertaTokenizer
60 | MarianTokenizer
61 | T5Tokenizer
62 | AlbertTokenizer
63 | )
64 |
65 | // TokenizerOption specifies a tokenizer
66 | type TokenizerOption struct {
67 | model ModelType
68 | tokenizer *tokenizer.Tokenizer
69 | }
70 |
71 | // ConfigOption methods:
72 | // =====================
73 |
74 | // ConfigOptionFromFile loads configuration for corresponding model type from file.
75 | func ConfigOptionFromFile(modelType ModelType, path string) *ConfigOption {
76 |
77 | var configOpt *ConfigOption
78 |
79 | switch reflect.TypeOf(modelType).Kind().String() {
80 | case "Bert":
81 | config := bert.ConfigFromFile(path)
82 | configOpt = &ConfigOption{
83 | model: Bert,
84 | config: config,
85 | }
86 |
87 | // TODO: implement others
88 | // case "DistilBert":
89 | default:
90 | log.Fatalf("Invalid modelType: '%v'\n", reflect.TypeOf(modelType).Kind().String())
91 | }
92 |
93 | return configOpt
94 | }
95 |
96 | // GetLabelMap returns label mapping for corresponding model type.
97 | func (co *ConfigOption) GetLabelMapping() map[int64]string {
98 |
99 | var labelMap map[int64]string = make(map[int64]string)
100 |
101 | modelTypeStr := reflect.TypeOf(co.model).Kind().String()
102 | switch modelTypeStr {
103 | case "Bert":
104 | labelMap = co.config.(bert.BertConfig).Id2Label
105 |
106 | // TODO: implement others
107 | default:
108 | log.Fatalf("ConfigOption GetLabelMapping error: invalid model type ('%v')\n", modelTypeStr)
109 | }
110 |
111 | return labelMap
112 | }
113 |
114 | // TOkenizerOptionFromFile loads TokenizerOption from file corresponding to model type.
115 | func TokenizerOptionFromFile(modelType ModelType, path string) *TokenizerOption {
116 | modelTypeStr := reflect.TypeOf(modelType).Kind().String()
117 |
118 | var tk *TokenizerOption
119 | switch modelTypeStr {
120 | case "Bert":
121 | tk = &TokenizerOption{
122 | model: modelType,
123 | tokenizer: getBert(path),
124 | }
125 |
126 | // TODO: implement others
127 |
128 | default:
129 | log.Fatalf("Unsupported model type: '%v'", modelTypeStr)
130 | }
131 |
132 | return tk
133 | }
134 |
135 | func getBert(path string) (retVal *tokenizer.Tokenizer) {
136 | model, err := wordpiece.NewWordPieceFromFile(path, "[UNK]")
137 | if err != nil {
138 | log.Fatal(err)
139 | }
140 |
141 | tk := tokenizer.NewTokenizer(model)
142 |
143 | bertNormalizer := normalizer.NewBertNormalizer(true, true, true, true)
144 | tk.WithNormalizer(bertNormalizer)
145 |
146 | bertPreTokenizer := pretokenizer.NewBertPreTokenizer()
147 | tk.WithPreTokenizer(bertPreTokenizer)
148 |
149 | var specialTokens []tokenizer.AddedToken
150 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("[MASK]", true))
151 |
152 | tk.AddSpecialTokens(specialTokens)
153 |
154 | sepId, ok := tk.TokenToId("[SEP]")
155 | if !ok {
156 | log.Fatalf("Cannot find ID for [SEP] token.\n")
157 | }
158 | sep := processor.PostToken{Id: sepId, Value: "[SEP]"}
159 |
160 | clsId, ok := tk.TokenToId("[CLS]")
161 | if !ok {
162 | log.Fatalf("Cannot find ID for [CLS] token.\n")
163 | }
164 | cls := processor.PostToken{Id: clsId, Value: "[CLS]"}
165 |
166 | postProcess := processor.NewBertProcessing(sep, cls)
167 | tk.WithPostProcessor(postProcess)
168 |
169 | return tk
170 | }
171 |
172 | // ModelType returns chosen model type
173 | func (tk *TokenizerOption) ModelType() ModelType {
174 | return tk.model
175 | }
176 |
177 | // EncodeList encodes a slice of input string
178 | func (tk *TokenizerOption) EncodeList(sentences []string) ([]tokenizer.Encoding, error) {
179 | var input []tokenizer.EncodeInput
180 | for _, sentence := range sentences {
181 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence)))
182 | }
183 |
184 | return tk.tokenizer.EncodeBatch(input, true)
185 | }
186 |
187 | // Tokenize tokenizes input string
188 | func (tk *TokenizerOption) Tokenize(sentence string) ([]string, error) {
189 |
190 | input := tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence))
191 |
192 | encoding, err := tk.tokenizer.Encode(input, true)
193 | if err != nil {
194 | return nil, err
195 | }
196 |
197 | return encoding.Tokens, nil
198 | }
199 |
200 | // AddSpecialTokens adds special tokens to tokenizer
201 | func (tk *TokenizerOption) AddSpecialTokens(tokens []string) {
202 |
203 | var addedToks []tokenizer.AddedToken
204 | for _, tok := range tokens {
205 | addedToks = append(addedToks, tokenizer.NewAddedToken(tok, true))
206 | }
207 |
208 | tk.tokenizer.AddSpecialTokens(addedToks)
209 | }
210 |
211 | // TokensToIds converts a slice of tokens to corresponding Ids.
212 | func (tk *TokenizerOption) TokensToIds(tokens []string) (ids []int64, ok bool) {
213 | for _, tok := range tokens {
214 | id, ok := tk.tokenizer.TokenToId(tok)
215 | if !ok {
216 | return nil, false
217 | }
218 | ids = append(ids, int64(id))
219 | }
220 |
221 | return ids, true
222 | }
223 |
224 | // PadId returns a PAD id if any.
225 | func (tk *TokenizerOption) PadId() (id int64, ok bool) {
226 | paddingParam := tk.tokenizer.GetPadding()
227 | if paddingParam == nil {
228 | return -1, false
229 | }
230 |
231 | return int64(paddingParam.PadId), true
232 | }
233 |
234 | // SepId returns a SEP id if any.
235 | // If optional sepOpt is not specify, default value is "[SEP]"
236 | func (tk *TokenizerOption) SepId(sepOpt ...string) (id int64, ok bool) {
237 |
238 | sep := "[SEP]" // default sep string
239 | if len(sepOpt) > 0 {
240 | sep = sepOpt[0]
241 | }
242 |
243 | i, ok := tk.tokenizer.TokenToId(sep)
244 | return int64(i), ok
245 | }
246 |
--------------------------------------------------------------------------------
/example/bert/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "log"
6 |
7 | "github.com/sugarme/gotch"
8 | "github.com/sugarme/gotch/nn"
9 | "github.com/sugarme/gotch/pickle"
10 | "github.com/sugarme/gotch/ts"
11 | "github.com/sugarme/tokenizer"
12 | "github.com/sugarme/tokenizer/model/wordpiece"
13 | "github.com/sugarme/tokenizer/normalizer"
14 | "github.com/sugarme/tokenizer/pretokenizer"
15 | "github.com/sugarme/tokenizer/processor"
16 | "github.com/sugarme/transformer/bert"
17 | "github.com/sugarme/transformer/util"
18 | )
19 |
20 | func main() {
21 | bertForMaskedLM()
22 | // bertForSequenceClassification()
23 | }
24 |
25 | func getBert() (retVal *tokenizer.Tokenizer) {
26 | vocabFile, err := util.CachedPath("bert-base-uncased", "vocab.txt")
27 | if err != nil {
28 | panic(err)
29 | }
30 |
31 | model, err := wordpiece.NewWordPieceFromFile(vocabFile, "[UNK]")
32 | if err != nil {
33 | log.Fatal(err)
34 | }
35 |
36 | tk := tokenizer.NewTokenizer(model)
37 | fmt.Printf("Vocab size: %v\n", tk.GetVocabSize(false))
38 |
39 | bertNormalizer := normalizer.NewBertNormalizer(true, true, true, true)
40 | tk.WithNormalizer(bertNormalizer)
41 |
42 | bertPreTokenizer := pretokenizer.NewBertPreTokenizer()
43 | tk.WithPreTokenizer(bertPreTokenizer)
44 |
45 | var specialTokens []tokenizer.AddedToken
46 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("[MASK]", true))
47 |
48 | tk.AddSpecialTokens(specialTokens)
49 |
50 | sepId, ok := tk.TokenToId("[SEP]")
51 | if !ok {
52 | log.Fatalf("Cannot find ID for [SEP] token.\n")
53 | }
54 | sep := processor.PostToken{Id: sepId, Value: "[SEP]"}
55 |
56 | clsId, ok := tk.TokenToId("[CLS]")
57 | if !ok {
58 | log.Fatalf("Cannot find ID for [CLS] token.\n")
59 | }
60 | cls := processor.PostToken{Id: clsId, Value: "[CLS]"}
61 |
62 | postProcess := processor.NewBertProcessing(sep, cls)
63 | tk.WithPostProcessor(postProcess)
64 |
65 | return tk
66 | }
67 |
68 | func bertForMaskedLM() {
69 | device := gotch.CPU
70 | vs := nn.NewVarStore(device)
71 |
72 | configFile, err := util.CachedPath("bert-base-uncased", "config.json")
73 | if err != nil {
74 | panic(err)
75 | }
76 | config, err := bert.ConfigFromFile(configFile)
77 | if err != nil {
78 | panic(err)
79 | }
80 | // fmt.Printf("Bert Configuration:\n%+v\n", config)
81 |
82 | model, err := bert.NewBertForMaskedLM(vs.Root(), config)
83 | if err != nil {
84 | panic(err)
85 | }
86 |
87 | modelFile, err := util.CachedPath("bert-base-uncased", "pytorch_model.bin")
88 | if err != nil {
89 | panic(err)
90 | }
91 | err = pickle.LoadAll(vs, modelFile)
92 | if err != nil {
93 | log.Fatalf("Load model weight error: \n%v", err)
94 | }
95 |
96 | // fmt.Printf("Varstore weights have been loaded\n")
97 | // fmt.Printf("Num of variables: %v\n", len(vs.Variables()))
98 | // fmt.Printf("%v\n", vs.Variables())
99 | // fmt.Printf("Bert is Decoder: %v\n", model.bert.IsDecoder)
100 |
101 | tk := getBert()
102 | sentence1 := "Looks like one [MASK] is missing"
103 | sentence2 := "It was a very nice and [MASK] day"
104 |
105 | var input []tokenizer.EncodeInput
106 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)))
107 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)))
108 |
109 | encodings, err := tk.EncodeBatch(input, true)
110 | if err != nil {
111 | log.Fatal(err)
112 | }
113 |
114 | // Find max length of token Ids from slice of encodings
115 | var maxLen int = 0
116 | for _, en := range encodings {
117 | if len(en.Ids) > maxLen {
118 | maxLen = len(en.Ids)
119 | }
120 | }
121 |
122 | var tensors []ts.Tensor
123 | for _, en := range encodings {
124 | var tokInput []int64 = make([]int64, maxLen)
125 | for i := 0; i < len(en.Ids); i++ {
126 | tokInput[i] = int64(en.Ids[i])
127 | }
128 |
129 | tensors = append(tensors, *ts.TensorFrom(tokInput))
130 | }
131 |
132 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true)
133 | // inputTensor.Print()
134 |
135 | var output *ts.Tensor
136 | ts.NoGrad(func() {
137 | output, _, _ = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, ts.None, ts.None, false)
138 | })
139 |
140 | index1 := output.MustGet(0).MustGet(4).MustArgmax([]int64{0}, false, false).Int64Values()[0]
141 | index2 := output.MustGet(1).MustGet(7).MustArgmax([]int64{0}, false, false).Int64Values()[0]
142 |
143 | word1, ok := tk.IdToToken(int(index1))
144 | if !ok {
145 | fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index1)
146 | }
147 | fmt.Printf("Input: '%v' \t- Output: '%v'\n", sentence1, word1)
148 |
149 | word2, ok := tk.IdToToken(int(index2))
150 | if !ok {
151 | fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index2)
152 | }
153 | fmt.Printf("Input: '%v' \t- Output: '%v'\n", sentence2, word2)
154 | }
155 |
156 | func bertForSequenceClassification() {
157 | device := gotch.CPU
158 | vs := nn.NewVarStore(device)
159 |
160 | configFile, err := util.CachedPath("bert-base-uncased", "config.json")
161 | if err != nil {
162 | panic(err)
163 | }
164 | config, err := bert.ConfigFromFile(configFile)
165 | if err != nil {
166 | panic(err)
167 | }
168 |
169 | var dummyLabelMap map[int64]string = make(map[int64]string)
170 | dummyLabelMap[0] = "positive"
171 | dummyLabelMap[1] = "negative"
172 | dummyLabelMap[3] = "neutral"
173 |
174 | config.Id2Label = dummyLabelMap
175 | config.OutputAttentions = true
176 | config.OutputHiddenStates = true
177 | model := bert.NewBertForSequenceClassification(vs.Root(), config)
178 | tk := getBert()
179 |
180 | // Define input
181 | sentence1 := "Looks like one thing is missing"
182 | sentence2 := `It's like comparing oranges to apples`
183 | var input []tokenizer.EncodeInput
184 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)))
185 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)))
186 | encodings, err := tk.EncodeBatch(input, true)
187 | if err != nil {
188 | log.Fatal(err)
189 | }
190 |
191 | // Find max length of token Ids from slice of encodings
192 | var maxLen int = 0
193 | for _, en := range encodings {
194 | if len(en.Ids) > maxLen {
195 | maxLen = len(en.Ids)
196 | }
197 | }
198 |
199 | fmt.Printf("encodings: %v\n", encodings)
200 | var tensors []ts.Tensor
201 | for _, en := range encodings {
202 | var tokInput []int64 = make([]int64, maxLen)
203 | for i := 0; i < len(en.Ids); i++ {
204 | tokInput[i] = int64(en.Ids[i])
205 | }
206 |
207 | tensors = append(tensors, *ts.TensorFrom(tokInput))
208 | }
209 |
210 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true)
211 | // inputTensor.Print()
212 |
213 | var (
214 | output *ts.Tensor
215 | allHiddenStates, allAttentions []ts.Tensor
216 | )
217 |
218 | ts.NoGrad(func() {
219 | output, allHiddenStates, allAttentions = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, false)
220 | })
221 |
222 | fmt.Printf("output size: %v\n", output.MustSize())
223 |
224 | fmt.Printf("NumHiddenLayers: %v\n", config.NumHiddenLayers)
225 | fmt.Printf("allHiddenStates length: %v\n", len(allHiddenStates))
226 | fmt.Printf("allAttentions length: %v\n", len(allAttentions))
227 |
228 | }
229 |
--------------------------------------------------------------------------------
/util/file-util.go:
--------------------------------------------------------------------------------
1 | package util
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "log"
7 | "net/http"
8 | "os"
9 | "path"
10 | "strconv"
11 | "strings"
12 | )
13 |
14 | // This file provides functions to work with local dataset cache, ...
15 |
16 | const (
17 | WeightName = "pytorch_model.gt"
18 | ConfigName = "config.json"
19 |
20 | // NOTE. URL form := `$HFpath/ModelName/resolve/main/WeightName`
21 | HFpath = "https://huggingface.co"
22 | )
23 |
24 | var (
25 | DUMMY_INPUT [][]int64 = [][]int64{
26 | {7, 6, 0, 0, 1},
27 | {1, 2, 3, 0, 0},
28 | {0, 0, 0, 4, 5},
29 | }
30 | )
31 |
32 | // CachedPath resolves and caches data based on input string, then returns fullpath to the cached data.
33 | //
34 | // Parameters:
35 | // - `modelNameOrPath`: model name e.g., "bert-base-uncased" or path to directory contains model/config files.
36 | // - `fileName`: model or config file name. E.g., "pytorch_model.py", "config.json"
37 | //
38 | // CachedPath does several things consequently:
39 | // 1. Resolves input string to a fullpath cached filename candidate.
40 | // 2. Check it at `CachedPath`, if exists, then return the candidate. If not
41 | // 3. Retrieves and Caches data to `CachedPath` and returns path to cached data
42 | //
43 | // NOTE. default `CachedDir` is at "{$HOME}/.cache/transformer"
44 | // Custom `CachedDir` can be changed by setting with environment `GO_TRANSFORMER`
45 | func CachedPath(modelNameOrPath, fileName string) (resolvedPath string, err error) {
46 |
47 | // Resolves to "candidate" filename at `CacheDir`
48 | cachedFileCandidate := fmt.Sprintf("%s/%s/%s", CachedDir, modelNameOrPath, fileName)
49 |
50 | // 1. Cached candidate file exists
51 | if _, err := os.Stat(cachedFileCandidate); err == nil {
52 | return cachedFileCandidate, nil
53 | }
54 |
55 | // 2. If valid fullpath to local file, caches it and return cached filename
56 | filepath := fmt.Sprintf("%s/%s", modelNameOrPath, fileName)
57 | if _, err := os.Stat(filepath); err == nil {
58 | err := copyFile(filepath, cachedFileCandidate)
59 | if err != nil {
60 | err := fmt.Errorf("CachedPath() failed at copying file: %w", err)
61 | return "", err
62 | }
63 | return cachedFileCandidate, nil
64 | }
65 |
66 | // 3. Cached candidate file NOT exist. Try to download it and save to `CachedDir`
67 | url := fmt.Sprintf("%s/%s/resolve/main/%s", HFpath, modelNameOrPath, fileName)
68 | // url := fmt.Sprintf("%s/%s/raw/main/%s", HFpath, modelNameOrPath, fileName)
69 | if isValidURL(url) {
70 | if _, err := http.Get(url); err == nil {
71 | err := downloadFile(url, cachedFileCandidate)
72 | if err != nil {
73 | err = fmt.Errorf("CachedPath() failed at trying to download file: %w", err)
74 | return "", err
75 | }
76 |
77 | return cachedFileCandidate, nil
78 | } else {
79 | err = fmt.Errorf("CachedPath() failed: Unable to parse '%v' as a URL or as a local path.\n", url)
80 | return "", err
81 | }
82 | }
83 |
84 | // Not resolves
85 | err = fmt.Errorf("CachedPath() failed: Unable to parse '%v' as a URL or as a local path.\n", url)
86 | return "", err
87 | }
88 |
89 | func isValidURL(url string) bool {
90 |
91 | // TODO: implement
92 | return true
93 | }
94 |
95 | // downloadFile downloads file from URL and stores it in local filepath.
96 | // It writes to the destination file as it downloads it, without loading
97 | // the entire file into memory. An `io.TeeReader` is passed into Copy()
98 | // to report progress on the download.
99 | func downloadFile(url string, filepath string) error {
100 | // Create path if not existing
101 | dir := path.Dir(filepath)
102 | filename := path.Base(filepath)
103 | if _, err := os.Stat(dir); os.IsNotExist(err) {
104 | if err := os.MkdirAll(dir, 0755); err != nil {
105 | log.Fatal(err)
106 | }
107 | }
108 |
109 | // Create the file with .tmp extension, so that we won't overwrite a
110 | // file until it's downloaded fully
111 | out, err := os.Create(filepath + ".tmp")
112 | if err != nil {
113 | return err
114 | }
115 | defer out.Close()
116 |
117 | // Get the data
118 | resp, err := http.Get(url)
119 | if err != nil {
120 | return err
121 | }
122 | defer resp.Body.Close()
123 |
124 | // Check server response
125 | if resp.StatusCode != http.StatusOK {
126 | err := fmt.Errorf("bad status: %s(%v)", resp.Status, resp.StatusCode)
127 |
128 | if resp.StatusCode == 404 {
129 | // if filename == "rust_model.ot" {
130 | // msg := fmt.Sprintf("model weight file not found. That means a compatible pretrained model weight file for Go is not available.\n")
131 | // msg = msg + fmt.Sprintf("You might need to manually convert a 'pytorch_model.bin' for Go. ")
132 | // msg = msg + fmt.Sprintf("See tutorial at: 'example/convert'")
133 | // err = fmt.Errorf(msg)
134 | // } else {
135 | // err = fmt.Errorf("download file not found: %q for downloading", url)
136 | // }
137 | err = fmt.Errorf("download file not found: %q for downloading", url)
138 | } else {
139 | err = fmt.Errorf("download file failed: %q", url)
140 | }
141 | return err
142 | }
143 |
144 | // the total file size to download
145 | size, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
146 | downloadSize := uint64(size)
147 |
148 | // Create our bytes counter and pass it to be used alongside our writer
149 | counter := &writeCounter{FileSize: downloadSize}
150 | _, err = io.Copy(out, io.TeeReader(resp.Body, counter))
151 | if err != nil {
152 | return err
153 | }
154 |
155 | fmt.Printf("\r%s... %s/%s completed", filename, byteCountIEC(counter.Total), byteCountIEC(counter.FileSize))
156 | // The progress use the same line so print a new line once it's finished downloading
157 | fmt.Println()
158 |
159 | // Rename the tmp file back to the original file
160 | err = os.Rename(filepath+".tmp", filepath)
161 | if err != nil {
162 | return err
163 | }
164 |
165 | return nil
166 | }
167 |
168 | // writeCounter counts the number of bytes written to it. By implementing the Write method,
169 | // it is of the io.Writer interface and we can pass this into io.TeeReader()
170 | // Every write to this writer, will print the progress of the file write.
171 | type writeCounter struct {
172 | Total uint64
173 | FileSize uint64
174 | }
175 |
176 | func (wc *writeCounter) Write(p []byte) (int, error) {
177 | n := len(p)
178 | wc.Total += uint64(n)
179 | wc.printProgress()
180 | return n, nil
181 | }
182 |
183 | // PrintProgress prints the progress of a file write
184 | func (wc writeCounter) printProgress() {
185 | // Clear the line by using a character return to go back to the start and remove
186 | // the remaining characters by filling it with spaces
187 | fmt.Printf("\r%s", strings.Repeat(" ", 50))
188 |
189 | // Return again and print current status of download
190 | fmt.Printf("\rDownloading... %s/%s", byteCountIEC(wc.Total), byteCountIEC(wc.FileSize))
191 | }
192 |
193 | // byteCountIEC converts bytes to human-readable string in binary (IEC) format.
194 | func byteCountIEC(b uint64) string {
195 | const unit = 1024
196 | if b < unit {
197 | return fmt.Sprintf("%d B", b)
198 | }
199 | div, exp := uint64(unit), 0
200 | for n := b / unit; n >= unit; n /= unit {
201 | div *= unit
202 | exp++
203 | }
204 | return fmt.Sprintf("%.1f %ciB",
205 | float64(b)/float64(div), "KMGTPE"[exp])
206 | }
207 |
208 | func copyFile(src, dst string) error {
209 | sourceFileStat, err := os.Stat(src)
210 | if err != nil {
211 | return err
212 | }
213 |
214 | if !sourceFileStat.Mode().IsRegular() {
215 | return fmt.Errorf("%s is not a regular file", src)
216 | }
217 |
218 | source, err := os.Open(src)
219 | if err != nil {
220 | return err
221 | }
222 | defer source.Close()
223 |
224 | destination, err := os.Create(dst)
225 | if err != nil {
226 | return err
227 | }
228 | defer destination.Close()
229 | _, err = io.Copy(destination, source)
230 | return err
231 | }
232 |
233 | // CleanCache removes all files cached in transformer cache directory `CachedDir`.
234 | //
235 | // NOTE. custom `CachedDir` can be changed by setting environment `GO_TRANSFORMER`
236 | func CleanCache() error {
237 | err := os.RemoveAll(CachedDir)
238 | if err != nil {
239 | err = fmt.Errorf("CleanCache() failed: %w", err)
240 | return err
241 | }
242 |
243 | return nil
244 | }
245 |
--------------------------------------------------------------------------------
/bert/attention.go:
--------------------------------------------------------------------------------
1 | package bert
2 |
3 | import (
4 | // "fmt"
5 | "log"
6 | "math"
7 |
8 | "github.com/sugarme/gotch"
9 | "github.com/sugarme/gotch/nn"
10 | "github.com/sugarme/gotch/ts"
11 |
12 | "github.com/sugarme/transformer/util"
13 | )
14 |
15 | // BertSelfAttention:
16 | //===================
17 |
18 | type BertSelfAttention struct {
19 | NumAttentionHeads int64
20 | AttentionHeadSize int64
21 | Dropout *util.Dropout
22 | OutputAttentions bool
23 | Query *nn.Linear
24 | Key *nn.Linear
25 | Value *nn.Linear
26 | }
27 |
28 | // NewBertSelfAttention creates a new `BertSelfAttention`
29 | func NewBertSelfAttention(p *nn.Path, config *BertConfig) *BertSelfAttention {
30 | if config.HiddenSize%config.NumAttentionHeads != 0 {
31 | log.Fatal("Hidden size is not a multiple of the number of attention heads.")
32 | }
33 |
34 | lconfig := nn.DefaultLinearConfig()
35 | query := nn.NewLinear(p.Sub("query"), config.HiddenSize, config.HiddenSize, lconfig)
36 | key := nn.NewLinear(p.Sub("key"), config.HiddenSize, config.HiddenSize, lconfig)
37 | value := nn.NewLinear(p.Sub("value"), config.HiddenSize, config.HiddenSize, lconfig)
38 |
39 | dropout := util.NewDropout(config.AttentionProbsDropoutProb)
40 | attentionHeadSize := int64(config.HiddenSize) / config.NumAttentionHeads
41 | outputAttentions := config.OutputAttentions
42 |
43 | return &BertSelfAttention{
44 | NumAttentionHeads: config.NumAttentionHeads,
45 | AttentionHeadSize: attentionHeadSize,
46 | Dropout: dropout,
47 | OutputAttentions: outputAttentions,
48 | Query: query,
49 | Key: key,
50 | Value: value,
51 | }
52 |
53 | }
54 |
55 | func (bsa *BertSelfAttention) splitHeads(x *ts.Tensor, bs, dimPerHead int64) (retVal *ts.Tensor) {
56 |
57 | xview := x.MustView([]int64{bs, -1, bsa.NumAttentionHeads, dimPerHead}, false)
58 |
59 | return xview.MustTranspose(1, 2, true)
60 | }
61 |
62 | func (bsa *BertSelfAttention) flatten(x *ts.Tensor, bs, dimPerHead int64) (retVal *ts.Tensor) {
63 |
64 | xT := x.MustTranspose(1, 2, false)
65 | xCon := xT.MustContiguous(true)
66 | retVal = xCon.MustView([]int64{bs, -1, bsa.NumAttentionHeads * dimPerHead}, true)
67 |
68 | return retVal
69 | }
70 |
71 | // ForwardT implements ModuleT interface for BertSelfAttention
72 | //
73 | // NOTE. mask, encoderHiddenStates, encoderMask are optional tensors
74 | // for `None` value, `ts.None` can be used.
75 | func (bsa *BertSelfAttention) ForwardT(hiddenStates, mask, encoderHiddenStates, encoderMask *ts.Tensor, train bool) (retVal, retValOpt *ts.Tensor) {
76 |
77 | key := bsa.Key.Forward(hiddenStates)
78 | value := bsa.Value.Forward(hiddenStates)
79 |
80 | if encoderHiddenStates.MustDefined() {
81 | key = bsa.Key.Forward(encoderHiddenStates)
82 | value = bsa.Value.Forward(encoderHiddenStates)
83 | }
84 |
85 | bs := hiddenStates.MustSize()[0]
86 |
87 | hiddenStatesQ := hiddenStates.Apply(bsa.Query)
88 | query := bsa.splitHeads(hiddenStatesQ, bs, bsa.AttentionHeadSize)
89 |
90 | hiddenStatesQ.MustDrop()
91 | keyLayer := bsa.splitHeads(key, bs, bsa.AttentionHeadSize)
92 | key.MustDrop()
93 | valueLayer := bsa.splitHeads(value, bs, bsa.AttentionHeadSize)
94 | value.MustDrop()
95 |
96 | size := math.Sqrt(float64(bsa.AttentionHeadSize))
97 | queryLayer := query.MustDivScalar(ts.FloatScalar(size), true)
98 |
99 | // Calculate score
100 | var scores *ts.Tensor
101 | if mask.MustDefined() {
102 | keyLayerT := keyLayer.MustTranspose(-1, -2, true)
103 | keyLayerT.MustAdd_(mask)
104 | scores = queryLayer.MustMatmul(keyLayerT, true)
105 | } else {
106 | keyLayerT := keyLayer.MustTranspose(-1, -2, true)
107 | scores = queryLayer.MustMatmul(keyLayerT, true)
108 | }
109 |
110 | weights := scores.MustSoftmax(-1, gotch.Float, true).ApplyT(bsa.Dropout, train)
111 |
112 | weightsMul := weights.MustMatmul(valueLayer, false)
113 |
114 | context := bsa.flatten(weightsMul, bs, bsa.AttentionHeadSize)
115 | weightsMul.MustDrop()
116 |
117 | if !bsa.OutputAttentions {
118 | weights.MustDrop()
119 | return context, ts.None
120 | } else {
121 | return context, weights
122 | }
123 |
124 | }
125 |
126 | // BertSelfOutput:
127 | //================
128 |
129 | type BertSelfOutput struct {
130 | Linear *nn.Linear
131 | LayerNorm *nn.LayerNorm
132 | Dropout *util.Dropout
133 | }
134 |
135 | func NewBertSelfOutput(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertSelfOutput {
136 | changeName := true
137 | if len(changeNameOpt) > 0 {
138 | changeName = changeNameOpt[0]
139 | }
140 |
141 | path := p.Sub("dense")
142 | lconfig := nn.DefaultLinearConfig()
143 | linear := nn.NewLinear(path, config.HiddenSize, config.HiddenSize, lconfig)
144 |
145 | layerNormConfig := nn.DefaultLayerNormConfig()
146 | if changeName {
147 | layerNormConfig.WsName = "gamma"
148 | layerNormConfig.BsName = "beta"
149 | }
150 | layerNormConfig.Eps = 1e-12
151 |
152 | layerNorm := nn.NewLayerNorm(p.Sub("LayerNorm"), []int64{config.HiddenSize}, layerNormConfig)
153 | dropout := util.NewDropout(config.HiddenDropoutProb)
154 |
155 | return &BertSelfOutput{linear, layerNorm, dropout}
156 | }
157 |
158 | func (bso *BertSelfOutput) ForwardT(hiddenStates *ts.Tensor, inputTensor *ts.Tensor, train bool) (retVal *ts.Tensor) {
159 |
160 | state1 := hiddenStates.Apply(bso.Linear)
161 | state2 := state1.ApplyT(bso.Dropout, train)
162 | state3 := inputTensor.MustAdd(state2, false)
163 |
164 | retVal = state3.Apply(bso.LayerNorm)
165 | state1.MustDrop()
166 | state2.MustDrop()
167 | state3.MustDrop()
168 |
169 | return retVal
170 | }
171 |
172 | // BertAttention:
173 | //===============
174 |
175 | type BertAttention struct {
176 | Bsa *BertSelfAttention
177 | Output *BertSelfOutput
178 | }
179 |
180 | func NewBertAttention(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertAttention {
181 | changeName := true
182 | if len(changeNameOpt) > 0 {
183 | changeName = changeNameOpt[0]
184 | }
185 | self := NewBertSelfAttention(p.Sub("self"), config)
186 | output := NewBertSelfOutput(p.Sub("output"), config, changeName)
187 |
188 | return &BertAttention{self, output}
189 | }
190 |
191 | func (ba *BertAttention) ForwardT(hiddenStates, mask, encoderHiddenStates, encoderMask *ts.Tensor, train bool) (retVal, RetValOpt *ts.Tensor) {
192 |
193 | selfOutput, attentionWeights := ba.Bsa.ForwardT(hiddenStates, mask, encoderHiddenStates, encoderMask, train)
194 | selfOutput = ba.Output.ForwardT(selfOutput, hiddenStates, train)
195 |
196 | return selfOutput, attentionWeights
197 | }
198 |
199 | // BertIntermedate:
200 | //=================
201 |
202 | type BertIntermediate struct {
203 | Lin *nn.Linear
204 | Activation util.ActivationFn // interface
205 | }
206 |
207 | func NewBertIntermediate(p *nn.Path, config *BertConfig) *BertIntermediate {
208 | lconfig := nn.DefaultLinearConfig()
209 | lin := nn.NewLinear(p.Sub("dense"), config.HiddenSize, config.IntermediateSize, lconfig)
210 |
211 | actFn, ok := util.ActivationFnMap[config.HiddenAct]
212 | if !ok {
213 | log.Fatalf("Unsupported activation function - %v\n", config.HiddenAct)
214 | }
215 |
216 | return &BertIntermediate{lin, actFn}
217 | }
218 |
219 | func (bi *BertIntermediate) Forward(hiddenStates *ts.Tensor) (retVal *ts.Tensor) {
220 |
221 | states := hiddenStates.Apply(bi.Lin)
222 |
223 | retVal = bi.Activation.Fwd(states)
224 | states.MustDrop()
225 |
226 | return retVal
227 | }
228 |
229 | // BertOutput:
230 | //============
231 |
232 | type BertOutput struct {
233 | Lin *nn.Linear
234 | LayerNorm *nn.LayerNorm
235 | Dropout *util.Dropout
236 | }
237 |
238 | func NewBertOutput(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertOutput {
239 | changeName := true
240 | if len(changeNameOpt) > 0 {
241 | changeName = changeNameOpt[0]
242 | }
243 |
244 | lconfig := nn.DefaultLinearConfig()
245 | lin := nn.NewLinear(p.Sub("dense"), config.IntermediateSize, config.HiddenSize, lconfig)
246 |
247 | layerNormConfig := nn.DefaultLayerNormConfig()
248 | if changeName {
249 | layerNormConfig.WsName = "gamma"
250 | layerNormConfig.BsName = "beta"
251 | }
252 | layerNormConfig.Eps = 1e-12
253 | layerNorm := nn.NewLayerNorm(p.Sub("LayerNorm"), []int64{config.HiddenSize}, layerNormConfig)
254 |
255 | dropout := util.NewDropout(config.HiddenDropoutProb)
256 |
257 | return &BertOutput{lin, layerNorm, dropout}
258 | }
259 |
260 | func (bo *BertOutput) ForwardT(hiddenStates, inputTensor *ts.Tensor, train bool) (retVal *ts.Tensor) {
261 |
262 | state1 := hiddenStates.Apply(bo.Lin)
263 | state2 := state1.ApplyT(bo.Dropout, train)
264 | state3 := inputTensor.MustAdd(state2, false)
265 |
266 | retVal = state3.Apply(bo.LayerNorm)
267 |
268 | state1.MustDrop()
269 | state2.MustDrop()
270 | state3.MustDrop()
271 |
272 | return retVal
273 | }
274 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | Copyright 2020 Thang Tran.
179 |
180 | Licensed under the Apache License, Version 2.0 (the "License");
181 | you may not use this file except in compliance with the License.
182 | You may obtain a copy of the License at
183 |
184 | http://www.apache.org/licenses/LICENSE-2.0
185 |
186 | Unless required by applicable law or agreed to in writing, software
187 | distributed under the License is distributed on an "AS IS" BASIS,
188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
189 | See the License for the specific language governing permissions and
190 | limitations under the License.
191 |
--------------------------------------------------------------------------------
/bert/model_test.go:
--------------------------------------------------------------------------------
1 | package bert_test
2 |
3 | import (
4 | "fmt"
5 | "log"
6 | "reflect"
7 | "testing"
8 |
9 | "github.com/sugarme/gotch"
10 | "github.com/sugarme/gotch/nn"
11 | "github.com/sugarme/gotch/pickle"
12 | "github.com/sugarme/gotch/ts"
13 | "github.com/sugarme/tokenizer"
14 | "github.com/sugarme/tokenizer/model/wordpiece"
15 | "github.com/sugarme/tokenizer/normalizer"
16 | "github.com/sugarme/tokenizer/pretokenizer"
17 | "github.com/sugarme/tokenizer/processor"
18 |
19 | "github.com/sugarme/transformer/bert"
20 | "github.com/sugarme/transformer/util"
21 | )
22 |
23 | func getBertTokenizer(vocabFile string) (retVal *tokenizer.Tokenizer) {
24 | model, err := wordpiece.NewWordPieceFromFile(vocabFile, "[UNK]")
25 | if err != nil {
26 | log.Fatal(err)
27 | }
28 |
29 | tk := tokenizer.NewTokenizer(model)
30 |
31 | bertNormalizer := normalizer.NewBertNormalizer(true, true, true, true)
32 | tk.WithNormalizer(bertNormalizer)
33 |
34 | bertPreTokenizer := pretokenizer.NewBertPreTokenizer()
35 | tk.WithPreTokenizer(bertPreTokenizer)
36 |
37 | var specialTokens []tokenizer.AddedToken
38 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("[MASK]", true))
39 |
40 | tk.AddSpecialTokens(specialTokens)
41 |
42 | sepId, ok := tk.TokenToId("[SEP]")
43 | if !ok {
44 | log.Fatalf("Cannot find ID for [SEP] token.\n")
45 | }
46 | sep := processor.PostToken{Id: sepId, Value: "[SEP]"}
47 |
48 | clsId, ok := tk.TokenToId("[CLS]")
49 | if !ok {
50 | log.Fatalf("Cannot find ID for [CLS] token.\n")
51 | }
52 | cls := processor.PostToken{Id: clsId, Value: "[CLS]"}
53 |
54 | postProcess := processor.NewBertProcessing(sep, cls)
55 | tk.WithPostProcessor(postProcess)
56 |
57 | return tk
58 | }
59 |
60 | func TestBertForMaskedLM(t *testing.T) {
61 | // Config
62 | config := new(bert.BertConfig)
63 |
64 | configFile, err := util.CachedPath("bert-base-uncased", "config.json")
65 | if err != nil {
66 | log.Fatal(err)
67 | }
68 |
69 | err = config.Load(configFile, nil)
70 | if err != nil {
71 | log.Fatal(err)
72 | }
73 |
74 | // Model
75 | device := gotch.CPU
76 | vs := nn.NewVarStore(device)
77 |
78 | model, err := bert.NewBertForMaskedLM(vs.Root(), config)
79 | if err != nil {
80 | log.Fatal(err)
81 | }
82 | modelFile, err := util.CachedPath("bert-base-uncased", "pytorch_model.bin")
83 | if err != nil {
84 | log.Fatal(err)
85 | }
86 |
87 | err = pickle.LoadAll(vs, modelFile)
88 | if err != nil {
89 | t.Error(err)
90 | }
91 |
92 | vocabFile, err := util.CachedPath("bert-base-uncased", "vocab.txt")
93 | tk := getBertTokenizer(vocabFile)
94 | sentence1 := "Looks like one [MASK] is missing"
95 | sentence2 := "It was a very nice and [MASK] day"
96 |
97 | var input []tokenizer.EncodeInput
98 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)))
99 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)))
100 |
101 | encodings, err := tk.EncodeBatch(input, true)
102 | if err != nil {
103 | log.Fatal(err)
104 | }
105 |
106 | var maxLen int = 0
107 | for _, en := range encodings {
108 | if len(en.Ids) > maxLen {
109 | maxLen = len(en.Ids)
110 | }
111 | }
112 |
113 | var tensors []ts.Tensor
114 | for _, en := range encodings {
115 | var tokInput []int64 = make([]int64, maxLen)
116 | for i := 0; i < len(en.Ids); i++ {
117 | tokInput[i] = int64(en.Ids[i])
118 | }
119 |
120 | tensors = append(tensors, *ts.TensorFrom(tokInput))
121 | }
122 |
123 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true)
124 |
125 | var output *ts.Tensor
126 | ts.NoGrad(func() {
127 | output, _, _ = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, ts.None, ts.None, false)
128 | })
129 |
130 | index1 := output.MustGet(0).MustGet(4).MustArgmax([]int64{0}, false, false).Int64Values()[0]
131 | index2 := output.MustGet(1).MustGet(7).MustArgmax([]int64{0}, false, false).Int64Values()[0]
132 |
133 | got1, ok := tk.IdToToken(int(index1))
134 | if !ok {
135 | fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index1)
136 | }
137 | want1 := "person"
138 |
139 | if !reflect.DeepEqual(want1, got1) {
140 | t.Errorf("Want: '%v'\n", want1)
141 | t.Errorf("Got '%v'\n", got1)
142 | }
143 |
144 | got2, ok := tk.IdToToken(int(index2))
145 | if !ok {
146 | fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index2)
147 | }
148 | want2 := "pleasant"
149 |
150 | if !reflect.DeepEqual(want2, got2) {
151 | t.Errorf("Want: '%v'\n", want2)
152 | t.Errorf("Got '%v'\n", got2)
153 | }
154 | }
155 |
156 | func TestBertForSequenceClassification(t *testing.T) {
157 | device := gotch.CPU
158 | vs := nn.NewVarStore(device)
159 |
160 | configFile, err := util.CachedPath("bert-base-uncased", "config.json")
161 | if err != nil {
162 | t.Error(err)
163 | }
164 | config, err := bert.ConfigFromFile(configFile)
165 | if err != nil {
166 | t.Error(err)
167 | }
168 |
169 | var dummyLabelMap map[int64]string = make(map[int64]string)
170 | dummyLabelMap[0] = "positive"
171 | dummyLabelMap[1] = "negative"
172 | dummyLabelMap[3] = "neutral"
173 |
174 | config.Id2Label = dummyLabelMap
175 | config.OutputAttentions = true
176 | config.OutputHiddenStates = true
177 | model := bert.NewBertForSequenceClassification(vs.Root(), config)
178 |
179 | vocabFile, err := util.CachedPath("bert-base-uncased", "vocab.txt")
180 | tk := getBertTokenizer(vocabFile)
181 |
182 | // Define input
183 | sentence1 := "Looks like one thing is missing"
184 | sentence2 := `It's like comparing oranges to apples`
185 | var input []tokenizer.EncodeInput
186 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)))
187 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)))
188 | encodings, err := tk.EncodeBatch(input, true)
189 | if err != nil {
190 | log.Fatal(err)
191 | }
192 |
193 | // Find max length of token Ids from slice of encodings
194 | var maxLen int = 0
195 | for _, en := range encodings {
196 | if len(en.Ids) > maxLen {
197 | maxLen = len(en.Ids)
198 | }
199 | }
200 |
201 | var tensors []ts.Tensor
202 | for _, en := range encodings {
203 | var tokInput []int64 = make([]int64, maxLen)
204 | for i := 0; i < len(en.Ids); i++ {
205 | tokInput[i] = int64(en.Ids[i])
206 | }
207 |
208 | tensors = append(tensors, *ts.TensorFrom(tokInput))
209 | }
210 |
211 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true)
212 |
213 | var (
214 | output *ts.Tensor
215 | allHiddenStates, allAttentions []ts.Tensor
216 | )
217 |
218 | ts.NoGrad(func() {
219 | output, allHiddenStates, allAttentions = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, false)
220 | })
221 |
222 | fmt.Printf("output size: %v\n", output.MustSize())
223 | gotOuputSize := output.MustSize()
224 | wantOuputSize := []int64{2, 3}
225 | if !reflect.DeepEqual(wantOuputSize, gotOuputSize) {
226 | t.Errorf("Want: %v\n", wantOuputSize)
227 | t.Errorf("Got: %v\n", gotOuputSize)
228 | }
229 |
230 | numHiddenLayers := int(config.NumHiddenLayers)
231 |
232 | if !reflect.DeepEqual(numHiddenLayers, len(allHiddenStates)) {
233 | t.Errorf("Want num of allHiddenStates: %v\n", numHiddenLayers)
234 | t.Errorf("Got num of allHiddenStates: %v\n", len(allHiddenStates))
235 | }
236 |
237 | if !reflect.DeepEqual(numHiddenLayers, len(allAttentions)) {
238 | t.Errorf("Want num of allAttentions: %v\n", numHiddenLayers)
239 | t.Errorf("Got num of allAttentions: %v\n", len(allAttentions))
240 | }
241 | }
242 |
243 | func TestBertForMultipleChoice(t *testing.T) {
244 | device := gotch.CPU
245 | vs := nn.NewVarStore(device)
246 |
247 | configFile, err := util.CachedPath("bert-base-uncased", "config.json")
248 | if err != nil {
249 | t.Error(err)
250 | }
251 | config, err := bert.ConfigFromFile(configFile)
252 | if err != nil {
253 | t.Error(err)
254 | }
255 |
256 | config.OutputAttentions = true
257 | config.OutputHiddenStates = true
258 | model := bert.NewBertForMultipleChoice(vs.Root(), config)
259 |
260 | vocabFile, err := util.CachedPath("bert-base-uncased", "vocab.txt")
261 | tk := getBertTokenizer(vocabFile)
262 |
263 | // Define input
264 | sentence1 := "Looks like one thing is missing"
265 | sentence2 := `It's like comparing oranges to apples`
266 | var input []tokenizer.EncodeInput
267 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)))
268 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)))
269 | encodings, err := tk.EncodeBatch(input, true)
270 | if err != nil {
271 | log.Fatal(err)
272 | }
273 |
274 | // Find max length of token Ids from slice of encodings
275 | var maxLen int = 0
276 | for _, en := range encodings {
277 | if len(en.Ids) > maxLen {
278 | maxLen = len(en.Ids)
279 | }
280 | }
281 |
282 | var tensors []ts.Tensor
283 | for _, en := range encodings {
284 | var tokInput []int64 = make([]int64, maxLen)
285 | for i := 0; i < len(en.Ids); i++ {
286 | tokInput[i] = int64(en.Ids[i])
287 | }
288 |
289 | tensors = append(tensors, *ts.TensorFrom(tokInput))
290 | }
291 |
292 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true).MustUnsqueeze(0, true)
293 |
294 | var (
295 | output *ts.Tensor
296 | allHiddenStates, allAttentions []ts.Tensor
297 | )
298 |
299 | ts.NoGrad(func() {
300 | output, allHiddenStates, allAttentions = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, false)
301 | })
302 |
303 | fmt.Printf("output size: %v\n", output.MustSize())
304 | gotOuputSize := output.MustSize()
305 | wantOuputSize := []int64{1, 2}
306 | if !reflect.DeepEqual(wantOuputSize, gotOuputSize) {
307 | t.Errorf("Want: %v\n", wantOuputSize)
308 | t.Errorf("Got: %v\n", gotOuputSize)
309 | }
310 |
311 | numHiddenLayers := int(config.NumHiddenLayers)
312 |
313 | if !reflect.DeepEqual(numHiddenLayers, len(allHiddenStates)) {
314 | t.Errorf("Want num of allHiddenStates: %v\n", numHiddenLayers)
315 | t.Errorf("Got num of allHiddenStates: %v\n", len(allHiddenStates))
316 | }
317 |
318 | if !reflect.DeepEqual(numHiddenLayers, len(allAttentions)) {
319 | t.Errorf("Want num of allAttentions: %v\n", numHiddenLayers)
320 | t.Errorf("Got num of allAttentions: %v\n", len(allAttentions))
321 | }
322 | }
323 |
324 | func TestBertForTokenClassification(t *testing.T) {
325 | device := gotch.CPU
326 | vs := nn.NewVarStore(device)
327 |
328 | configFile, err := util.CachedPath("bert-base-uncased", "config.json")
329 | if err != nil {
330 | t.Error(err)
331 | }
332 | config, err := bert.ConfigFromFile(configFile)
333 | if err != nil {
334 | t.Error(err)
335 | }
336 |
337 | var dummyLabelMap map[int64]string = make(map[int64]string)
338 | dummyLabelMap[0] = "O"
339 | dummyLabelMap[1] = "LOC"
340 | dummyLabelMap[2] = "PER"
341 | dummyLabelMap[3] = "ORG"
342 |
343 | config.Id2Label = dummyLabelMap
344 | config.OutputAttentions = true
345 | config.OutputHiddenStates = true
346 | model := bert.NewBertForTokenClassification(vs.Root(), config)
347 |
348 | vocabFile, err := util.CachedPath("bert-base-uncased", "vocab.txt")
349 | tk := getBertTokenizer(vocabFile)
350 |
351 | // Define input
352 | sentence1 := "Looks like one thing is missing"
353 | sentence2 := `It's like comparing oranges to apples`
354 | var input []tokenizer.EncodeInput
355 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)))
356 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)))
357 | encodings, err := tk.EncodeBatch(input, true)
358 | if err != nil {
359 | log.Fatal(err)
360 | }
361 |
362 | // Find max length of token Ids from slice of encodings
363 | var maxLen int = 0
364 | for _, en := range encodings {
365 | if len(en.Ids) > maxLen {
366 | maxLen = len(en.Ids)
367 | }
368 | }
369 |
370 | var tensors []ts.Tensor
371 | for _, en := range encodings {
372 | var tokInput []int64 = make([]int64, maxLen)
373 | for i := 0; i < len(en.Ids); i++ {
374 | tokInput[i] = int64(en.Ids[i])
375 | }
376 |
377 | tensors = append(tensors, *ts.TensorFrom(tokInput))
378 | }
379 |
380 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true)
381 |
382 | var (
383 | output *ts.Tensor
384 | allHiddenStates, allAttentions []ts.Tensor
385 | )
386 |
387 | ts.NoGrad(func() {
388 | output, allHiddenStates, allAttentions = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, false)
389 | })
390 |
391 | fmt.Printf("output size: %v\n", output.MustSize())
392 | gotOuputSize := output.MustSize()
393 | wantOuputSize := []int64{2, 11, 4}
394 | if !reflect.DeepEqual(wantOuputSize, gotOuputSize) {
395 | t.Errorf("Want: %v\n", wantOuputSize)
396 | t.Errorf("Got: %v\n", gotOuputSize)
397 | }
398 |
399 | numHiddenLayers := int(config.NumHiddenLayers)
400 |
401 | if !reflect.DeepEqual(numHiddenLayers, len(allHiddenStates)) {
402 | t.Errorf("Want num of allHiddenStates: %v\n", numHiddenLayers)
403 | t.Errorf("Got num of allHiddenStates: %v\n", len(allHiddenStates))
404 | }
405 |
406 | if !reflect.DeepEqual(numHiddenLayers, len(allAttentions)) {
407 | t.Errorf("Want num of allAttentions: %v\n", numHiddenLayers)
408 | t.Errorf("Got num of allAttentions: %v\n", len(allAttentions))
409 | }
410 | }
411 |
412 | func TestBertForQuestionAnswering(t *testing.T) {
413 | device := gotch.CPU
414 | vs := nn.NewVarStore(device)
415 |
416 | configFile, err := util.CachedPath("bert-base-uncased", "config.json")
417 | if err != nil {
418 | t.Error(err)
419 | }
420 | config, err := bert.ConfigFromFile(configFile)
421 | if err != nil {
422 | t.Error(err)
423 | }
424 |
425 | config.OutputAttentions = true
426 | config.OutputHiddenStates = true
427 | model := bert.NewForBertQuestionAnswering(vs.Root(), config)
428 |
429 | vocabFile, err := util.CachedPath("bert-base-uncased", "vocab.txt")
430 | tk := getBertTokenizer(vocabFile)
431 |
432 | // Define input
433 | sentence1 := "Looks like one thing is missing"
434 | sentence2 := `It's like comparing oranges to apples`
435 | var input []tokenizer.EncodeInput
436 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)))
437 | input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)))
438 | encodings, err := tk.EncodeBatch(input, true)
439 | if err != nil {
440 | log.Fatal(err)
441 | }
442 |
443 | // Find max length of token Ids from slice of encodings
444 | var maxLen int = 0
445 | for _, en := range encodings {
446 | if len(en.Ids) > maxLen {
447 | maxLen = len(en.Ids)
448 | }
449 | }
450 |
451 | var tensors []ts.Tensor
452 | for _, en := range encodings {
453 | var tokInput []int64 = make([]int64, maxLen)
454 | for i := 0; i < len(en.Ids); i++ {
455 | tokInput[i] = int64(en.Ids[i])
456 | }
457 |
458 | tensors = append(tensors, *ts.TensorFrom(tokInput))
459 | }
460 |
461 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true)
462 |
463 | var (
464 | startScores, endScores *ts.Tensor
465 | allHiddenStates, allAttentions []ts.Tensor
466 | )
467 |
468 | ts.NoGrad(func() {
469 | startScores, endScores, allHiddenStates, allAttentions = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, false)
470 | })
471 |
472 | gotStartScoresSize := startScores.MustSize()
473 | wantStartScoresSize := []int64{2, 11}
474 | if !reflect.DeepEqual(wantStartScoresSize, gotStartScoresSize) {
475 | t.Errorf("Want: %v\n", wantStartScoresSize)
476 | t.Errorf("Got: %v\n", gotStartScoresSize)
477 | }
478 |
479 | gotEndScoresSize := endScores.MustSize()
480 | wantEndScoresSize := []int64{2, 11}
481 | if !reflect.DeepEqual(wantEndScoresSize, gotEndScoresSize) {
482 | t.Errorf("Want: %v\n", wantEndScoresSize)
483 | t.Errorf("Got: %v\n", gotEndScoresSize)
484 | }
485 |
486 | numHiddenLayers := int(config.NumHiddenLayers)
487 |
488 | if !reflect.DeepEqual(numHiddenLayers, len(allHiddenStates)) {
489 | t.Errorf("Want num of allHiddenStates: %v\n", numHiddenLayers)
490 | t.Errorf("Got num of allHiddenStates: %v\n", len(allHiddenStates))
491 | }
492 |
493 | if !reflect.DeepEqual(numHiddenLayers, len(allAttentions)) {
494 | t.Errorf("Want num of allAttentions: %v\n", numHiddenLayers)
495 | t.Errorf("Got num of allAttentions: %v\n", len(allAttentions))
496 | }
497 | }
498 |
--------------------------------------------------------------------------------
/roberta/model_test.go:
--------------------------------------------------------------------------------
1 | package roberta_test
2 |
3 | import (
4 | // "fmt"
5 | "log"
6 | "reflect"
7 | "testing"
8 |
9 | "github.com/sugarme/gotch"
10 | "github.com/sugarme/gotch/nn"
11 | "github.com/sugarme/gotch/pickle"
12 | "github.com/sugarme/gotch/ts"
13 | "github.com/sugarme/tokenizer"
14 | "github.com/sugarme/tokenizer/model/bpe"
15 | "github.com/sugarme/tokenizer/normalizer"
16 | "github.com/sugarme/tokenizer/pretokenizer"
17 | "github.com/sugarme/tokenizer/processor"
18 | "github.com/sugarme/transformer/bert"
19 | "github.com/sugarme/transformer/roberta"
20 | "github.com/sugarme/transformer/util"
21 | )
22 |
23 | func getRobertaTokenizer(vocabFile string, mergeFile string) (retVal *tokenizer.Tokenizer) {
24 |
25 | model, err := bpe.NewBpeFromFiles(vocabFile, mergeFile)
26 | if err != nil {
27 | log.Fatal(err)
28 | }
29 |
30 | tk := tokenizer.NewTokenizer(model)
31 |
32 | bertNormalizer := normalizer.NewBertNormalizer(true, true, true, true)
33 | tk.WithNormalizer(bertNormalizer)
34 |
35 | blPreTokenizer := pretokenizer.NewByteLevel()
36 | // blPreTokenizer.SetAddPrefixSpace(false)
37 | tk.WithPreTokenizer(blPreTokenizer)
38 |
39 | var specialTokens []tokenizer.AddedToken
40 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true))
41 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true))
42 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true))
43 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true))
44 | specialTokens = append(specialTokens, tokenizer.NewAddedToken("", true))
45 | tk.AddSpecialTokens(specialTokens)
46 |
47 | postProcess := processor.DefaultRobertaProcessing()
48 | tk.WithPostProcessor(postProcess)
49 |
50 | return tk
51 | }
52 |
53 | func TestRobertaForMaskedLM(t *testing.T) {
54 |
55 | // Config
56 | config := new(bert.BertConfig)
57 | configFile, err := util.CachedPath("roberta-base", "config.json")
58 | if err != nil {
59 | log.Fatal(err)
60 | }
61 | err = config.Load(configFile, nil)
62 | if err != nil {
63 | log.Fatal(err)
64 | }
65 |
66 | // Model
67 | device := gotch.CPU
68 | vs := nn.NewVarStore(device)
69 |
70 | model, err := roberta.NewRobertaForMaskedLM(vs.Root(), config)
71 | if err != nil {
72 | log.Fatal(err)
73 | }
74 |
75 | modelFile, err := util.CachedPath("roberta-base", "pytorch_model.bin")
76 | if err != nil {
77 | log.Fatal(err)
78 | }
79 | // err = vs.Load("../data/roberta/roberta-base-model.gt")
80 | err = pickle.LoadAll(vs, modelFile)
81 | if err != nil {
82 | log.Fatal(err)
83 | }
84 |
85 | // Roberta tokenizer
86 | vocabFile, err := util.CachedPath("roberta-base", "vocab.json")
87 | if err != nil {
88 | log.Fatal(err)
89 | }
90 | mergesFile, err := util.CachedPath("roberta-base", "merges.txt")
91 | if err != nil {
92 | log.Fatal(err)
93 | }
94 |
95 | tk := getRobertaTokenizer(vocabFile, mergesFile)
96 |
97 | sentence1 := "Looks like one is missing"
98 | sentence2 := "It's like comparing to apples"
99 |
100 | input := []tokenizer.EncodeInput{
101 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)),
102 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)),
103 | }
104 |
105 | /*
106 | * // NOTE: EncodeBatch does encode concurrently, so it does not keep original order!
107 | * encodings, err := tk.EncodeBatch(input, true)
108 | * if err != nil {
109 | * log.Fatal(err)
110 | * }
111 | * */
112 | var encodings []tokenizer.Encoding
113 | for _, i := range input {
114 | en, err := tk.Encode(i, true)
115 | if err != nil {
116 | log.Fatal(err)
117 | }
118 | encodings = append(encodings, *en)
119 | }
120 |
121 | // fmt.Printf("encodings:\n%+v\n", encodings)
122 |
123 | var maxLen int = 0
124 | for _, en := range encodings {
125 | if len(en.Ids) > maxLen {
126 | maxLen = len(en.Ids)
127 | }
128 | }
129 |
130 | var tensors []ts.Tensor
131 | for _, en := range encodings {
132 | var tokInput []int64 = make([]int64, maxLen)
133 | for i := 0; i < len(en.Ids); i++ {
134 | tokInput[i] = int64(en.Ids[i])
135 | }
136 |
137 | tensors = append(tensors, *ts.TensorFrom(tokInput))
138 | }
139 |
140 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true)
141 |
142 | var output *ts.Tensor
143 | ts.NoGrad(func() {
144 | output, _, _, err = model.Forward(inputTensor, ts.None, ts.None, ts.None, ts.None, ts.None, ts.None, false)
145 | if err != nil {
146 | log.Fatal(err)
147 | }
148 | })
149 |
150 | index1 := output.MustGet(0).MustGet(4).MustArgmax([]int64{0}, false, false).Int64Values()[0]
151 | index2 := output.MustGet(1).MustGet(5).MustArgmax([]int64{0}, false, false).Int64Values()[0]
152 | gotMask1 := tk.Decode([]int{int(index1)}, false)
153 | gotMask2 := tk.Decode([]int{int(index2)}, false)
154 |
155 | // fmt.Printf("index1: '%v' - mask1: '%v'\n", index1, gotMask1)
156 | // fmt.Printf("index2: '%v' - mask2: '%v'\n", index2, gotMask2)
157 | wantMask1 := "Ġperson"
158 | wantMask2 := "Ġapples"
159 |
160 | if !reflect.DeepEqual(wantMask1, gotMask1) {
161 | t.Errorf("Want: %v got %v\n", wantMask1, gotMask1)
162 | }
163 |
164 | if !reflect.DeepEqual(wantMask2, gotMask2) {
165 | t.Errorf("Want: %v got %v\n", wantMask2, gotMask2)
166 | }
167 |
168 | }
169 |
170 | func TestRobertaForSequenceClassification(t *testing.T) {
171 |
172 | // Config
173 | config := new(bert.BertConfig)
174 | configFile, err := util.CachedPath("roberta-base", "config.json")
175 | if err != nil {
176 | log.Fatal(err)
177 | }
178 | err = config.Load(configFile, nil)
179 | if err != nil {
180 | log.Fatal(err)
181 | }
182 |
183 | // Model
184 | device := gotch.CPU
185 | vs := nn.NewVarStore(device)
186 |
187 | var dummyLabelMap map[int64]string = make(map[int64]string)
188 | dummyLabelMap[0] = "Positive"
189 | dummyLabelMap[1] = "Negative"
190 | dummyLabelMap[3] = "Neutral"
191 |
192 | config.Id2Label = dummyLabelMap
193 | config.OutputAttentions = true
194 | config.OutputHiddenStates = true
195 | model := roberta.NewRobertaForSequenceClassification(vs.Root(), config)
196 |
197 | // Roberta tokenizer
198 | vocabFile, err := util.CachedPath("roberta-base", "vocab.json")
199 | if err != nil {
200 | log.Fatal(err)
201 | }
202 | mergesFile, err := util.CachedPath("roberta-base", "merges.txt")
203 | if err != nil {
204 | log.Fatal(err)
205 | }
206 |
207 | tk := getRobertaTokenizer(vocabFile, mergesFile)
208 |
209 | sentence1 := "Looks like one thing is missing"
210 | sentence2 := "It's like comparing oranges to apples"
211 |
212 | input := []tokenizer.EncodeInput{
213 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)),
214 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)),
215 | }
216 |
217 | var encodings []tokenizer.Encoding
218 | for _, i := range input {
219 | en, err := tk.Encode(i, true)
220 | if err != nil {
221 | log.Fatal(err)
222 | }
223 | encodings = append(encodings, *en)
224 | }
225 |
226 | var maxLen int = 0
227 | for _, en := range encodings {
228 | if len(en.Ids) > maxLen {
229 | maxLen = len(en.Ids)
230 | }
231 | }
232 |
233 | var tensors []ts.Tensor
234 | for _, en := range encodings {
235 | var tokInput []int64 = make([]int64, maxLen)
236 | for i := 0; i < len(en.Ids); i++ {
237 | tokInput[i] = int64(en.Ids[i])
238 | }
239 |
240 | tensors = append(tensors, *ts.TensorFrom(tokInput))
241 | }
242 |
243 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true)
244 |
245 | var (
246 | output *ts.Tensor
247 | hiddenStates, attentions []ts.Tensor
248 | )
249 |
250 | ts.NoGrad(func() {
251 | output, hiddenStates, attentions, err = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, false)
252 | if err != nil {
253 | log.Fatal(err)
254 | }
255 | })
256 |
257 | wantOutput := []int64{2, 3}
258 | gotOutput := output.MustSize()
259 |
260 | wantNumHiddenLayers := config.NumHiddenLayers
261 | gotNumHiddenLayers := int64(len(hiddenStates))
262 |
263 | wantAttentions := config.NumHiddenLayers
264 | gotAttentions := int64(len(attentions))
265 |
266 | if !reflect.DeepEqual(wantOutput, gotOutput) {
267 | t.Errorf("want %v - got %v\n", wantOutput, gotOutput)
268 | }
269 |
270 | if !reflect.DeepEqual(wantNumHiddenLayers, gotNumHiddenLayers) {
271 | t.Errorf("want %v - got %v\n", wantNumHiddenLayers, gotNumHiddenLayers)
272 | }
273 |
274 | if !reflect.DeepEqual(wantAttentions, gotAttentions) {
275 | t.Errorf("want %v - got %v\n", wantAttentions, gotAttentions)
276 | }
277 | }
278 |
279 | func TestRobertaForMultipleChoice(t *testing.T) {
280 | // Config
281 | config := new(bert.BertConfig)
282 | config.OutputAttentions = true
283 | config.OutputHiddenStates = true
284 | configFile, err := util.CachedPath("roberta-base", "config.json")
285 | if err != nil {
286 | log.Fatal(err)
287 | }
288 | err = config.Load(configFile, nil)
289 | if err != nil {
290 | log.Fatal(err)
291 | }
292 |
293 | // Model
294 | device := gotch.CPU
295 | vs := nn.NewVarStore(device)
296 |
297 | model := roberta.NewRobertaForMultipleChoice(vs.Root(), config)
298 |
299 | // Roberta tokenizer
300 | vocabFile, err := util.CachedPath("roberta-base", "vocab.json")
301 | if err != nil {
302 | log.Fatal(err)
303 | }
304 | mergesFile, err := util.CachedPath("roberta-base", "merges.txt")
305 | if err != nil {
306 | log.Fatal(err)
307 | }
308 |
309 | tk := getRobertaTokenizer(vocabFile, mergesFile)
310 |
311 | sentence1 := "Looks like one is missing"
312 | sentence2 := "It's like comparing to apples"
313 |
314 | input := []tokenizer.EncodeInput{
315 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)),
316 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)),
317 | }
318 |
319 | var encodings []tokenizer.Encoding
320 | for _, i := range input {
321 | en, err := tk.Encode(i, true)
322 | if err != nil {
323 | log.Fatal(err)
324 | }
325 | encodings = append(encodings, *en)
326 | }
327 |
328 | var maxLen int = 0
329 | for _, en := range encodings {
330 | if len(en.Ids) > maxLen {
331 | maxLen = len(en.Ids)
332 | }
333 | }
334 |
335 | var tensors []ts.Tensor
336 | for _, en := range encodings {
337 | var tokInput []int64 = make([]int64, maxLen)
338 | for i := 0; i < len(en.Ids); i++ {
339 | tokInput[i] = int64(en.Ids[i])
340 | }
341 |
342 | tensors = append(tensors, *ts.TensorFrom(tokInput))
343 | }
344 |
345 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true).MustUnsqueeze(0, true)
346 |
347 | var (
348 | output *ts.Tensor
349 | hiddenStates, attentions []ts.Tensor
350 | )
351 | ts.NoGrad(func() {
352 | output, hiddenStates, attentions, err = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, false)
353 | if err != nil {
354 | log.Fatal(err)
355 | }
356 | })
357 |
358 | wantOutput := []int64{1, 2}
359 | gotOutput := output.MustSize()
360 |
361 | wantHiddenStates := config.NumHiddenLayers
362 | gotHiddenStates := int64(len(hiddenStates))
363 |
364 | wantAttentions := config.NumHiddenLayers
365 | gotAttentions := int64(len(attentions))
366 |
367 | if !reflect.DeepEqual(wantOutput, gotOutput) {
368 | t.Errorf("want %v - got %v\n", wantOutput, gotOutput)
369 | }
370 |
371 | if !reflect.DeepEqual(wantHiddenStates, gotHiddenStates) {
372 | t.Errorf("want %v - got %v\n", wantHiddenStates, gotHiddenStates)
373 | }
374 |
375 | if !reflect.DeepEqual(wantAttentions, gotAttentions) {
376 | t.Errorf("want %v - got %v\n", wantAttentions, gotAttentions)
377 | }
378 | }
379 |
380 | func TestRobertaForTokenClassification(t *testing.T) {
381 |
382 | // Config
383 | config := new(bert.BertConfig)
384 | configFile, err := util.CachedPath("roberta-base", "config.json")
385 | if err != nil {
386 | log.Fatal(err)
387 | }
388 | err = config.Load(configFile, nil)
389 | if err != nil {
390 | log.Fatal(err)
391 | }
392 |
393 | // Model
394 | device := gotch.CPU
395 | vs := nn.NewVarStore(device)
396 |
397 | var dummyLabelMap map[int64]string = make(map[int64]string)
398 | dummyLabelMap[0] = "O"
399 | dummyLabelMap[1] = "LOC"
400 | dummyLabelMap[2] = "PER"
401 | dummyLabelMap[3] = "ORG"
402 |
403 | config.Id2Label = dummyLabelMap
404 | config.OutputAttentions = true
405 | config.OutputHiddenStates = true
406 | model := roberta.NewRobertaForTokenClassification(vs.Root(), config)
407 |
408 | // Roberta tokenizer
409 | vocabFile, err := util.CachedPath("roberta-base", "vocab.json")
410 | if err != nil {
411 | log.Fatal(err)
412 | }
413 | mergesFile, err := util.CachedPath("roberta-base", "merges.txt")
414 | if err != nil {
415 | log.Fatal(err)
416 | }
417 |
418 | tk := getRobertaTokenizer(vocabFile, mergesFile)
419 |
420 | sentence1 := "Looks like one thing is missing"
421 | sentence2 := "It's like comparing oranges to apples"
422 |
423 | input := []tokenizer.EncodeInput{
424 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)),
425 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)),
426 | }
427 |
428 | var encodings []tokenizer.Encoding
429 | for _, i := range input {
430 | en, err := tk.Encode(i, true)
431 | if err != nil {
432 | log.Fatal(err)
433 | }
434 | encodings = append(encodings, *en)
435 | }
436 |
437 | var maxLen int = 0
438 | for _, en := range encodings {
439 | if len(en.Ids) > maxLen {
440 | maxLen = len(en.Ids)
441 | }
442 | }
443 |
444 | var tensors []ts.Tensor
445 | for _, en := range encodings {
446 | var tokInput []int64 = make([]int64, maxLen)
447 | for i := 0; i < len(en.Ids); i++ {
448 | tokInput[i] = int64(en.Ids[i])
449 | }
450 |
451 | tensors = append(tensors, *ts.TensorFrom(tokInput))
452 | }
453 |
454 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true)
455 |
456 | var (
457 | output *ts.Tensor
458 | hiddenStates, attentions []ts.Tensor
459 | )
460 |
461 | ts.NoGrad(func() {
462 | output, hiddenStates, attentions, err = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, false)
463 | if err != nil {
464 | log.Fatal(err)
465 | }
466 | })
467 |
468 | wantOutput := []int64{2, 9, 4}
469 | gotOutput := output.MustSize()
470 |
471 | wantNumHiddenLayers := config.NumHiddenLayers
472 | gotNumHiddenLayers := int64(len(hiddenStates))
473 |
474 | wantAttentions := config.NumHiddenLayers
475 | gotAttentions := int64(len(attentions))
476 |
477 | if !reflect.DeepEqual(wantOutput, gotOutput) {
478 | t.Errorf("want %v - got %v\n", wantOutput, gotOutput)
479 | }
480 |
481 | if !reflect.DeepEqual(wantNumHiddenLayers, gotNumHiddenLayers) {
482 | t.Errorf("want %v - got %v\n", wantNumHiddenLayers, gotNumHiddenLayers)
483 | }
484 |
485 | if !reflect.DeepEqual(wantAttentions, gotAttentions) {
486 | t.Errorf("want %v - got %v\n", wantAttentions, gotAttentions)
487 | }
488 | }
489 |
490 | func TestRobertaForQuestionAnswering(t *testing.T) {
491 |
492 | // Config
493 | config := new(bert.BertConfig)
494 | configFile, err := util.CachedPath("roberta-base", "config.json")
495 | if err != nil {
496 | log.Fatal(err)
497 | }
498 |
499 | err = config.Load(configFile, nil)
500 | if err != nil {
501 | log.Fatal(err)
502 | }
503 |
504 | // Model
505 | device := gotch.CPU
506 | vs := nn.NewVarStore(device)
507 |
508 | var dummyLabelMap map[int64]string = make(map[int64]string)
509 | dummyLabelMap[0] = "Positive"
510 | dummyLabelMap[1] = "Negative"
511 | dummyLabelMap[3] = "Neutral"
512 |
513 | config.Id2Label = dummyLabelMap
514 | config.OutputAttentions = true
515 | config.OutputHiddenStates = true
516 | model := roberta.NewRobertaForQuestionAnswering(vs.Root(), config)
517 |
518 | // Roberta tokenizer
519 | vocabFile, err := util.CachedPath("roberta-base", "vocab.json")
520 | if err != nil {
521 | log.Fatal(err)
522 | }
523 | mergesFile, err := util.CachedPath("roberta-base", "merges.txt")
524 | if err != nil {
525 | log.Fatal(err)
526 | }
527 |
528 | tk := getRobertaTokenizer(vocabFile, mergesFile)
529 |
530 | sentence1 := "Looks like one thing is missing"
531 | sentence2 := "It's like comparing oranges to apples"
532 |
533 | input := []tokenizer.EncodeInput{
534 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)),
535 | tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)),
536 | }
537 |
538 | var encodings []tokenizer.Encoding
539 | for _, i := range input {
540 | en, err := tk.Encode(i, true)
541 | if err != nil {
542 | log.Fatal(err)
543 | }
544 | encodings = append(encodings, *en)
545 | }
546 |
547 | var maxLen int = 0
548 | for _, en := range encodings {
549 | if len(en.Ids) > maxLen {
550 | maxLen = len(en.Ids)
551 | }
552 | }
553 |
554 | var tensors []ts.Tensor
555 | for _, en := range encodings {
556 | var tokInput []int64 = make([]int64, maxLen)
557 | for i := 0; i < len(en.Ids); i++ {
558 | tokInput[i] = int64(en.Ids[i])
559 | }
560 |
561 | tensors = append(tensors, *ts.TensorFrom(tokInput))
562 | }
563 |
564 | inputTensor := ts.MustStack(tensors, 0).MustTo(device, true)
565 |
566 | var (
567 | startScores, endScores *ts.Tensor
568 | hiddenStates, attentions []ts.Tensor
569 | )
570 |
571 | ts.NoGrad(func() {
572 | startScores, endScores, hiddenStates, attentions, err = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, false)
573 | if err != nil {
574 | log.Fatal(err)
575 | }
576 | })
577 |
578 | wantStartScores := []int64{2, 9}
579 | gotStartScores := startScores.MustSize()
580 |
581 | wantEndScores := []int64{2, 9}
582 | gotEndScores := endScores.MustSize()
583 |
584 | wantNumHiddenLayers := config.NumHiddenLayers
585 | gotNumHiddenLayers := int64(len(hiddenStates))
586 |
587 | wantAttentions := config.NumHiddenLayers
588 | gotAttentions := int64(len(attentions))
589 |
590 | if !reflect.DeepEqual(wantStartScores, gotStartScores) {
591 | t.Errorf("want %v - got %v\n", wantStartScores, gotStartScores)
592 | }
593 |
594 | if !reflect.DeepEqual(wantEndScores, gotEndScores) {
595 | t.Errorf("want %v - got %v\n", wantEndScores, gotEndScores)
596 | }
597 |
598 | if !reflect.DeepEqual(wantNumHiddenLayers, gotNumHiddenLayers) {
599 | t.Errorf("want %v - got %v\n", wantNumHiddenLayers, gotNumHiddenLayers)
600 | }
601 |
602 | if !reflect.DeepEqual(wantAttentions, gotAttentions) {
603 | t.Errorf("want %v - got %v\n", wantAttentions, gotAttentions)
604 | }
605 | }
606 |
607 | func TestRobertaQA(t *testing.T) {
608 | // TODO: implement via pipelines
609 | }
610 |
611 | func TestRobertaNER(t *testing.T) {
612 | // TODO: implement via pipelines
613 | }
614 |
--------------------------------------------------------------------------------
/roberta/model.go:
--------------------------------------------------------------------------------
1 | package roberta
2 |
3 | // roberta package implements Roberta transformer model.
4 |
5 | import (
6 | // "fmt"
7 |
8 | "github.com/sugarme/gotch"
9 | "github.com/sugarme/gotch/nn"
10 | "github.com/sugarme/gotch/pickle"
11 | "github.com/sugarme/gotch/ts"
12 |
13 | "github.com/sugarme/transformer/bert"
14 | "github.com/sugarme/transformer/pretrained"
15 | "github.com/sugarme/transformer/util"
16 | )
17 |
18 | // RobertaLMHead holds data of Roberta LM head.
19 | type RobertaLMHead struct {
20 | dense *nn.Linear
21 | decoder *util.LinearNoBias
22 | layerNorm *nn.LayerNorm
23 | bias *ts.Tensor
24 | }
25 |
26 | // NewRobertaLMHead creates new RobertaLMHead.
27 | func NewRobertaLMHead(p *nn.Path, config *bert.BertConfig) (*RobertaLMHead, error) {
28 | dense := nn.NewLinear(p.Sub("dense"), config.HiddenSize, config.HiddenSize, nn.DefaultLinearConfig())
29 |
30 | layerNormConfig := nn.DefaultLayerNormConfig()
31 | layerNormConfig.Eps = 1e-12
32 | layerNorm := nn.NewLayerNorm(p.Sub("layer_norm"), []int64{config.HiddenSize}, layerNormConfig)
33 |
34 | decoder, err := util.NewLinearNoBias(p.Sub("decoder"), config.HiddenSize, config.VocabSize, util.DefaultLinearNoBiasConfig())
35 | if err != nil {
36 | return nil, err
37 | }
38 |
39 | bias, err := p.NewVar("bias", []int64{config.VocabSize}, nn.NewKaimingUniformInit())
40 | if err != nil {
41 | return nil, err
42 | }
43 |
44 | return &RobertaLMHead{
45 | dense: dense,
46 | decoder: decoder,
47 | layerNorm: layerNorm,
48 | bias: bias,
49 | }, nil
50 | }
51 |
52 | // Foward forwards pass through RobertaLMHead model.
53 | func (rh *RobertaLMHead) Forward(hiddenStates *ts.Tensor) *ts.Tensor {
54 | gelu := util.NewGelu()
55 | appliedDense := hiddenStates.Apply(rh.dense)
56 | geluFwd := gelu.Fwd(appliedDense)
57 | appliedLN := geluFwd.Apply(rh.layerNorm)
58 | appliedDecoder := appliedLN.Apply(rh.decoder)
59 | appliedBias := appliedDecoder.MustAdd(rh.bias, true)
60 |
61 | geluFwd.MustDrop()
62 | appliedDense.MustDrop()
63 | appliedLN.MustDrop()
64 |
65 | return appliedBias
66 | }
67 |
68 | // RobertaForMaskedLM holds data for Roberta masked language model.
69 | //
70 | // Base RoBERTa model with a RoBERTa masked language model head to predict
71 | // missing tokens.
72 | type RobertaForMaskedLM struct {
73 | roberta *bert.BertModel
74 | lmHead *RobertaLMHead
75 | }
76 |
77 | // NewRobertaForMaskedLM builds a new RobertaForMaskedLM.
78 | func NewRobertaForMaskedLM(p *nn.Path, config *bert.BertConfig) (*RobertaForMaskedLM, error) {
79 | roberta := bert.NewBertModel(p.Sub("roberta"), config, false)
80 | lmHead, err := NewRobertaLMHead(p.Sub("lm_head"), config)
81 | if err != nil {
82 | return nil, err
83 | }
84 |
85 | return &RobertaForMaskedLM{
86 | roberta: roberta,
87 | lmHead: lmHead,
88 | }, nil
89 | }
90 |
91 | // Load loads model from file or model name. It also updates
92 | // default configuration parameters if provided.
93 | // This method implements `PretrainedModel` interface.
94 | func (mlm *RobertaForMaskedLM) Load(modelNameOrPath string, config interface{ pretrained.Config }, params map[string]interface{}, device gotch.Device) error {
95 | // var urlOrFilename string
96 | // // If modelName, infer to default configuration filename:
97 | // if modelFile, ok := pretrained.RobertaModels[modelNameOrPath]; ok {
98 | // urlOrFilename = modelFile
99 | // } else {
100 | // // Otherwise, just take the input
101 | // urlOrFilename = modelNameOrPath
102 | // }
103 |
104 | cachedFile, err := util.CachedPath(modelNameOrPath, "pytorch_model.bin")
105 | if err != nil {
106 | return err
107 | }
108 |
109 | vs := nn.NewVarStore(device)
110 | p := vs.Root()
111 |
112 | mlm.roberta = bert.NewBertModel(p.Sub("roberta"), config.(*bert.BertConfig), false)
113 | mlm.lmHead, err = NewRobertaLMHead(p.Sub("lm_head"), config.(*bert.BertConfig))
114 |
115 | // err = vs.Load(cachedFile)
116 | err = pickle.LoadAll(vs, cachedFile)
117 | if err != nil {
118 | return err
119 | }
120 |
121 | return nil
122 | }
123 |
124 | // Forwad forwads pass through the model.
125 | //
126 | // Params:
127 | // - `inputIds`: Optional input tensor of shape (batch size, sequence length).
128 | // If None, pre-computed embeddings must be provided (see inputEmbeds).
129 | // - `mask`: Optional mask of shape (batch size, sequence length).
130 | // Masked position have value 0, non-masked value 1. If None set to 1.
131 | // - `tokenTypeIds`: Optional segment id of shape (batch size, sequence length).
132 | // Convention is value of 0 for the first sentence (incl. ) and 1 for the
133 | // second sentence. If None set to 0.
134 | // - `positionIds`: Optional position ids of shape (batch size, sequence length).
135 | // If None, will be incremented from 0.
136 | // - `inputEmbeds`: Optional pre-computed input embeddings of shape (batch size,
137 | // sequence length, hidden size). If None, input ids must be provided (see inputIds).
138 | // - `encoderHiddenStates`: Optional encoder hidden state of shape (batch size,
139 | // encoder sequence length, hidden size). If the model is defined as a decoder and
140 | // the encoder hidden states is not None, used in the cross-attention layer as
141 | // keys and values (query from the decoder).
142 | // - `encoderMask`: Optional encoder attention mask of shape (batch size, encoder sequence length).
143 | // If the model is defined as a decoder and the *encoder_hidden_states* is not None,
144 | // used to mask encoder values. Positions with value 0 will be masked.
145 | // - `train`: boolean flag to turn on/off the dropout layers in the model.
146 | // Should be set to false for inference.
147 | //
148 | // Returns:
149 | // - `output`: tensor of shape (batch size, numLabels, vocab size)
150 | // - `hiddenStates`: optional slice of tensors of length numHiddenLayers with shape
151 | // (batch size, sequence length, hidden size).
152 | // - `attentions`: optional slice of tensors of length num hidden layers with shape
153 | // (batch size, sequence length, hidden size).
154 | // - `err`: error
155 | func (mlm *RobertaForMaskedLM) Forward(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, encoderHiddenStates, encoderMask *ts.Tensor, train bool) (output *ts.Tensor, hiddenStates, attentions []ts.Tensor, err error) {
156 |
157 | hiddenState, _, allHiddenStates, allAttentions, err := mlm.roberta.ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, encoderHiddenStates, encoderMask, train)
158 |
159 | if err != nil {
160 | return ts.None, nil, nil, err
161 | }
162 |
163 | predictionScores := mlm.lmHead.Forward(hiddenState)
164 |
165 | return predictionScores, allHiddenStates, allAttentions, nil
166 | }
167 |
168 | // RoberatClassificationHead holds data for Roberta classification head.
169 | type RobertaClassificationHead struct {
170 | dense *nn.Linear
171 | dropout *util.Dropout
172 | outProj *nn.Linear
173 | }
174 |
175 | // NewRobertaClassificationHead create a new RobertaClassificationHead.
176 | func NewRobertaClassificationHead(p *nn.Path, config *bert.BertConfig) *RobertaClassificationHead {
177 | dense := nn.NewLinear(p.Sub("dense"), config.HiddenSize, config.HiddenSize, nn.DefaultLinearConfig())
178 | numLabels := int64(len(config.Id2Label))
179 | outProj := nn.NewLinear(p.Sub("out_proj"), config.HiddenSize, numLabels, nn.DefaultLinearConfig())
180 | dropout := util.NewDropout(config.HiddenDropoutProb)
181 |
182 | return &RobertaClassificationHead{
183 | dense: dense,
184 | dropout: dropout,
185 | outProj: outProj,
186 | }
187 | }
188 |
189 | // ForwardT forwards pass through model.
190 | func (ch *RobertaClassificationHead) ForwardT(hiddenStates *ts.Tensor, train bool) *ts.Tensor {
191 | appliedDO1 := hiddenStates.MustSelect(1, 0, false).ApplyT(ch.dropout, train)
192 | appliedDense := appliedDO1.Apply(ch.dense)
193 | tanhTs := appliedDense.MustTanh(false)
194 | appliedDO2 := tanhTs.ApplyT(ch.dropout, train)
195 | retVal := appliedDO2.Apply(ch.outProj)
196 |
197 | appliedDO1.MustDrop()
198 | appliedDense.MustDrop()
199 | tanhTs.MustDrop()
200 | appliedDO2.MustDrop()
201 |
202 | return retVal
203 | }
204 |
205 | // RobertaForSequenceClassification holds data for Roberta sequence classification model.
206 | // It's used for performing sentence or document-level classification.
207 | type RobertaForSequenceClassification struct {
208 | roberta *bert.BertModel
209 | classifier *RobertaClassificationHead
210 | }
211 |
212 | // NewRobertaForSequenceClassification creates a new RobertaForSequenceClassification model.
213 | func NewRobertaForSequenceClassification(p *nn.Path, config *bert.BertConfig) *RobertaForSequenceClassification {
214 | roberta := bert.NewBertModel(p.Sub("roberta"), config, false)
215 | classifier := NewRobertaClassificationHead(p.Sub("classifier"), config)
216 |
217 | return &RobertaForSequenceClassification{
218 | roberta: roberta,
219 | classifier: classifier,
220 | }
221 | }
222 |
223 | // Load loads model from file or model name. It also updates default configuration parameters if provided.
224 | //
225 | // This method implements `PretrainedModel` interface.
226 | func (sc *RobertaForSequenceClassification) Load(modelNameOrPath string, config interface{ pretrained.Config }, params map[string]interface{}, device gotch.Device) error {
227 | // var urlOrFilename string
228 | // // If modelName, infer to default configuration filename:
229 | // if modelFile, ok := pretrained.RobertaModels[modelNameOrPath]; ok {
230 | // urlOrFilename = modelFile
231 | // } else {
232 | // // Otherwise, just take the input
233 | // urlOrFilename = modelNameOrPath
234 | // }
235 |
236 | cachedFile, err := util.CachedPath(modelNameOrPath, "pytorch_model.bin")
237 | if err != nil {
238 | return err
239 | }
240 |
241 | vs := nn.NewVarStore(device)
242 | p := vs.Root()
243 |
244 | sc.roberta = bert.NewBertModel(p.Sub("roberta"), config.(*bert.BertConfig), false)
245 | sc.classifier = NewRobertaClassificationHead(p.Sub("classifier"), config.(*bert.BertConfig))
246 |
247 | // err = vs.Load(cachedFile)
248 | err = pickle.LoadAll(vs, cachedFile)
249 | if err != nil {
250 | return err
251 | }
252 |
253 | return nil
254 | }
255 |
256 | // Forward forwards pass through the model.
257 | func (sc *RobertaForSequenceClassification) ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (labels *ts.Tensor, hiddenStates, attentions []ts.Tensor, err error) {
258 |
259 | hiddenState, _, hiddenStates, attentions, err := sc.roberta.ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, ts.None, ts.None, train)
260 | if err != nil {
261 | return ts.None, nil, nil, err
262 | }
263 |
264 | labels = sc.classifier.ForwardT(hiddenState, train)
265 | hiddenState.MustDrop()
266 |
267 | return labels, hiddenStates, attentions, nil
268 | }
269 |
270 | // RobertaForMultipleChoice holds data for Roberta multiple choice model.
271 | //
272 | // Input should be in form of ` Context Possible choice `.
273 | // The choice is made along the batch axis, assuming all elements of the batch are
274 | // alternatives to be chosen from for a given context.
275 | type RobertaForMultipleChoice struct {
276 | roberta *bert.BertModel
277 | dropout *util.Dropout
278 | classifier *nn.Linear
279 | }
280 |
281 | // NewRobertaForMultipleChoice creates a new RobertaForMultipleChoice model.
282 | func NewRobertaForMultipleChoice(p *nn.Path, config *bert.BertConfig) *RobertaForMultipleChoice {
283 | roberta := bert.NewBertModel(p.Sub("roberta"), config, false)
284 | dropout := util.NewDropout(config.HiddenDropoutProb)
285 | classifier := nn.NewLinear(p.Sub("classifier"), config.HiddenSize, 1, nn.DefaultLinearConfig())
286 |
287 | return &RobertaForMultipleChoice{
288 | roberta: roberta,
289 | dropout: dropout,
290 | classifier: classifier,
291 | }
292 | }
293 |
294 | // Load loads model from file or model name. It also updates default configuration parameters if provided.
295 | //
296 | // This method implements `PretrainedModel` interface.
297 | func (mc *RobertaForMultipleChoice) Load(modelNameOrPath string, config interface{ pretrained.Config }, params map[string]interface{}, device gotch.Device) error {
298 | // var urlOrFilename string
299 | // // If modelName, infer to default configuration filename:
300 | // if modelFile, ok := pretrained.RobertaModels[modelNameOrPath]; ok {
301 | // urlOrFilename = modelFile
302 | // } else {
303 | // // Otherwise, just take the input
304 | // urlOrFilename = modelNameOrPath
305 | // }
306 |
307 | cachedFile, err := util.CachedPath(modelNameOrPath, "pytorch_model.bin")
308 | if err != nil {
309 | return err
310 | }
311 |
312 | vs := nn.NewVarStore(device)
313 | p := vs.Root()
314 |
315 | mc.roberta = bert.NewBertModel(p.Sub("roberta"), config.(*bert.BertConfig), false)
316 | mc.dropout = util.NewDropout(config.(*bert.BertConfig).HiddenDropoutProb)
317 | classifier := nn.NewLinear(p.Sub("classifier"), config.(*bert.BertConfig).HiddenSize, 1, nn.DefaultLinearConfig())
318 | mc.classifier = classifier
319 |
320 | err = vs.Load(cachedFile)
321 | if err != nil {
322 | return err
323 | }
324 |
325 | return nil
326 | }
327 |
328 | // ForwardT forwards pass through the model.
329 | func (mc *RobertaForMultipleChoice) ForwardT(inputIds, mask, tokenTypeIds, positionIds *ts.Tensor, train bool) (output *ts.Tensor, hiddenStates, attentions []ts.Tensor, err error) {
330 |
331 | numChoices := inputIds.MustSize()[1]
332 |
333 | inputIdsSize := inputIds.MustSize()
334 | flatInputIds := inputIds.MustView([]int64{-1, inputIdsSize[len(inputIdsSize)-1]}, false)
335 |
336 | flatPositionIds := ts.None
337 | if positionIds.MustDefined() {
338 | positionIdsSize := positionIds.MustSize()
339 | flatPositionIds = positionIds.MustView([]int64{-1, positionIdsSize[len(positionIdsSize)-1]}, false)
340 | }
341 |
342 | flatTokenTypeIds := ts.None
343 | if tokenTypeIds.MustDefined() {
344 | tokenTypeIdsSize := tokenTypeIds.MustSize()
345 | flatTokenTypeIds = tokenTypeIds.MustView([]int64{-1, tokenTypeIdsSize[len(tokenTypeIdsSize)-1]}, false)
346 | }
347 |
348 | flatMask := ts.None
349 | if mask.MustDefined() {
350 | flatMaskSize := flatMask.MustSize()
351 | flatMask = mask.MustView([]int64{-1, flatMaskSize[len(flatMaskSize)-1]}, false)
352 | }
353 |
354 | var pooledOutput *ts.Tensor
355 | _, pooledOutput, hiddenStates, attentions, err = mc.roberta.ForwardT(flatInputIds, flatMask, flatTokenTypeIds, flatPositionIds, ts.None, ts.None, ts.None, train)
356 | if err != nil {
357 | return ts.None, nil, nil, err
358 | }
359 |
360 | appliedDO := pooledOutput.ApplyT(mc.dropout, train)
361 | appliedCls := appliedDO.Apply(mc.classifier)
362 | output = appliedCls.MustView([]int64{-1, numChoices}, true)
363 |
364 | appliedDO.MustDrop()
365 |
366 | return output, hiddenStates, attentions, nil
367 | }
368 |
369 | // RobertaForTokenClassification holds data for Roberta token classification model.
370 | type RobertaForTokenClassification struct {
371 | roberta *bert.BertModel
372 | dropout *util.Dropout
373 | classifier *nn.Linear
374 | }
375 |
376 | // NewRobertaForTokenClassification creates a new RobertaForTokenClassification model.
377 | func NewRobertaForTokenClassification(p *nn.Path, config *bert.BertConfig) *RobertaForTokenClassification {
378 | roberta := bert.NewBertModel(p.Sub("roberta"), config, false)
379 | dropout := util.NewDropout(config.HiddenDropoutProb)
380 | numLabels := int64(len(config.Id2Label))
381 | classifier := nn.NewLinear(p.Sub("classifier"), config.HiddenSize, numLabels, nn.DefaultLinearConfig())
382 |
383 | return &RobertaForTokenClassification{
384 | roberta: roberta,
385 | dropout: dropout,
386 | classifier: classifier,
387 | }
388 | }
389 |
390 | // Load loads model from file or model name. It also updates default configuration parameters if provided.
391 | //
392 | // This method implements `PretrainedModel` interface.
393 | func (tc *RobertaForTokenClassification) Load(modelNameOrPath string, config interface{ pretrained.Config }, params map[string]interface{}, device gotch.Device) error {
394 | // var urlOrFilename string
395 | // // If modelName, infer to default configuration filename:
396 | // if modelFile, ok := pretrained.RobertaModels[modelNameOrPath]; ok {
397 | // urlOrFilename = modelFile
398 | // } else {
399 | // // Otherwise, just take the input
400 | // urlOrFilename = modelNameOrPath
401 | // }
402 |
403 | cachedFile, err := util.CachedPath(modelNameOrPath, "pytorch_model.bin")
404 | if err != nil {
405 | return err
406 | }
407 |
408 | vs := nn.NewVarStore(device)
409 | p := vs.Root()
410 |
411 | roberta := bert.NewBertModel(p.Sub("roberta"), config.(*bert.BertConfig), false)
412 | dropout := util.NewDropout(config.(*bert.BertConfig).HiddenDropoutProb)
413 | numLabels := int64(len(config.(*bert.BertConfig).Id2Label))
414 | classifier := nn.NewLinear(p.Sub("classifier"), config.(*bert.BertConfig).HiddenSize, numLabels, nn.DefaultLinearConfig())
415 |
416 | tc.roberta = roberta
417 | tc.dropout = dropout
418 | tc.classifier = classifier
419 |
420 | err = vs.Load(cachedFile)
421 | if err != nil {
422 | return err
423 | }
424 |
425 | return nil
426 | }
427 |
428 | // ForwardT forwards pass through the model.
429 | func (tc *RobertaForTokenClassification) ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (output *ts.Tensor, hiddenStates, attentions []ts.Tensor, err error) {
430 | hiddenState, _, hiddenStates, attentions, err := tc.roberta.ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, ts.None, ts.None, train)
431 | if err != nil {
432 | return ts.None, nil, nil, err
433 | }
434 |
435 | appliedDO := hiddenState.ApplyT(tc.dropout, train)
436 | output = appliedDO.Apply(tc.classifier)
437 |
438 | appliedDO.MustDrop()
439 |
440 | return output, hiddenStates, attentions, nil
441 | }
442 |
443 | // RobertaForQuestionAnswering constructs layers for Roberta question answering model.
444 | type RobertaForQuestionAnswering struct {
445 | roberta *bert.BertModel
446 | qaOutputs *nn.Linear
447 | }
448 |
449 | // NewRobertaQuestionAnswering creates a new RobertaForQuestionAnswering model.
450 | func NewRobertaForQuestionAnswering(p *nn.Path, config *bert.BertConfig) *RobertaForQuestionAnswering {
451 | roberta := bert.NewBertModel(p.Sub("roberta"), config, false)
452 | numLabels := int64(2)
453 | qaOutputs := nn.NewLinear(p.Sub("qa_outputs"), config.HiddenSize, numLabels, nn.DefaultLinearConfig())
454 |
455 | return &RobertaForQuestionAnswering{
456 | roberta: roberta,
457 | qaOutputs: qaOutputs,
458 | }
459 | }
460 |
461 | // Load loads model from file or model name. It also updates default configuration parameters if provided.
462 | //
463 | // This method implements `PretrainedModel` interface.
464 | func (qa *RobertaForQuestionAnswering) Load(modelNameOrPath string, config interface{ pretrained.Config }, params map[string]interface{}, device gotch.Device) error {
465 | // var urlOrFilename string
466 | // // If modelName, infer to default configuration filename:
467 | // if modelFile, ok := pretrained.RobertaModels[modelNameOrPath]; ok {
468 | // urlOrFilename = modelFile
469 | // } else {
470 | // // Otherwise, just take the input
471 | // urlOrFilename = modelNameOrPath
472 | // }
473 |
474 | cachedFile, err := util.CachedPath(modelNameOrPath, "pytorch_model.bin")
475 | if err != nil {
476 | return err
477 | }
478 |
479 | vs := nn.NewVarStore(device)
480 | p := vs.Root()
481 |
482 | roberta := bert.NewBertModel(p.Sub("roberta"), config.(*bert.BertConfig), false)
483 | numLabels := int64(2)
484 | qaOutputs := nn.NewLinear(p.Sub("qa_outputs"), config.(*bert.BertConfig).HiddenSize, numLabels, nn.DefaultLinearConfig())
485 |
486 | qa.roberta = roberta
487 | qa.qaOutputs = qaOutputs
488 |
489 | err = vs.Load(cachedFile)
490 | if err != nil {
491 | return err
492 | }
493 |
494 | return nil
495 | }
496 |
497 | // ForwadT forwards pass through the model.
498 | func (qa *RobertaForQuestionAnswering) ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (startScores, endScores *ts.Tensor, hiddenStates, attentions []ts.Tensor, err error) {
499 | hiddenState, _, hiddenStates, attentions, err := qa.roberta.ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, ts.None, ts.None, train)
500 | if err != nil {
501 | return ts.None, ts.None, nil, nil, err
502 | }
503 |
504 | sequenceOutput := hiddenState.Apply(qa.qaOutputs)
505 | logits := sequenceOutput.MustSplit(1, -1, true)
506 | startScores = logits[0].MustSqueezeDim(-1, false)
507 | endScores = logits[1].MustSqueezeDim(-1, false)
508 |
509 | for _, x := range logits {
510 | x.MustDrop()
511 | }
512 |
513 | return startScores, endScores, hiddenStates, attentions, nil
514 | }
515 |
--------------------------------------------------------------------------------
/bert/model.go:
--------------------------------------------------------------------------------
1 | package bert
2 |
3 | import (
4 | "fmt"
5 | "log"
6 |
7 | "github.com/sugarme/gotch"
8 | "github.com/sugarme/gotch/nn"
9 | "github.com/sugarme/gotch/pickle"
10 | "github.com/sugarme/gotch/ts"
11 |
12 | "github.com/sugarme/transformer/pretrained"
13 | "github.com/sugarme/transformer/util"
14 | )
15 |
16 | // BertModel defines base architecture for BERT models.
17 | // Task-specific models can be built from this base model.
18 | //
19 | // Fields:
20 | // - Embeddings: for `token`, `position` and `segment` embeddings
21 | // - Encoder: is a vector of layers. Each layer compose of a `self-attention`,
22 | //
23 | // an `intermedate` (linear) and an output ( linear + layer norm) sub-layers.
24 | // - Pooler: linear layer applied to the first element of the sequence (`[MASK]` token)
25 | // - IsDecoder: whether model is used as a decoder. If set to `true`
26 | //
27 | // a casual mask will be applied to hide future positions that should be attended to.
28 | type BertModel struct {
29 | Embeddings *BertEmbeddings
30 | Encoder *BertEncoder
31 | Pooler *BertPooler
32 | IsDecoder bool
33 | }
34 |
35 | // NewBertModel builds a new `BertModel`.
36 | //
37 | // Params:
38 | // - `p`: Variable store path for the root of the BERT Model
39 | // - `config`: BertConfig onfiguration for model architecture and decoder status
40 | func NewBertModel(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertModel {
41 | changeName := true
42 | if len(changeNameOpt) > 0 {
43 | changeName = changeNameOpt[0]
44 | }
45 | isDecoder := false
46 | if config.IsDecoder {
47 | isDecoder = true
48 | }
49 |
50 | embeddings := NewBertEmbeddings(p.Sub("embeddings"), config, changeName)
51 | encoder := NewBertEncoder(p.Sub("encoder"), config, changeName)
52 | pooler := NewBertPooler(p.Sub("pooler"), config)
53 |
54 | return &BertModel{embeddings, encoder, pooler, isDecoder}
55 | }
56 |
57 | // ForwardT forwards pass through the model.
58 | //
59 | // Params:
60 | // - `inputIds`: optional input tensor of shape (batch size, sequence length).
61 | // If None, pre-computed embeddings must be provided (see `inputEmbeds`)
62 | // - `mask`: optional mask of shape (batch size, sequence length).
63 | // Masked position have value 0, non-masked value 1. If None set to 1.
64 | // - `tokenTypeIds`: optional segment id of shape (batch size, sequence length).
65 | // Convention is value of 0 for the first sentence (incl. [SEP]) and 1 for the second sentence. If None set to 0.
66 | // - `positionIds`: optional position ids of shape (batch size, sequence length).
67 | // If None, will be incremented from 0.
68 | // - `inputEmbeds`: optional pre-computed input embeddings of shape (batch size, sequence length, hidden size).
69 | // If None, input ids must be provided (see `inputIds`).
70 | // - `encoderHiddenStates`: optional encoder hidden state of shape (batch size, encoder sequence length, hidden size).
71 | // If the model is defined as a decoder and the `encoderHiddenStates` is not None,
72 | // used in the cross-attention layer as keys and values (query from the decoder).
73 | // - `encoderMask`: optional encoder attention mask of shape (batch size, encoder sequence length).
74 | // If the model is defined as a decoder and the `encoderHiddenStates` is not None, used to mask encoder values.
75 | // Positions with value 0 will be masked.
76 | // - `train`: boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
77 | //
78 | // Returns:
79 | // - `output`: tensor of shape (batch size, sequence length, hidden size)
80 | // - `pooledOutput`: tensor of shape (batch size, hidden size)
81 | // - `hiddenStates`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize)
82 | // - `attentions`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize)
83 | func (b *BertModel) ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, encoderHiddenStates, encoderMask *ts.Tensor, train bool) (retVal1, retVal2 *ts.Tensor, retValOpt1, retValOpt2 []ts.Tensor, err error) {
84 |
85 | var (
86 | inputShape []int64
87 | device gotch.Device
88 | )
89 |
90 | if inputIds.MustDefined() {
91 | if inputEmbeds.MustDefined() {
92 | err = fmt.Errorf("Only one of input ids or input embeddings may be set\n")
93 | return
94 | }
95 | inputShape = inputIds.MustSize()
96 | device = inputIds.MustDevice()
97 | } else {
98 | if inputEmbeds.MustDefined() {
99 | size := inputEmbeds.MustSize()
100 | inputShape = []int64{size[0], size[1]}
101 | device = inputEmbeds.MustDevice()
102 | } else {
103 | err = fmt.Errorf("At least one of input ids or input embeddings must be set\n")
104 | return
105 | }
106 | }
107 |
108 | var maskTs *ts.Tensor
109 | if mask.MustDefined() {
110 | maskTs = mask
111 | } else {
112 | maskTs = ts.MustOnes(inputShape, gotch.Int64, device)
113 | }
114 |
115 | var extendedAttentionMask *ts.Tensor
116 | switch maskTs.Dim() {
117 | case 3:
118 | extendedAttentionMask = maskTs.MustUnsqueeze(1, false) // TODO: check and delete maskTs if not using later
119 | case 2:
120 | if b.IsDecoder {
121 | seqIds := ts.MustArange(ts.IntScalar(inputShape[1]), gotch.Float, device)
122 | causalMaskTmp := seqIds.MustUnsqueeze(0, false).MustUnsqueeze(0, true).MustRepeat([]int64{inputShape[0], inputShape[1], 1}, true)
123 | causalMask := causalMaskTmp.MustLeTensor(seqIds.MustUnsqueeze(0, true).MustUnsqueeze(1, true), true)
124 | extendedAttentionMask = causalMask.MustMatmul(mask.MustUnsqueeze(1, false).MustUnsqueeze(1, true), true)
125 | } else {
126 | extendedAttentionMask = maskTs.MustUnsqueeze(1, false).MustUnsqueeze(1, true)
127 | }
128 |
129 | default:
130 | err = fmt.Errorf("Invalid attention mask dimension, must be 2 or 3, got %v\n", maskTs.Dim())
131 | }
132 |
133 | extendedAttnMask := extendedAttentionMask.MustOnesLike(false).MustSub(extendedAttentionMask, true).MustMulScalar(ts.FloatScalar(-10000.0), true)
134 |
135 | // NOTE. encoderExtendedAttentionMask is an optional tensor
136 | var encoderExtendedAttentionMask *ts.Tensor
137 | if b.IsDecoder && encoderHiddenStates.MustDefined() {
138 | size := encoderHiddenStates.MustSize()
139 | var encoderMaskTs *ts.Tensor
140 | if encoderMask.MustDefined() {
141 | encoderMaskTs = encoderMask
142 | } else {
143 | encoderMaskTs = ts.MustOnes([]int64{size[0], size[1]}, gotch.Int64, device)
144 | }
145 |
146 | switch encoderMaskTs.Dim() {
147 | case 2:
148 | encoderExtendedAttentionMask = encoderMaskTs.MustUnsqueeze(1, true).MustUnsqueeze(1, true)
149 | case 3:
150 | encoderExtendedAttentionMask = encoderMaskTs.MustUnsqueeze(1, true)
151 | default:
152 | err = fmt.Errorf("Invalid encoder attention mask dimension, must be 2, or 3 got %v\n", encoderMaskTs.Dim())
153 | return
154 | }
155 | } else {
156 | encoderExtendedAttentionMask = ts.None
157 | }
158 |
159 | embeddingOutput, err := b.Embeddings.ForwardT(inputIds, tokenTypeIds, positionIds, inputEmbeds, train)
160 | if err != nil {
161 | return
162 | }
163 |
164 | hiddenState, allHiddenStates, allAttentions := b.Encoder.ForwardT(embeddingOutput, extendedAttnMask, encoderHiddenStates, encoderExtendedAttentionMask, train)
165 |
166 | pooledOutput := b.Pooler.Forward(hiddenState)
167 |
168 | return hiddenState, pooledOutput, allHiddenStates, allAttentions, nil
169 | }
170 |
171 | // BertPredictionHeadTransform:
172 | // ============================
173 |
174 | // BertPredictionHeadTransform holds layers of BERT prediction head transform.
175 | type BertPredictionHeadTransform struct {
176 | Dense *nn.Linear
177 | Activation util.ActivationFn
178 | LayerNorm *nn.LayerNorm
179 | }
180 |
181 | // NewBertPredictionHead creates BertPredictionHeadTransform.
182 | func NewBertPredictionHeadTransform(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertPredictionHeadTransform {
183 | changeName := true
184 | if len(changeNameOpt) > 0 {
185 | changeName = changeNameOpt[0]
186 | }
187 | dense := nn.NewLinear(p.Sub("dense"), config.HiddenSize, config.HiddenSize, nn.DefaultLinearConfig())
188 | activation, ok := util.ActivationFnMap[config.HiddenAct]
189 | if !ok {
190 | log.Fatalf("Unsupported activation function - %v\n", config.HiddenAct)
191 | }
192 |
193 | lnConfig := nn.DefaultLayerNormConfig()
194 | if changeName {
195 | lnConfig.WsName = "gamma"
196 | lnConfig.BsName = "beta"
197 | }
198 | layerNorm := nn.NewLayerNorm(p.Sub("LayerNorm"), []int64{config.HiddenSize}, lnConfig)
199 |
200 | return &BertPredictionHeadTransform{dense, activation, layerNorm}
201 | }
202 |
203 | // Forward forwards through the model.
204 | func (bpht *BertPredictionHeadTransform) Forward(hiddenStates *ts.Tensor) (retVal *ts.Tensor) {
205 | tmp1 := hiddenStates.Apply(bpht.Dense)
206 | tmp2 := bpht.Activation.Fwd(tmp1)
207 | retVal = tmp2.Apply(bpht.LayerNorm)
208 | tmp1.MustDrop()
209 | tmp2.MustDrop()
210 |
211 | return retVal
212 | }
213 |
214 | // BertLMPredictionHead:
215 | // =====================
216 |
217 | // BertLMPredictionHead constructs layers for BERT prediction head.
218 | type BertLMPredictionHead struct {
219 | Transform *BertPredictionHeadTransform
220 | Decoder *util.LinearNoBias
221 | Bias *ts.Tensor
222 | }
223 |
224 | // NewBertLMPredictionHead creates BertLMPredictionHead.
225 | func NewBertLMPredictionHead(p *nn.Path, config *BertConfig) (*BertLMPredictionHead, error) {
226 | path := p.Sub("predictions")
227 | transform := NewBertPredictionHeadTransform(path.Sub("transform"), config)
228 | decoder, err := util.NewLinearNoBias(path.Sub("decoder"), config.HiddenSize, config.VocabSize, util.DefaultLinearNoBiasConfig())
229 | if err != nil {
230 | return nil, err
231 | }
232 | bias, err := path.NewVar("bias", []int64{config.VocabSize}, nn.NewKaimingUniformInit())
233 | if err != nil {
234 | return nil, err
235 | }
236 |
237 | return &BertLMPredictionHead{transform, decoder, bias}, nil
238 | }
239 |
240 | // Forward fowards through the model.
241 | func (ph *BertLMPredictionHead) Forward(hiddenState *ts.Tensor) *ts.Tensor {
242 | fwTensor := ph.Transform.Forward(hiddenState).Apply(ph.Decoder)
243 |
244 | retVal := fwTensor.MustAdd(ph.Bias, false)
245 | fwTensor.MustDrop()
246 |
247 | return retVal
248 | }
249 |
250 | // BertForMaskedLM:
251 | // ================
252 |
253 | // BertForMaskedLM is BERT for masked language model
254 | type BertForMaskedLM struct {
255 | bert *BertModel
256 | cls *BertLMPredictionHead
257 | }
258 |
259 | // NewBertForMaskedLM creates BertForMaskedLM.
260 | func NewBertForMaskedLM(p *nn.Path, config *BertConfig, changeNameOpt ...bool) (*BertForMaskedLM, error) {
261 | changeName := true
262 | if len(changeNameOpt) > 0 {
263 | changeName = changeNameOpt[0]
264 | }
265 | bert := NewBertModel(p.Sub("bert"), config, changeName)
266 | cls, err := NewBertLMPredictionHead(p.Sub("cls"), config)
267 | if err != nil {
268 | return nil, err
269 | }
270 |
271 | return &BertForMaskedLM{bert, cls}, nil
272 | }
273 |
274 | // Load loads model from file or model name. It also updates
275 | // default configuration parameters if provided.
276 | // This method implements `PretrainedModel` interface.
277 | func (mlm *BertForMaskedLM) Load(modelNameOrPath string, config interface{ pretrained.Config }, params map[string]interface{}, device gotch.Device) error {
278 | vs := nn.NewVarStore(device)
279 | p := vs.Root()
280 | mlm.bert = NewBertModel(p.Sub("bert"), config.(*BertConfig))
281 | var err error
282 | mlm.cls, err = NewBertLMPredictionHead(p.Sub("cls"), config.(*BertConfig))
283 | if err != nil {
284 | return err
285 | }
286 |
287 | // err = vs.Load(cachedFile)
288 | err = pickle.LoadAll(vs, modelNameOrPath)
289 | if err != nil {
290 | log.Fatalf("Load model weight error: \n%v", err)
291 | }
292 |
293 | return nil
294 | }
295 |
296 | // ForwardT forwards pass through the model.
297 | //
298 | // Params:
299 | // - `inputIds`: optional input tensor of shape (batch size, sequence length).
300 | // If None, pre-computed embeddings must be provided (see `inputEmbeds`)
301 | // - `mask`: optional mask of shape (batch size, sequence length).
302 | // Masked position have value 0, non-masked value 1. If None set to 1.
303 | // - `tokenTypeIds`: optional segment id of shape (batch size, sequence length).
304 | // Convention is value of 0 for the first sentence (incl. [SEP]) and 1 for the second sentence. If None set to 0.
305 | // - `positionIds`: optional position ids of shape (batch size, sequence length).
306 | // If None, will be incremented from 0.
307 | // - `inputEmbeds`: optional pre-computed input embeddings of shape (batch size, sequence length, hidden size).
308 | // If None, input ids must be provided (see `inputIds`).
309 | // - `encoderHiddenStates`: optional encoder hidden state of shape (batch size, encoder sequence length, hidden size).
310 | // If the model is defined as a decoder and the `encoderHiddenStates` is not None,
311 | // used in the cross-attention layer as keys and values (query from the decoder).
312 | // - `encoderMask`: optional encoder attention mask of shape (batch size, encoder sequence length).
313 | // If the model is defined as a decoder and the `encoderHiddenStates` is not None, used to mask encoder values.
314 | // Positions with value 0 will be masked.
315 | // - `train`: boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
316 | //
317 | // Returns:
318 | // - `output`: tensor of shape (batch size, sequence length, hidden size)
319 | // - `hiddenStates`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize)
320 | // - `attentions`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize)
321 | func (mlm *BertForMaskedLM) ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, encoderHiddenStates, encoderMask *ts.Tensor, train bool) (retVal1 *ts.Tensor, optRetVal1, optRetVal2 []ts.Tensor) {
322 |
323 | hiddenState, _, allHiddenStates, allAttentions, err := mlm.bert.ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, encoderHiddenStates, encoderMask, train)
324 | if err != nil {
325 | log.Fatal(err)
326 | }
327 |
328 | predictionScores := mlm.cls.Forward(hiddenState)
329 |
330 | return predictionScores, allHiddenStates, allAttentions
331 | }
332 |
333 | // BERT for sequence classification:
334 | // =================================
335 |
336 | // BertForSequenceClassification is Base BERT model with a classifier head to perform
337 | // sentence or document-level classification.
338 | //
339 | // It is made of the following blocks:
340 | // - `bert`: Base BertModel
341 | // - `classifier`: BERT linear layer for classification
342 | type BertForSequenceClassification struct {
343 | bert *BertModel
344 | dropout *util.Dropout
345 | classifier *nn.Linear
346 | }
347 |
348 | // NewBertForSequenceClassification creates a new `BertForSequenceClassification`.
349 | //
350 | // Params:
351 | // - `p`: ariable store path for the root of the BertForSequenceClassification model
352 | // - `config`: `BertConfig` object defining the model architecture and number of classes
353 | //
354 | // Example:
355 | //
356 | // device := gotch.CPU
357 | // vs := nn.NewVarStore(device)
358 | // config := bert.ConfigFromFile("path/to/config.json")
359 | // p := vs.Root()
360 | // bert := NewBertForSequenceClassification(p.Sub("bert"), config)
361 | func NewBertForSequenceClassification(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertForSequenceClassification {
362 | changeName := true
363 | if len(changeNameOpt) > 0 {
364 | changeName = changeNameOpt[0]
365 | }
366 | bert := NewBertModel(p.Sub("bert"), config, changeName)
367 | dropout := util.NewDropout(config.HiddenDropoutProb)
368 | numLabels := len(config.Id2Label)
369 |
370 | classifier := nn.NewLinear(p.Sub("classifier"), config.HiddenSize, int64(numLabels), nn.DefaultLinearConfig())
371 |
372 | return &BertForSequenceClassification{
373 | bert: bert,
374 | dropout: dropout,
375 | classifier: classifier,
376 | }
377 | }
378 |
379 | // ForwardT forwards pass through the model.
380 | //
381 | // Params:
382 | // - `inputIds`: optional input tensor of shape (batch size, sequence length).
383 | // If None, pre-computed embeddings must be provided (see `inputEmbeds`)
384 | // - `mask`: optional mask of shape (batch size, sequence length).
385 | // Masked position have value 0, non-masked value 1. If None set to 1.
386 | // - `tokenTypeIds`: optional segment id of shape (batch size, sequence length).
387 | // Convention is value of 0 for the first sentence (incl. [SEP]) and 1 for the second sentence. If None set to 0.
388 | // - `positionIds`: optional position ids of shape (batch size, sequence length).
389 | // If None, will be incremented from 0.
390 | // - `inputEmbeds`: optional pre-computed input embeddings of shape (batch size, sequence length, hidden size).
391 | // If None, input ids must be provided (see `inputIds`).
392 | // - `encoderHiddenStates`: optional encoder hidden state of shape (batch size, encoder sequence length, hidden size).
393 | // If the model is defined as a decoder and the `encoderHiddenStates` is not None,
394 | // used in the cross-attention layer as keys and values (query from the decoder).
395 | // - `encoderMask`: optional encoder attention mask of shape (batch size, encoder sequence length).
396 | // If the model is defined as a decoder and the `encoderHiddenStates` is not None, used to mask encoder values.
397 | // Positions with value 0 will be masked.
398 | // - `train`: boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
399 | //
400 | // Returns:
401 | // - `output`: tensor of shape (batch size, sequence length, hidden size)
402 | // - `pooledOutput`: tensor of shape (batch size, hidden size)
403 | // - `hiddenStates`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize)
404 | // - `attentions`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize)
405 | func (bsc *BertForSequenceClassification) ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (retVal *ts.Tensor, retValOpt1, retValOpt2 []ts.Tensor) {
406 | _, pooledOutput, allHiddenStates, allAttentions, err := bsc.bert.ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, ts.None, ts.None, train)
407 |
408 | if err != nil {
409 | log.Fatalf("call bert.ForwardT error: %v", err)
410 | }
411 |
412 | dropoutOutput := pooledOutput.ApplyT(bsc.dropout, train)
413 |
414 | output := dropoutOutput.Apply(bsc.classifier)
415 | dropoutOutput.MustDrop()
416 |
417 | return output, allHiddenStates, allAttentions
418 | }
419 |
420 | // BERT for multiple choices :
421 | // ===========================
422 |
423 | // BertForMultipleChoice constructs multiple choices model using a BERT base model and a linear classifier.
424 | // Input should be in the form `[CLS] Context [SEP] Possible choice [SEP]`. The choice is made along the batch axis,
425 | // assuming all elements of the batch are alternatives to be chosen from for a given context.
426 | //
427 | // It is made of the following blocks:
428 | // - `bert`: Base BertModel
429 | // - `classifier`: Linear layer for multiple choices
430 | type BertForMultipleChoice struct {
431 | bert *BertModel
432 | dropout *util.Dropout
433 | classifier *nn.Linear
434 | }
435 |
436 | // NewBertForMultipleChoice creates a new `BertForMultipleChoice`.
437 | //
438 | // Params:
439 | // - `p`: Variable store path for the root of the BertForMultipleChoice model
440 | // - `config`: `BertConfig` object defining the model architecture
441 | func NewBertForMultipleChoice(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertForMultipleChoice {
442 | changeName := true
443 | if len(changeNameOpt) > 0 {
444 | changeName = changeNameOpt[0]
445 | }
446 | bert := NewBertModel(p.Sub("bert"), config, changeName)
447 | dropout := util.NewDropout(config.HiddenDropoutProb)
448 | classifier := nn.NewLinear(p.Sub("classifier"), config.HiddenSize, 1, nn.DefaultLinearConfig())
449 |
450 | return &BertForMultipleChoice{
451 | bert: bert,
452 | dropout: dropout,
453 | classifier: classifier,
454 | }
455 | }
456 |
457 | // ForwardT forwards pass through the model.
458 | //
459 | // Params:
460 | // - `inputIds`: optional input tensor of shape (batch size, sequence length).
461 | // If None, pre-computed embeddings must be provided (see `inputEmbeds`)
462 | // - `mask`: optional mask of shape (batch size, sequence length).
463 | // Masked position have value 0, non-masked value 1. If None set to 1.
464 | // - `tokenTypeIds`: optional segment id of shape (batch size, sequence length).
465 | // Convention is value of 0 for the first sentence (incl. [SEP]) and 1 for the second sentence. If None set to 0.
466 | // - `positionIds`: optional position ids of shape (batch size, sequence length).
467 | // If None, will be incremented from 0.
468 | // - `train`: boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
469 | //
470 | // Returns:
471 | // - `output`: tensor of shape (batch size, sequence length, hidden size)
472 | // - `hiddenStates`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize)
473 | // - `attentions`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize)
474 | func (mc *BertForMultipleChoice) ForwardT(inputIds, mask, tokenTypeIds, positionIds *ts.Tensor, train bool) (retVal *ts.Tensor, retValOpt1, retValOpt2 []ts.Tensor) {
475 | inputIdsSize := inputIds.MustSize()
476 | fmt.Printf("inputIdsSize: %v\n", inputIdsSize)
477 | numChoices := inputIdsSize[1]
478 | inputIdsView := inputIds.MustView([]int64{-1, inputIdsSize[len(inputIdsSize)-1]}, false)
479 |
480 | maskView := ts.None
481 | if mask.MustDefined() {
482 | maskSize := mask.MustSize()
483 | maskView = mask.MustView([]int64{-1, maskSize[len(maskSize)-1]}, false)
484 | }
485 |
486 | tokenTypeIdsView := ts.None
487 | if tokenTypeIds.MustDefined() {
488 | tokenTypeIdsSize := tokenTypeIds.MustSize()
489 | tokenTypeIdsView = tokenTypeIds.MustView([]int64{-1, tokenTypeIdsSize[len(tokenTypeIdsSize)-1]}, false)
490 | }
491 |
492 | positionIdsView := ts.None
493 | if positionIds.MustDefined() {
494 | positionIdsSize := positionIds.MustSize()
495 | positionIdsView = positionIds.MustView([]int64{-1, positionIdsSize[len(positionIdsSize)-1]}, false)
496 | }
497 |
498 | _, pooledOutput, allHiddenStates, allAttentions, err := mc.bert.ForwardT(inputIdsView, maskView, tokenTypeIdsView, positionIdsView, ts.None, ts.None, ts.None, train)
499 | if err != nil {
500 | log.Fatalf("Call 'BertForMultipleChoice ForwordT' method error: %v\n", err)
501 | }
502 |
503 | outputDropout := pooledOutput.ApplyT(mc.dropout, train)
504 | outputClassifier := outputDropout.Apply(mc.classifier)
505 |
506 | output := outputClassifier.MustView([]int64{-1, numChoices}, false)
507 |
508 | outputDropout.MustDrop()
509 | outputClassifier.MustDrop()
510 |
511 | return output, allHiddenStates, allAttentions
512 | }
513 |
514 | // BERT for token classification (e.g., NER, POS):
515 | // ===============================================
516 |
517 | // BertForTokenClassification constructs token-level classifier predicting a label for each token provided.
518 | // Note that because of wordpiece tokenization, the labels predicted are not necessarily aligned with words in the sentence.
519 | //
520 | // It is made of the following blocks:
521 | // - `bert`: Base BertModel
522 | // - `classifier`: Linear layer for token classification
523 | type BertForTokenClassification struct {
524 | bert *BertModel
525 | dropout *util.Dropout
526 | classifier *nn.Linear
527 | }
528 |
529 | // NewBertForTokenClassification creates a new `BertForTokenClassification`
530 | //
531 | // Params:
532 | // - `p`: Variable store path for the root of the BertForTokenClassification model
533 | // - `config`: `BertConfig` object defining the model architecture, number of output labels and label mapping
534 | func NewBertForTokenClassification(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertForTokenClassification {
535 | changeName := true
536 | if len(changeNameOpt) > 0 {
537 | changeName = changeNameOpt[0]
538 | }
539 | bert := NewBertModel(p.Sub("bert"), config, changeName)
540 | dropout := util.NewDropout(config.HiddenDropoutProb)
541 |
542 | numLabels := len(config.Id2Label)
543 | classifier := nn.NewLinear(p.Sub("classifier"), config.HiddenSize, int64(numLabels), nn.DefaultLinearConfig())
544 |
545 | return &BertForTokenClassification{
546 | bert: bert,
547 | dropout: dropout,
548 | classifier: classifier,
549 | }
550 | }
551 |
552 | // ForwordT forwards pass through the model.
553 | //
554 | // Params:
555 | // - `inputIds`: optional input tensor of shape (batch size, sequence length).
556 | // If None, pre-computed embeddings must be provided (see `inputEmbeds`)
557 | // - `mask`: optional mask of shape (batch size, sequence length).
558 | // Masked position have value 0, non-masked value 1. If None set to 1.
559 | // - `tokenTypeIds`: optional segment id of shape (batch size, sequence length).
560 | // Convention is value of 0 for the first sentence (incl. [SEP]) and 1 for the second sentence. If None set to 0.
561 | // - `positionIds`: optional position ids of shape (batch size, sequence length).
562 | // If None, will be incremented from 0.
563 | // - `inputEmbeds`: optional pre-computed input embeddings of shape (batch size, sequence length, hidden size).
564 | // If None, input ids must be provided (see `inputIds`).
565 | // - `train`: boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
566 | //
567 | // Returns:
568 | // - `output`: tensor of shape (batch size, sequence length, hidden size)
569 | // - `hiddenStates`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize)
570 | // - `attentions`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize)
571 | func (tc *BertForTokenClassification) ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (retVal *ts.Tensor, retValOpt1, retValOpt2 []ts.Tensor) {
572 |
573 | hiddenState, _, allHiddenStates, allAttentions, err := tc.bert.ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, ts.None, ts.None, train)
574 | if err != nil {
575 | log.Fatalf("Call 'BertForTokenClassification ForwardT' method error: %v\n", err)
576 | }
577 |
578 | outputDropout := hiddenState.ApplyT(tc.dropout, train)
579 | output := outputDropout.Apply(tc.classifier)
580 |
581 | outputDropout.MustDrop()
582 |
583 | return output, allHiddenStates, allAttentions
584 | }
585 |
586 | // BERT for question answering:
587 | // ============================
588 |
589 | // BertForQuestionAnswering constructs extractive question-answering model based on a BERT language model. Identifies the segment of a context that answers a provided question.
590 | //
591 | // Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
592 | // See the question answering pipeline (also provided in this crate) for more details.
593 | //
594 | // It is made of the following blocks:
595 | // - `bert`: Base BertModel
596 | // - `qa_outputs`: Linear layer for question answering
597 | type BertForQuestionAnswering struct {
598 | bert *BertModel
599 | qaOutputs *nn.Linear
600 | }
601 |
602 | // NewBertForQuestionAnswering creates a new `BertForQuestionAnswering`.
603 | //
604 | // Params:
605 | // - `p`: Variable store path for the root of the BertForQuestionAnswering model
606 | // - `config`: `BertConfig` object defining the model architecture
607 | func NewForBertQuestionAnswering(p *nn.Path, config *BertConfig, changeNameOpt ...bool) *BertForQuestionAnswering {
608 | changeName := true
609 | if len(changeNameOpt) > 0 {
610 | changeName = changeNameOpt[0]
611 | }
612 | bert := NewBertModel(p.Sub("bert"), config, changeName)
613 |
614 | numLabels := 2
615 | qaOutputs := nn.NewLinear(p.Sub("qa_outputs"), config.HiddenSize, int64(numLabels), nn.DefaultLinearConfig())
616 |
617 | return &BertForQuestionAnswering{
618 | bert: bert,
619 | qaOutputs: qaOutputs,
620 | }
621 | }
622 |
623 | // ForwardT forwards pass through the model.
624 | //
625 | // Params:
626 | // - `inputIds`: optional input tensor of shape (batch size, sequence length).
627 | // If None, pre-computed embeddings must be provided (see `inputEmbeds`)
628 | // - `mask`: optional mask of shape (batch size, sequence length).
629 | // Masked position have value 0, non-masked value 1. If None set to 1.
630 | // - `tokenTypeIds`: optional segment id of shape (batch size, sequence length).
631 | // Convention is value of 0 for the first sentence (incl. [SEP]) and 1 for the second sentence. If None set to 0.
632 | // - `positionIds`: optional position ids of shape (batch size, sequence length).
633 | // If None, will be incremented from 0.
634 | // - `inputEmbeds`: optional pre-computed input embeddings of shape (batch size, sequence length, hidden size).
635 | // If None, input ids must be provided (see `inputIds`).
636 | // - `train`: boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
637 | //
638 | // Returns:
639 | // - `output`: tensor of shape (batch size, sequence length, hidden size)
640 | // - `hiddenStates`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize)
641 | // - `attentions`: slice of tensors of length numHiddenLayers with shape (batch size, sequenceLength, hiddenSize)
642 | func (qa *BertForQuestionAnswering) ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds *ts.Tensor, train bool) (retVal1, retVal2 *ts.Tensor, retValOpt1, retValOpt2 []ts.Tensor) {
643 |
644 | hiddenState, _, allHiddenStates, allAttentions, err := qa.bert.ForwardT(inputIds, mask, tokenTypeIds, positionIds, inputEmbeds, ts.None, ts.None, train)
645 | if err != nil {
646 | log.Fatalf("Call 'BertForTokenClassification ForwardT' method error: %v\n", err)
647 | }
648 |
649 | sequenceOutput := hiddenState.Apply(qa.qaOutputs)
650 | logits := sequenceOutput.MustSplit(1, -1, false) // -1 : split along last size
651 | startLogits := logits[0].MustSqueezeDim(int64(-1), false)
652 | endLogits := logits[1].MustSqueezeDim(int64(-1), false)
653 |
654 | return startLogits, endLogits, allHiddenStates, allAttentions
655 | }
656 |
--------------------------------------------------------------------------------