├── LICENSE └── README.md /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Axel Chemla--Romeu-Santos 4 | with free permission of use (without commercial application) to ACIDS 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchbend 2 | 3 | `torchbend` is a library grounded on `torch.fx` focused on generative neural networks analysis and creative bending. This library allows you to: 4 | 5 | - [✔︎] extend the tracing abilities of `torch.fx` with augmented parsers and proxies 6 | - dynamic parsing (wrapping un-traceable functions, shape propagation) 7 | - tracing torch distributions (currently implemented : `Bernoulli`, `Normal`, `Categorical`) 8 | - [✔︎] easily parse and analyze model's graphs 9 | - [✕︎] bend model's weights and activations 10 | - [✕︎] adapt the library to specific generative models, and provide handy interfaces for python notebooks 11 | - [✕︎] handful classes for image, text, and sound 12 | - [✕︎] panel implementation for real-time bending 13 | - [✕︎] model analysis UI 14 | - [✕︎] script generative models with JIT additional bending inputs (for use in [nn~] for example) 15 | 16 | `torchbend` provides end-to-end examples for the following libraries: 17 | - **Audio** 18 | - vschaos2 19 | - RAVE 20 | - audiocraft 21 | - **Image** 22 | - StyleGAN3 23 | - StableDiffusion 24 | - **Text** 25 | - Llama 26 | 27 | ## Parse and analyse model's graphs 28 | 29 | ```python 30 | import torch, torchbend 31 | 32 | # make dumb module to test 33 | module = torchbend.TestModule() 34 | module_in = torch.randn(1, 1, 512) 35 | 36 | # init BendedModule with the module, and trace target functions with given inputs 37 | bended_module = torchbend.BendedModule(module) 38 | bended_module.trace("forward", x=module_in) 39 | 40 | # print weights and activations 41 | print("weights : ") 42 | bended_module.print_weights() 43 | print("activations : ") 44 | bended_module.print_activations() 45 | 46 | outs = bended_module.get_activations('pre_conv', x=module_in) 47 | print("pre_conv activation min and max : ", outs['pre_conv'].min(), outs['pre_conv'].max()) 48 | ``` 49 | 50 | 51 | ## Bending weights and activations 52 | 53 | ```python 54 | import torch, torchbend 55 | 56 | # make dumb module to test 57 | module = torchbend.TestModule() 58 | module_in = torch.randn(1, 1, 512) 59 | 60 | # init BendedModule with the module, and trace target functions with given inputs 61 | bended_module = torchbend.BendedModule(module) 62 | bended_module.trace("forward", x=module_in) 63 | 64 | # bend target weights and make forward pass 65 | bended_module.bend(torchbend.bending.Mask(0.), "pre_conv.weight") 66 | outs = bended_module(x=module_in) 67 | print("pre_conv bended weight std : ", bended_module.bended_state_dict()['pre_conv.weight'].std()) 68 | print("pre_conv original weight std : ", bended_module.module.state_dict()['pre_conv.weight'].std()) 69 | 70 | # reset bending 71 | bended_module.reset() 72 | 73 | # bend target activation 74 | bended_module.bend(torchbend.bending.Mask(0.), "pre_conv") 75 | outs = bended_module.get_activations('pre_conv', x=module_in, bended=True) 76 | print("pre_conv activation min and max : ", outs['pre_conv'].min(), outs['pre_conv'].max()) 77 | ``` 78 | --------------------------------------------------------------------------------