├── .github
├── CODEOWNERS
├── README.md
└── images
│ ├── ddpm_both.gif
│ ├── ddpm_fashion.gif
│ ├── ddpm_mnist.gif
│ ├── dqn_cartpole.gif
│ ├── gnns_training.png
│ ├── nf_generated_images.png
│ ├── ppo_cartpole.gif
│ └── vit_architecture.png
├── .gitignore
├── .pre-commit-config.yaml
├── .vscode
└── launch.json
├── LICENSE
├── data
└── nlp
│ ├── bert
│ └── masked_sentences.txt
│ ├── gpt
│ └── prompts.txt
│ └── original
│ └── translate_sentences.txt
├── environment.yml
├── requirements.txt
└── src
├── __init__.py
├── cv
├── __init__.py
├── ddpm
│ ├── __init__.py
│ ├── ddpm.py
│ ├── models.py
│ └── notebook
│ │ └── DDPM.ipynb
├── ign
│ ├── ign.py
│ ├── main.py
│ └── model.py
├── nf
│ ├── __init__.py
│ └── normalizing_flows.py
├── vir
│ ├── train.py
│ └── vir.py
└── vit
│ ├── __init__.py
│ └── vit_torch.py
├── fff
├── __init__.py
├── fff.py
└── main.py
├── gnns
├── __init__.py
└── gnns.py
├── nlp
├── __init__.py
├── bert
│ ├── README.md
│ ├── __init__.py
│ ├── data.py
│ ├── main.py
│ └── model.py
├── gpt
│ ├── README.md
│ ├── __init__.py
│ ├── main.py
│ └── model.py
├── layers
│ ├── __init__.py
│ ├── attention.py
│ ├── decoder.py
│ ├── embeddings.py
│ ├── encoder.py
│ └── mlp.py
├── lm_is_compression
│ ├── __init__.py
│ └── lm_is_compression.py
├── lm_watermarking
│ ├── main.py
│ ├── plot.py
│ └── watermarking.py
├── original
│ ├── __init__.py
│ ├── data.py
│ ├── main.py
│ └── model.py
└── tokenizers
│ ├── bpe.py
│ └── wordpiece.py
└── rl
├── dqn
├── dqn.py
└── main.py
└── ppo
├── PPO.ipynb
├── __init__.py
└── ppo.py
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @BrianPulfer
2 |
--------------------------------------------------------------------------------
/.github/README.md:
--------------------------------------------------------------------------------
1 | # Overview
2 |
3 | Personal re-implementations of Machine Learning papers. Re-implementations might use different hyper-parameters / datasets / settings compared to the original paper.
4 |
5 | Current re-implementations include:
6 |
7 | | Paper | Code | Blog |
8 | | ----------- | ----------- | ----------- |
9 | | Natural language processing |
10 | |[A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226v2)|[Code](/src/nlp/lm_watermarking/)|
11 | | [Attention is all you need](https://arxiv.org/abs/1706.03762) | [Code](/src/nlp/original/) | *Coming soon*
12 | | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) | [Code](/src/nlp/bert/) | *Coming soon*
13 | | [Language Modeling Is Compression](https://arxiv.org/abs/2309.10668) | [Code](/src/nlp/lm_is_compression/) |
14 | | [Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165)| [Code](/src/nlp/gpt/) | *Coming soon*
15 | | Computer Vision |
16 | |[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929)| [Code](/src/cv/vit/) | [Blog](https://www.brianpulfer.ch/blog/vit)
17 | |[Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | [Code](/src/cv/ddpm/) | [Blog](https://www.brianpulfer.ch/blog/ddpm)
18 | | [Density estimation using Real NVP](https://arxiv.org/abs/1605.08803)| [Code](/src/cv/nf/) |
19 | | [Idempotent Generative Network](https://arxiv.org/abs/2311.01462)| [Code](/src/cv/ign/) | [Blog](https://brianpulfer.ch/blog/ign)
20 | |[ViR: Vision Retention Networks](https://arxiv.org/abs/2310.19731)|[Code](/src/cv/vir/)| [Blog](https://brianpulfer.ch/blog/vir)
21 | | Reinforcement Learning |
22 | |[Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347)| [Code](/src/ppo/) | [Blog](https://www.brianpulfer.ch/blog/ppo) |
23 | | [Playing Atari with Deep Reinforcement Learning](https://arxiv.org/abs/1312.5602) | [Code](/src/rl/dqn/) | *Coming soon* |
24 | | Others |
25 | |[Everything is Connected: Graph Neural Networks](https://arxiv.org/abs/2301.08210)| [Code](/src/gnns/) |
26 | |[Fast Feedforward Networks](https://arxiv.org/abs/2308.14711)| [Code](/src/fff/) |
27 |
28 |
29 | # Contributing
30 | While this repo is a personal attempt to familiarize with the ideas down to the nitty gritty details, contributions are welcome for re-implementations that are already on the repository. In particular, I am open to discuss doubts, questions, suggestions to improve the code, and spotted mistakes / bugs. If you would like to contribute, simply raise an issue before submitting a pull request.
31 |
32 | # License
33 | The code is released with the [MIT license](/LICENSE).
34 |
--------------------------------------------------------------------------------
/.github/images/ddpm_both.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/.github/images/ddpm_both.gif
--------------------------------------------------------------------------------
/.github/images/ddpm_fashion.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/.github/images/ddpm_fashion.gif
--------------------------------------------------------------------------------
/.github/images/ddpm_mnist.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/.github/images/ddpm_mnist.gif
--------------------------------------------------------------------------------
/.github/images/dqn_cartpole.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/.github/images/dqn_cartpole.gif
--------------------------------------------------------------------------------
/.github/images/gnns_training.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/.github/images/gnns_training.png
--------------------------------------------------------------------------------
/.github/images/nf_generated_images.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/.github/images/nf_generated_images.png
--------------------------------------------------------------------------------
/.github/images/ppo_cartpole.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/.github/images/ppo_cartpole.gif
--------------------------------------------------------------------------------
/.github/images/vit_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/.github/images/vit_architecture.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Custom folders and files
2 | .DS_Store
3 | /datasets
4 | *.pt
5 | wandb/
6 | checkpoints/
7 | lightning_logs/
8 |
9 | # Byte-compiled / optimized / DLL files
10 | __pycache__/
11 | *.py[cod]
12 | *$py.class
13 |
14 | # C extensions
15 | *.so
16 |
17 | # Distribution / packaging
18 | .Python
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 | cover/
61 |
62 | # Translations
63 | *.mo
64 | *.pot
65 |
66 | # Django stuff:
67 | *.log
68 | local_settings.py
69 | db.sqlite3
70 | db.sqlite3-journal
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | .pybuilder/
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | # For a library or package, you might want to ignore these files since the code is
95 | # intended to run in multiple environments; otherwise, check them in:
96 | # .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # poetry
106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
107 | # This is especially recommended for binary packages to ensure reproducibility, and is more
108 | # commonly ignored for libraries.
109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
110 | #poetry.lock
111 |
112 | # pdm
113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
114 | #pdm.lock
115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
116 | # in version control.
117 | # https://pdm.fming.dev/#use-with-ide
118 | .pdm.toml
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v2.3.0
4 | hooks:
5 | - id: check-yaml
6 | - id: end-of-file-fixer
7 | - id: trailing-whitespace
8 |
9 | - repo: https://github.com/psf/black
10 | rev: 23.3.0
11 | hooks:
12 | - id: black-jupyter
13 |
14 | - repo: https://github.com/pycqa/isort
15 | rev: 5.12.0
16 | hooks:
17 | - id: isort
18 | args: ["--profile", "black", "--filter-files"]
19 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "0.2.0",
3 | "configurations": [
4 | {
5 | "name": "Python: Current File",
6 | "type": "python",
7 | "request": "launch",
8 | "program": "${file}",
9 | "console": "integratedTerminal",
10 | "env": {"PYTHONPATH": "${workspaceFolder}"},
11 | "justMyCode": false
12 | },
13 | {
14 | "name": "cv - ddpm",
15 | "type": "python",
16 | "request": "launch",
17 | "program": "${workspaceFolder}/src/cv/ddpm/ddpm.py",
18 | "console": "integratedTerminal",
19 | "env": {"PYTHONPATH": "${workspaceFolder}"},
20 | "justMyCode": true
21 | },
22 | {
23 | "name": "cv - ign",
24 | "type": "python",
25 | "request": "launch",
26 | "program": "${workspaceFolder}/src/cv/ign/main.py",
27 | "console": "integratedTerminal",
28 | "env": {"PYTHONPATH": "${workspaceFolder}"},
29 | "justMyCode": true
30 | },
31 | {
32 | "name": "cv - normalizing flows",
33 | "type": "python",
34 | "request": "launch",
35 | "program": "${workspaceFolder}/src/cv/nf/normalizing_flows.py",
36 | "console": "integratedTerminal",
37 | "env": {"PYTHONPATH": "${workspaceFolder}"},
38 | "justMyCode": true
39 | },
40 | {
41 | "name": "cv - vir",
42 | "type": "python",
43 | "request": "launch",
44 | "program": "${workspaceFolder}/src/cv/vir/train.py",
45 | "console": "integratedTerminal",
46 | "env": {"PYTHONPATH": "${workspaceFolder}"},
47 | "justMyCode": true
48 | },
49 | {
50 | "name": "cv - vit",
51 | "type": "python",
52 | "request": "launch",
53 | "program": "${workspaceFolder}/src/cv/vit/vit_torch.py",
54 | "console": "integratedTerminal",
55 | "env": {"PYTHONPATH": "${workspaceFolder}"},
56 | "justMyCode": true
57 | },
58 | {
59 | "name": "fff",
60 | "type": "python",
61 | "request": "launch",
62 | "program": "${workspaceFolder}/src/fff/main.py",
63 | "console": "integratedTerminal",
64 | "env": {"PYTHONPATH": "${workspaceFolder}"},
65 | "justMyCode": true
66 | },
67 | {
68 | "name": "gnns",
69 | "type": "python",
70 | "request": "launch",
71 | "program": "${workspaceFolder}/src/gnns/gnns.py",
72 | "console": "integratedTerminal",
73 | "env": {"PYTHONPATH": "${workspaceFolder}"},
74 | "justMyCode": true
75 | },
76 | {
77 | "name": "nlp - bert",
78 | "type": "python",
79 | "request": "launch",
80 | "program": "${workspaceFolder}/src/nlp/bert/main.py",
81 | "console": "integratedTerminal",
82 | "env": {"PYTHONPATH": "${workspaceFolder}"},
83 | "args": [
84 | "--max_train_steps", "80",
85 | "--warmup_steps", "30",
86 | ],
87 | "justMyCode": true
88 | },
89 | {
90 | "name": "nlp - gpt",
91 | "type": "python",
92 | "request": "launch",
93 | "program": "${workspaceFolder}/src/nlp/gpt/main.py",
94 | "console": "integratedTerminal",
95 | "env": {"PYTHONPATH": "${workspaceFolder}"},
96 | "args": [
97 | "--max_train_steps", "80",
98 | "--warmup_steps", "8",
99 | "--batch_size", "16",
100 | ],
101 | "justMyCode": true
102 | },
103 | {
104 | "name": "nlp - lm compression",
105 | "type": "python",
106 | "request": "launch",
107 | "program": "${workspaceFolder}/src/nlp/lm_is_compression/lm_is_compression.py",
108 | "console": "integratedTerminal",
109 | "env": {"PYTHONPATH": "${workspaceFolder}"},
110 | "justMyCode": true
111 | },
112 | {
113 | "name": "nlp - original",
114 | "type": "python",
115 | "request": "launch",
116 | "program": "${workspaceFolder}/src/nlp/original/main.py",
117 | "console": "integratedTerminal",
118 | "env": {"PYTHONPATH": "${workspaceFolder}"},
119 | "args": [
120 | "--max_train_steps", "80",
121 | "--warmup_steps", "8",
122 | "--batch_size", "4",
123 | ],
124 | "justMyCode": true
125 | },
126 | {
127 | "name": "nlp - watermark",
128 | "type": "python",
129 | "request": "launch",
130 | "program": "${workspaceFolder}/src/nlp/lm_watermarking/main.py",
131 | "console": "integratedTerminal",
132 | "env": {"PYTHONPATH": "${workspaceFolder}"},
133 | "justMyCode": true
134 | },
135 | {
136 | "name": "rl - dqn",
137 | "type": "python",
138 | "request": "launch",
139 | "program": "${workspaceFolder}/src/rl/dqn/main.py",
140 | "console": "integratedTerminal",
141 | "env": {"PYTHONPATH": "${workspaceFolder}"},
142 | "justMyCode": true
143 | },
144 | {
145 | "name": "rl - ppo",
146 | "type": "python",
147 | "request": "launch",
148 | "program": "${workspaceFolder}/src/rl/ppo/ppo.py",
149 | "console": "integratedTerminal",
150 | "env": {"PYTHONPATH": "${workspaceFolder}"},
151 | "justMyCode": true
152 | }
153 | ]
154 | }
155 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Peutlefaire
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/data/nlp/bert/masked_sentences.txt:
--------------------------------------------------------------------------------
1 | Paris is the [MASK] of France.
2 | World War II started in [MASK] and ended in [MASK].
3 | The term AI stands for [MASK] Intelligence.
4 |
--------------------------------------------------------------------------------
/data/nlp/gpt/prompts.txt:
--------------------------------------------------------------------------------
1 | A dice has 6 sides, and
2 | Bern is the capital of
3 | Throughout a year there are four seasons:
4 |
--------------------------------------------------------------------------------
/data/nlp/original/translate_sentences.txt:
--------------------------------------------------------------------------------
1 | Der Stift liegt auf dem Tisch.
2 | Ich kenne kein Alter, keine Müdigkeit, keine Niederlage.
3 | Man weiß nicht, was man nicht weiß.
4 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: reimplementations
2 | channels:
3 | - nvidia
4 | - pytorch-nightly
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=5.1=1_gnu
9 | - blas=1.0=mkl
10 | - brotlipy=0.7.0=py311h9bf148f_1002
11 | - bzip2=1.0.8=h7b6447c_0
12 | - ca-certificates=2023.08.22=h06a4308_0
13 | - certifi=2023.7.22=py311h06a4308_0
14 | - cffi=1.15.1=py311h9bf148f_3
15 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
16 | - cryptography=38.0.4=py311h46ebde7_0
17 | - cuda-cudart=12.1.105=0
18 | - cuda-cupti=12.1.105=0
19 | - cuda-libraries=12.1.0=0
20 | - cuda-nvrtc=12.1.105=0
21 | - cuda-nvtx=12.1.105=0
22 | - cuda-opencl=12.2.140=0
23 | - cuda-runtime=12.1.0=0
24 | - ffmpeg=4.2.2=h20bf706_0
25 | - freetype=2.12.1=h4a9f257_0
26 | - giflib=5.2.1=h5eee18b_3
27 | - gmp=6.2.1=h295c915_3
28 | - gmpy2=2.1.2=py311hc9b5ff0_0
29 | - gnutls=3.6.15=he1e5248_0
30 | - idna=3.4=py311h06a4308_0
31 | - intel-openmp=2021.4.0=h06a4308_3561
32 | - jinja2=3.1.2=py311h06a4308_0
33 | - jpeg=9e=h5eee18b_1
34 | - lame=3.100=h7b6447c_0
35 | - lcms2=2.12=h3be6417_0
36 | - ld_impl_linux-64=2.38=h1181459_1
37 | - lerc=3.0=h295c915_0
38 | - libcublas=12.1.0.26=0
39 | - libcufft=11.0.2.4=0
40 | - libcufile=1.7.2.10=0
41 | - libcurand=10.3.3.141=0
42 | - libcusolver=11.4.4.55=0
43 | - libcusparse=12.0.2.55=0
44 | - libdeflate=1.17=h5eee18b_1
45 | - libffi=3.4.4=h6a678d5_0
46 | - libgcc-ng=11.2.0=h1234567_1
47 | - libgomp=11.2.0=h1234567_1
48 | - libidn2=2.3.4=h5eee18b_0
49 | - libjpeg-turbo=2.0.0=h9bf148f_0
50 | - libnpp=12.0.2.50=0
51 | - libnvjitlink=12.1.105=0
52 | - libnvjpeg=12.1.1.14=0
53 | - libopus=1.3.1=h7b6447c_0
54 | - libpng=1.6.39=h5eee18b_0
55 | - libstdcxx-ng=11.2.0=h1234567_1
56 | - libtasn1=4.19.0=h5eee18b_0
57 | - libtiff=4.5.1=h6a678d5_0
58 | - libunistring=0.9.10=h27cfd23_0
59 | - libuuid=1.41.5=h5eee18b_0
60 | - libvpx=1.7.0=h439df22_0
61 | - libwebp=1.2.4=h11a3e52_1
62 | - libwebp-base=1.2.4=h5eee18b_1
63 | - llvm-openmp=14.0.6=h9e868ea_0
64 | - lz4-c=1.9.4=h6a678d5_0
65 | - markupsafe=2.1.1=py311h5eee18b_0
66 | - mkl=2021.4.0=h06a4308_640
67 | - mkl-service=2.4.0=py311h9bf148f_0
68 | - mkl_fft=1.3.1=py311hc796f24_0
69 | - mkl_random=1.2.2=py311hbba84a0_0
70 | - mpc=1.1.0=h10f8cd9_1
71 | - mpfr=4.0.2=hb69a4c5_1
72 | - mpmath=1.2.1=py311_0
73 | - ncurses=6.4=h6a678d5_0
74 | - nettle=3.7.3=hbbd107a_1
75 | - networkx=3.1=py311h06a4308_0
76 | - numpy=1.24.3=py311hc206e33_0
77 | - numpy-base=1.24.3=py311hfd5febd_0
78 | - openh264=2.1.1=h4ff587b_0
79 | - openssl=3.0.11=h7f8727e_2
80 | - pillow=9.3.0=py311h3fd9d12_2
81 | - pip=23.2.1=py311h06a4308_0
82 | - pycparser=2.21=pyhd3eb1b0_0
83 | - pyopenssl=23.2.0=py311h06a4308_0
84 | - pysocks=1.7.1=py311_0
85 | - python=3.11.5=h955ad1f_0
86 | - pytorch=2.2.0.dev20231001=py3.11_cuda12.1_cudnn8.9.2_0
87 | - pytorch-cuda=12.1=ha16c6d3_5
88 | - pytorch-mutex=1.0=cuda
89 | - pyyaml=6.0=py311h5eee18b_1
90 | - readline=8.2=h5eee18b_0
91 | - requests=2.28.1=py311_0
92 | - six=1.16.0=pyhd3eb1b0_1
93 | - sqlite=3.41.2=h5eee18b_0
94 | - sympy=1.11.1=py311h06a4308_0
95 | - tk=8.6.12=h1ccaba5_0
96 | - torchaudio=2.2.0.dev20231001=py311_cu121
97 | - torchtriton=2.1.0+6e4932cda8=py311
98 | - torchvision=0.17.0.dev20231001=py311_cu121
99 | - typing_extensions=4.7.1=py311h06a4308_0
100 | - urllib3=1.26.14=py311_0
101 | - wheel=0.41.2=py311h06a4308_0
102 | - x264=1!157.20191217=h7b6447c_0
103 | - xz=5.4.2=h5eee18b_0
104 | - yaml=0.2.5=h7b6447c_0
105 | - zlib=1.2.13=h5eee18b_0
106 | - zstd=1.5.5=hc292b87_0
107 | - pip:
108 | - absl-py==2.0.0
109 | - accelerate==0.23.0
110 | - aiohttp==3.8.5
111 | - aiosignal==1.3.1
112 | - apache-beam==2.50.0
113 | - appdirs==1.4.4
114 | - async-timeout==4.0.3
115 | - attrs==23.1.0
116 | - black==23.9.1
117 | - blis==0.7.11
118 | - bpytop==1.0.68
119 | - cachetools==5.3.1
120 | - catalogue==2.0.10
121 | - cfgv==3.4.0
122 | - click==8.1.7
123 | - cloudpathlib==0.15.1
124 | - cloudpickle==2.2.1
125 | - confection==0.1.3
126 | - contourpy==1.1.1
127 | - crcmod==1.7
128 | - cycler==0.12.0
129 | - cymem==2.0.8
130 | - datasets==2.14.5
131 | - deepspeed==0.10.3
132 | - dill==0.3.1.1
133 | - distlib==0.3.7
134 | - dnspython==2.4.2
135 | - docker-pycreds==0.4.0
136 | - docopt==0.6.2
137 | - einops==0.7.0
138 | - fastavro==1.8.4
139 | - fasteners==0.19
140 | - filelock==3.12.4
141 | - flake8==6.1.0
142 | - fonttools==4.43.0
143 | - frozenlist==1.4.0
144 | - fsspec==2023.6.0
145 | - gitdb==4.0.10
146 | - gitpython==3.1.37
147 | - google-auth==2.23.2
148 | - google-auth-oauthlib==1.0.0
149 | - grpcio==1.59.0
150 | - gym==0.26.2
151 | - gym-notices==0.0.8
152 | - hdfs==2.7.2
153 | - hjson==3.1.0
154 | - httplib2==0.22.0
155 | - huggingface-hub==0.16.4
156 | - identify==2.5.30
157 | - imageio==2.31.5
158 | - importlib-metadata==6.8.0
159 | - isort==5.12.0
160 | - jedi==0.19.1
161 | - kiwisolver==1.4.5
162 | - langcodes==3.3.0
163 | - lightning-utilities==0.9.0
164 | - markdown==3.4.4
165 | - matplotlib==3.8.0
166 | - mccabe==0.7.0
167 | - multidict==6.0.4
168 | - multiprocess==0.70.15
169 | - murmurhash==1.0.10
170 | - mwparserfromhell==0.6.5
171 | - mypy-extensions==1.0.0
172 | - ninja==1.11.1
173 | - nodeenv==1.8.0
174 | - nvidia-ml-py==12.535.108
175 | - nvitop==1.3.0
176 | - oauthlib==3.2.2
177 | - objsize==0.6.1
178 | - orjson==3.9.7
179 | - packaging==23.2
180 | - pandas==2.1.1
181 | - parso==0.8.3
182 | - pathspec==0.11.2
183 | - pathtools==0.1.2
184 | - pathy==0.10.2
185 | - platformdirs==3.10.0
186 | - pre-commit==3.4.0
187 | - preshed==3.0.9
188 | - prompt-toolkit==3.0.39
189 | - proto-plus==1.22.3
190 | - protobuf==4.23.4
191 | - psutil==5.9.5
192 | - ptpython==3.0.23
193 | - py-cpuinfo==9.0.0
194 | - pyarrow==11.0.0
195 | - pyasn1==0.5.0
196 | - pyasn1-modules==0.3.0
197 | - pycodestyle==2.11.0
198 | - pydantic==1.10.13
199 | - pydot==1.4.2
200 | - pyflakes==3.1.0
201 | - pygments==2.16.1
202 | - pymongo==4.5.0
203 | - pyparsing==3.1.1
204 | - python-dateutil==2.8.2
205 | - python-dotenv==1.0.0
206 | - pytorch-lightning==2.0.9.post0
207 | - pytz==2023.3.post1
208 | - regex==2023.10.3
209 | - requests-oauthlib==1.3.1
210 | - rsa==4.9
211 | - safetensors==0.3.3
212 | - seaborn==0.13.0
213 | - sentencepiece==0.1.99
214 | - sentry-sdk==1.31.0
215 | - setproctitle==1.3.3
216 | - setuptools==68.2.2
217 | - smart-open==6.4.0
218 | - smmap==5.0.1
219 | - spacy==3.7.0
220 | - spacy-legacy==3.0.12
221 | - spacy-loggers==1.0.5
222 | - srsly==2.4.8
223 | - tensorboard==2.14.1
224 | - tensorboard-data-server==0.7.1
225 | - termcolor==2.3.0
226 | - thinc==8.2.1
227 | - tokenizers==0.14.0
228 | - tomli==2.0.1
229 | - torchmetrics==1.2.0
230 | - tqdm==4.66.1
231 | - transformers==4.34.0
232 | - typer==0.7.0
233 | - tzdata==2023.3
234 | - virtualenv==20.24.5
235 | - wandb==0.15.11
236 | - wasabi==1.1.2
237 | - wcwidth==0.2.8
238 | - weasel==0.3.1
239 | - werkzeug==3.0.0
240 | - xxhash==3.3.0
241 | - yapf==0.40.2
242 | - yarl==1.9.2
243 | - zipp==3.17.0
244 | - zstandard==0.21.0
245 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | apache_beam
3 | datasets
4 | deepspeed
5 | einops
6 | gymnasium
7 | gymnasium[Box2D]
8 | imageio
9 | matplotlib
10 | mwparserfromhell
11 | numpy
12 | pandas
13 | pillow
14 | pre-commit
15 | python-dotenv
16 | pytorch-lightning
17 | sacremoses
18 | seaborn
19 | spacy
20 | tensorboard
21 | torch
22 | torchvision
23 | tqdm
24 | transformers
25 | wandb
26 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/__init__.py
--------------------------------------------------------------------------------
/src/cv/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/cv/__init__.py
--------------------------------------------------------------------------------
/src/cv/ddpm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/cv/ddpm/__init__.py
--------------------------------------------------------------------------------
/src/cv/ddpm/ddpm.py:
--------------------------------------------------------------------------------
1 | # Import of libraries
2 | import random
3 | from argparse import ArgumentParser
4 |
5 | import einops
6 | import imageio
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 | import torch
10 | import torch.nn as nn
11 | from torch.optim import Adam
12 | from torch.utils.data import DataLoader
13 | from torchvision.datasets.mnist import MNIST, FashionMNIST
14 | from torchvision.transforms import Compose, Lambda, ToTensor
15 | from tqdm.auto import tqdm
16 |
17 | # Import of custom models
18 | from src.cv.ddpm.models import MyDDPM, MyUNet
19 |
20 | # Setting reproducibility
21 | SEED = 0
22 | random.seed(SEED)
23 | np.random.seed(SEED)
24 | torch.manual_seed(SEED)
25 |
26 | # Definitions
27 | STORE_PATH_MNIST = f"ddpm_model_mnist.pt"
28 | STORE_PATH_FASHION = f"ddpm_model_fashion.pt"
29 |
30 |
31 | def show_images(images, title=""):
32 | """Shows the provided images as sub-pictures in a square"""
33 |
34 | # Converting images to CPU numpy arrays
35 | if type(images) is torch.Tensor:
36 | images = images.detach().cpu().numpy()
37 |
38 | # Defining number of rows and columns
39 | fig = plt.figure(figsize=(8, 8))
40 | rows = int(len(images) ** (1 / 2))
41 | cols = round(len(images) / rows)
42 |
43 | # Populating figure with sub-plots
44 | idx = 0
45 | for r in range(rows):
46 | for c in range(cols):
47 | fig.add_subplot(rows, cols, idx + 1)
48 |
49 | if idx < len(images):
50 | plt.imshow(images[idx][0], cmap="gray")
51 | idx += 1
52 | fig.suptitle(title, fontsize=30)
53 |
54 | # Showing the figure
55 | plt.show()
56 |
57 |
58 | def show_first_batch(loader):
59 | for batch in loader:
60 | show_images(batch[0], "Images in the first batch")
61 | break
62 |
63 |
64 | def show_forward(ddpm, loader, device):
65 | # Showing the forward process
66 | for batch in loader:
67 | imgs = batch[0]
68 |
69 | show_images(imgs, "Original images")
70 |
71 | for percent in [0.25, 0.5, 0.75, 1]:
72 | show_images(
73 | ddpm(
74 | imgs.to(device),
75 | [int(percent * ddpm.n_steps) - 1 for _ in range(len(imgs))],
76 | ),
77 | f"DDPM Noisy images {int(percent * 100)}%",
78 | )
79 | break
80 |
81 |
82 | def generate_new_images(
83 | ddpm,
84 | n_samples=16,
85 | device=None,
86 | frames_per_gif=100,
87 | gif_name="sampling.gif",
88 | c=1,
89 | h=28,
90 | w=28,
91 | ):
92 | """Given a DDPM model, a number of samples to be generated and a device, returns some newly generated samples"""
93 | frame_idxs = np.linspace(0, ddpm.n_steps, frames_per_gif).astype(np.uint)
94 | frames = []
95 |
96 | with torch.no_grad():
97 | if device is None:
98 | device = ddpm.device
99 |
100 | # Starting from random noise
101 | x = torch.randn(n_samples, c, h, w).to(device)
102 |
103 | for idx, t in enumerate(list(range(ddpm.n_steps))[::-1]):
104 | # Estimating noise to be removed
105 | time_tensor = (torch.ones(n_samples, 1) * t).to(device).long()
106 | eta_theta = ddpm.backward(x, time_tensor)
107 |
108 | alpha_t = ddpm.alphas[t]
109 | alpha_t_bar = ddpm.alpha_bars[t]
110 |
111 | # Partially denoising the image
112 | x = (1 / alpha_t.sqrt()) * (
113 | x - (1 - alpha_t) / (1 - alpha_t_bar).sqrt() * eta_theta
114 | )
115 |
116 | if t > 0:
117 | z = torch.randn(n_samples, c, h, w).to(device)
118 |
119 | # Option 1: sigma_t squared = beta_t
120 | beta_t = ddpm.betas[t]
121 | sigma_t = beta_t.sqrt()
122 |
123 | # Option 2: sigma_t squared = beta_tilda_t
124 | # prev_alpha_t_bar = ddpm.alpha_bars[t-1] if t > 0 else ddpm.alphas[0]
125 | # beta_tilda_t = ((1 - prev_alpha_t_bar)/(1 - alpha_t_bar)) * beta_t
126 | # sigma_t = beta_tilda_t.sqrt()
127 |
128 | # Adding some more noise like in Langevin Dynamics fashion
129 | x = x + sigma_t * z
130 |
131 | # Adding frames to the GIF
132 | if idx in frame_idxs or t == 0:
133 | # Putting digits in range [0, 255]
134 | normalized = x.clone()
135 | for i in range(len(normalized)):
136 | normalized[i] -= torch.min(normalized[i])
137 | normalized[i] *= 255 / torch.max(normalized[i])
138 |
139 | # Reshaping batch (n, c, h, w) to be a (as much as it gets) square frame
140 | frame = einops.rearrange(
141 | normalized,
142 | "(b1 b2) c h w -> (b1 h) (b2 w) c",
143 | b1=int(n_samples**0.5),
144 | )
145 | frame = frame.cpu().numpy().astype(np.uint8)
146 |
147 | # Rendering frame
148 | frames.append(frame)
149 |
150 | # Storing the gif
151 | with imageio.get_writer(gif_name, mode="I") as writer:
152 | for idx, frame in enumerate(frames):
153 | rgb_frame = np.repeat(frame, 3, axis=2)
154 | writer.append_data(rgb_frame)
155 |
156 | # Showing the last frame for a longer time
157 | if idx == len(frames) - 1:
158 | last_rgb_frame = np.repeat(frames[-1], 3, axis=2)
159 | for _ in range(frames_per_gif // 3):
160 | writer.append_data(last_rgb_frame)
161 | return x
162 |
163 |
164 | def training_loop(
165 | ddpm, loader, n_epochs, optim, device, display=False, store_path="ddpm_model.pt"
166 | ):
167 | mse = nn.MSELoss()
168 | best_loss = float("inf")
169 | n_steps = ddpm.n_steps
170 |
171 | for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"):
172 | epoch_loss = 0.0
173 | for step, batch in enumerate(
174 | tqdm(
175 | loader,
176 | leave=False,
177 | desc=f"Epoch {epoch + 1}/{n_epochs}",
178 | colour="#005500",
179 | )
180 | ):
181 | # Loading data
182 | x0 = batch[0].to(device)
183 | n = len(x0)
184 |
185 | # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
186 | eta = torch.randn_like(x0).to(device)
187 | t = torch.randint(0, n_steps, (n,)).to(device)
188 |
189 | # Computing the noisy image based on x0 and the time-step (forward process)
190 | noisy_imgs = ddpm(x0, t, eta)
191 |
192 | # Getting model estimation of noise based on the images and the time-step
193 | eta_theta = ddpm.backward(noisy_imgs, t.reshape(n, -1))
194 |
195 | # Optimizing the MSE between the noise plugged and the predicted noise
196 | loss = mse(eta_theta, eta)
197 | optim.zero_grad()
198 | loss.backward()
199 | optim.step()
200 |
201 | epoch_loss += loss.item() * len(x0) / len(loader.dataset)
202 |
203 | # Display images generated at this epoch
204 | if display:
205 | show_images(
206 | generate_new_images(ddpm, device=device),
207 | f"Images generated at epoch {epoch + 1}",
208 | )
209 |
210 | log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}"
211 |
212 | # Storing the model
213 | if best_loss > epoch_loss:
214 | best_loss = epoch_loss
215 | torch.save(ddpm.state_dict(), store_path)
216 | log_string += " --> Best model ever (stored)"
217 |
218 | print(log_string)
219 |
220 |
221 | def main():
222 | # Program arguments
223 | parser = ArgumentParser()
224 | parser.add_argument(
225 | "--no_train", action="store_true", help="Whether to train a new model or not"
226 | )
227 | parser.add_argument(
228 | "--fashion",
229 | action="store_true",
230 | help="Uses MNIST if true, Fashion MNIST otherwise",
231 | )
232 | parser.add_argument("--bs", type=int, default=128, help="Batch size")
233 | parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
234 | parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
235 | args = vars(parser.parse_args())
236 | print(args)
237 |
238 | # Model store path
239 | store_path = "ddpm_fashion.pt" if args["fashion"] else "ddpm_mnist.pt"
240 |
241 | # Loading the data (converting each image into a tensor and normalizing between [-1, 1])
242 | transform = Compose([ToTensor(), Lambda(lambda x: (x - 0.5) * 2)])
243 | ds_fn = MNIST if not args["fashion"] else FashionMNIST
244 | dataset = ds_fn("./../datasets", download=True, train=True, transform=transform)
245 | loader = DataLoader(dataset, args["bs"], shuffle=True)
246 |
247 | # Getting device
248 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
249 | print(
250 | f"Using device: {device}\t"
251 | + (f"{torch.cuda.get_device_name(0)}" if torch.cuda.is_available() else "CPU")
252 | )
253 |
254 | # Defining model
255 | n_steps, min_beta, max_beta = 1000, 10**-4, 0.02 # Originally used by the authors
256 | ddpm = MyDDPM(
257 | MyUNet(n_steps),
258 | n_steps=n_steps,
259 | min_beta=min_beta,
260 | max_beta=max_beta,
261 | device=device,
262 | )
263 |
264 | # Optionally, load a pre-trained model that will be further trained
265 | # ddpm.load_state_dict(torch.load(store_path, map_location=device))
266 |
267 | # Optionally, show a batch of regular images
268 | # show_first_batch(loader)
269 |
270 | # Optionally, show the diffusion (forward) process
271 | # show_forward(ddpm, loader, device)
272 |
273 | # Optionally, show the denoising (backward) process
274 | # generated = generate_new_images(ddpm, gif_name="before_training.gif")
275 | # show_images(generated, "Images generated before training")
276 |
277 | # Training
278 | if not args["no_train"]:
279 | n_epochs, lr = args["epochs"], args["lr"]
280 | training_loop(
281 | ddpm,
282 | loader,
283 | n_epochs,
284 | optim=Adam(ddpm.parameters(), lr),
285 | device=device,
286 | store_path=store_path,
287 | )
288 |
289 | # Loading the trained model
290 | best_model = MyDDPM(MyUNet(), n_steps=n_steps, device=device)
291 | best_model.load_state_dict(torch.load(store_path, map_location=device))
292 | best_model.eval()
293 | print("Model loaded: Generating new images")
294 |
295 | # Showing generated images
296 | generated = generate_new_images(
297 | best_model,
298 | n_samples=100,
299 | device=device,
300 | gif_name="fashion.gif" if args["fashion"] else "mnist.gif",
301 | )
302 | show_images(generated, "Final result")
303 |
304 |
305 | if __name__ == "__main__":
306 | main()
307 |
--------------------------------------------------------------------------------
/src/cv/ddpm/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def sinusoidal_embedding(n, d):
6 | # Returns the standard positional embedding
7 | embedding = torch.zeros(n, d)
8 | wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
9 | wk = wk.reshape((1, d))
10 | t = torch.arange(n).reshape((n, 1))
11 | embedding[:, ::2] = torch.sin(t * wk[:, ::2])
12 | embedding[:, 1::2] = torch.cos(t * wk[:, ::2])
13 |
14 | return embedding
15 |
16 |
17 | # DDPM class
18 | class MyDDPM(nn.Module):
19 | def __init__(
20 | self,
21 | network,
22 | n_steps=200,
23 | min_beta=10**-4,
24 | max_beta=0.02,
25 | device=None,
26 | image_chw=(1, 28, 28),
27 | ):
28 | super(MyDDPM, self).__init__()
29 | self.n_steps = n_steps
30 | self.device = device
31 | self.image_chw = image_chw
32 | self.network = network.to(device)
33 | self.betas = torch.linspace(min_beta, max_beta, n_steps).to(
34 | device
35 | ) # Number of steps is typically in the order of thousands
36 | self.alphas = 1 - self.betas
37 | self.alpha_bars = torch.tensor(
38 | [torch.prod(self.alphas[: i + 1]) for i in range(len(self.alphas))]
39 | ).to(device)
40 |
41 | def forward(self, x0, t, eta=None):
42 | # Make input image more noisy (we can directly skip to the desired step)
43 | n, c, h, w = x0.shape
44 | a_bar = self.alpha_bars[t]
45 |
46 | if eta is None:
47 | eta = torch.randn(n, c, h, w).to(self.device)
48 |
49 | noisy = (
50 | a_bar.sqrt().reshape(n, 1, 1, 1) * x0
51 | + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta
52 | )
53 | return noisy
54 |
55 | def backward(self, x, t):
56 | # Run each image through the network for each timestep t in the vector t.
57 | # The network returns its estimation of the noise that was added.
58 | return self.network(x, t)
59 |
60 |
61 | class MyBlock(nn.Module):
62 | def __init__(
63 | self,
64 | shape,
65 | in_c,
66 | out_c,
67 | kernel_size=3,
68 | stride=1,
69 | padding=1,
70 | activation=None,
71 | normalize=True,
72 | ):
73 | super(MyBlock, self).__init__()
74 | self.ln = nn.LayerNorm(shape)
75 | self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding)
76 | self.conv2 = nn.Conv2d(out_c, out_c, kernel_size, stride, padding)
77 | self.activation = nn.SiLU() if activation is None else activation
78 | self.normalize = normalize
79 |
80 | def forward(self, x):
81 | out = self.ln(x) if self.normalize else x
82 | out = self.conv1(out)
83 | out = self.activation(out)
84 | out = self.conv2(out)
85 | out = self.activation(out)
86 | return out
87 |
88 |
89 | class MyUNet(nn.Module):
90 | def __init__(self, n_steps=1000, time_emb_dim=100):
91 | super(MyUNet, self).__init__()
92 |
93 | # Sinusoidal embedding
94 | self.time_embed = nn.Embedding(n_steps, time_emb_dim)
95 | self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
96 | self.time_embed.requires_grad_(False)
97 |
98 | # First half
99 | self.te1 = self._make_te(time_emb_dim, 1)
100 | self.b1 = nn.Sequential(
101 | MyBlock((1, 28, 28), 1, 10),
102 | MyBlock((10, 28, 28), 10, 10),
103 | MyBlock((10, 28, 28), 10, 10),
104 | )
105 | self.down1 = nn.Conv2d(10, 10, 4, 2, 1)
106 |
107 | self.te2 = self._make_te(time_emb_dim, 10)
108 | self.b2 = nn.Sequential(
109 | MyBlock((10, 14, 14), 10, 20),
110 | MyBlock((20, 14, 14), 20, 20),
111 | MyBlock((20, 14, 14), 20, 20),
112 | )
113 | self.down2 = nn.Conv2d(20, 20, 4, 2, 1)
114 |
115 | self.te3 = self._make_te(time_emb_dim, 20)
116 | self.b3 = nn.Sequential(
117 | MyBlock((20, 7, 7), 20, 40),
118 | MyBlock((40, 7, 7), 40, 40),
119 | MyBlock((40, 7, 7), 40, 40),
120 | )
121 | self.down3 = nn.Sequential(
122 | nn.Conv2d(40, 40, 2, 1), nn.SiLU(), nn.Conv2d(40, 40, 4, 2, 1)
123 | )
124 |
125 | # Bottleneck
126 | self.te_mid = self._make_te(time_emb_dim, 40)
127 | self.b_mid = nn.Sequential(
128 | MyBlock((40, 3, 3), 40, 20),
129 | MyBlock((20, 3, 3), 20, 20),
130 | MyBlock((20, 3, 3), 20, 40),
131 | )
132 |
133 | # Second half
134 | self.up1 = nn.Sequential(
135 | nn.ConvTranspose2d(40, 40, 4, 2, 1),
136 | nn.SiLU(),
137 | nn.ConvTranspose2d(40, 40, 2, 1),
138 | )
139 |
140 | self.te4 = self._make_te(time_emb_dim, 80)
141 | self.b4 = nn.Sequential(
142 | MyBlock((80, 7, 7), 80, 40),
143 | MyBlock((40, 7, 7), 40, 20),
144 | MyBlock((20, 7, 7), 20, 20),
145 | )
146 |
147 | self.up2 = nn.ConvTranspose2d(20, 20, 4, 2, 1)
148 | self.te5 = self._make_te(time_emb_dim, 40)
149 | self.b5 = nn.Sequential(
150 | MyBlock((40, 14, 14), 40, 20),
151 | MyBlock((20, 14, 14), 20, 10),
152 | MyBlock((10, 14, 14), 10, 10),
153 | )
154 |
155 | self.up3 = nn.ConvTranspose2d(10, 10, 4, 2, 1)
156 | self.te_out = self._make_te(time_emb_dim, 20)
157 | self.b_out = nn.Sequential(
158 | MyBlock((20, 28, 28), 20, 10),
159 | MyBlock((10, 28, 28), 10, 10),
160 | MyBlock((10, 28, 28), 10, 10, normalize=False),
161 | )
162 |
163 | self.conv_out = nn.Conv2d(10, 1, 3, 1, 1)
164 |
165 | def forward(self, x, t):
166 | # x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)
167 | t = self.time_embed(t)
168 | n = len(x)
169 | out1 = self.b1(x + self.te1(t).reshape(n, -1, 1, 1)) # (N, 10, 28, 28)
170 | out2 = self.b2(
171 | self.down1(out1) + self.te2(t).reshape(n, -1, 1, 1)
172 | ) # (N, 20, 14, 14)
173 | out3 = self.b3(
174 | self.down2(out2) + self.te3(t).reshape(n, -1, 1, 1)
175 | ) # (N, 40, 7, 7)
176 |
177 | out_mid = self.b_mid(
178 | self.down3(out3) + self.te_mid(t).reshape(n, -1, 1, 1)
179 | ) # (N, 40, 3, 3)
180 |
181 | out4 = torch.cat((out3, self.up1(out_mid)), dim=1) # (N, 80, 7, 7)
182 | out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1, 1)) # (N, 20, 7, 7)
183 |
184 | out5 = torch.cat((out2, self.up2(out4)), dim=1) # (N, 40, 14, 14)
185 | out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1, 1)) # (N, 10, 14, 14)
186 |
187 | out = torch.cat((out1, self.up3(out5)), dim=1) # (N, 20, 28, 28)
188 | out = self.b_out(out + self.te_out(t).reshape(n, -1, 1, 1)) # (N, 1, 28, 28)
189 |
190 | out = self.conv_out(out)
191 |
192 | return out
193 |
194 | def _make_te(self, dim_in, dim_out):
195 | return nn.Sequential(
196 | nn.Linear(dim_in, dim_out), nn.SiLU(), nn.Linear(dim_out, dim_out)
197 | )
198 |
--------------------------------------------------------------------------------
/src/cv/ign/ign.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import pytorch_lightning as pl
4 | from torch.nn import L1Loss
5 | from torch.optim import Adam
6 |
7 |
8 | class IdempotentNetwork(pl.LightningModule):
9 | def __init__(
10 | self,
11 | prior,
12 | model,
13 | lr=1e-4,
14 | criterion=L1Loss(),
15 | lrec_w=20.0,
16 | lidem_w=20.0,
17 | ltight_w=2.5,
18 | ):
19 | super(IdempotentNetwork, self).__init__()
20 | self.prior = prior
21 | self.model = model
22 | self.model_copy = deepcopy(model)
23 | self.lr = lr
24 | self.criterion = criterion
25 | self.lrec_w = lrec_w
26 | self.lidem_w = lidem_w
27 | self.ltight_w = ltight_w
28 |
29 | def forward(self, x):
30 | return self.model(x)
31 |
32 | def configure_optimizers(self):
33 | optim = Adam(self.model.parameters(), lr=self.lr, betas=(0.5, 0.999))
34 | return optim
35 |
36 | def get_losses(self, x):
37 | # Prior samples
38 | z = self.prior.sample_n(x.shape[0]).to(x.device)
39 |
40 | # Updating the copy
41 | self.model_copy.load_state_dict(self.model.state_dict())
42 |
43 | # Forward passes
44 | fx = self(x)
45 | fz = self(z)
46 | fzd = fz.detach()
47 |
48 | l_rec = self.lrec_w * self.criterion(fx, x)
49 | l_idem = self.lidem_w * self.criterion(self.model_copy(fz), fz)
50 | l_tight = -self.ltight_w * self.criterion(self(fzd), fzd)
51 |
52 | return l_rec, l_idem, l_tight
53 |
54 | def training_step(self, batch, batch_idx):
55 | l_rec, l_idem, l_tight = self.get_losses(batch)
56 | loss = l_rec + l_idem + l_tight
57 |
58 | self.log_dict(
59 | {
60 | "train/loss_rec": l_rec,
61 | "train/loss_idem": l_idem,
62 | "train/loss_tight": l_tight,
63 | "train/loss": l_rec + l_idem + l_tight,
64 | },
65 | sync_dist=True,
66 | )
67 |
68 | return loss
69 |
70 | def validation_step(self, batch, batch_idx):
71 | l_rec, l_idem, l_tight = self.get_losses(batch)
72 | loss = l_rec + l_idem + l_tight
73 |
74 | self.log_dict(
75 | {
76 | "val/loss_rec": l_rec,
77 | "val/loss_idem": l_idem,
78 | "val/loss_tight": l_tight,
79 | "val/loss": loss,
80 | },
81 | sync_dist=True,
82 | )
83 |
84 | def test_step(self, batch, batch_idx):
85 | l_rec, l_idem, l_tight = self.get_losses(batch)
86 | loss = l_rec + l_idem + l_tight
87 |
88 | self.log_dict(
89 | {
90 | "test/loss_rec": l_rec,
91 | "test/loss_idem": l_idem,
92 | "test/loss_tight": l_tight,
93 | "test/loss": loss,
94 | },
95 | sync_dist=True,
96 | )
97 |
98 | def generate_n(self, n, device=None):
99 | z = self.prior.sample_n(n)
100 |
101 | if device is not None:
102 | z = z.to(device)
103 |
104 | return self(z)
105 |
--------------------------------------------------------------------------------
/src/cv/ign/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 |
4 | import pytorch_lightning as pl
5 | import torch
6 | from pytorch_lightning.callbacks import ModelCheckpoint
7 | from pytorch_lightning.loggers import WandbLogger
8 | from torch.utils.data import DataLoader
9 | from torchvision.datasets import MNIST
10 | from torchvision.transforms import Compose, Lambda, ToTensor
11 | from torchvision.utils import save_image
12 |
13 | from src.cv.ign.ign import IdempotentNetwork
14 | from src.cv.ign.model import DCGANLikeModel
15 |
16 |
17 | def main(args):
18 | # Set seed
19 | pl.seed_everything(args["seed"])
20 |
21 | # Load datas
22 | normalize = Lambda(lambda x: (x - 0.5) * 2)
23 | noise = Lambda(lambda x: (x + torch.randn_like(x) * 0.15).clamp(-1, 1))
24 | train_transform = Compose([ToTensor(), normalize, noise])
25 | val_transform = Compose([ToTensor(), normalize])
26 |
27 | train_set = MNIST(
28 | root="data/mnist", train=True, download=True, transform=train_transform
29 | )
30 | val_set = MNIST(
31 | root="data/mnist", train=False, download=True, transform=val_transform
32 | )
33 |
34 | def collate_fn(samples):
35 | return torch.stack([sample[0] for sample in samples])
36 |
37 | train_loader = DataLoader(
38 | train_set,
39 | batch_size=args["batch_size"],
40 | shuffle=True,
41 | collate_fn=collate_fn,
42 | num_workers=args["num_workers"],
43 | )
44 | val_loader = DataLoader(
45 | val_set,
46 | batch_size=args["batch_size"],
47 | shuffle=False,
48 | collate_fn=collate_fn,
49 | num_workers=args["num_workers"],
50 | )
51 |
52 | # Initialize model
53 | prior = torch.distributions.Normal(torch.zeros(1, 28, 28), torch.ones(1, 28, 28))
54 | net = DCGANLikeModel()
55 | model = IdempotentNetwork(prior, net, args["lr"])
56 |
57 | if not args["skip_train"]:
58 | # Train model
59 | logger = WandbLogger(name="IGN", project="Papers Re-implementations")
60 | callbacks = [
61 | ModelCheckpoint(
62 | monitor="val/loss",
63 | mode="min",
64 | dirpath="checkpoints/ign",
65 | filename="best",
66 | )
67 | ]
68 | trainer = pl.Trainer(
69 | strategy="ddp",
70 | accelerator="auto",
71 | max_epochs=args["epochs"],
72 | logger=logger,
73 | callbacks=callbacks,
74 | )
75 | trainer.fit(model, train_loader, val_loader)
76 |
77 | # Loading the best model
78 | device = "cuda" if torch.cuda.is_available() else "cpu"
79 | model = (
80 | IdempotentNetwork.load_from_checkpoint(
81 | "checkpoints/ign/best.ckpt", prior=prior, model=net
82 | )
83 | .eval()
84 | .to(device)
85 | )
86 |
87 | # Generating images with the trained model
88 | os.makedirs("generated", exist_ok=True)
89 |
90 | images = model.generate_n(100, device=device)
91 | save_image(images, "generated.png", nrow=10, normalize=True)
92 |
93 | print("Done!")
94 |
95 |
96 | if __name__ == "__main__":
97 | parser = ArgumentParser()
98 | parser.add_argument("--seed", type=int, default=0)
99 | parser.add_argument("--lr", type=float, default=1e-4)
100 | parser.add_argument("--batch_size", type=int, default=256)
101 | parser.add_argument("--epochs", type=int, default=50)
102 | parser.add_argument("--num_workers", type=int, default=8)
103 | parser.add_argument("--skip_train", action="store_true")
104 | args = vars(parser.parse_args())
105 |
106 | print("\n\n", args, "\n\n")
107 | main(args)
108 |
--------------------------------------------------------------------------------
/src/cv/ign/model.py:
--------------------------------------------------------------------------------
1 | """DCGAN code from https://github.com/kpandey008/dcgan"""
2 | import torch.nn as nn
3 |
4 |
5 | class Discriminator(nn.Module):
6 | def __init__(self, in_channels=1, base_c=64):
7 | super(Discriminator, self).__init__()
8 | self.main = nn.Sequential(
9 | # Input Size: 1 x 28 x 28
10 | nn.Conv2d(in_channels, base_c, 4, 2, 1, bias=False),
11 | nn.Dropout2d(0.1),
12 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
13 | # Input Size: 32 x 14 x 14
14 | nn.BatchNorm2d(base_c),
15 | nn.Conv2d(base_c, base_c * 2, 4, 2, 1, bias=False),
16 | nn.Dropout2d(0.1),
17 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
18 | # Input Size: 64 x 7 x 7
19 | nn.BatchNorm2d(base_c * 2),
20 | nn.Conv2d(base_c * 2, base_c * 4, 3, 1, 0, bias=False),
21 | nn.Dropout2d(0.1),
22 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
23 | # Input Size: 128 x 7 x 7
24 | nn.BatchNorm2d(base_c * 4),
25 | nn.Conv2d(base_c * 4, base_c * 8, 3, 1, 0, bias=False),
26 | nn.Dropout2d(0.1),
27 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
28 | # Input Size: 256 x 7 x 7
29 | nn.Conv2d(base_c * 8, base_c * 8, 3, 1, 0, bias=False),
30 | )
31 |
32 | def forward(self, input):
33 | return self.main(input)
34 |
35 |
36 | class Generator(nn.Module):
37 | def __init__(self, in_channels=512, out_channels=1):
38 | super(Generator, self).__init__()
39 | self.main = nn.Sequential(
40 | # Input Size: 256 x 7 x 7
41 | nn.BatchNorm2d(in_channels),
42 | nn.ConvTranspose2d(in_channels, in_channels // 2, 3, 1, 0, bias=False),
43 | nn.Dropout2d(0.1),
44 | nn.ReLU(True),
45 | # Input Size: 128 x 7 x 7
46 | nn.BatchNorm2d(in_channels // 2),
47 | nn.ConvTranspose2d(in_channels // 2, in_channels // 4, 3, 1, 0, bias=False),
48 | nn.Dropout2d(0.1),
49 | nn.ReLU(True),
50 | # Input Size: 64 x 7 x 7
51 | nn.BatchNorm2d(in_channels // 4),
52 | nn.ConvTranspose2d(in_channels // 4, in_channels // 8, 3, 1, 0, bias=False),
53 | nn.Dropout2d(0.1),
54 | nn.ReLU(True),
55 | # Input Size: 32 x 14 x 14
56 | nn.BatchNorm2d(in_channels // 8),
57 | nn.ConvTranspose2d(
58 | in_channels // 8, in_channels // 16, 4, 2, 1, bias=False
59 | ),
60 | nn.Dropout2d(0.1),
61 | nn.ReLU(True),
62 | # Input Size : 16 x 28 x 28
63 | nn.ConvTranspose2d(in_channels // 16, out_channels, 4, 2, 1, bias=False),
64 | nn.Tanh(),
65 | # Final Output : 1 x 28 x 28
66 | )
67 |
68 | def forward(self, input):
69 | return self.main(input)
70 |
71 |
72 | class DCGANLikeModel(nn.Module):
73 | def __init__(self, in_channels=1, base_c=64):
74 | super(DCGANLikeModel, self).__init__()
75 | self.discriminator = Discriminator(in_channels=in_channels, base_c=base_c)
76 | self.generator = Generator(base_c * 8, out_channels=in_channels)
77 |
78 | def forward(self, x):
79 | return self.generator(self.discriminator(x))
80 |
--------------------------------------------------------------------------------
/src/cv/nf/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/cv/nf/__init__.py
--------------------------------------------------------------------------------
/src/cv/nf/normalizing_flows.py:
--------------------------------------------------------------------------------
1 | """
2 | Personal reimplementation of
3 | Density estimation using Real NVP
4 | (https://arxiv.org/abs/1605.08803)
5 |
6 | Useful links:
7 | - https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial11/NF_image_modeling.html
8 | """
9 |
10 | import os
11 | from argparse import ArgumentParser
12 |
13 | import matplotlib.pyplot as plt
14 | import numpy as np
15 | import torch
16 | import torch.nn as nn
17 | from PIL import Image
18 | from torch.optim import Adam
19 | from torch.optim.lr_scheduler import StepLR
20 | from torch.utils.data import DataLoader
21 | from torchvision.datasets import MNIST
22 | from torchvision.transforms import Compose, Lambda, ToTensor
23 | from torchvision.utils import save_image
24 | from tqdm.auto import tqdm
25 |
26 | # Seeding
27 | SEED = 17
28 | np.random.seed(SEED)
29 | torch.random.manual_seed(SEED)
30 | torch.use_deterministic_algorithms(True)
31 | torch.backends.cudnn.deterministic = True
32 |
33 |
34 | def test_reversability(model, x):
35 | """Tests that x ≈ model.backward(model.forward(x)) and shows images"""
36 | with torch.no_grad():
37 | # Running input forward and backward
38 | z = model.forward(x)[0]
39 | x_tilda = model.backward(z)[0]
40 |
41 | # Printing MSE
42 | mse = ((x_tilda - x) ** 2).mean()
43 | print(f"MSE between input and reconstruction: {mse}")
44 |
45 | # Comparing images visually
46 | plt.imshow(x[0][0].cpu().numpy(), cmap="gray")
47 | plt.title("Original image")
48 | plt.show()
49 |
50 | plt.imshow(z[0][0].cpu().numpy(), cmap="gray")
51 | plt.title("After forward pass")
52 | plt.show()
53 |
54 | plt.imshow(x_tilda[0][0].cpu().numpy(), cmap="gray")
55 | plt.title("Reconstructed image")
56 | plt.show()
57 |
58 |
59 | class LayerNormChannels(nn.Module):
60 | def __init__(self, c_in, eps=1e-5):
61 | super().__init__()
62 | self.gamma = nn.Parameter(torch.ones(1, c_in, 1, 1))
63 | self.beta = nn.Parameter(torch.zeros(1, c_in, 1, 1))
64 | self.eps = eps
65 |
66 | def forward(self, x):
67 | mean = x.mean(dim=1, keepdim=True)
68 | var = x.var(dim=1, unbiased=False, keepdim=True)
69 | y = (x - mean) / torch.sqrt(var + self.eps)
70 | y = y * self.gamma + self.beta
71 | return y
72 |
73 |
74 | class CNNBlock(nn.Module):
75 | """A simple CNN architecture which will applied at each Affine Coupling step"""
76 |
77 | def __init__(self, n_channels, kernel_size=3):
78 | super(CNNBlock, self).__init__()
79 | self.elu = nn.ELU()
80 |
81 | self.conv1 = nn.Conv2d(
82 | 2 * n_channels, n_channels, kernel_size, 1, kernel_size // 2
83 | )
84 | self.conv2 = nn.Conv2d(2 * n_channels, 2 * n_channels, 1, 1)
85 |
86 | def forward(self, x):
87 | out = torch.cat((self.elu(x), self.elu(-x)), dim=1)
88 | out = self.conv1(out)
89 | out = torch.cat((self.elu(out), self.elu(-out)), dim=1)
90 | out = self.conv2(out)
91 | val, gate = out.chunk(2, 1)
92 | return x + val * torch.sigmoid(gate)
93 |
94 |
95 | class SimpleCNN(nn.Module):
96 | def __init__(self, blocks=3, channels_in=1, channels_hidden=32, kernel_size=3):
97 | super(SimpleCNN, self).__init__()
98 |
99 | self.elu = nn.ELU()
100 | self.conv_in = nn.Conv2d(channels_in, channels_hidden, 3, 1, 1)
101 | self.net = nn.Sequential(
102 | *[
103 | nn.Sequential(
104 | CNNBlock(channels_hidden, kernel_size),
105 | LayerNormChannels(channels_hidden),
106 | )
107 | for _ in range(blocks)
108 | ]
109 | )
110 | self.conv_out = nn.Conv2d(2 * channels_hidden, 2 * channels_in, 3, 1, 1)
111 |
112 | # Initializing final convolution weights to zeros
113 | self.conv_out.weight.data.zero_()
114 | self.conv_out.bias.data.zero_()
115 |
116 | def forward(self, x):
117 | out = self.net(self.conv_in(x))
118 | out = torch.cat((self.elu(out), self.elu(-out)), dim=1)
119 | return self.conv_out(out)
120 |
121 |
122 | class Dequantization(nn.Module):
123 | """Dequantizes the image. Dequantization is the first step for flows, as it allows to not load datapoints
124 | with high likelihoods and put volume on other input data as well."""
125 |
126 | def __init__(self, max_val):
127 | super(Dequantization, self).__init__()
128 | self.eps = 1e-5
129 | self.max_val = max_val
130 | self.sigmoid_fn = nn.Sigmoid()
131 |
132 | def sigmoid(self, x):
133 | return self.sigmoid_fn(x)
134 |
135 | def log_det_sigmoid(self, x):
136 | s = self.sigmoid(x)
137 | return torch.log(s - s**2)
138 |
139 | def inv_sigmoid(self, x):
140 | return -torch.log((x) ** -1 - 1)
141 |
142 | def log_det_inv_sigmoid(self, x):
143 | return torch.log(1 / (x - x**2))
144 |
145 | def forward(self, x):
146 | # Dequantizing input (adding continuous noise in range [0, 1]) and putting in range [0, 1]
147 | x = x.to(torch.float32)
148 | log_det = (
149 | -np.log(self.max_val)
150 | * np.prod(x.shape[1:])
151 | * torch.ones(len(x)).to(x.device)
152 | )
153 | out = (x + torch.rand_like(x).detach()) / self.max_val
154 |
155 | # Making sure the input is not too close to either 0 or 1 (bounds of inverse sigmoid) --> put closer to 0.5
156 | log_det += np.log(1 - self.eps) * np.prod(x.shape[1:])
157 | out = (1 - self.eps) * out + self.eps * 0.5
158 |
159 | # Running the input through the inverse sigmoid function
160 | log_det += self.log_det_inv_sigmoid(out).sum(dim=[1, 2, 3])
161 | out = self.inv_sigmoid(out)
162 |
163 | return out, log_det
164 |
165 | def backward(self, x):
166 | # Running through the Sigmoid function
167 | log_det = self.log_det_sigmoid(x).sum(dim=[1, 2, 3])
168 | out = self.sigmoid(x)
169 |
170 | # Undoing the weighted sum
171 | log_det -= np.log(1 - self.eps) * np.prod(x.shape[1:])
172 | out = (out - self.eps * 0.5) / (1 - self.eps)
173 |
174 | # Undoing the dequantization
175 | log_det += np.log(self.max_val) * np.prod(x.shape[1:])
176 | out *= self.max_val
177 | out = torch.floor(out).clamp(min=0, max=self.max_val)
178 |
179 | return out, log_det
180 |
181 |
182 | class AffineCoupling(nn.Module):
183 | """Affine Coupling layer. Only modifies half of the input by running the other half through some non-linear function."""
184 |
185 | def __init__(self, m: nn.Module, modify_x2=True, chw=(1, 28, 28)):
186 | super(AffineCoupling, self).__init__()
187 | self.m = m
188 | self.modify_x2 = modify_x2
189 |
190 | c, h, w = chw
191 | self.scaling_fac = nn.Parameter(torch.ones(c))
192 | self.mask = torch.tensor(
193 | [[(j + k) % 2 == 0 for k in range(w)] for j in range(h)]
194 | )
195 | self.mask = self.mask.unsqueeze(0).unsqueeze(0)
196 |
197 | if self.modify_x2:
198 | self.mask = ~self.mask
199 |
200 | def forward(self, x):
201 | # Splitting input in two halves
202 | mask = self.mask.to(x.device)
203 | x1 = mask * x
204 |
205 | # Computing scale and shift for x2
206 | scale, shift = self.m(x1).chunk(2, 1) # Non linear network
207 | s_fac = self.scaling_fac.exp().view(1, -1, 1, 1)
208 | scale = torch.tanh(scale / s_fac) * s_fac # Stabilizes training
209 |
210 | # Masking scale and shift
211 | scale = ~mask * scale
212 | shift = ~mask * shift
213 |
214 | # Computing output
215 | out = (x + shift) * torch.exp(scale)
216 |
217 | # Computing log of the determinant of the Jacobian
218 | log_det_j = torch.sum(scale, dim=[1, 2, 3])
219 |
220 | return out, log_det_j
221 |
222 | def backward(self, y):
223 | # Splitting input
224 | mask = self.mask.to(y.device)
225 |
226 | x1 = mask * y
227 |
228 | # Computing scale and shift
229 | scale, shift = self.m(x1).chunk(2, 1)
230 | s_fac = self.scaling_fac.exp().view(1, -1, 1, 1)
231 | scale = torch.tanh(scale / s_fac) * s_fac
232 |
233 | # Masking scale and shift
234 | scale = ~mask * scale
235 | shift = ~mask * shift
236 |
237 | # Computing inverse transformation
238 | out = y / torch.exp(scale) - shift
239 |
240 | # Computing log of the determinant of the Jacobian (for backward tranformation)
241 | log_det_j = -torch.sum(scale, dim=[1, 2, 3])
242 |
243 | return out, log_det_j
244 |
245 |
246 | class Flow(nn.Module):
247 | """General Flow model. Uses invertible layers to map distributions."""
248 |
249 | def __init__(self, layers):
250 | super(Flow, self).__init__()
251 | self.layers = nn.ModuleList(layers)
252 |
253 | def forward(self, x):
254 | # Computing forward pass (images --> gaussian noise)
255 | out, log_det_j = x, 0
256 | for layer in self.layers:
257 | out, log_det_j_layer = layer(out)
258 | log_det_j += log_det_j_layer
259 |
260 | return out, log_det_j
261 |
262 | def backward(self, y):
263 | # Sampling with backward pass (gaussian noise --> images)
264 | out, log_det_j = y, 0
265 | for layer in self.layers[::-1]:
266 | out, log_det_j_layer = layer.backward(out)
267 | log_det_j += log_det_j_layer
268 |
269 | return out, log_det_j
270 |
271 |
272 | def training_loop(model, epochs, lr, loader, device, dir):
273 | """Trains the model"""
274 |
275 | model.train()
276 | best_loss = float("inf")
277 | optim = Adam(model.parameters(), lr=lr)
278 | scheduler = StepLR(optimizer=optim, step_size=1, gamma=0.99)
279 | to_bpd = np.log2(np.exp(1)) / (
280 | 28 * 28 * 1
281 | ) # Constant that normalizes w.r.t. input shape
282 |
283 | prior = torch.distributions.normal.Normal(loc=0.0, scale=1.0)
284 |
285 | for epoch in tqdm(range(epochs), desc="Training progress", colour="#00ff00"):
286 | epoch_loss = 0.0
287 | for batch in tqdm(
288 | loader, leave=False, desc=f"Epoch {epoch + 1}/{epochs}", colour="#005500"
289 | ):
290 | # Getting a batch of images and applying dequantization
291 | x = batch[0].to(device)
292 |
293 | # Running images forward and getting log likelihood (log_px)
294 | z, log_det_j = model(x)
295 | # log_pz = -np.log(np.sqrt(2*np.pi)) -(z**2).sum(dim=[1,2,3]) # Because we are mapping to a normal N(0, 1)
296 | log_pz = prior.log_prob(z).sum(dim=[1, 2, 3])
297 | log_px = log_pz + log_det_j
298 |
299 | # Getting the loss to be optimized (scaling with bits per dimension)
300 | loss = (-(log_px * to_bpd)).mean()
301 |
302 | # Optimization step
303 | optim.zero_grad()
304 | loss.backward()
305 | torch.nn.utils.clip_grad_norm_(
306 | model.parameters(), 1
307 | ) # Clipping gradient norm
308 | optim.step()
309 |
310 | # Logging variable
311 | epoch_loss += loss.item() / len(loader)
312 |
313 | # Stepping with the LR scheduler
314 | scheduler.step()
315 |
316 | # Logging epoch result and storing best model
317 | log_str = f"Epoch {epoch + 1}/{epochs} loss: {epoch_loss:.3f}"
318 | if best_loss > epoch_loss:
319 | best_loss = epoch_loss
320 | log_str += " --> Storing model"
321 | torch.save(model.state_dict(), os.path.join(dir, "nf_model.pt"))
322 | print(log_str)
323 |
324 |
325 | def main():
326 | # Program arguments
327 | parser = ArgumentParser()
328 | parser.add_argument(
329 | "--epochs", type=int, default=500, help="Number of training epochs"
330 | )
331 | parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
332 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
333 | parser.add_argument("--gpu", type=int, default=0, help="GPU number")
334 | parser.add_argument(
335 | "--store_dir", type=str, default=os.getcwd(), help="Store directory"
336 | )
337 | args = vars(parser.parse_args())
338 |
339 | N_EPOCHS = args["epochs"]
340 | LR = args["lr"]
341 | BATCH_SIZE = args["batch_size"]
342 | GPU = args["gpu"]
343 | DIR = args["store_dir"]
344 |
345 | # Loading data (images are put in range [0, 255] and are copied on the channel dimension)
346 | transform = Compose([ToTensor(), Lambda(lambda x: (255 * x).to(torch.int32))])
347 | dataset = MNIST(
348 | root="./../datasets", train=True, download=True, transform=transform
349 | )
350 | loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
351 |
352 | # Device
353 | device = torch.device(f"cuda:{GPU}" if torch.cuda.is_available() else "cpu")
354 | device_log = f"Using device: {device} " + (
355 | f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else ""
356 | )
357 | print(device_log)
358 |
359 | # Creating the model
360 | model = Flow(
361 | [
362 | Dequantization(256),
363 | *[AffineCoupling(SimpleCNN(), modify_x2=i % 2 == 0) for i in range(30)],
364 | ]
365 | ).to(device)
366 |
367 | # Showing number of trainable paramsk
368 | trainable_params = 0
369 | for param in model.parameters():
370 | trainable_params += np.prod(param.shape) if param.requires_grad else 0
371 | print(f"The model has {trainable_params} trainable parameters.")
372 |
373 | # Loading pre-trained model (if any)
374 | sd_path = os.path.join(DIR, "nf_model.pt")
375 | pretrained_exists = os.path.isfile(sd_path)
376 | if pretrained_exists:
377 | model.load_state_dict(torch.load(sd_path, map_location=device))
378 | print("Pre-trained model found and loaded")
379 |
380 | # Testing reversability with first image in the dataset
381 | test_reversability(model, dataset[0][0].unsqueeze(0).to(device))
382 |
383 | # Training loop (ony if model doesn't exist)
384 | if not pretrained_exists:
385 | training_loop(model, N_EPOCHS, LR, loader, device, DIR)
386 | sd_path = os.path.join(DIR, "nf_model.pt")
387 | model.load_state_dict(torch.load(sd_path, map_location=device))
388 |
389 | # Testing the trained model
390 | model.eval()
391 | with torch.no_grad():
392 | # Mapping the normally distributed noise to new images
393 | noise = torch.randn(64, 1, 28, 28).to(device)
394 | images = model.backward(noise)[0]
395 |
396 | save_image(images.float(), "Generated digits.png")
397 | Image.open("Generated digits.png").show()
398 |
399 | # Showing new latent mapping of first image in the dataset
400 | test_reversability(model, dataset[0][0].unsqueeze(0).to(device))
401 |
402 |
403 | if __name__ == "__main__":
404 | main()
405 |
--------------------------------------------------------------------------------
/src/cv/vir/train.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import pytorch_lightning as pl
4 | import torch
5 | from datasets import load_dataset
6 | from torch.utils.data import DataLoader
7 | from torchvision.transforms import (
8 | Compose,
9 | Normalize,
10 | RandomHorizontalFlip,
11 | RandomRotation,
12 | Resize,
13 | ToTensor,
14 | )
15 |
16 | from src.cv.vir.vir import ViR, ViRModes
17 |
18 |
19 | class ViRLightningModule(pl.LightningModule):
20 | def __init__(
21 | self,
22 | lr=1e-3,
23 | out_dim=10,
24 | patch_size=14,
25 | depth=12,
26 | heads=12,
27 | embed_dim=768,
28 | max_len=257,
29 | alpha=1.0,
30 | mode=ViRModes.PARALLEL,
31 | dropout=0.1,
32 | ):
33 | super(ViRLightningModule, self).__init__()
34 | self.lr = lr
35 | self.model = ViR(
36 | out_dim,
37 | patch_size,
38 | depth,
39 | heads,
40 | embed_dim,
41 | max_len,
42 | alpha,
43 | mode,
44 | dropout,
45 | )
46 |
47 | def forward(self, x):
48 | return self.model(x)
49 |
50 | def training_step(self, batch, batch_idx):
51 | x, y = batch["image"], batch["label"]
52 | y_hat = self(x)
53 | acc = (y_hat.argmax(dim=1) == y).float().mean()
54 | loss = torch.nn.functional.cross_entropy(y_hat, y)
55 | self.log_dict(
56 | {
57 | "train_loss": loss,
58 | "train_acc": acc,
59 | }
60 | )
61 | return loss
62 |
63 | def validation_step(self, batch, batch_idx):
64 | x, y = batch["image"], batch["label"]
65 | y_hat = self(x)
66 | acc = (y_hat.argmax(dim=1) == y).float().mean()
67 | loss = torch.nn.functional.cross_entropy(y_hat, y)
68 | self.log_dict(
69 | {
70 | "validation_loss": loss,
71 | "validation_acc": acc,
72 | }
73 | )
74 | return loss
75 |
76 | def test_step(self, batch, batch_idx):
77 | x, y = batch["image"], batch["label"]
78 | y_hat = self(x)
79 | acc = (y_hat.argmax(dim=1) == y).float().mean()
80 | loss = torch.nn.functional.cross_entropy(y_hat, y)
81 | self.log_dict(
82 | {
83 | "test_loss": loss,
84 | "test_acc": acc,
85 | }
86 | )
87 | return loss
88 |
89 | def configure_optimizers(self):
90 | optim = torch.optim.Adam(self.trainer.model.parameters(), self.lr)
91 | return optim
92 |
93 |
94 | def main(args):
95 | # Seed everything
96 | pl.seed_everything(args["seed"])
97 |
98 | # Data
99 | resize = Resize((args["image_size"], args["image_size"]))
100 | normalize = Normalize([0.5] * 3, [0.5] * 3)
101 | train_transform = Compose(
102 | [resize, ToTensor(), normalize, RandomHorizontalFlip(), RandomRotation(5)]
103 | )
104 | val_transform = Compose([resize, ToTensor(), normalize])
105 |
106 | def make_transform(fn):
107 | def transform(samples):
108 | samples["image"] = [fn(img.convert("RGB")) for img in samples["image"]]
109 | samples["label"] = torch.tensor(samples["label"])
110 | return samples
111 |
112 | return transform
113 |
114 | train_set = load_dataset("frgfm/imagenette", "320px", split="train")
115 | val_set = load_dataset("frgfm/imagenette", "320px", split="validation")
116 |
117 | train_set.set_transform(make_transform(train_transform))
118 | val_set.set_transform(make_transform(val_transform))
119 |
120 | train_loader = DataLoader(
121 | train_set,
122 | batch_size=args["batch_size"],
123 | shuffle=True,
124 | num_workers=args["num_workers"],
125 | )
126 | val_loader = DataLoader(
127 | val_set,
128 | batch_size=args["batch_size"],
129 | shuffle=False,
130 | num_workers=args["num_workers"],
131 | )
132 |
133 | # Load model
134 | model = ViRLightningModule(
135 | args["lr"],
136 | out_dim=10,
137 | patch_size=args["patch_size"],
138 | depth=args["depth"],
139 | heads=args["heads"],
140 | embed_dim=args["embed_dim"],
141 | max_len=(args["image_size"] // args["patch_size"]) ** 2 + 1,
142 | alpha=args["alpha"],
143 | mode=ViRModes.PARALLEL,
144 | dropout=args["dropout"],
145 | )
146 |
147 | # Train model
148 | logger = pl.loggers.WandbLogger(project="Papers Re-implementations", name="ViR")
149 | logger.experiment.config.update(args)
150 | trainer = pl.Trainer(
151 | strategy="ddp",
152 | accelerator="auto",
153 | max_epochs=args["epochs"],
154 | logger=logger,
155 | callbacks=[
156 | pl.callbacks.ModelCheckpoint(
157 | dirpath=args["checkpoint_dir"],
158 | filename="vir-model",
159 | save_top_k=3,
160 | monitor="train_loss",
161 | mode="min",
162 | )
163 | ],
164 | )
165 | trainer.fit(model, train_loader)
166 |
167 | # Evaluate model (setting to recurrent mode)
168 | model.model.set_compute_mode(ViRModes.RECURRENT)
169 | trainer.test(model, val_loader)
170 |
171 |
172 | if __name__ == "__main__":
173 | parser = ArgumentParser()
174 |
175 | # Training arguments
176 | parser.add_argument("--seed", help="Random seed", type=int, default=0)
177 | parser.add_argument(
178 | "--checkpoint_dir", help="Checkpoint directory", type=str, default="checkpoints"
179 | )
180 | parser.add_argument("--epochs", help="Number of epochs", type=int, default=40)
181 | parser.add_argument("--lr", help="Learning rate", type=float, default=1e-3)
182 | parser.add_argument("--batch_size", help="Batch size", type=int, default=64)
183 | parser.add_argument("--image_size", help="Image size", type=int, default=224)
184 | parser.add_argument("--num_workers", help="Number of workers", type=int, default=4)
185 |
186 | # Model arguments
187 | parser.add_argument("--patch_size", help="Patch size", type=int, default=14)
188 | parser.add_argument("--depth", help="Depth", type=int, default=12)
189 | parser.add_argument("--heads", help="Heads", type=int, default=3)
190 | parser.add_argument(
191 | "--embed_dim", help="Embedding dimension", type=int, default=192
192 | )
193 | parser.add_argument("--alpha", help="Alpha", type=float, default=0.99)
194 | parser.add_argument("--dropout", help="Dropout", type=float, default=0.1)
195 |
196 | args = vars(parser.parse_args())
197 |
198 | print("\n\nProgram arguments:\n\n", args, "\n\n")
199 | main(args)
200 |
--------------------------------------------------------------------------------
/src/cv/vir/vir.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | DEFAULT_ALPHA = 1.00
5 |
6 |
7 | class ViRModes:
8 | PARALLEL = "parallel"
9 | RECURRENT = "recurrent"
10 | CHUNKWISE = "chunkwise"
11 |
12 |
13 | class Retention(nn.Module):
14 | def __init__(
15 | self,
16 | embed_dim,
17 | max_len,
18 | alpha=DEFAULT_ALPHA,
19 | mode=ViRModes.PARALLEL,
20 | chunk_size=20,
21 | ):
22 | super(Retention, self).__init__()
23 | self.dim = embed_dim
24 | self.max_len = max_len
25 | self.chunk_size = chunk_size
26 | self.alpha = alpha
27 | self.mode = mode
28 |
29 | # Useful buffers
30 | self.register_buffer("dim_sqrt", torch.tensor(embed_dim**0.5))
31 | self.register_buffer(
32 | "decay_mask",
33 | torch.tensor(
34 | [[alpha ** (i - j) for j in range(max_len)] for i in range(max_len)]
35 | ),
36 | )
37 | self.register_buffer("causal_mask", torch.ones(max_len, max_len).tril())
38 | self.register_buffer(
39 | "retention_mask_chunkwise",
40 | torch.tensor(
41 | [self.alpha ** (chunk_size - i - 1) for i in range(chunk_size)]
42 | ),
43 | )
44 |
45 | self.register_buffer(
46 | "cross_mask_chunkwise",
47 | torch.tensor([self.alpha ** (i + 1) for i in range(chunk_size)]),
48 | )
49 | self.qkv = nn.Linear(embed_dim, embed_dim * 3)
50 |
51 | def forward_parallel(self, x):
52 | # Getting queries, keys, values
53 | bs, sl, d = x.shape
54 | qkv = self.qkv(x)
55 | q, k, v = torch.chunk(qkv, 3, dim=-1)
56 |
57 | # Causal and decay masking
58 | M = (self.causal_mask[:sl, :sl] * self.decay_mask[:sl, :sl]).repeat(bs, 1, 1)
59 |
60 | # Retention
61 | out = (q @ k.transpose(-1, -2) / self.dim_sqrt * M) @ v
62 |
63 | return out
64 |
65 | def forward_recurrent(self, x, state):
66 | batch_size, length, dim = x.shape
67 |
68 | all_outputs = []
69 | state = torch.zeros(batch_size, dim, dim).to(x.device)
70 | for i in range(length):
71 | xi = x[:, i]
72 | q, k, v = self.qkv(xi).chunk(3, dim=-1)
73 |
74 | state = self.alpha * state + k.unsqueeze(-1) @ v.unsqueeze(1)
75 | out = q.unsqueeze(1) @ state / self.dim_sqrt
76 | all_outputs.append(out.squeeze())
77 |
78 | x = torch.stack(all_outputs, dim=1)
79 | return x
80 |
81 | def forward_chunkwise(self, x, chunk_size=None):
82 | # Getting queries, keys, values
83 | if chunk_size is None:
84 | chunk_size = self.chunk_size
85 |
86 | bs, sl, d = x.shape
87 |
88 | # Adding dummy tokens to make the sequence length divisible by chunk_size
89 | if sl % chunk_size != 0:
90 | x = torch.cat(
91 | [x, torch.zeros(bs, chunk_size - sl % chunk_size, d).to(x.device)],
92 | dim=1,
93 | )
94 | n_chunks = x.shape[1] // chunk_size
95 |
96 | # Running all chunks in parallel
97 | x = x.reshape(bs, n_chunks, chunk_size, d)
98 | q, k, v = self.qkv(x).chunk(3, dim=-1)
99 |
100 | M = (
101 | self.causal_mask[:chunk_size, :chunk_size]
102 | * self.decay_mask[:chunk_size, :chunk_size]
103 | ).repeat(bs, n_chunks, 1, 1)
104 |
105 | inner_chunk = (q @ k.transpose(-1, -2) / self.dim_sqrt * M) @ v
106 |
107 | # Updating outputs with chunk-wise recurrent
108 | retention_mask = self.retention_mask_chunkwise.repeat(bs, d, 1).transpose(
109 | -1, -2
110 | )
111 | cross_mask = self.cross_mask_chunkwise.repeat(bs, n_chunks, d, 1).transpose(
112 | -1, -2
113 | )
114 |
115 | states = torch.zeros(bs, n_chunks, d, d).to(x.device)
116 | for i in range(1, n_chunks):
117 | chunk_state = k[:, i - 1].transpose(-1, -2) @ (v[:, i - 1] * retention_mask)
118 | states[:, i] = chunk_state + states[:, i - 1] * self.alpha**chunk_size
119 |
120 | cross_chunk = (q @ states) / self.dim_sqrt * cross_mask
121 |
122 | # Combining inner and cross chunk
123 | out = inner_chunk + cross_chunk
124 |
125 | # Removing dummy tokens
126 | out = out.flatten(1, 2)[:, :sl]
127 | return out
128 |
129 | def forward(self, x, state=None, mode=ViRModes.PARALLEL, chunk_size=None):
130 | if mode is None:
131 | mode = self.mode
132 |
133 | if mode == ViRModes.PARALLEL:
134 | return self.forward_parallel(x)
135 | elif mode == ViRModes.RECURRENT:
136 | return self.forward_recurrent(x, state)
137 | elif mode == ViRModes.CHUNKWISE:
138 | return self.forward_chunkwise(x, chunk_size)
139 | else:
140 | raise ValueError(f"Unknown mode {mode}")
141 |
142 |
143 | class MultiHeadRetention(nn.Module):
144 | def __init__(
145 | self,
146 | heads,
147 | embed_dim,
148 | max_len,
149 | alpha=DEFAULT_ALPHA,
150 | mode=ViRModes.PARALLEL,
151 | chunk_size=20,
152 | ):
153 | super(MultiHeadRetention, self).__init__()
154 | self.n_heads = heads
155 | self.embed_dim = embed_dim
156 | self.head_dim = embed_dim // heads
157 | self.mode = mode
158 | self.chunk_size = chunk_size
159 |
160 | assert (
161 | embed_dim % heads == 0
162 | ), "Embedding dimension must be divisible by the number of heads"
163 |
164 | self.heads = nn.ModuleList(
165 | [
166 | Retention(embed_dim // heads, max_len, alpha, chunk_size)
167 | for _ in range(heads)
168 | ]
169 | )
170 | self.ln = nn.LayerNorm(embed_dim)
171 | self.gelu = nn.GELU()
172 | self.linear = nn.Linear(embed_dim, embed_dim)
173 |
174 | def forward(self, x, mode=None, chunk_size=None):
175 | if mode is None:
176 | mode = self.mode
177 |
178 | if chunk_size is None:
179 | chunk_size = self.chunk_size
180 |
181 | out = torch.cat(
182 | [
183 | head(
184 | x[:, :, i * self.head_dim : (i + 1) * self.head_dim],
185 | mode=mode,
186 | chunk_size=chunk_size,
187 | )
188 | for i, head in enumerate(self.heads)
189 | ],
190 | dim=-1,
191 | )
192 | return self.linear(self.gelu(self.ln(out)))
193 |
194 |
195 | class MLP(nn.Module):
196 | def __init__(self, embed_dim, hidden_dim=None):
197 | super(MLP, self).__init__()
198 |
199 | if hidden_dim is None:
200 | hidden_dim = 4 * embed_dim
201 |
202 | self.linear1 = nn.Linear(embed_dim, hidden_dim)
203 | self.linear2 = nn.Linear(hidden_dim, embed_dim)
204 | self.gelu = nn.GELU()
205 |
206 | def forward(self, x):
207 | return self.linear2(self.gelu(self.linear1(x)))
208 |
209 |
210 | class ViRBlock(nn.Module):
211 | def __init__(
212 | self,
213 | heads,
214 | embed_dim,
215 | max_len,
216 | alpha=DEFAULT_ALPHA,
217 | mode=ViRModes.PARALLEL,
218 | chunk_size=20,
219 | dropout=0.1,
220 | ):
221 | super(ViRBlock, self).__init__()
222 | self.mode = mode
223 | self.chunk_size = chunk_size
224 |
225 | self.ln1 = nn.LayerNorm(embed_dim)
226 | self.retention = MultiHeadRetention(
227 | heads, embed_dim, max_len, alpha, mode, chunk_size
228 | )
229 | self.ln2 = nn.LayerNorm(embed_dim)
230 | self.mlp = MLP(embed_dim)
231 | self.dropout1 = nn.Dropout(dropout)
232 | self.dropout2 = nn.Dropout(dropout)
233 |
234 | def forward(self, x, mode=None, chunk_size=None):
235 | if mode is None:
236 | mode = self.mode
237 |
238 | if chunk_size is None:
239 | chunk_size = self.chunk_size
240 |
241 | x = (
242 | self.dropout1(self.retention(self.ln1(x), mode=mode, chunk_size=chunk_size))
243 | + x
244 | )
245 | x = self.dropout2(self.mlp(self.ln2(x))) + x
246 | return x
247 |
248 |
249 | class ViR(nn.Module):
250 | def __init__(
251 | self,
252 | out_dim=10,
253 | patch_size=14,
254 | depth=12,
255 | heads=12,
256 | embed_dim=768,
257 | max_len=257,
258 | alpha=DEFAULT_ALPHA,
259 | mode=ViRModes.PARALLEL,
260 | chunk_size=20,
261 | dropout=0.1,
262 | ):
263 | super(ViR, self).__init__()
264 |
265 | # Local parameters
266 | self.out_dim = 10
267 | self.patch_size = patch_size
268 | self.depth = depth
269 | self.heads = heads
270 | self.embed_dim = embed_dim
271 | self.max_len = max_len
272 | self.alpha = alpha
273 | self.mode = mode
274 | self.chunk_size = chunk_size
275 |
276 | # Embeddings
277 | self.patch_embed = nn.Conv2d(
278 | 3, embed_dim, (patch_size, patch_size), stride=(patch_size, patch_size)
279 | )
280 | self.pos_embed = nn.Parameter(torch.randn(1, max_len, embed_dim))
281 | self.class_token = nn.Parameter(torch.randn(1, 1, embed_dim))
282 |
283 | # ViR blocks
284 | self.blocks = nn.ModuleList(
285 | [
286 | ViRBlock(heads, embed_dim, max_len, alpha, mode, chunk_size, dropout)
287 | for _ in range(depth)
288 | ]
289 | )
290 |
291 | # Head
292 | self.ln = nn.LayerNorm(embed_dim)
293 | self.linear = nn.Linear(embed_dim, out_dim)
294 |
295 | def set_compute_mode(self, mode):
296 | self.mode = mode
297 |
298 | def forward(self, x, mode=None, chunk_size=None):
299 | if mode is None:
300 | mode = self.mode
301 |
302 | if chunk_size is None:
303 | chunk_size = self.chunk_size
304 |
305 | # Patch embedding, positional embedding, CLS token
306 | x = self.patch_embed(x).permute(0, 2, 3, 1).flatten(1, 2)
307 | bs, sl = x.shape[:2]
308 | x = x + self.pos_embed.repeat(bs, 1, 1)[:, :sl]
309 | x = torch.cat(
310 | (x, self.class_token.repeat(bs, 1, 1)), dim=1
311 | ) # Important: CLS token goes last
312 |
313 | # Blocks
314 | for block in self.blocks:
315 | x = block(x, mode=mode, chunk_size=chunk_size)
316 |
317 | # Head on the CLS token
318 | x = self.linear(self.ln(x[:, -1]))
319 |
320 | return x
321 |
322 |
323 | if __name__ == "__main__":
324 | """Tests that parallel and recurrent modes give the same output for ViR"""
325 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
326 | x = torch.randn(16, 3, 224, 224).to(device)
327 | model = ViR(depth=12, heads=3, embed_dim=192).eval().to(device)
328 |
329 | with torch.no_grad():
330 | model.set_compute_mode(ViRModes.CHUNKWISE)
331 | chunk_size = 20
332 | y3 = model(x, chunk_size=chunk_size)
333 |
--------------------------------------------------------------------------------
/src/cv/vit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/cv/vit/__init__.py
--------------------------------------------------------------------------------
/src/cv/vit/vit_torch.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import CrossEntropyLoss
5 | from torch.optim import Adam
6 | from torch.utils.data import DataLoader
7 | from torchvision.datasets.mnist import MNIST
8 | from torchvision.transforms import ToTensor
9 | from tqdm import tqdm, trange
10 |
11 | np.random.seed(0)
12 | torch.manual_seed(0)
13 |
14 |
15 | def patchify(images, n_patches):
16 | n, c, h, w = images.shape
17 |
18 | assert h == w, "Patchify method is implemented for square images only"
19 |
20 | patches = torch.zeros(n, n_patches**2, h * w * c // n_patches**2)
21 | patch_size = h // n_patches
22 |
23 | for idx, image in enumerate(images):
24 | for i in range(n_patches):
25 | for j in range(n_patches):
26 | patch = image[
27 | :,
28 | i * patch_size : (i + 1) * patch_size,
29 | j * patch_size : (j + 1) * patch_size,
30 | ]
31 | patches[idx, i * n_patches + j] = patch.flatten()
32 | return patches
33 |
34 |
35 | class MyMSA(nn.Module):
36 | def __init__(self, d, n_heads=2):
37 | super(MyMSA, self).__init__()
38 | self.d = d
39 | self.n_heads = n_heads
40 |
41 | assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"
42 |
43 | d_head = int(d / n_heads)
44 | self.q_mappings = nn.ModuleList(
45 | [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
46 | )
47 | self.k_mappings = nn.ModuleList(
48 | [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
49 | )
50 | self.v_mappings = nn.ModuleList(
51 | [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
52 | )
53 | self.d_head = d_head
54 | self.softmax = nn.Softmax(dim=-1)
55 |
56 | def forward(self, sequences):
57 | # Sequences has shape (N, seq_length, token_dim)
58 | # We go into shape (N, seq_length, n_heads, token_dim / n_heads)
59 | # And come back to (N, seq_length, item_dim) (through concatenation)
60 | result = []
61 | for sequence in sequences:
62 | seq_result = []
63 | for head in range(self.n_heads):
64 | q_mapping = self.q_mappings[head]
65 | k_mapping = self.k_mappings[head]
66 | v_mapping = self.v_mappings[head]
67 |
68 | seq = sequence[:, head * self.d_head : (head + 1) * self.d_head]
69 | q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)
70 |
71 | attention = self.softmax(q @ k.T / (self.d_head**0.5))
72 | seq_result.append(attention @ v)
73 | result.append(torch.hstack(seq_result))
74 | return torch.cat([torch.unsqueeze(r, dim=0) for r in result])
75 |
76 |
77 | class MyViTBlock(nn.Module):
78 | def __init__(self, hidden_d, n_heads, mlp_ratio=4):
79 | super(MyViTBlock, self).__init__()
80 | self.hidden_d = hidden_d
81 | self.n_heads = n_heads
82 |
83 | self.norm1 = nn.LayerNorm(hidden_d)
84 | self.mhsa = MyMSA(hidden_d, n_heads)
85 | self.norm2 = nn.LayerNorm(hidden_d)
86 | self.mlp = nn.Sequential(
87 | nn.Linear(hidden_d, mlp_ratio * hidden_d),
88 | nn.GELU(),
89 | nn.Linear(mlp_ratio * hidden_d, hidden_d),
90 | )
91 |
92 | def forward(self, x):
93 | out = x + self.mhsa(self.norm1(x))
94 | out = out + self.mlp(self.norm2(out))
95 | return out
96 |
97 |
98 | class MyViT(nn.Module):
99 | def __init__(self, chw, n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10):
100 | # Super constructor
101 | super(MyViT, self).__init__()
102 |
103 | # Attributes
104 | self.chw = chw # ( C , H , W )
105 | self.n_patches = n_patches
106 | self.n_blocks = n_blocks
107 | self.n_heads = n_heads
108 | self.hidden_d = hidden_d
109 |
110 | # Input and patches sizes
111 | assert (
112 | chw[1] % n_patches == 0
113 | ), "Input shape not entirely divisible by number of patches"
114 | assert (
115 | chw[2] % n_patches == 0
116 | ), "Input shape not entirely divisible by number of patches"
117 | self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)
118 |
119 | # 1) Linear mapper
120 | self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
121 | self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)
122 |
123 | # 2) Learnable classification token
124 | self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))
125 |
126 | # 3) Positional embedding
127 | self.register_buffer(
128 | "positional_embeddings",
129 | get_positional_embeddings(n_patches**2 + 1, hidden_d),
130 | persistent=False,
131 | )
132 |
133 | # 4) Transformer encoder blocks
134 | self.blocks = nn.ModuleList(
135 | [MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)]
136 | )
137 |
138 | # 5) Classification MLPk
139 | self.mlp = nn.Sequential(nn.Linear(self.hidden_d, out_d), nn.Softmax(dim=-1))
140 |
141 | def forward(self, images):
142 | # Dividing images into patches
143 | n, c, h, w = images.shape
144 | patches = patchify(images, self.n_patches).to(self.positional_embeddings.device)
145 |
146 | # Running linear layer tokenization
147 | # Map the vector corresponding to each patch to the hidden size dimension
148 | tokens = self.linear_mapper(patches)
149 |
150 | # Adding classification token to the tokens
151 | tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)
152 |
153 | # Adding positional embedding
154 | out = tokens + self.positional_embeddings.repeat(n, 1, 1)
155 |
156 | # Transformer Blocks
157 | for block in self.blocks:
158 | out = block(out)
159 |
160 | # Getting the classification token only
161 | out = out[:, 0]
162 |
163 | return self.mlp(out) # Map to output dimension, output category distribution
164 |
165 |
166 | def get_positional_embeddings(sequence_length, d):
167 | result = torch.ones(sequence_length, d)
168 | for i in range(sequence_length):
169 | for j in range(d):
170 | result[i][j] = (
171 | np.sin(i / (10000 ** (j / d)))
172 | if j % 2 == 0
173 | else np.cos(i / (10000 ** ((j - 1) / d)))
174 | )
175 | return result
176 |
177 |
178 | def main():
179 | # Loading data
180 | transform = ToTensor()
181 |
182 | train_set = MNIST(
183 | root="./../datasets", train=True, download=True, transform=transform
184 | )
185 | test_set = MNIST(
186 | root="./../datasets", train=False, download=True, transform=transform
187 | )
188 |
189 | train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
190 | test_loader = DataLoader(test_set, shuffle=False, batch_size=128)
191 |
192 | # Defining model and training options
193 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
194 | print(
195 | "Using device: ",
196 | device,
197 | f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "",
198 | )
199 | model = MyViT(
200 | (1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10
201 | ).to(device)
202 | N_EPOCHS = 5
203 | LR = 0.005
204 |
205 | # Training loop
206 | optimizer = Adam(model.parameters(), lr=LR)
207 | criterion = CrossEntropyLoss()
208 | for epoch in trange(N_EPOCHS, desc="Training"):
209 | train_loss = 0.0
210 | for batch in tqdm(
211 | train_loader, desc=f"Epoch {epoch + 1} in training", leave=False
212 | ):
213 | x, y = batch
214 | x, y = x.to(device), y.to(device)
215 | y_hat = model(x)
216 | loss = criterion(y_hat, y)
217 |
218 | train_loss += loss.detach().cpu().item() / len(train_loader)
219 |
220 | optimizer.zero_grad()
221 | loss.backward()
222 | optimizer.step()
223 |
224 | print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")
225 |
226 | # Test loop
227 | with torch.no_grad():
228 | correct, total = 0, 0
229 | test_loss = 0.0
230 | for batch in tqdm(test_loader, desc="Testing"):
231 | x, y = batch
232 | x, y = x.to(device), y.to(device)
233 | y_hat = model(x)
234 | loss = criterion(y_hat, y)
235 | test_loss += loss.detach().cpu().item() / len(test_loader)
236 |
237 | correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
238 | total += len(x)
239 | print(f"Test loss: {test_loss:.2f}")
240 | print(f"Test accuracy: {correct / total * 100:.2f}%")
241 |
242 |
243 | if __name__ == "__main__":
244 | main()
245 |
--------------------------------------------------------------------------------
/src/fff/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/fff/__init__.py
--------------------------------------------------------------------------------
/src/fff/fff.py:
--------------------------------------------------------------------------------
1 | """
2 | Unofficial re-implementation of
3 |
4 | 'Fast Feedforward Networks'
5 |
6 | by Peter Belcak and Roger Wattenhofer
7 |
8 | ArXiv: https://arxiv.org/abs/2308.14711
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 |
15 | class FFFMLP(nn.Module):
16 | def __init__(
17 | self,
18 | in_dim: int,
19 | hidden_dim: int,
20 | out_dim: int,
21 | hidden_activation: nn.Module = nn.ReLU(),
22 | out_activation: nn.Module = nn.Sigmoid(),
23 | swap_prob: float = 0.1,
24 | ):
25 | super(FFFMLP, self).__init__()
26 | self.in_dim = in_dim
27 | self.hidden_dim = hidden_dim
28 | self.out_dim = out_dim
29 | self.swap_prob = swap_prob
30 |
31 | self.linear1 = nn.Linear(in_dim, hidden_dim)
32 | self.hidden_activation = hidden_activation
33 | self.linear2 = nn.Linear(hidden_dim, out_dim)
34 | self.activation = out_activation
35 |
36 | def forward(self, x: torch.Tensor):
37 | out = self.linear2(self.hidden_activation(self.linear1(x)))
38 | out = self.activation(out)
39 |
40 | if self.training and torch.rand(1) < self.swap_prob:
41 | out = 1 - out
42 |
43 | return out
44 |
45 |
46 | class FFFLayer(nn.Module):
47 | def __init__(
48 | self, depth: int, in_dim: int, node_hidden: int, leaf_hidden: int, out_dim: int
49 | ):
50 | super(FFFLayer, self).__init__()
51 |
52 | self.depth = depth
53 | self.node_hidden = node_hidden
54 | self.leaf_hidden = leaf_hidden
55 |
56 | nodes = [FFFMLP(in_dim, node_hidden, 1) for _ in range(2 ** (depth) - 1)]
57 | leaves = [
58 | FFFMLP(
59 | in_dim,
60 | leaf_hidden,
61 | out_dim,
62 | out_activation=nn.Identity(),
63 | swap_prob=0.0,
64 | )
65 | for _ in range(2**depth)
66 | ]
67 | self.tree = nn.ModuleList(nodes + leaves)
68 |
69 | def forward(self, x: torch.Tensor, idx: int = 1, activations: list = None):
70 | c = self.tree[idx - 1](x)
71 |
72 | if 2 * idx + 1 <= len(self.tree):
73 | # Continuing down the tree
74 | if self.training:
75 | # During training, split signal all over the tree
76 | left, a_left = self.forward(x, 2 * idx, activations)
77 | right, a_right = self.forward(x, 2 * idx + 1, activations)
78 |
79 | children_are_leaves = a_left is None and a_right is None
80 | activations = [c] if children_are_leaves else a_left + a_right + [c]
81 |
82 | return c * left + (1 - c) * right, activations
83 | else:
84 | # During inference, input goes through a single path
85 | left, a_left = self.forward(x, 2 * idx, activations)
86 | right, a_right = self.forward(x, 2 * idx + 1, activations)
87 |
88 | children_are_leaves = a_left is None and a_right is None
89 | activations = [c] if children_are_leaves else a_left + a_right + [c]
90 |
91 | # Hardening
92 | # TODO: Improve and actually go down only one path
93 | c = (c >= 0.5).float()
94 |
95 | return c * left + (1 - c) * right, activations
96 |
97 | return c, activations
98 |
--------------------------------------------------------------------------------
/src/fff/main.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import pytorch_lightning as pl
4 | import torch
5 | import torch.nn as nn
6 | from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
7 | from pytorch_lightning.loggers import WandbLogger
8 | from torch.optim import SGD
9 | from torch.optim.lr_scheduler import LinearLR
10 | from torch.utils.data import DataLoader
11 | from torchmetrics import Accuracy
12 | from torchvision.datasets import MNIST
13 | from torchvision.transforms import Compose, ToTensor
14 |
15 | from src.fff.fff import FFFLayer
16 |
17 |
18 | class FlattenMNIST:
19 | def __call__(self, x: torch.Tensor) -> torch.Tensor:
20 | return x.flatten()
21 |
22 |
23 | class PLWrapper(pl.LightningModule):
24 | def __init__(self, model, total_iters=10, hardening_weight=3.0, lr=0.2):
25 | super(PLWrapper, self).__init__()
26 | self.model = model
27 | self.cross_entropy = nn.CrossEntropyLoss()
28 | self.accuracy = Accuracy("multiclass", num_classes=10)
29 | self.total_iters = total_iters
30 | self.hardening_weight = hardening_weight
31 | self.lr = lr
32 |
33 | def forward(self, x):
34 | return self.model(x)
35 |
36 | def entropy(self, x):
37 | entropies = x * torch.log(x + 1e-8) + (1 - x) * torch.log(1 - x + 1e-8)
38 | return -torch.mean(entropies)
39 |
40 | def configure_optimizers(self):
41 | optimizer = SGD(self.model.parameters(), lr=self.lr)
42 | scheduler = LinearLR(optimizer, 1, 0, total_iters=self.total_iters)
43 | return [optimizer], [scheduler]
44 |
45 | def training_step(self, batch, batch_idx):
46 | x, y = batch
47 | y_hat, activations = self.model(x)
48 | activations = torch.cat(activations, dim=0)
49 | le = self.hardening_weight * self.entropy(activations)
50 | lce = self.cross_entropy(y_hat, y)
51 | loss = lce + le
52 | acc = self.accuracy(y_hat, y)
53 | self.log_dict(
54 | {
55 | "train_loss": loss,
56 | "train_loss_ce": lce,
57 | "train_loss_entropy": le,
58 | "train_acc": acc,
59 | }
60 | )
61 | return loss
62 |
63 | def test_step(self, batch, batch_idx):
64 | x, y = batch
65 | y_hat, _ = self.model(x)
66 | loss = self.cross_entropy(y_hat, y)
67 | acc = self.accuracy(y_hat, y)
68 | self.log("test_loss", loss)
69 | self.log("test_acc", acc)
70 |
71 |
72 | def main(args):
73 | """
74 | Training and evaluating an FFF model on MNIST image classification.
75 | Over-engineering code with Lightning and W&B to keep the good habits.
76 | """
77 | # Program arguments
78 | batch_size = args["batch_size"]
79 | max_epochs = args["max_epochs"]
80 | lr = args["lr"]
81 | hardening_weight = args["hardening_weight"]
82 | seed = args["seed"]
83 |
84 | # Setting reproducibility
85 | pl.seed_everything(seed)
86 |
87 | # Getting data
88 | transform = Compose([ToTensor(), FlattenMNIST()])
89 | train = MNIST("./", train=True, download=True, transform=transform)
90 | test = MNIST("./", train=False, download=True, transform=transform)
91 | train_loader = DataLoader(train, batch_size=batch_size, num_workers=4, shuffle=True)
92 | test_loader = DataLoader(test, batch_size=batch_size, num_workers=4, shuffle=False)
93 |
94 | # Getting model
95 | fff_model = FFFLayer(
96 | depth=3, in_dim=28 * 28, node_hidden=32, leaf_hidden=32, out_dim=10
97 | )
98 |
99 | model = PLWrapper(
100 | fff_model,
101 | total_iters=max_epochs * len(train_loader),
102 | hardening_weight=hardening_weight,
103 | lr=lr,
104 | ).train()
105 |
106 | # Training
107 | trainer = pl.Trainer(
108 | logger=WandbLogger(
109 | name="FFF MNIST",
110 | project="Papers Re-implementations",
111 | ),
112 | accelerator="auto",
113 | max_epochs=max_epochs,
114 | callbacks=[
115 | ModelCheckpoint(dirpath="./checkpoints", monitor="train_loss", mode="min"),
116 | TQDMProgressBar(),
117 | ],
118 | )
119 | trainer.fit(model, train_loader)
120 |
121 | # Loading best model
122 | model = PLWrapper.load_from_checkpoint(
123 | trainer.checkpoint_callback.best_model_path, model=fff_model
124 | )
125 |
126 | # Testing and logging
127 | model.eval()
128 | trainer.test(model, test_loader)
129 |
130 |
131 | if __name__ == "__main__":
132 | parser = ArgumentParser()
133 | parser.add_argument(
134 | "--batch_size",
135 | type=int,
136 | default=64,
137 | help="Batch size for training and evaluation.",
138 | )
139 | parser.add_argument(
140 | "--max_epochs",
141 | type=int,
142 | default=10,
143 | help="Number of epochs to train the model.",
144 | )
145 | parser.add_argument(
146 | "--lr", type=float, default=0.2, help="Learning rate for the optimizer."
147 | )
148 | parser.add_argument(
149 | "--hardening_weight",
150 | type=float,
151 | default=3.0,
152 | help="Hardening weight for the entropy loss.",
153 | )
154 | parser.add_argument(
155 | "--seed", type=int, default=0, help="Seed for the random number generators."
156 | )
157 | args = vars(parser.parse_args())
158 | print(args)
159 | main(args)
160 |
--------------------------------------------------------------------------------
/src/gnns/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/gnns/__init__.py
--------------------------------------------------------------------------------
/src/gnns/gnns.py:
--------------------------------------------------------------------------------
1 | """Implementation of convolutional, attentional and message-passing GNNs, inspired by the paper
2 | Everything is Connected: Graph Neural Networks
3 | (https://arxiv.org/abs/2301.08210)
4 |
5 | Useful links:
6 | - Petar Veličković PDF talk: https://petar-v.com/talks/GNN-EEML.pdf
7 | """
8 |
9 | import warnings
10 | from argparse import ArgumentParser
11 |
12 | import torch
13 | import torch.nn as nn
14 | from torch.optim import Adam
15 | from torch.utils.data import DataLoader
16 | from torchvision.datasets import MNIST
17 | from torchvision.transforms import Compose, Lambda, Resize, ToTensor
18 | from tqdm import tqdm
19 |
20 | import wandb
21 |
22 | # Definitions
23 | NETWORK_TYPES = ["attn", "conv"]
24 | AGGREGATION_FUNCTIONS = {
25 | "sum": lambda X, dim=1: torch.sum(X, dim=dim),
26 | "avg": lambda X, dim=1: torch.mean(X, dim=dim),
27 | }
28 |
29 |
30 | def parse_args():
31 | """Parses the program arguments"""
32 | parser = ArgumentParser()
33 |
34 | # Data arguments
35 | parser.add_argument(
36 | f"--image_size",
37 | type=int,
38 | help="Size to which reshape CIFAR images. Default is 14 (196 nodes).",
39 | default=14,
40 | )
41 |
42 | # Model arguments
43 | parser.add_argument(
44 | f"--type",
45 | type=str,
46 | help="Type of the network used. Default is 'attn'",
47 | choices=NETWORK_TYPES,
48 | default="attn",
49 | )
50 | parser.add_argument(
51 | f"--aggregation",
52 | type=str,
53 | help="Aggregation function",
54 | choices=list(AGGREGATION_FUNCTIONS.keys()),
55 | default="avg",
56 | )
57 | parser.add_argument(
58 | f"--aggregation_out",
59 | type=str,
60 | help="Aggregation function for graph classification",
61 | choices=list(AGGREGATION_FUNCTIONS.keys()),
62 | default="avg",
63 | )
64 | parser.add_argument(
65 | f"--n_layers", type=int, help="Number of layers of the GNNs", default=8
66 | )
67 |
68 | # Training arguments
69 | parser.add_argument(
70 | f"--epochs", type=int, help="Training epochs. Default is 10.", default=10
71 | )
72 | parser.add_argument(
73 | f"--lr", type=float, help="Learning rate. Default is 1e-3.", default=0.001
74 | )
75 | parser.add_argument(
76 | f"--batch_size",
77 | type=int,
78 | help="Batch size used for training. Default is 64.",
79 | default=64,
80 | )
81 | parser.add_argument(
82 | f"--checkpoint",
83 | type=str,
84 | help="Path to model checkpoints. Default is 'gnn.pt'.",
85 | default="gnn.pt",
86 | )
87 |
88 | return vars(parser.parse_args())
89 |
90 |
91 | def get_device():
92 | """Gets the CUDA device if available, warns that code will run on CPU only otherwise"""
93 | if torch.cuda.is_available():
94 | device = torch.device("cuda")
95 | print("\nFound GPU: ", torch.cuda.get_device_name(device))
96 | return device
97 | elif torch.backends.mps.is_available():
98 | device = torch.device("mps")
99 | print("\nFound Apple MPS chip.")
100 |
101 | warnings.warn("\nWARNING: No GPU nor MPS found - Training on CPU.")
102 | return torch.device("cpu")
103 |
104 |
105 | class PsiNetwork(nn.Module):
106 | """
107 | Simple MLP network denoted as the 'psi' function in the paper.
108 | The role of this network is to extract relevant features to be passed to neighbouring edges.
109 | """
110 |
111 | def __init__(self, in_size, out_size):
112 | super(PsiNetwork, self).__init__()
113 | self.linear = nn.Linear(in_size, out_size)
114 | self.relu = nn.ReLU()
115 |
116 | def forward(self, X):
117 | return self.relu(self.linear(X))
118 |
119 |
120 | class GraphConvLayer(nn.Module):
121 | """
122 | Graph Convolutional layer.
123 | It computes the next hidden states of the edges as a convolution over neighbours.
124 | """
125 |
126 | def __init__(self, n, d, aggr):
127 | super(GraphConvLayer, self).__init__()
128 |
129 | self.coefficients = nn.Parameter(torch.ones((n, n)) / n)
130 | self.ln = nn.LayerNorm(d)
131 | self.psi = PsiNetwork(d, d)
132 | self.aggr = aggr
133 |
134 | def forward(self, H, A):
135 | weights = self.coefficients * A # (N, N)
136 | messages = self.psi(self.ln(H)) # (B, N, D)
137 | messages = torch.einsum("nm, bmd -> bndm", weights, messages) # (B, N, D, N)
138 | messages = self.aggr(messages, dim=-1) # (B, N, D)
139 | return messages
140 |
141 |
142 | class Attention(nn.Module):
143 | def __init__(self, dim):
144 | super(Attention, self).__init__()
145 | self.dim = dim
146 |
147 | def forward(self, x, mask=None):
148 | # x has shape (B, N, D)
149 | attn_cues = (x @ x.transpose(-2, -1)) / (self.dim**0.5 + 1e-5) # (B, N, N)
150 |
151 | if mask is not None:
152 | attn_cues = attn_cues.masked_fill(mask == 0, float("-inf"))
153 |
154 | attn_cues = attn_cues.softmax(-1)
155 | return attn_cues
156 |
157 |
158 | class GraphAttentionLayer(nn.Module):
159 | """
160 | Graph Attentional Layer.
161 | It computes the next hidden states of the edges using attention.
162 | """
163 |
164 | def __init__(self, n, d, aggr):
165 | super(GraphAttentionLayer, self).__init__()
166 |
167 | self.aggr = aggr
168 | self.psi = PsiNetwork(d, d)
169 | self.ln1 = nn.LayerNorm(d)
170 | self.ln2 = nn.LayerNorm(2 * d)
171 | self.sa = Attention(d)
172 |
173 | def forward(self, H, A):
174 | messages = self.psi(self.ln1(H)) # (B, N, D)
175 | attn = self.sa(H, A) # (B, N, N)
176 |
177 | messages = torch.einsum("bnm, bmd -> bndm", attn, messages) # (B, N, D, N)
178 | messages = self.aggr(messages, dim=-1) # (B, N, D)
179 | return messages
180 |
181 |
182 | class GraphNeuralNetwork(nn.Module):
183 | """Graph Neural Network class."""
184 |
185 | _NET_TYPE_TO_LAYER = {"attn": GraphAttentionLayer, "conv": GraphConvLayer}
186 |
187 | def _get_phi_net(self, dim_in, dim_out):
188 | return nn.Sequential(
189 | nn.Linear(dim_in, dim_out), nn.ReLU(), nn.Linear(dim_out, dim_out)
190 | )
191 |
192 | def __init__(self, net_type, n_layers, n, d_in, d_hidden, d_out, aggr, aggr_out):
193 | super(GraphNeuralNetwork, self).__init__()
194 |
195 | assert (
196 | net_type in NETWORK_TYPES
197 | ), f"ERROR: GNN type {net_type} not supported. Pick one of {NETWORK_TYPES}"
198 |
199 | self.net_type = net_type
200 | self.n_layers = n_layers
201 | self.encoding = nn.Linear(d_in, d_hidden)
202 | self.layers = nn.ModuleList(
203 | [
204 | self._NET_TYPE_TO_LAYER[net_type](
205 | n, d_hidden, AGGREGATION_FUNCTIONS[aggr]
206 | )
207 | for _ in range(n_layers)
208 | ]
209 | )
210 |
211 | self.phi_nets = nn.ModuleList(
212 | [self._get_phi_net(2 * d_hidden, d_hidden) for _ in range(n_layers)]
213 | )
214 |
215 | self.out_aggr = AGGREGATION_FUNCTIONS[aggr_out]
216 | self.out_mlp = nn.Sequential(
217 | nn.Linear(d_hidden, d_hidden), nn.ReLU(), nn.Linear(d_hidden, d_out)
218 | )
219 |
220 | def forward(self, X, A):
221 | # X has shape (B, N, D) and represents the edges.
222 | # A is binary with shape (N, N) and represents the adjacency matrix.
223 | H = self.encoding(X)
224 | for l, p in zip(self.layers, self.phi_nets):
225 | messages = l(H, A)
226 | H = H + p(torch.cat((H, messages), dim=-1))
227 | return self.out_mlp(self.out_aggr(H))
228 |
229 |
230 | def main():
231 | # Parsing arguments
232 | args = parse_args()
233 | print("Launched program with the following arguments:", args)
234 |
235 | # Getting device
236 | device = get_device()
237 |
238 | # Loading data
239 | # We reshape the image such that each pixel is an edge with C features.
240 | img_size = args["image_size"]
241 | transform = Compose(
242 | [
243 | ToTensor(),
244 | Resize((img_size, img_size)),
245 | # (C, H, W) -> (H*W, C).
246 | Lambda(lambda x: x.flatten().reshape(-1, 1)),
247 | ]
248 | )
249 | train_set = MNIST("./../datasets", train=True, transform=transform, download=True)
250 | test_set = MNIST("./../datasets", train=False, transform=transform, download=True)
251 |
252 | train_loader = DataLoader(train_set, batch_size=args["batch_size"], shuffle=True)
253 | test_loader = DataLoader(test_set, batch_size=args["batch_size"], shuffle=False)
254 |
255 | # Building the Neighbourhood matrix (1024 x 1024) for all "graphs" (images of size 32x32)
256 | A = torch.zeros((img_size**2, img_size**2)).to(device)
257 | nums = torch.arange(img_size**2).reshape((img_size, img_size))
258 | for i in range(img_size):
259 | for j in range(img_size):
260 | start_x, start_y = i - 1 if i > 0 else 0, j - 1 if j > 0 else 0
261 | neighbours = nums[start_x : i + 2, start_y : j + 2].flatten()
262 |
263 | for n in neighbours:
264 | A[i * img_size + j, n] = A[n, i * img_size + j] = 1
265 |
266 | # Creating model
267 | # Number of edges, edge dimensionality, hidden dimensionality and number of output classes
268 | n, d, h, o = img_size**2, 1, 16, 10
269 | model = GraphNeuralNetwork(
270 | args["type"],
271 | args["n_layers"],
272 | n,
273 | d,
274 | h,
275 | o,
276 | aggr=args["aggregation"],
277 | aggr_out=args["aggregation_out"],
278 | )
279 |
280 | # Training loop
281 | n_epochs = args["epochs"]
282 | checkpoint = (
283 | args["checkpoint"] if args["checkpoint"] else f"({args['type']}_gnn).pt"
284 | )
285 | optim = Adam(model.parameters(), args["lr"])
286 | criterion = nn.CrossEntropyLoss()
287 | model = model.to(device)
288 | model.train()
289 |
290 | min_loss = float("inf")
291 | progress_bar = tqdm(range(1, n_epochs + 1))
292 |
293 | wandb.init(
294 | project="Papers Re-implementations",
295 | entity="peutlefaire",
296 | name=f"GNN ({args['type']})",
297 | config={
298 | "type": args["type"],
299 | "aggregation": args["aggregation"],
300 | "layers": args["n_layers"],
301 | "epochs": n_epochs,
302 | "batch_size": args["batch_size"],
303 | "lr": args["lr"],
304 | },
305 | )
306 |
307 | for epoch in progress_bar:
308 | epoch_loss = 0.0
309 | for batch in tqdm(train_loader, leave=False):
310 | x, y = batch
311 | x, y = x.to(device), y.to(device)
312 |
313 | loss = criterion(model(x, A), y)
314 | epoch_loss += loss.item() / len(train_loader)
315 | optim.zero_grad()
316 | loss.backward()
317 | optim.step()
318 |
319 | wandb.log({"batch loss": loss.item()})
320 |
321 | description = f"Epoch {epoch}/{n_epochs} - Training loss: {epoch_loss:.3f}"
322 | if epoch_loss < min_loss:
323 | min_loss = epoch_loss
324 | torch.save(model.state_dict(), checkpoint)
325 | description += " -> ✅ Stored best model."
326 |
327 | wandb.log({"epoch loss": epoch_loss})
328 | progress_bar.set_description(description)
329 | wandb.finish()
330 |
331 | # Testing loop
332 | model.load_state_dict(torch.load(checkpoint, map_location=device))
333 | model = model.eval()
334 | with torch.no_grad():
335 | correct, total = 0, 0
336 | for batch in test_loader:
337 | x, y = batch
338 | x, y = x.to(device), y.to(device)
339 |
340 | correct += (model(x, A).argmax(1) == y).sum().item()
341 | total += len(y)
342 | print(f"\n\nFinal test accuracy: {(correct / total * 100):.2f}%")
343 | print("Program completed successfully!")
344 |
345 |
346 | if __name__ == "__main__":
347 | main()
348 |
--------------------------------------------------------------------------------
/src/nlp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/nlp/__init__.py
--------------------------------------------------------------------------------
/src/nlp/bert/README.md:
--------------------------------------------------------------------------------
1 | # BERT
2 |
3 | ## Dataset
4 | Training BERT will download the Wikipedia dataset from March 1st, 2022 from [huggingface datasets](https://huggingface.co/datasets/wikipedia). The total disk size required for the dataset is ~ `43GB`, and it will be downloaded under your `HF_DATASETS_CACHE`.
5 |
--------------------------------------------------------------------------------
/src/nlp/bert/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/nlp/bert/__init__.py
--------------------------------------------------------------------------------
/src/nlp/bert/data.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class BertDataset(Dataset):
8 | """Creates a dataset for BERT training where each sample is broken into two sentences."""
9 |
10 | def __init__(
11 | self,
12 | dataset,
13 | tokenizer,
14 | max_length,
15 | sentence_divider="\n\n",
16 | mlm_ratio=0.15,
17 | mask_ratio=0.8,
18 | ):
19 | super(BertDataset, self).__init__()
20 |
21 | # Filtering empty sentences
22 | self.dataset = dataset.filter(
23 | lambda sample: len(sample["text"]) > 0
24 | and sentence_divider in sample["text"],
25 | )
26 |
27 | # Dataset parameters
28 | self.tokenizer = tokenizer
29 | self.max_length = max_length
30 | self.sentence_divider = sentence_divider
31 | self.mlm_ratio = mlm_ratio
32 | self.mask_ratio = mask_ratio
33 |
34 | # MLM distribution
35 | nothing_prob = 1 - mlm_ratio
36 | mask_prob = (1 - nothing_prob) * mask_ratio
37 | different_word_prob = (1 - nothing_prob - mask_prob) / 2
38 | probs = torch.tensor(
39 | [nothing_prob, mask_prob, different_word_prob, different_word_prob]
40 | )
41 | self.mask_dist = torch.distributions.Multinomial(probs=probs)
42 |
43 | def __len__(self):
44 | return len(self.dataset)
45 |
46 | def __getitem__(self, index):
47 | # First sentence
48 | sentences = self.dataset[index]["text"].split(self.sentence_divider)
49 | n_sentences = len(sentences)
50 |
51 | # Picking first sentence
52 | s1_idx = random.randint(0, n_sentences - 2)
53 | s1 = sentences[s1_idx]
54 |
55 | # Next sentence and prediction label
56 | if torch.rand(1) > 0.5:
57 | # 50% of the time, pick the real next "sentence"
58 | s2 = sentences[(s1_idx + 1)]
59 | nsp_label = 1
60 | else:
61 | # The other 50% of the time, pick a random next "sentence"
62 | idx = random.randint(0, len(self.dataset))
63 | if idx == index:
64 | idx = idx + 1 % len(self.dataset)
65 | sentences_other = self.dataset[index]["text"].split(self.sentence_divider)
66 | s2 = sentences_other[random.randint(0, len(sentences_other) - 1)]
67 | nsp_label = 0
68 |
69 | # Preparing input ids
70 | tokenizer_out = self.tokenizer(
71 | s1,
72 | s2,
73 | return_tensors="pt",
74 | padding="max_length",
75 | max_length=self.max_length,
76 | truncation=True,
77 | )
78 |
79 | input_ids, segment_idx, attn_mask = (
80 | tokenizer_out["input_ids"][0],
81 | tokenizer_out["token_type_ids"][0],
82 | tokenizer_out["attention_mask"][0],
83 | )
84 |
85 | # Getting mask indexes
86 | mask_idx = self.mask_dist.sample((len(input_ids),))
87 |
88 | # Not masking CLS and SEP
89 | sep_idx = -1
90 | for i in range(len(segment_idx)):
91 | if segment_idx[i] == 1:
92 | sep_idx = i - 1
93 | break
94 | mask_idx[0] = mask_idx[-1] = mask_idx[sep_idx] = torch.tensor([1, 0, 0, 0])
95 | mask_idx = mask_idx.argmax(dim=-1)
96 |
97 | # Getting labels for masked tokens
98 | mlm_idx = (mask_idx != 0).long()
99 | mlm_labels = input_ids.clone()
100 |
101 | # Masking input tokens according to strategy
102 | input_ids[mask_idx == 1] = self.tokenizer.mask_token_id
103 | input_ids[mask_idx == 2] = torch.randint(0, self.tokenizer.vocab_size, (1,))
104 |
105 | return {
106 | "input_ids": input_ids,
107 | "segment_ids": segment_idx,
108 | "attention_mask": attn_mask,
109 | "mlm_labels": mlm_labels,
110 | "mlm_idx": mlm_idx,
111 | "nsp_labels": nsp_label,
112 | }
113 |
--------------------------------------------------------------------------------
/src/nlp/bert/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Re-implementation of
3 |
4 | BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
5 | Devlin et al. (2018)
6 | (https://arxiv.org/abs/1810.04805)
7 |
8 | on the WikiPedia dataset.
9 | """
10 |
11 | import os
12 | from argparse import ArgumentParser
13 |
14 | import torch
15 | from torch.utils.data import DataLoader
16 |
17 | torch.set_float32_matmul_precision("medium")
18 |
19 | import pytorch_lightning as pl
20 | import transformers
21 | from datasets import load_dataset
22 | from pytorch_lightning.callbacks import ModelCheckpoint
23 | from pytorch_lightning.loggers import WandbLogger
24 | from transformers import BertTokenizer
25 |
26 | transformers.logging.set_verbosity_error()
27 |
28 | from src.nlp.bert.data import BertDataset
29 | from src.nlp.bert.model import Bert
30 |
31 |
32 | @torch.no_grad()
33 | def unmask_sentences(bert, tokenizer, max_length, file_path):
34 | """Uses the bert model to unmask sentences from a file. Prints the unmaksed sentences."""
35 | file = open(file_path, "r")
36 | lines = file.readlines()
37 | lines = [line if not line.endswith("\n") else line[:-1] for line in lines]
38 | file.close()
39 |
40 | bert = bert.cuda()
41 | for i, line in enumerate(lines):
42 | # Preparing input
43 | input_ids = tokenizer(
44 | line,
45 | return_tensors="pt",
46 | max_length=max_length,
47 | padding="max_length",
48 | truncation=True,
49 | )["input_ids"].cuda()
50 | segment_ids = torch.zeros_like(input_ids)
51 | attn_mask = None
52 |
53 | # Running BERT
54 | _, _, mlm_preds = bert(input_ids)
55 |
56 | # Getting predictions for the MASK'd words
57 | unmasked_words = mlm_preds[input_ids == tokenizer.mask_token_id].argmax(dim=-1)
58 | unmasked_words = tokenizer.decode(unmasked_words).split(" ")
59 |
60 | # Reconstructing the unmasked sentence
61 | sentence = ""
62 | parts = line.split("[MASK]")
63 | for word in unmasked_words:
64 | sentence += parts.pop(0) + word
65 | sentence += parts.pop(0)
66 |
67 | # Showing results
68 | print(f"\n\nSENTENCE {i+1}:")
69 | print(f"\tOriginal: {line}\n\tUnmasked: {sentence}")
70 |
71 |
72 | def main(args):
73 | """
74 | Train a BERT model on Wikipedia.
75 | Use the model to "unmask" sentences from a file.
76 | """
77 | # Unpacking parameters
78 | n_blocks = args["n_blocks"]
79 | n_heads = args["n_heads"]
80 | hidden_dim = args["hidden_dim"]
81 | dropout = args["dropout"]
82 | max_len = args["max_len"]
83 | batch_size = args["batch_size"]
84 | max_train_steps = args["max_train_steps"]
85 | lr = args["lr"]
86 | weight_decay = args["weight_decay"]
87 | warmup_steps = args["warmup_steps"]
88 | save_dir = args["save"]
89 | file_path = args["masked_sentences"]
90 | seed = args["seed"]
91 |
92 | # Setting random seed
93 | pl.seed_everything(seed)
94 |
95 | # Load the dataset (wikipedia only has 'train', so we split it ourselves)
96 | train_set = load_dataset("wikipedia", "20220301.en", split="train[:98%]")
97 | val_set = load_dataset("wikipedia", "20220301.en", split="train[98%:99%]")
98 | test_set = load_dataset("wikipedia", "20220301.en", split="train[99%:]")
99 |
100 | # Setting format to torch (not striclty necessary)
101 | # train_set.set_format(type="torch", columns=["text"])
102 | # val_set.set_format(type="torch", columns=["text"])
103 | # test_set.set_format(type="torch", columns=["text"])
104 |
105 | # Bert tokenizer
106 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
107 |
108 | # Wrapping dataset with Bert logic for batches
109 | train_set, val_set, test_set = (
110 | BertDataset(train_set, tokenizer, max_len),
111 | BertDataset(val_set, tokenizer, max_len),
112 | BertDataset(test_set, tokenizer, max_len),
113 | )
114 |
115 | # Data loaders
116 | cpus = os.cpu_count()
117 | train_loader = DataLoader(
118 | train_set, batch_size=batch_size, shuffle=True, num_workers=cpus
119 | )
120 | val_loader = DataLoader(
121 | val_set, batch_size=batch_size, shuffle=False, num_workers=cpus
122 | )
123 | test_loader = DataLoader(
124 | test_set, batch_size=batch_size, shuffle=False, num_workers=cpus
125 | )
126 |
127 | # Initialize the model
128 | vocab_size = tokenizer.vocab_size
129 | bert = Bert(
130 | vocab_size,
131 | max_len,
132 | hidden_dim,
133 | n_heads,
134 | n_blocks,
135 | dropout=dropout,
136 | train_config={
137 | "lr": lr,
138 | "weight_decay": weight_decay,
139 | "max_train_steps": max_train_steps,
140 | "warmup_steps": warmup_steps,
141 | },
142 | )
143 |
144 | # Training
145 | checkpointing_freq = max_train_steps // 10
146 | wandb_logger = WandbLogger(project="Papers Re-implementations", name="BERT")
147 | wandb_logger.experiment.config.update(args)
148 | callbacks = [ModelCheckpoint(save_dir, monitor="val_loss", filename="best", every_n_train_steps=checkpointing_freq)]
149 | trainer = pl.Trainer(
150 | accelerator="auto",
151 | strategy="ddp", # State dict not saved with fsdp for some reason
152 | max_steps=max_train_steps,
153 | logger=wandb_logger,
154 | callbacks=callbacks,
155 | val_check_interval=checkpointing_freq,
156 | profiler="simple",
157 | )
158 | trainer.fit(bert, train_loader, val_loader)
159 |
160 | # Testing the best model
161 | bert = Bert.load_from_checkpoint(os.path.join(save_dir, "best.ckpt"))
162 | trainer.test(bert, test_loader)
163 |
164 | if file_path is not None and os.path.isfile(file_path):
165 | # Unmasking sentences
166 | unmask_sentences(bert, tokenizer, max_len, file_path)
167 |
168 |
169 | if __name__ == "__main__":
170 | parser = ArgumentParser()
171 |
172 | # Model hyper-parameters
173 | parser.add_argument("--n_blocks", type=int, default=12)
174 | parser.add_argument("--n_heads", type=int, default=12)
175 | parser.add_argument("--hidden_dim", type=int, default=768)
176 | parser.add_argument("--dropout", type=float, default=0.1)
177 | parser.add_argument("--max_len", type=int, default=128)
178 |
179 | # Training parameters
180 | parser.add_argument("--batch_size", type=int, default=32)
181 | parser.add_argument("--max_train_steps", type=int, default=10_000)
182 | parser.add_argument("--lr", type=float, default=1e-4)
183 | parser.add_argument("--weight_decay", type=float, default=0.01)
184 | parser.add_argument("--warmup_steps", type=int, default=100)
185 | parser.add_argument("--save", type=str, default="checkpoints/bert")
186 |
187 | # Testing parameters
188 | parser.add_argument(
189 | "--masked_sentences", type=str, default="data/nlp/bert/masked_sentences.txt"
190 | )
191 |
192 | # Seed
193 | parser.add_argument("--seed", type=int, default=0)
194 |
195 | args = vars(parser.parse_args())
196 | print(args)
197 | main(args)
198 |
--------------------------------------------------------------------------------
/src/nlp/bert/model.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import pytorch_lightning as pl
4 | import torch
5 | import torch.nn as nn
6 | from pytorch_lightning.utilities.types import STEP_OUTPUT
7 | from torch.optim import AdamW
8 | from torch.optim.lr_scheduler import LambdaLR
9 |
10 | from src.nlp.layers.embeddings import get_learnable_embedding
11 | from src.nlp.layers.encoder import EncoderTransformer
12 |
13 |
14 | class Bert(pl.LightningModule):
15 | DEFAULT_BERT_CONFIG = {
16 | "lr": 1e-4,
17 | "betas": (0.9, 0.999),
18 | "weight_decay": 0.01,
19 | "max_train_steps": 10_000,
20 | "warmup_steps": 100,
21 | }
22 |
23 | def __init__(
24 | self,
25 | vocab_size,
26 | max_len,
27 | hidden_dim,
28 | n_heads,
29 | depth,
30 | dropout=0.1,
31 | train_config=None,
32 | ):
33 | super(Bert, self).__init__()
34 |
35 | # Saving hyper-parameters so that they are logged
36 | self.save_hyperparameters()
37 |
38 | # Local parameters
39 | self.vocab_size = vocab_size
40 | self.max_len = max_len
41 | self.hidden_dim = hidden_dim
42 | self.n_heads = n_heads
43 | self.depth = depth
44 | self.dropout = dropout
45 | self.train_config = Bert.DEFAULT_BERT_CONFIG
46 |
47 | # Schedulers
48 | self.linear_scheduler = None
49 | self.warmup_scheduler = None
50 |
51 | # Training config
52 | if train_config is not None:
53 | self.train_config.update(train_config)
54 |
55 | # Embeddings
56 | self.embeddings = get_learnable_embedding(
57 | vocab_size, hidden_dim
58 | ) # nn.Embedding(vocab_size, hidden_dim)
59 | self.pos_embeddings = get_learnable_embedding(
60 | max_len, hidden_dim
61 | ) # nn.Embedding(max_len, hidden_dim)
62 | self.sentence_embeddings = get_learnable_embedding(
63 | 2, hidden_dim
64 | ) # nn.Embedding(2, hidden_dim)
65 |
66 | # Transformer and output layer
67 | self.transformer = EncoderTransformer(
68 | hidden_dim, n_heads, depth, dropout_p=dropout
69 | )
70 |
71 | # Next sentence classifier
72 | self.classifier = nn.Sequential(
73 | nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, 2)
74 | )
75 |
76 | # Masked language modeling head
77 | self.mask_predictor = nn.Sequential(
78 | nn.Linear(hidden_dim, hidden_dim),
79 | nn.GELU(),
80 | nn.LayerNorm(hidden_dim),
81 | nn.Linear(hidden_dim, vocab_size, bias=True),
82 | )
83 |
84 | def forward(self, ids, segment_ids=None, attn_mask=None):
85 | # Embedding
86 | b, t = ids.shape
87 | hidden = self.embeddings(ids)
88 | hidden += self.pos_embeddings(torch.arange(t, device=ids.device)).repeat(
89 | b, 1, 1
90 | )
91 |
92 | if segment_ids is not None:
93 | hidden += self.sentence_embeddings(segment_ids)
94 | else:
95 | hidden += self.sentence_embeddings(
96 | torch.zeros(1, dtype=torch.long, device=ids.device)
97 | ).repeat(b, 1, 1)
98 |
99 | # Transformer
100 | hidden = self.transformer(hidden, attn_mask=attn_mask)
101 |
102 | # Classification head based on CLS token
103 | class_preds = self.classifier(hidden[:, 0])
104 |
105 | # Masked language modeling head
106 | mlm_preds = self.mask_predictor(hidden)
107 |
108 | return hidden, class_preds, mlm_preds
109 |
110 | def get_losses(self, batch):
111 | # Unpacking batch
112 | ids = batch["input_ids"]
113 | segment_ids = batch["segment_ids"]
114 | attn_mask = batch["attention_mask"]
115 | mlm_labels = batch["mlm_labels"]
116 | mlm_idx = batch["mlm_idx"]
117 | nsp_labels = batch["nsp_labels"]
118 |
119 | # Running forward
120 | b, t = ids.shape
121 | _, class_preds, mlm_preds = self(
122 | ids, segment_ids, attn_mask.repeat(1, t).reshape(b, t, t)
123 | )
124 | mlm_preds = mlm_preds[mlm_idx + attn_mask == 2]
125 | mlm_labels = mlm_labels[mlm_idx + attn_mask == 2]
126 |
127 | # Classification loss
128 | class_loss = torch.nn.functional.cross_entropy(class_preds, nsp_labels)
129 |
130 | # Masked language modeling loss
131 | mlm_loss = torch.nn.functional.cross_entropy(mlm_preds, mlm_labels)
132 |
133 | # Getting accuracies
134 | class_acc = (class_preds.argmax(dim=-1) == nsp_labels).float().mean()
135 | mlm_acc = (mlm_preds.argmax(dim=-1) == mlm_labels).float().mean()
136 |
137 | return class_loss, mlm_loss, class_acc, mlm_acc
138 |
139 | def configure_optimizers(self):
140 | optim = AdamW(
141 | self.trainer.model.parameters(),
142 | lr=self.train_config["lr"],
143 | weight_decay=self.train_config["weight_decay"],
144 | betas=self.train_config["betas"],
145 | )
146 |
147 | self.linear_scheduler = LambdaLR(optim, self.scheduler_fn)
148 |
149 | return {"optimizer": optim}
150 |
151 | def scheduler_fn(self, step):
152 | warmup_steps = self.train_config["warmup_steps"]
153 | max_steps = self.train_config["max_train_steps"]
154 |
155 | if step < warmup_steps:
156 | return step / warmup_steps
157 | return 1 - (step - warmup_steps) / (max_steps - warmup_steps)
158 |
159 | def on_train_batch_end(
160 | self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int
161 | ) -> None:
162 | if self.linear_scheduler is not None:
163 | self.linear_scheduler.step()
164 |
165 | return super().on_train_batch_end(outputs, batch, batch_idx)
166 |
167 | def training_step(self, batch, batch_idx):
168 | # Getting losses & accuracies
169 | class_loss, mlm_loss, c_acc, m_acc = self.get_losses(batch)
170 |
171 | # Total loss
172 | loss = class_loss + mlm_loss
173 |
174 | # Logging
175 | self.log_dict(
176 | {
177 | "lr": self.optimizers().param_groups[0]["lr"],
178 | "train_loss": loss,
179 | "train_class_loss": class_loss,
180 | "train_mlm_loss": mlm_loss,
181 | "train_class_acc": c_acc,
182 | "train_mlm_acc": m_acc,
183 | },
184 | sync_dist=True,
185 | )
186 |
187 | return loss
188 |
189 | def validation_step(self, batch, batch_idx):
190 | # Getting losses & accuracies
191 | class_loss, mlm_loss, c_acc, m_acc = self.get_losses(batch)
192 |
193 | # Total loss
194 | loss = class_loss + mlm_loss
195 |
196 | # Logging
197 | self.log_dict(
198 | {
199 | "val_loss": loss,
200 | "val_class_loss": class_loss,
201 | "val_mlm_loss": mlm_loss,
202 | "val_class_acc": c_acc,
203 | "val_mlm_acc": m_acc,
204 | },
205 | sync_dist=True,
206 | )
207 |
208 | return loss
209 |
210 | def test_step(self, batch, batch_idx):
211 | # Getting losses & accuracies
212 | class_loss, mlm_loss, c_acc, m_acc = self.get_losses(batch)
213 |
214 | # Total loss
215 | loss = class_loss + mlm_loss
216 |
217 | # Logging
218 | self.log_dict(
219 | {
220 | "test_loss": loss,
221 | "test_class_loss": class_loss,
222 | "test_mlm_loss": mlm_loss,
223 | "test_class_acc": c_acc,
224 | "test_mlm_acc": m_acc,
225 | },
226 | sync_dist=True,
227 | )
228 |
229 | return loss
230 |
--------------------------------------------------------------------------------
/src/nlp/gpt/README.md:
--------------------------------------------------------------------------------
1 | # GPT
2 |
3 | ## Dataset
4 | Training BERT will download the Wikipedia dataset from March 1st, 2022 from [huggingface datasets](https://huggingface.co/datasets/wikipedia). The total disk size required for the dataset is ~ `43GB`, and it will be downloaded under your `HF_DATASETS_CACHE`.
5 |
--------------------------------------------------------------------------------
/src/nlp/gpt/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/nlp/gpt/__init__.py
--------------------------------------------------------------------------------
/src/nlp/gpt/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Re-implementation of
3 |
4 | Language Models are Few-Shot Learners
5 | Brown et al. (2020)
6 | (https://arxiv.org/abs/2005.14165)
7 |
8 | on the WikiPedia dataset.
9 | """
10 |
11 | import os
12 | from argparse import ArgumentParser
13 |
14 | import torch
15 | from torch.utils.data import DataLoader
16 |
17 | torch.set_float32_matmul_precision("medium")
18 |
19 | import pytorch_lightning as pl
20 | import transformers
21 | from datasets import load_dataset
22 | from pytorch_lightning.callbacks import ModelCheckpoint
23 | from pytorch_lightning.loggers import WandbLogger
24 | from transformers import GPT2Tokenizer
25 |
26 | transformers.logging.set_verbosity_error()
27 |
28 | from src.nlp.gpt.model import GPT
29 |
30 |
31 | def continue_sentences(gpt, tokenizer, max_len, file_path):
32 | """Uses the gpt model to continue sentences from a file. Prints the continued sentences."""
33 | file = open(file_path, "r")
34 | lines = file.readlines()
35 | lines = [line if not line.endswith("\n") else line[:-1] for line in lines]
36 | file.close()
37 |
38 | gpt = gpt.cuda()
39 | for i, line in enumerate(lines):
40 | # Preparing input
41 | input_ids = tokenizer(
42 | line,
43 | return_tensors="pt",
44 | max_length=max_len,
45 | )["input_ids"].cuda()
46 |
47 | # Generating sentence
48 | all_ids = gpt.generate(input_ids, max_len)
49 |
50 | # Decoding the sentence
51 | sentence = tokenizer.decode(all_ids.squeeze().tolist())
52 | print(f"\n\nSentence {i+1}:")
53 | print(f"\tOriginal: {line}\n\tCompleted: {sentence}")
54 |
55 |
56 | def main(args):
57 | """
58 | Train a GPT model on Wikipedia.
59 | Use the model to continue sentences from a file.
60 | """
61 | # Unpacking parameters
62 | n_blocks = args["n_blocks"]
63 | n_heads = args["n_heads"]
64 | hidden_dim = args["hidden_dim"]
65 | dropout = args["dropout"]
66 | max_len = args["max_len"]
67 | batch_size = args["batch_size"]
68 | max_train_steps = args["max_train_steps"]
69 | lr = args["lr"]
70 | weight_decay = args["weight_decay"]
71 | warmup_steps = args["warmup_steps"]
72 | save_dir = args["save"]
73 | file_path = args["prompts"]
74 | seed = args["seed"]
75 |
76 | # Setting random seed
77 | pl.seed_everything(seed)
78 |
79 | # Load the dataset (wikipedia only has 'train', so we split it ourselves)
80 | train_set = load_dataset("wikipedia", "20220301.en", split="train[:98%]")
81 | val_set = load_dataset("wikipedia", "20220301.en", split="train[98%:99%]")
82 | test_set = load_dataset("wikipedia", "20220301.en", split="train[99%:]")
83 |
84 | # Setting format to torch (not striclty necessary)
85 | # train_set.set_format(type="torch", columns=["text"])
86 | # val_set.set_format(type="torch", columns=["text"])
87 | # test_set.set_format(type="torch", columns=["text"])
88 |
89 | # Loading the GPT2 tokenizer
90 | added_pad_token = False
91 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
92 | if tokenizer.pad_token is None:
93 | tokenizer.add_special_tokens({"pad_token": "[PAD]"})
94 | added_pad_token = True
95 |
96 | # Tokenizing the whole dataset
97 | def collate(batch):
98 | return tokenizer(
99 | [sample["text"] for sample in batch],
100 | return_tensors="pt",
101 | truncation=True,
102 | padding="max_length",
103 | max_length=max_len,
104 | )
105 |
106 | # Data loaders
107 | cpus = os.cpu_count()
108 | train_loader = DataLoader(
109 | train_set,
110 | batch_size=batch_size,
111 | shuffle=True,
112 | num_workers=cpus,
113 | collate_fn=collate,
114 | )
115 | val_loader = DataLoader(
116 | val_set,
117 | batch_size=batch_size,
118 | shuffle=False,
119 | num_workers=cpus,
120 | collate_fn=collate,
121 | )
122 | test_loader = DataLoader(
123 | test_set,
124 | batch_size=batch_size,
125 | shuffle=False,
126 | num_workers=cpus,
127 | collate_fn=collate,
128 | )
129 |
130 | # Initialize the model
131 | vocab_size = tokenizer.vocab_size + 1 if added_pad_token else tokenizer.vocab_size
132 | gpt = GPT(
133 | vocab_size,
134 | max_len,
135 | hidden_dim,
136 | n_heads,
137 | n_blocks,
138 | dropout=dropout,
139 | train_config={
140 | "lr": lr,
141 | "weight_decay": weight_decay,
142 | "max_train_steps": max_train_steps,
143 | "warmup_steps": warmup_steps,
144 | },
145 | )
146 |
147 | # Training
148 | checkpointing_freq = max_train_steps // 10
149 | wandb_logger = WandbLogger(project="Papers Re-implementations", name="GPT")
150 | wandb_logger.experiment.config.update(args)
151 | callbacks = [ModelCheckpoint(save_dir, monitor="val_loss", filename="best", every_n_train_steps=checkpointing_freq)]
152 | trainer = pl.Trainer(
153 | accelerator="auto",
154 | strategy="ddp", # State dict not saved with fsdp for some reason
155 | max_steps=max_train_steps,
156 | logger=wandb_logger,
157 | callbacks=callbacks,
158 | profiler="simple",
159 | val_check_interval=checkpointing_freq
160 | )
161 | trainer.fit(gpt, train_loader, val_loader)
162 |
163 | # Testing the best model
164 | gpt = GPT.load_from_checkpoint(os.path.join(save_dir, "best.ckpt"))
165 | trainer.test(gpt, test_loader)
166 |
167 | if file_path is not None and os.path.isfile(file_path):
168 | # Generating sentences from prompts
169 | continue_sentences(gpt, tokenizer, max_len, file_path)
170 |
171 |
172 | if __name__ == "__main__":
173 | parser = ArgumentParser()
174 |
175 | # Model hyper-parameters
176 | parser.add_argument("--n_blocks", type=int, default=12)
177 | parser.add_argument("--n_heads", type=int, default=12)
178 | parser.add_argument("--hidden_dim", type=int, default=768)
179 | parser.add_argument("--dropout", type=float, default=0.1)
180 | parser.add_argument("--max_len", type=int, default=128)
181 |
182 | # Training parameters
183 | parser.add_argument("--batch_size", type=int, default=32)
184 | parser.add_argument("--max_train_steps", type=int, default=10_000)
185 | parser.add_argument("--lr", type=float, default=1e-4)
186 | parser.add_argument("--weight_decay", type=float, default=0.01)
187 | parser.add_argument("--warmup_steps", type=int, default=100)
188 | parser.add_argument("--save", type=str, default="checkpoints/gpt")
189 |
190 | # Testing parameters
191 | parser.add_argument("--prompts", type=str, default="data/nlp/gpt/prompts.txt")
192 |
193 | # Seed
194 | parser.add_argument("--seed", type=int, default=0)
195 |
196 | args = vars(parser.parse_args())
197 | print(args)
198 | main(args)
199 |
--------------------------------------------------------------------------------
/src/nlp/gpt/model.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import pytorch_lightning as pl
4 | import torch
5 | import torch.nn as nn
6 | from pytorch_lightning.utilities.types import STEP_OUTPUT
7 | from torch.optim import AdamW
8 | from torch.optim.lr_scheduler import LambdaLR
9 |
10 | from src.nlp.layers.decoder import DecoderTransformer
11 | from src.nlp.layers.embeddings import get_learnable_embedding
12 |
13 |
14 | class GPT(pl.LightningModule):
15 | DEFAULT_GPT_CONFIG = {
16 | "lr": 1e-4,
17 | "betas": (0.9, 0.999),
18 | "weight_decay": 0.01,
19 | "max_train_steps": 10_000,
20 | "warmup_steps": 100,
21 | }
22 |
23 | def __init__(
24 | self,
25 | vocab_size,
26 | max_len,
27 | hidden_dim,
28 | n_heads,
29 | depth,
30 | dropout=0.1,
31 | train_config=None,
32 | ):
33 | super(GPT, self).__init__()
34 |
35 | # Saving hyper-parameters so that they are logged
36 | self.save_hyperparameters()
37 |
38 | # Local parameters
39 | self.vocab_size = vocab_size
40 | self.max_len = max_len
41 | self.hidden_dim = hidden_dim
42 | self.n_heads = n_heads
43 | self.depth = depth
44 | self.dropout = dropout
45 | self.train_config = GPT.DEFAULT_GPT_CONFIG
46 |
47 | # Schedulers
48 | self.linear_scheduler = None
49 | self.warmup_scheduler = None
50 |
51 | # Training config
52 | if train_config is not None:
53 | self.train_config.update(train_config)
54 |
55 | # Embeddings
56 | self.embeddings = get_learnable_embedding(
57 | vocab_size, hidden_dim
58 | ) # nn.Embedding(vocab_size, hidden_dim)
59 | self.pos_embeddings = get_learnable_embedding(
60 | max_len, hidden_dim
61 | ) # nn.Embedding(max_len, hidden_dim)
62 |
63 | # Transformer and output layer
64 | self.transformer = DecoderTransformer(
65 | hidden_dim, n_heads, depth, dropout_p=dropout
66 | )
67 |
68 | # Next token classifier
69 | self.classifier = nn.Sequential(
70 | nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, vocab_size)
71 | )
72 |
73 | def forward(self, ids, attn_mask=None):
74 | # Embedding
75 | b, t = ids.shape
76 | hidden = self.embeddings(ids)
77 | hidden += self.pos_embeddings(torch.arange(t, device=ids.device)).repeat(
78 | b, 1, 1
79 | )
80 |
81 | # Transformer
82 | hidden = self.transformer(hidden, self_attn_mask=attn_mask)
83 |
84 | # Classification
85 | return self.classifier(hidden), hidden
86 |
87 | def get_losses(self, batch):
88 | # Unpacking batch
89 | ids = batch["input_ids"]
90 | attn_mask = batch["attention_mask"]
91 |
92 | # Running forward
93 | b, t = ids.shape
94 | out, _ = self(ids, attn_mask.repeat(1, t).reshape(b, t, t).tril())
95 |
96 | # Computing cross-entropy loss
97 | preds, labels = out[:, :-1], ids[:, 1:]
98 | preds, labels = preds[attn_mask[:, :-1] == 1], labels[attn_mask[:, :-1] == 1]
99 | ce_loss = nn.functional.cross_entropy(
100 | preds.reshape(-1, self.vocab_size), labels.reshape(-1)
101 | )
102 |
103 | accuracy = (preds.argmax(dim=-1) == labels).float().mean()
104 | perplexity = torch.exp(ce_loss)
105 |
106 | return ce_loss, accuracy, perplexity
107 |
108 | def configure_optimizers(self):
109 | optim = AdamW(
110 | self.trainer.model.parameters(),
111 | lr=self.train_config["lr"],
112 | weight_decay=self.train_config["weight_decay"],
113 | betas=self.train_config["betas"],
114 | )
115 |
116 | self.linear_scheduler = LambdaLR(optim, self.scheduler_fn)
117 | return {"optimizer": optim}
118 |
119 | def scheduler_fn(self, step):
120 | warmup_steps = self.train_config["warmup_steps"]
121 | max_steps = self.train_config["max_train_steps"]
122 |
123 | if step < warmup_steps:
124 | return step / warmup_steps
125 | return 1 - (step - warmup_steps) / (max_steps - warmup_steps)
126 |
127 | def on_train_batch_end(
128 | self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int
129 | ) -> None:
130 | if self.linear_scheduler is not None:
131 | self.linear_scheduler.step()
132 |
133 | return super().on_train_batch_end(outputs, batch, batch_idx)
134 |
135 | def training_step(self, batch, batch_idx):
136 | # Getting losses & accuracies
137 | ce_loss, accuracy, perplexity = self.get_losses(batch)
138 |
139 | # Logging
140 | self.log_dict(
141 | {
142 | "lr": self.optimizers().param_groups[0]["lr"],
143 | "train_loss": ce_loss,
144 | "train_acc": accuracy,
145 | "train_ppl": perplexity,
146 | },
147 | sync_dist=True,
148 | )
149 |
150 | return ce_loss
151 |
152 | def validation_step(self, batch, batch_idx):
153 | # Getting losses & accuracies
154 | ce_loss, accuracy, perplexity = self.get_losses(batch)
155 |
156 | # Logging
157 | self.log_dict(
158 | {
159 | "lr": self.optimizers().param_groups[0]["lr"],
160 | "val_loss": ce_loss,
161 | "val_acc": accuracy,
162 | "val_ppl": perplexity,
163 | },
164 | sync_dist=True,
165 | )
166 |
167 | return ce_loss
168 |
169 | def test_step(self, batch, batch_idx):
170 | # Getting losses & accuracies
171 | ce_loss, accuracy, perplexity = self.get_losses(batch)
172 |
173 | # Logging
174 | self.log_dict(
175 | {
176 | "lr": self.optimizers().param_groups[0]["lr"],
177 | "test_loss": ce_loss,
178 | "test_acc": accuracy,
179 | "test_ppl": perplexity,
180 | },
181 | sync_dist=True,
182 | )
183 |
184 | return ce_loss
185 |
186 | def generate(self, input_ids, max_len):
187 | # Predicting next token until max_len
188 | remaining = max_len - input_ids.shape[1]
189 | for _ in range(remaining):
190 | # Running GPT
191 | preds = self(input_ids)[0]
192 |
193 | # Getting probability of next token
194 | probs = preds[:, -1, :].softmax(dim=-1)
195 |
196 | # Sampling next token with multinomial sampling
197 | next_token = torch.multinomial(probs, num_samples=1)
198 |
199 | # Adding token to input_ids
200 | input_ids = torch.cat((input_ids, next_token), dim=-1)
201 | return input_ids
202 |
--------------------------------------------------------------------------------
/src/nlp/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/nlp/layers/__init__.py
--------------------------------------------------------------------------------
/src/nlp/layers/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Attention(nn.Module):
6 | """Single Attention Head."""
7 |
8 | def __init__(self, dropout_p=0.1):
9 | super(Attention, self).__init__()
10 | self.softmax = nn.Softmax(dim=-1)
11 | self.dropout = nn.Dropout(dropout_p)
12 |
13 | def forward(self, q, k, v, mask=None):
14 | # Comput the attention scores by computing the dot product of queries with keys
15 | b, t, d = q.shape
16 | attn = q @ k.transpose(-2, -1) / (d**0.5) # b, nq, nk
17 |
18 | # Mask interactions that should not be captured
19 | if mask is not None:
20 | assert (
21 | mask.shape == attn.shape
22 | ), f"Mask has shape {mask.shape} != {attn.shape}"
23 | attn = attn.masked_fill(mask == 0, float("-inf"))
24 |
25 | # Computing final output by multiplying attention scores with values
26 | attn = self.softmax(attn)
27 | out = attn @ v
28 |
29 | # Dropping out as regularization during training
30 | out = self.dropout(out)
31 | return out
32 |
33 |
34 | class MultiHeadAttention(nn.Module):
35 | """Multi-Head Attention."""
36 |
37 | def __init__(self, n_heads, dropout_p=0.1):
38 | super(MultiHeadAttention, self).__init__()
39 | self.n_heads = n_heads
40 | self.heads = nn.ModuleList([Attention(dropout_p) for _ in range(self.n_heads)])
41 |
42 | def forward(self, q, k, v, mask=None):
43 | # Check that dimensionalities are divisible by the number of heads
44 | b, nq, d = q.shape
45 | b, nk, dv = v.shape
46 | assert (
47 | d % self.n_heads == 0
48 | ), f"{d}-dimensional query cannot be broken into {self.n_heads} heads."
49 | assert (
50 | dv % self.n_heads == 0
51 | ), f"{dv}-dimensional value cannot be broken into {self.n_heads} heads."
52 |
53 | # Computing attention in all sub-vectors
54 | qk_dim_per_head = int(d / self.n_heads)
55 | v_dim_per_head = int(dv / self.n_heads)
56 | out = torch.cat(
57 | [
58 | head(
59 | q[:, :, i * qk_dim_per_head : (i + 1) * qk_dim_per_head],
60 | k[:, :, i * qk_dim_per_head : (i + 1) * qk_dim_per_head],
61 | v[:, :, i * v_dim_per_head : (i + 1) * v_dim_per_head],
62 | mask,
63 | )
64 | for i, head in enumerate(self.heads)
65 | ],
66 | dim=-1,
67 | )
68 |
69 | return out
70 |
--------------------------------------------------------------------------------
/src/nlp/layers/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from src.nlp.layers.attention import MultiHeadAttention
5 | from src.nlp.layers.mlp import MLP
6 |
7 |
8 | class DecoderBlock(nn.Module):
9 | def __init__(
10 | self,
11 | hidden_dim,
12 | n_heads,
13 | mlp_hidden=None,
14 | mlp_out=None,
15 | mlp_activation=nn.GELU(),
16 | dropout_p=0.1,
17 | with_xa=False,
18 | ):
19 | super(DecoderBlock, self).__init__()
20 | self.with_xa = with_xa
21 |
22 | self.ln1 = nn.LayerNorm(hidden_dim)
23 | self.sa_qkv = nn.Linear(hidden_dim, 3 * hidden_dim)
24 | self.mhsa = MultiHeadAttention(n_heads, dropout_p)
25 | self.sa_o = nn.Linear(hidden_dim, hidden_dim)
26 |
27 | if with_xa:
28 | self.ln2 = nn.LayerNorm(hidden_dim)
29 | self.xa_q = nn.Linear(hidden_dim, hidden_dim)
30 | self.xa_kv = nn.Linear(hidden_dim, 2 * hidden_dim)
31 | self.mhxa = MultiHeadAttention(n_heads, dropout_p)
32 | self.xa_o = nn.Linear(hidden_dim, hidden_dim)
33 |
34 | self.ln3 = nn.LayerNorm(hidden_dim)
35 | self.mlp = MLP(hidden_dim, mlp_hidden, mlp_out, mlp_activation, dropout_p)
36 |
37 | def forward(self, x, kv=None, self_attn_mask=None, cross_attn_mask=None):
38 | # Self-attention and residual part
39 | q, k, v = self.sa_qkv(self.ln1(x)).chunk(3, -1)
40 | sa_out = self.sa_o(self.mhsa(q, k, v, self_attn_mask))
41 | x = x + sa_out
42 |
43 | # Cross-attention and residual part
44 | if self.with_xa and kv is not None:
45 | q = self.xa_q(self.ln2(x))
46 | k, v = self.xa_kv(kv).chunk(2, -1)
47 | xa_out = self.xa_o(self.mhxa(q, k, v, cross_attn_mask))
48 | x = x + xa_out
49 |
50 | # MLP and residual part
51 | out = x + self.mlp(self.ln3(x))
52 | return out
53 |
54 |
55 | class DecoderTransformer(nn.Module):
56 | def __init__(
57 | self,
58 | hidden_dim,
59 | n_heads,
60 | depth,
61 | mlp_hidden=None,
62 | mlp_out=None,
63 | mlp_activation=nn.GELU(),
64 | dropout_p=0.1,
65 | with_xa=False,
66 | ):
67 | super(DecoderTransformer, self).__init__()
68 |
69 | self.blocks = nn.ModuleList(
70 | [
71 | DecoderBlock(
72 | hidden_dim,
73 | n_heads,
74 | mlp_hidden,
75 | mlp_out,
76 | mlp_activation,
77 | dropout_p,
78 | with_xa,
79 | )
80 | for _ in range(depth)
81 | ]
82 | )
83 |
84 | def forward(self, hidden, kv=None, self_attn_mask=None, cross_attn_mask=None):
85 | # Creating causal mask if not provided
86 | if self_attn_mask is None:
87 | b, l, d = hidden.shape
88 | self_attn_mask = torch.tril(torch.ones(l, l, device=hidden.device)).repeat(
89 | b, 1, 1
90 | )
91 |
92 | # Running blocks
93 | for block in self.blocks:
94 | hidden = block(hidden, kv, self_attn_mask, cross_attn_mask)
95 |
96 | return hidden
97 |
--------------------------------------------------------------------------------
/src/nlp/layers/embeddings.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | def get_learnable_embedding(n, hidden_dim):
7 | return nn.Embedding(n, hidden_dim)
8 |
9 |
10 | def get_sinusoidal_embedding(n, hidden_dim):
11 | emb = nn.Embedding(n, hidden_dim)
12 |
13 | def arg(pos, coord):
14 | return pos / np.power(10000, coord / hidden_dim)
15 |
16 | weight = [
17 | [
18 | np.sin(arg(pos, coord)) if coord % 2 == 0 else np.cos(arg(pos, coord))
19 | for coord in range(hidden_dim)
20 | ]
21 | for pos in range(n)
22 | ]
23 |
24 | emb.weight.data.copy_(torch.tensor(weight))
25 | emb.requires_grad_(False)
26 | return emb
27 |
28 |
29 | def get_rope_embedding(n, hidden_dim):
30 | # TODO ...
31 | pass
32 |
33 |
34 | if __name__ == "__main__":
35 | import matplotlib.pyplot as plt
36 |
37 | sin_emb = get_sinusoidal_embedding(256, 768)
38 | sin_emb = sin_emb.cpu().weight.data.numpy()
39 | plt.imshow(sin_emb)
40 | plt.xlabel("Hidden Dimension")
41 | plt.ylabel("Position")
42 | plt.title("Sinusoidal Embedding")
43 | plt.show()
44 |
45 | plt.imshow(sin_emb @ sin_emb.T)
46 | plt.xlabel("Position")
47 | plt.ylabel("Position")
48 | plt.title("Dot product of Sinusoidal Embeddings")
49 | plt.show()
50 |
--------------------------------------------------------------------------------
/src/nlp/layers/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from src.nlp.layers.attention import MultiHeadAttention
5 | from src.nlp.layers.mlp import MLP
6 |
7 |
8 | class EncoderBlock(nn.Module):
9 | def __init__(
10 | self,
11 | hidden_dim,
12 | n_heads,
13 | mlp_hidden=None,
14 | mlp_out=None,
15 | mlp_activation=nn.GELU(),
16 | dropout_p=0.1,
17 | ):
18 | super(EncoderBlock, self).__init__()
19 |
20 | self.ln1 = nn.LayerNorm(hidden_dim)
21 | self.ln2 = nn.LayerNorm(hidden_dim)
22 |
23 | self.to_qkv = nn.Linear(hidden_dim, 3 * hidden_dim)
24 | self.to_o = nn.Linear(hidden_dim, hidden_dim)
25 |
26 | self.mhsa = MultiHeadAttention(n_heads, dropout_p)
27 | self.mlp = MLP(hidden_dim, mlp_hidden, mlp_out, mlp_activation, dropout_p)
28 |
29 | def forward(self, x, attn_mask=None):
30 | # Attention and residual connection
31 | q, k, v = self.to_qkv(self.ln1(x)).chunk(3, -1)
32 | attn_out = self.to_o(self.mhsa(q, k, v, mask=attn_mask))
33 | x = x + attn_out
34 |
35 | # MLP and residual connection
36 | mlp_out = self.mlp(self.ln2(x))
37 | out = x + mlp_out
38 |
39 | return out
40 |
41 |
42 | class EncoderTransformer(nn.Module):
43 | def __init__(
44 | self,
45 | hidden_dim,
46 | n_heads,
47 | depth,
48 | mlp_hidden=None,
49 | mlp_out=None,
50 | mlp_activation=nn.GELU(),
51 | dropout_p=0.1,
52 | ):
53 | super(EncoderTransformer, self).__init__()
54 |
55 | self.blocks = nn.ModuleList(
56 | [
57 | EncoderBlock(
58 | hidden_dim, n_heads, mlp_hidden, mlp_out, mlp_activation, dropout_p
59 | )
60 | for _ in range(depth)
61 | ]
62 | )
63 |
64 | def forward(self, hidden, attn_mask=None):
65 | # Creating full attention mask if not provided
66 | if attn_mask is None:
67 | b, l, d = hidden.shape
68 | attn_mask = torch.ones((l, l), device=hidden.device).repeat(b, 1, 1)
69 |
70 | # Running blocks
71 | for block in self.blocks:
72 | hidden = block(hidden, attn_mask)
73 | return hidden
74 |
--------------------------------------------------------------------------------
/src/nlp/layers/mlp.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class MLP(nn.Module):
5 | def __init__(
6 | self, in_dim, hidden_dim=None, out_dim=None, activation=nn.GELU(), drop_p=0.1
7 | ) -> None:
8 | super(MLP, self).__init__()
9 |
10 | self.in_dim = in_dim
11 | self.hidden_dim = hidden_dim if hidden_dim is not None else in_dim * 4
12 | self.out_dim = out_dim if out_dim is not None else in_dim
13 |
14 | self.linear1 = nn.Linear(self.in_dim, self.hidden_dim)
15 | self.linear2 = nn.Linear(self.hidden_dim, self.out_dim)
16 | self.activation = activation
17 | self.dropout = nn.Dropout(drop_p)
18 |
19 | def forward(self, x):
20 | out = self.linear1(x)
21 | out = self.activation(out)
22 | out = self.linear2(out)
23 | out = self.dropout(out)
24 | return out
25 |
--------------------------------------------------------------------------------
/src/nlp/lm_is_compression/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/nlp/lm_is_compression/__init__.py
--------------------------------------------------------------------------------
/src/nlp/lm_is_compression/lm_is_compression.py:
--------------------------------------------------------------------------------
1 | """Re-implementation of
2 |
3 | Language Modeling Is Compression
4 | (https://arxiv.org/abs/2309.10668)
5 |
6 | by Delétang, Ruoss et. al.
7 | """
8 |
9 | from argparse import ArgumentParser
10 |
11 | import torch
12 | import torch.nn as nn
13 | from transformers import AutoModelForCausalLM, AutoTokenizer
14 |
15 |
16 | def set_reproducibility(seed=0):
17 | torch.manual_seed(seed)
18 | torch.backends.cudnn.deterministic = True
19 | torch.backends.cudnn.benchmark = False
20 |
21 |
22 | class ArithmeticEncoder:
23 | """The arithmetic encoder converts a sequence of tokens into a single
24 | number according to the probability distribution of the next token given
25 | by the language model."""
26 |
27 | def __init__(self, model: nn.Module, bos_token_id: int):
28 | self.bos_token_id = bos_token_id
29 |
30 | self.model = model.eval()
31 | self.softmax = nn.Softmax(dim=-1)
32 |
33 | def __call__(self, ids: torch.Tensor) -> bytes:
34 | return self.encode(ids)
35 |
36 | def highs_lows_to_lambdas(
37 | self, highs: torch.Tensor, lows: torch.Tensor
38 | ) -> torch.Tensor:
39 | # Now returning the midpoints
40 | # TODO: Return the numbers with the shortest binary representation
41 | return (highs + lows) / 2
42 |
43 | @torch.no_grad()
44 | def encode(self, ids: torch.Tensor) -> bytes:
45 | """Encode a sequence of tokens into a binary sequence of bits.
46 | The encoding is done by finding the scalar number in range [0, 1]
47 | using arithmetic encoding based on the probability distribution of the
48 | next token given by the language model.
49 |
50 | Args:
51 | ids (torch.Tensor): Two-dimensional tensor of token ids. Omit the BOS token.
52 |
53 | Returns:
54 | bytes: The encoded sequence of bits
55 | """
56 | # Appending the BOS token to the beginning of each sequence
57 | bos_tokens = torch.full(
58 | (ids.shape[0], 1), self.bos_token_id, dtype=torch.long, device=ids.device
59 | )
60 | ids = torch.cat([bos_tokens, ids], dim=1)
61 | N, T = ids.shape
62 |
63 | # Getting the probabilities of the next token
64 | logits = self.model(ids)["logits"]
65 | probs = self.softmax(logits)
66 |
67 | # Find the lambda number for each sequence
68 | lows, highs = torch.zeros(N, dtype=torch.double), torch.ones(
69 | N, dtype=torch.double
70 | )
71 | for i in range(T - 1):
72 | intervals = highs - lows
73 |
74 | # Getting cumulative probabilities
75 | # TODO: Parallelize this loop
76 | c_probs = torch.empty(N)
77 | n_probs = torch.empty(N)
78 | for j in range(N):
79 | c_probs[j] = probs[j, i, : ids[j, i + 1]].sum()
80 | n_probs[j] = probs[j, i, : ids[j, i + 1] + 1].sum()
81 |
82 | # Updating lows and highs
83 | highs = lows + intervals * n_probs
84 | lows = lows + intervals * c_probs
85 |
86 | # Return the lambda numbers
87 | return self.highs_lows_to_lambdas(highs, lows)
88 |
89 | @torch.no_grad()
90 | def decode(self, lambdas: torch.Tensor, atol=1e-30) -> torch.Tensor:
91 | """Undo the encoding and, given scalar lambdas, return the original input ids."""
92 | N, dev = lambdas.shape[0], lambdas.device
93 | ids = torch.full((N, 1), self.bos_token_id, dtype=torch.long, device=dev)
94 |
95 | # Recovering the ids
96 | lows, highs = torch.zeros(N, dtype=torch.double, device=dev), torch.ones(
97 | N, dtype=torch.double, device=dev
98 | )
99 | while not torch.allclose(
100 | self.highs_lows_to_lambdas(highs, lows), lambdas, atol=atol
101 | ):
102 | probs = self.softmax(self.model(ids)["logits"][:, -1])
103 |
104 | next_ids = torch.empty(N, dtype=torch.long, device=lambdas.device)
105 | for i in range(N):
106 | lamb = lambdas[i]
107 | low, high = lows[i], highs[i]
108 | for j in range(probs.shape[1]):
109 | l = low + (high - low) * probs[i, :j].sum()
110 | u = low + (high - low) * probs[i, : j + 1].sum()
111 |
112 | if l <= lamb < u:
113 | highs[i], lows[i] = u, l
114 | next_ids[i] = j
115 | break
116 |
117 | ids = torch.cat([ids, next_ids.unsqueeze(1)], dim=1)
118 |
119 | return ids
120 |
121 | def to(self, device):
122 | self.model.to(device)
123 | self.softmax.to(device)
124 | return self
125 |
126 |
127 | def main(args):
128 | # Getting program parameters
129 | model_ckpt = args["model"]
130 | seed = args["seed"]
131 |
132 | # Setting reproducibility
133 | set_reproducibility(seed)
134 |
135 | # Preparing sentences to encode
136 | sentences = ["The quick brown fox jumps over the lazy dog."]
137 |
138 | # Loading model and tokenizer
139 | model = AutoModelForCausalLM.from_pretrained(
140 | model_ckpt, torch_dtype=torch.float32, device_map="auto"
141 | ).eval()
142 | tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
143 |
144 | if tokenizer.pad_token is None:
145 | tokenizer.add_special_tokens({"pad_token": "[PAD]"})
146 | model.resize_token_embeddings(len(tokenizer))
147 |
148 | # Encoding sentences
149 | ids = tokenizer(sentences, return_tensors="pt", padding=True)["input_ids"].cuda()
150 | encoder = ArithmeticEncoder(model, tokenizer.bos_token_id)
151 | encoded = encoder(ids)
152 | decoded = encoder.decode(encoded.cuda())
153 |
154 | # Printing results
155 | print("\n\nOriginal sentences:", sentences)
156 | print(
157 | "Decoded sentences:", tokenizer.batch_decode(decoded, skip_special_tokens=True)
158 | )
159 |
160 | print("\n\nOriginal ids:", ids)
161 | print("Decoded ids:", decoded[:, 1:])
162 |
163 | print("\n\nEncoded sentences (as scalars):", encoded.cpu().numpy())
164 |
165 |
166 | if __name__ == "__main__":
167 | parser = ArgumentParser()
168 | parser.add_argument("--model", type=str, default="EleutherAI/pythia-1.4b-v0")
169 | parser.add_argument("--seed", type=int, default=0)
170 | args = vars(parser.parse_args())
171 | print(args)
172 | main(args)
173 |
--------------------------------------------------------------------------------
/src/nlp/lm_watermarking/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from accelerate import Accelerator
3 | from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed
4 |
5 | from src.nlp.lm_watermarking.watermarking import (
6 | detect_watermark,
7 | generate,
8 | get_perplexities,
9 | )
10 |
11 |
12 | class GPT2Wrapper(torch.nn.Module):
13 | """A wrapper around the GPT2 model to take ids as input and return logits as output."""
14 |
15 | def __init__(self):
16 | super(GPT2Wrapper, self).__init__()
17 | self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
18 | self.model = GPT2LMHeadModel.from_pretrained("gpt2")
19 |
20 | def forward(self, input_ids):
21 | outputs = self.model(input_ids)
22 | return outputs.logits
23 |
24 |
25 | def main():
26 | """Plots the perplexity of the GPT2 model and the z-static for sentences generated with and without watermarking."""
27 | # Setting seed
28 | set_seed(0)
29 |
30 | # Device
31 | device = Accelerator().device
32 |
33 | # Language Model (GPT2)
34 | model = GPT2Wrapper().to(device)
35 | vocab_size = model.tokenizer.vocab_size
36 |
37 | # Prior text
38 | prior = model.tokenizer("Some text to be continued", return_tensors="pt")[
39 | "input_ids"
40 | ].to(device)
41 |
42 | # A sentence generated without watermarking
43 | normal_ids = generate(model, prior, max_length=200, watermark=False)
44 | n_ppl = get_perplexities(model, normal_ids).item() # Perplexity
45 | n_z = detect_watermark(normal_ids, vocab_size).item() # Z-statistic
46 |
47 | # A sentence generated with watermarking
48 | watermarked_ids = generate(model, prior, max_length=200, watermark=True)
49 | w_ppl = get_perplexities(model, watermarked_ids).item() # Perplexity
50 | w_z = detect_watermark(watermarked_ids, vocab_size).item() # Z-statistic
51 |
52 | # Showing non-watermarked text, PPL and probability of watermark
53 | print(
54 | f"\n\n\033[92mNormal text (PPL = {n_ppl:.2f}, Z-statistic = {n_z:.2f})\033[0m:\n"
55 | )
56 | print(model.tokenizer.decode(normal_ids[0]))
57 |
58 | # Showing watermarked text, PPL and probability of watermark
59 | print(f"\n\n\033[93mWM text (PPL = {w_ppl:.2f}, Z-statistic = {w_z:.2f})\033[0m:\n")
60 | print(model.tokenizer.decode(watermarked_ids[0]))
61 |
62 |
63 | if __name__ == "__main__":
64 | main()
65 |
--------------------------------------------------------------------------------
/src/nlp/lm_watermarking/plot.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from tqdm import tqdm
3 |
4 | plt.rcParams.update({"font.size": 22})
5 | from argparse import ArgumentParser
6 |
7 | import torch
8 | from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed
9 |
10 | from src.nlp.lm_watermarking.watermarking import (
11 | detect_watermark,
12 | generate,
13 | get_perplexities,
14 | )
15 |
16 |
17 | def parse_args():
18 | parser = ArgumentParser()
19 |
20 | parser.add_argument(
21 | "--n_sentences", type=int, default=128, help="Number of sentences to generate"
22 | )
23 | parser.add_argument(
24 | "--seq_len", type=int, default=200, help="Length of the generated sentences"
25 | )
26 | parser.add_argument(
27 | "--batch_size", type=int, default=16, help="Batch size for the generation"
28 | )
29 | parser.add_argument(
30 | "--gamma", type=float, default=0.5, help="Green list proportion"
31 | )
32 | parser.add_argument(
33 | "--delta",
34 | type=float,
35 | default=2,
36 | help="Amount to add to the logits of the model when watermarking",
37 | )
38 | parser.add_argument(
39 | "--device", type=int, default=0, help="Device to use for generation"
40 | )
41 | parser.add_argument("--seed", type=int, default=0, help="Seed for the generation")
42 |
43 | return vars(parser.parse_args())
44 |
45 |
46 | class GPT2Wrapper(torch.nn.Module):
47 | """A wrapper around the GPT2 model to take ids as input and return logits as output."""
48 |
49 | def __init__(self):
50 | super(GPT2Wrapper, self).__init__()
51 | self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
52 | self.model = GPT2LMHeadModel.from_pretrained("gpt2")
53 |
54 | def forward(self, input_ids):
55 | outputs = self.model(input_ids)
56 | return outputs.logits
57 |
58 |
59 | def main():
60 | # Plotting parameters
61 | args = parse_args()
62 | n_sentences = args["n_sentences"]
63 | seq_len = args["seq_len"]
64 | batch_size = args["batch_size"]
65 | gamma = args["gamma"]
66 | delta = args["delta"]
67 | seed = args["seed"]
68 |
69 | # Setting seed
70 | set_seed(seed)
71 |
72 | # Device
73 | device = torch.device(
74 | "cuda:" + str(args["device"]) if torch.cuda.is_available() else "cpu"
75 | )
76 |
77 | # Model
78 | model = GPT2Wrapper().to(device)
79 | vocab_size = model.tokenizer.vocab_size
80 |
81 | # Prior text (BOS token)
82 | prior = (
83 | (model.tokenizer.bos_token_id * torch.ones((n_sentences, 1))).long().to(device)
84 | )
85 |
86 | # Collecting generations with and without watermark
87 | regular_ppls, regular_z_scores = [], []
88 | watermarked_ppls, watermarked_z_scores = [], []
89 | for i in tqdm(range(0, n_sentences, batch_size), desc="Generating sentences"):
90 | batch = prior[i : i + batch_size]
91 |
92 | # Regular sentences
93 | regular = generate(model, batch, max_length=seq_len, watermark=False)
94 | regular_ppls.extend(get_perplexities(model, regular).tolist())
95 | regular_z_scores.extend(detect_watermark(regular, vocab_size).tolist())
96 |
97 | # Watermarked sentences
98 | watermarked = generate(
99 | model, batch, max_length=seq_len, watermark=True, gamma=gamma, delta=delta
100 | )
101 | watermarked_ppls.extend(get_perplexities(model, watermarked).tolist())
102 | watermarked_z_scores.extend(detect_watermark(watermarked, vocab_size).tolist())
103 |
104 | # Scatter plot of perplexity vs z-score
105 | plt.figure(figsize=(10, 10))
106 | plt.scatter(regular_ppls, regular_z_scores, label="Regular")
107 | plt.scatter(watermarked_ppls, watermarked_z_scores, label="Watermarked")
108 | plt.legend()
109 | plt.title("Perplexity vs Z-score")
110 | plt.xlabel("Perplexity")
111 | plt.ylabel("Z-score")
112 | plt.savefig(
113 | f"perplexity_vs_zscore_(n={n_sentences}, seq_len={seq_len}, gamma={gamma}, delta={delta}, seed={seed}).png"
114 | )
115 | plt.show()
116 | print("Program completed successfully!")
117 |
118 |
119 | if __name__ == "__main__":
120 | main()
121 |
--------------------------------------------------------------------------------
/src/nlp/lm_watermarking/watermarking.py:
--------------------------------------------------------------------------------
1 | """Implementation of watermarking for language models as proposed in the paper
2 |
3 | "A Watermark for Large Language Models"
4 | by Kirchenbauer, Geiping et. al. (https://arxiv.org/abs/2301.10226v2).
5 | """
6 |
7 | from hashlib import sha256
8 |
9 | import numpy as np
10 | import torch
11 |
12 |
13 | def default_hash_fn(tensor):
14 | """Returns the hash of the given tensor using the sha256 algorithm and by converting the tensor to a string first.
15 |
16 | Args:
17 | tensor: The tensor to hash.
18 |
19 | Returns:
20 | The hash of the tensor.
21 | """
22 | return int(sha256(str(tensor).encode("utf-8")).hexdigest(), 16) % (10**8)
23 |
24 |
25 | @torch.no_grad()
26 | def generate(
27 | model,
28 | prior_tokens,
29 | max_length=200,
30 | watermark=True,
31 | gamma=0.5,
32 | delta=2,
33 | hash_function=default_hash_fn,
34 | ):
35 | """Generates text with the given model. Optionally, the text can be (soft) watermarked.
36 |
37 | Args:
38 | model: The model which outputs logits for the next token.
39 | prior_tokens: The input tensor from which the model starts generating of shape (B, T).
40 | max_length: The maximum length of the generated text. Default is 100.
41 | gamma: The proportion of the green list. Default is 0.5.
42 | delta: The hardness parameter. Default is 20.
43 | hash_function: The function to use for hashing. Default is default_hash_fn.
44 |
45 | Returns:
46 | The generated text of shape (B, T).
47 | """
48 | B, T = prior_tokens.shape
49 | device = prior_tokens.device
50 |
51 | generated_tokens = prior_tokens
52 | for _ in range(max_length - T):
53 | # Getting logits
54 | l_t = model(generated_tokens)[:, -1, :]
55 |
56 | if watermark:
57 | # Seeding generators based on previous token
58 | seeds = [hash_function(generated_tokens[i, -1]) for i in range(B)]
59 | generators = [
60 | torch.Generator(device=device).manual_seed(seed) for seed in seeds
61 | ]
62 |
63 | # Increasing probability of green list indices
64 | vs = l_t.shape[-1] # Vocabulary size
65 | gls = int(gamma * vs) # Green list size
66 | gli = torch.stack(
67 | [
68 | torch.randperm(vs, generator=generators[i], device=device)
69 | for i in range(B)
70 | ]
71 | ) # Green list indices
72 |
73 | l_t = l_t + delta * (gli < gls)
74 |
75 | # Sampling from the distribution
76 | l_t = torch.softmax(l_t, dim=-1)
77 | next_tokens = torch.multinomial(l_t, 1)
78 | generated_tokens = torch.cat([generated_tokens, next_tokens], dim=-1)
79 |
80 | return generated_tokens
81 |
82 |
83 | def detect_watermark(ids, vocab_size, gamma=0.5, hash_function=default_hash_fn):
84 | """Returns the probability that a text was created by a Language Model with watermarking.
85 |
86 | Args:
87 | ids: The tensor with the generated text indices of shape (B, T).
88 | gamma: The proportion of the green list. Default is 0.5.
89 | delta: The hardness parameter. Default is 20.
90 | hash_function: The function used for watermarking. Default is default_hash_fn.
91 |
92 | Returns:
93 | The z-statistic of the watermarking probability.
94 | """
95 | B, T = ids.shape
96 | device = ids.device
97 | gls = int(gamma * vocab_size) # Green list size
98 | in_green_list = torch.zeros(B, dtype=torch.float32).to(
99 | device
100 | ) # Number of tokens in the green list
101 |
102 | for i in range(T - 1):
103 | # Seeding generators based on previous token
104 | seeds = [hash_function(ids[j, i]) for j in range(B)]
105 | generators = [
106 | torch.Generator(device=device).manual_seed(seed) for seed in seeds
107 | ]
108 |
109 | # Increasing probability of green list indices
110 | gli = torch.stack(
111 | [
112 | torch.randperm(vocab_size, generator=generators[i], device=device)
113 | for i in range(B)
114 | ]
115 | ) # Green list indices
116 |
117 | # Counting tokens that are in the green list and adding to the total
118 | in_green_list += (gli.gather(1, ids[:, i + 1].unsqueeze(-1)) < gls).squeeze()
119 |
120 | z = (in_green_list - gamma * T) / np.sqrt(T * gamma * (1 - gamma))
121 | return z
122 |
123 |
124 | @torch.no_grad()
125 | def get_perplexities(model, ids):
126 | """Returns the perplexities of the model for the given texts.
127 |
128 | Args:
129 | model: The model which outputs logits for the next token.
130 | ids: The tensor with the generated text indices of shape (B, T)
131 |
132 | Returns:
133 | The perplexities of the model for the given texts as a tensor of shape (B,).
134 | """
135 | B, T = ids.shape
136 |
137 | perplexities = torch.zeros(B).to(ids.device)
138 | for i in range(T - 1):
139 | l_t = model(ids[:, : i + 1])[:, -1, :]
140 | l_t = torch.softmax(l_t, dim=-1)
141 | l_t = l_t[range(B), ids[:, i + 1]]
142 | l_t = torch.log(l_t)
143 | perplexities += l_t
144 |
145 | return torch.exp(-perplexities / (T - 1))
146 |
--------------------------------------------------------------------------------
/src/nlp/original/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/nlp/original/__init__.py
--------------------------------------------------------------------------------
/src/nlp/original/data.py:
--------------------------------------------------------------------------------
1 | from os import cpu_count
2 |
3 | import torch
4 | from datasets import load_dataset
5 | from pytorch_lightning import LightningDataModule
6 | from torch.utils.data import DataLoader, Dataset
7 |
8 |
9 | class WMT14Subset(Dataset):
10 | def __init__(self, subset, tokenizer, max_len=128, languages="de-en"):
11 | super(WMT14Subset, self).__init__()
12 | self.tokenizer = tokenizer
13 | self.max_len = max_len
14 | self.lang1, self.lang2 = languages.split("-")
15 | self.subset = subset
16 |
17 | def __len__(self):
18 | return len(self.subset)
19 |
20 | def __getitem__(self, index):
21 | return self.preprocess(self.subset[index])
22 |
23 | def preprocess(self, sample):
24 | # Getting sentences
25 | s1 = sample["translation"][self.lang1]
26 | s2 = sample["translation"][self.lang2]
27 |
28 | # Tokenizing sentencens
29 | enc_tok = self.tokenizer(
30 | s1,
31 | return_tensors="pt",
32 | padding="max_length",
33 | truncation=True,
34 | max_length=self.max_len,
35 | )
36 | dec_tok = self.tokenizer(
37 | s2,
38 | return_tensors="pt",
39 | padding="max_length",
40 | truncation=True,
41 | max_length=self.max_len,
42 | )
43 |
44 | # Unpacking return values
45 | x_enc = enc_tok["input_ids"][0]
46 | x_dec = dec_tok["input_ids"][0]
47 | enc_attn_mask = enc_tok["attention_mask"][0]
48 | dec_attn_mask = dec_tok["attention_mask"][0]
49 |
50 | return {
51 | "x_enc": x_enc,
52 | "x_dec": x_dec,
53 | "enc_attn_mask": enc_attn_mask,
54 | "dec_attn_mask": dec_attn_mask,
55 | "enc_dec_attn_mask": enc_attn_mask,
56 | }
57 |
58 |
59 | class WMT14Dataset(LightningDataModule):
60 | def __init__(self, tokenizer, batch_size=32, max_len=128, languages="de-en"):
61 | super(WMT14Dataset, self).__init__()
62 | self.tokenizer = tokenizer
63 | self.batch_size = batch_size
64 | self.max_len = max_len
65 | self.languages = languages
66 |
67 | def prepare_data(self):
68 | self.wmt14 = load_dataset("wmt14", self.languages)
69 |
70 | def setup(self, stage):
71 | self.train = WMT14Subset(
72 | self.wmt14["train"], self.tokenizer, self.max_len, self.languages
73 | )
74 | self.val = WMT14Subset(
75 | self.wmt14["validation"], self.tokenizer, self.max_len, self.languages
76 | )
77 | self.test = WMT14Subset(
78 | self.wmt14["test"], self.tokenizer, self.max_len, self.languages
79 | )
80 |
81 | def train_dataloader(self):
82 | return DataLoader(
83 | self.train,
84 | batch_size=self.batch_size,
85 | shuffle=True,
86 | num_workers=cpu_count(),
87 | )
88 |
89 | def val_dataloader(self):
90 | return DataLoader(
91 | self.val, batch_size=self.batch_size, shuffle=False, num_workers=cpu_count()
92 | )
93 |
94 | def test_dataloader(self):
95 | return DataLoader(
96 | self.test,
97 | batch_size=self.batch_size,
98 | shuffle=False,
99 | num_workers=cpu_count(),
100 | )
101 |
102 | def teardown(self, stage):
103 | pass
104 |
--------------------------------------------------------------------------------
/src/nlp/original/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Re-implementation of
3 |
4 | Attention is all you need
5 | Vaswani et al. (2017)
6 | (https://arxiv.org/abs/1706.03762)
7 |
8 | on the WMT14 dataset.
9 | """
10 | import os
11 | from argparse import ArgumentParser
12 |
13 | import pytorch_lightning as pl
14 | import torch
15 | from pytorch_lightning.callbacks import ModelCheckpoint
16 | from pytorch_lightning.loggers import WandbLogger
17 | from transformers import FSMTTokenizer
18 |
19 | from src.nlp.original.data import WMT14Dataset
20 | from src.nlp.original.model import EncoderDecoderModel
21 |
22 |
23 | def translate_sentences(file_path, model, tokenizer, max_len=128):
24 | file = open(file_path, "r")
25 | sentences = file.readlines()
26 | file.close()
27 |
28 | model = model.cuda()
29 | dec_input = tokenizer.bos_token_id * torch.ones(1, 1)
30 | dec_input = dec_input.long().cuda()
31 | for i, sentence in enumerate(sentences):
32 | x = tokenizer(
33 | sentence,
34 | return_tensors="pt",
35 | padding="max_length",
36 | truncation=True,
37 | max_length=max_len,
38 | )["input_ids"].cuda()
39 | y = model.generate(x, dec_input, max_len)[0]
40 | translated = tokenizer.decode(y)
41 |
42 | print(f"\n\nSENTENCE {i+1}:")
43 | print(f"\tOriginal: {sentence}\n\tTranslated: {translated}")
44 |
45 |
46 | def main(args):
47 | # Unpacking arguments
48 | enc_n_blocks = args["enc_n_blocks"]
49 | enc_n_heads = args["enc_n_heads"]
50 | dec_n_blocks = args["dec_n_blocks"]
51 | dec_n_heads = args["dec_n_heads"]
52 | hidden_dim = args["hidden_dim"]
53 | dropout = args["dropout"]
54 | max_len = args["max_len"]
55 | batch_size = args["batch_size"]
56 | max_train_steps = args["max_train_steps"]
57 | lr = args["lr"]
58 | weight_decay = args["weight_decay"]
59 | warmup_steps = args["warmup_steps"]
60 | save_dir = args["save"]
61 | file_path = args["file"]
62 | seed = args["seed"]
63 |
64 | # Setting seed
65 | pl.seed_everything(seed)
66 |
67 | # Loading dataset
68 | tokenizer = FSMTTokenizer.from_pretrained("facebook/wmt19-de-en")
69 | vocab_size = tokenizer.vocab_size
70 | if tokenizer.pad_token is None:
71 | tokenizer.add_special_tokens({"pad_token": ""})
72 | vocab_size += 1
73 |
74 | dataset = WMT14Dataset(tokenizer, batch_size, max_len)
75 |
76 | # Loading model
77 | model = EncoderDecoderModel(
78 | vocab_size,
79 | max_len,
80 | hidden_dim,
81 | enc_n_heads,
82 | enc_n_blocks,
83 | dec_n_heads,
84 | dec_n_blocks,
85 | dropout,
86 | lr,
87 | weight_decay,
88 | warmup_steps,
89 | )
90 |
91 | # Training
92 | checkpointing_freq = max_train_steps // 10
93 | callbacks = [
94 | ModelCheckpoint(
95 | save_dir,
96 | monitor="val_loss",
97 | filename="best",
98 | every_n_train_steps=checkpointing_freq,
99 | )
100 | ]
101 | logger = WandbLogger(project="Papers Re-implementations", name="ORIGINAL")
102 | logger.experiment.config.update(args)
103 | trainer = pl.Trainer(
104 | devices="auto",
105 | strategy="ddp",
106 | max_steps=max_train_steps,
107 | callbacks=callbacks,
108 | logger=logger,
109 | val_check_interval=checkpointing_freq,
110 | )
111 | trainer.fit(model, datamodule=dataset)
112 |
113 | # Testing
114 | model = EncoderDecoderModel.load_from_checkpoint(
115 | os.path.join(save_dir, "best.ckpt")
116 | )
117 | trainer.test(model, datamodule=dataset)
118 |
119 | # Translating custom sentences
120 | translate_sentences(file_path, model, tokenizer, max_len)
121 |
122 |
123 | if __name__ == "__main__":
124 | parser = ArgumentParser()
125 |
126 | # Model hyper-parameters
127 | parser.add_argument("--enc_n_blocks", type=int, default=12)
128 | parser.add_argument("--enc_n_heads", type=int, default=12)
129 | parser.add_argument("--dec_n_blocks", type=int, default=12)
130 | parser.add_argument("--dec_n_heads", type=int, default=12)
131 | parser.add_argument("--hidden_dim", type=int, default=768)
132 | parser.add_argument("--dropout", type=float, default=0.1)
133 | parser.add_argument("--max_len", type=int, default=128)
134 |
135 | # Training parameters
136 | parser.add_argument("--batch_size", type=int, default=32)
137 | parser.add_argument("--max_train_steps", type=int, default=100_000)
138 | parser.add_argument("--lr", type=float, default=1e-4)
139 | parser.add_argument("--weight_decay", type=float, default=0.01)
140 | parser.add_argument("--warmup_steps", type=int, default=4000)
141 | parser.add_argument("--save", type=str, default="checkpoints/original")
142 |
143 | # Testing parameters
144 | parser.add_argument(
145 | "--file", type=str, default="data/nlp/original/translate_sentences.txt"
146 | )
147 |
148 | # Seed
149 | parser.add_argument("--seed", type=int, default=0)
150 |
151 | args = vars(parser.parse_args())
152 | print(args)
153 | main(args)
154 |
--------------------------------------------------------------------------------
/src/nlp/original/model.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import pytorch_lightning as pl
4 | import torch
5 | import torch.nn as nn
6 | from pytorch_lightning.utilities.types import STEP_OUTPUT
7 | from torch.optim import Adam
8 | from torch.optim.lr_scheduler import LambdaLR
9 |
10 | from src.nlp.layers.decoder import DecoderTransformer
11 | from src.nlp.layers.embeddings import get_learnable_embedding
12 | from src.nlp.layers.encoder import EncoderTransformer
13 |
14 |
15 | class EncoderDecoderModel(pl.LightningModule):
16 | def __init__(
17 | self,
18 | vocab_size,
19 | max_len,
20 | hidden_dim,
21 | enc_n_heads,
22 | enc_n_blocks,
23 | dec_n_heads,
24 | dec_n_blocks,
25 | dropout,
26 | lr,
27 | weight_decay,
28 | warmup_steps,
29 | ):
30 | super(EncoderDecoderModel, self).__init__()
31 |
32 | # Saving hyper-parameters so that they are logged
33 | self.save_hyperparameters()
34 |
35 | # Local parameters
36 | self.vocab_size = vocab_size
37 | self.max_len = max_len
38 | self.hidden_dim = hidden_dim
39 | self.enc_n_heads = enc_n_heads
40 | self.enc_n_blocks = enc_n_blocks
41 | self.dec_n_heads = dec_n_heads
42 | self.dec_n_blocks = dec_n_blocks
43 | self.dropout = dropout
44 | self.lr = lr
45 | self.weight_decay = weight_decay
46 | self.warmup_steps = warmup_steps
47 | self.scheduler = None
48 |
49 | # Embeddings (note: we're learning embeddings for both languages in MT)
50 | self.embedding = get_learnable_embedding(vocab_size, hidden_dim)
51 | self.enc_pos_embedding = get_learnable_embedding(max_len, hidden_dim)
52 | self.dec_pos_embedding = get_learnable_embedding(max_len, hidden_dim)
53 |
54 | # Encoder and decoder models
55 | self.enc_transformer = EncoderTransformer(
56 | hidden_dim, enc_n_heads, enc_n_blocks, dropout_p=dropout
57 | )
58 | self.dec_transformer = DecoderTransformer(
59 | hidden_dim, enc_n_heads, enc_n_blocks, dropout_p=dropout, with_xa=True
60 | )
61 |
62 | # Decoding head
63 | self.head = nn.Sequential(
64 | nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, vocab_size)
65 | )
66 |
67 | def scheduler_fn(self, step):
68 | step += 1
69 | return self.hidden_dim ** (-0.5) * min(
70 | step ** (-0.5), step * self.warmup_steps ** (-1.5)
71 | )
72 |
73 | def configure_optimizers(self):
74 | optim = Adam(
75 | self.trainer.model.parameters(),
76 | lr=self.lr,
77 | weight_decay=self.weight_decay,
78 | betas=(0.9, 0.98),
79 | eps=1e-9,
80 | )
81 | self.scheduler = LambdaLR(optim, self.scheduler_fn)
82 | return optim
83 |
84 | def forward(
85 | self, ids_enc, ids_dec, enc_attn_mask, dec_attn_mask, enc_dec_attn_mask
86 | ):
87 | assert ids_enc.shape[0] == ids_dec.shape[0]
88 |
89 | enc_out = self.forward_enc(ids_enc, enc_attn_mask)
90 | dec_out = self.forward_dec(ids_dec, dec_attn_mask, enc_dec_attn_mask, enc_out)
91 |
92 | return self.head(dec_out), enc_out, dec_out
93 |
94 | def forward_enc(self, ids_enc, enc_attn_mask):
95 | b, t = ids_enc.shape
96 | x_enc = self.embedding(ids_enc) + self.enc_pos_embedding(
97 | torch.arange(t).cuda()
98 | ).repeat(b, 1, 1)
99 | enc_out = self.enc_transformer(x_enc, enc_attn_mask)
100 | return enc_out
101 |
102 | def forward_dec(self, ids_dec, dec_attn_mask, enc_dec_attn_mask, enc_out):
103 | b, t = ids_dec.shape
104 | x_dec = self.embedding(ids_dec) + self.dec_pos_embedding(
105 | torch.arange(t).cuda()
106 | ).repeat(b, 1, 1)
107 | dec_out = self.dec_transformer(
108 | x_dec,
109 | kv=enc_out,
110 | self_attn_mask=dec_attn_mask,
111 | cross_attn_mask=enc_dec_attn_mask,
112 | )
113 | return dec_out
114 |
115 | def compute_loss(self, batch):
116 | ids_enc = batch["x_enc"]
117 | ids_dec = batch["x_dec"]
118 | enc_attn_mask = batch["enc_attn_mask"]
119 | dec_attn_mask = batch["dec_attn_mask"]
120 | enc_dec_attn_mask = batch["enc_dec_attn_mask"]
121 |
122 | b, te = ids_enc.shape
123 | td = ids_dec.shape[1]
124 |
125 | y_pred, _, _ = self(
126 | ids_enc,
127 | ids_dec,
128 | enc_attn_mask.repeat(1, te).reshape(b, te, te),
129 | dec_attn_mask.repeat(1, td).reshape(b, td, td).tril(),
130 | enc_dec_attn_mask.repeat(1, td).reshape(b, td, td),
131 | )
132 |
133 | y_pred = y_pred[dec_attn_mask == 1][:-1]
134 | y = ids_dec[dec_attn_mask == 1][1:]
135 |
136 | loss = nn.functional.cross_entropy(
137 | y_pred.reshape(-1, self.vocab_size), y.reshape(-1)
138 | )
139 | return loss
140 |
141 | def training_step(self, batch, batch_idx):
142 | loss = self.compute_loss(batch)
143 | self.log("train_loss", loss)
144 | return loss
145 |
146 | def on_train_batch_end(
147 | self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int
148 | ) -> None:
149 | if self.scheduler is not None:
150 | self.scheduler.step()
151 |
152 | return super().on_train_batch_end(outputs, batch, batch_idx)
153 |
154 | def validation_step(self, batch, batch_idx):
155 | loss = self.compute_loss(batch)
156 | self.log("val_loss", loss)
157 | return loss
158 |
159 | def test_step(self, batch, batch_idx):
160 | loss = self.compute_loss(batch)
161 | self.log("test_loss", loss)
162 | return loss
163 |
164 | def generate(self, x_enc, x_dec, max_len=None):
165 | if max_len is None:
166 | max_len = self.max_len
167 |
168 | attn_enc = attn_enc_dec = torch.ones(x_enc.shape[1]).cuda()
169 | enc_out = self.forward_enc(x_enc, attn_enc)
170 |
171 | while x_dec.shape[1] < max_len:
172 | t = x_dec.shape[-1]
173 | attn_dec = torch.tril(torch.ones(t, t)).cuda()
174 | dec_out = self.forward_dec(x_dec, attn_dec, attn_enc_dec, enc_out)
175 | probs = torch.softmax(self.head(dec_out)[:, -1, :], dim=-1)
176 | token = torch.multinomial(probs, num_samples=1)
177 | x_dec = torch.cat((x_dec, token), dim=-1)
178 | return x_dec
179 |
--------------------------------------------------------------------------------
/src/nlp/tokenizers/bpe.py:
--------------------------------------------------------------------------------
1 | """Byte-Pair Encoding (BPE) tokenizer."""
2 |
3 | from tqdm.auto import tqdm
4 |
5 |
6 | def bpe_train(corpus, vocab_size):
7 | """Uses the BPE algorithm to train a tokenizer on a corpus. Returns the vocabulary."""
8 | # Initialize the vocabulary
9 | corpus = [
10 | l if word[0] == l else "##" + l for word in corpus.split(" ") for l in word
11 | ]
12 | vocab = set(corpus)
13 |
14 | # Keep merging most likely pairs until the vocabulary is the desired size
15 | for i in tqdm(range(vocab_size - len(vocab)), desc="Training bpe tokenizer..."):
16 | counts = {}
17 |
18 | # Count how many times each pair appears
19 | for i in range(len(corpus) - 1):
20 | pair = corpus[i] + corpus[i + 1]
21 | counts[pair] = counts.get(pair, 0) + 1
22 |
23 | # Find the pair that appeared the most
24 | max_pair, max_count = None, -1
25 | for k in counts:
26 | if counts[k] > max_count:
27 | max_pair, max_count = k, counts[k]
28 |
29 | # Adding pair to vocabulary
30 | vocab.add(max_pair)
31 |
32 | # Updating corpus
33 | new_corpus, added = [], False
34 | for i in range(len(corpus) - 1):
35 | if added:
36 | added = False
37 | continue
38 |
39 | if corpus[i] + corpus[i + 1] == max_pair:
40 | new_corpus.append(max_pair)
41 | added = True
42 | else:
43 | new_corpus.append(corpus[i])
44 |
45 | corpus = new_corpus
46 |
47 | # Remove the "##" prefix and return vocabulary
48 | vocab = set(
49 | [
50 | ("##" if elem.startswith("##") else "") + elem.replace("##", "")
51 | for elem in vocab
52 | ]
53 | )
54 | return vocab
55 |
56 |
57 | if __name__ == "__main__":
58 | # Example
59 | corpus = "machine learning and meta learning allow machines to learn how to learn"
60 | vocabulary = bpe_train(corpus, 30)
61 | print(vocabulary)
62 |
--------------------------------------------------------------------------------
/src/nlp/tokenizers/wordpiece.py:
--------------------------------------------------------------------------------
1 | """Wordpiece tokenizer."""
2 |
3 | from tqdm.auto import tqdm
4 |
5 |
6 | def wordpiece_train(corpus, vocab_size):
7 | """Uses the wordpiece algorithm to train a tokenizer on a corpus. Returns the vocabulary."""
8 | # Initialize the vocabulary with letters
9 | corpus = [
10 | l if word[0] == l else "##" + l for word in corpus.split(" ") for l in word
11 | ]
12 | vocab = set(corpus)
13 |
14 | # Keep merging most likely pairs until the vocabulary is the desired size
15 | for i in tqdm(
16 | range(vocab_size - len(vocab)), desc="Training wordpiece tokenizer..."
17 | ):
18 | counts = {}
19 | pair_counts = {}
20 |
21 | # Keep count of each word and each pair
22 | for i in range(len(corpus)):
23 | counts[corpus[i]] = counts.get(corpus[i], 0) + 1
24 |
25 | if i == len(corpus) - 1:
26 | continue
27 |
28 | pair = corpus[i] + corpus[i + 1]
29 | pair_counts[(corpus[i], corpus[i + 1])] = counts.get(pair, 0)
30 |
31 | # Find the pair that has the highest score
32 | # The score is count(w1, w2) / (count(w1) * count(w2))
33 | max_pair, max_score = None, -1
34 | for w1, w2 in pair_counts:
35 | p_count = pair_counts[(w1, w2)]
36 | pair_score = p_count / (counts[w1] * counts[w2])
37 | if pair_score > max_score:
38 | max_pair, max_score = w1 + w2, pair_score
39 |
40 | # Add the pair with the highest score to the vocabulary
41 | vocab.add(max_pair)
42 |
43 | # Update the corpus by merging the pair
44 | new_corpus, added = [], False
45 | for i in range(len(corpus) - 1):
46 | if added:
47 | added = False
48 | continue
49 |
50 | if corpus[i] + corpus[i + 1] == max_pair:
51 | new_corpus.append(max_pair)
52 | added = True
53 | else:
54 | new_corpus.append(corpus[i])
55 |
56 | corpus = new_corpus
57 |
58 | # Remove the "##" prefix and return vocabulary
59 | vocab = set(
60 | [
61 | ("##" if elem.startswith("##") else "") + elem.replace("##", "")
62 | for elem in vocab
63 | ]
64 | )
65 | return vocab
66 |
67 |
68 | if __name__ == "__main__":
69 | # Example
70 | corpus = "machine learning and meta learning allow machines to learn how to learn"
71 | vocabulary = wordpiece_train(corpus, 30)
72 | print(vocabulary)
73 |
--------------------------------------------------------------------------------
/src/rl/dqn/dqn.py:
--------------------------------------------------------------------------------
1 | """Reimplementation of 'Playing Atari with Deep Reinforcement Learning' by Mnih et al. (2013)"""
2 | import os
3 | import random
4 |
5 | import torch
6 | from accelerate import Accelerator
7 | from tqdm.auto import tqdm
8 |
9 | import wandb
10 |
11 |
12 | class ReplayBuffer:
13 | def __init__(self, capacity=10000):
14 | self.capacity = capacity
15 | self.buffer = []
16 |
17 | def push(self, state, action, reward, next_state, done):
18 | if len(self.buffer) >= self.capacity:
19 | self.buffer.pop(0)
20 | self.buffer.append((state, action, reward, next_state, done))
21 |
22 | def sample(self, batch_size):
23 | state, action, reward, next_state, done = zip(
24 | *random.sample(self.buffer, batch_size)
25 | )
26 | return (
27 | torch.stack(state),
28 | torch.tensor(action),
29 | torch.tensor(reward),
30 | torch.stack(next_state),
31 | torch.tensor(done).int(),
32 | )
33 |
34 |
35 | def dqn_training(
36 | model,
37 | environment,
38 | optimizer,
39 | gamma,
40 | epsilon,
41 | batch_size,
42 | episodes,
43 | checkpoint_path,
44 | buffer_capacity=10000,
45 | ):
46 | # Initialize run
47 | wandb.init(project="Papers Re-implementations", name="DQN")
48 | wandb.watch(model)
49 |
50 | # Create checkpoint directory
51 | os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
52 |
53 | # Initialize accelerator
54 | accelerator = Accelerator()
55 | model, optimizer = accelerator.prepare(model, optimizer)
56 |
57 | # Initialize replay buffer
58 | buffer = ReplayBuffer(buffer_capacity)
59 |
60 | def state_process_fn(x):
61 | return torch.tensor(x).float().to(accelerator.device)
62 |
63 | def optimization_step():
64 | # Sample replay buffer
65 | state, action, reward, next_state, done = buffer.sample(batch_size)
66 | action = action.to(accelerator.device)
67 | reward = reward.to(accelerator.device).float()
68 | done = done.to(accelerator.device)
69 |
70 | # Compute the target Q value
71 | with torch.no_grad():
72 | target_q = reward + gamma * model(next_state).max(1)[0] * (1 - done)
73 |
74 | # Get current Q estimate
75 | current_q = model(state).gather(1, action.unsqueeze(1))
76 |
77 | # Compute loss
78 | loss = torch.nn.functional.mse_loss(current_q, target_q.unsqueeze(1))
79 |
80 | # Optimize the model
81 | optimizer.zero_grad()
82 | accelerator.backward(loss)
83 | accelerator.clip_grad_norm_(model.parameters(), max_norm=10.0)
84 | optimizer.step()
85 |
86 | return loss.item()
87 |
88 | # Training loop
89 | checkpoint_loss = float("inf")
90 | pbar = tqdm(range(episodes))
91 | for ep in pbar:
92 | # Initialize episode
93 | pbar.set_description(f"Episode {ep+1}/{episodes}")
94 | state = state_process_fn(environment.reset()[0])
95 | episode_loss, episode_reward, episode_length = 0, 0, 0
96 |
97 | done, truncated = False, False
98 | while not done and not truncated:
99 | # Act in the environment
100 | if random.random() < epsilon:
101 | action = environment.action_space.sample()
102 | else:
103 | action = model(state).argmax().item()
104 |
105 | # Update environment
106 | next_state, reward, done, truncated, _ = environment.step(action)
107 | next_state = state_process_fn(next_state)
108 |
109 | # Register transition in replay buffer
110 | buffer.push(state, action, reward, next_state, done)
111 |
112 | # Update Q-function
113 | if len(buffer.buffer) >= batch_size:
114 | loss = optimization_step()
115 | episode_reward += reward
116 | episode_loss += loss
117 |
118 | state = next_state
119 | episode_length += 1
120 |
121 | if len(buffer.buffer) >= batch_size:
122 | # Log episode stats
123 | wandb.log(
124 | {
125 | "loss": episode_loss,
126 | "reward": episode_reward,
127 | "ep. length": episode_length,
128 | }
129 | )
130 |
131 | if len(buffer.buffer) >= batch_size and episode_loss < checkpoint_loss:
132 | torch.save(model.state_dict(), checkpoint_path)
133 | checkpoint_loss = episode_loss
134 | print(
135 | f"Checkpoint saved with loss {checkpoint_loss:.3f} at episode {ep+1} / {episodes}"
136 | )
137 |
138 | wandb.finish()
139 |
--------------------------------------------------------------------------------
/src/rl/dqn/main.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import gymnasium as gym
4 | import torch
5 | import torch.nn as nn
6 | from pytorch_lightning import seed_everything
7 |
8 | from src.rl.dqn.dqn import dqn_training
9 |
10 | DISCRETE_ACTION_ENVIRONMENTS = [
11 | "Acrobot-v1",
12 | "CartPole-v1",
13 | "LunarLander-v2",
14 | "MountainCar-v0",
15 | ]
16 |
17 |
18 | class MLP(nn.Module):
19 | def __init__(self, input_dim, output_dim, hidden_dims=(64, 64)):
20 | super().__init__()
21 | self.relu = nn.ReLU()
22 | self.in_layer = nn.Linear(input_dim, hidden_dims[0])
23 | self.blocks = nn.ModuleList(
24 | [
25 | nn.Sequential(
26 | nn.LayerNorm(hidden_dims[i]),
27 | nn.Linear(hidden_dims[i], hidden_dims[i + 1]),
28 | nn.ReLU(),
29 | )
30 | for i in range(len(hidden_dims) - 1)
31 | ]
32 | )
33 | self.out = nn.Linear(hidden_dims[-1], output_dim)
34 |
35 | def forward(self, x):
36 | x = self.in_layer(x)
37 | for block in self.blocks:
38 | x = block(x)
39 | return self.out(x)
40 |
41 |
42 | def main(args):
43 | # Setting seed
44 | seed_everything(args["seed"])
45 |
46 | # Unpacking args
47 | gamma = args["gamma"]
48 | epsilon = args["epsilon"]
49 | batch_size = args["batch_size"]
50 | lr = args["lr"]
51 | episodes = args["train_episodes"]
52 | checkpoint_path = args["checkpoint_path"]
53 | buffer_capacity = args["buffer_capacity"]
54 | optimizer_fn = getattr(torch.optim, args["optimizer"])
55 |
56 | # Environment
57 | env = gym.make(args["env"])
58 |
59 | # Training
60 | n_inputs, n_outputs = env.observation_space.shape[0], env.action_space.n
61 | model = MLP(n_inputs, n_outputs)
62 | optimizer = optimizer_fn(model.parameters(), lr=lr)
63 | dqn_training(
64 | model,
65 | env,
66 | optimizer,
67 | gamma,
68 | epsilon,
69 | batch_size,
70 | episodes,
71 | checkpoint_path,
72 | buffer_capacity,
73 | )
74 |
75 | # Showing episodes
76 | env = gym.make(args["env"], render_mode="human")
77 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78 | model = MLP(n_inputs, n_outputs)
79 | model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
80 | model = model.eval().to(device)
81 | for episode in range(args["test_episodes"]):
82 | state = env.reset()[0]
83 | done = False
84 | while not done:
85 | action = model(torch.tensor(state).float().to(device)).argmax().item()
86 | state, _, done, _, _ = env.step(action)
87 | env.render()
88 |
89 |
90 | if __name__ == "__main__":
91 | parser = ArgumentParser()
92 | parser.add_argument(
93 | "--env",
94 | type=str,
95 | default="CartPole-v1",
96 | choices=DISCRETE_ACTION_ENVIRONMENTS,
97 | )
98 | parser.add_argument("--seed", type=int, default=0)
99 | parser.add_argument("--gamma", type=float, default=0.99)
100 | parser.add_argument("--epsilon", type=float, default=0.05)
101 | parser.add_argument("--batch_size", type=int, default=256)
102 | parser.add_argument("--lr", type=float, default=0.001)
103 | parser.add_argument("--train_episodes", type=int, default=100)
104 | parser.add_argument("--test_episodes", type=int, default=10)
105 | parser.add_argument("--checkpoint_path", type=str, default="checkpoints/dqn/dqn.pt")
106 | parser.add_argument("--buffer_capacity", type=int, default=1024)
107 | parser.add_argument(
108 | "--optimizer", type=str, default="Adam", choices=["Adam", "SGD"]
109 | )
110 | args = vars(parser.parse_args())
111 | print(args)
112 | main(args)
113 |
--------------------------------------------------------------------------------
/src/rl/ppo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BrianPulfer/PapersReimplementations/ce70cab23f7157d911fa541b387797d9d64c65fc/src/rl/ppo/__init__.py
--------------------------------------------------------------------------------
/src/rl/ppo/ppo.py:
--------------------------------------------------------------------------------
1 | """
2 | Personal reimplementation of
3 | Proximal Policy Optimization Algorithms
4 | (https://arxiv.org/abs/1707.06347)
5 | """
6 |
7 | from argparse import ArgumentParser
8 |
9 | import gym
10 | import numpy as np
11 | import pytorch_lightning as pl
12 | import torch
13 | import torch.nn as nn
14 | from torch.distributions.categorical import Categorical
15 | from torch.optim import Adam
16 | from torch.optim.lr_scheduler import LinearLR
17 |
18 | import wandb
19 |
20 | # Definitions
21 | MODEL_PATH = "model.pt"
22 |
23 |
24 | def parse_args():
25 | """Pareser program arguments"""
26 | # Parser
27 | parser = ArgumentParser()
28 |
29 | # Program arguments (default for Atari games)
30 | parser.add_argument(
31 | "--max_iterations",
32 | type=int,
33 | help="Number of iterations of training",
34 | default=100,
35 | )
36 | parser.add_argument(
37 | "--n_actors", type=int, help="Number of actors for each update", default=8
38 | )
39 | parser.add_argument(
40 | "--horizon", type=int, help="Number of timestamps for each actor", default=128
41 | )
42 | parser.add_argument("--epsilon", type=float, help="Epsilon parameter", default=0.1)
43 | parser.add_argument(
44 | "--n_epochs",
45 | type=int,
46 | help="Number of training epochs per iteration",
47 | default=3,
48 | )
49 | parser.add_argument("--batch_size", type=int, help="Batch size", default=32 * 8)
50 | parser.add_argument("--lr", type=float, help="Learning rate", default=2.5 * 1e-4)
51 | parser.add_argument(
52 | "--gamma", type=float, help="Discount factor gamma", default=0.99
53 | )
54 | parser.add_argument(
55 | "--c1",
56 | type=float,
57 | help="Weight for the value function in the loss function",
58 | default=1,
59 | )
60 | parser.add_argument(
61 | "--c2",
62 | type=float,
63 | help="Weight for the entropy bonus in the loss function",
64 | default=0.01,
65 | )
66 | parser.add_argument(
67 | "--n_test_episodes", type=int, help="Number of episodes to render", default=5
68 | )
69 | parser.add_argument(
70 | "--seed", type=int, help="Randomizing seed for the experiment", default=0
71 | )
72 |
73 | # Dictionary with program arguments
74 | return vars(parser.parse_args())
75 |
76 |
77 | def get_device():
78 | """Gets the device (GPU if any) and logs the type"""
79 | if torch.cuda.is_available():
80 | device = torch.device("cuda")
81 | print(f"Found GPU device: {torch.cuda.get_device_name(device)}")
82 | else:
83 | device = torch.device("cpu")
84 | print("No GPU found: Running on CPU")
85 | return device
86 |
87 |
88 | class MyPPO(nn.Module):
89 | """Implementation of a PPO model. The same backbone is used to get actor and critic values."""
90 |
91 | def __init__(self, in_shape, n_actions, hidden_d=100, share_backbone=False):
92 | # Super constructor
93 | super(MyPPO, self).__init__()
94 |
95 | # Attributes
96 | self.in_shape = in_shape
97 | self.n_actions = n_actions
98 | self.hidden_d = hidden_d
99 | self.share_backbone = share_backbone
100 |
101 | # Shared backbone for policy and value functions
102 | in_dim = np.prod(in_shape)
103 |
104 | def to_features():
105 | return nn.Sequential(
106 | nn.Flatten(),
107 | nn.Linear(in_dim, hidden_d),
108 | nn.ReLU(),
109 | nn.Linear(hidden_d, hidden_d),
110 | nn.ReLU(),
111 | )
112 |
113 | self.backbone = to_features() if self.share_backbone else nn.Identity()
114 |
115 | # State action function
116 | self.actor = nn.Sequential(
117 | nn.Identity() if self.share_backbone else to_features(),
118 | nn.Linear(hidden_d, hidden_d),
119 | nn.ReLU(),
120 | nn.Linear(hidden_d, n_actions),
121 | nn.Softmax(dim=-1),
122 | )
123 |
124 | # Value function
125 | self.critic = nn.Sequential(
126 | nn.Identity() if self.share_backbone else to_features(),
127 | nn.Linear(hidden_d, hidden_d),
128 | nn.ReLU(),
129 | nn.Linear(hidden_d, 1),
130 | )
131 |
132 | def forward(self, x):
133 | features = self.backbone(x)
134 | action = self.actor(features)
135 | value = self.critic(features)
136 | return Categorical(action).sample(), action, value
137 |
138 |
139 | @torch.no_grad()
140 | def run_timestamps(env, model, timestamps=128, render=False, device="cpu"):
141 | """Runs the given policy on the given environment for the given amount of timestamps.
142 | Returns a buffer with state action transitions and rewards."""
143 | buffer = []
144 | state = env.reset()[0]
145 |
146 | # Running timestamps and collecting state, actions, rewards and terminations
147 | for ts in range(timestamps):
148 | # Taking a step into the environment
149 | model_input = torch.from_numpy(state).unsqueeze(0).to(device).float()
150 | action, action_logits, value = model(model_input)
151 | new_state, reward, terminated, truncated, info = env.step(action.item())
152 |
153 | # Rendering / storing (s, a, r, t) in the buffer
154 | if render:
155 | env.render()
156 | else:
157 | buffer.append(
158 | [
159 | model_input,
160 | action,
161 | action_logits,
162 | value,
163 | reward,
164 | terminated or truncated,
165 | ]
166 | )
167 |
168 | # Updating current state
169 | state = new_state
170 |
171 | # Resetting environment if episode terminated or truncated
172 | if terminated or truncated:
173 | state = env.reset()[0]
174 |
175 | return buffer
176 |
177 |
178 | def compute_cumulative_rewards(buffer, gamma):
179 | """Given a buffer with states, policy action logits, rewards and terminations,
180 | computes the cumulative rewards for each timestamp and substitutes them into the buffer.
181 | """
182 | curr_rew = 0.0
183 |
184 | # Traversing the buffer on the reverse direction
185 | for i in range(len(buffer) - 1, -1, -1):
186 | r, t = buffer[i][-2], buffer[i][-1]
187 |
188 | if t:
189 | curr_rew = 0
190 | else:
191 | curr_rew = r + gamma * curr_rew
192 |
193 | buffer[i][-2] = curr_rew
194 |
195 | # Getting the average reward before normalizing (for logging and checkpointing)
196 | avg_rew = np.mean([buffer[i][-2] for i in range(len(buffer))])
197 |
198 | # Normalizing cumulative rewards
199 | mean = np.mean([buffer[i][-2] for i in range(len(buffer))])
200 | std = np.std([buffer[i][-2] for i in range(len(buffer))]) + 1e-6
201 | for i in range(len(buffer)):
202 | buffer[i][-2] = (buffer[i][-2] - mean) / std
203 |
204 | return avg_rew
205 |
206 |
207 | def get_losses(model, batch, epsilon, annealing, device="cpu"):
208 | """Returns the three loss terms for a given model and a given batch and additional parameters"""
209 | # Getting old data
210 | n = len(batch)
211 | states = torch.cat([batch[i][0] for i in range(n)])
212 | actions = torch.cat([batch[i][1] for i in range(n)]).view(n, 1)
213 | logits = torch.cat([batch[i][2] for i in range(n)])
214 | values = torch.cat([batch[i][3] for i in range(n)])
215 | cumulative_rewards = (
216 | torch.tensor([batch[i][-2] for i in range(n)]).view(-1, 1).float().to(device)
217 | )
218 |
219 | # Computing predictions with the new model
220 | _, new_logits, new_values = model(states)
221 |
222 | # Loss on the state-action-function / actor (L_CLIP)
223 | advantages = cumulative_rewards - values
224 | margin = epsilon * annealing
225 | ratios = new_logits.gather(1, actions) / logits.gather(1, actions)
226 |
227 | l_clip = torch.mean(
228 | torch.min(
229 | torch.cat(
230 | (
231 | ratios * advantages,
232 | torch.clip(ratios, 1 - margin, 1 + margin) * advantages,
233 | ),
234 | dim=1,
235 | ),
236 | dim=1,
237 | ).values
238 | )
239 |
240 | # Loss on the value-function / critic (L_VF)
241 | l_vf = torch.mean((cumulative_rewards - new_values) ** 2)
242 |
243 | # Bonus for entropy of the actor
244 | entropy_bonus = torch.mean(
245 | torch.sum(-new_logits * (torch.log(new_logits + 1e-5)), dim=1)
246 | )
247 |
248 | return l_clip, l_vf, entropy_bonus
249 |
250 |
251 | def training_loop(
252 | env,
253 | model,
254 | max_iterations,
255 | n_actors,
256 | horizon,
257 | gamma,
258 | epsilon,
259 | n_epochs,
260 | batch_size,
261 | lr,
262 | c1,
263 | c2,
264 | device,
265 | env_name="",
266 | ):
267 | """Train the model on the given environment using multiple actors acting up to n timestamps."""
268 |
269 | # Starting a new Weights & Biases run
270 | wandb.init(
271 | project="Papers Re-implementations",
272 | entity="peutlefaire",
273 | name=f"PPO - {env_name}",
274 | config={
275 | "env": str(env),
276 | "number of actors": n_actors,
277 | "horizon": horizon,
278 | "gamma": gamma,
279 | "epsilon": epsilon,
280 | "epochs": n_epochs,
281 | "batch size": batch_size,
282 | "learning rate": lr,
283 | "c1": c1,
284 | "c2": c2,
285 | },
286 | )
287 |
288 | # Training variables
289 | max_reward = float("-inf")
290 | optimizer = Adam(model.parameters(), lr=lr, maximize=True)
291 | scheduler = LinearLR(optimizer, 1, 0, max_iterations * n_epochs)
292 | anneals = np.linspace(1, 0, max_iterations)
293 |
294 | # Training loop
295 | for iteration in range(max_iterations):
296 | buffer = []
297 | annealing = anneals[iteration]
298 |
299 | # Collecting timestamps for all actors with the current policy
300 | for actor in range(1, n_actors + 1):
301 | buffer.extend(run_timestamps(env, model, horizon, False, device))
302 |
303 | # Computing cumulative rewards and shuffling the buffer
304 | avg_rew = compute_cumulative_rewards(buffer, gamma)
305 | np.random.shuffle(buffer)
306 |
307 | # Running optimization for a few epochs
308 | for epoch in range(n_epochs):
309 | for batch_idx in range(len(buffer) // batch_size):
310 | # Getting batch for this buffer
311 | start = batch_size * batch_idx
312 | end = start + batch_size if start + batch_size < len(buffer) else -1
313 | batch = buffer[start:end]
314 |
315 | # Zero-ing optimizers gradients
316 | optimizer.zero_grad()
317 |
318 | # Getting the losses
319 | l_clip, l_vf, entropy_bonus = get_losses(
320 | model, batch, epsilon, annealing, device
321 | )
322 |
323 | # Computing total loss and back-propagating it
324 | loss = l_clip - c1 * l_vf + c2 * entropy_bonus
325 | loss.backward()
326 |
327 | # Optimizing
328 | optimizer.step()
329 | scheduler.step()
330 |
331 | # Logging information to stdout
332 | curr_loss = loss.item()
333 | log = (
334 | f"Iteration {iteration + 1} / {max_iterations}: "
335 | f"Average Reward: {avg_rew:.2f}\t"
336 | f"Loss: {curr_loss:.3f} "
337 | f"(L_CLIP: {l_clip.item():.1f} | L_VF: {l_vf.item():.1f} | L_bonus: {entropy_bonus.item():.1f})"
338 | )
339 | if avg_rew > max_reward:
340 | torch.save(model.state_dict(), MODEL_PATH)
341 | max_reward = avg_rew
342 | log += " --> Stored model with highest average reward"
343 | print(log)
344 |
345 | # Logging information to W&B
346 | wandb.log(
347 | {
348 | "loss (total)": curr_loss,
349 | "loss (clip)": l_clip.item(),
350 | "loss (vf)": l_vf.item(),
351 | "loss (entropy bonus)": entropy_bonus.item(),
352 | "average reward": avg_rew,
353 | }
354 | )
355 |
356 | # Finishing W&B session
357 | wandb.finish()
358 |
359 |
360 | def testing_loop(env, model, n_episodes, device):
361 | """Runs the learned policy on the environment for n episodes"""
362 | for _ in range(n_episodes):
363 | run_timestamps(env, model, timestamps=128, render=True, device=device)
364 |
365 |
366 | def main():
367 | # Parsing program arguments
368 | args = parse_args()
369 | print(args)
370 |
371 | # Setting seed
372 | pl.seed_everything(args["seed"])
373 |
374 | # Getting device
375 | device = get_device()
376 |
377 | # Creating environment (discrete action space)
378 | env_name = "CartPole-v1"
379 | env = gym.make(env_name)
380 |
381 | # Creating the model (both actor and critic)
382 | model = MyPPO(env.observation_space.shape, env.action_space.n).to(device)
383 |
384 | # Training
385 | training_loop(
386 | env,
387 | model,
388 | args["max_iterations"],
389 | args["n_actors"],
390 | args["horizon"],
391 | args["gamma"],
392 | args["epsilon"],
393 | args["n_epochs"],
394 | args["batch_size"],
395 | args["lr"],
396 | args["c1"],
397 | args["c2"],
398 | device,
399 | env_name,
400 | )
401 |
402 | # Loading best model
403 | model = MyPPO(env.observation_space.shape, env.action_space.n).to(device)
404 | model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
405 |
406 | # Testing
407 | env = gym.make(env_name, render_mode="human")
408 | testing_loop(env, model, args["n_test_episodes"], device)
409 | env.close()
410 |
411 |
412 | if __name__ == "__main__":
413 | main()
414 |
--------------------------------------------------------------------------------