├── .gitignore ├── test └── data │ ├── model1.t7 │ ├── model2.t7 │ ├── model3.t7 │ ├── model4.t7 │ ├── model5.t7 │ ├── model1batch.t7 │ ├── model2.lua │ ├── model4.lua │ ├── model5.lua │ ├── model1.lua │ ├── model1batch.lua │ └── model3.lua ├── onnx ├── init.lua ├── helper.lua ├── operators │ ├── abs.lua │ ├── init.lua │ ├── sigmoid.lua │ ├── softmax.lua │ ├── identity.lua │ ├── tanh.lua │ ├── add.lua │ ├── mul.lua │ ├── transpose.lua │ ├── reshape.lua │ ├── tile.lua │ ├── squeeze.lua │ ├── unsqueeze.lua │ ├── gather.lua │ ├── matmul.lua │ ├── concat.lua │ ├── split.lua │ └── gemm.lua ├── node.lua ├── checker.lua └── graph.lua ├── convertors ├── init.lua ├── onnx_onmt.lua ├── onnx_nngraph.lua └── onnx_nn.lua ├── README.md ├── convert.lua └── onnx_pb.lua /.gitignore: -------------------------------------------------------------------------------- 1 | *.onnxdir 2 | -------------------------------------------------------------------------------- /test/data/model1.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/luatorch-onnx-convert/HEAD/test/data/model1.t7 -------------------------------------------------------------------------------- /test/data/model2.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/luatorch-onnx-convert/HEAD/test/data/model2.t7 -------------------------------------------------------------------------------- /test/data/model3.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/luatorch-onnx-convert/HEAD/test/data/model3.t7 -------------------------------------------------------------------------------- /test/data/model4.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/luatorch-onnx-convert/HEAD/test/data/model4.t7 -------------------------------------------------------------------------------- /test/data/model5.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/luatorch-onnx-convert/HEAD/test/data/model5.t7 -------------------------------------------------------------------------------- /test/data/model1batch.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/luatorch-onnx-convert/HEAD/test/data/model1batch.t7 -------------------------------------------------------------------------------- /onnx/init.lua: -------------------------------------------------------------------------------- 1 | torch.class('onnx') 2 | 3 | require 'onnx.graph' 4 | require 'onnx.node' 5 | require 'onnx.helper' 6 | require 'onnx.checker' 7 | 8 | require 'onnx.operators.init' -------------------------------------------------------------------------------- /onnx/helper.lua: -------------------------------------------------------------------------------- 1 | local Helper = torch.class('onnx.helper') 2 | 3 | local onnx_pb = require('onnx_pb') 4 | 5 | function Helper.convertPrecision(_) 6 | return onnx_pb.TensorProto.FLOAT 7 | end 8 | -------------------------------------------------------------------------------- /test/data/model2.lua: -------------------------------------------------------------------------------- 1 | require 'nngraph' 2 | 3 | local h1 = - nn.Linear(20,10) 4 | local h2 = h1 5 | - nn.Tanh() 6 | - nn.Linear(10,10) 7 | - nn.Tanh() 8 | - nn.Linear(10, 1) 9 | local mlp = nn.gModule({h1}, {h2}) 10 | 11 | mlp:forward(torch.randn(3,20)) 12 | 13 | torch.save("model2.t7", mlp) -------------------------------------------------------------------------------- /test/data/model4.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | local mlp = nn.Sequential() 4 | mlp:add(nn.Linear(10, 25)) -- Linear module (10 inputs, 25 hidden units) 5 | mlp:add(nn.Tanh()) -- apply hyperbolic tangent transfer function on each hidden units 6 | mlp:add(nn.Linear(25, 1)) -- Linear module (25 inputs, 1 output) 7 | 8 | mlp:forward(torch.randn(10)) 9 | 10 | torch.save("model4.t7", mlp) -------------------------------------------------------------------------------- /test/data/model5.lua: -------------------------------------------------------------------------------- 1 | require 'nngraph' 2 | 3 | local id1 = nn.Identity()() 4 | local id2 = nn.Identity()() 5 | 6 | local a12 = nn.CAddTable()({id1, id2}) 7 | local mod1 = nn.gModule({id1,id2}, {a12, nn.Sigmoid()(id1)}) 8 | 9 | local id3 = nn.Identity()() 10 | local id4 = nn.Identity()() 11 | local m12 = nn.CMulTable()({id3, id4}) 12 | local o1, o2 = mod1({m12, id3}):split(2) 13 | local mod2 = nn.gModule({id3, id4}, {m12, o1, o2}) 14 | 15 | mod2:forward({torch.randn(5), torch.randn(5)}) 16 | 17 | torch.save("model5.t7", mod2) -------------------------------------------------------------------------------- /onnx/operators/abs.lua: -------------------------------------------------------------------------------- 1 | local Abs, parent = torch.class('onnx.node.Abs', 'onnx.node') 2 | 3 | function Abs:__init(inputs, outputs) 4 | parent.__init(self, "Abs", inputs, 1, outputs, 1) 5 | end 6 | 7 | -- given some constraint for the named parameters, check the compatibility 8 | -- and refine these constraints 9 | function Abs:getShapeConstraint(checker) 10 | checker:setChange(false) 11 | 12 | local cx = checker:getParam(self._inputs[1]) 13 | local cy = checker:getParam(self._outputs[1]) 14 | 15 | self._pass = checker:sameShape({cx, cy}) or checker:fail() 16 | 17 | return checker:hasChange() 18 | end -------------------------------------------------------------------------------- /onnx/operators/init.lua: -------------------------------------------------------------------------------- 1 | require 'onnx.operators.gemm' 2 | require 'onnx.operators.matmul' 3 | require 'onnx.operators.transpose' 4 | require 'onnx.operators.identity' 5 | require 'onnx.operators.abs' 6 | require 'onnx.operators.tanh' 7 | require 'onnx.operators.sigmoid' 8 | require 'onnx.operators.softmax' 9 | require 'onnx.operators.gather' 10 | require 'onnx.operators.add' 11 | require 'onnx.operators.mul' 12 | require 'onnx.operators.reshape' 13 | require 'onnx.operators.squeeze' 14 | require 'onnx.operators.unsqueeze' 15 | require 'onnx.operators.tile' 16 | require 'onnx.operators.split' 17 | require 'onnx.operators.concat' 18 | -------------------------------------------------------------------------------- /onnx/operators/sigmoid.lua: -------------------------------------------------------------------------------- 1 | local Sigmoid, parent = torch.class('onnx.node.Sigmoid', 'onnx.node') 2 | 3 | function Sigmoid:__init(inputs, outputs) 4 | parent.__init(self, "Sigmoid", inputs, 1, outputs, 1) 5 | end 6 | 7 | -- given some constraint for the named parameters, check the compatibility 8 | -- and refine these constraints 9 | function Sigmoid:getShapeConstraint(checker) 10 | checker:setChange(false) 11 | 12 | local cx = checker:getParam(self._inputs[1]) 13 | local cy = checker:getParam(self._outputs[1]) 14 | 15 | self._pass = checker:sameShape({cx, cy}) or checker:fail() 16 | 17 | return checker:hasChange() 18 | end -------------------------------------------------------------------------------- /onnx/operators/softmax.lua: -------------------------------------------------------------------------------- 1 | local SoftMax, parent = torch.class('onnx.node.SoftMax', 'onnx.node') 2 | 3 | function SoftMax:__init(inputs, outputs) 4 | parent.__init(self, "SoftMax", inputs, 1, outputs, 1) 5 | end 6 | 7 | -- given some constraint for the named parameters, check the compatibility 8 | -- and refine these constraints 9 | function SoftMax:getShapeConstraint(checker) 10 | checker:setChange(false) 11 | 12 | local cx = checker:getParam(self._inputs[1]) 13 | local cy = checker:getParam(self._outputs[1]) 14 | 15 | self._pass = checker:sameShape({cx, cy}) or checker:fail() 16 | 17 | return checker:hasChange() 18 | end -------------------------------------------------------------------------------- /onnx/operators/identity.lua: -------------------------------------------------------------------------------- 1 | local Identity, parent = torch.class('onnx.node.Identity', 'onnx.node') 2 | 3 | function Identity:__init(inputs, outputs) 4 | parent.__init(self, "Identity", inputs, 1, outputs, 1) 5 | end 6 | 7 | -- given some constraint for the named parameters, check the compatibility 8 | -- and refine these constraints 9 | function Identity:getShapeConstraint(checker) 10 | checker:setChange(false) 11 | 12 | local cx = checker:getParam(self._inputs[1]) 13 | local cy = checker:getParam(self._outputs[1]) 14 | 15 | self._pass = checker:sameShape({cx, cy}) or checker:fail() 16 | 17 | return checker:hasChange() 18 | end -------------------------------------------------------------------------------- /onnx/operators/tanh.lua: -------------------------------------------------------------------------------- 1 | local Tanh, parent = torch.class('onnx.node.Tanh', 'onnx.node') 2 | 3 | function Tanh:__init(inputs, outputs, precision) 4 | parent.__init(self, "Tanh", inputs, 1, outputs, 1) 5 | self._precision = precision 6 | end 7 | 8 | -- given some constraint for the named parameters, check the compatibility 9 | -- and refine these constraints 10 | function Tanh:getShapeConstraint(checker) 11 | checker:setChange(false) 12 | 13 | local cx = checker:getParam(self._inputs[1]) 14 | local cy = checker:getParam(self._outputs[1]) 15 | 16 | self._pass = checker:sameShape({cx, cy}) or checker:fail() 17 | 18 | return checker:hasChange() 19 | end -------------------------------------------------------------------------------- /onnx/operators/add.lua: -------------------------------------------------------------------------------- 1 | local Add, parent = torch.class('onnx.node.Add', 'onnx.node') 2 | 3 | function Add:__init(inputs, outputs, precision) 4 | parent.__init(self, "Add", inputs, 2, outputs, 1) 5 | self._precision = precision 6 | end 7 | 8 | -- given some constraint for the named parameters, check the compatibility 9 | -- and refine these constraints 10 | function Add:getShapeConstraint(checker) 11 | checker:setChange(false) 12 | 13 | local cx1 = checker:getParam(self._inputs[1]) 14 | local cx2 = checker:getParam(self._inputs[2]) 15 | local cy = checker:getParam(self._outputs[1]) 16 | 17 | self._pass = checker:sameShape({cx1, cx2, cy}) or checker:fail() 18 | 19 | return checker:hasChange() 20 | end 21 | -------------------------------------------------------------------------------- /onnx/operators/mul.lua: -------------------------------------------------------------------------------- 1 | local Mul, parent = torch.class('onnx.node.Mul', 'onnx.node') 2 | 3 | function Mul:__init(inputs, outputs, precision) 4 | parent.__init(self, "Mul", inputs, 2, outputs, 1) 5 | self._precision = precision 6 | end 7 | 8 | -- given some constraint for the named parameters, check the compatibility 9 | -- and refine these constraints 10 | function Mul:getShapeConstraint(checker) 11 | checker:setChange(false) 12 | 13 | local cx1 = checker:getParam(self._inputs[1]) 14 | local cx2 = checker:getParam(self._inputs[2]) 15 | local cy = checker:getParam(self._outputs[1]) 16 | 17 | self._pass = checker:sameShape({cx1, cx2, cy}) or checker:fail() 18 | 19 | return checker:hasChange() 20 | end 21 | -------------------------------------------------------------------------------- /test/data/model1.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | local mod = {} 4 | 5 | mod['linear-bias'] = nn.Linear(20, 10) 6 | mod['linear-bias']:forward(torch.randn(20)) 7 | 8 | mod['linear-nobias'] = nn.Linear(20, 10, false) 9 | mod['linear-nobias']:forward(torch.randn(20)) 10 | 11 | mod['cadd-table'] = nn.CAddTable() 12 | mod['cadd-table']:forward({torch.randn(3), torch.randn(3)}) 13 | 14 | mod['abs'] = nn.Abs() 15 | mod['abs']:forward(torch.randn(15)) 16 | 17 | mod['tanh'] = nn.Tanh() 18 | mod['tanh']:forward(torch.randn(15)) 19 | 20 | mod['sigmoid'] = nn.Sigmoid() 21 | mod['sigmoid']:forward(torch.randn(15)) 22 | 23 | mod['lookup'] = nn.LookupTable(20, 100) 24 | mod['lookup']:forward((torch.rand(3):abs()*20):int()) 25 | 26 | mod['reshape'] = nn.Reshape(4, 3, 7) 27 | mod['reshape']:forward(torch.rand(42, 2)) 28 | 29 | mod['splittable'] = nn.SplitTable(2) 30 | mod['splittable']:forward(torch.rand(42, 2)) 31 | 32 | mod['replicate'] = nn.Replicate(3, 1) 33 | mod['replicate']:forward(torch.linspace(1, 5, 5)) 34 | 35 | torch.save("model1.t7", mod) -------------------------------------------------------------------------------- /onnx/operators/transpose.lua: -------------------------------------------------------------------------------- 1 | local Transpose, parent = torch.class('onnx.node.Transpose', 'onnx.node') 2 | 3 | function Transpose:__init(inputs, outputs, perm) 4 | parent.__init(self, "Transpose", inputs, 1, outputs, 1) 5 | self._perm = perm 6 | end 7 | 8 | -- given some constraint for the named parameters, check the compatibility 9 | -- and refine these constraints 10 | function Transpose:getShapeConstraint(checker) 11 | local cx = checker:assert2D(self._inputs[1]) 12 | local cy = checker:assert2D(self._outputs[1]) 13 | 14 | local count = 0 15 | checker:setChange(true) 16 | while checker:hasChange() do 17 | count = count + 1 18 | checker:setChange(false) 19 | self._pass = checker:dimCheck(cx, 1, cy, 2) or checker:fail() 20 | self._pass = checker:dimCheck(cx, 2, cy, 1) or checker:fail() 21 | end 22 | 23 | return count ~= 1 24 | end 25 | 26 | function Transpose:build(onnx_pb, node) 27 | parent.build(self, onnx_pb, node) 28 | self.addAttribute(node, "perm", 'ints', self._perm, onnx_pb.AttributeProto.INTS) 29 | end -------------------------------------------------------------------------------- /test/data/model1batch.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | local mod = {} 4 | 5 | mod['linear-bias'] = nn.Linear(20, 10) 6 | mod['linear-bias']:forward(torch.randn(7, 20)) 7 | 8 | mod['linear-nobias'] = nn.Linear(20, 10, false) 9 | mod['linear-nobias']:forward(torch.randn(7, 20)) 10 | 11 | mod['cadd-table'] = nn.CAddTable() 12 | mod['cadd-table']:forward({torch.randn(7, 3), torch.randn(7, 3)}) 13 | 14 | mod['abs'] = nn.Abs() 15 | mod['abs']:forward(torch.randn(7, 15)) 16 | 17 | mod['tanh'] = nn.Tanh() 18 | mod['tanh']:forward(torch.randn(7, 15)) 19 | 20 | mod['sigmoid'] = nn.Sigmoid() 21 | mod['sigmoid']:forward(torch.randn(7, 15)) 22 | 23 | mod['lookup'] = nn.LookupTable(20, 100) 24 | mod['lookup']:forward((torch.rand(3):abs()*20):int()) 25 | 26 | mod['reshape'] = nn.Reshape(4, 3, 7, true) 27 | mod['reshape']:forward(torch.rand(7, 42, 2)) 28 | 29 | mod['splittable'] = nn.SplitTable(2, 2) 30 | mod['splittable']:forward(torch.rand(7, 42, 2)) 31 | 32 | mod['replicate'] = nn.Replicate(3, 1, 1) 33 | mod['replicate']:forward(torch.rand(7, 5)) 34 | 35 | torch.save("model1batch.t7", mod) -------------------------------------------------------------------------------- /onnx/node.lua: -------------------------------------------------------------------------------- 1 | local Node = torch.class('onnx.node') 2 | 3 | function Node:__init(name, inputs, ninputs, outputs, noutputs) 4 | assert(#inputs==ninputs, "invalid number of inputs parameters") 5 | self._inputs = inputs 6 | assert(#outputs==noutputs, "invalid number of outputs parameters") 7 | self._outputs = outputs 8 | self._name = name 9 | end 10 | 11 | function Node:inputs() 12 | return self._inputs 13 | end 14 | 15 | function Node:outputs() 16 | return self._outputs 17 | end 18 | 19 | function Node:getShapeConstraint(_) 20 | error("getShapeConstraint not implemented - for operator "..self._name) 21 | end 22 | 23 | function Node:build(onnx_pb, node) 24 | for _, p in ipairs(self._inputs) do 25 | node.input:append(p) 26 | end 27 | for _, p in ipairs(self._outputs) do 28 | node.output:append(p) 29 | end 30 | node.op_type = self._name 31 | end 32 | 33 | function Node.addAttribute(onnx_node, name, namev, v, precision) 34 | local attribute = onnx_node.attribute:add() 35 | attribute.name = name 36 | attribute.type = precision 37 | if type(v) == "table" then 38 | for _, av in ipairs(v) do 39 | attribute[namev]:append(av) 40 | end 41 | else 42 | attribute[namev] = v 43 | end 44 | end 45 | -------------------------------------------------------------------------------- /onnx/operators/reshape.lua: -------------------------------------------------------------------------------- 1 | local Reshape, parent = torch.class('onnx.node.Reshape', 'onnx.node') 2 | 3 | function Reshape:__init(inputs, outputs, fixedShape) 4 | parent.__init(self, "Reshape", inputs, 2, outputs, 1) 5 | self._fixedShape = fixedShape 6 | end 7 | 8 | -- given some constraint for the named parameters, check the compatibility 9 | -- and refine these constraints 10 | 11 | function Reshape:getShapeConstraint(checker) 12 | checker:setChange(false) 13 | 14 | local cx = checker:getParam(self._inputs[1]) 15 | local cind = checker:assert1D(self._inputs[2]) 16 | local cy = checker:getParam(self._outputs[1]) 17 | 18 | -- reshape is not inversible - we can not infer shape of input given output 19 | if self._fixedShape ~= nil then 20 | if #cy == 0 then 21 | for i = 1, #self._fixedShape do 22 | if self._fixedShape[i] == -1 then 23 | table.insert(cy, checker:getUnkDimIdx()) 24 | else 25 | table.insert(cy, self._fixedShape[i]) 26 | end 27 | end 28 | else 29 | assert(#cy == #self._fixedShape, "invalid output shape") 30 | for i = 1, #self._fixedShape do 31 | if self._fixedShape[i] ~= -1 then 32 | checker:dimCheck(self._fixedShape, i, cy, i) 33 | end 34 | end 35 | end 36 | end 37 | 38 | return checker:hasChange() 39 | end -------------------------------------------------------------------------------- /onnx/operators/tile.lua: -------------------------------------------------------------------------------- 1 | local Tile, parent = torch.class('onnx.node.Tile', 'onnx.node') 2 | 3 | function Tile:__init(inputs, outputs, fixedRepeats) 4 | parent.__init(self, "Tile", inputs, 2, outputs, 1) 5 | self._fixedRepeats = fixedRepeats 6 | end 7 | 8 | -- given some constraint for the named parameters, check the compatibility 9 | -- and refine these constraints 10 | 11 | function Tile:getShapeConstraint(checker) 12 | checker:setChange(false) 13 | 14 | local cx = checker:getParam(self._inputs[1]) 15 | local crepeats = checker:assert1D(self._inputs[2]) 16 | local cy = checker:getParam(self._outputs[1]) 17 | 18 | if self._fixedRepeats then 19 | local n = #self._fixedRepeats 20 | assert(crepeats[1] == n, 'pb with tile initialization') 21 | cx = checker:assertND(self._inputs[1], n) or checker:fail() 22 | cy = checker:assertND(self._outputs[1], n) or checker:fail() 23 | for i = 1, n do 24 | if cx[i] > 0 then 25 | if cy[i] > 0 then 26 | assert(cy[i] == cx[i] * self._fixedRepeats[i], 'invalid tile result') 27 | else 28 | checker:changeUnk(cy[i], cx[i] * self._fixedRepeats[i]) 29 | end 30 | elseif cy[i] > 0 then 31 | assert(cy[i] % self._fixedRepeats[i] == 0, 'tile size not consistent with multiplier') 32 | checker:changeUnk(cx[i], cy[i] / self._fixedRepeats[i]) 33 | end 34 | end 35 | end 36 | 37 | return checker:hasChange() 38 | end -------------------------------------------------------------------------------- /test/data/model3.lua: -------------------------------------------------------------------------------- 1 | require 'nngraph' 2 | 3 | function _buildLayer(inputSize, hiddenSize) 4 | local inputs = {} 5 | table.insert(inputs, nn.Identity()()) 6 | table.insert(inputs, nn.Identity()()) 7 | table.insert(inputs, nn.Identity()()) 8 | 9 | local prevC = inputs[1] 10 | local prevH = inputs[2] 11 | local x = inputs[3] 12 | 13 | -- Evaluate the input sums at once for efficiency. 14 | local i2h = nn.Linear(inputSize, 4 * hiddenSize)(x) 15 | local h2h = nn.Linear(hiddenSize, 4 * hiddenSize)(prevH) 16 | local allInputSums = nn.CAddTable()({i2h, h2h}) 17 | 18 | local reshaped = nn.Reshape(4, hiddenSize)(allInputSums) 19 | local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4) 20 | 21 | -- Decode the gates. 22 | local inGate = nn.Sigmoid()(n1) 23 | local forgetGate = nn.Sigmoid()(n2) 24 | local outGate = nn.Sigmoid()(n3) 25 | 26 | -- Decode the write inputs. 27 | local inTransform = nn.Tanh()(n4) 28 | 29 | -- Perform the LSTM update. 30 | local nextC = nn.CAddTable()({ 31 | nn.CMulTable()({forgetGate, prevC}), 32 | nn.CMulTable()({inGate, inTransform}) 33 | }) 34 | 35 | -- Gated cells form the output. 36 | local nextH = nn.CMulTable()({outGate, nn.Tanh()(nextC)}) 37 | 38 | return nn.gModule(inputs, {nextC, nextH}) 39 | end 40 | 41 | local gmod = _buildLayer(10,5) 42 | 43 | gmod:forward({torch.randn(3, 5), torch.randn(3,5), torch.randn(3,10)}) 44 | 45 | torch.save("model3.t7", gmod) -------------------------------------------------------------------------------- /onnx/operators/squeeze.lua: -------------------------------------------------------------------------------- 1 | local Squeeze, parent = torch.class('onnx.node.Squeeze', 'onnx.node') 2 | 3 | function Squeeze:__init(inputs, outputs, axes) 4 | parent.__init(self, "Squeeze", inputs, 1, outputs, 1) 5 | self._axes = axes 6 | end 7 | 8 | function _find(v, t) 9 | for _, x in ipairs(t) do 10 | if x == v then 11 | return true 12 | end 13 | end 14 | return false 15 | end 16 | 17 | -- given some constraint for the named parameters, check the compatibility 18 | -- and refine these constraints 19 | function Squeeze:getShapeConstraint(checker) 20 | checker:setChange(false) 21 | 22 | local cx = checker:getParam(self._inputs[1]) 23 | local cy = checker:getParam(self._outputs[1]) 24 | 25 | if #cx == 0 and #cy ~=0 then 26 | for i = 1, #cy do 27 | table.insert(cx, cy[i]) 28 | end 29 | for i = 1, #self._axes do 30 | table.insert(cx, i+1, 1) 31 | end 32 | checker:setChange(true) 33 | elseif #cx ~= 0 and #cy == 0 then 34 | for i = 1, #cx do 35 | if not _find(i-1, self._axes) then 36 | table.insert(cy, cx[i]) 37 | end 38 | end 39 | checker:setChange(true) 40 | elseif #cx ~= 0 and #cy ~= 0 then 41 | assert(#cy==#cx-#self._axes, "invalid shapes") 42 | end 43 | 44 | return checker:hasChange() 45 | end 46 | 47 | function Squeeze:build(onnx_pb, node) 48 | parent.build(self, onnx_pb, node) 49 | self.addAttribute(node, "axes", 'ints', self._axes, onnx_pb.AttributeProto.INTS) 50 | end -------------------------------------------------------------------------------- /onnx/operators/unsqueeze.lua: -------------------------------------------------------------------------------- 1 | local Unsqueeze, parent = torch.class('onnx.node.Unsqueeze', 'onnx.node') 2 | 3 | function Unsqueeze:__init(inputs, outputs, axes) 4 | parent.__init(self, "Unsqueeze", inputs, 1, outputs, 1) 5 | self._axes = axes 6 | end 7 | 8 | function _find(v, t) 9 | for _, x in ipairs(t) do 10 | if x == v then 11 | return true 12 | end 13 | end 14 | return false 15 | end 16 | 17 | -- given some constraint for the named parameters, check the compatibility 18 | -- and refine these constraints 19 | function Unsqueeze:getShapeConstraint(checker) 20 | checker:setChange(false) 21 | 22 | local cx = checker:getParam(self._inputs[1]) 23 | local cy = checker:getParam(self._outputs[1]) 24 | 25 | if #cy == 0 and #cx ~= 0 then 26 | for i = 1, #cx do 27 | table.insert(cy, cx[i]) 28 | end 29 | for i = 1, #self._axes do 30 | table.insert(cy, i+1, 1) 31 | end 32 | checker:setChange(true) 33 | elseif #cy ~= 0 and #cx == 0 then 34 | for i = 1, #cy do 35 | if not _find(i-1, self._axes) then 36 | table.insert(cx, cy[i]) 37 | end 38 | end 39 | checker:setChange(true) 40 | elseif #cx ~= 0 and #cy ~= 0 then 41 | assert(#cx==#cy-#self._axes, "invalid shapes") 42 | end 43 | 44 | return checker:hasChange() 45 | end 46 | 47 | function Unsqueeze:build(onnx_pb, node) 48 | parent.build(self, onnx_pb, node) 49 | self.addAttribute(node, "axes", 'ints', self._axes, onnx_pb.AttributeProto.INTS) 50 | end -------------------------------------------------------------------------------- /convertors/init.lua: -------------------------------------------------------------------------------- 1 | local convertor = {} 2 | 3 | -- cache the open convertors file 4 | local convertors = {} 5 | 6 | local function split(str, sep) 7 | local res = {} 8 | local index = 1 9 | 10 | while index <= str:len() do 11 | local sepStart, sepEnd = str:find(sep, index) 12 | 13 | local sub 14 | if not sepStart then 15 | sub = str:sub(index) 16 | table.insert(res, sub) 17 | index = str:len() + 1 18 | else 19 | sub = str:sub(index, sepStart - 1) 20 | table.insert(res, sub) 21 | index = sepEnd + 1 22 | if index > str:len() then 23 | table.insert(res, '') 24 | end 25 | end 26 | end 27 | 28 | return res 29 | end 30 | 31 | function convertor.mtype(object) 32 | if type(object) == 'table' and object.__typename then 33 | return object.__typename 34 | else 35 | return torch.type(object) 36 | end 37 | end 38 | 39 | function convertor.isSupported(tname) 40 | local namespace = tname 41 | local object = '' 42 | local decomp_name = split(tname, '%.') 43 | if #decomp_name > 1 then 44 | namespace = tname:sub(1, -decomp_name[#decomp_name]:len()-2) 45 | object = tname:sub(-decomp_name[#decomp_name]:len()) 46 | end 47 | if convertors[namespace] == nil then 48 | local _, err = pcall(function() 49 | convertors[namespace] = require('convertors.onnx_'..namespace) 50 | end) 51 | if err then 52 | print('no convertors for '..namespace) 53 | convertors[namespace] = false 54 | end 55 | end 56 | return convertors[namespace] and convertors[namespace][object] 57 | end 58 | 59 | return convertor -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # luatorch-onnx-convert 2 | 3 | ## Overview 4 | 5 | This repository provides an utility to extract model(s) from a serialized lua-torch model (`.t7` file) and convert it/them to onnx format. 6 | 7 | ## Requirements 8 | 9 | * You need first to install lua torch as well as the library used in the model that are necessary to load the model from torch (for instance, `nn`, `nngraph`, `onmt` or your own libraries). 10 | * Install protobuf lua library from this repository https://github.com/jsenellart/protobuf-lua (this version is mandatory since it fixes some issues with the original implementation). 11 | 12 | You can regenerate the lua interpreter of the onnx proto file by doing: 13 | 14 | ``` 15 | protoc --lua_out=. onnx.proto 16 | ``` 17 | 18 | This will generate an updated version of `onnx_pb.lua` 19 | 20 | ## Extracting model 21 | 22 | ``` 23 | $ th convert.lua -t7 test/data/model1.t7 -require nn -force 24 | ``` 25 | 26 | `convert.lua` goes through the serialized torch object and looks for supported modules or models. Each of them is converted into `test/data/model1.onnxdir` directory with the corresponding name: 27 | 28 | ``` 29 | model.linear-bias.onnx 30 | model.linear-nobias.onnx 31 | ``` 32 | 33 | The option `-require nn` indicates which library are necessary to deserialize the model. GPU model needs `cunn` installation to be read. 34 | 35 | ## Convertors 36 | 37 | To perform conversion to onnx, each of the used module must be described with onnx operators. New convertors for specific library can be added in `convertors`. Their name must match `onnx_class` where `class` is the torch classname of the modules. 38 | -------------------------------------------------------------------------------- /onnx/operators/gather.lua: -------------------------------------------------------------------------------- 1 | local Gather, parent = torch.class('onnx.node.Gather', 'onnx.node') 2 | 3 | function Gather:__init(inputs, outputs, precision, axis) 4 | parent.__init(self, "Gather", inputs, 2, outputs, 1) 5 | self._precision = precision 6 | self._axis = axis 7 | end 8 | 9 | -- given some constraint for the named parameters, check the compatibility 10 | -- and refine these constraints 11 | function Gather:getShapeConstraint(checker) 12 | local cind = checker:getParam(self._inputs[1]) 13 | local cx = checker:getParam(self._inputs[2]) 14 | local cy = checker:getParam(self._outputs[1]) 15 | 16 | checker:setType(self._inputs[1], "FLOAT") 17 | checker:setType(self._inputs[2], "INT32") 18 | checker:setType(self._outputs[1], "FLOAT") 19 | 20 | checker:setChange(false) 21 | 22 | if #cx ~= 0 and #cind ~= 0 then 23 | if #cy == 0 then 24 | for i = 1, #cx do 25 | table.insert(cy, cx[i]) 26 | end 27 | for i = 1, #cind - 1 do 28 | table.insert(cy, cind[i+1]) 29 | end 30 | checker:setChange(true) 31 | else 32 | assert(#cy == #cx + #cind - 1, "invalid size of gather output") 33 | for i = 1, #cx do 34 | checker:dimCheck(cy, i, cx, i) 35 | end 36 | for i = 1, #cind - 1 do 37 | checker:dimCheck(cy, i+#cx-1, cind, i+1 ) 38 | end 39 | end 40 | elseif #cy ~= 0 then 41 | checker:setChange(true) 42 | if #cx == 0 then 43 | cx = check:assertND(self._inputs[2], #cy - #cind + 1) 44 | else 45 | cind = check:assertND(self._inputs[1], #cy - #cx + 1) 46 | end 47 | end 48 | 49 | return checker:hasChange() 50 | end 51 | 52 | function Gather:build(onnx_pb, node) 53 | parent.build(self, onnx_pb, node) 54 | self.addAttribute(node, "axis", 'i', self._axis, onnx_pb.AttributeProto.INT) 55 | end -------------------------------------------------------------------------------- /onnx/operators/matmul.lua: -------------------------------------------------------------------------------- 1 | local MatMul, parent = torch.class('onnx.node.MatMul', 'onnx.node') 2 | 3 | function MatMul:__init(inputs, outputs, precision) 4 | parent.__init(self, "MatMul", inputs, 2, outputs, 1) 5 | self._precision = precision 6 | end 7 | 8 | -- given some constraint for the named parameters, check the compatibility 9 | -- and refine these constraints 10 | function MatMul:getShapeConstraint(checker) 11 | checker:setChange(false) 12 | 13 | local ca = checker:getParam(self._inputs[1]) 14 | local cb = checker:getParam(self._inputs[2]) 15 | local cy = checker:getParam(self._outputs[1]) 16 | 17 | if #ca > 2 or #cb > 2 or #cy >= 2 then 18 | local n = #ca 19 | if #cb > n then n = #cb end 20 | if #cy > n then n = #cb end 21 | ca = checker:assertND(self._inputs[1], n) 22 | cb = checker:assertND(self._inputs[2], n) 23 | cy = checker:assertND(self._outputs[1], n) 24 | local b = n - 2 25 | for i = 1, b do 26 | checker:dimCheck(ca, i, cb, i) 27 | checker:dimCheck(ca, i, cy, i) 28 | end 29 | self._pass = checker:dimCheck(ca, 1+b, cy, 1+b) or checker:fail() 30 | self._pass = checker:dimCheck(ca, 2+b, cb, 1+b) or checker:fail() 31 | self._pass = checker:dimCheck(cb, 2+b, cy, 2+b) or checker:fail() 32 | elseif #ca == 1 or #cb == 1 or #cy == 1 then 33 | cy = checker:assert1D(self._outputs[1]) 34 | if #ca == 1 or #cb == 2 then 35 | ca = checker:assert1D(self._inputs[1]) 36 | cb = checker:assert2D(self._inputs[2]) 37 | self._pass = checker:dimCheck(ca, 1, cb, 1) or checker:fail() 38 | self._pass = checker:dimCheck(cb, 2, cy, 1) or checker:fail() 39 | elseif #cb == 1 then 40 | ca = checker:assert2D(self._inputs[1]) 41 | cb = checker:assert1D(self._inputs[2]) 42 | self._pass = checker:dimCheck(ca, 1, cy, 1) or checker:fail() 43 | self._pass = checker:dimCheck(ca, 2, cb, 1) or checker:fail() 44 | end 45 | end 46 | 47 | return checker:hasChange() 48 | end 49 | -------------------------------------------------------------------------------- /onnx/operators/concat.lua: -------------------------------------------------------------------------------- 1 | local Concat, parent = torch.class('onnx.node.Concat', 'onnx.node') 2 | 3 | function Concat:__init(inputs, outputs, axis) 4 | parent.__init(self, "Concat", inputs, #inputs, outputs, 1) 5 | self._axis = axis 6 | end 7 | 8 | -- given some constraint for the named parameters, check the compatibility 9 | -- and refine these constraints 10 | function Concat:getShapeConstraint(checker) 11 | checker:setChange(false) 12 | 13 | local cy = checker:getParam(self._outputs[1]) 14 | 15 | local nbdim 16 | local sizes 17 | local sumaxisx = 0 18 | 19 | if #cy ~= 0 then 20 | sizes = cy 21 | nbdim = #cy 22 | end 23 | for _, p in pairs(self._inputs) do 24 | local cx = checker:getParam(p) 25 | if #cx ~= 0 then 26 | assert(nbdim == nil or nbdim == #cx, "inconsistent number of dimensions") 27 | if sumaxisx ~= nil and cx[self._axis+1] > 0 then 28 | sumaxisx = sumaxisx + cx[self._axis+1] 29 | else 30 | sumaxisx = nil 31 | end 32 | nbdim = #cx 33 | sizes = {} 34 | for _, v in ipairs(cx) do 35 | table.insert(sizes, v) 36 | end 37 | else 38 | sumaxisx = nil 39 | end 40 | end 41 | 42 | if nbdim ~= nil then 43 | for _, p in pairs(self._inputs) do 44 | local cx = checker:assertND(p, nbdim) 45 | for i = 1, nbdim do 46 | if i-1 ~= self._axis then 47 | checker:dimCheck(cx, i, sizes, i) 48 | end 49 | end 50 | end 51 | 52 | if #cy == 0 then 53 | cy = checker:assertND(self._outputs[1], nbdim) 54 | checker:setChange(true) 55 | end 56 | if sumaxisx ~= nil then 57 | sizes[self._axis+1] = sumaxisx 58 | else 59 | sizes[self._axis+1] = cy[self._axis+1] 60 | end 61 | for i = 1, nbdim do 62 | checker:dimCheck(cy, i, sizes, i) 63 | end 64 | end 65 | 66 | return checker:hasChange() 67 | end 68 | 69 | function Concat:build(onnx_pb, node) 70 | parent.build(self, onnx_pb, node) 71 | self.addAttribute(node, "axis", 'i', self._axis, onnx_pb.AttributeProto.INT) 72 | end -------------------------------------------------------------------------------- /onnx/operators/split.lua: -------------------------------------------------------------------------------- 1 | local Split, parent = torch.class('onnx.node.Split', 'onnx.node') 2 | 3 | function Split:__init(inputs, outputs, axis) 4 | parent.__init(self, "Split", inputs, 1, outputs, #outputs) 5 | self._axis = axis 6 | end 7 | 8 | -- given some constraint for the named parameters, check the compatibility 9 | -- and refine these constraints 10 | function Split:getShapeConstraint(checker) 11 | checker:setChange(false) 12 | 13 | local cx = checker:getParam(self._inputs[1]) 14 | 15 | local nbDimOutput 16 | local cys = {} 17 | for _, p in pairs(self._outputs) do 18 | local cy = checker:getParam(p) 19 | if #cy ~= 0 then 20 | if cy[self._axis+1] < 0 then 21 | checker:changeUnk(cy[self._axis+1], 1) 22 | else 23 | assert(cy[self._axis+1] == 1, "inconsistent size of axis output") 24 | end 25 | if nbDimOuput == nil then 26 | nbDimOutput = #cy 27 | else 28 | assert(nbDimOutput == #cy, "inconsistent dimension of Split output") 29 | end 30 | end 31 | table.insert(cys, cy) 32 | end 33 | 34 | if nbDimOutput then 35 | checker:sameShape(cys) 36 | if #cx == 0 then 37 | for i, d in pairs(cys[1]) do 38 | if i-1 == self._axis then 39 | table.insert(cx, #self._outputs) 40 | else 41 | table.insert(cx, d) 42 | end 43 | end 44 | else 45 | assert(#cx == #cys[1], "incorrect dimension of split input") 46 | for i, d in pairs(cys[1]) do 47 | if i-1 ~= self._axis then 48 | checker:dimCheck(cx, i, cys[1], i) 49 | else 50 | if cx[i] < 0 then 51 | checker:changeUnk(cx[i], #self._outputs) 52 | else 53 | assert(cx[i] == #self._outputs, "invalid split axis dimension") 54 | end 55 | end 56 | end 57 | end 58 | end 59 | 60 | return checker:hasChange() 61 | end 62 | 63 | function Split:build(onnx_pb, node) 64 | parent.build(self, onnx_pb, node) 65 | self.addAttribute(node, "axis", 'i', self._axis, onnx_pb.AttributeProto.INT) 66 | end -------------------------------------------------------------------------------- /convertors/onnx_onmt.lua: -------------------------------------------------------------------------------- 1 | local onnx_nn = require 'convertors.onnx_nn' 2 | local convertor = require 'convertors.init' 3 | 4 | local onnx_onmt = {} 5 | 6 | function onnx_onmt.WordEmbedding(obj, nInputs, nonbatch_mode) 7 | return onnx_nn.LookupTable(obj.net, nInputs, nonbatch_mode) 8 | end 9 | 10 | function onnx_onmt.LSTM(obj, nInputs, nonbatch_mode) 11 | return onnx_nn.gModule(obj.net, nInputs, nonbatch_mode) 12 | end 13 | 14 | function onnx_onmt.Bridge(obj, nInputs, nonbatch_mode) 15 | local obj = obj.net 16 | local tname = convertor.mtype(obj) 17 | if type(obj) == 'userdata' or type(obj) == 'table' then 18 | local convert_func = convertor.isSupported(tname) 19 | if convert_func then 20 | return convert_func(obj, nInputs, nonbatch_mode) 21 | else 22 | error('module `'..tname..'` not supported') 23 | end 24 | else 25 | error("unsupported module in onmt.Bridge: `"..tname.."`") 26 | end 27 | end 28 | 29 | function onnx_onmt.GlobalAttention(obj, nInputs, nonbatch_mode) 30 | local obj = obj.net 31 | local tname = convertor.mtype(obj) 32 | if type(obj) == 'userdata' or type(obj) == 'table' then 33 | local convert_func = convertor.isSupported(tname) 34 | if convert_func then 35 | return convert_func(obj, nInputs, nonbatch_mode) 36 | else 37 | error('module `'..tname..'` not supported') 38 | end 39 | else 40 | error("unsupported module in onmt.Bridge: `"..tname.."`") 41 | end 42 | end 43 | 44 | function onnx_onmt.Encoder(obj, nInputs, nonbatch_mode) 45 | local obj = obj.network 46 | local tname = convertor.mtype(obj) 47 | if type(obj) == 'userdata' or type(obj) == 'table' then 48 | local convert_func = convertor.isSupported(tname) 49 | if convert_func then 50 | return convert_func(obj, nInputs, nonbatch_mode) 51 | else 52 | error('module `'..tname..'` not supported') 53 | end 54 | else 55 | error("unsupported module in onmt.Encoder: `"..tname.."`") 56 | end 57 | end 58 | 59 | function onnx_onmt.Decoder(obj, nInputs, nonbatch_mode) 60 | local obj = obj.network 61 | local tname = convertor.mtype(obj) 62 | if type(obj) == 'userdata' or type(obj) == 'table' then 63 | local convert_func = convertor.isSupported(tname) 64 | if convert_func then 65 | return convert_func(obj, nInputs, nonbatch_mode) 66 | else 67 | error('module `'..tname..'` not supported') 68 | end 69 | else 70 | error("unsupported module in onmt.Decoder: `"..tname.."`") 71 | end 72 | end 73 | 74 | return onnx_onmt -------------------------------------------------------------------------------- /onnx/operators/gemm.lua: -------------------------------------------------------------------------------- 1 | local Gemm, parent = torch.class('onnx.node.Gemm', 'onnx.node') 2 | 3 | function Gemm:__init(inputs, outputs, precision, 4 | alpha, beta, broadcastC, transposeA, transposeB) 5 | parent.__init(self, "Gemm", inputs, 3, outputs, 1) 6 | self._precision = precision 7 | self._alpha = alpha 8 | self._beta = beta 9 | self._broadcastC = broadcastC == 1 10 | self._transposeA = transposeA == 1 11 | self._transposeB = transposeB == 1 12 | end 13 | 14 | -- given some constraint for the named parameters, check the compatibility 15 | -- and refine these constraints 16 | function Gemm:getShapeConstraint(checker) 17 | checker:setChange(false) 18 | local ca = checker:getParam(self._inputs[1]) 19 | local cb = checker:assert2D(self._inputs[2]) 20 | local cc = checker:getParam(self._inputs[3]) 21 | local cy = checker:getParam(self._outputs[1]) 22 | 23 | if #cy == 1 or #ca == 1 then 24 | -- 1D input 25 | ca = checker:assert1D(self._inputs[1]) 26 | cc = checker:assert1D(self._inputs[3]) 27 | cy = checker:assert1D(self._outputs[1]) 28 | assert(not self._transposeA, "cannot use transposeA with 1D input matrix") 29 | if self._transposeB then 30 | self._pass = checker:dimCheck(ca, 1, cb, 2) or checker:fail() 31 | self._pass = checker:dimCheck(cb, 1, cy, 1) or checker:fail() 32 | else 33 | self._pass = checker:dimCheck(ca, 1, cb, 1) or checker:fail() 34 | self._pass = checker:dimCheck(cb, 2, cy, 1) or checker:fail() 35 | end 36 | self._pass = checker:dimCheck(cc, 1, cy, 1) or checker:fail() 37 | elseif #ca == 2 or #cy == 2 then 38 | ca = checker:assert2D(self._inputs[1]) 39 | cc = checker:assert1or2D(self._inputs[3]) 40 | cy = checker:assert2D(self._outputs[1]) 41 | -- 2D input 42 | if not self._transposeA then 43 | self._pass = checker:dimCheck(ca, 1, cy, 1) or checker:fail() 44 | if self._transposeB then 45 | self._pass = checker:dimCheck(ca, 2, cb, 2) or checker:fail() 46 | else 47 | self._pass = checker:dimCheck(ca, 2, cb, 1) or checker:fail() 48 | end 49 | else 50 | self._pass = checker:dimCheck(ca, 2, cy, 1) or checker:fail() 51 | if self._transposeB then 52 | self._pass = checker:dimCheck(ca, 1, cb, 2) or checker:fail() 53 | else 54 | self._pass = checker:dimCheck(ca, 1, cb, 1) or checker:fail() 55 | end 56 | end 57 | if self._transposeB then 58 | self._pass = checker:dimCheck(cb, 1, cy, 2) or checker:fail() 59 | else 60 | self._pass = checker:dimCheck(cb, 2, cy, 2) or checker:fail() 61 | end 62 | if #cc == 1 then 63 | self._pass = checker:dimCheck(cc, 1, cy, 2) or 64 | checker:dimCheck(cc, 1, cy, 1) or checker:fail() 65 | elseif #cc == 2 then 66 | self._pass = checker:dimCheck(cc, 1, cy, 1) or checker:fail() 67 | self._pass = checker:dimCheck(cc, 2, cy, 2) or checker:fail() 68 | end 69 | end 70 | 71 | return checker:hasChange() 72 | end 73 | 74 | function Gemm:build(onnx_pb, node) 75 | parent.build(self, onnx_pb, node) 76 | self.addAttribute(node, "alpha", 'f', self._alpha, onnx_pb.AttributeProto.FLOAT) 77 | self.addAttribute(node, "beta", 'f', self._beta, onnx_pb.AttributeProto.FLOAT) 78 | self.addAttribute(node, "broadcast", 'i', self._broadcastC and 1 or 0, onnx_pb.AttributeProto.INT) 79 | self.addAttribute(node, "transA", 'i', self._transposeA and 1 or 0, onnx_pb.AttributeProto.INT) 80 | self.addAttribute(node, "transB", 'i', self._transposeB and 1 or 0, onnx_pb.AttributeProto.INT) 81 | end -------------------------------------------------------------------------------- /convertors/onnx_nngraph.lua: -------------------------------------------------------------------------------- 1 | local onnx_nn = {} 2 | 3 | local convertor = require 'convertors.init' 4 | require 'onnx.init' 5 | 6 | function serialize(t) 7 | if type(t) == 'table' then 8 | local s = '' 9 | for i,v in pairs(t) do 10 | if s ~= '' then 11 | s = s..',' 12 | else 13 | s = '' 14 | end 15 | s =s .. i..':'..serialize(v) 16 | end 17 | return '{'..s..'}' 18 | else 19 | return t 20 | end 21 | end 22 | 23 | function onnx_nn.gModule(obj, _, nonbatch_mode) 24 | local inputs = {} 25 | local outputs = {} 26 | for i = 1, obj.nInputs do 27 | table.insert(inputs, "x" .. i) 28 | end 29 | 30 | for i = 1, #obj.outnode.children do 31 | table.insert(outputs, "y" .. i) 32 | end 33 | 34 | local graph = onnx.graph.new(inputs, outputs) 35 | 36 | local function neteval(idx, node) 37 | local function propagate(node, x) 38 | for i, child in ipairs(node.children) do 39 | child.data.input = child.data.input or {} 40 | local mapindex = child.data.mapindex[node.data] 41 | assert(not child.data.input[mapindex], "each input should have one source") 42 | child.data.input[mapindex] = x 43 | end 44 | end 45 | if node.data.selectindex then 46 | assert(not node.data.module, "the selectindex-handling nodes should have no module") 47 | local input = node.data.input 48 | input = input[1][node.data.selectindex] 49 | propagate(node, input) 50 | else 51 | local inputs = node.data.input 52 | -- a parameter node is captured 53 | if inputs == nil and node.data.module ~= nil then 54 | inputs = {} 55 | end 56 | if #inputs == 1 then 57 | inputs = inputs[1] 58 | end 59 | 60 | -- forward through this node 61 | -- If no module is present, the node behaves like nn.Identity. 62 | local outputs 63 | if not node.data.module then 64 | outputs = inputs 65 | else 66 | local object = node.data.module 67 | local tname = convertor.mtype(object) 68 | if type(object) == 'userdata' or type(object) == 'table' then 69 | local convert_func = convertor.isSupported(tname) 70 | if convert_func then 71 | if type(inputs) ~= 'table' then 72 | inputs = { inputs } 73 | end 74 | local subgraph = convert_func(object, #inputs, nonbatch_mode) 75 | outputs = subgraph._outputs 76 | graph:merge(subgraph, idx) 77 | for i = 1, #inputs do 78 | graph:substitute_param(subgraph._inputs[i], inputs[i]) 79 | end 80 | else 81 | error('module `'..tname..'` not supported') 82 | end 83 | end 84 | end 85 | if #outputs == 1 then 86 | outputs = outputs[1] 87 | end 88 | propagate(node, outputs) 89 | end 90 | end 91 | 92 | local innode = obj.innode 93 | 94 | -- first clear the input states 95 | for _, node in ipairs(obj.forwardnodes) do 96 | local input = node.data.input 97 | while input and #input>0 do 98 | table.remove(input) 99 | end 100 | end 101 | -- Set the starting input. 102 | -- We do copy instead of modifying the passed input. 103 | obj.innode.data.input = obj.innode.data.input or {} 104 | for i, item in ipairs(inputs) do 105 | obj.innode.data.input[i] = item 106 | end 107 | 108 | -- the run forward 109 | for i, node in ipairs(obj.forwardnodes) do 110 | neteval(i, node) 111 | end 112 | 113 | for i, p in ipairs(outputs) do 114 | graph:substitute_param(obj.outnode.data.input[i], p) 115 | end 116 | 117 | return graph 118 | 119 | end 120 | 121 | return onnx_nn -------------------------------------------------------------------------------- /onnx/checker.lua: -------------------------------------------------------------------------------- 1 | local Checker = torch.class('onnx.checker') 2 | 3 | function Checker:__init() 4 | self._params = {} 5 | self._types = {} 6 | self._change = false 7 | self._unkDimIdx = -1; 8 | end 9 | 10 | function Checker:setChange(v) 11 | self._change = v 12 | end 13 | 14 | function Checker:hasChange() 15 | return self._change 16 | end 17 | 18 | function Checker:getUnkDimIdx() 19 | self._unkDimIdx = self._unkDimIdx - 1 20 | return self._unkDimIdx + 1 21 | end 22 | 23 | function Checker:fail() 24 | error(self._err) 25 | end 26 | 27 | function Checker:params() 28 | return self._params 29 | end 30 | 31 | function Checker:setType(p, t) 32 | self._types[p] = t 33 | end 34 | 35 | function Checker:getType(p) 36 | return self._types[p] or "FLOAT" 37 | end 38 | 39 | -- get or create a param, we don't know dimension 40 | function Checker:getParam(param) 41 | if self._params[param] == nil then 42 | self._change = true 43 | self._params[param] = {} 44 | end 45 | return self._params[param] 46 | end 47 | 48 | function Checker:changeUnk(v1, v2) 49 | self._change = true 50 | for _, S in pairs(self._params) do 51 | for i, d in ipairs(S) do 52 | if d == v1 then 53 | S[i] = v2 54 | end 55 | end 56 | end 57 | end 58 | 59 | function Checker:dimCheck(p1, i1, p2, i2) 60 | if p1[i1] == p2[i2] then 61 | return true 62 | end 63 | if p1[i1] < 0 then 64 | self:changeUnk(p1[i1], p2[i2]) 65 | return true 66 | elseif p2[i2] < 0 then 67 | self:changeUnk(p2[i2], p1[i1]) 68 | return true 69 | else 70 | self._err = '`'..p1[i1]..'` (dim '..i1..') different from `'..p2[i2]..'` (dim '..i2..')' 71 | return false 72 | end 73 | end 74 | 75 | function Checker:sameShape(t) 76 | local idx_nz = 1 77 | while idx_nz <= #t and #t[idx_nz] == 0 do 78 | idx_nz = idx_nz + 1 79 | end 80 | if idx_nz > #t then 81 | -- cannot find a non null member 82 | return true 83 | end 84 | for j = 1, #t do 85 | if #t[j] ~= 0 and #t[j] ~= #t[idx_nz] then 86 | self._err = 'different shapes: '..#t[idx_nz]..'/'..#t[j] 87 | return false 88 | end 89 | if #t[j] == 0 then 90 | for _, d in ipairs(t[idx_nz]) do 91 | table.insert(t[j], d) 92 | end 93 | self._change = true 94 | else 95 | for h = 1, #t[idx_nz] do 96 | self:dimCheck(t[idx_nz], h, t[j], h) 97 | end 98 | end 99 | end 100 | return true 101 | end 102 | 103 | function Checker:setDims(p1, dims) 104 | local p1dim = self._params[p1] 105 | if p1dim == nil or #p1dim == 0 then 106 | self._params[p1] = dims 107 | self._change = true 108 | return true 109 | end 110 | for i, d in ipairs(p1dim) do 111 | if dims[i] ~= d then 112 | if d < 0 then 113 | self:changeUnk(d, dims[i]) 114 | elseif dims[i] < 0 then 115 | self:changeUnk(dims[i], d) 116 | else 117 | error('incompatible dimension setting') 118 | end 119 | end 120 | end 121 | end 122 | 123 | function Checker:assertND(param, n) 124 | if self._params[param] == nil or #self._params[param] == 0 then 125 | self._change = true 126 | self._params[param] = { } 127 | for i = 1, n do 128 | table.insert(self._params[param], self:getUnkDimIdx()) 129 | end 130 | else 131 | assert(#self._params[param] == n, "param `"..param.."` has inconsistent number of dimension " 132 | ..#self._params[param].."/"..n) 133 | end 134 | return self._params[param] 135 | end 136 | 137 | function Checker:assert2D(param) 138 | return self:assertND(param, 2) 139 | end 140 | 141 | function Checker:assert1D(param) 142 | return self:assertND(param, 1) 143 | end 144 | 145 | function Checker:assert1or2D(param) 146 | if self._params[param] == nil or #self._params[param] == 0 then 147 | return {} 148 | else 149 | assert(#self._params[param] == 2 or #self._params[param] == 1, 150 | "param `"..param.."` has inconsistent number of dimension") 151 | end 152 | return self._params[param] 153 | end -------------------------------------------------------------------------------- /convert.lua: -------------------------------------------------------------------------------- 1 | local onnx_pb = require 'onnx_pb' 2 | require('onnx.init') 3 | local convertor = require 'convertors.init' 4 | 5 | local path = require('pl.path') 6 | 7 | local cmd = torch.CmdLine.new() 8 | 9 | local function split(str, sep) 10 | local res = {} 11 | local index = 1 12 | 13 | while index <= str:len() do 14 | local sepStart, sepEnd = str:find(sep, index) 15 | 16 | local sub 17 | if not sepStart then 18 | sub = str:sub(index) 19 | table.insert(res, sub) 20 | index = str:len() + 1 21 | else 22 | sub = str:sub(index, sepStart - 1) 23 | table.insert(res, sub) 24 | index = sepEnd + 1 25 | if index > str:len() then 26 | table.insert(res, '') 27 | end 28 | end 29 | end 30 | 31 | return res 32 | end 33 | 34 | cmd:option('-t7', '', [[Path to the torch serialized file.]]) 35 | cmd:option('-require', 'nngraph', [[List of modules to import for loading the torch object (default nngraph).]]) 36 | cmd:option('-models', '', [[Field in the object where the models is/are.]]) 37 | cmd:option('-output_dir', '', [[Path to directory where onnx models will be serialized.]] .. 38 | [[If not set, the extension is changed to _onnxdir.]]) 39 | cmd:option('-nonbatch_mode', true, [[Set if the models were using in non batch mode.]]) 40 | cmd:option('-force', false, [[Force output model creation even if the target file exists.]]) 41 | 42 | local opt = cmd:parse(arg) 43 | 44 | local function convert(output_dir, object, thepath) 45 | thepath = thepath or '' 46 | local prefpath = thepath 47 | if prefpath ~= '' then 48 | prefpath = prefpath .. '.' 49 | end 50 | 51 | local tname = convertor.mtype(object) 52 | if tname == 'table' then 53 | for k, v in pairs(object) do 54 | convert(output_dir, v, prefpath..k) 55 | end 56 | elseif type(object) == 'userdata' or type(object) == 'table' then 57 | local convert_func = convertor.isSupported(tname) 58 | if convert_func then 59 | print('convert '..thepath..'=`'..tname..'`') 60 | local save_outputs = {} 61 | if object.output then 62 | local outputs = object.output 63 | if type(outputs) ~= 'table' then 64 | outputs = { outputs } 65 | end 66 | for _, o in ipairs(outputs) do 67 | table.insert(save_outputs, o) 68 | end 69 | end 70 | local graph = convert_func(object, nil, nonbatch_mode) 71 | for i, p in ipairs(graph._outputs) do 72 | if save_outputs[i] ~= nil then 73 | graph:set_dimension(p, save_outputs[i]:size():totable()) 74 | end 75 | end 76 | local model = onnx_pb.ModelProto() 77 | model.ir_version = onnx_pb.VERSION_IR_VERSION_ENUM.number 78 | model.producer_name = 'lua-onnx-convert' 79 | model.producer_version = '0.0.1' 80 | local version = model.opset_import:add() 81 | version.version = 6 82 | model.graph.name = thepath 83 | graph:build(onnx_pb, model.graph) 84 | local output = assert(io.open(output_dir .. '/' .. thepath .. '.onnx', "wb")) 85 | model:SerializeToIOString(output) 86 | output:close() 87 | else 88 | if object.modules and #object.modules == 1 then 89 | convert(output_dir, object.modules, prefpath..'modules') 90 | end 91 | print('\tskipping '..thepath..' ('..tname..')') 92 | end 93 | end 94 | end 95 | 96 | local function main() 97 | assert(path.exists(opt.t7), 'file \'' .. opt.t7 .. '\' does not exist.') 98 | 99 | if opt.output_dir:len() == 0 then 100 | if opt.t7:sub(-3) == '.t7' then 101 | opt.output_dir = opt.t7:sub(1, -4) -- copy input model without '.t7' extension 102 | else 103 | opt.output_dir = opt.t7 104 | end 105 | opt.output_dir = opt.output_dir .. '.onnxdir' 106 | end 107 | 108 | if not opt.force then 109 | assert(not path.exists(opt.output_dir), 110 | 'output dir already exists; use -force to overwrite.') 111 | end 112 | 113 | if path.exists(opt.output_dir) then 114 | assert(path.isdir(opt.output_dir), 115 | 'output ('..opt.output_dir..') is not a directory') 116 | assert(opt.force, 117 | 'output dir already exists; use -force to overwrite.') 118 | else 119 | path.mkdir(opt.output_dir) 120 | end 121 | 122 | if opt.require ~= '' then 123 | local requires = split(opt.require, ',') 124 | for _, r in ipairs(requires) do 125 | print('import module `'..r..'`') 126 | require(r) 127 | end 128 | end 129 | 130 | -- try loading cutorch modules - while issue warning if not installed 131 | local _, err = pcall(function() 132 | require('cutorch') 133 | require('cunn') 134 | end) 135 | 136 | if err then 137 | print('warning: Failed loading cutorch/cunn, GPU models cannot be read') 138 | end 139 | 140 | print('Loading model \'' .. opt.t7 .. '\'...') 141 | 142 | local obj 143 | _, err = pcall(function () 144 | obj = torch.load(opt.t7) 145 | end) 146 | if err then 147 | error('unable to load the file (' .. err .. ').') 148 | end 149 | 150 | print('... done.') 151 | 152 | print('Converting model...') 153 | local models 154 | if opt.models ~= '' then 155 | models = obj[opt.models] 156 | else 157 | models = obj 158 | end 159 | convert(opt.output_dir, models, 'model') 160 | print('... done.') 161 | 162 | end 163 | 164 | main() -------------------------------------------------------------------------------- /onnx/graph.lua: -------------------------------------------------------------------------------- 1 | local Graph = torch.class('onnx.graph') 2 | local paths = require 'paths' 3 | 4 | function Graph:__init(inputs, outputs) 5 | self._nodes = {} 6 | self._node_input_map = {} 7 | self._node_output_map = {} 8 | self._node_map = {} 9 | self._initializer = {} 10 | self._inputs = inputs or {} 11 | self._outputs = outputs or {} 12 | self._checker = onnx.checker.new() 13 | self._tmpfile = paths.tmpname() 14 | 15 | end 16 | 17 | function Graph:add_node(node) 18 | table.insert(self._nodes, node) 19 | self._node_map[torch.pointer(node)] = node 20 | for _, p in ipairs(node:inputs()) do 21 | if self._node_input_map[p] == nil then 22 | self._node_input_map[p] = {} 23 | end 24 | table.insert(self._node_input_map[p], torch.pointer(node)) 25 | end 26 | for _, p in ipairs(node:outputs()) do 27 | self._node_output_map[p] = torch.pointer(node) 28 | end 29 | end 30 | 31 | function Graph:add_initializer(p, obj) 32 | assert(self._node_input_map[p] ~= nil or self._node_output_map[p] ~= nil, "unknown param `"..p.."`") 33 | assert(self._initializer[p] == nil, "two initializers defined for param `"..p.."`") 34 | self._initializer[p] = obj 35 | self._checker:setDims(p, torch.totable(obj:size())) 36 | end 37 | 38 | function Graph:set_dimension(p, dims) 39 | if dims ~= nil then 40 | self._checker:setDims(p, dims) 41 | end 42 | end 43 | 44 | function Graph:substitute_param(p1, p2) 45 | for _, n in ipairs(self._nodes) do 46 | for i,v in ipairs(n._inputs) do 47 | if v == p1 then 48 | n._inputs[i] = p2 49 | end 50 | end 51 | for i,v in ipairs(n._outputs) do 52 | if v == p1 then 53 | n._outputs[i] = p2 54 | end 55 | end 56 | end 57 | for i,v in ipairs(self._inputs) do 58 | if v == p1 then 59 | self._inputs[i] = p2 60 | end 61 | end 62 | for i,v in ipairs(self._outputs) do 63 | if v == p1 then 64 | self._outputs[i] = p2 65 | end 66 | end 67 | self._checker:sameShape({self._checker:getParam(p1), self._checker:getParam(p2)}) 68 | end 69 | 70 | function Graph:merge(subgraph, idx) 71 | for i, v in ipairs(subgraph._inputs) do 72 | subgraph._inputs[i] = 'n'..idx..'.'..v 73 | end 74 | for i, v in ipairs(subgraph._outputs) do 75 | subgraph._outputs[i] = 'n'..idx..'.'..v 76 | end 77 | for _, n in ipairs(subgraph._nodes) do 78 | table.insert(self._nodes, n) 79 | self._node_map[torch.pointer(n)] = n 80 | for i,v in ipairs(n._inputs) do 81 | n._inputs[i] = 'n'..idx..'.'..v 82 | end 83 | for i,v in ipairs(n._outputs) do 84 | n._outputs[i] = 'n'..idx..'.'..v 85 | end 86 | end 87 | for param, pn in pairs(subgraph._node_input_map) do 88 | if self._node_input_map['n'..idx..'.'..param] == nil then 89 | self._node_input_map['n'..idx..'.'..param] = {} 90 | end 91 | table.insert(self._node_input_map['n'..idx..'.'..param], pn) 92 | end 93 | for param, pn in pairs(subgraph._node_output_map) do 94 | if self._node_output_map['n'..idx..'.'..param] == nil then 95 | self._node_output_map['n'..idx..'.'..param] = {} 96 | end 97 | table.insert(self._node_output_map['n'..idx..'.'..param], pn) 98 | end 99 | for param, obj in pairs(subgraph._initializer) do 100 | self:add_initializer('n'..idx..'.'..param, obj) 101 | end 102 | -- propagate parameter constrain from subgraph into graph 103 | for param, shape in pairs(subgraph._checker._params) do 104 | self._checker._params['n'..idx..'.'..param] = shape 105 | end 106 | -- propagate parameter type from subgraph into graph 107 | for param, t in pairs(subgraph._checker._types) do 108 | self._checker._types['n'..idx..'.'..param] = t 109 | end 110 | end 111 | 112 | function Graph:build(onnx_pb, onnx_graph) 113 | 114 | -- shape inference 115 | local change = true 116 | while change do 117 | change = false 118 | for _, n in pairs(self._nodes) do 119 | if n:getShapeConstraint(self._checker) then 120 | change = true 121 | end 122 | end 123 | end 124 | 125 | -- build the graph - input params 126 | for _, p in ipairs(self._inputs) do 127 | local input = onnx_graph.input:add() 128 | input.name = p 129 | if self._checker:getType(p) == "INT32" then 130 | input.type.tensor_type.elem_type = onnx_pb.TensorProto.INT32 131 | else 132 | input.type.tensor_type.elem_type = onnx_pb.TensorProto.FLOAT 133 | end 134 | -- needed because of bug in protobuf library 135 | input.type:_Modified(true) 136 | for _, d in ipairs(self._checker:params()[p]) do 137 | input.type.tensor_type.shape.dim:add().dim_value = d 138 | end 139 | end 140 | 141 | -- add parameters for which we have initializer 142 | for p, _ in pairs(self._initializer) do 143 | local input = onnx_graph.input:add() 144 | input.name = p 145 | if self._checker:getType(p) == "INT32" then 146 | input.type.tensor_type.elem_type = onnx_pb.TensorProto.INT32 147 | else 148 | input.type.tensor_type.elem_type = onnx_pb.TensorProto.FLOAT 149 | end 150 | -- needed because of bug in protobuf library 151 | input.type:_Modified(true) 152 | for _, d in ipairs(self._checker:params()[p]) do 153 | input.type.tensor_type.shape.dim:add().dim_value = d 154 | end 155 | end 156 | 157 | -- build the graph - output params 158 | for _, p in ipairs(self._outputs) do 159 | local output = onnx_graph.output:add() 160 | output.name = p 161 | if self._checker:getType(p) == "INT32" then 162 | output.type.tensor_type.elem_type = onnx_pb.TensorProto.INT32 163 | else 164 | output.type.tensor_type.elem_type = onnx_pb.TensorProto.FLOAT 165 | end 166 | -- needed because of bug in protobuf library 167 | output.type:_Modified(true) 168 | for _, d in ipairs(self._checker:getParam(p)) do 169 | output.type.tensor_type.shape.dim:add().dim_value = d 170 | end 171 | end 172 | 173 | -- build the graph - the actual nodes 174 | for _, n in pairs(self._nodes) do 175 | local node = onnx_graph.node:add() 176 | n:build(onnx_pb, node) 177 | end 178 | 179 | -- dump initializer 180 | for p, w in pairs(self._initializer) do 181 | w = w:float() 182 | local initializer = onnx_graph.initializer:add() 183 | for _, d in ipairs(self._checker:params()[p]) do 184 | initializer.dims:append(d) 185 | end 186 | if self._checker:getType(p) == "INT" then 187 | initializer.data_type = onnx_pb.TensorProto.INT32 188 | else 189 | initializer.data_type = onnx_pb.TensorProto.FLOAT 190 | end 191 | initializer.name = p 192 | local file = torch.DiskFile(self._tmpfile, 'w'):binary() 193 | file:writeFloat(w:storage().new(w:storage(), w:storageOffset(), w:nElement())) 194 | file:close() 195 | local inp = assert(io.open(self._tmpfile, "rb")) 196 | initializer.raw_data = inp:read("*all") 197 | end 198 | 199 | end -------------------------------------------------------------------------------- /convertors/onnx_nn.lua: -------------------------------------------------------------------------------- 1 | local onnx_nn = require 'convertors.onnx_nngraph' 2 | local convertor = require 'convertors.init' 3 | 4 | function onnx_nn.Linear(obj, nInputs, nonbatch_mode) 5 | nInputs = nInputs or 1 6 | assert(nInputs == 1, "nn.Linear can not have multiple inputs") 7 | local graph = onnx.graph.new({'x'}, {'y'}) 8 | if obj.bias == nil then 9 | local perms = { 0, 2, 1 } 10 | if nonbatch_mode or (obj and #obj.output:size()==1) then 11 | perms = { 1, 0 } 12 | end 13 | graph:add_node(onnx.node.Transpose.new({'b'}, {'bt'}, 14 | perms)) 15 | graph:add_node(onnx.node.MatMul.new({'x', 'bt'}, {'y'}, 16 | onnx.helper.convertPrecision(obj.weight))) 17 | else 18 | graph:add_node(onnx.node.Gemm.new({'x', 'b', 'c'}, {'y'}, 19 | onnx.helper.convertPrecision(obj.weight), 20 | 1.0, -- alpha 21 | 1.0, -- beta 22 | 1, -- broadcast C 23 | 0, -- transpose A 24 | 1)) -- transpose B 25 | graph:add_initializer('c', obj.bias) 26 | end 27 | graph:add_initializer('b', obj.weight) 28 | return graph 29 | end 30 | 31 | function onnx_nn.MM(obj, nInputs, nonbatch_mode) 32 | local inputs = { 'a', 'b' } 33 | local graph = onnx.graph.new(inputs, {'y'}) 34 | 35 | if obj.transA then 36 | inputs[1] = 'at' 37 | local perms = { 0, 2, 1 } 38 | if nonbatch_mode or (obj and #obj.output:size()==2) then 39 | perms = { 1, 0 } 40 | end 41 | graph:add_node(onnx.node.Transpose.new({'a'}, {'at'}, 42 | perms)) 43 | end 44 | if obj.transB then 45 | inputs[2] = 'bt' 46 | local perms = { 0, 2, 1 } 47 | if nonbatch_mode or (obj and #obj.output:size()==2) then 48 | perms = { 1, 0 } 49 | end 50 | graph:add_node(onnx.node.Transpose.new({'b'}, {'bt'}, 51 | perms)) 52 | end 53 | graph:add_node(onnx.node.MatMul.new(inputs, {'y'})) 54 | return graph 55 | end 56 | 57 | function onnx_nn.Squeeze(obj, nInputs, nonbatch_mode) 58 | nInputs = nInputs or 1 59 | assert(nInputs == 1, "nn.Squeeze can not have multiple inputs") 60 | local graph = onnx.graph.new({'x'}, {'y'}) 61 | local batch_offset = 1 62 | if nonbatch_mode or obj.numInputDims == nil then 63 | batch_offset = 0 64 | end 65 | graph:add_node(onnx.node.Squeeze.new({'x'}, {'y'}, { obj.dim - 1 + batch_offset })) 66 | return graph 67 | end 68 | 69 | function onnx_nn.SoftMax(obj, nInputs) 70 | nInputs = nInputs or 1 71 | assert(nInputs == 1, "nn.SoftMax can not have multiple inputs") 72 | local graph = onnx.graph.new({'x'}, {'y'}) 73 | graph:add_node(onnx.node.SoftMax.new({'x'}, {'y'})) 74 | return graph 75 | end 76 | 77 | function onnx_nn.Identity(obj, nInputs) 78 | nInputs = nInputs or 1 79 | local graph = onnx.graph.new({'x'}, {'y'}) 80 | graph:add_node(onnx.node.Identity.new({'x'}, {'y'})) 81 | return graph 82 | end 83 | 84 | function onnx_nn.Reshape(obj, nInputs, nonbatch_mode) 85 | nInputs = nInputs or 1 86 | assert(nInputs == 1, "nn.Reshape can not have multiple inputs") 87 | local batchMode = obj.batchMode ~= false and not nonbatch_mode 88 | local reshape = obj.size:totable() 89 | if batchMode then 90 | table.insert(reshape, 1, 0) 91 | end 92 | local graph = onnx.graph.new({'x'}, {'y'}) 93 | graph:add_node(onnx.node.Reshape.new({'x', 'ind'}, {'y'}, reshape)) 94 | graph:add_initializer('ind', torch.Tensor(reshape)) 95 | return graph 96 | end 97 | 98 | function onnx_nn.Replicate(obj, nInputs, nonbatch_mode) 99 | nInputs = nInputs or 1 100 | -- Unsqueeze and Tile 101 | assert(nInputs == 1, "nn.Replicate can not have multiple inputs") 102 | local graph = onnx.graph.new({'x'}, {'y'}) 103 | local batch_offset = 1 104 | if nonbatch_mode or obj.ndim == nil then 105 | batch_offset = 0 106 | end 107 | graph:add_node(onnx.node.Unsqueeze.new({'x'}, {'xu'}, {obj.dim-1+batch_offset})) 108 | local repeats = torch.Tensor(obj.output:nDimension()):fill(1) 109 | repeats[obj.dim+batch_offset] = obj.nfeatures 110 | graph:add_node(onnx.node.Tile.new({'xu', 'repeats'}, {'y'}, repeats:totable())) 111 | graph:add_initializer('repeats', repeats) 112 | return graph 113 | end 114 | 115 | function onnx_nn.Abs(obj, nInputs) 116 | nInputs = nInputs or 1 117 | assert(nInputs == 1, "nn.Abs can not have multiple inputs") 118 | local graph = onnx.graph.new({'x'}, {'y'}) 119 | graph:add_node(onnx.node.Abs.new({'x'}, {'y'})) 120 | return graph 121 | end 122 | 123 | function onnx_nn.LookupTable(obj, nInputs) 124 | nInputs = nInputs or 1 125 | assert(nInputs == 1, "nn.Lookup can not have multiple inputs") 126 | local graph = onnx.graph.new({'x'}, {'y'}) 127 | graph:add_node(onnx.node.Gather.new({'ind', 'x'}, {'y'}, 128 | onnx.helper.convertPrecision(obj.weight), 129 | 0)) -- axis 130 | graph:add_initializer('ind', obj.weight) 131 | graph._checker:assert1D('x') 132 | return graph 133 | end 134 | 135 | function onnx_nn.Tanh(obj, nInputs) 136 | nInputs = nInputs or 1 137 | assert(nInputs == 1, "nn.Tanh can not have multiple inputs") 138 | local graph = onnx.graph.new({'x'}, {'y'}) 139 | graph:add_node(onnx.node.Tanh.new({'x'}, {'y'})) 140 | return graph 141 | end 142 | 143 | function onnx_nn.Sigmoid(obj, nInputs) 144 | nInputs = nInputs or 1 145 | assert(nInputs == 1, "nn.Sigmoid can not have multiple inputs") 146 | local graph = onnx.graph.new({'x'}, {'y'}) 147 | graph:add_node(onnx.node.Sigmoid.new({'x'}, {'y'})) 148 | return graph 149 | end 150 | 151 | -- convert Dropout to identity 152 | function onnx_nn.Dropout(obj, nInputs) 153 | nInputs = nInputs or 1 154 | assert(nInputs == 1, "nn.Dropout can not have multiple inputs") 155 | local graph = onnx.graph.new({'x'}, {'y'}) 156 | graph:add_node(onnx.node.Identity.new({'x'}, {'y'})) 157 | return graph 158 | end 159 | 160 | function onnx_nn.MapTable(obj, nInputs, nonbatch_mode) 161 | if nInputs == nil then 162 | nInputs = #obj.output 163 | end 164 | local tname = convertor.mtype(obj.modules[1]) 165 | local convert_func = convertor.isSupported(tname) 166 | if convert_func == nil then 167 | error('module `'..tname..'` not supported') 168 | end 169 | local inputs = {} 170 | local outputs = {} 171 | for i = 1, nInputs do 172 | table.insert(inputs, "x"..i) 173 | table.insert(outputs, "y"..i) 174 | end 175 | local graph = onnx.graph.new(inputs, outputs) 176 | for i = 1, nInputs do 177 | local subgraph = convert_func(obj.modules[1], 1, nonbatch_mode) 178 | assert(#subgraph._outputs == 1) 179 | graph:merge(subgraph, i) 180 | graph:substitute_param(subgraph._inputs[1], inputs[i]) 181 | graph:substitute_param(subgraph._outputs[1], outputs[i]) 182 | end 183 | return graph 184 | end 185 | 186 | function onnx_nn.ConcatTable(obj, nInputs, nonbatch_mode) 187 | nInputs = nInputs or 1 188 | assert(nInputs == 1, "nn.ConcatTable can not have multiple inputs") 189 | local inputs = {'x'} 190 | for i = 1, #obj.output do 191 | table.insert(outputs, 'y'..i) 192 | end 193 | local graph = onnx.graph.new(inputs, outputs) 194 | for i, subobj in pair(obj.modules) do 195 | local tname = convertor.mtype(subobj) 196 | local convert_func = convertor.isSupported(tname) 197 | if convert_func == nil then 198 | error('module `'..tname..'` not supported') 199 | end 200 | local subgraph = convert_func(subobj, 1, nonbatch_mode) 201 | assert(#subgraph._outputs == 1) 202 | graph:merge(subgraph, i) 203 | graph:substitute_param(subgraph._inputs[1], inputs[1]) 204 | graph:substitute_param(subgraph._outputs[1], outputs[i]) 205 | end 206 | return graph 207 | end 208 | 209 | 210 | function onnx_nn.JoinTable(obj, nInputs, nonbatch_mode) 211 | assert(nInputs ~= nil, "JoinTable can only be converted part of a gModule") 212 | local inputs = {} 213 | for i = 1, nInputs do 214 | table.insert(inputs, 'x'..i) 215 | end 216 | local batch_offset = 1 217 | if nonbatch_mode or obj.numInputDims == nil then 218 | batch_offset = 0 219 | end 220 | local graph = onnx.graph.new(inputs, {'y'}) 221 | graph:add_node(onnx.node.Concat(inputs, {'y'}, obj.dimension-1+batch_offset)) 222 | return graph 223 | end 224 | 225 | function onnx_nn.SplitTable(obj, nInputs, nonbatch_mode) 226 | nInputs = nInputs or 1 227 | assert(nInputs == 1, "nn.SplitTable can not have multiple inputs") 228 | local soutputs = {} 229 | local outputs = {} 230 | assert(obj.output ~= nil, "can only convert model with outputs") 231 | for i = 1, #obj.output do 232 | table.insert(soutputs, 'sy'..i) 233 | table.insert(outputs, 'y'..i) 234 | end 235 | local batch_offset = 1 236 | if nonbatch_mode or obj.numInputDims == nil then 237 | batch_offset = 0 238 | end 239 | local graph = onnx.graph.new({'x'}, outputs) 240 | graph:add_node(onnx.node.Split({'x'}, soutputs, obj.dimension-1+batch_offset)) 241 | for i = 1, #obj.output do 242 | graph:add_node(onnx.node.Squeeze({'sy'..i}, {'y'..i}, {obj.dimension-1+batch_offset})) 243 | end 244 | return graph 245 | end 246 | 247 | function onnx_nn.CAddTable(obj, nInputs) 248 | nInputs = nInputs or 2 249 | if nInputs == 1 then 250 | return onnx_nn.Identity(obj, 1) 251 | end 252 | local inputs = {} 253 | for i = 1, nInputs do 254 | table.insert(inputs, 'x'..i) 255 | end 256 | local graph = onnx.graph.new(inputs, {'y'}) 257 | local intSum = 'x1' 258 | for i = 2, nInputs do 259 | local resSum = 'y' 260 | if i < nInputs then 261 | resSum = 'y' .. i 262 | end 263 | graph:add_node(onnx.node.Add.new({intSum, inputs[i]}, {resSum}, 264 | onnx.helper.convertPrecision(obj.weight))) 265 | intSum = resSum 266 | end 267 | return graph 268 | end 269 | 270 | function onnx_nn.CMulTable(obj, nInputs) 271 | nInputs = nInputs or 2 272 | if nInputs == 1 then 273 | return onnx_nn.Identity(obj, 1) 274 | end 275 | local inputs = {} 276 | for i = 1, nInputs do 277 | table.insert(inputs, 'x'..i) 278 | end 279 | local graph = onnx.graph.new(inputs, {'y'}) 280 | local intMul = 'x1' 281 | for i = 2, nInputs do 282 | local resMul = 'y' 283 | if i < nInputs then 284 | resMul = 'y' .. i 285 | end 286 | graph:add_node(onnx.node.Mul.new({intMul, inputs[i]}, {resMul}, 287 | onnx.helper.convertPrecision(obj.weight))) 288 | intMul = resMul 289 | end 290 | return graph 291 | end 292 | 293 | function onnx_nn.Sequential(obj, nInputs, nonbatch_mode) 294 | local subgraphs = {} 295 | for i = 1, #obj.modules do 296 | local obj = obj.modules[i] 297 | local tname = convertor.mtype(obj) 298 | if type(obj) == 'userdata' or type(obj) == 'table' then 299 | local convert_func = convertor.isSupported(tname) 300 | if convert_func then 301 | local subgraph = convert_func(obj, nInputs, nonbatch_mode) 302 | nInputs = #subgraph._outputs 303 | table.insert(subgraphs, subgraph) 304 | else 305 | error('module `'..tname..'` not supported') 306 | end 307 | else 308 | error("unsupported module in nn.Sequential: `"+tname+"`") 309 | end 310 | end 311 | local inputs = {} 312 | local outputs = {} 313 | for i = 1, #subgraphs[1]._inputs do 314 | table.insert(inputs, "x"..i) 315 | end 316 | for i = 1, #subgraphs[#subgraphs]._outputs do 317 | table.insert(outputs, "y"..i) 318 | end 319 | local graph = onnx.graph.new(inputs, outputs) 320 | for i, subgraph in ipairs(subgraphs) do 321 | graph:merge(subgraph, i) 322 | for i = 1, #inputs do 323 | graph:substitute_param(subgraph._inputs[i], inputs[i]) 324 | end 325 | inputs = subgraph._outputs 326 | end 327 | for i = 1, #outputs do 328 | graph:substitute_param(subgraphs[#subgraphs]._outputs[i], outputs[i]) 329 | end 330 | return graph 331 | end 332 | 333 | return onnx_nn -------------------------------------------------------------------------------- /onnx_pb.lua: -------------------------------------------------------------------------------- 1 | -- Generated by protobuf; do not edit 2 | local module = {} 3 | local protobuf = require 'protobuf' 4 | 5 | module.VERSION = protobuf.EnumDescriptor() 6 | module.VERSION__START_VERSION_ENUM = protobuf.EnumValueDescriptor() 7 | module.VERSION_IR_VERSION_2017_10_10_ENUM = protobuf.EnumValueDescriptor() 8 | module.VERSION_IR_VERSION_2017_10_30_ENUM = protobuf.EnumValueDescriptor() 9 | module.VERSION_IR_VERSION_ENUM = protobuf.EnumValueDescriptor() 10 | module.ATTRIBUTEPROTO = protobuf.Descriptor() 11 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE = protobuf.EnumDescriptor() 12 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_UNDEFINED_ENUM = protobuf.EnumValueDescriptor() 13 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_FLOAT_ENUM = protobuf.EnumValueDescriptor() 14 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_INT_ENUM = protobuf.EnumValueDescriptor() 15 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_STRING_ENUM = protobuf.EnumValueDescriptor() 16 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_TENSOR_ENUM = protobuf.EnumValueDescriptor() 17 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_GRAPH_ENUM = protobuf.EnumValueDescriptor() 18 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_FLOATS_ENUM = protobuf.EnumValueDescriptor() 19 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_INTS_ENUM = protobuf.EnumValueDescriptor() 20 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_STRINGS_ENUM = protobuf.EnumValueDescriptor() 21 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_TENSORS_ENUM = protobuf.EnumValueDescriptor() 22 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_GRAPHS_ENUM = protobuf.EnumValueDescriptor() 23 | module.ATTRIBUTEPROTO_NAME_FIELD = protobuf.FieldDescriptor() 24 | module.ATTRIBUTEPROTO_REF_ATTR_NAME_FIELD = protobuf.FieldDescriptor() 25 | module.ATTRIBUTEPROTO_DOC_STRING_FIELD = protobuf.FieldDescriptor() 26 | module.ATTRIBUTEPROTO_TYPE_FIELD = protobuf.FieldDescriptor() 27 | module.ATTRIBUTEPROTO_F_FIELD = protobuf.FieldDescriptor() 28 | module.ATTRIBUTEPROTO_I_FIELD = protobuf.FieldDescriptor() 29 | module.ATTRIBUTEPROTO_S_FIELD = protobuf.FieldDescriptor() 30 | module.ATTRIBUTEPROTO_T_FIELD = protobuf.FieldDescriptor() 31 | module.ATTRIBUTEPROTO_G_FIELD = protobuf.FieldDescriptor() 32 | module.ATTRIBUTEPROTO_FLOATS_FIELD = protobuf.FieldDescriptor() 33 | module.ATTRIBUTEPROTO_INTS_FIELD = protobuf.FieldDescriptor() 34 | module.ATTRIBUTEPROTO_STRINGS_FIELD = protobuf.FieldDescriptor() 35 | module.ATTRIBUTEPROTO_TENSORS_FIELD = protobuf.FieldDescriptor() 36 | module.ATTRIBUTEPROTO_GRAPHS_FIELD = protobuf.FieldDescriptor() 37 | module.VALUEINFOPROTO = protobuf.Descriptor() 38 | module.VALUEINFOPROTO_NAME_FIELD = protobuf.FieldDescriptor() 39 | module.VALUEINFOPROTO_TYPE_FIELD = protobuf.FieldDescriptor() 40 | module.VALUEINFOPROTO_DOC_STRING_FIELD = protobuf.FieldDescriptor() 41 | module.NODEPROTO = protobuf.Descriptor() 42 | module.NODEPROTO_INPUT_FIELD = protobuf.FieldDescriptor() 43 | module.NODEPROTO_OUTPUT_FIELD = protobuf.FieldDescriptor() 44 | module.NODEPROTO_NAME_FIELD = protobuf.FieldDescriptor() 45 | module.NODEPROTO_OP_TYPE_FIELD = protobuf.FieldDescriptor() 46 | module.NODEPROTO_DOMAIN_FIELD = protobuf.FieldDescriptor() 47 | module.NODEPROTO_ATTRIBUTE_FIELD = protobuf.FieldDescriptor() 48 | module.NODEPROTO_DOC_STRING_FIELD = protobuf.FieldDescriptor() 49 | module.MODELPROTO = protobuf.Descriptor() 50 | module.MODELPROTO_IR_VERSION_FIELD = protobuf.FieldDescriptor() 51 | module.MODELPROTO_OPSET_IMPORT_FIELD = protobuf.FieldDescriptor() 52 | module.MODELPROTO_PRODUCER_NAME_FIELD = protobuf.FieldDescriptor() 53 | module.MODELPROTO_PRODUCER_VERSION_FIELD = protobuf.FieldDescriptor() 54 | module.MODELPROTO_DOMAIN_FIELD = protobuf.FieldDescriptor() 55 | module.MODELPROTO_MODEL_VERSION_FIELD = protobuf.FieldDescriptor() 56 | module.MODELPROTO_DOC_STRING_FIELD = protobuf.FieldDescriptor() 57 | module.MODELPROTO_GRAPH_FIELD = protobuf.FieldDescriptor() 58 | module.MODELPROTO_METADATA_PROPS_FIELD = protobuf.FieldDescriptor() 59 | module.STRINGSTRINGENTRYPROTO = protobuf.Descriptor() 60 | module.STRINGSTRINGENTRYPROTO_KEY_FIELD = protobuf.FieldDescriptor() 61 | module.STRINGSTRINGENTRYPROTO_VALUE_FIELD = protobuf.FieldDescriptor() 62 | module.GRAPHPROTO = protobuf.Descriptor() 63 | module.GRAPHPROTO_NODE_FIELD = protobuf.FieldDescriptor() 64 | module.GRAPHPROTO_NAME_FIELD = protobuf.FieldDescriptor() 65 | module.GRAPHPROTO_INITIALIZER_FIELD = protobuf.FieldDescriptor() 66 | module.GRAPHPROTO_DOC_STRING_FIELD = protobuf.FieldDescriptor() 67 | module.GRAPHPROTO_INPUT_FIELD = protobuf.FieldDescriptor() 68 | module.GRAPHPROTO_OUTPUT_FIELD = protobuf.FieldDescriptor() 69 | module.GRAPHPROTO_VALUE_INFO_FIELD = protobuf.FieldDescriptor() 70 | module.TENSORPROTO = protobuf.Descriptor() 71 | module.TENSORPROTO_SEGMENT = protobuf.Descriptor() 72 | module.TENSORPROTO_SEGMENT_BEGIN_FIELD = protobuf.FieldDescriptor() 73 | module.TENSORPROTO_SEGMENT_END_FIELD = protobuf.FieldDescriptor() 74 | module.TENSORPROTO_DATATYPE = protobuf.EnumDescriptor() 75 | module.TENSORPROTO_DATATYPE_UNDEFINED_ENUM = protobuf.EnumValueDescriptor() 76 | module.TENSORPROTO_DATATYPE_FLOAT_ENUM = protobuf.EnumValueDescriptor() 77 | module.TENSORPROTO_DATATYPE_UINT8_ENUM = protobuf.EnumValueDescriptor() 78 | module.TENSORPROTO_DATATYPE_INT8_ENUM = protobuf.EnumValueDescriptor() 79 | module.TENSORPROTO_DATATYPE_UINT16_ENUM = protobuf.EnumValueDescriptor() 80 | module.TENSORPROTO_DATATYPE_INT16_ENUM = protobuf.EnumValueDescriptor() 81 | module.TENSORPROTO_DATATYPE_INT32_ENUM = protobuf.EnumValueDescriptor() 82 | module.TENSORPROTO_DATATYPE_INT64_ENUM = protobuf.EnumValueDescriptor() 83 | module.TENSORPROTO_DATATYPE_STRING_ENUM = protobuf.EnumValueDescriptor() 84 | module.TENSORPROTO_DATATYPE_BOOL_ENUM = protobuf.EnumValueDescriptor() 85 | module.TENSORPROTO_DATATYPE_FLOAT16_ENUM = protobuf.EnumValueDescriptor() 86 | module.TENSORPROTO_DATATYPE_DOUBLE_ENUM = protobuf.EnumValueDescriptor() 87 | module.TENSORPROTO_DATATYPE_UINT32_ENUM = protobuf.EnumValueDescriptor() 88 | module.TENSORPROTO_DATATYPE_UINT64_ENUM = protobuf.EnumValueDescriptor() 89 | module.TENSORPROTO_DATATYPE_COMPLEX64_ENUM = protobuf.EnumValueDescriptor() 90 | module.TENSORPROTO_DATATYPE_COMPLEX128_ENUM = protobuf.EnumValueDescriptor() 91 | module.TENSORPROTO_DIMS_FIELD = protobuf.FieldDescriptor() 92 | module.TENSORPROTO_DATA_TYPE_FIELD = protobuf.FieldDescriptor() 93 | module.TENSORPROTO_SEGMENT_FIELD = protobuf.FieldDescriptor() 94 | module.TENSORPROTO_FLOAT_DATA_FIELD = protobuf.FieldDescriptor() 95 | module.TENSORPROTO_INT32_DATA_FIELD = protobuf.FieldDescriptor() 96 | module.TENSORPROTO_STRING_DATA_FIELD = protobuf.FieldDescriptor() 97 | module.TENSORPROTO_INT64_DATA_FIELD = protobuf.FieldDescriptor() 98 | module.TENSORPROTO_NAME_FIELD = protobuf.FieldDescriptor() 99 | module.TENSORPROTO_DOC_STRING_FIELD = protobuf.FieldDescriptor() 100 | module.TENSORPROTO_RAW_DATA_FIELD = protobuf.FieldDescriptor() 101 | module.TENSORPROTO_DOUBLE_DATA_FIELD = protobuf.FieldDescriptor() 102 | module.TENSORPROTO_UINT64_DATA_FIELD = protobuf.FieldDescriptor() 103 | module.TENSORSHAPEPROTO = protobuf.Descriptor() 104 | module.TENSORSHAPEPROTO_DIMENSION = protobuf.Descriptor() 105 | module.TENSORSHAPEPROTO_DIMENSION_DIM_VALUE_FIELD = protobuf.FieldDescriptor() 106 | module.TENSORSHAPEPROTO_DIMENSION_DIM_PARAM_FIELD = protobuf.FieldDescriptor() 107 | module.TENSORSHAPEPROTO_DIMENSION_DENOTATION_FIELD = protobuf.FieldDescriptor() 108 | module.TENSORSHAPEPROTO_DIM_FIELD = protobuf.FieldDescriptor() 109 | module.DENOTATIONCONSTPROTO = protobuf.Descriptor() 110 | module.DENOTATIONCONSTPROTO_DATA_BATCH_FIELD = protobuf.FieldDescriptor() 111 | module.DENOTATIONCONSTPROTO_DATA_CHANNEL_FIELD = protobuf.FieldDescriptor() 112 | module.DENOTATIONCONSTPROTO_DATA_TIME_FIELD = protobuf.FieldDescriptor() 113 | module.DENOTATIONCONSTPROTO_DATA_FEATURE_FIELD = protobuf.FieldDescriptor() 114 | module.DENOTATIONCONSTPROTO_FILTER_IN_CHANNEL_FIELD = protobuf.FieldDescriptor() 115 | module.DENOTATIONCONSTPROTO_FILTER_OUT_CHANNEL_FIELD = protobuf.FieldDescriptor() 116 | module.DENOTATIONCONSTPROTO_FILTER_SPATIAL_FIELD = protobuf.FieldDescriptor() 117 | module.TYPEPROTO = protobuf.Descriptor() 118 | module.TYPEPROTO_TENSOR = protobuf.Descriptor() 119 | module.TYPEPROTO_TENSOR_ELEM_TYPE_FIELD = protobuf.FieldDescriptor() 120 | module.TYPEPROTO_TENSOR_SHAPE_FIELD = protobuf.FieldDescriptor() 121 | module.TYPEPROTO_TENSOR_TYPE_FIELD = protobuf.FieldDescriptor() 122 | module.OPERATORSETIDPROTO = protobuf.Descriptor() 123 | module.OPERATORSETIDPROTO_DOMAIN_FIELD = protobuf.FieldDescriptor() 124 | module.OPERATORSETIDPROTO_VERSION_FIELD = protobuf.FieldDescriptor() 125 | 126 | module.VERSION__START_VERSION_ENUM.name = '_START_VERSION' 127 | module.VERSION__START_VERSION_ENUM.index = 0 128 | module.VERSION__START_VERSION_ENUM.number = 0 129 | module.VERSION_IR_VERSION_2017_10_10_ENUM.name = 'IR_VERSION_2017_10_10' 130 | module.VERSION_IR_VERSION_2017_10_10_ENUM.index = 1 131 | module.VERSION_IR_VERSION_2017_10_10_ENUM.number = 1 132 | module.VERSION_IR_VERSION_2017_10_30_ENUM.name = 'IR_VERSION_2017_10_30' 133 | module.VERSION_IR_VERSION_2017_10_30_ENUM.index = 2 134 | module.VERSION_IR_VERSION_2017_10_30_ENUM.number = 2 135 | module.VERSION_IR_VERSION_ENUM.name = 'IR_VERSION' 136 | module.VERSION_IR_VERSION_ENUM.index = 3 137 | module.VERSION_IR_VERSION_ENUM.number = 3 138 | module.VERSION.name = 'Version' 139 | module.VERSION.full_name = '.onnx.Version' 140 | module.VERSION.values = {module.VERSION__START_VERSION_ENUM,module.VERSION_IR_VERSION_2017_10_10_ENUM,module.VERSION_IR_VERSION_2017_10_30_ENUM,module.VERSION_IR_VERSION_ENUM} 141 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_UNDEFINED_ENUM.name = 'UNDEFINED' 142 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_UNDEFINED_ENUM.index = 0 143 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_UNDEFINED_ENUM.number = 0 144 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_FLOAT_ENUM.name = 'FLOAT' 145 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_FLOAT_ENUM.index = 1 146 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_FLOAT_ENUM.number = 1 147 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_INT_ENUM.name = 'INT' 148 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_INT_ENUM.index = 2 149 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_INT_ENUM.number = 2 150 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_STRING_ENUM.name = 'STRING' 151 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_STRING_ENUM.index = 3 152 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_STRING_ENUM.number = 3 153 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_TENSOR_ENUM.name = 'TENSOR' 154 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_TENSOR_ENUM.index = 4 155 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_TENSOR_ENUM.number = 4 156 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_GRAPH_ENUM.name = 'GRAPH' 157 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_GRAPH_ENUM.index = 5 158 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_GRAPH_ENUM.number = 5 159 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_FLOATS_ENUM.name = 'FLOATS' 160 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_FLOATS_ENUM.index = 6 161 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_FLOATS_ENUM.number = 6 162 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_INTS_ENUM.name = 'INTS' 163 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_INTS_ENUM.index = 7 164 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_INTS_ENUM.number = 7 165 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_STRINGS_ENUM.name = 'STRINGS' 166 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_STRINGS_ENUM.index = 8 167 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_STRINGS_ENUM.number = 8 168 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_TENSORS_ENUM.name = 'TENSORS' 169 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_TENSORS_ENUM.index = 9 170 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_TENSORS_ENUM.number = 9 171 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_GRAPHS_ENUM.name = 'GRAPHS' 172 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_GRAPHS_ENUM.index = 10 173 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE_GRAPHS_ENUM.number = 10 174 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE.name = 'AttributeType' 175 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE.full_name = '.onnx.AttributeProto.AttributeType' 176 | module.ATTRIBUTEPROTO_ATTRIBUTETYPE.values = {module.ATTRIBUTEPROTO_ATTRIBUTETYPE_UNDEFINED_ENUM,module.ATTRIBUTEPROTO_ATTRIBUTETYPE_FLOAT_ENUM,module.ATTRIBUTEPROTO_ATTRIBUTETYPE_INT_ENUM,module.ATTRIBUTEPROTO_ATTRIBUTETYPE_STRING_ENUM,module.ATTRIBUTEPROTO_ATTRIBUTETYPE_TENSOR_ENUM,module.ATTRIBUTEPROTO_ATTRIBUTETYPE_GRAPH_ENUM,module.ATTRIBUTEPROTO_ATTRIBUTETYPE_FLOATS_ENUM,module.ATTRIBUTEPROTO_ATTRIBUTETYPE_INTS_ENUM,module.ATTRIBUTEPROTO_ATTRIBUTETYPE_STRINGS_ENUM,module.ATTRIBUTEPROTO_ATTRIBUTETYPE_TENSORS_ENUM,module.ATTRIBUTEPROTO_ATTRIBUTETYPE_GRAPHS_ENUM} 177 | module.ATTRIBUTEPROTO_NAME_FIELD.name = 'name' 178 | module.ATTRIBUTEPROTO_NAME_FIELD.full_name = '.onnx.AttributeProto.name' 179 | module.ATTRIBUTEPROTO_NAME_FIELD.number = 1 180 | module.ATTRIBUTEPROTO_NAME_FIELD.index = 0 181 | module.ATTRIBUTEPROTO_NAME_FIELD.label = 1 182 | module.ATTRIBUTEPROTO_NAME_FIELD.has_default_value = false 183 | module.ATTRIBUTEPROTO_NAME_FIELD.default_value = '' 184 | module.ATTRIBUTEPROTO_NAME_FIELD.type = 9 185 | module.ATTRIBUTEPROTO_NAME_FIELD.cpp_type = 9 186 | 187 | module.ATTRIBUTEPROTO_REF_ATTR_NAME_FIELD.name = 'ref_attr_name' 188 | module.ATTRIBUTEPROTO_REF_ATTR_NAME_FIELD.full_name = '.onnx.AttributeProto.ref_attr_name' 189 | module.ATTRIBUTEPROTO_REF_ATTR_NAME_FIELD.number = 21 190 | module.ATTRIBUTEPROTO_REF_ATTR_NAME_FIELD.index = 1 191 | module.ATTRIBUTEPROTO_REF_ATTR_NAME_FIELD.label = 1 192 | module.ATTRIBUTEPROTO_REF_ATTR_NAME_FIELD.has_default_value = false 193 | module.ATTRIBUTEPROTO_REF_ATTR_NAME_FIELD.default_value = '' 194 | module.ATTRIBUTEPROTO_REF_ATTR_NAME_FIELD.type = 9 195 | module.ATTRIBUTEPROTO_REF_ATTR_NAME_FIELD.cpp_type = 9 196 | 197 | module.ATTRIBUTEPROTO_DOC_STRING_FIELD.name = 'doc_string' 198 | module.ATTRIBUTEPROTO_DOC_STRING_FIELD.full_name = '.onnx.AttributeProto.doc_string' 199 | module.ATTRIBUTEPROTO_DOC_STRING_FIELD.number = 13 200 | module.ATTRIBUTEPROTO_DOC_STRING_FIELD.index = 2 201 | module.ATTRIBUTEPROTO_DOC_STRING_FIELD.label = 1 202 | module.ATTRIBUTEPROTO_DOC_STRING_FIELD.has_default_value = false 203 | module.ATTRIBUTEPROTO_DOC_STRING_FIELD.default_value = '' 204 | module.ATTRIBUTEPROTO_DOC_STRING_FIELD.type = 9 205 | module.ATTRIBUTEPROTO_DOC_STRING_FIELD.cpp_type = 9 206 | 207 | module.ATTRIBUTEPROTO_TYPE_FIELD.name = 'type' 208 | module.ATTRIBUTEPROTO_TYPE_FIELD.full_name = '.onnx.AttributeProto.type' 209 | module.ATTRIBUTEPROTO_TYPE_FIELD.number = 20 210 | module.ATTRIBUTEPROTO_TYPE_FIELD.index = 3 211 | module.ATTRIBUTEPROTO_TYPE_FIELD.label = 1 212 | module.ATTRIBUTEPROTO_TYPE_FIELD.has_default_value = false 213 | module.ATTRIBUTEPROTO_TYPE_FIELD.default_value = nil 214 | module.ATTRIBUTEPROTO_TYPE_FIELD.enum_type = module.ATTRIBUTEPROTO_ATTRIBUTETYPE 215 | module.ATTRIBUTEPROTO_TYPE_FIELD.type = 14 216 | module.ATTRIBUTEPROTO_TYPE_FIELD.cpp_type = 8 217 | 218 | module.ATTRIBUTEPROTO_F_FIELD.name = 'f' 219 | module.ATTRIBUTEPROTO_F_FIELD.full_name = '.onnx.AttributeProto.f' 220 | module.ATTRIBUTEPROTO_F_FIELD.number = 2 221 | module.ATTRIBUTEPROTO_F_FIELD.index = 4 222 | module.ATTRIBUTEPROTO_F_FIELD.label = 1 223 | module.ATTRIBUTEPROTO_F_FIELD.has_default_value = false 224 | module.ATTRIBUTEPROTO_F_FIELD.default_value = 0.0 225 | module.ATTRIBUTEPROTO_F_FIELD.type = 2 226 | module.ATTRIBUTEPROTO_F_FIELD.cpp_type = 6 227 | 228 | module.ATTRIBUTEPROTO_I_FIELD.name = 'i' 229 | module.ATTRIBUTEPROTO_I_FIELD.full_name = '.onnx.AttributeProto.i' 230 | module.ATTRIBUTEPROTO_I_FIELD.number = 3 231 | module.ATTRIBUTEPROTO_I_FIELD.index = 5 232 | module.ATTRIBUTEPROTO_I_FIELD.label = 1 233 | module.ATTRIBUTEPROTO_I_FIELD.has_default_value = false 234 | module.ATTRIBUTEPROTO_I_FIELD.default_value = 0 235 | module.ATTRIBUTEPROTO_I_FIELD.type = 3 236 | module.ATTRIBUTEPROTO_I_FIELD.cpp_type = 2 237 | 238 | module.ATTRIBUTEPROTO_S_FIELD.name = 's' 239 | module.ATTRIBUTEPROTO_S_FIELD.full_name = '.onnx.AttributeProto.s' 240 | module.ATTRIBUTEPROTO_S_FIELD.number = 4 241 | module.ATTRIBUTEPROTO_S_FIELD.index = 6 242 | module.ATTRIBUTEPROTO_S_FIELD.label = 1 243 | module.ATTRIBUTEPROTO_S_FIELD.has_default_value = false 244 | module.ATTRIBUTEPROTO_S_FIELD.default_value = '' 245 | module.ATTRIBUTEPROTO_S_FIELD.type = 12 246 | module.ATTRIBUTEPROTO_S_FIELD.cpp_type = 9 247 | 248 | module.ATTRIBUTEPROTO_T_FIELD.name = 't' 249 | module.ATTRIBUTEPROTO_T_FIELD.full_name = '.onnx.AttributeProto.t' 250 | module.ATTRIBUTEPROTO_T_FIELD.number = 5 251 | module.ATTRIBUTEPROTO_T_FIELD.index = 7 252 | module.ATTRIBUTEPROTO_T_FIELD.label = 1 253 | module.ATTRIBUTEPROTO_T_FIELD.has_default_value = false 254 | module.ATTRIBUTEPROTO_T_FIELD.default_value = nil 255 | module.ATTRIBUTEPROTO_T_FIELD.message_type = module.TENSORPROTO 256 | module.ATTRIBUTEPROTO_T_FIELD.type = 11 257 | module.ATTRIBUTEPROTO_T_FIELD.cpp_type = 10 258 | 259 | module.ATTRIBUTEPROTO_G_FIELD.name = 'g' 260 | module.ATTRIBUTEPROTO_G_FIELD.full_name = '.onnx.AttributeProto.g' 261 | module.ATTRIBUTEPROTO_G_FIELD.number = 6 262 | module.ATTRIBUTEPROTO_G_FIELD.index = 8 263 | module.ATTRIBUTEPROTO_G_FIELD.label = 1 264 | module.ATTRIBUTEPROTO_G_FIELD.has_default_value = false 265 | module.ATTRIBUTEPROTO_G_FIELD.default_value = nil 266 | module.ATTRIBUTEPROTO_G_FIELD.message_type = module.GRAPHPROTO 267 | module.ATTRIBUTEPROTO_G_FIELD.type = 11 268 | module.ATTRIBUTEPROTO_G_FIELD.cpp_type = 10 269 | 270 | module.ATTRIBUTEPROTO_FLOATS_FIELD.name = 'floats' 271 | module.ATTRIBUTEPROTO_FLOATS_FIELD.full_name = '.onnx.AttributeProto.floats' 272 | module.ATTRIBUTEPROTO_FLOATS_FIELD.number = 7 273 | module.ATTRIBUTEPROTO_FLOATS_FIELD.index = 9 274 | module.ATTRIBUTEPROTO_FLOATS_FIELD.label = 3 275 | module.ATTRIBUTEPROTO_FLOATS_FIELD.has_default_value = false 276 | module.ATTRIBUTEPROTO_FLOATS_FIELD.default_value = {} 277 | module.ATTRIBUTEPROTO_FLOATS_FIELD.type = 2 278 | module.ATTRIBUTEPROTO_FLOATS_FIELD.cpp_type = 6 279 | 280 | module.ATTRIBUTEPROTO_INTS_FIELD.name = 'ints' 281 | module.ATTRIBUTEPROTO_INTS_FIELD.full_name = '.onnx.AttributeProto.ints' 282 | module.ATTRIBUTEPROTO_INTS_FIELD.number = 8 283 | module.ATTRIBUTEPROTO_INTS_FIELD.index = 10 284 | module.ATTRIBUTEPROTO_INTS_FIELD.label = 3 285 | module.ATTRIBUTEPROTO_INTS_FIELD.has_default_value = false 286 | module.ATTRIBUTEPROTO_INTS_FIELD.default_value = {} 287 | module.ATTRIBUTEPROTO_INTS_FIELD.type = 3 288 | module.ATTRIBUTEPROTO_INTS_FIELD.cpp_type = 2 289 | 290 | module.ATTRIBUTEPROTO_STRINGS_FIELD.name = 'strings' 291 | module.ATTRIBUTEPROTO_STRINGS_FIELD.full_name = '.onnx.AttributeProto.strings' 292 | module.ATTRIBUTEPROTO_STRINGS_FIELD.number = 9 293 | module.ATTRIBUTEPROTO_STRINGS_FIELD.index = 11 294 | module.ATTRIBUTEPROTO_STRINGS_FIELD.label = 3 295 | module.ATTRIBUTEPROTO_STRINGS_FIELD.has_default_value = false 296 | module.ATTRIBUTEPROTO_STRINGS_FIELD.default_value = {} 297 | module.ATTRIBUTEPROTO_STRINGS_FIELD.type = 12 298 | module.ATTRIBUTEPROTO_STRINGS_FIELD.cpp_type = 9 299 | 300 | module.ATTRIBUTEPROTO_TENSORS_FIELD.name = 'tensors' 301 | module.ATTRIBUTEPROTO_TENSORS_FIELD.full_name = '.onnx.AttributeProto.tensors' 302 | module.ATTRIBUTEPROTO_TENSORS_FIELD.number = 10 303 | module.ATTRIBUTEPROTO_TENSORS_FIELD.index = 12 304 | module.ATTRIBUTEPROTO_TENSORS_FIELD.label = 3 305 | module.ATTRIBUTEPROTO_TENSORS_FIELD.has_default_value = false 306 | module.ATTRIBUTEPROTO_TENSORS_FIELD.default_value = {} 307 | module.ATTRIBUTEPROTO_TENSORS_FIELD.message_type = module.TENSORPROTO 308 | module.ATTRIBUTEPROTO_TENSORS_FIELD.type = 11 309 | module.ATTRIBUTEPROTO_TENSORS_FIELD.cpp_type = 10 310 | 311 | module.ATTRIBUTEPROTO_GRAPHS_FIELD.name = 'graphs' 312 | module.ATTRIBUTEPROTO_GRAPHS_FIELD.full_name = '.onnx.AttributeProto.graphs' 313 | module.ATTRIBUTEPROTO_GRAPHS_FIELD.number = 11 314 | module.ATTRIBUTEPROTO_GRAPHS_FIELD.index = 13 315 | module.ATTRIBUTEPROTO_GRAPHS_FIELD.label = 3 316 | module.ATTRIBUTEPROTO_GRAPHS_FIELD.has_default_value = false 317 | module.ATTRIBUTEPROTO_GRAPHS_FIELD.default_value = {} 318 | module.ATTRIBUTEPROTO_GRAPHS_FIELD.message_type = module.GRAPHPROTO 319 | module.ATTRIBUTEPROTO_GRAPHS_FIELD.type = 11 320 | module.ATTRIBUTEPROTO_GRAPHS_FIELD.cpp_type = 10 321 | 322 | module.ATTRIBUTEPROTO.name = 'AttributeProto' 323 | module.ATTRIBUTEPROTO.full_name = '.onnx.AttributeProto' 324 | module.ATTRIBUTEPROTO.nested_types = {} 325 | module.ATTRIBUTEPROTO.enum_types = {module.ATTRIBUTEPROTO_ATTRIBUTETYPE} 326 | module.ATTRIBUTEPROTO.fields = {module.ATTRIBUTEPROTO_NAME_FIELD, module.ATTRIBUTEPROTO_REF_ATTR_NAME_FIELD, module.ATTRIBUTEPROTO_DOC_STRING_FIELD, module.ATTRIBUTEPROTO_TYPE_FIELD, module.ATTRIBUTEPROTO_F_FIELD, module.ATTRIBUTEPROTO_I_FIELD, module.ATTRIBUTEPROTO_S_FIELD, module.ATTRIBUTEPROTO_T_FIELD, module.ATTRIBUTEPROTO_G_FIELD, module.ATTRIBUTEPROTO_FLOATS_FIELD, module.ATTRIBUTEPROTO_INTS_FIELD, module.ATTRIBUTEPROTO_STRINGS_FIELD, module.ATTRIBUTEPROTO_TENSORS_FIELD, module.ATTRIBUTEPROTO_GRAPHS_FIELD} 327 | module.ATTRIBUTEPROTO.is_extendable = false 328 | module.ATTRIBUTEPROTO.extensions = {} 329 | module.VALUEINFOPROTO_NAME_FIELD.name = 'name' 330 | module.VALUEINFOPROTO_NAME_FIELD.full_name = '.onnx.ValueInfoProto.name' 331 | module.VALUEINFOPROTO_NAME_FIELD.number = 1 332 | module.VALUEINFOPROTO_NAME_FIELD.index = 0 333 | module.VALUEINFOPROTO_NAME_FIELD.label = 1 334 | module.VALUEINFOPROTO_NAME_FIELD.has_default_value = false 335 | module.VALUEINFOPROTO_NAME_FIELD.default_value = '' 336 | module.VALUEINFOPROTO_NAME_FIELD.type = 9 337 | module.VALUEINFOPROTO_NAME_FIELD.cpp_type = 9 338 | 339 | module.VALUEINFOPROTO_TYPE_FIELD.name = 'type' 340 | module.VALUEINFOPROTO_TYPE_FIELD.full_name = '.onnx.ValueInfoProto.type' 341 | module.VALUEINFOPROTO_TYPE_FIELD.number = 2 342 | module.VALUEINFOPROTO_TYPE_FIELD.index = 1 343 | module.VALUEINFOPROTO_TYPE_FIELD.label = 1 344 | module.VALUEINFOPROTO_TYPE_FIELD.has_default_value = false 345 | module.VALUEINFOPROTO_TYPE_FIELD.default_value = nil 346 | module.VALUEINFOPROTO_TYPE_FIELD.message_type = module.TYPEPROTO 347 | module.VALUEINFOPROTO_TYPE_FIELD.type = 11 348 | module.VALUEINFOPROTO_TYPE_FIELD.cpp_type = 10 349 | 350 | module.VALUEINFOPROTO_DOC_STRING_FIELD.name = 'doc_string' 351 | module.VALUEINFOPROTO_DOC_STRING_FIELD.full_name = '.onnx.ValueInfoProto.doc_string' 352 | module.VALUEINFOPROTO_DOC_STRING_FIELD.number = 3 353 | module.VALUEINFOPROTO_DOC_STRING_FIELD.index = 2 354 | module.VALUEINFOPROTO_DOC_STRING_FIELD.label = 1 355 | module.VALUEINFOPROTO_DOC_STRING_FIELD.has_default_value = false 356 | module.VALUEINFOPROTO_DOC_STRING_FIELD.default_value = '' 357 | module.VALUEINFOPROTO_DOC_STRING_FIELD.type = 9 358 | module.VALUEINFOPROTO_DOC_STRING_FIELD.cpp_type = 9 359 | 360 | module.VALUEINFOPROTO.name = 'ValueInfoProto' 361 | module.VALUEINFOPROTO.full_name = '.onnx.ValueInfoProto' 362 | module.VALUEINFOPROTO.nested_types = {} 363 | module.VALUEINFOPROTO.enum_types = {} 364 | module.VALUEINFOPROTO.fields = {module.VALUEINFOPROTO_NAME_FIELD, module.VALUEINFOPROTO_TYPE_FIELD, module.VALUEINFOPROTO_DOC_STRING_FIELD} 365 | module.VALUEINFOPROTO.is_extendable = false 366 | module.VALUEINFOPROTO.extensions = {} 367 | module.NODEPROTO_INPUT_FIELD.name = 'input' 368 | module.NODEPROTO_INPUT_FIELD.full_name = '.onnx.NodeProto.input' 369 | module.NODEPROTO_INPUT_FIELD.number = 1 370 | module.NODEPROTO_INPUT_FIELD.index = 0 371 | module.NODEPROTO_INPUT_FIELD.label = 3 372 | module.NODEPROTO_INPUT_FIELD.has_default_value = false 373 | module.NODEPROTO_INPUT_FIELD.default_value = {} 374 | module.NODEPROTO_INPUT_FIELD.type = 9 375 | module.NODEPROTO_INPUT_FIELD.cpp_type = 9 376 | 377 | module.NODEPROTO_OUTPUT_FIELD.name = 'output' 378 | module.NODEPROTO_OUTPUT_FIELD.full_name = '.onnx.NodeProto.output' 379 | module.NODEPROTO_OUTPUT_FIELD.number = 2 380 | module.NODEPROTO_OUTPUT_FIELD.index = 1 381 | module.NODEPROTO_OUTPUT_FIELD.label = 3 382 | module.NODEPROTO_OUTPUT_FIELD.has_default_value = false 383 | module.NODEPROTO_OUTPUT_FIELD.default_value = {} 384 | module.NODEPROTO_OUTPUT_FIELD.type = 9 385 | module.NODEPROTO_OUTPUT_FIELD.cpp_type = 9 386 | 387 | module.NODEPROTO_NAME_FIELD.name = 'name' 388 | module.NODEPROTO_NAME_FIELD.full_name = '.onnx.NodeProto.name' 389 | module.NODEPROTO_NAME_FIELD.number = 3 390 | module.NODEPROTO_NAME_FIELD.index = 2 391 | module.NODEPROTO_NAME_FIELD.label = 1 392 | module.NODEPROTO_NAME_FIELD.has_default_value = false 393 | module.NODEPROTO_NAME_FIELD.default_value = '' 394 | module.NODEPROTO_NAME_FIELD.type = 9 395 | module.NODEPROTO_NAME_FIELD.cpp_type = 9 396 | 397 | module.NODEPROTO_OP_TYPE_FIELD.name = 'op_type' 398 | module.NODEPROTO_OP_TYPE_FIELD.full_name = '.onnx.NodeProto.op_type' 399 | module.NODEPROTO_OP_TYPE_FIELD.number = 4 400 | module.NODEPROTO_OP_TYPE_FIELD.index = 3 401 | module.NODEPROTO_OP_TYPE_FIELD.label = 1 402 | module.NODEPROTO_OP_TYPE_FIELD.has_default_value = false 403 | module.NODEPROTO_OP_TYPE_FIELD.default_value = '' 404 | module.NODEPROTO_OP_TYPE_FIELD.type = 9 405 | module.NODEPROTO_OP_TYPE_FIELD.cpp_type = 9 406 | 407 | module.NODEPROTO_DOMAIN_FIELD.name = 'domain' 408 | module.NODEPROTO_DOMAIN_FIELD.full_name = '.onnx.NodeProto.domain' 409 | module.NODEPROTO_DOMAIN_FIELD.number = 7 410 | module.NODEPROTO_DOMAIN_FIELD.index = 4 411 | module.NODEPROTO_DOMAIN_FIELD.label = 1 412 | module.NODEPROTO_DOMAIN_FIELD.has_default_value = false 413 | module.NODEPROTO_DOMAIN_FIELD.default_value = '' 414 | module.NODEPROTO_DOMAIN_FIELD.type = 9 415 | module.NODEPROTO_DOMAIN_FIELD.cpp_type = 9 416 | 417 | module.NODEPROTO_ATTRIBUTE_FIELD.name = 'attribute' 418 | module.NODEPROTO_ATTRIBUTE_FIELD.full_name = '.onnx.NodeProto.attribute' 419 | module.NODEPROTO_ATTRIBUTE_FIELD.number = 5 420 | module.NODEPROTO_ATTRIBUTE_FIELD.index = 5 421 | module.NODEPROTO_ATTRIBUTE_FIELD.label = 3 422 | module.NODEPROTO_ATTRIBUTE_FIELD.has_default_value = false 423 | module.NODEPROTO_ATTRIBUTE_FIELD.default_value = {} 424 | module.NODEPROTO_ATTRIBUTE_FIELD.message_type = module.ATTRIBUTEPROTO 425 | module.NODEPROTO_ATTRIBUTE_FIELD.type = 11 426 | module.NODEPROTO_ATTRIBUTE_FIELD.cpp_type = 10 427 | 428 | module.NODEPROTO_DOC_STRING_FIELD.name = 'doc_string' 429 | module.NODEPROTO_DOC_STRING_FIELD.full_name = '.onnx.NodeProto.doc_string' 430 | module.NODEPROTO_DOC_STRING_FIELD.number = 6 431 | module.NODEPROTO_DOC_STRING_FIELD.index = 6 432 | module.NODEPROTO_DOC_STRING_FIELD.label = 1 433 | module.NODEPROTO_DOC_STRING_FIELD.has_default_value = false 434 | module.NODEPROTO_DOC_STRING_FIELD.default_value = '' 435 | module.NODEPROTO_DOC_STRING_FIELD.type = 9 436 | module.NODEPROTO_DOC_STRING_FIELD.cpp_type = 9 437 | 438 | module.NODEPROTO.name = 'NodeProto' 439 | module.NODEPROTO.full_name = '.onnx.NodeProto' 440 | module.NODEPROTO.nested_types = {} 441 | module.NODEPROTO.enum_types = {} 442 | module.NODEPROTO.fields = {module.NODEPROTO_INPUT_FIELD, module.NODEPROTO_OUTPUT_FIELD, module.NODEPROTO_NAME_FIELD, module.NODEPROTO_OP_TYPE_FIELD, module.NODEPROTO_DOMAIN_FIELD, module.NODEPROTO_ATTRIBUTE_FIELD, module.NODEPROTO_DOC_STRING_FIELD} 443 | module.NODEPROTO.is_extendable = false 444 | module.NODEPROTO.extensions = {} 445 | module.MODELPROTO_IR_VERSION_FIELD.name = 'ir_version' 446 | module.MODELPROTO_IR_VERSION_FIELD.full_name = '.onnx.ModelProto.ir_version' 447 | module.MODELPROTO_IR_VERSION_FIELD.number = 1 448 | module.MODELPROTO_IR_VERSION_FIELD.index = 0 449 | module.MODELPROTO_IR_VERSION_FIELD.label = 1 450 | module.MODELPROTO_IR_VERSION_FIELD.has_default_value = false 451 | module.MODELPROTO_IR_VERSION_FIELD.default_value = 0 452 | module.MODELPROTO_IR_VERSION_FIELD.type = 3 453 | module.MODELPROTO_IR_VERSION_FIELD.cpp_type = 2 454 | 455 | module.MODELPROTO_OPSET_IMPORT_FIELD.name = 'opset_import' 456 | module.MODELPROTO_OPSET_IMPORT_FIELD.full_name = '.onnx.ModelProto.opset_import' 457 | module.MODELPROTO_OPSET_IMPORT_FIELD.number = 8 458 | module.MODELPROTO_OPSET_IMPORT_FIELD.index = 1 459 | module.MODELPROTO_OPSET_IMPORT_FIELD.label = 3 460 | module.MODELPROTO_OPSET_IMPORT_FIELD.has_default_value = false 461 | module.MODELPROTO_OPSET_IMPORT_FIELD.default_value = {} 462 | module.MODELPROTO_OPSET_IMPORT_FIELD.message_type = module.OPERATORSETIDPROTO 463 | module.MODELPROTO_OPSET_IMPORT_FIELD.type = 11 464 | module.MODELPROTO_OPSET_IMPORT_FIELD.cpp_type = 10 465 | 466 | module.MODELPROTO_PRODUCER_NAME_FIELD.name = 'producer_name' 467 | module.MODELPROTO_PRODUCER_NAME_FIELD.full_name = '.onnx.ModelProto.producer_name' 468 | module.MODELPROTO_PRODUCER_NAME_FIELD.number = 2 469 | module.MODELPROTO_PRODUCER_NAME_FIELD.index = 2 470 | module.MODELPROTO_PRODUCER_NAME_FIELD.label = 1 471 | module.MODELPROTO_PRODUCER_NAME_FIELD.has_default_value = false 472 | module.MODELPROTO_PRODUCER_NAME_FIELD.default_value = '' 473 | module.MODELPROTO_PRODUCER_NAME_FIELD.type = 9 474 | module.MODELPROTO_PRODUCER_NAME_FIELD.cpp_type = 9 475 | 476 | module.MODELPROTO_PRODUCER_VERSION_FIELD.name = 'producer_version' 477 | module.MODELPROTO_PRODUCER_VERSION_FIELD.full_name = '.onnx.ModelProto.producer_version' 478 | module.MODELPROTO_PRODUCER_VERSION_FIELD.number = 3 479 | module.MODELPROTO_PRODUCER_VERSION_FIELD.index = 3 480 | module.MODELPROTO_PRODUCER_VERSION_FIELD.label = 1 481 | module.MODELPROTO_PRODUCER_VERSION_FIELD.has_default_value = false 482 | module.MODELPROTO_PRODUCER_VERSION_FIELD.default_value = '' 483 | module.MODELPROTO_PRODUCER_VERSION_FIELD.type = 9 484 | module.MODELPROTO_PRODUCER_VERSION_FIELD.cpp_type = 9 485 | 486 | module.MODELPROTO_DOMAIN_FIELD.name = 'domain' 487 | module.MODELPROTO_DOMAIN_FIELD.full_name = '.onnx.ModelProto.domain' 488 | module.MODELPROTO_DOMAIN_FIELD.number = 4 489 | module.MODELPROTO_DOMAIN_FIELD.index = 4 490 | module.MODELPROTO_DOMAIN_FIELD.label = 1 491 | module.MODELPROTO_DOMAIN_FIELD.has_default_value = false 492 | module.MODELPROTO_DOMAIN_FIELD.default_value = '' 493 | module.MODELPROTO_DOMAIN_FIELD.type = 9 494 | module.MODELPROTO_DOMAIN_FIELD.cpp_type = 9 495 | 496 | module.MODELPROTO_MODEL_VERSION_FIELD.name = 'model_version' 497 | module.MODELPROTO_MODEL_VERSION_FIELD.full_name = '.onnx.ModelProto.model_version' 498 | module.MODELPROTO_MODEL_VERSION_FIELD.number = 5 499 | module.MODELPROTO_MODEL_VERSION_FIELD.index = 5 500 | module.MODELPROTO_MODEL_VERSION_FIELD.label = 1 501 | module.MODELPROTO_MODEL_VERSION_FIELD.has_default_value = false 502 | module.MODELPROTO_MODEL_VERSION_FIELD.default_value = 0 503 | module.MODELPROTO_MODEL_VERSION_FIELD.type = 3 504 | module.MODELPROTO_MODEL_VERSION_FIELD.cpp_type = 2 505 | 506 | module.MODELPROTO_DOC_STRING_FIELD.name = 'doc_string' 507 | module.MODELPROTO_DOC_STRING_FIELD.full_name = '.onnx.ModelProto.doc_string' 508 | module.MODELPROTO_DOC_STRING_FIELD.number = 6 509 | module.MODELPROTO_DOC_STRING_FIELD.index = 6 510 | module.MODELPROTO_DOC_STRING_FIELD.label = 1 511 | module.MODELPROTO_DOC_STRING_FIELD.has_default_value = false 512 | module.MODELPROTO_DOC_STRING_FIELD.default_value = '' 513 | module.MODELPROTO_DOC_STRING_FIELD.type = 9 514 | module.MODELPROTO_DOC_STRING_FIELD.cpp_type = 9 515 | 516 | module.MODELPROTO_GRAPH_FIELD.name = 'graph' 517 | module.MODELPROTO_GRAPH_FIELD.full_name = '.onnx.ModelProto.graph' 518 | module.MODELPROTO_GRAPH_FIELD.number = 7 519 | module.MODELPROTO_GRAPH_FIELD.index = 7 520 | module.MODELPROTO_GRAPH_FIELD.label = 1 521 | module.MODELPROTO_GRAPH_FIELD.has_default_value = false 522 | module.MODELPROTO_GRAPH_FIELD.default_value = nil 523 | module.MODELPROTO_GRAPH_FIELD.message_type = module.GRAPHPROTO 524 | module.MODELPROTO_GRAPH_FIELD.type = 11 525 | module.MODELPROTO_GRAPH_FIELD.cpp_type = 10 526 | 527 | module.MODELPROTO_METADATA_PROPS_FIELD.name = 'metadata_props' 528 | module.MODELPROTO_METADATA_PROPS_FIELD.full_name = '.onnx.ModelProto.metadata_props' 529 | module.MODELPROTO_METADATA_PROPS_FIELD.number = 14 530 | module.MODELPROTO_METADATA_PROPS_FIELD.index = 8 531 | module.MODELPROTO_METADATA_PROPS_FIELD.label = 3 532 | module.MODELPROTO_METADATA_PROPS_FIELD.has_default_value = false 533 | module.MODELPROTO_METADATA_PROPS_FIELD.default_value = {} 534 | module.MODELPROTO_METADATA_PROPS_FIELD.message_type = module.STRINGSTRINGENTRYPROTO 535 | module.MODELPROTO_METADATA_PROPS_FIELD.type = 11 536 | module.MODELPROTO_METADATA_PROPS_FIELD.cpp_type = 10 537 | 538 | module.MODELPROTO.name = 'ModelProto' 539 | module.MODELPROTO.full_name = '.onnx.ModelProto' 540 | module.MODELPROTO.nested_types = {} 541 | module.MODELPROTO.enum_types = {} 542 | module.MODELPROTO.fields = {module.MODELPROTO_IR_VERSION_FIELD, module.MODELPROTO_OPSET_IMPORT_FIELD, module.MODELPROTO_PRODUCER_NAME_FIELD, module.MODELPROTO_PRODUCER_VERSION_FIELD, module.MODELPROTO_DOMAIN_FIELD, module.MODELPROTO_MODEL_VERSION_FIELD, module.MODELPROTO_DOC_STRING_FIELD, module.MODELPROTO_GRAPH_FIELD, module.MODELPROTO_METADATA_PROPS_FIELD} 543 | module.MODELPROTO.is_extendable = false 544 | module.MODELPROTO.extensions = {} 545 | module.STRINGSTRINGENTRYPROTO_KEY_FIELD.name = 'key' 546 | module.STRINGSTRINGENTRYPROTO_KEY_FIELD.full_name = '.onnx.StringStringEntryProto.key' 547 | module.STRINGSTRINGENTRYPROTO_KEY_FIELD.number = 1 548 | module.STRINGSTRINGENTRYPROTO_KEY_FIELD.index = 0 549 | module.STRINGSTRINGENTRYPROTO_KEY_FIELD.label = 1 550 | module.STRINGSTRINGENTRYPROTO_KEY_FIELD.has_default_value = false 551 | module.STRINGSTRINGENTRYPROTO_KEY_FIELD.default_value = '' 552 | module.STRINGSTRINGENTRYPROTO_KEY_FIELD.type = 9 553 | module.STRINGSTRINGENTRYPROTO_KEY_FIELD.cpp_type = 9 554 | 555 | module.STRINGSTRINGENTRYPROTO_VALUE_FIELD.name = 'value' 556 | module.STRINGSTRINGENTRYPROTO_VALUE_FIELD.full_name = '.onnx.StringStringEntryProto.value' 557 | module.STRINGSTRINGENTRYPROTO_VALUE_FIELD.number = 2 558 | module.STRINGSTRINGENTRYPROTO_VALUE_FIELD.index = 1 559 | module.STRINGSTRINGENTRYPROTO_VALUE_FIELD.label = 1 560 | module.STRINGSTRINGENTRYPROTO_VALUE_FIELD.has_default_value = false 561 | module.STRINGSTRINGENTRYPROTO_VALUE_FIELD.default_value = '' 562 | module.STRINGSTRINGENTRYPROTO_VALUE_FIELD.type = 9 563 | module.STRINGSTRINGENTRYPROTO_VALUE_FIELD.cpp_type = 9 564 | 565 | module.STRINGSTRINGENTRYPROTO.name = 'StringStringEntryProto' 566 | module.STRINGSTRINGENTRYPROTO.full_name = '.onnx.StringStringEntryProto' 567 | module.STRINGSTRINGENTRYPROTO.nested_types = {} 568 | module.STRINGSTRINGENTRYPROTO.enum_types = {} 569 | module.STRINGSTRINGENTRYPROTO.fields = {module.STRINGSTRINGENTRYPROTO_KEY_FIELD, module.STRINGSTRINGENTRYPROTO_VALUE_FIELD} 570 | module.STRINGSTRINGENTRYPROTO.is_extendable = false 571 | module.STRINGSTRINGENTRYPROTO.extensions = {} 572 | module.GRAPHPROTO_NODE_FIELD.name = 'node' 573 | module.GRAPHPROTO_NODE_FIELD.full_name = '.onnx.GraphProto.node' 574 | module.GRAPHPROTO_NODE_FIELD.number = 1 575 | module.GRAPHPROTO_NODE_FIELD.index = 0 576 | module.GRAPHPROTO_NODE_FIELD.label = 3 577 | module.GRAPHPROTO_NODE_FIELD.has_default_value = false 578 | module.GRAPHPROTO_NODE_FIELD.default_value = {} 579 | module.GRAPHPROTO_NODE_FIELD.message_type = module.NODEPROTO 580 | module.GRAPHPROTO_NODE_FIELD.type = 11 581 | module.GRAPHPROTO_NODE_FIELD.cpp_type = 10 582 | 583 | module.GRAPHPROTO_NAME_FIELD.name = 'name' 584 | module.GRAPHPROTO_NAME_FIELD.full_name = '.onnx.GraphProto.name' 585 | module.GRAPHPROTO_NAME_FIELD.number = 2 586 | module.GRAPHPROTO_NAME_FIELD.index = 1 587 | module.GRAPHPROTO_NAME_FIELD.label = 1 588 | module.GRAPHPROTO_NAME_FIELD.has_default_value = false 589 | module.GRAPHPROTO_NAME_FIELD.default_value = '' 590 | module.GRAPHPROTO_NAME_FIELD.type = 9 591 | module.GRAPHPROTO_NAME_FIELD.cpp_type = 9 592 | 593 | module.GRAPHPROTO_INITIALIZER_FIELD.name = 'initializer' 594 | module.GRAPHPROTO_INITIALIZER_FIELD.full_name = '.onnx.GraphProto.initializer' 595 | module.GRAPHPROTO_INITIALIZER_FIELD.number = 5 596 | module.GRAPHPROTO_INITIALIZER_FIELD.index = 2 597 | module.GRAPHPROTO_INITIALIZER_FIELD.label = 3 598 | module.GRAPHPROTO_INITIALIZER_FIELD.has_default_value = false 599 | module.GRAPHPROTO_INITIALIZER_FIELD.default_value = {} 600 | module.GRAPHPROTO_INITIALIZER_FIELD.message_type = module.TENSORPROTO 601 | module.GRAPHPROTO_INITIALIZER_FIELD.type = 11 602 | module.GRAPHPROTO_INITIALIZER_FIELD.cpp_type = 10 603 | 604 | module.GRAPHPROTO_DOC_STRING_FIELD.name = 'doc_string' 605 | module.GRAPHPROTO_DOC_STRING_FIELD.full_name = '.onnx.GraphProto.doc_string' 606 | module.GRAPHPROTO_DOC_STRING_FIELD.number = 10 607 | module.GRAPHPROTO_DOC_STRING_FIELD.index = 3 608 | module.GRAPHPROTO_DOC_STRING_FIELD.label = 1 609 | module.GRAPHPROTO_DOC_STRING_FIELD.has_default_value = false 610 | module.GRAPHPROTO_DOC_STRING_FIELD.default_value = '' 611 | module.GRAPHPROTO_DOC_STRING_FIELD.type = 9 612 | module.GRAPHPROTO_DOC_STRING_FIELD.cpp_type = 9 613 | 614 | module.GRAPHPROTO_INPUT_FIELD.name = 'input' 615 | module.GRAPHPROTO_INPUT_FIELD.full_name = '.onnx.GraphProto.input' 616 | module.GRAPHPROTO_INPUT_FIELD.number = 11 617 | module.GRAPHPROTO_INPUT_FIELD.index = 4 618 | module.GRAPHPROTO_INPUT_FIELD.label = 3 619 | module.GRAPHPROTO_INPUT_FIELD.has_default_value = false 620 | module.GRAPHPROTO_INPUT_FIELD.default_value = {} 621 | module.GRAPHPROTO_INPUT_FIELD.message_type = module.VALUEINFOPROTO 622 | module.GRAPHPROTO_INPUT_FIELD.type = 11 623 | module.GRAPHPROTO_INPUT_FIELD.cpp_type = 10 624 | 625 | module.GRAPHPROTO_OUTPUT_FIELD.name = 'output' 626 | module.GRAPHPROTO_OUTPUT_FIELD.full_name = '.onnx.GraphProto.output' 627 | module.GRAPHPROTO_OUTPUT_FIELD.number = 12 628 | module.GRAPHPROTO_OUTPUT_FIELD.index = 5 629 | module.GRAPHPROTO_OUTPUT_FIELD.label = 3 630 | module.GRAPHPROTO_OUTPUT_FIELD.has_default_value = false 631 | module.GRAPHPROTO_OUTPUT_FIELD.default_value = {} 632 | module.GRAPHPROTO_OUTPUT_FIELD.message_type = module.VALUEINFOPROTO 633 | module.GRAPHPROTO_OUTPUT_FIELD.type = 11 634 | module.GRAPHPROTO_OUTPUT_FIELD.cpp_type = 10 635 | 636 | module.GRAPHPROTO_VALUE_INFO_FIELD.name = 'value_info' 637 | module.GRAPHPROTO_VALUE_INFO_FIELD.full_name = '.onnx.GraphProto.value_info' 638 | module.GRAPHPROTO_VALUE_INFO_FIELD.number = 13 639 | module.GRAPHPROTO_VALUE_INFO_FIELD.index = 6 640 | module.GRAPHPROTO_VALUE_INFO_FIELD.label = 3 641 | module.GRAPHPROTO_VALUE_INFO_FIELD.has_default_value = false 642 | module.GRAPHPROTO_VALUE_INFO_FIELD.default_value = {} 643 | module.GRAPHPROTO_VALUE_INFO_FIELD.message_type = module.VALUEINFOPROTO 644 | module.GRAPHPROTO_VALUE_INFO_FIELD.type = 11 645 | module.GRAPHPROTO_VALUE_INFO_FIELD.cpp_type = 10 646 | 647 | module.GRAPHPROTO.name = 'GraphProto' 648 | module.GRAPHPROTO.full_name = '.onnx.GraphProto' 649 | module.GRAPHPROTO.nested_types = {} 650 | module.GRAPHPROTO.enum_types = {} 651 | module.GRAPHPROTO.fields = {module.GRAPHPROTO_NODE_FIELD, module.GRAPHPROTO_NAME_FIELD, module.GRAPHPROTO_INITIALIZER_FIELD, module.GRAPHPROTO_DOC_STRING_FIELD, module.GRAPHPROTO_INPUT_FIELD, module.GRAPHPROTO_OUTPUT_FIELD, module.GRAPHPROTO_VALUE_INFO_FIELD} 652 | module.GRAPHPROTO.is_extendable = false 653 | module.GRAPHPROTO.extensions = {} 654 | module.TENSORPROTO_SEGMENT_BEGIN_FIELD.name = 'begin' 655 | module.TENSORPROTO_SEGMENT_BEGIN_FIELD.full_name = '.onnx.TensorProto.Segment.begin' 656 | module.TENSORPROTO_SEGMENT_BEGIN_FIELD.number = 1 657 | module.TENSORPROTO_SEGMENT_BEGIN_FIELD.index = 0 658 | module.TENSORPROTO_SEGMENT_BEGIN_FIELD.label = 1 659 | module.TENSORPROTO_SEGMENT_BEGIN_FIELD.has_default_value = false 660 | module.TENSORPROTO_SEGMENT_BEGIN_FIELD.default_value = 0 661 | module.TENSORPROTO_SEGMENT_BEGIN_FIELD.type = 3 662 | module.TENSORPROTO_SEGMENT_BEGIN_FIELD.cpp_type = 2 663 | 664 | module.TENSORPROTO_SEGMENT_END_FIELD.name = 'end' 665 | module.TENSORPROTO_SEGMENT_END_FIELD.full_name = '.onnx.TensorProto.Segment.end' 666 | module.TENSORPROTO_SEGMENT_END_FIELD.number = 2 667 | module.TENSORPROTO_SEGMENT_END_FIELD.index = 1 668 | module.TENSORPROTO_SEGMENT_END_FIELD.label = 1 669 | module.TENSORPROTO_SEGMENT_END_FIELD.has_default_value = false 670 | module.TENSORPROTO_SEGMENT_END_FIELD.default_value = 0 671 | module.TENSORPROTO_SEGMENT_END_FIELD.type = 3 672 | module.TENSORPROTO_SEGMENT_END_FIELD.cpp_type = 2 673 | 674 | module.TENSORPROTO_SEGMENT.name = 'Segment' 675 | module.TENSORPROTO_SEGMENT.full_name = '.onnx.TensorProto.Segment' 676 | module.TENSORPROTO_SEGMENT.nested_types = {} 677 | module.TENSORPROTO_SEGMENT.enum_types = {} 678 | module.TENSORPROTO_SEGMENT.fields = {module.TENSORPROTO_SEGMENT_BEGIN_FIELD, module.TENSORPROTO_SEGMENT_END_FIELD} 679 | module.TENSORPROTO_SEGMENT.is_extendable = false 680 | module.TENSORPROTO_SEGMENT.extensions = {} 681 | module.TENSORPROTO_SEGMENT.containing_type = module.TENSORPROTO 682 | module.TENSORPROTO_DATATYPE_UNDEFINED_ENUM.name = 'UNDEFINED' 683 | module.TENSORPROTO_DATATYPE_UNDEFINED_ENUM.index = 0 684 | module.TENSORPROTO_DATATYPE_UNDEFINED_ENUM.number = 0 685 | module.TENSORPROTO_DATATYPE_FLOAT_ENUM.name = 'FLOAT' 686 | module.TENSORPROTO_DATATYPE_FLOAT_ENUM.index = 1 687 | module.TENSORPROTO_DATATYPE_FLOAT_ENUM.number = 1 688 | module.TENSORPROTO_DATATYPE_UINT8_ENUM.name = 'UINT8' 689 | module.TENSORPROTO_DATATYPE_UINT8_ENUM.index = 2 690 | module.TENSORPROTO_DATATYPE_UINT8_ENUM.number = 2 691 | module.TENSORPROTO_DATATYPE_INT8_ENUM.name = 'INT8' 692 | module.TENSORPROTO_DATATYPE_INT8_ENUM.index = 3 693 | module.TENSORPROTO_DATATYPE_INT8_ENUM.number = 3 694 | module.TENSORPROTO_DATATYPE_UINT16_ENUM.name = 'UINT16' 695 | module.TENSORPROTO_DATATYPE_UINT16_ENUM.index = 4 696 | module.TENSORPROTO_DATATYPE_UINT16_ENUM.number = 4 697 | module.TENSORPROTO_DATATYPE_INT16_ENUM.name = 'INT16' 698 | module.TENSORPROTO_DATATYPE_INT16_ENUM.index = 5 699 | module.TENSORPROTO_DATATYPE_INT16_ENUM.number = 5 700 | module.TENSORPROTO_DATATYPE_INT32_ENUM.name = 'INT32' 701 | module.TENSORPROTO_DATATYPE_INT32_ENUM.index = 6 702 | module.TENSORPROTO_DATATYPE_INT32_ENUM.number = 6 703 | module.TENSORPROTO_DATATYPE_INT64_ENUM.name = 'INT64' 704 | module.TENSORPROTO_DATATYPE_INT64_ENUM.index = 7 705 | module.TENSORPROTO_DATATYPE_INT64_ENUM.number = 7 706 | module.TENSORPROTO_DATATYPE_STRING_ENUM.name = 'STRING' 707 | module.TENSORPROTO_DATATYPE_STRING_ENUM.index = 8 708 | module.TENSORPROTO_DATATYPE_STRING_ENUM.number = 8 709 | module.TENSORPROTO_DATATYPE_BOOL_ENUM.name = 'BOOL' 710 | module.TENSORPROTO_DATATYPE_BOOL_ENUM.index = 9 711 | module.TENSORPROTO_DATATYPE_BOOL_ENUM.number = 9 712 | module.TENSORPROTO_DATATYPE_FLOAT16_ENUM.name = 'FLOAT16' 713 | module.TENSORPROTO_DATATYPE_FLOAT16_ENUM.index = 10 714 | module.TENSORPROTO_DATATYPE_FLOAT16_ENUM.number = 10 715 | module.TENSORPROTO_DATATYPE_DOUBLE_ENUM.name = 'DOUBLE' 716 | module.TENSORPROTO_DATATYPE_DOUBLE_ENUM.index = 11 717 | module.TENSORPROTO_DATATYPE_DOUBLE_ENUM.number = 11 718 | module.TENSORPROTO_DATATYPE_UINT32_ENUM.name = 'UINT32' 719 | module.TENSORPROTO_DATATYPE_UINT32_ENUM.index = 12 720 | module.TENSORPROTO_DATATYPE_UINT32_ENUM.number = 12 721 | module.TENSORPROTO_DATATYPE_UINT64_ENUM.name = 'UINT64' 722 | module.TENSORPROTO_DATATYPE_UINT64_ENUM.index = 13 723 | module.TENSORPROTO_DATATYPE_UINT64_ENUM.number = 13 724 | module.TENSORPROTO_DATATYPE_COMPLEX64_ENUM.name = 'COMPLEX64' 725 | module.TENSORPROTO_DATATYPE_COMPLEX64_ENUM.index = 14 726 | module.TENSORPROTO_DATATYPE_COMPLEX64_ENUM.number = 14 727 | module.TENSORPROTO_DATATYPE_COMPLEX128_ENUM.name = 'COMPLEX128' 728 | module.TENSORPROTO_DATATYPE_COMPLEX128_ENUM.index = 15 729 | module.TENSORPROTO_DATATYPE_COMPLEX128_ENUM.number = 15 730 | module.TENSORPROTO_DATATYPE.name = 'DataType' 731 | module.TENSORPROTO_DATATYPE.full_name = '.onnx.TensorProto.DataType' 732 | module.TENSORPROTO_DATATYPE.values = {module.TENSORPROTO_DATATYPE_UNDEFINED_ENUM,module.TENSORPROTO_DATATYPE_FLOAT_ENUM,module.TENSORPROTO_DATATYPE_UINT8_ENUM,module.TENSORPROTO_DATATYPE_INT8_ENUM,module.TENSORPROTO_DATATYPE_UINT16_ENUM,module.TENSORPROTO_DATATYPE_INT16_ENUM,module.TENSORPROTO_DATATYPE_INT32_ENUM,module.TENSORPROTO_DATATYPE_INT64_ENUM,module.TENSORPROTO_DATATYPE_STRING_ENUM,module.TENSORPROTO_DATATYPE_BOOL_ENUM,module.TENSORPROTO_DATATYPE_FLOAT16_ENUM,module.TENSORPROTO_DATATYPE_DOUBLE_ENUM,module.TENSORPROTO_DATATYPE_UINT32_ENUM,module.TENSORPROTO_DATATYPE_UINT64_ENUM,module.TENSORPROTO_DATATYPE_COMPLEX64_ENUM,module.TENSORPROTO_DATATYPE_COMPLEX128_ENUM} 733 | module.TENSORPROTO_DIMS_FIELD.name = 'dims' 734 | module.TENSORPROTO_DIMS_FIELD.full_name = '.onnx.TensorProto.dims' 735 | module.TENSORPROTO_DIMS_FIELD.number = 1 736 | module.TENSORPROTO_DIMS_FIELD.index = 0 737 | module.TENSORPROTO_DIMS_FIELD.label = 3 738 | module.TENSORPROTO_DIMS_FIELD.has_default_value = false 739 | module.TENSORPROTO_DIMS_FIELD.default_value = {} 740 | module.TENSORPROTO_DIMS_FIELD.type = 3 741 | module.TENSORPROTO_DIMS_FIELD.cpp_type = 2 742 | 743 | module.TENSORPROTO_DATA_TYPE_FIELD.name = 'data_type' 744 | module.TENSORPROTO_DATA_TYPE_FIELD.full_name = '.onnx.TensorProto.data_type' 745 | module.TENSORPROTO_DATA_TYPE_FIELD.number = 2 746 | module.TENSORPROTO_DATA_TYPE_FIELD.index = 1 747 | module.TENSORPROTO_DATA_TYPE_FIELD.label = 1 748 | module.TENSORPROTO_DATA_TYPE_FIELD.has_default_value = false 749 | module.TENSORPROTO_DATA_TYPE_FIELD.default_value = nil 750 | module.TENSORPROTO_DATA_TYPE_FIELD.enum_type = module.TENSORPROTO_DATATYPE 751 | module.TENSORPROTO_DATA_TYPE_FIELD.type = 14 752 | module.TENSORPROTO_DATA_TYPE_FIELD.cpp_type = 8 753 | 754 | module.TENSORPROTO_SEGMENT_FIELD.name = 'segment' 755 | module.TENSORPROTO_SEGMENT_FIELD.full_name = '.onnx.TensorProto.segment' 756 | module.TENSORPROTO_SEGMENT_FIELD.number = 3 757 | module.TENSORPROTO_SEGMENT_FIELD.index = 2 758 | module.TENSORPROTO_SEGMENT_FIELD.label = 1 759 | module.TENSORPROTO_SEGMENT_FIELD.has_default_value = false 760 | module.TENSORPROTO_SEGMENT_FIELD.default_value = nil 761 | module.TENSORPROTO_SEGMENT_FIELD.message_type = module.TENSORPROTO_SEGMENT 762 | module.TENSORPROTO_SEGMENT_FIELD.type = 11 763 | module.TENSORPROTO_SEGMENT_FIELD.cpp_type = 10 764 | 765 | module.TENSORPROTO_FLOAT_DATA_FIELD.name = 'float_data' 766 | module.TENSORPROTO_FLOAT_DATA_FIELD.full_name = '.onnx.TensorProto.float_data' 767 | module.TENSORPROTO_FLOAT_DATA_FIELD.number = 4 768 | module.TENSORPROTO_FLOAT_DATA_FIELD.index = 3 769 | module.TENSORPROTO_FLOAT_DATA_FIELD.label = 3 770 | module.TENSORPROTO_FLOAT_DATA_FIELD.has_default_value = false 771 | module.TENSORPROTO_FLOAT_DATA_FIELD.default_value = {} 772 | module.TENSORPROTO_FLOAT_DATA_FIELD.type = 2 773 | module.TENSORPROTO_FLOAT_DATA_FIELD.cpp_type = 6 774 | 775 | module.TENSORPROTO_INT32_DATA_FIELD.name = 'int32_data' 776 | module.TENSORPROTO_INT32_DATA_FIELD.full_name = '.onnx.TensorProto.int32_data' 777 | module.TENSORPROTO_INT32_DATA_FIELD.number = 5 778 | module.TENSORPROTO_INT32_DATA_FIELD.index = 4 779 | module.TENSORPROTO_INT32_DATA_FIELD.label = 3 780 | module.TENSORPROTO_INT32_DATA_FIELD.has_default_value = false 781 | module.TENSORPROTO_INT32_DATA_FIELD.default_value = {} 782 | module.TENSORPROTO_INT32_DATA_FIELD.type = 5 783 | module.TENSORPROTO_INT32_DATA_FIELD.cpp_type = 1 784 | 785 | module.TENSORPROTO_STRING_DATA_FIELD.name = 'string_data' 786 | module.TENSORPROTO_STRING_DATA_FIELD.full_name = '.onnx.TensorProto.string_data' 787 | module.TENSORPROTO_STRING_DATA_FIELD.number = 6 788 | module.TENSORPROTO_STRING_DATA_FIELD.index = 5 789 | module.TENSORPROTO_STRING_DATA_FIELD.label = 3 790 | module.TENSORPROTO_STRING_DATA_FIELD.has_default_value = false 791 | module.TENSORPROTO_STRING_DATA_FIELD.default_value = {} 792 | module.TENSORPROTO_STRING_DATA_FIELD.type = 12 793 | module.TENSORPROTO_STRING_DATA_FIELD.cpp_type = 9 794 | 795 | module.TENSORPROTO_INT64_DATA_FIELD.name = 'int64_data' 796 | module.TENSORPROTO_INT64_DATA_FIELD.full_name = '.onnx.TensorProto.int64_data' 797 | module.TENSORPROTO_INT64_DATA_FIELD.number = 7 798 | module.TENSORPROTO_INT64_DATA_FIELD.index = 6 799 | module.TENSORPROTO_INT64_DATA_FIELD.label = 3 800 | module.TENSORPROTO_INT64_DATA_FIELD.has_default_value = false 801 | module.TENSORPROTO_INT64_DATA_FIELD.default_value = {} 802 | module.TENSORPROTO_INT64_DATA_FIELD.type = 3 803 | module.TENSORPROTO_INT64_DATA_FIELD.cpp_type = 2 804 | 805 | module.TENSORPROTO_NAME_FIELD.name = 'name' 806 | module.TENSORPROTO_NAME_FIELD.full_name = '.onnx.TensorProto.name' 807 | module.TENSORPROTO_NAME_FIELD.number = 8 808 | module.TENSORPROTO_NAME_FIELD.index = 7 809 | module.TENSORPROTO_NAME_FIELD.label = 1 810 | module.TENSORPROTO_NAME_FIELD.has_default_value = false 811 | module.TENSORPROTO_NAME_FIELD.default_value = '' 812 | module.TENSORPROTO_NAME_FIELD.type = 9 813 | module.TENSORPROTO_NAME_FIELD.cpp_type = 9 814 | 815 | module.TENSORPROTO_DOC_STRING_FIELD.name = 'doc_string' 816 | module.TENSORPROTO_DOC_STRING_FIELD.full_name = '.onnx.TensorProto.doc_string' 817 | module.TENSORPROTO_DOC_STRING_FIELD.number = 12 818 | module.TENSORPROTO_DOC_STRING_FIELD.index = 8 819 | module.TENSORPROTO_DOC_STRING_FIELD.label = 1 820 | module.TENSORPROTO_DOC_STRING_FIELD.has_default_value = false 821 | module.TENSORPROTO_DOC_STRING_FIELD.default_value = '' 822 | module.TENSORPROTO_DOC_STRING_FIELD.type = 9 823 | module.TENSORPROTO_DOC_STRING_FIELD.cpp_type = 9 824 | 825 | module.TENSORPROTO_RAW_DATA_FIELD.name = 'raw_data' 826 | module.TENSORPROTO_RAW_DATA_FIELD.full_name = '.onnx.TensorProto.raw_data' 827 | module.TENSORPROTO_RAW_DATA_FIELD.number = 9 828 | module.TENSORPROTO_RAW_DATA_FIELD.index = 9 829 | module.TENSORPROTO_RAW_DATA_FIELD.label = 1 830 | module.TENSORPROTO_RAW_DATA_FIELD.has_default_value = false 831 | module.TENSORPROTO_RAW_DATA_FIELD.default_value = '' 832 | module.TENSORPROTO_RAW_DATA_FIELD.type = 12 833 | module.TENSORPROTO_RAW_DATA_FIELD.cpp_type = 9 834 | 835 | module.TENSORPROTO_DOUBLE_DATA_FIELD.name = 'double_data' 836 | module.TENSORPROTO_DOUBLE_DATA_FIELD.full_name = '.onnx.TensorProto.double_data' 837 | module.TENSORPROTO_DOUBLE_DATA_FIELD.number = 10 838 | module.TENSORPROTO_DOUBLE_DATA_FIELD.index = 10 839 | module.TENSORPROTO_DOUBLE_DATA_FIELD.label = 3 840 | module.TENSORPROTO_DOUBLE_DATA_FIELD.has_default_value = false 841 | module.TENSORPROTO_DOUBLE_DATA_FIELD.default_value = {} 842 | module.TENSORPROTO_DOUBLE_DATA_FIELD.type = 1 843 | module.TENSORPROTO_DOUBLE_DATA_FIELD.cpp_type = 5 844 | 845 | module.TENSORPROTO_UINT64_DATA_FIELD.name = 'uint64_data' 846 | module.TENSORPROTO_UINT64_DATA_FIELD.full_name = '.onnx.TensorProto.uint64_data' 847 | module.TENSORPROTO_UINT64_DATA_FIELD.number = 11 848 | module.TENSORPROTO_UINT64_DATA_FIELD.index = 11 849 | module.TENSORPROTO_UINT64_DATA_FIELD.label = 3 850 | module.TENSORPROTO_UINT64_DATA_FIELD.has_default_value = false 851 | module.TENSORPROTO_UINT64_DATA_FIELD.default_value = {} 852 | module.TENSORPROTO_UINT64_DATA_FIELD.type = 4 853 | module.TENSORPROTO_UINT64_DATA_FIELD.cpp_type = 4 854 | 855 | module.TENSORPROTO.name = 'TensorProto' 856 | module.TENSORPROTO.full_name = '.onnx.TensorProto' 857 | module.TENSORPROTO.nested_types = {module.TENSORPROTO_SEGMENT} 858 | module.TENSORPROTO.enum_types = {module.TENSORPROTO_DATATYPE} 859 | module.TENSORPROTO.fields = {module.TENSORPROTO_DIMS_FIELD, module.TENSORPROTO_DATA_TYPE_FIELD, module.TENSORPROTO_SEGMENT_FIELD, module.TENSORPROTO_FLOAT_DATA_FIELD, module.TENSORPROTO_INT32_DATA_FIELD, module.TENSORPROTO_STRING_DATA_FIELD, module.TENSORPROTO_INT64_DATA_FIELD, module.TENSORPROTO_NAME_FIELD, module.TENSORPROTO_DOC_STRING_FIELD, module.TENSORPROTO_RAW_DATA_FIELD, module.TENSORPROTO_DOUBLE_DATA_FIELD, module.TENSORPROTO_UINT64_DATA_FIELD} 860 | module.TENSORPROTO.is_extendable = false 861 | module.TENSORPROTO.extensions = {} 862 | module.TENSORSHAPEPROTO_DIMENSION_DIM_VALUE_FIELD.name = 'dim_value' 863 | module.TENSORSHAPEPROTO_DIMENSION_DIM_VALUE_FIELD.full_name = '.onnx.TensorShapeProto.Dimension.dim_value' 864 | module.TENSORSHAPEPROTO_DIMENSION_DIM_VALUE_FIELD.number = 1 865 | module.TENSORSHAPEPROTO_DIMENSION_DIM_VALUE_FIELD.index = 0 866 | module.TENSORSHAPEPROTO_DIMENSION_DIM_VALUE_FIELD.label = 1 867 | module.TENSORSHAPEPROTO_DIMENSION_DIM_VALUE_FIELD.has_default_value = false 868 | module.TENSORSHAPEPROTO_DIMENSION_DIM_VALUE_FIELD.default_value = 0 869 | module.TENSORSHAPEPROTO_DIMENSION_DIM_VALUE_FIELD.type = 3 870 | module.TENSORSHAPEPROTO_DIMENSION_DIM_VALUE_FIELD.cpp_type = 2 871 | 872 | module.TENSORSHAPEPROTO_DIMENSION_DIM_PARAM_FIELD.name = 'dim_param' 873 | module.TENSORSHAPEPROTO_DIMENSION_DIM_PARAM_FIELD.full_name = '.onnx.TensorShapeProto.Dimension.dim_param' 874 | module.TENSORSHAPEPROTO_DIMENSION_DIM_PARAM_FIELD.number = 2 875 | module.TENSORSHAPEPROTO_DIMENSION_DIM_PARAM_FIELD.index = 1 876 | module.TENSORSHAPEPROTO_DIMENSION_DIM_PARAM_FIELD.label = 1 877 | module.TENSORSHAPEPROTO_DIMENSION_DIM_PARAM_FIELD.has_default_value = false 878 | module.TENSORSHAPEPROTO_DIMENSION_DIM_PARAM_FIELD.default_value = '' 879 | module.TENSORSHAPEPROTO_DIMENSION_DIM_PARAM_FIELD.type = 9 880 | module.TENSORSHAPEPROTO_DIMENSION_DIM_PARAM_FIELD.cpp_type = 9 881 | 882 | module.TENSORSHAPEPROTO_DIMENSION_DENOTATION_FIELD.name = 'denotation' 883 | module.TENSORSHAPEPROTO_DIMENSION_DENOTATION_FIELD.full_name = '.onnx.TensorShapeProto.Dimension.denotation' 884 | module.TENSORSHAPEPROTO_DIMENSION_DENOTATION_FIELD.number = 3 885 | module.TENSORSHAPEPROTO_DIMENSION_DENOTATION_FIELD.index = 2 886 | module.TENSORSHAPEPROTO_DIMENSION_DENOTATION_FIELD.label = 1 887 | module.TENSORSHAPEPROTO_DIMENSION_DENOTATION_FIELD.has_default_value = false 888 | module.TENSORSHAPEPROTO_DIMENSION_DENOTATION_FIELD.default_value = '' 889 | module.TENSORSHAPEPROTO_DIMENSION_DENOTATION_FIELD.type = 9 890 | module.TENSORSHAPEPROTO_DIMENSION_DENOTATION_FIELD.cpp_type = 9 891 | 892 | module.TENSORSHAPEPROTO_DIMENSION.name = 'Dimension' 893 | module.TENSORSHAPEPROTO_DIMENSION.full_name = '.onnx.TensorShapeProto.Dimension' 894 | module.TENSORSHAPEPROTO_DIMENSION.nested_types = {} 895 | module.TENSORSHAPEPROTO_DIMENSION.enum_types = {} 896 | module.TENSORSHAPEPROTO_DIMENSION.fields = {module.TENSORSHAPEPROTO_DIMENSION_DIM_VALUE_FIELD, module.TENSORSHAPEPROTO_DIMENSION_DIM_PARAM_FIELD, module.TENSORSHAPEPROTO_DIMENSION_DENOTATION_FIELD} 897 | module.TENSORSHAPEPROTO_DIMENSION.is_extendable = false 898 | module.TENSORSHAPEPROTO_DIMENSION.extensions = {} 899 | module.TENSORSHAPEPROTO_DIMENSION.containing_type = module.TENSORSHAPEPROTO 900 | module.TENSORSHAPEPROTO_DIM_FIELD.name = 'dim' 901 | module.TENSORSHAPEPROTO_DIM_FIELD.full_name = '.onnx.TensorShapeProto.dim' 902 | module.TENSORSHAPEPROTO_DIM_FIELD.number = 1 903 | module.TENSORSHAPEPROTO_DIM_FIELD.index = 0 904 | module.TENSORSHAPEPROTO_DIM_FIELD.label = 3 905 | module.TENSORSHAPEPROTO_DIM_FIELD.has_default_value = false 906 | module.TENSORSHAPEPROTO_DIM_FIELD.default_value = {} 907 | module.TENSORSHAPEPROTO_DIM_FIELD.message_type = module.TENSORSHAPEPROTO_DIMENSION 908 | module.TENSORSHAPEPROTO_DIM_FIELD.type = 11 909 | module.TENSORSHAPEPROTO_DIM_FIELD.cpp_type = 10 910 | 911 | module.TENSORSHAPEPROTO.name = 'TensorShapeProto' 912 | module.TENSORSHAPEPROTO.full_name = '.onnx.TensorShapeProto' 913 | module.TENSORSHAPEPROTO.nested_types = {module.TENSORSHAPEPROTO_DIMENSION} 914 | module.TENSORSHAPEPROTO.enum_types = {} 915 | module.TENSORSHAPEPROTO.fields = {module.TENSORSHAPEPROTO_DIM_FIELD} 916 | module.TENSORSHAPEPROTO.is_extendable = false 917 | module.TENSORSHAPEPROTO.extensions = {} 918 | module.DENOTATIONCONSTPROTO_DATA_BATCH_FIELD.name = 'DATA_BATCH' 919 | module.DENOTATIONCONSTPROTO_DATA_BATCH_FIELD.full_name = '.onnx.DenotationConstProto.DATA_BATCH' 920 | module.DENOTATIONCONSTPROTO_DATA_BATCH_FIELD.number = 1 921 | module.DENOTATIONCONSTPROTO_DATA_BATCH_FIELD.index = 0 922 | module.DENOTATIONCONSTPROTO_DATA_BATCH_FIELD.label = 1 923 | module.DENOTATIONCONSTPROTO_DATA_BATCH_FIELD.has_default_value = true 924 | module.DENOTATIONCONSTPROTO_DATA_BATCH_FIELD.default_value = 'DATA_BATCH' 925 | module.DENOTATIONCONSTPROTO_DATA_BATCH_FIELD.type = 9 926 | module.DENOTATIONCONSTPROTO_DATA_BATCH_FIELD.cpp_type = 9 927 | 928 | module.DENOTATIONCONSTPROTO_DATA_CHANNEL_FIELD.name = 'DATA_CHANNEL' 929 | module.DENOTATIONCONSTPROTO_DATA_CHANNEL_FIELD.full_name = '.onnx.DenotationConstProto.DATA_CHANNEL' 930 | module.DENOTATIONCONSTPROTO_DATA_CHANNEL_FIELD.number = 2 931 | module.DENOTATIONCONSTPROTO_DATA_CHANNEL_FIELD.index = 1 932 | module.DENOTATIONCONSTPROTO_DATA_CHANNEL_FIELD.label = 1 933 | module.DENOTATIONCONSTPROTO_DATA_CHANNEL_FIELD.has_default_value = true 934 | module.DENOTATIONCONSTPROTO_DATA_CHANNEL_FIELD.default_value = 'DATA_CHANNEL' 935 | module.DENOTATIONCONSTPROTO_DATA_CHANNEL_FIELD.type = 9 936 | module.DENOTATIONCONSTPROTO_DATA_CHANNEL_FIELD.cpp_type = 9 937 | 938 | module.DENOTATIONCONSTPROTO_DATA_TIME_FIELD.name = 'DATA_TIME' 939 | module.DENOTATIONCONSTPROTO_DATA_TIME_FIELD.full_name = '.onnx.DenotationConstProto.DATA_TIME' 940 | module.DENOTATIONCONSTPROTO_DATA_TIME_FIELD.number = 3 941 | module.DENOTATIONCONSTPROTO_DATA_TIME_FIELD.index = 2 942 | module.DENOTATIONCONSTPROTO_DATA_TIME_FIELD.label = 1 943 | module.DENOTATIONCONSTPROTO_DATA_TIME_FIELD.has_default_value = true 944 | module.DENOTATIONCONSTPROTO_DATA_TIME_FIELD.default_value = 'DATA_TIME' 945 | module.DENOTATIONCONSTPROTO_DATA_TIME_FIELD.type = 9 946 | module.DENOTATIONCONSTPROTO_DATA_TIME_FIELD.cpp_type = 9 947 | 948 | module.DENOTATIONCONSTPROTO_DATA_FEATURE_FIELD.name = 'DATA_FEATURE' 949 | module.DENOTATIONCONSTPROTO_DATA_FEATURE_FIELD.full_name = '.onnx.DenotationConstProto.DATA_FEATURE' 950 | module.DENOTATIONCONSTPROTO_DATA_FEATURE_FIELD.number = 4 951 | module.DENOTATIONCONSTPROTO_DATA_FEATURE_FIELD.index = 3 952 | module.DENOTATIONCONSTPROTO_DATA_FEATURE_FIELD.label = 1 953 | module.DENOTATIONCONSTPROTO_DATA_FEATURE_FIELD.has_default_value = true 954 | module.DENOTATIONCONSTPROTO_DATA_FEATURE_FIELD.default_value = 'DATA_FEATURE' 955 | module.DENOTATIONCONSTPROTO_DATA_FEATURE_FIELD.type = 9 956 | module.DENOTATIONCONSTPROTO_DATA_FEATURE_FIELD.cpp_type = 9 957 | 958 | module.DENOTATIONCONSTPROTO_FILTER_IN_CHANNEL_FIELD.name = 'FILTER_IN_CHANNEL' 959 | module.DENOTATIONCONSTPROTO_FILTER_IN_CHANNEL_FIELD.full_name = '.onnx.DenotationConstProto.FILTER_IN_CHANNEL' 960 | module.DENOTATIONCONSTPROTO_FILTER_IN_CHANNEL_FIELD.number = 5 961 | module.DENOTATIONCONSTPROTO_FILTER_IN_CHANNEL_FIELD.index = 4 962 | module.DENOTATIONCONSTPROTO_FILTER_IN_CHANNEL_FIELD.label = 1 963 | module.DENOTATIONCONSTPROTO_FILTER_IN_CHANNEL_FIELD.has_default_value = true 964 | module.DENOTATIONCONSTPROTO_FILTER_IN_CHANNEL_FIELD.default_value = 'FILTER_IN_CHANNEL' 965 | module.DENOTATIONCONSTPROTO_FILTER_IN_CHANNEL_FIELD.type = 9 966 | module.DENOTATIONCONSTPROTO_FILTER_IN_CHANNEL_FIELD.cpp_type = 9 967 | 968 | module.DENOTATIONCONSTPROTO_FILTER_OUT_CHANNEL_FIELD.name = 'FILTER_OUT_CHANNEL' 969 | module.DENOTATIONCONSTPROTO_FILTER_OUT_CHANNEL_FIELD.full_name = '.onnx.DenotationConstProto.FILTER_OUT_CHANNEL' 970 | module.DENOTATIONCONSTPROTO_FILTER_OUT_CHANNEL_FIELD.number = 6 971 | module.DENOTATIONCONSTPROTO_FILTER_OUT_CHANNEL_FIELD.index = 5 972 | module.DENOTATIONCONSTPROTO_FILTER_OUT_CHANNEL_FIELD.label = 1 973 | module.DENOTATIONCONSTPROTO_FILTER_OUT_CHANNEL_FIELD.has_default_value = true 974 | module.DENOTATIONCONSTPROTO_FILTER_OUT_CHANNEL_FIELD.default_value = 'FILTER_OUT_CHANNEL' 975 | module.DENOTATIONCONSTPROTO_FILTER_OUT_CHANNEL_FIELD.type = 9 976 | module.DENOTATIONCONSTPROTO_FILTER_OUT_CHANNEL_FIELD.cpp_type = 9 977 | 978 | module.DENOTATIONCONSTPROTO_FILTER_SPATIAL_FIELD.name = 'FILTER_SPATIAL' 979 | module.DENOTATIONCONSTPROTO_FILTER_SPATIAL_FIELD.full_name = '.onnx.DenotationConstProto.FILTER_SPATIAL' 980 | module.DENOTATIONCONSTPROTO_FILTER_SPATIAL_FIELD.number = 7 981 | module.DENOTATIONCONSTPROTO_FILTER_SPATIAL_FIELD.index = 6 982 | module.DENOTATIONCONSTPROTO_FILTER_SPATIAL_FIELD.label = 1 983 | module.DENOTATIONCONSTPROTO_FILTER_SPATIAL_FIELD.has_default_value = true 984 | module.DENOTATIONCONSTPROTO_FILTER_SPATIAL_FIELD.default_value = 'FILTER_SPATIAL' 985 | module.DENOTATIONCONSTPROTO_FILTER_SPATIAL_FIELD.type = 9 986 | module.DENOTATIONCONSTPROTO_FILTER_SPATIAL_FIELD.cpp_type = 9 987 | 988 | module.DENOTATIONCONSTPROTO.name = 'DenotationConstProto' 989 | module.DENOTATIONCONSTPROTO.full_name = '.onnx.DenotationConstProto' 990 | module.DENOTATIONCONSTPROTO.nested_types = {} 991 | module.DENOTATIONCONSTPROTO.enum_types = {} 992 | module.DENOTATIONCONSTPROTO.fields = {module.DENOTATIONCONSTPROTO_DATA_BATCH_FIELD, module.DENOTATIONCONSTPROTO_DATA_CHANNEL_FIELD, module.DENOTATIONCONSTPROTO_DATA_TIME_FIELD, module.DENOTATIONCONSTPROTO_DATA_FEATURE_FIELD, module.DENOTATIONCONSTPROTO_FILTER_IN_CHANNEL_FIELD, module.DENOTATIONCONSTPROTO_FILTER_OUT_CHANNEL_FIELD, module.DENOTATIONCONSTPROTO_FILTER_SPATIAL_FIELD} 993 | module.DENOTATIONCONSTPROTO.is_extendable = false 994 | module.DENOTATIONCONSTPROTO.extensions = {} 995 | module.TYPEPROTO_TENSOR_ELEM_TYPE_FIELD.name = 'elem_type' 996 | module.TYPEPROTO_TENSOR_ELEM_TYPE_FIELD.full_name = '.onnx.TypeProto.Tensor.elem_type' 997 | module.TYPEPROTO_TENSOR_ELEM_TYPE_FIELD.number = 1 998 | module.TYPEPROTO_TENSOR_ELEM_TYPE_FIELD.index = 0 999 | module.TYPEPROTO_TENSOR_ELEM_TYPE_FIELD.label = 1 1000 | module.TYPEPROTO_TENSOR_ELEM_TYPE_FIELD.has_default_value = false 1001 | module.TYPEPROTO_TENSOR_ELEM_TYPE_FIELD.default_value = nil 1002 | module.TYPEPROTO_TENSOR_ELEM_TYPE_FIELD.enum_type = module.TENSORPROTO_DATATYPE 1003 | module.TYPEPROTO_TENSOR_ELEM_TYPE_FIELD.type = 14 1004 | module.TYPEPROTO_TENSOR_ELEM_TYPE_FIELD.cpp_type = 8 1005 | 1006 | module.TYPEPROTO_TENSOR_SHAPE_FIELD.name = 'shape' 1007 | module.TYPEPROTO_TENSOR_SHAPE_FIELD.full_name = '.onnx.TypeProto.Tensor.shape' 1008 | module.TYPEPROTO_TENSOR_SHAPE_FIELD.number = 2 1009 | module.TYPEPROTO_TENSOR_SHAPE_FIELD.index = 1 1010 | module.TYPEPROTO_TENSOR_SHAPE_FIELD.label = 1 1011 | module.TYPEPROTO_TENSOR_SHAPE_FIELD.has_default_value = false 1012 | module.TYPEPROTO_TENSOR_SHAPE_FIELD.default_value = nil 1013 | module.TYPEPROTO_TENSOR_SHAPE_FIELD.message_type = module.TENSORSHAPEPROTO 1014 | module.TYPEPROTO_TENSOR_SHAPE_FIELD.type = 11 1015 | module.TYPEPROTO_TENSOR_SHAPE_FIELD.cpp_type = 10 1016 | 1017 | module.TYPEPROTO_TENSOR.name = 'Tensor' 1018 | module.TYPEPROTO_TENSOR.full_name = '.onnx.TypeProto.Tensor' 1019 | module.TYPEPROTO_TENSOR.nested_types = {} 1020 | module.TYPEPROTO_TENSOR.enum_types = {} 1021 | module.TYPEPROTO_TENSOR.fields = {module.TYPEPROTO_TENSOR_ELEM_TYPE_FIELD, module.TYPEPROTO_TENSOR_SHAPE_FIELD} 1022 | module.TYPEPROTO_TENSOR.is_extendable = false 1023 | module.TYPEPROTO_TENSOR.extensions = {} 1024 | module.TYPEPROTO_TENSOR.containing_type = module.TYPEPROTO 1025 | module.TYPEPROTO_TENSOR_TYPE_FIELD.name = 'tensor_type' 1026 | module.TYPEPROTO_TENSOR_TYPE_FIELD.full_name = '.onnx.TypeProto.tensor_type' 1027 | module.TYPEPROTO_TENSOR_TYPE_FIELD.number = 1 1028 | module.TYPEPROTO_TENSOR_TYPE_FIELD.index = 0 1029 | module.TYPEPROTO_TENSOR_TYPE_FIELD.label = 1 1030 | module.TYPEPROTO_TENSOR_TYPE_FIELD.has_default_value = false 1031 | module.TYPEPROTO_TENSOR_TYPE_FIELD.default_value = nil 1032 | module.TYPEPROTO_TENSOR_TYPE_FIELD.message_type = module.TYPEPROTO_TENSOR 1033 | module.TYPEPROTO_TENSOR_TYPE_FIELD.type = 11 1034 | module.TYPEPROTO_TENSOR_TYPE_FIELD.cpp_type = 10 1035 | 1036 | module.TYPEPROTO.name = 'TypeProto' 1037 | module.TYPEPROTO.full_name = '.onnx.TypeProto' 1038 | module.TYPEPROTO.nested_types = {module.TYPEPROTO_TENSOR} 1039 | module.TYPEPROTO.enum_types = {} 1040 | module.TYPEPROTO.fields = {module.TYPEPROTO_TENSOR_TYPE_FIELD} 1041 | module.TYPEPROTO.is_extendable = false 1042 | module.TYPEPROTO.extensions = {} 1043 | module.OPERATORSETIDPROTO_DOMAIN_FIELD.name = 'domain' 1044 | module.OPERATORSETIDPROTO_DOMAIN_FIELD.full_name = '.onnx.OperatorSetIdProto.domain' 1045 | module.OPERATORSETIDPROTO_DOMAIN_FIELD.number = 1 1046 | module.OPERATORSETIDPROTO_DOMAIN_FIELD.index = 0 1047 | module.OPERATORSETIDPROTO_DOMAIN_FIELD.label = 1 1048 | module.OPERATORSETIDPROTO_DOMAIN_FIELD.has_default_value = false 1049 | module.OPERATORSETIDPROTO_DOMAIN_FIELD.default_value = '' 1050 | module.OPERATORSETIDPROTO_DOMAIN_FIELD.type = 9 1051 | module.OPERATORSETIDPROTO_DOMAIN_FIELD.cpp_type = 9 1052 | 1053 | module.OPERATORSETIDPROTO_VERSION_FIELD.name = 'version' 1054 | module.OPERATORSETIDPROTO_VERSION_FIELD.full_name = '.onnx.OperatorSetIdProto.version' 1055 | module.OPERATORSETIDPROTO_VERSION_FIELD.number = 2 1056 | module.OPERATORSETIDPROTO_VERSION_FIELD.index = 1 1057 | module.OPERATORSETIDPROTO_VERSION_FIELD.label = 1 1058 | module.OPERATORSETIDPROTO_VERSION_FIELD.has_default_value = false 1059 | module.OPERATORSETIDPROTO_VERSION_FIELD.default_value = 0 1060 | module.OPERATORSETIDPROTO_VERSION_FIELD.type = 3 1061 | module.OPERATORSETIDPROTO_VERSION_FIELD.cpp_type = 2 1062 | 1063 | module.OPERATORSETIDPROTO.name = 'OperatorSetIdProto' 1064 | module.OPERATORSETIDPROTO.full_name = '.onnx.OperatorSetIdProto' 1065 | module.OPERATORSETIDPROTO.nested_types = {} 1066 | module.OPERATORSETIDPROTO.enum_types = {} 1067 | module.OPERATORSETIDPROTO.fields = {module.OPERATORSETIDPROTO_DOMAIN_FIELD, module.OPERATORSETIDPROTO_VERSION_FIELD} 1068 | module.OPERATORSETIDPROTO.is_extendable = false 1069 | module.OPERATORSETIDPROTO.extensions = {} 1070 | 1071 | module.AttributeProto = protobuf.Message(module.ATTRIBUTEPROTO) 1072 | module.DenotationConstProto = protobuf.Message(module.DENOTATIONCONSTPROTO) 1073 | module.GraphProto = protobuf.Message(module.GRAPHPROTO) 1074 | module.ModelProto = protobuf.Message(module.MODELPROTO) 1075 | module.NodeProto = protobuf.Message(module.NODEPROTO) 1076 | module.OperatorSetIdProto = protobuf.Message(module.OPERATORSETIDPROTO) 1077 | module.StringStringEntryProto = protobuf.Message(module.STRINGSTRINGENTRYPROTO) 1078 | module.TensorProto = protobuf.Message(module.TENSORPROTO) 1079 | module.TensorProto.Segment = protobuf.Message(module.TENSORPROTO_SEGMENT) 1080 | module.TensorShapeProto = protobuf.Message(module.TENSORSHAPEPROTO) 1081 | module.TensorShapeProto.Dimension = protobuf.Message(module.TENSORSHAPEPROTO_DIMENSION) 1082 | module.TypeProto = protobuf.Message(module.TYPEPROTO) 1083 | module.TypeProto.Tensor = protobuf.Message(module.TYPEPROTO_TENSOR) 1084 | module.ValueInfoProto = protobuf.Message(module.VALUEINFOPROTO) 1085 | module.Version = {} 1086 | module.Version.IR_VERSION = 3 1087 | module.Version.IR_VERSION_2017_10_10 = 1 1088 | module.Version.IR_VERSION_2017_10_30 = 2 1089 | module.Version._START_VERSION = 0 1090 | 1091 | 1092 | module.MESSAGE_TYPES = {'AttributeProto','ValueInfoProto','NodeProto','ModelProto','StringStringEntryProto','GraphProto','TensorProto','TensorShapeProto','DenotationConstProto','TypeProto','OperatorSetIdProto'} 1093 | module.ENUM_TYPES = {'Version'} 1094 | 1095 | return module 1096 | --------------------------------------------------------------------------------