├── .gitignore ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── configs └── config.yaml ├── dataloader.py ├── inference.py ├── model.py ├── models └── model.pt ├── train.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | runs 4 | experiments 5 | div2k_lmdb 6 | outputs 7 | model_timing.py 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Hasnain Raza 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = 'https://pypi.python.org/simple' 3 | verify_ssl = true 4 | name = 'pypi' 5 | 6 | [requires] 7 | python_version = "3.10" 8 | 9 | 10 | [packages] 11 | torch = "==2.3" 12 | torchvision = "==0.18" 13 | torchmetrics = "==1.4.0" 14 | tensorboard = "==2.16.2" 15 | hydra-core = "==1.3.2" 16 | 17 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "488463251be5feb8b5af32706c4035bfbc4044eb934399d6265535f18bcc5365" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3.10" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.python.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": { 19 | "absl-py": { 20 | "hashes": [ 21 | "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308", 22 | "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff" 23 | ], 24 | "markers": "python_version >= '3.7'", 25 | "version": "==2.1.0" 26 | }, 27 | "antlr4-python3-runtime": { 28 | "hashes": [ 29 | "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b" 30 | ], 31 | "version": "==4.9.3" 32 | }, 33 | "colorama": { 34 | "hashes": [ 35 | "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", 36 | "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6" 37 | ], 38 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6'", 39 | "version": "==0.4.6" 40 | }, 41 | "filelock": { 42 | "hashes": [ 43 | "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f", 44 | "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a" 45 | ], 46 | "markers": "python_version >= '3.8'", 47 | "version": "==3.14.0" 48 | }, 49 | "fsspec": { 50 | "hashes": [ 51 | "sha256:58d7122eb8a1a46f7f13453187bfea4972d66bf01618d37366521b1998034cee", 52 | "sha256:f579960a56e6d8038a9efc8f9c77279ec12e6299aa86b0769a7e9c46b94527c2" 53 | ], 54 | "markers": "python_version >= '3.8'", 55 | "version": "==2024.6.0" 56 | }, 57 | "grpcio": { 58 | "hashes": [ 59 | "sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040", 60 | "sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122", 61 | "sha256:0a2813093ddb27418a4c99f9b1c223fab0b053157176a64cc9db0f4557b69bd9", 62 | "sha256:0cc79c982ccb2feec8aad0e8fb0d168bcbca85bc77b080d0d3c5f2f15c24ea8f", 63 | "sha256:1257b76748612aca0f89beec7fa0615727fd6f2a1ad580a9638816a4b2eb18fd", 64 | "sha256:1262402af5a511c245c3ae918167eca57342c72320dffae5d9b51840c4b2f86d", 65 | "sha256:19264fc964576ddb065368cae953f8d0514ecc6cb3da8903766d9fb9d4554c33", 66 | "sha256:198908f9b22e2672a998870355e226a725aeab327ac4e6ff3a1399792ece4762", 67 | "sha256:1de403fc1305fd96cfa75e83be3dee8538f2413a6b1685b8452301c7ba33c294", 68 | "sha256:20405cb8b13fd779135df23fabadc53b86522d0f1cba8cca0e87968587f50650", 69 | "sha256:2981c7365a9353f9b5c864595c510c983251b1ab403e05b1ccc70a3d9541a73b", 70 | "sha256:2c3c1b90ab93fed424e454e93c0ed0b9d552bdf1b0929712b094f5ecfe7a23ad", 71 | "sha256:39b9d0acaa8d835a6566c640f48b50054f422d03e77e49716d4c4e8e279665a1", 72 | "sha256:3b64ae304c175671efdaa7ec9ae2cc36996b681eb63ca39c464958396697daff", 73 | "sha256:4657d24c8063e6095f850b68f2d1ba3b39f2b287a38242dcabc166453e950c59", 74 | "sha256:4d6dab6124225496010bd22690f2d9bd35c7cbb267b3f14e7a3eb05c911325d4", 75 | "sha256:55260032b95c49bee69a423c2f5365baa9369d2f7d233e933564d8a47b893027", 76 | "sha256:55697ecec192bc3f2f3cc13a295ab670f51de29884ca9ae6cd6247df55df2502", 77 | "sha256:5841dd1f284bd1b3d8a6eca3a7f062b06f1eec09b184397e1d1d43447e89a7ae", 78 | "sha256:58b1041e7c870bb30ee41d3090cbd6f0851f30ae4eb68228955d973d3efa2e61", 79 | "sha256:5e42634a989c3aa6049f132266faf6b949ec2a6f7d302dbb5c15395b77d757eb", 80 | "sha256:5e56462b05a6f860b72f0fa50dca06d5b26543a4e88d0396259a07dc30f4e5aa", 81 | "sha256:5f8b75f64d5d324c565b263c67dbe4f0af595635bbdd93bb1a88189fc62ed2e5", 82 | "sha256:62b4e6eb7bf901719fce0ca83e3ed474ae5022bb3827b0a501e056458c51c0a1", 83 | "sha256:6503b64c8b2dfad299749cad1b595c650c91e5b2c8a1b775380fcf8d2cbba1e9", 84 | "sha256:6c024ffc22d6dc59000faf8ad781696d81e8e38f4078cb0f2630b4a3cf231a90", 85 | "sha256:73819689c169417a4f978e562d24f2def2be75739c4bed1992435d007819da1b", 86 | "sha256:75dbbf415026d2862192fe1b28d71f209e2fd87079d98470db90bebe57b33179", 87 | "sha256:8caee47e970b92b3dd948371230fcceb80d3f2277b3bf7fbd7c0564e7d39068e", 88 | "sha256:8d51dd1c59d5fa0f34266b80a3805ec29a1f26425c2a54736133f6d87fc4968a", 89 | "sha256:940e3ec884520155f68a3b712d045e077d61c520a195d1a5932c531f11883489", 90 | "sha256:a011ac6c03cfe162ff2b727bcb530567826cec85eb8d4ad2bfb4bd023287a52d", 91 | "sha256:a3a035c37ce7565b8f4f35ff683a4db34d24e53dc487e47438e434eb3f701b2a", 92 | "sha256:a5e771d0252e871ce194d0fdcafd13971f1aae0ddacc5f25615030d5df55c3a2", 93 | "sha256:ac15b6c2c80a4d1338b04d42a02d376a53395ddf0ec9ab157cbaf44191f3ffdd", 94 | "sha256:b1a82e0b9b3022799c336e1fc0f6210adc019ae84efb7321d668129d28ee1efb", 95 | "sha256:bac71b4b28bc9af61efcdc7630b166440bbfbaa80940c9a697271b5e1dabbc61", 96 | "sha256:bbc5b1d78a7822b0a84c6f8917faa986c1a744e65d762ef6d8be9d75677af2ca", 97 | "sha256:c1a786ac592b47573a5bb7e35665c08064a5d77ab88a076eec11f8ae86b3e3f6", 98 | "sha256:c84ad903d0d94311a2b7eea608da163dace97c5fe9412ea311e72c3684925602", 99 | "sha256:d4d29cc612e1332237877dfa7fe687157973aab1d63bd0f84cf06692f04c0367", 100 | "sha256:e3d9f8d1221baa0ced7ec7322a981e28deb23749c76eeeb3d33e18b72935ab62", 101 | "sha256:e7cd5c1325f6808b8ae31657d281aadb2a51ac11ab081ae335f4f7fc44c1721d", 102 | "sha256:ed6091fa0adcc7e4ff944090cf203a52da35c37a130efa564ded02b7aff63bcd", 103 | "sha256:ee73a2f5ca4ba44fa33b4d7d2c71e2c8a9e9f78d53f6507ad68e7d2ad5f64a22", 104 | "sha256:f10193c69fc9d3d726e83bbf0f3d316f1847c3071c8c93d8090cf5f326b14309" 105 | ], 106 | "markers": "python_version >= '3.8'", 107 | "version": "==1.64.1" 108 | }, 109 | "hydra-core": { 110 | "hashes": [ 111 | "sha256:8a878ed67216997c3e9d88a8e72e7b4767e81af37afb4ea3334b269a4390a824", 112 | "sha256:fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b" 113 | ], 114 | "index": "pypi", 115 | "version": "==1.3.2" 116 | }, 117 | "jinja2": { 118 | "hashes": [ 119 | "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369", 120 | "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d" 121 | ], 122 | "markers": "python_version >= '3.7'", 123 | "version": "==3.1.4" 124 | }, 125 | "lightning-utilities": { 126 | "hashes": [ 127 | "sha256:541f471ed94e18a28d72879338c8c52e873bb46f4c47644d89228faeb6751159", 128 | "sha256:adf4cf9c5d912fe505db4729e51d1369c6927f3a8ac55a9dff895ce5c0da08d9" 129 | ], 130 | "markers": "python_version >= '3.8'", 131 | "version": "==0.11.2" 132 | }, 133 | "markdown": { 134 | "hashes": [ 135 | "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f", 136 | "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224" 137 | ], 138 | "markers": "python_version >= '3.8'", 139 | "version": "==3.6" 140 | }, 141 | "markupsafe": { 142 | "hashes": [ 143 | "sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4", 144 | "sha256:0f4ca02bea9a23221c0182836703cbf8930c5e9454bacce27e767509fa286a30", 145 | "sha256:1225beacc926f536dc82e45f8a4d68502949dc67eea90eab715dea3a21c1b5f0", 146 | "sha256:131a3c7689c85f5ad20f9f6fb1b866f402c445b220c19fe4308c0b147ccd2ad9", 147 | "sha256:15ab75ef81add55874e7ab7055e9c397312385bd9ced94920f2802310c930396", 148 | "sha256:1a9d3f5f0901fdec14d8d2f66ef7d035f2157240a433441719ac9a3fba440b13", 149 | "sha256:1c99d261bd2d5f6b59325c92c73df481e05e57f19837bdca8413b9eac4bd8028", 150 | "sha256:1e084f686b92e5b83186b07e8a17fc09e38fff551f3602b249881fec658d3eca", 151 | "sha256:2181e67807fc2fa785d0592dc2d6206c019b9502410671cc905d132a92866557", 152 | "sha256:2cb8438c3cbb25e220c2ab33bb226559e7afb3baec11c4f218ffa7308603c832", 153 | "sha256:3169b1eefae027567d1ce6ee7cae382c57fe26e82775f460f0b2778beaad66c0", 154 | "sha256:3809ede931876f5b2ec92eef964286840ed3540dadf803dd570c3b7e13141a3b", 155 | "sha256:38a9ef736c01fccdd6600705b09dc574584b89bea478200c5fbf112a6b0d5579", 156 | "sha256:3d79d162e7be8f996986c064d1c7c817f6df3a77fe3d6859f6f9e7be4b8c213a", 157 | "sha256:444dcda765c8a838eaae23112db52f1efaf750daddb2d9ca300bcae1039adc5c", 158 | "sha256:48032821bbdf20f5799ff537c7ac3d1fba0ba032cfc06194faffa8cda8b560ff", 159 | "sha256:4aa4e5faecf353ed117801a068ebab7b7e09ffb6e1d5e412dc852e0da018126c", 160 | "sha256:52305740fe773d09cffb16f8ed0427942901f00adedac82ec8b67752f58a1b22", 161 | "sha256:569511d3b58c8791ab4c2e1285575265991e6d8f8700c7be0e88f86cb0672094", 162 | "sha256:57cb5a3cf367aeb1d316576250f65edec5bb3be939e9247ae594b4bcbc317dfb", 163 | "sha256:5b02fb34468b6aaa40dfc198d813a641e3a63b98c2b05a16b9f80b7ec314185e", 164 | "sha256:6381026f158fdb7c72a168278597a5e3a5222e83ea18f543112b2662a9b699c5", 165 | "sha256:6af100e168aa82a50e186c82875a5893c5597a0c1ccdb0d8b40240b1f28b969a", 166 | "sha256:6c89876f41da747c8d3677a2b540fb32ef5715f97b66eeb0c6b66f5e3ef6f59d", 167 | "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a", 168 | "sha256:70a87b411535ccad5ef2f1df5136506a10775d267e197e4cf531ced10537bd6b", 169 | "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8", 170 | "sha256:846ade7b71e3536c4e56b386c2a47adf5741d2d8b94ec9dc3e92e5e1ee1e2225", 171 | "sha256:88416bd1e65dcea10bc7569faacb2c20ce071dd1f87539ca2ab364bf6231393c", 172 | "sha256:88b49a3b9ff31e19998750c38e030fc7bb937398b1f78cfa599aaef92d693144", 173 | "sha256:8c4e8c3ce11e1f92f6536ff07154f9d49677ebaaafc32db9db4620bc11ed480f", 174 | "sha256:8e06879fc22a25ca47312fbe7c8264eb0b662f6db27cb2d3bbbc74b1df4b9b87", 175 | "sha256:9025b4018f3a1314059769c7bf15441064b2207cb3f065e6ea1e7359cb46db9d", 176 | "sha256:93335ca3812df2f366e80509ae119189886b0f3c2b81325d39efdb84a1e2ae93", 177 | "sha256:9778bd8ab0a994ebf6f84c2b949e65736d5575320a17ae8984a77fab08db94cf", 178 | "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158", 179 | "sha256:a123e330ef0853c6e822384873bef7507557d8e4a082961e1defa947aa59ba84", 180 | "sha256:a904af0a6162c73e3edcb969eeeb53a63ceeb5d8cf642fade7d39e7963a22ddb", 181 | "sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48", 182 | "sha256:b424c77b206d63d500bcb69fa55ed8d0e6a3774056bdc4839fc9298a7edca171", 183 | "sha256:b5a6b3ada725cea8a5e634536b1b01c30bcdcd7f9c6fff4151548d5bf6b3a36c", 184 | "sha256:ba8062ed2cf21c07a9e295d5b8a2a5ce678b913b45fdf68c32d95d6c1291e0b6", 185 | "sha256:ba9527cdd4c926ed0760bc301f6728ef34d841f405abf9d4f959c478421e4efd", 186 | "sha256:bbcb445fa71794da8f178f0f6d66789a28d7319071af7a496d4d507ed566270d", 187 | "sha256:bcf3e58998965654fdaff38e58584d8937aa3096ab5354d493c77d1fdd66d7a1", 188 | "sha256:c0ef13eaeee5b615fb07c9a7dadb38eac06a0608b41570d8ade51c56539e509d", 189 | "sha256:cabc348d87e913db6ab4aa100f01b08f481097838bdddf7c7a84b7575b7309ca", 190 | "sha256:cdb82a876c47801bb54a690c5ae105a46b392ac6099881cdfb9f6e95e4014c6a", 191 | "sha256:cfad01eed2c2e0c01fd0ecd2ef42c492f7f93902e39a42fc9ee1692961443a29", 192 | "sha256:d16a81a06776313e817c951135cf7340a3e91e8c1ff2fac444cfd75fffa04afe", 193 | "sha256:d8213e09c917a951de9d09ecee036d5c7d36cb6cb7dbaece4c71a60d79fb9798", 194 | "sha256:e07c3764494e3776c602c1e78e298937c3315ccc9043ead7e685b7f2b8d47b3c", 195 | "sha256:e17c96c14e19278594aa4841ec148115f9c7615a47382ecb6b82bd8fea3ab0c8", 196 | "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f", 197 | "sha256:e6a2a455bd412959b57a172ce6328d2dd1f01cb2135efda2e4576e8a23fa3b0f", 198 | "sha256:eaa0a10b7f72326f1372a713e73c3f739b524b3af41feb43e4921cb529f5929a", 199 | "sha256:eb7972a85c54febfb25b5c4b4f3af4dcc731994c7da0d8a0b4a6eb0640e1d178", 200 | "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0", 201 | "sha256:f3818cb119498c0678015754eba762e0d61e5b52d34c8b13d770f0719f7b1d79", 202 | "sha256:f8b3d067f2e40fe93e1ccdd6b2e1d16c43140e76f02fb1319a05cf2b79d99430", 203 | "sha256:fcabf5ff6eea076f859677f5f0b6b5c1a51e70a376b0579e0eadef8db48c6b50" 204 | ], 205 | "markers": "python_version >= '3.9'", 206 | "version": "==3.0.2" 207 | }, 208 | "mpmath": { 209 | "hashes": [ 210 | "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", 211 | "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c" 212 | ], 213 | "version": "==1.3.0" 214 | }, 215 | "networkx": { 216 | "hashes": [ 217 | "sha256:0c127d8b2f4865f59ae9cb8aafcd60b5c70f3241ebd66f7defad7c4ab90126c9", 218 | "sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2" 219 | ], 220 | "markers": "python_version >= '3.10'", 221 | "version": "==3.3" 222 | }, 223 | "numpy": { 224 | "hashes": [ 225 | "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b", 226 | "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818", 227 | "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20", 228 | "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0", 229 | "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010", 230 | "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a", 231 | "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea", 232 | "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c", 233 | "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71", 234 | "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110", 235 | "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be", 236 | "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a", 237 | "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a", 238 | "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5", 239 | "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed", 240 | "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd", 241 | "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c", 242 | "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e", 243 | "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0", 244 | "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c", 245 | "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a", 246 | "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b", 247 | "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0", 248 | "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6", 249 | "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2", 250 | "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a", 251 | "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30", 252 | "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218", 253 | "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5", 254 | "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07", 255 | "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2", 256 | "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4", 257 | "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764", 258 | "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef", 259 | "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3", 260 | "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f" 261 | ], 262 | "markers": "python_version >= '3.9'", 263 | "version": "==1.26.4" 264 | }, 265 | "omegaconf": { 266 | "hashes": [ 267 | "sha256:7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b", 268 | "sha256:d5d4b6d29955cc50ad50c46dc269bcd92c6e00f5f90d23ab5fee7bfca4ba4cc7" 269 | ], 270 | "markers": "python_version >= '3.6'", 271 | "version": "==2.3.0" 272 | }, 273 | "packaging": { 274 | "hashes": [ 275 | "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5", 276 | "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9" 277 | ], 278 | "markers": "python_version >= '3.7'", 279 | "version": "==24.0" 280 | }, 281 | "pillow": { 282 | "hashes": [ 283 | "sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c", 284 | "sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2", 285 | "sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb", 286 | "sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d", 287 | "sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa", 288 | "sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3", 289 | "sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1", 290 | "sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a", 291 | "sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd", 292 | "sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8", 293 | "sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999", 294 | "sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599", 295 | "sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936", 296 | "sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375", 297 | "sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d", 298 | "sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b", 299 | "sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60", 300 | "sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572", 301 | "sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3", 302 | "sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced", 303 | "sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f", 304 | "sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b", 305 | "sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19", 306 | "sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f", 307 | "sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d", 308 | "sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383", 309 | "sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795", 310 | "sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355", 311 | "sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57", 312 | "sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09", 313 | "sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b", 314 | "sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462", 315 | "sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf", 316 | "sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f", 317 | "sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a", 318 | "sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad", 319 | "sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9", 320 | "sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d", 321 | "sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45", 322 | "sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994", 323 | "sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d", 324 | "sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338", 325 | "sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463", 326 | "sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451", 327 | "sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591", 328 | "sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c", 329 | "sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd", 330 | "sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32", 331 | "sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9", 332 | "sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf", 333 | "sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5", 334 | "sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828", 335 | "sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3", 336 | "sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5", 337 | "sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2", 338 | "sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b", 339 | "sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2", 340 | "sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475", 341 | "sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3", 342 | "sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb", 343 | "sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef", 344 | "sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015", 345 | "sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002", 346 | "sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170", 347 | "sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84", 348 | "sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57", 349 | "sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f", 350 | "sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27", 351 | "sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a" 352 | ], 353 | "markers": "python_version >= '3.8'", 354 | "version": "==10.3.0" 355 | }, 356 | "pretty-errors": { 357 | "hashes": [ 358 | "sha256:8ce68ccd99e0f2a099265c8c1f1c23b7c60a15d69bb08816cb336e237d5dc983", 359 | "sha256:a16ba5c752c87c263bf92f8b4b58624e3b1e29271a9391f564f12b86e93c6755" 360 | ], 361 | "version": "==1.2.25" 362 | }, 363 | "protobuf": { 364 | "hashes": [ 365 | "sha256:07f2b9a15255e3cf3f137d884af7972407b556a7a220912b252f26dc3121e6bf", 366 | "sha256:2f83bf341d925650d550b8932b71763321d782529ac0eaf278f5242f513cc04e", 367 | "sha256:56937f97ae0dcf4e220ff2abb1456c51a334144c9960b23597f044ce99c29c89", 368 | "sha256:587be23f1212da7a14a6c65fd61995f8ef35779d4aea9e36aad81f5f3b80aec5", 369 | "sha256:673ad60f1536b394b4fa0bcd3146a4130fcad85bfe3b60eaa86d6a0ace0fa374", 370 | "sha256:744489f77c29174328d32f8921566fb0f7080a2f064c5137b9d6f4b790f9e0c1", 371 | "sha256:7cb65fc8fba680b27cf7a07678084c6e68ee13cab7cace734954c25a43da6d0f", 372 | "sha256:a17f4d664ea868102feaa30a674542255f9f4bf835d943d588440d1f49a3ed15", 373 | "sha256:aabbbcf794fbb4c692ff14ce06780a66d04758435717107c387f12fb477bf0d8", 374 | "sha256:b276e3f477ea1eebff3c2e1515136cfcff5ac14519c45f9b4aa2f6a87ea627c4", 375 | "sha256:f51f33d305e18646f03acfdb343aac15b8115235af98bc9f844bf9446573827b" 376 | ], 377 | "markers": "python_version >= '3.8'", 378 | "version": "==5.27.0" 379 | }, 380 | "pyyaml": { 381 | "hashes": [ 382 | "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5", 383 | "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc", 384 | "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df", 385 | "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741", 386 | "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206", 387 | "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27", 388 | "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595", 389 | "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62", 390 | "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98", 391 | "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696", 392 | "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290", 393 | "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9", 394 | "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d", 395 | "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6", 396 | "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867", 397 | "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47", 398 | "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486", 399 | "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6", 400 | "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3", 401 | "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007", 402 | "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938", 403 | "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0", 404 | "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c", 405 | "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735", 406 | "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d", 407 | "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28", 408 | "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4", 409 | "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba", 410 | "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8", 411 | "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef", 412 | "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5", 413 | "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd", 414 | "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3", 415 | "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0", 416 | "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515", 417 | "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c", 418 | "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c", 419 | "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924", 420 | "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34", 421 | "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43", 422 | "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859", 423 | "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673", 424 | "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54", 425 | "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a", 426 | "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b", 427 | "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab", 428 | "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa", 429 | "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c", 430 | "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585", 431 | "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d", 432 | "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f" 433 | ], 434 | "markers": "python_version >= '3.6'", 435 | "version": "==6.0.1" 436 | }, 437 | "setuptools": { 438 | "hashes": [ 439 | "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4", 440 | "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0" 441 | ], 442 | "markers": "python_version >= '3.8'", 443 | "version": "==70.0.0" 444 | }, 445 | "six": { 446 | "hashes": [ 447 | "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926", 448 | "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254" 449 | ], 450 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2'", 451 | "version": "==1.16.0" 452 | }, 453 | "sympy": { 454 | "hashes": [ 455 | "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88", 456 | "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515" 457 | ], 458 | "markers": "python_version >= '3.8'", 459 | "version": "==1.12.1" 460 | }, 461 | "tensorboard": { 462 | "hashes": [ 463 | "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45" 464 | ], 465 | "index": "pypi", 466 | "markers": "python_version >= '3.9'", 467 | "version": "==2.16.2" 468 | }, 469 | "tensorboard-data-server": { 470 | "hashes": [ 471 | "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb", 472 | "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60", 473 | "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530" 474 | ], 475 | "markers": "python_version >= '3.7'", 476 | "version": "==0.7.2" 477 | }, 478 | "torch": { 479 | "hashes": [ 480 | "sha256:09c81c5859a5b819956c6925a405ef1cdda393c9d8a01ce3851453f699d3358c", 481 | "sha256:1bf023aa20902586f614f7682fedfa463e773e26c58820b74158a72470259459", 482 | "sha256:20572f426965dd8a04e92a473d7e445fa579e09943cc0354f3e6fef6130ce061", 483 | "sha256:493d54ee2f9df100b5ce1d18c96dbb8d14908721f76351e908c9d2622773a788", 484 | "sha256:4fb27b35dbb32303c2927da86e27b54a92209ddfb7234afb1949ea2b3effffea", 485 | "sha256:5515503a193781fd1b3f5c474e89c9dfa2faaa782b2795cc4a7ab7e67de923f6", 486 | "sha256:6ae9f64b09516baa4ef890af0672dc981c20b1f0d829ce115d4420a247e88fba", 487 | "sha256:729804e97b7cf19ae9ab4181f91f5e612af07956f35c8b2c8e9d9f3596a8e877", 488 | "sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5", 489 | "sha256:760f8bedff506ce9e6e103498f9b1e9e15809e008368594c3a66bf74a8a51380", 490 | "sha256:a306c87a3eead1ed47457822c01dfbd459fe2920f2d38cbdf90de18f23f72542", 491 | "sha256:b0de2bdc0486ea7b14fc47ff805172df44e421a7318b7c4d92ef589a75d27410", 492 | "sha256:bce43af735c3da16cc14c7de2be7ad038e2fbf75654c2e274e575c6c05772ace", 493 | "sha256:cd0dc498b961ab19cb3f8dbf0c6c50e244f2f37dbfa05754ab44ea057c944ef9", 494 | "sha256:d24e328226d8e2af7cf80fcb1d2f1d108e0de32777fab4aaa2b37b9765d8be73", 495 | "sha256:d8ea5a465dbfd8501f33c937d1f693176c9aef9d1c1b0ca1d44ed7b0a18c52ac", 496 | "sha256:dca986214267b34065a79000cee54232e62b41dff1ec2cab9abc3fc8b3dee0ad", 497 | "sha256:e05f836559251e4096f3786ee99f4a8cbe67bc7fbedba8ad5e799681e47c5e80", 498 | "sha256:e65ba85ae292909cde0dde6369826d51165a3fc8823dc1854cd9432d7f79b932", 499 | "sha256:f9b98bf1a3c8af2d4c41f0bf1433920900896c446d1ddc128290ff146d1eb4bd" 500 | ], 501 | "index": "pypi", 502 | "markers": "python_full_version >= '3.8.0'", 503 | "version": "==2.3.0" 504 | }, 505 | "torchmetrics": { 506 | "hashes": [ 507 | "sha256:0b1e5acdcc9beb05bfe369d3d56cfa5b143f060ebfd6079d19ccc59ba46465b3", 508 | "sha256:18599929a0fff7d4b840a3f9a7700054121850c378caaf7206f4161c0a5dc93c" 509 | ], 510 | "index": "pypi", 511 | "markers": "python_version >= '3.8'", 512 | "version": "==1.4.0" 513 | }, 514 | "torchvision": { 515 | "hashes": [ 516 | "sha256:2115a1906c015f5da9ceedc40a983313b0fd6e2c8a17108a92991706f51f6987", 517 | "sha256:36efd87001c6bee2383e043e46a025affb03179747c8f4777b9918527ffce756", 518 | "sha256:3d7955398d4ceaad77c487c2c44f6f7813112402c9bab8cd906d346005891048", 519 | "sha256:493c45f9937dad37aa1b64b14da17c7a589c72b91adc4837d431009cfe29bd53", 520 | "sha256:4c334b3e719ba0a9ba6e15d4aff1178f5e6d029174f346163fed525f0ccfffd3", 521 | "sha256:5337f6acfa1fe959d5cb340d01a00614d6b31ce7a4824ccb95435a85c5273b95", 522 | "sha256:6323f7e5423ff2594d5891863b919deb9d0de95f01c36bf26fbd879036b6ed08", 523 | "sha256:6896a52168befe1105fb3c9335287390ed227e71d1e4ec4d68b62e8a3099fc09", 524 | "sha256:6ad70ddfa879bda5ed886b2518fe562640e0059787cbd65cb2bffa7674541410", 525 | "sha256:75e22ecf44a13b8f95b8ad421c0261282d859c61816badaca1959e073ccdd691", 526 | "sha256:7c770f0f748e0b17f57c0297508d7254f686cdf03fc2e2949f422b20574f4c0f", 527 | "sha256:925d0a82cccf6f986c18b29b4392a942db65cbdb73c13a129c8493822eb9e36f", 528 | "sha256:95b42d0dc599b47a01530c7439a5751e67e45b85e3a67113989cf7c7c70f2039", 529 | "sha256:a964afbc7ddf50a46b941477f6c35729b416deedd139756befd488245e2e226d", 530 | "sha256:b657d052d146f24cb3b2a78219bfc82ae70a9706671c50f632528907d10cccec", 531 | "sha256:bd8e6f3b5beb49965f15c461302488edfa3d8c2d01d3bb79b150d6fb62711e3a", 532 | "sha256:ccc292e093771d5baacf5535ac4416306b6b5f15676341cd4d010d8542eace25", 533 | "sha256:dd61628a3d189c6852a12dc5ed4cd2eece66d2d67f35a866cb16f1dcb06c8c62", 534 | "sha256:e5a24d620cea14a4bb89f24aa2b506230c0a16a3ada57fc53ad80cfd256a2128", 535 | "sha256:eb9d83c0e1dbb54ecb0fb04c87f786333e3a6fb8b9c400aca7c31081f9aa5707" 536 | ], 537 | "index": "pypi", 538 | "markers": "python_version >= '3.8'", 539 | "version": "==0.18.0" 540 | }, 541 | "typing-extensions": { 542 | "hashes": [ 543 | "sha256:6024b58b69089e5a89c347397254e35f1bf02a907728ec7fee9bf0fe837d203a", 544 | "sha256:915f5e35ff76f56588223f15fdd5938f9a1cf9195c0de25130c627e4d597f6d1" 545 | ], 546 | "markers": "python_version >= '3.8'", 547 | "version": "==4.12.1" 548 | }, 549 | "werkzeug": { 550 | "hashes": [ 551 | "sha256:1bc0c2310d2fbb07b1dd1105eba2f7af72f322e1e455f2f93c993bee8c8a5f17", 552 | "sha256:a8dd59d4de28ca70471a34cba79bed5f7ef2e036a76b3ab0835474246eb41f8d" 553 | ], 554 | "index": "pypi", 555 | "markers": "python_version >= '3.8'", 556 | "version": "==3.0.6" 557 | } 558 | }, 559 | "develop": {} 560 | } 561 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast-SRGAN 2 | The goal of this repository is to enable real time super resolution for upsampling low resolution videos. Currently, the design follows the [SR-GAN](https://arxiv.org/pdf/1609.04802.pdf) architecture. For speed, the upsampling is done through pixel shuffle. 3 | 4 | The training setup looks like the following diagram: 5 | 6 |

7 | 8 |

9 | 10 | # Speed Benchmarks 11 | The following runtimes/fps are obtained by averaging runtimes over 800 frames. Measured on MPS (MacBook M1 Pro GPU). 12 | 13 | | Input Image Size | Output Size | Time (s) | FPS | 14 | | ------------- |:--------------------:|:---------:|:---:| 15 | | 90x160 | 360x640 (360p) | 0.01 | 82 | 16 | | 180x320 | 720x1080 (720p) | 0.04 | 27 | 17 | 18 | We see it's possible to upsample to 720p at around 30fps. 19 | 20 | # Requirements 21 | This was tested on Python 3.10. To install the required packages, use the provided Pipfile: 22 | ```bash 23 | pip install pipenv --upgrade 24 | pipenv install --system --deploy 25 | ``` 26 | 27 | # Pre-trained Model 28 | A pretrained generator model on the DIV2k dataset is provided in the 'models' directory. It uses 8 residual blocks, with 64 filters in every layer of the generator. 29 | 30 | 31 | To try out the provided pretrained model on your own images, run the following: 32 | 33 | ```bash 34 | python inference.py --image_dir 'path/to/your/image/directory' --output_dir 'path/to/save/super/resolution/images' 35 | ``` 36 | 37 | # Training 38 | To train, simply edit the config file in the folder `configs/config.yaml` with your settings, and then launch the training with: 39 | ```bash 40 | python train.py 41 | ``` 42 | 43 | You can also change the config parameters from the command line. The following will run training with a `batch_size` of 32, a generator with 12 residual blocks, and a path to the image directory `/path/to/image/dataset`. 44 | ``` 45 | python train.py data.image_dir="/path/to/image/dataset" training.batch_size=32 generator.n_layers=12 46 | 47 | ``` 48 | This is powered by `hydra`, which means all the parameters in the config are editable via the CLI. 49 | 50 | Model checkpoints and training summaries are saved in tensorboard. To monitor training progress, open up tensorboard by pointing it to the `outputs` directory that will be created when you start training. 51 | 52 | # Samples 53 | Following are some results from the provided trained model. Left shows the low res image, after 4x bicubic upsampling. Middle is the output of the model. Right is the actual high resolution image. 54 | 55 |

56 | The following shows images upsampled 4x by bicubic interpolation, the pretrained model from this repository and the original high resolution image as a comparison 57 | 58 | 59 | 60 | 61 |

62 | 63 | # Contributing 64 | If you have ideas on improving model performance, adding metrics, or any other changes, please make a pull request or open an issue. I'd be happy to accept any contributions. 65 | 66 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | experiment: 2 | name: "SRGAN" 3 | seed: 1234 4 | data: 5 | image_dir: "/Users/hasnain.raza/Datasets/DIV2K" 6 | numpy_dir: "/Users/hasnain.raza/Datasets/div2k_np" 7 | lr_image_size: 24 8 | scale_factor: 4 9 | generator: 10 | n_filters: 64 11 | n_layers: 8 12 | discriminator: 13 | n_filters: 64 14 | n_layers: 7 15 | training: 16 | compiled: false 17 | pretrain_iterations: 100 18 | iterations: 100 19 | device: mps 20 | log_iter: 5000 21 | checkpoint_iter: 5000 22 | batch_size: 24 23 | num_workers: 16 24 | generator_lr: 1e-4 25 | discriminator_lr: 1e-4 26 | 27 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torchvision.transforms import v2 7 | 8 | 9 | class NumpyImagesDataset(Dataset): 10 | 11 | def __init__(self, numpy_paths, lr_image_size, scale_factor): 12 | self.numpy_paths = numpy_paths 13 | self.lr_image_size = lr_image_size 14 | self.hr_image_size = lr_image_size * scale_factor 15 | self.resize = v2.Resize( 16 | (self.lr_image_size, self.lr_image_size), 17 | antialias=True, 18 | interpolation=v2.InterpolationMode.BICUBIC, 19 | ) 20 | 21 | def __len__(self): 22 | return len(self.numpy_paths) 23 | 24 | def __getitem__(self, idx): 25 | image = np.load(self.numpy_paths[idx], mmap_mode="c") 26 | _, h, w = image.shape 27 | crop_h, crop_w = random.randint(0, h - self.hr_image_size), random.randint( 28 | 0, w - self.hr_image_size 29 | ) 30 | hr_image = image[ 31 | :, crop_h : crop_h + self.hr_image_size, crop_w : crop_w + self.hr_image_size 32 | ] 33 | hr_image = torch.tensor(hr_image, dtype=torch.float32) 34 | lr_image = self.resize(hr_image) 35 | 36 | hr_image = (hr_image / 127.5) - 1.0 37 | lr_image = (lr_image / 127.5) - 1.0 38 | return lr_image, hr_image 39 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | import numpy as np 5 | import torch 6 | from omegaconf import OmegaConf 7 | from PIL import Image 8 | from tqdm import tqdm 9 | 10 | from model import Generator 11 | 12 | parser = ArgumentParser("Real Time Image Super Resolution") 13 | parser.add_argument("--image_dir", default=None, required=True, type=str) 14 | parser.add_argument("--output_dir", default=None, required=True, type=str) 15 | 16 | 17 | def main(): 18 | args = parser.parse_args() 19 | os.makedirs(args.output_dir, exist_ok=True) 20 | device = "cpu" 21 | if torch.cuda.is_available(): 22 | device = "cuda" 23 | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): 24 | device = "mps" 25 | print(f"Using device: {device}") 26 | config = OmegaConf.load("configs/config.yaml") 27 | model_path = "models/model.pt" 28 | model = Generator(config.generator) 29 | weights = torch.load(model_path, map_location="cpu") 30 | new_weights = {} 31 | for k, v in weights.items(): 32 | new_weights[k.replace("_orig_mod.", "")] = v 33 | model.load_state_dict(new_weights) 34 | model.to(device) 35 | model.eval() 36 | 37 | image_paths = sorted( 38 | [ 39 | x 40 | for x in os.listdir(args.image_dir) 41 | if x.lower().endswith(".png") 42 | or x.lower().endswith(".jpg") 43 | or x.lower().endswith("jpeg") 44 | ] 45 | ) 46 | print(f"Found {len(image_paths)} to super resolve, starting...") 47 | for image_path in tqdm(image_paths, total=len(image_paths), desc="Super Resolving"): 48 | lr_image = Image.open(os.path.join(args.image_dir, image_path)).convert("RGB") 49 | lr_image = np.array(lr_image) 50 | lr_image = (torch.from_numpy(lr_image) / 127.5) - 1.0 51 | lr_image = lr_image.permute(2, 0, 1).unsqueeze(dim=0).to(device) 52 | with torch.no_grad(): 53 | sr_image = model(lr_image).cpu() 54 | sr_image = (sr_image + 1.0) / 2.0 55 | sr_image = sr_image.permute(0, 2, 3, 1).squeeze() 56 | sr_image = (sr_image * 255).numpy().astype(np.uint8) 57 | Image.fromarray(sr_image).save(os.path.join(args.output_dir, os.path.basename(image_path))) 58 | 59 | 60 | if __name__ == "__main__": 61 | main() 62 | 63 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models.vgg import VGG19_Weights, vgg19 3 | 4 | 5 | class VGG19(torch.nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.vgg = vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features[:34] 9 | for param in self.vgg.parameters(): 10 | param.requires_grad = False 11 | self.register_buffer( 12 | "mean", 13 | torch.tensor([0.485, 0.456, 0.406], requires_grad=False).view(1, 3, 1, 1), 14 | ) 15 | self.register_buffer( 16 | "std", 17 | torch.tensor([0.229, 0.224, 0.225], requires_grad=False).view(1, 3, 1, 1), 18 | ) 19 | 20 | def forward(self, x): 21 | x = (x + 1.0) / 2.0 22 | x = (x - self.mean) / self.std 23 | return self.vgg(x) 24 | 25 | 26 | class UpSamplingBlock(torch.nn.Module): 27 | 28 | def __init__(self, config): 29 | super().__init__() 30 | self.conv = torch.nn.Conv2d( 31 | in_channels=config.n_filters, 32 | out_channels=config.n_filters * 4, 33 | kernel_size=3, 34 | padding=1, 35 | ) 36 | self.phase_shift = torch.nn.PixelShuffle(upscale_factor=2) 37 | self.relu = torch.nn.PReLU() 38 | 39 | def forward(self, x): 40 | return self.relu(self.phase_shift(self.conv(x))) 41 | 42 | 43 | class ResidualBlock(torch.nn.Module): 44 | 45 | def __init__(self, in_channels, out_channels): 46 | super().__init__() 47 | self.conv1 = torch.nn.Conv2d( 48 | in_channels=in_channels, 49 | out_channels=out_channels, 50 | kernel_size=3, 51 | stride=1, 52 | padding=1, 53 | bias=False, 54 | ) 55 | self.bn1 = torch.nn.InstanceNorm2d(out_channels) 56 | self.relu1 = torch.nn.PReLU() 57 | self.conv2 = torch.nn.Conv2d( 58 | in_channels=in_channels, 59 | out_channels=out_channels, 60 | kernel_size=3, 61 | stride=1, 62 | padding=1, 63 | bias=False, 64 | ) 65 | self.bn2 = torch.nn.InstanceNorm2d(out_channels) 66 | 67 | def forward(self, x): 68 | y = self.relu1(self.bn1(self.conv1(x))) 69 | return self.bn2(self.conv2(y)) + x 70 | 71 | 72 | class Generator(torch.nn.Module): 73 | def __init__(self, config): 74 | super().__init__() 75 | self.neck = torch.nn.Sequential( 76 | torch.nn.Conv2d(in_channels=3, out_channels=config.n_filters, kernel_size=3, padding=1), 77 | torch.nn.PReLU(), 78 | ) 79 | self.stem = torch.nn.Sequential( 80 | *[ 81 | ResidualBlock(in_channels=config.n_filters, out_channels=config.n_filters) 82 | for _ in range(config.n_layers) 83 | ] 84 | ) 85 | 86 | self.bottleneck = torch.nn.Sequential( 87 | torch.nn.Conv2d( 88 | in_channels=config.n_filters, 89 | out_channels=config.n_filters, 90 | kernel_size=3, 91 | padding=1, 92 | bias=False, 93 | ), 94 | torch.nn.InstanceNorm2d(config.n_filters), 95 | ) 96 | 97 | self.upsampling = torch.nn.Sequential( 98 | UpSamplingBlock(config), 99 | UpSamplingBlock(config), 100 | ) 101 | 102 | self.head = torch.nn.Sequential( 103 | torch.nn.Conv2d( 104 | in_channels=config.n_filters, 105 | out_channels=3, 106 | kernel_size=3, 107 | padding=1, 108 | ), 109 | torch.nn.Tanh(), 110 | ) 111 | 112 | def forward(self, x): 113 | residual = self.neck(x) 114 | x = self.stem(residual) 115 | x = self.bottleneck(x) + residual 116 | x = self.upsampling(x) 117 | return self.head(x) 118 | 119 | 120 | class SimpleBlock(torch.nn.Module): 121 | 122 | def __init__(self, in_channels, out_channels, stride): 123 | super().__init__() 124 | self.conv = torch.nn.Conv2d( 125 | in_channels=in_channels, 126 | out_channels=out_channels, 127 | kernel_size=3, 128 | padding=1, 129 | stride=stride, 130 | bias=False, 131 | ) 132 | self.bn = torch.nn.InstanceNorm2d(out_channels) 133 | self.act = torch.nn.LeakyReLU() 134 | 135 | def forward(self, x): 136 | return self.act(self.bn(self.conv(x))) 137 | 138 | 139 | class Discriminator(torch.nn.Module): 140 | def __init__(self, config): 141 | super().__init__() 142 | self.config = config 143 | self.neck = torch.nn.Sequential( 144 | torch.nn.Conv2d(in_channels=3, out_channels=config.n_filters, kernel_size=3, padding=1), 145 | torch.nn.LeakyReLU(negative_slope=0.2), 146 | ) 147 | 148 | layers = [ 149 | SimpleBlock( 150 | in_channels=config.n_filters, 151 | out_channels=config.n_filters, 152 | stride=2, 153 | ), 154 | SimpleBlock( 155 | in_channels=config.n_filters, 156 | out_channels=config.n_filters * 2, 157 | stride=1, 158 | ), 159 | SimpleBlock( 160 | in_channels=config.n_filters * 2, 161 | out_channels=config.n_filters * 2, 162 | stride=2, 163 | ), 164 | SimpleBlock( 165 | in_channels=config.n_filters * 2, 166 | out_channels=config.n_filters * 4, 167 | stride=1, 168 | ), 169 | SimpleBlock( 170 | in_channels=config.n_filters * 4, 171 | out_channels=config.n_filters * 4, 172 | stride=2, 173 | ), 174 | SimpleBlock( 175 | in_channels=config.n_filters * 4, 176 | out_channels=config.n_filters * 8, 177 | stride=1, 178 | ), 179 | SimpleBlock( 180 | in_channels=config.n_filters * 8, 181 | out_channels=config.n_filters * 8, 182 | stride=2, 183 | ), 184 | torch.nn.Conv2d( 185 | in_channels=config.n_filters * 8, out_channels=1, kernel_size=1, padding=0, stride=1 186 | ), 187 | ] 188 | 189 | self.stem = torch.nn.Sequential(*layers) 190 | 191 | def forward(self, x): 192 | x = self.neck(x) 193 | return self.stem(x) 194 | -------------------------------------------------------------------------------- /models/model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HasnainRaz/Fast-SRGAN/2c09baa662f864ad9197f5b64171e6cc5c37a409/models/model.pt -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from concurrent.futures import ThreadPoolExecutor 4 | 5 | import hydra 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from torch.utils.data import DataLoader, RandomSampler 10 | from tqdm import tqdm 11 | 12 | from dataloader import NumpyImagesDataset 13 | from trainer import Trainer 14 | 15 | 16 | def seed(seed): 17 | torch.manual_seed(seed) 18 | np.random.seed(seed) 19 | random.seed(seed) 20 | 21 | 22 | def write_images_to_numpy_arrays(image_list, output_dir): 23 | os.makedirs(output_dir, exist_ok=True) 24 | 25 | def _write_image_to_numpy(image_path, numpy_path): 26 | image = Image.open(image_path).convert("RGB") 27 | image = np.array(image).astype(np.uint8) 28 | image = np.transpose(image, (2, 0, 1)) 29 | np.save(numpy_path, image) 30 | pbar.update(1) 31 | 32 | with tqdm(total=len(image_list)) as pbar: 33 | with ThreadPoolExecutor(max_workers=16) as executor: 34 | for image_path in image_list: 35 | file_name = os.path.basename(image_path).replace(".png", "") 36 | numpy_path = os.path.join(output_dir, file_name) 37 | executor.submit(_write_image_to_numpy, image_path, numpy_path) 38 | 39 | 40 | def seed_worker(_): 41 | worker_seed = torch.initial_seed() % 2**32 42 | np.random.seed(worker_seed) 43 | random.seed(worker_seed) 44 | 45 | 46 | @hydra.main(version_base="1.1", config_path="configs", config_name="config") 47 | def main(config): 48 | if not os.path.exists(config.data.numpy_dir): 49 | write_images_to_numpy_arrays( 50 | [ 51 | os.path.join(config.data.image_dir, x) 52 | for x in os.listdir(config.data.image_dir) 53 | if x.endswith(".png") 54 | ], 55 | config.data.numpy_dir, 56 | ) 57 | g = torch.Generator() 58 | g.manual_seed(config.experiment.seed) 59 | seed(config.experiment.seed) 60 | 61 | numpy_files = [ 62 | os.path.join(config.data.numpy_dir, x) 63 | for x in os.listdir(config.data.numpy_dir) 64 | if x.endswith(".npy") 65 | ] 66 | train_dataset = NumpyImagesDataset( 67 | numpy_files, config.data.lr_image_size, config.data.scale_factor 68 | ) 69 | pretrain_sampler = RandomSampler( 70 | train_dataset, 71 | replacement=True, 72 | num_samples=config.training.pretrain_iterations * config.training.batch_size, 73 | generator=g, 74 | ) 75 | train_sampler = RandomSampler( 76 | train_dataset, 77 | replacement=True, 78 | num_samples=config.training.iterations * config.training.batch_size, 79 | generator=g, 80 | ) 81 | val_dataloader = DataLoader( 82 | train_dataset, 83 | batch_size=config.training.batch_size, 84 | num_workers=config.training.num_workers, 85 | drop_last=True, 86 | shuffle=False, 87 | pin_memory=True, 88 | persistent_workers=True, 89 | worker_init_fn=seed_worker, 90 | generator=g, 91 | ) 92 | pretrain_dataloader = DataLoader( 93 | train_dataset, 94 | sampler=pretrain_sampler, 95 | batch_size=config.training.batch_size, 96 | num_workers=config.training.num_workers, 97 | drop_last=True, 98 | persistent_workers=True, 99 | pin_memory=True, 100 | generator=g, 101 | worker_init_fn=seed_worker, 102 | ) 103 | train_dataloader = DataLoader( 104 | train_dataset, 105 | batch_size=config.training.batch_size, 106 | num_workers=config.training.num_workers, 107 | drop_last=True, 108 | sampler=train_sampler, 109 | pin_memory=True, 110 | persistent_workers=True, 111 | worker_init_fn=seed_worker, 112 | generator=g, 113 | ) 114 | trainer = Trainer(config) 115 | trainer.pretrain(pretrain_dataloader, val_dataloader) 116 | trainer.train(train_dataloader, val_dataloader) 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | from torch.utils.tensorboard.writer import SummaryWriter 5 | from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure 6 | from tqdm import tqdm 7 | 8 | from model import VGG19, Discriminator, Generator 9 | 10 | 11 | class Trainer: 12 | fixed_lr_images = torch.tensor([]) 13 | fixed_hr_images = torch.tensor([]) 14 | 15 | def __init__(self, config): 16 | self.config = config 17 | self.writer = SummaryWriter(log_dir=osp.join("runs", config.experiment.name)) 18 | self.generator = Generator(config=config.generator) 19 | self.generator.to(self.config.training.device) 20 | self.discriminator = Discriminator(config=config.discriminator) 21 | self.discriminator.to(self.config.training.device) 22 | self.perceptual_network = VGG19().to(self.config.training.device) 23 | if config.training.compiled and torch.cuda.is_available(): 24 | self.generator = torch.compile(self.generator, mode="max-autotune") 25 | self.discriminator = torch.compile(self.discriminator, mode="max-autotune") 26 | self.perceptual_network = torch.compile(self.perceptual_network, mode="max-autotune") 27 | 28 | # The VGG just provides features, no gradient needed 29 | self.perceptual_network.eval() 30 | for p in self.perceptual_network.parameters(): 31 | p.requires_grad = False 32 | 33 | self.optim_generator = torch.optim.AdamW( 34 | self.generator.parameters(), lr=self.config.training.generator_lr, fused=True 35 | ) 36 | self.optim_discriminator = torch.optim.AdamW( 37 | self.discriminator.parameters(), lr=self.config.training.discriminator_lr, fused=True 38 | ) 39 | 40 | # Loss function for the adversarial players 41 | self.loss_fn = torch.nn.BCEWithLogitsLoss() 42 | # Loss function for the content loss 43 | self.l1_loss = torch.nn.SmoothL1Loss() 44 | 45 | # Metrics for our optimization 46 | self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0, reduction="none").to( 47 | config.training.device 48 | ) 49 | self.psnr = PeakSignalNoiseRatio(data_range=1.0, reduction="none").to( 50 | config.training.device 51 | ) 52 | 53 | @torch.no_grad 54 | def _calculate_metrics_over_dataset(self, dataloader, phase, step): 55 | self.generator.eval() 56 | self.ssim.reset() 57 | self.psnr.reset() 58 | for lr_images, hr_images in tqdm( 59 | dataloader, desc="Calculating metrics", total=len(dataloader) 60 | ): 61 | lr_images, hr_images = lr_images.to( 62 | self.config.training.device, non_blocking=True 63 | ), hr_images.to(self.config.training.device, non_blocking=True) 64 | sr_images = (1.0 + self.generator(lr_images)) / 2.0 65 | self.ssim.update(sr_images, (1.0 + hr_images) / 2.0) 66 | self.psnr.update(sr_images, (1.0 + hr_images) / 2.0) 67 | self.writer.add_scalar(f"{phase}/SSIM", self.ssim.compute().mean(), global_step=step) 68 | self.writer.add_scalar(f"{phase}/PSNR", self.psnr.compute().mean(), global_step=step) 69 | self.writer.flush() 70 | 71 | def _log_fixed_images(self, phase): 72 | Trainer.fixed_hr_images = Trainer.fixed_hr_images.to(self.config.training.device) 73 | Trainer.fixed_lr_images = Trainer.fixed_lr_images.to(self.config.training.device) 74 | upsampled_images = torch.nn.functional.interpolate( 75 | Trainer.fixed_lr_images.cpu(), scale_factor=4, mode="bicubic", antialias=True 76 | ).to(self.config.training.device) 77 | self.writer.add_images(f"{phase}/HighRes", Trainer.fixed_hr_images, global_step=0) 78 | self.writer.add_images(f"{phase}/Bicubic", upsampled_images, global_step=0) 79 | 80 | @classmethod 81 | def _pre_train_setup(cls, dataloader): 82 | if cls.fixed_lr_images.ndim == 1: 83 | for fixed_lr_images, fixed_hr_images in dataloader: 84 | cls.fixed_lr_images = (fixed_lr_images + 1.0) / 2.0 85 | cls.fixed_hr_images = (fixed_hr_images + 1.0) / 2.0 86 | cls.images_are_set = True 87 | break 88 | 89 | def pretrain(self, train_dataloader, val_dataloader): 90 | if osp.exists("runs/pretrain.pt"): 91 | print("Pretrained model found, skipping pretraining") 92 | self.generator.load_state_dict(torch.load("runs/pretrain.pt")["model"]) 93 | self.optim_generator.load_state_dict(torch.load("runs/pretrain.pt")["optimizer"]) 94 | return 95 | self._calculate_metrics_over_dataset(val_dataloader, "Pretrain", step=0) 96 | self._pre_train_setup(val_dataloader) 97 | self._log_fixed_images("Pretrain") 98 | step = 0 99 | for step, (lr_images, hr_images) in tqdm( 100 | enumerate(train_dataloader, start=1), 101 | desc="Pretraining Generator", 102 | total=len(train_dataloader), 103 | ): 104 | lr_images, hr_images = lr_images.to( 105 | self.config.training.device, non_blocking=True 106 | ), hr_images.to(self.config.training.device, non_blocking=True) 107 | self.optim_generator.zero_grad(set_to_none=True) 108 | fake_hr_images = self.generator(lr_images) 109 | gen_loss = self.l1_loss(fake_hr_images, hr_images) 110 | gen_loss.backward() 111 | self.optim_generator.step() 112 | 113 | if step % self.config.training.log_iter == 0: 114 | self.writer.add_scalar( 115 | "Pretrain/Generator/Loss", 116 | gen_loss, 117 | global_step=step, 118 | ) 119 | if step % self.config.training.checkpoint_iter == 0: 120 | self.generator.eval() 121 | with torch.no_grad(): 122 | fake_hr_images = (1.0 + self.generator(2.0 * self.fixed_lr_images - 1.0)) / 2.0 123 | self.writer.add_images( 124 | "Pretrain/Generated", 125 | fake_hr_images, 126 | global_step=step, 127 | ) 128 | self._calculate_metrics_over_dataset(val_dataloader, "Pretrain", step) 129 | self.generator.train() 130 | 131 | torch.save( 132 | {"model": self.generator.state_dict(), "optimizer": self.optim_generator.state_dict()}, 133 | f"runs/pretrain_generator.pt", 134 | ) 135 | torch.save( 136 | { 137 | "model": self.discriminator.state_dict(), 138 | "optimizer": self.optim_discriminator.state_dict(), 139 | }, 140 | f"runs/pretrain_discriminator.pt", 141 | ) 142 | 143 | def save_checkpoints(self, step): 144 | save_dir = osp.join("runs", self.config.experiment.name) 145 | torch.save(self.generator.state_dict(), osp.join(save_dir, f"generator_epoch_{step}.pt")) 146 | torch.save( 147 | self.discriminator.state_dict(), osp.join(save_dir, f"discriminator_epoch_{step}.pt") 148 | ) 149 | torch.save( 150 | self.optim_generator.state_dict(), 151 | osp.join(save_dir, f"generator_optim_epoch_{step}.pt"), 152 | ) 153 | torch.save( 154 | self.optim_discriminator.state_dict(), 155 | osp.join(save_dir, f"discriminator_optim_epoch_{step}.pt"), 156 | ) 157 | 158 | def train(self, train_dataloader, val_dataloader): 159 | self._calculate_metrics_over_dataset(val_dataloader, "GAN", step=0) 160 | if Trainer.fixed_lr_images is None: 161 | self._pre_train_setup(train_dataloader) 162 | self._log_fixed_images("GAN") 163 | self.generator.train() 164 | self.discriminator.train() 165 | for step, (lr_images, hr_images) in tqdm( 166 | enumerate(train_dataloader, start=1), desc="GAN Training", total=len(train_dataloader) 167 | ): 168 | lr_images, hr_images = lr_images.to( 169 | self.config.training.device, non_blocking=True 170 | ), hr_images.to(self.config.training.device, non_blocking=True) 171 | self.optim_discriminator.zero_grad(set_to_none=True) 172 | y_real = self.discriminator(hr_images) 173 | sr_images = self.generator(lr_images).detach() 174 | y_fake = self.discriminator(sr_images) 175 | real_labels = 0.3 * torch.rand_like(y_real) + 0.8 176 | fake_labels = 0.3 * torch.rand_like(y_fake) 177 | loss_real = self.loss_fn(y_real, real_labels.to(self.config.training.device)) 178 | loss_fake = self.loss_fn(y_fake, fake_labels.to(self.config.training.device)) 179 | discriminator_loss = 0.5 * loss_real + 0.5 * loss_fake 180 | discriminator_loss.backward() 181 | self.optim_discriminator.step() 182 | 183 | # Get the adv loss for the generator 184 | self.optim_generator.zero_grad(set_to_none=True) 185 | sr_images = self.generator(lr_images) 186 | y_fake = self.discriminator(sr_images) 187 | real_labels = 0.3 * torch.rand_like(y_fake) + 0.7 188 | adv_loss = 1e-1 * self.loss_fn(y_fake, real_labels.to(self.config.training.device)) 189 | # Get the content loss for the generator 190 | fake_features = self.perceptual_network(sr_images) 191 | real_features = self.perceptual_network(hr_images) 192 | content_loss = self.l1_loss(fake_features, real_features) 193 | # Train the generator 194 | generator_loss = 0.5 * adv_loss + 0.5 * content_loss 195 | generator_loss.backward() 196 | self.optim_generator.step() 197 | 198 | if step % self.config.training.log_iter == 0: 199 | self.writer.add_scalar( 200 | "Loss/Discriminator/Real", 201 | loss_real, 202 | global_step=step, 203 | ) 204 | self.writer.add_scalar( 205 | "Loss/Discriminator/Fake", 206 | loss_fake, 207 | global_step=step, 208 | ) 209 | self.writer.add_scalar( 210 | "Loss/Generator/Adversarial", 211 | adv_loss, 212 | global_step=step, 213 | ) 214 | self.writer.add_scalar( 215 | "Loss/Generator/Content", 216 | content_loss, 217 | global_step=step, 218 | ) 219 | 220 | if step % self.config.training.checkpoint_iter == 0: 221 | self.generator.eval() 222 | with torch.no_grad(): 223 | generated_sr_image = ( 224 | 1.0 + self.generator(2 * self.fixed_lr_images - 1.0) 225 | ) / 2.0 226 | self.writer.add_images( 227 | "GAN/Generated", 228 | generated_sr_image, 229 | global_step=step, 230 | ) 231 | self._calculate_metrics_over_dataset(val_dataloader, "GAN", step=step) 232 | self.save_checkpoints(step) 233 | self.generator.train() 234 | --------------------------------------------------------------------------------