├── .gitignore ├── 1-setup.ipynb ├── 2-intro_FNO.ipynb ├── 3-darcy_flow.ipynb ├── 4-training-on-Darcy-Flow.ipynb ├── README.md ├── assets ├── 2023-Nik.pdf └── Bootcamp_DLI_Modulus_v2207_CalTech_share.pdf ├── config ├── darcy_config.yaml └── tfno_darcy_config.yaml └── images └── fourier_layer.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | data 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /1-setup.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "2c160c9c-a445-493e-9948-7ba507c606fb", 6 | "metadata": {}, 7 | "source": [ 8 | "# Running bash commands from your notebook\n", 9 | "\n", 10 | "First, let's install all the dependencies. \n", 11 | "\n", 12 | "You can directly run bash commands in your notebook, by either prefixing your commands with an exclamation mark `!`:\n", 13 | "```ipython\n", 14 | "[1] !echo \"this is a bash command\"\n", 15 | "this is a bash command\n", 16 | "\n", 17 | "[2] !ls\n", 18 | "/home/user/git_repos/FNO_workshop\n", 19 | "```\n", 20 | "\n", 21 | "or by starting your cell with the `%%bash` ipython magic. \n", 22 | "\n", 23 | "Let's see a simple example:" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 3, 29 | "id": "24e20734-97e5-4295-9952-d67ac36b63a0", 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "Couldn't find program: 'bash'\n" 37 | ] 38 | } 39 | ], 40 | "source": [ 41 | "%%bash\n", 42 | "\n", 43 | "for var in hello world\n", 44 | "do\n", 45 | " echo ${var} \n", 46 | "done" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "id": "5b47acb6-a558-40bf-bd76-872941fdf879", 52 | "metadata": {}, 53 | "source": [ 54 | "# Installing the dependencies\n", 55 | "\n", 56 | "Now, let's install the dependencies." 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 4, 62 | "id": "bcce1c3e-b4d7-44ea-8b98-bce1f04182cf", 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "Couldn't find program: 'bash'\n" 70 | ] 71 | } 72 | ], 73 | "source": [ 74 | "%%bash \n", 75 | "\n", 76 | "target_folder='./temp'\n", 77 | "[ -d ${target_folder} ] || mkdir -p ${target_folder}\n", 78 | "cd temp\n", 79 | "\n", 80 | "git clone https://github.com/tensorly/tensorly \n", 81 | "cd tensorly\n", 82 | "python -m pip install -e .\n", 83 | "cd ..\n", 84 | "\n", 85 | "git clone https://github.com/tensorly/torch\n", 86 | "cd torch\n", 87 | "python -m pip install -e .\n", 88 | "cd ..\n", 89 | "\n", 90 | "git clone https://github.com/NeuralOperator/neuraloperator\n", 91 | "cd neuraloperator\n", 92 | "python -m pip install -e ." 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 5, 98 | "id": "e0bb548e-6e98-4fac-935e-52a8115c4aac", 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "Collecting gpustat\n", 106 | " Downloading gpustat-1.0.0.tar.gz (90 kB)\n", 107 | "Requirement already satisfied: six>=1.7 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from gpustat) (1.16.0)\n", 108 | "Collecting nvidia-ml-py<=11.495.46,>=11.450.129\n", 109 | " Downloading nvidia_ml_py-11.495.46-py3-none-any.whl (25 kB)\n", 110 | "Requirement already satisfied: psutil>=5.6.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from gpustat) (5.8.0)\n", 111 | "Collecting blessed>=1.17.1\n", 112 | " Downloading blessed-1.20.0-py2.py3-none-any.whl (58 kB)\n", 113 | "Collecting jinxed>=1.1.0\n", 114 | " Downloading jinxed-1.2.0-py2.py3-none-any.whl (33 kB)\n", 115 | "Requirement already satisfied: wcwidth>=0.1.4 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from blessed>=1.17.1->gpustat) (0.2.5)\n", 116 | "Collecting ansicon\n", 117 | " Downloading ansicon-1.89.0-py2.py3-none-any.whl (63 kB)\n", 118 | "Building wheels for collected packages: gpustat\n", 119 | " Building wheel for gpustat (setup.py): started\n", 120 | " Building wheel for gpustat (setup.py): finished with status 'done'\n", 121 | " Created wheel for gpustat: filename=gpustat-1.0.0-py3-none-any.whl size=19886 sha256=647135e0be6c489fa67d18d54e79c7dca544dfd5496efe4d20129d52a8c8803f\n", 122 | " Stored in directory: c:\\users\\devzh\\appdata\\local\\pip\\cache\\wheels\\1b\\ed\\14\\0d513c962b25da841c42022cb5847c2ef835902c8563b8fb01\n", 123 | "Successfully built gpustat\n", 124 | "Installing collected packages: ansicon, jinxed, nvidia-ml-py, blessed, gpustat\n", 125 | "Successfully installed ansicon-1.89.0 blessed-1.20.0 gpustat-1.0.0 jinxed-1.2.0 nvidia-ml-py-11.495.46\n", 126 | "Collecting gdown\n", 127 | " Downloading gdown-4.6.4-py3-none-any.whl (14 kB)\n", 128 | "Requirement already satisfied: requests[socks] in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from gdown) (2.26.0)\n", 129 | "Requirement already satisfied: beautifulsoup4 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from gdown) (4.10.0)\n", 130 | "Requirement already satisfied: filelock in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from gdown) (3.3.1)\n", 131 | "Requirement already satisfied: six in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from gdown) (1.16.0)\n", 132 | "Requirement already satisfied: tqdm in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from gdown) (4.62.3)\n", 133 | "Requirement already satisfied: soupsieve>1.2 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from beautifulsoup4->gdown) (2.2.1)\n", 134 | "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests[socks]->gdown) (2.0.4)\n", 135 | "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests[socks]->gdown) (1.26.7)\n", 136 | "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests[socks]->gdown) (2021.10.8)\n", 137 | "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests[socks]->gdown) (3.2)\n", 138 | "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests[socks]->gdown) (1.7.1)\n", 139 | "Requirement already satisfied: colorama in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from tqdm->gdown) (0.4.4)\n", 140 | "Installing collected packages: gdown\n", 141 | "Successfully installed gdown-4.6.4\n", 142 | "Requirement already satisfied: opt-einsum in c:\\users\\devzh\\anaconda3\\lib\\site-packages (3.3.0)\n", 143 | "Requirement already satisfied: numpy>=1.7 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from opt-einsum) (1.21.2)\n", 144 | "Requirement already satisfied: h5py in c:\\users\\devzh\\anaconda3\\lib\\site-packages (3.6.0)\n", 145 | "Requirement already satisfied: wandb in c:\\users\\devzh\\anaconda3\\lib\\site-packages (0.12.1)\n", 146 | "Requirement already satisfied: ruamel.yaml in c:\\users\\devzh\\anaconda3\\lib\\site-packages (0.17.21)\n", 147 | "Requirement already satisfied: zarr in c:\\users\\devzh\\anaconda3\\lib\\site-packages (2.14.1)\n", 148 | "Requirement already satisfied: numpy>=1.14.5 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from h5py) (1.21.2)\n", 149 | "Requirement already satisfied: Click!=8.0.0,>=7.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (8.0.3)\n", 150 | "Requirement already satisfied: promise<3,>=2.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (2.3)\n", 151 | "Requirement already satisfied: sentry-sdk>=1.0.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (1.3.1)\n", 152 | "Requirement already satisfied: GitPython>=1.0.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (3.1.18)\n", 153 | "Requirement already satisfied: docker-pycreds>=0.4.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (0.4.0)\n", 154 | "Requirement already satisfied: protobuf>=3.12.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (3.17.3)\n", 155 | "Requirement already satisfied: psutil>=5.0.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (5.8.0)\n", 156 | "Requirement already satisfied: subprocess32>=3.5.3 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (3.5.4)\n", 157 | "Requirement already satisfied: python-dateutil>=2.6.1 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (2.8.2)\n", 158 | "Requirement already satisfied: shortuuid>=0.5.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (1.0.1)\n", 159 | "Requirement already satisfied: requests<3,>=2.0.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (2.26.0)\n", 160 | "Requirement already satisfied: configparser>=3.8.1 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (5.0.2)\n", 161 | "Requirement already satisfied: PyYAML in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (6.0)\n", 162 | "Requirement already satisfied: six>=1.13.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (1.16.0)\n", 163 | "Requirement already satisfied: pathtools in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (0.1.2)\n", 164 | "Requirement already satisfied: ruamel.yaml.clib>=0.2.6 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from ruamel.yaml) (0.2.7)\n", 165 | "Requirement already satisfied: asciitree in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from zarr) (0.3.3)\n", 166 | "Requirement already satisfied: numcodecs>=0.10.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from zarr) (0.11.0)\n", 167 | "Requirement already satisfied: fasteners in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from zarr) (0.18)\n", 168 | "Requirement already satisfied: colorama in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from Click!=8.0.0,>=7.0->wandb) (0.4.4)\n", 169 | "Requirement already satisfied: gitdb<5,>=4.0.1 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from GitPython>=1.0.0->wandb) (4.0.7)\n", 170 | "Requirement already satisfied: smmap<5,>=3.0.1 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from gitdb<5,>=4.0.1->GitPython>=1.0.0->wandb) (4.0.0)\n", 171 | "Requirement already satisfied: entrypoints in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from numcodecs>=0.10.0->zarr) (0.3)\n", 172 | "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (1.26.7)\n", 173 | "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (2.0.4)\n", 174 | "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (2021.10.8)\n", 175 | "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (3.2)\n" 176 | ] 177 | } 178 | ], 179 | "source": [ 180 | "!pip install gpustat\n", 181 | "!pip install gdown\n", 182 | "!pip install opt-einsum\n", 183 | "!pip install h5py wandb ruamel.yaml zarr" 184 | ] 185 | }, 186 | { 187 | "attachments": {}, 188 | "cell_type": "markdown", 189 | "id": "f4ed3b9d-fffd-4d5d-852c-7dc95dad086f", 190 | "metadata": {}, 191 | "source": [ 192 | "# Prepare data " 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 10, 198 | "id": "3a2484ab-0f02-45c9-acce-cb0bbe803dbb", 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "import os\n", 203 | "import requests\n", 204 | "import hashlib\n", 205 | "url_dict = {\n", 206 | " 'darcyflow-1':'https://caltech-pde-data.s3.us-west-2.amazonaws.com/piececonst_r241_N1024_smooth1.mat', \n", 207 | " 'darcyflow-2': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/piececonst_r241_N1024_smooth2.mat', \n", 208 | " 'Navier-Stokes': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/ns_V1e-3_N5000_T50.mat', \n", 209 | " 'darcy-test-32': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_test_32.pt', \n", 210 | " 'darcy-test-64': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_test_64.pt', \n", 211 | " 'darcy-train-32': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_train_32.pt', \n", 212 | " 'darcy-train-64': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_train_64.pt', \n", 213 | " 'KF-Re100': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/KFvorticity_Re100_N50_T500.npy'\n", 214 | "}\n", 215 | "\n", 216 | "chksum_dict = {\n", 217 | " 'piececonst_r241_N1024_smooth1.mat': '5ab3edf67bb5fd6d49ebf308cd79ed70340106d1a18af8a8439d3e7fc8e82d21', \n", 218 | " 'piececonst_r241_N1024_smooth2.mat': '51a818ed2e4f08752eea5d3f137f0e00271589c48297a46c641382a51eb80acf', \n", 219 | " 'ns_V1e-3_N5000_T50.mat': '78b8d9e83d767dc7050fb8145ee7e7f11e2d18d325bff9abc7f108ec3292ee78', \n", 220 | " 'darcy_train_64.pt': 'c05770239c91ebf093ea971e4d724008a49c9f21b5363fcf182e80499fae7fb4', \n", 221 | " 'darcy_train_32.pt': 'b8d8095d3832ed67f55b4a8fcb1970618b4ca2c6fc91aee2fe49b9c9b2c071ae', \n", 222 | " 'darcy_test_64.pt': '2220bb25c920109e9565a7fc07b675de16d124d563996f6e7256e2faa1fde24f', \n", 223 | " 'darcy_test_32.pt': '65137910193a553295c26e3d8273761daa44766597f4b34cfb12299fc6e3f311', \n", 224 | " 'KFvorticity_Re100_N50_T500.npy': '55f5af44a732a7843d631ace6384ac75c787d4fb36765b2e83ce1febb52d5463'\n", 225 | "}\n", 226 | "\n", 227 | "def download_file(url, file_path):\n", 228 | " with requests.get(url, stream=True) as r:\n", 229 | " r.raise_for_status()\n", 230 | " with open(file_path, 'wb') as f:\n", 231 | " for chunk in r.iter_content(chunk_size=1024 * 1024 * 1024):\n", 232 | " f.write(chunk)\n", 233 | " print('Complete')\n" 234 | ] 235 | }, 236 | { 237 | "attachments": {}, 238 | "cell_type": "markdown", 239 | "id": "d36ba93d", 240 | "metadata": {}, 241 | "source": [ 242 | "## Download Darcy datasets" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 7, 248 | "id": "70b2c9d0-990d-43fd-9a80-7af9dbc8dd64", 249 | "metadata": {}, 250 | "outputs": [ 251 | { 252 | "name": "stdout", 253 | "output_type": "stream", 254 | "text": [ 255 | "Downloading https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_train_64.pt...\n", 256 | "Complete\n", 257 | "Downloading https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_train_32.pt...\n", 258 | "Complete\n", 259 | "Downloading https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_test_64.pt...\n", 260 | "Complete\n", 261 | "Downloading https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_test_32.pt...\n", 262 | "Complete\n" 263 | ] 264 | } 265 | ], 266 | "source": [ 267 | "data_root = 'data'\n", 268 | "darcy_dir = os.path.join(data_root, 'darcy_flow')\n", 269 | "os.makedirs(darcy_dir, exist_ok=True)\n", 270 | "\n", 271 | "day1_data = ['darcy-train-64', 'darcy-train-32', 'darcy-test-64', 'darcy-test-32']\n", 272 | "\n", 273 | "for key in day1_data:\n", 274 | " value = url_dict[key]\n", 275 | " print(f'Downloading {value}...')\n", 276 | " filename = os.path.basename(value)\n", 277 | " save_path = os.path.join(darcy_dir, filename)\n", 278 | " download_file(url=value, file_path=save_path)\n" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "id": "db98503a", 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "# verify data integrity\n", 289 | "for data_file in os.listdir(darcy_dir):\n", 290 | " data_path = os.path.join(darcy_dir, data_file)\n", 291 | " with open(data_path, 'rb') as f:\n", 292 | " data = f.read()\n", 293 | " sha256 = hashlib.sha256(data).hexdigest()\n", 294 | " if sha256 == chksum_dict[data_file]:\n", 295 | " print(f'{data_file} verified!')\n", 296 | " else:\n", 297 | " print(f'{data_file} verfication failed!')" 298 | ] 299 | }, 300 | { 301 | "attachments": {}, 302 | "cell_type": "markdown", 303 | "id": "6a5cc551", 304 | "metadata": {}, 305 | "source": [ 306 | "### Download KF datasets (2d NS)" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 2, 312 | "id": "817f3d48", 313 | "metadata": {}, 314 | "outputs": [ 315 | { 316 | "name": "stdout", 317 | "output_type": "stream", 318 | "text": [ 319 | "Downloading https://caltech-pde-data.s3.us-west-2.amazonaws.com/KFvorticity_Re100_N50_T500.npy to data\\kf\n", 320 | "Complete\n" 321 | ] 322 | } 323 | ], 324 | "source": [ 325 | "data_root = 'data'\n", 326 | "kf_dir = os.path.join(data_root, 'kf')\n", 327 | "os.makedirs(kf_dir, exist_ok=True)\n", 328 | "\n", 329 | "kf_data = ['KF-Re100']\n", 330 | "for key in kf_data:\n", 331 | " value = url_dict[key]\n", 332 | " print(f'Downloading {value} to {kf_dir}')\n", 333 | " filename = os.path.basename(value)\n", 334 | " save_path = os.path.join(kf_dir, filename)\n", 335 | " download_file(url=value, file_path=save_path)" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 11, 341 | "id": "73acfd3d-23d4-4f02-9bc7-167438ac2de4", 342 | "metadata": {}, 343 | "outputs": [ 344 | { 345 | "name": "stdout", 346 | "output_type": "stream", 347 | "text": [ 348 | "KFvorticity_Re100_N50_T500.npy verified!\n" 349 | ] 350 | } 351 | ], 352 | "source": [ 353 | "for data_file in os.listdir(kf_dir):\n", 354 | " data_path = os.path.join(kf_dir, data_file)\n", 355 | " with open(data_path, 'rb') as f:\n", 356 | " data = f.read()\n", 357 | " sha256 = hashlib.sha256(data).hexdigest()\n", 358 | " if sha256 == chksum_dict[data_file]:\n", 359 | " print(f'{data_file} verified!')\n", 360 | " else:\n", 361 | " print(f'{data_file} verfication failed!')\n" 362 | ] 363 | } 364 | ], 365 | "metadata": { 366 | "kernelspec": { 367 | "display_name": "base", 368 | "language": "python", 369 | "name": "python3" 370 | }, 371 | "language_info": { 372 | "codemirror_mode": { 373 | "name": "ipython", 374 | "version": 3 375 | }, 376 | "file_extension": ".py", 377 | "mimetype": "text/x-python", 378 | "name": "python", 379 | "nbconvert_exporter": "python", 380 | "pygments_lexer": "ipython3", 381 | "version": "3.8.8" 382 | }, 383 | "vscode": { 384 | "interpreter": { 385 | "hash": "95d4b27ba6bfea4a66eebe0e0159b214d32a94d313a7f4c98bd9b87f5ee37cbe" 386 | } 387 | } 388 | }, 389 | "nbformat": 4, 390 | "nbformat_minor": 5 391 | } 392 | -------------------------------------------------------------------------------- /2-intro_FNO.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "59194c45-83c9-4a77-a1b0-185eca26afd5", 6 | "metadata": {}, 7 | "source": [ 8 | "# Check the dependencies " 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "af7a5c4c-b3a5-4f32-aee9-55290566ff56", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "tl.__version__='0.8.0'\n", 22 | "tltorch.__version__='0.3.0'\n", 23 | "no.__version__='0.1.0'\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "import tensorly as tl\n", 29 | "import tltorch\n", 30 | "import neuralop as no\n", 31 | "\n", 32 | "print(f'{tl.__version__=}')\n", 33 | "print(f'{tltorch.__version__=}')\n", 34 | "print(f'{no.__version__=}')" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "id": "a36bb3e2-c158-497c-babe-5eead700cbf1", 40 | "metadata": { 41 | "tags": [] 42 | }, 43 | "source": [ 44 | "# FFT and Spectral Convolution\n" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "id": "4efa0d7f-e39c-496e-891d-6b34c62fbd9d", 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "from neuralop.models.fno_block import FactorizedSpectralConv\n", 55 | "from neuralop.models import TFNO2d\n", 56 | "import torch" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "id": "2c8c3eb7-82e1-4df7-b6e6-3f34331637c4", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "fourier_conv = FactorizedSpectralConv(in_channels=3, out_channels=10, n_modes=(4, 4),\n", 67 | " factorization=None, implementation='reconstructed')" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 4, 73 | "id": "4eaf645a-b7f5-4dcc-b8de-fb388ccc9b26", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "in_data = torch.randn((2, 3, 16, 16))" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "id": "016b33e0-88a6-4215-99e0-19da4f8fd5f5", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "out = fourier_conv(in_data)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 6, 93 | "id": "36d0f546-9fa9-4936-a6b6-19d7bde03639", 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "data": { 98 | "text/plain": [ 99 | "torch.Size([2, 10, 16, 16])" 100 | ] 101 | }, 102 | "execution_count": 6, 103 | "metadata": {}, 104 | "output_type": "execute_result" 105 | } 106 | ], 107 | "source": [ 108 | "out.shape" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 7, 114 | "id": "4936746b-5abb-4a8b-9e74-238502c65930", 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "data": { 119 | "text/plain": [ 120 | "FactorizedSpectralConv(\n", 121 | " (weight): ModuleList(\n", 122 | " (0): ComplexDenseTensor(shape=torch.Size([3, 10, 2, 2]), rank=None)\n", 123 | " (1): ComplexDenseTensor(shape=torch.Size([3, 10, 2, 2]), rank=None)\n", 124 | " )\n", 125 | ")" 126 | ] 127 | }, 128 | "execution_count": 7, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | } 132 | ], 133 | "source": [ 134 | "fourier_conv" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "id": "a616d68d-677a-4e6f-abd5-9e631ebf7fb6", 140 | "metadata": {}, 141 | "source": [ 142 | "The way the spectral convolution works is that it multiplies (complex) coefficients with (complex) weights, learned end-to-end." 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "id": "0c8d9860-d43d-47f3-a6aa-c7ed4522684e", 148 | "metadata": { 149 | "tags": [] 150 | }, 151 | "source": [ 152 | "# Tensorized Spectral Convolutions\n", 153 | "\n", 154 | "It is possible to express the weights of one or more layers as in factorized form, as a low-rank decomposition of the full weights.\n", 155 | "\n", 156 | "`neuralop` comes with support for tensorization out of the box, you can simply specify, e.g., to use a Tucker factorization, `factorization='tucker'`." 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 8, 162 | "id": "b3f919de-97c2-4f0b-bb40-8e47cd2c1e0e", 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "fourier_conv = FactorizedSpectralConv(in_channels=3, out_channels=10, n_modes=(4, 4),\n", 167 | " factorization='tucker', implementation='reconstructed')" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 9, 173 | "id": "91a7aa04-9cc3-4f8c-b34f-54fbc625b718", 174 | "metadata": {}, 175 | "outputs": [ 176 | { 177 | "data": { 178 | "text/plain": [ 179 | "FactorizedSpectralConv(\n", 180 | " (weight): ModuleList(\n", 181 | " (0): ComplexTuckerTensor(shape=(3, 10, 2, 2), rank=(1, 5, 1, 1))\n", 182 | " (1): ComplexTuckerTensor(shape=(3, 10, 2, 2), rank=(1, 5, 1, 1))\n", 183 | " )\n", 184 | ")" 185 | ] 186 | }, 187 | "execution_count": 9, 188 | "metadata": {}, 189 | "output_type": "execute_result" 190 | } 191 | ], 192 | "source": [ 193 | "fourier_conv" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "id": "f8df876d-72e1-40cd-9a86-330a57dc0e8d", 199 | "metadata": {}, 200 | "source": [ 201 | "## Efficient forward pass\n", 202 | "\n", 203 | "When factorizing the weights, have two main options during the forward pass:\n", 204 | "1. reconstruct the full weights and use that for the forward pass \n", 205 | "2. contract the input directly with the factorized weights to predict the output\n", 206 | "\n", 207 | "When the factorized weights are small, the second option can lead to large speedups or memory reduction, particularly when coupled with checkpointing. \n", 208 | "\n", 209 | "In `neuralop`, you can use those simply by specifying `implementation='reconstructed'` or `implementation='factorized'`:" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 10, 215 | "id": "a0667a6b-1efe-47e0-8908-29c5fb0cf45a", 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "fourier_conv = FactorizedSpectralConv(in_channels=3, out_channels=10, n_modes=(4, 4),\n", 220 | " factorization='tucker', implementation='factorized')" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "id": "ec3ab24a-09fe-4864-b2ed-e96b54792e9f", 226 | "metadata": {}, 227 | "source": [ 228 | "# Full Tensorized Fourier Neural Operator \n", 229 | "\n", 230 | "The full architecture is composed of \n", 231 | "\n", 232 | "i) a lifting layer taking the number of input channels and lifting that to the desired number of hidden channels\n", 233 | "ii) a number of spectral convolutions, as shown above\n", 234 | "iii) a projection layer projecting back from the number of hidden channels to the desired number of output channels\n" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 11, 240 | "id": "d51aec17-2cf4-40c4-9452-84a4b5259db6", 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "tfno = TFNO2d(n_modes_height=16, n_modes_width=16, hidden_channels=16, \n", 245 | " factorization=None, skip='linear')" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 12, 251 | "id": "c87127e5-d24c-4096-be3a-8872a853a132", 252 | "metadata": {}, 253 | "outputs": [ 254 | { 255 | "data": { 256 | "text/plain": [ 257 | "TFNO2d(\n", 258 | " (convs): FactorizedSpectralConv2d(\n", 259 | " (weight): ModuleList(\n", 260 | " (0): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", 261 | " (1): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", 262 | " (2): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", 263 | " (3): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", 264 | " (4): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", 265 | " (5): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", 266 | " (6): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", 267 | " (7): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", 268 | " )\n", 269 | " )\n", 270 | " (fno_skips): ModuleList(\n", 271 | " (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 272 | " (1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 273 | " (2): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 274 | " (3): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 275 | " )\n", 276 | " (lifting): Lifting(\n", 277 | " (fc): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))\n", 278 | " )\n", 279 | " (projection): Projection(\n", 280 | " (fc1): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))\n", 281 | " (fc2): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))\n", 282 | " )\n", 283 | ")" 284 | ] 285 | }, 286 | "execution_count": 12, 287 | "metadata": {}, 288 | "output_type": "execute_result" 289 | } 290 | ], 291 | "source": [ 292 | "tfno" 293 | ] 294 | }, 295 | { 296 | "cell_type": "markdown", 297 | "id": "0e70efec-bf3c-48ac-b53a-59800055f1b9", 298 | "metadata": {}, 299 | "source": [ 300 | "## Lifting layer\n", 301 | "\n", 302 | "Increasing the number of channels" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 13, 308 | "id": "1deead74-bd3d-4aa9-8d2c-cfd9ab0763d7", 309 | "metadata": {}, 310 | "outputs": [ 311 | { 312 | "data": { 313 | "text/plain": [ 314 | "Lifting(\n", 315 | " (fc): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))\n", 316 | ")" 317 | ] 318 | }, 319 | "execution_count": 13, 320 | "metadata": {}, 321 | "output_type": "execute_result" 322 | } 323 | ], 324 | "source": [ 325 | "tfno.lifting" 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "id": "08844bac-9335-4ac4-afc8-f1d67c3e31bb", 331 | "metadata": {}, 332 | "source": [ 333 | "## Spectral convolutions" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 14, 339 | "id": "f2bc28dc-1226-4ed3-b757-3c42357d276a", 340 | "metadata": {}, 341 | "outputs": [ 342 | { 343 | "data": { 344 | "text/plain": [ 345 | "FactorizedSpectralConv2d(\n", 346 | " (weight): ModuleList(\n", 347 | " (0): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", 348 | " (1): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", 349 | " (2): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", 350 | " (3): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", 351 | " (4): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", 352 | " (5): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", 353 | " (6): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", 354 | " (7): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", 355 | " )\n", 356 | ")" 357 | ] 358 | }, 359 | "execution_count": 14, 360 | "metadata": {}, 361 | "output_type": "execute_result" 362 | } 363 | ], 364 | "source": [ 365 | "tfno.convs" 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "id": "1c7d9882-13db-447d-affd-07ef17256e1c", 371 | "metadata": {}, 372 | "source": [ 373 | "## Skip connections: recovering non-periodicity\n", 374 | "\n", 375 | "Recall the FNO architecture has skip connections: the FFT transformation will loose non-periodic information that has to be reinjected through skip connections. These skip connections also help with learning.\n", 376 | "\n", 377 | "![FNO_layer](./images/fourier_layer.png)\n", 378 | "\n", 379 | "Here, linear layer (represented by weight W in the image). We can also use Identity skip (`skip='identity'`) or soft-gated connections (`skip='soft-gating'`)" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": 15, 385 | "id": "f063e3bf-34e5-4d7f-83f9-b3522aa6430b", 386 | "metadata": {}, 387 | "outputs": [ 388 | { 389 | "data": { 390 | "text/plain": [ 391 | "ModuleList(\n", 392 | " (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 393 | " (1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 394 | " (2): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 395 | " (3): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 396 | ")" 397 | ] 398 | }, 399 | "execution_count": 15, 400 | "metadata": {}, 401 | "output_type": "execute_result" 402 | } 403 | ], 404 | "source": [ 405 | "tfno.fno_skips" 406 | ] 407 | }, 408 | { 409 | "cell_type": "markdown", 410 | "id": "070e930e-38b6-4d3c-b62a-3ca700294c99", 411 | "metadata": {}, 412 | "source": [ 413 | "## Projection: going back to the target number of channels \n", 414 | "\n", 415 | "Finally, the projection layer takes the hidden dimension to projection_channels and to the actual number of output channels (here, 1)" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": 16, 421 | "id": "88344f47-a7e8-458e-9fbb-775804fbbaad", 422 | "metadata": {}, 423 | "outputs": [ 424 | { 425 | "data": { 426 | "text/plain": [ 427 | "Projection(\n", 428 | " (fc1): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))\n", 429 | " (fc2): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))\n", 430 | ")" 431 | ] 432 | }, 433 | "execution_count": 16, 434 | "metadata": {}, 435 | "output_type": "execute_result" 436 | } 437 | ], 438 | "source": [ 439 | "tfno.projection" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": null, 445 | "id": "7aae1ab6-852c-4720-9b3b-5791c2b42872", 446 | "metadata": {}, 447 | "outputs": [], 448 | "source": [] 449 | } 450 | ], 451 | "metadata": { 452 | "kernelspec": { 453 | "display_name": "Python 3 (ipykernel)", 454 | "language": "python", 455 | "name": "python3" 456 | }, 457 | "language_info": { 458 | "codemirror_mode": { 459 | "name": "ipython", 460 | "version": 3 461 | }, 462 | "file_extension": ".py", 463 | "mimetype": "text/x-python", 464 | "name": "python", 465 | "nbconvert_exporter": "python", 466 | "pygments_lexer": "ipython3", 467 | "version": "3.9.15" 468 | } 469 | }, 470 | "nbformat": 4, 471 | "nbformat_minor": 5 472 | } 473 | -------------------------------------------------------------------------------- /3-darcy_flow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "4df7dcda-a364-4255-9339-a9a09c2a5e34", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from pathlib import Path\n", 11 | "from neuralop.datasets import load_darcy_pt" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "id": "ff12d431-bde9-4eba-906b-d0faea8c49fb", 17 | "metadata": {}, 18 | "source": [ 19 | "# Load the data " 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 3, 25 | "id": "aa9c49f5-878b-4cac-9a35-b9dc53085d11", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "data_path=\"/dli/task/bootcamp/data/darcy_flow/\"" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "id": "f40e8c0a-c031-457b-863c-c728de7d1b80", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "train_loader, test_loaders, output_encoder = load_darcy_pt(data_path, n_train=100, n_tests=[10], \n", 40 | " batch_size=3, test_batch_sizes=[3],\n", 41 | " test_resolutions=[32], train_resolution=32)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 8, 47 | "id": "2f29d90a-4fc7-4b83-8ed4-f1bb4dce2574", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "train_dataset = train_loader.dataset" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "id": "21000189-ecac-42e2-b008-06eefa7b1710", 57 | "metadata": {}, 58 | "source": [ 59 | "# Visualizing the data " 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "1cf47f09-1fb3-4667-9b04-a98b3ee8d08d", 65 | "metadata": {}, 66 | "source": [ 67 | "The data is stored in a dictionary" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 9, 73 | "id": "b6a9aed5-6532-42ba-8131-0307460c960d", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "data = train_dataset[0]\n", 78 | "x = data['x']\n", 79 | "y = data['y']" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 10, 85 | "id": "9f475172-62b0-4ce3-8dce-a7d0d9dca9fb", 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "data": { 90 | "text/plain": [ 91 | "torch.Size([3, 128, 128])" 92 | ] 93 | }, 94 | "execution_count": 6, 95 | "metadata": {}, 96 | "output_type": "execute_result" 97 | } 98 | ], 99 | "source": [ 100 | "x.shape" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "id": "7d7947ad-f98c-414b-8a12-64270988ad1f", 106 | "metadata": {}, 107 | "source": [ 108 | "`x` is of shape (3, height, width). \n", 109 | "\n", 110 | "This is because, in addition to the binary input, we appended a positional encoding, so the model knows the location of each pixel.\n", 111 | "\n", 112 | "Let's check the actual data:" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 8, 118 | "id": "077ebd7d-883b-4300-b13d-ed88813a3be1", 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "%matplotlib inline\n", 123 | "import matplotlib.pyplot as plt" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 10, 129 | "id": "814a044d-a52f-4cf2-aa1b-859370012af5", 130 | "metadata": {}, 131 | "outputs": [ 132 | { 133 | "data": { 134 | "image/png": "", 135 | "text/plain": [ 136 | "
" 137 | ] 138 | }, 139 | "metadata": {}, 140 | "output_type": "display_data" 141 | } 142 | ], 143 | "source": [ 144 | "# Which sample to view\n", 145 | "index = 10\n", 146 | "\n", 147 | "data = train_dataset[index]\n", 148 | "x = data['x']\n", 149 | "y = data['y']\n", 150 | "fig = plt.figure(figsize=(7, 7))\n", 151 | "ax = fig.add_subplot(2, 2, 1)\n", 152 | "ax.imshow(x[0], cmap='gray')\n", 153 | "ax.set_title('input x')\n", 154 | "ax = fig.add_subplot(2, 2, 2)\n", 155 | "ax.imshow(y.squeeze())\n", 156 | "ax.set_title('input y')\n", 157 | "ax = fig.add_subplot(2, 2, 3)\n", 158 | "ax.imshow(x[1])\n", 159 | "ax.set_title('x: 1st pos embedding')\n", 160 | "ax = fig.add_subplot(2, 2, 4)\n", 161 | "ax.imshow(x[2])\n", 162 | "ax.set_title('x: 2nd pos embedding')\n", 163 | "fig.suptitle('Visualizing one input sample', y=0.98)\n", 164 | "plt.tight_layout()\n", 165 | "fig.show()" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "ba9e2a5c-98e7-47c0-9e24-7a8e41c657dc", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [] 175 | } 176 | ], 177 | "metadata": { 178 | "kernelspec": { 179 | "display_name": "Python 3 (ipykernel)", 180 | "language": "python", 181 | "name": "python3" 182 | }, 183 | "language_info": { 184 | "codemirror_mode": { 185 | "name": "ipython", 186 | "version": 3 187 | }, 188 | "file_extension": ".py", 189 | "mimetype": "text/x-python", 190 | "name": "python", 191 | "nbconvert_exporter": "python", 192 | "pygments_lexer": "ipython3", 193 | "version": "3.9.15" 194 | } 195 | }, 196 | "nbformat": 4, 197 | "nbformat_minor": 5 198 | } 199 | -------------------------------------------------------------------------------- /4-training-on-Darcy-Flow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "012a357f-8533-482c-823d-a4587c49e726", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import wandb\n", 12 | "import sys\n", 13 | "from configmypy import ConfigPipeline, YamlConfig, ArgparseConfig\n", 14 | "from neuralop import get_model\n", 15 | "from neuralop import Trainer\n", 16 | "from neuralop.training import setup\n", 17 | "from neuralop.datasets import load_darcy_pt\n", 18 | "from neuralop.utils import get_wandb_api_key, count_params\n", 19 | "from neuralop import LpLoss, H1Loss" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "id": "9b6358b5-78d1-4baf-8928-6bb49b150680", 25 | "metadata": {}, 26 | "source": [ 27 | "# Loading the configuration\n", 28 | "\n", 29 | "You can open the yaml file in config/darcy_config in the same folder as this notebook to inspect the parameters and change them." 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "4503f065-4063-4a4f-b00f-06a7c3a88e27", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "# Read the configuration\n", 40 | "config_name = 'default'\n", 41 | "pipe = ConfigPipeline([YamlConfig('./darcy_config.yaml', config_name='default', config_folder='./config'),\n", 42 | " ])\n", 43 | "config = pipe.read_conf()\n", 44 | "config_name = pipe.steps[-1].config_name" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "e95d820d-9578-4ad7-80b4-05a5771f1642", 50 | "metadata": {}, 51 | "source": [ 52 | "## Setup\n", 53 | "\n", 54 | "Here we just setup pytorch and print the configuration" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "id": "46066d9f-21a3-4aab-b6e1-f7f38e05f88b", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# Set-up distributed communication, if using\n", 65 | "device, is_logger = setup(config)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 4, 71 | "id": "26d599f9-6463-4056-9a4d-72c01d05298e", 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "###############################\n", 79 | "##### CONFIGURATION #####\n", 80 | "###############################\n", 81 | "\n", 82 | "Steps:\n", 83 | "------\n", 84 | " (1) YamlConfig with config_file=./darcy_config.yaml, config_name=default, config_folder=./config\n", 85 | "\n", 86 | "-------------------------------\n", 87 | "\n", 88 | "Configuration:\n", 89 | "--------------\n", 90 | "\n", 91 | "n_params_baseline=None\n", 92 | "verbose=True\n", 93 | "arch=tfno2d\n", 94 | "distributed.use_distributed=False\n", 95 | "tfno2d.data_channels=3\n", 96 | "tfno2d.n_modes_height=32\n", 97 | "tfno2d.n_modes_width=32\n", 98 | "tfno2d.hidden_channels=64\n", 99 | "tfno2d.projection_channels=256\n", 100 | "tfno2d.n_layers=4\n", 101 | "tfno2d.domain_padding=None\n", 102 | "tfno2d.domain_padding_mode=one-sided\n", 103 | "tfno2d.fft_norm=forward\n", 104 | "tfno2d.norm=group_norm\n", 105 | "tfno2d.skip=linear\n", 106 | "tfno2d.implementation=factorized\n", 107 | "tfno2d.separable=0\n", 108 | "tfno2d.preactivation=0\n", 109 | "tfno2d.use_mlp=1\n", 110 | "tfno2d.mlp.expansion=0.5\n", 111 | "tfno2d.mlp.dropout=0\n", 112 | "tfno2d.factorization=None\n", 113 | "tfno2d.rank=1.0\n", 114 | "tfno2d.fixed_rank_modes=None\n", 115 | "tfno2d.dropout=0.0\n", 116 | "tfno2d.tensor_lasso_penalty=0.0\n", 117 | "tfno2d.joint_factorization=False\n", 118 | "opt.n_epochs=150\n", 119 | "opt.learning_rate=0.005\n", 120 | "opt.training_loss=h1\n", 121 | "opt.weight_decay=0.0001\n", 122 | "opt.amp_autocast=False\n", 123 | "opt.scheduler_T_max=300\n", 124 | "opt.scheduler_patience=5\n", 125 | "opt.scheduler=CosineAnnealingLR\n", 126 | "opt.step_size=50\n", 127 | "opt.gamma=0.5\n", 128 | "data.folder=/data/darcy_flow/\n", 129 | "data.batch_size=32\n", 130 | "data.n_train=3000\n", 131 | "data.train_resolution=32\n", 132 | "data.n_tests=[500, 500]\n", 133 | "data.test_resolutions=[32, 64]\n", 134 | "data.test_batch_sizes=[32, 32]\n", 135 | "data.positional_encoding=True\n", 136 | "data.encode_input=True\n", 137 | "data.encode_output=False\n", 138 | "patching.levels=0\n", 139 | "patching.padding=0\n", 140 | "patching.stitching=False\n", 141 | "wandb.log=False\n", 142 | "wandb.log_test_interval=1\n", 143 | "\n", 144 | "###############################\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "# Make sure we only print information when needed\n", 150 | "config.verbose = config.verbose and is_logger\n", 151 | "\n", 152 | "#Print config to screen\n", 153 | "if config.verbose and is_logger:\n", 154 | " pipe.log()\n", 155 | " sys.stdout.flush()" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "id": "1339c794-3e1c-469b-b0a0-cf968fc1dfa1", 161 | "metadata": {}, 162 | "source": [ 163 | "# Loading the data \n", 164 | "\n", 165 | "We train in one resolution and test in several resolutions to show the zero-shot super-resolution capabilities of neural-operators. " 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "3515a85a-40fc-4223-9cdb-8768de37d6e2", 172 | "metadata": {}, 173 | "outputs": [ 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "UnitGaussianNormalizer init on 3000, reducing over [0, 1, 2, 3], samples of shape [1, 32, 32].\n", 179 | " Mean and std of shape torch.Size([1, 1, 1]), eps=1e-05\n", 180 | "Loading test db at resolution 64 with 500 samples and batch-size=32\n" 181 | ] 182 | } 183 | ], 184 | "source": [ 185 | "# Loading the Darcy flow training set in 32x32 resolution, test set in 32x32 and 64x64 resolutions\n", 186 | "train_loader, test_loaders, output_encoder = load_darcy_pt(\n", 187 | " config.data.folder, train_resolution=config.data.train_resolution, n_train=config.data.n_train, batch_size=config.data.batch_size, \n", 188 | " positional_encoding=config.data.positional_encoding,\n", 189 | " test_resolutions=config.data.test_resolutions, n_tests=config.data.n_tests, test_batch_sizes=config.data.test_batch_sizes,\n", 190 | " encode_input=config.data.encode_input, encode_output=config.data.encode_output,\n", 191 | " )" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "id": "8109298a-aca3-45b7-a8de-c5cf4e1c210b", 197 | "metadata": {}, 198 | "source": [ 199 | "# Creating the model and putting it on the GPU " 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 6, 205 | "id": "db295d23-ab86-4f37-83cc-7af0a8e485ea", 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "name": "stdout", 210 | "output_type": "stream", 211 | "text": [ 212 | "Given argument key='dropout' that is not in TFNO2d's signature.\n", 213 | "Given argument key='tensor_lasso_penalty' that is not in TFNO2d's signature.\n", 214 | "Keyword argument out_channels not specified for model TFNO2d, using default=1.\n", 215 | "Keyword argument lifting_channels not specified for model TFNO2d, using default=256.\n", 216 | "Keyword argument non_linearity not specified for model TFNO2d, using default=.\n", 217 | "Keyword argument decomposition_kwargs not specified for model TFNO2d, using default={}.\n", 218 | "\n", 219 | "n_params: 16844673\n" 220 | ] 221 | } 222 | ], 223 | "source": [ 224 | "model = get_model(config)\n", 225 | "model = model.to(device)\n", 226 | "\n", 227 | "#Log parameter count\n", 228 | "if is_logger:\n", 229 | " n_params = count_params(model)\n", 230 | "\n", 231 | " if config.verbose:\n", 232 | " print(f'\\nn_params: {n_params}')\n", 233 | " sys.stdout.flush()" 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "id": "fec85d0a-4db4-4b1f-b599-8c2afc98520a", 239 | "metadata": {}, 240 | "source": [ 241 | "# Create the optimizer and learning rate scheduler\n", 242 | "\n", 243 | "Here, we use an Adam optimizer and a learning rate schedule depending on the configuration" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 7, 249 | "id": "5164537a-267b-4fda-9bcd-257dc3ac4826", 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "#Create the optimizer\n", 254 | "optimizer = torch.optim.Adam(model.parameters(), \n", 255 | " lr=config.opt.learning_rate, \n", 256 | " weight_decay=config.opt.weight_decay)\n", 257 | "\n", 258 | "if config.opt.scheduler == 'ReduceLROnPlateau':\n", 259 | " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.opt.gamma, patience=config.opt.scheduler_patience, mode='min')\n", 260 | "elif config.opt.scheduler == 'CosineAnnealingLR':\n", 261 | " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.opt.scheduler_T_max)\n", 262 | "elif config.opt.scheduler == 'StepLR':\n", 263 | " scheduler = torch.optim.lr_scheduler.StepLR(optimizer, \n", 264 | " step_size=config.opt.step_size,\n", 265 | " gamma=config.opt.gamma)\n", 266 | "else:\n", 267 | " raise ValueError(f'Got {config.opt.scheduler=}')" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "id": "e52a72eb-965a-4997-89a4-0cdfcbcb0a1a", 273 | "metadata": {}, 274 | "source": [ 275 | "# Creating the loss\n", 276 | "\n", 277 | "We will optimize the Sobolev norm but also evaluate our goal: the l2 relative error" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 8, 283 | "id": "07a53d9d-2d06-4d36-9b46-2c7f15f29c40", 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "# Creating the losses\n", 288 | "l2loss = LpLoss(d=2, p=2)\n", 289 | "h1loss = H1Loss(d=2)\n", 290 | "if config.opt.training_loss == 'l2':\n", 291 | " train_loss = l2loss\n", 292 | "elif config.opt.training_loss == 'h1':\n", 293 | " train_loss = h1loss\n", 294 | "else:\n", 295 | " raise ValueError(f'Got training_loss={config.opt.training_loss} but expected one of [\"l2\", \"h1\"]')\n", 296 | "eval_losses={'h1': h1loss, 'l2': l2loss}" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 9, 302 | "id": "5dad660e-43e9-4f38-91f6-8427b14b8ae0", 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "name": "stdout", 307 | "output_type": "stream", 308 | "text": [ 309 | "\n", 310 | "### MODEL ###\n", 311 | " TFNO2d(\n", 312 | " (convs): FactorizedSpectralConv2d(\n", 313 | " (weight): ModuleList(\n", 314 | " (0): ComplexDenseTensor(shape=torch.Size([64, 64, 16, 16]), rank=None)\n", 315 | " (1): ComplexDenseTensor(shape=torch.Size([64, 64, 16, 16]), rank=None)\n", 316 | " (2): ComplexDenseTensor(shape=torch.Size([64, 64, 16, 16]), rank=None)\n", 317 | " (3): ComplexDenseTensor(shape=torch.Size([64, 64, 16, 16]), rank=None)\n", 318 | " (4): ComplexDenseTensor(shape=torch.Size([64, 64, 16, 16]), rank=None)\n", 319 | " (5): ComplexDenseTensor(shape=torch.Size([64, 64, 16, 16]), rank=None)\n", 320 | " (6): ComplexDenseTensor(shape=torch.Size([64, 64, 16, 16]), rank=None)\n", 321 | " (7): ComplexDenseTensor(shape=torch.Size([64, 64, 16, 16]), rank=None)\n", 322 | " )\n", 323 | " )\n", 324 | " (fno_skips): ModuleList(\n", 325 | " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 326 | " (1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 327 | " (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 328 | " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 329 | " )\n", 330 | " (mlp): ModuleList(\n", 331 | " (0): MLP(\n", 332 | " (fcs): ModuleList(\n", 333 | " (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", 334 | " (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n", 335 | " )\n", 336 | " )\n", 337 | " (1): MLP(\n", 338 | " (fcs): ModuleList(\n", 339 | " (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", 340 | " (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n", 341 | " )\n", 342 | " )\n", 343 | " (2): MLP(\n", 344 | " (fcs): ModuleList(\n", 345 | " (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", 346 | " (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n", 347 | " )\n", 348 | " )\n", 349 | " (3): MLP(\n", 350 | " (fcs): ModuleList(\n", 351 | " (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", 352 | " (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n", 353 | " )\n", 354 | " )\n", 355 | " )\n", 356 | " (mlp_skips): ModuleList(\n", 357 | " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 358 | " (1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 359 | " (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 360 | " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 361 | " )\n", 362 | " (norm): ModuleList(\n", 363 | " (0): GroupNorm(1, 64, eps=1e-05, affine=True)\n", 364 | " (1): GroupNorm(1, 64, eps=1e-05, affine=True)\n", 365 | " (2): GroupNorm(1, 64, eps=1e-05, affine=True)\n", 366 | " (3): GroupNorm(1, 64, eps=1e-05, affine=True)\n", 367 | " )\n", 368 | " (lifting): Lifting(\n", 369 | " (fc): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))\n", 370 | " )\n", 371 | " (projection): Projection(\n", 372 | " (fc1): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))\n", 373 | " (fc2): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))\n", 374 | " )\n", 375 | ")\n", 376 | "\n", 377 | "### OPTIMIZER ###\n", 378 | " Adam (\n", 379 | "Parameter Group 0\n", 380 | " amsgrad: False\n", 381 | " betas: (0.9, 0.999)\n", 382 | " capturable: False\n", 383 | " differentiable: False\n", 384 | " eps: 1e-08\n", 385 | " foreach: None\n", 386 | " fused: False\n", 387 | " initial_lr: 0.005\n", 388 | " lr: 0.005\n", 389 | " maximize: False\n", 390 | " weight_decay: 0.0001\n", 391 | ")\n", 392 | "\n", 393 | "### SCHEDULER ###\n", 394 | " \n", 395 | "\n", 396 | "### LOSSES ###\n", 397 | "\n", 398 | " * Train: \n", 399 | "\n", 400 | " * Test: {'h1': , 'l2': }\n", 401 | "\n", 402 | "### Beginning Training...\n", 403 | "\n" 404 | ] 405 | } 406 | ], 407 | "source": [ 408 | "if config.verbose and is_logger:\n", 409 | " print('\\n### MODEL ###\\n', model)\n", 410 | " print('\\n### OPTIMIZER ###\\n', optimizer)\n", 411 | " print('\\n### SCHEDULER ###\\n', scheduler)\n", 412 | " print('\\n### LOSSES ###')\n", 413 | " print(f'\\n * Train: {train_loss}')\n", 414 | " print(f'\\n * Test: {eval_losses}')\n", 415 | " print(f'\\n### Beginning Training...\\n')\n", 416 | " sys.stdout.flush()" 417 | ] 418 | }, 419 | { 420 | "cell_type": "markdown", 421 | "id": "b5967441-b8bc-4ea8-a4d9-7a5bea384cbf", 422 | "metadata": {}, 423 | "source": [ 424 | "# Creating the trainer" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": 10, 430 | "id": "a19ebfd3-8a2b-42c0-af98-7a1db2dda0f6", 431 | "metadata": {}, 432 | "outputs": [ 433 | { 434 | "name": "stdout", 435 | "output_type": "stream", 436 | "text": [ 437 | "Training on regular inputs (no multi-grid patching).\n", 438 | "MGPatching(self.n_patches=[1, 1], self.padding_fraction=[0, 0], self.levels=0, use_distributed=False, stitching=False)\n" 439 | ] 440 | } 441 | ], 442 | "source": [ 443 | "trainer = Trainer(model, n_epochs=config.opt.n_epochs,\n", 444 | " device=device,\n", 445 | " mg_patching_levels=config.patching.levels,\n", 446 | " mg_patching_padding=config.patching.padding,\n", 447 | " mg_patching_stitching=config.patching.stitching,\n", 448 | " wandb_log=config.wandb.log,\n", 449 | " log_test_interval=config.wandb.log_test_interval,\n", 450 | " log_output=False,\n", 451 | " use_distributed=config.distributed.use_distributed,\n", 452 | " verbose=config.verbose and is_logger)" 453 | ] 454 | }, 455 | { 456 | "cell_type": "markdown", 457 | "id": "b16a3727-313d-4219-8f8f-0cec58d74b00", 458 | "metadata": {}, 459 | "source": [ 460 | "# Training the model " 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": 11, 466 | "id": "0d6e3298-99ee-4371-8bad-60e6aac03d56", 467 | "metadata": {}, 468 | "outputs": [ 469 | { 470 | "name": "stdout", 471 | "output_type": "stream", 472 | "text": [ 473 | "Training on 3000 samples, testing on [32, 64].\n", 474 | "[0] time=3.03, avg_loss=7.8899, train_err=0.3945, 32_h1=0.2295, 32_l2=0.1710, 64_h1=0.2847, 64_l2=0.1733\n", 475 | "[1] time=1.38, avg_loss=3.7664, train_err=0.1883, 32_h1=0.1646, 32_l2=0.1177, 64_h1=0.2326, 64_l2=0.1221\n", 476 | "[2] time=1.37, avg_loss=3.1005, train_err=0.1550, 32_h1=0.1411, 32_l2=0.1027, 64_h1=0.2156, 64_l2=0.1106\n", 477 | "[3] time=1.36, avg_loss=2.5222, train_err=0.1261, 32_h1=0.1238, 32_l2=0.0800, 64_h1=0.2026, 64_l2=0.0936\n", 478 | "[4] time=1.36, avg_loss=2.3043, train_err=0.1152, 32_h1=0.1235, 32_l2=0.0808, 64_h1=0.1874, 64_l2=0.0858\n", 479 | "[5] time=1.36, avg_loss=2.2108, train_err=0.1105, 32_h1=0.1332, 32_l2=0.1041, 64_h1=0.2055, 64_l2=0.1122\n", 480 | "[6] time=1.37, avg_loss=1.9753, train_err=0.0988, 32_h1=0.1077, 32_l2=0.0720, 64_h1=0.1885, 64_l2=0.0838\n", 481 | "[7] time=1.37, avg_loss=1.9352, train_err=0.0968, 32_h1=0.1032, 32_l2=0.0642, 64_h1=0.1847, 64_l2=0.0753\n", 482 | "[8] time=1.36, avg_loss=1.8174, train_err=0.0909, 32_h1=0.1013, 32_l2=0.0632, 64_h1=0.1798, 64_l2=0.0763\n", 483 | "[9] time=1.37, avg_loss=1.7847, train_err=0.0892, 32_h1=0.1053, 32_l2=0.0672, 64_h1=0.1909, 64_l2=0.0788\n", 484 | "[10] time=1.37, avg_loss=1.6375, train_err=0.0819, 32_h1=0.0926, 32_l2=0.0513, 64_h1=0.1808, 64_l2=0.0666\n", 485 | "[11] time=1.37, avg_loss=1.5826, train_err=0.0791, 32_h1=0.0958, 32_l2=0.0574, 64_h1=0.1810, 64_l2=0.0700\n", 486 | "[12] time=1.36, avg_loss=1.6231, train_err=0.0812, 32_h1=0.0940, 32_l2=0.0534, 64_h1=0.1740, 64_l2=0.0636\n", 487 | "[13] time=1.42, avg_loss=1.5427, train_err=0.0771, 32_h1=0.0937, 32_l2=0.0532, 64_h1=0.1834, 64_l2=0.0692\n", 488 | "[14] time=1.37, avg_loss=1.4741, train_err=0.0737, 32_h1=0.0989, 32_l2=0.0623, 64_h1=0.1844, 64_l2=0.0798\n", 489 | "[15] time=1.36, avg_loss=1.5156, train_err=0.0758, 32_h1=0.1020, 32_l2=0.0649, 64_h1=0.1844, 64_l2=0.0730\n", 490 | "[16] time=1.37, avg_loss=1.5620, train_err=0.0781, 32_h1=0.0940, 32_l2=0.0608, 64_h1=0.1803, 64_l2=0.0747\n", 491 | "[17] time=1.36, avg_loss=1.3939, train_err=0.0697, 32_h1=0.1018, 32_l2=0.0620, 64_h1=0.1842, 64_l2=0.0772\n", 492 | "[18] time=1.89, avg_loss=1.4904, train_err=0.0745, 32_h1=0.1010, 32_l2=0.0704, 64_h1=0.1868, 64_l2=0.0794\n", 493 | "[19] time=1.83, avg_loss=1.4300, train_err=0.0715, 32_h1=0.0929, 32_l2=0.0525, 64_h1=0.1784, 64_l2=0.0679\n", 494 | "[20] time=1.84, avg_loss=1.3752, train_err=0.0688, 32_h1=0.0964, 32_l2=0.0635, 64_h1=0.1825, 64_l2=0.0694\n", 495 | "[21] time=1.84, avg_loss=1.4671, train_err=0.0734, 32_h1=0.0911, 32_l2=0.0513, 64_h1=0.1832, 64_l2=0.0696\n", 496 | "[22] time=1.88, avg_loss=1.3043, train_err=0.0652, 32_h1=0.0938, 32_l2=0.0538, 64_h1=0.1804, 64_l2=0.0687\n", 497 | "[23] time=1.37, avg_loss=1.2880, train_err=0.0644, 32_h1=0.0897, 32_l2=0.0492, 64_h1=0.1824, 64_l2=0.0629\n", 498 | "[24] time=1.37, avg_loss=1.3901, train_err=0.0695, 32_h1=0.1080, 32_l2=0.0701, 64_h1=0.1828, 64_l2=0.0785\n", 499 | "[25] time=1.37, avg_loss=1.3788, train_err=0.0689, 32_h1=0.0878, 32_l2=0.0514, 64_h1=0.1744, 64_l2=0.0613\n", 500 | "[26] time=1.37, avg_loss=1.3071, train_err=0.0654, 32_h1=0.0880, 32_l2=0.0489, 64_h1=0.1847, 64_l2=0.0698\n", 501 | "[27] time=1.36, avg_loss=1.3056, train_err=0.0653, 32_h1=0.0980, 32_l2=0.0679, 64_h1=0.1828, 64_l2=0.0830\n", 502 | "[28] time=1.37, avg_loss=1.2677, train_err=0.0634, 32_h1=0.0956, 32_l2=0.0621, 64_h1=0.1827, 64_l2=0.0692\n", 503 | "[29] time=1.37, avg_loss=1.2611, train_err=0.0631, 32_h1=0.0913, 32_l2=0.0500, 64_h1=0.1855, 64_l2=0.0652\n", 504 | "[30] time=1.37, avg_loss=1.1833, train_err=0.0592, 32_h1=0.0888, 32_l2=0.0512, 64_h1=0.1818, 64_l2=0.0655\n", 505 | "[31] time=1.36, avg_loss=1.2170, train_err=0.0608, 32_h1=0.0879, 32_l2=0.0481, 64_h1=0.1758, 64_l2=0.0625\n", 506 | "[32] time=1.36, avg_loss=1.1431, train_err=0.0572, 32_h1=0.0886, 32_l2=0.0479, 64_h1=0.1756, 64_l2=0.0594\n", 507 | "[33] time=1.37, avg_loss=1.2162, train_err=0.0608, 32_h1=0.0923, 32_l2=0.0522, 64_h1=0.1749, 64_l2=0.0629\n", 508 | "[34] time=1.37, avg_loss=1.1588, train_err=0.0579, 32_h1=0.0892, 32_l2=0.0526, 64_h1=0.1797, 64_l2=0.0656\n", 509 | "[35] time=1.37, avg_loss=1.1747, train_err=0.0587, 32_h1=0.0884, 32_l2=0.0481, 64_h1=0.1829, 64_l2=0.0650\n", 510 | "[36] time=1.36, avg_loss=1.1491, train_err=0.0575, 32_h1=0.0936, 32_l2=0.0542, 64_h1=0.1787, 64_l2=0.0672\n", 511 | "[37] time=1.37, avg_loss=1.1532, train_err=0.0577, 32_h1=0.0950, 32_l2=0.0569, 64_h1=0.1737, 64_l2=0.0679\n", 512 | "[38] time=1.37, avg_loss=1.2426, train_err=0.0621, 32_h1=0.0875, 32_l2=0.0488, 64_h1=0.1750, 64_l2=0.0638\n", 513 | "[39] time=1.37, avg_loss=1.1345, train_err=0.0567, 32_h1=0.0874, 32_l2=0.0493, 64_h1=0.1780, 64_l2=0.0658\n", 514 | "[40] time=1.36, avg_loss=1.1238, train_err=0.0562, 32_h1=0.0914, 32_l2=0.0516, 64_h1=0.1796, 64_l2=0.0662\n", 515 | "[41] time=1.36, avg_loss=1.1093, train_err=0.0555, 32_h1=0.0855, 32_l2=0.0457, 64_h1=0.1741, 64_l2=0.0621\n", 516 | "[42] time=1.36, avg_loss=1.0772, train_err=0.0539, 32_h1=0.0899, 32_l2=0.0523, 64_h1=0.1807, 64_l2=0.0688\n", 517 | "[43] time=1.36, avg_loss=1.0772, train_err=0.0539, 32_h1=0.0894, 32_l2=0.0556, 64_h1=0.1769, 64_l2=0.0705\n", 518 | "[44] time=1.36, avg_loss=1.0901, train_err=0.0545, 32_h1=0.0843, 32_l2=0.0443, 64_h1=0.1750, 64_l2=0.0589\n", 519 | "[45] time=1.36, avg_loss=1.0783, train_err=0.0539, 32_h1=0.0874, 32_l2=0.0486, 64_h1=0.1778, 64_l2=0.0593\n", 520 | "[46] time=1.46, avg_loss=1.0837, train_err=0.0542, 32_h1=0.0874, 32_l2=0.0482, 64_h1=0.1722, 64_l2=0.0575\n", 521 | "[47] time=1.37, avg_loss=1.1760, train_err=0.0588, 32_h1=0.0873, 32_l2=0.0507, 64_h1=0.1706, 64_l2=0.0639\n", 522 | "[48] time=1.37, avg_loss=1.0357, train_err=0.0518, 32_h1=0.0889, 32_l2=0.0503, 64_h1=0.1799, 64_l2=0.0663\n", 523 | "[49] time=1.36, avg_loss=1.0873, train_err=0.0544, 32_h1=0.0846, 32_l2=0.0464, 64_h1=0.1725, 64_l2=0.0592\n", 524 | "[50] time=1.36, avg_loss=1.0996, train_err=0.0550, 32_h1=0.0861, 32_l2=0.0461, 64_h1=0.1696, 64_l2=0.0598\n", 525 | "[51] time=1.37, avg_loss=1.0487, train_err=0.0524, 32_h1=0.0839, 32_l2=0.0433, 64_h1=0.1752, 64_l2=0.0602\n", 526 | "[52] time=1.37, avg_loss=1.0527, train_err=0.0526, 32_h1=0.0858, 32_l2=0.0469, 64_h1=0.1736, 64_l2=0.0588\n", 527 | "[53] time=1.36, avg_loss=1.0138, train_err=0.0507, 32_h1=0.0854, 32_l2=0.0475, 64_h1=0.1777, 64_l2=0.0619\n", 528 | "[54] time=1.36, avg_loss=1.0210, train_err=0.0511, 32_h1=0.0832, 32_l2=0.0431, 64_h1=0.1728, 64_l2=0.0580\n", 529 | "[55] time=1.36, avg_loss=0.9939, train_err=0.0497, 32_h1=0.0870, 32_l2=0.0474, 64_h1=0.1755, 64_l2=0.0609\n", 530 | "[56] time=1.37, avg_loss=1.0085, train_err=0.0504, 32_h1=0.0833, 32_l2=0.0438, 64_h1=0.1731, 64_l2=0.0603\n", 531 | "[57] time=1.37, avg_loss=1.0132, train_err=0.0507, 32_h1=0.0842, 32_l2=0.0462, 64_h1=0.1757, 64_l2=0.0613\n", 532 | "[58] time=1.36, avg_loss=0.9938, train_err=0.0497, 32_h1=0.0839, 32_l2=0.0439, 64_h1=0.1811, 64_l2=0.0651\n", 533 | "[59] time=1.36, avg_loss=0.9814, train_err=0.0491, 32_h1=0.0820, 32_l2=0.0425, 64_h1=0.1728, 64_l2=0.0565\n", 534 | "[60] time=1.36, avg_loss=0.9849, train_err=0.0492, 32_h1=0.0861, 32_l2=0.0477, 64_h1=0.1715, 64_l2=0.0616\n", 535 | "[61] time=1.37, avg_loss=0.9787, train_err=0.0489, 32_h1=0.0844, 32_l2=0.0450, 64_h1=0.1742, 64_l2=0.0623\n", 536 | "[62] time=1.36, avg_loss=1.0104, train_err=0.0505, 32_h1=0.0830, 32_l2=0.0437, 64_h1=0.1769, 64_l2=0.0605\n", 537 | "[63] time=1.36, avg_loss=0.9910, train_err=0.0495, 32_h1=0.0821, 32_l2=0.0415, 64_h1=0.1742, 64_l2=0.0579\n", 538 | "[64] time=1.36, avg_loss=0.9622, train_err=0.0481, 32_h1=0.0849, 32_l2=0.0462, 64_h1=0.1763, 64_l2=0.0608\n", 539 | "[65] time=1.36, avg_loss=1.0191, train_err=0.0510, 32_h1=0.0823, 32_l2=0.0419, 64_h1=0.1736, 64_l2=0.0570\n", 540 | "[66] time=1.37, avg_loss=0.9814, train_err=0.0491, 32_h1=0.0873, 32_l2=0.0492, 64_h1=0.1752, 64_l2=0.0643\n", 541 | "[67] time=1.36, avg_loss=0.9867, train_err=0.0493, 32_h1=0.0833, 32_l2=0.0446, 64_h1=0.1698, 64_l2=0.0588\n", 542 | "[68] time=1.36, avg_loss=0.9983, train_err=0.0499, 32_h1=0.0815, 32_l2=0.0417, 64_h1=0.1712, 64_l2=0.0590\n", 543 | "[69] time=1.37, avg_loss=0.9956, train_err=0.0498, 32_h1=0.0836, 32_l2=0.0453, 64_h1=0.1756, 64_l2=0.0604\n", 544 | "[70] time=1.37, avg_loss=0.9433, train_err=0.0472, 32_h1=0.0830, 32_l2=0.0432, 64_h1=0.1739, 64_l2=0.0583\n", 545 | "[71] time=1.36, avg_loss=0.9813, train_err=0.0491, 32_h1=0.0830, 32_l2=0.0433, 64_h1=0.1691, 64_l2=0.0588\n", 546 | "[72] time=1.36, avg_loss=0.9456, train_err=0.0473, 32_h1=0.0828, 32_l2=0.0429, 64_h1=0.1695, 64_l2=0.0599\n", 547 | "[73] time=1.37, avg_loss=0.9099, train_err=0.0455, 32_h1=0.0835, 32_l2=0.0438, 64_h1=0.1716, 64_l2=0.0599\n", 548 | "[74] time=1.37, avg_loss=0.9241, train_err=0.0462, 32_h1=0.0816, 32_l2=0.0419, 64_h1=0.1699, 64_l2=0.0572\n", 549 | "[75] time=1.37, avg_loss=0.8907, train_err=0.0445, 32_h1=0.0825, 32_l2=0.0410, 64_h1=0.1772, 64_l2=0.0604\n", 550 | "[76] time=1.36, avg_loss=0.8940, train_err=0.0447, 32_h1=0.0821, 32_l2=0.0428, 64_h1=0.1733, 64_l2=0.0588\n", 551 | "[77] time=1.37, avg_loss=0.8958, train_err=0.0448, 32_h1=0.0828, 32_l2=0.0447, 64_h1=0.1756, 64_l2=0.0593\n", 552 | "[78] time=1.37, avg_loss=0.9276, train_err=0.0464, 32_h1=0.0816, 32_l2=0.0424, 64_h1=0.1740, 64_l2=0.0599\n", 553 | "[79] time=1.37, avg_loss=0.8763, train_err=0.0438, 32_h1=0.0818, 32_l2=0.0414, 64_h1=0.1715, 64_l2=0.0570\n", 554 | "[80] time=1.36, avg_loss=0.8634, train_err=0.0432, 32_h1=0.0812, 32_l2=0.0416, 64_h1=0.1753, 64_l2=0.0614\n", 555 | "[81] time=1.36, avg_loss=0.8450, train_err=0.0423, 32_h1=0.0832, 32_l2=0.0448, 64_h1=0.1701, 64_l2=0.0626\n", 556 | "[82] time=1.37, avg_loss=0.8997, train_err=0.0450, 32_h1=0.0818, 32_l2=0.0419, 64_h1=0.1718, 64_l2=0.0590\n", 557 | "[83] time=1.37, avg_loss=0.8658, train_err=0.0433, 32_h1=0.0816, 32_l2=0.0415, 64_h1=0.1703, 64_l2=0.0552\n", 558 | "[84] time=1.37, avg_loss=0.9292, train_err=0.0465, 32_h1=0.0815, 32_l2=0.0424, 64_h1=0.1674, 64_l2=0.0580\n", 559 | "[85] time=1.36, avg_loss=0.9417, train_err=0.0471, 32_h1=0.0825, 32_l2=0.0439, 64_h1=0.1755, 64_l2=0.0608\n", 560 | "[86] time=1.37, avg_loss=0.8608, train_err=0.0430, 32_h1=0.0792, 32_l2=0.0392, 64_h1=0.1720, 64_l2=0.0573\n", 561 | "[87] time=1.38, avg_loss=0.9083, train_err=0.0454, 32_h1=0.0822, 32_l2=0.0440, 64_h1=0.1693, 64_l2=0.0602\n", 562 | "[88] time=1.57, avg_loss=0.8522, train_err=0.0426, 32_h1=0.0823, 32_l2=0.0427, 64_h1=0.1695, 64_l2=0.0571\n", 563 | "[89] time=1.36, avg_loss=0.8273, train_err=0.0414, 32_h1=0.0813, 32_l2=0.0414, 64_h1=0.1702, 64_l2=0.0568\n", 564 | "[90] time=1.36, avg_loss=0.8612, train_err=0.0431, 32_h1=0.0834, 32_l2=0.0468, 64_h1=0.1718, 64_l2=0.0641\n", 565 | "[91] time=1.36, avg_loss=0.8358, train_err=0.0418, 32_h1=0.0811, 32_l2=0.0410, 64_h1=0.1678, 64_l2=0.0558\n", 566 | "[92] time=1.37, avg_loss=0.8725, train_err=0.0436, 32_h1=0.0807, 32_l2=0.0408, 64_h1=0.1688, 64_l2=0.0557\n", 567 | "[93] time=1.36, avg_loss=0.8163, train_err=0.0408, 32_h1=0.0804, 32_l2=0.0417, 64_h1=0.1714, 64_l2=0.0593\n", 568 | "[94] time=1.36, avg_loss=0.8119, train_err=0.0406, 32_h1=0.0791, 32_l2=0.0393, 64_h1=0.1706, 64_l2=0.0581\n", 569 | "[95] time=1.36, avg_loss=0.8022, train_err=0.0401, 32_h1=0.0819, 32_l2=0.0416, 64_h1=0.1697, 64_l2=0.0555\n", 570 | "[96] time=1.37, avg_loss=0.8371, train_err=0.0419, 32_h1=0.0793, 32_l2=0.0393, 64_h1=0.1684, 64_l2=0.0570\n", 571 | "[97] time=1.37, avg_loss=0.8227, train_err=0.0411, 32_h1=0.0800, 32_l2=0.0407, 64_h1=0.1685, 64_l2=0.0583\n", 572 | "[98] time=1.43, avg_loss=0.8176, train_err=0.0409, 32_h1=0.0841, 32_l2=0.0471, 64_h1=0.1681, 64_l2=0.0578\n", 573 | "[99] time=1.85, avg_loss=0.8517, train_err=0.0426, 32_h1=0.0809, 32_l2=0.0401, 64_h1=0.1726, 64_l2=0.0607\n", 574 | "[100] time=1.85, avg_loss=0.8445, train_err=0.0422, 32_h1=0.0810, 32_l2=0.0408, 64_h1=0.1688, 64_l2=0.0558\n", 575 | "[101] time=1.42, avg_loss=0.7962, train_err=0.0398, 32_h1=0.0796, 32_l2=0.0393, 64_h1=0.1680, 64_l2=0.0577\n", 576 | "[102] time=1.84, avg_loss=0.7758, train_err=0.0388, 32_h1=0.0799, 32_l2=0.0398, 64_h1=0.1664, 64_l2=0.0556\n", 577 | "[103] time=1.87, avg_loss=0.8005, train_err=0.0400, 32_h1=0.0792, 32_l2=0.0395, 64_h1=0.1688, 64_l2=0.0552\n", 578 | "[104] time=1.43, avg_loss=0.8099, train_err=0.0405, 32_h1=0.0791, 32_l2=0.0394, 64_h1=0.1664, 64_l2=0.0535\n", 579 | "[105] time=1.37, avg_loss=0.7828, train_err=0.0391, 32_h1=0.0815, 32_l2=0.0430, 64_h1=0.1691, 64_l2=0.0574\n", 580 | "[106] time=1.37, avg_loss=0.7799, train_err=0.0390, 32_h1=0.0795, 32_l2=0.0393, 64_h1=0.1679, 64_l2=0.0556\n", 581 | "[107] time=1.36, avg_loss=0.7685, train_err=0.0384, 32_h1=0.0810, 32_l2=0.0434, 64_h1=0.1725, 64_l2=0.0633\n", 582 | "[108] time=1.36, avg_loss=0.7581, train_err=0.0379, 32_h1=0.0801, 32_l2=0.0407, 64_h1=0.1744, 64_l2=0.0574\n", 583 | "[109] time=1.37, avg_loss=0.7415, train_err=0.0371, 32_h1=0.0782, 32_l2=0.0383, 64_h1=0.1670, 64_l2=0.0540\n", 584 | "[110] time=1.37, avg_loss=0.7387, train_err=0.0369, 32_h1=0.0790, 32_l2=0.0392, 64_h1=0.1664, 64_l2=0.0539\n", 585 | "[111] time=1.37, avg_loss=0.7338, train_err=0.0367, 32_h1=0.0788, 32_l2=0.0385, 64_h1=0.1694, 64_l2=0.0574\n", 586 | "[112] time=1.36, avg_loss=0.7426, train_err=0.0371, 32_h1=0.0811, 32_l2=0.0434, 64_h1=0.1745, 64_l2=0.0593\n", 587 | "[113] time=1.36, avg_loss=0.7849, train_err=0.0392, 32_h1=0.0817, 32_l2=0.0452, 64_h1=0.1653, 64_l2=0.0627\n", 588 | "[114] time=1.37, avg_loss=0.7933, train_err=0.0397, 32_h1=0.0803, 32_l2=0.0409, 64_h1=0.1715, 64_l2=0.0568\n", 589 | "[115] time=1.37, avg_loss=0.7377, train_err=0.0369, 32_h1=0.0789, 32_l2=0.0389, 64_h1=0.1688, 64_l2=0.0556\n", 590 | "[116] time=1.37, avg_loss=0.7639, train_err=0.0382, 32_h1=0.0794, 32_l2=0.0394, 64_h1=0.1683, 64_l2=0.0574\n", 591 | "[117] time=1.36, avg_loss=0.7515, train_err=0.0376, 32_h1=0.0785, 32_l2=0.0382, 64_h1=0.1665, 64_l2=0.0549\n", 592 | "[118] time=1.37, avg_loss=0.7180, train_err=0.0359, 32_h1=0.0792, 32_l2=0.0394, 64_h1=0.1671, 64_l2=0.0576\n", 593 | "[119] time=1.37, avg_loss=0.7191, train_err=0.0360, 32_h1=0.0795, 32_l2=0.0396, 64_h1=0.1672, 64_l2=0.0541\n", 594 | "[120] time=1.37, avg_loss=0.7148, train_err=0.0357, 32_h1=0.0792, 32_l2=0.0389, 64_h1=0.1671, 64_l2=0.0575\n", 595 | "[121] time=1.36, avg_loss=0.7012, train_err=0.0351, 32_h1=0.0795, 32_l2=0.0399, 64_h1=0.1639, 64_l2=0.0555\n", 596 | "[122] time=1.37, avg_loss=0.6962, train_err=0.0348, 32_h1=0.0787, 32_l2=0.0388, 64_h1=0.1697, 64_l2=0.0570\n", 597 | "[123] time=1.37, avg_loss=0.6970, train_err=0.0349, 32_h1=0.0793, 32_l2=0.0388, 64_h1=0.1693, 64_l2=0.0567\n", 598 | "[124] time=1.37, avg_loss=0.6888, train_err=0.0344, 32_h1=0.0788, 32_l2=0.0382, 64_h1=0.1687, 64_l2=0.0570\n", 599 | "[125] time=1.37, avg_loss=0.7060, train_err=0.0353, 32_h1=0.0799, 32_l2=0.0412, 64_h1=0.1649, 64_l2=0.0576\n", 600 | "[126] time=1.36, avg_loss=0.6991, train_err=0.0350, 32_h1=0.0792, 32_l2=0.0393, 64_h1=0.1681, 64_l2=0.0583\n", 601 | "[127] time=1.37, avg_loss=0.7098, train_err=0.0355, 32_h1=0.0796, 32_l2=0.0406, 64_h1=0.1641, 64_l2=0.0574\n", 602 | "[128] time=1.37, avg_loss=0.6971, train_err=0.0349, 32_h1=0.0792, 32_l2=0.0399, 64_h1=0.1690, 64_l2=0.0588\n", 603 | "[129] time=1.37, avg_loss=0.6810, train_err=0.0340, 32_h1=0.0793, 32_l2=0.0393, 64_h1=0.1648, 64_l2=0.0559\n", 604 | "[130] time=1.36, avg_loss=0.6848, train_err=0.0342, 32_h1=0.0780, 32_l2=0.0378, 64_h1=0.1670, 64_l2=0.0536\n", 605 | "[131] time=1.94, avg_loss=0.6600, train_err=0.0330, 32_h1=0.0779, 32_l2=0.0379, 64_h1=0.1661, 64_l2=0.0545\n", 606 | "[132] time=1.87, avg_loss=0.6428, train_err=0.0321, 32_h1=0.0794, 32_l2=0.0394, 64_h1=0.1695, 64_l2=0.0588\n", 607 | "[133] time=1.85, avg_loss=0.6532, train_err=0.0327, 32_h1=0.0789, 32_l2=0.0392, 64_h1=0.1690, 64_l2=0.0568\n", 608 | "[134] time=1.88, avg_loss=0.6573, train_err=0.0329, 32_h1=0.0780, 32_l2=0.0376, 64_h1=0.1664, 64_l2=0.0559\n", 609 | "[135] time=1.78, avg_loss=0.6445, train_err=0.0322, 32_h1=0.0784, 32_l2=0.0386, 64_h1=0.1644, 64_l2=0.0560\n", 610 | "[136] time=1.82, avg_loss=0.6378, train_err=0.0319, 32_h1=0.0780, 32_l2=0.0383, 64_h1=0.1651, 64_l2=0.0527\n", 611 | "[137] time=1.85, avg_loss=0.6550, train_err=0.0327, 32_h1=0.0792, 32_l2=0.0400, 64_h1=0.1652, 64_l2=0.0551\n", 612 | "[138] time=1.75, avg_loss=0.6341, train_err=0.0317, 32_h1=0.0775, 32_l2=0.0376, 64_h1=0.1661, 64_l2=0.0556\n", 613 | "[139] time=2.00, avg_loss=0.8234, train_err=0.0412, 32_h1=0.0783, 32_l2=0.0386, 64_h1=0.1654, 64_l2=0.0566\n", 614 | "[140] time=1.97, avg_loss=0.6822, train_err=0.0341, 32_h1=0.0784, 32_l2=0.0380, 64_h1=0.1675, 64_l2=0.0558\n", 615 | "[141] time=1.98, avg_loss=0.6332, train_err=0.0317, 32_h1=0.0778, 32_l2=0.0379, 64_h1=0.1670, 64_l2=0.0546\n", 616 | "[142] time=1.99, avg_loss=0.6205, train_err=0.0310, 32_h1=0.0786, 32_l2=0.0394, 64_h1=0.1678, 64_l2=0.0592\n", 617 | "[143] time=1.82, avg_loss=0.6098, train_err=0.0305, 32_h1=0.0785, 32_l2=0.0386, 64_h1=0.1676, 64_l2=0.0584\n", 618 | "[144] time=1.38, avg_loss=0.6116, train_err=0.0306, 32_h1=0.0794, 32_l2=0.0422, 64_h1=0.1702, 64_l2=0.0591\n", 619 | "[145] time=1.37, avg_loss=0.6018, train_err=0.0301, 32_h1=0.0776, 32_l2=0.0373, 64_h1=0.1674, 64_l2=0.0564\n", 620 | "[146] time=1.38, avg_loss=0.6001, train_err=0.0300, 32_h1=0.0781, 32_l2=0.0386, 64_h1=0.1662, 64_l2=0.0583\n", 621 | "[147] time=1.38, avg_loss=0.5990, train_err=0.0300, 32_h1=0.0796, 32_l2=0.0416, 64_h1=0.1679, 64_l2=0.0572\n", 622 | "[148] time=1.38, avg_loss=0.6462, train_err=0.0323, 32_h1=0.0802, 32_l2=0.0411, 64_h1=0.1721, 64_l2=0.0580\n", 623 | "[149] time=1.37, avg_loss=0.6152, train_err=0.0308, 32_h1=0.0777, 32_l2=0.0373, 64_h1=0.1688, 64_l2=0.0562\n" 624 | ] 625 | } 626 | ], 627 | "source": [ 628 | "trainer.train(train_loader, test_loaders,\n", 629 | " output_encoder,\n", 630 | " model, \n", 631 | " optimizer,\n", 632 | " scheduler, \n", 633 | " regularizer=False, \n", 634 | " training_loss=train_loss,\n", 635 | " eval_losses=eval_losses)" 636 | ] 637 | }, 638 | { 639 | "cell_type": "markdown", 640 | "id": "1b20be56-d200-44dc-b97b-fca021e353c8", 641 | "metadata": {}, 642 | "source": [ 643 | "# Follow-up questions" 644 | ] 645 | }, 646 | { 647 | "cell_type": "markdown", 648 | "id": "9a67e1d5-4b9a-4be3-bff4-fb2a6b152f9c", 649 | "metadata": {}, 650 | "source": [ 651 | "You can now play with the configuration and see how the performance is impacted.\n", 652 | "\n", 653 | "Which parameters do you think will most influence performance? \n", 654 | "Learning rate? Learning schedule? hidden_channels? Number of training samples? \n", 655 | "\n", 656 | "Does your intuition match the results you are getting?" 657 | ] 658 | } 659 | ], 660 | "metadata": { 661 | "kernelspec": { 662 | "display_name": "Python 3 (ipykernel)", 663 | "language": "python", 664 | "name": "python3" 665 | }, 666 | "language_info": { 667 | "codemirror_mode": { 668 | "name": "ipython", 669 | "version": 3 670 | }, 671 | "file_extension": ".py", 672 | "mimetype": "text/x-python", 673 | "name": "python", 674 | "nbconvert_exporter": "python", 675 | "pygments_lexer": "ipython3", 676 | "version": "3.9.15" 677 | } 678 | }, 679 | "nbformat": 4, 680 | "nbformat_minor": 5 681 | } 682 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Caltech AI4Science Bootcamp 2 | 3 | - [Slides](assets/Bootcamp_DLI_Modulus_v2207_CalTech_share.pdf) 4 | - [Reference](assets/2023-Nik.pdf) 5 | 6 | ## Agenda 7 | Day 1: Feb 13th, 2023 8 | 9 | - 12:30 PM - 01:30 PM: Welcome and Introduction to PDEs and Neural Operators 10 | - 01:30 PM - 02:00 PM: Getting started with NVIDIA compute 11 | - 02:00 PM - 02:15 PM: Break 12 | - 02:15 PM - 03:45 PM: Fourier Neural Operator (FNO), NeuralOp library, and learning the Navier-Stokes equations 13 | - [Recording](https://caltech.zoom.us/rec/share/QoUOu2fRVjZohOBTBeeqkvivJWLbCEi15UBqatJwxgCy546A1Dv7VunQEjNKO_97.7Acc_2LX7J17mMem)- passcode: RG$?n2?v 14 | 15 | 16 | Day 2: Feb 14th, 2023 17 | - 12:30 PM - 02:00 PM: Introduction to Modulus, Physics Informed Neural Network (PINN), and Physics Informed Neural Operator (PINO) 18 | - 02:00 PM - 02:15 PM: Break 19 | - 02:15 PM - 04: 00 PM: Re-visiting the Navier-Stokes equations and Mini project 20 | - [Recording]( 21 | https://caltech.zoom.us/rec/share/1UeB1MZiqhGzkzPKoHAWAX33IpMYniTOs0YkxDAo9lgFHXux4mBBoki5QOrZ6pOy.65onc8Bg6I41Cvga)- passcode: 2he6K@su 22 | 23 | ## Others 24 | - [Bootcamp material](https://github.com/NeuralOperator/bootcamp.git) 25 | - [NeuralOperator repo](https://github.com/NeuralOperator/neuraloperator.git) 26 | 27 | -------------------------------------------------------------------------------- /assets/2023-Nik.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/bootcamp/380749a4d670639806abc82de6eae11fda249ed9/assets/2023-Nik.pdf -------------------------------------------------------------------------------- /assets/Bootcamp_DLI_Modulus_v2207_CalTech_share.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/bootcamp/380749a4d670639806abc82de6eae11fda249ed9/assets/Bootcamp_DLI_Modulus_v2207_CalTech_share.pdf -------------------------------------------------------------------------------- /config/darcy_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | 3 | #General 4 | verbose: True 5 | arch: 'tfno2d' 6 | 7 | #Distributed computing 8 | distributed: 9 | use_distributed: False 10 | 11 | # FNO related 12 | tfno2d: 13 | data_channels: 3 14 | n_modes_height: 32 15 | n_modes_width: 32 16 | hidden_channels: 64 17 | projection_channels: 256 18 | n_layers: 4 19 | domain_padding: None #0.078125 20 | domain_padding_mode: 'one-sided' #symmetric 21 | fft_norm: 'forward' 22 | norm: 'group_norm' 23 | skip: 'linear' 24 | implementation: 'factorized' 25 | separable: 0 26 | preactivation: 0 27 | 28 | use_mlp: 1 29 | mlp: 30 | expansion: 0.5 31 | dropout: 0 32 | 33 | factorization: None 34 | rank: 1.0 35 | fixed_rank_modes: None 36 | dropout: 0.0 37 | tensor_lasso_penalty: 0.0 38 | joint_factorization: False 39 | 40 | # Optimizer 41 | opt: 42 | n_epochs: 150 43 | learning_rate: 5e-3 44 | training_loss: 'h1' 45 | weight_decay: 1e-4 46 | amp_autocast: False 47 | 48 | scheduler_T_max: 300 # For cosine only, typically take n_epochs 49 | scheduler_patience: 5 # For ReduceLROnPlateau only 50 | scheduler: 'CosineAnnealingLR' # Or 'CosineAnnealingLR' OR 'ReduceLROnPlateau' OR 'StepLR' 51 | step_size: 50 52 | gamma: 0.5 53 | 54 | # Dataset related 55 | data: 56 | folder: "/dli/task/bootcamp/data/darcy_flow/" 57 | batch_size: 32 58 | n_train: 3000 59 | train_resolution: 32 60 | n_tests: [500, 500] 61 | test_resolutions: [32, 64] 62 | test_batch_sizes: [32, 32] 63 | positional_encoding: True 64 | 65 | encode_input: True 66 | encode_output: False 67 | 68 | # Patching 69 | patching: 70 | levels: 0 71 | padding: 0 72 | stitching: False 73 | 74 | # Weights and biases 75 | wandb: 76 | log: False 77 | log_test_interval: 1 78 | -------------------------------------------------------------------------------- /config/tfno_darcy_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | 3 | #General 4 | verbose: True 5 | arch: 'tfno2d' 6 | 7 | #Distributed computing 8 | distributed: 9 | use_distributed: False 10 | 11 | # FNO related 12 | tfno2d: 13 | lifting_channels: 32 14 | data_channels: 3 15 | n_modes_height: 12 16 | n_modes_width: 12 17 | hidden_channels: 32 18 | projection_channels: 32 19 | n_layers: 4 20 | domain_padding: None #0.078125 21 | domain_padding_mode: 'one-sided' #symmetric 22 | fft_norm: 'forward' 23 | norm: None 24 | skip: 'linear' 25 | implementation: 'factorized' 26 | separable: 0 27 | preactivation: 0 28 | 29 | use_mlp: 1 30 | mlp: 31 | expansion: 0.5 32 | dropout: 0 33 | 34 | factorization: 'tucker' 35 | rank: 0.2 36 | fixed_rank_modes: None 37 | dropout: 0.0 38 | tensor_lasso_penalty: 0.0 39 | joint_factorization: False 40 | 41 | # Optimizer 42 | opt: 43 | n_epochs: 150 44 | learning_rate: 5e-3 45 | training_loss: 'h1' 46 | weight_decay: 1e-4 47 | amp_autocast: False 48 | 49 | scheduler_T_max: 300 # For cosine only, typically take n_epochs 50 | scheduler_patience: 5 # For ReduceLROnPlateau only 51 | scheduler: 'CosineAnnealingLR' # Or 'CosineAnnealingLR' OR 'ReduceLROnPlateau' OR 'StepLR' 52 | step_size: 50 53 | gamma: 0.5 54 | 55 | # Dataset related 56 | data: 57 | folder: "/dli/task/bootcamp/data/darcy_flow/" 58 | batch_size: 32 59 | n_train: 3000 60 | train_resolution: 32 61 | n_tests: [500, 500] 62 | test_resolutions: [32, 64] 63 | test_batch_sizes: [32, 32] 64 | positional_encoding: True 65 | 66 | encode_input: True 67 | encode_output: False 68 | 69 | # Patching 70 | patching: 71 | levels: 0 72 | padding: 0 73 | stitching: False 74 | 75 | # Weights and biases 76 | wandb: 77 | log: False 78 | log_test_interval: 1 79 | -------------------------------------------------------------------------------- /images/fourier_layer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/bootcamp/380749a4d670639806abc82de6eae11fda249ed9/images/fourier_layer.png --------------------------------------------------------------------------------