├── DGAN.ipynb ├── README.md ├── download.py ├── main.py ├── model.py ├── ops.py ├── uniform.py ├── utils.py └── wikiart_scraper.ipynb /DGAN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "DGAN", 7 | "version": "0.3.2", 8 | "views": {}, 9 | "default_view": {}, 10 | "provenance": [], 11 | "collapsed_sections": [ 12 | "Qu0fuzQ4Sh9d" 13 | ] 14 | }, 15 | "kernelspec": { 16 | "name": "python2", 17 | "display_name": "Python 2" 18 | }, 19 | "accelerator": "GPU" 20 | }, 21 | "cells": [ 22 | { 23 | "metadata": { 24 | "id": "zmtd_DvDErgZ", 25 | "colab_type": "text" 26 | }, 27 | "cell_type": "markdown", 28 | "source": [ 29 | "# Setup" 30 | ] 31 | }, 32 | { 33 | "metadata": { 34 | "id": "k3S5q576kbgz", 35 | "colab_type": "text" 36 | }, 37 | "cell_type": "markdown", 38 | "source": [ 39 | "Check if GPU is enabled" 40 | ] 41 | }, 42 | { 43 | "metadata": { 44 | "id": "gdZxqX5qkMn2", 45 | "colab_type": "code", 46 | "colab": { 47 | "autoexec": { 48 | "startup": false, 49 | "wait_interval": 0 50 | } 51 | } 52 | }, 53 | "cell_type": "code", 54 | "source": [ 55 | "import tensorflow as tf\n", 56 | "tf.test.gpu_device_name()" 57 | ], 58 | "execution_count": 0, 59 | "outputs": [] 60 | }, 61 | { 62 | "metadata": { 63 | "id": "CHP1ugLUbA7F", 64 | "colab_type": "text" 65 | }, 66 | "cell_type": "markdown", 67 | "source": [ 68 | "check how much memory this gpu has" 69 | ] 70 | }, 71 | { 72 | "metadata": { 73 | "id": "PZ8_gNv14yD4", 74 | "colab_type": "code", 75 | "colab": { 76 | "autoexec": { 77 | "startup": false, 78 | "wait_interval": 0 79 | } 80 | } 81 | }, 82 | "cell_type": "code", 83 | "source": [ 84 | "!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi\n", 85 | "!pip install gputil\n", 86 | "!pip install psutil\n", 87 | "!pip install humanize\n", 88 | "import psutil\n", 89 | "import humanize\n", 90 | "import os\n", 91 | "import GPUtil as GPU\n", 92 | "GPUs = GPU.getGPUs()\n", 93 | "# XXX: only one GPU on Colab and isn’t guaranteed\n", 94 | "gpu = GPUs[0]\n", 95 | "def printm():\n", 96 | " process = psutil.Process(os.getpid())\n", 97 | " print(\"Gen RAM Free: \" + humanize.naturalsize( psutil.virtual_memory().available ), \" I Proc size: \" + humanize.naturalsize( process.memory_info().rss))\n", 98 | " print(\"GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB\".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))\n", 99 | "printm()" 100 | ], 101 | "execution_count": 0, 102 | "outputs": [] 103 | }, 104 | { 105 | "metadata": { 106 | "id": "sBImtdUF45iW", 107 | "colab_type": "text" 108 | }, 109 | "cell_type": "markdown", 110 | "source": [ 111 | "DON'T RUN - only to remove everything\n", 112 | "💀💀💀💀💀💀💀💀\n", 113 | "!kill -9 -1\n", 114 | "💀💀💀💀💀💀💀💀\n" 115 | ] 116 | }, 117 | { 118 | "metadata": { 119 | "id": "tV4_I7Q0B9i-", 120 | "colab_type": "code", 121 | "colab": { 122 | "autoexec": { 123 | "startup": false, 124 | "wait_interval": 0 125 | } 126 | } 127 | }, 128 | "cell_type": "code", 129 | "source": [ 130 | "import os\n", 131 | "os.listdir('.')" 132 | ], 133 | "execution_count": 0, 134 | "outputs": [] 135 | }, 136 | { 137 | "metadata": { 138 | "id": "jRWG4VLlzD4e", 139 | "colab_type": "code", 140 | "colab": { 141 | "autoexec": { 142 | "startup": false, 143 | "wait_interval": 0 144 | } 145 | } 146 | }, 147 | "cell_type": "code", 148 | "source": [ 149 | "" 150 | ], 151 | "execution_count": 0, 152 | "outputs": [] 153 | }, 154 | { 155 | "metadata": { 156 | "id": "FoI6Ic8HmNZ2", 157 | "colab_type": "text" 158 | }, 159 | "cell_type": "markdown", 160 | "source": [ 161 | "# Install repo and import images" 162 | ] 163 | }, 164 | { 165 | "metadata": { 166 | "id": "38qyMuE3XtP9", 167 | "colab_type": "text" 168 | }, 169 | "cell_type": "markdown", 170 | "source": [ 171 | "Lets use the repo made by carpedm20, a tensorflow implementation to Dgans" 172 | ] 173 | }, 174 | { 175 | "metadata": { 176 | "id": "qZ49GKGLmK_S", 177 | "colab_type": "code", 178 | "colab": { 179 | "autoexec": { 180 | "startup": false, 181 | "wait_interval": 0 182 | } 183 | } 184 | }, 185 | "cell_type": "code", 186 | "source": [ 187 | "!git clone https://github.com/sinanatra/DCGAN-Art-Tensorflow.git" 188 | ], 189 | "execution_count": 0, 190 | "outputs": [] 191 | }, 192 | { 193 | "metadata": { 194 | "id": "R7_8N0XzmLF8", 195 | "colab_type": "code", 196 | "colab": { 197 | "autoexec": { 198 | "startup": false, 199 | "wait_interval": 0 200 | } 201 | } 202 | }, 203 | "cell_type": "code", 204 | "source": [ 205 | "cd DCGAN-Art-Tensorflow" 206 | ], 207 | "execution_count": 0, 208 | "outputs": [] 209 | }, 210 | { 211 | "metadata": { 212 | "id": "Ob5XE7hZwuwc", 213 | "colab_type": "code", 214 | "colab": { 215 | "autoexec": { 216 | "startup": false, 217 | "wait_interval": 0 218 | } 219 | } 220 | }, 221 | "cell_type": "code", 222 | "source": [ 223 | "# install the necessary libraries\n", 224 | "!pip install tqdm\n", 225 | "!pip install -U -q PyDrive\n", 226 | "!pip install googledrivedownloader" 227 | ], 228 | "execution_count": 0, 229 | "outputs": [] 230 | }, 231 | { 232 | "metadata": { 233 | "id": "nUFAWzEgUrIy", 234 | "colab_type": "text" 235 | }, 236 | "cell_type": "markdown", 237 | "source": [ 238 | "\n", 239 | "\n", 240 | "```\n", 241 | "# This is formatted as code\n", 242 | "```\n", 243 | "\n", 244 | "Import from your Google Drive your dataset -* never upload with the files.upload(), it crashes if file is too big*" 245 | ] 246 | }, 247 | { 248 | "metadata": { 249 | "id": "z6IqbD3hfdjZ", 250 | "colab_type": "code", 251 | "colab": { 252 | "autoexec": { 253 | "startup": false, 254 | "wait_interval": 0 255 | } 256 | } 257 | }, 258 | "cell_type": "code", 259 | "source": [ 260 | "from pydrive.auth import GoogleAuth\n", 261 | "from pydrive.drive import GoogleDrive\n", 262 | "from google.colab import auth\n", 263 | "from oauth2client.client import GoogleCredentials\n", 264 | "\n", 265 | "# 1. Authenticate and create the PyDrive client.\n", 266 | "auth.authenticate_user()\n", 267 | "gauth = GoogleAuth()\n", 268 | "gauth.credentials = GoogleCredentials.get_application_default()\n", 269 | "drive = GoogleDrive(gauth)\n", 270 | "print(\"all right\")" 271 | ], 272 | "execution_count": 0, 273 | "outputs": [] 274 | }, 275 | { 276 | "metadata": { 277 | "id": "8JnlmpLxoKsD", 278 | "colab_type": "text" 279 | }, 280 | "cell_type": "markdown", 281 | "source": [ 282 | "# Import Dataset" 283 | ] 284 | }, 285 | { 286 | "metadata": { 287 | "id": "auwW5mgYboJO", 288 | "colab_type": "text" 289 | }, 290 | "cell_type": "markdown", 291 | "source": [ 292 | "Some of the dataset i'm using, not sure they will work on other computers, just check or else upload a folder with the same img size and colors (RGB, not RGBA) to your drive. Make it a sharable link and copy the last part like this: drive.google.com/open?id=**1LWolfUnkoAwHnxmyOz3PnI1gWsQI-oUc**" 293 | ] 294 | }, 295 | { 296 | "metadata": { 297 | "id": "YLBwMdxMW3PR", 298 | "colab_type": "code", 299 | "colab": { 300 | "autoexec": { 301 | "startup": false, 302 | "wait_interval": 0 303 | } 304 | } 305 | }, 306 | "cell_type": "code", 307 | "source": [ 308 | "#link_name = \"16PP-ayI5CoGB04nKxA_qSsH2h48qxSJN\" # Prints\n", 309 | "#link_name =\"1LWolfUnkoAwHnxmyOz3PnI1gWsQI-oUc\" #BCBF\n", 310 | "link_name =\"1DMGmyr-SuB8ugKMafxPEYBe-4ta30aUf\" # portraits - wiki_art / if you want to use wiki_art images i have a scraper somewhere\n", 311 | "#link_name=\"1wDOHDKd7LFyEuicYe1MesMrkIDrC-tQ_\"\n", 312 | "#link_name = \"1oQ8aH64pNejs-DOxZ8Sk2onLHFEm7m-X\" CHEKPOINTS" 313 | ], 314 | "execution_count": 0, 315 | "outputs": [] 316 | }, 317 | { 318 | "metadata": { 319 | "id": "ROE0HrCCnVvh", 320 | "colab_type": "text" 321 | }, 322 | "cell_type": "markdown", 323 | "source": [ 324 | "create dirs to place dataset" 325 | ] 326 | }, 327 | { 328 | "metadata": { 329 | "id": "H_-jyFNY58zz", 330 | "colab_type": "code", 331 | "colab": { 332 | "autoexec": { 333 | "startup": false, 334 | "wait_interval": 0 335 | } 336 | } 337 | }, 338 | "cell_type": "code", 339 | "source": [ 340 | "mkdir data" 341 | ], 342 | "execution_count": 0, 343 | "outputs": [] 344 | }, 345 | { 346 | "metadata": { 347 | "id": "yiHcgEyWfZNk", 348 | "colab_type": "code", 349 | "colab": { 350 | "autoexec": { 351 | "startup": false, 352 | "wait_interval": 0 353 | } 354 | } 355 | }, 356 | "cell_type": "code", 357 | "source": [ 358 | "mkdir data/portrait" 359 | ], 360 | "execution_count": 0, 361 | "outputs": [] 362 | }, 363 | { 364 | "metadata": { 365 | "id": "ZIGaVL0gkSXZ", 366 | "colab_type": "code", 367 | "colab": { 368 | "autoexec": { 369 | "startup": false, 370 | "wait_interval": 0 371 | } 372 | } 373 | }, 374 | "cell_type": "code", 375 | "source": [ 376 | "#rm -R data" 377 | ], 378 | "execution_count": 0, 379 | "outputs": [] 380 | }, 381 | { 382 | "metadata": { 383 | "id": "bP036Q00cVIK", 384 | "colab_type": "text" 385 | }, 386 | "cell_type": "markdown", 387 | "source": [ 388 | "import from drive and unzip the file to *dest*" 389 | ] 390 | }, 391 | { 392 | "metadata": { 393 | "id": "d06JZAcwcRQa", 394 | "colab_type": "code", 395 | "colab": { 396 | "autoexec": { 397 | "startup": false, 398 | "wait_interval": 0 399 | } 400 | } 401 | }, 402 | "cell_type": "code", 403 | "source": [ 404 | "dest =\"./data/portrait/\"+ link_name+\".zip\"\n" 405 | ], 406 | "execution_count": 0, 407 | "outputs": [] 408 | }, 409 | { 410 | "metadata": { 411 | "id": "S9byes1iUp0T", 412 | "colab_type": "code", 413 | "colab": { 414 | "autoexec": { 415 | "startup": false, 416 | "wait_interval": 0 417 | } 418 | } 419 | }, 420 | "cell_type": "code", 421 | "source": [ 422 | "from google_drive_downloader import GoogleDriveDownloader as gdd\n", 423 | "\n", 424 | "gdd.download_file_from_google_drive(file_id=link_name,\n", 425 | " dest_path=dest,\n", 426 | " unzip=True)" 427 | ], 428 | "execution_count": 0, 429 | "outputs": [] 430 | }, 431 | { 432 | "metadata": { 433 | "id": "Nl9YW79Fe5ko", 434 | "colab_type": "text" 435 | }, 436 | "cell_type": "markdown", 437 | "source": [ 438 | "# Import checkpoints from drive" 439 | ] 440 | }, 441 | { 442 | "metadata": { 443 | "id": "dnW2nvY95tSs", 444 | "colab_type": "code", 445 | "colab": { 446 | "autoexec": { 447 | "startup": false, 448 | "wait_interval": 0 449 | } 450 | } 451 | }, 452 | "cell_type": "code", 453 | "source": [ 454 | "link1 = \"19dTjHVgIDzzIpirdPkDoZacGQ-I47W7P\" #CHEKPOINTS " 455 | ], 456 | "execution_count": 0, 457 | "outputs": [] 458 | }, 459 | { 460 | "metadata": { 461 | "id": "0sfkWm6R0rU-", 462 | "colab_type": "code", 463 | "colab": { 464 | "autoexec": { 465 | "startup": false, 466 | "wait_interval": 0 467 | } 468 | } 469 | }, 470 | "cell_type": "code", 471 | "source": [ 472 | "mkdir checkpoint" 473 | ], 474 | "execution_count": 0, 475 | "outputs": [] 476 | }, 477 | { 478 | "metadata": { 479 | "id": "DDfxMtIofZTy", 480 | "colab_type": "code", 481 | "colab": { 482 | "autoexec": { 483 | "startup": false, 484 | "wait_interval": 0 485 | } 486 | } 487 | }, 488 | "cell_type": "code", 489 | "source": [ 490 | "loc =\"./checkpoint/\"+ link1+\".zip\"" 491 | ], 492 | "execution_count": 0, 493 | "outputs": [] 494 | }, 495 | { 496 | "metadata": { 497 | "id": "TZ1K4Xtd64Yv", 498 | "colab_type": "code", 499 | "colab": { 500 | "autoexec": { 501 | "startup": false, 502 | "wait_interval": 0 503 | } 504 | } 505 | }, 506 | "cell_type": "code", 507 | "source": [ 508 | "from google_drive_downloader import GoogleDriveDownloader as gdd\n", 509 | "\n", 510 | "gdd.download_file_from_google_drive(file_id=link1,\n", 511 | " dest_path=loc,\n", 512 | " unzip=True)" 513 | ], 514 | "execution_count": 0, 515 | "outputs": [] 516 | }, 517 | { 518 | "metadata": { 519 | "id": "oW_WrHlg6X-a", 520 | "colab_type": "code", 521 | "colab": { 522 | "autoexec": { 523 | "startup": false, 524 | "wait_interval": 0 525 | } 526 | } 527 | }, 528 | "cell_type": "code", 529 | "source": [ 530 | "import os\n", 531 | "os.listdir(\"./checkpoint\")" 532 | ], 533 | "execution_count": 0, 534 | "outputs": [] 535 | }, 536 | { 537 | "metadata": { 538 | "id": "RgTJnDwRzRB9", 539 | "colab_type": "code", 540 | "colab": { 541 | "autoexec": { 542 | "startup": false, 543 | "wait_interval": 0 544 | } 545 | } 546 | }, 547 | "cell_type": "code", 548 | "source": [ 549 | "rm -r ./checkpoint/1AIEx-Gqqx4FaYDf4LWRNoS630MTCii0u.zip" 550 | ], 551 | "execution_count": 0, 552 | "outputs": [] 553 | }, 554 | { 555 | "metadata": { 556 | "id": "ocKWm5bXXBjO", 557 | "colab_type": "code", 558 | "colab": { 559 | "autoexec": { 560 | "startup": false, 561 | "wait_interval": 0 562 | } 563 | } 564 | }, 565 | "cell_type": "code", 566 | "source": [ 567 | "import os\n", 568 | "os.listdir(\"./checkpoint\")" 569 | ], 570 | "execution_count": 0, 571 | "outputs": [] 572 | }, 573 | { 574 | "metadata": { 575 | "id": "ZfL7NeCofR_r", 576 | "colab_type": "text" 577 | }, 578 | "cell_type": "markdown", 579 | "source": [ 580 | "# Uniform images" 581 | ] 582 | }, 583 | { 584 | "metadata": { 585 | "id": "eVvolST0hKDY", 586 | "colab_type": "text" 587 | }, 588 | "cell_type": "markdown", 589 | "source": [ 590 | "Run this to avoid this error: *ValueError: could not broadcast input array from shape.. *\n", 591 | "\n", 592 | "It converts all images in a directory to RGB (it removes the alpha channel) and moves everything to a new folder" 593 | ] 594 | }, 595 | { 596 | "metadata": { 597 | "id": "xbZk_1JIY_mr", 598 | "colab_type": "code", 599 | "colab": { 600 | "autoexec": { 601 | "startup": false, 602 | "wait_interval": 0 603 | } 604 | } 605 | }, 606 | "cell_type": "code", 607 | "source": [ 608 | "from PIL import Image\n", 609 | "import os, sys\n", 610 | "\n", 611 | "path ='./data/portrait/'\n", 612 | "\n", 613 | "dirs = os.listdir( path )\n", 614 | "\n", 615 | "\n", 616 | "for item in dirs:\n", 617 | " try:\n", 618 | " if os.path.isfile(path+item):\n", 619 | " im = Image.open(path+item)\n", 620 | " longer_side = max(im.size)\n", 621 | "\n", 622 | " horizontal_padding = (longer_side - im.size[0]) / 2\n", 623 | " vertical_padding = (longer_side - im.size[1]) / 2\n", 624 | " f, e = os.path.splitext(path+item)\n", 625 | " imResize = im.crop(\n", 626 | " (\n", 627 | " -horizontal_padding,\n", 628 | " -vertical_padding,\n", 629 | " im.size[0] + horizontal_padding,\n", 630 | " im.size[1] + vertical_padding\n", 631 | " )\n", 632 | " )\n", 633 | " RGB = imResize.convert('RGB')\n", 634 | " little = RGB.resize((32,32), Image.ANTIALIAS)\n", 635 | "\n", 636 | " little.save(f + 'resize.jpg', 'JPEG', quality=30)\n", 637 | " \n", 638 | " except Exception as e:\n", 639 | " print(e)" 640 | ], 641 | "execution_count": 0, 642 | "outputs": [] 643 | }, 644 | { 645 | "metadata": { 646 | "id": "9HSVWO--Wuw1", 647 | "colab_type": "code", 648 | "colab": { 649 | "autoexec": { 650 | "startup": false, 651 | "wait_interval": 0 652 | } 653 | } 654 | }, 655 | "cell_type": "code", 656 | "source": [ 657 | "import os\n", 658 | "os.listdir('./')" 659 | ], 660 | "execution_count": 0, 661 | "outputs": [] 662 | }, 663 | { 664 | "metadata": { 665 | "id": "_-NvQtkJxCeB", 666 | "colab_type": "text" 667 | }, 668 | "cell_type": "markdown", 669 | "source": [ 670 | "# Train - Test" 671 | ] 672 | }, 673 | { 674 | "metadata": { 675 | "id": "rKZxrFrG2a6v", 676 | "colab_type": "code", 677 | "colab": { 678 | "autoexec": { 679 | "startup": false, 680 | "wait_interval": 0 681 | } 682 | } 683 | }, 684 | "cell_type": "code", 685 | "source": [ 686 | "cd DCGAN-Art-Tensorflow" 687 | ], 688 | "execution_count": 0, 689 | "outputs": [] 690 | }, 691 | { 692 | "metadata": { 693 | "id": "qPPAxSqXO5iV", 694 | "colab_type": "code", 695 | "colab": { 696 | "autoexec": { 697 | "startup": false, 698 | "wait_interval": 0 699 | } 700 | } 701 | }, 702 | "cell_type": "code", 703 | "source": [ 704 | "!python main.py --batch_size 64 --dataset portrait --crop --train --epoch 150 --input_fname_pattern \"*resize.jpg\"" 705 | ], 706 | "execution_count": 0, 707 | "outputs": [] 708 | }, 709 | { 710 | "metadata": { 711 | "id": "8v4YAdGacx_G", 712 | "colab_type": "code", 713 | "colab": { 714 | "autoexec": { 715 | "startup": false, 716 | "wait_interval": 0 717 | } 718 | } 719 | }, 720 | "cell_type": "code", 721 | "source": [ 722 | "" 723 | ], 724 | "execution_count": 0, 725 | "outputs": [] 726 | }, 727 | { 728 | "metadata": { 729 | "id": "w4KRFkpS2GjH", 730 | "colab_type": "code", 731 | "colab": { 732 | "autoexec": { 733 | "startup": false, 734 | "wait_interval": 0 735 | } 736 | } 737 | }, 738 | "cell_type": "code", 739 | "source": [ 740 | "!python main.py --batch_size 64 --dataset portrait --crop --epoch 50 --input_fname_pattern \"*resize.jpg\"" 741 | ], 742 | "execution_count": 0, 743 | "outputs": [] 744 | }, 745 | { 746 | "metadata": { 747 | "id": "V6Qi4eBP2wog", 748 | "colab_type": "code", 749 | "colab": { 750 | "autoexec": { 751 | "startup": false, 752 | "wait_interval": 0 753 | } 754 | } 755 | }, 756 | "cell_type": "code", 757 | "source": [ 758 | "from google.colab import files\n", 759 | "files.download( \"./samples/test_arange_11.png\" )\n", 760 | "print(\"aaaa right\")" 761 | ], 762 | "execution_count": 0, 763 | "outputs": [] 764 | }, 765 | { 766 | "metadata": { 767 | "id": "asotRId_2tRj", 768 | "colab_type": "code", 769 | "colab": { 770 | "autoexec": { 771 | "startup": false, 772 | "wait_interval": 0 773 | } 774 | } 775 | }, 776 | "cell_type": "code", 777 | "source": [ 778 | "from google.colab import files\n", 779 | "files.download( \"./samples/test_arange_99.png\" )\n", 780 | "print(\"aaaa right\")" 781 | ], 782 | "execution_count": 0, 783 | "outputs": [] 784 | }, 785 | { 786 | "metadata": { 787 | "id": "Qu0fuzQ4Sh9d", 788 | "colab_type": "text" 789 | }, 790 | "cell_type": "markdown", 791 | "source": [ 792 | "# Save files" 793 | ] 794 | }, 795 | { 796 | "metadata": { 797 | "id": "x_0drhNwxAw8", 798 | "colab_type": "text" 799 | }, 800 | "cell_type": "markdown", 801 | "source": [ 802 | "save specific img" 803 | ] 804 | }, 805 | { 806 | "metadata": { 807 | "id": "l-dL0keI9SFB", 808 | "colab_type": "code", 809 | "colab": { 810 | "autoexec": { 811 | "startup": false, 812 | "wait_interval": 0 813 | } 814 | } 815 | }, 816 | "cell_type": "code", 817 | "source": [ 818 | "from google.colab import files\n", 819 | "files.download( \"./samples/test_arange_18.png\" )\n", 820 | "print(\"aaaa right\")" 821 | ], 822 | "execution_count": 0, 823 | "outputs": [] 824 | }, 825 | { 826 | "metadata": { 827 | "id": "fjth_GeX5PiW", 828 | "colab_type": "text" 829 | }, 830 | "cell_type": "markdown", 831 | "source": [ 832 | "save checkpoint folder - to continue when the VM changes" 833 | ] 834 | }, 835 | { 836 | "metadata": { 837 | "id": "jzVFJsBp8DLK", 838 | "colab_type": "code", 839 | "colab": { 840 | "autoexec": { 841 | "startup": false, 842 | "wait_interval": 0 843 | } 844 | } 845 | }, 846 | "cell_type": "code", 847 | "source": [ 848 | "cd .." 849 | ], 850 | "execution_count": 0, 851 | "outputs": [] 852 | }, 853 | { 854 | "metadata": { 855 | "id": "VpSsIvK9xg95", 856 | "colab_type": "text" 857 | }, 858 | "cell_type": "markdown", 859 | "source": [ 860 | "# Save to Drive - Zip file\n" 861 | ] 862 | }, 863 | { 864 | "metadata": { 865 | "id": "b5zglqmkSBoZ", 866 | "colab_type": "code", 867 | "colab": { 868 | "autoexec": { 869 | "startup": false, 870 | "wait_interval": 0 871 | } 872 | } 873 | }, 874 | "cell_type": "code", 875 | "source": [ 876 | "CPT_PATH = \"./checkpoint\"" 877 | ], 878 | "execution_count": 0, 879 | "outputs": [] 880 | }, 881 | { 882 | "metadata": { 883 | "id": "sR-5MGDkd0D6", 884 | "colab_type": "text" 885 | }, 886 | "cell_type": "markdown", 887 | "source": [ 888 | "Same logic as before, this time it saves the .zip to a specific drive folder " 889 | ] 890 | }, 891 | { 892 | "metadata": { 893 | "id": "xeMebbocTBh3", 894 | "colab_type": "code", 895 | "colab": { 896 | "autoexec": { 897 | "startup": false, 898 | "wait_interval": 0 899 | } 900 | } 901 | }, 902 | "cell_type": "code", 903 | "source": [ 904 | "drive_folder = \"116OGJqQpksx6AaRpcH5EUt2M99wJ9mfJ\" # change the link, or it will fill my drive" 905 | ], 906 | "execution_count": 0, 907 | "outputs": [] 908 | }, 909 | { 910 | "metadata": { 911 | "id": "hg2emDrpWeYT", 912 | "colab_type": "code", 913 | "colab": { 914 | "autoexec": { 915 | "startup": false, 916 | "wait_interval": 0 917 | } 918 | } 919 | }, 920 | "cell_type": "code", 921 | "source": [ 922 | "from pydrive.auth import GoogleAuth\n", 923 | "from pydrive.drive import GoogleDrive\n", 924 | "from google.colab import auth\n", 925 | "from oauth2client.client import GoogleCredentials\n", 926 | "\n", 927 | "# 1. Authenticate and create the PyDrive client.\n", 928 | "auth.authenticate_user()\n", 929 | "gauth = GoogleAuth()\n", 930 | "gauth.credentials = GoogleCredentials.get_application_default()\n", 931 | "drive = GoogleDrive(gauth)\n", 932 | "print(\"all right\")" 933 | ], 934 | "execution_count": 0, 935 | "outputs": [] 936 | }, 937 | { 938 | "metadata": { 939 | "id": "WXSbM8WlmqWd", 940 | "colab_type": "code", 941 | "colab": { 942 | "autoexec": { 943 | "startup": false, 944 | "wait_interval": 0 945 | } 946 | } 947 | }, 948 | "cell_type": "code", 949 | "source": [ 950 | "\"\"\"import zipfile\n", 951 | "import os\n", 952 | "\n", 953 | "d = CPT_PATH\n", 954 | "\n", 955 | "os.chdir(os.path.dirname(d))\n", 956 | "with zipfile.ZipFile(d + '.zip',\n", 957 | " \"w\",\n", 958 | " zipfile.ZIP_DEFLATED,\n", 959 | " allowZip64=True) as zf:\n", 960 | " for root, _, filenames in os.walk(os.path.basename(d)):\n", 961 | " for name in filenames:\n", 962 | " name = os.path.join(root, name)\n", 963 | " name = os.path.normpath(name)\n", 964 | " zipFile = zf.write(name, name)\n", 965 | " \n", 966 | "print(zipFile)\n", 967 | "\n", 968 | "# get the folder id where you want to save your file\n", 969 | "file = drive.CreateFile({'parents':[{u'id': folder_id}]})\n", 970 | "file.SetContentFile(zipFile)\n", 971 | "file.Upload() \n", 972 | "print(\"zip file saved on Drive\")\"\"\"" 973 | ], 974 | "execution_count": 0, 975 | "outputs": [] 976 | }, 977 | { 978 | "metadata": { 979 | "id": "KaazvsPXSj6I", 980 | "colab_type": "code", 981 | "colab": { 982 | "autoexec": { 983 | "startup": false, 984 | "wait_interval": 0 985 | } 986 | }, 987 | "cellView": "form" 988 | }, 989 | "cell_type": "code", 990 | "source": [ 991 | "#@title\n", 992 | "import shutil\n", 993 | "folder_id = drive_folder\n", 994 | "\n", 995 | "print(\"go go go!\")\n", 996 | "zip_name = CPT_PATH\n", 997 | "directory_name = '.'\n", 998 | "directory_name = CPT_PATH\n", 999 | "foo = shutil.make_archive(zip_name, 'zip', directory_name)\n", 1000 | "\n", 1001 | "# get the folder id where you want to save your file\n", 1002 | "file = drive.CreateFile({'parents':[{u'id': folder_id}]})\n", 1003 | "file.SetContentFile(foo)\n", 1004 | "file.Upload() \n", 1005 | "print(\"zip file saved on Drive\")" 1006 | ], 1007 | "execution_count": 0, 1008 | "outputs": [] 1009 | }, 1010 | { 1011 | "metadata": { 1012 | "id": "U95uZH1n9coy", 1013 | "colab_type": "text" 1014 | }, 1015 | "cell_type": "markdown", 1016 | "source": [ 1017 | "" 1018 | ] 1019 | } 1020 | ] 1021 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Art and Design DCGAN in Tensorflow 2 | 3 | Modified version of Taehoon Kim’s tensorflow implementation of DCGAN `https://carpedm20.github.io/faces/` with a focus on generating paintings and soon graphic design. 4 | 5 | It includes a script to scrape WikiArt, one to uniform images in a format that the Dcgan can work with and a Google Colab Notebook to train it on a free GPU. 6 | 7 | ![](https://pbs.twimg.com/media/DdvgUjdVwAAyANO.jpg:large) 8 | 9 | ## Prerequisites 10 | 11 | - Python 2.7 or Python 3.3+ 12 | - [Tensorflow 0.12.1](https://github.com/tensorflow/tensorflow/tree/r0.12) 13 | - [SciPy](http://www.scipy.org/install.html) 14 | - [pillow](https://github.com/python-pillow/Pillow) 15 | 16 | You can find a zip of my dataset in: 17 | `https://drive.google.com/open?id=17Cm2352V9G1tR4kii5yHI_KUkevLC67_` 18 | 19 | Checkpoints here: 20 | `https://drive.google.com/open?id=1yABe4LsWeDQz5p5IO2AYJPosGOgqtD2Z` 21 | 22 | Colab Notebook: 23 | `https://colab.research.google.com/drive/18RglimpA1JH7bRbTXtxx9fAbDl60sFVQ#scrollTo=YLBwMdxMW3PR` 24 | 25 | You will have to convert the images to RGB and resize them with: `uniform.py` by changing `path` with your directory 26 | 27 | Train your own dataset: 28 | 29 | $ mkdir data/DATASET_NAME 30 | ... add images to data/DATASET_NAME ... 31 | $ python main.py --dataset DATASET_NAME --train 32 | $ python main.py --dataset DATASET_NAME 33 | $ # example 34 | $ python main.py --dataset=eyes --input_fname_pattern="*_cropped.png" --train 35 | 36 | If your dataset is located in a different root directory: 37 | 38 | $ python main.py --dataset DATASET_NAME --data_dir DATASET_ROOT_DIR --train 39 | $ python main.py --dataset DATASET_NAME --data_dir DATASET_ROOT_DIR 40 | $ # example 41 | $ python main.py --dataset=eyes --data_dir ../datasets/ --input_fname_pattern="*_cropped.png" --train 42 | 43 | 44 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modification of https://github.com/stanfordnlp/treelstm/blob/master/scripts/download.py 3 | 4 | Downloads the following: 5 | - Celeb-A dataset 6 | - LSUN dataset 7 | - MNIST dataset 8 | """ 9 | 10 | from __future__ import print_function 11 | import os 12 | import sys 13 | import gzip 14 | import json 15 | import shutil 16 | import zipfile 17 | import argparse 18 | import requests 19 | import subprocess 20 | from tqdm import tqdm 21 | from six.moves import urllib 22 | 23 | parser = argparse.ArgumentParser(description='Download dataset for DCGAN.') 24 | parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist'], 25 | help='name of dataset to download [celebA, lsun, mnist]') 26 | 27 | def download(url, dirpath): 28 | filename = url.split('/')[-1] 29 | filepath = os.path.join(dirpath, filename) 30 | u = urllib.request.urlopen(url) 31 | f = open(filepath, 'wb') 32 | filesize = int(u.headers["Content-Length"]) 33 | print("Downloading: %s Bytes: %s" % (filename, filesize)) 34 | 35 | downloaded = 0 36 | block_sz = 8192 37 | status_width = 70 38 | while True: 39 | buf = u.read(block_sz) 40 | if not buf: 41 | print('') 42 | break 43 | else: 44 | print('', end='\r') 45 | downloaded += len(buf) 46 | f.write(buf) 47 | status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") % 48 | ('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize)) 49 | print(status, end='') 50 | sys.stdout.flush() 51 | f.close() 52 | return filepath 53 | 54 | def download_file_from_google_drive(id, destination): 55 | URL = "https://docs.google.com/uc?export=download" 56 | session = requests.Session() 57 | 58 | response = session.get(URL, params={ 'id': id }, stream=True) 59 | token = get_confirm_token(response) 60 | 61 | if token: 62 | params = { 'id' : id, 'confirm' : token } 63 | response = session.get(URL, params=params, stream=True) 64 | 65 | save_response_content(response, destination) 66 | 67 | def get_confirm_token(response): 68 | for key, value in response.cookies.items(): 69 | if key.startswith('download_warning'): 70 | return value 71 | return None 72 | 73 | def save_response_content(response, destination, chunk_size=32*1024): 74 | total_size = int(response.headers.get('content-length', 0)) 75 | with open(destination, "wb") as f: 76 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size, 77 | unit='B', unit_scale=True, desc=destination): 78 | if chunk: # filter out keep-alive new chunks 79 | f.write(chunk) 80 | 81 | def unzip(filepath): 82 | print("Extracting: " + filepath) 83 | dirpath = os.path.dirname(filepath) 84 | with zipfile.ZipFile(filepath) as zf: 85 | zf.extractall(dirpath) 86 | os.remove(filepath) 87 | 88 | def download_celeb_a(dirpath): 89 | data_dir = 'celebA' 90 | if os.path.exists(os.path.join(dirpath, data_dir)): 91 | print('Found Celeb-A - skip') 92 | return 93 | 94 | filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" 95 | save_path = os.path.join(dirpath, filename) 96 | 97 | if os.path.exists(save_path): 98 | print('[*] {} already exists'.format(save_path)) 99 | else: 100 | download_file_from_google_drive(drive_id, save_path) 101 | 102 | zip_dir = '' 103 | with zipfile.ZipFile(save_path) as zf: 104 | zip_dir = zf.namelist()[0] 105 | zf.extractall(dirpath) 106 | os.remove(save_path) 107 | os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir)) 108 | 109 | def _list_categories(tag): 110 | url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag 111 | f = urllib.request.urlopen(url) 112 | return json.loads(f.read()) 113 | 114 | def _download_lsun(out_dir, category, set_name, tag): 115 | url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \ 116 | '&category={category}&set={set_name}'.format(**locals()) 117 | print(url) 118 | if set_name == 'test': 119 | out_name = 'test_lmdb.zip' 120 | else: 121 | out_name = '{category}_{set_name}_lmdb.zip'.format(**locals()) 122 | out_path = os.path.join(out_dir, out_name) 123 | cmd = ['curl', url, '-o', out_path] 124 | print('Downloading', category, set_name, 'set') 125 | subprocess.call(cmd) 126 | 127 | def download_lsun(dirpath): 128 | data_dir = os.path.join(dirpath, 'lsun') 129 | if os.path.exists(data_dir): 130 | print('Found LSUN - skip') 131 | return 132 | else: 133 | os.mkdir(data_dir) 134 | 135 | tag = 'latest' 136 | #categories = _list_categories(tag) 137 | categories = ['bedroom'] 138 | 139 | for category in categories: 140 | _download_lsun(data_dir, category, 'train', tag) 141 | _download_lsun(data_dir, category, 'val', tag) 142 | _download_lsun(data_dir, '', 'test', tag) 143 | 144 | def download_mnist(dirpath): 145 | data_dir = os.path.join(dirpath, 'mnist') 146 | if os.path.exists(data_dir): 147 | print('Found MNIST - skip') 148 | return 149 | else: 150 | os.mkdir(data_dir) 151 | url_base = 'http://yann.lecun.com/exdb/mnist/' 152 | file_names = ['train-images-idx3-ubyte.gz', 153 | 'train-labels-idx1-ubyte.gz', 154 | 't10k-images-idx3-ubyte.gz', 155 | 't10k-labels-idx1-ubyte.gz'] 156 | for file_name in file_names: 157 | url = (url_base+file_name).format(**locals()) 158 | print(url) 159 | out_path = os.path.join(data_dir,file_name) 160 | cmd = ['curl', url, '-o', out_path] 161 | print('Downloading ', file_name) 162 | subprocess.call(cmd) 163 | cmd = ['gzip', '-d', out_path] 164 | print('Decompressing ', file_name) 165 | subprocess.call(cmd) 166 | 167 | def prepare_data_dir(path = './data'): 168 | if not os.path.exists(path): 169 | os.mkdir(path) 170 | 171 | if __name__ == '__main__': 172 | args = parser.parse_args() 173 | prepare_data_dir() 174 | 175 | if any(name in args.datasets for name in ['CelebA', 'celebA', 'celebA']): 176 | download_celeb_a('./data') 177 | if 'lsun' in args.datasets: 178 | download_lsun('./data') 179 | if 'mnist' in args.datasets: 180 | download_mnist('./data') 181 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.misc 3 | import numpy as np 4 | 5 | from model import DCGAN 6 | from utils import pp, visualize, to_json, show_all_variables 7 | 8 | import tensorflow as tf 9 | 10 | flags = tf.app.flags 11 | flags.DEFINE_integer("epoch", 25, "Epoch to train [25]") 12 | flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") 13 | flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") 14 | flags.DEFINE_float("train_size", np.inf, "The size of train images [np.inf]") 15 | flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") 16 | flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]") 17 | flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]") 18 | flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]") 19 | flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]") 20 | flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]") 21 | flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]") 22 | flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") 23 | flags.DEFINE_string("data_dir", "./data", "Root directory of dataset [data]") 24 | flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]") 25 | flags.DEFINE_boolean("train", False, "True for training, False for testing [False]") 26 | flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]") 27 | flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]") 28 | flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]") 29 | FLAGS = flags.FLAGS 30 | 31 | def main(_): 32 | pp.pprint(flags.FLAGS.__flags) 33 | 34 | if FLAGS.input_width is None: 35 | FLAGS.input_width = FLAGS.input_height 36 | if FLAGS.output_width is None: 37 | FLAGS.output_width = FLAGS.output_height 38 | 39 | if not os.path.exists(FLAGS.checkpoint_dir): 40 | os.makedirs(FLAGS.checkpoint_dir) 41 | if not os.path.exists(FLAGS.sample_dir): 42 | os.makedirs(FLAGS.sample_dir) 43 | 44 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333) 45 | run_config = tf.ConfigProto(gpu_options=gpu_options) 46 | #run_confing = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 47 | run_config.gpu_options.allow_growth=True 48 | 49 | 50 | 51 | 52 | 53 | with tf.Session(config=run_config) as sess: 54 | if FLAGS.dataset == 'mnist': 55 | dcgan = DCGAN( 56 | sess, 57 | input_width=FLAGS.input_width, 58 | input_height=FLAGS.input_height, 59 | output_width=FLAGS.output_width, 60 | output_height=FLAGS.output_height, 61 | batch_size=FLAGS.batch_size, 62 | sample_num=FLAGS.batch_size, 63 | y_dim=10, 64 | z_dim=FLAGS.generate_test_images, 65 | dataset_name=FLAGS.dataset, 66 | input_fname_pattern=FLAGS.input_fname_pattern, 67 | crop=FLAGS.crop, 68 | checkpoint_dir=FLAGS.checkpoint_dir, 69 | sample_dir=FLAGS.sample_dir, 70 | data_dir=FLAGS.data_dir) 71 | else: 72 | dcgan = DCGAN( 73 | sess, 74 | input_width=FLAGS.input_width, 75 | input_height=FLAGS.input_height, 76 | output_width=FLAGS.output_width, 77 | output_height=FLAGS.output_height, 78 | batch_size=FLAGS.batch_size, 79 | sample_num=FLAGS.batch_size, 80 | z_dim=FLAGS.generate_test_images, 81 | dataset_name=FLAGS.dataset, 82 | input_fname_pattern=FLAGS.input_fname_pattern, 83 | crop=FLAGS.crop, 84 | checkpoint_dir=FLAGS.checkpoint_dir, 85 | sample_dir=FLAGS.sample_dir, 86 | data_dir=FLAGS.data_dir) 87 | 88 | show_all_variables() 89 | 90 | if FLAGS.train: 91 | dcgan.train(FLAGS) 92 | else: 93 | if not dcgan.load(FLAGS.checkpoint_dir)[0]: 94 | raise Exception("[!] Train a model first, then run test mode") 95 | 96 | 97 | # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0], 98 | # [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1], 99 | # [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2], 100 | # [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3], 101 | # [dcgan.h4_w, dcgan.h4_b, None]) 102 | 103 | # Below is codes for visualization 104 | OPTION = 1 105 | visualize(sess, dcgan, FLAGS, OPTION) 106 | 107 | if __name__ == '__main__': 108 | tf.app.run() 109 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import time 4 | import math 5 | from glob import glob 6 | import tensorflow as tf 7 | import numpy as np 8 | from six.moves import xrange 9 | 10 | from ops import * 11 | from utils import * 12 | 13 | def conv_out_size_same(size, stride): 14 | return int(math.ceil(float(size) / float(stride))) 15 | 16 | class DCGAN(object): 17 | def __init__(self, sess, input_height=108, input_width=108, crop=True, 18 | batch_size=64, sample_num = 64, output_height=64, output_width=64, 19 | y_dim=None, z_dim=100, gf_dim=64, df_dim=64, 20 | gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default', 21 | input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None, data_dir='./data'): 22 | """ 23 | 24 | Args: 25 | sess: TensorFlow session 26 | batch_size: The size of batch. Should be specified before training. 27 | y_dim: (optional) Dimension of dim for y. [None] 28 | z_dim: (optional) Dimension of dim for Z. [100] 29 | gf_dim: (optional) Dimension of gen filters in first conv layer. [64] 30 | df_dim: (optional) Dimension of discrim filters in first conv layer. [64] 31 | gfc_dim: (optional) Dimension of gen units for for fully connected layer. [1024] 32 | dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024] 33 | c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3] 34 | """ 35 | self.sess = sess 36 | self.crop = crop 37 | 38 | self.batch_size = batch_size 39 | self.sample_num = sample_num 40 | 41 | self.input_height = input_height 42 | self.input_width = input_width 43 | self.output_height = output_height 44 | self.output_width = output_width 45 | 46 | self.y_dim = y_dim 47 | self.z_dim = z_dim 48 | 49 | self.gf_dim = gf_dim 50 | self.df_dim = df_dim 51 | 52 | self.gfc_dim = gfc_dim 53 | self.dfc_dim = dfc_dim 54 | 55 | # batch normalization : deals with poor initialization helps gradient flow 56 | self.d_bn1 = batch_norm(name='d_bn1') 57 | self.d_bn2 = batch_norm(name='d_bn2') 58 | 59 | if not self.y_dim: 60 | self.d_bn3 = batch_norm(name='d_bn3') 61 | 62 | self.g_bn0 = batch_norm(name='g_bn0') 63 | self.g_bn1 = batch_norm(name='g_bn1') 64 | self.g_bn2 = batch_norm(name='g_bn2') 65 | 66 | if not self.y_dim: 67 | self.g_bn3 = batch_norm(name='g_bn3') 68 | 69 | self.dataset_name = dataset_name 70 | self.input_fname_pattern = input_fname_pattern 71 | self.checkpoint_dir = checkpoint_dir 72 | self.data_dir = data_dir 73 | 74 | if self.dataset_name == 'mnist': 75 | self.data_X, self.data_y = self.load_mnist() 76 | self.c_dim = self.data_X[0].shape[-1] 77 | else: 78 | self.data = glob(os.path.join(self.data_dir, self.dataset_name, self.input_fname_pattern)) 79 | imreadImg = imread(self.data[0]) 80 | if len(imreadImg.shape) >= 3: #check if image is a non-grayscale image by checking channel number 81 | self.c_dim = imread(self.data[0]).shape[-1] 82 | else: 83 | self.c_dim = 1 84 | 85 | self.grayscale = (self.c_dim == 1) 86 | 87 | self.build_model() 88 | 89 | def build_model(self): 90 | if self.y_dim: 91 | self.y = tf.placeholder(tf.float32, [self.batch_size, self.y_dim], name='y') 92 | else: 93 | self.y = None 94 | 95 | if self.crop: 96 | image_dims = [self.output_height, self.output_width, self.c_dim] 97 | else: 98 | image_dims = [self.input_height, self.input_width, self.c_dim] 99 | 100 | self.inputs = tf.placeholder( 101 | tf.float32, [self.batch_size] + image_dims, name='real_images') 102 | 103 | inputs = self.inputs 104 | 105 | self.z = tf.placeholder( 106 | tf.float32, [None, self.z_dim], name='z') 107 | self.z_sum = histogram_summary("z", self.z) 108 | 109 | self.G = self.generator(self.z, self.y) 110 | self.D, self.D_logits = self.discriminator(inputs, self.y, reuse=False) 111 | self.sampler = self.sampler(self.z, self.y) 112 | self.D_, self.D_logits_ = self.discriminator(self.G, self.y, reuse=True) 113 | 114 | self.d_sum = histogram_summary("d", self.D) 115 | self.d__sum = histogram_summary("d_", self.D_) 116 | self.G_sum = image_summary("G", self.G) 117 | 118 | def sigmoid_cross_entropy_with_logits(x, y): 119 | try: 120 | return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y) 121 | except: 122 | return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y) 123 | 124 | self.d_loss_real = tf.reduce_mean( 125 | sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D))) 126 | self.d_loss_fake = tf.reduce_mean( 127 | sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_))) 128 | self.g_loss = tf.reduce_mean( 129 | sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_))) 130 | 131 | self.d_loss_real_sum = scalar_summary("d_loss_real", self.d_loss_real) 132 | self.d_loss_fake_sum = scalar_summary("d_loss_fake", self.d_loss_fake) 133 | 134 | self.d_loss = self.d_loss_real + self.d_loss_fake 135 | 136 | self.g_loss_sum = scalar_summary("g_loss", self.g_loss) 137 | self.d_loss_sum = scalar_summary("d_loss", self.d_loss) 138 | 139 | t_vars = tf.trainable_variables() 140 | 141 | self.d_vars = [var for var in t_vars if 'd_' in var.name] 142 | self.g_vars = [var for var in t_vars if 'g_' in var.name] 143 | 144 | self.saver = tf.train.Saver() 145 | 146 | def train(self, config): 147 | d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ 148 | .minimize(self.d_loss, var_list=self.d_vars) 149 | g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ 150 | .minimize(self.g_loss, var_list=self.g_vars) 151 | try: 152 | tf.global_variables_initializer().run() 153 | except: 154 | tf.initialize_all_variables().run() 155 | 156 | self.g_sum = merge_summary([self.z_sum, self.d__sum, 157 | self.G_sum, self.d_loss_fake_sum, self.g_loss_sum]) 158 | self.d_sum = merge_summary( 159 | [self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum]) 160 | self.writer = SummaryWriter("./logs", self.sess.graph) 161 | 162 | sample_z = np.random.uniform(-1, 1, size=(self.sample_num , self.z_dim)) 163 | 164 | if config.dataset == 'mnist': 165 | sample_inputs = self.data_X[0:self.sample_num] 166 | sample_labels = self.data_y[0:self.sample_num] 167 | else: 168 | sample_files = self.data[0:self.sample_num] 169 | sample = [ 170 | get_image(sample_file, 171 | input_height=self.input_height, 172 | input_width=self.input_width, 173 | resize_height=self.output_height, 174 | resize_width=self.output_width, 175 | crop=self.crop, 176 | grayscale=self.grayscale) for sample_file in sample_files] 177 | if (self.grayscale): 178 | sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None] 179 | else: 180 | sample_inputs = np.array(sample).astype(np.float32) 181 | 182 | counter = 1 183 | start_time = time.time() 184 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 185 | if could_load: 186 | counter = checkpoint_counter 187 | print(" [*] Load SUCCESS") 188 | else: 189 | print(" [!] Load failed...") 190 | 191 | for epoch in xrange(config.epoch): 192 | if config.dataset == 'mnist': 193 | batch_idxs = min(len(self.data_X), config.train_size) // config.batch_size 194 | else: 195 | self.data = glob(os.path.join( 196 | config.data_dir, config.dataset, self.input_fname_pattern)) 197 | batch_idxs = min(len(self.data), config.train_size) // config.batch_size 198 | 199 | for idx in xrange(0, batch_idxs): 200 | if config.dataset == 'mnist': 201 | batch_images = self.data_X[idx*config.batch_size:(idx+1)*config.batch_size] 202 | batch_labels = self.data_y[idx*config.batch_size:(idx+1)*config.batch_size] 203 | else: 204 | batch_files = self.data[idx*config.batch_size:(idx+1)*config.batch_size] 205 | batch = [ 206 | get_image(batch_file, 207 | input_height=self.input_height, 208 | input_width=self.input_width, 209 | resize_height=self.output_height, 210 | resize_width=self.output_width, 211 | crop=self.crop, 212 | grayscale=self.grayscale) for batch_file in batch_files] 213 | if self.grayscale: 214 | batch_images = np.array(batch).astype(np.float32)[:, :, :, None] 215 | else: 216 | batch_images = np.array(batch).astype(np.float32) 217 | 218 | batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \ 219 | .astype(np.float32) 220 | 221 | if config.dataset == 'mnist': 222 | # Update D network 223 | _, summary_str = self.sess.run([d_optim, self.d_sum], 224 | feed_dict={ 225 | self.inputs: batch_images, 226 | self.z: batch_z, 227 | self.y:batch_labels, 228 | }) 229 | self.writer.add_summary(summary_str, counter) 230 | 231 | # Update G network 232 | _, summary_str = self.sess.run([g_optim, self.g_sum], 233 | feed_dict={ 234 | self.z: batch_z, 235 | self.y:batch_labels, 236 | }) 237 | self.writer.add_summary(summary_str, counter) 238 | 239 | # Run g_optim twice to make sure that d_loss does not go to zero (different from paper) 240 | _, summary_str = self.sess.run([g_optim, self.g_sum], 241 | feed_dict={ self.z: batch_z, self.y:batch_labels }) 242 | self.writer.add_summary(summary_str, counter) 243 | 244 | errD_fake = self.d_loss_fake.eval({ 245 | self.z: batch_z, 246 | self.y:batch_labels 247 | }) 248 | errD_real = self.d_loss_real.eval({ 249 | self.inputs: batch_images, 250 | self.y:batch_labels 251 | }) 252 | errG = self.g_loss.eval({ 253 | self.z: batch_z, 254 | self.y: batch_labels 255 | }) 256 | else: 257 | # Update D network 258 | _, summary_str = self.sess.run([d_optim, self.d_sum], 259 | feed_dict={ self.inputs: batch_images, self.z: batch_z }) 260 | self.writer.add_summary(summary_str, counter) 261 | 262 | # Update G network 263 | _, summary_str = self.sess.run([g_optim, self.g_sum], 264 | feed_dict={ self.z: batch_z }) 265 | self.writer.add_summary(summary_str, counter) 266 | 267 | # Run g_optim twice to make sure that d_loss does not go to zero (different from paper) 268 | _, summary_str = self.sess.run([g_optim, self.g_sum], 269 | feed_dict={ self.z: batch_z }) 270 | self.writer.add_summary(summary_str, counter) 271 | 272 | errD_fake = self.d_loss_fake.eval({ self.z: batch_z }) 273 | errD_real = self.d_loss_real.eval({ self.inputs: batch_images }) 274 | errG = self.g_loss.eval({self.z: batch_z}) 275 | 276 | counter += 1 277 | print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 278 | % (epoch, config.epoch, idx, batch_idxs, 279 | time.time() - start_time, errD_fake+errD_real, errG)) 280 | 281 | if np.mod(counter, 100) == 1: 282 | if config.dataset == 'mnist': 283 | samples, d_loss, g_loss = self.sess.run( 284 | [self.sampler, self.d_loss, self.g_loss], 285 | feed_dict={ 286 | self.z: sample_z, 287 | self.inputs: sample_inputs, 288 | self.y:sample_labels, 289 | } 290 | ) 291 | save_images(samples, image_manifold_size(samples.shape[0]), 292 | './{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx)) 293 | print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) 294 | else: 295 | try: 296 | samples, d_loss, g_loss = self.sess.run( 297 | [self.sampler, self.d_loss, self.g_loss], 298 | feed_dict={ 299 | self.z: sample_z, 300 | self.inputs: sample_inputs, 301 | }, 302 | ) 303 | save_images(samples, image_manifold_size(samples.shape[0]), 304 | './{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx)) 305 | print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) 306 | except: 307 | print("one pic error!...") 308 | 309 | if np.mod(counter, 500) == 2: 310 | self.save(config.checkpoint_dir, counter) 311 | 312 | def discriminator(self, image, y=None, reuse=False): 313 | with tf.variable_scope("discriminator") as scope: 314 | if reuse: 315 | scope.reuse_variables() 316 | 317 | if not self.y_dim: 318 | h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv')) 319 | h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv'))) 320 | h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv'))) 321 | h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, name='d_h3_conv'))) 322 | h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h4_lin') 323 | 324 | return tf.nn.sigmoid(h4), h4 325 | else: 326 | yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) 327 | x = conv_cond_concat(image, yb) 328 | 329 | h0 = lrelu(conv2d(x, self.c_dim + self.y_dim, name='d_h0_conv')) 330 | h0 = conv_cond_concat(h0, yb) 331 | 332 | h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim + self.y_dim, name='d_h1_conv'))) 333 | h1 = tf.reshape(h1, [self.batch_size, -1]) 334 | h1 = concat([h1, y], 1) 335 | 336 | h2 = lrelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin'))) 337 | h2 = concat([h2, y], 1) 338 | 339 | h3 = linear(h2, 1, 'd_h3_lin') 340 | 341 | return tf.nn.sigmoid(h3), h3 342 | 343 | def generator(self, z, y=None): 344 | with tf.variable_scope("generator") as scope: 345 | if not self.y_dim: 346 | s_h, s_w = self.output_height, self.output_width 347 | s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2) 348 | s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2) 349 | s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2) 350 | s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2) 351 | 352 | # project `z` and reshape 353 | self.z_, self.h0_w, self.h0_b = linear( 354 | z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin', with_w=True) 355 | 356 | self.h0 = tf.reshape( 357 | self.z_, [-1, s_h16, s_w16, self.gf_dim * 8]) 358 | h0 = tf.nn.relu(self.g_bn0(self.h0)) 359 | 360 | self.h1, self.h1_w, self.h1_b = deconv2d( 361 | h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1', with_w=True) 362 | h1 = tf.nn.relu(self.g_bn1(self.h1)) 363 | 364 | h2, self.h2_w, self.h2_b = deconv2d( 365 | h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2', with_w=True) 366 | h2 = tf.nn.relu(self.g_bn2(h2)) 367 | 368 | h3, self.h3_w, self.h3_b = deconv2d( 369 | h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3', with_w=True) 370 | h3 = tf.nn.relu(self.g_bn3(h3)) 371 | 372 | h4, self.h4_w, self.h4_b = deconv2d( 373 | h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4', with_w=True) 374 | 375 | return tf.nn.tanh(h4) 376 | else: 377 | s_h, s_w = self.output_height, self.output_width 378 | s_h2, s_h4 = int(s_h/2), int(s_h/4) 379 | s_w2, s_w4 = int(s_w/2), int(s_w/4) 380 | 381 | # yb = tf.expand_dims(tf.expand_dims(y, 1),2) 382 | yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) 383 | z = concat([z, y], 1) 384 | 385 | h0 = tf.nn.relu( 386 | self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin'))) 387 | h0 = concat([h0, y], 1) 388 | 389 | h1 = tf.nn.relu(self.g_bn1( 390 | linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin'))) 391 | h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2]) 392 | 393 | h1 = conv_cond_concat(h1, yb) 394 | 395 | h2 = tf.nn.relu(self.g_bn2(deconv2d(h1, 396 | [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2'))) 397 | h2 = conv_cond_concat(h2, yb) 398 | 399 | return tf.nn.sigmoid( 400 | deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3')) 401 | 402 | def sampler(self, z, y=None): 403 | with tf.variable_scope("generator") as scope: 404 | scope.reuse_variables() 405 | 406 | if not self.y_dim: 407 | s_h, s_w = self.output_height, self.output_width 408 | s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2) 409 | s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2) 410 | s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2) 411 | s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2) 412 | 413 | # project `z` and reshape 414 | h0 = tf.reshape( 415 | linear(z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin'), 416 | [-1, s_h16, s_w16, self.gf_dim * 8]) 417 | h0 = tf.nn.relu(self.g_bn0(h0, train=False)) 418 | 419 | h1 = deconv2d(h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1') 420 | h1 = tf.nn.relu(self.g_bn1(h1, train=False)) 421 | 422 | h2 = deconv2d(h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2') 423 | h2 = tf.nn.relu(self.g_bn2(h2, train=False)) 424 | 425 | h3 = deconv2d(h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3') 426 | h3 = tf.nn.relu(self.g_bn3(h3, train=False)) 427 | 428 | h4 = deconv2d(h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4') 429 | 430 | return tf.nn.tanh(h4) 431 | else: 432 | s_h, s_w = self.output_height, self.output_width 433 | s_h2, s_h4 = int(s_h/2), int(s_h/4) 434 | s_w2, s_w4 = int(s_w/2), int(s_w/4) 435 | 436 | # yb = tf.reshape(y, [-1, 1, 1, self.y_dim]) 437 | yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) 438 | z = concat([z, y], 1) 439 | 440 | h0 = tf.nn.relu(self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin'), train=False)) 441 | h0 = concat([h0, y], 1) 442 | 443 | h1 = tf.nn.relu(self.g_bn1( 444 | linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin'), train=False)) 445 | h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2]) 446 | h1 = conv_cond_concat(h1, yb) 447 | 448 | h2 = tf.nn.relu(self.g_bn2( 449 | deconv2d(h1, [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2'), train=False)) 450 | h2 = conv_cond_concat(h2, yb) 451 | 452 | return tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3')) 453 | 454 | def load_mnist(self): 455 | data_dir = os.path.join(self.data_dir, self.dataset_name) 456 | 457 | fd = open(os.path.join(data_dir,'train-images-idx3-ubyte')) 458 | loaded = np.fromfile(file=fd,dtype=np.uint8) 459 | trX = loaded[16:].reshape((60000,28,28,1)).astype(np.float) 460 | 461 | fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte')) 462 | loaded = np.fromfile(file=fd,dtype=np.uint8) 463 | trY = loaded[8:].reshape((60000)).astype(np.float) 464 | 465 | fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte')) 466 | loaded = np.fromfile(file=fd,dtype=np.uint8) 467 | teX = loaded[16:].reshape((10000,28,28,1)).astype(np.float) 468 | 469 | fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte')) 470 | loaded = np.fromfile(file=fd,dtype=np.uint8) 471 | teY = loaded[8:].reshape((10000)).astype(np.float) 472 | 473 | trY = np.asarray(trY) 474 | teY = np.asarray(teY) 475 | 476 | X = np.concatenate((trX, teX), axis=0) 477 | y = np.concatenate((trY, teY), axis=0).astype(np.int) 478 | 479 | seed = 547 480 | np.random.seed(seed) 481 | np.random.shuffle(X) 482 | np.random.seed(seed) 483 | np.random.shuffle(y) 484 | 485 | y_vec = np.zeros((len(y), self.y_dim), dtype=np.float) 486 | for i, label in enumerate(y): 487 | y_vec[i,y[i]] = 1.0 488 | 489 | return X/255.,y_vec 490 | 491 | @property 492 | def model_dir(self): 493 | return "{}_{}_{}_{}".format( 494 | self.dataset_name, self.batch_size, 495 | self.output_height, self.output_width) 496 | 497 | def save(self, checkpoint_dir, step): 498 | model_name = "DCGAN.model" 499 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 500 | 501 | if not os.path.exists(checkpoint_dir): 502 | os.makedirs(checkpoint_dir) 503 | 504 | self.saver.save(self.sess, 505 | os.path.join(checkpoint_dir, model_name), 506 | global_step=step) 507 | 508 | def load(self, checkpoint_dir): 509 | import re 510 | print(" [*] Reading checkpoints...") 511 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 512 | 513 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 514 | if ckpt and ckpt.model_checkpoint_path: 515 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 516 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 517 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) 518 | print(" [*] Success to read {}".format(ckpt_name)) 519 | return True, counter 520 | else: 521 | print(" [*] Failed to find a checkpoint") 522 | return False, 0 523 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from tensorflow.python.framework import ops 6 | 7 | from utils import * 8 | 9 | try: 10 | image_summary = tf.image_summary 11 | scalar_summary = tf.scalar_summary 12 | histogram_summary = tf.histogram_summary 13 | merge_summary = tf.merge_summary 14 | SummaryWriter = tf.train.SummaryWriter 15 | except: 16 | image_summary = tf.summary.image 17 | scalar_summary = tf.summary.scalar 18 | histogram_summary = tf.summary.histogram 19 | merge_summary = tf.summary.merge 20 | SummaryWriter = tf.summary.FileWriter 21 | 22 | if "concat_v2" in dir(tf): 23 | def concat(tensors, axis, *args, **kwargs): 24 | return tf.concat_v2(tensors, axis, *args, **kwargs) 25 | else: 26 | def concat(tensors, axis, *args, **kwargs): 27 | return tf.concat(tensors, axis, *args, **kwargs) 28 | 29 | class batch_norm(object): 30 | def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): 31 | with tf.variable_scope(name): 32 | self.epsilon = epsilon 33 | self.momentum = momentum 34 | self.name = name 35 | 36 | def __call__(self, x, train=True): 37 | return tf.contrib.layers.batch_norm(x, 38 | decay=self.momentum, 39 | updates_collections=None, 40 | epsilon=self.epsilon, 41 | scale=True, 42 | is_training=train, 43 | scope=self.name) 44 | 45 | def conv_cond_concat(x, y): 46 | """Concatenate conditioning vector on feature map axis.""" 47 | x_shapes = x.get_shape() 48 | y_shapes = y.get_shape() 49 | return concat([ 50 | x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3) 51 | 52 | def conv2d(input_, output_dim, 53 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 54 | name="conv2d"): 55 | with tf.variable_scope(name): 56 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 57 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 58 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 59 | 60 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 61 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 62 | 63 | return conv 64 | 65 | def deconv2d(input_, output_shape, 66 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 67 | name="deconv2d", with_w=False): 68 | with tf.variable_scope(name): 69 | # filter : [height, width, output_channels, in_channels] 70 | w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], 71 | initializer=tf.random_normal_initializer(stddev=stddev)) 72 | 73 | try: 74 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, 75 | strides=[1, d_h, d_w, 1]) 76 | 77 | # Support for verisons of TensorFlow before 0.7.0 78 | except AttributeError: 79 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, 80 | strides=[1, d_h, d_w, 1]) 81 | 82 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 83 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 84 | 85 | if with_w: 86 | return deconv, w, biases 87 | else: 88 | return deconv 89 | 90 | def lrelu(x, leak=0.2, name="lrelu"): 91 | return tf.maximum(x, leak*x) 92 | 93 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 94 | shape = input_.get_shape().as_list() 95 | 96 | with tf.variable_scope(scope or "Linear"): 97 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 98 | tf.random_normal_initializer(stddev=stddev)) 99 | bias = tf.get_variable("bias", [output_size], 100 | initializer=tf.constant_initializer(bias_start)) 101 | if with_w: 102 | return tf.matmul(input_, matrix) + bias, matrix, bias 103 | else: 104 | return tf.matmul(input_, matrix) + bias 105 | -------------------------------------------------------------------------------- /uniform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os, sys 3 | 4 | path ='./data/portrait/' 5 | 6 | dirs = os.listdir( path ) 7 | 8 | 9 | for item in dirs: 10 | try: 11 | if os.path.isfile(path+item): 12 | im = Image.open(path+item) 13 | longer_side = max(im.size) 14 | 15 | horizontal_padding = (longer_side - im.size[0]) / 2 16 | vertical_padding = (longer_side - im.size[1]) / 2 17 | f, e = os.path.splitext(path+item) 18 | imResize = im.crop( 19 | ( 20 | -horizontal_padding, 21 | -vertical_padding, 22 | im.size[0] + horizontal_padding, 23 | im.size[1] + vertical_padding 24 | ) 25 | ) 26 | RGB = imResize.convert('RGB') 27 | little = RGB.resize((32,32), Image.ANTIALIAS) 28 | 29 | little.save(f + 'resize.jpg', 'JPEG', quality=30) 30 | 31 | except Exception as e: 32 | print(e) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some codes from https://github.com/Newmu/dcgan_code 3 | """ 4 | from __future__ import division 5 | import math 6 | import json 7 | import random 8 | import pprint 9 | import scipy.misc 10 | import numpy as np 11 | from time import gmtime, strftime 12 | from six.moves import xrange 13 | 14 | import tensorflow as tf 15 | import tensorflow.contrib.slim as slim 16 | 17 | pp = pprint.PrettyPrinter() 18 | 19 | get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1]) 20 | 21 | def show_all_variables(): 22 | model_vars = tf.trainable_variables() 23 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 24 | 25 | def get_image(image_path, input_height, input_width, 26 | resize_height=64, resize_width=64, 27 | crop=True, grayscale=False): 28 | image = imread(image_path, grayscale) 29 | return transform(image, input_height, input_width, 30 | resize_height, resize_width, crop) 31 | 32 | def save_images(images, size, image_path): 33 | return imsave(inverse_transform(images), size, image_path) 34 | 35 | def imread(path, grayscale = False): 36 | if (grayscale): 37 | return scipy.misc.imread(path, flatten = True).astype(np.float) 38 | else: 39 | return scipy.misc.imread(path).astype(np.float) 40 | 41 | def merge_images(images, size): 42 | return inverse_transform(images) 43 | 44 | def merge(images, size): 45 | h, w = images.shape[1], images.shape[2] 46 | if (images.shape[3] in (3,4)): 47 | c = images.shape[3] 48 | img = np.zeros((h * size[0], w * size[1], c)) 49 | for idx, image in enumerate(images): 50 | i = idx % size[1] 51 | j = idx // size[1] 52 | img[j * h:j * h + h, i * w:i * w + w, :] = image 53 | return img 54 | elif images.shape[3]==1: 55 | img = np.zeros((h * size[0], w * size[1])) 56 | for idx, image in enumerate(images): 57 | i = idx % size[1] 58 | j = idx // size[1] 59 | img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0] 60 | return img 61 | else: 62 | raise ValueError('in merge(images,size) images parameter ' 63 | 'must have dimensions: HxW or HxWx3 or HxWx4') 64 | 65 | def imsave(images, size, path): 66 | image = np.squeeze(merge(images, size)) 67 | return scipy.misc.imsave(path, image) 68 | 69 | def center_crop(x, crop_h, crop_w, 70 | resize_h=64, resize_w=64): 71 | if crop_w is None: 72 | crop_w = crop_h 73 | h, w = x.shape[:2] 74 | j = int(round((h - crop_h)/2.)) 75 | i = int(round((w - crop_w)/2.)) 76 | return scipy.misc.imresize( 77 | x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w]) 78 | 79 | def transform(image, input_height, input_width, 80 | resize_height=64, resize_width=64, crop=True): 81 | if crop: 82 | cropped_image = center_crop( 83 | image, input_height, input_width, 84 | resize_height, resize_width) 85 | else: 86 | cropped_image = scipy.misc.imresize(image, [resize_height, resize_width]) 87 | return np.array(cropped_image)/127.5 - 1. 88 | 89 | def inverse_transform(images): 90 | return (images+1.)/2. 91 | 92 | def to_json(output_path, *layers): 93 | with open(output_path, "w") as layer_f: 94 | lines = "" 95 | for w, b, bn in layers: 96 | layer_idx = w.name.split('/')[0].split('h')[1] 97 | 98 | B = b.eval() 99 | 100 | if "lin/" in w.name: 101 | W = w.eval() 102 | depth = W.shape[1] 103 | else: 104 | W = np.rollaxis(w.eval(), 2, 0) 105 | depth = W.shape[0] 106 | 107 | biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]} 108 | if bn != None: 109 | gamma = bn.gamma.eval() 110 | beta = bn.beta.eval() 111 | 112 | gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]} 113 | beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]} 114 | else: 115 | gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []} 116 | beta = {"sy": 1, "sx": 1, "depth": 0, "w": []} 117 | 118 | if "lin/" in w.name: 119 | fs = [] 120 | for w in W.T: 121 | fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]}) 122 | 123 | lines += """ 124 | var layer_%s = { 125 | "layer_type": "fc", 126 | "sy": 1, "sx": 1, 127 | "out_sx": 1, "out_sy": 1, 128 | "stride": 1, "pad": 0, 129 | "out_depth": %s, "in_depth": %s, 130 | "biases": %s, 131 | "gamma": %s, 132 | "beta": %s, 133 | "filters": %s 134 | };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs) 135 | else: 136 | fs = [] 137 | for w_ in W: 138 | fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]}) 139 | 140 | lines += """ 141 | var layer_%s = { 142 | "layer_type": "deconv", 143 | "sy": 5, "sx": 5, 144 | "out_sx": %s, "out_sy": %s, 145 | "stride": 2, "pad": 1, 146 | "out_depth": %s, "in_depth": %s, 147 | "biases": %s, 148 | "gamma": %s, 149 | "beta": %s, 150 | "filters": %s 151 | };""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2), 152 | W.shape[0], W.shape[3], biases, gamma, beta, fs) 153 | layer_f.write(" ".join(lines.replace("'","").split())) 154 | 155 | def make_gif(images, fname, duration=2, true_image=False): 156 | import moviepy.editor as mpy 157 | 158 | def make_frame(t): 159 | try: 160 | x = images[int(len(images)/duration*t)] 161 | except: 162 | x = images[-1] 163 | 164 | if true_image: 165 | return x.astype(np.uint8) 166 | else: 167 | return ((x+1)/2*255).astype(np.uint8) 168 | 169 | clip = mpy.VideoClip(make_frame, duration=duration) 170 | clip.write_gif(fname, fps = len(images) / duration) 171 | 172 | def visualize(sess, dcgan, config, option): 173 | image_frame_dim = int(math.ceil(config.batch_size**.5)) 174 | if option == 0: 175 | z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim)) 176 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 177 | save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime())) 178 | elif option == 1: 179 | #p =samples.shape[0] 180 | #for i in range(0,p): 181 | # scipy.misc.imsave('./samples/single_%s_%s.png' %(idx,i), samples[i]) 182 | values = np.arange(0, 1, 1./config.batch_size) 183 | for idx in xrange(dcgan.z_dim): 184 | print(" [*] %d" % idx) 185 | z_sample = np.random.uniform(-1, 1, size=(config.batch_size , dcgan.z_dim)) 186 | for kdx, z in enumerate(z_sample): 187 | z[idx] = values[kdx] 188 | 189 | if config.dataset == "mnist": 190 | y = np.random.choice(10, config.batch_size) 191 | y_one_hot = np.zeros((config.batch_size, 10)) 192 | y_one_hot[np.arange(config.batch_size), y] = 1 193 | 194 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot}) 195 | else: 196 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 197 | 198 | save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_arange_%s.png' % (idx)) 199 | elif option == 2: 200 | values = np.arange(0, 1, 1./config.batch_size) 201 | for idx in [random.randint(0, dcgan.z_dim - 1) for _ in xrange(dcgan.z_dim)]: 202 | print(" [*] %d" % idx) 203 | z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim)) 204 | z_sample = np.tile(z, (config.batch_size, 1)) 205 | #z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 206 | for kdx, z in enumerate(z_sample): 207 | z[idx] = values[kdx] 208 | 209 | if config.dataset == "mnist": 210 | y = np.random.choice(10, config.batch_size) 211 | y_one_hot = np.zeros((config.batch_size, 10)) 212 | y_one_hot[np.arange(config.batch_size), y] = 1 213 | 214 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot}) 215 | else: 216 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 217 | 218 | try: 219 | print("trying to save gif") 220 | make_gif(samples, './samples/test_gif_%s.gif' % (idx)) 221 | except Exception as e: 222 | print("not saving gif") 223 | print(e) 224 | save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime())) 225 | elif option == 3: 226 | values = np.arange(0, 1, 1./config.batch_size) 227 | for idx in xrange(dcgan.z_dim): 228 | print(" [*] %d" % idx) 229 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 230 | for kdx, z in enumerate(z_sample): 231 | z[idx] = values[kdx] 232 | 233 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 234 | make_gif(samples, './samples/test_gif_%s.gif' % (idx)) 235 | elif option == 4: 236 | image_set = [] 237 | values = np.arange(0, 1, 1./config.batch_size) 238 | 239 | for idx in xrange(dcgan.z_dim): 240 | print(" [*] %d" % idx) 241 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 242 | for kdx, z in enumerate(z_sample): z[idx] = values[kdx] 243 | 244 | image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})) 245 | make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx)) 246 | 247 | new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \ 248 | for idx in range(64) + range(63, -1, -1)] 249 | make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8) 250 | 251 | 252 | def image_manifold_size(num_images): 253 | manifold_h = int(np.floor(np.sqrt(num_images))) 254 | manifold_w = int(np.ceil(np.sqrt(num_images))) 255 | assert manifold_h * manifold_w == num_images 256 | return manifold_h, manifold_w 257 | -------------------------------------------------------------------------------- /wikiart_scraper.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import codecs\n", 10 | "import urllib, re\n", 11 | "import requests\n", 12 | "import urllib.request\n", 13 | "import datetime\n", 14 | "import lxml.html\n", 15 | "import os\n", 16 | "import time\n", 17 | "from bs4 import BeautifulSoup\n", 18 | "\n", 19 | "genre = \"abstract\"\n", 20 | "style =\"genre\"\n", 21 | " \n", 22 | "hard =\"./Wiki_art_dataset/\" + genre + \"/\"\n", 23 | "\n", 24 | "if not os.path.exists(hard):\n", 25 | " os.makedirs(hard)\n", 26 | "\n", 27 | "#xbmc.executebuiltin(\"Container.SetViewMode(%s)\" % ADDON.getSetting(viewType) )\n", 28 | "\n", 29 | "\n", 30 | "url = \"https://www.wikiart.org/en/paintings-by-\"+style+\"/\"+genre\n", 31 | "print(url)\n", 32 | "\n", 33 | "html = lxml.html.parse(urllib.request.urlopen(url) )\n", 34 | "links = [i.strip() for i in html.xpath(\"//ul[contains(@class, 'artists-group-list')]/li/a/@href\")]\n", 35 | "\n", 36 | "num = 0\n", 37 | "\n", 38 | "for link in links:\n", 39 | " #print (link)\n", 40 | " split = \"=\"\n", 41 | " \n", 42 | " name = link[link.index(split) + len(split):]\n", 43 | " final_link = \"https://www.wikiart.org\" + link\n", 44 | "\n", 45 | " img = urllib.request.urlopen(final_link).read()\n", 46 | " \n", 47 | " try:\n", 48 | " result = re.search('"https://uploads(.*).wikiart.org/(.*).jpg', str(img) )\n", 49 | " result1 = re.search('"https://uploads(.*).wikiart.org/(.*).JPG', str(img) )\n", 50 | "\n", 51 | " raw = result.group()\n", 52 | " raw1 = result.group()\n", 53 | "\n", 54 | " clean = raw.split(\""\")\n", 55 | " clean1 = raw1.split(\""\")\n", 56 | "\n", 57 | " for i in clean or i in clean1:\n", 58 | " try:\n", 59 | " if len(i) < 16:\n", 60 | " continue\n", 61 | " else:\n", 62 | " #print(i)\n", 63 | " save_img = urllib.request.urlretrieve(i,hard+genre+\"_\"+name+\"_\"+str(num)+\".jpg\")\n", 64 | " num +=1\n", 65 | " except:\n", 66 | " continue\n", 67 | " except:\n", 68 | " continue\n", 69 | " \n", 70 | " \n", 71 | "print(\"That's All Folks!\")" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 13, 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "cannot identify image file '/Volumes/sinanatra/Programming/Wiki_art_dataset//test/.DS_Store'\n" 84 | ] 85 | } 86 | ], 87 | "source": [ 88 | "from PIL import Image\n", 89 | "import os, sys\n", 90 | "\n", 91 | "path = \"/Volumes/sinanatra/Programming/Wiki_art_dataset/\"+\"/test/\"\n", 92 | "\n", 93 | "dirs = os.listdir( path )\n", 94 | "\n", 95 | "\n", 96 | "\n", 97 | "for item in dirs:\n", 98 | " try:\n", 99 | " if os.path.isfile(path+item):\n", 100 | " im = Image.open(path+item)\n", 101 | " longer_side = max(im.size)\n", 102 | "\n", 103 | " horizontal_padding = (longer_side - im.size[0]) / 2\n", 104 | " vertical_padding = (longer_side - im.size[1]) / 2\n", 105 | " f, e = os.path.splitext(path+item)\n", 106 | " imResize = im.crop(\n", 107 | " (\n", 108 | " -horizontal_padding,\n", 109 | " -vertical_padding,\n", 110 | " im.size[0] + horizontal_padding,\n", 111 | " im.size[1] + vertical_padding\n", 112 | " )\n", 113 | " )\n", 114 | " imResize.save(f + ' resized.jpg', 'JPEG', quality=90)\n", 115 | " except Exception as e:\n", 116 | " print(e)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 10, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "ename": "TypeError", 126 | "evalue": "'NoneType' object is not callable", 127 | "output_type": "error", 128 | "traceback": [ 129 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 130 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 131 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 13\u001b[0m )\n\u001b[1;32m 14\u001b[0m )\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mimg5\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"img5.jpg\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 132 | "\u001b[0;31mTypeError\u001b[0m: 'NoneType' object is not callable" 133 | ] 134 | } 135 | ], 136 | "source": [ 137 | "from PIL import Image\n", 138 | "img = Image.open(\"/Volumes/sinanatra/Programming/Wiki_art_dataset/\"+\"/test/\"+\"4.jpg\")\n", 139 | "\n", 140 | "longer_side = max(img.size)\n", 141 | "horizontal_padding = (longer_side - img.size[0]) / 2\n", 142 | "vertical_padding = (longer_side - img.size[1]) / 2\n", 143 | "img5 = img.crop(\n", 144 | " (\n", 145 | " -horizontal_padding,\n", 146 | " -vertical_padding,\n", 147 | " img.size[0] + horizontal_padding,\n", 148 | " img.size[1] + vertical_padding\n", 149 | " )\n", 150 | ")\n", 151 | "img5.show()(\"img5.jpg\")" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [] 160 | } 161 | ], 162 | "metadata": { 163 | "anaconda-cloud": {}, 164 | "kernelspec": { 165 | "display_name": "python3env", 166 | "language": "python", 167 | "name": "python3env" 168 | }, 169 | "language_info": { 170 | "codemirror_mode": { 171 | "name": "ipython", 172 | "version": 3 173 | }, 174 | "file_extension": ".py", 175 | "mimetype": "text/x-python", 176 | "name": "python", 177 | "nbconvert_exporter": "python", 178 | "pygments_lexer": "ipython3", 179 | "version": "3.6.4" 180 | } 181 | }, 182 | "nbformat": 4, 183 | "nbformat_minor": 2 184 | } 185 | --------------------------------------------------------------------------------