├── .gitignore ├── .gitmodules ├── README.md ├── onnx_pytorch ├── __init__.py ├── pytorch_helper.py └── verify.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .eggs 2 | *.egg-info 3 | *.pyc 4 | env 5 | experimental 6 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | # Keep this file synced with the .gitmodules of our subtrees. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | onnx-pytorch 2 | ======== 3 | -------------------------------------------------------------------------------- /onnx_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | import onnx_pytorch.verify 2 | -------------------------------------------------------------------------------- /onnx_pytorch/pytorch_helper.py: -------------------------------------------------------------------------------- 1 | import io 2 | import torch.onnx 3 | import onnx 4 | from onnx_caffe2.backend import Caffe2Backend 5 | from caffe2.python.core import BlobReference, Net 6 | 7 | 8 | _next_idx = 0 9 | # Clone net takes a dict instead of a lambda 10 | # It should probably take a lambda, it is more flexible 11 | # We fake dict here 12 | 13 | 14 | class _FakeDict(object): 15 | def __init__(self, fn): 16 | self.fn = fn 17 | 18 | def get(self, name, _): 19 | return self.fn(name) 20 | 21 | 22 | def PyTorchModule(helper, model, sample_arguments, caffe2_inputs, prefix_name=None): 23 | """ 24 | Embed an ONNX-exportable PyTorch Model into a Caffe2 model being built. 25 | 26 | Arguments: 27 | helper (caffe2.python.core.ModelHelder): the model helper where 28 | this imported network should be inserted 29 | model (torch.nn.Module): the model to be exported 30 | sample_arguments (tuple of arguments): the inputs to 31 | the model, e.g., such that ``model(*args)`` is a valid 32 | invocation of the model. Any non-Variable arguments will 33 | be hard-coded into the exported model; any Variable arguments 34 | will become inputs of the exported model, in the order they 35 | occur in args. If args is a Variable, this is equivalent 36 | to having called it with a 1-ary tuple of that Variable. 37 | (Note: passing keyword arguments to the model is not currently 38 | supported. Give us a shout if you need it.) 39 | caffe2_inputs (list of str or caffe2.python.core.BlobReference): the 40 | caffe2 Blobs that should be inputs to this network. Must be 41 | the same length as sample_arguments 42 | prefix_name: prefix name to add to each member of the blob, if None then 43 | a fresh prefix pytorch_input_N/ is used 44 | Returns: 45 | A tuple of caffe2.python.core.BlobReference objects referring to the 46 | models outputs, or a single BlobReference when the model returns a single 47 | value. 48 | """ 49 | if prefix_name is None: 50 | global _next_idx 51 | prefix_name = 'pytorch_import_' + str(_next_idx) + '/' 52 | _next_idx += 1 53 | 54 | # TODO: handle the case where model cannot be exported 55 | # and embed as a Python op in Caffe2 56 | f = io.BytesIO() 57 | torch.onnx.export( 58 | model, sample_arguments, f, export_params=True) 59 | onnx_model = onnx.load(io.BytesIO(f.getvalue())) 60 | init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net( 61 | onnx_model) 62 | 63 | initialized = set([x.name for x in onnx_model.graph.initializer]) 64 | uninitialized_inputs = {x.name: i for i, x in enumerate( 65 | onnx_model.graph.input) if x.name not in initialized} 66 | 67 | if(len(uninitialized_inputs) != len(caffe2_inputs)): 68 | raise ValueError('Expected {} inputs but found {}'.format( 69 | len(uninitialized_inputs), len(caffe2_inputs))) 70 | 71 | def remap_blob_name(name): 72 | if name in uninitialized_inputs: 73 | idx = uninitialized_inputs[name] 74 | return str(caffe2_inputs[idx]) 75 | return prefix_name + name 76 | 77 | predict_net = Net(predict_net).Clone('anon', _FakeDict(remap_blob_name)) 78 | helper.net.AppendNet(predict_net) 79 | 80 | init_net = Net(init_net).Clone('anon', _FakeDict(remap_blob_name)) 81 | helper.param_init_net.AppendNet(init_net) 82 | 83 | results = tuple([BlobReference(remap_blob_name(x.name), helper.net) 84 | for x in onnx_model.graph.output]) 85 | return results 86 | -------------------------------------------------------------------------------- /onnx_pytorch/verify.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.jit 3 | import torch.onnx 4 | 5 | import onnx 6 | import onnx.helper 7 | 8 | import numpy as np 9 | 10 | import difflib 11 | import contextlib 12 | import io 13 | 14 | 15 | def colonize(msg, sep=": "): 16 | if not msg: 17 | return "" 18 | else: 19 | return msg + sep 20 | 21 | 22 | class Errors(object): 23 | """ 24 | An error-collecting object which supports error recovery. 25 | 26 | It is intended to be used like a context manager: 27 | 28 | >>> with Errors("Top-level error message") as errs: 29 | >>> ... 30 | """ 31 | 32 | def __init__(self, msg, decimal=3): 33 | self.msg = msg 34 | self.errors = [] 35 | self.context = [] 36 | self.decimal = decimal 37 | 38 | # Allocated upon instance creation so that multiple Errors 39 | # can be used 40 | class ShortCircuit(Exception): 41 | pass 42 | self.exc_class = ShortCircuit 43 | 44 | def requireAlmostEqual(self, x, y, msg=None): 45 | """ 46 | Test that x and y are nearly equal (equal within self.decimal 47 | precision); aborts execution if they are not. 48 | """ 49 | self.almostEqualAndThen(x, y, msg, self.failWith) 50 | 51 | def checkAlmostEqual(self, x, y, msg=None): 52 | """ 53 | Test that x and y are nearly equal (equal within self.decimal 54 | precision), but continue execution even if they are not equal. 55 | 56 | To prevent error cascades, you should remember to call 'failIfErrs' 57 | at some later point in time. 58 | """ 59 | self.almostEqualAndThen(x, y, msg, self.addErr) 60 | 61 | def almostEqualAndThen(self, x, y, msg, k): 62 | """ 63 | Helper for implementing 'requireAlmostEqual' and 'checkAlmostEqual'. 64 | Upon failure, invokes continuation 'k' with the error message. 65 | 66 | At the moment, only tests on 'numpy.ndarray' are supported. 67 | """ 68 | if isinstance(x, np.ndarray) and isinstance(y, np.ndarray): 69 | try: 70 | np.testing.assert_almost_equal(x, y, decimal=self.decimal) 71 | except AssertionError as e: 72 | k("{}{}".format(colonize(msg), str(e).lstrip())) 73 | else: 74 | raise RuntimeError("Unsupported almost equal test") 75 | 76 | def requireEqual(self, x, y, msg=None): 77 | """ 78 | Test that x and y are equal; aborts execution if they are not. 79 | """ 80 | self.equalAndThen(x, y, msg, self.failWith) 81 | 82 | def checkEqual(self, x, y, msg=None): 83 | """ 84 | Test that x and y are equal, but continue execution even if they are not equal. 85 | 86 | To prevent error cascades, you should remember to call 'failIfErrs' 87 | at some later point in time. 88 | """ 89 | self.equalAndThen(x, y, msg, self.addErr) 90 | 91 | # Bit-for-bit accuracy test 92 | def equalAndThen(self, x, y, msg, k): 93 | """ 94 | Helper for implementing 'requireEqual' and 'checkEqual'. Upon failure, 95 | invokes continuation 'k' with the error message. 96 | """ 97 | if isinstance(x, onnx.TensorProto) and isinstance(y, onnx.TensorProto): 98 | self.equalAndThen(x.name, y.name, msg, k) 99 | # Use numpy for the comparison 100 | t1 = onnx.numpy_helper.to_array(x) 101 | t2 = onnx.numpy_helper.to_array(y) 102 | new_msg = "{}In embedded parameter '{}'".format(colonize(msg), x.name) 103 | self.equalAndThen(t1, t2, new_msg, k) 104 | elif isinstance(x, np.ndarray) and isinstance(y, np.ndarray): 105 | try: 106 | np.testing.assert_equal(x, y) 107 | except AssertionError as e: 108 | k("{}{}".format(colonize(msg, ": "), str(e).lstrip())) 109 | else: 110 | if x != y: 111 | # TODO: Better algorithm for lists 112 | sx = str(x) 113 | sy = str(y) 114 | if len(sx) > 40 or len(sy) > 40 or '\n' in sx or '\n' in sy: 115 | # long form 116 | l = "=" * 50 117 | k("\n{}The value\n{}\n{}\n{}\n\ndoes not equal\n\n{}\n{}\n{}" 118 | .format(colonize(msg, ":\n"), l, sx, l, l, sy, l)) 119 | else: 120 | k("{}{} != {}".format(colonize(msg), sx, sy)) 121 | 122 | def requireMultiLineEqual(self, x, y, msg=None): 123 | """ 124 | Test that long, multi-line strings x and y are equal; 125 | aborts execution if they are not. 126 | """ 127 | self.multiLineEqualAndThen(x, y, msg, self.failWith) 128 | 129 | def multiLineEqualAndThen(self, x, y, msg, k): 130 | """ 131 | Helper for implementing 'requireMultiLineEqual'. Upon failure, 132 | invokes continuation 'k' with the error message. 133 | """ 134 | if msg is None: 135 | msg = "Strings are not equal" 136 | if x != y: 137 | diff = difflib.ndiff(x.splitlines(True), y.splitlines(True)) 138 | k("{}{}".format(colonize(msg, ":\n\n"), "".join(diff))) 139 | 140 | def addErr(self, msg): 141 | """ 142 | Add an error to the error context, but continue executing. 143 | """ 144 | # TODO: instead of immediately concatenating the context in the msg, 145 | # attach it as metadata and make a decision how to format it later. 146 | msg_w_ctx = msg 147 | for c in reversed(self.context): 148 | msg += "\n\n * " + "\n ".join(c.splitlines()) 149 | self.errors.append(msg) 150 | 151 | def fail(self): 152 | """ 153 | Immediately fail and short-circuit to the next recovery context. 154 | 155 | NB: It is an error to 'fail' without having added any errors to 156 | the error context. 157 | """ 158 | raise self.exc_class() 159 | 160 | def failWith(self, msg): 161 | """ 162 | Add an error to the error context, and then short-circuit. 163 | """ 164 | self.addErr(msg) 165 | self.fail() 166 | 167 | def failIfErrs(self): 168 | """ 169 | If there are any errors in the error context, short-circuit. 170 | 171 | This is used to prevent error cascades. 172 | """ 173 | if self.errors: 174 | self.fail() 175 | 176 | def recover(parent_self): 177 | """ 178 | Returns a context manager which can be used to recover in case of 179 | an error. Example usage: 180 | 181 | >>> with errs.recover(): 182 | >>> ... 183 | """ 184 | class Recover(object): 185 | def __enter__(self): 186 | pass 187 | 188 | def __exit__(self, exc_type, exc_value, traceback): 189 | if exc_type == parent_self.exc_class: 190 | return True 191 | return Recover() 192 | 193 | def addErrCtxt(parent_self, msg): 194 | """ 195 | Returns a context manager which encloses a fragment of code with 196 | an extra contextual message, e.g., where an error occurred, or a hint 197 | applicable to all errors in the area. Example usage: 198 | 199 | >>> with errs.addErrCtx("Some text"): 200 | >>> ... 201 | """ 202 | class AddContext(object): 203 | def __enter__(self): 204 | parent_self.context.append(msg) 205 | 206 | def __exit__(self, exc_type, exc_value, traceback): 207 | parent_self.context.pop() 208 | return AddContext() 209 | 210 | def __enter__(self): 211 | return self 212 | 213 | def __exit__(self, exc_type, exc_value, traceback): 214 | if self.errors: 215 | errors_msg = "\n\n".join(map(lambda x: "ERROR: " + x, self.errors)) 216 | final_msg = "{}\n{}\n{}".format(self.msg, '-' * 70, errors_msg) 217 | raise AssertionError(final_msg) 218 | if exc_type == self.exc_class: 219 | raise RuntimeError("ShortCircuit was raised, but no errors were recorded") 220 | 221 | 222 | @contextlib.contextmanager 223 | def set_training(model, mode): 224 | """ 225 | A context manager to temporarily set the training mode of 'model' 226 | to 'mode', resetting it when we exit the with-block. 227 | """ 228 | old_mode = model.training 229 | if old_mode != mode: 230 | model.train(mode) 231 | try: 232 | yield 233 | finally: 234 | if old_mode != mode: 235 | model.train(old_mode) 236 | 237 | 238 | def verify(model, args, backend, verbose=False, training=False, decimal=3, test_args=2): 239 | """ 240 | Export a model into ONNX, import it into a specified ONNX backend, and then 241 | on a few random inputs verify that PyTorch and the backend produced the same 242 | results. Requires onnx to be installed. 243 | 244 | This function may spuriously fail: some operators are implemented with 245 | different numerical precision in an ONNX backend, in which case an unstable 246 | network (e.g., Inception) may blow up these numerical instabilities. This 247 | situation is less likely to happen if your model has been trained. However, 248 | if this is not the case, you may have found a bug! Please report it to the 249 | PyTorch developers. You can also debug the issue yourself by removing 250 | suffixes of operators from your model until verification passes. 251 | 252 | For reproduceability, we recommend explicitly setting PyTorch's seed before 253 | invoking this function. 254 | 255 | Arguments: 256 | model (torch.nn.Module): the model to be exported and verified 257 | args (tuple of arguments): the inputs to 258 | the model, e.g., such that ``model(*args)`` is a valid 259 | invocation of the model. Any non-Variable arguments will 260 | be hard-coded into the exported model; any Variable arguments 261 | will become inputs of the exported model, in the order they 262 | occur in args. If args is a Variable, this is equivalent 263 | to having called it with a 1-ary tuple of that Variable. 264 | (Note: passing keyword arguments to the model is not currently 265 | supported. Give us a shout if you need it.) 266 | backend (onnx.backend module): ONNX backend to verify with 267 | verbose (bool, default False): if specified, we will print out a debug 268 | description of the trace being exported. 269 | training (bool, default False): export the model in training mode. At 270 | the moment, ONNX is oriented towards exporting models for inference 271 | only, so you will generally not need to set this to True. 272 | decimal (int, default 3): how many decimal places to test precision 273 | test_args (int or iterable of args, default 2): 274 | either an integer specifying the number 275 | of random arguments to generate, or an iterable producing arguments 276 | to test under. 277 | """ 278 | 279 | def is_variable(o): 280 | return isinstance(o, torch.autograd.Variable) 281 | 282 | def randomize_arg(arg): 283 | new_data = arg.data.clone() 284 | # For now, don't try randomizing non-float tensors; these 285 | # are likely to be things like indices, where just randomly 286 | # spattering some longs is unlikely to work. One way we could 287 | # make this work is to apply a random permutation or something. 288 | if hasattr(new_data, 'uniform_'): 289 | new_data.uniform_() 290 | return torch.autograd.Variable(new_data, volatile=arg.volatile, requires_grad=arg.requires_grad) 291 | 292 | def randomize_args(args): 293 | return torch.autograd.function._nested_map(is_variable, randomize_arg)(args) 294 | 295 | def backend_args(args): 296 | # TODO: onnx should accept iterables 297 | return tuple(v.data.cpu().numpy() for v in torch.autograd.function._iter_variables(args)) 298 | 299 | def load_bytes(b): 300 | b.seek(0) 301 | x = onnx.load(b) 302 | # doc_string has stack traces - let's remove them to make comparison 303 | # sane 304 | onnx.helper.strip_doc_string(x) 305 | return x 306 | 307 | # Special case for common case of passing a single Variable 308 | if isinstance(args, torch.autograd.Variable): 309 | args = (args, ) 310 | 311 | with set_training(model, training): 312 | proto_bytes = io.BytesIO() 313 | torch_out = torch.onnx._export(model, args, proto_bytes, verbose=verbose) 314 | proto = load_bytes(proto_bytes) 315 | prepared = backend.prepare(proto) 316 | 317 | def run(args): 318 | alt_proto_bytes = io.BytesIO() 319 | torch_out = torch.onnx._export(model, args, alt_proto_bytes, verbose=verbose) 320 | alt_proto = load_bytes(alt_proto_bytes) 321 | if proto.SerializeToString() != alt_proto.SerializeToString(): 322 | # OK, let's try to figure out what happened. 323 | msg = "When I exported your model with different inputs, the result was different." 324 | if not verbose: 325 | msg += "\n(To get more information, run torch.onnx.verify(..., verbose=True))" 326 | with Errors(msg) as errs: 327 | # First, check if we have the same number of parameters, and 328 | # that they're the same order. If they don't, something has *really* gone wrong. 329 | initializer_order_hint = ("This is really strange! The second time I exported your model,\n" 330 | "it had a different set of parameters. Are you assigning Parameters\n" 331 | "in the forward() of your model definition?") 332 | with errs.addErrCtxt(initializer_order_hint): 333 | errs.requireEqual(list(map(lambda x: x.name, proto.graph.initializer)), 334 | list(map(lambda x: x.name, alt_proto.graph.initializer)), 335 | msg="Parameters list differs") 336 | 337 | # Now check if the embedded parameters are actually the same 338 | initializer_hint = ("A difference in embedded parameters usually means that\n" 339 | "your model is updating parameters/buffers even in inference\n" 340 | "mode. Look for a buggy nn.Module which isn't respecting train().\n") 341 | with errs.recover(), errs.addErrCtxt(initializer_hint): 342 | for x, y in zip(proto.graph.initializer, alt_proto.graph.initializer): 343 | errs.checkEqual(x, y) 344 | 345 | # Next, check if the model structure lines up. 346 | structure_hint = ("A difference in model structure usually means that\n" 347 | "your model has dynamic control flow. These models are not\n" 348 | "currently supported by the exporter.") 349 | with errs.recover(), errs.addErrCtxt(structure_hint): 350 | # Delete initializers since we already tested them 351 | stripped_proto = onnx.ModelProto() 352 | stripped_proto.CopyFrom(proto) 353 | del stripped_proto.graph.initializer[:] 354 | 355 | stripped_alt_proto = onnx.ModelProto() 356 | stripped_alt_proto.CopyFrom(alt_proto) 357 | del stripped_alt_proto.graph.initializer[:] 358 | 359 | # Compare the printable graph representations first 360 | errs.requireMultiLineEqual(onnx.helper.printable_graph(stripped_proto.graph), 361 | onnx.helper.printable_graph(stripped_alt_proto.graph)) 362 | 363 | # Compare the actual protobuf text formats now (not 364 | # very user-friendly!) 365 | errs.requireMultiLineEqual(str(stripped_proto), str(stripped_alt_proto)) 366 | 367 | # One last ditch effort, using built-in equality on 368 | # protobufs 369 | errs.requireEqual(stripped_proto, stripped_alt_proto) 370 | 371 | errs.failIfErrs() 372 | 373 | # At this point, we should have figured out why the binary 374 | # protobufs differed, and short-circuited out of this code 375 | # with a helpful error message. But what if we didn't? 376 | # We better still try to give a good error message in this 377 | # case. We EXPECT these requires to fail. If they don't, 378 | # that is a bug in verify 379 | errs.requireEqual(proto, alt_proto) 380 | errs.requireEqual(proto_bytes.getvalue(), alt_proto_bytes.getvalue()) 381 | assert False 382 | 383 | # TODO: test that the traced model also returns the same thing... 384 | run_helper(torch_out, args) 385 | 386 | # Factored out so we can avoid one run of the model 387 | def run_helper(torch_out, args): 388 | backend_out = prepared.run(backend_args(args)) 389 | if isinstance(torch_out, torch.autograd.Variable): 390 | torch_out = (torch_out,) 391 | # NB: onnx backend NEVER returns bare numpy array 392 | msg = "ONNX backend returned different results from PyTorch" 393 | result_hint = ("If you are not using trained parameters, a difference in results\n" 394 | "could mean that your network is numerically unstable. Otherwise\n" 395 | "it indicates a bug in PyTorch/ONNX; please file a bug report.") 396 | with Errors(msg) as errs, errs.addErrCtxt(result_hint): 397 | for i, (x, y) in enumerate(zip(torch_out, backend_out)): 398 | errs.checkAlmostEqual(x.data.cpu().numpy(), y, "In output {}".format(i)) 399 | 400 | run_helper(torch_out, args) 401 | 402 | if isinstance(test_args, int): 403 | for i in range(test_args): 404 | run(randomize_args(args)) 405 | else: 406 | for test_arg in test_args: 407 | run(test_arg) 408 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import sys 7 | from setuptools import setup, find_packages 8 | 9 | setup( 10 | name="onnx-pytorch", 11 | version='0.l', 12 | description="PyTorch helpers for working with Open Neural Network Exchange format", 13 | install_requires=['numpy', 'onnx'], 14 | setup_requires=[], 15 | tests_require=[], 16 | packages=find_packages(), 17 | author='ezyang', 18 | author_email='ezyang@fb.com', 19 | url='https://github.com/ezyang/onnx-pytorch', 20 | ) 21 | --------------------------------------------------------------------------------