├── .github ├── CODEOWNERS ├── README.md └── images │ ├── ddpm_both.gif │ ├── ddpm_fashion.gif │ ├── ddpm_mnist.gif │ ├── dqn_cartpole.gif │ ├── gnns_training.png │ ├── nf_generated_images.png │ ├── ppo_cartpole.gif │ └── vit_architecture.png ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode └── launch.json ├── LICENSE ├── data └── nlp │ ├── bert │ └── masked_sentences.txt │ ├── gpt │ └── prompts.txt │ └── original │ └── translate_sentences.txt ├── environment.yml ├── requirements.txt └── src ├── __init__.py ├── cv ├── __init__.py ├── ddpm │ ├── __init__.py │ ├── ddpm.py │ ├── models.py │ └── notebook │ │ └── DDPM.ipynb ├── ign │ ├── ign.py │ ├── main.py │ └── model.py ├── nf │ ├── __init__.py │ └── normalizing_flows.py ├── vir │ ├── train.py │ └── vir.py └── vit │ ├── __init__.py │ └── vit_torch.py ├── fff ├── __init__.py ├── fff.py └── main.py ├── gnns ├── __init__.py └── gnns.py ├── nlp ├── __init__.py ├── bert │ ├── README.md │ ├── __init__.py │ ├── data.py │ ├── main.py │ └── model.py ├── gpt │ ├── README.md │ ├── __init__.py │ ├── main.py │ └── model.py ├── layers │ ├── __init__.py │ ├── attention.py │ ├── decoder.py │ ├── embeddings.py │ ├── encoder.py │ └── mlp.py ├── lm_is_compression │ ├── __init__.py │ └── lm_is_compression.py ├── lm_watermarking │ ├── main.py │ ├── plot.py │ └── watermarking.py ├── original │ ├── __init__.py │ ├── data.py │ ├── main.py │ └── model.py └── tokenizers │ ├── bpe.py │ └── wordpiece.py └── rl ├── dqn ├── dqn.py └── main.py └── ppo ├── PPO.ipynb ├── __init__.py └── ppo.py /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @BrianPulfer 2 | -------------------------------------------------------------------------------- /.github/README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | Personal re-implementations of Machine Learning papers. Re-implementations might use different hyper-parameters / datasets / settings compared to the original paper. 4 | 5 | Current re-implementations include: 6 | 7 | | Paper | Code | Blog | 8 | | ----------- | ----------- | ----------- | 9 | | Natural language processing | 10 | |[A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226v2)|[Code](/src/nlp/lm_watermarking/)| 11 | | [Attention is all you need](https://arxiv.org/abs/1706.03762) | [Code](/src/nlp/original/) | *Coming soon* 12 | | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) | [Code](/src/nlp/bert/) | *Coming soon* 13 | | [Language Modeling Is Compression](https://arxiv.org/abs/2309.10668) | [Code](/src/nlp/lm_is_compression/) | 14 | | [Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165)| [Code](/src/nlp/gpt/) | *Coming soon* 15 | | Computer Vision | 16 | |[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929)| [Code](/src/cv/vit/) | [Blog](https://www.brianpulfer.ch/blog/vit) 17 | |[Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | [Code](/src/cv/ddpm/) | [Blog](https://www.brianpulfer.ch/blog/ddpm) 18 | | [Density estimation using Real NVP](https://arxiv.org/abs/1605.08803)| [Code](/src/cv/nf/) | 19 | | [Idempotent Generative Network](https://arxiv.org/abs/2311.01462)| [Code](/src/cv/ign/) | [Blog](https://brianpulfer.ch/blog/ign) 20 | |[ViR: Vision Retention Networks](https://arxiv.org/abs/2310.19731)|[Code](/src/cv/vir/)| [Blog](https://brianpulfer.ch/blog/vir) 21 | | Reinforcement Learning | 22 | |[Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347)| [Code](/src/ppo/) | [Blog](https://www.brianpulfer.ch/blog/ppo) | 23 | | [Playing Atari with Deep Reinforcement Learning](https://arxiv.org/abs/1312.5602) | [Code](/src/rl/dqn/) | *Coming soon* | 24 | | Others | 25 | |[Everything is Connected: Graph Neural Networks](https://arxiv.org/abs/2301.08210)| [Code](/src/gnns/) | 26 | |[Fast Feedforward Networks](https://arxiv.org/abs/2308.14711)| [Code](/src/fff/) | 27 | 28 | 29 | # Contributing 30 | While this repo is a personal attempt to familiarize with the ideas down to the nitty gritty details, contributions are welcome for re-implementations that are already on the repository. In particular, I am open to discuss doubts, questions, suggestions to improve the code, and spotted mistakes / bugs. If you would like to contribute, simply raise an issue before submitting a pull request. 31 | 32 | # License 33 | The code is released with the [MIT license](/LICENSE). 34 | -------------------------------------------------------------------------------- /.github/images/ddpm_both.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/.github/images/ddpm_both.gif -------------------------------------------------------------------------------- /.github/images/ddpm_fashion.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/.github/images/ddpm_fashion.gif -------------------------------------------------------------------------------- /.github/images/ddpm_mnist.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/.github/images/ddpm_mnist.gif -------------------------------------------------------------------------------- /.github/images/dqn_cartpole.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/.github/images/dqn_cartpole.gif -------------------------------------------------------------------------------- /.github/images/gnns_training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/.github/images/gnns_training.png -------------------------------------------------------------------------------- /.github/images/nf_generated_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/.github/images/nf_generated_images.png -------------------------------------------------------------------------------- /.github/images/ppo_cartpole.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/.github/images/ppo_cartpole.gif -------------------------------------------------------------------------------- /.github/images/vit_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/.github/images/vit_architecture.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom folders and files 2 | .DS_Store 3 | /datasets 4 | *.pt 5 | wandb/ 6 | checkpoints/ 7 | lightning_logs/ 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/#use-with-ide 118 | .pdm.toml 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | 9 | - repo: https://github.com/psf/black 10 | rev: 23.3.0 11 | hooks: 12 | - id: black-jupyter 13 | 14 | - repo: https://github.com/pycqa/isort 15 | rev: 5.12.0 16 | hooks: 17 | - id: isort 18 | args: ["--profile", "black", "--filter-files"] 19 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "Python: Current File", 6 | "type": "python", 7 | "request": "launch", 8 | "program": "${file}", 9 | "console": "integratedTerminal", 10 | "env": {"PYTHONPATH": "${workspaceFolder}"}, 11 | "justMyCode": false 12 | }, 13 | { 14 | "name": "cv - ddpm", 15 | "type": "python", 16 | "request": "launch", 17 | "program": "${workspaceFolder}/src/cv/ddpm/ddpm.py", 18 | "console": "integratedTerminal", 19 | "env": {"PYTHONPATH": "${workspaceFolder}"}, 20 | "justMyCode": true 21 | }, 22 | { 23 | "name": "cv - ign", 24 | "type": "python", 25 | "request": "launch", 26 | "program": "${workspaceFolder}/src/cv/ign/main.py", 27 | "console": "integratedTerminal", 28 | "env": {"PYTHONPATH": "${workspaceFolder}"}, 29 | "justMyCode": true 30 | }, 31 | { 32 | "name": "cv - normalizing flows", 33 | "type": "python", 34 | "request": "launch", 35 | "program": "${workspaceFolder}/src/cv/nf/normalizing_flows.py", 36 | "console": "integratedTerminal", 37 | "env": {"PYTHONPATH": "${workspaceFolder}"}, 38 | "justMyCode": true 39 | }, 40 | { 41 | "name": "cv - vir", 42 | "type": "python", 43 | "request": "launch", 44 | "program": "${workspaceFolder}/src/cv/vir/train.py", 45 | "console": "integratedTerminal", 46 | "env": {"PYTHONPATH": "${workspaceFolder}"}, 47 | "justMyCode": true 48 | }, 49 | { 50 | "name": "cv - vit", 51 | "type": "python", 52 | "request": "launch", 53 | "program": "${workspaceFolder}/src/cv/vit/vit_torch.py", 54 | "console": "integratedTerminal", 55 | "env": {"PYTHONPATH": "${workspaceFolder}"}, 56 | "justMyCode": true 57 | }, 58 | { 59 | "name": "fff", 60 | "type": "python", 61 | "request": "launch", 62 | "program": "${workspaceFolder}/src/fff/main.py", 63 | "console": "integratedTerminal", 64 | "env": {"PYTHONPATH": "${workspaceFolder}"}, 65 | "justMyCode": true 66 | }, 67 | { 68 | "name": "gnns", 69 | "type": "python", 70 | "request": "launch", 71 | "program": "${workspaceFolder}/src/gnns/gnns.py", 72 | "console": "integratedTerminal", 73 | "env": {"PYTHONPATH": "${workspaceFolder}"}, 74 | "justMyCode": true 75 | }, 76 | { 77 | "name": "nlp - bert", 78 | "type": "python", 79 | "request": "launch", 80 | "program": "${workspaceFolder}/src/nlp/bert/main.py", 81 | "console": "integratedTerminal", 82 | "env": {"PYTHONPATH": "${workspaceFolder}"}, 83 | "args": [ 84 | "--max_train_steps", "80", 85 | "--warmup_steps", "30", 86 | ], 87 | "justMyCode": true 88 | }, 89 | { 90 | "name": "nlp - gpt", 91 | "type": "python", 92 | "request": "launch", 93 | "program": "${workspaceFolder}/src/nlp/gpt/main.py", 94 | "console": "integratedTerminal", 95 | "env": {"PYTHONPATH": "${workspaceFolder}"}, 96 | "args": [ 97 | "--max_train_steps", "80", 98 | "--warmup_steps", "8", 99 | "--batch_size", "16", 100 | ], 101 | "justMyCode": true 102 | }, 103 | { 104 | "name": "nlp - lm compression", 105 | "type": "python", 106 | "request": "launch", 107 | "program": "${workspaceFolder}/src/nlp/lm_is_compression/lm_is_compression.py", 108 | "console": "integratedTerminal", 109 | "env": {"PYTHONPATH": "${workspaceFolder}"}, 110 | "justMyCode": true 111 | }, 112 | { 113 | "name": "nlp - original", 114 | "type": "python", 115 | "request": "launch", 116 | "program": "${workspaceFolder}/src/nlp/original/main.py", 117 | "console": "integratedTerminal", 118 | "env": {"PYTHONPATH": "${workspaceFolder}"}, 119 | "args": [ 120 | "--max_train_steps", "80", 121 | "--warmup_steps", "8", 122 | "--batch_size", "4", 123 | ], 124 | "justMyCode": true 125 | }, 126 | { 127 | "name": "nlp - watermark", 128 | "type": "python", 129 | "request": "launch", 130 | "program": "${workspaceFolder}/src/nlp/lm_watermarking/main.py", 131 | "console": "integratedTerminal", 132 | "env": {"PYTHONPATH": "${workspaceFolder}"}, 133 | "justMyCode": true 134 | }, 135 | { 136 | "name": "rl - dqn", 137 | "type": "python", 138 | "request": "launch", 139 | "program": "${workspaceFolder}/src/rl/dqn/main.py", 140 | "console": "integratedTerminal", 141 | "env": {"PYTHONPATH": "${workspaceFolder}"}, 142 | "justMyCode": true 143 | }, 144 | { 145 | "name": "rl - ppo", 146 | "type": "python", 147 | "request": "launch", 148 | "program": "${workspaceFolder}/src/rl/ppo/ppo.py", 149 | "console": "integratedTerminal", 150 | "env": {"PYTHONPATH": "${workspaceFolder}"}, 151 | "justMyCode": true 152 | } 153 | ] 154 | } 155 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Peutlefaire 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/nlp/bert/masked_sentences.txt: -------------------------------------------------------------------------------- 1 | Paris is the [MASK] of France. 2 | World War II started in [MASK] and ended in [MASK]. 3 | The term AI stands for [MASK] Intelligence. 4 | -------------------------------------------------------------------------------- /data/nlp/gpt/prompts.txt: -------------------------------------------------------------------------------- 1 | A dice has 6 sides, and 2 | Bern is the capital of 3 | Throughout a year there are four seasons: 4 | -------------------------------------------------------------------------------- /data/nlp/original/translate_sentences.txt: -------------------------------------------------------------------------------- 1 | Der Stift liegt auf dem Tisch. 2 | Ich kenne kein Alter, keine Müdigkeit, keine Niederlage. 3 | Man weiß nicht, was man nicht weiß. 4 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: reimplementations 2 | channels: 3 | - nvidia 4 | - pytorch-nightly 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - brotlipy=0.7.0=py311h9bf148f_1002 11 | - bzip2=1.0.8=h7b6447c_0 12 | - ca-certificates=2023.08.22=h06a4308_0 13 | - certifi=2023.7.22=py311h06a4308_0 14 | - cffi=1.15.1=py311h9bf148f_3 15 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 16 | - cryptography=38.0.4=py311h46ebde7_0 17 | - cuda-cudart=12.1.105=0 18 | - cuda-cupti=12.1.105=0 19 | - cuda-libraries=12.1.0=0 20 | - cuda-nvrtc=12.1.105=0 21 | - cuda-nvtx=12.1.105=0 22 | - cuda-opencl=12.2.140=0 23 | - cuda-runtime=12.1.0=0 24 | - ffmpeg=4.2.2=h20bf706_0 25 | - freetype=2.12.1=h4a9f257_0 26 | - giflib=5.2.1=h5eee18b_3 27 | - gmp=6.2.1=h295c915_3 28 | - gmpy2=2.1.2=py311hc9b5ff0_0 29 | - gnutls=3.6.15=he1e5248_0 30 | - idna=3.4=py311h06a4308_0 31 | - intel-openmp=2021.4.0=h06a4308_3561 32 | - jinja2=3.1.2=py311h06a4308_0 33 | - jpeg=9e=h5eee18b_1 34 | - lame=3.100=h7b6447c_0 35 | - lcms2=2.12=h3be6417_0 36 | - ld_impl_linux-64=2.38=h1181459_1 37 | - lerc=3.0=h295c915_0 38 | - libcublas=12.1.0.26=0 39 | - libcufft=11.0.2.4=0 40 | - libcufile=1.7.2.10=0 41 | - libcurand=10.3.3.141=0 42 | - libcusolver=11.4.4.55=0 43 | - libcusparse=12.0.2.55=0 44 | - libdeflate=1.17=h5eee18b_1 45 | - libffi=3.4.4=h6a678d5_0 46 | - libgcc-ng=11.2.0=h1234567_1 47 | - libgomp=11.2.0=h1234567_1 48 | - libidn2=2.3.4=h5eee18b_0 49 | - libjpeg-turbo=2.0.0=h9bf148f_0 50 | - libnpp=12.0.2.50=0 51 | - libnvjitlink=12.1.105=0 52 | - libnvjpeg=12.1.1.14=0 53 | - libopus=1.3.1=h7b6447c_0 54 | - libpng=1.6.39=h5eee18b_0 55 | - libstdcxx-ng=11.2.0=h1234567_1 56 | - libtasn1=4.19.0=h5eee18b_0 57 | - libtiff=4.5.1=h6a678d5_0 58 | - libunistring=0.9.10=h27cfd23_0 59 | - libuuid=1.41.5=h5eee18b_0 60 | - libvpx=1.7.0=h439df22_0 61 | - libwebp=1.2.4=h11a3e52_1 62 | - libwebp-base=1.2.4=h5eee18b_1 63 | - llvm-openmp=14.0.6=h9e868ea_0 64 | - lz4-c=1.9.4=h6a678d5_0 65 | - markupsafe=2.1.1=py311h5eee18b_0 66 | - mkl=2021.4.0=h06a4308_640 67 | - mkl-service=2.4.0=py311h9bf148f_0 68 | - mkl_fft=1.3.1=py311hc796f24_0 69 | - mkl_random=1.2.2=py311hbba84a0_0 70 | - mpc=1.1.0=h10f8cd9_1 71 | - mpfr=4.0.2=hb69a4c5_1 72 | - mpmath=1.2.1=py311_0 73 | - ncurses=6.4=h6a678d5_0 74 | - nettle=3.7.3=hbbd107a_1 75 | - networkx=3.1=py311h06a4308_0 76 | - numpy=1.24.3=py311hc206e33_0 77 | - numpy-base=1.24.3=py311hfd5febd_0 78 | - openh264=2.1.1=h4ff587b_0 79 | - openssl=3.0.11=h7f8727e_2 80 | - pillow=9.3.0=py311h3fd9d12_2 81 | - pip=23.2.1=py311h06a4308_0 82 | - pycparser=2.21=pyhd3eb1b0_0 83 | - pyopenssl=23.2.0=py311h06a4308_0 84 | - pysocks=1.7.1=py311_0 85 | - python=3.11.5=h955ad1f_0 86 | - pytorch=2.2.0.dev20231001=py3.11_cuda12.1_cudnn8.9.2_0 87 | - pytorch-cuda=12.1=ha16c6d3_5 88 | - pytorch-mutex=1.0=cuda 89 | - pyyaml=6.0=py311h5eee18b_1 90 | - readline=8.2=h5eee18b_0 91 | - requests=2.28.1=py311_0 92 | - six=1.16.0=pyhd3eb1b0_1 93 | - sqlite=3.41.2=h5eee18b_0 94 | - sympy=1.11.1=py311h06a4308_0 95 | - tk=8.6.12=h1ccaba5_0 96 | - torchaudio=2.2.0.dev20231001=py311_cu121 97 | - torchtriton=2.1.0+6e4932cda8=py311 98 | - torchvision=0.17.0.dev20231001=py311_cu121 99 | - typing_extensions=4.7.1=py311h06a4308_0 100 | - urllib3=1.26.14=py311_0 101 | - wheel=0.41.2=py311h06a4308_0 102 | - x264=1!157.20191217=h7b6447c_0 103 | - xz=5.4.2=h5eee18b_0 104 | - yaml=0.2.5=h7b6447c_0 105 | - zlib=1.2.13=h5eee18b_0 106 | - zstd=1.5.5=hc292b87_0 107 | - pip: 108 | - absl-py==2.0.0 109 | - accelerate==0.23.0 110 | - aiohttp==3.8.5 111 | - aiosignal==1.3.1 112 | - apache-beam==2.50.0 113 | - appdirs==1.4.4 114 | - async-timeout==4.0.3 115 | - attrs==23.1.0 116 | - black==23.9.1 117 | - blis==0.7.11 118 | - bpytop==1.0.68 119 | - cachetools==5.3.1 120 | - catalogue==2.0.10 121 | - cfgv==3.4.0 122 | - click==8.1.7 123 | - cloudpathlib==0.15.1 124 | - cloudpickle==2.2.1 125 | - confection==0.1.3 126 | - contourpy==1.1.1 127 | - crcmod==1.7 128 | - cycler==0.12.0 129 | - cymem==2.0.8 130 | - datasets==2.14.5 131 | - deepspeed==0.10.3 132 | - dill==0.3.1.1 133 | - distlib==0.3.7 134 | - dnspython==2.4.2 135 | - docker-pycreds==0.4.0 136 | - docopt==0.6.2 137 | - einops==0.7.0 138 | - fastavro==1.8.4 139 | - fasteners==0.19 140 | - filelock==3.12.4 141 | - flake8==6.1.0 142 | - fonttools==4.43.0 143 | - frozenlist==1.4.0 144 | - fsspec==2023.6.0 145 | - gitdb==4.0.10 146 | - gitpython==3.1.37 147 | - google-auth==2.23.2 148 | - google-auth-oauthlib==1.0.0 149 | - grpcio==1.59.0 150 | - gym==0.26.2 151 | - gym-notices==0.0.8 152 | - hdfs==2.7.2 153 | - hjson==3.1.0 154 | - httplib2==0.22.0 155 | - huggingface-hub==0.16.4 156 | - identify==2.5.30 157 | - imageio==2.31.5 158 | - importlib-metadata==6.8.0 159 | - isort==5.12.0 160 | - jedi==0.19.1 161 | - kiwisolver==1.4.5 162 | - langcodes==3.3.0 163 | - lightning-utilities==0.9.0 164 | - markdown==3.4.4 165 | - matplotlib==3.8.0 166 | - mccabe==0.7.0 167 | - multidict==6.0.4 168 | - multiprocess==0.70.15 169 | - murmurhash==1.0.10 170 | - mwparserfromhell==0.6.5 171 | - mypy-extensions==1.0.0 172 | - ninja==1.11.1 173 | - nodeenv==1.8.0 174 | - nvidia-ml-py==12.535.108 175 | - nvitop==1.3.0 176 | - oauthlib==3.2.2 177 | - objsize==0.6.1 178 | - orjson==3.9.7 179 | - packaging==23.2 180 | - pandas==2.1.1 181 | - parso==0.8.3 182 | - pathspec==0.11.2 183 | - pathtools==0.1.2 184 | - pathy==0.10.2 185 | - platformdirs==3.10.0 186 | - pre-commit==3.4.0 187 | - preshed==3.0.9 188 | - prompt-toolkit==3.0.39 189 | - proto-plus==1.22.3 190 | - protobuf==4.23.4 191 | - psutil==5.9.5 192 | - ptpython==3.0.23 193 | - py-cpuinfo==9.0.0 194 | - pyarrow==11.0.0 195 | - pyasn1==0.5.0 196 | - pyasn1-modules==0.3.0 197 | - pycodestyle==2.11.0 198 | - pydantic==1.10.13 199 | - pydot==1.4.2 200 | - pyflakes==3.1.0 201 | - pygments==2.16.1 202 | - pymongo==4.5.0 203 | - pyparsing==3.1.1 204 | - python-dateutil==2.8.2 205 | - python-dotenv==1.0.0 206 | - pytorch-lightning==2.0.9.post0 207 | - pytz==2023.3.post1 208 | - regex==2023.10.3 209 | - requests-oauthlib==1.3.1 210 | - rsa==4.9 211 | - safetensors==0.3.3 212 | - seaborn==0.13.0 213 | - sentencepiece==0.1.99 214 | - sentry-sdk==1.31.0 215 | - setproctitle==1.3.3 216 | - setuptools==68.2.2 217 | - smart-open==6.4.0 218 | - smmap==5.0.1 219 | - spacy==3.7.0 220 | - spacy-legacy==3.0.12 221 | - spacy-loggers==1.0.5 222 | - srsly==2.4.8 223 | - tensorboard==2.14.1 224 | - tensorboard-data-server==0.7.1 225 | - termcolor==2.3.0 226 | - thinc==8.2.1 227 | - tokenizers==0.14.0 228 | - tomli==2.0.1 229 | - torchmetrics==1.2.0 230 | - tqdm==4.66.1 231 | - transformers==4.34.0 232 | - typer==0.7.0 233 | - tzdata==2023.3 234 | - virtualenv==20.24.5 235 | - wandb==0.15.11 236 | - wasabi==1.1.2 237 | - wcwidth==0.2.8 238 | - weasel==0.3.1 239 | - werkzeug==3.0.0 240 | - xxhash==3.3.0 241 | - yapf==0.40.2 242 | - yarl==1.9.2 243 | - zipp==3.17.0 244 | - zstandard==0.21.0 245 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | apache_beam 3 | datasets 4 | deepspeed 5 | einops 6 | gymnasium 7 | gymnasium[Box2D] 8 | imageio 9 | matplotlib 10 | mwparserfromhell 11 | numpy 12 | pandas 13 | pillow 14 | pre-commit 15 | python-dotenv 16 | pytorch-lightning 17 | sacremoses 18 | seaborn 19 | spacy 20 | tensorboard 21 | torch 22 | torchvision 23 | tqdm 24 | transformers 25 | wandb 26 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/__init__.py -------------------------------------------------------------------------------- /src/cv/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/cv/__init__.py -------------------------------------------------------------------------------- /src/cv/ddpm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/cv/ddpm/__init__.py -------------------------------------------------------------------------------- /src/cv/ddpm/ddpm.py: -------------------------------------------------------------------------------- 1 | # Import of libraries 2 | import random 3 | from argparse import ArgumentParser 4 | 5 | import einops 6 | import imageio 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from torch.optim import Adam 12 | from torch.utils.data import DataLoader 13 | from torchvision.datasets.mnist import MNIST, FashionMNIST 14 | from torchvision.transforms import Compose, Lambda, ToTensor 15 | from tqdm.auto import tqdm 16 | 17 | # Import of custom models 18 | from src.cv.ddpm.models import MyDDPM, MyUNet 19 | 20 | # Setting reproducibility 21 | SEED = 0 22 | random.seed(SEED) 23 | np.random.seed(SEED) 24 | torch.manual_seed(SEED) 25 | 26 | # Definitions 27 | STORE_PATH_MNIST = f"ddpm_model_mnist.pt" 28 | STORE_PATH_FASHION = f"ddpm_model_fashion.pt" 29 | 30 | 31 | def show_images(images, title=""): 32 | """Shows the provided images as sub-pictures in a square""" 33 | 34 | # Converting images to CPU numpy arrays 35 | if type(images) is torch.Tensor: 36 | images = images.detach().cpu().numpy() 37 | 38 | # Defining number of rows and columns 39 | fig = plt.figure(figsize=(8, 8)) 40 | rows = int(len(images) ** (1 / 2)) 41 | cols = round(len(images) / rows) 42 | 43 | # Populating figure with sub-plots 44 | idx = 0 45 | for r in range(rows): 46 | for c in range(cols): 47 | fig.add_subplot(rows, cols, idx + 1) 48 | 49 | if idx < len(images): 50 | plt.imshow(images[idx][0], cmap="gray") 51 | idx += 1 52 | fig.suptitle(title, fontsize=30) 53 | 54 | # Showing the figure 55 | plt.show() 56 | 57 | 58 | def show_first_batch(loader): 59 | for batch in loader: 60 | show_images(batch[0], "Images in the first batch") 61 | break 62 | 63 | 64 | def show_forward(ddpm, loader, device): 65 | # Showing the forward process 66 | for batch in loader: 67 | imgs = batch[0] 68 | 69 | show_images(imgs, "Original images") 70 | 71 | for percent in [0.25, 0.5, 0.75, 1]: 72 | show_images( 73 | ddpm( 74 | imgs.to(device), 75 | [int(percent * ddpm.n_steps) - 1 for _ in range(len(imgs))], 76 | ), 77 | f"DDPM Noisy images {int(percent * 100)}%", 78 | ) 79 | break 80 | 81 | 82 | def generate_new_images( 83 | ddpm, 84 | n_samples=16, 85 | device=None, 86 | frames_per_gif=100, 87 | gif_name="sampling.gif", 88 | c=1, 89 | h=28, 90 | w=28, 91 | ): 92 | """Given a DDPM model, a number of samples to be generated and a device, returns some newly generated samples""" 93 | frame_idxs = np.linspace(0, ddpm.n_steps, frames_per_gif).astype(np.uint) 94 | frames = [] 95 | 96 | with torch.no_grad(): 97 | if device is None: 98 | device = ddpm.device 99 | 100 | # Starting from random noise 101 | x = torch.randn(n_samples, c, h, w).to(device) 102 | 103 | for idx, t in enumerate(list(range(ddpm.n_steps))[::-1]): 104 | # Estimating noise to be removed 105 | time_tensor = (torch.ones(n_samples, 1) * t).to(device).long() 106 | eta_theta = ddpm.backward(x, time_tensor) 107 | 108 | alpha_t = ddpm.alphas[t] 109 | alpha_t_bar = ddpm.alpha_bars[t] 110 | 111 | # Partially denoising the image 112 | x = (1 / alpha_t.sqrt()) * ( 113 | x - (1 - alpha_t) / (1 - alpha_t_bar).sqrt() * eta_theta 114 | ) 115 | 116 | if t > 0: 117 | z = torch.randn(n_samples, c, h, w).to(device) 118 | 119 | # Option 1: sigma_t squared = beta_t 120 | beta_t = ddpm.betas[t] 121 | sigma_t = beta_t.sqrt() 122 | 123 | # Option 2: sigma_t squared = beta_tilda_t 124 | # prev_alpha_t_bar = ddpm.alpha_bars[t-1] if t > 0 else ddpm.alphas[0] 125 | # beta_tilda_t = ((1 - prev_alpha_t_bar)/(1 - alpha_t_bar)) * beta_t 126 | # sigma_t = beta_tilda_t.sqrt() 127 | 128 | # Adding some more noise like in Langevin Dynamics fashion 129 | x = x + sigma_t * z 130 | 131 | # Adding frames to the GIF 132 | if idx in frame_idxs or t == 0: 133 | # Putting digits in range [0, 255] 134 | normalized = x.clone() 135 | for i in range(len(normalized)): 136 | normalized[i] -= torch.min(normalized[i]) 137 | normalized[i] *= 255 / torch.max(normalized[i]) 138 | 139 | # Reshaping batch (n, c, h, w) to be a (as much as it gets) square frame 140 | frame = einops.rearrange( 141 | normalized, 142 | "(b1 b2) c h w -> (b1 h) (b2 w) c", 143 | b1=int(n_samples**0.5), 144 | ) 145 | frame = frame.cpu().numpy().astype(np.uint8) 146 | 147 | # Rendering frame 148 | frames.append(frame) 149 | 150 | # Storing the gif 151 | with imageio.get_writer(gif_name, mode="I") as writer: 152 | for idx, frame in enumerate(frames): 153 | rgb_frame = np.repeat(frame, 3, axis=2) 154 | writer.append_data(rgb_frame) 155 | 156 | # Showing the last frame for a longer time 157 | if idx == len(frames) - 1: 158 | last_rgb_frame = np.repeat(frames[-1], 3, axis=2) 159 | for _ in range(frames_per_gif // 3): 160 | writer.append_data(last_rgb_frame) 161 | return x 162 | 163 | 164 | def training_loop( 165 | ddpm, loader, n_epochs, optim, device, display=False, store_path="ddpm_model.pt" 166 | ): 167 | mse = nn.MSELoss() 168 | best_loss = float("inf") 169 | n_steps = ddpm.n_steps 170 | 171 | for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"): 172 | epoch_loss = 0.0 173 | for step, batch in enumerate( 174 | tqdm( 175 | loader, 176 | leave=False, 177 | desc=f"Epoch {epoch + 1}/{n_epochs}", 178 | colour="#005500", 179 | ) 180 | ): 181 | # Loading data 182 | x0 = batch[0].to(device) 183 | n = len(x0) 184 | 185 | # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars 186 | eta = torch.randn_like(x0).to(device) 187 | t = torch.randint(0, n_steps, (n,)).to(device) 188 | 189 | # Computing the noisy image based on x0 and the time-step (forward process) 190 | noisy_imgs = ddpm(x0, t, eta) 191 | 192 | # Getting model estimation of noise based on the images and the time-step 193 | eta_theta = ddpm.backward(noisy_imgs, t.reshape(n, -1)) 194 | 195 | # Optimizing the MSE between the noise plugged and the predicted noise 196 | loss = mse(eta_theta, eta) 197 | optim.zero_grad() 198 | loss.backward() 199 | optim.step() 200 | 201 | epoch_loss += loss.item() * len(x0) / len(loader.dataset) 202 | 203 | # Display images generated at this epoch 204 | if display: 205 | show_images( 206 | generate_new_images(ddpm, device=device), 207 | f"Images generated at epoch {epoch + 1}", 208 | ) 209 | 210 | log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}" 211 | 212 | # Storing the model 213 | if best_loss > epoch_loss: 214 | best_loss = epoch_loss 215 | torch.save(ddpm.state_dict(), store_path) 216 | log_string += " --> Best model ever (stored)" 217 | 218 | print(log_string) 219 | 220 | 221 | def main(): 222 | # Program arguments 223 | parser = ArgumentParser() 224 | parser.add_argument( 225 | "--no_train", action="store_true", help="Whether to train a new model or not" 226 | ) 227 | parser.add_argument( 228 | "--fashion", 229 | action="store_true", 230 | help="Uses MNIST if true, Fashion MNIST otherwise", 231 | ) 232 | parser.add_argument("--bs", type=int, default=128, help="Batch size") 233 | parser.add_argument("--epochs", type=int, default=20, help="Number of epochs") 234 | parser.add_argument("--lr", type=float, default=0.001, help="Learning rate") 235 | args = vars(parser.parse_args()) 236 | print(args) 237 | 238 | # Model store path 239 | store_path = "ddpm_fashion.pt" if args["fashion"] else "ddpm_mnist.pt" 240 | 241 | # Loading the data (converting each image into a tensor and normalizing between [-1, 1]) 242 | transform = Compose([ToTensor(), Lambda(lambda x: (x - 0.5) * 2)]) 243 | ds_fn = MNIST if not args["fashion"] else FashionMNIST 244 | dataset = ds_fn("./../datasets", download=True, train=True, transform=transform) 245 | loader = DataLoader(dataset, args["bs"], shuffle=True) 246 | 247 | # Getting device 248 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 249 | print( 250 | f"Using device: {device}\t" 251 | + (f"{torch.cuda.get_device_name(0)}" if torch.cuda.is_available() else "CPU") 252 | ) 253 | 254 | # Defining model 255 | n_steps, min_beta, max_beta = 1000, 10**-4, 0.02 # Originally used by the authors 256 | ddpm = MyDDPM( 257 | MyUNet(n_steps), 258 | n_steps=n_steps, 259 | min_beta=min_beta, 260 | max_beta=max_beta, 261 | device=device, 262 | ) 263 | 264 | # Optionally, load a pre-trained model that will be further trained 265 | # ddpm.load_state_dict(torch.load(store_path, map_location=device)) 266 | 267 | # Optionally, show a batch of regular images 268 | # show_first_batch(loader) 269 | 270 | # Optionally, show the diffusion (forward) process 271 | # show_forward(ddpm, loader, device) 272 | 273 | # Optionally, show the denoising (backward) process 274 | # generated = generate_new_images(ddpm, gif_name="before_training.gif") 275 | # show_images(generated, "Images generated before training") 276 | 277 | # Training 278 | if not args["no_train"]: 279 | n_epochs, lr = args["epochs"], args["lr"] 280 | training_loop( 281 | ddpm, 282 | loader, 283 | n_epochs, 284 | optim=Adam(ddpm.parameters(), lr), 285 | device=device, 286 | store_path=store_path, 287 | ) 288 | 289 | # Loading the trained model 290 | best_model = MyDDPM(MyUNet(), n_steps=n_steps, device=device) 291 | best_model.load_state_dict(torch.load(store_path, map_location=device)) 292 | best_model.eval() 293 | print("Model loaded: Generating new images") 294 | 295 | # Showing generated images 296 | generated = generate_new_images( 297 | best_model, 298 | n_samples=100, 299 | device=device, 300 | gif_name="fashion.gif" if args["fashion"] else "mnist.gif", 301 | ) 302 | show_images(generated, "Final result") 303 | 304 | 305 | if __name__ == "__main__": 306 | main() 307 | -------------------------------------------------------------------------------- /src/cv/ddpm/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def sinusoidal_embedding(n, d): 6 | # Returns the standard positional embedding 7 | embedding = torch.zeros(n, d) 8 | wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)]) 9 | wk = wk.reshape((1, d)) 10 | t = torch.arange(n).reshape((n, 1)) 11 | embedding[:, ::2] = torch.sin(t * wk[:, ::2]) 12 | embedding[:, 1::2] = torch.cos(t * wk[:, ::2]) 13 | 14 | return embedding 15 | 16 | 17 | # DDPM class 18 | class MyDDPM(nn.Module): 19 | def __init__( 20 | self, 21 | network, 22 | n_steps=200, 23 | min_beta=10**-4, 24 | max_beta=0.02, 25 | device=None, 26 | image_chw=(1, 28, 28), 27 | ): 28 | super(MyDDPM, self).__init__() 29 | self.n_steps = n_steps 30 | self.device = device 31 | self.image_chw = image_chw 32 | self.network = network.to(device) 33 | self.betas = torch.linspace(min_beta, max_beta, n_steps).to( 34 | device 35 | ) # Number of steps is typically in the order of thousands 36 | self.alphas = 1 - self.betas 37 | self.alpha_bars = torch.tensor( 38 | [torch.prod(self.alphas[: i + 1]) for i in range(len(self.alphas))] 39 | ).to(device) 40 | 41 | def forward(self, x0, t, eta=None): 42 | # Make input image more noisy (we can directly skip to the desired step) 43 | n, c, h, w = x0.shape 44 | a_bar = self.alpha_bars[t] 45 | 46 | if eta is None: 47 | eta = torch.randn(n, c, h, w).to(self.device) 48 | 49 | noisy = ( 50 | a_bar.sqrt().reshape(n, 1, 1, 1) * x0 51 | + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta 52 | ) 53 | return noisy 54 | 55 | def backward(self, x, t): 56 | # Run each image through the network for each timestep t in the vector t. 57 | # The network returns its estimation of the noise that was added. 58 | return self.network(x, t) 59 | 60 | 61 | class MyBlock(nn.Module): 62 | def __init__( 63 | self, 64 | shape, 65 | in_c, 66 | out_c, 67 | kernel_size=3, 68 | stride=1, 69 | padding=1, 70 | activation=None, 71 | normalize=True, 72 | ): 73 | super(MyBlock, self).__init__() 74 | self.ln = nn.LayerNorm(shape) 75 | self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding) 76 | self.conv2 = nn.Conv2d(out_c, out_c, kernel_size, stride, padding) 77 | self.activation = nn.SiLU() if activation is None else activation 78 | self.normalize = normalize 79 | 80 | def forward(self, x): 81 | out = self.ln(x) if self.normalize else x 82 | out = self.conv1(out) 83 | out = self.activation(out) 84 | out = self.conv2(out) 85 | out = self.activation(out) 86 | return out 87 | 88 | 89 | class MyUNet(nn.Module): 90 | def __init__(self, n_steps=1000, time_emb_dim=100): 91 | super(MyUNet, self).__init__() 92 | 93 | # Sinusoidal embedding 94 | self.time_embed = nn.Embedding(n_steps, time_emb_dim) 95 | self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim) 96 | self.time_embed.requires_grad_(False) 97 | 98 | # First half 99 | self.te1 = self._make_te(time_emb_dim, 1) 100 | self.b1 = nn.Sequential( 101 | MyBlock((1, 28, 28), 1, 10), 102 | MyBlock((10, 28, 28), 10, 10), 103 | MyBlock((10, 28, 28), 10, 10), 104 | ) 105 | self.down1 = nn.Conv2d(10, 10, 4, 2, 1) 106 | 107 | self.te2 = self._make_te(time_emb_dim, 10) 108 | self.b2 = nn.Sequential( 109 | MyBlock((10, 14, 14), 10, 20), 110 | MyBlock((20, 14, 14), 20, 20), 111 | MyBlock((20, 14, 14), 20, 20), 112 | ) 113 | self.down2 = nn.Conv2d(20, 20, 4, 2, 1) 114 | 115 | self.te3 = self._make_te(time_emb_dim, 20) 116 | self.b3 = nn.Sequential( 117 | MyBlock((20, 7, 7), 20, 40), 118 | MyBlock((40, 7, 7), 40, 40), 119 | MyBlock((40, 7, 7), 40, 40), 120 | ) 121 | self.down3 = nn.Sequential( 122 | nn.Conv2d(40, 40, 2, 1), nn.SiLU(), nn.Conv2d(40, 40, 4, 2, 1) 123 | ) 124 | 125 | # Bottleneck 126 | self.te_mid = self._make_te(time_emb_dim, 40) 127 | self.b_mid = nn.Sequential( 128 | MyBlock((40, 3, 3), 40, 20), 129 | MyBlock((20, 3, 3), 20, 20), 130 | MyBlock((20, 3, 3), 20, 40), 131 | ) 132 | 133 | # Second half 134 | self.up1 = nn.Sequential( 135 | nn.ConvTranspose2d(40, 40, 4, 2, 1), 136 | nn.SiLU(), 137 | nn.ConvTranspose2d(40, 40, 2, 1), 138 | ) 139 | 140 | self.te4 = self._make_te(time_emb_dim, 80) 141 | self.b4 = nn.Sequential( 142 | MyBlock((80, 7, 7), 80, 40), 143 | MyBlock((40, 7, 7), 40, 20), 144 | MyBlock((20, 7, 7), 20, 20), 145 | ) 146 | 147 | self.up2 = nn.ConvTranspose2d(20, 20, 4, 2, 1) 148 | self.te5 = self._make_te(time_emb_dim, 40) 149 | self.b5 = nn.Sequential( 150 | MyBlock((40, 14, 14), 40, 20), 151 | MyBlock((20, 14, 14), 20, 10), 152 | MyBlock((10, 14, 14), 10, 10), 153 | ) 154 | 155 | self.up3 = nn.ConvTranspose2d(10, 10, 4, 2, 1) 156 | self.te_out = self._make_te(time_emb_dim, 20) 157 | self.b_out = nn.Sequential( 158 | MyBlock((20, 28, 28), 20, 10), 159 | MyBlock((10, 28, 28), 10, 10), 160 | MyBlock((10, 28, 28), 10, 10, normalize=False), 161 | ) 162 | 163 | self.conv_out = nn.Conv2d(10, 1, 3, 1, 1) 164 | 165 | def forward(self, x, t): 166 | # x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension) 167 | t = self.time_embed(t) 168 | n = len(x) 169 | out1 = self.b1(x + self.te1(t).reshape(n, -1, 1, 1)) # (N, 10, 28, 28) 170 | out2 = self.b2( 171 | self.down1(out1) + self.te2(t).reshape(n, -1, 1, 1) 172 | ) # (N, 20, 14, 14) 173 | out3 = self.b3( 174 | self.down2(out2) + self.te3(t).reshape(n, -1, 1, 1) 175 | ) # (N, 40, 7, 7) 176 | 177 | out_mid = self.b_mid( 178 | self.down3(out3) + self.te_mid(t).reshape(n, -1, 1, 1) 179 | ) # (N, 40, 3, 3) 180 | 181 | out4 = torch.cat((out3, self.up1(out_mid)), dim=1) # (N, 80, 7, 7) 182 | out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1, 1)) # (N, 20, 7, 7) 183 | 184 | out5 = torch.cat((out2, self.up2(out4)), dim=1) # (N, 40, 14, 14) 185 | out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1, 1)) # (N, 10, 14, 14) 186 | 187 | out = torch.cat((out1, self.up3(out5)), dim=1) # (N, 20, 28, 28) 188 | out = self.b_out(out + self.te_out(t).reshape(n, -1, 1, 1)) # (N, 1, 28, 28) 189 | 190 | out = self.conv_out(out) 191 | 192 | return out 193 | 194 | def _make_te(self, dim_in, dim_out): 195 | return nn.Sequential( 196 | nn.Linear(dim_in, dim_out), nn.SiLU(), nn.Linear(dim_out, dim_out) 197 | ) 198 | -------------------------------------------------------------------------------- /src/cv/ign/ign.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import pytorch_lightning as pl 4 | from torch.nn import L1Loss 5 | from torch.optim import Adam 6 | 7 | 8 | class IdempotentNetwork(pl.LightningModule): 9 | def __init__( 10 | self, 11 | prior, 12 | model, 13 | lr=1e-4, 14 | criterion=L1Loss(), 15 | lrec_w=20.0, 16 | lidem_w=20.0, 17 | ltight_w=2.5, 18 | ): 19 | super(IdempotentNetwork, self).__init__() 20 | self.prior = prior 21 | self.model = model 22 | self.model_copy = deepcopy(model) 23 | self.lr = lr 24 | self.criterion = criterion 25 | self.lrec_w = lrec_w 26 | self.lidem_w = lidem_w 27 | self.ltight_w = ltight_w 28 | 29 | def forward(self, x): 30 | return self.model(x) 31 | 32 | def configure_optimizers(self): 33 | optim = Adam(self.model.parameters(), lr=self.lr, betas=(0.5, 0.999)) 34 | return optim 35 | 36 | def get_losses(self, x): 37 | # Prior samples 38 | z = self.prior.sample_n(x.shape[0]).to(x.device) 39 | 40 | # Updating the copy 41 | self.model_copy.load_state_dict(self.model.state_dict()) 42 | 43 | # Forward passes 44 | fx = self(x) 45 | fz = self(z) 46 | fzd = fz.detach() 47 | 48 | l_rec = self.lrec_w * self.criterion(fx, x) 49 | l_idem = self.lidem_w * self.criterion(self.model_copy(fz), fz) 50 | l_tight = -self.ltight_w * self.criterion(self(fzd), fzd) 51 | 52 | return l_rec, l_idem, l_tight 53 | 54 | def training_step(self, batch, batch_idx): 55 | l_rec, l_idem, l_tight = self.get_losses(batch) 56 | loss = l_rec + l_idem + l_tight 57 | 58 | self.log_dict( 59 | { 60 | "train/loss_rec": l_rec, 61 | "train/loss_idem": l_idem, 62 | "train/loss_tight": l_tight, 63 | "train/loss": l_rec + l_idem + l_tight, 64 | }, 65 | sync_dist=True, 66 | ) 67 | 68 | return loss 69 | 70 | def validation_step(self, batch, batch_idx): 71 | l_rec, l_idem, l_tight = self.get_losses(batch) 72 | loss = l_rec + l_idem + l_tight 73 | 74 | self.log_dict( 75 | { 76 | "val/loss_rec": l_rec, 77 | "val/loss_idem": l_idem, 78 | "val/loss_tight": l_tight, 79 | "val/loss": loss, 80 | }, 81 | sync_dist=True, 82 | ) 83 | 84 | def test_step(self, batch, batch_idx): 85 | l_rec, l_idem, l_tight = self.get_losses(batch) 86 | loss = l_rec + l_idem + l_tight 87 | 88 | self.log_dict( 89 | { 90 | "test/loss_rec": l_rec, 91 | "test/loss_idem": l_idem, 92 | "test/loss_tight": l_tight, 93 | "test/loss": loss, 94 | }, 95 | sync_dist=True, 96 | ) 97 | 98 | def generate_n(self, n, device=None): 99 | z = self.prior.sample_n(n) 100 | 101 | if device is not None: 102 | z = z.to(device) 103 | 104 | return self(z) 105 | -------------------------------------------------------------------------------- /src/cv/ign/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning.loggers import WandbLogger 8 | from torch.utils.data import DataLoader 9 | from torchvision.datasets import MNIST 10 | from torchvision.transforms import Compose, Lambda, ToTensor 11 | from torchvision.utils import save_image 12 | 13 | from src.cv.ign.ign import IdempotentNetwork 14 | from src.cv.ign.model import DCGANLikeModel 15 | 16 | 17 | def main(args): 18 | # Set seed 19 | pl.seed_everything(args["seed"]) 20 | 21 | # Load datas 22 | normalize = Lambda(lambda x: (x - 0.5) * 2) 23 | noise = Lambda(lambda x: (x + torch.randn_like(x) * 0.15).clamp(-1, 1)) 24 | train_transform = Compose([ToTensor(), normalize, noise]) 25 | val_transform = Compose([ToTensor(), normalize]) 26 | 27 | train_set = MNIST( 28 | root="data/mnist", train=True, download=True, transform=train_transform 29 | ) 30 | val_set = MNIST( 31 | root="data/mnist", train=False, download=True, transform=val_transform 32 | ) 33 | 34 | def collate_fn(samples): 35 | return torch.stack([sample[0] for sample in samples]) 36 | 37 | train_loader = DataLoader( 38 | train_set, 39 | batch_size=args["batch_size"], 40 | shuffle=True, 41 | collate_fn=collate_fn, 42 | num_workers=args["num_workers"], 43 | ) 44 | val_loader = DataLoader( 45 | val_set, 46 | batch_size=args["batch_size"], 47 | shuffle=False, 48 | collate_fn=collate_fn, 49 | num_workers=args["num_workers"], 50 | ) 51 | 52 | # Initialize model 53 | prior = torch.distributions.Normal(torch.zeros(1, 28, 28), torch.ones(1, 28, 28)) 54 | net = DCGANLikeModel() 55 | model = IdempotentNetwork(prior, net, args["lr"]) 56 | 57 | if not args["skip_train"]: 58 | # Train model 59 | logger = WandbLogger(name="IGN", project="Papers Re-implementations") 60 | callbacks = [ 61 | ModelCheckpoint( 62 | monitor="val/loss", 63 | mode="min", 64 | dirpath="checkpoints/ign", 65 | filename="best", 66 | ) 67 | ] 68 | trainer = pl.Trainer( 69 | strategy="ddp", 70 | accelerator="auto", 71 | max_epochs=args["epochs"], 72 | logger=logger, 73 | callbacks=callbacks, 74 | ) 75 | trainer.fit(model, train_loader, val_loader) 76 | 77 | # Loading the best model 78 | device = "cuda" if torch.cuda.is_available() else "cpu" 79 | model = ( 80 | IdempotentNetwork.load_from_checkpoint( 81 | "checkpoints/ign/best.ckpt", prior=prior, model=net 82 | ) 83 | .eval() 84 | .to(device) 85 | ) 86 | 87 | # Generating images with the trained model 88 | os.makedirs("generated", exist_ok=True) 89 | 90 | images = model.generate_n(100, device=device) 91 | save_image(images, "generated.png", nrow=10, normalize=True) 92 | 93 | print("Done!") 94 | 95 | 96 | if __name__ == "__main__": 97 | parser = ArgumentParser() 98 | parser.add_argument("--seed", type=int, default=0) 99 | parser.add_argument("--lr", type=float, default=1e-4) 100 | parser.add_argument("--batch_size", type=int, default=256) 101 | parser.add_argument("--epochs", type=int, default=50) 102 | parser.add_argument("--num_workers", type=int, default=8) 103 | parser.add_argument("--skip_train", action="store_true") 104 | args = vars(parser.parse_args()) 105 | 106 | print("\n\n", args, "\n\n") 107 | main(args) 108 | -------------------------------------------------------------------------------- /src/cv/ign/model.py: -------------------------------------------------------------------------------- 1 | """DCGAN code from https://github.com/kpandey008/dcgan""" 2 | import torch.nn as nn 3 | 4 | 5 | class Discriminator(nn.Module): 6 | def __init__(self, in_channels=1, base_c=64): 7 | super(Discriminator, self).__init__() 8 | self.main = nn.Sequential( 9 | # Input Size: 1 x 28 x 28 10 | nn.Conv2d(in_channels, base_c, 4, 2, 1, bias=False), 11 | nn.Dropout2d(0.1), 12 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 13 | # Input Size: 32 x 14 x 14 14 | nn.BatchNorm2d(base_c), 15 | nn.Conv2d(base_c, base_c * 2, 4, 2, 1, bias=False), 16 | nn.Dropout2d(0.1), 17 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 18 | # Input Size: 64 x 7 x 7 19 | nn.BatchNorm2d(base_c * 2), 20 | nn.Conv2d(base_c * 2, base_c * 4, 3, 1, 0, bias=False), 21 | nn.Dropout2d(0.1), 22 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 23 | # Input Size: 128 x 7 x 7 24 | nn.BatchNorm2d(base_c * 4), 25 | nn.Conv2d(base_c * 4, base_c * 8, 3, 1, 0, bias=False), 26 | nn.Dropout2d(0.1), 27 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 28 | # Input Size: 256 x 7 x 7 29 | nn.Conv2d(base_c * 8, base_c * 8, 3, 1, 0, bias=False), 30 | ) 31 | 32 | def forward(self, input): 33 | return self.main(input) 34 | 35 | 36 | class Generator(nn.Module): 37 | def __init__(self, in_channels=512, out_channels=1): 38 | super(Generator, self).__init__() 39 | self.main = nn.Sequential( 40 | # Input Size: 256 x 7 x 7 41 | nn.BatchNorm2d(in_channels), 42 | nn.ConvTranspose2d(in_channels, in_channels // 2, 3, 1, 0, bias=False), 43 | nn.Dropout2d(0.1), 44 | nn.ReLU(True), 45 | # Input Size: 128 x 7 x 7 46 | nn.BatchNorm2d(in_channels // 2), 47 | nn.ConvTranspose2d(in_channels // 2, in_channels // 4, 3, 1, 0, bias=False), 48 | nn.Dropout2d(0.1), 49 | nn.ReLU(True), 50 | # Input Size: 64 x 7 x 7 51 | nn.BatchNorm2d(in_channels // 4), 52 | nn.ConvTranspose2d(in_channels // 4, in_channels // 8, 3, 1, 0, bias=False), 53 | nn.Dropout2d(0.1), 54 | nn.ReLU(True), 55 | # Input Size: 32 x 14 x 14 56 | nn.BatchNorm2d(in_channels // 8), 57 | nn.ConvTranspose2d( 58 | in_channels // 8, in_channels // 16, 4, 2, 1, bias=False 59 | ), 60 | nn.Dropout2d(0.1), 61 | nn.ReLU(True), 62 | # Input Size : 16 x 28 x 28 63 | nn.ConvTranspose2d(in_channels // 16, out_channels, 4, 2, 1, bias=False), 64 | nn.Tanh(), 65 | # Final Output : 1 x 28 x 28 66 | ) 67 | 68 | def forward(self, input): 69 | return self.main(input) 70 | 71 | 72 | class DCGANLikeModel(nn.Module): 73 | def __init__(self, in_channels=1, base_c=64): 74 | super(DCGANLikeModel, self).__init__() 75 | self.discriminator = Discriminator(in_channels=in_channels, base_c=base_c) 76 | self.generator = Generator(base_c * 8, out_channels=in_channels) 77 | 78 | def forward(self, x): 79 | return self.generator(self.discriminator(x)) 80 | -------------------------------------------------------------------------------- /src/cv/nf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/cv/nf/__init__.py -------------------------------------------------------------------------------- /src/cv/nf/normalizing_flows.py: -------------------------------------------------------------------------------- 1 | """ 2 | Personal reimplementation of 3 | Density estimation using Real NVP 4 | (https://arxiv.org/abs/1605.08803) 5 | 6 | Useful links: 7 | - https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial11/NF_image_modeling.html 8 | """ 9 | 10 | import os 11 | from argparse import ArgumentParser 12 | 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | from PIL import Image 18 | from torch.optim import Adam 19 | from torch.optim.lr_scheduler import StepLR 20 | from torch.utils.data import DataLoader 21 | from torchvision.datasets import MNIST 22 | from torchvision.transforms import Compose, Lambda, ToTensor 23 | from torchvision.utils import save_image 24 | from tqdm.auto import tqdm 25 | 26 | # Seeding 27 | SEED = 17 28 | np.random.seed(SEED) 29 | torch.random.manual_seed(SEED) 30 | torch.use_deterministic_algorithms(True) 31 | torch.backends.cudnn.deterministic = True 32 | 33 | 34 | def test_reversability(model, x): 35 | """Tests that x ≈ model.backward(model.forward(x)) and shows images""" 36 | with torch.no_grad(): 37 | # Running input forward and backward 38 | z = model.forward(x)[0] 39 | x_tilda = model.backward(z)[0] 40 | 41 | # Printing MSE 42 | mse = ((x_tilda - x) ** 2).mean() 43 | print(f"MSE between input and reconstruction: {mse}") 44 | 45 | # Comparing images visually 46 | plt.imshow(x[0][0].cpu().numpy(), cmap="gray") 47 | plt.title("Original image") 48 | plt.show() 49 | 50 | plt.imshow(z[0][0].cpu().numpy(), cmap="gray") 51 | plt.title("After forward pass") 52 | plt.show() 53 | 54 | plt.imshow(x_tilda[0][0].cpu().numpy(), cmap="gray") 55 | plt.title("Reconstructed image") 56 | plt.show() 57 | 58 | 59 | class LayerNormChannels(nn.Module): 60 | def __init__(self, c_in, eps=1e-5): 61 | super().__init__() 62 | self.gamma = nn.Parameter(torch.ones(1, c_in, 1, 1)) 63 | self.beta = nn.Parameter(torch.zeros(1, c_in, 1, 1)) 64 | self.eps = eps 65 | 66 | def forward(self, x): 67 | mean = x.mean(dim=1, keepdim=True) 68 | var = x.var(dim=1, unbiased=False, keepdim=True) 69 | y = (x - mean) / torch.sqrt(var + self.eps) 70 | y = y * self.gamma + self.beta 71 | return y 72 | 73 | 74 | class CNNBlock(nn.Module): 75 | """A simple CNN architecture which will applied at each Affine Coupling step""" 76 | 77 | def __init__(self, n_channels, kernel_size=3): 78 | super(CNNBlock, self).__init__() 79 | self.elu = nn.ELU() 80 | 81 | self.conv1 = nn.Conv2d( 82 | 2 * n_channels, n_channels, kernel_size, 1, kernel_size // 2 83 | ) 84 | self.conv2 = nn.Conv2d(2 * n_channels, 2 * n_channels, 1, 1) 85 | 86 | def forward(self, x): 87 | out = torch.cat((self.elu(x), self.elu(-x)), dim=1) 88 | out = self.conv1(out) 89 | out = torch.cat((self.elu(out), self.elu(-out)), dim=1) 90 | out = self.conv2(out) 91 | val, gate = out.chunk(2, 1) 92 | return x + val * torch.sigmoid(gate) 93 | 94 | 95 | class SimpleCNN(nn.Module): 96 | def __init__(self, blocks=3, channels_in=1, channels_hidden=32, kernel_size=3): 97 | super(SimpleCNN, self).__init__() 98 | 99 | self.elu = nn.ELU() 100 | self.conv_in = nn.Conv2d(channels_in, channels_hidden, 3, 1, 1) 101 | self.net = nn.Sequential( 102 | *[ 103 | nn.Sequential( 104 | CNNBlock(channels_hidden, kernel_size), 105 | LayerNormChannels(channels_hidden), 106 | ) 107 | for _ in range(blocks) 108 | ] 109 | ) 110 | self.conv_out = nn.Conv2d(2 * channels_hidden, 2 * channels_in, 3, 1, 1) 111 | 112 | # Initializing final convolution weights to zeros 113 | self.conv_out.weight.data.zero_() 114 | self.conv_out.bias.data.zero_() 115 | 116 | def forward(self, x): 117 | out = self.net(self.conv_in(x)) 118 | out = torch.cat((self.elu(out), self.elu(-out)), dim=1) 119 | return self.conv_out(out) 120 | 121 | 122 | class Dequantization(nn.Module): 123 | """Dequantizes the image. Dequantization is the first step for flows, as it allows to not load datapoints 124 | with high likelihoods and put volume on other input data as well.""" 125 | 126 | def __init__(self, max_val): 127 | super(Dequantization, self).__init__() 128 | self.eps = 1e-5 129 | self.max_val = max_val 130 | self.sigmoid_fn = nn.Sigmoid() 131 | 132 | def sigmoid(self, x): 133 | return self.sigmoid_fn(x) 134 | 135 | def log_det_sigmoid(self, x): 136 | s = self.sigmoid(x) 137 | return torch.log(s - s**2) 138 | 139 | def inv_sigmoid(self, x): 140 | return -torch.log((x) ** -1 - 1) 141 | 142 | def log_det_inv_sigmoid(self, x): 143 | return torch.log(1 / (x - x**2)) 144 | 145 | def forward(self, x): 146 | # Dequantizing input (adding continuous noise in range [0, 1]) and putting in range [0, 1] 147 | x = x.to(torch.float32) 148 | log_det = ( 149 | -np.log(self.max_val) 150 | * np.prod(x.shape[1:]) 151 | * torch.ones(len(x)).to(x.device) 152 | ) 153 | out = (x + torch.rand_like(x).detach()) / self.max_val 154 | 155 | # Making sure the input is not too close to either 0 or 1 (bounds of inverse sigmoid) --> put closer to 0.5 156 | log_det += np.log(1 - self.eps) * np.prod(x.shape[1:]) 157 | out = (1 - self.eps) * out + self.eps * 0.5 158 | 159 | # Running the input through the inverse sigmoid function 160 | log_det += self.log_det_inv_sigmoid(out).sum(dim=[1, 2, 3]) 161 | out = self.inv_sigmoid(out) 162 | 163 | return out, log_det 164 | 165 | def backward(self, x): 166 | # Running through the Sigmoid function 167 | log_det = self.log_det_sigmoid(x).sum(dim=[1, 2, 3]) 168 | out = self.sigmoid(x) 169 | 170 | # Undoing the weighted sum 171 | log_det -= np.log(1 - self.eps) * np.prod(x.shape[1:]) 172 | out = (out - self.eps * 0.5) / (1 - self.eps) 173 | 174 | # Undoing the dequantization 175 | log_det += np.log(self.max_val) * np.prod(x.shape[1:]) 176 | out *= self.max_val 177 | out = torch.floor(out).clamp(min=0, max=self.max_val) 178 | 179 | return out, log_det 180 | 181 | 182 | class AffineCoupling(nn.Module): 183 | """Affine Coupling layer. Only modifies half of the input by running the other half through some non-linear function.""" 184 | 185 | def __init__(self, m: nn.Module, modify_x2=True, chw=(1, 28, 28)): 186 | super(AffineCoupling, self).__init__() 187 | self.m = m 188 | self.modify_x2 = modify_x2 189 | 190 | c, h, w = chw 191 | self.scaling_fac = nn.Parameter(torch.ones(c)) 192 | self.mask = torch.tensor( 193 | [[(j + k) % 2 == 0 for k in range(w)] for j in range(h)] 194 | ) 195 | self.mask = self.mask.unsqueeze(0).unsqueeze(0) 196 | 197 | if self.modify_x2: 198 | self.mask = ~self.mask 199 | 200 | def forward(self, x): 201 | # Splitting input in two halves 202 | mask = self.mask.to(x.device) 203 | x1 = mask * x 204 | 205 | # Computing scale and shift for x2 206 | scale, shift = self.m(x1).chunk(2, 1) # Non linear network 207 | s_fac = self.scaling_fac.exp().view(1, -1, 1, 1) 208 | scale = torch.tanh(scale / s_fac) * s_fac # Stabilizes training 209 | 210 | # Masking scale and shift 211 | scale = ~mask * scale 212 | shift = ~mask * shift 213 | 214 | # Computing output 215 | out = (x + shift) * torch.exp(scale) 216 | 217 | # Computing log of the determinant of the Jacobian 218 | log_det_j = torch.sum(scale, dim=[1, 2, 3]) 219 | 220 | return out, log_det_j 221 | 222 | def backward(self, y): 223 | # Splitting input 224 | mask = self.mask.to(y.device) 225 | 226 | x1 = mask * y 227 | 228 | # Computing scale and shift 229 | scale, shift = self.m(x1).chunk(2, 1) 230 | s_fac = self.scaling_fac.exp().view(1, -1, 1, 1) 231 | scale = torch.tanh(scale / s_fac) * s_fac 232 | 233 | # Masking scale and shift 234 | scale = ~mask * scale 235 | shift = ~mask * shift 236 | 237 | # Computing inverse transformation 238 | out = y / torch.exp(scale) - shift 239 | 240 | # Computing log of the determinant of the Jacobian (for backward tranformation) 241 | log_det_j = -torch.sum(scale, dim=[1, 2, 3]) 242 | 243 | return out, log_det_j 244 | 245 | 246 | class Flow(nn.Module): 247 | """General Flow model. Uses invertible layers to map distributions.""" 248 | 249 | def __init__(self, layers): 250 | super(Flow, self).__init__() 251 | self.layers = nn.ModuleList(layers) 252 | 253 | def forward(self, x): 254 | # Computing forward pass (images --> gaussian noise) 255 | out, log_det_j = x, 0 256 | for layer in self.layers: 257 | out, log_det_j_layer = layer(out) 258 | log_det_j += log_det_j_layer 259 | 260 | return out, log_det_j 261 | 262 | def backward(self, y): 263 | # Sampling with backward pass (gaussian noise --> images) 264 | out, log_det_j = y, 0 265 | for layer in self.layers[::-1]: 266 | out, log_det_j_layer = layer.backward(out) 267 | log_det_j += log_det_j_layer 268 | 269 | return out, log_det_j 270 | 271 | 272 | def training_loop(model, epochs, lr, loader, device, dir): 273 | """Trains the model""" 274 | 275 | model.train() 276 | best_loss = float("inf") 277 | optim = Adam(model.parameters(), lr=lr) 278 | scheduler = StepLR(optimizer=optim, step_size=1, gamma=0.99) 279 | to_bpd = np.log2(np.exp(1)) / ( 280 | 28 * 28 * 1 281 | ) # Constant that normalizes w.r.t. input shape 282 | 283 | prior = torch.distributions.normal.Normal(loc=0.0, scale=1.0) 284 | 285 | for epoch in tqdm(range(epochs), desc="Training progress", colour="#00ff00"): 286 | epoch_loss = 0.0 287 | for batch in tqdm( 288 | loader, leave=False, desc=f"Epoch {epoch + 1}/{epochs}", colour="#005500" 289 | ): 290 | # Getting a batch of images and applying dequantization 291 | x = batch[0].to(device) 292 | 293 | # Running images forward and getting log likelihood (log_px) 294 | z, log_det_j = model(x) 295 | # log_pz = -np.log(np.sqrt(2*np.pi)) -(z**2).sum(dim=[1,2,3]) # Because we are mapping to a normal N(0, 1) 296 | log_pz = prior.log_prob(z).sum(dim=[1, 2, 3]) 297 | log_px = log_pz + log_det_j 298 | 299 | # Getting the loss to be optimized (scaling with bits per dimension) 300 | loss = (-(log_px * to_bpd)).mean() 301 | 302 | # Optimization step 303 | optim.zero_grad() 304 | loss.backward() 305 | torch.nn.utils.clip_grad_norm_( 306 | model.parameters(), 1 307 | ) # Clipping gradient norm 308 | optim.step() 309 | 310 | # Logging variable 311 | epoch_loss += loss.item() / len(loader) 312 | 313 | # Stepping with the LR scheduler 314 | scheduler.step() 315 | 316 | # Logging epoch result and storing best model 317 | log_str = f"Epoch {epoch + 1}/{epochs} loss: {epoch_loss:.3f}" 318 | if best_loss > epoch_loss: 319 | best_loss = epoch_loss 320 | log_str += " --> Storing model" 321 | torch.save(model.state_dict(), os.path.join(dir, "nf_model.pt")) 322 | print(log_str) 323 | 324 | 325 | def main(): 326 | # Program arguments 327 | parser = ArgumentParser() 328 | parser.add_argument( 329 | "--epochs", type=int, default=500, help="Number of training epochs" 330 | ) 331 | parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate") 332 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size") 333 | parser.add_argument("--gpu", type=int, default=0, help="GPU number") 334 | parser.add_argument( 335 | "--store_dir", type=str, default=os.getcwd(), help="Store directory" 336 | ) 337 | args = vars(parser.parse_args()) 338 | 339 | N_EPOCHS = args["epochs"] 340 | LR = args["lr"] 341 | BATCH_SIZE = args["batch_size"] 342 | GPU = args["gpu"] 343 | DIR = args["store_dir"] 344 | 345 | # Loading data (images are put in range [0, 255] and are copied on the channel dimension) 346 | transform = Compose([ToTensor(), Lambda(lambda x: (255 * x).to(torch.int32))]) 347 | dataset = MNIST( 348 | root="./../datasets", train=True, download=True, transform=transform 349 | ) 350 | loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) 351 | 352 | # Device 353 | device = torch.device(f"cuda:{GPU}" if torch.cuda.is_available() else "cpu") 354 | device_log = f"Using device: {device} " + ( 355 | f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "" 356 | ) 357 | print(device_log) 358 | 359 | # Creating the model 360 | model = Flow( 361 | [ 362 | Dequantization(256), 363 | *[AffineCoupling(SimpleCNN(), modify_x2=i % 2 == 0) for i in range(30)], 364 | ] 365 | ).to(device) 366 | 367 | # Showing number of trainable paramsk 368 | trainable_params = 0 369 | for param in model.parameters(): 370 | trainable_params += np.prod(param.shape) if param.requires_grad else 0 371 | print(f"The model has {trainable_params} trainable parameters.") 372 | 373 | # Loading pre-trained model (if any) 374 | sd_path = os.path.join(DIR, "nf_model.pt") 375 | pretrained_exists = os.path.isfile(sd_path) 376 | if pretrained_exists: 377 | model.load_state_dict(torch.load(sd_path, map_location=device)) 378 | print("Pre-trained model found and loaded") 379 | 380 | # Testing reversability with first image in the dataset 381 | test_reversability(model, dataset[0][0].unsqueeze(0).to(device)) 382 | 383 | # Training loop (ony if model doesn't exist) 384 | if not pretrained_exists: 385 | training_loop(model, N_EPOCHS, LR, loader, device, DIR) 386 | sd_path = os.path.join(DIR, "nf_model.pt") 387 | model.load_state_dict(torch.load(sd_path, map_location=device)) 388 | 389 | # Testing the trained model 390 | model.eval() 391 | with torch.no_grad(): 392 | # Mapping the normally distributed noise to new images 393 | noise = torch.randn(64, 1, 28, 28).to(device) 394 | images = model.backward(noise)[0] 395 | 396 | save_image(images.float(), "Generated digits.png") 397 | Image.open("Generated digits.png").show() 398 | 399 | # Showing new latent mapping of first image in the dataset 400 | test_reversability(model, dataset[0][0].unsqueeze(0).to(device)) 401 | 402 | 403 | if __name__ == "__main__": 404 | main() 405 | -------------------------------------------------------------------------------- /src/cv/vir/train.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from datasets import load_dataset 6 | from torch.utils.data import DataLoader 7 | from torchvision.transforms import ( 8 | Compose, 9 | Normalize, 10 | RandomHorizontalFlip, 11 | RandomRotation, 12 | Resize, 13 | ToTensor, 14 | ) 15 | 16 | from src.cv.vir.vir import ViR, ViRModes 17 | 18 | 19 | class ViRLightningModule(pl.LightningModule): 20 | def __init__( 21 | self, 22 | lr=1e-3, 23 | out_dim=10, 24 | patch_size=14, 25 | depth=12, 26 | heads=12, 27 | embed_dim=768, 28 | max_len=257, 29 | alpha=1.0, 30 | mode=ViRModes.PARALLEL, 31 | dropout=0.1, 32 | ): 33 | super(ViRLightningModule, self).__init__() 34 | self.lr = lr 35 | self.model = ViR( 36 | out_dim, 37 | patch_size, 38 | depth, 39 | heads, 40 | embed_dim, 41 | max_len, 42 | alpha, 43 | mode, 44 | dropout, 45 | ) 46 | 47 | def forward(self, x): 48 | return self.model(x) 49 | 50 | def training_step(self, batch, batch_idx): 51 | x, y = batch["image"], batch["label"] 52 | y_hat = self(x) 53 | acc = (y_hat.argmax(dim=1) == y).float().mean() 54 | loss = torch.nn.functional.cross_entropy(y_hat, y) 55 | self.log_dict( 56 | { 57 | "train_loss": loss, 58 | "train_acc": acc, 59 | } 60 | ) 61 | return loss 62 | 63 | def validation_step(self, batch, batch_idx): 64 | x, y = batch["image"], batch["label"] 65 | y_hat = self(x) 66 | acc = (y_hat.argmax(dim=1) == y).float().mean() 67 | loss = torch.nn.functional.cross_entropy(y_hat, y) 68 | self.log_dict( 69 | { 70 | "validation_loss": loss, 71 | "validation_acc": acc, 72 | } 73 | ) 74 | return loss 75 | 76 | def test_step(self, batch, batch_idx): 77 | x, y = batch["image"], batch["label"] 78 | y_hat = self(x) 79 | acc = (y_hat.argmax(dim=1) == y).float().mean() 80 | loss = torch.nn.functional.cross_entropy(y_hat, y) 81 | self.log_dict( 82 | { 83 | "test_loss": loss, 84 | "test_acc": acc, 85 | } 86 | ) 87 | return loss 88 | 89 | def configure_optimizers(self): 90 | optim = torch.optim.Adam(self.trainer.model.parameters(), self.lr) 91 | return optim 92 | 93 | 94 | def main(args): 95 | # Seed everything 96 | pl.seed_everything(args["seed"]) 97 | 98 | # Data 99 | resize = Resize((args["image_size"], args["image_size"])) 100 | normalize = Normalize([0.5] * 3, [0.5] * 3) 101 | train_transform = Compose( 102 | [resize, ToTensor(), normalize, RandomHorizontalFlip(), RandomRotation(5)] 103 | ) 104 | val_transform = Compose([resize, ToTensor(), normalize]) 105 | 106 | def make_transform(fn): 107 | def transform(samples): 108 | samples["image"] = [fn(img.convert("RGB")) for img in samples["image"]] 109 | samples["label"] = torch.tensor(samples["label"]) 110 | return samples 111 | 112 | return transform 113 | 114 | train_set = load_dataset("frgfm/imagenette", "320px", split="train") 115 | val_set = load_dataset("frgfm/imagenette", "320px", split="validation") 116 | 117 | train_set.set_transform(make_transform(train_transform)) 118 | val_set.set_transform(make_transform(val_transform)) 119 | 120 | train_loader = DataLoader( 121 | train_set, 122 | batch_size=args["batch_size"], 123 | shuffle=True, 124 | num_workers=args["num_workers"], 125 | ) 126 | val_loader = DataLoader( 127 | val_set, 128 | batch_size=args["batch_size"], 129 | shuffle=False, 130 | num_workers=args["num_workers"], 131 | ) 132 | 133 | # Load model 134 | model = ViRLightningModule( 135 | args["lr"], 136 | out_dim=10, 137 | patch_size=args["patch_size"], 138 | depth=args["depth"], 139 | heads=args["heads"], 140 | embed_dim=args["embed_dim"], 141 | max_len=(args["image_size"] // args["patch_size"]) ** 2 + 1, 142 | alpha=args["alpha"], 143 | mode=ViRModes.PARALLEL, 144 | dropout=args["dropout"], 145 | ) 146 | 147 | # Train model 148 | logger = pl.loggers.WandbLogger(project="Papers Re-implementations", name="ViR") 149 | logger.experiment.config.update(args) 150 | trainer = pl.Trainer( 151 | strategy="ddp", 152 | accelerator="auto", 153 | max_epochs=args["epochs"], 154 | logger=logger, 155 | callbacks=[ 156 | pl.callbacks.ModelCheckpoint( 157 | dirpath=args["checkpoint_dir"], 158 | filename="vir-model", 159 | save_top_k=3, 160 | monitor="train_loss", 161 | mode="min", 162 | ) 163 | ], 164 | ) 165 | trainer.fit(model, train_loader) 166 | 167 | # Evaluate model (setting to recurrent mode) 168 | model.model.set_compute_mode(ViRModes.RECURRENT) 169 | trainer.test(model, val_loader) 170 | 171 | 172 | if __name__ == "__main__": 173 | parser = ArgumentParser() 174 | 175 | # Training arguments 176 | parser.add_argument("--seed", help="Random seed", type=int, default=0) 177 | parser.add_argument( 178 | "--checkpoint_dir", help="Checkpoint directory", type=str, default="checkpoints" 179 | ) 180 | parser.add_argument("--epochs", help="Number of epochs", type=int, default=40) 181 | parser.add_argument("--lr", help="Learning rate", type=float, default=1e-3) 182 | parser.add_argument("--batch_size", help="Batch size", type=int, default=64) 183 | parser.add_argument("--image_size", help="Image size", type=int, default=224) 184 | parser.add_argument("--num_workers", help="Number of workers", type=int, default=4) 185 | 186 | # Model arguments 187 | parser.add_argument("--patch_size", help="Patch size", type=int, default=14) 188 | parser.add_argument("--depth", help="Depth", type=int, default=12) 189 | parser.add_argument("--heads", help="Heads", type=int, default=3) 190 | parser.add_argument( 191 | "--embed_dim", help="Embedding dimension", type=int, default=192 192 | ) 193 | parser.add_argument("--alpha", help="Alpha", type=float, default=0.99) 194 | parser.add_argument("--dropout", help="Dropout", type=float, default=0.1) 195 | 196 | args = vars(parser.parse_args()) 197 | 198 | print("\n\nProgram arguments:\n\n", args, "\n\n") 199 | main(args) 200 | -------------------------------------------------------------------------------- /src/cv/vir/vir.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | DEFAULT_ALPHA = 1.00 5 | 6 | 7 | class ViRModes: 8 | PARALLEL = "parallel" 9 | RECURRENT = "recurrent" 10 | CHUNKWISE = "chunkwise" 11 | 12 | 13 | class Retention(nn.Module): 14 | def __init__( 15 | self, 16 | embed_dim, 17 | max_len, 18 | alpha=DEFAULT_ALPHA, 19 | mode=ViRModes.PARALLEL, 20 | chunk_size=20, 21 | ): 22 | super(Retention, self).__init__() 23 | self.dim = embed_dim 24 | self.max_len = max_len 25 | self.chunk_size = chunk_size 26 | self.alpha = alpha 27 | self.mode = mode 28 | 29 | # Useful buffers 30 | self.register_buffer("dim_sqrt", torch.tensor(embed_dim**0.5)) 31 | self.register_buffer( 32 | "decay_mask", 33 | torch.tensor( 34 | [[alpha ** (i - j) for j in range(max_len)] for i in range(max_len)] 35 | ), 36 | ) 37 | self.register_buffer("causal_mask", torch.ones(max_len, max_len).tril()) 38 | self.register_buffer( 39 | "retention_mask_chunkwise", 40 | torch.tensor( 41 | [self.alpha ** (chunk_size - i - 1) for i in range(chunk_size)] 42 | ), 43 | ) 44 | 45 | self.register_buffer( 46 | "cross_mask_chunkwise", 47 | torch.tensor([self.alpha ** (i + 1) for i in range(chunk_size)]), 48 | ) 49 | self.qkv = nn.Linear(embed_dim, embed_dim * 3) 50 | 51 | def forward_parallel(self, x): 52 | # Getting queries, keys, values 53 | bs, sl, d = x.shape 54 | qkv = self.qkv(x) 55 | q, k, v = torch.chunk(qkv, 3, dim=-1) 56 | 57 | # Causal and decay masking 58 | M = (self.causal_mask[:sl, :sl] * self.decay_mask[:sl, :sl]).repeat(bs, 1, 1) 59 | 60 | # Retention 61 | out = (q @ k.transpose(-1, -2) / self.dim_sqrt * M) @ v 62 | 63 | return out 64 | 65 | def forward_recurrent(self, x, state): 66 | batch_size, length, dim = x.shape 67 | 68 | all_outputs = [] 69 | state = torch.zeros(batch_size, dim, dim).to(x.device) 70 | for i in range(length): 71 | xi = x[:, i] 72 | q, k, v = self.qkv(xi).chunk(3, dim=-1) 73 | 74 | state = self.alpha * state + k.unsqueeze(-1) @ v.unsqueeze(1) 75 | out = q.unsqueeze(1) @ state / self.dim_sqrt 76 | all_outputs.append(out.squeeze()) 77 | 78 | x = torch.stack(all_outputs, dim=1) 79 | return x 80 | 81 | def forward_chunkwise(self, x, chunk_size=None): 82 | # Getting queries, keys, values 83 | if chunk_size is None: 84 | chunk_size = self.chunk_size 85 | 86 | bs, sl, d = x.shape 87 | 88 | # Adding dummy tokens to make the sequence length divisible by chunk_size 89 | if sl % chunk_size != 0: 90 | x = torch.cat( 91 | [x, torch.zeros(bs, chunk_size - sl % chunk_size, d).to(x.device)], 92 | dim=1, 93 | ) 94 | n_chunks = x.shape[1] // chunk_size 95 | 96 | # Running all chunks in parallel 97 | x = x.reshape(bs, n_chunks, chunk_size, d) 98 | q, k, v = self.qkv(x).chunk(3, dim=-1) 99 | 100 | M = ( 101 | self.causal_mask[:chunk_size, :chunk_size] 102 | * self.decay_mask[:chunk_size, :chunk_size] 103 | ).repeat(bs, n_chunks, 1, 1) 104 | 105 | inner_chunk = (q @ k.transpose(-1, -2) / self.dim_sqrt * M) @ v 106 | 107 | # Updating outputs with chunk-wise recurrent 108 | retention_mask = self.retention_mask_chunkwise.repeat(bs, d, 1).transpose( 109 | -1, -2 110 | ) 111 | cross_mask = self.cross_mask_chunkwise.repeat(bs, n_chunks, d, 1).transpose( 112 | -1, -2 113 | ) 114 | 115 | states = torch.zeros(bs, n_chunks, d, d).to(x.device) 116 | for i in range(1, n_chunks): 117 | chunk_state = k[:, i - 1].transpose(-1, -2) @ (v[:, i - 1] * retention_mask) 118 | states[:, i] = chunk_state + states[:, i - 1] * self.alpha**chunk_size 119 | 120 | cross_chunk = (q @ states) / self.dim_sqrt * cross_mask 121 | 122 | # Combining inner and cross chunk 123 | out = inner_chunk + cross_chunk 124 | 125 | # Removing dummy tokens 126 | out = out.flatten(1, 2)[:, :sl] 127 | return out 128 | 129 | def forward(self, x, state=None, mode=ViRModes.PARALLEL, chunk_size=None): 130 | if mode is None: 131 | mode = self.mode 132 | 133 | if mode == ViRModes.PARALLEL: 134 | return self.forward_parallel(x) 135 | elif mode == ViRModes.RECURRENT: 136 | return self.forward_recurrent(x, state) 137 | elif mode == ViRModes.CHUNKWISE: 138 | return self.forward_chunkwise(x, chunk_size) 139 | else: 140 | raise ValueError(f"Unknown mode {mode}") 141 | 142 | 143 | class MultiHeadRetention(nn.Module): 144 | def __init__( 145 | self, 146 | heads, 147 | embed_dim, 148 | max_len, 149 | alpha=DEFAULT_ALPHA, 150 | mode=ViRModes.PARALLEL, 151 | chunk_size=20, 152 | ): 153 | super(MultiHeadRetention, self).__init__() 154 | self.n_heads = heads 155 | self.embed_dim = embed_dim 156 | self.head_dim = embed_dim // heads 157 | self.mode = mode 158 | self.chunk_size = chunk_size 159 | 160 | assert ( 161 | embed_dim % heads == 0 162 | ), "Embedding dimension must be divisible by the number of heads" 163 | 164 | self.heads = nn.ModuleList( 165 | [ 166 | Retention(embed_dim // heads, max_len, alpha, chunk_size) 167 | for _ in range(heads) 168 | ] 169 | ) 170 | self.ln = nn.LayerNorm(embed_dim) 171 | self.gelu = nn.GELU() 172 | self.linear = nn.Linear(embed_dim, embed_dim) 173 | 174 | def forward(self, x, mode=None, chunk_size=None): 175 | if mode is None: 176 | mode = self.mode 177 | 178 | if chunk_size is None: 179 | chunk_size = self.chunk_size 180 | 181 | out = torch.cat( 182 | [ 183 | head( 184 | x[:, :, i * self.head_dim : (i + 1) * self.head_dim], 185 | mode=mode, 186 | chunk_size=chunk_size, 187 | ) 188 | for i, head in enumerate(self.heads) 189 | ], 190 | dim=-1, 191 | ) 192 | return self.linear(self.gelu(self.ln(out))) 193 | 194 | 195 | class MLP(nn.Module): 196 | def __init__(self, embed_dim, hidden_dim=None): 197 | super(MLP, self).__init__() 198 | 199 | if hidden_dim is None: 200 | hidden_dim = 4 * embed_dim 201 | 202 | self.linear1 = nn.Linear(embed_dim, hidden_dim) 203 | self.linear2 = nn.Linear(hidden_dim, embed_dim) 204 | self.gelu = nn.GELU() 205 | 206 | def forward(self, x): 207 | return self.linear2(self.gelu(self.linear1(x))) 208 | 209 | 210 | class ViRBlock(nn.Module): 211 | def __init__( 212 | self, 213 | heads, 214 | embed_dim, 215 | max_len, 216 | alpha=DEFAULT_ALPHA, 217 | mode=ViRModes.PARALLEL, 218 | chunk_size=20, 219 | dropout=0.1, 220 | ): 221 | super(ViRBlock, self).__init__() 222 | self.mode = mode 223 | self.chunk_size = chunk_size 224 | 225 | self.ln1 = nn.LayerNorm(embed_dim) 226 | self.retention = MultiHeadRetention( 227 | heads, embed_dim, max_len, alpha, mode, chunk_size 228 | ) 229 | self.ln2 = nn.LayerNorm(embed_dim) 230 | self.mlp = MLP(embed_dim) 231 | self.dropout1 = nn.Dropout(dropout) 232 | self.dropout2 = nn.Dropout(dropout) 233 | 234 | def forward(self, x, mode=None, chunk_size=None): 235 | if mode is None: 236 | mode = self.mode 237 | 238 | if chunk_size is None: 239 | chunk_size = self.chunk_size 240 | 241 | x = ( 242 | self.dropout1(self.retention(self.ln1(x), mode=mode, chunk_size=chunk_size)) 243 | + x 244 | ) 245 | x = self.dropout2(self.mlp(self.ln2(x))) + x 246 | return x 247 | 248 | 249 | class ViR(nn.Module): 250 | def __init__( 251 | self, 252 | out_dim=10, 253 | patch_size=14, 254 | depth=12, 255 | heads=12, 256 | embed_dim=768, 257 | max_len=257, 258 | alpha=DEFAULT_ALPHA, 259 | mode=ViRModes.PARALLEL, 260 | chunk_size=20, 261 | dropout=0.1, 262 | ): 263 | super(ViR, self).__init__() 264 | 265 | # Local parameters 266 | self.out_dim = 10 267 | self.patch_size = patch_size 268 | self.depth = depth 269 | self.heads = heads 270 | self.embed_dim = embed_dim 271 | self.max_len = max_len 272 | self.alpha = alpha 273 | self.mode = mode 274 | self.chunk_size = chunk_size 275 | 276 | # Embeddings 277 | self.patch_embed = nn.Conv2d( 278 | 3, embed_dim, (patch_size, patch_size), stride=(patch_size, patch_size) 279 | ) 280 | self.pos_embed = nn.Parameter(torch.randn(1, max_len, embed_dim)) 281 | self.class_token = nn.Parameter(torch.randn(1, 1, embed_dim)) 282 | 283 | # ViR blocks 284 | self.blocks = nn.ModuleList( 285 | [ 286 | ViRBlock(heads, embed_dim, max_len, alpha, mode, chunk_size, dropout) 287 | for _ in range(depth) 288 | ] 289 | ) 290 | 291 | # Head 292 | self.ln = nn.LayerNorm(embed_dim) 293 | self.linear = nn.Linear(embed_dim, out_dim) 294 | 295 | def set_compute_mode(self, mode): 296 | self.mode = mode 297 | 298 | def forward(self, x, mode=None, chunk_size=None): 299 | if mode is None: 300 | mode = self.mode 301 | 302 | if chunk_size is None: 303 | chunk_size = self.chunk_size 304 | 305 | # Patch embedding, positional embedding, CLS token 306 | x = self.patch_embed(x).permute(0, 2, 3, 1).flatten(1, 2) 307 | bs, sl = x.shape[:2] 308 | x = x + self.pos_embed.repeat(bs, 1, 1)[:, :sl] 309 | x = torch.cat( 310 | (x, self.class_token.repeat(bs, 1, 1)), dim=1 311 | ) # Important: CLS token goes last 312 | 313 | # Blocks 314 | for block in self.blocks: 315 | x = block(x, mode=mode, chunk_size=chunk_size) 316 | 317 | # Head on the CLS token 318 | x = self.linear(self.ln(x[:, -1])) 319 | 320 | return x 321 | 322 | 323 | if __name__ == "__main__": 324 | """Tests that parallel and recurrent modes give the same output for ViR""" 325 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 326 | x = torch.randn(16, 3, 224, 224).to(device) 327 | model = ViR(depth=12, heads=3, embed_dim=192).eval().to(device) 328 | 329 | with torch.no_grad(): 330 | model.set_compute_mode(ViRModes.CHUNKWISE) 331 | chunk_size = 20 332 | y3 = model(x, chunk_size=chunk_size) 333 | -------------------------------------------------------------------------------- /src/cv/vit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/cv/vit/__init__.py -------------------------------------------------------------------------------- /src/cv/vit/vit_torch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import CrossEntropyLoss 5 | from torch.optim import Adam 6 | from torch.utils.data import DataLoader 7 | from torchvision.datasets.mnist import MNIST 8 | from torchvision.transforms import ToTensor 9 | from tqdm import tqdm, trange 10 | 11 | np.random.seed(0) 12 | torch.manual_seed(0) 13 | 14 | 15 | def patchify(images, n_patches): 16 | n, c, h, w = images.shape 17 | 18 | assert h == w, "Patchify method is implemented for square images only" 19 | 20 | patches = torch.zeros(n, n_patches**2, h * w * c // n_patches**2) 21 | patch_size = h // n_patches 22 | 23 | for idx, image in enumerate(images): 24 | for i in range(n_patches): 25 | for j in range(n_patches): 26 | patch = image[ 27 | :, 28 | i * patch_size : (i + 1) * patch_size, 29 | j * patch_size : (j + 1) * patch_size, 30 | ] 31 | patches[idx, i * n_patches + j] = patch.flatten() 32 | return patches 33 | 34 | 35 | class MyMSA(nn.Module): 36 | def __init__(self, d, n_heads=2): 37 | super(MyMSA, self).__init__() 38 | self.d = d 39 | self.n_heads = n_heads 40 | 41 | assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads" 42 | 43 | d_head = int(d / n_heads) 44 | self.q_mappings = nn.ModuleList( 45 | [nn.Linear(d_head, d_head) for _ in range(self.n_heads)] 46 | ) 47 | self.k_mappings = nn.ModuleList( 48 | [nn.Linear(d_head, d_head) for _ in range(self.n_heads)] 49 | ) 50 | self.v_mappings = nn.ModuleList( 51 | [nn.Linear(d_head, d_head) for _ in range(self.n_heads)] 52 | ) 53 | self.d_head = d_head 54 | self.softmax = nn.Softmax(dim=-1) 55 | 56 | def forward(self, sequences): 57 | # Sequences has shape (N, seq_length, token_dim) 58 | # We go into shape (N, seq_length, n_heads, token_dim / n_heads) 59 | # And come back to (N, seq_length, item_dim) (through concatenation) 60 | result = [] 61 | for sequence in sequences: 62 | seq_result = [] 63 | for head in range(self.n_heads): 64 | q_mapping = self.q_mappings[head] 65 | k_mapping = self.k_mappings[head] 66 | v_mapping = self.v_mappings[head] 67 | 68 | seq = sequence[:, head * self.d_head : (head + 1) * self.d_head] 69 | q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq) 70 | 71 | attention = self.softmax(q @ k.T / (self.d_head**0.5)) 72 | seq_result.append(attention @ v) 73 | result.append(torch.hstack(seq_result)) 74 | return torch.cat([torch.unsqueeze(r, dim=0) for r in result]) 75 | 76 | 77 | class MyViTBlock(nn.Module): 78 | def __init__(self, hidden_d, n_heads, mlp_ratio=4): 79 | super(MyViTBlock, self).__init__() 80 | self.hidden_d = hidden_d 81 | self.n_heads = n_heads 82 | 83 | self.norm1 = nn.LayerNorm(hidden_d) 84 | self.mhsa = MyMSA(hidden_d, n_heads) 85 | self.norm2 = nn.LayerNorm(hidden_d) 86 | self.mlp = nn.Sequential( 87 | nn.Linear(hidden_d, mlp_ratio * hidden_d), 88 | nn.GELU(), 89 | nn.Linear(mlp_ratio * hidden_d, hidden_d), 90 | ) 91 | 92 | def forward(self, x): 93 | out = x + self.mhsa(self.norm1(x)) 94 | out = out + self.mlp(self.norm2(out)) 95 | return out 96 | 97 | 98 | class MyViT(nn.Module): 99 | def __init__(self, chw, n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10): 100 | # Super constructor 101 | super(MyViT, self).__init__() 102 | 103 | # Attributes 104 | self.chw = chw # ( C , H , W ) 105 | self.n_patches = n_patches 106 | self.n_blocks = n_blocks 107 | self.n_heads = n_heads 108 | self.hidden_d = hidden_d 109 | 110 | # Input and patches sizes 111 | assert ( 112 | chw[1] % n_patches == 0 113 | ), "Input shape not entirely divisible by number of patches" 114 | assert ( 115 | chw[2] % n_patches == 0 116 | ), "Input shape not entirely divisible by number of patches" 117 | self.patch_size = (chw[1] / n_patches, chw[2] / n_patches) 118 | 119 | # 1) Linear mapper 120 | self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1]) 121 | self.linear_mapper = nn.Linear(self.input_d, self.hidden_d) 122 | 123 | # 2) Learnable classification token 124 | self.class_token = nn.Parameter(torch.rand(1, self.hidden_d)) 125 | 126 | # 3) Positional embedding 127 | self.register_buffer( 128 | "positional_embeddings", 129 | get_positional_embeddings(n_patches**2 + 1, hidden_d), 130 | persistent=False, 131 | ) 132 | 133 | # 4) Transformer encoder blocks 134 | self.blocks = nn.ModuleList( 135 | [MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)] 136 | ) 137 | 138 | # 5) Classification MLPk 139 | self.mlp = nn.Sequential(nn.Linear(self.hidden_d, out_d), nn.Softmax(dim=-1)) 140 | 141 | def forward(self, images): 142 | # Dividing images into patches 143 | n, c, h, w = images.shape 144 | patches = patchify(images, self.n_patches).to(self.positional_embeddings.device) 145 | 146 | # Running linear layer tokenization 147 | # Map the vector corresponding to each patch to the hidden size dimension 148 | tokens = self.linear_mapper(patches) 149 | 150 | # Adding classification token to the tokens 151 | tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1) 152 | 153 | # Adding positional embedding 154 | out = tokens + self.positional_embeddings.repeat(n, 1, 1) 155 | 156 | # Transformer Blocks 157 | for block in self.blocks: 158 | out = block(out) 159 | 160 | # Getting the classification token only 161 | out = out[:, 0] 162 | 163 | return self.mlp(out) # Map to output dimension, output category distribution 164 | 165 | 166 | def get_positional_embeddings(sequence_length, d): 167 | result = torch.ones(sequence_length, d) 168 | for i in range(sequence_length): 169 | for j in range(d): 170 | result[i][j] = ( 171 | np.sin(i / (10000 ** (j / d))) 172 | if j % 2 == 0 173 | else np.cos(i / (10000 ** ((j - 1) / d))) 174 | ) 175 | return result 176 | 177 | 178 | def main(): 179 | # Loading data 180 | transform = ToTensor() 181 | 182 | train_set = MNIST( 183 | root="./../datasets", train=True, download=True, transform=transform 184 | ) 185 | test_set = MNIST( 186 | root="./../datasets", train=False, download=True, transform=transform 187 | ) 188 | 189 | train_loader = DataLoader(train_set, shuffle=True, batch_size=128) 190 | test_loader = DataLoader(test_set, shuffle=False, batch_size=128) 191 | 192 | # Defining model and training options 193 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 194 | print( 195 | "Using device: ", 196 | device, 197 | f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "", 198 | ) 199 | model = MyViT( 200 | (1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10 201 | ).to(device) 202 | N_EPOCHS = 5 203 | LR = 0.005 204 | 205 | # Training loop 206 | optimizer = Adam(model.parameters(), lr=LR) 207 | criterion = CrossEntropyLoss() 208 | for epoch in trange(N_EPOCHS, desc="Training"): 209 | train_loss = 0.0 210 | for batch in tqdm( 211 | train_loader, desc=f"Epoch {epoch + 1} in training", leave=False 212 | ): 213 | x, y = batch 214 | x, y = x.to(device), y.to(device) 215 | y_hat = model(x) 216 | loss = criterion(y_hat, y) 217 | 218 | train_loss += loss.detach().cpu().item() / len(train_loader) 219 | 220 | optimizer.zero_grad() 221 | loss.backward() 222 | optimizer.step() 223 | 224 | print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}") 225 | 226 | # Test loop 227 | with torch.no_grad(): 228 | correct, total = 0, 0 229 | test_loss = 0.0 230 | for batch in tqdm(test_loader, desc="Testing"): 231 | x, y = batch 232 | x, y = x.to(device), y.to(device) 233 | y_hat = model(x) 234 | loss = criterion(y_hat, y) 235 | test_loss += loss.detach().cpu().item() / len(test_loader) 236 | 237 | correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item() 238 | total += len(x) 239 | print(f"Test loss: {test_loss:.2f}") 240 | print(f"Test accuracy: {correct / total * 100:.2f}%") 241 | 242 | 243 | if __name__ == "__main__": 244 | main() 245 | -------------------------------------------------------------------------------- /src/fff/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/fff/__init__.py -------------------------------------------------------------------------------- /src/fff/fff.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unofficial re-implementation of 3 | 4 | 'Fast Feedforward Networks' 5 | 6 | by Peter Belcak and Roger Wattenhofer 7 | 8 | ArXiv: https://arxiv.org/abs/2308.14711 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class FFFMLP(nn.Module): 16 | def __init__( 17 | self, 18 | in_dim: int, 19 | hidden_dim: int, 20 | out_dim: int, 21 | hidden_activation: nn.Module = nn.ReLU(), 22 | out_activation: nn.Module = nn.Sigmoid(), 23 | swap_prob: float = 0.1, 24 | ): 25 | super(FFFMLP, self).__init__() 26 | self.in_dim = in_dim 27 | self.hidden_dim = hidden_dim 28 | self.out_dim = out_dim 29 | self.swap_prob = swap_prob 30 | 31 | self.linear1 = nn.Linear(in_dim, hidden_dim) 32 | self.hidden_activation = hidden_activation 33 | self.linear2 = nn.Linear(hidden_dim, out_dim) 34 | self.activation = out_activation 35 | 36 | def forward(self, x: torch.Tensor): 37 | out = self.linear2(self.hidden_activation(self.linear1(x))) 38 | out = self.activation(out) 39 | 40 | if self.training and torch.rand(1) < self.swap_prob: 41 | out = 1 - out 42 | 43 | return out 44 | 45 | 46 | class FFFLayer(nn.Module): 47 | def __init__( 48 | self, depth: int, in_dim: int, node_hidden: int, leaf_hidden: int, out_dim: int 49 | ): 50 | super(FFFLayer, self).__init__() 51 | 52 | self.depth = depth 53 | self.node_hidden = node_hidden 54 | self.leaf_hidden = leaf_hidden 55 | 56 | nodes = [FFFMLP(in_dim, node_hidden, 1) for _ in range(2 ** (depth) - 1)] 57 | leaves = [ 58 | FFFMLP( 59 | in_dim, 60 | leaf_hidden, 61 | out_dim, 62 | out_activation=nn.Identity(), 63 | swap_prob=0.0, 64 | ) 65 | for _ in range(2**depth) 66 | ] 67 | self.tree = nn.ModuleList(nodes + leaves) 68 | 69 | def forward(self, x: torch.Tensor, idx: int = 1, activations: list = None): 70 | c = self.tree[idx - 1](x) 71 | 72 | if 2 * idx + 1 <= len(self.tree): 73 | # Continuing down the tree 74 | if self.training: 75 | # During training, split signal all over the tree 76 | left, a_left = self.forward(x, 2 * idx, activations) 77 | right, a_right = self.forward(x, 2 * idx + 1, activations) 78 | 79 | children_are_leaves = a_left is None and a_right is None 80 | activations = [c] if children_are_leaves else a_left + a_right + [c] 81 | 82 | return c * left + (1 - c) * right, activations 83 | else: 84 | # During inference, input goes through a single path 85 | left, a_left = self.forward(x, 2 * idx, activations) 86 | right, a_right = self.forward(x, 2 * idx + 1, activations) 87 | 88 | children_are_leaves = a_left is None and a_right is None 89 | activations = [c] if children_are_leaves else a_left + a_right + [c] 90 | 91 | # Hardening 92 | # TODO: Improve and actually go down only one path 93 | c = (c >= 0.5).float() 94 | 95 | return c * left + (1 - c) * right, activations 96 | 97 | return c, activations 98 | -------------------------------------------------------------------------------- /src/fff/main.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import torch.nn as nn 6 | from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar 7 | from pytorch_lightning.loggers import WandbLogger 8 | from torch.optim import SGD 9 | from torch.optim.lr_scheduler import LinearLR 10 | from torch.utils.data import DataLoader 11 | from torchmetrics import Accuracy 12 | from torchvision.datasets import MNIST 13 | from torchvision.transforms import Compose, ToTensor 14 | 15 | from src.fff.fff import FFFLayer 16 | 17 | 18 | class FlattenMNIST: 19 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 20 | return x.flatten() 21 | 22 | 23 | class PLWrapper(pl.LightningModule): 24 | def __init__(self, model, total_iters=10, hardening_weight=3.0, lr=0.2): 25 | super(PLWrapper, self).__init__() 26 | self.model = model 27 | self.cross_entropy = nn.CrossEntropyLoss() 28 | self.accuracy = Accuracy("multiclass", num_classes=10) 29 | self.total_iters = total_iters 30 | self.hardening_weight = hardening_weight 31 | self.lr = lr 32 | 33 | def forward(self, x): 34 | return self.model(x) 35 | 36 | def entropy(self, x): 37 | entropies = x * torch.log(x + 1e-8) + (1 - x) * torch.log(1 - x + 1e-8) 38 | return -torch.mean(entropies) 39 | 40 | def configure_optimizers(self): 41 | optimizer = SGD(self.model.parameters(), lr=self.lr) 42 | scheduler = LinearLR(optimizer, 1, 0, total_iters=self.total_iters) 43 | return [optimizer], [scheduler] 44 | 45 | def training_step(self, batch, batch_idx): 46 | x, y = batch 47 | y_hat, activations = self.model(x) 48 | activations = torch.cat(activations, dim=0) 49 | le = self.hardening_weight * self.entropy(activations) 50 | lce = self.cross_entropy(y_hat, y) 51 | loss = lce + le 52 | acc = self.accuracy(y_hat, y) 53 | self.log_dict( 54 | { 55 | "train_loss": loss, 56 | "train_loss_ce": lce, 57 | "train_loss_entropy": le, 58 | "train_acc": acc, 59 | } 60 | ) 61 | return loss 62 | 63 | def test_step(self, batch, batch_idx): 64 | x, y = batch 65 | y_hat, _ = self.model(x) 66 | loss = self.cross_entropy(y_hat, y) 67 | acc = self.accuracy(y_hat, y) 68 | self.log("test_loss", loss) 69 | self.log("test_acc", acc) 70 | 71 | 72 | def main(args): 73 | """ 74 | Training and evaluating an FFF model on MNIST image classification. 75 | Over-engineering code with Lightning and W&B to keep the good habits. 76 | """ 77 | # Program arguments 78 | batch_size = args["batch_size"] 79 | max_epochs = args["max_epochs"] 80 | lr = args["lr"] 81 | hardening_weight = args["hardening_weight"] 82 | seed = args["seed"] 83 | 84 | # Setting reproducibility 85 | pl.seed_everything(seed) 86 | 87 | # Getting data 88 | transform = Compose([ToTensor(), FlattenMNIST()]) 89 | train = MNIST("./", train=True, download=True, transform=transform) 90 | test = MNIST("./", train=False, download=True, transform=transform) 91 | train_loader = DataLoader(train, batch_size=batch_size, num_workers=4, shuffle=True) 92 | test_loader = DataLoader(test, batch_size=batch_size, num_workers=4, shuffle=False) 93 | 94 | # Getting model 95 | fff_model = FFFLayer( 96 | depth=3, in_dim=28 * 28, node_hidden=32, leaf_hidden=32, out_dim=10 97 | ) 98 | 99 | model = PLWrapper( 100 | fff_model, 101 | total_iters=max_epochs * len(train_loader), 102 | hardening_weight=hardening_weight, 103 | lr=lr, 104 | ).train() 105 | 106 | # Training 107 | trainer = pl.Trainer( 108 | logger=WandbLogger( 109 | name="FFF MNIST", 110 | project="Papers Re-implementations", 111 | ), 112 | accelerator="auto", 113 | max_epochs=max_epochs, 114 | callbacks=[ 115 | ModelCheckpoint(dirpath="./checkpoints", monitor="train_loss", mode="min"), 116 | TQDMProgressBar(), 117 | ], 118 | ) 119 | trainer.fit(model, train_loader) 120 | 121 | # Loading best model 122 | model = PLWrapper.load_from_checkpoint( 123 | trainer.checkpoint_callback.best_model_path, model=fff_model 124 | ) 125 | 126 | # Testing and logging 127 | model.eval() 128 | trainer.test(model, test_loader) 129 | 130 | 131 | if __name__ == "__main__": 132 | parser = ArgumentParser() 133 | parser.add_argument( 134 | "--batch_size", 135 | type=int, 136 | default=64, 137 | help="Batch size for training and evaluation.", 138 | ) 139 | parser.add_argument( 140 | "--max_epochs", 141 | type=int, 142 | default=10, 143 | help="Number of epochs to train the model.", 144 | ) 145 | parser.add_argument( 146 | "--lr", type=float, default=0.2, help="Learning rate for the optimizer." 147 | ) 148 | parser.add_argument( 149 | "--hardening_weight", 150 | type=float, 151 | default=3.0, 152 | help="Hardening weight for the entropy loss.", 153 | ) 154 | parser.add_argument( 155 | "--seed", type=int, default=0, help="Seed for the random number generators." 156 | ) 157 | args = vars(parser.parse_args()) 158 | print(args) 159 | main(args) 160 | -------------------------------------------------------------------------------- /src/gnns/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/gnns/__init__.py -------------------------------------------------------------------------------- /src/gnns/gnns.py: -------------------------------------------------------------------------------- 1 | """Implementation of convolutional, attentional and message-passing GNNs, inspired by the paper 2 | Everything is Connected: Graph Neural Networks 3 | (https://arxiv.org/abs/2301.08210) 4 | 5 | Useful links: 6 | - Petar Veličković PDF talk: https://petar-v.com/talks/GNN-EEML.pdf 7 | """ 8 | 9 | import warnings 10 | from argparse import ArgumentParser 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.optim import Adam 15 | from torch.utils.data import DataLoader 16 | from torchvision.datasets import MNIST 17 | from torchvision.transforms import Compose, Lambda, Resize, ToTensor 18 | from tqdm import tqdm 19 | 20 | import wandb 21 | 22 | # Definitions 23 | NETWORK_TYPES = ["attn", "conv"] 24 | AGGREGATION_FUNCTIONS = { 25 | "sum": lambda X, dim=1: torch.sum(X, dim=dim), 26 | "avg": lambda X, dim=1: torch.mean(X, dim=dim), 27 | } 28 | 29 | 30 | def parse_args(): 31 | """Parses the program arguments""" 32 | parser = ArgumentParser() 33 | 34 | # Data arguments 35 | parser.add_argument( 36 | f"--image_size", 37 | type=int, 38 | help="Size to which reshape CIFAR images. Default is 14 (196 nodes).", 39 | default=14, 40 | ) 41 | 42 | # Model arguments 43 | parser.add_argument( 44 | f"--type", 45 | type=str, 46 | help="Type of the network used. Default is 'attn'", 47 | choices=NETWORK_TYPES, 48 | default="attn", 49 | ) 50 | parser.add_argument( 51 | f"--aggregation", 52 | type=str, 53 | help="Aggregation function", 54 | choices=list(AGGREGATION_FUNCTIONS.keys()), 55 | default="avg", 56 | ) 57 | parser.add_argument( 58 | f"--aggregation_out", 59 | type=str, 60 | help="Aggregation function for graph classification", 61 | choices=list(AGGREGATION_FUNCTIONS.keys()), 62 | default="avg", 63 | ) 64 | parser.add_argument( 65 | f"--n_layers", type=int, help="Number of layers of the GNNs", default=8 66 | ) 67 | 68 | # Training arguments 69 | parser.add_argument( 70 | f"--epochs", type=int, help="Training epochs. Default is 10.", default=10 71 | ) 72 | parser.add_argument( 73 | f"--lr", type=float, help="Learning rate. Default is 1e-3.", default=0.001 74 | ) 75 | parser.add_argument( 76 | f"--batch_size", 77 | type=int, 78 | help="Batch size used for training. Default is 64.", 79 | default=64, 80 | ) 81 | parser.add_argument( 82 | f"--checkpoint", 83 | type=str, 84 | help="Path to model checkpoints. Default is 'gnn.pt'.", 85 | default="gnn.pt", 86 | ) 87 | 88 | return vars(parser.parse_args()) 89 | 90 | 91 | def get_device(): 92 | """Gets the CUDA device if available, warns that code will run on CPU only otherwise""" 93 | if torch.cuda.is_available(): 94 | device = torch.device("cuda") 95 | print("\nFound GPU: ", torch.cuda.get_device_name(device)) 96 | return device 97 | elif torch.backends.mps.is_available(): 98 | device = torch.device("mps") 99 | print("\nFound Apple MPS chip.") 100 | 101 | warnings.warn("\nWARNING: No GPU nor MPS found - Training on CPU.") 102 | return torch.device("cpu") 103 | 104 | 105 | class PsiNetwork(nn.Module): 106 | """ 107 | Simple MLP network denoted as the 'psi' function in the paper. 108 | The role of this network is to extract relevant features to be passed to neighbouring edges. 109 | """ 110 | 111 | def __init__(self, in_size, out_size): 112 | super(PsiNetwork, self).__init__() 113 | self.linear = nn.Linear(in_size, out_size) 114 | self.relu = nn.ReLU() 115 | 116 | def forward(self, X): 117 | return self.relu(self.linear(X)) 118 | 119 | 120 | class GraphConvLayer(nn.Module): 121 | """ 122 | Graph Convolutional layer. 123 | It computes the next hidden states of the edges as a convolution over neighbours. 124 | """ 125 | 126 | def __init__(self, n, d, aggr): 127 | super(GraphConvLayer, self).__init__() 128 | 129 | self.coefficients = nn.Parameter(torch.ones((n, n)) / n) 130 | self.ln = nn.LayerNorm(d) 131 | self.psi = PsiNetwork(d, d) 132 | self.aggr = aggr 133 | 134 | def forward(self, H, A): 135 | weights = self.coefficients * A # (N, N) 136 | messages = self.psi(self.ln(H)) # (B, N, D) 137 | messages = torch.einsum("nm, bmd -> bndm", weights, messages) # (B, N, D, N) 138 | messages = self.aggr(messages, dim=-1) # (B, N, D) 139 | return messages 140 | 141 | 142 | class Attention(nn.Module): 143 | def __init__(self, dim): 144 | super(Attention, self).__init__() 145 | self.dim = dim 146 | 147 | def forward(self, x, mask=None): 148 | # x has shape (B, N, D) 149 | attn_cues = (x @ x.transpose(-2, -1)) / (self.dim**0.5 + 1e-5) #  (B, N, N) 150 | 151 | if mask is not None: 152 | attn_cues = attn_cues.masked_fill(mask == 0, float("-inf")) 153 | 154 | attn_cues = attn_cues.softmax(-1) 155 | return attn_cues 156 | 157 | 158 | class GraphAttentionLayer(nn.Module): 159 | """ 160 | Graph Attentional Layer. 161 | It computes the next hidden states of the edges using attention. 162 | """ 163 | 164 | def __init__(self, n, d, aggr): 165 | super(GraphAttentionLayer, self).__init__() 166 | 167 | self.aggr = aggr 168 | self.psi = PsiNetwork(d, d) 169 | self.ln1 = nn.LayerNorm(d) 170 | self.ln2 = nn.LayerNorm(2 * d) 171 | self.sa = Attention(d) 172 | 173 | def forward(self, H, A): 174 | messages = self.psi(self.ln1(H)) #  (B, N, D) 175 | attn = self.sa(H, A) # (B, N, N) 176 | 177 | messages = torch.einsum("bnm, bmd -> bndm", attn, messages) #  (B, N, D, N) 178 | messages = self.aggr(messages, dim=-1) # (B, N, D) 179 | return messages 180 | 181 | 182 | class GraphNeuralNetwork(nn.Module): 183 | """Graph Neural Network class.""" 184 | 185 | _NET_TYPE_TO_LAYER = {"attn": GraphAttentionLayer, "conv": GraphConvLayer} 186 | 187 | def _get_phi_net(self, dim_in, dim_out): 188 | return nn.Sequential( 189 | nn.Linear(dim_in, dim_out), nn.ReLU(), nn.Linear(dim_out, dim_out) 190 | ) 191 | 192 | def __init__(self, net_type, n_layers, n, d_in, d_hidden, d_out, aggr, aggr_out): 193 | super(GraphNeuralNetwork, self).__init__() 194 | 195 | assert ( 196 | net_type in NETWORK_TYPES 197 | ), f"ERROR: GNN type {net_type} not supported. Pick one of {NETWORK_TYPES}" 198 | 199 | self.net_type = net_type 200 | self.n_layers = n_layers 201 | self.encoding = nn.Linear(d_in, d_hidden) 202 | self.layers = nn.ModuleList( 203 | [ 204 | self._NET_TYPE_TO_LAYER[net_type]( 205 | n, d_hidden, AGGREGATION_FUNCTIONS[aggr] 206 | ) 207 | for _ in range(n_layers) 208 | ] 209 | ) 210 | 211 | self.phi_nets = nn.ModuleList( 212 | [self._get_phi_net(2 * d_hidden, d_hidden) for _ in range(n_layers)] 213 | ) 214 | 215 | self.out_aggr = AGGREGATION_FUNCTIONS[aggr_out] 216 | self.out_mlp = nn.Sequential( 217 | nn.Linear(d_hidden, d_hidden), nn.ReLU(), nn.Linear(d_hidden, d_out) 218 | ) 219 | 220 | def forward(self, X, A): 221 | # X has shape (B, N, D) and represents the edges. 222 | # A is binary with shape (N, N) and represents the adjacency matrix. 223 | H = self.encoding(X) 224 | for l, p in zip(self.layers, self.phi_nets): 225 | messages = l(H, A) 226 | H = H + p(torch.cat((H, messages), dim=-1)) 227 | return self.out_mlp(self.out_aggr(H)) 228 | 229 | 230 | def main(): 231 | # Parsing arguments 232 | args = parse_args() 233 | print("Launched program with the following arguments:", args) 234 | 235 | # Getting device 236 | device = get_device() 237 | 238 | # Loading data 239 | # We reshape the image such that each pixel is an edge with C features. 240 | img_size = args["image_size"] 241 | transform = Compose( 242 | [ 243 | ToTensor(), 244 | Resize((img_size, img_size)), 245 | # (C, H, W) -> (H*W, C). 246 | Lambda(lambda x: x.flatten().reshape(-1, 1)), 247 | ] 248 | ) 249 | train_set = MNIST("./../datasets", train=True, transform=transform, download=True) 250 | test_set = MNIST("./../datasets", train=False, transform=transform, download=True) 251 | 252 | train_loader = DataLoader(train_set, batch_size=args["batch_size"], shuffle=True) 253 | test_loader = DataLoader(test_set, batch_size=args["batch_size"], shuffle=False) 254 | 255 | # Building the Neighbourhood matrix (1024 x 1024) for all "graphs" (images of size 32x32) 256 | A = torch.zeros((img_size**2, img_size**2)).to(device) 257 | nums = torch.arange(img_size**2).reshape((img_size, img_size)) 258 | for i in range(img_size): 259 | for j in range(img_size): 260 | start_x, start_y = i - 1 if i > 0 else 0, j - 1 if j > 0 else 0 261 | neighbours = nums[start_x : i + 2, start_y : j + 2].flatten() 262 | 263 | for n in neighbours: 264 | A[i * img_size + j, n] = A[n, i * img_size + j] = 1 265 | 266 | # Creating model 267 | # Number of edges, edge dimensionality, hidden dimensionality and number of output classes 268 | n, d, h, o = img_size**2, 1, 16, 10 269 | model = GraphNeuralNetwork( 270 | args["type"], 271 | args["n_layers"], 272 | n, 273 | d, 274 | h, 275 | o, 276 | aggr=args["aggregation"], 277 | aggr_out=args["aggregation_out"], 278 | ) 279 | 280 | # Training loop 281 | n_epochs = args["epochs"] 282 | checkpoint = ( 283 | args["checkpoint"] if args["checkpoint"] else f"({args['type']}_gnn).pt" 284 | ) 285 | optim = Adam(model.parameters(), args["lr"]) 286 | criterion = nn.CrossEntropyLoss() 287 | model = model.to(device) 288 | model.train() 289 | 290 | min_loss = float("inf") 291 | progress_bar = tqdm(range(1, n_epochs + 1)) 292 | 293 | wandb.init( 294 | project="Papers Re-implementations", 295 | entity="peutlefaire", 296 | name=f"GNN ({args['type']})", 297 | config={ 298 | "type": args["type"], 299 | "aggregation": args["aggregation"], 300 | "layers": args["n_layers"], 301 | "epochs": n_epochs, 302 | "batch_size": args["batch_size"], 303 | "lr": args["lr"], 304 | }, 305 | ) 306 | 307 | for epoch in progress_bar: 308 | epoch_loss = 0.0 309 | for batch in tqdm(train_loader, leave=False): 310 | x, y = batch 311 | x, y = x.to(device), y.to(device) 312 | 313 | loss = criterion(model(x, A), y) 314 | epoch_loss += loss.item() / len(train_loader) 315 | optim.zero_grad() 316 | loss.backward() 317 | optim.step() 318 | 319 | wandb.log({"batch loss": loss.item()}) 320 | 321 | description = f"Epoch {epoch}/{n_epochs} - Training loss: {epoch_loss:.3f}" 322 | if epoch_loss < min_loss: 323 | min_loss = epoch_loss 324 | torch.save(model.state_dict(), checkpoint) 325 | description += " -> ✅ Stored best model." 326 | 327 | wandb.log({"epoch loss": epoch_loss}) 328 | progress_bar.set_description(description) 329 | wandb.finish() 330 | 331 | # Testing loop 332 | model.load_state_dict(torch.load(checkpoint, map_location=device)) 333 | model = model.eval() 334 | with torch.no_grad(): 335 | correct, total = 0, 0 336 | for batch in test_loader: 337 | x, y = batch 338 | x, y = x.to(device), y.to(device) 339 | 340 | correct += (model(x, A).argmax(1) == y).sum().item() 341 | total += len(y) 342 | print(f"\n\nFinal test accuracy: {(correct / total * 100):.2f}%") 343 | print("Program completed successfully!") 344 | 345 | 346 | if __name__ == "__main__": 347 | main() 348 | -------------------------------------------------------------------------------- /src/nlp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/nlp/__init__.py -------------------------------------------------------------------------------- /src/nlp/bert/README.md: -------------------------------------------------------------------------------- 1 | # BERT 2 | 3 | ## Dataset 4 | Training BERT will download the Wikipedia dataset from March 1st, 2022 from [huggingface datasets](https://huggingface.co/datasets/wikipedia). The total disk size required for the dataset is ~ `43GB`, and it will be downloaded under your `HF_DATASETS_CACHE`. 5 | -------------------------------------------------------------------------------- /src/nlp/bert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/nlp/bert/__init__.py -------------------------------------------------------------------------------- /src/nlp/bert/data.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class BertDataset(Dataset): 8 | """Creates a dataset for BERT training where each sample is broken into two sentences.""" 9 | 10 | def __init__( 11 | self, 12 | dataset, 13 | tokenizer, 14 | max_length, 15 | sentence_divider="\n\n", 16 | mlm_ratio=0.15, 17 | mask_ratio=0.8, 18 | ): 19 | super(BertDataset, self).__init__() 20 | 21 | # Filtering empty sentences 22 | self.dataset = dataset.filter( 23 | lambda sample: len(sample["text"]) > 0 24 | and sentence_divider in sample["text"], 25 | ) 26 | 27 | # Dataset parameters 28 | self.tokenizer = tokenizer 29 | self.max_length = max_length 30 | self.sentence_divider = sentence_divider 31 | self.mlm_ratio = mlm_ratio 32 | self.mask_ratio = mask_ratio 33 | 34 | # MLM distribution 35 | nothing_prob = 1 - mlm_ratio 36 | mask_prob = (1 - nothing_prob) * mask_ratio 37 | different_word_prob = (1 - nothing_prob - mask_prob) / 2 38 | probs = torch.tensor( 39 | [nothing_prob, mask_prob, different_word_prob, different_word_prob] 40 | ) 41 | self.mask_dist = torch.distributions.Multinomial(probs=probs) 42 | 43 | def __len__(self): 44 | return len(self.dataset) 45 | 46 | def __getitem__(self, index): 47 | # First sentence 48 | sentences = self.dataset[index]["text"].split(self.sentence_divider) 49 | n_sentences = len(sentences) 50 | 51 | # Picking first sentence 52 | s1_idx = random.randint(0, n_sentences - 2) 53 | s1 = sentences[s1_idx] 54 | 55 | # Next sentence and prediction label 56 | if torch.rand(1) > 0.5: 57 | # 50% of the time, pick the real next "sentence" 58 | s2 = sentences[(s1_idx + 1)] 59 | nsp_label = 1 60 | else: 61 | # The other 50% of the time, pick a random next "sentence" 62 | idx = random.randint(0, len(self.dataset)) 63 | if idx == index: 64 | idx = idx + 1 % len(self.dataset) 65 | sentences_other = self.dataset[index]["text"].split(self.sentence_divider) 66 | s2 = sentences_other[random.randint(0, len(sentences_other) - 1)] 67 | nsp_label = 0 68 | 69 | # Preparing input ids 70 | tokenizer_out = self.tokenizer( 71 | s1, 72 | s2, 73 | return_tensors="pt", 74 | padding="max_length", 75 | max_length=self.max_length, 76 | truncation=True, 77 | ) 78 | 79 | input_ids, segment_idx, attn_mask = ( 80 | tokenizer_out["input_ids"][0], 81 | tokenizer_out["token_type_ids"][0], 82 | tokenizer_out["attention_mask"][0], 83 | ) 84 | 85 | # Getting mask indexes 86 | mask_idx = self.mask_dist.sample((len(input_ids),)) 87 | 88 | # Not masking CLS and SEP 89 | sep_idx = -1 90 | for i in range(len(segment_idx)): 91 | if segment_idx[i] == 1: 92 | sep_idx = i - 1 93 | break 94 | mask_idx[0] = mask_idx[-1] = mask_idx[sep_idx] = torch.tensor([1, 0, 0, 0]) 95 | mask_idx = mask_idx.argmax(dim=-1) 96 | 97 | # Getting labels for masked tokens 98 | mlm_idx = (mask_idx != 0).long() 99 | mlm_labels = input_ids.clone() 100 | 101 | # Masking input tokens according to strategy 102 | input_ids[mask_idx == 1] = self.tokenizer.mask_token_id 103 | input_ids[mask_idx == 2] = torch.randint(0, self.tokenizer.vocab_size, (1,)) 104 | 105 | return { 106 | "input_ids": input_ids, 107 | "segment_ids": segment_idx, 108 | "attention_mask": attn_mask, 109 | "mlm_labels": mlm_labels, 110 | "mlm_idx": mlm_idx, 111 | "nsp_labels": nsp_label, 112 | } 113 | -------------------------------------------------------------------------------- /src/nlp/bert/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Re-implementation of 3 | 4 | BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 5 | Devlin et al. (2018) 6 | (https://arxiv.org/abs/1810.04805) 7 | 8 | on the WikiPedia dataset. 9 | """ 10 | 11 | import os 12 | from argparse import ArgumentParser 13 | 14 | import torch 15 | from torch.utils.data import DataLoader 16 | 17 | torch.set_float32_matmul_precision("medium") 18 | 19 | import pytorch_lightning as pl 20 | import transformers 21 | from datasets import load_dataset 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | from pytorch_lightning.loggers import WandbLogger 24 | from transformers import BertTokenizer 25 | 26 | transformers.logging.set_verbosity_error() 27 | 28 | from src.nlp.bert.data import BertDataset 29 | from src.nlp.bert.model import Bert 30 | 31 | 32 | @torch.no_grad() 33 | def unmask_sentences(bert, tokenizer, max_length, file_path): 34 | """Uses the bert model to unmask sentences from a file. Prints the unmaksed sentences.""" 35 | file = open(file_path, "r") 36 | lines = file.readlines() 37 | lines = [line if not line.endswith("\n") else line[:-1] for line in lines] 38 | file.close() 39 | 40 | bert = bert.cuda() 41 | for i, line in enumerate(lines): 42 | # Preparing input 43 | input_ids = tokenizer( 44 | line, 45 | return_tensors="pt", 46 | max_length=max_length, 47 | padding="max_length", 48 | truncation=True, 49 | )["input_ids"].cuda() 50 | segment_ids = torch.zeros_like(input_ids) 51 | attn_mask = None 52 | 53 | # Running BERT 54 | _, _, mlm_preds = bert(input_ids) 55 | 56 | # Getting predictions for the MASK'd words 57 | unmasked_words = mlm_preds[input_ids == tokenizer.mask_token_id].argmax(dim=-1) 58 | unmasked_words = tokenizer.decode(unmasked_words).split(" ") 59 | 60 | # Reconstructing the unmasked sentence 61 | sentence = "" 62 | parts = line.split("[MASK]") 63 | for word in unmasked_words: 64 | sentence += parts.pop(0) + word 65 | sentence += parts.pop(0) 66 | 67 | # Showing results 68 | print(f"\n\nSENTENCE {i+1}:") 69 | print(f"\tOriginal: {line}\n\tUnmasked: {sentence}") 70 | 71 | 72 | def main(args): 73 | """ 74 | Train a BERT model on Wikipedia. 75 | Use the model to "unmask" sentences from a file. 76 | """ 77 | # Unpacking parameters 78 | n_blocks = args["n_blocks"] 79 | n_heads = args["n_heads"] 80 | hidden_dim = args["hidden_dim"] 81 | dropout = args["dropout"] 82 | max_len = args["max_len"] 83 | batch_size = args["batch_size"] 84 | max_train_steps = args["max_train_steps"] 85 | lr = args["lr"] 86 | weight_decay = args["weight_decay"] 87 | warmup_steps = args["warmup_steps"] 88 | save_dir = args["save"] 89 | file_path = args["masked_sentences"] 90 | seed = args["seed"] 91 | 92 | # Setting random seed 93 | pl.seed_everything(seed) 94 | 95 | # Load the dataset (wikipedia only has 'train', so we split it ourselves) 96 | train_set = load_dataset("wikipedia", "20220301.en", split="train[:98%]") 97 | val_set = load_dataset("wikipedia", "20220301.en", split="train[98%:99%]") 98 | test_set = load_dataset("wikipedia", "20220301.en", split="train[99%:]") 99 | 100 | # Setting format to torch (not striclty necessary) 101 | # train_set.set_format(type="torch", columns=["text"]) 102 | # val_set.set_format(type="torch", columns=["text"]) 103 | # test_set.set_format(type="torch", columns=["text"]) 104 | 105 | # Bert tokenizer 106 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 107 | 108 | # Wrapping dataset with Bert logic for batches 109 | train_set, val_set, test_set = ( 110 | BertDataset(train_set, tokenizer, max_len), 111 | BertDataset(val_set, tokenizer, max_len), 112 | BertDataset(test_set, tokenizer, max_len), 113 | ) 114 | 115 | # Data loaders 116 | cpus = os.cpu_count() 117 | train_loader = DataLoader( 118 | train_set, batch_size=batch_size, shuffle=True, num_workers=cpus 119 | ) 120 | val_loader = DataLoader( 121 | val_set, batch_size=batch_size, shuffle=False, num_workers=cpus 122 | ) 123 | test_loader = DataLoader( 124 | test_set, batch_size=batch_size, shuffle=False, num_workers=cpus 125 | ) 126 | 127 | # Initialize the model 128 | vocab_size = tokenizer.vocab_size 129 | bert = Bert( 130 | vocab_size, 131 | max_len, 132 | hidden_dim, 133 | n_heads, 134 | n_blocks, 135 | dropout=dropout, 136 | train_config={ 137 | "lr": lr, 138 | "weight_decay": weight_decay, 139 | "max_train_steps": max_train_steps, 140 | "warmup_steps": warmup_steps, 141 | }, 142 | ) 143 | 144 | # Training 145 | checkpointing_freq = max_train_steps // 10 146 | wandb_logger = WandbLogger(project="Papers Re-implementations", name="BERT") 147 | wandb_logger.experiment.config.update(args) 148 | callbacks = [ModelCheckpoint(save_dir, monitor="val_loss", filename="best", every_n_train_steps=checkpointing_freq)] 149 | trainer = pl.Trainer( 150 | accelerator="auto", 151 | strategy="ddp", # State dict not saved with fsdp for some reason 152 | max_steps=max_train_steps, 153 | logger=wandb_logger, 154 | callbacks=callbacks, 155 | val_check_interval=checkpointing_freq, 156 | profiler="simple", 157 | ) 158 | trainer.fit(bert, train_loader, val_loader) 159 | 160 | # Testing the best model 161 | bert = Bert.load_from_checkpoint(os.path.join(save_dir, "best.ckpt")) 162 | trainer.test(bert, test_loader) 163 | 164 | if file_path is not None and os.path.isfile(file_path): 165 | # Unmasking sentences 166 | unmask_sentences(bert, tokenizer, max_len, file_path) 167 | 168 | 169 | if __name__ == "__main__": 170 | parser = ArgumentParser() 171 | 172 | # Model hyper-parameters 173 | parser.add_argument("--n_blocks", type=int, default=12) 174 | parser.add_argument("--n_heads", type=int, default=12) 175 | parser.add_argument("--hidden_dim", type=int, default=768) 176 | parser.add_argument("--dropout", type=float, default=0.1) 177 | parser.add_argument("--max_len", type=int, default=128) 178 | 179 | # Training parameters 180 | parser.add_argument("--batch_size", type=int, default=32) 181 | parser.add_argument("--max_train_steps", type=int, default=10_000) 182 | parser.add_argument("--lr", type=float, default=1e-4) 183 | parser.add_argument("--weight_decay", type=float, default=0.01) 184 | parser.add_argument("--warmup_steps", type=int, default=100) 185 | parser.add_argument("--save", type=str, default="checkpoints/bert") 186 | 187 | # Testing parameters 188 | parser.add_argument( 189 | "--masked_sentences", type=str, default="data/nlp/bert/masked_sentences.txt" 190 | ) 191 | 192 | # Seed 193 | parser.add_argument("--seed", type=int, default=0) 194 | 195 | args = vars(parser.parse_args()) 196 | print(args) 197 | main(args) 198 | -------------------------------------------------------------------------------- /src/nlp/bert/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import torch.nn as nn 6 | from pytorch_lightning.utilities.types import STEP_OUTPUT 7 | from torch.optim import AdamW 8 | from torch.optim.lr_scheduler import LambdaLR 9 | 10 | from src.nlp.layers.embeddings import get_learnable_embedding 11 | from src.nlp.layers.encoder import EncoderTransformer 12 | 13 | 14 | class Bert(pl.LightningModule): 15 | DEFAULT_BERT_CONFIG = { 16 | "lr": 1e-4, 17 | "betas": (0.9, 0.999), 18 | "weight_decay": 0.01, 19 | "max_train_steps": 10_000, 20 | "warmup_steps": 100, 21 | } 22 | 23 | def __init__( 24 | self, 25 | vocab_size, 26 | max_len, 27 | hidden_dim, 28 | n_heads, 29 | depth, 30 | dropout=0.1, 31 | train_config=None, 32 | ): 33 | super(Bert, self).__init__() 34 | 35 | # Saving hyper-parameters so that they are logged 36 | self.save_hyperparameters() 37 | 38 | # Local parameters 39 | self.vocab_size = vocab_size 40 | self.max_len = max_len 41 | self.hidden_dim = hidden_dim 42 | self.n_heads = n_heads 43 | self.depth = depth 44 | self.dropout = dropout 45 | self.train_config = Bert.DEFAULT_BERT_CONFIG 46 | 47 | # Schedulers 48 | self.linear_scheduler = None 49 | self.warmup_scheduler = None 50 | 51 | # Training config 52 | if train_config is not None: 53 | self.train_config.update(train_config) 54 | 55 | # Embeddings 56 | self.embeddings = get_learnable_embedding( 57 | vocab_size, hidden_dim 58 | ) # nn.Embedding(vocab_size, hidden_dim) 59 | self.pos_embeddings = get_learnable_embedding( 60 | max_len, hidden_dim 61 | ) # nn.Embedding(max_len, hidden_dim) 62 | self.sentence_embeddings = get_learnable_embedding( 63 | 2, hidden_dim 64 | ) # nn.Embedding(2, hidden_dim) 65 | 66 | # Transformer and output layer 67 | self.transformer = EncoderTransformer( 68 | hidden_dim, n_heads, depth, dropout_p=dropout 69 | ) 70 | 71 | # Next sentence classifier 72 | self.classifier = nn.Sequential( 73 | nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, 2) 74 | ) 75 | 76 | # Masked language modeling head 77 | self.mask_predictor = nn.Sequential( 78 | nn.Linear(hidden_dim, hidden_dim), 79 | nn.GELU(), 80 | nn.LayerNorm(hidden_dim), 81 | nn.Linear(hidden_dim, vocab_size, bias=True), 82 | ) 83 | 84 | def forward(self, ids, segment_ids=None, attn_mask=None): 85 | # Embedding 86 | b, t = ids.shape 87 | hidden = self.embeddings(ids) 88 | hidden += self.pos_embeddings(torch.arange(t, device=ids.device)).repeat( 89 | b, 1, 1 90 | ) 91 | 92 | if segment_ids is not None: 93 | hidden += self.sentence_embeddings(segment_ids) 94 | else: 95 | hidden += self.sentence_embeddings( 96 | torch.zeros(1, dtype=torch.long, device=ids.device) 97 | ).repeat(b, 1, 1) 98 | 99 | # Transformer 100 | hidden = self.transformer(hidden, attn_mask=attn_mask) 101 | 102 | # Classification head based on CLS token 103 | class_preds = self.classifier(hidden[:, 0]) 104 | 105 | # Masked language modeling head 106 | mlm_preds = self.mask_predictor(hidden) 107 | 108 | return hidden, class_preds, mlm_preds 109 | 110 | def get_losses(self, batch): 111 | # Unpacking batch 112 | ids = batch["input_ids"] 113 | segment_ids = batch["segment_ids"] 114 | attn_mask = batch["attention_mask"] 115 | mlm_labels = batch["mlm_labels"] 116 | mlm_idx = batch["mlm_idx"] 117 | nsp_labels = batch["nsp_labels"] 118 | 119 | # Running forward 120 | b, t = ids.shape 121 | _, class_preds, mlm_preds = self( 122 | ids, segment_ids, attn_mask.repeat(1, t).reshape(b, t, t) 123 | ) 124 | mlm_preds = mlm_preds[mlm_idx + attn_mask == 2] 125 | mlm_labels = mlm_labels[mlm_idx + attn_mask == 2] 126 | 127 | # Classification loss 128 | class_loss = torch.nn.functional.cross_entropy(class_preds, nsp_labels) 129 | 130 | # Masked language modeling loss 131 | mlm_loss = torch.nn.functional.cross_entropy(mlm_preds, mlm_labels) 132 | 133 | # Getting accuracies 134 | class_acc = (class_preds.argmax(dim=-1) == nsp_labels).float().mean() 135 | mlm_acc = (mlm_preds.argmax(dim=-1) == mlm_labels).float().mean() 136 | 137 | return class_loss, mlm_loss, class_acc, mlm_acc 138 | 139 | def configure_optimizers(self): 140 | optim = AdamW( 141 | self.trainer.model.parameters(), 142 | lr=self.train_config["lr"], 143 | weight_decay=self.train_config["weight_decay"], 144 | betas=self.train_config["betas"], 145 | ) 146 | 147 | self.linear_scheduler = LambdaLR(optim, self.scheduler_fn) 148 | 149 | return {"optimizer": optim} 150 | 151 | def scheduler_fn(self, step): 152 | warmup_steps = self.train_config["warmup_steps"] 153 | max_steps = self.train_config["max_train_steps"] 154 | 155 | if step < warmup_steps: 156 | return step / warmup_steps 157 | return 1 - (step - warmup_steps) / (max_steps - warmup_steps) 158 | 159 | def on_train_batch_end( 160 | self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int 161 | ) -> None: 162 | if self.linear_scheduler is not None: 163 | self.linear_scheduler.step() 164 | 165 | return super().on_train_batch_end(outputs, batch, batch_idx) 166 | 167 | def training_step(self, batch, batch_idx): 168 | # Getting losses & accuracies 169 | class_loss, mlm_loss, c_acc, m_acc = self.get_losses(batch) 170 | 171 | # Total loss 172 | loss = class_loss + mlm_loss 173 | 174 | # Logging 175 | self.log_dict( 176 | { 177 | "lr": self.optimizers().param_groups[0]["lr"], 178 | "train_loss": loss, 179 | "train_class_loss": class_loss, 180 | "train_mlm_loss": mlm_loss, 181 | "train_class_acc": c_acc, 182 | "train_mlm_acc": m_acc, 183 | }, 184 | sync_dist=True, 185 | ) 186 | 187 | return loss 188 | 189 | def validation_step(self, batch, batch_idx): 190 | # Getting losses & accuracies 191 | class_loss, mlm_loss, c_acc, m_acc = self.get_losses(batch) 192 | 193 | # Total loss 194 | loss = class_loss + mlm_loss 195 | 196 | # Logging 197 | self.log_dict( 198 | { 199 | "val_loss": loss, 200 | "val_class_loss": class_loss, 201 | "val_mlm_loss": mlm_loss, 202 | "val_class_acc": c_acc, 203 | "val_mlm_acc": m_acc, 204 | }, 205 | sync_dist=True, 206 | ) 207 | 208 | return loss 209 | 210 | def test_step(self, batch, batch_idx): 211 | # Getting losses & accuracies 212 | class_loss, mlm_loss, c_acc, m_acc = self.get_losses(batch) 213 | 214 | # Total loss 215 | loss = class_loss + mlm_loss 216 | 217 | # Logging 218 | self.log_dict( 219 | { 220 | "test_loss": loss, 221 | "test_class_loss": class_loss, 222 | "test_mlm_loss": mlm_loss, 223 | "test_class_acc": c_acc, 224 | "test_mlm_acc": m_acc, 225 | }, 226 | sync_dist=True, 227 | ) 228 | 229 | return loss 230 | -------------------------------------------------------------------------------- /src/nlp/gpt/README.md: -------------------------------------------------------------------------------- 1 | # GPT 2 | 3 | ## Dataset 4 | Training BERT will download the Wikipedia dataset from March 1st, 2022 from [huggingface datasets](https://huggingface.co/datasets/wikipedia). The total disk size required for the dataset is ~ `43GB`, and it will be downloaded under your `HF_DATASETS_CACHE`. 5 | -------------------------------------------------------------------------------- /src/nlp/gpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/nlp/gpt/__init__.py -------------------------------------------------------------------------------- /src/nlp/gpt/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Re-implementation of 3 | 4 | Language Models are Few-Shot Learners 5 | Brown et al. (2020) 6 | (https://arxiv.org/abs/2005.14165) 7 | 8 | on the WikiPedia dataset. 9 | """ 10 | 11 | import os 12 | from argparse import ArgumentParser 13 | 14 | import torch 15 | from torch.utils.data import DataLoader 16 | 17 | torch.set_float32_matmul_precision("medium") 18 | 19 | import pytorch_lightning as pl 20 | import transformers 21 | from datasets import load_dataset 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | from pytorch_lightning.loggers import WandbLogger 24 | from transformers import GPT2Tokenizer 25 | 26 | transformers.logging.set_verbosity_error() 27 | 28 | from src.nlp.gpt.model import GPT 29 | 30 | 31 | def continue_sentences(gpt, tokenizer, max_len, file_path): 32 | """Uses the gpt model to continue sentences from a file. Prints the continued sentences.""" 33 | file = open(file_path, "r") 34 | lines = file.readlines() 35 | lines = [line if not line.endswith("\n") else line[:-1] for line in lines] 36 | file.close() 37 | 38 | gpt = gpt.cuda() 39 | for i, line in enumerate(lines): 40 | # Preparing input 41 | input_ids = tokenizer( 42 | line, 43 | return_tensors="pt", 44 | max_length=max_len, 45 | )["input_ids"].cuda() 46 | 47 | # Generating sentence 48 | all_ids = gpt.generate(input_ids, max_len) 49 | 50 | # Decoding the sentence 51 | sentence = tokenizer.decode(all_ids.squeeze().tolist()) 52 | print(f"\n\nSentence {i+1}:") 53 | print(f"\tOriginal: {line}\n\tCompleted: {sentence}") 54 | 55 | 56 | def main(args): 57 | """ 58 | Train a GPT model on Wikipedia. 59 | Use the model to continue sentences from a file. 60 | """ 61 | # Unpacking parameters 62 | n_blocks = args["n_blocks"] 63 | n_heads = args["n_heads"] 64 | hidden_dim = args["hidden_dim"] 65 | dropout = args["dropout"] 66 | max_len = args["max_len"] 67 | batch_size = args["batch_size"] 68 | max_train_steps = args["max_train_steps"] 69 | lr = args["lr"] 70 | weight_decay = args["weight_decay"] 71 | warmup_steps = args["warmup_steps"] 72 | save_dir = args["save"] 73 | file_path = args["prompts"] 74 | seed = args["seed"] 75 | 76 | # Setting random seed 77 | pl.seed_everything(seed) 78 | 79 | # Load the dataset (wikipedia only has 'train', so we split it ourselves) 80 | train_set = load_dataset("wikipedia", "20220301.en", split="train[:98%]") 81 | val_set = load_dataset("wikipedia", "20220301.en", split="train[98%:99%]") 82 | test_set = load_dataset("wikipedia", "20220301.en", split="train[99%:]") 83 | 84 | # Setting format to torch (not striclty necessary) 85 | # train_set.set_format(type="torch", columns=["text"]) 86 | # val_set.set_format(type="torch", columns=["text"]) 87 | # test_set.set_format(type="torch", columns=["text"]) 88 | 89 | # Loading the GPT2 tokenizer 90 | added_pad_token = False 91 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 92 | if tokenizer.pad_token is None: 93 | tokenizer.add_special_tokens({"pad_token": "[PAD]"}) 94 | added_pad_token = True 95 | 96 | # Tokenizing the whole dataset 97 | def collate(batch): 98 | return tokenizer( 99 | [sample["text"] for sample in batch], 100 | return_tensors="pt", 101 | truncation=True, 102 | padding="max_length", 103 | max_length=max_len, 104 | ) 105 | 106 | # Data loaders 107 | cpus = os.cpu_count() 108 | train_loader = DataLoader( 109 | train_set, 110 | batch_size=batch_size, 111 | shuffle=True, 112 | num_workers=cpus, 113 | collate_fn=collate, 114 | ) 115 | val_loader = DataLoader( 116 | val_set, 117 | batch_size=batch_size, 118 | shuffle=False, 119 | num_workers=cpus, 120 | collate_fn=collate, 121 | ) 122 | test_loader = DataLoader( 123 | test_set, 124 | batch_size=batch_size, 125 | shuffle=False, 126 | num_workers=cpus, 127 | collate_fn=collate, 128 | ) 129 | 130 | # Initialize the model 131 | vocab_size = tokenizer.vocab_size + 1 if added_pad_token else tokenizer.vocab_size 132 | gpt = GPT( 133 | vocab_size, 134 | max_len, 135 | hidden_dim, 136 | n_heads, 137 | n_blocks, 138 | dropout=dropout, 139 | train_config={ 140 | "lr": lr, 141 | "weight_decay": weight_decay, 142 | "max_train_steps": max_train_steps, 143 | "warmup_steps": warmup_steps, 144 | }, 145 | ) 146 | 147 | # Training 148 | checkpointing_freq = max_train_steps // 10 149 | wandb_logger = WandbLogger(project="Papers Re-implementations", name="GPT") 150 | wandb_logger.experiment.config.update(args) 151 | callbacks = [ModelCheckpoint(save_dir, monitor="val_loss", filename="best", every_n_train_steps=checkpointing_freq)] 152 | trainer = pl.Trainer( 153 | accelerator="auto", 154 | strategy="ddp", # State dict not saved with fsdp for some reason 155 | max_steps=max_train_steps, 156 | logger=wandb_logger, 157 | callbacks=callbacks, 158 | profiler="simple", 159 | val_check_interval=checkpointing_freq 160 | ) 161 | trainer.fit(gpt, train_loader, val_loader) 162 | 163 | # Testing the best model 164 | gpt = GPT.load_from_checkpoint(os.path.join(save_dir, "best.ckpt")) 165 | trainer.test(gpt, test_loader) 166 | 167 | if file_path is not None and os.path.isfile(file_path): 168 | # Generating sentences from prompts 169 | continue_sentences(gpt, tokenizer, max_len, file_path) 170 | 171 | 172 | if __name__ == "__main__": 173 | parser = ArgumentParser() 174 | 175 | # Model hyper-parameters 176 | parser.add_argument("--n_blocks", type=int, default=12) 177 | parser.add_argument("--n_heads", type=int, default=12) 178 | parser.add_argument("--hidden_dim", type=int, default=768) 179 | parser.add_argument("--dropout", type=float, default=0.1) 180 | parser.add_argument("--max_len", type=int, default=128) 181 | 182 | # Training parameters 183 | parser.add_argument("--batch_size", type=int, default=32) 184 | parser.add_argument("--max_train_steps", type=int, default=10_000) 185 | parser.add_argument("--lr", type=float, default=1e-4) 186 | parser.add_argument("--weight_decay", type=float, default=0.01) 187 | parser.add_argument("--warmup_steps", type=int, default=100) 188 | parser.add_argument("--save", type=str, default="checkpoints/gpt") 189 | 190 | # Testing parameters 191 | parser.add_argument("--prompts", type=str, default="data/nlp/gpt/prompts.txt") 192 | 193 | # Seed 194 | parser.add_argument("--seed", type=int, default=0) 195 | 196 | args = vars(parser.parse_args()) 197 | print(args) 198 | main(args) 199 | -------------------------------------------------------------------------------- /src/nlp/gpt/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import torch.nn as nn 6 | from pytorch_lightning.utilities.types import STEP_OUTPUT 7 | from torch.optim import AdamW 8 | from torch.optim.lr_scheduler import LambdaLR 9 | 10 | from src.nlp.layers.decoder import DecoderTransformer 11 | from src.nlp.layers.embeddings import get_learnable_embedding 12 | 13 | 14 | class GPT(pl.LightningModule): 15 | DEFAULT_GPT_CONFIG = { 16 | "lr": 1e-4, 17 | "betas": (0.9, 0.999), 18 | "weight_decay": 0.01, 19 | "max_train_steps": 10_000, 20 | "warmup_steps": 100, 21 | } 22 | 23 | def __init__( 24 | self, 25 | vocab_size, 26 | max_len, 27 | hidden_dim, 28 | n_heads, 29 | depth, 30 | dropout=0.1, 31 | train_config=None, 32 | ): 33 | super(GPT, self).__init__() 34 | 35 | # Saving hyper-parameters so that they are logged 36 | self.save_hyperparameters() 37 | 38 | # Local parameters 39 | self.vocab_size = vocab_size 40 | self.max_len = max_len 41 | self.hidden_dim = hidden_dim 42 | self.n_heads = n_heads 43 | self.depth = depth 44 | self.dropout = dropout 45 | self.train_config = GPT.DEFAULT_GPT_CONFIG 46 | 47 | # Schedulers 48 | self.linear_scheduler = None 49 | self.warmup_scheduler = None 50 | 51 | # Training config 52 | if train_config is not None: 53 | self.train_config.update(train_config) 54 | 55 | # Embeddings 56 | self.embeddings = get_learnable_embedding( 57 | vocab_size, hidden_dim 58 | ) # nn.Embedding(vocab_size, hidden_dim) 59 | self.pos_embeddings = get_learnable_embedding( 60 | max_len, hidden_dim 61 | ) # nn.Embedding(max_len, hidden_dim) 62 | 63 | # Transformer and output layer 64 | self.transformer = DecoderTransformer( 65 | hidden_dim, n_heads, depth, dropout_p=dropout 66 | ) 67 | 68 | # Next token classifier 69 | self.classifier = nn.Sequential( 70 | nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, vocab_size) 71 | ) 72 | 73 | def forward(self, ids, attn_mask=None): 74 | # Embedding 75 | b, t = ids.shape 76 | hidden = self.embeddings(ids) 77 | hidden += self.pos_embeddings(torch.arange(t, device=ids.device)).repeat( 78 | b, 1, 1 79 | ) 80 | 81 | # Transformer 82 | hidden = self.transformer(hidden, self_attn_mask=attn_mask) 83 | 84 | # Classification 85 | return self.classifier(hidden), hidden 86 | 87 | def get_losses(self, batch): 88 | # Unpacking batch 89 | ids = batch["input_ids"] 90 | attn_mask = batch["attention_mask"] 91 | 92 | # Running forward 93 | b, t = ids.shape 94 | out, _ = self(ids, attn_mask.repeat(1, t).reshape(b, t, t).tril()) 95 | 96 | # Computing cross-entropy loss 97 | preds, labels = out[:, :-1], ids[:, 1:] 98 | preds, labels = preds[attn_mask[:, :-1] == 1], labels[attn_mask[:, :-1] == 1] 99 | ce_loss = nn.functional.cross_entropy( 100 | preds.reshape(-1, self.vocab_size), labels.reshape(-1) 101 | ) 102 | 103 | accuracy = (preds.argmax(dim=-1) == labels).float().mean() 104 | perplexity = torch.exp(ce_loss) 105 | 106 | return ce_loss, accuracy, perplexity 107 | 108 | def configure_optimizers(self): 109 | optim = AdamW( 110 | self.trainer.model.parameters(), 111 | lr=self.train_config["lr"], 112 | weight_decay=self.train_config["weight_decay"], 113 | betas=self.train_config["betas"], 114 | ) 115 | 116 | self.linear_scheduler = LambdaLR(optim, self.scheduler_fn) 117 | return {"optimizer": optim} 118 | 119 | def scheduler_fn(self, step): 120 | warmup_steps = self.train_config["warmup_steps"] 121 | max_steps = self.train_config["max_train_steps"] 122 | 123 | if step < warmup_steps: 124 | return step / warmup_steps 125 | return 1 - (step - warmup_steps) / (max_steps - warmup_steps) 126 | 127 | def on_train_batch_end( 128 | self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int 129 | ) -> None: 130 | if self.linear_scheduler is not None: 131 | self.linear_scheduler.step() 132 | 133 | return super().on_train_batch_end(outputs, batch, batch_idx) 134 | 135 | def training_step(self, batch, batch_idx): 136 | # Getting losses & accuracies 137 | ce_loss, accuracy, perplexity = self.get_losses(batch) 138 | 139 | # Logging 140 | self.log_dict( 141 | { 142 | "lr": self.optimizers().param_groups[0]["lr"], 143 | "train_loss": ce_loss, 144 | "train_acc": accuracy, 145 | "train_ppl": perplexity, 146 | }, 147 | sync_dist=True, 148 | ) 149 | 150 | return ce_loss 151 | 152 | def validation_step(self, batch, batch_idx): 153 | # Getting losses & accuracies 154 | ce_loss, accuracy, perplexity = self.get_losses(batch) 155 | 156 | # Logging 157 | self.log_dict( 158 | { 159 | "lr": self.optimizers().param_groups[0]["lr"], 160 | "val_loss": ce_loss, 161 | "val_acc": accuracy, 162 | "val_ppl": perplexity, 163 | }, 164 | sync_dist=True, 165 | ) 166 | 167 | return ce_loss 168 | 169 | def test_step(self, batch, batch_idx): 170 | # Getting losses & accuracies 171 | ce_loss, accuracy, perplexity = self.get_losses(batch) 172 | 173 | # Logging 174 | self.log_dict( 175 | { 176 | "lr": self.optimizers().param_groups[0]["lr"], 177 | "test_loss": ce_loss, 178 | "test_acc": accuracy, 179 | "test_ppl": perplexity, 180 | }, 181 | sync_dist=True, 182 | ) 183 | 184 | return ce_loss 185 | 186 | def generate(self, input_ids, max_len): 187 | # Predicting next token until max_len 188 | remaining = max_len - input_ids.shape[1] 189 | for _ in range(remaining): 190 | # Running GPT 191 | preds = self(input_ids)[0] 192 | 193 | # Getting probability of next token 194 | probs = preds[:, -1, :].softmax(dim=-1) 195 | 196 | # Sampling next token with multinomial sampling 197 | next_token = torch.multinomial(probs, num_samples=1) 198 | 199 | # Adding token to input_ids 200 | input_ids = torch.cat((input_ids, next_token), dim=-1) 201 | return input_ids 202 | -------------------------------------------------------------------------------- /src/nlp/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/nlp/layers/__init__.py -------------------------------------------------------------------------------- /src/nlp/layers/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Attention(nn.Module): 6 | """Single Attention Head.""" 7 | 8 | def __init__(self, dropout_p=0.1): 9 | super(Attention, self).__init__() 10 | self.softmax = nn.Softmax(dim=-1) 11 | self.dropout = nn.Dropout(dropout_p) 12 | 13 | def forward(self, q, k, v, mask=None): 14 | # Comput the attention scores by computing the dot product of queries with keys 15 | b, t, d = q.shape 16 | attn = q @ k.transpose(-2, -1) / (d**0.5) # b, nq, nk 17 | 18 | # Mask interactions that should not be captured 19 | if mask is not None: 20 | assert ( 21 | mask.shape == attn.shape 22 | ), f"Mask has shape {mask.shape} != {attn.shape}" 23 | attn = attn.masked_fill(mask == 0, float("-inf")) 24 | 25 | # Computing final output by multiplying attention scores with values 26 | attn = self.softmax(attn) 27 | out = attn @ v 28 | 29 | # Dropping out as regularization during training 30 | out = self.dropout(out) 31 | return out 32 | 33 | 34 | class MultiHeadAttention(nn.Module): 35 | """Multi-Head Attention.""" 36 | 37 | def __init__(self, n_heads, dropout_p=0.1): 38 | super(MultiHeadAttention, self).__init__() 39 | self.n_heads = n_heads 40 | self.heads = nn.ModuleList([Attention(dropout_p) for _ in range(self.n_heads)]) 41 | 42 | def forward(self, q, k, v, mask=None): 43 | # Check that dimensionalities are divisible by the number of heads 44 | b, nq, d = q.shape 45 | b, nk, dv = v.shape 46 | assert ( 47 | d % self.n_heads == 0 48 | ), f"{d}-dimensional query cannot be broken into {self.n_heads} heads." 49 | assert ( 50 | dv % self.n_heads == 0 51 | ), f"{dv}-dimensional value cannot be broken into {self.n_heads} heads." 52 | 53 | # Computing attention in all sub-vectors 54 | qk_dim_per_head = int(d / self.n_heads) 55 | v_dim_per_head = int(dv / self.n_heads) 56 | out = torch.cat( 57 | [ 58 | head( 59 | q[:, :, i * qk_dim_per_head : (i + 1) * qk_dim_per_head], 60 | k[:, :, i * qk_dim_per_head : (i + 1) * qk_dim_per_head], 61 | v[:, :, i * v_dim_per_head : (i + 1) * v_dim_per_head], 62 | mask, 63 | ) 64 | for i, head in enumerate(self.heads) 65 | ], 66 | dim=-1, 67 | ) 68 | 69 | return out 70 | -------------------------------------------------------------------------------- /src/nlp/layers/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from src.nlp.layers.attention import MultiHeadAttention 5 | from src.nlp.layers.mlp import MLP 6 | 7 | 8 | class DecoderBlock(nn.Module): 9 | def __init__( 10 | self, 11 | hidden_dim, 12 | n_heads, 13 | mlp_hidden=None, 14 | mlp_out=None, 15 | mlp_activation=nn.GELU(), 16 | dropout_p=0.1, 17 | with_xa=False, 18 | ): 19 | super(DecoderBlock, self).__init__() 20 | self.with_xa = with_xa 21 | 22 | self.ln1 = nn.LayerNorm(hidden_dim) 23 | self.sa_qkv = nn.Linear(hidden_dim, 3 * hidden_dim) 24 | self.mhsa = MultiHeadAttention(n_heads, dropout_p) 25 | self.sa_o = nn.Linear(hidden_dim, hidden_dim) 26 | 27 | if with_xa: 28 | self.ln2 = nn.LayerNorm(hidden_dim) 29 | self.xa_q = nn.Linear(hidden_dim, hidden_dim) 30 | self.xa_kv = nn.Linear(hidden_dim, 2 * hidden_dim) 31 | self.mhxa = MultiHeadAttention(n_heads, dropout_p) 32 | self.xa_o = nn.Linear(hidden_dim, hidden_dim) 33 | 34 | self.ln3 = nn.LayerNorm(hidden_dim) 35 | self.mlp = MLP(hidden_dim, mlp_hidden, mlp_out, mlp_activation, dropout_p) 36 | 37 | def forward(self, x, kv=None, self_attn_mask=None, cross_attn_mask=None): 38 | # Self-attention and residual part 39 | q, k, v = self.sa_qkv(self.ln1(x)).chunk(3, -1) 40 | sa_out = self.sa_o(self.mhsa(q, k, v, self_attn_mask)) 41 | x = x + sa_out 42 | 43 | # Cross-attention and residual part 44 | if self.with_xa and kv is not None: 45 | q = self.xa_q(self.ln2(x)) 46 | k, v = self.xa_kv(kv).chunk(2, -1) 47 | xa_out = self.xa_o(self.mhxa(q, k, v, cross_attn_mask)) 48 | x = x + xa_out 49 | 50 | # MLP and residual part 51 | out = x + self.mlp(self.ln3(x)) 52 | return out 53 | 54 | 55 | class DecoderTransformer(nn.Module): 56 | def __init__( 57 | self, 58 | hidden_dim, 59 | n_heads, 60 | depth, 61 | mlp_hidden=None, 62 | mlp_out=None, 63 | mlp_activation=nn.GELU(), 64 | dropout_p=0.1, 65 | with_xa=False, 66 | ): 67 | super(DecoderTransformer, self).__init__() 68 | 69 | self.blocks = nn.ModuleList( 70 | [ 71 | DecoderBlock( 72 | hidden_dim, 73 | n_heads, 74 | mlp_hidden, 75 | mlp_out, 76 | mlp_activation, 77 | dropout_p, 78 | with_xa, 79 | ) 80 | for _ in range(depth) 81 | ] 82 | ) 83 | 84 | def forward(self, hidden, kv=None, self_attn_mask=None, cross_attn_mask=None): 85 | # Creating causal mask if not provided 86 | if self_attn_mask is None: 87 | b, l, d = hidden.shape 88 | self_attn_mask = torch.tril(torch.ones(l, l, device=hidden.device)).repeat( 89 | b, 1, 1 90 | ) 91 | 92 | # Running blocks 93 | for block in self.blocks: 94 | hidden = block(hidden, kv, self_attn_mask, cross_attn_mask) 95 | 96 | return hidden 97 | -------------------------------------------------------------------------------- /src/nlp/layers/embeddings.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def get_learnable_embedding(n, hidden_dim): 7 | return nn.Embedding(n, hidden_dim) 8 | 9 | 10 | def get_sinusoidal_embedding(n, hidden_dim): 11 | emb = nn.Embedding(n, hidden_dim) 12 | 13 | def arg(pos, coord): 14 | return pos / np.power(10000, coord / hidden_dim) 15 | 16 | weight = [ 17 | [ 18 | np.sin(arg(pos, coord)) if coord % 2 == 0 else np.cos(arg(pos, coord)) 19 | for coord in range(hidden_dim) 20 | ] 21 | for pos in range(n) 22 | ] 23 | 24 | emb.weight.data.copy_(torch.tensor(weight)) 25 | emb.requires_grad_(False) 26 | return emb 27 | 28 | 29 | def get_rope_embedding(n, hidden_dim): 30 | # TODO ... 31 | pass 32 | 33 | 34 | if __name__ == "__main__": 35 | import matplotlib.pyplot as plt 36 | 37 | sin_emb = get_sinusoidal_embedding(256, 768) 38 | sin_emb = sin_emb.cpu().weight.data.numpy() 39 | plt.imshow(sin_emb) 40 | plt.xlabel("Hidden Dimension") 41 | plt.ylabel("Position") 42 | plt.title("Sinusoidal Embedding") 43 | plt.show() 44 | 45 | plt.imshow(sin_emb @ sin_emb.T) 46 | plt.xlabel("Position") 47 | plt.ylabel("Position") 48 | plt.title("Dot product of Sinusoidal Embeddings") 49 | plt.show() 50 | -------------------------------------------------------------------------------- /src/nlp/layers/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from src.nlp.layers.attention import MultiHeadAttention 5 | from src.nlp.layers.mlp import MLP 6 | 7 | 8 | class EncoderBlock(nn.Module): 9 | def __init__( 10 | self, 11 | hidden_dim, 12 | n_heads, 13 | mlp_hidden=None, 14 | mlp_out=None, 15 | mlp_activation=nn.GELU(), 16 | dropout_p=0.1, 17 | ): 18 | super(EncoderBlock, self).__init__() 19 | 20 | self.ln1 = nn.LayerNorm(hidden_dim) 21 | self.ln2 = nn.LayerNorm(hidden_dim) 22 | 23 | self.to_qkv = nn.Linear(hidden_dim, 3 * hidden_dim) 24 | self.to_o = nn.Linear(hidden_dim, hidden_dim) 25 | 26 | self.mhsa = MultiHeadAttention(n_heads, dropout_p) 27 | self.mlp = MLP(hidden_dim, mlp_hidden, mlp_out, mlp_activation, dropout_p) 28 | 29 | def forward(self, x, attn_mask=None): 30 | # Attention and residual connection 31 | q, k, v = self.to_qkv(self.ln1(x)).chunk(3, -1) 32 | attn_out = self.to_o(self.mhsa(q, k, v, mask=attn_mask)) 33 | x = x + attn_out 34 | 35 | # MLP and residual connection 36 | mlp_out = self.mlp(self.ln2(x)) 37 | out = x + mlp_out 38 | 39 | return out 40 | 41 | 42 | class EncoderTransformer(nn.Module): 43 | def __init__( 44 | self, 45 | hidden_dim, 46 | n_heads, 47 | depth, 48 | mlp_hidden=None, 49 | mlp_out=None, 50 | mlp_activation=nn.GELU(), 51 | dropout_p=0.1, 52 | ): 53 | super(EncoderTransformer, self).__init__() 54 | 55 | self.blocks = nn.ModuleList( 56 | [ 57 | EncoderBlock( 58 | hidden_dim, n_heads, mlp_hidden, mlp_out, mlp_activation, dropout_p 59 | ) 60 | for _ in range(depth) 61 | ] 62 | ) 63 | 64 | def forward(self, hidden, attn_mask=None): 65 | # Creating full attention mask if not provided 66 | if attn_mask is None: 67 | b, l, d = hidden.shape 68 | attn_mask = torch.ones((l, l), device=hidden.device).repeat(b, 1, 1) 69 | 70 | # Running blocks 71 | for block in self.blocks: 72 | hidden = block(hidden, attn_mask) 73 | return hidden 74 | -------------------------------------------------------------------------------- /src/nlp/layers/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class MLP(nn.Module): 5 | def __init__( 6 | self, in_dim, hidden_dim=None, out_dim=None, activation=nn.GELU(), drop_p=0.1 7 | ) -> None: 8 | super(MLP, self).__init__() 9 | 10 | self.in_dim = in_dim 11 | self.hidden_dim = hidden_dim if hidden_dim is not None else in_dim * 4 12 | self.out_dim = out_dim if out_dim is not None else in_dim 13 | 14 | self.linear1 = nn.Linear(self.in_dim, self.hidden_dim) 15 | self.linear2 = nn.Linear(self.hidden_dim, self.out_dim) 16 | self.activation = activation 17 | self.dropout = nn.Dropout(drop_p) 18 | 19 | def forward(self, x): 20 | out = self.linear1(x) 21 | out = self.activation(out) 22 | out = self.linear2(out) 23 | out = self.dropout(out) 24 | return out 25 | -------------------------------------------------------------------------------- /src/nlp/lm_is_compression/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/nlp/lm_is_compression/__init__.py -------------------------------------------------------------------------------- /src/nlp/lm_is_compression/lm_is_compression.py: -------------------------------------------------------------------------------- 1 | """Re-implementation of 2 | 3 | Language Modeling Is Compression 4 | (https://arxiv.org/abs/2309.10668) 5 | 6 | by Delétang, Ruoss et. al. 7 | """ 8 | 9 | from argparse import ArgumentParser 10 | 11 | import torch 12 | import torch.nn as nn 13 | from transformers import AutoModelForCausalLM, AutoTokenizer 14 | 15 | 16 | def set_reproducibility(seed=0): 17 | torch.manual_seed(seed) 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | 21 | 22 | class ArithmeticEncoder: 23 | """The arithmetic encoder converts a sequence of tokens into a single 24 | number according to the probability distribution of the next token given 25 | by the language model.""" 26 | 27 | def __init__(self, model: nn.Module, bos_token_id: int): 28 | self.bos_token_id = bos_token_id 29 | 30 | self.model = model.eval() 31 | self.softmax = nn.Softmax(dim=-1) 32 | 33 | def __call__(self, ids: torch.Tensor) -> bytes: 34 | return self.encode(ids) 35 | 36 | def highs_lows_to_lambdas( 37 | self, highs: torch.Tensor, lows: torch.Tensor 38 | ) -> torch.Tensor: 39 | # Now returning the midpoints 40 | # TODO: Return the numbers with the shortest binary representation 41 | return (highs + lows) / 2 42 | 43 | @torch.no_grad() 44 | def encode(self, ids: torch.Tensor) -> bytes: 45 | """Encode a sequence of tokens into a binary sequence of bits. 46 | The encoding is done by finding the scalar number in range [0, 1] 47 | using arithmetic encoding based on the probability distribution of the 48 | next token given by the language model. 49 | 50 | Args: 51 | ids (torch.Tensor): Two-dimensional tensor of token ids. Omit the BOS token. 52 | 53 | Returns: 54 | bytes: The encoded sequence of bits 55 | """ 56 | # Appending the BOS token to the beginning of each sequence 57 | bos_tokens = torch.full( 58 | (ids.shape[0], 1), self.bos_token_id, dtype=torch.long, device=ids.device 59 | ) 60 | ids = torch.cat([bos_tokens, ids], dim=1) 61 | N, T = ids.shape 62 | 63 | # Getting the probabilities of the next token 64 | logits = self.model(ids)["logits"] 65 | probs = self.softmax(logits) 66 | 67 | # Find the lambda number for each sequence 68 | lows, highs = torch.zeros(N, dtype=torch.double), torch.ones( 69 | N, dtype=torch.double 70 | ) 71 | for i in range(T - 1): 72 | intervals = highs - lows 73 | 74 | # Getting cumulative probabilities 75 | # TODO: Parallelize this loop 76 | c_probs = torch.empty(N) 77 | n_probs = torch.empty(N) 78 | for j in range(N): 79 | c_probs[j] = probs[j, i, : ids[j, i + 1]].sum() 80 | n_probs[j] = probs[j, i, : ids[j, i + 1] + 1].sum() 81 | 82 | # Updating lows and highs 83 | highs = lows + intervals * n_probs 84 | lows = lows + intervals * c_probs 85 | 86 | # Return the lambda numbers 87 | return self.highs_lows_to_lambdas(highs, lows) 88 | 89 | @torch.no_grad() 90 | def decode(self, lambdas: torch.Tensor, atol=1e-30) -> torch.Tensor: 91 | """Undo the encoding and, given scalar lambdas, return the original input ids.""" 92 | N, dev = lambdas.shape[0], lambdas.device 93 | ids = torch.full((N, 1), self.bos_token_id, dtype=torch.long, device=dev) 94 | 95 | # Recovering the ids 96 | lows, highs = torch.zeros(N, dtype=torch.double, device=dev), torch.ones( 97 | N, dtype=torch.double, device=dev 98 | ) 99 | while not torch.allclose( 100 | self.highs_lows_to_lambdas(highs, lows), lambdas, atol=atol 101 | ): 102 | probs = self.softmax(self.model(ids)["logits"][:, -1]) 103 | 104 | next_ids = torch.empty(N, dtype=torch.long, device=lambdas.device) 105 | for i in range(N): 106 | lamb = lambdas[i] 107 | low, high = lows[i], highs[i] 108 | for j in range(probs.shape[1]): 109 | l = low + (high - low) * probs[i, :j].sum() 110 | u = low + (high - low) * probs[i, : j + 1].sum() 111 | 112 | if l <= lamb < u: 113 | highs[i], lows[i] = u, l 114 | next_ids[i] = j 115 | break 116 | 117 | ids = torch.cat([ids, next_ids.unsqueeze(1)], dim=1) 118 | 119 | return ids 120 | 121 | def to(self, device): 122 | self.model.to(device) 123 | self.softmax.to(device) 124 | return self 125 | 126 | 127 | def main(args): 128 | # Getting program parameters 129 | model_ckpt = args["model"] 130 | seed = args["seed"] 131 | 132 | # Setting reproducibility 133 | set_reproducibility(seed) 134 | 135 | # Preparing sentences to encode 136 | sentences = ["The quick brown fox jumps over the lazy dog."] 137 | 138 | # Loading model and tokenizer 139 | model = AutoModelForCausalLM.from_pretrained( 140 | model_ckpt, torch_dtype=torch.float32, device_map="auto" 141 | ).eval() 142 | tokenizer = AutoTokenizer.from_pretrained(model_ckpt) 143 | 144 | if tokenizer.pad_token is None: 145 | tokenizer.add_special_tokens({"pad_token": "[PAD]"}) 146 | model.resize_token_embeddings(len(tokenizer)) 147 | 148 | # Encoding sentences 149 | ids = tokenizer(sentences, return_tensors="pt", padding=True)["input_ids"].cuda() 150 | encoder = ArithmeticEncoder(model, tokenizer.bos_token_id) 151 | encoded = encoder(ids) 152 | decoded = encoder.decode(encoded.cuda()) 153 | 154 | # Printing results 155 | print("\n\nOriginal sentences:", sentences) 156 | print( 157 | "Decoded sentences:", tokenizer.batch_decode(decoded, skip_special_tokens=True) 158 | ) 159 | 160 | print("\n\nOriginal ids:", ids) 161 | print("Decoded ids:", decoded[:, 1:]) 162 | 163 | print("\n\nEncoded sentences (as scalars):", encoded.cpu().numpy()) 164 | 165 | 166 | if __name__ == "__main__": 167 | parser = ArgumentParser() 168 | parser.add_argument("--model", type=str, default="EleutherAI/pythia-1.4b-v0") 169 | parser.add_argument("--seed", type=int, default=0) 170 | args = vars(parser.parse_args()) 171 | print(args) 172 | main(args) 173 | -------------------------------------------------------------------------------- /src/nlp/lm_watermarking/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from accelerate import Accelerator 3 | from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed 4 | 5 | from src.nlp.lm_watermarking.watermarking import ( 6 | detect_watermark, 7 | generate, 8 | get_perplexities, 9 | ) 10 | 11 | 12 | class GPT2Wrapper(torch.nn.Module): 13 | """A wrapper around the GPT2 model to take ids as input and return logits as output.""" 14 | 15 | def __init__(self): 16 | super(GPT2Wrapper, self).__init__() 17 | self.tokenizer = AutoTokenizer.from_pretrained("gpt2") 18 | self.model = GPT2LMHeadModel.from_pretrained("gpt2") 19 | 20 | def forward(self, input_ids): 21 | outputs = self.model(input_ids) 22 | return outputs.logits 23 | 24 | 25 | def main(): 26 | """Plots the perplexity of the GPT2 model and the z-static for sentences generated with and without watermarking.""" 27 | # Setting seed 28 | set_seed(0) 29 | 30 | # Device 31 | device = Accelerator().device 32 | 33 | # Language Model (GPT2) 34 | model = GPT2Wrapper().to(device) 35 | vocab_size = model.tokenizer.vocab_size 36 | 37 | # Prior text 38 | prior = model.tokenizer("Some text to be continued", return_tensors="pt")[ 39 | "input_ids" 40 | ].to(device) 41 | 42 | # A sentence generated without watermarking 43 | normal_ids = generate(model, prior, max_length=200, watermark=False) 44 | n_ppl = get_perplexities(model, normal_ids).item() # Perplexity 45 | n_z = detect_watermark(normal_ids, vocab_size).item() # Z-statistic 46 | 47 | # A sentence generated with watermarking 48 | watermarked_ids = generate(model, prior, max_length=200, watermark=True) 49 | w_ppl = get_perplexities(model, watermarked_ids).item() # Perplexity 50 | w_z = detect_watermark(watermarked_ids, vocab_size).item() # Z-statistic 51 | 52 | # Showing non-watermarked text, PPL and probability of watermark 53 | print( 54 | f"\n\n\033[92mNormal text (PPL = {n_ppl:.2f}, Z-statistic = {n_z:.2f})\033[0m:\n" 55 | ) 56 | print(model.tokenizer.decode(normal_ids[0])) 57 | 58 | # Showing watermarked text, PPL and probability of watermark 59 | print(f"\n\n\033[93mWM text (PPL = {w_ppl:.2f}, Z-statistic = {w_z:.2f})\033[0m:\n") 60 | print(model.tokenizer.decode(watermarked_ids[0])) 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /src/nlp/lm_watermarking/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from tqdm import tqdm 3 | 4 | plt.rcParams.update({"font.size": 22}) 5 | from argparse import ArgumentParser 6 | 7 | import torch 8 | from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed 9 | 10 | from src.nlp.lm_watermarking.watermarking import ( 11 | detect_watermark, 12 | generate, 13 | get_perplexities, 14 | ) 15 | 16 | 17 | def parse_args(): 18 | parser = ArgumentParser() 19 | 20 | parser.add_argument( 21 | "--n_sentences", type=int, default=128, help="Number of sentences to generate" 22 | ) 23 | parser.add_argument( 24 | "--seq_len", type=int, default=200, help="Length of the generated sentences" 25 | ) 26 | parser.add_argument( 27 | "--batch_size", type=int, default=16, help="Batch size for the generation" 28 | ) 29 | parser.add_argument( 30 | "--gamma", type=float, default=0.5, help="Green list proportion" 31 | ) 32 | parser.add_argument( 33 | "--delta", 34 | type=float, 35 | default=2, 36 | help="Amount to add to the logits of the model when watermarking", 37 | ) 38 | parser.add_argument( 39 | "--device", type=int, default=0, help="Device to use for generation" 40 | ) 41 | parser.add_argument("--seed", type=int, default=0, help="Seed for the generation") 42 | 43 | return vars(parser.parse_args()) 44 | 45 | 46 | class GPT2Wrapper(torch.nn.Module): 47 | """A wrapper around the GPT2 model to take ids as input and return logits as output.""" 48 | 49 | def __init__(self): 50 | super(GPT2Wrapper, self).__init__() 51 | self.tokenizer = AutoTokenizer.from_pretrained("gpt2") 52 | self.model = GPT2LMHeadModel.from_pretrained("gpt2") 53 | 54 | def forward(self, input_ids): 55 | outputs = self.model(input_ids) 56 | return outputs.logits 57 | 58 | 59 | def main(): 60 | # Plotting parameters 61 | args = parse_args() 62 | n_sentences = args["n_sentences"] 63 | seq_len = args["seq_len"] 64 | batch_size = args["batch_size"] 65 | gamma = args["gamma"] 66 | delta = args["delta"] 67 | seed = args["seed"] 68 | 69 | # Setting seed 70 | set_seed(seed) 71 | 72 | # Device 73 | device = torch.device( 74 | "cuda:" + str(args["device"]) if torch.cuda.is_available() else "cpu" 75 | ) 76 | 77 | # Model 78 | model = GPT2Wrapper().to(device) 79 | vocab_size = model.tokenizer.vocab_size 80 | 81 | # Prior text (BOS token) 82 | prior = ( 83 | (model.tokenizer.bos_token_id * torch.ones((n_sentences, 1))).long().to(device) 84 | ) 85 | 86 | # Collecting generations with and without watermark 87 | regular_ppls, regular_z_scores = [], [] 88 | watermarked_ppls, watermarked_z_scores = [], [] 89 | for i in tqdm(range(0, n_sentences, batch_size), desc="Generating sentences"): 90 | batch = prior[i : i + batch_size] 91 | 92 | # Regular sentences 93 | regular = generate(model, batch, max_length=seq_len, watermark=False) 94 | regular_ppls.extend(get_perplexities(model, regular).tolist()) 95 | regular_z_scores.extend(detect_watermark(regular, vocab_size).tolist()) 96 | 97 | # Watermarked sentences 98 | watermarked = generate( 99 | model, batch, max_length=seq_len, watermark=True, gamma=gamma, delta=delta 100 | ) 101 | watermarked_ppls.extend(get_perplexities(model, watermarked).tolist()) 102 | watermarked_z_scores.extend(detect_watermark(watermarked, vocab_size).tolist()) 103 | 104 | # Scatter plot of perplexity vs z-score 105 | plt.figure(figsize=(10, 10)) 106 | plt.scatter(regular_ppls, regular_z_scores, label="Regular") 107 | plt.scatter(watermarked_ppls, watermarked_z_scores, label="Watermarked") 108 | plt.legend() 109 | plt.title("Perplexity vs Z-score") 110 | plt.xlabel("Perplexity") 111 | plt.ylabel("Z-score") 112 | plt.savefig( 113 | f"perplexity_vs_zscore_(n={n_sentences}, seq_len={seq_len}, gamma={gamma}, delta={delta}, seed={seed}).png" 114 | ) 115 | plt.show() 116 | print("Program completed successfully!") 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /src/nlp/lm_watermarking/watermarking.py: -------------------------------------------------------------------------------- 1 | """Implementation of watermarking for language models as proposed in the paper 2 | 3 | "A Watermark for Large Language Models" 4 | by Kirchenbauer, Geiping et. al. (https://arxiv.org/abs/2301.10226v2). 5 | """ 6 | 7 | from hashlib import sha256 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | def default_hash_fn(tensor): 14 | """Returns the hash of the given tensor using the sha256 algorithm and by converting the tensor to a string first. 15 | 16 | Args: 17 | tensor: The tensor to hash. 18 | 19 | Returns: 20 | The hash of the tensor. 21 | """ 22 | return int(sha256(str(tensor).encode("utf-8")).hexdigest(), 16) % (10**8) 23 | 24 | 25 | @torch.no_grad() 26 | def generate( 27 | model, 28 | prior_tokens, 29 | max_length=200, 30 | watermark=True, 31 | gamma=0.5, 32 | delta=2, 33 | hash_function=default_hash_fn, 34 | ): 35 | """Generates text with the given model. Optionally, the text can be (soft) watermarked. 36 | 37 | Args: 38 | model: The model which outputs logits for the next token. 39 | prior_tokens: The input tensor from which the model starts generating of shape (B, T). 40 | max_length: The maximum length of the generated text. Default is 100. 41 | gamma: The proportion of the green list. Default is 0.5. 42 | delta: The hardness parameter. Default is 20. 43 | hash_function: The function to use for hashing. Default is default_hash_fn. 44 | 45 | Returns: 46 | The generated text of shape (B, T). 47 | """ 48 | B, T = prior_tokens.shape 49 | device = prior_tokens.device 50 | 51 | generated_tokens = prior_tokens 52 | for _ in range(max_length - T): 53 | # Getting logits 54 | l_t = model(generated_tokens)[:, -1, :] 55 | 56 | if watermark: 57 | # Seeding generators based on previous token 58 | seeds = [hash_function(generated_tokens[i, -1]) for i in range(B)] 59 | generators = [ 60 | torch.Generator(device=device).manual_seed(seed) for seed in seeds 61 | ] 62 | 63 | # Increasing probability of green list indices 64 | vs = l_t.shape[-1] # Vocabulary size 65 | gls = int(gamma * vs) # Green list size 66 | gli = torch.stack( 67 | [ 68 | torch.randperm(vs, generator=generators[i], device=device) 69 | for i in range(B) 70 | ] 71 | ) # Green list indices 72 | 73 | l_t = l_t + delta * (gli < gls) 74 | 75 | # Sampling from the distribution 76 | l_t = torch.softmax(l_t, dim=-1) 77 | next_tokens = torch.multinomial(l_t, 1) 78 | generated_tokens = torch.cat([generated_tokens, next_tokens], dim=-1) 79 | 80 | return generated_tokens 81 | 82 | 83 | def detect_watermark(ids, vocab_size, gamma=0.5, hash_function=default_hash_fn): 84 | """Returns the probability that a text was created by a Language Model with watermarking. 85 | 86 | Args: 87 | ids: The tensor with the generated text indices of shape (B, T). 88 | gamma: The proportion of the green list. Default is 0.5. 89 | delta: The hardness parameter. Default is 20. 90 | hash_function: The function used for watermarking. Default is default_hash_fn. 91 | 92 | Returns: 93 | The z-statistic of the watermarking probability. 94 | """ 95 | B, T = ids.shape 96 | device = ids.device 97 | gls = int(gamma * vocab_size) # Green list size 98 | in_green_list = torch.zeros(B, dtype=torch.float32).to( 99 | device 100 | ) # Number of tokens in the green list 101 | 102 | for i in range(T - 1): 103 | # Seeding generators based on previous token 104 | seeds = [hash_function(ids[j, i]) for j in range(B)] 105 | generators = [ 106 | torch.Generator(device=device).manual_seed(seed) for seed in seeds 107 | ] 108 | 109 | # Increasing probability of green list indices 110 | gli = torch.stack( 111 | [ 112 | torch.randperm(vocab_size, generator=generators[i], device=device) 113 | for i in range(B) 114 | ] 115 | ) # Green list indices 116 | 117 | # Counting tokens that are in the green list and adding to the total 118 | in_green_list += (gli.gather(1, ids[:, i + 1].unsqueeze(-1)) < gls).squeeze() 119 | 120 | z = (in_green_list - gamma * T) / np.sqrt(T * gamma * (1 - gamma)) 121 | return z 122 | 123 | 124 | @torch.no_grad() 125 | def get_perplexities(model, ids): 126 | """Returns the perplexities of the model for the given texts. 127 | 128 | Args: 129 | model: The model which outputs logits for the next token. 130 | ids: The tensor with the generated text indices of shape (B, T) 131 | 132 | Returns: 133 | The perplexities of the model for the given texts as a tensor of shape (B,). 134 | """ 135 | B, T = ids.shape 136 | 137 | perplexities = torch.zeros(B).to(ids.device) 138 | for i in range(T - 1): 139 | l_t = model(ids[:, : i + 1])[:, -1, :] 140 | l_t = torch.softmax(l_t, dim=-1) 141 | l_t = l_t[range(B), ids[:, i + 1]] 142 | l_t = torch.log(l_t) 143 | perplexities += l_t 144 | 145 | return torch.exp(-perplexities / (T - 1)) 146 | -------------------------------------------------------------------------------- /src/nlp/original/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/nlp/original/__init__.py -------------------------------------------------------------------------------- /src/nlp/original/data.py: -------------------------------------------------------------------------------- 1 | from os import cpu_count 2 | 3 | import torch 4 | from datasets import load_dataset 5 | from pytorch_lightning import LightningDataModule 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | 9 | class WMT14Subset(Dataset): 10 | def __init__(self, subset, tokenizer, max_len=128, languages="de-en"): 11 | super(WMT14Subset, self).__init__() 12 | self.tokenizer = tokenizer 13 | self.max_len = max_len 14 | self.lang1, self.lang2 = languages.split("-") 15 | self.subset = subset 16 | 17 | def __len__(self): 18 | return len(self.subset) 19 | 20 | def __getitem__(self, index): 21 | return self.preprocess(self.subset[index]) 22 | 23 | def preprocess(self, sample): 24 | # Getting sentences 25 | s1 = sample["translation"][self.lang1] 26 | s2 = sample["translation"][self.lang2] 27 | 28 | # Tokenizing sentencens 29 | enc_tok = self.tokenizer( 30 | s1, 31 | return_tensors="pt", 32 | padding="max_length", 33 | truncation=True, 34 | max_length=self.max_len, 35 | ) 36 | dec_tok = self.tokenizer( 37 | s2, 38 | return_tensors="pt", 39 | padding="max_length", 40 | truncation=True, 41 | max_length=self.max_len, 42 | ) 43 | 44 | # Unpacking return values 45 | x_enc = enc_tok["input_ids"][0] 46 | x_dec = dec_tok["input_ids"][0] 47 | enc_attn_mask = enc_tok["attention_mask"][0] 48 | dec_attn_mask = dec_tok["attention_mask"][0] 49 | 50 | return { 51 | "x_enc": x_enc, 52 | "x_dec": x_dec, 53 | "enc_attn_mask": enc_attn_mask, 54 | "dec_attn_mask": dec_attn_mask, 55 | "enc_dec_attn_mask": enc_attn_mask, 56 | } 57 | 58 | 59 | class WMT14Dataset(LightningDataModule): 60 | def __init__(self, tokenizer, batch_size=32, max_len=128, languages="de-en"): 61 | super(WMT14Dataset, self).__init__() 62 | self.tokenizer = tokenizer 63 | self.batch_size = batch_size 64 | self.max_len = max_len 65 | self.languages = languages 66 | 67 | def prepare_data(self): 68 | self.wmt14 = load_dataset("wmt14", self.languages) 69 | 70 | def setup(self, stage): 71 | self.train = WMT14Subset( 72 | self.wmt14["train"], self.tokenizer, self.max_len, self.languages 73 | ) 74 | self.val = WMT14Subset( 75 | self.wmt14["validation"], self.tokenizer, self.max_len, self.languages 76 | ) 77 | self.test = WMT14Subset( 78 | self.wmt14["test"], self.tokenizer, self.max_len, self.languages 79 | ) 80 | 81 | def train_dataloader(self): 82 | return DataLoader( 83 | self.train, 84 | batch_size=self.batch_size, 85 | shuffle=True, 86 | num_workers=cpu_count(), 87 | ) 88 | 89 | def val_dataloader(self): 90 | return DataLoader( 91 | self.val, batch_size=self.batch_size, shuffle=False, num_workers=cpu_count() 92 | ) 93 | 94 | def test_dataloader(self): 95 | return DataLoader( 96 | self.test, 97 | batch_size=self.batch_size, 98 | shuffle=False, 99 | num_workers=cpu_count(), 100 | ) 101 | 102 | def teardown(self, stage): 103 | pass 104 | -------------------------------------------------------------------------------- /src/nlp/original/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Re-implementation of 3 | 4 | Attention is all you need 5 | Vaswani et al. (2017) 6 | (https://arxiv.org/abs/1706.03762) 7 | 8 | on the WMT14 dataset. 9 | """ 10 | import os 11 | from argparse import ArgumentParser 12 | 13 | import pytorch_lightning as pl 14 | import torch 15 | from pytorch_lightning.callbacks import ModelCheckpoint 16 | from pytorch_lightning.loggers import WandbLogger 17 | from transformers import FSMTTokenizer 18 | 19 | from src.nlp.original.data import WMT14Dataset 20 | from src.nlp.original.model import EncoderDecoderModel 21 | 22 | 23 | def translate_sentences(file_path, model, tokenizer, max_len=128): 24 | file = open(file_path, "r") 25 | sentences = file.readlines() 26 | file.close() 27 | 28 | model = model.cuda() 29 | dec_input = tokenizer.bos_token_id * torch.ones(1, 1) 30 | dec_input = dec_input.long().cuda() 31 | for i, sentence in enumerate(sentences): 32 | x = tokenizer( 33 | sentence, 34 | return_tensors="pt", 35 | padding="max_length", 36 | truncation=True, 37 | max_length=max_len, 38 | )["input_ids"].cuda() 39 | y = model.generate(x, dec_input, max_len)[0] 40 | translated = tokenizer.decode(y) 41 | 42 | print(f"\n\nSENTENCE {i+1}:") 43 | print(f"\tOriginal: {sentence}\n\tTranslated: {translated}") 44 | 45 | 46 | def main(args): 47 | # Unpacking arguments 48 | enc_n_blocks = args["enc_n_blocks"] 49 | enc_n_heads = args["enc_n_heads"] 50 | dec_n_blocks = args["dec_n_blocks"] 51 | dec_n_heads = args["dec_n_heads"] 52 | hidden_dim = args["hidden_dim"] 53 | dropout = args["dropout"] 54 | max_len = args["max_len"] 55 | batch_size = args["batch_size"] 56 | max_train_steps = args["max_train_steps"] 57 | lr = args["lr"] 58 | weight_decay = args["weight_decay"] 59 | warmup_steps = args["warmup_steps"] 60 | save_dir = args["save"] 61 | file_path = args["file"] 62 | seed = args["seed"] 63 | 64 | # Setting seed 65 | pl.seed_everything(seed) 66 | 67 | # Loading dataset 68 | tokenizer = FSMTTokenizer.from_pretrained("facebook/wmt19-de-en") 69 | vocab_size = tokenizer.vocab_size 70 | if tokenizer.pad_token is None: 71 | tokenizer.add_special_tokens({"pad_token": ""}) 72 | vocab_size += 1 73 | 74 | dataset = WMT14Dataset(tokenizer, batch_size, max_len) 75 | 76 | # Loading model 77 | model = EncoderDecoderModel( 78 | vocab_size, 79 | max_len, 80 | hidden_dim, 81 | enc_n_heads, 82 | enc_n_blocks, 83 | dec_n_heads, 84 | dec_n_blocks, 85 | dropout, 86 | lr, 87 | weight_decay, 88 | warmup_steps, 89 | ) 90 | 91 | # Training 92 | checkpointing_freq = max_train_steps // 10 93 | callbacks = [ 94 | ModelCheckpoint( 95 | save_dir, 96 | monitor="val_loss", 97 | filename="best", 98 | every_n_train_steps=checkpointing_freq, 99 | ) 100 | ] 101 | logger = WandbLogger(project="Papers Re-implementations", name="ORIGINAL") 102 | logger.experiment.config.update(args) 103 | trainer = pl.Trainer( 104 | devices="auto", 105 | strategy="ddp", 106 | max_steps=max_train_steps, 107 | callbacks=callbacks, 108 | logger=logger, 109 | val_check_interval=checkpointing_freq, 110 | ) 111 | trainer.fit(model, datamodule=dataset) 112 | 113 | # Testing 114 | model = EncoderDecoderModel.load_from_checkpoint( 115 | os.path.join(save_dir, "best.ckpt") 116 | ) 117 | trainer.test(model, datamodule=dataset) 118 | 119 | # Translating custom sentences 120 | translate_sentences(file_path, model, tokenizer, max_len) 121 | 122 | 123 | if __name__ == "__main__": 124 | parser = ArgumentParser() 125 | 126 | # Model hyper-parameters 127 | parser.add_argument("--enc_n_blocks", type=int, default=12) 128 | parser.add_argument("--enc_n_heads", type=int, default=12) 129 | parser.add_argument("--dec_n_blocks", type=int, default=12) 130 | parser.add_argument("--dec_n_heads", type=int, default=12) 131 | parser.add_argument("--hidden_dim", type=int, default=768) 132 | parser.add_argument("--dropout", type=float, default=0.1) 133 | parser.add_argument("--max_len", type=int, default=128) 134 | 135 | # Training parameters 136 | parser.add_argument("--batch_size", type=int, default=32) 137 | parser.add_argument("--max_train_steps", type=int, default=100_000) 138 | parser.add_argument("--lr", type=float, default=1e-4) 139 | parser.add_argument("--weight_decay", type=float, default=0.01) 140 | parser.add_argument("--warmup_steps", type=int, default=4000) 141 | parser.add_argument("--save", type=str, default="checkpoints/original") 142 | 143 | # Testing parameters 144 | parser.add_argument( 145 | "--file", type=str, default="data/nlp/original/translate_sentences.txt" 146 | ) 147 | 148 | # Seed 149 | parser.add_argument("--seed", type=int, default=0) 150 | 151 | args = vars(parser.parse_args()) 152 | print(args) 153 | main(args) 154 | -------------------------------------------------------------------------------- /src/nlp/original/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import torch.nn as nn 6 | from pytorch_lightning.utilities.types import STEP_OUTPUT 7 | from torch.optim import Adam 8 | from torch.optim.lr_scheduler import LambdaLR 9 | 10 | from src.nlp.layers.decoder import DecoderTransformer 11 | from src.nlp.layers.embeddings import get_learnable_embedding 12 | from src.nlp.layers.encoder import EncoderTransformer 13 | 14 | 15 | class EncoderDecoderModel(pl.LightningModule): 16 | def __init__( 17 | self, 18 | vocab_size, 19 | max_len, 20 | hidden_dim, 21 | enc_n_heads, 22 | enc_n_blocks, 23 | dec_n_heads, 24 | dec_n_blocks, 25 | dropout, 26 | lr, 27 | weight_decay, 28 | warmup_steps, 29 | ): 30 | super(EncoderDecoderModel, self).__init__() 31 | 32 | # Saving hyper-parameters so that they are logged 33 | self.save_hyperparameters() 34 | 35 | # Local parameters 36 | self.vocab_size = vocab_size 37 | self.max_len = max_len 38 | self.hidden_dim = hidden_dim 39 | self.enc_n_heads = enc_n_heads 40 | self.enc_n_blocks = enc_n_blocks 41 | self.dec_n_heads = dec_n_heads 42 | self.dec_n_blocks = dec_n_blocks 43 | self.dropout = dropout 44 | self.lr = lr 45 | self.weight_decay = weight_decay 46 | self.warmup_steps = warmup_steps 47 | self.scheduler = None 48 | 49 | # Embeddings (note: we're learning embeddings for both languages in MT) 50 | self.embedding = get_learnable_embedding(vocab_size, hidden_dim) 51 | self.enc_pos_embedding = get_learnable_embedding(max_len, hidden_dim) 52 | self.dec_pos_embedding = get_learnable_embedding(max_len, hidden_dim) 53 | 54 | # Encoder and decoder models 55 | self.enc_transformer = EncoderTransformer( 56 | hidden_dim, enc_n_heads, enc_n_blocks, dropout_p=dropout 57 | ) 58 | self.dec_transformer = DecoderTransformer( 59 | hidden_dim, enc_n_heads, enc_n_blocks, dropout_p=dropout, with_xa=True 60 | ) 61 | 62 | # Decoding head 63 | self.head = nn.Sequential( 64 | nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, vocab_size) 65 | ) 66 | 67 | def scheduler_fn(self, step): 68 | step += 1 69 | return self.hidden_dim ** (-0.5) * min( 70 | step ** (-0.5), step * self.warmup_steps ** (-1.5) 71 | ) 72 | 73 | def configure_optimizers(self): 74 | optim = Adam( 75 | self.trainer.model.parameters(), 76 | lr=self.lr, 77 | weight_decay=self.weight_decay, 78 | betas=(0.9, 0.98), 79 | eps=1e-9, 80 | ) 81 | self.scheduler = LambdaLR(optim, self.scheduler_fn) 82 | return optim 83 | 84 | def forward( 85 | self, ids_enc, ids_dec, enc_attn_mask, dec_attn_mask, enc_dec_attn_mask 86 | ): 87 | assert ids_enc.shape[0] == ids_dec.shape[0] 88 | 89 | enc_out = self.forward_enc(ids_enc, enc_attn_mask) 90 | dec_out = self.forward_dec(ids_dec, dec_attn_mask, enc_dec_attn_mask, enc_out) 91 | 92 | return self.head(dec_out), enc_out, dec_out 93 | 94 | def forward_enc(self, ids_enc, enc_attn_mask): 95 | b, t = ids_enc.shape 96 | x_enc = self.embedding(ids_enc) + self.enc_pos_embedding( 97 | torch.arange(t).cuda() 98 | ).repeat(b, 1, 1) 99 | enc_out = self.enc_transformer(x_enc, enc_attn_mask) 100 | return enc_out 101 | 102 | def forward_dec(self, ids_dec, dec_attn_mask, enc_dec_attn_mask, enc_out): 103 | b, t = ids_dec.shape 104 | x_dec = self.embedding(ids_dec) + self.dec_pos_embedding( 105 | torch.arange(t).cuda() 106 | ).repeat(b, 1, 1) 107 | dec_out = self.dec_transformer( 108 | x_dec, 109 | kv=enc_out, 110 | self_attn_mask=dec_attn_mask, 111 | cross_attn_mask=enc_dec_attn_mask, 112 | ) 113 | return dec_out 114 | 115 | def compute_loss(self, batch): 116 | ids_enc = batch["x_enc"] 117 | ids_dec = batch["x_dec"] 118 | enc_attn_mask = batch["enc_attn_mask"] 119 | dec_attn_mask = batch["dec_attn_mask"] 120 | enc_dec_attn_mask = batch["enc_dec_attn_mask"] 121 | 122 | b, te = ids_enc.shape 123 | td = ids_dec.shape[1] 124 | 125 | y_pred, _, _ = self( 126 | ids_enc, 127 | ids_dec, 128 | enc_attn_mask.repeat(1, te).reshape(b, te, te), 129 | dec_attn_mask.repeat(1, td).reshape(b, td, td).tril(), 130 | enc_dec_attn_mask.repeat(1, td).reshape(b, td, td), 131 | ) 132 | 133 | y_pred = y_pred[dec_attn_mask == 1][:-1] 134 | y = ids_dec[dec_attn_mask == 1][1:] 135 | 136 | loss = nn.functional.cross_entropy( 137 | y_pred.reshape(-1, self.vocab_size), y.reshape(-1) 138 | ) 139 | return loss 140 | 141 | def training_step(self, batch, batch_idx): 142 | loss = self.compute_loss(batch) 143 | self.log("train_loss", loss) 144 | return loss 145 | 146 | def on_train_batch_end( 147 | self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int 148 | ) -> None: 149 | if self.scheduler is not None: 150 | self.scheduler.step() 151 | 152 | return super().on_train_batch_end(outputs, batch, batch_idx) 153 | 154 | def validation_step(self, batch, batch_idx): 155 | loss = self.compute_loss(batch) 156 | self.log("val_loss", loss) 157 | return loss 158 | 159 | def test_step(self, batch, batch_idx): 160 | loss = self.compute_loss(batch) 161 | self.log("test_loss", loss) 162 | return loss 163 | 164 | def generate(self, x_enc, x_dec, max_len=None): 165 | if max_len is None: 166 | max_len = self.max_len 167 | 168 | attn_enc = attn_enc_dec = torch.ones(x_enc.shape[1]).cuda() 169 | enc_out = self.forward_enc(x_enc, attn_enc) 170 | 171 | while x_dec.shape[1] < max_len: 172 | t = x_dec.shape[-1] 173 | attn_dec = torch.tril(torch.ones(t, t)).cuda() 174 | dec_out = self.forward_dec(x_dec, attn_dec, attn_enc_dec, enc_out) 175 | probs = torch.softmax(self.head(dec_out)[:, -1, :], dim=-1) 176 | token = torch.multinomial(probs, num_samples=1) 177 | x_dec = torch.cat((x_dec, token), dim=-1) 178 | return x_dec 179 | -------------------------------------------------------------------------------- /src/nlp/tokenizers/bpe.py: -------------------------------------------------------------------------------- 1 | """Byte-Pair Encoding (BPE) tokenizer.""" 2 | 3 | from tqdm.auto import tqdm 4 | 5 | 6 | def bpe_train(corpus, vocab_size): 7 | """Uses the BPE algorithm to train a tokenizer on a corpus. Returns the vocabulary.""" 8 | # Initialize the vocabulary 9 | corpus = [ 10 | l if word[0] == l else "##" + l for word in corpus.split(" ") for l in word 11 | ] 12 | vocab = set(corpus) 13 | 14 | # Keep merging most likely pairs until the vocabulary is the desired size 15 | for i in tqdm(range(vocab_size - len(vocab)), desc="Training bpe tokenizer..."): 16 | counts = {} 17 | 18 | # Count how many times each pair appears 19 | for i in range(len(corpus) - 1): 20 | pair = corpus[i] + corpus[i + 1] 21 | counts[pair] = counts.get(pair, 0) + 1 22 | 23 | # Find the pair that appeared the most 24 | max_pair, max_count = None, -1 25 | for k in counts: 26 | if counts[k] > max_count: 27 | max_pair, max_count = k, counts[k] 28 | 29 | # Adding pair to vocabulary 30 | vocab.add(max_pair) 31 | 32 | # Updating corpus 33 | new_corpus, added = [], False 34 | for i in range(len(corpus) - 1): 35 | if added: 36 | added = False 37 | continue 38 | 39 | if corpus[i] + corpus[i + 1] == max_pair: 40 | new_corpus.append(max_pair) 41 | added = True 42 | else: 43 | new_corpus.append(corpus[i]) 44 | 45 | corpus = new_corpus 46 | 47 | # Remove the "##" prefix and return vocabulary 48 | vocab = set( 49 | [ 50 | ("##" if elem.startswith("##") else "") + elem.replace("##", "") 51 | for elem in vocab 52 | ] 53 | ) 54 | return vocab 55 | 56 | 57 | if __name__ == "__main__": 58 | # Example 59 | corpus = "machine learning and meta learning allow machines to learn how to learn" 60 | vocabulary = bpe_train(corpus, 30) 61 | print(vocabulary) 62 | -------------------------------------------------------------------------------- /src/nlp/tokenizers/wordpiece.py: -------------------------------------------------------------------------------- 1 | """Wordpiece tokenizer.""" 2 | 3 | from tqdm.auto import tqdm 4 | 5 | 6 | def wordpiece_train(corpus, vocab_size): 7 | """Uses the wordpiece algorithm to train a tokenizer on a corpus. Returns the vocabulary.""" 8 | # Initialize the vocabulary with letters 9 | corpus = [ 10 | l if word[0] == l else "##" + l for word in corpus.split(" ") for l in word 11 | ] 12 | vocab = set(corpus) 13 | 14 | # Keep merging most likely pairs until the vocabulary is the desired size 15 | for i in tqdm( 16 | range(vocab_size - len(vocab)), desc="Training wordpiece tokenizer..." 17 | ): 18 | counts = {} 19 | pair_counts = {} 20 | 21 | # Keep count of each word and each pair 22 | for i in range(len(corpus)): 23 | counts[corpus[i]] = counts.get(corpus[i], 0) + 1 24 | 25 | if i == len(corpus) - 1: 26 | continue 27 | 28 | pair = corpus[i] + corpus[i + 1] 29 | pair_counts[(corpus[i], corpus[i + 1])] = counts.get(pair, 0) 30 | 31 | # Find the pair that has the highest score 32 | # The score is count(w1, w2) / (count(w1) * count(w2)) 33 | max_pair, max_score = None, -1 34 | for w1, w2 in pair_counts: 35 | p_count = pair_counts[(w1, w2)] 36 | pair_score = p_count / (counts[w1] * counts[w2]) 37 | if pair_score > max_score: 38 | max_pair, max_score = w1 + w2, pair_score 39 | 40 | # Add the pair with the highest score to the vocabulary 41 | vocab.add(max_pair) 42 | 43 | # Update the corpus by merging the pair 44 | new_corpus, added = [], False 45 | for i in range(len(corpus) - 1): 46 | if added: 47 | added = False 48 | continue 49 | 50 | if corpus[i] + corpus[i + 1] == max_pair: 51 | new_corpus.append(max_pair) 52 | added = True 53 | else: 54 | new_corpus.append(corpus[i]) 55 | 56 | corpus = new_corpus 57 | 58 | # Remove the "##" prefix and return vocabulary 59 | vocab = set( 60 | [ 61 | ("##" if elem.startswith("##") else "") + elem.replace("##", "") 62 | for elem in vocab 63 | ] 64 | ) 65 | return vocab 66 | 67 | 68 | if __name__ == "__main__": 69 | # Example 70 | corpus = "machine learning and meta learning allow machines to learn how to learn" 71 | vocabulary = wordpiece_train(corpus, 30) 72 | print(vocabulary) 73 | -------------------------------------------------------------------------------- /src/rl/dqn/dqn.py: -------------------------------------------------------------------------------- 1 | """Reimplementation of 'Playing Atari with Deep Reinforcement Learning' by Mnih et al. (2013)""" 2 | import os 3 | import random 4 | 5 | import torch 6 | from accelerate import Accelerator 7 | from tqdm.auto import tqdm 8 | 9 | import wandb 10 | 11 | 12 | class ReplayBuffer: 13 | def __init__(self, capacity=10000): 14 | self.capacity = capacity 15 | self.buffer = [] 16 | 17 | def push(self, state, action, reward, next_state, done): 18 | if len(self.buffer) >= self.capacity: 19 | self.buffer.pop(0) 20 | self.buffer.append((state, action, reward, next_state, done)) 21 | 22 | def sample(self, batch_size): 23 | state, action, reward, next_state, done = zip( 24 | *random.sample(self.buffer, batch_size) 25 | ) 26 | return ( 27 | torch.stack(state), 28 | torch.tensor(action), 29 | torch.tensor(reward), 30 | torch.stack(next_state), 31 | torch.tensor(done).int(), 32 | ) 33 | 34 | 35 | def dqn_training( 36 | model, 37 | environment, 38 | optimizer, 39 | gamma, 40 | epsilon, 41 | batch_size, 42 | episodes, 43 | checkpoint_path, 44 | buffer_capacity=10000, 45 | ): 46 | # Initialize run 47 | wandb.init(project="Papers Re-implementations", name="DQN") 48 | wandb.watch(model) 49 | 50 | # Create checkpoint directory 51 | os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) 52 | 53 | # Initialize accelerator 54 | accelerator = Accelerator() 55 | model, optimizer = accelerator.prepare(model, optimizer) 56 | 57 | # Initialize replay buffer 58 | buffer = ReplayBuffer(buffer_capacity) 59 | 60 | def state_process_fn(x): 61 | return torch.tensor(x).float().to(accelerator.device) 62 | 63 | def optimization_step(): 64 | # Sample replay buffer 65 | state, action, reward, next_state, done = buffer.sample(batch_size) 66 | action = action.to(accelerator.device) 67 | reward = reward.to(accelerator.device).float() 68 | done = done.to(accelerator.device) 69 | 70 | # Compute the target Q value 71 | with torch.no_grad(): 72 | target_q = reward + gamma * model(next_state).max(1)[0] * (1 - done) 73 | 74 | # Get current Q estimate 75 | current_q = model(state).gather(1, action.unsqueeze(1)) 76 | 77 | # Compute loss 78 | loss = torch.nn.functional.mse_loss(current_q, target_q.unsqueeze(1)) 79 | 80 | # Optimize the model 81 | optimizer.zero_grad() 82 | accelerator.backward(loss) 83 | accelerator.clip_grad_norm_(model.parameters(), max_norm=10.0) 84 | optimizer.step() 85 | 86 | return loss.item() 87 | 88 | # Training loop 89 | checkpoint_loss = float("inf") 90 | pbar = tqdm(range(episodes)) 91 | for ep in pbar: 92 | # Initialize episode 93 | pbar.set_description(f"Episode {ep+1}/{episodes}") 94 | state = state_process_fn(environment.reset()[0]) 95 | episode_loss, episode_reward, episode_length = 0, 0, 0 96 | 97 | done, truncated = False, False 98 | while not done and not truncated: 99 | # Act in the environment 100 | if random.random() < epsilon: 101 | action = environment.action_space.sample() 102 | else: 103 | action = model(state).argmax().item() 104 | 105 | # Update environment 106 | next_state, reward, done, truncated, _ = environment.step(action) 107 | next_state = state_process_fn(next_state) 108 | 109 | # Register transition in replay buffer 110 | buffer.push(state, action, reward, next_state, done) 111 | 112 | # Update Q-function 113 | if len(buffer.buffer) >= batch_size: 114 | loss = optimization_step() 115 | episode_reward += reward 116 | episode_loss += loss 117 | 118 | state = next_state 119 | episode_length += 1 120 | 121 | if len(buffer.buffer) >= batch_size: 122 | # Log episode stats 123 | wandb.log( 124 | { 125 | "loss": episode_loss, 126 | "reward": episode_reward, 127 | "ep. length": episode_length, 128 | } 129 | ) 130 | 131 | if len(buffer.buffer) >= batch_size and episode_loss < checkpoint_loss: 132 | torch.save(model.state_dict(), checkpoint_path) 133 | checkpoint_loss = episode_loss 134 | print( 135 | f"Checkpoint saved with loss {checkpoint_loss:.3f} at episode {ep+1} / {episodes}" 136 | ) 137 | 138 | wandb.finish() 139 | -------------------------------------------------------------------------------- /src/rl/dqn/main.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import gymnasium as gym 4 | import torch 5 | import torch.nn as nn 6 | from pytorch_lightning import seed_everything 7 | 8 | from src.rl.dqn.dqn import dqn_training 9 | 10 | DISCRETE_ACTION_ENVIRONMENTS = [ 11 | "Acrobot-v1", 12 | "CartPole-v1", 13 | "LunarLander-v2", 14 | "MountainCar-v0", 15 | ] 16 | 17 | 18 | class MLP(nn.Module): 19 | def __init__(self, input_dim, output_dim, hidden_dims=(64, 64)): 20 | super().__init__() 21 | self.relu = nn.ReLU() 22 | self.in_layer = nn.Linear(input_dim, hidden_dims[0]) 23 | self.blocks = nn.ModuleList( 24 | [ 25 | nn.Sequential( 26 | nn.LayerNorm(hidden_dims[i]), 27 | nn.Linear(hidden_dims[i], hidden_dims[i + 1]), 28 | nn.ReLU(), 29 | ) 30 | for i in range(len(hidden_dims) - 1) 31 | ] 32 | ) 33 | self.out = nn.Linear(hidden_dims[-1], output_dim) 34 | 35 | def forward(self, x): 36 | x = self.in_layer(x) 37 | for block in self.blocks: 38 | x = block(x) 39 | return self.out(x) 40 | 41 | 42 | def main(args): 43 | # Setting seed 44 | seed_everything(args["seed"]) 45 | 46 | # Unpacking args 47 | gamma = args["gamma"] 48 | epsilon = args["epsilon"] 49 | batch_size = args["batch_size"] 50 | lr = args["lr"] 51 | episodes = args["train_episodes"] 52 | checkpoint_path = args["checkpoint_path"] 53 | buffer_capacity = args["buffer_capacity"] 54 | optimizer_fn = getattr(torch.optim, args["optimizer"]) 55 | 56 | # Environment 57 | env = gym.make(args["env"]) 58 | 59 | # Training 60 | n_inputs, n_outputs = env.observation_space.shape[0], env.action_space.n 61 | model = MLP(n_inputs, n_outputs) 62 | optimizer = optimizer_fn(model.parameters(), lr=lr) 63 | dqn_training( 64 | model, 65 | env, 66 | optimizer, 67 | gamma, 68 | epsilon, 69 | batch_size, 70 | episodes, 71 | checkpoint_path, 72 | buffer_capacity, 73 | ) 74 | 75 | # Showing episodes 76 | env = gym.make(args["env"], render_mode="human") 77 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 78 | model = MLP(n_inputs, n_outputs) 79 | model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) 80 | model = model.eval().to(device) 81 | for episode in range(args["test_episodes"]): 82 | state = env.reset()[0] 83 | done = False 84 | while not done: 85 | action = model(torch.tensor(state).float().to(device)).argmax().item() 86 | state, _, done, _, _ = env.step(action) 87 | env.render() 88 | 89 | 90 | if __name__ == "__main__": 91 | parser = ArgumentParser() 92 | parser.add_argument( 93 | "--env", 94 | type=str, 95 | default="CartPole-v1", 96 | choices=DISCRETE_ACTION_ENVIRONMENTS, 97 | ) 98 | parser.add_argument("--seed", type=int, default=0) 99 | parser.add_argument("--gamma", type=float, default=0.99) 100 | parser.add_argument("--epsilon", type=float, default=0.05) 101 | parser.add_argument("--batch_size", type=int, default=256) 102 | parser.add_argument("--lr", type=float, default=0.001) 103 | parser.add_argument("--train_episodes", type=int, default=100) 104 | parser.add_argument("--test_episodes", type=int, default=10) 105 | parser.add_argument("--checkpoint_path", type=str, default="checkpoints/dqn/dqn.pt") 106 | parser.add_argument("--buffer_capacity", type=int, default=1024) 107 | parser.add_argument( 108 | "--optimizer", type=str, default="Adam", choices=["Adam", "SGD"] 109 | ) 110 | args = vars(parser.parse_args()) 111 | print(args) 112 | main(args) 113 | -------------------------------------------------------------------------------- /src/rl/ppo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/rl/ppo/__init__.py -------------------------------------------------------------------------------- /src/rl/ppo/ppo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Personal reimplementation of 3 | Proximal Policy Optimization Algorithms 4 | (https://arxiv.org/abs/1707.06347) 5 | """ 6 | 7 | from argparse import ArgumentParser 8 | 9 | import gym 10 | import numpy as np 11 | import pytorch_lightning as pl 12 | import torch 13 | import torch.nn as nn 14 | from torch.distributions.categorical import Categorical 15 | from torch.optim import Adam 16 | from torch.optim.lr_scheduler import LinearLR 17 | 18 | import wandb 19 | 20 | # Definitions 21 | MODEL_PATH = "model.pt" 22 | 23 | 24 | def parse_args(): 25 | """Pareser program arguments""" 26 | # Parser 27 | parser = ArgumentParser() 28 | 29 | # Program arguments (default for Atari games) 30 | parser.add_argument( 31 | "--max_iterations", 32 | type=int, 33 | help="Number of iterations of training", 34 | default=100, 35 | ) 36 | parser.add_argument( 37 | "--n_actors", type=int, help="Number of actors for each update", default=8 38 | ) 39 | parser.add_argument( 40 | "--horizon", type=int, help="Number of timestamps for each actor", default=128 41 | ) 42 | parser.add_argument("--epsilon", type=float, help="Epsilon parameter", default=0.1) 43 | parser.add_argument( 44 | "--n_epochs", 45 | type=int, 46 | help="Number of training epochs per iteration", 47 | default=3, 48 | ) 49 | parser.add_argument("--batch_size", type=int, help="Batch size", default=32 * 8) 50 | parser.add_argument("--lr", type=float, help="Learning rate", default=2.5 * 1e-4) 51 | parser.add_argument( 52 | "--gamma", type=float, help="Discount factor gamma", default=0.99 53 | ) 54 | parser.add_argument( 55 | "--c1", 56 | type=float, 57 | help="Weight for the value function in the loss function", 58 | default=1, 59 | ) 60 | parser.add_argument( 61 | "--c2", 62 | type=float, 63 | help="Weight for the entropy bonus in the loss function", 64 | default=0.01, 65 | ) 66 | parser.add_argument( 67 | "--n_test_episodes", type=int, help="Number of episodes to render", default=5 68 | ) 69 | parser.add_argument( 70 | "--seed", type=int, help="Randomizing seed for the experiment", default=0 71 | ) 72 | 73 | # Dictionary with program arguments 74 | return vars(parser.parse_args()) 75 | 76 | 77 | def get_device(): 78 | """Gets the device (GPU if any) and logs the type""" 79 | if torch.cuda.is_available(): 80 | device = torch.device("cuda") 81 | print(f"Found GPU device: {torch.cuda.get_device_name(device)}") 82 | else: 83 | device = torch.device("cpu") 84 | print("No GPU found: Running on CPU") 85 | return device 86 | 87 | 88 | class MyPPO(nn.Module): 89 | """Implementation of a PPO model. The same backbone is used to get actor and critic values.""" 90 | 91 | def __init__(self, in_shape, n_actions, hidden_d=100, share_backbone=False): 92 | # Super constructor 93 | super(MyPPO, self).__init__() 94 | 95 | # Attributes 96 | self.in_shape = in_shape 97 | self.n_actions = n_actions 98 | self.hidden_d = hidden_d 99 | self.share_backbone = share_backbone 100 | 101 | # Shared backbone for policy and value functions 102 | in_dim = np.prod(in_shape) 103 | 104 | def to_features(): 105 | return nn.Sequential( 106 | nn.Flatten(), 107 | nn.Linear(in_dim, hidden_d), 108 | nn.ReLU(), 109 | nn.Linear(hidden_d, hidden_d), 110 | nn.ReLU(), 111 | ) 112 | 113 | self.backbone = to_features() if self.share_backbone else nn.Identity() 114 | 115 | # State action function 116 | self.actor = nn.Sequential( 117 | nn.Identity() if self.share_backbone else to_features(), 118 | nn.Linear(hidden_d, hidden_d), 119 | nn.ReLU(), 120 | nn.Linear(hidden_d, n_actions), 121 | nn.Softmax(dim=-1), 122 | ) 123 | 124 | # Value function 125 | self.critic = nn.Sequential( 126 | nn.Identity() if self.share_backbone else to_features(), 127 | nn.Linear(hidden_d, hidden_d), 128 | nn.ReLU(), 129 | nn.Linear(hidden_d, 1), 130 | ) 131 | 132 | def forward(self, x): 133 | features = self.backbone(x) 134 | action = self.actor(features) 135 | value = self.critic(features) 136 | return Categorical(action).sample(), action, value 137 | 138 | 139 | @torch.no_grad() 140 | def run_timestamps(env, model, timestamps=128, render=False, device="cpu"): 141 | """Runs the given policy on the given environment for the given amount of timestamps. 142 | Returns a buffer with state action transitions and rewards.""" 143 | buffer = [] 144 | state = env.reset()[0] 145 | 146 | # Running timestamps and collecting state, actions, rewards and terminations 147 | for ts in range(timestamps): 148 | # Taking a step into the environment 149 | model_input = torch.from_numpy(state).unsqueeze(0).to(device).float() 150 | action, action_logits, value = model(model_input) 151 | new_state, reward, terminated, truncated, info = env.step(action.item()) 152 | 153 | # Rendering / storing (s, a, r, t) in the buffer 154 | if render: 155 | env.render() 156 | else: 157 | buffer.append( 158 | [ 159 | model_input, 160 | action, 161 | action_logits, 162 | value, 163 | reward, 164 | terminated or truncated, 165 | ] 166 | ) 167 | 168 | # Updating current state 169 | state = new_state 170 | 171 | # Resetting environment if episode terminated or truncated 172 | if terminated or truncated: 173 | state = env.reset()[0] 174 | 175 | return buffer 176 | 177 | 178 | def compute_cumulative_rewards(buffer, gamma): 179 | """Given a buffer with states, policy action logits, rewards and terminations, 180 | computes the cumulative rewards for each timestamp and substitutes them into the buffer. 181 | """ 182 | curr_rew = 0.0 183 | 184 | # Traversing the buffer on the reverse direction 185 | for i in range(len(buffer) - 1, -1, -1): 186 | r, t = buffer[i][-2], buffer[i][-1] 187 | 188 | if t: 189 | curr_rew = 0 190 | else: 191 | curr_rew = r + gamma * curr_rew 192 | 193 | buffer[i][-2] = curr_rew 194 | 195 | # Getting the average reward before normalizing (for logging and checkpointing) 196 | avg_rew = np.mean([buffer[i][-2] for i in range(len(buffer))]) 197 | 198 | # Normalizing cumulative rewards 199 | mean = np.mean([buffer[i][-2] for i in range(len(buffer))]) 200 | std = np.std([buffer[i][-2] for i in range(len(buffer))]) + 1e-6 201 | for i in range(len(buffer)): 202 | buffer[i][-2] = (buffer[i][-2] - mean) / std 203 | 204 | return avg_rew 205 | 206 | 207 | def get_losses(model, batch, epsilon, annealing, device="cpu"): 208 | """Returns the three loss terms for a given model and a given batch and additional parameters""" 209 | # Getting old data 210 | n = len(batch) 211 | states = torch.cat([batch[i][0] for i in range(n)]) 212 | actions = torch.cat([batch[i][1] for i in range(n)]).view(n, 1) 213 | logits = torch.cat([batch[i][2] for i in range(n)]) 214 | values = torch.cat([batch[i][3] for i in range(n)]) 215 | cumulative_rewards = ( 216 | torch.tensor([batch[i][-2] for i in range(n)]).view(-1, 1).float().to(device) 217 | ) 218 | 219 | # Computing predictions with the new model 220 | _, new_logits, new_values = model(states) 221 | 222 | # Loss on the state-action-function / actor (L_CLIP) 223 | advantages = cumulative_rewards - values 224 | margin = epsilon * annealing 225 | ratios = new_logits.gather(1, actions) / logits.gather(1, actions) 226 | 227 | l_clip = torch.mean( 228 | torch.min( 229 | torch.cat( 230 | ( 231 | ratios * advantages, 232 | torch.clip(ratios, 1 - margin, 1 + margin) * advantages, 233 | ), 234 | dim=1, 235 | ), 236 | dim=1, 237 | ).values 238 | ) 239 | 240 | # Loss on the value-function / critic (L_VF) 241 | l_vf = torch.mean((cumulative_rewards - new_values) ** 2) 242 | 243 | # Bonus for entropy of the actor 244 | entropy_bonus = torch.mean( 245 | torch.sum(-new_logits * (torch.log(new_logits + 1e-5)), dim=1) 246 | ) 247 | 248 | return l_clip, l_vf, entropy_bonus 249 | 250 | 251 | def training_loop( 252 | env, 253 | model, 254 | max_iterations, 255 | n_actors, 256 | horizon, 257 | gamma, 258 | epsilon, 259 | n_epochs, 260 | batch_size, 261 | lr, 262 | c1, 263 | c2, 264 | device, 265 | env_name="", 266 | ): 267 | """Train the model on the given environment using multiple actors acting up to n timestamps.""" 268 | 269 | # Starting a new Weights & Biases run 270 | wandb.init( 271 | project="Papers Re-implementations", 272 | entity="peutlefaire", 273 | name=f"PPO - {env_name}", 274 | config={ 275 | "env": str(env), 276 | "number of actors": n_actors, 277 | "horizon": horizon, 278 | "gamma": gamma, 279 | "epsilon": epsilon, 280 | "epochs": n_epochs, 281 | "batch size": batch_size, 282 | "learning rate": lr, 283 | "c1": c1, 284 | "c2": c2, 285 | }, 286 | ) 287 | 288 | # Training variables 289 | max_reward = float("-inf") 290 | optimizer = Adam(model.parameters(), lr=lr, maximize=True) 291 | scheduler = LinearLR(optimizer, 1, 0, max_iterations * n_epochs) 292 | anneals = np.linspace(1, 0, max_iterations) 293 | 294 | # Training loop 295 | for iteration in range(max_iterations): 296 | buffer = [] 297 | annealing = anneals[iteration] 298 | 299 | # Collecting timestamps for all actors with the current policy 300 | for actor in range(1, n_actors + 1): 301 | buffer.extend(run_timestamps(env, model, horizon, False, device)) 302 | 303 | # Computing cumulative rewards and shuffling the buffer 304 | avg_rew = compute_cumulative_rewards(buffer, gamma) 305 | np.random.shuffle(buffer) 306 | 307 | # Running optimization for a few epochs 308 | for epoch in range(n_epochs): 309 | for batch_idx in range(len(buffer) // batch_size): 310 | # Getting batch for this buffer 311 | start = batch_size * batch_idx 312 | end = start + batch_size if start + batch_size < len(buffer) else -1 313 | batch = buffer[start:end] 314 | 315 | # Zero-ing optimizers gradients 316 | optimizer.zero_grad() 317 | 318 | # Getting the losses 319 | l_clip, l_vf, entropy_bonus = get_losses( 320 | model, batch, epsilon, annealing, device 321 | ) 322 | 323 | # Computing total loss and back-propagating it 324 | loss = l_clip - c1 * l_vf + c2 * entropy_bonus 325 | loss.backward() 326 | 327 | # Optimizing 328 | optimizer.step() 329 | scheduler.step() 330 | 331 | # Logging information to stdout 332 | curr_loss = loss.item() 333 | log = ( 334 | f"Iteration {iteration + 1} / {max_iterations}: " 335 | f"Average Reward: {avg_rew:.2f}\t" 336 | f"Loss: {curr_loss:.3f} " 337 | f"(L_CLIP: {l_clip.item():.1f} | L_VF: {l_vf.item():.1f} | L_bonus: {entropy_bonus.item():.1f})" 338 | ) 339 | if avg_rew > max_reward: 340 | torch.save(model.state_dict(), MODEL_PATH) 341 | max_reward = avg_rew 342 | log += " --> Stored model with highest average reward" 343 | print(log) 344 | 345 | # Logging information to W&B 346 | wandb.log( 347 | { 348 | "loss (total)": curr_loss, 349 | "loss (clip)": l_clip.item(), 350 | "loss (vf)": l_vf.item(), 351 | "loss (entropy bonus)": entropy_bonus.item(), 352 | "average reward": avg_rew, 353 | } 354 | ) 355 | 356 | # Finishing W&B session 357 | wandb.finish() 358 | 359 | 360 | def testing_loop(env, model, n_episodes, device): 361 | """Runs the learned policy on the environment for n episodes""" 362 | for _ in range(n_episodes): 363 | run_timestamps(env, model, timestamps=128, render=True, device=device) 364 | 365 | 366 | def main(): 367 | # Parsing program arguments 368 | args = parse_args() 369 | print(args) 370 | 371 | # Setting seed 372 | pl.seed_everything(args["seed"]) 373 | 374 | # Getting device 375 | device = get_device() 376 | 377 | # Creating environment (discrete action space) 378 | env_name = "CartPole-v1" 379 | env = gym.make(env_name) 380 | 381 | # Creating the model (both actor and critic) 382 | model = MyPPO(env.observation_space.shape, env.action_space.n).to(device) 383 | 384 | # Training 385 | training_loop( 386 | env, 387 | model, 388 | args["max_iterations"], 389 | args["n_actors"], 390 | args["horizon"], 391 | args["gamma"], 392 | args["epsilon"], 393 | args["n_epochs"], 394 | args["batch_size"], 395 | args["lr"], 396 | args["c1"], 397 | args["c2"], 398 | device, 399 | env_name, 400 | ) 401 | 402 | # Loading best model 403 | model = MyPPO(env.observation_space.shape, env.action_space.n).to(device) 404 | model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) 405 | 406 | # Testing 407 | env = gym.make(env_name, render_mode="human") 408 | testing_loop(env, model, args["n_test_episodes"], device) 409 | env.close() 410 | 411 | 412 | if __name__ == "__main__": 413 | main() 414 | --------------------------------------------------------------------------------