├── .gitignore ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── contextual_loss ├── __init__.py ├── config.py ├── functional.py └── modules │ ├── __init__.py │ ├── contextual.py │ ├── contextual_bilateral.py │ └── vgg.py ├── doc └── small_example.ipynb ├── setup.py └── tests ├── __init__.py ├── test_contextual.py └── test_contextual_bilateral.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Sou Uchida 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | pytest = "*" 8 | flake8 = "*" 9 | pylint = "*" 10 | matplotlib = "*" 11 | ipython = "*" 12 | 13 | [packages] 14 | torch = "*" 15 | torchvision = "*" 16 | 17 | [requires] 18 | python_version = "3.7" 19 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "d73f48ef2416fa486483d856a1b6fb29ee661445de7c6d3b7ded11f777cabe2c" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3.7" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": { 19 | "numpy": { 20 | "hashes": [ 21 | "sha256:0a7a1dd123aecc9f0076934288ceed7fd9a81ba3919f11a855a7887cbe82a02f", 22 | "sha256:0c0763787133dfeec19904c22c7e358b231c87ba3206b211652f8cbe1241deb6", 23 | "sha256:3d52298d0be333583739f1aec9026f3b09fdfe3ddf7c7028cb16d9d2af1cca7e", 24 | "sha256:43bb4b70585f1c2d153e45323a886839f98af8bfa810f7014b20be714c37c447", 25 | "sha256:475963c5b9e116c38ad7347e154e5651d05a2286d86455671f5b1eebba5feb76", 26 | "sha256:64874913367f18eb3013b16123c9fed113962e75d809fca5b78ebfbb73ed93ba", 27 | "sha256:683828e50c339fc9e68720396f2de14253992c495fdddef77a1e17de55f1decc", 28 | "sha256:6ca4000c4a6f95a78c33c7dadbb9495c10880be9c89316aa536eac359ab820ae", 29 | "sha256:75fd817b7061f6378e4659dd792c84c0b60533e867f83e0d1e52d5d8e53df88c", 30 | "sha256:7d81d784bdbed30137aca242ab307f3e65c8d93f4c7b7d8f322110b2e90177f9", 31 | "sha256:8d0af8d3664f142414fd5b15cabfd3b6cc3ef242a3c7a7493257025be5a6955f", 32 | "sha256:9679831005fb16c6df3dd35d17aa31dc0d4d7573d84f0b44cc481490a65c7725", 33 | "sha256:a8f67ebfae9f575d85fa859b54d3bdecaeece74e3274b0b5c5f804d7ca789fe1", 34 | "sha256:acbf5c52db4adb366c064d0b7c7899e3e778d89db585feadd23b06b587d64761", 35 | "sha256:ada4805ed51f5bcaa3a06d3dd94939351869c095e30a2b54264f5a5004b52170", 36 | "sha256:c7354e8f0eca5c110b7e978034cd86ed98a7a5ffcf69ca97535445a595e07b8e", 37 | "sha256:e2e9d8c87120ba2c591f60e32736b82b67f72c37ba88a4c23c81b5b8fa49c018", 38 | "sha256:e467c57121fe1b78a8f68dd9255fbb3bb3f4f7547c6b9e109f31d14569f490c3", 39 | "sha256:ede47b98de79565fcd7f2decb475e2dcc85ee4097743e551fe26cfc7eb3ff143", 40 | "sha256:f58913e9227400f1395c7b800503ebfdb0772f1c33ff8cb4d6451c06cabdf316", 41 | "sha256:fe39f5fd4103ec4ca3cb8600b19216cd1ff316b4990f4c0b6057ad982c0a34d5" 42 | ], 43 | "version": "==1.17.4" 44 | }, 45 | "pillow": { 46 | "hashes": [ 47 | "sha256:047d9473cf68af50ac85f8ee5d5f21a60f849bc17d348da7fc85711287a75031", 48 | "sha256:0f66dc6c8a3cc319561a633b6aa82c44107f12594643efa37210d8c924fc1c71", 49 | "sha256:12c9169c4e8fe0a7329e8658c7e488001f6b4c8e88740e76292c2b857af2e94c", 50 | "sha256:248cffc168896982f125f5c13e9317c059f74fffdb4152893339f3be62a01340", 51 | "sha256:27faf0552bf8c260a5cee21a76e031acaea68babb64daf7e8f2e2540745082aa", 52 | "sha256:285edafad9bc60d96978ed24d77cdc0b91dace88e5da8c548ba5937c425bca8b", 53 | "sha256:384b12c9aa8ef95558abdcb50aada56d74bc7cc131dd62d28c2d0e4d3aadd573", 54 | "sha256:38950b3a707f6cef09cd3cbb142474357ad1a985ceb44d921bdf7b4647b3e13e", 55 | "sha256:4aad1b88933fd6dc2846552b89ad0c74ddbba2f0884e2c162aa368374bf5abab", 56 | "sha256:4ac6148008c169603070c092e81f88738f1a0c511e07bd2bb0f9ef542d375da9", 57 | "sha256:4deb1d2a45861ae6f0b12ea0a786a03d19d29edcc7e05775b85ec2877cb54c5e", 58 | "sha256:59aa2c124df72cc75ed72c8d6005c442d4685691a30c55321e00ed915ad1a291", 59 | "sha256:5a47d2123a9ec86660fe0e8d0ebf0aa6bc6a17edc63f338b73ea20ba11713f12", 60 | "sha256:5cc901c2ab9409b4b7ac7b5bcc3e86ac14548627062463da0af3b6b7c555a871", 61 | "sha256:6c1db03e8dff7b9f955a0fb9907eb9ca5da75b5ce056c0c93d33100a35050281", 62 | "sha256:7ce80c0a65a6ea90ef9c1f63c8593fcd2929448613fc8da0adf3e6bfad669d08", 63 | "sha256:809c19241c14433c5d6135e1b6c72da4e3b56d5c865ad5736ab99af8896b8f41", 64 | "sha256:83792cb4e0b5af480588601467c0764242b9a483caea71ef12d22a0d0d6bdce2", 65 | "sha256:846fa202bd7ee0f6215c897a1d33238ef071b50766339186687bd9b7a6d26ac5", 66 | "sha256:9f5529fc02009f96ba95bea48870173426879dc19eec49ca8e08cd63ecd82ddb", 67 | "sha256:a423c2ea001c6265ed28700df056f75e26215fd28c001e93ef4380b0f05f9547", 68 | "sha256:ac4428094b42907aba5879c7c000d01c8278d451a3b7cccd2103e21f6397ea75", 69 | "sha256:b1ae48d87f10d1384e5beecd169c77502fcc04a2c00a4c02b85f0a94b419e5f9", 70 | "sha256:bf4e972a88f8841d8fdc6db1a75e0f8d763e66e3754b03006cbc3854d89f1cb1", 71 | "sha256:c6414f6aad598364aaf81068cabb077894eb88fed99c6a65e6e8217bab62ae7a", 72 | "sha256:c710fcb7ee32f67baf25aa9ffede4795fd5d93b163ce95fdc724383e38c9df96", 73 | "sha256:c7be4b8a09852291c3c48d3c25d1b876d2494a0a674980089ac9d5e0d78bd132", 74 | "sha256:c9e5ffb910b14f090ac9c38599063e354887a5f6d7e6d26795e916b4514f2c1a", 75 | "sha256:e0697b826da6c2472bb6488db4c0a7fa8af0d52fa08833ceb3681358914b14e5", 76 | "sha256:e9a3edd5f714229d41057d56ac0f39ad9bdba6767e8c888c951869f0bdd129b0" 77 | ], 78 | "version": "==6.2.1" 79 | }, 80 | "six": { 81 | "hashes": [ 82 | "sha256:1f1b7d42e254082a9db6279deae68afb421ceba6158efa6131de7b3003ee93fd", 83 | "sha256:30f610279e8b2578cab6db20741130331735c781b56053c59c4076da27f06b66" 84 | ], 85 | "version": "==1.13.0" 86 | }, 87 | "torch": { 88 | "hashes": [ 89 | "sha256:0cec2e13a2e95c24c34f17d437f354ee2a40902e8d515a524556b350e12555dd", 90 | "sha256:134e8291a97151b1ffeea09cb9ddde5238beb4e6d9dfb66657143d6990bfb865", 91 | "sha256:31062923ac2e60eac676f6a0ae14702b051c158bbcf7f440eaba266b0defa197", 92 | "sha256:3b05233481b51bb636cee63dc761bb7f602e198178782ff4159d385d1759608b", 93 | "sha256:458f1d87e5b7064b2c39e36675d84e163be3143dd2fc806057b7878880c461bc", 94 | "sha256:72a1c85bffd2154f085bc0a1d378d8a54e55a57d49664b874fe7c949022bf071", 95 | "sha256:77fd8866c0bf529861ffd850a5dada2190a8d9c5167719fb0cfa89163e23b143", 96 | "sha256:b6f01d851d1c5989d4a99b50ae0187762b15b7718dcd1a33704b665daa2402f9", 97 | "sha256:d8e1d904a6193ed14a4fed220b00503b2baa576e71471286d1ebba899c851fae" 98 | ], 99 | "index": "pypi", 100 | "version": "==1.3.1" 101 | }, 102 | "torchvision": { 103 | "hashes": [ 104 | "sha256:0f8245d6378acc86917f58492675f93df5279abae8bc5f832e3510722191f6c9", 105 | "sha256:1ad7593d94f6612ccb84a59467f0d10cdc213fb3e2bb91f1e773eb844787fa4c", 106 | "sha256:2553405b9afe3cedb410873b9877eb18b1526f8b01cb7c2747e51b69a936e0b5", 107 | "sha256:276a385f2f5fe484bf08467b5d081d9144b97eb458ba5b4a11e4640389e53149", 108 | "sha256:66deba9c577e36f4f071decdd894bf7ba794ac133dae64b3fd02fc3f0c6b989d", 109 | "sha256:7a458330e4efcd66f9f70127ab21fcf8cfea84acda8e707322fd2843aa6dd396", 110 | "sha256:8ff715c2323d9eca89126824ebfa74b282a95d6f64a4743fbe9b738d2de21c77", 111 | "sha256:dca4aadc12a123730957b501f9c5c2870d2f6727a2c28552cb7907b68b0ea10c", 112 | "sha256:dda25ce304978bba19e6543f7dcfee4f37d2f128ec83d4ab0c7e8f991d64865f" 113 | ], 114 | "index": "pypi", 115 | "version": "==0.4.2" 116 | } 117 | }, 118 | "develop": { 119 | "astroid": { 120 | "hashes": [ 121 | "sha256:71ea07f44df9568a75d0f354c49143a4575d90645e9fead6dfb52c26a85ed13a", 122 | "sha256:840947ebfa8b58f318d42301cf8c0a20fd794a33b61cc4638e28e9e61ba32f42" 123 | ], 124 | "version": "==2.3.3" 125 | }, 126 | "attrs": { 127 | "hashes": [ 128 | "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", 129 | "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72" 130 | ], 131 | "version": "==19.3.0" 132 | }, 133 | "backcall": { 134 | "hashes": [ 135 | "sha256:38ecd85be2c1e78f77fd91700c76e14667dc21e2713b63876c0eb901196e01e4", 136 | "sha256:bbbf4b1e5cd2bdb08f915895b51081c041bac22394fdfcfdfbe9f14b77c08bf2" 137 | ], 138 | "version": "==0.1.0" 139 | }, 140 | "cycler": { 141 | "hashes": [ 142 | "sha256:1d8a5ae1ff6c5cf9b93e8811e581232ad8920aeec647c37316ceac982b08cb2d", 143 | "sha256:cd7b2d1018258d7247a71425e9f26463dfb444d411c39569972f4ce586b0c9d8" 144 | ], 145 | "version": "==0.10.0" 146 | }, 147 | "decorator": { 148 | "hashes": [ 149 | "sha256:54c38050039232e1db4ad7375cfce6748d7b41c29e95a081c8a6d2c30364a2ce", 150 | "sha256:5d19b92a3c8f7f101c8dd86afd86b0f061a8ce4540ab8cd401fa2542756bce6d" 151 | ], 152 | "version": "==4.4.1" 153 | }, 154 | "entrypoints": { 155 | "hashes": [ 156 | "sha256:589f874b313739ad35be6e0cd7efde2a4e9b6fea91edcc34e58ecbb8dbe56d19", 157 | "sha256:c70dd71abe5a8c85e55e12c19bd91ccfeec11a6e99044204511f9ed547d48451" 158 | ], 159 | "version": "==0.3" 160 | }, 161 | "flake8": { 162 | "hashes": [ 163 | "sha256:45681a117ecc81e870cbf1262835ae4af5e7a8b08e40b944a8a6e6b895914cfb", 164 | "sha256:49356e766643ad15072a789a20915d3c91dc89fd313ccd71802303fd67e4deca" 165 | ], 166 | "index": "pypi", 167 | "version": "==3.7.9" 168 | }, 169 | "importlib-metadata": { 170 | "hashes": [ 171 | "sha256:aa18d7378b00b40847790e7c27e11673d7fed219354109d0e7b9e5b25dc3ad26", 172 | "sha256:d5f18a79777f3aa179c145737780282e27b508fc8fd688cb17c7a813e8bd39af" 173 | ], 174 | "markers": "python_version < '3.8'", 175 | "version": "==0.23" 176 | }, 177 | "ipython": { 178 | "hashes": [ 179 | "sha256:060d19feef09453d3375ab23c7295ed36cb59e5a3904598ab903f93ec45f1f63", 180 | "sha256:e468b8f03a0168a667982b50f0b4e0828cc32721bbea32b23934e55b7970eb7a" 181 | ], 182 | "index": "pypi", 183 | "version": "==7.10.0" 184 | }, 185 | "ipython-genutils": { 186 | "hashes": [ 187 | "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8", 188 | "sha256:eb2e116e75ecef9d4d228fdc66af54269afa26ab4463042e33785b887c628ba8" 189 | ], 190 | "version": "==0.2.0" 191 | }, 192 | "isort": { 193 | "hashes": [ 194 | "sha256:54da7e92468955c4fceacd0c86bd0ec997b0e1ee80d97f67c35a78b719dccab1", 195 | "sha256:6e811fcb295968434526407adb8796944f1988c5b65e8139058f2014cbe100fd" 196 | ], 197 | "version": "==4.3.21" 198 | }, 199 | "jedi": { 200 | "hashes": [ 201 | "sha256:786b6c3d80e2f06fd77162a07fed81b8baa22dde5d62896a790a331d6ac21a27", 202 | "sha256:ba859c74fa3c966a22f2aeebe1b74ee27e2a462f56d3f5f7ca4a59af61bfe42e" 203 | ], 204 | "version": "==0.15.1" 205 | }, 206 | "kiwisolver": { 207 | "hashes": [ 208 | "sha256:05b5b061e09f60f56244adc885c4a7867da25ca387376b02c1efc29cc16bcd0f", 209 | "sha256:210d8c39d01758d76c2b9a693567e1657ec661229bc32eac30761fa79b2474b0", 210 | "sha256:26f4fbd6f5e1dabff70a9ba0d2c4bd30761086454aa30dddc5b52764ee4852b7", 211 | "sha256:3b15d56a9cd40c52d7ab763ff0bc700edbb4e1a298dc43715ecccd605002cf11", 212 | "sha256:3b2378ad387f49cbb328205bda569b9f87288d6bc1bf4cd683c34523a2341efe", 213 | "sha256:400599c0fe58d21522cae0e8b22318e09d9729451b17ee61ba8e1e7c0346565c", 214 | "sha256:47b8cb81a7d18dbaf4fed6a61c3cecdb5adec7b4ac292bddb0d016d57e8507d5", 215 | "sha256:53eaed412477c836e1b9522c19858a8557d6e595077830146182225613b11a75", 216 | "sha256:58e626e1f7dfbb620d08d457325a4cdac65d1809680009f46bf41eaf74ad0187", 217 | "sha256:5a52e1b006bfa5be04fe4debbcdd2688432a9af4b207a3f429c74ad625022641", 218 | "sha256:5c7ca4e449ac9f99b3b9d4693debb1d6d237d1542dd6a56b3305fe8a9620f883", 219 | "sha256:682e54f0ce8f45981878756d7203fd01e188cc6c8b2c5e2cf03675390b4534d5", 220 | "sha256:76275ee077772c8dde04fb6c5bc24b91af1bb3e7f4816fd1852f1495a64dad93", 221 | "sha256:79bfb2f0bd7cbf9ea256612c9523367e5ec51d7cd616ae20ca2c90f575d839a2", 222 | "sha256:7f4dd50874177d2bb060d74769210f3bce1af87a8c7cf5b37d032ebf94f0aca3", 223 | "sha256:8944a16020c07b682df861207b7e0efcd2f46c7488619cb55f65882279119389", 224 | "sha256:8aa7009437640beb2768bfd06da049bad0df85f47ff18426261acecd1cf00897", 225 | "sha256:9105ce82dcc32c73eb53a04c869b6a4bc756b43e4385f76ea7943e827f529e4d", 226 | "sha256:933df612c453928f1c6faa9236161a1d999a26cd40abf1dc5d7ebbc6dbfb8fca", 227 | "sha256:939f36f21a8c571686eb491acfffa9c7f1ac345087281b412d63ea39ca14ec4a", 228 | "sha256:9491578147849b93e70d7c1d23cb1229458f71fc79c51d52dce0809b2ca44eea", 229 | "sha256:9733b7f64bd9f807832d673355f79703f81f0b3e52bfce420fc00d8cb28c6a6c", 230 | "sha256:a02f6c3e229d0b7220bd74600e9351e18bc0c361b05f29adae0d10599ae0e326", 231 | "sha256:a0c0a9f06872330d0dd31b45607197caab3c22777600e88031bfe66799e70bb0", 232 | "sha256:aa716b9122307c50686356cfb47bfbc66541868078d0c801341df31dca1232a9", 233 | "sha256:acc4df99308111585121db217681f1ce0eecb48d3a828a2f9bbf9773f4937e9e", 234 | "sha256:b64916959e4ae0ac78af7c3e8cef4becee0c0e9694ad477b4c6b3a536de6a544", 235 | "sha256:d22702cadb86b6fcba0e6b907d9f84a312db9cd6934ee728144ce3018e715ee1", 236 | "sha256:d3fcf0819dc3fea58be1fd1ca390851bdb719a549850e708ed858503ff25d995", 237 | "sha256:d52e3b1868a4e8fd18b5cb15055c76820df514e26aa84cc02f593d99fef6707f", 238 | "sha256:db1a5d3cc4ae943d674718d6c47d2d82488ddd94b93b9e12d24aabdbfe48caee", 239 | "sha256:e3a21a720791712ed721c7b95d433e036134de6f18c77dbe96119eaf7aa08004", 240 | "sha256:e8bf074363ce2babeb4764d94f8e65efd22e6a7c74860a4f05a6947afc020ff2", 241 | "sha256:f16814a4a96dc04bf1da7d53ee8d5b1d6decfc1a92a63349bb15d37b6a263dd9", 242 | "sha256:f2b22153870ca5cf2ab9c940d7bc38e8e9089fa0f7e5856ea195e1cf4ff43d5a", 243 | "sha256:f790f8b3dff3d53453de6a7b7ddd173d2e020fb160baff578d578065b108a05f", 244 | "sha256:fe51b79da0062f8e9d49ed0182a626a7dc7a0cbca0328f612c6ee5e4711c81e4" 245 | ], 246 | "version": "==1.1.0" 247 | }, 248 | "lazy-object-proxy": { 249 | "hashes": [ 250 | "sha256:0c4b206227a8097f05c4dbdd323c50edf81f15db3b8dc064d08c62d37e1a504d", 251 | "sha256:194d092e6f246b906e8f70884e620e459fc54db3259e60cf69a4d66c3fda3449", 252 | "sha256:1be7e4c9f96948003609aa6c974ae59830a6baecc5376c25c92d7d697e684c08", 253 | "sha256:4677f594e474c91da97f489fea5b7daa17b5517190899cf213697e48d3902f5a", 254 | "sha256:48dab84ebd4831077b150572aec802f303117c8cc5c871e182447281ebf3ac50", 255 | "sha256:5541cada25cd173702dbd99f8e22434105456314462326f06dba3e180f203dfd", 256 | "sha256:59f79fef100b09564bc2df42ea2d8d21a64fdcda64979c0fa3db7bdaabaf6239", 257 | "sha256:8d859b89baf8ef7f8bc6b00aa20316483d67f0b1cbf422f5b4dc56701c8f2ffb", 258 | "sha256:9254f4358b9b541e3441b007a0ea0764b9d056afdeafc1a5569eee1cc6c1b9ea", 259 | "sha256:9651375199045a358eb6741df3e02a651e0330be090b3bc79f6d0de31a80ec3e", 260 | "sha256:97bb5884f6f1cdce0099f86b907aa41c970c3c672ac8b9c8352789e103cf3156", 261 | "sha256:9b15f3f4c0f35727d3a0fba4b770b3c4ebbb1fa907dbcc046a1d2799f3edd142", 262 | "sha256:a2238e9d1bb71a56cd710611a1614d1194dc10a175c1e08d75e1a7bcc250d442", 263 | "sha256:a6ae12d08c0bf9909ce12385803a543bfe99b95fe01e752536a60af2b7797c62", 264 | "sha256:ca0a928a3ddbc5725be2dd1cf895ec0a254798915fb3a36af0964a0a4149e3db", 265 | "sha256:cb2c7c57005a6804ab66f106ceb8482da55f5314b7fcb06551db1edae4ad1531", 266 | "sha256:d74bb8693bf9cf75ac3b47a54d716bbb1a92648d5f781fc799347cfc95952383", 267 | "sha256:d945239a5639b3ff35b70a88c5f2f491913eb94871780ebfabb2568bd58afc5a", 268 | "sha256:eba7011090323c1dadf18b3b689845fd96a61ba0a1dfbd7f24b921398affc357", 269 | "sha256:efa1909120ce98bbb3777e8b6f92237f5d5c8ea6758efea36a473e1d38f7d3e4", 270 | "sha256:f3900e8a5de27447acbf900b4750b0ddfd7ec1ea7fbaf11dfa911141bc522af0" 271 | ], 272 | "version": "==1.4.3" 273 | }, 274 | "matplotlib": { 275 | "hashes": [ 276 | "sha256:08ccc8922eb4792b91c652d3e6d46b1c99073f1284d1b6705155643e8046463a", 277 | "sha256:161dcd807c0c3232f4dcd4a12a382d52004a498174cbfafd40646106c5bcdcc8", 278 | "sha256:1f9e885bfa1b148d16f82a6672d043ecf11197f6c71ae222d0546db706e52eb2", 279 | "sha256:2d6ab54015a7c0d727c33e36f85f5c5e4172059efdd067f7527f6e5d16ad01aa", 280 | "sha256:5d2e408a2813abf664bd79431107543ecb449136912eb55bb312317edecf597e", 281 | "sha256:61c8b740a008218eb604de518eb411c4953db0cb725dd0b32adf8a81771cab9e", 282 | "sha256:80f10af8378fccc136da40ea6aa4a920767476cdfb3241acb93ef4f0465dbf57", 283 | "sha256:819d4860315468b482f38f1afe45a5437f60f03eaede495d5ff89f2eeac89500", 284 | "sha256:8cc0e44905c2c8fda5637cad6f311eb9517017515a034247ab93d0cf99f8bb7a", 285 | "sha256:8e8e2c2fe3d873108735c6ee9884e6f36f467df4a143136209cff303b183bada", 286 | "sha256:98c2ffeab8b79a4e3a0af5dd9939f92980eb6e3fec10f7f313df5f35a84dacab", 287 | "sha256:d59bb0e82002ac49f4152963f8a1079e66794a4f454457fd2f0dcc7bf0797d30", 288 | "sha256:ee59b7bb9eb75932fe3787e54e61c99b628155b0cedc907864f24723ba55b309" 289 | ], 290 | "index": "pypi", 291 | "version": "==3.1.2" 292 | }, 293 | "mccabe": { 294 | "hashes": [ 295 | "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42", 296 | "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f" 297 | ], 298 | "version": "==0.6.1" 299 | }, 300 | "more-itertools": { 301 | "hashes": [ 302 | "sha256:53ff73f186307d9c8ef17a9600309154a6ae27f25579e80af4db8f047ba14bc2", 303 | "sha256:a0ea684c39bc4315ba7aae406596ef191fd84f873d2d2751f84d64e81a7a2d45" 304 | ], 305 | "version": "==8.0.0" 306 | }, 307 | "numpy": { 308 | "hashes": [ 309 | "sha256:0a7a1dd123aecc9f0076934288ceed7fd9a81ba3919f11a855a7887cbe82a02f", 310 | "sha256:0c0763787133dfeec19904c22c7e358b231c87ba3206b211652f8cbe1241deb6", 311 | "sha256:3d52298d0be333583739f1aec9026f3b09fdfe3ddf7c7028cb16d9d2af1cca7e", 312 | "sha256:43bb4b70585f1c2d153e45323a886839f98af8bfa810f7014b20be714c37c447", 313 | "sha256:475963c5b9e116c38ad7347e154e5651d05a2286d86455671f5b1eebba5feb76", 314 | "sha256:64874913367f18eb3013b16123c9fed113962e75d809fca5b78ebfbb73ed93ba", 315 | "sha256:683828e50c339fc9e68720396f2de14253992c495fdddef77a1e17de55f1decc", 316 | "sha256:6ca4000c4a6f95a78c33c7dadbb9495c10880be9c89316aa536eac359ab820ae", 317 | "sha256:75fd817b7061f6378e4659dd792c84c0b60533e867f83e0d1e52d5d8e53df88c", 318 | "sha256:7d81d784bdbed30137aca242ab307f3e65c8d93f4c7b7d8f322110b2e90177f9", 319 | "sha256:8d0af8d3664f142414fd5b15cabfd3b6cc3ef242a3c7a7493257025be5a6955f", 320 | "sha256:9679831005fb16c6df3dd35d17aa31dc0d4d7573d84f0b44cc481490a65c7725", 321 | "sha256:a8f67ebfae9f575d85fa859b54d3bdecaeece74e3274b0b5c5f804d7ca789fe1", 322 | "sha256:acbf5c52db4adb366c064d0b7c7899e3e778d89db585feadd23b06b587d64761", 323 | "sha256:ada4805ed51f5bcaa3a06d3dd94939351869c095e30a2b54264f5a5004b52170", 324 | "sha256:c7354e8f0eca5c110b7e978034cd86ed98a7a5ffcf69ca97535445a595e07b8e", 325 | "sha256:e2e9d8c87120ba2c591f60e32736b82b67f72c37ba88a4c23c81b5b8fa49c018", 326 | "sha256:e467c57121fe1b78a8f68dd9255fbb3bb3f4f7547c6b9e109f31d14569f490c3", 327 | "sha256:ede47b98de79565fcd7f2decb475e2dcc85ee4097743e551fe26cfc7eb3ff143", 328 | "sha256:f58913e9227400f1395c7b800503ebfdb0772f1c33ff8cb4d6451c06cabdf316", 329 | "sha256:fe39f5fd4103ec4ca3cb8600b19216cd1ff316b4990f4c0b6057ad982c0a34d5" 330 | ], 331 | "version": "==1.17.4" 332 | }, 333 | "packaging": { 334 | "hashes": [ 335 | "sha256:28b924174df7a2fa32c1953825ff29c61e2f5e082343165438812f00d3a7fc47", 336 | "sha256:d9551545c6d761f3def1677baf08ab2a3ca17c56879e70fecba2fc4dde4ed108" 337 | ], 338 | "version": "==19.2" 339 | }, 340 | "parso": { 341 | "hashes": [ 342 | "sha256:63854233e1fadb5da97f2744b6b24346d2750b85965e7e399bec1620232797dc", 343 | "sha256:666b0ee4a7a1220f65d367617f2cd3ffddff3e205f3f16a0284df30e774c2a9c" 344 | ], 345 | "version": "==0.5.1" 346 | }, 347 | "pexpect": { 348 | "hashes": [ 349 | "sha256:2094eefdfcf37a1fdbfb9aa090862c1a4878e5c7e0e7e7088bdb511c558e5cd1", 350 | "sha256:9e2c1fd0e6ee3a49b28f95d4b33bc389c89b20af6a1255906e90ff1262ce62eb" 351 | ], 352 | "markers": "sys_platform != 'win32'", 353 | "version": "==4.7.0" 354 | }, 355 | "pickleshare": { 356 | "hashes": [ 357 | "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca", 358 | "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56" 359 | ], 360 | "version": "==0.7.5" 361 | }, 362 | "pluggy": { 363 | "hashes": [ 364 | "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0", 365 | "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d" 366 | ], 367 | "version": "==0.13.1" 368 | }, 369 | "prompt-toolkit": { 370 | "hashes": [ 371 | "sha256:0278d2f51b5ceba6ea8da39f76d15684e84c996b325475f6e5720edc584326a7", 372 | "sha256:63daee79aa8366c8f1c637f1a4876b890da5fc92a19ebd2f7080ebacb901e990" 373 | ], 374 | "version": "==3.0.2" 375 | }, 376 | "ptyprocess": { 377 | "hashes": [ 378 | "sha256:923f299cc5ad920c68f2bc0bc98b75b9f838b93b599941a6b63ddbc2476394c0", 379 | "sha256:d7cc528d76e76342423ca640335bd3633420dc1366f258cb31d05e865ef5ca1f" 380 | ], 381 | "version": "==0.6.0" 382 | }, 383 | "py": { 384 | "hashes": [ 385 | "sha256:64f65755aee5b381cea27766a3a147c3f15b9b6b9ac88676de66ba2ae36793fa", 386 | "sha256:dc639b046a6e2cff5bbe40194ad65936d6ba360b52b3c3fe1d08a82dd50b5e53" 387 | ], 388 | "version": "==1.8.0" 389 | }, 390 | "pycodestyle": { 391 | "hashes": [ 392 | "sha256:95a2219d12372f05704562a14ec30bc76b05a5b297b21a5dfe3f6fac3491ae56", 393 | "sha256:e40a936c9a450ad81df37f549d676d127b1b66000a6c500caa2b085bc0ca976c" 394 | ], 395 | "version": "==2.5.0" 396 | }, 397 | "pyflakes": { 398 | "hashes": [ 399 | "sha256:17dbeb2e3f4d772725c777fabc446d5634d1038f234e77343108ce445ea69ce0", 400 | "sha256:d976835886f8c5b31d47970ed689944a0262b5f3afa00a5a7b4dc81e5449f8a2" 401 | ], 402 | "version": "==2.1.1" 403 | }, 404 | "pygments": { 405 | "hashes": [ 406 | "sha256:2a3fe295e54a20164a9df49c75fa58526d3be48e14aceba6d6b1e8ac0bfd6f1b", 407 | "sha256:98c8aa5a9f778fcd1026a17361ddaf7330d1b7c62ae97c3bb0ae73e0b9b6b0fe" 408 | ], 409 | "version": "==2.5.2" 410 | }, 411 | "pylint": { 412 | "hashes": [ 413 | "sha256:3db5468ad013380e987410a8d6956226963aed94ecb5f9d3a28acca6d9ac36cd", 414 | "sha256:886e6afc935ea2590b462664b161ca9a5e40168ea99e5300935f6591ad467df4" 415 | ], 416 | "index": "pypi", 417 | "version": "==2.4.4" 418 | }, 419 | "pyparsing": { 420 | "hashes": [ 421 | "sha256:20f995ecd72f2a1f4bf6b072b63b22e2eb457836601e76d6e5dfcd75436acc1f", 422 | "sha256:4ca62001be367f01bd3e92ecbb79070272a9d4964dce6a48a82ff0b8bc7e683a" 423 | ], 424 | "version": "==2.4.5" 425 | }, 426 | "pytest": { 427 | "hashes": [ 428 | "sha256:63344a2e3bce2e4d522fd62b4fdebb647c019f1f9e4ca075debbd13219db4418", 429 | "sha256:f67403f33b2b1d25a6756184077394167fe5e2f9d8bdaab30707d19ccec35427" 430 | ], 431 | "index": "pypi", 432 | "version": "==5.3.1" 433 | }, 434 | "python-dateutil": { 435 | "hashes": [ 436 | "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c", 437 | "sha256:75bb3f31ea686f1197762692a9ee6a7550b59fc6ca3a1f4b5d7e32fb98e2da2a" 438 | ], 439 | "version": "==2.8.1" 440 | }, 441 | "six": { 442 | "hashes": [ 443 | "sha256:1f1b7d42e254082a9db6279deae68afb421ceba6158efa6131de7b3003ee93fd", 444 | "sha256:30f610279e8b2578cab6db20741130331735c781b56053c59c4076da27f06b66" 445 | ], 446 | "version": "==1.13.0" 447 | }, 448 | "traitlets": { 449 | "hashes": [ 450 | "sha256:70b4c6a1d9019d7b4f6846832288f86998aa3b9207c6821f3578a6a6a467fe44", 451 | "sha256:d023ee369ddd2763310e4c3eae1ff649689440d4ae59d7485eb4cfbbe3e359f7" 452 | ], 453 | "version": "==4.3.3" 454 | }, 455 | "typed-ast": { 456 | "hashes": [ 457 | "sha256:1170afa46a3799e18b4c977777ce137bb53c7485379d9706af8a59f2ea1aa161", 458 | "sha256:18511a0b3e7922276346bcb47e2ef9f38fb90fd31cb9223eed42c85d1312344e", 459 | "sha256:262c247a82d005e43b5b7f69aff746370538e176131c32dda9cb0f324d27141e", 460 | "sha256:2b907eb046d049bcd9892e3076c7a6456c93a25bebfe554e931620c90e6a25b0", 461 | "sha256:354c16e5babd09f5cb0ee000d54cfa38401d8b8891eefa878ac772f827181a3c", 462 | "sha256:48e5b1e71f25cfdef98b013263a88d7145879fbb2d5185f2a0c79fa7ebbeae47", 463 | "sha256:4e0b70c6fc4d010f8107726af5fd37921b666f5b31d9331f0bd24ad9a088e631", 464 | "sha256:630968c5cdee51a11c05a30453f8cd65e0cc1d2ad0d9192819df9978984529f4", 465 | "sha256:66480f95b8167c9c5c5c87f32cf437d585937970f3fc24386f313a4c97b44e34", 466 | "sha256:71211d26ffd12d63a83e079ff258ac9d56a1376a25bc80b1cdcdf601b855b90b", 467 | "sha256:7954560051331d003b4e2b3eb822d9dd2e376fa4f6d98fee32f452f52dd6ebb2", 468 | "sha256:838997f4310012cf2e1ad3803bce2f3402e9ffb71ded61b5ee22617b3a7f6b6e", 469 | "sha256:95bd11af7eafc16e829af2d3df510cecfd4387f6453355188342c3e79a2ec87a", 470 | "sha256:bc6c7d3fa1325a0c6613512a093bc2a2a15aeec350451cbdf9e1d4bffe3e3233", 471 | "sha256:cc34a6f5b426748a507dd5d1de4c1978f2eb5626d51326e43280941206c209e1", 472 | "sha256:d755f03c1e4a51e9b24d899561fec4ccaf51f210d52abdf8c07ee2849b212a36", 473 | "sha256:d7c45933b1bdfaf9f36c579671fec15d25b06c8398f113dab64c18ed1adda01d", 474 | "sha256:d896919306dd0aa22d0132f62a1b78d11aaf4c9fc5b3410d3c666b818191630a", 475 | "sha256:fdc1c9bbf79510b76408840e009ed65958feba92a88833cdceecff93ae8fff66", 476 | "sha256:ffde2fbfad571af120fcbfbbc61c72469e72f550d676c3342492a9dfdefb8f12" 477 | ], 478 | "markers": "implementation_name == 'cpython' and python_version < '3.8'", 479 | "version": "==1.4.0" 480 | }, 481 | "wcwidth": { 482 | "hashes": [ 483 | "sha256:3df37372226d6e63e1b1e1eda15c594bca98a22d33a23832a90998faa96bc65e", 484 | "sha256:f4ebe71925af7b40a864553f761ed559b43544f8f71746c2d756c7fe788ade7c" 485 | ], 486 | "version": "==0.1.7" 487 | }, 488 | "wrapt": { 489 | "hashes": [ 490 | "sha256:565a021fd19419476b9362b05eeaa094178de64f8361e44468f9e9d7843901e1" 491 | ], 492 | "version": "==1.11.2" 493 | }, 494 | "zipp": { 495 | "hashes": [ 496 | "sha256:3718b1cbcd963c7d4c5511a8240812904164b7f381b647143a89d3b98f9bcd8e", 497 | "sha256:f06903e9f1f43b12d371004b4ac7b06ab39a44adc747266928ae6debfa7b3335" 498 | ], 499 | "version": "==0.6.0" 500 | } 501 | } 502 | } 503 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Contextual Loss 2 | PyTorch implementation of Contextual Loss (CX) and Contextual Bilateral Loss (CoBi). 3 | 4 | ## Introduction 5 | There are many image transformation tasks whose spatially aligned data is hard to capture in the wild. 6 | Pixel-to-pixel or global loss functions can NOT be directly applied such unaligned data. 7 | CX is a loss function to defeat the problem. 8 | The key idea of CX is interpreting images as sets of feature points that don't have spatial coordinates. 9 | If you want to know more about CX, please refer the original [paper](https://arxiv.org/abs/1803.02077), [repo](https://github.com/roimehrez/contextualLoss) and examples in [./doc](./doc) directory. 10 | 11 | ## Requirements 12 | - Python3.7+ 13 | - `torch` & `torchvision` 14 | 15 | ## Installation 16 | ``` 17 | pip install git+https://github.com/S-aiueo32/contextual_loss_pytorch.git 18 | ``` 19 | 20 | ## Usage 21 | You can use it like PyTorch APIs. 22 | ```python 23 | import torch 24 | 25 | import contextual_loss as cl 26 | import contextual_loss.fuctional as F 27 | 28 | 29 | # input features 30 | img1 = torch.rand(1, 3, 96, 96) 31 | img2 = torch.rand(1, 3, 96, 96) 32 | 33 | # contextual loss 34 | criterion = cl.ContextualLoss() 35 | loss = criterion(img1, img2) 36 | 37 | # functional call 38 | loss = F.contextual_loss(img1, img2, band_width=0.1, loss_type='cosine') 39 | 40 | # comparing with VGG features 41 | # if `use_vgg` is set, VGG model will be created inside of the criterion 42 | criterion = cl.ContextualLoss(use_vgg=True, vgg_layer='relu5_4') 43 | loss = criterion(img1, img2) 44 | 45 | ``` 46 | 47 | ## Reference 48 | ### Papers 49 | 1. Mechrez, Roey, Itamar Talmi, and Lihi Zelnik-Manor. "The contextual loss for image transformation with non-aligned data." Proceedings of the European Conference on Computer Vision (ECCV). 2018. 50 | 2. Mechrez, Roey, et al. "Maintaining natural image statistics with the contextual loss." Asian Conference on Computer Vision. Springer, Cham, 2018. 51 | ### Implementations 52 | Thanks to the owners of the following awesome implementations. 53 | - Original Repository: https://github.com/roimehrez/contextualLoss 54 | - Simple PyTorch Implemantation: https://gist.github.com/yunjey/3105146c736f9c1055463c33b4c989da 55 | - CoBi: https://github.com/ceciliavision/zoom-learn-zoom 56 | -------------------------------------------------------------------------------- /contextual_loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | -------------------------------------------------------------------------------- /contextual_loss/config.py: -------------------------------------------------------------------------------- 1 | # TODO: add supports for L1, L2 etc. 2 | LOSS_TYPES = ['cosine', 'l1', 'l2'] 3 | -------------------------------------------------------------------------------- /contextual_loss/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .config import LOSS_TYPES 5 | 6 | __all__ = ['contextual_loss', 'contextual_bilateral_loss'] 7 | 8 | 9 | def contextual_loss(x: torch.Tensor, 10 | y: torch.Tensor, 11 | band_width: float = 0.5, 12 | loss_type: str = 'cosine'): 13 | """ 14 | Computes contextual loss between x and y. 15 | The most of this code is copied from 16 | https://gist.github.com/yunjey/3105146c736f9c1055463c33b4c989da. 17 | 18 | Parameters 19 | --- 20 | x : torch.Tensor 21 | features of shape (N, C, H, W). 22 | y : torch.Tensor 23 | features of shape (N, C, H, W). 24 | band_width : float, optional 25 | a band-width parameter used to convert distance to similarity. 26 | in the paper, this is described as :math:`h`. 27 | loss_type : str, optional 28 | a loss type to measure the distance between features. 29 | Note: `l1` and `l2` frequently raises OOM. 30 | 31 | Returns 32 | --- 33 | cx_loss : torch.Tensor 34 | contextual loss between x and y (Eq (1) in the paper) 35 | """ 36 | 37 | assert x.size() == y.size(), 'input tensor must have the same size.' 38 | assert loss_type in LOSS_TYPES, f'select a loss type from {LOSS_TYPES}.' 39 | 40 | N, C, H, W = x.size() 41 | 42 | if loss_type == 'cosine': 43 | dist_raw = compute_cosine_distance(x, y) 44 | elif loss_type == 'l1': 45 | dist_raw = compute_l1_distance(x, y) 46 | elif loss_type == 'l2': 47 | dist_raw = compute_l2_distance(x, y) 48 | 49 | dist_tilde = compute_relative_distance(dist_raw) 50 | cx = compute_cx(dist_tilde, band_width) 51 | cx = torch.mean(torch.max(cx, dim=1)[0], dim=1) # Eq(1) 52 | cx_loss = torch.mean(-torch.log(cx + 1e-5)) # Eq(5) 53 | 54 | return cx_loss 55 | 56 | 57 | # TODO: Operation check 58 | def contextual_bilateral_loss(x: torch.Tensor, 59 | y: torch.Tensor, 60 | weight_sp: float = 0.1, 61 | band_width: float = 1., 62 | loss_type: str = 'cosine'): 63 | """ 64 | Computes Contextual Bilateral (CoBi) Loss between x and y, 65 | proposed in https://arxiv.org/pdf/1905.05169.pdf. 66 | 67 | Parameters 68 | --- 69 | x : torch.Tensor 70 | features of shape (N, C, H, W). 71 | y : torch.Tensor 72 | features of shape (N, C, H, W). 73 | band_width : float, optional 74 | a band-width parameter used to convert distance to similarity. 75 | in the paper, this is described as :math:`h`. 76 | loss_type : str, optional 77 | a loss type to measure the distance between features. 78 | Note: `l1` and `l2` frequently raises OOM. 79 | 80 | Returns 81 | --- 82 | cx_loss : torch.Tensor 83 | contextual loss between x and y (Eq (1) in the paper). 84 | k_arg_max_NC : torch.Tensor 85 | indices to maximize similarity over channels. 86 | """ 87 | 88 | assert x.size() == y.size(), 'input tensor must have the same size.' 89 | assert loss_type in LOSS_TYPES, f'select a loss type from {LOSS_TYPES}.' 90 | 91 | # spatial loss 92 | grid = compute_meshgrid(x.shape).to(x.device) 93 | dist_raw = compute_l2_distance(grid, grid) 94 | dist_tilde = compute_relative_distance(dist_raw) 95 | cx_sp = compute_cx(dist_tilde, band_width) 96 | 97 | # feature loss 98 | if loss_type == 'cosine': 99 | dist_raw = compute_cosine_distance(x, y) 100 | elif loss_type == 'l1': 101 | dist_raw = compute_l1_distance(x, y) 102 | elif loss_type == 'l2': 103 | dist_raw = compute_l2_distance(x, y) 104 | dist_tilde = compute_relative_distance(dist_raw) 105 | cx_feat = compute_cx(dist_tilde, band_width) 106 | 107 | # combined loss 108 | cx_combine = (1. - weight_sp) * cx_feat + weight_sp * cx_sp 109 | 110 | k_max_NC, _ = torch.max(cx_combine, dim=2, keepdim=True) 111 | 112 | cx = k_max_NC.mean(dim=1) 113 | cx_loss = torch.mean(-torch.log(cx + 1e-5)) 114 | 115 | return cx_loss 116 | 117 | 118 | def compute_cx(dist_tilde, band_width): 119 | w = torch.exp((1 - dist_tilde) / band_width) # Eq(3) 120 | cx = w / torch.sum(w, dim=2, keepdim=True) # Eq(4) 121 | return cx 122 | 123 | 124 | def compute_relative_distance(dist_raw): 125 | dist_min, _ = torch.min(dist_raw, dim=2, keepdim=True) 126 | dist_tilde = dist_raw / (dist_min + 1e-5) 127 | return dist_tilde 128 | 129 | 130 | def compute_cosine_distance(x, y): 131 | # mean shifting by channel-wise mean of `y`. 132 | y_mu = y.mean(dim=(0, 2, 3), keepdim=True) 133 | x_centered = x - y_mu 134 | y_centered = y - y_mu 135 | 136 | # L2 normalization 137 | x_normalized = F.normalize(x_centered, p=2, dim=1) 138 | y_normalized = F.normalize(y_centered, p=2, dim=1) 139 | 140 | # channel-wise vectorization 141 | N, C, *_ = x.size() 142 | x_normalized = x_normalized.reshape(N, C, -1) # (N, C, H*W) 143 | y_normalized = y_normalized.reshape(N, C, -1) # (N, C, H*W) 144 | 145 | # consine similarity 146 | cosine_sim = torch.bmm(x_normalized.transpose(1, 2), 147 | y_normalized) # (N, H*W, H*W) 148 | 149 | # convert to distance 150 | dist = 1 - cosine_sim 151 | 152 | return dist 153 | 154 | 155 | # TODO: Considering avoiding OOM. 156 | def compute_l1_distance(x: torch.Tensor, y: torch.Tensor): 157 | N, C, H, W = x.size() 158 | x_vec = x.view(N, C, -1) 159 | y_vec = y.view(N, C, -1) 160 | 161 | dist = x_vec.unsqueeze(2) - y_vec.unsqueeze(3) 162 | dist = dist.sum(dim=1).abs() 163 | dist = dist.transpose(1, 2).reshape(N, H*W, H*W) 164 | dist = dist.clamp(min=0.) 165 | 166 | return dist 167 | 168 | 169 | # TODO: Considering avoiding OOM. 170 | def compute_l2_distance(x, y): 171 | N, C, H, W = x.size() 172 | x_vec = x.view(N, C, -1) 173 | y_vec = y.view(N, C, -1) 174 | x_s = torch.sum(x_vec ** 2, dim=1) 175 | y_s = torch.sum(y_vec ** 2, dim=1) 176 | 177 | A = y_vec.transpose(1, 2) @ x_vec 178 | dist = y_s - 2 * A + x_s.transpose(0, 1) 179 | dist = dist.transpose(1, 2).reshape(N, H*W, H*W) 180 | dist = dist.clamp(min=0.) 181 | 182 | return dist 183 | 184 | 185 | def compute_meshgrid(shape): 186 | N, C, H, W = shape 187 | rows = torch.arange(0, H, dtype=torch.float32) / (H + 1) 188 | cols = torch.arange(0, W, dtype=torch.float32) / (W + 1) 189 | 190 | feature_grid = torch.meshgrid(rows, cols) 191 | feature_grid = torch.stack(feature_grid).unsqueeze(0) 192 | feature_grid = torch.cat([feature_grid for _ in range(N)], dim=0) 193 | 194 | return feature_grid 195 | -------------------------------------------------------------------------------- /contextual_loss/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .contextual import ContextualLoss 2 | from .contextual_bilateral import ContextualBilateralLoss 3 | 4 | __all__ = ['ContextualLoss', 'ContextualBilateralLoss'] 5 | -------------------------------------------------------------------------------- /contextual_loss/modules/contextual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vgg import VGG19 5 | from .. import functional as F 6 | from ..config import LOSS_TYPES 7 | 8 | 9 | class ContextualLoss(nn.Module): 10 | """ 11 | Creates a criterion that measures the contextual loss. 12 | 13 | Parameters 14 | --- 15 | band_width : int, optional 16 | a band_width parameter described as :math:`h` in the paper. 17 | use_vgg : bool, optional 18 | if you want to use VGG feature, set this `True`. 19 | vgg_layer : str, optional 20 | intermidiate layer name for VGG feature. 21 | Now we support layer names: 22 | `['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4']` 23 | """ 24 | 25 | def __init__(self, 26 | band_width: float = 0.5, 27 | loss_type: str = 'cosine', 28 | use_vgg: bool = False, 29 | vgg_layer: str = 'relu3_4'): 30 | 31 | super(ContextualLoss, self).__init__() 32 | 33 | assert band_width > 0, 'band_width parameter must be positive.' 34 | assert loss_type in LOSS_TYPES,\ 35 | f'select a loss type from {LOSS_TYPES}.' 36 | 37 | self.band_width = band_width 38 | 39 | if use_vgg: 40 | self.vgg_model = VGG19() 41 | self.vgg_layer = vgg_layer 42 | self.register_buffer( 43 | name='vgg_mean', 44 | tensor=torch.tensor( 45 | [[[0.485]], [[0.456]], [[0.406]]], requires_grad=False) 46 | ) 47 | self.register_buffer( 48 | name='vgg_std', 49 | tensor=torch.tensor( 50 | [[[0.229]], [[0.224]], [[0.225]]], requires_grad=False) 51 | ) 52 | 53 | def forward(self, x, y): 54 | if hasattr(self, 'vgg_model'): 55 | assert x.shape[1] == 3 and y.shape[1] == 3,\ 56 | 'VGG model takes 3 chennel images.' 57 | 58 | # normalization 59 | x = x.sub(self.vgg_mean.detach()).div(self.vgg_std.detach()) 60 | y = y.sub(self.vgg_mean.detach()).div(self.vgg_std.detach()) 61 | 62 | # picking up vgg feature maps 63 | x = getattr(self.vgg_model(x), self.vgg_layer) 64 | y = getattr(self.vgg_model(y), self.vgg_layer) 65 | 66 | return F.contextual_loss(x, y, self.band_width) 67 | -------------------------------------------------------------------------------- /contextual_loss/modules/contextual_bilateral.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vgg import VGG19 5 | from .. import functional as F 6 | from ..config import LOSS_TYPES 7 | 8 | 9 | class ContextualBilateralLoss(nn.Module): 10 | """ 11 | Creates a criterion that measures the contextual bilateral loss. 12 | 13 | Parameters 14 | --- 15 | weight_sp : float, optional 16 | a balancing weight between spatial and feature loss. 17 | band_width : int, optional 18 | a band_width parameter described as :math:`h` in the paper. 19 | use_vgg : bool, optional 20 | if you want to use VGG feature, set this `True`. 21 | vgg_layer : str, optional 22 | intermidiate layer name for VGG feature. 23 | Now we support layer names: 24 | `['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4']` 25 | """ 26 | 27 | def __init__(self, 28 | weight_sp: float = 0.1, 29 | band_width: float = 0.5, 30 | loss_type: str = 'cosine', 31 | use_vgg: bool = False, 32 | vgg_layer: str = 'relu3_4'): 33 | 34 | super(ContextualBilateralLoss, self).__init__() 35 | 36 | assert band_width > 0, 'band_width parameter must be positive.' 37 | assert loss_type in LOSS_TYPES,\ 38 | f'select a loss type from {LOSS_TYPES}.' 39 | 40 | self.band_width = band_width 41 | 42 | if use_vgg: 43 | self.vgg_model = VGG19() 44 | self.vgg_layer = vgg_layer 45 | self.register_buffer( 46 | name='vgg_mean', 47 | tensor=torch.tensor( 48 | [[[0.485]], [[0.456]], [[0.406]]], requires_grad=False) 49 | ) 50 | self.register_buffer( 51 | name='vgg_std', 52 | tensor=torch.tensor( 53 | [[[0.229]], [[0.224]], [[0.225]]], requires_grad=False) 54 | ) 55 | 56 | def forward(self, x, y): 57 | if hasattr(self, 'vgg_model'): 58 | assert x.shape[1] == 3 and y.shape[1] == 3,\ 59 | 'VGG model takes 3 chennel images.' 60 | 61 | # normalization 62 | x = x.sub(self.vgg_mean.detach()).div(self.vgg_std.detach()) 63 | y = y.sub(self.vgg_mean.detach()).div(self.vgg_std.detach()) 64 | 65 | # picking up vgg feature maps 66 | x = getattr(self.vgg_model(x), self.vgg_layer) 67 | y = getattr(self.vgg_model(y), self.vgg_layer) 68 | 69 | return F.contextual_bilateral_loss(x, y, self.band_width) 70 | -------------------------------------------------------------------------------- /contextual_loss/modules/vgg.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch.nn as nn 4 | import torchvision.models.vgg as vgg 5 | 6 | 7 | class VGG19(nn.Module): 8 | def __init__(self, requires_grad=False): 9 | super(VGG19, self).__init__() 10 | vgg_pretrained_features = vgg.vgg19(pretrained=True).features 11 | self.slice1 = nn.Sequential() 12 | self.slice2 = nn.Sequential() 13 | self.slice3 = nn.Sequential() 14 | self.slice4 = nn.Sequential() 15 | self.slice5 = nn.Sequential() 16 | for x in range(4): 17 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 18 | for x in range(4, 9): 19 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 20 | for x in range(9, 18): 21 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 22 | for x in range(18, 27): 23 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 24 | for x in range(27, 36): 25 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 26 | if not requires_grad: 27 | for param in self.parameters(): 28 | param.requires_grad = False 29 | 30 | def forward(self, X): 31 | h = self.slice1(X) 32 | h_relu1_2 = h 33 | h = self.slice2(h) 34 | h_relu2_2 = h 35 | h = self.slice3(h) 36 | h_relu3_4 = h 37 | h = self.slice4(h) 38 | h_relu4_4 = h 39 | h = self.slice5(h) 40 | h_relu5_4 = h 41 | 42 | vgg_outputs = namedtuple( 43 | "VggOutputs", ['relu1_2', 'relu2_2', 44 | 'relu3_4', 'relu4_4', 'relu5_4']) 45 | out = vgg_outputs(h_relu1_2, h_relu2_2, 46 | h_relu3_4, h_relu4_4, h_relu5_4) 47 | 48 | return out 49 | -------------------------------------------------------------------------------- /doc/small_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Small exmaple of Contextual Loss (CX)\n", 8 | "[Open with Colab](https://colab.research.google.com/github/S-aiueo32/contextual_loss_pytorch/blob/master/doc/small_example.ipynb)\n", 9 | "\n", 10 | "I will show you a small example of CX here.\n", 11 | "\n", 12 | "## Calculate CX\n", 13 | "For the first, import dependencies." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import numpy as np\n", 23 | "import matplotlib.pyplot as plt" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "Create images called $x$ and $y$ in the paper.\n", 31 | "You can see the pixels whose spatial correspondings have similar looks.\n", 32 | "For simplicity, I will use the RGB values as the features through this notebook." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "" 44 | ] 45 | }, 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | }, 50 | { 51 | "data": { 52 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAABWCAYAAADMpy0uAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAHDElEQVR4nO3dX4gdZx3G8efpJmlIU2NstK5paCtJi1XR2rUgxT/0D229aApaaC8khZZVJKg3YkSQ0hujFy0IvQk2UEXaShVcSyAobSlBLVlLrU1CmjUXbUI0bVISo2i68niRIxzWs93amXNmzrzfDyw7c+Zl3t+c/Z2Hs7Mze5xEAIDuO6/pAgAAo0HgA0AhCHwAKASBDwCFIPABoBAEPgAUolLg236P7V/bPtT7vnaRcf+2/ULva6bKnMAo0NvoIle5Dt/2DySdTLLd9jZJa5N8a8C4M0lWV6gTGCl6G11UNfAPSvpckmO2JyU9k+TKAeN4UWCs0Nvooqrn8C9Ocqy3/BdJFy8ybqXtWdu/t317xTmBUaC30TnLlhpg+zeS3j9g03f6V5LE9mK/Llya5KjtD0p6yvafkvx5wFzTkqYladXEeddsXL1qyQMYB4c+0I3jkCT//X1Nl1Cbf7zy0puSXh6waai9fcHKlddcsWF9xerb4dSq5U2XUBvPr2y6hFocP/qKTr9xwoO2LRn4SW5cbJvtv9qe7Pu19/gi+zja+37Y9jOSrpb0Py+KJDsk7ZCkj737wuz+9MeXKm8s3HzfJ5suoTYrZr/adAm1mf3KpheTTA3aNsze/sQVG7Pnhw/UdBTN2jX13qZLqM3Eax9uuoRafPOOzy66reopnRlJW3rLWyT9cuEA22ttn99bXifpOkn7K84LDBu9jc6pGvjbJd1k+5CkG3vrsj1l+0e9MR+SNGv7j5KelrQ9CS8KtB29jc5Z8pTOW0lyQtINAx6flXRvb/m3kj5aZR5g1OhtdBF32gJAIQh8ACgEgQ8AhSDwAaAQBD4AFILAB4BCEPgAUAgCHwAKQeADQCEIfAAoBIEPAIUg8AGgEAQ+ABSCwAeAQhD4AFAIAh8ACkHgA0Ahagl827fYPmh7zva2AdvPt/14b/tzti+rY15g2OhtdEnlwLc9IekhSbdKukrSXbavWjDsHklvJNko6UFJ3686LzBs9Da6po53+NdKmktyOMlZSY9J2rxgzGZJj/SWn5B0g23XMDcwTPQ2OqWOwF8v6dW+9SO9xwaOSTIv6ZSkixbuyPa07VnbsyfOvllDaUAlQ+nt10+dHlK5wFtr1R9tk+xIMpVk6qIVy5suB6hNf2+vW/OupstBoeoI/KOSNvStX9J7bOAY28skrZF0ooa5gWGit9EpdQT+XkmbbF9ue4WkOyXNLBgzI2lLb/mLkp5KkhrmBoaJ3kanLKu6gyTztrdK2i1pQtLOJPts3y9pNsmMpIcl/cT2nKSTOvfCAVqN3kbXVA58SUqyS9KuBY99t2/5n5LuqGMuYJTobXRJq/5oCwAYHgIfAApB4ANAIQh8ACgEgQ8AhSDwAaAQBD4AFILAB4BCEPgAUAgCHwAKQeADQCEIfAAoBIEPAIUg8AGgEAQ+ABSCwAeAQtQS+LZvsX3Q9pztbQO23237Ndsv9L7urWNeYNjobXRJ5U+8sj0h6SFJN0k6Immv7Zkk+xcMfTzJ1qrzAaNCb6Nr6niHf62kuSSHk5yV9JikzTXsF2gavY1OqSPw10t6tW/9SO+xhb5g+0XbT9jeUMO8wLDR2+iUWj7E/G34laRHk/zL9pclPSLp+oWDbE9Lmu6tnpl8cs/BEdS2TtLrQ53hyT1D3X2f4R+LHhzu7s8ZwXFIki6tYR/vqLcvuHXzsHt7VM/hKHAs/59F+9pJKu3Z9qck3Zfk5t76tyUpyfcWGT8h6WSSNZUmront2SRTTddRh64cS1uOY5x7uy3PYR04lvrUcUpnr6RNti+3vULSnZJm+gfYnuxbvU3SgRrmBYaN3kanVD6lk2Te9lZJuyVNSNqZZJ/t+yXNJpmR9DXbt0mal3RS0t1V5wWGjd5G11Q+pTPubE8n2dF0HXXoyrF05Tia1KXnkGOpcf7SAx8ASsG/VgCAQhQb+EvdMj8ubO+0fdz2S03XUpXtDbaftr3f9j7bX2+6pnFEb7dLm/q6yFM6vcvnXlbfLfOS7hpwy3zr2f6MpDOSfpzkI03XU0XvipfJJM/bvlDSHyTdPo4/l6bQ2+3Tpr4u9R1+Z26ZT/Kszl0dMvaSHEvyfG/5bzp3ieOgO1uxOHq7ZdrU16UG/tu9ZR4NsX2ZpKslPddsJWOH3m6xpvu61MBHi9leLennkr6R5HTT9QB1aENflxr4RyX1/5OrS3qPoWG2l+vci+KnSX7RdD1jiN5uobb0damBv+Qt8xg925b0sKQDSR5oup4xRW+3TJv6usjATzIv6b+3zB+Q9LMk+5qt6p2x/aik30m60vYR2/c0XVMF10n6kqTr+z5B6vNNFzVO6O1Wak1fF3lZJgCUqMh3+ABQIgIfAApB4ANAIQh8ACgEgQ8AhSDwAaAQBD4AFILAB4BC/Acw2kFzNIE/cgAAAABJRU5ErkJggg==\n", 53 | "text/plain": [ 54 | "
" 55 | ] 56 | }, 57 | "metadata": { 58 | "needs_background": "light" 59 | }, 60 | "output_type": "display_data" 61 | } 62 | ], 63 | "source": [ 64 | "x = np.array([\n", 65 | " [231, 76, 60],\n", 66 | " [46, 204, 113],\n", 67 | " [52, 152, 219],\n", 68 | "]).reshape(1, 3, 3) / 255\n", 69 | "y = np.array([\n", 70 | " [245, 183, 177],\n", 71 | " [171, 235, 198],\n", 72 | " [174, 214, 241],\n", 73 | "]).reshape(1, 3, 3) / 255\n", 74 | "\n", 75 | "plt.figure()\n", 76 | "plt.subplot(1, 2, 1)\n", 77 | "plt.imshow(x)\n", 78 | "plt.subplot(1, 2, 2)\n", 79 | "plt.imshow(y)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "Along with the key idea, convert $x$, $y$ to $X$, $Y$ respectively." 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 3, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "X = x.reshape(-1, 3)\n", 96 | "Y = y.reshape(-1, 3)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "Calculate cosine distances between all points." 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "# normalize\n", 113 | "mu = Y.mean(axis=0, keepdims=True)\n", 114 | "X_centered = X -mu\n", 115 | "Y_centered = Y -mu\n", 116 | "X_normalized = X_centered / np.linalg.norm(X_centered, ord=2, axis=1, keepdims=True)\n", 117 | "Y_normalized = Y_centered / np.linalg.norm(Y_centered, ord=2, axis=1, keepdims=True)\n", 118 | "\n", 119 | "# cosine distance\n", 120 | "d = 1 - np.matmul(X_normalized, Y_normalized.transpose())" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "Looking at the heatmap of `d`, you confirm the correspondings are similar." 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 5, 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "data": { 137 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUgAAAEXCAYAAADPzN0RAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAbhklEQVR4nO3df9RdVX3n8feHEEBL5FdEIoRfJToilqoRsGiNo64GFiWdSjvgDIgLJ0NbrLY6FbVFh9YWdOmMDlhWiiwELYioEG0Uf0HRQTCBIiGwgAgqiWgk/K4K5Hk+88c5D3O9ufeem+Teu+/z5PNinfWce85+9vnmEr7ss/fZ+8g2ERGxuR1KBxARMa6SICMiukiCjIjoIgkyIqKLJMiIiC6SICMiukiCjGdIukDS35SOI2JcJEFuJyT9UNIvJT0u6RFJN0g6XdIzfwdsn277b/us6/XDjXjLSFokyZK+2Hb88Pr4dYVCi2ksCXL78vu25wAHAOcA7wY+WTakgfo58EpJe7UcezNwd6F4YppLgtwO2X7U9nLgPwNvlnQYgKSLJf1dvT9X0pfr1uZDkr4taQdJlwL7A1+S9ISkv6rLf07STyU9Kul6SS+eul5d7/mS/qVuwd4k6Tdbzr9Y0tfr6/xM0nvr4ztIOlPSDyRtlHSFpD17/NGeAq4CTqx/f1b9Z/xMayFJH5N0v6THJN0s6dUt5z4g6UpJn61jvUXS4dvwdcc0lgS5HbP9PWAd8OoOp99Zn3su8DzgvdWv+GTgx1St0V1tf6gu/xVgAbA3cAttSYkqaf1PYA9gLfBBAElzgG8AXwWeDxwCfLP+nbcBfwC8pj73MHB+wx/rEuCUev/3gNuBn7SVWQn8NrAn8M/A5yTt0nJ+CfC5lvNXSZrdcN2YgZIg4ydUiaDd08A84ADbT9v+tntM3Ld9ke3HbT8JfAA4XNJuLUW+aPt7tjdRJc/fro8fB/zU9kds/6qu46b63OnA+2yva6n3BEk79ojjBmBPSS+kSpSXdCjzadsbbW+y/RFgZ+CFLUVutn2l7aeBjwK7AEd1u2bMXEmQsS/wUIfjH6Zq6X1N0r2SzuxWgaRZks6pb4UfA35Yn5rbUuynLfu/AHat9+cDP+hS9QHAF+vb/EeAO4EJqhZtL5cCZwCvBb7YflLSuyTdWXcHPALs1hbr/VM7tiepWtLPb7hmzEBJkNsxSa+gSpDfaT9Xt+Teaftg4HjgLyW9bup0W/E3Ud2Wvp4q2Rw4dYk+wrgfOLjHuWNs796y7WJ7fUOdlwJ/Cqyw/YvWE3V/418BfwzsYXt34NG2WOe3lN8B2I/Nb9NjO5AEuR2S9BxJxwGXA5+2vbpDmeMkHSJJVAlkApisT/+MX09qc4AngY3As4G/34JwvgzMk/QOSTtLmiPpyPrcBcAHJR1Qx/RcSUuaKrR9H1W/5fs6nJ4DbKIa8d5R0lnAc9rKvFzSH9a38u+o/2w3bsGfKWaIJMjty5ckPU7VMnsfVf/aW7qUXUA1ePIE8F3gE7avrc/9A/DX9a3vu6j6+X4ErAfuYAuSie3HgTcAv091G34P1a0xwMeA5VS3+Y/X9R7ZqZ4O9X7HdqdW3zVUA0J31zH/ipZb6trVVKPfDwMnA39Y90fGdkZZMDfi/5P0AeAQ2/+1dCxRXlqQERFdFEuQkvasHw6+p/65R5dyE5Jurbflo44zIrZfxW6xJX0IeMj2OfUjJHvYfneHck/Y3nXzGiIihqtkgrwLWGT7AUnzgOtsv7BDuSTIiCiiZIJ8pH4GjfpRkoenPreV2wTcSvVoxjm2r+pS31JgKcAsZr382Zs9uRFTnvWi0hGMv/mzf9FcaDt3821PPmj7udtSx++99je88aGJfq51je3F23KtrdF1ytYgSPoGsE+HU7/2fJptS+qWqQ+wvV7SwcC3JK22vdnMC9vLgGUAz9GePvKZZ5qj3WGfydhck4/Mu6V0CGNv1ry1P9rWOjY+NMH3rtm/j2vdM7ex0BAMNUHa7rpmYL1qy7yWW+wNXepYX/+8t17T76V0n5oWEdOIgcln5h+Mn5JNieVUa/VR/7y6vYCkPSTtXO/PBY6mehA5ImYAY572RONWSskEeQ7wBkn3UM3hPQdA0kJJF9ZlXgSskvR94FqqPsgkyIgZZLKPf0oZ6i12L7Y3Apt1FNpeBby13r8BeMmIQ4uIETFmYoxn8xVLkBERAJObLQ41PpIgI6IYAxNJkBERnaUFGRHRgYGn0wcZEbE549xiR0R0ZJgY3/yYBBkR5VQzacZXEmREFCQm+nq3WxlJkBFRjIHJ3GJHRGzOwFNj/OaXJMiIKGrSucWOiNhMNZMmCTIiYjNGTIzxLfb4RhYR24VJq3Hrh6SLJG2QdHuPMovqN6SukfSvTXWmBRkRxRjxlGcNqrqLgfOASzqdlLQ78Algse0fS9q7qcIkyIgopnpQfDA3sravl3RgjyJvAr5g+8d1+Y6veWmVW+yIKGqifli81wbMlbSqZVu6FZd6AbCHpOsk3SzplKZfSAsyIoqxxYT7aqc9aHvhNl5uR+DlVG8yeBbwXUk32r671y9ERBQzObrHfNYBG23/O/Dvkq4HDge6JsjcYkdEMdVzkDs0bgNyNfAqSTtKejZwJHBnr19ICzIiijHiaQ8mDUm6DFhE1V+5Dng/MBvA9gW275T0VeA2qkWELrTd9ZEgSIKMiMImBjTV0PZJfZT5MPDhfutMgoyIYsZ9Jk0SZEQUNdnfKHYRSZARUczUIM24SoKMiGKMBtYHOQxJkBFRjM3ARrGHoXjbVtJiSXdJWivpzA7nd5b02fr8TQ1zLSNiWhGTfWylFE2QkmYB5wPHAIcCJ0k6tK3YacDDtg8B/hdw7mijjIhhMTDhHRq3Ukq3II8A1tq+1/ZTwOXAkrYyS4BP1ftXAq+TNL6dFhGxRUY4k2aLlU6Q+wL3t3xeVx/rWMb2JuBRYK+RRBcRQ2WaF8st+c6a8e0d3UL18kdLAXbh2YWjiYh+mAzS9LIemN/yeb/6WMcyknYEdgM2tldke5nthbYXzmbnIYUbEYPVvBZkyZd6lU6QK4EFkg6StBNwIrC8rcxy4M31/gnAt2yP8avGI6JfpppJ07SVUrRta3uTpDOAa4BZwEW210g6G1hleznwSeBSSWuBh6iSaETMEHntaw+2VwAr2o6d1bL/K+CPRh1XRAyfrczFjojopuRzjk2SICOimGrB3IG99nXgkiAjophqkCZ9kBERHWW5s4iIDqZm0oyrJMiIKGpyjFuQ4xtZRMx4dvXSrqatH5IukrRBUs83FUp6haRNkk5oqjMJMiKKMWLT5KzGrU8XA4t7FaiXWDwX+Fo/FSZBRkRRg5qLbft6qtl2vbwN+DywoZ860wcZEcVswWM+cyWtavm8zPayLbmWpH2B/wS8FnhFP7+TBBkRBfU91fBB2wu38WL/G3i37cl+19xOgoyIokb4zpmFwOV1cpwLHCtpk+2ruv1CEmREFGPD0/0PwmzjtXzQ1L6ki4Ev90qOkAQZEQUN8kFxSZcBi6j6K9cB7wdmA9i+YGvqTIKMiKIGdYtt+6QtKHtqP+WSICOimCxWERHRQxbMjYjopPBrXZskQUZEMQY2pQUZEbG59EFGRPSQBBkR0UEWzI2I6GGEUw23WBJkRJTj3GJHRHRkYNNkRrEjIjaTPsiIiB6cBBkR0dk4D9IUv/mXtFjSXZLWSjqzw/lTJf1c0q319tYScUbE4LkepGnaSinagqzfMHY+8AZgHbBS0nLbd7QV/aztM0YeYEQMmZgY40Ga0pEdAay1fa/tp4DLgSWFY4qIEbLVuJVSug9yX+D+ls/rgCM7lHujpN8F7gb+wvb97QUkLQWWAuy6z7NZ8OWdhxDuzHDn63YpHcLYO+RvTi8dwjTwrm2uYdznYpduQfbjS8CBtn8L+DrwqU6FbC+zvdD2wmftkQQQMS246ods2kopnSDXA/NbPu9XH3uG7Y22n6w/Xgi8fESxRcQITKLGrZTSCXIlsEDSQZJ2Ak4ElrcWkDSv5ePxwJ0jjC8ihsiMdx9k0QRpexNwBnANVeK7wvYaSWdLOr4u9ueS1kj6PvDnwKlloo2IwRMTk81bXzVJF0naIOn2Luf/i6TbJK2WdIOkw5vqLD1Ig+0VwIq2Y2e17L8HeM+o44qI0RhgC/Fi4Dzgki7n7wNeY/thSccAy+g8KPyM4gkyIrZf1SDMwF77er2kA3ucv6Hl441UYx49JUFGRFF9PuYzV9Kqls/LbC/bhsueBnylqVASZEQU1edjPA/aXjiI60l6LVWCfFVT2STIiCjGiMkRTjWU9FtUjwseY3tjU/kkyIgoalTPgUvaH/gCcLLtu/v5nSTIiChngIM0ki4DFlH1V64D3g/MBrB9AXAWsBfwCUkAm5pu25MgI6KsATUhbZ/UcP6twBYtl5gEGRFFZUXxiIguSi5G0SQJMiKKscFjvGBuEmREFJUWZEREN0mQERGdlF3OrEkSZESUlRZkREQHA3xQfBiSICOirCTIiIgucosdEdFFEmRERAcmt9gREd3kQfGIiG76fGthCUmQEVGU0oKMiOjAZJAmIqIzZZAmIqKrtCAjIrpIgoyI6MCM9Sh2z6V8JV1R/1wt6baWbbWk2wYRgKSLJG2QdHuX85L0cUlr62u/bBDXjYjxIDdvfdUzhFzS1IJ8e/3zuP5C3CoXA+cBl3Q5fwywoN6OBP6x/hkRM8HgbrEvZsC5pGeCtP1A/fNHvcpJ+q7tV/Yq0+Ma10s6sEeRJcAltg3cKGl3SfOmYouIgOHkkkG9LWeXAdXTyb7A/S2f19XHfo2kpZJWSVr1y4d/NcRwImKQ+rzFnjv133e9Ld2KS/WVS1oNapCm+DiU7WXAMoC9D92reDwR0af+noN80PbCYYfSbjqMYq8H5rd83q8+FhHTnYHJkV1ti3NJX7fYkg7tcGxR68d+6tlKy4FT6hGoo4BH0/8YMXMMahS7D1ucS/ptQV4h6VLgQ1T9jR8CFgJTAzMnb2XASLoMWETVx7AOeD8wG8D2BcAK4FhgLfAL4C1be62IGEMDSoDDyCX9JsgjgXOBG4A5wGeAo6dO2u743FE/bJ/UcN7An21t/REx5gaUIIeRS/pNkE8DvwSeRdWCvM/26HoOImJGGvAt9MD1+5jPSqoE+Qrg1cBJkj43tKgiYvsxqeatkH5bkKfZXlXvPwAskbTV/Y4REVPGuQXZV4JsSY6txy4dfDgRsd2Z7gkyImIoxrwPMgkyIspKgoyI6CIJMiKis9xiR0R0kwQZEdFBBmkiInpIgoyI6CIJMiJicyK32BERnRk0xsveJEFGRFlpQUZEdJEEGRHRWfogIyK6SYKMiOhgzAdp+l1RPCJiONzH1gdJiyXdJWmtpDM7nN9f0rWS/k3SbZKObaozCTIiihrEa18lzQLOB44BDqV6LUz766r/GrjC9kuBE4FPNNWbBBkRZQ2mBXkEsNb2vbafAi4HlnS40nPq/d2AnzRVmj7IiChnC26hG+wL3N/yeR3V66pbfQD4mqS3Ab8BvL6p0rQgI6IY9bkBcyWtatmWbsXlTgIutr0fcCxwqaSeOTAtyIgoqs9R7AdtL+xxfj0wv+XzfvWxVqcBiwFsf1fSLsBcYEO3StOCjIiyBtMHuRJYIOkgSTtRDcIsbyvzY+B1AJJeBOwC/LxXpWlBRkRZA+iDtL1J0hnANcAs4CLbaySdDayyvRx4J/BPkv6ivuqptntevWiClHQRcBywwfZhHc4vAq4G7qsPfcH22aOLMCKGaoAritteAaxoO3ZWy/4dwNFbUmfpFuTFwHnAJT3KfNv2caMJJyJGLlMNO7N9vaQDS8YQEWWN81TD0i3IfrxS0vepHup8l+01nQrVw/5LAfZ+/o4cv8ctIwxxevm717yldAhj75BPP146hLF3X3ORvozzaj7jPop9C3CA7cOB/wNc1a2g7WW2F9peuNue0yHvR0RfI9gFE+hYJ0jbj9l+ot5fAcyWNLdwWBExSEmQW0fSPpJU7x9BFe/GslFFxKBMvbRrWxerGJbSj/lcBiyimka0Dng/MBvA9gXACcCfSNoE/BI4sem5pYiYZsb4v+jSo9gnNZw/j+oxoIiYiQyaHN8MmdGMiChqnEexkyAjoqwkyIiIztKCjIjoJgkyIqKDwo/xNEmCjIhiROZiR0R0N8aPNidBRkRRucWOiOik8FzrJkmQEVFU+iAjIrpIgoyI6MSM9SDNWC93FhEz36CWO5O0WNJdktZKOrNLmT+WdIekNZL+uanOtCAjoqwBNCAlzQLOB94ArANWSlpev8lwqswC4D3A0bYflrR3U71pQUZEMQNcMPcIYK3te20/BVwOLGkr89+A820/DGB7Q1OlSZARUY7d39ZsX+D+ls/r6mOtXgC8QNL/lXSjpMVNleYWOyKK6nMUe66kVS2fl9letoWX2hFYQPUWg/2A6yW9xPYjvX4hIqKYPm+hH7S9sMf59cD8ls/71cdarQNusv00cJ+ku6kS5spuleYWOyLKMTDp5q3ZSmCBpIMk7QScCCxvK3MVVeuR+u2oLwDu7VVpEmRElDWA177a3gScAVwD3AlcYXuNpLMlHV8XuwbYKOkO4Frgf9ju+ZbU3GJHRFGDWqzC9gpgRduxs1r2DfxlvfUlCTIiyhrjmTRJkBFRjjMXOyKio+pB8bQgIyI6SwsyIqKztCAjIjoZ8xXFiz4HKWm+pGtblh96e4cykvTxegmj2yS9rESsETEMRpPNWymlW5CbgHfavkXSHOBmSV9vXaIIOIZqOtAC4EjgH+ufETETjPEtdtEWpO0HbN9S7z9O9QR8+wocS4BLXLkR2F3SvBGHGhHDUD/m07SVMjZTDSUdCLwUuKntVD/LGEXEdDWY5c6GovQtNgCSdgU+D7zD9mNbWcdSYCnA3s8fiz9WRPRjfO+wy7cgJc2mSo6fsf2FDkX6WcYI28tsL7S9cLc9kyAjpgvZjVsppUexBXwSuNP2R7sUWw6cUo9mHwU8avuBkQUZEcNjYMLNWyGlm1pHAycDqyXdWh97L7A/gO0LqFbnOBZYC/wCeEuBOCNiCETZFmKTognS9neopmP2KmPgz0YTUUSMXBJkREQXSZARER2YLFYREdFN+iAjIjoyTI5vEzIJMiLKMWPdB1n8QfGI2M5N9rH1QdJiSXfVK3+d2aPcGyVZUq/3bANJkBFR2CBm0kiaBZxPtfrXocBJkg7tUG4O8HY2X/OhoyTIiChrMItVHAGstX2v7aeAy6lWAmv3t8C5wK/6qTQJMiLKsWFisnmDuZJWtWxL22pqXPWrXmx7vu1/6Te8DNJERFn9tRAftN3YZ9iNpB2AjwKnbsnvpQUZEWUN5ha7adWvOcBhwHWSfggcBSxvGqhJCzIiyjEwmHfOrAQWSDqIKjGeCLzpmcvYjwJzpz5Lug54l+1VvSpNCzIiCjJ4snlrqsXeBJwBXEP16pYrbK+RdLak47c2urQgI6KsAT0obnsF1fKIrcfO6lJ2UT91JkFGRDlmapR6LCVBRkRZYzzVMAkyIgoq+9bCJkmQEVGOyWo+ERFdpQUZEdFFEmRERAc2npgoHUVXSZARUdZgZtIMRRJkRJSVW+yIiA6cd9JERHSXFmRERCcZpImI6Gxwy50NRRJkRJTVx3JmpRRdD1LSfEnXSrpD0hpJb+9QZpGkRyXdWm8dly+KiOnHgCfduJVSugW5CXin7Vvq1zHeLOnrtu9oK/dt28cViC8ihske6xZk0QRp+wHggXr/cUl3Ur2JrD1BRsQMVbKF2EQekyF2SQcC1wOH2X6s5fgi4PNUr3H8CdV7JNZ0+P2lwNSrIA8Dbh9uxFtsLvBg6SBaJJ7exi0eGL+YXmh7zrZUIOmrtLwrpocHbS/elmttjbFIkJJ2Bf4V+KDtL7Sdew4wafsJSccCH7O9oKG+VdvyishhGLeYEk9v4xYPjF9M4xbPMBR/aZek2VQtxM+0J0cA24/ZfqLeXwHMltTP/3EiIrZJ6VFsAZ8E7rT90S5l9qnLIekIqpg3ji7KiNhelR7FPho4GVgt6db62HuB/QFsXwCcAPyJpE3AL4ET3dwvsGxI8W6LcYsp8fQ2bvHA+MU0bvEM3Fj0QUZEjKPifZAREeMqCTIioosZkSAl7Snp65LuqX/u0aXcRMuUxeVDiGOxpLskrZV0ZofzO0v6bH3+pvrZz6HqI6ZTJf285Xt56xBjuUjSBkkdn1FV5eN1rLdJetmwYukznpFOc+1z6u2ov6Ptezqw7Wm/AR8Czqz3zwTO7VLuiSHGMAv4AXAwsBPwfeDQtjJ/ClxQ758IfHbI30s/MZ0KnDeif0+/C7wMuL3L+WOBrwACjgJuKhzPIuDLo/hu6uvNA15W788B7u7w72vU31E/MY30exrlNiNakMAS4FP1/qeAPygQwxHAWtv32n4KuLyOq1VrnFcCr5t6hKlgTCNj+3rgoR5FlgCXuHIjsLukeQXjGSnbD9i+pd5/HJiaettq1N9RPzHNWDMlQT7P1bxugJ8Cz+tSbhdJqyTdKGnQSXRf4P6Wz+vY/C/SM2VsbwIeBfYacBxbGhPAG+vbtSslzR9iPE36jXeUXinp+5K+IunFo7po3f3yUuCmtlPFvqMeMUGh72nYSj8H2TdJ3wD26XDqfa0fbFtSt2eXDrC9XtLBwLckrbb9g0HHOs18CbjM9pOS/jtVC/c/Fo5pXNxC9XdmaprrVUDPaa6DUE+9/TzwDresS1BSQ0xFvqdRmDYtSNuvt31Yh+1q4GdTtxn1zw1d6lhf/7wXuI7q/4aDsh5obX3tVx/rWEbSjsBuDHdWUGNMtjfafrL+eCHw8iHG06Sf73BkXGCaa9PUWwp8R9vzdOBpkyAbLAfeXO+/Gbi6vYCkPSTtXO/PpZrFM8hl1VYCCyQdJGknqkGY9pHy1jhPAL7lupd7SBpjauu/Op6qj6mU5cAp9UjtUcCjLV0nIzfqaa71tXpOvWXE31E/MY36exqp0qNEg9io+vG+CdwDfAPYsz6+ELiw3v8dYDXVSO5q4LQhxHEs1SjfD4D31cfOBo6v93cBPgesBb4HHDyC76Yppn8A1tTfy7XAfxhiLJdRrf/5NFXf2WnA6cDp9XkB59exrgYWDvm7aYrnjJbv5kbgd4Ycz6uoFtm+Dbi13o4t/B31E9NIv6dRbplqGBHRxUy5xY6IGLgkyIiILpIgIyK6SIKMiOgiCTIiooskyIiILpIgY+xIer6kK0vHEZHnICMiukgLMkZG0tmS3tHy+YNdFmA9sNsithGjlAQZo3QRcAqApB2o5oZ/umhEET1Mm+XOYvqz/UNJGyW9lGrNzn+zPTMWNYgZKQkyRu1Cqtc87EPVoowYWxmkiZGql11bDcwGFtie6FDmQKp3nBw22ugifl1akDFStp+SdC3wSKfk2Fp0VDFFdJMEGSNVD84cBfxRj2J7MUYv04rtV0axY2QkHUq1WPA3bd/TpcxCqoVsPzbK2CI6SR9kFCPpJcClbYeftH1kiXgi2iVBRkR0kVvsiIgukiAjIrpIgoyI6CIJMiKii/8HxD0znaHgHtwAAAAASUVORK5CYII=\n", 138 | "text/plain": [ 139 | "
" 140 | ] 141 | }, 142 | "metadata": { 143 | "needs_background": "light" 144 | }, 145 | "output_type": "display_data" 146 | } 147 | ], 148 | "source": [ 149 | "plt.imshow(d)\n", 150 | "plt.colorbar()\n", 151 | "plt.title('Distance Map');\n", 152 | "plt.ylabel('x_i');\n", 153 | "plt.xlabel('y_j');" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "Along with Eq.(2) in the paper, normalize using the minimum.\n", 161 | "Through this process, the most similar point over $y_j$ becomes `1.`" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 6, 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "data": { 171 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAT4AAAEXCAYAAAA0myjOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAbmUlEQVR4nO3de7QlZX3m8e9D04ByhzaA0NAyQ0wIyQTsgLcwHVFHkEBmhUxQB4Sl0+KI0Vk6DuoSEiasKIk6sHDEFlBuQRAjtogXDBC8QWiwpWku0hAZumlpaO4qlz7nmT/qPc1ms8/Z1c3euw6nng+r1q7Le6p+u+jzO+9bb71Vsk1ERJts0nQAERGjlsQXEa2TxBcRrZPEFxGtk8QXEa2TxBcRrZPE11KS/lrSBWV+d0lPSJo14GP8QtIbX8DPv0PS9wYZUwQk8Q1N+aVfI2nLjnXvlnRNg2H1ZPv/2d7K9tiojinpy5KelvR4mW6R9HeStu2I60Lbb665r78dbsQbTpLLv4FNO9bNLutyA22DkviGaxbwgRe6E1Vm4v+rU21vDbwMOBZ4NfCjzj8WM8DDwMEdyweXddGgmfjLNJ38PfBhSdv12ijptZJukPRo+Xxtx7ZrJJ0i6UfAr4E9y7q/lfTj0jT9pqQdJV0o6bGyj3kd+zhN0r1l242S/niSOOaV2smmkl5T9j0xPSnpF6XcJpJOkHSXpLWSLpG0Q8d+jpJ0T9n28bonyfaTtm8ADgN2pEqCSDpG0g/LvCR9ttSWHpO0TNI+khYC7wA+MnFOSvmJOB+XdKuk/9wR5zGSfijpHyQ9LOnfJB3csX0HSV+SdF/ZflnHtkMlLZX0SPn/8Ad9vt75wNEdy0cD53Wd/2Ml3VZivVvSezq2LZC0UtLHJD1YWhLvqHtuo7ckvuFaAlwDfLh7Q0kY3wJOp/pl/wzwLUk7dhQ7ClgIbA3cU9YdWdbvCvw74CfAl4AdgNuAkzp+/gbgD8u2fwS+KmmLqQK2/ZPS7N0K2B64HriobH4/8GfAfwReTlVz+Vz5PnsDny+xvbx8p92mOlaPYz8OXAn0StBvBg4EfhvYFvgvwFrbi4ALqWqPW9n+01L+rrKfbYG/AS6QtEvH/g4A7gDmAKcCZ0tS2XY+8FLg94DfAj5bvuO+wDnAe8r3+wKwWNLmU3yty4ADJW0nafsS0ze6yqwBDgW2oUr6n5W0X8f2nUucuwLvBBZJeuUUx4w+kviG70Tg/ZJe1rX+rcCdts+3vc72RcDtwJ92lPmy7eVl+zNl3Zds32X7UeDbwF22v297HfBVYN+JH7Z9ge215ec/DWwObMgvzOnA48BE7e044OO2V9p+Cvhr4IhyDesI4HLb15ZtnwDGN+BYE+6jStTdnqH6A/A7gGzfZnv1ZDux/VXb99ket30xcCewf0eRe2x/sVzXPBfYBdipJMeDgeNsP2z7Gdv/Un5mIfAF29fbHrN9LvAUVRN9Mk8C3wT+skyLy7rOWL9V/p+6HOt7PD/5f8L2U2X7t6gSf2ykJL4hs30LcDlwQteml/NsLW7CPVR/1Sfc22OX93fM/6bH8lYTC5I+XJpQj0p6hKr2M6dO3KW5tQB4u+2JBLYH8PXSzHuEqoY5BuxUvs/6eG3/Clhb51hddgUe6l5p+yrgDKoa5hpJiyRtM0X8R3c0SR8B9uG53/2XHfv+dZndCpgLPGS713W4PYAPTeyz7Hcu1XefynlUTdznNXNLrAdLuk7SQ2Wfh3TF+nA5nxPuqXHMmEIS32icBPw3npvU7qP6Req0O7CqY3mje/7K9byPUNUMtre9HfAooCl/8Nmf/d/A4bYf69h0L3Cw7e06pi1srwJWUyWBiX28lKo5uCExbwW8EfhBr+22T7f9KmBvqibv/5zY1LWfPYAvAscDO5bvfgs1vnv5jjuo93XZe4FTur7/S0ttfSo/oNQogR92xbo58DXgH4CdSqxXdMW6vZ7b4bM71b+f2EhJfCNgewVwMfBXHauvAH5b0ttLp8JfUv1CXz6gw24NrAMeADaVdCLVNaQpSZoLXAIcbfvnXZvPBE4piQVJL5N0eNl2KXCopNdL2gw4mZr/viRtLulVVNfDHqa6Ztld5o8kHSBpNvArqubiRE30fmDPjuJbUiXDB8rPHktV4+urNJ+/DfxfSduruv3kwLL5i8BxJQ5J2lLSWyVt3WefprqEcZif/xy4zaguQTwArCudLL1u4fkbSZuVP0qHUl3WiI2UxDc6J1P9QgJgey3VP+APUTUJPwIcavvBAR3vu8B3gJ9TNY2epHfTudtBVDWTS/Vsz+7ysu00qmtU35P0OHAdVScBtpcD76PqRFlNlcBW9jnWR8p+1lI1AW8EXtvVrJuwDVXiebh8n7VUveYAZwN7l+bnZbZvBT5N1fFzP/D7wI9qfPcJR1FdU7ydquPhg+U7LqGquZ9R4lgBHFNnh+Va7fIe6x+n+oN4Sdnn26nOcadflm33UXXkHGf79g34PtFFeRBpxPQlaQFwge0N6iGPqaXGFxGt01jiKzeJXinpzvK5/STlxkrv3FJJ3U2AiIgN1lhTV9KpVLcNfFLSCVQ9j/+rR7knys20ERED0WTiuwNYYHt1uWn0GtvPu7k2iS8iBq3JxPdIuWeJMlTo4YnlrnLrgKVUt2Z80vZl3WVKuYVUd9Yzi1mvemn/Ozda6yW/23QE09/c2b/uX6jlbrz5qQdtd49I2iD/6U+29NqH+j8U6Mabn/qu7be8kGN12rR/kY0n6ftU4wy7PWcAu21r8sf07GF7laQ9gaskLbN9V3ehMmZzEcA22sEH6KAXGP3Mtc+F6dPq59O73NR0CNPerF1WdI882mBrHxrjX7+7e41j3VlrxFFdQ018tid9CKWk+yXt0tHUXTPJPlaVz7tVPctuX6oB6BHxImdgfKOGdL8wTf7pX0z1pAnKZ/cTKyh3zm9e5ucArwNuHVmEETFUxjzjsb7ToDWZ+D4JvEnSnVTjMz8JIGm+pLNKmd8Flkj6GXA11TW+JL6IGWS8xn/9SHplx21vS1U9s/GDk5UfalN3KmXI1vMuxJVhQe8u8z+mGm4UETOQMWMD6GC1fQfVsydR9e6YVcDXJyvfWOKLiAAY3/iHEE3mIKrnVE7a+ZLEFxGNMTBWL/HNkbSkY3lRuZOjlyN59qnhPSXxRUSjatb4HrQ9v1+h8ki0w4CPTlUuiS8iGmPgmcEOojgYuMn2/VMVSuKLiMYY123q1vU2+jRzIYkvIppkGBtQ3iuP538T1VvwppTEFxGNqUZuDGhf1ZO7a73nJYkvIhokxmq9A2qwkvgiojEGxht4QFQSX0Q0xsDTDYycTeKLiEaNO03diGiRauRGEl9EtIgRY2nqRkTbpKkbEa1ixNOeNfLjJvFFRGOqG5jT1I2IlknnRkS0ii3GnBpfRLTMeGp8EdEm1X18qfFFRIsY8YxHn4aS+CKiUWO5jy8i2iQjNyKilcbTqxsRbZLOjYhoHaNc44uIdrFppFd39HXMLpLeIukOSSskndBj++aSLi7br5c0b/RRRsRwiPEa06A1mvgkzQI+R/US4L2Bt0nau6vYu4CHbf974LPAp0YbZUQMi4Exb9J3GrSma3z7Ayts3237aeArwOFdZQ4Hzi3zlwIHSRr9RYGIGIoxNuk7DVrTiW9X4N6O5ZVlXc8yttcBj1Lz3ZkRMb0ZMe7+06DNmM4NSQuBhQBb8NKGo4mIOkw7OzdWAXM7lncr63qWkbQpsC2wtntHthfZnm97/mw2H1K4ETFY1QvF+02D1nTiuwHYS9IrJG0GHAks7iqzGHhnmT8CuMp2A68gjohBq14ovknfqQ5J20m6VNLtkm6T9JrJyjba1LW9TtLxwHeBWcA5tpdLOhlYYnsxcDZwvqQVwENUyTEiZogB1uhOA75j+4hSkZr0mlfj1/hsXwFc0bXuxI75J4G/GHVcETF8tgYyVlfStsCBwDHVfv008PRk5Ztu6kZEy9W8j2+OpCUd08Ku3bwCeAD4kqSfSjpL0paTHbPxGl9EtFf1INJar5d80Pb8KbZvCuwHvN/29ZJOA04APtGrcGp8EdGYqnNjIPfxrQRW2r6+LF9KlQh7SuKLiEYNYuSG7V8C90p6ZVl1EHDrZOXT1I2IxkyM3BiQ9wMXlh7du4FjJyuYxBcRjRofUMPT9lJgquuA6yXxRURj7LxsKCJaxoh147V6dQcqiS8iGjWMsbj9JPFFRGMmbmcZtSS+iGjQYIasbagkvoho1DDeqdFPEl9ENMaGZ9K5ERFtMuAbmGtL4ouIRqWpGxGtkl7diGil9OpGRLsM6fWR/STxRURjDKxLjS8i2iTX+CKilZL4IqJVch9fRLRS7uOLiHZxmroR0TIG1o2nVzciWiTX+CKilZzEFxFt00TnRuMvFJf0Fkl3SFoh6YQe24+R9ICkpWV6dxNxRsTguXRu9JsGrdEan6RZwOeANwErgRskLbbd/Qb0i20fP/IAI2LIxFgDnRtN1/j2B1bYvtv208BXgMMbjikiRshW32nQmr7Gtytwb8fySuCAHuX+XNKBwM+B/2H73u4CkhYCCwFm7bgdP//7Wi9Ub6dXLWk6gmlvv/e+t+kQXgQ+9IL30NRY3aZrfHV8E5hn+w+AK4FzexWyvcj2fNvzZ2215UgDjIiN5Oo6X7+pDkm/kLSs9AVM+de96RrfKmBux/JuZd16ttd2LJ4FnDqCuCJiRAbcq/snth/sV6jpGt8NwF6SXiFpM+BIYHFnAUm7dCweBtw2wvgiYohMC6/x2V4n6Xjgu8As4BzbyyWdDCyxvRj4K0mHAeuAh4BjGgs4IgZMjI3XSmxzupqvi2wv6ipj4HuSDHyhx/b1mm7qYvsK4IqudSd2zH8U+Oio44qI0ahZo3vQdr8ey9fbXiXpt4ArJd1u+9peBZtu6kZEi1WdF4Np6tpeVT7XAF+nul2upyS+iGjUIEZuSNpS0tYT88CbgVsmK994Uzci2q3u7Sp97AR8XRJUee0fbX9nssJJfBHRGCPGBzBkzfbdwH+oWz6JLyIaNZgK34ZJ4ouI5jjP44uINmqgypfEFxGNSo0vIlpnQL26GySJLyIaY4PzlrWIaJvU+CKifZL4IqJdhvPYqX6S+CKiWanxRUSr5AbmiGilJL6IaJ00dSOidZL4IqJVTJq6EdE+uYE5Itqn3lvWBiqJLyIapdT4IqJVTDo3IqJtlM6NiGih1PgionWS+CKiVUwjvbpTPvpU0iXlc5mkmzumZZJuHkQAks6RtEZSz7eeq3K6pBXl2PsN4rgRMT3I/adB61fj+0D5PHTwh17vy8AZwHmTbD8Y2KtMBwCfL58RMRNMt6au7dXl856pykn6ie3XbEwAtq+VNG+KIocD59k2cJ2k7STtMhFbRMSGGtRbPrYY0H562RW4t2N5ZVn3HJIWSloiacnYE78aYjgRMUiDbOpKmiXpp5Iun6rcoBJfA5XVrgDsRbbn254/a6stmw4nIuqy+k/1fQC4rV+h0b/XbcOtAuZ2LO9W1kXEi52B8RpTDZJ2A94KnNWvbK3EJ2nvHusWdC7WC22jLAaOLr27rwYezfW9iJljgE3d/wN8hBqpsu59fJdIOh84lep63qnAfGCiQ+Oo2qF1kXQRsACYI2klcBIwG8D2mcAVwCHACuDXwLEbe6yImIbqJbY5kpZ0LC+yvWhiQdKhwBrbN3ZVynqqm/gOAD4F/BjYGrgQeN3ERts978Grw/bb+mw38L6N3X9ETHP1Et+DtudPsf11wGGSDqGqnG0j6QLb/7VX4brX+J4BfgO8pOz032zXbHlHRPRWp5lbp6lr+6O2d7M9DzgSuGqypAf1E98NVInvj4A/Bt4m6as1fzYiYnLj6j8NWN2m7rtsT7SvVwOHS9ro63oRERMGPSTN9jXANVOVqZX4OpJe57rzNyqqiIhO023IWkTEUA3pIQT9JPFFRLOS+CKidZL4IqJt0tSNiPZJ4ouIVknnRkS0UhJfRLROEl9EtIlIUzci2sagBh53ksQXEc1KjS8iWieJLyLaJtf4IqJ9kvgiolXSuRERrZQaX0S0Ta7xRUT7JPFFRKuYJL6IaBeVadSS+CKiUenVjYj2SVM3IlqngcS3yegP+SxJ50haI+mWSbYvkPSopKVlOnHUMUbEEJUnMPebBq3pGt+XgTOA86Yo8wPbh44mnIgYubY1dW1fK2lekzFERLMG0bkhaQvgWmBzqrx2qe2TJivfdI2vjtdI+hlwH/Bh28t7FZK0EFgIMHub7dnq9s1GGOKLy6wdd2g6hGnvZZ//SdMhtMaAmrJPAW+w/YSk2cAPJX3b9nW9Ck/3xHcTsEf5MocAlwF79SpoexGwCOAlO89toPIcERtsQDcw2zbwRFmcXaZJ99xo50Y/th+z/USZvwKYLWlOw2FFxCC5xlSDpFmSlgJrgCttXz9Z2Wmd+CTtLEllfn+qeNc2G1VEDMrEy4Zq9OrOkbSkY1rYvS/bY7b/ENgN2F/SPpMdt9GmrqSLgAVUX2olcBJVFRXbZwJHAO+VtA74DXBkqdJGxExR7zf6Qdvza+3OfkTS1cBbgJ63yjXdq/u2PtvPoLrdJSJmIoPGX3hdRtLLgGdK0nsJ8CbgU5OVn+6dGxExww2oV3cX4FxJs6guiV1i+/LJCifxRUSzBtOrezOwb93ySXwR0ag8gTki2ieJLyJaZUgPIegniS8iGiPyINKIaKMGbs1N4ouIRqWpGxHtkresRUQb5RpfRLROEl9EtItJ50ZEtE86NyKifZL4IqJNJh5EOmpJfBHRHDvX+CKifdKrGxGtk6ZuRLSLgQE8en5DJfFFRLNS44uItklTNyLaJ726EdEqTq9uRLRMdQNzanwR0Tap8UVE26TGFxHt0tATmDcZ/SGfJWmupKsl3SppuaQP9CgjSadLWiHpZkn7NRFrRAyD0Xj/adCarvGtAz5k+yZJWwM3SrrS9q0dZQ4G9irTAcDny2dEzAQNNHUbrfHZXm37pjL/OHAbsGtXscOB81y5DthO0i4jDjUihqHcztJv6qdO67FT0zW+9STNA/YFru/atCtwb8fyyrJu9UgCi4jhGkyNr07rcb1Ga3wTJG0FfA34oO3HNnIfCyUtkbRk7De/GmyAETE8rjH120W91uN6jSc+SbOpkt6Ftv+pR5FVwNyO5d3Kuuewvcj2fNvzZ71ky+EEGxEDJ7vvBMyZqNiUaeGk+5u89bheo01dSQLOBm6z/ZlJii0Gjpf0FapOjUdtp5kbMRMYGKvV1H3Q9vx+heq2Hpu+xvc64ChgmaSlZd3HgN0BbJ8JXAEcAqwAfg0c20CcETEEwgO7gblG63G9RhOf7R9SDdebqoyB940moogYuQEkvpqtx/Uav8YXES038cKhqab+JlqPb5C0tEyHTFa46aZuRLSZGchDCuq0Hjsl8UVEo/KQgohoGcP46J9LlcQXEc0xefR8RLRQHkQaEW2Ta3wR0T5JfBHRKjaMpXMjItomNb6IaJ0kvohoFQNDeKdGP0l8EdEgg3ONLyLaJk3diGgVk17diGih1Pgiol1qP29voJL4IqI5Jk9niYgWSo0vIloniS8iWsXGY2MjP2wSX0Q0KyM3IqJ10tSNiFZx3rkREW2UGl9EtEs6NyKibfJYqohopQYeS7XJyI/YQdJcSVdLulXSckkf6FFmgaRHJS0t04lNxBoRg2fA4+471SHpHElrJN3Sr2zTNb51wIds3yRpa+BGSVfavrWr3A9sH9pAfBExTB7og0i/DJwBnNevYKM1Pturbd9U5h8HbgN2bTKmiBitQdX4bF8LPFSnrNxAV3IvkuYB1wL72H6sY/0C4GvASuA+4MO2l/f4+YXAwrK4D9C3ujtic4AHmw6iQ+KZ2nSLB6ZfTK+0vfUL2YGk71B9r362AJ7sWF5ke1GP/c0DLre9z5THnQ6JT9JWwL8Ap9j+p65t2wDjtp+QdAhwmu29+uxvie35w4t4w023mBLP1KZbPDD9Yppu8UD9xNdoUxdA0myqGt2F3UkPwPZjtp8o81cAsyXV+QsREdFT0726As4GbrP9mUnK7FzKIWl/qpjXji7KiJhpmu7VfR1wFLBM0tKy7mPA7gC2zwSOAN4raR3wG+BI92+fP6/tPw1Mt5gSz9SmWzww/WKaVvFIughYAMyRtBI4yfbZPctOh2t8ERGj1Pg1voiIUUvii4jWmRGJT9IOkq6UdGf53H6ScmMdQ98WDyGOt0i6Q9IKSSf02L65pIvL9utL1/tQ1YjpGEkPdJyXdw8xlimHFKlyeon1Zkn7DSuWmvGMdLhkzSGcoz5HM3NYqe0X/QScCpxQ5k8APjVJuSeGGMMs4C5gT2Az4GfA3l1l/jtwZpk/Erh4yOelTkzHAGeM6P/TgcB+wC2TbD8E+DYg4NXA9Q3Hs4DqnrChn5tyvF2A/cr81sDPe/z/GvU5qhPTSM/TIKYZUeMDDgfOLfPnAn/WQAz7Ayts3237aeArJa5OnXFeChw0catOgzGNjPsPKTocOM+V64DtJO3SYDwj5XpDOEd9jmbksNKZkvh2sr26zP8S2GmScltIWiLpOkmDTo67Avd2LK/k+f9A1pexvQ54FNhxwHFsaEwAf16aTZdKmjvEePqpG+8ovUbSzyR9W9Lvjeqg5TLIvsD1XZsaO0dTxAQNnaeN1fR9fLVJ+j6wc49NH+9csG1Jk92js4ftVZL2BK6StMz2XYOO9UXmm8BFtp+S9B6qGukbGo5puriJ6t/MxHDJy4Aph0sOQhnC+TXgg+4Yt96kPjE1cp5eiBdNjc/2G23v02P6BnD/RHW/fK6ZZB+ryufdwDVUf70GZRXQWVvarazrWUbSpsC2DHcUSt+YbK+1/VRZPAt41RDj6afOORwZNzBcst8QTho4RzNxWOmLJvH1sRh4Z5l/J/CN7gKStpe0eZmfQzVqpPu5fy/EDcBekl4haTOqzovunuPOOI8ArnK5OjwkfWPquj50GNU1nKYsBo4uPZevBh7tuIQxcqMeLlmONeUQTkZ8jurENOrzNBBN964MYqK6TvbPwJ3A94Edyvr5wFll/rXAMqqezWXAu4YQxyFUvV53AR8v604GDivzWwBfBVYA/wrsOYJz0y+mvwOWl/NyNfA7Q4zlImA18AzVtal3AccBx5XtAj5XYl0GzB/yuekXz/Ed5+Y64LVDjuf1VA8lvhlYWqZDGj5HdWIa6XkaxJQhaxHROjOlqRsRUVsSX0S0ThJfRLROEl9EtE4SX0S0ThJfRLROEl9MO5JeLunSpuOImSv38UVE66TGFyMj6WRJH+xYPmWSB1vOm+zhoBGDkMQXo3QOcDSApE2oxg5f0GhE0UovmsdSxYuf7V9IWitpX6pnJv7U9vQezB4zUhJfjNpZVI+735mqBhgxcunciJEqj8daBswG9rI91qPMPKp3OOwz2uiiLVLji5Gy/bSkq4FHeiW9zqKjiinaJ4kvRqp0arwa+Ispiu3INHoJUMw86dWNkZG0N9VDWP/Z9p2TlJlP9YDQ00YZW7RLrvFFYyT9PnB+1+qnbB/QRDzRHkl8EdE6aepGROsk8UVE6yTxRUTrJPFFROv8f9PWvVHkWD8FAAAAAElFTkSuQmCC\n", 172 | "text/plain": [ 173 | "
" 174 | ] 175 | }, 176 | "metadata": { 177 | "needs_background": "light" 178 | }, 179 | "output_type": "display_data" 180 | } 181 | ], 182 | "source": [ 183 | "d_tilde = d / (d.min(axis=1, keepdims=True) + 1e-5)\n", 184 | "\n", 185 | "plt.imshow(d_tilde)\n", 186 | "plt.colorbar()\n", 187 | "plt.title('Normalized Distance Map');\n", 188 | "plt.ylabel('x_i');\n", 189 | "plt.xlabel('y_j');" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "Convert to similarity showed in Eq.(3)." 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 7, 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "data": { 206 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUgAAAEXCAYAAADPzN0RAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAaVUlEQVR4nO3de9RddX3n8feHCGQQEDAdQQhExyim6BRMCRQvtN4CdYhrqZ3QVsSFk/FCxwu2pTqDLZ2ZKjPjjI4sMVWGi5SL6MJHJ64IioOtBomICYEiAbVcolxNQBHI83zmj70fejjZ55LknPM7z5PPa629ztln/85vf58DfPnt/bts2SYiIra1W+kAIiLGVRJkREQHSZARER0kQUZEdJAEGRHRQRJkREQHSZC7MEl/JOnrO/jdV0i6rWX/J5JesxOxPCrp+Tv6/YhhSIKc5SS9XNJ3JG2W9JCkf5D02wC2L7H9uh2p1/a3bb9oUHHa3tv2nXXMF0j6zztaV52sn5A0r+3zH0iypAU7F23sKpIgZzFJ+wJfBf43cABwMPBXwOMl42ol6RlDqvrHwMkt53kJsNeQzhWzVBLk7PZCANuX2p60/Zjtr9teByDpVEl/P124bl29W9Ltkh6R9NeS/lXdAt0i6QpJe9Rlj5d0d9NJJR0t6buSfiFpk6RPTX+v5TzvkXQ7cHvLZy+QtAL4I+DP6svur0j6U0lfbDvHJyV9osvffjFwSsv+24CL2ur4/bpVuUXSXZL+suXYgjqmFZLurf+OD3Y5X8xCSZCz24+ASUkXSjpB0v59fOf1wMuAY4A/A1YCfwzMB46gpVXWxSTwfmAecCzwauDdbWXeCCwBFrV+aHslcAlwTn3Z/W+AzwNLJe0HT7U6l9OW8NqsAfaV9GJJc+ryn28r80uqJLof8PvAuyS9sa3M7wILgdcBf74z91lj5kmCnMVsbwFeDhj4W+B+SROSntPla+fY3mJ7A3Az8HXbd9reDHwNOLKP837f9hrbW23/BPgM8Kq2Yn9j+yHbj/VR3ybgOuAt9UdLgQdsf7/HV6dbka8FbgXuaav3W7bX256qW9WXNsT5V7Z/aXs98H/o738QMUskQc5ytm+1fartQ6hagM8F/leXr/y85f1jDft79zqnpBdK+qqkn0naAvxXqtZkq7v6+gP+2YVULVnq14v7+M7FwB8Cp9LQ2pS0RNK1ku6XtBl4Z484f0r1+8UuIglyF2L7H4ELqBLlMH0a+Edgoe19gQ8Bag+ny/ebjl0FvFTSEcAbqC7Du7L9U6rOmhOBLzUU+TtgAphv+1nAeQ1xzm95fyhwb6/zxuyRBDmLSTpc0hmSDqn351NdIq4Z8qn3AbYAj0o6HHjXdn7/58DTxkTa/jVwJVVS+57tf+qzrtOA37P9yw5xPmT715KOpmpttvtPkvaS9JvA24HL+/0jYuZLgpzdHqHqCLle0i+pEuPNwBlDPu8HqZLNI1T3Prc3qXwOWFT3gl/V8vmFwEvo7/IaANt32F7b4fC7gbMlPQKcBVzRUOb/ARuBbwD/3fYODayPmUlZMDdmCkmHUl26H1h3QA3zXAuoLs93t711mOeK8ZUWZMwIknYDPgBcNuzkGDGtWIKUdICkq+tByVd3GqMnaVLSTfU2Meo4ozxJz6S6p/la4COFw4kxJel8SfdJurnDcdUTDDZKWifpqJ51lrrElnQO1Q3yj0o6E9jf9p83lHvUds+hJRGxa5P0SuBR4CLb24zUkHQi8CdUoxqWAJ+wvaRbnSUvsZdR3XSnfm2fwRAR0Tfb1wEPdSmyjCp52vYaYD9JB3Wrc1gLBfTjOfUMCYCfAZ1md8yVtBbYCnzU9lVNheo5vCsAnrmXXnb4C/ZoKhbAj9ZlzYbYeY/w8AO2f2Nn6nj97z7TDz402bPc99c9vtr20p05F9ViLa0D/++uP9vUXHzICVLSNcCBDYc+3Lpj25I6XesfZvueeq3Ab0pab/uO9kL1HN6VAIv/9Vx/b/X89iJRe/1zf6t0CDELXOMrf7qzdTz40CTfW31oz3JzDrr98LqhNG1l/d/8UA01QdruOLFf0s8lHWR7U93Mva9DHffUr3dK+hbVXOBtEmREzDwGppjqp+gDthfv5Onu4ekzow6hbX5+u5L3ICeolqCifv1yewFJ+0vas34/DzgOuGVkEUbEUBnzpCd7bgMyAZxS92YfA2xuuc3XqOQ9yI8CV0g6jWoRgD8AkLQYeKftdwAvBj4jaYoqmX/UdhJkxCzSZwuyJ0mXAscD8+q1Sj8C7A5g+zxgFVUP9kbgV1RTR7sqliBtP0i1TmD752uBd9Tvv0M1tSwiZiFjJgc01NB216XoXI1pfM/21FmyBRkRwVTXhZ3KSoKMiGIMTCZBRkQ0SwsyIqKBgSfHeEWxJMiIKMY4l9gREY0Mk+ObH5MgI6KcaibN+EqCjIiCxOQ2z0kbH0mQEVGMgalcYkdEbMvAE2P85JckyIgoasq5xI6I2EY1kyYJMiJiG0ZM5hI7IqJZLrEjIhoY8YTnlA6joyTIiCimGiieS+yIiEbppImIaGCLSacFGRHRaCotyIiIbVXjINOCjIjYhhFPenzT0PhGFhG7hMmMg4yI2FZm0kREdDGVXuyIiG2lkyYiogOj3IOMiGhiM9a92MXbtpKWSrpN0kZJZzYc31PS5fXx6yUtGH2UETEcYqqPrZSiCVLSHOBc4ARgEXCypEVtxU4DHrb9AuB/Ah8bbZQRMSwGJr1bz62U0i3Io4GNtu+0/QRwGbCsrcwy4ML6/ZXAqyWN702LiNguk+zWcyuldII8GLirZf/u+rPGMra3ApuBZ48kuogYKiOm3HsrZXzvjm4nSSuAFQCHHjxr/qyIWc2kk6abe4D5LfuH1J81lpH0DOBZwIPtFdleaXux7cW/8ezxXaE4IlqJyT62UkonyBuAhZKeJ2kPYDkw0VZmAnhb/f7NwDdtj/GjxiOiX6aaSdNrK6Vo29b2VkmnA6uBOcD5tjdIOhtYa3sC+BxwsaSNwENUSTQiZomsKN6F7VXAqrbPzmp5/2vgLaOOKyKGz9bAWoiSlgKfoGpsfdb2R9uOH0o1Ima/usyZdf7pqHiCjIhd2yDGObaMqX4t1WiYGyRN2L6lpdh/BK6w/el6vPUqYEG3epMgI6KYasHcgXSqPjWmGkDS9Jjq1gRpYN/6/bOAe3tVmgQZEcVUnTR93YOcJ2lty/5K2ytb9pvGVC9pq+Mvga9L+hPgmcBrep00CTIiiupzpswDthfv5KlOBi6w/T8kHUvV+XuE7alOX0iCjIhipmfSDEA/Y6pPA5YC2P6upLnAPOC+TpWWHgcZEbu4KXbrufWhnzHV/wS8GkDSi4G5wP3dKk0LMiKKsQfz0K4+x1SfAfytpPdT3f48tdekkyTIiCjGiK1Tg5ka3MeY6luA47anziTIiCgqM2kiIhpsxzCfIpIgI6KgwU01HIYkyIgoquQzZ3pJgoyIYmx4ckCdNMOQBBkRxQxwoPhQJEFGRFG5xI6IaJBe7IiILtKLHRHRpPBjXXtJgoyIYgxsTQsyImJbuQcZEdFFEmRERIOMg4yI6CLjICMimjiX2BERjQxsnUovdkTENnIPMiKiCydBRkQ0G+dOmuIX/5KWSrpN0kZJZzYcP1XS/ZJuqrd3lIgzIgbPdSdNr62Uoi1ISXOAc4HXAncDN0iaqJ8+1upy26ePPMCIGDIxOcadNKUjOxrYaPtO208AlwHLCscUESNkq+dWSul7kAcDd7Xs3w0saSj3JkmvBH4EvN/2Xe0FJK0AVgDMZS9e/9zfGkK4s8Pqe28qHcLYy78/ozHuc7FLtyD78RVgge2XAlcDFzYVsr3S9mLbi3dnz5EGGBE7yNV9yF5bKaUT5D3A/Jb9Q+rPnmL7QduP17ufBV42otgiYgSmUM+tlNIJ8gZgoaTnSdoDWA5MtBaQdFDL7knArSOMLyKGyOQeZEe2t0o6HVgNzAHOt71B0tnAWtsTwH+QdBKwFXgIOLVYwBExYGJyanzvQZbupMH2KmBV22dntbz/C+AvRh1XRIxGZtJERDSoOmGSICMiGo3zMJ8kyIgoquQwnl6SICOiGCOmxniqYRJkRBQ1xg3I4uMgI2JX5sGNg+y1Mlhd5g8k3SJpg6S/61VnWpARUdYAmpD9rAwmaSHVkMHjbD8s6V/2qjctyIgoakAtyH5WBvt3wLm2H67O6/t6VZoEGRFF9blYxTxJa1u2FW3VNK0MdnBbmRcCL5T0D5LWSFraK7ZcYkdEMTa4v17sB2wv3snTPQNYCBxPtTDOdZJeYvsXnb6QFmREFDWg5c56rgxG1aqcsP2k7R9TrS+7sFulSZARUZb72HrruTIYcBVV6xFJ86guue/sVmkusSOioMEsZ9bnymCrgddJugWYBP7U9oPd6k2CjIiyBjRSvI+VwQx8oN76kgQZEeVkNZ+IiC6SICMiOhjjydhJkBFRVhJkREQDk0vsiIhOsmBuREQneaphREQzpQUZEdGg/6mERSRBRkRBSidNRERHaUFGRHSQBBkR0cCMdS921/UgJV1Rv66XtK5lWy9p3SACkHS+pPsk3dzhuCR9sn5S2TpJRw3ivBExHuTeWym9WpDvrV/fMMQYLgA+BVzU4fgJVKv+LgSWAJ+uXyNiNpipl9i2N9WvP+1WTtJ3bR+7IwHYvk7Sgi5FlgEX1Wu5rZG0n6SDpmOLiBiWQT1yYe6A6mnSz9PKkLRi+olnT/L4EMOJiEGayZfY/SreSLa9ElgJsK8OKB5PRPQp4yB3Sj9PK4uImcjAVOkgOuvrElvSoobPjm/dHVRADSaAU+re7GOAzbn/GDF7zIZL7CskXQycQ3W/8RxgMTDdMfPWHQ1A0qVUj2KcJ+lu4CPA7gC2z6N6CM+JwEbgV8Dbd/RcETGGxviGWL8JcgnwMeA7wD7AJcBx0wdtN45h7Iftk3scN/CeHa0/IsbcLEiQTwKPAf+CqgX5Y9tjfOcgImaC0pfQvfQ7zOcGqgT528ArgJMlfWFoUUXErmNKvbdC+m1BnmZ7bf1+E7BM0g7fd4yImDbOLci+EmRLcmz97OLBhxMRu5yZniAjIoZizO9BJkFGRFlJkBERHSRBRkQ0yyV2REQnSZAREQ3SSRMR0UUSZEREB0mQERHbEuN9iT2oRy5ERGw/g6Z6b/2QtFTSbfUTUM/sUu5Nkixpca86kyAjoiz3sfUgaQ5wLtVTUBdRLajTtND3PlRPa72+n9CSICOirAEkSOBoYKPtO20/AVxG9UTUdn9Ntbbtr/upNAkyIorq85EL86afWlpvK9qq6fn0U0lHAfNt/99+Y0snTUSU1V8L8QHbPe8ZdiJpN+DjwKnb870kyIgox/13wvTQ6+mn+wBHAN+SBHAgMCHppKblHKclQUZEWYMZ5nMDsFDS86gS43LgD586hb0ZmDe9L+lbwAe7JUfIPciIKGwQj321vRU4HVgN3ApcYXuDpLMlnbSjsaUFGRFlDWiguO1VVI+Jbv3srA5lj++nziTIiCin/2E8RSRBRkQxqrdxlQQZEUUNqBd7KJIgI6KsXGJHRHQwxgmy6DAfSedLuk/SzR2OHy9ps6Sb6q2xRyoiZqg+hviUXA6tdAvyAuBTwEVdynzb9htGE05EjNwYtyCLJkjb10laUDKGiCgrnTQ751hJPwTupZoatKGpUL26xwqAuew1wvBmnhNe9IrSIYy91fd+u3QIY2/OQYOpZ5xXFB/3BHkjcJjtRyWdCFwFLGwqaHslsBJgXx0wxj95RDxlzAeKj/VcbNtbbD9av18F7C5pXo+vRcRMMpgFc4dirBOkpANVr00k6WiqeB8sG1VEDMr0Q7vSi91A0qXA8VSrBd8NfATYHcD2ecCbgXdJ2go8Biy3PcYN8ojYbmP8X3TpXuyTexz/FNUwoIiYjQyaGt8MOe6dNBExy6UXOyKikyTIiIhmaUFGRHSSBBkR0aDwMJ5ekiAjohiRudgREZ2N8dDmJMiIKCqX2BERTcZ8sYokyIgoKvcgIyI6SIKMiGhi0kkTEdFJOmkiIjpJgoyI2Nb0grnjKgkyIsqxcw8yIqKT9GJHRHSQS+yIiCYG8siFiIgOxjc/jvdjXyNi9hvUY18lLZV0m6SNks5sOP4BSbdIWifpG5IO61VnEmRElDXdk91t60HSHOBc4ARgEXCypEVtxX4ALLb9UuBK4Jxe9SZBRkQ5rnqxe219OBrYaPtO208AlwHLnnYq+1rbv6p31wCH9Ko09yAjophqoHhf19DzJK1t2V9pe2XL/sHAXS37dwNLutR3GvC1XidNgoyIsvprIT5ge/EgTifpj4HFwKt6lU2CjIii+mxB9nIPML9l/5D6s6efS3oN8GHgVbYf71Vp7kFGRDnuc+vtBmChpOdJ2gNYDky0FpB0JPAZ4CTb9/VTadEEKWm+pGvrrvcNkt7bUEaSPll33a+TdFSJWCNiGIymem89a7G3AqcDq4FbgStsb5B0tqST6mL/Ddgb+IKkmyRNdKjuKaUvsbcCZ9i+UdI+wPclXW37lpYyJwAL620J8Gm633yNiJlkQItV2F4FrGr77KyW96/Z3jqLtiBtb7J9Y/3+EarMf3BbsWXARa6sAfaTdNCIQ42IYRjcMJ+hGJt7kJIWAEcC17cdauq+b0+iETFTDWCg+LCUvsQGQNLewBeB99nesoN1rABWAMxlrwFGFxFDNcZzsYsnSEm7UyXHS2x/qaFIX9339aDRlQD76oAx/skjotWAhvkMRelebAGfA261/fEOxSaAU+re7GOAzbY3jSzIiBgeA5PuvRVSugV5HPBWYL2km+rPPgQcCmD7PKpeqROBjcCvgLcXiDMihkB4rFuQRROk7b+nmo7ZrYyB94wmoogYuSTIiIgOkiAjIhqYfherKCIJMiKKyj3IiIhGhqnxbUImQUZEOSb3ICMiOhrfBmQSZESUlXuQERGdJEFGRDSwYXJ8r7GTICOirLQgIyI6SIKMiGhgoI9nzpSSBBkRBRmce5AREc1yiR0R0cCkFzsioqO0ICMimpR9amEvSZARUY7Jaj4RER2lBRkR0UESZEREAxtPTpaOoqMkyIgoKzNpIiI6yCV2REQD55k0ERGdpQUZEdEknTQREc2y3FlERBdjvNzZbiVPLmm+pGsl3SJpg6T3NpQ5XtJmSTfV21klYo2IwTPgKffc+iFpqaTbJG2UdGbD8T0lXV4fv17Sgl51lm5BbgXOsH2jpH2A70u62vYtbeW+bfsNBeKLiGHyYBbMlTQHOBd4LXA3cIOkibZcchrwsO0XSFoOfAz4t93qLdqCtL3J9o31+0eAW4GDS8YUEaM1oBbk0cBG23fafgK4DFjWVmYZcGH9/krg1ZLUrdLSLcin1M3dI4HrGw4fK+mHwL3AB21vaPj+CmBFvfv4Nb7y5iGFuqPmAQ+UDgKALcA4xVMZq3jmHDRe8dTGLaYX7WwFj/Dw6mumrpjXR9G5kta27K+0vbJl/2Dgrpb9u4ElbXU8Vcb2VkmbgWfT5TcdiwQpaW/gi8D7bG9pO3wjcJjtRyWdCFwFLGyvo/6xVtb1rbW9eMhhb5dxiynxdDdu8cD4xdSWsHaI7aWDiGVYil5iA0janSo5XmL7S+3HbW+x/Wj9fhWwu6R+/o8TEbuOe4D5LfuH1J81lpH0DOBZwIPdKi3diy3gc8Cttj/eocyB0/cJJB1NFXPXPyoidjk3AAslPU/SHsByYKKtzATwtvr9m4Fv2t2n8ZS+xD4OeCuwXtJN9WcfAg4FsH0e1R/yLklbgceA5b3+KOpL7TEzbjElnu7GLR4Yv5jGJp76nuLpwGpgDnC+7Q2SzgbW2p6gaoxdLGkj8BBVEu1KvXNNRMSuqfg9yIiIcZUEGRHRwaxIkJIOkHS1pNvr1/07lJtsmbLYfgN3EHEMfKrTCGI6VdL9Lb/LO4YYy/mS7pPUOEZVlU/Wsa6TdNSwYukznpFOc+1z6u2of6Ndezqw7Rm/AecAZ9bvzwQ+1qHco0OMYQ5wB/B8YA/gh8CitjLvBs6r3y8HLh/y79JPTKcCnxrRP6dXAkcBN3c4fiLwNUDAMcD1heM5HvjqKH6b+nwHAUfV7/cBftTwz2vUv1E/MY30dxrlNitakDx9CtGFwBsLxDCUqU4jiGlkbF9H1XvYyTLgIlfWAPtJOqhgPCPl/qbejvo32qWnA8+WBPkc25vq9z8DntOh3FxJayWtkTToJNo01an9X6SnTXUCpqc6DUs/MQG8qb5cu1LS/Ibjo9JvvKN0rKQfSvqapN8c1Um7TL0t9hv1Mx141L/TsJUeB9k3SdcABzYc+nDrjm1L6jR26TDb90h6PvBNSett3zHoWGeYrwCX2n5c0r+nauH+XuGYxkVf01wHrcfU2yIGMR14JpoxLUjbr7F9RMP2ZeDn05cZ9et9Heq4p369E/gW1f8NB2UoU52GHZPtB20/Xu9+FnjZEOPppZ/fcGRcYJprr6m3FPiNduXpwDMmQfbQOoXobcCX2wtI2l/SnvX7eVSzeNrXndwZQ5nqNOyY2u5fnUR1j6mUCeCUuqf2GGBzy62TkdOIp7nW5+o69ZYR/0b9xDTq32mkSvcSDWKjuo/3DeB24BrggPrzxcBn6/e/A6yn6sldD5w2hDhOpOrluwP4cP3Z2cBJ9fu5wBeAjcD3gOeP4LfpFdPfABvq3+Va4PAhxnIpsAl4kure2WnAO4F31sdFtejpHfU/o8VD/m16xXN6y2+zBvidIcfzcqpFttcBN9XbiYV/o35iGunvNMotUw0jIjqYLZfYEREDlwQZEdFBEmRERAdJkBERHSRBRkR0kAQZEdFBEmSMHUnPlXRl6TgiMg4yIqKDtCBjZCSdLel9Lfv/pcMCrAs6LWIbMUpJkDFK5wOnAEjajWpu+OeLRhTRxYxZ7ixmPts/kfSgpCOp1uz8ge3ZsahBzEpJkDFqn6V6zMOBVC3KiLGVTpoYqXrZtfXA7sBC25MNZRZQPePkiNFGF/F0aUHGSNl+QtK1wC+akmNr0VHFFNFJEmSMVN05cwzwli7Fns0YPUwrdl3pxY6RkbSIarHgb9i+vUOZxVQL2X5ilLFFNMk9yChG0kuAi9s+ftz2khLxRLRLgoyI6CCX2BERHSRBRkR0kAQZEdFBEmRERAf/HyfBk5lQwXMgAAAAAElFTkSuQmCC\n", 207 | "text/plain": [ 208 | "
" 209 | ] 210 | }, 211 | "metadata": { 212 | "needs_background": "light" 213 | }, 214 | "output_type": "display_data" 215 | } 216 | ], 217 | "source": [ 218 | "w = np.exp((1 - d_tilde) / 0.1)\n", 219 | "\n", 220 | "plt.imshow(w)\n", 221 | "plt.colorbar()\n", 222 | "plt.title('Similarity Map');\n", 223 | "plt.ylabel('x_i');\n", 224 | "plt.xlabel('y_j');" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": {}, 230 | "source": [ 231 | "The bandwidth parameter is set as `0.1` above.\n", 232 | "The setting is very important to get the acceptable result i.e. the inappropriate value prevents Top-1 feature enhancement like below." 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 8, 238 | "metadata": {}, 239 | "outputs": [ 240 | { 241 | "data": { 242 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUgAAAEXCAYAAADPzN0RAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAdbElEQVR4nO3dfbQddX3v8fcn4SEI4UFjBSEQrLGKaCtGHsQHWlEgeolrqV1gFbDYFJWqLXpL9V60tLbIbblXLiwxIlfgWpQHF0YbL6KCaJWHgDEhUCREKAlRIEAIyFPO+dw/Zg7d7MycvXOy956dk89rrVlnz57fnvnuSc73/Ob3+81vZJuIiNjYlKYDiIgYVkmQERE1kiAjImokQUZE1EiCjIiokQQZEVEjCbIPJP2JpO9N8LNvlHRHy/rdkg7fjFgek/SSiX5+GEg6QtKVLeuHSrqz/G7v7PDZnp7PYSbpRkmvbDqOySQJcoIkvUHSTyWtk/SQpH+T9DoA21+z/baJ7Nf2j23/Xq/itL2T7ZVlzF+V9PcT3VeZXJ6WNKPt/Z9LsqRZmxdtrc8BZ7Ssnw6cU363K2s+A/T+fA65f6I4N9EjSZATIGln4DvA/waeD+wJ/C3wVJNxtZK0TZ92/Svg2JbjvAp4Xp+ORflHZxfb17e8vQ+wvF/HrImjX+ezlxYCfyhp96YDmSySICfmZQC2L7E9YvsJ29+zvRRA0gmSfjJWuKxdfbi8LFwv6e8k/W5ZA31U0qWStivLHiZpVdVBJR0o6WeSHpG0RtI5Y59rOc5HJN0J3Nny3kslzQf+BPiv5aXptyV9UtIVbcc4W9IXxvnuFwPHtawfD1zUto+3l7XKRyXdK+mzLdtmlTHNl3Rf+T0+Mc7xjgJ+1PL5u4CXAN8uv8f2kj4g6fby3K6U9Oct5cc7n8+pUbeXLWvMfy1pKfC4pG0kvVjSFZIekPQrSR+tC1zSDpL+WdI95ZXGTyTtUG67TNKvy/eva700ljRX0m3l91nden4kvUPSkvL/wE8lvXpsm+0ngZuBI8Y5n7EpbGfZxAXYGVgLXEjxC7xb2/YTgJ+0rBv4Vvm5V1LUNH9A8Yu+C3AbcHxZ9jBgVctn7wYOL1+/FjgY2AaYBdwOfLztOFdT1Gp3aHnvpeXrrwJ/31J+D+BxYNdyfRvgfuC1Nd/7buBw4A7gFcBUYBVFjc7ArJbv8CqKP8CvBn4DvLPcNqssewmwY1nugbHvWHHMy4BPVsXRsv524HcBAW8Gfgsc0MX5bD8fVWWXADOBHcrvczNwGrBd+e+3EjiiJvZzgWsprjCmAq8Hti+3/SkwHdge+F/AkpbPrQHeWL7ereW7vKb89zmo3N/xZYzbt3z2bOCspn9HJsuSGuQE2H4UeAPFL/qXgQckLZT0onE+dqbtR20vB24Fvmd7pe11wHcp/vN3Ou7Ntq+3vcH23cCXKBJCq3+0/ZDtJ7rY3xrgOuA95VtHAg/avrnDR8dqkW+lSNKr2/Z7re1ltkdd1KovqYjzb20/bnsZ8H9ouWxvsyuwvsP3+Ffbd7nwI+B7wBs7fIdunW373vJ8vg54oe3TbT/tom33y8Ax7R+SNIUiCX7M9moXVxo/tf1UGfMFtteX658Ffl/SLuXHnwH2k7Sz7Ydt31K+Px/4ku0byv1dSPHH9uCWQ6+nOGfRA0mQE2T7dtsn2N4L2B94MUVNoM5vWl4/UbG+U6djSnqZpO+Ul2aPAv8AzGgrdm9XX+A/XQi8r3z9Pork18nFwHspasoXtW+UdJCka8rL0HXASR3ivIfi/FV5mKKmVUvSUZKuLzvLHgHmVhxvolrj3Ad4cXl5+0h5rE8BVX8YZwDTgLsq4p0q6QxJd5X/jne3fAbgXeV3uEfSjyQd0nL8U9qOP5PnnrvpwCMT+qaxkSTIHrD97xSXa/v3+VBfBP4dmG17Z4pfTrWHM87nq7ZdCbxa0v7AO4CvdQrC9j0UnTVzgW9WFPkXig6DmbZ3Ac6riHNmy+u9gftqDreUss23iqTtgSsoenBfZHtXYFHF8ao8znM7mKo6N1rP2b3Ar2zv2rJMtz234nMPAk9SXPq3ey8wj6K5YheKZgfGYrZ9k+15wO9Q/Ptc2nL8z7Ud/3m2L2nZ9yuAX9R/5dgUSZATIOnlkk6RtFe5PpPiEvH68T+52aYDjwKPSXo58KFN/PxvKNrNnuWiYf9yiqR2o+3/6HJfJwJ/ZPvxmjgfsv2kpAMpEkK7/y7peWXnxAeAb9QcZxEbX5632o6iHe8BYIOko4Buh1gtAeZKer6Knt+Pdyh/I7C+7LjZoawJ7q9yeFcr26PABcBZZcfOVEmHlAl9OsWl8VqKBP0PY5+TtJ2KcbS72H6G4t97tNz8ZeCksoYuSTuWHWLTy89Oo2invrrL7x8dJEFOzHqKhvIbJD1OkRhvBU7p83E/QZFs1lP8stQllTpfoWjbekQtA68pLrNfRXeX1wCUbX6LazZ/GDhd0nqKDo1LK8r8CFhB0Vn1T7YrB9aX7W/rJB1Us3098NHyGA9TnJ+FXX6NiylqW3dTtFuOez5tj1DUsv+Aogb9IHA+RS2wyieAZcBNwEPA5yl+5y6iaFZYTdFB1/6H9f3A3eXl90kUow8oz/efAeeU33UFRTPHmP8CXGu7rjYem0h2Jszd2knam+LSffeyA6qfx5pFkVy2tb2hy8+8Dfiw7XHvmtnaSboBONH2rU3HMlkkQW7lyt7Ws4Cdbf/pAI43i01MkBFNaewSu2z3uVrF4OmrJe1WU26kHBi7RFK3l07RBUk7UrRxvRX4TMPhRGwWSRdIul9SZQ26bLc9W9IKSUslHdBxn03VICWdSdGQf4akUykGW/91RbnHbHccAhMRWzdJbwIeAy6yvdGIEklzgb+gGH1xEPAF25Vt22Oa7KSZR9E5QPkz7UsRMWG2r6PoDKszjyJ52sW9/btK2mO8fTZ5A/6Lyjs5AH5N9WBbgGmSFgMbgDNcM3uLinuN5wPs+Dy99uUv3a6qWAC/XNq3uSViK7Kehx+0/cLN2ccRf7ij1z400rHczUufWk4xrnTMAtsLNvFwe/Lcgf+ryvfWVBfvc4KU9H2qB99+unXFtiXVXevvY3u1ijkNfyhpme2N7k4oT9YCgDm/P803XjWzvUiUjnjxHzQdQkwC3/fl92zuPtY+NMKNV+3dsdzUPe580vaczT3epuprgrRdOzGppN9I2sP2mrKae3/NPlaXP1dKupbinuWNEmREbHkMjD47Dr7vVvPcO7j2om0egXZNtkEupJiNhPLnt9oLSNqtvPMAFZO0HkoxsDYiJgFjnvFIx6VHFgLHlb3ZBwPrWpr5KjXZBnkGcKmkEynuKvhjAElzgJNsf5DivtIvSRqlSOZn2E6CjJhEelWDlHQJxZR1M1TM6/kZYFsA2+dR3LY6l+IOpN9S3OI6rsYSpO21wFsq3l8MfLB8/VOKW+AiYhIyZqRHQw1t102ZN7bdwEc2ZZ9bwjTyETGJjY47AVWzkiAjojEGRpIgIyKqpQYZEVHBwDNDPGFOEmRENMY4l9gREZUMI8ObH5MgI6I5xZ00wysJMiIaJEa6er5aM5IgI6IxBkZziR0RsTEDTw/xswOTICOiUaPOJXZExEaKO2mSICMiNmLESC6xIyKq5RI7IqKCEU97atNh1EqCjIjGFAPFc4kdEVEpnTQRERVsMeLUICMiKo2mBhkRsbFiHGRqkBERGzHiGQ9vGhreyCJiqzCScZARERvLnTQREeMYTS92RMTG0kkTEVHDKG2QERFVbIa6F7vxuq2kIyXdIWmFpFMrtm8v6Rvl9hskzRp8lBHRH2K0i6UpjSZISVOBc4GjgP2AYyXt11bsROBh2y8F/ifw+cFGGRH9YmDEUzouTWm6BnkgsML2SttPA18H5rWVmQdcWL6+HHiLpOFttIiITTLClI5LU5pOkHsC97asryrfqyxjewOwDnjBQKKLiL4yYtSdl6YMb+voJpI0H5gPsPeek+ZrRUxqJp0041kNzGxZ36t8r7KMpG2AXYC17TuyvcD2HNtzXviC4Z2hOCJaiZEulqY0nSBvAmZL2lfSdsAxwMK2MguB48vX7wZ+aHuIHzUeEd0yxZ00nZamNFq3tb1B0snAVcBU4ALbyyWdDiy2vRD4CnCxpBXAQxRJNCImicwoPg7bi4BFbe+d1vL6SeA9g44rIvrPVs9qiJKOBL5AUdk63/YZbdv3phgRs2tZ5tQy/9RqPEFGxNatF+McW8ZUv5ViNMxNkhbavq2l2H8DLrX9xXK89SJg1nj7TYKMiMYUE+b2pFP12THVAJLGxlS3JkgDO5evdwHu67TTJMiIaEzRSdNVG+QMSYtb1hfYXtCyXjWm+qC2fXwW+J6kvwB2BA7vdNAkyIhoVJd3yjxoe85mHupY4Ku2/1nSIRSdv/vbHq37QBJkRDRm7E6aHuhmTPWJwJEAtn8maRowA7i/bqdNj4OMiK3cKFM6Ll3oZkz1fwBvAZD0CmAa8MB4O00NMiIaY/fmoV1djqk+BfiypL+kaP48odNNJ0mQEdEYIzaM9ubW4C7GVN8GHLop+0yCjIhG5U6aiIgKmzDMpxFJkBHRoN7datgPSZAR0agmnznTSRJkRDTGhmd61EnTD0mQEdGYHg4U74skyIhoVC6xIyIqpBc7ImIc6cWOiKjS8GNdO0mCjIjGGNiQGmRExMbSBhkRMY4kyIiIChkHGRExjoyDjIio4lxiR0RUMrBhNL3YEREbSRtkRMQ4nAQZEVFtmDtpGr/4l3SkpDskrZB0asX2EyQ9IGlJuXywiTgjovdcdtJ0WprSaA1S0lTgXOCtwCrgJkkLy6ePtfqG7ZMHHmBE9JkYGeJOmqYjOxBYYXul7aeBrwPzGo4pIgbIVselKU23Qe4J3Nuyvgo4qKLcuyS9Cfgl8Je2720vIGk+MB9gGs/jyL3n9CHcyeGq+xY3HcLQe/sBRzQdwvBbs/m7GPZ7sZuuQXbj28As268GrgYurCpke4HtObbnbKvtBxpgREyQi3bITktTmk6Qq4GZLet7le89y/Za20+Vq+cDrx1QbBExAKOo49KUphPkTcBsSftK2g44BljYWkDSHi2rRwO3DzC+iOgjkzbIWrY3SDoZuAqYClxge7mk04HFthcCH5V0NLABeAg4obGAI6LHxMjo8LZBNt1Jg+1FwKK2905ref03wN8MOq6IGIzcSRMRUaHohEmCjIioNMzDfJIgI6JRTQ7j6SQJMiIaY8ToEN9qmAQZEY0a4gpk4+MgI2Jr5t6Ng+w0M1hZ5o8l3SZpuaR/6bTP1CAjolk9qEJ2MzOYpNkUQwYPtf2wpN/ptN/UICOiUT2qQXYzM9ifAefafrg4ru/vtNMkyIhoVJeTVcyQtLhlmd+2m6qZwfZsK/My4GWS/k3S9ZKO7BRbLrEjojE2uLte7Adtb+4chtsAs4HDKCbGuU7Sq2w/UveB1CAjolE9mu6s48xgFLXKhbafsf0rivllZ4+30yTIiGiWu1g66zgzGHAlRe0RSTMoLrlXjrfTXGJHRIN6M51ZlzODXQW8TdJtwAjwSdtrx9tvEmRENKtHI8W7mBnMwF+VS1eSICOiOZnNJyJiHEmQERE1hvhm7CTIiGhWEmRERAWTS+yIiDqZMDciok6eahgRUU2pQUZEVOj+VsJGJEFGRIOUTpqIiFqpQUZE1EiCjIioYIa6F3vc+SAlXVr+XCZpacuyTNLSXgQg6QJJ90u6tWa7JJ1dPqlsqaQDenHciBgOcuelKZ1qkB8rf76jjzF8FTgHuKhm+1EUs/7OBg4Cvlj+jIjJYEu9xLa9pvx5z3jlJP3M9iETCcD2dZJmjVNkHnBROZfb9ZJ2lbTHWGwREf3Sq0cuTOvRfqp087QyJM0fe+LZM36qj+FERC9tyZfY3Wq8kmx7AbAAYOcpz288nojoUsZBbpZunlYWEVsiA6NNB1Gvq0tsSftVvHdY62qvAqqwEDiu7M0+GFiX9seIyWMyXGJfKuli4EyK9sYzgTnAWMfM+ycagKRLKB7FOEPSKuAzwLYAts+jeAjPXGAF8FvgAxM9VkQMoSFuEOs2QR4EfB74KTAd+Bpw6NhG25VjGLth+9gO2w18ZKL7j4ghNwkS5DPAE8AOFDXIX9ke4paDiNgSNH0J3Um3w3xuokiQrwPeCBwr6bK+RRURW49RdV4a0m0N8kTbi8vXa4B5kibc7hgRMWaYa5BdJciW5Nj63sW9DycitjpbeoKMiOiLIW+DTIKMiGYlQUZE1EiCjIiolkvsiIg6SZARERXSSRMRMY4kyIiIGkmQEREbE8N9id2rRy5ERGw6g0Y7L92QdKSkO8onoJ46Trl3SbKkOZ32mQQZEc1yF0sHkqYC51I8BXU/igl1qib6nk7xtNYbugktCTIimtWDBAkcCKywvdL208DXKZ6I2u7vKOa2fbKbnSZBRkSjunzkwoyxp5aWy/y23XR8+qmkA4CZtv+129jSSRMRzequhvig7Y5thnUkTQHOAk7YlM8lQUZEc9x9J0wHnZ5+Oh3YH7hWEsDuwEJJR1dN5zgmCTIimtWbYT43AbMl7UuRGI8B3vvsIex1wIyxdUnXAp8YLzlC2iAjomG9eOyr7Q3AycBVwO3ApbaXSzpd0tETjS01yIhoVo8GitteRPGY6Nb3Tqspe1g3+0yCjIjmdD+MpxFJkBHRGJXLsEqCjIhG9agXuy+SICOiWbnEjoioMcQJstFhPpIukHS/pFtrth8maZ2kJeVS2SMVEVuoLob4NDkdWtM1yK8C5wAXjVPmx7bfMZhwImLghrgG2WiCtH2dpFlNxhARzUonzeY5RNIvgPsobg1aXlWonN1jPsC0KTsxZfr0AYa4ZTnw5+9pOoShd+MtlzUdwtCbukdv9jPMM4oPe4K8BdjH9mOS5gJXArOrCtpeACwA2GWbFw7xKY+IZw35QPGhvhfb9qO2HytfLwK2lTSjw8ciYkvSmwlz+2KoE6Sk3VXOTSTpQIp41zYbVUT0ythDu9KLXUHSJcBhFLMFrwI+A2wLYPs84N3AhyRtAJ4AjrE9xBXyiNhkQ/wb3XQv9rEdtp9DMQwoIiYjg0aHN0MOeydNRExy6cWOiKiTBBkRUS01yIiIOkmQEREVGh7G00kSZEQ0RuRe7IiIekM8tDkJMiIalUvsiIgqQz5ZRRJkRDQqbZARETWSICMiqph00kRE1EknTUREnSTIiIiNjU2YO6ySICOiOXbaICMi6qQXOyKiRi6xIyKqGMgjFyIiagxvfhzux75GxOTXq8e+SjpS0h2SVkg6tWL7X0m6TdJSST+QtE+nfSZBRkSzxnqyx1s6kDQVOBc4CtgPOFbSfm3Ffg7Msf1q4HLgzE77TYKMiOa46MXutHThQGCF7ZW2nwa+Dsx7zqHsa2z/tly9Htir007TBhkRjSkGind1DT1D0uKW9QW2F7Ss7wnc27K+CjhonP2dCHy300GTICOiWd3VEB+0PacXh5P0PmAO8OZOZZMgI6JRXdYgO1kNzGxZ36t877nHkg4HPg282fZTnXaaNsiIaI67XDq7CZgtaV9J2wHHAAtbC0h6DfAl4Gjb93ez00YTpKSZkq4pu96XS/pYRRlJOrvsul8q6YAmYo2IfjAa7bx03Iu9ATgZuAq4HbjU9nJJp0s6uiz2P4CdgMskLZG0sGZ3z2r6EnsDcIrtWyRNB26WdLXt21rKHAXMLpeDgC8yfuNrRGxJejRZhe1FwKK2905reX34pu6z0Rqk7TW2bylfr6fI/Hu2FZsHXOTC9cCukvYYcKgR0Q+9G+bTF0PTBilpFvAa4Ia2TVXd9+1JNCK2VD0YKN4vTV9iAyBpJ+AK4OO2H53gPuYD8wGmTdmph9FFRF8N8b3YjSdISdtSJMev2f5mRZGuuu/LQaMLAHbZ5oVDfMojolWPhvn0RdO92AK+Atxu+6yaYguB48re7IOBdbbXDCzIiOgfAyPuvDSk6RrkocD7gWWSlpTvfQrYG8D2eRS9UnOBFcBvgQ80EGdE9IHwUNcgG02Qtn9CcTvmeGUMfGQwEUXEwCVBRkTUSIKMiKhgup2sohFJkBHRqLRBRkRUMowObxUyCTIimmPSBhkRUWt4K5BJkBHRrLRBRkTUSYKMiKhgw8jwXmMnQUZEs1KDjIiokQQZEVHBQBfPnGlKEmRENMjgtEFGRFTLJXZERAWTXuyIiFqpQUZEVGn2qYWdJEFGRHNMZvOJiKiVGmRERI0kyIiICjYeGWk6ilpJkBHRrNxJExFRI5fYEREVnGfSRETUSw0yIqJKOmkiIqplurOIiHEM8XRnU5o8uKSZkq6RdJuk5ZI+VlHmMEnrJC0pl9OaiDUies+AR91xaUrTNcgNwCm2b5E0HbhZ0tW2b2sr92Pb72ggvojoJ2fC3Fq21wBrytfrJd0O7Am0J8iImKSarCF2Ig9JF7ukWcB1wP62H215/zDgCmAVcB/wCdvLKz4/H5hfru4P3NrfiDfZDODBpoNokXjGN2zxwPDF9Hu2p2/ODiT9P4rv1cmDto/cnGNNxFAkSEk7AT8CPmf7m23bdgZGbT8maS7wBduzO+xvse05/Yt40w1bTIlnfMMWDwxfTMMWTz802kkDIGlbihri19qTI4DtR20/Vr5eBGwrqZu/OBERm6XpXmwBXwFut31WTZndy3JIOpAi5rWDizIitlZN92IfCrwfWCZpSfnep4C9AWyfB7wb+JCkDcATwDHu3C6woE/xbo5hiynxjG/Y4oHhi2nY4um5oWiDjIgYRo23QUZEDKskyIiIGpMiQUp6vqSrJd1Z/tytptxIyy2LC/sQx5GS7pC0QtKpFdu3l/SNcvsN5djPvuoiphMkPdByXj7Yx1gukHS/pMoxqiqcXca6VNIB/Yqly3gGeptrl7feDvocbd23A9ve4hfgTODU8vWpwOdryj3WxximAncBLwG2A34B7NdW5sPAeeXrY4Bv9Pm8dBPTCcA5A/p3ehNwAHBrzfa5wHcBAQcDNzQcz2HAdwZxbsrj7QEcUL6eDvyy4t9r0Oeom5gGep4GuUyKGiQwD7iwfH0h8M4GYjgQWGF7pe2nga+XcbVqjfNy4C1jQ5gajGlgbF8HPDROkXnARS5cD+wqaY8G4xko22ts31K+Xg+M3XrbatDnqJuYJq3JkiBf5OK+boBfAy+qKTdN0mJJ10vqdRLdE7i3ZX0VG/9HeraM7Q3AOuAFPY5jU2MCeFd5uXa5pJl9jKeTbuMdpEMk/ULSdyW9clAHLZtfXgPc0LapsXM0TkzQ0Hnqt6bHQXZN0veB3Ss2fbp1xbYl1Y1d2sf2akkvAX4oaZntu3od6xbm28Altp+S9OcUNdw/ajimYXELxf+ZsdtcrwTGvc21F8pbb68APu6WeQma1CGmRs7TIGwxNUjbh9vev2L5FvCbscuM8uf9NftYXf5cCVxL8dewV1YDrbWvvcr3KstI2gbYhf7eFdQxJttrbT9Vrp4PvLaP8XTSzTkcGDdwm2unW29p4BxtzbcDbzEJsoOFwPHl6+OBb7UXkLSbpO3L1zMo7uLp5bRqNwGzJe0raTuKTpj2nvLWON8N/NBlK3efdIyprf3qaIo2pqYsBI4re2oPBta1NJ0M3KBvcy2PNe6ttwz4HHUT06DP00A13UvUi4WiHe8HwJ3A94Hnl+/PAc4vX78eWEbRk7sMOLEPccyl6OW7C/h0+d7pwNHl62nAZcAK4EbgJQM4N51i+kdgeXlergFe3sdYLqGY//MZirazE4GTgJPK7QLOLWNdBszp87npFM/JLefmeuD1fY7nDRSTbC8FlpTL3IbPUTcxDfQ8DXLJrYYRETUmyyV2RETPJUFGRNRIgoyIqJEEGRFRIwkyIqJGEmRERI0kyBg6kl4s6fKm44jIOMiIiBqpQcbASDpd0sdb1j9XMwHrrLpJbCMGKQkyBukC4DgASVMo7g3/v41GFDGOLWa6s9jy2b5b0lpJr6GYs/PntifHpAYxKSVBxqCdT/GYh90papQRQyudNDFQ5bRry4Btgdm2RyrKzKJ4xsn+g40u4rlSg4yBsv20pGuAR6qSY2vRQcUUUScJMgaq7Jw5GHjPOMVewBA9TCu2XunFjoGRtB/FZME/sH1nTZk5FBPZfmGQsUVUSRtkNEbSq4CL295+yvZBTcQT0S4JMiKiRi6xIyJqJEFGRNRIgoyIqJEEGRFR4/8DxZQEAZbZnpsAAAAASUVORK5CYII=\n", 243 | "text/plain": [ 244 | "
" 245 | ] 246 | }, 247 | "metadata": { 248 | "needs_background": "light" 249 | }, 250 | "output_type": "display_data" 251 | } 252 | ], 253 | "source": [ 254 | "w_ = np.exp((1 - d_tilde) / 0.8)\n", 255 | "\n", 256 | "plt.imshow(w_)\n", 257 | "plt.colorbar()\n", 258 | "plt.title('Similarity Map (failure case)');\n", 259 | "plt.ylabel('x_i');\n", 260 | "plt.xlabel('y_j');" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": {}, 266 | "source": [ 267 | "Normalize and get final result." 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 9, 273 | "metadata": {}, 274 | "outputs": [ 275 | { 276 | "name": "stdout", 277 | "output_type": "stream", 278 | "text": [ 279 | "CX: 0.9878605414928291\n" 280 | ] 281 | } 282 | ], 283 | "source": [ 284 | "cx_ij = w / np.sum(w, axis=1, keepdims=True) # normalize\n", 285 | "cx = np.mean(np.max(cx_ij, axis=0))\n", 286 | "print(f'CX: {cx}')" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "metadata": {}, 292 | "source": [ 293 | "## The Robustness of CX\n", 294 | "CX is the metric invariant misalignment of images.\n", 295 | "To show the robustness, shuffle the pixels of $y$ and measure CX on all patterns." 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 10, 301 | "metadata": {}, 302 | "outputs": [ 303 | { 304 | "name": "stdout", 305 | "output_type": "stream", 306 | "text": [ 307 | "[trial 0]: 0.9878605414928291\n", 308 | "[trial 1]: 0.9878605414928291\n", 309 | "[trial 2]: 0.9878605414928291\n", 310 | "[trial 3]: 0.9878605414928291\n", 311 | "[trial 4]: 0.9878605414928291\n", 312 | "[trial 5]: 0.9878605414928291\n" 313 | ] 314 | } 315 | ], 316 | "source": [ 317 | "from itertools import permutations\n", 318 | "\n", 319 | "def compute_cx(x, y): # integrate as a function\n", 320 | " X = x.reshape(-1, 3)\n", 321 | " Y = y.reshape(-1, 3)\n", 322 | "\n", 323 | " mu = Y.mean(axis=0, keepdims=True)\n", 324 | " X_centered = X -mu\n", 325 | " Y_centered = Y -mu\n", 326 | " X_normalized = X_centered / np.linalg.norm(X_centered, ord=2, axis=1, keepdims=True)\n", 327 | " Y_normalized = Y_centered / np.linalg.norm(Y_centered, ord=2, axis=1, keepdims=True)\n", 328 | "\n", 329 | " d = 1 - np.matmul(X_normalized, Y_normalized.transpose())\n", 330 | " d_tilde = d / (d.min(axis=1, keepdims=True) + 1e-5)\n", 331 | " w = np.exp((1 - d_tilde) / 0.1)\n", 332 | " cx_ij = w / np.sum(w, axis=1, keepdims=True)\n", 333 | "\n", 334 | " return np.mean(np.max(cx_ij, axis=0))\n", 335 | "\n", 336 | "for i, p in enumerate(permutations([0, 1, 2])):\n", 337 | " y_ = y[:, p, :]\n", 338 | " cx = compute_cx(x, y_)\n", 339 | " print(f'[trial {i}]: {cx}')" 340 | ] 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "metadata": {}, 345 | "source": [ 346 | "Did you cnofirm all of CXs is same? It's robustness!\n", 347 | "\n", 348 | "That's all, happy CX!" 349 | ] 350 | } 351 | ], 352 | "metadata": { 353 | "kernelspec": { 354 | "display_name": "Python 3", 355 | "language": "python", 356 | "name": "python3" 357 | }, 358 | "language_info": { 359 | "codemirror_mode": { 360 | "name": "ipython", 361 | "version": 3 362 | }, 363 | "file_extension": ".py", 364 | "mimetype": "text/x-python", 365 | "name": "python", 366 | "nbconvert_exporter": "python", 367 | "pygments_lexer": "ipython3", 368 | "version": "3.7.4" 369 | } 370 | }, 371 | "nbformat": 4, 372 | "nbformat_minor": 4 373 | } 374 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='contextual_loss_pytorch', 5 | version='latest', 6 | description='Contextual Loss w/ PyTorch', 7 | packages=find_packages(exclude=('tests', 'doc')), 8 | author='So Uchida', 9 | author_email='s.aiueo32@gmail.com', 10 | install_requires=["torch", "torchvision"], 11 | url='https://github.com/S-aiueo32/contextual_loss_pytorch', 12 | ) 13 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/S-aiueo32/contextual_loss_pytorch/c886571b07be95788e99e6751560216acd54dae3/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_contextual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from contextual_loss import functional as F 4 | from contextual_loss import ContextualLoss 5 | 6 | test_shape = [1, 256, 48, 48] 7 | 8 | 9 | def test_module(): 10 | prediction = torch.rand(*test_shape) 11 | f = ContextualLoss() 12 | loss = f(prediction, prediction) 13 | assert loss.shape == torch.Size([]) 14 | 15 | 16 | def test_module_gpu(): 17 | prediction = torch.rand(*test_shape).to('cuda:0') 18 | f = ContextualLoss().to('cuda:0') 19 | loss = f(prediction, prediction) 20 | assert loss.shape == torch.Size([]) 21 | 22 | 23 | def test_vgg(): 24 | prediction = torch.rand(test_shape[0], 3, *test_shape[2:]) 25 | f = ContextualLoss(use_vgg=True) 26 | loss = f(prediction, prediction) 27 | assert loss.shape == torch.Size([]) 28 | 29 | 30 | def test_vgg_gpu(): 31 | prediction = torch.rand(test_shape[0], 3, *test_shape[2:]).to('cuda:0') 32 | f = ContextualLoss(use_vgg=True).to('cuda:0') 33 | loss = f(prediction, prediction) 34 | assert loss.shape == torch.Size([]) 35 | 36 | 37 | def test_cosine(): 38 | prediction = torch.rand(*test_shape) 39 | loss = F.contextual_loss(prediction, prediction, loss_type='cosine') 40 | assert loss.shape == torch.Size([]) 41 | 42 | 43 | def test_cosine_gpu(): 44 | prediction = torch.rand(*test_shape).to('cuda:0') 45 | loss = F.contextual_loss(prediction, prediction, loss_type='cosine') 46 | assert loss.shape == torch.Size([]) 47 | 48 | 49 | def test_l1(): 50 | prediction = torch.rand(*test_shape) 51 | loss = F.contextual_loss(prediction, prediction, loss_type='l1') 52 | assert loss.shape == torch.Size([]) 53 | 54 | 55 | def test_l1_gpu(): 56 | prediction = torch.rand(*test_shape).to('cuda:0') 57 | loss = F.contextual_loss(prediction, prediction, loss_type='l1') 58 | assert loss.shape == torch.Size([]) 59 | 60 | 61 | def test_l2(): 62 | prediction = torch.rand(*test_shape) 63 | loss = F.contextual_loss(prediction, prediction, loss_type='l2') 64 | assert loss.shape == torch.Size([]) 65 | 66 | 67 | def test_l2_gpu(): 68 | prediction = torch.rand(*test_shape).to('cuda:0') 69 | loss = F.contextual_loss(prediction, prediction, loss_type='l2') 70 | assert loss.shape == torch.Size([]) 71 | -------------------------------------------------------------------------------- /tests/test_contextual_bilateral.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from contextual_loss import functional as F 4 | from contextual_loss import ContextualBilateralLoss 5 | 6 | test_shape = [1, 256, 48, 48] 7 | 8 | 9 | def test_module(): 10 | prediction = torch.rand(*test_shape) 11 | f = ContextualBilateralLoss() 12 | loss = f(prediction, prediction) 13 | assert loss.shape == torch.Size([]) 14 | 15 | 16 | def test_module_gpu(): 17 | prediction = torch.rand(*test_shape).to('cuda:0') 18 | f = ContextualBilateralLoss().to('cuda:0') 19 | loss = f(prediction, prediction) 20 | assert loss.shape == torch.Size([]) 21 | 22 | 23 | def test_vgg(): 24 | prediction = torch.rand(test_shape[0], 3, *test_shape[2:]) 25 | f = ContextualBilateralLoss(use_vgg=True) 26 | loss = f(prediction, prediction) 27 | assert loss.shape == torch.Size([]) 28 | 29 | 30 | def test_vgg_gpu(): 31 | prediction = torch.rand(test_shape[0], 3, *test_shape[2:]).to('cuda:0') 32 | f = ContextualBilateralLoss(use_vgg=True).to('cuda:0') 33 | loss = f(prediction, prediction) 34 | assert loss.shape == torch.Size([]) 35 | 36 | 37 | def test_cosine(): 38 | prediction = torch.rand(*test_shape) 39 | loss = F.contextual_bilateral_loss( 40 | prediction, prediction, loss_type='cosine') 41 | assert loss.shape == torch.Size([]) 42 | 43 | 44 | def test_cosine_gpu(): 45 | prediction = torch.rand(*test_shape).to('cuda:0') 46 | loss = F.contextual_bilateral_loss( 47 | prediction, prediction, loss_type='cosine') 48 | assert loss.shape == torch.Size([]) 49 | 50 | 51 | def test_l1(): 52 | prediction = torch.rand(*test_shape) 53 | loss = F.contextual_bilateral_loss( 54 | prediction, prediction, loss_type='l1') 55 | assert loss.shape == torch.Size([]) 56 | 57 | 58 | def test_l1_gpu(): 59 | prediction = torch.rand(*test_shape).to('cuda:0') 60 | loss = F.contextual_bilateral_loss( 61 | prediction, prediction, loss_type='l1') 62 | assert loss.shape == torch.Size([]) 63 | 64 | 65 | def test_l2(): 66 | prediction = torch.rand(*test_shape) 67 | loss = F.contextual_bilateral_loss( 68 | prediction, prediction, loss_type='l2') 69 | assert loss.shape == torch.Size([]) 70 | 71 | 72 | def test_l2_gpu(): 73 | prediction = torch.rand(*test_shape).to('cuda:0') 74 | loss = F.contextual_bilateral_loss( 75 | prediction, prediction, loss_type='l2') 76 | assert loss.shape == torch.Size([]) 77 | --------------------------------------------------------------------------------