├── .DS_Store
├── res_noattn.gif
├── res_attention.gif
├── results
├── base.jpg
├── 0-step-0.jpg
├── 0-step-1.jpg
├── 0-step-2.jpg
├── 0-step-3.jpg
├── 0-step-4.jpg
├── 0-step-5.jpg
├── 0-step-6.jpg
├── 0-step-7.jpg
├── 0-step-8.jpg
├── 0-step-9.jpg
├── 800-step-0.jpg
├── 800-step-1.jpg
├── 800-step-2.jpg
├── 800-step-3.jpg
├── 800-step-4.jpg
├── 800-step-5.jpg
├── 800-step-6.jpg
├── 800-step-7.jpg
├── 800-step-8.jpg
├── 800-step-9.jpg
├── 1600-step-0.jpg
├── 1600-step-1.jpg
├── 1600-step-2.jpg
├── 1600-step-3.jpg
├── 1600-step-4.jpg
├── 1600-step-5.jpg
├── 1600-step-6.jpg
├── 1600-step-7.jpg
├── 1600-step-8.jpg
├── 1600-step-9.jpg
├── 2400-step-0.jpg
├── 2400-step-1.jpg
├── 2400-step-2.jpg
├── 2400-step-3.jpg
├── 2400-step-4.jpg
├── 2400-step-5.jpg
├── 2400-step-6.jpg
├── 2400-step-7.jpg
├── 2400-step-8.jpg
├── 2400-step-9.jpg
├── view-step-0.jpg
├── view-step-1.jpg
├── view-step-2.jpg
├── view-step-3.jpg
├── view-step-4.jpg
├── view-step-5.jpg
├── view-step-6.jpg
├── view-step-7.jpg
├── view-step-8.jpg
├── view-step-9.jpg
├── view-attn-step-0.jpg
├── view-attn-step-1.jpg
├── view-attn-step-2.jpg
├── view-attn-step-3.jpg
├── view-attn-step-4.jpg
├── view-attn-step-5.jpg
├── view-attn-step-6.jpg
├── view-attn-step-7.jpg
├── view-attn-step-8.jpg
├── view-attn-step-9.jpg
├── view-clean-step-0.jpg
├── view-clean-step-1.jpg
├── view-clean-step-2.jpg
├── view-clean-step-3.jpg
├── view-clean-step-4.jpg
├── view-clean-step-5.jpg
├── view-clean-step-6.jpg
├── view-clean-step-7.jpg
├── view-clean-step-8.jpg
└── view-clean-step-9.jpg
├── good-results
├── base.jpg
├── 100-step-0.jpg
├── 100-step-1.jpg
├── 100-step-2.jpg
├── 100-step-3.jpg
├── 100-step-4.jpg
├── 100-step-5.jpg
├── 100-step-6.jpg
├── 100-step-7.jpg
├── 100-step-8.jpg
├── 100-step-9.jpg
├── 200-step-0.jpg
├── 200-step-1.jpg
├── 200-step-2.jpg
├── 200-step-3.jpg
├── 200-step-4.jpg
├── 200-step-5.jpg
├── 200-step-6.jpg
├── 200-step-7.jpg
├── 200-step-8.jpg
├── 200-step-9.jpg
├── 300-step-0.jpg
├── 300-step-1.jpg
├── 300-step-2.jpg
├── 300-step-3.jpg
├── 300-step-4.jpg
├── 300-step-5.jpg
├── 300-step-6.jpg
├── 300-step-7.jpg
├── 300-step-8.jpg
├── 300-step-9.jpg
├── 400-step-0.jpg
├── 400-step-1.jpg
├── 400-step-2.jpg
├── 400-step-3.jpg
├── 400-step-4.jpg
├── 400-step-5.jpg
├── 400-step-6.jpg
├── 400-step-7.jpg
├── 400-step-8.jpg
├── 400-step-9.jpg
├── 500-step-0.jpg
├── 500-step-1.jpg
├── 500-step-2.jpg
├── 500-step-3.jpg
├── 500-step-4.jpg
├── 500-step-5.jpg
├── 500-step-6.jpg
├── 500-step-7.jpg
├── 500-step-8.jpg
├── 500-step-9.jpg
├── 600-step-0.jpg
├── 600-step-1.jpg
├── 3000-step-6.jpg
├── 3100-step-0.jpg
├── 3100-step-1.jpg
├── 3100-step-2.jpg
├── 3100-step-3.jpg
├── 3100-step-4.jpg
├── 3100-step-5.jpg
├── 3100-step-6.jpg
├── 3100-step-7.jpg
├── 3100-step-8.jpg
└── 3100-step-9.jpg
├── results-noattn
├── 14000-step-0.jpg
├── 14000-step-1.jpg
├── 14000-step-2.jpg
├── 14000-step-3.jpg
├── 14000-step-4.jpg
├── 14000-step-5.jpg
├── 14000-step-6.jpg
├── 14000-step-7.jpg
├── 14000-step-8.jpg
└── 14000-step-9.jpg
├── README.md
├── utils.py
├── plot_data.py
├── ops.py
├── input_data.py
├── main.py
└── color.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/.DS_Store
--------------------------------------------------------------------------------
/res_noattn.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/res_noattn.gif
--------------------------------------------------------------------------------
/res_attention.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/res_attention.gif
--------------------------------------------------------------------------------
/results/base.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/base.jpg
--------------------------------------------------------------------------------
/good-results/base.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/base.jpg
--------------------------------------------------------------------------------
/results/0-step-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-0.jpg
--------------------------------------------------------------------------------
/results/0-step-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-1.jpg
--------------------------------------------------------------------------------
/results/0-step-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-2.jpg
--------------------------------------------------------------------------------
/results/0-step-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-3.jpg
--------------------------------------------------------------------------------
/results/0-step-4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-4.jpg
--------------------------------------------------------------------------------
/results/0-step-5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-5.jpg
--------------------------------------------------------------------------------
/results/0-step-6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-6.jpg
--------------------------------------------------------------------------------
/results/0-step-7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-7.jpg
--------------------------------------------------------------------------------
/results/0-step-8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-8.jpg
--------------------------------------------------------------------------------
/results/0-step-9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-9.jpg
--------------------------------------------------------------------------------
/results/800-step-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-0.jpg
--------------------------------------------------------------------------------
/results/800-step-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-1.jpg
--------------------------------------------------------------------------------
/results/800-step-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-2.jpg
--------------------------------------------------------------------------------
/results/800-step-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-3.jpg
--------------------------------------------------------------------------------
/results/800-step-4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-4.jpg
--------------------------------------------------------------------------------
/results/800-step-5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-5.jpg
--------------------------------------------------------------------------------
/results/800-step-6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-6.jpg
--------------------------------------------------------------------------------
/results/800-step-7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-7.jpg
--------------------------------------------------------------------------------
/results/800-step-8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-8.jpg
--------------------------------------------------------------------------------
/results/800-step-9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-9.jpg
--------------------------------------------------------------------------------
/results/1600-step-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-0.jpg
--------------------------------------------------------------------------------
/results/1600-step-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-1.jpg
--------------------------------------------------------------------------------
/results/1600-step-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-2.jpg
--------------------------------------------------------------------------------
/results/1600-step-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-3.jpg
--------------------------------------------------------------------------------
/results/1600-step-4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-4.jpg
--------------------------------------------------------------------------------
/results/1600-step-5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-5.jpg
--------------------------------------------------------------------------------
/results/1600-step-6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-6.jpg
--------------------------------------------------------------------------------
/results/1600-step-7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-7.jpg
--------------------------------------------------------------------------------
/results/1600-step-8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-8.jpg
--------------------------------------------------------------------------------
/results/1600-step-9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-9.jpg
--------------------------------------------------------------------------------
/results/2400-step-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-0.jpg
--------------------------------------------------------------------------------
/results/2400-step-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-1.jpg
--------------------------------------------------------------------------------
/results/2400-step-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-2.jpg
--------------------------------------------------------------------------------
/results/2400-step-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-3.jpg
--------------------------------------------------------------------------------
/results/2400-step-4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-4.jpg
--------------------------------------------------------------------------------
/results/2400-step-5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-5.jpg
--------------------------------------------------------------------------------
/results/2400-step-6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-6.jpg
--------------------------------------------------------------------------------
/results/2400-step-7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-7.jpg
--------------------------------------------------------------------------------
/results/2400-step-8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-8.jpg
--------------------------------------------------------------------------------
/results/2400-step-9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-9.jpg
--------------------------------------------------------------------------------
/results/view-step-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-0.jpg
--------------------------------------------------------------------------------
/results/view-step-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-1.jpg
--------------------------------------------------------------------------------
/results/view-step-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-2.jpg
--------------------------------------------------------------------------------
/results/view-step-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-3.jpg
--------------------------------------------------------------------------------
/results/view-step-4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-4.jpg
--------------------------------------------------------------------------------
/results/view-step-5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-5.jpg
--------------------------------------------------------------------------------
/results/view-step-6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-6.jpg
--------------------------------------------------------------------------------
/results/view-step-7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-7.jpg
--------------------------------------------------------------------------------
/results/view-step-8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-8.jpg
--------------------------------------------------------------------------------
/results/view-step-9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-9.jpg
--------------------------------------------------------------------------------
/good-results/100-step-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-0.jpg
--------------------------------------------------------------------------------
/good-results/100-step-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-1.jpg
--------------------------------------------------------------------------------
/good-results/100-step-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-2.jpg
--------------------------------------------------------------------------------
/good-results/100-step-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-3.jpg
--------------------------------------------------------------------------------
/good-results/100-step-4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-4.jpg
--------------------------------------------------------------------------------
/good-results/100-step-5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-5.jpg
--------------------------------------------------------------------------------
/good-results/100-step-6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-6.jpg
--------------------------------------------------------------------------------
/good-results/100-step-7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-7.jpg
--------------------------------------------------------------------------------
/good-results/100-step-8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-8.jpg
--------------------------------------------------------------------------------
/good-results/100-step-9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-9.jpg
--------------------------------------------------------------------------------
/good-results/200-step-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-0.jpg
--------------------------------------------------------------------------------
/good-results/200-step-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-1.jpg
--------------------------------------------------------------------------------
/good-results/200-step-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-2.jpg
--------------------------------------------------------------------------------
/good-results/200-step-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-3.jpg
--------------------------------------------------------------------------------
/good-results/200-step-4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-4.jpg
--------------------------------------------------------------------------------
/good-results/200-step-5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-5.jpg
--------------------------------------------------------------------------------
/good-results/200-step-6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-6.jpg
--------------------------------------------------------------------------------
/good-results/200-step-7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-7.jpg
--------------------------------------------------------------------------------
/good-results/200-step-8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-8.jpg
--------------------------------------------------------------------------------
/good-results/200-step-9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-9.jpg
--------------------------------------------------------------------------------
/good-results/300-step-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-0.jpg
--------------------------------------------------------------------------------
/good-results/300-step-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-1.jpg
--------------------------------------------------------------------------------
/good-results/300-step-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-2.jpg
--------------------------------------------------------------------------------
/good-results/300-step-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-3.jpg
--------------------------------------------------------------------------------
/good-results/300-step-4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-4.jpg
--------------------------------------------------------------------------------
/good-results/300-step-5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-5.jpg
--------------------------------------------------------------------------------
/good-results/300-step-6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-6.jpg
--------------------------------------------------------------------------------
/good-results/300-step-7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-7.jpg
--------------------------------------------------------------------------------
/good-results/300-step-8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-8.jpg
--------------------------------------------------------------------------------
/good-results/300-step-9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-9.jpg
--------------------------------------------------------------------------------
/good-results/400-step-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-0.jpg
--------------------------------------------------------------------------------
/good-results/400-step-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-1.jpg
--------------------------------------------------------------------------------
/good-results/400-step-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-2.jpg
--------------------------------------------------------------------------------
/good-results/400-step-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-3.jpg
--------------------------------------------------------------------------------
/good-results/400-step-4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-4.jpg
--------------------------------------------------------------------------------
/good-results/400-step-5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-5.jpg
--------------------------------------------------------------------------------
/good-results/400-step-6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-6.jpg
--------------------------------------------------------------------------------
/good-results/400-step-7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-7.jpg
--------------------------------------------------------------------------------
/good-results/400-step-8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-8.jpg
--------------------------------------------------------------------------------
/good-results/400-step-9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-9.jpg
--------------------------------------------------------------------------------
/good-results/500-step-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-0.jpg
--------------------------------------------------------------------------------
/good-results/500-step-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-1.jpg
--------------------------------------------------------------------------------
/good-results/500-step-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-2.jpg
--------------------------------------------------------------------------------
/good-results/500-step-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-3.jpg
--------------------------------------------------------------------------------
/good-results/500-step-4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-4.jpg
--------------------------------------------------------------------------------
/good-results/500-step-5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-5.jpg
--------------------------------------------------------------------------------
/good-results/500-step-6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-6.jpg
--------------------------------------------------------------------------------
/good-results/500-step-7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-7.jpg
--------------------------------------------------------------------------------
/good-results/500-step-8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-8.jpg
--------------------------------------------------------------------------------
/good-results/500-step-9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-9.jpg
--------------------------------------------------------------------------------
/good-results/600-step-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/600-step-0.jpg
--------------------------------------------------------------------------------
/good-results/600-step-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/600-step-1.jpg
--------------------------------------------------------------------------------
/good-results/3000-step-6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3000-step-6.jpg
--------------------------------------------------------------------------------
/good-results/3100-step-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-0.jpg
--------------------------------------------------------------------------------
/good-results/3100-step-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-1.jpg
--------------------------------------------------------------------------------
/good-results/3100-step-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-2.jpg
--------------------------------------------------------------------------------
/good-results/3100-step-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-3.jpg
--------------------------------------------------------------------------------
/good-results/3100-step-4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-4.jpg
--------------------------------------------------------------------------------
/good-results/3100-step-5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-5.jpg
--------------------------------------------------------------------------------
/good-results/3100-step-6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-6.jpg
--------------------------------------------------------------------------------
/good-results/3100-step-7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-7.jpg
--------------------------------------------------------------------------------
/good-results/3100-step-8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-8.jpg
--------------------------------------------------------------------------------
/good-results/3100-step-9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-9.jpg
--------------------------------------------------------------------------------
/results/view-attn-step-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-0.jpg
--------------------------------------------------------------------------------
/results/view-attn-step-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-1.jpg
--------------------------------------------------------------------------------
/results/view-attn-step-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-2.jpg
--------------------------------------------------------------------------------
/results/view-attn-step-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-3.jpg
--------------------------------------------------------------------------------
/results/view-attn-step-4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-4.jpg
--------------------------------------------------------------------------------
/results/view-attn-step-5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-5.jpg
--------------------------------------------------------------------------------
/results/view-attn-step-6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-6.jpg
--------------------------------------------------------------------------------
/results/view-attn-step-7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-7.jpg
--------------------------------------------------------------------------------
/results/view-attn-step-8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-8.jpg
--------------------------------------------------------------------------------
/results/view-attn-step-9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-9.jpg
--------------------------------------------------------------------------------
/results/view-clean-step-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-0.jpg
--------------------------------------------------------------------------------
/results/view-clean-step-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-1.jpg
--------------------------------------------------------------------------------
/results/view-clean-step-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-2.jpg
--------------------------------------------------------------------------------
/results/view-clean-step-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-3.jpg
--------------------------------------------------------------------------------
/results/view-clean-step-4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-4.jpg
--------------------------------------------------------------------------------
/results/view-clean-step-5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-5.jpg
--------------------------------------------------------------------------------
/results/view-clean-step-6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-6.jpg
--------------------------------------------------------------------------------
/results/view-clean-step-7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-7.jpg
--------------------------------------------------------------------------------
/results/view-clean-step-8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-8.jpg
--------------------------------------------------------------------------------
/results/view-clean-step-9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-9.jpg
--------------------------------------------------------------------------------
/results-noattn/14000-step-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-0.jpg
--------------------------------------------------------------------------------
/results-noattn/14000-step-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-1.jpg
--------------------------------------------------------------------------------
/results-noattn/14000-step-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-2.jpg
--------------------------------------------------------------------------------
/results-noattn/14000-step-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-3.jpg
--------------------------------------------------------------------------------
/results-noattn/14000-step-4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-4.jpg
--------------------------------------------------------------------------------
/results-noattn/14000-step-5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-5.jpg
--------------------------------------------------------------------------------
/results-noattn/14000-step-6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-6.jpg
--------------------------------------------------------------------------------
/results-noattn/14000-step-7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-7.jpg
--------------------------------------------------------------------------------
/results-noattn/14000-step-8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-8.jpg
--------------------------------------------------------------------------------
/results-noattn/14000-step-9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-9.jpg
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # draw-color
2 | A Tensorflow implementation of [DRAW](https://arxiv.org/abs/1502.04623). Now includes support for colored images!
3 |
4 | This is code to accompany [my post on the DRAW model](http://kvfrans.com/what-is-draw-deep-recurrent-attentive-writer/).
5 |
6 | For an explanation of how I modified the original DRAW into a colored model, check out [my post on colorizing DRAW](http://kvfrans.com/colorizing-the-draw-model/).
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | The original DRAW is a slightly rewritten, commented version of [ericjang's implementation](https://github.com/ericjang/draw), running on MNIST.
15 |
16 | Usage:
17 | ```
18 | python main.py
19 | ```
20 |
21 | The colored DRAW runs on the celebA dataset.
22 |
23 | Usage:
24 | ```
25 | python color.py
26 | ```
27 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import scipy.misc
2 | import numpy as np
3 | import random
4 | import tensorflow as tf
5 | import cPickle
6 |
7 |
8 | def get_image(image_path, image_size, is_crop=True):
9 | return transform(imread(image_path), image_size, is_crop)
10 |
11 | def transform(image, npx=64, is_crop=True):
12 | # npx : # of pixels width/height of image
13 | if is_crop:
14 | cropped_image = center_crop(image, npx)
15 | else:
16 | cropped_image = image
17 | return np.array(cropped_image)/127.5 - 1.
18 |
19 | def center_crop(x, crop_h, crop_w=None, resize_w=64):
20 | if crop_w is None:
21 | crop_w = crop_h
22 | h, w = x.shape[:2]
23 | j = int(round((h - crop_h)/2.))
24 | i = int(round((w - crop_w)/2.))
25 | return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w],
26 | [resize_w, resize_w])
27 |
28 | def imread(path):
29 | readimage = scipy.misc.imread(path, mode="RGB").astype(np.float)
30 | return readimage
31 |
32 | def merge_color(images, size):
33 | h, w = images.shape[1], images.shape[2]
34 | img = np.zeros((h * size[0], w * size[1], 3))
35 |
36 | for idx, image in enumerate(images):
37 | i = idx % size[1]
38 | j = idx / size[1]
39 | img[j*h:j*h+h, i*w:i*w+w, :] = image
40 |
41 | return img
42 |
43 | def unpickle(file):
44 | fo = open(file, 'rb')
45 | dict = cPickle.load(fo)
46 | fo.close()
47 | return dict
48 |
49 | def ims(name, img):
50 | # print img[:10][:10]
51 | scipy.misc.toimage(img, cmin=0, cmax=1).save(name)
52 |
53 | def sigmoid(x):
54 | return 1 / (1 + np.exp(-x))
55 |
--------------------------------------------------------------------------------
/plot_data.py:
--------------------------------------------------------------------------------
1 | # takes data saved by DRAW model and generates animations
2 | # example usage: python plot_data.py noattn /tmp/draw/draw_data.npy
3 |
4 | import matplotlib
5 | import sys
6 | import numpy as np
7 |
8 | interactive=False # set to False if you want to write images to file
9 |
10 | if not interactive:
11 | matplotlib.use('Agg') # Force matplotlib to not use any Xwindows backend.
12 | import matplotlib.pyplot as plt
13 |
14 |
15 | def xrecons_grid(X,B,A):
16 | """
17 | plots canvas for single time step
18 | X is x_recons, (batch_size x img_size)
19 | assumes features = BxA images
20 | batch is assumed to be a square number
21 | """
22 | padsize=1
23 | padval=.5
24 | ph=B+2*padsize
25 | pw=A+2*padsize
26 | batch_size=X.shape[0]
27 | N=int(np.sqrt(batch_size))
28 | X=X.reshape((N,N,B,A))
29 | img=np.ones((N*ph,N*pw))*padval
30 | for i in range(N):
31 | for j in range(N):
32 | startr=i*ph+padsize
33 | endr=startr+B
34 | startc=j*pw+padsize
35 | endc=startc+A
36 | img[startr:endr,startc:endc]=X[i,j,:,:]
37 | return img
38 |
39 | if __name__ == '__main__':
40 | prefix=sys.argv[1]
41 | out_file=sys.argv[2]
42 | C=np.load(out_file)
43 | T,batch_size,img_size=C.shape
44 | X=1.0/(1.0+np.exp(-C)) # x_recons=sigmoid(canvas)
45 | B=A=int(np.sqrt(img_size))
46 | if interactive:
47 | f,arr=plt.subplots(1,T)
48 | for t in range(T):
49 | img=xrecons_grid(X[t,:,:],B,A)
50 | if interactive:
51 | arr[t].matshow(img,cmap=plt.cm.gray)
52 | arr[t].set_xticks([])
53 | arr[t].set_yticks([])
54 | else:
55 | plt.matshow(img,cmap=plt.cm.gray)
56 | imgname='%s_%d.png' % (prefix,t) # you can merge using imagemagick, i.e. convert -delay 10 -loop 0 *.png mnist.gif
57 | plt.savefig(imgname)
58 | print(imgname)
59 |
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | class batch_norm(object):
5 | """Code modification of http://stackoverflow.com/a/33950177"""
6 | def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"):
7 | with tf.variable_scope(name):
8 | self.epsilon = epsilon
9 | self.momentum = momentum
10 |
11 | self.ema = tf.train.ExponentialMovingAverage(decay=self.momentum)
12 | self.name = name
13 |
14 | def __call__(self, x, train=True):
15 | shape = x.get_shape().as_list()
16 |
17 | if train:
18 | with tf.variable_scope(self.name) as scope:
19 | self.beta = tf.get_variable("beta", [shape[-1]],
20 | initializer=tf.constant_initializer(0.))
21 | self.gamma = tf.get_variable("gamma", [shape[-1]],
22 | initializer=tf.random_normal_initializer(1., 0.02))
23 |
24 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments')
25 | ema_apply_op = self.ema.apply([batch_mean, batch_var])
26 | self.ema_mean, self.ema_var = self.ema.average(batch_mean), self.ema.average(batch_var)
27 |
28 | with tf.control_dependencies([ema_apply_op]):
29 | mean, var = tf.identity(batch_mean), tf.identity(batch_var)
30 | else:
31 | mean, var = self.ema_mean, self.ema_var
32 |
33 | normed = tf.nn.batch_norm_with_global_normalization(
34 | x, mean, var, self.beta, self.gamma, self.epsilon, scale_after_normalization=True)
35 |
36 | return normed
37 |
38 | # standard convolution layer
39 | def conv2d(x, inputFeatures, outputFeatures, name):
40 | with tf.variable_scope(name):
41 | w = tf.get_variable("w",[5,5,inputFeatures, outputFeatures], initializer=tf.truncated_normal_initializer(stddev=0.02))
42 | b = tf.get_variable("b",[outputFeatures], initializer=tf.constant_initializer(0.0))
43 | conv = tf.nn.conv2d(x, w, strides=[1,2,2,1], padding="SAME") + b
44 | return conv
45 |
46 | def conv_transpose(x, outputShape, name):
47 | with tf.variable_scope(name):
48 | # h, w, out, in
49 | w = tf.get_variable("w",[5,5, outputShape[-1], x.get_shape()[-1]], initializer=tf.truncated_normal_initializer(stddev=0.02))
50 | b = tf.get_variable("b",[outputShape[-1]], initializer=tf.constant_initializer(0.0))
51 | convt = tf.nn.conv2d_transpose(x, w, output_shape=outputShape, strides=[1,2,2,1])
52 | return convt
53 |
54 | # leaky reLu unit
55 | def lrelu(x, leak=0.2, name="lrelu"):
56 | with tf.variable_scope(name):
57 | f1 = 0.5 * (1 + leak)
58 | f2 = 0.5 * (1 - leak)
59 | return f1 * x + f2 * abs(x)
60 |
61 | # fully-conected layer
62 | def dense(x, inputFeatures, outputFeatures, scope=None, with_w=False):
63 | with tf.variable_scope(scope or "Linear"):
64 | matrix = tf.get_variable("Matrix", [inputFeatures, outputFeatures], tf.float32, tf.random_normal_initializer(stddev=0.02))
65 | bias = tf.get_variable("bias", [outputFeatures], initializer=tf.constant_initializer(0.0))
66 | if with_w:
67 | return tf.matmul(x, matrix) + bias, matrix, bias
68 | else:
69 | return tf.matmul(x, matrix) + bias
70 |
71 | def merge(images, size):
72 | h, w = images.shape[1], images.shape[2]
73 | img = np.zeros((h * size[0], w * size[1]))
74 |
75 | for idx, image in enumerate(images):
76 | i = idx % size[1]
77 | j = idx / size[1]
78 | img[j*h:j*h+h, i*w:i*w+w] = image
79 |
80 | return img
81 |
--------------------------------------------------------------------------------
/input_data.py:
--------------------------------------------------------------------------------
1 | """Functions for downloading and reading MNIST data."""
2 | import gzip
3 | import os
4 | import urllib
5 | import numpy
6 | SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
7 |
8 |
9 | def maybe_download(filename, work_directory):
10 | """Download the data from Yann's website, unless it's already here."""
11 | if not os.path.exists(work_directory):
12 | os.mkdir(work_directory)
13 | filepath = os.path.join(work_directory, filename)
14 | if not os.path.exists(filepath):
15 | filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)
16 | statinfo = os.stat(filepath)
17 | print 'Succesfully downloaded', filename, statinfo.st_size, 'bytes.'
18 | return filepath
19 |
20 |
21 | def _read32(bytestream):
22 | dt = numpy.dtype(numpy.uint32).newbyteorder('>')
23 | return numpy.frombuffer(bytestream.read(4), dtype=dt)
24 |
25 |
26 | def extract_images(filename):
27 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
28 | print 'Extracting', filename
29 | with gzip.open(filename) as bytestream:
30 | magic = _read32(bytestream)
31 | if magic != 2051:
32 | raise ValueError(
33 | 'Invalid magic number %d in MNIST image file: %s' %
34 | (magic, filename))
35 | num_images = _read32(bytestream)
36 | rows = _read32(bytestream)
37 | cols = _read32(bytestream)
38 | buf = bytestream.read(rows * cols * num_images)
39 | data = numpy.frombuffer(buf, dtype=numpy.uint8)
40 | data = data.reshape(num_images, rows, cols, 1)
41 | return data
42 |
43 |
44 | def dense_to_one_hot(labels_dense, num_classes=10):
45 | """Convert class labels from scalars to one-hot vectors."""
46 | num_labels = labels_dense.shape[0]
47 | index_offset = numpy.arange(num_labels) * num_classes
48 | labels_one_hot = numpy.zeros((num_labels, num_classes))
49 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
50 | return labels_one_hot
51 |
52 |
53 | def extract_labels(filename, one_hot=False):
54 | """Extract the labels into a 1D uint8 numpy array [index]."""
55 | print 'Extracting', filename
56 | with gzip.open(filename) as bytestream:
57 | magic = _read32(bytestream)
58 | if magic != 2049:
59 | raise ValueError(
60 | 'Invalid magic number %d in MNIST label file: %s' %
61 | (magic, filename))
62 | num_items = _read32(bytestream)
63 | buf = bytestream.read(num_items)
64 | labels = numpy.frombuffer(buf, dtype=numpy.uint8)
65 | if one_hot:
66 | return dense_to_one_hot(labels)
67 | return labels
68 |
69 |
70 | class DataSet(object):
71 | def __init__(self, images, labels, fake_data=False):
72 | if fake_data:
73 | self._num_examples = 10000
74 | else:
75 | assert images.shape[0] == labels.shape[0], (
76 | "images.shape: %s labels.shape: %s" % (images.shape,
77 | labels.shape))
78 | self._num_examples = images.shape[0]
79 | # Convert shape from [num examples, rows, columns, depth]
80 | # to [num examples, rows*columns] (assuming depth == 1)
81 | assert images.shape[3] == 1
82 | images = images.reshape(images.shape[0],
83 | images.shape[1] * images.shape[2])
84 | # Convert from [0, 255] -> [0.0, 1.0].
85 | images = images.astype(numpy.float32)
86 | images = numpy.multiply(images, 1.0 / 255.0)
87 | self._images = images
88 | self._labels = labels
89 | self._epochs_completed = 0
90 | self._index_in_epoch = 0
91 |
92 | @property
93 | def images(self):
94 | return self._images
95 |
96 | @property
97 | def labels(self):
98 | return self._labels
99 |
100 | @property
101 | def num_examples(self):
102 | return self._num_examples
103 |
104 | @property
105 | def epochs_completed(self):
106 | return self._epochs_completed
107 |
108 | def next_batch(self, batch_size, fake_data=False):
109 | """Return the next `batch_size` examples from this data set."""
110 | if fake_data:
111 | fake_image = [1.0 for _ in xrange(784)]
112 | fake_label = 0
113 | return [fake_image for _ in xrange(batch_size)], [
114 | fake_label for _ in xrange(batch_size)]
115 | start = self._index_in_epoch
116 | self._index_in_epoch += batch_size
117 | if self._index_in_epoch > self._num_examples:
118 | # Finished epoch
119 | self._epochs_completed += 1
120 | # Shuffle the data
121 | perm = numpy.arange(self._num_examples)
122 | numpy.random.shuffle(perm)
123 | self._images = self._images[perm]
124 | self._labels = self._labels[perm]
125 | # Start next epoch
126 | start = 0
127 | self._index_in_epoch = batch_size
128 | assert batch_size <= self._num_examples
129 | end = self._index_in_epoch
130 | return self._images[start:end], self._labels[start:end]
131 |
132 |
133 | def read_data_sets(train_dir, fake_data=False, one_hot=False):
134 | class DataSets(object):
135 | pass
136 | data_sets = DataSets()
137 | if fake_data:
138 | data_sets.train = DataSet([], [], fake_data=True)
139 | data_sets.validation = DataSet([], [], fake_data=True)
140 | data_sets.test = DataSet([], [], fake_data=True)
141 | return data_sets
142 | TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
143 | TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
144 | TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
145 | TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
146 | VALIDATION_SIZE = 5000
147 | local_file = maybe_download(TRAIN_IMAGES, train_dir)
148 | train_images = extract_images(local_file)
149 | local_file = maybe_download(TRAIN_LABELS, train_dir)
150 | train_labels = extract_labels(local_file, one_hot=one_hot)
151 | local_file = maybe_download(TEST_IMAGES, train_dir)
152 | test_images = extract_images(local_file)
153 | local_file = maybe_download(TEST_LABELS, train_dir)
154 | test_labels = extract_labels(local_file, one_hot=one_hot)
155 | validation_images = train_images[:VALIDATION_SIZE]
156 | validation_labels = train_labels[:VALIDATION_SIZE]
157 | train_images = train_images[VALIDATION_SIZE:]
158 | train_labels = train_labels[VALIDATION_SIZE:]
159 | data_sets.train = DataSet(train_images, train_labels)
160 | data_sets.validation = DataSet(validation_images, validation_labels)
161 | data_sets.test = DataSet(test_images, test_labels)
162 | return data_sets
163 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from ops import *
4 | from utils import *
5 | import input_data
6 | # from scipy.misc import imsave as ims
7 |
8 |
9 | class Draw():
10 | def __init__(self):
11 | self.mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
12 | self.n_samples = self.mnist.train.num_examples
13 |
14 | self.img_size = 28
15 | self.attention_n = 5
16 | self.n_hidden = 256
17 | self.n_z = 10
18 | self.sequence_length = 10
19 | self.batch_size = 64
20 | self.share_parameters = False
21 |
22 | self.images = tf.placeholder(tf.float32, [None, 784])
23 | self.e = tf.random_normal((self.batch_size, self.n_z), mean=0, stddev=1) # Qsampler noise
24 | self.lstm_enc = tf.nn.rnn_cell.LSTMCell(self.n_hidden, state_is_tuple=True) # encoder Op
25 | self.lstm_dec = tf.nn.rnn_cell.LSTMCell(self.n_hidden, state_is_tuple=True) # decoder Op
26 |
27 | self.cs = [0] * self.sequence_length
28 | self.mu, self.logsigma, self.sigma = [0] * self.sequence_length, [0] * self.sequence_length, [0] * self.sequence_length
29 |
30 | h_dec_prev = tf.zeros((self.batch_size, self.n_hidden))
31 | enc_state = self.lstm_enc.zero_state(self.batch_size, tf.float32)
32 | dec_state = self.lstm_dec.zero_state(self.batch_size, tf.float32)
33 |
34 | x = self.images
35 | for t in range(self.sequence_length):
36 | # error image + original image
37 | c_prev = tf.zeros((self.batch_size, self.img_size**2)) if t == 0 else self.cs[t-1]
38 | x_hat = x - tf.sigmoid(c_prev)
39 | # read the image
40 | r = self.read_basic(x,x_hat,h_dec_prev)
41 | print r.get_shape()
42 | # r = self.read_attention(x,x_hat,h_dec_prev)
43 | # encode it to guass distrib
44 | self.mu[t], self.logsigma[t], self.sigma[t], enc_state = self.encode(enc_state, tf.concat(1, [r, h_dec_prev]))
45 | # sample from the distrib to get z
46 | z = self.sampleQ(self.mu[t],self.sigma[t])
47 | print z.get_shape()
48 | # retrieve the hidden layer of RNN
49 | h_dec, dec_state = self.decode_layer(dec_state, z)
50 |
51 | print h_dec.get_shape()
52 |
53 | # map from hidden layer -> image portion, and then write it.
54 | self.cs[t] = c_prev + self.write_basic(h_dec)
55 | # self.cs[t] = c_prev + self.write_attention(h_dec)
56 | h_dec_prev = h_dec
57 | self.share_parameters = True # from now on, share variables
58 |
59 | # the final timestep
60 | self.generated_images = tf.nn.sigmoid(self.cs[-1])
61 |
62 | self.generation_loss = tf.reduce_mean(-tf.reduce_sum(self.images * tf.log(1e-10 + self.generated_images) + (1-self.images) * tf.log(1e-10 + 1 - self.generated_images),1))
63 |
64 | kl_terms = [0]*self.sequence_length
65 | for t in xrange(self.sequence_length):
66 | mu2 = tf.square(self.mu[t])
67 | sigma2 = tf.square(self.sigma[t])
68 | logsigma = self.logsigma[t]
69 | kl_terms[t] = 0.5 * tf.reduce_sum(mu2 + sigma2 - 2*logsigma, 1) - self.sequence_length*0.5
70 | self.latent_loss = tf.reduce_mean(tf.add_n(kl_terms))
71 | self.cost = self.generation_loss + self.latent_loss
72 | optimizer = tf.train.AdamOptimizer(1e-3, beta1=0.5)
73 | grads = optimizer.compute_gradients(self.cost)
74 | for i,(g,v) in enumerate(grads):
75 | if g is not None:
76 | grads[i] = (tf.clip_by_norm(g,5),v)
77 | self.train_op = optimizer.apply_gradients(grads)
78 |
79 | self.sess = tf.Session()
80 | self.sess.run(tf.initialize_all_variables())
81 |
82 | def train(self):
83 | for i in xrange(15000):
84 | xtrain, _ = self.mnist.train.next_batch(self.batch_size)
85 | cs, gen_loss, lat_loss, _ = self.sess.run([self.cs, self.generation_loss, self.latent_loss, self.train_op], feed_dict={self.images: xtrain})
86 | print "iter %d genloss %f latloss %f" % (i, gen_loss, lat_loss)
87 | if i % 500 == 0:
88 |
89 | cs = 1.0/(1.0+np.exp(-np.array(cs))) # x_recons=sigmoid(canvas)
90 |
91 | for cs_iter in xrange(10):
92 | results = cs[cs_iter]
93 | results_square = np.reshape(results, [-1, 28, 28])
94 | print results_square.shape
95 | ims("results/"+str(i)+"-step-"+str(cs_iter)+".jpg",merge(results_square,[8,8]))
96 |
97 |
98 | # given a hidden decoder layer:
99 | # locate where to put attention filters
100 | def attn_window(self, scope, h_dec):
101 | with tf.variable_scope(scope, reuse=self.share_parameters):
102 | parameters = dense(h_dec, self.n_hidden, 5)
103 | # gx_, gy_: center of 2d gaussian on a scale of -1 to 1
104 | gx_, gy_, log_sigma2, log_delta, log_gamma = tf.split(1,5,parameters)
105 |
106 | # move gx/gy to be a scale of -imgsize to +imgsize
107 | gx = (self.img_size+1)/2 * (gx_ + 1)
108 | gy = (self.img_size+1)/2 * (gy_ + 1)
109 |
110 | sigma2 = tf.exp(log_sigma2)
111 | # stride/delta: how far apart these patches will be
112 | delta = (self.img_size - 1) / ((self.attention_n-1) * tf.exp(log_delta))
113 | # returns [Fx, Fy, gamma]
114 | return self.filterbank(gx,gy,sigma2,delta) + (tf.exp(log_gamma),)
115 |
116 | # Given a center, distance, and spread
117 | # Construct [attention_n x attention_n] patches of gaussian filters
118 | # represented by Fx = horizontal gaussian, Fy = vertical guassian
119 | def filterbank(self, gx, gy, sigma2, delta):
120 | # 1 x N, look like [[0,1,2,3,4]]
121 | grid_i = tf.reshape(tf.cast(tf.range(self.attention_n), tf.float32),[1, -1])
122 | # centers for the individual patches
123 | mu_x = gx + (grid_i - self.attention_n/2 - 0.5) * delta
124 | mu_y = gy + (grid_i - self.attention_n/2 - 0.5) * delta
125 | mu_x = tf.reshape(mu_x, [-1, self.attention_n, 1])
126 | mu_y = tf.reshape(mu_y, [-1, self.attention_n, 1])
127 | # 1 x 1 x imgsize, looks like [[[0,1,2,3,4,...,27]]]
128 | im = tf.reshape(tf.cast(tf.range(self.img_size), tf.float32), [1, 1, -1])
129 | # list of gaussian curves for x and y
130 | sigma2 = tf.reshape(sigma2, [-1, 1, 1])
131 | Fx = tf.exp(-tf.square((im - mu_x) / (2*sigma2)))
132 | Fy = tf.exp(-tf.square((im - mu_x) / (2*sigma2)))
133 | # normalize so area-under-curve = 1
134 | Fx = Fx / tf.maximum(tf.reduce_sum(Fx,2,keep_dims=True),1e-8)
135 | Fy = Fy / tf.maximum(tf.reduce_sum(Fy,2,keep_dims=True),1e-8)
136 | return Fx, Fy
137 |
138 |
139 | # the read() operation without attention
140 | def read_basic(self, x, x_hat, h_dec_prev):
141 | return tf.concat(1,[x,x_hat])
142 |
143 | def read_attention(self, x, x_hat, h_dec_prev):
144 | Fx, Fy, gamma = self.attn_window("read", h_dec_prev)
145 | # we have the parameters for a patch of gaussian filters. apply them.
146 | def filter_img(img, Fx, Fy, gamma):
147 | Fxt = tf.transpose(Fx, perm=[0,2,1])
148 | img = tf.reshape(img, [-1, self.img_size, self.img_size])
149 | # Apply the gaussian patches:
150 | # keep in mind: horiz = imgsize = verts (they are all the image size)
151 | # keep in mind: attn = height/length of attention patches
152 | # allfilters = [attn, vert] * [imgsize,imgsize] * [horiz, attn]
153 | # we have batches, so the full batch_matmul equation looks like:
154 | # [1, 1, vert] * [batchsize,imgsize,imgsize] * [1, horiz, 1]
155 | glimpse = tf.batch_matmul(Fy, tf.batch_matmul(img, Fxt))
156 | glimpse = tf.reshape(glimpse, [-1, self.attention_n**2])
157 | # finally scale this glimpse w/ the gamma parameter
158 | return glimpse * tf.reshape(gamma, [-1, 1])
159 | x = filter_img(x, Fx, Fy, gamma)
160 | x_hat = filter_img(x_hat, Fx, Fy, gamma)
161 | return tf.concat(1, [x, x_hat])
162 |
163 | # encode an attention patch
164 | def encode(self, prev_state, image):
165 | # update the RNN with image
166 | with tf.variable_scope("encoder",reuse=self.share_parameters):
167 | hidden_layer, next_state = self.lstm_enc(image, prev_state)
168 |
169 | # map the RNN hidden state to latent variables
170 | with tf.variable_scope("mu", reuse=self.share_parameters):
171 | mu = dense(hidden_layer, self.n_hidden, self.n_z)
172 | with tf.variable_scope("sigma", reuse=self.share_parameters):
173 | logsigma = dense(hidden_layer, self.n_hidden, self.n_z)
174 | sigma = tf.exp(logsigma)
175 | return mu, logsigma, sigma, next_state
176 |
177 |
178 | def sampleQ(self, mu, sigma):
179 | return mu + sigma*self.e
180 |
181 | def decode_layer(self, prev_state, latent):
182 | # update decoder RNN with latent var
183 | with tf.variable_scope("decoder", reuse=self.share_parameters):
184 | hidden_layer, next_state = self.lstm_dec(latent, prev_state)
185 |
186 | return hidden_layer, next_state
187 |
188 | def write_basic(self, hidden_layer):
189 | # map RNN hidden state to image
190 | with tf.variable_scope("write", reuse=self.share_parameters):
191 | decoded_image_portion = dense(hidden_layer, self.n_hidden, self.img_size**2)
192 | return decoded_image_portion
193 |
194 | def write_attention(self, hidden_layer):
195 | with tf.variable_scope("writeW", reuse=self.share_parameters):
196 | w = dense(hidden_layer, self.n_hidden, self.attention_n**2)
197 | w = tf.reshape(w, [self.batch_size, self.attention_n, self.attention_n])
198 | Fx, Fy, gamma = self.attn_window("write", hidden_layer)
199 | Fyt = tf.transpose(Fy, perm=[0,2,1])
200 | # [vert, attn_n] * [attn_n, attn_n] * [attn_n, horiz]
201 | wr = tf.batch_matmul(Fyt, tf.batch_matmul(w, Fx))
202 | wr = tf.reshape(wr, [self.batch_size, self.img_size**2])
203 | return wr * tf.reshape(1.0/gamma, [-1, 1])
204 |
205 |
206 |
207 | model = Draw()
208 | model.train()
209 |
--------------------------------------------------------------------------------
/color.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from ops import *
4 | from utils import *
5 | from glob import glob
6 | import os
7 |
8 | class Draw():
9 | def __init__(self):
10 |
11 | self.img_size = 64
12 | self.num_colors = 3
13 |
14 | self.attention_n = 5
15 | self.n_hidden = 256
16 | self.n_z = 10
17 | self.sequence_length = 10
18 | self.batch_size = 64
19 | self.share_parameters = False
20 |
21 | self.images = tf.placeholder(tf.float32, [None, self.img_size, self.img_size, self.num_colors])
22 |
23 | self.e = tf.random_normal((self.batch_size, self.n_z), mean=0, stddev=1) # Qsampler noise
24 |
25 | self.lstm_enc = tf.nn.rnn_cell.LSTMCell(self.n_hidden, state_is_tuple=True) # encoder Op
26 | self.lstm_dec = tf.nn.rnn_cell.LSTMCell(self.n_hidden, state_is_tuple=True) # decoder Op
27 |
28 | self.cs = [0] * self.sequence_length
29 | self.mu, self.logsigma, self.sigma = [0] * self.sequence_length, [0] * self.sequence_length, [0] * self.sequence_length
30 |
31 | h_dec_prev = tf.zeros((self.batch_size, self.n_hidden))
32 | enc_state = self.lstm_enc.zero_state(self.batch_size, tf.float32)
33 | dec_state = self.lstm_dec.zero_state(self.batch_size, tf.float32)
34 |
35 | x = tf.reshape(self.images, [-1, self.img_size*self.img_size*self.num_colors])
36 | self.attn_params = []
37 | for t in range(self.sequence_length):
38 | # error image + original image
39 | c_prev = tf.zeros((self.batch_size, self.img_size * self.img_size * self.num_colors)) if t == 0 else self.cs[t-1]
40 | x_hat = x - tf.sigmoid(c_prev)
41 | # read the image
42 | # r = self.read_basic(x,x_hat,h_dec_prev)
43 | r = self.read_attention(x,x_hat,h_dec_prev)
44 | # encode it to gauss distrib
45 | self.mu[t], self.logsigma[t], self.sigma[t], enc_state = self.encode(enc_state, tf.concat(1, [r, h_dec_prev]))
46 | # sample from the distrib to get z
47 | z = self.sampleQ(self.mu[t],self.sigma[t])
48 | # retrieve the hidden layer of RNN
49 | h_dec, dec_state = self.decode_layer(dec_state, z)
50 | # map from hidden layer -> image portion, and then write it.
51 | # self.cs[t] = c_prev + self.write_basic(h_dec)
52 | self.cs[t] = c_prev + self.write_attention(h_dec)
53 | h_dec_prev = h_dec
54 | self.share_parameters = True # from now on, share variables
55 |
56 | # the final timestep
57 | self.generated_images = tf.nn.sigmoid(self.cs[-1])
58 |
59 | # self.generation_loss = tf.reduce_mean(-tf.reduce_sum(self.images * tf.log(1e-10 + self.generated_images) + (1-self.images) * tf.log(1e-10 + 1 - self.generated_images),1))
60 | self.generation_loss = tf.nn.l2_loss(x - self.generated_images)
61 |
62 | kl_terms = [0]*self.sequence_length
63 | for t in xrange(self.sequence_length):
64 | mu2 = tf.square(self.mu[t])
65 | sigma2 = tf.square(self.sigma[t])
66 | logsigma = self.logsigma[t]
67 | kl_terms[t] = 0.5 * tf.reduce_sum(mu2 + sigma2 - 2*logsigma, 1) - self.sequence_length*0.5
68 | self.latent_loss = tf.reduce_mean(tf.add_n(kl_terms))
69 | self.cost = self.generation_loss + self.latent_loss
70 | optimizer = tf.train.AdamOptimizer(1e-3, beta1=0.5)
71 | grads = optimizer.compute_gradients(self.cost)
72 | for i,(g,v) in enumerate(grads):
73 | if g is not None:
74 | grads[i] = (tf.clip_by_norm(g,5),v)
75 | self.train_op = optimizer.apply_gradients(grads)
76 |
77 | self.sess = tf.Session()
78 | self.sess.run(tf.initialize_all_variables())
79 |
80 | # given a hidden decoder layer:
81 | # locate where to put attention filters
82 | def attn_window(self, scope, h_dec):
83 | with tf.variable_scope(scope, reuse=self.share_parameters):
84 | parameters = dense(h_dec, self.n_hidden, 5)
85 | # gx_, gy_: center of 2d gaussian on a scale of -1 to 1
86 | gx_, gy_, log_sigma2, log_delta, log_gamma = tf.split(1,5,parameters)
87 |
88 | # move gx/gy to be a scale of -imgsize to +imgsize
89 | gx = (self.img_size+1)/2 * (gx_ + 1)
90 | gy = (self.img_size+1)/2 * (gy_ + 1)
91 |
92 | sigma2 = tf.exp(log_sigma2)
93 | # stride/delta: how far apart these patches will be
94 | delta = (self.img_size - 1) / ((self.attention_n-1) * tf.exp(log_delta))
95 | # returns [Fx, Fy, gamma]
96 |
97 | self.attn_params.append([gx, gy, delta])
98 |
99 | return self.filterbank(gx,gy,sigma2,delta) + (tf.exp(log_gamma),)
100 |
101 | # Given a center, distance, and spread
102 | # Construct [attention_n x attention_n] patches of gaussian filters
103 | # represented by Fx = horizontal gaussian, Fy = vertical guassian
104 | def filterbank(self, gx, gy, sigma2, delta):
105 | # 1 x N, look like [[0,1,2,3,4]]
106 | grid_i = tf.reshape(tf.cast(tf.range(self.attention_n), tf.float32),[1, -1])
107 | # centers for the individual patches
108 | mu_x = gx + (grid_i - self.attention_n/2 - 0.5) * delta
109 | mu_y = gy + (grid_i - self.attention_n/2 - 0.5) * delta
110 | mu_x = tf.reshape(mu_x, [-1, self.attention_n, 1])
111 | mu_y = tf.reshape(mu_y, [-1, self.attention_n, 1])
112 | # 1 x 1 x imgsize, looks like [[[0,1,2,3,4,...,27]]]
113 | im = tf.reshape(tf.cast(tf.range(self.img_size), tf.float32), [1, 1, -1])
114 | # list of gaussian curves for x and y
115 | sigma2 = tf.reshape(sigma2, [-1, 1, 1])
116 | Fx = tf.exp(-tf.square((im - mu_x) / (2*sigma2)))
117 | Fy = tf.exp(-tf.square((im - mu_x) / (2*sigma2)))
118 | # normalize so area-under-curve = 1
119 | Fx = Fx / tf.maximum(tf.reduce_sum(Fx,2,keep_dims=True),1e-8)
120 | Fy = Fy / tf.maximum(tf.reduce_sum(Fy,2,keep_dims=True),1e-8)
121 | return Fx, Fy
122 |
123 |
124 | # the read() operation without attention
125 | def read_basic(self, x, x_hat, h_dec_prev):
126 | return tf.concat(1,[x,x_hat])
127 |
128 | def read_attention(self, x, x_hat, h_dec_prev):
129 | Fx, Fy, gamma = self.attn_window("read", h_dec_prev)
130 | # we have the parameters for a patch of gaussian filters. apply them.
131 | def filter_img(img, Fx, Fy, gamma):
132 | # Fx,Fy = [64,5,32]
133 | # img = [64, 32*32*3]
134 |
135 | img = tf.reshape(img, [-1, self.img_size, self.img_size, self.num_colors])
136 | img_t = tf.transpose(img, perm=[3,0,1,2])
137 |
138 | # color1, color2, color3, color1, color2, color3, etc.
139 | batch_colors_array = tf.reshape(img_t, [self.num_colors * self.batch_size, self.img_size, self.img_size])
140 | Fx_array = tf.concat(0, [Fx, Fx, Fx])
141 | Fy_array = tf.concat(0, [Fy, Fy, Fy])
142 |
143 | Fxt = tf.transpose(Fx_array, perm=[0,2,1])
144 |
145 | # Apply the gaussian patches:
146 | glimpse = tf.batch_matmul(Fy_array, tf.batch_matmul(batch_colors_array, Fxt))
147 | glimpse = tf.reshape(glimpse, [self.num_colors, self.batch_size, self.attention_n, self.attention_n])
148 | glimpse = tf.transpose(glimpse, [1,2,3,0])
149 | glimpse = tf.reshape(glimpse, [self.batch_size, self.attention_n*self.attention_n*self.num_colors])
150 | # finally scale this glimpse w/ the gamma parameter
151 | return glimpse * tf.reshape(gamma, [-1, 1])
152 | x = filter_img(x, Fx, Fy, gamma)
153 | x_hat = filter_img(x_hat, Fx, Fy, gamma)
154 | return tf.concat(1, [x, x_hat])
155 |
156 | # encode an attention patch
157 | def encode(self, prev_state, image):
158 | # update the RNN with image
159 | with tf.variable_scope("encoder",reuse=self.share_parameters):
160 | hidden_layer, next_state = self.lstm_enc(image, prev_state)
161 |
162 | # map the RNN hidden state to latent variables
163 | with tf.variable_scope("mu", reuse=self.share_parameters):
164 | mu = dense(hidden_layer, self.n_hidden, self.n_z)
165 | with tf.variable_scope("sigma", reuse=self.share_parameters):
166 | logsigma = dense(hidden_layer, self.n_hidden, self.n_z)
167 | sigma = tf.exp(logsigma)
168 | return mu, logsigma, sigma, next_state
169 |
170 |
171 | def sampleQ(self, mu, sigma):
172 | return mu + sigma*self.e
173 |
174 | def decode_layer(self, prev_state, latent):
175 | # update decoder RNN with latent var
176 | with tf.variable_scope("decoder", reuse=self.share_parameters):
177 | hidden_layer, next_state = self.lstm_dec(latent, prev_state)
178 |
179 | return hidden_layer, next_state
180 |
181 | def write_basic(self, hidden_layer):
182 | # map RNN hidden state to image
183 | with tf.variable_scope("write", reuse=self.share_parameters):
184 | decoded_image_portion = dense(hidden_layer, self.n_hidden, self.img_size*self.img_size*self.num_colors)
185 | # decoded_image_portion = tf.reshape(decoded_image_portion, [-1, self.img_size, self.img_size, self.num_colors])
186 | return decoded_image_portion
187 |
188 | def write_attention(self, hidden_layer):
189 | with tf.variable_scope("writeW", reuse=self.share_parameters):
190 | w = dense(hidden_layer, self.n_hidden, self.attention_n*self.attention_n*self.num_colors)
191 |
192 | w = tf.reshape(w, [self.batch_size, self.attention_n, self.attention_n, self.num_colors])
193 | w_t = tf.transpose(w, perm=[3,0,1,2])
194 | Fx, Fy, gamma = self.attn_window("write", hidden_layer)
195 |
196 | # color1, color2, color3, color1, color2, color3, etc.
197 | w_array = tf.reshape(w_t, [self.num_colors * self.batch_size, self.attention_n, self.attention_n])
198 | Fx_array = tf.concat(0, [Fx, Fx, Fx])
199 | Fy_array = tf.concat(0, [Fy, Fy, Fy])
200 |
201 | Fyt = tf.transpose(Fy_array, perm=[0,2,1])
202 | # [vert, attn_n] * [attn_n, attn_n] * [attn_n, horiz]
203 | wr = tf.batch_matmul(Fyt, tf.batch_matmul(w_array, Fx_array))
204 | sep_colors = tf.reshape(wr, [self.batch_size, self.num_colors, self.img_size**2])
205 | wr = tf.reshape(wr, [self.num_colors, self.batch_size, self.img_size, self.img_size])
206 | wr = tf.transpose(wr, [1,2,3,0])
207 | wr = tf.reshape(wr, [self.batch_size, self.img_size * self.img_size * self.num_colors])
208 | return wr * tf.reshape(1.0/gamma, [-1, 1])
209 |
210 |
211 | def train(self):
212 | data = glob(os.path.join("../Datasets/celebA", "*.jpg"))
213 | base = np.array([get_image(sample_file, 108, is_crop=True) for sample_file in data[0:64]])
214 | base += 1
215 | base /= 2
216 |
217 | ims("results/base.jpg",merge_color(base,[8,8]))
218 |
219 | saver = tf.train.Saver(max_to_keep=2)
220 |
221 | for e in xrange(10):
222 | for i in range((len(data) / self.batch_size) - 2):
223 |
224 | batch_files = data[i*self.batch_size:(i+1)*self.batch_size]
225 | batch = [get_image(batch_file, 108, is_crop=True) for batch_file in batch_files]
226 | batch_images = np.array(batch).astype(np.float32)
227 | batch_images += 1
228 | batch_images /= 2
229 |
230 | cs, attn_params, gen_loss, lat_loss, _ = self.sess.run([self.cs, self.attn_params, self.generation_loss, self.latent_loss, self.train_op], feed_dict={self.images: batch_images})
231 | print "epoch %d iter %d genloss %f latloss %f" % (e, i, gen_loss, lat_loss)
232 | # print attn_params[0].shape
233 | # print attn_params[1].shape
234 | # print attn_params[2].shape
235 | if i % 800 == 0:
236 |
237 | saver.save(self.sess, os.getcwd() + "/training/train", global_step=e*10000 + i)
238 |
239 | cs = 1.0/(1.0+np.exp(-np.array(cs))) # x_recons=sigmoid(canvas)
240 |
241 | for cs_iter in xrange(10):
242 | results = cs[cs_iter]
243 | results_square = np.reshape(results, [-1, self.img_size, self.img_size, self.num_colors])
244 | print results_square.shape
245 | ims("results/"+str(e)+"-"+str(i)+"-step-"+str(cs_iter)+".jpg",merge_color(results_square,[8,8]))
246 |
247 |
248 | def view(self):
249 | data = glob(os.path.join("../Datasets/celebA", "*.jpg"))
250 | base = np.array([get_image(sample_file, 108, is_crop=True) for sample_file in data[0:64]])
251 | base += 1
252 | base /= 2
253 |
254 | ims("results/base.jpg",merge_color(base,[8,8]))
255 |
256 | saver = tf.train.Saver(max_to_keep=2)
257 | saver.restore(self.sess, tf.train.latest_checkpoint(os.getcwd()+"/training/"))
258 |
259 | cs, attn_params, gen_loss, lat_loss = self.sess.run([self.cs, self.attn_params, self.generation_loss, self.latent_loss], feed_dict={self.images: base})
260 | print "genloss %f latloss %f" % (gen_loss, lat_loss)
261 |
262 | cs = 1.0/(1.0+np.exp(-np.array(cs))) # x_recons=sigmoid(canvas)
263 |
264 | print np.shape(cs)
265 | print np.shape(attn_params)
266 | # cs[0][cent]
267 |
268 | for cs_iter in xrange(10):
269 | results = cs[cs_iter]
270 | results_square = np.reshape(results, [-1, self.img_size, self.img_size, self.num_colors])
271 |
272 | print np.shape(results_square)
273 |
274 | for i in xrange(64):
275 | center_x = int(attn_params[cs_iter][0][i][0])
276 | center_y = int(attn_params[cs_iter][1][i][0])
277 | distance = int(attn_params[cs_iter][2][i][0])
278 |
279 | size = 2;
280 |
281 | # for x in xrange(3):
282 | # for y in xrange(3):
283 | # nx = x - 1;
284 | # ny = y - 1;
285 | #
286 | # xpos = center_x + nx*distance
287 | # ypos = center_y + ny*distance
288 | #
289 | # xpos2 = min(max(0, xpos + size), 63)
290 | # ypos2 = min(max(0, ypos + size), 63)
291 | #
292 | # xpos = min(max(0, xpos), 63)
293 | # ypos = min(max(0, ypos), 63)
294 | #
295 | # results_square[i,xpos:xpos2,ypos:ypos2,0] = 0;
296 | # results_square[i,xpos:xpos2,ypos:ypos2,1] = 1;
297 | # results_square[i,xpos:xpos2,ypos:ypos2,2] = 0;
298 | # print "%f , %f" % (center_x, center_y)
299 |
300 | print results_square
301 |
302 | ims("results/view-clean-step-"+str(cs_iter)+".jpg",merge_color(results_square,[8,8]))
303 |
304 |
305 |
306 |
307 | model = Draw()
308 | # model.train()
309 | model.view()
310 |
--------------------------------------------------------------------------------