├── imageutils ├── zebra.png ├── load.go ├── augmentation_test.go └── tensor.go ├── table ├── fixtures │ ├── fruits.csv │ └── census.csv ├── table_test.go └── table.go ├── examples ├── vgg16 │ ├── zebra.png │ ├── download_weights.sh │ └── main.go ├── tabnet │ ├── download_dataset.sh │ └── census.go ├── mnist │ ├── download_dataset.sh │ ├── loader_test.go │ ├── fc │ │ └── fc.go │ └── loader.go └── simple-tabnet │ └── simple.go ├── .gitignore ├── init.go ├── sequential.go ├── layer.go ├── init_test.go ├── tabnet ├── init_test.go ├── glu_block_test.go ├── feature_transformer.go ├── tab_net.go ├── glu_block.go ├── classifier.go ├── attentive_transformer.go ├── attentive_transformer_test.go ├── tab_net_regressor.go ├── feature_transformer_test.go ├── tab_net_test.go ├── tab_net_no_embeddings_test.go ├── tab_net_regressor_test.go └── tab_net_no_embeddings.go ├── errors.go ├── utils.go ├── batcher.go ├── module.go ├── activation └── activation.go ├── print.go ├── watchables.go ├── storage ├── storage.go ├── base.go └── nn1.go ├── activation.go ├── README.md ├── embedding.go ├── imagenet └── classifier.go ├── avg_pool.go ├── losses.go ├── max_pool.go ├── embedding_test.go ├── embedding_generator.go ├── data_loader_test.go ├── fc_test.go ├── weights.go ├── conv2d.go ├── glu.go ├── go.mod ├── fc.go ├── batch_norm.go ├── gbn.go ├── gbn_test.go ├── vggface2 ├── identity_block.go ├── conv_block.go └── vggface2.go ├── embedding_generator_test.go ├── vgg ├── block.go └── vgg16.go ├── losses_test.go ├── ui └── ui.go ├── batch_norm_test.go ├── validate.go ├── data_loader.go ├── model.go ├── train.go └── lstm └── lstm.go /imageutils/zebra.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcu/godl/HEAD/imageutils/zebra.png -------------------------------------------------------------------------------- /table/fixtures/fruits.csv: -------------------------------------------------------------------------------- 1 | Banana, 100, 0.1 2 | Strawberry, 200, 0.2 3 | Lemon, 150, 0.1 -------------------------------------------------------------------------------- /examples/vgg16/zebra.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcu/godl/HEAD/examples/vgg16/zebra.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.gz 2 | go.work 3 | go.work.sum 4 | *.py 5 | *~ 6 | .vscode 7 | .DS_Store 8 | *.svg 9 | *.log 10 | -------------------------------------------------------------------------------- /init.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "math/rand" 5 | "time" 6 | ) 7 | 8 | func init() { 9 | rand.Seed(time.Now().UnixNano()) 10 | } 11 | -------------------------------------------------------------------------------- /examples/tabnet/download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## https://archive.ics.uci.edu/ml/datasets/adult 4 | 5 | wget https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data 6 | 7 | -------------------------------------------------------------------------------- /examples/vgg16/download_weights.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | wget https://github.com/dcu/godl/releases/download/vgg16/vgg16.nn1.gz 4 | wget https://github.com/dcu/godl/releases/download/vgg16/vgg16_notop.nn1.gz 5 | 6 | -------------------------------------------------------------------------------- /sequential.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | // Sequential runs the given layers one after the other 4 | func Sequential(m *Model, modules ...Module) ModuleList { 5 | _ = AddLayer("Sequential") 6 | 7 | list := ModuleList{} 8 | list.Add(modules...) 9 | 10 | return list 11 | } 12 | -------------------------------------------------------------------------------- /examples/mnist/download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz 4 | wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz 5 | wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz 6 | wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz 7 | 8 | -------------------------------------------------------------------------------- /layer.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | ) 7 | 8 | var ( 9 | layersCount = map[string]uint32{} 10 | countMutex = sync.Mutex{} 11 | ) 12 | 13 | type LayerType string 14 | 15 | func AddLayer(typ string) LayerType { 16 | countMutex.Lock() 17 | layersCount[typ]++ 18 | countMutex.Unlock() 19 | 20 | return LayerType(fmt.Sprintf("%s_%d", typ, layersCount[typ])) 21 | } 22 | -------------------------------------------------------------------------------- /init_test.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "log" 5 | "os" 6 | 7 | "gorgonia.org/tensor" 8 | ) 9 | 10 | var ( 11 | testLogger *log.Logger 12 | ) 13 | 14 | func init() { 15 | f, err := os.Create("test.log") 16 | if err != nil { 17 | panic(err) 18 | } 19 | 20 | testLogger = log.New(f, "[G]", log.LstdFlags) 21 | } 22 | 23 | func initDummyWeights(dt tensor.Dtype, s ...int) interface{} { 24 | v := make([]float32, tensor.Shape(s).TotalSize()) 25 | 26 | for i := range v { 27 | v[i] = 1.0 28 | } 29 | 30 | return v 31 | } 32 | -------------------------------------------------------------------------------- /tabnet/init_test.go: -------------------------------------------------------------------------------- 1 | package tabnet 2 | 3 | import ( 4 | "log" 5 | "os" 6 | 7 | "gorgonia.org/tensor" 8 | ) 9 | 10 | var ( 11 | testLogger *log.Logger 12 | ) 13 | 14 | func init() { 15 | f, err := os.Create("test.log") 16 | if err != nil { 17 | panic(err) 18 | } 19 | 20 | testLogger = log.New(f, "[G]", log.LstdFlags) 21 | } 22 | 23 | func initDummyWeights(dt tensor.Dtype, s ...int) interface{} { 24 | v := make([]float32, tensor.Shape(s).TotalSize()) 25 | 26 | for i := range v { 27 | v[i] = 1.0 28 | } 29 | 30 | return v 31 | } 32 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/fatih/color" 7 | ) 8 | 9 | // HandleErr panics if the given err is not nil 10 | func HandleErr(err error, where string, args ...interface{}) { 11 | if err == nil { 12 | return 13 | } 14 | 15 | message := fmt.Sprintf(where, args...) 16 | 17 | panic(fmt.Sprintf("%s: %v", color.RedString(message), err)) 18 | } 19 | 20 | func ErrorF(lt LayerType, template string, args ...interface{}) error { 21 | args = append([]interface{}{lt}, args...) 22 | return fmt.Errorf("[%s] "+template, args...) 23 | } 24 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | func anyNumberToFloat64(v interface{}) float64 { 8 | switch f := v.(type) { 9 | case float64: 10 | return f 11 | case int: 12 | return float64(f) 13 | case int64: 14 | return float64(f) 15 | } 16 | 17 | panic(fmt.Errorf("unsupported type: %T", v)) 18 | } 19 | 20 | func MustBeGreatherThan(lt LayerType, context string, v interface{}, base interface{}) { 21 | if anyNumberToFloat64(v) <= anyNumberToFloat64(base) { 22 | panic(fmt.Errorf("[%s] %s: %v must be greater than %v", lt, context, v, base)) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /batcher.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "gorgonia.org/gorgonia" 5 | "gorgonia.org/tensor" 6 | ) 7 | 8 | func InBatches(x tensor.Tensor, batchSize int, cb func(v tensor.Tensor)) { 9 | totalSize := x.Shape()[0] 10 | batches := totalSize / batchSize 11 | 12 | for b := 0; b < batches; b++ { 13 | start := b * batchSize 14 | end := start + batchSize 15 | 16 | if start >= totalSize { 17 | break 18 | } 19 | 20 | if end > totalSize { 21 | end = totalSize 22 | } 23 | 24 | sliced, err := x.Slice(gorgonia.S(start, end)) 25 | if err != nil { 26 | panic(err) 27 | } 28 | 29 | cb(sliced) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /module.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "gorgonia.org/gorgonia" 5 | ) 6 | 7 | type ( 8 | Node = gorgonia.Node 9 | Nodes = gorgonia.Nodes 10 | ) 11 | 12 | type Module interface { 13 | Forward(inputs ...*Node) Nodes 14 | Name() string 15 | } 16 | 17 | type ModuleList []Module 18 | 19 | func (m *ModuleList) Add(mods ...Module) { 20 | *m = append(*m, mods...) 21 | } 22 | 23 | func (m ModuleList) Forward(inputs ...*Node) (out Nodes) { 24 | out = inputs 25 | 26 | for _, mod := range m { 27 | out = mod.Forward(out...) 28 | } 29 | 30 | return out 31 | } 32 | 33 | func (m ModuleList) Name() string { 34 | return "ModuleList" 35 | } 36 | 37 | var ( 38 | _ Module = ModuleList{} 39 | ) 40 | -------------------------------------------------------------------------------- /activation/activation.go: -------------------------------------------------------------------------------- 1 | package activation 2 | 3 | import "gorgonia.org/gorgonia" 4 | 5 | // Function represents an activation function 6 | type Function func(*gorgonia.Node) (*gorgonia.Node, error) 7 | 8 | func Sigmoid(x *gorgonia.Node) (*gorgonia.Node, error) { 9 | return gorgonia.Sigmoid(x) 10 | } 11 | 12 | func Tanh(x *gorgonia.Node) (*gorgonia.Node, error) { 13 | return gorgonia.Tanh(x) 14 | } 15 | 16 | func Rectify(x *gorgonia.Node) (*gorgonia.Node, error) { 17 | return gorgonia.Rectify(x) 18 | } 19 | 20 | func SoftMax(x *gorgonia.Node) (*gorgonia.Node, error) { 21 | return gorgonia.SoftMax(x) 22 | } 23 | 24 | func SparseMax(x *gorgonia.Node) (*gorgonia.Node, error) { 25 | return gorgonia.Sparsemax(x) 26 | } 27 | -------------------------------------------------------------------------------- /print.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os" 7 | 8 | "github.com/fatih/color" 9 | ) 10 | 11 | func info(tmpl string, args ...interface{}) { 12 | msg := fmt.Sprintf(tmpl, args...) 13 | 14 | log.Printf("=> %s", color.GreenString(msg)) 15 | } 16 | 17 | func warn(tmpl string, args ...interface{}) { 18 | msg := fmt.Sprintf(tmpl, args...) 19 | 20 | log.Printf("=> %s", color.YellowString(msg)) 21 | } 22 | 23 | func failure(tmpl string, args ...interface{}) { 24 | msg := fmt.Sprintf(tmpl, args...) 25 | 26 | log.Printf("=> %s", color.RedString(msg)) 27 | } 28 | 29 | func fatal(tmpl string, args ...interface{}) { 30 | msg := fmt.Sprintf(tmpl, args...) 31 | 32 | log.Printf("=> %s", color.RedString(msg)) 33 | os.Exit(1) 34 | } 35 | -------------------------------------------------------------------------------- /examples/mnist/loader_test.go: -------------------------------------------------------------------------------- 1 | package mnist 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func TestLoadData(t *testing.T) { 10 | testCases := []struct { 11 | desc string 12 | mode Mode 13 | examples int 14 | }{ 15 | { 16 | desc: "ModeTrain", 17 | mode: ModeTrain, 18 | examples: 600000, 19 | }, 20 | { 21 | desc: "ModeTest", 22 | mode: ModeTest, 23 | examples: 100000, 24 | }, 25 | } 26 | for _, tC := range testCases { 27 | t.Run(tC.desc, func(t *testing.T) { 28 | c := require.New(t) 29 | 30 | x, y, err := Load(tC.mode, "") 31 | c.NoError(err) 32 | c.NotNil(x) 33 | c.NotNil(y) 34 | 35 | c.Equal(x.Shape()[0], y.Shape()[0]) 36 | 37 | c.Equal(tC.examples, y.Size()) 38 | c.Equal(tC.examples*28*28/10, x.Size()) 39 | }) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /watchables.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/fatih/color" 7 | "gorgonia.org/gorgonia" 8 | ) 9 | 10 | type watchable struct { 11 | name string 12 | node *gorgonia.Value 13 | } 14 | 15 | // Watch watches the given node 16 | func (m *Model) Watch(name string, node *gorgonia.Node) { 17 | var v gorgonia.Value 18 | 19 | name = fmt.Sprintf("%s <%s>", name, node.Name()) 20 | pointer := &v 21 | 22 | gorgonia.Read(node, pointer) 23 | 24 | m.watchables = append(m.watchables, watchable{name, pointer}) 25 | } 26 | 27 | func (m Model) PrintWatchables() { 28 | for _, w := range m.watchables { 29 | if w.node != nil { 30 | fmt.Printf("[w] %s: %v\n%v\n\n", color.GreenString(w.name), (*w.node).Shape(), (*w.node)) 31 | 32 | if m.Logger != nil { 33 | m.Logger.Printf("%s-%v\n%#v", w.name, (*w.node).Shape(), (*w.node).Data()) 34 | } 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /imageutils/load.go: -------------------------------------------------------------------------------- 1 | package imageutils 2 | 3 | import ( 4 | "fmt" 5 | "image" 6 | "os" 7 | 8 | "github.com/dcu/resize" 9 | 10 | _ "image/jpeg" // import for side effects 11 | _ "image/png" // import for side effects 12 | ) 13 | 14 | // LoadOpts contains options to load an image 15 | type LoadOpts struct { 16 | TargetSize []uint 17 | } 18 | 19 | // Load loads an image from the given path 20 | func Load(filePath string, opts LoadOpts) (image.Image, error) { 21 | file, err := os.Open(filePath) 22 | if err != nil { 23 | return nil, err 24 | } 25 | 26 | defer func() { _ = file.Close() }() 27 | 28 | img, _, err := image.Decode(file) 29 | if err != nil { 30 | return nil, fmt.Errorf("decoding image %v: %w", filePath, err) 31 | } 32 | 33 | if len(opts.TargetSize) == 2 { 34 | img = resize.Resize(opts.TargetSize[0], opts.TargetSize[1], img, resize.Lanczos3) 35 | } 36 | 37 | return img, nil 38 | } 39 | -------------------------------------------------------------------------------- /examples/vgg16/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | "os" 7 | 8 | "github.com/dcu/godl/imagenet" 9 | "github.com/dcu/godl/imageutils" 10 | "github.com/dcu/godl/vgg" 11 | "github.com/fatih/color" 12 | ) 13 | 14 | func handleErr(err error) { 15 | if err != nil { 16 | panic(err) 17 | } 18 | } 19 | 20 | func main() { 21 | flag.Parse() 22 | if len(flag.Args()) == 0 { 23 | color.Yellow("pass an image to detect") 24 | 25 | os.Exit(1) 26 | } 27 | 28 | img, err := imageutils.Load(flag.Args()[0], imageutils.LoadOpts{ 29 | TargetSize: []uint{224, 224}, 30 | }) 31 | handleErr(err) 32 | 33 | vgg16 := vgg.VGG16Builder(vgg.Opts{ 34 | PreTrained: true, 35 | Learnable: false, 36 | WithBias: true, 37 | }) 38 | 39 | classifier := imagenet.NewClassifier(vgg16, 224, 224) 40 | 41 | label, prob, err := classifier.Predict(img) 42 | if err != nil { 43 | // classifier.Model().WriteSVG("model.svg") 44 | handleErr(err) 45 | } 46 | 47 | log.Printf("%v: %.2f%%", label, prob*100) 48 | } 49 | -------------------------------------------------------------------------------- /imageutils/augmentation_test.go: -------------------------------------------------------------------------------- 1 | package imageutils 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/fogleman/gg" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestAugmentation(t *testing.T) { 11 | testCases := []struct { 12 | desc string 13 | filters []Filter 14 | }{ 15 | { 16 | desc: "Example 1", 17 | filters: []Filter{ 18 | WithCrop(6, 2, 55, 100), 19 | Either( 20 | WithRandomRotation(0.8, -15, 15), 21 | WithRandomShear(0.8, -15, 15), 22 | ), 23 | WithRandomGaussianBlur(0.5, 0.2, 1.0), 24 | WithRandomErosion(0.3, 0.5, 1.0), 25 | }, 26 | }, 27 | } 28 | 29 | img, err := gg.LoadImage("zebra.png") 30 | if err != nil { 31 | panic(err) 32 | } 33 | 34 | for _, tC := range testCases { 35 | t.Run(tC.desc, func(t *testing.T) { 36 | c := require.New(t) 37 | 38 | a := NewAugmenter(tC.filters...) 39 | result, err := a.ApplyN(img, 10) 40 | 41 | c.NoError(err) 42 | c.Len(result, 10) 43 | 44 | for i, r := range result { 45 | c.NotEqual(img, r) 46 | 47 | _ = i 48 | // _ = gg.SavePNG(fmt.Sprintf("%d.png", i), r) 49 | } 50 | }) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /storage/storage.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "strings" 7 | 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | // Storage is in charge of loading the weights from files 12 | type Storage struct { 13 | Cost float64 14 | Learnables map[string]Weight 15 | } 16 | 17 | // NewStorage instantiates a storage 18 | func NewStorage() *Storage { 19 | return &Storage{ 20 | Cost: 0.0, 21 | Learnables: map[string]Weight{}, 22 | } 23 | } 24 | 25 | // TensorByName returns the tensor associated to a weight name 26 | func (l *Storage) TensorByName(name string) (tensor.Tensor, error) { 27 | t, ok := l.Learnables[name] 28 | if !ok { 29 | return nil, ErrLearnableNotFound 30 | } 31 | 32 | return t.Value.(tensor.Tensor), nil 33 | } 34 | 35 | // Load loads the weights in the given path 36 | func (l *Storage) LoadFile(filePath string) error { 37 | if strings.Contains(filePath, ".nn1") { 38 | return LoadNN1(l, filePath) 39 | } else { 40 | return fmt.Errorf("extension %v is not supported yet", filePath) 41 | } 42 | } 43 | 44 | // AddWeights adds weights to the storage 45 | func (l *Storage) AddWeights(weights ...Weight) { 46 | for _, w := range weights { 47 | if _, ok := l.Learnables[w.Name]; ok { 48 | log.Panicf("weight %s is already present in the storage", w.Name) 49 | } 50 | 51 | l.Learnables[w.Name] = w 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /activation.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "gorgonia.org/gorgonia" 5 | ) 6 | 7 | type ActivationModule struct { 8 | name string 9 | fn func(x *Node) (*Node, error) 10 | } 11 | 12 | func (m *ActivationModule) Name() string { 13 | return m.name 14 | } 15 | 16 | func (m *ActivationModule) Forward(inputs ...*Node) Nodes { 17 | x := inputs[0] 18 | y := gorgonia.Must(m.fn(x)) 19 | 20 | return Nodes{y} 21 | } 22 | 23 | type ActivationAxisModule struct { 24 | name string 25 | axis []int 26 | fn func(x *Node, axis ...int) (*Node, error) 27 | } 28 | 29 | func (m *ActivationAxisModule) Forward(inputs ...*Node) Nodes { 30 | x := inputs[0] 31 | y := gorgonia.Must(m.fn(x, m.axis...)) 32 | 33 | return Nodes{y} 34 | } 35 | 36 | func (m *ActivationAxisModule) Name() string { 37 | return m.name 38 | } 39 | 40 | func Sigmoid() Module { 41 | return &ActivationModule{ 42 | name: "Sigmoid", 43 | fn: gorgonia.Sigmoid, 44 | } 45 | } 46 | 47 | func Tanh() Module { 48 | return &ActivationModule{ 49 | name: "Tanh", 50 | fn: gorgonia.Tanh, 51 | } 52 | } 53 | 54 | func Rectify() Module { 55 | return &ActivationModule{ 56 | name: "Rectify", 57 | fn: gorgonia.Rectify, 58 | } 59 | } 60 | 61 | func SparseMax(axis ...int) Module { 62 | return &ActivationAxisModule{ 63 | name: "SparseMax", 64 | axis: axis, 65 | fn: gorgonia.Sparsemax, 66 | } 67 | } 68 | 69 | func SoftMax(axis ...int) Module { 70 | return &ActivationAxisModule{ 71 | name: "SoftMax", 72 | axis: axis, 73 | fn: gorgonia.SoftMax, 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GoDL 2 | 3 | godl is **Go** **D**eep **L**earning framework written on top of Gorgonia. 4 | godl is to Gorgonia what Keras is to TensorFlow. 5 | 6 | ## API Stability 7 | The API is not stable and can change at any moment. 8 | I'm writing this framework mostly to learn and so I don't provide any guarantees 9 | that it'll work for you. Use it at your own risk. 10 | 11 | ## Roadmap 12 | 13 | The following items are in the current roadmap, some of them need to be implemented in Gorgonia first. 14 | 15 | - [x] Data loader 16 | - [x] Base storage (save/load) 17 | - [ ] CLI to scaffold a project 18 | - [x] Embeddings 19 | - [x] Dense/Linear/FC 20 | 21 | ### Losses 22 | - [x] Cross Entropy 23 | - [x] MSE 24 | - [ ] BCE 25 | - [ ] BinaryXent 26 | - [ ] CTC Losses 27 | 28 | ### Pooling 29 | - [x] MaxPool 30 | - [x] AvgPool 31 | - [x] GlobalMaxPool 32 | - [x] GlobalAvgPool 33 | 34 | ### Normalization 35 | - [x] Batch Norm 36 | - [x] Ghost Batch Norm 37 | - [ ] GroupNorm 38 | - [ ] LayerNorm 39 | 40 | ### Recurrent Layers 41 | - [x] LSTM 42 | - [x] Bidirectional 43 | - [ ] GRU 44 | - [ ] ConvLSTM2D 45 | 46 | ### Reshaping 47 | - [ ] ZeroPadding 48 | - [ ] UpSampling 49 | 50 | ### Convolutional 51 | - [x] Conv2D 52 | - [ ] DepthWiseConv2D 53 | 54 | ### Applications 55 | - [x] TabNet 56 | - [x] VGG16 57 | - [ ] VGGFace2 (in progress) 58 | - [ ] VGG19 59 | - [ ] ResNet50 60 | - [ ] ResNet101 61 | - [ ] YOLO 62 | - [ ] BERT 63 | 64 | ### Future 65 | - [ ] Support ONNX 66 | - [ ] Support hdf5 files 67 | -------------------------------------------------------------------------------- /storage/base.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | import ( 4 | "bytes" 5 | "encoding/gob" 6 | "errors" 7 | 8 | "gorgonia.org/gorgonia" 9 | "gorgonia.org/tensor" 10 | ) 11 | 12 | var ( 13 | ErrLearnableNotFound = errors.New("learnable not found") 14 | ) 15 | 16 | type Weight struct { 17 | Name string 18 | Value gorgonia.Value 19 | } 20 | 21 | // GobEncode implements the gob.GobEncoder interface 22 | func (w *Weight) GobEncode() ([]byte, error) { 23 | buf := new(bytes.Buffer) 24 | encoder := gob.NewEncoder(buf) 25 | 26 | err := encoder.Encode(&w.Name) 27 | if err != nil { 28 | return nil, err 29 | } 30 | 31 | err = encoder.Encode(&w.Value) 32 | if err != nil { 33 | return nil, err 34 | } 35 | 36 | return buf.Bytes(), err 37 | } 38 | 39 | // GobDecode implements the gob.GobDecoder interface 40 | func (w *Weight) GobDecode(buf []byte) error { 41 | reader := bytes.NewBuffer(buf) 42 | decoder := gob.NewDecoder(reader) 43 | 44 | err := decoder.Decode(&w.Name) 45 | if err != nil { 46 | return err 47 | } 48 | 49 | err = decoder.Decode(&w.Value) 50 | if err != nil { 51 | return err 52 | } 53 | 54 | return nil 55 | } 56 | 57 | type Item struct { 58 | Cost float64 59 | Weights []Weight 60 | } 61 | 62 | // NodesToItem converts a list of nodes to an storage.Item 63 | func NodesToItem(nodes ...*gorgonia.Node) Item { 64 | item := Item{} 65 | item.Weights = make([]Weight, len(nodes)) 66 | 67 | for i, n := range nodes { 68 | item.Weights[i].Name = n.Name() 69 | item.Weights[i].Value = n.Value().(tensor.Tensor) 70 | } 71 | 72 | return item 73 | } 74 | -------------------------------------------------------------------------------- /embedding.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "gorgonia.org/gorgonia" 5 | "gorgonia.org/tensor" 6 | ) 7 | 8 | type EmbeddingOpts struct { 9 | WeightsInit gorgonia.InitWFn 10 | } 11 | 12 | type EmbeddingModule struct { 13 | model *Model 14 | layer LayerType 15 | opts EmbeddingOpts 16 | embeddingSize, embeddingDim int 17 | 18 | weight *Node 19 | } 20 | 21 | func (m *EmbeddingModule) Forward(inputs ...*Node) Nodes { 22 | err := m.model.CheckArity(m.layer, inputs, 1) 23 | if err != nil { 24 | panic(err) 25 | } 26 | 27 | indices := inputs[0] 28 | indicesShape := indices.Shape().Clone() 29 | indices = gorgonia.Must(gorgonia.Reshape(indices, tensor.Shape{indicesShape.TotalSize()})) 30 | 31 | embedding, err := gorgonia.ByIndices(m.weight, indices, 0) 32 | if err != nil { 33 | panic(err) 34 | } 35 | 36 | embedding = gorgonia.Must(gorgonia.Reshape(embedding, append(indicesShape, m.embeddingDim))) 37 | 38 | return Nodes{embedding} 39 | } 40 | 41 | // Embedding implements a embedding layer 42 | func Embedding(m *Model, embeddingSize int, embeddingDim int, opts EmbeddingOpts) *EmbeddingModule { 43 | lt := AddLayer("Embedding") 44 | 45 | if opts.WeightsInit == nil { 46 | opts.WeightsInit = gorgonia.Gaussian(0.0, 1.0) 47 | } 48 | 49 | w := m.AddWeights(lt, tensor.Shape{embeddingSize, embeddingDim}, NewWeightsOpts{ 50 | InitFN: opts.WeightsInit, 51 | }) 52 | 53 | return &EmbeddingModule{ 54 | model: m, 55 | layer: lt, 56 | opts: opts, 57 | embeddingSize: embeddingSize, 58 | embeddingDim: embeddingDim, 59 | weight: w, 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /imagenet/classifier.go: -------------------------------------------------------------------------------- 1 | package imagenet 2 | 3 | import ( 4 | "image" 5 | "sync" 6 | 7 | "github.com/dcu/godl" 8 | "github.com/dcu/godl/imageutils" 9 | "gorgonia.org/gorgonia" 10 | "gorgonia.org/tensor" 11 | ) 12 | 13 | // Classifier is an imagenet classifier 14 | type Classifier struct { 15 | m *godl.Model 16 | x *gorgonia.Node 17 | 18 | output gorgonia.Value 19 | mutex sync.Mutex 20 | } 21 | 22 | func NewClassifier(builder func(m *godl.Model) godl.Module, width, height int) *Classifier { 23 | m := godl.NewModel() 24 | module := builder(m) 25 | x := gorgonia.NewTensor(m.TrainGraph(), tensor.Float32, 4, gorgonia.WithShape(1, 3, width, height), gorgonia.WithName("x")) 26 | 27 | result := module.Forward(x) 28 | c := &Classifier{ 29 | m: m, 30 | x: x, 31 | } 32 | 33 | gorgonia.Read(result[0], &c.output) 34 | 35 | return c 36 | } 37 | 38 | func (c *Classifier) Model() *godl.Model { 39 | return c.m 40 | } 41 | 42 | func (c *Classifier) Predict(img image.Image) (string, float32, error) { 43 | c.mutex.Lock() 44 | defer c.mutex.Unlock() 45 | 46 | input := imageutils.ToTensor(img, imageutils.ToTensorOpts{}) 47 | 48 | err := gorgonia.Let(c.x, input) 49 | if err != nil { 50 | return "", 0.0, err 51 | } 52 | 53 | err = c.m.Run() 54 | if err != nil { 55 | return "", 0.0, err 56 | } 57 | 58 | outputTensor := c.output.(tensor.Tensor) 59 | max, err := tensor.Argmax(outputTensor, 1) 60 | if err != nil { 61 | return "", 0.0, err 62 | } 63 | 64 | index := max.Data().([]int)[0] 65 | 66 | val, err := outputTensor.At(0, index) 67 | if err != nil { 68 | return "", 0.0, err 69 | } 70 | 71 | return Labels[index], val.(float32), nil 72 | } 73 | -------------------------------------------------------------------------------- /avg_pool.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "gorgonia.org/gorgonia" 5 | "gorgonia.org/tensor" 6 | ) 7 | 8 | type AvgPool2DOpts struct { 9 | Kernel tensor.Shape 10 | Padding []int 11 | Stride []int 12 | } 13 | 14 | func (opts *AvgPool2DOpts) setDefaults() { 15 | if opts.Padding == nil { 16 | opts.Padding = []int{0, 0} 17 | } 18 | 19 | if opts.Stride == nil { 20 | opts.Padding = []int(opts.Kernel) 21 | } 22 | } 23 | 24 | type GlobalAvgPool2DModule struct { 25 | model *Model 26 | layer LayerType 27 | } 28 | 29 | func (m *GlobalAvgPool2DModule) Name() string { 30 | return "GlobalAvgPool2d" 31 | } 32 | 33 | func (m *GlobalAvgPool2DModule) Forward(inputs ...*Node) Nodes { 34 | err := m.model.CheckArity(m.layer, inputs, 1) 35 | if err != nil { 36 | panic(err) 37 | } 38 | 39 | x := inputs[0] 40 | x = gorgonia.Must(gorgonia.GlobalAveragePool2D(x)) 41 | 42 | return Nodes{x} 43 | } 44 | 45 | // GlobalAvgPool2D applies the global average pool operation to the given image 46 | func GlobalAvgPool2D(nn *Model) *GlobalAvgPool2DModule { 47 | lt := AddLayer("GlobalAvgPool2D") 48 | 49 | return &GlobalAvgPool2DModule{ 50 | model: nn, 51 | layer: lt, 52 | } 53 | } 54 | 55 | type AvgPool2DModule struct { 56 | model *Model 57 | opts AvgPool2DOpts 58 | layer LayerType 59 | } 60 | 61 | func (m *AvgPool2DModule) Name() string { 62 | return "AvgPool2d" 63 | } 64 | 65 | func (m *AvgPool2DModule) Forward(inputs ...*Node) Nodes { 66 | err := m.model.CheckArity(m.layer, inputs, 1) 67 | if err != nil { 68 | panic(err) 69 | } 70 | 71 | x := inputs[0] 72 | x = gorgonia.Must(gorgonia.AveragePool2D(x, m.opts.Kernel, m.opts.Padding, m.opts.Stride)) 73 | 74 | return Nodes{x} 75 | } 76 | 77 | // AvgPool2D applies the average pool operation to the given image 78 | func AvgPool2D(nn *Model, opts AvgPool2DOpts) *AvgPool2DModule { 79 | lt := AddLayer("AvgPool2D") 80 | 81 | opts.setDefaults() 82 | 83 | return &AvgPool2DModule{ 84 | model: nn, 85 | opts: opts, 86 | layer: lt, 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /losses.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "gorgonia.org/gorgonia" 5 | ) 6 | 7 | var ( 8 | reductionMap = map[Reduction]func(*gorgonia.Node, ...int) (*gorgonia.Node, error){ 9 | ReductionNone: noopReduction, 10 | ReductionMean: gorgonia.Mean, 11 | ReductionSum: gorgonia.Sum, 12 | } 13 | ) 14 | 15 | type Reduction string 16 | 17 | func (r Reduction) Func() func(*gorgonia.Node, ...int) (*gorgonia.Node, error) { 18 | fn, ok := reductionMap[r] 19 | if !ok { 20 | return gorgonia.Mean 21 | } 22 | 23 | return fn 24 | } 25 | 26 | const ( 27 | ReductionNone Reduction = "none" 28 | ReductionSum Reduction = "sum" 29 | ReductionMean Reduction = "mean" 30 | ) 31 | 32 | type MSELossOpts struct { 33 | Reduction Reduction 34 | } 35 | 36 | type CostFn func(output Nodes, target *Node) *Node 37 | 38 | // MSELoss defines the mean square root cost function 39 | func MSELoss(opts MSELossOpts) CostFn { 40 | return func(output Nodes, target *gorgonia.Node) *gorgonia.Node { 41 | sub := gorgonia.Must(gorgonia.Sub(output[0], target)) 42 | 43 | return gorgonia.Must(opts.Reduction.Func()(gorgonia.Must(gorgonia.Square(sub)))) 44 | } 45 | } 46 | 47 | type CrossEntropyLossOpt struct { 48 | Reduction Reduction 49 | } 50 | 51 | // CrossEntropyLoss implements cross entropy loss function 52 | func CrossEntropyLoss(opts CrossEntropyLossOpt) CostFn { 53 | return func(output Nodes, target *Node) *gorgonia.Node { 54 | cost := gorgonia.Must(gorgonia.HadamardProd(gorgonia.Must(gorgonia.Neg(gorgonia.Must(gorgonia.Log(output[0])))), target)) 55 | 56 | return gorgonia.Must(opts.Reduction.Func()(cost)) 57 | } 58 | } 59 | 60 | // CategoricalCrossEntropyLoss is softmax + cce 61 | func CategoricalCrossEntropyLoss(opts CrossEntropyLossOpt) CostFn { 62 | return func(output Nodes, target *Node) *Node { 63 | cost := gorgonia.Must(gorgonia.SoftMax(output[0])) 64 | 65 | return CrossEntropyLoss(opts)(Nodes{cost}, target) 66 | } 67 | } 68 | 69 | func noopReduction(n *gorgonia.Node, along ...int) (*gorgonia.Node, error) { 70 | return n, nil 71 | } 72 | -------------------------------------------------------------------------------- /storage/nn1.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | import ( 4 | "compress/gzip" 5 | "encoding/gob" 6 | "fmt" 7 | "os" 8 | ) 9 | 10 | // LoadNN1 opens the given file path 11 | func LoadNN1(loader *Storage, filePath string) error { 12 | f, err := os.Open(filePath) 13 | if err != nil { 14 | return err 15 | } 16 | 17 | defer func() { 18 | _ = f.Close() 19 | }() 20 | 21 | ungzipper, err := gzip.NewReader(f) 22 | if err != nil { 23 | return err 24 | } 25 | 26 | defer func() { 27 | _ = ungzipper.Close() 28 | }() 29 | 30 | dec := gob.NewDecoder(ungzipper) 31 | 32 | version := 0 33 | weightsCount := 0 34 | 35 | if err = dec.Decode(&version); err != nil { 36 | return err 37 | } 38 | 39 | if err = dec.Decode(&loader.Cost); err != nil { 40 | return err 41 | } 42 | 43 | if err = dec.Decode(&weightsCount); err != nil { 44 | return err 45 | } 46 | 47 | for i := 0; i < weightsCount; i++ { 48 | weight := Weight{} 49 | if err = dec.Decode(&weight); err != nil { 50 | return err 51 | } 52 | 53 | loader.AddWeights(weight) 54 | } 55 | 56 | return nil 57 | } 58 | 59 | // SaveNN1 saves the model in the given path 60 | func SaveNN1(path string, item Item) error { 61 | f, err := os.Create(path) 62 | if err != nil { 63 | return err 64 | } 65 | 66 | defer func() { 67 | _ = f.Close() 68 | }() 69 | 70 | gzipper := gzip.NewWriter(f) 71 | defer func() { 72 | _ = gzipper.Close() 73 | }() 74 | 75 | enc := gob.NewEncoder(gzipper) 76 | version := 0 77 | 78 | if err = enc.Encode(version); err != nil { 79 | return fmt.Errorf("encoding version %d: %w", version, err) 80 | } 81 | 82 | if err = enc.Encode(item.Cost); err != nil { 83 | return fmt.Errorf("encoding cost %v: %w", item.Cost, err) 84 | } 85 | 86 | if err = enc.Encode(len(item.Weights)); err != nil { 87 | return fmt.Errorf("encoding weights count: %w", err) 88 | } 89 | 90 | for _, w := range item.Weights { 91 | if err = enc.Encode(&w); err != nil { 92 | return fmt.Errorf("encoding learnable %s: %w", w.Name, err) 93 | } 94 | } 95 | 96 | return nil 97 | } 98 | -------------------------------------------------------------------------------- /max_pool.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "gorgonia.org/gorgonia" 5 | "gorgonia.org/tensor" 6 | ) 7 | 8 | type MaxPool2DOpts struct { 9 | Kernel tensor.Shape 10 | Padding []int 11 | Stride []int 12 | } 13 | 14 | func (opts *MaxPool2DOpts) setDefaults() { 15 | if opts.Padding == nil { 16 | opts.Padding = []int{0, 0} 17 | } 18 | 19 | if opts.Stride == nil { 20 | opts.Stride = []int(opts.Kernel) 21 | } 22 | } 23 | 24 | type GlobalMaxPool2DModule struct { 25 | model *Model 26 | layer LayerType 27 | } 28 | 29 | func (m *GlobalMaxPool2DModule) Name() string { 30 | return "GlobalMaxPool2d" 31 | } 32 | 33 | func (m *GlobalMaxPool2DModule) Forward(inputs ...*Node) Nodes { 34 | err := m.model.CheckArity(m.layer, inputs, 1) 35 | if err != nil { 36 | panic(err) 37 | } 38 | 39 | x := inputs[0] 40 | x = gorgonia.Must(gorgonia.GlobalAveragePool2D(x)) 41 | 42 | return Nodes{x} 43 | } 44 | 45 | // GlobalMaxPool2D applies the global average pool operation to the given image 46 | func GlobalMaxPool2D(nn *Model) *GlobalMaxPool2DModule { 47 | lt := AddLayer("GlobalMaxPool2d") 48 | 49 | return &GlobalMaxPool2DModule{ 50 | model: nn, 51 | layer: lt, 52 | } 53 | } 54 | 55 | type MaxPool2DModule struct { 56 | model *Model 57 | opts MaxPool2DOpts 58 | layer LayerType 59 | } 60 | 61 | func (m *MaxPool2DModule) Name() string { 62 | return "MaxPool2d" 63 | } 64 | 65 | func (m *MaxPool2DModule) Forward(inputs ...*Node) Nodes { 66 | err := m.model.CheckArity(m.layer, inputs, 1) 67 | if err != nil { 68 | panic(err) 69 | } 70 | 71 | x := inputs[0] 72 | x = gorgonia.Must(gorgonia.MaxPool2D(x, m.opts.Kernel, m.opts.Padding, m.opts.Stride)) 73 | 74 | return Nodes{x} 75 | } 76 | 77 | // MaxPool2D applies the average pool operation to the given image 78 | func MaxPool2D(nn *Model, opts MaxPool2DOpts) *MaxPool2DModule { 79 | lt := AddLayer("MaxPool2D") 80 | 81 | opts.setDefaults() 82 | 83 | return &MaxPool2DModule{ 84 | model: nn, 85 | opts: opts, 86 | layer: lt, 87 | } 88 | } 89 | 90 | var ( 91 | _ Module = &MaxPool2DModule{} 92 | _ Module = &GlobalMaxPool2DModule{} 93 | ) 94 | -------------------------------------------------------------------------------- /tabnet/glu_block_test.go: -------------------------------------------------------------------------------- 1 | package tabnet 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/dcu/godl" 8 | "github.com/stretchr/testify/require" 9 | "gorgonia.org/gorgonia" 10 | "gorgonia.org/tensor" 11 | ) 12 | 13 | func TestGLUBlock(t *testing.T) { 14 | testCases := []struct { 15 | X tensor.Tensor 16 | BlockSize int 17 | Shared int 18 | Output int 19 | VBS int 20 | }{ 21 | { 22 | X: tensor.New( 23 | tensor.WithShape(6, 2), 24 | tensor.WithBacking([]float32{0.1, -0.5, 0.3, 0.9, 0.04, -0.3, 0.01, 0.09, -0.1, 0.9, 0.7, 0.04}), 25 | ), 26 | VBS: 3, 27 | Shared: 5, 28 | Output: 5, 29 | BlockSize: 2, 30 | }, 31 | } 32 | for i, tC := range testCases { 33 | t.Run(fmt.Sprintf("#%d", i+1), func(t *testing.T) { 34 | c := require.New(t) 35 | nn := godl.NewModel() 36 | 37 | input := gorgonia.NewTensor(nn.TrainGraph(), tensor.Float32, tC.X.Dims(), gorgonia.WithShape(tC.X.Shape()...), gorgonia.WithName("x"), gorgonia.WithValue(tC.X)) 38 | 39 | shared := make([]*godl.LinearModule, tC.Shared) 40 | fcInput := input.Shape()[1] 41 | fcOutput := 2 * tC.Output 42 | for i := 0; i < tC.Shared; i++ { 43 | shared[i] = godl.Linear(nn, godl.LinearOpts{ 44 | OutputDimension: fcOutput, // double the size so we can take half and half 45 | WeightsInit: gorgonia.RangedFromWithStep(-0.1, 0.01), 46 | InputDimension: fcInput, 47 | }) 48 | 49 | fcInput = tC.Output 50 | } 51 | 52 | result := GLUBlock(nn, GLUBlockOpts{ 53 | InputDimension: tC.X.Shape()[1], 54 | OutputDimension: tC.Output, 55 | Shared: shared, 56 | VirtualBatchSize: tC.VBS, 57 | Size: tC.BlockSize, 58 | }).Forward(input) 59 | 60 | y := result[0] 61 | cost := gorgonia.Must(gorgonia.Mean(y)) 62 | _, err := gorgonia.Grad(cost, input) 63 | c.NoError(err) 64 | 65 | vm := gorgonia.NewTapeMachine(nn.TrainGraph(), gorgonia.BindDualValues(nn.Learnables()...)) 66 | c.NoError(vm.RunAll()) 67 | 68 | nn.PrintWatchables() 69 | 70 | t.Logf("y: %v", y.Value()) 71 | t.Logf("dx: %v", input.Deriv().Value()) 72 | }) 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /embedding_test.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | "gorgonia.org/gorgonia" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestEmbedding(t *testing.T) { 12 | testCases := []struct { 13 | desc string 14 | classes int 15 | dim int 16 | input []int 17 | inputShape tensor.Shape 18 | expectedOutput []float32 19 | }{ 20 | { 21 | desc: "Example 1", 22 | input: []int{1, 2}, 23 | inputShape: tensor.Shape{2}, 24 | classes: 4, 25 | dim: 2, 26 | expectedOutput: []float32{2, 3, 4, 5}, 27 | }, 28 | { 29 | desc: "Example 2", 30 | input: []int{1, 2, 2, 0}, 31 | inputShape: tensor.Shape{4}, 32 | classes: 4, 33 | dim: 2, 34 | expectedOutput: []float32{2, 3, 4, 5, 4, 5, 0, 1}, 35 | }, 36 | { 37 | desc: "Example 3", 38 | input: []int{2}, 39 | inputShape: tensor.Shape{1, 1, 1}, 40 | classes: 4, 41 | dim: 2, 42 | expectedOutput: []float32{4, 5}, 43 | }, 44 | { 45 | desc: "Example 4", 46 | input: []int{0, 3, 2, 1}, 47 | inputShape: tensor.Shape{2, 1, 2}, 48 | classes: 4, 49 | dim: 2, 50 | expectedOutput: []float32{0, 1, 6, 7, 4, 5, 2, 3}, 51 | }, 52 | } 53 | 54 | for _, tcase := range testCases { 55 | t.Run(tcase.desc, func(t *testing.T) { 56 | c := require.New(t) 57 | 58 | tn := NewModel() 59 | emb := Embedding(tn, tcase.classes, tcase.dim, EmbeddingOpts{ 60 | WeightsInit: gorgonia.RangedFrom(0), 61 | }) 62 | 63 | ts := tensor.New(tensor.WithShape(tcase.inputShape...), tensor.WithBacking(tcase.input)) 64 | 65 | selector := gorgonia.NewTensor(tn.trainGraph, tensor.Int, tcase.inputShape.Dims(), gorgonia.WithShape(ts.Shape()...), gorgonia.WithValue(ts), gorgonia.WithName("selector")) 66 | output := emb.Forward(selector)[0] 67 | 68 | vm := gorgonia.NewTapeMachine(tn.trainGraph, gorgonia.BindDualValues(tn.learnables...)) 69 | c.NoError(vm.RunAll()) 70 | 71 | c.Equal(append(tcase.inputShape, tcase.dim), output.Shape()) 72 | c.Equal(tcase.expectedOutput, output.Value().Data()) 73 | }) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /embedding_generator.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "sort" 5 | 6 | "gorgonia.org/gorgonia" 7 | "gorgonia.org/tensor" 8 | ) 9 | 10 | type EmbeddingGeneratorModule struct { 11 | model *Model 12 | opts EmbeddingOpts 13 | skipEmbedding bool 14 | 15 | categoricalColumnIndexes []bool 16 | embeddings []*EmbeddingModule 17 | } 18 | 19 | func (m *EmbeddingGeneratorModule) Forward(inputs ...*Node) Nodes { 20 | err := m.model.CheckArity("EmbeddingGenerator", inputs, 1) 21 | if err != nil { 22 | panic(err) 23 | } 24 | 25 | x := inputs[0] 26 | if m.skipEmbedding { 27 | return Nodes{x} 28 | } 29 | 30 | cols := make([]*gorgonia.Node, len(m.categoricalColumnIndexes)) 31 | catFeatCounter := 0 32 | 33 | for featInitIdx, isCategorical := range m.categoricalColumnIndexes { 34 | s := gorgonia.Must(gorgonia.Slice(x, nil, gorgonia.S(featInitIdx))) 35 | 36 | if isCategorical { 37 | s := gorgonia.Must(gorgonia.ConvType(s, tensor.Float32, tensor.Int)) 38 | result := m.embeddings[catFeatCounter].Forward(s) 39 | 40 | cols[featInitIdx] = result[0] 41 | 42 | catFeatCounter++ 43 | } else { 44 | cols[featInitIdx] = gorgonia.Must(gorgonia.Reshape(s, tensor.Shape{s.Shape().TotalSize(), 1})) 45 | } 46 | } 47 | 48 | output := gorgonia.Must(gorgonia.Concat(1, cols...)) 49 | 50 | return Nodes{output} 51 | } 52 | 53 | func EmbeddingGenerator(m *Model, inputDims int, catDims []int, catIdxs []int, catEmbDim []int, opts EmbeddingOpts) *EmbeddingGeneratorModule { 54 | skipEmbedding := false 55 | if len(catDims) == 0 || len(catIdxs) == 0 { 56 | skipEmbedding = true 57 | } 58 | 59 | sort.Slice(catIdxs, func(i, j int) bool { 60 | return catIdxs[i] < catIdxs[j] 61 | }) 62 | 63 | embeddings := make([]*EmbeddingModule, len(catIdxs)) 64 | categoricalColumnIndexes := make([]bool, inputDims) 65 | 66 | for i, v := range catIdxs { 67 | embeddings[i] = Embedding( 68 | m, 69 | catDims[i], 70 | catEmbDim[i], 71 | opts, 72 | ) 73 | 74 | categoricalColumnIndexes[v] = true 75 | } 76 | 77 | return &EmbeddingGeneratorModule{ 78 | model: m, 79 | opts: opts, 80 | skipEmbedding: skipEmbedding, 81 | categoricalColumnIndexes: categoricalColumnIndexes, 82 | embeddings: embeddings, 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /data_loader_test.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "log" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestDataLoader(t *testing.T) { 12 | testCases := []struct { 13 | desc string 14 | x, y tensor.Tensor 15 | opts DataLoaderOpts 16 | loops int 17 | }{ 18 | { 19 | desc: "Example 1", 20 | x: tensor.New( 21 | tensor.WithShape(6), 22 | tensor.WithBacking( 23 | tensor.Range(tensor.Float32, 0, 6), 24 | ), 25 | ), 26 | y: tensor.New( 27 | tensor.WithShape(6), 28 | tensor.WithBacking( 29 | tensor.Range(tensor.Float32, 0, 6), 30 | ), 31 | ), 32 | opts: DataLoaderOpts{ 33 | Shuffle: true, 34 | BatchSize: 2, 35 | Drop: false, 36 | }, 37 | loops: 2, 38 | }, 39 | { 40 | desc: "Example 2", 41 | x: tensor.New( 42 | tensor.WithShape(6), 43 | tensor.WithBacking( 44 | tensor.Range(tensor.Float32, 0, 6), 45 | ), 46 | ), 47 | y: tensor.New( 48 | tensor.WithShape(6), 49 | tensor.WithBacking( 50 | tensor.Range(tensor.Float32, 0, 6), 51 | ), 52 | ), 53 | opts: DataLoaderOpts{ 54 | Shuffle: false, 55 | BatchSize: 4, 56 | Drop: false, 57 | }, 58 | loops: 2, 59 | }, 60 | } 61 | for _, tC := range testCases { 62 | t.Run(tC.desc, func(t *testing.T) { 63 | c := require.New(t) 64 | 65 | log.Printf("%v %v", tC.x, tC.y) 66 | 67 | dl := NewDataLoader(tC.x, tC.y, tC.opts) 68 | count := 0 69 | 70 | for i := 0; i < tC.loops; i++ { 71 | log.Printf("Loop %d", i) 72 | 73 | for dl.HasNext() { 74 | xVal, yVal := dl.Next() 75 | 76 | log.Printf("%v %v", xVal, yVal) 77 | 78 | for b := 0; b < tC.opts.BatchSize; b++ { 79 | xx, err := xVal.At(b) 80 | c.NoError(err) 81 | yy, err := yVal.At(b) 82 | c.NoError(err) 83 | 84 | c.Equal(xx, yy) 85 | } 86 | 87 | count++ 88 | } 89 | 90 | dl.Reset() 91 | } 92 | 93 | if !dl.opts.Drop && tC.x.Shape()[0]%dl.opts.BatchSize > 0 { 94 | n := tC.x.Shape()[0] % dl.opts.BatchSize 95 | c.Equal(count, tC.loops*(tC.x.Shape()[0]+n)/tC.opts.BatchSize) 96 | } else { 97 | c.Equal(count, tC.loops*tC.x.Shape()[0]/tC.opts.BatchSize) 98 | } 99 | 100 | log.Printf("count: %v", count) 101 | }) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /table/fixtures/census.csv: -------------------------------------------------------------------------------- 1 | 39, State-gov, 77516, Bachelors, 13, Never-married, Adm-clerical, Not-in-family, White, Male, 2174, 0, 40, United-States, <=50K 2 | 50, Self-emp-not-inc, 83311, Bachelors, 13, Married-civ-spouse, Exec-managerial, Husband, White, Male, 0, 0, 13, United-States, <=50K 3 | 38, Private, 215646, HS-grad, 9, Divorced, Handlers-cleaners, Not-in-family, White, Male, 0, 0, 40, United-States, <=50K 4 | 53, Private, 234721, 11th, 7, Married-civ-spouse, Handlers-cleaners, Husband, Black, Male, 0, 0, 40, United-States, <=50K 5 | 28, Private, 338409, Bachelors, 13, Married-civ-spouse, Prof-specialty, Wife, Black, Female, 0, 0, 40, Cuba, <=50K 6 | 37, Private, 284582, Masters, 14, Married-civ-spouse, Exec-managerial, Wife, White, Female, 0, 0, 40, United-States, <=50K 7 | 49, Private, 160187, 9th, 5, Married-spouse-absent, Other-service, Not-in-family, Black, Female, 0, 0, 16, Jamaica, <=50K 8 | 52, Self-emp-not-inc, 209642, HS-grad, 9, Married-civ-spouse, Exec-managerial, Husband, White, Male, 0, 0, 45, United-States, >50K 9 | 31, Private, 45781, Masters, 14, Never-married, Prof-specialty, Not-in-family, White, Female, 14084, 0, 50, United-States, >50K 10 | 42, Private, 159449, Bachelors, 13, Married-civ-spouse, Exec-managerial, Husband, White, Male, 5178, 0, 40, United-States, >50K 11 | 37, Private, 280464, Some-college, 10, Married-civ-spouse, Exec-managerial, Husband, Black, Male, 0, 0, 80, United-States, >50K 12 | 30, State-gov, 141297, Bachelors, 13, Married-civ-spouse, Prof-specialty, Husband, Asian-Pac-Islander, Male, 0, 0, 40, India, >50K 13 | 23, Private, 122272, Bachelors, 13, Never-married, Adm-clerical, Own-child, White, Female, 0, 0, 30, United-States, <=50K 14 | 32, Private, 205019, Assoc-acdm, 12, Never-married, Sales, Not-in-family, Black, Male, 0, 0, 50, United-States, <=50K 15 | 40, Private, 121772, Assoc-voc, 11, Married-civ-spouse, Craft-repair, Husband, Asian-Pac-Islander, Male, 0, 0, 40, ?, >50K 16 | 34, Private, 245487, 7th-8th, 4, Married-civ-spouse, Transport-moving, Husband, Amer-Indian-Eskimo, Male, 0, 0, 45, Mexico, <=50K 17 | 25, Self-emp-not-inc, 176756, HS-grad, 9, Never-married, Farming-fishing, Own-child, White, Male, 0, 0, 35, United-States, <=50K 18 | 32, Private, 186824, HS-grad, 9, Never-married, Machine-op-inspct, Unmarried, White, Male, 0, 0, 40, United-States, <=50K 19 | 38, Private, 28887, 11th, 7, Married-civ-spouse, Sales, Husband, White, Male, 0, 0, 50, United-States, <=50K 20 | 43, Self-emp-not-inc, 292175, Masters, 14, Divorced, Exec-managerial, Unmarried, White, Female, 0, 0, 45, United-States, >50K -------------------------------------------------------------------------------- /tabnet/feature_transformer.go: -------------------------------------------------------------------------------- 1 | package tabnet 2 | 3 | import ( 4 | "github.com/dcu/godl" 5 | "gorgonia.org/gorgonia" 6 | ) 7 | 8 | // FeatureTransformerOpts contains options for feature transformer layer 9 | type FeatureTransformerOpts struct { 10 | Shared []*godl.LinearModule 11 | VirtualBatchSize int 12 | IndependentBlocks int 13 | InputDimension int 14 | OutputDimension int 15 | WithBias bool 16 | Momentum float64 17 | 18 | WeightsInit gorgonia.InitWFn 19 | } 20 | 21 | func (o *FeatureTransformerOpts) setDefaults() { 22 | if o.InputDimension == 0 { 23 | panic("input dimension can't be nil") 24 | } 25 | 26 | if o.OutputDimension == 0 { 27 | panic("output dimension can't be nil") 28 | } 29 | 30 | if o.Momentum == 0 { 31 | o.Momentum = 0.01 32 | } 33 | } 34 | 35 | type FeatureTransformerModule struct { 36 | model *godl.Model 37 | layer godl.LayerType 38 | opts FeatureTransformerOpts 39 | shared *GLUBlockModule 40 | independent *GLUBlockModule 41 | } 42 | 43 | func (m *FeatureTransformerModule) Forward(inputs ...*godl.Node) godl.Nodes { 44 | if err := m.model.CheckArity(m.layer, inputs, 1); err != nil { 45 | panic(err) 46 | } 47 | 48 | x := inputs[0] 49 | res := m.shared.Forward(x) 50 | 51 | return m.independent.Forward(res[0]) 52 | } 53 | 54 | // FeatureTransformer implements a feature transformer layer 55 | func FeatureTransformer(nn *godl.Model, opts FeatureTransformerOpts) *FeatureTransformerModule { 56 | lt := godl.AddLayer("FeatureTransformer") 57 | 58 | opts.setDefaults() 59 | 60 | shared := GLUBlock(nn, GLUBlockOpts{ 61 | InputDimension: opts.InputDimension, 62 | OutputDimension: opts.OutputDimension, 63 | VirtualBatchSize: opts.VirtualBatchSize, 64 | Size: len(opts.Shared), 65 | Shared: opts.Shared, 66 | WithBias: opts.WithBias, 67 | Momentum: opts.Momentum, 68 | WeightsInit: opts.WeightsInit, 69 | }) 70 | 71 | independent := GLUBlock(nn, GLUBlockOpts{ 72 | InputDimension: opts.InputDimension, 73 | OutputDimension: opts.OutputDimension, 74 | VirtualBatchSize: opts.VirtualBatchSize, 75 | Size: opts.IndependentBlocks, 76 | Shared: nil, 77 | WithBias: opts.WithBias, 78 | Momentum: opts.Momentum, 79 | WeightsInit: opts.WeightsInit, 80 | }) 81 | 82 | return &FeatureTransformerModule{ 83 | model: nn, 84 | layer: lt, 85 | opts: opts, 86 | shared: shared, 87 | independent: independent, 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /tabnet/tab_net.go: -------------------------------------------------------------------------------- 1 | package tabnet 2 | 3 | import ( 4 | "github.com/dcu/godl" 5 | "github.com/dcu/godl/activation" 6 | "gorgonia.org/gorgonia" 7 | ) 8 | 9 | type TabNetOpts struct { 10 | OutputSize int 11 | InputSize int 12 | BatchSize int 13 | 14 | SharedBlocks int 15 | IndependentBlocks int 16 | DecisionSteps int 17 | PredictionLayerDim int 18 | AttentionLayerDim int 19 | 20 | MaskFunction activation.Function 21 | 22 | WithBias bool 23 | 24 | Gamma float64 25 | Momentum float64 26 | Epsilon float64 27 | VirtualBatchSize int 28 | WeightsInit, ScaleInit, BiasInit gorgonia.InitWFn 29 | 30 | CatDims []int 31 | CatIdxs []int 32 | CatEmbDim []int 33 | } 34 | 35 | type TabNetModule struct { 36 | model *godl.Model 37 | embedder *godl.EmbeddingGeneratorModule 38 | tabnet *TabNetNoEmbeddingsModule 39 | } 40 | 41 | func (m *TabNetModule) Name() string { 42 | return "TabNet" 43 | } 44 | 45 | func (m *TabNetModule) Forward(inputs ...*godl.Node) godl.Nodes { 46 | x := inputs[0] 47 | res := m.embedder.Forward(x) 48 | 49 | return m.tabnet.Forward(res[0]) 50 | } 51 | 52 | func TabNet(nn *godl.Model, opts TabNetOpts) *TabNetModule { 53 | embedder := godl.EmbeddingGenerator(nn, opts.InputSize, opts.CatDims, opts.CatIdxs, opts.CatEmbDim, godl.EmbeddingOpts{ 54 | WeightsInit: opts.WeightsInit, 55 | }) 56 | 57 | embedDimSum := 0 58 | for _, v := range opts.CatEmbDim { 59 | embedDimSum += v 60 | } 61 | 62 | tabNetInputDim := opts.InputSize + embedDimSum - len(opts.CatEmbDim) 63 | tn := TabNetNoEmbeddings(nn, TabNetNoEmbeddingsOpts{ 64 | InputSize: tabNetInputDim, 65 | OutputSize: opts.OutputSize, 66 | BatchSize: opts.BatchSize, 67 | VirtualBatchSize: opts.VirtualBatchSize, 68 | MaskFunction: opts.MaskFunction, 69 | WithBias: opts.WithBias, 70 | WeightsInit: opts.WeightsInit, 71 | ScaleInit: opts.ScaleInit, 72 | BiasInit: opts.BiasInit, 73 | SharedBlocks: opts.SharedBlocks, 74 | IndependentBlocks: opts.IndependentBlocks, 75 | DecisionSteps: opts.DecisionSteps, 76 | PredictionLayerDim: opts.PredictionLayerDim, 77 | AttentionLayerDim: opts.AttentionLayerDim, 78 | Gamma: opts.Gamma, 79 | Momentum: opts.Momentum, 80 | Epsilon: opts.Epsilon, 81 | }) 82 | 83 | return &TabNetModule{ 84 | model: nn, 85 | tabnet: tn, 86 | embedder: embedder, 87 | } 88 | } 89 | 90 | var ( 91 | _ godl.Module = &TabNetModule{} 92 | ) 93 | -------------------------------------------------------------------------------- /tabnet/glu_block.go: -------------------------------------------------------------------------------- 1 | package tabnet 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/chewxy/math32" 7 | "github.com/dcu/godl" 8 | "gorgonia.org/gorgonia" 9 | ) 10 | 11 | type GLUBlockOpts struct { 12 | InputDimension int 13 | OutputDimension int 14 | Shared []*godl.LinearModule 15 | VirtualBatchSize int 16 | 17 | Size int 18 | 19 | WithBias bool 20 | Momentum float64 21 | WeightsInit gorgonia.InitWFn 22 | } 23 | 24 | type GLUBlockModule struct { 25 | model *godl.Model 26 | layer godl.LayerType 27 | opts GLUBlockOpts 28 | 29 | gluLayers []*godl.GLUModule 30 | scale *godl.Node 31 | } 32 | 33 | func (m *GLUBlockModule) Forward(inputs ...*godl.Node) godl.Nodes { 34 | if err := m.model.CheckArity(m.layer, inputs, 1); err != nil { 35 | panic(err) 36 | } 37 | 38 | x := inputs[0] 39 | startAt := 0 40 | 41 | if len(m.opts.Shared) > 0 { 42 | result := m.gluLayers[0].Forward(x) 43 | 44 | x = result[0] 45 | startAt = 1 46 | } 47 | 48 | for _, glu := range m.gluLayers[startAt:] { 49 | result := glu.Forward(x)[0] 50 | 51 | x = gorgonia.Must(gorgonia.Add(x, result)) 52 | x = gorgonia.Must(gorgonia.Mul(x, m.scale)) 53 | } 54 | 55 | return godl.Nodes{x} 56 | } 57 | 58 | func GLUBlock(nn *godl.Model, opts GLUBlockOpts) *GLUBlockModule { 59 | lt := godl.AddLayer("GLUBlock") 60 | 61 | gluLayers := make([]*godl.GLUModule, 0, opts.Size) 62 | gluInput := opts.InputDimension 63 | if len(opts.Shared) == 0 { // for independent layers 64 | gluInput = opts.OutputDimension 65 | } 66 | 67 | gluOutput := opts.OutputDimension 68 | weightsInit := opts.WeightsInit 69 | 70 | if weightsInit == nil { 71 | gain := math.Sqrt(float64(gluInput+gluOutput) / math.Sqrt(float64(gluInput))) 72 | weightsInit = gorgonia.GlorotN(gain) 73 | } 74 | 75 | for i := 0; i < opts.Size; i++ { 76 | var fcLayer *godl.LinearModule 77 | if len(opts.Shared) > 0 { 78 | fcLayer = opts.Shared[i] 79 | } 80 | 81 | gluLayers = append(gluLayers, godl.GLU(nn, godl.GLUOpts{ 82 | InputDimension: gluInput, 83 | OutputDimension: gluOutput, 84 | VirtualBatchSize: opts.VirtualBatchSize, 85 | Linear: fcLayer, 86 | WeightsInit: weightsInit, 87 | WithBias: opts.WithBias, 88 | Momentum: opts.Momentum, 89 | })) 90 | 91 | gluInput = gluOutput 92 | } 93 | 94 | scale := gorgonia.NewConstant(math32.Sqrt(0.5), gorgonia.WithName("ft.scale")) 95 | 96 | return &GLUBlockModule{ 97 | model: nn, 98 | layer: lt, 99 | opts: opts, 100 | gluLayers: gluLayers, 101 | scale: scale, 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /tabnet/classifier.go: -------------------------------------------------------------------------------- 1 | package tabnet 2 | 3 | import ( 4 | "github.com/dcu/godl" 5 | "github.com/dcu/godl/activation" 6 | "gorgonia.org/gorgonia" 7 | "gorgonia.org/tensor" 8 | ) 9 | 10 | type Classifier struct { 11 | model *godl.Model 12 | tabnet *TabNetModule 13 | } 14 | 15 | type ClassifierOpts struct { 16 | BatchSize int 17 | VirtualBatchSize int 18 | MaskFunction activation.Function 19 | WithBias bool 20 | 21 | SharedBlocks int 22 | IndependentBlocks int 23 | DecisionSteps int 24 | PredictionLayerDim int 25 | AttentionLayerDim int 26 | 27 | Gamma float64 28 | Momentum float64 29 | Epsilon float64 30 | 31 | WeightsInit, ScaleInit, BiasInit gorgonia.InitWFn 32 | } 33 | 34 | func NewClassifier(inputDim int, catDims []int, catIdxs []int, catEmbDim []int, opts ClassifierOpts) *Classifier { 35 | nn := godl.NewModel() 36 | 37 | tn := TabNet(nn, TabNetOpts{ 38 | OutputSize: 1, 39 | BatchSize: opts.BatchSize, 40 | VirtualBatchSize: opts.VirtualBatchSize, 41 | InputSize: inputDim, 42 | MaskFunction: gorgonia.Sigmoid, 43 | WithBias: opts.WithBias, 44 | WeightsInit: opts.WeightsInit, 45 | ScaleInit: opts.ScaleInit, 46 | BiasInit: opts.BiasInit, 47 | SharedBlocks: opts.SharedBlocks, 48 | IndependentBlocks: opts.IndependentBlocks, 49 | DecisionSteps: opts.DecisionSteps, 50 | PredictionLayerDim: opts.PredictionLayerDim, 51 | AttentionLayerDim: opts.AttentionLayerDim, 52 | Gamma: opts.Gamma, 53 | Momentum: opts.Momentum, 54 | Epsilon: opts.Epsilon, 55 | CatDims: catDims, 56 | CatIdxs: catIdxs, 57 | CatEmbDim: catEmbDim, 58 | }) 59 | 60 | return &Classifier{ 61 | model: nn, 62 | tabnet: tn, 63 | } 64 | } 65 | 66 | func (r *Classifier) Model() *godl.Model { 67 | return r.model 68 | } 69 | 70 | func (r *Classifier) Train(trainX, trainY, validateX, validateY tensor.Tensor, opts godl.TrainOpts) error { 71 | if opts.CostFn == nil { 72 | lambdaSparse := gorgonia.NewConstant(float32(1e-3)) 73 | crossEntropy := godl.CategoricalCrossEntropyLoss(godl.CrossEntropyLossOpt{}) 74 | 75 | opts.CostFn = func(output godl.Nodes, target *godl.Node) *gorgonia.Node { 76 | cost := crossEntropy(output, target) 77 | cost = gorgonia.Must(gorgonia.Sub(cost, gorgonia.Must(gorgonia.Mul(lambdaSparse, output[1])))) 78 | 79 | return cost 80 | } 81 | } 82 | 83 | if opts.Solver == nil { 84 | opts.Solver = gorgonia.NewAdamSolver(gorgonia.WithBatchSize(float64(opts.BatchSize)), gorgonia.WithLearnRate(0.02)) 85 | } 86 | 87 | return godl.Train(r.model, r.tabnet, trainX, trainY, validateX, validateY, opts) 88 | } 89 | -------------------------------------------------------------------------------- /fc_test.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | "gorgonia.org/gorgonia" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestFC(t *testing.T) { 12 | testCases := []struct { 13 | desc string 14 | input tensor.Tensor 15 | expectedOutput tensor.Tensor 16 | expectedOutputGrad tensor.Tensor 17 | expectedInputGrad tensor.Tensor 18 | expectedWeightGrad tensor.Tensor 19 | }{ 20 | { 21 | desc: "Example 1", 22 | input: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float32{0.1, 0.01, 0.01, 0.1})), 23 | expectedOutput: tensor.New(tensor.WithShape(2, 4), tensor.WithBacking([]float32{-0.0052000005, 0.0013999998, 0.008, 0.014599999, -0.0025000002, 0.0041, 0.0107, 0.0173})), 24 | expectedOutputGrad: tensor.New(tensor.WithShape(2, 4), tensor.WithBacking([]float32{0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125})), 25 | expectedInputGrad: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float32{0.019999999999999997, 0.034999999999999996, 0.019999999999999997, 0.034999999999999996})), 26 | expectedWeightGrad: tensor.New(tensor.WithShape(2, 4), tensor.WithBacking([]float32{0.01375, 0.01375, 0.01375, 0.01375, 0.01375, 0.01375, 0.01375, 0.01375})), 27 | }, 28 | } 29 | for _, tC := range testCases { 30 | t.Run(tC.desc, func(t *testing.T) { 31 | c := require.New(t) 32 | 33 | m := NewModel() 34 | fc := Linear(m, LinearOpts{ 35 | InputDimension: tC.input.Shape()[1], 36 | OutputDimension: 4, 37 | WeightsInit: gorgonia.RangedFromWithStep(-0.05, 0.03), 38 | }) 39 | 40 | x := gorgonia.NewTensor(m.trainGraph, tensor.Float32, 2, gorgonia.WithShape(tC.input.Shape()...), gorgonia.WithValue(tC.input), gorgonia.WithName("x")) 41 | 42 | result := fc.Forward(x)[0] 43 | 44 | cost := gorgonia.Must(gorgonia.Mean(result)) 45 | 46 | l := m.learnables 47 | 48 | _, err := gorgonia.Grad(cost, append(l, x)...) 49 | c.NoError(err) 50 | 51 | // _ = ioutil.WriteFile("fc.dot", []byte(m.g.ToDot()), 0644) 52 | 53 | vm := gorgonia.NewTapeMachine(m.trainGraph, 54 | gorgonia.BindDualValues(l...), 55 | gorgonia.WithWatchlist(), 56 | gorgonia.TraceExec(), 57 | ) 58 | c.NoError(vm.RunAll()) 59 | c.NoError(vm.Close()) 60 | 61 | c.Equal(tC.expectedInputGrad.Data(), x.Deriv().Value().Data()) 62 | c.Equal(tC.expectedOutput.Data(), result.Value().Data()) 63 | 64 | outputGrad, err := result.Grad() 65 | c.NoError(err) 66 | c.Equal(tC.expectedOutputGrad.Data(), outputGrad.Data()) 67 | 68 | weightGrad, err := m.learnables[0].Grad() 69 | c.NoError(err) 70 | c.Equal(tC.expectedWeightGrad.Data(), weightGrad.Data()) 71 | }) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /weights.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "fmt" 5 | "sync/atomic" 6 | 7 | "github.com/fatih/color" 8 | "gorgonia.org/gorgonia" 9 | "gorgonia.org/tensor" 10 | ) 11 | 12 | var weightsCount int64 13 | 14 | // NewWeightsOpts defines the options to create a node 15 | // Value has priority if it's not defined then it uses the InitFN if it's not defined it uses Glorot/Xavier(1.0) 16 | // If UniqueName is empty an automatic one will be assigned. 17 | type NewWeightsOpts struct { 18 | UniqueName string 19 | Value gorgonia.Value 20 | InitFN gorgonia.InitWFn 21 | 22 | // Fixed indicates that the weights won't be learnable. By default the weights are learnable 23 | Fixed bool 24 | } 25 | 26 | // WeightsCount return the number of learnables 27 | func (t *Model) WeightsCount() int64 { 28 | return weightsCount 29 | } 30 | 31 | func (t *Model) AddWeights(lt LayerType, shape tensor.Shape, opts NewWeightsOpts) *gorgonia.Node { 32 | return t.AddLearnable(lt, "weight", shape, opts) 33 | } 34 | 35 | func (t *Model) AddBias(lt LayerType, shape tensor.Shape, opts NewWeightsOpts) *gorgonia.Node { 36 | return t.AddLearnable(lt, "bias", shape, opts) 37 | } 38 | 39 | func (t *Model) AddLearnable(lt LayerType, typ string, shape tensor.Shape, opts NewWeightsOpts) *gorgonia.Node { 40 | if opts.UniqueName == "" { 41 | opts.UniqueName = fmt.Sprintf("%s.%d.%s.%d.%d", lt, weightsCount, typ, len(t.learnables), shape.TotalSize()) 42 | } 43 | 44 | w := t.CreateWeightsNode(shape, opts) 45 | t.learnables = append(t.learnables, w) 46 | 47 | return w 48 | } 49 | 50 | func (t *Model) CreateWeightsNode(shape tensor.Shape, opts NewWeightsOpts) *gorgonia.Node { 51 | atomic.AddInt64(&weightsCount, 1) 52 | 53 | var init gorgonia.NodeConsOpt 54 | 55 | if opts.Value != nil { 56 | init = gorgonia.WithValue(opts.Value) 57 | } else if opts.InitFN != nil { 58 | init = gorgonia.WithInit(opts.InitFN) 59 | } else { 60 | init = gorgonia.WithInit(gorgonia.GlorotN(1.0)) 61 | } 62 | 63 | val, err := t.Storage.TensorByName(opts.UniqueName) 64 | if err == nil { 65 | color.Green("Loaded weights %v %v from storage", shape, opts.UniqueName) 66 | init = gorgonia.WithValue(val) 67 | } else { 68 | color.Yellow("Assigned random weights to %v %v", shape, opts.UniqueName) 69 | } 70 | 71 | var w *gorgonia.Node 72 | 73 | if shape.Dims() == 2 { 74 | w = gorgonia.NewMatrix( 75 | t.trainGraph, 76 | tensor.Float32, 77 | gorgonia.WithShape(shape...), 78 | gorgonia.WithName(opts.UniqueName), 79 | init, 80 | ) 81 | } else { 82 | w = gorgonia.NewTensor( 83 | t.trainGraph, 84 | tensor.Float32, 85 | shape.Dims(), 86 | gorgonia.WithShape(shape...), 87 | gorgonia.WithName(opts.UniqueName), 88 | init, 89 | ) 90 | } 91 | 92 | return w 93 | } 94 | -------------------------------------------------------------------------------- /tabnet/attentive_transformer.go: -------------------------------------------------------------------------------- 1 | package tabnet 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/dcu/godl" 7 | "github.com/dcu/godl/activation" 8 | "gorgonia.org/gorgonia" 9 | ) 10 | 11 | type AttentiveTransformerOpts struct { 12 | InputDimension int 13 | OutputDimension int 14 | Momentum float64 15 | Epsilon float64 16 | VirtualBatchSize int 17 | Activation activation.Function 18 | WithBias bool 19 | WeightsInit, ScaleInit, BiasInit gorgonia.InitWFn 20 | } 21 | 22 | func (o *AttentiveTransformerOpts) setDefaults() { 23 | if o.Activation == nil { 24 | o.Activation = activation.SparseMax 25 | } 26 | 27 | if o.WeightsInit == nil { 28 | gain := math.Sqrt(float64(o.InputDimension+o.OutputDimension) / math.Sqrt(float64(4*o.InputDimension))) 29 | o.WeightsInit = gorgonia.GlorotN(gain) 30 | } 31 | } 32 | 33 | type AttentiveTransformerModule struct { 34 | model *godl.Model 35 | layer godl.LayerType 36 | opts AttentiveTransformerOpts 37 | linear *godl.LinearModule 38 | gbn *godl.GhostBatchNormModule 39 | } 40 | 41 | func (m *AttentiveTransformerModule) Forward(inputs ...*godl.Node) godl.Nodes { 42 | if err := m.model.CheckArity(m.layer, inputs, 2); err != nil { 43 | panic(err) 44 | } 45 | 46 | x := inputs[0] 47 | prior := inputs[1] 48 | 49 | fc := m.linear.Forward(x) 50 | bn := m.gbn.Forward(fc...) 51 | 52 | mul := gorgonia.Must(gorgonia.HadamardProd(bn[0], prior)) 53 | 54 | sm := gorgonia.Must(m.opts.Activation(mul)) 55 | 56 | return godl.Nodes{sm} 57 | } 58 | 59 | // AttentiveTransformer implements an attetion transformer layer 60 | func AttentiveTransformer(nn *godl.Model, opts AttentiveTransformerOpts) *AttentiveTransformerModule { 61 | lt := godl.AddLayer("AttentiveTransformer") 62 | 63 | opts.setDefaults() 64 | 65 | weightsInit := opts.WeightsInit 66 | if weightsInit == nil { 67 | gain := math.Sqrt(float64(opts.InputDimension+opts.OutputDimension) / math.Sqrt(float64(4*opts.InputDimension))) 68 | weightsInit = gorgonia.GlorotN(gain) 69 | } 70 | 71 | fcLayer := godl.Linear(nn, godl.LinearOpts{ 72 | InputDimension: opts.InputDimension, 73 | OutputDimension: opts.OutputDimension, 74 | WeightsInit: weightsInit, 75 | WithBias: opts.WithBias, 76 | }) 77 | 78 | gbnLayer := godl.GhostBatchNorm(nn, godl.GhostBatchNormOpts{ 79 | Momentum: opts.Momentum, 80 | Epsilon: opts.Epsilon, 81 | VirtualBatchSize: opts.VirtualBatchSize, 82 | OutputDimension: opts.OutputDimension, 83 | ScaleInit: opts.ScaleInit, 84 | BiasInit: opts.BiasInit, 85 | }) 86 | 87 | return &AttentiveTransformerModule{ 88 | model: nn, 89 | layer: lt, 90 | opts: opts, 91 | linear: fcLayer, 92 | gbn: gbnLayer, 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /conv2d.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "math" 5 | 6 | "gorgonia.org/gorgonia" 7 | "gorgonia.org/tensor" 8 | ) 9 | 10 | // Conv2dOpts are the options to run the conv2d operation 11 | type Conv2dOpts struct { 12 | InputDimension int 13 | OutputDimension int 14 | 15 | KernelSize tensor.Shape 16 | Pad []int 17 | Stride []int 18 | Dilation []int 19 | 20 | WithBias bool 21 | 22 | WeightsInit, BiasInit gorgonia.InitWFn 23 | WeightsName, BiasName string 24 | FixedWeights bool 25 | } 26 | 27 | func (o *Conv2dOpts) setDefaults() { 28 | if o.KernelSize == nil { 29 | o.KernelSize = tensor.Shape{3, 3} 30 | } 31 | 32 | if o.Pad == nil { 33 | o.Pad = []int{0, 0} 34 | } 35 | 36 | if o.Stride == nil { 37 | o.Stride = []int{1, 1} 38 | } 39 | 40 | if o.Dilation == nil { 41 | o.Dilation = []int{1, 1} 42 | } 43 | 44 | if o.WeightsInit == nil { 45 | k := math.Sqrt(1 / float64(o.OutputDimension*o.KernelSize[0]*o.KernelSize[1])) 46 | o.WeightsInit = gorgonia.Uniform(-k, k) 47 | } 48 | 49 | if o.BiasInit == nil { 50 | k := math.Sqrt(1 / float64(o.OutputDimension*o.KernelSize[0]*o.KernelSize[1])) 51 | o.WeightsInit = gorgonia.Uniform(-k, k) 52 | } 53 | } 54 | 55 | type Conv2dModule struct { 56 | model *Model 57 | layer LayerType 58 | 59 | opts Conv2dOpts 60 | 61 | weight, bias *Node 62 | } 63 | 64 | func (m *Conv2dModule) Name() string { 65 | return "Conv2d" 66 | } 67 | 68 | func (m *Conv2dModule) Forward(inputs ...*Node) Nodes { 69 | err := m.model.CheckArity(m.layer, inputs, 1) 70 | if err != nil { 71 | panic(err) 72 | } 73 | 74 | x := inputs[0] 75 | x = gorgonia.Must(gorgonia.Conv2d(x, m.weight, m.opts.KernelSize, m.opts.Pad, m.opts.Stride, m.opts.Dilation)) 76 | 77 | if m.bias != nil { 78 | x = gorgonia.Must(gorgonia.BroadcastAdd(x, m.bias, nil, []byte{0, 2, 3})) 79 | } 80 | 81 | return Nodes{x} 82 | } 83 | 84 | // Conv2d applies a conv2d operation to the input 85 | func Conv2d(m *Model, opts Conv2dOpts) *Conv2dModule { 86 | opts.setDefaults() 87 | lt := AddLayer("Conv2d") 88 | 89 | w := m.AddWeights(lt, tensor.Shape{opts.OutputDimension, opts.InputDimension, opts.KernelSize[0], opts.KernelSize[0]}, NewWeightsOpts{ 90 | InitFN: opts.WeightsInit, 91 | UniqueName: opts.WeightsName, 92 | Fixed: opts.FixedWeights, 93 | }) 94 | 95 | var bias *gorgonia.Node 96 | if opts.WithBias { 97 | bias = m.AddBias(lt, tensor.Shape{1, opts.OutputDimension, 1, 1}, NewWeightsOpts{ 98 | InitFN: opts.BiasInit, 99 | UniqueName: opts.BiasName, 100 | Fixed: opts.FixedWeights, 101 | }) 102 | } 103 | 104 | return &Conv2dModule{ 105 | model: m, 106 | layer: lt, 107 | opts: opts, 108 | weight: w, 109 | bias: bias, 110 | } 111 | } 112 | 113 | var ( 114 | _ Module = &Conv2dModule{} 115 | ) 116 | -------------------------------------------------------------------------------- /glu.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "github.com/dcu/godl/activation" 5 | "gorgonia.org/gorgonia" 6 | ) 7 | 8 | var gluCount uint64 = 0 9 | 10 | // GLUOpts are the supported options for GLU 11 | type GLUOpts struct { 12 | InputDimension int 13 | OutputDimension int 14 | VirtualBatchSize int 15 | Activation activation.Function 16 | Linear *LinearModule 17 | WeightsInit gorgonia.InitWFn 18 | WithBias bool 19 | Momentum float64 20 | } 21 | 22 | func (opts *GLUOpts) setDefaults() { 23 | if opts.Momentum == 0 { 24 | opts.Momentum = 0.02 25 | } 26 | 27 | if opts.Activation == nil { 28 | opts.Activation = gorgonia.Sigmoid 29 | } 30 | 31 | if opts.InputDimension == 0 { 32 | panic("input dimension must be set") 33 | } 34 | 35 | if opts.OutputDimension == 0 { 36 | panic("output dimension must be set") 37 | } 38 | 39 | if opts.VirtualBatchSize == 0 { 40 | panic("virtual batch size must be set") 41 | } 42 | } 43 | 44 | type GLUModule struct { 45 | model *Model 46 | layer LayerType 47 | opts GLUOpts 48 | gbn *GhostBatchNormModule 49 | linear *LinearModule 50 | } 51 | 52 | func (m *GLUModule) Forward(inputs ...*Node) Nodes { 53 | if err := m.model.CheckArity(m.layer, inputs, 1); err != nil { 54 | panic(err) 55 | } 56 | 57 | x := inputs[0] 58 | 59 | fcResult := m.opts.Linear.Forward(x) 60 | gnbResult := m.gbn.Forward(fcResult...)[0] 61 | 62 | // GLU 63 | firstHalf := gorgonia.Must(gorgonia.Slice(gnbResult, nil, gorgonia.S(0, m.opts.OutputDimension))) 64 | secondHalf := gorgonia.Must(gorgonia.Slice(gnbResult, nil, gorgonia.S(m.opts.OutputDimension, gnbResult.Shape()[1]))) 65 | 66 | act, err := m.opts.Activation(secondHalf) 67 | if err != nil { 68 | panic(ErrorF(m.layer, "%s: applying activation function failed: %w", err)) 69 | } 70 | 71 | mul, err := gorgonia.HadamardProd(firstHalf, act) 72 | if err != nil { 73 | panic(ErrorF(m.layer, "%s: HadamardProd %d x %d: %w", firstHalf.Shape(), act.Shape(), err)) 74 | } 75 | 76 | return Nodes{mul} 77 | } 78 | 79 | // GLU implements a Gated Linear Unit Block 80 | func GLU(nn *Model, opts GLUOpts) *GLUModule { 81 | opts.setDefaults() 82 | 83 | lt := AddLayer("GLU") 84 | 85 | if opts.Linear == nil { 86 | opts.Linear = Linear(nn, LinearOpts{ 87 | InputDimension: opts.InputDimension, 88 | OutputDimension: opts.OutputDimension * 2, 89 | WeightsInit: opts.WeightsInit, 90 | WithBias: opts.WithBias, 91 | }) 92 | } 93 | 94 | gbn := GhostBatchNorm(nn, GhostBatchNormOpts{ 95 | VirtualBatchSize: opts.VirtualBatchSize, 96 | OutputDimension: opts.OutputDimension * 2, 97 | Momentum: opts.Momentum, 98 | }) 99 | 100 | return &GLUModule{ 101 | model: nn, 102 | layer: lt, 103 | opts: opts, 104 | gbn: gbn, 105 | linear: opts.Linear, 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/dcu/godl 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/anthonynsimon/bild v0.13.0 7 | github.com/cheggaaa/pb/v3 v3.0.8 8 | github.com/chewxy/math32 v1.10.1 9 | github.com/dcu/resize v0.0.0-20201120200444-0e5185b92511 10 | github.com/fatih/color v1.13.0 11 | github.com/fogleman/gg v1.3.0 12 | github.com/mum4k/termdash v0.16.1 13 | github.com/olekukonko/tablewriter v0.0.5 14 | github.com/oliamb/cutter v0.2.2 15 | github.com/stretchr/testify v1.7.0 16 | gonum.org/v1/plot v0.11.0 17 | gorgonia.org/gorgonia v0.9.18-0.20220428013624-8f3502bcdaf8 18 | gorgonia.org/qol v0.0.0-20220326215349-708736a2aac5 19 | gorgonia.org/tensor v0.9.23 20 | ) 21 | 22 | require ( 23 | git.sr.ht/~sbinet/gg v0.3.1 // indirect 24 | github.com/VividCortex/ewma v1.2.0 // indirect 25 | github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b // indirect 26 | github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect 27 | github.com/awalterschulze/gographviz v2.0.3+incompatible // indirect 28 | github.com/chewxy/hm v1.0.0 // indirect 29 | github.com/davecgh/go-spew v1.1.1 // indirect 30 | github.com/go-fonts/liberation v0.2.0 // indirect 31 | github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81 // indirect 32 | github.com/go-pdf/fpdf v0.6.0 // indirect 33 | github.com/gogo/protobuf v1.3.2 // indirect 34 | github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect 35 | github.com/golang/protobuf v1.5.2 // indirect 36 | github.com/google/flatbuffers v2.0.6+incompatible // indirect 37 | github.com/google/uuid v1.3.0 // indirect 38 | github.com/kr/text v0.2.0 // indirect 39 | github.com/leesper/go_rng v0.0.0-20190531154944-a612b043e353 // indirect 40 | github.com/mattn/go-colorable v0.1.12 // indirect 41 | github.com/mattn/go-isatty v0.0.14 // indirect 42 | github.com/mattn/go-runewidth v0.0.13 // indirect 43 | github.com/nsf/termbox-go v1.1.1 // indirect 44 | github.com/pkg/errors v0.9.1 // indirect 45 | github.com/pmezard/go-difflib v1.0.0 // indirect 46 | github.com/rivo/uniseg v0.2.0 // indirect 47 | github.com/xtgo/set v1.0.0 // indirect 48 | go4.org/unsafe/assume-no-moving-gc v0.0.0-20211027215541-db492cf91b37 // indirect 49 | golang.org/x/exp v0.0.0-20220321173239-a90fa8a75705 // indirect 50 | golang.org/x/image v0.0.0-20220321031419-a8550c1d254a // indirect 51 | golang.org/x/sys v0.0.0-20220403205710-6acee93ad0eb // indirect 52 | golang.org/x/text v0.3.7 // indirect 53 | golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f // indirect 54 | gonum.org/v1/gonum v0.11.0 // indirect 55 | google.golang.org/protobuf v1.28.0 // indirect 56 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 57 | gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 // indirect 58 | gorgonia.org/cu v0.9.4 // indirect 59 | gorgonia.org/dawson v1.2.0 // indirect 60 | gorgonia.org/vecf32 v0.9.0 // indirect 61 | gorgonia.org/vecf64 v0.9.0 // indirect 62 | ) 63 | -------------------------------------------------------------------------------- /fc.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "github.com/dcu/godl/activation" 5 | "gorgonia.org/gorgonia" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | // LinearOpts contains optional parameter for a layer 10 | type LinearOpts struct { 11 | Activation activation.Function 12 | Dropout float64 13 | OutputDimension int 14 | InputDimension int 15 | 16 | WeightsInit gorgonia.InitWFn 17 | BiasInit gorgonia.InitWFn 18 | WithBias bool 19 | WeightsName, BiasName string 20 | FixedWeights bool 21 | } 22 | 23 | type LinearModule struct { 24 | model *Model 25 | opts LinearOpts 26 | layer LayerType 27 | 28 | weight, bias *Node 29 | } 30 | 31 | func (m *LinearModule) Name() string { 32 | return "Linear" 33 | } 34 | 35 | func (m *LinearModule) Forward(inputs ...*Node) (out Nodes) { 36 | x := inputs[0] 37 | xShape := x.Shape() 38 | 39 | if x.Dims() > 2 { 40 | b, v := xShape[0], tensor.Shape(xShape[1:]).TotalSize() 41 | x = gorgonia.Must(gorgonia.Reshape(x, tensor.Shape{b, v})) 42 | } 43 | 44 | wT := gorgonia.Must(gorgonia.Transpose(m.weight, 1, 0)) 45 | 46 | result, err := gorgonia.Mul(x, wT) 47 | if err != nil { 48 | panic(ErrorF(m.layer, "error applying mul %v x %v: %w ", x.Shape(), wT.Shape(), err)) 49 | } 50 | 51 | if m.opts.WithBias { 52 | result, err = gorgonia.BroadcastAdd(result, m.bias, nil, []byte{0}) 53 | if err != nil { 54 | panic(ErrorF(m.layer, "error adding bias %w", err)) 55 | } 56 | } 57 | 58 | if m.opts.Activation != nil { 59 | result, err = m.opts.Activation(result) 60 | if err != nil { 61 | panic(ErrorF(m.layer, "error applying activation %w", err)) 62 | } 63 | } 64 | 65 | if m.opts.Dropout > 0.0 { 66 | result, err = gorgonia.Dropout(result, m.opts.Dropout) 67 | if err != nil { 68 | panic(ErrorF(m.layer, "error applying dropout %w", err)) 69 | } 70 | } 71 | 72 | return Nodes{result} 73 | } 74 | 75 | func Linear(nn *Model, opts LinearOpts) *LinearModule { 76 | lt := AddLayer("FC") 77 | 78 | MustBeGreatherThan(lt, "input dimension", opts.InputDimension, 0) 79 | MustBeGreatherThan(lt, "output dimension", opts.OutputDimension, 0) 80 | 81 | var ( 82 | bias *gorgonia.Node 83 | w = nn.AddWeights(lt, tensor.Shape{opts.OutputDimension, opts.InputDimension}, NewWeightsOpts{ 84 | InitFN: opts.WeightsInit, 85 | UniqueName: opts.WeightsName, 86 | Fixed: opts.FixedWeights, 87 | }) 88 | ) 89 | 90 | if opts.WithBias { 91 | bias = nn.AddBias(lt, tensor.Shape{1, opts.OutputDimension}, NewWeightsOpts{ 92 | InitFN: opts.BiasInit, 93 | UniqueName: opts.BiasName, 94 | Fixed: opts.FixedWeights, 95 | }) 96 | } 97 | 98 | return &LinearModule{ 99 | model: nn, 100 | layer: lt, 101 | opts: opts, 102 | bias: bias, 103 | weight: w, 104 | } 105 | } 106 | 107 | var ( 108 | _ Module = &LinearModule{} 109 | ) 110 | -------------------------------------------------------------------------------- /batch_norm.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "fmt" 5 | 6 | "gorgonia.org/gorgonia" 7 | "gorgonia.org/tensor" 8 | ) 9 | 10 | // BatchNormOpts are the options to configure a batch normalization 11 | type BatchNormOpts struct { 12 | Momentum float64 13 | Epsilon float64 14 | ScaleInit, BiasInit gorgonia.InitWFn 15 | 16 | ScaleName, BiasName string 17 | 18 | InputSize int 19 | } 20 | 21 | func (o *BatchNormOpts) setDefaults() { 22 | if o.InputSize == 0 { 23 | panic("output size for BN can't be 0") 24 | } 25 | 26 | if o.Momentum == 0.0 { 27 | o.Momentum = 0.01 28 | } 29 | 30 | if o.Epsilon == 0.0 { 31 | o.Epsilon = 1e-5 32 | } 33 | 34 | if o.ScaleInit == nil { 35 | o.ScaleInit = gorgonia.Ones() 36 | } 37 | 38 | if o.BiasInit == nil { 39 | o.BiasInit = gorgonia.Zeroes() 40 | } 41 | } 42 | 43 | type BatchNormModule struct { 44 | model *Model 45 | layer LayerType 46 | opts BatchNormOpts 47 | 48 | scale, bias *Node 49 | } 50 | 51 | func (m *BatchNormModule) Name() string { 52 | return "BatchNorm" 53 | } 54 | 55 | func (m *BatchNormModule) Forward(inputs ...*Node) Nodes { 56 | if err := m.model.CheckArity(m.layer, inputs, 1); err != nil { 57 | panic(err) 58 | } 59 | 60 | x := inputs[0] 61 | 62 | ret, _, _, _, err := gorgonia.BatchNorm(x, m.scale, m.bias, float64(m.opts.Momentum), float64(m.opts.Epsilon)) 63 | if err != nil { 64 | panic(fmt.Errorf("%v: %w", m.layer, err)) 65 | } 66 | 67 | return Nodes{ret} 68 | } 69 | 70 | // BatchNorm1d defines the batch norm operation for tensors with shape (B, N) 71 | func BatchNorm1d(nn *Model, opts BatchNormOpts) *BatchNormModule { 72 | opts.setDefaults() 73 | lt := AddLayer("BatchNorm1d") 74 | 75 | scale := nn.AddLearnable(lt, "scale", tensor.Shape{1, opts.InputSize}, NewWeightsOpts{ 76 | UniqueName: opts.ScaleName, 77 | InitFN: opts.ScaleInit, 78 | }) 79 | bias := nn.AddBias(lt, tensor.Shape{1, opts.InputSize}, NewWeightsOpts{ 80 | UniqueName: opts.BiasName, 81 | InitFN: opts.BiasInit, 82 | }) 83 | 84 | return &BatchNormModule{ 85 | model: nn, 86 | layer: lt, 87 | opts: opts, 88 | scale: scale, 89 | bias: bias, 90 | } 91 | } 92 | 93 | // BatchNorm2d defines the batch norm operation for tensors with shape (B, C, W, H) 94 | func BatchNorm2d(nn *Model, opts BatchNormOpts) *BatchNormModule { 95 | opts.setDefaults() 96 | lt := AddLayer("BatchNorm2d") 97 | 98 | scale := nn.AddLearnable(lt, "scale", tensor.Shape{1, opts.InputSize, 1, 1}, NewWeightsOpts{ 99 | UniqueName: opts.ScaleName, 100 | InitFN: opts.ScaleInit, 101 | }) 102 | bias := nn.AddBias(lt, tensor.Shape{1, opts.InputSize, 1, 1}, NewWeightsOpts{ 103 | UniqueName: opts.BiasName, 104 | InitFN: opts.BiasInit, 105 | }) 106 | 107 | return &BatchNormModule{ 108 | model: nn, 109 | layer: lt, 110 | opts: opts, 111 | scale: scale, 112 | bias: bias, 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /gbn.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "math" 5 | 6 | "gorgonia.org/gorgonia" 7 | ) 8 | 9 | // GhostBatchNormOpts contains config options for the ghost batch normalization 10 | type GhostBatchNormOpts struct { 11 | Momentum float64 12 | Epsilon float64 13 | VirtualBatchSize int 14 | OutputDimension int 15 | 16 | ScaleInit, BiasInit gorgonia.InitWFn 17 | } 18 | 19 | func (o *GhostBatchNormOpts) setDefaults() { 20 | if o.VirtualBatchSize == 0 { 21 | o.VirtualBatchSize = 128 22 | } 23 | 24 | if o.Momentum == 0.0 { 25 | o.Momentum = 0.01 26 | } 27 | 28 | if o.Epsilon == 0.0 { 29 | o.Epsilon = 1e-5 30 | } 31 | } 32 | 33 | type GhostBatchNormModule struct { 34 | model *Model 35 | layer LayerType 36 | bn *BatchNormModule 37 | opts GhostBatchNormOpts 38 | } 39 | 40 | func (m *GhostBatchNormModule) Forward(inputs ...*Node) Nodes { 41 | if err := m.model.CheckArity(m.layer, inputs, 1); err != nil { 42 | panic(err) 43 | } 44 | 45 | x := inputs[0] 46 | xShape := x.Shape() 47 | inputSize := xShape[0] 48 | 49 | if m.opts.VirtualBatchSize > inputSize { 50 | m.opts.VirtualBatchSize = inputSize 51 | } 52 | 53 | if inputSize%m.opts.VirtualBatchSize != 0 { 54 | panic(ErrorF(m.layer, "input size (%d) must be divisible by virtual batch size (%v)", inputSize, m.opts.VirtualBatchSize)) 55 | } 56 | 57 | batches := int(math.Ceil(float64(inputSize) / float64(m.opts.VirtualBatchSize))) 58 | nodes := make([]*gorgonia.Node, 0, batches) 59 | 60 | // Split the vector in virtual batches 61 | for vb := 0; vb < batches; vb++ { 62 | start := vb * m.opts.VirtualBatchSize 63 | if start > inputSize { 64 | break 65 | } 66 | 67 | end := start + m.opts.VirtualBatchSize 68 | if end > inputSize { 69 | panic("this should not happen") 70 | } 71 | 72 | virtualBatch := gorgonia.Must(gorgonia.Slice(x, gorgonia.S(start, end))) 73 | 74 | result := m.bn.Forward(virtualBatch) 75 | 76 | nodes = append(nodes, result...) 77 | } 78 | 79 | ret, err := gorgonia.Concat(0, nodes...) 80 | if err != nil { 81 | panic(ErrorF(m.layer, "error concatenating %d nodes: %w", len(nodes), err)) 82 | } 83 | 84 | return Nodes{ret} 85 | } 86 | 87 | // GhostBatchNorm implements a Ghost Batch Normalization: https://arxiv.org/pdf/1705.08741.pdf 88 | // momentum defaults to 0.01 if 0 is passed 89 | // epsilon defaults to 1e-5 if 0 is passed 90 | func GhostBatchNorm(nn *Model, opts GhostBatchNormOpts) *GhostBatchNormModule { 91 | opts.setDefaults() 92 | 93 | lt := AddLayer("GBN") 94 | 95 | MustBeGreatherThan(lt, "OutputDimesion", opts.OutputDimension, 0) 96 | 97 | bn := BatchNorm1d(nn, BatchNormOpts{ 98 | Momentum: opts.Momentum, 99 | Epsilon: opts.Epsilon, 100 | ScaleInit: opts.ScaleInit, 101 | BiasInit: opts.BiasInit, 102 | InputSize: opts.OutputDimension, 103 | }) 104 | 105 | return &GhostBatchNormModule{ 106 | model: nn, 107 | layer: lt, 108 | bn: bn, 109 | opts: opts, 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /examples/mnist/fc/fc.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | 8 | "github.com/dcu/godl" 9 | "github.com/dcu/godl/examples/mnist" 10 | "gorgonia.org/gorgonia" 11 | ) 12 | 13 | var ( 14 | datasetDir string 15 | ) 16 | 17 | func init() { 18 | flag.StringVar(&datasetDir, "dataset-dir", "..", "The dir where the dataset is located") 19 | } 20 | 21 | func handleErr(what string, err error) { 22 | if err != nil { 23 | log.Panicf("%s: %v", what, err) 24 | } 25 | } 26 | 27 | func main() { 28 | trainX, trainY, err := mnist.Load(mnist.ModeTrain, datasetDir) 29 | handleErr("loading trainig mnist data", err) 30 | 31 | validateX, validateY, err := mnist.Load(mnist.ModeTrain, datasetDir) 32 | handleErr("loading validation mnist data", err) 33 | 34 | model := godl.NewModel() 35 | layer := godl.Sequential( 36 | model, 37 | godl.Linear(model, godl.LinearOpts{ 38 | InputDimension: 784, 39 | OutputDimension: 300, 40 | WithBias: true, 41 | }), 42 | godl.BatchNorm1d(model, godl.BatchNormOpts{ 43 | InputSize: 300, 44 | }), 45 | godl.Rectify(), 46 | godl.Linear(model, godl.LinearOpts{ 47 | InputDimension: 300, 48 | OutputDimension: 100, 49 | WithBias: true, 50 | }), 51 | godl.BatchNorm1d(model, godl.BatchNormOpts{ 52 | InputSize: 100, 53 | }), 54 | godl.Rectify(), 55 | godl.Linear(model, godl.LinearOpts{ 56 | InputDimension: 100, 57 | OutputDimension: 10, 58 | WithBias: true, 59 | }), 60 | ) 61 | 62 | err = godl.Train(model, layer, trainX, trainY, validateX, validateY, godl.TrainOpts{ 63 | Epochs: 3, 64 | ValidateEvery: 0, 65 | BatchSize: 64, 66 | // WriteGraphFileTo: "graph.svg", 67 | Solver: gorgonia.NewAdamSolver(gorgonia.WithLearnRate(5e-4)), 68 | CostObserver: func(epoch, totalEpoch, batch, totalBatch int, cost float32) { 69 | // log.Printf("batch=%d/%d epoch=%d/%d cost=%0.3f", batch, totalBatch, epoch, totalEpoch, cost) 70 | }, 71 | MatchTypeFor: func(predVal, targetVal []float32) godl.MatchType { 72 | var ( 73 | rowLabel int 74 | yRowHigh float32 75 | ) 76 | 77 | for k := 0; k < 10; k++ { 78 | if k == 0 { 79 | rowLabel = 0 80 | yRowHigh = targetVal[k] 81 | } else if targetVal[k] > yRowHigh { 82 | rowLabel = k 83 | yRowHigh = targetVal[k] 84 | } 85 | } 86 | 87 | var ( 88 | rowGuess int 89 | predRowHigh float32 90 | ) 91 | 92 | for k := 0; k < 10; k++ { 93 | if k == 0 { 94 | rowGuess = 0 95 | predRowHigh = predVal[k] 96 | } else if predVal[k] > predRowHigh { 97 | rowGuess = k 98 | predRowHigh = predVal[k] 99 | } 100 | } 101 | 102 | if rowLabel == rowGuess { 103 | return godl.MatchTypeTruePositive 104 | } 105 | 106 | return godl.MatchTypeFalseNegative 107 | }, 108 | ValidationObserver: func(confMat godl.ConfusionMatrix, cost float32) { 109 | fmt.Printf("%v\nCost: %0.4f", confMat, cost) 110 | }, 111 | CostFn: godl.CategoricalCrossEntropyLoss(godl.CrossEntropyLossOpt{}), 112 | }) 113 | handleErr("training", err) 114 | } 115 | -------------------------------------------------------------------------------- /gbn_test.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | "gorgonia.org/gorgonia" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestGBN(t *testing.T) { 12 | testCases := []struct { 13 | desc string 14 | input tensor.Tensor 15 | vbs int 16 | expectedShape tensor.Shape 17 | expectedErr string 18 | expectedOutput []float32 19 | expectedGrad []float32 20 | expectedCost float64 21 | }{ 22 | // { 23 | // desc: "Example 1", 24 | // input: tensor.New( 25 | // tensor.WithShape(10, 1), 26 | // tensor.WithBacking([]float32{0.4, 1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4, 9.4}), 27 | // ), 28 | // vbs: 5, 29 | // expectedShape: tensor.Shape{10, 1}, 30 | // expectedOutput: []float32{-1.4142100268524473, -0.7071050134262237, -1.8394620353687656e-17, 0.7071050134262237, 1.4142100268524476, -1.4142100268524476, -0.7071050134262239, -2.5948842597279634e-16, 0.7071050134262234, 1.4142100268524471}, 31 | // expectedGrad: []float32{0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1}, 32 | // expectedCost: -1.3322676295501878e-16, 33 | // }, 34 | { 35 | desc: "Example 2", 36 | input: tensor.New( 37 | tensor.WithShape(5, 2), 38 | tensor.WithBacking([]float32{0.4, -1.4, 2.4, -3.4, 4.4, -5.4, 6.4, -7.4, 8.4, -9.4}), 39 | ), 40 | vbs: 5, 41 | expectedShape: tensor.Shape{5, 2}, 42 | expectedOutput: []float32{-1.4142126784904472, 1.4142126784904474, -0.7071063392452236, 0.7071063392452238, 1.0340285769764954e-16, 6.313059599612395e-17, 0.7071063392452237, -0.7071063392452235, 1.4142126784904472, -1.4142126784904472}, 43 | expectedGrad: []float32{0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1}, 44 | expectedCost: 8.881784197001253e-17, 45 | }, 46 | } 47 | 48 | for _, tcase := range testCases { 49 | t.Run(tcase.desc, func(t *testing.T) { 50 | c := require.New(t) 51 | 52 | tn := NewModel() 53 | g := tn.TrainGraph() 54 | 55 | input := gorgonia.NewTensor(g, tensor.Float32, 2, gorgonia.WithShape(tcase.input.Shape()...), gorgonia.WithName("GBNInput"), gorgonia.WithValue(tcase.input)) 56 | 57 | y := GhostBatchNorm(tn, GhostBatchNormOpts{ 58 | VirtualBatchSize: tcase.vbs, 59 | OutputDimension: tcase.input.Shape()[1], 60 | }).Forward(input)[0] 61 | 62 | cost := gorgonia.Must(gorgonia.Mean(y)) 63 | _, err := gorgonia.Grad(cost, append(tn.Learnables(), input)...) 64 | c.NoError(err) 65 | 66 | c.Equal(tcase.expectedShape, y.Shape()) 67 | 68 | vm := gorgonia.NewTapeMachine(tn.trainGraph, 69 | gorgonia.WithLogger(testLogger), 70 | gorgonia.BindDualValues(tn.learnables...), 71 | gorgonia.WithValueFmt("%+v"), 72 | gorgonia.WithWatchlist(), 73 | ) 74 | c.NoError(vm.RunAll()) 75 | 76 | t.Logf("dx: %v", input.Deriv().Value()) 77 | 78 | yGrad, err := y.Grad() 79 | c.NoError(err) 80 | 81 | c.InDeltaSlice(tcase.expectedOutput, y.Value().Data().([]float32), 1e-5, "actual: %#v", y.Value().Data()) 82 | c.InDelta(tcase.expectedCost, cost.Value().Data(), 1e-5, "actual: %#v", cost.Value().Data()) 83 | c.InDeltaSlice(tcase.expectedGrad, yGrad.Data(), 1e-5, "actual: %#v", yGrad.Data()) 84 | }) 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /examples/simple-tabnet/simple.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "math/rand" 7 | "time" 8 | 9 | "github.com/dcu/godl" 10 | "github.com/dcu/godl/activation" 11 | "github.com/dcu/godl/table" 12 | "github.com/dcu/godl/tabnet" 13 | "gorgonia.org/gorgonia" 14 | ) 15 | 16 | func init() { 17 | rand.Seed(time.Now().UnixNano()) 18 | } 19 | 20 | func handleErr(err error) { 21 | if err == nil { 22 | return 23 | } 24 | 25 | panic(err) 26 | } 27 | 28 | func main() { 29 | p, err := table.ReadCSV("dataset.csv") 30 | handleErr(err) 31 | 32 | fmt.Printf(">> Uniq values per column\n") 33 | for col, classes := range p.ClassesByColumn { 34 | if len(classes) > 0 { 35 | fmt.Printf("%s: %d\n", p.Header[col], len(classes)) 36 | } 37 | } 38 | 39 | p.AddTag(table.RandValueIn(map[string]float64{ 40 | "train": 0.8, 41 | "validate": 0.1, 42 | "test": 0.1, 43 | })) 44 | 45 | trainX, trainY := p.ToTensors(table.ToTensorOpts{TargetColumns: []int{10}, SelectTags: []string{"train"}}) 46 | validateX, validateY := p.ToTensors(table.ToTensorOpts{TargetColumns: []int{10}, SelectTags: []string{"validate"}}) 47 | testX, testY := p.ToTensors(table.ToTensorOpts{TargetColumns: []int{10}, SelectTags: []string{"test"}}) 48 | 49 | log.Printf("rows: %v", len(p.Rows)) 50 | 51 | log.Printf("train x: %v train y: %v", trainX, trainY) 52 | log.Printf("validateX: %v validateY: %v", validateX, validateY) 53 | log.Printf("testX: %v testY: %v", testX, testY) 54 | 55 | batchSize := 128 56 | if trainX.Shape()[0] < batchSize { 57 | batchSize = trainX.Shape()[0] 58 | } 59 | 60 | virtualBatchSize := 8 61 | 62 | regressor := tabnet.NewRegressor( 63 | trainX.Shape()[1], []int{}, []int{}, []int{}, tabnet.RegressorOpts{ 64 | BatchSize: batchSize, 65 | VirtualBatchSize: virtualBatchSize, 66 | MaskFunction: activation.Sigmoid, 67 | PredictionLayerDim: 8, 68 | AttentionLayerDim: 8, 69 | Gamma: 1.3, 70 | DecisionSteps: 3, 71 | IndependentBlocks: 2, 72 | SharedBlocks: 2, 73 | Momentum: 0.02, 74 | WithBias: false, 75 | Epsilon: 1e-15, 76 | }, 77 | ) 78 | 79 | err = regressor.Train(trainX, trainY, validateX, validateY, godl.TrainOpts{ 80 | BatchSize: batchSize, 81 | Epochs: 3, 82 | DevMode: true, 83 | Solver: gorgonia.NewAdamSolver(gorgonia.WithLearnRate(0.02), gorgonia.WithBatchSize(float64(batchSize))), 84 | MatchTypeFor: func(predVal, targetVal []float32) godl.MatchType { 85 | if targetVal[0] == 1 { 86 | if predVal[0] >= 0.5 { 87 | return godl.MatchTypeTruePositive 88 | } else { 89 | return godl.MatchTypeFalsePositive 90 | } 91 | } else { // == 0 92 | if predVal[0] < 0.5 { 93 | return godl.MatchTypeTrueNegative 94 | } else { 95 | return godl.MatchTypeFalseNegative 96 | } 97 | } 98 | }, 99 | ValidationObserver: func(confMat godl.ConfusionMatrix, cost float32) { 100 | fmt.Printf("%v\nCost: %0.4f", confMat, cost) 101 | }, 102 | WithLearnablesHeatmap: false, 103 | }) 104 | handleErr(err) 105 | 106 | out, err := regressor.Solve(testX, testY) 107 | handleErr(err) 108 | 109 | log.Printf("out: %v", out) 110 | } 111 | -------------------------------------------------------------------------------- /vggface2/identity_block.go: -------------------------------------------------------------------------------- 1 | package vggface2 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/dcu/godl" 7 | "gorgonia.org/gorgonia" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | type BlockOpts struct { 12 | Filters [3]int 13 | Stride []int 14 | Stage int 15 | KernelSize tensor.Shape 16 | Block int 17 | } 18 | 19 | type IdentityBlockModule struct { 20 | model *godl.Model 21 | layer godl.LayerType 22 | opts BlockOpts 23 | bns []*godl.BatchNormModule 24 | weights []*godl.Node 25 | } 26 | 27 | func (m *IdentityBlockModule) Name() string { 28 | return "IdentityBlock" 29 | } 30 | 31 | func (m *IdentityBlockModule) Forward(inputs ...*godl.Node) godl.Nodes { 32 | if err := m.model.CheckArity(m.layer, inputs, 1); err != nil { 33 | panic(err) 34 | } 35 | 36 | x := inputs[0] 37 | { 38 | 39 | x = gorgonia.Must(gorgonia.Conv2d(x, m.weights[0], tensor.Shape{1, 1}, []int{0, 0}, []int{1, 1}, []int{1, 1})) 40 | 41 | result := m.bns[0].Forward(x) 42 | x = gorgonia.Must(gorgonia.Rectify(result[0])) 43 | } 44 | 45 | { 46 | 47 | x = gorgonia.Must(gorgonia.Conv2d(x, m.weights[1], m.opts.KernelSize, []int{0, 0}, []int{1, 1}, []int{1, 1})) 48 | result := m.bns[1].Forward(x) 49 | 50 | x = gorgonia.Must(gorgonia.Rectify(result[0])) 51 | } 52 | 53 | { 54 | 55 | x = gorgonia.Must(gorgonia.Conv2d(x, m.weights[2], tensor.Shape{1, 1}, []int{0, 0}, []int{1, 1}, []int{1, 1})) 56 | 57 | result := m.bns[2].Forward(x) 58 | 59 | x = gorgonia.Must(gorgonia.Add(result[0], inputs[0])) 60 | x = gorgonia.Must(gorgonia.Rectify(x)) 61 | } 62 | 63 | return godl.Nodes{x} 64 | } 65 | 66 | func IdentityBlock(m *godl.Model, opts BlockOpts) *IdentityBlockModule { 67 | lt := godl.AddLayer("vggface2.IdentityBlock") 68 | 69 | conv1ReduceName := fmt.Sprintf("conv%d_%d_1x1_reduce", opts.Stage, opts.Block) 70 | conv1IncreaseName := fmt.Sprintf("conv%d_%d_1x1_increase", opts.Stage, opts.Block) 71 | conv3Name := fmt.Sprintf("conv%d_%d_3x3", opts.Stage, opts.Block) 72 | 73 | bn1 := godl.BatchNorm2d(m, godl.BatchNormOpts{ 74 | InputSize: opts.Filters[0], 75 | ScaleName: conv1ReduceName + "/bn/gamma", 76 | BiasName: conv1ReduceName + "/bn/beta", 77 | }) 78 | w1 := m.AddWeights(lt, tensor.Shape{opts.Filters[0], 3, 3, 3}, godl.NewWeightsOpts{ 79 | UniqueName: conv1ReduceName + "/kernel", 80 | }) 81 | 82 | bn2 := godl.BatchNorm2d(m, godl.BatchNormOpts{ 83 | InputSize: opts.Filters[1], 84 | ScaleName: conv3Name + "/bn/gamma", 85 | BiasName: conv3Name + "/bn/beta", 86 | }) 87 | w2 := m.AddWeights(lt, tensor.Shape{opts.Filters[1], opts.Filters[0], 3, 3}, godl.NewWeightsOpts{ 88 | UniqueName: conv3Name + "/kernel", 89 | }) 90 | 91 | bn3 := godl.BatchNorm2d(m, godl.BatchNormOpts{ 92 | InputSize: opts.Filters[2], 93 | ScaleName: conv1IncreaseName + "/bn/gamma", 94 | BiasName: conv1IncreaseName + "/bn/beta", 95 | }) 96 | w3 := m.AddWeights(lt, tensor.Shape{opts.Filters[2], opts.Filters[1], 3, 3}, godl.NewWeightsOpts{ 97 | UniqueName: conv1IncreaseName + "/kernel", 98 | }) 99 | 100 | return &IdentityBlockModule{ 101 | model: m, 102 | layer: lt, 103 | opts: opts, 104 | bns: []*godl.BatchNormModule{bn1, bn2, bn3}, 105 | weights: []*godl.Node{w1, w2, w3}, 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /embedding_generator_test.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | "gorgonia.org/gorgonia" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestEmbeddingGenerator(t *testing.T) { 12 | testCases := []struct { 13 | desc string 14 | classes int 15 | catDims []int 16 | catIdxs []int 17 | catEmbDim []int 18 | 19 | input []float32 20 | inputShape tensor.Shape 21 | expectedOutputShape tensor.Shape 22 | expectedOutput []float32 23 | expectedGrad []float32 24 | expectedCost float32 25 | }{ 26 | { 27 | // 0 1 2 3 4 5 28 | // cat idxs: 1 4 29 | desc: "Example 1", 30 | classes: 5, 31 | catIdxs: []int{1, 4}, 32 | catDims: []int{4, 4}, 33 | catEmbDim: []int{2, 2}, 34 | input: []float32{0, 1, 2, 1, 3}, 35 | inputShape: tensor.Shape{1, 5}, 36 | expectedOutputShape: tensor.Shape{1, 7}, 37 | expectedOutput: []float32{0, 2, 3, 2, 1, 6, 7}, 38 | expectedGrad: []float32{0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285}, 39 | expectedCost: 3, 40 | }, 41 | { 42 | desc: "Example 2", 43 | classes: 5, 44 | catIdxs: []int{1, 4}, 45 | catDims: []int{4, 4}, 46 | catEmbDim: []int{2, 2}, 47 | input: []float32{0, 1, 2, 1, 3, 0, 1, 2, 1, 3}, 48 | inputShape: tensor.Shape{2, 5}, 49 | expectedOutputShape: tensor.Shape{2, 7}, 50 | expectedOutput: []float32{0, 2, 3, 2, 1, 6, 7, 0, 2, 3, 2, 1, 6, 7}, 51 | expectedGrad: []float32{0.07142857142857142, 0.07142857142857142, 0.07142857142857142, 0.07142857142857142, 0.07142857142857142, 0.07142857142857142, 0.07142857142857142, 0.07142857142857142, 0.07142857142857142, 0.07142857142857142, 0.07142857142857142, 0.07142857142857142, 0.07142857142857142, 0.07142857142857142}, 52 | expectedCost: 3, 53 | }, 54 | } 55 | for _, tcase := range testCases { 56 | t.Run(tcase.desc, func(t *testing.T) { 57 | c := require.New(t) 58 | 59 | tn := NewModel() 60 | embedder := EmbeddingGenerator(tn, tcase.classes, tcase.catDims, tcase.catIdxs, tcase.catEmbDim, EmbeddingOpts{ 61 | WeightsInit: gorgonia.RangedFrom(0), 62 | }) 63 | 64 | ts := tensor.New( 65 | tensor.WithShape(tcase.inputShape...), 66 | tensor.WithBacking(tcase.input), 67 | ) 68 | 69 | input := gorgonia.NewTensor(tn.trainGraph, tensor.Float32, tcase.inputShape.Dims(), gorgonia.WithShape(tcase.inputShape...), gorgonia.WithValue(ts), gorgonia.WithName("input")) 70 | result := embedder.Forward(input)[0] 71 | 72 | cost := gorgonia.Must(gorgonia.Mean(result)) 73 | _, err := gorgonia.Grad(cost, tn.Learnables()...) 74 | c.NoError(err) 75 | 76 | vm := gorgonia.NewTapeMachine(tn.trainGraph) 77 | c.NoError(vm.RunAll()) 78 | 79 | c.Equal(tcase.expectedOutputShape, result.Shape()) 80 | c.Equal(tcase.expectedOutput, result.Value().Data()) 81 | 82 | yGrad, err := result.Grad() 83 | c.NoError(err) 84 | 85 | c.Equal(tcase.expectedGrad, yGrad.Data()) 86 | c.Equal(tcase.expectedCost, cost.Value().Data()) 87 | }) 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /vgg/block.go: -------------------------------------------------------------------------------- 1 | package vgg 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/dcu/godl" 7 | "github.com/dcu/godl/activation" 8 | "gorgonia.org/gorgonia" 9 | "gorgonia.org/tensor" 10 | ) 11 | 12 | // BlockOpts are the options for a VGG Block 13 | type BlockOpts struct { 14 | InputDimension int 15 | OutputDimension int 16 | 17 | Activation activation.Function 18 | Dropout float64 19 | KernelSize tensor.Shape 20 | Pad []int 21 | Stride []int 22 | Dilation []int 23 | WithBias bool 24 | WithPooling bool 25 | 26 | WeightsInit, BiasInit gorgonia.InitWFn 27 | WeightsName, BiasName string 28 | FixedWeights bool 29 | } 30 | 31 | func (o *BlockOpts) setDefaults() { 32 | if o.Activation == nil { 33 | o.Activation = activation.Rectify 34 | } 35 | 36 | if o.KernelSize == nil { 37 | o.KernelSize = tensor.Shape{3, 3} 38 | } 39 | 40 | if o.Pad == nil { 41 | o.Pad = []int{1, 1} 42 | } 43 | 44 | if o.Stride == nil { 45 | o.Stride = []int{1, 1} 46 | } 47 | 48 | if o.Dilation == nil { 49 | o.Dilation = []int{1, 1} 50 | } 51 | 52 | if o.WeightsInit == nil { 53 | k := math.Sqrt(1 / float64(o.OutputDimension*o.KernelSize[0]*o.KernelSize[1])) 54 | o.WeightsInit = gorgonia.Uniform(-k, k) 55 | } 56 | 57 | if o.BiasInit == nil { 58 | k := math.Sqrt(1 / float64(o.OutputDimension*o.KernelSize[0]*o.KernelSize[1])) 59 | o.WeightsInit = gorgonia.Uniform(-k, k) 60 | } 61 | } 62 | 63 | type BlockModule struct { 64 | model *godl.Model 65 | layer godl.LayerType 66 | opts BlockOpts 67 | 68 | weight, bias *godl.Node 69 | } 70 | 71 | func (m *BlockModule) Name() string { 72 | return "VGGBlock" 73 | } 74 | 75 | func (m *BlockModule) Forward(inputs ...*godl.Node) godl.Nodes { 76 | if err := m.model.CheckArity(m.layer, inputs, 1); err != nil { 77 | panic(err) 78 | } 79 | 80 | x := inputs[0] 81 | x = gorgonia.Must(gorgonia.Conv2d(x, m.weight, m.opts.KernelSize, m.opts.Pad, m.opts.Stride, m.opts.Dilation)) 82 | 83 | if m.bias != nil { 84 | x = gorgonia.Must(gorgonia.BroadcastAdd(x, m.bias, nil, []byte{0, 2, 3})) 85 | } 86 | 87 | if m.opts.Activation != nil { 88 | x = gorgonia.Must(m.opts.Activation(x)) 89 | } 90 | 91 | if m.opts.WithPooling { 92 | x = gorgonia.Must(gorgonia.MaxPool2D(x, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2})) 93 | } 94 | 95 | if m.opts.Dropout > 0.0 { 96 | x = gorgonia.Must(gorgonia.Dropout(x, m.opts.Dropout)) 97 | } 98 | 99 | return godl.Nodes{x} 100 | } 101 | 102 | // Block is a VGG block composed of conv2d+maxpool with optional dropout and activation function 103 | func Block(m *godl.Model, opts BlockOpts) *BlockModule { 104 | opts.setDefaults() 105 | 106 | lt := godl.AddLayer("vgg.Block") 107 | 108 | w := m.AddWeights(lt, tensor.Shape{opts.OutputDimension, opts.InputDimension, opts.KernelSize[0], opts.KernelSize[0]}, godl.NewWeightsOpts{ 109 | InitFN: opts.WeightsInit, 110 | UniqueName: opts.WeightsName, 111 | Fixed: opts.FixedWeights, 112 | }) 113 | 114 | var bias *gorgonia.Node 115 | if opts.WithBias { 116 | bias = m.AddBias(lt, tensor.Shape{1, opts.OutputDimension, 1, 1}, godl.NewWeightsOpts{ 117 | InitFN: opts.BiasInit, 118 | UniqueName: opts.BiasName, 119 | Fixed: opts.FixedWeights, 120 | }) 121 | } 122 | 123 | return &BlockModule{ 124 | model: m, 125 | layer: lt, 126 | opts: opts, 127 | weight: w, 128 | bias: bias, 129 | } 130 | } 131 | 132 | var ( 133 | _ godl.Module = &BlockModule{} 134 | ) 135 | -------------------------------------------------------------------------------- /table/table_test.go: -------------------------------------------------------------------------------- 1 | package table 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestRandValueIn(t *testing.T) { 12 | testCases := []struct { 13 | valsAndProbs map[string]float64 14 | }{ 15 | { 16 | valsAndProbs: map[string]float64{ 17 | "large": 0.8, 18 | "medium": 0.1, 19 | "small": 0.1, 20 | }, 21 | }, 22 | } 23 | for i, tC := range testCases { 24 | t.Run(fmt.Sprintf("Example #%d", i+1), func(t *testing.T) { 25 | c := require.New(t) 26 | 27 | rv := RandValueIn(tC.valsAndProbs) 28 | counters := map[string]int{} 29 | n := 100 30 | 31 | for i := 0; i < n; i++ { 32 | counters[rv()]++ 33 | } 34 | 35 | t.Logf("counters: %#v", counters) 36 | 37 | for val, prob := range tC.valsAndProbs { 38 | expected := prob * float64(n) 39 | 40 | c.InDelta(expected, counters[val], 5) 41 | } 42 | }) 43 | } 44 | } 45 | 46 | func TestReadCSV(t *testing.T) { 47 | testCases := []struct { 48 | csvPath string 49 | SumColumn int 50 | ExpectedSum float32 51 | SliceAt int 52 | ExpectedSliceValue []float32 53 | 54 | TargetColumns []int 55 | ExpectedCatIndexes []int 56 | ExpectedCatDims []int 57 | }{ 58 | { 59 | csvPath: "fixtures/fruits.csv", 60 | SliceAt: 1, 61 | ExpectedSliceValue: []float32{2, 200}, 62 | SumColumn: 1, 63 | ExpectedSum: 450, 64 | TargetColumns: []int{2}, 65 | ExpectedCatIndexes: []int{0}, 66 | ExpectedCatDims: []int{3}, 67 | }, 68 | { 69 | csvPath: "fixtures/census.csv", 70 | SliceAt: 11, 71 | ExpectedSliceValue: []float32{30, 2, 141297, 5, 13, 1, 7, 0, 1, 1, 0, 0, 40, 2}, 72 | SumColumn: 4, 73 | ExpectedSum: 212, 74 | TargetColumns: []int{14}, 75 | ExpectedCatIndexes: []int{1, 3, 5, 6, 7, 8, 9, 13}, 76 | ExpectedCatDims: []int{3, 9, 4, 10, 5, 4, 2, 6}, 77 | }, 78 | } 79 | for i, tC := range testCases { 80 | t.Run(fmt.Sprintf("Example #%d (%s)", i+1, tC.csvPath), func(t *testing.T) { 81 | c := require.New(t) 82 | table, err := ReadCSV(tC.csvPath) 83 | c.NoError(err) 84 | 85 | t.Logf("rows: %#v", table.Rows) 86 | 87 | table.EachColumn(func(columnName string, v Cell) { 88 | if columnName == " 100" && v.Dtype == tensor.Int64 { 89 | c.Equal(100, v.V) 90 | } 91 | }) 92 | 93 | sum := float32(0.0) 94 | table.EachRow(func(row Row) { 95 | v := row.Cells[tC.SumColumn] 96 | if v.Dtype == tensor.Int { 97 | sum += float32(row.Cells[tC.SumColumn].Int()) 98 | } else if v.Dtype == tensor.Float32 { 99 | sum += row.Cells[tC.SumColumn].Float32() 100 | } 101 | }) 102 | 103 | c.Equal(tC.ExpectedSum, sum) 104 | 105 | t.Logf("classes: %v", table.ClassesByColumn) 106 | 107 | x, y := table.ToTensors(ToTensorOpts{ 108 | TargetColumns: tC.TargetColumns, 109 | }) 110 | s, err := x.Slice(tensor.S(tC.SliceAt)) 111 | c.NoError(err) 112 | 113 | t.Logf("x:\n%#v %v", x, x.Shape()) 114 | 115 | if y != nil { 116 | t.Logf("y:\n%#v %v", y, y.Shape()) 117 | } 118 | 119 | c.Equal(tC.ExpectedSliceValue, s.Data()) 120 | 121 | idx, dims := table.CategoricalColumns(tC.TargetColumns...) 122 | t.Logf("%v %v", idx, dims) 123 | 124 | c.Equal(tC.ExpectedCatIndexes, idx) 125 | c.Equal(tC.ExpectedCatDims, dims) 126 | }) 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /examples/tabnet/census.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "math/rand" 7 | "time" 8 | 9 | "github.com/dcu/godl" 10 | "github.com/dcu/godl/activation" 11 | "github.com/dcu/godl/table" 12 | "github.com/dcu/godl/tabnet" 13 | "gorgonia.org/gorgonia" 14 | ) 15 | 16 | func init() { 17 | rand.Seed(time.Now().UnixNano()) 18 | } 19 | 20 | func handleErr(err error) { 21 | if err == nil { 22 | return 23 | } 24 | 25 | panic(err) 26 | } 27 | 28 | func main() { 29 | p, err := table.ReadCSV("adult.data") 30 | handleErr(err) 31 | 32 | fmt.Printf(">> Uniq values per column\n") 33 | for col, classes := range p.ClassesByColumn { 34 | if len(classes) > 0 { 35 | fmt.Printf("%s: %d\n", p.Header[col], len(classes)) 36 | } 37 | } 38 | 39 | p.AddTag(table.RandValueIn(map[string]float64{ 40 | "train": 0.8, 41 | "validate": 0.1, 42 | "test": 0.1, 43 | })) 44 | 45 | trainX, trainY := p.ToTensors(table.ToTensorOpts{TargetColumns: []int{14}, SelectTags: []string{"train"}}) 46 | validateX, validateY := p.ToTensors(table.ToTensorOpts{TargetColumns: []int{14}, SelectTags: []string{"validate"}}) 47 | testX, testY := p.ToTensors(table.ToTensorOpts{TargetColumns: []int{14}, SelectTags: []string{"test"}}) 48 | 49 | log.Printf("rows: %v", len(p.Rows)) 50 | 51 | log.Printf("train x: %v train y: %v", trainX, trainY) 52 | log.Printf("validateX: %v validateY: %v", validateX, validateY) 53 | log.Printf("testX: %v testY: %v", testX, testY) 54 | 55 | catIdxs, catDims := p.CategoricalColumns(14) 56 | 57 | batchSize := 128 58 | if trainX.Shape()[0] < batchSize { 59 | batchSize = trainX.Shape()[0] 60 | } 61 | 62 | virtualBatchSize := 16 63 | catEmbDim := []int{5, 4, 3, 6, 2, 2, 1, 10} 64 | 65 | log.Printf("cat dims: %v", catDims) 66 | log.Printf("cat emb dims: %v", catEmbDim) 67 | log.Printf("cat idxs: %v", catIdxs) 68 | 69 | regressor := tabnet.NewRegressor( 70 | trainX.Shape()[1], catDims, catIdxs, catEmbDim, tabnet.RegressorOpts{ 71 | BatchSize: batchSize, 72 | VirtualBatchSize: virtualBatchSize, 73 | MaskFunction: activation.SparseMax, 74 | PredictionLayerDim: 8, 75 | AttentionLayerDim: 8, 76 | Gamma: 1.3, 77 | DecisionSteps: 3, 78 | IndependentBlocks: 2, 79 | SharedBlocks: 2, 80 | Momentum: 0.02, 81 | WithBias: false, 82 | Epsilon: 1e-15, 83 | }, 84 | ) 85 | 86 | err = regressor.Train(trainX, trainY, validateX, validateY, godl.TrainOpts{ 87 | BatchSize: batchSize, 88 | Epochs: 10, 89 | DevMode: false, 90 | Solver: gorgonia.NewAdamSolver(gorgonia.WithLearnRate(0.02), gorgonia.WithBatchSize(float64(batchSize))), 91 | MatchTypeFor: func(predVal, targetVal []float32) godl.MatchType { 92 | // log.Printf("%v vs %v", predVal, targetVal) 93 | 94 | if targetVal[0] == 1 { 95 | if predVal[0] >= 0.5 { 96 | return godl.MatchTypeTruePositive 97 | } else { 98 | return godl.MatchTypeFalsePositive 99 | } 100 | } else { // == 0 101 | if predVal[0] < 0.5 { 102 | return godl.MatchTypeTrueNegative 103 | } else { 104 | return godl.MatchTypeFalseNegative 105 | } 106 | } 107 | }, 108 | ValidationObserver: func(confMat godl.ConfusionMatrix, cost float32) { 109 | fmt.Printf("%v\nCost: %0.4f", confMat, cost) 110 | }, 111 | WithLearnablesHeatmap: false, 112 | }) 113 | handleErr(err) 114 | 115 | out, err := regressor.Solve(testX, testY) 116 | handleErr(err) 117 | 118 | log.Printf("out: %v", out) 119 | } 120 | -------------------------------------------------------------------------------- /losses_test.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | "gorgonia.org/gorgonia" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestMSELoss(t *testing.T) { 12 | testCases := []struct { 13 | desc string 14 | output tensor.Tensor 15 | target tensor.Tensor 16 | expectedLoss float32 17 | }{ 18 | { 19 | desc: "Example 1", 20 | output: tensor.New( 21 | tensor.WithShape(1), 22 | tensor.WithBacking([]float32{0.5}), 23 | ), 24 | target: tensor.New( 25 | tensor.WithShape(1), 26 | tensor.WithBacking([]float32{0.1}), 27 | ), 28 | expectedLoss: 0.16000001, 29 | }, 30 | { 31 | desc: "Example 2", 32 | output: tensor.New( 33 | tensor.WithShape(2, 2), 34 | tensor.WithBacking([]float32{0.5, 0.2, 0.5, 0.7}), 35 | ), 36 | target: tensor.New( 37 | tensor.WithShape(2, 2), 38 | tensor.WithBacking([]float32{0.1, 0.2, 0.3, 0.9}), 39 | ), 40 | expectedLoss: 0.06000000000000002, 41 | }, 42 | } 43 | for _, tC := range testCases { 44 | t.Run(tC.desc, func(t *testing.T) { 45 | c := require.New(t) 46 | 47 | g := gorgonia.NewGraph() 48 | 49 | outputNode := gorgonia.NewTensor(g, tensor.Float32, tC.output.Shape().Dims(), gorgonia.WithShape(tC.output.Shape()...), gorgonia.WithValue(tC.output), gorgonia.WithName("output")) 50 | targetNode := gorgonia.NewTensor(g, tensor.Float32, tC.target.Shape().Dims(), gorgonia.WithShape(tC.target.Shape()...), gorgonia.WithValue(tC.target), gorgonia.WithName("target")) 51 | 52 | loss := MSELoss(MSELossOpts{})(Nodes{outputNode}, targetNode) 53 | 54 | var lossV gorgonia.Value 55 | gorgonia.Read(loss, &lossV) 56 | 57 | vm := gorgonia.NewTapeMachine(g) 58 | c.NoError(vm.RunAll()) 59 | 60 | c.Equal(tC.expectedLoss, lossV.Data()) 61 | }) 62 | } 63 | } 64 | 65 | func TestCrossEntropyLoss(t *testing.T) { 66 | testCases := []struct { 67 | desc string 68 | reduction Reduction 69 | output tensor.Tensor 70 | target tensor.Tensor 71 | expectedLoss float32 72 | }{ 73 | { 74 | desc: "Example 1", 75 | reduction: ReductionSum, 76 | output: tensor.New( 77 | tensor.WithShape(2), 78 | tensor.WithBacking([]float32{0.5, 0.1}), 79 | ), 80 | target: tensor.New( 81 | tensor.WithShape(2), 82 | tensor.WithBacking([]float32{1, 0}), 83 | ), 84 | expectedLoss: 0.6931471805599453, 85 | }, 86 | { 87 | desc: "Example 2", 88 | reduction: ReductionMean, 89 | output: tensor.New( 90 | tensor.WithShape(2, 2), 91 | tensor.WithBacking([]float32{0.5, 0.2, 0.5, 0.7}), 92 | ), 93 | target: tensor.New( 94 | tensor.WithShape(2, 2), 95 | tensor.WithBacking([]float32{0.1, 0.2, 0.3, 0.9}), 96 | ), 97 | expectedLoss: 0.2300385, 98 | }, 99 | } 100 | for _, tC := range testCases { 101 | t.Run(tC.desc, func(t *testing.T) { 102 | c := require.New(t) 103 | 104 | g := gorgonia.NewGraph() 105 | 106 | outputNode := gorgonia.NewTensor(g, tensor.Float32, tC.output.Shape().Dims(), gorgonia.WithShape(tC.output.Shape()...), gorgonia.WithValue(tC.output), gorgonia.WithName("output")) 107 | targetNode := gorgonia.NewTensor(g, tensor.Float32, tC.target.Shape().Dims(), gorgonia.WithShape(tC.target.Shape()...), gorgonia.WithValue(tC.target), gorgonia.WithName("target")) 108 | 109 | loss := CrossEntropyLoss(CrossEntropyLossOpt{ 110 | Reduction: tC.reduction, 111 | })(Nodes{outputNode}, targetNode) 112 | 113 | var lossV gorgonia.Value 114 | gorgonia.Read(loss, &lossV) 115 | 116 | vm := gorgonia.NewTapeMachine(g) 117 | c.NoError(vm.RunAll()) 118 | 119 | c.Equal(tC.expectedLoss, lossV.Data()) 120 | }) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /vggface2/conv_block.go: -------------------------------------------------------------------------------- 1 | package vggface2 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/dcu/godl" 7 | "gorgonia.org/gorgonia" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | type ConvBlockModule struct { 12 | model *godl.Model 13 | layer godl.LayerType 14 | opts BlockOpts 15 | bns []*godl.BatchNormModule 16 | weights godl.Nodes 17 | } 18 | 19 | func (m *ConvBlockModule) Name() string { 20 | return "ConvBlock" 21 | } 22 | 23 | func (m *ConvBlockModule) Forward(inputs ...*godl.Node) godl.Nodes { 24 | if err := m.model.CheckArity(m.layer, inputs, 1); err != nil { 25 | panic(err) 26 | } 27 | 28 | x := inputs[0] 29 | { 30 | x = gorgonia.Must(gorgonia.Conv2d(x, m.weights[0], tensor.Shape{1, 1}, []int{0, 0}, m.opts.Stride, []int{1, 1})) 31 | 32 | result := m.bns[0].Forward(x) 33 | 34 | x = gorgonia.Must(gorgonia.Rectify(result[0])) 35 | } 36 | 37 | { 38 | x = gorgonia.Must(gorgonia.Conv2d(x, m.weights[1], m.opts.KernelSize, []int{1, 1}, []int{1, 1}, []int{1, 1})) 39 | 40 | result := m.bns[1].Forward(x) 41 | 42 | x = gorgonia.Must(gorgonia.Rectify(result[0])) 43 | } 44 | 45 | { 46 | x = gorgonia.Must(gorgonia.Conv2d(x, m.weights[2], tensor.Shape{1, 1}, []int{0, 0}, []int{1, 1}, []int{1, 1})) 47 | 48 | result := m.bns[2].Forward(x) 49 | 50 | x = result[0] 51 | } 52 | 53 | { 54 | shortCut := gorgonia.Must(gorgonia.Conv2d(x, m.weights[3], tensor.Shape{1, 1}, []int{0, 0}, m.opts.Stride, []int{1, 1})) 55 | 56 | result := m.bns[3].Forward(shortCut) 57 | 58 | x = gorgonia.Must(gorgonia.Add(x, result[0])) 59 | x = gorgonia.Must(gorgonia.Rectify(x)) 60 | } 61 | 62 | return godl.Nodes{x} 63 | } 64 | 65 | func ConvBlock(m *godl.Model, opts BlockOpts) *ConvBlockModule { 66 | lt := godl.AddLayer("vggface2.ConvBlock") 67 | 68 | conv1ReduceName := fmt.Sprintf("conv%d_%d_1x1_reduce", opts.Stage, opts.Block) 69 | conv1IncreaseName := fmt.Sprintf("conv%d_%d_1x1_increase", opts.Stage, opts.Block) 70 | conv1ProjName := fmt.Sprintf("conv%d_%d_1x1_proj", opts.Stage, opts.Block) 71 | conv3Name := fmt.Sprintf("conv%d_%d_3x3", opts.Stage, opts.Block) 72 | 73 | bn1 := godl.BatchNorm2d(m, godl.BatchNormOpts{ 74 | InputSize: opts.Filters[0], 75 | ScaleName: conv1ReduceName + "/bn/gamma", 76 | BiasName: conv1ReduceName + "/bn/beta", 77 | }) 78 | w1 := m.AddWeights(lt, tensor.Shape{opts.Filters[0], 3, 3, 3}, godl.NewWeightsOpts{ 79 | UniqueName: conv1ReduceName + "/kernel", 80 | }) 81 | 82 | bn2 := godl.BatchNorm2d(m, godl.BatchNormOpts{ 83 | InputSize: opts.Filters[1], 84 | ScaleName: conv3Name + "/bn/gamma", 85 | BiasName: conv3Name + "/bn/beta", 86 | }) 87 | w2 := m.AddWeights(lt, tensor.Shape{opts.Filters[1], opts.Filters[0], 3, 3}, godl.NewWeightsOpts{ 88 | UniqueName: conv3Name + "/kernel", 89 | }) 90 | 91 | bn3 := godl.BatchNorm2d(m, godl.BatchNormOpts{ 92 | InputSize: opts.Filters[2], 93 | ScaleName: conv1IncreaseName + "/bn/gamma", 94 | BiasName: conv1IncreaseName + "/bn/beta", 95 | }) 96 | w3 := m.AddWeights(lt, tensor.Shape{opts.Filters[2], opts.Filters[1], 3, 3}, godl.NewWeightsOpts{ 97 | UniqueName: conv1IncreaseName + "/kernel", 98 | }) 99 | 100 | bn4 := godl.BatchNorm2d(m, godl.BatchNormOpts{ 101 | InputSize: opts.Filters[2], 102 | ScaleName: conv1ProjName + "/bn/gamma", 103 | BiasName: conv1ProjName + "/bn/beta", 104 | }) 105 | w4 := m.AddWeights(lt, tensor.Shape{opts.Filters[2], opts.Filters[1], 3, 3}, godl.NewWeightsOpts{ 106 | UniqueName: conv1ProjName + "/kernel", 107 | }) 108 | 109 | return &ConvBlockModule{ 110 | model: m, 111 | layer: lt, 112 | opts: opts, 113 | bns: []*godl.BatchNormModule{ 114 | bn1, bn2, bn3, bn4, 115 | }, 116 | weights: godl.Nodes{ 117 | w1, w2, w3, w4, 118 | }, 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /tabnet/attentive_transformer_test.go: -------------------------------------------------------------------------------- 1 | package tabnet 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/dcu/godl" 7 | "github.com/stretchr/testify/require" 8 | "gorgonia.org/gorgonia" 9 | "gorgonia.org/tensor" 10 | ) 11 | 12 | func TestAttentiveTransformer(t *testing.T) { 13 | testCases := []struct { 14 | desc string 15 | input tensor.Tensor 16 | priors []float32 17 | vbs int 18 | output int 19 | expectedShape tensor.Shape 20 | expectedErr string 21 | expectedOutput []float32 22 | expectedGrad []float32 23 | expectedCost float32 24 | expectedInputGrad []float32 25 | }{ 26 | { 27 | desc: "Example 1", 28 | input: tensor.New( 29 | tensor.WithShape(6, 2), 30 | tensor.WithBacking([]float32{0.1, -0.5, 0.3, 0.9, 0.04, -0.3, 0.01, 0.09, -0.1, 0.9, 0.7, 0.04}), 31 | ), 32 | priors: []float32{-1.0143, 0.9077, 0.8760, -2.8345, 0.9163, -1.5155, -0.8302, 0.5957, -0.9591, 0.4161, -0.2541, 0.6725}, 33 | vbs: 2, 34 | output: 2, 35 | expectedShape: tensor.Shape{6, 2}, 36 | expectedOutput: []float32{-0.05, 0.009999998, -0.05, 0.009999998, -0.020000001, 0.04, -0.020000001, 0.04, -0.020000001, 0.04, -0.04882242, 0.01117758}, 37 | expectedGrad: []float32{0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333}, 38 | expectedCost: -0.0048037386, 39 | expectedInputGrad: []float32{0, 0, 0, 0, 0, 0, 0, 0, -0.00042192982004980896, -0.00042192982004980896, 0.00042192982004980896, 0.00042192982004980896}, 40 | }, 41 | } 42 | 43 | for _, tcase := range testCases { 44 | t.Run(tcase.desc, func(t *testing.T) { 45 | c := require.New(t) 46 | 47 | tn := godl.NewModel() 48 | g := tn.TrainGraph() 49 | 50 | input := gorgonia.NewTensor(g, tensor.Float32, tcase.input.Dims(), gorgonia.WithShape(tcase.input.Shape()...), gorgonia.WithName("input"), gorgonia.WithValue(tcase.input)) 51 | priors := gorgonia.NewTensor(g, tensor.Float32, tcase.input.Dims(), gorgonia.WithShape(tcase.input.Shape()...), gorgonia.WithValue( 52 | tensor.New( 53 | tensor.WithShape(input.Shape()...), 54 | tensor.WithBacking(tcase.priors), 55 | ), 56 | ), 57 | gorgonia.WithName("priors"), 58 | ) 59 | result := AttentiveTransformer(tn, AttentiveTransformerOpts{ 60 | VirtualBatchSize: tcase.vbs, 61 | InputDimension: input.Shape()[1], 62 | OutputDimension: tcase.output, 63 | WeightsInit: initDummyWeights, 64 | }).Forward(input, priors) 65 | 66 | fcWeight := gorgonia.NewTensor(g, tensor.Float32, 2, gorgonia.WithShape(input.Shape()[1], tcase.output), gorgonia.WithInit(gorgonia.RangedFromWithStep(-0.05, 0.03)), gorgonia.WithName("fcWeight")) 67 | 68 | y := result[0] 69 | wT := gorgonia.Must(gorgonia.Transpose(fcWeight, 1, 0)) 70 | y = gorgonia.Must(gorgonia.Mul(y, wT)) 71 | 72 | cost := gorgonia.Must(gorgonia.Mean(y)) 73 | _, err := gorgonia.Grad(cost, input) 74 | c.NoError(err) 75 | 76 | vm := gorgonia.NewTapeMachine(g, gorgonia.BindDualValues(append(tn.Learnables(), fcWeight)...)) 77 | c.NoError(vm.RunAll()) 78 | 79 | tn.PrintWatchables() 80 | 81 | t.Logf("input: %v", input.Value()) 82 | t.Logf("priors: %v", priors.Value()) 83 | t.Logf("dx: %v", input.Deriv().Value()) 84 | t.Logf("att output: %v", y.Value()) 85 | 86 | c.Equal(tcase.expectedShape, y.Shape()) 87 | c.Equal(tcase.expectedOutput, y.Value().Data().([]float32)) 88 | 89 | yGrad, err := y.Grad() 90 | c.NoError(err) 91 | 92 | c.Equal(tcase.expectedGrad, yGrad.Data()) 93 | c.Equal(tcase.expectedCost, cost.Value().Data()) 94 | 95 | c.InDeltaSlice(tcase.expectedInputGrad, input.Deriv().Value().Data(), 1e-5, "actual: %#v", input.Deriv().Value().Data()) 96 | }) 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /imageutils/tensor.go: -------------------------------------------------------------------------------- 1 | package imageutils 2 | 3 | import ( 4 | "fmt" 5 | "image" 6 | "image/color" 7 | "io/fs" 8 | "path/filepath" 9 | "regexp" 10 | 11 | "gorgonia.org/tensor" 12 | ) 13 | 14 | var ( 15 | imgRegexp = regexp.MustCompile(`\.(png|jpg|jpeg)$`) 16 | ) 17 | 18 | type TensorMode string 19 | 20 | const ( 21 | TensorModeCaffe TensorMode = "caffe" 22 | TensorModeTensorFlow TensorMode = "tensorflow" 23 | TensorModeTorch TensorMode = "torch" 24 | ) 25 | 26 | // ToTensorOpts are the options to convert the image to a tensor 27 | type ToTensorOpts struct { 28 | // TensorMode is the mode to weight the pixels. The default one is Caffe 29 | TensorMode TensorMode 30 | } 31 | 32 | // ToTensorFromDirectory loads all images in a directory 33 | func ToTensorFromDirectory(dirPath string, loadOpts LoadOpts, tensorOpts ToTensorOpts) (tensor.Tensor, error) { 34 | if len(loadOpts.TargetSize) != 2 { 35 | return nil, fmt.Errorf("TargetSize must be defined") 36 | } 37 | 38 | backing := []float32{} 39 | imagesCount := 0 40 | 41 | err := filepath.WalkDir(dirPath, func(path string, d fs.DirEntry, err error) error { 42 | if err != nil { 43 | return err 44 | } 45 | 46 | if !imgRegexp.MatchString(path) { 47 | return nil 48 | } 49 | 50 | img, err := Load(path, loadOpts) 51 | if err != nil { 52 | return err 53 | } 54 | 55 | weights := ToArray(img, tensorOpts) 56 | backing = append(backing, weights...) 57 | imagesCount++ 58 | 59 | return nil 60 | }) 61 | if err != nil { 62 | return nil, err 63 | } 64 | 65 | return tensor.New( 66 | tensor.Of(tensor.Float32), 67 | tensor.WithShape(imagesCount, 3, int(loadOpts.TargetSize[0]), int(loadOpts.TargetSize[1])), // count, channels, width, height 68 | tensor.WithBacking(backing), 69 | ), nil 70 | } 71 | 72 | // ToTensor converts the given image to a tensor 73 | func ToTensor(img image.Image, opts ToTensorOpts) tensor.Tensor { 74 | bounds := img.Bounds() 75 | 76 | return tensor.New( 77 | tensor.Of(tensor.Float32), 78 | tensor.WithShape(1, 3, bounds.Max.X, bounds.Max.Y), // batchSize, channels, width, height 79 | tensor.WithBacking(ToArray(img, opts)), 80 | ) 81 | } 82 | 83 | // ToArray converts the image in a []float32 84 | func ToArray(img image.Image, opts ToTensorOpts) []float32 { 85 | bounds := img.Bounds() 86 | width, height := bounds.Max.X, bounds.Max.Y 87 | 88 | pixels := make([]float32, 3*width*height) 89 | 90 | for x := 0; x < width; x++ { 91 | for y := 0; y < height; y++ { 92 | w1, w2, w3 := pixelWeight(img.At(x, y), opts) 93 | 94 | pixels[width*y+x] = w1 95 | pixels[(width*y+x)+1*width*height] = w2 96 | pixels[(width*y+x)+2*width*height] = w3 97 | } 98 | } 99 | 100 | return pixels 101 | } 102 | 103 | func pixelWeight(pixel color.Color, opts ToTensorOpts) (float32, float32, float32) { 104 | r, g, b, _ := pixel.RGBA() 105 | 106 | switch opts.TensorMode { 107 | case TensorModeTensorFlow: 108 | // https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/python/keras/applications/imagenet_utils.py#L192 109 | 110 | return float32(r/256)/127.5 - 1.0, 111 | float32(g/256)/127.5 - 1.0, 112 | float32(b/256)/127.5 - 1.0 113 | case TensorModeTorch: 114 | // https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/python/keras/applications/imagenet_utils.py#L197 115 | mean := []float32{0.485, 0.456, 0.406} 116 | std := []float32{0.229, 0.224, 0.225} 117 | 118 | return (float32(r)/65536 - mean[0]) / std[0], 119 | (float32(g)/65536 - mean[1]) / std[1], 120 | (float32(b)/65536 - mean[2]) / std[2] 121 | default: 122 | // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/imagenet_utils.py#L202 123 | mean := []float32{103.939, 116.779, 123.68} 124 | 125 | // RGB -> BGR 126 | return float32(b/256) - mean[0], 127 | float32(g/256) - mean[1], 128 | float32(r/256) - mean[2] 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /ui/ui.go: -------------------------------------------------------------------------------- 1 | package ui 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/mum4k/termdash" 9 | "github.com/mum4k/termdash/cell" 10 | "github.com/mum4k/termdash/container" 11 | "github.com/mum4k/termdash/keyboard" 12 | "github.com/mum4k/termdash/linestyle" 13 | "github.com/mum4k/termdash/terminal/termbox" 14 | "github.com/mum4k/termdash/terminal/terminalapi" 15 | "github.com/mum4k/termdash/widgets/gauge" 16 | "github.com/mum4k/termdash/widgets/linechart" 17 | "github.com/mum4k/termdash/widgets/text" 18 | ) 19 | 20 | const ( 21 | rootID = "root" 22 | ) 23 | 24 | type UI struct { 25 | term *termbox.Terminal 26 | container *container.Container 27 | epochsBar, batchesBar *gauge.Gauge 28 | costLine *linechart.LineChart 29 | costText *text.Text 30 | costs []float32 31 | } 32 | 33 | func New() *UI { 34 | t, err := termbox.New(termbox.ColorMode(terminalapi.ColorMode256)) 35 | if err != nil { 36 | panic(err) 37 | } 38 | 39 | epochsBar, err := gauge.New( 40 | gauge.TextLabel("Epoch"), 41 | gauge.Color(cell.ColorCyan), 42 | gauge.FilledTextColor(cell.ColorBlack), 43 | gauge.EmptyTextColor(cell.ColorYellow), 44 | ) 45 | if err != nil { 46 | panic(err) 47 | } 48 | 49 | batchesBar, err := gauge.New( 50 | gauge.TextLabel("Batch"), 51 | gauge.Color(cell.ColorCyan), 52 | gauge.FilledTextColor(cell.ColorBlack), 53 | gauge.EmptyTextColor(cell.ColorYellow), 54 | ) 55 | if err != nil { 56 | panic(err) 57 | } 58 | 59 | costLine, err := linechart.New( 60 | linechart.AxesCellOpts(cell.FgColor(cell.ColorCyan)), 61 | linechart.YLabelCellOpts(cell.FgColor(cell.ColorGreen)), 62 | linechart.XLabelCellOpts(cell.FgColor(cell.ColorGreen)), 63 | ) 64 | if err != nil { 65 | panic(err) 66 | } 67 | 68 | costText, err := text.New(text.RollContent()) 69 | if err != nil { 70 | panic(err) 71 | } 72 | 73 | c, err := container.New( 74 | t, container.ID(rootID), 75 | container.SplitVertical( 76 | container.Left( 77 | container.SplitHorizontal( 78 | container.Top( 79 | container.Border(linestyle.Light), 80 | container.PlaceWidget(costLine), 81 | ), 82 | container.Bottom( 83 | container.PlaceWidget(costText), 84 | container.Border(linestyle.Light), 85 | ), 86 | ), 87 | ), 88 | container.Right( 89 | container.SplitHorizontal( 90 | container.Top( 91 | container.PlaceWidget(epochsBar), 92 | container.Border(linestyle.Light), 93 | ), 94 | container.Bottom( 95 | container.PlaceWidget(batchesBar), 96 | container.Border(linestyle.Light), 97 | ), 98 | ), 99 | ), 100 | ), 101 | ) 102 | if err != nil { 103 | panic(err) 104 | } 105 | 106 | return &UI{ 107 | term: t, 108 | container: c, 109 | epochsBar: epochsBar, 110 | batchesBar: batchesBar, 111 | costLine: costLine, 112 | costText: costText, 113 | } 114 | } 115 | 116 | func (ui *UI) Start() { 117 | ctx, cancel := context.WithCancel(context.Background()) 118 | 119 | quitter := func(k *terminalapi.Keyboard) { 120 | if k.Key == keyboard.KeyEsc || k.Key == keyboard.KeyCtrlC { 121 | cancel() 122 | } 123 | } 124 | if err := termdash.Run(ctx, ui.term, ui.container, termdash.KeyboardSubscriber(quitter), termdash.RedrawInterval(250*time.Millisecond)); err != nil { 125 | panic(err) 126 | } 127 | } 128 | 129 | func (ui *UI) UpdateCost(epoch, epochs int, batch, batches int, cost float64) { 130 | _ = ui.epochsBar.Percent(int(100 * (float64(epoch) / float64(epochs)))) 131 | _ = ui.batchesBar.Percent(int(100 * (float64(batch) / float64(batches)))) 132 | 133 | _ = ui.costText.Write(fmt.Sprintf("%d/%d cost: %v\n", epoch, epochs, cost)) 134 | 135 | ui.costs = append(ui.costs, float64(cost)) 136 | if len(ui.costs) > 100 { 137 | ui.costs = ui.costs[len(ui.costs)-100:] 138 | } 139 | 140 | _ = ui.costLine.Series("cost", ui.costs) 141 | } 142 | -------------------------------------------------------------------------------- /batch_norm_test.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "log" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | "gorgonia.org/gorgonia" 9 | "gorgonia.org/tensor" 10 | ) 11 | 12 | func TestBatchNorm(t *testing.T) { 13 | testCases := []struct { 14 | desc string 15 | input tensor.Tensor 16 | expectedOutput tensor.Tensor 17 | expectedOutputGrad tensor.Tensor 18 | expectedScaleGrad tensor.Tensor 19 | expectedBiasGrad tensor.Tensor 20 | expectedCost float32 21 | }{ 22 | { 23 | desc: "Example 1", 24 | input: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float32{0.3, 0.03, 0.07, 0.7})), 25 | expectedOutput: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float32{0.9996221424388056, -0.9999554496246411, -0.9996221424388058, 0.999955449624641})), 26 | expectedOutputGrad: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float32{0.2500, 0.2500, 0.2500, 0.2500})), 27 | expectedScaleGrad: tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float32{2.9802322e-08, 2.9802322e-08})), 28 | expectedBiasGrad: tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float32{0.5, 0.5})), 29 | expectedCost: 0, 30 | }, 31 | { 32 | desc: "Example 2", 33 | input: tensor.New(tensor.WithShape(2, 2, 1, 1), tensor.WithBacking([]float32{0.3, 0.03, 0.07, 0.7})), 34 | expectedOutput: tensor.New(tensor.WithShape(2, 2, 1, 1), tensor.WithBacking([]float32{0.9996221424388056, -0.9999554496246411, -0.9996221424388058, 0.999955449624641})), 35 | expectedOutputGrad: tensor.New(tensor.WithShape(2, 2, 1, 1), tensor.WithBacking([]float32{0.2500, 0.2500, 0.2500, 0.2500})), 36 | expectedScaleGrad: tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float32{2.9802322e-08, 2.9802322e-08})), 37 | expectedBiasGrad: tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float32{0.5, 0.5})), 38 | expectedCost: 0, 39 | }, 40 | } 41 | for _, tC := range testCases { 42 | t.Run(tC.desc, func(t *testing.T) { 43 | c := require.New(t) 44 | 45 | solver := gorgonia.NewAdamSolver(gorgonia.WithLearnRate(0.1)) 46 | 47 | m := NewModel() 48 | opts := BatchNormOpts{ 49 | InputSize: tC.input.Shape()[1], 50 | } 51 | 52 | var bnFunc func(*Model, BatchNormOpts) *BatchNormModule 53 | bnFunc = BatchNorm1d 54 | if tC.input.Dims() == 4 { 55 | bnFunc = BatchNorm2d 56 | } 57 | 58 | x := gorgonia.NewTensor(m.trainGraph, tensor.Float32, tC.input.Shape().Dims(), gorgonia.WithShape(tC.input.Shape()...), gorgonia.WithValue(tC.input), gorgonia.WithName("x")) 59 | 60 | bn := bnFunc(m, opts) 61 | result := bn.Forward(x) 62 | 63 | cost := gorgonia.Must(gorgonia.Mean(result[0])) 64 | 65 | l := m.learnables 66 | 67 | _, err := gorgonia.Grad(cost, l...) 68 | c.NoError(err) 69 | 70 | vm := gorgonia.NewTapeMachine(m.trainGraph, 71 | gorgonia.BindDualValues(l...), 72 | gorgonia.WithWatchlist(), 73 | gorgonia.TraceExec(), 74 | ) 75 | c.NoError(vm.RunAll()) 76 | c.NoError(vm.Close()) 77 | 78 | outputGrad, err := result[0].Grad() 79 | c.NoError(err) 80 | 81 | scaleGrad, err := l[0].Grad() 82 | c.NoError(err) 83 | 84 | biasGrad, err := l[1].Grad() 85 | c.NoError(err) 86 | 87 | log.Printf("input: %v", tC.input) 88 | log.Printf("output: %v", result[0].Value()) 89 | log.Printf("output grad: %v", outputGrad) 90 | log.Printf("scale grad: %v", scaleGrad) 91 | log.Printf("bias grad: %v", biasGrad) 92 | log.Printf("cost: %v", cost.Value()) 93 | 94 | c.InDeltaSlice(tC.expectedOutput.Data(), result[0].Value().Data(), 1e-5, "actual: %#v", result[0].Value().Data()) 95 | c.Equal(tC.expectedOutputGrad.Data(), outputGrad.Data()) 96 | c.Equal(tC.expectedScaleGrad.Data(), scaleGrad.Data()) 97 | c.Equal(tC.expectedBiasGrad.Data(), biasGrad.Data()) 98 | c.InDelta(tC.expectedCost, cost.Value().Data(), 1e-5) 99 | 100 | c.NoError(solver.Step(gorgonia.NodesToValueGrads(m.Learnables()))) 101 | 102 | log.Printf("scale: %v", l[0].Value()) 103 | log.Printf("bias: %v", l[1].Value()) 104 | }) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /tabnet/tab_net_regressor.go: -------------------------------------------------------------------------------- 1 | package tabnet 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/dcu/godl" 7 | "github.com/dcu/godl/activation" 8 | "gorgonia.org/gorgonia" 9 | "gorgonia.org/tensor" 10 | ) 11 | 12 | type Regressor struct { 13 | opts RegressorOpts 14 | model *godl.Model 15 | tabnet *TabNetModule 16 | } 17 | 18 | type RegressorOpts struct { 19 | BatchSize int 20 | VirtualBatchSize int 21 | MaskFunction activation.Function 22 | WithBias bool 23 | 24 | SharedBlocks int 25 | IndependentBlocks int 26 | DecisionSteps int 27 | PredictionLayerDim int 28 | AttentionLayerDim int 29 | 30 | Gamma float64 31 | Momentum float64 32 | Epsilon float64 33 | 34 | WeightsInit, ScaleInit, BiasInit gorgonia.InitWFn 35 | } 36 | 37 | func newModel(training bool, batchSize int, inputDim int, catDims []int, catIdxs []int, catEmbDim []int, opts RegressorOpts) (*godl.Model, *TabNetModule) { 38 | nn := godl.NewModel() 39 | 40 | layer := TabNet(nn, TabNetOpts{ 41 | OutputSize: 1, 42 | BatchSize: batchSize, 43 | VirtualBatchSize: opts.VirtualBatchSize, 44 | InputSize: inputDim, 45 | MaskFunction: opts.MaskFunction, 46 | WithBias: opts.WithBias, 47 | WeightsInit: opts.WeightsInit, 48 | ScaleInit: opts.ScaleInit, 49 | BiasInit: opts.BiasInit, 50 | SharedBlocks: opts.SharedBlocks, 51 | IndependentBlocks: opts.IndependentBlocks, 52 | DecisionSteps: opts.DecisionSteps, 53 | PredictionLayerDim: opts.PredictionLayerDim, 54 | AttentionLayerDim: opts.AttentionLayerDim, 55 | Gamma: opts.Gamma, 56 | Momentum: opts.Momentum, 57 | Epsilon: opts.Epsilon, 58 | CatDims: catDims, 59 | CatIdxs: catIdxs, 60 | CatEmbDim: catEmbDim, 61 | }) 62 | 63 | return nn, layer 64 | } 65 | 66 | func NewRegressor(inputDim int, catDims []int, catIdxs []int, catEmbDim []int, opts RegressorOpts) *Regressor { 67 | model, layer := newModel(true, opts.BatchSize, inputDim, catDims, catIdxs, catEmbDim, opts) 68 | 69 | return &Regressor{ 70 | opts: opts, 71 | model: model, 72 | tabnet: layer, 73 | } 74 | } 75 | 76 | func (r *Regressor) Train(trainX, trainY, validateX, validateY tensor.Tensor, opts godl.TrainOpts) error { 77 | log.Printf("input: %v", trainX) 78 | 79 | if opts.CostFn == nil { 80 | lambdaSparse := gorgonia.NewConstant(float32(1e-3), gorgonia.WithName("LambdaSparse")) 81 | mseLoss := godl.MSELoss(godl.MSELossOpts{}) 82 | 83 | opts.CostFn = func(output godl.Nodes, target *godl.Node) *gorgonia.Node { 84 | cost := mseLoss(output, target) 85 | tmpLoss := gorgonia.Must(gorgonia.Mul(output[1], lambdaSparse)) 86 | cost = gorgonia.Must(gorgonia.Sub(cost, tmpLoss)) 87 | 88 | return cost 89 | } 90 | } 91 | 92 | if opts.Solver == nil { 93 | opts.Solver = gorgonia.NewAdamSolver(gorgonia.WithBatchSize(float64(opts.BatchSize)), gorgonia.WithLearnRate(0.02)) 94 | } 95 | 96 | return godl.Train(r.model, r.tabnet, trainX, trainY, validateX, validateY, opts) 97 | } 98 | 99 | // FIXME: this shouldn't receive Y 100 | func (r *Regressor) Solve(x tensor.Tensor, y tensor.Tensor) (tensor.Tensor, error) { 101 | predictor, err := r.model.Predictor(r.tabnet, godl.PredictOpts{ 102 | InputShape: tensor.Shape{r.opts.BatchSize, x.Shape()[1]}, 103 | }) 104 | if err != nil { 105 | return nil, err 106 | } 107 | 108 | yPos := 0 109 | correct := 0.0 110 | 111 | godl.InBatches(x, r.opts.BatchSize, func(v tensor.Tensor) { 112 | val, err := predictor(v) 113 | if err != nil { 114 | panic(err) 115 | } 116 | 117 | t := val.(tensor.Tensor) 118 | 119 | log.Printf("output: %v", t.Shape()) 120 | 121 | for _, o := range t.Data().([]float32) { 122 | yVal, err := y.At(yPos, 0) 123 | if err != nil { 124 | panic(err) 125 | } 126 | 127 | // log.Printf("%v == %v", yVal, o) 128 | if yVal.(float32) == 1 { 129 | if o > 0.5 { 130 | correct++ 131 | } 132 | } else { 133 | if o <= 0.5 { 134 | correct++ 135 | } 136 | } 137 | 138 | yPos++ 139 | } 140 | }) 141 | 142 | log.Printf("r=%v", correct/float64(yPos)) 143 | 144 | return nil, nil 145 | } 146 | -------------------------------------------------------------------------------- /tabnet/feature_transformer_test.go: -------------------------------------------------------------------------------- 1 | package tabnet 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/dcu/godl" 7 | "github.com/stretchr/testify/require" 8 | "gorgonia.org/gorgonia" 9 | "gorgonia.org/tensor" 10 | ) 11 | 12 | func TestFeatureTransformer(t *testing.T) { 13 | testCases := []struct { 14 | desc string 15 | input tensor.Tensor 16 | weight []float32 17 | vbs int 18 | independentBlocks int 19 | sharedBlocks int 20 | output int 21 | expectedShape tensor.Shape 22 | expectedErr string 23 | expectedOutput []float32 24 | expectedGrad []float32 25 | expectedCost float32 26 | }{ 27 | { 28 | desc: "Example1", 29 | input: tensor.New( 30 | tensor.WithShape(6, 2), 31 | tensor.WithBacking([]float32{0.4, 1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4, 9.4, 10.4, 11.4}), 32 | ), 33 | weight: []float32{-2.1376, -1.7072, 1.5896, 1.3657, 0.2603, -0.4051, -0.8271, 1.0830, 0.2617, -0.2792, -1.2426, 0.8678, 1.0771, -0.1787, -0.7482, 0.9506, -0.0861, -0.3015, -1.0695, -0.0246, -0.7007, 0.0354, -0.2400, -0.2516, -0.3165, 2.0425, 0.6425, 0.5848, -1.9183, 0.0099, -0.8387, -1.5346}, 34 | vbs: 2, 35 | output: 8 + 8, 36 | independentBlocks: 5, 37 | sharedBlocks: 5, 38 | expectedShape: tensor.Shape{6, 2}, 39 | expectedOutput: []float32{-0.24453057, 2.446192, 0.664702, -6.64944, -0.2445306, 2.446192, 0.66470236, -6.6494403, -0.2445307, 2.4461923, 0.6647016, -6.64944}, 40 | expectedGrad: []float32{0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333}, 41 | expectedCost: -0.94576913, 42 | }, 43 | } 44 | 45 | for _, tcase := range testCases { 46 | t.Run(tcase.desc, func(t *testing.T) { 47 | c := require.New(t) 48 | tn := godl.NewModel() 49 | g := tn.TrainGraph() 50 | 51 | input := gorgonia.NewTensor(g, tensor.Float32, tcase.input.Dims(), gorgonia.WithShape(tcase.input.Shape()...), gorgonia.WithName("x"), gorgonia.WithValue(tcase.input)) 52 | 53 | fcWeight := gorgonia.NewTensor(g, tensor.Float32, 2, gorgonia.WithShape(input.Shape()[1], tcase.output), gorgonia.WithValue( 54 | tensor.New( 55 | tensor.WithShape(input.Shape()[1], tcase.output), 56 | tensor.WithBacking(tcase.weight), 57 | ), 58 | ), gorgonia.WithName("fcWeight")) 59 | 60 | shared := make([]*godl.LinearModule, tcase.sharedBlocks) 61 | fcInput := input.Shape()[1] 62 | fcOutput := 2 * tcase.output 63 | for i := 0; i < tcase.sharedBlocks; i++ { 64 | shared[i] = godl.Linear(tn, godl.LinearOpts{ 65 | OutputDimension: fcOutput, // double the size so we can take half and half 66 | WeightsInit: gorgonia.RangedFromWithStep(-0.1, 0.01), 67 | InputDimension: fcInput, 68 | }) 69 | 70 | fcInput = tcase.output 71 | } 72 | 73 | result := FeatureTransformer(tn, FeatureTransformerOpts{ 74 | VirtualBatchSize: tcase.vbs, 75 | InputDimension: input.Shape()[1], 76 | OutputDimension: tcase.output, 77 | Shared: shared, 78 | IndependentBlocks: tcase.independentBlocks, 79 | WeightsInit: initDummyWeights, 80 | Momentum: 0.02, 81 | }).Forward(input) 82 | 83 | y := result[0] 84 | 85 | wT := gorgonia.Must(gorgonia.Transpose(fcWeight, 1, 0)) 86 | y = gorgonia.Must(gorgonia.Mul(y, wT)) 87 | 88 | cost := gorgonia.Must(gorgonia.Mean(y)) 89 | _, err := gorgonia.Grad(cost, input) 90 | c.NoError(err) 91 | 92 | vm := gorgonia.NewTapeMachine(g, gorgonia.BindDualValues(tn.Learnables()...)) 93 | c.NoError(vm.RunAll()) 94 | 95 | tn.PrintWatchables() 96 | 97 | t.Logf("feat output: %v", y.Value()) 98 | t.Logf("y: %v", y.Value()) 99 | t.Logf("dx: %v", input.Deriv().Value()) 100 | 101 | c.Equal(tcase.expectedShape, y.Shape()) 102 | c.InDeltaSlice(tcase.expectedOutput, y.Value().Data().([]float32), 1e-5, "%#v expected. Got: %#v", tcase.expectedOutput, y.Value().Data()) 103 | 104 | yGrad, err := y.Grad() 105 | c.NoError(err) 106 | 107 | c.Equal(tcase.expectedGrad, yGrad.Data()) 108 | c.InDelta(tcase.expectedCost, cost.Value().Data(), 1e-5) 109 | }) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /examples/mnist/loader.go: -------------------------------------------------------------------------------- 1 | package mnist 2 | 3 | import ( 4 | "compress/gzip" 5 | "encoding/binary" 6 | "fmt" 7 | "io" 8 | "os" 9 | "path/filepath" 10 | 11 | "gorgonia.org/tensor" 12 | ) 13 | 14 | type Mode string 15 | type ContentType string 16 | 17 | const ( 18 | ModeTrain Mode = "train" 19 | ModeTest Mode = "test" 20 | ) 21 | 22 | const ( 23 | ContentTypeData ContentType = "data" 24 | ContentTypeLabels ContentType = "labels" 25 | ) 26 | 27 | var ( 28 | fileNames = map[Mode]map[ContentType]string{ 29 | ModeTrain: { 30 | ContentTypeData: "train-images-idx3-ubyte.gz", 31 | ContentTypeLabels: "train-labels-idx1-ubyte.gz", 32 | }, 33 | ModeTest: { 34 | ContentTypeData: "t10k-images-idx3-ubyte.gz", 35 | ContentTypeLabels: "t10k-labels-idx1-ubyte.gz", 36 | }, 37 | } 38 | ) 39 | 40 | const ( 41 | imageMagic = 0x00000803 42 | labelMagic = 0x00000801 43 | ) 44 | 45 | func Load(mode Mode, baseDir string) (inputs, targets tensor.Tensor, err error) { 46 | labelFile := filepath.Join(baseDir, fileNames[mode][ContentTypeLabels]) 47 | dataFile := filepath.Join(baseDir, fileNames[mode][ContentTypeData]) 48 | 49 | labelData, err := loadLabelFile(labelFile) 50 | if err != nil { 51 | return nil, nil, fmt.Errorf("cannot load label data: %w", err) 52 | } 53 | 54 | imageData, err := loadDataFile(dataFile) 55 | if err != nil { 56 | return nil, nil, fmt.Errorf("cannot load image data: %w", err) 57 | } 58 | 59 | return imageDataToTensor(imageData), labelDataToTensor(labelData), nil 60 | } 61 | 62 | func pixelWeight(px byte) float32 { 63 | retVal := float32(px)/255*0.9 + 0.1 64 | if retVal == 1.0 { 65 | return 0.999 66 | } 67 | 68 | return retVal 69 | } 70 | 71 | func imageDataToTensor(imageData [][]byte) tensor.Tensor { 72 | rows := len(imageData) 73 | cols := len(imageData[0]) 74 | 75 | backing := make([]float32, 0, rows*cols) 76 | 77 | for i := 0; i < rows; i++ { 78 | for j := 0; j < len(imageData[i]); j++ { 79 | backing = append(backing, pixelWeight(imageData[i][j])) 80 | } 81 | } 82 | 83 | return tensor.New(tensor.WithShape(rows, cols), tensor.WithBacking(backing)) 84 | } 85 | 86 | func labelDataToTensor(labelData []uint8) tensor.Tensor { 87 | rows := len(labelData) 88 | cols := 10 89 | 90 | backing := make([]float32, 0, rows*cols) 91 | 92 | for i := 0; i < rows; i++ { 93 | for j := 0; j < cols; j++ { 94 | if j == int(labelData[i]) { 95 | backing = append(backing, 0.9) 96 | } else { 97 | backing = append(backing, 0.1) 98 | } 99 | } 100 | } 101 | 102 | return tensor.New(tensor.WithShape(rows, cols), tensor.WithBacking(backing)) 103 | } 104 | 105 | func loadLabelFile(labelFile string) ([]uint8, error) { 106 | f, err := os.Open(labelFile) 107 | if err != nil { 108 | return nil, err 109 | } 110 | 111 | defer func() { _ = f.Close() }() 112 | 113 | r, err := gzip.NewReader(f) 114 | if err != nil { 115 | return nil, err 116 | } 117 | 118 | var ( 119 | magic int32 120 | n int32 121 | ) 122 | 123 | if err = binary.Read(r, binary.BigEndian, &magic); err != nil { 124 | return nil, err 125 | } 126 | 127 | if magic != labelMagic { 128 | return nil, os.ErrInvalid 129 | } 130 | 131 | if err = binary.Read(r, binary.BigEndian, &n); err != nil { 132 | return nil, err 133 | } 134 | 135 | labels := make([]uint8, n) 136 | for i := 0; i < int(n); i++ { 137 | var l uint8 138 | if err := binary.Read(r, binary.BigEndian, &l); err != nil { 139 | return nil, err 140 | } 141 | 142 | labels[i] = l 143 | } 144 | 145 | return labels, nil 146 | } 147 | 148 | func loadDataFile(dataFile string) ([][]byte, error) { 149 | f, err := os.Open(dataFile) 150 | if err != nil { 151 | return nil, err 152 | } 153 | 154 | defer func() { _ = f.Close() }() 155 | 156 | r, err := gzip.NewReader(f) 157 | if err != nil { 158 | return nil, err 159 | } 160 | 161 | var ( 162 | magic int32 163 | n int32 164 | nrow int32 165 | ncol int32 166 | ) 167 | 168 | if err = binary.Read(r, binary.BigEndian, &magic); err != nil { 169 | return nil, err 170 | } 171 | 172 | if magic != imageMagic { 173 | return nil, err 174 | } 175 | 176 | if err = binary.Read(r, binary.BigEndian, &n); err != nil { 177 | return nil, err 178 | } 179 | 180 | if err = binary.Read(r, binary.BigEndian, &nrow); err != nil { 181 | return nil, err 182 | } 183 | 184 | if err = binary.Read(r, binary.BigEndian, &ncol); err != nil { 185 | return nil, err 186 | } 187 | 188 | imgs := make([][]byte, n) 189 | m := int(nrow * ncol) 190 | for i := 0; i < int(n); i++ { 191 | imgs[i] = make([]byte, m) 192 | m_, err := io.ReadFull(r, imgs[i]) 193 | if err != nil { 194 | return nil, err 195 | } 196 | if m_ != int(m) { 197 | return nil, os.ErrInvalid 198 | } 199 | } 200 | 201 | return imgs, nil 202 | } 203 | -------------------------------------------------------------------------------- /validate.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "strings" 7 | 8 | "github.com/olekukonko/tablewriter" 9 | "gorgonia.org/gorgonia" 10 | "gorgonia.org/tensor" 11 | ) 12 | 13 | type MatchType int 14 | 15 | const ( 16 | MatchTypeTruePositive MatchType = iota 17 | MatchTypeTrueNegative 18 | MatchTypeFalsePositive 19 | MatchTypeFalseNegative 20 | ) 21 | 22 | type ConfusionMatrix map[MatchType]int 23 | 24 | func (cmat ConfusionMatrix) Accuracy() float64 { 25 | v := float64(cmat[MatchTypeTruePositive]+cmat[MatchTypeTrueNegative]) / float64(cmat[MatchTypeTrueNegative]+cmat[MatchTypeTruePositive]+cmat[MatchTypeFalseNegative]+cmat[MatchTypeFalsePositive]) 26 | 27 | return v 28 | } 29 | 30 | func (cmat ConfusionMatrix) Precision() float64 { 31 | v := float64(cmat[MatchTypeTruePositive]) / float64(cmat[MatchTypeTruePositive]+cmat[MatchTypeFalsePositive]) 32 | 33 | return v 34 | } 35 | 36 | func (cmat ConfusionMatrix) F1Score() float64 { 37 | v := float64(2*cmat[MatchTypeTruePositive]) / float64(2*cmat[MatchTypeTruePositive]+cmat[MatchTypeFalsePositive]+cmat[MatchTypeFalseNegative]) 38 | 39 | return v 40 | } 41 | 42 | func (cmat ConfusionMatrix) Recall() float64 { 43 | v := float64(cmat[MatchTypeTruePositive]) / float64(cmat[MatchTypeTruePositive]+cmat[MatchTypeFalseNegative]) 44 | 45 | return v 46 | } 47 | 48 | func (cmat ConfusionMatrix) MissRate() float64 { 49 | v := 1 - cmat.Recall() 50 | 51 | return v 52 | } 53 | 54 | func (cmat ConfusionMatrix) String() string { 55 | b := strings.Builder{} 56 | w := tablewriter.NewWriter(&b) 57 | w.SetBorder(true) 58 | w.SetAlignment(tablewriter.ALIGN_CENTER) 59 | w.SetRowLine(true) 60 | 61 | w.SetHeader([]string{"Actual\\Predicted", "P", "N"}) 62 | w.Rich([]string{"P", fmt.Sprintf("%*d", 8, cmat[MatchTypeTruePositive]), fmt.Sprintf("%*d", 8, cmat[MatchTypeFalseNegative])}, []tablewriter.Colors{{}, {tablewriter.BgHiGreenColor, tablewriter.FgBlackColor}, {tablewriter.BgHiRedColor, tablewriter.FgBlackColor}}) 63 | w.Rich([]string{"N", fmt.Sprintf("%*d", 8, cmat[MatchTypeFalsePositive]), fmt.Sprintf("%*d", 8, cmat[MatchTypeTrueNegative])}, []tablewriter.Colors{{}, {tablewriter.BgHiRedColor, tablewriter.FgBlackColor}, {tablewriter.BgHiGreenColor, tablewriter.FgBlackColor}}) 64 | 65 | w.Render() 66 | 67 | b.WriteString(fmt.Sprintf(` 68 | Accuracy: %0.1f%% 69 | Precision: %0.1f%% 70 | F1 Score: %0.1f%% 71 | Recall: %0.1f%% 72 | `, cmat.Accuracy()*100, cmat.Precision()*100, cmat.F1Score()*100, cmat.Recall()*100)) 73 | 74 | return b.String() 75 | } 76 | 77 | func Validate(m *Model, x, y *gorgonia.Node, costVal, predVal gorgonia.Value, validateX, validateY tensor.Tensor, opts TrainOpts) error { 78 | opts.setDefaults() 79 | 80 | g := m.evalGraph 81 | if g == nil { 82 | fatal("evaluation graph not set") 83 | } 84 | 85 | dl := NewDataLoader(validateX, validateY, DataLoaderOpts{ 86 | BatchSize: opts.BatchSize, 87 | Shuffle: false, 88 | }) 89 | 90 | vmOpts := []gorgonia.VMOpt{ 91 | gorgonia.EvalMode(), 92 | } 93 | 94 | if opts.DevMode { 95 | vmOpts = append( 96 | vmOpts, 97 | gorgonia.TraceExec(), 98 | gorgonia.WithInfWatch(), 99 | gorgonia.WithNaNWatch(), 100 | ) 101 | } 102 | 103 | vm := gorgonia.NewTapeMachine(g, vmOpts...) 104 | 105 | defer vm.Close() 106 | 107 | confMat := ConfusionMatrix{} 108 | 109 | for dl.HasNext() { 110 | xVal, yVal := dl.Next() 111 | 112 | err := gorgonia.Let(x, xVal) 113 | if err != nil { 114 | fatal("error assigning x: %v", err) 115 | } 116 | 117 | err = gorgonia.Let(y, yVal) 118 | if err != nil { 119 | fatal("error assigning y: %v", err) 120 | } 121 | 122 | if err = vm.RunAll(); err != nil { 123 | fatal("Failed batch %d. Error: %v", dl.CurrentBatch, err) 124 | } 125 | 126 | for j := 0; j < predVal.Shape()[0]; j++ { 127 | yRowT, err := yVal.Slice(gorgonia.S(j, j+1)) 128 | if err != nil { 129 | panic(err) 130 | } 131 | 132 | var yRow []float32 133 | 134 | switch v := yRowT.Data().(type) { 135 | case []float32: 136 | yRow = v 137 | case float32: 138 | yRow = []float32{v} 139 | default: 140 | log.Panicf("type %T not supported", v) 141 | } 142 | 143 | // get prediction 144 | predRowT, err := predVal.(tensor.Tensor).Slice(gorgonia.S(j, j+1)) 145 | if err != nil { 146 | panic(err) 147 | } 148 | 149 | var predRow []float32 150 | 151 | switch v := predRowT.Data().(type) { 152 | case []float32: 153 | predRow = v 154 | case float32: 155 | predRow = []float32{v} 156 | default: 157 | log.Panicf("type %T not supported", v) 158 | } 159 | 160 | mt := opts.MatchTypeFor(predRow, yRow) 161 | confMat[mt]++ 162 | } 163 | 164 | vm.Reset() 165 | } 166 | 167 | if opts.ValidationObserver != nil { 168 | opts.ValidationObserver(confMat, costVal.Data().(float32)) 169 | } 170 | 171 | return nil 172 | } 173 | -------------------------------------------------------------------------------- /data_loader.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "math/rand" 5 | 6 | "gorgonia.org/gorgonia" 7 | "gorgonia.org/tensor" 8 | "gorgonia.org/tensor/native" 9 | ) 10 | 11 | type DataLoaderOpts struct { 12 | Shuffle bool 13 | BatchSize int 14 | Drop bool 15 | } 16 | 17 | func (o *DataLoaderOpts) setDefaults() { 18 | if o.BatchSize <= 0 { 19 | panic("batch size must be greater than 0") 20 | } 21 | } 22 | 23 | type DataLoader struct { 24 | x tensor.Tensor 25 | y tensor.Tensor 26 | opts DataLoaderOpts 27 | 28 | FeaturesShape tensor.Shape 29 | Rows int 30 | Batches int 31 | CurrentBatch int 32 | } 33 | 34 | // NewDataLoader creates a data loader with the given data and options 35 | func NewDataLoader(x tensor.Tensor, y tensor.Tensor, opts DataLoaderOpts) *DataLoader { 36 | opts.setDefaults() 37 | 38 | var err error 39 | 40 | numExamples := x.Shape()[0] 41 | 42 | if !opts.Drop && numExamples%opts.BatchSize > 0 { 43 | missingRows := opts.BatchSize - (numExamples % opts.BatchSize) 44 | 45 | rowsX := make([]tensor.Tensor, missingRows) 46 | rowsY := make([]tensor.Tensor, missingRows) 47 | 48 | for i := 0; i < missingRows; i++ { 49 | row := rand.Intn(numExamples) 50 | 51 | xS, err := x.Slice(gorgonia.S(row)) 52 | if err != nil { 53 | panic(err) 54 | } 55 | 56 | newXShape := append(tensor.Shape{1}, x.Shape()[1:]...) 57 | err = xS.Reshape(newXShape...) 58 | if err != nil { 59 | panic(err) 60 | } 61 | 62 | rowsX[i] = xS 63 | 64 | yS, err := y.Slice(gorgonia.S(row)) 65 | if err != nil { 66 | panic(err) 67 | } 68 | 69 | newYShape := append(tensor.Shape{1}, y.Shape()[1:]...) 70 | err = yS.Reshape(newYShape...) 71 | if err != nil { 72 | panic(err) 73 | } 74 | 75 | rowsY[i] = yS 76 | } 77 | 78 | x, err = tensor.Concat(0, x, rowsX...) 79 | if err != nil { 80 | panic(err) 81 | } 82 | 83 | y, err = tensor.Concat(0, y, rowsY...) 84 | if err != nil { 85 | panic(err) 86 | } 87 | 88 | numExamples += missingRows 89 | } 90 | 91 | batches := numExamples / opts.BatchSize 92 | 93 | dl := &DataLoader{ 94 | x: x, 95 | y: y, 96 | opts: opts, 97 | Rows: numExamples, 98 | Batches: batches, 99 | FeaturesShape: tensor.Shape(x.Shape()[1:]).Clone(), 100 | } 101 | 102 | dl.Reset() 103 | 104 | return dl 105 | } 106 | 107 | // HasNext returns true if there's more batches to fetch 108 | func (dl DataLoader) HasNext() bool { 109 | start := (dl.CurrentBatch) * dl.opts.BatchSize 110 | 111 | if start >= dl.Rows { 112 | return false 113 | } 114 | 115 | return true 116 | } 117 | 118 | // Reset resets the iterator 119 | func (dl *DataLoader) Reset() { 120 | dl.CurrentBatch = 0 121 | 122 | if dl.opts.Shuffle { 123 | err := dl.Shuffle() 124 | if err != nil { 125 | panic(err) 126 | } 127 | } 128 | } 129 | 130 | func (dl *DataLoader) toMatrix(t tensor.Tensor) tensor.Shape { 131 | prevShape := t.Shape().Clone() 132 | 133 | err := t.Reshape(append(tensor.Shape{prevShape[0]}, tensor.Shape(prevShape[1:]).TotalSize())...) 134 | if err != nil { 135 | panic(err) 136 | } 137 | 138 | return prevShape 139 | } 140 | 141 | // Shuffle shuffles the data 142 | func (dl *DataLoader) Shuffle() error { 143 | oldXShape := dl.toMatrix(dl.x) 144 | defer func() { 145 | _ = dl.x.Reshape(oldXShape...) 146 | }() 147 | 148 | iterX, err := native.MatrixF32(dl.x.(*tensor.Dense)) 149 | if err != nil { 150 | return err 151 | } 152 | 153 | oldYShape := dl.toMatrix(dl.y) 154 | defer func() { 155 | _ = dl.y.Reshape(oldYShape...) 156 | }() 157 | 158 | iterY, err := native.MatrixF32(dl.y.(*tensor.Dense)) 159 | if err != nil { 160 | return err 161 | } 162 | 163 | tmp := make([]float32, dl.FeaturesShape.TotalSize()) 164 | rand.Shuffle(dl.Rows, func(i, j int) { 165 | copy(tmp, iterX[i]) 166 | copy(iterX[i], iterX[j]) 167 | copy(iterX[j], tmp) 168 | 169 | copy(tmp, iterY[i]) 170 | copy(iterY[i], iterY[j]) 171 | copy(iterY[j], tmp) 172 | }) 173 | 174 | return nil 175 | } 176 | 177 | // Next returns the next batch 178 | func (dl *DataLoader) Next() (tensor.Tensor, tensor.Tensor) { 179 | start := dl.CurrentBatch * dl.opts.BatchSize 180 | end := start + dl.opts.BatchSize 181 | 182 | if start >= dl.Rows { 183 | return nil, nil 184 | } 185 | 186 | if end > dl.Rows { 187 | end = dl.Rows 188 | } 189 | 190 | inputSize := end - start 191 | 192 | xVal, err := dl.x.Slice(gorgonia.S(start, end)) 193 | if err != nil { 194 | panic(err) 195 | } 196 | 197 | yVal, err := dl.y.Slice(gorgonia.S(start, end)) 198 | if err != nil { 199 | panic(err) 200 | } 201 | 202 | err = xVal.(*tensor.Dense).Reshape(append(tensor.Shape{inputSize}, dl.FeaturesShape...)...) 203 | if err != nil { 204 | panic(err) 205 | } 206 | 207 | dl.CurrentBatch++ 208 | 209 | return xVal, yVal 210 | } 211 | -------------------------------------------------------------------------------- /model.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "log" 7 | "os" 8 | "os/exec" 9 | "path/filepath" 10 | 11 | "github.com/dcu/godl/storage" 12 | "gonum.org/v1/plot/vg" 13 | "gorgonia.org/gorgonia" 14 | "gorgonia.org/qol/plot" 15 | "gorgonia.org/tensor" 16 | ) 17 | 18 | const ( 19 | heatmapPath = "heatmap" 20 | bufferSizeModel = 16 21 | ) 22 | 23 | // Model implements the tab net model 24 | type Model struct { 25 | trainGraph, evalGraph *gorgonia.ExprGraph 26 | learnables gorgonia.Nodes 27 | watchables []watchable 28 | 29 | Logger *log.Logger 30 | Storage *storage.Storage 31 | } 32 | 33 | // NewModel creates a new model for the neural network 34 | func NewModel() *Model { 35 | return &Model{ 36 | trainGraph: gorgonia.NewGraph(), 37 | learnables: make([]*gorgonia.Node, 0, bufferSizeModel), 38 | watchables: make([]watchable, 0), 39 | Storage: storage.NewStorage(), 40 | } 41 | } 42 | 43 | // WriteSVG creates a SVG representation of the node 44 | func (m *Model) WriteSVG(path string) error { 45 | b := m.trainGraph.ToDot() 46 | 47 | fileName := "graph.dot" 48 | 49 | err := ioutil.WriteFile(fileName, []byte(b), 0644) 50 | if err != nil { 51 | return err 52 | } 53 | 54 | defer func() { _ = os.Remove(fileName) }() 55 | 56 | cmd := exec.Command("dot", "-T", "svg", fileName, "-o", path) 57 | 58 | return cmd.Run() 59 | } 60 | 61 | // TrainGraph returns the graph for the model 62 | func (m *Model) TrainGraph() *gorgonia.ExprGraph { 63 | return m.trainGraph 64 | } 65 | 66 | // Learnables returns all learnables in the model 67 | func (m *Model) Learnables() gorgonia.Nodes { 68 | return m.learnables 69 | } 70 | 71 | // Run runs the virtual machine in prediction mode 72 | func (m *Model) Run(vmOpts ...gorgonia.VMOpt) error { 73 | vm := gorgonia.NewTapeMachine(m.trainGraph, vmOpts...) 74 | 75 | err := vm.RunAll() 76 | if err != nil { 77 | return err 78 | } 79 | 80 | return vm.Close() 81 | } 82 | 83 | type PredictOpts struct { 84 | InputShape tensor.Shape 85 | DevMode bool 86 | } 87 | 88 | type Predictor func(x tensor.Tensor) (gorgonia.Value, error) 89 | 90 | func (o *PredictOpts) setDefaults() { 91 | if o.InputShape == nil { 92 | panic("InputShape is required") 93 | } 94 | } 95 | 96 | func (m *Model) Predictor(module Module, opts PredictOpts) (Predictor, error) { 97 | opts.setDefaults() 98 | 99 | x := gorgonia.NewTensor( 100 | m.trainGraph, 101 | tensor.Float32, 102 | opts.InputShape.Dims(), 103 | gorgonia.WithName("input"), 104 | gorgonia.WithShape(opts.InputShape...), 105 | ) 106 | 107 | result := module.Forward(x) 108 | vmOpts := []gorgonia.VMOpt{ 109 | gorgonia.EvalMode(), 110 | } 111 | 112 | if opts.DevMode { 113 | vmOpts = append( 114 | vmOpts, 115 | gorgonia.TraceExec(), 116 | gorgonia.WithInfWatch(), 117 | gorgonia.WithNaNWatch(), 118 | ) 119 | } 120 | 121 | var predVal gorgonia.Value 122 | 123 | gorgonia.Read(result[0], &predVal) 124 | 125 | return func(input tensor.Tensor) (gorgonia.Value, error) { 126 | gorgonia.Let(x, input) 127 | 128 | if err := m.Run(vmOpts...); err != nil { 129 | return nil, fmt.Errorf("failed to run prediction: %w", err) 130 | } 131 | 132 | return predVal, nil 133 | }, nil 134 | } 135 | 136 | func (m Model) saveHeatmaps(epoch, batch, batchSize, features int) { 137 | for _, v := range m.learnables { 138 | wt := v.Value().(tensor.Tensor) 139 | wtShape := wt.Shape().Clone() 140 | x, y := wtShape[0], tensor.Shape(wtShape[1:]).TotalSize() 141 | 142 | newShape := tensor.Shape{x, y} 143 | 144 | grad, err := v.Grad() 145 | if err != nil { 146 | panic(err) 147 | } 148 | 149 | gradT := grad.(tensor.Tensor) 150 | 151 | pathName := filepath.Join(heatmapPath, v.Name()) 152 | fileName := fmt.Sprintf("%s/%d_%d_%v.png", pathName, epoch, batch, wtShape) 153 | gradFileName := fmt.Sprintf("%s/grad_%d_%d_%v.png", pathName, epoch, batch, wtShape) 154 | 155 | err = wt.Reshape(newShape...) 156 | if err != nil { 157 | panic(err) 158 | } 159 | 160 | p, err := plot.Heatmap(wt, nil) 161 | if err != nil { 162 | panic(fmt.Errorf("failed to process %s: %w", fileName, err)) 163 | } 164 | 165 | err = gradT.Reshape(newShape...) 166 | if err != nil { 167 | panic(err) 168 | } 169 | 170 | pGrad, err := plot.Heatmap(gradT, nil) 171 | if err != nil { 172 | panic(fmt.Errorf("failed to process %s: %w", fileName, err)) 173 | } 174 | 175 | err = wt.Reshape(wtShape...) 176 | if err != nil { 177 | panic(err) 178 | } 179 | 180 | err = gradT.Reshape(wtShape...) 181 | if err != nil { 182 | panic(err) 183 | } 184 | 185 | width := vg.Length(newShape[0]) * vg.Centimeter 186 | height := vg.Length(newShape[1]) * vg.Centimeter 187 | 188 | _ = os.MkdirAll(pathName, 0755) 189 | _ = p.Save(width, height, fileName) 190 | 191 | _ = pGrad.Save(width, height, gradFileName) 192 | } 193 | } 194 | 195 | // CheckArity checks if the arity is the correct one 196 | func (m Model) CheckArity(lt LayerType, nodes []*gorgonia.Node, arity int) error { 197 | if len(nodes) != arity { 198 | return ErrorF(lt, "arity doesn't match, expected %d, got %d", arity, len(nodes)) 199 | } 200 | 201 | return nil 202 | } 203 | -------------------------------------------------------------------------------- /train.go: -------------------------------------------------------------------------------- 1 | package godl 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "log" 7 | "os" 8 | "time" 9 | 10 | "github.com/fatih/color" 11 | "gorgonia.org/gorgonia" 12 | "gorgonia.org/tensor" 13 | ) 14 | 15 | // TrainOpts are the options to train the model 16 | type TrainOpts struct { 17 | Epochs int 18 | BatchSize int 19 | 20 | // DevMode detects common issues like exploding and vanishing gradients at the cost of performance 21 | DevMode bool 22 | 23 | WriteGraphFileTo string 24 | 25 | // WithLearnablesHeatmap writes images representing heatmaps for the weights. Use it to debug. 26 | WithLearnablesHeatmap bool 27 | 28 | // Solver defines the solver to use. It uses gorgonia.AdamSolver by default if none is passed 29 | Solver gorgonia.Solver 30 | 31 | // ValidateEvery indicates the number of epochs to run before running a validation. Defaults 1 (every epoch) 32 | ValidateEvery int 33 | 34 | CostObserver func(epoch int, totalEpoch, batch int, totalBatch int, cost float32) 35 | ValidationObserver func(confMat ConfusionMatrix, cost float32) 36 | MatchTypeFor func(predVal, targetVal []float32) MatchType 37 | CostFn CostFn 38 | } 39 | 40 | func (o *TrainOpts) setDefaults() { 41 | if o.Epochs == 0 { 42 | o.Epochs = 10 43 | } 44 | 45 | if o.BatchSize == 0 { 46 | o.BatchSize = 1024 47 | } 48 | 49 | if o.ValidateEvery == 0 { 50 | o.ValidateEvery = 1 51 | } 52 | 53 | if o.CostFn == nil { 54 | panic("CostFN must be set") 55 | } 56 | } 57 | 58 | // Train trains the model with the given data 59 | func Train(m *Model, module Module, trainX, trainY, validateX, validateY tensor.Tensor, opts TrainOpts) error { 60 | opts.setDefaults() 61 | 62 | if opts.DevMode { 63 | warn("Start training in dev mode") 64 | 65 | defer func() { 66 | if err := recover(); err != nil { 67 | graphFileName := "graph.dot" 68 | 69 | log.Printf("panic triggered, dumping the model graph to: %v", graphFileName) 70 | _ = ioutil.WriteFile(graphFileName, []byte(m.trainGraph.ToDot()), 0644) 71 | panic(err) 72 | } 73 | }() 74 | } 75 | 76 | if opts.WithLearnablesHeatmap { 77 | warn("Heatmaps will be stored in: %s", heatmapPath) 78 | _ = os.RemoveAll(heatmapPath) 79 | } 80 | 81 | dl := NewDataLoader(trainX, trainY, DataLoaderOpts{ 82 | BatchSize: opts.BatchSize, 83 | Shuffle: false, 84 | }) 85 | 86 | xShape := append(tensor.Shape{opts.BatchSize}, trainX.Shape()[1:]...) 87 | 88 | x := gorgonia.NewTensor(m.trainGraph, trainX.Dtype(), trainX.Shape().Dims(), gorgonia.WithShape(xShape...), gorgonia.WithName("x")) 89 | y := gorgonia.NewMatrix(m.trainGraph, trainY.Dtype(), gorgonia.WithShape(opts.BatchSize, trainY.Shape()[1]), gorgonia.WithName("y")) 90 | 91 | result := module.Forward(x) 92 | 93 | if opts.WriteGraphFileTo != "" { 94 | m.WriteSVG(opts.WriteGraphFileTo) 95 | } 96 | 97 | var ( 98 | costVal gorgonia.Value 99 | predVal gorgonia.Value 100 | ) 101 | 102 | { 103 | cost := opts.CostFn(result, y) 104 | 105 | gorgonia.Read(cost, &costVal) 106 | gorgonia.Read(result[0], &predVal) 107 | 108 | if _, err := gorgonia.Grad(cost, m.Learnables()...); err != nil { 109 | return fmt.Errorf("error calculating gradient: %w", err) 110 | } 111 | } 112 | 113 | validationGraph := m.trainGraph.SubgraphRoots(result[0]) 114 | validationGraph.RemoveNode(y) 115 | 116 | m.evalGraph = validationGraph 117 | 118 | vmOpts := []gorgonia.VMOpt{ 119 | gorgonia.BindDualValues(m.learnables...), 120 | } 121 | 122 | if opts.DevMode { 123 | vmOpts = append( 124 | vmOpts, 125 | gorgonia.TraceExec(), 126 | gorgonia.WithNaNWatch(), 127 | gorgonia.WithInfWatch(), 128 | ) 129 | } 130 | 131 | vm := gorgonia.NewTapeMachine(m.trainGraph, vmOpts...) 132 | 133 | if opts.Solver == nil { 134 | info("defaulting to RMS solver") 135 | 136 | opts.Solver = gorgonia.NewRMSPropSolver(gorgonia.WithBatchSize(float64(opts.BatchSize))) 137 | } 138 | 139 | defer vm.Close() 140 | 141 | startTime := time.Now() 142 | 143 | for i := 0; i < opts.Epochs; i++ { 144 | for dl.HasNext() { 145 | xVal, yVal := dl.Next() 146 | 147 | err := gorgonia.Let(x, xVal) 148 | if err != nil { 149 | fatal("error assigning x: %v", err) 150 | } 151 | 152 | err = gorgonia.Let(y, yVal) 153 | if err != nil { 154 | fatal("error assigning y: %v", err) 155 | } 156 | 157 | if err = vm.RunAll(); err != nil { 158 | fatal("Failed at epoch %d, batch %d. Error: %v", i, dl.CurrentBatch, err) 159 | } 160 | 161 | if opts.WithLearnablesHeatmap { 162 | m.saveHeatmaps(i, dl.CurrentBatch, dl.opts.BatchSize, dl.FeaturesShape[0]) 163 | } 164 | 165 | if err = opts.Solver.Step(gorgonia.NodesToValueGrads(m.learnables)); err != nil { 166 | fatal("Failed to update nodes with gradients at epoch %d, batch %d. Error %v", i, dl.CurrentBatch, err) 167 | } 168 | 169 | if opts.CostObserver != nil { 170 | opts.CostObserver(i, opts.Epochs, dl.CurrentBatch, dl.Batches, costVal.Data().(float32)) 171 | } else { 172 | // color.Yellow(" Epoch %d %d | cost %v (%v)\n", i, dl.CurrentBatch, costVal, time.Since(startTime)) 173 | } 174 | 175 | m.PrintWatchables() 176 | 177 | vm.Reset() 178 | } 179 | 180 | dl.Reset() 181 | 182 | if i%opts.ValidateEvery == 0 { 183 | err := Validate(m, x, y, costVal, predVal, validateX, validateY, opts) 184 | if err != nil { 185 | color.Red("Failed to run validation on epoch %v: %v", i, err) 186 | } 187 | 188 | color.Yellow(" Epoch %d | cost %v (%v)\n", i, costVal, time.Since(startTime)) 189 | } 190 | } 191 | 192 | return nil 193 | } 194 | -------------------------------------------------------------------------------- /vggface2/vggface2.go: -------------------------------------------------------------------------------- 1 | package vggface2 2 | 3 | import ( 4 | "github.com/dcu/godl" 5 | "gorgonia.org/gorgonia" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | type Opts struct { 10 | WithBias bool 11 | WeightsInit, BiasInit gorgonia.InitWFn 12 | Learnable bool 13 | 14 | PreTrained bool 15 | OnlyFeatureExtraction bool 16 | Classes int 17 | } 18 | 19 | func (opts *Opts) setDefaults() { 20 | if opts.Classes == 0 { 21 | opts.Classes = 8631 22 | } 23 | } 24 | 25 | type VGGFace2Module struct { 26 | model *godl.Model 27 | layer godl.LayerType 28 | 29 | seq godl.ModuleList 30 | } 31 | 32 | func (m *VGGFace2Module) Forward(inputs ...*godl.Node) godl.Nodes { 33 | if err := m.model.CheckArity(m.layer, inputs, 1); err != nil { 34 | panic(err) 35 | } 36 | 37 | x := inputs[0] 38 | 39 | result := godl.Conv2d(m.model, godl.Conv2dOpts{ 40 | InputDimension: 64, 41 | OutputDimension: 3, 42 | KernelSize: tensor.Shape{7, 7}, 43 | Pad: []int{0, 0}, 44 | WeightsName: "/conv1/7x7_s2/gamma", 45 | BiasName: "/conv1/7x7_s2/beta", 46 | }).Forward(x) 47 | 48 | result = godl.BatchNorm2d(m.model, godl.BatchNormOpts{ 49 | InputSize: result[0].Shape()[0], 50 | ScaleName: "/conv1/7x7_s2/bn/gamma", 51 | BiasName: "/conv1/7x7_s2/bn/beta", 52 | }).Forward(result[0]) 53 | 54 | x = gorgonia.Must(gorgonia.Rectify(result[0])) 55 | x = gorgonia.Must(gorgonia.MaxPool2D(x, tensor.Shape{3, 3}, []int{0, 0}, []int{1, 1})) 56 | 57 | result = m.seq.Forward(x) 58 | 59 | return result 60 | } 61 | 62 | func VGGFace2Builder(opts Opts) func(*godl.Model) godl.Module { 63 | return func(m *godl.Model) godl.Module { 64 | return VGGFace2(m, opts) 65 | } 66 | } 67 | 68 | func VGGFace2(m *godl.Model, opts Opts) *VGGFace2Module { 69 | lt := godl.AddLayer("VGGFace2") 70 | 71 | blocks := []godl.Module{ 72 | // Stage 2 73 | ConvBlock(m, BlockOpts{ 74 | KernelSize: tensor.Shape{3, 3}, 75 | Filters: [3]int{64, 64, 256}, 76 | Stage: 2, 77 | Block: 1, 78 | Stride: []int{1, 1}, 79 | }), 80 | IdentityBlock(m, BlockOpts{ 81 | KernelSize: tensor.Shape{3, 3}, 82 | Filters: [3]int{64, 64, 256}, 83 | Stage: 2, 84 | Block: 2, 85 | }), 86 | IdentityBlock(m, BlockOpts{ 87 | KernelSize: tensor.Shape{3, 3}, 88 | Filters: [3]int{64, 64, 256}, 89 | Stage: 2, 90 | Block: 3, 91 | }), 92 | // Stage 3 93 | ConvBlock(m, BlockOpts{ 94 | KernelSize: tensor.Shape{3, 3}, 95 | Filters: [3]int{128, 128, 512}, 96 | Stage: 3, 97 | Block: 1, 98 | Stride: []int{1, 1}, 99 | }), 100 | IdentityBlock(m, BlockOpts{ 101 | KernelSize: tensor.Shape{3, 3}, 102 | Filters: [3]int{128, 128, 512}, 103 | Stage: 3, 104 | Block: 2, 105 | }), 106 | IdentityBlock(m, BlockOpts{ 107 | KernelSize: tensor.Shape{3, 3}, 108 | Filters: [3]int{128, 128, 512}, 109 | Stage: 3, 110 | Block: 3, 111 | }), 112 | IdentityBlock(m, BlockOpts{ 113 | KernelSize: tensor.Shape{3, 3}, 114 | Filters: [3]int{128, 128, 512}, 115 | Stage: 3, 116 | Block: 4, 117 | }), 118 | // Stage 4 119 | ConvBlock(m, BlockOpts{ 120 | KernelSize: tensor.Shape{3, 3}, 121 | Filters: [3]int{256, 256, 1024}, 122 | Stage: 4, 123 | Block: 1, 124 | Stride: []int{1, 1}, 125 | }), 126 | IdentityBlock(m, BlockOpts{ 127 | KernelSize: tensor.Shape{3, 3}, 128 | Filters: [3]int{256, 256, 1024}, 129 | Stage: 4, 130 | Block: 2, 131 | }), 132 | IdentityBlock(m, BlockOpts{ 133 | KernelSize: tensor.Shape{3, 3}, 134 | Filters: [3]int{256, 256, 1024}, 135 | Stage: 4, 136 | Block: 3, 137 | }), 138 | IdentityBlock(m, BlockOpts{ 139 | KernelSize: tensor.Shape{3, 3}, 140 | Filters: [3]int{256, 256, 1024}, 141 | Stage: 4, 142 | Block: 4, 143 | }), 144 | IdentityBlock(m, BlockOpts{ 145 | KernelSize: tensor.Shape{3, 3}, 146 | Filters: [3]int{256, 256, 1024}, 147 | Stage: 4, 148 | Block: 5, 149 | }), 150 | IdentityBlock(m, BlockOpts{ 151 | KernelSize: tensor.Shape{3, 3}, 152 | Filters: [3]int{256, 256, 1024}, 153 | Stage: 4, 154 | Block: 6, 155 | }), 156 | // Stage 5 157 | ConvBlock(m, BlockOpts{ 158 | KernelSize: tensor.Shape{3, 3}, 159 | Filters: [3]int{512, 512, 2048}, 160 | Stage: 5, 161 | Block: 1, 162 | Stride: []int{1, 1}, 163 | }), 164 | IdentityBlock(m, BlockOpts{ 165 | KernelSize: tensor.Shape{3, 3}, 166 | Filters: [3]int{512, 512, 2048}, 167 | Stage: 5, 168 | Block: 2, 169 | }), 170 | IdentityBlock(m, BlockOpts{ 171 | KernelSize: tensor.Shape{3, 3}, 172 | Filters: [3]int{512, 512, 2048}, 173 | Stage: 5, 174 | Block: 3, 175 | }), 176 | godl.AvgPool2D(m, godl.AvgPool2DOpts{ 177 | Kernel: tensor.Shape{7, 7}, 178 | }), 179 | } 180 | 181 | if !opts.OnlyFeatureExtraction { 182 | blocks = append(blocks, godl.Linear(m, godl.LinearOpts{ 183 | InputDimension: 0, // FIXME 184 | OutputDimension: opts.Classes, 185 | WeightsName: "classifier/kernel", 186 | BiasName: "classifier/bias", 187 | })) 188 | } else { 189 | // TODO: give option to apply global max pool2d 190 | blocks = append(blocks, godl.GlobalAvgPool2D(m)) 191 | } 192 | 193 | seq := godl.Sequential(m, blocks...) 194 | 195 | return &VGGFace2Module{ 196 | model: m, 197 | layer: lt, 198 | seq: seq, 199 | } 200 | } 201 | 202 | func handleErr(err error) { 203 | if err != nil { 204 | panic(err) 205 | } 206 | } 207 | -------------------------------------------------------------------------------- /tabnet/tab_net_test.go: -------------------------------------------------------------------------------- 1 | package tabnet 2 | 3 | import ( 4 | "log" 5 | "os" 6 | "testing" 7 | 8 | "github.com/dcu/godl" 9 | "github.com/stretchr/testify/require" 10 | "gorgonia.org/gorgonia" 11 | "gorgonia.org/tensor" 12 | ) 13 | 14 | func TestTabNetEmbeddings(t *testing.T) { 15 | testCases := []struct { 16 | desc string 17 | input tensor.Tensor 18 | vbs int 19 | independentBlocks int 20 | sharedBlocks int 21 | output int 22 | steps int 23 | gamma float64 24 | prediction int 25 | attentive int 26 | expectedShape tensor.Shape 27 | expectedErr string 28 | expectedOutput []float32 29 | expectedGrad []float32 30 | expectedCost float32 31 | expectedAcumLoss float32 32 | }{ 33 | { 34 | desc: "Example 1", 35 | input: tensor.New( 36 | tensor.WithShape(4, 4), 37 | tensor.WithBacking([]float32{0.4, 1.4, 2.4, 0, 4.4, 5.4, 6.4, 1, 8.4, 9.4, 10.4, 2, 12.4, 13.4, 14.4, 3}), 38 | ), 39 | vbs: 4, 40 | output: 12, 41 | independentBlocks: 2, 42 | sharedBlocks: 2, 43 | steps: 5, 44 | gamma: 1.2, 45 | prediction: 8, 46 | attentive: 8, 47 | expectedShape: tensor.Shape{4, 12}, 48 | expectedOutput: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.81077474, 0.81077474, 0.81077474, 0.81077474, 0.81077474, 0.81077474, 0.81077474, 0.81077474, 0.81077474, 0.81077474, 0.81077474, 0.81077474, 102.183235, 102.183235, 102.183235, 102.183235, 102.183235, 102.183235, 102.183235, 102.183235, 102.183235, 102.183235, 102.183235, 102.183235}, 49 | expectedGrad: []float32{0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332}, 50 | expectedCost: 25.748503, 51 | expectedAcumLoss: -1.609438, 52 | }, 53 | } 54 | 55 | for _, tcase := range testCases { 56 | t.Run(tcase.desc, func(t *testing.T) { 57 | c := require.New(t) 58 | 59 | tn := godl.NewModel() 60 | logFile, _ := os.Create("tabnet.log") 61 | defer func() { _ = logFile.Close() }() 62 | 63 | tn.Logger = log.New(logFile, "", log.LstdFlags) 64 | 65 | g := tn.TrainGraph() 66 | 67 | x := gorgonia.NewTensor(g, tensor.Float32, 2, gorgonia.WithShape(tcase.input.Shape()...), gorgonia.WithName("Input"), gorgonia.WithValue(tcase.input)) 68 | 69 | a := gorgonia.NewTensor(g, tensor.Float32, tcase.input.Dims(), gorgonia.WithShape(tcase.input.Shape()...), gorgonia.WithInit(gorgonia.Ones()), gorgonia.WithName("AttentiveX")) 70 | priors := gorgonia.NewTensor(g, tensor.Float32, tcase.input.Dims(), gorgonia.WithShape(tcase.input.Shape()...), gorgonia.WithInit(gorgonia.Ones()), gorgonia.WithName("Priors")) 71 | 72 | result := TabNet(tn, TabNetOpts{ 73 | VirtualBatchSize: tcase.vbs, 74 | IndependentBlocks: tcase.independentBlocks, 75 | PredictionLayerDim: tcase.prediction, 76 | AttentionLayerDim: tcase.attentive, 77 | OutputSize: tcase.output, 78 | SharedBlocks: tcase.sharedBlocks, 79 | DecisionSteps: tcase.steps, 80 | Gamma: tcase.gamma, 81 | InputSize: a.Shape()[0], 82 | BatchSize: a.Shape()[0], 83 | WeightsInit: initDummyWeights, 84 | ScaleInit: gorgonia.Ones(), 85 | BiasInit: gorgonia.Zeroes(), 86 | Epsilon: 1e-10, 87 | CatIdxs: []int{3}, 88 | CatDims: []int{4}, 89 | CatEmbDim: []int{2}, 90 | }).Forward(x, a, priors) 91 | 92 | y := result[0] 93 | 94 | cost := gorgonia.Must(gorgonia.Mean(y)) 95 | _, err := gorgonia.Grad(cost, append([]*gorgonia.Node{x}, tn.Learnables()...)...) 96 | c.NoError(err) 97 | 98 | optimizer := gorgonia.NewAdamSolver(gorgonia.WithLearnRate(0.02)) 99 | 100 | vm := gorgonia.NewTapeMachine(g, 101 | gorgonia.BindDualValues(tn.Learnables()...), 102 | gorgonia.WithLogger(testLogger), 103 | gorgonia.WithValueFmt("%+v"), 104 | gorgonia.WithWatchlist(), 105 | gorgonia.WithNaNWatch(), 106 | gorgonia.WithInfWatch(), 107 | ) 108 | 109 | err = vm.RunAll() 110 | tn.PrintWatchables() 111 | c.NoError(err) 112 | 113 | err = optimizer.Step(gorgonia.NodesToValueGrads(tn.Learnables())) 114 | c.NoError(err) 115 | 116 | vm.Reset() 117 | 118 | // fmt.Printf("%v\n", g.String()) 119 | 120 | log.Printf("input grad: %v", x.Deriv().Value()) 121 | 122 | c.Equal(tcase.expectedShape, y.Shape()) 123 | 124 | log.Printf("[train] y: %#v", y.Value().Data()) 125 | log.Printf("[train] cost: %#v", cost.Value().Data()) 126 | log.Printf("[train] accum lost: %#v", result[1].Value().Data()) 127 | 128 | c.InDeltaSlice(tcase.expectedOutput, y.Value().Data().([]float32), 1e-5) 129 | 130 | c.InDelta(tcase.expectedCost, cost.Value().Data(), 1e-5) 131 | c.Equal(tcase.expectedAcumLoss, result[1].Value().Data()) 132 | 133 | vmEval := gorgonia.NewTapeMachine(g, 134 | gorgonia.EvalMode(), 135 | gorgonia.WithLogger(testLogger), 136 | gorgonia.WithValueFmt("%+v"), 137 | gorgonia.WithWatchlist(), 138 | gorgonia.WithNaNWatch(), 139 | gorgonia.WithInfWatch(), 140 | ) 141 | 142 | err = vmEval.RunAll() 143 | tn.PrintWatchables() 144 | c.NoError(err) 145 | 146 | log.Printf("[eval] y: %#v", y.Value().Data()) 147 | log.Printf("[eval] accum lost: %#v", result[1].Value().Data()) 148 | }) 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /tabnet/tab_net_no_embeddings_test.go: -------------------------------------------------------------------------------- 1 | package tabnet 2 | 3 | import ( 4 | "log" 5 | "testing" 6 | 7 | "github.com/dcu/godl" 8 | "github.com/stretchr/testify/require" 9 | "gorgonia.org/gorgonia" 10 | "gorgonia.org/tensor" 11 | ) 12 | 13 | func TestTabNetNoEmbeddings(t *testing.T) { 14 | testCases := []struct { 15 | desc string 16 | input tensor.Tensor 17 | vbs int 18 | independentBlocks int 19 | sharedBlocks int 20 | output int 21 | steps int 22 | gamma float64 23 | epsilon float64 24 | momentum float64 25 | prediction int 26 | attentive int 27 | expectedShape tensor.Shape 28 | expectedErr string 29 | expectedOutput []float32 30 | expectedCost float32 31 | expectedAcumLoss float32 32 | }{ 33 | // { 34 | // desc: "Example 1", 35 | // input: tensor.New( 36 | // tensor.WithShape(4, 4), 37 | // tensor.WithBacking([]float32{0.4, 1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4, 9.4, 10.4, 11.4, 12.4, 13.4, 14.4, 15.4}), 38 | // ), 39 | // vbs: 2, 40 | // output: 12, 41 | // independentBlocks: 2, 42 | // sharedBlocks: 2, 43 | // epsilon: 1e-10, 44 | // steps: 5, 45 | // gamma: 1.2, 46 | // momentum: 0.02, 47 | // prediction: 64, 48 | // attentive: 64, 49 | // expectedShape: tensor.Shape{4, 12}, 50 | // expectedOutput: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638, 447.8060947055638}, 51 | // expectedGrad: []float32{0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332}, 52 | // expectedCost: 223.9030473527819, 53 | // expectedAcumLoss: -1.3862943607198905, 54 | // }, 55 | { 56 | desc: "Example 2", 57 | input: tensor.New( 58 | tensor.WithShape(4, 5), 59 | tensor.WithBacking([]float32{0.4, 1.4, 2.4, 1, 1, 4.4, 5.4, 6.4, 1, 1, 8.4, 9.4, 10.4, 1, 1, 12.4, 13.4, 14.4, 1, 1}), 60 | ), 61 | vbs: 128, 62 | output: 1, 63 | independentBlocks: 2, 64 | sharedBlocks: 2, 65 | steps: 3, 66 | gamma: 1.3, 67 | epsilon: 1e-15, 68 | momentum: 0.02, 69 | prediction: 8, 70 | attentive: 8, 71 | expectedShape: tensor.Shape{4, 1}, 72 | expectedOutput: []float32{0, 0, 0.4864648, 61.30994}, 73 | expectedCost: 15.449101, 74 | expectedAcumLoss: -1.6094379124340954, 75 | }, 76 | } 77 | 78 | for _, tcase := range testCases { 79 | t.Run(tcase.desc, func(t *testing.T) { 80 | c := require.New(t) 81 | 82 | tn := godl.NewModel() 83 | 84 | g := tn.TrainGraph() 85 | 86 | x := gorgonia.NewTensor(g, tensor.Float32, 2, gorgonia.WithShape(tcase.input.Shape()...), gorgonia.WithName("Input"), gorgonia.WithValue(tcase.input)) 87 | 88 | a := gorgonia.NewTensor(g, tensor.Float32, tcase.input.Dims(), gorgonia.WithShape(tcase.input.Shape()...), gorgonia.WithInit(gorgonia.Ones()), gorgonia.WithName("AttentiveX")) 89 | priors := gorgonia.NewTensor(g, tensor.Float32, tcase.input.Dims(), gorgonia.WithShape(tcase.input.Shape()...), gorgonia.WithInit(gorgonia.Ones()), gorgonia.WithName("Priors")) 90 | 91 | result := TabNetNoEmbeddings(tn, TabNetNoEmbeddingsOpts{ 92 | VirtualBatchSize: tcase.vbs, 93 | IndependentBlocks: tcase.independentBlocks, 94 | PredictionLayerDim: tcase.prediction, 95 | AttentionLayerDim: tcase.attentive, 96 | OutputSize: tcase.output, 97 | SharedBlocks: tcase.sharedBlocks, 98 | DecisionSteps: tcase.steps, 99 | Gamma: tcase.gamma, 100 | InputSize: a.Shape()[1], 101 | BatchSize: a.Shape()[0], 102 | WeightsInit: initDummyWeights, 103 | ScaleInit: gorgonia.Ones(), 104 | BiasInit: gorgonia.Zeroes(), 105 | Epsilon: tcase.epsilon, 106 | Momentum: tcase.momentum, 107 | }).Forward(x, a, priors) 108 | 109 | y := result[0] 110 | 111 | cost := gorgonia.Must(gorgonia.Mean(y)) 112 | _, err := gorgonia.Grad(cost, append(tn.Learnables(), x)...) 113 | c.NoError(err) 114 | 115 | vm := gorgonia.NewTapeMachine(g, 116 | gorgonia.BindDualValues(tn.Learnables()...), 117 | gorgonia.WithLogger(testLogger), 118 | gorgonia.WithValueFmt("%+v"), 119 | gorgonia.WithWatchlist(), 120 | ) 121 | c.NoError(vm.RunAll()) 122 | 123 | tn.PrintWatchables() 124 | // fmt.Printf("%v\n", g.String()) 125 | 126 | log.Printf("input grad: %v", x.Deriv().Value()) 127 | 128 | c.Equal(tcase.expectedShape, y.Shape()) 129 | c.Equal(tcase.expectedOutput, y.Value().Data().([]float32)) 130 | 131 | c.Equal(tcase.expectedCost, cost.Value().Data()) 132 | c.Equal(tcase.expectedAcumLoss, result[1].Value().Data()) 133 | 134 | w := tn.Learnables()[len(tn.Learnables())-1] 135 | 136 | optim := gorgonia.NewAdamSolver(gorgonia.WithLearnRate(0.02)) 137 | err = optim.Step([]gorgonia.ValueGrad{w}) 138 | c.NoError(err) 139 | 140 | log.Printf("weight updated: %v\n\n\n", w.Value()) 141 | }) 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /tabnet/tab_net_regressor_test.go: -------------------------------------------------------------------------------- 1 | package tabnet 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/dcu/godl" 7 | "github.com/stretchr/testify/require" 8 | "gorgonia.org/gorgonia" 9 | "gorgonia.org/tensor" 10 | ) 11 | 12 | func TestTabNetRegressor(t *testing.T) { 13 | testCases := []struct { 14 | desc string 15 | epochs int 16 | input tensor.Tensor 17 | target tensor.Tensor 18 | vbs int 19 | independentBlocks int 20 | sharedBlocks int 21 | output int 22 | steps int 23 | gamma float64 24 | prediction int 25 | attentive int 26 | expectedShape tensor.Shape 27 | expectedErr string 28 | expectedOutput []float32 29 | expectedGrad []float32 30 | expectedCost float64 31 | expectedAcumLoss float64 32 | }{ 33 | { 34 | desc: "Example 1", 35 | epochs: 1, 36 | input: tensor.New( 37 | tensor.WithShape(4, 4), 38 | tensor.WithBacking([]float32{0.4, 1.4, 2.4, 0, 4.4, 5.4, 6.4, 1, 8.4, 9.4, 10.4, 2, 12.4, 13.4, 14.4, 3}), 39 | ), 40 | target: tensor.New( 41 | tensor.WithShape(4, 1), 42 | tensor.WithBacking([]float32{1, 1, 0, 0}), 43 | ), 44 | vbs: 128, 45 | output: 1, 46 | independentBlocks: 2, 47 | sharedBlocks: 2, 48 | steps: 3, 49 | gamma: 1.3, 50 | prediction: 8, 51 | attentive: 8, 52 | expectedShape: tensor.Shape{4, 12}, 53 | expectedOutput: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882, 447.8014308162882}, 54 | expectedGrad: []float32{0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332, 0.020833333333333332}, 55 | expectedCost: 223.90071540814404, 56 | expectedAcumLoss: -1.6094379119341007, 57 | }, 58 | } 59 | 60 | for _, tcase := range testCases { 61 | t.Run(tcase.desc, func(t *testing.T) { 62 | c := require.New(t) 63 | 64 | regressor := NewRegressor(tcase.input.Shape()[1], []int{4}, []int{3}, []int{2}, RegressorOpts{ 65 | VirtualBatchSize: tcase.vbs, 66 | IndependentBlocks: tcase.independentBlocks, 67 | PredictionLayerDim: tcase.prediction, 68 | AttentionLayerDim: tcase.attentive, 69 | SharedBlocks: tcase.sharedBlocks, 70 | DecisionSteps: tcase.steps, 71 | Gamma: tcase.gamma, 72 | BatchSize: tcase.input.Shape()[0], 73 | WeightsInit: initDummyWeights, 74 | ScaleInit: gorgonia.Ones(), 75 | BiasInit: gorgonia.Zeroes(), 76 | Epsilon: 1e-15, 77 | Momentum: 0.02, 78 | WithBias: false, 79 | }) 80 | 81 | err := regressor.Train(tcase.input, tcase.target, tcase.input, tcase.target, godl.TrainOpts{ 82 | Epochs: tcase.epochs, 83 | BatchSize: tcase.input.Shape()[0], 84 | DevMode: true, 85 | Solver: gorgonia.NewAdamSolver(gorgonia.WithLearnRate(0.02), gorgonia.WithClip(1.0)), 86 | MatchTypeFor: func(predVal, targetVal []float32) godl.MatchType { 87 | t.Logf("%v vs %v", predVal, targetVal) 88 | 89 | if targetVal[0] == 1 { 90 | if predVal[0] >= 0.5 { 91 | return godl.MatchTypeTruePositive 92 | } else { 93 | return godl.MatchTypeFalsePositive 94 | } 95 | } else { // == 0 96 | if predVal[0] < 0.5 { 97 | return godl.MatchTypeTrueNegative 98 | } else { 99 | return godl.MatchTypeFalseNegative 100 | } 101 | } 102 | }, 103 | ValidationObserver: func(confMat godl.ConfusionMatrix, cost float32) { 104 | t.Logf("%v\nCost: %0.4f", confMat, cost) 105 | }, 106 | }) 107 | c.NoError(err) 108 | 109 | regressor.model.PrintWatchables() 110 | 111 | for _, n := range regressor.model.Learnables() { 112 | t.Logf("%s: %v", n.Name(), n.Value().Data().([]float32)[0:2]) 113 | } 114 | 115 | // y := result.Output 116 | 117 | // if tcase.expectedErr != "" { 118 | // c.Error(err) 119 | 120 | // c.Equal(tcase.expectedErr, err.Error()) 121 | 122 | // return 123 | // } else { 124 | // c.NoError(err) 125 | // } 126 | 127 | // cost := gorgonia.Must(gorgonia.Mean(y)) 128 | // _, err = gorgonia.Grad(cost, append([]*gorgonia.Node{x}, tn.Learnables()...)...) 129 | // c.NoError(err) 130 | 131 | // vm := gorgonia.NewTapeMachine(g, 132 | // gorgonia.BindDualValues(tn.Learnables()...), 133 | // gorgonia.WithLogger(testLogger), 134 | // gorgonia.WithValueFmt("%+v"), 135 | // gorgonia.WithWatchlist(), 136 | // gorgonia.WithNaNWatch(), 137 | // gorgonia.WithInfWatch(), 138 | // ) 139 | // c.NoError(vm.RunAll()) 140 | 141 | // tn.PrintWatchables() 142 | // // fmt.Printf("%v\n", g.String()) 143 | 144 | // log.Printf("input grad: %v", x.Deriv().Value()) 145 | 146 | // c.Equal(tcase.expectedShape, y.Shape()) 147 | 148 | // log.Printf("y: %#v", y.Value().Data()) 149 | // c.InDeltaSlice(tcase.expectedOutput, y.Value().Data().([]float32), 1e-5) 150 | 151 | // yGrad, err := y.Grad() 152 | // c.NoError(err) 153 | 154 | // c.Equal(tcase.expectedGrad, yGrad.Data()) 155 | // c.InDelta(tcase.expectedCost, cost.Value().Data(), 1e-5) 156 | // c.Equal(tcase.expectedAcumLoss, result.Loss.Value().Data()) 157 | 158 | // weightsByName := map[string]*gorgonia.Node{} 159 | 160 | // for _, n := range tn.Learnables() { 161 | // weightsByName[n.Name()] = n 162 | 163 | // wGrad, err := n.Grad() 164 | // c.NoError(err) 165 | // log.Printf("%s: %v", n.Name(), wGrad.Data().([]float32)[0:2]) 166 | // } 167 | 168 | // // { 169 | // // w := weightsByName["BatchNorm1d_31.81.scale.1.5"] 170 | // // wGrad, err := w.Grad() 171 | // // c.NoError(err) 172 | // // c.Equal([]float32{0.0024, 0.0024, 0.0024, 0.0000, 0.0000}, wGrad.Data(), w.Name()) 173 | // // } 174 | 175 | // optim := gorgonia.NewAdamSolver(gorgonia.WithLearnRate(0.02)) 176 | // err = optim.Step(gorgonia.NodesToValueGrads(tn.Learnables())) 177 | // c.NoError(err) 178 | 179 | // { 180 | // w := weightsByName["BatchNorm1d_1.1.scale.1.5"] 181 | // log.Printf("weight updated: %v\n\n\n", w.Value()) 182 | // c.Equal([]float32{0.9800000823404622, 0.9800000823404622, 0.9800000823404622, 1, 1}, w.Value().Data(), w.Name()) 183 | // } 184 | 185 | // for _, n := range tn.Learnables() { 186 | // log.Printf("%s: %v", n.Name(), n.Value().Data().([]float32)[0:2]) 187 | // } 188 | }) 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /table/table.go: -------------------------------------------------------------------------------- 1 | package table 2 | 3 | import ( 4 | "encoding/csv" 5 | "fmt" 6 | "io" 7 | "log" 8 | "math/rand" 9 | "os" 10 | "sort" 11 | "strconv" 12 | "strings" 13 | 14 | "gorgonia.org/tensor" 15 | ) 16 | 17 | type Cell struct { 18 | Dtype tensor.Dtype 19 | V any 20 | } 21 | 22 | func (v Cell) Int() int { 23 | return v.V.(int) 24 | } 25 | 26 | func (v Cell) Float32() float32 { 27 | return v.V.(float32) 28 | } 29 | 30 | func (v Cell) String() string { 31 | return fmt.Sprintf("%v", v.V) 32 | } 33 | 34 | func StringToCell(v string) *Cell { 35 | i, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64) 36 | if err == nil { 37 | return &Cell{tensor.Int, int(i)} 38 | } 39 | 40 | f, err := strconv.ParseFloat(strings.TrimSpace(v), 32) 41 | if err == nil { 42 | return &Cell{tensor.Float32, float32(f)} 43 | } 44 | 45 | return &Cell{tensor.String, v} 46 | } 47 | 48 | type Row struct { 49 | Cells []*Cell 50 | Tags map[string]bool 51 | } 52 | 53 | func (r *Row) AddTag(tags ...string) { 54 | for _, tag := range tags { 55 | r.Tags[tag] = true 56 | } 57 | } 58 | 59 | func (r Row) HasAnyTag(tags []string) bool { 60 | if len(tags) == 0 { 61 | return true 62 | } 63 | 64 | for _, tag := range tags { 65 | if r.Tags[tag] { 66 | return true 67 | } 68 | } 69 | 70 | return false 71 | } 72 | 73 | func StringsToRow(values []string) Row { 74 | cells := make([]*Cell, len(values)) 75 | for i, v := range values { 76 | cells[i] = StringToCell(v) 77 | } 78 | 79 | return Row{cells, map[string]bool{}} 80 | } 81 | 82 | type Rows []*Row 83 | 84 | type Table struct { 85 | Header []string 86 | Rows Rows 87 | 88 | ClassesByColumn map[int][]string 89 | } 90 | 91 | // ReadCSV loads a CSV table 92 | func ReadCSV(pathCSV string) (*Table, error) { 93 | f, err := os.Open(pathCSV) 94 | if err != nil { 95 | return nil, err 96 | } 97 | 98 | defer func() { _ = f.Close() }() 99 | 100 | t := &Table{ 101 | ClassesByColumn: map[int][]string{}, 102 | } 103 | 104 | knownClasses := map[int]map[string]bool{} 105 | 106 | csvReader := csv.NewReader(f) 107 | for { 108 | record, err := csvReader.Read() 109 | if err == io.EOF { 110 | break 111 | } 112 | 113 | if t.Header == nil { 114 | t.Header = record 115 | 116 | for i := range t.Header { 117 | knownClasses[i] = map[string]bool{} 118 | } 119 | } 120 | 121 | if err != nil { 122 | return nil, err 123 | } 124 | 125 | cells := make([]*Cell, len(record)) 126 | for i, r := range record { 127 | v := StringToCell(r) 128 | 129 | cells[i] = v 130 | 131 | if v.Dtype == tensor.String && !knownClasses[i][v.V.(string)] { 132 | t.ClassesByColumn[i] = append(t.ClassesByColumn[i], v.V.(string)) 133 | 134 | knownClasses[i][v.V.(string)] = true 135 | } 136 | } 137 | 138 | t.Rows = append(t.Rows, &Row{cells, map[string]bool{}}) 139 | } 140 | 141 | return t, nil 142 | } 143 | 144 | func (t *Table) Has(column string) bool { 145 | for _, n := range t.Header { 146 | if n == column { 147 | return true 148 | } 149 | } 150 | 151 | return false 152 | } 153 | 154 | func (t *Table) AddColumn(columnName string, val interface{}) { 155 | switch v := val.(type) { 156 | case func() string: 157 | { 158 | for i := range t.Rows { 159 | t.Rows[i].Cells = append(t.Rows[i].Cells, StringToCell(v())) 160 | } 161 | } 162 | default: 163 | for i := range t.Rows { 164 | t.Rows[i].Cells = append(t.Rows[i].Cells, StringToCell(fmt.Sprintf("%v", val))) 165 | } 166 | } 167 | } 168 | 169 | func (t *Table) AddTag(tagFunc func() string) { 170 | for i := range t.Rows { 171 | tag := tagFunc() 172 | if tag != "" { 173 | t.Rows[i].Tags[tag] = true 174 | } 175 | } 176 | } 177 | 178 | func (t *Table) EachColumn(f func(columnName string, v *Cell)) { 179 | if len(t.Rows) == 0 { 180 | return 181 | } 182 | 183 | for i, h := range t.Rows[0].Cells { 184 | f(t.Header[i], h) 185 | } 186 | } 187 | 188 | func (t *Table) EachRow(f func(row *Row)) { 189 | if len(t.Rows) == 0 { 190 | return 191 | } 192 | 193 | for _, row := range t.Rows { 194 | f(row) 195 | } 196 | } 197 | 198 | func (t *Table) EachCell(cb func(rowNumber, columnNumber int, cell *Cell)) { 199 | for rowNumber, row := range t.Rows { 200 | for columnNumber, cell := range row.Cells { 201 | cb(rowNumber, columnNumber, cell) 202 | } 203 | } 204 | } 205 | 206 | type ToTensorOpts struct { 207 | TargetColumns []int 208 | SelectTags []string 209 | } 210 | 211 | func (opt *ToTensorOpts) setDefaults() { 212 | if opt.TargetColumns == nil { 213 | opt.TargetColumns = []int{} 214 | } 215 | } 216 | 217 | func (t *Table) ToTensors(opts ToTensorOpts) (x *tensor.Dense, y *tensor.Dense) { 218 | opts.setDefaults() 219 | 220 | indexes := make(map[int]map[string]int, len(t.ClassesByColumn)) 221 | 222 | for col, cats := range t.ClassesByColumn { 223 | sort.Strings(cats) 224 | 225 | indexes[col] = make(map[string]int, len(cats)) 226 | for i, c := range cats { 227 | indexes[col][c] = i 228 | } 229 | } 230 | 231 | targetColumnsIndexed := make(map[int]bool, len(opts.TargetColumns)) 232 | for _, col := range opts.TargetColumns { 233 | targetColumnsIndexed[col] = true 234 | } 235 | 236 | width := len(t.Header) - len(opts.TargetColumns) 237 | backing := make([]float32, 0, width*len(t.Rows)) 238 | targetBacking := make([]float32, 0, len(t.Rows)*len(opts.TargetColumns)) 239 | 240 | var rowsCount int 241 | 242 | for _, row := range t.Rows { 243 | include := row.HasAnyTag(opts.SelectTags) 244 | if !include { 245 | continue 246 | } 247 | 248 | for colIndex, cell := range row.Cells { 249 | var val float32 250 | 251 | if m, ok := indexes[colIndex]; ok { 252 | // this is a categorical column which is encoded as a number 253 | val = float32(m[fmt.Sprintf("%v", cell.V)]) 254 | } else if cell.Dtype == tensor.Int { 255 | val = float32(cell.V.(int)) 256 | } else if cell.Dtype == tensor.Float64 { 257 | val = float32(cell.V.(float64)) 258 | } else if cell.Dtype == tensor.Float32 { 259 | val = cell.V.(float32) 260 | } else { 261 | log.Panicf("unsupported type: %v", cell.Dtype) 262 | } 263 | 264 | if targetColumnsIndexed[colIndex] { 265 | targetBacking = append(targetBacking, val) 266 | } else { 267 | backing = append(backing, val) 268 | } 269 | } 270 | 271 | rowsCount++ 272 | } 273 | 274 | x = tensor.New( 275 | tensor.Of(tensor.Float32), 276 | tensor.WithShape(rowsCount, width), 277 | tensor.WithBacking(backing), 278 | ) 279 | 280 | if len(opts.TargetColumns) > 0 { 281 | y = tensor.New( 282 | tensor.Of(tensor.Float32), 283 | tensor.WithShape(rowsCount, len(opts.TargetColumns)), 284 | tensor.WithBacking(targetBacking), 285 | ) 286 | } 287 | 288 | return x, y 289 | } 290 | 291 | func (t *Table) CategoricalColumns(excludeColumns ...int) (columns []int, dimensions []int) { 292 | for col := range t.ClassesByColumn { 293 | if !isIn(col, excludeColumns) { 294 | columns = append(columns, col) 295 | } 296 | } 297 | 298 | sort.Ints(columns) 299 | 300 | for _, col := range columns { 301 | dimensions = append(dimensions, len(t.ClassesByColumn[col])) 302 | } 303 | 304 | return columns, dimensions 305 | } 306 | 307 | func RandValueIn(valueAndProbability map[string]float64) func() string { 308 | totalProb := 0.0 309 | 310 | values := make([]string, 0, len(valueAndProbability)) 311 | thresholds := make([]float64, 0, len(valueAndProbability)) 312 | 313 | for val, prob := range valueAndProbability { 314 | totalProb += prob 315 | 316 | thresholds = append(thresholds, totalProb) 317 | values = append(values, val) 318 | } 319 | 320 | if totalProb != 1 { 321 | log.Panicf("probabilities sum must be 1") 322 | } 323 | 324 | return func() string { 325 | r := rand.Float64() 326 | 327 | for i, p := range thresholds { 328 | if r <= p { 329 | return values[i] 330 | } 331 | } 332 | 333 | // this can't happen 334 | return "" 335 | } 336 | } 337 | 338 | func isIn(x int, a []int) bool { 339 | for _, v := range a { 340 | if x == v { 341 | return true 342 | } 343 | } 344 | 345 | return false 346 | } 347 | 348 | var ( 349 | _ fmt.Stringer = Cell{} 350 | ) 351 | -------------------------------------------------------------------------------- /tabnet/tab_net_no_embeddings.go: -------------------------------------------------------------------------------- 1 | package tabnet 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | 7 | "github.com/dcu/godl" 8 | "github.com/dcu/godl/activation" 9 | "gorgonia.org/gorgonia" 10 | "gorgonia.org/tensor" 11 | ) 12 | 13 | // TabNetNoEmbeddingsOpts contains parameters to configure the tab net algorithm 14 | type TabNetNoEmbeddingsOpts struct { 15 | OutputSize int 16 | InputSize int 17 | BatchSize int 18 | 19 | SharedBlocks int 20 | IndependentBlocks int 21 | DecisionSteps int 22 | PredictionLayerDim int 23 | AttentionLayerDim int 24 | 25 | MaskFunction activation.Function 26 | 27 | WithBias bool 28 | 29 | Gamma float64 30 | Momentum float64 31 | Epsilon float64 32 | VirtualBatchSize int 33 | WeightsInit, ScaleInit, BiasInit gorgonia.InitWFn 34 | } 35 | 36 | func (o *TabNetNoEmbeddingsOpts) setDefaults() { 37 | if o.SharedBlocks == 0 { 38 | o.SharedBlocks = 2 39 | } 40 | 41 | if o.IndependentBlocks == 0 { 42 | o.IndependentBlocks = 2 43 | } 44 | 45 | if o.DecisionSteps == 0 { 46 | o.DecisionSteps = 3 47 | } 48 | 49 | if o.PredictionLayerDim == 0 { 50 | o.PredictionLayerDim = 8 51 | } 52 | 53 | if o.AttentionLayerDim == 0 { 54 | o.AttentionLayerDim = 8 55 | } 56 | 57 | if o.Epsilon == 0.0 { 58 | o.Epsilon = 1e-15 59 | } 60 | 61 | if o.Gamma == 0.0 { 62 | o.Gamma = 1.3 63 | } 64 | 65 | if o.Momentum == 0 { 66 | o.Momentum = 0.02 67 | } 68 | 69 | if o.VirtualBatchSize == 0 { 70 | o.VirtualBatchSize = 128 71 | } 72 | } 73 | 74 | type TabNetNoEmbeddingsModule struct { 75 | model *godl.Model 76 | opts TabNetNoEmbeddingsOpts 77 | 78 | bn *godl.BatchNormModule 79 | initialSplitter *FeatureTransformerModule 80 | finalMapping *godl.LinearModule 81 | 82 | attentiveTransformers []*AttentiveTransformerModule 83 | featureTransformers []*FeatureTransformerModule 84 | } 85 | 86 | func (m *TabNetNoEmbeddingsModule) Forward(inputs ...*godl.Node) godl.Nodes { 87 | x := inputs[0] 88 | 89 | bn := m.bn.Forward(x)[0] 90 | ft := m.initialSplitter.Forward(bn)[0] 91 | 92 | xAttentiveLayer := gorgonia.Must(gorgonia.Slice(ft, nil, gorgonia.S(m.opts.PredictionLayerDim, ft.Shape()[1]))) 93 | 94 | gamma := gorgonia.NewConstant(float32(m.opts.Gamma)) 95 | epsilon := gorgonia.NewConstant(float32(m.opts.Epsilon)) 96 | 97 | loss := gorgonia.NewScalar(m.model.TrainGraph(), tensor.Float32, gorgonia.WithValue(float32(0.0)), gorgonia.WithName("TabNetLoss")) 98 | stepsCount := gorgonia.NewScalar(m.model.TrainGraph(), tensor.Float32, gorgonia.WithValue(float32(m.opts.DecisionSteps)), gorgonia.WithName("Steps")) 99 | 100 | prior := gorgonia.NewTensor(m.model.TrainGraph(), tensor.Float32, 2, gorgonia.WithShape(m.opts.BatchSize, m.opts.InputSize), gorgonia.WithInit(gorgonia.Ones()), gorgonia.WithName("Prior")) 101 | out := gorgonia.NewTensor(m.model.TrainGraph(), tensor.Float32, 2, gorgonia.WithShape(m.opts.BatchSize, m.opts.PredictionLayerDim), gorgonia.WithInit(gorgonia.Zeroes()), gorgonia.WithName("Output")) 102 | 103 | for i := 0; i < m.opts.DecisionSteps; i++ { 104 | attentiveTransformer := m.attentiveTransformers[i] 105 | featureTransformer := m.featureTransformers[i] 106 | 107 | result := attentiveTransformer.Forward(xAttentiveLayer, prior) 108 | 109 | mask := result[0] 110 | 111 | stepLoss := gorgonia.Must(gorgonia.Mean( 112 | gorgonia.Must(gorgonia.Sum( 113 | gorgonia.Must(gorgonia.HadamardProd( 114 | mask, 115 | gorgonia.Must(gorgonia.Log( 116 | gorgonia.Must(gorgonia.Add(mask, epsilon)), 117 | )), 118 | )), 119 | 1, 120 | )), 121 | )) 122 | 123 | // accum losses 124 | loss = gorgonia.Must(gorgonia.Add(loss, stepLoss)) 125 | 126 | // Update prior 127 | { 128 | prior = gorgonia.Must(gorgonia.HadamardProd(gorgonia.Must(gorgonia.Sub(gamma, mask)), prior)) 129 | } 130 | 131 | maskedX := gorgonia.Must(gorgonia.HadamardProd(mask, bn)) 132 | 133 | ds := featureTransformer.Forward(maskedX)[0] 134 | 135 | firstPart := gorgonia.Must(gorgonia.Slice(ds, nil, gorgonia.S(0, m.opts.PredictionLayerDim))) 136 | 137 | relu := gorgonia.Must(gorgonia.Rectify(firstPart)) 138 | 139 | out = gorgonia.Must(gorgonia.Add(out, relu)) 140 | 141 | xAttentiveLayer = gorgonia.Must(gorgonia.Slice(ds, nil, gorgonia.S(m.opts.PredictionLayerDim, ds.Shape()[1]))) 142 | } 143 | 144 | loss = gorgonia.Must(gorgonia.Div(loss, stepsCount)) 145 | result := m.finalMapping.Forward(out)[0] 146 | 147 | return godl.Nodes{result, loss} 148 | } 149 | 150 | // TabNetNoEmbeddings implements the tab net architecture 151 | func TabNetNoEmbeddings(nn *godl.Model, opts TabNetNoEmbeddingsOpts) *TabNetNoEmbeddingsModule { 152 | opts.setDefaults() 153 | 154 | bnLayer := godl.BatchNorm1d(nn, godl.BatchNormOpts{ 155 | ScaleInit: opts.ScaleInit, 156 | BiasInit: opts.BiasInit, 157 | InputSize: opts.InputSize, 158 | Momentum: 0.01, 159 | }) 160 | 161 | shared := make([]*godl.LinearModule, 0, opts.SharedBlocks) 162 | outputDim := 2 * (opts.PredictionLayerDim + opts.AttentionLayerDim) // double the size so we can take half and half 163 | 164 | { 165 | fcInput := opts.InputSize 166 | fcOutput := outputDim 167 | 168 | for i := 0; i < opts.SharedBlocks; i++ { 169 | sharedWeightsInit := opts.WeightsInit 170 | 171 | if sharedWeightsInit == nil { 172 | gain := math.Sqrt(float64(fcInput+fcOutput) / math.Sqrt(float64(fcInput))) 173 | 174 | sharedWeightsInit = gorgonia.GlorotN(gain) 175 | } 176 | 177 | shared = append(shared, godl.Linear(nn, godl.LinearOpts{ 178 | InputDimension: fcInput, 179 | OutputDimension: fcOutput, 180 | WeightsInit: sharedWeightsInit, 181 | WithBias: opts.WithBias, 182 | WeightsName: fmt.Sprintf("shared.weight.%d", i), 183 | BiasName: fmt.Sprintf("shared.bias.%d", i), 184 | })) 185 | 186 | fcInput = opts.PredictionLayerDim + opts.AttentionLayerDim 187 | } 188 | } 189 | 190 | // first step 191 | initialSplitter := FeatureTransformer(nn, FeatureTransformerOpts{ 192 | Shared: shared, 193 | VirtualBatchSize: opts.VirtualBatchSize, 194 | IndependentBlocks: opts.IndependentBlocks, 195 | InputDimension: opts.InputSize, 196 | OutputDimension: opts.AttentionLayerDim + opts.PredictionLayerDim, 197 | WeightsInit: opts.WeightsInit, 198 | WithBias: opts.WithBias, 199 | Momentum: opts.Momentum, 200 | }) 201 | 202 | featureTransformers := make([]*FeatureTransformerModule, 0, opts.DecisionSteps) 203 | attentiveTransformers := make([]*AttentiveTransformerModule, 0, opts.DecisionSteps) 204 | 205 | for i := 0; i < opts.DecisionSteps; i++ { 206 | featureTransformer := FeatureTransformer(nn, FeatureTransformerOpts{ 207 | Shared: shared, 208 | VirtualBatchSize: opts.VirtualBatchSize, 209 | InputDimension: opts.BatchSize, 210 | OutputDimension: opts.AttentionLayerDim + opts.PredictionLayerDim, 211 | IndependentBlocks: opts.IndependentBlocks, 212 | WeightsInit: opts.WeightsInit, 213 | WithBias: opts.WithBias, 214 | Momentum: opts.Momentum, 215 | }) 216 | featureTransformers = append(featureTransformers, featureTransformer) 217 | } 218 | 219 | for i := 0; i < opts.DecisionSteps; i++ { 220 | attentiveTransformer := AttentiveTransformer(nn, AttentiveTransformerOpts{ 221 | InputDimension: opts.AttentionLayerDim, // or prediction? 222 | OutputDimension: opts.InputSize, 223 | Momentum: opts.Momentum, 224 | Epsilon: opts.Epsilon, 225 | VirtualBatchSize: opts.VirtualBatchSize, 226 | ScaleInit: opts.ScaleInit, 227 | BiasInit: opts.BiasInit, 228 | WeightsInit: opts.WeightsInit, 229 | Activation: opts.MaskFunction, 230 | WithBias: opts.WithBias, 231 | }) 232 | attentiveTransformers = append(attentiveTransformers, attentiveTransformer) 233 | } 234 | 235 | weightsInit := opts.WeightsInit 236 | if weightsInit == nil { 237 | gain := math.Sqrt(float64(opts.PredictionLayerDim+opts.OutputSize) / math.Sqrt(float64(4*opts.PredictionLayerDim))) 238 | 239 | weightsInit = gorgonia.GlorotN(gain) 240 | } 241 | 242 | finalMapping := godl.Linear(nn, godl.LinearOpts{ 243 | InputDimension: opts.PredictionLayerDim, 244 | OutputDimension: opts.OutputSize, 245 | WeightsInit: weightsInit, 246 | WeightsName: "FinalMapping", 247 | WithBias: opts.WithBias, 248 | }) 249 | 250 | return &TabNetNoEmbeddingsModule{ 251 | model: nn, 252 | opts: opts, 253 | bn: bnLayer, 254 | initialSplitter: initialSplitter, 255 | finalMapping: finalMapping, 256 | attentiveTransformers: attentiveTransformers, 257 | featureTransformers: featureTransformers, 258 | } 259 | } 260 | -------------------------------------------------------------------------------- /vgg/vgg16.go: -------------------------------------------------------------------------------- 1 | package vgg 2 | 3 | import ( 4 | "github.com/dcu/godl" 5 | "github.com/dcu/godl/activation" 6 | "gorgonia.org/gorgonia" 7 | ) 8 | 9 | // Opts are the options for VGG 10 | type Opts struct { 11 | WithBias bool 12 | WeightsInit, BiasInit gorgonia.InitWFn 13 | Learnable bool 14 | 15 | PreTrained bool 16 | OnlyFeatureExtraction bool 17 | } 18 | 19 | type VGG16Module struct { 20 | model *godl.Model 21 | opts Opts 22 | layer godl.LayerType 23 | 24 | seq godl.ModuleList 25 | } 26 | 27 | func (m *VGG16Module) Name() string { 28 | return "VGG16" 29 | } 30 | 31 | func (m *VGG16Module) Forward(inputs ...*godl.Node) godl.Nodes { 32 | if err := m.model.CheckArity(m.layer, inputs, 1); err != nil { 33 | panic(err) 34 | } 35 | 36 | x := inputs[0] 37 | 38 | return m.seq.Forward(x) 39 | } 40 | 41 | func VGG16Builder(opts Opts) func(*godl.Model) godl.Module { 42 | return func(m *godl.Model) godl.Module { 43 | return VGG16(m, opts) 44 | } 45 | } 46 | 47 | // VGG16 returns the layer for the VGG16 network 48 | func VGG16(m *godl.Model, opts Opts) *VGG16Module { 49 | lt := godl.AddLayer("vgg.VGG16") 50 | fixedWeights := false 51 | 52 | if opts.PreTrained { 53 | fileName := "vgg16.nn1.gz" 54 | if opts.OnlyFeatureExtraction { 55 | fileName = "vgg16_notop.nn1.gz" 56 | } 57 | 58 | err := m.Storage.LoadFile(fileName) 59 | if err != nil { 60 | panic(err) 61 | } 62 | 63 | fixedWeights = !opts.Learnable 64 | opts.WithBias = true 65 | } 66 | 67 | layers := []godl.Module{ 68 | Block(m, BlockOpts{ 69 | InputDimension: 3, 70 | OutputDimension: 64, 71 | Activation: gorgonia.Rectify, 72 | Dropout: 0.0, 73 | WithBias: opts.WithBias, 74 | BiasInit: opts.BiasInit, 75 | WeightsInit: opts.WeightsInit, 76 | WeightsName: "/block1_conv1/block1_conv1_W:0", 77 | BiasName: "/block1_conv1/block1_conv1_b:0", 78 | FixedWeights: fixedWeights, 79 | }), 80 | Block(m, BlockOpts{ 81 | InputDimension: 64, 82 | OutputDimension: 64, 83 | Activation: gorgonia.Rectify, 84 | Dropout: 0.0, 85 | WithBias: opts.WithBias, 86 | BiasInit: opts.BiasInit, 87 | WeightsInit: opts.WeightsInit, 88 | WithPooling: true, 89 | WeightsName: "/block1_conv2/block1_conv2_W:0", 90 | BiasName: "/block1_conv2/block1_conv2_b:0", 91 | FixedWeights: fixedWeights, 92 | }), 93 | Block(m, BlockOpts{ 94 | InputDimension: 64, 95 | OutputDimension: 128, 96 | Activation: gorgonia.Rectify, 97 | Dropout: 0.0, 98 | WithBias: opts.WithBias, 99 | BiasInit: opts.BiasInit, 100 | WeightsInit: opts.WeightsInit, 101 | WeightsName: "/block2_conv1/block2_conv1_W:0", 102 | BiasName: "/block2_conv1/block2_conv1_b:0", 103 | FixedWeights: fixedWeights, 104 | }), 105 | Block(m, BlockOpts{ 106 | InputDimension: 128, 107 | OutputDimension: 128, 108 | Activation: gorgonia.Rectify, 109 | Dropout: 0.0, 110 | WithBias: opts.WithBias, 111 | BiasInit: opts.BiasInit, 112 | WeightsInit: opts.WeightsInit, 113 | WithPooling: true, 114 | WeightsName: "/block2_conv2/block2_conv2_W:0", 115 | BiasName: "/block2_conv2/block2_conv2_b:0", 116 | FixedWeights: fixedWeights, 117 | }), 118 | Block(m, BlockOpts{ 119 | InputDimension: 128, 120 | OutputDimension: 256, 121 | Activation: gorgonia.Rectify, 122 | Dropout: 0.0, 123 | WithBias: opts.WithBias, 124 | BiasInit: opts.BiasInit, 125 | WeightsInit: opts.WeightsInit, 126 | WeightsName: "/block3_conv1/block3_conv1_W:0", 127 | BiasName: "/block3_conv1/block3_conv1_b:0", 128 | FixedWeights: fixedWeights, 129 | }), 130 | Block(m, BlockOpts{ 131 | InputDimension: 256, 132 | OutputDimension: 256, 133 | Activation: gorgonia.Rectify, 134 | Dropout: 0.0, 135 | WithBias: opts.WithBias, 136 | BiasInit: opts.BiasInit, 137 | WeightsInit: opts.WeightsInit, 138 | WeightsName: "/block3_conv2/block3_conv2_W:0", 139 | BiasName: "/block3_conv2/block3_conv2_b:0", 140 | FixedWeights: fixedWeights, 141 | }), 142 | Block(m, BlockOpts{ 143 | InputDimension: 256, 144 | OutputDimension: 256, 145 | Activation: gorgonia.Rectify, 146 | Dropout: 0.0, 147 | WithBias: opts.WithBias, 148 | BiasInit: opts.BiasInit, 149 | WeightsInit: opts.WeightsInit, 150 | WithPooling: true, 151 | WeightsName: "/block3_conv3/block3_conv3_W:0", 152 | BiasName: "/block3_conv3/block3_conv3_b:0", 153 | FixedWeights: fixedWeights, 154 | }), 155 | Block(m, BlockOpts{ 156 | InputDimension: 256, 157 | OutputDimension: 512, 158 | Activation: gorgonia.Rectify, 159 | Dropout: 0.0, 160 | WithBias: opts.WithBias, 161 | BiasInit: opts.BiasInit, 162 | WeightsInit: opts.WeightsInit, 163 | WeightsName: "/block4_conv1/block4_conv1_W:0", 164 | BiasName: "/block4_conv1/block4_conv1_b:0", 165 | FixedWeights: fixedWeights, 166 | }), 167 | Block(m, BlockOpts{ 168 | InputDimension: 512, 169 | OutputDimension: 512, 170 | Activation: gorgonia.Rectify, 171 | Dropout: 0.0, 172 | WithBias: opts.WithBias, 173 | BiasInit: opts.BiasInit, 174 | WeightsInit: opts.WeightsInit, 175 | WeightsName: "/block4_conv2/block4_conv2_W:0", 176 | BiasName: "/block4_conv2/block4_conv2_b:0", 177 | FixedWeights: fixedWeights, 178 | }), 179 | Block(m, BlockOpts{ 180 | InputDimension: 512, 181 | OutputDimension: 512, 182 | Activation: gorgonia.Rectify, 183 | Dropout: 0.0, 184 | WithBias: opts.WithBias, 185 | BiasInit: opts.BiasInit, 186 | WeightsInit: opts.WeightsInit, 187 | WithPooling: true, 188 | WeightsName: "/block4_conv3/block4_conv3_W:0", 189 | BiasName: "/block4_conv3/block4_conv3_b:0", 190 | FixedWeights: fixedWeights, 191 | }), 192 | Block(m, BlockOpts{ 193 | InputDimension: 512, 194 | OutputDimension: 512, 195 | Activation: gorgonia.Rectify, 196 | Dropout: 0.0, 197 | WithBias: opts.WithBias, 198 | BiasInit: opts.BiasInit, 199 | WeightsInit: opts.WeightsInit, 200 | WeightsName: "/block5_conv1/block5_conv1_W:0", 201 | BiasName: "/block5_conv1/block5_conv1_b:0", 202 | FixedWeights: fixedWeights, 203 | }), 204 | Block(m, BlockOpts{ 205 | InputDimension: 512, 206 | OutputDimension: 512, 207 | Activation: gorgonia.Rectify, 208 | Dropout: 0.0, 209 | WithBias: opts.WithBias, 210 | BiasInit: opts.BiasInit, 211 | WeightsInit: opts.WeightsInit, 212 | WeightsName: "/block5_conv2/block5_conv2_W:0", 213 | BiasName: "/block5_conv2/block5_conv2_b:0", 214 | FixedWeights: fixedWeights, 215 | }), 216 | Block(m, BlockOpts{ 217 | InputDimension: 512, 218 | OutputDimension: 512, 219 | Activation: gorgonia.Rectify, 220 | Dropout: 0.0, 221 | WithBias: opts.WithBias, 222 | BiasInit: opts.BiasInit, 223 | WeightsInit: opts.WeightsInit, 224 | WithPooling: true, 225 | WeightsName: "/block5_conv3/block5_conv3_W:0", 226 | BiasName: "/block5_conv3/block5_conv3_b:0", 227 | FixedWeights: fixedWeights, 228 | }), 229 | } 230 | 231 | if !opts.OnlyFeatureExtraction { 232 | layers = append(layers, 233 | godl.Linear(m, godl.LinearOpts{ 234 | InputDimension: 25088, 235 | OutputDimension: 4096, 236 | WithBias: opts.WithBias, 237 | Activation: gorgonia.Rectify, 238 | Dropout: 0.0, 239 | WeightsInit: opts.WeightsInit, 240 | BiasInit: opts.BiasInit, 241 | WeightsName: "/fc1/fc1_W:0", 242 | BiasName: "/fc1/fc1_b:0", 243 | FixedWeights: fixedWeights, 244 | }), 245 | godl.Linear(m, godl.LinearOpts{ 246 | InputDimension: 4096, 247 | OutputDimension: 4096, 248 | WithBias: opts.WithBias, 249 | Activation: gorgonia.Rectify, 250 | Dropout: 0.0, 251 | WeightsInit: opts.WeightsInit, 252 | BiasInit: opts.BiasInit, 253 | WeightsName: "/fc2/fc2_W:0", 254 | BiasName: "/fc2/fc2_b:0", 255 | FixedWeights: fixedWeights, 256 | }), 257 | godl.Linear(m, godl.LinearOpts{ 258 | InputDimension: 4096, 259 | OutputDimension: 1000, 260 | WithBias: opts.WithBias, 261 | Activation: activation.SoftMax, 262 | Dropout: 0.0, 263 | WeightsInit: opts.WeightsInit, 264 | BiasInit: opts.BiasInit, 265 | WeightsName: "/predictions/predictions_W:0", 266 | BiasName: "/predictions/predictions_b:0", 267 | FixedWeights: fixedWeights, 268 | }), 269 | ) 270 | } 271 | 272 | seq := godl.Sequential(m, layers...) 273 | 274 | return &VGG16Module{ 275 | model: m, 276 | opts: opts, 277 | layer: lt, 278 | seq: seq, 279 | } 280 | } 281 | -------------------------------------------------------------------------------- /lstm/lstm.go: -------------------------------------------------------------------------------- 1 | package lstm 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/dcu/godl" 7 | "github.com/dcu/godl/activation" 8 | "gorgonia.org/gorgonia" 9 | "gorgonia.org/tensor" 10 | ) 11 | 12 | type MergeMode string 13 | 14 | const ( 15 | MergeModeConcat MergeMode = "concat" 16 | MergeModeSum MergeMode = "sum" 17 | MergeModeAverage MergeMode = "avg" 18 | MergeModeMul MergeMode = "mul" 19 | ) 20 | 21 | type LSTMOpts struct { 22 | InputDimension int 23 | HiddenSize int 24 | // Layers int 25 | Bidirectional bool 26 | MergeMode MergeMode 27 | 28 | WithBias bool 29 | WeightsInit, BiasInit gorgonia.InitWFn 30 | 31 | RecurrentActivation activation.Function 32 | Activation activation.Function 33 | } 34 | 35 | func (o *LSTMOpts) setDefaults() { 36 | if o.RecurrentActivation == nil { 37 | o.RecurrentActivation = activation.Sigmoid 38 | } 39 | 40 | if o.Activation == nil { 41 | o.Activation = activation.Tanh 42 | } 43 | 44 | if o.MergeMode == "" { 45 | o.MergeMode = MergeModeConcat 46 | } 47 | } 48 | 49 | type LSTMModule struct { 50 | model *godl.Model 51 | layer godl.LayerType 52 | opts LSTMOpts 53 | weights []lstmParams 54 | } 55 | 56 | func (m *LSTMModule) Name() string { 57 | return "LSTM" 58 | } 59 | 60 | func (m *LSTMModule) Forward(inputs ...*godl.Node) godl.Nodes { 61 | two := gorgonia.NewConstant(float32(2.0), gorgonia.WithName("two")) 62 | x := inputs[0] 63 | 64 | var ( 65 | prevHidden, prevCell *gorgonia.Node 66 | ) 67 | 68 | xShape := x.Shape() 69 | 70 | if xShape.Dims() != 3 { 71 | var err error 72 | newShape := tensor.Shape{m.weights[0].inputWeights.Shape()[0], m.opts.HiddenSize * 4, m.opts.InputDimension} 73 | 74 | x, err = gorgonia.Reshape(x, newShape) 75 | if err != nil { 76 | panic(fmt.Errorf("x %v cannot be reshaped as %v", xShape, newShape)) 77 | } 78 | } 79 | 80 | switch len(inputs) { 81 | case 1: 82 | batchSize := x.Shape()[1] 83 | 84 | dummyHidden := gorgonia.NewTensor(m.model.TrainGraph(), tensor.Float32, 3, gorgonia.WithShape(1, batchSize, m.opts.HiddenSize), gorgonia.WithInit(gorgonia.Zeroes()), gorgonia.WithName(string(m.layer)+"LSTMDummyHidden")) 85 | dummyCell := gorgonia.NewTensor(m.model.TrainGraph(), tensor.Float32, 3, gorgonia.WithShape(1, batchSize, m.opts.HiddenSize), gorgonia.WithInit(gorgonia.Zeroes()), gorgonia.WithName(string(m.layer)+"LSTMDummyCell")) 86 | 87 | prevHidden = dummyHidden 88 | prevCell = dummyCell 89 | case 3: 90 | prevHidden = inputs[1] 91 | prevCell = inputs[2] 92 | default: 93 | panic(fmt.Errorf("%v: invalid input size", m.layer)) 94 | } 95 | 96 | if m.opts.InputDimension != x.Shape()[2] { 97 | panic(fmt.Errorf("expecting input size = %v and got %v", m.opts.InputDimension, x.Shape()[2])) 98 | } 99 | 100 | if !m.opts.Bidirectional { 101 | return lstm(m.model, x, m.weights[0].inputWeights, prevHidden, m.weights[0].hiddenWeights, prevCell, m.weights[0].bias, false, m.opts) 102 | } 103 | 104 | x1 := gorgonia.Must(gorgonia.BatchedMatMul(x, m.weights[0].inputWeights)) 105 | forwardOutput := lstm(m.model, x1, m.weights[0].inputWeights, prevHidden, m.weights[0].hiddenWeights, prevCell, m.weights[0].bias, true, m.opts) 106 | 107 | x2 := gorgonia.Must(gorgonia.BatchedMatMul(x, m.weights[1].inputWeights)) 108 | x2 = gorgonia.Must(Reverse(x2, 0)) 109 | backwardOutput := lstm(m.model, x2, m.weights[1].inputWeights, prevHidden, m.weights[1].hiddenWeights, prevCell, m.weights[1].bias, true, m.opts) 110 | backwardOutputReversed := gorgonia.Must(Reverse(backwardOutput[0], 0)) 111 | 112 | var output *godl.Node 113 | 114 | switch m.opts.MergeMode { 115 | case MergeModeAverage: 116 | output = gorgonia.Must(gorgonia.Div(gorgonia.Must(gorgonia.Add(forwardOutput[0], backwardOutputReversed)), two)) 117 | case MergeModeConcat: 118 | output = gorgonia.Must(gorgonia.Concat(forwardOutput[0].Dims()-1, forwardOutput[0], backwardOutputReversed)) 119 | case MergeModeMul: 120 | output = gorgonia.Must(gorgonia.HadamardProd(forwardOutput[0], backwardOutputReversed)) 121 | case MergeModeSum: 122 | output = gorgonia.Must(gorgonia.Add(forwardOutput[0], backwardOutputReversed)) 123 | } 124 | 125 | hidden := gorgonia.Must(gorgonia.Concat(0, forwardOutput[1], backwardOutput[1])) 126 | cell := gorgonia.Must(gorgonia.Concat(0, forwardOutput[2], backwardOutput[2])) 127 | 128 | return godl.Nodes{output, hidden, cell} 129 | } 130 | 131 | func Reverse(x *gorgonia.Node, axis int) (*gorgonia.Node, error) { 132 | indicesA := make([]int, 0, x.Shape()[axis]) 133 | for i := x.Shape()[axis] - 1; i >= 0; i-- { 134 | indicesA = append(indicesA, i) 135 | } 136 | 137 | t := tensor.New(tensor.WithShape(len(indicesA)), tensor.WithBacking(indicesA)) 138 | indices := gorgonia.NewTensor(x.Graph(), tensor.Int, 1, gorgonia.WithShape(len(indicesA)), gorgonia.WithValue(t), gorgonia.WithName("indices-"+x.Name())) 139 | 140 | return gorgonia.ByIndices(x, indices, axis) 141 | } 142 | 143 | func LSTM(m *godl.Model, opts LSTMOpts) *LSTMModule { 144 | opts.setDefaults() 145 | lt := godl.AddLayer("LSTM") 146 | 147 | paramsCount := 1 148 | if opts.Bidirectional { 149 | paramsCount = 2 150 | } 151 | 152 | weights := buildParamsList(paramsCount, m, lt, opts) 153 | 154 | return &LSTMModule{ 155 | model: m, 156 | layer: lt, 157 | weights: weights, 158 | opts: opts, 159 | } 160 | } 161 | 162 | type lstmParams struct { 163 | inputWeights, hiddenWeights, bias *gorgonia.Node 164 | } 165 | 166 | func buildParamsList(count int, m *godl.Model, lt godl.LayerType, opts LSTMOpts) []lstmParams { 167 | list := make([]lstmParams, count) 168 | for i := 0; i < count; i++ { 169 | list[i] = newParams(m, lt, opts) 170 | } 171 | 172 | return list 173 | } 174 | 175 | func newParams(m *godl.Model, lt godl.LayerType, opts LSTMOpts) lstmParams { 176 | inputWeightsSize := 1 177 | if opts.Bidirectional { 178 | inputWeightsSize = 2 179 | } 180 | 181 | inputWeights := m.AddWeights(lt, tensor.Shape{inputWeightsSize, opts.InputDimension, opts.HiddenSize * 4}, godl.NewWeightsOpts{ 182 | InitFN: opts.WeightsInit, 183 | }) 184 | hiddenWeights := m.AddWeights(lt, tensor.Shape{1, opts.HiddenSize, opts.HiddenSize * 4}, godl.NewWeightsOpts{ 185 | InitFN: opts.WeightsInit, 186 | }) 187 | 188 | var bias *gorgonia.Node 189 | 190 | if opts.WithBias { 191 | bias = m.AddBias(lt, tensor.Shape{1, 1, opts.HiddenSize * 4}, godl.NewWeightsOpts{ 192 | InitFN: opts.BiasInit, 193 | }) 194 | } 195 | 196 | return lstmParams{ 197 | inputWeights: inputWeights, 198 | hiddenWeights: hiddenWeights, 199 | bias: bias, 200 | } 201 | } 202 | 203 | func lstm(m *godl.Model, x, inputWeights, prevHidden, hiddenWeights, prevCell, bias *gorgonia.Node, withPrecomputedInput bool, opts LSTMOpts) godl.Nodes { 204 | seqs := x.Shape()[0] 205 | outputs := make([]*gorgonia.Node, seqs) 206 | 207 | var err error 208 | 209 | for seq := 0; seq < seqs; seq++ { 210 | seqX := gorgonia.Must(gorgonia.Slice(x, gorgonia.S(seq), nil, nil)) 211 | seqX = gorgonia.Must(gorgonia.Reshape(seqX, tensor.Shape{1, seqX.Shape()[0], seqX.Shape()[1]})) 212 | 213 | prevHidden, prevCell, err = lstmGate(m, seqX, inputWeights, prevHidden, hiddenWeights, prevCell, bias, withPrecomputedInput, opts) 214 | if err != nil { 215 | panic(err) 216 | } 217 | 218 | outputs[seq] = prevHidden 219 | } 220 | 221 | outputGate := gorgonia.Must(gorgonia.Concat(0, outputs...)) 222 | 223 | return godl.Nodes{outputGate, prevHidden, prevCell} 224 | } 225 | 226 | func lstmGate(m *godl.Model, seqX, inputWeights, prevHidden, hiddenWeights, prevCell, bias *gorgonia.Node, withPrecomputedInput bool, opts LSTMOpts) (*gorgonia.Node, *gorgonia.Node, error) { 227 | prevHidden = gorgonia.Must(gorgonia.BatchedMatMul(prevHidden, hiddenWeights)) 228 | 229 | if !withPrecomputedInput { 230 | seqX = gorgonia.Must(gorgonia.BatchedMatMul(seqX, inputWeights)) 231 | } 232 | 233 | gates := gorgonia.Must(gorgonia.Add(prevHidden, seqX)) 234 | 235 | if bias != nil { 236 | gates = gorgonia.Must(gorgonia.BroadcastAdd(gates, bias, nil, []byte{0, 1})) 237 | } 238 | 239 | inputGate := gorgonia.Must(gorgonia.Slice(gates, nil, nil, gorgonia.S(0, opts.HiddenSize))) 240 | inputGate = gorgonia.Must(opts.RecurrentActivation(inputGate)) 241 | 242 | forgetGate := gorgonia.Must(gorgonia.Slice(gates, nil, nil, gorgonia.S(opts.HiddenSize, opts.HiddenSize*2))) 243 | forgetGate = gorgonia.Must(opts.RecurrentActivation(forgetGate)) 244 | 245 | cellGate := gorgonia.Must(gorgonia.Slice(gates, nil, nil, gorgonia.S(opts.HiddenSize*2, opts.HiddenSize*3))) 246 | cellGate = gorgonia.Must(opts.Activation(cellGate)) 247 | 248 | outputGate := gorgonia.Must(gorgonia.Slice(gates, nil, nil, gorgonia.S(opts.HiddenSize*3, opts.HiddenSize*4))) 249 | 250 | outputGate = gorgonia.Must(opts.RecurrentActivation(outputGate)) 251 | 252 | retain, err := gorgonia.BroadcastHadamardProd(forgetGate, prevCell, nil, []byte{0}) 253 | if err != nil { 254 | return nil, nil, err 255 | } 256 | 257 | write, err := gorgonia.BroadcastHadamardProd(inputGate, cellGate, nil, []byte{0}) 258 | if err != nil { 259 | return nil, nil, err 260 | } 261 | 262 | prevCell, err = gorgonia.Add(retain, write) 263 | if err != nil { 264 | return nil, nil, err 265 | } 266 | 267 | cellTan := gorgonia.Must(activation.Tanh(prevCell)) 268 | 269 | prevHidden, err = gorgonia.BroadcastHadamardProd(outputGate, cellTan, nil, []byte{0}) 270 | if err != nil { 271 | return nil, nil, err 272 | } 273 | 274 | return prevHidden, prevCell, nil 275 | } 276 | 277 | var ( 278 | _ godl.Module = &LSTMModule{} 279 | ) 280 | --------------------------------------------------------------------------------