├── .DS_Store ├── res_noattn.gif ├── res_attention.gif ├── results ├── base.jpg ├── 0-step-0.jpg ├── 0-step-1.jpg ├── 0-step-2.jpg ├── 0-step-3.jpg ├── 0-step-4.jpg ├── 0-step-5.jpg ├── 0-step-6.jpg ├── 0-step-7.jpg ├── 0-step-8.jpg ├── 0-step-9.jpg ├── 800-step-0.jpg ├── 800-step-1.jpg ├── 800-step-2.jpg ├── 800-step-3.jpg ├── 800-step-4.jpg ├── 800-step-5.jpg ├── 800-step-6.jpg ├── 800-step-7.jpg ├── 800-step-8.jpg ├── 800-step-9.jpg ├── 1600-step-0.jpg ├── 1600-step-1.jpg ├── 1600-step-2.jpg ├── 1600-step-3.jpg ├── 1600-step-4.jpg ├── 1600-step-5.jpg ├── 1600-step-6.jpg ├── 1600-step-7.jpg ├── 1600-step-8.jpg ├── 1600-step-9.jpg ├── 2400-step-0.jpg ├── 2400-step-1.jpg ├── 2400-step-2.jpg ├── 2400-step-3.jpg ├── 2400-step-4.jpg ├── 2400-step-5.jpg ├── 2400-step-6.jpg ├── 2400-step-7.jpg ├── 2400-step-8.jpg ├── 2400-step-9.jpg ├── view-step-0.jpg ├── view-step-1.jpg ├── view-step-2.jpg ├── view-step-3.jpg ├── view-step-4.jpg ├── view-step-5.jpg ├── view-step-6.jpg ├── view-step-7.jpg ├── view-step-8.jpg ├── view-step-9.jpg ├── view-attn-step-0.jpg ├── view-attn-step-1.jpg ├── view-attn-step-2.jpg ├── view-attn-step-3.jpg ├── view-attn-step-4.jpg ├── view-attn-step-5.jpg ├── view-attn-step-6.jpg ├── view-attn-step-7.jpg ├── view-attn-step-8.jpg ├── view-attn-step-9.jpg ├── view-clean-step-0.jpg ├── view-clean-step-1.jpg ├── view-clean-step-2.jpg ├── view-clean-step-3.jpg ├── view-clean-step-4.jpg ├── view-clean-step-5.jpg ├── view-clean-step-6.jpg ├── view-clean-step-7.jpg ├── view-clean-step-8.jpg └── view-clean-step-9.jpg ├── good-results ├── base.jpg ├── 100-step-0.jpg ├── 100-step-1.jpg ├── 100-step-2.jpg ├── 100-step-3.jpg ├── 100-step-4.jpg ├── 100-step-5.jpg ├── 100-step-6.jpg ├── 100-step-7.jpg ├── 100-step-8.jpg ├── 100-step-9.jpg ├── 200-step-0.jpg ├── 200-step-1.jpg ├── 200-step-2.jpg ├── 200-step-3.jpg ├── 200-step-4.jpg ├── 200-step-5.jpg ├── 200-step-6.jpg ├── 200-step-7.jpg ├── 200-step-8.jpg ├── 200-step-9.jpg ├── 300-step-0.jpg ├── 300-step-1.jpg ├── 300-step-2.jpg ├── 300-step-3.jpg ├── 300-step-4.jpg ├── 300-step-5.jpg ├── 300-step-6.jpg ├── 300-step-7.jpg ├── 300-step-8.jpg ├── 300-step-9.jpg ├── 400-step-0.jpg ├── 400-step-1.jpg ├── 400-step-2.jpg ├── 400-step-3.jpg ├── 400-step-4.jpg ├── 400-step-5.jpg ├── 400-step-6.jpg ├── 400-step-7.jpg ├── 400-step-8.jpg ├── 400-step-9.jpg ├── 500-step-0.jpg ├── 500-step-1.jpg ├── 500-step-2.jpg ├── 500-step-3.jpg ├── 500-step-4.jpg ├── 500-step-5.jpg ├── 500-step-6.jpg ├── 500-step-7.jpg ├── 500-step-8.jpg ├── 500-step-9.jpg ├── 600-step-0.jpg ├── 600-step-1.jpg ├── 3000-step-6.jpg ├── 3100-step-0.jpg ├── 3100-step-1.jpg ├── 3100-step-2.jpg ├── 3100-step-3.jpg ├── 3100-step-4.jpg ├── 3100-step-5.jpg ├── 3100-step-6.jpg ├── 3100-step-7.jpg ├── 3100-step-8.jpg └── 3100-step-9.jpg ├── results-noattn ├── 14000-step-0.jpg ├── 14000-step-1.jpg ├── 14000-step-2.jpg ├── 14000-step-3.jpg ├── 14000-step-4.jpg ├── 14000-step-5.jpg ├── 14000-step-6.jpg ├── 14000-step-7.jpg ├── 14000-step-8.jpg └── 14000-step-9.jpg ├── README.md ├── utils.py ├── plot_data.py ├── ops.py ├── input_data.py ├── main.py └── color.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/.DS_Store -------------------------------------------------------------------------------- /res_noattn.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/res_noattn.gif -------------------------------------------------------------------------------- /res_attention.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/res_attention.gif -------------------------------------------------------------------------------- /results/base.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/base.jpg -------------------------------------------------------------------------------- /good-results/base.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/base.jpg -------------------------------------------------------------------------------- /results/0-step-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-0.jpg -------------------------------------------------------------------------------- /results/0-step-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-1.jpg -------------------------------------------------------------------------------- /results/0-step-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-2.jpg -------------------------------------------------------------------------------- /results/0-step-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-3.jpg -------------------------------------------------------------------------------- /results/0-step-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-4.jpg -------------------------------------------------------------------------------- /results/0-step-5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-5.jpg -------------------------------------------------------------------------------- /results/0-step-6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-6.jpg -------------------------------------------------------------------------------- /results/0-step-7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-7.jpg -------------------------------------------------------------------------------- /results/0-step-8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-8.jpg -------------------------------------------------------------------------------- /results/0-step-9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/0-step-9.jpg -------------------------------------------------------------------------------- /results/800-step-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-0.jpg -------------------------------------------------------------------------------- /results/800-step-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-1.jpg -------------------------------------------------------------------------------- /results/800-step-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-2.jpg -------------------------------------------------------------------------------- /results/800-step-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-3.jpg -------------------------------------------------------------------------------- /results/800-step-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-4.jpg -------------------------------------------------------------------------------- /results/800-step-5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-5.jpg -------------------------------------------------------------------------------- /results/800-step-6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-6.jpg -------------------------------------------------------------------------------- /results/800-step-7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-7.jpg -------------------------------------------------------------------------------- /results/800-step-8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-8.jpg -------------------------------------------------------------------------------- /results/800-step-9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/800-step-9.jpg -------------------------------------------------------------------------------- /results/1600-step-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-0.jpg -------------------------------------------------------------------------------- /results/1600-step-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-1.jpg -------------------------------------------------------------------------------- /results/1600-step-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-2.jpg -------------------------------------------------------------------------------- /results/1600-step-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-3.jpg -------------------------------------------------------------------------------- /results/1600-step-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-4.jpg -------------------------------------------------------------------------------- /results/1600-step-5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-5.jpg -------------------------------------------------------------------------------- /results/1600-step-6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-6.jpg -------------------------------------------------------------------------------- /results/1600-step-7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-7.jpg -------------------------------------------------------------------------------- /results/1600-step-8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-8.jpg -------------------------------------------------------------------------------- /results/1600-step-9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/1600-step-9.jpg -------------------------------------------------------------------------------- /results/2400-step-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-0.jpg -------------------------------------------------------------------------------- /results/2400-step-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-1.jpg -------------------------------------------------------------------------------- /results/2400-step-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-2.jpg -------------------------------------------------------------------------------- /results/2400-step-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-3.jpg -------------------------------------------------------------------------------- /results/2400-step-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-4.jpg -------------------------------------------------------------------------------- /results/2400-step-5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-5.jpg -------------------------------------------------------------------------------- /results/2400-step-6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-6.jpg -------------------------------------------------------------------------------- /results/2400-step-7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-7.jpg -------------------------------------------------------------------------------- /results/2400-step-8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-8.jpg -------------------------------------------------------------------------------- /results/2400-step-9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/2400-step-9.jpg -------------------------------------------------------------------------------- /results/view-step-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-0.jpg -------------------------------------------------------------------------------- /results/view-step-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-1.jpg -------------------------------------------------------------------------------- /results/view-step-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-2.jpg -------------------------------------------------------------------------------- /results/view-step-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-3.jpg -------------------------------------------------------------------------------- /results/view-step-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-4.jpg -------------------------------------------------------------------------------- /results/view-step-5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-5.jpg -------------------------------------------------------------------------------- /results/view-step-6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-6.jpg -------------------------------------------------------------------------------- /results/view-step-7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-7.jpg -------------------------------------------------------------------------------- /results/view-step-8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-8.jpg -------------------------------------------------------------------------------- /results/view-step-9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-step-9.jpg -------------------------------------------------------------------------------- /good-results/100-step-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-0.jpg -------------------------------------------------------------------------------- /good-results/100-step-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-1.jpg -------------------------------------------------------------------------------- /good-results/100-step-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-2.jpg -------------------------------------------------------------------------------- /good-results/100-step-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-3.jpg -------------------------------------------------------------------------------- /good-results/100-step-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-4.jpg -------------------------------------------------------------------------------- /good-results/100-step-5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-5.jpg -------------------------------------------------------------------------------- /good-results/100-step-6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-6.jpg -------------------------------------------------------------------------------- /good-results/100-step-7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-7.jpg -------------------------------------------------------------------------------- /good-results/100-step-8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-8.jpg -------------------------------------------------------------------------------- /good-results/100-step-9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/100-step-9.jpg -------------------------------------------------------------------------------- /good-results/200-step-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-0.jpg -------------------------------------------------------------------------------- /good-results/200-step-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-1.jpg -------------------------------------------------------------------------------- /good-results/200-step-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-2.jpg -------------------------------------------------------------------------------- /good-results/200-step-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-3.jpg -------------------------------------------------------------------------------- /good-results/200-step-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-4.jpg -------------------------------------------------------------------------------- /good-results/200-step-5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-5.jpg -------------------------------------------------------------------------------- /good-results/200-step-6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-6.jpg -------------------------------------------------------------------------------- /good-results/200-step-7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-7.jpg -------------------------------------------------------------------------------- /good-results/200-step-8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-8.jpg -------------------------------------------------------------------------------- /good-results/200-step-9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/200-step-9.jpg -------------------------------------------------------------------------------- /good-results/300-step-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-0.jpg -------------------------------------------------------------------------------- /good-results/300-step-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-1.jpg -------------------------------------------------------------------------------- /good-results/300-step-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-2.jpg -------------------------------------------------------------------------------- /good-results/300-step-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-3.jpg -------------------------------------------------------------------------------- /good-results/300-step-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-4.jpg -------------------------------------------------------------------------------- /good-results/300-step-5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-5.jpg -------------------------------------------------------------------------------- /good-results/300-step-6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-6.jpg -------------------------------------------------------------------------------- /good-results/300-step-7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-7.jpg -------------------------------------------------------------------------------- /good-results/300-step-8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-8.jpg -------------------------------------------------------------------------------- /good-results/300-step-9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/300-step-9.jpg -------------------------------------------------------------------------------- /good-results/400-step-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-0.jpg -------------------------------------------------------------------------------- /good-results/400-step-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-1.jpg -------------------------------------------------------------------------------- /good-results/400-step-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-2.jpg -------------------------------------------------------------------------------- /good-results/400-step-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-3.jpg -------------------------------------------------------------------------------- /good-results/400-step-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-4.jpg -------------------------------------------------------------------------------- /good-results/400-step-5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-5.jpg -------------------------------------------------------------------------------- /good-results/400-step-6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-6.jpg -------------------------------------------------------------------------------- /good-results/400-step-7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-7.jpg -------------------------------------------------------------------------------- /good-results/400-step-8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-8.jpg -------------------------------------------------------------------------------- /good-results/400-step-9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/400-step-9.jpg -------------------------------------------------------------------------------- /good-results/500-step-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-0.jpg -------------------------------------------------------------------------------- /good-results/500-step-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-1.jpg -------------------------------------------------------------------------------- /good-results/500-step-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-2.jpg -------------------------------------------------------------------------------- /good-results/500-step-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-3.jpg -------------------------------------------------------------------------------- /good-results/500-step-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-4.jpg -------------------------------------------------------------------------------- /good-results/500-step-5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-5.jpg -------------------------------------------------------------------------------- /good-results/500-step-6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-6.jpg -------------------------------------------------------------------------------- /good-results/500-step-7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-7.jpg -------------------------------------------------------------------------------- /good-results/500-step-8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-8.jpg -------------------------------------------------------------------------------- /good-results/500-step-9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/500-step-9.jpg -------------------------------------------------------------------------------- /good-results/600-step-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/600-step-0.jpg -------------------------------------------------------------------------------- /good-results/600-step-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/600-step-1.jpg -------------------------------------------------------------------------------- /good-results/3000-step-6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3000-step-6.jpg -------------------------------------------------------------------------------- /good-results/3100-step-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-0.jpg -------------------------------------------------------------------------------- /good-results/3100-step-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-1.jpg -------------------------------------------------------------------------------- /good-results/3100-step-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-2.jpg -------------------------------------------------------------------------------- /good-results/3100-step-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-3.jpg -------------------------------------------------------------------------------- /good-results/3100-step-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-4.jpg -------------------------------------------------------------------------------- /good-results/3100-step-5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-5.jpg -------------------------------------------------------------------------------- /good-results/3100-step-6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-6.jpg -------------------------------------------------------------------------------- /good-results/3100-step-7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-7.jpg -------------------------------------------------------------------------------- /good-results/3100-step-8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-8.jpg -------------------------------------------------------------------------------- /good-results/3100-step-9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/good-results/3100-step-9.jpg -------------------------------------------------------------------------------- /results/view-attn-step-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-0.jpg -------------------------------------------------------------------------------- /results/view-attn-step-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-1.jpg -------------------------------------------------------------------------------- /results/view-attn-step-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-2.jpg -------------------------------------------------------------------------------- /results/view-attn-step-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-3.jpg -------------------------------------------------------------------------------- /results/view-attn-step-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-4.jpg -------------------------------------------------------------------------------- /results/view-attn-step-5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-5.jpg -------------------------------------------------------------------------------- /results/view-attn-step-6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-6.jpg -------------------------------------------------------------------------------- /results/view-attn-step-7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-7.jpg -------------------------------------------------------------------------------- /results/view-attn-step-8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-8.jpg -------------------------------------------------------------------------------- /results/view-attn-step-9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-attn-step-9.jpg -------------------------------------------------------------------------------- /results/view-clean-step-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-0.jpg -------------------------------------------------------------------------------- /results/view-clean-step-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-1.jpg -------------------------------------------------------------------------------- /results/view-clean-step-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-2.jpg -------------------------------------------------------------------------------- /results/view-clean-step-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-3.jpg -------------------------------------------------------------------------------- /results/view-clean-step-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-4.jpg -------------------------------------------------------------------------------- /results/view-clean-step-5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-5.jpg -------------------------------------------------------------------------------- /results/view-clean-step-6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-6.jpg -------------------------------------------------------------------------------- /results/view-clean-step-7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-7.jpg -------------------------------------------------------------------------------- /results/view-clean-step-8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-8.jpg -------------------------------------------------------------------------------- /results/view-clean-step-9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results/view-clean-step-9.jpg -------------------------------------------------------------------------------- /results-noattn/14000-step-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-0.jpg -------------------------------------------------------------------------------- /results-noattn/14000-step-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-1.jpg -------------------------------------------------------------------------------- /results-noattn/14000-step-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-2.jpg -------------------------------------------------------------------------------- /results-noattn/14000-step-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-3.jpg -------------------------------------------------------------------------------- /results-noattn/14000-step-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-4.jpg -------------------------------------------------------------------------------- /results-noattn/14000-step-5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-5.jpg -------------------------------------------------------------------------------- /results-noattn/14000-step-6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-6.jpg -------------------------------------------------------------------------------- /results-noattn/14000-step-7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-7.jpg -------------------------------------------------------------------------------- /results-noattn/14000-step-8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-8.jpg -------------------------------------------------------------------------------- /results-noattn/14000-step-9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/draw-color/HEAD/results-noattn/14000-step-9.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # draw-color 2 | A Tensorflow implementation of [DRAW](https://arxiv.org/abs/1502.04623). Now includes support for colored images! 3 | 4 | This is code to accompany [my post on the DRAW model](http://kvfrans.com/what-is-draw-deep-recurrent-attentive-writer/). 5 | 6 | For an explanation of how I modified the original DRAW into a colored model, check out [my post on colorizing DRAW](http://kvfrans.com/colorizing-the-draw-model/). 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | The original DRAW is a slightly rewritten, commented version of [ericjang's implementation](https://github.com/ericjang/draw), running on MNIST. 15 | 16 | Usage: 17 | ``` 18 | python main.py 19 | ``` 20 | 21 | The colored DRAW runs on the celebA dataset. 22 | 23 | Usage: 24 | ``` 25 | python color.py 26 | ``` 27 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import scipy.misc 2 | import numpy as np 3 | import random 4 | import tensorflow as tf 5 | import cPickle 6 | 7 | 8 | def get_image(image_path, image_size, is_crop=True): 9 | return transform(imread(image_path), image_size, is_crop) 10 | 11 | def transform(image, npx=64, is_crop=True): 12 | # npx : # of pixels width/height of image 13 | if is_crop: 14 | cropped_image = center_crop(image, npx) 15 | else: 16 | cropped_image = image 17 | return np.array(cropped_image)/127.5 - 1. 18 | 19 | def center_crop(x, crop_h, crop_w=None, resize_w=64): 20 | if crop_w is None: 21 | crop_w = crop_h 22 | h, w = x.shape[:2] 23 | j = int(round((h - crop_h)/2.)) 24 | i = int(round((w - crop_w)/2.)) 25 | return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w], 26 | [resize_w, resize_w]) 27 | 28 | def imread(path): 29 | readimage = scipy.misc.imread(path, mode="RGB").astype(np.float) 30 | return readimage 31 | 32 | def merge_color(images, size): 33 | h, w = images.shape[1], images.shape[2] 34 | img = np.zeros((h * size[0], w * size[1], 3)) 35 | 36 | for idx, image in enumerate(images): 37 | i = idx % size[1] 38 | j = idx / size[1] 39 | img[j*h:j*h+h, i*w:i*w+w, :] = image 40 | 41 | return img 42 | 43 | def unpickle(file): 44 | fo = open(file, 'rb') 45 | dict = cPickle.load(fo) 46 | fo.close() 47 | return dict 48 | 49 | def ims(name, img): 50 | # print img[:10][:10] 51 | scipy.misc.toimage(img, cmin=0, cmax=1).save(name) 52 | 53 | def sigmoid(x): 54 | return 1 / (1 + np.exp(-x)) 55 | -------------------------------------------------------------------------------- /plot_data.py: -------------------------------------------------------------------------------- 1 | # takes data saved by DRAW model and generates animations 2 | # example usage: python plot_data.py noattn /tmp/draw/draw_data.npy 3 | 4 | import matplotlib 5 | import sys 6 | import numpy as np 7 | 8 | interactive=False # set to False if you want to write images to file 9 | 10 | if not interactive: 11 | matplotlib.use('Agg') # Force matplotlib to not use any Xwindows backend. 12 | import matplotlib.pyplot as plt 13 | 14 | 15 | def xrecons_grid(X,B,A): 16 | """ 17 | plots canvas for single time step 18 | X is x_recons, (batch_size x img_size) 19 | assumes features = BxA images 20 | batch is assumed to be a square number 21 | """ 22 | padsize=1 23 | padval=.5 24 | ph=B+2*padsize 25 | pw=A+2*padsize 26 | batch_size=X.shape[0] 27 | N=int(np.sqrt(batch_size)) 28 | X=X.reshape((N,N,B,A)) 29 | img=np.ones((N*ph,N*pw))*padval 30 | for i in range(N): 31 | for j in range(N): 32 | startr=i*ph+padsize 33 | endr=startr+B 34 | startc=j*pw+padsize 35 | endc=startc+A 36 | img[startr:endr,startc:endc]=X[i,j,:,:] 37 | return img 38 | 39 | if __name__ == '__main__': 40 | prefix=sys.argv[1] 41 | out_file=sys.argv[2] 42 | C=np.load(out_file) 43 | T,batch_size,img_size=C.shape 44 | X=1.0/(1.0+np.exp(-C)) # x_recons=sigmoid(canvas) 45 | B=A=int(np.sqrt(img_size)) 46 | if interactive: 47 | f,arr=plt.subplots(1,T) 48 | for t in range(T): 49 | img=xrecons_grid(X[t,:,:],B,A) 50 | if interactive: 51 | arr[t].matshow(img,cmap=plt.cm.gray) 52 | arr[t].set_xticks([]) 53 | arr[t].set_yticks([]) 54 | else: 55 | plt.matshow(img,cmap=plt.cm.gray) 56 | imgname='%s_%d.png' % (prefix,t) # you can merge using imagemagick, i.e. convert -delay 10 -loop 0 *.png mnist.gif 57 | plt.savefig(imgname) 58 | print(imgname) 59 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | class batch_norm(object): 5 | """Code modification of http://stackoverflow.com/a/33950177""" 6 | def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): 7 | with tf.variable_scope(name): 8 | self.epsilon = epsilon 9 | self.momentum = momentum 10 | 11 | self.ema = tf.train.ExponentialMovingAverage(decay=self.momentum) 12 | self.name = name 13 | 14 | def __call__(self, x, train=True): 15 | shape = x.get_shape().as_list() 16 | 17 | if train: 18 | with tf.variable_scope(self.name) as scope: 19 | self.beta = tf.get_variable("beta", [shape[-1]], 20 | initializer=tf.constant_initializer(0.)) 21 | self.gamma = tf.get_variable("gamma", [shape[-1]], 22 | initializer=tf.random_normal_initializer(1., 0.02)) 23 | 24 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments') 25 | ema_apply_op = self.ema.apply([batch_mean, batch_var]) 26 | self.ema_mean, self.ema_var = self.ema.average(batch_mean), self.ema.average(batch_var) 27 | 28 | with tf.control_dependencies([ema_apply_op]): 29 | mean, var = tf.identity(batch_mean), tf.identity(batch_var) 30 | else: 31 | mean, var = self.ema_mean, self.ema_var 32 | 33 | normed = tf.nn.batch_norm_with_global_normalization( 34 | x, mean, var, self.beta, self.gamma, self.epsilon, scale_after_normalization=True) 35 | 36 | return normed 37 | 38 | # standard convolution layer 39 | def conv2d(x, inputFeatures, outputFeatures, name): 40 | with tf.variable_scope(name): 41 | w = tf.get_variable("w",[5,5,inputFeatures, outputFeatures], initializer=tf.truncated_normal_initializer(stddev=0.02)) 42 | b = tf.get_variable("b",[outputFeatures], initializer=tf.constant_initializer(0.0)) 43 | conv = tf.nn.conv2d(x, w, strides=[1,2,2,1], padding="SAME") + b 44 | return conv 45 | 46 | def conv_transpose(x, outputShape, name): 47 | with tf.variable_scope(name): 48 | # h, w, out, in 49 | w = tf.get_variable("w",[5,5, outputShape[-1], x.get_shape()[-1]], initializer=tf.truncated_normal_initializer(stddev=0.02)) 50 | b = tf.get_variable("b",[outputShape[-1]], initializer=tf.constant_initializer(0.0)) 51 | convt = tf.nn.conv2d_transpose(x, w, output_shape=outputShape, strides=[1,2,2,1]) 52 | return convt 53 | 54 | # leaky reLu unit 55 | def lrelu(x, leak=0.2, name="lrelu"): 56 | with tf.variable_scope(name): 57 | f1 = 0.5 * (1 + leak) 58 | f2 = 0.5 * (1 - leak) 59 | return f1 * x + f2 * abs(x) 60 | 61 | # fully-conected layer 62 | def dense(x, inputFeatures, outputFeatures, scope=None, with_w=False): 63 | with tf.variable_scope(scope or "Linear"): 64 | matrix = tf.get_variable("Matrix", [inputFeatures, outputFeatures], tf.float32, tf.random_normal_initializer(stddev=0.02)) 65 | bias = tf.get_variable("bias", [outputFeatures], initializer=tf.constant_initializer(0.0)) 66 | if with_w: 67 | return tf.matmul(x, matrix) + bias, matrix, bias 68 | else: 69 | return tf.matmul(x, matrix) + bias 70 | 71 | def merge(images, size): 72 | h, w = images.shape[1], images.shape[2] 73 | img = np.zeros((h * size[0], w * size[1])) 74 | 75 | for idx, image in enumerate(images): 76 | i = idx % size[1] 77 | j = idx / size[1] 78 | img[j*h:j*h+h, i*w:i*w+w] = image 79 | 80 | return img 81 | -------------------------------------------------------------------------------- /input_data.py: -------------------------------------------------------------------------------- 1 | """Functions for downloading and reading MNIST data.""" 2 | import gzip 3 | import os 4 | import urllib 5 | import numpy 6 | SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' 7 | 8 | 9 | def maybe_download(filename, work_directory): 10 | """Download the data from Yann's website, unless it's already here.""" 11 | if not os.path.exists(work_directory): 12 | os.mkdir(work_directory) 13 | filepath = os.path.join(work_directory, filename) 14 | if not os.path.exists(filepath): 15 | filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath) 16 | statinfo = os.stat(filepath) 17 | print 'Succesfully downloaded', filename, statinfo.st_size, 'bytes.' 18 | return filepath 19 | 20 | 21 | def _read32(bytestream): 22 | dt = numpy.dtype(numpy.uint32).newbyteorder('>') 23 | return numpy.frombuffer(bytestream.read(4), dtype=dt) 24 | 25 | 26 | def extract_images(filename): 27 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" 28 | print 'Extracting', filename 29 | with gzip.open(filename) as bytestream: 30 | magic = _read32(bytestream) 31 | if magic != 2051: 32 | raise ValueError( 33 | 'Invalid magic number %d in MNIST image file: %s' % 34 | (magic, filename)) 35 | num_images = _read32(bytestream) 36 | rows = _read32(bytestream) 37 | cols = _read32(bytestream) 38 | buf = bytestream.read(rows * cols * num_images) 39 | data = numpy.frombuffer(buf, dtype=numpy.uint8) 40 | data = data.reshape(num_images, rows, cols, 1) 41 | return data 42 | 43 | 44 | def dense_to_one_hot(labels_dense, num_classes=10): 45 | """Convert class labels from scalars to one-hot vectors.""" 46 | num_labels = labels_dense.shape[0] 47 | index_offset = numpy.arange(num_labels) * num_classes 48 | labels_one_hot = numpy.zeros((num_labels, num_classes)) 49 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 50 | return labels_one_hot 51 | 52 | 53 | def extract_labels(filename, one_hot=False): 54 | """Extract the labels into a 1D uint8 numpy array [index].""" 55 | print 'Extracting', filename 56 | with gzip.open(filename) as bytestream: 57 | magic = _read32(bytestream) 58 | if magic != 2049: 59 | raise ValueError( 60 | 'Invalid magic number %d in MNIST label file: %s' % 61 | (magic, filename)) 62 | num_items = _read32(bytestream) 63 | buf = bytestream.read(num_items) 64 | labels = numpy.frombuffer(buf, dtype=numpy.uint8) 65 | if one_hot: 66 | return dense_to_one_hot(labels) 67 | return labels 68 | 69 | 70 | class DataSet(object): 71 | def __init__(self, images, labels, fake_data=False): 72 | if fake_data: 73 | self._num_examples = 10000 74 | else: 75 | assert images.shape[0] == labels.shape[0], ( 76 | "images.shape: %s labels.shape: %s" % (images.shape, 77 | labels.shape)) 78 | self._num_examples = images.shape[0] 79 | # Convert shape from [num examples, rows, columns, depth] 80 | # to [num examples, rows*columns] (assuming depth == 1) 81 | assert images.shape[3] == 1 82 | images = images.reshape(images.shape[0], 83 | images.shape[1] * images.shape[2]) 84 | # Convert from [0, 255] -> [0.0, 1.0]. 85 | images = images.astype(numpy.float32) 86 | images = numpy.multiply(images, 1.0 / 255.0) 87 | self._images = images 88 | self._labels = labels 89 | self._epochs_completed = 0 90 | self._index_in_epoch = 0 91 | 92 | @property 93 | def images(self): 94 | return self._images 95 | 96 | @property 97 | def labels(self): 98 | return self._labels 99 | 100 | @property 101 | def num_examples(self): 102 | return self._num_examples 103 | 104 | @property 105 | def epochs_completed(self): 106 | return self._epochs_completed 107 | 108 | def next_batch(self, batch_size, fake_data=False): 109 | """Return the next `batch_size` examples from this data set.""" 110 | if fake_data: 111 | fake_image = [1.0 for _ in xrange(784)] 112 | fake_label = 0 113 | return [fake_image for _ in xrange(batch_size)], [ 114 | fake_label for _ in xrange(batch_size)] 115 | start = self._index_in_epoch 116 | self._index_in_epoch += batch_size 117 | if self._index_in_epoch > self._num_examples: 118 | # Finished epoch 119 | self._epochs_completed += 1 120 | # Shuffle the data 121 | perm = numpy.arange(self._num_examples) 122 | numpy.random.shuffle(perm) 123 | self._images = self._images[perm] 124 | self._labels = self._labels[perm] 125 | # Start next epoch 126 | start = 0 127 | self._index_in_epoch = batch_size 128 | assert batch_size <= self._num_examples 129 | end = self._index_in_epoch 130 | return self._images[start:end], self._labels[start:end] 131 | 132 | 133 | def read_data_sets(train_dir, fake_data=False, one_hot=False): 134 | class DataSets(object): 135 | pass 136 | data_sets = DataSets() 137 | if fake_data: 138 | data_sets.train = DataSet([], [], fake_data=True) 139 | data_sets.validation = DataSet([], [], fake_data=True) 140 | data_sets.test = DataSet([], [], fake_data=True) 141 | return data_sets 142 | TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' 143 | TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' 144 | TEST_IMAGES = 't10k-images-idx3-ubyte.gz' 145 | TEST_LABELS = 't10k-labels-idx1-ubyte.gz' 146 | VALIDATION_SIZE = 5000 147 | local_file = maybe_download(TRAIN_IMAGES, train_dir) 148 | train_images = extract_images(local_file) 149 | local_file = maybe_download(TRAIN_LABELS, train_dir) 150 | train_labels = extract_labels(local_file, one_hot=one_hot) 151 | local_file = maybe_download(TEST_IMAGES, train_dir) 152 | test_images = extract_images(local_file) 153 | local_file = maybe_download(TEST_LABELS, train_dir) 154 | test_labels = extract_labels(local_file, one_hot=one_hot) 155 | validation_images = train_images[:VALIDATION_SIZE] 156 | validation_labels = train_labels[:VALIDATION_SIZE] 157 | train_images = train_images[VALIDATION_SIZE:] 158 | train_labels = train_labels[VALIDATION_SIZE:] 159 | data_sets.train = DataSet(train_images, train_labels) 160 | data_sets.validation = DataSet(validation_images, validation_labels) 161 | data_sets.test = DataSet(test_images, test_labels) 162 | return data_sets 163 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from ops import * 4 | from utils import * 5 | import input_data 6 | # from scipy.misc import imsave as ims 7 | 8 | 9 | class Draw(): 10 | def __init__(self): 11 | self.mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 12 | self.n_samples = self.mnist.train.num_examples 13 | 14 | self.img_size = 28 15 | self.attention_n = 5 16 | self.n_hidden = 256 17 | self.n_z = 10 18 | self.sequence_length = 10 19 | self.batch_size = 64 20 | self.share_parameters = False 21 | 22 | self.images = tf.placeholder(tf.float32, [None, 784]) 23 | self.e = tf.random_normal((self.batch_size, self.n_z), mean=0, stddev=1) # Qsampler noise 24 | self.lstm_enc = tf.nn.rnn_cell.LSTMCell(self.n_hidden, state_is_tuple=True) # encoder Op 25 | self.lstm_dec = tf.nn.rnn_cell.LSTMCell(self.n_hidden, state_is_tuple=True) # decoder Op 26 | 27 | self.cs = [0] * self.sequence_length 28 | self.mu, self.logsigma, self.sigma = [0] * self.sequence_length, [0] * self.sequence_length, [0] * self.sequence_length 29 | 30 | h_dec_prev = tf.zeros((self.batch_size, self.n_hidden)) 31 | enc_state = self.lstm_enc.zero_state(self.batch_size, tf.float32) 32 | dec_state = self.lstm_dec.zero_state(self.batch_size, tf.float32) 33 | 34 | x = self.images 35 | for t in range(self.sequence_length): 36 | # error image + original image 37 | c_prev = tf.zeros((self.batch_size, self.img_size**2)) if t == 0 else self.cs[t-1] 38 | x_hat = x - tf.sigmoid(c_prev) 39 | # read the image 40 | r = self.read_basic(x,x_hat,h_dec_prev) 41 | print r.get_shape() 42 | # r = self.read_attention(x,x_hat,h_dec_prev) 43 | # encode it to guass distrib 44 | self.mu[t], self.logsigma[t], self.sigma[t], enc_state = self.encode(enc_state, tf.concat(1, [r, h_dec_prev])) 45 | # sample from the distrib to get z 46 | z = self.sampleQ(self.mu[t],self.sigma[t]) 47 | print z.get_shape() 48 | # retrieve the hidden layer of RNN 49 | h_dec, dec_state = self.decode_layer(dec_state, z) 50 | 51 | print h_dec.get_shape() 52 | 53 | # map from hidden layer -> image portion, and then write it. 54 | self.cs[t] = c_prev + self.write_basic(h_dec) 55 | # self.cs[t] = c_prev + self.write_attention(h_dec) 56 | h_dec_prev = h_dec 57 | self.share_parameters = True # from now on, share variables 58 | 59 | # the final timestep 60 | self.generated_images = tf.nn.sigmoid(self.cs[-1]) 61 | 62 | self.generation_loss = tf.reduce_mean(-tf.reduce_sum(self.images * tf.log(1e-10 + self.generated_images) + (1-self.images) * tf.log(1e-10 + 1 - self.generated_images),1)) 63 | 64 | kl_terms = [0]*self.sequence_length 65 | for t in xrange(self.sequence_length): 66 | mu2 = tf.square(self.mu[t]) 67 | sigma2 = tf.square(self.sigma[t]) 68 | logsigma = self.logsigma[t] 69 | kl_terms[t] = 0.5 * tf.reduce_sum(mu2 + sigma2 - 2*logsigma, 1) - self.sequence_length*0.5 70 | self.latent_loss = tf.reduce_mean(tf.add_n(kl_terms)) 71 | self.cost = self.generation_loss + self.latent_loss 72 | optimizer = tf.train.AdamOptimizer(1e-3, beta1=0.5) 73 | grads = optimizer.compute_gradients(self.cost) 74 | for i,(g,v) in enumerate(grads): 75 | if g is not None: 76 | grads[i] = (tf.clip_by_norm(g,5),v) 77 | self.train_op = optimizer.apply_gradients(grads) 78 | 79 | self.sess = tf.Session() 80 | self.sess.run(tf.initialize_all_variables()) 81 | 82 | def train(self): 83 | for i in xrange(15000): 84 | xtrain, _ = self.mnist.train.next_batch(self.batch_size) 85 | cs, gen_loss, lat_loss, _ = self.sess.run([self.cs, self.generation_loss, self.latent_loss, self.train_op], feed_dict={self.images: xtrain}) 86 | print "iter %d genloss %f latloss %f" % (i, gen_loss, lat_loss) 87 | if i % 500 == 0: 88 | 89 | cs = 1.0/(1.0+np.exp(-np.array(cs))) # x_recons=sigmoid(canvas) 90 | 91 | for cs_iter in xrange(10): 92 | results = cs[cs_iter] 93 | results_square = np.reshape(results, [-1, 28, 28]) 94 | print results_square.shape 95 | ims("results/"+str(i)+"-step-"+str(cs_iter)+".jpg",merge(results_square,[8,8])) 96 | 97 | 98 | # given a hidden decoder layer: 99 | # locate where to put attention filters 100 | def attn_window(self, scope, h_dec): 101 | with tf.variable_scope(scope, reuse=self.share_parameters): 102 | parameters = dense(h_dec, self.n_hidden, 5) 103 | # gx_, gy_: center of 2d gaussian on a scale of -1 to 1 104 | gx_, gy_, log_sigma2, log_delta, log_gamma = tf.split(1,5,parameters) 105 | 106 | # move gx/gy to be a scale of -imgsize to +imgsize 107 | gx = (self.img_size+1)/2 * (gx_ + 1) 108 | gy = (self.img_size+1)/2 * (gy_ + 1) 109 | 110 | sigma2 = tf.exp(log_sigma2) 111 | # stride/delta: how far apart these patches will be 112 | delta = (self.img_size - 1) / ((self.attention_n-1) * tf.exp(log_delta)) 113 | # returns [Fx, Fy, gamma] 114 | return self.filterbank(gx,gy,sigma2,delta) + (tf.exp(log_gamma),) 115 | 116 | # Given a center, distance, and spread 117 | # Construct [attention_n x attention_n] patches of gaussian filters 118 | # represented by Fx = horizontal gaussian, Fy = vertical guassian 119 | def filterbank(self, gx, gy, sigma2, delta): 120 | # 1 x N, look like [[0,1,2,3,4]] 121 | grid_i = tf.reshape(tf.cast(tf.range(self.attention_n), tf.float32),[1, -1]) 122 | # centers for the individual patches 123 | mu_x = gx + (grid_i - self.attention_n/2 - 0.5) * delta 124 | mu_y = gy + (grid_i - self.attention_n/2 - 0.5) * delta 125 | mu_x = tf.reshape(mu_x, [-1, self.attention_n, 1]) 126 | mu_y = tf.reshape(mu_y, [-1, self.attention_n, 1]) 127 | # 1 x 1 x imgsize, looks like [[[0,1,2,3,4,...,27]]] 128 | im = tf.reshape(tf.cast(tf.range(self.img_size), tf.float32), [1, 1, -1]) 129 | # list of gaussian curves for x and y 130 | sigma2 = tf.reshape(sigma2, [-1, 1, 1]) 131 | Fx = tf.exp(-tf.square((im - mu_x) / (2*sigma2))) 132 | Fy = tf.exp(-tf.square((im - mu_x) / (2*sigma2))) 133 | # normalize so area-under-curve = 1 134 | Fx = Fx / tf.maximum(tf.reduce_sum(Fx,2,keep_dims=True),1e-8) 135 | Fy = Fy / tf.maximum(tf.reduce_sum(Fy,2,keep_dims=True),1e-8) 136 | return Fx, Fy 137 | 138 | 139 | # the read() operation without attention 140 | def read_basic(self, x, x_hat, h_dec_prev): 141 | return tf.concat(1,[x,x_hat]) 142 | 143 | def read_attention(self, x, x_hat, h_dec_prev): 144 | Fx, Fy, gamma = self.attn_window("read", h_dec_prev) 145 | # we have the parameters for a patch of gaussian filters. apply them. 146 | def filter_img(img, Fx, Fy, gamma): 147 | Fxt = tf.transpose(Fx, perm=[0,2,1]) 148 | img = tf.reshape(img, [-1, self.img_size, self.img_size]) 149 | # Apply the gaussian patches: 150 | # keep in mind: horiz = imgsize = verts (they are all the image size) 151 | # keep in mind: attn = height/length of attention patches 152 | # allfilters = [attn, vert] * [imgsize,imgsize] * [horiz, attn] 153 | # we have batches, so the full batch_matmul equation looks like: 154 | # [1, 1, vert] * [batchsize,imgsize,imgsize] * [1, horiz, 1] 155 | glimpse = tf.batch_matmul(Fy, tf.batch_matmul(img, Fxt)) 156 | glimpse = tf.reshape(glimpse, [-1, self.attention_n**2]) 157 | # finally scale this glimpse w/ the gamma parameter 158 | return glimpse * tf.reshape(gamma, [-1, 1]) 159 | x = filter_img(x, Fx, Fy, gamma) 160 | x_hat = filter_img(x_hat, Fx, Fy, gamma) 161 | return tf.concat(1, [x, x_hat]) 162 | 163 | # encode an attention patch 164 | def encode(self, prev_state, image): 165 | # update the RNN with image 166 | with tf.variable_scope("encoder",reuse=self.share_parameters): 167 | hidden_layer, next_state = self.lstm_enc(image, prev_state) 168 | 169 | # map the RNN hidden state to latent variables 170 | with tf.variable_scope("mu", reuse=self.share_parameters): 171 | mu = dense(hidden_layer, self.n_hidden, self.n_z) 172 | with tf.variable_scope("sigma", reuse=self.share_parameters): 173 | logsigma = dense(hidden_layer, self.n_hidden, self.n_z) 174 | sigma = tf.exp(logsigma) 175 | return mu, logsigma, sigma, next_state 176 | 177 | 178 | def sampleQ(self, mu, sigma): 179 | return mu + sigma*self.e 180 | 181 | def decode_layer(self, prev_state, latent): 182 | # update decoder RNN with latent var 183 | with tf.variable_scope("decoder", reuse=self.share_parameters): 184 | hidden_layer, next_state = self.lstm_dec(latent, prev_state) 185 | 186 | return hidden_layer, next_state 187 | 188 | def write_basic(self, hidden_layer): 189 | # map RNN hidden state to image 190 | with tf.variable_scope("write", reuse=self.share_parameters): 191 | decoded_image_portion = dense(hidden_layer, self.n_hidden, self.img_size**2) 192 | return decoded_image_portion 193 | 194 | def write_attention(self, hidden_layer): 195 | with tf.variable_scope("writeW", reuse=self.share_parameters): 196 | w = dense(hidden_layer, self.n_hidden, self.attention_n**2) 197 | w = tf.reshape(w, [self.batch_size, self.attention_n, self.attention_n]) 198 | Fx, Fy, gamma = self.attn_window("write", hidden_layer) 199 | Fyt = tf.transpose(Fy, perm=[0,2,1]) 200 | # [vert, attn_n] * [attn_n, attn_n] * [attn_n, horiz] 201 | wr = tf.batch_matmul(Fyt, tf.batch_matmul(w, Fx)) 202 | wr = tf.reshape(wr, [self.batch_size, self.img_size**2]) 203 | return wr * tf.reshape(1.0/gamma, [-1, 1]) 204 | 205 | 206 | 207 | model = Draw() 208 | model.train() 209 | -------------------------------------------------------------------------------- /color.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from ops import * 4 | from utils import * 5 | from glob import glob 6 | import os 7 | 8 | class Draw(): 9 | def __init__(self): 10 | 11 | self.img_size = 64 12 | self.num_colors = 3 13 | 14 | self.attention_n = 5 15 | self.n_hidden = 256 16 | self.n_z = 10 17 | self.sequence_length = 10 18 | self.batch_size = 64 19 | self.share_parameters = False 20 | 21 | self.images = tf.placeholder(tf.float32, [None, self.img_size, self.img_size, self.num_colors]) 22 | 23 | self.e = tf.random_normal((self.batch_size, self.n_z), mean=0, stddev=1) # Qsampler noise 24 | 25 | self.lstm_enc = tf.nn.rnn_cell.LSTMCell(self.n_hidden, state_is_tuple=True) # encoder Op 26 | self.lstm_dec = tf.nn.rnn_cell.LSTMCell(self.n_hidden, state_is_tuple=True) # decoder Op 27 | 28 | self.cs = [0] * self.sequence_length 29 | self.mu, self.logsigma, self.sigma = [0] * self.sequence_length, [0] * self.sequence_length, [0] * self.sequence_length 30 | 31 | h_dec_prev = tf.zeros((self.batch_size, self.n_hidden)) 32 | enc_state = self.lstm_enc.zero_state(self.batch_size, tf.float32) 33 | dec_state = self.lstm_dec.zero_state(self.batch_size, tf.float32) 34 | 35 | x = tf.reshape(self.images, [-1, self.img_size*self.img_size*self.num_colors]) 36 | self.attn_params = [] 37 | for t in range(self.sequence_length): 38 | # error image + original image 39 | c_prev = tf.zeros((self.batch_size, self.img_size * self.img_size * self.num_colors)) if t == 0 else self.cs[t-1] 40 | x_hat = x - tf.sigmoid(c_prev) 41 | # read the image 42 | # r = self.read_basic(x,x_hat,h_dec_prev) 43 | r = self.read_attention(x,x_hat,h_dec_prev) 44 | # encode it to gauss distrib 45 | self.mu[t], self.logsigma[t], self.sigma[t], enc_state = self.encode(enc_state, tf.concat(1, [r, h_dec_prev])) 46 | # sample from the distrib to get z 47 | z = self.sampleQ(self.mu[t],self.sigma[t]) 48 | # retrieve the hidden layer of RNN 49 | h_dec, dec_state = self.decode_layer(dec_state, z) 50 | # map from hidden layer -> image portion, and then write it. 51 | # self.cs[t] = c_prev + self.write_basic(h_dec) 52 | self.cs[t] = c_prev + self.write_attention(h_dec) 53 | h_dec_prev = h_dec 54 | self.share_parameters = True # from now on, share variables 55 | 56 | # the final timestep 57 | self.generated_images = tf.nn.sigmoid(self.cs[-1]) 58 | 59 | # self.generation_loss = tf.reduce_mean(-tf.reduce_sum(self.images * tf.log(1e-10 + self.generated_images) + (1-self.images) * tf.log(1e-10 + 1 - self.generated_images),1)) 60 | self.generation_loss = tf.nn.l2_loss(x - self.generated_images) 61 | 62 | kl_terms = [0]*self.sequence_length 63 | for t in xrange(self.sequence_length): 64 | mu2 = tf.square(self.mu[t]) 65 | sigma2 = tf.square(self.sigma[t]) 66 | logsigma = self.logsigma[t] 67 | kl_terms[t] = 0.5 * tf.reduce_sum(mu2 + sigma2 - 2*logsigma, 1) - self.sequence_length*0.5 68 | self.latent_loss = tf.reduce_mean(tf.add_n(kl_terms)) 69 | self.cost = self.generation_loss + self.latent_loss 70 | optimizer = tf.train.AdamOptimizer(1e-3, beta1=0.5) 71 | grads = optimizer.compute_gradients(self.cost) 72 | for i,(g,v) in enumerate(grads): 73 | if g is not None: 74 | grads[i] = (tf.clip_by_norm(g,5),v) 75 | self.train_op = optimizer.apply_gradients(grads) 76 | 77 | self.sess = tf.Session() 78 | self.sess.run(tf.initialize_all_variables()) 79 | 80 | # given a hidden decoder layer: 81 | # locate where to put attention filters 82 | def attn_window(self, scope, h_dec): 83 | with tf.variable_scope(scope, reuse=self.share_parameters): 84 | parameters = dense(h_dec, self.n_hidden, 5) 85 | # gx_, gy_: center of 2d gaussian on a scale of -1 to 1 86 | gx_, gy_, log_sigma2, log_delta, log_gamma = tf.split(1,5,parameters) 87 | 88 | # move gx/gy to be a scale of -imgsize to +imgsize 89 | gx = (self.img_size+1)/2 * (gx_ + 1) 90 | gy = (self.img_size+1)/2 * (gy_ + 1) 91 | 92 | sigma2 = tf.exp(log_sigma2) 93 | # stride/delta: how far apart these patches will be 94 | delta = (self.img_size - 1) / ((self.attention_n-1) * tf.exp(log_delta)) 95 | # returns [Fx, Fy, gamma] 96 | 97 | self.attn_params.append([gx, gy, delta]) 98 | 99 | return self.filterbank(gx,gy,sigma2,delta) + (tf.exp(log_gamma),) 100 | 101 | # Given a center, distance, and spread 102 | # Construct [attention_n x attention_n] patches of gaussian filters 103 | # represented by Fx = horizontal gaussian, Fy = vertical guassian 104 | def filterbank(self, gx, gy, sigma2, delta): 105 | # 1 x N, look like [[0,1,2,3,4]] 106 | grid_i = tf.reshape(tf.cast(tf.range(self.attention_n), tf.float32),[1, -1]) 107 | # centers for the individual patches 108 | mu_x = gx + (grid_i - self.attention_n/2 - 0.5) * delta 109 | mu_y = gy + (grid_i - self.attention_n/2 - 0.5) * delta 110 | mu_x = tf.reshape(mu_x, [-1, self.attention_n, 1]) 111 | mu_y = tf.reshape(mu_y, [-1, self.attention_n, 1]) 112 | # 1 x 1 x imgsize, looks like [[[0,1,2,3,4,...,27]]] 113 | im = tf.reshape(tf.cast(tf.range(self.img_size), tf.float32), [1, 1, -1]) 114 | # list of gaussian curves for x and y 115 | sigma2 = tf.reshape(sigma2, [-1, 1, 1]) 116 | Fx = tf.exp(-tf.square((im - mu_x) / (2*sigma2))) 117 | Fy = tf.exp(-tf.square((im - mu_x) / (2*sigma2))) 118 | # normalize so area-under-curve = 1 119 | Fx = Fx / tf.maximum(tf.reduce_sum(Fx,2,keep_dims=True),1e-8) 120 | Fy = Fy / tf.maximum(tf.reduce_sum(Fy,2,keep_dims=True),1e-8) 121 | return Fx, Fy 122 | 123 | 124 | # the read() operation without attention 125 | def read_basic(self, x, x_hat, h_dec_prev): 126 | return tf.concat(1,[x,x_hat]) 127 | 128 | def read_attention(self, x, x_hat, h_dec_prev): 129 | Fx, Fy, gamma = self.attn_window("read", h_dec_prev) 130 | # we have the parameters for a patch of gaussian filters. apply them. 131 | def filter_img(img, Fx, Fy, gamma): 132 | # Fx,Fy = [64,5,32] 133 | # img = [64, 32*32*3] 134 | 135 | img = tf.reshape(img, [-1, self.img_size, self.img_size, self.num_colors]) 136 | img_t = tf.transpose(img, perm=[3,0,1,2]) 137 | 138 | # color1, color2, color3, color1, color2, color3, etc. 139 | batch_colors_array = tf.reshape(img_t, [self.num_colors * self.batch_size, self.img_size, self.img_size]) 140 | Fx_array = tf.concat(0, [Fx, Fx, Fx]) 141 | Fy_array = tf.concat(0, [Fy, Fy, Fy]) 142 | 143 | Fxt = tf.transpose(Fx_array, perm=[0,2,1]) 144 | 145 | # Apply the gaussian patches: 146 | glimpse = tf.batch_matmul(Fy_array, tf.batch_matmul(batch_colors_array, Fxt)) 147 | glimpse = tf.reshape(glimpse, [self.num_colors, self.batch_size, self.attention_n, self.attention_n]) 148 | glimpse = tf.transpose(glimpse, [1,2,3,0]) 149 | glimpse = tf.reshape(glimpse, [self.batch_size, self.attention_n*self.attention_n*self.num_colors]) 150 | # finally scale this glimpse w/ the gamma parameter 151 | return glimpse * tf.reshape(gamma, [-1, 1]) 152 | x = filter_img(x, Fx, Fy, gamma) 153 | x_hat = filter_img(x_hat, Fx, Fy, gamma) 154 | return tf.concat(1, [x, x_hat]) 155 | 156 | # encode an attention patch 157 | def encode(self, prev_state, image): 158 | # update the RNN with image 159 | with tf.variable_scope("encoder",reuse=self.share_parameters): 160 | hidden_layer, next_state = self.lstm_enc(image, prev_state) 161 | 162 | # map the RNN hidden state to latent variables 163 | with tf.variable_scope("mu", reuse=self.share_parameters): 164 | mu = dense(hidden_layer, self.n_hidden, self.n_z) 165 | with tf.variable_scope("sigma", reuse=self.share_parameters): 166 | logsigma = dense(hidden_layer, self.n_hidden, self.n_z) 167 | sigma = tf.exp(logsigma) 168 | return mu, logsigma, sigma, next_state 169 | 170 | 171 | def sampleQ(self, mu, sigma): 172 | return mu + sigma*self.e 173 | 174 | def decode_layer(self, prev_state, latent): 175 | # update decoder RNN with latent var 176 | with tf.variable_scope("decoder", reuse=self.share_parameters): 177 | hidden_layer, next_state = self.lstm_dec(latent, prev_state) 178 | 179 | return hidden_layer, next_state 180 | 181 | def write_basic(self, hidden_layer): 182 | # map RNN hidden state to image 183 | with tf.variable_scope("write", reuse=self.share_parameters): 184 | decoded_image_portion = dense(hidden_layer, self.n_hidden, self.img_size*self.img_size*self.num_colors) 185 | # decoded_image_portion = tf.reshape(decoded_image_portion, [-1, self.img_size, self.img_size, self.num_colors]) 186 | return decoded_image_portion 187 | 188 | def write_attention(self, hidden_layer): 189 | with tf.variable_scope("writeW", reuse=self.share_parameters): 190 | w = dense(hidden_layer, self.n_hidden, self.attention_n*self.attention_n*self.num_colors) 191 | 192 | w = tf.reshape(w, [self.batch_size, self.attention_n, self.attention_n, self.num_colors]) 193 | w_t = tf.transpose(w, perm=[3,0,1,2]) 194 | Fx, Fy, gamma = self.attn_window("write", hidden_layer) 195 | 196 | # color1, color2, color3, color1, color2, color3, etc. 197 | w_array = tf.reshape(w_t, [self.num_colors * self.batch_size, self.attention_n, self.attention_n]) 198 | Fx_array = tf.concat(0, [Fx, Fx, Fx]) 199 | Fy_array = tf.concat(0, [Fy, Fy, Fy]) 200 | 201 | Fyt = tf.transpose(Fy_array, perm=[0,2,1]) 202 | # [vert, attn_n] * [attn_n, attn_n] * [attn_n, horiz] 203 | wr = tf.batch_matmul(Fyt, tf.batch_matmul(w_array, Fx_array)) 204 | sep_colors = tf.reshape(wr, [self.batch_size, self.num_colors, self.img_size**2]) 205 | wr = tf.reshape(wr, [self.num_colors, self.batch_size, self.img_size, self.img_size]) 206 | wr = tf.transpose(wr, [1,2,3,0]) 207 | wr = tf.reshape(wr, [self.batch_size, self.img_size * self.img_size * self.num_colors]) 208 | return wr * tf.reshape(1.0/gamma, [-1, 1]) 209 | 210 | 211 | def train(self): 212 | data = glob(os.path.join("../Datasets/celebA", "*.jpg")) 213 | base = np.array([get_image(sample_file, 108, is_crop=True) for sample_file in data[0:64]]) 214 | base += 1 215 | base /= 2 216 | 217 | ims("results/base.jpg",merge_color(base,[8,8])) 218 | 219 | saver = tf.train.Saver(max_to_keep=2) 220 | 221 | for e in xrange(10): 222 | for i in range((len(data) / self.batch_size) - 2): 223 | 224 | batch_files = data[i*self.batch_size:(i+1)*self.batch_size] 225 | batch = [get_image(batch_file, 108, is_crop=True) for batch_file in batch_files] 226 | batch_images = np.array(batch).astype(np.float32) 227 | batch_images += 1 228 | batch_images /= 2 229 | 230 | cs, attn_params, gen_loss, lat_loss, _ = self.sess.run([self.cs, self.attn_params, self.generation_loss, self.latent_loss, self.train_op], feed_dict={self.images: batch_images}) 231 | print "epoch %d iter %d genloss %f latloss %f" % (e, i, gen_loss, lat_loss) 232 | # print attn_params[0].shape 233 | # print attn_params[1].shape 234 | # print attn_params[2].shape 235 | if i % 800 == 0: 236 | 237 | saver.save(self.sess, os.getcwd() + "/training/train", global_step=e*10000 + i) 238 | 239 | cs = 1.0/(1.0+np.exp(-np.array(cs))) # x_recons=sigmoid(canvas) 240 | 241 | for cs_iter in xrange(10): 242 | results = cs[cs_iter] 243 | results_square = np.reshape(results, [-1, self.img_size, self.img_size, self.num_colors]) 244 | print results_square.shape 245 | ims("results/"+str(e)+"-"+str(i)+"-step-"+str(cs_iter)+".jpg",merge_color(results_square,[8,8])) 246 | 247 | 248 | def view(self): 249 | data = glob(os.path.join("../Datasets/celebA", "*.jpg")) 250 | base = np.array([get_image(sample_file, 108, is_crop=True) for sample_file in data[0:64]]) 251 | base += 1 252 | base /= 2 253 | 254 | ims("results/base.jpg",merge_color(base,[8,8])) 255 | 256 | saver = tf.train.Saver(max_to_keep=2) 257 | saver.restore(self.sess, tf.train.latest_checkpoint(os.getcwd()+"/training/")) 258 | 259 | cs, attn_params, gen_loss, lat_loss = self.sess.run([self.cs, self.attn_params, self.generation_loss, self.latent_loss], feed_dict={self.images: base}) 260 | print "genloss %f latloss %f" % (gen_loss, lat_loss) 261 | 262 | cs = 1.0/(1.0+np.exp(-np.array(cs))) # x_recons=sigmoid(canvas) 263 | 264 | print np.shape(cs) 265 | print np.shape(attn_params) 266 | # cs[0][cent] 267 | 268 | for cs_iter in xrange(10): 269 | results = cs[cs_iter] 270 | results_square = np.reshape(results, [-1, self.img_size, self.img_size, self.num_colors]) 271 | 272 | print np.shape(results_square) 273 | 274 | for i in xrange(64): 275 | center_x = int(attn_params[cs_iter][0][i][0]) 276 | center_y = int(attn_params[cs_iter][1][i][0]) 277 | distance = int(attn_params[cs_iter][2][i][0]) 278 | 279 | size = 2; 280 | 281 | # for x in xrange(3): 282 | # for y in xrange(3): 283 | # nx = x - 1; 284 | # ny = y - 1; 285 | # 286 | # xpos = center_x + nx*distance 287 | # ypos = center_y + ny*distance 288 | # 289 | # xpos2 = min(max(0, xpos + size), 63) 290 | # ypos2 = min(max(0, ypos + size), 63) 291 | # 292 | # xpos = min(max(0, xpos), 63) 293 | # ypos = min(max(0, ypos), 63) 294 | # 295 | # results_square[i,xpos:xpos2,ypos:ypos2,0] = 0; 296 | # results_square[i,xpos:xpos2,ypos:ypos2,1] = 1; 297 | # results_square[i,xpos:xpos2,ypos:ypos2,2] = 0; 298 | # print "%f , %f" % (center_x, center_y) 299 | 300 | print results_square 301 | 302 | ims("results/view-clean-step-"+str(cs_iter)+".jpg",merge_color(results_square,[8,8])) 303 | 304 | 305 | 306 | 307 | model = Draw() 308 | # model.train() 309 | model.view() 310 | --------------------------------------------------------------------------------