├── README.md ├── example.ipynb ├── setup.py └── torch_bandpass ├── __init__.py ├── dct.py └── layer.py /README.md: -------------------------------------------------------------------------------- 1 | # torch-bandpass 2 | 3 | This is an implementation of the [Prism layer](https://arxiv.org/abs/2011.04823), a DCT-based bandpass filter suitable for transformer sequence models. 4 | 5 | # Usage 6 | 7 | See [example.ipynb](example.ipynb) for full usage. The basic usage is as follows: 8 | 9 | ```python 10 | seq_len = 512 # number of timesteps per sequence 11 | d_model = 768 # number of feature channels in the transformer 12 | 13 | # Create a Prism layer, which only needs to know about the 14 | # total sequence length and how you want to split up features. 15 | layer = Prism(seq_len, mid_periods=(2, 8, 32, 256)) 16 | 17 | # random [N x T x C] tensor. 18 | input_sequence = torch.randn(BATCH_SIZE, seq_len, d_model) 19 | 20 | # output is the same shape as input_sequence 21 | output_sequence = layer(input_sequence) 22 | ``` 23 | -------------------------------------------------------------------------------- /example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "language_info": { 4 | "codemirror_mode": { 5 | "name": "ipython", 6 | "version": 3 7 | }, 8 | "file_extension": ".py", 9 | "mimetype": "text/x-python", 10 | "name": "python", 11 | "nbconvert_exporter": "python", 12 | "pygments_lexer": "ipython3", 13 | "version": "3.7.3-final" 14 | }, 15 | "orig_nbformat": 2, 16 | "kernelspec": { 17 | "name": "python37364bit2d832ff4253d4ad2949ee022c04a1bde", 18 | "display_name": "Python 3.7.3 64-bit" 19 | } 20 | }, 21 | "nbformat": 4, 22 | "nbformat_minor": 2, 23 | "cells": [ 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import matplotlib.pyplot as plt\n", 31 | "import numpy as np\n", 32 | "import torch as th\n", 33 | "\n", 34 | "from torch_bandpass import Prism" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "layer = Prism(64, mid_periods=(2, 8))\n", 44 | "data = th.randn(1, 64) + th.arange(64) / 8" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "output_type": "display_data", 54 | "data": { 55 | "text/plain": "
", 56 | "image/svg+xml": "\n\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", 57 | "image/png": "\n" 58 | }, 59 | "metadata": { 60 | "needs_background": "light" 61 | } 62 | } 63 | ], 64 | "source": [ 65 | "repeated = data[..., None].repeat(1, 1, 3)\n", 66 | "filtered = layer(repeated)\n", 67 | "\n", 68 | "plt.plot(data[0].numpy())\n", 69 | "plt.plot(filtered[0, :, 0].numpy(), '--')\n", 70 | "plt.plot(filtered[0, :, 1].numpy(), '--')\n", 71 | "plt.plot(filtered[0, :, 2].numpy(), '--')\n", 72 | "plt.legend(['signal', 'low-pass', 'mid-pass', 'high-pass'])\n", 73 | "plt.show()" 74 | ] 75 | } 76 | ] 77 | } -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="torch-bandpass", 5 | version="1.0.0", 6 | description="A PyTorch implementation of the Prism filter for Transformers", 7 | url="https://github.com/unixpickle/torch-bandpass", 8 | author="Alex Nichol", 9 | author_email="unixpickle@gmail.com", 10 | license="BSD", 11 | packages=["torch_bandpass"], 12 | install_requires=["numpy", "torch"], 13 | ) 14 | -------------------------------------------------------------------------------- /torch_bandpass/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | An implementation of the Prism algorithm from https://arxiv.org/abs/2011.04823. 3 | """ 4 | 5 | from .layer import Prism 6 | 7 | __all__ = ["Prism"] 8 | -------------------------------------------------------------------------------- /torch_bandpass/dct.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def create_dct_matrix(num_samples): 5 | """ 6 | Create a matrix that can be left-multiplied to perform the discrete cosine 7 | transform. 8 | """ 9 | result = np.zeros([num_samples, num_samples], dtype=np.float32) 10 | result += np.pi / num_samples 11 | result *= np.arange(num_samples, dtype=np.float32) + 0.5 12 | result *= np.arange(num_samples, dtype=np.float32)[:, None] 13 | result = np.cos(result) 14 | return result 15 | 16 | 17 | def create_dct_inverse(num_samples): 18 | """ 19 | Create the inverse of create_dct_matrix(). 20 | """ 21 | return np.linalg.inv(create_dct_matrix(num_samples)) 22 | 23 | 24 | def dct_period_to_bin(num_samples, period): 25 | """ 26 | Get the DCT bin index closest to the given period. 27 | """ 28 | return round(num_samples / (2 * period)) 29 | -------------------------------------------------------------------------------- /torch_bandpass/layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .dct import create_dct_matrix, create_dct_inverse, dct_period_to_bin 5 | 6 | 7 | class Prism(nn.Module): 8 | """ 9 | A Prism module performs different bandpass filters on different subsets of 10 | the features in a batch of activations. 11 | 12 | :param num_samples: the number of timesteps to expect in the inputs. 13 | :param mid_periods: the periods separating the bands. Periods of 1 and 14 | infinity are implied at the two extremes. 15 | """ 16 | 17 | def __init__(self, num_samples, mid_periods=(2, 8, 32, 256)): 18 | super().__init__() 19 | self.num_samples = num_samples 20 | self.mid_periods = mid_periods 21 | bins = ( 22 | [0] 23 | + [dct_period_to_bin(num_samples, p) for p in mid_periods[::-1]] 24 | + [num_samples] 25 | ) 26 | self.bands = nn.ModuleList([]) 27 | for min_index, max_index in zip(bins, bins[1:]): 28 | self.bands.append(Bandpass(num_samples, min_index, max_index)) 29 | 30 | def forward(self, x): 31 | """ 32 | Apply the Prism layer to a batch of sequences. 33 | 34 | :param x: an [N x T x C] Tensor, where C is the number of features, 35 | T is the number of timesteps, and N is the batch size. 36 | """ 37 | n, t, _c = x.shape 38 | assert t == self.num_samples 39 | 40 | x = x.permute(2, 0, 1).contiguous() # put C on the outer dimension 41 | 42 | chunks_in = _split_up_chunks(x, len(self.bands)) 43 | chunks_out = [] 44 | for bandpass, chunk in zip(self.bands, chunks_in): 45 | chunk = chunk.reshape(-1, t) 46 | chunk = bandpass(chunk) 47 | chunk = chunk.reshape(-1, n, t) 48 | chunks_out.append(chunk) 49 | joined_out = torch.cat(chunks_out, dim=0) 50 | joined_out = joined_out.permute( 51 | 1, 2, 0 52 | ).contiguous() # put C back as the inner dimension 53 | return joined_out 54 | 55 | 56 | class Bandpass(nn.Module): 57 | def __init__(self, num_samples, min_index, max_index): 58 | super().__init__() 59 | dct = create_dct_matrix(num_samples)[min_index:max_index] 60 | inv = create_dct_inverse(num_samples)[:, min_index:max_index] 61 | self.register_buffer("forward_dct", torch.from_numpy(dct.T)) 62 | self.register_buffer("backward_dct", torch.from_numpy(inv.T)) 63 | 64 | def forward(self, x): 65 | return (x @ self.forward_dct) @ self.backward_dct 66 | 67 | 68 | def _split_up_chunks(batch, num_chunks): 69 | chunk_size = batch.shape[0] // num_chunks 70 | num_larger_chunks = batch.shape[0] % num_chunks 71 | start_idx = 0 72 | result = [] 73 | for i in range(num_chunks): 74 | this_chunk_size = chunk_size 75 | if i < num_larger_chunks: 76 | this_chunk_size += 1 77 | result.append(batch[start_idx : start_idx + this_chunk_size]) 78 | start_idx += this_chunk_size 79 | return result 80 | --------------------------------------------------------------------------------