├── .gitignore ├── LICENSE ├── README.md ├── packages ├── candle │ ├── __init__.py │ ├── engine.py │ ├── modules.py │ ├── recurrent.py │ ├── tracing.py │ └── utils.py └── siglayer │ ├── __init__.py │ ├── backend.py │ ├── examples.py │ └── modules.py ├── poster └── deepsig_poster.pdf ├── requirements.txt └── src ├── README.md ├── base.ipynb ├── example_generative_model.ipynb ├── example_hurst_parameter.ipynb ├── example_reinforcement_learning.ipynb ├── example_signature_inversion.ipynb ├── generative_model.py ├── hurst_parameter.py ├── reinforcement_learning.py ├── signature_inversion.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | **/.ipynb_checkpoints/ 3 | *.py[cod] 4 | .idea/ 5 | .vs/ 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Signature Transforms 2 | Using the signature transform as a pooling layer in a neural network. 3 | 4 | This is the code for the paper [Deep Signature Transforms](https://arxiv.org/abs/1905.08494) by Bonnier, Kidger, Perez Arribas, Salvi, Lyons 2019. 5 | 6 | Look at [Signatory](https://github.com/patrick-kidger/signatory) for a PyTorch implementation of the signature transform. 7 | 8 | ## Overview 9 | If you're coming at this already knowing something about neural networks, then the idea is that the 'signature transform' is a transformation that does a particularly good job extracting features from streams of data, so it's a natural thing to try and build into our neural network models. 10 | 11 | If you're coming at this already knowing something about signatures, then you probably know that they've previously only been used as a feature transformation, on top of which a model is built. But it is actually possible to backpropagate through the signature transform, so as long you design your model correctly (it has to be 'stream-preserving'; see the paper), then it actually makes sense to embed the signature within a neural network. Learning a nonlinearity before the signature transform provides a compact way to select which terms in the signature (of the original path) are useful for the given dataset. 12 | 13 | ## What are signatures? 14 | The signature of a stream of data is essentially a collection of statistics about that stream of data. This collection of statistics does such a good job of capturing the information about the stream of data that it actually determines the stream of data uniquely. (Up to something called 'tree-like equivalance' anyway, which is really just a technicality. It's an equivalence relation that matters about as much as two functions being equal almost everywhere. That is to say, not much at all.) The signature transform is a particularly attractive tool in machine learning because it is what we call a 'universal nonlinearity': it is sufficiently rich that it captures every possible nonlinear function of the original stream of data. Any function of a stream is *linear* on its signature. Now for various reasons this is a mathematical idealisation not borne out in practice (which is why we put them in a neural network and don't just use a simple linear model), but they still work very well! 15 | 16 | ## Directory layout and reproducability 17 | The `src` directory contains the scripts for our experiments. Reproducability should be easy: just run the `.ipynb` files. 18 | 19 | (The `packages` directory just contains some separate packages that were put together to support this project.) 20 | 21 | ## Dependencies 22 | Python 3.7 was used. Virtual environments and packages were managed with [Miniconda](https://docs.conda.io/en/latest/miniconda.html). The following external packages were used, and may be installed via `pip3 install -r requirements.txt`. 23 | 24 | [`fbm==0.2.0`](https://pypi.org/project/fbm/) for generating fractional Brownian motion. 25 | 26 | [`gym==0.12.1`](https://gym.openai.com/) 27 | 28 | [`pytorch-ignite==0.1.2`](https://pytorch.org/ignite/) is an extension to PyTorch. 29 | 30 | [`iisignature==0.23`](https://github.com/bottler/iisignature) for calculating signatures. (Which was used as [Signatory](https://github.com/patrick-kidger/signatory) had not been developed yet.) 31 | 32 | [`jupyter==1.0.0`](https://jupyter.org/) 33 | 34 | [`matplotlib==2.2.4`](https://matplotlib.org/) 35 | 36 | [`pandas==0.24.2`](https://pandas.pydata.org/) 37 | 38 | [`torch==1.0.1`](https://pytorch.org/) 39 | 40 | [`scikit-learn==0.20.3`](https://scikit-learn.org/) 41 | 42 | [`sdepy==1.0.1`](https://pypi.org/project/sdepy/) for simulating solutions to stochastic differential equations. 43 | 44 | [`tqdm==4.31.1`](https://github.com/tqdm/tqdm) for progress bars. 45 | -------------------------------------------------------------------------------- /packages/candle/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .engine import create_supervised_trainer 3 | except ImportError: # no ignite 4 | def create_supervised_trainer(*args, **kwargs): 5 | raise ImportError('Ignite not installed') 6 | 7 | from .modules import (Lambda, 8 | Flatten, 9 | View, 10 | Concat, 11 | Split, 12 | SkipConnection, 13 | NoInputSpec, 14 | CannedNet, 15 | CannedResNet) 16 | 17 | from .recurrent import (Window, 18 | Recur) 19 | 20 | from .tracing import (convert_to_tensor, 21 | Integer) 22 | 23 | from .utils import (batch_fn, 24 | outer_product, 25 | flatten, 26 | batch_flatten, 27 | cat, 28 | stack) 29 | -------------------------------------------------------------------------------- /packages/candle/engine.py: -------------------------------------------------------------------------------- 1 | import ignite.engine as engine 2 | import torch 3 | import torch.nn.utils as nnutils 4 | 5 | 6 | def create_supervised_trainer(model, optimizer, loss_fn, 7 | device=None, non_blocking=False, 8 | prepare_batch=engine._prepare_batch, 9 | check_nan=False, 10 | grad_clip=None, 11 | output_predictions=False): 12 | """As ignite.engine.create_supervised_trainer, but may also optionally perform: 13 | - NaN checking on predictions (in a more debuggable way than ignite.handlers.TerminateOnNaN) 14 | - Gradient clipping 15 | - Record the predictions made by a model 16 | 17 | Arguments: 18 | (as ignite.engine.create_supervised_trainer, plus) 19 | check_nan: Optional boolean specifying whether the engine should check predictions for NaN values. Defaults to 20 | False. If True, and a NaN value is encountered, then a RuntimeError will be raised with attributes 'x', 'y', 21 | 'y_pred', 'model', details the feature, label, prediction and model, respetively, on which this occurred. 22 | grad_clip: Optional number, boolean or None, specifying the value to clip the infinity-norm of the gradient to. 23 | Defaults to None. If False or None then no gradient clipping will be applied. If True then the gradient is 24 | clipped to 1.0. 25 | output_predictions: Optional boolean specifying whether the engine should record the predictions the model made 26 | on a batch. Defaults to False. If True then state.output will be a tuple of (loss, predictions). If False 27 | then state.output will just be the loss. (Not wrapped in a tuple.) 28 | """ 29 | 30 | if device: 31 | model.to(device) 32 | 33 | if grad_clip is False: 34 | grad_clip = None 35 | elif grad_clip is True: 36 | grad_clip = 1.0 37 | 38 | def _update(engine, batch): 39 | model.train() 40 | optimizer.zero_grad() 41 | x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) 42 | y_pred = model(x) 43 | 44 | if check_nan and torch.isnan(y_pred).any(): 45 | e = RuntimeError('Model generated NaN value.') 46 | e.y = y 47 | e.y_pred = y_pred 48 | e.x = x 49 | e.model = model 50 | raise e 51 | 52 | loss = loss_fn(y_pred, y) 53 | loss.backward() 54 | 55 | if grad_clip is not None: 56 | nnutils.clip_grad_norm_(model.parameters(), grad_clip, norm_type='inf') 57 | 58 | optimizer.step() 59 | 60 | if output_predictions: 61 | return loss.item(), y_pred 62 | else: 63 | return loss.item() 64 | 65 | return engine.Engine(_update) 66 | -------------------------------------------------------------------------------- /packages/candle/modules.py: -------------------------------------------------------------------------------- 1 | import functools as ft 2 | import torch 3 | import torch.nn as nn 4 | 5 | from . import utils 6 | 7 | 8 | class Lambda(nn.Module): 9 | """Wraps an arbitrary PyTorch function into a Module.""" 10 | 11 | def __init__(self, fn, fn_args=(), fn_kwargs=None): 12 | super(Lambda, self).__init__() 13 | if isinstance(fn, ft.partial): 14 | fn.__name__ = fn.func.__name__ 15 | fn.__qualname__ = fn.func.__qualname__ 16 | self.fn = fn 17 | self.fn_args = fn_args 18 | self.fn_kwargs = {} if fn_kwargs is None else fn_kwargs 19 | 20 | def forward(self, x): 21 | return self.fn(x, *self.fn_args, **self.fn_kwargs) 22 | 23 | def extra_repr(self): 24 | return f'fn={self.fn.__qualname__}, fn_args={self.fn_args}, fn_kwargs={self.fn_kwargs}' 25 | 26 | 27 | class Flatten(nn.Module): 28 | """Flattening Module.""" 29 | 30 | def forward(self, x): 31 | return utils.batch_flatten(x) 32 | 33 | 34 | class View(nn.Module): 35 | """View Module.""" 36 | 37 | def __init__(self, shape, **kwargs): 38 | super(View, self).__init__(**kwargs) 39 | self.shape = shape 40 | 41 | def forward(self, x): 42 | return x.view(x.size(0), *self.shape) 43 | 44 | def extra_repr(self): 45 | return f'shape={self.shape}' 46 | 47 | 48 | class Concat(nn.Module): 49 | """Concatenation Module.""" 50 | 51 | def __init__(self, dim=-1, **kwargs): 52 | super(Concat, self).__init__(**kwargs) 53 | self.dim = dim 54 | 55 | def forward(self, xs): 56 | return torch.cat(xs, dim=self.dim) 57 | 58 | def extra_repr(self): 59 | return f'dim={self.dim}' 60 | 61 | 62 | class Split(nn.Module): 63 | """Split Module.""" 64 | 65 | def __init__(self, split, dim=-1, **kwargs): 66 | super(Split, self).__init__(**kwargs) 67 | self.split = split 68 | self.dim = dim 69 | 70 | def forward(self, x): 71 | return torch.split(x, self.split, self.dim) 72 | 73 | def extra_repr(self): 74 | return f'split={self.split}, dim={self.dim}' 75 | 76 | 77 | class SkipConnection(nn.Module): 78 | """Applies a Module, and then adds its input to its output and returns that.""" 79 | 80 | def __init__(self, module, **kwargs): 81 | super(SkipConnection, self).__init__(**kwargs) 82 | self.module = module 83 | 84 | def forward(self, x): 85 | y = self.module(x) 86 | if isinstance(x, (tuple, list)): 87 | return tuple(xi + yi for xi, yi in zip(x, y)) 88 | else: 89 | return x + y 90 | 91 | 92 | class NoInputSpec(nn.Module): 93 | """Used to create a Module without specifying its number of inputs. A necessary evil is that the Module must 94 | be called on an example batch of inputs before use, so that it can figure out the input shapes. 95 | """ 96 | 97 | def __init__(self, moduletype, *args, **kwargs): 98 | self._parameters_property = None 99 | self._parameters_not_specified = True 100 | super(NoInputSpec, self).__init__() 101 | 102 | self.moduletype = moduletype 103 | self.args = args 104 | self.kwargs = kwargs 105 | 106 | self.module = None 107 | 108 | @property 109 | def _parameters(self): 110 | if self._parameters_not_specified: 111 | raise RuntimeError('Module has not yet been called on an example batch of inputs, so it is not yet ' 112 | 'fully specified.') 113 | else: 114 | return self._parameters_property 115 | 116 | @_parameters.setter 117 | def _parameters(self, item): 118 | self._parameters_property = item 119 | 120 | def create_module(self, x): 121 | """The specification for creating a module from the input tensor :x:, and given :args: and :kwarg:.""" 122 | # x.size(1) is the feature dimension of a two-dimensional input, which is used for creating Linear layers, and 123 | # the channel dimension of a three-dimensional (batch, channel, length) input, which is used to create 124 | # convolutional layers. Even so this is quite fragile. 125 | # TODO: Improve this 126 | return self.moduletype(x.size(1), *self.args, **self.kwargs) 127 | 128 | def forward(self, x): 129 | if self.module is None: 130 | # Can't use self.module == None to test because the self.module = self.create_module(x) line actually 131 | # accesses self._parameters, because it's assigning a Module. 132 | self._parameters_not_specified = False 133 | self.module = self.create_module(x) 134 | return self.module(x) 135 | 136 | def extra_repr(self): 137 | if self.module is None: 138 | return f'not called yet: moduletype={self.moduletype.__name__}, args={self.args}, kwargs={self.kwargs}' 139 | else: 140 | return super(NoInputSpec, self).extra_repr() 141 | 142 | 143 | class CannedNet(nn.Module): 144 | """Provides a simple extensible way to specify a neural network without having to define a class. A bit like 145 | Sequential, but more general. 146 | - A Module may be specified by something as simple as an integer, for example the width of a Linear layer. 147 | - In particular the input size of a Linear layer does not need computing in advance 148 | - The framework may be extended to easily create more complicated nets, for example ResNets; see CannedResNet. 149 | 150 | Subclasses wishing to define how to interpret these 'something more simple' should override the _interpret_element 151 | method. 152 | 153 | As it uses NoInputSpec internally, instances of this Module should be called on an example batch of inputs before 154 | use. 155 | """ 156 | 157 | def __init__(self, hidden_blocks, debug=False, **kwargs): 158 | """Create a neural network. 159 | 160 | Note that no activation functions are automatically applied: these should be specified in the :hidden_blocks: 161 | argument along with everything else. 162 | 163 | Arguments: 164 | hidden_blocks: A tuple specifying the layers of the network. Subclasses may specify how to interpret the 165 | values of the tuple. In the default implementation, integers are interpreted as a hidden layer of that 166 | size, callables are wrapped into a Module, and Modules are used as they are. The documentation of a 167 | subclass may provide more information on other objects it can interpret. Note that any Module instances 168 | which are elements of the tuple must have an 'output_shapes' method taking one parameter 'input_shapes', 169 | specifying the output shapes of the tensors that the module produces, given particular input shapes. (In 170 | both cases excluding batch dimension.) 171 | debug: Optional, defaults to False. Whether to print the sizes of Tensors as they go through the network 172 | layer-by-layer. 173 | """ 174 | 175 | super(CannedNet, self).__init__(**kwargs) 176 | 177 | self.hidden_blocks = hidden_blocks 178 | self.debug = debug 179 | 180 | self.layers = nn.ModuleList() 181 | for elem in hidden_blocks: 182 | self.layers.append(self._interpret_element_wrapper(elem)) 183 | 184 | @classmethod 185 | def spec(cls, *args, **kwargs): 186 | """Returns a function of no arguments which returns instances of the class with the specified arguments and 187 | keyword arguments given now. 188 | """ 189 | def specced(): 190 | return cls(*args, **kwargs) 191 | return specced 192 | 193 | def __iter__(self): 194 | # Allows for *unpacking 195 | return iter(self.layers) 196 | 197 | def _interpret_element(self, elem): 198 | """Specifies how an element of the :hidden_blocks: argument of __init__ should be interpreted. 199 | 200 | If overriding this method in a subclass, the expected pattern is (note in particular the super() call): 201 | 202 | def _interpret_element(self, current_shapes, elem): 203 | if isinstance(elem, my_type): 204 | if is_valid(elem): 205 | ... 206 | return module 207 | else: 208 | raise ValueError(...) 209 | return super()._interpret_element(current_shapes, elem) 210 | 211 | Arguments: 212 | elem: The element of the :hidden_blocks: tuple. 213 | 214 | Returns: 215 | If :elem: could not be interpreted it will return None. 216 | If :elem: could be interpreted it will return a module, as specified by :elem:. 217 | 218 | Raises: 219 | ValueError if :elem: could be interpreted but did not correspond to a well-defined module. For example, 220 | integers correspond to dense layers of that size. A negative integer would thus raise a ValueError. 221 | """ 222 | 223 | if isinstance(elem, int): 224 | if elem < 1: 225 | raise ValueError(f'Integers specifying layers sizes must be greater than or equal to one. Given ' 226 | f'{elem}.') 227 | layer = NoInputSpec(nn.Linear, elem) 228 | return layer 229 | elif isinstance(elem, nn.Module): 230 | return elem 231 | elif callable(elem): 232 | return Lambda(elem) 233 | 234 | def _interpret_element_wrapper(self, elem): 235 | """Wraps the _interpret_element method to check whether or not an element has been interpreted.""" 236 | 237 | out = self._interpret_element(elem) 238 | if out is None: 239 | raise ValueError(f'Element {elem} of type {type(elem)} in hidden_blocks argument was not understood.') 240 | return out 241 | 242 | def forward(self, x): 243 | if self.debug: 244 | print(f'Input: {x.shape}') 245 | for layer in self.layers: 246 | x = layer(x) 247 | if self.debug: 248 | print(f'{type(layer).__name__}: {x.shape}') 249 | return x 250 | 251 | 252 | class CannedResNet(CannedNet): 253 | """As CannedNet, but is also capable of understanding tuples as elements of the :hidden_blocks: argument to 254 | __init__. This tuple element will be interpreted recursively as another CannedResNet, and a skip connection added 255 | across the layers specified in the tuple. 256 | """ 257 | 258 | def _interpret_element(self, elem): 259 | if isinstance(elem, (tuple, list)): 260 | subnet = self.__class__(hidden_blocks=elem) 261 | return SkipConnection(subnet) 262 | return super(CannedResNet, self)._interpret_element(elem) 263 | -------------------------------------------------------------------------------- /packages/candle/recurrent.py: -------------------------------------------------------------------------------- 1 | import functools as ft 2 | import queue 3 | import torch 4 | import torch.nn as nn 5 | import warnings 6 | 7 | from . import utils 8 | 9 | 10 | def identity(x): 11 | return x 12 | 13 | 14 | class Window(nn.Module): 15 | """Creates a sliding window along a Tensor, yielding slices of the Tensor as it goes along. It is given Tensors to 16 | store in memory, and then when requested, will yield slices from them as if those Tensors were one long Tensor, 17 | concatenated along a specified axis. This is useful for time series data, for example, when data may be arriving 18 | continuously, at variable rates and lengths. 19 | 20 | Can iterate on a Window instance to get all of the possible slices from its memory. For particular choices of length 21 | and stride, given a collection of input Tensors of some total length (along the specified dimension), this may mean 22 | that the final slice is smaller than requested. Instead of being yielded, it will be retained in the Window's memory 23 | and used to start off from later, once new Tensors have been added to its memory. 24 | 25 | May also be used as a PyTorch Module. in this case the single input is pushed into the Window, and the results 26 | returned as a single tensor, stacked along a new dimension. (Which will be the last dimension in the shape.) This 27 | usage is dependent on the results all having the same shape. That is, if adjust_length != 0 then a suitable 28 | transformation should be applied to ensure that the results remain the same length. 29 | 30 | Note that if you just want a sliding window with no transformations then torch.Tensor.unfold is going to be much 31 | quicker. 32 | """ 33 | 34 | def __init__(self, length, stride, adjust_length=0, dim=-1, clone=True, transformation=identity, **kwargs): 35 | """See Window.__doc__. 36 | 37 | Arguments: 38 | length: The length of the slice taken from the input Tensor 39 | stride: How much to move the start point of the slice after the previous slice has been yielded. Interesting 40 | choices include 1 (to yield overlapping slices, each one just offset from the previous one), or 41 | :length:, which will yield completely nonoverlapping slices, or 0, which will yield slices starting from 42 | the same start point every time (for example, to use alongside a nonzero value for :adjust_length:). 43 | adjust_length: Optional, integer or callable, defaults to 0. How much the length is changed when a slice has 44 | been yielded. If an integer it will be added on to the length. If a callable then the current length 45 | will be passed as an input, and the new length should be returned as an output. For example, setting 46 | :stride:=0 and :adjust_length:=1 will give an expanding window. 47 | dim: Optional, defaults to -1. The dimension of the input Tensor to move along whilst yielding slices. 48 | clone: Optional, defaults to True. Whether to clone the output before yielding it. Otherwise later in-place 49 | operations could affect the Window's memory. If you're sure this isn't going to happen, for example 50 | because a copy is made somewhere within the transformation argument, then setting this this to False 51 | will give a speed-up. 52 | transformation: Optional, defaults to no transformation. A transformation to apply to the output before 53 | yielding it. 54 | """ 55 | 56 | super(Window, self).__init__(**kwargs) 57 | 58 | self.length = length 59 | self._original_length = length 60 | self.stride = stride 61 | self.adjust_length = adjust_length 62 | self.dim = dim 63 | self.clone = clone 64 | self.transformation = transformation if transformation is not None else identity 65 | 66 | self.last = torch.zeros(0) 67 | self.queue = queue.Queue() 68 | 69 | self._device = None 70 | 71 | def extra_repr(self): 72 | msg = f'length={self.length}, stride={self.stride}, adjust_length={self.adjust_length}, dim={self.dim}' 73 | if self.transformation is not identity and not isinstance(self.transformation, nn.Module): 74 | if hasattr(self.transformation, '__name__'): 75 | msg += f', transformation={self.transformation.__name__}' 76 | elif isinstance(self.transformation, ft.partial): 77 | fn = self.transformation 78 | msg += f', transformation=partial({fn.func.__name__}, args={fn.args}, keywords={fn.keywords})' 79 | return msg 80 | 81 | def push(self, item): 82 | """Add a Tensor to the Window's memory.""" 83 | if self._device is None: 84 | self._device = item.device 85 | self.last = self.last.to(device=item.device) 86 | if self._device != item.device: 87 | raise RuntimeError(f'{self.__class__.__name__} previously had tensors of backend {self._device} pushed, but' 88 | f' have now had tensor of backend {item.device} pushed.') 89 | self.queue.put_nowait(item) 90 | 91 | def pull(self): 92 | """Take a slice from the Tensors in the Window's memory.""" 93 | 94 | size_so_far = self.last.size(self.dim) 95 | items = [self.last] 96 | 97 | while True: 98 | if size_so_far < self.length: 99 | try: 100 | last = self.queue.get_nowait() 101 | except queue.Empty: 102 | self.last = utils.cat(items, dim=self.dim) 103 | raise 104 | size_so_far += last.size(self.dim) 105 | items.append(last) 106 | else: 107 | break 108 | 109 | out = utils.cat(items, dim=self.dim) 110 | rem = out.size(self.dim) - self.stride 111 | out, self.last = out.narrow(self.dim, 0, self.length), out.narrow(self.dim, self.stride, rem) 112 | 113 | if callable(self.adjust_length): 114 | self.length = self.adjust_length(self.length) 115 | else: 116 | self.length = self.length + self.adjust_length 117 | 118 | out = self.transformation(out) 119 | if self.clone: 120 | out = out.clone() 121 | return out 122 | 123 | def clear(self): 124 | """Clear the Window's memory.""" 125 | 126 | try: 127 | while True: 128 | self.queue.get_nowait() 129 | except queue.Empty: 130 | pass 131 | self.last = torch.zeros(0) 132 | self._device = None 133 | self.length = self._original_length 134 | 135 | def __iter__(self): 136 | not_iterated = True 137 | try: 138 | while True: 139 | yield self.pull() 140 | not_iterated = False 141 | except queue.Empty: 142 | if not_iterated: 143 | warnings.simplefilter('always', RuntimeWarning) 144 | warnings.warn(f'{self.__class__.__name__} did not iterate over any windows. This means there was not ' 145 | f'enough input data to create one entire window: either increase the length of the input ' 146 | f'data or decrease the size of the window.', RuntimeWarning) 147 | 148 | def forward(self, x): 149 | self.push(x) 150 | try: 151 | out = utils.stack(list(self), dim=-1) 152 | finally: 153 | self.clear() 154 | return out 155 | 156 | 157 | class Recur(nn.Module): 158 | """Takes a tensor of shape (..., channels, path), splits it up into individual tensors along the last (path) 159 | dimension, and applies the specified network to them in a recurrent manner. 160 | """ 161 | 162 | def __init__(self, module, memory_shape, intermediate_outputs=True, **kwargs): 163 | super(Recur, self).__init__(**kwargs) 164 | 165 | self.module = module 166 | self.memory_shape = memory_shape 167 | self.intermediate_outputs = intermediate_outputs 168 | 169 | def extra_repr(self): 170 | return f'memory_shape={self.memory_shape}, intermediate_outputs={self.intermediate_outputs}' 171 | 172 | def forward(self, x): 173 | outs = [] 174 | memory = torch.zeros(x.size(0), *self.memory_shape, device=x.device) 175 | xs = x.unbind(dim=-1) 176 | for inp in xs: 177 | memory, out = self.module((memory, inp)) 178 | memory = memory.view(x.size(0), *self.memory_shape) 179 | if self.intermediate_outputs: 180 | outs.append(out) 181 | if self.intermediate_outputs: 182 | return utils.stack(outs, dim=-1) 183 | else: 184 | return out 185 | -------------------------------------------------------------------------------- /packages/candle/tracing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def convert_to_tensor(tensor, **kwargs): 5 | """Converts the argument :tensor: to a tensor if it isn't already one. If converted to a tensor, then any :**kwargs: 6 | will be passed as well. 7 | 8 | This is useful (over torch.tensor or torch.as_tensor) when using torch.jit.trace, which interprets both torch.tensor 9 | and torch.as_tensor as static values; this will correctly use the same tensor that was passed, if possible. 10 | """ 11 | if isinstance(tensor, torch.Tensor): 12 | return tensor 13 | else: 14 | return torch.tensor(tensor, **kwargs) 15 | 16 | 17 | # For tracing, when everything gets turned into a tensor. 18 | Integer = (int, torch.IntTensor, torch.LongTensor) 19 | -------------------------------------------------------------------------------- /packages/candle/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def batch_fn(fn): 5 | """Transforms a function :fn: to act on each element of a batch individually.""" 6 | def batched_fn(x, **kwargs): 7 | return torch.stack([fn(xi, **kwargs) for xi in x]) 8 | return batched_fn 9 | 10 | 11 | def outer_product(tensor1, tensor2): 12 | """Computes the outer product of two tensors.""" 13 | return torch.tensordot(tensor1.unsqueeze(0), tensor2.unsqueeze(0), dims=((0,), (0,))) 14 | 15 | 16 | def flatten(tensor): 17 | """Flattens a tensor.""" 18 | return tensor.view(-1) 19 | 20 | 21 | def batch_flatten(tensor): 22 | """Flattens a tensor except for the batch dimension.""" 23 | return tensor.view(tensor.size(0), -1) 24 | 25 | 26 | def cat(tensors, dim=0, out=None): 27 | """As torch.cat, but returns the original tensor if len(tensors) == 1, so that an unneeded copy is not made.""" 28 | if len(tensors) == 1: 29 | return tensors[0] 30 | else: 31 | return torch.cat(tensors, dim=dim, out=out) 32 | 33 | 34 | def stack(tensors, dim=0, out=None): 35 | """As torch.stack, but returns the original tensor if len(tensors) == 1, so that an unneeded copy is not made.""" 36 | if len(tensors) == 1: 37 | return tensors[0].unsqueeze(dim=dim) 38 | else: 39 | return torch.stack(tensors, dim=dim, out=out) 40 | -------------------------------------------------------------------------------- /packages/siglayer/__init__.py: -------------------------------------------------------------------------------- 1 | from .backend import (sig_dim, 2 | path_sig, 3 | batch_path_sig, 4 | Signature) 5 | 6 | from .modules import (Augment, 7 | ViewSignature) 8 | -------------------------------------------------------------------------------- /packages/siglayer/backend.py: -------------------------------------------------------------------------------- 1 | import candle 2 | import iisignature 3 | import torch 4 | import torch.autograd as autograd 5 | import torch.nn as nn 6 | import warnings 7 | 8 | 9 | def sig_dim(alphabet_size, depth): 10 | """Calculates the number of terms in a signature of depth :depth: over an alphabet of size :alphabet_size:.""" 11 | return int(alphabet_size * (1 - alphabet_size ** depth) / (1 - alphabet_size)) 12 | # == sum(alphabet_size ** i for i in range(1, depth + 1)) (geometric sum formula) 13 | 14 | 15 | class path_sig_fn(autograd.Function): 16 | """An autograd.Function corresponding to the signature map. See also siglayer/backend/pytorch_implementation.py.""" 17 | 18 | @staticmethod 19 | def forward(ctx, path, depth): 20 | device = path.device 21 | # transpose because the PyTorch convention for convolutions is channels first. The iisignature expectation is 22 | # that channels are last. 23 | path = path.detach().cpu().numpy().transpose() # sloooow CPU :( 24 | ctx.path = path 25 | ctx.depth = depth 26 | return torch.tensor(iisignature.sig(path, depth), dtype=torch.float, device=device) 27 | 28 | @staticmethod 29 | def backward(ctx, grad_output): 30 | device = grad_output.device 31 | backprop = iisignature.sigbackprop(grad_output.cpu().numpy(), ctx.path, ctx.depth) 32 | # transpose again to go back to the PyTorch convention of channels first 33 | out = torch.tensor(backprop, dtype=torch.float, device=device).t() 34 | 35 | # better safe than sorry 36 | # https://discuss.pytorch.org/t/when-should-you-save-for-backward-vs-storing-in-ctx/6522/9 37 | # not sure this is actually necessary though 38 | del ctx.path 39 | del ctx.depth 40 | return out, None 41 | 42 | 43 | def path_sig(path, depth): 44 | """Calculates the signature transform of a :path: to signature depth :depth:.""" 45 | return path_sig_fn.apply(path, depth) 46 | 47 | 48 | batch_path_sig = candle.batch_fn(path_sig) 49 | 50 | 51 | class Signature(nn.Module): 52 | """Given some path mapping from, say, [0, 1] into \reals^d, we may define the 'signature' of the path as a 53 | particular sigtensor with respect to an alphabet of n letters. (Note how d is the target dimension of the path.) 54 | That is, the signature is a map from the space of paths to the tensor algebra. Up to certain mathematical niceties, 55 | this map may be inverted; the signature is sufficient to define the path. (Technically speaking, it defines the path 56 | up to 'tree-like equivalence': this means that the signature does not pick up on back-tracking) 57 | 58 | Thus the signature is a natural way to characterise a path; in the language of machine learning is an excellent 59 | feature map. 60 | 61 | Given a tensor of shape (x, y), then one may interpret this a piecewise constant path from [0, x] into \reals^y, 62 | changing its value at each integer. Whether this is a natural interpretation depends on the data that the tensor 63 | represents, of course, but this allows for taking the signature of a tensor, which is precisely what this Module 64 | does. 65 | """ 66 | 67 | def __init__(self, depth, **kwargs): 68 | if not isinstance(depth, candle.Integer) or depth < 1: 69 | raise ValueError(f'Depth must be an integer greater than or equal to one. Given {depth} of type ' 70 | f'{type(depth)}') 71 | super(Signature, self).__init__(**kwargs) 72 | self.depth = depth 73 | 74 | def forward(self, path): 75 | if path.size(1) == 1: 76 | warnings.warn(f'{self.__class__.__name__} called on path with only one channel; the signature is now just ' 77 | f'the moments of the path, so there is no interesting information from cross terms.') 78 | # path is expected to be a 3-dimensional tensor, with batch, channel and length axes respectively, say of shape 79 | # (b, c, l). Each batch element is treated separately. Then values are interpreted as l sample points from a 80 | # path in \reals^c 81 | return batch_path_sig(path, depth=self.depth) 82 | 83 | def extra_repr(self): 84 | return f'depth={self.depth}' 85 | -------------------------------------------------------------------------------- /packages/siglayer/examples.py: -------------------------------------------------------------------------------- 1 | import candle 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from . import backend 6 | from . import modules 7 | 8 | 9 | def create_feedforward(output_shape, sig=True, sig_depth=4, final_nonlinearity=lambda x: x, 10 | layer_sizes=(32, 32, 32)): 11 | """This simple model uses a few hidden layers with signature nonlinearities between them. 12 | If :sig: is falsy then the signature layers will be replaced with ReLU instead. 13 | It expects input tensors of two dimensions: (batch, features). 14 | 15 | Note that whilst this is a simple example, this is fundamentally quite a strange idea (with sig=True). 16 | There's no natural path-like structure here. 17 | """ 18 | 19 | if sig: 20 | nonlinearity = lambda: modules.ViewSignature(channels=2, length=16, sig_depth=sig_depth) 21 | else: 22 | nonlinearity = lambda: F.relu 23 | 24 | layers = [] 25 | for layer_size in layer_sizes: 26 | layers.append(layer_size) 27 | layers.append(nonlinearity()) 28 | return candle.CannedNet((candle.Flatten(), 29 | *layers, 30 | torch.Size(output_shape).numel(), 31 | candle.View(output_shape), 32 | final_nonlinearity)) 33 | 34 | 35 | def create_simple(output_shape, sig=True, sig_depth=4, final_nonlinearity=lambda x: x, 36 | augment_layer_sizes=(8, 8, 2), augment_kernel_size=1, 37 | augment_include_original=True, augment_include_time=True, 38 | layer_sizes=(32, 32)): 39 | """This model uses a single signature layer: 40 | - Augment the features with something learnable 41 | - Apply signature 42 | - Small ReLU network afterwards. 43 | If :sig: is falsy then the signature layers will be replaced with flatten-and-ReLU instead. 44 | It expects input tensors of three dimensions: (batch, channels, length). 45 | """ 46 | 47 | if sig: 48 | siglayer = (backend.Signature(sig_depth),) 49 | else: 50 | siglayer = (candle.Flatten(), F.relu) 51 | 52 | layers = [] 53 | for layer_size in layer_sizes: 54 | layers.append(layer_size) 55 | layers.append(F.relu) 56 | 57 | return candle.CannedNet((modules.Augment(layer_sizes=augment_layer_sizes, 58 | kernel_size=augment_kernel_size, 59 | include_original=augment_include_original, 60 | include_time=augment_include_time), 61 | *siglayer, 62 | *layers, 63 | torch.Size(output_shape).numel(), 64 | candle.View(output_shape), 65 | final_nonlinearity)) 66 | 67 | 68 | # This example is deliberately not as flexible as the other examples here, to make it easier 69 | # to understand 70 | def create_windowed(output_shape, sig_depth=4, final_nonlinearity=lambda x: x): 71 | """This model applies two signature layers: 72 | - Augment the features with something learnable 73 | - Apply signature 74 | - Recurrent network 75 | - Apply signature 76 | - Recurrent network 77 | - take the final output of the last recurrent block and reshape and return it. 78 | 79 | Basically it applies a couple of RNNs to the input data with signatures in between them. 80 | """ 81 | 82 | if sig: 83 | transformation = lambda: backend.Signature(depth=sig_depth) 84 | else: 85 | transformation = lambda: candle.batch_flatten 86 | 87 | output_size = torch.Size(output_shape).numel() 88 | 89 | return candle.CannedNet((modules.Augment(layer_sizes=(16, 16, 2), kernel_size=4), 90 | candle.Window(length=5, stride=1, transformation=backend.Signature(depth=sig_depth)), 91 | # We could equally well have an Augment here instead of a Recur; both are path-preserving 92 | # neural networks. 93 | candle.Recur(module=candle.CannedNet((candle.Concat(), 94 | 32, F.relu, 95 | 16, # memory size + output size 96 | candle.Split((8, 8)))), # memory size, output size 97 | memory_shape=(8,)), # memory size 98 | candle.Window(length=10, stride=5, transformation=backend.Signature(depth=sig_depth)), 99 | candle.Recur(module=candle.CannedNet((candle.Concat(), 100 | 32, F.relu, 16, F.relu, 101 | 8 + output_size, # memory size + output size 102 | candle.Split((8, output_size)))), # memory size, output size 103 | memory_shape=(8,), # memory size 104 | intermediate_outputs=False), 105 | candle.View(output_shape), 106 | final_nonlinearity)) 107 | -------------------------------------------------------------------------------- /packages/siglayer/modules.py: -------------------------------------------------------------------------------- 1 | import candle 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from . import backend 7 | 8 | 9 | class Augment(nn.Module): 10 | """Augmenting the path before feeding it into a signature is crucial to obtain higher-order information. A way to do 11 | this is to apply a feedforward neural network to sections of the path, to obtain a nonlinear function of the path 12 | before the signature is applied. 13 | 14 | Both the original path and time can be specifically included in the augmentation. 15 | """ 16 | 17 | def __init__(self, layer_sizes, kernel_size, activation=F.relu, include_original=True, include_time=True, **kwargs): 18 | """See Augment.__doc__. 19 | 20 | The assumption is that the input tensor is three dimensional, with dimensions (batch, channel, path). A 21 | feedforward neural network is applied to subtensors of shape (channel, :kernel_size:); the input tensor is split 22 | up along its batch dimension, and the subtensors are successive overlapping slices along the path dimension. The 23 | result of this feedforward neural network provides the augmentation of the path, thus giving an output tensor of 24 | shape (batch, new_channel, path - :kernel_size: + 1), where the size of new_channel depends on the 25 | :include_original: and :include_time: arguments. In pseudocode: 26 | 27 | new_channel = :layer_sizes:[-1] 28 | if :include_original:: 29 | new_channel += channel 30 | if :include_time:: 31 | new_channel += 1 32 | 33 | Arguments: 34 | layer_sizes: tuple of int. Specifies the size of the feedforward neural network to apply to the path. The 35 | final value of this tuple specifies the number of channels in the augmented path. 36 | kernel_size: int specifying the size of the kernel to slide over the path. 37 | activation: Optional, defaults to ReLU. The activation function to use in the feedforward neural network. 38 | include_original: Optional, defaults to True. Whether or not to include the original path (pre-augmentation) 39 | in the augmented path. 40 | include_time: Optional, defaults to True. Whether or not to also augment the path with a 'time' value. These 41 | are values in [0, 1] corresponding to how far along the path dimension the element is. 42 | """ 43 | 44 | super(Augment, self).__init__(**kwargs) 45 | 46 | if isinstance(layer_sizes, int): 47 | layer_sizes = (layer_sizes,) 48 | 49 | self.layer_sizes = layer_sizes 50 | self.kernel_size = kernel_size 51 | self.activation = activation 52 | self.include_original = include_original 53 | self.include_time = include_time 54 | 55 | self.convs = nn.ModuleList() 56 | if layer_sizes: 57 | self.convs.append(candle.NoInputSpec(nn.Conv1d, out_channels=layer_sizes[0], kernel_size=kernel_size)) 58 | last_layer_channels = layer_sizes[0] 59 | for augment_channel in layer_sizes[1:]: 60 | # These pointwise convolutions correspond to sliding a standard feedforward network across the input. 61 | self.convs.append(nn.Conv1d(in_channels=last_layer_channels, out_channels=augment_channel, 62 | kernel_size=1)) 63 | last_layer_channels = augment_channel 64 | 65 | def extra_repr(self): 66 | return f'include_original={self.include_original}, include_time={self.include_time}' 67 | 68 | def forward(self, x): 69 | if len(x.shape) != 3: 70 | raise RuntimeError(f'Argument x should have three dimensions, batch, channnel, path. Given shape' 71 | f'{x.shape} dimensions with {x}.') 72 | pieces = [] 73 | if self.include_original: 74 | truncated_x = x.narrow(2, self.kernel_size - 1, x.size(2) - self.kernel_size + 1) 75 | pieces.append(truncated_x) 76 | 77 | if self.include_time: 78 | time = torch.linspace(0, 1, x.size(2) - self.kernel_size + 1, dtype=torch.float, device=x.device) 79 | time = time.expand(x.size(0), 1, -1) 80 | pieces.append(time) 81 | 82 | if self.layer_sizes: 83 | augmented_x = self.convs[0](x) 84 | for conv in self.convs[1:]: 85 | augmented_x = self.activation(augmented_x) 86 | augmented_x = conv(augmented_x) 87 | pieces.append(augmented_x) 88 | return candle.cat(pieces, dim=1) # concatenate along channel axis 89 | 90 | 91 | class ViewSignature(nn.Module): 92 | """Applies a signature Module in a manner akin to any other nonlinearity. As the signature requires a 93 | three-dimensional input of shape (batch, channel, path), this Module reshapes the input automatically before passing 94 | it to the signature Module. 95 | 96 | Note that this is fundamentally quite a strange idea - this Module is more for demonstration purposes; you probably 97 | don't want to use it. 98 | """ 99 | 100 | def __init__(self, channels, length, sig_depth, **kwargs): 101 | """See ViewSigLayer.__doc__. 102 | 103 | Arguments: 104 | channels: int, specifying the number of channels in the reshaped tensor. 105 | length: int, specifying the length of the path in the reshaped tensor. 106 | sig_depth: int, specifying the depth at which to truncate the signature. 107 | """ 108 | 109 | super(ViewSignature, self).__init__(**kwargs) 110 | 111 | self.channels = channels 112 | self.length = length 113 | 114 | self.sig = backend.Signature(sig_depth) 115 | 116 | def forward(self, x): 117 | x = x.view(x.size(0), self.channels, self.length) 118 | return self.sig(x) 119 | 120 | def extra_repr(self): 121 | return f'channels={self.channels}, length={self.length}' 122 | -------------------------------------------------------------------------------- /poster/deepsig_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrick-kidger/Deep-Signature-Transforms/76d79bc88a0fba2a7b670ce6f12d1ee8c21aedfe/poster/deepsig_poster.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fbm==0.2.0 2 | gym==0.12.1 3 | pytorch-ignite==0.1.2 4 | iisignature==0.23 5 | jupyter==1.0.0 6 | matplotlib==2.2.4 7 | pandas==0.24.2 8 | torch==1.0.1 9 | scikit-learn==0.20.3 10 | sdepy==1.0.1 11 | tqdm==4.31.1 12 | 13 | -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | ## src 2 | 3 | Every example is split into two files, a `.ipynb` and a `.py` file. Just run the `.ipynb` file to reproduce the experiments. (Its corresponding `.py` file just defines the objects it needs, and will be imported.) 4 | 5 | There are two other files here. First there is `base.ipynb`, which is run by every other notebook, and just performs path hacking to get access to the packages folder. And then there is also `utils.py`, which is exactly what it sounds like. -------------------------------------------------------------------------------- /src/base.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.insert(0, '../packages')" 11 | ] 12 | } 13 | ], 14 | "metadata": { 15 | "kernelspec": { 16 | "display_name": "Python 3", 17 | "language": "python", 18 | "name": "python3" 19 | }, 20 | "language_info": { 21 | "codemirror_mode": { 22 | "name": "ipython", 23 | "version": 3 24 | }, 25 | "file_extension": ".py", 26 | "mimetype": "text/x-python", 27 | "name": "python", 28 | "nbconvert_exporter": "python", 29 | "pygments_lexer": "ipython3", 30 | "version": "3.7.3" 31 | } 32 | }, 33 | "nbformat": 4, 34 | "nbformat_minor": 2 35 | } 36 | -------------------------------------------------------------------------------- /src/example_generative_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "scrolled": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "%run base.ipynb\n", 12 | "%matplotlib inline\n", 13 | "\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import matplotlib.patches as mpatches\n", 16 | "import numpy as np\n", 17 | "import torch\n", 18 | "import torch.optim as optim\n", 19 | "import torch.utils.data as torchdata\n", 20 | "\n", 21 | "import generative_model\n", 22 | "import utils\n", 23 | "\n", 24 | "plt.rcParams['axes.labelsize'] = 20\n", 25 | "plt.rcParams['xtick.labelsize'] = 20\n", 26 | "plt.rcParams['ytick.labelsize'] = 20" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## Hyperparameters" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "train_batch_size = 2 ** 10\n", 43 | "val_batch_size = 2 ** 10\n", 44 | "max_epochs = 100\n", 45 | "\n", 46 | "optimizer_fn = lambda x: optim.Adam(x, lr=0.01)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Data" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "n_points = 100\n", 63 | "\n", 64 | "train_dataset = generative_model.get_noise(n_points=n_points, num_samples=train_batch_size)\n", 65 | "eval_dataset = generative_model.get_noise(n_points=n_points, num_samples=val_batch_size)\n", 66 | "signals = generative_model.get_signal(num_samples=train_batch_size, n_points=n_points,).tensors[0]\n", 67 | "\n", 68 | "train_dataloader = torchdata.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=8)\n", 69 | "eval_dataloader = torchdata.DataLoader(eval_dataset, batch_size=val_batch_size, shuffle=False, num_workers=8)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": { 76 | "scrolled": false 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "example_batch, _ = next(iter(train_dataloader))\n", 81 | "example = example_batch[0]\n", 82 | "\n", 83 | "print(f'Feature shape: {tuple(example.shape)}')\n", 84 | "plt.plot(*example.numpy())\n", 85 | "for path in signals[:100]:\n", 86 | " plt.plot(*path.numpy(), \"orange\", alpha=0.1)\n", 87 | "plt.show()" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "# Loss function\n", 97 | "loss_fn = generative_model.loss(signals, sig_depth=4, normalise_sigs=True)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "## Define Neural Network model" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "model = generative_model.create_generative_model()" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "## Train Model" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "history = {}\n", 130 | "train_model = utils.create_train_model_fn(max_epochs, optimizer_fn, loss_fn, train_dataloader, eval_dataloader, \n", 131 | " example_batch)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": { 138 | "scrolled": true 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "train_model(model, 'SigNet', history)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "## Results" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": { 156 | "scrolled": false 157 | }, 158 | "outputs": [], 159 | "source": [ 160 | "fig, axs = plt.subplots(1, 2, gridspec_kw={'wspace': 0.6, 'hspace': 0.6}, figsize=(12, 4))\n", 161 | "axs = axs.flatten()\n", 162 | "for i, metric_name in enumerate(('train_loss', 'val_loss')):\n", 163 | " ax = axs[i]\n", 164 | " for model_history in history.values():\n", 165 | " metric = model_history[metric_name]\n", 166 | "\n", 167 | " # Moving average\n", 168 | " metric = np.convolve(metric, np.ones(10), 'valid') / 10.\n", 169 | " ax.semilogy(np.exp(metric))\n", 170 | " ax.set_xlabel('Epoch')\n", 171 | " ax.set_ylabel(metric_name)\n", 172 | "plt.show()" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "batch, _ = next(iter(eval_dataloader))\n", 182 | "batch = batch.to(device=next(model.parameters()).device)\n", 183 | "generated = model(batch).cpu()\n", 184 | "plt.figure(figsize=(12, 8))\n", 185 | "plt.plot(generated[50:100].detach().numpy().T, \"b\", alpha=0.2)\n", 186 | "plt.plot(signals[50:100, 1, :99].detach().numpy().T, \"#ba0404\", alpha=0.2)\n", 187 | "\n", 188 | "orange_patch = mpatches.Patch(color='#ba0404', label='Ornstein–Uhlenbeck process')\n", 189 | "blue_patch = mpatches.Patch(color='blue', label='Generated paths')\n", 190 | "plt.legend(mode='expand', ncol=2, prop={'size': 18}, bbox_to_anchor=(0, 1, 1, 0), \n", 191 | " handles=[blue_patch, orange_patch])\n", 192 | "plt.ylim([-2,2])\n", 193 | "plt.yticks([-2, -1, 0, 1, 2])\n", 194 | "\n", 195 | "plt.show()" 196 | ] 197 | } 198 | ], 199 | "metadata": { 200 | "kernelspec": { 201 | "display_name": "Python 3", 202 | "language": "python", 203 | "name": "python3" 204 | }, 205 | "language_info": { 206 | "codemirror_mode": { 207 | "name": "ipython", 208 | "version": 3 209 | }, 210 | "file_extension": ".py", 211 | "mimetype": "text/x-python", 212 | "name": "python", 213 | "nbconvert_exporter": "python", 214 | "pygments_lexer": "ipython3", 215 | "version": "3.7.3" 216 | } 217 | }, 218 | "nbformat": 4, 219 | "nbformat_minor": 1 220 | } 221 | -------------------------------------------------------------------------------- /src/example_hurst_parameter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "colab": {}, 8 | "colab_type": "code", 9 | "id": "P7uYfpGxKr3U" 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "%run base.ipynb\n", 14 | "%matplotlib inline\n", 15 | "\n", 16 | "import iisignature\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "import numpy as np\n", 19 | "import pandas as pd\n", 20 | "import siglayer.examples as examples\n", 21 | "import torch\n", 22 | "import torch.nn.functional as F\n", 23 | "import torch.optim as optim\n", 24 | "\n", 25 | "import hurst_parameter\n", 26 | "import utils\n", 27 | "\n", 28 | "plt.rcParams['axes.labelsize'] = 20\n", 29 | "plt.rcParams['xtick.labelsize'] = 20\n", 30 | "plt.rcParams['ytick.labelsize'] = 20" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": { 36 | "colab_type": "text", 37 | "id": "V0qWPxHnKr3y" 38 | }, 39 | "source": [ 40 | "## Dataset hyperparameters " 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": { 47 | "colab": {}, 48 | "colab_type": "code", 49 | "id": "8JZ6vFuxKr3z" 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "# dataset parameters\n", 54 | "n_paths_train=600\n", 55 | "n_paths_test=100 \n", 56 | "n_samples=300\n", 57 | "hurst_exponents=np.around(np.linspace(0.2, 0.8, 7), decimals=1).tolist()\n", 58 | "\n", 59 | "# target shape\n", 60 | "output_shape = (1,)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": { 66 | "colab_type": "text", 67 | "id": "K-52_W35Kr35" 68 | }, 69 | "source": [ 70 | "## Learning hyperparameters" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": { 77 | "colab": {}, 78 | "colab_type": "code", 79 | "id": "ghOjJD9gKr37" 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "# batch and epoch sizes\n", 84 | "train_batch_size = 128\n", 85 | "val_batch_size = 128\n", 86 | "max_epochs = 100\n", 87 | "\n", 88 | "optimizer_fn = optim.Adam\n", 89 | "\n", 90 | "def loss_fn(x,y):\n", 91 | " return torch.log(F.mse_loss(x, y))" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": { 97 | "colab_type": "text", 98 | "id": "s5wgLp6r62t2" 99 | }, 100 | "source": [ 101 | "## On to the training!" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "history = {}\n", 111 | "x_train, y_train, x_test, y_test = hurst_parameter.generate_data(n_paths_train, \n", 112 | " n_paths_test, \n", 113 | " n_samples, \n", 114 | " hurst_exponents)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": { 120 | "colab_type": "text", 121 | "id": "_ejmE5-Bo--E" 122 | }, 123 | "source": [ 124 | "### Feedforward, RNN, DeepSigNet, DeeperSigNet" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": { 131 | "colab": {}, 132 | "colab_type": "code", 133 | "id": "PiyEil9a62t2" 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "x_train_, x_test_ = hurst_parameter.preprocess_data(x_train, x_test)\n", 138 | "\n", 139 | "(train_dataloader, test_dataloader, \n", 140 | " example_batch_x, example_batch_y) = hurst_parameter.generate_torch_batched_data(x_train_, \n", 141 | " y_train, \n", 142 | " x_test_, \n", 143 | " y_test,\n", 144 | " train_batch_size, \n", 145 | " val_batch_size)\n", 146 | "\n", 147 | "train_model = utils.create_train_model_fn(max_epochs, optimizer_fn, loss_fn, train_dataloader, \n", 148 | " test_dataloader, example_batch_x)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": { 155 | "colab": { 156 | "base_uri": "https://localhost:8080/", 157 | "height": 419 158 | }, 159 | "colab_type": "code", 160 | "id": "C6nlT9o2Kr4C", 161 | "outputId": "1c4df83b-2145-4dbb-f6b7-3ba129762438" 162 | }, 163 | "outputs": [], 164 | "source": [ 165 | "feedforward = examples.create_simple(output_shape, sig=False, augment_layer_sizes=(), \n", 166 | " layer_sizes = (16, 16, 16),\n", 167 | " final_nonlinearity=torch.sigmoid)\n", 168 | "train_model(feedforward, 'Feedforward', history)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "rnn = hurst_parameter.deep_recurrent(output_shape, \n", 178 | " sig=False,\n", 179 | " augment_layer_sizes=(), \n", 180 | " layer_sizes_s=((64,64,32), (32,32,32)),\n", 181 | " lengths=(4,4), \n", 182 | " strides=(2,4), \n", 183 | " adjust_lengths=(0, 0),\n", 184 | " memory_sizes=(2,4),\n", 185 | " hidden_output_sizes=(4,),\n", 186 | " final_nonlinearity=torch.sigmoid)\n", 187 | "train_model(rnn, 'RNN', history)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": { 194 | "colab": { 195 | "base_uri": "https://localhost:8080/", 196 | "height": 419 197 | }, 198 | "colab_type": "code", 199 | "id": "C6nlT9o2Kr4C", 200 | "outputId": "1c4df83b-2145-4dbb-f6b7-3ba129762438" 201 | }, 202 | "outputs": [], 203 | "source": [ 204 | "deepsignet = examples.create_simple(output_shape,\n", 205 | " sig=True,\n", 206 | " sig_depth=3,\n", 207 | " augment_layer_sizes=(3,),\n", 208 | " augment_kernel_size=3,\n", 209 | " layer_sizes = (32, 32, 32, 32, 32),\n", 210 | " final_nonlinearity=torch.sigmoid)\n", 211 | "train_model(deepsignet, 'DeepSigNet', history)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": { 218 | "colab": { 219 | "base_uri": "https://localhost:8080/", 220 | "height": 419 221 | }, 222 | "colab_type": "code", 223 | "id": "C6nlT9o2Kr4C", 224 | "outputId": "1c4df83b-2145-4dbb-f6b7-3ba129762438" 225 | }, 226 | "outputs": [], 227 | "source": [ 228 | "deepersignet = hurst_parameter.deep_recurrent(output_shape, \n", 229 | " sig=True, \n", 230 | " sig_depth=3,\n", 231 | " augment_layer_sizes=(16, 16, 3), \n", 232 | " augment_kernel_size=4,\n", 233 | " lengths=(10, 10, 10), \n", 234 | " strides=(0, 0, 0), \n", 235 | " adjust_lengths=(5, 5, 5),\n", 236 | " layer_sizes_s=((16, 16), (16, 16), (16, 16)), \n", 237 | " memory_sizes=(8, 8, 8),\n", 238 | " hidden_output_sizes=(5, 5),\n", 239 | " final_nonlinearity=torch.sigmoid)\n", 240 | "train_model(deepersignet, 'DeeperSigNet', history)" 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": { 246 | "colab_type": "text", 247 | "id": "YThvrzJBo--S" 248 | }, 249 | "source": [ 250 | "### GRU, LSTM" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "metadata": { 257 | "colab": {}, 258 | "colab_type": "code", 259 | "id": "Ts3zEWQtNW4S" 260 | }, 261 | "outputs": [], 262 | "source": [ 263 | "x_train_, x_test_ = hurst_parameter.preprocess_data(x_train, x_test, flag='lstm')\n", 264 | "\n", 265 | "(train_dataloader_lstm, test_dataloader_lstm, \n", 266 | " example_batch_lstm_x, example_batch_lstm_y) = hurst_parameter.generate_torch_batched_data(x_train_, \n", 267 | " y_train,\n", 268 | " x_test_, \n", 269 | " y_test,\n", 270 | " train_batch_size,\n", 271 | " val_batch_size)\n", 272 | "\n", 273 | "train_model_lstm = utils.create_train_model_fn(max_epochs, \n", 274 | " optimizer_fn, \n", 275 | " loss_fn, \n", 276 | " train_dataloader_lstm, \n", 277 | " test_dataloader_lstm, \n", 278 | " example_batch_lstm_x)" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": { 285 | "colab": { 286 | "base_uri": "https://localhost:8080/", 287 | "height": 104 288 | }, 289 | "colab_type": "code", 290 | "id": "fzjaWHWWp3IW", 291 | "outputId": "14dc3f03-b34f-46e8-a0dc-76e4284d3e76" 292 | }, 293 | "outputs": [], 294 | "source": [ 295 | "lstmnet = hurst_parameter.LSTM(input_dim=1, \n", 296 | " num_layers=2,\n", 297 | " hidden_dim=32,\n", 298 | " output_dim=1,\n", 299 | " final_nonlinearity=torch.sigmoid)\n", 300 | "train_model_lstm(lstmnet, 'LSTM', history)" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "metadata": { 307 | "colab": { 308 | "base_uri": "https://localhost:8080/", 309 | "height": 87 310 | }, 311 | "colab_type": "code", 312 | "id": "-2bpWU-bqT5L", 313 | "outputId": "ab8beb3d-3a94-4800-af21-70386d90662d" 314 | }, 315 | "outputs": [], 316 | "source": [ 317 | "grunet = hurst_parameter.GRU(input_dim=1, \n", 318 | " num_layers=2, \n", 319 | " hidden_dim=32,\n", 320 | " output_dim=1,\n", 321 | " final_nonlinearity=torch.sigmoid)\n", 322 | "train_model_lstm(grunet, 'GRU', history)" 323 | ] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": { 328 | "colab_type": "text", 329 | "id": "ZnQK1Eplo--l" 330 | }, 331 | "source": [ 332 | "### Neural-Signature" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": null, 338 | "metadata": { 339 | "colab": {}, 340 | "colab_type": "code", 341 | "id": "RGb-mqL3Of_H" 342 | }, 343 | "outputs": [], 344 | "source": [ 345 | "# generate dataset\n", 346 | "x_train_, x_test_ = hurst_parameter.preprocess_data(x_train, x_test, flag='neuralsig')\n", 347 | "\n", 348 | "# generate torch dataloaders\n", 349 | "(train_dataloader_sig, test_dataloader_sig, \n", 350 | " example_batch_sig_x, example_batch_sig_y) = hurst_parameter.generate_torch_batched_data(x_train_,\n", 351 | " y_train,\n", 352 | " x_test_,\n", 353 | " y_test,\n", 354 | " train_batch_size,\n", 355 | " val_batch_size)\n", 356 | "\n", 357 | "# trainer function\n", 358 | "train_model_sig = utils.create_train_model_fn(max_epochs, \n", 359 | " optimizer_fn, \n", 360 | " loss_fn, \n", 361 | " train_dataloader_sig, \n", 362 | " test_dataloader_sig, \n", 363 | " example_batch_sig_x)" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "metadata": { 370 | "colab": { 371 | "base_uri": "https://localhost:8080/", 372 | "height": 752 373 | }, 374 | "colab_type": "code", 375 | "id": "eGDg-jFp-bGD", 376 | "outputId": "97ed12b1-7ced-4d5a-c026-f7235f274934" 377 | }, 378 | "outputs": [], 379 | "source": [ 380 | "neuralsig = examples.create_feedforward(output_shape, sig=False, \n", 381 | " layer_sizes=(64, 64, 32, 32, 16, 16),\n", 382 | " final_nonlinearity=torch.sigmoid)\n", 383 | "train_model_sig(neuralsig, 'Neural-Sig', history)" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "metadata": { 389 | "colab_type": "text", 390 | "id": "NHYyc-qqXasc" 391 | }, 392 | "source": [ 393 | "## Results" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": null, 399 | "metadata": { 400 | "colab": {}, 401 | "colab_type": "code", 402 | "id": "ObD1hgF2He9c" 403 | }, 404 | "outputs": [], 405 | "source": [ 406 | "params = {}\n", 407 | "for k, m in zip(('DeeperSigNet', 'DeepSigNet', 'Neural-Sig', 'LSTM', 'GRU', 'RNN', 'Feedforward'), \n", 408 | " (deepersignet, deepsignet, neuralsig, lstmnet, grunet, rnn, feedforward)):\n", 409 | " params[k] = utils.count_parameters(m)" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": null, 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "for key in history:\n", 419 | " print('{:12} {:6.4f} {}'.format(key, history[key]['val_loss'][-1], params[key]))" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [ 428 | "# Loss for the non-neural-network mathematically-derived rescaled range method\n", 429 | "rescaled_range_pred = [hurst_parameter.hurst_rescaled_range(x_test_i) for x_test_i in x_test]\n", 430 | "loss_fn(torch.Tensor(rescaled_range_pred), torch.Tensor(y_test))" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": null, 436 | "metadata": {}, 437 | "outputs": [], 438 | "source": [ 439 | "# adapted from jet\n", 440 | "colors = np.array([[0.5 , 0.5 , 0.5 , 1. ],\n", 441 | " [0. , 0.06470588, 1. , 1. ],\n", 442 | " [0. , 0.64509804, 1. , 1. ],\n", 443 | " [0.05882352, 0.51764705, 0.17647058, 1. ],\n", 444 | " [0.9 , 0.7 , 0. , 1. ],\n", 445 | " [1. , 0.18954248, 0. , 1. ],\n", 446 | " [0.28627450, 0.18823529, 0.06666666, 1. ]])\n", 447 | "\n", 448 | "# define pd dataframe for losses\n", 449 | "df_test_log = pd.DataFrame()\n", 450 | "for k in ('Feedforward', 'RNN', 'GRU', 'LSTM', 'Neural-Sig', 'DeepSigNet', 'DeeperSigNet'):\n", 451 | " df_test_log[k] = history[k]['val_loss']\n", 452 | "\n", 453 | "fig, axes = plt.subplots(figsize=(10, 8))\n", 454 | "np.power(np.e, df_test_log.rolling(5).mean()).plot(grid=False, ax=axes, color=colors, lw=1.5, alpha=0.8)\n", 455 | "plt.yscale('log', basey=10)\n", 456 | "axes.set_xlabel('Epoch')\n", 457 | "axes.set_ylabel('Test MSE')\n", 458 | "plt.legend(mode='expand', bbox_to_anchor=(0, 1, 1, 0), ncol=3, prop={'size': 18})\n", 459 | "\n", 460 | "plt.show()" 461 | ] 462 | } 463 | ], 464 | "metadata": { 465 | "accelerator": "GPU", 466 | "colab": { 467 | "collapsed_sections": [], 468 | "include_colab_link": true, 469 | "name": "Copy of example_hurst_exponent.ipynb", 470 | "provenance": [], 471 | "toc_visible": true, 472 | "version": "0.3.2" 473 | }, 474 | "kernelspec": { 475 | "display_name": "Python 3", 476 | "language": "python", 477 | "name": "python3" 478 | }, 479 | "language_info": { 480 | "codemirror_mode": { 481 | "name": "ipython", 482 | "version": 3 483 | }, 484 | "file_extension": ".py", 485 | "mimetype": "text/x-python", 486 | "name": "python", 487 | "nbconvert_exporter": "python", 488 | "pygments_lexer": "ipython3", 489 | "version": "3.7.3" 490 | } 491 | }, 492 | "nbformat": 4, 493 | "nbformat_minor": 1 494 | } 495 | -------------------------------------------------------------------------------- /src/example_reinforcement_learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Signatures and RL" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%run base.ipynb\n", 17 | "\n", 18 | "import gym\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "import numpy as np\n", 21 | "import pandas as pd\n", 22 | "import sys\n", 23 | "import torch\n", 24 | "import tqdm\n", 25 | "\n", 26 | "import reinforcement_learning\n", 27 | "import utils" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "env = gym.make('MountainCar-v0')\n", 37 | "\n", 38 | "steps = 300\n", 39 | "episodes = 2000" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## Random play exploration" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "A random play exploration, unsurprisingly, does not tend to work." 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "random_policy = reinforcement_learning.RandomPolicy(env)\n", 63 | "successes = [reinforcement_learning.play(env, random_policy, steps, render=False)[1] \n", 64 | " for _ in tqdm.trange(episodes, file=sys.stdout)]\n", 65 | "print(\"Number of successes: {}\".format(sum(successes)))" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "## Training a policy" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "sigpolicy = reinforcement_learning.SigPolicy(env)\n", 82 | "rnnpolicy = reinforcement_learning.RNNPolicy(env)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": { 89 | "scrolled": false 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "sighistory = reinforcement_learning.train(env, sigpolicy, steps, episodes)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "rnnhistory = reinforcement_learning.train(env, rnnpolicy, steps, episodes)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "## Plot Results" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "tuple(utils.count_parameters(x) for x in (sigpolicy, rnnpolicy))" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "plt.figure(2, figsize=[10,5])\n", 128 | "\n", 129 | "sigp = pd.Series(sighistory[2])\n", 130 | "sigma = sigp.rolling(100).mean()\n", 131 | "plt.plot(sigma, label=\"Signatures\")\n", 132 | "rnnp = pd.Series(rnnhistory[2])\n", 133 | "rnnma = rnnp.rolling(100).mean()\n", 134 | "plt.plot(rnnma, label=\"RNN\")\n", 135 | "\n", 136 | "plt.xlabel('Generation')\n", 137 | "plt.ylabel('Final position')\n", 138 | "plt.legend(mode='expand', bbox_to_anchor=(0, 1, 1, 0), ncol=3, prop={'size': 16})\n", 139 | "plt.show()" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "## Play" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "states, success = reinforcement_learning.play(env, sigpolicy, steps, render=True)\n", 156 | "print(f\"Success: {success}\")" 157 | ] 158 | } 159 | ], 160 | "metadata": { 161 | "kernelspec": { 162 | "display_name": "Python 3", 163 | "language": "python", 164 | "name": "python3" 165 | }, 166 | "language_info": { 167 | "codemirror_mode": { 168 | "name": "ipython", 169 | "version": 3 170 | }, 171 | "file_extension": ".py", 172 | "mimetype": "text/x-python", 173 | "name": "python", 174 | "nbconvert_exporter": "python", 175 | "pygments_lexer": "ipython3", 176 | "version": "3.7.3" 177 | } 178 | }, 179 | "nbformat": 4, 180 | "nbformat_minor": 2 181 | } 182 | -------------------------------------------------------------------------------- /src/generative_model.py: -------------------------------------------------------------------------------- 1 | import candle 2 | import numpy as np 3 | import sdepy 4 | import siglayer 5 | import torch 6 | import torch.nn as nn 7 | import torch.utils.data as torchdata 8 | 9 | 10 | def gen_data(n_points=100): 11 | """Generate an Ornstein-Uhlenbeck process.""" 12 | 13 | sde = sdepy.ornstein_uhlenbeck_process() 14 | timeline = np.linspace(0, 1, n_points) 15 | values = sde(timeline).flatten() 16 | path = np.c_[timeline, values.tolist()] 17 | return path.T 18 | 19 | 20 | def gen_noise(n_points=100): 21 | """Generate a Brownian motion.""" 22 | 23 | dt = 1 / np.sqrt(n_points) 24 | bm = dt * np.r_[0., np.random.randn(n_points - 1).cumsum()] 25 | timeline = np.linspace(0, 1, n_points) 26 | return np.c_[timeline, bm].T 27 | 28 | 29 | def get_signal(num_samples=1000, **kwargs): 30 | """Generate examples of an Ornstein-Uhlenbeck process.""" 31 | 32 | paths = np.array([gen_data(**kwargs) for _ in range(num_samples)]) 33 | return torchdata.TensorDataset(torch.tensor(paths, dtype=torch.float)) 34 | 35 | 36 | def get_noise(num_samples=1000, **kwargs): 37 | """Generate examples of an Brownian motion.""" 38 | 39 | paths = np.array([gen_noise(**kwargs) for _ in range(num_samples)]) 40 | y = np.zeros_like(paths[:, 0, :-1]) 41 | return torchdata.TensorDataset(torch.tensor(paths, dtype=torch.float), torch.tensor(y, dtype=torch.float)) 42 | 43 | 44 | def scalar_orders(dim, order): 45 | """The order of the scalar basis elements as one moves along the signature.""" 46 | 47 | for i in range(order + 1): 48 | for _ in range(dim ** i): 49 | yield i 50 | 51 | 52 | def psi(x, M=4, a=1): 53 | """Psi function, as defined in the following paper: 54 | 55 | Chevyrev, I. and Oberhauser, H., 2018. Signature moments to 56 | characterize laws of stochastic processes. arXiv preprint arXiv:1810.10971. 57 | 58 | """ 59 | 60 | if x <= M: 61 | return x 62 | 63 | return M + M ** (1 + a) * (M ** (-a) - x ** (-a)) / a 64 | 65 | 66 | def normalise_instance(x, order): 67 | """Normalise signature, following the paper 68 | 69 | Chevyrev, I. and Oberhauser, H., 2018. Signature moments to 70 | characterize laws of stochastic processes. arXiv preprint arXiv:1810.10971. 71 | 72 | """ 73 | 74 | x = torch.cat([torch.tensor([1.], device=x.device), x]) 75 | 76 | a = x ** 2 77 | a[0] -= psi(torch.norm(x)) 78 | 79 | 80 | x0 = 1. # Starting point for Newton-Raphson 81 | 82 | moments = torch.tensor([x0 ** (2 * m) for m in range(len(x))], device=x.device) 83 | polx0 = torch.dot(a, moments) 84 | 85 | d_moments = torch.tensor([2 * m * x0 ** (2 * m - 1) for m in range(len(x))], device=x.device) 86 | d_polx0 = torch.dot(a, d_moments) 87 | x1 = x0 - polx0 / d_polx0 88 | 89 | if x1 < 0.2: 90 | x1 = 1. 91 | 92 | lambda_ = torch.tensor([x1 ** t for t in scalar_orders(2, order)], device=x.device) 93 | 94 | 95 | return lambda_ * x 96 | 97 | 98 | def normalise(x, order): 99 | """Normalise signature.""" 100 | 101 | return torch.stack([normalise_instance(sig, order) for sig in x]) 102 | 103 | 104 | def loss(orig_paths, sig_depth=2, normalise_sigs=True): 105 | """Loss function is the T statistic defined in 106 | 107 | Chevyrev, I. and Oberhauser, H., 2018. Signature moments to 108 | characterize laws of stochastic processes. arXiv preprint arXiv:1810.10971. 109 | 110 | """ 111 | 112 | sig = siglayer.Signature(sig_depth) 113 | orig_signatures = sig(orig_paths) 114 | if normalise_sigs: 115 | orig_signatures = normalise(orig_signatures, sig_depth) 116 | 117 | T1 = torch.mean(torch.mm(orig_signatures, orig_signatures.t())) 118 | 119 | def loss_fn(output, *args): 120 | nonlocal T1, orig_signatures 121 | T1 = T1.to(device=output.device) 122 | orig_signatures = orig_signatures.to(device=output.device) 123 | 124 | timeline = torch.tensor(np.linspace(0, 1, output.shape[1] + 1), dtype=torch.float32, device=output.device) 125 | paths = torch.stack([torch.stack([timeline, torch.cat([torch.tensor([0.], device=output.device), path])]) 126 | for path in output]) 127 | 128 | generated_sigs = sig(paths) 129 | 130 | if normalise_sigs: 131 | generated_sigs = normalise(generated_sigs, sig_depth) 132 | 133 | T2 = torch.mean(torch.mm(orig_signatures, generated_sigs.t())) 134 | T3 = torch.mean(torch.mm(generated_sigs, generated_sigs.t())) 135 | 136 | return torch.log(T1 - 2 * T2 + T3) 137 | 138 | return loss_fn 139 | 140 | 141 | def create_generative_model(): 142 | return candle.CannedNet((siglayer.Augment((8, 8, 2), 1, include_original=True, include_time=False), 143 | candle.Window(2, 0, 1, transformation=siglayer.Signature(3)), 144 | siglayer.Augment((1,), 1, include_original=False, include_time=False), 145 | candle.batch_flatten # just squeezing out the channel dimension of length 1 146 | )) 147 | -------------------------------------------------------------------------------- /src/hurst_parameter.py: -------------------------------------------------------------------------------- 1 | import candle 2 | import fbm 3 | import iisignature 4 | import numpy as np 5 | import random 6 | import siglayer 7 | import sklearn.base as base 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.data as torchdata 12 | import torchvision.transforms as transforms 13 | 14 | 15 | class AddTime(base.BaseEstimator, base.TransformerMixin): 16 | """Augments the path with time.""" 17 | def __init__(self, init_time=0.): 18 | self.init_time = init_time 19 | 20 | def fit(self, X, y=None): 21 | return self 22 | 23 | def transform_instance(self, X): 24 | t = np.linspace(self.init_time, self.init_time + 1, len(X)) 25 | return np.c_[t, X] 26 | 27 | def transform(self, X, y=None): 28 | return [self.transform_instance(x) for x in X] 29 | 30 | 31 | def generate_fBM(n_paths, n_samples, hurst_exponents): 32 | """Generate FBM paths""" 33 | X = [] 34 | y = [] 35 | for j in range(n_paths): 36 | hurst = random.choice(hurst_exponents) 37 | X.append(fbm.FBM(n=n_samples, hurst=hurst, length=1, method='daviesharte').fbm()) 38 | y.append(hurst) 39 | return np.array(X), np.array(y) 40 | 41 | 42 | def generate_data(n_paths_train, n_paths_test, n_samples, hurst_exponents): 43 | """Generate train and test datasets""" 44 | 45 | # generate dataset 46 | x_train, y_train = generate_fBM(n_paths_train, n_samples, hurst_exponents) 47 | x_test, y_test = generate_fBM(n_paths_test, n_samples, hurst_exponents) 48 | 49 | # reshape targets 50 | y_train = np.expand_dims(y_train, axis=1) 51 | y_test = np.expand_dims(y_test, axis=1) 52 | 53 | return x_train, y_train, x_test, y_test 54 | 55 | 56 | def preprocess_data(x_train, x_test, flag=None): 57 | """Peforms model-dependent preprocessing.""" 58 | if flag == 'neuralsig': 59 | # We don't need to backprop through the signature if we're just building a model on top 60 | # so we actually perform the signature here as a feature transformation, rather than in 61 | # the model. 62 | path_transform = AddTime() 63 | x_train = np.array([iisignature.sig(x, 4) for x in path_transform.fit_transform(x_train)]) 64 | x_test = np.array([iisignature.sig(x, 4) for x in path_transform.fit_transform(x_test)]) 65 | elif flag == 'lstm': 66 | # LSTM wants another dimension in one place... 67 | x_train = np.expand_dims(x_train, 2) 68 | x_test = np.expand_dims(x_test, 2) 69 | else: 70 | # ...everyone else wants the extra dimension in another 71 | x_train = np.expand_dims(x_train, 1) 72 | x_test = np.expand_dims(x_test, 1) 73 | return x_train, x_test 74 | 75 | 76 | def generate_torch_batched_data(x_train, y_train, x_test, y_test, train_batch_size, test_batch_size): 77 | """Generate torch dataloaders""" 78 | 79 | # make torch dataset 80 | train_dataset = torchdata.TensorDataset(torch.tensor(x_train, dtype=torch.float), torch.tensor(y_train, dtype=torch.float)) 81 | test_dataset = torchdata.TensorDataset(torch.tensor(x_test, dtype=torch.float), torch.tensor(y_test, dtype=torch.float)) 82 | 83 | # process with torch dataloader 84 | train_dataloader = torchdata.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=8) 85 | test_dataloader = torchdata.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=8) 86 | 87 | example_batch_x, example_batch_y = next(iter(train_dataloader)) 88 | 89 | return train_dataloader, test_dataloader, example_batch_x, example_batch_y 90 | 91 | 92 | def hurst_rescaled_range(ts): 93 | """Uses the rescaled range method to estimate the Hurst parameter.""" 94 | 95 | # calculate standard deviation of differenced series using various lags 96 | lags = range(2, 20) 97 | tau = [np.sqrt(np.std(np.subtract(ts[lag:], ts[:-lag]))) for lag in lags] 98 | # calculate Hurst as slope of log-log plot 99 | m = np.polyfit(np.log(lags), np.log(tau), 1) 100 | hurst = m[0]*2.0 101 | return hurst 102 | 103 | 104 | class LSTM(nn.Module): 105 | def __init__(self, input_dim, hidden_dim, num_layers, output_dim, final_nonlinearity=lambda x: x, **kwargs): 106 | super(LSTM, self).__init__(**kwargs) 107 | 108 | self.mod = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True) 109 | self.fc = nn.Linear(hidden_dim, output_dim) 110 | self.final = final_nonlinearity 111 | 112 | def forward(self, x): 113 | out, _ = self.mod(x) 114 | out = out[:, -1, :] 115 | return self.final(self.fc(out)) 116 | 117 | 118 | class GRU(nn.Module): 119 | def __init__(self, input_dim, hidden_dim, num_layers, output_dim, final_nonlinearity=lambda x: x, **kwargs): 120 | super(GRU, self).__init__(**kwargs) 121 | 122 | self.mod = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True) 123 | self.fc = nn.Linear(hidden_dim, output_dim) 124 | self.final = final_nonlinearity 125 | 126 | def forward(self, x): 127 | out, _ = self.mod(x) 128 | out = out[:, -1, :] 129 | return self.final(self.fc(out)) 130 | 131 | 132 | # Now THIS is deep signatures! 133 | def deep_recurrent(output_shape, sig=True, sig_depth=4, final_nonlinearity=lambda x: x, 134 | augment_layer_sizes=(32, 32, 2), augment_kernel_size=8, augment_include_original=True, 135 | augment_include_time=True, 136 | lengths=(5, 5, 10), strides=(1, 1, 5), adjust_lengths=(0, 0, 0), memory_sizes=(8, 8, 8), 137 | layer_sizes_s=((32,), (32,), (32, 16)), hidden_output_sizes=(8, 8)): 138 | """This model stacks multiple layers of signatures on top of one another in a natural way. 139 | 140 | - Augment the features with something learnable 141 | - Slide a window across the augmented features 142 | - Take the signature of each window 143 | - Put this list of signatures back together to recover the path dimension 144 | - Apply an RNN across the path dimension, preserving the intermediate outputs, so the path dimension is preserved 145 | - Slide another window 146 | - Take another signature 147 | - Reassemble signatures along path dimension 148 | - Another RNN 149 | - ... 150 | - etc. for some number of times 151 | - ... 152 | - Slide another window 153 | - Take another signature 154 | - Reassemble signatures along path dimension 155 | - Another RNN; this time throw away intermediate outputs and just present the final output as the overall output. 156 | If :sig: is falsy then the signature layers will be replaced with flattening instead. 157 | It expects input tensors of three dimensions: (batch, channels, length). 158 | 159 | For a simpler example in the same vein, see siglayer.examples.create_windowed. 160 | 161 | Arguments: 162 | output_shape: The final output shape from the network. 163 | sig: Optional, whether to use signatures in the network. If True a signature will be applied between each 164 | window. If False then the output is simply flattened. Defaults to True. 165 | sig_depth: Optional. If signatures are used, then this specifies how deep they should be truncated to. 166 | final_nonlinearity: Optional. What final nonlinearity to feed the final tensors of the network through, e.g. a 167 | sigmoid when desiring output between 0 and 1. Defaults to the identity. 168 | augment_layer_sizes: Optional. A tuple of integers specifying the size of the hidden layers of the feedforward 169 | network that is swept across the input stream to augment it. May be set to the empty tuple to do no 170 | augmentation. 171 | augment_kernel_size: Optional. How far into the past the swept feedforward network (that is doing augmenting) 172 | should take inputs from. For example if this is 1 then it will just take data from a single 'time', making 173 | it operate in a 'pointwise' manner. If this is 2 then it will take the present and the most recent piece of 174 | past information, and so on. 175 | augment_include_original: Optional. Whether to include the original path in the augmentation. 176 | augment_include_time: Optional. Whether to include an increasing 'time' parameter in the augmentation. 177 | lengths, strides, adjust_lengths, memory_sizes: Optional. Should each be a tuple of integers, all of the same 178 | length as one another. The length of these arguments determines the number of windows; this length must be 179 | at least one. The ith values determine the length, stride and adjust_length arguments of the ith Window, 180 | and the size of the memory of the ith RNN. 181 | layer_sizes_s: Optional. Should be a tuple of the same length as lengths, strides, adjust_lengths, 182 | memory_sizes. Each element of the tuple should itself be a tuple of integers specifying the sizes of the 183 | hidden layers of each RNN. 184 | hidden_output_sizes: Optional. Should be a tuple of integers one shorter than the length of lengths, strides, 185 | adjust_lengths, memory_sizes. It determines the output size of each RNN. It is of a slightly shorter length 186 | because the final output size is actually already determined by the output_shape argument! 187 | """ 188 | 189 | num_windows = len(lengths) 190 | assert num_windows >= 1 191 | assert len(strides) == num_windows 192 | assert len(adjust_lengths) == num_windows 193 | assert len(layer_sizes_s) == num_windows 194 | assert len(memory_sizes) == num_windows 195 | assert len(hidden_output_sizes) == num_windows - 1 196 | 197 | if sig: 198 | transformation = siglayer.Signature(depth=sig_depth) 199 | else: 200 | transformation = lambda x: candle.batch_flatten(x.contiguous()) 201 | 202 | final_output_size = torch.Size(output_shape).numel() 203 | output_sizes = (*hidden_output_sizes, final_output_size) 204 | 205 | recurrent_layers = [] 206 | for (i, length, stride, adjust_length, layer_sizes, memory_size, output_size 207 | ) in zip(range(num_windows), lengths, strides, adjust_lengths, layer_sizes_s, memory_sizes, output_sizes): 208 | 209 | window_layers = [] 210 | for layer_size in layer_sizes: 211 | window_layers.append(layer_size) 212 | window_layers.append(F.relu) 213 | 214 | intermediate_outputs = (num_windows - 1 != i) 215 | 216 | recurrent_layers.append(candle.Window(length=length, stride=stride, adjust_length=adjust_length, 217 | transformation=transformation)) 218 | recurrent_layers.append(candle.Recur(module=candle.CannedNet((candle.Concat(), 219 | *window_layers, 220 | memory_size + output_size, 221 | candle.Split((memory_size, output_size)))), 222 | memory_shape=(memory_size,), 223 | intermediate_outputs=intermediate_outputs)) 224 | 225 | return candle.CannedNet((siglayer.Augment(layer_sizes=augment_layer_sizes, 226 | kernel_size=augment_kernel_size, 227 | include_original=augment_include_original, 228 | include_time=augment_include_time), 229 | *recurrent_layers, 230 | candle.View(output_shape), 231 | final_nonlinearity)) 232 | -------------------------------------------------------------------------------- /src/reinforcement_learning.py: -------------------------------------------------------------------------------- 1 | import candle 2 | import numpy as np 3 | import siglayer 4 | import sys 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | import tqdm 10 | 11 | 12 | def play(env, policy, steps, render=True): 13 | state_sequence = [] 14 | # We're going to ignore velocity 15 | position, velocity = env.reset() 16 | state = np.array([position, 0.]) # Record the timestep 17 | 18 | for s in range(steps): 19 | state_sequence.append(state) 20 | if render: 21 | env.render() 22 | 23 | state_seq_var = torch.tensor(state_sequence, requires_grad=True, dtype=torch.float).t().unsqueeze(0) 24 | 25 | # select action 26 | Q = policy(state_seq_var) 27 | _, action = torch.max(Q, -1) 28 | action = action.item() 29 | 30 | # take action 31 | state, _, done, _ = env.step(action) 32 | position, velocity = state 33 | state = np.array([position, (s + 1) / steps]) # Record the timestep 34 | 35 | if done: 36 | break 37 | 38 | return np.array(state_sequence), state[0] > 0.5 39 | 40 | 41 | def train(env, policy, steps, episodes, epsilon=0.2, gamma=0.99, learning_rate=0.001): 42 | successes = 0 43 | max_position = -0.4 44 | 45 | loss_history = [] 46 | reward_history = [] 47 | position_history = [] 48 | 49 | # call policy on example inputs to fully specify model 50 | # (needed for things using candle.NoInputSpec, such as siglayer.Augment, which is 51 | # used in the signature policy) 52 | state_seq_var = torch.tensor([[env.reset()[0], 0.]], dtype=torch.float).t().unsqueeze(0) 53 | policy(state_seq_var) 54 | 55 | loss_fn = nn.MSELoss() 56 | optimizer = optim.Adam(policy.parameters(), lr=learning_rate) 57 | # not the same gamma as was passed in kwargs 58 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) 59 | 60 | pbar = tqdm.trange(episodes, file=sys.stdout) 61 | for episode in pbar: 62 | pbar.set_description("{} successes".format(successes)) 63 | episode_loss = 0 64 | episode_reward = 0 65 | 66 | state_sequence = [] 67 | position, velocity = env.reset() 68 | state = np.array([position, 0.]) 69 | 70 | for s in range(steps): 71 | state_sequence.append(state) 72 | 73 | # wrap into batch of length 1 74 | state_seq_var = torch.tensor(state_sequence, requires_grad=True, dtype=torch.float).t().unsqueeze(0) 75 | 76 | # Choose action epsilon-greedily 77 | Q = policy(state_seq_var)[0] # unwrap from batch dimension 78 | if np.random.rand(1) < epsilon: 79 | action = np.random.randint(0, 3) 80 | else: 81 | _, action = torch.max(Q, -1) 82 | action = action.item() 83 | 84 | # take action 85 | state_1, reward, done, _ = env.step(action) 86 | position_1, velocity_1 = state_1 87 | state_1 = np.array([position_1, (s + 1) / steps]) 88 | 89 | # keep track of max position 90 | max_position = max(position_1, max_position) 91 | 92 | # Increase reward for task completion 93 | reward = position_1 94 | if position_1 >= 0.5: 95 | reward += 1 96 | 97 | # Find max Q for t+1 state 98 | state_sequence_1 = state_sequence + [state_1] 99 | # wrap into batch of length 1 100 | state_seq_var_1 = torch.tensor(state_sequence_1, dtype=torch.float).t().unsqueeze(0) 101 | Q1 = policy(state_seq_var_1)[0] # unwrap from batch dimension 102 | maxQ1, _ = torch.max(Q1, -1) 103 | 104 | # Create target Q value for training the policy 105 | Q_target = reward + torch.mul(maxQ1, gamma) 106 | Q_target.detach_() 107 | 108 | # Calculate loss 109 | loss = loss_fn(Q[action], Q_target) 110 | 111 | # Update policy 112 | policy.zero_grad() 113 | loss.backward() 114 | optimizer.step() 115 | 116 | episode_loss += loss.item() 117 | episode_reward += reward 118 | 119 | if done: 120 | if position_1 >= 0.5: 121 | # On successful epsisodes, adjust the following parameters 122 | epsilon *= .95 123 | scheduler.step() 124 | successes += 1 125 | 126 | # Record history 127 | loss_history.append(episode_loss) 128 | reward_history.append(episode_reward) 129 | position_history.append(position_1) 130 | break 131 | else: 132 | state = state_1 133 | 134 | if episode % 1000 == 0 and episode > 0: 135 | optimizer = optim.Adam(policy.parameters(), lr=learning_rate) 136 | # not the same gamma as was passed in kwargs 137 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) 138 | 139 | print('successful episodes: {:d} - {:.4f}%'.format(successes, successes / episodes * 100)) 140 | 141 | return loss_history, reward_history, position_history 142 | 143 | 144 | class RandomPolicy(nn.Module): 145 | def __init__(self, env, **kwargs): 146 | super(RandomPolicy, self).__init__(**kwargs) 147 | self.action_space = env.action_space.n 148 | 149 | def forward(self, seq): 150 | r = np.random.randint(0, self.action_space) 151 | return torch.eye(self.action_space)[r] 152 | 153 | 154 | class SigPolicy(nn.Module): 155 | def __init__(self, env, sig_depth=3, **kwargs): 156 | super(SigPolicy, self).__init__(**kwargs) 157 | 158 | channels = 2 159 | self.augmentation = siglayer.Augment((channels,), 1, include_original=True, include_time=False) 160 | self.sig = siglayer.Signature(sig_depth) 161 | self.l1 = nn.Linear(siglayer.sig_dim(channels + 2, sig_depth) + 2, 64) 162 | self.l2 = nn.Linear(64, env.action_space.n) 163 | 164 | def forward(self, seq): 165 | x = self.augmentation(seq) 166 | x = self.sig(x) 167 | x = torch.cat([x, seq[:, :, -1]], dim=-1) 168 | x = self.l1(x) 169 | x = F.relu(x) 170 | return self.l2(x) 171 | 172 | 173 | class RNNPolicy(nn.Module): 174 | def __init__(self, env, **kwargs): 175 | super(RNNPolicy, self).__init__(**kwargs) 176 | 177 | self.rnn = nn.RNN(2, 32, 3, nonlinearity="relu", batch_first=True) 178 | self.fc1 = nn.Linear(32, env.action_space.n) 179 | 180 | def forward(self, seq): 181 | # seq.shape == (batch, feature, seq) 182 | seq = seq.transpose(1, 2) 183 | # seq.shape == (batch, seq, feature) 184 | out, _ = self.rnn(seq) 185 | out = out[:, -1, :] 186 | 187 | out = self.fc1(out) 188 | return out 189 | -------------------------------------------------------------------------------- /src/signature_inversion.py: -------------------------------------------------------------------------------- 1 | import iisignature 2 | import numpy as np 3 | import siglayer 4 | import torch 5 | import torch.nn as nn 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def _get_tree_reduced_steps(X, order=4, steps=4, tol=0.1): 10 | if len(X) < steps: 11 | return X 12 | 13 | dim = X.shape[1] 14 | 15 | for i in range(steps - 1, len(X)): 16 | new_path = X[i - steps + 1 : i + 1] 17 | new_path2 = np.r_[X[i - steps + 1].reshape(-1, dim), X[i].reshape(-1, dim)] 18 | 19 | new_path_sig = iisignature.sig(new_path, order) 20 | new_path2_sig = iisignature.sig(new_path2, order) 21 | 22 | norm = np.linalg.norm(new_path_sig - new_path2_sig) 23 | if norm < tol: 24 | return _get_tree_reduced_steps(np.r_[X[:i - steps + 2], X[i:]]) 25 | 26 | return X 27 | 28 | def get_tree_reduced(X, order=4, tol=0.1): 29 | """Removes tree-like pieces of the path.""" 30 | 31 | X = np.r_[X, [X[-1]]] 32 | 33 | for step in range(3, len(X) + 1): 34 | X = _get_tree_reduced_steps(X, order, step, tol) 35 | 36 | if (X[-1] == X[-2]).all(): 37 | return X[:-1] 38 | 39 | return X 40 | 41 | 42 | def loss_fn(order): 43 | normalisation = torch.tensor([np.floor(np.log(i + 1) / np.log(2)) 44 | for i in range(1, 2 ** (order + 1) - 1)], 45 | dtype=torch.float) 46 | def loss(output, target): 47 | output *= normalisation 48 | target *= normalisation 49 | 50 | return torch.log(((output - target) ** 2).mean()) 51 | 52 | return loss 53 | 54 | 55 | class Invert(nn.Module): 56 | """Given a signature, we build a neural network that learns the inverse of that signature.""" 57 | 58 | def __init__(self, n_steps, order, derivatives=False, **kwargs): 59 | super(Invert, self).__init__(**kwargs) 60 | 61 | self.n_steps = n_steps 62 | self.order = order 63 | self.derivatives = derivatives 64 | 65 | self.path = nn.Linear(1, 2 * n_steps, bias=False) 66 | self.sig = siglayer.Signature(self.order) 67 | 68 | def _forward(self, x): 69 | x = self.path(x) 70 | if self.derivatives: 71 | x = torch.cumsum(x, 1) 72 | return x.view(x.size(0), 2, self.n_steps) 73 | 74 | def forward(self, x): 75 | x = self._forward(x) 76 | return self.sig(x) 77 | 78 | def get_path(self): 79 | x = torch.ones(1, 1, 1) 80 | x = self._forward(x) 81 | return np.array(x.detach().numpy()[0].T, dtype=float) 82 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import candle 2 | import ignite.engine as engine 3 | import ignite.metrics as ignite_metrics 4 | import sys 5 | import time 6 | import torch 7 | import tqdm 8 | 9 | 10 | def create_train_model_fn(max_epochs, optimizer_fn, loss_fn, train_dataloader, eval_dataloader, example_batch_x): 11 | 12 | def train_model(model, name, history, device=None): 13 | # Initialise all layers in model before passing parameters to optimizer 14 | # (necessary with the candle framework) 15 | model(example_batch_x) 16 | optimizer = optimizer_fn(model.parameters()) 17 | 18 | history[name] = {'train_loss': [], 'train_mse': [], 'val_loss': [], 'val_mse': []} 19 | 20 | if device not in ('cuda', 'cpu'): 21 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 22 | 23 | trainer = candle.create_supervised_trainer(model, optimizer, loss_fn, check_nan=True, grad_clip=1.0, device=device) 24 | evaluator = engine.create_supervised_evaluator(model, device=device, 25 | metrics={'mse': ignite_metrics.MeanSquaredError(), 26 | 'loss': ignite_metrics.Loss(loss_fn)}) 27 | 28 | log_interval = 10 29 | desc = "Epoch: {:4}{:12}" 30 | num_batches = len(train_dataloader) 31 | 32 | @trainer.on(engine.Events.STARTED) 33 | def log_results(trainer): 34 | 35 | # training 36 | evaluator.run(train_dataloader) 37 | train_mse = evaluator.state.metrics['mse'] 38 | train_loss = evaluator.state.metrics['loss'] 39 | 40 | # testing 41 | evaluator.run(eval_dataloader) 42 | val_mse = evaluator.state.metrics['mse'] 43 | val_loss = evaluator.state.metrics['loss'] 44 | 45 | 46 | tqdm.tqdm.write("train mse: {:5.4f} --- train loss: {:5.4f} --- val mse: {:5.4f} --- val loss: {:5.4f}" 47 | .format(train_mse, train_loss, val_mse, val_loss), file=sys.stdout) 48 | 49 | model_history = history[name] 50 | model_history['train_loss'].append(train_loss) 51 | model_history['train_mse'].append(train_mse) 52 | model_history['val_loss'].append(val_loss) 53 | model_history['val_mse'].append(val_mse) 54 | 55 | @trainer.on(engine.Events.EPOCH_STARTED) 56 | def create_pbar(trainer): 57 | trainer.state.pbar = tqdm.tqdm(initial=0, total=num_batches, desc=desc.format(trainer.state.epoch, ''), 58 | file=sys.stdout) 59 | 60 | @trainer.on(engine.Events.ITERATION_COMPLETED) 61 | def log_training_loss(trainer): 62 | iteration = (trainer.state.iteration - 1) % len(train_dataloader) + 1 63 | if iteration % log_interval == 0: 64 | trainer.state.pbar.desc = desc.format(trainer.state.epoch, ' Loss: {:5.4f}'.format(trainer.state.output)) 65 | trainer.state.pbar.update(log_interval) 66 | 67 | @trainer.on(engine.Events.EPOCH_COMPLETED) 68 | def log_results_(trainer): 69 | trainer.state.pbar.n = num_batches 70 | trainer.state.pbar.last_print_n = num_batches 71 | trainer.state.pbar.refresh() 72 | trainer.state.pbar.close() 73 | log_results(trainer) 74 | 75 | start = time.time() 76 | trainer.run(train_dataloader, max_epochs=max_epochs) 77 | end = time.time() 78 | tqdm.tqdm.write("Training took {:.2f} seconds.".format(end - start), file=sys.stdout) 79 | 80 | return train_model 81 | 82 | 83 | def count_parameters(model): 84 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 85 | --------------------------------------------------------------------------------