├── tests ├── __init__.py └── test_ezmup.py ├── requirements.txt ├── contents ├── meme.png ├── image.png ├── coord-check.png ├── hyperparams.png └── example.csv ├── docs ├── source │ ├── modules.rst │ ├── ezmup.rst │ ├── index.rst │ └── conf.py ├── modules.rst ├── ezmup.rst ├── Makefile └── make.bat ├── ezmup ├── __init__.py └── ezmup.py ├── setup.py ├── .readthedocs.yaml ├── .github └── workflows │ └── tag-and-release.yml ├── example.py ├── .gitignore ├── pyproject.toml ├── requirements-dev.txt └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | pandas 4 | torch 5 | -------------------------------------------------------------------------------- /contents/meme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/ezmup/HEAD/contents/meme.png -------------------------------------------------------------------------------- /contents/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/ezmup/HEAD/contents/image.png -------------------------------------------------------------------------------- /contents/coord-check.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/ezmup/HEAD/contents/coord-check.png -------------------------------------------------------------------------------- /contents/hyperparams.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/ezmup/HEAD/contents/hyperparams.png -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | ezmup 2 | ===== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | ezmup 8 | -------------------------------------------------------------------------------- /tests/test_ezmup.py: -------------------------------------------------------------------------------- 1 | from ezmup import __version__ 2 | 3 | 4 | def test_version(): 5 | assert __version__ == "0.0.1" 6 | -------------------------------------------------------------------------------- /ezmup/__init__.py: -------------------------------------------------------------------------------- 1 | """Top-level pacakge for the ezmup module.""" 2 | 3 | __author__ = """Simo Ryu""" 4 | __email__ = "cloneofsimo@gmail.com" 5 | __version__ = "0.0.1" 6 | 7 | from .ezmup import * 8 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | ezmup 2 | ========== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | ezmup 8 | 9 | .. automodule:: ezmup 10 | :members: 11 | :undoc-members: 12 | :show-inheritance: 13 | :inherited-members: 14 | -------------------------------------------------------------------------------- /docs/ezmup.rst: -------------------------------------------------------------------------------- 1 | ezmup package 2 | ================== 3 | 4 | Submodules 5 | ------------ 6 | 7 | ezmup module 8 | ---------------------- 9 | 10 | .. automodule:: ezmup 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: ezmup 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/source/ezmup.rst: -------------------------------------------------------------------------------- 1 | ezmup package 2 | ============= 3 | 4 | Submodules 5 | ---------- 6 | 7 | ezmup.ezmup module 8 | ------------------ 9 | 10 | .. automodule:: ezmup.ezmup 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: ezmup 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import find_packages, setup 5 | 6 | setup( 7 | name="ezmup", 8 | py_modules=["ezmup"], 9 | version="0.0.1", 10 | description="Really Easy MuP", 11 | author="Simo Ryu", 12 | packages=find_packages(), 13 | install_requires=[ 14 | str(r) 15 | for r in pkg_resources.parse_requirements( 16 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")), 17 | ) 18 | ], 19 | include_package_data=True, 20 | ) 21 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. ezmup documentation master file, created by 2 | sphinx-quickstart on Tue Oct 3 22:44:23 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to ezmup's documentation! 7 | ==================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | modules 14 | 15 | 16 | Indices and tables 17 | ================== 18 | 19 | * :ref:`genindex` 20 | * :ref:`modindex` 21 | * :ref:`search` 22 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.11" 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | configuration: docs/conf.py 17 | 18 | # We recommend specifying your dependencies to enable reproducible builds: 19 | # https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 20 | python: 21 | install: 22 | - requirements: requirements-dev.txt 23 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = "ezmup" 10 | copyright = "2024, Simo Ryu" 11 | author = "Simo Ryu" 12 | release = "0.0.1" 13 | 14 | # -- General configuration --------------------------------------------------- 15 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 16 | 17 | extensions = [ 18 | "sphinx.ext.autodoc", 19 | "sphinx.ext.viewcode", 20 | "sphinx.ext.napoleon", 21 | "myst_parser", 22 | "sphinx.ext.intersphinx", 23 | "sphinxcontrib.youtube", 24 | ] 25 | 26 | templates_path = ["_templates"] 27 | exclude_patterns = [] 28 | 29 | 30 | # -- Options for HTML output ------------------------------------------------- 31 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 32 | 33 | html_theme = "pydata_sphinx_theme" 34 | html_static_path = ["_static"] 35 | -------------------------------------------------------------------------------- /.github/workflows/tag-and-release.yml: -------------------------------------------------------------------------------- 1 | # .github/workflows/publish-release.yml 2 | name: Publish release 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | jobs: 9 | check-permission: 10 | name: check permission 11 | runs-on: ubuntu-latest 12 | outputs: 13 | permission: ${{ steps.check.outputs.permission }} 14 | steps: 15 | - id: check 16 | uses: shogo82148/actions-check-permissions@v1 17 | env: 18 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 19 | 20 | publish: 21 | name: Publish release 22 | runs-on: ubuntu-latest 23 | permissions: 24 | contents: write 25 | pull-requests: read 26 | needs: 27 | - check-permission 28 | steps: 29 | ########################## 30 | # Checkout the code base # 31 | ########################## 32 | - name: Checkout code 33 | uses: actions/checkout@v4 34 | ########################## 35 | # Release from tags # 36 | ########################## 37 | - name: Publish Release 38 | id: publish_release 39 | uses: ghalactic/github-release-from-tag@v5 40 | if: github.ref_type == 'tag' 41 | with: 42 | prerelease: "false" 43 | reactions: rocket, +1, eyes 44 | token: ${{ secrets.GITHUB_TOKEN }} 45 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ezmup import Ezmup, get_coord_data, plot_coord_data 6 | 7 | 8 | class AttentionLayer(nn.Module): 9 | def __init__(self, hidden_dim): 10 | super(AttentionLayer, self).__init__() 11 | self.hidden_dim = hidden_dim 12 | # The query, key, and value layers now map from 'hidden_dim' to 'hidden_dim' 13 | self.query = nn.Linear(hidden_dim, hidden_dim) 14 | self.key = nn.Linear(hidden_dim, hidden_dim) 15 | self.value = nn.Linear(hidden_dim, hidden_dim) 16 | 17 | def forward(self, x): 18 | Q = self.query(x) 19 | K = self.key(x) 20 | V = self.value(x) 21 | attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.hidden_dim 22 | attention = F.softmax(attention_scores, dim=-1) 23 | return torch.matmul(attention, V) 24 | 25 | 26 | class MyModel(nn.Module): 27 | def __init__(self, input_dim, output_dim, hidden_dim, num_layers): 28 | super(MyModel, self).__init__() 29 | self.layers = nn.ModuleList( 30 | [AttentionLayer(hidden_dim) for _ in range(num_layers)], 31 | ) 32 | self.fin = nn.Linear(input_dim, hidden_dim) 33 | self.fout = nn.Linear(hidden_dim, output_dim) 34 | 35 | def forward(self, x): 36 | x = self.fin(x) 37 | x = nn.ReLU()(x) 38 | print(x.shape) 39 | for layer in self.layers: 40 | x = layer(x) 41 | x = nn.ReLU()(x) 42 | x = self.fout(x) 43 | return x 44 | 45 | 46 | model = MyModel(input_dim=41, output_dim=41, hidden_dim=47, num_layers=4) 47 | model.to("cuda:0") 48 | 49 | mup_engine = Ezmup(47, model, init_std=1.0) 50 | mup_engine.change_width_as(64) 51 | 52 | 53 | def loss_fn(batch, model): 54 | x, y = batch 55 | y_pred = model(x) 56 | return F.mse_loss(y_pred, y) 57 | 58 | 59 | mup_engine.forward = loss_fn 60 | 61 | # example run 62 | x = torch.randn(4, 33, 41).to("cuda:0") 63 | y = torch.randn(4, 33, 41).to("cuda:0") 64 | # y = model(x) 65 | 66 | 67 | df = get_coord_data(mup_engine, (x, y), n_seeds=1, n_steps=3) 68 | df.to_csv("contents/example.csv") 69 | 70 | 71 | plot_coord_data( 72 | df, 73 | y="l1", 74 | save_to="contents/coord-check.png", 75 | suptitle=None, 76 | x="width", 77 | hue="module", 78 | ) 79 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # pyenv 132 | .python-version 133 | 134 | # vscode 135 | .vscode/ 136 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "ezmup" # Change this 3 | version = "0.0.1" # Change this 4 | description = "Really Easy MuP" # Change this 5 | readme = "README.md" 6 | license = {file="LICENSE"} 7 | authors = [ 8 | {name="Simo Ryu", email="cloneofsimo@gmail.com"} 9 | ] 10 | dependencies = [ 11 | "matplotlib", 12 | "numpy", 13 | "pandas", 14 | "torch", 15 | ] 16 | requires-python = ">=3.10, <3.12" 17 | keywords = [ 18 | "template", 19 | "python", 20 | "project", 21 | "template-project" 22 | ] # Change this 23 | classifiers = [ 24 | "Development Status :: 3 - Alpha", 25 | "Intended Audience :: Developers", 26 | "Intended Audience :: Science/Research", 27 | "License :: OSI Approved :: MIT License", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | "Environment :: GPU :: NVIDIA CUDA :: 12", 31 | "Topic :: Software Development :: Libraries :: Python Modules" 32 | ] 33 | 34 | [project.optional-dependencies] 35 | dev = [ 36 | "black", 37 | "myst-parser", 38 | "ruff", 39 | "pip-tools", 40 | "pip-autoremove", 41 | "pydata-sphinx-theme", 42 | "pytest", 43 | "pytest-cov", 44 | "pytest-html", 45 | "sphinx", 46 | "sphinxcontrib-youtube", 47 | ] 48 | 49 | [build-system] 50 | requires = ["setuptools >= 67.0", "wheel"] 51 | build-backend = "setuptools.build_meta" 52 | 53 | [tool.setuptools] 54 | packages = ["ezmup"] 55 | 56 | [tool.black] 57 | target-version = ['py310', 'py311'] 58 | include = '\.pyi?$' 59 | color = true 60 | extend-exclude = ''' 61 | /( 62 | # The following are specific to Black, you probably don't want those. 63 | \.git 64 | | \.hg 65 | | \.mypy_cache 66 | | \.tox 67 | | \.venv 68 | | _build 69 | | buck-out 70 | | build 71 | | dist 72 | | env 73 | | venv 74 | | tests/data 75 | )/ 76 | ''' 77 | 78 | [tool.ruff] 79 | include = ["*.py", "*.pyi", "**/pyproject.toml", "*.ipynb"] 80 | exclude = [ 81 | "git", 82 | "__pycache__", 83 | "data/*", 84 | "notebooks/*", 85 | "logs/*", 86 | "**/__pycache__", 87 | ".bzr", 88 | ".direnv", 89 | ".eggs", 90 | ".git", 91 | ".git-rewrite", 92 | ".hg", 93 | ".mypy_cache", 94 | ".nox", 95 | ".pants.d", 96 | ".pytype", 97 | ".ruff_cache", 98 | ".svn", 99 | ".tox", 100 | ".venv", 101 | "__pypackages__", 102 | "_build", 103 | "buck-out", 104 | "build", 105 | "dist", 106 | "node_modules", 107 | "venv", 108 | "docs", 109 | "data", 110 | ] 111 | ignore = [ 112 | "D100", 113 | "B017", 114 | "C408", 115 | "C901", 116 | "E501", 117 | "E741", 118 | "F401", 119 | "F403", 120 | "F811", 121 | "F841", 122 | "FBT002", 123 | "PD901", 124 | ] 125 | select = [ 126 | "ARG", # flake8-unused-arguments 127 | "B", 128 | "B9", 129 | "C", 130 | "COM", # flake8-commas 131 | "D", # pydocstyle 132 | "DTZ", # flake8-datetimez 133 | "E", # pycodestyle 134 | "EM", # flake8-errmsg 135 | "F", # pyflakes 136 | "FBT", # flake8-boolean-trap 137 | "G", # flake8-logging-format 138 | "ISC", # flake8-implicit-str-concat 139 | "N", # pep8-naming 140 | "NPY", # NumPy-specific rules 141 | "PD", # pandas-vet 142 | "PT", # flake8-pytest-style 143 | "PTH", # flake8-use-pathlib 144 | "RUF", # Ruff-specific rules 145 | "S", # flake8-bandit 146 | "SIM", # flake8-simplify 147 | "TID", # flake8-tidy-imports 148 | "UP", # pyupgrade 149 | "W", # pycodestyle 150 | ] 151 | # Same as Black. 152 | line-length = 88 153 | # Allow unused variables when underscore-prefixed. 154 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 155 | # Assume Python 3.10. 156 | target-version = "py310" 157 | 158 | [tool.ruff.mccabe] 159 | max-complexity = 18 160 | 161 | [tool.ruff.per-file-ignores] 162 | "**/configs/**.py" = [ 163 | "F401", 164 | "E402", 165 | ] 166 | "**/__init__.py" = [ 167 | "F401", 168 | "F403", 169 | "E402", 170 | ] 171 | "**/tests/config/**.py" = [ 172 | "F401", 173 | "E402", 174 | ] 175 | "**/tests/**.py" = [ 176 | "D100", 177 | "D103", 178 | "D104", 179 | "S101", 180 | ] 181 | "configs/**.py" = [ 182 | "F401", 183 | "E402", 184 | ] 185 | "tests/config/**.py" = [ 186 | "F401", 187 | "E402", 188 | ] 189 | 190 | [tool.ruff.pydocstyle] 191 | convention = "google" 192 | 193 | [tool.ruff.isort] 194 | known-third-party = ["numpy", "scipy", "pandas", "matplotlib", "sklearn", "tensorflow", "tqdm"] 195 | known-first-party = ["ezmup"] 196 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.11 3 | # by the following command: 4 | # 5 | # pip-compile --extra=dev --output-file=requirements-dev.txt pyproject.toml 6 | # 7 | accessible-pygments==0.0.4 8 | # via pydata-sphinx-theme 9 | alabaster==0.7.13 10 | # via sphinx 11 | babel==2.14.0 12 | # via 13 | # pydata-sphinx-theme 14 | # sphinx 15 | beautifulsoup4==4.12.2 16 | # via pydata-sphinx-theme 17 | black==23.12.1 18 | # via ezmup (pyproject.toml) 19 | build==1.0.3 20 | # via pip-tools 21 | certifi==2023.11.17 22 | # via requests 23 | charset-normalizer==3.3.2 24 | # via requests 25 | click==8.1.7 26 | # via 27 | # black 28 | # pip-tools 29 | contourpy==1.2.0 30 | # via matplotlib 31 | coverage[toml]==7.4.0 32 | # via 33 | # coverage 34 | # pytest-cov 35 | cycler==0.12.1 36 | # via matplotlib 37 | docutils==0.20.1 38 | # via 39 | # myst-parser 40 | # pydata-sphinx-theme 41 | # sphinx 42 | filelock==3.13.1 43 | # via 44 | # torch 45 | # triton 46 | fonttools==4.47.0 47 | # via matplotlib 48 | fsspec==2023.12.2 49 | # via torch 50 | idna==3.6 51 | # via requests 52 | imagesize==1.4.1 53 | # via sphinx 54 | iniconfig==2.0.0 55 | # via pytest 56 | jinja2==3.1.2 57 | # via 58 | # myst-parser 59 | # pytest-html 60 | # sphinx 61 | # torch 62 | kiwisolver==1.4.5 63 | # via matplotlib 64 | markdown-it-py==3.0.0 65 | # via 66 | # mdit-py-plugins 67 | # myst-parser 68 | markupsafe==2.1.3 69 | # via jinja2 70 | matplotlib==3.8.2 71 | # via ezmup (pyproject.toml) 72 | mdit-py-plugins==0.4.0 73 | # via myst-parser 74 | mdurl==0.1.2 75 | # via markdown-it-py 76 | mpmath==1.3.0 77 | # via sympy 78 | mypy-extensions==1.0.0 79 | # via black 80 | myst-parser==2.0.0 81 | # via ezmup (pyproject.toml) 82 | networkx==3.2.1 83 | # via torch 84 | numpy==1.26.3 85 | # via 86 | # contourpy 87 | # ezmup (pyproject.toml) 88 | # matplotlib 89 | # pandas 90 | nvidia-cublas-cu12==12.1.3.1 91 | # via 92 | # nvidia-cudnn-cu12 93 | # nvidia-cusolver-cu12 94 | # torch 95 | nvidia-cuda-cupti-cu12==12.1.105 96 | # via torch 97 | nvidia-cuda-nvrtc-cu12==12.1.105 98 | # via torch 99 | nvidia-cuda-runtime-cu12==12.1.105 100 | # via torch 101 | nvidia-cudnn-cu12==8.9.2.26 102 | # via torch 103 | nvidia-cufft-cu12==11.0.2.54 104 | # via torch 105 | nvidia-curand-cu12==10.3.2.106 106 | # via torch 107 | nvidia-cusolver-cu12==11.4.5.107 108 | # via torch 109 | nvidia-cusparse-cu12==12.1.0.106 110 | # via 111 | # nvidia-cusolver-cu12 112 | # torch 113 | nvidia-nccl-cu12==2.18.1 114 | # via torch 115 | nvidia-nvjitlink-cu12==12.3.101 116 | # via 117 | # nvidia-cusolver-cu12 118 | # nvidia-cusparse-cu12 119 | nvidia-nvtx-cu12==12.1.105 120 | # via torch 121 | packaging==23.2 122 | # via 123 | # black 124 | # build 125 | # matplotlib 126 | # pydata-sphinx-theme 127 | # pytest 128 | # sphinx 129 | pandas==2.1.4 130 | # via ezmup (pyproject.toml) 131 | pathspec==0.12.1 132 | # via black 133 | pillow==10.2.0 134 | # via matplotlib 135 | pip-autoremove==0.10.0 136 | # via ezmup (pyproject.toml) 137 | pip-tools==7.3.0 138 | # via ezmup (pyproject.toml) 139 | platformdirs==4.1.0 140 | # via black 141 | pluggy==1.3.0 142 | # via pytest 143 | pydata-sphinx-theme==0.15.1 144 | # via ezmup (pyproject.toml) 145 | pygments==2.17.2 146 | # via 147 | # accessible-pygments 148 | # pydata-sphinx-theme 149 | # sphinx 150 | pyparsing==3.1.1 151 | # via matplotlib 152 | pyproject-hooks==1.0.0 153 | # via build 154 | pytest==7.4.4 155 | # via 156 | # ezmup (pyproject.toml) 157 | # pytest-cov 158 | # pytest-html 159 | # pytest-metadata 160 | pytest-cov==4.1.0 161 | # via ezmup (pyproject.toml) 162 | pytest-html==4.1.1 163 | # via ezmup (pyproject.toml) 164 | pytest-metadata==3.0.0 165 | # via pytest-html 166 | python-dateutil==2.8.2 167 | # via 168 | # matplotlib 169 | # pandas 170 | pytz==2023.3.post1 171 | # via pandas 172 | pyyaml==6.0.1 173 | # via myst-parser 174 | requests==2.31.0 175 | # via 176 | # sphinx 177 | # sphinxcontrib-youtube 178 | ruff==0.1.11 179 | # via ezmup (pyproject.toml) 180 | six==1.16.0 181 | # via python-dateutil 182 | snowballstemmer==2.2.0 183 | # via sphinx 184 | soupsieve==2.5 185 | # via beautifulsoup4 186 | sphinx==7.2.6 187 | # via 188 | # ezmup (pyproject.toml) 189 | # myst-parser 190 | # pydata-sphinx-theme 191 | # sphinxcontrib-applehelp 192 | # sphinxcontrib-devhelp 193 | # sphinxcontrib-htmlhelp 194 | # sphinxcontrib-qthelp 195 | # sphinxcontrib-serializinghtml 196 | # sphinxcontrib-youtube 197 | sphinxcontrib-applehelp==1.0.7 198 | # via sphinx 199 | sphinxcontrib-devhelp==1.0.5 200 | # via sphinx 201 | sphinxcontrib-htmlhelp==2.0.4 202 | # via sphinx 203 | sphinxcontrib-jsmath==1.0.1 204 | # via sphinx 205 | sphinxcontrib-qthelp==1.0.6 206 | # via sphinx 207 | sphinxcontrib-serializinghtml==1.1.9 208 | # via sphinx 209 | sphinxcontrib-youtube==1.4.1 210 | # via ezmup (pyproject.toml) 211 | sympy==1.12 212 | # via torch 213 | torch==2.1.2 214 | # via ezmup (pyproject.toml) 215 | triton==2.1.0 216 | # via torch 217 | typing-extensions==4.9.0 218 | # via 219 | # pydata-sphinx-theme 220 | # torch 221 | tzdata==2023.4 222 | # via pandas 223 | urllib3==2.1.0 224 | # via requests 225 | wheel==0.42.0 226 | # via pip-tools 227 | 228 | # The following packages are considered to be unsafe in a requirements file: 229 | # pip 230 | # setuptools 231 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Minimal Implementation of muP (Maximal Update Parametrization) that happens to be also Easy 2 | 3 | > This is radical implementation of the muP algorithm (Maximal Update Parametrization) for the paper [Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer](https://arxiv.org/abs/2203.03466) and [A Spectral Condition for Feature Learning](https://arxiv.org/abs/2310.17813), series of research driven by [Greg Yang](https://thegregyang.com/). *This is not an official implementation, which can be found [here](https://github.com/microsoft/mup).* 4 | 5 | ## What is muP? 6 | 7 | 1. You want to train a large NN with pytorch (you have a *width* as a hyperparameter) 8 | 2. You want to find optimal hyperparameters, such as learning rate, adam beta, lr scheduler, init varaince, etc (see list of hparams that are relevent [here](#list-of-hyperparameters). Note that not all hparams are relevant to muP!) 9 | 3. So of course you are going to search for the optimal hparams using the *small model*, and apply it to the *big model*... right? 10 | 11 | ![Alt text](contents/meme.png) 12 | 13 | But as you all know, with *default pytorch settings*, you can't do this! You often have to **search the hparams for the big model, because hparams you find on the small model is often not applicable to the big model.** 14 | 15 | However, this can be fixed! *muP* is a method of **scaling width** so that you can search the hyperparams on the smaller model (small width), and then apply it to the big model (large width): **you can transfer the hparams from small model to big model!, that saves you a lot of time and money!** 16 | 17 | ![Alt text](contents/image.png) 18 | 19 | > You see? in practice, your optimal Learing rate found with width 128 does not transfer to width 8192. But with muP, you can transfer the hyperparameters from small model to big model! Figure directly from [Tensor Programs V](https://arxiv.org/abs/2203.03466) 20 | 21 | ## How to use this repo? 22 | 23 | This package suggest the following approach: 24 | 25 | First, in your project, change the config so that some weird large enough prime number represents the *varying width*. By *weird*, I mean such number should not be used in your other hyperparameters of model shapes. 47 is such good number. 26 | 27 | ```python 28 | 29 | from ezmup import Ezmup 30 | 31 | model = MyModel(ff_width = 47 * 32, hidden_dim = 47 * 8, head_dim = 47, num_layers = 4, ...) 32 | 33 | ``` 34 | 35 | Then, you can use the ezmup package so that it changes all of the width variable to value you want. 36 | 37 | 38 | ```python 39 | mup_engine = Ezmup(width_basis = 47, model = model) # this will treat 47 as a width variable. i.e., replace 47 occuring in all of the model as a variable length. 40 | 41 | # do hyperparameter-search, such as scaling, with fixed width. here, we take 32. 42 | mup_engine.change_width_as(32) 43 | 44 | optimizer = mup_engine.get_optimizer(lr = test_lr) 45 | 46 | # now after you get the best parameters, you can apply them to your desired width. 47 | mup_engine.change_width_as(1024) 48 | 49 | # if you want to use the model for other purposes, say, other frameworks just save them as state-dict, safetensors, etc. 50 | torch.save(model.state_dict(), 'model.pt') 51 | 52 | ``` 53 | 54 | 55 | In that way, we can 56 | 57 | 1. Use muP in our own projects without too much hassle (just set the varying-width you desire to be prime number like 47! Explained later) 58 | 2. Check the correctness of muP via Coord-Checking 59 | 60 | Oh well, this is exactly how you use this package! The code is very minimal, so do have a look as well. 61 | 62 | ## Installation 63 | 64 | ```bash 65 | python3 -m pip install ezmup+https://github.com/cloneofsimo/ezmup.git 66 | ``` 67 | 68 | ## Other methods: 69 | 70 | The code does two things: 71 | 72 | 1. Without needing to replace the code-layers in your implementation, it changes the width variable to the value you want, by actually replacing the neural network parameters into desired ones. While doing that, it will use a standard scaling InitStd, which is also some value you can change (per layer!) 73 | 74 | 2. It also tells you the learning-rate scaling factor you need per layer, which is a varying factor depending on the width. 75 | 76 | Say you have know that you want different learning rate for different layers such as Embedding and FFN. 77 | 78 | Then, you can do the following: 79 | 80 | ```python 81 | # You can do it manually by getting the parameter_name, lr_scaling dictionary. 82 | mup_engine = Ezmup(width_basis = 47, model = model) 83 | mup_scaling = mup_engine.lr_scaling_dict 84 | 85 | lr_dict = { 86 | 'embedding' : 1e-3, 87 | 'ffn' : 1e-4 88 | } 89 | 90 | optimizer_groups = [ 91 | {'params' : [p] , 'lr' : lr_dict[name] * mup_scaling[p_name] } for p_name, p in model.named_parameters() 92 | ] 93 | optimizer = Adam(optimizer_groups, **kwargs) 94 | ``` 95 | 96 | 97 | # List of hyperparameters 98 | 99 | Not all hyperparameters are relevant to muP. Here is the list of hyperparameters that are relevant to muP: 100 | 101 | ![Alt text](contents/hyperparams.png) 102 | 103 | ## So what is with the prime number? 104 | 105 | Ok so the difficulty of muP implementation for the drop-in replacement is that you need to change the width variable *automagically*. muP requires you to identify which `fan_in` or `fan_out` refers to *infinite-width*. 106 | 107 | So, the idea is to use a prime number as a width variable, which is then used as an indicator via dividability : if *some* shape is multiple of this prime number, then it is infinite-width. With design choice, users don't have to manually specify which shape is infinite-width, nor change the code. 108 | 109 | 110 | ## Coord-Checking 111 | 112 | You can see how to run the coord checking in the `example.py` file. The result should looks as: 113 | 114 | ![Alt text](contents/coord-check.png) 115 | 116 | 117 | # For Developers 118 | ## Installation 119 | 120 | ```bash 121 | git clone https://github.com/cloneofsimo/ezmup.git 122 | cd ezmup 123 | python3 -m pip install ".[dev]" 124 | ``` 125 | 126 | ## Build docs 127 | ```bash 128 | sphinx-apidoc -f -o docs/source ezmup 129 | cd docs 130 | make html 131 | ``` 132 | 133 | ## Update `requirements.txt` and `requirements-dev.txt` from `pyproject.toml` (Use `pip-compile`) 134 | ```bash 135 | pip-compile -o requirements.txt pyproject.toml 136 | pip-compile --extra dev -o requirements-dev.txt pyproject.toml 137 | ``` 138 | -------------------------------------------------------------------------------- /contents/example.csv: -------------------------------------------------------------------------------- 1 | ,width,module,t,l1 2 | 0,64,fin,0,0.8149098753929138 3 | 1,64,layers.0.query,0,0.5772204995155334 4 | 2,64,layers.0.key,0,0.581022322177887 5 | 3,64,layers.0.value,0,0.5941454768180847 6 | 4,64,layers.0,0,0.3674468398094177 7 | 5,64,layers.1.query,0,0.17641261219978333 8 | 6,64,layers.1.key,0,0.24193255603313446 9 | 7,64,layers.1.value,0,0.2423495352268219 10 | 8,64,layers.1,0,0.24230468273162842 11 | 9,64,layers.2.query,0,0.17052577435970306 12 | 10,64,layers.2.key,0,0.1610855758190155 13 | 11,64,layers.2.value,0,0.18306007981300354 14 | 12,64,layers.2,0,0.18306009471416473 15 | 13,64,layers.3.query,0,0.11143392324447632 16 | 14,64,layers.3.key,0,0.14773796498775482 17 | 15,64,layers.3.value,0,0.13598960638046265 18 | 16,64,layers.3,0,0.13598962128162384 19 | 17,64,fout,0,0.10186693072319031 20 | 18,64,,0,0.10186693072319031 21 | 19,64,fin,1,0.8148900866508484 22 | 20,64,layers.0.query,1,0.5771677494049072 23 | 21,64,layers.0.key,1,0.5809863209724426 24 | 22,64,layers.0.value,1,0.5941069722175598 25 | 23,64,layers.0,1,0.36739522218704224 26 | 24,64,layers.1.query,1,0.17632868885993958 27 | 25,64,layers.1.key,1,0.24185262620449066 28 | 26,64,layers.1.value,1,0.24217648804187775 29 | 27,64,layers.1,1,0.24213044345378876 30 | 28,64,layers.2.query,1,0.17040590941905975 31 | 29,64,layers.2.key,1,0.16093800961971283 32 | 30,64,layers.2.value,1,0.18281416594982147 33 | 31,64,layers.2,1,0.18281416594982147 34 | 32,64,layers.3.query,1,0.11127031594514847 35 | 33,64,layers.3.key,1,0.14751508831977844 36 | 34,64,layers.3.value,1,0.13572008907794952 37 | 35,64,layers.3,1,0.13572008907794952 38 | 36,64,fout,1,0.10114678740501404 39 | 37,64,,1,0.10114678740501404 40 | 38,64,fin,2,0.8148705363273621 41 | 39,64,layers.0.query,2,0.5771153569221497 42 | 40,64,layers.0.key,2,0.5809506177902222 43 | 41,64,layers.0.value,2,0.5940680503845215 44 | 42,64,layers.0,2,0.36734333634376526 45 | 43,64,layers.1.query,2,0.17624594271183014 46 | 44,64,layers.1.key,2,0.24177080392837524 47 | 45,64,layers.1.value,2,0.24200372397899628 48 | 46,64,layers.1,2,0.24195601046085358 49 | 47,64,layers.2.query,2,0.17028766870498657 50 | 48,64,layers.2.key,2,0.16079214215278625 51 | 49,64,layers.2.value,2,0.18256743252277374 52 | 50,64,layers.2,2,0.18256743252277374 53 | 51,64,layers.3.query,2,0.11110606044530869 54 | 52,64,layers.3.key,2,0.1472921371459961 55 | 53,64,layers.3.value,2,0.13545045256614685 56 | 54,64,layers.3,2,0.13545045256614685 57 | 55,64,fout,2,0.10042726248502731 58 | 56,64,,2,0.10042726248502731 59 | 57,128,fin,0,0.7967124581336975 60 | 58,128,layers.0.query,0,0.5854323506355286 61 | 59,128,layers.0.key,0,0.5547754764556885 62 | 60,128,layers.0.value,0,0.5630514025688171 63 | 61,128,layers.0,0,0.34021008014678955 64 | 62,128,layers.1.query,0,0.24652083218097687 65 | 63,128,layers.1.key,0,0.23194226622581482 66 | 64,128,layers.1.value,0,0.21953068673610687 67 | 65,128,layers.1,0,0.21947798132896423 68 | 66,128,layers.2.query,0,0.18088005483150482 69 | 67,128,layers.2.key,0,0.17130425572395325 70 | 68,128,layers.2.value,0,0.1707664281129837 71 | 69,128,layers.2,0,0.1707664281129837 72 | 70,128,layers.3.query,0,0.1157296821475029 73 | 71,128,layers.3.key,0,0.11210741102695465 74 | 72,128,layers.3.value,0,0.1076396256685257 75 | 73,128,layers.3,0,0.10763963311910629 76 | 74,128,fout,0,0.08033119887113571 77 | 75,128,,0,0.08033119887113571 78 | 76,128,fin,1,0.7966850996017456 79 | 77,128,layers.0.query,1,0.5854005813598633 80 | 78,128,layers.0.key,1,0.5547543168067932 81 | 79,128,layers.0.value,1,0.5629923343658447 82 | 80,128,layers.0,1,0.3400992751121521 83 | 81,128,layers.1.query,1,0.2463548332452774 84 | 82,128,layers.1.key,1,0.2318057268857956 85 | 83,128,layers.1.value,1,0.21936655044555664 86 | 84,128,layers.1,1,0.21931549906730652 87 | 85,128,layers.2.query,1,0.18071752786636353 88 | 86,128,layers.2.key,1,0.17112348973751068 89 | 87,128,layers.2.value,1,0.17058396339416504 90 | 88,128,layers.2,1,0.17058396339416504 91 | 89,128,layers.3.query,1,0.11543787270784378 92 | 90,128,layers.3.key,1,0.11180929839611053 93 | 91,128,layers.3.value,1,0.10727155208587646 94 | 92,128,layers.3,1,0.10727155953645706 95 | 93,128,fout,1,0.07958870381116867 96 | 94,128,,1,0.07958870381116867 97 | 95,128,fin,2,0.7966573238372803 98 | 96,128,layers.0.query,2,0.5853700041770935 99 | 97,128,layers.0.key,2,0.5547330975532532 100 | 98,128,layers.0.value,2,0.5629326105117798 101 | 99,128,layers.0,2,0.33998575806617737 102 | 100,128,layers.1.query,2,0.2461901307106018 103 | 101,128,layers.1.key,2,0.23166987299919128 104 | 102,128,layers.1.value,2,0.21919913589954376 105 | 103,128,layers.1,2,0.21914896368980408 106 | 104,128,layers.2.query,2,0.18055394291877747 107 | 105,128,layers.2.key,2,0.17094431817531586 108 | 106,128,layers.2.value,2,0.17040283977985382 109 | 107,128,layers.2,2,0.17040283977985382 110 | 108,128,layers.3.query,2,0.11514024436473846 111 | 109,128,layers.3.key,2,0.11150948703289032 112 | 110,128,layers.3.value,2,0.10690951347351074 113 | 111,128,layers.3,2,0.10690951347351074 114 | 112,128,fout,2,0.07887613028287888 115 | 113,128,,2,0.07887613028287888 116 | 114,256,fin,0,0.8135337233543396 117 | 115,256,layers.0.query,0,0.581088662147522 118 | 116,256,layers.0.key,0,0.5684283971786499 119 | 117,256,layers.0.value,0,0.5643811821937561 120 | 118,256,layers.0,0,0.32293927669525146 121 | 119,256,layers.1.query,0,0.2351231724023819 122 | 120,256,layers.1.key,0,0.2321474850177765 123 | 121,256,layers.1.value,0,0.21924404799938202 124 | 122,256,layers.1,0,0.21911177039146423 125 | 123,256,layers.2.query,0,0.15226450562477112 126 | 124,256,layers.2.key,0,0.1517157107591629 127 | 125,256,layers.2.value,0,0.16001679003238678 128 | 126,256,layers.2,0,0.16001679003238678 129 | 127,256,layers.3.query,0,0.10916223376989365 130 | 128,256,layers.3.key,0,0.11315370351076126 131 | 129,256,layers.3.value,0,0.10016275197267532 132 | 130,256,layers.3,0,0.10016273707151413 133 | 131,256,fout,0,0.0744924545288086 134 | 132,256,,0,0.0744924545288086 135 | 133,256,fin,1,0.8135054111480713 136 | 134,256,layers.0.query,1,0.581052839756012 137 | 135,256,layers.0.key,1,0.5684000849723816 138 | 136,256,layers.0.value,1,0.5643109679222107 139 | 137,256,layers.0,1,0.32283008098602295 140 | 138,256,layers.1.query,1,0.23495745658874512 141 | 139,256,layers.1.key,1,0.23206350207328796 142 | 140,256,layers.1.value,1,0.21901115775108337 143 | 141,256,layers.1,1,0.21888233721256256 144 | 142,256,layers.2.query,1,0.15209117531776428 145 | 143,256,layers.2.key,1,0.1514565348625183 146 | 144,256,layers.2.value,1,0.15985962748527527 147 | 145,256,layers.2,1,0.15985962748527527 148 | 146,256,layers.3.query,1,0.1089741438627243 149 | 147,256,layers.3.key,1,0.11287224292755127 150 | 148,256,layers.3.value,1,0.09990844875574112 151 | 149,256,layers.3,1,0.09990844130516052 152 | 150,256,fout,1,0.07373299449682236 153 | 151,256,,1,0.07373299449682236 154 | 152,256,fin,2,0.8134771585464478 155 | 153,256,layers.0.query,2,0.5810161828994751 156 | 154,256,layers.0.key,2,0.5683727860450745 157 | 155,256,layers.0.value,2,0.5642446279525757 158 | 156,256,layers.0,2,0.3227340281009674 159 | 157,256,layers.1.query,2,0.2347920835018158 160 | 158,256,layers.1.key,2,0.23199455440044403 161 | 159,256,layers.1.value,2,0.2187986820936203 162 | 160,256,layers.1,2,0.21867269277572632 163 | 161,256,layers.2.query,2,0.15195396542549133 164 | 162,256,layers.2.key,2,0.1511940062046051 165 | 163,256,layers.2.value,2,0.15972399711608887 166 | 164,256,layers.2,2,0.15972401201725006 167 | 165,256,layers.3.query,2,0.10879963636398315 168 | 166,256,layers.3.key,2,0.11261562258005142 169 | 167,256,layers.3.value,2,0.09967256337404251 170 | 168,256,layers.3,2,0.09967257082462311 171 | 169,256,fout,2,0.07303734123706818 172 | 170,256,,2,0.07303734123706818 173 | 171,512,fin,0,0.8129843473434448 174 | 172,512,layers.0.query,0,0.5762840509414673 175 | 173,512,layers.0.key,0,0.5762239098548889 176 | 174,512,layers.0.value,0,0.5780788660049438 177 | 175,512,layers.0,0,0.3462582230567932 178 | 176,512,layers.1.query,0,0.2269355058670044 179 | 177,512,layers.1.key,0,0.22915934026241302 180 | 178,512,layers.1.value,0,0.22600020468235016 181 | 179,512,layers.1,0,0.22568926215171814 182 | 180,512,layers.2.query,0,0.1657772958278656 183 | 181,512,layers.2.key,0,0.1607837975025177 184 | 182,512,layers.2.value,0,0.15545342862606049 185 | 183,512,layers.2,0,0.15545342862606049 186 | 184,512,layers.3.query,0,0.11198598891496658 187 | 185,512,layers.3.key,0,0.10839514434337616 188 | 186,512,layers.3.value,0,0.11297659575939178 189 | 187,512,layers.3,0,0.11297659575939178 190 | 188,512,fout,0,0.06560536473989487 191 | 189,512,,0,0.06560536473989487 192 | 190,512,fin,1,0.8129409551620483 193 | 191,512,layers.0.query,1,0.5762478113174438 194 | 192,512,layers.0.key,1,0.5762039422988892 195 | 193,512,layers.0.value,1,0.5779631733894348 196 | 194,512,layers.0,1,0.3460577130317688 197 | 195,512,layers.1.query,1,0.22681958973407745 198 | 196,512,layers.1.key,1,0.22915402054786682 199 | 197,512,layers.1.value,1,0.22572264075279236 200 | 198,512,layers.1,1,0.22541484236717224 201 | 199,512,layers.2.query,1,0.1656753122806549 202 | 200,512,layers.2.key,1,0.16069597005844116 203 | 201,512,layers.2.value,1,0.15516871213912964 204 | 202,512,layers.2,1,0.15516872704029083 205 | 203,512,layers.3.query,1,0.11176659166812897 206 | 204,512,layers.3.key,1,0.10822875797748566 207 | 205,512,layers.3.value,1,0.1128641739487648 208 | 206,512,layers.3,1,0.1128641739487648 209 | 207,512,fout,1,0.06447132676839828 210 | 208,512,,1,0.06447132676839828 211 | 209,512,fin,2,0.8129048347473145 212 | 210,512,layers.0.query,2,0.5762245059013367 213 | 211,512,layers.0.key,2,0.5761887431144714 214 | 212,512,layers.0.value,2,0.5778694152832031 215 | 213,512,layers.0,2,0.34588584303855896 216 | 214,512,layers.1.query,2,0.2267587035894394 217 | 215,512,layers.1.key,2,0.22920525074005127 218 | 216,512,layers.1.value,2,0.22550083696842194 219 | 217,512,layers.1,2,0.22518934309482574 220 | 218,512,layers.2.query,2,0.16565018892288208 221 | 219,512,layers.2.key,2,0.1606561243534088 222 | 220,512,layers.2.value,2,0.15494635701179504 223 | 221,512,layers.2,2,0.15494635701179504 224 | 222,512,layers.3.query,2,0.11158950626850128 225 | 223,512,layers.3.key,2,0.10808990150690079 226 | 224,512,layers.3.value,2,0.11280868202447891 227 | 225,512,layers.3,2,0.11280868947505951 228 | 226,512,fout,2,0.06335920840501785 229 | 227,512,,2,0.06335920840501785 230 | 228,1024,fin,0,0.8044828176498413 231 | 229,1024,layers.0.query,0,0.5675575733184814 232 | 230,1024,layers.0.key,0,0.5664249658584595 233 | 231,1024,layers.0.value,0,0.5735787153244019 234 | 232,1024,layers.0,0,0.34725043177604675 235 | 233,1024,layers.1.query,0,0.24275292456150055 236 | 234,1024,layers.1.key,0,0.23795805871486664 237 | 235,1024,layers.1.value,0,0.24160075187683105 238 | 236,1024,layers.1,0,0.24107715487480164 239 | 237,1024,layers.2.query,0,0.1675582379102707 240 | 238,1024,layers.2.key,0,0.16335546970367432 241 | 239,1024,layers.2.value,0,0.16832022368907928 242 | 240,1024,layers.2,0,0.16832022368907928 243 | 241,1024,layers.3.query,0,0.1206839457154274 244 | 242,1024,layers.3.key,0,0.12115409970283508 245 | 243,1024,layers.3.value,0,0.11666642874479294 246 | 244,1024,layers.3,0,0.11666642874479294 247 | 245,1024,fout,0,0.06721170991659164 248 | 246,1024,,0,0.06721170991659164 249 | 247,1024,fin,1,0.8044159412384033 250 | 248,1024,layers.0.query,1,0.5673855543136597 251 | 249,1024,layers.0.key,1,0.5663749575614929 252 | 250,1024,layers.0.value,1,0.5733946561813354 253 | 251,1024,layers.0,1,0.34696149826049805 254 | 252,1024,layers.1.query,1,0.24206887185573578 255 | 253,1024,layers.1.key,1,0.23740629851818085 256 | 254,1024,layers.1.value,1,0.24068087339401245 257 | 255,1024,layers.1,1,0.2401561439037323 258 | 256,1024,layers.2.query,1,0.16700904071331024 259 | 257,1024,layers.2.key,1,0.16271911561489105 260 | 258,1024,layers.2.value,1,0.16755135357379913 261 | 259,1024,layers.2,1,0.16755133867263794 262 | 260,1024,layers.3.query,1,0.1200239434838295 263 | 261,1024,layers.3.key,1,0.12011966109275818 264 | 262,1024,layers.3.value,1,0.115775927901268 265 | 263,1024,layers.3,1,0.115775927901268 266 | 264,1024,fout,1,0.0655176043510437 267 | 265,1024,,1,0.0655176043510437 268 | 266,1024,fin,2,0.8043627142906189 269 | 267,1024,layers.0.query,2,0.5673120021820068 270 | 268,1024,layers.0.key,2,0.566342830657959 271 | 269,1024,layers.0.value,2,0.5732734203338623 272 | 270,1024,layers.0,2,0.346832275390625 273 | 271,1024,layers.1.query,2,0.24160249531269073 274 | 272,1024,layers.1.key,2,0.23704542219638824 275 | 273,1024,layers.1.value,2,0.2399766445159912 276 | 274,1024,layers.1,2,0.23942111432552338 277 | 275,1024,layers.2.query,2,0.16663087904453278 278 | 276,1024,layers.2.key,2,0.16221968829631805 279 | 277,1024,layers.2.value,2,0.1670490950345993 280 | 278,1024,layers.2,2,0.1670490801334381 281 | 279,1024,layers.3.query,2,0.11962497234344482 282 | 280,1024,layers.3.key,2,0.1193583682179451 283 | 281,1024,layers.3.value,2,0.11511623114347458 284 | 282,1024,layers.3,2,0.11511623114347458 285 | 283,1024,fout,2,0.06397029757499695 286 | 284,1024,,2,0.06397029757499695 287 | 285,2048,fin,0,0.8067775368690491 288 | 286,2048,layers.0.query,0,0.5643748044967651 289 | 287,2048,layers.0.key,0,0.5702745914459229 290 | 288,2048,layers.0.value,0,0.5669212937355042 291 | 289,2048,layers.0,0,0.33085957169532776 292 | 290,2048,layers.1.query,0,0.2401261329650879 293 | 291,2048,layers.1.key,0,0.23920074105262756 294 | 292,2048,layers.1.value,0,0.22682078182697296 295 | 293,2048,layers.1,0,0.22560590505599976 296 | 294,2048,layers.2.query,0,0.15996049344539642 297 | 295,2048,layers.2.key,0,0.1558251827955246 298 | 296,2048,layers.2.value,0,0.1603100299835205 299 | 297,2048,layers.2,0,0.1603100299835205 300 | 298,2048,layers.3.query,0,0.12006624788045883 301 | 299,2048,layers.3.key,0,0.11523381620645523 302 | 300,2048,layers.3.value,0,0.12021595239639282 303 | 301,2048,layers.3,0,0.12021595239639282 304 | 302,2048,fout,0,0.0644032433629036 305 | 303,2048,,0,0.0644032433629036 306 | 304,2048,fin,1,0.8067179322242737 307 | 305,2048,layers.0.query,1,0.5643737316131592 308 | 306,2048,layers.0.key,1,0.5702306628227234 309 | 307,2048,layers.0.value,1,0.5668097138404846 310 | 308,2048,layers.0,1,0.3313599228858948 311 | 309,2048,layers.1.query,1,0.24051904678344727 312 | 310,2048,layers.1.key,1,0.23936232924461365 313 | 311,2048,layers.1.value,1,0.22722578048706055 314 | 312,2048,layers.1,1,0.22598044574260712 315 | 313,2048,layers.2.query,1,0.15993644297122955 316 | 314,2048,layers.2.key,1,0.15580180287361145 317 | 315,2048,layers.2.value,1,0.1605413854122162 318 | 316,2048,layers.2,1,0.16054140031337738 319 | 317,2048,layers.3.query,1,0.12049165368080139 320 | 318,2048,layers.3.key,1,0.11510790139436722 321 | 319,2048,layers.3.value,1,0.12005327641963959 322 | 320,2048,layers.3,1,0.12005327641963959 323 | 321,2048,fout,1,0.0611707977950573 324 | 322,2048,,1,0.0611707977950573 325 | 323,2048,fin,2,0.8067548871040344 326 | 324,2048,layers.0.query,2,0.5646882653236389 327 | 325,2048,layers.0.key,2,0.5702722668647766 328 | 326,2048,layers.0.value,2,0.5669767260551453 329 | 327,2048,layers.0,2,0.33317604660987854 330 | 328,2048,layers.1.query,2,0.24184982478618622 331 | 329,2048,layers.1.key,2,0.2406063675880432 332 | 330,2048,layers.1.value,2,0.22914132475852966 333 | 331,2048,layers.1,2,0.22787687182426453 334 | 332,2048,layers.2.query,2,0.16122350096702576 335 | 333,2048,layers.2.key,2,0.15707577764987946 336 | 334,2048,layers.2.value,2,0.1621721237897873 337 | 335,2048,layers.2,2,0.1621721237897873 338 | 336,2048,layers.3.query,2,0.12229496240615845 339 | 337,2048,layers.3.key,2,0.11628807336091995 340 | 338,2048,layers.3.value,2,0.12139161676168442 341 | 339,2048,layers.3,2,0.12139161676168442 342 | 340,2048,fout,2,0.058184914290905 343 | 341,2048,,2,0.058184914290905 344 | 342,4096,fin,0,0.7988824844360352 345 | 343,4096,layers.0.query,0,0.5606533288955688 346 | 344,4096,layers.0.key,0,0.5624681115150452 347 | 345,4096,layers.0.value,0,0.5628334879875183 348 | 346,4096,layers.0,0,0.3314320743083954 349 | 347,4096,layers.1.query,0,0.2283448427915573 350 | 348,4096,layers.1.key,0,0.23269470036029816 351 | 349,4096,layers.1.value,0,0.23543527722358704 352 | 350,4096,layers.1,0,0.2332807034254074 353 | 351,4096,layers.2.query,0,0.1615886390209198 354 | 352,4096,layers.2.key,0,0.1679946780204773 355 | 353,4096,layers.2.value,0,0.16637186706066132 356 | 354,4096,layers.2,0,0.16637185215950012 357 | 355,4096,layers.3.query,0,0.12361430376768112 358 | 356,4096,layers.3.key,0,0.11976416409015656 359 | 357,4096,layers.3.value,0,0.1222844123840332 360 | 358,4096,layers.3,0,0.1222844123840332 361 | 359,4096,fout,0,0.06032772362232208 362 | 360,4096,,0,0.06032772362232208 363 | 361,4096,fin,1,0.7990217804908752 364 | 362,4096,layers.0.query,1,0.5608536601066589 365 | 363,4096,layers.0.key,1,0.5626009106636047 366 | 364,4096,layers.0.value,1,0.5635536909103394 367 | 365,4096,layers.0,1,0.3360650837421417 368 | 366,4096,layers.1.query,1,0.232390895485878 369 | 367,4096,layers.1.key,1,0.2356625199317932 370 | 368,4096,layers.1.value,1,0.2396964728832245 371 | 369,4096,layers.1,1,0.2373056262731552 372 | 370,4096,layers.2.query,1,0.16517041623592377 373 | 371,4096,layers.2.key,1,0.17163842916488647 374 | 372,4096,layers.2.value,1,0.1708918809890747 375 | 373,4096,layers.2,1,0.17089185118675232 376 | 374,4096,layers.3.query,1,0.1273987591266632 377 | 375,4096,layers.3.key,1,0.1231566071510315 378 | 376,4096,layers.3.value,1,0.12683652341365814 379 | 377,4096,layers.3,1,0.12683652341365814 380 | 378,4096,fout,1,0.05613601207733154 381 | 379,4096,,1,0.05613601207733154 382 | 380,4096,fin,2,0.7994076013565063 383 | 381,4096,layers.0.query,2,0.5614911913871765 384 | 382,4096,layers.0.key,2,0.562955379486084 385 | 383,4096,layers.0.value,2,0.5651777386665344 386 | 384,4096,layers.0,2,0.34917548298835754 387 | 385,4096,layers.1.query,2,0.24306939542293549 388 | 386,4096,layers.1.key,2,0.2457692176103592 389 | 387,4096,layers.1.value,2,0.25185996294021606 390 | 388,4096,layers.1,2,0.2489781379699707 391 | 389,4096,layers.2.query,2,0.17529898881912231 392 | 390,4096,layers.2.key,2,0.18186745047569275 393 | 391,4096,layers.2.value,2,0.18276286125183105 394 | 392,4096,layers.2,2,0.18276286125183105 395 | 393,4096,layers.3.query,2,0.13728417456150055 396 | 394,4096,layers.3.key,2,0.13262803852558136 397 | 395,4096,layers.3.value,2,0.1378001570701599 398 | 396,4096,layers.3,2,0.1378001570701599 399 | 397,4096,fout,2,0.053119465708732605 400 | 398,4096,,2,0.053119465708732605 401 | 399,8192,fin,0,0.8027874827384949 402 | 400,8192,layers.0.query,0,0.5685182213783264 403 | 401,8192,layers.0.key,0,0.5699728727340698 404 | 402,8192,layers.0.value,0,0.5704889893531799 405 | 403,8192,layers.0,0,0.3496896028518677 406 | 404,8192,layers.1.query,0,0.2447047233581543 407 | 405,8192,layers.1.key,0,0.24606159329414368 408 | 406,8192,layers.1.value,0,0.24192270636558533 409 | 407,8192,layers.1,0,0.23695224523544312 410 | 408,8192,layers.2.query,0,0.17106039822101593 411 | 409,8192,layers.2.key,0,0.1683071106672287 412 | 410,8192,layers.2.value,0,0.16687315702438354 413 | 411,8192,layers.2,0,0.16687312722206116 414 | 412,8192,layers.3.query,0,0.11674073338508606 415 | 413,8192,layers.3.key,0,0.1190183088183403 416 | 414,8192,layers.3.value,0,0.11664886772632599 417 | 415,8192,layers.3,0,0.11664886772632599 418 | 416,8192,fout,0,0.06016528233885765 419 | 417,8192,,0,0.06016528233885765 420 | 418,8192,fin,1,0.8029922246932983 421 | 419,8192,layers.0.query,1,0.5698066353797913 422 | 420,8192,layers.0.key,1,0.5701940059661865 423 | 421,8192,layers.0.value,1,0.5715948343276978 424 | 422,8192,layers.0,1,0.4077516198158264 425 | 423,8192,layers.1.query,1,0.28662800788879395 426 | 424,8192,layers.1.key,1,0.28804048895835876 427 | 425,8192,layers.1.value,1,0.2864210903644562 428 | 426,8192,layers.1,1,0.2807700037956238 429 | 427,8192,layers.2.query,1,0.20267140865325928 430 | 428,8192,layers.2.key,1,0.19874384999275208 431 | 429,8192,layers.2.value,1,0.20031467080116272 432 | 430,8192,layers.2,1,0.20031426846981049 433 | 431,8192,layers.3.query,1,0.14127890765666962 434 | 432,8192,layers.3.key,1,0.14288511872291565 435 | 433,8192,layers.3.value,1,0.1436128318309784 436 | 434,8192,layers.3,1,0.1436128318309784 437 | 435,8192,fout,1,0.05321292206645012 438 | 436,8192,,1,0.05321292206645012 439 | 437,8192,fin,2,0.8035334944725037 440 | 438,8192,layers.0.query,2,0.571211040019989 441 | 439,8192,layers.0.key,2,0.5706230401992798 442 | 440,8192,layers.0.value,2,0.5751043558120728 443 | 441,8192,layers.0,2,0.5039435029029846 444 | 442,8192,layers.1.query,2,0.35737693309783936 445 | 443,8192,layers.1.key,2,0.3586585819721222 446 | 444,8192,layers.1.value,2,0.3617681562900543 447 | 445,8192,layers.1,2,0.35565856099128723 448 | 446,8192,layers.2.query,2,0.26041749119758606 449 | 447,8192,layers.2.key,2,0.2544698417186737 450 | 448,8192,layers.2.value,2,0.2612403929233551 451 | 449,8192,layers.2,2,0.261238157749176 452 | 450,8192,layers.3.query,2,0.18750232458114624 453 | 451,8192,layers.3.key,2,0.18866753578186035 454 | 452,8192,layers.3.value,2,0.19270314276218414 455 | 453,8192,layers.3,2,0.19270314276218414 456 | 454,8192,fout,2,0.04850998893380165 457 | 455,8192,,2,0.04850998893380165 458 | -------------------------------------------------------------------------------- /ezmup/ezmup.py: -------------------------------------------------------------------------------- 1 | import math 2 | from copy import copy 3 | from typing import Any 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from torch import nn 10 | from torch.optim import Adam 11 | 12 | 13 | def spectral_sigma(fan_in, fan_out, init_std): 14 | """Spectral parameterization from the [paper](https://arxiv.org/abs/2310.17813).""" 15 | return (init_std / math.sqrt(fan_in)) * min(1, math.sqrt(fan_out / fan_in)) 16 | 17 | 18 | def spectral_lr(fan_in, fan_out): 19 | """Spectral parameterization from the [paper](https://arxiv.org/abs/2310.17813).""" 20 | return fan_out / fan_in 21 | 22 | 23 | SPECTRAL_DEFAULT = (spectral_sigma, spectral_lr) 24 | 25 | 26 | LAYER_REGISTRY = { 27 | "Conv1d.weight": SPECTRAL_DEFAULT, 28 | "Conv1d.bias": SPECTRAL_DEFAULT, 29 | "Conv2d.weight": SPECTRAL_DEFAULT, 30 | "Conv2d.bias": SPECTRAL_DEFAULT, 31 | "Conv3d.weight": SPECTRAL_DEFAULT, 32 | "Conv3d.bias": SPECTRAL_DEFAULT, 33 | "ConvTranspose1d.weight": SPECTRAL_DEFAULT, 34 | "ConvTranspose1d.bias": SPECTRAL_DEFAULT, 35 | "ConvTranspose2d.weight": SPECTRAL_DEFAULT, 36 | "ConvTranspose2d.bias": SPECTRAL_DEFAULT, 37 | "ConvTranspose3d.weight": SPECTRAL_DEFAULT, 38 | "ConvTranspose3d.bias": SPECTRAL_DEFAULT, 39 | "Linear.weight": SPECTRAL_DEFAULT, 40 | "Linear.bias": SPECTRAL_DEFAULT, 41 | "Embedding.weight": (1.0, 1.0), 42 | "BatchNorm2d.weight": (1.0, 1.0), 43 | "BatchNorm2d.bias": (0.0, 1.0), 44 | "LayerNorm.weight": (1.0, 1.0), 45 | "LayerNorm.bias": (0.0, 1.0), 46 | "GroupNorm.weight": (1.0, 1.0), 47 | "GroupNorm.bias": (0.0, 1.0), 48 | } 49 | 50 | 51 | class Ezmup: 52 | """Easier maximal update parametrization(muP).""" 53 | 54 | def __init__(self, width_basis: int, model: nn.Module, init_std: float = 1.0): 55 | """Initialize Ezmup by specifying the width basis, the model, and the init_std. 56 | 57 | Args: 58 | width_basis (int): A Base dimension. 59 | model (nn.Module): A model to be changed. 60 | init_std (float, optional): The initial standard deviation of the model parameters. Defaults to 1.0. 61 | """ 62 | self.width_basis = width_basis 63 | self.model = model 64 | self.init_std = init_std # Can be a float or a dict 65 | self.model_param_shape_dict = { 66 | name: param.shape for name, param in self.model.named_parameters() 67 | } 68 | self.lr_scaling_dict = {} 69 | 70 | @torch.no_grad() 71 | def change_width_as(self, new_width: int): 72 | """Update model parameters with new width multiplier. 73 | 74 | Args: 75 | new_width (int): Width multiplier used for calculating μP scaling. 76 | 77 | Raises: 78 | ValueError: When the module with the name is not found. 79 | NotImplementedError: When the parameter class is not found in the LAYER_REGISTRY. 80 | """ 81 | new_param_dict = {} 82 | dtype, device = ( 83 | self.model.parameters().__next__().dtype, 84 | self.model.parameters().__next__().device, 85 | ) 86 | 87 | for name, param in self.model.named_parameters(): 88 | shape = self.model_param_shape_dict[name] 89 | new_shape = [ 90 | new_width * (dim // self.width_basis) 91 | if dim % self.width_basis == 0 92 | else dim 93 | for dim in shape 94 | ] 95 | 96 | print(f"Now changing {name} from {shape} to {new_shape}") 97 | 98 | # if this is not a new layer, we want to skip it. 99 | if all(dim == new_shape[i] for i, dim in enumerate(shape)): 100 | new_param_dict[name] = param 101 | continue 102 | 103 | weight_shape = new_shape 104 | init_std, lr_scaling = SPECTRAL_DEFAULT 105 | # check where the parameter is in the registry. See if the parameter's class is in the registry. 106 | # this difficulty arises due to the fact that bias does not itself have an implication of the fan_in just from the parameter. 107 | 108 | if name.endswith("bias") or name.endswith("weight"): 109 | # remove the last part of the name. 110 | oname = ( 111 | name[: -len(".bias")] 112 | if name.endswith("bias") 113 | else name[: -len(".weight")] 114 | ) 115 | # print(name) 116 | module_with_name = self.model.get_submodule(oname) 117 | 118 | if module_with_name is None: 119 | # Exceptions must not be used with f-strings 120 | msg = f"Could not find module with name {name}" 121 | raise ValueError(msg) 122 | 123 | module_class = module_with_name.__class__.__name__ 124 | param_classname = module_class + ".bias" 125 | # print(param_classname) 126 | 127 | if param_classname in LAYER_REGISTRY: 128 | init_std, lr_scaling = LAYER_REGISTRY[param_classname] 129 | 130 | if name.endswith("bias"): 131 | fan_in_of_weight = self.model_param_shape_dict[ 132 | oname + ".weight" 133 | ][-1] 134 | if fan_in_of_weight % self.width_basis == 0: 135 | fan_in_of_weight = fan_in_of_weight * ( 136 | new_width // self.width_basis 137 | ) 138 | fan_in = fan_in_of_weight 139 | fan_out = 1 140 | else: 141 | weight_shape = new_shape 142 | fan_in = weight_shape[-1] 143 | fan_out = np.prod(weight_shape[:-1]) 144 | 145 | else: 146 | # Exceptions must not be used with f-strings 147 | msg = f"Could not find {param_classname} in LAYER_REGISTRY" 148 | raise NotImplementedError(msg) 149 | 150 | else: 151 | # we don't recognize this parameter : it is not a bias or a weight. 152 | # so we will assume fan_in and fan_out are simply the product of the dimensions. 153 | fan_in = weight_shape[-1] 154 | fan_out = np.prod(weight_shape[:-1]) 155 | init_std, lr_scaling = SPECTRAL_DEFAULT 156 | 157 | # assert fan_in * fan_out == np.prod(new_shape), f"fan_in * fan_out != np.prod(new_shape) for {name}, {fan_in * fan_out} != {np.prod(new_shape)}" 158 | print(f"{name} fan_in: {fan_in}, fan_out: {fan_out}") 159 | 160 | if isinstance(init_std, float): 161 | init_std = init_std * self._get_init_std(name) 162 | else: 163 | init_std = init_std(fan_in, fan_out, self._get_init_std(name)) 164 | 165 | if isinstance(lr_scaling, float): 166 | lr_scaling = lr_scaling / 64 167 | else: 168 | lr_scaling = lr_scaling(fan_in, fan_out) / 64 169 | 170 | self.lr_scaling_dict[name] = lr_scaling 171 | 172 | new_param = torch.randn(new_shape) * init_std 173 | new_param_dict[name] = new_param 174 | 175 | # print(f"Changing {name} from {shape} to {new_shape}") 176 | 177 | for name, named_module in self.model.named_modules(): 178 | if hasattr(named_module, "weight"): 179 | named_module.weight = torch.nn.Parameter( 180 | new_param_dict[name + ".weight"], 181 | requires_grad=True, 182 | ).to(dtype=named_module.weight.dtype) 183 | 184 | if hasattr(named_module, "bias"): 185 | named_module.bias = torch.nn.Parameter( 186 | new_param_dict[name + ".bias"], 187 | requires_grad=True, 188 | ).to(dtype=named_module.bias.dtype) 189 | 190 | self.model.to(dtype=dtype, device=device) 191 | 192 | def get_optimizer(self, optimizer_class: Any, lr: float, **kwargs): 193 | """Get an optimizer for the model. 194 | 195 | Args: 196 | optimizer_class (Any): Optimizer class. 197 | lr (float): Learning rate. 198 | **kwargs: Arbitrary keyword arguments. 199 | 200 | Returns: 201 | Any: Updated optimizer. 202 | """ 203 | mup_scaling = self.lr_scaling_dict 204 | 205 | optimizer_groups = [ 206 | {"params": [p], "lr": lr * mup_scaling.get(name, 1.0)} 207 | for name, p in self.model.named_parameters() 208 | ] 209 | 210 | return optimizer_class(optimizer_groups, **kwargs) 211 | 212 | def _get_init_std(self, name): 213 | if isinstance(self.init_std, dict): 214 | return self.init_std.get(name, 1.0) 215 | return self.init_std 216 | 217 | def forward(self, *args, **kwargs): 218 | """Forward pass of the model.""" 219 | pass 220 | 221 | 222 | def cov(x: torch.Tensor) -> torch.Tensor: 223 | """Treat `x` as a collection of vectors and its Gram matrix. 224 | 225 | Args: 226 | x (torch.Tensor): If it has shape [..., d], then it's treated as 227 | a collection of d-dimensional vectors 228 | 229 | Returns: 230 | torch.Tensor: a matrix of size N x N where N is the product of 231 | the non-last dimensions of `x`. 232 | """ 233 | if x.nelement() == 1: 234 | width = 1 235 | xx = x.reshape(1, 1) 236 | else: 237 | width = x.shape[-1] 238 | xx = x.reshape(-1, x.shape[-1]) 239 | return xx @ xx.T / width 240 | 241 | 242 | def covoffdiag(x: torch.Tensor) -> torch.Tensor: 243 | """Get off-diagonal entries of `cov(x)` in a vector. 244 | 245 | Args: 246 | x (torch.Tensor): If it has shape [..., d], then it's treated as 247 | a collection of d-dimensional vectors 248 | 249 | Returns: 250 | torch.Tensor: Off-diagonal entries of `cov(x)` in a vector. 251 | """ 252 | c = cov(x) 253 | return c[~torch.eye(c.shape[0], dtype=bool)] 254 | 255 | 256 | #: dict of provided functions for use in coord check 257 | FDICT = { 258 | "l1": lambda x: torch.abs(x).mean(), 259 | "l2": lambda x: (x**2).mean() ** 0.5, 260 | "mean": lambda x: x.mean(), 261 | "std": lambda x: x.std(), 262 | "covl1": lambda x: torch.abs(cov(x)).mean(), 263 | "covl2": lambda x: (cov(x) ** 2).mean() ** 0.5, 264 | "covoffdiagl1": lambda x: torch.abs(covoffdiag(x)).mean(), 265 | "covoffdiagl2": lambda x: (covoffdiag(x) ** 2).mean() ** 0.5, 266 | } 267 | 268 | 269 | def convert_fdict(d: dict[Any, Any]) -> dict[Any, Any]: 270 | """Convert a dict `d` with string values to function values. 271 | 272 | Args: 273 | d (dict[Any, Any]): a dict whose values are either strings or functions 274 | 275 | Returns: 276 | dict[Any, Any]: a new dict, with the same keys as `d`, but the string values are 277 | converted to functions using `FDICT`. 278 | """ 279 | return dict( 280 | [((k, FDICT[v]) if isinstance(v, str) else (k, v)) for k, v in d.items()], 281 | ) 282 | 283 | 284 | def _record_coords( 285 | records, 286 | width, 287 | modulename, 288 | t, 289 | output_fdict=None, 290 | input_fdict=None, 291 | param_fdict=None, 292 | ): 293 | """Returns a forward hook that records coordinate statistics. 294 | 295 | Returns a forward hook that records statistics regarding the output, input, 296 | and/or parameters of a `nn.Module`. This hook is intended to run only once, 297 | on the timestep specified by `t`. 298 | 299 | On forward pass, the returned hook calculates statistics specified in 300 | `output_fdict`, `input_fdict`, and `param_fdict`, such as the normalized l1 301 | norm, of output, input, and/or parameters of the module. The statistics are 302 | recorded along with the `width`, `modulename`, and `t` (the time step) as a 303 | dict and inserted into `records` (which should be a list). More precisely, 304 | for each output, input, and/or parameter, the inserted dict is of the form 305 | 306 | { 307 | 'width': width, 'module': modified_modulename, 't': t, 308 | # keys are keys in fdict 309 | 'l1': 0.241, 'l2': 0.420, 'mean': 0.0, ... 310 | } 311 | 312 | where `modified_modulename` is a string that combines the `modulename` with 313 | an indicator of which output, input, or parameter tensor is the statistics 314 | computed over. 315 | 316 | The `*_fdict` inputs should be dictionaries with string keys and whose 317 | values can either be functions or strings. The string values are converted 318 | to functions via `convert_fdict`. The default values of `*_dict` inputs are 319 | converted to `output_fdict = dict(l1=FDICT['l1'])`, `input_fdict = {}`, 320 | `param_fdict = {}`, i.e., only the average coordinate size (`l1`) of the 321 | output activations are recorded. 322 | 323 | Args: 324 | records: list to append coordinate data to. 325 | width: width of the model. This is used only for plotting coord check later 326 | on, so it can be any notion of width. 327 | modulename: string name of the module. This is used only for plotting coord check. 328 | t: timestep of training. This is used only for plotting coord check. 329 | output_fdict: dict with string keys and whose values can either be functions or strings. 330 | The string values are converted to functions via `convert_fdict`. 331 | input_fdict: dict with string keys and whose values can either be functions or strings. 332 | The string values are converted to functions via `convert_fdict`. 333 | param_fdict: dict with string keys and whose values can either be functions or strings. 334 | The string values are converted to functions via `convert_fdict`. 335 | 336 | Returns: 337 | Any: a forward hook that records statistics regarding the output, input, 338 | and/or parameters of a `nn.Module`, as discussed above. 339 | """ 340 | if output_fdict is None: 341 | output_fdict = dict(l1=FDICT["l1"]) 342 | else: 343 | output_fdict = convert_fdict(output_fdict) 344 | # SIM108: Use the ternary operator if it's reasonable 345 | input_fdict = {} if input_fdict is None else convert_fdict(input_fdict) 346 | param_fdict = {} if param_fdict is None else convert_fdict(param_fdict) 347 | 348 | def f(module, input, output): 349 | def get_stat(d, x, fdict): 350 | if isinstance(x, tuple | list): 351 | for i, _x in enumerate(x): 352 | _d = copy(d) 353 | _d["module"] += f"[{i}]" 354 | get_stat(_d, _x, fdict) 355 | elif isinstance(x, dict): 356 | for name, _x in x.items(): 357 | _d = copy(d) 358 | _d["module"] += f"[{name}]" 359 | get_stat(_d, _x, fdict) 360 | elif isinstance(x, torch.Tensor): 361 | _d = copy(d) 362 | for fname, f in fdict.items(): 363 | _d[fname] = f(x).item() 364 | records.append(_d) 365 | else: 366 | msg = f"Unexpected output type: {type(x)}" 367 | raise NotImplementedError(msg) 368 | 369 | with torch.no_grad(): 370 | ret = {"width": width, "module": modulename, "t": t} 371 | 372 | # output stats 373 | if isinstance(output, tuple | list): 374 | for i, out in enumerate(output): 375 | _ret = copy(ret) 376 | _ret["module"] += f":out[{i}]" 377 | get_stat(_ret, out, output_fdict) 378 | elif isinstance(output, dict): 379 | for name, out in output.items(): 380 | _ret = copy(ret) 381 | _ret["module"] += f":out[{name}]" 382 | get_stat(_ret, out, output_fdict) 383 | elif isinstance(output, torch.Tensor): 384 | _ret = copy(ret) 385 | for fname, f in output_fdict.items(): 386 | _ret[fname] = f(output).item() 387 | records.append(_ret) 388 | else: 389 | msg = f"Unexpected output type: {type(output)}" 390 | raise NotImplementedError(msg) 391 | 392 | # input stats 393 | if input_fdict: 394 | if isinstance(input, tuple | list): 395 | for i, out in enumerate(input): 396 | _ret = copy(ret) 397 | _ret["module"] += f":in[{i}]" 398 | get_stat(_ret, out, input_fdict) 399 | elif isinstance(input, dict): 400 | for name, out in input.items(): 401 | _ret = copy(ret) 402 | _ret["module"] += f":in[{name}]" 403 | get_stat(_ret, out, input_fdict) 404 | elif isinstance(input, torch.Tensor): 405 | _ret = copy(ret) 406 | for fname, f in input_fdict.items(): 407 | _ret[fname] = f(input).item() 408 | records.append(_ret) 409 | else: 410 | msg = f"Unexpected output type: {type(input)}" 411 | raise NotImplementedError(msg) 412 | 413 | # param stats 414 | if param_fdict: 415 | for name, p in module.named_parameters(): 416 | _ret = copy(ret) 417 | _ret["module"] += f":param[{name}]" 418 | for fname, f in param_fdict.items(): 419 | _ret[fname] = f(p).item() 420 | records.append(_ret) 421 | 422 | return f 423 | 424 | 425 | def get_coord_data( 426 | model_engine, 427 | datapoint, 428 | width_list=None, 429 | optim_cls=Adam, 430 | optim_kwargs=None, 431 | n_seeds=1, 432 | n_steps=3, 433 | ) -> pd.DataFrame: 434 | """Get coordinate data for coord check. 435 | 436 | Args: 437 | model_engine (_type_): Ezmup model engine. 438 | datapoint (_type_): A datapoint to be used for forward pass. 439 | width_list (list, optional): _description_. Defaults to [64, 128, 256, 512, 1024, 2048, 4096, 8192]. 440 | optim_cls (_type_, optional): _description_. Defaults to Adam. 441 | optim_kwargs (_type_, optional): _description_. Defaults to None. 442 | n_seeds (int, optional): _description_. Defaults to 1. 443 | n_steps (int, optional): _description_. Defaults to 3. 444 | 445 | Returns: 446 | pd.DataFrame: A dataframe containing the coordinate data. 447 | """ 448 | df = [] 449 | 450 | # Mutable default arguments are dangerous. Use None instead. 451 | if width_list is None: 452 | width_list = [64, 128, 256, 512, 1024, 2048, 4096, 8192] 453 | 454 | for i in range(n_seeds): 455 | torch.manual_seed(i) 456 | for width in width_list: 457 | model_engine.change_width_as(width) 458 | optim = ( 459 | model_engine.get_optimizer(optim_cls, lr=1e-3) 460 | if optim_kwargs is None 461 | else model_engine.get_optimizer(optim_cls, lr=1e-3, **optim_kwargs) 462 | ) 463 | 464 | for j in range(n_steps): 465 | remove_hooks = [] 466 | for name, module in model_engine.model.named_modules(): 467 | remove_hooks.append( 468 | module.register_forward_hook( 469 | _record_coords( 470 | df, 471 | width, 472 | name, 473 | j, 474 | output_fdict=None, 475 | input_fdict=None, 476 | param_fdict=None, 477 | ), 478 | ), 479 | ) 480 | 481 | model_engine.model.train() 482 | 483 | loss = model_engine.forward(datapoint, model_engine.model) 484 | loss.backward() 485 | optim.step() 486 | optim.zero_grad() 487 | 488 | for handle in remove_hooks: 489 | handle.remove() 490 | 491 | return pd.DataFrame(df) 492 | 493 | 494 | def plot_coord_data( 495 | df, 496 | y="l1", 497 | save_to=None, 498 | suptitle=None, 499 | x="width", 500 | hue="module", 501 | legend=True, 502 | name_contains=None, 503 | name_not_contains=None, 504 | loglog=True, 505 | logbase=2, 506 | face_color=None, 507 | jitter=True, 508 | jitter_strength=0.1, 509 | ): 510 | """Plot coord check data `df`. 511 | 512 | Args: 513 | df: pandas DataFrame 514 | y: column for y-axis. Default: 'l1' 515 | save_to: path to save the figure, or None. Default: None. 516 | suptitle: The title of the entire figure. 517 | x: column for x-axis. Default: 'width' 518 | hue: column for color. Default: 'module' 519 | legend: whether to show legend. Default: True 520 | name_contains: filter modules by name inclusion 521 | name_not_contains: filter modules by name exclusion 522 | loglog: use loglog scale. Default: True 523 | logbase: log base if using loglog. Default: 2 524 | face_color: background color of the plot. Default: None 525 | jitter: Whether to apply jitter to the y-axis. Default: True 526 | jitter_strength: The strength of the jitter. Default: 0.1 527 | 528 | Returns: 529 | the matplotlib figure object 530 | """ 531 | 532 | def apply_jitter(values, jitter_strength): 533 | # Apply a random multiplicative shift to each data point 534 | rng = np.random.default_rng() 535 | jitter = rng.uniform(-jitter_strength, jitter_strength, size=len(values)) 536 | return values * np.exp(jitter) 537 | 538 | # Preprocessing 539 | df = df.copy() 540 | df = df[df.module != ""] 541 | if name_contains is not None: 542 | df = df[df["module"].str.contains(name_contains)] 543 | elif name_not_contains is not None: 544 | df = df[~(df["module"].str.contains(name_not_contains))] 545 | 546 | ts = df.t.unique() 547 | 548 | # Plot 549 | fig, axes = plt.subplots(1, len(ts), figsize=(5 * len(ts), 4)) 550 | if face_color: 551 | fig.patch.set_facecolor(face_color) 552 | if suptitle: 553 | plt.suptitle(suptitle) 554 | 555 | for idx, t in enumerate(ts): 556 | ax = axes[idx] if len(ts) > 1 else axes 557 | subset = df[df.t == t] 558 | groups = subset.groupby(hue) 559 | 560 | for name, group in groups: 561 | x_values = group[x] 562 | y_values = group[y] 563 | 564 | if jitter: 565 | y_values = apply_jitter(y_values, jitter_strength) 566 | 567 | ax.plot(x_values, y_values, label=name) 568 | 569 | ax.set_title(f"t={t}") 570 | if loglog: 571 | ax.set_xscale("log", base=logbase) 572 | ax.set_yscale("log", base=logbase) 573 | 574 | if legend and idx == len(ts) - 1: 575 | ax.legend() 576 | 577 | fig.tight_layout(rect=[0, 0.03, 1, 0.95]) 578 | 579 | if save_to: 580 | plt.savefig(save_to) 581 | print(f"Plot saved to {save_to}") 582 | 583 | return fig 584 | --------------------------------------------------------------------------------