├── .flexci ├── config.pbtxt ├── run.sh └── test.sh ├── .github ├── CONTRIBUTING.md └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── LICENSE ├── README.md ├── chainer_pytorch_migration ├── __init__.py ├── allocator.py ├── chainermn │ ├── __init__.py │ └── optimizers.py ├── datasets.py ├── device.py ├── ignite │ ├── __init__.py │ ├── collate.py │ └── extensions.py ├── links.py ├── parameter.py └── tensor.py ├── setup.py └── tests ├── test_collate.py ├── test_datasets.py ├── test_device.py ├── test_extensions.py ├── test_links.py ├── test_parameter.py └── test_tensor.py /.flexci/config.pbtxt: -------------------------------------------------------------------------------- 1 | configs { 2 | key: "chainer-pytorch-migration" 3 | value { 4 | requirement { 5 | cpu: 2 6 | gpu: 1 7 | memory: 8 8 | disk: 10 9 | } 10 | time_limit: { 11 | seconds: 900 12 | } 13 | command: "bash .flexci/run.sh" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /.flexci/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -uex 2 | 3 | nvidia-docker run --volume ${PWD}:/work --workdir /work nvidia/cuda:10.0-cudnn7-runtime-ubuntu18.04 .flexci/test.sh 4 | -------------------------------------------------------------------------------- /.flexci/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -uex 2 | 3 | perl -pi.bak -e 's|http://archive\.ubuntu\.com/ubuntu/|mirror://mirrors.ubuntu.com/mirrors.txt|g' /etc/apt/sources.list 4 | apt update 5 | apt -y install python3 python3-pip 6 | 7 | pip3 install -U chainer 'cupy-cuda100<8' 8 | pip3 install pytorch-ignite 9 | pip3 install -e '.[test]' 10 | 11 | python3 -m pytest tests/ 12 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Please refer to the contribution guide in the [README](../README.md). 2 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Thank you for the contribution! 2 | 3 | Please go through our [contribution guide in the README](../README.md) if this is your first time. 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019 Preferred Networks, Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chainer/PyTorch Migration Library 2 | 3 | `chainer-pytorch-migration` is a provisional tool to help migrating Chainer projects to PyTorch. 4 | 5 | ## Description 6 | 7 | chainer-pytorch-migration (CPM) is a tool to help in the migration process of a project from Chainer to PyTorch. 8 | 9 | The main features in CPM are: 10 | 11 | + Use PyTorch models with Chainer training scripts 12 | + Use Chainer models with PyTorch training scripts 13 | + Use Chainer extensions with Ignite trainers 14 | + Use PyTorch memory allocator in CuPy 15 | 16 | The main goal of CPM is to allow components from the two frameworks to interact together while the migration of a project is on-going. 17 | 18 | Please refer to the migration guide for the detailed usage. 19 | 20 | ## Installation 21 | 22 | ```sh 23 | pip install chainer-pytorch-migration 24 | 25 | # Required only if you want to use `chainer_pytorch_migration.ingite`: 26 | pip install pytorch-ignite 27 | 28 | # Required only if you want to use CuPy integration: 29 | # See: https://docs-cupy.chainer.org/en/latest/install.html#install-cupy 30 | pip install cupy-cudaXXX 31 | ``` 32 | 33 | ## Contribution Guide 34 | 35 | You can contribute to this project by sending a pull request. 36 | After approval, the pull request will be merged by the reviewer. 37 | 38 | Before making a contribution, please confirm that: 39 | 40 | - Code quality stays consistent across the script, module or package. 41 | - Code is covered by unit tests. 42 | - API is maintainable. 43 | -------------------------------------------------------------------------------- /chainer_pytorch_migration/__init__.py: -------------------------------------------------------------------------------- 1 | from . import links 2 | from .allocator import use_mempool_in_cupy_malloc, use_torch_in_cupy_malloc 3 | from .datasets import TransformDataset 4 | from .links import TorchModule 5 | from .parameter import ChainerParameter, LinkAsTorchModel, Optimizer 6 | from .tensor import asarray, astensor, to_numpy_dtype 7 | from .device import to_chainer_device, to_torch_device 8 | -------------------------------------------------------------------------------- /chainer_pytorch_migration/allocator.py: -------------------------------------------------------------------------------- 1 | try: 2 | import cupy 3 | _cupy_import_error = None 4 | except ImportError as e: 5 | _cupy_import_error = e 6 | import torch 7 | 8 | 9 | def use_mempool_in_cupy_malloc(): 10 | _ensure_cupy() 11 | cupy.cuda.set_allocator(cupy.get_default_memory_pool().malloc) 12 | 13 | 14 | def use_torch_in_cupy_malloc(): 15 | _ensure_cupy() 16 | cupy.cuda.set_allocator(_torch_alloc) 17 | 18 | 19 | def _ensure_cupy(): 20 | if _cupy_import_error is not None: 21 | raise RuntimeError( 22 | 'cupy is not available; import error is:\n{}', _cupy_import_error) 23 | 24 | 25 | def _torch_alloc(size): 26 | device = cupy.cuda.Device().id 27 | tensor = torch.empty(size, dtype=torch.uint8, device=device) 28 | return cupy.cuda.MemoryPointer( 29 | cupy.cuda.UnownedMemory(tensor.data_ptr(), size, tensor), 0) 30 | -------------------------------------------------------------------------------- /chainer_pytorch_migration/chainermn/__init__.py: -------------------------------------------------------------------------------- 1 | from chainermn import * 2 | from chainer_pytorch_migration.chainermn.optimizers import create_multi_node_optimizer 3 | -------------------------------------------------------------------------------- /chainer_pytorch_migration/chainermn/optimizers.py: -------------------------------------------------------------------------------- 1 | import chainermn 2 | 3 | 4 | class ChainerMNOptimizer(object): 5 | def __init__(self, optimizer): 6 | super(ChainerMNOptimizer, self).__setattr__( 7 | 'optimizer', optimizer) 8 | 9 | def update(self, lossfun=None, *args, **kwds): 10 | """Used to fool chainermn optimizer wrappers""" 11 | self.optimizer.step() 12 | 13 | def setup(self, link): 14 | self.target = link 15 | return self 16 | 17 | def __getattr__(self, attr_name): 18 | return getattr(self.optimizer, attr_name) 19 | 20 | def __setattr__(self, attr_name, value): 21 | setattr(self.optimizer, attr_name, value) 22 | 23 | 24 | def create_multi_node_optimizer(actual_optimizer, communicator, 25 | double_buffering=False, zero_fill=True): 26 | return chainermn.optimizers.create_multi_node_optimizer( 27 | ChainerMNOptimizer(actual_optimizer), communicator, 28 | double_buffering, zero_fill) 29 | -------------------------------------------------------------------------------- /chainer_pytorch_migration/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TransformDataset(torch.utils.data.Dataset): 5 | 6 | def __init__(self, dataset, transform): 7 | self._dataset = dataset 8 | self._transform = transform 9 | 10 | def __len__(self): 11 | return len(self._dataset) 12 | 13 | def __getitem__(self, i): 14 | in_data = self._dataset[i] 15 | return self._transform(in_data) 16 | -------------------------------------------------------------------------------- /chainer_pytorch_migration/device.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import torch 3 | 4 | 5 | def to_chainer_device(device): 6 | """Create a chainer device from a given torch device. 7 | 8 | Args: 9 | device (torch.device): Device to be converted. 10 | 11 | Returns: 12 | A ``chainer.device`` object corresponding to the given input. 13 | """ 14 | if not isinstance(device, torch.device): 15 | raise TypeError('The argument should be torch device.') 16 | if device.type == 'cpu': 17 | return chainer.get_device('@numpy') 18 | if device.type == 'cuda': 19 | device_index = 0 if device.index is None else device.index 20 | return chainer.get_device('@cupy:{}'.format(device_index)) 21 | raise ValueError('{} is not supported.'.format(device.type)) 22 | 23 | 24 | def to_torch_device(device): 25 | """Create a torch device from a given chainer device. 26 | 27 | Args: 28 | device (chainer.Device): Device to be converted. 29 | 30 | Returns: 31 | A ``torch.device`` object corresponding to the given input. 32 | """ 33 | if not isinstance(device, chainer.backend.Device): 34 | raise TypeError('The argument should be chainer device.') 35 | if device.name == '@numpy': 36 | return torch.device('cpu') 37 | if device.name.startswith('@cupy:'): 38 | cuda_device_index = int(device.name.split(':')[1]) 39 | return torch.device('cuda', cuda_device_index) 40 | raise ValueError('{} is not supported.'.format(device.name)) 41 | -------------------------------------------------------------------------------- /chainer_pytorch_migration/ignite/__init__.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import ignite 4 | 5 | from .extensions import add_trainer_extension, load_chainer_snapshot # NOQA 6 | from .collate import collate_to_array # NOQA 7 | 8 | 9 | def _get_version(version): 10 | # We compare up to the minor version (first two digits). 11 | # This is because it is highly unlikely that these numbers 12 | # will contain other character than digits. 13 | 14 | # Ignite versioning system is not explicitly documented. 15 | # However, it seems to be using semver, so the 16 | # major and minor ids can be only integers. 17 | # Some examples of versions are: 18 | # 0.1.0, 0.1.1, 0.3.0.dev20191007, 0.3.0. 19 | version_regexp = r'^[0-9]+\.[0-9]+\.[0-9]+(\.[0-9a-zA-Z]+)?$' 20 | if re.search(version_regexp, version): 21 | return [int(x) for x in version.split('.')[:2]] 22 | raise ValueError('Invalid version format') 23 | 24 | 25 | if _get_version(ignite.__version__) < _get_version('0.3.0'): 26 | raise ImportError('Ignite version found {}. ' 27 | 'Required is >=0.3.0'.format(ignite.__version__)) 28 | -------------------------------------------------------------------------------- /chainer_pytorch_migration/ignite/collate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from chainer_pytorch_migration import tensor 4 | 5 | 6 | def collate_to_array(batch): 7 | data = torch.utils.data._utils.collate.default_collate(batch) 8 | return [tensor.asarray(x) for x in data] 9 | -------------------------------------------------------------------------------- /chainer_pytorch_migration/ignite/extensions.py: -------------------------------------------------------------------------------- 1 | import time 2 | import six 3 | import os 4 | 5 | import chainer 6 | import torch 7 | 8 | import chainer_pytorch_migration as cpm 9 | from chainer.training import trigger as trigger_module 10 | from ignite.engine import Events, Engine 11 | 12 | # Torch computational graph here 13 | # https://gist.github.com/wangg12/f11258583ffcc4728eb71adc0f38e832I 14 | # Make our own extensions for the nocompat ones and 15 | # reroute them? 16 | 17 | 18 | """ 19 | Currently supported extensions 20 | +ExponentialShift (optimizer must be None) 21 | +FailOnNonNumber 22 | +InverseShift (optimizer must be None) 23 | +LinearShift (optimizer must be None) 24 | +LogReport 25 | +MicroAverage (must be registered before LogReport) 26 | +MultistepShift 27 | +ParameterStatistics 28 | +PlotReport 29 | +PolynomialShift (optimizer must be None) 30 | +PrintReport 31 | +ProgressBar 32 | +SnapshotWriters 33 | +StepShift (optimizer must be None) 34 | +observe_lr (observe_value) 35 | +VariableStatisticsPlot 36 | +WarmupShift 37 | 38 | 39 | Not working 40 | +ComputationalGraph 41 | +Evaluator 42 | +unchain_variables 43 | """ 44 | 45 | 46 | # In case multiple engines are used? 47 | engines = {} 48 | 49 | 50 | def add_trainer_extension(engine, optimizer, extension, 51 | name=None, trigger=None, priority=None, **kwargs): 52 | 53 | """Function to register a `chainer.training.Extension` in a 54 | `ignite.Engine` trainer. 55 | 56 | Args: 57 | engine (:class:`ignite.Engine`): The ignite trainer object to which 58 | the extension will be associated. 59 | optimizer: (:class: `torch.optim.Optimizer`): A Torch 60 | with the .target attribute set to the model 61 | extension: (:class: `chainer.training.Extension`): Chainer extension 62 | to be executed. 63 | trigger (tuple or Trigger): Trigger object that determines when to 64 | invoke the extension. If it is ``None``, ``extension.trigger`` 65 | is used instead. If it is ``None`` and the extension does not 66 | have the trigger attribute, the extension is triggered at every 67 | iteration by default. If the trigger is not callable, it is 68 | passed to :class:`IntervalTrigger` to build an interval 69 | trigger. 70 | priority (int): Invocation priority of the extension. Extensions 71 | are invoked in the descending order of priorities in each 72 | iteration. If this is ``None``, ``extension.priority`` is used 73 | instead. 74 | 75 | """ 76 | if isinstance(extension, ( 77 | chainer.training.extensions.DumpGraph, 78 | chainer.training.extensions.Evaluator, 79 | chainer.training.extensions.unchain_variables)): 80 | raise ValueError('Extension {} is not supported in cpm'.format( 81 | extension.__class__.__name__)) 82 | if not id(engine) in engines: 83 | engines[id(engine)] = ExtensionTrainerAdapter(engine, optimizer) 84 | 85 | adapter = engines[id(engine)] 86 | adapter._chainer_trainer.extend(extension, name, trigger, priority, 87 | **kwargs) 88 | 89 | 90 | def load_chainer_snapshot(engine, optimizer, snapshot_file, 91 | snapshot_file_torch=None): 92 | """Function to load a torch/chainer combined snapshot 93 | using the cpm interface 94 | 95 | Args: 96 | engine (:class:`ignite.Engine`): The ignite trainer object to which 97 | the extension will be associated. 98 | optimizer: (:class: `torch.optim.Optimizer`): A Torch 99 | with the .target attribute set to the model 100 | extension: (:class: `chainer.training.Extension`): Chainer extension 101 | to be executed. 102 | snapshot_file (str or file-like): Target chainer snapshot 103 | obtained with the `chainer.extensions.snapshot` 104 | ExtensionTrainerAdapter loaded through the cpi tools 105 | snapshot_file_torch (str or file-like): Target torch snapshot. 106 | If not given, torch data would be loaded from 107 | "`snapshot_file`-torch". 108 | 109 | """ 110 | if not id(engine) in engines: 111 | engines[id(engine)] = ExtensionTrainerAdapter(engine, optimizer) 112 | 113 | adapter = engines[id(engine)] 114 | if snapshot_file_torch is None: 115 | # if the torch snapshot is not given, we pass the filename 116 | # of chainer snapshot and let the ExtensionUpdaterAdapter to generate 117 | # the torch snapshot name 118 | if isinstance(snapshot_file, six.string_types): 119 | adapter.torch_snapshot = snapshot_file 120 | else: 121 | adapter.torch_snapshot = snapshot_file.name 122 | adapter.torch_snapshot += '-torch' 123 | else: 124 | adapter.torch_snapshot = snapshot_file_torch 125 | 126 | # Need to defer state loading because of some ignite particularities 127 | @engine.on(Events.STARTED) 128 | def set_load_snapshot_on_start(engine): 129 | chainer.serializers.load_npz(snapshot_file, adapter) 130 | 131 | 132 | class ExtensionUpdaterAdapter(object): 133 | 134 | """Bridge between the extensions and `ignite.Engine` 135 | 136 | Keeps tracking of the current training status and allows 137 | the extensions to retrieve it using the same API than 138 | the Chainer updaters 139 | 140 | """ 141 | def __init__(self, engine, optimizer): 142 | self.engine = engine 143 | self._optimizers = {'main': optimizer} 144 | 145 | @property 146 | def iteration(self): 147 | return self.engine.state.iteration 148 | 149 | @property 150 | def epoch(self): 151 | return self.engine.state.epoch - 1 152 | 153 | @property 154 | def epoch_detail(self): 155 | epoch_size = len(self.engine.state.dataloader) 156 | return self.iteration/epoch_size 157 | 158 | def get_optimizer(self, name): 159 | return self._optimizers[name] 160 | 161 | def get_all_optimizers(self): 162 | return self._optimizers 163 | 164 | def connect_trainer(self, trainer): 165 | pass 166 | 167 | def serialize(self, serializer, state): 168 | 169 | for name, optimizer in six.iteritems(self._optimizers): 170 | optimizer.serialize(serializer['optimizer:' + name], state) 171 | 172 | if isinstance(serializer, chainer.serializer.Serializer): 173 | state['iteration'] = self.engine.state.iteration 174 | state['epoch_length'] = self.engine.state.epoch_length 175 | elif isinstance(serializer, chainer.serializer.Deserializer): 176 | self.engine.state.iteration = state['iteration'] 177 | self.engine.state.epoch = (state['iteration'] 178 | // state['epoch_length']) 179 | 180 | 181 | class ExtensionTrainerAdapter(object): 182 | 183 | """Bridge between the extensions and `ignite.Engine` 184 | 185 | Manages the extensions by using a dummy `chainer.training.Trainer` 186 | It provides a chainer Trainer compatible API so that the extensions 187 | can interact with the `ignite.Engine` without modifications. 188 | 189 | This class registers several handlers on ignite and forces the order 190 | of handler execution so that user defined ignite handlers are executed 191 | before the chainer extensions. 192 | 193 | """ 194 | def __init__(self, engine, optimizer): 195 | 196 | self.engine = engine 197 | engine.run = self.pre_run 198 | 199 | self.optimizer = ExtensionOptimizerAdapter(optimizer) 200 | self.updater = ExtensionUpdaterAdapter(engine, self.optimizer) 201 | 202 | self.max_epochs = 0 203 | self.stop_trigger = None 204 | self.observation = {} 205 | self.cm = None 206 | 207 | self._start_time = 0 208 | self.out = getattr(engine, 'out', 'result') 209 | if not os.path.exists(self.out): 210 | os.makedirs(self.out) 211 | self.snapshot_file = None 212 | 213 | # We hold a chainer.Trainer dummy object to deal with all 214 | # the extensions registration mechanism and reporter population 215 | self._chainer_trainer = chainer.training.Trainer(self.updater) 216 | # The reporter has several observers associated (links) 217 | self.reporter = self._chainer_trainer.reporter 218 | 219 | self.set_ignite_handlers() 220 | 221 | @property 222 | def is_before_training(self): 223 | return self.updater.iteration == 0 224 | 225 | @property 226 | def elapsed_time(self): 227 | return time.time()-self._start_time 228 | 229 | def set_ignite_handlers(self): 230 | 231 | # Set a handler that sets the reporter scope on every iteration 232 | @self.engine.on(Events.ITERATION_STARTED) 233 | def set_reporter_on_iter(engine): 234 | self.observation = {} 235 | self.cm = self.reporter.scope(self.observation) 236 | self.cm.__enter__() 237 | 238 | @self.engine.on(Events.STARTED) 239 | def set_training_started(engine): 240 | # self._is_before_training = True 241 | self._start_time = time.time() 242 | self.start_extensions() 243 | # Make all the next 244 | # handlers to be executed after user defined ones 245 | @self.engine.on(Events.ITERATION_COMPLETED) 246 | def run_extensions_on_iter(engine): 247 | self.run_extensions() 248 | 249 | # This should be the last extension to be run 250 | @self.engine.on(Events.ITERATION_COMPLETED) 251 | def close_reporter_on_iter(engine): 252 | self.cm.__exit__(None, None, None) 253 | 254 | def start_extensions(self): 255 | exts = self._chainer_trainer._extensions 256 | extension_order = sorted( 257 | exts.keys(), 258 | key=lambda name: exts[name].priority, reverse=True) 259 | self.extensions = [(name, exts[name]) 260 | for name in extension_order] 261 | 262 | # invoke initializer of each extension 263 | for _, entry in self.extensions: 264 | initializer = getattr(entry.extension, 'initialize', None) 265 | finished = getattr(entry.trigger, 'finished', False) 266 | if initializer and not finished: 267 | initializer(self) 268 | 269 | # call extensions before training loop 270 | self.observation = {} 271 | if chainer.__version__ > "7.0.0b2": 272 | # call_before_training only works after 7.0.0b3 273 | with self.reporter.scope(self.observation): 274 | for name, entry in self.extensions: 275 | if entry.call_before_training: 276 | entry.extension(self) 277 | 278 | def run_extensions(self): 279 | for name, entry in self.extensions: 280 | if entry.trigger(self): 281 | ext = entry.extension 282 | self.cur_ext = (name, ext) 283 | entry.extension(self) 284 | 285 | def extend(self, extension): 286 | self.extensions.append(extension) 287 | 288 | def get_extension(self, class_name): 289 | return self._chainer_trainer.get_extension(class_name) 290 | 291 | def pre_run(self, data, max_epochs=1): 292 | # Method interception to capture the max_epochs 293 | # max_epochs is never saved in the Engine class 294 | self.max_epochs = max_epochs 295 | self.stop_trigger = trigger_module.get_trigger((max_epochs, 'epoch')) 296 | Engine.run(self.engine, data, max_epochs) 297 | 298 | def serialize(self, serializer): 299 | 300 | # Lets save torch objects using torch interface 301 | if isinstance(serializer, chainer.serializer.Serializer): 302 | name, ext = self.cur_ext 303 | if type(ext).__name__ == '_MultiNodeSnapshot': 304 | ext = ext.snapshot 305 | snap_path = ext.filename.format(self) 306 | snap_path = os.path.join(self.out, snap_path+'-torch') 307 | state = {'updater': {}} 308 | self.updater.serialize(serializer['updater'], state['updater']) 309 | torch.save(state, snap_path) 310 | elif isinstance(serializer, chainer.serializer.Deserializer): 311 | state = torch.load(self.torch_snapshot) 312 | self.updater.serialize(serializer['updater'], state['updater']) 313 | 314 | if hasattr(self.stop_trigger, 'serialize'): 315 | self.stop_trigger.serialize(serializer['stop_trigger']) 316 | 317 | s = serializer['extensions'] 318 | t = serializer['extension_triggers'] 319 | for name, entry in six.iteritems(self._chainer_trainer._extensions): 320 | if hasattr(entry.extension, 'serialize'): 321 | entry.extension.serialize(s[name]) 322 | if hasattr(entry.trigger, 'serialize'): 323 | entry.trigger.serialize(t[name]) 324 | 325 | 326 | class ExtensionOptimizerAdapter(object): 327 | 328 | """Adapts torch optimizer interface 329 | to chainer one, so extensions are 330 | compatible 331 | It only access the optimizer param_groups 332 | TODO(ecastill) multiple param_groups 333 | """ 334 | def __init__(self, optimizer): 335 | self.optimizer = optimizer 336 | # Torch doesnt track the optimizer params 337 | # until calculations are performed 338 | # We need to delay the conversion. 339 | self.torch_model = None 340 | if isinstance(optimizer.target, chainer.Link): 341 | self.target = optimizer.target 342 | else: 343 | self.torch_model = optimizer.target 344 | self.target = cpm.TorchModule(optimizer.target) 345 | # There is no API in torch to know wether a model is on cuda 346 | param_tensor = next(optimizer.target.parameters()) 347 | if param_tensor.is_cuda: 348 | self.target.to_gpu(param_tensor.device.index) 349 | 350 | def __getattr__(self, item): 351 | if item not in ('target', 'optimizer', 'torch_model'): 352 | return self.optimizer.param_groups[0][item] 353 | return super().__getattr__(item) 354 | 355 | def __setattr__(self, item, value): 356 | if item not in ('target', 'optimizer', 'torch_model'): 357 | self.optimizer.param_groups[0][item] = value 358 | else: 359 | super().__setattr__(item, value) 360 | 361 | def serialize(self, serializer, state): 362 | model_is_torch = self.torch_model is not None 363 | optimizer_is_torch = isinstance(self.optimizer, torch.optim.Optimizer) 364 | # if the model is or the optimizer is chainer use the chainer 365 | # serializers 366 | if not model_is_torch: 367 | self.target.serialize(serializer['model']) 368 | if not optimizer_is_torch: 369 | self.target.serialize(serializer['optimizer']) 370 | 371 | if isinstance(serializer, chainer.serializer.Serializer): 372 | if optimizer_is_torch: 373 | state['optimizer'] = self.optimizer.state_dict() 374 | if model_is_torch: 375 | state['model'] = self.torch_model.state_dict() 376 | elif isinstance(serializer, chainer.serializer.Deserializer): 377 | if optimizer_is_torch: 378 | self.optimizer.load_state_dict(state['optimizer']) 379 | if model_is_torch: 380 | self.torch_model.load_state_dict(state['model']) 381 | -------------------------------------------------------------------------------- /chainer_pytorch_migration/links.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | from chainer.backends import cuda 3 | 4 | from chainer_pytorch_migration import tensor 5 | 6 | 7 | class TorchModule(chainer.Chain): 8 | 9 | """Chain that wraps PyTorch module. 10 | 11 | ``TorchModule`` wraps a :class:`torch.nn.Module` object with 12 | :class:`chainer.Link` interface. The module hierarchy is reproduced as a 13 | link hierarchy. The parameters and persistences of each link are views of 14 | the parameters and buffers of the corresponding module. 15 | 16 | This class does not provide ``forward`` implementation. To perform forward 17 | (and backward) propagations, use :attr:`module` directly. When the backprop 18 | is performed to compute the gradients, the gradient with respect to each 19 | parameter is automatically reflected to the corresponding link parameter. 20 | 21 | .. Note: 22 | For device transfer, only :meth:`to_cpu` and :meth:`to_gpu` keep track 23 | of the mapping. Currently, :meth:`to_device` breaks the mapping and 24 | makes the module and link diverge from each other. 25 | 26 | """ 27 | def __init__(self, module): 28 | super().__init__() 29 | self._module = module 30 | 31 | with self.init_scope(): 32 | for name, child in module.named_children(): 33 | if name == 'module': 34 | # DataParallel objects have the model stored as `module` 35 | # causing a conflict. 36 | name = 'wrapped_module' 37 | setattr(self, name, TorchModule(child)) 38 | for name, param in module.named_parameters(recurse=False): 39 | ch_param = chainer.Parameter(tensor.asarray(param)) 40 | setattr(self, name, ch_param) 41 | # Gradient computed at PyTorch side is automatically 42 | # synchronized to Chainer side with this hook. 43 | param.register_hook(_get_grad_setter(ch_param)) 44 | for name, buffer in module.named_buffers(recurse=False): 45 | self.add_persistent(name, tensor.asarray(buffer)) 46 | 47 | @property 48 | def module(self): 49 | """PyTorch module that this object wraps.""" 50 | return self._module 51 | 52 | # TODO(beam2d): Fix Chainer to enable TorchModule to override to_device. 53 | def to_cpu(self): 54 | self.module.cpu() 55 | self._sync_from_torch() 56 | 57 | # This super call does not transfer the parameters and arrays, but is 58 | # needed to correctly change the metadata. 59 | super().to_cpu() 60 | 61 | def to_gpu(self, device=None): 62 | self.module.cuda(cuda.cupy.cuda.Device(device).id) 63 | self._sync_from_torch() 64 | 65 | # This super call does not transfer the parameters and arrays, but is 66 | # needed to correctly change the metadata. 67 | super().to_gpu(device) 68 | 69 | def _sync_from_torch(self): 70 | for child in self.children(): 71 | child._sync_from_torch() 72 | for name, param in self.module.named_parameters(recurse=False): 73 | getattr(self, name).array = tensor.asarray(param) 74 | for name, buffer in self.module.named_buffers(recurse=False): 75 | setattr(self, name, tensor.asarray(buffer)) 76 | 77 | 78 | def _get_grad_setter(param): 79 | def hook(grad): 80 | param.grad = tensor.asarray(grad) 81 | return hook 82 | -------------------------------------------------------------------------------- /chainer_pytorch_migration/parameter.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import torch 3 | 4 | import chainer_pytorch_migration as cpm 5 | 6 | 7 | def _named_children(link): 8 | assert isinstance(link, chainer.Link) 9 | if isinstance(link, chainer.Chain): 10 | for name in link._children: 11 | yield name, getattr(link, name) 12 | 13 | 14 | def _named_params(link): 15 | assert isinstance(link, chainer.Link) 16 | for name in link._params: 17 | yield name, getattr(link, name) 18 | 19 | 20 | # Corresponding to ``torch._C._nn._parse_to``. 21 | def _parse_to(*args, device=None, dtype=None, non_blocking=False): 22 | args = list(args) 23 | 24 | if len(args) > 0: 25 | if isinstance(args[0], torch.Tensor): 26 | tensor = args.pop(0) 27 | device = tensor.device 28 | dtype = tensor.dtype 29 | elif isinstance(args[0], torch.dtype): 30 | dtype = args.pop(0) 31 | elif isinstance(args[0], (str, torch.device)): 32 | device = args.pop(0) 33 | if len(args) > 0 and isinstance(args[0], torch.dtype): 34 | dtype = torch.dtype(args.pop(0)) 35 | else: 36 | raise TypeError('Received an invalid combination of arguments.') 37 | 38 | if len(args) > 0: 39 | non_blocking = bool(args.pop(0)) 40 | 41 | if len(args) > 0: 42 | raise TypeError('Received an invalid combination of arguments.') 43 | 44 | if device is not None: 45 | device = torch.device(device) 46 | 47 | return device, dtype, non_blocking 48 | 49 | 50 | def _setattr_recursive(obj, name, value): 51 | attr_list = name.split('.') 52 | for attr in attr_list[:-1]: 53 | obj = getattr(obj, attr) 54 | setattr(obj, attr_list[-1], value) 55 | 56 | 57 | class LinkAsTorchModel(torch.nn.Module): 58 | 59 | '''Converts a Chainer Link to a PyTorch module. 60 | 61 | The parameters of the link are automatically 62 | wrapped using `ChainerParameter` and added 63 | to the module as its own parameters. 64 | 65 | Args: 66 | link (:class:`chainer.Link`): A link. Must have been initialized. 67 | ''' 68 | 69 | def __init__(self, link, **kwargs): 70 | super().__init__() 71 | device = kwargs.pop('_device', None) 72 | uninitialized_params = [ 73 | n for n, p in sorted(_named_params(link)) if p.array is None] 74 | if uninitialized_params: 75 | raise RuntimeError( 76 | 'Link with uninitialized parameters cannot be wrapped with ' 77 | 'LinkAsTorchModel. ' 78 | 'Please initialize parameters before wrapping, by feeding a ' 79 | 'dummy batch to the Chainer model, for example. ' 80 | 'Uninitialized params: [{}]'.format( 81 | ', '.join(repr(n) for n in uninitialized_params))) 82 | 83 | for name, child in _named_children(link): 84 | child_module = LinkAsTorchModel(child, _device=device) 85 | setattr(self, name, child_module) 86 | for name, param in sorted(_named_params(link)): 87 | if device is not None: 88 | param.to_device(device) 89 | setattr(self, name, ChainerParameter(param)) 90 | 91 | self.link = link 92 | 93 | def forward(self, *input): 94 | # The computation graph should be done in Chainer. 95 | # Forward converts the input tensors to numpy/cupy arrays 96 | # as accepted by Chainer. 97 | # The return value should be a tensor as well. 98 | input = [cpm.tensor.asarray(x) if isinstance(x, torch.Tensor) 99 | else x for x in input] 100 | outputs = self.link.forward(*input) 101 | ret = self.__as_tensor(outputs) 102 | return ret 103 | 104 | def __as_tensor(self, value): 105 | if isinstance(value, tuple): 106 | return tuple(self.__as_tensor(x) for x in value) 107 | if isinstance(value, list): 108 | return [self.__as_tensor(x) for x in value] 109 | if isinstance(value, chainer.Variable): 110 | return _ChainerTensor(value) 111 | return value 112 | 113 | def to(self, *args, **kwargs): 114 | device, dtype, non_blocking = _parse_to(*args, **kwargs) 115 | chainer_device = cpm.to_chainer_device(device) 116 | if dtype is not None: 117 | raise NotImplementedError 118 | if non_blocking: 119 | raise NotImplementedError 120 | for name, value in self.named_parameters(): 121 | assert isinstance(value, ChainerParameter) 122 | param = value._param 123 | param.to_device(chainer_device) 124 | value = ChainerParameter(param) 125 | _setattr_recursive(self, name, value) 126 | return self 127 | 128 | 129 | class Optimizer(torch.optim.Optimizer): 130 | 131 | def __init__(self, base_optimizer): 132 | assert isinstance(base_optimizer, torch.optim.Optimizer) 133 | super().__init__(base_optimizer.param_groups, base_optimizer.defaults) 134 | self._base_optimizer = base_optimizer 135 | 136 | def __getattr__(self, name): 137 | if name in ('step', 'zero_grad', '_base_optimizer'): 138 | return object.__getattribute__(self, name) 139 | return getattr(self._base_optimizer, name) 140 | 141 | def step(self, closure=None): 142 | for param_group in self._base_optimizer.param_groups: 143 | for param in param_group['params']: 144 | assert isinstance(param, ChainerParameter) 145 | param.grad.copy_(cpm.astensor(param._param.grad)) 146 | self._base_optimizer.step(closure) 147 | 148 | def zero_grad(self): 149 | self._base_optimizer.zero_grad() 150 | for param_group in self._base_optimizer.param_groups: 151 | for param in param_group['params']: 152 | assert isinstance(param, ChainerParameter) 153 | param._param.zerograd() 154 | 155 | 156 | class _ChainerTensor(torch.Tensor): 157 | ''' 158 | Torch tensor from which backprop can be performed. 159 | ''' 160 | def __new__(cls, variable): 161 | assert isinstance(variable, chainer.Variable) 162 | obj = cpm.astensor(variable.array) 163 | obj.__class__ = cls 164 | return obj 165 | 166 | def __init__(self, variable): 167 | self._variable = variable 168 | 169 | def backward(self, gradient=None, retain_graph=None, create_graph=False): 170 | assert retain_graph is None or retain_graph == False # True not supported 171 | assert self._variable is not None 172 | 173 | var = self._variable 174 | if gradient is not None: 175 | var.grad = cpm.tensor.asarray(gradient) 176 | var.backward( 177 | enable_double_backprop=create_graph, 178 | ) 179 | 180 | def zero_(self): 181 | super().zero_() 182 | self._variable.array[...] = 0 183 | 184 | 185 | class ChainerParameter(torch.nn.Parameter): 186 | 187 | '''Wraps a Chainer parameter for use with a PyTorch optimizer. 188 | 189 | It is used to share the data, and more importantly, the gradient memory 190 | buffer between Chainer and PyTorch, since :class:`chainer.Parameter.grad` 191 | may be reassigned a new buffer after each backward. Computational graphs 192 | must be constructed and backpropagated through on the Chainer-side. 193 | 194 | Args: 195 | param (:class:`chainer.Parameter`): A parameter to convert. 196 | Returns: 197 | A :class:`ChainerParameter`. 198 | ''' 199 | 200 | __grad = None 201 | 202 | def __new__(cls, param): 203 | return super().__new__(cls, cpm.astensor(param.array)) 204 | 205 | def __init__(self, param): 206 | super().__init__() 207 | self._param = param 208 | 209 | @property 210 | def grad(self): 211 | if self.__grad is None: 212 | if self._param.grad is not None: 213 | self.__grad = _ChainerTensor(self._param.grad_var) 214 | return self.__grad 215 | 216 | @grad.setter 217 | def grad(self, g): 218 | if self._param.grad is not None: 219 | self.grad[...] = g 220 | else: 221 | self._param.grad = cpm.asarray(g) 222 | 223 | def zero_(self): 224 | super().zero_() 225 | self._param.cleargrad() 226 | self._param.array[...] = 0 227 | self.__grad = None 228 | -------------------------------------------------------------------------------- /chainer_pytorch_migration/tensor.py: -------------------------------------------------------------------------------- 1 | from chainer.backends import cuda 2 | import numpy 3 | import torch 4 | 5 | 6 | def asarray(tensor): 7 | """Create an ndarray view of a given tensor. 8 | 9 | Args: 10 | tensor (torch.Tensor): Tensor to be converted. 11 | 12 | Returns: 13 | An ndarray view of ``tensor``. The returned array shares the underlying 14 | buffer with ``tensor``. The ownership is also shared, so the buffer is 15 | released only after both the original tensor and the returned ndarray 16 | view are gone. 17 | 18 | """ 19 | dev_type = tensor.device.type 20 | if dev_type == 'cuda': 21 | dev_id = tensor.device.index 22 | cupy = cuda.cupy 23 | with cupy.cuda.Device(dev_id): 24 | # If the tensor is not allocated in torch (empty) 25 | # we just create a new one 26 | if tensor.data_ptr() == 0: 27 | return cupy.ndarray( 28 | tuple(tensor.shape), 29 | dtype=to_numpy_dtype(tensor.dtype)) 30 | itemsize = tensor.element_size() 31 | storage = tensor.storage() 32 | memptr = cupy.cuda.MemoryPointer( 33 | cupy.cuda.UnownedMemory( 34 | storage.data_ptr(), storage.size() * itemsize, tensor, 35 | ), 36 | tensor.storage_offset() * itemsize, 37 | ) 38 | return cupy.ndarray( 39 | tuple(tensor.shape), 40 | dtype=to_numpy_dtype(tensor.dtype), 41 | memptr=memptr, 42 | strides=tuple(s * itemsize for s in tensor.stride()), 43 | ) 44 | if dev_type == 'cpu': 45 | return tensor.detach().numpy() 46 | raise ValueError('tensor on device "{}" is not supported', dev_type) 47 | 48 | 49 | def astensor(array): 50 | """Create a tensor view of a given ndarray. 51 | 52 | Args: 53 | array (numpy.ndarray or cupy.ndarray): Source array to make a view of. 54 | 55 | Returns: 56 | A :class:`torch.Tensor` view of ``array``. The returned tensor shares 57 | the buffer with ``array``. The ownership is also shared, so the buffer 58 | is released only after both the original array and the returned tensor 59 | view are gone. 60 | 61 | Note: 62 | If the array has negative strides, a copy is made 63 | """ 64 | if array is None: 65 | raise TypeError('array cannot be None') 66 | 67 | # Torch does not support negative strides, make a implicit copy of the 68 | # array in such case 69 | if any(s < 0 for s in array.strides): 70 | array = array.copy() 71 | if isinstance(array, cuda.ndarray): 72 | # If the array is not allocated (empty) 73 | # we just create a new one 74 | if array.data.ptr == 0: 75 | return torch.empty( 76 | array.shape, 77 | dtype=to_torch_dtype(array.dtype), 78 | device=array.device.id 79 | ) 80 | return torch.as_tensor( 81 | _ArrayWithCudaArrayInterfaceHavingStrides(array), 82 | device=array.device.id, 83 | ) 84 | if isinstance(array, numpy.ndarray): 85 | return torch.from_numpy(array) 86 | raise TypeError('array of type {} is not supported'.format(type(array))) 87 | 88 | 89 | # Workaround to avoid a bug in converting cupy.ndarray to torch.Tensor via 90 | # __cuda_array_interface__. See: https://github.com/pytorch/pytorch/pull/24947 91 | class _ArrayWithCudaArrayInterfaceHavingStrides: 92 | 93 | def __init__(self, array): 94 | self._array = array 95 | 96 | @property 97 | def __cuda_array_interface__(self): 98 | d = self._array.__cuda_array_interface__ 99 | d['strides'] = self._array.strides 100 | return d 101 | 102 | 103 | def to_numpy_dtype(torch_dtype): 104 | """Convert PyTorch dtype to NumPy dtype. 105 | 106 | Args: 107 | torch_dtype: PyTorch's dtype object. 108 | 109 | Returns: 110 | NumPy type object. 111 | 112 | """ 113 | numpy_dtype = _torch_dtype_mapping.get(torch_dtype, None) 114 | if numpy_dtype is None: 115 | raise TypeError('{} does not have corresponding numpy dtype'.format( 116 | torch_dtype 117 | )) 118 | return numpy_dtype 119 | 120 | 121 | def to_torch_dtype(numpy_dtype): 122 | """Convert NumPy dtype to PyTorch dtype. 123 | 124 | Args: 125 | numpy_dtype: NumPy's dtype object. 126 | 127 | Returns: 128 | PyTorch type object. 129 | 130 | """ 131 | torch_dtype = _numpy_dtype_mapping.get(numpy_dtype, None) 132 | if torch_dtype is None: 133 | raise TypeError('{} does not have corresponding numpy dtype'.format( 134 | numpy_dtype 135 | )) 136 | return torch_dtype 137 | 138 | 139 | _torch_dtype_mapping = { 140 | torch.bool: numpy.dtype('bool'), 141 | torch.uint8: numpy.dtype('uint8'), 142 | torch.int8: numpy.dtype('int8'), 143 | torch.int16: numpy.dtype('int16'), 144 | torch.int32: numpy.dtype('int32'), 145 | torch.int64: numpy.dtype('int64'), 146 | torch.float16: numpy.dtype('float16'), 147 | torch.float32: numpy.dtype('float32'), 148 | torch.float64: numpy.dtype('float64'), 149 | # Note: numpy does not have complex32 150 | # torch.complex32: numpy.dtype('complex32'), 151 | torch.complex64: numpy.dtype('complex64'), 152 | torch.complex128: numpy.dtype('complex128'), 153 | } 154 | 155 | _numpy_dtype_mapping = {v: k for k, v in _torch_dtype_mapping.items()} 156 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | 4 | setuptools.setup( 5 | name='chainer_pytorch_migration', 6 | description='Chainer/PyTorch Migration Library', 7 | license='MIT License', 8 | version='0.0.2', 9 | install_requires=['chainer', 'numpy', 'torch'], 10 | extras_require={'test': ['pytest']}, 11 | packages=[ 12 | 'chainer_pytorch_migration', 13 | 'chainer_pytorch_migration.ignite', 14 | 'chainer_pytorch_migration.chainermn', 15 | ], 16 | ) 17 | -------------------------------------------------------------------------------- /tests/test_collate.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch 3 | import pytest 4 | 5 | import chainer_pytorch_migration.ignite as cpm_ignite 6 | 7 | 8 | @pytest.mark.parametrize( 9 | 'data, batch_size', 10 | [([], 1), (list(range(10)), 1), (list(range(100)), 10)]) 11 | def test_collate(data, batch_size): 12 | collate = cpm_ignite.collate_to_array 13 | dl = torch.utils.data.DataLoader( 14 | data, collate_fn=collate, batch_size=batch_size) 15 | for i, x in enumerate(dl): 16 | for e in x: 17 | assert isinstance(e, numpy.ndarray) 18 | expected = [ 19 | numpy.array(e) for e in data[i * batch_size:(i + 1) * batch_size]] 20 | numpy.testing.assert_array_equal(x, expected) 21 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | import chainer 4 | import chainer_pytorch_migration as cpm 5 | 6 | 7 | def _transform(in_data): 8 | img, label = in_data 9 | img = img - 0.5 # scale to [-0.5, 0.5] 10 | return img, label 11 | 12 | 13 | def test_transform_dataset(): 14 | dataset, _ = chainer.datasets.get_mnist() 15 | 16 | chainer_dataset = chainer.datasets.TransformDataset(dataset, _transform) 17 | cpm_dataset = cpm.TransformDataset(dataset, _transform) 18 | 19 | assert len(chainer_dataset) == len(cpm_dataset) 20 | 21 | for x, y in zip(chainer_dataset, cpm_dataset): 22 | x_data, x_label = x 23 | y_data, y_label = y 24 | numpy.testing.assert_array_equal(x_data, y_data) 25 | assert x_data.dtype == y_data.dtype 26 | assert x_label == y_label 27 | -------------------------------------------------------------------------------- /tests/test_device.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import torch 3 | 4 | import chainer_pytorch_migration as cpm 5 | 6 | 7 | def test_to_chainer_device_cpu(): 8 | device = torch.device('cpu') 9 | chainer_device = cpm.to_chainer_device(device) 10 | assert chainer_device.name == '@numpy' 11 | 12 | def test_to_chainer_device_gpu(): 13 | device = torch.device('cuda') 14 | chainer_device = cpm.to_chainer_device(device) 15 | assert chainer_device.name == '@cupy:0' 16 | 17 | def test_to_chainer_device_gpu_0(): 18 | device = torch.device('cuda:0') 19 | chainer_device = cpm.to_chainer_device(device) 20 | assert chainer_device.name == '@cupy:0' 21 | 22 | def test_to_chainer_device_gpu_1(): 23 | device = torch.device('cuda:1') 24 | chainer_device = cpm.to_chainer_device(device) 25 | assert chainer_device.name == '@cupy:1' 26 | 27 | def test_to_torch_device_cpu(): 28 | device = chainer.get_device('@numpy') 29 | torch_device = cpm.to_torch_device(device) 30 | assert torch_device.type == 'cpu' 31 | 32 | def test_to_torch_device_gpu(): 33 | device = chainer.get_device('@cupy:0') 34 | torch_device = cpm.to_torch_device(device) 35 | assert torch_device.type == 'cuda' 36 | assert torch_device.index == 0 37 | 38 | def test_to_torch_device_gpu_0(): 39 | device = chainer.get_device('@cupy:1') 40 | torch_device = cpm.to_torch_device(device) 41 | assert torch_device.type == 'cuda' 42 | assert torch_device.index == 1 43 | -------------------------------------------------------------------------------- /tests/test_extensions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import unittest 4 | 5 | import chainer 6 | import ignite 7 | import torch 8 | 9 | from ignite.engine import Events 10 | 11 | import chainer_pytorch_migration.ignite 12 | import chainer_pytorch_migration as cpm 13 | 14 | 15 | def test_chainer_extensions(): 16 | 17 | count = 0 18 | 19 | def dummy_extension(trainer): 20 | nonlocal count 21 | count += 1 22 | 23 | engine = ignite.engine.Engine(lambda engine, x: []) 24 | # We just create dummy models as we won't be utilizing them 25 | # We only want the training loop to call our extension 26 | model = torch.nn.Linear(128, 1) 27 | optimizer = torch.optim.SGD(model.parameters(), lr=1.0) 28 | optimizer.target = model 29 | cpm.ignite.add_trainer_extension(engine, optimizer, dummy_extension) 30 | engine.run([1, 2, 3], max_epochs=1) 31 | assert count == 3 32 | 33 | 34 | class SnapshotBaseMixin(object): 35 | 36 | def setUp(self): 37 | self.engine = ignite.engine.Engine(lambda engine, x: []) 38 | self.engine.out = '.' 39 | self.model = torch.nn.Linear(128, 1) 40 | self.optimizer = torch.optim.SGD( 41 | self.model.parameters(), lr=1.0, momentum=0.5) 42 | self.optimizer.target = self.model 43 | w = chainer.training.extensions.snapshot_writers.SimpleWriter() 44 | snapshot = chainer.training.extensions.snapshot(writer=w) 45 | 46 | cpm.ignite.add_trainer_extension( 47 | self.engine, self.optimizer, snapshot, trigger=(1, 'epoch')) 48 | 49 | def tearDown(self): 50 | snapshot_ch = os.path.join(self.engine.out, 'snapshot_iter_3') 51 | snapshot_pt = os.path.join(self.engine.out, 'snapshot_iter_3-torch') 52 | if os.path.exists(snapshot_ch): 53 | os.remove(snapshot_ch) 54 | if os.path.exists(snapshot_pt): 55 | os.remove(snapshot_pt) 56 | 57 | 58 | class TestSnapshotSaveFile(SnapshotBaseMixin, unittest.TestCase): 59 | 60 | def test_save_file(self): 61 | self.engine.run([1, 2, 3], max_epochs=1) 62 | snapshot_ch = os.path.join(self.engine.out, 'snapshot_iter_3') 63 | snapshot_pt = os.path.join(self.engine.out, 'snapshot_iter_3-torch') 64 | assert os.path.exists(snapshot_ch) 65 | assert os.path.exists(snapshot_pt) 66 | 67 | 68 | def compare_state_dicts(d1, d2): 69 | if type(d1) != type(d2): 70 | return False 71 | if isinstance(d1, torch.Tensor): 72 | return torch.equal(d1, d2) 73 | if type(d1) is dict: 74 | # Params just hold pointers, should not be restored 75 | return (d1.keys() == d2.keys()) and all([ 76 | compare_state_dicts(d1[k], d2[k]) for k in d1 if k != 'params']) 77 | if type(d1) is list: 78 | return len(d1) == len(d2) and all( 79 | [compare_state_dicts(l1, l2) for l1, l2 in zip(d1, d2)]) 80 | return d1 == d2 81 | 82 | 83 | class TestSnapshotLoadFile(SnapshotBaseMixin, unittest.TestCase): 84 | 85 | def verify_snapshot_on_start(self, engine, model, optimizer): 86 | assert engine.state.epoch == self.engine.state.epoch 87 | assert engine.state.iteration == self.engine.state.iteration 88 | for p1, p2 in zip(model.parameters(), self.model.parameters()): 89 | assert torch.equal(p1, p2) 90 | assert compare_state_dicts( 91 | optimizer.state_dict(), self.optimizer.state_dict()) 92 | 93 | def setup_model(self): 94 | self.engine.run([1, 2, 3], max_epochs=1) 95 | self.snapshot_ch = os.path.join(self.engine.out, 96 | 'snapshot_iter_3') 97 | self.snapshot_pt = os.path.join(self.engine.out, 98 | 'snapshot_iter_3-torch') 99 | assert os.path.exists(self.snapshot_ch) 100 | assert os.path.exists(self.snapshot_pt) 101 | 102 | # Create a new trainer, load the state and compare model and optimizer 103 | # params 104 | engine = ignite.engine.Engine(lambda engine, x: []) 105 | model = torch.nn.Linear(128, 1) 106 | optimizer = torch.optim.SGD(model.parameters(), lr=1.0) 107 | optimizer.target = model 108 | return engine, model, optimizer 109 | 110 | def test_load_file(self): 111 | engine, model, optimizer = self.setup_model() 112 | cpm.ignite.load_chainer_snapshot(engine, optimizer, 'snapshot_iter_3') 113 | # Need to defer state loading because of some ignite particularities 114 | engine.add_event_handler(Events.STARTED, self.verify_snapshot_on_start, 115 | model=model, optimizer=optimizer) 116 | engine.run([1, 2, 3], max_epochs=2) 117 | 118 | def test_load_file_torch(self): 119 | engine, model, optimizer = self.setup_model() 120 | cpm.ignite.load_chainer_snapshot(engine, optimizer, 121 | self.snapshot_ch, self.snapshot_pt) 122 | # Need to defer state loading because of some ignite particularities 123 | engine.add_event_handler(Events.STARTED, self.verify_snapshot_on_start, 124 | model=model, optimizer=optimizer) 125 | engine.run([1, 2, 3], max_epochs=2) 126 | 127 | def test_load_single_file_obj(self): 128 | engine, model, optimizer = self.setup_model() 129 | # Need to defer state loading because of some ignite particularities 130 | with open(self.snapshot_ch, "rb") as f: 131 | cpm.ignite.load_chainer_snapshot(engine, optimizer, f) 132 | engine.add_event_handler(Events.STARTED, 133 | self.verify_snapshot_on_start, 134 | model=model, optimizer=optimizer) 135 | engine.run([1, 2, 3], max_epochs=2) 136 | 137 | def test_load_both_file_obj(self): 138 | engine, model, optimizer = self.setup_model() 139 | # Need to defer state loading because of some ignite particularities 140 | with open(self.snapshot_ch, "rb") as f_ch: 141 | with open(self.snapshot_pt, "rb") as f_pt: 142 | cpm.ignite.load_chainer_snapshot(engine, optimizer, 143 | f_ch, f_pt) 144 | engine.add_event_handler(Events.STARTED, 145 | self.verify_snapshot_on_start, 146 | model=model, optimizer=optimizer) 147 | engine.run([1, 2, 3], max_epochs=2) 148 | 149 | 150 | class TestResumeTrain(object): 151 | 152 | def create_trainer(self, out_dir): 153 | device = torch.device('cpu') 154 | model = torch.nn.Linear(3, 1).to(device) 155 | X = torch.randn(100, 3) 156 | y = torch.randint(high=1, size=(100,)).to(torch.int64) 157 | dataset = torch.utils.data.TensorDataset(X, y) 158 | optimizer = torch.optim.Adam(model.parameters()) 159 | trainer = ignite.engine.create_supervised_trainer( 160 | model, optimizer, torch.nn.functional.nll_loss, device=device) 161 | optimizer.target = model 162 | trainer.out = out_dir 163 | snapshot = chainer.training.extensions.snapshot( 164 | filename='snapshot_iter-{.updater.iteration}') 165 | cpm.ignite.add_trainer_extension( 166 | trainer, optimizer, snapshot, trigger=(1, 'iteration')) 167 | return trainer, optimizer, dataset 168 | 169 | def test_load_continue(self): 170 | with tempfile.TemporaryDirectory() as out_dir: 171 | trainer, optimizer, dataset = self.create_trainer(out_dir) 172 | train_loader = torch.utils.data.DataLoader( 173 | dataset, shuffle=True, batch_size=10) 174 | trainer.run(train_loader, max_epochs=1) 175 | trainer, optimizer, dataset = self.create_trainer(out_dir) 176 | cpm.ignite.load_chainer_snapshot( 177 | trainer, optimizer, os.path.join(out_dir, 'snapshot_iter-5')) 178 | 179 | # Needs to defer the assert because of the delayed snapshot load. 180 | @trainer.on(ignite.engine.Events.STARTED) 181 | def assert_iteration_count(engine): 182 | assert engine.state.iteration == 5 183 | 184 | trainer.run(train_loader, max_epochs=3) 185 | assert trainer.state.iteration == 30 186 | -------------------------------------------------------------------------------- /tests/test_links.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import torch 3 | import numpy 4 | 5 | import chainer_pytorch_migration as cpm 6 | 7 | 8 | def test_to_torch_module(): 9 | model = torch.nn.Linear(3, 1) 10 | model.weight.data = torch.ones(1, 3) 11 | # Conversion 12 | chained = cpm.TorchModule(model) 13 | 14 | assert isinstance(chained.weight, chainer.Variable) 15 | assert isinstance(chained.bias, chainer.Variable) 16 | assert chained.weight.shape == (1, 3) 17 | assert chained.bias.shape == (1,) 18 | assert (chained.weight.array == numpy.ones((1, 3))).all() 19 | 20 | # Test memory sharing 21 | chained.weight.array[...] = numpy.arange(3).reshape((1, 3)) 22 | assert (model.weight.data == torch.arange(3).reshape((1, 3))).all() 23 | 24 | 25 | def test_to_torch_module_data_parallel(): 26 | model = torch.nn.Linear(3, 1) 27 | model = torch.nn.DataParallel(model, device_ids=[0]) 28 | model.module.weight.data = torch.ones(1, 3) 29 | # Conversion 30 | chained = cpm.TorchModule(model) 31 | 32 | assert isinstance(chained.wrapped_module.weight, chainer.Variable) 33 | assert isinstance(chained.wrapped_module.bias, chainer.Variable) 34 | assert chained.wrapped_module.weight.shape == (1, 3) 35 | assert chained.wrapped_module.bias.shape == (1,) 36 | assert (chained.wrapped_module.weight.array == numpy.ones((1, 3))).all() 37 | 38 | # # Test memory sharing 39 | chained.wrapped_module.weight.array[...] = numpy.arange(3).reshape((1, 3)) 40 | assert (model.module.weight.data == torch.arange(3).reshape((1, 3))).all() 41 | -------------------------------------------------------------------------------- /tests/test_parameter.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import numpy 3 | import pytest 4 | import torch 5 | 6 | import chainer_pytorch_migration as cpm 7 | 8 | 9 | @pytest.mark.parametrize('shape', [(3, 2), (2, 0, 1)]) 10 | def test_chainer_parameter(shape): 11 | # initialized parameter 12 | arr = numpy.full(shape, 17, 'float32') 13 | chainer_param = chainer.Parameter(arr) 14 | 15 | # Conversion 16 | torch_param = cpm.ChainerParameter(chainer_param) 17 | 18 | assert isinstance(torch_param, torch.nn.Parameter) 19 | assert torch_param.shape == shape 20 | assert (torch_param.data.numpy() == numpy.full(shape, 17, 'float32')).all() 21 | 22 | # Test memory sharing 23 | new_arr = numpy.random.randint(-4, 4, shape) 24 | torch_param.data[...] = torch.tensor(new_arr.copy()) 25 | assert (chainer_param.array == new_arr).all() 26 | 27 | 28 | def test_chainer_parameter_uninitialized(): 29 | # Uninitialized parameters are not supported 30 | chainer_param = chainer.Parameter() 31 | 32 | with pytest.raises(TypeError): 33 | cpm.ChainerParameter(chainer_param) 34 | 35 | 36 | @pytest.mark.parametrize('shape', [(3, 2), (2, 0, 1)]) 37 | def test_chainer_parameter_grad_getter(shape): 38 | arr = numpy.full(shape, 17, 'float32') 39 | grad = numpy.full(shape, 9, 'float32') 40 | chainer_param = chainer.Parameter(arr) 41 | chainer_param.grad = grad.copy() 42 | 43 | # Conversion 44 | torch_param = cpm.ChainerParameter(chainer_param) 45 | 46 | # Getter 47 | torch_grad = torch_param.grad 48 | 49 | assert isinstance(torch_grad, torch.Tensor) 50 | assert (torch_grad.numpy() == grad).all() 51 | 52 | # Test memory sharing 53 | new_arr = numpy.random.randint(-4, 4, shape) 54 | torch_grad[...] = torch.tensor(new_arr.copy()) 55 | assert (chainer_param.grad == new_arr).all() 56 | 57 | 58 | @pytest.mark.parametrize('shape', [(3, 2), (2, 0, 1)]) 59 | def test_chainer_parameter_grad_setter(shape): 60 | arr = numpy.full(shape, 17, 'float32') 61 | chainer_param = chainer.Parameter(arr) 62 | 63 | # Conversion 64 | torch_param = cpm.ChainerParameter(chainer_param) 65 | # Initialize grad 66 | torch_param.requires_grad = True 67 | optimizer = torch.optim.SGD([torch_param], lr=0.01, momentum=0.9) 68 | optimizer.zero_grad() 69 | 70 | # Setter 71 | grad = torch.full(shape, 9, dtype=torch.float32) 72 | torch_param.grad = grad 73 | numpy.testing.assert_array_equal(grad, torch_param.grad) 74 | 75 | 76 | def test_link_as_torch_model(): 77 | # initialized parameter 78 | a_arr = numpy.ones((3, 2), 'float32') 79 | a_chainer_param = chainer.Parameter(a_arr) 80 | # 0-size parameter 81 | b_arr = numpy.ones((2, 0, 1), 'float32') 82 | b_chainer_param = chainer.Parameter(b_arr) 83 | 84 | link = chainer.Link() 85 | with link.init_scope(): 86 | link.a = a_chainer_param 87 | link.b = b_chainer_param 88 | 89 | # Conversion 90 | torched = cpm.LinkAsTorchModel(link) 91 | params = list(torched.parameters()) 92 | assert len(params) == 2 93 | assert isinstance(params[0], torch.nn.Parameter) 94 | assert isinstance(params[1], torch.nn.Parameter) 95 | assert params[0].shape == (3, 2) 96 | assert params[1].shape == (2, 0, 1) 97 | assert (params[0].data.numpy() == numpy.ones((3, 2))).all() 98 | 99 | # Test memory sharing 100 | params[0].data[...] = torch.tensor(numpy.arange(6).reshape((3, 2))) 101 | assert (a_chainer_param.array == numpy.arange(6).reshape((3, 2))).all() 102 | 103 | 104 | def test_link_as_torch_model_nested(): 105 | dtype = numpy.float32 106 | 107 | # l2: MyLink2 := p2 * l1(x) 108 | # - p2 109 | # - l1: MyLink1 := p1 * x 110 | # - p1 111 | class MyLink1(chainer.Link): 112 | def __init__(self): 113 | super().__init__() 114 | with self.init_scope(): 115 | self.p1 = chainer.Parameter(numpy.array([2], dtype)) 116 | def forward(self, x1): 117 | return self.p1 * x1 118 | 119 | class MyLink2(chainer.Chain): 120 | def __init__(self): 121 | super().__init__() 122 | with self.init_scope(): 123 | self.p2 = chainer.Parameter(numpy.array([3], dtype)) 124 | self.l1 = MyLink1() 125 | def forward(self, x2): 126 | return self.p2 * self.l1(x2) 127 | 128 | # Dummy optimizer that always writes a constant value. 129 | class MyOptim(torch.optim.Optimizer): 130 | def __init__(self, params): 131 | super().__init__(params, {}) 132 | self.constant = None 133 | def set_constant(self, constant): 134 | self.constant = constant 135 | def step(self, closure=None): 136 | for group in self.param_groups: 137 | for param in group['params']: 138 | param.data[...] = self.constant 139 | 140 | link = MyLink2() 141 | module = cpm.LinkAsTorchModel(link) 142 | assert isinstance(module.p2, torch.nn.Parameter) 143 | assert isinstance(module.l1, torch.nn.Module) 144 | assert isinstance(module.l1.p1, torch.nn.Parameter) 145 | assert len(list(module.parameters(recurse=False))) == 1 146 | assert len(list(module.l1.parameters(recurse=False))) == 1 147 | assert len(list(module.parameters(recurse=True))) == 2 148 | 149 | optimizer = cpm.Optimizer(MyOptim(module.parameters())) 150 | 151 | #-------------- 152 | # Iteration 1 153 | #-------------- 154 | x = numpy.array([4], dtype) 155 | 156 | ### Forward 157 | y = module(x) 158 | assert isinstance(y, torch.Tensor) 159 | numpy.testing.assert_array_equal(y.detach().numpy(), [24]) 160 | 161 | ### Backward 162 | y.backward() 163 | 164 | numpy.testing.assert_array_equal(link.l1.p1.grad, [12]) 165 | numpy.testing.assert_array_equal(link.p2.grad, [8]) 166 | 167 | ### Optimizer step 168 | optimizer.set_constant(3) 169 | optimizer.step() 170 | 171 | # (Torch grads are only synchronized after step()) 172 | numpy.testing.assert_array_equal(module.l1.p1.grad.detach().numpy(), [12]) 173 | numpy.testing.assert_array_equal(module.p2.grad.detach().numpy(), [8]) 174 | 175 | numpy.testing.assert_array_equal(link.p2.array, [3]) 176 | numpy.testing.assert_array_equal(link.l1.p1.array, [3]) 177 | numpy.testing.assert_array_equal(module.p2.detach().numpy(), [3]) 178 | numpy.testing.assert_array_equal(module.l1.p1.detach().numpy(), [3]) 179 | 180 | ### Zero grad 181 | optimizer.zero_grad() 182 | 183 | numpy.testing.assert_array_equal(link.l1.p1.grad, [0]) 184 | numpy.testing.assert_array_equal(link.p2.grad, [0]) 185 | numpy.testing.assert_array_equal(module.l1.p1.grad.detach().numpy(), [0]) 186 | numpy.testing.assert_array_equal(module.p2.grad.detach().numpy(), [0]) 187 | 188 | #-------------- 189 | # Iteration 2 190 | #-------------- 191 | x = numpy.array([5], dtype) 192 | 193 | ### Forward 194 | y = module(x) 195 | assert isinstance(y, torch.Tensor) 196 | numpy.testing.assert_array_equal(y.detach().numpy(), [45]) 197 | 198 | ### Backward 199 | y.backward() 200 | 201 | numpy.testing.assert_array_equal(link.l1.p1.grad, [15]) 202 | numpy.testing.assert_array_equal(link.p2.grad, [15]) 203 | 204 | ### Optimizer step 205 | optimizer.set_constant(9) 206 | optimizer.step() 207 | 208 | # (Torch grads are only synchronized after step()) 209 | numpy.testing.assert_array_equal(module.l1.p1.grad.detach().numpy(), [15]) 210 | numpy.testing.assert_array_equal(module.p2.grad.detach().numpy(), [15]) 211 | 212 | numpy.testing.assert_array_equal(link.p2.array, [9]) 213 | numpy.testing.assert_array_equal(link.l1.p1.array, [9]) 214 | numpy.testing.assert_array_equal(module.p2.detach().numpy(), [9]) 215 | numpy.testing.assert_array_equal(module.l1.p1.detach().numpy(), [9]) 216 | 217 | ### Zero grad 218 | optimizer.zero_grad() 219 | 220 | numpy.testing.assert_array_equal(link.l1.p1.grad, [0]) 221 | numpy.testing.assert_array_equal(link.p2.grad, [0]) 222 | numpy.testing.assert_array_equal(module.l1.p1.grad.detach().numpy(), [0]) 223 | numpy.testing.assert_array_equal(module.p2.grad.detach().numpy(), [0]) 224 | 225 | 226 | def test_link_as_torch_model_uninitialized(): 227 | # Uninitialized parameters are not supported 228 | a_chainer_param = chainer.Parameter() 229 | 230 | link = chainer.Link() 231 | with link.init_scope(): 232 | link.a = a_chainer_param 233 | 234 | with pytest.raises(RuntimeError): 235 | torched = cpm.LinkAsTorchModel(link) 236 | torched.parameters() 237 | 238 | 239 | def test_state_dict(): 240 | a_arr = numpy.ones((3, 2), 'float32') 241 | a_chainer_param = chainer.Parameter(a_arr) 242 | # 0-size parameter 243 | b_arr = numpy.ones((2, 0, 1), 'float32') 244 | b_chainer_param = chainer.Parameter(b_arr) 245 | 246 | link = chainer.Link() 247 | with link.init_scope(): 248 | link.a = a_chainer_param 249 | link.b = b_chainer_param 250 | 251 | torched = cpm.LinkAsTorchModel(link) 252 | state_dict = torched.state_dict() 253 | assert 'a' in state_dict 254 | numpy.testing.assert_array_equal(a_arr, state_dict['a'].detach()) 255 | assert 'b' in state_dict 256 | numpy.testing.assert_array_equal(b_arr, state_dict['b'].detach()) 257 | 258 | 259 | def test_named_params(): 260 | a_arr = numpy.ones((3, 2), 'float32') 261 | a_chainer_param = chainer.Parameter(a_arr) 262 | # 0-size parameter 263 | b_arr = numpy.ones((2, 0, 1), 'float32') 264 | b_chainer_param = chainer.Parameter(b_arr) 265 | 266 | link = chainer.Link() 267 | with link.init_scope(): 268 | link.a = a_chainer_param 269 | link.b = b_chainer_param 270 | 271 | torched = cpm.LinkAsTorchModel(link) 272 | n_params = dict(torched.named_parameters()) 273 | assert 'a' in n_params 274 | numpy.testing.assert_array_equal(a_arr, n_params['a'].detach()) 275 | assert 'b' in n_params 276 | numpy.testing.assert_array_equal(b_arr, n_params['b'].detach()) 277 | 278 | 279 | def test_link_to_device(): 280 | a_arr = numpy.ones((3, 2), 'float32') 281 | a_chainer_param = chainer.Parameter(a_arr) 282 | # 0-size parameter 283 | b_arr = numpy.ones((2, 0, 1), 'float32') 284 | b_chainer_param = chainer.Parameter(b_arr) 285 | 286 | link = chainer.Link() 287 | with link.init_scope(): 288 | link.a = a_chainer_param 289 | link.b = b_chainer_param 290 | 291 | torched = cpm.LinkAsTorchModel(link) 292 | ret = torched.to('cuda') 293 | 294 | assert torched is ret 295 | 296 | for name, param in torched.named_parameters(): 297 | assert param.device.type == 'cuda' 298 | -------------------------------------------------------------------------------- /tests/test_tensor.py: -------------------------------------------------------------------------------- 1 | import cupy 2 | import numpy 3 | import pytest 4 | import torch 5 | 6 | from chainer_pytorch_migration import tensor 7 | 8 | 9 | def test_asarray_cpu(): 10 | t = torch.arange(5, dtype=torch.float32) 11 | a = tensor.asarray(t) 12 | assert isinstance(a, numpy.ndarray) 13 | a += 1 14 | numpy.testing.assert_array_equal(a, t.numpy()) 15 | 16 | 17 | def test_asarray_gpu(): 18 | t = torch.arange(5, dtype=torch.float32, device='cuda') 19 | a = tensor.asarray(t) 20 | assert isinstance(a, cupy.ndarray) 21 | a += 1 22 | numpy.testing.assert_array_equal(a.get(), t.cpu().numpy()) 23 | 24 | 25 | def test_asarray_multi_gpu(): 26 | if torch.cuda.device_count() < 2: 27 | pytest.skip('Not enough GPUs') 28 | t = torch.arange(5, dtype=torch.float32, device='cuda:1') 29 | a = tensor.asarray(t) 30 | assert isinstance(a, cupy.ndarray) 31 | with cupy.cuda.Device(1): 32 | a += 1 33 | numpy.testing.assert_array_equal(a.get(), t.cpu().numpy()) 34 | 35 | 36 | def test_astensor_cpu(): 37 | a = numpy.arange(5, dtype=numpy.float32) 38 | t = tensor.astensor(a) 39 | assert isinstance(t, torch.Tensor) 40 | t += 1 41 | numpy.testing.assert_array_equal(a, t.numpy()) 42 | 43 | 44 | def test_astensor_gpu(): 45 | a = cupy.arange(5, dtype=cupy.float32) 46 | t = tensor.astensor(a) 47 | assert isinstance(t, torch.Tensor) 48 | t += 1 49 | numpy.testing.assert_array_equal(a.get(), t.cpu().numpy()) 50 | 51 | 52 | def test_astensor_negative_stride(): 53 | a = numpy.array([1, 2, 3]) 54 | a = a[::-1] 55 | t = tensor.astensor(a) 56 | numpy.testing.assert_array_equal(a, t.numpy()) 57 | 58 | 59 | def test_asarray_empty_cpu(): 60 | t = torch.tensor([], dtype=torch.float32) 61 | a = tensor.asarray(t) 62 | assert isinstance(a, numpy.ndarray) 63 | 64 | 65 | def test_asarray_empty_gpu(): 66 | t = torch.tensor([], dtype=torch.float32, device='cuda') 67 | a = tensor.asarray(t) 68 | assert isinstance(a, cupy.ndarray) 69 | 70 | 71 | def test_astensor_empty_cpu(): 72 | a = numpy.array([], dtype=numpy.float32) 73 | t = tensor.astensor(a) 74 | assert t.device.type == 'cpu' 75 | 76 | 77 | def test_astensor_empty_gpu(): 78 | a = cupy.array([], dtype=cupy.float32) 79 | t = tensor.astensor(a) 80 | assert isinstance(t, torch.Tensor) 81 | assert t.device.type == 'cuda' 82 | t += 1 83 | numpy.testing.assert_array_equal(a.get(), t.cpu().numpy()) 84 | 85 | 86 | @pytest.mark.parametrize('dtype', [ 87 | 'bool', 88 | 'uint8', 'int8', 'int16', 'int32', 'int64', 89 | 'float16', 'float32', 'float64', 90 | 'complex64', 'complex128', 91 | ]) 92 | def test_to_numpy_dtype(dtype): 93 | torch_dtype = getattr(torch, dtype) 94 | numpy_dtype = numpy.dtype(dtype) 95 | assert tensor.to_numpy_dtype(torch_dtype) == numpy_dtype 96 | --------------------------------------------------------------------------------