├── .gitignore ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── contextual_loss ├── __init__.py ├── config.py ├── functional.py └── modules │ ├── __init__.py │ ├── contextual.py │ ├── contextual_bilateral.py │ └── vgg.py ├── doc └── small_example.ipynb ├── setup.py └── tests ├── __init__.py ├── test_contextual.py └── test_contextual_bilateral.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Sou Uchida 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | pytest = "*" 8 | flake8 = "*" 9 | pylint = "*" 10 | matplotlib = "*" 11 | ipython = "*" 12 | 13 | [packages] 14 | torch = "*" 15 | torchvision = "*" 16 | 17 | [requires] 18 | python_version = "3.7" 19 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "d73f48ef2416fa486483d856a1b6fb29ee661445de7c6d3b7ded11f777cabe2c" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3.7" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": { 19 | "numpy": { 20 | "hashes": [ 21 | "sha256:0a7a1dd123aecc9f0076934288ceed7fd9a81ba3919f11a855a7887cbe82a02f", 22 | "sha256:0c0763787133dfeec19904c22c7e358b231c87ba3206b211652f8cbe1241deb6", 23 | "sha256:3d52298d0be333583739f1aec9026f3b09fdfe3ddf7c7028cb16d9d2af1cca7e", 24 | "sha256:43bb4b70585f1c2d153e45323a886839f98af8bfa810f7014b20be714c37c447", 25 | "sha256:475963c5b9e116c38ad7347e154e5651d05a2286d86455671f5b1eebba5feb76", 26 | "sha256:64874913367f18eb3013b16123c9fed113962e75d809fca5b78ebfbb73ed93ba", 27 | "sha256:683828e50c339fc9e68720396f2de14253992c495fdddef77a1e17de55f1decc", 28 | "sha256:6ca4000c4a6f95a78c33c7dadbb9495c10880be9c89316aa536eac359ab820ae", 29 | "sha256:75fd817b7061f6378e4659dd792c84c0b60533e867f83e0d1e52d5d8e53df88c", 30 | "sha256:7d81d784bdbed30137aca242ab307f3e65c8d93f4c7b7d8f322110b2e90177f9", 31 | "sha256:8d0af8d3664f142414fd5b15cabfd3b6cc3ef242a3c7a7493257025be5a6955f", 32 | "sha256:9679831005fb16c6df3dd35d17aa31dc0d4d7573d84f0b44cc481490a65c7725", 33 | "sha256:a8f67ebfae9f575d85fa859b54d3bdecaeece74e3274b0b5c5f804d7ca789fe1", 34 | "sha256:acbf5c52db4adb366c064d0b7c7899e3e778d89db585feadd23b06b587d64761", 35 | "sha256:ada4805ed51f5bcaa3a06d3dd94939351869c095e30a2b54264f5a5004b52170", 36 | "sha256:c7354e8f0eca5c110b7e978034cd86ed98a7a5ffcf69ca97535445a595e07b8e", 37 | "sha256:e2e9d8c87120ba2c591f60e32736b82b67f72c37ba88a4c23c81b5b8fa49c018", 38 | "sha256:e467c57121fe1b78a8f68dd9255fbb3bb3f4f7547c6b9e109f31d14569f490c3", 39 | "sha256:ede47b98de79565fcd7f2decb475e2dcc85ee4097743e551fe26cfc7eb3ff143", 40 | "sha256:f58913e9227400f1395c7b800503ebfdb0772f1c33ff8cb4d6451c06cabdf316", 41 | "sha256:fe39f5fd4103ec4ca3cb8600b19216cd1ff316b4990f4c0b6057ad982c0a34d5" 42 | ], 43 | "version": "==1.17.4" 44 | }, 45 | "pillow": { 46 | "hashes": [ 47 | "sha256:047d9473cf68af50ac85f8ee5d5f21a60f849bc17d348da7fc85711287a75031", 48 | "sha256:0f66dc6c8a3cc319561a633b6aa82c44107f12594643efa37210d8c924fc1c71", 49 | "sha256:12c9169c4e8fe0a7329e8658c7e488001f6b4c8e88740e76292c2b857af2e94c", 50 | "sha256:248cffc168896982f125f5c13e9317c059f74fffdb4152893339f3be62a01340", 51 | "sha256:27faf0552bf8c260a5cee21a76e031acaea68babb64daf7e8f2e2540745082aa", 52 | "sha256:285edafad9bc60d96978ed24d77cdc0b91dace88e5da8c548ba5937c425bca8b", 53 | "sha256:384b12c9aa8ef95558abdcb50aada56d74bc7cc131dd62d28c2d0e4d3aadd573", 54 | "sha256:38950b3a707f6cef09cd3cbb142474357ad1a985ceb44d921bdf7b4647b3e13e", 55 | "sha256:4aad1b88933fd6dc2846552b89ad0c74ddbba2f0884e2c162aa368374bf5abab", 56 | "sha256:4ac6148008c169603070c092e81f88738f1a0c511e07bd2bb0f9ef542d375da9", 57 | "sha256:4deb1d2a45861ae6f0b12ea0a786a03d19d29edcc7e05775b85ec2877cb54c5e", 58 | "sha256:59aa2c124df72cc75ed72c8d6005c442d4685691a30c55321e00ed915ad1a291", 59 | "sha256:5a47d2123a9ec86660fe0e8d0ebf0aa6bc6a17edc63f338b73ea20ba11713f12", 60 | "sha256:5cc901c2ab9409b4b7ac7b5bcc3e86ac14548627062463da0af3b6b7c555a871", 61 | "sha256:6c1db03e8dff7b9f955a0fb9907eb9ca5da75b5ce056c0c93d33100a35050281", 62 | "sha256:7ce80c0a65a6ea90ef9c1f63c8593fcd2929448613fc8da0adf3e6bfad669d08", 63 | "sha256:809c19241c14433c5d6135e1b6c72da4e3b56d5c865ad5736ab99af8896b8f41", 64 | "sha256:83792cb4e0b5af480588601467c0764242b9a483caea71ef12d22a0d0d6bdce2", 65 | "sha256:846fa202bd7ee0f6215c897a1d33238ef071b50766339186687bd9b7a6d26ac5", 66 | "sha256:9f5529fc02009f96ba95bea48870173426879dc19eec49ca8e08cd63ecd82ddb", 67 | "sha256:a423c2ea001c6265ed28700df056f75e26215fd28c001e93ef4380b0f05f9547", 68 | "sha256:ac4428094b42907aba5879c7c000d01c8278d451a3b7cccd2103e21f6397ea75", 69 | "sha256:b1ae48d87f10d1384e5beecd169c77502fcc04a2c00a4c02b85f0a94b419e5f9", 70 | "sha256:bf4e972a88f8841d8fdc6db1a75e0f8d763e66e3754b03006cbc3854d89f1cb1", 71 | "sha256:c6414f6aad598364aaf81068cabb077894eb88fed99c6a65e6e8217bab62ae7a", 72 | "sha256:c710fcb7ee32f67baf25aa9ffede4795fd5d93b163ce95fdc724383e38c9df96", 73 | "sha256:c7be4b8a09852291c3c48d3c25d1b876d2494a0a674980089ac9d5e0d78bd132", 74 | "sha256:c9e5ffb910b14f090ac9c38599063e354887a5f6d7e6d26795e916b4514f2c1a", 75 | "sha256:e0697b826da6c2472bb6488db4c0a7fa8af0d52fa08833ceb3681358914b14e5", 76 | "sha256:e9a3edd5f714229d41057d56ac0f39ad9bdba6767e8c888c951869f0bdd129b0" 77 | ], 78 | "version": "==6.2.1" 79 | }, 80 | "six": { 81 | "hashes": [ 82 | "sha256:1f1b7d42e254082a9db6279deae68afb421ceba6158efa6131de7b3003ee93fd", 83 | "sha256:30f610279e8b2578cab6db20741130331735c781b56053c59c4076da27f06b66" 84 | ], 85 | "version": "==1.13.0" 86 | }, 87 | "torch": { 88 | "hashes": [ 89 | "sha256:0cec2e13a2e95c24c34f17d437f354ee2a40902e8d515a524556b350e12555dd", 90 | "sha256:134e8291a97151b1ffeea09cb9ddde5238beb4e6d9dfb66657143d6990bfb865", 91 | "sha256:31062923ac2e60eac676f6a0ae14702b051c158bbcf7f440eaba266b0defa197", 92 | "sha256:3b05233481b51bb636cee63dc761bb7f602e198178782ff4159d385d1759608b", 93 | "sha256:458f1d87e5b7064b2c39e36675d84e163be3143dd2fc806057b7878880c461bc", 94 | "sha256:72a1c85bffd2154f085bc0a1d378d8a54e55a57d49664b874fe7c949022bf071", 95 | "sha256:77fd8866c0bf529861ffd850a5dada2190a8d9c5167719fb0cfa89163e23b143", 96 | "sha256:b6f01d851d1c5989d4a99b50ae0187762b15b7718dcd1a33704b665daa2402f9", 97 | "sha256:d8e1d904a6193ed14a4fed220b00503b2baa576e71471286d1ebba899c851fae" 98 | ], 99 | "index": "pypi", 100 | "version": "==1.3.1" 101 | }, 102 | "torchvision": { 103 | "hashes": [ 104 | "sha256:0f8245d6378acc86917f58492675f93df5279abae8bc5f832e3510722191f6c9", 105 | "sha256:1ad7593d94f6612ccb84a59467f0d10cdc213fb3e2bb91f1e773eb844787fa4c", 106 | "sha256:2553405b9afe3cedb410873b9877eb18b1526f8b01cb7c2747e51b69a936e0b5", 107 | "sha256:276a385f2f5fe484bf08467b5d081d9144b97eb458ba5b4a11e4640389e53149", 108 | "sha256:66deba9c577e36f4f071decdd894bf7ba794ac133dae64b3fd02fc3f0c6b989d", 109 | "sha256:7a458330e4efcd66f9f70127ab21fcf8cfea84acda8e707322fd2843aa6dd396", 110 | "sha256:8ff715c2323d9eca89126824ebfa74b282a95d6f64a4743fbe9b738d2de21c77", 111 | "sha256:dca4aadc12a123730957b501f9c5c2870d2f6727a2c28552cb7907b68b0ea10c", 112 | "sha256:dda25ce304978bba19e6543f7dcfee4f37d2f128ec83d4ab0c7e8f991d64865f" 113 | ], 114 | "index": "pypi", 115 | "version": "==0.4.2" 116 | } 117 | }, 118 | "develop": { 119 | "astroid": { 120 | "hashes": [ 121 | "sha256:71ea07f44df9568a75d0f354c49143a4575d90645e9fead6dfb52c26a85ed13a", 122 | "sha256:840947ebfa8b58f318d42301cf8c0a20fd794a33b61cc4638e28e9e61ba32f42" 123 | ], 124 | "version": "==2.3.3" 125 | }, 126 | "attrs": { 127 | "hashes": [ 128 | "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", 129 | "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72" 130 | ], 131 | "version": "==19.3.0" 132 | }, 133 | "backcall": { 134 | "hashes": [ 135 | "sha256:38ecd85be2c1e78f77fd91700c76e14667dc21e2713b63876c0eb901196e01e4", 136 | "sha256:bbbf4b1e5cd2bdb08f915895b51081c041bac22394fdfcfdfbe9f14b77c08bf2" 137 | ], 138 | "version": "==0.1.0" 139 | }, 140 | "cycler": { 141 | "hashes": [ 142 | "sha256:1d8a5ae1ff6c5cf9b93e8811e581232ad8920aeec647c37316ceac982b08cb2d", 143 | "sha256:cd7b2d1018258d7247a71425e9f26463dfb444d411c39569972f4ce586b0c9d8" 144 | ], 145 | "version": "==0.10.0" 146 | }, 147 | "decorator": { 148 | "hashes": [ 149 | "sha256:54c38050039232e1db4ad7375cfce6748d7b41c29e95a081c8a6d2c30364a2ce", 150 | "sha256:5d19b92a3c8f7f101c8dd86afd86b0f061a8ce4540ab8cd401fa2542756bce6d" 151 | ], 152 | "version": "==4.4.1" 153 | }, 154 | "entrypoints": { 155 | "hashes": [ 156 | "sha256:589f874b313739ad35be6e0cd7efde2a4e9b6fea91edcc34e58ecbb8dbe56d19", 157 | "sha256:c70dd71abe5a8c85e55e12c19bd91ccfeec11a6e99044204511f9ed547d48451" 158 | ], 159 | "version": "==0.3" 160 | }, 161 | "flake8": { 162 | "hashes": [ 163 | "sha256:45681a117ecc81e870cbf1262835ae4af5e7a8b08e40b944a8a6e6b895914cfb", 164 | "sha256:49356e766643ad15072a789a20915d3c91dc89fd313ccd71802303fd67e4deca" 165 | ], 166 | "index": "pypi", 167 | "version": "==3.7.9" 168 | }, 169 | "importlib-metadata": { 170 | "hashes": [ 171 | "sha256:aa18d7378b00b40847790e7c27e11673d7fed219354109d0e7b9e5b25dc3ad26", 172 | "sha256:d5f18a79777f3aa179c145737780282e27b508fc8fd688cb17c7a813e8bd39af" 173 | ], 174 | "markers": "python_version < '3.8'", 175 | "version": "==0.23" 176 | }, 177 | "ipython": { 178 | "hashes": [ 179 | "sha256:060d19feef09453d3375ab23c7295ed36cb59e5a3904598ab903f93ec45f1f63", 180 | "sha256:e468b8f03a0168a667982b50f0b4e0828cc32721bbea32b23934e55b7970eb7a" 181 | ], 182 | "index": "pypi", 183 | "version": "==7.10.0" 184 | }, 185 | "ipython-genutils": { 186 | "hashes": [ 187 | "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8", 188 | "sha256:eb2e116e75ecef9d4d228fdc66af54269afa26ab4463042e33785b887c628ba8" 189 | ], 190 | "version": "==0.2.0" 191 | }, 192 | "isort": { 193 | "hashes": [ 194 | "sha256:54da7e92468955c4fceacd0c86bd0ec997b0e1ee80d97f67c35a78b719dccab1", 195 | "sha256:6e811fcb295968434526407adb8796944f1988c5b65e8139058f2014cbe100fd" 196 | ], 197 | "version": "==4.3.21" 198 | }, 199 | "jedi": { 200 | "hashes": [ 201 | "sha256:786b6c3d80e2f06fd77162a07fed81b8baa22dde5d62896a790a331d6ac21a27", 202 | "sha256:ba859c74fa3c966a22f2aeebe1b74ee27e2a462f56d3f5f7ca4a59af61bfe42e" 203 | ], 204 | "version": "==0.15.1" 205 | }, 206 | "kiwisolver": { 207 | "hashes": [ 208 | "sha256:05b5b061e09f60f56244adc885c4a7867da25ca387376b02c1efc29cc16bcd0f", 209 | "sha256:210d8c39d01758d76c2b9a693567e1657ec661229bc32eac30761fa79b2474b0", 210 | "sha256:26f4fbd6f5e1dabff70a9ba0d2c4bd30761086454aa30dddc5b52764ee4852b7", 211 | "sha256:3b15d56a9cd40c52d7ab763ff0bc700edbb4e1a298dc43715ecccd605002cf11", 212 | "sha256:3b2378ad387f49cbb328205bda569b9f87288d6bc1bf4cd683c34523a2341efe", 213 | "sha256:400599c0fe58d21522cae0e8b22318e09d9729451b17ee61ba8e1e7c0346565c", 214 | "sha256:47b8cb81a7d18dbaf4fed6a61c3cecdb5adec7b4ac292bddb0d016d57e8507d5", 215 | "sha256:53eaed412477c836e1b9522c19858a8557d6e595077830146182225613b11a75", 216 | "sha256:58e626e1f7dfbb620d08d457325a4cdac65d1809680009f46bf41eaf74ad0187", 217 | "sha256:5a52e1b006bfa5be04fe4debbcdd2688432a9af4b207a3f429c74ad625022641", 218 | "sha256:5c7ca4e449ac9f99b3b9d4693debb1d6d237d1542dd6a56b3305fe8a9620f883", 219 | "sha256:682e54f0ce8f45981878756d7203fd01e188cc6c8b2c5e2cf03675390b4534d5", 220 | "sha256:76275ee077772c8dde04fb6c5bc24b91af1bb3e7f4816fd1852f1495a64dad93", 221 | "sha256:79bfb2f0bd7cbf9ea256612c9523367e5ec51d7cd616ae20ca2c90f575d839a2", 222 | "sha256:7f4dd50874177d2bb060d74769210f3bce1af87a8c7cf5b37d032ebf94f0aca3", 223 | "sha256:8944a16020c07b682df861207b7e0efcd2f46c7488619cb55f65882279119389", 224 | "sha256:8aa7009437640beb2768bfd06da049bad0df85f47ff18426261acecd1cf00897", 225 | "sha256:9105ce82dcc32c73eb53a04c869b6a4bc756b43e4385f76ea7943e827f529e4d", 226 | "sha256:933df612c453928f1c6faa9236161a1d999a26cd40abf1dc5d7ebbc6dbfb8fca", 227 | "sha256:939f36f21a8c571686eb491acfffa9c7f1ac345087281b412d63ea39ca14ec4a", 228 | "sha256:9491578147849b93e70d7c1d23cb1229458f71fc79c51d52dce0809b2ca44eea", 229 | "sha256:9733b7f64bd9f807832d673355f79703f81f0b3e52bfce420fc00d8cb28c6a6c", 230 | "sha256:a02f6c3e229d0b7220bd74600e9351e18bc0c361b05f29adae0d10599ae0e326", 231 | "sha256:a0c0a9f06872330d0dd31b45607197caab3c22777600e88031bfe66799e70bb0", 232 | "sha256:aa716b9122307c50686356cfb47bfbc66541868078d0c801341df31dca1232a9", 233 | "sha256:acc4df99308111585121db217681f1ce0eecb48d3a828a2f9bbf9773f4937e9e", 234 | "sha256:b64916959e4ae0ac78af7c3e8cef4becee0c0e9694ad477b4c6b3a536de6a544", 235 | "sha256:d22702cadb86b6fcba0e6b907d9f84a312db9cd6934ee728144ce3018e715ee1", 236 | "sha256:d3fcf0819dc3fea58be1fd1ca390851bdb719a549850e708ed858503ff25d995", 237 | "sha256:d52e3b1868a4e8fd18b5cb15055c76820df514e26aa84cc02f593d99fef6707f", 238 | "sha256:db1a5d3cc4ae943d674718d6c47d2d82488ddd94b93b9e12d24aabdbfe48caee", 239 | "sha256:e3a21a720791712ed721c7b95d433e036134de6f18c77dbe96119eaf7aa08004", 240 | "sha256:e8bf074363ce2babeb4764d94f8e65efd22e6a7c74860a4f05a6947afc020ff2", 241 | "sha256:f16814a4a96dc04bf1da7d53ee8d5b1d6decfc1a92a63349bb15d37b6a263dd9", 242 | "sha256:f2b22153870ca5cf2ab9c940d7bc38e8e9089fa0f7e5856ea195e1cf4ff43d5a", 243 | "sha256:f790f8b3dff3d53453de6a7b7ddd173d2e020fb160baff578d578065b108a05f", 244 | "sha256:fe51b79da0062f8e9d49ed0182a626a7dc7a0cbca0328f612c6ee5e4711c81e4" 245 | ], 246 | "version": "==1.1.0" 247 | }, 248 | "lazy-object-proxy": { 249 | "hashes": [ 250 | "sha256:0c4b206227a8097f05c4dbdd323c50edf81f15db3b8dc064d08c62d37e1a504d", 251 | "sha256:194d092e6f246b906e8f70884e620e459fc54db3259e60cf69a4d66c3fda3449", 252 | "sha256:1be7e4c9f96948003609aa6c974ae59830a6baecc5376c25c92d7d697e684c08", 253 | "sha256:4677f594e474c91da97f489fea5b7daa17b5517190899cf213697e48d3902f5a", 254 | "sha256:48dab84ebd4831077b150572aec802f303117c8cc5c871e182447281ebf3ac50", 255 | "sha256:5541cada25cd173702dbd99f8e22434105456314462326f06dba3e180f203dfd", 256 | "sha256:59f79fef100b09564bc2df42ea2d8d21a64fdcda64979c0fa3db7bdaabaf6239", 257 | "sha256:8d859b89baf8ef7f8bc6b00aa20316483d67f0b1cbf422f5b4dc56701c8f2ffb", 258 | "sha256:9254f4358b9b541e3441b007a0ea0764b9d056afdeafc1a5569eee1cc6c1b9ea", 259 | "sha256:9651375199045a358eb6741df3e02a651e0330be090b3bc79f6d0de31a80ec3e", 260 | "sha256:97bb5884f6f1cdce0099f86b907aa41c970c3c672ac8b9c8352789e103cf3156", 261 | "sha256:9b15f3f4c0f35727d3a0fba4b770b3c4ebbb1fa907dbcc046a1d2799f3edd142", 262 | "sha256:a2238e9d1bb71a56cd710611a1614d1194dc10a175c1e08d75e1a7bcc250d442", 263 | "sha256:a6ae12d08c0bf9909ce12385803a543bfe99b95fe01e752536a60af2b7797c62", 264 | "sha256:ca0a928a3ddbc5725be2dd1cf895ec0a254798915fb3a36af0964a0a4149e3db", 265 | "sha256:cb2c7c57005a6804ab66f106ceb8482da55f5314b7fcb06551db1edae4ad1531", 266 | "sha256:d74bb8693bf9cf75ac3b47a54d716bbb1a92648d5f781fc799347cfc95952383", 267 | "sha256:d945239a5639b3ff35b70a88c5f2f491913eb94871780ebfabb2568bd58afc5a", 268 | "sha256:eba7011090323c1dadf18b3b689845fd96a61ba0a1dfbd7f24b921398affc357", 269 | "sha256:efa1909120ce98bbb3777e8b6f92237f5d5c8ea6758efea36a473e1d38f7d3e4", 270 | "sha256:f3900e8a5de27447acbf900b4750b0ddfd7ec1ea7fbaf11dfa911141bc522af0" 271 | ], 272 | "version": "==1.4.3" 273 | }, 274 | "matplotlib": { 275 | "hashes": [ 276 | "sha256:08ccc8922eb4792b91c652d3e6d46b1c99073f1284d1b6705155643e8046463a", 277 | "sha256:161dcd807c0c3232f4dcd4a12a382d52004a498174cbfafd40646106c5bcdcc8", 278 | "sha256:1f9e885bfa1b148d16f82a6672d043ecf11197f6c71ae222d0546db706e52eb2", 279 | "sha256:2d6ab54015a7c0d727c33e36f85f5c5e4172059efdd067f7527f6e5d16ad01aa", 280 | "sha256:5d2e408a2813abf664bd79431107543ecb449136912eb55bb312317edecf597e", 281 | "sha256:61c8b740a008218eb604de518eb411c4953db0cb725dd0b32adf8a81771cab9e", 282 | "sha256:80f10af8378fccc136da40ea6aa4a920767476cdfb3241acb93ef4f0465dbf57", 283 | "sha256:819d4860315468b482f38f1afe45a5437f60f03eaede495d5ff89f2eeac89500", 284 | "sha256:8cc0e44905c2c8fda5637cad6f311eb9517017515a034247ab93d0cf99f8bb7a", 285 | "sha256:8e8e2c2fe3d873108735c6ee9884e6f36f467df4a143136209cff303b183bada", 286 | "sha256:98c2ffeab8b79a4e3a0af5dd9939f92980eb6e3fec10f7f313df5f35a84dacab", 287 | "sha256:d59bb0e82002ac49f4152963f8a1079e66794a4f454457fd2f0dcc7bf0797d30", 288 | "sha256:ee59b7bb9eb75932fe3787e54e61c99b628155b0cedc907864f24723ba55b309" 289 | ], 290 | "index": "pypi", 291 | "version": "==3.1.2" 292 | }, 293 | "mccabe": { 294 | "hashes": [ 295 | "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42", 296 | "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f" 297 | ], 298 | "version": "==0.6.1" 299 | }, 300 | "more-itertools": { 301 | "hashes": [ 302 | "sha256:53ff73f186307d9c8ef17a9600309154a6ae27f25579e80af4db8f047ba14bc2", 303 | "sha256:a0ea684c39bc4315ba7aae406596ef191fd84f873d2d2751f84d64e81a7a2d45" 304 | ], 305 | "version": "==8.0.0" 306 | }, 307 | "numpy": { 308 | "hashes": [ 309 | "sha256:0a7a1dd123aecc9f0076934288ceed7fd9a81ba3919f11a855a7887cbe82a02f", 310 | "sha256:0c0763787133dfeec19904c22c7e358b231c87ba3206b211652f8cbe1241deb6", 311 | "sha256:3d52298d0be333583739f1aec9026f3b09fdfe3ddf7c7028cb16d9d2af1cca7e", 312 | "sha256:43bb4b70585f1c2d153e45323a886839f98af8bfa810f7014b20be714c37c447", 313 | "sha256:475963c5b9e116c38ad7347e154e5651d05a2286d86455671f5b1eebba5feb76", 314 | "sha256:64874913367f18eb3013b16123c9fed113962e75d809fca5b78ebfbb73ed93ba", 315 | "sha256:683828e50c339fc9e68720396f2de14253992c495fdddef77a1e17de55f1decc", 316 | "sha256:6ca4000c4a6f95a78c33c7dadbb9495c10880be9c89316aa536eac359ab820ae", 317 | "sha256:75fd817b7061f6378e4659dd792c84c0b60533e867f83e0d1e52d5d8e53df88c", 318 | "sha256:7d81d784bdbed30137aca242ab307f3e65c8d93f4c7b7d8f322110b2e90177f9", 319 | "sha256:8d0af8d3664f142414fd5b15cabfd3b6cc3ef242a3c7a7493257025be5a6955f", 320 | "sha256:9679831005fb16c6df3dd35d17aa31dc0d4d7573d84f0b44cc481490a65c7725", 321 | "sha256:a8f67ebfae9f575d85fa859b54d3bdecaeece74e3274b0b5c5f804d7ca789fe1", 322 | "sha256:acbf5c52db4adb366c064d0b7c7899e3e778d89db585feadd23b06b587d64761", 323 | "sha256:ada4805ed51f5bcaa3a06d3dd94939351869c095e30a2b54264f5a5004b52170", 324 | "sha256:c7354e8f0eca5c110b7e978034cd86ed98a7a5ffcf69ca97535445a595e07b8e", 325 | "sha256:e2e9d8c87120ba2c591f60e32736b82b67f72c37ba88a4c23c81b5b8fa49c018", 326 | "sha256:e467c57121fe1b78a8f68dd9255fbb3bb3f4f7547c6b9e109f31d14569f490c3", 327 | "sha256:ede47b98de79565fcd7f2decb475e2dcc85ee4097743e551fe26cfc7eb3ff143", 328 | "sha256:f58913e9227400f1395c7b800503ebfdb0772f1c33ff8cb4d6451c06cabdf316", 329 | "sha256:fe39f5fd4103ec4ca3cb8600b19216cd1ff316b4990f4c0b6057ad982c0a34d5" 330 | ], 331 | "version": "==1.17.4" 332 | }, 333 | "packaging": { 334 | "hashes": [ 335 | "sha256:28b924174df7a2fa32c1953825ff29c61e2f5e082343165438812f00d3a7fc47", 336 | "sha256:d9551545c6d761f3def1677baf08ab2a3ca17c56879e70fecba2fc4dde4ed108" 337 | ], 338 | "version": "==19.2" 339 | }, 340 | "parso": { 341 | "hashes": [ 342 | "sha256:63854233e1fadb5da97f2744b6b24346d2750b85965e7e399bec1620232797dc", 343 | "sha256:666b0ee4a7a1220f65d367617f2cd3ffddff3e205f3f16a0284df30e774c2a9c" 344 | ], 345 | "version": "==0.5.1" 346 | }, 347 | "pexpect": { 348 | "hashes": [ 349 | "sha256:2094eefdfcf37a1fdbfb9aa090862c1a4878e5c7e0e7e7088bdb511c558e5cd1", 350 | "sha256:9e2c1fd0e6ee3a49b28f95d4b33bc389c89b20af6a1255906e90ff1262ce62eb" 351 | ], 352 | "markers": "sys_platform != 'win32'", 353 | "version": "==4.7.0" 354 | }, 355 | "pickleshare": { 356 | "hashes": [ 357 | "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca", 358 | "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56" 359 | ], 360 | "version": "==0.7.5" 361 | }, 362 | "pluggy": { 363 | "hashes": [ 364 | "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0", 365 | "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d" 366 | ], 367 | "version": "==0.13.1" 368 | }, 369 | "prompt-toolkit": { 370 | "hashes": [ 371 | "sha256:0278d2f51b5ceba6ea8da39f76d15684e84c996b325475f6e5720edc584326a7", 372 | "sha256:63daee79aa8366c8f1c637f1a4876b890da5fc92a19ebd2f7080ebacb901e990" 373 | ], 374 | "version": "==3.0.2" 375 | }, 376 | "ptyprocess": { 377 | "hashes": [ 378 | "sha256:923f299cc5ad920c68f2bc0bc98b75b9f838b93b599941a6b63ddbc2476394c0", 379 | "sha256:d7cc528d76e76342423ca640335bd3633420dc1366f258cb31d05e865ef5ca1f" 380 | ], 381 | "version": "==0.6.0" 382 | }, 383 | "py": { 384 | "hashes": [ 385 | "sha256:64f65755aee5b381cea27766a3a147c3f15b9b6b9ac88676de66ba2ae36793fa", 386 | "sha256:dc639b046a6e2cff5bbe40194ad65936d6ba360b52b3c3fe1d08a82dd50b5e53" 387 | ], 388 | "version": "==1.8.0" 389 | }, 390 | "pycodestyle": { 391 | "hashes": [ 392 | "sha256:95a2219d12372f05704562a14ec30bc76b05a5b297b21a5dfe3f6fac3491ae56", 393 | "sha256:e40a936c9a450ad81df37f549d676d127b1b66000a6c500caa2b085bc0ca976c" 394 | ], 395 | "version": "==2.5.0" 396 | }, 397 | "pyflakes": { 398 | "hashes": [ 399 | "sha256:17dbeb2e3f4d772725c777fabc446d5634d1038f234e77343108ce445ea69ce0", 400 | "sha256:d976835886f8c5b31d47970ed689944a0262b5f3afa00a5a7b4dc81e5449f8a2" 401 | ], 402 | "version": "==2.1.1" 403 | }, 404 | "pygments": { 405 | "hashes": [ 406 | "sha256:2a3fe295e54a20164a9df49c75fa58526d3be48e14aceba6d6b1e8ac0bfd6f1b", 407 | "sha256:98c8aa5a9f778fcd1026a17361ddaf7330d1b7c62ae97c3bb0ae73e0b9b6b0fe" 408 | ], 409 | "version": "==2.5.2" 410 | }, 411 | "pylint": { 412 | "hashes": [ 413 | "sha256:3db5468ad013380e987410a8d6956226963aed94ecb5f9d3a28acca6d9ac36cd", 414 | "sha256:886e6afc935ea2590b462664b161ca9a5e40168ea99e5300935f6591ad467df4" 415 | ], 416 | "index": "pypi", 417 | "version": "==2.4.4" 418 | }, 419 | "pyparsing": { 420 | "hashes": [ 421 | "sha256:20f995ecd72f2a1f4bf6b072b63b22e2eb457836601e76d6e5dfcd75436acc1f", 422 | "sha256:4ca62001be367f01bd3e92ecbb79070272a9d4964dce6a48a82ff0b8bc7e683a" 423 | ], 424 | "version": "==2.4.5" 425 | }, 426 | "pytest": { 427 | "hashes": [ 428 | "sha256:63344a2e3bce2e4d522fd62b4fdebb647c019f1f9e4ca075debbd13219db4418", 429 | "sha256:f67403f33b2b1d25a6756184077394167fe5e2f9d8bdaab30707d19ccec35427" 430 | ], 431 | "index": "pypi", 432 | "version": "==5.3.1" 433 | }, 434 | "python-dateutil": { 435 | "hashes": [ 436 | "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c", 437 | "sha256:75bb3f31ea686f1197762692a9ee6a7550b59fc6ca3a1f4b5d7e32fb98e2da2a" 438 | ], 439 | "version": "==2.8.1" 440 | }, 441 | "six": { 442 | "hashes": [ 443 | "sha256:1f1b7d42e254082a9db6279deae68afb421ceba6158efa6131de7b3003ee93fd", 444 | "sha256:30f610279e8b2578cab6db20741130331735c781b56053c59c4076da27f06b66" 445 | ], 446 | "version": "==1.13.0" 447 | }, 448 | "traitlets": { 449 | "hashes": [ 450 | "sha256:70b4c6a1d9019d7b4f6846832288f86998aa3b9207c6821f3578a6a6a467fe44", 451 | "sha256:d023ee369ddd2763310e4c3eae1ff649689440d4ae59d7485eb4cfbbe3e359f7" 452 | ], 453 | "version": "==4.3.3" 454 | }, 455 | "typed-ast": { 456 | "hashes": [ 457 | "sha256:1170afa46a3799e18b4c977777ce137bb53c7485379d9706af8a59f2ea1aa161", 458 | "sha256:18511a0b3e7922276346bcb47e2ef9f38fb90fd31cb9223eed42c85d1312344e", 459 | "sha256:262c247a82d005e43b5b7f69aff746370538e176131c32dda9cb0f324d27141e", 460 | "sha256:2b907eb046d049bcd9892e3076c7a6456c93a25bebfe554e931620c90e6a25b0", 461 | "sha256:354c16e5babd09f5cb0ee000d54cfa38401d8b8891eefa878ac772f827181a3c", 462 | "sha256:48e5b1e71f25cfdef98b013263a88d7145879fbb2d5185f2a0c79fa7ebbeae47", 463 | "sha256:4e0b70c6fc4d010f8107726af5fd37921b666f5b31d9331f0bd24ad9a088e631", 464 | "sha256:630968c5cdee51a11c05a30453f8cd65e0cc1d2ad0d9192819df9978984529f4", 465 | "sha256:66480f95b8167c9c5c5c87f32cf437d585937970f3fc24386f313a4c97b44e34", 466 | "sha256:71211d26ffd12d63a83e079ff258ac9d56a1376a25bc80b1cdcdf601b855b90b", 467 | "sha256:7954560051331d003b4e2b3eb822d9dd2e376fa4f6d98fee32f452f52dd6ebb2", 468 | "sha256:838997f4310012cf2e1ad3803bce2f3402e9ffb71ded61b5ee22617b3a7f6b6e", 469 | "sha256:95bd11af7eafc16e829af2d3df510cecfd4387f6453355188342c3e79a2ec87a", 470 | "sha256:bc6c7d3fa1325a0c6613512a093bc2a2a15aeec350451cbdf9e1d4bffe3e3233", 471 | "sha256:cc34a6f5b426748a507dd5d1de4c1978f2eb5626d51326e43280941206c209e1", 472 | "sha256:d755f03c1e4a51e9b24d899561fec4ccaf51f210d52abdf8c07ee2849b212a36", 473 | "sha256:d7c45933b1bdfaf9f36c579671fec15d25b06c8398f113dab64c18ed1adda01d", 474 | "sha256:d896919306dd0aa22d0132f62a1b78d11aaf4c9fc5b3410d3c666b818191630a", 475 | "sha256:fdc1c9bbf79510b76408840e009ed65958feba92a88833cdceecff93ae8fff66", 476 | "sha256:ffde2fbfad571af120fcbfbbc61c72469e72f550d676c3342492a9dfdefb8f12" 477 | ], 478 | "markers": "implementation_name == 'cpython' and python_version < '3.8'", 479 | "version": "==1.4.0" 480 | }, 481 | "wcwidth": { 482 | "hashes": [ 483 | "sha256:3df37372226d6e63e1b1e1eda15c594bca98a22d33a23832a90998faa96bc65e", 484 | "sha256:f4ebe71925af7b40a864553f761ed559b43544f8f71746c2d756c7fe788ade7c" 485 | ], 486 | "version": "==0.1.7" 487 | }, 488 | "wrapt": { 489 | "hashes": [ 490 | "sha256:565a021fd19419476b9362b05eeaa094178de64f8361e44468f9e9d7843901e1" 491 | ], 492 | "version": "==1.11.2" 493 | }, 494 | "zipp": { 495 | "hashes": [ 496 | "sha256:3718b1cbcd963c7d4c5511a8240812904164b7f381b647143a89d3b98f9bcd8e", 497 | "sha256:f06903e9f1f43b12d371004b4ac7b06ab39a44adc747266928ae6debfa7b3335" 498 | ], 499 | "version": "==0.6.0" 500 | } 501 | } 502 | } 503 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Contextual Loss 2 | PyTorch implementation of Contextual Loss (CX) and Contextual Bilateral Loss (CoBi). 3 | 4 | ## Introduction 5 | There are many image transformation tasks whose spatially aligned data is hard to capture in the wild. 6 | Pixel-to-pixel or global loss functions can NOT be directly applied such unaligned data. 7 | CX is a loss function to defeat the problem. 8 | The key idea of CX is interpreting images as sets of feature points that don't have spatial coordinates. 9 | If you want to know more about CX, please refer the original [paper](https://arxiv.org/abs/1803.02077), [repo](https://github.com/roimehrez/contextualLoss) and examples in [./doc](./doc) directory. 10 | 11 | ## Requirements 12 | - Python3.7+ 13 | - `torch` & `torchvision` 14 | 15 | ## Installation 16 | ``` 17 | pip install git+https://github.com/S-aiueo32/contextual_loss_pytorch.git 18 | ``` 19 | 20 | ## Usage 21 | You can use it like PyTorch APIs. 22 | ```python 23 | import torch 24 | 25 | import contextual_loss as cl 26 | import contextual_loss.fuctional as F 27 | 28 | 29 | # input features 30 | img1 = torch.rand(1, 3, 96, 96) 31 | img2 = torch.rand(1, 3, 96, 96) 32 | 33 | # contextual loss 34 | criterion = cl.ContextualLoss() 35 | loss = criterion(img1, img2) 36 | 37 | # functional call 38 | loss = F.contextual_loss(img1, img2, band_width=0.1, loss_type='cosine') 39 | 40 | # comparing with VGG features 41 | # if `use_vgg` is set, VGG model will be created inside of the criterion 42 | criterion = cl.ContextualLoss(use_vgg=True, vgg_layer='relu5_4') 43 | loss = criterion(img1, img2) 44 | 45 | ``` 46 | 47 | ## Reference 48 | ### Papers 49 | 1. Mechrez, Roey, Itamar Talmi, and Lihi Zelnik-Manor. "The contextual loss for image transformation with non-aligned data." Proceedings of the European Conference on Computer Vision (ECCV). 2018. 50 | 2. Mechrez, Roey, et al. "Maintaining natural image statistics with the contextual loss." Asian Conference on Computer Vision. Springer, Cham, 2018. 51 | ### Implementations 52 | Thanks to the owners of the following awesome implementations. 53 | - Original Repository: https://github.com/roimehrez/contextualLoss 54 | - Simple PyTorch Implemantation: https://gist.github.com/yunjey/3105146c736f9c1055463c33b4c989da 55 | - CoBi: https://github.com/ceciliavision/zoom-learn-zoom 56 | -------------------------------------------------------------------------------- /contextual_loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | -------------------------------------------------------------------------------- /contextual_loss/config.py: -------------------------------------------------------------------------------- 1 | # TODO: add supports for L1, L2 etc. 2 | LOSS_TYPES = ['cosine', 'l1', 'l2'] 3 | -------------------------------------------------------------------------------- /contextual_loss/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .config import LOSS_TYPES 5 | 6 | __all__ = ['contextual_loss', 'contextual_bilateral_loss'] 7 | 8 | 9 | def contextual_loss(x: torch.Tensor, 10 | y: torch.Tensor, 11 | band_width: float = 0.5, 12 | loss_type: str = 'cosine'): 13 | """ 14 | Computes contextual loss between x and y. 15 | The most of this code is copied from 16 | https://gist.github.com/yunjey/3105146c736f9c1055463c33b4c989da. 17 | 18 | Parameters 19 | --- 20 | x : torch.Tensor 21 | features of shape (N, C, H, W). 22 | y : torch.Tensor 23 | features of shape (N, C, H, W). 24 | band_width : float, optional 25 | a band-width parameter used to convert distance to similarity. 26 | in the paper, this is described as :math:`h`. 27 | loss_type : str, optional 28 | a loss type to measure the distance between features. 29 | Note: `l1` and `l2` frequently raises OOM. 30 | 31 | Returns 32 | --- 33 | cx_loss : torch.Tensor 34 | contextual loss between x and y (Eq (1) in the paper) 35 | """ 36 | 37 | assert x.size() == y.size(), 'input tensor must have the same size.' 38 | assert loss_type in LOSS_TYPES, f'select a loss type from {LOSS_TYPES}.' 39 | 40 | N, C, H, W = x.size() 41 | 42 | if loss_type == 'cosine': 43 | dist_raw = compute_cosine_distance(x, y) 44 | elif loss_type == 'l1': 45 | dist_raw = compute_l1_distance(x, y) 46 | elif loss_type == 'l2': 47 | dist_raw = compute_l2_distance(x, y) 48 | 49 | dist_tilde = compute_relative_distance(dist_raw) 50 | cx = compute_cx(dist_tilde, band_width) 51 | cx = torch.mean(torch.max(cx, dim=1)[0], dim=1) # Eq(1) 52 | cx_loss = torch.mean(-torch.log(cx + 1e-5)) # Eq(5) 53 | 54 | return cx_loss 55 | 56 | 57 | # TODO: Operation check 58 | def contextual_bilateral_loss(x: torch.Tensor, 59 | y: torch.Tensor, 60 | weight_sp: float = 0.1, 61 | band_width: float = 1., 62 | loss_type: str = 'cosine'): 63 | """ 64 | Computes Contextual Bilateral (CoBi) Loss between x and y, 65 | proposed in https://arxiv.org/pdf/1905.05169.pdf. 66 | 67 | Parameters 68 | --- 69 | x : torch.Tensor 70 | features of shape (N, C, H, W). 71 | y : torch.Tensor 72 | features of shape (N, C, H, W). 73 | band_width : float, optional 74 | a band-width parameter used to convert distance to similarity. 75 | in the paper, this is described as :math:`h`. 76 | loss_type : str, optional 77 | a loss type to measure the distance between features. 78 | Note: `l1` and `l2` frequently raises OOM. 79 | 80 | Returns 81 | --- 82 | cx_loss : torch.Tensor 83 | contextual loss between x and y (Eq (1) in the paper). 84 | k_arg_max_NC : torch.Tensor 85 | indices to maximize similarity over channels. 86 | """ 87 | 88 | assert x.size() == y.size(), 'input tensor must have the same size.' 89 | assert loss_type in LOSS_TYPES, f'select a loss type from {LOSS_TYPES}.' 90 | 91 | # spatial loss 92 | grid = compute_meshgrid(x.shape).to(x.device) 93 | dist_raw = compute_l2_distance(grid, grid) 94 | dist_tilde = compute_relative_distance(dist_raw) 95 | cx_sp = compute_cx(dist_tilde, band_width) 96 | 97 | # feature loss 98 | if loss_type == 'cosine': 99 | dist_raw = compute_cosine_distance(x, y) 100 | elif loss_type == 'l1': 101 | dist_raw = compute_l1_distance(x, y) 102 | elif loss_type == 'l2': 103 | dist_raw = compute_l2_distance(x, y) 104 | dist_tilde = compute_relative_distance(dist_raw) 105 | cx_feat = compute_cx(dist_tilde, band_width) 106 | 107 | # combined loss 108 | cx_combine = (1. - weight_sp) * cx_feat + weight_sp * cx_sp 109 | 110 | k_max_NC, _ = torch.max(cx_combine, dim=2, keepdim=True) 111 | 112 | cx = k_max_NC.mean(dim=1) 113 | cx_loss = torch.mean(-torch.log(cx + 1e-5)) 114 | 115 | return cx_loss 116 | 117 | 118 | def compute_cx(dist_tilde, band_width): 119 | w = torch.exp((1 - dist_tilde) / band_width) # Eq(3) 120 | cx = w / torch.sum(w, dim=2, keepdim=True) # Eq(4) 121 | return cx 122 | 123 | 124 | def compute_relative_distance(dist_raw): 125 | dist_min, _ = torch.min(dist_raw, dim=2, keepdim=True) 126 | dist_tilde = dist_raw / (dist_min + 1e-5) 127 | return dist_tilde 128 | 129 | 130 | def compute_cosine_distance(x, y): 131 | # mean shifting by channel-wise mean of `y`. 132 | y_mu = y.mean(dim=(0, 2, 3), keepdim=True) 133 | x_centered = x - y_mu 134 | y_centered = y - y_mu 135 | 136 | # L2 normalization 137 | x_normalized = F.normalize(x_centered, p=2, dim=1) 138 | y_normalized = F.normalize(y_centered, p=2, dim=1) 139 | 140 | # channel-wise vectorization 141 | N, C, *_ = x.size() 142 | x_normalized = x_normalized.reshape(N, C, -1) # (N, C, H*W) 143 | y_normalized = y_normalized.reshape(N, C, -1) # (N, C, H*W) 144 | 145 | # consine similarity 146 | cosine_sim = torch.bmm(x_normalized.transpose(1, 2), 147 | y_normalized) # (N, H*W, H*W) 148 | 149 | # convert to distance 150 | dist = 1 - cosine_sim 151 | 152 | return dist 153 | 154 | 155 | # TODO: Considering avoiding OOM. 156 | def compute_l1_distance(x: torch.Tensor, y: torch.Tensor): 157 | N, C, H, W = x.size() 158 | x_vec = x.view(N, C, -1) 159 | y_vec = y.view(N, C, -1) 160 | 161 | dist = x_vec.unsqueeze(2) - y_vec.unsqueeze(3) 162 | dist = dist.sum(dim=1).abs() 163 | dist = dist.transpose(1, 2).reshape(N, H*W, H*W) 164 | dist = dist.clamp(min=0.) 165 | 166 | return dist 167 | 168 | 169 | # TODO: Considering avoiding OOM. 170 | def compute_l2_distance(x, y): 171 | N, C, H, W = x.size() 172 | x_vec = x.view(N, C, -1) 173 | y_vec = y.view(N, C, -1) 174 | x_s = torch.sum(x_vec ** 2, dim=1) 175 | y_s = torch.sum(y_vec ** 2, dim=1) 176 | 177 | A = y_vec.transpose(1, 2) @ x_vec 178 | dist = y_s - 2 * A + x_s.transpose(0, 1) 179 | dist = dist.transpose(1, 2).reshape(N, H*W, H*W) 180 | dist = dist.clamp(min=0.) 181 | 182 | return dist 183 | 184 | 185 | def compute_meshgrid(shape): 186 | N, C, H, W = shape 187 | rows = torch.arange(0, H, dtype=torch.float32) / (H + 1) 188 | cols = torch.arange(0, W, dtype=torch.float32) / (W + 1) 189 | 190 | feature_grid = torch.meshgrid(rows, cols) 191 | feature_grid = torch.stack(feature_grid).unsqueeze(0) 192 | feature_grid = torch.cat([feature_grid for _ in range(N)], dim=0) 193 | 194 | return feature_grid 195 | -------------------------------------------------------------------------------- /contextual_loss/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .contextual import ContextualLoss 2 | from .contextual_bilateral import ContextualBilateralLoss 3 | 4 | __all__ = ['ContextualLoss', 'ContextualBilateralLoss'] 5 | -------------------------------------------------------------------------------- /contextual_loss/modules/contextual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vgg import VGG19 5 | from .. import functional as F 6 | from ..config import LOSS_TYPES 7 | 8 | 9 | class ContextualLoss(nn.Module): 10 | """ 11 | Creates a criterion that measures the contextual loss. 12 | 13 | Parameters 14 | --- 15 | band_width : int, optional 16 | a band_width parameter described as :math:`h` in the paper. 17 | use_vgg : bool, optional 18 | if you want to use VGG feature, set this `True`. 19 | vgg_layer : str, optional 20 | intermidiate layer name for VGG feature. 21 | Now we support layer names: 22 | `['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4']` 23 | """ 24 | 25 | def __init__(self, 26 | band_width: float = 0.5, 27 | loss_type: str = 'cosine', 28 | use_vgg: bool = False, 29 | vgg_layer: str = 'relu3_4'): 30 | 31 | super(ContextualLoss, self).__init__() 32 | 33 | assert band_width > 0, 'band_width parameter must be positive.' 34 | assert loss_type in LOSS_TYPES,\ 35 | f'select a loss type from {LOSS_TYPES}.' 36 | 37 | self.band_width = band_width 38 | 39 | if use_vgg: 40 | self.vgg_model = VGG19() 41 | self.vgg_layer = vgg_layer 42 | self.register_buffer( 43 | name='vgg_mean', 44 | tensor=torch.tensor( 45 | [[[0.485]], [[0.456]], [[0.406]]], requires_grad=False) 46 | ) 47 | self.register_buffer( 48 | name='vgg_std', 49 | tensor=torch.tensor( 50 | [[[0.229]], [[0.224]], [[0.225]]], requires_grad=False) 51 | ) 52 | 53 | def forward(self, x, y): 54 | if hasattr(self, 'vgg_model'): 55 | assert x.shape[1] == 3 and y.shape[1] == 3,\ 56 | 'VGG model takes 3 chennel images.' 57 | 58 | # normalization 59 | x = x.sub(self.vgg_mean.detach()).div(self.vgg_std.detach()) 60 | y = y.sub(self.vgg_mean.detach()).div(self.vgg_std.detach()) 61 | 62 | # picking up vgg feature maps 63 | x = getattr(self.vgg_model(x), self.vgg_layer) 64 | y = getattr(self.vgg_model(y), self.vgg_layer) 65 | 66 | return F.contextual_loss(x, y, self.band_width) 67 | -------------------------------------------------------------------------------- /contextual_loss/modules/contextual_bilateral.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vgg import VGG19 5 | from .. import functional as F 6 | from ..config import LOSS_TYPES 7 | 8 | 9 | class ContextualBilateralLoss(nn.Module): 10 | """ 11 | Creates a criterion that measures the contextual bilateral loss. 12 | 13 | Parameters 14 | --- 15 | weight_sp : float, optional 16 | a balancing weight between spatial and feature loss. 17 | band_width : int, optional 18 | a band_width parameter described as :math:`h` in the paper. 19 | use_vgg : bool, optional 20 | if you want to use VGG feature, set this `True`. 21 | vgg_layer : str, optional 22 | intermidiate layer name for VGG feature. 23 | Now we support layer names: 24 | `['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4']` 25 | """ 26 | 27 | def __init__(self, 28 | weight_sp: float = 0.1, 29 | band_width: float = 0.5, 30 | loss_type: str = 'cosine', 31 | use_vgg: bool = False, 32 | vgg_layer: str = 'relu3_4'): 33 | 34 | super(ContextualBilateralLoss, self).__init__() 35 | 36 | assert band_width > 0, 'band_width parameter must be positive.' 37 | assert loss_type in LOSS_TYPES,\ 38 | f'select a loss type from {LOSS_TYPES}.' 39 | 40 | self.band_width = band_width 41 | 42 | if use_vgg: 43 | self.vgg_model = VGG19() 44 | self.vgg_layer = vgg_layer 45 | self.register_buffer( 46 | name='vgg_mean', 47 | tensor=torch.tensor( 48 | [[[0.485]], [[0.456]], [[0.406]]], requires_grad=False) 49 | ) 50 | self.register_buffer( 51 | name='vgg_std', 52 | tensor=torch.tensor( 53 | [[[0.229]], [[0.224]], [[0.225]]], requires_grad=False) 54 | ) 55 | 56 | def forward(self, x, y): 57 | if hasattr(self, 'vgg_model'): 58 | assert x.shape[1] == 3 and y.shape[1] == 3,\ 59 | 'VGG model takes 3 chennel images.' 60 | 61 | # normalization 62 | x = x.sub(self.vgg_mean.detach()).div(self.vgg_std.detach()) 63 | y = y.sub(self.vgg_mean.detach()).div(self.vgg_std.detach()) 64 | 65 | # picking up vgg feature maps 66 | x = getattr(self.vgg_model(x), self.vgg_layer) 67 | y = getattr(self.vgg_model(y), self.vgg_layer) 68 | 69 | return F.contextual_bilateral_loss(x, y, self.band_width) 70 | -------------------------------------------------------------------------------- /contextual_loss/modules/vgg.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch.nn as nn 4 | import torchvision.models.vgg as vgg 5 | 6 | 7 | class VGG19(nn.Module): 8 | def __init__(self, requires_grad=False): 9 | super(VGG19, self).__init__() 10 | vgg_pretrained_features = vgg.vgg19(pretrained=True).features 11 | self.slice1 = nn.Sequential() 12 | self.slice2 = nn.Sequential() 13 | self.slice3 = nn.Sequential() 14 | self.slice4 = nn.Sequential() 15 | self.slice5 = nn.Sequential() 16 | for x in range(4): 17 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 18 | for x in range(4, 9): 19 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 20 | for x in range(9, 18): 21 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 22 | for x in range(18, 27): 23 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 24 | for x in range(27, 36): 25 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 26 | if not requires_grad: 27 | for param in self.parameters(): 28 | param.requires_grad = False 29 | 30 | def forward(self, X): 31 | h = self.slice1(X) 32 | h_relu1_2 = h 33 | h = self.slice2(h) 34 | h_relu2_2 = h 35 | h = self.slice3(h) 36 | h_relu3_4 = h 37 | h = self.slice4(h) 38 | h_relu4_4 = h 39 | h = self.slice5(h) 40 | h_relu5_4 = h 41 | 42 | vgg_outputs = namedtuple( 43 | "VggOutputs", ['relu1_2', 'relu2_2', 44 | 'relu3_4', 'relu4_4', 'relu5_4']) 45 | out = vgg_outputs(h_relu1_2, h_relu2_2, 46 | h_relu3_4, h_relu4_4, h_relu5_4) 47 | 48 | return out 49 | -------------------------------------------------------------------------------- /doc/small_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Small exmaple of Contextual Loss (CX)\n", 8 | "[Open with Colab](https://colab.research.google.com/github/S-aiueo32/contextual_loss_pytorch/blob/master/doc/small_example.ipynb)\n", 9 | "\n", 10 | "I will show you a small example of CX here.\n", 11 | "\n", 12 | "## Calculate CX\n", 13 | "For the first, import dependencies." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import numpy as np\n", 23 | "import matplotlib.pyplot as plt" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "Create images called $x$ and $y$ in the paper.\n", 31 | "You can see the pixels whose spatial correspondings have similar looks.\n", 32 | "For simplicity, I will use the RGB values as the features through this notebook." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "" 44 | ] 45 | }, 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | }, 50 | { 51 | "data": { 52 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAABWCAYAAADMpy0uAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAHDElEQVR4nO3dX4gdZx3G8efpJmlIU2NstK5paCtJi1XR2rUgxT/0D229aApaaC8khZZVJKg3YkSQ0hujFy0IvQk2UEXaShVcSyAobSlBLVlLrU1CmjUXbUI0bVISo2i68niRIxzWs93amXNmzrzfDyw7c+Zl3t+c/Z2Hs7Mze5xEAIDuO6/pAgAAo0HgA0AhCHwAKASBDwCFIPABoBAEPgAUolLg236P7V/bPtT7vnaRcf+2/ULva6bKnMAo0NvoIle5Dt/2DySdTLLd9jZJa5N8a8C4M0lWV6gTGCl6G11UNfAPSvpckmO2JyU9k+TKAeN4UWCs0Nvooqrn8C9Ocqy3/BdJFy8ybqXtWdu/t317xTmBUaC30TnLlhpg+zeS3j9g03f6V5LE9mK/Llya5KjtD0p6yvafkvx5wFzTkqYladXEeddsXL1qyQMYB4c+0I3jkCT//X1Nl1Cbf7zy0puSXh6waai9fcHKlddcsWF9xerb4dSq5U2XUBvPr2y6hFocP/qKTr9xwoO2LRn4SW5cbJvtv9qe7Pu19/gi+zja+37Y9jOSrpb0Py+KJDsk7ZCkj737wuz+9MeXKm8s3HzfJ5suoTYrZr/adAm1mf3KpheTTA3aNsze/sQVG7Pnhw/UdBTN2jX13qZLqM3Eax9uuoRafPOOzy66reopnRlJW3rLWyT9cuEA22ttn99bXifpOkn7K84LDBu9jc6pGvjbJd1k+5CkG3vrsj1l+0e9MR+SNGv7j5KelrQ9CS8KtB29jc5Z8pTOW0lyQtINAx6flXRvb/m3kj5aZR5g1OhtdBF32gJAIQh8ACgEgQ8AhSDwAaAQBD4AFILAB4BCEPgAUAgCHwAKQeADQCEIfAAoBIEPAIUg8AGgEAQ+ABSCwAeAQhD4AFAIAh8ACkHgA0Ahagl827fYPmh7zva2AdvPt/14b/tzti+rY15g2OhtdEnlwLc9IekhSbdKukrSXbavWjDsHklvJNko6UFJ3686LzBs9Da6po53+NdKmktyOMlZSY9J2rxgzGZJj/SWn5B0g23XMDcwTPQ2OqWOwF8v6dW+9SO9xwaOSTIv6ZSkixbuyPa07VnbsyfOvllDaUAlQ+nt10+dHlK5wFtr1R9tk+xIMpVk6qIVy5suB6hNf2+vW/OupstBoeoI/KOSNvStX9J7bOAY28skrZF0ooa5gWGit9EpdQT+XkmbbF9ue4WkOyXNLBgzI2lLb/mLkp5KkhrmBoaJ3kanLKu6gyTztrdK2i1pQtLOJPts3y9pNsmMpIcl/cT2nKSTOvfCAVqN3kbXVA58SUqyS9KuBY99t2/5n5LuqGMuYJTobXRJq/5oCwAYHgIfAApB4ANAIQh8ACgEgQ8AhSDwAaAQBD4AFILAB4BCEPgAUAgCHwAKQeADQCEIfAAoBIEPAIUg8AGgEAQ+ABSCwAeAQtQS+LZvsX3Q9pztbQO23237Ndsv9L7urWNeYNjobXRJ5U+8sj0h6SFJN0k6Immv7Zkk+xcMfTzJ1qrzAaNCb6Nr6niHf62kuSSHk5yV9JikzTXsF2gavY1OqSPw10t6tW/9SO+xhb5g+0XbT9jeUMO8wLDR2+iUWj7E/G34laRHk/zL9pclPSLp+oWDbE9Lmu6tnpl8cs/BEdS2TtLrQ53hyT1D3X2f4R+LHhzu7s8ZwXFIki6tYR/vqLcvuHXzsHt7VM/hKHAs/59F+9pJKu3Z9qck3Zfk5t76tyUpyfcWGT8h6WSSNZUmront2SRTTddRh64cS1uOY5x7uy3PYR04lvrUcUpnr6RNti+3vULSnZJm+gfYnuxbvU3SgRrmBYaN3kanVD6lk2Te9lZJuyVNSNqZZJ/t+yXNJpmR9DXbt0mal3RS0t1V5wWGjd5G11Q+pTPubE8n2dF0HXXoyrF05Tia1KXnkGOpcf7SAx8ASsG/VgCAQhQb+EvdMj8ubO+0fdz2S03XUpXtDbaftr3f9j7bX2+6pnFEb7dLm/q6yFM6vcvnXlbfLfOS7hpwy3zr2f6MpDOSfpzkI03XU0XvipfJJM/bvlDSHyTdPo4/l6bQ2+3Tpr4u9R1+Z26ZT/Kszl0dMvaSHEvyfG/5bzp3ieOgO1uxOHq7ZdrU16UG/tu9ZR4NsX2ZpKslPddsJWOH3m6xpvu61MBHi9leLennkr6R5HTT9QB1aENflxr4RyX1/5OrS3qPoWG2l+vci+KnSX7RdD1jiN5uobb0damBv+Qt8xg925b0sKQDSR5oup4xRW+3TJv6usjATzIv6b+3zB+Q9LMk+5qt6p2x/aik30m60vYR2/c0XVMF10n6kqTr+z5B6vNNFzVO6O1Wak1fF3lZJgCUqMh3+ABQIgIfAApB4ANAIQh8ACgEgQ8AhSDwAaAQBD4AFILAB4BC/Acw2kFzNIE/cgAAAABJRU5ErkJggg==\n", 53 | "text/plain": [ 54 | "
" 55 | ] 56 | }, 57 | "metadata": { 58 | "needs_background": "light" 59 | }, 60 | "output_type": "display_data" 61 | } 62 | ], 63 | "source": [ 64 | "x = np.array([\n", 65 | " [231, 76, 60],\n", 66 | " [46, 204, 113],\n", 67 | " [52, 152, 219],\n", 68 | "]).reshape(1, 3, 3) / 255\n", 69 | "y = np.array([\n", 70 | " [245, 183, 177],\n", 71 | " [171, 235, 198],\n", 72 | " [174, 214, 241],\n", 73 | "]).reshape(1, 3, 3) / 255\n", 74 | "\n", 75 | "plt.figure()\n", 76 | "plt.subplot(1, 2, 1)\n", 77 | "plt.imshow(x)\n", 78 | "plt.subplot(1, 2, 2)\n", 79 | "plt.imshow(y)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "Along with the key idea, convert $x$, $y$ to $X$, $Y$ respectively." 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 3, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "X = x.reshape(-1, 3)\n", 96 | "Y = y.reshape(-1, 3)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "Calculate cosine distances between all points." 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "# normalize\n", 113 | "mu = Y.mean(axis=0, keepdims=True)\n", 114 | "X_centered = X -mu\n", 115 | "Y_centered = Y -mu\n", 116 | "X_normalized = X_centered / np.linalg.norm(X_centered, ord=2, axis=1, keepdims=True)\n", 117 | "Y_normalized = Y_centered / np.linalg.norm(Y_centered, ord=2, axis=1, keepdims=True)\n", 118 | "\n", 119 | "# cosine distance\n", 120 | "d = 1 - np.matmul(X_normalized, Y_normalized.transpose())" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "Looking at the heatmap of `d`, you confirm the correspondings are similar." 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 5, 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "data": { 137 | "image/png": "\n", 138 | "text/plain": [ 139 | "
" 140 | ] 141 | }, 142 | "metadata": { 143 | "needs_background": "light" 144 | }, 145 | "output_type": "display_data" 146 | } 147 | ], 148 | "source": [ 149 | "plt.imshow(d)\n", 150 | "plt.colorbar()\n", 151 | "plt.title('Distance Map');\n", 152 | "plt.ylabel('x_i');\n", 153 | "plt.xlabel('y_j');" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "Along with Eq.(2) in the paper, normalize using the minimum.\n", 161 | "Through this process, the most similar point over $y_j$ becomes `1.`" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 6, 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "data": { 171 | "image/png": "\n", 172 | "text/plain": [ 173 | "
" 174 | ] 175 | }, 176 | "metadata": { 177 | "needs_background": "light" 178 | }, 179 | "output_type": "display_data" 180 | } 181 | ], 182 | "source": [ 183 | "d_tilde = d / (d.min(axis=1, keepdims=True) + 1e-5)\n", 184 | "\n", 185 | "plt.imshow(d_tilde)\n", 186 | "plt.colorbar()\n", 187 | "plt.title('Normalized Distance Map');\n", 188 | "plt.ylabel('x_i');\n", 189 | "plt.xlabel('y_j');" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "Convert to similarity showed in Eq.(3)." 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 7, 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "data": { 206 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUgAAAEXCAYAAADPzN0RAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAaVUlEQVR4nO3de9RddX3n8feHCGQQEDAdQQhExyim6BRMCRQvtN4CdYhrqZ3QVsSFk/FCxwu2pTqDLZ2ZKjPjjI4sMVWGi5SL6MJHJ64IioOtBomICYEiAbVcolxNQBHI83zmj70fejjZ55LknPM7z5PPa629ztln/85vf58DfPnt/bts2SYiIra1W+kAIiLGVRJkREQHSZARER0kQUZEdJAEGRHRQRJkREQHSZC7MEl/JOnrO/jdV0i6rWX/J5JesxOxPCrp+Tv6/YhhSIKc5SS9XNJ3JG2W9JCkf5D02wC2L7H9uh2p1/a3bb9oUHHa3tv2nXXMF0j6zztaV52sn5A0r+3zH0iypAU7F23sKpIgZzFJ+wJfBf43cABwMPBXwOMl42ol6RlDqvrHwMkt53kJsNeQzhWzVBLk7PZCANuX2p60/Zjtr9teByDpVEl/P124bl29W9Ltkh6R9NeS/lXdAt0i6QpJe9Rlj5d0d9NJJR0t6buSfiFpk6RPTX+v5TzvkXQ7cHvLZy+QtAL4I+DP6svur0j6U0lfbDvHJyV9osvffjFwSsv+24CL2ur4/bpVuUXSXZL+suXYgjqmFZLurf+OD3Y5X8xCSZCz24+ASUkXSjpB0v59fOf1wMuAY4A/A1YCfwzMB46gpVXWxSTwfmAecCzwauDdbWXeCCwBFrV+aHslcAlwTn3Z/W+AzwNLJe0HT7U6l9OW8NqsAfaV9GJJc+ryn28r80uqJLof8PvAuyS9sa3M7wILgdcBf74z91lj5kmCnMVsbwFeDhj4W+B+SROSntPla+fY3mJ7A3Az8HXbd9reDHwNOLKP837f9hrbW23/BPgM8Kq2Yn9j+yHbj/VR3ybgOuAt9UdLgQdsf7/HV6dbka8FbgXuaav3W7bX256qW9WXNsT5V7Z/aXs98H/o738QMUskQc5ytm+1fartQ6hagM8F/leXr/y85f1jDft79zqnpBdK+qqkn0naAvxXqtZkq7v6+gP+2YVULVnq14v7+M7FwB8Cp9LQ2pS0RNK1ku6XtBl4Z484f0r1+8UuIglyF2L7H4ELqBLlMH0a+Edgoe19gQ8Bag+ny/ebjl0FvFTSEcAbqC7Du7L9U6rOmhOBLzUU+TtgAphv+1nAeQ1xzm95fyhwb6/zxuyRBDmLSTpc0hmSDqn351NdIq4Z8qn3AbYAj0o6HHjXdn7/58DTxkTa/jVwJVVS+57tf+qzrtOA37P9yw5xPmT715KOpmpttvtPkvaS9JvA24HL+/0jYuZLgpzdHqHqCLle0i+pEuPNwBlDPu8HqZLNI1T3Prc3qXwOWFT3gl/V8vmFwEvo7/IaANt32F7b4fC7gbMlPQKcBVzRUOb/ARuBbwD/3fYODayPmUlZMDdmCkmHUl26H1h3QA3zXAuoLs93t711mOeK8ZUWZMwIknYDPgBcNuzkGDGtWIKUdICkq+tByVd3GqMnaVLSTfU2Meo4ozxJz6S6p/la4COFw4kxJel8SfdJurnDcdUTDDZKWifpqJ51lrrElnQO1Q3yj0o6E9jf9p83lHvUds+hJRGxa5P0SuBR4CLb24zUkHQi8CdUoxqWAJ+wvaRbnSUvsZdR3XSnfm2fwRAR0Tfb1wEPdSmyjCp52vYaYD9JB3Wrc1gLBfTjOfUMCYCfAZ1md8yVtBbYCnzU9lVNheo5vCsAnrmXXnb4C/ZoKhbAj9ZlzYbYeY/w8AO2f2Nn6nj97z7TDz402bPc99c9vtr20p05F9ViLa0D/++uP9vUXHzICVLSNcCBDYc+3Lpj25I6XesfZvueeq3Ab0pab/uO9kL1HN6VAIv/9Vx/b/X89iJRe/1zf6t0CDELXOMrf7qzdTz40CTfW31oz3JzDrr98LqhNG1l/d/8UA01QdruOLFf0s8lHWR7U93Mva9DHffUr3dK+hbVXOBtEmREzDwGppjqp+gDthfv5Onu4ekzow6hbX5+u5L3ICeolqCifv1yewFJ+0vas34/DzgOuGVkEUbEUBnzpCd7bgMyAZxS92YfA2xuuc3XqOQ9yI8CV0g6jWoRgD8AkLQYeKftdwAvBj4jaYoqmX/UdhJkxCzSZwuyJ0mXAscD8+q1Sj8C7A5g+zxgFVUP9kbgV1RTR7sqliBtP0i1TmD752uBd9Tvv0M1tSwiZiFjJgc01NB216XoXI1pfM/21FmyBRkRwVTXhZ3KSoKMiGIMTCZBRkQ0SwsyIqKBgSfHeEWxJMiIKMY4l9gREY0Mk+ObH5MgI6KcaibN+EqCjIiCxOQ2z0kbH0mQEVGMgalcYkdEbMvAE2P85JckyIgoasq5xI6I2EY1kyYJMiJiG0ZM5hI7IqJZLrEjIhoY8YTnlA6joyTIiCimGiieS+yIiEbppImIaGCLSacFGRHRaCotyIiIbVXjINOCjIjYhhFPenzT0PhGFhG7hMmMg4yI2FZm0kREdDGVXuyIiG2lkyYiogOj3IOMiGhiM9a92MXbtpKWSrpN0kZJZzYc31PS5fXx6yUtGH2UETEcYqqPrZSiCVLSHOBc4ARgEXCypEVtxU4DHrb9AuB/Ah8bbZQRMSwGJr1bz62U0i3Io4GNtu+0/QRwGbCsrcwy4ML6/ZXAqyWN702LiNguk+zWcyuldII8GLirZf/u+rPGMra3ApuBZ48kuogYKiOm3HsrZXzvjm4nSSuAFQCHHjxr/qyIWc2kk6abe4D5LfuH1J81lpH0DOBZwIPtFdleaXux7cW/8ezxXaE4IlqJyT62UkonyBuAhZKeJ2kPYDkw0VZmAnhb/f7NwDdtj/GjxiOiX6aaSdNrK6Vo29b2VkmnA6uBOcD5tjdIOhtYa3sC+BxwsaSNwENUSTQiZomsKN6F7VXAqrbPzmp5/2vgLaOOKyKGz9bAWoiSlgKfoGpsfdb2R9uOH0o1Ima/usyZdf7pqHiCjIhd2yDGObaMqX4t1WiYGyRN2L6lpdh/BK6w/el6vPUqYEG3epMgI6KYasHcgXSqPjWmGkDS9Jjq1gRpYN/6/bOAe3tVmgQZEcVUnTR93YOcJ2lty/5K2ytb9pvGVC9pq+Mvga9L+hPgmcBrep00CTIiiupzpswDthfv5KlOBi6w/T8kHUvV+XuE7alOX0iCjIhipmfSDEA/Y6pPA5YC2P6upLnAPOC+TpWWHgcZEbu4KXbrufWhnzHV/wS8GkDSi4G5wP3dKk0LMiKKsQfz0K4+x1SfAfytpPdT3f48tdekkyTIiCjGiK1Tg5ka3MeY6luA47anziTIiCgqM2kiIhpsxzCfIpIgI6KgwU01HIYkyIgoquQzZ3pJgoyIYmx4ckCdNMOQBBkRxQxwoPhQJEFGRFG5xI6IaJBe7IiILtKLHRHRpPBjXXtJgoyIYgxsTQsyImJbuQcZEdFFEmRERIOMg4yI6CLjICMimjiX2BERjQxsnUovdkTENnIPMiKiCydBRkQ0G+dOmuIX/5KWSrpN0kZJZzYcP1XS/ZJuqrd3lIgzIgbPdSdNr62Uoi1ISXOAc4HXAncDN0iaqJ8+1upy26ePPMCIGDIxOcadNKUjOxrYaPtO208AlwHLCscUESNkq+dWSul7kAcDd7Xs3w0saSj3JkmvBH4EvN/2Xe0FJK0AVgDMZS9e/9zfGkK4s8Pqe28qHcLYy78/ozHuc7FLtyD78RVgge2XAlcDFzYVsr3S9mLbi3dnz5EGGBE7yNV9yF5bKaUT5D3A/Jb9Q+rPnmL7QduP17ufBV42otgiYgSmUM+tlNIJ8gZgoaTnSdoDWA5MtBaQdFDL7knArSOMLyKGyOQeZEe2t0o6HVgNzAHOt71B0tnAWtsTwH+QdBKwFXgIOLVYwBExYGJyanzvQZbupMH2KmBV22dntbz/C+AvRh1XRIxGZtJERDSoOmGSICMiGo3zMJ8kyIgoquQwnl6SICOiGCOmxniqYRJkRBQ1xg3I4uMgI2JX5sGNg+y1Mlhd5g8k3SJpg6S/61VnWpARUdYAmpD9rAwmaSHVkMHjbD8s6V/2qjctyIgoakAtyH5WBvt3wLm2H67O6/t6VZoEGRFF9blYxTxJa1u2FW3VNK0MdnBbmRcCL5T0D5LWSFraK7ZcYkdEMTa4v17sB2wv3snTPQNYCBxPtTDOdZJeYvsXnb6QFmREFDWg5c56rgxG1aqcsP2k7R9TrS+7sFulSZARUZb72HrruTIYcBVV6xFJ86guue/sVmkusSOioMEsZ9bnymCrgddJugWYBP7U9oPd6k2CjIiyBjRSvI+VwQx8oN76kgQZEeVkNZ+IiC6SICMiOhjjydhJkBFRVhJkREQDk0vsiIhOsmBuREQneaphREQzpQUZEdGg/6mERSRBRkRBSidNRERHaUFGRHSQBBkR0cCMdS921/UgJV1Rv66XtK5lWy9p3SACkHS+pPsk3dzhuCR9sn5S2TpJRw3ivBExHuTeWym9WpDvrV/fMMQYLgA+BVzU4fgJVKv+LgSWAJ+uXyNiNpipl9i2N9WvP+1WTtJ3bR+7IwHYvk7Sgi5FlgEX1Wu5rZG0n6SDpmOLiBiWQT1yYe6A6mnSz9PKkLRi+olnT/L4EMOJiEGayZfY/SreSLa9ElgJsK8OKB5PRPQp4yB3Sj9PK4uImcjAVOkgOuvrElvSoobPjm/dHVRADSaAU+re7GOAzbn/GDF7zIZL7CskXQycQ3W/8RxgMTDdMfPWHQ1A0qVUj2KcJ+lu4CPA7gC2z6N6CM+JwEbgV8Dbd/RcETGGxviGWL8JcgnwMeA7wD7AJcBx0wdtN45h7Iftk3scN/CeHa0/IsbcLEiQTwKPAf+CqgX5Y9tjfOcgImaC0pfQvfQ7zOcGqgT528ArgJMlfWFoUUXErmNKvbdC+m1BnmZ7bf1+E7BM0g7fd4yImDbOLci+EmRLcmz97OLBhxMRu5yZniAjIoZizO9BJkFGRFlJkBERHSRBRkQ0yyV2REQnSZAREQ3SSRMR0UUSZEREB0mQERHbEuN9iT2oRy5ERGw/g6Z6b/2QtFTSbfUTUM/sUu5Nkixpca86kyAjoiz3sfUgaQ5wLtVTUBdRLajTtND3PlRPa72+n9CSICOirAEkSOBoYKPtO20/AVxG9UTUdn9Ntbbtr/upNAkyIorq85EL86afWlpvK9qq6fn0U0lHAfNt/99+Y0snTUSU1V8L8QHbPe8ZdiJpN+DjwKnb870kyIgox/13wvTQ6+mn+wBHAN+SBHAgMCHppKblHKclQUZEWYMZ5nMDsFDS86gS43LgD586hb0ZmDe9L+lbwAe7JUfIPciIKGwQj321vRU4HVgN3ApcYXuDpLMlnbSjsaUFGRFlDWiguO1VVI+Jbv3srA5lj++nziTIiCin/2E8RSRBRkQxqrdxlQQZEUUNqBd7KJIgI6KsXGJHRHQwxgmy6DAfSedLuk/SzR2OHy9ps6Sb6q2xRyoiZqg+hviUXA6tdAvyAuBTwEVdynzb9htGE05EjNwYtyCLJkjb10laUDKGiCgrnTQ751hJPwTupZoatKGpUL26xwqAuew1wvBmnhNe9IrSIYy91fd+u3QIY2/OQYOpZ5xXFB/3BHkjcJjtRyWdCFwFLGwqaHslsBJgXx0wxj95RDxlzAeKj/VcbNtbbD9av18F7C5pXo+vRcRMMpgFc4dirBOkpANVr00k6WiqeB8sG1VEDMr0Q7vSi91A0qXA8VSrBd8NfATYHcD2ecCbgXdJ2go8Biy3PcYN8ojYbmP8X3TpXuyTexz/FNUwoIiYjQyaGt8MOe6dNBExy6UXOyKikyTIiIhmaUFGRHSSBBkR0aDwMJ5ekiAjohiRudgREZ2N8dDmJMiIKCqX2BERTcZ8sYokyIgoKvcgIyI6SIKMiGhi0kkTEdFJOmkiIjpJgoyI2Nb0grnjKgkyIsqxcw8yIqKT9GJHRHSQS+yIiCYG8siFiIgOxjc/jvdjXyNi9hvUY18lLZV0m6SNks5sOP4BSbdIWifpG5IO61VnEmRElDXdk91t60HSHOBc4ARgEXCypEVtxX4ALLb9UuBK4Jxe9SZBRkQ5rnqxe219OBrYaPtO208AlwHLnnYq+1rbv6p31wCH9Ko09yAjophqoHhf19DzJK1t2V9pe2XL/sHAXS37dwNLutR3GvC1XidNgoyIsvprIT5ge/EgTifpj4HFwKt6lU2CjIii+mxB9nIPML9l/5D6s6efS3oN8GHgVbYf71Vp7kFGRDnuc+vtBmChpOdJ2gNYDky0FpB0JPAZ4CTb9/VTadEEKWm+pGvrrvcNkt7bUEaSPll33a+TdFSJWCNiGIymem89a7G3AqcDq4FbgStsb5B0tqST6mL/Ddgb+IKkmyRNdKjuKaUvsbcCZ9i+UdI+wPclXW37lpYyJwAL620J8Gm633yNiJlkQItV2F4FrGr77KyW96/Z3jqLtiBtb7J9Y/3+EarMf3BbsWXARa6sAfaTdNCIQ42IYRjcMJ+hGJt7kJIWAEcC17cdauq+b0+iETFTDWCg+LCUvsQGQNLewBeB99nesoN1rABWAMxlrwFGFxFDNcZzsYsnSEm7UyXHS2x/qaFIX9339aDRlQD76oAx/skjotWAhvkMRelebAGfA261/fEOxSaAU+re7GOAzbY3jSzIiBgeA5PuvRVSugV5HPBWYL2km+rPPgQcCmD7PKpeqROBjcCvgLcXiDMihkB4rFuQRROk7b+nmo7ZrYyB94wmoogYuSTIiIgOkiAjIhqYfherKCIJMiKKyj3IiIhGhqnxbUImQUZEOSb3ICMiOhrfBmQSZESUlXuQERGdJEFGRDSwYXJ8r7GTICOirLQgIyI6SIKMiGhgoI9nzpSSBBkRBRmce5AREc1yiR0R0cCkFzsioqO0ICMimpR9amEvSZARUY7Jaj4RER2lBRkR0UESZEREAxtPTpaOoqMkyIgoKzNpIiI6yCV2REQD55k0ERGdpQUZEdEknTQREc2y3FlERBdjvNzZbiVPLmm+pGsl3SJpg6T3NpQ5XtJmSTfV21klYo2IwTPgKffc+iFpqaTbJG2UdGbD8T0lXV4fv17Sgl51lm5BbgXOsH2jpH2A70u62vYtbeW+bfsNBeKLiGHyYBbMlTQHOBd4LXA3cIOkibZcchrwsO0XSFoOfAz4t93qLdqCtL3J9o31+0eAW4GDS8YUEaM1oBbk0cBG23fafgK4DFjWVmYZcGH9/krg1ZLUrdLSLcin1M3dI4HrGw4fK+mHwL3AB21vaPj+CmBFvfv4Nb7y5iGFuqPmAQ+UDgKALcA4xVMZq3jmHDRe8dTGLaYX7WwFj/Dw6mumrpjXR9G5kta27K+0vbJl/2Dgrpb9u4ElbXU8Vcb2VkmbgWfT5TcdiwQpaW/gi8D7bG9pO3wjcJjtRyWdCFwFLGyvo/6xVtb1rbW9eMhhb5dxiynxdDdu8cD4xdSWsHaI7aWDiGVYil5iA0janSo5XmL7S+3HbW+x/Wj9fhWwu6R+/o8TEbuOe4D5LfuH1J81lpH0DOBZwIPdKi3diy3gc8Cttj/eocyB0/cJJB1NFXPXPyoidjk3AAslPU/SHsByYKKtzATwtvr9m4Fv2t2n8ZS+xD4OeCuwXtJN9WcfAg4FsH0e1R/yLklbgceA5b3+KOpL7TEzbjElnu7GLR4Yv5jGJp76nuLpwGpgDnC+7Q2SzgbW2p6gaoxdLGkj8BBVEu1KvXNNRMSuqfg9yIiIcZUEGRHRwaxIkJIOkHS1pNvr1/07lJtsmbLYfgN3EHEMfKrTCGI6VdL9Lb/LO4YYy/mS7pPUOEZVlU/Wsa6TdNSwYukznpFOc+1z6u2of6Ndezqw7Rm/AecAZ9bvzwQ+1qHco0OMYQ5wB/B8YA/gh8CitjLvBs6r3y8HLh/y79JPTKcCnxrRP6dXAkcBN3c4fiLwNUDAMcD1heM5HvjqKH6b+nwHAUfV7/cBftTwz2vUv1E/MY30dxrlNitakDx9CtGFwBsLxDCUqU4jiGlkbF9H1XvYyTLgIlfWAPtJOqhgPCPl/qbejvo32qWnA8+WBPkc25vq9z8DntOh3FxJayWtkTToJNo01an9X6SnTXUCpqc6DUs/MQG8qb5cu1LS/Ibjo9JvvKN0rKQfSvqapN8c1Um7TL0t9hv1Mx141L/TsJUeB9k3SdcABzYc+nDrjm1L6jR26TDb90h6PvBNSett3zHoWGeYrwCX2n5c0r+nauH+XuGYxkVf01wHrcfU2yIGMR14JpoxLUjbr7F9RMP2ZeDn05cZ9et9Heq4p369E/gW1f8NB2UoU52GHZPtB20/Xu9+FnjZEOPppZ/fcGRcYJprr6m3FPiNduXpwDMmQfbQOoXobcCX2wtI2l/SnvX7eVSzeNrXndwZQ5nqNOyY2u5fnUR1j6mUCeCUuqf2GGBzy62TkdOIp7nW5+o69ZYR/0b9xDTq32mkSvcSDWKjuo/3DeB24BrggPrzxcBn6/e/A6yn6sldD5w2hDhOpOrluwP4cP3Z2cBJ9fu5wBeAjcD3gOeP4LfpFdPfABvq3+Va4PAhxnIpsAl4kure2WnAO4F31sdFtejpHfU/o8VD/m16xXN6y2+zBvidIcfzcqpFttcBN9XbiYV/o35iGunvNMotUw0jIjqYLZfYEREDlwQZEdFBEmRERAdJkBERHSRBRkR0kAQZEdFBEmSMHUnPlXRl6TgiMg4yIqKDtCBjZCSdLel9Lfv/pcMCrAs6LWIbMUpJkDFK5wOnAEjajWpu+OeLRhTRxYxZ7ixmPts/kfSgpCOp1uz8ge3ZsahBzEpJkDFqn6V6zMOBVC3KiLGVTpoYqXrZtfXA7sBC25MNZRZQPePkiNFGF/F0aUHGSNl+QtK1wC+akmNr0VHFFNFJEmSMVN05cwzwli7Fns0YPUwrdl3pxY6RkbSIarHgb9i+vUOZxVQL2X5ilLFFNMk9yChG0kuAi9s+ftz2khLxRLRLgoyI6CCX2BERHSRBRkR0kAQZEdFBEmRERAf/HyfBk5lQwXMgAAAAAElFTkSuQmCC\n", 207 | "text/plain": [ 208 | "
" 209 | ] 210 | }, 211 | "metadata": { 212 | "needs_background": "light" 213 | }, 214 | "output_type": "display_data" 215 | } 216 | ], 217 | "source": [ 218 | "w = np.exp((1 - d_tilde) / 0.1)\n", 219 | "\n", 220 | "plt.imshow(w)\n", 221 | "plt.colorbar()\n", 222 | "plt.title('Similarity Map');\n", 223 | "plt.ylabel('x_i');\n", 224 | "plt.xlabel('y_j');" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": {}, 230 | "source": [ 231 | "The bandwidth parameter is set as `0.1` above.\n", 232 | "The setting is very important to get the acceptable result i.e. the inappropriate value prevents Top-1 feature enhancement like below." 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 8, 238 | "metadata": {}, 239 | "outputs": [ 240 | { 241 | "data": { 242 | "image/png": "\n", 243 | "text/plain": [ 244 | "
" 245 | ] 246 | }, 247 | "metadata": { 248 | "needs_background": "light" 249 | }, 250 | "output_type": "display_data" 251 | } 252 | ], 253 | "source": [ 254 | "w_ = np.exp((1 - d_tilde) / 0.8)\n", 255 | "\n", 256 | "plt.imshow(w_)\n", 257 | "plt.colorbar()\n", 258 | "plt.title('Similarity Map (failure case)');\n", 259 | "plt.ylabel('x_i');\n", 260 | "plt.xlabel('y_j');" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": {}, 266 | "source": [ 267 | "Normalize and get final result." 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 9, 273 | "metadata": {}, 274 | "outputs": [ 275 | { 276 | "name": "stdout", 277 | "output_type": "stream", 278 | "text": [ 279 | "CX: 0.9878605414928291\n" 280 | ] 281 | } 282 | ], 283 | "source": [ 284 | "cx_ij = w / np.sum(w, axis=1, keepdims=True) # normalize\n", 285 | "cx = np.mean(np.max(cx_ij, axis=0))\n", 286 | "print(f'CX: {cx}')" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "metadata": {}, 292 | "source": [ 293 | "## The Robustness of CX\n", 294 | "CX is the metric invariant misalignment of images.\n", 295 | "To show the robustness, shuffle the pixels of $y$ and measure CX on all patterns." 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 10, 301 | "metadata": {}, 302 | "outputs": [ 303 | { 304 | "name": "stdout", 305 | "output_type": "stream", 306 | "text": [ 307 | "[trial 0]: 0.9878605414928291\n", 308 | "[trial 1]: 0.9878605414928291\n", 309 | "[trial 2]: 0.9878605414928291\n", 310 | "[trial 3]: 0.9878605414928291\n", 311 | "[trial 4]: 0.9878605414928291\n", 312 | "[trial 5]: 0.9878605414928291\n" 313 | ] 314 | } 315 | ], 316 | "source": [ 317 | "from itertools import permutations\n", 318 | "\n", 319 | "def compute_cx(x, y): # integrate as a function\n", 320 | " X = x.reshape(-1, 3)\n", 321 | " Y = y.reshape(-1, 3)\n", 322 | "\n", 323 | " mu = Y.mean(axis=0, keepdims=True)\n", 324 | " X_centered = X -mu\n", 325 | " Y_centered = Y -mu\n", 326 | " X_normalized = X_centered / np.linalg.norm(X_centered, ord=2, axis=1, keepdims=True)\n", 327 | " Y_normalized = Y_centered / np.linalg.norm(Y_centered, ord=2, axis=1, keepdims=True)\n", 328 | "\n", 329 | " d = 1 - np.matmul(X_normalized, Y_normalized.transpose())\n", 330 | " d_tilde = d / (d.min(axis=1, keepdims=True) + 1e-5)\n", 331 | " w = np.exp((1 - d_tilde) / 0.1)\n", 332 | " cx_ij = w / np.sum(w, axis=1, keepdims=True)\n", 333 | "\n", 334 | " return np.mean(np.max(cx_ij, axis=0))\n", 335 | "\n", 336 | "for i, p in enumerate(permutations([0, 1, 2])):\n", 337 | " y_ = y[:, p, :]\n", 338 | " cx = compute_cx(x, y_)\n", 339 | " print(f'[trial {i}]: {cx}')" 340 | ] 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "metadata": {}, 345 | "source": [ 346 | "Did you cnofirm all of CXs is same? It's robustness!\n", 347 | "\n", 348 | "That's all, happy CX!" 349 | ] 350 | } 351 | ], 352 | "metadata": { 353 | "kernelspec": { 354 | "display_name": "Python 3", 355 | "language": "python", 356 | "name": "python3" 357 | }, 358 | "language_info": { 359 | "codemirror_mode": { 360 | "name": "ipython", 361 | "version": 3 362 | }, 363 | "file_extension": ".py", 364 | "mimetype": "text/x-python", 365 | "name": "python", 366 | "nbconvert_exporter": "python", 367 | "pygments_lexer": "ipython3", 368 | "version": "3.7.4" 369 | } 370 | }, 371 | "nbformat": 4, 372 | "nbformat_minor": 4 373 | } 374 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='contextual_loss_pytorch', 5 | version='latest', 6 | description='Contextual Loss w/ PyTorch', 7 | packages=find_packages(exclude=('tests', 'doc')), 8 | author='So Uchida', 9 | author_email='s.aiueo32@gmail.com', 10 | install_requires=["torch", "torchvision"], 11 | url='https://github.com/S-aiueo32/contextual_loss_pytorch', 12 | ) 13 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/S-aiueo32/contextual_loss_pytorch/c886571b07be95788e99e6751560216acd54dae3/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_contextual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from contextual_loss import functional as F 4 | from contextual_loss import ContextualLoss 5 | 6 | test_shape = [1, 256, 48, 48] 7 | 8 | 9 | def test_module(): 10 | prediction = torch.rand(*test_shape) 11 | f = ContextualLoss() 12 | loss = f(prediction, prediction) 13 | assert loss.shape == torch.Size([]) 14 | 15 | 16 | def test_module_gpu(): 17 | prediction = torch.rand(*test_shape).to('cuda:0') 18 | f = ContextualLoss().to('cuda:0') 19 | loss = f(prediction, prediction) 20 | assert loss.shape == torch.Size([]) 21 | 22 | 23 | def test_vgg(): 24 | prediction = torch.rand(test_shape[0], 3, *test_shape[2:]) 25 | f = ContextualLoss(use_vgg=True) 26 | loss = f(prediction, prediction) 27 | assert loss.shape == torch.Size([]) 28 | 29 | 30 | def test_vgg_gpu(): 31 | prediction = torch.rand(test_shape[0], 3, *test_shape[2:]).to('cuda:0') 32 | f = ContextualLoss(use_vgg=True).to('cuda:0') 33 | loss = f(prediction, prediction) 34 | assert loss.shape == torch.Size([]) 35 | 36 | 37 | def test_cosine(): 38 | prediction = torch.rand(*test_shape) 39 | loss = F.contextual_loss(prediction, prediction, loss_type='cosine') 40 | assert loss.shape == torch.Size([]) 41 | 42 | 43 | def test_cosine_gpu(): 44 | prediction = torch.rand(*test_shape).to('cuda:0') 45 | loss = F.contextual_loss(prediction, prediction, loss_type='cosine') 46 | assert loss.shape == torch.Size([]) 47 | 48 | 49 | def test_l1(): 50 | prediction = torch.rand(*test_shape) 51 | loss = F.contextual_loss(prediction, prediction, loss_type='l1') 52 | assert loss.shape == torch.Size([]) 53 | 54 | 55 | def test_l1_gpu(): 56 | prediction = torch.rand(*test_shape).to('cuda:0') 57 | loss = F.contextual_loss(prediction, prediction, loss_type='l1') 58 | assert loss.shape == torch.Size([]) 59 | 60 | 61 | def test_l2(): 62 | prediction = torch.rand(*test_shape) 63 | loss = F.contextual_loss(prediction, prediction, loss_type='l2') 64 | assert loss.shape == torch.Size([]) 65 | 66 | 67 | def test_l2_gpu(): 68 | prediction = torch.rand(*test_shape).to('cuda:0') 69 | loss = F.contextual_loss(prediction, prediction, loss_type='l2') 70 | assert loss.shape == torch.Size([]) 71 | -------------------------------------------------------------------------------- /tests/test_contextual_bilateral.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from contextual_loss import functional as F 4 | from contextual_loss import ContextualBilateralLoss 5 | 6 | test_shape = [1, 256, 48, 48] 7 | 8 | 9 | def test_module(): 10 | prediction = torch.rand(*test_shape) 11 | f = ContextualBilateralLoss() 12 | loss = f(prediction, prediction) 13 | assert loss.shape == torch.Size([]) 14 | 15 | 16 | def test_module_gpu(): 17 | prediction = torch.rand(*test_shape).to('cuda:0') 18 | f = ContextualBilateralLoss().to('cuda:0') 19 | loss = f(prediction, prediction) 20 | assert loss.shape == torch.Size([]) 21 | 22 | 23 | def test_vgg(): 24 | prediction = torch.rand(test_shape[0], 3, *test_shape[2:]) 25 | f = ContextualBilateralLoss(use_vgg=True) 26 | loss = f(prediction, prediction) 27 | assert loss.shape == torch.Size([]) 28 | 29 | 30 | def test_vgg_gpu(): 31 | prediction = torch.rand(test_shape[0], 3, *test_shape[2:]).to('cuda:0') 32 | f = ContextualBilateralLoss(use_vgg=True).to('cuda:0') 33 | loss = f(prediction, prediction) 34 | assert loss.shape == torch.Size([]) 35 | 36 | 37 | def test_cosine(): 38 | prediction = torch.rand(*test_shape) 39 | loss = F.contextual_bilateral_loss( 40 | prediction, prediction, loss_type='cosine') 41 | assert loss.shape == torch.Size([]) 42 | 43 | 44 | def test_cosine_gpu(): 45 | prediction = torch.rand(*test_shape).to('cuda:0') 46 | loss = F.contextual_bilateral_loss( 47 | prediction, prediction, loss_type='cosine') 48 | assert loss.shape == torch.Size([]) 49 | 50 | 51 | def test_l1(): 52 | prediction = torch.rand(*test_shape) 53 | loss = F.contextual_bilateral_loss( 54 | prediction, prediction, loss_type='l1') 55 | assert loss.shape == torch.Size([]) 56 | 57 | 58 | def test_l1_gpu(): 59 | prediction = torch.rand(*test_shape).to('cuda:0') 60 | loss = F.contextual_bilateral_loss( 61 | prediction, prediction, loss_type='l1') 62 | assert loss.shape == torch.Size([]) 63 | 64 | 65 | def test_l2(): 66 | prediction = torch.rand(*test_shape) 67 | loss = F.contextual_bilateral_loss( 68 | prediction, prediction, loss_type='l2') 69 | assert loss.shape == torch.Size([]) 70 | 71 | 72 | def test_l2_gpu(): 73 | prediction = torch.rand(*test_shape).to('cuda:0') 74 | loss = F.contextual_bilateral_loss( 75 | prediction, prediction, loss_type='l2') 76 | assert loss.shape == torch.Size([]) 77 | --------------------------------------------------------------------------------