├── .gitignore
├── LICENSE
├── README.md
├── data
├── __init__.py
├── base_dataset.py
├── fashion_dataset.py
└── pose_utils.py
├── datasets
└── fashion
├── model
├── __init__.py
├── base_model.py
├── cocos_model.py
├── contextual_loss.py
├── correspondence_net.py
├── discriminator.py
├── loss.py
├── networks.py
└── translation_net.py
├── options.py
├── test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # other stuff
132 | **/*.png
133 | **/*.pdf
134 | **/*.pth
135 | **/*.jpg
136 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU AFFERO GENERAL PUBLIC LICENSE
2 | Version 3, 19 November 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU Affero General Public License is a free, copyleft license for
11 | software and other kinds of works, specifically designed to ensure
12 | cooperation with the community in the case of network server software.
13 |
14 | The licenses for most software and other practical works are designed
15 | to take away your freedom to share and change the works. By contrast,
16 | our General Public Licenses are intended to guarantee your freedom to
17 | share and change all versions of a program--to make sure it remains free
18 | software for all its users.
19 |
20 | When we speak of free software, we are referring to freedom, not
21 | price. Our General Public Licenses are designed to make sure that you
22 | have the freedom to distribute copies of free software (and charge for
23 | them if you wish), that you receive source code or can get it if you
24 | want it, that you can change the software or use pieces of it in new
25 | free programs, and that you know you can do these things.
26 |
27 | Developers that use our General Public Licenses protect your rights
28 | with two steps: (1) assert copyright on the software, and (2) offer
29 | you this License which gives you legal permission to copy, distribute
30 | and/or modify the software.
31 |
32 | A secondary benefit of defending all users' freedom is that
33 | improvements made in alternate versions of the program, if they
34 | receive widespread use, become available for other developers to
35 | incorporate. Many developers of free software are heartened and
36 | encouraged by the resulting cooperation. However, in the case of
37 | software used on network servers, this result may fail to come about.
38 | The GNU General Public License permits making a modified version and
39 | letting the public access it on a server without ever releasing its
40 | source code to the public.
41 |
42 | The GNU Affero General Public License is designed specifically to
43 | ensure that, in such cases, the modified source code becomes available
44 | to the community. It requires the operator of a network server to
45 | provide the source code of the modified version running there to the
46 | users of that server. Therefore, public use of a modified version, on
47 | a publicly accessible server, gives the public access to the source
48 | code of the modified version.
49 |
50 | An older license, called the Affero General Public License and
51 | published by Affero, was designed to accomplish similar goals. This is
52 | a different license, not a version of the Affero GPL, but Affero has
53 | released a new version of the Affero GPL which permits relicensing under
54 | this license.
55 |
56 | The precise terms and conditions for copying, distribution and
57 | modification follow.
58 |
59 | TERMS AND CONDITIONS
60 |
61 | 0. Definitions.
62 |
63 | "This License" refers to version 3 of the GNU Affero General Public License.
64 |
65 | "Copyright" also means copyright-like laws that apply to other kinds of
66 | works, such as semiconductor masks.
67 |
68 | "The Program" refers to any copyrightable work licensed under this
69 | License. Each licensee is addressed as "you". "Licensees" and
70 | "recipients" may be individuals or organizations.
71 |
72 | To "modify" a work means to copy from or adapt all or part of the work
73 | in a fashion requiring copyright permission, other than the making of an
74 | exact copy. The resulting work is called a "modified version" of the
75 | earlier work or a work "based on" the earlier work.
76 |
77 | A "covered work" means either the unmodified Program or a work based
78 | on the Program.
79 |
80 | To "propagate" a work means to do anything with it that, without
81 | permission, would make you directly or secondarily liable for
82 | infringement under applicable copyright law, except executing it on a
83 | computer or modifying a private copy. Propagation includes copying,
84 | distribution (with or without modification), making available to the
85 | public, and in some countries other activities as well.
86 |
87 | To "convey" a work means any kind of propagation that enables other
88 | parties to make or receive copies. Mere interaction with a user through
89 | a computer network, with no transfer of a copy, is not conveying.
90 |
91 | An interactive user interface displays "Appropriate Legal Notices"
92 | to the extent that it includes a convenient and prominently visible
93 | feature that (1) displays an appropriate copyright notice, and (2)
94 | tells the user that there is no warranty for the work (except to the
95 | extent that warranties are provided), that licensees may convey the
96 | work under this License, and how to view a copy of this License. If
97 | the interface presents a list of user commands or options, such as a
98 | menu, a prominent item in the list meets this criterion.
99 |
100 | 1. Source Code.
101 |
102 | The "source code" for a work means the preferred form of the work
103 | for making modifications to it. "Object code" means any non-source
104 | form of a work.
105 |
106 | A "Standard Interface" means an interface that either is an official
107 | standard defined by a recognized standards body, or, in the case of
108 | interfaces specified for a particular programming language, one that
109 | is widely used among developers working in that language.
110 |
111 | The "System Libraries" of an executable work include anything, other
112 | than the work as a whole, that (a) is included in the normal form of
113 | packaging a Major Component, but which is not part of that Major
114 | Component, and (b) serves only to enable use of the work with that
115 | Major Component, or to implement a Standard Interface for which an
116 | implementation is available to the public in source code form. A
117 | "Major Component", in this context, means a major essential component
118 | (kernel, window system, and so on) of the specific operating system
119 | (if any) on which the executable work runs, or a compiler used to
120 | produce the work, or an object code interpreter used to run it.
121 |
122 | The "Corresponding Source" for a work in object code form means all
123 | the source code needed to generate, install, and (for an executable
124 | work) run the object code and to modify the work, including scripts to
125 | control those activities. However, it does not include the work's
126 | System Libraries, or general-purpose tools or generally available free
127 | programs which are used unmodified in performing those activities but
128 | which are not part of the work. For example, Corresponding Source
129 | includes interface definition files associated with source files for
130 | the work, and the source code for shared libraries and dynamically
131 | linked subprograms that the work is specifically designed to require,
132 | such as by intimate data communication or control flow between those
133 | subprograms and other parts of the work.
134 |
135 | The Corresponding Source need not include anything that users
136 | can regenerate automatically from other parts of the Corresponding
137 | Source.
138 |
139 | The Corresponding Source for a work in source code form is that
140 | same work.
141 |
142 | 2. Basic Permissions.
143 |
144 | All rights granted under this License are granted for the term of
145 | copyright on the Program, and are irrevocable provided the stated
146 | conditions are met. This License explicitly affirms your unlimited
147 | permission to run the unmodified Program. The output from running a
148 | covered work is covered by this License only if the output, given its
149 | content, constitutes a covered work. This License acknowledges your
150 | rights of fair use or other equivalent, as provided by copyright law.
151 |
152 | You may make, run and propagate covered works that you do not
153 | convey, without conditions so long as your license otherwise remains
154 | in force. You may convey covered works to others for the sole purpose
155 | of having them make modifications exclusively for you, or provide you
156 | with facilities for running those works, provided that you comply with
157 | the terms of this License in conveying all material for which you do
158 | not control copyright. Those thus making or running the covered works
159 | for you must do so exclusively on your behalf, under your direction
160 | and control, on terms that prohibit them from making any copies of
161 | your copyrighted material outside their relationship with you.
162 |
163 | Conveying under any other circumstances is permitted solely under
164 | the conditions stated below. Sublicensing is not allowed; section 10
165 | makes it unnecessary.
166 |
167 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168 |
169 | No covered work shall be deemed part of an effective technological
170 | measure under any applicable law fulfilling obligations under article
171 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172 | similar laws prohibiting or restricting circumvention of such
173 | measures.
174 |
175 | When you convey a covered work, you waive any legal power to forbid
176 | circumvention of technological measures to the extent such circumvention
177 | is effected by exercising rights under this License with respect to
178 | the covered work, and you disclaim any intention to limit operation or
179 | modification of the work as a means of enforcing, against the work's
180 | users, your or third parties' legal rights to forbid circumvention of
181 | technological measures.
182 |
183 | 4. Conveying Verbatim Copies.
184 |
185 | You may convey verbatim copies of the Program's source code as you
186 | receive it, in any medium, provided that you conspicuously and
187 | appropriately publish on each copy an appropriate copyright notice;
188 | keep intact all notices stating that this License and any
189 | non-permissive terms added in accord with section 7 apply to the code;
190 | keep intact all notices of the absence of any warranty; and give all
191 | recipients a copy of this License along with the Program.
192 |
193 | You may charge any price or no price for each copy that you convey,
194 | and you may offer support or warranty protection for a fee.
195 |
196 | 5. Conveying Modified Source Versions.
197 |
198 | You may convey a work based on the Program, or the modifications to
199 | produce it from the Program, in the form of source code under the
200 | terms of section 4, provided that you also meet all of these conditions:
201 |
202 | a) The work must carry prominent notices stating that you modified
203 | it, and giving a relevant date.
204 |
205 | b) The work must carry prominent notices stating that it is
206 | released under this License and any conditions added under section
207 | 7. This requirement modifies the requirement in section 4 to
208 | "keep intact all notices".
209 |
210 | c) You must license the entire work, as a whole, under this
211 | License to anyone who comes into possession of a copy. This
212 | License will therefore apply, along with any applicable section 7
213 | additional terms, to the whole of the work, and all its parts,
214 | regardless of how they are packaged. This License gives no
215 | permission to license the work in any other way, but it does not
216 | invalidate such permission if you have separately received it.
217 |
218 | d) If the work has interactive user interfaces, each must display
219 | Appropriate Legal Notices; however, if the Program has interactive
220 | interfaces that do not display Appropriate Legal Notices, your
221 | work need not make them do so.
222 |
223 | A compilation of a covered work with other separate and independent
224 | works, which are not by their nature extensions of the covered work,
225 | and which are not combined with it such as to form a larger program,
226 | in or on a volume of a storage or distribution medium, is called an
227 | "aggregate" if the compilation and its resulting copyright are not
228 | used to limit the access or legal rights of the compilation's users
229 | beyond what the individual works permit. Inclusion of a covered work
230 | in an aggregate does not cause this License to apply to the other
231 | parts of the aggregate.
232 |
233 | 6. Conveying Non-Source Forms.
234 |
235 | You may convey a covered work in object code form under the terms
236 | of sections 4 and 5, provided that you also convey the
237 | machine-readable Corresponding Source under the terms of this License,
238 | in one of these ways:
239 |
240 | a) Convey the object code in, or embodied in, a physical product
241 | (including a physical distribution medium), accompanied by the
242 | Corresponding Source fixed on a durable physical medium
243 | customarily used for software interchange.
244 |
245 | b) Convey the object code in, or embodied in, a physical product
246 | (including a physical distribution medium), accompanied by a
247 | written offer, valid for at least three years and valid for as
248 | long as you offer spare parts or customer support for that product
249 | model, to give anyone who possesses the object code either (1) a
250 | copy of the Corresponding Source for all the software in the
251 | product that is covered by this License, on a durable physical
252 | medium customarily used for software interchange, for a price no
253 | more than your reasonable cost of physically performing this
254 | conveying of source, or (2) access to copy the
255 | Corresponding Source from a network server at no charge.
256 |
257 | c) Convey individual copies of the object code with a copy of the
258 | written offer to provide the Corresponding Source. This
259 | alternative is allowed only occasionally and noncommercially, and
260 | only if you received the object code with such an offer, in accord
261 | with subsection 6b.
262 |
263 | d) Convey the object code by offering access from a designated
264 | place (gratis or for a charge), and offer equivalent access to the
265 | Corresponding Source in the same way through the same place at no
266 | further charge. You need not require recipients to copy the
267 | Corresponding Source along with the object code. If the place to
268 | copy the object code is a network server, the Corresponding Source
269 | may be on a different server (operated by you or a third party)
270 | that supports equivalent copying facilities, provided you maintain
271 | clear directions next to the object code saying where to find the
272 | Corresponding Source. Regardless of what server hosts the
273 | Corresponding Source, you remain obligated to ensure that it is
274 | available for as long as needed to satisfy these requirements.
275 |
276 | e) Convey the object code using peer-to-peer transmission, provided
277 | you inform other peers where the object code and Corresponding
278 | Source of the work are being offered to the general public at no
279 | charge under subsection 6d.
280 |
281 | A separable portion of the object code, whose source code is excluded
282 | from the Corresponding Source as a System Library, need not be
283 | included in conveying the object code work.
284 |
285 | A "User Product" is either (1) a "consumer product", which means any
286 | tangible personal property which is normally used for personal, family,
287 | or household purposes, or (2) anything designed or sold for incorporation
288 | into a dwelling. In determining whether a product is a consumer product,
289 | doubtful cases shall be resolved in favor of coverage. For a particular
290 | product received by a particular user, "normally used" refers to a
291 | typical or common use of that class of product, regardless of the status
292 | of the particular user or of the way in which the particular user
293 | actually uses, or expects or is expected to use, the product. A product
294 | is a consumer product regardless of whether the product has substantial
295 | commercial, industrial or non-consumer uses, unless such uses represent
296 | the only significant mode of use of the product.
297 |
298 | "Installation Information" for a User Product means any methods,
299 | procedures, authorization keys, or other information required to install
300 | and execute modified versions of a covered work in that User Product from
301 | a modified version of its Corresponding Source. The information must
302 | suffice to ensure that the continued functioning of the modified object
303 | code is in no case prevented or interfered with solely because
304 | modification has been made.
305 |
306 | If you convey an object code work under this section in, or with, or
307 | specifically for use in, a User Product, and the conveying occurs as
308 | part of a transaction in which the right of possession and use of the
309 | User Product is transferred to the recipient in perpetuity or for a
310 | fixed term (regardless of how the transaction is characterized), the
311 | Corresponding Source conveyed under this section must be accompanied
312 | by the Installation Information. But this requirement does not apply
313 | if neither you nor any third party retains the ability to install
314 | modified object code on the User Product (for example, the work has
315 | been installed in ROM).
316 |
317 | The requirement to provide Installation Information does not include a
318 | requirement to continue to provide support service, warranty, or updates
319 | for a work that has been modified or installed by the recipient, or for
320 | the User Product in which it has been modified or installed. Access to a
321 | network may be denied when the modification itself materially and
322 | adversely affects the operation of the network or violates the rules and
323 | protocols for communication across the network.
324 |
325 | Corresponding Source conveyed, and Installation Information provided,
326 | in accord with this section must be in a format that is publicly
327 | documented (and with an implementation available to the public in
328 | source code form), and must require no special password or key for
329 | unpacking, reading or copying.
330 |
331 | 7. Additional Terms.
332 |
333 | "Additional permissions" are terms that supplement the terms of this
334 | License by making exceptions from one or more of its conditions.
335 | Additional permissions that are applicable to the entire Program shall
336 | be treated as though they were included in this License, to the extent
337 | that they are valid under applicable law. If additional permissions
338 | apply only to part of the Program, that part may be used separately
339 | under those permissions, but the entire Program remains governed by
340 | this License without regard to the additional permissions.
341 |
342 | When you convey a copy of a covered work, you may at your option
343 | remove any additional permissions from that copy, or from any part of
344 | it. (Additional permissions may be written to require their own
345 | removal in certain cases when you modify the work.) You may place
346 | additional permissions on material, added by you to a covered work,
347 | for which you have or can give appropriate copyright permission.
348 |
349 | Notwithstanding any other provision of this License, for material you
350 | add to a covered work, you may (if authorized by the copyright holders of
351 | that material) supplement the terms of this License with terms:
352 |
353 | a) Disclaiming warranty or limiting liability differently from the
354 | terms of sections 15 and 16 of this License; or
355 |
356 | b) Requiring preservation of specified reasonable legal notices or
357 | author attributions in that material or in the Appropriate Legal
358 | Notices displayed by works containing it; or
359 |
360 | c) Prohibiting misrepresentation of the origin of that material, or
361 | requiring that modified versions of such material be marked in
362 | reasonable ways as different from the original version; or
363 |
364 | d) Limiting the use for publicity purposes of names of licensors or
365 | authors of the material; or
366 |
367 | e) Declining to grant rights under trademark law for use of some
368 | trade names, trademarks, or service marks; or
369 |
370 | f) Requiring indemnification of licensors and authors of that
371 | material by anyone who conveys the material (or modified versions of
372 | it) with contractual assumptions of liability to the recipient, for
373 | any liability that these contractual assumptions directly impose on
374 | those licensors and authors.
375 |
376 | All other non-permissive additional terms are considered "further
377 | restrictions" within the meaning of section 10. If the Program as you
378 | received it, or any part of it, contains a notice stating that it is
379 | governed by this License along with a term that is a further
380 | restriction, you may remove that term. If a license document contains
381 | a further restriction but permits relicensing or conveying under this
382 | License, you may add to a covered work material governed by the terms
383 | of that license document, provided that the further restriction does
384 | not survive such relicensing or conveying.
385 |
386 | If you add terms to a covered work in accord with this section, you
387 | must place, in the relevant source files, a statement of the
388 | additional terms that apply to those files, or a notice indicating
389 | where to find the applicable terms.
390 |
391 | Additional terms, permissive or non-permissive, may be stated in the
392 | form of a separately written license, or stated as exceptions;
393 | the above requirements apply either way.
394 |
395 | 8. Termination.
396 |
397 | You may not propagate or modify a covered work except as expressly
398 | provided under this License. Any attempt otherwise to propagate or
399 | modify it is void, and will automatically terminate your rights under
400 | this License (including any patent licenses granted under the third
401 | paragraph of section 11).
402 |
403 | However, if you cease all violation of this License, then your
404 | license from a particular copyright holder is reinstated (a)
405 | provisionally, unless and until the copyright holder explicitly and
406 | finally terminates your license, and (b) permanently, if the copyright
407 | holder fails to notify you of the violation by some reasonable means
408 | prior to 60 days after the cessation.
409 |
410 | Moreover, your license from a particular copyright holder is
411 | reinstated permanently if the copyright holder notifies you of the
412 | violation by some reasonable means, this is the first time you have
413 | received notice of violation of this License (for any work) from that
414 | copyright holder, and you cure the violation prior to 30 days after
415 | your receipt of the notice.
416 |
417 | Termination of your rights under this section does not terminate the
418 | licenses of parties who have received copies or rights from you under
419 | this License. If your rights have been terminated and not permanently
420 | reinstated, you do not qualify to receive new licenses for the same
421 | material under section 10.
422 |
423 | 9. Acceptance Not Required for Having Copies.
424 |
425 | You are not required to accept this License in order to receive or
426 | run a copy of the Program. Ancillary propagation of a covered work
427 | occurring solely as a consequence of using peer-to-peer transmission
428 | to receive a copy likewise does not require acceptance. However,
429 | nothing other than this License grants you permission to propagate or
430 | modify any covered work. These actions infringe copyright if you do
431 | not accept this License. Therefore, by modifying or propagating a
432 | covered work, you indicate your acceptance of this License to do so.
433 |
434 | 10. Automatic Licensing of Downstream Recipients.
435 |
436 | Each time you convey a covered work, the recipient automatically
437 | receives a license from the original licensors, to run, modify and
438 | propagate that work, subject to this License. You are not responsible
439 | for enforcing compliance by third parties with this License.
440 |
441 | An "entity transaction" is a transaction transferring control of an
442 | organization, or substantially all assets of one, or subdividing an
443 | organization, or merging organizations. If propagation of a covered
444 | work results from an entity transaction, each party to that
445 | transaction who receives a copy of the work also receives whatever
446 | licenses to the work the party's predecessor in interest had or could
447 | give under the previous paragraph, plus a right to possession of the
448 | Corresponding Source of the work from the predecessor in interest, if
449 | the predecessor has it or can get it with reasonable efforts.
450 |
451 | You may not impose any further restrictions on the exercise of the
452 | rights granted or affirmed under this License. For example, you may
453 | not impose a license fee, royalty, or other charge for exercise of
454 | rights granted under this License, and you may not initiate litigation
455 | (including a cross-claim or counterclaim in a lawsuit) alleging that
456 | any patent claim is infringed by making, using, selling, offering for
457 | sale, or importing the Program or any portion of it.
458 |
459 | 11. Patents.
460 |
461 | A "contributor" is a copyright holder who authorizes use under this
462 | License of the Program or a work on which the Program is based. The
463 | work thus licensed is called the contributor's "contributor version".
464 |
465 | A contributor's "essential patent claims" are all patent claims
466 | owned or controlled by the contributor, whether already acquired or
467 | hereafter acquired, that would be infringed by some manner, permitted
468 | by this License, of making, using, or selling its contributor version,
469 | but do not include claims that would be infringed only as a
470 | consequence of further modification of the contributor version. For
471 | purposes of this definition, "control" includes the right to grant
472 | patent sublicenses in a manner consistent with the requirements of
473 | this License.
474 |
475 | Each contributor grants you a non-exclusive, worldwide, royalty-free
476 | patent license under the contributor's essential patent claims, to
477 | make, use, sell, offer for sale, import and otherwise run, modify and
478 | propagate the contents of its contributor version.
479 |
480 | In the following three paragraphs, a "patent license" is any express
481 | agreement or commitment, however denominated, not to enforce a patent
482 | (such as an express permission to practice a patent or covenant not to
483 | sue for patent infringement). To "grant" such a patent license to a
484 | party means to make such an agreement or commitment not to enforce a
485 | patent against the party.
486 |
487 | If you convey a covered work, knowingly relying on a patent license,
488 | and the Corresponding Source of the work is not available for anyone
489 | to copy, free of charge and under the terms of this License, through a
490 | publicly available network server or other readily accessible means,
491 | then you must either (1) cause the Corresponding Source to be so
492 | available, or (2) arrange to deprive yourself of the benefit of the
493 | patent license for this particular work, or (3) arrange, in a manner
494 | consistent with the requirements of this License, to extend the patent
495 | license to downstream recipients. "Knowingly relying" means you have
496 | actual knowledge that, but for the patent license, your conveying the
497 | covered work in a country, or your recipient's use of the covered work
498 | in a country, would infringe one or more identifiable patents in that
499 | country that you have reason to believe are valid.
500 |
501 | If, pursuant to or in connection with a single transaction or
502 | arrangement, you convey, or propagate by procuring conveyance of, a
503 | covered work, and grant a patent license to some of the parties
504 | receiving the covered work authorizing them to use, propagate, modify
505 | or convey a specific copy of the covered work, then the patent license
506 | you grant is automatically extended to all recipients of the covered
507 | work and works based on it.
508 |
509 | A patent license is "discriminatory" if it does not include within
510 | the scope of its coverage, prohibits the exercise of, or is
511 | conditioned on the non-exercise of one or more of the rights that are
512 | specifically granted under this License. You may not convey a covered
513 | work if you are a party to an arrangement with a third party that is
514 | in the business of distributing software, under which you make payment
515 | to the third party based on the extent of your activity of conveying
516 | the work, and under which the third party grants, to any of the
517 | parties who would receive the covered work from you, a discriminatory
518 | patent license (a) in connection with copies of the covered work
519 | conveyed by you (or copies made from those copies), or (b) primarily
520 | for and in connection with specific products or compilations that
521 | contain the covered work, unless you entered into that arrangement,
522 | or that patent license was granted, prior to 28 March 2007.
523 |
524 | Nothing in this License shall be construed as excluding or limiting
525 | any implied license or other defenses to infringement that may
526 | otherwise be available to you under applicable patent law.
527 |
528 | 12. No Surrender of Others' Freedom.
529 |
530 | If conditions are imposed on you (whether by court order, agreement or
531 | otherwise) that contradict the conditions of this License, they do not
532 | excuse you from the conditions of this License. If you cannot convey a
533 | covered work so as to satisfy simultaneously your obligations under this
534 | License and any other pertinent obligations, then as a consequence you may
535 | not convey it at all. For example, if you agree to terms that obligate you
536 | to collect a royalty for further conveying from those to whom you convey
537 | the Program, the only way you could satisfy both those terms and this
538 | License would be to refrain entirely from conveying the Program.
539 |
540 | 13. Remote Network Interaction; Use with the GNU General Public License.
541 |
542 | Notwithstanding any other provision of this License, if you modify the
543 | Program, your modified version must prominently offer all users
544 | interacting with it remotely through a computer network (if your version
545 | supports such interaction) an opportunity to receive the Corresponding
546 | Source of your version by providing access to the Corresponding Source
547 | from a network server at no charge, through some standard or customary
548 | means of facilitating copying of software. This Corresponding Source
549 | shall include the Corresponding Source for any work covered by version 3
550 | of the GNU General Public License that is incorporated pursuant to the
551 | following paragraph.
552 |
553 | Notwithstanding any other provision of this License, you have
554 | permission to link or combine any covered work with a work licensed
555 | under version 3 of the GNU General Public License into a single
556 | combined work, and to convey the resulting work. The terms of this
557 | License will continue to apply to the part which is the covered work,
558 | but the work with which it is combined will remain governed by version
559 | 3 of the GNU General Public License.
560 |
561 | 14. Revised Versions of this License.
562 |
563 | The Free Software Foundation may publish revised and/or new versions of
564 | the GNU Affero General Public License from time to time. Such new versions
565 | will be similar in spirit to the present version, but may differ in detail to
566 | address new problems or concerns.
567 |
568 | Each version is given a distinguishing version number. If the
569 | Program specifies that a certain numbered version of the GNU Affero General
570 | Public License "or any later version" applies to it, you have the
571 | option of following the terms and conditions either of that numbered
572 | version or of any later version published by the Free Software
573 | Foundation. If the Program does not specify a version number of the
574 | GNU Affero General Public License, you may choose any version ever published
575 | by the Free Software Foundation.
576 |
577 | If the Program specifies that a proxy can decide which future
578 | versions of the GNU Affero General Public License can be used, that proxy's
579 | public statement of acceptance of a version permanently authorizes you
580 | to choose that version for the Program.
581 |
582 | Later license versions may give you additional or different
583 | permissions. However, no additional obligations are imposed on any
584 | author or copyright holder as a result of your choosing to follow a
585 | later version.
586 |
587 | 15. Disclaimer of Warranty.
588 |
589 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597 |
598 | 16. Limitation of Liability.
599 |
600 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608 | SUCH DAMAGES.
609 |
610 | 17. Interpretation of Sections 15 and 16.
611 |
612 | If the disclaimer of warranty and limitation of liability provided
613 | above cannot be given local legal effect according to their terms,
614 | reviewing courts shall apply local law that most closely approximates
615 | an absolute waiver of all civil liability in connection with the
616 | Program, unless a warranty or assumption of liability accompanies a
617 | copy of the Program in return for a fee.
618 |
619 | END OF TERMS AND CONDITIONS
620 |
621 | How to Apply These Terms to Your New Programs
622 |
623 | If you develop a new program, and you want it to be of the greatest
624 | possible use to the public, the best way to achieve this is to make it
625 | free software which everyone can redistribute and change under these terms.
626 |
627 | To do so, attach the following notices to the program. It is safest
628 | to attach them to the start of each source file to most effectively
629 | state the exclusion of warranty; and each file should have at least
630 | the "copyright" line and a pointer to where the full notice is found.
631 |
632 |
633 | Copyright (C)
634 |
635 | This program is free software: you can redistribute it and/or modify
636 | it under the terms of the GNU Affero General Public License as published
637 | by the Free Software Foundation, either version 3 of the License, or
638 | (at your option) any later version.
639 |
640 | This program is distributed in the hope that it will be useful,
641 | but WITHOUT ANY WARRANTY; without even the implied warranty of
642 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643 | GNU Affero General Public License for more details.
644 |
645 | You should have received a copy of the GNU Affero General Public License
646 | along with this program. If not, see .
647 |
648 | Also add information on how to contact you by electronic and paper mail.
649 |
650 | If your software can interact with users remotely through a computer
651 | network, you should also make sure that it provides a way for users to
652 | get its source. For example, if your program is a web application, its
653 | interface could display a "Source" link that leads users to an archive
654 | of the code. There are many ways you could offer source, and different
655 | solutions will be better for different programs; see section 13 for the
656 | specific requirements.
657 |
658 | You should also get your employer (if you work as a programmer) or school,
659 | if any, to sign a "copyright disclaimer" for the program, if necessary.
660 | For more information on this, and how to apply and follow the GNU AGPL, see
661 | .
662 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 | 
3 | [](https://arxiv.org/abs/2004.05571)
4 |
5 | # CoCosNet
6 | Pytorch Implementation of the paper ["Cross-domain Correspondence Learning for Exemplar-based Image Translation"](https://panzhang0212.github.io/CoCosNet) (CVPR 2020 oral).
7 |
8 |
9 | 
10 |
11 | ### Update:
12 | 20200525: Training code for deepfashion complete. Due to the memory limitations, I employed the following conversions:
13 | - Disable the non-local layer, as the memory cost is infeasible on common hardware. If the original paper is telling the truth that the non-lacal layer works on (128-128-256) tensors, then each attention matrix would contain 128^4 elements (which takes 1GB).
14 | - Shrink the correspondence map size from 64 to 32, leading to 4x memory save on dense correspondence matrices.
15 | - Shrink the base number of filters from 64 to 16.
16 |
17 | The truncated model barely fits in a 12GB GTX Titan X card, but the performance would not be the same.
18 |
19 | # Environment
20 | - Ubuntu/CentOS
21 | - Pytorch 1.0+
22 | - opencv-python
23 | - tqdm
24 |
25 | # TODO list
26 | - [x] Prepare dataset
27 | - [x] Implement the network
28 | - [x] Implement the loss functions
29 | - [x] Implement the trainer
30 | - [x] Training on DeepFashion
31 | - [ ] Adjust network architecture to satisfy a single 16 GB GPU.
32 | - [ ] Training for other tasks
33 |
34 | # Dataset Preparation
35 | ### DeepFashion
36 | Just follow the routine in [the PATN repo](https://github.com/Lotayou/Pose-Transfer)
37 |
38 | # Pretrained Model
39 | The pretrained model for human pose transfer task: [TO BE RELEASED](https://github.com/Lotayou)
40 |
41 | # Training
42 | run `python train.py`.
43 |
44 | # Citations
45 | If you find this repo useful for your research, don't forget to cite the original paper:
46 | ```
47 | @article{Zhang2020CrossdomainCL,
48 | title={Cross-domain Correspondence Learning for Exemplar-based Image Translation},
49 | author={Pan Zhang and Bo Zhang and Dong Chen and Lu Yuan and Fang Wen},
50 | journal={ArXiv},
51 | year={2020},
52 | volume={abs/2004.05571}
53 | }
54 | ```
55 |
56 | # Acknowledgement
57 | TODO.
58 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes all the modules related to data loading and preprocessing
2 |
3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4 | You need to implement four functions:
5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6 | -- <__len__>: return the size of dataset.
7 | -- <__getitem__>: get a data point from data loader.
8 | -- : (optionally) add dataset-specific options and set default options.
9 |
10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11 | See our template dataset class 'template_dataset.py' for more details.
12 | """
13 | import importlib
14 | import torch.utils.data
15 | from data.base_dataset import BaseDataset
16 |
17 |
18 | def find_dataset_using_name(dataset_name):
19 | """Import the module "data/[dataset_name]_dataset.py".
20 |
21 | In the file, the class called DatasetNameDataset() will
22 | be instantiated. It has to be a subclass of BaseDataset,
23 | and it is case-insensitive.
24 | """
25 | dataset_filename = "data." + dataset_name + "_dataset"
26 | datasetlib = importlib.import_module(dataset_filename)
27 |
28 | dataset = None
29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
30 | for name, cls in datasetlib.__dict__.items():
31 | if name.lower() == target_dataset_name.lower() \
32 | and issubclass(cls, BaseDataset):
33 | dataset = cls
34 |
35 | if dataset is None:
36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
37 |
38 | return dataset
39 |
40 |
41 | def get_option_setter(dataset_name):
42 | """Return the static method of the dataset class."""
43 | dataset_class = find_dataset_using_name(dataset_name)
44 | return dataset_class.modify_commandline_options
45 |
46 |
47 | def create_dataset(opt):
48 | """Create a dataset given the option.
49 |
50 | This function wraps the class CustomDatasetDataLoader.
51 | This is the main interface between this package and 'train.py'/'test.py'
52 |
53 | Example:
54 | >>> from data import create_dataset
55 | >>> dataset = create_dataset(opt)
56 | """
57 | data_loader = CustomDatasetDataLoader(opt)
58 | dataset = data_loader.load_data()
59 | return dataset
60 |
61 |
62 | class CustomDatasetDataLoader():
63 | """Wrapper class of Dataset class that performs multi-threaded data loading"""
64 |
65 | def __init__(self, opt):
66 | """Initialize this class
67 |
68 | Step 1: create a dataset instance given the name [dataset_mode]
69 | Step 2: create a multi-threaded data loader.
70 | """
71 | self.opt = opt
72 | dataset_class = find_dataset_using_name(opt.dataset_mode)
73 | self.dataset = dataset_class(opt)
74 | print("dataset [%s] was created" % type(self.dataset).__name__)
75 | self.dataloader = torch.utils.data.DataLoader(
76 | self.dataset,
77 | batch_size=opt.batch_size,
78 | shuffle=not opt.serial_batches,
79 | num_workers=int(opt.num_workers),
80 | drop_last=True,
81 | pin_memory=True)
82 |
83 | def load_data(self):
84 | return self
85 |
86 | def __len__(self):
87 | """Return the number of data in the dataset"""
88 | return min(len(self.dataset), self.opt.max_dataset_size)
89 |
90 | def __iter__(self):
91 | """Return a batch of data"""
92 | for i, data in enumerate(self.dataloader):
93 | if i * self.opt.batch_size >= self.opt.max_dataset_size:
94 | break
95 | yield data
96 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 |
3 | class BaseDataset(Dataset):
4 | def __init__(self, opt):
5 | super().__init__()
6 |
7 | def __getitem__(self, index):
8 | pass
9 |
10 | def __len__(self): pass
11 |
--------------------------------------------------------------------------------
/data/fashion_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | fashion dataset: load deepfashion models
3 | Requires skeleton input as stick figures.
4 | """
5 |
6 | import random
7 | import numpy as np
8 | import torch
9 | import torch.utils.data as data
10 | import cv2
11 | from tqdm import tqdm
12 | import os
13 | from data.base_dataset import BaseDataset
14 | from data.pose_utils import draw_pose_from_cords, load_pose_cords_from_strings
15 |
16 | class FashionDataset(BaseDataset):
17 | # Beware, the pose annotation is fitted for 256*176 images, need additional resizing
18 | def __init__(self, opt):
19 | super().__init__(opt)
20 | self.opt = opt
21 | self.h = opt.image_size
22 | self.w = opt.image_size - 2 * opt.padding
23 | self.size = (self.h, self.w)
24 | self.pd = opt.padding
25 |
26 | self.white = torch.ones((3, self.h, self.h), dtype=torch.float32)
27 | self.black = -1 * self.white
28 |
29 | self.dir_Img = os.path.join(opt.dataroot, opt.phase) # person images (exemplar)
30 | self.dir_Anno = os.path.join(opt.dataroot, opt.phase + '_pose_rgb') # rgb pose images
31 |
32 | pairLst = os.path.join(opt.dataroot, 'fasion-resize-pairs-%s.csv' % opt.phase)
33 | self.init_categories(pairLst)
34 |
35 | if not os.path.isdir(self.dir_Anno):
36 | print('Folder %s not found or annotation incomplete...' % self.dir_Anno)
37 | annotation_csv = os.path.join(opt.dataroot, 'fasion-resize-annotation-%s.csv' % opt.phase)
38 | if os.path.isfile(annotation_csv):
39 | print('Found backup annotation file, start generating required pose images...')
40 | self.draw_stick_figures(annotation_csv, self.dir_Anno)
41 |
42 |
43 | def trans(self, x, bg='black'):
44 | x = torch.from_numpy(x / 127.5 - 1).permute(2, 0, 1).float()
45 | full = torch.ones((3, self.h, self.h), dtype=torch.float32)
46 | if bg == 'black':
47 | full = -1 * full
48 |
49 | full[:,:,self.pd:self.pd+self.w] = x
50 | return full
51 |
52 | def draw_stick_figures(self, annotation, target_dir):
53 | os.makedirs(target_dir, exist_ok=True)
54 | with open(annotation, 'r') as f:
55 | lines = [l.strip() for l in f][1:]
56 |
57 | for l in tqdm(lines):
58 | name, str_y, str_x = l.split(':')
59 | target_name = os.path.join(target_dir, name)
60 | cords = load_pose_cords_from_strings(str_y, str_x)
61 | target_im, _ = draw_pose_from_cords(cords, self.size)
62 | cv2.imwrite(target_name, target_im)
63 |
64 |
65 | def init_categories(self, pairLst):
66 | '''
67 | Using pandas is too f**king slow...
68 |
69 | pairs_file_train = pd.read_csv(pairLst)
70 | self.size = len(pairs_file_train)
71 | self.pairs = []
72 | print('Loading data pairs ...')
73 | for i in range(self.size):
74 | pair = [pairs_file_train.iloc[i]['from'], pairs_file_train.iloc[i]['to']]
75 | self.pairs.append(pair)
76 | '''
77 | with open(pairLst, 'r') as f:
78 | lines = [l for l in f][1:self.opt.max_dataset_size+1]
79 | self.pairs = [l.strip().split(',') for l in lines]
80 | print('Loading data pairs finished ...')
81 |
82 | def __getitem__(self, index):
83 | P1_name, P2_name = self.pairs[index]
84 |
85 | P1 = self.trans(cv2.imread(os.path.join(self.dir_Img, P1_name)), bg='white') # person 1
86 | BP1 = self.trans(cv2.imread(os.path.join(self.dir_Anno, P1_name)), bg='black') # bone of person 1
87 | P2 = self.trans(cv2.imread(os.path.join(self.dir_Img, P2_name)), bg='white') # person 2
88 | BP2 = self.trans(cv2.imread(os.path.join(self.dir_Anno, P2_name)), bg='black') # bone of person 2
89 | # domain x: posemap
90 | # domain y: exemplar
91 | return {'a': BP2, 'b_gt': P2, 'a_exemplar': BP1, 'b_exemplar': P1}
92 |
93 |
94 | def __len__(self):
95 | return len(self.pairs)
96 |
97 |
--------------------------------------------------------------------------------
/data/pose_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.ndimage.filters import gaussian_filter
3 | from skimage.draw import circle, line_aa, polygon
4 | import json
5 | from pandas import Series
6 | import matplotlib
7 | matplotlib.use('Agg')
8 | import matplotlib.pyplot as plt
9 | import matplotlib.patches as mpatches
10 | from collections import defaultdict
11 | import skimage.measure, skimage.transform
12 | import sys
13 |
14 | LIMB_SEQ = [[1,2], [1,5], [2,3], [3,4], [5,6], [6,7], [1,8], [8,9],
15 | [9,10], [1,11], [11,12], [12,13], [1,0], [0,14], [14,16],
16 | [0,15], [15,17], [2,16], [5,17]]
17 |
18 | COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
19 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
20 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
21 |
22 |
23 | LABELS = ['nose', 'neck', 'Rsho', 'Relb', 'Rwri', 'Lsho', 'Lelb', 'Lwri',
24 | 'Rhip', 'Rkne', 'Rank', 'Lhip', 'Lkne', 'Lank', 'Leye', 'Reye', 'Lear', 'Rear']
25 |
26 | MISSING_VALUE = -1
27 | def MISSING(x):
28 | return x == -1 or x == 0
29 |
30 |
31 | def map_to_cord(pose_map, threshold=0.1):
32 | all_peaks = [[] for i in range(18)]
33 | pose_map = pose_map[..., :18]
34 |
35 | y, x, z = np.where(np.logical_and(pose_map == pose_map.max(axis = (0, 1)),
36 | pose_map > threshold))
37 | for x_i, y_i, z_i in zip(x, y, z):
38 | all_peaks[z_i].append([x_i, y_i])
39 |
40 | x_values = []
41 | y_values = []
42 |
43 | for i in range(18):
44 | if len(all_peaks[i]) != 0:
45 | x_values.append(all_peaks[i][0][0])
46 | y_values.append(all_peaks[i][0][1])
47 | else:
48 | x_values.append(MISSING_VALUE)
49 | y_values.append(MISSING_VALUE)
50 |
51 | return np.concatenate([np.expand_dims(y_values, -1), np.expand_dims(x_values, -1)], axis=1)
52 |
53 |
54 | def cords_to_map(cords, img_size, sigma=6):
55 | result = np.zeros(img_size + cords.shape[0:1], dtype='float32')
56 | for i, point in enumerate(cords):
57 | if point[0] == MISSING_VALUE or point[1] == MISSING_VALUE:
58 | continue
59 | xx, yy = np.meshgrid(np.arange(img_size[1]), np.arange(img_size[0]))
60 | result[..., i] = np.exp(-((yy - point[0]) ** 2 + (xx - point[1]) ** 2) / (2 * sigma ** 2))
61 | return result
62 |
63 |
64 | def draw_pose_from_cords(pose_joints, img_size, radius=2, draw_joints=True):
65 | colors = np.zeros(shape=img_size + (3, ), dtype=np.uint8)
66 | mask = np.zeros(shape=img_size, dtype=bool)
67 |
68 | if draw_joints:
69 | for f, t in LIMB_SEQ:
70 | from_missing = MISSING(pose_joints[f][0]) or MISSING(pose_joints[f][1])
71 | to_missing = MISSING(pose_joints[t][0]) or MISSING(pose_joints[t][1])
72 | if from_missing or to_missing:
73 | continue
74 |
75 | '''
76 | Trick, use a 4-polygon with 1 pixel width to represent lines, involve shape control.
77 |
78 | yy, xx = polygon(
79 | [pose_joints[f][0], pose_joints[t][0], pose_joints[t][0]+1, pose_joints[f][0]+1],
80 | [pose_joints[f][1], pose_joints[t][1], pose_joints[t][1]+1, pose_joints[f][1]+1],
81 | shape=img_size
82 | )
83 | '''
84 | yy, xx, val = line_aa(pose_joints[f][0], pose_joints[f][1], pose_joints[t][0], pose_joints[t][1])
85 | valid_ids = [i for i in range(len(yy)) if 0 < yy[i] < img_size[0] and 0 < xx[i] < img_size[1]]
86 | yy, xx, val = yy[valid_ids], xx[valid_ids], val[valid_ids]
87 | colors[yy, xx] = np.expand_dims(val, 1) * 255
88 | mask[yy, xx] = True
89 |
90 | for i, joint in enumerate(pose_joints):
91 | if MISSING(pose_joints[i][0]) or MISSING(pose_joints[i][1]):
92 | continue
93 | yy, xx = circle(joint[0], joint[1], radius=radius, shape=img_size)
94 | colors[yy, xx] = COLORS[i]
95 | mask[yy, xx] = True
96 |
97 | return colors, mask
98 |
99 |
100 | def draw_pose_from_map(pose_map, threshold=0.1, **kwargs):
101 | cords = map_to_cord(pose_map, threshold=threshold)
102 | return draw_pose_from_cords(cords, pose_map.shape[:2], **kwargs)
103 |
104 |
105 | def load_pose_cords_from_strings(y_str, x_str):
106 | ## 20181114: FIX bug, convert pandas.Series object to a string-formatted int list
107 | if isinstance(y_str, Series):
108 | y_str = y_str.values[0]
109 | if isinstance(x_str, Series):
110 | x_str = x_str.values[0]
111 | y_cords = json.loads(y_str)
112 | x_cords = json.loads(x_str)
113 | # 20191117: modify PATN processed coords by adding 40 to non-negative indices
114 | # NOTE: For fasion dataset only.
115 | # print(x_cords)
116 | # 20191123: deprecate this.
117 | # x_cords = [item + 40 if item > 0 else item for item in x_cords]
118 | # print(x_cords)
119 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1)
120 |
121 | def mean_inputation(X):
122 | X = X.copy()
123 | for i in range(X.shape[1]):
124 | for j in range(X.shape[2]):
125 | val = np.mean(X[:, i, j][X[:, i, j] != -1])
126 | X[:, i, j][X[:, i, j] == -1] = val
127 | return X
128 |
129 | def draw_legend():
130 | handles = [mpatches.Patch(color=np.array(color) / 255.0, label=name) for color, name in zip(COLORS, LABELS)]
131 | plt.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
132 |
133 | def produce_ma_mask(kp_array, img_size, point_radius=4):
134 | from skimage.morphology import dilation, erosion, square
135 | mask = np.zeros(shape=img_size, dtype=bool)
136 | limbs = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10],
137 | [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17],
138 | [1,16], [16,18], [2,17], [2,18], [9,12], [12,6], [9,3], [17,18]]
139 | limbs = np.array(limbs) - 1
140 | for f, t in limbs:
141 | from_missing = kp_array[f][0] == MISSING_VALUE or kp_array[f][1] == MISSING_VALUE
142 | to_missing = kp_array[t][0] == MISSING_VALUE or kp_array[t][1] == MISSING_VALUE
143 | if from_missing or to_missing:
144 | continue
145 |
146 | norm_vec = kp_array[f] - kp_array[t]
147 | norm_vec = np.array([-norm_vec[1], norm_vec[0]])
148 | norm_vec = point_radius * norm_vec / np.linalg.norm(norm_vec)
149 |
150 |
151 | vetexes = np.array([
152 | kp_array[f] + norm_vec,
153 | kp_array[f] - norm_vec,
154 | kp_array[t] - norm_vec,
155 | kp_array[t] + norm_vec
156 | ])
157 | yy, xx = polygon(vetexes[:, 0], vetexes[:, 1], shape=img_size)
158 | mask[yy, xx] = True
159 |
160 | for i, joint in enumerate(kp_array):
161 | if kp_array[i][0] == MISSING_VALUE or kp_array[i][1] == MISSING_VALUE:
162 | continue
163 | yy, xx = circle(joint[0], joint[1], radius=point_radius, shape=img_size)
164 | mask[yy, xx] = True
165 |
166 | mask = dilation(mask, square(5))
167 | mask = erosion(mask, square(5))
168 | return mask
169 |
170 | if __name__ == "__main__":
171 | import pandas as pd
172 | from skimage.io import imread
173 | import pylab as plt
174 | import os
175 | i = 5
176 | df = pd.read_csv('data/market-annotation-train.csv', sep=':')
177 |
178 | for index, row in df.iterrows():
179 | pose_cords = load_pose_cords_from_strings(row['keypoints_y'], row['keypoints_x'])
180 |
181 | colors, mask = draw_pose_from_cords(pose_cords, (128, 64))
182 |
183 | mmm = produce_ma_mask(pose_cords, (128, 64)).astype(float)[..., np.newaxis].repeat(3, axis=-1)
184 | print(mmm.shape)
185 | img = imread('data/market-dataset/train/' + row['name'])
186 |
187 | mmm[mask] = colors[mask]
188 |
189 | print (mmm)
190 | plt.subplot(1, 1, 1)
191 | plt.imshow(mmm)
192 | plt.show()
193 |
--------------------------------------------------------------------------------
/datasets/fashion:
--------------------------------------------------------------------------------
1 | /backup1/lingboyang/human_image_generation/CVPR2019_pose_transfer/fashion_data
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to objective functions, optimizations, and network architectures.
2 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
3 | You need to implement the following five functions:
4 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
5 | -- : unpack data from dataset and apply preprocessing.
6 | -- : produce intermediate results.
7 | -- : calculate loss, gradients, and update network weights.
8 | -- : (optionally) add model-specific options and set default options.
9 | In the function <__init__>, you need to define four lists:
10 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
11 | -- self.model_names (str list): specify the images that you want to display and save.
12 | -- self.visual_names (str list): define networks used in our training.
13 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
14 | Now you can use the model class by specifying flag '--model dummy'.
15 | See our template model class 'template_model.py' for an example.
16 | """
17 |
18 | import importlib
19 | from model.base_model import BaseModel
20 |
21 |
22 | def find_model_using_name(model_name):
23 | """Import the module "models/[model_name]_model.py".
24 | In the file, the class called DatasetNameModel() will
25 | be instantiated. It has to be a subclass of BaseModel,
26 | and it is case-insensitive.
27 | """
28 | model_filename = "model." + model_name + "_model"
29 | modellib = importlib.import_module(model_filename)
30 | model = None
31 | target_model_name = model_name.replace('_', '') + 'model'
32 | for name, cls in modellib.__dict__.items():
33 | if name.lower() == target_model_name.lower() \
34 | and issubclass(cls, BaseModel):
35 | model = cls
36 |
37 | if model is None:
38 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
39 | exit(0)
40 |
41 | return model
42 |
43 |
44 | def get_option_setter(model_name):
45 | """Return the static method of the model class."""
46 | model_class = find_model_using_name(model_name)
47 | return model_class.modify_commandline_options
48 |
49 |
50 | def create_model(opt):
51 | """Create a model given the option.
52 | This function warps the class CustomDatasetDataLoader.
53 | This is the main interface between this package and 'train.py'/'test.py'
54 | Example:
55 | >>> from models import create_model
56 | >>> model = create_model(opt)
57 | """
58 | model = find_model_using_name(opt.model)
59 | instance = model(opt)
60 | print("model [%s] was created" % type(instance).__name__)
61 | return instance
62 |
--------------------------------------------------------------------------------
/model/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from abc import ABC, abstractmethod
5 | from . import networks
6 |
7 |
8 | class BaseModel(ABC):
9 | """This class is an abstract base class (ABC) for models.
10 | To create a subclass, you need to implement the following five functions:
11 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
12 | -- : unpack data from dataset and apply preprocessing.
13 | -- : produce intermediate results.
14 | -- : calculate losses, gradients, and update network weights.
15 | -- : (optionally) add model-specific options and set default options.
16 | """
17 |
18 | def __init__(self, opt):
19 | """Initialize the BaseModel class.
20 |
21 | Parameters:
22 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
23 |
24 | When creating your custom class, you need to implement your own initialization.
25 | In this fucntion, you should first call `BaseModel.__init__(self, opt)`
26 | Then, you need to define four lists:
27 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
28 | -- self.model_names (str list): specify the images that you want to display and save.
29 | -- self.visual_names (str list): define networks used in our training.
30 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
31 | """
32 | self.opt = opt
33 | self.gpu_ids = opt.gpu_ids
34 | self.isTrain = opt.isTrain
35 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
36 | # damn it, build all directories recursively
37 | self.mkdir_recursive(opt.checkpoints_dir, opt.name, 'images')
38 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
39 | self.save_image_dir = os.path.join(self.save_dir, 'images')
40 |
41 | self.loss_names = []
42 | self.model_names = []
43 | self.visual_names = []
44 | self.optimizers = []
45 | self.image_paths = []
46 |
47 | @staticmethod
48 | def mkdir_recursive(*folders):
49 | cur_folder = None
50 | for folder in folders:
51 | cur_folder = folder if cur_folder is None else os.path.join(cur_folder, folder)
52 | os.makedirs(cur_folder, exist_ok=True)
53 |
54 | @staticmethod
55 | def modify_commandline_options(parser, is_train):
56 | """Add new model-specific options, and rewrite default values for existing options.
57 |
58 | Parameters:
59 | parser -- original option parser
60 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
61 |
62 | Returns:
63 | the modified parser.
64 | """
65 | return parser
66 |
67 | @abstractmethod
68 | def set_input(self, input):
69 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
70 |
71 | Parameters:
72 | input (dict): includes the data itself and its metadata information.
73 | """
74 | pass
75 |
76 | @abstractmethod
77 | def forward(self):
78 | """Run forward pass; called by both functions and ."""
79 | pass
80 |
81 | def is_train(self):
82 | """check if the current batch is good for training."""
83 | return True
84 |
85 | @abstractmethod
86 | def optimize_parameters(self):
87 | """Calculate losses, gradients, and update network weights; called in every training iteration"""
88 | pass
89 |
90 | def setup(self, opt):
91 | """Load and print networks; create schedulers
92 |
93 | Parameters:
94 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
95 | """
96 | if self.isTrain:
97 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
98 | if not self.isTrain or opt.continue_train:
99 | self.load_networks(opt.which_epoch)
100 | else:
101 | self.init_networks(opt)
102 | self.print_networks(opt.verbose)
103 |
104 | def init_networks(self, opt):
105 | print('Initializing models in %s mode and start training from scratch' % opt.init_type)
106 | for name in self.model_names:
107 | net = getattr(self, 'net' + name)
108 | if isinstance(net, torch.nn.DataParallel):
109 | net = net.module
110 | net.to(self.device)
111 | networks.init_weights(net, opt.init_type, opt.init_gain)
112 |
113 | def eval(self):
114 | """Make models eval mode during test time"""
115 | for name in self.model_names:
116 | if isinstance(name, str):
117 | net = getattr(self, 'net' + name)
118 | net.eval()
119 |
120 | def test(self):
121 | """Forward function used in test time.
122 |
123 | This function wraps function in no_grad() so we don't save intermediate steps for backprop
124 | It also calls to produce additional visualization results
125 | """
126 | with torch.no_grad():
127 | self.forward()
128 | self.compute_visuals()
129 |
130 | def compute_visuals(self):
131 | """Calculate additional output images for visdom and HTML visualization"""
132 | pass
133 |
134 | def get_image_paths(self):
135 | """ Return image paths that are used to load current data"""
136 | return self.image_paths
137 |
138 | def update_learning_rate(self):
139 | """Update learning rates for all the networks; called at the end of every epoch"""
140 | for scheduler in self.schedulers:
141 | scheduler.step()
142 | lr = self.optimizers[0].param_groups[0]['lr']
143 | print('learning rate = %.7f' % lr)
144 |
145 | def get_current_visuals(self):
146 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
147 | visual_ret = OrderedDict()
148 | for name in self.visual_names:
149 | if isinstance(name, str):
150 | visual_ret[name] = getattr(self, name)
151 | return visual_ret
152 |
153 | def get_current_losses(self):
154 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
155 | errors_ret = OrderedDict()
156 | for name in self.loss_names:
157 | if isinstance(name, str):
158 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
159 | return errors_ret
160 |
161 | def save_networks(self, epoch):
162 | """Save all the networks to the disk.
163 |
164 | Parameters:
165 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
166 | """
167 | for name in self.model_names:
168 | if isinstance(name, str):
169 | save_filename = '%s_net_%s.pth' % (epoch, name)
170 | save_path = os.path.join(self.save_dir, save_filename)
171 | net = getattr(self, 'net' + name)
172 | torch.save(net.state_dict(), save_path)
173 |
174 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
175 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
176 | key = keys[i]
177 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
178 | if module.__class__.__name__.startswith('InstanceNorm') and \
179 | (key == 'running_mean' or key == 'running_var'):
180 | if getattr(module, key) is None:
181 | state_dict.pop('.'.join(keys))
182 | if module.__class__.__name__.startswith('InstanceNorm') and \
183 | (key == 'num_batches_tracked'):
184 | state_dict.pop('.'.join(keys))
185 | else:
186 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
187 |
188 | def load_networks(self, epoch):
189 | """Load all the networks from the disk.
190 |
191 | Parameters:
192 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
193 | """
194 | for name in self.model_names:
195 | if isinstance(name, str):
196 | load_filename = '%s_net_%s.pth' % (epoch, name)
197 | load_path = os.path.join(self.save_dir, load_filename)
198 | net = getattr(self, 'net' + name)
199 | if isinstance(net, torch.nn.DataParallel):
200 | net = net.module
201 | print('loading the model from %s' % load_path)
202 | # if you are using PyTorch newer than 0.4 (e.g., built from
203 | # GitHub source), you can remove str() on self.device
204 | state_dict = torch.load(load_path, map_location=str(self.device))
205 | if hasattr(state_dict, '_metadata'):
206 | del state_dict._metadata
207 |
208 | # patch InstanceNorm checkpoints prior to 0.4
209 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
210 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
211 | net.load_state_dict(state_dict)
212 |
213 | def print_networks(self, verbose):
214 | """Print the total number of parameters in the network and (if verbose) network architecture
215 |
216 | Parameters:
217 | verbose (bool) -- if verbose: print the network architecture
218 | """
219 | print('---------- Networks initialized -------------')
220 | for name in self.model_names:
221 | if isinstance(name, str):
222 | net = getattr(self, 'net' + name)
223 | num_params = 0
224 | for param in net.parameters():
225 | num_params += param.numel()
226 | if verbose:
227 | print(net)
228 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
229 | print('-----------------------------------------------')
230 |
231 | def set_requires_grad(self, nets, requires_grad=False):
232 | """Set requires_grad=False for all the networks to avoid unnecessary computations
233 | Parameters:
234 | nets (network list) -- a list of networks
235 | requires_grad (bool) -- whether the networks require gradients or not
236 | """
237 | if not isinstance(nets, list):
238 | nets = [nets]
239 | for net in nets:
240 | if net is not None:
241 | for param in net.parameters():
242 | param.requires_grad = requires_grad
243 |
--------------------------------------------------------------------------------
/model/cocos_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import os
5 | import cv2
6 | import numpy as np
7 | import itertools
8 | from model import networks
9 | from model.base_model import BaseModel
10 | from model.translation_net import TranslationNet
11 | from model.correspondence_net import CorrespondenceNet
12 | from model.discriminator import Discriminator
13 | from model.loss import VGGLoss, GANLoss
14 | '''
15 | Cross-Domian Correpondence Model
16 | '''
17 | class CoCosModel(BaseModel):
18 | @staticmethod
19 | def modify_commandline_options(parser, is_train=True):
20 | return parser
21 |
22 | @staticmethod
23 | def torch2numpy(x):
24 | # from [-1,1] to [0,255]
25 | return ((x.detach().cpu().numpy().transpose(1,2,0) + 1) * 127.5).astype(np.uint8)
26 |
27 | def __name__(self):
28 | return 'CoCosModel'
29 |
30 | def __init__(self, opt):
31 | super().__init__(opt)
32 | self.w = opt.image_size
33 | # make a folder for save images
34 | self.image_dir = os.path.join(self.save_dir, 'images')
35 | if not os.path.isdir(self.image_dir):
36 | os.mkdir(self.image_dir)
37 |
38 | # initialize networks
39 | self.model_names = ['C', 'T']
40 | self.netC = CorrespondenceNet(opt)
41 | self.netT = TranslationNet(opt)
42 | if opt.isTrain:
43 | self.model_names.append('D')
44 | self.netD = Discriminator(opt)
45 |
46 | self.visual_names = ['b_exemplar', 'a', 'b_gen', 'b_gt'] # HPT convention
47 |
48 | if opt.isTrain:
49 | # assign losses
50 | self.loss_names = ['perc', 'domain', 'feat', 'context', 'reg', 'adv']
51 | self.visual_names += ['b_warp']
52 | self.criterionFeat = torch.nn.L1Loss()
53 | # Both interface for VGG and perceptual loss
54 | # call with different mode and layer params
55 | self.criterionVGG = VGGLoss(self.device)
56 | # Support hinge loss
57 | self.criterionAdv = GANLoss(gan_mode=opt.gan_mode).to(self.device)
58 | self.criterionDomain = nn.L1Loss()
59 | self.criterionReg = torch.nn.L1Loss()
60 |
61 |
62 | # initialize optimizers
63 | gen_params = itertools.chain(self.netT.parameters(), self.netC.parameters())
64 | self.optG = torch.optim.Adam(gen_params, lr=opt.lr, betas=(opt.beta1, 0.999))
65 | self.optD = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
66 | self.optimizers = [self.optG, self.optD]
67 |
68 | # Finally, load checkpoints and recover schedulers
69 | self.setup(opt)
70 | torch.autograd.set_detect_anomaly(True)
71 |
72 | def set_input(self, batch):
73 | # expecting 'a' -> 'b_gt', 'a_exemplar' -> 'b_exemplar', ('b_deform')
74 | # for human pose transfer, 'b_deform' is already 'b_exemplar'
75 | for k, v in batch.items():
76 | setattr(self, k, v.to(self.device))
77 |
78 | def forward(self):
79 | self.sa, self.sb, self.fb_warp, self.b_warp = self.netC(self.a, self.b_exemplar) # 3*HW*HW
80 | self.b_gen = self.netT(self.b_warp)
81 | # self.b_gen = self.netT(self.fb_warp) retain original feature or use warped rgb?
82 |
83 | # TODO: Implement backward warping (maybe we should adjust the input size?)
84 | _, _, _, self.b_reg = self.netC(self.a_exemplar,
85 | F.interpolate(self.b_warp, (self.w, self.w), mode='bilinear')
86 | )
87 | #print(self.b_gen.shape, self.b_reg.shape, self.b_gt.shape)
88 |
89 | def test(self):
90 | with torch.no_grad():
91 | _, _, _, self.b_warp = self.netC(self.a, self.b_exemplar) # 3*HW*HW
92 | self.b_gen = self.netT(self.b_warp)
93 |
94 | def backward_G(self):
95 | self.optG.zero_grad()
96 | # Damn, do we really need 6 losses?
97 | # 1. Perc loss(For human pose transfer we abandon it, it's all in the criterion Feat)
98 | self.loss_perc = 0
99 | # 2. domain loss
100 | self.loss_domain = self.opt.lambda_domain * self.criterionDomain(self.sa, self.sb)
101 | # 3. losses for pseudo exemplar pairs
102 | self.loss_feat = self.opt.lambda_feat * self.criterionVGG(self.b_gen, self.b_gt, mode='perceptual')
103 | # 4. Contextural loss
104 | self.loss_context = self.opt.lambda_context * self.criterionVGG(self.b_gen, self.b_exemplar, mode='contextual', layers=[2,3,4,5])
105 | # 5. Reg loss
106 | b_exemplar_small = F.interpolate(self.b_exemplar, self.b_reg.size()[2:], mode='bilinear')
107 | self.loss_reg = self.opt.lambda_reg * self.criterionReg(self.b_reg, b_exemplar_small)
108 | # 6. GAN loss
109 | pred_real, pred_fake = self.discriminate(self.b_gt, self.b_gen)
110 | self.loss_adv = self.opt.lambda_adv * self.criterionAdv(pred_fake, True, for_discriminator=False)
111 |
112 | g_loss = self.loss_perc + self.loss_domain + self.loss_feat \
113 | + self.loss_context + self.loss_reg + self.loss_adv
114 |
115 | g_loss.backward()
116 | self.optG.step()
117 |
118 | def discriminate(self, real, fake):
119 | fake_and_real = torch.cat([fake, real], dim=0)
120 | discriminator_out = self.netD(fake_and_real)
121 | pred_fake, pred_real = self.divide_pred(discriminator_out)
122 |
123 | return pred_fake, pred_real
124 |
125 | # Take the prediction of fake and real images from the combined batch
126 | def divide_pred(self, pred):
127 | # the prediction contains the intermediate outputs of multiscale GAN,
128 | # so it's usually a list
129 | if isinstance(pred, list):
130 | fake = [p[:p.size(0) // 2] for p in pred]
131 | real = [p[p.size(0) // 2:] for p in pred]
132 | else:
133 | fake = pred[:pred.size(0) // 2]
134 | real = pred[pred.size(0) // 2:]
135 |
136 | return fake, real
137 |
138 | def backward_D(self):
139 | self.optD.zero_grad()
140 | # test, run under no_grad mode
141 | self.test()
142 |
143 | pred_fake, pred_real = self.discriminate(self.b_gt, self.b_gen)
144 |
145 | self.d_fake = self.criterionAdv(pred_fake, False, for_discriminator=True)
146 | self.d_real = self.criterionAdv(pred_real, True, for_discriminator=True)
147 |
148 | d_loss = (self.d_fake + self.d_real) / 2
149 | d_loss.backward()
150 | self.optD.step()
151 |
152 | def optimize_parameters(self):
153 | # must call self.set_input(data) first
154 | self.forward()
155 | self.backward_G()
156 | self.backward_D()
157 |
158 | ### Standalone utility functions
159 | def log_loss(self, epoch, iter):
160 | msg = 'Epoch %d iter %d\n ' % (epoch, iter)
161 | for name in self.loss_names:
162 | val = getattr(self, 'loss_%s' % name)
163 | if isinstance(val, torch.cuda.FloatTensor):
164 | val = val.item()
165 | msg += '%s: %.4f, ' % (name, val)
166 | print(msg)
167 |
168 | def log_visual(self, epoch, iter):
169 | save_path = os.path.join(self.save_image_dir, 'epoch%03d_iter%05d.png' % (epoch, iter))
170 | # warped image is not the same resolution, need scaling
171 | self.b_warp = F.interpolate(self.b_warp, (self.w, self.w), mode='bicubic')
172 | pack = torch.cat(
173 | [getattr(self, name) for name in self.visual_names], dim=3
174 | )[0] # only save one example
175 | cv2.imwrite(save_path, self.torch2numpy(pack))
176 | cv2.imwrite('b_ex' + save_path, self.torch2numpy(self.b_exemplar[0]))
177 |
178 | def update_learning_rate(self):
179 | '''
180 | Update learning rates for all the networks;
181 | called at the end of every epoch by train.py
182 | '''
183 | for scheduler in self.schedulers:
184 | scheduler.step()
185 | lr = self.optimizers[0].param_groups[0]['lr']
186 | print('learning rate updated to %.7f' % lr)
187 |
--------------------------------------------------------------------------------
/model/contextual_loss.py:
--------------------------------------------------------------------------------
1 | '''
2 | https://github.com/roimehrez/contextualLoss/blob/master/CX/CX_distance.py
3 | '''
4 | import torch
5 | import numpy as np
6 |
7 | class TensorAxis:
8 | N = 0
9 | H = 1
10 | W = 2
11 | C = 3
12 |
13 |
14 | class CSFlow:
15 | def __init__(self, sigma=float(0.1), b=float(1.0)):
16 | self.b = b
17 | self.sigma = sigma
18 |
19 | def __calculate_CS(self, scaled_distances, axis_for_normalization=TensorAxis.C):
20 | self.scaled_distances = scaled_distances
21 | self.cs_weights_before_normalization = torch.exp((self.b - scaled_distances) / self.sigma)
22 | # self.cs_weights_before_normalization = 1 / (1 + scaled_distances)
23 | # self.cs_NHWC = CSFlow.sum_normalize(self.cs_weights_before_normalization, axis_for_normalization)
24 | self.cs_NHWC = self.cs_weights_before_normalization
25 |
26 | # def reversed_direction_CS(self):
27 | # cs_flow_opposite = CSFlow(self.sigma, self.b)
28 | # cs_flow_opposite.raw_distances = self.raw_distances
29 | # work_axis = [TensorAxis.H, TensorAxis.W]
30 | # relative_dist = cs_flow_opposite.calc_relative_distances(axis=work_axis)
31 | # cs_flow_opposite.__calculate_CS(relative_dist, work_axis)
32 | # return cs_flow_opposite
33 |
34 | # --
35 | @staticmethod
36 | def create_using_L2(I_features, T_features, sigma=float(0.5), b=float(1.0)):
37 | cs_flow = CSFlow(sigma, b)
38 | sT = T_features.shape
39 | sI = I_features.shape
40 |
41 | Ivecs = torch.reshape(I_features, (sI[0], -1, sI[3]))
42 | Tvecs = torch.reshape(T_features, (sI[0], -1, sT[3]))
43 | r_Ts = torch.sum(Tvecs * Tvecs, 2)
44 | r_Is = torch.sum(Ivecs * Ivecs, 2)
45 | raw_distances_list = []
46 | for i in range(sT[0]):
47 | Ivec, Tvec, r_T, r_I = Ivecs[i], Tvecs[i], r_Ts[i], r_Is[i]
48 | A = Tvec @ torch.transpose(Ivec, 0, 1) # (matrix multiplication)
49 | cs_flow.A = A
50 | # A = tf.matmul(Tvec, tf.transpose(Ivec))
51 | r_T = torch.reshape(r_T, [-1, 1]) # turn to column vector
52 | dist = r_T - 2 * A + r_I
53 | dist = torch.reshape(torch.transpose(dist, 0, 1), shape=(1, sI[1], sI[2], dist.shape[0]))
54 | # protecting against numerical problems, dist should be positive
55 | dist = torch.clamp(dist, min=float(0.0))
56 | # dist = tf.sqrt(dist)
57 | raw_distances_list += [dist]
58 |
59 | cs_flow.raw_distances = torch.cat(raw_distances_list)
60 |
61 | relative_dist = cs_flow.calc_relative_distances()
62 | cs_flow.__calculate_CS(relative_dist)
63 | return cs_flow
64 |
65 | # --
66 | @staticmethod
67 | def create_using_L1(I_features, T_features, sigma=float(0.5), b=float(1.0)):
68 | cs_flow = CSFlow(sigma, b)
69 | sT = T_features.shape
70 | sI = I_features.shape
71 |
72 | Ivecs = torch.reshape(I_features, (sI[0], -1, sI[3]))
73 | Tvecs = torch.reshape(T_features, (sI[0], -1, sT[3]))
74 | raw_distances_list = []
75 | for i in range(sT[0]):
76 | Ivec, Tvec = Ivecs[i], Tvecs[i]
77 | dist = torch.abs(torch.sum(Ivec.unsqueeze(1) - Tvec.unsqueeze(0), dim=2))
78 | dist = torch.reshape(torch.transpose(dist, 0, 1), shape=(1, sI[1], sI[2], dist.shape[0]))
79 | # protecting against numerical problems, dist should be positive
80 | dist = torch.clamp(dist, min=float(0.0))
81 | # dist = tf.sqrt(dist)
82 | raw_distances_list += [dist]
83 |
84 | cs_flow.raw_distances = torch.cat(raw_distances_list)
85 |
86 | relative_dist = cs_flow.calc_relative_distances()
87 | cs_flow.__calculate_CS(relative_dist)
88 | return cs_flow
89 |
90 | # --
91 | @staticmethod
92 | def create_using_dotP(I_features, T_features, sigma=float(0.5), b=float(1.0)):
93 | cs_flow = CSFlow(sigma, b)
94 | # prepare feature before calculating cosine distance
95 | T_features, I_features = cs_flow.center_by_T(T_features, I_features)
96 | T_features = CSFlow.l2_normalize_channelwise(T_features)
97 | I_features = CSFlow.l2_normalize_channelwise(I_features)
98 |
99 | # work seperatly for each example in dim 1
100 | cosine_dist_l = []
101 | N = T_features.size()[0]
102 | for i in range(N):
103 | T_features_i = T_features[i, :, :, :].unsqueeze_(0) # 1HWC --> 1CHW
104 | I_features_i = I_features[i, :, :, :].unsqueeze_(0).permute((0, 3, 1, 2))
105 | patches_PC11_i = cs_flow.patch_decomposition(T_features_i) # 1HWC --> PC11, with P=H*W
106 | cosine_dist_i = torch.nn.functional.conv2d(I_features_i, patches_PC11_i)
107 | cosine_dist_1HWC = cosine_dist_i.permute((0, 2, 3, 1))
108 | cosine_dist_l.append(cosine_dist_i.permute((0, 2, 3, 1))) # back to 1HWC
109 |
110 | cs_flow.cosine_dist = torch.cat(cosine_dist_l, dim=0)
111 |
112 | cs_flow.raw_distances = - (cs_flow.cosine_dist - 1) / 2 ### why -
113 |
114 | relative_dist = cs_flow.calc_relative_distances()
115 | cs_flow.__calculate_CS(relative_dist)
116 | return cs_flow
117 |
118 | def calc_relative_distances(self, axis=TensorAxis.C):
119 | epsilon = 1e-5
120 | div = torch.min(self.raw_distances, dim=axis, keepdim=True)[0]
121 | relative_dist = self.raw_distances / (div + epsilon)
122 | return relative_dist
123 |
124 | @staticmethod
125 | def sum_normalize(cs, axis=TensorAxis.C):
126 | reduce_sum = torch.sum(cs, dim=axis, keepdim=True)
127 | cs_normalize = torch.div(cs, reduce_sum)
128 | return cs_normalize
129 |
130 | def center_by_T(self, T_features, I_features):
131 | # assuming both input are of the same size
132 | # calculate stas over [batch, height, width], expecting 1x1xDepth tensor
133 | axes = [0, 1, 2]
134 | self.meanT = T_features.mean(0, keepdim=True).mean(1, keepdim=True).mean(2, keepdim=True)
135 | self.varT = T_features.var(0, keepdim=True).var(1, keepdim=True).var(2, keepdim=True)
136 | self.T_features_centered = T_features - self.meanT
137 | self.I_features_centered = I_features - self.meanT
138 |
139 | return self.T_features_centered, self.I_features_centered
140 |
141 | @staticmethod
142 | def l2_normalize_channelwise(features):
143 | norms = features.norm(p=2, dim=TensorAxis.C, keepdim=True)
144 | features = features.div(norms)
145 | return features
146 |
147 | def patch_decomposition(self, T_features):
148 | # 1HWC --> 11PC --> PC11, with P=H*W
149 | (N, H, W, C) = T_features.shape
150 | P = H * W
151 | patches_PC11 = T_features.reshape(shape=(1, 1, P, C)).permute(dims=(2, 3, 0, 1))
152 | return patches_PC11
153 |
154 | @staticmethod
155 | def pdist2(x, keepdim=False):
156 | sx = x.shape
157 | x = x.reshape(shape=(sx[0], sx[1] * sx[2], sx[3]))
158 | differences = x.unsqueeze(2) - x.unsqueeze(1)
159 | distances = torch.sum(differences**2, -1)
160 | if keepdim:
161 | distances = distances.reshape(shape=(sx[0], sx[1], sx[2], sx[3]))
162 | return distances
163 |
164 | @staticmethod
165 | def calcR_static(sT, order='C', deformation_sigma=0.05):
166 | # oreder can be C or F (matlab order)
167 | pixel_count = sT[0] * sT[1]
168 |
169 | rangeRows = range(0, sT[1])
170 | rangeCols = range(0, sT[0])
171 | Js, Is = np.meshgrid(rangeRows, rangeCols)
172 | row_diff_from_first_row = Is
173 | col_diff_from_first_col = Js
174 |
175 | row_diff_from_first_row_3d_repeat = np.repeat(row_diff_from_first_row[:, :, np.newaxis], pixel_count, axis=2)
176 | col_diff_from_first_col_3d_repeat = np.repeat(col_diff_from_first_col[:, :, np.newaxis], pixel_count, axis=2)
177 |
178 | rowDiffs = -row_diff_from_first_row_3d_repeat + row_diff_from_first_row.flatten(order).reshape(1, 1, -1)
179 | colDiffs = -col_diff_from_first_col_3d_repeat + col_diff_from_first_col.flatten(order).reshape(1, 1, -1)
180 | R = rowDiffs ** 2 + colDiffs ** 2
181 | R = R.astype(np.float32)
182 | R = np.exp(-(R) / (2 * deformation_sigma ** 2))
183 | return R
184 |
185 |
186 |
187 |
188 |
189 |
190 | # --------------------------------------------------
191 | # CX loss
192 | # --------------------------------------------------
193 |
194 |
195 |
196 | def CX_loss(T_features, I_features, deformation=False, dis=False):
197 | # T_features = tf.convert_to_tensor(T_features, dtype=tf.float32)
198 | # I_features = tf.convert_to_tensor(I_features, dtype=tf.float32)
199 | # since this is a convertion of tensorflow to pytorch we permute the tensor from
200 | # T_features = normalize_tensor(T_features)
201 | # I_features = normalize_tensor(I_features)
202 |
203 | # since this originally Tensorflow implemntation
204 | # we modify all tensors to be as TF convention and not as the convention of pytorch.
205 | def from_pt2tf(Tpt):
206 | Ttf = Tpt.permute(0, 2, 3, 1)
207 | return Ttf
208 | # N x C x H x W --> N x H x W x C
209 | T_features_tf = from_pt2tf(T_features)
210 | I_features_tf = from_pt2tf(I_features)
211 |
212 | # cs_flow = CSFlow.create_using_dotP(I_features_tf, T_features_tf, sigma=1.0)
213 | cs_flow = CSFlow.create_using_L2(I_features_tf, T_features_tf, sigma=1.0)
214 | # sum_normalize:
215 | # To:
216 | cs = cs_flow.cs_NHWC
217 |
218 | if deformation:
219 | deforma_sigma = 0.001
220 | sT = T_features_tf.shape[1:2 + 1]
221 | R = CSFlow.calcR_static(sT, deformation_sigma=deforma_sigma)
222 | cs *= torch.Tensor(R).unsqueeze(dim=0).cuda()
223 |
224 | if dis:
225 | CS = []
226 | k_max_NC = torch.max(torch.max(cs, dim=1)[1], dim=1)[1]
227 | indices = k_max_NC.cpu()
228 | N, C = indices.shape
229 | for i in range(N):
230 | CS.append((C - len(torch.unique(indices[i, :]))) / C)
231 | score = torch.FloatTensor(CS)
232 | else:
233 | # reduce_max X and Y dims
234 | # cs = CSFlow.pdist2(cs,keepdim=True)
235 | k_max_NC = torch.max(torch.max(cs, dim=1)[0], dim=1)[0]
236 | # reduce mean over C dim
237 | CS = torch.mean(k_max_NC, dim=1)
238 | # score = 1/CS
239 | # score = torch.exp(-CS*10)
240 | score = -torch.log(CS)
241 | # reduce mean over N dim
242 | # CX_loss = torch.mean(CX_loss)
243 | return score
244 |
245 |
246 | def symmetric_CX_loss(T_features, I_features):
247 | score = (CX_loss(T_features, I_features) + CX_loss(I_features, T_features)) / 2
248 | return score
--------------------------------------------------------------------------------
/model/correspondence_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import model.networks
5 | import itertools
6 |
7 | '''
8 | CorrespondenceNet: Align images in different domains
9 | into a shared domain S, and compute the correlation
10 | matrix (vectorized)
11 |
12 | Note that a is guidance, b is exemplar
13 | e.g. for human pose transfer
14 | a is the target pose, b is the source image
15 |
16 | output: b_warp: a 3*H*W image
17 | -----------------
18 | # TODO: Add sychonized batchnorm to support multi-GPU training
19 |
20 | 20200525: Potential Bug: Insufficient memory to support 4096*4096 correspondence, retreat to 1024*1024 instead
21 | '''
22 | class CorrespondenceNet(nn.Module):
23 | def __init__(self, opt):
24 | super().__init__()
25 | print('Making a CorrespondenceNet')
26 | # domain adaptors are not shared
27 | ngf = opt.ngf
28 | self.domainA_adaptor = self.create_adaptor(opt.ncA, ngf)
29 | self.domainB_adaptor = self.create_adaptor(opt.ncB, ngf)
30 | self.softmax_alpha = 100
31 | ada_blocks = []
32 | for i in range(4):
33 | ada_blocks += [BasicBlock(ngf*4, ngf*4)]
34 |
35 | ada_blocks += [nn.Conv2d(ngf*4, ngf*4, kernel_size=1, stride=1, padding=0)]
36 | self.adaptive_feature_block = nn.Sequential(*ada_blocks)
37 |
38 | self.to_rgb = nn.Conv2d(ngf*4, 3, kernel_size=1, stride=1, padding=0)
39 |
40 | @staticmethod
41 | def warp(fa, fb, b_raw, alpha):
42 | '''
43 | calculate correspondence matrix and warp the exemplar features
44 | '''
45 | assert fa.shape == fb.shape, \
46 | 'Feature shape must match. Got %s in a and %s in b)' % (a.shape, b.shape)
47 | n,c,h,w = fa.shape
48 | # subtract mean
49 | fa = fa - torch.mean(fa, dim=(2,3), keepdim=True)
50 | fb = fb - torch.mean(fb, dim=(2,3), keepdim=True)
51 |
52 | # vectorize (merge dim H, W) and normalize channelwise vectors
53 | fa = fa.view(n, c, -1)
54 | fb = fb.view(n, c, -1)
55 | fa = fa / torch.norm(fa, dim=1, keepdim=True)
56 | fb = fb / torch.norm(fb, dim=1, keepdim=True)
57 |
58 | # correlation matrix, gonna be huge (4096*4096)
59 | # use matrix multiplication for CUDA speed up
60 | # Also, calculate the transpose of the atob correlation
61 |
62 | # warp the exemplar features b, taking softmax along the b dimension
63 | corr_ab_T = F.softmax(torch.bmm(fb.transpose(-2,-1), fa), dim=2) # n*HW*C @ n*C*HW -> n*HW*HW
64 | #print(corr_ab_T.shape)
65 | #print(softmax_weights.shape, b_raw.shape)
66 | b_warp = torch.bmm(b_raw.view(n, c, h*w), corr_ab_T) # n*HW*1
67 | return b_warp.view(n,c,h,w)
68 |
69 | def create_adaptor(self, nc, ngf):
70 | model_parts = [self.combo(nc, ngf, 3, 1, 1),
71 | self.combo(ngf, ngf*2, 4, 2, 1),
72 | self.combo(ngf*2, ngf*4, 3, 1, 1),
73 | self.combo(ngf*4, ngf*8, 4, 2, 1),
74 | self.combo(ngf*8, ngf*8, 3, 1, 1),
75 | # The following line shrinks the spatial dimension to 32*32
76 | self.combo(ngf*8, ngf*8, 4, 2, 1),
77 | [BasicBlock(ngf*8, ngf*4)],
78 | [BasicBlock(ngf*4, ngf*4)],
79 | [BasicBlock(ngf*4, ngf*4)]
80 | ]
81 | model = itertools.chain(*model_parts)
82 | return nn.Sequential(*model)
83 |
84 | def combo(self, cin, cout, kw, stride, padw):
85 | layers = [
86 | nn.Conv2d(cin, cout, kernel_size=kw, stride=stride, padding=padw),
87 | nn.InstanceNorm2d(cout),
88 | nn.LeakyReLU(0.2),
89 | ]
90 | return layers
91 |
92 | def forward(self, a, b):
93 | sa = self.domainA_adaptor(a)
94 | sb = self.domainB_adaptor(b)
95 | fa = self.adaptive_feature_block(sa)
96 | fb = self.adaptive_feature_block(sb)
97 | # This should be sb, but who knows?
98 | b_warp = self.warp(fa, fb, b_raw=sb, alpha=self.softmax_alpha)
99 | b_img = F.tanh(self.to_rgb(b_warp))
100 | return sa, sb, b_warp, b_img
101 |
102 | # Basic residual block
103 | class BasicBlock(nn.Module):
104 | def __init__(self, cin, cout):
105 | super(BasicBlock, self).__init__()
106 | layers = [
107 | nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1),
108 | nn.InstanceNorm2d(cout),
109 | nn.LeakyReLU(0.2),
110 | nn.Conv2d(cout, cout, kernel_size=3, stride=1, padding=1),
111 | nn.InstanceNorm2d(cout),
112 | ]
113 | self.conv = nn.Sequential(*layers)
114 | if cin != cout:
115 | self.shortcut = nn.Conv2d(cin, cout, kernel_size=1, stride=1, padding=0)
116 | else:
117 | self.shortcut = lambda x:x
118 |
119 | def forward(self, x):
120 | out = self.conv(x) + self.shortcut(x)
121 | return out
122 |
--------------------------------------------------------------------------------
/model/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import model.networks
4 |
5 |
6 | class Discriminator(nn.Module):
7 | def __init__(self, opt):
8 | super(Discriminator, self).__init__()
9 | print('Making a discriminator')
10 | input_nc = opt.ncB
11 | ndf = opt.ndf
12 | n_layers = opt.nd_layers
13 | self.num_D = opt.numD
14 | norm_layer = nn.BatchNorm2d
15 |
16 | if self.num_D == 1:
17 | layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
18 | self.model = nn.Sequential(*layers)
19 | else:
20 | layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
21 | self.add_module("model_0", nn.Sequential(*layers))
22 | self.down = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
23 | for i in range(1, self.num_D):
24 | ndf_i = int(round(ndf / (2**i)))
25 | layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer)
26 | self.add_module("model_%d" % i, nn.Sequential(*layers))
27 |
28 | def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
29 | kw = 4
30 | padw = 1
31 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw,
32 | stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
33 |
34 | nf_mult = 1
35 | nf_mult_prev = 1
36 | for n in range(1, n_layers):
37 | nf_mult_prev = nf_mult
38 | nf_mult = min(2**n, 8)
39 | sequence += [
40 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
41 | kernel_size=kw, stride=2, padding=padw),
42 | norm_layer(ndf * nf_mult),
43 | nn.LeakyReLU(0.2, True)
44 | ]
45 |
46 | nf_mult_prev = nf_mult
47 | nf_mult = min(2**n_layers, 8)
48 | sequence += [
49 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
50 | kernel_size=kw, stride=1, padding=padw),
51 | norm_layer(ndf * nf_mult),
52 | nn.LeakyReLU(0.2, True)
53 | ]
54 |
55 | sequence += [nn.Conv2d(ndf * nf_mult, 1,
56 | kernel_size=kw, stride=1, padding=padw)]
57 |
58 | return sequence
59 |
60 | def forward(self, input):
61 | if self.num_D == 1:
62 | return self.model(input)
63 | result = []
64 | down = input
65 | for i in range(self.num_D):
66 | model = getattr(self, "model_%d" % i)
67 | result.append(model(down))
68 | if i != self.num_D - 1:
69 | down = self.down(down)
70 | return result
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torchvision.models import vgg19
10 | from model.contextual_loss import symmetric_CX_loss
11 |
12 |
13 | # Defines the GAN loss which uses either LSGAN or the regular GAN.
14 | # When LSGAN is used, it is basically same as MSELoss,
15 | # but it abstracts away the need to create the target label tensor
16 | # that has the same size as the input
17 | class GANLoss(nn.Module):
18 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
19 | tensor=torch.cuda.FloatTensor, opt=None):
20 | super(GANLoss, self).__init__()
21 | self.real_label = target_real_label
22 | self.fake_label = target_fake_label
23 | self.real_label_tensor = None
24 | self.fake_label_tensor = None
25 | self.zero_tensor = None
26 | self.Tensor = tensor
27 | self.gan_mode = gan_mode
28 | self.opt = opt
29 | if gan_mode == 'ls':
30 | pass
31 | elif gan_mode == 'original':
32 | pass
33 | elif gan_mode == 'w':
34 | pass
35 | elif gan_mode == 'hinge':
36 | pass
37 | else:
38 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
39 |
40 | def get_target_tensor(self, input, target_is_real):
41 | if target_is_real:
42 | if self.real_label_tensor is None:
43 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
44 | self.real_label_tensor.requires_grad_(False)
45 | return self.real_label_tensor.expand_as(input)
46 | else:
47 | if self.fake_label_tensor is None:
48 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
49 | self.fake_label_tensor.requires_grad_(False)
50 | return self.fake_label_tensor.expand_as(input)
51 |
52 | def get_zero_tensor(self, input):
53 | if self.zero_tensor is None:
54 | self.zero_tensor = self.Tensor(1).fill_(0)
55 | self.zero_tensor.requires_grad_(False)
56 | return self.zero_tensor.expand_as(input)
57 |
58 | def loss(self, input, target_is_real, for_discriminator=True):
59 | if self.gan_mode == 'original': # cross entropy loss
60 | target_tensor = self.get_target_tensor(input, target_is_real)
61 | loss = F.binary_cross_entropy_with_logits(input, target_tensor)
62 | return loss
63 | elif self.gan_mode == 'ls':
64 | target_tensor = self.get_target_tensor(input, target_is_real)
65 | return F.mse_loss(input, target_tensor)
66 | elif self.gan_mode == 'hinge':
67 | if for_discriminator:
68 | if target_is_real:
69 | minval = torch.min(input - 1, self.get_zero_tensor(input))
70 | loss = -torch.mean(minval)
71 | else:
72 | minval = torch.min(-input - 1, self.get_zero_tensor(input))
73 | loss = -torch.mean(minval)
74 | else:
75 | assert target_is_real, "The generator's hinge loss must be aiming for real"
76 | loss = -torch.mean(input)
77 | return loss
78 | else:
79 | # wgan
80 | if target_is_real:
81 | return -input.mean()
82 | else:
83 | return input.mean()
84 |
85 | def __call__(self, input, target_is_real, for_discriminator=True):
86 | # computing loss is a bit complicated because |input| may not be
87 | # a tensor, but list of tensors in case of multiscale discriminator
88 | if isinstance(input, list):
89 | loss = 0
90 | for pred_i in input:
91 | if isinstance(pred_i, list):
92 | pred_i = pred_i[-1]
93 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
94 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
95 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
96 | loss += new_loss
97 | return loss / len(input)
98 | else:
99 | return self.loss(input, target_is_real, for_discriminator)
100 |
101 |
102 | # Perceptual loss and contextual loss that both
103 | # use a pretrained VGG network to extract features
104 | # To calculate different losses, assign mode when calling it
105 | class VGGLoss(nn.Module):
106 | def __init__(self, device, active_layers=None):
107 | super(VGGLoss, self).__init__()
108 | self.vgg = VGG19().to(device)
109 | self.criterion_perceptual = nn.L1Loss()
110 | self.criterion_contextual = symmetric_CX_loss
111 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
112 |
113 | def forward(self, x, y, mode='perceptual', layers=None):
114 | '''
115 | Control feature usage
116 | Say you only want to compute relu4_2
117 | set active_layers = [4]
118 | Or, you want to include relu2_2 to 5_2
119 | set active_layers = [2,3,4,5]
120 | '''
121 | criterion = getattr(self, 'criterion_%s' % mode)
122 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
123 | loss = 0
124 | if layers is None:
125 | layers = range(len(x_vgg))
126 | else:
127 | layers = [l-1 for l in layers] # 0-index
128 | for i in layers:
129 | #print(i, x_vgg[i].shape, y_vgg[i].shape)
130 | loss += self.weights[i] * criterion(x_vgg[i], y_vgg[i].detach())
131 |
132 | return loss
133 |
134 |
135 | # VGG architecter, used for the perceptual loss using a pretrained VGG network
136 | class VGG19(torch.nn.Module):
137 | def __init__(self, requires_grad=False):
138 | super().__init__()
139 | vgg_pretrained_features = vgg19(pretrained=True).features
140 | self.slice1 = torch.nn.Sequential()
141 | self.slice2 = torch.nn.Sequential()
142 | self.slice3 = torch.nn.Sequential()
143 | self.slice4 = torch.nn.Sequential()
144 | self.slice5 = torch.nn.Sequential()
145 | for x in range(2):
146 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
147 | for x in range(2, 7):
148 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
149 | for x in range(7, 12):
150 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
151 | for x in range(12, 21):
152 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
153 | for x in range(21, 30):
154 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
155 | if not requires_grad:
156 | for param in self.parameters():
157 | param.requires_grad = False
158 |
159 | def forward(self, X):
160 | h_relu1 = self.slice1(X)
161 | h_relu2 = self.slice2(h_relu1)
162 | h_relu3 = self.slice3(h_relu2)
163 | h_relu4 = self.slice4(h_relu3)
164 | h_relu5 = self.slice5(h_relu4)
165 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
166 | return out
167 |
--------------------------------------------------------------------------------
/model/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | from torch.optim import lr_scheduler
6 |
7 | ###############################################################################
8 | # Helper functions
9 | ###############################################################################
10 |
11 |
12 | def init_weights(net, init_type='normal', init_gain=0.02):
13 | """Initialize network weights.
14 | Parameters:
15 | net (network) -- network to be initialized
16 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
17 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
18 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
19 | work better for some applications. Feel free to try yourself.
20 | """
21 | def init_func(m): # define the initialization function
22 | classname = m.__class__.__name__
23 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
24 | if init_type == 'normal':
25 | init.normal_(m.weight.data, 0.0, init_gain)
26 | elif init_type == 'xavier':
27 | init.xavier_normal_(m.weight.data, gain=init_gain)
28 | elif init_type == 'kaiming':
29 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
30 | elif init_type == 'orthogonal':
31 | init.orthogonal_(m.weight.data, gain=init_gain)
32 | else:
33 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
34 | if hasattr(m, 'bias') and m.bias is not None:
35 | init.constant_(m.bias.data, 0.0)
36 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
37 | init.normal_(m.weight.data, 1.0, init_gain)
38 | init.constant_(m.bias.data, 0.0)
39 |
40 | print('initialize network with %s' % init_type)
41 | net.apply(init_func) # apply the initialization function
42 |
43 |
44 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
45 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
46 | Parameters:
47 | net (network) -- the network to be initialized
48 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
49 | gain (float) -- scaling factor for normal, xavier and orthogonal.
50 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
51 | Return an initialized network.
52 | """
53 | if len(gpu_ids) > 0:
54 | assert(torch.cuda.is_available())
55 | net.to(gpu_ids[0])
56 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
57 | init_weights(net, init_type, init_gain=init_gain)
58 | return net
59 |
60 |
61 | def get_scheduler(optimizer, opt):
62 | """Return a learning rate scheduler
63 | Parameters:
64 | optimizer -- the optimizer of the network
65 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
66 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
67 | For 'linear', we keep the same learning rate for the first epochs
68 | and linearly decay the rate to zero over the next epochs.
69 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
70 | See https://pytorch.org/docs/stable/optim.html for more details.
71 | """
72 | if opt.lr_policy == 'linear':
73 | def lambda_rule(epoch):
74 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
75 | return lr_l
76 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
77 | elif opt.lr_policy == 'step':
78 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
79 | elif opt.lr_policy == 'plateau':
80 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
81 | elif opt.lr_policy == 'cosine':
82 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
83 | else:
84 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
85 | return scheduler
86 |
87 |
88 | def get_norm_layer(norm_type='instance'):
89 | """Return a normalization layer
90 | Parameters:
91 | norm_type (str) -- the name of the normalization layer: batch | instance | none
92 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
93 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
94 | """
95 | if norm_type == 'batch':
96 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
97 | elif norm_type == 'instance':
98 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
99 | elif norm_type == 'none':
100 | norm_layer = None
101 | else:
102 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
103 | return norm_layer
104 |
105 |
106 | def get_non_linearity(layer_type='relu'):
107 | if layer_type == 'relu':
108 | nl_layer = functools.partial(nn.ReLU, inplace=True)
109 | elif layer_type == 'lrelu':
110 | nl_layer = functools.partial(
111 | nn.LeakyReLU, negative_slope=0.2, inplace=True)
112 | elif layer_type == 'elu':
113 | nl_layer = functools.partial(nn.ELU, inplace=True)
114 | else:
115 | raise NotImplementedError(
116 | 'nonlinearity activitation [%s] is not found' % layer_type)
117 | return nl_layer
118 |
119 |
120 | def define_G(input_nc, output_nc, nz, ngf, netG='unet_128', norm='batch', nl='relu',
121 | use_dropout=False, init_type='xavier', init_gain=0.02, gpu_ids=[], where_add='input', upsample='bilinear'):
122 | net = None
123 | norm_layer = get_norm_layer(norm_type=norm)
124 | nl_layer = get_non_linearity(layer_type=nl)
125 |
126 | if nz == 0:
127 | where_add = 'input'
128 |
129 | if netG == 'unet_128' and where_add == 'input':
130 | net = G_Unet_add_input(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
131 | use_dropout=use_dropout, upsample=upsample)
132 | elif netG == 'unet_256' and where_add == 'input':
133 | net = G_Unet_add_input(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
134 | use_dropout=use_dropout, upsample=upsample)
135 | elif netG == 'unet_128' and where_add == 'all':
136 | net = G_Unet_add_all(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
137 | use_dropout=use_dropout, upsample=upsample)
138 | elif netG == 'unet_256' and where_add == 'all':
139 | net = G_Unet_add_all(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
140 | use_dropout=use_dropout, upsample=upsample)
141 | else:
142 | raise NotImplementedError('Generator model name [%s] is not recognized' % net)
143 |
144 | return init_net(net, init_type, init_gain, gpu_ids)
145 |
146 |
147 | def define_D(input_nc, ndf, netD, norm='batch', nl='lrelu', init_type='xavier', init_gain=0.02, num_Ds=1, gpu_ids=[]):
148 | net = None
149 | norm_layer = get_norm_layer(norm_type=norm)
150 | nl = 'lrelu' # use leaky relu for D
151 | nl_layer = get_non_linearity(layer_type=nl)
152 |
153 | if netD == 'basic_128':
154 | net = D_NLayers(input_nc, ndf, n_layers=2, norm_layer=norm_layer, nl_layer=nl_layer)
155 | elif netD == 'basic_256':
156 | net = D_NLayers(input_nc, ndf, n_layers=3, norm_layer=norm_layer, nl_layer=nl_layer)
157 | elif netD == 'basic_128_multi':
158 | net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=2, norm_layer=norm_layer, num_D=num_Ds)
159 | elif netD == 'basic_256_multi':
160 | net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=3, norm_layer=norm_layer, num_D=num_Ds)
161 | else:
162 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
163 | return init_net(net, init_type, init_gain, gpu_ids)
164 |
165 |
166 | def define_E(input_nc, output_nc, ndf, netE,
167 | norm='batch', nl='lrelu',
168 | init_type='xavier', init_gain=0.02, gpu_ids=[], vaeLike=False):
169 | net = None
170 | norm_layer = get_norm_layer(norm_type=norm)
171 | nl = 'lrelu' # use leaky relu for E
172 | nl_layer = get_non_linearity(layer_type=nl)
173 | if netE == 'resnet_128':
174 | net = E_ResNet(input_nc, output_nc, ndf, n_blocks=4, norm_layer=norm_layer,
175 | nl_layer=nl_layer, vaeLike=vaeLike)
176 | elif netE == 'resnet_256':
177 | net = E_ResNet(input_nc, output_nc, ndf, n_blocks=5, norm_layer=norm_layer,
178 | nl_layer=nl_layer, vaeLike=vaeLike)
179 | elif netE == 'conv_128':
180 | net = E_NLayers(input_nc, output_nc, ndf, n_layers=4, norm_layer=norm_layer,
181 | nl_layer=nl_layer, vaeLike=vaeLike)
182 | elif netE == 'conv_256':
183 | net = E_NLayers(input_nc, output_nc, ndf, n_layers=5, norm_layer=norm_layer,
184 | nl_layer=nl_layer, vaeLike=vaeLike)
185 | else:
186 | raise NotImplementedError('Encoder model name [%s] is not recognized' % net)
187 |
188 | return init_net(net, init_type, init_gain, gpu_ids)
189 |
190 |
191 | class D_NLayersMulti(nn.Module):
192 | def __init__(self, input_nc, ndf=64, n_layers=3,
193 | norm_layer=nn.BatchNorm2d, num_D=1):
194 | super(D_NLayersMulti, self).__init__()
195 | # st()
196 | self.num_D = num_D
197 | if num_D == 1:
198 | layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
199 | self.model = nn.Sequential(*layers)
200 | else:
201 | layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
202 | self.add_module("model_0", nn.Sequential(*layers))
203 | self.down = nn.AvgPool2d(3, stride=2, padding=[
204 | 1, 1], count_include_pad=False)
205 | for i in range(1, num_D):
206 | ndf_i = int(round(ndf / (2**i)))
207 | layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer)
208 | self.add_module("model_%d" % i, nn.Sequential(*layers))
209 |
210 | def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
211 | kw = 4
212 | padw = 1
213 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw,
214 | stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
215 |
216 | nf_mult = 1
217 | nf_mult_prev = 1
218 | for n in range(1, n_layers):
219 | nf_mult_prev = nf_mult
220 | nf_mult = min(2**n, 8)
221 | sequence += [
222 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
223 | kernel_size=kw, stride=2, padding=padw),
224 | norm_layer(ndf * nf_mult),
225 | nn.LeakyReLU(0.2, True)
226 | ]
227 |
228 | nf_mult_prev = nf_mult
229 | nf_mult = min(2**n_layers, 8)
230 | sequence += [
231 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
232 | kernel_size=kw, stride=1, padding=padw),
233 | norm_layer(ndf * nf_mult),
234 | nn.LeakyReLU(0.2, True)
235 | ]
236 |
237 | sequence += [nn.Conv2d(ndf * nf_mult, 1,
238 | kernel_size=kw, stride=1, padding=padw)]
239 |
240 | return sequence
241 |
242 | def forward(self, input):
243 | if self.num_D == 1:
244 | return self.model(input)
245 | result = []
246 | down = input
247 | for i in range(self.num_D):
248 | model = getattr(self, "model_%d" % i)
249 | result.append(model(down))
250 | if i != self.num_D - 1:
251 | down = self.down(down)
252 | return result
253 |
254 |
255 | class D_NLayers(nn.Module):
256 | """Defines a PatchGAN discriminator"""
257 |
258 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
259 | """Construct a PatchGAN discriminator
260 | Parameters:
261 | input_nc (int) -- the number of channels in input images
262 | ndf (int) -- the number of filters in the last conv layer
263 | n_layers (int) -- the number of conv layers in the discriminator
264 | norm_layer -- normalization layer
265 | """
266 | super(D_NLayers, self).__init__()
267 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
268 | use_bias = norm_layer.func != nn.BatchNorm2d
269 | else:
270 | use_bias = norm_layer != nn.BatchNorm2d
271 |
272 | kw = 4
273 | padw = 1
274 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
275 | nf_mult = 1
276 | nf_mult_prev = 1
277 | for n in range(1, n_layers): # gradually increase the number of filters
278 | nf_mult_prev = nf_mult
279 | nf_mult = min(2 ** n, 8)
280 | sequence += [
281 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
282 | norm_layer(ndf * nf_mult),
283 | nn.LeakyReLU(0.2, True)
284 | ]
285 |
286 | nf_mult_prev = nf_mult
287 | nf_mult = min(2 ** n_layers, 8)
288 | sequence += [
289 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
290 | norm_layer(ndf * nf_mult),
291 | nn.LeakyReLU(0.2, True)
292 | ]
293 |
294 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
295 | self.model = nn.Sequential(*sequence)
296 |
297 | def forward(self, input):
298 | """Standard forward."""
299 | return self.model(input)
300 |
301 |
302 | ##############################################################################
303 | # Classes
304 | ##############################################################################
305 | class RecLoss(nn.Module):
306 | def __init__(self, use_L2=True):
307 | super(RecLoss, self).__init__()
308 | self.use_L2 = use_L2
309 |
310 | def __call__(self, input, target, batch_mean=True):
311 | if self.use_L2:
312 | diff = (input - target) ** 2
313 | else:
314 | diff = torch.abs(input - target)
315 | if batch_mean:
316 | return torch.mean(diff)
317 | else:
318 | return torch.mean(torch.mean(torch.mean(diff, dim=1), dim=2), dim=3)
319 |
320 |
321 | # Defines the GAN loss which uses either LSGAN or the regular GAN.
322 | # When LSGAN is used, it is basically same as MSELoss,
323 | # but it abstracts away the need to create the target label tensor
324 | # that has the same size as the input
325 | class GANLoss(nn.Module):
326 | """Define different GAN objectives.
327 |
328 | The GANLoss class abstracts away the need to create the target label tensor
329 | that has the same size as the input.
330 | """
331 |
332 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
333 | """ Initialize the GANLoss class.
334 |
335 | Parameters:
336 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
337 | target_real_label (bool) - - label for a real image
338 | target_fake_label (bool) - - label of a fake image
339 |
340 | Note: Do not use sigmoid as the last layer of Discriminator.
341 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
342 | """
343 | super(GANLoss, self).__init__()
344 | self.register_buffer('real_label', torch.tensor(target_real_label))
345 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
346 | self.gan_mode = gan_mode
347 | if gan_mode == 'lsgan':
348 | self.loss = nn.MSELoss()
349 | elif gan_mode == 'vanilla':
350 | self.loss = nn.BCEWithLogitsLoss()
351 | elif gan_mode in ['wgangp']:
352 | self.loss = None
353 | else:
354 | raise NotImplementedError('gan mode %s not implemented' % gan_mode)
355 |
356 | def get_target_tensor(self, prediction, target_is_real):
357 | """Create label tensors with the same size as the input.
358 |
359 | Parameters:
360 | prediction (tensor) - - tpyically the prediction from a discriminator
361 | target_is_real (bool) - - if the ground truth label is for real images or fake images
362 |
363 | Returns:
364 | A label tensor filled with ground truth label, and with the size of the input
365 | """
366 |
367 | if target_is_real:
368 | target_tensor = self.real_label
369 | else:
370 | target_tensor = self.fake_label
371 | return target_tensor.expand_as(prediction)
372 |
373 | def __call__(self, predictions, target_is_real):
374 | """Calculate loss given Discriminator's output and grount truth labels.
375 |
376 | Parameters:
377 | prediction (tensor list) - - tpyically the prediction output from a discriminator; supports multi Ds.
378 | target_is_real (bool) - - if the ground truth label is for real images or fake images
379 |
380 | Returns:
381 | the calculated loss.
382 | """
383 | all_losses = []
384 | for prediction in predictions:
385 | if self.gan_mode in ['lsgan', 'vanilla']:
386 | target_tensor = self.get_target_tensor(prediction, target_is_real)
387 | loss = self.loss(prediction, target_tensor)
388 | elif self.gan_mode == 'wgangp':
389 | if target_is_real:
390 | loss = -prediction.mean()
391 | else:
392 | loss = prediction.mean()
393 | all_losses.append(loss)
394 | total_loss = sum(all_losses)
395 | return total_loss, all_losses
396 |
397 |
398 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
399 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
400 | Arguments:
401 | netD (network) -- discriminator network
402 | real_data (tensor array) -- real images
403 | fake_data (tensor array) -- generated images from the generator
404 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
405 | type (str) -- if we mix real and fake data or not [real | fake | mixed].
406 | constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
407 | lambda_gp (float) -- weight for this loss
408 | Returns the gradient penalty loss
409 | """
410 | if lambda_gp > 0.0:
411 | if type == 'real': # either use real images, fake images, or a linear interpolation of two.
412 | interpolatesv = real_data
413 | elif type == 'fake':
414 | interpolatesv = fake_data
415 | elif type == 'mixed':
416 | alpha = torch.rand(real_data.shape[0], 1)
417 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
418 | alpha = alpha.to(device)
419 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
420 | else:
421 | raise NotImplementedError('{} not implemented'.format(type))
422 | interpolatesv.requires_grad_(True)
423 | disc_interpolates = netD(interpolatesv)
424 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
425 | grad_outputs=torch.ones(disc_interpolates.size()).to(device),
426 | create_graph=True, retain_graph=True, only_inputs=True)
427 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data
428 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
429 | return gradient_penalty, gradients
430 | else:
431 | return 0.0, None
432 |
433 | # Defines the Unet generator.
434 | # |num_downs|: number of downsamplings in UNet. For example,
435 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1
436 | # at the bottleneck
437 |
438 |
439 | class G_Unet_add_input(nn.Module):
440 | def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
441 | norm_layer=None, nl_layer=None, use_dropout=False,
442 | upsample='basic'):
443 | super(G_Unet_add_input, self).__init__()
444 | self.nz = nz
445 | max_nchn = 8
446 | # construct unet structure
447 | unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn,
448 | innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
449 | for i in range(num_downs - 5):
450 | unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block,
451 | norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
452 | unet_block = UnetBlock(ngf * 4, ngf * 4, ngf * max_nchn, unet_block,
453 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
454 | unet_block = UnetBlock(ngf * 2, ngf * 2, ngf * 4, unet_block,
455 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
456 | unet_block = UnetBlock(ngf, ngf, ngf * 2, unet_block,
457 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
458 | unet_block = UnetBlock(input_nc + nz, output_nc, ngf, unet_block,
459 | outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
460 |
461 | self.model = unet_block
462 |
463 | def forward(self, x, z=None):
464 | if self.nz > 0:
465 | z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
466 | z.size(0), z.size(1), x.size(2), x.size(3))
467 | x_with_z = torch.cat([x, z_img], 1)
468 | else:
469 | x_with_z = x # no z
470 |
471 | return self.model(x_with_z)
472 |
473 |
474 | def upsampleLayer(inplanes, outplanes, upsample='basic', padding_type='zero'):
475 | # padding_type = 'zero'
476 | if upsample == 'basic':
477 | upconv = [nn.ConvTranspose2d(
478 | inplanes, outplanes, kernel_size=4, stride=2, padding=1)]
479 | elif upsample == 'bilinear':
480 | upconv = [nn.Upsample(scale_factor=2, mode='bilinear'),
481 | nn.ReflectionPad2d(1),
482 | nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=1, padding=0)]
483 | else:
484 | raise NotImplementedError(
485 | 'upsample layer [%s] not implemented' % upsample)
486 | return upconv
487 |
488 |
489 | # Defines the submodule with skip connection.
490 | # X -------------------identity---------------------- X
491 | # |-- downsampling -- |submodule| -- upsampling --|
492 | class UnetBlock(nn.Module):
493 | def __init__(self, input_nc, outer_nc, inner_nc,
494 | submodule=None, outermost=False, innermost=False,
495 | norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='zero'):
496 | super(UnetBlock, self).__init__()
497 | self.outermost = outermost
498 | p = 0
499 | downconv = []
500 | if padding_type == 'reflect':
501 | downconv += [nn.ReflectionPad2d(1)]
502 | elif padding_type == 'replicate':
503 | downconv += [nn.ReplicationPad2d(1)]
504 | elif padding_type == 'zero':
505 | p = 1
506 | else:
507 | raise NotImplementedError(
508 | 'padding [%s] is not implemented' % padding_type)
509 | downconv += [nn.Conv2d(input_nc, inner_nc,
510 | kernel_size=4, stride=2, padding=p)]
511 | # downsample is different from upsample
512 | downrelu = nn.LeakyReLU(0.2, True)
513 | downnorm = norm_layer(inner_nc) if norm_layer is not None else None
514 | uprelu = nl_layer()
515 | upnorm = norm_layer(outer_nc) if norm_layer is not None else None
516 |
517 | if outermost:
518 | upconv = upsampleLayer(
519 | inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
520 | down = downconv
521 | up = [uprelu] + upconv + [nn.Tanh()]
522 | model = down + [submodule] + up
523 | elif innermost:
524 | upconv = upsampleLayer(
525 | inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
526 | down = [downrelu] + downconv
527 | up = [uprelu] + upconv
528 | if upnorm is not None:
529 | up += [upnorm]
530 | model = down + up
531 | else:
532 | upconv = upsampleLayer(
533 | inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
534 | down = [downrelu] + downconv
535 | if downnorm is not None:
536 | down += [downnorm]
537 | up = [uprelu] + upconv
538 | if upnorm is not None:
539 | up += [upnorm]
540 |
541 | if use_dropout:
542 | model = down + [submodule] + up + [nn.Dropout(0.5)]
543 | else:
544 | model = down + [submodule] + up
545 |
546 | self.model = nn.Sequential(*model)
547 |
548 | def forward(self, x):
549 | if self.outermost:
550 | return self.model(x)
551 | else:
552 | return torch.cat([self.model(x), x], 1)
553 |
554 |
555 | def conv3x3(in_planes, out_planes):
556 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
557 | padding=1, bias=True)
558 |
559 |
560 | # two usage cases, depend on kw and padw
561 | def upsampleConv(inplanes, outplanes, kw, padw):
562 | sequence = []
563 | sequence += [nn.Upsample(scale_factor=2, mode='nearest')]
564 | sequence += [nn.Conv2d(inplanes, outplanes, kernel_size=kw,
565 | stride=1, padding=padw, bias=True)]
566 | return nn.Sequential(*sequence)
567 |
568 |
569 | def meanpoolConv(inplanes, outplanes):
570 | sequence = []
571 | sequence += [nn.AvgPool2d(kernel_size=2, stride=2)]
572 | sequence += [nn.Conv2d(inplanes, outplanes,
573 | kernel_size=1, stride=1, padding=0, bias=True)]
574 | return nn.Sequential(*sequence)
575 |
576 |
577 | def convMeanpool(inplanes, outplanes):
578 | sequence = []
579 | sequence += [conv3x3(inplanes, outplanes)]
580 | sequence += [nn.AvgPool2d(kernel_size=2, stride=2)]
581 | return nn.Sequential(*sequence)
582 |
583 |
584 | class BasicBlockUp(nn.Module):
585 | def __init__(self, inplanes, outplanes, norm_layer=None, nl_layer=None):
586 | super(BasicBlockUp, self).__init__()
587 | layers = []
588 | if norm_layer is not None:
589 | layers += [norm_layer(inplanes)]
590 | layers += [nl_layer()]
591 | layers += [upsampleConv(inplanes, outplanes, kw=3, padw=1)]
592 | if norm_layer is not None:
593 | layers += [norm_layer(outplanes)]
594 | layers += [conv3x3(outplanes, outplanes)]
595 | self.conv = nn.Sequential(*layers)
596 | self.shortcut = upsampleConv(inplanes, outplanes, kw=1, padw=0)
597 |
598 | def forward(self, x):
599 | out = self.conv(x) + self.shortcut(x)
600 | return out
601 |
602 |
603 | class BasicBlock(nn.Module):
604 | def __init__(self, inplanes, outplanes, norm_layer=None, nl_layer=None):
605 | super(BasicBlock, self).__init__()
606 | layers = []
607 | if norm_layer is not None:
608 | layers += [norm_layer(inplanes)]
609 | layers += [nl_layer()]
610 | layers += [conv3x3(inplanes, inplanes)]
611 | if norm_layer is not None:
612 | layers += [norm_layer(inplanes)]
613 | layers += [nl_layer()]
614 | layers += [convMeanpool(inplanes, outplanes)]
615 | self.conv = nn.Sequential(*layers)
616 | self.shortcut = meanpoolConv(inplanes, outplanes)
617 |
618 | def forward(self, x):
619 | out = self.conv(x) + self.shortcut(x)
620 | return out
621 |
622 |
623 | class E_ResNet(nn.Module):
624 | def __init__(self, input_nc=3, output_nc=1, ndf=64, n_blocks=4,
625 | norm_layer=None, nl_layer=None, vaeLike=False):
626 | super(E_ResNet, self).__init__()
627 | self.vaeLike = vaeLike
628 | max_ndf = 4
629 | conv_layers = [
630 | nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1, bias=True)]
631 | for n in range(1, n_blocks):
632 | input_ndf = ndf * min(max_ndf, n)
633 | output_ndf = ndf * min(max_ndf, n + 1)
634 | conv_layers += [BasicBlock(input_ndf,
635 | output_ndf, norm_layer, nl_layer)]
636 | conv_layers += [nl_layer(), nn.AvgPool2d(8)]
637 | if vaeLike:
638 | self.fc = nn.Sequential(*[nn.Linear(output_ndf, output_nc)])
639 | self.fcVar = nn.Sequential(*[nn.Linear(output_ndf, output_nc)])
640 | else:
641 | self.fc = nn.Sequential(*[nn.Linear(output_ndf, output_nc)])
642 | self.conv = nn.Sequential(*conv_layers)
643 |
644 | def forward(self, x):
645 | x_conv = self.conv(x)
646 | conv_flat = x_conv.view(x.size(0), -1)
647 | output = self.fc(conv_flat)
648 | if self.vaeLike:
649 | outputVar = self.fcVar(conv_flat)
650 | return output, outputVar
651 | else:
652 | return output
653 | return output
654 |
655 |
656 | # Defines the Unet generator.
657 | # |num_downs|: number of downsamplings in UNet. For example,
658 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1
659 | # at the bottleneck
660 | class G_Unet_add_all(nn.Module):
661 | def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
662 | norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic'):
663 | super(G_Unet_add_all, self).__init__()
664 | self.nz = nz
665 | # construct unet structure
666 | unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, None, innermost=True,
667 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
668 | unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, unet_block,
669 | norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
670 | for i in range(num_downs - 6):
671 | unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, unet_block,
672 | norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
673 | unet_block = UnetBlock_with_z(ngf * 4, ngf * 4, ngf * 8, nz, unet_block,
674 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
675 | unet_block = UnetBlock_with_z(ngf * 2, ngf * 2, ngf * 4, nz, unet_block,
676 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
677 | unet_block = UnetBlock_with_z(
678 | ngf, ngf, ngf * 2, nz, unet_block, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
679 | unet_block = UnetBlock_with_z(input_nc, output_nc, ngf, nz, unet_block,
680 | outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
681 | self.model = unet_block
682 |
683 | def forward(self, x, z):
684 | return self.model(x, z)
685 |
686 |
687 | class UnetBlock_with_z(nn.Module):
688 | def __init__(self, input_nc, outer_nc, inner_nc, nz=0,
689 | submodule=None, outermost=False, innermost=False,
690 | norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='zero'):
691 | super(UnetBlock_with_z, self).__init__()
692 | p = 0
693 | downconv = []
694 | if padding_type == 'reflect':
695 | downconv += [nn.ReflectionPad2d(1)]
696 | elif padding_type == 'replicate':
697 | downconv += [nn.ReplicationPad2d(1)]
698 | elif padding_type == 'zero':
699 | p = 1
700 | else:
701 | raise NotImplementedError(
702 | 'padding [%s] is not implemented' % padding_type)
703 |
704 | self.outermost = outermost
705 | self.innermost = innermost
706 | self.nz = nz
707 | input_nc = input_nc + nz
708 | downconv += [nn.Conv2d(input_nc, inner_nc,
709 | kernel_size=4, stride=2, padding=p)]
710 | # downsample is different from upsample
711 | downrelu = nn.LeakyReLU(0.2, True)
712 | uprelu = nl_layer()
713 |
714 | if outermost:
715 | upconv = upsampleLayer(
716 | inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
717 | down = downconv
718 | up = [uprelu] + upconv + [nn.Tanh()]
719 | elif innermost:
720 | upconv = upsampleLayer(
721 | inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
722 | down = [downrelu] + downconv
723 | up = [uprelu] + upconv
724 | if norm_layer is not None:
725 | up += [norm_layer(outer_nc)]
726 | else:
727 | upconv = upsampleLayer(
728 | inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
729 | down = [downrelu] + downconv
730 | if norm_layer is not None:
731 | down += [norm_layer(inner_nc)]
732 | up = [uprelu] + upconv
733 |
734 | if norm_layer is not None:
735 | up += [norm_layer(outer_nc)]
736 |
737 | if use_dropout:
738 | up += [nn.Dropout(0.5)]
739 | self.down = nn.Sequential(*down)
740 | self.submodule = submodule
741 | self.up = nn.Sequential(*up)
742 |
743 | def forward(self, x, z):
744 | # print(x.size())
745 | if self.nz > 0:
746 | z_img = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), x.size(2), x.size(3))
747 | x_and_z = torch.cat([x, z_img], 1)
748 | else:
749 | x_and_z = x
750 |
751 | if self.outermost:
752 | x1 = self.down(x_and_z)
753 | x2 = self.submodule(x1, z)
754 | return self.up(x2)
755 | elif self.innermost:
756 | x1 = self.up(self.down(x_and_z))
757 | return torch.cat([x1, x], 1)
758 | else:
759 | x1 = self.down(x_and_z)
760 | x2 = self.submodule(x1, z)
761 | return torch.cat([self.up(x2), x], 1)
762 |
763 |
764 | class E_NLayers(nn.Module):
765 | def __init__(self, input_nc, output_nc=1, ndf=64, n_layers=3,
766 | norm_layer=None, nl_layer=None, vaeLike=False):
767 | super(E_NLayers, self).__init__()
768 | self.vaeLike = vaeLike
769 |
770 | kw, padw = 4, 1
771 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw,
772 | stride=2, padding=padw), nl_layer()]
773 |
774 | nf_mult = 1
775 | nf_mult_prev = 1
776 | for n in range(1, n_layers):
777 | nf_mult_prev = nf_mult
778 | nf_mult = min(2**n, 4)
779 | sequence += [
780 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
781 | kernel_size=kw, stride=2, padding=padw)]
782 | if norm_layer is not None:
783 | sequence += [norm_layer(ndf * nf_mult)]
784 | sequence += [nl_layer()]
785 | sequence += [nn.AvgPool2d(8)]
786 | self.conv = nn.Sequential(*sequence)
787 | self.fc = nn.Sequential(*[nn.Linear(ndf * nf_mult, output_nc)])
788 | if vaeLike:
789 | self.fcVar = nn.Sequential(*[nn.Linear(ndf * nf_mult, output_nc)])
790 |
791 | def forward(self, x):
792 | x_conv = self.conv(x)
793 | conv_flat = x_conv.view(x.size(0), -1)
794 | output = self.fc(conv_flat)
795 | if self.vaeLike:
796 | outputVar = self.fcVar(conv_flat)
797 | return output, outputVar
798 | return output
799 |
--------------------------------------------------------------------------------
/model/translation_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import model.networks
5 | from torch.nn.utils import spectral_norm
6 |
7 | # Also, we figure it would be better to inject the warped
8 | # guidance at the beginning rather than a constant tensor
9 |
10 | class TranslationNet(nn.Module):
11 | def __init__(self, opt):
12 | super().__init__()
13 | print('Making a TranslationNet')
14 | self.fc = nn.Conv2d(3, 16 * opt.ngf, 3, padding=1)
15 | self.sw = opt.image_size // (2**5) # fixed, 5 upsample layers
16 | self.head = SPADEResBlk(16 * opt.ngf, 16 * opt.ngf, opt.seg_dim)
17 | self.G_middle_0 = SPADEResBlk(16 * opt.ngf, 16 * opt.ngf, opt.seg_dim)
18 | self.G_middle_1 = SPADEResBlk(16 * opt.ngf, 16 * opt.ngf, opt.seg_dim)
19 | self.up_0 = SPADEResBlk(16 * opt.ngf, 8 * opt.ngf, opt.seg_dim)
20 | self.up_1 = SPADEResBlk(8 * opt.ngf, 4 * opt.ngf, opt.seg_dim)
21 | self.non_local = NonLocalLayer(opt.ngf*4)
22 | self.up_2 = SPADEResBlk(4 * opt.ngf, 2 * opt.ngf, opt.seg_dim)
23 | self.up_3 = SPADEResBlk(2 * opt.ngf, 1 * opt.ngf, opt.seg_dim)
24 |
25 | self.conv_img = nn.Conv2d(opt.ngf, 3, kernel_size=3, stride=1, padding=1)
26 |
27 | @staticmethod
28 | def up(x):
29 | return F.interpolate(x, scale_factor=2, mode='bilinear')
30 |
31 | def forward(self, x, seg=None):
32 | if seg is None:
33 | seg = x
34 | # separate execute
35 | x = F.interpolate(x, (self.sw, self.sw), mode='bilinear') # how can I forget this one?
36 | x = self.fc(x)
37 | x = self.head(x, seg)
38 |
39 | x = self.up(x) # 16
40 | x = self.G_middle_0(x, seg)
41 | x = self.G_middle_1(x, seg)
42 |
43 | x = self.up(x) # 32
44 | x = self.up_0(x, seg)
45 | x = self.up(x) # 64
46 | x = self.up_1(x, seg)
47 | x = self.up(x) # 128
48 |
49 | # 20200525: Critical Bug:
50 | # Using non-local layer with such a huge spatial resolution (128*128)
51 | # occupied way too much memory (as the intermediate tensor is O(h ** 4) memory)
52 | # I sincerely hope it's an honest mistake:)
53 | # x = self.non_local(x)
54 |
55 | x = self.up_2(x, seg)
56 | x = self.up(x) # 256
57 | x = self.up_3(x, seg)
58 |
59 | x = self.conv_img(F.leaky_relu(x, 2e-1))
60 | x = F.tanh(x)
61 | return x
62 |
63 |
64 | # NOTE: The SPADE implementation will slightly
65 | # differ from the original https://github.com/NVlabs/SPADE
66 | # where BN will be replaced with PN.
67 | class SPADE(nn.Module):
68 | def __init__(self, cin, seg_dim):
69 | super().__init__()
70 | self.conv = nn.Sequential(
71 | nn.Conv2d(seg_dim, 128, kernel_size=3, stride=1, padding=1),
72 | nn.ReLU(),
73 | )
74 | self.alpha = nn.Conv2d(128, cin,
75 | kernel_size=3, stride=1, padding=1)
76 | self.beta = nn.Conv2d(128, cin,
77 | kernel_size=3, stride=1, padding=1)
78 |
79 | @staticmethod
80 | def PN(x):
81 | '''
82 | positional normalization: normalize each positional vector along the channel dimension
83 | '''
84 | assert len(x.shape) == 4, 'Only works for 4D(image) tensor'
85 | x = x - x.mean(dim=1, keepdim=True)
86 | x_norm = x.norm(dim=1, keepdim=True) + 1e-6
87 | x = x / x_norm
88 | return x
89 |
90 | def DPN(self, x, s):
91 | h, w = x.shape[2], x.shape[3]
92 | s = F.interpolate(s, (h, w), mode='bilinear')
93 | s = self.conv(s)
94 | a = self.alpha(s)
95 | b = self.beta(s)
96 | return x * (1 + a) + b
97 |
98 | def forward(self, x, s):
99 | x_out = self.DPN(self.PN(x), s)
100 | return x_out
101 |
102 | class SPADEResBlk(nn.Module):
103 | def __init__(self, fin, fout, seg_fin):
104 | super().__init__()
105 | # Attributes
106 | self.learned_shortcut = (fin != fout)
107 | fmiddle = min(fin, fout)
108 |
109 | # create conv layers
110 | self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
111 | self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
112 | if self.learned_shortcut:
113 | self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
114 |
115 | # apply spectral norm if specified
116 | self.conv_0 = spectral_norm(self.conv_0)
117 | self.conv_1 = spectral_norm(self.conv_1)
118 | if self.learned_shortcut:
119 | self.conv_s = spectral_norm(self.conv_s)
120 |
121 | # define normalization layers
122 | self.norm_0 = SPADE(fin, seg_fin)
123 | self.norm_1 = SPADE(fmiddle, seg_fin)
124 | if self.learned_shortcut:
125 | self.norm_s = SPADE(fin, seg_fin)
126 |
127 | # note the resnet block with SPADE also takes in |seg|,
128 | # the semantic segmentation map as input
129 | def forward(self, x, seg):
130 | x_s = self.shortcut(x, seg)
131 |
132 | dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
133 | dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))
134 |
135 | out = x_s + dx
136 |
137 | return out
138 |
139 | def shortcut(self, x, seg):
140 | if self.learned_shortcut:
141 | x_s = self.conv_s(self.norm_s(x, seg))
142 | else:
143 | x_s = x
144 | return x_s
145 |
146 | def actvn(self, x):
147 | return F.leaky_relu(x, 2e-1)
148 |
149 |
150 | class NonLocalLayer(nn.Module):
151 | # Non-local layer for 2D shape
152 | def __init__(self, cin):
153 | super().__init__()
154 | self.cinter = cin // 2
155 | self.theta = nn.Conv2d(cin, self.cinter,
156 | kernel_size=1, stride=1, padding=0)
157 | self.phi = nn.Conv2d(cin, self.cinter,
158 | kernel_size=1, stride=1, padding=0)
159 | self.g = nn.Conv2d(cin, self.cinter,
160 | kernel_size=1, stride=1, padding=0)
161 |
162 | self.w = nn.Conv2d(self.cinter, cin,
163 | kernel_size=1, stride=1, padding=0)
164 |
165 | def forward(self, x):
166 | n, c, h, w = x.shape
167 | g_x = self.g(x).view(n, self.cinter, -1)
168 | phi_x = self.phi(x).view(n, self.cinter, -1)
169 | theta_x = self.theta(x).view(n, self.cinter, -1)
170 | # This non-local layer here occupies too much memory...
171 | print(phi_x.shape, theta_x.shape)
172 | f_x = torch.bmm(phi_x.transpose(-1,-2), theta_x) # note the transpose here
173 | f_x = F.softmax(f_x, dim=-1)
174 | res_x = self.w(torch.bmm(g_x, f_x)) # inverse order to save a permute of g_x
175 | return x + res_x
176 |
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | INF = 99999999
2 |
3 | class BaseOptions(object):
4 | # Data options
5 | dataroot='datasets/fashion'
6 | dataset_mode='fashion'
7 | name='fashion_cocosnet'
8 | checkpoints_dir='checkpoints'
9 | results_dir='results'
10 | num_workers=0
11 | batch_size=1
12 | serial_batches=False
13 | max_dataset_size=INF
14 | gpu_ids = [2]
15 |
16 | # Model options
17 | image_size=256
18 | padding=40 # For deep fashion dataset, the input image maybe cropped
19 | model='cocos'
20 | ncA=3
21 | ncB=3
22 | seg_dim=3
23 | ngf=16
24 | ndf=16
25 | numD=2
26 | nd_layers=3
27 |
28 | # Training options
29 | niter=30
30 | niter_decay=20
31 | epoch_count=0
32 | continue_train=False
33 | which_epoch='latest'
34 |
35 | # Logging options
36 | verbose=True
37 | print_every=10
38 | visual_every=1000
39 | save_every=5
40 |
41 |
42 | class TrainOptions(BaseOptions):
43 | phase='train'
44 | isTrain=True
45 |
46 | # Training Options
47 | lr=0.0002
48 | beta1=0.5
49 | gan_mode='hinge'
50 | lr_policy='linear'
51 | init_type='xavier'
52 | init_gain=0.02
53 |
54 | lambda_perc = 1.0
55 | lambda_domain = 5.0
56 | lambda_feat = 10.0
57 | lambda_context = 10.0
58 | lambda_reg = 1.0
59 | lambda_adv = 1.0
60 |
61 | # To resume training, uncomment the following lines
62 | # continue_train=True
63 | # which_epoch='latest' # or a certain number (e.g. '10' or '20200525-112233')
64 |
65 | class DebugOptions(TrainOptions):
66 | max_dataset_size=4
67 | num_workers=0
68 | print_every=1
69 | visual_every=1
70 | save_every=1
71 | niter=2
72 | niter_decay=1
73 | verbose=False
74 |
75 | class TestOptions(BaseOptions):
76 | phase='test'
77 | isTrain=False
78 | serial_batches=True
79 | num_workers=0
80 | batch_size=1
81 | which_epoch='latest'
82 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lotayou/CoCosNet/93142f55ff09e8ee6052d8b5c81931e7f9570093/test.py
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from options import DebugOptions, TrainOptions
2 | from data import create_dataset
3 | from model import create_model
4 | from torch.backends import cudnn
5 | import torch
6 | #opt = DebugOptions()
7 | opt = TrainOptions()
8 | #os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu_ids[0]) # test single GPU first
9 |
10 | torch.cuda.set_device(opt.gpu_ids[0])
11 | cudnn.enabled = True
12 | cudnn.benchmark = True
13 |
14 | loader = create_dataset(opt)
15 | dataset_size = len(loader)
16 | print('#training images = %d' % dataset_size)
17 |
18 | net = create_model(opt)
19 |
20 | for epoch in range(1,opt.niter+opt.niter_decay+1):
21 | print('Begin epoch %d' % epoch)
22 | for i, data_i in enumerate(loader):
23 | net.set_input(data_i)
24 | net.optimize_parameters()
25 |
26 | #### logging, visualizing, saving
27 | if i % opt.print_every == 0:
28 | net.log_loss(epoch, i)
29 | if i % opt.visual_every == 0:
30 | net.log_visual(epoch, i)
31 |
32 | net.save_networks('latest')
33 | if epoch % opt.save_every == 0:
34 | net.save_networks(epoch)
35 | net.update_learning_rate()
36 |
--------------------------------------------------------------------------------