├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── adv_utils.py ├── assets ├── imgs │ ├── celebA_resnet_adv_targeted_large_plot.jpg │ ├── compare_adv_imgs.png │ ├── mnist_madry_adv_targeted_large_plot.jpg │ └── svhn_resnet_adv_targeted_large_plot.png └── pretrained │ ├── mnist_aditi_adv │ ├── B1.npy │ ├── B2.npy │ ├── W1.npy │ └── W2.npy │ └── mnist_zico_adv │ └── mnist.pth ├── main.py ├── models ├── __init__.py ├── acwgan_gp.py ├── aditi_mnist.py ├── libs │ ├── ops.py │ ├── resnet_ops.py │ └── sn.py ├── madry_mnist.py ├── resnet_model.py ├── vgg16.py └── zico_mnist.py ├── mturk_websites ├── mturk.html ├── mturk_abtest.html └── mturk_celeba.html ├── ops.py ├── train_acgan.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | assets/data 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Constructing Unrestricted Adversarial Examples with Generative Models 2 | 3 | This repo contains necessary code for reproducing main results in the paper [Constructing Unrestricted Adversarial Examples with Generative Models](https://arxiv.org/abs/1805.07894), NIPS 2018, Montréal, Canada. 4 | 5 | by [Yang Song](https://yang-song.github.io/), [Rui Shu](https://ruishu.io//), [Nate Kushman](http://www.kushman.org/) and [Stefano Ermon](https://cs.stanford.edu/~ermon/), Stanford AI Lab. 6 | 7 | --- 8 | 9 | We propose **Unrestricted Adversarial Examples**, a new kind of adversarial examples to machine learning systems. Different from traditional adversarial examples that are crafted by adding norm-bounded perturbations to clean images, unrestricted adversarial examples are _realistic images that are synthesized entirely from scratch_, and not restricted to small norm-balls. This new attack demonstrates the danger of a stronger **threat model**, where traditional defense methods for perturbation-based adversarial examples fail. 10 | 11 | ## Datasets 12 | 13 | Here are links to the datasets used in our experiments: 14 | * [CelebA (gender)](https://drive.google.com/open?id=1coLQbEZW6zshHVYi00IYSRiexq4RkA2x) 15 | * [SVHN](https://drive.google.com/open?id=1uPxNdW4K-GLFhqhOgtfI1jFFNEqp2eZn) 16 | 17 | ## Running Experiments 18 | 19 | ### Training AC-GANs 20 | 21 | In order to do unrestricted adversarial attack, we first need a good conditional generative model so that we can search on the manifold of realistic images to find the adversarial ones. You can use `train_acgan.py` to do this. For example, the following command 22 | 23 | ```bash 24 | CUDA_VISIBLE_DEVICES=0 python train_acgan.py --dataset mnist --checkpoint_dir checkpoints/ 25 | ``` 26 | 27 | will train an AC-GAN on the `MNIST` dataset with GPU #0 and output the weight files to the `checkpoints/` directory. 28 | 29 | Run `python train_acgan.py --help` to see more available argument options. 30 | 31 | ### Unrestricted Adversarial Attack 32 | 33 | After the AC-GAN is trained, you can use `main.py` to do targeted / untargeted attack. You can also use `main.py` to evaluate the accuracy and PGD-robustness of a trained neural network classifier. For example, the following command 34 | 35 | ```bash 36 | CUDA_VISIBLE_DEVICES=0 python main.py --mode targeted_attack --dataset mnist --classifier zico --source 0 --target 1 37 | ``` 38 | 39 | attacks the provable defense method from [Kolter & Wong, 2018](https://arxiv.org/pdf/1711.00851.pdf) on the `MNIST` dataset, with the source class being 0 and target class being 1. 40 | 41 | Run `python main.py --help` to view more argument options. For hyperparameters such as `--noise`, `--lambda1`, `--lambda2`, `--eps`, `--z_eps`, `--lr`, and `--n_iters` (in that order), please refer to **Table. 4** in the Appendix of our [paper](https://arxiv.org/pdf/1805.07894.pdf). 42 | 43 | ### Evaluating Unrestricted Adversarial Examples 44 | 45 | In the paper, we use [Amazon Mechanical Turk](https://www.mturk.com/) to evaluate whether our unrestricted adversarial examples are legitimate or not. We have provided `html` files for the labelling interface in folder `amt_websites`. 46 | 47 | 48 | ## Samples 49 | 50 | Perturbation-based adversarial examples (top row) VS unrestricted adversarial examples (bottom-row): 51 | 52 | ![compare](assets/imgs/compare_adv_imgs.png) 53 | 54 | Targeted unrestricted adversarial examples against robust classifiers on `MNIST` (Green borders denote legitimate unrestricted adversarial examples while red borders denote illegimate ones. The tiny white text at the top-left corder of a red image denotes the label given by the annotators. ) 55 | 56 | ![mnist](assets/imgs/mnist_madry_adv_targeted_large_plot.jpg) 57 | 58 | We also have samples for `SVHN` dataset: 59 | 60 | ![svhn](assets/imgs/svhn_resnet_adv_targeted_large_plot.png) 61 | 62 | Finally here are the results for `CelebA` 63 | 64 | ![celeba](assets/imgs/celebA_resnet_adv_targeted_large_plot.jpg) 65 | 66 | ## Citation 67 | 68 | If you find the idea or code useful for your research, please consider citing our [paper](https://arxiv.org/abs/1805.07894): 69 | 70 | ```bib 71 | @inproceedings{song2018constructing, 72 | author={Song, Yang and Shu, Rui and Kushman, Nate and Ermon, Stefano}, 73 | booktitle = {Advances in Neural Information Processing Systems (NIPS)}, 74 | title = {Constructing Unrestricted Adversarial Examples with Generative Models}, 75 | year = {2018}, 76 | } 77 | ``` 78 | 79 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/generative_adversary/d276718254cbf40d1951d3f26e0ea43e7772af73/__init__.py -------------------------------------------------------------------------------- /adv_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import numpy as np 4 | 5 | 6 | def label_smooth(y, weight=0.9): 7 | # requires y to be one_hot! 8 | return tf.clip_by_value(y, clip_value_min=(1.0 - weight) / (FLAGS.num_classes - 1.), clip_value_max=weight) 9 | 10 | 11 | def random_flip_left_right(images): 12 | images_flipped = tf.reverse(images, axis=[2]) 13 | flip = tf.cast(tf.contrib.distributions.Bernoulli(probs=tf.ones((tf.shape(images)[0],)) * 0.5).sample(), tf.bool) 14 | final_images = tf.where(flip, x=images, y=images_flipped) 15 | return final_images 16 | 17 | 18 | def feature_squeeze(images, dataset='cifar'): 19 | # color depth reduction 20 | if dataset == 'cifar': 21 | npp = 2 ** 5 22 | elif dataset == 'mnist': 23 | npp = 2 ** 3 24 | 25 | npp_int = npp - 1 26 | images = images / 255. 27 | x_int = tf.rint(tf.multiply(images, npp_int)) 28 | x_float = tf.div(x_int, npp_int) 29 | return median_filtering_2x2(x_float, dataset=dataset) 30 | 31 | 32 | def median_filtering_2x2(images, dataset='cifar'): 33 | def median_filtering_layer_2x2(channel): 34 | top = tf.pad(channel, paddings=[[0, 0], [1, 0], [0, 0]], mode="REFLECT")[:, :-1, :] 35 | left = tf.pad(channel, paddings=[[0, 0], [0, 0], [1, 0]], mode="REFLECT")[:, :, :-1] 36 | top_left = tf.pad(channel, paddings=[[0, 0], [1, 0], [1, 0]], mode="REFLECT")[:, :-1, :-1] 37 | comb = tf.stack([channel, top, left, top_left], axis=3) 38 | return tf.nn.top_k(comb, 2).values[..., -1] 39 | 40 | if dataset == 'cifar': 41 | c0 = median_filtering_layer_2x2(images[..., 0]) 42 | c1 = median_filtering_layer_2x2(images[..., 1]) 43 | c2 = median_filtering_layer_2x2(images[..., 2]) 44 | return tf.stack([c0, c1, c2], axis=3) 45 | elif dataset == 'mnist': 46 | return median_filtering_layer_2x2(images[..., 0])[..., None] 47 | 48 | 49 | def normalize_image(images): 50 | return (images.astype(np.int32) - 127.5) / 127.5 51 | 52 | 53 | def unnormalize_image(images): 54 | return images * 127.5 + 127.5 55 | 56 | def get_weights_path(args): 57 | prefix = os.path.join('assets', 'pretrained') 58 | folder = args.dataset + '_' + args.classifier 59 | if args.trained: 60 | folder += '_trained' 61 | if args.adv: 62 | folder += '_adv' 63 | if args.adv_gen: 64 | folder += '_advgen' 65 | 66 | path = os.path.join(prefix, folder) 67 | if not os.path.exists(path): 68 | os.makedirs(path) 69 | 70 | ckpt_path = os.path.join(path, 'model.ckpt') 71 | return path, ckpt_path -------------------------------------------------------------------------------- /assets/imgs/celebA_resnet_adv_targeted_large_plot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/generative_adversary/d276718254cbf40d1951d3f26e0ea43e7772af73/assets/imgs/celebA_resnet_adv_targeted_large_plot.jpg -------------------------------------------------------------------------------- /assets/imgs/compare_adv_imgs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/generative_adversary/d276718254cbf40d1951d3f26e0ea43e7772af73/assets/imgs/compare_adv_imgs.png -------------------------------------------------------------------------------- /assets/imgs/mnist_madry_adv_targeted_large_plot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/generative_adversary/d276718254cbf40d1951d3f26e0ea43e7772af73/assets/imgs/mnist_madry_adv_targeted_large_plot.jpg -------------------------------------------------------------------------------- /assets/imgs/svhn_resnet_adv_targeted_large_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/generative_adversary/d276718254cbf40d1951d3f26e0ea43e7772af73/assets/imgs/svhn_resnet_adv_targeted_large_plot.png -------------------------------------------------------------------------------- /assets/pretrained/mnist_aditi_adv/B1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/generative_adversary/d276718254cbf40d1951d3f26e0ea43e7772af73/assets/pretrained/mnist_aditi_adv/B1.npy -------------------------------------------------------------------------------- /assets/pretrained/mnist_aditi_adv/B2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/generative_adversary/d276718254cbf40d1951d3f26e0ea43e7772af73/assets/pretrained/mnist_aditi_adv/B2.npy -------------------------------------------------------------------------------- /assets/pretrained/mnist_aditi_adv/W1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/generative_adversary/d276718254cbf40d1951d3f26e0ea43e7772af73/assets/pretrained/mnist_aditi_adv/W1.npy -------------------------------------------------------------------------------- /assets/pretrained/mnist_aditi_adv/W2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/generative_adversary/d276718254cbf40d1951d3f26e0ea43e7772af73/assets/pretrained/mnist_aditi_adv/W2.npy -------------------------------------------------------------------------------- /assets/pretrained/mnist_zico_adv/mnist.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/generative_adversary/d276718254cbf40d1951d3f26e0ea43e7772af73/assets/pretrained/mnist_zico_adv/mnist.pth -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | import models.resnet_model as resnet_model 4 | from cleverhans.attacks_tf import fgm 5 | from models.madry_mnist import MadryModel 6 | from models.aditi_mnist import AditiMNIST 7 | from models.zico_mnist import ZicoMNIST 8 | from utils import * 9 | from adv_utils import * 10 | from models.vgg16 import vgg_16 11 | from models.acwgan_gp import ACWGAN_GP 12 | import argparse 13 | from scipy.misc import imsave 14 | 15 | parser = argparse.ArgumentParser("Generative Adversarial Examples") 16 | parser.add_argument('--dataset', type=str, default='mnist', help="mnist | svhn | celebA") 17 | parser.add_argument('--adv', action='store_true', help="using adversarially trained network") 18 | parser.add_argument('--classifier', type=str, default='resnet', help='resnet | vgg | madry | aditi | zico') 19 | parser.add_argument('--datapath', type=str, default=None, help="input data path") 20 | parser.add_argument('--seed', type=int, default=1234, help="random seed") 21 | parser.add_argument('--batch_size', type=int, default=64, help="batch size") 22 | parser.add_argument('--mode', type=str, default='attack', help='eval | targeted_attack | untargeted_attack') 23 | parser.add_argument('--top5', action='store_true', help="use top5 error") 24 | 25 | parser.add_argument('--lr', type=float, default=1, help="learning rate for doing targeted/untargeted attack") 26 | parser.add_argument('--n_adv_examples', type=int, default=1000000, 27 | help="number of adversarial examples batches to search") 28 | parser.add_argument('--n_iters', type=int, default=1000, 29 | help="number of inner iterations for computing adversarial examples") 30 | parser.add_argument('--z_dim', type=int, default=128, help="dimension of noise vector") 31 | parser.add_argument('--checkpoint_dir', type=str, default='assets/checkpoint', 32 | help='Directory name to save the checkpoints') 33 | parser.add_argument('--result_dir', type=str, default='assets/results', 34 | help='Directory name to save the generated images') 35 | parser.add_argument('--log_dir', type=str, default='assets/logs', 36 | help='Directory name to save training logs') 37 | parser.add_argument('--source', type=int, default=0, help="ground truth class (source class)") 38 | parser.add_argument('--target', type=int, default=1, help="target class") 39 | parser.add_argument('--lambda1', type=float, default=100, help="coefficient for the closeness regularization term") 40 | parser.add_argument('--lambda2', type=float, default=100, help="coefficient for the repulsive regularization term") 41 | parser.add_argument('--n2collect', type=int, default=1024, help="number of adversarial examples to collect") 42 | parser.add_argument('--eps', type=float, default=0.1, help="eps for attack augmented with noise") 43 | parser.add_argument('--noise', action="store_true", help="add noise augmentation to attacks") 44 | parser.add_argument('--z_eps', type=float, default=0.1, help="soft constraint for the search region of latent space") 45 | parser.add_argument('--adv_gen', action="store_true", help="adversarial training using generative adversarial examples") 46 | parser.add_argument('--trained', action="store_true", help="trained models") 47 | 48 | args = parser.parse_args() 49 | 50 | 51 | def resnet_template(images, training, hps): 52 | # Do per image standardization 53 | images_standardized = per_image_standardization(images) 54 | model = resnet_model.ResNet(hps, images_standardized, training) 55 | model.build_graph() 56 | return model.logits 57 | 58 | 59 | def vgg_template(images, training, hps): 60 | images_standardized = per_image_standardization(images) 61 | logits, _ = vgg_16(images_standardized, num_classes=hps.num_classes, is_training=training, dataset=hps.dataset) 62 | return logits 63 | 64 | 65 | def madry_template(images, training): 66 | model = MadryModel(images) 67 | return model.pre_softmax 68 | 69 | 70 | def aditi_template(images, training): 71 | model = AditiMNIST(images) 72 | return model.logits 73 | 74 | 75 | def zico_template(images, training): 76 | model = ZicoMNIST(images) 77 | return model.logits 78 | 79 | 80 | def evaluate(hps, data_X, data_y, eval_once=True): 81 | """Eval loop.""" 82 | images = tf.placeholder(tf.float32, shape=(None, args.image_size, args.image_size, args.channels)) 83 | 84 | labels_onehot = tf.placeholder(tf.int32, shape=(None, args.num_classes)) 85 | labels = tf.argmax(labels_onehot, axis=1) 86 | 87 | if args.classifier == "madry": 88 | net = tf.make_template('net', madry_template) 89 | logits = net(images, training=False) 90 | elif args.classifier == 'aditi': 91 | net = tf.make_template('net', aditi_template) 92 | logits = net(images, training=False) 93 | elif args.classifier == 'zico': 94 | net = tf.make_template('net', zico_template) 95 | logits = net(images, training=False) 96 | else: 97 | net = tf.make_template('net', resnet_template, hps=hps) if args.classifier == 'resnet' else \ 98 | tf.make_template('net', vgg_template, hps=hps) 99 | logits = net(images, training=False) 100 | 101 | pred = tf.argmax(logits, axis=1) 102 | probs = tf.nn.softmax(logits) 103 | 104 | cost = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels_onehot) 105 | adv_image = fgm(images, tf.nn.softmax(logits), y=labels_onehot, eps=args.eps / 10, clip_min=0.0, clip_max=1.0) 106 | top_5 = tf.nn.in_top_k(predictions=logits, targets=labels, k=5) 107 | 108 | saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net')) 109 | if args.classifier == 'madry' and not args.trained: 110 | saver = tf.train.Saver( 111 | {x.name[4:-2]: x for x in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="net")}) 112 | 113 | config = tf.ConfigProto(allow_soft_placement=True) 114 | config.gpu_options.allow_growth = True 115 | sess = tf.Session(config=config) 116 | 117 | best_precision = 0.0 118 | save_path, save_path_ckpt = get_weights_path(args) 119 | while True: 120 | try: 121 | ckpt_state = tf.train.get_checkpoint_state(save_path) 122 | except tf.errors.OutOfRangeError as e: 123 | print('[!] Cannot restore checkpoint: %s', e) 124 | break 125 | if not (ckpt_state and ckpt_state.model_checkpoint_path): 126 | print('[!] No model to eval yet at %s', save_path) 127 | break 128 | print('[*] Loading checkpoint %s' % ckpt_state.model_checkpoint_path) 129 | saver.restore(sess, ckpt_state.model_checkpoint_path) 130 | 131 | total_prediction, correct_prediction = 0, 0 132 | adv_prediction = 0 133 | total_loss = 0 134 | all_preds = [] 135 | batch_size = args.batch_size 136 | num_batch = len(data_X) // batch_size 137 | bad_images = [] 138 | bad_labels = [] 139 | confidences = [] 140 | adv_images = [] 141 | cls_preds = [] 142 | true_labels = [] 143 | for batch in range(num_batch): 144 | x = data_X[batch * batch_size: (batch + 1) * batch_size] 145 | x = x.astype(np.float32) 146 | y = data_y[batch * batch_size: (batch + 1) * batch_size] 147 | y = y.astype(np.int32) 148 | if not args.top5: 149 | (loss, predictions, conf) = sess.run( 150 | [cost, pred, probs], feed_dict={ 151 | images: x, 152 | labels_onehot: y 153 | }) 154 | all_preds.extend(predictions) 155 | confidences.extend(conf[np.arange(conf.shape[0]), predictions]) 156 | img_np = np.copy(x) 157 | for i in range(100): 158 | img_np = sess.run(adv_image, feed_dict={ 159 | images: img_np, 160 | labels_onehot: y 161 | }) 162 | img_np = np.clip(img_np, x - args.eps, x + args.eps) 163 | img_np = np.clip(img_np, 0.0, 1.0) 164 | adv_images.extend(img_np) 165 | 166 | adv_pred_np = pred.eval(session=sess, feed_dict={ 167 | images: img_np, 168 | labels_onehot: y 169 | }) 170 | 171 | cls_preds.extend(adv_pred_np) 172 | true_labels.extend(np.argmax(y, axis=1)) 173 | else: 174 | (loss, in_top5) = sess.run( 175 | [cost, top_5], feed_dict={ 176 | images: x, 177 | labels_onehot: y 178 | } 179 | ) 180 | total_loss += np.sum(loss) 181 | y = np.argmax(y, axis=1) 182 | correct_prediction += np.sum(y == predictions) if not args.top5 else np.sum(in_top5) 183 | bad_images.extend(x[y != predictions]) 184 | bad_labels.extend(predictions[y != predictions]) 185 | adv_prediction += np.sum(y == adv_pred_np) 186 | total_prediction += loss.shape[0] 187 | 188 | precision = 1.0 * correct_prediction / total_prediction 189 | loss = 1.0 * total_loss / total_prediction 190 | best_precision = max(precision, best_precision) 191 | average_conf = np.mean(np.asarray(confidences)) 192 | adv_images = np.asarray(adv_images) 193 | cls_preds = np.asarray(cls_preds) 194 | true_labels = np.asarray(true_labels) 195 | 196 | if not args.top5: 197 | print('[*] loss: %.6f, precision: %.6f, PGD precision: %.6f, Confidence: %.6f' % 198 | (loss, precision, adv_prediction / total_prediction, average_conf)) 199 | folder_format = '/atlas/u/yangsong/generative_adversary/{}_{}_pgd/' 200 | np.savez(os.path.join(check_folder(folder_format.format(args.dataset, args.classifier)), 201 | 'eps_{:.3f}.npz'.format(args.eps)), 202 | adv_images=adv_images, cls_preds=cls_preds, true_labels=true_labels) 203 | else: 204 | print('[*] loss: %.6f, top 5 accuracy: %.6f, best top 5 accuracy: %.6f' % 205 | (loss, precision, best_precision)) 206 | 207 | bad_images = np.asarray(bad_images) 208 | bad_images = (255. * bad_images).astype(np.uint8) 209 | bad_labels = np.asarray(bad_labels).astype(np.uint8) 210 | 211 | if len(bad_images) > 10: 212 | bad_images = bad_images[:10] 213 | bad_labels = bad_labels[:10] 214 | 215 | bad_images = np.reshape(bad_images, (len(bad_images) * args.image_size, args.image_size, args.channels)) 216 | bad_images = np.squeeze(bad_images) 217 | 218 | imsave(os.path.join(check_folder('tmp'), 'bad_images.png'), bad_images) 219 | print("bad_labels:\n{}".format(bad_labels)) 220 | 221 | if eval_once: 222 | break 223 | 224 | time.sleep(60) 225 | 226 | 227 | def untargeted_attack(hps, lambda1, lambda2, source, noise=False): 228 | """generative adversarial attack""" 229 | 230 | source_np = np.asarray([source] * args.batch_size).astype(np.int32) 231 | if args.classifier == "madry": 232 | net = tf.make_template('net', madry_template) 233 | elif args.classifier == 'aditi': 234 | net = tf.make_template('net', aditi_template) 235 | elif args.classifier == 'zico': 236 | net = tf.make_template('net', zico_template) 237 | else: 238 | net = tf.make_template('net', resnet_template, hps=hps) if args.classifier == 'resnet' else \ 239 | tf.make_template('net', vgg_template, hps=hps) 240 | 241 | adv_noise = tf.get_variable('adv_noise', shape=(args.batch_size, args.image_size, args.image_size, args.channels), 242 | dtype=tf.float32, initializer=tf.zeros_initializer) 243 | adv_z = tf.get_variable('adv_z', 244 | shape=(args.batch_size, args.z_dim), 245 | dtype=tf.float32, 246 | initializer=tf.random_normal_initializer) 247 | 248 | ref_z = tf.get_variable('ref_z', 249 | shape=(args.batch_size, args.z_dim), 250 | dtype=tf.float32, 251 | initializer=tf.zeros_initializer) 252 | 253 | config = tf.ConfigProto(allow_soft_placement=True) 254 | config.gpu_options.allow_growth = True 255 | sess = tf.Session(config=config) 256 | if args.dataset == 'mnist': 257 | dim_D = 32 258 | dim_G = 32 259 | elif args.dataset == 'svhn': 260 | dim_D = 128 261 | dim_G = 128 262 | elif args.dataset == 'celebA': 263 | dim_D = 64 264 | dim_G = 64 265 | 266 | acgan = ACWGAN_GP( 267 | sess, 268 | epoch=10, 269 | batch_size=args.batch_size, 270 | z_dim=args.z_dim, 271 | dataset_name=args.dataset, 272 | checkpoint_dir=args.checkpoint_dir, 273 | result_dir=args.result_dir, 274 | log_dir=args.log_dir, 275 | dim_D=dim_D, 276 | dim_G=dim_G 277 | ) 278 | 279 | acgan.build_model() 280 | 281 | adv_images = acgan.generator(adv_z, source_np, reuse=True) 282 | 283 | _, acgan_logits = acgan.discriminator(adv_images, update_collection=None, reuse=True) 284 | acgan_pred = tf.argmax(acgan_logits, axis=1) 285 | 286 | if noise: 287 | adv_images += args.eps * tf.nn.tanh(adv_noise) 288 | if args.dataset in ('svhn', 'celebA'): 289 | adv_images = tf.clip_by_value(adv_images, clip_value_min=-1., clip_value_max=1.0) 290 | else: 291 | adv_images = tf.clip_by_value(adv_images, clip_value_min=0., clip_value_max=1.) 292 | 293 | net_logits = net(adv_images, training=False) 294 | net_softmax = tf.nn.softmax(net_logits) 295 | net_pred = tf.argmax(net_logits, axis=1) 296 | 297 | # loop over all classes 298 | obj_classes = [] 299 | for i in range(args.num_classes): 300 | if i == source: 301 | continue 302 | onehot = np.zeros((args.batch_size, args.num_classes), dtype=np.float32) 303 | onehot[:, i] = 1.0 304 | obj_classes.append(tf.nn.softmax_cross_entropy_with_logits(logits=net_logits, labels=onehot)) 305 | 306 | all_cross_entropy = tf.stack(obj_classes, axis=1) 307 | min_cross_entropy = tf.reduce_mean(tf.reduce_min(all_cross_entropy, axis=1)) 308 | 309 | obj = min_cross_entropy + \ 310 | lambda1 * tf.reduce_mean(tf.maximum(tf.square(ref_z - adv_z) - args.z_eps ** 2, 0.0)) + \ 311 | lambda2 * tf.reduce_mean( 312 | tf.nn.sparse_softmax_cross_entropy_with_logits(logits=acgan_logits, labels=source_np)) 313 | 314 | _iter = tf.placeholder(tf.float32, shape=(), name="iter") 315 | with tf.variable_scope("train_ops"): 316 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=args.lr) 317 | var = 0.01 / (1. + _iter) ** 0.55 318 | if noise: 319 | grads = optimizer.compute_gradients(obj, var_list=[adv_z, adv_noise]) 320 | else: 321 | grads = optimizer.compute_gradients(obj, var_list=[adv_z]) 322 | 323 | new_grads = [] 324 | for grad, v in grads: 325 | if v is not adv_noise: 326 | new_grads.append((grad + tf.random_normal(shape=grad.get_shape().as_list(), stddev=tf.sqrt(var)), v)) 327 | else: 328 | new_grads.append((grad / tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1, 2, 3], keep_dims=True)), v)) 329 | 330 | adv_op = optimizer.apply_gradients(new_grads) 331 | 332 | momentum_init = tf.variables_initializer(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='train_ops')) 333 | init_op = tf.group(momentum_init, tf.variables_initializer([adv_z, adv_noise])) 334 | with tf.control_dependencies([init_op]): 335 | init_op = tf.group(init_op, tf.assign(ref_z, adv_z)) 336 | 337 | sess.run(tf.global_variables_initializer()) 338 | # load classifier 339 | save_path, save_path_ckpt = get_weights_path(args) 340 | if args.classifier == 'madry': 341 | if args.trained: 342 | saver4classifier = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="net")) 343 | else: 344 | saver4classifier = tf.train.Saver( 345 | {x.name[4:-2]: x for x in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="net")}) 346 | else: 347 | saver4classifier = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net')) 348 | 349 | checkpoint_dir = os.path.join(args.checkpoint_dir, acgan.model_dir, acgan.model_name) 350 | try: 351 | ckpt_state = tf.train.get_checkpoint_state(save_path) 352 | except tf.errors.OutOfRangeError as e: 353 | print('[!] Cannot restore checkpoint: %s' % e) 354 | if not (ckpt_state and ckpt_state.model_checkpoint_path): 355 | print('[!] No model to eval yet at %s' % save_path) 356 | print('[*] Loading checkpoint %s' % ckpt_state.model_checkpoint_path) 357 | saver4classifier.restore(sess, ckpt_state.model_checkpoint_path) 358 | # load ACGAN 359 | saver4gen = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 360 | scope='generator|discriminator|classifier')) 361 | try: 362 | ckpt_state = tf.train.get_checkpoint_state(checkpoint_dir) 363 | except tf.errors.OutOfRangeError as e: 364 | print('[!] Cannot restore checkpoint: %s' % e) 365 | if not (ckpt_state and ckpt_state.model_checkpoint_path): 366 | print('[!] No model to eval yet at %s' % checkpoint_dir) 367 | print('[*] Loading checkpoint %s' % ckpt_state.model_checkpoint_path) 368 | saver4gen.restore(sess, ckpt_state.model_checkpoint_path) 369 | 370 | acc = 0. 371 | adv_acc = 0. 372 | adv_im_np = [] 373 | adv_labels_np = [] 374 | latent_z = [] 375 | for batch in range(args.n_adv_examples): 376 | # random ground truth classes 377 | sess.run(init_op) 378 | preds_np, probs_np, im_np, cost_before = sess.run([net_pred, net_softmax, adv_images, obj]) 379 | 380 | ###### Using GD for attacking 381 | # initialize optimizers 382 | for i in range(args.n_iters): 383 | _, now_cost, pred_np, acgan_pred_np = sess.run([adv_op, obj, net_pred, acgan_pred], feed_dict={_iter: i}) 384 | ok = np.logical_and(pred_np != source, acgan_pred_np == source) 385 | print(" [*] {}th iter, cost: {}, success: {}/{}".format(i + 1, now_cost, np.sum(ok), args.batch_size)) 386 | 387 | adv_preds_np, adv_probs_np, im_np, cost_after, hidden_z, acgan_pred_np = sess.run( 388 | [net_pred, net_softmax, adv_images, obj, adv_z, acgan_pred]) 389 | acc += np.sum(preds_np == source) 390 | adv_acc += np.sum(adv_preds_np == source) 391 | idx = np.logical_and(adv_preds_np != source, acgan_pred_np == source) 392 | adv_im_np.extend(im_np[idx]) 393 | adv_labels_np.extend(adv_preds_np[idx]) 394 | latent_z.extend(hidden_z[idx]) 395 | print("batch: {}, acc: {}, adv_acc: {}, num collected: {}, cost before: {}, cost after: {}". 396 | format(batch + 1, 397 | acc / ((batch + 1) * args.batch_size), 398 | adv_acc / ((batch + 1) * args.batch_size), 399 | len(adv_im_np), cost_before, cost_after)) 400 | 401 | if len(adv_im_np) >= args.n2collect: 402 | adv_im_np = np.asarray(adv_im_np) 403 | adv_labels_np = np.asarray(adv_labels_np) 404 | latent_z = np.asarray(latent_z) 405 | classifier = args.classifier 406 | if args.adv: 407 | classifier += '_adv' 408 | 409 | folder_format = '{}_{}_untargeted_attack' 410 | if noise: folder_format += '_noise' 411 | np.savez(os.path.join(check_folder(folder_format.format(args.dataset, classifier)), 412 | 'source_{}'.format(source)), adv_labels=adv_labels_np, adv_imgs=adv_im_np, 413 | latent_z=latent_z) 414 | size = int(np.sqrt(args.n2collect)) 415 | write_labels(adv_labels_np, args.dataset, size) 416 | img = label_images(adv_im_np[:args.n2collect, ...], adv_labels_np[:args.n2collect]) 417 | save_images(img, [size, size], os.path.join( 418 | check_folder(folder_format.format(args.dataset, classifier)), 's_{}_ims.png').format(args.source)) 419 | break 420 | 421 | 422 | def targeted_attack(hps, source, target, lambda1, lambda2, noise=False): 423 | """targeted generative adversarial attack""" 424 | 425 | source_np = np.asarray([source] * args.batch_size).astype(np.int32) 426 | target_np = np.asarray([target] * args.batch_size).astype(np.int32) 427 | 428 | if args.classifier == "madry": 429 | net = tf.make_template('net', madry_template) 430 | elif args.classifier == 'aditi': 431 | net = tf.make_template('net', aditi_template) 432 | elif args.classifier == 'zico': 433 | net = tf.make_template('net', zico_template) 434 | else: 435 | net = tf.make_template('net', resnet_template, hps=hps) if args.classifier == 'resnet' else \ 436 | tf.make_template('net', vgg_template, hps=hps) 437 | 438 | adv_noise = tf.get_variable('adv_noise', shape=(args.batch_size, args.image_size, args.image_size, args.channels), 439 | dtype=tf.float32, initializer=tf.zeros_initializer) 440 | adv_z = tf.get_variable('adv_z', 441 | shape=(args.batch_size, args.z_dim), 442 | dtype=tf.float32, 443 | initializer=tf.random_normal_initializer) 444 | 445 | ref_z = tf.get_variable('ref_z', 446 | shape=(args.batch_size, args.z_dim), 447 | dtype=tf.float32, 448 | initializer=tf.zeros_initializer) 449 | 450 | config = tf.ConfigProto(allow_soft_placement=True) 451 | config.gpu_options.allow_growth = True 452 | sess = tf.Session(config=config) 453 | if args.dataset == 'mnist': 454 | dim_D = 32 455 | dim_G = 32 456 | elif args.dataset == 'svhn': 457 | dim_D = 128 458 | dim_G = 128 459 | elif args.dataset == 'celebA': 460 | dim_D = 64 461 | dim_G = 64 462 | 463 | acgan = ACWGAN_GP( 464 | sess, 465 | epoch=10, 466 | batch_size=args.batch_size, 467 | z_dim=args.z_dim, 468 | dataset_name=args.dataset, 469 | checkpoint_dir=args.checkpoint_dir, 470 | result_dir=args.result_dir, 471 | log_dir=args.log_dir, 472 | dim_D=dim_D, 473 | dim_G=dim_G 474 | ) 475 | 476 | acgan.build_model() 477 | 478 | adv_images = acgan.generator(adv_z, source_np, reuse=True) 479 | _, acgan_logits = acgan.discriminator(adv_images, update_collection=None, reuse=True) 480 | acgan_pred = tf.argmax(acgan_logits, axis=1) 481 | acgan_softmax = tf.nn.softmax(acgan_logits) 482 | 483 | if noise: 484 | adv_images += args.eps * tf.tanh(adv_noise) 485 | if args.dataset in ('svhn', 'celebA'): 486 | adv_images = tf.clip_by_value(adv_images, clip_value_min=-1., clip_value_max=1.) 487 | else: 488 | adv_images = tf.clip_by_value(adv_images, clip_value_min=0., clip_value_max=1.0) 489 | 490 | net_logits = net(adv_images, training=False) 491 | net_softmax = tf.nn.softmax(net_logits) 492 | net_pred = tf.argmax(net_logits, axis=1) 493 | 494 | obj = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=net_logits, labels=target_np)) + \ 495 | lambda1 * tf.reduce_mean(tf.maximum(tf.square(ref_z - adv_z) - args.z_eps ** 2, 0.0)) + \ 496 | lambda2 * tf.reduce_mean( 497 | tf.nn.sparse_softmax_cross_entropy_with_logits(logits=acgan_logits, labels=source_np)) 498 | 499 | _iter = tf.placeholder(tf.float32, shape=(), name="iter") 500 | with tf.variable_scope("train_ops"): 501 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=args.lr) 502 | var = 0.01 / (1. + _iter) ** 0.55 503 | if noise: 504 | grads = optimizer.compute_gradients(obj, var_list=[adv_z, adv_noise]) 505 | else: 506 | grads = optimizer.compute_gradients(obj, var_list=[adv_z]) 507 | 508 | new_grads = [] 509 | for grad, v in grads: 510 | if v is not adv_noise: 511 | new_grads.append((grad + tf.random_normal(shape=grad.get_shape().as_list(), stddev=tf.sqrt(var)), v)) 512 | else: 513 | new_grads.append((grad / tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1, 2, 3], keep_dims=True)), v)) 514 | 515 | adv_op = optimizer.apply_gradients(new_grads) 516 | 517 | momentum_init = tf.variables_initializer(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='train_ops')) 518 | init_op = tf.group(momentum_init, tf.variables_initializer([adv_z, adv_noise])) 519 | with tf.control_dependencies([init_op]): 520 | init_op = tf.group(init_op, tf.assign(ref_z, adv_z)) 521 | 522 | sess.run(tf.global_variables_initializer()) 523 | # load classifier 524 | save_path, save_path_ckpt = get_weights_path(args) 525 | if args.classifier == 'madry': 526 | if args.trained: 527 | saver4classifier = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="net")) 528 | else: 529 | saver4classifier = tf.train.Saver( 530 | {x.name[4:-2]: x for x in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="net")}) 531 | else: 532 | saver4classifier = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net')) 533 | checkpoint_dir = os.path.join(args.checkpoint_dir, acgan.model_dir, acgan.model_name) 534 | try: 535 | ckpt_state = tf.train.get_checkpoint_state(save_path) 536 | except tf.errors.OutOfRangeError as e: 537 | print('[!] Cannot restore checkpoint: %s' % e) 538 | if not (ckpt_state and ckpt_state.model_checkpoint_path): 539 | print('[!] No model to eval yet at %s' % save_path) 540 | print('[*] Loading checkpoint %s' % ckpt_state.model_checkpoint_path) 541 | saver4classifier.restore(sess, ckpt_state.model_checkpoint_path) 542 | # load ACGAN 543 | 544 | saver4gen = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 545 | scope='generator|discriminator|classifier')) 546 | try: 547 | ckpt_state = tf.train.get_checkpoint_state(checkpoint_dir) 548 | except tf.errors.OutOfRangeError as e: 549 | print('[!] Cannot restore checkpoint: %s' % e) 550 | if not (ckpt_state and ckpt_state.model_checkpoint_path): 551 | print('[!] No model to eval yet at %s' % checkpoint_dir) 552 | print('[*] Loading checkpoint %s' % ckpt_state.model_checkpoint_path) 553 | saver4gen.restore(sess, ckpt_state.model_checkpoint_path) 554 | 555 | acc = 0. 556 | adv_acc = 0. 557 | adv_im_np = [] 558 | latent_z = [] 559 | init_latent_z = [] 560 | for batch in range(args.n_adv_examples): 561 | ###### random ground truth classes 562 | sess.run(init_op) 563 | preds_np, probs_np, im_np, cost_before = sess.run([net_pred, net_softmax, adv_images, obj]) 564 | 565 | ###### Using GD for attacking 566 | # initialize optimizers 567 | for i in range(args.n_iters): 568 | _, now_cost, pred_np, acgan_pred_np, acgan_probs = sess.run( 569 | [adv_op, obj, net_pred, acgan_pred, acgan_softmax], 570 | feed_dict={_iter: i}) 571 | ok = np.logical_and(pred_np == target, acgan_pred_np == source) 572 | print(" [*] {}th iter, cost: {}, success: {}/{}".format(i + 1, now_cost, np.sum(ok), args.batch_size)) 573 | 574 | adv_preds_np, acgan_preds_np, adv_probs_np, acgan_probs_np, im_np, hidden_z, init_z, cost_after = sess.run( 575 | [net_pred, acgan_pred, 576 | net_softmax, acgan_softmax, adv_images, adv_z, ref_z, obj]) 577 | acc += np.sum(preds_np == source) 578 | idx = np.logical_and(adv_preds_np == target, acgan_preds_np == source) 579 | adv_acc += np.sum(idx) 580 | adv_im_np.extend(im_np[idx]) 581 | latent_z.extend(hidden_z[idx]) 582 | init_latent_z.extend(init_z[idx]) 583 | print("batch: {}, acc: {}, adv_acc: {}, num collected: {}, cost before: {}, cost after: {}". 584 | format(batch + 1, acc / ((batch + 1) * args.batch_size), adv_acc / ((batch + 1) * args.batch_size), 585 | len(adv_im_np), cost_before, cost_after)) 586 | 587 | if len(adv_im_np) >= args.n2collect: 588 | adv_im_np = np.asarray(adv_im_np) 589 | latent_z = np.asarray(latent_z) 590 | size = int(np.sqrt(args.n2collect)) 591 | classifier = args.classifier 592 | if args.adv: 593 | classifier += '_adv' 594 | 595 | folder_format = '{}_{}_targeted_attack_with_z0' 596 | if noise: folder_format += '_noise' 597 | np.savez(os.path.join(check_folder(folder_format.format(args.dataset, classifier)), 598 | 'from{}to{}'.format(source, target)), adv_imgs=adv_im_np, latent_z=latent_z, 599 | init_latent_z=init_latent_z) 600 | save_images(adv_im_np[:args.n2collect, :, :, :], [size, size], 601 | os.path.join(check_folder(folder_format.format(args.dataset, classifier)), 602 | '{}_ims_from{}_to{}.png').format(args.dataset, source, target)) 603 | break 604 | 605 | 606 | def main(): 607 | np.random.seed(args.seed) 608 | tf.set_random_seed(args.seed) 609 | 610 | if args.dataset == 'mnist': 611 | num_classes = 10 612 | args.num_classes = 10 613 | args.image_size = 28 614 | args.channels = 1 615 | 616 | if args.mode == 'eval': 617 | data_X, data_y, test_X, test_y = load_mnist4classifier(args.dataset) 618 | 619 | elif args.dataset == 'svhn': 620 | num_classes = 10 621 | args.num_classes = 10 622 | args.image_size = 32 623 | args.channels = 3 624 | 625 | if args.mode == 'eval': 626 | data_X, data_y, test_X, test_y = load_svhn4classifier() 627 | 628 | elif args.dataset == 'celebA': 629 | num_classes = 2 630 | args.num_classes = 2 631 | args.image_size = 64 632 | args.channels = 3 633 | 634 | if args.mode == 'eval': 635 | data_X, data_y, test_X, test_y = load_celebA4classifier() 636 | 637 | else: 638 | raise NotImplementedError("Dataset {} not supported!".format(args.dataset)) 639 | 640 | print("[*] input args:\n", json.dumps(vars(args), indent=4, separators=(',', ':'))) 641 | 642 | num_residual_units = 5 643 | hps = resnet_model.HParams(batch_size=args.batch_size, 644 | num_classes=num_classes, 645 | min_lrn_rate=0.0001, 646 | lrn_rate=0.1, 647 | num_residual_units=num_residual_units, 648 | use_bottleneck=False, 649 | weight_decay_rate=0.0002, 650 | relu_leakiness=0.1, 651 | optimizer='mom', 652 | dataset=args.dataset) 653 | 654 | if args.mode == 'eval': 655 | evaluate(hps, test_X, test_y) 656 | elif args.mode == 'targeted_attack': 657 | targeted_attack(hps, args.source, args.target, args.lambda1, args.lambda2, noise=args.noise) 658 | elif args.mode == 'untargeted_attack': 659 | untargeted_attack(hps, args.lambda1, args.lambda2, source=args.source, noise=args.noise) 660 | else: 661 | raise NotImplementedError("No modes other than eval and attack!") 662 | 663 | 664 | if __name__ == '__main__': 665 | main() 666 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/generative_adversary/d276718254cbf40d1951d3f26e0ea43e7772af73/models/__init__.py -------------------------------------------------------------------------------- /models/acwgan_gp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | 4 | from utils import * 5 | from models.libs.resnet_ops import * 6 | from models.libs.ops import linear, gan_batch_norm 7 | 8 | 9 | class ACWGAN_GP(object): 10 | model_name = "ACWGAN_GP" # name for checkpoint 11 | 12 | def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir, 13 | dim_G=128, dim_D=128, learning_rate=None): 14 | self.sess = sess 15 | self.dataset_name = dataset_name 16 | self.checkpoint_dir = checkpoint_dir 17 | self.result_dir = result_dir 18 | self.log_dir = log_dir 19 | self.epoch = epoch 20 | self.batch_size = batch_size 21 | self.model_name = self.model_name 22 | self.dim_G = dim_G 23 | self.dim_D = dim_D 24 | 25 | if dataset_name == 'mnist': 26 | # parameters 27 | self.input_height = 28 28 | self.input_width = 28 29 | self.output_height = 28 30 | self.output_width = 28 31 | self.n_iters = 100000 32 | 33 | self.z_dim = z_dim # dimension of noise-vector 34 | self.y_dim = 10 35 | self.c_dim = 1 36 | 37 | # WGAN_GP parameter 38 | self.lambd = 10 # The higher value, the more stable, but the slower convergence 39 | self.disc_iters = 5 # The number of critic iterations for one-step of generator 40 | 41 | # train 42 | self.learning_rate = 0.0002 if learning_rate is None else learning_rate 43 | self.beta1 = 0.0 44 | self.beta2 = 0.9 45 | 46 | # test 47 | self.sample_num = 64 # number of generated images to be saved 48 | 49 | # code 50 | self.len_discrete_code = 10 # categorical distribution (i.e. label) 51 | self.len_continuous_code = 2 # gaussian distribution (e.g. rotation, thickness) 52 | 53 | # load mnist 54 | self.data_X, self.data_y = load_mnist(self.dataset_name) 55 | 56 | # get number of batches for a single epoch 57 | self.num_batches = len(self.data_X) // (self.batch_size * self.disc_iters) 58 | 59 | elif dataset_name == 'svhn': 60 | self.input_height = 32 61 | self.input_width = 32 62 | self.output_height = 32 63 | self.output_width = 32 64 | 65 | self.z_dim = z_dim # dimension of noise-vector 66 | self.c_dim = 3 67 | self.n_iters = 100000 68 | 69 | # WGAN_GP parameter 70 | self.lambd = 10 # The higher value, the more stable, but the slower convergence 71 | self.disc_iters = 5 # The number of critic iterations for one-step of generator 72 | # train 73 | self.beta1 = 0.0 74 | self.beta2 = 0.9 75 | 76 | # test 77 | self.sample_num = 64 # number of generated images to be saved 78 | 79 | # code 80 | self.len_continuous_code = 2 # gaussian distribution (e.g. rotation, thickness) 81 | 82 | # load svhn 83 | self.y_dim = 10 84 | self.len_discrete_code = 10 # categorical distribution (i.e. label) 85 | self.learning_rate = 0.0002 if learning_rate is None else learning_rate 86 | self.data_X, self.data_y = load_svhn() 87 | 88 | # get number of batches for a single epoch 89 | self.num_batches = len(self.data_X) // (self.batch_size * self.disc_iters) 90 | 91 | elif dataset_name == 'celebA': 92 | self.input_height = 64 93 | self.input_width = 64 94 | self.output_height = 64 95 | self.output_width = 64 96 | 97 | self.z_dim = z_dim 98 | self.y_dim = 2 99 | self.c_dim = 3 100 | self.n_iters = 200000 101 | 102 | self.lambd = 10 103 | self.disc_iters = 5 104 | self.learning_rate = 0.0001 if learning_rate is None else learning_rate 105 | 106 | self.beta1 = 0.0 107 | self.beta2 = 0.9 108 | 109 | self.sample_num = 64 110 | 111 | self.len_discrete_code = 2 112 | self.len_continuous_code = 2 113 | 114 | self.data_X, self.data_y = load_celebA() 115 | 116 | self.num_batches = len(self.data_X) // (self.batch_size * self.disc_iters) 117 | else: 118 | raise NotImplementedError 119 | 120 | 121 | def discriminator(self, x, update_collection, reuse=False): 122 | with tf.variable_scope("discriminator", reuse=reuse): 123 | output = tf.reshape(x, [-1, self.output_height, self.output_width, self.c_dim]) 124 | if self.dataset_name in ('mnist', 'svhn'): 125 | output = ResidualBlockDisc(output, self.dim_D, spectral_normed=False, update_collection=update_collection, name="d_residual_block") 126 | output = ResidualBlock(output, None, self.dim_D, 3, resample='down', spectral_normed=False, 127 | update_collection=update_collection, name='d_res1') 128 | output = ResidualBlock(output, None, self.dim_D, 3, resample=None, spectral_normed=False, 129 | update_collection=update_collection, name='d_res2') 130 | output = ResidualBlock(output, None, self.dim_D, 3, resample=None, spectral_normed=False, 131 | update_collection=update_collection, name='d_res3') 132 | output = tf.nn.relu(output) 133 | output = tf.reduce_mean(output, axis=[1,2]) # global sum pooling 134 | output_logits = linear(output, 1, spectral_normed=False, update_collection=update_collection, name='d_output') 135 | output_acgan = linear(output, self.y_dim, spectral_normed=False, update_collection=update_collection, 136 | name='d_acgan_output') 137 | elif self.dataset_name == 'celebA': 138 | output = conv2d(output, self.dim_D, 3, he_init=False) 139 | output = ResidualBlock_celebA(output, None, 2 * self.dim_D, 3, resample='down', name='d_res1') 140 | output = ResidualBlock_celebA(output, None, 4 * self.dim_D, 3, resample='down', name='d_res2') 141 | output = ResidualBlock_celebA(output, None, 8 * self.dim_D, 3, resample='down', name='d_res3') 142 | output = ResidualBlock_celebA(output, None, 8 * self.dim_D, 3, resample='down', name='d_res4') 143 | 144 | output = tf.reshape(output, [-1, 4 * 4 * 8 * self.dim_D]) 145 | output_logits = linear(output, 1, spectral_normed=False, update_collection=update_collection, name='d_output') 146 | output_acgan = linear(output, self.y_dim, spectral_normed=False, update_collection=update_collection, 147 | name='d_acgan_output') 148 | 149 | else: 150 | raise NotImplementedError("do not support dataset {}".format(self.dataset_name)) 151 | 152 | return output_logits, output_acgan 153 | 154 | def generator(self, z, y, reuse=False): 155 | with tf.variable_scope("generator", reuse=reuse): 156 | onehot = tf.one_hot(y, depth=self.y_dim, dtype=tf.float32) 157 | z = tf.concat([z, onehot], axis=1) 158 | if self.dataset_name == "mnist": 159 | output = linear(z, 7 * 7 * self.dim_G, name='g_fc1') 160 | output = tf.reshape(output, [-1, 7, 7, self.dim_G]) 161 | output = ResidualBlock(output, y, self.dim_G, 3, resample='up', name='g_res1', n_labels=self.y_dim) 162 | output = ResidualBlock(output, y, self.dim_G, 3, resample='up', name='g_res2', n_labels=self.y_dim) 163 | output = gan_batch_norm(output, name='g_out') 164 | output = tf.nn.relu(output) 165 | output = conv2d(output, 1, 3, he_init=False, name='g_final') 166 | output = tf.sigmoid(output) 167 | 168 | elif self.dataset_name == 'svhn': 169 | output = linear(z, 4 * 4 * self.dim_G, name='g_fc1') 170 | output = tf.reshape(output, [-1, 4, 4, self.dim_G]) 171 | output = ResidualBlock(output, y, self.dim_G, 3, resample='up', name='g_res1', n_labels=self.y_dim) 172 | output = ResidualBlock(output, y, self.dim_G, 3, resample='up', name='g_res2', n_labels=self.y_dim) 173 | output = ResidualBlock(output, y, self.dim_G, 3, resample='up', name='g_res3', n_labels=self.y_dim) 174 | output = gan_batch_norm(output, name='g_out') 175 | output = tf.nn.relu(output) 176 | output = conv2d(output, 3, 3, he_init=False, name='g_final') 177 | output = tf.tanh(output) 178 | 179 | elif self.dataset_name == 'celebA': 180 | output = linear(z, 4 * 4 * 8 * self.dim_G, name='g_fc1') 181 | output = tf.reshape(output, [-1, 4, 4, 8 * self.dim_G]) 182 | output = ResidualBlock_celebA(output, y, self.dim_G * 8, 3, resample='up', name='g_res1', n_labels=self.y_dim) 183 | output = ResidualBlock_celebA(output, y, self.dim_G * 4, 3, resample='up', name='g_res2', n_labels=self.y_dim) 184 | output = ResidualBlock_celebA(output, y, self.dim_G * 2, 3, resample='up', name='g_res3', n_labels=self.y_dim) 185 | output = ResidualBlock_celebA(output, y, self.dim_G, 3, resample='up', name='g_res4', n_labels=self.y_dim) 186 | output = gan_batch_norm(output, name='g_out') 187 | output = tf.nn.relu(output) 188 | output = conv2d(output, 3, 3, he_init=True, name='g_final') 189 | output = tf.tanh(output) 190 | 191 | return output 192 | 193 | def build_model(self): 194 | # some parameters 195 | image_dims = [self.input_height, self.input_width, self.c_dim] 196 | bs = self.batch_size 197 | 198 | """ Graph Input """ 199 | # images 200 | self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images') 201 | self.inputs += tf.random_uniform(shape=self.inputs.get_shape().as_list(), minval=0., maxval=1/255.) # dequantize 202 | 203 | # noises 204 | self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') 205 | self.y = tf.placeholder(tf.int32, [bs], name='y') 206 | self.lr = tf.placeholder(tf.float32, (), name='lr') 207 | 208 | """ Loss Function """ 209 | 210 | # output of D for real images 211 | D_real_logits, code_real_logits = self.discriminator(self.inputs, reuse=False, update_collection='spectral_norm') 212 | 213 | # output of D for fake images 214 | G = self.generator(self.z, self.y, reuse=False) 215 | D_fake_logits, code_fake_logits = self.discriminator(G, reuse=True, update_collection='NO_OPS') 216 | 217 | # get loss for discriminator 218 | d_loss_real = - tf.reduce_mean(D_real_logits) 219 | d_loss_fake = tf.reduce_mean(D_fake_logits) 220 | acgan_loss_real = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=code_real_logits, labels=self.y)) 221 | acgan_loss_fake = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=code_fake_logits, labels=self.y)) 222 | self.acgan_loss_real = acgan_loss_real 223 | self.acgan_loss_fake = acgan_loss_fake 224 | self.acgan_real_acc = tf.reduce_mean(tf.to_float(tf.equal(tf.cast(tf.argmax(code_real_logits, axis=1), tf.int32), self.y))) 225 | self.acgan_fake_acc = tf.reduce_mean(tf.to_float(tf.equal(tf.cast(tf.argmax(code_fake_logits, axis=1), tf.int32), self.y))) 226 | 227 | self.d_loss = d_loss_real + d_loss_fake + acgan_loss_real 228 | 229 | # get loss for generator 230 | self.g_loss = - d_loss_fake + acgan_loss_fake 231 | 232 | self.update_op = tf.group(*tf.get_collection("spectral_norm")) 233 | 234 | """ Gradient Penalty """ 235 | alpha = tf.random_uniform(shape=self.inputs.get_shape(), minval=0., maxval=1.) 236 | differences = G - self.inputs # This is different from MAGAN 237 | interpolates = self.inputs + (alpha * differences) 238 | D_inter = self.discriminator(interpolates, reuse=True, update_collection='NO_OPS')[0] 239 | gradients = tf.gradients(D_inter, [interpolates])[0] 240 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3])) 241 | gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2) 242 | self.d_loss += self.lambd * gradient_penalty 243 | 244 | """ Training """ 245 | # divide trainable variables into a group for D and a group for G 246 | d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') 247 | g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') 248 | 249 | # optimizers 250 | self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=self.beta1) \ 251 | .minimize(self.d_loss, var_list=d_vars) 252 | self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=self.beta1) \ 253 | .minimize(self.g_loss, var_list=g_vars) 254 | 255 | """" Testing """ 256 | # for test 257 | self.fake_images = self.generator(self.z, self.y, reuse=True) 258 | 259 | """ Summary """ 260 | d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real) 261 | d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake) 262 | d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) 263 | g_loss_sum = tf.summary.scalar("g_loss", self.g_loss) 264 | 265 | 266 | # final summary operations 267 | self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum]) 268 | self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum]) 269 | 270 | def train(self): 271 | # initialize all variables 272 | tf.global_variables_initializer().run() 273 | 274 | # graph inputs for visualize training results 275 | self.sample_z = np.random.normal(size=(self.batch_size, self.z_dim)) 276 | self.test_codes = np.argmax(self.data_y[0:self.batch_size], axis=1) 277 | 278 | # saver to save model 279 | self.saver = tf.train.Saver() 280 | 281 | # summary writer 282 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph) 283 | 284 | # restore check-point if it exits 285 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 286 | if could_load: 287 | start_epoch = (int)(checkpoint_counter / self.num_batches) 288 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches 289 | counter = checkpoint_counter 290 | print(" [*] Load SUCCESS") 291 | else: 292 | start_epoch = 0 293 | start_batch_id = 0 294 | counter = 0 295 | print(" [!] Load failed...") 296 | 297 | # loop for epoch 298 | start_time = time.time() 299 | for epoch in range(start_epoch, self.epoch): 300 | # get batch data 301 | random_state = np.random.get_state() 302 | np.random.shuffle(self.data_X) 303 | np.random.set_state(random_state) 304 | np.random.shuffle(self.data_y) 305 | 306 | for idx in range(start_batch_id, self.num_batches): 307 | decay = np.maximum(0.0, 1. - counter / (self.n_iters - 1)) if self.dataset_name != 'celebA' else 1. 308 | batch_images = self.data_X[idx * self.batch_size * self.disc_iters :(idx + 1) * self.batch_size * self.disc_iters] 309 | if self.dataset_name in ('svhn', 'celebA'): 310 | batch_images = (batch_images - 0.5) * 2. 311 | batch_y = np.argmax(self.data_y[idx * self.batch_size * self.disc_iters : (idx + 1) * self.batch_size * self.disc_iters], axis=1) 312 | for i in range(self.disc_iters): 313 | this_input = batch_images[i * self.batch_size: (i+1) * self.batch_size] 314 | this_y = batch_y[i * self.batch_size: (i+1) * self.batch_size] 315 | batch_z = np.random.normal(size=[self.batch_size, self.z_dim]).astype(np.float32) 316 | _, summary_str, d_loss, real_acc, fake_acc, acgan_l1, acgan_l2 = self.sess.run([self.d_optim, self.d_sum, self.d_loss, 317 | self.acgan_real_acc, self.acgan_fake_acc, self.acgan_loss_real, self.acgan_loss_fake], 318 | feed_dict={self.inputs: this_input, 319 | self.z: batch_z, 320 | self.y: this_y, 321 | self.lr: self.learning_rate * decay}) 322 | # self.sess.run([self.update_op]) 323 | self.writer.add_summary(summary_str, counter) 324 | 325 | batch_z = np.random.normal(size=[self.batch_size, self.z_dim]).astype(np.float32) 326 | random_y = np.random.choice(self.y_dim, self.batch_size).astype(np.int32) 327 | 328 | 329 | _, summary_str_g, g_loss, acgan_l2 = self.sess.run( 330 | [self.g_optim, self.g_sum, self.g_loss, self.acgan_loss_fake], 331 | feed_dict={self.z: batch_z, 332 | self.y: random_y, 333 | self.lr: self.learning_rate * decay}) 334 | self.writer.add_summary(summary_str_g, counter) 335 | 336 | counter += 1 337 | 338 | # display training status 339 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 340 | % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss)) 341 | print("ACGAN accuracy -- real: {}, fake: {}".format(real_acc, fake_acc)) 342 | 343 | # save training results for every 300 steps 344 | if np.mod(counter, 300) == 0: 345 | samples = self.sess.run(self.fake_images, 346 | feed_dict={self.z: self.sample_z, self.y: self.test_codes}) 347 | tot_num_samples = min(self.sample_num, self.batch_size) 348 | manifold_h = int(np.floor(np.sqrt(tot_num_samples))) 349 | manifold_w = int(np.floor(np.sqrt(tot_num_samples))) 350 | save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], 351 | './' + check_folder( 352 | self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format( 353 | epoch, idx)) 354 | 355 | if counter == self.n_iters: 356 | self.save(self.checkpoint_dir, counter) 357 | self.visualize_results(epoch) 358 | return 359 | 360 | # After an epoch, start_batch_id is set to zero 361 | # non-zero value is only for the first epoch after loading pre-trained model 362 | start_batch_id = 0 363 | 364 | # save model 365 | self.save(self.checkpoint_dir, counter) 366 | 367 | # show temporal results 368 | self.visualize_results(epoch) 369 | 370 | # save model for final step 371 | self.save(self.checkpoint_dir, counter) 372 | 373 | def visualize_results(self, epoch): 374 | tot_num_samples = min(self.sample_num, self.batch_size) 375 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 376 | z_sample = np.random.normal(size=(self.batch_size, self.z_dim)) 377 | 378 | """ random noise, random discrete code, fixed continuous code """ 379 | y = np.random.choice(self.len_discrete_code, self.batch_size) 380 | 381 | samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y}) 382 | 383 | save_images(samples[:image_frame_dim*image_frame_dim,:,:,:], [image_frame_dim, image_frame_dim], 384 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png') 385 | 386 | """ specified condition, random noise """ 387 | n_styles = 10 # must be less than or equal to self.batch_size 388 | 389 | np.random.seed() 390 | si = np.random.choice(self.batch_size, n_styles) 391 | 392 | for l in range(self.len_discrete_code): 393 | y = np.zeros(self.batch_size, dtype=np.int64) + l 394 | 395 | samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y}) 396 | save_images(samples[:image_frame_dim*image_frame_dim,:,:,:], [image_frame_dim, image_frame_dim], 397 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l) 398 | 399 | samples = samples[si, :, :, :] 400 | 401 | if l == 0: 402 | all_samples = samples 403 | else: 404 | all_samples = np.concatenate((all_samples, samples), axis=0) 405 | 406 | """ save merged images to check style-consistency """ 407 | canvas = np.zeros_like(all_samples) 408 | for s in range(n_styles): 409 | for c in range(self.len_discrete_code): 410 | canvas[s * self.len_discrete_code + c, :, :, :] = all_samples[c * n_styles + s, :, :, :] 411 | 412 | save_images(canvas, [n_styles, self.len_discrete_code], 413 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png') 414 | 415 | @property 416 | def model_dir(self): 417 | return "{}_{}_{}_{}".format( 418 | self.model_name, self.dataset_name, 419 | self.batch_size, self.z_dim) 420 | 421 | def save(self, checkpoint_dir, step): 422 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 423 | 424 | if not os.path.exists(checkpoint_dir): 425 | os.makedirs(checkpoint_dir) 426 | 427 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 428 | 429 | def load(self, checkpoint_dir): 430 | import re 431 | print(" [*] Reading checkpoints...") 432 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 433 | 434 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 435 | if ckpt and ckpt.model_checkpoint_path: 436 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 437 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 438 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0)) 439 | print(" [*] Success to read {}".format(ckpt_name)) 440 | return True, counter 441 | else: 442 | print(" [*] Failed to find a checkpoint") 443 | return False, 0 444 | -------------------------------------------------------------------------------- /models/aditi_mnist.py: -------------------------------------------------------------------------------- 1 | """ 2 | The model is adapted from the tensorflow tutorial: 3 | https://www.tensorflow.org/get_started/mnist/pros 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import tensorflow as tf 10 | 11 | class AditiMNIST(object): 12 | def __init__(self, images): 13 | self.x_image = images 14 | x = tf.reshape(self.x_image, shape=(-1, 784)) 15 | W1 = tf.get_variable('W1', shape=(784, 500), dtype=tf.float32, initializer=tf.random_normal_initializer) 16 | B1 = tf.get_variable('B1', shape=(500,), dtype=tf.float32, initializer=tf.random_normal_initializer) 17 | W2 = tf.get_variable('W2', shape=(500, 10), dtype=tf.float32, initializer=tf.random_normal_initializer) 18 | B2 = tf.get_variable('B2', shape=(10,), dtype=tf.float32, initializer=tf.random_normal_initializer) 19 | 20 | y = x @ W1 + B1 21 | y = tf.nn.relu(y) 22 | self.logits = y @ W2 + B2 -------------------------------------------------------------------------------- /models/libs/ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from models.libs.sn import spectral_normed_weight 5 | 6 | 7 | def scope_has_variables(scope): 8 | return len(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope.name)) > 0 9 | 10 | 11 | def conv2d(input_, output_dim, 12 | k_h=4, k_w=4, d_h=2, d_w=2, stddev=None, 13 | name="conv2d", spectral_normed=False, update_collection=None, with_w=False, padding="SAME"): 14 | # Glorot intialization 15 | # For RELU nonlinearity, it's sqrt(2./(n_in)) instead 16 | fan_in = k_h * k_w * input_.get_shape().as_list()[-1] 17 | fan_out = k_h * k_w * output_dim 18 | if stddev is None: 19 | stddev = np.sqrt(2. / (fan_in)) 20 | 21 | with tf.variable_scope(name) as scope: 22 | if scope_has_variables(scope): 23 | scope.reuse_variables() 24 | w = tf.get_variable("w", [k_h, k_w, input_.get_shape()[-1], output_dim], 25 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 26 | if spectral_normed: 27 | conv = tf.nn.conv2d(input_, spectral_normed_weight(w, update_collection=update_collection), 28 | strides=[1, d_h, d_w, 1], padding=padding) 29 | else: 30 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding=padding) 31 | 32 | biases = tf.get_variable("b", [output_dim], initializer=tf.constant_initializer(0.0)) 33 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 34 | 35 | if with_w: 36 | return conv, w, biases 37 | else: 38 | return conv 39 | 40 | 41 | def deconv2d(input_, output_shape, 42 | k_h=4, k_w=4, d_h=2, d_w=2, stddev=None, 43 | name="deconv2d", spectral_normed=False, update_collection=None, with_w=False, padding="SAME"): 44 | # Glorot initialization 45 | # For RELU nonlinearity, it's sqrt(2./(n_in)) instead 46 | fan_in = k_h * k_w * input_.get_shape().as_list()[-1] 47 | fan_out = k_h * k_w * output_shape[-1] 48 | if stddev is None: 49 | stddev = np.sqrt(2. / (fan_in)) 50 | 51 | with tf.variable_scope(name) as scope: 52 | if scope_has_variables(scope): 53 | scope.reuse_variables() 54 | # filter : [height, width, output_channels, in_channels] 55 | w = tf.get_variable("w", [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], 56 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 57 | if spectral_normed: 58 | deconv = tf.nn.conv2d_transpose(input_, spectral_normed_weight(w, update_collection=update_collection), 59 | output_shape=output_shape, 60 | strides=[1, d_h, d_w, 1], padding=padding) 61 | else: 62 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, 63 | strides=[1, d_h, d_w, 1], padding=padding) 64 | 65 | biases = tf.get_variable("b", [output_shape[-1]], initializer=tf.constant_initializer(0)) 66 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 67 | if with_w: 68 | return deconv, w, biases 69 | else: 70 | return deconv 71 | 72 | 73 | def lrelu(x, leak=0.1): 74 | return tf.maximum(x, leak * x) 75 | 76 | 77 | def linear(input_, output_size, name="linear", spectral_normed=False, update_collection=None, stddev=None, 78 | bias_start=0.0, with_biases=True, 79 | with_w=False): 80 | shape = input_.get_shape().as_list() 81 | 82 | if stddev is None: 83 | stddev = np.sqrt(1. / (shape[1])) 84 | with tf.variable_scope(name) as scope: 85 | if scope_has_variables(scope): 86 | scope.reuse_variables() 87 | weight = tf.get_variable("w", [shape[1], output_size], tf.float32, 88 | tf.truncated_normal_initializer(stddev=stddev)) 89 | if with_biases: 90 | bias = tf.get_variable("b", [output_size], 91 | initializer=tf.constant_initializer(bias_start)) 92 | if spectral_normed: 93 | mul = tf.matmul(input_, spectral_normed_weight(weight, update_collection=update_collection)) 94 | else: 95 | mul = tf.matmul(input_, weight) 96 | if with_w: 97 | if with_biases: 98 | return mul + bias, weight, bias 99 | else: 100 | return mul, weight, None 101 | else: 102 | if with_biases: 103 | return mul + bias 104 | else: 105 | return mul 106 | 107 | 108 | def batch_norm(input, is_training=True, momentum=0.9, epsilon=2e-5, in_place_update=True, name="batch_norm"): 109 | if in_place_update: 110 | return tf.contrib.layers.batch_norm(input, 111 | decay=momentum, 112 | center=True, 113 | scale=True, 114 | epsilon=epsilon, 115 | updates_collections=None, 116 | is_training=is_training, 117 | scope=name) 118 | else: 119 | return tf.contrib.layers.batch_norm(input, 120 | decay=momentum, 121 | center=True, 122 | scale=True, 123 | epsilon=epsilon, 124 | is_training=is_training, 125 | scope=name) 126 | 127 | 128 | def gan_cond_batch_norm(input, labels, n_labels=10, name="cond_batch_norm"): 129 | ''' 130 | Batch normalization in GANs is different. Do not use running statistics for testing 131 | ''' 132 | mean, var = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=True) 133 | shape = mean.get_shape().as_list() 134 | offset_m = tf.get_variable(name+'_offset', shape=(n_labels, shape[-1]), dtype=tf.float32, 135 | initializer=tf.zeros_initializer) 136 | scale_m = tf.get_variable(name+'_scale', shape=(n_labels, shape[-1]), dtype=tf.float32, 137 | initializer=tf.ones_initializer) 138 | offset = tf.nn.embedding_lookup(offset_m, labels) 139 | scale = tf.nn.embedding_lookup(scale_m, labels) 140 | 141 | result = tf.nn.batch_normalization(input, mean, var, offset[:, None, None, :], scale[:, None, None, :], 1e-5) 142 | return result 143 | 144 | def gan_batch_norm(input, name="batch_norm"): 145 | ''' 146 | Batch normalization in GANs is different. Do not use running statistics for testing 147 | ''' 148 | mean, var = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=True) 149 | shape = mean.get_shape().as_list() 150 | offset = tf.get_variable(name+'_offset', shape=shape, dtype=tf.float32, initializer=tf.zeros_initializer) 151 | scale = tf.get_variable(name+'_scale', shape=shape, dtype=tf.float32, initializer=tf.ones_initializer) 152 | result = tf.nn.batch_normalization(input, mean, var, offset, scale, 1e-5) 153 | return result 154 | -------------------------------------------------------------------------------- /models/libs/resnet_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from models.libs.sn import spectral_normed_weight 4 | from models.libs.ops import gan_cond_batch_norm 5 | from functools import partial 6 | 7 | def conv2d(input_, output_dim, filter_size, stddev=None, 8 | name="conv2d", spectral_normed=False, update_collection=None, with_w=False, he_init=True, padding="SAME"): 9 | # Glorot intialization 10 | # For RELU nonlinearity, it's sqrt(2./(n_in)) instead 11 | k_h = filter_size 12 | k_w = filter_size 13 | d_h = 1 14 | d_w = 1 15 | fan_in = k_h * k_w * input_.get_shape().as_list()[-1] 16 | fan_out = k_h * k_w * output_dim / (d_h * d_w) 17 | if stddev is None: 18 | if he_init: 19 | stddev = np.sqrt(4. / (fan_in + fan_out)) # He initialization 20 | else: 21 | stddev = np.sqrt(2. / (fan_in + fan_out)) # Glorot initialization 22 | 23 | with tf.variable_scope(name): 24 | w = tf.get_variable("w", [k_h, k_w, input_.get_shape()[-1], output_dim], 25 | initializer=tf.random_uniform_initializer( 26 | minval=-stddev * np.sqrt(3), 27 | maxval=stddev * np.sqrt(3) 28 | )) 29 | if spectral_normed: 30 | conv = tf.nn.conv2d(input_, spectral_normed_weight(w, update_collection=update_collection), 31 | strides=[1, d_h, d_w, 1], padding=padding) 32 | else: 33 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding=padding) 34 | 35 | biases = tf.get_variable("b", [output_dim], initializer=tf.constant_initializer(0.0)) 36 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 37 | 38 | if with_w: 39 | return conv, w, biases 40 | else: 41 | return conv 42 | 43 | def ConvMeanPool(input, output_dim, filter_size, name, spectral_normed=False, update_collection=None, he_init=True): 44 | output = conv2d(input, output_dim, filter_size, spectral_normed=spectral_normed, 45 | update_collection=update_collection, name=name, he_init=he_init) 46 | output = tf.add_n([output[:, ::2, ::2, :], 47 | output[:, 1::2, ::2, :], 48 | output[:, ::2, 1::2, :], 49 | output[:, 1::2, 1::2, :]]) / 4. 50 | return output 51 | 52 | def MeanPoolConv(input, output_dim, filter_size, name, spectral_normed=False, update_collection=None, he_init=True): 53 | output = input 54 | output = tf.add_n([output[:, ::2, ::2, :], 55 | output[:, 1::2, ::2, :], 56 | output[:, ::2, 1::2, :], 57 | output[:, 1::2, 1::2, :]]) / 4. 58 | return conv2d(output, output_dim, filter_size, spectral_normed=spectral_normed, 59 | update_collection=update_collection, name=name, he_init=he_init) 60 | 61 | def UpsampleConv(input, output_dim, filter_size, name, spectral_normed=False, update_collection=None, he_init=True): 62 | output = input 63 | output = tf.concat([output, output, output, output], axis=3) 64 | output = tf.depth_to_space(output, 2) 65 | return conv2d(output, output_dim, filter_size, spectral_normed=spectral_normed, 66 | update_collection=update_collection, name=name, he_init=he_init) 67 | 68 | def ResidualBlock_celebA(input, labels, output_dim, filter_size, resample, name, spectral_normed=False, update_collection=None, n_labels=10): 69 | input_dim = input.get_shape().as_list()[-1] 70 | if resample == 'down': 71 | conv_shortcut = MeanPoolConv 72 | conv1 = partial(conv2d, output_dim=input_dim) 73 | conv2 = partial(ConvMeanPool, output_dim=output_dim) 74 | elif resample == 'up': 75 | conv_shortcut = UpsampleConv 76 | conv1 = partial(UpsampleConv, output_dim=output_dim) 77 | conv2 = partial(conv2d, output_dim=output_dim) 78 | elif resample is None: 79 | conv_shortcut = conv2d 80 | conv1 = partial(conv2d, output_dim=input_dim) 81 | conv2 = partial(conv2d, output_dim=output_dim) 82 | 83 | if output_dim == input_dim and resample is None: 84 | shortcut = input 85 | else: 86 | shortcut = conv_shortcut(input, output_dim=output_dim, filter_size=1, spectral_normed=spectral_normed, 87 | update_collection=update_collection, he_init=False, name=name+'_shortcut') 88 | 89 | output = input 90 | if labels is not None: 91 | output = gan_cond_batch_norm(output, labels, n_labels=n_labels, name=name+'_bn1') 92 | 93 | output = tf.nn.relu(output) 94 | output = conv1(output, filter_size=filter_size, spectral_normed=spectral_normed, 95 | update_collection=update_collection, name=name+'_conv1') 96 | 97 | if labels is not None: 98 | output = gan_cond_batch_norm(output, labels, n_labels=n_labels, name=name+'_bn2') 99 | 100 | output = tf.nn.relu(output) 101 | output = conv2(output, filter_size=filter_size, spectral_normed=spectral_normed, 102 | update_collection=update_collection, name=name+'_conv2') 103 | 104 | return output + shortcut 105 | 106 | def ResidualBlock(input, labels, output_dim, filter_size, resample, name, spectral_normed=False, update_collection=None, n_labels=10): 107 | input_dim = input.get_shape().as_list()[-1] 108 | if resample == 'down': 109 | conv1 = partial(conv2d, output_dim=input_dim) 110 | conv2 = partial(ConvMeanPool, output_dim=output_dim) 111 | conv_shortcut = ConvMeanPool 112 | elif resample == 'up': 113 | conv1 = partial(UpsampleConv, output_dim=output_dim) 114 | conv_shortcut = UpsampleConv 115 | conv2 = partial(conv2d, output_dim=output_dim) 116 | elif resample is None: 117 | conv_shortcut = conv2d 118 | conv1 = partial(conv2d, output_dim=output_dim) 119 | conv2 = partial(conv2d, output_dim=output_dim) 120 | 121 | if output_dim == input_dim and resample is None: 122 | shortcut = input 123 | else: 124 | shortcut = conv_shortcut(input, output_dim=output_dim, filter_size=1, spectral_normed=spectral_normed, 125 | update_collection=update_collection, he_init=False, name=name+'_shortcut') 126 | 127 | output = input 128 | if labels is not None: 129 | output = gan_cond_batch_norm(output, labels, n_labels=n_labels, name=name+'_bn1') 130 | 131 | output = tf.nn.relu(output) 132 | output = conv1(output, filter_size=filter_size, spectral_normed=spectral_normed, 133 | update_collection=update_collection, name=name+'_conv1') 134 | 135 | if labels is not None: 136 | output = gan_cond_batch_norm(output, labels, n_labels=n_labels, name=name+'_bn2') 137 | 138 | output = tf.nn.relu(output) 139 | output = conv2(output, filter_size=filter_size, spectral_normed=spectral_normed, 140 | update_collection=update_collection, name=name+'_conv2') 141 | 142 | return output + shortcut 143 | 144 | def ResidualBlockDisc(input, dim_D, name, spectral_normed=False, update_collection=None): 145 | conv1 = partial(conv2d, output_dim=dim_D, spectral_normed=spectral_normed, update_collection=update_collection) 146 | conv2 = partial(ConvMeanPool, output_dim=dim_D, spectral_normed=spectral_normed, update_collection=update_collection) 147 | conv_shortcut = partial(MeanPoolConv, output_dim=dim_D, spectral_normed=spectral_normed, update_collection=update_collection) 148 | 149 | shortcut = conv_shortcut(input, filter_size=1, he_init=False, name=name+'_shortcut') 150 | output = input 151 | output = conv1(output, filter_size=3, name=name+'_conv1') 152 | output = tf.nn.relu(output) 153 | output = conv2(output, filter_size=3, name=name+'_conv2') 154 | return shortcut + output 155 | 156 | -------------------------------------------------------------------------------- /models/libs/sn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import warnings 3 | 4 | NO_OPS = 'NO_OPS' 5 | 6 | def _l2normalize(v, eps=1e-12): 7 | return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps) 8 | 9 | 10 | def spectral_normed_weight(W, u=None, num_iters=1, update_collection=None, with_sigma=False): 11 | # Usually num_iters = 1 will be enough 12 | W_shape = W.shape.as_list() 13 | W_reshaped = tf.reshape(W, [-1, W_shape[-1]]) 14 | if u is None: 15 | u = tf.get_variable("u", [1, W_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False) 16 | 17 | def power_iteration(i, u_i, v_i): 18 | v_ip1 = _l2normalize(tf.matmul(u_i, tf.transpose(W_reshaped))) 19 | u_ip1 = _l2normalize(tf.matmul(v_ip1, W_reshaped)) 20 | return i + 1, u_ip1, v_ip1 21 | 22 | _, u_final, v_final = tf.while_loop( 23 | cond=lambda i, _1, _2: i < num_iters, 24 | body=power_iteration, 25 | loop_vars=(tf.constant(0, dtype=tf.int32), 26 | u, tf.zeros(dtype=tf.float32, shape=[1, W_reshaped.shape.as_list()[0]])) 27 | ) 28 | sigma = tf.matmul(tf.matmul(v_final, W_reshaped), tf.transpose(u_final))[0, 0] 29 | # sigma = tf.reduce_sum(tf.matmul(u_final, tf.transpose(W_reshaped)) * v_final) 30 | W_bar = W_reshaped / sigma 31 | W_bar = tf.reshape(W_bar, W_shape) 32 | # Put NO_OPS to not update any collection. This is useful for the second call of discriminator if the update_op 33 | # has already been collected on the first call. 34 | if update_collection != NO_OPS: 35 | tf.add_to_collection(update_collection, u.assign(u_final)) 36 | 37 | if with_sigma: 38 | return W_bar, sigma 39 | else: 40 | return W_bar 41 | 42 | -------------------------------------------------------------------------------- /models/madry_mnist.py: -------------------------------------------------------------------------------- 1 | """ 2 | The model is adapted from the tensorflow tutorial: 3 | https://www.tensorflow.org/get_started/mnist/pros 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import tensorflow as tf 10 | 11 | 12 | class MadryModel(object): 13 | def __init__(self, images): 14 | 15 | self.x_image = images 16 | 17 | # first convolutional layer 18 | W_conv1 = self._weight_variable([5, 5, 1, 32], name="Variable") 19 | b_conv1 = self._bias_variable([32], name="Variable_1") 20 | 21 | h_conv1 = tf.nn.relu(self._conv2d(self.x_image, W_conv1) + b_conv1) 22 | h_pool1 = self._max_pool_2x2(h_conv1) 23 | 24 | # second convolutional layer 25 | W_conv2 = self._weight_variable([5, 5, 32, 64], name="Variable_2") 26 | b_conv2 = self._bias_variable([64], name="Variable_3") 27 | 28 | h_conv2 = tf.nn.relu(self._conv2d(h_pool1, W_conv2) + b_conv2) 29 | h_pool2 = self._max_pool_2x2(h_conv2) 30 | 31 | # first fully connected layer 32 | W_fc1 = self._weight_variable([7 * 7 * 64, 1024], name="Variable_4") 33 | b_fc1 = self._bias_variable([1024], name="Variable_5") 34 | 35 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) 36 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 37 | 38 | # output layer 39 | W_fc2 = self._weight_variable([1024, 10], name="Variable_6") 40 | b_fc2 = self._bias_variable([10], name="Variable_7") 41 | 42 | self.pre_softmax = tf.matmul(h_fc1, W_fc2) + b_fc2 43 | 44 | self.y_pred = tf.argmax(self.pre_softmax, 1) 45 | 46 | 47 | @staticmethod 48 | def _weight_variable(shape, name): 49 | return tf.get_variable(name=name, shape=shape, initializer=tf.truncated_normal_initializer(stddev=0.1)) 50 | 51 | @staticmethod 52 | def _bias_variable(shape, name): 53 | return tf.get_variable(name=name, shape=shape, initializer=tf.constant_initializer(value=0.1)) 54 | 55 | @staticmethod 56 | def _conv2d(x, W): 57 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 58 | 59 | @staticmethod 60 | def _max_pool_2x2(x): 61 | return tf.nn.max_pool(x, 62 | ksize=[1, 2, 2, 1], 63 | strides=[1, 2, 2, 1], 64 | padding='SAME') 65 | -------------------------------------------------------------------------------- /models/resnet_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ResNet model. 17 | 18 | Related papers: 19 | https://arxiv.org/pdf/1603.05027v2.pdf 20 | https://arxiv.org/pdf/1512.03385v1.pdf 21 | https://arxiv.org/pdf/1605.07146v1.pdf 22 | """ 23 | from collections import namedtuple 24 | 25 | import numpy as np 26 | import tensorflow as tf 27 | import six 28 | 29 | 30 | HParams = namedtuple('HParams', 31 | 'batch_size, num_classes, min_lrn_rate, lrn_rate, ' 32 | 'num_residual_units, use_bottleneck, weight_decay_rate, ' 33 | 'relu_leakiness, optimizer, dataset') 34 | 35 | 36 | class ResNet(object): 37 | """ResNet model.""" 38 | 39 | def __init__(self, hps, images, training): 40 | """ResNet constructor. 41 | 42 | Args: 43 | hps: Hyperparameters. 44 | images: Batches of images. [batch_size, image_size, image_size, 3] 45 | labels: Batches of labels. [batch_size, num_classes] 46 | mode: One of 'train' and 'eval'. 47 | """ 48 | self.hps = hps 49 | self._images = images 50 | self.training = training 51 | 52 | self._extra_train_ops = [] 53 | 54 | def build_graph(self): 55 | """Build a whole graph for the model.""" 56 | self.global_step = tf.contrib.framework.get_or_create_global_step() 57 | self._build_model() 58 | self.summaries = tf.summary.merge_all() 59 | 60 | def _stride_arr(self, stride): 61 | """Map a stride scalar to the stride array for tf.nn.conv2d.""" 62 | return [1, stride, stride, 1] 63 | 64 | def _build_model(self): 65 | """Build the core model within the graph.""" 66 | with tf.variable_scope('init'): 67 | x = self._images 68 | if self.hps.dataset == 'mnist': 69 | channels = 1 70 | elif self.hps.dataset in ('svhn', 'celebA'): 71 | channels = 3 72 | if self.hps.dataset in ('svhn', 'celebA'): 73 | x = self._conv('init_conv', x, 3, channels, 16, self._stride_arr(1)) 74 | elif self.hps.dataset == 'mnist': 75 | x = self._conv('init_conv', x, 3, channels, 4, self._stride_arr(1)) 76 | else: 77 | raise NotImplementedError("Dataset {} is not supported!".format(self.hps.dataset)) 78 | 79 | strides = [1, 2, 2] 80 | activate_before_residual = [True, False, False] 81 | if self.hps.use_bottleneck: 82 | res_func = self._bottleneck_residual 83 | filters = [16, 64, 128, 256] 84 | 85 | else: 86 | res_func = self._residual 87 | filters = [16, 16, 32, 64] 88 | # Uncomment the following codes to use w28-10 wide residual network. 89 | # It is more memory efficient than very deep residual network and has 90 | # comparably good performance. 91 | # https://arxiv.org/pdf/1605.07146v1.pdf 92 | # filters = [16, 160, 320, 640] 93 | # Update hps.num_residual_units to 4 94 | if self.hps.dataset == 'mnist': 95 | filters = [_ // 4 for _ in filters] 96 | 97 | with tf.variable_scope('unit_1_0'): 98 | x = res_func(x, filters[0], filters[1], self._stride_arr(strides[0]), 99 | activate_before_residual[0]) 100 | for i in six.moves.range(1, self.hps.num_residual_units): 101 | with tf.variable_scope('unit_1_%d' % i): 102 | x = res_func(x, filters[1], filters[1], self._stride_arr(1), False) 103 | 104 | with tf.variable_scope('unit_2_0'): 105 | x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]), 106 | activate_before_residual[1]) 107 | for i in six.moves.range(1, self.hps.num_residual_units): 108 | with tf.variable_scope('unit_2_%d' % i): 109 | x = res_func(x, filters[2], filters[2], self._stride_arr(1), False) 110 | 111 | with tf.variable_scope('unit_3_0'): 112 | x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]), 113 | activate_before_residual[2]) 114 | for i in six.moves.range(1, self.hps.num_residual_units): 115 | with tf.variable_scope('unit_3_%d' % i): 116 | x = res_func(x, filters[3], filters[3], self._stride_arr(1), False) 117 | 118 | with tf.variable_scope('unit_last'): 119 | x = self._batch_norm('final_bn', x) 120 | x = self._relu(x, self.hps.relu_leakiness) 121 | x = self._global_avg_pool(x) 122 | 123 | with tf.variable_scope('logit'): 124 | self.logits = self._fully_connected(x, self.hps.num_classes) 125 | 126 | def _batch_norm(self, name, x): 127 | """Batch normalization.""" 128 | with tf.variable_scope(name, reuse=False): 129 | return tf.layers.batch_normalization(x, training=self.training, name="batch_norm") 130 | 131 | def _residual(self, x, in_filter, out_filter, stride, 132 | activate_before_residual=False): 133 | """Residual unit with 2 sub layers.""" 134 | if activate_before_residual: 135 | with tf.variable_scope('shared_activation'): 136 | x = self._batch_norm('init_bn', x) 137 | x = self._relu(x, self.hps.relu_leakiness) 138 | orig_x = x 139 | else: 140 | with tf.variable_scope('residual_only_activation'): 141 | orig_x = x 142 | x = self._batch_norm('init_bn', x) 143 | x = self._relu(x, self.hps.relu_leakiness) 144 | 145 | with tf.variable_scope('sub1'): 146 | x = self._conv('conv1', x, 3, in_filter, out_filter, stride) 147 | 148 | with tf.variable_scope('sub2'): 149 | x = self._batch_norm('bn2', x) 150 | x = self._relu(x, self.hps.relu_leakiness) 151 | x = self._conv('conv2', x, 3, out_filter, out_filter, [1, 1, 1, 1]) 152 | 153 | with tf.variable_scope('sub_add'): 154 | if in_filter != out_filter: 155 | orig_x = tf.nn.avg_pool(orig_x, stride, stride, 'VALID') 156 | orig_x = tf.pad( 157 | orig_x, [[0, 0], [0, 0], [0, 0], 158 | [(out_filter - in_filter) // 2, (out_filter - in_filter) // 2]]) 159 | x += orig_x 160 | 161 | tf.logging.debug('image after unit %s', x.get_shape()) 162 | return x 163 | 164 | def _bottleneck_residual(self, x, in_filter, out_filter, stride, 165 | activate_before_residual=False): 166 | """Bottleneck residual unit with 3 sub layers.""" 167 | if activate_before_residual: 168 | with tf.variable_scope('common_bn_relu'): 169 | x = self._batch_norm('init_bn', x) 170 | x = self._relu(x, self.hps.relu_leakiness) 171 | orig_x = x 172 | else: 173 | with tf.variable_scope('residual_bn_relu'): 174 | orig_x = x 175 | x = self._batch_norm('init_bn', x) 176 | x = self._relu(x, self.hps.relu_leakiness) 177 | 178 | with tf.variable_scope('sub1'): 179 | x = self._conv('conv1', x, 1, in_filter, out_filter / 4, stride) 180 | 181 | with tf.variable_scope('sub2'): 182 | x = self._batch_norm('bn2', x) 183 | x = self._relu(x, self.hps.relu_leakiness) 184 | x = self._conv('conv2', x, 3, out_filter / 4, out_filter / 4, [1, 1, 1, 1]) 185 | 186 | with tf.variable_scope('sub3'): 187 | x = self._batch_norm('bn3', x) 188 | x = self._relu(x, self.hps.relu_leakiness) 189 | x = self._conv('conv3', x, 1, out_filter / 4, out_filter, [1, 1, 1, 1]) 190 | 191 | with tf.variable_scope('sub_add'): 192 | if in_filter != out_filter: 193 | orig_x = self._conv('project', orig_x, 1, in_filter, out_filter, stride) 194 | x += orig_x 195 | 196 | tf.logging.info('image after unit %s', x.get_shape()) 197 | return x 198 | 199 | def _decay(self): 200 | """L2 weight decay loss.""" 201 | costs = [] 202 | for var in tf.trainable_variables(): 203 | if var.op.name.find(r'DW') > 0: 204 | costs.append(tf.nn.l2_loss(var)) 205 | # tf.summary.histogram(var.op.name, var) 206 | 207 | return tf.multiply(self.hps.weight_decay_rate, tf.add_n(costs)) 208 | 209 | def _conv(self, name, x, filter_size, in_filters, out_filters, strides): 210 | """Convolution.""" 211 | with tf.variable_scope(name): 212 | n = filter_size * filter_size * out_filters 213 | kernel = tf.get_variable( 214 | 'DW', [filter_size, filter_size, in_filters, out_filters], 215 | tf.float32, initializer=tf.random_normal_initializer( 216 | stddev=np.sqrt(2.0 / n))) 217 | return tf.nn.conv2d(x, kernel, strides, padding='SAME') 218 | 219 | def _relu(self, x, leakiness=0.0): 220 | """Relu, with optional leaky support.""" 221 | return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu') 222 | 223 | def _fully_connected(self, x, out_dim): 224 | """FullyConnected layer for final output.""" 225 | x = tf.contrib.layers.flatten(x) 226 | w = tf.get_variable( 227 | 'DW', [x.get_shape()[1], out_dim], 228 | initializer=tf.uniform_unit_scaling_initializer(factor=1.0)) 229 | b = tf.get_variable('biases', [out_dim], 230 | initializer=tf.constant_initializer()) 231 | return tf.nn.xw_plus_b(x, w, b) 232 | 233 | def _global_avg_pool(self, x): 234 | assert x.get_shape().ndims == 4 235 | return tf.reduce_mean(x, [1, 2]) 236 | -------------------------------------------------------------------------------- /models/vgg16.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tensorflow.contrib import layers 6 | from tensorflow.contrib.framework.python.ops import arg_scope 7 | from tensorflow.contrib.layers.python.layers import layers as layers_lib 8 | from tensorflow.contrib.layers.python.layers import utils 9 | from tensorflow.python.ops import variable_scope 10 | import tensorflow as tf 11 | 12 | 13 | def vgg_16(inputs, 14 | num_classes=1000, 15 | is_training=True, 16 | dataset='cifar', 17 | scope='vgg_16'): 18 | """Oxford Net VGG 16-Layers version D Example. 19 | 20 | Note: All the fully_connected layers have been transformed to conv2d layers. 21 | To use in classification mode, resize input to 224x224. 22 | 23 | Args: 24 | inputs: a tensor of size [batch_size, height, width, channels]. 25 | num_classes: number of predicted classes. 26 | is_training: whether or not the model is being trained. 27 | dropout_keep_prob: the probability that activations are kept in the dropout 28 | layers during training. 29 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 30 | outputs. Useful to remove unnecessary dimensions for classification. 31 | scope: Optional scope for the variables. 32 | 33 | Returns: 34 | the last op containing the log predictions and end_points dict. 35 | """ 36 | with variable_scope.variable_scope(scope, 'vgg_16', [inputs]) as sc: 37 | end_points_collection = sc.original_name_scope + '_end_points' 38 | # Collect outputs for conv2d, fully_connected and max_pool2d. 39 | with arg_scope( 40 | [layers.conv2d, layers_lib.fully_connected, layers_lib.max_pool2d], 41 | outputs_collections=end_points_collection): 42 | def ConvBatchRelu(layer_input, n_output_plane, name): 43 | with variable_scope.variable_scope(name): 44 | output = layers.conv2d(layer_input, n_output_plane, [3, 3], scope='conv') 45 | output = layers.batch_norm(output, center=True, scale=True, activation_fn=tf.nn.relu, 46 | is_training=is_training) 47 | return output 48 | 49 | filters = [64, 64, 128, 128, 256, 256, 256, 512, 512, 512, 512, 512, 512, 512] 50 | if dataset == 'mnist': 51 | filters = [_ // 4 for _ in filters] 52 | elif dataset not in ('cifar', 'svhn'): 53 | raise NotImplementedError("Dataset {} is not supported!".format(dataset)) 54 | 55 | net = ConvBatchRelu(inputs, filters[0], 'conv1_1') 56 | net = ConvBatchRelu(net, filters[1], 'conv1_2') 57 | net = layers_lib.max_pool2d(net, [2, 2], scope='pool1') 58 | net = ConvBatchRelu(net, filters[2], 'conv2_1') 59 | net = ConvBatchRelu(net, filters[3], 'conv2_2') 60 | net = layers_lib.max_pool2d(net, [2, 2], scope='pool2') 61 | net = ConvBatchRelu(net, filters[4], 'conv3_1') 62 | net = ConvBatchRelu(net, filters[5], 'conv3_2') 63 | net = ConvBatchRelu(net, filters[6], 'conv3_3') 64 | net = layers_lib.max_pool2d(net, [2, 2], scope='pool3') 65 | net = ConvBatchRelu(net, filters[7], 'conv4_1') 66 | net = ConvBatchRelu(net, filters[8], 'conv4_2') 67 | net = ConvBatchRelu(net, filters[9], 'conv4_3') 68 | net = layers_lib.max_pool2d(net, [2, 2], scope='pool4') 69 | net = ConvBatchRelu(net, filters[10], 'conv5_1') 70 | net = ConvBatchRelu(net, filters[11], 'conv5_2') 71 | net = ConvBatchRelu(net, filters[12], 'conv5_3') 72 | if dataset == 'cifar': 73 | net = layers_lib.max_pool2d(net, [2, 2], scope='pool5') 74 | # Use conv2d instead of fully_connected layers. 75 | net = layers.flatten(net, scope='flatten6') 76 | net = layers_lib.dropout(net, 0.5, is_training=is_training, scope='dropout6') 77 | net = layers.relu(net, filters[13]) 78 | net = layers_lib.dropout(net, 0.5, is_training=is_training, scope='dropout6') 79 | net = layers.linear(net, num_classes) 80 | # Convert end_points_collection into a end_point dict. 81 | end_points = utils.convert_collection_to_dict(end_points_collection) 82 | end_points[sc.name + '/fc8'] = net 83 | return net, end_points 84 | 85 | vgg_16.default_image_size = 32 86 | -------------------------------------------------------------------------------- /models/zico_mnist.py: -------------------------------------------------------------------------------- 1 | """ 2 | The model is adapted from the tensorflow tutorial: 3 | https://www.tensorflow.org/get_started/mnist/pros 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import tensorflow as tf 10 | 11 | class ZicoMNIST(object): 12 | def __init__(self, images): 13 | self.x_image = images 14 | W0 = tf.get_variable('W0', dtype=tf.float32, shape=(4, 4, 1, 16)) 15 | B0 = tf.get_variable('B0', dtype=tf.float32, shape=(16,)) 16 | W2 = tf.get_variable('W2', dtype=tf.float32, shape=(4, 4, 16, 32)) 17 | B2 = tf.get_variable('B2', dtype=tf.float32, shape=(32,)) 18 | W5 = tf.get_variable('W5', dtype=tf.float32, shape=(1568, 100)) 19 | B5 = tf.get_variable('B5', dtype=tf.float32, shape=(100,)) 20 | W7 = tf.get_variable('W7', dtype=tf.float32, shape=(100, 10)) 21 | B7 = tf.get_variable('B7', dtype=tf.float32, shape=(10,)) 22 | 23 | y = tf.pad(self.x_image, [[0, 0], [1, 1], [1, 1], [0, 0]]) 24 | y = tf.nn.conv2d(y, W0, strides=[1, 2, 2, 1], padding='VALID') 25 | y = tf.nn.bias_add(y, B0) 26 | y = tf.nn.relu(y) 27 | y = tf.pad(y, [[0, 0], [1, 1], [1, 1], [0, 0]]) 28 | y = tf.nn.conv2d(y, W2, strides=[1, 2, 2, 1], padding="VALID") 29 | y = tf.nn.bias_add(y, B2) 30 | y = tf.nn.relu(y) 31 | y = tf.transpose(y, [0, 3, 1, 2]) 32 | y = tf.reshape(y, [tf.shape(y)[0], -1]) 33 | y = y @ W5 + B5 34 | y = tf.nn.relu(y) 35 | y = y @ W7 + B7 36 | 37 | self.logits = y -------------------------------------------------------------------------------- /mturk_websites/mturk.html: -------------------------------------------------------------------------------- 1 | 3 | 4 |
5 |
6 |
7 |
8 | 9 | Image Tagging 10 | Instructions, please read! (click to collapse) 11 |
12 |

You must identify the digit shown in the following image

13 |
    14 |
  • Select the digit below each image 15 |
  • 16 |
  • If an image contains several digits, tag the digit closer to the middle of 17 | the image. 18 |
  • 19 |
  • Select N/A only when the image doesn't look close to any digit. It's OK if the 20 | digit is blurry, noisy or looks artificial.
  • 21 |
  • Irresponsible answers will be REJECTED and you WON'T BE 22 | PAID. 23 |
  • 24 |
  • Examples:
    25 | example1 26 | example2 27 |
  • 28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 | 37 | 38 | 39 | 40 | 42 | 44 | 46 | 48 | 50 | 51 | 52 | 53 | 71 | 89 | 107 | 125 | 143 | 144 | 145 | 146 | 148 | 150 | 152 | 154 | 156 | 157 | 158 | 159 | 177 | 195 | 213 | 231 | 249 | 250 | 251 |
Image:
image_url_0image_url_1image_url_2image_url_3image_url_4
Digit:
54 |
55 | 56 | 57 | 58 |
59 | 60 | 61 | 62 |
63 | 64 | 65 | 66 |
67 | 68 | 69 |
70 |
72 |
73 | 74 | 75 | 76 |
77 | 78 | 79 | 80 |
81 | 82 | 83 | 84 |
85 | 86 | 87 |
88 |
90 |
91 | 92 | 93 | 94 |
95 | 96 | 97 | 98 |
99 | 100 | 101 | 102 |
103 | 104 | 105 |
106 |
108 |
109 | 110 | 111 | 112 |
113 | 114 | 115 | 116 |
117 | 118 | 119 | 120 |
121 | 122 | 123 |
124 |
126 |
127 | 128 | 129 | 130 |
131 | 132 | 133 | 134 |
135 | 136 | 137 | 138 |
139 | 140 | 141 |
142 |
Image:
image_url_5image_url_6image_url_7image_url_8image_url_9
Digit:
160 |
161 | 162 | 163 | 164 |
165 | 166 | 167 | 168 |
169 | 170 | 171 | 172 |
173 | 174 | 175 |
176 |
178 |
179 | 180 | 181 | 182 |
183 | 184 | 185 | 186 |
187 | 188 | 189 | 190 |
191 | 192 | 193 |
194 |
196 |
197 | 198 | 199 | 200 |
201 | 202 | 203 | 204 |
205 | 206 | 207 | 208 |
209 | 210 | 211 |
212 |
214 |
215 | 216 | 217 | 218 |
219 | 220 | 221 | 222 |
223 | 224 | 225 | 226 |
227 | 228 | 229 |
230 |
232 |
233 | 234 | 235 | 236 |
237 | 238 | 239 | 240 |
241 | 242 | 243 | 244 |
245 | 246 | 247 |
248 |
252 |
253 |
254 |
255 |
256 | 274 | 275 | 276 | 278 | 281 | 299 | -------------------------------------------------------------------------------- /mturk_websites/mturk_abtest.html: -------------------------------------------------------------------------------- 1 | 3 | 4 |
5 |
6 |
7 |
8 | 9 | Instructions, please read! (click to collapse) 10 |
11 |

You must identify which digit is real and which is synthesized by the computer:

12 |
    13 |
  • 14 | For each group of digits, there is one drawn by human beings (real) and another generated by a computer program (fake). 15 |
  • 16 |
  • 17 | Try your best to identify the fake image, i.e., the one generated by the computer. 18 |
  • 19 |
  • 20 | Clean your screen and increase its brightness before working on this HIT! The fake images are usually noisy/spotty/unnatural in some way. 21 |
  • 22 |
  • Examples:
    23 | example1 24 | example2 25 |
  • 26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 | 34 | 35 | 36 | 37 | 40 | 43 | 46 | 49 | 50 | 51 | 52 | 56 | 59 | 62 | 65 | 66 | 67 | 68 | 74 | 80 | 86 | 92 | 93 | 94 |
Image
(Top):
image_url_0 39 | image_url_1 42 | image_url_2 45 | image_url_3 48 |
Image
(Bottom):
53 | image_url_10 55 | image_url_11 58 | image_url_12 61 | image_url_13 64 |
Which
is
fake:
69 |
70 | 71 | 72 |
73 |
75 |
76 | 77 | 78 |
79 |
81 |
82 | 83 | 84 |
85 |
87 |
88 | 89 | 90 |
91 |
95 |
96 |
97 |
98 | 116 | 117 | 118 | 120 | 123 | 141 | -------------------------------------------------------------------------------- /mturk_websites/mturk_celeba.html: -------------------------------------------------------------------------------- 1 | 3 | 4 |
5 |
6 |
7 |
8 | 9 | Gender Tagging 10 | Instructions, please read! (click to collapse) 11 |
12 |

You must identify the genders of faces shown in the images

13 |
    14 |
  • Select the gender of each human face image. 15 |
  • 16 |
  • Select N/A only when the image contains no human face or the face is heavy distorted
  • 17 |
  • Irresponsible answers will be REJECTED and you WON'T GET 18 | PAID. 19 |
  • 20 |
  • Examples:
    21 | example1 22 | example2 23 | example3 24 | example4 25 |
  • 26 |
27 |
28 |
29 |
30 |
31 | 32 |
33 |
34 |
35 | 36 | 37 | 38 | 39 | 41 | 43 | 45 | 47 | 49 | 50 | 51 | 52 | 59 | 66 | 73 | 80 | 87 | 88 | 89 | 90 | 92 | 94 | 96 | 98 | 100 | 101 | 102 | 103 | 110 | 117 | 124 | 131 | 138 | 139 | 140 |
Image:
image_url_0image_url_1image_url_2image_url_3image_url_4
Gender:
53 |
54 | 55 | 56 | 57 |
58 |
60 |
61 | 62 | 63 | 64 |
65 |
67 |
68 | 69 | 70 | 71 |
72 |
74 |
75 | 76 | 77 | 78 |
79 |
81 |
82 | 83 | 84 | 85 |
86 |
Image:
image_url_5image_url_6image_url_7image_url_8image_url_9
Digit:
104 |
105 | 106 | 107 | 108 |
109 |
111 |
112 | 113 | 114 | 115 |
116 |
118 |
119 | 120 | 121 | 122 |
123 |
125 |
126 | 127 | 128 | 129 |
130 |
132 |
133 | 134 | 135 | 136 |
137 |
141 |
142 |
143 |
144 |
145 | 163 | 164 | 165 | 167 | 170 | 188 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Most codes from https://github.com/carpedm20/DCGAN-tensorflow 3 | """ 4 | import math 5 | 6 | from utils import * 7 | 8 | if "concat_v2" in dir(tf): 9 | def concat(tensors, axis, *args, **kwargs): 10 | return tf.concat_v2(tensors, axis, *args, **kwargs) 11 | else: 12 | def concat(tensors, axis, *args, **kwargs): 13 | return tf.concat(tensors, axis, *args, **kwargs) 14 | 15 | def bn(x, is_training, scope): 16 | return tf.contrib.layers.batch_norm(x, 17 | decay=0.9, 18 | updates_collections=None, 19 | epsilon=1e-5, 20 | scale=True, 21 | is_training=is_training, 22 | scope=scope) 23 | 24 | def conv_out_size_same(size, stride): 25 | return int(math.ceil(float(size) / float(stride))) 26 | 27 | def conv_cond_concat(x, y): 28 | """Concatenate conditioning vector on feature map axis.""" 29 | x_shapes = x.get_shape() 30 | y_shapes = y.get_shape() 31 | return concat([x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3) 32 | 33 | def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="conv2d"): 34 | with tf.variable_scope(name): 35 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 36 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 37 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 38 | 39 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 40 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 41 | 42 | return conv 43 | 44 | def deconv2d(input_, output_shape, k_h=5, k_w=5, d_h=2, d_w=2, name="deconv2d", stddev=0.02, with_w=False): 45 | with tf.variable_scope(name): 46 | # filter : [height, width, output_channels, in_channels] 47 | w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], 48 | initializer=tf.random_normal_initializer(stddev=stddev)) 49 | 50 | try: 51 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1]) 52 | 53 | # Support for verisons of TensorFlow before 0.7.0 54 | except AttributeError: 55 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1]) 56 | 57 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 58 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 59 | 60 | if with_w: 61 | return deconv, w, biases 62 | else: 63 | return deconv 64 | 65 | def lrelu(x, leak=0.2, name="lrelu"): 66 | return tf.maximum(x, leak*x) 67 | 68 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 69 | shape = input_.get_shape().as_list() 70 | 71 | with tf.variable_scope(scope or "Linear"): 72 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 73 | tf.random_normal_initializer(stddev=stddev)) 74 | bias = tf.get_variable("bias", [output_size], 75 | initializer=tf.constant_initializer(bias_start)) 76 | if with_w: 77 | return tf.matmul(input_, matrix) + bias, matrix, bias 78 | else: 79 | return tf.matmul(input_, matrix) + bias -------------------------------------------------------------------------------- /train_acgan.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils import * 3 | from models.acwgan_gp import ACWGAN_GP 4 | import json 5 | 6 | parser = argparse.ArgumentParser("Training Wasserstein ACGAN") 7 | parser.add_argument('--dataset', type=str, default='mnist', help="dataset: mnist | svnh | celebA") 8 | parser.add_argument('--n_epochs', type=int, default=50, help="number of epochs") 9 | parser.add_argument('--batch_size', type=int, default=64, help="batch size") 10 | parser.add_argument('--z_dim', type=int, default=128, help="dimension of noise vector") 11 | parser.add_argument('--checkpoint_dir', type=str, default='assets/checkpoint', 12 | help='Directory name to save the checkpoints') 13 | parser.add_argument('--result_dir', type=str, default='assets/results', 14 | help='Directory name to save the generated images') 15 | parser.add_argument('--log_dir', type=str, default='assets/logs', 16 | help='Directory name to save training logs') 17 | 18 | args = parser.parse_args() 19 | check_folder(args.checkpoint_dir) 20 | 21 | # --result_dir 22 | check_folder(args.result_dir) 23 | 24 | # --result_dir 25 | check_folder(args.log_dir) 26 | 27 | def main(): 28 | print("[*] input args:\n", json.dumps(vars(args), indent=4, separators=(',', ':'))) 29 | 30 | tf_config = tf.ConfigProto() 31 | tf_config.gpu_options.allow_growth = True 32 | if args.dataset == 'mnist': 33 | dim_D = 32 34 | dim_G = 32 35 | elif args.dataset in 'svhn': 36 | dim_D = 128 37 | dim_G = 128 38 | elif args.dataset == 'celebA': 39 | dim_D = 64 40 | dim_G = 64 41 | 42 | with tf.Session(config=tf_config) as sess: 43 | gan = ACWGAN_GP( 44 | sess, 45 | epoch=args.n_epochs, 46 | batch_size=args.batch_size, 47 | z_dim=args.z_dim, 48 | dataset_name=args.dataset, 49 | checkpoint_dir=args.checkpoint_dir, 50 | result_dir=args.result_dir, 51 | log_dir=args.log_dir, 52 | dim_D=dim_D, 53 | dim_G=dim_G 54 | ) 55 | 56 | gan.build_model() 57 | show_all_variables() 58 | gan.train() 59 | print(" [*] Training finished") 60 | gan.visualize_results(args.n_epochs - 1) 61 | print(" [*] Testing finished") 62 | 63 | 64 | if __name__ == '__main__': 65 | main() 66 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Most codes from https://github.com/carpedm20/DCGAN-tensorflow 3 | """ 4 | import scipy.misc 5 | import scipy.io as sio 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import os, gzip 9 | import cv2 as cv 10 | 11 | import tensorflow as tf 12 | import tensorflow.contrib.slim as slim 13 | 14 | 15 | def load_mnist(dataset_name, trainonly=False): 16 | data_dir = os.path.join("assets/data", dataset_name) 17 | 18 | def extract_data(filename, num_data, head_size, data_size): 19 | with gzip.open(filename) as bytestream: 20 | bytestream.read(head_size) 21 | buf = bytestream.read(data_size * num_data) 22 | data = np.frombuffer(buf, dtype=np.uint8).astype(np.float) 23 | return data 24 | 25 | data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28) 26 | trX = data.reshape((60000, 28, 28, 1)) 27 | 28 | data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1) 29 | trY = data.reshape((60000)) 30 | 31 | data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28) 32 | teX = data.reshape((10000, 28, 28, 1)) 33 | 34 | data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1) 35 | teY = data.reshape((10000)) 36 | 37 | trY = np.asarray(trY) 38 | teY = np.asarray(teY) 39 | 40 | if trainonly: 41 | X = trX 42 | y = trY.astype(np.int) 43 | else: 44 | X = np.concatenate((trX, teX), axis=0) 45 | y = np.concatenate((trY, teY), axis=0).astype(np.int) 46 | 47 | seed = 547 48 | np.random.seed(seed) 49 | np.random.shuffle(X) 50 | np.random.seed(seed) 51 | np.random.shuffle(y) 52 | 53 | y_vec = np.zeros((len(y), 10), dtype=np.float) 54 | for i, label in enumerate(y): 55 | y_vec[i, y[i]] = 1.0 56 | 57 | return X / 255., y_vec 58 | 59 | 60 | def load_svhn(source_class=None, trainonly=False): 61 | print("[*] Loading SVHN") 62 | data_dir = os.path.join("assets", "data", "svhn") 63 | 64 | def extract_data(filename): 65 | data = sio.loadmat(os.path.join(data_dir, filename)) 66 | X = data['X'].transpose(3, 0, 1, 2) 67 | y = data['y'].reshape((-1)) 68 | y[y == 10] = 0 69 | return X, y.astype(np.int) 70 | 71 | trX, trY = extract_data('train_32x32.mat') 72 | teX, teY = extract_data('test_32x32.mat') 73 | exX, exY = extract_data('extra_32x32.mat') 74 | 75 | print("[*] SVHN loaded") 76 | 77 | if trainonly: 78 | X = trX 79 | y = trY 80 | else: 81 | X = np.concatenate([trX, teX, exX], axis=0) 82 | y = np.concatenate([trY, teY, exY], axis=0) 83 | 84 | if source_class is not None: 85 | idx = (y == source_class) 86 | X = X[idx] 87 | y = y[idx] 88 | 89 | seed = 547 90 | np.random.seed(seed) 91 | np.random.shuffle(X) 92 | np.random.seed(seed) 93 | np.random.shuffle(y) 94 | 95 | y_vec = np.zeros((len(y), 10), dtype=np.float) 96 | y_vec[np.arange(0, len(y)), y] = 1.0 97 | return X / 255., y_vec 98 | 99 | 100 | def load_celebA(): 101 | print("[*] Loading CelebA") 102 | X = sio.loadmat('/atlas/u/ruishu/data/celeba64_zoom.mat')['images'] 103 | y = sio.loadmat('/atlas/u/ruishu/data/celeba_gender.mat')['y'] 104 | y = np.eye(2)[y.reshape(-1)] 105 | 106 | seed = 547 107 | np.random.seed(seed) 108 | np.random.shuffle(X) 109 | np.random.seed(seed) 110 | np.random.shuffle(y) 111 | return X / 255., y 112 | 113 | 114 | def load_celebA4classifier(): 115 | print("[*] Loading CelebA") 116 | X = sio.loadmat('/atlas/u/ruishu/data/celeba64_zoom.mat')['images'] 117 | y = sio.loadmat('/atlas/u/ruishu/data/celeba_gender.mat')['y'] 118 | y = np.eye(2)[y.reshape(-1)] 119 | 120 | trX = X[:150000] 121 | trY = y[:150000] 122 | teX = X[150000:] 123 | teY = y[150000:] 124 | return trX / 255., trY, teX / 255., teY 125 | 126 | 127 | def load_svhn4classifier(): 128 | print("[*] Loading SVHN") 129 | data_dir = os.path.join("assets", "data", "svhn") 130 | 131 | def extract_data(filename): 132 | data = sio.loadmat(os.path.join(data_dir, filename)) 133 | X = data['X'].transpose(3, 0, 1, 2) 134 | y = data['y'].reshape((-1)) 135 | y[y == 10] = 0 136 | return X, y.astype(np.int) 137 | 138 | trX, trY = extract_data('train_32x32.mat') 139 | teX, teY = extract_data('test_32x32.mat') 140 | print("[*] SVHN loaded") 141 | seed = 547 142 | np.random.seed(seed) 143 | np.random.shuffle(trX) 144 | np.random.seed(seed) 145 | np.random.shuffle(trY) 146 | 147 | tr_y_vec = np.zeros((len(trY), 10), dtype=np.float) 148 | tr_y_vec[np.arange(0, len(trY)), trY] = 1.0 149 | 150 | te_y_vec = np.zeros((len(teY), 10), dtype=np.float) 151 | te_y_vec[np.arange(0, len(teY)), teY] = 1.0 152 | return trX / 255., tr_y_vec, teX / 255., te_y_vec 153 | 154 | 155 | def load_mnist4classifier(dataset_name): 156 | data_dir = os.path.join("assets/data", dataset_name) 157 | 158 | def extract_data(filename, num_data, head_size, data_size): 159 | with gzip.open(filename) as bytestream: 160 | bytestream.read(head_size) 161 | buf = bytestream.read(data_size * num_data) 162 | data = np.frombuffer(buf, dtype=np.uint8).astype(np.float) 163 | return data 164 | 165 | data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28) 166 | trX = data.reshape((60000, 28, 28, 1)) 167 | 168 | data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1) 169 | trY = data.reshape((60000)) 170 | 171 | data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28) 172 | teX = data.reshape((10000, 28, 28, 1)) 173 | 174 | data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1) 175 | teY = data.reshape((10000)) 176 | 177 | trY = np.asarray(trY).astype(np.int) 178 | teY = np.asarray(teY).astype(np.int) 179 | 180 | seed = 547 181 | np.random.seed(seed) 182 | np.random.shuffle(trX) 183 | np.random.seed(seed) 184 | np.random.shuffle(trY) 185 | 186 | tr_y_vec = np.zeros((len(trY), 10), dtype=np.float) 187 | tr_y_vec[np.arange(0, len(trY)), trY] = 1.0 188 | te_y_vec = np.zeros((len(teY), 10), dtype=np.float) 189 | te_y_vec[np.arange(0, len(teY)), teY] = 1.0 190 | 191 | return trX / 255., tr_y_vec, teX / 255., te_y_vec 192 | 193 | 194 | def check_folder(log_dir): 195 | if not os.path.exists(log_dir): 196 | os.makedirs(log_dir) 197 | return log_dir 198 | 199 | 200 | def show_all_variables(): 201 | model_vars = tf.trainable_variables() 202 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 203 | 204 | 205 | def get_image(image_path, input_height, input_width, resize_height=64, resize_width=64, crop=True, grayscale=False): 206 | image = imread(image_path, grayscale) 207 | return transform(image, input_height, input_width, resize_height, resize_width, crop) 208 | 209 | 210 | def write_labels(labels, dataset, size): 211 | if dataset in ('mnist', 'svhn'): 212 | dic = {x: str(x) for x in range(10)} 213 | else: 214 | raise NotImplementedError("Dataset {} not supported".format(dataset)) 215 | print("adversarial labels:") 216 | for i in range(size): 217 | for j in range(size): 218 | print("{}".format(dic[labels[i * size + j]]), end='\t') 219 | print("") 220 | 221 | 222 | def save_images(images, size, image_path): 223 | return imsave(inverse_transform(images), size, image_path) 224 | 225 | 226 | def label_images(images, labels): 227 | font = cv.FONT_HERSHEY_SIMPLEX 228 | new_imgs = [] 229 | for i, img in enumerate(images): 230 | new_img = ((img.copy() + 1.) * 127.5).astype(np.uint8) 231 | if new_img.shape[-1] == 3: 232 | new_img = new_img[..., ::-1] 233 | new_img = cv.resize(new_img, (100, 100), interpolation=cv.INTER_LINEAR) 234 | new_img = cv.putText(new_img, str(labels[i]), (10, 30), font, 1, (255, 255, 255), 2, cv.LINE_AA) 235 | new_img = cv.copyMakeBorder(new_img, top=2, bottom=2, left=2, right=2, borderType=cv.BORDER_CONSTANT, 236 | value=(255, 255, 255)) 237 | else: 238 | new_img = np.squeeze(new_img) 239 | new_img = cv.resize(new_img, (100, 100), interpolation=cv.INTER_LINEAR) 240 | new_img = cv.putText(new_img, str(labels[i]), (10, 30), font, 1, (255), 2, cv.LINE_AA) 241 | new_img = new_img[..., None] 242 | 243 | new_img = (new_img / 127.5 - 1.0).astype(np.float32) 244 | new_imgs.append(new_img[..., ::-1]) 245 | return np.stack(new_imgs, axis=0) 246 | 247 | 248 | def imread(path, grayscale=False): 249 | if (grayscale): 250 | return scipy.misc.imread(path, flatten=True).astype(np.float) 251 | else: 252 | return scipy.misc.imread(path).astype(np.float) 253 | 254 | 255 | def merge_images(images, size): 256 | return inverse_transform(images) 257 | 258 | 259 | def merge(images, size): 260 | h, w = images.shape[1], images.shape[2] 261 | if (images.shape[3] in (3, 4)): 262 | c = images.shape[3] 263 | img = np.zeros((h * size[0], w * size[1], c)) 264 | for idx, image in enumerate(images): 265 | i = idx % size[1] 266 | j = idx // size[1] 267 | img[j * h:j * h + h, i * w:i * w + w, :] = image 268 | return img 269 | elif images.shape[3] == 1: 270 | img = np.zeros((h * size[0], w * size[1])) 271 | for idx, image in enumerate(images): 272 | i = idx % size[1] 273 | j = idx // size[1] 274 | img[j * h:j * h + h, i * w:i * w + w] = image[:, :, 0] 275 | return img 276 | else: 277 | raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4') 278 | 279 | 280 | def imsave(images, size, path): 281 | image = np.squeeze(merge(images, size)) 282 | return scipy.misc.imsave(path, image) 283 | 284 | 285 | def center_crop(x, crop_h, crop_w, resize_h=64, resize_w=64): 286 | if crop_w is None: 287 | crop_w = crop_h 288 | h, w = x.shape[:2] 289 | j = int(round((h - crop_h) / 2.)) 290 | i = int(round((w - crop_w) / 2.)) 291 | return scipy.misc.imresize(x[j:j + crop_h, i:i + crop_w], [resize_h, resize_w]) 292 | 293 | 294 | def transform(image, input_height, input_width, resize_height=64, resize_width=64, crop=True): 295 | if crop: 296 | cropped_image = center_crop(image, input_height, input_width, resize_height, resize_width) 297 | else: 298 | cropped_image = scipy.misc.imresize(image, [resize_height, resize_width]) 299 | return np.array(cropped_image) / 127.5 - 1. 300 | 301 | 302 | def inverse_transform(images): 303 | return (images + 1.) / 2. 304 | 305 | 306 | """ Drawing Tools """ 307 | 308 | 309 | # borrowed from https://github.com/ykwon0407/variational_autoencoder/blob/master/variational_bayes.ipynb 310 | def save_scattered_image(z, id, z_range_x, z_range_y, name='scattered_image.jpg'): 311 | N = 10 312 | plt.figure(figsize=(8, 6)) 313 | plt.scatter(z[:, 0], z[:, 1], c=np.argmax(id, 1), marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet')) 314 | plt.colorbar(ticks=range(N)) 315 | axes = plt.gca() 316 | axes.set_xlim([-z_range_x, z_range_x]) 317 | axes.set_ylim([-z_range_y, z_range_y]) 318 | plt.grid(True) 319 | plt.savefig(name) 320 | 321 | 322 | # borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a 323 | def discrete_cmap(N, base_cmap=None): 324 | """Create an N-bin discrete colormap from the specified input map""" 325 | 326 | # Note that if base_cmap is a string or None, you can simply do 327 | # return plt.cm.get_cmap(base_cmap, N) 328 | # The following works for string, None, or a colormap instance: 329 | 330 | base = plt.cm.get_cmap(base_cmap) 331 | color_list = base(np.linspace(0, 1, N)) 332 | cmap_name = base.name + str(N) 333 | return base.from_list(cmap_name, color_list, N) 334 | 335 | 336 | def per_image_standardization(images, image_size=28): 337 | image_mean, image_std = tf.nn.moments(images, axes=[1, 2, 3]) 338 | image_std = tf.sqrt(image_std)[:, None, None, None] 339 | images_standardized = (images - image_mean[:, None, None, None]) / tf.maximum(image_std, 1.0 / np.sqrt( 340 | image_size ** 2 * 3)) 341 | return images_standardized 342 | 343 | 344 | def gradients(f, x, grad_ys=None): 345 | ''' 346 | An easier way of computing gradients in tensorflow. The difference from tf.gradients is 347 | * If f is not connected with x in the graph, it will output 0s instead of Nones. This will be more meaningful 348 | for computing higher-order gradients. 349 | 350 | * The output will have the same shape and type as x. If x is a list, it will be a list. If x is a Tensor, it 351 | will be a tensor as well. 352 | 353 | :param f: A `Tensor` or a list of tensors to be differentiated 354 | :param x: A `Tensor` or a list of tensors to be used for differentiation 355 | :param grad_ys: Optional. It is a `Tensor` or a list of tensors having exactly the same shape and type as `f` and 356 | holds gradients computed for each of `f`. 357 | :return: A `Tensor` or a list of tensors having the same shape and type as `x` 358 | ''' 359 | 360 | if isinstance(x, list): 361 | grad = tf.gradients(f, x, grad_ys=grad_ys) 362 | for i in range(len(x)): 363 | if grad[i] is None: 364 | grad[i] = tf.zeros_like(x[i]) 365 | return grad 366 | else: 367 | grad = tf.gradients(f, x, grad_ys=grad_ys)[0] 368 | if grad is None: 369 | return tf.zeros_like(x) 370 | else: 371 | return grad 372 | 373 | 374 | def Lop(f, x, v): 375 | ''' 376 | Compute Jacobian-vector product. The result is v^T @ J_x 377 | 378 | :param f: A `Tensor` or a list of tensors for computing the Jacobian J_x 379 | :param x: A `Tensor` or a list of tensors with respect to which the Jacobian is computed. 380 | :param v: A `Tensor` or a list of tensors having the same shape and type as `f` 381 | :return: A `Tensor` or a list of tensors having the same shape and type as `x` 382 | ''' 383 | assert not isinstance(f, list) or isinstance(v, list), "f and v should be of the same type" 384 | return gradients(f, x, grad_ys=v) 385 | 386 | 387 | def Rop(f, x, v): 388 | ''' 389 | Compute Jacobian-vector product. The result is J_x @ v. 390 | The method is inspired by [deep yearning's blog](https://j-towns.github.io/2017/06/12/A-new-trick.html) 391 | :param f: A `Tensor` or a list of tensors for computing the Jacobian J_x 392 | :param x: A `Tensor` or a list of tensors with respect to which the Jacobian is computed 393 | :param v: A `Tensor` or a list of tensors having the same shape and type as `v` 394 | :return: A `Tensor` or a list of tensors having the same shape and type as `f` 395 | ''' 396 | assert not isinstance(x, list) or isinstance(v, list), "x and v should be of the same type" 397 | if isinstance(f, list): 398 | w = [tf.ones_like(_) for _ in f] 399 | else: 400 | w = tf.ones_like(f) 401 | return gradients(Lop(f, x, w), w, grad_ys=v) 402 | --------------------------------------------------------------------------------