├── .github └── workflows │ ├── deploy.yaml │ └── test.yaml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── index_files └── figure-commonmark │ ├── cell-13-output-1.png │ ├── cell-14-output-1.png │ ├── cell-15-output-1.svg │ ├── cell-16-output-1.svg │ ├── cell-17-output-1.svg │ ├── cell-19-output-1.png │ ├── cell-21-output-1.png │ ├── cell-30-output-1.png │ ├── cell-31-output-1.svg │ ├── cell-32-output-1.png │ ├── cell-33-output-1.png │ ├── cell-34-output-1.png │ └── cell-37-output-1.png ├── lovely_jax ├── __init__.py ├── _modidx.py ├── patch.py ├── repr_chans.py ├── repr_plt.py ├── repr_rgb.py ├── repr_str.py └── utils │ ├── __init__.py │ ├── config.py │ └── misc.py ├── nbs ├── 00_repr_str.ipynb ├── 01_repr_rgb.ipynb ├── 02_repr_plt.ipynb ├── 03a_utils.config.ipynb ├── 03b_utils.misc.ipynb ├── 05_repr_chans.ipynb ├── 10_patch.ipynb ├── _quarto.yml ├── index.ipynb ├── matplotlib.ipynb ├── mysteryman.npy ├── nbdev.yml ├── sidebar.yml └── styles.css ├── settings.ini └── setup.py /.github/workflows/deploy.yaml: -------------------------------------------------------------------------------- 1 | name: Deploy to GitHub Pages 2 | on: 3 | push: 4 | branches: [ "main", "master" ] 5 | workflow_dispatch: 6 | jobs: 7 | deploy: 8 | runs-on: ubuntu-latest 9 | steps: [uses: fastai/workflows/quarto-ghp@master] 10 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: [workflow_dispatch, pull_request, push] 3 | 4 | jobs: 5 | test: 6 | runs-on: ubuntu-latest 7 | steps: [uses: fastai/workflows/nbdev-ci@master] 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | _proc 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Alexey Zaytsev 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include settings.ini 2 | include LICENSE 3 | include CONTRIBUTING.md 4 | include README.md 5 | recursive-exclude * __pycache__ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 💘 Lovely JAX 2 | 3 | 4 | 5 | ## [Read full docs](https://xl0.github.io/lovely-jax) \| ❤️ [Lovely Tensors](https://github.com/xl0/lovely-tensors) \| 💟 [Lovely `NumPy`](https://github.com/xl0/lovely-numpy) \| [Discord](https://discord.gg/4NxRV7NH) 6 | 7 | ## Note: I’m pretty new to JAX 8 | 9 | If something does not make sense, shoot me an 10 | [Issue](https://github.com/xl0/lovely-jax/issues) or ping me on Discord 11 | and let me know how it’s supposed to work! 12 | 13 | Better support for sharded arrays and solid jit/pmap/vmap support coming 14 | soon! 15 | 16 | ## Install 17 | 18 | ``` sh 19 | pip install lovely-jax 20 | ``` 21 | 22 | ## How to use 23 | 24 | How often do you find yourself debugging JAX code? You dump an array to 25 | the cell output, and see this: 26 | 27 | ``` python 28 | numbers 29 | ``` 30 | 31 | Array([[[-0.354, -0.337, -0.405, ..., -0.56 , -0.474, 2.249], 32 | [-0.405, -0.423, -0.491, ..., -0.919, -0.851, 2.163], 33 | [-0.474, -0.474, -0.542, ..., -1.039, -1.039, 2.198], 34 | ..., 35 | [-0.902, -0.834, -0.936, ..., -1.467, -1.296, 2.232], 36 | [-0.851, -0.782, -0.936, ..., -1.604, -1.501, 2.18 ], 37 | [-0.834, -0.816, -0.971, ..., -1.656, -1.553, 2.112]], 38 | 39 | [[-0.197, -0.197, -0.303, ..., -0.478, -0.373, 2.411], 40 | [-0.25 , -0.232, -0.338, ..., -0.705, -0.67 , 2.359], 41 | [-0.303, -0.285, -0.39 , ..., -0.74 , -0.81 , 2.376], 42 | ..., 43 | [-0.425, -0.232, -0.373, ..., -1.09 , -1.02 , 2.429], 44 | [-0.39 , -0.232, -0.425, ..., -1.23 , -1.23 , 2.411], 45 | [-0.408, -0.285, -0.478, ..., -1.283, -1.283, 2.341]], 46 | 47 | [[-0.672, -0.985, -0.881, ..., -0.968, -0.689, 2.396], 48 | [-0.724, -1.072, -0.968, ..., -1.247, -1.02 , 2.326], 49 | [-0.828, -1.125, -1.02 , ..., -1.264, -1.16 , 2.379], 50 | ..., 51 | [-1.229, -1.473, -1.386, ..., -1.508, -1.264, 2.518], 52 | [-1.194, -1.456, -1.421, ..., -1.648, -1.473, 2.431], 53 | [-1.229, -1.526, -1.508, ..., -1.682, -1.526, 2.361]]], dtype=float32) 54 | 55 | Was it really useful for you, as a human, to see all these numbers? 56 | 57 | What is the shape? The size? 58 | What are the statistics? 59 | Are any of the values `nan` or `inf`? 60 | Is it an image of a man holding a tench? 61 | 62 | ``` python 63 | import lovely_jax as lj 64 | ``` 65 | 66 | ``` python 67 | lj.monkey_patch() 68 | ``` 69 | 70 | ## Summary 71 | 72 | ``` python 73 | numbers 74 | ``` 75 | 76 | Array[196, 196, 3] n=115248 (0.4Mb) x∈[-2.118, 2.640] μ=-0.388 σ=1.073 cpu:0 77 | 78 | Better, huh? 79 | 80 | ``` python 81 | numbers[1,:6,1] # Still shows values if there are not too many. 82 | ``` 83 | 84 | Array[6] x∈[-0.408, -0.232] μ=-0.340 σ=0.075 cpu:0 [-0.250, -0.232, -0.338, -0.408, -0.408, -0.408] 85 | 86 | ``` python 87 | spicy = numbers.flatten()[:12].copy() 88 | 89 | spicy = (spicy .at[0].mul(10000) 90 | .at[1].divide(10000) 91 | .at[2].set(float('inf')) 92 | .at[3].set(float('-inf')) 93 | .at[4].set(float('nan')) 94 | .reshape((2,6))) 95 | spicy # Spicy stuff 96 | ``` 97 | 98 | Array[2, 6] n=12 x∈[-3.541e+03, -1.975e-05] μ=-393.848 σ=1.113e+03 +Inf! -Inf! NaN! cpu:0 99 | 100 | ``` python 101 | jnp.zeros((10, 10)) # A zero array - make it obvious 102 | ``` 103 | 104 | Array[10, 10] n=100 all_zeros cpu:0 105 | 106 | ``` python 107 | spicy.v # Verbose 108 | ``` 109 | 110 | Array[2, 6] n=12 x∈[-3.541e+03, -1.975e-05] μ=-393.848 σ=1.113e+03 +Inf! -Inf! NaN! cpu:0 111 | Array([[-3.541e+03, -1.975e-05, inf, -inf, nan, -9.853e-01], 112 | [-4.054e-01, -3.025e-01, -8.807e-01, -4.397e-01, -3.025e-01, -7.761e-01]], dtype=float32) 113 | 114 | ``` python 115 | spicy.p # The plain old way 116 | ``` 117 | 118 | Array([[-3.541e+03, -1.975e-05, inf, -inf, nan, -9.853e-01], 119 | [-4.054e-01, -3.025e-01, -8.807e-01, -4.397e-01, -3.025e-01, -7.761e-01]], dtype=float32) 120 | 121 | ## Going `.deeper` 122 | 123 | ``` python 124 | numbers.deeper 125 | ``` 126 | 127 | Array[196, 196, 3] n=115248 (0.4Mb) x∈[-2.118, 2.640] μ=-0.388 σ=1.073 cpu:0 128 | Array[196, 3] n=588 x∈[-1.912, 2.411] μ=-0.728 σ=0.519 cpu:0 129 | Array[196, 3] n=588 x∈[-1.861, 2.359] μ=-0.778 σ=0.450 cpu:0 130 | Array[196, 3] n=588 x∈[-1.758, 2.379] μ=-0.838 σ=0.437 cpu:0 131 | Array[196, 3] n=588 x∈[-1.656, 2.466] μ=-0.878 σ=0.415 cpu:0 132 | Array[196, 3] n=588 x∈[-1.717, 2.448] μ=-0.882 σ=0.399 cpu:0 133 | Array[196, 3] n=588 x∈[-1.717, 2.431] μ=-0.905 σ=0.408 cpu:0 134 | Array[196, 3] n=588 x∈[-1.563, 2.448] μ=-0.859 σ=0.416 cpu:0 135 | Array[196, 3] n=588 x∈[-1.475, 2.431] μ=-0.791 σ=0.463 cpu:0 136 | Array[196, 3] n=588 x∈[-1.526, 2.429] μ=-0.759 σ=0.499 cpu:0 137 | ... 138 | 139 | ``` python 140 | # You can go deeper if you need to 141 | numbers[:3,:5,:3].deeper(2) 142 | ``` 143 | 144 | Array[3, 5, 3] n=45 x∈[-1.316, -0.197] μ=-0.593 σ=0.302 cpu:0 145 | Array[5, 3] n=15 x∈[-0.985, -0.197] μ=-0.491 σ=0.267 cpu:0 146 | Array[3] x∈[-0.672, -0.197] μ=-0.408 σ=0.197 cpu:0 [-0.354, -0.197, -0.672] 147 | Array[3] x∈[-0.985, -0.197] μ=-0.507 σ=0.343 cpu:0 [-0.337, -0.197, -0.985] 148 | Array[3] x∈[-0.881, -0.303] μ=-0.530 σ=0.252 cpu:0 [-0.405, -0.303, -0.881] 149 | Array[3] x∈[-0.776, -0.303] μ=-0.506 σ=0.199 cpu:0 [-0.440, -0.303, -0.776] 150 | Array[3] x∈[-0.916, -0.215] μ=-0.506 σ=0.298 cpu:0 [-0.388, -0.215, -0.916] 151 | Array[5, 3] n=15 x∈[-1.212, -0.232] μ=-0.609 σ=0.302 cpu:0 152 | Array[3] x∈[-0.724, -0.250] μ=-0.460 σ=0.197 cpu:0 [-0.405, -0.250, -0.724] 153 | Array[3] x∈[-1.072, -0.232] μ=-0.576 σ=0.360 cpu:0 [-0.423, -0.232, -1.072] 154 | Array[3] x∈[-0.968, -0.338] μ=-0.599 σ=0.268 cpu:0 [-0.491, -0.338, -0.968] 155 | Array[3] x∈[-0.968, -0.408] μ=-0.651 σ=0.235 cpu:0 [-0.577, -0.408, -0.968] 156 | Array[3] x∈[-1.212, -0.408] μ=-0.761 σ=0.336 cpu:0 [-0.662, -0.408, -1.212] 157 | Array[5, 3] n=15 x∈[-1.316, -0.285] μ=-0.677 σ=0.306 cpu:0 158 | Array[3] x∈[-0.828, -0.303] μ=-0.535 σ=0.219 cpu:0 [-0.474, -0.303, -0.828] 159 | Array[3] x∈[-1.125, -0.285] μ=-0.628 σ=0.360 cpu:0 [-0.474, -0.285, -1.125] 160 | Array[3] x∈[-1.020, -0.390] μ=-0.651 σ=0.268 cpu:0 [-0.542, -0.390, -1.020] 161 | Array[3] x∈[-1.003, -0.478] μ=-0.708 σ=0.219 cpu:0 [-0.645, -0.478, -1.003] 162 | Array[3] x∈[-1.316, -0.513] μ=-0.865 σ=0.336 cpu:0 [-0.765, -0.513, -1.316] 163 | 164 | ## Now in `.rgb` color 165 | 166 | The important queston - is it our man? 167 | 168 | ``` python 169 | numbers.rgb 170 | ``` 171 | 172 | ![](index_files/figure-commonmark/cell-13-output-1.png) 173 | 174 | *Maaaaybe?* Looks like someone normalized him. 175 | 176 | ``` python 177 | in_stats = ( (0.485, 0.456, 0.406), # mean 178 | (0.229, 0.224, 0.225) ) # std 179 | 180 | # numbers.rgb(in_stats, cl=True) # For channel-last input format 181 | numbers.rgb(in_stats) 182 | ``` 183 | 184 | ![](index_files/figure-commonmark/cell-14-output-1.png) 185 | 186 | It’s indeed our hero, the Tenchman! 187 | 188 | ## `.plt` the statistics 189 | 190 | ``` python 191 | (numbers+3).plt 192 | ``` 193 | 194 | ![](index_files/figure-commonmark/cell-15-output-1.svg) 195 | 196 | ``` python 197 | (numbers+3).plt(center="mean", max_s=1000) 198 | ``` 199 | 200 | ![](index_files/figure-commonmark/cell-16-output-1.svg) 201 | 202 | ``` python 203 | (numbers+3).plt(center="range") 204 | ``` 205 | 206 | ![](index_files/figure-commonmark/cell-17-output-1.svg) 207 | 208 | ## See the `.chans` 209 | 210 | ``` python 211 | # .chans will map values betwen [-1,1] to colors. 212 | # Make our values fit into that range to avoid clipping. 213 | mean = jnp.array(in_stats[0]) 214 | std = jnp.array(in_stats[1]) 215 | numbers_01 = (numbers*std + mean) 216 | numbers_01 217 | ``` 218 | 219 | Array[196, 196, 3] n=115248 (0.4Mb) x∈[0., 1.000] μ=0.361 σ=0.248 cpu:0 220 | 221 | ``` python 222 | numbers_01.chans 223 | ``` 224 | 225 | ![](index_files/figure-commonmark/cell-19-output-1.png) 226 | 227 | ## Grouping 228 | 229 | ``` python 230 | # Make 8 images with progressively higher brightness and stack them 2x2x2. 231 | eight_images = (jnp.stack([numbers]*8) + jnp.linspace(-2, 2, 8)[:,None,None,None]) 232 | eight_images = (eight_images 233 | *jnp.array(in_stats[1]) 234 | +jnp.array(in_stats[0]) 235 | ).clip(0,1).reshape(2,2,2,196,196,3) 236 | 237 | eight_images 238 | ``` 239 | 240 | Array[2, 2, 2, 196, 196, 3] n=921984 (3.5Mb) x∈[0., 1.000] μ=0.382 σ=0.319 cpu:0 241 | 242 | ``` python 243 | eight_images.rgb 244 | ``` 245 | 246 | ![](index_files/figure-commonmark/cell-21-output-1.png) 247 | 248 | ## Sharding 249 | 250 | ``` python 251 | assert jax.__version_info__[0] == 0 252 | if jax.__version_info__[1] >= 4: 253 | from jax.sharding import PositionalSharding 254 | from jax.experimental import mesh_utils 255 | sharding = PositionalSharding(mesh_utils.create_device_mesh((4,2))) 256 | x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192)) 257 | y = jax.device_put(x, sharding) 258 | 259 | jax.debug.visualize_array_sharding(y) 260 | else: 261 | # Note: Looks like ShardedDeviceArray needs an explicit device axis? 262 | x = jax.random.normal(jax.random.PRNGKey(0), (8, 1024, 8192)) 263 | y = jax.device_put_sharded([x for x in x], jax.devices()) 264 | 265 | print(x) 266 | print(y) 267 | ``` 268 | 269 |
                        
270 |    CPU 0       CPU 1    
271 |                         
272 |                         
273 |    CPU 2       CPU 3    
274 |                         
275 |                         
276 |    CPU 4       CPU 5    
277 |                         
278 |                         
279 |    CPU 6       CPU 7    
280 |                         
281 | 
282 | 283 | Array[8192, 8192] n=67108864 (0.2Gb) x∈[-5.420, 5.220] μ=-0.000 σ=1.000 cpu:0 284 | Array[8192, 8192] n=67108864 (0.2Gb) x∈[-5.420, 5.220] μ=-0.000 σ=1.000 cpu:0,1,2,3,4,5,6,7 285 | 286 | ## Options \| [Docs](utils.config.html) 287 | 288 | ``` python 289 | from lovely_jax import set_config, config, lovely, get_config 290 | ``` 291 | 292 | ``` python 293 | set_config(precision=5, sci_mode=True, color=False) 294 | jnp.array([1., 2, jnp.nan]) 295 | ``` 296 | 297 | Array[3] μ=1.50000e+00 σ=5.00000e-01 NaN! cpu:0 [1.00000e+00, 2.00000e+00, nan] 298 | 299 | ``` python 300 | set_config(precision=None, sci_mode=None, color=None) # None -> Reset to defaults 301 | ``` 302 | 303 | ``` python 304 | print(jnp.array([1., 2])) 305 | # Or with config context manager. 306 | with config(sci_mode=True, precision=5): 307 | print(jnp.array([1., 2])) 308 | 309 | print(jnp.array([1., 2])) 310 | ``` 311 | 312 | Array[2] μ=1.500 σ=0.500 cpu:0 [1.000, 2.000] 313 | Array[2] μ=1.50000e+00 σ=5.00000e-01 cpu:0 [1.00000e+00, 2.00000e+00] 314 | Array[2] μ=1.500 σ=0.500 cpu:0 [1.000, 2.000] 315 | 316 | ## Without `.monkey_patch` 317 | 318 | ``` python 319 | lj.lovely(spicy) 320 | ``` 321 | 322 | Array[2, 6] n=12 x∈[-3.541e+03, -1.975e-05] μ=-393.848 σ=1.113e+03 +Inf! -Inf! NaN! cpu:0 323 | 324 | ``` python 325 | lj.lovely(spicy, verbose=True) 326 | ``` 327 | 328 | Array[2, 6] n=12 x∈[-3.541e+03, -1.975e-05] μ=-393.848 σ=1.113e+03 +Inf! -Inf! NaN! cpu:0 329 | Array([[-3.541e+03, -1.975e-05, inf, -inf, nan, -9.853e-01], 330 | [-4.054e-01, -3.025e-01, -8.807e-01, -4.397e-01, -3.025e-01, -7.761e-01]], dtype=float32) 331 | 332 | ``` python 333 | lj.lovely(numbers, depth=1) 334 | ``` 335 | 336 | Array[196, 196, 3] n=115248 (0.4Mb) x∈[-2.118, 2.640] μ=-0.388 σ=1.073 cpu:0 337 | Array[196, 3] n=588 x∈[-1.912, 2.411] μ=-0.728 σ=0.519 cpu:0 338 | Array[196, 3] n=588 x∈[-1.861, 2.359] μ=-0.778 σ=0.450 cpu:0 339 | Array[196, 3] n=588 x∈[-1.758, 2.379] μ=-0.838 σ=0.437 cpu:0 340 | Array[196, 3] n=588 x∈[-1.656, 2.466] μ=-0.878 σ=0.415 cpu:0 341 | Array[196, 3] n=588 x∈[-1.717, 2.448] μ=-0.882 σ=0.399 cpu:0 342 | Array[196, 3] n=588 x∈[-1.717, 2.431] μ=-0.905 σ=0.408 cpu:0 343 | Array[196, 3] n=588 x∈[-1.563, 2.448] μ=-0.859 σ=0.416 cpu:0 344 | Array[196, 3] n=588 x∈[-1.475, 2.431] μ=-0.791 σ=0.463 cpu:0 345 | Array[196, 3] n=588 x∈[-1.526, 2.429] μ=-0.759 σ=0.499 cpu:0 346 | ... 347 | 348 | ``` python 349 | lj.rgb(numbers, in_stats) 350 | ``` 351 | 352 | ![](index_files/figure-commonmark/cell-30-output-1.png) 353 | 354 | ``` python 355 | lj.plot(numbers, center="mean") 356 | ``` 357 | 358 | ![](index_files/figure-commonmark/cell-31-output-1.svg) 359 | 360 | ``` python 361 | lj.chans(numbers_01) 362 | ``` 363 | 364 | ![](index_files/figure-commonmark/cell-32-output-1.png) 365 | 366 | ## Matplotlib integration \| [Docs](matplotlib.html) 367 | 368 | ``` python 369 | numbers.rgb(in_stats).fig # matplotlib figure 370 | ``` 371 | 372 | ![](index_files/figure-commonmark/cell-33-output-1.png) 373 | 374 | ``` python 375 | (numbers*0.3+0.5).chans.fig # matplotlib figure 376 | ``` 377 | 378 | ![](index_files/figure-commonmark/cell-34-output-1.png) 379 | 380 | ``` python 381 | numbers.plt.fig.savefig('pretty.svg') # Save it 382 | ``` 383 | 384 | ``` python 385 | !file pretty.svg; rm pretty.svg 386 | ``` 387 | 388 | pretty.svg: SVG Scalable Vector Graphics image 389 | 390 | ### Add content to existing Axes 391 | 392 | ``` python 393 | fig = plt.figure(figsize=(8,3)) 394 | fig.set_constrained_layout(True) 395 | gs = fig.add_gridspec(2,2) 396 | ax1 = fig.add_subplot(gs[0, :]) 397 | ax2 = fig.add_subplot(gs[1, 0]) 398 | ax3 = fig.add_subplot(gs[1,1:]) 399 | 400 | ax2.set_axis_off() 401 | ax3.set_axis_off() 402 | 403 | numbers_01.plt(ax=ax1) 404 | numbers_01.rgb(ax=ax2) 405 | numbers_01.chans(ax=ax3); 406 | ``` 407 | 408 | ![](index_files/figure-commonmark/cell-37-output-1.png) 409 | -------------------------------------------------------------------------------- /index_files/figure-commonmark/cell-13-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xl0/lovely-jax/0bf049db981d5737893bc992a4799d385b1d229a/index_files/figure-commonmark/cell-13-output-1.png -------------------------------------------------------------------------------- /index_files/figure-commonmark/cell-14-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xl0/lovely-jax/0bf049db981d5737893bc992a4799d385b1d229a/index_files/figure-commonmark/cell-14-output-1.png -------------------------------------------------------------------------------- /index_files/figure-commonmark/cell-15-output-1.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | image/svg+xml 10 | 11 | 12 | Matplotlib, https://matplotlib.org/ 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 29 | 30 | 31 | 32 | 38 | 39 | 40 | 41 | 42 | 43 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 62 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 195 | 196 | 197 | 200 | 201 | 202 | 205 | 206 | 207 | 210 | 211 | 212 | 215 | 216 | 217 | 220 | 221 | 222 | 225 | 226 | 227 | 230 | 231 | 232 | 235 | 236 | 237 | 240 | 241 | 242 | 245 | 246 | 247 | 248 | 259 | 260 | 261 | 262 | 263 | 293 | 306 | 325 | 338 | 345 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 409 | 410 | 411 | 412 | 413 | 446 | 461 | 486 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 546 | 547 | 578 | 604 | 611 | 636 | 649 | 677 | 686 | 709 | 739 | 752 | 773 | 794 | 810 | 827 | 844 | 855 | 887 | 898 | 922 | 932 | 953 | 975 | 988 | 989 | 990 | 991 | 992 | 993 | 994 | 995 | 996 | 997 | 998 | 999 | 1000 | 1001 | 1002 | 1003 | 1004 | 1005 | 1006 | 1007 | 1008 | 1009 | 1010 | 1011 | 1012 | 1013 | 1014 | 1015 | 1016 | 1017 | 1018 | 1019 | 1020 | 1021 | 1022 | 1023 | 1024 | 1025 | 1026 | 1027 | 1028 | 1029 | 1030 | 1031 | 1032 | 1033 | 1034 | 1035 | 1036 | 1037 | 1038 | 1039 | 1040 | 1041 | 1042 | 1043 | 1044 | 1045 | 1046 | 1047 | 1048 | 1049 | 1050 | 1051 | 1052 | 1053 | 1054 | 1055 | 1056 | 1057 | 1058 | 1059 | 1060 | 1061 | 1062 | 1063 | 1064 | 1065 | 1066 | 1067 | 1068 | 1069 | 1070 | 1071 | 1072 | 1073 | 1074 | 1075 | 1076 | 1077 | 1078 | 1079 | 1080 | 1081 | 1082 | 1083 | 1084 | 1085 | 1086 | 1087 | 1088 | 1089 | 1090 | 1091 | 1092 | 1093 | 1094 | 1100 | 1101 | 1102 | 1108 | 1109 | 1110 | 1116 | 1117 | 1118 | 1124 | 1125 | 1126 | 1132 | 1133 | 1134 | 1140 | 1141 | 1142 | 1148 | 1149 | 1150 | 1156 | 1157 | 1158 | 1164 | 1165 | 1166 | 1172 | 1173 | 1174 | 1180 | 1181 | 1182 | 1188 | 1189 | 1190 | 1196 | 1197 | 1198 | 1204 | 1205 | 1206 | 1212 | 1213 | 1214 | 1220 | 1221 | 1222 | 1228 | 1229 | 1230 | 1236 | 1237 | 1238 | 1244 | 1245 | 1246 | 1252 | 1253 | 1254 | 1260 | 1261 | 1262 | 1268 | 1269 | 1270 | 1276 | 1277 | 1278 | 1284 | 1285 | 1286 | 1292 | 1293 | 1294 | 1300 | 1301 | 1302 | 1308 | 1309 | 1310 | 1316 | 1317 | 1318 | 1324 | 1325 | 1326 | 1332 | 1333 | 1334 | 1340 | 1341 | 1342 | 1348 | 1349 | 1350 | 1356 | 1357 | 1358 | 1364 | 1365 | 1366 | 1372 | 1373 | 1374 | 1380 | 1381 | 1382 | 1388 | 1389 | 1390 | 1396 | 1397 | 1398 | 1404 | 1405 | 1406 | 1412 | 1413 | 1414 | 1515 | 1516 | 1517 | 1518 | 1529 | 1530 | 1531 | 1532 | 1533 | 1540 | 1541 | 1542 | 1543 | 1544 | 1545 | 1546 | 1547 | 1548 | 1559 | 1560 | 1561 | 1562 | 1563 | 1564 | 1565 | 1566 | 1567 | 1568 | 1579 | 1580 | 1581 | 1582 | 1583 | 1611 | 1612 | 1613 | 1614 | 1615 | 1616 | 1617 | 1628 | 1629 | 1630 | 1631 | 1632 | 1647 | 1648 | 1649 | 1650 | 1651 | 1652 | 1653 | 1654 | 1665 | 1666 | 1667 | 1668 | 1669 | 1670 | 1671 | 1672 | 1673 | 1674 | 1675 | 1676 | 1677 | 1678 | 1679 | 1680 | 1681 | -------------------------------------------------------------------------------- /index_files/figure-commonmark/cell-16-output-1.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | image/svg+xml 10 | 11 | 12 | Matplotlib, https://matplotlib.org/ 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 29 | 30 | 31 | 32 | 38 | 39 | 40 | 41 | 42 | 43 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 276 | 277 | 278 | 281 | 282 | 283 | 286 | 287 | 288 | 291 | 292 | 293 | 296 | 297 | 298 | 301 | 302 | 303 | 306 | 307 | 308 | 311 | 312 | 313 | 316 | 317 | 318 | 321 | 322 | 323 | 326 | 327 | 328 | 329 | 340 | 341 | 342 | 343 | 344 | 374 | 387 | 406 | 419 | 426 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 490 | 491 | 492 | 493 | 494 | 527 | 542 | 572 | 573 | 574 | 575 | 576 | 577 | 578 | 579 | 580 | 581 | 582 | 583 | 584 | 585 | 586 | 587 | 588 | 589 | 620 | 646 | 653 | 678 | 691 | 719 | 728 | 751 | 764 | 785 | 806 | 822 | 839 | 856 | 867 | 897 | 908 | 932 | 942 | 963 | 985 | 998 | 999 | 1000 | 1001 | 1002 | 1003 | 1004 | 1005 | 1006 | 1007 | 1008 | 1009 | 1010 | 1011 | 1012 | 1013 | 1014 | 1015 | 1016 | 1017 | 1018 | 1019 | 1020 | 1021 | 1022 | 1023 | 1024 | 1025 | 1026 | 1027 | 1028 | 1029 | 1030 | 1031 | 1032 | 1033 | 1034 | 1035 | 1036 | 1037 | 1038 | 1039 | 1040 | 1041 | 1042 | 1043 | 1044 | 1045 | 1046 | 1047 | 1048 | 1049 | 1050 | 1051 | 1052 | 1053 | 1054 | 1055 | 1056 | 1057 | 1058 | 1059 | 1060 | 1061 | 1062 | 1063 | 1064 | 1065 | 1066 | 1067 | 1068 | 1069 | 1070 | 1071 | 1072 | 1073 | 1074 | 1075 | 1076 | 1077 | 1078 | 1079 | 1080 | 1081 | 1082 | 1083 | 1084 | 1085 | 1086 | 1087 | 1088 | 1089 | 1090 | 1091 | 1092 | 1093 | 1094 | 1095 | 1096 | 1097 | 1098 | 1099 | 1100 | 1101 | 1102 | 1103 | 1109 | 1110 | 1111 | 1117 | 1118 | 1119 | 1125 | 1126 | 1127 | 1133 | 1134 | 1135 | 1141 | 1142 | 1143 | 1149 | 1150 | 1151 | 1157 | 1158 | 1159 | 1165 | 1166 | 1167 | 1173 | 1174 | 1175 | 1181 | 1182 | 1183 | 1189 | 1190 | 1191 | 1197 | 1198 | 1199 | 1205 | 1206 | 1207 | 1213 | 1214 | 1215 | 1221 | 1222 | 1223 | 1324 | 1325 | 1326 | 1327 | 1338 | 1339 | 1340 | 1341 | 1342 | 1349 | 1350 | 1351 | 1352 | 1353 | 1354 | 1355 | 1356 | 1357 | 1368 | 1369 | 1370 | 1371 | 1372 | 1373 | 1374 | 1375 | 1376 | 1377 | 1388 | 1389 | 1390 | 1391 | 1392 | 1420 | 1421 | 1422 | 1423 | 1424 | 1425 | 1426 | 1437 | 1438 | 1439 | 1440 | 1441 | 1456 | 1457 | 1458 | 1459 | 1460 | 1461 | 1462 | 1463 | 1474 | 1475 | 1476 | 1477 | 1478 | 1479 | 1480 | 1481 | 1482 | 1483 | 1484 | 1485 | 1486 | 1487 | 1488 | 1489 | 1490 | -------------------------------------------------------------------------------- /index_files/figure-commonmark/cell-19-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xl0/lovely-jax/0bf049db981d5737893bc992a4799d385b1d229a/index_files/figure-commonmark/cell-19-output-1.png -------------------------------------------------------------------------------- /index_files/figure-commonmark/cell-21-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xl0/lovely-jax/0bf049db981d5737893bc992a4799d385b1d229a/index_files/figure-commonmark/cell-21-output-1.png -------------------------------------------------------------------------------- /index_files/figure-commonmark/cell-30-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xl0/lovely-jax/0bf049db981d5737893bc992a4799d385b1d229a/index_files/figure-commonmark/cell-30-output-1.png -------------------------------------------------------------------------------- /index_files/figure-commonmark/cell-32-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xl0/lovely-jax/0bf049db981d5737893bc992a4799d385b1d229a/index_files/figure-commonmark/cell-32-output-1.png -------------------------------------------------------------------------------- /index_files/figure-commonmark/cell-33-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xl0/lovely-jax/0bf049db981d5737893bc992a4799d385b1d229a/index_files/figure-commonmark/cell-33-output-1.png -------------------------------------------------------------------------------- /index_files/figure-commonmark/cell-34-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xl0/lovely-jax/0bf049db981d5737893bc992a4799d385b1d229a/index_files/figure-commonmark/cell-34-output-1.png -------------------------------------------------------------------------------- /index_files/figure-commonmark/cell-37-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xl0/lovely-jax/0bf049db981d5737893bc992a4799d385b1d229a/index_files/figure-commonmark/cell-37-output-1.png -------------------------------------------------------------------------------- /lovely_jax/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.3" 2 | 3 | 4 | from .repr_str import * 5 | from .repr_rgb import * 6 | from .repr_plt import * 7 | from .repr_chans import * 8 | from .patch import * 9 | from .utils import * -------------------------------------------------------------------------------- /lovely_jax/_modidx.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by nbdev 2 | 3 | d = { 'settings': { 'branch': 'master', 4 | 'doc_baseurl': '/lovely-jax', 5 | 'doc_host': 'https://xl0.github.io', 6 | 'git_url': 'https://github.com/xl0/lovely-jax', 7 | 'lib_path': 'lovely_jax'}, 8 | 'syms': { 'lovely_jax.patch': { 'lovely_jax.patch._monkey_patch': ('patch.html#_monkey_patch', 'lovely_jax/patch.py'), 9 | 'lovely_jax.patch.monkey_patch': ('patch.html#monkey_patch', 'lovely_jax/patch.py')}, 10 | 'lovely_jax.repr_chans': { 'lovely_jax.repr_chans.ChanProxy': ('repr_chans.html#chanproxy', 'lovely_jax/repr_chans.py'), 11 | 'lovely_jax.repr_chans.ChanProxy.__call__': ( 'repr_chans.html#chanproxy.__call__', 12 | 'lovely_jax/repr_chans.py'), 13 | 'lovely_jax.repr_chans.ChanProxy.__init__': ( 'repr_chans.html#chanproxy.__init__', 14 | 'lovely_jax/repr_chans.py'), 15 | 'lovely_jax.repr_chans.ChanProxy._repr_png_': ( 'repr_chans.html#chanproxy._repr_png_', 16 | 'lovely_jax/repr_chans.py'), 17 | 'lovely_jax.repr_chans.ChanProxy.fig': ('repr_chans.html#chanproxy.fig', 'lovely_jax/repr_chans.py'), 18 | 'lovely_jax.repr_chans.chans': ('repr_chans.html#chans', 'lovely_jax/repr_chans.py')}, 19 | 'lovely_jax.repr_plt': { 'lovely_jax.repr_plt.PlotProxy': ('repr_plt.html#plotproxy', 'lovely_jax/repr_plt.py'), 20 | 'lovely_jax.repr_plt.PlotProxy.__call__': ( 'repr_plt.html#plotproxy.__call__', 21 | 'lovely_jax/repr_plt.py'), 22 | 'lovely_jax.repr_plt.PlotProxy.__init__': ( 'repr_plt.html#plotproxy.__init__', 23 | 'lovely_jax/repr_plt.py'), 24 | 'lovely_jax.repr_plt.PlotProxy._repr_png_': ( 'repr_plt.html#plotproxy._repr_png_', 25 | 'lovely_jax/repr_plt.py'), 26 | 'lovely_jax.repr_plt.PlotProxy._repr_svg_': ( 'repr_plt.html#plotproxy._repr_svg_', 27 | 'lovely_jax/repr_plt.py'), 28 | 'lovely_jax.repr_plt.PlotProxy.fig': ('repr_plt.html#plotproxy.fig', 'lovely_jax/repr_plt.py'), 29 | 'lovely_jax.repr_plt.plot': ('repr_plt.html#plot', 'lovely_jax/repr_plt.py')}, 30 | 'lovely_jax.repr_rgb': { 'lovely_jax.repr_rgb.RGBProxy': ('repr_rgb.html#rgbproxy', 'lovely_jax/repr_rgb.py'), 31 | 'lovely_jax.repr_rgb.RGBProxy.__call__': ('repr_rgb.html#rgbproxy.__call__', 'lovely_jax/repr_rgb.py'), 32 | 'lovely_jax.repr_rgb.RGBProxy.__init__': ('repr_rgb.html#rgbproxy.__init__', 'lovely_jax/repr_rgb.py'), 33 | 'lovely_jax.repr_rgb.RGBProxy._repr_png_': ( 'repr_rgb.html#rgbproxy._repr_png_', 34 | 'lovely_jax/repr_rgb.py'), 35 | 'lovely_jax.repr_rgb.RGBProxy.fig': ('repr_rgb.html#rgbproxy.fig', 'lovely_jax/repr_rgb.py'), 36 | 'lovely_jax.repr_rgb.rgb': ('repr_rgb.html#rgb', 'lovely_jax/repr_rgb.py')}, 37 | 'lovely_jax.repr_str': { 'lovely_jax.repr_str.StrProxy': ('repr_str.html#strproxy', 'lovely_jax/repr_str.py'), 38 | 'lovely_jax.repr_str.StrProxy.__call__': ('repr_str.html#strproxy.__call__', 'lovely_jax/repr_str.py'), 39 | 'lovely_jax.repr_str.StrProxy.__init__': ('repr_str.html#strproxy.__init__', 'lovely_jax/repr_str.py'), 40 | 'lovely_jax.repr_str.StrProxy.__repr__': ('repr_str.html#strproxy.__repr__', 'lovely_jax/repr_str.py'), 41 | 'lovely_jax.repr_str.history_warning': ('repr_str.html#history_warning', 'lovely_jax/repr_str.py'), 42 | 'lovely_jax.repr_str.is_nasty': ('repr_str.html#is_nasty', 'lovely_jax/repr_str.py'), 43 | 'lovely_jax.repr_str.jax_to_str_common': ('repr_str.html#jax_to_str_common', 'lovely_jax/repr_str.py'), 44 | 'lovely_jax.repr_str.lovely': ('repr_str.html#lovely', 'lovely_jax/repr_str.py'), 45 | 'lovely_jax.repr_str.plain_repr': ('repr_str.html#plain_repr', 'lovely_jax/repr_str.py'), 46 | 'lovely_jax.repr_str.short_dtype': ('repr_str.html#short_dtype', 'lovely_jax/repr_str.py'), 47 | 'lovely_jax.repr_str.to_str': ('repr_str.html#to_str', 'lovely_jax/repr_str.py')}, 48 | 'lovely_jax.utils.config': { 'lovely_jax.utils.config.Config': ('utils.config.html#config', 'lovely_jax/utils/config.py'), 49 | 'lovely_jax.utils.config.Config.__init__': ( 'utils.config.html#config.__init__', 50 | 'lovely_jax/utils/config.py'), 51 | 'lovely_jax.utils.config._Default': ('utils.config.html#_default', 'lovely_jax/utils/config.py'), 52 | 'lovely_jax.utils.config._Default.__repr__': ( 'utils.config.html#_default.__repr__', 53 | 'lovely_jax/utils/config.py'), 54 | 'lovely_jax.utils.config.config': ('utils.config.html#config', 'lovely_jax/utils/config.py'), 55 | 'lovely_jax.utils.config.get_config': ( 'utils.config.html#get_config', 56 | 'lovely_jax/utils/config.py'), 57 | 'lovely_jax.utils.config.set_config': ( 'utils.config.html#set_config', 58 | 'lovely_jax/utils/config.py')}, 59 | 'lovely_jax.utils.misc': { 'lovely_jax.utils.misc.is_cpu': ('utils.misc.html#is_cpu', 'lovely_jax/utils/misc.py'), 60 | 'lovely_jax.utils.misc.test_array_repr': ( 'utils.misc.html#test_array_repr', 61 | 'lovely_jax/utils/misc.py'), 62 | 'lovely_jax.utils.misc.to_numpy': ('utils.misc.html#to_numpy', 'lovely_jax/utils/misc.py')}}} 63 | -------------------------------------------------------------------------------- /lovely_jax/patch.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/10_patch.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['monkey_patch'] 5 | 6 | # %% ../nbs/10_patch.ipynb 5 7 | import numpy as np 8 | import jax 9 | import jax.numpy as jnp 10 | from jax._src import array 11 | from fastcore.foundation import patch_to 12 | import matplotlib.pyplot as plt 13 | 14 | from .repr_str import StrProxy 15 | from .repr_rgb import RGBProxy 16 | from .repr_plt import PlotProxy 17 | from .repr_chans import ChanProxy 18 | 19 | # %% ../nbs/10_patch.ipynb 6 20 | def _monkey_patch(cls): 21 | "Monkey-patch lovely features into `cls`" 22 | 23 | if not hasattr(cls, '_plain_repr'): 24 | cls._plain_repr = cls.__repr__ 25 | cls._plain_str = cls.__str__ 26 | cls._plain_format = cls.__format__ 27 | 28 | @patch_to(cls) 29 | def __repr__(self: jax.Array): 30 | return str(StrProxy(self)) 31 | 32 | # __str__ is used when you do print(), and gives a less detailed version of the object. 33 | # __repr__ is used when you inspect an object in Jupyter or VSCode, and gives a more detailed version. 34 | # I think we want to patch both. 35 | @patch_to(cls) 36 | def __str__(self: jax.Array): 37 | return str(StrProxy(self)) 38 | 39 | # Without this, the native __format__ will call into numpy formatter 40 | # and will produce raw numbers. Idea: A way to pass fmt through? 41 | @patch_to(cls) 42 | def __format__(self: jax.Array, tmp: str): 43 | return str(StrProxy(self)) 44 | 45 | # Plain - the old behavior 46 | @patch_to(cls, as_prop=True) 47 | def p(self: jax.Array): 48 | return StrProxy(self, plain=True) 49 | 50 | # Verbose - print both stats and plain values 51 | @patch_to(cls, as_prop=True) 52 | def v(self: jax.Array): 53 | return StrProxy(self, verbose=True) 54 | 55 | @patch_to(cls, as_prop=True) 56 | def deeper(self: jax.Array): 57 | return StrProxy(self, depth=1) 58 | 59 | @patch_to(cls, as_prop=True) 60 | def rgb(t: jax.Array): 61 | return RGBProxy(t) 62 | 63 | @patch_to(cls, as_prop=True) 64 | def chans(t: jax.Array): 65 | return ChanProxy(t) 66 | 67 | @patch_to(cls, as_prop=True) 68 | def plt(t: jax.Array): 69 | return PlotProxy(t) 70 | 71 | 72 | def monkey_patch(): 73 | _monkey_patch(array.ArrayImpl) 74 | # To support jax version higher than 0.4.14 75 | if hasattr(array, "DeviceArray"): 76 | _monkey_patch(array.DeviceArray) 77 | 78 | # This was required for earlied version of jax 0.4.x 79 | # In jax version higher than 0.4.14 pxla is not accesible 80 | # instead we use jax.interpreters.pxla 81 | if not hasattr(jax, "interpreters"): 82 | if hasattr(jax.pxla, '_ShardedDeviceArray'): 83 | _monkey_patch(jax.pxla._ShardedDeviceArray) 84 | 85 | -------------------------------------------------------------------------------- /lovely_jax/repr_chans.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/05_repr_chans.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['chans'] 5 | 6 | # %% ../nbs/05_repr_chans.ipynb 4 7 | from typing import Any, Optional as O 8 | 9 | import jax, jax.numpy as jnp 10 | from matplotlib import pyplot as plt, axes, figure 11 | from IPython.core.pylabtools import print_figure 12 | 13 | from lovely_numpy.utils.utils import cached_property 14 | from lovely_numpy.repr_chans import fig_chans 15 | from lovely_numpy import config as np_config 16 | 17 | from .utils.misc import to_numpy 18 | from .utils.config import get_config 19 | 20 | 21 | # %% ../nbs/05_repr_chans.ipynb 6 22 | class ChanProxy(): 23 | def __init__(self, t: jax.Array): 24 | self.t = t 25 | self.params = dict(cmap = "twilight", 26 | cm_below="blue", 27 | cm_above="red", 28 | cm_ninf="cyan", 29 | cm_pinf="fuchsia", 30 | cm_nan="yellow", 31 | view_width=966, 32 | gutter_px=3, 33 | frame_px=1, 34 | scale=1, 35 | cl=True, 36 | ax=None) 37 | 38 | def __call__(self, 39 | cmap :O[str]=None, 40 | cm_below :O[str]=None, 41 | cm_above :O[str]=None, 42 | cm_ninf :O[str]=None, 43 | cm_pinf :O[str]=None, 44 | cm_nan :O[str]=None, 45 | view_width :O[int]=None, 46 | gutter_px :O[int]=None, 47 | frame_px :O[int]=None, 48 | scale :O[int]=None, 49 | cl :Any=None, 50 | ax :O[axes.Axes]=None): 51 | 52 | self.params.update( { k:v for 53 | k,v in locals().items() 54 | if k != "self" and v is not None } ) 55 | _ = self.fig # Trigger figure generation 56 | return self 57 | 58 | @cached_property 59 | def fig(self) -> figure.Figure: 60 | cfg = get_config() 61 | with np_config(fig_close=cfg.fig_close, fig_show=cfg.fig_show): 62 | return fig_chans(to_numpy(self.t), **self.params) 63 | 64 | def _repr_png_(self): 65 | return print_figure(self.fig, fmt="png", pad_inches=0, 66 | metadata={"Software": "Matplotlib, https://matplotlib.org/"}) 67 | 68 | # %% ../nbs/05_repr_chans.ipynb 7 69 | def chans( x: jax.Array, # Input, shape=([...], H, W) 70 | cmap :str ="twilight",# Use matplotlib colormap by this name 71 | cm_below :str ="blue", # Color for values below -1 72 | cm_above :str ="red", # Color for values above 1 73 | cm_ninf :str ="cyan", # Color for -inf values 74 | cm_pinf :str ="fuchsia", # Color for +inf values 75 | cm_nan :str ="yellow", # Color for NaN values 76 | view_width :int =966, # Try to produce an image at most this wide 77 | gutter_px :int =3, # Draw write gutters when tiling the images 78 | frame_px :int =1, # Draw black frame around each image 79 | scale :int =1, 80 | cl :Any =True, 81 | ax :O[axes.Axes]=None 82 | ) -> ChanProxy: 83 | 84 | "Map tensor values to colors. RGB[A] color is added as channel-last" 85 | args = locals() 86 | del args["x"] 87 | 88 | return ChanProxy(x)(**args) 89 | -------------------------------------------------------------------------------- /lovely_jax/repr_plt.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_repr_plt.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['plot'] 5 | 6 | # %% ../nbs/02_repr_plt.ipynb 4 7 | import math 8 | from typing import Union, Any, Optional as O 9 | 10 | import jax, jax.numpy as jnp 11 | from matplotlib import pyplot as plt, axes, figure, rc_context, rcParams 12 | from IPython.core.pylabtools import print_figure 13 | 14 | from lovely_numpy.utils.utils import cached_property 15 | from lovely_numpy.repr_plt import fig_plot 16 | from lovely_numpy import config as np_config 17 | 18 | from .repr_str import to_str, pretty_str 19 | from .utils.misc import to_numpy 20 | from .utils.config import get_config, config 21 | 22 | # %% ../nbs/02_repr_plt.ipynb 6 23 | # This is here for the monkey-patched tensor use case. 24 | # Gives the ability to call both .plt and .plt(ax=ax). 25 | 26 | class PlotProxy(): 27 | """Flexible `PIL.Image.Image` wrapper""" 28 | 29 | def __init__(self, x:jax.Array): 30 | self.x = x 31 | self.params = dict( center="zero", 32 | max_s=10000, 33 | plt0=True, 34 | ax=None) 35 | 36 | def __call__( self, 37 | center :O[str] =None, 38 | max_s :O[int] =None, 39 | plt0 :Any =None, 40 | ax :O[axes.Axes]=None): 41 | 42 | self.params.update( { k:v for 43 | k,v in locals().items() 44 | if k != "self" and v is not None } ) 45 | 46 | _ = self.fig # Trigger figure generation 47 | return self 48 | 49 | @cached_property 50 | def fig(self) -> figure.Figure: 51 | cfg = get_config() 52 | with np_config( fig_close=cfg.fig_close, 53 | fig_show=cfg.fig_show, 54 | plt_seed=cfg.plt_seed ), config(show_mem_above=jnp.inf): 55 | return fig_plot( to_numpy(self.x), 56 | summary=to_str(self.x, color=False), 57 | **self.params) 58 | 59 | def _repr_png_(self): 60 | return print_figure(self.fig, fmt="png", 61 | metadata={"Software": "Matplotlib, https://matplotlib.org/"}) 62 | 63 | def _repr_svg_(self): 64 | # Metadata and context for a mode deterministic svg generation 65 | metadata={ 66 | "Date": None, 67 | "Creator": "Matplotlib, https://matplotlib.org/", 68 | } 69 | with rc_context({"svg.hashsalt": "1"}): 70 | svg_repr = print_figure(self.fig, fmt="svg", metadata=metadata) 71 | return svg_repr 72 | 73 | 74 | # %% ../nbs/02_repr_plt.ipynb 7 75 | def plot( x :jax.Array, # Tensor to explore 76 | center :str ="zero", # Center plot on `zero`, `mean`, or `range` 77 | max_s :int =10000, # Draw up to this many samples. =0 to draw all 78 | plt0 :Any =True, # Take zero values into account 79 | ax :O[axes.Axes]=None # Optionally provide a matplotlib axes. 80 | ) -> PlotProxy: 81 | 82 | args = locals() 83 | del args["x"] 84 | 85 | return PlotProxy(x)(**args) 86 | 87 | -------------------------------------------------------------------------------- /lovely_jax/repr_rgb.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_repr_rgb.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['rgb'] 5 | 6 | # %% ../nbs/01_repr_rgb.ipynb 4 7 | from typing import Any, Optional as O 8 | 9 | 10 | from matplotlib import axes, figure 11 | from IPython.core.pylabtools import print_figure 12 | from PIL import Image 13 | import jax, jax.numpy as jnp 14 | 15 | from lovely_numpy.utils.utils import cached_property 16 | from lovely_numpy.utils.pad import pad_frame_gutters 17 | from lovely_numpy.utils.tile2d import hypertile 18 | from lovely_numpy.repr_rgb import fig_rgb 19 | from lovely_numpy import config as np_config 20 | 21 | from .utils.misc import to_numpy 22 | from .utils.config import get_config 23 | 24 | 25 | # %% ../nbs/01_repr_rgb.ipynb 6 26 | # This is here for the monkey-patched tensor use case. 27 | 28 | # I want to be able to call both `tensor.rgb` and `tensor.rgb(stats)`. For the 29 | # first case, the class defines `_repr_png_` to send the image to Jupyter. For 30 | # the later case, it defines __call__, which accps the argument. 31 | 32 | class RGBProxy(): 33 | """Flexible `PIL.Image.Image` wrapper""" 34 | 35 | def __init__(self, x: jax.Array): 36 | assert x.ndim >= 3, f"Expecting at least 3 dimensions, got shape{x.shape}={x.ndim}" 37 | self.x =x 38 | self.params = dict(denorm = None, 39 | cl = True, 40 | gutter_px = 3, 41 | frame_px = 1, 42 | scale = 1, 43 | view_width = 966, 44 | ax = None) 45 | 46 | 47 | def __call__(self, 48 | denorm :Any =None, 49 | cl :Any =True, 50 | gutter_px :O[int] =None, 51 | frame_px :O[int] =None, 52 | scale :O[int] =None, 53 | view_width :O[int] =None, 54 | ax :O[axes.Axes]=None): 55 | 56 | self.params.update( { k:v for 57 | k,v in locals().items() 58 | if k != "self" and v is not None } ) 59 | _ = self.fig # Trigger figure generation 60 | return self 61 | 62 | @cached_property 63 | def fig(self) -> figure.Figure: 64 | cfg = get_config() 65 | with np_config(fig_close=cfg.fig_close, fig_show=cfg.fig_show): 66 | return fig_rgb(to_numpy(self.x), **self.params) 67 | 68 | def _repr_png_(self): 69 | return print_figure(self.fig, fmt="png", pad_inches=0, 70 | metadata={"Software": "Matplotlib, https://matplotlib.org/"}) 71 | 72 | 73 | # %% ../nbs/01_repr_rgb.ipynb 7 74 | def rgb(x :jax.Array, # Tensor to display. [[...], C,H,W] or [[...], H,W,C] 75 | denorm :Any =None, # Reverse per-channel normalizatoin 76 | cl :Any =True, # Channel-last 77 | gutter_px :int =3, # If more than one tensor -> tile with this gutter width 78 | frame_px :int =1, # If more than one tensor -> tile with this frame width 79 | scale :int =1, # Scale up. Can't scale down. 80 | view_width :int =966, # target width of the image 81 | ax :O[axes.Axes] =None # Use this Axes 82 | ) -> RGBProxy: 83 | 84 | args = locals() 85 | del args["x"] 86 | 87 | return RGBProxy(x)(**args) 88 | -------------------------------------------------------------------------------- /lovely_jax/repr_str.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_repr_str.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['jax_to_str_common', 'lovely'] 5 | 6 | # %% ../nbs/00_repr_str.ipynb 5 7 | import warnings 8 | from typing import Union, Optional as O 9 | 10 | import numpy as np 11 | import jax, jax.numpy as jnp 12 | 13 | from lovely_numpy import np_to_str_common, pretty_str, sparse_join, ansi_color, in_debugger, bytes_to_human 14 | from lovely_numpy import config as lnp_config 15 | 16 | from .utils.config import get_config, config 17 | from .utils.misc import to_numpy, is_cpu, test_array_repr 18 | 19 | # %% ../nbs/00_repr_str.ipynb 8 20 | dtnames = { "float16": "f16", 21 | "float32": "", # Default dtype in jax 22 | "float64": "f64", 23 | "bfloat16": "bf16", 24 | "uint8": "u8", 25 | "uint16": "u16", 26 | "uint32": "u32", 27 | "uint64": "u64", 28 | "int8": "i8", 29 | "int16": "i16", 30 | "int32": "i32", 31 | "int64": "i64", 32 | } 33 | 34 | def short_dtype(x: jax.Array) -> str: 35 | return dtnames.get(x.dtype.name, str(x.dtype)) 36 | 37 | # %% ../nbs/00_repr_str.ipynb 10 38 | def plain_repr(x: jax.Array): 39 | "Pick the right function to get a plain repr" 40 | # assert isinstance(x, np.ndarray), f"expected np.ndarray but got {type(x)}" # Could be a sub-class. 41 | return x._plain_repr() if hasattr(x, "_plain_repr") else repr(x) 42 | 43 | # def plain_str(x: torch.Tensor): 44 | # "Pick the right function to get a plain str." 45 | # # assert isinstance(x, np.ndarray), f"expected np.ndarray but got {type(x)}" 46 | # return x._plain_str() if hasattr(type(x), "_plain_str") else str(x) 47 | 48 | # %% ../nbs/00_repr_str.ipynb 11 49 | def is_nasty(x: jax.Array): 50 | """Return true of any `x` values are inf or nan""" 51 | 52 | if x.size == 0: return False # min/max don't like zero-lenght arrays 53 | 54 | x_min = x.min() 55 | x_max = x.max() 56 | 57 | return jnp.isnan(x_min) or jnp.isinf(x_min) or jnp.isinf(x_max) 58 | 59 | # %% ../nbs/00_repr_str.ipynb 13 60 | def jax_to_str_common(x: jax.Array, # Input 61 | color=True, # ANSI color highlighting 62 | ddof=0): # For "std" unbiasing 63 | 64 | if x.size == 0: 65 | return ansi_color("empty", "grey", color) 66 | 67 | zeros = ansi_color("all_zeros", "grey", color) if jnp.equal(x, 0.).all() and x.size > 1 else None 68 | # pinf = ansi_color("+Inf!", "red", color) if jnp.isposinf(x).any() else None 69 | # ninf = ansi_color("-Inf!", "red", color) if jnp.isneginf(x).any() else None 70 | # nan = ansi_color("NaN!", "red", color) if jnp.isnan(x).any() else None 71 | 72 | # attention = sparse_join([zeros,pinf,ninf,nan]) 73 | 74 | summary = None 75 | if not zeros and x.ndim > 0: 76 | minmax = f"x∈[{pretty_str(x.min())}, {pretty_str(x.max())}]" if x.size > 2 else None 77 | meanstd = f"μ={pretty_str(x.mean())} σ={pretty_str(x.std(ddof=ddof))}" if x.size >= 2 else None 78 | summary = sparse_join([minmax, meanstd]) 79 | 80 | 81 | return sparse_join([ summary, zeros]) 82 | 83 | # %% ../nbs/00_repr_str.ipynb 14 84 | def to_str(x: jax.Array, # Input 85 | plain: bool=False, 86 | verbose: bool=False, 87 | depth=0, 88 | lvl=0, 89 | color=None) -> str: 90 | 91 | if plain: 92 | return plain_repr(x) 93 | 94 | conf = get_config() 95 | 96 | tname = type(x).__name__.split(".")[-1] 97 | if tname in ("ArrayImpl"): tname = "Array" 98 | shape = str(list(x.shape)) if x.ndim else None 99 | type_str = sparse_join([tname, shape], sep="") 100 | 101 | if hasattr(x, "devices"): # Unified Array (jax >= 0.4) 102 | int_dev_ids = sorted([d.id for d in x.devices()]) 103 | ids = ",".join(map(str, int_dev_ids)) 104 | dev = f"{list(x.devices())[0].platform}:{ids}" 105 | elif hasattr(x, "device"): # Old-style DeviceArray 106 | dev = f"{x.device().platform}:{x.device().id}" 107 | elif hasattr(x, "sharding"): 108 | int_dev_ids = sorted([d.id for d in x.sharding.devices]) 109 | ids = ",".join(map(str, int_dev_ids)) 110 | dev = f"{x.sharding.devices[0].platform}:{ids}" 111 | else: 112 | assert 0, f"Weird input type={type(input)}, expecrted Array, DeviceArray, or ShardedDeviceArray" 113 | 114 | dtype = short_dtype(x) 115 | # grad_fn = t.grad_fn.name() if t.grad_fn else None 116 | # PyTorch does not want you to know, but all `grad_fn`` 117 | # tensors actuall have `requires_grad=True`` too. 118 | # grad = "grad" if t.requires_grad else None 119 | grad = grad_fn = None 120 | 121 | # For complex tensors, just show the shape / size part for now. 122 | if not jnp.iscomplexobj(x): 123 | if color is None: color=conf.color 124 | if in_debugger(): color = False 125 | # `lovely-numpy` is used to calculate stats when doing so on GPU would require 126 | # memory allocation (not float tensors, tensors with bad numbers), or if the 127 | # data is on CPU (because numpy is faster). 128 | # 129 | # Temporarily set the numpy config to match our config for consistency. 130 | with lnp_config(precision=conf.precision, 131 | threshold_min=conf.threshold_min, 132 | threshold_max=conf.threshold_max, 133 | sci_mode=conf.sci_mode): 134 | 135 | if is_cpu(x) or is_nasty(x): 136 | common = np_to_str_common(np.array(x), color=color) 137 | else: 138 | common = jax_to_str_common(x, color=color) 139 | 140 | numel = None 141 | if x.shape and max(x.shape) != x.size: 142 | numel = f"n={x.size}" 143 | if get_config().show_mem_above <= x.nbytes: 144 | numel = sparse_join([numel, f"({bytes_to_human(x.nbytes)})"]) 145 | elif get_config().show_mem_above <= x.nbytes: 146 | numel = bytes_to_human(x.nbytes) 147 | 148 | vals = pretty_str(x) if 0 < x.size <= 10 else None 149 | res = sparse_join([type_str, dtype, numel, common, grad, grad_fn, dev, vals]) 150 | else: 151 | res = plain_repr(x) 152 | 153 | if verbose: 154 | res += "\n" + plain_repr(x) 155 | 156 | if depth and x.ndim > 1: 157 | with config(show_mem_above=jnp.inf): 158 | deep_width = min((x.shape[0]), conf.deeper_width) # Print at most this many lines 159 | deep_lines = [ " "*conf.indent*(lvl+1) + to_str(x[i,:], depth=depth-1, lvl=lvl+1) 160 | for i in range(deep_width)] 161 | 162 | # If we were limited by width, print ... 163 | if deep_width < x.shape[0]: deep_lines.append(" "*conf.indent*(lvl+1) + "...") 164 | 165 | res += "\n" + "\n".join(deep_lines) 166 | 167 | return res 168 | 169 | # %% ../nbs/00_repr_str.ipynb 15 170 | def history_warning(): 171 | "Issue a warning (once) ifw e are running in IPYthon with output cache enabled" 172 | 173 | if "get_ipython" in globals() and get_ipython().cache_size > 0: 174 | warnings.warn("IPYthon has its output cache enabled. See https://xl0.github.io/lovely-tensors/history.html") 175 | 176 | # %% ../nbs/00_repr_str.ipynb 18 177 | class StrProxy(): 178 | def __init__(self, x: jax.Array, plain=False, verbose=False, depth=0, lvl=0, color=None): 179 | self.x = x 180 | self.plain = plain 181 | self.verbose = verbose 182 | self.depth=depth 183 | self.lvl=lvl 184 | self.color=color 185 | history_warning() 186 | 187 | def __repr__(self): 188 | return to_str(self.x, plain=self.plain, verbose=self.verbose, 189 | depth=self.depth, lvl=self.lvl, color=self.color) 190 | 191 | # This is used for .deeper attribute and .deeper(depth=...). 192 | # The second onthe results in a __call__. 193 | def __call__(self, depth=1): 194 | return StrProxy(self.x, depth=depth) 195 | 196 | # %% ../nbs/00_repr_str.ipynb 19 197 | def lovely(x: jax.Array, # Tensor of interest 198 | verbose=False, # Whether to show the full tensor 199 | plain=False, # Just print if exactly as before 200 | depth=0, # Show stats in depth 201 | color=None): # Force color (True/False) or auto. 202 | return StrProxy(x, verbose=verbose, plain=plain, depth=depth, color=color) 203 | -------------------------------------------------------------------------------- /lovely_jax/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .config import * -------------------------------------------------------------------------------- /lovely_jax/utils/config.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/03a_utils.config.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['set_config', 'get_config', 'config'] 5 | 6 | # %% ../../nbs/03a_utils.config.ipynb 5 7 | from copy import copy 8 | from types import SimpleNamespace 9 | from typing import Optional, Union, Callable, TypeVar 10 | from contextlib import contextmanager 11 | from lovely_numpy import config as np_config 12 | 13 | # %% ../../nbs/03a_utils.config.ipynb 6 14 | class Config(SimpleNamespace): 15 | "Config" 16 | def __init__(self, 17 | precision = 3, # Digits after `.` 18 | threshold_max = 3, # .abs() larger than 1e3 -> Sci mode 19 | threshold_min = -4, # .abs() smaller that 1e-4 -> Sci mode 20 | sci_mode = None, # Sci mode (2.3e4). None=auto 21 | show_mem_above= 1024, # Show memory footprint above this size 22 | indent = 2, # Indent for .deeper() 23 | color = True, # ANSI colors in text 24 | deeper_width = 9, # For .deeper, width per level 25 | plt_seed = 42, # Sampling seed for `plot` 26 | fig_close = True, # Close matplotlib Figure 27 | fig_show = False,# Call `plt.show()` for `.plt`, `.chans` and `.rgb` 28 | 29 | ): 30 | super().__init__(**{k:v for k,v in locals().items() if k not in ["self", "__class__"]}) 31 | 32 | _defaults = Config() 33 | _config = copy(_defaults) 34 | 35 | # %% ../../nbs/03a_utils.config.ipynb 9 36 | # Allows passing None as an argument to reset the 37 | class _Default(): 38 | def __repr__(self): 39 | return "Ignore" 40 | D = _Default() 41 | Default = TypeVar("Default") 42 | 43 | # %% ../../nbs/03a_utils.config.ipynb 10 44 | def set_config( precision :Optional[Union[Default,int]] =D, 45 | threshold_min :Optional[Union[Default,int]] =D, 46 | threshold_max :Optional[Union[Default,int]] =D, 47 | sci_mode :Optional[Union[Default,bool]] =D, 48 | show_mem_above :Optional[Union[Default,int]] =D, 49 | indent :Optional[Union[Default,bool]] =D, 50 | color :Optional[Union[Default,bool]] =D, 51 | deeper_width :Optional[Union[Default,int]] =D, 52 | plt_seed :Optional[Union[Default,int]] =D, 53 | fig_close :Optional[Union[Default,bool]] =D, 54 | fig_show :Optional[Union[Default,bool]] =D 55 | ) -> None: 56 | 57 | "Set config variables" 58 | args = locals().copy() 59 | for k,v in args.items(): 60 | if v != D: 61 | if v is None: 62 | setattr(_config, k, getattr(_defaults, k)) 63 | else: 64 | setattr(_config, k, v) 65 | 66 | # %% ../../nbs/03a_utils.config.ipynb 11 67 | def get_config(): 68 | "Get a copy of config variables" 69 | return copy(_config) 70 | 71 | # %% ../../nbs/03a_utils.config.ipynb 12 72 | @contextmanager 73 | def config( precision :Optional[Union[Default,int]] =D, 74 | threshold_min :Optional[Union[Default,int]] =D, 75 | threshold_max :Optional[Union[Default,int]] =D, 76 | sci_mode :Optional[Union[Default,bool]] =D, 77 | show_mem_above :Optional[Union[Default,int]] =D, 78 | indent :Optional[Union[Default,bool]] =D, 79 | color :Optional[Union[Default,bool]] =D, 80 | deeper_width :Optional[Union[Default,int]] =D, 81 | plt_seed :Optional[Union[Default,int]] =D, 82 | fig_close :Optional[Union[Default,bool]] =D, 83 | fig_show :Optional[Union[Default,bool]] =D 84 | ): 85 | 86 | 87 | "Context manager for temporarily setting printting options." 88 | global _config 89 | new_opts = { k:v for k, v in locals().items() if v != D } 90 | old_opts = copy(get_config().__dict__) 91 | 92 | try: 93 | set_config(**new_opts) 94 | yield 95 | finally: 96 | set_config(**old_opts) 97 | -------------------------------------------------------------------------------- /lovely_jax/utils/misc.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/03b_utils.misc.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = [] 5 | 6 | # %% ../../nbs/03b_utils.misc.ipynb 4 7 | import re 8 | import codecs 9 | import numpy as np 10 | import jax, jax.numpy as jnp 11 | 12 | from fastcore.test import test_eq 13 | 14 | # %% ../../nbs/03b_utils.misc.ipynb 6 15 | def to_numpy(t): 16 | # if t.dtype == jnp.bfloat16: 17 | return np.array(t, dtype=np.float32) 18 | # return t 19 | 20 | # %% ../../nbs/03b_utils.misc.ipynb 7 21 | def is_cpu(x: jax.Array): 22 | 23 | if hasattr(x, "devices"): # Unified Array (jax >= 0.4) 24 | return list(x.devices())[0] == jax.devices("cpu")[0] 25 | if hasattr(x, "device"): # Old-style DeviceArray 26 | return x.device() == jax.devices("cpu")[0] 27 | 28 | assert hasattr(x, "sharding"), f"Weird input type={type(input)}, expecrted Array, DeviceArray, or ShardedDeviceArray" 29 | return False 30 | 31 | 32 | # %% ../../nbs/03b_utils.misc.ipynb 8 33 | def test_array_repr(input: str, template:str): 34 | # Depending on the jax version, the arrray type can be either "Array" or "DeviceArray". 35 | # Depending on platform, the default device can be "cpu:0" or "gpu:0" (or, I guess, "tpu:0"?) 36 | 37 | # Create templace to match the "Array" and "gpu:0" case, they will be replaced with 38 | # regexes that will match either case 39 | 40 | # Escape the template to make it a valid regex. 41 | 42 | template = re.escape(template) 43 | template = template.replace("Array", "(Array|DeviceArray)") 44 | template = template.replace("\\ gpu:0", "( cpu:0| gpu:0| tpu:0)?") 45 | 46 | # Does imput match the regex? 47 | if not re.search(template, input): 48 | template = template.replace("\\", "") 49 | raise Exception(f"Template does not match\nTemplate: '{template}'\ninput: '{input}'") 50 | 51 | -------------------------------------------------------------------------------- /nbs/00_repr_str.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 🧾 View as a summary" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "#| default_exp repr_str" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "#| hide\n", 26 | "import nbdev; nbdev.nbdev_export()" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "# |hide\n", 36 | "import os\n", 37 | "from nbdev.showdoc import *\n", 38 | "from fastcore.test import test_eq, test" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "#| hide\n", 48 | "# For testing, I want to see 8 CPU devices.\n", 49 | "os.environ[\"JAX_PLATFORM_NAME\"] = \"cpu\"\n", 50 | "os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=8'" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "#| hide\n", 60 | "#| export\n", 61 | "\n", 62 | "import warnings\n", 63 | "from typing import Union, Optional as O\n", 64 | "\n", 65 | "import numpy as np\n", 66 | "import jax, jax.numpy as jnp\n", 67 | "\n", 68 | "from lovely_numpy import np_to_str_common, pretty_str, sparse_join, ansi_color, in_debugger, bytes_to_human\n", 69 | "from lovely_numpy import config as lnp_config\n", 70 | "\n", 71 | "from lovely_jax.utils.config import get_config, config\n", 72 | "from lovely_jax.utils.misc import to_numpy, is_cpu, test_array_repr" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "# |hide\n", 82 | "key = jax.random.PRNGKey(0)\n", 83 | "randoms = jax.random.normal(key, (100,))" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "spicy = (randoms[:12].at[0].mul(10000)\n", 93 | " .at[1].divide(10000)\n", 94 | " .at[3].set(float('inf'))\n", 95 | " .at[4].set(float('-inf'))\n", 96 | " .at[5].set(float('nan'))\n", 97 | " .reshape((2,6)))" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "# |exporti\n", 107 | "dtnames = { \"float16\": \"f16\",\n", 108 | " \"float32\": \"\", # Default dtype in jax\n", 109 | " \"float64\": \"f64\", \n", 110 | " \"bfloat16\": \"bf16\",\n", 111 | " \"uint8\": \"u8\",\n", 112 | " \"uint16\": \"u16\",\n", 113 | " \"uint32\": \"u32\",\n", 114 | " \"uint64\": \"u64\",\n", 115 | " \"int8\": \"i8\",\n", 116 | " \"int16\": \"i16\",\n", 117 | " \"int32\": \"i32\",\n", 118 | " \"int64\": \"i64\",\n", 119 | " }\n", 120 | "\n", 121 | "def short_dtype(x: jax.Array) -> str:\n", 122 | " return dtnames.get(x.dtype.name, str(x.dtype))" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "# |hide\n", 132 | "test_eq(short_dtype(jnp.array(1., dtype=jnp.bfloat16)), \"bf16\")" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "# | exporti\n", 142 | "def plain_repr(x: jax.Array):\n", 143 | " \"Pick the right function to get a plain repr\"\n", 144 | " # assert isinstance(x, np.ndarray), f\"expected np.ndarray but got {type(x)}\" # Could be a sub-class.\n", 145 | " return x._plain_repr() if hasattr(x, \"_plain_repr\") else repr(x)\n", 146 | "\n", 147 | "# def plain_str(x: torch.Tensor):\n", 148 | "# \"Pick the right function to get a plain str.\"\n", 149 | "# # assert isinstance(x, np.ndarray), f\"expected np.ndarray but got {type(x)}\"\n", 150 | "# return x._plain_str() if hasattr(type(x), \"_plain_str\") else str(x)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "# | exporti\n", 160 | "def is_nasty(x: jax.Array):\n", 161 | " \"\"\"Return true of any `x` values are inf or nan\"\"\"\n", 162 | " \n", 163 | " if x.size == 0: return False # min/max don't like zero-lenght arrays\n", 164 | " \n", 165 | " x_min = x.min()\n", 166 | " x_max = x.max()\n", 167 | " \n", 168 | " return jnp.isnan(x_min) or jnp.isinf(x_min) or jnp.isinf(x_max)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "#| hide\n", 178 | "\n", 179 | "test_eq(is_nasty(jnp.array([1, 2, float(\"nan\")])), True)\n", 180 | "test_eq(is_nasty(jnp.array([1, 2, float(\"inf\")])), True)\n", 181 | "test_eq(is_nasty(jnp.array([1, 2, 3])), False)\n", 182 | "test_eq(is_nasty(jnp.array([])), False)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "# |export\n", 192 | "def jax_to_str_common(x: jax.Array, # Input\n", 193 | " color=True, # ANSI color highlighting\n", 194 | " ddof=0): # For \"std\" unbiasing\n", 195 | "\n", 196 | " if x.size == 0:\n", 197 | " return ansi_color(\"empty\", \"grey\", color)\n", 198 | "\n", 199 | " zeros = ansi_color(\"all_zeros\", \"grey\", color) if jnp.equal(x, 0.).all() and x.size > 1 else None\n", 200 | " # pinf = ansi_color(\"+Inf!\", \"red\", color) if jnp.isposinf(x).any() else None\n", 201 | " # ninf = ansi_color(\"-Inf!\", \"red\", color) if jnp.isneginf(x).any() else None\n", 202 | " # nan = ansi_color(\"NaN!\", \"red\", color) if jnp.isnan(x).any() else None\n", 203 | "\n", 204 | " # attention = sparse_join([zeros,pinf,ninf,nan])\n", 205 | "\n", 206 | " summary = None\n", 207 | " if not zeros and x.ndim > 0:\n", 208 | " minmax = f\"x∈[{pretty_str(x.min())}, {pretty_str(x.max())}]\" if x.size > 2 else None\n", 209 | " meanstd = f\"μ={pretty_str(x.mean())} σ={pretty_str(x.std(ddof=ddof))}\" if x.size >= 2 else None\n", 210 | " summary = sparse_join([minmax, meanstd])\n", 211 | "\n", 212 | "\n", 213 | " return sparse_join([ summary, zeros])" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "# |exporti\n", 223 | "\n", 224 | "def to_str(x: jax.Array, # Input\n", 225 | " plain: bool=False,\n", 226 | " verbose: bool=False,\n", 227 | " depth=0,\n", 228 | " lvl=0,\n", 229 | " color=None) -> str:\n", 230 | "\n", 231 | " if plain:\n", 232 | " return plain_repr(x)\n", 233 | "\n", 234 | " conf = get_config()\n", 235 | "\n", 236 | " tname = type(x).__name__.split(\".\")[-1]\n", 237 | " if tname in (\"ArrayImpl\"): tname = \"Array\"\n", 238 | " shape = str(list(x.shape)) if x.ndim else None\n", 239 | " type_str = sparse_join([tname, shape], sep=\"\")\n", 240 | "\n", 241 | " if hasattr(x, \"devices\"): # Unified Array (jax >= 0.4)\n", 242 | " int_dev_ids = sorted([d.id for d in x.devices()])\n", 243 | " ids = \",\".join(map(str, int_dev_ids))\n", 244 | " dev = f\"{list(x.devices())[0].platform}:{ids}\"\n", 245 | " elif hasattr(x, \"device\"): # Old-style DeviceArray\n", 246 | " dev = f\"{x.device().platform}:{x.device().id}\"\n", 247 | " elif hasattr(x, \"sharding\"):\n", 248 | " int_dev_ids = sorted([d.id for d in x.sharding.devices])\n", 249 | " ids = \",\".join(map(str, int_dev_ids))\n", 250 | " dev = f\"{x.sharding.devices[0].platform}:{ids}\"\n", 251 | " else:\n", 252 | " assert 0, f\"Weird input type={type(input)}, expecrted Array, DeviceArray, or ShardedDeviceArray\"\n", 253 | "\n", 254 | " dtype = short_dtype(x)\n", 255 | " # grad_fn = t.grad_fn.name() if t.grad_fn else None\n", 256 | " # PyTorch does not want you to know, but all `grad_fn``\n", 257 | " # tensors actuall have `requires_grad=True`` too.\n", 258 | " # grad = \"grad\" if t.requires_grad else None \n", 259 | " grad = grad_fn = None\n", 260 | "\n", 261 | " # For complex tensors, just show the shape / size part for now.\n", 262 | " if not jnp.iscomplexobj(x):\n", 263 | " if color is None: color=conf.color\n", 264 | " if in_debugger(): color = False\n", 265 | " # `lovely-numpy` is used to calculate stats when doing so on GPU would require\n", 266 | " # memory allocation (not float tensors, tensors with bad numbers), or if the\n", 267 | " # data is on CPU (because numpy is faster).\n", 268 | " #\n", 269 | " # Temporarily set the numpy config to match our config for consistency.\n", 270 | " with lnp_config(precision=conf.precision,\n", 271 | " threshold_min=conf.threshold_min,\n", 272 | " threshold_max=conf.threshold_max,\n", 273 | " sci_mode=conf.sci_mode):\n", 274 | "\n", 275 | " if is_cpu(x) or is_nasty(x):\n", 276 | " common = np_to_str_common(np.array(x), color=color)\n", 277 | " else:\n", 278 | " common = jax_to_str_common(x, color=color)\n", 279 | "\n", 280 | " numel = None\n", 281 | " if x.shape and max(x.shape) != x.size:\n", 282 | " numel = f\"n={x.size}\"\n", 283 | " if get_config().show_mem_above <= x.nbytes:\n", 284 | " numel = sparse_join([numel, f\"({bytes_to_human(x.nbytes)})\"])\n", 285 | " elif get_config().show_mem_above <= x.nbytes:\n", 286 | " numel = bytes_to_human(x.nbytes)\n", 287 | "\n", 288 | " vals = pretty_str(x) if 0 < x.size <= 10 else None\n", 289 | " res = sparse_join([type_str, dtype, numel, common, grad, grad_fn, dev, vals])\n", 290 | " else:\n", 291 | " res = plain_repr(x)\n", 292 | "\n", 293 | " if verbose:\n", 294 | " res += \"\\n\" + plain_repr(x)\n", 295 | "\n", 296 | " if depth and x.ndim > 1:\n", 297 | " with config(show_mem_above=jnp.inf):\n", 298 | " deep_width = min((x.shape[0]), conf.deeper_width) # Print at most this many lines\n", 299 | " deep_lines = [ \" \"*conf.indent*(lvl+1) + to_str(x[i,:], depth=depth-1, lvl=lvl+1)\n", 300 | " for i in range(deep_width)] \n", 301 | "\n", 302 | " # If we were limited by width, print ...\n", 303 | " if deep_width < x.shape[0]: deep_lines.append(\" \"*conf.indent*(lvl+1) + \"...\")\n", 304 | "\n", 305 | " res += \"\\n\" + \"\\n\".join(deep_lines)\n", 306 | "\n", 307 | " return res" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "# |exporti\n", 317 | "def history_warning():\n", 318 | " \"Issue a warning (once) ifw e are running in IPYthon with output cache enabled\"\n", 319 | "\n", 320 | " if \"get_ipython\" in globals() and get_ipython().cache_size > 0:\n", 321 | " warnings.warn(\"IPYthon has its output cache enabled. See https://xl0.github.io/lovely-tensors/history.html\")" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "metadata": {}, 328 | "outputs": [ 329 | { 330 | "name": "stderr", 331 | "output_type": "stream", 332 | "text": [ 333 | "/tmp/ipykernel_285620/3648473780.py:6: UserWarning: IPYthon has its output cache enabled. See https://xl0.github.io/lovely-tensors/history.html\n", 334 | " warnings.warn(\"IPYthon has its output cache enabled. See https://xl0.github.io/lovely-tensors/history.html\")\n" 335 | ] 336 | } 337 | ], 338 | "source": [ 339 | "# |hide\n", 340 | "get_ipython().cache_size=1000\n", 341 | "history_warning()" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "# |hide\n", 351 | "get_ipython().cache_size=0" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "#| exporti\n", 361 | "\n", 362 | "class StrProxy():\n", 363 | " def __init__(self, x: jax.Array, plain=False, verbose=False, depth=0, lvl=0, color=None):\n", 364 | " self.x = x\n", 365 | " self.plain = plain\n", 366 | " self.verbose = verbose\n", 367 | " self.depth=depth\n", 368 | " self.lvl=lvl\n", 369 | " self.color=color\n", 370 | " history_warning()\n", 371 | " \n", 372 | " def __repr__(self):\n", 373 | " return to_str(self.x, plain=self.plain, verbose=self.verbose,\n", 374 | " depth=self.depth, lvl=self.lvl, color=self.color)\n", 375 | "\n", 376 | " # This is used for .deeper attribute and .deeper(depth=...).\n", 377 | " # The second onthe results in a __call__.\n", 378 | " def __call__(self, depth=1):\n", 379 | " return StrProxy(self.x, depth=depth)" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": null, 385 | "metadata": {}, 386 | "outputs": [], 387 | "source": [ 388 | "# |export\n", 389 | "def lovely(x: jax.Array, # Tensor of interest\n", 390 | " verbose=False, # Whether to show the full tensor\n", 391 | " plain=False, # Just print if exactly as before\n", 392 | " depth=0, # Show stats in depth\n", 393 | " color=None): # Force color (True/False) or auto.\n", 394 | " return StrProxy(x, verbose=verbose, plain=plain, depth=depth, color=color)" 395 | ] 396 | }, 397 | { 398 | "cell_type": "markdown", 399 | "metadata": {}, 400 | "source": [ 401 | "### Examples" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "metadata": {}, 408 | "outputs": [ 409 | { 410 | "name": "stdout", 411 | "output_type": "stream", 412 | "text": [ 413 | "Array cpu:0 -1.981\n", 414 | "Array[2] μ=-0.466 σ=1.515 cpu:0 [-1.981, 1.048]\n", 415 | "Array[2, 3] n=6 x∈[-1.981, 1.048] μ=-0.017 σ=1.113 cpu:0 [[-1.981, 1.048, 0.890], [0.035, -0.947, 0.851]]\n", 416 | "Array[11] x∈[-1.981, 1.048] μ=-0.191 σ=0.899 cpu:0\n" 417 | ] 418 | } 419 | ], 420 | "source": [ 421 | "print(lovely(randoms[0]))\n", 422 | "print(lovely(randoms[:2]))\n", 423 | "print(lovely(randoms[:6].reshape((2, 3)))) # More than 2 elements -> show statistics\n", 424 | "print(lovely(randoms[:11])) # More than 10 -> suppress data output\n" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": null, 430 | "metadata": {}, 431 | "outputs": [], 432 | "source": [ 433 | "# |hide\n", 434 | "test_array_repr(str(lovely(randoms[0])), \"Array gpu:0 -1.981\")\n", 435 | "test_array_repr(str(lovely(randoms[:2])), \"Array[2] μ=-0.466 σ=1.515 gpu:0 [-1.981, 1.048]\")\n", 436 | "test_array_repr(str(lovely(randoms[:6].reshape(2, 3))), \"Array[2, 3] n=6 x∈[-1.981, 1.048] μ=-0.017 σ=1.113 gpu:0 [[-1.981, 1.048, 0.890], [0.035, -0.947, 0.851]]\")\n", 437 | "test_array_repr(str(lovely(randoms[:11])), \"Array[11] x∈[-1.981, 1.048] μ=-0.191 σ=0.899 gpu:0\")" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "metadata": {}, 444 | "outputs": [ 445 | { 446 | "name": "stdout", 447 | "output_type": "stream", 448 | "text": [ 449 | "Array f16 cpu:0 1.000\n", 450 | "Array f16 cpu:0 2.000\n" 451 | ] 452 | } 453 | ], 454 | "source": [ 455 | "grad = jnp.array(1., dtype=jnp.float16)\n", 456 | "print(lovely(grad)); print(lovely(grad+1))" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": null, 462 | "metadata": {}, 463 | "outputs": [], 464 | "source": [ 465 | "# |hide\n", 466 | "# test_eq(str(lovely(grad)), \"tensor f64 grad 1.000\")\n", 467 | "# test_eq(str(lovely(grad+1)), \"tensor f64 grad AddBackward0 2.000\")" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": null, 473 | "metadata": {}, 474 | "outputs": [], 475 | "source": [ 476 | "# if torch.cuda.is_available():\n", 477 | "# print(lovely(torch.tensor(1., device=torch.device(\"cuda:0\"))))\n", 478 | "# test_eq(str(lovely(torch.tensor(1., device=torch.device(\"cuda:0\")))), \"tensor cuda:0 1.000\")" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": {}, 484 | "source": [ 485 | "Do we have __any__ floating point nasties? Is the tensor __all__ zeros?" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": null, 491 | "metadata": {}, 492 | "outputs": [ 493 | { 494 | "data": { 495 | "text/plain": [ 496 | "Array[2, 6] n=12 x∈[-1.981e+04, 0.890] μ=-2.201e+03 σ=6.226e+03 \u001b[31m+Inf!\u001b[0m \u001b[31m-Inf!\u001b[0m \u001b[31mNaN!\u001b[0m cpu:0" 497 | ] 498 | }, 499 | "execution_count": null, 500 | "metadata": {}, 501 | "output_type": "execute_result" 502 | } 503 | ], 504 | "source": [ 505 | "# Statistics and range are calculated on good values only, if there are at lest 3 of them.\n", 506 | "lovely(spicy)" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": null, 512 | "metadata": {}, 513 | "outputs": [], 514 | "source": [ 515 | "# |hide\n", 516 | "test_array_repr(str(lovely(spicy)),\n", 517 | " 'Array[2, 6] n=12 x∈[-1.981e+04, 0.890] μ=-2.201e+03 σ=6.226e+03 \\x1b[31m+Inf!\\x1b[0m \\x1b[31m-Inf!\\x1b[0m \\x1b[31mNaN!\\x1b[0m gpu:0')" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": null, 523 | "metadata": {}, 524 | "outputs": [ 525 | { 526 | "data": { 527 | "text/plain": [ 528 | "Array[2, 6] n=12 x∈[-1.981e+04, 0.890] μ=-2.201e+03 σ=6.226e+03 +Inf! -Inf! NaN! cpu:0" 529 | ] 530 | }, 531 | "execution_count": null, 532 | "metadata": {}, 533 | "output_type": "execute_result" 534 | } 535 | ], 536 | "source": [ 537 | "lovely(spicy, color=False)" 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": null, 543 | "metadata": {}, 544 | "outputs": [ 545 | { 546 | "data": { 547 | "text/plain": [ 548 | "'Array[11] \\x1b[31mNaN!\\x1b[0m cpu:0'" 549 | ] 550 | }, 551 | "execution_count": null, 552 | "metadata": {}, 553 | "output_type": "execute_result" 554 | } 555 | ], 556 | "source": [ 557 | "str(lovely(jnp.array([float(\"nan\")]*11)))" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": null, 563 | "metadata": {}, 564 | "outputs": [], 565 | "source": [ 566 | "# |hide \n", 567 | "test_array_repr(str(lovely(jnp.array([float(\"nan\")]*11))),\n", 568 | " 'Array[11] \\x1b[31mNaN!\\x1b[0m gpu:0')" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": null, 574 | "metadata": {}, 575 | "outputs": [ 576 | { 577 | "data": { 578 | "text/plain": [ 579 | "Array[12] \u001b[38;2;127;127;127mall_zeros\u001b[0m cpu:0" 580 | ] 581 | }, 582 | "execution_count": null, 583 | "metadata": {}, 584 | "output_type": "execute_result" 585 | } 586 | ], 587 | "source": [ 588 | "lovely(jnp.zeros(12))" 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": null, 594 | "metadata": {}, 595 | "outputs": [], 596 | "source": [ 597 | "# |hide\n", 598 | "test_array_repr(str(lovely(jnp.zeros(12))),\n", 599 | " 'Array[12] \\x1b[38;2;127;127;127mall_zeros\\x1b[0m gpu:0')" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": null, 605 | "metadata": {}, 606 | "outputs": [ 607 | { 608 | "data": { 609 | "text/plain": [ 610 | "Array[0, 0, 0] f16 \u001b[38;2;127;127;127mempty\u001b[0m cpu:0" 611 | ] 612 | }, 613 | "execution_count": null, 614 | "metadata": {}, 615 | "output_type": "execute_result" 616 | } 617 | ], 618 | "source": [ 619 | "lovely(jnp.array([], dtype=jnp.float16).reshape((0,0,0)))" 620 | ] 621 | }, 622 | { 623 | "cell_type": "code", 624 | "execution_count": null, 625 | "metadata": {}, 626 | "outputs": [], 627 | "source": [ 628 | "# |hide\n", 629 | "test_array_repr(str(lovely(jnp.array([], dtype=jnp.float16).reshape((0,0,0)))),\n", 630 | " 'Array[0, 0, 0] f16 \\x1b[38;2;127;127;127mempty\\x1b[0m gpu:0')" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": null, 636 | "metadata": {}, 637 | "outputs": [ 638 | { 639 | "data": { 640 | "text/plain": [ 641 | "Array[3] i32 x∈[1, 3] μ=2.000 σ=0.816 cpu:0 [1, 2, 3]" 642 | ] 643 | }, 644 | "execution_count": null, 645 | "metadata": {}, 646 | "output_type": "execute_result" 647 | } 648 | ], 649 | "source": [ 650 | "lovely(jnp.array([1,2,3], dtype=jnp.int32))" 651 | ] 652 | }, 653 | { 654 | "cell_type": "code", 655 | "execution_count": null, 656 | "metadata": {}, 657 | "outputs": [], 658 | "source": [ 659 | "# |hide\n", 660 | "test_array_repr(str(lovely(jnp.array([1,2,3], dtype=jnp.int32))),\n", 661 | " 'Array[3] i32 x∈[1, 3] μ=2.000 σ=0.816 gpu:0 [1, 2, 3]')" 662 | ] 663 | }, 664 | { 665 | "cell_type": "code", 666 | "execution_count": null, 667 | "metadata": {}, 668 | "outputs": [ 669 | { 670 | "data": { 671 | "text/plain": [ 672 | "Array[2, 6] n=12 x∈[-1.981e+04, 0.890] μ=-2.201e+03 σ=6.226e+03 \u001b[31m+Inf!\u001b[0m \u001b[31m-Inf!\u001b[0m \u001b[31mNaN!\u001b[0m cpu:0\n", 673 | "Array([[-1.98e+04, 1.05e-04, 8.90e-01, inf, -inf, nan],\n", 674 | " [ 3.12e-02, -3.90e-01, 1.32e-02, -4.21e-01, -1.23e+00, -1.25e+00]], dtype=float32)" 675 | ] 676 | }, 677 | "execution_count": null, 678 | "metadata": {}, 679 | "output_type": "execute_result" 680 | } 681 | ], 682 | "source": [ 683 | "jnp.set_printoptions(linewidth=120, precision=2)\n", 684 | "lovely(spicy, verbose=True)" 685 | ] 686 | }, 687 | { 688 | "cell_type": "code", 689 | "execution_count": null, 690 | "metadata": {}, 691 | "outputs": [ 692 | { 693 | "data": { 694 | "text/plain": [ 695 | "Array([[-1.98e+04, 1.05e-04, 8.90e-01, inf, -inf, nan],\n", 696 | " [ 3.12e-02, -3.90e-01, 1.32e-02, -4.21e-01, -1.23e+00, -1.25e+00]], dtype=float32)" 697 | ] 698 | }, 699 | "execution_count": null, 700 | "metadata": {}, 701 | "output_type": "execute_result" 702 | } 703 | ], 704 | "source": [ 705 | "lovely(spicy, plain=True)" 706 | ] 707 | }, 708 | { 709 | "cell_type": "code", 710 | "execution_count": null, 711 | "metadata": {}, 712 | "outputs": [ 713 | { 714 | "data": { 715 | "text/plain": [ 716 | "Array[3, 196, 196] n=115248 (0.4Mb) x∈[-2.118, 2.640] μ=-0.388 σ=1.073 \u001b[31mNaN!\u001b[0m cpu:0\n", 717 | " Array[196, 196] n=38416 x∈[-2.118, 2.249] μ=-0.324 σ=1.036 cpu:0\n", 718 | " Array[196] x∈[-1.912, 2.249] μ=-0.673 σ=0.521 cpu:0\n", 719 | " Array[196] x∈[-1.861, 2.163] μ=-0.738 σ=0.417 cpu:0\n", 720 | " Array[196] x∈[-1.758, 2.198] μ=-0.806 σ=0.396 cpu:0\n", 721 | " Array[196] x∈[-1.656, 2.249] μ=-0.849 σ=0.368 cpu:0\n", 722 | " Array[196] x∈[-1.673, 2.198] μ=-0.857 σ=0.356 cpu:0\n", 723 | " Array[196] x∈[-1.656, 2.146] μ=-0.848 σ=0.371 cpu:0\n", 724 | " Array[196] x∈[-1.433, 2.215] μ=-0.784 σ=0.396 cpu:0\n", 725 | " Array[196] x∈[-1.279, 2.249] μ=-0.695 σ=0.485 cpu:0\n", 726 | " Array[196] x∈[-1.364, 2.249] μ=-0.637 σ=0.538 cpu:0\n", 727 | " ...\n", 728 | " Array[196, 196] n=38416 x∈[-1.966, 2.429] μ=-0.274 σ=0.973 \u001b[31mNaN!\u001b[0m cpu:0\n", 729 | " Array[196] x∈[-1.861, 2.411] μ=-0.529 σ=0.555 cpu:0\n", 730 | " Array[196] x∈[-1.826, 2.359] μ=-0.562 σ=0.472 cpu:0\n", 731 | " Array[196] x∈[-1.756, 2.376] μ=-0.622 σ=0.458 \u001b[31mNaN!\u001b[0m cpu:0\n", 732 | " Array[196] x∈[-1.633, 2.429] μ=-0.664 σ=0.429 cpu:0\n", 733 | " Array[196] x∈[-1.651, 2.376] μ=-0.669 σ=0.398 cpu:0\n", 734 | " Array[196] x∈[-1.633, 2.376] μ=-0.701 σ=0.390 cpu:0\n", 735 | " Array[196] x∈[-1.563, 2.429] μ=-0.670 σ=0.379 cpu:0\n", 736 | " Array[196] x∈[-1.475, 2.429] μ=-0.616 σ=0.385 cpu:0\n", 737 | " Array[196] x∈[-1.511, 2.429] μ=-0.593 σ=0.398 cpu:0\n", 738 | " ...\n", 739 | " Array[196, 196] n=38416 x∈[-1.804, 2.640] μ=-0.567 σ=1.178 cpu:0\n", 740 | " Array[196] x∈[-1.717, 2.396] μ=-0.982 σ=0.349 cpu:0\n", 741 | " Array[196] x∈[-1.752, 2.326] μ=-1.034 σ=0.313 cpu:0\n", 742 | " Array[196] x∈[-1.648, 2.379] μ=-1.086 σ=0.313 cpu:0\n", 743 | " Array[196] x∈[-1.630, 2.466] μ=-1.121 σ=0.304 cpu:0\n", 744 | " Array[196] x∈[-1.717, 2.448] μ=-1.120 σ=0.301 cpu:0\n", 745 | " Array[196] x∈[-1.717, 2.431] μ=-1.166 σ=0.313 cpu:0\n", 746 | " Array[196] x∈[-1.560, 2.448] μ=-1.124 σ=0.325 cpu:0\n", 747 | " Array[196] x∈[-1.421, 2.431] μ=-1.064 σ=0.382 cpu:0\n", 748 | " Array[196] x∈[-1.526, 2.396] μ=-1.047 σ=0.416 cpu:0\n", 749 | " ..." 750 | ] 751 | }, 752 | "execution_count": null, 753 | "metadata": {}, 754 | "output_type": "execute_result" 755 | } 756 | ], 757 | "source": [ 758 | "image = jnp.load(\"mysteryman.npy\")\n", 759 | "image = image.at[1,2,3].set(float('nan'))\n", 760 | "\n", 761 | "lovely(image, depth=2) # Limited by set_config(deeper_lines=N)" 762 | ] 763 | }, 764 | { 765 | "cell_type": "code", 766 | "execution_count": null, 767 | "metadata": {}, 768 | "outputs": [], 769 | "source": [ 770 | "# |hide\n", 771 | "#### CUDA memory is not leaked" 772 | ] 773 | }, 774 | { 775 | "cell_type": "code", 776 | "execution_count": null, 777 | "metadata": {}, 778 | "outputs": [], 779 | "source": [ 780 | "# |hide\n", 781 | "# |eval: false\n", 782 | "# def memstats():\n", 783 | "# allocated = int(torch.cuda.memory_allocated() // (1024*1024))\n", 784 | "# max_allocated = int(torch.cuda.max_memory_allocated() // (1024*1024))\n", 785 | "# return f\"Allocated: {allocated} MB, Max: {max_allocated} Mb\"\n", 786 | "\n", 787 | "# if torch.cuda.is_available():\n", 788 | "# cudamem = torch.cuda.memory_allocated()\n", 789 | "# print(f\"before allocation: {memstats()}\")\n", 790 | "# numbers = torch.randn((3, 1024, 1024), device=\"cuda\") # 12Mb image\n", 791 | "# torch.cuda.synchronize()\n", 792 | "\n", 793 | "# print(f\"after allocation: {memstats()}\")\n", 794 | "# # Note, the return value of lovely() is not a string, but a\n", 795 | "# # StrProxy that holds reference to 'numbers'. You have to del\n", 796 | "# # the references to it, but once it's gone, the reference to\n", 797 | "# # the tensor is gone too.\n", 798 | "# display(lovely(numbers) )\n", 799 | "# print(f\"after repr: {memstats()}\")\n", 800 | " \n", 801 | "# del numbers\n", 802 | "# # torch.cuda.memory.empty_cache()\n", 803 | "\n", 804 | "# print(f\"after cleanup: {memstats()}\")\n", 805 | "# test_eq(cudamem >= torch.cuda.memory_allocated(), True)" 806 | ] 807 | }, 808 | { 809 | "cell_type": "code", 810 | "execution_count": null, 811 | "metadata": {}, 812 | "outputs": [ 813 | { 814 | "data": { 815 | "text/plain": [ 816 | "Array([-0.4 -0.4j , 1.13+0.08j, -0.03+1.j , -0.46+0.61j, -1.15-0.99j], dtype=complex64)" 817 | ] 818 | }, 819 | "execution_count": null, 820 | "metadata": {}, 821 | "output_type": "execute_result" 822 | } 823 | ], 824 | "source": [ 825 | "# We don't really supposed complex numbers yet\n", 826 | "c = jnp.array([-0.4011-0.4035j, 1.1300+0.0788j, -0.0277+0.9978j, -0.4636+0.6064j, -1.1505-0.9865j])\n", 827 | "lovely(c)" 828 | ] 829 | }, 830 | { 831 | "cell_type": "code", 832 | "execution_count": null, 833 | "metadata": {}, 834 | "outputs": [ 835 | { 836 | "data": { 837 | "text/html": [ 838 | "
          CPU 0          \n",
839 |        "                         \n",
840 |        "          CPU 1          \n",
841 |        "                         \n",
842 |        "          CPU 2          \n",
843 |        "                         \n",
844 |        "          CPU 3          \n",
845 |        "                         \n",
846 |        "          CPU 4          \n",
847 |        "                         \n",
848 |        "          CPU 5          \n",
849 |        "                         \n",
850 |        "          CPU 6          \n",
851 |        "                         \n",
852 |        "          CPU 7          \n",
853 |        "                         \n",
854 |        "
\n" 855 | ], 856 | "text/plain": [ 857 | "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", 858 | "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", 859 | "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mCPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", 860 | "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", 861 | "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mCPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", 862 | "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", 863 | "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", 864 | "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", 865 | "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mCPU 4\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", 866 | "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", 867 | "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mCPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", 868 | "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", 869 | "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mCPU 6\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", 870 | "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", 871 | "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mCPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", 872 | "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" 873 | ] 874 | }, 875 | "metadata": {}, 876 | "output_type": "display_data" 877 | }, 878 | { 879 | "name": "stdout", 880 | "output_type": "stream", 881 | "text": [ 882 | "Array[8192, 8192] n=67108864 (0.2Gb) x∈[-5.420, 5.220] μ=-0.000 σ=1.000 cpu:0\n", 883 | "Array[8192, 8192] n=67108864 (0.2Gb) x∈[-5.420, 5.220] μ=-0.000 σ=1.000 cpu:0,1,2,3,4,5,6,7\n" 884 | ] 885 | } 886 | ], 887 | "source": [ 888 | "#| eval: false\n", 889 | "assert jax.__version_info__[0] == 0\n", 890 | "if jax.__version_info__[1] >= 4:\n", 891 | " from jax.sharding import PositionalSharding\n", 892 | " from jax.experimental import mesh_utils\n", 893 | " sharding = PositionalSharding(mesh_utils.create_device_mesh((8,1)))\n", 894 | " x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n", 895 | " y = jax.device_put(x, sharding)\n", 896 | "\n", 897 | " jax.debug.visualize_array_sharding(y)\n", 898 | "else:\n", 899 | " # Note: Looks like ShardedDeviceArray needs an explicit device axis?\n", 900 | " x = jax.random.normal(jax.random.PRNGKey(0), (8, 1024, 8192))\n", 901 | " y = jax.device_put_sharded([x for x in x], jax.devices())\n", 902 | "\n", 903 | "print(lovely(x))\n", 904 | "print(lovely(y))" 905 | ] 906 | } 907 | ], 908 | "metadata": { 909 | "kernelspec": { 910 | "display_name": "python3", 911 | "language": "python", 912 | "name": "python3" 913 | } 914 | }, 915 | "nbformat": 4, 916 | "nbformat_minor": 4 917 | } 918 | -------------------------------------------------------------------------------- /nbs/_quarto.yml: -------------------------------------------------------------------------------- 1 | project: 2 | type: website 3 | 4 | format: 5 | html: 6 | theme: cosmo 7 | css: styles.css 8 | # toc: true 9 | page-layout: full 10 | 11 | website: 12 | twitter-card: true 13 | open-graph: true 14 | repo-actions: [issue] 15 | navbar: 16 | background: primary 17 | search: true 18 | right: 19 | - icon: github 20 | href: "http://github.com/xl0/lovely-jax/" 21 | sidebar: 22 | style: floating 23 | 24 | metadata-files: [nbdev.yml, sidebar.yml] -------------------------------------------------------------------------------- /nbs/mysteryman.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xl0/lovely-jax/0bf049db981d5737893bc992a4799d385b1d229a/nbs/mysteryman.npy -------------------------------------------------------------------------------- /nbs/nbdev.yml: -------------------------------------------------------------------------------- 1 | project: 2 | output-dir: _docs 3 | 4 | website: 5 | title: "lovely-jax" 6 | site-url: "https://xl0.github.io/lovely-jax" 7 | description: "💘 Lovely JAX" 8 | repo-branch: master 9 | repo-url: "https://github.com/xl0/lovely-jax" 10 | -------------------------------------------------------------------------------- /nbs/sidebar.yml: -------------------------------------------------------------------------------- 1 | website: 2 | sidebar: 3 | contents: 4 | - index.ipynb 5 | - section: Data representations 6 | contents: 7 | - 00_repr_str.ipynb 8 | - 01_repr_rgb.ipynb 9 | - 02_repr_plt.ipynb 10 | - 05_repr_chans.ipynb 11 | # - section: Image utils 12 | # contents: 13 | # - 03a_utils.colormap.ipynb 14 | # - 03b_utils.pad.ipynb 15 | # - 03c_utils.tile2d.ipynb 16 | - section: Misc 17 | contents: 18 | - 03a_utils.config.ipynb 19 | - 10_patch.ipynb 20 | - matplotlib.ipynb -------------------------------------------------------------------------------- /nbs/styles.css: -------------------------------------------------------------------------------- 1 | .cell { 2 | margin-bottom: 1rem; 3 | } 4 | 5 | .cell > .sourceCode { 6 | margin-bottom: 0; 7 | } 8 | 9 | .cell-output > pre { 10 | margin-bottom: 0; 11 | } 12 | 13 | /* .cell-output > pre, .cell-output > .sourceCode > pre, .cell-output-stdout > pre, .ansi-escaped-output > pre { */ 14 | /* margin-left: 0.8rem; */ 15 | /* margin-top: 0; */ 16 | /* background: none; */ 17 | /* border-left: 2px solid lightsalmon; */ 18 | /* border-top-left-radius: 0; */ 19 | /* border-top-right-radius: 0; */ 20 | /* } */ 21 | 22 | .cell-output > .sourceCode { 23 | border: none; 24 | } 25 | 26 | .cell-output > .sourceCode { 27 | background: none; 28 | margin-top: 0; 29 | } 30 | 31 | div.description { 32 | padding-left: 2px; 33 | padding-top: 5px; 34 | font-style: italic; 35 | font-size: 135%; 36 | opacity: 70%; 37 | } 38 | -------------------------------------------------------------------------------- /settings.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | # All sections below are required unless otherwise specified. 3 | # See https://github.com/fastai/nbdev/blob/master/settings.ini for examples. 4 | 5 | ### Python library ### 6 | repo = lovely-jax 7 | lib_name = %(repo)s 8 | version = 0.1.3 9 | min_python = 3.7 10 | license = mit 11 | 12 | ### nbdev ### 13 | doc_path = _docs 14 | lib_path = lovely_jax 15 | nbs_path = nbs 16 | recursive = True 17 | tst_flags = notest 18 | put_version_in_init = True 19 | 20 | ### Docs ### 21 | branch = master 22 | custom_sidebar = True 23 | doc_host = https://%(user)s.github.io 24 | doc_baseurl = /%(repo)s 25 | git_url = https://github.com/%(user)s/%(repo)s 26 | title = %(lib_name)s 27 | 28 | ### PyPI ### 29 | audience = Developers 30 | author = Alexey Zaytsev 31 | author_email = alexey.zaytsev@gmail.com 32 | copyright = 2022 onwards, %(author)s 33 | description = 💘 Lovely JAX 34 | keywords = nbdev jupyter jax 35 | language = English 36 | status = 3 37 | user = xl0 38 | 39 | ### Optional ### 40 | # Note: [cpu] for CI. 41 | # If you already have the cuda version installed, this will not override it. 42 | requirements = jax[cpu] lovely-numpy>=0.2.9 43 | # pip_requirements = 44 | # dev_requirements = 45 | # console_scripts = -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pkg_resources import parse_version 2 | from configparser import ConfigParser 3 | import setuptools 4 | assert parse_version(setuptools.__version__)>=parse_version('36.2') 5 | 6 | # note: all settings are in settings.ini; edit there, not here 7 | config = ConfigParser(delimiters=['=']) 8 | config.read('settings.ini') 9 | cfg = config['DEFAULT'] 10 | 11 | cfg_keys = 'version description keywords author author_email'.split() 12 | expected = cfg_keys + "lib_name user branch license status min_python audience language".split() 13 | for o in expected: assert o in cfg, "missing expected setting: {}".format(o) 14 | setup_cfg = {o:cfg[o] for o in cfg_keys} 15 | 16 | licenses = { 17 | 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'), 18 | 'mit': ('MIT License', 'OSI Approved :: MIT License'), 19 | 'gpl2': ('GNU General Public License v2', 'OSI Approved :: GNU General Public License v2 (GPLv2)'), 20 | 'gpl3': ('GNU General Public License v3', 'OSI Approved :: GNU General Public License v3 (GPLv3)'), 21 | 'bsd3': ('BSD License', 'OSI Approved :: BSD License'), 22 | } 23 | statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha', 24 | '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ] 25 | py_versions = '3.6 3.7 3.8 3.9 3.10'.split() 26 | 27 | requirements = cfg.get('requirements','').split() 28 | if cfg.get('pip_requirements'): requirements += cfg.get('pip_requirements','').split() 29 | min_python = cfg['min_python'] 30 | lic = licenses.get(cfg['license'].lower(), (cfg['license'], None)) 31 | dev_requirements = (cfg.get('dev_requirements') or '').split() 32 | 33 | setuptools.setup( 34 | name = cfg['lib_name'], 35 | license = lic[0], 36 | classifiers = [ 37 | 'Development Status :: ' + statuses[int(cfg['status'])], 38 | 'Intended Audience :: ' + cfg['audience'].title(), 39 | 'Natural Language :: ' + cfg['language'].title(), 40 | ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]] + (['License :: ' + lic[1] ] if lic[1] else []), 41 | url = cfg['git_url'], 42 | packages = setuptools.find_packages(), 43 | include_package_data = True, 44 | install_requires = requirements, 45 | extras_require={ 'dev': dev_requirements }, 46 | dependency_links = cfg.get('dep_links','').split(), 47 | python_requires = '>=' + cfg['min_python'], 48 | long_description = open('README.md').read(), 49 | long_description_content_type = 'text/markdown', 50 | zip_safe = False, 51 | entry_points = { 52 | 'console_scripts': cfg.get('console_scripts','').split(), 53 | 'nbdev': [f'{cfg.get("lib_path")}={cfg.get("lib_path")}._modidx:d'] 54 | }, 55 | **setup_cfg) 56 | 57 | 58 | --------------------------------------------------------------------------------