├── .gitignore ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── README.md ├── setup.cfg ├── setup.py ├── testfiles_x86_64 ├── README.md ├── custom_class.t7 ├── doubletensor.t7 ├── floattensor.t7 ├── function.t7 ├── function_upvals.t7 ├── gmodule_with_linear_identity.t7 ├── hello=123.t7 ├── list_table.t7 ├── map_table1.t7 ├── map_table2.t7 ├── nngraph_node.t7 ├── recursive_class.t7 ├── recursive_kv_table.t7 ├── tds_hash.t7 └── tds_vec.t7 ├── tests.py └── torchfile.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask instance folder 57 | instance/ 58 | 59 | # Sphinx documentation 60 | docs/_build/ 61 | 62 | # PyBuilder 63 | target/ 64 | 65 | # IPython Notebook 66 | .ipynb_checkpoints 67 | 68 | # pyenv 69 | .python-version 70 | 71 | 72 | MANIFEST 73 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "2.7" 4 | - "3.4" 5 | - "3.5" 6 | - "3.6" 7 | install: 8 | - pip install numpy nose coverage python-coveralls 9 | - pip install . 10 | script: 11 | - nosetests --with-coverage --cover-package torchfile 12 | after_success: coveralls 13 | 14 | notifications: 15 | email: false 16 | 17 | matrix: 18 | fast_finish: true 19 | include: 20 | env: LINT_CHECK 21 | python: "2.7" 22 | addons: true 23 | install: pip install pep8 24 | script: pep8 torchfile.py 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, Brendan Shillingford 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include *.md 3 | include tests.py 4 | recursive-include testfiles_x86_64 *.md 5 | recursive-include testfiles_x86_64 *.t7 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Torch serialization reader for Python 2 | [![Build Status](https://travis-ci.org/bshillingford/python-torchfile.svg?branch=master)](https://travis-ci.org/bshillingford/python-torchfile) 3 | [![Coverage Status](https://coveralls.io/repos/github/bshillingford/python-torchfile/badge.svg)](https://coveralls.io/github/bshillingford/python-torchfile) 4 | 5 | Mostly direct port of the torch7 Lua and C serialization implementation to 6 | Python, depending only on `numpy` (and the standard library: `array` 7 | and `struct`). Sharing of objects including `torch.Tensor`s is preserved. 8 | 9 | ```python 10 | import torchfile 11 | stuff = torchfile.load('a_bunch_of_stuff.t7') 12 | ``` 13 | 14 | ## Installation: 15 | Install from [PyPI](https://pypi.python.org/pypi/torchfile/0.0.2): 16 | ```sh 17 | pip install torchfile 18 | ``` 19 | or clone this repository, then: 20 | ```sh 21 | python setup.py install 22 | ``` 23 | 24 | Supports Python 2.7, 3.4, 3.5, 3.6. Probably others too. 25 | 26 | ## More examples: 27 | ### Write from torch, read from Python: 28 | Lua: 29 | ```lua 30 | +th> torch.save('/tmp/test.t7', {hello=123, world=torch.rand(1,2,3)}) 31 | ``` 32 | Python: 33 | ```python 34 | In [3]: o = torchfile.load('/tmp/test.t7') 35 | In [4]: print o['world'].shape 36 | (1, 2, 3) 37 | In [5]: o 38 | Out[5]: 39 | {'hello': 123, 'world': array([[[ 0.52291083, 0.29261517, 0.11113465], 40 | [ 0.01017287, 0.21466237, 0.26572137]]])} 41 | ``` 42 | 43 | ### Arbitary torch classes supported: 44 | ```python 45 | In [1]: import torchfile 46 | 47 | In [2]: o = torchfile.load('testfiles_x86_64/gmodule_with_linear_identity.t7') 48 | 49 | In [3]: o.forwardnodes[3].data.module 50 | Out[3]: TorchObject(nn.Identity, {'output': array([], dtype=float64), 'gradInput': array([], dtype=float64)}) 51 | 52 | In [4]: for node in o.forwardnodes: print(repr(node.data.module)) 53 | None 54 | None 55 | None 56 | TorchObject(nn.Identity, {'output': array([], dtype=float64), 'gradInput': array([], dtype=float64)}) 57 | None 58 | TorchObject(nn.Identity, {'output': array([], dtype=float64), 'gradInput': array([], dtype=float64)}) 59 | TorchObject(nn.Linear, {'weight': array([[-0.0248373 ], 60 | [ 0.17503954]]), 'gradInput': array([], dtype=float64), 'gradWeight': array([[ 1.22317168e-312], 61 | [ 1.22317168e-312]]), 'bias': array([ 0.05159848, -0.25367146]), 'gradBias': array([ 1.22317168e-312, 1.22317168e-312]), 'output': array([], dtype=float64)}) 62 | TorchObject(nn.CAddTable, {'output': array([], dtype=float64), 'gradInput': []}) 63 | None 64 | 65 | In [5]: o.forwardnodes[6].data.module.weight 66 | Out[5]: 67 | array([[-0.0248373 ], 68 | [ 0.17503954]]) 69 | 70 | In [6]: o.forwardnodes[6].data.module.bias 71 | Out[6]: array([ 0.05159848, -0.25367146]) 72 | ``` 73 | 74 | ### More complex writing from torch: 75 | Lua: 76 | ```lua 77 | +th> f = torch.DiskFile('/tmp/test.t7', 'w'):binary() 78 | +th> f:writeBool(false) 79 | +th> f:writeObject({hello=123}) 80 | +th> f:writeInt(456) 81 | +th> f:close() 82 | ``` 83 | Python: 84 | ```python 85 | In [1]: import torchfile 86 | In [2]: with open('/tmp/test.t7','rb') as f: 87 | ...: r = torchfile.T7Reader(f) 88 | ...: print(r.read_boolean()) 89 | ...: print(r.read_obj()) 90 | ...: print(r.read_int()) 91 | ...: 92 | False 93 | {'hello': 123} 94 | 456 95 | ``` 96 | 97 | 98 | ## Supported types: 99 | * `nil` to Python `None` 100 | * numbers to Python floats, or by default a heuristic changes them to ints or 101 | longs if they are integral 102 | * booleans 103 | * strings: read as byte strings (Python 3) or normal strings (Python 2), like 104 | lua strings which don't support unicode, and that can contain null chars 105 | * tables converted to a special dict (*); if they are list-like (i.e. have 106 | numeric keys from 1 through n) they become a python list by default 107 | * Torch classes: supports Tensors and Storages, and most classes such as 108 | modules. Trivially extensible much like the Torch serialization code. 109 | Trivial torch classes like most `nn.Module` subclasses become 110 | `TorchObject`s. The `torch_readers` dict contains the mapping from class 111 | names to reading functions. 112 | * functions: loaded into the `LuaFunction` `namedtuple`, 113 | which simply wraps the raw serialized data, i.e. upvalues and code. 114 | These are mostly useless, but exist so you can deserialize anything. 115 | * tds.Hash, tds.Vec 116 | 117 | (*) Since Lua allows you to index a table with a table but Python does not, we 118 | replace dicts with a subclass that is hashable, and change its 119 | equality comparison behaviour to compare by reference. 120 | See `hashable_uniq_dict`. 121 | 122 | 123 | ### Test files demonstrating various features: 124 | ```python 125 | In [1]: import torchfile 126 | 127 | In [2]: torchfile.load('testfiles_x86_64/list_table.t7') 128 | Out[2]: ['hello', 'world', 'third item', 123] 129 | 130 | In [3]: torchfile.load('testfiles_x86_64/doubletensor.t7') 131 | Out[3]: 132 | array([[ 1. , 2. , 3. ], 133 | [ 4. , 5. , 6.9]]) 134 | 135 | # ...also other files demonstrating various types. 136 | ``` 137 | 138 | The example `t7` files will work on any modern Intel or AMD 64-bit CPU, but the 139 | code will use the native byte ordering etc. Currently, the implementation 140 | assumes the system-dependent binary Torch format, but minor refactoring can 141 | give support for the ascii format as well. 142 | 143 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | version = '0.1.0' 4 | 5 | setup( 6 | name = 'torchfile', 7 | version = version, 8 | description = "Torch7 binary serialized file parser", 9 | classifiers = [ 10 | 'Development Status :: 5 - Production/Stable', 11 | 'License :: OSI Approved', 12 | 'Intended Audience :: Developers', 13 | 'Intended Audience :: Science/Research', 14 | 'Topic :: Software Development', 15 | 'Topic :: Software Development :: Libraries', 16 | 'Topic :: Scientific/Engineering', 17 | 'Operating System :: OS Independent', 18 | 'Programming Language :: Python :: 2', 19 | 'Programming Language :: Python :: 2.7', 20 | 'Programming Language :: Python :: 3', 21 | 'Programming Language :: Python :: 3.4', 22 | 'Programming Language :: Python :: 3.5', 23 | 'Programming Language :: Python :: 3.6' 24 | ], 25 | author = 'Brendan Shillingford', 26 | author_email = 'brendan.shillingford@cs.ox.ac.uk', 27 | url = 'https://github.com/bshillingford/python-torchfile', 28 | license = 'BSD', 29 | py_modules=['torchfile'] 30 | ) 31 | 32 | -------------------------------------------------------------------------------- /testfiles_x86_64/README.md: -------------------------------------------------------------------------------- 1 | # t7 test files 2 | 3 | These `.t7` files are all created on Linux x86_64. In practice, this is relevant because of the differing size of a long (see [this PR](https://github.com/bshillingford/python-torchfile/pull/1)) in VC++ and hence win32 Python, and almost all saved `t7` files are little-endian. 4 | 5 | Non-exhaustive explanation of test files: 6 | 7 | * `recursive_class.t7`: instance of a class created as `torch.class("A")`, containing a reference to itself 8 | * `recursive_kv_table.t7`: a table containing one element, with key and value both referencing the table they are contained in 9 | 10 | Some other test files check for the correctness of the heuristics. See `tests.py`. 11 | 12 | -------------------------------------------------------------------------------- /testfiles_x86_64/custom_class.t7: -------------------------------------------------------------------------------- 1 | V 1Blah -------------------------------------------------------------------------------- /testfiles_x86_64/doubletensor.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshillingford/python-torchfile/fbd434a5b5562c88b91a95e6476e11dbb7735436/testfiles_x86_64/doubletensor.t7 -------------------------------------------------------------------------------- /testfiles_x86_64/floattensor.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshillingford/python-torchfile/fbd434a5b5562c88b91a95e6476e11dbb7735436/testfiles_x86_64/floattensor.t7 -------------------------------------------------------------------------------- /testfiles_x86_64/function.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshillingford/python-torchfile/fbd434a5b5562c88b91a95e6476e11dbb7735436/testfiles_x86_64/function.t7 -------------------------------------------------------------------------------- /testfiles_x86_64/function_upvals.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshillingford/python-torchfile/fbd434a5b5562c88b91a95e6476e11dbb7735436/testfiles_x86_64/function_upvals.t7 -------------------------------------------------------------------------------- /testfiles_x86_64/gmodule_with_linear_identity.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshillingford/python-torchfile/fbd434a5b5562c88b91a95e6476e11dbb7735436/testfiles_x86_64/gmodule_with_linear_identity.t7 -------------------------------------------------------------------------------- /testfiles_x86_64/hello=123.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshillingford/python-torchfile/fbd434a5b5562c88b91a95e6476e11dbb7735436/testfiles_x86_64/hello=123.t7 -------------------------------------------------------------------------------- /testfiles_x86_64/list_table.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshillingford/python-torchfile/fbd434a5b5562c88b91a95e6476e11dbb7735436/testfiles_x86_64/list_table.t7 -------------------------------------------------------------------------------- /testfiles_x86_64/map_table1.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshillingford/python-torchfile/fbd434a5b5562c88b91a95e6476e11dbb7735436/testfiles_x86_64/map_table1.t7 -------------------------------------------------------------------------------- /testfiles_x86_64/map_table2.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshillingford/python-torchfile/fbd434a5b5562c88b91a95e6476e11dbb7735436/testfiles_x86_64/map_table2.t7 -------------------------------------------------------------------------------- /testfiles_x86_64/nngraph_node.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshillingford/python-torchfile/fbd434a5b5562c88b91a95e6476e11dbb7735436/testfiles_x86_64/nngraph_node.t7 -------------------------------------------------------------------------------- /testfiles_x86_64/recursive_class.t7: -------------------------------------------------------------------------------- 1 | V 1Aa -------------------------------------------------------------------------------- /testfiles_x86_64/recursive_kv_table.t7: -------------------------------------------------------------------------------- 1 |  -------------------------------------------------------------------------------- /testfiles_x86_64/tds_hash.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshillingford/python-torchfile/fbd434a5b5562c88b91a95e6476e11dbb7735436/testfiles_x86_64/tds_hash.t7 -------------------------------------------------------------------------------- /testfiles_x86_64/tds_vec.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshillingford/python-torchfile/fbd434a5b5562c88b91a95e6476e11dbb7735436/testfiles_x86_64/tds_vec.t7 -------------------------------------------------------------------------------- /tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torchfile 3 | import os.path 4 | import sys 5 | import numpy as np 6 | 7 | 8 | unicode_type = str if sys.version_info > (3,) else unicode 9 | 10 | 11 | def make_filename(fn): 12 | TEST_FILE_DIRECTORY = 'testfiles_x86_64' 13 | return os.path.join(TEST_FILE_DIRECTORY, fn) 14 | 15 | 16 | def load(fn, **kwargs): 17 | return torchfile.load(make_filename(fn), **kwargs) 18 | 19 | 20 | class TestBasics(unittest.TestCase): 21 | 22 | def test_dict(self): 23 | obj = load('hello=123.t7') 24 | self.assertEqual(dict(obj), {b'hello': 123}) 25 | 26 | def test_custom_class(self): 27 | obj = load('custom_class.t7') 28 | self.assertEqual(obj.torch_typename(), b"Blah") 29 | 30 | def test_classnames_never_decoded(self): 31 | obj = load('custom_class.t7', utf8_decode_strings=True) 32 | self.assertNotIsInstance(obj.torch_typename(), unicode_type) 33 | 34 | obj = load('custom_class.t7', utf8_decode_strings=False) 35 | self.assertNotIsInstance(obj.torch_typename(), unicode_type) 36 | 37 | def test_basic_tensors(self): 38 | f64 = load('doubletensor.t7') 39 | self.assertTrue((f64 == np.array([[1, 2, 3, ], [4, 5, 6.9]], 40 | dtype=np.float64)).all()) 41 | 42 | f32 = load('floattensor.t7') 43 | self.assertAlmostEqual(f32.sum(), 12.97241666913, delta=1e-5) 44 | 45 | def test_function(self): 46 | func_with_upvals = load('function_upvals.t7') 47 | self.assertIsInstance(func_with_upvals, torchfile.LuaFunction) 48 | 49 | def test_dict_accessors(self): 50 | obj = load('hello=123.t7', 51 | use_int_heuristic=True, 52 | utf8_decode_strings=True) 53 | self.assertIsInstance(obj['hello'], int) 54 | self.assertIsInstance(obj.hello, int) 55 | 56 | obj = load('hello=123.t7', 57 | use_int_heuristic=True, 58 | utf8_decode_strings=False) 59 | self.assertIsInstance(obj[b'hello'], int) 60 | self.assertIsInstance(obj.hello, int) 61 | 62 | 63 | class TestRecursiveObjects(unittest.TestCase): 64 | 65 | def test_recursive_class(self): 66 | obj = load('recursive_class.t7') 67 | self.assertEqual(obj.a, obj) 68 | 69 | def test_recursive_table(self): 70 | obj = load('recursive_kv_table.t7') 71 | # both the key and value point to itself: 72 | key, = obj.keys() 73 | self.assertEqual(key, obj) 74 | self.assertEqual(obj[key], obj) 75 | 76 | 77 | class TestTDS(unittest.TestCase): 78 | 79 | def test_hash(self): 80 | obj = load('tds_hash.t7') 81 | self.assertEqual(len(obj), 3) 82 | self.assertEqual(obj[1], 2) 83 | self.assertEqual(obj[10], 11) 84 | 85 | def test_vec(self): 86 | # Should not be affected by list heuristic at all 87 | vec = load('tds_vec.t7', use_list_heuristic=False) 88 | self.assertEqual(vec, [123, 456]) 89 | 90 | 91 | class TestHeuristics(unittest.TestCase): 92 | 93 | def test_list_heuristic(self): 94 | obj = load('list_table.t7', use_list_heuristic=True) 95 | self.assertEqual(obj, [b'hello', b'world', b'third item', 123]) 96 | 97 | obj = load('list_table.t7', 98 | use_list_heuristic=False, 99 | use_int_heuristic=True) 100 | self.assertEqual( 101 | dict(obj), 102 | {1: b'hello', 2: b'world', 3: b'third item', 4: 123}) 103 | 104 | def test_int_heuristic(self): 105 | obj = load('hello=123.t7', use_int_heuristic=True) 106 | self.assertIsInstance(obj[b'hello'], int) 107 | 108 | obj = load('hello=123.t7', use_int_heuristic=False) 109 | self.assertNotIsInstance(obj[b'hello'], int) 110 | 111 | obj = load('list_table.t7', 112 | use_list_heuristic=False, 113 | use_int_heuristic=False) 114 | self.assertEqual( 115 | dict(obj), 116 | {1: b'hello', 2: b'world', 3: b'third item', 4: 123}) 117 | self.assertNotIsInstance(list(obj.keys())[0], int) 118 | 119 | 120 | if __name__ == '__main__': 121 | unittest.main() 122 | -------------------------------------------------------------------------------- /torchfile.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mostly direct port of the Lua and C serialization implementation to 3 | Python, depending only on `struct`, `array`, and numpy. 4 | 5 | Supported types: 6 | * `nil` to Python `None` 7 | * numbers to Python floats, or by default a heuristic changes them to ints or 8 | longs if they are integral 9 | * booleans 10 | * strings: read as byte strings (Python 3) or normal strings (Python 2), like 11 | lua strings which don't support unicode, and that can contain null chars 12 | * tables converted to a special dict (*); if they are list-like (i.e. have 13 | numeric keys from 1 through n) they become a python list by default 14 | * Torch classes: supports Tensors and Storages, and most classes such as 15 | modules. Trivially extensible much like the Torch serialization code. 16 | Trivial torch classes like most `nn.Module` subclasses become 17 | `TorchObject`s. The `type_handlers` dict contains the mapping from class 18 | names to reading functions. 19 | * functions: loaded into the `LuaFunction` `namedtuple`, 20 | which simply wraps the raw serialized data, i.e. upvalues and code. 21 | These are mostly useless, but exist so you can deserialize anything. 22 | 23 | (*) Since Lua allows you to index a table with a table but Python does not, we 24 | replace dicts with a subclass that is hashable, and change its 25 | equality comparison behaviour to compare by reference. 26 | See `hashable_uniq_dict`. 27 | 28 | Currently, the implementation assumes the system-dependent binary Torch 29 | format, but minor refactoring can give support for the ascii format as well. 30 | """ 31 | import struct 32 | from array import array 33 | import numpy as np 34 | import sys 35 | from collections import namedtuple 36 | 37 | 38 | TYPE_NIL = 0 39 | TYPE_NUMBER = 1 40 | TYPE_STRING = 2 41 | TYPE_TABLE = 3 42 | TYPE_TORCH = 4 43 | TYPE_BOOLEAN = 5 44 | TYPE_FUNCTION = 6 45 | TYPE_RECUR_FUNCTION = 8 46 | LEGACY_TYPE_RECUR_FUNCTION = 7 47 | 48 | LuaFunction = namedtuple('LuaFunction', 49 | ['size', 'dumped', 'upvalues']) 50 | 51 | 52 | class hashable_uniq_dict(dict): 53 | """ 54 | Subclass of dict with equality and hashing semantics changed: 55 | equality and hashing is purely by reference/instance, to match 56 | the behaviour of lua tables. 57 | 58 | Supports lua-style dot indexing. 59 | 60 | This way, dicts can be keys of other dicts. 61 | """ 62 | 63 | def __hash__(self): 64 | return id(self) 65 | 66 | def __getattr__(self, key): 67 | if key in self: 68 | return self[key] 69 | if isinstance(key, (str, bytes)): 70 | return self.get(key.encode('utf8')) 71 | 72 | def __eq__(self, other): 73 | return id(self) == id(other) 74 | 75 | def __ne__(self, other): 76 | return id(self) != id(other) 77 | 78 | def _disabled_binop(self, other): 79 | raise TypeError( 80 | 'hashable_uniq_dict does not support these comparisons') 81 | __cmp__ = __ne__ = __le__ = __gt__ = __lt__ = _disabled_binop 82 | 83 | 84 | class TorchObject(object): 85 | """ 86 | Simple torch object, used by `add_trivial_class_reader`. 87 | Supports both forms of lua-style indexing, i.e. getattr and getitem. 88 | Use the `torch_typename` method to get the object's torch class name. 89 | 90 | Equality is by reference, as usual for lua (and the default for Python 91 | objects). 92 | """ 93 | 94 | def __init__(self, typename, obj=None, version_number=0): 95 | self._typename = typename 96 | self._obj = obj 97 | self._version_number = version_number 98 | 99 | def __getattr__(self, k): 100 | if k in self._obj: 101 | return self._obj[k] 102 | if isinstance(k, (str, bytes)): 103 | return self._obj.get(k.encode('utf8')) 104 | 105 | def __getitem__(self, k): 106 | if k in self._obj: 107 | return self._obj[k] 108 | if isinstance(k, (str, bytes)): 109 | return self._obj.get(k.encode('utf8')) 110 | 111 | def torch_typename(self): 112 | return self._typename 113 | 114 | def __repr__(self): 115 | return "TorchObject(%s, %s)" % (self._typename, repr(self._obj)) 116 | 117 | def __str__(self): 118 | return repr(self) 119 | 120 | def __dir__(self): 121 | keys = list(self._obj.keys()) 122 | keys.append('torch_typename') 123 | return keys 124 | 125 | 126 | type_handlers = {} 127 | 128 | 129 | def register_handler(typename): 130 | def do_register(handler): 131 | type_handlers[typename] = handler 132 | return do_register 133 | 134 | 135 | def add_tensor_reader(typename, dtype): 136 | def read_tensor_generic(reader, version): 137 | # https://github.com/torch/torch7/blob/1e86025/generic/Tensor.c#L1249 138 | ndim = reader.read_int() 139 | 140 | size = reader.read_long_array(ndim) 141 | stride = reader.read_long_array(ndim) 142 | storage_offset = reader.read_long() - 1 # 0-indexing 143 | # read storage: 144 | storage = reader.read_obj() 145 | 146 | if storage is None or ndim == 0 or len(size) == 0 or len(stride) == 0: 147 | # empty torch tensor 148 | return np.empty((0), dtype=dtype) 149 | 150 | # convert stride to numpy style (i.e. in bytes) 151 | stride = [storage.dtype.itemsize * x for x in stride] 152 | 153 | # create numpy array that indexes into the storage: 154 | return np.lib.stride_tricks.as_strided( 155 | storage[storage_offset:], 156 | shape=size, 157 | strides=stride) 158 | type_handlers[typename] = read_tensor_generic 159 | add_tensor_reader(b'torch.ByteTensor', dtype=np.uint8) 160 | add_tensor_reader(b'torch.CharTensor', dtype=np.int8) 161 | add_tensor_reader(b'torch.ShortTensor', dtype=np.int16) 162 | add_tensor_reader(b'torch.IntTensor', dtype=np.int32) 163 | add_tensor_reader(b'torch.LongTensor', dtype=np.int64) 164 | add_tensor_reader(b'torch.FloatTensor', dtype=np.float32) 165 | add_tensor_reader(b'torch.DoubleTensor', dtype=np.float64) 166 | add_tensor_reader(b'torch.CudaTensor', dtype=np.float32) 167 | add_tensor_reader(b'torch.CudaByteTensor', dtype=np.uint8) 168 | add_tensor_reader(b'torch.CudaCharTensor', dtype=np.int8) 169 | add_tensor_reader(b'torch.CudaShortTensor', dtype=np.int16) 170 | add_tensor_reader(b'torch.CudaIntTensor', dtype=np.int32) 171 | add_tensor_reader(b'torch.CudaDoubleTensor', dtype=np.float64) 172 | 173 | 174 | def add_storage_reader(typename, dtype): 175 | def read_storage(reader, version): 176 | # https://github.com/torch/torch7/blob/1e86025/generic/Storage.c#L237 177 | size = reader.read_long() 178 | return np.fromfile(reader.f, dtype=dtype, count=size) 179 | type_handlers[typename] = read_storage 180 | add_storage_reader(b'torch.ByteStorage', dtype=np.uint8) 181 | add_storage_reader(b'torch.CharStorage', dtype=np.int8) 182 | add_storage_reader(b'torch.ShortStorage', dtype=np.int16) 183 | add_storage_reader(b'torch.IntStorage', dtype=np.int32) 184 | add_storage_reader(b'torch.LongStorage', dtype=np.int64) 185 | add_storage_reader(b'torch.FloatStorage', dtype=np.float32) 186 | add_storage_reader(b'torch.DoubleStorage', dtype=np.float64) 187 | add_storage_reader(b'torch.CudaStorage', dtype=np.float32) 188 | add_storage_reader(b'torch.CudaByteStorage', dtype=np.uint8) 189 | add_storage_reader(b'torch.CudaCharStorage', dtype=np.int8) 190 | add_storage_reader(b'torch.CudaShortStorage', dtype=np.int16) 191 | add_storage_reader(b'torch.CudaIntStorage', dtype=np.int32) 192 | add_storage_reader(b'torch.CudaDoubleStorage', dtype=np.float64) 193 | 194 | 195 | def add_notimpl_reader(typename): 196 | def read_notimpl(reader, version): 197 | raise NotImplementedError('Reader not implemented for: ' + typename) 198 | type_handlers[typename] = read_notimpl 199 | add_notimpl_reader(b'torch.HalfTensor') 200 | add_notimpl_reader(b'torch.HalfStorage') 201 | add_notimpl_reader(b'torch.CudaHalfTensor') 202 | add_notimpl_reader(b'torch.CudaHalfStorage') 203 | 204 | 205 | @register_handler(b'tds.Vec') 206 | def tds_Vec_reader(reader, version): 207 | size = reader.read_int() 208 | obj = [] 209 | _ = reader.read_obj() 210 | for i in range(size): 211 | e = reader.read_obj() 212 | obj.append(e) 213 | return obj 214 | 215 | 216 | @register_handler(b'tds.Hash') 217 | def tds_Hash_reader(reader, version): 218 | size = reader.read_int() 219 | obj = hashable_uniq_dict() 220 | _ = reader.read_obj() 221 | for i in range(size): 222 | k = reader.read_obj() 223 | v = reader.read_obj() 224 | obj[k] = v 225 | return obj 226 | 227 | 228 | class T7ReaderException(Exception): 229 | pass 230 | 231 | 232 | class T7Reader: 233 | 234 | def __init__(self, 235 | fileobj, 236 | use_list_heuristic=True, 237 | use_int_heuristic=True, 238 | utf8_decode_strings=False, 239 | force_deserialize_classes=None, 240 | force_8bytes_long=False): 241 | """ 242 | Params: 243 | * `fileobj`: file object to read from, must be an actual file object 244 | as it will be read by `array`, `struct`, and `numpy`. Since 245 | it is only read sequentially, certain objects like pipes or 246 | `sys.stdin` should work as well (untested). 247 | * `use_list_heuristic`: automatically turn tables with only consecutive 248 | positive integral indices into lists 249 | (default True) 250 | * `use_int_heuristic`: cast all whole floats into ints (default True) 251 | * `utf8_decode_strings`: decode all strings as UTF8. By default they 252 | remain as byte strings. Version strings always 253 | are byte strings, but this setting affects 254 | class names. (default False) 255 | * `force_deserialize_classes`: deprecated. 256 | """ 257 | self.f = fileobj 258 | self.objects = {} # read objects so far 259 | 260 | if force_deserialize_classes is not None: 261 | raise DeprecationWarning( 262 | 'force_deserialize_classes is now always ' 263 | 'forced to be true, so no longer required') 264 | self.use_list_heuristic = use_list_heuristic 265 | self.use_int_heuristic = use_int_heuristic 266 | self.utf8_decode_strings = utf8_decode_strings 267 | self.force_8bytes_long = force_8bytes_long 268 | 269 | def _read(self, fmt): 270 | sz = struct.calcsize(fmt) 271 | return struct.unpack(fmt, self.f.read(sz)) 272 | 273 | def read_boolean(self): 274 | return self.read_int() == 1 275 | 276 | def read_int(self): 277 | return self._read('i')[0] 278 | 279 | def read_long(self): 280 | if self.force_8bytes_long: 281 | return self._read('q')[0] 282 | else: 283 | return self._read('l')[0] 284 | 285 | def read_long_array(self, n): 286 | if self.force_8bytes_long: 287 | lst = [] 288 | for i in range(n): 289 | lst.append(self.read_long()) 290 | return lst 291 | else: 292 | arr = array('l') 293 | arr.fromfile(self.f, n) 294 | return arr.tolist() 295 | 296 | def read_float(self): 297 | return self._read('f')[0] 298 | 299 | def read_double(self): 300 | return self._read('d')[0] 301 | 302 | def read_string(self, disable_utf8=False): 303 | size = self.read_int() 304 | s = self.f.read(size) 305 | if disable_utf8 or not self.utf8_decode_strings: 306 | return s 307 | return s.decode('utf8') 308 | 309 | def read_obj(self): 310 | typeidx = self.read_int() 311 | 312 | if typeidx == TYPE_NIL: 313 | return None 314 | 315 | elif typeidx == TYPE_NUMBER: 316 | x = self.read_double() 317 | # Extra checking for integral numbers: 318 | if self.use_int_heuristic and x.is_integer(): 319 | return int(x) 320 | return x 321 | 322 | elif typeidx == TYPE_BOOLEAN: 323 | return self.read_boolean() 324 | 325 | elif typeidx == TYPE_STRING: 326 | return self.read_string() 327 | 328 | elif (typeidx == TYPE_TABLE or typeidx == TYPE_TORCH or 329 | typeidx == TYPE_FUNCTION or typeidx == TYPE_RECUR_FUNCTION or 330 | typeidx == LEGACY_TYPE_RECUR_FUNCTION): 331 | # read the object reference index 332 | index = self.read_int() 333 | 334 | # check it is loaded already 335 | if index in self.objects: 336 | return self.objects[index] 337 | 338 | # otherwise read it 339 | if (typeidx == TYPE_FUNCTION or typeidx == TYPE_RECUR_FUNCTION or 340 | typeidx == LEGACY_TYPE_RECUR_FUNCTION): 341 | size = self.read_int() 342 | dumped = self.f.read(size) 343 | upvalues = self.read_obj() 344 | obj = LuaFunction(size, dumped, upvalues) 345 | self.objects[index] = obj 346 | return obj 347 | 348 | elif typeidx == TYPE_TORCH: 349 | version = self.read_string(disable_utf8=True) 350 | if version.startswith(b'V '): 351 | version_number = int(float(version.partition(b' ')[2])) 352 | class_name = self.read_string(disable_utf8=True) 353 | else: 354 | class_name = version 355 | # created before existence of versioning 356 | version_number = 0 357 | if class_name in type_handlers: 358 | # TODO: can custom readers ever be self-referential? 359 | self.objects[index] = None # FIXME: if self-referential 360 | obj = type_handlers[class_name](self, version) 361 | self.objects[index] = obj 362 | else: 363 | # This must be performed in two steps to allow objects 364 | # to be a property of themselves. 365 | obj = TorchObject( 366 | class_name, version_number=version_number) 367 | self.objects[index] = obj 368 | # After self.objects is populated, it's safe to read in 369 | # case self-referential 370 | obj._obj = self.read_obj() 371 | return obj 372 | 373 | else: # it is a table: returns a custom dict or a list 374 | size = self.read_int() 375 | # custom hashable dict, so that it can be a key, see above 376 | obj = hashable_uniq_dict() 377 | # For checking if keys are consecutive and positive ints; 378 | # if so, returns a list with indices converted to 0-indices. 379 | key_sum = 0 380 | keys_natural = True 381 | # bugfix: obj must be registered before reading keys and vals 382 | self.objects[index] = obj 383 | 384 | for _ in range(size): 385 | k = self.read_obj() 386 | v = self.read_obj() 387 | obj[k] = v 388 | 389 | if self.use_list_heuristic: 390 | if not isinstance(k, int) or k <= 0: 391 | keys_natural = False 392 | elif isinstance(k, int): 393 | key_sum += k 394 | 395 | if self.use_list_heuristic: 396 | # n(n+1)/2 = sum <=> consecutive and natural numbers 397 | n = len(obj) 398 | if keys_natural and n * (n + 1) == 2 * key_sum: 399 | lst = [] 400 | for i in range(len(obj)): 401 | elem = obj[i + 1] 402 | # In case it is self-referential. This is not 403 | # needed in lua torch since the tables are never 404 | # modified as they are here. 405 | if elem == obj: 406 | elem = lst 407 | lst.append(elem) 408 | self.objects[index] = obj = lst 409 | 410 | return obj 411 | 412 | else: 413 | raise T7ReaderException( 414 | "unknown object type / typeidx: {}".format(typeidx)) 415 | 416 | 417 | def load(filename, **kwargs): 418 | """ 419 | Loads the given t7 file using default settings; kwargs are forwarded 420 | to `T7Reader`. 421 | """ 422 | with open(filename, 'rb') as f: 423 | reader = T7Reader(f, **kwargs) 424 | return reader.read_obj() 425 | --------------------------------------------------------------------------------