├── .gitignore ├── LICENSE ├── README.md ├── data ├── __init__.py ├── base_dataset.py ├── fashion_dataset.py └── pose_utils.py ├── datasets └── fashion ├── model ├── __init__.py ├── base_model.py ├── cocos_model.py ├── contextual_loss.py ├── correspondence_net.py ├── discriminator.py ├── loss.py ├── networks.py └── translation_net.py ├── options.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # other stuff 132 | **/*.png 133 | **/*.pdf 134 | **/*.pth 135 | **/*.jpg 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU AFFERO GENERAL PUBLIC LICENSE 2 | Version 3, 19 November 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU Affero General Public License is a free, copyleft license for 11 | software and other kinds of works, specifically designed to ensure 12 | cooperation with the community in the case of network server software. 13 | 14 | The licenses for most software and other practical works are designed 15 | to take away your freedom to share and change the works. By contrast, 16 | our General Public Licenses are intended to guarantee your freedom to 17 | share and change all versions of a program--to make sure it remains free 18 | software for all its users. 19 | 20 | When we speak of free software, we are referring to freedom, not 21 | price. Our General Public Licenses are designed to make sure that you 22 | have the freedom to distribute copies of free software (and charge for 23 | them if you wish), that you receive source code or can get it if you 24 | want it, that you can change the software or use pieces of it in new 25 | free programs, and that you know you can do these things. 26 | 27 | Developers that use our General Public Licenses protect your rights 28 | with two steps: (1) assert copyright on the software, and (2) offer 29 | you this License which gives you legal permission to copy, distribute 30 | and/or modify the software. 31 | 32 | A secondary benefit of defending all users' freedom is that 33 | improvements made in alternate versions of the program, if they 34 | receive widespread use, become available for other developers to 35 | incorporate. Many developers of free software are heartened and 36 | encouraged by the resulting cooperation. However, in the case of 37 | software used on network servers, this result may fail to come about. 38 | The GNU General Public License permits making a modified version and 39 | letting the public access it on a server without ever releasing its 40 | source code to the public. 41 | 42 | The GNU Affero General Public License is designed specifically to 43 | ensure that, in such cases, the modified source code becomes available 44 | to the community. It requires the operator of a network server to 45 | provide the source code of the modified version running there to the 46 | users of that server. Therefore, public use of a modified version, on 47 | a publicly accessible server, gives the public access to the source 48 | code of the modified version. 49 | 50 | An older license, called the Affero General Public License and 51 | published by Affero, was designed to accomplish similar goals. This is 52 | a different license, not a version of the Affero GPL, but Affero has 53 | released a new version of the Affero GPL which permits relicensing under 54 | this license. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | TERMS AND CONDITIONS 60 | 61 | 0. Definitions. 62 | 63 | "This License" refers to version 3 of the GNU Affero General Public License. 64 | 65 | "Copyright" also means copyright-like laws that apply to other kinds of 66 | works, such as semiconductor masks. 67 | 68 | "The Program" refers to any copyrightable work licensed under this 69 | License. Each licensee is addressed as "you". "Licensees" and 70 | "recipients" may be individuals or organizations. 71 | 72 | To "modify" a work means to copy from or adapt all or part of the work 73 | in a fashion requiring copyright permission, other than the making of an 74 | exact copy. The resulting work is called a "modified version" of the 75 | earlier work or a work "based on" the earlier work. 76 | 77 | A "covered work" means either the unmodified Program or a work based 78 | on the Program. 79 | 80 | To "propagate" a work means to do anything with it that, without 81 | permission, would make you directly or secondarily liable for 82 | infringement under applicable copyright law, except executing it on a 83 | computer or modifying a private copy. Propagation includes copying, 84 | distribution (with or without modification), making available to the 85 | public, and in some countries other activities as well. 86 | 87 | To "convey" a work means any kind of propagation that enables other 88 | parties to make or receive copies. Mere interaction with a user through 89 | a computer network, with no transfer of a copy, is not conveying. 90 | 91 | An interactive user interface displays "Appropriate Legal Notices" 92 | to the extent that it includes a convenient and prominently visible 93 | feature that (1) displays an appropriate copyright notice, and (2) 94 | tells the user that there is no warranty for the work (except to the 95 | extent that warranties are provided), that licensees may convey the 96 | work under this License, and how to view a copy of this License. If 97 | the interface presents a list of user commands or options, such as a 98 | menu, a prominent item in the list meets this criterion. 99 | 100 | 1. Source Code. 101 | 102 | The "source code" for a work means the preferred form of the work 103 | for making modifications to it. "Object code" means any non-source 104 | form of a work. 105 | 106 | A "Standard Interface" means an interface that either is an official 107 | standard defined by a recognized standards body, or, in the case of 108 | interfaces specified for a particular programming language, one that 109 | is widely used among developers working in that language. 110 | 111 | The "System Libraries" of an executable work include anything, other 112 | than the work as a whole, that (a) is included in the normal form of 113 | packaging a Major Component, but which is not part of that Major 114 | Component, and (b) serves only to enable use of the work with that 115 | Major Component, or to implement a Standard Interface for which an 116 | implementation is available to the public in source code form. A 117 | "Major Component", in this context, means a major essential component 118 | (kernel, window system, and so on) of the specific operating system 119 | (if any) on which the executable work runs, or a compiler used to 120 | produce the work, or an object code interpreter used to run it. 121 | 122 | The "Corresponding Source" for a work in object code form means all 123 | the source code needed to generate, install, and (for an executable 124 | work) run the object code and to modify the work, including scripts to 125 | control those activities. However, it does not include the work's 126 | System Libraries, or general-purpose tools or generally available free 127 | programs which are used unmodified in performing those activities but 128 | which are not part of the work. For example, Corresponding Source 129 | includes interface definition files associated with source files for 130 | the work, and the source code for shared libraries and dynamically 131 | linked subprograms that the work is specifically designed to require, 132 | such as by intimate data communication or control flow between those 133 | subprograms and other parts of the work. 134 | 135 | The Corresponding Source need not include anything that users 136 | can regenerate automatically from other parts of the Corresponding 137 | Source. 138 | 139 | The Corresponding Source for a work in source code form is that 140 | same work. 141 | 142 | 2. Basic Permissions. 143 | 144 | All rights granted under this License are granted for the term of 145 | copyright on the Program, and are irrevocable provided the stated 146 | conditions are met. This License explicitly affirms your unlimited 147 | permission to run the unmodified Program. The output from running a 148 | covered work is covered by this License only if the output, given its 149 | content, constitutes a covered work. This License acknowledges your 150 | rights of fair use or other equivalent, as provided by copyright law. 151 | 152 | You may make, run and propagate covered works that you do not 153 | convey, without conditions so long as your license otherwise remains 154 | in force. You may convey covered works to others for the sole purpose 155 | of having them make modifications exclusively for you, or provide you 156 | with facilities for running those works, provided that you comply with 157 | the terms of this License in conveying all material for which you do 158 | not control copyright. Those thus making or running the covered works 159 | for you must do so exclusively on your behalf, under your direction 160 | and control, on terms that prohibit them from making any copies of 161 | your copyrighted material outside their relationship with you. 162 | 163 | Conveying under any other circumstances is permitted solely under 164 | the conditions stated below. Sublicensing is not allowed; section 10 165 | makes it unnecessary. 166 | 167 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 168 | 169 | No covered work shall be deemed part of an effective technological 170 | measure under any applicable law fulfilling obligations under article 171 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 172 | similar laws prohibiting or restricting circumvention of such 173 | measures. 174 | 175 | When you convey a covered work, you waive any legal power to forbid 176 | circumvention of technological measures to the extent such circumvention 177 | is effected by exercising rights under this License with respect to 178 | the covered work, and you disclaim any intention to limit operation or 179 | modification of the work as a means of enforcing, against the work's 180 | users, your or third parties' legal rights to forbid circumvention of 181 | technological measures. 182 | 183 | 4. Conveying Verbatim Copies. 184 | 185 | You may convey verbatim copies of the Program's source code as you 186 | receive it, in any medium, provided that you conspicuously and 187 | appropriately publish on each copy an appropriate copyright notice; 188 | keep intact all notices stating that this License and any 189 | non-permissive terms added in accord with section 7 apply to the code; 190 | keep intact all notices of the absence of any warranty; and give all 191 | recipients a copy of this License along with the Program. 192 | 193 | You may charge any price or no price for each copy that you convey, 194 | and you may offer support or warranty protection for a fee. 195 | 196 | 5. Conveying Modified Source Versions. 197 | 198 | You may convey a work based on the Program, or the modifications to 199 | produce it from the Program, in the form of source code under the 200 | terms of section 4, provided that you also meet all of these conditions: 201 | 202 | a) The work must carry prominent notices stating that you modified 203 | it, and giving a relevant date. 204 | 205 | b) The work must carry prominent notices stating that it is 206 | released under this License and any conditions added under section 207 | 7. This requirement modifies the requirement in section 4 to 208 | "keep intact all notices". 209 | 210 | c) You must license the entire work, as a whole, under this 211 | License to anyone who comes into possession of a copy. This 212 | License will therefore apply, along with any applicable section 7 213 | additional terms, to the whole of the work, and all its parts, 214 | regardless of how they are packaged. This License gives no 215 | permission to license the work in any other way, but it does not 216 | invalidate such permission if you have separately received it. 217 | 218 | d) If the work has interactive user interfaces, each must display 219 | Appropriate Legal Notices; however, if the Program has interactive 220 | interfaces that do not display Appropriate Legal Notices, your 221 | work need not make them do so. 222 | 223 | A compilation of a covered work with other separate and independent 224 | works, which are not by their nature extensions of the covered work, 225 | and which are not combined with it such as to form a larger program, 226 | in or on a volume of a storage or distribution medium, is called an 227 | "aggregate" if the compilation and its resulting copyright are not 228 | used to limit the access or legal rights of the compilation's users 229 | beyond what the individual works permit. Inclusion of a covered work 230 | in an aggregate does not cause this License to apply to the other 231 | parts of the aggregate. 232 | 233 | 6. Conveying Non-Source Forms. 234 | 235 | You may convey a covered work in object code form under the terms 236 | of sections 4 and 5, provided that you also convey the 237 | machine-readable Corresponding Source under the terms of this License, 238 | in one of these ways: 239 | 240 | a) Convey the object code in, or embodied in, a physical product 241 | (including a physical distribution medium), accompanied by the 242 | Corresponding Source fixed on a durable physical medium 243 | customarily used for software interchange. 244 | 245 | b) Convey the object code in, or embodied in, a physical product 246 | (including a physical distribution medium), accompanied by a 247 | written offer, valid for at least three years and valid for as 248 | long as you offer spare parts or customer support for that product 249 | model, to give anyone who possesses the object code either (1) a 250 | copy of the Corresponding Source for all the software in the 251 | product that is covered by this License, on a durable physical 252 | medium customarily used for software interchange, for a price no 253 | more than your reasonable cost of physically performing this 254 | conveying of source, or (2) access to copy the 255 | Corresponding Source from a network server at no charge. 256 | 257 | c) Convey individual copies of the object code with a copy of the 258 | written offer to provide the Corresponding Source. This 259 | alternative is allowed only occasionally and noncommercially, and 260 | only if you received the object code with such an offer, in accord 261 | with subsection 6b. 262 | 263 | d) Convey the object code by offering access from a designated 264 | place (gratis or for a charge), and offer equivalent access to the 265 | Corresponding Source in the same way through the same place at no 266 | further charge. You need not require recipients to copy the 267 | Corresponding Source along with the object code. If the place to 268 | copy the object code is a network server, the Corresponding Source 269 | may be on a different server (operated by you or a third party) 270 | that supports equivalent copying facilities, provided you maintain 271 | clear directions next to the object code saying where to find the 272 | Corresponding Source. Regardless of what server hosts the 273 | Corresponding Source, you remain obligated to ensure that it is 274 | available for as long as needed to satisfy these requirements. 275 | 276 | e) Convey the object code using peer-to-peer transmission, provided 277 | you inform other peers where the object code and Corresponding 278 | Source of the work are being offered to the general public at no 279 | charge under subsection 6d. 280 | 281 | A separable portion of the object code, whose source code is excluded 282 | from the Corresponding Source as a System Library, need not be 283 | included in conveying the object code work. 284 | 285 | A "User Product" is either (1) a "consumer product", which means any 286 | tangible personal property which is normally used for personal, family, 287 | or household purposes, or (2) anything designed or sold for incorporation 288 | into a dwelling. In determining whether a product is a consumer product, 289 | doubtful cases shall be resolved in favor of coverage. For a particular 290 | product received by a particular user, "normally used" refers to a 291 | typical or common use of that class of product, regardless of the status 292 | of the particular user or of the way in which the particular user 293 | actually uses, or expects or is expected to use, the product. A product 294 | is a consumer product regardless of whether the product has substantial 295 | commercial, industrial or non-consumer uses, unless such uses represent 296 | the only significant mode of use of the product. 297 | 298 | "Installation Information" for a User Product means any methods, 299 | procedures, authorization keys, or other information required to install 300 | and execute modified versions of a covered work in that User Product from 301 | a modified version of its Corresponding Source. The information must 302 | suffice to ensure that the continued functioning of the modified object 303 | code is in no case prevented or interfered with solely because 304 | modification has been made. 305 | 306 | If you convey an object code work under this section in, or with, or 307 | specifically for use in, a User Product, and the conveying occurs as 308 | part of a transaction in which the right of possession and use of the 309 | User Product is transferred to the recipient in perpetuity or for a 310 | fixed term (regardless of how the transaction is characterized), the 311 | Corresponding Source conveyed under this section must be accompanied 312 | by the Installation Information. But this requirement does not apply 313 | if neither you nor any third party retains the ability to install 314 | modified object code on the User Product (for example, the work has 315 | been installed in ROM). 316 | 317 | The requirement to provide Installation Information does not include a 318 | requirement to continue to provide support service, warranty, or updates 319 | for a work that has been modified or installed by the recipient, or for 320 | the User Product in which it has been modified or installed. Access to a 321 | network may be denied when the modification itself materially and 322 | adversely affects the operation of the network or violates the rules and 323 | protocols for communication across the network. 324 | 325 | Corresponding Source conveyed, and Installation Information provided, 326 | in accord with this section must be in a format that is publicly 327 | documented (and with an implementation available to the public in 328 | source code form), and must require no special password or key for 329 | unpacking, reading or copying. 330 | 331 | 7. Additional Terms. 332 | 333 | "Additional permissions" are terms that supplement the terms of this 334 | License by making exceptions from one or more of its conditions. 335 | Additional permissions that are applicable to the entire Program shall 336 | be treated as though they were included in this License, to the extent 337 | that they are valid under applicable law. If additional permissions 338 | apply only to part of the Program, that part may be used separately 339 | under those permissions, but the entire Program remains governed by 340 | this License without regard to the additional permissions. 341 | 342 | When you convey a copy of a covered work, you may at your option 343 | remove any additional permissions from that copy, or from any part of 344 | it. (Additional permissions may be written to require their own 345 | removal in certain cases when you modify the work.) You may place 346 | additional permissions on material, added by you to a covered work, 347 | for which you have or can give appropriate copyright permission. 348 | 349 | Notwithstanding any other provision of this License, for material you 350 | add to a covered work, you may (if authorized by the copyright holders of 351 | that material) supplement the terms of this License with terms: 352 | 353 | a) Disclaiming warranty or limiting liability differently from the 354 | terms of sections 15 and 16 of this License; or 355 | 356 | b) Requiring preservation of specified reasonable legal notices or 357 | author attributions in that material or in the Appropriate Legal 358 | Notices displayed by works containing it; or 359 | 360 | c) Prohibiting misrepresentation of the origin of that material, or 361 | requiring that modified versions of such material be marked in 362 | reasonable ways as different from the original version; or 363 | 364 | d) Limiting the use for publicity purposes of names of licensors or 365 | authors of the material; or 366 | 367 | e) Declining to grant rights under trademark law for use of some 368 | trade names, trademarks, or service marks; or 369 | 370 | f) Requiring indemnification of licensors and authors of that 371 | material by anyone who conveys the material (or modified versions of 372 | it) with contractual assumptions of liability to the recipient, for 373 | any liability that these contractual assumptions directly impose on 374 | those licensors and authors. 375 | 376 | All other non-permissive additional terms are considered "further 377 | restrictions" within the meaning of section 10. If the Program as you 378 | received it, or any part of it, contains a notice stating that it is 379 | governed by this License along with a term that is a further 380 | restriction, you may remove that term. If a license document contains 381 | a further restriction but permits relicensing or conveying under this 382 | License, you may add to a covered work material governed by the terms 383 | of that license document, provided that the further restriction does 384 | not survive such relicensing or conveying. 385 | 386 | If you add terms to a covered work in accord with this section, you 387 | must place, in the relevant source files, a statement of the 388 | additional terms that apply to those files, or a notice indicating 389 | where to find the applicable terms. 390 | 391 | Additional terms, permissive or non-permissive, may be stated in the 392 | form of a separately written license, or stated as exceptions; 393 | the above requirements apply either way. 394 | 395 | 8. Termination. 396 | 397 | You may not propagate or modify a covered work except as expressly 398 | provided under this License. Any attempt otherwise to propagate or 399 | modify it is void, and will automatically terminate your rights under 400 | this License (including any patent licenses granted under the third 401 | paragraph of section 11). 402 | 403 | However, if you cease all violation of this License, then your 404 | license from a particular copyright holder is reinstated (a) 405 | provisionally, unless and until the copyright holder explicitly and 406 | finally terminates your license, and (b) permanently, if the copyright 407 | holder fails to notify you of the violation by some reasonable means 408 | prior to 60 days after the cessation. 409 | 410 | Moreover, your license from a particular copyright holder is 411 | reinstated permanently if the copyright holder notifies you of the 412 | violation by some reasonable means, this is the first time you have 413 | received notice of violation of this License (for any work) from that 414 | copyright holder, and you cure the violation prior to 30 days after 415 | your receipt of the notice. 416 | 417 | Termination of your rights under this section does not terminate the 418 | licenses of parties who have received copies or rights from you under 419 | this License. If your rights have been terminated and not permanently 420 | reinstated, you do not qualify to receive new licenses for the same 421 | material under section 10. 422 | 423 | 9. Acceptance Not Required for Having Copies. 424 | 425 | You are not required to accept this License in order to receive or 426 | run a copy of the Program. Ancillary propagation of a covered work 427 | occurring solely as a consequence of using peer-to-peer transmission 428 | to receive a copy likewise does not require acceptance. However, 429 | nothing other than this License grants you permission to propagate or 430 | modify any covered work. These actions infringe copyright if you do 431 | not accept this License. Therefore, by modifying or propagating a 432 | covered work, you indicate your acceptance of this License to do so. 433 | 434 | 10. Automatic Licensing of Downstream Recipients. 435 | 436 | Each time you convey a covered work, the recipient automatically 437 | receives a license from the original licensors, to run, modify and 438 | propagate that work, subject to this License. You are not responsible 439 | for enforcing compliance by third parties with this License. 440 | 441 | An "entity transaction" is a transaction transferring control of an 442 | organization, or substantially all assets of one, or subdividing an 443 | organization, or merging organizations. If propagation of a covered 444 | work results from an entity transaction, each party to that 445 | transaction who receives a copy of the work also receives whatever 446 | licenses to the work the party's predecessor in interest had or could 447 | give under the previous paragraph, plus a right to possession of the 448 | Corresponding Source of the work from the predecessor in interest, if 449 | the predecessor has it or can get it with reasonable efforts. 450 | 451 | You may not impose any further restrictions on the exercise of the 452 | rights granted or affirmed under this License. For example, you may 453 | not impose a license fee, royalty, or other charge for exercise of 454 | rights granted under this License, and you may not initiate litigation 455 | (including a cross-claim or counterclaim in a lawsuit) alleging that 456 | any patent claim is infringed by making, using, selling, offering for 457 | sale, or importing the Program or any portion of it. 458 | 459 | 11. Patents. 460 | 461 | A "contributor" is a copyright holder who authorizes use under this 462 | License of the Program or a work on which the Program is based. The 463 | work thus licensed is called the contributor's "contributor version". 464 | 465 | A contributor's "essential patent claims" are all patent claims 466 | owned or controlled by the contributor, whether already acquired or 467 | hereafter acquired, that would be infringed by some manner, permitted 468 | by this License, of making, using, or selling its contributor version, 469 | but do not include claims that would be infringed only as a 470 | consequence of further modification of the contributor version. For 471 | purposes of this definition, "control" includes the right to grant 472 | patent sublicenses in a manner consistent with the requirements of 473 | this License. 474 | 475 | Each contributor grants you a non-exclusive, worldwide, royalty-free 476 | patent license under the contributor's essential patent claims, to 477 | make, use, sell, offer for sale, import and otherwise run, modify and 478 | propagate the contents of its contributor version. 479 | 480 | In the following three paragraphs, a "patent license" is any express 481 | agreement or commitment, however denominated, not to enforce a patent 482 | (such as an express permission to practice a patent or covenant not to 483 | sue for patent infringement). To "grant" such a patent license to a 484 | party means to make such an agreement or commitment not to enforce a 485 | patent against the party. 486 | 487 | If you convey a covered work, knowingly relying on a patent license, 488 | and the Corresponding Source of the work is not available for anyone 489 | to copy, free of charge and under the terms of this License, through a 490 | publicly available network server or other readily accessible means, 491 | then you must either (1) cause the Corresponding Source to be so 492 | available, or (2) arrange to deprive yourself of the benefit of the 493 | patent license for this particular work, or (3) arrange, in a manner 494 | consistent with the requirements of this License, to extend the patent 495 | license to downstream recipients. "Knowingly relying" means you have 496 | actual knowledge that, but for the patent license, your conveying the 497 | covered work in a country, or your recipient's use of the covered work 498 | in a country, would infringe one or more identifiable patents in that 499 | country that you have reason to believe are valid. 500 | 501 | If, pursuant to or in connection with a single transaction or 502 | arrangement, you convey, or propagate by procuring conveyance of, a 503 | covered work, and grant a patent license to some of the parties 504 | receiving the covered work authorizing them to use, propagate, modify 505 | or convey a specific copy of the covered work, then the patent license 506 | you grant is automatically extended to all recipients of the covered 507 | work and works based on it. 508 | 509 | A patent license is "discriminatory" if it does not include within 510 | the scope of its coverage, prohibits the exercise of, or is 511 | conditioned on the non-exercise of one or more of the rights that are 512 | specifically granted under this License. You may not convey a covered 513 | work if you are a party to an arrangement with a third party that is 514 | in the business of distributing software, under which you make payment 515 | to the third party based on the extent of your activity of conveying 516 | the work, and under which the third party grants, to any of the 517 | parties who would receive the covered work from you, a discriminatory 518 | patent license (a) in connection with copies of the covered work 519 | conveyed by you (or copies made from those copies), or (b) primarily 520 | for and in connection with specific products or compilations that 521 | contain the covered work, unless you entered into that arrangement, 522 | or that patent license was granted, prior to 28 March 2007. 523 | 524 | Nothing in this License shall be construed as excluding or limiting 525 | any implied license or other defenses to infringement that may 526 | otherwise be available to you under applicable patent law. 527 | 528 | 12. No Surrender of Others' Freedom. 529 | 530 | If conditions are imposed on you (whether by court order, agreement or 531 | otherwise) that contradict the conditions of this License, they do not 532 | excuse you from the conditions of this License. If you cannot convey a 533 | covered work so as to satisfy simultaneously your obligations under this 534 | License and any other pertinent obligations, then as a consequence you may 535 | not convey it at all. For example, if you agree to terms that obligate you 536 | to collect a royalty for further conveying from those to whom you convey 537 | the Program, the only way you could satisfy both those terms and this 538 | License would be to refrain entirely from conveying the Program. 539 | 540 | 13. Remote Network Interaction; Use with the GNU General Public License. 541 | 542 | Notwithstanding any other provision of this License, if you modify the 543 | Program, your modified version must prominently offer all users 544 | interacting with it remotely through a computer network (if your version 545 | supports such interaction) an opportunity to receive the Corresponding 546 | Source of your version by providing access to the Corresponding Source 547 | from a network server at no charge, through some standard or customary 548 | means of facilitating copying of software. This Corresponding Source 549 | shall include the Corresponding Source for any work covered by version 3 550 | of the GNU General Public License that is incorporated pursuant to the 551 | following paragraph. 552 | 553 | Notwithstanding any other provision of this License, you have 554 | permission to link or combine any covered work with a work licensed 555 | under version 3 of the GNU General Public License into a single 556 | combined work, and to convey the resulting work. The terms of this 557 | License will continue to apply to the part which is the covered work, 558 | but the work with which it is combined will remain governed by version 559 | 3 of the GNU General Public License. 560 | 561 | 14. Revised Versions of this License. 562 | 563 | The Free Software Foundation may publish revised and/or new versions of 564 | the GNU Affero General Public License from time to time. Such new versions 565 | will be similar in spirit to the present version, but may differ in detail to 566 | address new problems or concerns. 567 | 568 | Each version is given a distinguishing version number. If the 569 | Program specifies that a certain numbered version of the GNU Affero General 570 | Public License "or any later version" applies to it, you have the 571 | option of following the terms and conditions either of that numbered 572 | version or of any later version published by the Free Software 573 | Foundation. If the Program does not specify a version number of the 574 | GNU Affero General Public License, you may choose any version ever published 575 | by the Free Software Foundation. 576 | 577 | If the Program specifies that a proxy can decide which future 578 | versions of the GNU Affero General Public License can be used, that proxy's 579 | public statement of acceptance of a version permanently authorizes you 580 | to choose that version for the Program. 581 | 582 | Later license versions may give you additional or different 583 | permissions. However, no additional obligations are imposed on any 584 | author or copyright holder as a result of your choosing to follow a 585 | later version. 586 | 587 | 15. Disclaimer of Warranty. 588 | 589 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 590 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 591 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 592 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 593 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 594 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 595 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 596 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 597 | 598 | 16. Limitation of Liability. 599 | 600 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 601 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 602 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 603 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 604 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 605 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 606 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 607 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 608 | SUCH DAMAGES. 609 | 610 | 17. Interpretation of Sections 15 and 16. 611 | 612 | If the disclaimer of warranty and limitation of liability provided 613 | above cannot be given local legal effect according to their terms, 614 | reviewing courts shall apply local law that most closely approximates 615 | an absolute waiver of all civil liability in connection with the 616 | Program, unless a warranty or assumption of liability accompanies a 617 | copy of the Program in return for a fee. 618 | 619 | END OF TERMS AND CONDITIONS 620 | 621 | How to Apply These Terms to Your New Programs 622 | 623 | If you develop a new program, and you want it to be of the greatest 624 | possible use to the public, the best way to achieve this is to make it 625 | free software which everyone can redistribute and change under these terms. 626 | 627 | To do so, attach the following notices to the program. It is safest 628 | to attach them to the start of each source file to most effectively 629 | state the exclusion of warranty; and each file should have at least 630 | the "copyright" line and a pointer to where the full notice is found. 631 | 632 | 633 | Copyright (C) 634 | 635 | This program is free software: you can redistribute it and/or modify 636 | it under the terms of the GNU Affero General Public License as published 637 | by the Free Software Foundation, either version 3 of the License, or 638 | (at your option) any later version. 639 | 640 | This program is distributed in the hope that it will be useful, 641 | but WITHOUT ANY WARRANTY; without even the implied warranty of 642 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 643 | GNU Affero General Public License for more details. 644 | 645 | You should have received a copy of the GNU Affero General Public License 646 | along with this program. If not, see . 647 | 648 | Also add information on how to contact you by electronic and paper mail. 649 | 650 | If your software can interact with users remotely through a computer 651 | network, you should also make sure that it provides a way for users to 652 | get its source. For example, if your program is a web application, its 653 | interface could display a "Source" link that leads users to an archive 654 | of the code. There are many ways you could offer source, and different 655 | solutions will be better for different programs; see section 13 for the 656 | specific requirements. 657 | 658 | You should also get your employer (if you work as a programmer) or school, 659 | if any, to sign a "copyright disclaimer" for the program, if necessary. 660 | For more information on this, and how to apply and follow the GNU AGPL, see 661 | . 662 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![python](https://img.shields.io/badge/python-3.6+-blue.svg) 2 | ![pytorch](https://img.shields.io/badge/pytorch-1.0%2B-brightgreen) 3 | [![report](https://img.shields.io/badge/arxiv-report-red)](https://arxiv.org/abs/2004.05571) 4 | 5 | # CoCosNet 6 | Pytorch Implementation of the paper ["Cross-domain Correspondence Learning for Exemplar-based Image Translation"](https://panzhang0212.github.io/CoCosNet) (CVPR 2020 oral). 7 | 8 | 9 | ![teaser](https://panzhang0212.github.io/CoCosNet/images/teaser.png) 10 | 11 | ### Update: 12 | 20200525: Training code for deepfashion complete. Due to the memory limitations, I employed the following conversions: 13 | - Disable the non-local layer, as the memory cost is infeasible on common hardware. If the original paper is telling the truth that the non-lacal layer works on (128-128-256) tensors, then each attention matrix would contain 128^4 elements (which takes 1GB). 14 | - Shrink the correspondence map size from 64 to 32, leading to 4x memory save on dense correspondence matrices. 15 | - Shrink the base number of filters from 64 to 16. 16 | 17 | The truncated model barely fits in a 12GB GTX Titan X card, but the performance would not be the same. 18 | 19 | # Environment 20 | - Ubuntu/CentOS 21 | - Pytorch 1.0+ 22 | - opencv-python 23 | - tqdm 24 | 25 | # TODO list 26 | - [x] Prepare dataset 27 | - [x] Implement the network 28 | - [x] Implement the loss functions 29 | - [x] Implement the trainer 30 | - [x] Training on DeepFashion 31 | - [ ] Adjust network architecture to satisfy a single 16 GB GPU. 32 | - [ ] Training for other tasks 33 | 34 | # Dataset Preparation 35 | ### DeepFashion 36 | Just follow the routine in [the PATN repo](https://github.com/Lotayou/Pose-Transfer) 37 | 38 | # Pretrained Model 39 | The pretrained model for human pose transfer task: [TO BE RELEASED](https://github.com/Lotayou) 40 | 41 | # Training 42 | run `python train.py`. 43 | 44 | # Citations 45 | If you find this repo useful for your research, don't forget to cite the original paper: 46 | ``` 47 | @article{Zhang2020CrossdomainCL, 48 | title={Cross-domain Correspondence Learning for Exemplar-based Image Translation}, 49 | author={Pan Zhang and Bo Zhang and Dong Chen and Lu Yuan and Fang Wen}, 50 | journal={ArXiv}, 51 | year={2020}, 52 | volume={abs/2004.05571} 53 | } 54 | ``` 55 | 56 | # Acknowledgement 57 | TODO. 58 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from data.base_dataset import BaseDataset 16 | 17 | 18 | def find_dataset_using_name(dataset_name): 19 | """Import the module "data/[dataset_name]_dataset.py". 20 | 21 | In the file, the class called DatasetNameDataset() will 22 | be instantiated. It has to be a subclass of BaseDataset, 23 | and it is case-insensitive. 24 | """ 25 | dataset_filename = "data." + dataset_name + "_dataset" 26 | datasetlib = importlib.import_module(dataset_filename) 27 | 28 | dataset = None 29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 30 | for name, cls in datasetlib.__dict__.items(): 31 | if name.lower() == target_dataset_name.lower() \ 32 | and issubclass(cls, BaseDataset): 33 | dataset = cls 34 | 35 | if dataset is None: 36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 37 | 38 | return dataset 39 | 40 | 41 | def get_option_setter(dataset_name): 42 | """Return the static method of the dataset class.""" 43 | dataset_class = find_dataset_using_name(dataset_name) 44 | return dataset_class.modify_commandline_options 45 | 46 | 47 | def create_dataset(opt): 48 | """Create a dataset given the option. 49 | 50 | This function wraps the class CustomDatasetDataLoader. 51 | This is the main interface between this package and 'train.py'/'test.py' 52 | 53 | Example: 54 | >>> from data import create_dataset 55 | >>> dataset = create_dataset(opt) 56 | """ 57 | data_loader = CustomDatasetDataLoader(opt) 58 | dataset = data_loader.load_data() 59 | return dataset 60 | 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | print("dataset [%s] was created" % type(self.dataset).__name__) 75 | self.dataloader = torch.utils.data.DataLoader( 76 | self.dataset, 77 | batch_size=opt.batch_size, 78 | shuffle=not opt.serial_batches, 79 | num_workers=int(opt.num_workers), 80 | drop_last=True, 81 | pin_memory=True) 82 | 83 | def load_data(self): 84 | return self 85 | 86 | def __len__(self): 87 | """Return the number of data in the dataset""" 88 | return min(len(self.dataset), self.opt.max_dataset_size) 89 | 90 | def __iter__(self): 91 | """Return a batch of data""" 92 | for i, data in enumerate(self.dataloader): 93 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 94 | break 95 | yield data 96 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | class BaseDataset(Dataset): 4 | def __init__(self, opt): 5 | super().__init__() 6 | 7 | def __getitem__(self, index): 8 | pass 9 | 10 | def __len__(self): pass 11 | -------------------------------------------------------------------------------- /data/fashion_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | fashion dataset: load deepfashion models 3 | Requires skeleton input as stick figures. 4 | """ 5 | 6 | import random 7 | import numpy as np 8 | import torch 9 | import torch.utils.data as data 10 | import cv2 11 | from tqdm import tqdm 12 | import os 13 | from data.base_dataset import BaseDataset 14 | from data.pose_utils import draw_pose_from_cords, load_pose_cords_from_strings 15 | 16 | class FashionDataset(BaseDataset): 17 | # Beware, the pose annotation is fitted for 256*176 images, need additional resizing 18 | def __init__(self, opt): 19 | super().__init__(opt) 20 | self.opt = opt 21 | self.h = opt.image_size 22 | self.w = opt.image_size - 2 * opt.padding 23 | self.size = (self.h, self.w) 24 | self.pd = opt.padding 25 | 26 | self.white = torch.ones((3, self.h, self.h), dtype=torch.float32) 27 | self.black = -1 * self.white 28 | 29 | self.dir_Img = os.path.join(opt.dataroot, opt.phase) # person images (exemplar) 30 | self.dir_Anno = os.path.join(opt.dataroot, opt.phase + '_pose_rgb') # rgb pose images 31 | 32 | pairLst = os.path.join(opt.dataroot, 'fasion-resize-pairs-%s.csv' % opt.phase) 33 | self.init_categories(pairLst) 34 | 35 | if not os.path.isdir(self.dir_Anno): 36 | print('Folder %s not found or annotation incomplete...' % self.dir_Anno) 37 | annotation_csv = os.path.join(opt.dataroot, 'fasion-resize-annotation-%s.csv' % opt.phase) 38 | if os.path.isfile(annotation_csv): 39 | print('Found backup annotation file, start generating required pose images...') 40 | self.draw_stick_figures(annotation_csv, self.dir_Anno) 41 | 42 | 43 | def trans(self, x, bg='black'): 44 | x = torch.from_numpy(x / 127.5 - 1).permute(2, 0, 1).float() 45 | full = torch.ones((3, self.h, self.h), dtype=torch.float32) 46 | if bg == 'black': 47 | full = -1 * full 48 | 49 | full[:,:,self.pd:self.pd+self.w] = x 50 | return full 51 | 52 | def draw_stick_figures(self, annotation, target_dir): 53 | os.makedirs(target_dir, exist_ok=True) 54 | with open(annotation, 'r') as f: 55 | lines = [l.strip() for l in f][1:] 56 | 57 | for l in tqdm(lines): 58 | name, str_y, str_x = l.split(':') 59 | target_name = os.path.join(target_dir, name) 60 | cords = load_pose_cords_from_strings(str_y, str_x) 61 | target_im, _ = draw_pose_from_cords(cords, self.size) 62 | cv2.imwrite(target_name, target_im) 63 | 64 | 65 | def init_categories(self, pairLst): 66 | ''' 67 | Using pandas is too f**king slow... 68 | 69 | pairs_file_train = pd.read_csv(pairLst) 70 | self.size = len(pairs_file_train) 71 | self.pairs = [] 72 | print('Loading data pairs ...') 73 | for i in range(self.size): 74 | pair = [pairs_file_train.iloc[i]['from'], pairs_file_train.iloc[i]['to']] 75 | self.pairs.append(pair) 76 | ''' 77 | with open(pairLst, 'r') as f: 78 | lines = [l for l in f][1:self.opt.max_dataset_size+1] 79 | self.pairs = [l.strip().split(',') for l in lines] 80 | print('Loading data pairs finished ...') 81 | 82 | def __getitem__(self, index): 83 | P1_name, P2_name = self.pairs[index] 84 | 85 | P1 = self.trans(cv2.imread(os.path.join(self.dir_Img, P1_name)), bg='white') # person 1 86 | BP1 = self.trans(cv2.imread(os.path.join(self.dir_Anno, P1_name)), bg='black') # bone of person 1 87 | P2 = self.trans(cv2.imread(os.path.join(self.dir_Img, P2_name)), bg='white') # person 2 88 | BP2 = self.trans(cv2.imread(os.path.join(self.dir_Anno, P2_name)), bg='black') # bone of person 2 89 | # domain x: posemap 90 | # domain y: exemplar 91 | return {'a': BP2, 'b_gt': P2, 'a_exemplar': BP1, 'b_exemplar': P1} 92 | 93 | 94 | def __len__(self): 95 | return len(self.pairs) 96 | 97 | -------------------------------------------------------------------------------- /data/pose_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage.filters import gaussian_filter 3 | from skimage.draw import circle, line_aa, polygon 4 | import json 5 | from pandas import Series 6 | import matplotlib 7 | matplotlib.use('Agg') 8 | import matplotlib.pyplot as plt 9 | import matplotlib.patches as mpatches 10 | from collections import defaultdict 11 | import skimage.measure, skimage.transform 12 | import sys 13 | 14 | LIMB_SEQ = [[1,2], [1,5], [2,3], [3,4], [5,6], [6,7], [1,8], [8,9], 15 | [9,10], [1,11], [11,12], [12,13], [1,0], [0,14], [14,16], 16 | [0,15], [15,17], [2,16], [5,17]] 17 | 18 | COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 19 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 20 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 21 | 22 | 23 | LABELS = ['nose', 'neck', 'Rsho', 'Relb', 'Rwri', 'Lsho', 'Lelb', 'Lwri', 24 | 'Rhip', 'Rkne', 'Rank', 'Lhip', 'Lkne', 'Lank', 'Leye', 'Reye', 'Lear', 'Rear'] 25 | 26 | MISSING_VALUE = -1 27 | def MISSING(x): 28 | return x == -1 or x == 0 29 | 30 | 31 | def map_to_cord(pose_map, threshold=0.1): 32 | all_peaks = [[] for i in range(18)] 33 | pose_map = pose_map[..., :18] 34 | 35 | y, x, z = np.where(np.logical_and(pose_map == pose_map.max(axis = (0, 1)), 36 | pose_map > threshold)) 37 | for x_i, y_i, z_i in zip(x, y, z): 38 | all_peaks[z_i].append([x_i, y_i]) 39 | 40 | x_values = [] 41 | y_values = [] 42 | 43 | for i in range(18): 44 | if len(all_peaks[i]) != 0: 45 | x_values.append(all_peaks[i][0][0]) 46 | y_values.append(all_peaks[i][0][1]) 47 | else: 48 | x_values.append(MISSING_VALUE) 49 | y_values.append(MISSING_VALUE) 50 | 51 | return np.concatenate([np.expand_dims(y_values, -1), np.expand_dims(x_values, -1)], axis=1) 52 | 53 | 54 | def cords_to_map(cords, img_size, sigma=6): 55 | result = np.zeros(img_size + cords.shape[0:1], dtype='float32') 56 | for i, point in enumerate(cords): 57 | if point[0] == MISSING_VALUE or point[1] == MISSING_VALUE: 58 | continue 59 | xx, yy = np.meshgrid(np.arange(img_size[1]), np.arange(img_size[0])) 60 | result[..., i] = np.exp(-((yy - point[0]) ** 2 + (xx - point[1]) ** 2) / (2 * sigma ** 2)) 61 | return result 62 | 63 | 64 | def draw_pose_from_cords(pose_joints, img_size, radius=2, draw_joints=True): 65 | colors = np.zeros(shape=img_size + (3, ), dtype=np.uint8) 66 | mask = np.zeros(shape=img_size, dtype=bool) 67 | 68 | if draw_joints: 69 | for f, t in LIMB_SEQ: 70 | from_missing = MISSING(pose_joints[f][0]) or MISSING(pose_joints[f][1]) 71 | to_missing = MISSING(pose_joints[t][0]) or MISSING(pose_joints[t][1]) 72 | if from_missing or to_missing: 73 | continue 74 | 75 | ''' 76 | Trick, use a 4-polygon with 1 pixel width to represent lines, involve shape control. 77 | 78 | yy, xx = polygon( 79 | [pose_joints[f][0], pose_joints[t][0], pose_joints[t][0]+1, pose_joints[f][0]+1], 80 | [pose_joints[f][1], pose_joints[t][1], pose_joints[t][1]+1, pose_joints[f][1]+1], 81 | shape=img_size 82 | ) 83 | ''' 84 | yy, xx, val = line_aa(pose_joints[f][0], pose_joints[f][1], pose_joints[t][0], pose_joints[t][1]) 85 | valid_ids = [i for i in range(len(yy)) if 0 < yy[i] < img_size[0] and 0 < xx[i] < img_size[1]] 86 | yy, xx, val = yy[valid_ids], xx[valid_ids], val[valid_ids] 87 | colors[yy, xx] = np.expand_dims(val, 1) * 255 88 | mask[yy, xx] = True 89 | 90 | for i, joint in enumerate(pose_joints): 91 | if MISSING(pose_joints[i][0]) or MISSING(pose_joints[i][1]): 92 | continue 93 | yy, xx = circle(joint[0], joint[1], radius=radius, shape=img_size) 94 | colors[yy, xx] = COLORS[i] 95 | mask[yy, xx] = True 96 | 97 | return colors, mask 98 | 99 | 100 | def draw_pose_from_map(pose_map, threshold=0.1, **kwargs): 101 | cords = map_to_cord(pose_map, threshold=threshold) 102 | return draw_pose_from_cords(cords, pose_map.shape[:2], **kwargs) 103 | 104 | 105 | def load_pose_cords_from_strings(y_str, x_str): 106 | ## 20181114: FIX bug, convert pandas.Series object to a string-formatted int list 107 | if isinstance(y_str, Series): 108 | y_str = y_str.values[0] 109 | if isinstance(x_str, Series): 110 | x_str = x_str.values[0] 111 | y_cords = json.loads(y_str) 112 | x_cords = json.loads(x_str) 113 | # 20191117: modify PATN processed coords by adding 40 to non-negative indices 114 | # NOTE: For fasion dataset only. 115 | # print(x_cords) 116 | # 20191123: deprecate this. 117 | # x_cords = [item + 40 if item > 0 else item for item in x_cords] 118 | # print(x_cords) 119 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1) 120 | 121 | def mean_inputation(X): 122 | X = X.copy() 123 | for i in range(X.shape[1]): 124 | for j in range(X.shape[2]): 125 | val = np.mean(X[:, i, j][X[:, i, j] != -1]) 126 | X[:, i, j][X[:, i, j] == -1] = val 127 | return X 128 | 129 | def draw_legend(): 130 | handles = [mpatches.Patch(color=np.array(color) / 255.0, label=name) for color, name in zip(COLORS, LABELS)] 131 | plt.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 132 | 133 | def produce_ma_mask(kp_array, img_size, point_radius=4): 134 | from skimage.morphology import dilation, erosion, square 135 | mask = np.zeros(shape=img_size, dtype=bool) 136 | limbs = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], 137 | [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], 138 | [1,16], [16,18], [2,17], [2,18], [9,12], [12,6], [9,3], [17,18]] 139 | limbs = np.array(limbs) - 1 140 | for f, t in limbs: 141 | from_missing = kp_array[f][0] == MISSING_VALUE or kp_array[f][1] == MISSING_VALUE 142 | to_missing = kp_array[t][0] == MISSING_VALUE or kp_array[t][1] == MISSING_VALUE 143 | if from_missing or to_missing: 144 | continue 145 | 146 | norm_vec = kp_array[f] - kp_array[t] 147 | norm_vec = np.array([-norm_vec[1], norm_vec[0]]) 148 | norm_vec = point_radius * norm_vec / np.linalg.norm(norm_vec) 149 | 150 | 151 | vetexes = np.array([ 152 | kp_array[f] + norm_vec, 153 | kp_array[f] - norm_vec, 154 | kp_array[t] - norm_vec, 155 | kp_array[t] + norm_vec 156 | ]) 157 | yy, xx = polygon(vetexes[:, 0], vetexes[:, 1], shape=img_size) 158 | mask[yy, xx] = True 159 | 160 | for i, joint in enumerate(kp_array): 161 | if kp_array[i][0] == MISSING_VALUE or kp_array[i][1] == MISSING_VALUE: 162 | continue 163 | yy, xx = circle(joint[0], joint[1], radius=point_radius, shape=img_size) 164 | mask[yy, xx] = True 165 | 166 | mask = dilation(mask, square(5)) 167 | mask = erosion(mask, square(5)) 168 | return mask 169 | 170 | if __name__ == "__main__": 171 | import pandas as pd 172 | from skimage.io import imread 173 | import pylab as plt 174 | import os 175 | i = 5 176 | df = pd.read_csv('data/market-annotation-train.csv', sep=':') 177 | 178 | for index, row in df.iterrows(): 179 | pose_cords = load_pose_cords_from_strings(row['keypoints_y'], row['keypoints_x']) 180 | 181 | colors, mask = draw_pose_from_cords(pose_cords, (128, 64)) 182 | 183 | mmm = produce_ma_mask(pose_cords, (128, 64)).astype(float)[..., np.newaxis].repeat(3, axis=-1) 184 | print(mmm.shape) 185 | img = imread('data/market-dataset/train/' + row['name']) 186 | 187 | mmm[mask] = colors[mask] 188 | 189 | print (mmm) 190 | plt.subplot(1, 1, 1) 191 | plt.imshow(mmm) 192 | plt.show() 193 | -------------------------------------------------------------------------------- /datasets/fashion: -------------------------------------------------------------------------------- 1 | /backup1/lingboyang/human_image_generation/CVPR2019_pose_transfer/fashion_data -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 3 | You need to implement the following five functions: 4 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 5 | -- : unpack data from dataset and apply preprocessing. 6 | -- : produce intermediate results. 7 | -- : calculate loss, gradients, and update network weights. 8 | -- : (optionally) add model-specific options and set default options. 9 | In the function <__init__>, you need to define four lists: 10 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 11 | -- self.model_names (str list): specify the images that you want to display and save. 12 | -- self.visual_names (str list): define networks used in our training. 13 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 14 | Now you can use the model class by specifying flag '--model dummy'. 15 | See our template model class 'template_model.py' for an example. 16 | """ 17 | 18 | import importlib 19 | from model.base_model import BaseModel 20 | 21 | 22 | def find_model_using_name(model_name): 23 | """Import the module "models/[model_name]_model.py". 24 | In the file, the class called DatasetNameModel() will 25 | be instantiated. It has to be a subclass of BaseModel, 26 | and it is case-insensitive. 27 | """ 28 | model_filename = "model." + model_name + "_model" 29 | modellib = importlib.import_module(model_filename) 30 | model = None 31 | target_model_name = model_name.replace('_', '') + 'model' 32 | for name, cls in modellib.__dict__.items(): 33 | if name.lower() == target_model_name.lower() \ 34 | and issubclass(cls, BaseModel): 35 | model = cls 36 | 37 | if model is None: 38 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 39 | exit(0) 40 | 41 | return model 42 | 43 | 44 | def get_option_setter(model_name): 45 | """Return the static method of the model class.""" 46 | model_class = find_model_using_name(model_name) 47 | return model_class.modify_commandline_options 48 | 49 | 50 | def create_model(opt): 51 | """Create a model given the option. 52 | This function warps the class CustomDatasetDataLoader. 53 | This is the main interface between this package and 'train.py'/'test.py' 54 | Example: 55 | >>> from models import create_model 56 | >>> model = create_model(opt) 57 | """ 58 | model = find_model_using_name(opt.model) 59 | instance = model(opt) 60 | print("model [%s] was created" % type(instance).__name__) 61 | return instance 62 | -------------------------------------------------------------------------------- /model/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | from . import networks 6 | 7 | 8 | class BaseModel(ABC): 9 | """This class is an abstract base class (ABC) for models. 10 | To create a subclass, you need to implement the following five functions: 11 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 12 | -- : unpack data from dataset and apply preprocessing. 13 | -- : produce intermediate results. 14 | -- : calculate losses, gradients, and update network weights. 15 | -- : (optionally) add model-specific options and set default options. 16 | """ 17 | 18 | def __init__(self, opt): 19 | """Initialize the BaseModel class. 20 | 21 | Parameters: 22 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 23 | 24 | When creating your custom class, you need to implement your own initialization. 25 | In this fucntion, you should first call `BaseModel.__init__(self, opt)` 26 | Then, you need to define four lists: 27 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 28 | -- self.model_names (str list): specify the images that you want to display and save. 29 | -- self.visual_names (str list): define networks used in our training. 30 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 31 | """ 32 | self.opt = opt 33 | self.gpu_ids = opt.gpu_ids 34 | self.isTrain = opt.isTrain 35 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 36 | # damn it, build all directories recursively 37 | self.mkdir_recursive(opt.checkpoints_dir, opt.name, 'images') 38 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 39 | self.save_image_dir = os.path.join(self.save_dir, 'images') 40 | 41 | self.loss_names = [] 42 | self.model_names = [] 43 | self.visual_names = [] 44 | self.optimizers = [] 45 | self.image_paths = [] 46 | 47 | @staticmethod 48 | def mkdir_recursive(*folders): 49 | cur_folder = None 50 | for folder in folders: 51 | cur_folder = folder if cur_folder is None else os.path.join(cur_folder, folder) 52 | os.makedirs(cur_folder, exist_ok=True) 53 | 54 | @staticmethod 55 | def modify_commandline_options(parser, is_train): 56 | """Add new model-specific options, and rewrite default values for existing options. 57 | 58 | Parameters: 59 | parser -- original option parser 60 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 61 | 62 | Returns: 63 | the modified parser. 64 | """ 65 | return parser 66 | 67 | @abstractmethod 68 | def set_input(self, input): 69 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 70 | 71 | Parameters: 72 | input (dict): includes the data itself and its metadata information. 73 | """ 74 | pass 75 | 76 | @abstractmethod 77 | def forward(self): 78 | """Run forward pass; called by both functions and .""" 79 | pass 80 | 81 | def is_train(self): 82 | """check if the current batch is good for training.""" 83 | return True 84 | 85 | @abstractmethod 86 | def optimize_parameters(self): 87 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 88 | pass 89 | 90 | def setup(self, opt): 91 | """Load and print networks; create schedulers 92 | 93 | Parameters: 94 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 95 | """ 96 | if self.isTrain: 97 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 98 | if not self.isTrain or opt.continue_train: 99 | self.load_networks(opt.which_epoch) 100 | else: 101 | self.init_networks(opt) 102 | self.print_networks(opt.verbose) 103 | 104 | def init_networks(self, opt): 105 | print('Initializing models in %s mode and start training from scratch' % opt.init_type) 106 | for name in self.model_names: 107 | net = getattr(self, 'net' + name) 108 | if isinstance(net, torch.nn.DataParallel): 109 | net = net.module 110 | net.to(self.device) 111 | networks.init_weights(net, opt.init_type, opt.init_gain) 112 | 113 | def eval(self): 114 | """Make models eval mode during test time""" 115 | for name in self.model_names: 116 | if isinstance(name, str): 117 | net = getattr(self, 'net' + name) 118 | net.eval() 119 | 120 | def test(self): 121 | """Forward function used in test time. 122 | 123 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 124 | It also calls to produce additional visualization results 125 | """ 126 | with torch.no_grad(): 127 | self.forward() 128 | self.compute_visuals() 129 | 130 | def compute_visuals(self): 131 | """Calculate additional output images for visdom and HTML visualization""" 132 | pass 133 | 134 | def get_image_paths(self): 135 | """ Return image paths that are used to load current data""" 136 | return self.image_paths 137 | 138 | def update_learning_rate(self): 139 | """Update learning rates for all the networks; called at the end of every epoch""" 140 | for scheduler in self.schedulers: 141 | scheduler.step() 142 | lr = self.optimizers[0].param_groups[0]['lr'] 143 | print('learning rate = %.7f' % lr) 144 | 145 | def get_current_visuals(self): 146 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 147 | visual_ret = OrderedDict() 148 | for name in self.visual_names: 149 | if isinstance(name, str): 150 | visual_ret[name] = getattr(self, name) 151 | return visual_ret 152 | 153 | def get_current_losses(self): 154 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 155 | errors_ret = OrderedDict() 156 | for name in self.loss_names: 157 | if isinstance(name, str): 158 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 159 | return errors_ret 160 | 161 | def save_networks(self, epoch): 162 | """Save all the networks to the disk. 163 | 164 | Parameters: 165 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 166 | """ 167 | for name in self.model_names: 168 | if isinstance(name, str): 169 | save_filename = '%s_net_%s.pth' % (epoch, name) 170 | save_path = os.path.join(self.save_dir, save_filename) 171 | net = getattr(self, 'net' + name) 172 | torch.save(net.state_dict(), save_path) 173 | 174 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 175 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 176 | key = keys[i] 177 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 178 | if module.__class__.__name__.startswith('InstanceNorm') and \ 179 | (key == 'running_mean' or key == 'running_var'): 180 | if getattr(module, key) is None: 181 | state_dict.pop('.'.join(keys)) 182 | if module.__class__.__name__.startswith('InstanceNorm') and \ 183 | (key == 'num_batches_tracked'): 184 | state_dict.pop('.'.join(keys)) 185 | else: 186 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 187 | 188 | def load_networks(self, epoch): 189 | """Load all the networks from the disk. 190 | 191 | Parameters: 192 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 193 | """ 194 | for name in self.model_names: 195 | if isinstance(name, str): 196 | load_filename = '%s_net_%s.pth' % (epoch, name) 197 | load_path = os.path.join(self.save_dir, load_filename) 198 | net = getattr(self, 'net' + name) 199 | if isinstance(net, torch.nn.DataParallel): 200 | net = net.module 201 | print('loading the model from %s' % load_path) 202 | # if you are using PyTorch newer than 0.4 (e.g., built from 203 | # GitHub source), you can remove str() on self.device 204 | state_dict = torch.load(load_path, map_location=str(self.device)) 205 | if hasattr(state_dict, '_metadata'): 206 | del state_dict._metadata 207 | 208 | # patch InstanceNorm checkpoints prior to 0.4 209 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 210 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 211 | net.load_state_dict(state_dict) 212 | 213 | def print_networks(self, verbose): 214 | """Print the total number of parameters in the network and (if verbose) network architecture 215 | 216 | Parameters: 217 | verbose (bool) -- if verbose: print the network architecture 218 | """ 219 | print('---------- Networks initialized -------------') 220 | for name in self.model_names: 221 | if isinstance(name, str): 222 | net = getattr(self, 'net' + name) 223 | num_params = 0 224 | for param in net.parameters(): 225 | num_params += param.numel() 226 | if verbose: 227 | print(net) 228 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 229 | print('-----------------------------------------------') 230 | 231 | def set_requires_grad(self, nets, requires_grad=False): 232 | """Set requires_grad=False for all the networks to avoid unnecessary computations 233 | Parameters: 234 | nets (network list) -- a list of networks 235 | requires_grad (bool) -- whether the networks require gradients or not 236 | """ 237 | if not isinstance(nets, list): 238 | nets = [nets] 239 | for net in nets: 240 | if net is not None: 241 | for param in net.parameters(): 242 | param.requires_grad = requires_grad 243 | -------------------------------------------------------------------------------- /model/cocos_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import os 5 | import cv2 6 | import numpy as np 7 | import itertools 8 | from model import networks 9 | from model.base_model import BaseModel 10 | from model.translation_net import TranslationNet 11 | from model.correspondence_net import CorrespondenceNet 12 | from model.discriminator import Discriminator 13 | from model.loss import VGGLoss, GANLoss 14 | ''' 15 | Cross-Domian Correpondence Model 16 | ''' 17 | class CoCosModel(BaseModel): 18 | @staticmethod 19 | def modify_commandline_options(parser, is_train=True): 20 | return parser 21 | 22 | @staticmethod 23 | def torch2numpy(x): 24 | # from [-1,1] to [0,255] 25 | return ((x.detach().cpu().numpy().transpose(1,2,0) + 1) * 127.5).astype(np.uint8) 26 | 27 | def __name__(self): 28 | return 'CoCosModel' 29 | 30 | def __init__(self, opt): 31 | super().__init__(opt) 32 | self.w = opt.image_size 33 | # make a folder for save images 34 | self.image_dir = os.path.join(self.save_dir, 'images') 35 | if not os.path.isdir(self.image_dir): 36 | os.mkdir(self.image_dir) 37 | 38 | # initialize networks 39 | self.model_names = ['C', 'T'] 40 | self.netC = CorrespondenceNet(opt) 41 | self.netT = TranslationNet(opt) 42 | if opt.isTrain: 43 | self.model_names.append('D') 44 | self.netD = Discriminator(opt) 45 | 46 | self.visual_names = ['b_exemplar', 'a', 'b_gen', 'b_gt'] # HPT convention 47 | 48 | if opt.isTrain: 49 | # assign losses 50 | self.loss_names = ['perc', 'domain', 'feat', 'context', 'reg', 'adv'] 51 | self.visual_names += ['b_warp'] 52 | self.criterionFeat = torch.nn.L1Loss() 53 | # Both interface for VGG and perceptual loss 54 | # call with different mode and layer params 55 | self.criterionVGG = VGGLoss(self.device) 56 | # Support hinge loss 57 | self.criterionAdv = GANLoss(gan_mode=opt.gan_mode).to(self.device) 58 | self.criterionDomain = nn.L1Loss() 59 | self.criterionReg = torch.nn.L1Loss() 60 | 61 | 62 | # initialize optimizers 63 | gen_params = itertools.chain(self.netT.parameters(), self.netC.parameters()) 64 | self.optG = torch.optim.Adam(gen_params, lr=opt.lr, betas=(opt.beta1, 0.999)) 65 | self.optD = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 66 | self.optimizers = [self.optG, self.optD] 67 | 68 | # Finally, load checkpoints and recover schedulers 69 | self.setup(opt) 70 | torch.autograd.set_detect_anomaly(True) 71 | 72 | def set_input(self, batch): 73 | # expecting 'a' -> 'b_gt', 'a_exemplar' -> 'b_exemplar', ('b_deform') 74 | # for human pose transfer, 'b_deform' is already 'b_exemplar' 75 | for k, v in batch.items(): 76 | setattr(self, k, v.to(self.device)) 77 | 78 | def forward(self): 79 | self.sa, self.sb, self.fb_warp, self.b_warp = self.netC(self.a, self.b_exemplar) # 3*HW*HW 80 | self.b_gen = self.netT(self.b_warp) 81 | # self.b_gen = self.netT(self.fb_warp) retain original feature or use warped rgb? 82 | 83 | # TODO: Implement backward warping (maybe we should adjust the input size?) 84 | _, _, _, self.b_reg = self.netC(self.a_exemplar, 85 | F.interpolate(self.b_warp, (self.w, self.w), mode='bilinear') 86 | ) 87 | #print(self.b_gen.shape, self.b_reg.shape, self.b_gt.shape) 88 | 89 | def test(self): 90 | with torch.no_grad(): 91 | _, _, _, self.b_warp = self.netC(self.a, self.b_exemplar) # 3*HW*HW 92 | self.b_gen = self.netT(self.b_warp) 93 | 94 | def backward_G(self): 95 | self.optG.zero_grad() 96 | # Damn, do we really need 6 losses? 97 | # 1. Perc loss(For human pose transfer we abandon it, it's all in the criterion Feat) 98 | self.loss_perc = 0 99 | # 2. domain loss 100 | self.loss_domain = self.opt.lambda_domain * self.criterionDomain(self.sa, self.sb) 101 | # 3. losses for pseudo exemplar pairs 102 | self.loss_feat = self.opt.lambda_feat * self.criterionVGG(self.b_gen, self.b_gt, mode='perceptual') 103 | # 4. Contextural loss 104 | self.loss_context = self.opt.lambda_context * self.criterionVGG(self.b_gen, self.b_exemplar, mode='contextual', layers=[2,3,4,5]) 105 | # 5. Reg loss 106 | b_exemplar_small = F.interpolate(self.b_exemplar, self.b_reg.size()[2:], mode='bilinear') 107 | self.loss_reg = self.opt.lambda_reg * self.criterionReg(self.b_reg, b_exemplar_small) 108 | # 6. GAN loss 109 | pred_real, pred_fake = self.discriminate(self.b_gt, self.b_gen) 110 | self.loss_adv = self.opt.lambda_adv * self.criterionAdv(pred_fake, True, for_discriminator=False) 111 | 112 | g_loss = self.loss_perc + self.loss_domain + self.loss_feat \ 113 | + self.loss_context + self.loss_reg + self.loss_adv 114 | 115 | g_loss.backward() 116 | self.optG.step() 117 | 118 | def discriminate(self, real, fake): 119 | fake_and_real = torch.cat([fake, real], dim=0) 120 | discriminator_out = self.netD(fake_and_real) 121 | pred_fake, pred_real = self.divide_pred(discriminator_out) 122 | 123 | return pred_fake, pred_real 124 | 125 | # Take the prediction of fake and real images from the combined batch 126 | def divide_pred(self, pred): 127 | # the prediction contains the intermediate outputs of multiscale GAN, 128 | # so it's usually a list 129 | if isinstance(pred, list): 130 | fake = [p[:p.size(0) // 2] for p in pred] 131 | real = [p[p.size(0) // 2:] for p in pred] 132 | else: 133 | fake = pred[:pred.size(0) // 2] 134 | real = pred[pred.size(0) // 2:] 135 | 136 | return fake, real 137 | 138 | def backward_D(self): 139 | self.optD.zero_grad() 140 | # test, run under no_grad mode 141 | self.test() 142 | 143 | pred_fake, pred_real = self.discriminate(self.b_gt, self.b_gen) 144 | 145 | self.d_fake = self.criterionAdv(pred_fake, False, for_discriminator=True) 146 | self.d_real = self.criterionAdv(pred_real, True, for_discriminator=True) 147 | 148 | d_loss = (self.d_fake + self.d_real) / 2 149 | d_loss.backward() 150 | self.optD.step() 151 | 152 | def optimize_parameters(self): 153 | # must call self.set_input(data) first 154 | self.forward() 155 | self.backward_G() 156 | self.backward_D() 157 | 158 | ### Standalone utility functions 159 | def log_loss(self, epoch, iter): 160 | msg = 'Epoch %d iter %d\n ' % (epoch, iter) 161 | for name in self.loss_names: 162 | val = getattr(self, 'loss_%s' % name) 163 | if isinstance(val, torch.cuda.FloatTensor): 164 | val = val.item() 165 | msg += '%s: %.4f, ' % (name, val) 166 | print(msg) 167 | 168 | def log_visual(self, epoch, iter): 169 | save_path = os.path.join(self.save_image_dir, 'epoch%03d_iter%05d.png' % (epoch, iter)) 170 | # warped image is not the same resolution, need scaling 171 | self.b_warp = F.interpolate(self.b_warp, (self.w, self.w), mode='bicubic') 172 | pack = torch.cat( 173 | [getattr(self, name) for name in self.visual_names], dim=3 174 | )[0] # only save one example 175 | cv2.imwrite(save_path, self.torch2numpy(pack)) 176 | cv2.imwrite('b_ex' + save_path, self.torch2numpy(self.b_exemplar[0])) 177 | 178 | def update_learning_rate(self): 179 | ''' 180 | Update learning rates for all the networks; 181 | called at the end of every epoch by train.py 182 | ''' 183 | for scheduler in self.schedulers: 184 | scheduler.step() 185 | lr = self.optimizers[0].param_groups[0]['lr'] 186 | print('learning rate updated to %.7f' % lr) 187 | -------------------------------------------------------------------------------- /model/contextual_loss.py: -------------------------------------------------------------------------------- 1 | ''' 2 | https://github.com/roimehrez/contextualLoss/blob/master/CX/CX_distance.py 3 | ''' 4 | import torch 5 | import numpy as np 6 | 7 | class TensorAxis: 8 | N = 0 9 | H = 1 10 | W = 2 11 | C = 3 12 | 13 | 14 | class CSFlow: 15 | def __init__(self, sigma=float(0.1), b=float(1.0)): 16 | self.b = b 17 | self.sigma = sigma 18 | 19 | def __calculate_CS(self, scaled_distances, axis_for_normalization=TensorAxis.C): 20 | self.scaled_distances = scaled_distances 21 | self.cs_weights_before_normalization = torch.exp((self.b - scaled_distances) / self.sigma) 22 | # self.cs_weights_before_normalization = 1 / (1 + scaled_distances) 23 | # self.cs_NHWC = CSFlow.sum_normalize(self.cs_weights_before_normalization, axis_for_normalization) 24 | self.cs_NHWC = self.cs_weights_before_normalization 25 | 26 | # def reversed_direction_CS(self): 27 | # cs_flow_opposite = CSFlow(self.sigma, self.b) 28 | # cs_flow_opposite.raw_distances = self.raw_distances 29 | # work_axis = [TensorAxis.H, TensorAxis.W] 30 | # relative_dist = cs_flow_opposite.calc_relative_distances(axis=work_axis) 31 | # cs_flow_opposite.__calculate_CS(relative_dist, work_axis) 32 | # return cs_flow_opposite 33 | 34 | # -- 35 | @staticmethod 36 | def create_using_L2(I_features, T_features, sigma=float(0.5), b=float(1.0)): 37 | cs_flow = CSFlow(sigma, b) 38 | sT = T_features.shape 39 | sI = I_features.shape 40 | 41 | Ivecs = torch.reshape(I_features, (sI[0], -1, sI[3])) 42 | Tvecs = torch.reshape(T_features, (sI[0], -1, sT[3])) 43 | r_Ts = torch.sum(Tvecs * Tvecs, 2) 44 | r_Is = torch.sum(Ivecs * Ivecs, 2) 45 | raw_distances_list = [] 46 | for i in range(sT[0]): 47 | Ivec, Tvec, r_T, r_I = Ivecs[i], Tvecs[i], r_Ts[i], r_Is[i] 48 | A = Tvec @ torch.transpose(Ivec, 0, 1) # (matrix multiplication) 49 | cs_flow.A = A 50 | # A = tf.matmul(Tvec, tf.transpose(Ivec)) 51 | r_T = torch.reshape(r_T, [-1, 1]) # turn to column vector 52 | dist = r_T - 2 * A + r_I 53 | dist = torch.reshape(torch.transpose(dist, 0, 1), shape=(1, sI[1], sI[2], dist.shape[0])) 54 | # protecting against numerical problems, dist should be positive 55 | dist = torch.clamp(dist, min=float(0.0)) 56 | # dist = tf.sqrt(dist) 57 | raw_distances_list += [dist] 58 | 59 | cs_flow.raw_distances = torch.cat(raw_distances_list) 60 | 61 | relative_dist = cs_flow.calc_relative_distances() 62 | cs_flow.__calculate_CS(relative_dist) 63 | return cs_flow 64 | 65 | # -- 66 | @staticmethod 67 | def create_using_L1(I_features, T_features, sigma=float(0.5), b=float(1.0)): 68 | cs_flow = CSFlow(sigma, b) 69 | sT = T_features.shape 70 | sI = I_features.shape 71 | 72 | Ivecs = torch.reshape(I_features, (sI[0], -1, sI[3])) 73 | Tvecs = torch.reshape(T_features, (sI[0], -1, sT[3])) 74 | raw_distances_list = [] 75 | for i in range(sT[0]): 76 | Ivec, Tvec = Ivecs[i], Tvecs[i] 77 | dist = torch.abs(torch.sum(Ivec.unsqueeze(1) - Tvec.unsqueeze(0), dim=2)) 78 | dist = torch.reshape(torch.transpose(dist, 0, 1), shape=(1, sI[1], sI[2], dist.shape[0])) 79 | # protecting against numerical problems, dist should be positive 80 | dist = torch.clamp(dist, min=float(0.0)) 81 | # dist = tf.sqrt(dist) 82 | raw_distances_list += [dist] 83 | 84 | cs_flow.raw_distances = torch.cat(raw_distances_list) 85 | 86 | relative_dist = cs_flow.calc_relative_distances() 87 | cs_flow.__calculate_CS(relative_dist) 88 | return cs_flow 89 | 90 | # -- 91 | @staticmethod 92 | def create_using_dotP(I_features, T_features, sigma=float(0.5), b=float(1.0)): 93 | cs_flow = CSFlow(sigma, b) 94 | # prepare feature before calculating cosine distance 95 | T_features, I_features = cs_flow.center_by_T(T_features, I_features) 96 | T_features = CSFlow.l2_normalize_channelwise(T_features) 97 | I_features = CSFlow.l2_normalize_channelwise(I_features) 98 | 99 | # work seperatly for each example in dim 1 100 | cosine_dist_l = [] 101 | N = T_features.size()[0] 102 | for i in range(N): 103 | T_features_i = T_features[i, :, :, :].unsqueeze_(0) # 1HWC --> 1CHW 104 | I_features_i = I_features[i, :, :, :].unsqueeze_(0).permute((0, 3, 1, 2)) 105 | patches_PC11_i = cs_flow.patch_decomposition(T_features_i) # 1HWC --> PC11, with P=H*W 106 | cosine_dist_i = torch.nn.functional.conv2d(I_features_i, patches_PC11_i) 107 | cosine_dist_1HWC = cosine_dist_i.permute((0, 2, 3, 1)) 108 | cosine_dist_l.append(cosine_dist_i.permute((0, 2, 3, 1))) # back to 1HWC 109 | 110 | cs_flow.cosine_dist = torch.cat(cosine_dist_l, dim=0) 111 | 112 | cs_flow.raw_distances = - (cs_flow.cosine_dist - 1) / 2 ### why - 113 | 114 | relative_dist = cs_flow.calc_relative_distances() 115 | cs_flow.__calculate_CS(relative_dist) 116 | return cs_flow 117 | 118 | def calc_relative_distances(self, axis=TensorAxis.C): 119 | epsilon = 1e-5 120 | div = torch.min(self.raw_distances, dim=axis, keepdim=True)[0] 121 | relative_dist = self.raw_distances / (div + epsilon) 122 | return relative_dist 123 | 124 | @staticmethod 125 | def sum_normalize(cs, axis=TensorAxis.C): 126 | reduce_sum = torch.sum(cs, dim=axis, keepdim=True) 127 | cs_normalize = torch.div(cs, reduce_sum) 128 | return cs_normalize 129 | 130 | def center_by_T(self, T_features, I_features): 131 | # assuming both input are of the same size 132 | # calculate stas over [batch, height, width], expecting 1x1xDepth tensor 133 | axes = [0, 1, 2] 134 | self.meanT = T_features.mean(0, keepdim=True).mean(1, keepdim=True).mean(2, keepdim=True) 135 | self.varT = T_features.var(0, keepdim=True).var(1, keepdim=True).var(2, keepdim=True) 136 | self.T_features_centered = T_features - self.meanT 137 | self.I_features_centered = I_features - self.meanT 138 | 139 | return self.T_features_centered, self.I_features_centered 140 | 141 | @staticmethod 142 | def l2_normalize_channelwise(features): 143 | norms = features.norm(p=2, dim=TensorAxis.C, keepdim=True) 144 | features = features.div(norms) 145 | return features 146 | 147 | def patch_decomposition(self, T_features): 148 | # 1HWC --> 11PC --> PC11, with P=H*W 149 | (N, H, W, C) = T_features.shape 150 | P = H * W 151 | patches_PC11 = T_features.reshape(shape=(1, 1, P, C)).permute(dims=(2, 3, 0, 1)) 152 | return patches_PC11 153 | 154 | @staticmethod 155 | def pdist2(x, keepdim=False): 156 | sx = x.shape 157 | x = x.reshape(shape=(sx[0], sx[1] * sx[2], sx[3])) 158 | differences = x.unsqueeze(2) - x.unsqueeze(1) 159 | distances = torch.sum(differences**2, -1) 160 | if keepdim: 161 | distances = distances.reshape(shape=(sx[0], sx[1], sx[2], sx[3])) 162 | return distances 163 | 164 | @staticmethod 165 | def calcR_static(sT, order='C', deformation_sigma=0.05): 166 | # oreder can be C or F (matlab order) 167 | pixel_count = sT[0] * sT[1] 168 | 169 | rangeRows = range(0, sT[1]) 170 | rangeCols = range(0, sT[0]) 171 | Js, Is = np.meshgrid(rangeRows, rangeCols) 172 | row_diff_from_first_row = Is 173 | col_diff_from_first_col = Js 174 | 175 | row_diff_from_first_row_3d_repeat = np.repeat(row_diff_from_first_row[:, :, np.newaxis], pixel_count, axis=2) 176 | col_diff_from_first_col_3d_repeat = np.repeat(col_diff_from_first_col[:, :, np.newaxis], pixel_count, axis=2) 177 | 178 | rowDiffs = -row_diff_from_first_row_3d_repeat + row_diff_from_first_row.flatten(order).reshape(1, 1, -1) 179 | colDiffs = -col_diff_from_first_col_3d_repeat + col_diff_from_first_col.flatten(order).reshape(1, 1, -1) 180 | R = rowDiffs ** 2 + colDiffs ** 2 181 | R = R.astype(np.float32) 182 | R = np.exp(-(R) / (2 * deformation_sigma ** 2)) 183 | return R 184 | 185 | 186 | 187 | 188 | 189 | 190 | # -------------------------------------------------- 191 | # CX loss 192 | # -------------------------------------------------- 193 | 194 | 195 | 196 | def CX_loss(T_features, I_features, deformation=False, dis=False): 197 | # T_features = tf.convert_to_tensor(T_features, dtype=tf.float32) 198 | # I_features = tf.convert_to_tensor(I_features, dtype=tf.float32) 199 | # since this is a convertion of tensorflow to pytorch we permute the tensor from 200 | # T_features = normalize_tensor(T_features) 201 | # I_features = normalize_tensor(I_features) 202 | 203 | # since this originally Tensorflow implemntation 204 | # we modify all tensors to be as TF convention and not as the convention of pytorch. 205 | def from_pt2tf(Tpt): 206 | Ttf = Tpt.permute(0, 2, 3, 1) 207 | return Ttf 208 | # N x C x H x W --> N x H x W x C 209 | T_features_tf = from_pt2tf(T_features) 210 | I_features_tf = from_pt2tf(I_features) 211 | 212 | # cs_flow = CSFlow.create_using_dotP(I_features_tf, T_features_tf, sigma=1.0) 213 | cs_flow = CSFlow.create_using_L2(I_features_tf, T_features_tf, sigma=1.0) 214 | # sum_normalize: 215 | # To: 216 | cs = cs_flow.cs_NHWC 217 | 218 | if deformation: 219 | deforma_sigma = 0.001 220 | sT = T_features_tf.shape[1:2 + 1] 221 | R = CSFlow.calcR_static(sT, deformation_sigma=deforma_sigma) 222 | cs *= torch.Tensor(R).unsqueeze(dim=0).cuda() 223 | 224 | if dis: 225 | CS = [] 226 | k_max_NC = torch.max(torch.max(cs, dim=1)[1], dim=1)[1] 227 | indices = k_max_NC.cpu() 228 | N, C = indices.shape 229 | for i in range(N): 230 | CS.append((C - len(torch.unique(indices[i, :]))) / C) 231 | score = torch.FloatTensor(CS) 232 | else: 233 | # reduce_max X and Y dims 234 | # cs = CSFlow.pdist2(cs,keepdim=True) 235 | k_max_NC = torch.max(torch.max(cs, dim=1)[0], dim=1)[0] 236 | # reduce mean over C dim 237 | CS = torch.mean(k_max_NC, dim=1) 238 | # score = 1/CS 239 | # score = torch.exp(-CS*10) 240 | score = -torch.log(CS) 241 | # reduce mean over N dim 242 | # CX_loss = torch.mean(CX_loss) 243 | return score 244 | 245 | 246 | def symmetric_CX_loss(T_features, I_features): 247 | score = (CX_loss(T_features, I_features) + CX_loss(I_features, T_features)) / 2 248 | return score -------------------------------------------------------------------------------- /model/correspondence_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import model.networks 5 | import itertools 6 | 7 | ''' 8 | CorrespondenceNet: Align images in different domains 9 | into a shared domain S, and compute the correlation 10 | matrix (vectorized) 11 | 12 | Note that a is guidance, b is exemplar 13 | e.g. for human pose transfer 14 | a is the target pose, b is the source image 15 | 16 | output: b_warp: a 3*H*W image 17 | ----------------- 18 | # TODO: Add sychonized batchnorm to support multi-GPU training 19 | 20 | 20200525: Potential Bug: Insufficient memory to support 4096*4096 correspondence, retreat to 1024*1024 instead 21 | ''' 22 | class CorrespondenceNet(nn.Module): 23 | def __init__(self, opt): 24 | super().__init__() 25 | print('Making a CorrespondenceNet') 26 | # domain adaptors are not shared 27 | ngf = opt.ngf 28 | self.domainA_adaptor = self.create_adaptor(opt.ncA, ngf) 29 | self.domainB_adaptor = self.create_adaptor(opt.ncB, ngf) 30 | self.softmax_alpha = 100 31 | ada_blocks = [] 32 | for i in range(4): 33 | ada_blocks += [BasicBlock(ngf*4, ngf*4)] 34 | 35 | ada_blocks += [nn.Conv2d(ngf*4, ngf*4, kernel_size=1, stride=1, padding=0)] 36 | self.adaptive_feature_block = nn.Sequential(*ada_blocks) 37 | 38 | self.to_rgb = nn.Conv2d(ngf*4, 3, kernel_size=1, stride=1, padding=0) 39 | 40 | @staticmethod 41 | def warp(fa, fb, b_raw, alpha): 42 | ''' 43 | calculate correspondence matrix and warp the exemplar features 44 | ''' 45 | assert fa.shape == fb.shape, \ 46 | 'Feature shape must match. Got %s in a and %s in b)' % (a.shape, b.shape) 47 | n,c,h,w = fa.shape 48 | # subtract mean 49 | fa = fa - torch.mean(fa, dim=(2,3), keepdim=True) 50 | fb = fb - torch.mean(fb, dim=(2,3), keepdim=True) 51 | 52 | # vectorize (merge dim H, W) and normalize channelwise vectors 53 | fa = fa.view(n, c, -1) 54 | fb = fb.view(n, c, -1) 55 | fa = fa / torch.norm(fa, dim=1, keepdim=True) 56 | fb = fb / torch.norm(fb, dim=1, keepdim=True) 57 | 58 | # correlation matrix, gonna be huge (4096*4096) 59 | # use matrix multiplication for CUDA speed up 60 | # Also, calculate the transpose of the atob correlation 61 | 62 | # warp the exemplar features b, taking softmax along the b dimension 63 | corr_ab_T = F.softmax(torch.bmm(fb.transpose(-2,-1), fa), dim=2) # n*HW*C @ n*C*HW -> n*HW*HW 64 | #print(corr_ab_T.shape) 65 | #print(softmax_weights.shape, b_raw.shape) 66 | b_warp = torch.bmm(b_raw.view(n, c, h*w), corr_ab_T) # n*HW*1 67 | return b_warp.view(n,c,h,w) 68 | 69 | def create_adaptor(self, nc, ngf): 70 | model_parts = [self.combo(nc, ngf, 3, 1, 1), 71 | self.combo(ngf, ngf*2, 4, 2, 1), 72 | self.combo(ngf*2, ngf*4, 3, 1, 1), 73 | self.combo(ngf*4, ngf*8, 4, 2, 1), 74 | self.combo(ngf*8, ngf*8, 3, 1, 1), 75 | # The following line shrinks the spatial dimension to 32*32 76 | self.combo(ngf*8, ngf*8, 4, 2, 1), 77 | [BasicBlock(ngf*8, ngf*4)], 78 | [BasicBlock(ngf*4, ngf*4)], 79 | [BasicBlock(ngf*4, ngf*4)] 80 | ] 81 | model = itertools.chain(*model_parts) 82 | return nn.Sequential(*model) 83 | 84 | def combo(self, cin, cout, kw, stride, padw): 85 | layers = [ 86 | nn.Conv2d(cin, cout, kernel_size=kw, stride=stride, padding=padw), 87 | nn.InstanceNorm2d(cout), 88 | nn.LeakyReLU(0.2), 89 | ] 90 | return layers 91 | 92 | def forward(self, a, b): 93 | sa = self.domainA_adaptor(a) 94 | sb = self.domainB_adaptor(b) 95 | fa = self.adaptive_feature_block(sa) 96 | fb = self.adaptive_feature_block(sb) 97 | # This should be sb, but who knows? 98 | b_warp = self.warp(fa, fb, b_raw=sb, alpha=self.softmax_alpha) 99 | b_img = F.tanh(self.to_rgb(b_warp)) 100 | return sa, sb, b_warp, b_img 101 | 102 | # Basic residual block 103 | class BasicBlock(nn.Module): 104 | def __init__(self, cin, cout): 105 | super(BasicBlock, self).__init__() 106 | layers = [ 107 | nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1), 108 | nn.InstanceNorm2d(cout), 109 | nn.LeakyReLU(0.2), 110 | nn.Conv2d(cout, cout, kernel_size=3, stride=1, padding=1), 111 | nn.InstanceNorm2d(cout), 112 | ] 113 | self.conv = nn.Sequential(*layers) 114 | if cin != cout: 115 | self.shortcut = nn.Conv2d(cin, cout, kernel_size=1, stride=1, padding=0) 116 | else: 117 | self.shortcut = lambda x:x 118 | 119 | def forward(self, x): 120 | out = self.conv(x) + self.shortcut(x) 121 | return out 122 | -------------------------------------------------------------------------------- /model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import model.networks 4 | 5 | 6 | class Discriminator(nn.Module): 7 | def __init__(self, opt): 8 | super(Discriminator, self).__init__() 9 | print('Making a discriminator') 10 | input_nc = opt.ncB 11 | ndf = opt.ndf 12 | n_layers = opt.nd_layers 13 | self.num_D = opt.numD 14 | norm_layer = nn.BatchNorm2d 15 | 16 | if self.num_D == 1: 17 | layers = self.get_layers(input_nc, ndf, n_layers, norm_layer) 18 | self.model = nn.Sequential(*layers) 19 | else: 20 | layers = self.get_layers(input_nc, ndf, n_layers, norm_layer) 21 | self.add_module("model_0", nn.Sequential(*layers)) 22 | self.down = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 23 | for i in range(1, self.num_D): 24 | ndf_i = int(round(ndf / (2**i))) 25 | layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer) 26 | self.add_module("model_%d" % i, nn.Sequential(*layers)) 27 | 28 | def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 29 | kw = 4 30 | padw = 1 31 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, 32 | stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 33 | 34 | nf_mult = 1 35 | nf_mult_prev = 1 36 | for n in range(1, n_layers): 37 | nf_mult_prev = nf_mult 38 | nf_mult = min(2**n, 8) 39 | sequence += [ 40 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 41 | kernel_size=kw, stride=2, padding=padw), 42 | norm_layer(ndf * nf_mult), 43 | nn.LeakyReLU(0.2, True) 44 | ] 45 | 46 | nf_mult_prev = nf_mult 47 | nf_mult = min(2**n_layers, 8) 48 | sequence += [ 49 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 50 | kernel_size=kw, stride=1, padding=padw), 51 | norm_layer(ndf * nf_mult), 52 | nn.LeakyReLU(0.2, True) 53 | ] 54 | 55 | sequence += [nn.Conv2d(ndf * nf_mult, 1, 56 | kernel_size=kw, stride=1, padding=padw)] 57 | 58 | return sequence 59 | 60 | def forward(self, input): 61 | if self.num_D == 1: 62 | return self.model(input) 63 | result = [] 64 | down = input 65 | for i in range(self.num_D): 66 | model = getattr(self, "model_%d" % i) 67 | result.append(model(down)) 68 | if i != self.num_D - 1: 69 | down = self.down(down) 70 | return result -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torchvision.models import vgg19 10 | from model.contextual_loss import symmetric_CX_loss 11 | 12 | 13 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 14 | # When LSGAN is used, it is basically same as MSELoss, 15 | # but it abstracts away the need to create the target label tensor 16 | # that has the same size as the input 17 | class GANLoss(nn.Module): 18 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, 19 | tensor=torch.cuda.FloatTensor, opt=None): 20 | super(GANLoss, self).__init__() 21 | self.real_label = target_real_label 22 | self.fake_label = target_fake_label 23 | self.real_label_tensor = None 24 | self.fake_label_tensor = None 25 | self.zero_tensor = None 26 | self.Tensor = tensor 27 | self.gan_mode = gan_mode 28 | self.opt = opt 29 | if gan_mode == 'ls': 30 | pass 31 | elif gan_mode == 'original': 32 | pass 33 | elif gan_mode == 'w': 34 | pass 35 | elif gan_mode == 'hinge': 36 | pass 37 | else: 38 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) 39 | 40 | def get_target_tensor(self, input, target_is_real): 41 | if target_is_real: 42 | if self.real_label_tensor is None: 43 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label) 44 | self.real_label_tensor.requires_grad_(False) 45 | return self.real_label_tensor.expand_as(input) 46 | else: 47 | if self.fake_label_tensor is None: 48 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) 49 | self.fake_label_tensor.requires_grad_(False) 50 | return self.fake_label_tensor.expand_as(input) 51 | 52 | def get_zero_tensor(self, input): 53 | if self.zero_tensor is None: 54 | self.zero_tensor = self.Tensor(1).fill_(0) 55 | self.zero_tensor.requires_grad_(False) 56 | return self.zero_tensor.expand_as(input) 57 | 58 | def loss(self, input, target_is_real, for_discriminator=True): 59 | if self.gan_mode == 'original': # cross entropy loss 60 | target_tensor = self.get_target_tensor(input, target_is_real) 61 | loss = F.binary_cross_entropy_with_logits(input, target_tensor) 62 | return loss 63 | elif self.gan_mode == 'ls': 64 | target_tensor = self.get_target_tensor(input, target_is_real) 65 | return F.mse_loss(input, target_tensor) 66 | elif self.gan_mode == 'hinge': 67 | if for_discriminator: 68 | if target_is_real: 69 | minval = torch.min(input - 1, self.get_zero_tensor(input)) 70 | loss = -torch.mean(minval) 71 | else: 72 | minval = torch.min(-input - 1, self.get_zero_tensor(input)) 73 | loss = -torch.mean(minval) 74 | else: 75 | assert target_is_real, "The generator's hinge loss must be aiming for real" 76 | loss = -torch.mean(input) 77 | return loss 78 | else: 79 | # wgan 80 | if target_is_real: 81 | return -input.mean() 82 | else: 83 | return input.mean() 84 | 85 | def __call__(self, input, target_is_real, for_discriminator=True): 86 | # computing loss is a bit complicated because |input| may not be 87 | # a tensor, but list of tensors in case of multiscale discriminator 88 | if isinstance(input, list): 89 | loss = 0 90 | for pred_i in input: 91 | if isinstance(pred_i, list): 92 | pred_i = pred_i[-1] 93 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) 94 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 95 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 96 | loss += new_loss 97 | return loss / len(input) 98 | else: 99 | return self.loss(input, target_is_real, for_discriminator) 100 | 101 | 102 | # Perceptual loss and contextual loss that both 103 | # use a pretrained VGG network to extract features 104 | # To calculate different losses, assign mode when calling it 105 | class VGGLoss(nn.Module): 106 | def __init__(self, device, active_layers=None): 107 | super(VGGLoss, self).__init__() 108 | self.vgg = VGG19().to(device) 109 | self.criterion_perceptual = nn.L1Loss() 110 | self.criterion_contextual = symmetric_CX_loss 111 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 112 | 113 | def forward(self, x, y, mode='perceptual', layers=None): 114 | ''' 115 | Control feature usage 116 | Say you only want to compute relu4_2 117 | set active_layers = [4] 118 | Or, you want to include relu2_2 to 5_2 119 | set active_layers = [2,3,4,5] 120 | ''' 121 | criterion = getattr(self, 'criterion_%s' % mode) 122 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 123 | loss = 0 124 | if layers is None: 125 | layers = range(len(x_vgg)) 126 | else: 127 | layers = [l-1 for l in layers] # 0-index 128 | for i in layers: 129 | #print(i, x_vgg[i].shape, y_vgg[i].shape) 130 | loss += self.weights[i] * criterion(x_vgg[i], y_vgg[i].detach()) 131 | 132 | return loss 133 | 134 | 135 | # VGG architecter, used for the perceptual loss using a pretrained VGG network 136 | class VGG19(torch.nn.Module): 137 | def __init__(self, requires_grad=False): 138 | super().__init__() 139 | vgg_pretrained_features = vgg19(pretrained=True).features 140 | self.slice1 = torch.nn.Sequential() 141 | self.slice2 = torch.nn.Sequential() 142 | self.slice3 = torch.nn.Sequential() 143 | self.slice4 = torch.nn.Sequential() 144 | self.slice5 = torch.nn.Sequential() 145 | for x in range(2): 146 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 147 | for x in range(2, 7): 148 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 149 | for x in range(7, 12): 150 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 151 | for x in range(12, 21): 152 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 153 | for x in range(21, 30): 154 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 155 | if not requires_grad: 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | 159 | def forward(self, X): 160 | h_relu1 = self.slice1(X) 161 | h_relu2 = self.slice2(h_relu1) 162 | h_relu3 = self.slice3(h_relu2) 163 | h_relu4 = self.slice4(h_relu3) 164 | h_relu5 = self.slice5(h_relu4) 165 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 166 | return out 167 | -------------------------------------------------------------------------------- /model/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | 7 | ############################################################################### 8 | # Helper functions 9 | ############################################################################### 10 | 11 | 12 | def init_weights(net, init_type='normal', init_gain=0.02): 13 | """Initialize network weights. 14 | Parameters: 15 | net (network) -- network to be initialized 16 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 17 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 18 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 19 | work better for some applications. Feel free to try yourself. 20 | """ 21 | def init_func(m): # define the initialization function 22 | classname = m.__class__.__name__ 23 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 24 | if init_type == 'normal': 25 | init.normal_(m.weight.data, 0.0, init_gain) 26 | elif init_type == 'xavier': 27 | init.xavier_normal_(m.weight.data, gain=init_gain) 28 | elif init_type == 'kaiming': 29 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 30 | elif init_type == 'orthogonal': 31 | init.orthogonal_(m.weight.data, gain=init_gain) 32 | else: 33 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 34 | if hasattr(m, 'bias') and m.bias is not None: 35 | init.constant_(m.bias.data, 0.0) 36 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 37 | init.normal_(m.weight.data, 1.0, init_gain) 38 | init.constant_(m.bias.data, 0.0) 39 | 40 | print('initialize network with %s' % init_type) 41 | net.apply(init_func) # apply the initialization function 42 | 43 | 44 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 45 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 46 | Parameters: 47 | net (network) -- the network to be initialized 48 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 49 | gain (float) -- scaling factor for normal, xavier and orthogonal. 50 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 51 | Return an initialized network. 52 | """ 53 | if len(gpu_ids) > 0: 54 | assert(torch.cuda.is_available()) 55 | net.to(gpu_ids[0]) 56 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 57 | init_weights(net, init_type, init_gain=init_gain) 58 | return net 59 | 60 | 61 | def get_scheduler(optimizer, opt): 62 | """Return a learning rate scheduler 63 | Parameters: 64 | optimizer -- the optimizer of the network 65 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  66 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 67 | For 'linear', we keep the same learning rate for the first epochs 68 | and linearly decay the rate to zero over the next epochs. 69 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 70 | See https://pytorch.org/docs/stable/optim.html for more details. 71 | """ 72 | if opt.lr_policy == 'linear': 73 | def lambda_rule(epoch): 74 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 75 | return lr_l 76 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 77 | elif opt.lr_policy == 'step': 78 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 79 | elif opt.lr_policy == 'plateau': 80 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 81 | elif opt.lr_policy == 'cosine': 82 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 83 | else: 84 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 85 | return scheduler 86 | 87 | 88 | def get_norm_layer(norm_type='instance'): 89 | """Return a normalization layer 90 | Parameters: 91 | norm_type (str) -- the name of the normalization layer: batch | instance | none 92 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 93 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 94 | """ 95 | if norm_type == 'batch': 96 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 97 | elif norm_type == 'instance': 98 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 99 | elif norm_type == 'none': 100 | norm_layer = None 101 | else: 102 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 103 | return norm_layer 104 | 105 | 106 | def get_non_linearity(layer_type='relu'): 107 | if layer_type == 'relu': 108 | nl_layer = functools.partial(nn.ReLU, inplace=True) 109 | elif layer_type == 'lrelu': 110 | nl_layer = functools.partial( 111 | nn.LeakyReLU, negative_slope=0.2, inplace=True) 112 | elif layer_type == 'elu': 113 | nl_layer = functools.partial(nn.ELU, inplace=True) 114 | else: 115 | raise NotImplementedError( 116 | 'nonlinearity activitation [%s] is not found' % layer_type) 117 | return nl_layer 118 | 119 | 120 | def define_G(input_nc, output_nc, nz, ngf, netG='unet_128', norm='batch', nl='relu', 121 | use_dropout=False, init_type='xavier', init_gain=0.02, gpu_ids=[], where_add='input', upsample='bilinear'): 122 | net = None 123 | norm_layer = get_norm_layer(norm_type=norm) 124 | nl_layer = get_non_linearity(layer_type=nl) 125 | 126 | if nz == 0: 127 | where_add = 'input' 128 | 129 | if netG == 'unet_128' and where_add == 'input': 130 | net = G_Unet_add_input(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, 131 | use_dropout=use_dropout, upsample=upsample) 132 | elif netG == 'unet_256' and where_add == 'input': 133 | net = G_Unet_add_input(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, 134 | use_dropout=use_dropout, upsample=upsample) 135 | elif netG == 'unet_128' and where_add == 'all': 136 | net = G_Unet_add_all(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, 137 | use_dropout=use_dropout, upsample=upsample) 138 | elif netG == 'unet_256' and where_add == 'all': 139 | net = G_Unet_add_all(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, 140 | use_dropout=use_dropout, upsample=upsample) 141 | else: 142 | raise NotImplementedError('Generator model name [%s] is not recognized' % net) 143 | 144 | return init_net(net, init_type, init_gain, gpu_ids) 145 | 146 | 147 | def define_D(input_nc, ndf, netD, norm='batch', nl='lrelu', init_type='xavier', init_gain=0.02, num_Ds=1, gpu_ids=[]): 148 | net = None 149 | norm_layer = get_norm_layer(norm_type=norm) 150 | nl = 'lrelu' # use leaky relu for D 151 | nl_layer = get_non_linearity(layer_type=nl) 152 | 153 | if netD == 'basic_128': 154 | net = D_NLayers(input_nc, ndf, n_layers=2, norm_layer=norm_layer, nl_layer=nl_layer) 155 | elif netD == 'basic_256': 156 | net = D_NLayers(input_nc, ndf, n_layers=3, norm_layer=norm_layer, nl_layer=nl_layer) 157 | elif netD == 'basic_128_multi': 158 | net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=2, norm_layer=norm_layer, num_D=num_Ds) 159 | elif netD == 'basic_256_multi': 160 | net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=3, norm_layer=norm_layer, num_D=num_Ds) 161 | else: 162 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % net) 163 | return init_net(net, init_type, init_gain, gpu_ids) 164 | 165 | 166 | def define_E(input_nc, output_nc, ndf, netE, 167 | norm='batch', nl='lrelu', 168 | init_type='xavier', init_gain=0.02, gpu_ids=[], vaeLike=False): 169 | net = None 170 | norm_layer = get_norm_layer(norm_type=norm) 171 | nl = 'lrelu' # use leaky relu for E 172 | nl_layer = get_non_linearity(layer_type=nl) 173 | if netE == 'resnet_128': 174 | net = E_ResNet(input_nc, output_nc, ndf, n_blocks=4, norm_layer=norm_layer, 175 | nl_layer=nl_layer, vaeLike=vaeLike) 176 | elif netE == 'resnet_256': 177 | net = E_ResNet(input_nc, output_nc, ndf, n_blocks=5, norm_layer=norm_layer, 178 | nl_layer=nl_layer, vaeLike=vaeLike) 179 | elif netE == 'conv_128': 180 | net = E_NLayers(input_nc, output_nc, ndf, n_layers=4, norm_layer=norm_layer, 181 | nl_layer=nl_layer, vaeLike=vaeLike) 182 | elif netE == 'conv_256': 183 | net = E_NLayers(input_nc, output_nc, ndf, n_layers=5, norm_layer=norm_layer, 184 | nl_layer=nl_layer, vaeLike=vaeLike) 185 | else: 186 | raise NotImplementedError('Encoder model name [%s] is not recognized' % net) 187 | 188 | return init_net(net, init_type, init_gain, gpu_ids) 189 | 190 | 191 | class D_NLayersMulti(nn.Module): 192 | def __init__(self, input_nc, ndf=64, n_layers=3, 193 | norm_layer=nn.BatchNorm2d, num_D=1): 194 | super(D_NLayersMulti, self).__init__() 195 | # st() 196 | self.num_D = num_D 197 | if num_D == 1: 198 | layers = self.get_layers(input_nc, ndf, n_layers, norm_layer) 199 | self.model = nn.Sequential(*layers) 200 | else: 201 | layers = self.get_layers(input_nc, ndf, n_layers, norm_layer) 202 | self.add_module("model_0", nn.Sequential(*layers)) 203 | self.down = nn.AvgPool2d(3, stride=2, padding=[ 204 | 1, 1], count_include_pad=False) 205 | for i in range(1, num_D): 206 | ndf_i = int(round(ndf / (2**i))) 207 | layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer) 208 | self.add_module("model_%d" % i, nn.Sequential(*layers)) 209 | 210 | def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 211 | kw = 4 212 | padw = 1 213 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, 214 | stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 215 | 216 | nf_mult = 1 217 | nf_mult_prev = 1 218 | for n in range(1, n_layers): 219 | nf_mult_prev = nf_mult 220 | nf_mult = min(2**n, 8) 221 | sequence += [ 222 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 223 | kernel_size=kw, stride=2, padding=padw), 224 | norm_layer(ndf * nf_mult), 225 | nn.LeakyReLU(0.2, True) 226 | ] 227 | 228 | nf_mult_prev = nf_mult 229 | nf_mult = min(2**n_layers, 8) 230 | sequence += [ 231 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 232 | kernel_size=kw, stride=1, padding=padw), 233 | norm_layer(ndf * nf_mult), 234 | nn.LeakyReLU(0.2, True) 235 | ] 236 | 237 | sequence += [nn.Conv2d(ndf * nf_mult, 1, 238 | kernel_size=kw, stride=1, padding=padw)] 239 | 240 | return sequence 241 | 242 | def forward(self, input): 243 | if self.num_D == 1: 244 | return self.model(input) 245 | result = [] 246 | down = input 247 | for i in range(self.num_D): 248 | model = getattr(self, "model_%d" % i) 249 | result.append(model(down)) 250 | if i != self.num_D - 1: 251 | down = self.down(down) 252 | return result 253 | 254 | 255 | class D_NLayers(nn.Module): 256 | """Defines a PatchGAN discriminator""" 257 | 258 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 259 | """Construct a PatchGAN discriminator 260 | Parameters: 261 | input_nc (int) -- the number of channels in input images 262 | ndf (int) -- the number of filters in the last conv layer 263 | n_layers (int) -- the number of conv layers in the discriminator 264 | norm_layer -- normalization layer 265 | """ 266 | super(D_NLayers, self).__init__() 267 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 268 | use_bias = norm_layer.func != nn.BatchNorm2d 269 | else: 270 | use_bias = norm_layer != nn.BatchNorm2d 271 | 272 | kw = 4 273 | padw = 1 274 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 275 | nf_mult = 1 276 | nf_mult_prev = 1 277 | for n in range(1, n_layers): # gradually increase the number of filters 278 | nf_mult_prev = nf_mult 279 | nf_mult = min(2 ** n, 8) 280 | sequence += [ 281 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 282 | norm_layer(ndf * nf_mult), 283 | nn.LeakyReLU(0.2, True) 284 | ] 285 | 286 | nf_mult_prev = nf_mult 287 | nf_mult = min(2 ** n_layers, 8) 288 | sequence += [ 289 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 290 | norm_layer(ndf * nf_mult), 291 | nn.LeakyReLU(0.2, True) 292 | ] 293 | 294 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 295 | self.model = nn.Sequential(*sequence) 296 | 297 | def forward(self, input): 298 | """Standard forward.""" 299 | return self.model(input) 300 | 301 | 302 | ############################################################################## 303 | # Classes 304 | ############################################################################## 305 | class RecLoss(nn.Module): 306 | def __init__(self, use_L2=True): 307 | super(RecLoss, self).__init__() 308 | self.use_L2 = use_L2 309 | 310 | def __call__(self, input, target, batch_mean=True): 311 | if self.use_L2: 312 | diff = (input - target) ** 2 313 | else: 314 | diff = torch.abs(input - target) 315 | if batch_mean: 316 | return torch.mean(diff) 317 | else: 318 | return torch.mean(torch.mean(torch.mean(diff, dim=1), dim=2), dim=3) 319 | 320 | 321 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 322 | # When LSGAN is used, it is basically same as MSELoss, 323 | # but it abstracts away the need to create the target label tensor 324 | # that has the same size as the input 325 | class GANLoss(nn.Module): 326 | """Define different GAN objectives. 327 | 328 | The GANLoss class abstracts away the need to create the target label tensor 329 | that has the same size as the input. 330 | """ 331 | 332 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 333 | """ Initialize the GANLoss class. 334 | 335 | Parameters: 336 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 337 | target_real_label (bool) - - label for a real image 338 | target_fake_label (bool) - - label of a fake image 339 | 340 | Note: Do not use sigmoid as the last layer of Discriminator. 341 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 342 | """ 343 | super(GANLoss, self).__init__() 344 | self.register_buffer('real_label', torch.tensor(target_real_label)) 345 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 346 | self.gan_mode = gan_mode 347 | if gan_mode == 'lsgan': 348 | self.loss = nn.MSELoss() 349 | elif gan_mode == 'vanilla': 350 | self.loss = nn.BCEWithLogitsLoss() 351 | elif gan_mode in ['wgangp']: 352 | self.loss = None 353 | else: 354 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 355 | 356 | def get_target_tensor(self, prediction, target_is_real): 357 | """Create label tensors with the same size as the input. 358 | 359 | Parameters: 360 | prediction (tensor) - - tpyically the prediction from a discriminator 361 | target_is_real (bool) - - if the ground truth label is for real images or fake images 362 | 363 | Returns: 364 | A label tensor filled with ground truth label, and with the size of the input 365 | """ 366 | 367 | if target_is_real: 368 | target_tensor = self.real_label 369 | else: 370 | target_tensor = self.fake_label 371 | return target_tensor.expand_as(prediction) 372 | 373 | def __call__(self, predictions, target_is_real): 374 | """Calculate loss given Discriminator's output and grount truth labels. 375 | 376 | Parameters: 377 | prediction (tensor list) - - tpyically the prediction output from a discriminator; supports multi Ds. 378 | target_is_real (bool) - - if the ground truth label is for real images or fake images 379 | 380 | Returns: 381 | the calculated loss. 382 | """ 383 | all_losses = [] 384 | for prediction in predictions: 385 | if self.gan_mode in ['lsgan', 'vanilla']: 386 | target_tensor = self.get_target_tensor(prediction, target_is_real) 387 | loss = self.loss(prediction, target_tensor) 388 | elif self.gan_mode == 'wgangp': 389 | if target_is_real: 390 | loss = -prediction.mean() 391 | else: 392 | loss = prediction.mean() 393 | all_losses.append(loss) 394 | total_loss = sum(all_losses) 395 | return total_loss, all_losses 396 | 397 | 398 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): 399 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 400 | Arguments: 401 | netD (network) -- discriminator network 402 | real_data (tensor array) -- real images 403 | fake_data (tensor array) -- generated images from the generator 404 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 405 | type (str) -- if we mix real and fake data or not [real | fake | mixed]. 406 | constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 407 | lambda_gp (float) -- weight for this loss 408 | Returns the gradient penalty loss 409 | """ 410 | if lambda_gp > 0.0: 411 | if type == 'real': # either use real images, fake images, or a linear interpolation of two. 412 | interpolatesv = real_data 413 | elif type == 'fake': 414 | interpolatesv = fake_data 415 | elif type == 'mixed': 416 | alpha = torch.rand(real_data.shape[0], 1) 417 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) 418 | alpha = alpha.to(device) 419 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) 420 | else: 421 | raise NotImplementedError('{} not implemented'.format(type)) 422 | interpolatesv.requires_grad_(True) 423 | disc_interpolates = netD(interpolatesv) 424 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, 425 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 426 | create_graph=True, retain_graph=True, only_inputs=True) 427 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data 428 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps 429 | return gradient_penalty, gradients 430 | else: 431 | return 0.0, None 432 | 433 | # Defines the Unet generator. 434 | # |num_downs|: number of downsamplings in UNet. For example, 435 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1 436 | # at the bottleneck 437 | 438 | 439 | class G_Unet_add_input(nn.Module): 440 | def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64, 441 | norm_layer=None, nl_layer=None, use_dropout=False, 442 | upsample='basic'): 443 | super(G_Unet_add_input, self).__init__() 444 | self.nz = nz 445 | max_nchn = 8 446 | # construct unet structure 447 | unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, 448 | innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 449 | for i in range(num_downs - 5): 450 | unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, 451 | norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample) 452 | unet_block = UnetBlock(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, 453 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 454 | unet_block = UnetBlock(ngf * 2, ngf * 2, ngf * 4, unet_block, 455 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 456 | unet_block = UnetBlock(ngf, ngf, ngf * 2, unet_block, 457 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 458 | unet_block = UnetBlock(input_nc + nz, output_nc, ngf, unet_block, 459 | outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 460 | 461 | self.model = unet_block 462 | 463 | def forward(self, x, z=None): 464 | if self.nz > 0: 465 | z_img = z.view(z.size(0), z.size(1), 1, 1).expand( 466 | z.size(0), z.size(1), x.size(2), x.size(3)) 467 | x_with_z = torch.cat([x, z_img], 1) 468 | else: 469 | x_with_z = x # no z 470 | 471 | return self.model(x_with_z) 472 | 473 | 474 | def upsampleLayer(inplanes, outplanes, upsample='basic', padding_type='zero'): 475 | # padding_type = 'zero' 476 | if upsample == 'basic': 477 | upconv = [nn.ConvTranspose2d( 478 | inplanes, outplanes, kernel_size=4, stride=2, padding=1)] 479 | elif upsample == 'bilinear': 480 | upconv = [nn.Upsample(scale_factor=2, mode='bilinear'), 481 | nn.ReflectionPad2d(1), 482 | nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=1, padding=0)] 483 | else: 484 | raise NotImplementedError( 485 | 'upsample layer [%s] not implemented' % upsample) 486 | return upconv 487 | 488 | 489 | # Defines the submodule with skip connection. 490 | # X -------------------identity---------------------- X 491 | # |-- downsampling -- |submodule| -- upsampling --| 492 | class UnetBlock(nn.Module): 493 | def __init__(self, input_nc, outer_nc, inner_nc, 494 | submodule=None, outermost=False, innermost=False, 495 | norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='zero'): 496 | super(UnetBlock, self).__init__() 497 | self.outermost = outermost 498 | p = 0 499 | downconv = [] 500 | if padding_type == 'reflect': 501 | downconv += [nn.ReflectionPad2d(1)] 502 | elif padding_type == 'replicate': 503 | downconv += [nn.ReplicationPad2d(1)] 504 | elif padding_type == 'zero': 505 | p = 1 506 | else: 507 | raise NotImplementedError( 508 | 'padding [%s] is not implemented' % padding_type) 509 | downconv += [nn.Conv2d(input_nc, inner_nc, 510 | kernel_size=4, stride=2, padding=p)] 511 | # downsample is different from upsample 512 | downrelu = nn.LeakyReLU(0.2, True) 513 | downnorm = norm_layer(inner_nc) if norm_layer is not None else None 514 | uprelu = nl_layer() 515 | upnorm = norm_layer(outer_nc) if norm_layer is not None else None 516 | 517 | if outermost: 518 | upconv = upsampleLayer( 519 | inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type) 520 | down = downconv 521 | up = [uprelu] + upconv + [nn.Tanh()] 522 | model = down + [submodule] + up 523 | elif innermost: 524 | upconv = upsampleLayer( 525 | inner_nc, outer_nc, upsample=upsample, padding_type=padding_type) 526 | down = [downrelu] + downconv 527 | up = [uprelu] + upconv 528 | if upnorm is not None: 529 | up += [upnorm] 530 | model = down + up 531 | else: 532 | upconv = upsampleLayer( 533 | inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type) 534 | down = [downrelu] + downconv 535 | if downnorm is not None: 536 | down += [downnorm] 537 | up = [uprelu] + upconv 538 | if upnorm is not None: 539 | up += [upnorm] 540 | 541 | if use_dropout: 542 | model = down + [submodule] + up + [nn.Dropout(0.5)] 543 | else: 544 | model = down + [submodule] + up 545 | 546 | self.model = nn.Sequential(*model) 547 | 548 | def forward(self, x): 549 | if self.outermost: 550 | return self.model(x) 551 | else: 552 | return torch.cat([self.model(x), x], 1) 553 | 554 | 555 | def conv3x3(in_planes, out_planes): 556 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 557 | padding=1, bias=True) 558 | 559 | 560 | # two usage cases, depend on kw and padw 561 | def upsampleConv(inplanes, outplanes, kw, padw): 562 | sequence = [] 563 | sequence += [nn.Upsample(scale_factor=2, mode='nearest')] 564 | sequence += [nn.Conv2d(inplanes, outplanes, kernel_size=kw, 565 | stride=1, padding=padw, bias=True)] 566 | return nn.Sequential(*sequence) 567 | 568 | 569 | def meanpoolConv(inplanes, outplanes): 570 | sequence = [] 571 | sequence += [nn.AvgPool2d(kernel_size=2, stride=2)] 572 | sequence += [nn.Conv2d(inplanes, outplanes, 573 | kernel_size=1, stride=1, padding=0, bias=True)] 574 | return nn.Sequential(*sequence) 575 | 576 | 577 | def convMeanpool(inplanes, outplanes): 578 | sequence = [] 579 | sequence += [conv3x3(inplanes, outplanes)] 580 | sequence += [nn.AvgPool2d(kernel_size=2, stride=2)] 581 | return nn.Sequential(*sequence) 582 | 583 | 584 | class BasicBlockUp(nn.Module): 585 | def __init__(self, inplanes, outplanes, norm_layer=None, nl_layer=None): 586 | super(BasicBlockUp, self).__init__() 587 | layers = [] 588 | if norm_layer is not None: 589 | layers += [norm_layer(inplanes)] 590 | layers += [nl_layer()] 591 | layers += [upsampleConv(inplanes, outplanes, kw=3, padw=1)] 592 | if norm_layer is not None: 593 | layers += [norm_layer(outplanes)] 594 | layers += [conv3x3(outplanes, outplanes)] 595 | self.conv = nn.Sequential(*layers) 596 | self.shortcut = upsampleConv(inplanes, outplanes, kw=1, padw=0) 597 | 598 | def forward(self, x): 599 | out = self.conv(x) + self.shortcut(x) 600 | return out 601 | 602 | 603 | class BasicBlock(nn.Module): 604 | def __init__(self, inplanes, outplanes, norm_layer=None, nl_layer=None): 605 | super(BasicBlock, self).__init__() 606 | layers = [] 607 | if norm_layer is not None: 608 | layers += [norm_layer(inplanes)] 609 | layers += [nl_layer()] 610 | layers += [conv3x3(inplanes, inplanes)] 611 | if norm_layer is not None: 612 | layers += [norm_layer(inplanes)] 613 | layers += [nl_layer()] 614 | layers += [convMeanpool(inplanes, outplanes)] 615 | self.conv = nn.Sequential(*layers) 616 | self.shortcut = meanpoolConv(inplanes, outplanes) 617 | 618 | def forward(self, x): 619 | out = self.conv(x) + self.shortcut(x) 620 | return out 621 | 622 | 623 | class E_ResNet(nn.Module): 624 | def __init__(self, input_nc=3, output_nc=1, ndf=64, n_blocks=4, 625 | norm_layer=None, nl_layer=None, vaeLike=False): 626 | super(E_ResNet, self).__init__() 627 | self.vaeLike = vaeLike 628 | max_ndf = 4 629 | conv_layers = [ 630 | nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1, bias=True)] 631 | for n in range(1, n_blocks): 632 | input_ndf = ndf * min(max_ndf, n) 633 | output_ndf = ndf * min(max_ndf, n + 1) 634 | conv_layers += [BasicBlock(input_ndf, 635 | output_ndf, norm_layer, nl_layer)] 636 | conv_layers += [nl_layer(), nn.AvgPool2d(8)] 637 | if vaeLike: 638 | self.fc = nn.Sequential(*[nn.Linear(output_ndf, output_nc)]) 639 | self.fcVar = nn.Sequential(*[nn.Linear(output_ndf, output_nc)]) 640 | else: 641 | self.fc = nn.Sequential(*[nn.Linear(output_ndf, output_nc)]) 642 | self.conv = nn.Sequential(*conv_layers) 643 | 644 | def forward(self, x): 645 | x_conv = self.conv(x) 646 | conv_flat = x_conv.view(x.size(0), -1) 647 | output = self.fc(conv_flat) 648 | if self.vaeLike: 649 | outputVar = self.fcVar(conv_flat) 650 | return output, outputVar 651 | else: 652 | return output 653 | return output 654 | 655 | 656 | # Defines the Unet generator. 657 | # |num_downs|: number of downsamplings in UNet. For example, 658 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1 659 | # at the bottleneck 660 | class G_Unet_add_all(nn.Module): 661 | def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64, 662 | norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic'): 663 | super(G_Unet_add_all, self).__init__() 664 | self.nz = nz 665 | # construct unet structure 666 | unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, None, innermost=True, 667 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 668 | unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, unet_block, 669 | norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample) 670 | for i in range(num_downs - 6): 671 | unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, unet_block, 672 | norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample) 673 | unet_block = UnetBlock_with_z(ngf * 4, ngf * 4, ngf * 8, nz, unet_block, 674 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 675 | unet_block = UnetBlock_with_z(ngf * 2, ngf * 2, ngf * 4, nz, unet_block, 676 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 677 | unet_block = UnetBlock_with_z( 678 | ngf, ngf, ngf * 2, nz, unet_block, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 679 | unet_block = UnetBlock_with_z(input_nc, output_nc, ngf, nz, unet_block, 680 | outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 681 | self.model = unet_block 682 | 683 | def forward(self, x, z): 684 | return self.model(x, z) 685 | 686 | 687 | class UnetBlock_with_z(nn.Module): 688 | def __init__(self, input_nc, outer_nc, inner_nc, nz=0, 689 | submodule=None, outermost=False, innermost=False, 690 | norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='zero'): 691 | super(UnetBlock_with_z, self).__init__() 692 | p = 0 693 | downconv = [] 694 | if padding_type == 'reflect': 695 | downconv += [nn.ReflectionPad2d(1)] 696 | elif padding_type == 'replicate': 697 | downconv += [nn.ReplicationPad2d(1)] 698 | elif padding_type == 'zero': 699 | p = 1 700 | else: 701 | raise NotImplementedError( 702 | 'padding [%s] is not implemented' % padding_type) 703 | 704 | self.outermost = outermost 705 | self.innermost = innermost 706 | self.nz = nz 707 | input_nc = input_nc + nz 708 | downconv += [nn.Conv2d(input_nc, inner_nc, 709 | kernel_size=4, stride=2, padding=p)] 710 | # downsample is different from upsample 711 | downrelu = nn.LeakyReLU(0.2, True) 712 | uprelu = nl_layer() 713 | 714 | if outermost: 715 | upconv = upsampleLayer( 716 | inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type) 717 | down = downconv 718 | up = [uprelu] + upconv + [nn.Tanh()] 719 | elif innermost: 720 | upconv = upsampleLayer( 721 | inner_nc, outer_nc, upsample=upsample, padding_type=padding_type) 722 | down = [downrelu] + downconv 723 | up = [uprelu] + upconv 724 | if norm_layer is not None: 725 | up += [norm_layer(outer_nc)] 726 | else: 727 | upconv = upsampleLayer( 728 | inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type) 729 | down = [downrelu] + downconv 730 | if norm_layer is not None: 731 | down += [norm_layer(inner_nc)] 732 | up = [uprelu] + upconv 733 | 734 | if norm_layer is not None: 735 | up += [norm_layer(outer_nc)] 736 | 737 | if use_dropout: 738 | up += [nn.Dropout(0.5)] 739 | self.down = nn.Sequential(*down) 740 | self.submodule = submodule 741 | self.up = nn.Sequential(*up) 742 | 743 | def forward(self, x, z): 744 | # print(x.size()) 745 | if self.nz > 0: 746 | z_img = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), x.size(2), x.size(3)) 747 | x_and_z = torch.cat([x, z_img], 1) 748 | else: 749 | x_and_z = x 750 | 751 | if self.outermost: 752 | x1 = self.down(x_and_z) 753 | x2 = self.submodule(x1, z) 754 | return self.up(x2) 755 | elif self.innermost: 756 | x1 = self.up(self.down(x_and_z)) 757 | return torch.cat([x1, x], 1) 758 | else: 759 | x1 = self.down(x_and_z) 760 | x2 = self.submodule(x1, z) 761 | return torch.cat([self.up(x2), x], 1) 762 | 763 | 764 | class E_NLayers(nn.Module): 765 | def __init__(self, input_nc, output_nc=1, ndf=64, n_layers=3, 766 | norm_layer=None, nl_layer=None, vaeLike=False): 767 | super(E_NLayers, self).__init__() 768 | self.vaeLike = vaeLike 769 | 770 | kw, padw = 4, 1 771 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, 772 | stride=2, padding=padw), nl_layer()] 773 | 774 | nf_mult = 1 775 | nf_mult_prev = 1 776 | for n in range(1, n_layers): 777 | nf_mult_prev = nf_mult 778 | nf_mult = min(2**n, 4) 779 | sequence += [ 780 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 781 | kernel_size=kw, stride=2, padding=padw)] 782 | if norm_layer is not None: 783 | sequence += [norm_layer(ndf * nf_mult)] 784 | sequence += [nl_layer()] 785 | sequence += [nn.AvgPool2d(8)] 786 | self.conv = nn.Sequential(*sequence) 787 | self.fc = nn.Sequential(*[nn.Linear(ndf * nf_mult, output_nc)]) 788 | if vaeLike: 789 | self.fcVar = nn.Sequential(*[nn.Linear(ndf * nf_mult, output_nc)]) 790 | 791 | def forward(self, x): 792 | x_conv = self.conv(x) 793 | conv_flat = x_conv.view(x.size(0), -1) 794 | output = self.fc(conv_flat) 795 | if self.vaeLike: 796 | outputVar = self.fcVar(conv_flat) 797 | return output, outputVar 798 | return output 799 | -------------------------------------------------------------------------------- /model/translation_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import model.networks 5 | from torch.nn.utils import spectral_norm 6 | 7 | # Also, we figure it would be better to inject the warped 8 | # guidance at the beginning rather than a constant tensor 9 | 10 | class TranslationNet(nn.Module): 11 | def __init__(self, opt): 12 | super().__init__() 13 | print('Making a TranslationNet') 14 | self.fc = nn.Conv2d(3, 16 * opt.ngf, 3, padding=1) 15 | self.sw = opt.image_size // (2**5) # fixed, 5 upsample layers 16 | self.head = SPADEResBlk(16 * opt.ngf, 16 * opt.ngf, opt.seg_dim) 17 | self.G_middle_0 = SPADEResBlk(16 * opt.ngf, 16 * opt.ngf, opt.seg_dim) 18 | self.G_middle_1 = SPADEResBlk(16 * opt.ngf, 16 * opt.ngf, opt.seg_dim) 19 | self.up_0 = SPADEResBlk(16 * opt.ngf, 8 * opt.ngf, opt.seg_dim) 20 | self.up_1 = SPADEResBlk(8 * opt.ngf, 4 * opt.ngf, opt.seg_dim) 21 | self.non_local = NonLocalLayer(opt.ngf*4) 22 | self.up_2 = SPADEResBlk(4 * opt.ngf, 2 * opt.ngf, opt.seg_dim) 23 | self.up_3 = SPADEResBlk(2 * opt.ngf, 1 * opt.ngf, opt.seg_dim) 24 | 25 | self.conv_img = nn.Conv2d(opt.ngf, 3, kernel_size=3, stride=1, padding=1) 26 | 27 | @staticmethod 28 | def up(x): 29 | return F.interpolate(x, scale_factor=2, mode='bilinear') 30 | 31 | def forward(self, x, seg=None): 32 | if seg is None: 33 | seg = x 34 | # separate execute 35 | x = F.interpolate(x, (self.sw, self.sw), mode='bilinear') # how can I forget this one? 36 | x = self.fc(x) 37 | x = self.head(x, seg) 38 | 39 | x = self.up(x) # 16 40 | x = self.G_middle_0(x, seg) 41 | x = self.G_middle_1(x, seg) 42 | 43 | x = self.up(x) # 32 44 | x = self.up_0(x, seg) 45 | x = self.up(x) # 64 46 | x = self.up_1(x, seg) 47 | x = self.up(x) # 128 48 | 49 | # 20200525: Critical Bug: 50 | # Using non-local layer with such a huge spatial resolution (128*128) 51 | # occupied way too much memory (as the intermediate tensor is O(h ** 4) memory) 52 | # I sincerely hope it's an honest mistake:) 53 | # x = self.non_local(x) 54 | 55 | x = self.up_2(x, seg) 56 | x = self.up(x) # 256 57 | x = self.up_3(x, seg) 58 | 59 | x = self.conv_img(F.leaky_relu(x, 2e-1)) 60 | x = F.tanh(x) 61 | return x 62 | 63 | 64 | # NOTE: The SPADE implementation will slightly 65 | # differ from the original https://github.com/NVlabs/SPADE 66 | # where BN will be replaced with PN. 67 | class SPADE(nn.Module): 68 | def __init__(self, cin, seg_dim): 69 | super().__init__() 70 | self.conv = nn.Sequential( 71 | nn.Conv2d(seg_dim, 128, kernel_size=3, stride=1, padding=1), 72 | nn.ReLU(), 73 | ) 74 | self.alpha = nn.Conv2d(128, cin, 75 | kernel_size=3, stride=1, padding=1) 76 | self.beta = nn.Conv2d(128, cin, 77 | kernel_size=3, stride=1, padding=1) 78 | 79 | @staticmethod 80 | def PN(x): 81 | ''' 82 | positional normalization: normalize each positional vector along the channel dimension 83 | ''' 84 | assert len(x.shape) == 4, 'Only works for 4D(image) tensor' 85 | x = x - x.mean(dim=1, keepdim=True) 86 | x_norm = x.norm(dim=1, keepdim=True) + 1e-6 87 | x = x / x_norm 88 | return x 89 | 90 | def DPN(self, x, s): 91 | h, w = x.shape[2], x.shape[3] 92 | s = F.interpolate(s, (h, w), mode='bilinear') 93 | s = self.conv(s) 94 | a = self.alpha(s) 95 | b = self.beta(s) 96 | return x * (1 + a) + b 97 | 98 | def forward(self, x, s): 99 | x_out = self.DPN(self.PN(x), s) 100 | return x_out 101 | 102 | class SPADEResBlk(nn.Module): 103 | def __init__(self, fin, fout, seg_fin): 104 | super().__init__() 105 | # Attributes 106 | self.learned_shortcut = (fin != fout) 107 | fmiddle = min(fin, fout) 108 | 109 | # create conv layers 110 | self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) 111 | self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) 112 | if self.learned_shortcut: 113 | self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) 114 | 115 | # apply spectral norm if specified 116 | self.conv_0 = spectral_norm(self.conv_0) 117 | self.conv_1 = spectral_norm(self.conv_1) 118 | if self.learned_shortcut: 119 | self.conv_s = spectral_norm(self.conv_s) 120 | 121 | # define normalization layers 122 | self.norm_0 = SPADE(fin, seg_fin) 123 | self.norm_1 = SPADE(fmiddle, seg_fin) 124 | if self.learned_shortcut: 125 | self.norm_s = SPADE(fin, seg_fin) 126 | 127 | # note the resnet block with SPADE also takes in |seg|, 128 | # the semantic segmentation map as input 129 | def forward(self, x, seg): 130 | x_s = self.shortcut(x, seg) 131 | 132 | dx = self.conv_0(self.actvn(self.norm_0(x, seg))) 133 | dx = self.conv_1(self.actvn(self.norm_1(dx, seg))) 134 | 135 | out = x_s + dx 136 | 137 | return out 138 | 139 | def shortcut(self, x, seg): 140 | if self.learned_shortcut: 141 | x_s = self.conv_s(self.norm_s(x, seg)) 142 | else: 143 | x_s = x 144 | return x_s 145 | 146 | def actvn(self, x): 147 | return F.leaky_relu(x, 2e-1) 148 | 149 | 150 | class NonLocalLayer(nn.Module): 151 | # Non-local layer for 2D shape 152 | def __init__(self, cin): 153 | super().__init__() 154 | self.cinter = cin // 2 155 | self.theta = nn.Conv2d(cin, self.cinter, 156 | kernel_size=1, stride=1, padding=0) 157 | self.phi = nn.Conv2d(cin, self.cinter, 158 | kernel_size=1, stride=1, padding=0) 159 | self.g = nn.Conv2d(cin, self.cinter, 160 | kernel_size=1, stride=1, padding=0) 161 | 162 | self.w = nn.Conv2d(self.cinter, cin, 163 | kernel_size=1, stride=1, padding=0) 164 | 165 | def forward(self, x): 166 | n, c, h, w = x.shape 167 | g_x = self.g(x).view(n, self.cinter, -1) 168 | phi_x = self.phi(x).view(n, self.cinter, -1) 169 | theta_x = self.theta(x).view(n, self.cinter, -1) 170 | # This non-local layer here occupies too much memory... 171 | print(phi_x.shape, theta_x.shape) 172 | f_x = torch.bmm(phi_x.transpose(-1,-2), theta_x) # note the transpose here 173 | f_x = F.softmax(f_x, dim=-1) 174 | res_x = self.w(torch.bmm(g_x, f_x)) # inverse order to save a permute of g_x 175 | return x + res_x 176 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | INF = 99999999 2 | 3 | class BaseOptions(object): 4 | # Data options 5 | dataroot='datasets/fashion' 6 | dataset_mode='fashion' 7 | name='fashion_cocosnet' 8 | checkpoints_dir='checkpoints' 9 | results_dir='results' 10 | num_workers=0 11 | batch_size=1 12 | serial_batches=False 13 | max_dataset_size=INF 14 | gpu_ids = [2] 15 | 16 | # Model options 17 | image_size=256 18 | padding=40 # For deep fashion dataset, the input image maybe cropped 19 | model='cocos' 20 | ncA=3 21 | ncB=3 22 | seg_dim=3 23 | ngf=16 24 | ndf=16 25 | numD=2 26 | nd_layers=3 27 | 28 | # Training options 29 | niter=30 30 | niter_decay=20 31 | epoch_count=0 32 | continue_train=False 33 | which_epoch='latest' 34 | 35 | # Logging options 36 | verbose=True 37 | print_every=10 38 | visual_every=1000 39 | save_every=5 40 | 41 | 42 | class TrainOptions(BaseOptions): 43 | phase='train' 44 | isTrain=True 45 | 46 | # Training Options 47 | lr=0.0002 48 | beta1=0.5 49 | gan_mode='hinge' 50 | lr_policy='linear' 51 | init_type='xavier' 52 | init_gain=0.02 53 | 54 | lambda_perc = 1.0 55 | lambda_domain = 5.0 56 | lambda_feat = 10.0 57 | lambda_context = 10.0 58 | lambda_reg = 1.0 59 | lambda_adv = 1.0 60 | 61 | # To resume training, uncomment the following lines 62 | # continue_train=True 63 | # which_epoch='latest' # or a certain number (e.g. '10' or '20200525-112233') 64 | 65 | class DebugOptions(TrainOptions): 66 | max_dataset_size=4 67 | num_workers=0 68 | print_every=1 69 | visual_every=1 70 | save_every=1 71 | niter=2 72 | niter_decay=1 73 | verbose=False 74 | 75 | class TestOptions(BaseOptions): 76 | phase='test' 77 | isTrain=False 78 | serial_batches=True 79 | num_workers=0 80 | batch_size=1 81 | which_epoch='latest' 82 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lotayou/CoCosNet/93142f55ff09e8ee6052d8b5c81931e7f9570093/test.py -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from options import DebugOptions, TrainOptions 2 | from data import create_dataset 3 | from model import create_model 4 | from torch.backends import cudnn 5 | import torch 6 | #opt = DebugOptions() 7 | opt = TrainOptions() 8 | #os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu_ids[0]) # test single GPU first 9 | 10 | torch.cuda.set_device(opt.gpu_ids[0]) 11 | cudnn.enabled = True 12 | cudnn.benchmark = True 13 | 14 | loader = create_dataset(opt) 15 | dataset_size = len(loader) 16 | print('#training images = %d' % dataset_size) 17 | 18 | net = create_model(opt) 19 | 20 | for epoch in range(1,opt.niter+opt.niter_decay+1): 21 | print('Begin epoch %d' % epoch) 22 | for i, data_i in enumerate(loader): 23 | net.set_input(data_i) 24 | net.optimize_parameters() 25 | 26 | #### logging, visualizing, saving 27 | if i % opt.print_every == 0: 28 | net.log_loss(epoch, i) 29 | if i % opt.visual_every == 0: 30 | net.log_visual(epoch, i) 31 | 32 | net.save_networks('latest') 33 | if epoch % opt.save_every == 0: 34 | net.save_networks(epoch) 35 | net.update_learning_rate() 36 | --------------------------------------------------------------------------------