├── .gitignore ├── .gitlab-ci.yml ├── LICENSE ├── README.md ├── example.py ├── s5 ├── __init__.py ├── init.py ├── jax_bench.py ├── jax_compat.py └── s5_model.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.egg-info/ 3 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | include: 2 | - remote: https://git.devdroplets.com/Ryan/ci-cd-templates/-/raw/main/py-packaging/gitlab-build.yml 3 | 4 | stages: 5 | - build 6 | 7 | build-gitlab: 8 | extends: .gitlab-build 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Mozilla Public License Version 2.0 2 | ================================== 3 | 4 | 1. Definitions 5 | -------------- 6 | 7 | 1.1. "Contributor" 8 | means each individual or legal entity that creates, contributes to 9 | the creation of, or owns Covered Software. 10 | 11 | 1.2. "Contributor Version" 12 | means the combination of the Contributions of others (if any) used 13 | by a Contributor and that particular Contributor's Contribution. 14 | 15 | 1.3. "Contribution" 16 | means Covered Software of a particular Contributor. 17 | 18 | 1.4. "Covered Software" 19 | means Source Code Form to which the initial Contributor has attached 20 | the notice in Exhibit A, the Executable Form of such Source Code 21 | Form, and Modifications of such Source Code Form, in each case 22 | including portions thereof. 23 | 24 | 1.5. "Incompatible With Secondary Licenses" 25 | means 26 | 27 | (a) that the initial Contributor has attached the notice described 28 | in Exhibit B to the Covered Software; or 29 | 30 | (b) that the Covered Software was made available under the terms of 31 | version 1.1 or earlier of the License, but not also under the 32 | terms of a Secondary License. 33 | 34 | 1.6. "Executable Form" 35 | means any form of the work other than Source Code Form. 36 | 37 | 1.7. "Larger Work" 38 | means a work that combines Covered Software with other material, in 39 | a separate file or files, that is not Covered Software. 40 | 41 | 1.8. "License" 42 | means this document. 43 | 44 | 1.9. "Licensable" 45 | means having the right to grant, to the maximum extent possible, 46 | whether at the time of the initial grant or subsequently, any and 47 | all of the rights conveyed by this License. 48 | 49 | 1.10. "Modifications" 50 | means any of the following: 51 | 52 | (a) any file in Source Code Form that results from an addition to, 53 | deletion from, or modification of the contents of Covered 54 | Software; or 55 | 56 | (b) any new file in Source Code Form that contains any Covered 57 | Software. 58 | 59 | 1.11. "Patent Claims" of a Contributor 60 | means any patent claim(s), including without limitation, method, 61 | process, and apparatus claims, in any patent Licensable by such 62 | Contributor that would be infringed, but for the grant of the 63 | License, by the making, using, selling, offering for sale, having 64 | made, import, or transfer of either its Contributions or its 65 | Contributor Version. 66 | 67 | 1.12. "Secondary License" 68 | means either the GNU General Public License, Version 2.0, the GNU 69 | Lesser General Public License, Version 2.1, the GNU Affero General 70 | Public License, Version 3.0, or any later versions of those 71 | licenses. 72 | 73 | 1.13. "Source Code Form" 74 | means the form of the work preferred for making modifications. 75 | 76 | 1.14. "You" (or "Your") 77 | means an individual or a legal entity exercising rights under this 78 | License. For legal entities, "You" includes any entity that 79 | controls, is controlled by, or is under common control with You. For 80 | purposes of this definition, "control" means (a) the power, direct 81 | or indirect, to cause the direction or management of such entity, 82 | whether by contract or otherwise, or (b) ownership of more than 83 | fifty percent (50%) of the outstanding shares or beneficial 84 | ownership of such entity. 85 | 86 | 2. License Grants and Conditions 87 | -------------------------------- 88 | 89 | 2.1. Grants 90 | 91 | Each Contributor hereby grants You a world-wide, royalty-free, 92 | non-exclusive license: 93 | 94 | (a) under intellectual property rights (other than patent or trademark) 95 | Licensable by such Contributor to use, reproduce, make available, 96 | modify, display, perform, distribute, and otherwise exploit its 97 | Contributions, either on an unmodified basis, with Modifications, or 98 | as part of a Larger Work; and 99 | 100 | (b) under Patent Claims of such Contributor to make, use, sell, offer 101 | for sale, have made, import, and otherwise transfer either its 102 | Contributions or its Contributor Version. 103 | 104 | 2.2. Effective Date 105 | 106 | The licenses granted in Section 2.1 with respect to any Contribution 107 | become effective for each Contribution on the date the Contributor first 108 | distributes such Contribution. 109 | 110 | 2.3. Limitations on Grant Scope 111 | 112 | The licenses granted in this Section 2 are the only rights granted under 113 | this License. No additional rights or licenses will be implied from the 114 | distribution or licensing of Covered Software under this License. 115 | Notwithstanding Section 2.1(b) above, no patent license is granted by a 116 | Contributor: 117 | 118 | (a) for any code that a Contributor has removed from Covered Software; 119 | or 120 | 121 | (b) for infringements caused by: (i) Your and any other third party's 122 | modifications of Covered Software, or (ii) the combination of its 123 | Contributions with other software (except as part of its Contributor 124 | Version); or 125 | 126 | (c) under Patent Claims infringed by Covered Software in the absence of 127 | its Contributions. 128 | 129 | This License does not grant any rights in the trademarks, service marks, 130 | or logos of any Contributor (except as may be necessary to comply with 131 | the notice requirements in Section 3.4). 132 | 133 | 2.4. Subsequent Licenses 134 | 135 | No Contributor makes additional grants as a result of Your choice to 136 | distribute the Covered Software under a subsequent version of this 137 | License (see Section 10.2) or under the terms of a Secondary License (if 138 | permitted under the terms of Section 3.3). 139 | 140 | 2.5. Representation 141 | 142 | Each Contributor represents that the Contributor believes its 143 | Contributions are its original creation(s) or it has sufficient rights 144 | to grant the rights to its Contributions conveyed by this License. 145 | 146 | 2.6. Fair Use 147 | 148 | This License is not intended to limit any rights You have under 149 | applicable copyright doctrines of fair use, fair dealing, or other 150 | equivalents. 151 | 152 | 2.7. Conditions 153 | 154 | Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted 155 | in Section 2.1. 156 | 157 | 3. Responsibilities 158 | ------------------- 159 | 160 | 3.1. Distribution of Source Form 161 | 162 | All distribution of Covered Software in Source Code Form, including any 163 | Modifications that You create or to which You contribute, must be under 164 | the terms of this License. You must inform recipients that the Source 165 | Code Form of the Covered Software is governed by the terms of this 166 | License, and how they can obtain a copy of this License. You may not 167 | attempt to alter or restrict the recipients' rights in the Source Code 168 | Form. 169 | 170 | 3.2. Distribution of Executable Form 171 | 172 | If You distribute Covered Software in Executable Form then: 173 | 174 | (a) such Covered Software must also be made available in Source Code 175 | Form, as described in Section 3.1, and You must inform recipients of 176 | the Executable Form how they can obtain a copy of such Source Code 177 | Form by reasonable means in a timely manner, at a charge no more 178 | than the cost of distribution to the recipient; and 179 | 180 | (b) You may distribute such Executable Form under the terms of this 181 | License, or sublicense it under different terms, provided that the 182 | license for the Executable Form does not attempt to limit or alter 183 | the recipients' rights in the Source Code Form under this License. 184 | 185 | 3.3. Distribution of a Larger Work 186 | 187 | You may create and distribute a Larger Work under terms of Your choice, 188 | provided that You also comply with the requirements of this License for 189 | the Covered Software. If the Larger Work is a combination of Covered 190 | Software with a work governed by one or more Secondary Licenses, and the 191 | Covered Software is not Incompatible With Secondary Licenses, this 192 | License permits You to additionally distribute such Covered Software 193 | under the terms of such Secondary License(s), so that the recipient of 194 | the Larger Work may, at their option, further distribute the Covered 195 | Software under the terms of either this License or such Secondary 196 | License(s). 197 | 198 | 3.4. Notices 199 | 200 | You may not remove or alter the substance of any license notices 201 | (including copyright notices, patent notices, disclaimers of warranty, 202 | or limitations of liability) contained within the Source Code Form of 203 | the Covered Software, except that You may alter any license notices to 204 | the extent required to remedy known factual inaccuracies. 205 | 206 | 3.5. Application of Additional Terms 207 | 208 | You may choose to offer, and to charge a fee for, warranty, support, 209 | indemnity or liability obligations to one or more recipients of Covered 210 | Software. However, You may do so only on Your own behalf, and not on 211 | behalf of any Contributor. You must make it absolutely clear that any 212 | such warranty, support, indemnity, or liability obligation is offered by 213 | You alone, and You hereby agree to indemnify every Contributor for any 214 | liability incurred by such Contributor as a result of warranty, support, 215 | indemnity or liability terms You offer. You may include additional 216 | disclaimers of warranty and limitations of liability specific to any 217 | jurisdiction. 218 | 219 | 4. Inability to Comply Due to Statute or Regulation 220 | --------------------------------------------------- 221 | 222 | If it is impossible for You to comply with any of the terms of this 223 | License with respect to some or all of the Covered Software due to 224 | statute, judicial order, or regulation then You must: (a) comply with 225 | the terms of this License to the maximum extent possible; and (b) 226 | describe the limitations and the code they affect. Such description must 227 | be placed in a text file included with all distributions of the Covered 228 | Software under this License. Except to the extent prohibited by statute 229 | or regulation, such description must be sufficiently detailed for a 230 | recipient of ordinary skill to be able to understand it. 231 | 232 | 5. Termination 233 | -------------- 234 | 235 | 5.1. The rights granted under this License will terminate automatically 236 | if You fail to comply with any of its terms. However, if You become 237 | compliant, then the rights granted under this License from a particular 238 | Contributor are reinstated (a) provisionally, unless and until such 239 | Contributor explicitly and finally terminates Your grants, and (b) on an 240 | ongoing basis, if such Contributor fails to notify You of the 241 | non-compliance by some reasonable means prior to 60 days after You have 242 | come back into compliance. Moreover, Your grants from a particular 243 | Contributor are reinstated on an ongoing basis if such Contributor 244 | notifies You of the non-compliance by some reasonable means, this is the 245 | first time You have received notice of non-compliance with this License 246 | from such Contributor, and You become compliant prior to 30 days after 247 | Your receipt of the notice. 248 | 249 | 5.2. If You initiate litigation against any entity by asserting a patent 250 | infringement claim (excluding declaratory judgment actions, 251 | counter-claims, and cross-claims) alleging that a Contributor Version 252 | directly or indirectly infringes any patent, then the rights granted to 253 | You by any and all Contributors for the Covered Software under Section 254 | 2.1 of this License shall terminate. 255 | 256 | 5.3. In the event of termination under Sections 5.1 or 5.2 above, all 257 | end user license agreements (excluding distributors and resellers) which 258 | have been validly granted by You or Your distributors under this License 259 | prior to termination shall survive termination. 260 | 261 | ************************************************************************ 262 | * * 263 | * 6. Disclaimer of Warranty * 264 | * ------------------------- * 265 | * * 266 | * Covered Software is provided under this License on an "as is" * 267 | * basis, without warranty of any kind, either expressed, implied, or * 268 | * statutory, including, without limitation, warranties that the * 269 | * Covered Software is free of defects, merchantable, fit for a * 270 | * particular purpose or non-infringing. The entire risk as to the * 271 | * quality and performance of the Covered Software is with You. * 272 | * Should any Covered Software prove defective in any respect, You * 273 | * (not any Contributor) assume the cost of any necessary servicing, * 274 | * repair, or correction. This disclaimer of warranty constitutes an * 275 | * essential part of this License. No use of any Covered Software is * 276 | * authorized under this License except under this disclaimer. * 277 | * * 278 | ************************************************************************ 279 | 280 | ************************************************************************ 281 | * * 282 | * 7. Limitation of Liability * 283 | * -------------------------- * 284 | * * 285 | * Under no circumstances and under no legal theory, whether tort * 286 | * (including negligence), contract, or otherwise, shall any * 287 | * Contributor, or anyone who distributes Covered Software as * 288 | * permitted above, be liable to You for any direct, indirect, * 289 | * special, incidental, or consequential damages of any character * 290 | * including, without limitation, damages for lost profits, loss of * 291 | * goodwill, work stoppage, computer failure or malfunction, or any * 292 | * and all other commercial damages or losses, even if such party * 293 | * shall have been informed of the possibility of such damages. This * 294 | * limitation of liability shall not apply to liability for death or * 295 | * personal injury resulting from such party's negligence to the * 296 | * extent applicable law prohibits such limitation. Some * 297 | * jurisdictions do not allow the exclusion or limitation of * 298 | * incidental or consequential damages, so this exclusion and * 299 | * limitation may not apply to You. * 300 | * * 301 | ************************************************************************ 302 | 303 | 8. Litigation 304 | ------------- 305 | 306 | Any litigation relating to this License may be brought only in the 307 | courts of a jurisdiction where the defendant maintains its principal 308 | place of business and such litigation shall be governed by laws of that 309 | jurisdiction, without reference to its conflict-of-law provisions. 310 | Nothing in this Section shall prevent a party's ability to bring 311 | cross-claims or counter-claims. 312 | 313 | 9. Miscellaneous 314 | ---------------- 315 | 316 | This License represents the complete agreement concerning the subject 317 | matter hereof. If any provision of this License is held to be 318 | unenforceable, such provision shall be reformed only to the extent 319 | necessary to make it enforceable. Any law or regulation which provides 320 | that the language of a contract shall be construed against the drafter 321 | shall not be used to construe this License against a Contributor. 322 | 323 | 10. Versions of the License 324 | --------------------------- 325 | 326 | 10.1. New Versions 327 | 328 | Mozilla Foundation is the license steward. Except as provided in Section 329 | 10.3, no one other than the license steward has the right to modify or 330 | publish new versions of this License. Each version will be given a 331 | distinguishing version number. 332 | 333 | 10.2. Effect of New Versions 334 | 335 | You may distribute the Covered Software under the terms of the version 336 | of the License under which You originally received the Covered Software, 337 | or under the terms of any subsequent version published by the license 338 | steward. 339 | 340 | 10.3. Modified Versions 341 | 342 | If you create software not governed by this License, and you want to 343 | create a new license for such software, you may create and use a 344 | modified version of this License if you rename the license and remove 345 | any references to the name of the license steward (except to note that 346 | such modified license differs from this License). 347 | 348 | 10.4. Distributing Source Code Form that is Incompatible With Secondary 349 | Licenses 350 | 351 | If You choose to distribute Source Code Form that is Incompatible With 352 | Secondary Licenses under the terms of this version of the License, the 353 | notice described in Exhibit B of this License must be attached. 354 | 355 | Exhibit A - Source Code Form License Notice 356 | ------------------------------------------- 357 | 358 | This Source Code Form is subject to the terms of the Mozilla Public 359 | License, v. 2.0. If a copy of the MPL was not distributed with this 360 | file, You can obtain one at https://mozilla.org/MPL/2.0/. 361 | 362 | If it is not possible or desirable to put the notice in a particular 363 | file, then You may include the notice in a location (such as a LICENSE 364 | file in a relevant directory) where a recipient would be likely to look 365 | for such a notice. 366 | 367 | You may add additional accurate notices of copyright ownership. 368 | 369 | Exhibit B - "Incompatible With Secondary Licenses" Notice 370 | --------------------------------------------------------- 371 | 372 | This Source Code Form is "Incompatible With Secondary Licenses", as 373 | defined by the Mozilla Public License, v. 2.0. 374 | 375 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # S5: Simplified State Space Layers for Sequence Modeling 2 | This is a ported version derived from and . 3 | It includes a bunch of functions ported from jax/lax/flax/whatever since they didn't exist yet. 4 | 5 | ~~Jax is required because it relies on the pytree structure but it's not used for any computation.~~ 6 | Since version 0.2.0 jax is not required, it's using the pytorch native `torch.utils._pytree` (this may be incompatible for pytorch future versions). 7 | Pytorch 2 or later is required because it makes heavy use of `torch.vmap` and `torch.utils._pytree` to substitute it's jax counterpart. 8 | Python 3.10 or later is required due to usage of the `match` keyword 9 | 10 | \--- 11 | 12 | Update: 13 | 14 | In my experiments it follows the results found in the [Hyena Hierarchy](https://arxiv.org/abs/2302.10866) (& H3) paper that the state spaces alone lack the recall capabilities required for LLM but seem work well for regular sequence feature extraction and linear complexity. 15 | 16 | You can use variable step-size as described in the paper using a 1D tensor for `step_scale` however this takes **a lot of memory** due to a lot of intermediate values needing to be held (which I believe is true for the official S5 repo, but not mentioned in the paper unless I missed it). 17 | 18 | ## Install 19 | 20 | ```sh 21 | pip install s5-pytorch 22 | ``` 23 | 24 | ## Example 25 | 26 | ```py3 27 | from s5 import S5, S5Block 28 | 29 | # Raw S5 operator 30 | x = torch.rand([2, 256, 32]) 31 | model = S5(32, 32) 32 | model(x) # [2, 256, 32] 33 | 34 | # S5-former block (S5+FFN-GLU w/ layernorm, dropout & residual) 35 | model = S5Block(32, 32, False) 36 | model(x) # [2, 256, 32] 37 | ``` 38 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import torch.profiler as profiler 2 | import torch 3 | import torchinfo 4 | from s5 import S5, S5Block 5 | 6 | dim = 512 7 | x = torch.rand(2, 8192, dim) 8 | # model = S5(32, 32) 9 | model = S5Block(dim, 512, block_count=8, bidir=False) 10 | 11 | print(torchinfo.summary(model, (2, 8192, dim), device='cpu', depth=5)) 12 | 13 | for i in range(5): 14 | y = model(x) 15 | print(y.shape, y) # [2, 256, 32] 16 | 17 | with profiler.profile(with_stack=True, profile_memory=True) as prof: 18 | res = model(x) 19 | 20 | print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cpu_memory_usage", row_limit=10)) -------------------------------------------------------------------------------- /s5/__init__.py: -------------------------------------------------------------------------------- 1 | from .s5_model import * 2 | -------------------------------------------------------------------------------- /s5/init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .jax_compat import variance_scaling, lecun_normal, uniform 4 | import scipy.linalg 5 | 6 | # Initialization Functions 7 | 8 | def make_HiPPO(N): 9 | """ Create a HiPPO-LegS matrix. 10 | From https://github.com/srush/annotated-s4/blob/main/s4/s4.py 11 | Args: 12 | N (int32): state size 13 | Returns: 14 | N x N HiPPO LegS matrix 15 | """ 16 | P = np.sqrt(1 + 2 * np.arange(N)) 17 | A = P[:, np.newaxis] * P[np.newaxis, :] 18 | A = np.tril(A) - np.diag(np.arange(N)) 19 | return -A 20 | 21 | 22 | def make_NPLR_HiPPO(N): 23 | """ 24 | Makes components needed for NPLR representation of HiPPO-LegS 25 | From https://github.com/srush/annotated-s4/blob/main/s4/s4.py 26 | Args: 27 | N (int32): state size 28 | Returns: 29 | N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B 30 | """ 31 | # Make -HiPPO 32 | hippo = make_HiPPO(N) 33 | 34 | # Add in a rank 1 term. Makes it Normal. 35 | P = np.sqrt(np.arange(N) + 0.5) 36 | 37 | # HiPPO also specifies the B matrix 38 | B = np.sqrt(2 * np.arange(N) + 1.0) 39 | return hippo, P, B 40 | 41 | 42 | def make_DPLR_HiPPO(N): 43 | """ 44 | Makes components needed for DPLR representation of HiPPO-LegS 45 | From https://github.com/srush/annotated-s4/blob/main/s4/s4.py 46 | Note, we will only use the diagonal part 47 | Args: 48 | N: 49 | Returns: 50 | eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B, 51 | eigenvectors V, HiPPO B pre-conjugation 52 | """ 53 | A, P, B = make_NPLR_HiPPO(N) 54 | 55 | S = A + P[:, np.newaxis] * P[np.newaxis, :] 56 | 57 | S_diag = np.diagonal(S) 58 | Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) 59 | 60 | # Diagonalize S to V \Lambda V^* 61 | Lambda_imag, V = np.linalg.eigh(S * -1j) 62 | 63 | P = V.conj().T @ P 64 | B_orig = B 65 | B = V.conj().T @ B 66 | return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig 67 | 68 | 69 | def make_Normal_S(N): 70 | nhippo = make_HiPPO(N) 71 | # Add in a rank 1 term. Makes it Normal. 72 | p = 0.5 * np.sqrt(2 * np.arange(1, N + 1) + 1.0) 73 | q = 2 * p 74 | S = nhippo + p[:, np.newaxis] * q[np.newaxis, :] 75 | return S 76 | 77 | 78 | def make_Normal_HiPPO(N, B=1): 79 | """Create a normal approximation to HiPPO-LegS matrix. 80 | For HiPPO matrix A, A=S+pqT is normal plus low-rank for 81 | a certain normal matrix S and low rank terms p and q. 82 | We are going to approximate the HiPPO matrix with the normal matrix S. 83 | Note we use original numpy instead of jax.numpy first to use the 84 | onp.linalg.eig function. This is because Jax's linalg.eig function does not run 85 | on GPU for non-symmetric matrices. This creates tracing issues. 86 | So we instead use onp.linalg eig and then cast to a jax array 87 | (since we only have to do this once in the beginning to initialize). 88 | Args: 89 | N (int32): state size 90 | B (int32): diagonal blocks 91 | Returns: 92 | Lambda (complex64): eigenvalues of S (N,) 93 | V (complex64): eigenvectors of S (N,N) 94 | """ 95 | 96 | assert N % B == 0, "N must divide blocks" 97 | S = (make_Normal_S(N // B),) * B 98 | S = scipy.linalg.block_diag(*S) 99 | 100 | # Diagonalize S to V \Lambda V^* 101 | Lambda, V = np.linalg.eig(S) 102 | 103 | # Convert to jax array 104 | return torch.tensor(Lambda), torch.tensor(V) 105 | 106 | 107 | def log_step_initializer(dt_min=0.001, dt_max=0.1): 108 | """ Initialize the learnable timescale Delta by sampling 109 | uniformly between dt_min and dt_max. 110 | Args: 111 | dt_min (float32): minimum value 112 | dt_max (float32): maximum value 113 | Returns: 114 | init function 115 | """ 116 | def init(shape): 117 | """ Init function 118 | Args: 119 | key: jax random key 120 | shape tuple: desired shape 121 | Returns: 122 | sampled log_step (float32) 123 | """ 124 | return uniform(shape, minval=np.log(dt_min), maxval=np.log(dt_max)) 125 | # return torch.rand(shape) * (np.log(dt_max) - np.log(dt_min)) + np.log(dt_min) 126 | 127 | return init 128 | 129 | 130 | def init_log_steps(H, dt_min, dt_max): 131 | """ Initialize an array of learnable timescale parameters 132 | Args: 133 | key: jax random key 134 | input: tuple containing the array shape H and 135 | dt_min and dt_max 136 | Returns: 137 | initialized array of timescales (float32): (H,) 138 | """ 139 | log_steps = [] 140 | for i in range(H): 141 | log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(shape=(1,)) 142 | log_steps.append(log_step) 143 | 144 | return torch.tensor(log_steps) 145 | 146 | 147 | def init_VinvB(init_fun, Vinv): 148 | """ Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. 149 | Note we will parameterize this with two different matrices for complex 150 | numbers. 151 | Args: 152 | init_fun: the initialization function to use, e.g. lecun_normal() 153 | shape (tuple): desired shape (P,H) 154 | Vinv: (complex64) the inverse eigenvectors used for initialization 155 | Returns: 156 | B_tilde (complex64) of shape (P,H,2) 157 | """ 158 | def init(shape, dtype): 159 | B = init_fun(shape, dtype) 160 | VinvB = Vinv @ B.type(Vinv.dtype) 161 | VinvB_real = VinvB.real 162 | VinvB_imag = VinvB.imag 163 | return torch.cat((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1) 164 | return init 165 | 166 | 167 | def trunc_standard_normal(shape): 168 | """ Sample C with a truncated normal distribution with standard deviation 1. 169 | Args: 170 | key: jax random key 171 | shape (tuple): desired shape, of length 3, (H,P,_) 172 | Returns: 173 | sampled C matrix (float32) of shape (H,P,2) (for complex parameterization) 174 | """ 175 | H, P, _ = shape 176 | Cs = [] 177 | for i in range(H): 178 | C = lecun_normal()(shape=(1, P, 2)) 179 | Cs.append(C) 180 | return torch.tensor(Cs)[:, 0] 181 | 182 | 183 | def init_CV(init_fun, shape, V) -> torch.Tensor: 184 | """ Initialize C_tilde=CV. First sample C. Then compute CV. 185 | Note we will parameterize this with two different matrices for complex 186 | numbers. 187 | Args: 188 | init_fun: the initialization function to use, e.g. lecun_normal() 189 | shape (tuple): desired shape (H,P) 190 | V: (complex64) the eigenvectors used for initialization 191 | Returns: 192 | C_tilde (complex64) of shape (H,P,2) 193 | """ 194 | C_ = init_fun(shape + (2,)) 195 | C = C_[..., 0] + 1j * C_[..., 1] 196 | CV = C @ V 197 | return CV 198 | 199 | 200 | def init_columnwise_B(shape, dtype): 201 | """Initialize B matrix in columnwise fashion. 202 | We will sample each column of B from a lecun_normal distribution. 203 | This gives a different fan-in size then if we sample the entire 204 | matrix B at once. We found this approach to be helpful for PathX 205 | It appears to be related to the point in 206 | https://arxiv.org/abs/2206.12037 regarding the initialization of 207 | the C matrix in S4, so potentially more important for the 208 | C initialization than for B. 209 | Args: 210 | key: jax random key 211 | shape (tuple): desired shape, either of length 3, (P,H,_), or 212 | of length 2 (N,H) depending on if the function is called 213 | from the low-rank factorization initialization or a dense 214 | initialization 215 | Returns: 216 | sampled B matrix (float32), either of shape (H,P) or 217 | shape (H,P,2) (for complex parameterization) 218 | """ 219 | shape = shape[:2] + ((2,) if len(shape) == 3 else ()) 220 | lecun = variance_scaling(0.5 if len(shape) == 3 else 1.0, fan_in_axes=(0,)) 221 | return lecun(shape, dtype) 222 | 223 | 224 | def init_columnwise_VinvB(init_fun, Vinv): 225 | """Same function as above, but with transpose applied to prevent shape mismatch 226 | when using the columnwise initialization. In general this is unnecessary 227 | and will be removed in future versions, but is left for now consistency with 228 | certain random seeds until we rerun experiments.""" 229 | 230 | def init(shape, dtype): 231 | B = init_fun(shape[:2], dtype) 232 | VinvB = Vinv @ B 233 | VinvB_real = VinvB.real 234 | VinvB_imag = VinvB.imag 235 | return torch.cat((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1) 236 | 237 | return init 238 | 239 | 240 | def init_rowwise_C(shape, dtype): 241 | """Initialize C matrix in rowwise fashion. Analogous to init_columnwise_B function above. 242 | We will sample each row of C from a lecun_normal distribution. 243 | This gives a different fan-in size then if we sample the entire 244 | matrix B at once. We found this approach to be helpful for PathX. 245 | It appears to be related to the point in 246 | https://arxiv.org/abs/2206.12037 regarding the initialization of 247 | the C matrix in S4. 248 | Args: 249 | shape (tuple): desired shape, of length 3, (H,P,_) 250 | Returns: 251 | sampled C matrix (float32) of shape (H,P,2) (for complex parameterization) 252 | """ 253 | shape = shape[:2] + ((2,) if len(shape) == 3 else ()) 254 | lecun = variance_scaling(0.5, fan_in_axes=(0,)) 255 | return lecun(shape, dtype) -------------------------------------------------------------------------------- /s5/jax_bench.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import time 4 | import triton 5 | import triton.language as tl 6 | from triton.runtime.jit import TensorWrapper, reinterpret 7 | from s5.jax_compat import associative_scan 8 | 9 | int_dtypes = ['int8', 'int16', 'int32', 'int64'] 10 | uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] 11 | float_dtypes = ['float16', 'float32', 'float64'] 12 | dtypes = int_dtypes + uint_dtypes + float_dtypes 13 | dtypes_with_bfloat16 = dtypes + ['bfloat16'] 14 | torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16'] 15 | 16 | 17 | def to_triton(x: np.ndarray, device='cuda', dst_type=None): 18 | t = x.dtype.name 19 | if t in uint_dtypes: 20 | signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16" 21 | x_signed = x.astype(getattr(np, signed_type_name)) 22 | return reinterpret(torch.tensor(x_signed, device=device).contiguous(), getattr(tl, t)) 23 | else: 24 | if dst_type and 'float8' in dst_type: 25 | return reinterpret(torch.tensor(x, device=device).contiguous(), getattr(tl, dst_type)) 26 | if t == 'float32' and dst_type == 'bfloat16': 27 | return torch.tensor(x, device=device).contiguous().bfloat16() 28 | return torch.tensor(x, device=device).contiguous() 29 | 30 | 31 | def to_numpy(x): 32 | if isinstance(x, TensorWrapper): 33 | # FIXME: torch_dtype_name doesn't exist 34 | return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype))) 35 | elif isinstance(x, torch.Tensor): 36 | if x.dtype is torch.bfloat16: 37 | return x.cpu().float().numpy() 38 | return x.cpu().numpy() 39 | else: 40 | raise ValueError(f"Not a triton-compatible tensor: {x}") 41 | 42 | 43 | if __name__ == "__main__": 44 | use_gpu = True 45 | 46 | if use_gpu: 47 | device = torch.device('cuda:0') 48 | else: 49 | device = None 50 | 51 | triton_times = [] 52 | loop_times = [] 53 | loop_comp_times = [] 54 | jax_compat_times = [] 55 | 56 | print("Initializing") 57 | op = 'cumsum' 58 | num_warps = 16 59 | 60 | dim = 1 61 | seq_len = 2048 62 | batch = 4 63 | 64 | dtype_str = 'float32' 65 | axis = 0 66 | shape = (batch, seq_len, dim) 67 | n_timings = 10000 68 | 69 | x = np.random.rand(*shape).astype(dtype=np.float32) 70 | inp = torch.tensor(x, device=device, requires_grad=True, dtype=torch.float32) 71 | init = torch.zeros(shape[1], 1, device=device, requires_grad=True) 72 | inp_scan = inp 73 | 74 | @triton.jit 75 | def sum_op(a, b): 76 | return a + b 77 | 78 | @triton.jit 79 | def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): 80 | range_m = tl.arange(0, BLOCK_M) 81 | range_n = tl.arange(0, BLOCK_N) 82 | x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) 83 | #tl.device_print("z", x) 84 | z = tl.associative_scan(x, 0, sum_op) 85 | #tl.device_print("z", z) 86 | tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z) 87 | 88 | print("Triton") 89 | z = np.empty_like(x) 90 | x_tri = to_triton(x, device=device) 91 | numpy_op = np.cumsum 92 | z_dtype_str = dtype_str 93 | z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) 94 | # triton result 95 | z_tri = to_triton(z, device=device) 96 | val = kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps) 97 | out_triton = to_numpy(z_tri) 98 | 99 | for _ in range(n_timings): 100 | # print('.', end='', flush=True) 101 | start = time.monotonic_ns() 102 | kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps) 103 | stop = time.monotonic_ns() 104 | triton_times.append((stop - start) / (10 ** 9)) 105 | 106 | print("\nFake scan") 107 | def f(carry, x): 108 | return carry+x, carry+x 109 | 110 | def _fake_scan(f, init, x): 111 | zs = [] 112 | carry = init 113 | for xp in x: 114 | carry, out = f(carry, xp) 115 | zs.append(out) 116 | return carry, torch.stack(zs) 117 | 118 | expected_carry_out, expected_ys = _fake_scan(f, init, inp_scan) 119 | 120 | for _ in range(n_timings): 121 | # print('.', end='', flush=True) 122 | start = time.monotonic_ns() 123 | expected_carry_out, expected_ys = _fake_scan(f, init, inp_scan) 124 | stop = time.monotonic_ns() 125 | loop_times.append((stop - start) / (10 ** 9)) 126 | 127 | # _fake_scan_comp = torch.compile(_fake_scan, mode='reduce-overhead', fullgraph=True, dynamic=False) 128 | 129 | # # Warm-up cycles 130 | # print("\nFake scan-compiled") 131 | # for _ in range(5): 132 | # expected_carry_out_comp, expected_ys_comp = _fake_scan_comp(f, init, inp_scan) 133 | 134 | # for _ in range(n_timings): 135 | # print('.', end='', flush=True) 136 | # start = time.monotonic_ns() 137 | # expected_carry_out_comp, expected_ys_comp = _fake_scan_comp(f, init, inp_scan) 138 | # stop = time.monotonic_ns() 139 | # loop_comp_times.append((stop - start) / (10 ** 9)) 140 | 141 | def sum_op2(a, b): 142 | return a+b, a + b 143 | 144 | # Warm-up 145 | print("\njax_compat") 146 | for _ in range(5): 147 | expected_ys_comp = associative_scan(sum_op2, inp_scan, axis=-1) 148 | 149 | for _ in range(n_timings): 150 | # print('.', end='', flush=True) 151 | start = time.monotonic_ns() 152 | expected_ys_comp = associative_scan(sum_op2, inp_scan, axis=-1) 153 | stop = time.monotonic_ns() 154 | jax_compat_times.append((stop - start) / (10 ** 9)) 155 | 156 | print() 157 | print('Times regular loop ' + str(np.array(loop_times).mean())) 158 | # print('Times compiled loop ' + str(np.array(loop_comp_times).mean())) 159 | print('Times triton ' + str(np.array(triton_times).mean())) 160 | print('Times jax_compat ' + str(np.array(jax_compat_times).mean())) 161 | print('Script ended') 162 | -------------------------------------------------------------------------------- /s5/jax_compat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils._pytree import tree_flatten, tree_unflatten 4 | from typing import overload, Callable, Iterable, List, TypeVar, Any, Literal, Sequence, Optional 5 | from functools import partial 6 | import math 7 | 8 | """ 9 | Jax-Pytorch ported functions, mostly interfaces are kept the same but unsupported features are removed: 10 | * Jax-Keyed RNGs are sampled from global RNG 11 | * Canonical/Named shapes/dtypes/etc are now regular shapes,dtypes 12 | """ 13 | 14 | T = TypeVar("T") 15 | T1 = TypeVar("T1") 16 | T2 = TypeVar("T2") 17 | T3 = TypeVar("T3") 18 | 19 | 20 | @overload 21 | def safe_map(f: Callable[[T1], T], __arg1: Iterable[T1]) -> List[T]: ... 22 | 23 | 24 | @overload 25 | def safe_map(f: Callable[[T1, T2], T], __arg1: Iterable[T1], __arg2: Iterable[T2]) -> List[T]: ... 26 | 27 | 28 | @overload 29 | def safe_map(f: Callable[[T1, T2, T3], T], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> List[T]: ... 30 | 31 | 32 | @overload 33 | def safe_map(f: Callable[..., T], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> List[T]: ... 34 | 35 | 36 | def safe_map(f, *args): 37 | args = list(map(list, args)) 38 | n = len(args[0]) 39 | for arg in args[1:]: 40 | assert len(arg) == n, f'length mismatch: {list(map(len, args))}' 41 | return list(map(f, *args)) 42 | 43 | 44 | def combine(tree, operator, a_flat, b_flat): 45 | # Lower `fn` to operate on flattened sequences of elems. 46 | a = tree_unflatten(a_flat, tree) 47 | b = tree_unflatten(b_flat, tree) 48 | c = operator(a, b) 49 | c_flat, _ = tree_flatten(c) 50 | return c_flat 51 | 52 | 53 | def _scan(tree, operator, elems, axis: int): 54 | """Perform scan on `elems`.""" 55 | num_elems = elems[0].shape[axis] 56 | 57 | if num_elems < 2: 58 | return elems 59 | 60 | # Combine adjacent pairs of elements. 61 | reduced_elems = combine(tree, operator, 62 | [torch.ops.aten.slice(elem, axis, 0, -1, 2) for elem in elems], 63 | [torch.ops.aten.slice(elem, axis, 1, None, 2) for elem in elems]) 64 | 65 | # Recursively compute scan for partially reduced tensors. 66 | odd_elems = _scan(tree, operator, reduced_elems, axis) 67 | 68 | if num_elems % 2 == 0: 69 | even_elems = combine(tree, operator, 70 | [torch.ops.aten.slice(e, axis, 0, -1) for e in odd_elems], 71 | [torch.ops.aten.slice(e, axis, 2, None, 2) for e in elems]) 72 | else: 73 | even_elems = combine(tree, operator, 74 | odd_elems, 75 | [torch.ops.aten.slice(e, axis, 2, None, 2) for e in elems]) 76 | 77 | # The first element of a scan is the same as the first element 78 | # of the original `elems`. 79 | even_elems = [ 80 | torch.cat([torch.ops.aten.slice(elem, axis, 0, 1), result], dim=axis) 81 | if result.shape.numel() > 0 and elem.shape[axis] > 0 else 82 | result if result.shape.numel() > 0 else 83 | torch.ops.aten.slice(elem, axis, 0, 1) # Jax allows/ignores concat with 0-dim, Pytorch does not 84 | for (elem, result) in zip(elems, even_elems)] 85 | 86 | return list(safe_map(partial(_interleave, axis=axis), even_elems, odd_elems)) 87 | 88 | # Pytorch impl. of jax.lax.associative_scan 89 | 90 | 91 | def associative_scan(operator: Callable, elems, axis: int = 0, reverse: bool = False): 92 | # if not callable(operator): 93 | # raise TypeError("lax.associative_scan: fn argument should be callable.") 94 | elems_flat, tree = tree_flatten(elems) 95 | 96 | if reverse: 97 | elems_flat = [torch.flip(elem, [axis]) for elem in elems_flat] 98 | 99 | assert axis >= 0 or axis < elems_flat[0].ndim, "Axis should be within bounds of input" 100 | num_elems = int(elems_flat[0].shape[axis]) 101 | if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]): 102 | raise ValueError('Array inputs to associative_scan must have the same ' 103 | 'first dimension. (saw: {})' 104 | .format([elem.shape for elem in elems_flat])) 105 | 106 | scans = _scan(tree, operator, elems_flat, axis) 107 | 108 | if reverse: 109 | scans = [torch.flip(scanned, [axis]) for scanned in scans] 110 | 111 | return tree_unflatten(scans, tree) 112 | 113 | 114 | def test_associative_scan(shape=(1, 24, 24)): 115 | import jax.lax 116 | import jax 117 | 118 | x = np.random.randn(*shape) 119 | jx = jax.numpy.array(x) 120 | tx = torch.tensor(x, dtype=torch.float32) 121 | 122 | def nested_func(a, b): 123 | a_i, b_i = a 124 | a_j, b_j = b 125 | return a_j*a_i, a_j*b_i + b_j 126 | jy1, jy2 = jax.lax.associative_scan(nested_func, (jx, jx)) 127 | ty1, ty2 = associative_scan(nested_func, (tx, tx)) 128 | assert np.isclose(ty1.numpy(), np.array(jy1)).all() and np.isclose(ty2.numpy(), np.array(jy2)).all(), "Expected jax & pytorch impl to be close" 129 | 130 | jy1, jy2 = jax.lax.associative_scan(nested_func, (jx, jx), reverse=True) 131 | ty1, ty2 = associative_scan(nested_func, (tx, tx), reverse=True) 132 | assert np.isclose(ty1.numpy(), np.array(jy1)).all() and np.isclose(ty2.numpy(), np.array(jy2)).all(), "Expected jax & pytorch reverse impl to be close" 133 | 134 | 135 | # def _interleave(a, b, axis): 136 | # assert a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1 137 | # if b_trunc := (a.shape[axis] == b.shape[axis] + 1): 138 | # pad = [0, 0] * b.ndim 139 | # pad[(b.ndim-axis-1)*2+1] = 1 # +1=always end of dim, pad-order is reversed so start is at end 140 | # b = torch.nn.functional.pad(b, pad) 141 | 142 | # keys = list('ijklmnop')[:a.ndim] # Get enough keys for each dim 143 | # expr = 't ' + ' '.join(keys) + ' -> ' 144 | 145 | # keys[axis] = f'({keys[axis]} t)' # Interleave along desired axis 146 | # expr += ' '.join(keys) 147 | # # for example 't i j -> (i t) j' 148 | # out: torch.Tensor = rearrange([a, b], expr) 149 | # if b_trunc: 150 | # out = out[slice_along_axis(0, b.shape[axis]+a.shape[axis]-1, axis=axis)] 151 | # return out 152 | 153 | # @torch.jit.script 154 | def _interleave(a, b, axis: int): 155 | # https://stackoverflow.com/questions/60869537/how-can-i-interleave-5-pytorch-tensors 156 | b_trunc = (a.shape[axis] == b.shape[axis] + 1) 157 | if b_trunc: 158 | pad = [0, 0] * b.ndim 159 | pad[(b.ndim-axis-1)*2+1] = 1 # +1=always end of dim, pad-order is reversed so start is at end 160 | b = torch.nn.functional.pad(b, pad) 161 | 162 | stacked = torch.stack([a, b], dim=axis+1) 163 | interleaved = torch.flatten(stacked, start_dim=axis, end_dim=axis+1) 164 | if b_trunc: 165 | # TODO: find torch alternative for slice_along axis for torch.jit.script to work 166 | interleaved = torch.ops.aten.slice(interleaved, axis, 0, b.shape[axis]+a.shape[axis]-1) 167 | return interleaved 168 | 169 | 170 | def test_interleave(): 171 | x, y = torch.randn(1, 32, 32), torch.randn(1, 32, 32) 172 | v = _interleave(x, y, axis=1) 173 | assert v.shape == (1, 64, 32) 174 | assert (v[:, 0] == x[:, 0]).all() 175 | assert (v[:, 1] == y[:, 0]).all() 176 | assert (v[:, 2] == x[:, 1]).all() 177 | assert (v[:, 3] == y[:, 1]).all() 178 | assert (v[:, 4] == x[:, 2]).all() 179 | 180 | v = _interleave(x, y, axis=2) 181 | assert v.shape == (1, 32, 64) 182 | assert (v[..., 0] == x[..., 0]).all() 183 | assert (v[..., 1] == y[..., 0]).all() 184 | assert (v[..., 2] == x[..., 1]).all() 185 | assert (v[..., 3] == y[..., 1]).all() 186 | assert (v[..., 4] == x[..., 2]).all() 187 | 188 | x, y = torch.randn(1, 24, 24), torch.randn(1, 24, 24) 189 | assert _interleave(x, y, axis=1).shape == (1, 48, 24) 190 | assert _interleave(x, y, axis=2).shape == (1, 24, 48) 191 | 192 | x, y = torch.randn(3, 96), torch.randn(2, 96) 193 | v = _interleave(x, y, axis=0) 194 | assert v.shape == (5, 96) 195 | assert (v[0] == x[0]).all() 196 | assert (v[1] == y[0]).all() 197 | assert (v[2] == x[1]).all() 198 | assert (v[3] == y[1]).all() 199 | assert (v[4] == x[2]).all() 200 | print('Interleave working as expected!') 201 | 202 | 203 | def _compute_fans(shape, fan_in_axes=None): 204 | """Computes the number of input and output units for a weight shape.""" 205 | if len(shape) < 1: 206 | fan_in = fan_out = 1 207 | elif len(shape) == 1: 208 | fan_in = fan_out = shape[0] 209 | elif len(shape) == 2: 210 | fan_in, fan_out = shape 211 | else: 212 | if fan_in_axes is not None: 213 | # Compute fan-in using user-specified fan-in axes. 214 | fan_in = np.prod([shape[i] for i in fan_in_axes]) 215 | fan_out = np.prod([s for i, s in enumerate(shape) 216 | if i not in fan_in_axes]) 217 | else: 218 | # If no axes specified, assume convolution kernels (2D, 3D, or more.) 219 | # kernel_shape: (..., input_depth, depth) 220 | receptive_field_size = np.prod(shape[:-2]) 221 | fan_in = shape[-2] * receptive_field_size 222 | fan_out = shape[-1] * receptive_field_size 223 | return fan_in, fan_out 224 | 225 | 226 | def uniform(shape, dtype=torch.float, minval=0., maxval=1.0, device=None): 227 | src = torch.rand(shape, dtype=dtype, device=device) 228 | if minval == 0 and maxval == 1.: 229 | return src 230 | else: 231 | return (src * (maxval - minval)) + minval 232 | 233 | 234 | def _complex_uniform(shape: Sequence[int], 235 | dtype, device=None) -> torch.Tensor: 236 | """ 237 | Sample uniform random values within a disk on the complex plane, 238 | with zero mean and unit variance. 239 | """ 240 | r = torch.sqrt(2 * torch.rand(shape, dtype=dtype, device=device)) 241 | theta = 2 * torch.pi * torch.rand(shape, dtype=dtype, device=device) 242 | return r * torch.exp(1j * theta) 243 | 244 | 245 | def complex_as_float_dtype(dtype): 246 | match dtype: 247 | case torch.complex32: 248 | return torch.float32 # NOTE: complexe32 is not wel supported yet 249 | case torch.complex64: 250 | return torch.float32 251 | case torch.complex128: 252 | return torch.float64 253 | case _: 254 | return dtype 255 | 256 | 257 | def _complex_truncated_normal(upper: float, 258 | shape: Sequence[int], 259 | dtype, device=None) -> torch.Tensor: 260 | """ 261 | Sample random values from a centered normal distribution on the complex plane, 262 | whose modulus is truncated to `upper`, and the variance before the truncation 263 | is one. 264 | """ 265 | real_dtype = torch.tensor(0, dtype=dtype).real.dtype 266 | t = ((1 - torch.exp(torch.tensor(-(upper ** 2), dtype=dtype, device=device))) 267 | * torch.rand(shape, dtype=real_dtype, device=device).type(dtype)) 268 | r = torch.sqrt(-torch.log(1 - t)) 269 | theta = 2 * torch.pi * torch.rand(shape, dtype=real_dtype, device=device).type(dtype) 270 | return r * torch.exp(1j * theta) 271 | 272 | 273 | def _truncated_normal(lower, upper, shape, dtype=torch.float): 274 | if shape is None: 275 | shape = torch.broadcast_shapes(np.shape(lower), np.shape(upper)) 276 | 277 | sqrt2 = math.sqrt(2) 278 | a = math.erf(lower / sqrt2) 279 | b = math.erf(upper / sqrt2) 280 | 281 | # a>> import jax, jax.numpy as jnp 356 | >>> initializer = jax.nn.initializers.lecun_normal() 357 | >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP 358 | Array([[ 0.46700746, 0.8414632 , 0.8518669 ], 359 | [-0.61677957, -0.67402434, 0.09683388]], dtype=float32) 360 | 361 | .. _Lecun normal initializer: https://arxiv.org/abs/1706.02515 362 | """ 363 | return variance_scaling(1.0, "fan_in", "truncated_normal", fan_in_axes=fan_in_axes, dtype=dtype) 364 | 365 | 366 | def test_variance_scaling(): 367 | v = variance_scaling(1.0, distribution='normal') 368 | n_f32 = v((1, 10000), dtype=torch.float) 369 | assert np.isclose(n_f32.std().item(), 1.0, rtol=0.015, atol=0.015), f'std for f32 normal[0,1.0] is {n_f32.std()} != 1.0' 370 | del n_f32 371 | # NOTE: this is used in the original as `complex_normal` (but with stddev=0.5**0.5) 372 | n_c64 = v((1, 10000), dtype=torch.complex64) 373 | assert np.isclose(n_c64.std().item(), 1.0, rtol=0.015, atol=0.015), f'std for c64 normal[0,1.0] is {n_c64.std()} != 1.0' 374 | del n_c64 375 | 376 | # Truncated normal 377 | v = variance_scaling(1.0, distribution='truncated_normal') 378 | tn_f32 = v((1, 10000), dtype=torch.float) 379 | assert np.isclose(tn_f32.std().item(), 0.775, rtol=0.015, atol=0.015), f'std for f32 truncated normal[0,1.0] is {tn_f32.std()} != 0.775' 380 | del tn_f32 381 | 382 | # NOTE: this is used in the original (both trunc_standard_normal & lecun_normal it seems), 383 | # seems that they are using the fan-in/out feature to 'hide the low variance initialization' 384 | # The actual std observed is np.sqrt(2/shape[1]/(2*shape[0])); shape[2] has no impact 385 | v = variance_scaling(1.0, distribution='truncated_normal') 386 | tn_f32 = v((1, 10000, 2), dtype=torch.float) 387 | tn_c32 = torch.complex(tn_f32[..., 0], tn_f32[..., 1]) 388 | expected_std = np.sqrt(2/tn_f32.shape[1]/(2*tn_f32.shape[0])) 389 | print(tn_c32.shape) 390 | assert np.isclose(tn_c32.std().item(), expected_std, rtol=0.015, atol=0.015), f'std for f32 truncated normal[0,1.0] is {tn_c32.std()} != {expected_std}' 391 | del tn_f32 392 | del tn_c32 393 | 394 | 395 | # tn_c64 = v((1, 10000), dtype=torch.complex64) 396 | # assert np.isclose(tn_c64.std().item(), 0.775, rtol=0.015, atol=0.015), f'std for c64 truncated normal[0,1.0] is {tn_c64.std()} != 0.775' 397 | # del tn_c64 398 | 399 | 400 | if __name__ == '__main__': 401 | test_variance_scaling() 402 | test_interleave() 403 | test_associative_scan() 404 | test_associative_scan(shape=(2, 256, 24)) 405 | test_associative_scan(shape=(360, 96)) 406 | -------------------------------------------------------------------------------- /s5/s5_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Tuple, Optional, Literal 4 | from .jax_compat import associative_scan 5 | from .init import * 6 | 7 | # Runtime functions 8 | 9 | 10 | @torch.jit.script 11 | def binary_operator(q_i: Tuple[torch.Tensor, torch.Tensor], q_j: Tuple[torch.Tensor, torch.Tensor]): 12 | """Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. 13 | Args: 14 | q_i: tuple containing A_i and Bu_i at position i (P,), (P,) 15 | q_j: tuple containing A_j and Bu_j at position j (P,), (P,) 16 | Returns: 17 | new element ( A_out, Bu_out ) 18 | """ 19 | A_i, Bu_i = q_i 20 | A_j, Bu_j = q_j 21 | # return A_j * A_i, A_j * Bu_i + Bu_j 22 | return A_j * A_i, torch.addcmul(Bu_j, A_j, Bu_i) 23 | 24 | 25 | def apply_ssm(Lambda_bars: torch.Tensor, B_bars, C_tilde, D, input_sequence, state=None, bidir: bool = False): 26 | cinput_sequence = input_sequence.type(Lambda_bars.dtype) # Cast to correct complex type 27 | 28 | if B_bars.ndim == 3: 29 | # Dynamic timesteps (significantly more expensive) 30 | Bu_elements = torch.vmap(lambda B_bar, u: B_bar @ u)(B_bars, cinput_sequence) 31 | else: 32 | # Static timesteps 33 | Bu_elements = torch.vmap(lambda u: B_bars @ u)(cinput_sequence) 34 | 35 | if Lambda_bars.ndim == 1: # Zero-pad for associative_scan 36 | Lambda_bars = Lambda_bars.tile(input_sequence.shape[0], 1) 37 | 38 | if state is not None: 39 | # Bu_elements = torch.cat(((state).unsqueeze(0), Bu_elements), dim=0) 40 | # Lambda_bars = torch.cat((torch.ones_like(state.unsqueeze(0)), Lambda_bars), dim=0) 41 | # Manually compute first step (Lambda_bar=1 so no change) 42 | Bu_elements[0] = Bu_elements[0] + state * Lambda_bars[0] 43 | 44 | _, xs = associative_scan(binary_operator, (Lambda_bars, Bu_elements)) 45 | 46 | if bidir: 47 | _, xs2 = associative_scan(binary_operator, (Lambda_bars, Bu_elements), reverse=True) 48 | xs = torch.cat((xs, xs2), axis=-1) 49 | 50 | Du = torch.vmap(lambda u: D * u)(input_sequence) 51 | return torch.vmap(lambda x: (C_tilde @ x).real)(xs) + Du, xs[-1] #torch.stack((_[-1], xs[-1])) 52 | 53 | 54 | def apply_ssm_liquid(Lambda_bars, B_bars, C_tilde, D, input_sequence, state=None, bidir: bool = False): 55 | """Liquid time constant SSM \u00e1 la dynamical systems given in Eq. 8 of 56 | https://arxiv.org/abs/2209.12951""" 57 | cinput_sequence = input_sequence.type(Lambda_bars.dtype) # Cast to correct complex type 58 | 59 | if B_bars.ndim == 3: 60 | # Dynamic timesteps (significantly more expensive) 61 | Bu_elements = torch.vmap(lambda B_bar, u: B_bar @ u)(B_bars, cinput_sequence) 62 | else: 63 | # Static timesteps 64 | Bu_elements = torch.vmap(lambda u: B_bars @ u)(cinput_sequence) 65 | 66 | if Lambda_bars.ndim == 1: # Zero-pad for associative_scan 67 | Lambda_bars = Lambda_bars.tile(input_sequence.shape[0], 1) 68 | 69 | if state is not None: 70 | # Manually compute first step (Lambda_bar=1 so no change) 71 | Bu_elements[0] = Bu_elements[0] + state * Lambda_bars[0] 72 | 73 | _, xs = associative_scan(binary_operator, (Lambda_bars + Bu_elements, Bu_elements)) 74 | 75 | if bidir: 76 | _, xs2 = associative_scan(binary_operator, (Lambda_bars, Bu_elements), reverse=True) 77 | xs = torch.cat((xs, xs2), axis=-1) 78 | 79 | Du = torch.vmap(lambda u: D * u)(input_sequence) 80 | return torch.vmap(lambda x: (C_tilde @ x).real)(xs) + Du, xs[-1] 81 | 82 | 83 | # Discretization functions 84 | def discretize_bilinear(Lambda, B_tilde, Delta): 85 | """Discretize a diagonalized, continuous-time linear SSM 86 | using bilinear transform method. 87 | Args: 88 | Lambda (complex64): diagonal state matrix (P,) 89 | B_tilde (complex64): input matrix (P, H) 90 | Delta (float32): discretization step sizes (P,) 91 | Returns: 92 | discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) 93 | """ 94 | Identity = torch.ones(Lambda.shape[0], device=Lambda.device) 95 | BL = 1 / (Identity - (Delta / 2.0) * Lambda) 96 | Lambda_bar = BL * (Identity + (Delta / 2.0) * Lambda) 97 | B_bar = (BL * Delta)[..., None] * B_tilde 98 | return Lambda_bar, B_bar 99 | 100 | 101 | def discretize_zoh(Lambda, B_tilde, Delta): 102 | """Discretize a diagonalized, continuous-time linear SSM 103 | using zero-order hold method. 104 | Args: 105 | Lambda (complex64): diagonal state matrix (P,) 106 | B_tilde (complex64): input matrix (P, H) 107 | Delta (float32): discretization step sizes (P,) 108 | Returns: 109 | discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) 110 | """ 111 | Identity = torch.ones(Lambda.shape[0], device=Lambda.device) # (replaced by -1) 112 | Lambda_bar = torch.exp(Lambda * Delta) 113 | B_bar = (1 / Lambda * (Lambda_bar - Identity))[..., None] * B_tilde 114 | return Lambda_bar, B_bar 115 | 116 | 117 | def as_complex(t: torch.Tensor, dtype=torch.complex64): 118 | assert t.shape[-1] == 2, "as_complex can only be done on tensors with shape=(...,2)" 119 | nt = torch.complex(t[..., 0], t[..., 1]) 120 | if nt.dtype != dtype: 121 | nt = nt.type(dtype) 122 | return nt 123 | 124 | 125 | Initialization = Literal['dense_columns', 'dense', 'factorized'] 126 | 127 | 128 | class S5SSM(torch.nn.Module): 129 | def __init__(self, lambdaInit: torch.Tensor, 130 | V: torch.Tensor, Vinv: torch.Tensor, h: int, p: int, 131 | dt_min: float, 132 | dt_max: float, 133 | liquid: bool = False, 134 | factor_rank: Optional[int] = None, 135 | discretization: Literal['zoh', 'bilinear'] = 'zoh', 136 | bcInit: Initialization = 'factorized', 137 | degree: int = 1, 138 | bidir: bool = False): 139 | """The S5 SSM 140 | Args: 141 | lambdaInit (complex64): Initial diagonal state matrix (P,) 142 | V (complex64): Eigenvectors used for init (P,P) 143 | Vinv (complex64): Inverse eigenvectors used for init (P,P) 144 | h (int32): Number of features of input seq 145 | p (int32): state size 146 | k (int32): rank of low-rank factorization (if used) 147 | bcInit (string): Specifies How B and C are initialized 148 | Options: [factorized: low-rank factorization, 149 | dense: dense matrix drawn from Lecun_normal] 150 | dense_columns: dense matrix where the columns 151 | of B and the rows of C are each drawn from Lecun_normal 152 | separately (i.e. different fan-in then the dense option). 153 | We found this initialization to be helpful for Pathx. 154 | discretization: (string) Specifies discretization method 155 | options: [zoh: zero-order hold method, 156 | bilinear: bilinear transform] 157 | liquid: (bool): use liquid_ssm from LiquidS4 158 | dt_min: (float32): minimum value to draw timescale values from when 159 | initializing log_step 160 | dt_max: (float32): maximum value to draw timescale values from when 161 | initializing log_step 162 | step_scale: (float32): allows for changing the step size, e.g. after training 163 | on a different resolution for the speech commands benchmark 164 | """ 165 | super().__init__() 166 | self.Lambda = torch.nn.Parameter(lambdaInit) 167 | self.degree = degree 168 | self.liquid = liquid 169 | self.bcInit = bcInit 170 | self.bidir = bidir 171 | # TODO: 172 | # if self.clip_eigs: 173 | # self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im 174 | 175 | # the P-dim of C can needs to be 2P for bidir 176 | cp = p 177 | if self.bidir: 178 | cp *= 2 179 | 180 | match bcInit: 181 | case 'complex_normal': 182 | self.C = torch.nn.Parameter(torch.normal(0, 0.5 ** 0.5, (h, cp), dtype=torch.complex64)) 183 | self.B = torch.nn.Parameter(init_VinvB(lecun_normal(), Vinv)((p, h), torch.float)) 184 | case 'dense_columns' | 'dense': 185 | if bcInit == "dense_columns": 186 | B_eigen_init = init_columnwise_VinvB 187 | B_init = init_columnwise_B 188 | C_init = init_rowwise_C 189 | elif bcInit == "dense": 190 | B_eigen_init = init_VinvB 191 | B_init = C_init = lecun_normal() 192 | # TODO: make init_*VinvB all a the same interface 193 | self.B = torch.nn.Parameter(B_eigen_init(B_init, Vinv)((p, h), torch.float)) 194 | if self.bidir: 195 | C = torch.cat([init_CV(C_init, (h, p), V), init_CV(C_init, (h, p), V)], axis=-1) 196 | else: 197 | C = init_CV(C_init, (h, p), V) 198 | self.C = torch.nn.Parameter(C) 199 | case 'factorized': 200 | print('[WARN]: factorized was removed from the original repo, might be for a reason :?') 201 | # Use a low rank factorization of rank k for B and C 202 | self.BH = torch.nn.Parameter(as_complex(init_columnwise_B((h, k, 2), torch.float32))) 203 | self.BP = torch.nn.Parameter(as_complex(init_columnwise_B((p, k, 2), torch.float32))) 204 | self.CH = torch.nn.Parameter(as_complex(init_rowwise_C((k, h, 2), torch.float32))) 205 | self.CP = torch.nn.Parameter(as_complex(init_rowwise_C((k, cp, 2), torch.float32))) 206 | #self.BH = torch.nn.Parameter(init_columnwise_B((h, k), torch.complex64)) 207 | #self.BP = torch.nn.Parameter(init_columnwise_B((p, k), torch.complex64)) 208 | #self.CH = torch.nn.Parameter(init_rowwise_C((k, h), torch.complex64)) 209 | #self.CP = torch.nn.Parameter(init_rowwise_C((k, p), torch.complex64)) 210 | case _: 211 | raise NotImplementedError(f"BC_init method {bcInit} not implemented") 212 | 213 | # Initialize feedthrough (D) matrix 214 | self.D = torch.nn.Parameter(torch.rand(h,)) 215 | self.log_step = torch.nn.Parameter(init_log_steps(p, dt_min, dt_max)) 216 | match discretization: 217 | case 'zoh': 218 | self.discretize = discretize_zoh 219 | case 'bilinear': 220 | self.discretize = discretize_bilinear 221 | case _: 222 | raise ValueError(f'Unknown discretization {discretization}') 223 | 224 | def initial_state(self, batch_size: Optional[int]): 225 | batch_shape = (batch_size,) if batch_size is not None else () 226 | return torch.zeros((*batch_shape, self.C_tilde.shape[-2])) 227 | 228 | def get_BC_tilde(self): 229 | match self.bcInit: 230 | case 'dense_columns' | 'dense' | 'complex_normal': 231 | B_tilde = as_complex(self.B) 232 | C_tilde = self.C 233 | case 'factorized': 234 | B_tilde = self.BP @ self.BH.T 235 | C_tilde = self.CH.T @ self.CP 236 | return B_tilde, C_tilde 237 | 238 | def forward_rnn(self, signal, prev_state, step_scale: float | torch.Tensor = 1.0): 239 | assert not self.bidir, "Can't use bidirectional when manually stepping" 240 | B_tilde, C_tilde = self.get_BC_tilde() 241 | step = step_scale * torch.exp(self.log_step) 242 | Lambda_bar, B_bar = self.discretize(self.Lambda, B_tilde, step) 243 | if self.degree != 1: 244 | assert (B_bar.shape[-2] == B_bar.shape[-1]), "higher-order input operators must be full-rank" 245 | B_bar **= self.degree 246 | 247 | # https://arxiv.org/abs/2209.12951v1, Eq. 9 248 | Bu = B_bar @ signal.type(B_bar.dtype) 249 | if self.liquid: 250 | Lambda_bar += Bu 251 | # https://arxiv.org/abs/2208.04933v2, Eq. 2 252 | x = Lambda_bar * prev_state + Bu 253 | y = (C_tilde @ x + self.D * signal).real 254 | return y, x 255 | 256 | def forward(self, signal, step_scale: float | torch.Tensor = 1.0, state=None, return_state=False): 257 | B_tilde, C_tilde = self.get_BC_tilde() 258 | 259 | if not torch.is_tensor(step_scale) or step_scale.ndim == 0: 260 | step = step_scale * torch.exp(self.log_step) 261 | else: 262 | # TODO: This is very expensive due to individual steps being multiplied by B_tilde in self.discretize 263 | step = step_scale[:, None] * torch.exp(self.log_step) 264 | 265 | # print(f'{self.Lambda.shape=} {B_tilde.shape=} {step.shape=}') 266 | # Lambda_bars, B_bars = torch.vmap(lambda s: self.discretize(self.Lambda, B_tilde, s))(step) 267 | # print(Lambda_bars.shape, B_bars.shape) 268 | Lambda_bars, B_bars = self.discretize(self.Lambda, B_tilde, step) 269 | if self.degree != 1: 270 | assert (B_bars.shape[-2] == B_bars.shape[-1]), "higher-order input operators must be full-rank" 271 | B_bars **= self.degree 272 | 273 | assert not (self.bidir and (state is not None)), "injecting state is not compatible with bidirectional S5" 274 | 275 | forward = apply_ssm_liquid if self.liquid else apply_ssm 276 | out, state = forward(Lambda_bars, B_bars, C_tilde, self.D, signal, state=state, bidir=self.bidir) 277 | # NOTE: technically it could work in a limited sense; taking the first and last element 278 | # but that wouldn't be equivalent to running bidir on full sequences. 279 | # It would be more like a circular S5 where you keep splicing the new signal into it; 280 | # we leave implementing/testing this as an exercise to the reader 281 | assert not (self.bidir and return_state), "return_state does not work with bidirectional S5" 282 | if return_state: 283 | return out, state 284 | return out 285 | 286 | 287 | class S5(torch.nn.Module): 288 | def __init__(self, 289 | width: int, 290 | state_width: Optional[int] = None, 291 | factor_rank: Optional[int] = None, 292 | block_count: int = 1, 293 | dt_min: float = 0.001, 294 | dt_max: float = 0.1, 295 | liquid: bool = False, 296 | degree: int = 1, 297 | bidir: bool = False, 298 | bcInit: Optional[Initialization] = None): 299 | super().__init__() 300 | state_width = state_width or width 301 | assert state_width % block_count == 0, "block_count should be a factor of state_width" 302 | 303 | block_size = state_width // block_count 304 | Lambda, _, B, V, B_orig = make_DPLR_HiPPO(block_size) 305 | Vinv = V.conj().T 306 | Lambda, B, V, B_orig, Vinv = map(lambda v: torch.tensor(v, dtype=torch.complex64), (Lambda, B, V, B_orig, Vinv)) 307 | if block_count > 1: 308 | Lambda = Lambda[:block_size] 309 | V = V[:, :block_size] 310 | Lambda = (Lambda * torch.ones((block_count, block_size))).ravel() 311 | V = torch.block_diag(*([V] * block_count)) 312 | Vinv = torch.block_diag(*([Vinv] * block_count)) 313 | 314 | assert bool(factor_rank) != bool(bcInit != 'factorized'), "Can't have `bcInit != factorized` and `factor_rank` defined" 315 | bc_init = "factorized" if factor_rank is not None else (bcInit or "dense") 316 | self.width = width 317 | self.seq = S5SSM( 318 | Lambda, 319 | V, 320 | Vinv, 321 | width, 322 | state_width, 323 | dt_min, 324 | dt_max, 325 | factor_rank=factor_rank, 326 | bcInit=bc_init, 327 | liquid=liquid, 328 | degree=degree, 329 | bidir=bidir 330 | ) 331 | 332 | def initial_state(self, batch_size: Optional[int] = None): 333 | return self.seq.initial_state(batch_size) 334 | 335 | def forward(self, signal, step_scale: float | torch.Tensor = 1.0, state=None, return_state=False): 336 | # NOTE: step_scale can be float | Tensor[batch] | Tensor[batch, seq] 337 | if not torch.is_tensor(step_scale): 338 | # Duplicate across batchdim 339 | step_scale = torch.ones(signal.shape[0], device=signal.device) * step_scale 340 | 341 | if state is None: 342 | return torch.vmap(lambda s, ss: self.seq(s, step_scale=ss, return_state=return_state))(signal, step_scale) 343 | else: 344 | return torch.vmap(lambda s, ss, _state: self.seq(s, step_scale=ss, state=_state, return_state=return_state))(signal, step_scale, state) 345 | 346 | 347 | class GEGLU(torch.nn.Module): 348 | def forward(self, x): 349 | x, gates = x.chunk(2, dim=-1) 350 | return x * F.gelu(gates) 351 | 352 | 353 | class S5Block(torch.nn.Module): 354 | def __init__(self, dim: int, state_dim: int, bidir: bool, block_count: int = 1, liquid: bool = False, degree: int = 1, factor_rank: int | None = None, bcInit: Optional[Initialization] = None, ff_mult: float = 1., glu: bool = True, 355 | ff_dropout: float = 0.0, attn_dropout: float = 0.0): 356 | super().__init__() 357 | self.s5 = S5(dim, state_width=state_dim, bidir=bidir, block_count=block_count, liquid=liquid, degree=degree, factor_rank=factor_rank, bcInit=bcInit) 358 | self.attn_norm = torch.nn.LayerNorm(dim) 359 | self.attn_dropout = torch.nn.Dropout(p=attn_dropout) 360 | self.geglu = GEGLU() if glu else None 361 | self.ff_enc = torch.nn.Linear(dim, int(dim * ff_mult) * (1 + glu), bias=False) 362 | self.ff_dec = torch.nn.Linear(int(dim * ff_mult), dim, bias=False) 363 | self.ff_norm = torch.nn.LayerNorm(dim) 364 | self.ff_dropout = torch.nn.Dropout(p=ff_dropout) 365 | 366 | def forward(self, x, state=None, return_state=False): 367 | # Standard transfomer-style block with GEGLU/Pre-LayerNorm 368 | fx = self.attn_norm(x) 369 | res = fx.clone() 370 | x = self.s5(fx, state=state, return_state=return_state) 371 | if return_state: 372 | x, next_state = x 373 | 374 | x = F.gelu(x) + res 375 | x = self.attn_dropout(x) 376 | 377 | fx = self.ff_norm(x) 378 | res = fx.clone() 379 | x = self.ff_enc(fx) 380 | if self.geglu is not None: 381 | x = self.geglu(x) 382 | x = self.ff_dec(x) + res 383 | x = self.ff_dropout(x) # TODO: test if should be placed inbetween ff or after ff 384 | 385 | if return_state: 386 | return x, next_state 387 | return x 388 | 389 | 390 | if __name__ == '__main__': 391 | # import lovely_tensors as lt 392 | # lt.monkey_patch() 393 | from tqdm import tqdm 394 | 395 | def tensor_stats(t: torch.Tensor): # Clone of lovely_tensors for complex support 396 | return f'tensor[{t.shape}] n={t.shape.numel()}, u={t.mean()}, s={round(t.std().item(), 3)} var={round(t.var().item(), 3)}\n' 397 | 398 | x = torch.rand([2, 768, 32]) 399 | model = S5(32, 128) 400 | print('B', tensor_stats(model.seq.B.data)) 401 | print('C', tensor_stats(model.seq.C.data)) 402 | # print('B', tensor_stats(model.seq.BH.data), tensor_stats(model.seq.BP.data)) 403 | # print('C', tensor_stats(model.seq.CH.data), tensor_stats(model.seq.CP.data)) 404 | # state = model.initial_state(256) 405 | # res = model(x, prev_state=state) 406 | # print(res.shape, res.dtype, res) 407 | with torch.no_grad(): 408 | res, state = model(x, return_state=True) 409 | print(state.shape, state.dtype, tensor_stats(state), f'{state[..., :10]=}') 410 | print(res.shape, res.dtype, res[:, -1]) 411 | 412 | print("Now with 100% more state:") 413 | res, state = model(x[:, :256], return_state=True) 414 | # print(state.shape, state.dtype, tensor_stats(state)) 415 | # print(res.shape, res.dtype, res) 416 | res, state = model(x[:, 256:512], state=state, return_state=True) 417 | # print(state.shape, state.dtype, tensor_stats(state)) 418 | # print(res.shape, res.dtype, res) 419 | res, state = model(x[:, 512:768], state=state, return_state=True) 420 | print(state.shape, state.dtype, tensor_stats(state), f'{state[..., :10]=}') 421 | print(res.shape, res.dtype, res[:, -1]) 422 | 423 | print("Corrupted state (negative test):") 424 | res, state = model(x[:, 512:768], state=torch.randn_like(state)/2, return_state=True) 425 | print(state.shape, state.dtype, tensor_stats(state), f'{state[..., :10]=}') 426 | print(res.shape, res.dtype, res[:, -1]) 427 | 428 | print("SSM specifics:") 429 | ssm = model.seq 430 | print("block:") 431 | res, state = ssm.forward(x[0, :512], return_state=True) 432 | print(res[-1], state[..., :10], state.shape) 433 | 434 | print("block-recurrent:") 435 | res, state = ssm.forward(x[0, :256], return_state=True) 436 | # print(res[-1], state) 437 | res, state = ssm.forward(x[0, 256:512], state=state, return_state=True) 438 | print(res[-1], state[..., :10], state.shape) 439 | 440 | print("Now as rnn:") 441 | state = torch.zeros_like(state[0]) 442 | for i in tqdm(range(512)): 443 | res, state = ssm.forward_rnn(x[0,i], state) 444 | print(res, state[..., :10], state.shape) 445 | 446 | 447 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='s5-pytorch', 5 | packages=find_packages(exclude=[]), 6 | version='0.2.1', 7 | license='MIT', 8 | description='S5 - Simplified State Space Layers for Sequence Modeling - Pytorch', 9 | author='Ferris Kwaijtaal', 10 | author_email='ferris+gh@devdroplets.com', 11 | long_description_content_type='text/markdown', 12 | long_description=open('README.md', 'r').read(), 13 | url='https://github.com/i404788/s5-pytorch', 14 | keywords=[ 15 | 'artificial intelligence', 16 | 'deep learning', 17 | 'transformers', 18 | 'attention mechanism', 19 | 'audio generation' 20 | ], 21 | install_requires=[ 22 | 'einops>=0.6', 23 | 'scipy', 24 | 'torch>=2', 25 | ], 26 | extra_requires={ 27 | "dev": ["jax"], 28 | }, 29 | classifiers=[ 30 | 'Development Status :: 4 - Beta', 31 | 'Intended Audience :: Developers', 32 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 33 | 'License :: OSI Approved :: MIT License', 34 | 'Programming Language :: Python :: 3.10', 35 | ], 36 | ) 37 | --------------------------------------------------------------------------------