├── .github └── workflows │ └── go.yml ├── .gitignore ├── .golangci.yml ├── LICENSE ├── Makefile ├── README.md ├── errors.go ├── go.mod ├── go.sum ├── logo.png ├── model.go ├── model_test.go ├── onnx ├── graph_proto.go ├── onnx.proto3 └── onnx.proto3.pb.go ├── ops ├── abs │ ├── abs.go │ ├── abs_test.go │ └── versions.go ├── acos │ ├── acos.go │ ├── acos_test.go │ └── versions.go ├── acosh │ ├── acosh.go │ ├── acosh_test.go │ └── versions.go ├── activation.go ├── activation_test.go ├── add │ ├── add.go │ ├── add_test.go │ └── versions.go ├── and │ ├── and.go │ ├── and_test.go │ └── versions.go ├── argmax │ ├── argmax.go │ ├── argmax_test.go │ └── versions.go ├── asin │ ├── asin.go │ ├── asin_test.go │ └── versions.go ├── asinh │ ├── asinh.go │ ├── asinh_test.go │ └── versions.go ├── atan │ ├── atan.go │ ├── atan_test.go │ └── versions.go ├── atanh │ ├── atanh.go │ ├── atanh_test.go │ └── versions.go ├── base.go ├── binary_op.go ├── cast │ ├── cast.go │ ├── cast_test.go │ └── versions.go ├── concat │ ├── concat.go │ ├── concat_test.go │ └── versions.go ├── constant │ ├── constant.go │ ├── constant_11.go │ ├── constant_legacy.go │ ├── constant_test.go │ ├── constants.go │ └── versions.go ├── constantofshape │ ├── constant_of_shape.go │ ├── constant_of_shape_test.go │ └── versions.go ├── conv │ ├── conv.go │ ├── conv_test.go │ └── versions.go ├── convert.go ├── convert_test.go ├── cos │ ├── cos.go │ ├── cos_test.go │ └── versions.go ├── cosh │ ├── cosh.go │ ├── cosh_test.go │ └── versions.go ├── cumsum │ ├── cumsum.go │ ├── cumsum_test.go │ └── versions.go ├── div │ ├── div.go │ ├── div_13_test.go │ └── versions.go ├── equal │ ├── equal.go │ ├── equal_test.go │ └── versions.go ├── erf │ ├── erf.go │ ├── erf_test.go │ └── versions.go ├── errors.go ├── expand │ ├── expand.go │ ├── expand_test.go │ └── versions.go ├── fixtures.go ├── flatten │ ├── constants.go │ ├── flatten.go │ ├── flatten_test.go │ └── versions.go ├── gather │ ├── constants.go │ ├── gather.go │ ├── gather_test.go │ └── versions.go ├── gemm │ ├── constants.go │ ├── gemm.go │ ├── gemm_legacy.go │ ├── gemm_test.go │ └── versions.go ├── greater │ ├── greater.go │ ├── greater_test.go │ └── versions.go ├── greaterorequal │ ├── greater_or_equal.go │ ├── greater_or_equal_test.go │ └── versions.go ├── gru │ ├── gru.go │ ├── gru_test.go │ └── versions.go ├── identity │ ├── identity.go │ ├── identity_test.go │ └── versions.go ├── less │ ├── less.go │ ├── less_test.go │ └── versions.go ├── lessorequal │ ├── less_or_equal.go │ ├── less_or_equal_test.go │ └── versions.go ├── linearregressor │ ├── linear_regressor.go │ ├── linear_regressor_test.go │ └── versions.go ├── logsoftmax │ ├── logsoftmax.go │ ├── logsoftmax_test.go │ └── versions.go ├── lstm │ ├── lstm.go │ ├── lstm_test.go │ └── versions.go ├── matmul │ ├── matmul.go │ ├── matmul_test.go │ └── versions.go ├── mul │ ├── mul.go │ ├── mul_test.go │ └── versions.go ├── multidir_broadcast.go ├── multidir_broadcast_test.go ├── not │ ├── not.go │ ├── not_test.go │ └── versions.go ├── operator.go ├── or │ ├── or.go │ ├── or_test.go │ └── versions.go ├── pow │ ├── pow.go │ ├── pow_test.go │ └── versions.go ├── prelu │ ├── prelu.go │ ├── prelu_test.go │ └── versions.go ├── recurrent_utils.go ├── reducemax │ ├── constants.go │ ├── reduce_max.go │ ├── reduce_max_test.go │ └── versions.go ├── reducemean │ ├── reduce_mean.go │ ├── reduce_mean_test.go │ └── versions.go ├── reducemin │ ├── constants.go │ ├── reduce_min.go │ ├── reduce_min_test.go │ └── versions.go ├── relu │ ├── relu.go │ ├── relu_test.go │ └── versions.go ├── reshape │ ├── reshape.go │ ├── reshape_test.go │ └── versions.go ├── rnn │ ├── rnn.go │ ├── rnn_test.go │ └── versions.go ├── scaler │ ├── scaler.go │ ├── scaler_test.go │ └── versions.go ├── shape │ ├── shape.go │ ├── shape_test.go │ └── versions.go ├── sigmoid │ ├── sigmoid.go │ ├── sigmoid_test.go │ └── versions.go ├── sin │ ├── sin.go │ ├── sin_test.go │ └── versions.go ├── sinh │ ├── sinh.go │ ├── sinh_test.go │ └── versions.go ├── slice │ ├── slice.go │ ├── slice_1.go │ ├── slice_test.go │ └── versions.go ├── slicer.go ├── slicer_test.go ├── softmax │ ├── softmax.go │ ├── softmax_test.go │ └── versions.go ├── sqrt │ ├── sqrt.go │ ├── sqrt_test.go │ └── versions.go ├── squeeze │ ├── squeeze.go │ ├── squeeze_1.go │ ├── squeeze_11.go │ ├── squeeze_test.go │ └── versions.go ├── sub │ ├── sub.go │ ├── sub_test.go │ └── versions.go ├── tan │ ├── tan.go │ ├── tan_test.go │ └── versions.go ├── tanh │ ├── tanh.go │ ├── tanh_test.go │ └── versions.go ├── transpose │ ├── transpose.go │ ├── transpose_test.go │ └── versions.go ├── types.go ├── unidir_broadcast.go ├── unidir_broadcast_test.go ├── unsqueeze │ ├── unsqueeze.go │ ├── unsqueeze_1.go │ ├── unsqueeze_11.go │ ├── unsqueeze_test.go │ └── versions.go ├── utils.go ├── utils_test.go ├── validate_inputs.go ├── validate_inputs_test.go ├── where │ ├── versions.go │ ├── where.go │ └── where_test.go └── xor │ ├── versions.go │ ├── xor.go │ └── xor_test.go ├── ops_test.go ├── opset.go ├── opset_test.go └── sample_models ├── generate_sample_models.py ├── onnx_models ├── gru.onnx ├── mlp.onnx ├── mnist-8-opset13.onnx ├── ndm.onnx ├── nt_1.zip └── scaler.onnx ├── python_models ├── __init__.py ├── gru_torch.py ├── mlp_torch.py └── scaler_sklearn.py └── requirements.txt /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Go 5 | 6 | on: 7 | push: 8 | branches: [ "develop" ] 9 | pull_request: 10 | branches: [ "develop" ] 11 | 12 | jobs: 13 | lint: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v3 17 | 18 | - name: Set up Go 19 | uses: actions/setup-go@v3 20 | with: 21 | go-version: 1.23 22 | 23 | - name: Install linter 24 | run: make install_lint 25 | 26 | - name: Lint 27 | run: make lint 28 | 29 | tests: 30 | runs-on: ubuntu-latest 31 | steps: 32 | - uses: actions/checkout@v3 33 | 34 | - name: Set up Go 35 | uses: actions/setup-go@v3 36 | with: 37 | go-version: 1.23 38 | 39 | - name: Install dependencies 40 | run: make install 41 | 42 | - name: Install Gotestsum 43 | run: make install_gotestsum 44 | 45 | - name: Setup ONNX test data 46 | run: make test_data 47 | 48 | - name: Tests 49 | run: make test 50 | 51 | build_amd64: 52 | runs-on: ubuntu-latest 53 | steps: 54 | - uses: actions/checkout@v3 55 | 56 | - name: Set up Go 57 | uses: actions/setup-go@v3 58 | with: 59 | go-version: 1.23 60 | 61 | - name: Build amd64 62 | run: make build_amd64 63 | 64 | build_arm64: 65 | runs-on: ubuntu-latest 66 | steps: 67 | - uses: actions/checkout@v3 68 | 69 | - name: Set up Go 70 | uses: actions/setup-go@v3 71 | with: 72 | go-version: 1.23 73 | 74 | - name: Build arm64 75 | run: make build_arm64 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | test_data/ 2 | .coverage.out 3 | 4 | sample_models/.env 5 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | linters: 2 | disable-all: true 3 | enable: 4 | - asasalint 5 | - asciicheck 6 | - bidichk 7 | - decorder 8 | - durationcheck 9 | - errcheck 10 | - errchkjson 11 | - errname 12 | - errorlint 13 | - exhaustive 14 | - exportloopref 15 | - forcetypeassert 16 | - gochecknoinits 17 | - goconst 18 | - gocritic 19 | - godot 20 | - godox 21 | - err113 22 | - goprintffuncname 23 | - govet 24 | - ineffassign 25 | - makezero 26 | - misspell 27 | - nilerr 28 | - nlreturn 29 | - prealloc 30 | - predeclared 31 | - reassign 32 | - revive 33 | - staticcheck 34 | - typecheck 35 | - unconvert 36 | - unparam 37 | - unused 38 | - usestdlibvars 39 | - whitespace 40 | - wsl 41 | linters-settings: 42 | gomnd: 43 | ignored-functions: 44 | - "strconv.ParseInt" 45 | - "strconv.ParseFloat" 46 | - "strconv.FormatInt" 47 | - "strconv.FormatFloat" 48 | gocritic: 49 | disabled-checks: 50 | # In the world of AI tensor's are often denoted with a capital letter. 51 | # We want to adopt the go style guide as much as possible but we also want 52 | # to be able to easily show when a variable is a Tensor. So we chose to 53 | # disable captLocal. Note that any other parameter should use a lower case letters. 54 | - "captLocal" 55 | issues: 56 | max-issues-per-linter: 0 57 | max-same-issues: 0 58 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build test 2 | 3 | VERSION=$(shell git describe --always --tags --dirty) 4 | LDFLAGS=-ldflags "-s -w -X main.Version=${VERSION}" 5 | TEST=$(shell go list ./... | grep -v /onnx/) 6 | 7 | BUILD_PARAMS=CGO_ENABLED=0 8 | 9 | 10 | define echotask 11 | @tput setaf 6 12 | @echo -n " $1" 13 | @tput setaf 3 14 | @echo -n " - " 15 | @tput sgr0 16 | @echo $2 17 | endef 18 | 19 | help: 20 | $(call echotask,"help","Shows this page.") 21 | $(call echotask,"lint","Runs the GOLANGCI linter.") 22 | $(call echotask,"test","Runs the Go tests.") 23 | $(call echotask,"test_data","Downloads data for the ONNX test suite.") 24 | $(call echotask,"install","Install project dependencies.") 25 | $(call echotask,"install_lint","Install the Go linter.") 26 | $(call echotask,"install_gotestsum","Install the Go test runner.") 27 | $(call echotask,"build_all","Builds the project for both amd64 and arm64") 28 | $(call echotask,"build_amd64","Go amd64 build of the project.") 29 | $(call echotask,"build_arm64","Go arm64 build of the project.") 30 | 31 | lint: ## Run various linters. 32 | @golangci-lint run --timeout=1m --config .golangci.yml 33 | 34 | test: ## Run tests using gotestsum. 35 | @ ${BUILD_PARAMS} gotestsum \ 36 | --format=dots-v2 -- \ 37 | -timeout=30000ms \ 38 | -covermode=set \ 39 | -coverprofile=.coverage.out ${TEST} 40 | 41 | test_ci: ## Run tests using normal test runner for ci output. 42 | @ ${BUILD_PARAMS} go test \ 43 | -coverprofile .coverage.out ${TEST} && go tool cover -func=.coverage.out 44 | 45 | test_data: ## Creates test data from the ONNX test module. 46 | rm -R ./test_data; mkdir ./test_data; touch ./test_data/ 47 | git clone --depth 1 --branch v1.17.0 https://github.com/onnx/onnx.git temp_onnx 48 | cp -r temp_onnx/onnx/backend/test/data/node/* ./test_data 49 | rm -Rf temp_onnx 50 | 51 | test_html: ## Run tests showing coverage in the browser. 52 | @$(MAKE) test 53 | @ go tool cover -html=.coverage.out 54 | 55 | install: ## Install project with its depedencies. 56 | go get ./... 57 | @go mod download 58 | 59 | install_lint: ## Install the linter. 60 | curl -sfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh \ 61 | | sh -s -- -b $(shell go env GOPATH)/bin v1.61.0 62 | 63 | install_gotestsum: ## Install a tool for prettier test output. 64 | curl -sfL https://github.com/gotestyourself/gotestsum/releases/download/v1.9.0/gotestsum_1.9.0_linux_amd64.tar.gz \ 65 | | tar -C $(shell go env GOPATH)/bin -zxf - gotestsum 66 | 67 | build_all: build_amd64 build_arm64 68 | 69 | build_amd64: 70 | @GOARCH=amd64 GOOS=linux ${BUILD_PARAMS} go build ${LDFLAGs} ./... 71 | 72 | build_arm64: 73 | @GOARCH=arm64 GOOS=linux ${BUILD_PARAMS} go build ${LDFLAGS} ./... 74 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package gonnx 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/advancedclimatesystems/gonnx/onnx" 8 | ) 9 | 10 | var errModel = errors.New("gonnx model error") 11 | 12 | type InvalidShapeError struct { 13 | expected onnx.Shape 14 | actual []int 15 | } 16 | 17 | func (i InvalidShapeError) Error() string { 18 | return fmt.Sprintf("invalid shape error expected: %v actual %v", i.expected, i.actual) 19 | } 20 | 21 | func ErrInvalidShape(expected onnx.Shape, actual []int) error { 22 | return InvalidShapeError{ 23 | expected: expected, 24 | actual: actual, 25 | } 26 | } 27 | 28 | // ErrModel is used for when an error ocured during setup of running onnx models. 29 | // The user can specify a formatted message using the standard formatting rules. 30 | func ErrModel(format string, a ...any) error { 31 | return fmt.Errorf("%w: %s", errModel, fmt.Sprintf(format, a...)) 32 | } 33 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/advancedclimatesystems/gonnx 2 | 3 | go 1.23 4 | 5 | require ( 6 | github.com/pkg/errors v0.9.1 7 | github.com/stretchr/testify v1.8.1 8 | google.golang.org/protobuf v1.31.0 9 | gorgonia.org/tensor v0.9.24 10 | ) 11 | 12 | require ( 13 | github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect 14 | github.com/chewxy/hm v1.0.0 // indirect 15 | github.com/chewxy/math32 v1.10.1 // indirect 16 | github.com/davecgh/go-spew v1.1.1 // indirect 17 | github.com/gogo/protobuf v1.3.2 // indirect 18 | github.com/golang/protobuf v1.5.3 // indirect 19 | github.com/google/flatbuffers v23.5.26+incompatible // indirect 20 | github.com/pmezard/go-difflib v1.0.0 // indirect 21 | github.com/xtgo/set v1.0.0 // indirect 22 | go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect 23 | golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect 24 | gonum.org/v1/gonum v0.14.0 // indirect 25 | gopkg.in/yaml.v3 v3.0.1 // indirect 26 | gorgonia.org/vecf32 v0.9.0 // indirect 27 | gorgonia.org/vecf64 v0.9.0 // indirect 28 | ) 29 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdvancedClimateSystems/gonnx/c879ba407e657994a925ae985a06521cda34739d/logo.png -------------------------------------------------------------------------------- /ops/abs/abs.go: -------------------------------------------------------------------------------- 1 | package abs 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var absTypeConstraint = [][]tensor.Dtype{ 10 | {tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 11 | } 12 | 13 | // Abs represents the ONNX abs operator. 14 | type Abs struct { 15 | ops.BaseOperator 16 | } 17 | 18 | // newAbs creates a new abs operator. 19 | func newAbs(version int, typeConstraint [][]tensor.Dtype) ops.Operator { 20 | return &Abs{ 21 | BaseOperator: ops.NewBaseOperator( 22 | version, 23 | 1, 24 | 1, 25 | typeConstraint, 26 | "abs", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the abs operator. 32 | func (a *Abs) Init(*onnx.NodeProto) error { 33 | return nil 34 | } 35 | 36 | // Apply applies the abs operator. 37 | func (a *Abs) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 38 | out, err := tensor.Abs(inputs[0]) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | return []tensor.Tensor{out}, nil 44 | } 45 | -------------------------------------------------------------------------------- /ops/abs/versions.go: -------------------------------------------------------------------------------- 1 | package abs 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var absVersions = ops.OperatorVersions{ 8 | 6: ops.NewOperatorConstructor(newAbs, 6, absTypeConstraint), 9 | 13: ops.NewOperatorConstructor(newAbs, 13, absTypeConstraint), 10 | } 11 | 12 | func GetVersions() ops.OperatorVersions { 13 | return absVersions 14 | } 15 | -------------------------------------------------------------------------------- /ops/acos/acos.go: -------------------------------------------------------------------------------- 1 | package acos 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "github.com/advancedclimatesystems/gonnx/ops" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | // Acos represents the ONNX acos operator. 12 | type Acos struct { 13 | ops.BaseOperator 14 | } 15 | 16 | // newAcos creates a new acos operator. 17 | func newAcos() ops.Operator { 18 | return &Acos{ 19 | BaseOperator: ops.NewBaseOperator( 20 | 7, 21 | 1, 22 | 1, 23 | [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, 24 | "acos", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the acos operator. 30 | func (c *Acos) Init(*onnx.NodeProto) error { 31 | return nil 32 | } 33 | 34 | // Apply applies the acos operator. 35 | func (c *Acos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 36 | var ( 37 | out tensor.Tensor 38 | err error 39 | ) 40 | 41 | switch inputs[0].Dtype() { 42 | case tensor.Float32: 43 | out, err = inputs[0].Apply(acos[float32]) 44 | case tensor.Float64: 45 | out, err = inputs[0].Apply(acos[float64]) 46 | default: 47 | return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c.BaseOperator) 48 | } 49 | 50 | if err != nil { 51 | return nil, err 52 | } 53 | 54 | return []tensor.Tensor{out}, nil 55 | } 56 | 57 | func acos[T ops.FloatType](x T) T { 58 | return T(math.Acos(float64(x))) 59 | } 60 | -------------------------------------------------------------------------------- /ops/acos/acos_test.go: -------------------------------------------------------------------------------- 1 | package acos 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestAcosInit(t *testing.T) { 12 | c := &Acos{} 13 | 14 | // since 'acos' does not have any attributes we pass in nil. This should not 15 | // fail initializing the acos. 16 | err := c.Init(nil) 17 | assert.Nil(t, err) 18 | } 19 | 20 | func TestAcos(t *testing.T) { 21 | tests := []struct { 22 | acos ops.Operator 23 | backing []float32 24 | shape []int 25 | expected []float32 26 | }{ 27 | { 28 | newAcos(), 29 | []float32{-1, -1, 0, 1}, 30 | []int{2, 2}, 31 | []float32{3.1415927, 3.1415927, 1.5707964, 0}, 32 | }, 33 | { 34 | newAcos(), 35 | []float32{1, 0.5, 0.0, -0.5}, 36 | []int{1, 4}, 37 | []float32{0, 1.0471976, 1.5707964, 2.0943952}, 38 | }, 39 | { 40 | newAcos(), 41 | []float32{-1, -1, -1, -1}, 42 | []int{1, 4}, 43 | []float32{3.1415927, 3.1415927, 3.1415927, 3.1415927}, 44 | }, 45 | } 46 | 47 | for _, test := range tests { 48 | inputs := []tensor.Tensor{ 49 | ops.TensorWithBackingFixture(test.backing, test.shape...), 50 | } 51 | 52 | res, err := test.acos.Apply(inputs) 53 | assert.Nil(t, err) 54 | 55 | assert.Nil(t, err) 56 | assert.Equal(t, test.expected, res[0].Data()) 57 | } 58 | } 59 | 60 | func TestInputValidationAcos(t *testing.T) { 61 | tests := []struct { 62 | inputs []tensor.Tensor 63 | err error 64 | }{ 65 | { 66 | []tensor.Tensor{ 67 | ops.TensorWithBackingFixture([]float32{1, 2}, 2), 68 | }, 69 | nil, 70 | }, 71 | { 72 | []tensor.Tensor{ 73 | ops.TensorWithBackingFixture([]float64{1, 2}, 2), 74 | }, 75 | nil, 76 | }, 77 | { 78 | []tensor.Tensor{}, 79 | ops.ErrInvalidInputCount(0, ops.NewBaseOperator(7, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "acos")), 80 | }, 81 | { 82 | []tensor.Tensor{ 83 | ops.TensorWithBackingFixture([]int{1, 2}, 2), 84 | }, 85 | ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(7, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "acos")), 86 | }, 87 | } 88 | 89 | for _, test := range tests { 90 | acos := newAcos() 91 | validated, err := acos.ValidateInputs(test.inputs) 92 | 93 | assert.Equal(t, test.err, err) 94 | 95 | if test.err == nil { 96 | assert.Equal(t, test.inputs, validated) 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /ops/acos/versions.go: -------------------------------------------------------------------------------- 1 | package acos 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var acosVersions = ops.OperatorVersions{ 8 | 7: newAcos, 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return acosVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/acosh/acosh.go: -------------------------------------------------------------------------------- 1 | package acosh 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "github.com/advancedclimatesystems/gonnx/ops" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | // Acosh represents the ONNX acosh operator. 12 | type Acosh struct { 13 | ops.BaseOperator 14 | } 15 | 16 | // newAcosh creates a new acosh operator. 17 | func newAcosh() ops.Operator { 18 | return &Acosh{ 19 | BaseOperator: ops.NewBaseOperator(9, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "acosh"), 20 | } 21 | } 22 | 23 | // Init initializes the acosh operator. 24 | func (c *Acosh) Init(*onnx.NodeProto) error { 25 | return nil 26 | } 27 | 28 | // Apply applies the acosh operator. 29 | func (c *Acosh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 30 | var ( 31 | out tensor.Tensor 32 | err error 33 | ) 34 | 35 | switch inputs[0].Dtype() { 36 | case tensor.Float32: 37 | out, err = inputs[0].Apply(acosh[float32]) 38 | case tensor.Float64: 39 | out, err = inputs[0].Apply(acosh[float64]) 40 | default: 41 | return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c.BaseOperator) 42 | } 43 | 44 | if err != nil { 45 | return nil, err 46 | } 47 | 48 | return []tensor.Tensor{out}, nil 49 | } 50 | 51 | func acosh[T ops.FloatType](x T) T { 52 | return T(math.Acosh(float64(x))) 53 | } 54 | -------------------------------------------------------------------------------- /ops/acosh/acosh_test.go: -------------------------------------------------------------------------------- 1 | package acosh 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestAcosh9Init(t *testing.T) { 12 | c := &Acosh{} 13 | 14 | // since 'acosh' does not have any attributes we pass in nil. This should not 15 | // fail initializing the acosh. 16 | err := c.Init(nil) 17 | assert.Nil(t, err) 18 | } 19 | 20 | func TestAcosh9(t *testing.T) { 21 | tests := []struct { 22 | acosh ops.Operator 23 | backing []float32 24 | shape []int 25 | expected []float32 26 | }{ 27 | { 28 | newAcosh(), 29 | []float32{1, 2, 3, 4}, 30 | []int{2, 2}, 31 | []float32{0, 1.316958, 1.7627472, 2.063437}, 32 | }, 33 | { 34 | newAcosh(), 35 | []float32{1, 2, 3, 4}, 36 | []int{1, 4}, 37 | []float32{0, 1.316958, 1.7627472, 2.063437}, 38 | }, 39 | { 40 | newAcosh(), 41 | []float32{2, 2, 2, 2}, 42 | []int{1, 4}, 43 | []float32{1.316958, 1.316958, 1.316958, 1.316958}, 44 | }, 45 | } 46 | 47 | for _, test := range tests { 48 | inputs := []tensor.Tensor{ 49 | ops.TensorWithBackingFixture(test.backing, test.shape...), 50 | } 51 | 52 | res, err := test.acosh.Apply(inputs) 53 | assert.Nil(t, err) 54 | 55 | assert.Nil(t, err) 56 | assert.Equal(t, test.expected, res[0].Data()) 57 | } 58 | } 59 | 60 | func TestInputValidationAcosh(t *testing.T) { 61 | tests := []struct { 62 | inputs []tensor.Tensor 63 | err error 64 | }{ 65 | { 66 | []tensor.Tensor{ 67 | ops.TensorWithBackingFixture([]float32{1, 2}, 2), 68 | }, 69 | nil, 70 | }, 71 | { 72 | []tensor.Tensor{ 73 | ops.TensorWithBackingFixture([]float64{1, 2}, 2), 74 | }, 75 | nil, 76 | }, 77 | { 78 | []tensor.Tensor{}, 79 | ops.ErrInvalidInputCount(0, ops.NewBaseOperator(9, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "acosh")), 80 | }, 81 | { 82 | []tensor.Tensor{ 83 | ops.TensorWithBackingFixture([]int{1, 2}, 2), 84 | }, 85 | ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(9, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "acosh")), 86 | }, 87 | } 88 | 89 | for _, test := range tests { 90 | acosh := newAcosh() 91 | validated, err := acosh.ValidateInputs(test.inputs) 92 | 93 | assert.Equal(t, test.err, err) 94 | 95 | if test.err == nil { 96 | assert.Equal(t, test.inputs, validated) 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /ops/acosh/versions.go: -------------------------------------------------------------------------------- 1 | package acosh 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var acoshVersions = ops.OperatorVersions{ 8 | 9: newAcosh, 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return acoshVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/activation.go: -------------------------------------------------------------------------------- 1 | package ops 2 | 3 | import ( 4 | "gorgonia.org/tensor" 5 | ) 6 | 7 | // Activation is an activation function. 8 | type Activation func(n tensor.Tensor) (tensor.Tensor, error) 9 | 10 | // activations maps strings to the activation function. This is 11 | // used by operators like LSTM, GRU and RNN. 12 | var activations = map[string]Activation{ 13 | "tanh": Tanh, 14 | "sigmoid": Sigmoid, 15 | "relu": ReLU, 16 | } 17 | 18 | func GetActivation(activation string) (Activation, error) { 19 | if a, ok := activations[activation]; ok { 20 | return a, nil 21 | } 22 | 23 | return nil, ErrActivationNotImplemented(activation) 24 | } 25 | 26 | // Tanh performs the tanh operation on a tensor. 27 | func Tanh(X tensor.Tensor) (tensor.Tensor, error) { 28 | return tensor.Tanh(X) 29 | } 30 | 31 | // Sigmoid performs the sigmoid operation on a tensor. 32 | func Sigmoid(X tensor.Tensor) (tensor.Tensor, error) { 33 | negX, err := tensor.Neg(X) 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | expX, err := tensor.Exp(negX) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | typedOne, err := GetValueAsTensorType(1.0, expX.Dtype()) 44 | if err != nil { 45 | return nil, err 46 | } 47 | 48 | numeratorX, err := tensor.Add(typedOne, expX) 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | return tensor.Div(typedOne, numeratorX) 54 | } 55 | 56 | // ReLU performs the ReLU operation on a tensor. 57 | func ReLU(X tensor.Tensor) (tensor.Tensor, error) { 58 | typedZero, err := GetValueAsTensorType(0.0, X.Dtype()) 59 | if err != nil { 60 | return nil, err 61 | } 62 | 63 | comparison, err := tensor.Gt(X, typedZero, tensor.AsSameType()) 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | return tensor.Mul(X, comparison) 69 | } 70 | -------------------------------------------------------------------------------- /ops/activation_test.go: -------------------------------------------------------------------------------- 1 | package ops 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "gorgonia.org/tensor" 8 | ) 9 | 10 | func TestTanhActivation(t *testing.T) { 11 | tIn := tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float32{1, 2, 3, 4})) 12 | tOut, err := Tanh(tIn) 13 | 14 | assert.Nil(t, err) 15 | assert.Equal(t, []float32{0.7615942, 0.9640276, 0.9950548, 0.9993293}, tOut.Data()) 16 | } 17 | 18 | func TestSigmoidActivation(t *testing.T) { 19 | tIn := tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float32{1, 2, 3, 4})) 20 | tOut, err := Sigmoid(tIn) 21 | 22 | assert.Nil(t, err) 23 | assert.Equal(t, []float32{0.7310586, 0.880797, 0.95257413, 0.98201376}, tOut.Data()) 24 | } 25 | -------------------------------------------------------------------------------- /ops/add/add.go: -------------------------------------------------------------------------------- 1 | package add 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var addTypeConstraints = [][]tensor.Dtype{ 10 | {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 11 | {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 12 | } 13 | 14 | // Add represents the ONNX add operator. 15 | type Add struct { 16 | ops.BaseOperator 17 | } 18 | 19 | // newAdd creates a new add operator. 20 | func newAdd(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 21 | return &Add{ 22 | BaseOperator: ops.NewBaseOperator( 23 | version, 24 | 2, 25 | 2, 26 | typeConstraints, 27 | "add", 28 | ), 29 | } 30 | } 31 | 32 | // Init initializes the add operator. 33 | func (a *Add) Init(*onnx.NodeProto) error { 34 | return nil 35 | } 36 | 37 | // Apply applies the add operator. 38 | func (a *Add) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 39 | return ops.ApplyBinaryOperation( 40 | inputs[0], 41 | inputs[1], 42 | ops.Add, 43 | ops.MultidirectionalBroadcasting, 44 | ) 45 | } 46 | -------------------------------------------------------------------------------- /ops/add/versions.go: -------------------------------------------------------------------------------- 1 | package add 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var addVersions = ops.OperatorVersions{ 8 | 7: ops.NewOperatorConstructor(newAdd, 7, addTypeConstraints), 9 | 13: ops.NewOperatorConstructor(newAdd, 13, addTypeConstraints), 10 | } 11 | 12 | func GetVersions() ops.OperatorVersions { 13 | return addVersions 14 | } 15 | -------------------------------------------------------------------------------- /ops/and/and.go: -------------------------------------------------------------------------------- 1 | package and 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var andTypeConstraints = [][]tensor.Dtype{{tensor.Bool}, {tensor.Bool}} 10 | 11 | // And represents the ONNX and operator. 12 | type And struct { 13 | ops.BaseOperator 14 | } 15 | 16 | // newAnd creates a new and operator. 17 | func newAnd(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 18 | return &And{ 19 | BaseOperator: ops.NewBaseOperator( 20 | version, 21 | 2, 22 | 2, 23 | typeConstraints, 24 | "and", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the and operator. 30 | func (a *And) Init(*onnx.NodeProto) error { 31 | return nil 32 | } 33 | 34 | // Apply applies the and operator. 35 | func (a *And) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 36 | return ops.ApplyBinaryOperation( 37 | inputs[0], 38 | inputs[1], 39 | ops.And, 40 | ops.MultidirectionalBroadcasting, 41 | ) 42 | } 43 | -------------------------------------------------------------------------------- /ops/and/and_test.go: -------------------------------------------------------------------------------- 1 | package and 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestAndInit(t *testing.T) { 12 | a := &And{} 13 | 14 | // since 'and' does not have any attributes we pass in nil. This should not 15 | // fail initializing the and. 16 | err := a.Init(nil) 17 | assert.Nil(t, err) 18 | } 19 | 20 | func TestAnd(t *testing.T) { 21 | tests := []struct { 22 | version int64 23 | backings [][]bool 24 | shapes [][]int 25 | expected []bool 26 | }{ 27 | { 28 | 7, 29 | [][]bool{{true, false, true, false}, {true, true, true, false}}, 30 | [][]int{{2, 2}, {2, 2}}, 31 | []bool{true, false, true, false}, 32 | }, 33 | { 34 | 7, 35 | [][]bool{{true, false, true, false}, {true, false}}, 36 | [][]int{{2, 2}, {1, 2}}, 37 | []bool{true, false, true, false}, 38 | }, 39 | { 40 | 7, 41 | [][]bool{{true, false, true, false}, {true, false}}, 42 | [][]int{{2, 2}, {2, 1}}, 43 | []bool{true, false, false, false}, 44 | }, 45 | { 46 | 7, 47 | [][]bool{{true, false, true, false, true, false}, {false, false}}, 48 | [][]int{{3, 2}, {1, 2}}, 49 | []bool{false, false, false, false, false, false}, 50 | }, 51 | } 52 | 53 | for _, test := range tests { 54 | inputs := []tensor.Tensor{ 55 | ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), 56 | ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), 57 | } 58 | 59 | and := andVersions[test.version]() 60 | 61 | res, err := and.Apply(inputs) 62 | assert.Nil(t, err) 63 | 64 | assert.Nil(t, err) 65 | assert.Equal(t, test.expected, res[0].Data()) 66 | } 67 | } 68 | 69 | func TestInputValidationAnd(t *testing.T) { 70 | tests := []struct { 71 | version int64 72 | inputs []tensor.Tensor 73 | err error 74 | }{ 75 | { 76 | 7, 77 | []tensor.Tensor{ 78 | ops.TensorWithBackingFixture([]bool{false, false}, 2), 79 | ops.TensorWithBackingFixture([]bool{false, false}, 2), 80 | }, 81 | nil, 82 | }, 83 | { 84 | 7, 85 | []tensor.Tensor{ 86 | ops.TensorWithBackingFixture([]bool{false, false}, 2), 87 | }, 88 | ops.ErrInvalidInputCount(1, and7BaseOpFixture()), 89 | }, 90 | { 91 | 7, 92 | []tensor.Tensor{ 93 | ops.TensorWithBackingFixture([]bool{false, false}, 2), 94 | ops.TensorWithBackingFixture([]int{1, 2}, 2), 95 | }, 96 | ops.ErrInvalidInputType(1, "int", and7BaseOpFixture()), 97 | }, 98 | } 99 | 100 | for _, test := range tests { 101 | and := andVersions[test.version]() 102 | validated, err := and.ValidateInputs(test.inputs) 103 | 104 | assert.Equal(t, test.err, err) 105 | 106 | if test.err == nil { 107 | assert.Equal(t, test.inputs, validated) 108 | } 109 | } 110 | } 111 | 112 | func and7BaseOpFixture() ops.BaseOperator { 113 | return ops.NewBaseOperator(7, 2, 2, andTypeConstraints, "and") 114 | } 115 | -------------------------------------------------------------------------------- /ops/and/versions.go: -------------------------------------------------------------------------------- 1 | package and 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var andVersions = ops.OperatorVersions{ 8 | 7: ops.NewOperatorConstructor(newAnd, 7, andTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return andVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/argmax/versions.go: -------------------------------------------------------------------------------- 1 | package argmax 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var argMaxVersions = ops.OperatorVersions{ 8 | 11: ops.NewOperatorConstructor(newArgMax, 11, argMaxTypeConstraints), 9 | 12: ops.NewOperatorConstructor(newArgMax, 12, argMaxTypeConstraints), 10 | 13: ops.NewOperatorConstructor(newArgMax, 13, argMaxTypeConstraints), 11 | } 12 | 13 | func GetVersions() ops.OperatorVersions { 14 | return argMaxVersions 15 | } 16 | -------------------------------------------------------------------------------- /ops/asin/asin.go: -------------------------------------------------------------------------------- 1 | package asin 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "github.com/advancedclimatesystems/gonnx/ops" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | var asinTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} 12 | 13 | // Asin represents the ONNX asin operator. 14 | type Asin struct { 15 | ops.BaseOperator 16 | } 17 | 18 | // newSin creates a new asin operator. 19 | func newAsin(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 20 | return &Asin{ 21 | BaseOperator: ops.NewBaseOperator( 22 | version, 23 | 1, 24 | 1, 25 | typeConstraints, 26 | "asin", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the asin operator. 32 | func (s *Asin) Init(*onnx.NodeProto) error { 33 | return nil 34 | } 35 | 36 | // Apply applies the asin operator. 37 | func (s *Asin) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 38 | var ( 39 | out tensor.Tensor 40 | err error 41 | ) 42 | 43 | switch inputs[0].Dtype() { 44 | case tensor.Float32: 45 | out, err = inputs[0].Apply(asin[float32]) 46 | case tensor.Float64: 47 | out, err = inputs[0].Apply(asin[float64]) 48 | default: 49 | return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), s.BaseOperator) 50 | } 51 | 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | return []tensor.Tensor{out}, nil 57 | } 58 | 59 | func asin[T ops.FloatType](x T) T { 60 | return T(math.Asin(float64(x))) 61 | } 62 | -------------------------------------------------------------------------------- /ops/asin/asin_test.go: -------------------------------------------------------------------------------- 1 | package asin 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestAsinInit(t *testing.T) { 12 | s := &Asin{} 13 | 14 | // since 'asin' does not have any attributes we pass in nil. This should not 15 | // fail initializing the asin. 16 | err := s.Init(nil) 17 | assert.Nil(t, err) 18 | } 19 | 20 | func TestAsin(t *testing.T) { 21 | tests := []struct { 22 | version int64 23 | backing []float32 24 | shape []int 25 | expected []float32 26 | }{ 27 | { 28 | 7, 29 | []float32{-1, -1, 0, 1}, 30 | []int{2, 2}, 31 | []float32{-1.5707964, -1.5707964, 0, 1.5707964}, 32 | }, 33 | { 34 | 7, 35 | []float32{1, 0.5, 0.0, -0.5}, 36 | []int{1, 4}, 37 | []float32{1.5707964, 0.5235988, 0, -0.5235988}, 38 | }, 39 | { 40 | 7, 41 | []float32{-1, -1, -1, -1}, 42 | []int{1, 4}, 43 | []float32{-1.5707964, -1.5707964, -1.5707964, -1.5707964}, 44 | }, 45 | } 46 | 47 | for _, test := range tests { 48 | inputs := []tensor.Tensor{ 49 | ops.TensorWithBackingFixture(test.backing, test.shape...), 50 | } 51 | 52 | asin := asinVersions[test.version]() 53 | 54 | res, err := asin.Apply(inputs) 55 | assert.Nil(t, err) 56 | 57 | assert.Nil(t, err) 58 | assert.Equal(t, test.expected, res[0].Data()) 59 | } 60 | } 61 | 62 | func TestInputValidationAsin(t *testing.T) { 63 | tests := []struct { 64 | version int64 65 | inputs []tensor.Tensor 66 | err error 67 | }{ 68 | { 69 | 7, 70 | []tensor.Tensor{ 71 | ops.TensorWithBackingFixture([]float32{1, 2}, 2), 72 | }, 73 | nil, 74 | }, 75 | { 76 | 7, 77 | []tensor.Tensor{ 78 | ops.TensorWithBackingFixture([]float64{1, 2}, 2), 79 | }, 80 | nil, 81 | }, 82 | { 83 | 7, 84 | []tensor.Tensor{}, 85 | ops.ErrInvalidInputCount(0, asin7BaseOpFixture()), 86 | }, 87 | { 88 | 7, 89 | []tensor.Tensor{ 90 | ops.TensorWithBackingFixture([]int{1, 2}, 2), 91 | }, 92 | ops.ErrInvalidInputType(0, "int", asin7BaseOpFixture()), 93 | }, 94 | } 95 | 96 | for _, test := range tests { 97 | asin := asinVersions[test.version]() 98 | validated, err := asin.ValidateInputs(test.inputs) 99 | 100 | assert.Equal(t, test.err, err) 101 | 102 | if test.err == nil { 103 | assert.Equal(t, test.inputs, validated) 104 | } 105 | } 106 | } 107 | 108 | func asin7BaseOpFixture() ops.BaseOperator { 109 | return ops.NewBaseOperator(7, 1, 1, asinTypeConstraints, "asin") 110 | } 111 | -------------------------------------------------------------------------------- /ops/asin/versions.go: -------------------------------------------------------------------------------- 1 | package asin 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var asinVersions = ops.OperatorVersions{ 8 | 7: ops.NewOperatorConstructor(newAsin, 7, asinTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return asinVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/asinh/asinh.go: -------------------------------------------------------------------------------- 1 | package asinh 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "github.com/advancedclimatesystems/gonnx/ops" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | var asinhTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} 12 | 13 | // Asinh represents the ONNX asinh operator. 14 | type Asinh struct { 15 | ops.BaseOperator 16 | } 17 | 18 | // newAsinh creates a new asinh operator. 19 | func newAsinh(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 20 | return &Asinh{ 21 | BaseOperator: ops.NewBaseOperator( 22 | version, 23 | 1, 24 | 1, 25 | typeConstraints, 26 | "asinh", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the asinh operator. 32 | func (a *Asinh) Init(*onnx.NodeProto) error { 33 | return nil 34 | } 35 | 36 | // Apply applies the asinh operator. 37 | func (a *Asinh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 38 | var ( 39 | out tensor.Tensor 40 | err error 41 | ) 42 | 43 | switch inputs[0].Dtype() { 44 | case tensor.Float32: 45 | out, err = inputs[0].Apply(asinh[float32]) 46 | case tensor.Float64: 47 | out, err = inputs[0].Apply(asinh[float64]) 48 | default: 49 | return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), a.BaseOperator) 50 | } 51 | 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | return []tensor.Tensor{out}, nil 57 | } 58 | 59 | func asinh[T ops.FloatType](x T) T { 60 | return T(math.Asinh(float64(x))) 61 | } 62 | -------------------------------------------------------------------------------- /ops/asinh/asinh_test.go: -------------------------------------------------------------------------------- 1 | package asinh 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestAsinhInit(t *testing.T) { 12 | c := &Asinh{} 13 | 14 | // since 'asinh' does not have any attributes we pass in nil. This should not 15 | // fail initializing the asinh. 16 | err := c.Init(nil) 17 | assert.Nil(t, err) 18 | } 19 | 20 | func TestAsinh(t *testing.T) { 21 | tests := []struct { 22 | version int64 23 | backing []float32 24 | shape []int 25 | expected []float32 26 | }{ 27 | { 28 | 9, 29 | []float32{1, 2, 3, 4}, 30 | []int{2, 2}, 31 | []float32{0.8813736, 1.4436355, 1.8184465, 2.0947125}, 32 | }, 33 | { 34 | 9, 35 | []float32{1, 2, 3, 4}, 36 | []int{1, 4}, 37 | []float32{0.8813736, 1.4436355, 1.8184465, 2.0947125}, 38 | }, 39 | { 40 | 9, 41 | []float32{2, 2, 2, 2}, 42 | []int{1, 4}, 43 | []float32{1.4436355, 1.4436355, 1.4436355, 1.4436355}, 44 | }, 45 | } 46 | 47 | for _, test := range tests { 48 | inputs := []tensor.Tensor{ 49 | ops.TensorWithBackingFixture(test.backing, test.shape...), 50 | } 51 | 52 | asinh := asinhVersions[test.version]() 53 | 54 | res, err := asinh.Apply(inputs) 55 | assert.Nil(t, err) 56 | 57 | assert.Nil(t, err) 58 | assert.Equal(t, test.expected, res[0].Data()) 59 | } 60 | } 61 | 62 | func TestInputValidationAsinh(t *testing.T) { 63 | tests := []struct { 64 | version int64 65 | inputs []tensor.Tensor 66 | err error 67 | }{ 68 | { 69 | 9, 70 | []tensor.Tensor{ 71 | ops.TensorWithBackingFixture([]float32{1, 2}, 2), 72 | }, 73 | nil, 74 | }, 75 | { 76 | 9, 77 | []tensor.Tensor{ 78 | ops.TensorWithBackingFixture([]float64{1, 2}, 2), 79 | }, 80 | nil, 81 | }, 82 | { 83 | 9, 84 | []tensor.Tensor{}, 85 | ops.ErrInvalidInputCount(0, asinh9BaseOpFixture()), 86 | }, 87 | { 88 | 9, 89 | []tensor.Tensor{ 90 | ops.TensorWithBackingFixture([]int{1, 2}, 2), 91 | }, 92 | ops.ErrInvalidInputType(0, "int", asinh9BaseOpFixture()), 93 | }, 94 | } 95 | 96 | for _, test := range tests { 97 | asinh := asinhVersions[test.version]() 98 | validated, err := asinh.ValidateInputs(test.inputs) 99 | 100 | assert.Equal(t, test.err, err) 101 | 102 | if test.err == nil { 103 | assert.Equal(t, test.inputs, validated) 104 | } 105 | } 106 | } 107 | 108 | func asinh9BaseOpFixture() ops.BaseOperator { 109 | return ops.NewBaseOperator(9, 1, 1, asinhTypeConstraints, "asinh") 110 | } 111 | -------------------------------------------------------------------------------- /ops/asinh/versions.go: -------------------------------------------------------------------------------- 1 | package asinh 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var asinhVersions = ops.OperatorVersions{ 8 | 9: ops.NewOperatorConstructor(newAsinh, 9, asinhTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return asinhVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/atan/atan.go: -------------------------------------------------------------------------------- 1 | package atan 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "github.com/advancedclimatesystems/gonnx/ops" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | var atanTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} 12 | 13 | // Atan represents the ONNX atan operator. 14 | type Atan struct { 15 | ops.BaseOperator 16 | } 17 | 18 | // newAtan creates a new atan operator. 19 | func newAtan(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 20 | return &Atan{ 21 | BaseOperator: ops.NewBaseOperator( 22 | version, 23 | 1, 24 | 1, 25 | typeConstraints, 26 | "atan", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the atan operator. 32 | func (a *Atan) Init(*onnx.NodeProto) error { 33 | return nil 34 | } 35 | 36 | // Apply applies the atan operator. 37 | func (a *Atan) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 38 | var ( 39 | out tensor.Tensor 40 | err error 41 | ) 42 | 43 | switch inputs[0].Dtype() { 44 | case tensor.Float32: 45 | out, err = inputs[0].Apply(atan[float32]) 46 | case tensor.Float64: 47 | out, err = inputs[0].Apply(atan[float64]) 48 | default: 49 | return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), a.BaseOperator) 50 | } 51 | 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | return []tensor.Tensor{out}, nil 57 | } 58 | 59 | func atan[T ops.FloatType](x T) T { 60 | return T(math.Atan(float64(x))) 61 | } 62 | -------------------------------------------------------------------------------- /ops/atan/atan_test.go: -------------------------------------------------------------------------------- 1 | package atan 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestAtanInit(t *testing.T) { 12 | a := &Atan{} 13 | 14 | // since 'atan' does not have any attributes we pass in nil. This should not 15 | // fail initializing the atan. 16 | err := a.Init(nil) 17 | assert.Nil(t, err) 18 | } 19 | 20 | func TestAtan(t *testing.T) { 21 | tests := []struct { 22 | version int64 23 | backing []float32 24 | shape []int 25 | expected []float32 26 | }{ 27 | { 28 | 7, 29 | []float32{1, 2, 3, 4}, 30 | []int{2, 2}, 31 | []float32{0.7853982, 1.1071488, 1.2490457, 1.3258177}, 32 | }, 33 | { 34 | 7, 35 | []float32{1, 2, 3, 4}, 36 | []int{1, 4}, 37 | []float32{0.7853982, 1.1071488, 1.2490457, 1.3258177}, 38 | }, 39 | { 40 | 7, 41 | []float32{2, 2, 2, 2}, 42 | []int{1, 4}, 43 | []float32{1.1071488, 1.1071488, 1.1071488, 1.1071488}, 44 | }, 45 | } 46 | 47 | for _, test := range tests { 48 | inputs := []tensor.Tensor{ 49 | ops.TensorWithBackingFixture(test.backing, test.shape...), 50 | } 51 | 52 | atan := atanVersions[test.version]() 53 | 54 | res, err := atan.Apply(inputs) 55 | assert.Nil(t, err) 56 | 57 | assert.Nil(t, err) 58 | assert.Equal(t, test.expected, res[0].Data()) 59 | } 60 | } 61 | 62 | func TestInputValidationAtan(t *testing.T) { 63 | tests := []struct { 64 | version int64 65 | inputs []tensor.Tensor 66 | err error 67 | }{ 68 | { 69 | 7, 70 | []tensor.Tensor{ 71 | ops.TensorWithBackingFixture([]float32{1, 2}, 2), 72 | }, 73 | nil, 74 | }, 75 | { 76 | 7, 77 | []tensor.Tensor{ 78 | ops.TensorWithBackingFixture([]float64{1, 2}, 2), 79 | }, 80 | nil, 81 | }, 82 | { 83 | 7, 84 | []tensor.Tensor{}, 85 | ops.ErrInvalidInputCount(0, atan7BaseOpFixture()), 86 | }, 87 | { 88 | 7, 89 | []tensor.Tensor{ 90 | ops.TensorWithBackingFixture([]int{1, 2}, 2), 91 | }, 92 | ops.ErrInvalidInputType(0, "int", atan7BaseOpFixture()), 93 | }, 94 | } 95 | 96 | for _, test := range tests { 97 | atan := atanVersions[test.version]() 98 | validated, err := atan.ValidateInputs(test.inputs) 99 | 100 | assert.Equal(t, test.err, err) 101 | 102 | if test.err == nil { 103 | assert.Equal(t, test.inputs, validated) 104 | } 105 | } 106 | } 107 | 108 | func atan7BaseOpFixture() ops.BaseOperator { 109 | return ops.NewBaseOperator(7, 1, 1, atanTypeConstraints, "atan") 110 | } 111 | -------------------------------------------------------------------------------- /ops/atan/versions.go: -------------------------------------------------------------------------------- 1 | package atan 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var atanVersions = ops.OperatorVersions{ 8 | 7: ops.NewOperatorConstructor(newAtan, 7, atanTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return atanVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/atanh/atanh.go: -------------------------------------------------------------------------------- 1 | package atanh 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "github.com/advancedclimatesystems/gonnx/ops" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | var atanhTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} 12 | 13 | // Atanh represents the ONNX atanh operator. 14 | type Atanh struct { 15 | ops.BaseOperator 16 | } 17 | 18 | // newAtanh creates a new atanh operator. 19 | func newAtanh(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 20 | return &Atanh{ 21 | BaseOperator: ops.NewBaseOperator( 22 | version, 23 | 1, 24 | 1, 25 | typeConstraints, 26 | "atanh", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the atanh operator. 32 | func (a *Atanh) Init(*onnx.NodeProto) error { 33 | return nil 34 | } 35 | 36 | // Apply applies the atanh operator. 37 | func (a *Atanh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 38 | var ( 39 | out tensor.Tensor 40 | err error 41 | ) 42 | 43 | switch inputs[0].Dtype() { 44 | case tensor.Float32: 45 | out, err = inputs[0].Apply(atanh[float32]) 46 | case tensor.Float64: 47 | out, err = inputs[0].Apply(atanh[float64]) 48 | default: 49 | return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), a.BaseOperator) 50 | } 51 | 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | return []tensor.Tensor{out}, nil 57 | } 58 | 59 | func atanh[T ops.FloatType](x T) T { 60 | return T(math.Atanh(float64(x))) 61 | } 62 | -------------------------------------------------------------------------------- /ops/atanh/atanh_test.go: -------------------------------------------------------------------------------- 1 | package atanh 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestAtanhInit(t *testing.T) { 12 | a := &Atanh{} 13 | 14 | // since 'atanh' does not have any attributes we pass in nil. This should not 15 | // fail initializing the atanh. 16 | err := a.Init(nil) 17 | assert.Nil(t, err) 18 | } 19 | 20 | func TestAtanh(t *testing.T) { 21 | tests := []struct { 22 | version int64 23 | backing []float32 24 | shape []int 25 | expected []float32 26 | }{ 27 | { 28 | 9, 29 | []float32{-0.9, -0.5, 0, 0.5}, 30 | []int{2, 2}, 31 | []float32{-1.4722193, -0.54930615, 0, 0.54930615}, 32 | }, 33 | { 34 | 9, 35 | []float32{-0.9, -0.5, 0, 0.5}, 36 | []int{1, 4}, 37 | []float32{-1.4722193, -0.54930615, 0, 0.54930615}, 38 | }, 39 | { 40 | 9, 41 | []float32{0.5, 0.5, 0.5, 0.5}, 42 | []int{1, 4}, 43 | []float32{0.54930615, 0.54930615, 0.54930615, 0.54930615}, 44 | }, 45 | } 46 | 47 | for _, test := range tests { 48 | inputs := []tensor.Tensor{ 49 | ops.TensorWithBackingFixture(test.backing, test.shape...), 50 | } 51 | 52 | atanh := atanhVersions[test.version]() 53 | 54 | res, err := atanh.Apply(inputs) 55 | assert.Nil(t, err) 56 | 57 | assert.Nil(t, err) 58 | assert.Equal(t, test.expected, res[0].Data()) 59 | } 60 | } 61 | 62 | func TestInputValidationAtanh(t *testing.T) { 63 | tests := []struct { 64 | version int64 65 | inputs []tensor.Tensor 66 | err error 67 | }{ 68 | { 69 | 9, 70 | []tensor.Tensor{ 71 | ops.TensorWithBackingFixture([]float32{1, 2}, 2), 72 | }, 73 | nil, 74 | }, 75 | { 76 | 9, 77 | []tensor.Tensor{ 78 | ops.TensorWithBackingFixture([]float64{1, 2}, 2), 79 | }, 80 | nil, 81 | }, 82 | { 83 | 9, 84 | []tensor.Tensor{}, 85 | ops.ErrInvalidInputCount(0, atanh9BaseOpFixture()), 86 | }, 87 | { 88 | 9, 89 | []tensor.Tensor{ 90 | ops.TensorWithBackingFixture([]int{1, 2}, 2), 91 | }, 92 | ops.ErrInvalidInputType(0, "int", atanh9BaseOpFixture()), 93 | }, 94 | } 95 | 96 | for _, test := range tests { 97 | atanh := atanhVersions[test.version]() 98 | validated, err := atanh.ValidateInputs(test.inputs) 99 | 100 | assert.Equal(t, test.err, err) 101 | 102 | if test.err == nil { 103 | assert.Equal(t, test.inputs, validated) 104 | } 105 | } 106 | } 107 | 108 | func atanh9BaseOpFixture() ops.BaseOperator { 109 | return ops.NewBaseOperator(9, 1, 1, atanhTypeConstraints, "atanh") 110 | } 111 | -------------------------------------------------------------------------------- /ops/atanh/versions.go: -------------------------------------------------------------------------------- 1 | package atanh 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var atanhVersions = ops.OperatorVersions{ 8 | 9: ops.NewOperatorConstructor(newAtanh, 9, atanhTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return atanhVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/base.go: -------------------------------------------------------------------------------- 1 | package ops 2 | 3 | import ( 4 | "fmt" 5 | 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | // Concrete implementation for shared operator methods. 10 | type BaseOperator struct { 11 | name string 12 | version int 13 | minInputs int 14 | maxInputs int 15 | inputTypeConstraints [][]tensor.Dtype 16 | } 17 | 18 | func NewBaseOperator(version, minInputs, maxInputs int, inputTypeConstraints [][]tensor.Dtype, name string) BaseOperator { 19 | return BaseOperator{ 20 | name: name, 21 | version: version, 22 | minInputs: minInputs, 23 | maxInputs: maxInputs, 24 | inputTypeConstraints: inputTypeConstraints, 25 | } 26 | } 27 | 28 | // ValidateInputs validates the inputs for the operator. 29 | func (f BaseOperator) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 30 | return ValidateInputs(f, inputs) 31 | } 32 | 33 | // Version returns the version of the operator. 34 | func (f BaseOperator) Version() int { 35 | return f.version 36 | } 37 | 38 | // GetMinInputs returns the minimum number of input tensors. 39 | func (f BaseOperator) GetMinInputs() int { 40 | return f.minInputs 41 | } 42 | 43 | // GetMaxInputs returns the maximum number of input tensors. 44 | func (f BaseOperator) GetMaxInputs() int { 45 | return f.maxInputs 46 | } 47 | 48 | // GetInputTypeConstraints returns allowed input types. 49 | func (f BaseOperator) GetInputTypeConstraints() [][]tensor.Dtype { 50 | return f.inputTypeConstraints 51 | } 52 | 53 | func (f BaseOperator) String() string { 54 | return fmt.Sprintf("%s v%d", f.name, f.version) 55 | } 56 | -------------------------------------------------------------------------------- /ops/cast/cast.go: -------------------------------------------------------------------------------- 1 | package cast 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var castTypeConstraints = [][]tensor.Dtype{ 10 | {tensor.Bool, tensor.Int16, tensor.Uint16, tensor.Int32, tensor.Uint32, tensor.Int64, tensor.Uint64, tensor.Float32, tensor.Float64}, 11 | } 12 | 13 | // Cast represents the ONNX cast operator. 14 | type Cast struct { 15 | ops.BaseOperator 16 | 17 | to int32 // DataType to cast to, as defined by TensorProto 18 | } 19 | 20 | // newCast creates a new cast operator. 21 | func newCast(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 22 | return &Cast{ 23 | BaseOperator: ops.NewBaseOperator( 24 | version, 25 | 1, 26 | 1, 27 | typeConstraints, 28 | "cast", 29 | ), 30 | } 31 | } 32 | 33 | // Init initializes the cast operator. 34 | func (c *Cast) Init(n *onnx.NodeProto) error { 35 | attributes := n.GetAttribute() 36 | 37 | if len(attributes) != 1 { 38 | return ops.ErrInvalidAttributeCount(1, len(attributes), c) 39 | } 40 | 41 | attr := attributes[0] 42 | if attr.GetName() == "to" { 43 | c.to = int32(attr.GetI()) 44 | } else { 45 | return ops.ErrInvalidAttribute(attr.GetName(), c) 46 | } 47 | 48 | return nil 49 | } 50 | 51 | // Apply applies the cast operator. 52 | func (c *Cast) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 53 | out, err := ops.ConvertTensorDtype(inputs[0], c.to) 54 | if err != nil { 55 | return nil, err 56 | } 57 | 58 | return []tensor.Tensor{out}, nil 59 | } 60 | -------------------------------------------------------------------------------- /ops/cast/versions.go: -------------------------------------------------------------------------------- 1 | package cast 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var castVersions = ops.OperatorVersions{ 8 | 6: ops.NewOperatorConstructor(newCast, 6, castTypeConstraints), 9 | 9: ops.NewOperatorConstructor(newCast, 9, castTypeConstraints), 10 | 13: ops.NewOperatorConstructor(newCast, 13, castTypeConstraints), 11 | } 12 | 13 | func GetVersions() ops.OperatorVersions { 14 | return castVersions 15 | } 16 | -------------------------------------------------------------------------------- /ops/concat/concat.go: -------------------------------------------------------------------------------- 1 | package concat 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var concatTypeConstraints = [][]tensor.Dtype{ops.AllTypes} 10 | 11 | const ( 12 | MinConcatInputs = 1 13 | ) 14 | 15 | // Concat represents the ONNX concat operator. 16 | type Concat struct { 17 | ops.BaseOperator 18 | 19 | axis int 20 | } 21 | 22 | // newConcat creates a new concat operator. 23 | func newConcat(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 24 | return &Concat{ 25 | BaseOperator: ops.NewBaseOperator( 26 | version, 27 | 1, 28 | 1, 29 | typeConstraints, 30 | "concat", 31 | ), 32 | } 33 | } 34 | 35 | // Init initializes the concat operator. 36 | func (c *Concat) Init(n *onnx.NodeProto) error { 37 | attributes := n.GetAttribute() 38 | 39 | if len(attributes) != 1 { 40 | return ops.ErrInvalidAttributeCount(1, len(attributes), c) 41 | } 42 | 43 | c.axis = int(attributes[0].GetI()) 44 | 45 | return nil 46 | } 47 | 48 | // Apply applies the concat operator. 49 | func (c *Concat) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 50 | // Not sure why this is possible, but minimum number of inputs is said to be 1. 51 | if len(inputs) == 1 { 52 | return inputs, nil 53 | } 54 | 55 | axis := c.axis 56 | if axis < 0 { 57 | axis = len(inputs[0].Shape()) + axis 58 | } 59 | 60 | out, err := tensor.Concat(axis, inputs[0], inputs[1:]...) 61 | if err != nil { 62 | return nil, err 63 | } 64 | 65 | return []tensor.Tensor{out}, nil 66 | } 67 | 68 | // ValidateInputs validates the inputs that will be given to Apply for this operator. 69 | // Because Concat can have an infinite number of inputs, we set the maximum number 70 | // of inputs dynamically, based on our inputs. Every input can have any type. 71 | func (c *Concat) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 72 | inputTypeConstraints := make([][]tensor.Dtype, len(inputs)) 73 | 74 | for i := 0; i < len(inputs); i++ { 75 | inputTypeConstraints[i] = ops.AllTypes 76 | } 77 | 78 | c.BaseOperator = ops.NewBaseOperator( 79 | c.BaseOperator.Version(), 80 | c.BaseOperator.GetMinInputs(), 81 | len(inputs), 82 | inputTypeConstraints, 83 | "concat", 84 | ) 85 | 86 | return ops.ValidateInputs(c.BaseOperator, inputs) 87 | } 88 | -------------------------------------------------------------------------------- /ops/concat/versions.go: -------------------------------------------------------------------------------- 1 | package concat 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var concatVersions = ops.OperatorVersions{ 8 | 4: ops.NewOperatorConstructor(newConcat, 4, concatTypeConstraints), 9 | 11: ops.NewOperatorConstructor(newConcat, 11, concatTypeConstraints), 10 | 13: ops.NewOperatorConstructor(newConcat, 13, concatTypeConstraints), 11 | } 12 | 13 | func GetVersions() ops.OperatorVersions { 14 | return concatVersions 15 | } 16 | -------------------------------------------------------------------------------- /ops/constant/constant.go: -------------------------------------------------------------------------------- 1 | package constant 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | // Constant represents the ONNX constant operator. 10 | type Constant struct { 11 | ops.BaseOperator 12 | 13 | value tensor.Tensor 14 | } 15 | 16 | // newConstant creates a new constant operator. 17 | func newConstant(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 18 | return &Constant{ 19 | BaseOperator: ops.NewBaseOperator( 20 | version, 21 | 0, 22 | 0, 23 | typeConstraints, 24 | "constant", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the constant operator. It supports all constant types except 30 | // `sparse_value`, `value_string`, and `value_strings`. 31 | func (c *Constant) Init(n *onnx.NodeProto) error { 32 | attributes := n.GetAttribute() 33 | if len(attributes) != 1 { 34 | return ops.ErrInvalidAttributeCount(1, len(attributes), c) 35 | } 36 | 37 | attr := attributes[0] 38 | 39 | switch attr.GetName() { 40 | case sparseValue, valueString, valueStrings: 41 | return ops.ErrUnsupportedAttribute(attr.GetName(), c) 42 | case value: 43 | t, err := onnx.TensorFromProto(attr.GetT()) 44 | if err != nil { 45 | return err 46 | } 47 | 48 | c.value = t 49 | case valueFloat: 50 | c.value = tensor.New(tensor.FromScalar(attr.GetF())) 51 | case valueFloats: 52 | floats := attr.GetFloats() 53 | c.value = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats)) 54 | case valueInt: 55 | c.value = tensor.New(tensor.FromScalar(attr.GetI())) 56 | case valueInts: 57 | ints := attr.GetInts() 58 | c.value = tensor.New(tensor.WithShape(len(ints)), tensor.WithBacking(ints)) 59 | default: 60 | return ops.ErrUnsupportedAttribute(attr.GetName(), c) 61 | } 62 | 63 | return nil 64 | } 65 | 66 | // Apply applies the constant operator. 67 | func (c *Constant) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { 68 | return []tensor.Tensor{c.value}, nil 69 | } 70 | -------------------------------------------------------------------------------- /ops/constant/constant_11.go: -------------------------------------------------------------------------------- 1 | package constant 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | // Constant11 represents the ONNX constant operator. 10 | type Constant11 struct { 11 | ops.BaseOperator 12 | 13 | value tensor.Tensor 14 | } 15 | 16 | // newConstant11 creates a new constant operator. 17 | func newConstant11() ops.Operator { 18 | return &Constant11{ 19 | BaseOperator: ops.NewBaseOperator( 20 | 11, 21 | 0, 22 | 0, 23 | [][]tensor.Dtype{}, 24 | "constant", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the constant operator. It supports all constant types except 30 | // `sparse_value`. 31 | func (c *Constant11) Init(n *onnx.NodeProto) error { 32 | attributes := n.GetAttribute() 33 | if len(attributes) != 1 { 34 | return ops.ErrInvalidAttributeCount(1, len(attributes), c) 35 | } 36 | 37 | attr := attributes[0] 38 | 39 | switch attr.GetName() { 40 | case sparseValue: 41 | return ops.ErrUnsupportedAttribute(attr.GetName(), c) 42 | case value: 43 | t, err := onnx.TensorFromProto(attr.GetT()) 44 | if err != nil { 45 | return err 46 | } 47 | 48 | c.value = t 49 | default: 50 | return ops.ErrUnsupportedAttribute(attr.GetName(), c) 51 | } 52 | 53 | return nil 54 | } 55 | 56 | // Apply applies the constant operator. 57 | func (c *Constant11) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { 58 | return []tensor.Tensor{c.value}, nil 59 | } 60 | -------------------------------------------------------------------------------- /ops/constant/constant_legacy.go: -------------------------------------------------------------------------------- 1 | package constant 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | // Constant9 represents the ONNX constant operator for version 9 and 1. 10 | type Constant9 struct { 11 | ops.BaseOperator 12 | 13 | value tensor.Tensor 14 | } 15 | 16 | // newConstant9 creates a new constant operator. 17 | func newConstant9(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 18 | return &Constant9{ 19 | BaseOperator: ops.NewBaseOperator( 20 | version, 21 | 0, 22 | 0, 23 | typeConstraints, 24 | "constant", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the constant operator. 30 | func (c *Constant9) Init(n *onnx.NodeProto) error { 31 | attributes := n.GetAttribute() 32 | if len(attributes) != 1 { 33 | return ops.ErrInvalidAttributeCount(1, len(attributes), c) 34 | } 35 | 36 | attr := attributes[0] 37 | 38 | switch attr.GetName() { 39 | case value: 40 | t, err := onnx.TensorFromProto(attr.GetT()) 41 | if err != nil { 42 | return err 43 | } 44 | 45 | c.value = t 46 | default: 47 | return ops.ErrUnsupportedAttribute(attr.GetName(), c) 48 | } 49 | 50 | return nil 51 | } 52 | 53 | // Apply applies the constant operator. 54 | func (c *Constant9) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { 55 | return []tensor.Tensor{c.value}, nil 56 | } 57 | -------------------------------------------------------------------------------- /ops/constant/constants.go: -------------------------------------------------------------------------------- 1 | package constant 2 | 3 | const ( 4 | value = "value" 5 | sparseValue = "sparse_value" 6 | valueString = "value_string" 7 | valueStrings = "value_strings" 8 | valueFloat = "value_float" 9 | valueFloats = "value_floats" 10 | valueInt = "value_int" 11 | valueInts = "value_ints" 12 | ) 13 | -------------------------------------------------------------------------------- /ops/constant/versions.go: -------------------------------------------------------------------------------- 1 | package constant 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | "gorgonia.org/tensor" 6 | ) 7 | 8 | var constantVersions = ops.OperatorVersions{ 9 | 1: ops.NewOperatorConstructor(newConstant9, 1, [][]tensor.Dtype{}), 10 | 9: ops.NewOperatorConstructor(newConstant9, 9, [][]tensor.Dtype{}), 11 | 11: newConstant11, 12 | 12: ops.NewOperatorConstructor(newConstant, 12, [][]tensor.Dtype{}), 13 | 13: ops.NewOperatorConstructor(newConstant, 13, [][]tensor.Dtype{}), 14 | } 15 | 16 | func GetVersions() ops.OperatorVersions { 17 | return constantVersions 18 | } 19 | -------------------------------------------------------------------------------- /ops/constantofshape/constant_of_shape.go: -------------------------------------------------------------------------------- 1 | package constantofshape 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var constantOfShapeTypeConstraints = [][]tensor.Dtype{{tensor.Int64}} 10 | 11 | // ConstantOfShape represents the ONNX constant of shape operator. 12 | type ConstantOfShape struct { 13 | ops.BaseOperator 14 | 15 | // One element tensor, giving the value and type of the output tensor 16 | // defaults to value 0 and type float32. 17 | value *tensor.Dense 18 | } 19 | 20 | // newConstantOfShape creates a new constant of shape operator. 21 | func newConstantOfShape(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 22 | return &ConstantOfShape{ 23 | BaseOperator: ops.NewBaseOperator( 24 | version, 25 | 1, 26 | 1, 27 | typeConstraints, 28 | "constantofshape", 29 | ), 30 | } 31 | } 32 | 33 | // Init initializes the constant of shape operator. 34 | func (c *ConstantOfShape) Init(n *onnx.NodeProto) error { 35 | attributes := n.GetAttribute() 36 | 37 | if len(attributes) > 1 { 38 | return ops.ErrInvalidAttributeCount(1, len(attributes), c) 39 | } 40 | 41 | if len(attributes) == 1 { 42 | attr := attributes[0] 43 | if attr.GetName() == "value" { 44 | t, err := onnx.TensorFromProto(attr.GetT()) 45 | if err != nil { 46 | return err 47 | } 48 | 49 | c.value = tensor.New(tensor.WithBacking(t.Data())) 50 | if c.value.Len() != 1 { 51 | return ops.ErrInvalidTensor("expected tensor to have one element", c) 52 | } 53 | } else { 54 | return ops.ErrInvalidAttribute(attr.GetName(), c) 55 | } 56 | } else { 57 | c.value = tensor.New(tensor.FromScalar(float32(0.0))) 58 | } 59 | 60 | return nil 61 | } 62 | 63 | // Apply applies the constant of shape operator. 64 | func (c *ConstantOfShape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 65 | shape, err := ops.AnyToIntSlice(ops.IfScalarToSlice(inputs[0].Data())) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | // Empty dimensions in a tensor are not supported 71 | for i := range shape { 72 | if shape[i] <= 0 { 73 | return nil, ops.ErrInvalidTensor("empty dimensions are not allowed", c) 74 | } 75 | } 76 | 77 | t := tensor.New(tensor.WithShape(shape...), tensor.Of(c.value.Dtype())) 78 | 79 | t, err = t.AddScalar(c.value, true) 80 | if err != nil { 81 | return nil, err 82 | } 83 | 84 | return []tensor.Tensor{t}, err 85 | } 86 | -------------------------------------------------------------------------------- /ops/constantofshape/versions.go: -------------------------------------------------------------------------------- 1 | package constantofshape 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var constantOfShapeVersions = ops.OperatorVersions{ 8 | 9: ops.NewOperatorConstructor(newConstantOfShape, 9, constantOfShapeTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return constantOfShapeVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/conv/versions.go: -------------------------------------------------------------------------------- 1 | package conv 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var convVersions = ops.OperatorVersions{ 8 | 1: ops.NewOperatorConstructor(newConv, 1, convTypeConstraints), 9 | 11: ops.NewOperatorConstructor(newConv, 11, convTypeConstraints), 10 | } 11 | 12 | func GetVersions() ops.OperatorVersions { 13 | return convVersions 14 | } 15 | -------------------------------------------------------------------------------- /ops/cos/cos.go: -------------------------------------------------------------------------------- 1 | package cos 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "github.com/advancedclimatesystems/gonnx/ops" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | var cosTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} 12 | 13 | // Cos represents the ONNX cos operator. 14 | type Cos struct { 15 | ops.BaseOperator 16 | } 17 | 18 | // newCos creates a new cos operator. 19 | func newCos(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 20 | return &Cos{ 21 | BaseOperator: ops.NewBaseOperator( 22 | version, 23 | 1, 24 | 1, 25 | typeConstraints, 26 | "cos", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the cos operator. 32 | func (c *Cos) Init(*onnx.NodeProto) error { 33 | return nil 34 | } 35 | 36 | // Apply applies the cos operator. 37 | func (c *Cos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 38 | var ( 39 | out tensor.Tensor 40 | err error 41 | ) 42 | 43 | switch inputs[0].Dtype() { 44 | case tensor.Float32: 45 | out, err = inputs[0].Apply(cos[float32]) 46 | case tensor.Float64: 47 | out, err = inputs[0].Apply(cos[float64]) 48 | default: 49 | return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c.BaseOperator) 50 | } 51 | 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | return []tensor.Tensor{out}, nil 57 | } 58 | 59 | func cos[T ops.FloatType](x T) T { 60 | return T(math.Cos(float64(x))) 61 | } 62 | -------------------------------------------------------------------------------- /ops/cos/cos_test.go: -------------------------------------------------------------------------------- 1 | package cos 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestCosInit(t *testing.T) { 12 | c := &Cos{} 13 | 14 | // since 'cos' does not have any attributes we pass in nil. This should not 15 | // fail initializing the cos. 16 | err := c.Init(nil) 17 | assert.Nil(t, err) 18 | } 19 | 20 | func TestCos(t *testing.T) { 21 | tests := []struct { 22 | version int64 23 | backing []float32 24 | shape []int 25 | expected []float32 26 | }{ 27 | { 28 | 7, 29 | []float32{-2, -1, 0, 1}, 30 | []int{2, 2}, 31 | []float32{-0.41614684, 0.5403023, 1, 0.5403023}, 32 | }, 33 | { 34 | 7, 35 | []float32{1, 3, 4, 5}, 36 | []int{1, 4}, 37 | []float32{0.5403023, -0.9899925, -0.6536436, 0.2836622}, 38 | }, 39 | { 40 | 7, 41 | []float32{-1, -1, -1, -1}, 42 | []int{1, 4}, 43 | []float32{0.5403023, 0.5403023, 0.5403023, 0.5403023}, 44 | }, 45 | } 46 | 47 | for _, test := range tests { 48 | inputs := []tensor.Tensor{ 49 | ops.TensorWithBackingFixture(test.backing, test.shape...), 50 | } 51 | 52 | cos := cosVersions[test.version]() 53 | 54 | res, err := cos.Apply(inputs) 55 | assert.Nil(t, err) 56 | 57 | assert.Nil(t, err) 58 | assert.Equal(t, test.expected, res[0].Data()) 59 | } 60 | } 61 | 62 | func TestInputValidationCos(t *testing.T) { 63 | tests := []struct { 64 | version int64 65 | inputs []tensor.Tensor 66 | err error 67 | }{ 68 | { 69 | 7, 70 | []tensor.Tensor{ 71 | ops.TensorWithBackingFixture([]float32{1, 2}, 2), 72 | }, 73 | nil, 74 | }, 75 | { 76 | 7, 77 | []tensor.Tensor{ 78 | ops.TensorWithBackingFixture([]float64{1, 2}, 2), 79 | }, 80 | nil, 81 | }, 82 | { 83 | 7, 84 | []tensor.Tensor{}, 85 | ops.ErrInvalidInputCount(0, cos7BaseOperator()), 86 | }, 87 | { 88 | 7, 89 | []tensor.Tensor{ 90 | ops.TensorWithBackingFixture([]int{1, 2}, 2), 91 | }, 92 | ops.ErrInvalidInputType(0, "int", cos7BaseOperator()), 93 | }, 94 | } 95 | 96 | for _, test := range tests { 97 | cos := cosVersions[test.version]() 98 | validated, err := cos.ValidateInputs(test.inputs) 99 | 100 | assert.Equal(t, test.err, err) 101 | 102 | if test.err == nil { 103 | assert.Equal(t, test.inputs, validated) 104 | } 105 | } 106 | } 107 | 108 | func cos7BaseOperator() ops.BaseOperator { 109 | return ops.NewBaseOperator(7, 1, 1, cosTypeConstraints, "cos") 110 | } 111 | -------------------------------------------------------------------------------- /ops/cos/versions.go: -------------------------------------------------------------------------------- 1 | package cos 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var cosVersions = ops.OperatorVersions{ 8 | 7: ops.NewOperatorConstructor(newCos, 7, cosTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return cosVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/cosh/cosh.go: -------------------------------------------------------------------------------- 1 | package cosh 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "github.com/advancedclimatesystems/gonnx/ops" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | var coshTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} 12 | 13 | // Cosh represents the ONNX cosh operator. 14 | type Cosh struct { 15 | ops.BaseOperator 16 | } 17 | 18 | // newCosh creates a new cosh operator. 19 | func newCosh(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 20 | return &Cosh{ 21 | BaseOperator: ops.NewBaseOperator( 22 | version, 23 | 1, 24 | 1, 25 | typeConstraints, 26 | "cosh", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the cosh operator. 32 | func (c *Cosh) Init(*onnx.NodeProto) error { 33 | return nil 34 | } 35 | 36 | // Apply applies the cosh operator. 37 | func (c *Cosh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 38 | var ( 39 | out tensor.Tensor 40 | err error 41 | ) 42 | 43 | switch inputs[0].Dtype() { 44 | case tensor.Float32: 45 | out, err = inputs[0].Apply(cosh[float32]) 46 | case tensor.Float64: 47 | out, err = inputs[0].Apply(cosh[float64]) 48 | default: 49 | return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c.BaseOperator) 50 | } 51 | 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | return []tensor.Tensor{out}, nil 57 | } 58 | 59 | func cosh[T ops.FloatType](x T) T { 60 | return T(math.Cosh(float64(x))) 61 | } 62 | -------------------------------------------------------------------------------- /ops/cosh/cosh_test.go: -------------------------------------------------------------------------------- 1 | package cosh 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestCoshInit(t *testing.T) { 12 | c := &Cosh{} 13 | 14 | // since 'cosh' does not have any attributes we pass in nil. This should not 15 | // fail initializing the cosh. 16 | err := c.Init(nil) 17 | assert.Nil(t, err) 18 | } 19 | 20 | func TestCosh(t *testing.T) { 21 | tests := []struct { 22 | version int64 23 | backing []float32 24 | shape []int 25 | expected []float32 26 | }{ 27 | { 28 | 9, 29 | []float32{-2, -1, 0, 1}, 30 | []int{2, 2}, 31 | []float32{3.7621956, 1.5430807, 1, 1.5430807}, 32 | }, 33 | { 34 | 9, 35 | []float32{1, 3, 4, 5}, 36 | []int{1, 4}, 37 | []float32{1.5430807, 10.067662, 27.308233, 74.209946}, 38 | }, 39 | { 40 | 9, 41 | []float32{-1, -1, -1, -1}, 42 | []int{1, 4}, 43 | []float32{1.5430807, 1.5430807, 1.5430807, 1.5430807}, 44 | }, 45 | } 46 | 47 | for _, test := range tests { 48 | inputs := []tensor.Tensor{ 49 | ops.TensorWithBackingFixture(test.backing, test.shape...), 50 | } 51 | 52 | cosh := coshVersions[test.version]() 53 | 54 | res, err := cosh.Apply(inputs) 55 | assert.Nil(t, err) 56 | 57 | assert.Nil(t, err) 58 | assert.Equal(t, test.expected, res[0].Data()) 59 | } 60 | } 61 | 62 | func TestInputValidationCosh(t *testing.T) { 63 | tests := []struct { 64 | version int64 65 | inputs []tensor.Tensor 66 | err error 67 | }{ 68 | { 69 | 9, 70 | []tensor.Tensor{ 71 | ops.TensorWithBackingFixture([]float32{1, 2}, 2), 72 | }, 73 | nil, 74 | }, 75 | { 76 | 9, 77 | []tensor.Tensor{ 78 | ops.TensorWithBackingFixture([]float64{1, 2}, 2), 79 | }, 80 | nil, 81 | }, 82 | { 83 | 9, 84 | []tensor.Tensor{}, 85 | ops.ErrInvalidInputCount(0, cosh9BaseOpFixture()), 86 | }, 87 | { 88 | 9, 89 | []tensor.Tensor{ 90 | ops.TensorWithBackingFixture([]int{1, 2}, 2), 91 | }, 92 | ops.ErrInvalidInputType(0, "int", cosh9BaseOpFixture()), 93 | }, 94 | } 95 | 96 | for _, test := range tests { 97 | cosh := coshVersions[test.version]() 98 | validated, err := cosh.ValidateInputs(test.inputs) 99 | 100 | assert.Equal(t, test.err, err) 101 | 102 | if test.err == nil { 103 | assert.Equal(t, test.inputs, validated) 104 | } 105 | } 106 | } 107 | 108 | func cosh9BaseOpFixture() ops.BaseOperator { 109 | return ops.NewBaseOperator(9, 1, 1, coshTypeConstraints, "cosh") 110 | } 111 | -------------------------------------------------------------------------------- /ops/cosh/versions.go: -------------------------------------------------------------------------------- 1 | package cosh 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var coshVersions = ops.OperatorVersions{ 8 | 9: ops.NewOperatorConstructor(newCosh, 9, coshTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return coshVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/cumsum/cumsum_test.go: -------------------------------------------------------------------------------- 1 | package cumsum 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "github.com/advancedclimatesystems/gonnx/ops" 8 | "github.com/stretchr/testify/assert" 9 | "gorgonia.org/tensor" 10 | ) 11 | 12 | func TestCumSumInit(t *testing.T) { 13 | c := &CumSum{} 14 | err := c.Init( 15 | &onnx.NodeProto{ 16 | Attribute: []*onnx.AttributeProto{ 17 | {Name: "exclusive", I: 1}, 18 | {Name: "reverse", I: 1}, 19 | }, 20 | }, 21 | ) 22 | 23 | assert.Nil(t, err) 24 | assert.Equal(t, true, c.exclusive) 25 | assert.Equal(t, true, c.reverse) 26 | } 27 | 28 | func TestCumSumInitDefaults(t *testing.T) { 29 | c := &CumSum{} 30 | err := c.Init( 31 | &onnx.NodeProto{ 32 | Attribute: []*onnx.AttributeProto{}, 33 | }, 34 | ) 35 | 36 | assert.Nil(t, err) 37 | assert.Equal(t, false, c.exclusive) 38 | assert.Equal(t, false, c.reverse) 39 | } 40 | 41 | func TestCumSum(t *testing.T) { 42 | tests := []struct { 43 | version int64 44 | node *onnx.NodeProto 45 | backing []float32 46 | axis int32 47 | shape []int 48 | expected []float32 49 | }{ 50 | { 51 | 11, 52 | &onnx.NodeProto{ 53 | Attribute: []*onnx.AttributeProto{ 54 | {Name: "exclusive", I: 0}, 55 | {Name: "reverse", I: 0}, 56 | }, 57 | }, 58 | []float32{1, 2, 3, 4}, 59 | 0, 60 | []int{2, 2}, 61 | []float32{1, 2, 4, 6}, 62 | }, 63 | { 64 | 11, 65 | &onnx.NodeProto{ 66 | Attribute: []*onnx.AttributeProto{ 67 | {Name: "exclusive", I: 0}, 68 | {Name: "reverse", I: 0}, 69 | }, 70 | }, 71 | []float32{1, 2, 3, 4}, 72 | 1, 73 | []int{2, 2}, 74 | []float32{1, 3, 3, 7}, 75 | }, 76 | { 77 | 11, 78 | &onnx.NodeProto{ 79 | Attribute: []*onnx.AttributeProto{ 80 | {Name: "exclusive", I: 1}, 81 | {Name: "reverse", I: 0}, 82 | }, 83 | }, 84 | []float32{1, 2, 3}, 85 | 0, 86 | []int{3}, 87 | []float32{0, 1, 3}, 88 | }, 89 | { 90 | 11, 91 | &onnx.NodeProto{ 92 | Attribute: []*onnx.AttributeProto{ 93 | {Name: "exclusive", I: 0}, 94 | {Name: "reverse", I: 1}, 95 | }, 96 | }, 97 | []float32{1, 2, 3}, 98 | 0, 99 | []int{3}, 100 | []float32{6, 5, 3}, 101 | }, 102 | } 103 | 104 | for _, test := range tests { 105 | inputs := []tensor.Tensor{ 106 | ops.TensorWithBackingFixture(test.backing, test.shape...), 107 | tensor.New(tensor.FromScalar(test.axis)), 108 | } 109 | 110 | cumsum := cumsumVersions[test.version]() 111 | err := cumsum.Init(test.node) 112 | assert.Nil(t, err) 113 | 114 | res, err := cumsum.Apply(inputs) 115 | assert.Nil(t, err) 116 | assert.Equal(t, test.expected, res[0].Data()) 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /ops/cumsum/versions.go: -------------------------------------------------------------------------------- 1 | package cumsum 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var cumsumVersions = ops.OperatorVersions{ 8 | 11: ops.NewOperatorConstructor(newCumSum, 11, cumsumTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return cumsumVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/div/div.go: -------------------------------------------------------------------------------- 1 | package div 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var divTypeConstraints = [][]tensor.Dtype{ 10 | {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 11 | {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 12 | } 13 | 14 | // Div represents the ONNX div operator. 15 | type Div struct { 16 | ops.BaseOperator 17 | } 18 | 19 | // newDiv creates a new div operator. 20 | func newDiv(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 21 | return &Div{ 22 | BaseOperator: ops.NewBaseOperator( 23 | version, 24 | 2, 25 | 2, 26 | typeConstraints, 27 | "div", 28 | ), 29 | } 30 | } 31 | 32 | // Init initializes the div operator. 33 | func (d *Div) Init(*onnx.NodeProto) error { 34 | return nil 35 | } 36 | 37 | // Apply applies the div operator. 38 | func (d *Div) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 39 | return ops.ApplyBinaryOperation( 40 | inputs[0], 41 | inputs[1], 42 | ops.Div, 43 | ops.MultidirectionalBroadcasting, 44 | ) 45 | } 46 | -------------------------------------------------------------------------------- /ops/div/versions.go: -------------------------------------------------------------------------------- 1 | package div 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var divVersions = ops.OperatorVersions{ 8 | 7: ops.NewOperatorConstructor(newDiv, 7, divTypeConstraints), 9 | 13: ops.NewOperatorConstructor(newDiv, 13, divTypeConstraints), 10 | } 11 | 12 | func GetVersions() ops.OperatorVersions { 13 | return divVersions 14 | } 15 | -------------------------------------------------------------------------------- /ops/equal/equal.go: -------------------------------------------------------------------------------- 1 | package equal 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var equal7TypeConstraints = [][]tensor.Dtype{ 10 | {tensor.Bool, tensor.Int32, tensor.Int64}, 11 | {tensor.Bool, tensor.Int32, tensor.Int64}, 12 | } 13 | 14 | var equalTypeConstraints = [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} 15 | 16 | // Equal represents the ONNX equal operator. 17 | type Equal struct { 18 | ops.BaseOperator 19 | } 20 | 21 | // newEqual creates a new equal operator. 22 | func newEqual(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 23 | return &Equal{ 24 | BaseOperator: ops.NewBaseOperator( 25 | version, 26 | 2, 27 | 2, 28 | typeConstraints, 29 | "equal", 30 | ), 31 | } 32 | } 33 | 34 | // Init initializes the equal operator. 35 | func (e *Equal) Init(*onnx.NodeProto) error { 36 | return nil 37 | } 38 | 39 | // Apply applies the equal operator. 40 | func (e *Equal) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 41 | return ops.ApplyBinaryOperation( 42 | inputs[0], 43 | inputs[1], 44 | ops.Equal, 45 | ops.MultidirectionalBroadcasting, 46 | ) 47 | } 48 | -------------------------------------------------------------------------------- /ops/equal/versions.go: -------------------------------------------------------------------------------- 1 | package equal 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var equalVersions = ops.OperatorVersions{ 8 | 7: ops.NewOperatorConstructor(newEqual, 7, equal7TypeConstraints), 9 | 11: ops.NewOperatorConstructor(newEqual, 11, equalTypeConstraints), 10 | 13: ops.NewOperatorConstructor(newEqual, 13, equalTypeConstraints), 11 | } 12 | 13 | func GetVersions() ops.OperatorVersions { 14 | return equalVersions 15 | } 16 | -------------------------------------------------------------------------------- /ops/erf/erf.go: -------------------------------------------------------------------------------- 1 | package erf 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "github.com/advancedclimatesystems/gonnx/ops" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | var erfTypeConstraints = [][]tensor.Dtype{ops.NumericTypes} 12 | 13 | // Erf represents the ONNX erf operator. 14 | type Erf struct { 15 | ops.BaseOperator 16 | } 17 | 18 | // newSin creates a new erf operator. 19 | func newErf(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 20 | return &Erf{ 21 | BaseOperator: ops.NewBaseOperator( 22 | version, 23 | 1, 24 | 1, 25 | typeConstraints, 26 | "erf", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the erf operator. 32 | func (e *Erf) Init(*onnx.NodeProto) error { 33 | return nil 34 | } 35 | 36 | // Apply applies the erf operator. 37 | func (e *Erf) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 38 | var ( 39 | out tensor.Tensor 40 | err error 41 | ) 42 | 43 | switch inputs[0].Dtype() { 44 | case tensor.Uint8: 45 | out, err = inputs[0].Apply(erf[uint8]) 46 | case tensor.Uint16: 47 | out, err = inputs[0].Apply(erf[uint16]) 48 | case tensor.Uint32: 49 | out, err = inputs[0].Apply(erf[uint32]) 50 | case tensor.Uint64: 51 | out, err = inputs[0].Apply(erf[uint64]) 52 | case tensor.Int8: 53 | out, err = inputs[0].Apply(erf[int8]) 54 | case tensor.Int16: 55 | out, err = inputs[0].Apply(erf[int16]) 56 | case tensor.Int32: 57 | out, err = inputs[0].Apply(erf[int32]) 58 | case tensor.Int64: 59 | out, err = inputs[0].Apply(erf[int64]) 60 | case tensor.Float32: 61 | out, err = inputs[0].Apply(erf[float32]) 62 | case tensor.Float64: 63 | out, err = inputs[0].Apply(erf[float64]) 64 | default: 65 | return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), e.BaseOperator) 66 | } 67 | 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | return []tensor.Tensor{out}, nil 73 | } 74 | 75 | func erf[T ops.NumericType](x T) T { 76 | return T(math.Erf(float64(x))) 77 | } 78 | -------------------------------------------------------------------------------- /ops/erf/erf_test.go: -------------------------------------------------------------------------------- 1 | package erf 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestErfInit(t *testing.T) { 12 | e := &Erf{} 13 | err := e.Init(nil) 14 | assert.Nil(t, err) 15 | } 16 | 17 | func TestErf(t *testing.T) { 18 | tests := []struct { 19 | version int64 20 | backing []float32 21 | shape []int 22 | expected []float32 23 | }{ 24 | { 25 | 9, 26 | []float32{-1, -1, 0, 1}, 27 | []int{2, 2}, 28 | []float32{-0.8427008, -0.8427008, 0, 0.8427008}, 29 | }, 30 | { 31 | 13, 32 | []float32{1, 0.5, 0.0, -0.5}, 33 | []int{1, 4}, 34 | []float32{0.8427008, 0.5204999, 0, -0.5204999}, 35 | }, 36 | } 37 | 38 | for _, test := range tests { 39 | inputs := []tensor.Tensor{ 40 | ops.TensorWithBackingFixture(test.backing, test.shape...), 41 | } 42 | 43 | erf := erfVersions[test.version]() 44 | 45 | res, err := erf.Apply(inputs) 46 | assert.Nil(t, err) 47 | 48 | assert.Nil(t, err) 49 | assert.Equal(t, test.expected, res[0].Data()) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /ops/erf/versions.go: -------------------------------------------------------------------------------- 1 | package erf 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var erfVersions = ops.OperatorVersions{ 8 | 9: ops.NewOperatorConstructor(newErf, 9, erfTypeConstraints), 9 | 13: ops.NewOperatorConstructor(newErf, 13, erfTypeConstraints), 10 | } 11 | 12 | func GetVersions() ops.OperatorVersions { 13 | return erfVersions 14 | } 15 | -------------------------------------------------------------------------------- /ops/expand/expand.go: -------------------------------------------------------------------------------- 1 | package expand 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var expandTypeConstraints = [][]tensor.Dtype{ops.AllTypes, {tensor.Int64}} 10 | 11 | // Expand represents the ONNX expand operator. 12 | type Expand struct { 13 | ops.BaseOperator 14 | } 15 | 16 | // newExpand creates a new expand operator. 17 | func newExpand(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 18 | return &Expand{ 19 | BaseOperator: ops.NewBaseOperator( 20 | version, 21 | 2, 22 | 2, 23 | typeConstraints, 24 | "expand", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the expand operator. 30 | func (f *Expand) Init(*onnx.NodeProto) error { 31 | return nil 32 | } 33 | 34 | // Apply applies the expand operator. 35 | func (f *Expand) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 36 | input := inputs[0] 37 | 38 | shape, err := ops.AnyToIntSlice(inputs[1].Data()) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | // If the new shape has more dimensions than the input tensor, we 44 | // need to prepend some dimensions to the input tensor shape. 45 | if len(shape) > len(input.Shape()) { 46 | input, err = ops.AddExtraDimsToTensor(input, len(shape)-len(input.Shape())) 47 | if err != nil { 48 | return nil, err 49 | } 50 | } 51 | 52 | for axis := len(shape) - 1; axis >= 0; axis-- { 53 | if input.Shape()[axis] != shape[axis] { 54 | input, err = tensor.Repeat(input, axis, shape[axis]) 55 | if err != nil { 56 | return nil, err 57 | } 58 | } 59 | } 60 | 61 | return []tensor.Tensor{input}, nil 62 | } 63 | -------------------------------------------------------------------------------- /ops/expand/versions.go: -------------------------------------------------------------------------------- 1 | package expand 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var expandVersions = ops.OperatorVersions{ 8 | 8: ops.NewOperatorConstructor(newExpand, 8, expandTypeConstraints), 9 | 13: ops.NewOperatorConstructor(newExpand, 13, expandTypeConstraints), 10 | } 11 | 12 | func GetVersions() ops.OperatorVersions { 13 | return expandVersions 14 | } 15 | -------------------------------------------------------------------------------- /ops/fixtures.go: -------------------------------------------------------------------------------- 1 | package ops 2 | 3 | import ( 4 | "math/rand" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "gorgonia.org/tensor" 8 | ) 9 | 10 | // InputFixture is a function that generates inputs for ops. Useful in testing. 11 | type InputFixture func() []tensor.Tensor 12 | 13 | // Float32TensorFixture returns a float32 backed gorgonia node. It initializes all its values 14 | // using tensor.Range. 15 | func Float32TensorFixture(shp ...int) tensor.Tensor { 16 | return tensor.New( 17 | tensor.WithShape(shp...), 18 | tensor.WithBacking(tensor.Range(tensor.Float32, 0, NElements(shp...))), 19 | ) 20 | } 21 | 22 | func RandomFloat32TensorFixture(r *rand.Rand, shp ...int) tensor.Tensor { 23 | rands := make([]float32, NElements(shp...)) 24 | for i := 0; i < NElements(shp...); i++ { 25 | rands[i] = r.Float32() 26 | } 27 | 28 | return tensor.New( 29 | tensor.WithShape(shp...), 30 | tensor.WithBacking(rands), 31 | ) 32 | } 33 | 34 | // TensorWithBackingFixture returns a gorgonia node with a tensor using the given backing. 35 | func TensorWithBackingFixture(b interface{}, shp ...int) tensor.Tensor { 36 | return tensor.New(tensor.WithShape(shp...), tensor.WithBacking(b)) 37 | } 38 | 39 | // TensorInputsFixture returns a list with a given number of tensors. 40 | func TensorInputsFixture(nTensors int) []tensor.Tensor { 41 | result := make([]tensor.Tensor, nTensors) 42 | for i := 0; i < nTensors; i++ { 43 | result[i] = tensor.New(tensor.WithShape(1), tensor.WithBacking([]float32{0.0})) 44 | } 45 | 46 | return result 47 | } 48 | 49 | // EmptyNodeProto returns a node proto with no attributes. 50 | func EmptyNodeProto() *onnx.NodeProto { 51 | return &onnx.NodeProto{Attribute: []*onnx.AttributeProto{}} 52 | } 53 | -------------------------------------------------------------------------------- /ops/flatten/constants.go: -------------------------------------------------------------------------------- 1 | package flatten 2 | 3 | const axis = "axis" 4 | -------------------------------------------------------------------------------- /ops/flatten/flatten.go: -------------------------------------------------------------------------------- 1 | package flatten 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | // Flatten provides common functionality for all Flatten versions. 10 | type Flatten struct { 11 | ops.BaseOperator 12 | axis int 13 | } 14 | 15 | func newFlatten(version int, typeConstraint [][]tensor.Dtype) ops.Operator { 16 | return &Flatten{ 17 | BaseOperator: ops.NewBaseOperator( 18 | version, 19 | 1, 20 | 1, 21 | typeConstraint, 22 | "flatten", 23 | ), 24 | } 25 | } 26 | 27 | // Init initializes the flatten operator. 28 | func (f *Flatten) Init(n *onnx.NodeProto) error { 29 | for _, attr := range n.GetAttribute() { 30 | switch attr.GetName() { 31 | case axis: 32 | f.axis = int(attr.GetI()) 33 | default: 34 | return ops.ErrInvalidAttribute(attr.GetName(), f) 35 | } 36 | } 37 | 38 | return nil 39 | } 40 | 41 | // Apply applies the flatten operator. 42 | func (f *Flatten) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 43 | inputShape := inputs[0].Shape() 44 | rank := len(inputShape) 45 | 46 | axis := f.axis 47 | if axis < 0 { 48 | axis = rank + axis 49 | } 50 | 51 | out, ok := inputs[0].Clone().(tensor.Tensor) 52 | if !ok { 53 | return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) 54 | } 55 | 56 | var err error 57 | // Handle the special case where axis is 0. 58 | if axis == 0 { 59 | err = out.Reshape(1, ops.NElements(inputShape...)) 60 | } else { 61 | err = out.Reshape(ops.NElements(inputShape[:axis]...), ops.NElements(inputShape[axis:]...)) 62 | } 63 | 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | return []tensor.Tensor{out}, nil 69 | } 70 | -------------------------------------------------------------------------------- /ops/flatten/versions.go: -------------------------------------------------------------------------------- 1 | package flatten 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | "gorgonia.org/tensor" 6 | ) 7 | 8 | var flattenVersions = ops.OperatorVersions{ 9 | 1: ops.NewOperatorConstructor(newFlatten, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}), 10 | 9: ops.NewOperatorConstructor(newFlatten, 9, [][]tensor.Dtype{ops.AllTypes}), 11 | 11: ops.NewOperatorConstructor(newFlatten, 11, [][]tensor.Dtype{ops.AllTypes}), 12 | 13: ops.NewOperatorConstructor(newFlatten, 13, [][]tensor.Dtype{ops.AllTypes}), 13 | } 14 | 15 | func GetVersions() ops.OperatorVersions { 16 | return flattenVersions 17 | } 18 | -------------------------------------------------------------------------------- /ops/gather/constants.go: -------------------------------------------------------------------------------- 1 | package gather 2 | 3 | const axis = "axis" 4 | -------------------------------------------------------------------------------- /ops/gather/versions.go: -------------------------------------------------------------------------------- 1 | package gather 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var gatherVersions = ops.OperatorVersions{ 6 | 1: ops.NewOperatorConstructor(newGather, 1, gatherTypeConstraints), 7 | 11: ops.NewOperatorConstructor(newGather, 11, gatherTypeConstraints), 8 | 13: ops.NewOperatorConstructor(newGather, 13, gatherTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return gatherVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/gemm/constants.go: -------------------------------------------------------------------------------- 1 | package gemm 2 | 3 | const ( 4 | alpha = "alpha" 5 | beta = "beta" 6 | transA = "transA" 7 | transB = "transB" 8 | ) 9 | -------------------------------------------------------------------------------- /ops/gemm/gemm_legacy.go: -------------------------------------------------------------------------------- 1 | package gemm 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | // Gemm9 represents the ONNX gemm operator, for version <= 9. 10 | type Gemm9 struct { 11 | ops.BaseOperator 12 | 13 | alpha float32 14 | beta float32 15 | transA bool 16 | transB bool 17 | } 18 | 19 | // newGemm7 creates a new gemm operator and initializes it with the default values. 20 | func newGemm9(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 21 | return &Gemm9{ 22 | BaseOperator: ops.NewBaseOperator(version, 3, 3, typeConstraints, "gemm"), 23 | alpha: 1.0, 24 | beta: 1.0, 25 | transA: false, 26 | transB: false, 27 | } 28 | } 29 | 30 | // Init initializes the Gemm7 operator based on the ModelProto attributes. 31 | func (g *Gemm9) Init(n *onnx.NodeProto) error { 32 | for _, attr := range n.GetAttribute() { 33 | switch attr.GetName() { 34 | case alpha: 35 | g.alpha = attr.GetF() 36 | case beta: 37 | g.beta = attr.GetF() 38 | case transA: 39 | g.transA = ops.Int64ToBool(attr.GetI()) 40 | case transB: 41 | g.transB = ops.Int64ToBool(attr.GetI()) 42 | default: 43 | return ops.ErrInvalidAttribute(attr.GetName(), g) 44 | } 45 | } 46 | 47 | return nil 48 | } 49 | 50 | // Apply applies the gemm operator on the given graph. 51 | func (g *Gemm9) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 52 | var err error 53 | 54 | a := inputs[0] 55 | b := inputs[1] 56 | c := inputs[2] 57 | 58 | if g.transA { 59 | a, err = tensor.Transpose(a) 60 | if err != nil { 61 | return nil, err 62 | } 63 | } 64 | 65 | if g.transB { 66 | b, err = tensor.Transpose(b) 67 | if err != nil { 68 | return nil, err 69 | } 70 | } 71 | 72 | x, err := tensor.MatMul(a, b) 73 | if err != nil { 74 | return nil, err 75 | } 76 | 77 | x, err = tensor.Mul(x, g.alpha) 78 | if err != nil { 79 | return nil, err 80 | } 81 | 82 | y, err := tensor.Mul(c, g.beta) 83 | if err != nil { 84 | return nil, err 85 | } 86 | 87 | x, y, err = ops.UnidirectionalBroadcast(x, y) 88 | if err != nil { 89 | return nil, err 90 | } 91 | 92 | output, err := tensor.Add(x, y) 93 | if err != nil { 94 | return nil, err 95 | } 96 | 97 | return []tensor.Tensor{output}, nil 98 | } 99 | -------------------------------------------------------------------------------- /ops/gemm/versions.go: -------------------------------------------------------------------------------- 1 | package gemm 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | "gorgonia.org/tensor" 6 | ) 7 | 8 | var gemmVersions = ops.OperatorVersions{ 9 | 7: ops.NewOperatorConstructor( 10 | newGemm9, 11 | 7, 12 | [][]tensor.Dtype{ 13 | {tensor.Float32, tensor.Float64}, 14 | {tensor.Float32, tensor.Float64}, 15 | {tensor.Float32, tensor.Float64}, 16 | }, 17 | ), 18 | 9: ops.NewOperatorConstructor(newGemm9, 9, gemmTypeConstraints), 19 | 11: ops.NewOperatorConstructor(newGemm, 11, gemmTypeConstraints), 20 | 13: ops.NewOperatorConstructor(newGemm, 13, gemmTypeConstraints), 21 | } 22 | 23 | func GetVersions() ops.OperatorVersions { 24 | return gemmVersions 25 | } 26 | -------------------------------------------------------------------------------- /ops/greater/greater.go: -------------------------------------------------------------------------------- 1 | package greater 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var greater7TypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}, {tensor.Float32, tensor.Float64}} 10 | 11 | var greaterTypeConstraints = [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} 12 | 13 | // Greater represents the ONNX greater operator. 14 | type Greater struct { 15 | ops.BaseOperator 16 | } 17 | 18 | // newGreater creates a new greater operator. 19 | func newGreater(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 20 | return &Greater{ 21 | BaseOperator: ops.NewBaseOperator( 22 | version, 23 | 2, 24 | 2, 25 | typeConstraints, 26 | "greater", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the greater operator. 32 | func (g *Greater) Init(*onnx.NodeProto) error { 33 | return nil 34 | } 35 | 36 | // Apply applies the greater operator. 37 | func (g *Greater) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 38 | return ops.ApplyBinaryOperation( 39 | inputs[0], 40 | inputs[1], 41 | ops.Gt, 42 | ops.MultidirectionalBroadcasting, 43 | ) 44 | } 45 | -------------------------------------------------------------------------------- /ops/greater/versions.go: -------------------------------------------------------------------------------- 1 | package greater 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var greaterVersions = ops.OperatorVersions{ 6 | 7: ops.NewOperatorConstructor(newGreater, 7, greater7TypeConstraints), 7 | 9: ops.NewOperatorConstructor(newGreater, 9, greaterTypeConstraints), 8 | 13: ops.NewOperatorConstructor(newGreater, 13, greaterTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return greaterVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/greaterorequal/greater_or_equal.go: -------------------------------------------------------------------------------- 1 | package greaterorequal 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var greaterOrEqualTypeConstraints = [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} 10 | 11 | // GreaterOrEqual represents the ONNX greaterOrEqual operator. 12 | type GreaterOrEqual struct { 13 | ops.BaseOperator 14 | } 15 | 16 | // newGreaterOrEqual creates a new greaterOrEqual operator. 17 | func newGreaterOrEqual(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 18 | return &GreaterOrEqual{ 19 | BaseOperator: ops.NewBaseOperator( 20 | version, 21 | 2, 22 | 2, 23 | typeConstraints, 24 | "greaterorequal", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the greaterOrEqual operator. 30 | func (g *GreaterOrEqual) Init(*onnx.NodeProto) error { 31 | return nil 32 | } 33 | 34 | // Apply applies the greaterOrEqual operator. 35 | func (g *GreaterOrEqual) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 36 | return ops.ApplyBinaryOperation( 37 | inputs[0], 38 | inputs[1], 39 | ops.Gte, 40 | ops.MultidirectionalBroadcasting, 41 | ) 42 | } 43 | -------------------------------------------------------------------------------- /ops/greaterorequal/versions.go: -------------------------------------------------------------------------------- 1 | package greaterorequal 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var greaterOrEqualVersions = ops.OperatorVersions{ 6 | 12: ops.NewOperatorConstructor(newGreaterOrEqual, 12, greaterOrEqualTypeConstraints), 7 | } 8 | 9 | func GetVersions() ops.OperatorVersions { 10 | return greaterOrEqualVersions 11 | } 12 | -------------------------------------------------------------------------------- /ops/gru/versions.go: -------------------------------------------------------------------------------- 1 | package gru 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var gruVersions = ops.OperatorVersions{ 6 | 7: ops.NewOperatorConstructor(newGRU, 7, gruTypeConstraints), 7 | } 8 | 9 | func GetVersions() ops.OperatorVersions { 10 | return gruVersions 11 | } 12 | -------------------------------------------------------------------------------- /ops/identity/identity.go: -------------------------------------------------------------------------------- 1 | package identity 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var identityTypeConstraints = [][]tensor.Dtype{ops.AllTypes} 10 | 11 | // Identity represents the ONNX identity operator. 12 | type Identity struct { 13 | ops.BaseOperator 14 | } 15 | 16 | // newIdentity creates a new identity operator. 17 | func newIdentity(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 18 | return &Identity{ 19 | BaseOperator: ops.NewBaseOperator( 20 | version, 21 | 1, 22 | 1, 23 | typeConstraints, 24 | "identity", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the identity operator. 30 | func (a *Identity) Init(*onnx.NodeProto) error { 31 | return nil 32 | } 33 | 34 | // Apply applies the identity operator. 35 | func (a *Identity) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 36 | out, ok := inputs[0].Clone().(tensor.Tensor) 37 | if !ok { 38 | return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) 39 | } 40 | 41 | return []tensor.Tensor{out}, nil 42 | } 43 | -------------------------------------------------------------------------------- /ops/identity/identity_test.go: -------------------------------------------------------------------------------- 1 | package identity 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestIdentityInit(t *testing.T) { 12 | i := &Identity{} 13 | 14 | // since 'identity' does not have any attributes we pass in nil. This should not 15 | // fail initializing the identity. 16 | err := i.Init(nil) 17 | assert.Nil(t, err) 18 | } 19 | 20 | func TestIdentity(t *testing.T) { 21 | tests := []struct { 22 | version int64 23 | backing []float32 24 | shape []int 25 | expected []float32 26 | }{ 27 | { 28 | 13, 29 | []float32{0, 1, 2, 3}, 30 | []int{2, 2}, 31 | []float32{0, 1, 2, 3}, 32 | }, 33 | } 34 | 35 | for _, test := range tests { 36 | inputs := []tensor.Tensor{ 37 | ops.TensorWithBackingFixture(test.backing, test.shape...), 38 | } 39 | 40 | identity := identityVersions[test.version]() 41 | 42 | res, err := identity.Apply(inputs) 43 | assert.Nil(t, err) 44 | 45 | assert.Equal(t, test.expected, res[0].Data()) 46 | } 47 | } 48 | 49 | func TestInputValidationIdentity(t *testing.T) { 50 | tests := []struct { 51 | version int64 52 | inputs []tensor.Tensor 53 | err error 54 | }{ 55 | { 56 | 13, 57 | []tensor.Tensor{ 58 | ops.TensorWithBackingFixture([]uint32{1, 2}, 2), 59 | }, 60 | nil, 61 | }, 62 | { 63 | 13, 64 | []tensor.Tensor{ 65 | ops.TensorWithBackingFixture([]float32{1, 2}, 2), 66 | }, 67 | nil, 68 | }, 69 | { 70 | 13, 71 | []tensor.Tensor{ 72 | ops.TensorWithBackingFixture([]float32{1, 2}, 2), 73 | ops.TensorWithBackingFixture([]float32{3, 4}, 2), 74 | }, 75 | ops.ErrInvalidInputCount(2, identity13BaseOpFixture()), 76 | }, 77 | { 78 | 13, 79 | []tensor.Tensor{ 80 | ops.TensorWithBackingFixture([]int{1, 2}, 2), 81 | }, 82 | ops.ErrInvalidInputType(0, "int", identity13BaseOpFixture()), 83 | }, 84 | } 85 | 86 | for _, test := range tests { 87 | identity := identityVersions[test.version]() 88 | validated, err := identity.ValidateInputs(test.inputs) 89 | 90 | assert.Equal(t, test.err, err) 91 | 92 | if test.err == nil { 93 | assert.Equal(t, test.inputs, validated) 94 | } 95 | } 96 | } 97 | 98 | func identity13BaseOpFixture() ops.BaseOperator { 99 | return ops.NewBaseOperator(13, 1, 1, identityTypeConstraints, "identity") 100 | } 101 | -------------------------------------------------------------------------------- /ops/identity/versions.go: -------------------------------------------------------------------------------- 1 | package identity 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var identityVersions = ops.OperatorVersions{ 8 | 13: ops.NewOperatorConstructor(newIdentity, 13, identityTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return identityVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/less/less.go: -------------------------------------------------------------------------------- 1 | package less 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var less7TypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}, {tensor.Float32, tensor.Float64}} 10 | 11 | var lessTypeConstraints = [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} 12 | 13 | // Less represents the ONNX less operator. 14 | type Less struct { 15 | ops.BaseOperator 16 | } 17 | 18 | // newLess creates a new less operator. 19 | func newLess(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 20 | return &Less{ 21 | BaseOperator: ops.NewBaseOperator( 22 | version, 23 | 2, 24 | 2, 25 | typeConstraints, 26 | "less", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the less operator. 32 | func (l *Less) Init(*onnx.NodeProto) error { 33 | return nil 34 | } 35 | 36 | // Apply applies the less operator. 37 | func (l *Less) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 38 | return ops.ApplyBinaryOperation( 39 | inputs[0], 40 | inputs[1], 41 | ops.Lt, 42 | ops.MultidirectionalBroadcasting, 43 | ) 44 | } 45 | -------------------------------------------------------------------------------- /ops/less/versions.go: -------------------------------------------------------------------------------- 1 | package less 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var lessVersions = ops.OperatorVersions{ 6 | 7: ops.NewOperatorConstructor(newLess, 7, less7TypeConstraints), 7 | 9: ops.NewOperatorConstructor(newLess, 9, lessTypeConstraints), 8 | 13: ops.NewOperatorConstructor(newLess, 13, lessTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return lessVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/lessorequal/less_or_equal.go: -------------------------------------------------------------------------------- 1 | package lessorequal 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var lessOrEqualTypeConstraints = [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} 10 | 11 | // LessOrEqual represents the ONNX lessOrEqual operator. 12 | type LessOrEqual struct { 13 | ops.BaseOperator 14 | } 15 | 16 | // newLessOrEqual creates a new lessOrEqual operator. 17 | func newLessOrEqual(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 18 | return &LessOrEqual{ 19 | BaseOperator: ops.NewBaseOperator( 20 | version, 21 | 2, 22 | 2, 23 | typeConstraints, 24 | "lessorequal", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the lessOrEqual operator. 30 | func (l *LessOrEqual) Init(*onnx.NodeProto) error { 31 | return nil 32 | } 33 | 34 | // Apply applies the lessOrEqual operator. 35 | func (l *LessOrEqual) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 36 | return ops.ApplyBinaryOperation( 37 | inputs[0], 38 | inputs[1], 39 | ops.Lte, 40 | ops.MultidirectionalBroadcasting, 41 | ) 42 | } 43 | -------------------------------------------------------------------------------- /ops/lessorequal/versions.go: -------------------------------------------------------------------------------- 1 | package lessorequal 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var lessOrEqualVersions = ops.OperatorVersions{ 6 | 12: ops.NewOperatorConstructor(newLessOrEqual, 12, lessOrEqualTypeConstraints), 7 | } 8 | 9 | func GetVersions() ops.OperatorVersions { 10 | return lessOrEqualVersions 11 | } 12 | -------------------------------------------------------------------------------- /ops/linearregressor/linear_regressor.go: -------------------------------------------------------------------------------- 1 | package linearregressor 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var linearRegressorTypeConstraints = [][]tensor.Dtype{ 10 | {tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 11 | } 12 | 13 | // PostTransformOption describes all possible post transform options for the 14 | // linear regressor operator. 15 | type postTransformOption string 16 | 17 | const ( 18 | noTransform postTransformOption = "NONE" 19 | softmaxTransform postTransformOption = "SOFTMAX" 20 | logisticTransform postTransformOption = "LOGISTIC" 21 | softmaxZeroTransform postTransformOption = "SOFTMAX_ZERO" 22 | probitTransform postTransformOption = "PROBIT" 23 | ) 24 | 25 | // LinearRegressor represents the ONNX-ml linearRegressor operator. 26 | type LinearRegressor struct { 27 | ops.BaseOperator 28 | 29 | coefficients tensor.Tensor 30 | intercepts tensor.Tensor 31 | postTransform postTransformOption 32 | targets int 33 | } 34 | 35 | // newLinearRegressor creates a new linearRegressor operator. 36 | func newLinearRegressor(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 37 | return &LinearRegressor{ 38 | BaseOperator: ops.NewBaseOperator( 39 | version, 40 | 1, 41 | 1, 42 | typeConstraints, 43 | "linearregressor", 44 | ), 45 | postTransform: noTransform, 46 | targets: 1, 47 | } 48 | } 49 | 50 | // Init initializes the linearRegressor operator. 51 | func (l *LinearRegressor) Init(n *onnx.NodeProto) error { 52 | for _, attr := range n.GetAttribute() { 53 | switch attr.GetName() { 54 | case "coefficients": 55 | floats := attr.GetFloats() 56 | l.coefficients = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats)) 57 | case "intercepts": 58 | floats := attr.GetFloats() 59 | l.intercepts = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats)) 60 | case "post_transform": 61 | return ops.ErrUnsupportedAttribute(attr.GetName(), l) 62 | case "targets": 63 | l.targets = int(attr.GetI()) 64 | default: 65 | return ops.ErrInvalidAttribute(attr.GetName(), l) 66 | } 67 | } 68 | 69 | err := l.coefficients.Reshape(l.targets, ops.NElements(l.coefficients.Shape()...)/l.targets) 70 | if err != nil { 71 | return err 72 | } 73 | 74 | return l.coefficients.T() 75 | } 76 | 77 | // Apply applies the linearRegressor operator. 78 | func (l *LinearRegressor) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 79 | X := inputs[0] 80 | 81 | result, err := tensor.MatMul(X, l.coefficients) 82 | if err != nil { 83 | return nil, err 84 | } 85 | 86 | result, intercepts, err := ops.UnidirectionalBroadcast(result, l.intercepts) 87 | if err != nil { 88 | return nil, err 89 | } 90 | 91 | Y, err := tensor.Add(result, intercepts) 92 | if err != nil { 93 | return nil, err 94 | } 95 | 96 | return []tensor.Tensor{Y}, nil 97 | } 98 | -------------------------------------------------------------------------------- /ops/linearregressor/versions.go: -------------------------------------------------------------------------------- 1 | package linearregressor 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var linearRegressorVersions = ops.OperatorVersions{ 6 | 1: ops.NewOperatorConstructor(newLinearRegressor, 1, linearRegressorTypeConstraints), 7 | } 8 | 9 | func GetVersions() ops.OperatorVersions { 10 | return linearRegressorVersions 11 | } 12 | -------------------------------------------------------------------------------- /ops/logsoftmax/logsoftmax.go: -------------------------------------------------------------------------------- 1 | package logsoftmax 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var logSoftmaxTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} 10 | 11 | // LogSoftmax represents the ONNX logsoftmax operator. 12 | type LogSoftmax struct { 13 | ops.BaseOperator 14 | 15 | // The axis along which to perform the LogSoftmax operation. 16 | axis int 17 | } 18 | 19 | // newLogSoftmax creates a new logsoftmax operator. 20 | func newLogSoftmax(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 21 | return &LogSoftmax{ 22 | BaseOperator: ops.NewBaseOperator( 23 | version, 24 | 1, 25 | 1, 26 | typeConstraints, 27 | "logsoftmax", 28 | ), 29 | axis: -1, 30 | } 31 | } 32 | 33 | // Init initializes the logsoftmax operator. 34 | func (l *LogSoftmax) Init(n *onnx.NodeProto) error { 35 | attributes := n.GetAttribute() 36 | 37 | nAttributes := len(attributes) 38 | if nAttributes > 1 { 39 | return ops.ErrInvalidAttributeCount(1, nAttributes, l) 40 | } 41 | 42 | if nAttributes == 1 { 43 | l.axis = int(attributes[0].GetI()) 44 | } 45 | 46 | return nil 47 | } 48 | 49 | // Apply applies the logsoftmax operator. 50 | func (l *LogSoftmax) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 51 | input := inputs[0] 52 | nDims := len(input.Shape()) 53 | 54 | if l.axis < -nDims || l.axis >= nDims { 55 | return nil, ops.ErrAxisOutOfRange(-nDims, nDims, l.axis) 56 | } 57 | 58 | axis := l.axis 59 | if l.axis < 0 { 60 | axis += nDims 61 | } 62 | 63 | out, err := tensor.LogSoftMax(inputs[0], axis) 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | return []tensor.Tensor{out}, nil 69 | } 70 | -------------------------------------------------------------------------------- /ops/logsoftmax/versions.go: -------------------------------------------------------------------------------- 1 | package logsoftmax 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var logSoftmaxVersions = ops.OperatorVersions{ 6 | 1: ops.NewOperatorConstructor(newLogSoftmax, 1, logSoftmaxTypeConstraints), 7 | 11: ops.NewOperatorConstructor(newLogSoftmax, 11, logSoftmaxTypeConstraints), 8 | 13: ops.NewOperatorConstructor(newLogSoftmax, 13, logSoftmaxTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return logSoftmaxVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/lstm/versions.go: -------------------------------------------------------------------------------- 1 | package lstm 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var lstmVersions = ops.OperatorVersions{ 6 | 7: ops.NewOperatorConstructor(newLSTM, 7, lstmTypeConstraints), 7 | } 8 | 9 | func GetVersions() ops.OperatorVersions { 10 | return lstmVersions 11 | } 12 | -------------------------------------------------------------------------------- /ops/matmul/versions.go: -------------------------------------------------------------------------------- 1 | package matmul 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var matMulVersions = ops.OperatorVersions{ 6 | 1: ops.NewOperatorConstructor(newMatMul, 1, matmul1TypeConstraints), 7 | 9: ops.NewOperatorConstructor(newMatMul, 9, matmulTypeConstraints), 8 | 13: ops.NewOperatorConstructor(newMatMul, 13, matmulTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return matMulVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/mul/mul.go: -------------------------------------------------------------------------------- 1 | package mul 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var mulTypeConstraints = [][]tensor.Dtype{ 10 | {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 11 | {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 12 | } 13 | 14 | // Mul represents the ONNX mul operator. 15 | type Mul struct { 16 | ops.BaseOperator 17 | } 18 | 19 | // newMul creates a new mul operator. 20 | func newMul(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 21 | return &Mul{ 22 | BaseOperator: ops.NewBaseOperator( 23 | version, 24 | 2, 25 | 2, 26 | typeConstraints, 27 | "mul", 28 | ), 29 | } 30 | } 31 | 32 | // Init initializes the mul operator. 33 | func (m *Mul) Init(*onnx.NodeProto) error { 34 | return nil 35 | } 36 | 37 | // Apply applies the mul operator. 38 | func (m *Mul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 39 | return ops.ApplyBinaryOperation( 40 | inputs[0], 41 | inputs[1], 42 | ops.Mul, 43 | ops.MultidirectionalBroadcasting, 44 | ) 45 | } 46 | -------------------------------------------------------------------------------- /ops/mul/versions.go: -------------------------------------------------------------------------------- 1 | package mul 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var mulVersions = ops.OperatorVersions{ 6 | 7: ops.NewOperatorConstructor(newMul, 7, mulTypeConstraints), 7 | 13: ops.NewOperatorConstructor(newMul, 13, mulTypeConstraints), 8 | } 9 | 10 | func GetVersions() ops.OperatorVersions { 11 | return mulVersions 12 | } 13 | -------------------------------------------------------------------------------- /ops/multidir_broadcast_test.go: -------------------------------------------------------------------------------- 1 | package ops 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "gorgonia.org/tensor" 8 | ) 9 | 10 | func TestMultidirectionalBroadcast(t *testing.T) { 11 | tests := []struct { 12 | shapes [][]int 13 | expectedShape tensor.Shape 14 | err error 15 | }{ 16 | { 17 | [][]int{{2}, {2, 2}}, 18 | []int{2, 2}, 19 | nil, 20 | }, 21 | { 22 | [][]int{{2, 3, 4, 5}, {}}, 23 | []int{2, 3, 4, 5}, 24 | nil, 25 | }, 26 | { 27 | [][]int{{2, 3, 4, 5}, {5}}, 28 | []int{2, 3, 4, 5}, 29 | nil, 30 | }, 31 | { 32 | [][]int{{4, 5}, {2, 3, 4, 5}}, 33 | []int{2, 3, 4, 5}, 34 | nil, 35 | }, 36 | { 37 | [][]int{{1, 4, 5}, {2, 3, 1, 1}}, 38 | []int{2, 3, 4, 5}, 39 | nil, 40 | }, 41 | { 42 | [][]int{{3, 4, 5}, {2, 1, 1, 1}}, 43 | []int{2, 3, 4, 5}, 44 | nil, 45 | }, 46 | { 47 | [][]int{{1, 4, 5}, {2, 1, 1, 3}}, 48 | nil, 49 | ErrMultidirBroadcast([]int{1, 4, 5}, []int{2, 1, 1, 3}, ErrIncompatibleDimensions()), 50 | }, 51 | { 52 | [][]int{{5}, {2, 3, 4}}, 53 | nil, 54 | ErrMultidirBroadcast([]int{5}, []int{2, 3, 4}, ErrIncompatibleDimensions()), 55 | }, 56 | } 57 | 58 | for _, test := range tests { 59 | A := Float32TensorFixture(test.shapes[0]...) 60 | B := Float32TensorFixture(test.shapes[1]...) 61 | 62 | newA, newB, err := MultidirectionalBroadcast(A, B) 63 | 64 | assert.Equal(t, test.err, err) 65 | 66 | if err == nil { 67 | assert.Equal(t, test.expectedShape, newA.Shape()) 68 | assert.Equal(t, test.expectedShape, newB.Shape()) 69 | } else { 70 | assert.Nil(t, newA) 71 | assert.Nil(t, newB) 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /ops/not/not.go: -------------------------------------------------------------------------------- 1 | package not 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var notTypeConstraints = [][]tensor.Dtype{{tensor.Bool}} 10 | 11 | // Not represents the ONNX not operator. 12 | type Not struct { 13 | ops.BaseOperator 14 | } 15 | 16 | // newNot creates a new not operator. 17 | func newNot(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 18 | return &Not{ 19 | BaseOperator: ops.NewBaseOperator( 20 | version, 21 | 1, 22 | 1, 23 | typeConstraints, 24 | "not", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the not operator. 30 | func (n *Not) Init(*onnx.NodeProto) error { 31 | return nil 32 | } 33 | 34 | // Apply applies the not operator. 35 | func (n *Not) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 36 | out, err := inputs[0].Apply(not) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | return []tensor.Tensor{out}, nil 42 | } 43 | 44 | func not(x bool) bool { 45 | return !x 46 | } 47 | -------------------------------------------------------------------------------- /ops/not/not_test.go: -------------------------------------------------------------------------------- 1 | package not 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestNotInit(t *testing.T) { 12 | n := &Not{} 13 | 14 | // since 'not' does not have any attributes we pass in nil. This should not 15 | // fail initializing the not. 16 | err := n.Init(nil) 17 | assert.Nil(t, err) 18 | } 19 | 20 | func TestNot(t *testing.T) { 21 | tests := []struct { 22 | version int64 23 | backing []bool 24 | shape []int 25 | expected []bool 26 | }{ 27 | { 28 | 1, 29 | []bool{true, false, true, false}, 30 | []int{2, 2}, 31 | []bool{false, true, false, true}, 32 | }, 33 | { 34 | 1, 35 | []bool{true, true, false, false}, 36 | []int{1, 4}, 37 | []bool{false, false, true, true}, 38 | }, 39 | { 40 | 1, 41 | []bool{false, false, false, false}, 42 | []int{4, 1}, 43 | []bool{true, true, true, true}, 44 | }, 45 | } 46 | 47 | for _, test := range tests { 48 | inputs := []tensor.Tensor{ 49 | ops.TensorWithBackingFixture(test.backing, test.shape...), 50 | } 51 | 52 | not := notVersions[test.version]() 53 | res, err := not.Apply(inputs) 54 | assert.Nil(t, err) 55 | 56 | assert.Nil(t, err) 57 | assert.Equal(t, test.expected, res[0].Data()) 58 | } 59 | } 60 | 61 | func TestInputValidationNot(t *testing.T) { 62 | tests := []struct { 63 | version int64 64 | inputs []tensor.Tensor 65 | err error 66 | }{ 67 | { 68 | 1, 69 | []tensor.Tensor{ 70 | ops.TensorWithBackingFixture([]bool{false, false}, 2), 71 | }, 72 | nil, 73 | }, 74 | { 75 | 1, 76 | []tensor.Tensor{}, 77 | ops.ErrInvalidInputCount(0, not1BaseOpFixture()), 78 | }, 79 | { 80 | 1, 81 | []tensor.Tensor{ 82 | ops.TensorWithBackingFixture([]int{1, 2}, 2), 83 | }, 84 | ops.ErrInvalidInputType(0, "int", not1BaseOpFixture()), 85 | }, 86 | } 87 | 88 | for _, test := range tests { 89 | not := notVersions[test.version]() 90 | validated, err := not.ValidateInputs(test.inputs) 91 | 92 | assert.Equal(t, test.err, err) 93 | 94 | if test.err == nil { 95 | assert.Equal(t, test.inputs, validated) 96 | } 97 | } 98 | } 99 | 100 | func not1BaseOpFixture() ops.BaseOperator { 101 | return ops.NewBaseOperator( 102 | 1, 103 | 1, 104 | 1, 105 | notTypeConstraints, 106 | "not", 107 | ) 108 | } 109 | -------------------------------------------------------------------------------- /ops/not/versions.go: -------------------------------------------------------------------------------- 1 | package not 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var notVersions = ops.OperatorVersions{ 6 | 1: ops.NewOperatorConstructor(newNot, 1, notTypeConstraints), 7 | } 8 | 9 | func GetVersions() ops.OperatorVersions { 10 | return notVersions 11 | } 12 | -------------------------------------------------------------------------------- /ops/operator.go: -------------------------------------------------------------------------------- 1 | package ops 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "gorgonia.org/tensor" 6 | ) 7 | 8 | type OperatorVersions map[int64]OperatorFactory 9 | 10 | type OperatorFactory func() Operator 11 | 12 | type Constructor func(int, [][]tensor.Dtype) Operator 13 | 14 | func NewOperatorConstructor(fn Constructor, version int, typeContstraint [][]tensor.Dtype) OperatorFactory { 15 | return func() Operator { 16 | return fn(version, typeContstraint) 17 | } 18 | } 19 | 20 | // Operator is the base interface for all operators. 21 | type Operator interface { 22 | // String should return a simple string describing the operator 23 | String() string 24 | 25 | // Init should initialize the operator based on the given node. 26 | // This node contains attributes, which outputs are expected and more. How these 27 | // attributes influence the operator is defined by the ONNX standard, and can be 28 | // found in https://github.com/onnx/onnx/blob/main/docs/Operators.md 29 | Init(*onnx.NodeProto) error 30 | 31 | // Apply should apply the operator to the list of input tensors. It should return a 32 | // list with output tensors, the result of the operator. 33 | Apply([]tensor.Tensor) ([]tensor.Tensor, error) 34 | 35 | // GetMinInputs should return the minimum number of inputs this operator expects. 36 | GetMinInputs() int 37 | 38 | // GetMaxInputs should return the maximum number of inputs this operator expects. 39 | GetMaxInputs() int 40 | 41 | // GetInputTypeConstraints should return a list. Every element represents a set of 42 | // allowed tensor dtypes for the corresponding input tensor. 43 | GetInputTypeConstraints() [][]tensor.Dtype 44 | 45 | // ValidateInputs should validate the list of input tensors. It should check for both 46 | // the right amount of inputs and the correct dtypes of the tensors. 47 | ValidateInputs([]tensor.Tensor) ([]tensor.Tensor, error) 48 | 49 | // Version returns the version of this operator. 50 | Version() int 51 | } 52 | -------------------------------------------------------------------------------- /ops/or/or.go: -------------------------------------------------------------------------------- 1 | package or 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var orTypeConstraints = [][]tensor.Dtype{{tensor.Bool}, {tensor.Bool}} 10 | 11 | // Or represents the ONNX or operator. 12 | type Or struct { 13 | ops.BaseOperator 14 | } 15 | 16 | // newOr creates a new or operator. 17 | func newOr(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 18 | return &Or{ 19 | BaseOperator: ops.NewBaseOperator( 20 | version, 21 | 2, 22 | 2, 23 | typeConstraints, 24 | "or", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the or operator. 30 | func (o *Or) Init(*onnx.NodeProto) error { 31 | return nil 32 | } 33 | 34 | // Apply applies the or operator. 35 | func (o *Or) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 36 | return ops.ApplyBinaryOperation( 37 | inputs[0], 38 | inputs[1], 39 | ops.Or, 40 | ops.MultidirectionalBroadcasting, 41 | ) 42 | } 43 | -------------------------------------------------------------------------------- /ops/or/versions.go: -------------------------------------------------------------------------------- 1 | package or 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var orVersions = ops.OperatorVersions{ 6 | 7: ops.NewOperatorConstructor(newOr, 7, orTypeConstraints), 7 | } 8 | 9 | func GetVersions() ops.OperatorVersions { 10 | return orVersions 11 | } 12 | -------------------------------------------------------------------------------- /ops/pow/pow.go: -------------------------------------------------------------------------------- 1 | package pow 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var pow7TypeConstraints = [][]tensor.Dtype{ 10 | {tensor.Float32, tensor.Float64}, 11 | {tensor.Float32, tensor.Float64}, 12 | } 13 | 14 | var powTypeConstraints = [][]tensor.Dtype{ 15 | {tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 16 | {tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 17 | } 18 | 19 | // Pow represents the ONNX pow operator. 20 | type Pow struct { 21 | ops.BaseOperator 22 | } 23 | 24 | // newPow creates a new pow operator. 25 | func newPow(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 26 | return &Pow{ 27 | BaseOperator: ops.NewBaseOperator( 28 | version, 29 | 2, 30 | 2, 31 | typeConstraints, 32 | "pow", 33 | ), 34 | } 35 | } 36 | 37 | // Init initializes the pow operator. 38 | func (a *Pow) Init(*onnx.NodeProto) error { 39 | return nil 40 | } 41 | 42 | // Apply applies the pow operator. 43 | func (a *Pow) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 44 | powTensor := inputs[1] 45 | if inputs[0].Dtype() != powTensor.Dtype() { 46 | to, err := ops.DTypeToONNXType(inputs[0].Dtype()) 47 | if err != nil { 48 | return nil, err 49 | } 50 | 51 | powTensor, err = ops.ConvertTensorDtype(powTensor, to) 52 | if err != nil { 53 | return nil, err 54 | } 55 | } 56 | 57 | return ops.ApplyBinaryOperation( 58 | inputs[0], 59 | powTensor, 60 | ops.Pow, 61 | ops.MultidirectionalBroadcasting, 62 | ) 63 | } 64 | -------------------------------------------------------------------------------- /ops/pow/pow_test.go: -------------------------------------------------------------------------------- 1 | package pow 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestPowInit(t *testing.T) { 12 | p := &Pow{} 13 | err := p.Init(nil) 14 | assert.Nil(t, err) 15 | } 16 | 17 | func TestPow(t *testing.T) { 18 | tests := []struct { 19 | version int64 20 | backing0 any 21 | backing1 any 22 | shapes [][]int 23 | expected any 24 | }{ 25 | { 26 | 13, 27 | []float32{0, 1, 2, 3}, 28 | []float32{1, 1, 1, 1}, 29 | [][]int{{2, 2}, {2, 2}}, 30 | []float32{0, 1, 2, 3}, 31 | }, 32 | { 33 | 13, 34 | []float32{0, 1, 2, 3, 4, 5}, 35 | []float32{2, 2, 2, 2, 2, 2}, 36 | [][]int{{3, 2}, {3, 2}}, 37 | []float32{0, 1, 4, 9, 16, 25}, 38 | }, 39 | { 40 | 13, 41 | []float32{0, 1}, 42 | []float32{0, 1, 2, 3}, 43 | [][]int{{2}, {2, 2}}, 44 | []float32{1, 1, 0, 1}, 45 | }, 46 | { 47 | 13, 48 | []int32{1, 2, 3}, 49 | []int32{4, 5, 6}, 50 | [][]int{{3}, {3}}, 51 | []int32{1, 32, 729}, 52 | }, 53 | } 54 | 55 | for _, test := range tests { 56 | inputs := []tensor.Tensor{ 57 | ops.TensorWithBackingFixture(test.backing0, test.shapes[0]...), 58 | ops.TensorWithBackingFixture(test.backing1, test.shapes[1]...), 59 | } 60 | 61 | pow := powVersions[test.version]() 62 | 63 | res, err := pow.Apply(inputs) 64 | assert.Nil(t, err) 65 | 66 | assert.Equal(t, test.expected, res[0].Data()) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /ops/pow/versions.go: -------------------------------------------------------------------------------- 1 | package pow 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | ) 6 | 7 | var powVersions = ops.OperatorVersions{ 8 | 7: ops.NewOperatorConstructor(newPow, 7, pow7TypeConstraints), 9 | 12: ops.NewOperatorConstructor(newPow, 12, powTypeConstraints), 10 | 13: ops.NewOperatorConstructor(newPow, 13, powTypeConstraints), 11 | } 12 | 13 | func GetVersions() ops.OperatorVersions { 14 | return powVersions 15 | } 16 | -------------------------------------------------------------------------------- /ops/prelu/versions.go: -------------------------------------------------------------------------------- 1 | package prelu 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/ops" 5 | "gorgonia.org/tensor" 6 | ) 7 | 8 | var preluVersions = ops.OperatorVersions{ 9 | 7: ops.NewOperatorConstructor(newPRelu, 7, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}, {tensor.Float32, tensor.Float64}}), 10 | 9: ops.NewOperatorConstructor(newPRelu, 9, preluTypeConstraints), 11 | } 12 | 13 | func GetVersions() ops.OperatorVersions { 14 | return preluVersions 15 | } 16 | -------------------------------------------------------------------------------- /ops/recurrent_utils.go: -------------------------------------------------------------------------------- 1 | package ops 2 | 3 | import ( 4 | "gorgonia.org/tensor" 5 | ) 6 | 7 | // SequenceProcessDirection is the direction in which a sequential input is processed. 8 | // We can process sequential inputs forward (from first to last), in reverse (from 9 | // last to first) or bidirectional (which is both forward and reverse added together). 10 | type SequenceProcessDirection string 11 | 12 | const ( 13 | Forward SequenceProcessDirection = "forward" 14 | Reverse SequenceProcessDirection = "reverse" 15 | Bidirectional SequenceProcessDirection = "bidirectional" 16 | ) 17 | 18 | // These constants define attributes that are applicable to GRU, LSTM and RNN operators. 19 | const ( 20 | ActivationAlphaAttr = "activation_alpha" 21 | ActivationBetaAttr = "activation_beta" 22 | ActivationsAttr = "activations" 23 | ClipAttr = "clip" 24 | DirectionAttr = "direction" 25 | HiddenSizeAttr = "hidden_size" 26 | ) 27 | 28 | // ExtractMatrices extracts a given number of matrices from tensor M. 29 | // M contains concatenated matrices along a certain dimension. 30 | // M is assumed to have a shape of (num_directions, nMatrices * hidden_size, ...) and we extract the 31 | // by slicing over the 'nMatrices * hidden_size' dimension. 32 | // This method is specific for recurrent operators RNN, GRU and LSTM. 33 | func ExtractMatrices(M tensor.Tensor, nMatrices, nDimensions, hiddenSize int) ([]tensor.Tensor, error) { 34 | dirSlice := NewSlicer(0) 35 | matrices := make([]tensor.Tensor, nMatrices) 36 | 37 | for i := 0; i < nMatrices; i++ { 38 | hiddenSlice := NewSlicer(i*hiddenSize, (i+1)*hiddenSize) 39 | 40 | allSlices := make([]tensor.Slice, nDimensions) 41 | allSlices[0] = dirSlice 42 | allSlices[1] = hiddenSlice 43 | 44 | for i := 2; i < nDimensions; i++ { 45 | allSlices[i] = nil 46 | } 47 | 48 | m, err := M.Slice(allSlices...) 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | matrices[i] = m 54 | } 55 | 56 | return matrices, nil 57 | } 58 | 59 | // ZeroTensor returns a tensor filled with zeros with the given shape. 60 | func ZeroTensor(shape ...int) tensor.Tensor { 61 | return tensor.New( 62 | tensor.WithShape(shape...), 63 | tensor.WithBacking(Zeros(NElements(shape...))), 64 | ) 65 | } 66 | 67 | // OnesTensor returns a new tensor with the same shape as the given tensor intialized with all ones. 68 | func OnesTensor(t tensor.Tensor) tensor.Tensor { 69 | return tensor.New( 70 | tensor.WithShape(t.Shape()...), 71 | tensor.WithBacking(Ones(NElements(t.Shape()...))), 72 | ) 73 | } 74 | -------------------------------------------------------------------------------- /ops/reducemax/constants.go: -------------------------------------------------------------------------------- 1 | package reducemax 2 | 3 | const ( 4 | axes = "axes" 5 | keepDims = "keepdims" 6 | ) 7 | -------------------------------------------------------------------------------- /ops/reducemax/reduce_max.go: -------------------------------------------------------------------------------- 1 | package reducemax 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var reduceMaxTypeConstraints = [][]tensor.Dtype{ 10 | {tensor.Uint8, tensor.Int8, tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 11 | } 12 | 13 | var reduceMax11TypeConstraints = [][]tensor.Dtype{ 14 | {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 15 | } 16 | 17 | const ( 18 | MinReduceMaxAttributes = 1 19 | MaxReduceMaxAttributes = 2 20 | ) 21 | 22 | // ReduceMax represents the ONNX reduceMax operator. 23 | type ReduceMax struct { 24 | ops.BaseOperator 25 | 26 | axes []int 27 | keepDims bool 28 | } 29 | 30 | // newReduceMax creates a new reduceMax operator. 31 | func newReduceMax(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 32 | return &ReduceMax{ 33 | BaseOperator: ops.NewBaseOperator( 34 | version, 35 | 1, 36 | 1, 37 | typeConstraints, 38 | "reducemax", 39 | ), 40 | axes: []int{}, 41 | keepDims: true, 42 | } 43 | } 44 | 45 | // Init initializes the reduceMax operator. 46 | func (r *ReduceMax) Init(n *onnx.NodeProto) error { 47 | attributes := n.GetAttribute() 48 | if len(attributes) == 0 || len(attributes) > MaxReduceMaxAttributes { 49 | return ops.ErrInvalidOptionalAttributeCount(MinReduceMaxAttributes, MaxReduceMaxAttributes, len(attributes), r) 50 | } 51 | 52 | for _, attr := range attributes { 53 | switch attr.GetName() { 54 | case axes: 55 | value, err := ops.AnyToIntSlice(attr.GetInts()) 56 | if err != nil { 57 | return err 58 | } 59 | 60 | r.axes = value 61 | case keepDims: 62 | r.keepDims = attr.GetI() == 1 63 | default: 64 | return ops.ErrInvalidAttribute(attr.GetName(), r) 65 | } 66 | } 67 | 68 | return nil 69 | } 70 | 71 | // Apply applies the reduceMax operator. 72 | func (r *ReduceMax) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 73 | input := tensor.New(tensor.WithBacking(inputs[0].Data()), tensor.WithShape(inputs[0].Shape()...)) 74 | 75 | axes := make([]int, len(r.axes)) 76 | for i, axis := range r.axes { 77 | axes[i] = ops.ConvertNegativeAxis(axis, len(input.Shape())) 78 | } 79 | 80 | out, err := input.Max(axes...) 81 | if err != nil { 82 | return nil, err 83 | } 84 | 85 | if r.keepDims { 86 | newShape := input.Shape() 87 | for _, axes := range axes { 88 | newShape[axes] = 1 89 | } 90 | 91 | err := out.Reshape(newShape...) 92 | if err != nil { 93 | return nil, err 94 | } 95 | } 96 | 97 | return []tensor.Tensor{out}, nil 98 | } 99 | -------------------------------------------------------------------------------- /ops/reducemax/versions.go: -------------------------------------------------------------------------------- 1 | package reducemax 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var reduceMaxVersions = ops.OperatorVersions{ 6 | 1: ops.NewOperatorConstructor(newReduceMax, 1, reduceMax11TypeConstraints), 7 | 11: ops.NewOperatorConstructor(newReduceMax, 11, reduceMax11TypeConstraints), 8 | 12: ops.NewOperatorConstructor(newReduceMax, 12, reduceMaxTypeConstraints), 9 | 13: ops.NewOperatorConstructor(newReduceMax, 13, reduceMaxTypeConstraints), 10 | } 11 | 12 | func GetVersions() ops.OperatorVersions { 13 | return reduceMaxVersions 14 | } 15 | -------------------------------------------------------------------------------- /ops/reducemean/versions.go: -------------------------------------------------------------------------------- 1 | package reducemean 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var reduceMeanVersions = ops.OperatorVersions{ 6 | 1: ops.NewOperatorConstructor(newReduceMean, 1, reduceMeanTypeConstraints), 7 | 11: ops.NewOperatorConstructor(newReduceMean, 11, reduceMeanTypeConstraints), 8 | 13: ops.NewOperatorConstructor(newReduceMean, 13, reduceMeanTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return reduceMeanVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/reducemin/constants.go: -------------------------------------------------------------------------------- 1 | package reducemin 2 | 3 | const ( 4 | axes = "axes" 5 | keepDims = "keepdims" 6 | ) 7 | -------------------------------------------------------------------------------- /ops/reducemin/reduce_min.go: -------------------------------------------------------------------------------- 1 | package reducemin 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var reduceMinTypeConstraints = [][]tensor.Dtype{ 10 | {tensor.Uint8, tensor.Int8, tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 11 | } 12 | 13 | var reduceMin11TypeConstraints = [][]tensor.Dtype{ 14 | {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 15 | } 16 | 17 | const ( 18 | MinReduceMinAttributes = 1 19 | MaxReduceMinAttributes = 2 20 | ) 21 | 22 | // ReduceMin represents the ONNX reduceMin operator. 23 | type ReduceMin struct { 24 | ops.BaseOperator 25 | 26 | axes []int 27 | keepDims bool 28 | } 29 | 30 | // newReduceMin creates a new reduceMin operator. 31 | func newReduceMin(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 32 | return &ReduceMin{ 33 | BaseOperator: ops.NewBaseOperator( 34 | version, 35 | 1, 36 | 1, 37 | typeConstraints, 38 | "reducemin", 39 | ), 40 | axes: []int{}, 41 | keepDims: true, 42 | } 43 | } 44 | 45 | // Init initializes the reduceMin operator. 46 | func (r *ReduceMin) Init(n *onnx.NodeProto) error { 47 | attributes := n.GetAttribute() 48 | if len(attributes) == 0 || len(attributes) > MaxReduceMinAttributes { 49 | return ops.ErrInvalidOptionalAttributeCount(MinReduceMinAttributes, MaxReduceMinAttributes, len(attributes), r) 50 | } 51 | 52 | for _, attr := range attributes { 53 | switch attr.GetName() { 54 | case axes: 55 | value, err := ops.AnyToIntSlice(attr.GetInts()) 56 | if err != nil { 57 | return err 58 | } 59 | 60 | r.axes = value 61 | case keepDims: 62 | r.keepDims = attr.GetI() == 1 63 | default: 64 | return ops.ErrInvalidAttribute(attr.GetName(), r) 65 | } 66 | } 67 | 68 | return nil 69 | } 70 | 71 | // Apply applies the reduceMin operator. 72 | func (r *ReduceMin) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 73 | input := tensor.New(tensor.WithBacking(inputs[0].Data()), tensor.WithShape(inputs[0].Shape()...)) 74 | 75 | axes := make([]int, len(r.axes)) 76 | for i, axis := range r.axes { 77 | axes[i] = ops.ConvertNegativeAxis(axis, len(input.Shape())) 78 | } 79 | 80 | out, err := input.Min(axes...) 81 | if err != nil { 82 | return nil, err 83 | } 84 | 85 | if r.keepDims { 86 | newShape := input.Shape() 87 | for _, axes := range axes { 88 | newShape[axes] = 1 89 | } 90 | 91 | err := out.Reshape(newShape...) 92 | if err != nil { 93 | return nil, err 94 | } 95 | } 96 | 97 | return []tensor.Tensor{out}, nil 98 | } 99 | -------------------------------------------------------------------------------- /ops/reducemin/versions.go: -------------------------------------------------------------------------------- 1 | package reducemin 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var reduceMinVersions = ops.OperatorVersions{ 6 | 1: ops.NewOperatorConstructor(newReduceMin, 1, reduceMin11TypeConstraints), 7 | 11: ops.NewOperatorConstructor(newReduceMin, 11, reduceMin11TypeConstraints), 8 | 12: ops.NewOperatorConstructor(newReduceMin, 12, reduceMinTypeConstraints), 9 | 13: ops.NewOperatorConstructor(newReduceMin, 13, reduceMinTypeConstraints), 10 | } 11 | 12 | func GetVersions() ops.OperatorVersions { 13 | return reduceMinVersions 14 | } 15 | -------------------------------------------------------------------------------- /ops/relu/relu.go: -------------------------------------------------------------------------------- 1 | package relu 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var reluTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} 10 | 11 | // Relu represents the ONNX relu operator. 12 | type Relu struct { 13 | ops.BaseOperator 14 | } 15 | 16 | // newRelu creates a new relu operator. 17 | func newRelu(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 18 | return &Relu{ 19 | BaseOperator: ops.NewBaseOperator( 20 | version, 21 | 1, 22 | 1, 23 | typeConstraints, 24 | "relu", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the relu operator. 30 | func (r *Relu) Init(*onnx.NodeProto) error { 31 | return nil 32 | } 33 | 34 | // Apply applies the relu operator. 35 | func (r *Relu) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 36 | out, err := ops.ReLU(inputs[0]) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | return []tensor.Tensor{out}, nil 42 | } 43 | -------------------------------------------------------------------------------- /ops/relu/versions.go: -------------------------------------------------------------------------------- 1 | package relu 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var reluVersions = ops.OperatorVersions{ 6 | 6: ops.NewOperatorConstructor(newRelu, 6, reluTypeConstraints), 7 | 13: ops.NewOperatorConstructor(newRelu, 13, reluTypeConstraints), 8 | } 9 | 10 | func GetVersions() ops.OperatorVersions { 11 | return reluVersions 12 | } 13 | -------------------------------------------------------------------------------- /ops/reshape/reshape.go: -------------------------------------------------------------------------------- 1 | package reshape 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var reshapeTypeConstraints = [][]tensor.Dtype{ops.AllTypes, {tensor.Int64}} 10 | 11 | const ( 12 | ReshapeMinInputs = 2 13 | ReshapeMaxInputs = 2 14 | ) 15 | 16 | // Reshape represents the ONNX reshape operator. 17 | type Reshape struct { 18 | ops.BaseOperator 19 | } 20 | 21 | // newReshape creates a new reshape operator. 22 | func newReshape(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 23 | return &Reshape{ 24 | BaseOperator: ops.NewBaseOperator( 25 | version, 26 | ReshapeMinInputs, 27 | ReshapeMaxInputs, 28 | typeConstraints, 29 | "reshape", 30 | ), 31 | } 32 | } 33 | 34 | // Init initializes the reshape operator. 35 | func (r *Reshape) Init(*onnx.NodeProto) error { 36 | return nil 37 | } 38 | 39 | // Apply applies the reshape operator. 40 | func (r *Reshape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 41 | t := inputs[0] 42 | 43 | newShape, err := ops.AnyToIntSlice(ops.IfScalarToSlice(inputs[1].Data().([]int64))) 44 | if err != nil { 45 | return nil, err 46 | } 47 | 48 | err = processShape(newShape, t.Shape()) 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | out, ok := t.Clone().(tensor.Tensor) 54 | if !ok { 55 | return nil, ops.ErrTypeAssert("tensor.Tensor", t.Clone()) 56 | } 57 | 58 | err = out.Reshape(newShape...) 59 | 60 | return []tensor.Tensor{out}, err 61 | } 62 | 63 | func processShape(newShape, currentShape []int) error { 64 | for i := 0; i < len(newShape); i++ { 65 | if newShape[i] == 0 { 66 | if i >= len(currentShape) { 67 | return ops.ErrDimension("could not infer dim size") 68 | } 69 | 70 | newShape[i] = currentShape[i] 71 | } 72 | } 73 | 74 | // Calculate the total number of elements in the original tensor. 75 | totalSize := ops.NElements(currentShape...) 76 | 77 | for i := 0; i < len(newShape); i++ { 78 | // When encountering a -1 dim size, calculate which size this should be. 79 | if newShape[i] == -1 { 80 | remainingSize := totalSize 81 | 82 | for j := 0; j < len(newShape); j++ { 83 | if j == i { 84 | continue 85 | } 86 | 87 | if newShape[j] == -1 { 88 | return ops.ErrDimension("at most one -1 dim size is allowed") 89 | } 90 | 91 | remainingSize /= newShape[j] 92 | } 93 | 94 | newShape[i] = remainingSize 95 | 96 | break 97 | } 98 | } 99 | 100 | return nil 101 | } 102 | -------------------------------------------------------------------------------- /ops/reshape/versions.go: -------------------------------------------------------------------------------- 1 | package reshape 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var reshapeVersions = ops.OperatorVersions{ 6 | 5: ops.NewOperatorConstructor(newReshape, 5, reshapeTypeConstraints), 7 | 13: ops.NewOperatorConstructor(newReshape, 13, reshapeTypeConstraints), 8 | } 9 | 10 | func GetVersions() ops.OperatorVersions { 11 | return reshapeVersions 12 | } 13 | -------------------------------------------------------------------------------- /ops/rnn/versions.go: -------------------------------------------------------------------------------- 1 | package rnn 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var rnnVersions = ops.OperatorVersions{ 6 | 7: ops.NewOperatorConstructor(newRNN, 7, rnnTypeConstraints), 7 | } 8 | 9 | func GetVersions() ops.OperatorVersions { 10 | return rnnVersions 11 | } 12 | -------------------------------------------------------------------------------- /ops/scaler/scaler.go: -------------------------------------------------------------------------------- 1 | package scaler 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var scalerTypeConstraints = [][]tensor.Dtype{ 10 | {tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 11 | } 12 | 13 | const ( 14 | ScalerExpectedAttributes = 2 15 | ) 16 | 17 | // Scaler represents the ONNX-ml scaler operator. 18 | type Scaler struct { 19 | ops.BaseOperator 20 | 21 | offset tensor.Tensor 22 | scale tensor.Tensor 23 | } 24 | 25 | // newScaler creates a new scaler operator. 26 | func newScaler(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 27 | return &Scaler{ 28 | BaseOperator: ops.NewBaseOperator( 29 | version, 30 | 1, 31 | 1, 32 | typeConstraints, 33 | "scaler", 34 | ), 35 | } 36 | } 37 | 38 | // Init initializes the scaler operator. 39 | func (s *Scaler) Init(n *onnx.NodeProto) error { 40 | attributes := n.GetAttribute() 41 | if len(attributes) != ScalerExpectedAttributes { 42 | return ops.ErrInvalidAttributeCount(ScalerExpectedAttributes, len(attributes), s) 43 | } 44 | 45 | for _, attr := range attributes { 46 | switch attr.GetName() { 47 | case "offset": 48 | floats := attr.GetFloats() 49 | s.offset = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats)) 50 | case "scale": 51 | floats := attr.GetFloats() 52 | s.scale = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats)) 53 | default: 54 | return ops.ErrInvalidAttribute(attr.GetName(), s) 55 | } 56 | } 57 | 58 | return nil 59 | } 60 | 61 | // Apply applies the scaler operator. 62 | func (s *Scaler) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 63 | X, offset, err := ops.UnidirectionalBroadcast(inputs[0], s.offset) 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | X, err = tensor.Sub(X, offset) 69 | if err != nil { 70 | return nil, err 71 | } 72 | 73 | X, scale, err := ops.UnidirectionalBroadcast(X, s.scale) 74 | if err != nil { 75 | return nil, err 76 | } 77 | 78 | Y, err := tensor.Mul(X, scale) 79 | if err != nil { 80 | return nil, err 81 | } 82 | 83 | return []tensor.Tensor{Y}, nil 84 | } 85 | -------------------------------------------------------------------------------- /ops/scaler/versions.go: -------------------------------------------------------------------------------- 1 | package scaler 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var scalerVersions = ops.OperatorVersions{ 6 | 1: ops.NewOperatorConstructor(newScaler, 1, scalerTypeConstraints), 7 | } 8 | 9 | func GetVersions() ops.OperatorVersions { 10 | return scalerVersions 11 | } 12 | -------------------------------------------------------------------------------- /ops/shape/shape.go: -------------------------------------------------------------------------------- 1 | package shape 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var shapeTypeConstraints = [][]tensor.Dtype{ops.AllTypes} 10 | 11 | // Shape represents the ONNX shape operator. 12 | type Shape struct { 13 | ops.BaseOperator 14 | } 15 | 16 | // newShape creates a new shape operator. 17 | func newShape(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 18 | return &Shape{ 19 | BaseOperator: ops.NewBaseOperator( 20 | version, 21 | 1, 22 | 1, 23 | typeConstraints, 24 | "shape", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the shape operator. 30 | func (s *Shape) Init(*onnx.NodeProto) error { 31 | return nil 32 | } 33 | 34 | // Apply the shape operator to the graph. It creates a node that holds the shape of the 35 | // input node as 1D int64 tensor. 36 | func (s *Shape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 37 | nodeShape := inputs[0].Shape() 38 | shape := make([]int64, len(nodeShape)) 39 | 40 | for i, dimSize := range nodeShape { 41 | shape[i] = int64(dimSize) 42 | } 43 | 44 | out := tensor.New(tensor.WithShape(len(nodeShape)), tensor.WithBacking(shape)) 45 | 46 | return []tensor.Tensor{out}, nil 47 | } 48 | -------------------------------------------------------------------------------- /ops/shape/shape_test.go: -------------------------------------------------------------------------------- 1 | package shape 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestShapeInit(t *testing.T) { 12 | s := &Shape{} 13 | 14 | // since 'shape' does not have any attributes we pass in nil. This should not 15 | // fail initializing the shape operator. 16 | err := s.Init(nil) 17 | assert.Nil(t, err) 18 | } 19 | 20 | func TestShape(t *testing.T) { 21 | tests := []struct { 22 | version int64 23 | inputShape []int 24 | expected []int64 25 | }{ 26 | { 27 | 1, 28 | []int{1, 2, 3, 4}, 29 | []int64{1, 2, 3, 4}, 30 | }, 31 | { 32 | 13, 33 | []int{2, 3}, 34 | []int64{2, 3}, 35 | }, 36 | } 37 | 38 | for _, test := range tests { 39 | shape := shapeVersions[test.version]() 40 | inputs := []tensor.Tensor{ 41 | ops.Float32TensorFixture(test.inputShape...), 42 | } 43 | 44 | res, err := shape.Apply(inputs) 45 | assert.Nil(t, err) 46 | assert.Equal(t, test.expected, res[0].Data()) 47 | } 48 | } 49 | 50 | func TestInputValidationShape(t *testing.T) { 51 | tests := []struct { 52 | version int64 53 | inputs []tensor.Tensor 54 | err error 55 | }{ 56 | { 57 | 1, 58 | []tensor.Tensor{ops.TensorWithBackingFixture([]uint32{3, 4}, 2)}, 59 | nil, 60 | }, 61 | { 62 | 13, 63 | []tensor.Tensor{ops.TensorWithBackingFixture([]float32{3, 4}, 2)}, 64 | nil, 65 | }, 66 | { 67 | 13, 68 | []tensor.Tensor{}, 69 | ops.ErrInvalidInputCount(0, shape13BaseOpFixture()), 70 | }, 71 | { 72 | 13, 73 | []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, 74 | ops.ErrInvalidInputType(0, "int", shape13BaseOpFixture()), 75 | }, 76 | } 77 | 78 | for _, test := range tests { 79 | shape := shapeVersions[test.version]() 80 | validated, err := shape.ValidateInputs(test.inputs) 81 | 82 | assert.Equal(t, test.err, err) 83 | 84 | if test.err == nil { 85 | assert.Equal(t, test.inputs, validated) 86 | } 87 | } 88 | } 89 | 90 | func shape13BaseOpFixture() ops.BaseOperator { 91 | return ops.NewBaseOperator(13, 1, 1, shapeTypeConstraints, "shape") 92 | } 93 | -------------------------------------------------------------------------------- /ops/shape/versions.go: -------------------------------------------------------------------------------- 1 | package shape 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var shapeVersions = ops.OperatorVersions{ 6 | 1: ops.NewOperatorConstructor(newShape, 1, shapeTypeConstraints), 7 | 13: ops.NewOperatorConstructor(newShape, 13, shapeTypeConstraints), 8 | } 9 | 10 | func GetVersions() ops.OperatorVersions { 11 | return shapeVersions 12 | } 13 | -------------------------------------------------------------------------------- /ops/sigmoid/sigmoid.go: -------------------------------------------------------------------------------- 1 | package sigmoid 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var sigmoidTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} 10 | 11 | // Sigmoid represents the ONNX sigmoid operator. 12 | type Sigmoid struct { 13 | ops.BaseOperator 14 | } 15 | 16 | // newSigmoid returns a new sigmoid operator. 17 | func newSigmoid(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 18 | return &Sigmoid{ 19 | BaseOperator: ops.NewBaseOperator( 20 | version, 21 | 1, 22 | 1, 23 | typeConstraints, 24 | "sigmoid", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the sigmoid operator. 30 | func (s *Sigmoid) Init(*onnx.NodeProto) error { 31 | return nil 32 | } 33 | 34 | // Apply the sigmoid operator to the input node. 35 | func (s *Sigmoid) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 36 | out, err := ops.Sigmoid(inputs[0]) 37 | 38 | return []tensor.Tensor{out}, err 39 | } 40 | -------------------------------------------------------------------------------- /ops/sigmoid/versions.go: -------------------------------------------------------------------------------- 1 | package sigmoid 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var sigmoidVersions = ops.OperatorVersions{ 6 | 6: ops.NewOperatorConstructor(newSigmoid, 6, sigmoidTypeConstraints), 7 | 13: ops.NewOperatorConstructor(newSigmoid, 13, sigmoidTypeConstraints), 8 | } 9 | 10 | func GetVersions() ops.OperatorVersions { 11 | return sigmoidVersions 12 | } 13 | -------------------------------------------------------------------------------- /ops/sin/sin.go: -------------------------------------------------------------------------------- 1 | package sin 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "github.com/advancedclimatesystems/gonnx/ops" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | var sinTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} 12 | 13 | // Sin represents the ONNX sin operator. 14 | type Sin struct { 15 | ops.BaseOperator 16 | } 17 | 18 | // newSin creates a new sin operator. 19 | func newSin(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 20 | return &Sin{ 21 | BaseOperator: ops.NewBaseOperator( 22 | version, 23 | 1, 24 | 1, 25 | typeConstraints, 26 | "sin", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the sin operator. 32 | func (s *Sin) Init(*onnx.NodeProto) error { 33 | return nil 34 | } 35 | 36 | // Apply applies the sin operator. 37 | func (s *Sin) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 38 | var ( 39 | out tensor.Tensor 40 | err error 41 | ) 42 | 43 | switch inputs[0].Dtype() { 44 | case tensor.Float32: 45 | out, err = inputs[0].Apply(sin[float32]) 46 | case tensor.Float64: 47 | out, err = inputs[0].Apply(sin[float64]) 48 | default: 49 | return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), s.BaseOperator) 50 | } 51 | 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | return []tensor.Tensor{out}, nil 57 | } 58 | 59 | func sin[T ops.FloatType](x T) T { 60 | return T(math.Sin(float64(x))) 61 | } 62 | -------------------------------------------------------------------------------- /ops/sin/sin_test.go: -------------------------------------------------------------------------------- 1 | package sin 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestSinInit(t *testing.T) { 12 | a := &Sin{} 13 | 14 | // since 'sin' does not have any attributes we pass in nil. This should not 15 | // fail initializing the sin. 16 | err := a.Init(nil) 17 | assert.Nil(t, err) 18 | } 19 | 20 | func TestSin(t *testing.T) { 21 | tests := []struct { 22 | version int64 23 | backing []float32 24 | shape []int 25 | expected []float32 26 | }{ 27 | { 28 | 7, 29 | []float32{-2, -1, 0, 1}, 30 | []int{2, 2}, 31 | []float32{-0.9092974, -0.84147096, 0, 0.84147096}, 32 | }, 33 | { 34 | 7, 35 | []float32{1, 3, 4, 5}, 36 | []int{1, 4}, 37 | []float32{0.84147096, 0.14112, -0.7568025, -0.9589243}, 38 | }, 39 | { 40 | 7, 41 | []float32{-1, -1, -1, -1}, 42 | []int{1, 4}, 43 | []float32{-0.84147096, -0.84147096, -0.84147096, -0.84147096}, 44 | }, 45 | } 46 | 47 | for _, test := range tests { 48 | sin := sinVersions[test.version]() 49 | inputs := []tensor.Tensor{ 50 | ops.TensorWithBackingFixture(test.backing, test.shape...), 51 | } 52 | 53 | res, err := sin.Apply(inputs) 54 | assert.Nil(t, err) 55 | 56 | assert.Nil(t, err) 57 | assert.Equal(t, test.expected, res[0].Data()) 58 | } 59 | } 60 | 61 | func TestInputValidationSin(t *testing.T) { 62 | tests := []struct { 63 | version int64 64 | inputs []tensor.Tensor 65 | err error 66 | }{ 67 | { 68 | 7, 69 | []tensor.Tensor{ 70 | ops.TensorWithBackingFixture([]float32{1, 2}, 2), 71 | }, 72 | nil, 73 | }, 74 | { 75 | 7, 76 | []tensor.Tensor{ 77 | ops.TensorWithBackingFixture([]float64{1, 2}, 2), 78 | }, 79 | nil, 80 | }, 81 | { 82 | 7, 83 | []tensor.Tensor{}, 84 | ops.ErrInvalidInputCount(0, sin7BaseOpFixture()), 85 | }, 86 | { 87 | 7, 88 | []tensor.Tensor{ 89 | ops.TensorWithBackingFixture([]int{1, 2}, 2), 90 | }, 91 | ops.ErrInvalidInputType(0, "int", sin7BaseOpFixture()), 92 | }, 93 | } 94 | 95 | for _, test := range tests { 96 | sin := sinVersions[test.version]() 97 | validated, err := sin.ValidateInputs(test.inputs) 98 | 99 | assert.Equal(t, test.err, err) 100 | 101 | if test.err == nil { 102 | assert.Equal(t, test.inputs, validated) 103 | } 104 | } 105 | } 106 | 107 | func sin7BaseOpFixture() ops.BaseOperator { 108 | return ops.NewBaseOperator(7, 1, 1, sinTypeConstraints, "sin") 109 | } 110 | -------------------------------------------------------------------------------- /ops/sin/versions.go: -------------------------------------------------------------------------------- 1 | package sin 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var sinVersions = ops.OperatorVersions{ 6 | 7: ops.NewOperatorConstructor(newSin, 7, sinTypeConstraints), 7 | } 8 | 9 | func GetVersions() ops.OperatorVersions { 10 | return sinVersions 11 | } 12 | -------------------------------------------------------------------------------- /ops/sinh/sinh.go: -------------------------------------------------------------------------------- 1 | package sinh 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "github.com/advancedclimatesystems/gonnx/ops" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | var sinhTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} 12 | 13 | // Sinh represents the ONNX sinh operator. 14 | type Sinh struct { 15 | ops.BaseOperator 16 | } 17 | 18 | // newSin creates a new sinh operator. 19 | func newSinh(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 20 | return &Sinh{ 21 | BaseOperator: ops.NewBaseOperator( 22 | version, 23 | 1, 24 | 1, 25 | typeConstraints, 26 | "sinh", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the sinh operator. 32 | func (s *Sinh) Init(*onnx.NodeProto) error { 33 | return nil 34 | } 35 | 36 | // Apply applies the sinh operator. 37 | func (s *Sinh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 38 | var ( 39 | out tensor.Tensor 40 | err error 41 | ) 42 | 43 | switch inputs[0].Dtype() { 44 | case tensor.Float32: 45 | out, err = inputs[0].Apply(sinh[float32]) 46 | case tensor.Float64: 47 | out, err = inputs[0].Apply(sinh[float64]) 48 | default: 49 | return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), s.BaseOperator) 50 | } 51 | 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | return []tensor.Tensor{out}, nil 57 | } 58 | 59 | func sinh[T ops.FloatType](x T) T { 60 | return T(math.Sinh(float64(x))) 61 | } 62 | -------------------------------------------------------------------------------- /ops/sinh/sinh_test.go: -------------------------------------------------------------------------------- 1 | package sinh 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestSinhInit(t *testing.T) { 12 | s := &Sinh{} 13 | 14 | // since 'sinh' does not have any attributes we pass in nil. This should not 15 | // fail initializing the sinh. 16 | err := s.Init(nil) 17 | assert.Nil(t, err) 18 | } 19 | 20 | func TestSinh(t *testing.T) { 21 | tests := []struct { 22 | sinh *Sinh 23 | backing []float32 24 | shape []int 25 | expected []float32 26 | }{ 27 | { 28 | &Sinh{}, 29 | []float32{-2, -1, 0, 1}, 30 | []int{2, 2}, 31 | []float32{-3.6268604, -1.1752012, 0, 1.1752012}, 32 | }, 33 | { 34 | &Sinh{}, 35 | []float32{1, 3, 4, 5}, 36 | []int{1, 4}, 37 | []float32{1.1752012, 10.017875, 27.289917, 74.20321}, 38 | }, 39 | { 40 | &Sinh{}, 41 | []float32{-1, -1, -1, -1}, 42 | []int{1, 4}, 43 | []float32{-1.1752012, -1.1752012, -1.1752012, -1.1752012}, 44 | }, 45 | } 46 | 47 | for _, test := range tests { 48 | inputs := []tensor.Tensor{ 49 | ops.TensorWithBackingFixture(test.backing, test.shape...), 50 | } 51 | 52 | res, err := test.sinh.Apply(inputs) 53 | assert.Nil(t, err) 54 | 55 | assert.Nil(t, err) 56 | assert.Equal(t, test.expected, res[0].Data()) 57 | } 58 | } 59 | 60 | func TestInputValidationSinh(t *testing.T) { 61 | tests := []struct { 62 | version int64 63 | inputs []tensor.Tensor 64 | err error 65 | }{ 66 | { 67 | 9, 68 | []tensor.Tensor{ 69 | ops.TensorWithBackingFixture([]float32{1, 2}, 2), 70 | }, 71 | nil, 72 | }, 73 | { 74 | 9, 75 | []tensor.Tensor{ 76 | ops.TensorWithBackingFixture([]float64{1, 2}, 2), 77 | }, 78 | nil, 79 | }, 80 | { 81 | 9, 82 | []tensor.Tensor{}, 83 | ops.ErrInvalidInputCount(0, sinh9BaseOpFixture()), 84 | }, 85 | { 86 | 9, 87 | []tensor.Tensor{ 88 | ops.TensorWithBackingFixture([]int{1, 2}, 2), 89 | }, 90 | ops.ErrInvalidInputType(0, "int", sinh9BaseOpFixture()), 91 | }, 92 | } 93 | 94 | for _, test := range tests { 95 | sinh := sinhVersions[test.version]() 96 | validated, err := sinh.ValidateInputs(test.inputs) 97 | 98 | assert.Equal(t, test.err, err) 99 | 100 | if test.err == nil { 101 | assert.Equal(t, test.inputs, validated) 102 | } 103 | } 104 | } 105 | 106 | func sinh9BaseOpFixture() ops.BaseOperator { 107 | return ops.NewBaseOperator(9, 1, 1, sinhTypeConstraints, "sinh") 108 | } 109 | -------------------------------------------------------------------------------- /ops/sinh/versions.go: -------------------------------------------------------------------------------- 1 | package sinh 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var sinhVersions = ops.OperatorVersions{ 6 | 9: ops.NewOperatorConstructor(newSinh, 9, sinhTypeConstraints), 7 | } 8 | 9 | func GetVersions() ops.OperatorVersions { 10 | return sinhVersions 11 | } 12 | -------------------------------------------------------------------------------- /ops/slice/slice_1.go: -------------------------------------------------------------------------------- 1 | package slice 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | const ( 10 | MinSlice1Attributes = 2 11 | MaxSlice1Attributes = 3 12 | ) 13 | 14 | // Slice1 represents the ONNX slice operator. 15 | type Slice1 struct { 16 | ops.BaseOperator 17 | 18 | axes []int 19 | ends []int 20 | starts []int 21 | } 22 | 23 | // newSlice1 creates a new slice operator. 24 | func newSlice1() ops.Operator { 25 | return &Slice1{ 26 | BaseOperator: ops.NewBaseOperator( 27 | 1, 28 | MinSliceInputs, 29 | MaxSliceInputs, 30 | sliceTypeConstraints, 31 | "slice", 32 | ), 33 | } 34 | } 35 | 36 | // Init initializes the slice operator. 37 | func (s *Slice1) Init(n *onnx.NodeProto) error { 38 | nAttrs := len(n.GetAttribute()) 39 | if nAttrs < MinSlice1Attributes || nAttrs > MaxSlice1Attributes { 40 | return ops.ErrInvalidOptionalAttributeCount(MinSlice1Attributes, MaxSlice1Attributes, nAttrs, s) 41 | } 42 | 43 | for _, attr := range n.GetAttribute() { 44 | switch attr.GetName() { 45 | case "axes": 46 | axes, err := ops.AnyToIntSlice(attr.GetInts()) 47 | if err != nil { 48 | return err 49 | } 50 | 51 | s.axes = axes 52 | case "ends": 53 | ends, err := ops.AnyToIntSlice(attr.GetInts()) 54 | if err != nil { 55 | return err 56 | } 57 | 58 | s.ends = ends 59 | case "starts": 60 | starts, err := ops.AnyToIntSlice(attr.GetInts()) 61 | if err != nil { 62 | return err 63 | } 64 | 65 | s.starts = starts 66 | default: 67 | return ops.ErrInvalidAttribute(attr.GetName(), s) 68 | } 69 | } 70 | 71 | return nil 72 | } 73 | 74 | // Apply applies the slice operator. 75 | func (s *Slice1) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 76 | data := inputs[0] 77 | 78 | axes := s.axes 79 | if len(s.axes) == 0 { 80 | axes = getDefaultAxes(len(s.starts)) 81 | } 82 | 83 | steps := make([]int, len(s.starts)) 84 | for i := range steps { 85 | steps[i] = 1 86 | } 87 | 88 | slices := constructSlices(s.starts, s.ends, steps, axes, len(data.Shape())) 89 | 90 | out, err := data.Slice(slices...) 91 | if err != nil { 92 | return nil, err 93 | } 94 | 95 | return []tensor.Tensor{out.Materialize()}, nil 96 | } 97 | -------------------------------------------------------------------------------- /ops/slice/versions.go: -------------------------------------------------------------------------------- 1 | package slice 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var sliceVersions = ops.OperatorVersions{ 6 | 1: newSlice1, 7 | 10: ops.NewOperatorConstructor(newSlice, 10, sliceTypeConstraints), 8 | 11: ops.NewOperatorConstructor(newSlice, 11, sliceTypeConstraints), 9 | 13: ops.NewOperatorConstructor(newSlice, 13, sliceTypeConstraints), 10 | } 11 | 12 | func GetVersions() ops.OperatorVersions { 13 | return sliceVersions 14 | } 15 | -------------------------------------------------------------------------------- /ops/slicer.go: -------------------------------------------------------------------------------- 1 | package ops 2 | 3 | import "gorgonia.org/tensor" 4 | 5 | // Slicer implements the tensor.Slice interface. It is able to slice the dimension of a tensor. 6 | type Slicer struct { 7 | start int 8 | end int 9 | step int 10 | } 11 | 12 | // NewSlicer creates a new Slicer object. By default, end will be set to start + 1 and step 13 | // will be set to 1. If options are given, it is assumed that the first element will be the value 14 | // for the end index and the second element the value for the step of the Slicer. 15 | func NewSlicer(start int, options ...int) tensor.Slice { 16 | const maxOptionLength = 2 17 | 18 | end := start + 1 19 | step := 1 20 | 21 | if len(options) >= 1 { 22 | end = options[0] 23 | } 24 | 25 | if len(options) >= maxOptionLength { 26 | step = options[1] 27 | } 28 | 29 | return &Slicer{ 30 | start: start, 31 | end: end, 32 | step: step, 33 | } 34 | } 35 | 36 | // Start returns the start of the slice. 37 | func (s *Slicer) Start() int { 38 | return s.start 39 | } 40 | 41 | // End returns the start of the slice. 42 | func (s *Slicer) End() int { 43 | return s.end 44 | } 45 | 46 | // Step returns the step of the slice. 47 | func (s *Slicer) Step() int { 48 | return s.step 49 | } 50 | -------------------------------------------------------------------------------- /ops/slicer_test.go: -------------------------------------------------------------------------------- 1 | package ops 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestNewSlicer(t *testing.T) { 10 | tests := []struct { 11 | start int 12 | options []int 13 | expectedStart int 14 | expectedEnd int 15 | expectedStep int 16 | }{ 17 | {0, []int{}, 0, 1, 1}, 18 | {1, []int{}, 1, 2, 1}, 19 | {0, []int{2}, 0, 2, 1}, 20 | {0, []int{2, 2}, 0, 2, 2}, 21 | } 22 | 23 | for _, test := range tests { 24 | slicer := NewSlicer(test.start, test.options...) 25 | assert.Equal(t, test.expectedStart, slicer.Start()) 26 | assert.Equal(t, test.expectedEnd, slicer.End()) 27 | assert.Equal(t, test.expectedStep, slicer.Step()) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /ops/softmax/softmax.go: -------------------------------------------------------------------------------- 1 | package softmax 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var softmaxTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} 10 | 11 | // Softmax represents the ONNX softmax operator. 12 | type Softmax struct { 13 | ops.BaseOperator 14 | 15 | // The axis along which to perform the Softmax operation. 16 | axis int 17 | } 18 | 19 | // newSoftmax creates a new softmax operator. 20 | func newSoftmax(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 21 | return &Softmax{ 22 | BaseOperator: ops.NewBaseOperator( 23 | version, 24 | 1, 25 | 1, 26 | typeConstraints, 27 | "softmax", 28 | ), 29 | axis: -1, 30 | } 31 | } 32 | 33 | // Init initializes the softmax operator. 34 | func (s *Softmax) Init(n *onnx.NodeProto) error { 35 | attributes := n.GetAttribute() 36 | nAttributes := len(attributes) 37 | 38 | if nAttributes > 1 { 39 | return ops.ErrInvalidAttributeCount(1, nAttributes, s) 40 | } 41 | 42 | if nAttributes == 1 { 43 | s.axis = int(attributes[0].GetI()) 44 | } 45 | 46 | return nil 47 | } 48 | 49 | // Apply applies the softmax operator. 50 | func (s *Softmax) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 51 | input := inputs[0] 52 | nDims := len(input.Shape()) 53 | 54 | if s.axis < -nDims || s.axis >= nDims { 55 | return nil, ops.ErrAxisOutOfRange(-nDims, nDims, s.axis) 56 | } 57 | 58 | axis := s.axis 59 | if s.axis < 0 { 60 | axis += nDims 61 | } 62 | 63 | out, err := tensor.SoftMax(inputs[0], axis) 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | return []tensor.Tensor{out}, nil 69 | } 70 | -------------------------------------------------------------------------------- /ops/softmax/versions.go: -------------------------------------------------------------------------------- 1 | package softmax 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var softmaxVersions = ops.OperatorVersions{ 6 | 1: ops.NewOperatorConstructor(newSoftmax, 1, softmaxTypeConstraints), 7 | 11: ops.NewOperatorConstructor(newSoftmax, 11, softmaxTypeConstraints), 8 | 13: ops.NewOperatorConstructor(newSoftmax, 13, softmaxTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return softmaxVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/sqrt/sqrt.go: -------------------------------------------------------------------------------- 1 | package sqrt 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var sqrtTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} 10 | 11 | // Sqrt represents the ONNX sqrt operator. 12 | type Sqrt struct { 13 | ops.BaseOperator 14 | } 15 | 16 | // newSqrt creates a new sqrt operator. 17 | func newSqrt(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 18 | return &Sqrt{ 19 | BaseOperator: ops.NewBaseOperator( 20 | version, 21 | 1, 22 | 1, 23 | typeConstraints, 24 | "sqrt", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the sqrt operator. 30 | func (s *Sqrt) Init(_ *onnx.NodeProto) error { 31 | return nil 32 | } 33 | 34 | // Apply applies the sqrt operator. 35 | func (s *Sqrt) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 36 | out, err := tensor.Sqrt(inputs[0]) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | return []tensor.Tensor{out}, nil 42 | } 43 | -------------------------------------------------------------------------------- /ops/sqrt/sqrt_test.go: -------------------------------------------------------------------------------- 1 | package sqrt 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestSqrtInit(t *testing.T) { 12 | s := &Sqrt{} 13 | err := s.Init(nil) 14 | assert.Nil(t, err) 15 | } 16 | 17 | func TestSqrt(t *testing.T) { 18 | tests := []struct { 19 | version int64 20 | backing []float32 21 | shape []int 22 | expected []float32 23 | }{ 24 | { 25 | 13, 26 | []float32{1, 2, 3, 4}, 27 | []int{2, 2}, 28 | []float32{1, 1.4142135, 1.7320508, 2}, 29 | }, 30 | { 31 | 6, 32 | []float32{1, 3, 4, 5}, 33 | []int{1, 4}, 34 | []float32{1, 1.7320508, 2, 2.236068}, 35 | }, 36 | { 37 | 13, 38 | []float32{1, 1, 1, 1}, 39 | []int{1, 4}, 40 | []float32{1, 1, 1, 1}, 41 | }, 42 | } 43 | 44 | for _, test := range tests { 45 | inputs := []tensor.Tensor{ 46 | ops.TensorWithBackingFixture(test.backing, test.shape...), 47 | } 48 | 49 | sqrt := sqrtVersions[test.version]() 50 | res, err := sqrt.Apply(inputs) 51 | assert.Nil(t, err) 52 | 53 | assert.Nil(t, err) 54 | assert.Equal(t, test.expected, res[0].Data()) 55 | } 56 | } 57 | 58 | func TestInputValidationSqrt(t *testing.T) { 59 | tests := []struct { 60 | version int64 61 | inputs []tensor.Tensor 62 | err error 63 | }{ 64 | { 65 | 13, 66 | []tensor.Tensor{ 67 | ops.TensorWithBackingFixture([]float32{1, 2}, 2), 68 | }, 69 | nil, 70 | }, 71 | { 72 | 13, 73 | []tensor.Tensor{ 74 | ops.TensorWithBackingFixture([]float64{1, 2}, 2), 75 | }, 76 | nil, 77 | }, 78 | { 79 | 13, 80 | []tensor.Tensor{}, 81 | ops.ErrInvalidInputCount(0, ops.NewBaseOperator(13, 1, 1, sqrtTypeConstraints, "sqrt")), 82 | }, 83 | { 84 | 13, 85 | []tensor.Tensor{ 86 | ops.TensorWithBackingFixture([]int{1, 2}, 2), 87 | }, 88 | ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(13, 1, 1, sqrtTypeConstraints, "sqrt")), 89 | }, 90 | } 91 | 92 | for _, test := range tests { 93 | sqrt := sqrtVersions[test.version]() 94 | validated, err := sqrt.ValidateInputs(test.inputs) 95 | 96 | assert.Equal(t, test.err, err) 97 | assert.Equal(t, test.inputs, validated) 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /ops/sqrt/versions.go: -------------------------------------------------------------------------------- 1 | package sqrt 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var sqrtVersions = ops.OperatorVersions{ 6 | 6: ops.NewOperatorConstructor(newSqrt, 6, sqrtTypeConstraints), 7 | 13: ops.NewOperatorConstructor(newSqrt, 13, sqrtTypeConstraints), 8 | } 9 | 10 | func GetVersions() ops.OperatorVersions { 11 | return sqrtVersions 12 | } 13 | -------------------------------------------------------------------------------- /ops/squeeze/squeeze_1.go: -------------------------------------------------------------------------------- 1 | package squeeze 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | // Squeeze1 represents the ONNX squeeze operator. 10 | type Squeeze1 struct { 11 | ops.BaseOperator 12 | 13 | axes []int 14 | } 15 | 16 | // newSqueeze1 creates a new squeeze operator. 17 | func newSqueeze1() ops.Operator { 18 | return &Squeeze1{ 19 | BaseOperator: ops.NewBaseOperator( 20 | 1, 21 | 1, 22 | 1, 23 | [][]tensor.Dtype{ops.AllTypes}, 24 | "squeeze", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the squeeze operator. 30 | func (s *Squeeze1) Init(n *onnx.NodeProto) error { 31 | for _, attr := range n.GetAttribute() { 32 | switch attr.GetName() { 33 | case "axes": 34 | axes, err := ops.AnyToIntSlice(attr.GetInts()) 35 | if err != nil { 36 | return err 37 | } 38 | 39 | s.axes = axes 40 | default: 41 | return ops.ErrInvalidAttribute(attr.GetName(), s) 42 | } 43 | } 44 | 45 | return nil 46 | } 47 | 48 | // Apply applies the squeeze operator. 49 | func (s *Squeeze1) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 50 | var err error 51 | 52 | currentShape := inputs[0].Shape() 53 | nDims := len(currentShape) 54 | dimsToSqueeze := getDimsToSqueezeFromShape(currentShape) 55 | 56 | if !ops.AllInRange(dimsToSqueeze, 0, nDims-1) { 57 | return nil, ops.ErrNotAllAxesInRange(nDims, nDims) 58 | } 59 | 60 | if len(s.axes) > 0 { 61 | dimsToSqueeze = getDimsToSqueezeFromList(s.axes, nDims) 62 | } 63 | 64 | newShape := getNewShape(currentShape, dimsToSqueeze) 65 | 66 | out, ok := inputs[0].Clone().(tensor.Tensor) 67 | if !ok { 68 | return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) 69 | } 70 | 71 | err = out.Reshape(newShape...) 72 | 73 | return []tensor.Tensor{out}, err 74 | } 75 | -------------------------------------------------------------------------------- /ops/squeeze/squeeze_11.go: -------------------------------------------------------------------------------- 1 | package squeeze 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | // Squeeze11 represents the ONNX squeeze operator. 10 | type Squeeze11 struct { 11 | ops.BaseOperator 12 | 13 | axes []int 14 | } 15 | 16 | // newSqueeze11 creates a new squeeze operator. 17 | func newSqueeze11() ops.Operator { 18 | return &Squeeze11{ 19 | BaseOperator: ops.NewBaseOperator( 20 | 11, 21 | 1, 22 | 1, 23 | [][]tensor.Dtype{ops.AllTypes}, 24 | "squeeze", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the squeeze operator. 30 | func (s *Squeeze11) Init(n *onnx.NodeProto) error { 31 | for _, attr := range n.GetAttribute() { 32 | switch attr.GetName() { 33 | case "axes": 34 | axes, err := ops.AnyToIntSlice(attr.GetInts()) 35 | if err != nil { 36 | return err 37 | } 38 | 39 | s.axes = axes 40 | default: 41 | return ops.ErrInvalidAttribute(attr.GetName(), s) 42 | } 43 | } 44 | 45 | return nil 46 | } 47 | 48 | // Apply applies the squeeze operator. 49 | func (s *Squeeze11) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 50 | var err error 51 | 52 | currentShape := inputs[0].Shape() 53 | nDims := len(currentShape) 54 | dimsToSqueeze := getDimsToSqueezeFromShape(currentShape) 55 | 56 | if !ops.AllInRange(dimsToSqueeze, -nDims, nDims-1) { 57 | return nil, ops.ErrNotAllAxesInRange(nDims, nDims) 58 | } 59 | 60 | // negative entries should be offset by the rank of the output tensor 61 | // i.e. -1 -> nDims - 1, -nDims -> 0 62 | ops.OffsetArrayIfNegative(dimsToSqueeze, nDims) 63 | 64 | if len(s.axes) > 0 { 65 | dimsToSqueeze = getDimsToSqueezeFromList(s.axes, nDims) 66 | } 67 | 68 | newShape := getNewShape(currentShape, dimsToSqueeze) 69 | 70 | out, ok := inputs[0].Clone().(tensor.Tensor) 71 | if !ok { 72 | return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) 73 | } 74 | 75 | err = out.Reshape(newShape...) 76 | 77 | return []tensor.Tensor{out}, err 78 | } 79 | 80 | // getDimsToSqueezeFromList creates a list with ints representing the dimensions/axes to squeeze 81 | // based on a list of ints. The list should contain dimensions/axes to squeeze. Negative dimensions 82 | // represent dimensions counting from the end of the shape, i.e. -2 repesents the second 83 | // last dimension. 84 | func getDimsToSqueezeFromList(axes []int, nDims int) []int { 85 | dimsToSqueeze := make([]int, len(axes)) 86 | copy(dimsToSqueeze, axes) 87 | 88 | for i, val := range dimsToSqueeze { 89 | if val < 0 { 90 | dimsToSqueeze[i] = nDims + val 91 | } 92 | } 93 | 94 | return dimsToSqueeze 95 | } 96 | -------------------------------------------------------------------------------- /ops/squeeze/versions.go: -------------------------------------------------------------------------------- 1 | package squeeze 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var squeezeVersions = ops.OperatorVersions{ 6 | 1: newSqueeze1, 7 | 11: newSqueeze11, 8 | 13: ops.NewOperatorConstructor(newSqueeze, 13, squeezeTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return squeezeVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/sub/sub.go: -------------------------------------------------------------------------------- 1 | package sub 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var subTypeConstraints = [][]tensor.Dtype{ 10 | {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 11 | {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, 12 | } 13 | 14 | // Sub represents the ONNX sub operator. 15 | type Sub struct { 16 | ops.BaseOperator 17 | } 18 | 19 | // newSub creates a new sub operator. 20 | func newSub(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 21 | return &Sub{ 22 | BaseOperator: ops.NewBaseOperator( 23 | version, 24 | 2, 25 | 2, 26 | typeConstraints, 27 | "sub", 28 | ), 29 | } 30 | } 31 | 32 | // Init initializes the sub operator. 33 | func (s *Sub) Init(*onnx.NodeProto) error { 34 | return nil 35 | } 36 | 37 | // Apply applies the sub operator. 38 | func (s *Sub) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 39 | return ops.ApplyBinaryOperation( 40 | inputs[0], 41 | inputs[1], 42 | ops.Sub, 43 | ops.MultidirectionalBroadcasting, 44 | ) 45 | } 46 | -------------------------------------------------------------------------------- /ops/sub/versions.go: -------------------------------------------------------------------------------- 1 | package sub 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var subVersions = ops.OperatorVersions{ 6 | 7: ops.NewOperatorConstructor(newSub, 7, subTypeConstraints), 7 | 13: ops.NewOperatorConstructor(newSub, 13, subTypeConstraints), 8 | } 9 | 10 | func GetVersions() ops.OperatorVersions { 11 | return subVersions 12 | } 13 | -------------------------------------------------------------------------------- /ops/tan/tan.go: -------------------------------------------------------------------------------- 1 | package tan 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "github.com/advancedclimatesystems/gonnx/ops" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | var tanTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} 12 | 13 | // Tan represents the ONNX tan operator. 14 | type Tan struct { 15 | ops.BaseOperator 16 | } 17 | 18 | // newTan creates a new tan operator. 19 | func newTan(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 20 | return &Tan{ 21 | BaseOperator: ops.NewBaseOperator( 22 | version, 23 | 1, 24 | 1, 25 | typeConstraints, 26 | "tan", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the tan operator. 32 | func (t *Tan) Init(*onnx.NodeProto) error { 33 | return nil 34 | } 35 | 36 | // Apply applies the tan operator. 37 | func (t *Tan) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 38 | var ( 39 | out tensor.Tensor 40 | err error 41 | ) 42 | 43 | switch inputs[0].Dtype() { 44 | case tensor.Float32: 45 | out, err = inputs[0].Apply(tan[float32]) 46 | case tensor.Float64: 47 | out, err = inputs[0].Apply(tan[float64]) 48 | default: 49 | return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), t.BaseOperator) 50 | } 51 | 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | return []tensor.Tensor{out}, nil 57 | } 58 | 59 | func tan[T ops.FloatType](x T) T { 60 | return T(math.Tan(float64(x))) 61 | } 62 | -------------------------------------------------------------------------------- /ops/tan/tan_test.go: -------------------------------------------------------------------------------- 1 | package tan 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestTanInit(t *testing.T) { 12 | a := &Tan{} 13 | 14 | // since 'tan' does not have any attributes we pass in nil. This should not 15 | // fail initializing the tan. 16 | err := a.Init(nil) 17 | assert.Nil(t, err) 18 | } 19 | 20 | func TestTan(t *testing.T) { 21 | tests := []struct { 22 | tan *Tan 23 | backing []float32 24 | shape []int 25 | expected []float32 26 | }{ 27 | { 28 | &Tan{}, 29 | []float32{1, 2, 3, 4}, 30 | []int{2, 2}, 31 | []float32{1.5574077, -2.1850398, -0.14254655, 1.1578213}, 32 | }, 33 | { 34 | &Tan{}, 35 | []float32{1, 2, 3, 4}, 36 | []int{1, 4}, 37 | []float32{1.5574077, -2.1850398, -0.14254655, 1.1578213}, 38 | }, 39 | { 40 | &Tan{}, 41 | []float32{2, 2, 2, 2}, 42 | []int{1, 4}, 43 | []float32{-2.1850398, -2.1850398, -2.1850398, -2.1850398}, 44 | }, 45 | } 46 | 47 | for _, test := range tests { 48 | inputs := []tensor.Tensor{ 49 | ops.TensorWithBackingFixture(test.backing, test.shape...), 50 | } 51 | 52 | res, err := test.tan.Apply(inputs) 53 | assert.Nil(t, err) 54 | 55 | assert.Nil(t, err) 56 | assert.Equal(t, test.expected, res[0].Data()) 57 | } 58 | } 59 | 60 | func TestInputValidationTan(t *testing.T) { 61 | tests := []struct { 62 | version int64 63 | inputs []tensor.Tensor 64 | err error 65 | }{ 66 | { 67 | 7, 68 | []tensor.Tensor{ 69 | ops.TensorWithBackingFixture([]float32{1, 2}, 2), 70 | }, 71 | nil, 72 | }, 73 | { 74 | 7, 75 | []tensor.Tensor{ 76 | ops.TensorWithBackingFixture([]float64{1, 2}, 2), 77 | }, 78 | nil, 79 | }, 80 | { 81 | 7, 82 | []tensor.Tensor{}, 83 | ops.ErrInvalidInputCount(0, tan7BaseOpFixture()), 84 | }, 85 | { 86 | 7, 87 | []tensor.Tensor{ 88 | ops.TensorWithBackingFixture([]int{1, 2}, 2), 89 | }, 90 | ops.ErrInvalidInputType(0, "int", tan7BaseOpFixture()), 91 | }, 92 | } 93 | 94 | for _, test := range tests { 95 | tan := tanVersions[test.version]() 96 | validated, err := tan.ValidateInputs(test.inputs) 97 | 98 | assert.Equal(t, test.err, err) 99 | 100 | if test.err == nil { 101 | assert.Equal(t, test.inputs, validated) 102 | } 103 | } 104 | } 105 | 106 | func tan7BaseOpFixture() ops.BaseOperator { 107 | return ops.NewBaseOperator(7, 1, 1, tanTypeConstraints, "tan") 108 | } 109 | -------------------------------------------------------------------------------- /ops/tan/versions.go: -------------------------------------------------------------------------------- 1 | package tan 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var tanVersions = ops.OperatorVersions{ 6 | 7: ops.NewOperatorConstructor(newTan, 7, tanTypeConstraints), 7 | } 8 | 9 | func GetVersions() ops.OperatorVersions { 10 | return tanVersions 11 | } 12 | -------------------------------------------------------------------------------- /ops/tanh/tanh.go: -------------------------------------------------------------------------------- 1 | package tanh 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var tanhTypeConstraint = [][]tensor.Dtype{ 10 | {tensor.Float32, tensor.Float64}, 11 | } 12 | 13 | // Tanh represents the tanh operator. 14 | type Tanh struct { 15 | ops.BaseOperator 16 | } 17 | 18 | // newTanh returns a new tanh operator. 19 | func newTanh(version int, typeConstraint [][]tensor.Dtype) ops.Operator { 20 | return &Tanh{ 21 | BaseOperator: ops.NewBaseOperator( 22 | version, 23 | 1, 24 | 1, 25 | typeConstraint, 26 | "tanh", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the sigmoid operator. 32 | func (t *Tanh) Init(*onnx.NodeProto) error { 33 | return nil 34 | } 35 | 36 | // Apply the sigmoid operator to the input node. 37 | func (t *Tanh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 38 | out, err := ops.Tanh(inputs[0]) 39 | 40 | return []tensor.Tensor{out}, err 41 | } 42 | -------------------------------------------------------------------------------- /ops/tanh/versions.go: -------------------------------------------------------------------------------- 1 | package tanh 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var tanhVersions = ops.OperatorVersions{ 6 | 6: ops.NewOperatorConstructor(newTanh, 6, tanhTypeConstraint), 7 | 13: ops.NewOperatorConstructor(newTanh, 13, tanhTypeConstraint), 8 | } 9 | 10 | func GetVersions() ops.OperatorVersions { 11 | return tanhVersions 12 | } 13 | -------------------------------------------------------------------------------- /ops/transpose/transpose.go: -------------------------------------------------------------------------------- 1 | package transpose 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var transposeTypeConstraint = [][]tensor.Dtype{ops.AllTypes} 10 | 11 | // Transpose represents the ONNX transpose operator. 12 | type Transpose struct { 13 | ops.BaseOperator 14 | 15 | perm []int 16 | } 17 | 18 | // newTranspose creates a new transpose operator. 19 | func newTranspose(version int, typeConstraint [][]tensor.Dtype) ops.Operator { 20 | return &Transpose{ 21 | BaseOperator: ops.NewBaseOperator( 22 | version, 23 | 1, 24 | 1, 25 | typeConstraint, 26 | "transpose", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the transpose operator. 32 | func (t *Transpose) Init(n *onnx.NodeProto) error { 33 | attributes := n.GetAttribute() 34 | 35 | if len(attributes) == 1 { 36 | attr := attributes[0] 37 | 38 | if attr.GetName() != "perm" { 39 | return ops.ErrInvalidAttribute(attr.GetName(), t) 40 | } 41 | 42 | attrPerm := attr.GetInts() 43 | 44 | perm := make([]int, 0) 45 | for _, val := range attrPerm { 46 | perm = append(perm, int(val)) 47 | } 48 | 49 | t.perm = perm 50 | } 51 | 52 | return nil 53 | } 54 | 55 | // Apply applies the transpose operator. 56 | func (t *Transpose) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 57 | out, err := tensor.Transpose(inputs[0], t.perm...) 58 | if err != nil { 59 | return nil, err 60 | } 61 | 62 | return []tensor.Tensor{out}, nil 63 | } 64 | -------------------------------------------------------------------------------- /ops/transpose/versions.go: -------------------------------------------------------------------------------- 1 | package transpose 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var transposeVersions = ops.OperatorVersions{ 6 | 1: ops.NewOperatorConstructor(newTranspose, 1, transposeTypeConstraint), 7 | 13: ops.NewOperatorConstructor(newTranspose, 13, transposeTypeConstraint), 8 | } 9 | 10 | func GetVersions() ops.OperatorVersions { 11 | return transposeVersions 12 | } 13 | -------------------------------------------------------------------------------- /ops/types.go: -------------------------------------------------------------------------------- 1 | package ops 2 | 3 | import "gorgonia.org/tensor" 4 | 5 | // FloatType is a type that describes a float value. Can be either float32 or float64. 6 | type FloatType interface { 7 | float32 | float64 8 | } 9 | 10 | type IntType interface { 11 | uint8 | uint16 | uint32 | uint64 | int8 | int16 | int32 | int64 12 | } 13 | 14 | type NumericType interface { 15 | IntType | FloatType 16 | } 17 | 18 | // AllTypes is a type constraint which allows all types. 19 | var AllTypes = []tensor.Dtype{ 20 | tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, 21 | tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, 22 | tensor.Float32, tensor.Float64, 23 | tensor.Complex64, tensor.Complex128, 24 | tensor.String, 25 | tensor.Bool, 26 | } 27 | 28 | // IntTypes is a list with all integer types. 29 | var IntTypes = []tensor.Dtype{ 30 | tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, 31 | tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, 32 | } 33 | 34 | // NumericTypes is a list with all numeric types. 35 | var NumericTypes = []tensor.Dtype{ 36 | tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, 37 | tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, 38 | tensor.Float32, tensor.Float64, 39 | } 40 | -------------------------------------------------------------------------------- /ops/unidir_broadcast.go: -------------------------------------------------------------------------------- 1 | package ops 2 | 3 | import ( 4 | "gorgonia.org/tensor" 5 | ) 6 | 7 | type BroadcastType int 8 | 9 | const ( 10 | NoBroadcasting BroadcastType = 0 11 | UnidirectionalBroadcasting BroadcastType = 1 12 | MultidirectionalBroadcasting BroadcastType = 2 13 | ) 14 | 15 | // UnidirectionalBroadcast tries to broadcast tensor B to tensor A according to the ONNX standards. 16 | func UnidirectionalBroadcast(A, B tensor.Tensor) (tensor.Tensor, tensor.Tensor, error) { 17 | reshapedB, err := reshapeTensorsForUnidirBroadcast(A, B) 18 | if err != nil { 19 | return nil, nil, ErrUnidirBroadcast(A.Shape(), B.Shape()) 20 | } 21 | 22 | newB, err := repeatTensorsForUnidirBroadcast(A, reshapedB) 23 | if err != nil { 24 | return nil, nil, ErrUnidirBroadcast(A.Shape(), B.Shape()) 25 | } 26 | 27 | return A, newB, nil 28 | } 29 | 30 | // reshapeTensorsForUnidirBroadcast reshapes the B tensor to match the number of dimensions 31 | // of the A tensor. New dimensions of size 1 are added to the front. 32 | // Example: shapeA=(2, 3, 4) and shapeB=(3, 4) yields shapeNewB=(1, 3, 4). 33 | func reshapeTensorsForUnidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, error) { 34 | nDimsA := len(A.Shape()) 35 | nDimsB := len(B.Shape()) 36 | 37 | switch { 38 | case nDimsA > nDimsB: 39 | return AddExtraDimsToTensor(B, nDimsA-nDimsB) 40 | case nDimsA == nDimsB: 41 | return B, nil 42 | default: 43 | return nil, ErrUnidirBroadcast(A.Shape(), B.Shape()) 44 | } 45 | } 46 | 47 | // repeatTensorsForUnidirBroadcast broadcasts tensor B such that it corresponds with the 48 | // shape of tensor A. Assumes the B tensor has already been reshaped such that it has 49 | // the same number of dimensions as tensor A. 50 | // Example: shapeA=(2, 3, 4) and shapeB=(1, 3, 4) yields shapeNewB=(2, 3, 4). 51 | func repeatTensorsForUnidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, error) { 52 | var err error 53 | 54 | shapeA := A.Shape() 55 | shapeB := B.Shape() 56 | 57 | // Repeatedly repeat the B tensor along dimensions of size 1. 58 | for axis := len(shapeA) - 1; axis >= 0; axis-- { 59 | sizeDimA := shapeA[axis] 60 | sizeDimB := shapeB[axis] 61 | 62 | if sizeDimA != sizeDimB { 63 | if sizeDimB != 1 { 64 | return nil, ErrUnidirBroadcast(shapeA, shapeB) 65 | } 66 | 67 | B, err = tensor.Repeat(B, axis, sizeDimA) 68 | if err != nil { 69 | return nil, err 70 | } 71 | } 72 | } 73 | 74 | return B, nil 75 | } 76 | -------------------------------------------------------------------------------- /ops/unidir_broadcast_test.go: -------------------------------------------------------------------------------- 1 | package ops 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "gorgonia.org/tensor" 8 | ) 9 | 10 | func TestUnidirectionalBroadcast(t *testing.T) { 11 | tests := []struct { 12 | shapes [][]int 13 | expectedShape tensor.Shape 14 | err error 15 | }{ 16 | { 17 | [][]int{{2, 3, 4, 5}, {5}}, 18 | []int{2, 3, 4, 5}, 19 | nil, 20 | }, 21 | { 22 | [][]int{{2, 3, 4, 5}, {2, 1, 1, 5}}, 23 | []int{2, 3, 4, 5}, 24 | nil, 25 | }, 26 | { 27 | [][]int{{2, 3, 4, 5}, {1, 3, 1, 5}}, 28 | []int{2, 3, 4, 5}, 29 | nil, 30 | }, 31 | { 32 | [][]int{{1, 3, 1}, {3, 1}}, 33 | []int{1, 3, 1}, 34 | nil, 35 | }, 36 | { 37 | [][]int{{1, 3, 1}, {3, 1}}, 38 | []int{1, 3, 1}, 39 | nil, 40 | }, 41 | { 42 | [][]int{{1, 3, 1}, {3, 2}}, 43 | nil, 44 | ErrUnidirBroadcast([]int{1, 3, 1}, []int{3, 2}), 45 | }, 46 | { 47 | [][]int{{5}, {2, 3, 4}}, 48 | nil, 49 | ErrUnidirBroadcast([]int{5}, []int{2, 3, 4}), 50 | }, 51 | { 52 | [][]int{{1, 4, 5}, {1, 1, 3}}, 53 | nil, 54 | ErrUnidirBroadcast([]int{1, 4, 5}, []int{1, 1, 3}), 55 | }, 56 | } 57 | 58 | for _, test := range tests { 59 | A := Float32TensorFixture(test.shapes[0]...) 60 | B := Float32TensorFixture(test.shapes[1]...) 61 | 62 | newA, newB, err := UnidirectionalBroadcast(A, B) 63 | 64 | assert.Equal(t, test.err, err) 65 | 66 | if err == nil { 67 | assert.Equal(t, test.expectedShape, newA.Shape()) 68 | assert.Equal(t, test.expectedShape, newB.Shape()) 69 | } else { 70 | assert.Nil(t, newA) 71 | assert.Nil(t, newB) 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /ops/unsqueeze/unsqueeze.go: -------------------------------------------------------------------------------- 1 | package unsqueeze 2 | 3 | import ( 4 | "sort" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "github.com/advancedclimatesystems/gonnx/ops" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | var unsqueezeTypeConstraints = [][]tensor.Dtype{ 12 | ops.AllTypes, 13 | {tensor.Int64}, 14 | } 15 | 16 | const ( 17 | MinUnsqueezeInputs = 2 18 | MaxUnsqueezeInputs = 2 19 | ) 20 | 21 | // Unsqueeze represents the ONNX unsqueeze operator. 22 | type Unsqueeze struct { 23 | ops.BaseOperator 24 | } 25 | 26 | // newUnsqueeze creates a new unsqueeze operator. 27 | func newUnsqueeze(version int, typeConstraint [][]tensor.Dtype) ops.Operator { 28 | return &Unsqueeze{ 29 | BaseOperator: ops.NewBaseOperator( 30 | version, 31 | MinUnsqueezeInputs, 32 | MaxUnsqueezeInputs, 33 | typeConstraint, 34 | "unsqueeze", 35 | ), 36 | } 37 | } 38 | 39 | // Init initializes the unsqueeze operator. 40 | func (u *Unsqueeze) Init(*onnx.NodeProto) error { 41 | return nil 42 | } 43 | 44 | // Apply applies the unsqueeze operator. 45 | func (u *Unsqueeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 46 | dataShape := inputs[0].Shape() 47 | 48 | axes, err := ops.AnyToIntSlice(inputs[1].Data()) 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | outputRank := len(dataShape) + len(axes) 54 | 55 | if !ops.AllInRange(axes, -outputRank, outputRank-1) { 56 | return nil, ops.ErrNotAllAxesInRange(outputRank, outputRank) 57 | } 58 | 59 | // negative entries should be offset by the rank of the output tensor 60 | // i.e. -1 -> outputRank - 1, -outputrank -> 0 61 | ops.OffsetArrayIfNegative(axes, outputRank) 62 | 63 | sort.Ints(axes) 64 | 65 | if ops.HasDuplicates(axes) { 66 | return nil, ops.ErrInvalidInput("axes cannot have duplicate entries after offset", u.BaseOperator) 67 | } 68 | 69 | newShape := insertOnes(dataShape, axes) 70 | 71 | out, ok := inputs[0].Clone().(tensor.Tensor) 72 | if !ok { 73 | return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) 74 | } 75 | 76 | err = out.Reshape(newShape...) 77 | 78 | return []tensor.Tensor{out}, err 79 | } 80 | 81 | // Creates a new array, which is `original` with ones added at the indices specified by `indices` 82 | // `indices` may not contain duplicates, the elements are assumed to be in the range 0 <= x < N 83 | // and should be sorted in increasing order. 84 | // Is done in a single pass through the new array with length: len(original) + len(indices). 85 | func insertOnes(original, indices []int) []int { 86 | N := len(indices) + len(original) 87 | 88 | // Pre-allocate the output shape 89 | newShape := make([]int, N) 90 | 91 | originalIdx := 0 92 | indicesIdx := 0 93 | 94 | for i := 0; i < N; i++ { 95 | if indicesIdx < len(indices) && indices[indicesIdx] == i { 96 | newShape[i] = 1 97 | indicesIdx++ 98 | } else { 99 | newShape[i] = original[originalIdx] 100 | originalIdx++ 101 | } 102 | } 103 | 104 | return newShape 105 | } 106 | -------------------------------------------------------------------------------- /ops/unsqueeze/unsqueeze_1.go: -------------------------------------------------------------------------------- 1 | package unsqueeze 2 | 3 | import ( 4 | "sort" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "github.com/advancedclimatesystems/gonnx/ops" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | // Unsqueeze1 represents version 1 of the ONNX unsqueeze operator. 12 | type Unsqueeze1 struct { 13 | ops.BaseOperator 14 | 15 | axes []int 16 | } 17 | 18 | // newUnsqueeze1 creates a new unsqueeze operator. 19 | func newUnsqueeze1() ops.Operator { 20 | return &Unsqueeze1{ 21 | BaseOperator: ops.NewBaseOperator( 22 | 1, 23 | 1, 24 | 1, 25 | [][]tensor.Dtype{ops.AllTypes}, 26 | "unsqueeze", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the unsqueeze operator. 32 | func (u *Unsqueeze1) Init(n *onnx.NodeProto) error { 33 | attrs := n.GetAttribute() 34 | if len(attrs) != 1 { 35 | return ops.ErrInvalidAttributeCount(1, len(attrs), u) 36 | } 37 | 38 | axes, err := ops.AnyToIntSlice(attrs[0].GetInts()) 39 | if err != nil { 40 | return err 41 | } 42 | 43 | u.axes = axes 44 | 45 | return nil 46 | } 47 | 48 | // Apply applies the unsqueeze operator. 49 | func (u *Unsqueeze1) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 50 | dataShape := inputs[0].Shape() 51 | 52 | outputRank := len(dataShape) + len(u.axes) 53 | 54 | if !ops.AllInRange(u.axes, 0, outputRank-1) { 55 | return nil, ops.ErrNotAllAxesInRange(outputRank, outputRank) 56 | } 57 | 58 | sort.Ints(u.axes) 59 | 60 | if ops.HasDuplicates(u.axes) { 61 | return nil, ops.ErrInvalidInput("axes cannot have duplicate entries after offset", u.BaseOperator) 62 | } 63 | 64 | newShape := insertOnes(dataShape, u.axes) 65 | 66 | out, ok := inputs[0].Clone().(tensor.Tensor) 67 | if !ok { 68 | return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) 69 | } 70 | 71 | err := out.Reshape(newShape...) 72 | 73 | return []tensor.Tensor{out}, err 74 | } 75 | -------------------------------------------------------------------------------- /ops/unsqueeze/unsqueeze_11.go: -------------------------------------------------------------------------------- 1 | package unsqueeze 2 | 3 | import ( 4 | "sort" 5 | 6 | "github.com/advancedclimatesystems/gonnx/onnx" 7 | "github.com/advancedclimatesystems/gonnx/ops" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | // Unsqueeze11 represents version 11 of the ONNX unsqueeze operator. 12 | type Unsqueeze11 struct { 13 | ops.BaseOperator 14 | 15 | axes []int 16 | } 17 | 18 | // newUnsqueeze11 creates a new unsqueeze operator. 19 | func newUnsqueeze11() ops.Operator { 20 | return &Unsqueeze11{ 21 | BaseOperator: ops.NewBaseOperator( 22 | 11, 23 | 1, 24 | 1, 25 | [][]tensor.Dtype{ops.AllTypes}, 26 | "unsqueeze", 27 | ), 28 | } 29 | } 30 | 31 | // Init initializes the unsqueeze operator. 32 | func (u *Unsqueeze11) Init(n *onnx.NodeProto) error { 33 | attrs := n.GetAttribute() 34 | if len(attrs) != 1 { 35 | return ops.ErrInvalidAttributeCount(1, len(attrs), u) 36 | } 37 | 38 | axes, err := ops.AnyToIntSlice(attrs[0].GetInts()) 39 | if err != nil { 40 | return err 41 | } 42 | 43 | u.axes = axes 44 | 45 | return nil 46 | } 47 | 48 | // Apply applies the unsqueeze operator. 49 | func (u *Unsqueeze11) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 50 | dataShape := inputs[0].Shape() 51 | 52 | outputRank := len(dataShape) + len(u.axes) 53 | 54 | if !ops.AllInRange(u.axes, -outputRank, outputRank-1) { 55 | return nil, ops.ErrNotAllAxesInRange(outputRank, outputRank) 56 | } 57 | 58 | // negative entries should be offset by the rank of the output tensor 59 | // i.e. -1 -> outputRank - 1, -outputrank -> 0 60 | ops.OffsetArrayIfNegative(u.axes, outputRank) 61 | 62 | sort.Ints(u.axes) 63 | 64 | if ops.HasDuplicates(u.axes) { 65 | return nil, ops.ErrInvalidInput("axes cannot have duplicate entries after offset", u.BaseOperator) 66 | } 67 | 68 | newShape := insertOnes(dataShape, u.axes) 69 | 70 | out, ok := inputs[0].Clone().(tensor.Tensor) 71 | if !ok { 72 | return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) 73 | } 74 | 75 | err := out.Reshape(newShape...) 76 | 77 | return []tensor.Tensor{out}, err 78 | } 79 | -------------------------------------------------------------------------------- /ops/unsqueeze/versions.go: -------------------------------------------------------------------------------- 1 | package unsqueeze 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var unsqueezeVersions = ops.OperatorVersions{ 6 | 1: newUnsqueeze1, 7 | 11: newUnsqueeze11, 8 | 13: ops.NewOperatorConstructor(newUnsqueeze, 13, unsqueezeTypeConstraints), 9 | } 10 | 11 | func GetVersions() ops.OperatorVersions { 12 | return unsqueezeVersions 13 | } 14 | -------------------------------------------------------------------------------- /ops/validate_inputs.go: -------------------------------------------------------------------------------- 1 | package ops 2 | 3 | import ( 4 | "gorgonia.org/tensor" 5 | ) 6 | 7 | // ValidateInputs validates if a list of nodes has enough (not too few or too many) nodes. 8 | // When there are fewer input nodes then the given max, the list is padded with nils. 9 | // Expects either 1 requirement ==> the expected number of inputs, or 2 requirements, 10 | // the minimum and the maximum number of inputs. 11 | func ValidateInputs(op BaseOperator, inputs []tensor.Tensor) ([]tensor.Tensor, error) { 12 | padLength, err := checkNInputs(op, inputs) 13 | if err != nil { 14 | return inputs, err 15 | } 16 | 17 | inputs = padInputs(inputs, padLength) 18 | 19 | err = checkInputTypes(op, inputs) 20 | if err != nil { 21 | return inputs, err 22 | } 23 | 24 | return inputs, nil 25 | } 26 | 27 | func checkNInputs(op BaseOperator, inputs []tensor.Tensor) (int, error) { 28 | nInputs := len(inputs) 29 | padLength := 0 30 | 31 | minInputs := op.GetMinInputs() 32 | maxInputs := op.GetMaxInputs() 33 | 34 | if minInputs == maxInputs { 35 | if nInputs != minInputs { 36 | return 0, ErrInvalidInputCount(nInputs, op) 37 | } 38 | 39 | padLength = minInputs 40 | } else { 41 | if nInputs < minInputs || nInputs > maxInputs { 42 | return 0, ErrInvalidOptionalInputCount(nInputs, op) 43 | } 44 | 45 | padLength = maxInputs 46 | } 47 | 48 | return padLength, nil 49 | } 50 | 51 | // padInputs pads a list of input nodes to the given length with nils. 52 | func padInputs(inputs []tensor.Tensor, length int) []tensor.Tensor { 53 | for len(inputs) < length { 54 | inputs = append(inputs, nil) 55 | } 56 | 57 | return inputs 58 | } 59 | 60 | func checkInputTypes(op BaseOperator, inputs []tensor.Tensor) error { 61 | typeConstraints := op.GetInputTypeConstraints() 62 | 63 | for i, input := range inputs { 64 | // Optional inputs can be nil, we can not check for type constraints then. 65 | if input == nil { 66 | continue 67 | } 68 | 69 | typeConstraint := newTypeConstraint(typeConstraints[i]) 70 | 71 | if _, ok := typeConstraint[input.Dtype()]; !ok { 72 | return ErrInvalidInputType(i, input.Dtype().Name(), op) 73 | } 74 | } 75 | 76 | return nil 77 | } 78 | 79 | // newTypeConstraint creates a map with for every type whether or not they are allowed. 80 | func newTypeConstraint(allowedTypes []tensor.Dtype) map[tensor.Dtype]bool { 81 | typeConstraint := make(map[tensor.Dtype]bool) 82 | 83 | for _, allowedType := range allowedTypes { 84 | typeConstraint[allowedType] = true 85 | } 86 | 87 | return typeConstraint 88 | } 89 | -------------------------------------------------------------------------------- /ops/where/versions.go: -------------------------------------------------------------------------------- 1 | package where 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var whereVersions = ops.OperatorVersions{ 6 | 9: ops.NewOperatorConstructor(newWhere, 9, whereTypeConstraints), 7 | } 8 | 9 | func GetVersions() ops.OperatorVersions { 10 | return whereVersions 11 | } 12 | -------------------------------------------------------------------------------- /ops/where/where.go: -------------------------------------------------------------------------------- 1 | package where 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var whereTypeConstraints = [][]tensor.Dtype{ 10 | {tensor.Bool}, 11 | ops.AllTypes, 12 | ops.AllTypes, 13 | } 14 | 15 | // Where represents the ONNX where operator. 16 | type Where struct { 17 | ops.BaseOperator 18 | } 19 | 20 | // newWhere creates a new where operator. 21 | func newWhere(version int, typeConstraints [][]tensor.Dtype) ops.Operator { 22 | return &Where{ 23 | BaseOperator: ops.NewBaseOperator( 24 | version, 25 | 3, 26 | 3, 27 | typeConstraints, 28 | "where", 29 | ), 30 | } 31 | } 32 | 33 | // Init initializes the where operator. 34 | func (w *Where) Init(*onnx.NodeProto) error { 35 | return nil 36 | } 37 | 38 | // Apply applies the where operator. 39 | func (w *Where) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 40 | condition := inputs[0] 41 | 42 | X := inputs[1] 43 | Y := inputs[2] 44 | 45 | X, Y, err := ops.MultidirectionalBroadcast(X, Y) 46 | if err != nil { 47 | return nil, err 48 | } 49 | 50 | condition, X, err = ops.MultidirectionalBroadcast(condition, X) 51 | if err != nil { 52 | return nil, err 53 | } 54 | 55 | out, err := where(X, Y, condition) 56 | if err != nil { 57 | return nil, err 58 | } 59 | 60 | return []tensor.Tensor{out}, err 61 | } 62 | 63 | func where(X, Y, condition tensor.Tensor) (tensor.Tensor, error) { 64 | out := tensor.New(tensor.Of(X.Dtype()), tensor.WithShape(X.Shape()...)) 65 | 66 | iterator := condition.Iterator() 67 | iterator.Reset() 68 | 69 | for !iterator.Done() { 70 | coords := iterator.Coord() 71 | 72 | conditionRaw, err := condition.At(coords...) 73 | if err != nil { 74 | return nil, err 75 | } 76 | 77 | conditionValue, ok := conditionRaw.(bool) 78 | if !ok { 79 | return nil, ops.ErrCast 80 | } 81 | 82 | var value any 83 | if conditionValue { 84 | value, err = X.At(coords...) 85 | } else { 86 | value, err = Y.At(coords...) 87 | } 88 | 89 | if err != nil { 90 | return nil, err 91 | } 92 | 93 | err = out.SetAt(value, coords...) 94 | if err != nil { 95 | return nil, err 96 | } 97 | 98 | _, err = iterator.Next() 99 | if err != nil { 100 | return nil, err 101 | } 102 | } 103 | 104 | return out, nil 105 | } 106 | -------------------------------------------------------------------------------- /ops/where/where_test.go: -------------------------------------------------------------------------------- 1 | package where 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "gorgonia.org/tensor" 8 | ) 9 | 10 | func TestWhereInit(t *testing.T) { 11 | op := whereVersions[9]() 12 | err := op.Init(nil) 13 | assert.Nil(t, err) 14 | } 15 | 16 | func TestWhere(t *testing.T) { 17 | tests := []struct { 18 | version int64 19 | condition []bool 20 | conditionShape []int 21 | backing1 []float32 22 | backing1Shape []int 23 | backing2 []float32 24 | backing2Shape []int 25 | expectedBacking []float32 26 | }{ 27 | { 28 | 9, 29 | []bool{true, false, true}, 30 | []int{3}, 31 | []float32{1, 2, 3}, 32 | []int{3}, 33 | []float32{4, 5, 6}, 34 | []int{3}, 35 | []float32{1, 5, 3}, 36 | }, 37 | { 38 | 9, 39 | []bool{true, false, true, false}, 40 | []int{2, 2}, 41 | []float32{1, 2, 3, 4}, 42 | []int{2, 2}, 43 | []float32{4, 5}, 44 | []int{1, 2}, 45 | []float32{1, 5, 3, 5}, 46 | }, 47 | { 48 | 9, 49 | []bool{false, true}, 50 | []int{2}, 51 | []float32{1, 2, 3, 4}, 52 | []int{2, 2}, 53 | []float32{4, 5}, 54 | []int{1, 2}, 55 | []float32{4, 2, 4, 4}, 56 | }, 57 | { 58 | 9, 59 | []bool{false, false, false, true, true, true}, 60 | []int{2, 3}, 61 | []float32{1, 2, 3, 4, 5, 6}, 62 | []int{2, 3}, 63 | []float32{4, 5, 6}, 64 | []int{3}, 65 | []float32{4, 5, 6, 4, 5, 6}, 66 | }, 67 | { 68 | 9, 69 | []bool{false, true, true, false, false, true}, 70 | []int{2, 3}, 71 | []float32{1, 2, 3, 4, 5, 6}, 72 | []int{2, 3}, 73 | []float32{4, 5, 6}, 74 | []int{3}, 75 | []float32{4, 2, 3, 4, 5, 6}, 76 | }, 77 | } 78 | 79 | for _, test := range tests { 80 | inputs := []tensor.Tensor{ 81 | tensor.New(tensor.WithShape(test.conditionShape...), tensor.WithBacking(test.condition)), 82 | tensor.New(tensor.WithShape(test.backing1Shape...), tensor.WithBacking(test.backing1)), 83 | tensor.New(tensor.WithShape(test.backing2Shape...), tensor.WithBacking(test.backing2)), 84 | } 85 | 86 | op := whereVersions[test.version]() 87 | 88 | res, err := op.Apply(inputs) 89 | assert.Nil(t, err) 90 | assert.Equal(t, test.expectedBacking, res[0].Data()) 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /ops/xor/versions.go: -------------------------------------------------------------------------------- 1 | package xor 2 | 3 | import "github.com/advancedclimatesystems/gonnx/ops" 4 | 5 | var xorVersions = ops.OperatorVersions{ 6 | 7: ops.NewOperatorConstructor(newXor, 7, xorTypeConstraints), 7 | } 8 | 9 | func GetVersions() ops.OperatorVersions { 10 | return xorVersions 11 | } 12 | -------------------------------------------------------------------------------- /ops/xor/xor.go: -------------------------------------------------------------------------------- 1 | package xor 2 | 3 | import ( 4 | "github.com/advancedclimatesystems/gonnx/onnx" 5 | "github.com/advancedclimatesystems/gonnx/ops" 6 | "gorgonia.org/tensor" 7 | ) 8 | 9 | var xorTypeConstraints = [][]tensor.Dtype{{tensor.Bool}, {tensor.Bool}} 10 | 11 | // Xor represents the ONNX xor operator. 12 | type Xor struct { 13 | ops.BaseOperator 14 | } 15 | 16 | // newXor creates a new xor operator. 17 | func newXor(version int, typeConstraint [][]tensor.Dtype) ops.Operator { 18 | return &Xor{ 19 | BaseOperator: ops.NewBaseOperator( 20 | version, 21 | 2, 22 | 2, 23 | typeConstraint, 24 | "xor", 25 | ), 26 | } 27 | } 28 | 29 | // Init initializes the xor operator. 30 | func (x *Xor) Init(*onnx.NodeProto) error { 31 | return nil 32 | } 33 | 34 | // Apply applies the xor operator. 35 | func (x *Xor) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { 36 | return ops.ApplyBinaryOperation( 37 | inputs[0], 38 | inputs[1], 39 | ops.Xor, 40 | ops.MultidirectionalBroadcasting, 41 | ) 42 | } 43 | -------------------------------------------------------------------------------- /ops/xor/xor_test.go: -------------------------------------------------------------------------------- 1 | package xor 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | "gorgonia.org/tensor" 9 | ) 10 | 11 | func TestXorInit(t *testing.T) { 12 | x := &Xor{} 13 | 14 | err := x.Init(nil) 15 | assert.Nil(t, err) 16 | } 17 | 18 | func TestXor(t *testing.T) { 19 | tests := []struct { 20 | xor *Xor 21 | backings [][]bool 22 | shapes [][]int 23 | expected []bool 24 | }{ 25 | { 26 | &Xor{}, 27 | [][]bool{{true, false, true, false}, {true, true, true, false}}, 28 | [][]int{{2, 2}, {2, 2}}, 29 | []bool{false, true, false, false}, 30 | }, 31 | { 32 | &Xor{}, 33 | [][]bool{{true, false, true, false}, {true, false}}, 34 | [][]int{{2, 2}, {1, 2}}, 35 | []bool{false, false, false, false}, 36 | }, 37 | { 38 | &Xor{}, 39 | [][]bool{{true, false, true, false}, {true, false}}, 40 | [][]int{{2, 2}, {2, 1}}, 41 | []bool{false, true, true, false}, 42 | }, 43 | { 44 | &Xor{}, 45 | [][]bool{{true, false, true, false, true, false}, {false, false}}, 46 | [][]int{{3, 2}, {1, 2}}, 47 | []bool{true, false, true, false, true, false}, 48 | }, 49 | } 50 | 51 | for _, test := range tests { 52 | inputs := []tensor.Tensor{ 53 | ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), 54 | ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), 55 | } 56 | 57 | res, err := test.xor.Apply(inputs) 58 | assert.Nil(t, err) 59 | 60 | assert.Nil(t, err) 61 | assert.Equal(t, test.expected, res[0].Data()) 62 | } 63 | } 64 | 65 | func TestInputValidationXor(t *testing.T) { 66 | tests := []struct { 67 | inputs []tensor.Tensor 68 | err error 69 | version int64 70 | }{ 71 | { 72 | []tensor.Tensor{ 73 | ops.TensorWithBackingFixture([]bool{false, false}, 2), 74 | ops.TensorWithBackingFixture([]bool{false, false}, 2), 75 | }, 76 | nil, 77 | 7, 78 | }, 79 | { 80 | []tensor.Tensor{ 81 | ops.TensorWithBackingFixture([]bool{false, false}, 2), 82 | }, 83 | ops.ErrInvalidInputCount(1, ops.NewBaseOperator(7, 2, 2, xorTypeConstraints, "xor")), 84 | 7, 85 | }, 86 | { 87 | []tensor.Tensor{ 88 | ops.TensorWithBackingFixture([]bool{false, false}, 2), 89 | ops.TensorWithBackingFixture([]int{1, 2}, 2), 90 | }, 91 | ops.ErrInvalidInputType(1, "int", ops.NewBaseOperator(7, 2, 2, xorTypeConstraints, "xor")), 92 | 7, 93 | }, 94 | } 95 | 96 | for _, test := range tests { 97 | xor := xorVersions[test.version]() 98 | validated, err := xor.ValidateInputs(test.inputs) 99 | 100 | assert.Equal(t, test.err, err) 101 | 102 | if test.err == nil { 103 | assert.Equal(t, test.inputs, validated) 104 | } 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /opset_test.go: -------------------------------------------------------------------------------- 1 | package gonnx 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/advancedclimatesystems/gonnx/ops" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestResolveOpset(t *testing.T) { 11 | _, err := ResolveOpset(13) 12 | assert.Nil(t, err) 13 | } 14 | 15 | func TestResolveOpsetNotSupported(t *testing.T) { 16 | opset, err := ResolveOpset(6) 17 | assert.Nil(t, opset) 18 | assert.Equal(t, ops.ErrUnsupportedOpsetVersion, err) 19 | } 20 | -------------------------------------------------------------------------------- /sample_models/generate_sample_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from python_models import ( 4 | generate_mlp_onnx_from_torch, 5 | generate_gru_onnx_from_torch, 6 | generate_scaler_onnx_from_sklearn, 7 | test_mlp_torch, 8 | test_gru_torch, 9 | test_scaler_sklearn, 10 | ) 11 | 12 | 13 | def main(args): 14 | if args.action == "generate": 15 | generate_mlp_onnx_from_torch() 16 | generate_gru_onnx_from_torch() 17 | generate_scaler_onnx_from_sklearn() 18 | elif args.action == "test": 19 | test_mlp_torch() 20 | test_gru_torch() 21 | test_scaler_sklearn() 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | "--action", 28 | type=str, 29 | choices=["generate", "test"], 30 | required=True, 31 | help="Whether to generate sample models or test sample models (generate output)", 32 | ) 33 | args = parser.parse_args() 34 | main(args) 35 | -------------------------------------------------------------------------------- /sample_models/onnx_models/gru.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdvancedClimateSystems/gonnx/c879ba407e657994a925ae985a06521cda34739d/sample_models/onnx_models/gru.onnx -------------------------------------------------------------------------------- /sample_models/onnx_models/mlp.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdvancedClimateSystems/gonnx/c879ba407e657994a925ae985a06521cda34739d/sample_models/onnx_models/mlp.onnx -------------------------------------------------------------------------------- /sample_models/onnx_models/mnist-8-opset13.onnx: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d6267e75ad19e51ad643554f861f21fc76bcb54b625074a845ccf329c465bad6 3 | size 26454 4 | -------------------------------------------------------------------------------- /sample_models/onnx_models/ndm.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdvancedClimateSystems/gonnx/c879ba407e657994a925ae985a06521cda34739d/sample_models/onnx_models/ndm.onnx -------------------------------------------------------------------------------- /sample_models/onnx_models/nt_1.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdvancedClimateSystems/gonnx/c879ba407e657994a925ae985a06521cda34739d/sample_models/onnx_models/nt_1.zip -------------------------------------------------------------------------------- /sample_models/onnx_models/scaler.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdvancedClimateSystems/gonnx/c879ba407e657994a925ae985a06521cda34739d/sample_models/onnx_models/scaler.onnx -------------------------------------------------------------------------------- /sample_models/python_models/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8:noqa 2 | 3 | from python_models.gru_torch import generate_gru_onnx_from_torch, test_gru_torch 4 | from python_models.mlp_torch import generate_mlp_onnx_from_torch, test_mlp_torch 5 | from python_models.scaler_sklearn import generate_scaler_onnx_from_sklearn, test_scaler_sklearn 6 | -------------------------------------------------------------------------------- /sample_models/python_models/gru_torch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | INPUT_SIZE = 3 6 | HIDDEN_SIZE = 5 7 | 8 | 9 | class GRU(nn.Module): 10 | """ 11 | Simple GRU model with only a single GRU unit. Structure: 12 | 13 | GRU( 14 | (gru): GRU(3, 5, batch_first=True) 15 | ) 16 | """ 17 | 18 | def __init__(self): 19 | super().__init__() 20 | 21 | self.gru = nn.GRU(INPUT_SIZE, HIDDEN_SIZE, batch_first=True) 22 | 23 | def forward(self, inputs, hidden): 24 | x, new_hidden = self.gru(inputs, hidden) 25 | return x, new_hidden 26 | 27 | def get_init_hidden(self, batch_size): 28 | return torch.zeros((1, batch_size, HIDDEN_SIZE)) 29 | 30 | 31 | def generate_gru_onnx_from_torch(): 32 | print("-" * 100) 33 | print("Generating 'gru.onnx'...") 34 | 35 | torch.manual_seed(42) 36 | gru = GRU() 37 | print(gru, "\n") 38 | 39 | batch_size = 1 40 | seq_length = 30 41 | data_input = torch.from_numpy( 42 | np.arange(0, batch_size * seq_length * INPUT_SIZE).reshape( 43 | batch_size, seq_length, INPUT_SIZE 44 | ) 45 | ).float() 46 | 47 | sample_in = (data_input, gru.get_init_hidden(batch_size)) 48 | torch.onnx.export( 49 | gru, 50 | sample_in, 51 | "./onnx_models/gru.onnx", 52 | opset_version=13, 53 | input_names=["data_input", "init_hidden"], 54 | output_names=["preds", "hidden_out"], 55 | dynamic_axes={ 56 | "data_input": {0: "batch_size", 1: "seq_length"}, 57 | "init_hidden": {1: "batch_size"}, 58 | "preds": {0: "batch_size", 1: "seq_length"}, 59 | "hidden_out": {1: "batch_size"}, 60 | }, 61 | ) 62 | 63 | 64 | def test_gru_torch(): 65 | torch.manual_seed(42) 66 | gru = GRU() 67 | 68 | batch_size = 1 69 | seq_length = 30 70 | data_input = torch.from_numpy( 71 | np.arange(0, batch_size * seq_length * INPUT_SIZE).reshape( 72 | batch_size, seq_length, INPUT_SIZE 73 | ) 74 | ).float() 75 | init_hidden = gru.get_init_hidden(batch_size) 76 | 77 | with torch.no_grad(): 78 | preds, hidden_out = gru(data_input, init_hidden) 79 | 80 | print("-" * 50) 81 | print("GRU sample:\n") 82 | print("---INPUTS---") 83 | print("data_input: ", data_input.shape) 84 | print(data_input, "\n") 85 | print("init_hidden: ", init_hidden.shape) 86 | print(init_hidden, "\n") 87 | print("---OUTPUTS---") 88 | print("preds: ", preds.shape) 89 | print(preds, "\n") 90 | print("hidden_out: ", hidden_out.shape) 91 | print(hidden_out, "\n") 92 | -------------------------------------------------------------------------------- /sample_models/python_models/mlp_torch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | INPUT_SIZE = 3 6 | HIDDEN_SIZE = 5 7 | OUTPUT_SIZE = 2 8 | 9 | 10 | class MLP(nn.Module): 11 | """ 12 | A simple MultiLayer Perceptron (MLP). It has structure: 13 | 14 | MLP( 15 | (layer1): Linear(in_features=3, out_features=5, bias=True) 16 | (relu): ReLU() 17 | (layer2): Linear(in_features=5, out_features=2, bias=True) 18 | ) 19 | """ 20 | 21 | def __init__(self): 22 | super().__init__() 23 | self.layer1 = nn.Linear(INPUT_SIZE, HIDDEN_SIZE) 24 | self.relu = nn.ReLU() 25 | self.layer2 = nn.Linear(HIDDEN_SIZE, OUTPUT_SIZE) 26 | 27 | def forward(self, inputs): 28 | x = self.relu(self.layer1(inputs)) 29 | return self.layer2(x) 30 | 31 | 32 | def generate_mlp_onnx_from_torch(): 33 | print("-" * 100) 34 | print("Generating 'mlp.onnx'...") 35 | 36 | torch.manual_seed(42) 37 | mlp = MLP() 38 | print(mlp, "\n") 39 | 40 | batch_size = 2 41 | sample_in = torch.from_numpy(np.random.rand(batch_size, INPUT_SIZE)).float() 42 | 43 | torch.onnx.export( 44 | mlp, 45 | sample_in, 46 | "./onnx_models/mlp.onnx", 47 | opset_version=13, 48 | input_names=["data_input"], 49 | output_names=["preds"], 50 | dynamic_axes={ 51 | "data_input": {0: "batch_size"}, 52 | "preds": {0: "batch_size"}, 53 | }, 54 | ) 55 | 56 | 57 | def test_mlp_torch(): 58 | torch.manual_seed(42) 59 | mlp = MLP() 60 | 61 | batch_size = 2 62 | data_input = torch.from_numpy( 63 | np.arange(0, batch_size * INPUT_SIZE).reshape(batch_size, INPUT_SIZE) 64 | ).float() 65 | 66 | with torch.no_grad(): 67 | preds = mlp(data_input) 68 | 69 | print("-" * 50) 70 | print("MLP sample:\n") 71 | print("---INPUTS---") 72 | print("data_input: ", data_input.shape) 73 | print(data_input, "\n") 74 | print("---OUTPUTS---") 75 | print("preds: ", preds.shape) 76 | print(preds, "\n") 77 | -------------------------------------------------------------------------------- /sample_models/python_models/scaler_sklearn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.preprocessing import StandardScaler 3 | from sklearn.pipeline import Pipeline 4 | from skl2onnx import to_onnx 5 | 6 | INPUT_SIZE = 3 7 | 8 | sample_data = np.array( 9 | [ 10 | [1, 10, 100], 11 | [1.5, 13, 120], 12 | [0.8, 9, 95], 13 | [0.9, 11, 105], 14 | [0.6, 12, 101], 15 | [1.3, 10, 110], 16 | [1.1, 7, 108], 17 | ], 18 | dtype=np.float32, 19 | ) 20 | 21 | 22 | def generate_scaler_onnx_from_sklearn(): 23 | print("-" * 100) 24 | print("Generating 'scaler.onnx'...") 25 | 26 | scaler = StandardScaler() 27 | scaler.fit(sample_data) 28 | 29 | scaler = Pipeline(steps=[("scaler", scaler)]) 30 | 31 | onnx_scaler = to_onnx(scaler, sample_data, target_opset=13) 32 | with open("./onnx_models/scaler.onnx", "wb") as onnx_file: 33 | onnx_file.write(onnx_scaler.SerializeToString()) 34 | 35 | print(scaler) 36 | 37 | 38 | def test_scaler_sklearn(): 39 | scaler = StandardScaler() 40 | scaler.fit(sample_data) 41 | 42 | Y = scaler.transform(sample_data[:2]) 43 | 44 | print("-" * 50) 45 | print("Scaler sample:\n") 46 | print("---INPUTS---") 47 | print("X: ", sample_data[:2].shape) 48 | print(sample_data[:2], "\n") 49 | print("---OUTPUTS---") 50 | print("Y: ", Y.shape) 51 | print(Y, "\n") 52 | -------------------------------------------------------------------------------- /sample_models/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.26.1 2 | scikit-learn==1.3.1 3 | skl2onnx==1.15.0 4 | torch==1.9.0 5 | --------------------------------------------------------------------------------