├── requirements.txt ├── samples ├── A_landmark │ └── example │ │ ├── timg.txt │ │ ├── WechatIMG256.txt │ │ ├── WechatIMG260.txt │ │ ├── WechatIMG261.txt │ │ ├── WechatIMG262.txt │ │ ├── WechatIMG263.txt │ │ ├── w_sexy_gr.txt │ │ ├── WechatIMG253_2.txt │ │ ├── girl-2099354_1920.txt │ │ ├── girl-2099357_1920.txt │ │ ├── girl-2122909_1920.txt │ │ ├── girl-2122927_1280.txt │ │ ├── girl-2128294_1920.txt │ │ ├── girl-2132171_1920.txt │ │ ├── girl-2143709_1920.txt │ │ ├── girl-2164409_1920.txt │ │ ├── girl-2177360_1920.txt │ │ ├── girl-2720476_1920.txt │ │ ├── girl-2999078_1920.txt │ │ ├── girl-4024238_1920.txt │ │ ├── girl-4024240_1920.txt │ │ ├── girl-4024244_1920.txt │ │ ├── male-467711_1920.txt │ │ ├── own-2553537_1280.txt │ │ ├── young-507297_1920.txt │ │ ├── brazil-1368806_1920.txt │ │ ├── garden-2768329_1920.txt │ │ ├── model-2134460_1920.txt │ │ ├── model-2911329_1920.txt │ │ ├── model-2911332_1920.txt │ │ ├── pinky-2727846_1920.txt │ │ ├── pinky-2727874_1920.txt │ │ ├── portrait-2164027_1920.txt │ │ ├── portrait-2554431_1920.txt │ │ ├── fconrad_Portrait_060414a.txt │ │ ├── mid-autumn-2752710_1920.txt │ │ ├── portrait-smiling-woman-blue-shirt-450w-218101459.txt │ │ ├── passport-picture-businesswoman-brown-hair-450w-250775908.txt │ │ └── portrait-laughing-businesswoman-long-dark-450w-235195312.txt ├── A │ └── example │ │ ├── timg.png │ │ ├── w_sexy_gr.png │ │ ├── WechatIMG256.png │ │ ├── WechatIMG260.png │ │ ├── WechatIMG261.png │ │ ├── WechatIMG262.png │ │ ├── WechatIMG263.png │ │ ├── WechatIMG253_2.png │ │ ├── girl-2099354_1920.png │ │ ├── girl-2099357_1920.png │ │ ├── girl-2122909_1920.png │ │ ├── girl-2122927_1280.png │ │ ├── girl-2128294_1920.png │ │ ├── girl-2132171_1920.png │ │ ├── girl-2143709_1920.png │ │ ├── girl-2164409_1920.png │ │ ├── girl-2177360_1920.png │ │ ├── girl-2720476_1920.png │ │ ├── girl-2999078_1920.png │ │ ├── girl-4024238_1920.png │ │ ├── girl-4024240_1920.png │ │ ├── girl-4024244_1920.png │ │ ├── male-467711_1920.png │ │ ├── own-2553537_1280.png │ │ ├── young-507297_1920.png │ │ ├── brazil-1368806_1920.png │ │ ├── garden-2768329_1920.png │ │ ├── model-2134460_1920.png │ │ ├── model-2911329_1920.png │ │ ├── model-2911332_1920.png │ │ ├── pinky-2727846_1920.png │ │ ├── pinky-2727874_1920.png │ │ ├── portrait-2164027_1920.png │ │ ├── portrait-2554431_1920.png │ │ ├── fconrad_Portrait_060414a.png │ │ ├── mid-autumn-2752710_1920.png │ │ ├── portrait-smiling-woman-blue-shirt-450w-218101459.png │ │ ├── passport-picture-businesswoman-brown-hair-450w-250775908.png │ │ └── portrait-laughing-businesswoman-long-dark-450w-235195312.png └── A_mask │ └── example │ ├── timg.png │ ├── w_sexy_gr.png │ ├── WechatIMG256.png │ ├── WechatIMG260.png │ ├── WechatIMG261.png │ ├── WechatIMG262.png │ ├── WechatIMG263.png │ ├── WechatIMG253_2.png │ ├── girl-2099354_1920.png │ ├── girl-2099357_1920.png │ ├── girl-2122909_1920.png │ ├── girl-2122927_1280.png │ ├── girl-2128294_1920.png │ ├── girl-2132171_1920.png │ ├── girl-2143709_1920.png │ ├── girl-2164409_1920.png │ ├── girl-2177360_1920.png │ ├── girl-2720476_1920.png │ ├── girl-2999078_1920.png │ ├── girl-4024238_1920.png │ ├── girl-4024240_1920.png │ ├── girl-4024244_1920.png │ ├── male-467711_1920.png │ ├── own-2553537_1280.png │ ├── young-507297_1920.png │ ├── brazil-1368806_1920.png │ ├── garden-2768329_1920.png │ ├── model-2134460_1920.png │ ├── model-2911329_1920.png │ ├── model-2911332_1920.png │ ├── pinky-2727846_1920.png │ ├── pinky-2727874_1920.png │ ├── portrait-2164027_1920.png │ ├── portrait-2554431_1920.png │ ├── fconrad_Portrait_060414a.png │ ├── mid-autumn-2752710_1920.png │ ├── portrait-smiling-woman-blue-shirt-450w-218101459.png │ ├── passport-picture-businesswoman-brown-hair-450w-250775908.png │ └── portrait-laughing-businesswoman-long-dark-450w-235195312.png ├── preprocess ├── example │ ├── img_1701_aligned.txt │ ├── img_1701.jpg │ ├── img_1701_aligned.png │ ├── img_1701_facial5point.mat │ └── img_1701_aligned_bgmask.png ├── face_align_512.m ├── combine_A_and_B.py └── readme.md ├── imgs ├── architecture.png └── samples │ ├── img_1673.png │ ├── img_1682.png │ ├── img_1696.png │ ├── img_1701.png │ ├── img_1794.png │ ├── img_1673_fake_B.png │ ├── img_1682_fake_B.png │ ├── img_1696_fake_B.png │ ├── img_1701_fake_B.png │ └── img_1794_fake_B.png ├── .gitignore ├── utils.py ├── readme.md ├── test.py ├── models.py ├── datasets.py └── apdrawing_gan.py /requirements.txt: -------------------------------------------------------------------------------- 1 | jittor==1.2.2.58 2 | opencv-python==4.2.0.34 3 | numpy==1.19.4 4 | Pillow==7.2.0 -------------------------------------------------------------------------------- /samples/A_landmark/example/timg.txt: -------------------------------------------------------------------------------- 1 | 205 240 2 | 313 241 3 | 240 329 4 | 217 367 5 | 305 370 6 | -------------------------------------------------------------------------------- /preprocess/example/img_1701_aligned.txt: -------------------------------------------------------------------------------- 1 | 194 248 2 | 314 249 3 | 261 312 4 | 209 368 5 | 302 371 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/WechatIMG256.txt: -------------------------------------------------------------------------------- 1 | 199 246 2 | 313 247 3 | 255 317 4 | 207 369 5 | 306 368 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/WechatIMG260.txt: -------------------------------------------------------------------------------- 1 | 195 247 2 | 316 246 3 | 259 314 4 | 212 371 5 | 298 370 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/WechatIMG261.txt: -------------------------------------------------------------------------------- 1 | 196 246 2 | 310 247 3 | 267 316 4 | 204 370 5 | 302 368 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/WechatIMG262.txt: -------------------------------------------------------------------------------- 1 | 191 245 2 | 319 246 3 | 258 330 4 | 208 363 5 | 303 363 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/WechatIMG263.txt: -------------------------------------------------------------------------------- 1 | 199 242 2 | 318 244 3 | 241 329 4 | 217 364 5 | 306 368 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/w_sexy_gr.txt: -------------------------------------------------------------------------------- 1 | 200 245 2 | 312 245 3 | 256 316 4 | 210 371 5 | 302 371 6 | -------------------------------------------------------------------------------- /imgs/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/imgs/architecture.png -------------------------------------------------------------------------------- /samples/A_landmark/example/WechatIMG253_2.txt: -------------------------------------------------------------------------------- 1 | 197 248 2 | 316 247 3 | 254 324 4 | 204 363 5 | 310 365 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/girl-2099354_1920.txt: -------------------------------------------------------------------------------- 1 | 200 247 2 | 316 244 3 | 253 313 4 | 213 372 5 | 299 371 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/girl-2099357_1920.txt: -------------------------------------------------------------------------------- 1 | 200 243 2 | 312 245 3 | 254 315 4 | 214 372 5 | 299 372 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/girl-2122909_1920.txt: -------------------------------------------------------------------------------- 1 | 197 242 2 | 310 243 3 | 267 317 4 | 213 373 5 | 293 372 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/girl-2122927_1280.txt: -------------------------------------------------------------------------------- 1 | 202 245 2 | 315 247 3 | 241 316 4 | 213 368 5 | 309 371 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/girl-2128294_1920.txt: -------------------------------------------------------------------------------- 1 | 195 246 2 | 311 247 3 | 266 320 4 | 204 367 5 | 303 368 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/girl-2132171_1920.txt: -------------------------------------------------------------------------------- 1 | 195 247 2 | 309 247 3 | 275 321 4 | 198 367 5 | 304 365 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/girl-2143709_1920.txt: -------------------------------------------------------------------------------- 1 | 201 245 2 | 312 244 3 | 255 308 4 | 215 375 5 | 296 375 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/girl-2164409_1920.txt: -------------------------------------------------------------------------------- 1 | 199 245 2 | 313 246 3 | 254 326 4 | 205 364 5 | 309 366 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/girl-2177360_1920.txt: -------------------------------------------------------------------------------- 1 | 197 248 2 | 316 249 3 | 253 321 4 | 201 365 5 | 311 363 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/girl-2720476_1920.txt: -------------------------------------------------------------------------------- 1 | 194 249 2 | 317 247 3 | 260 326 4 | 201 363 5 | 308 362 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/girl-2999078_1920.txt: -------------------------------------------------------------------------------- 1 | 198 249 2 | 315 251 3 | 250 326 4 | 196 360 5 | 320 361 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/girl-4024238_1920.txt: -------------------------------------------------------------------------------- 1 | 197 251 2 | 310 249 3 | 268 321 4 | 192 363 5 | 313 363 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/girl-4024240_1920.txt: -------------------------------------------------------------------------------- 1 | 197 249 2 | 310 248 3 | 270 326 4 | 192 364 5 | 311 360 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/girl-4024244_1920.txt: -------------------------------------------------------------------------------- 1 | 197 250 2 | 313 249 3 | 261 327 4 | 193 361 5 | 315 361 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/male-467711_1920.txt: -------------------------------------------------------------------------------- 1 | 201 249 2 | 313 247 3 | 253 312 4 | 206 369 5 | 308 370 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/own-2553537_1280.txt: -------------------------------------------------------------------------------- 1 | 196 248 2 | 314 250 3 | 257 319 4 | 201 364 5 | 312 366 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/young-507297_1920.txt: -------------------------------------------------------------------------------- 1 | 205 249 2 | 314 246 3 | 244 310 4 | 208 372 5 | 308 370 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/brazil-1368806_1920.txt: -------------------------------------------------------------------------------- 1 | 199 246 2 | 313 246 3 | 254 316 4 | 210 369 5 | 303 370 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/garden-2768329_1920.txt: -------------------------------------------------------------------------------- 1 | 197 249 2 | 311 248 3 | 269 326 4 | 192 364 5 | 311 360 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/model-2134460_1920.txt: -------------------------------------------------------------------------------- 1 | 201 246 2 | 317 246 3 | 242 311 4 | 217 372 5 | 303 372 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/model-2911329_1920.txt: -------------------------------------------------------------------------------- 1 | 199 246 2 | 309 247 3 | 263 316 4 | 203 368 5 | 306 370 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/model-2911332_1920.txt: -------------------------------------------------------------------------------- 1 | 200 244 2 | 310 244 3 | 260 317 4 | 210 369 5 | 301 372 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/pinky-2727846_1920.txt: -------------------------------------------------------------------------------- 1 | 196 249 2 | 317 250 3 | 254 331 4 | 196 359 5 | 317 359 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/pinky-2727874_1920.txt: -------------------------------------------------------------------------------- 1 | 206 244 2 | 316 247 3 | 231 326 4 | 211 365 5 | 316 366 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/portrait-2164027_1920.txt: -------------------------------------------------------------------------------- 1 | 198 250 2 | 315 250 3 | 252 325 4 | 197 361 5 | 318 362 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/portrait-2554431_1920.txt: -------------------------------------------------------------------------------- 1 | 198 250 2 | 315 250 3 | 256 318 4 | 198 366 5 | 313 362 6 | -------------------------------------------------------------------------------- /imgs/samples/img_1673.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/imgs/samples/img_1673.png -------------------------------------------------------------------------------- /imgs/samples/img_1682.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/imgs/samples/img_1682.png -------------------------------------------------------------------------------- /imgs/samples/img_1696.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/imgs/samples/img_1696.png -------------------------------------------------------------------------------- /imgs/samples/img_1701.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/imgs/samples/img_1701.png -------------------------------------------------------------------------------- /imgs/samples/img_1794.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/imgs/samples/img_1794.png -------------------------------------------------------------------------------- /samples/A/example/timg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/timg.png -------------------------------------------------------------------------------- /samples/A_landmark/example/fconrad_Portrait_060414a.txt: -------------------------------------------------------------------------------- 1 | 200 250 2 | 309 251 3 | 259 307 4 | 199 368 5 | 312 371 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/mid-autumn-2752710_1920.txt: -------------------------------------------------------------------------------- 1 | 201 247 2 | 316 250 3 | 243 325 4 | 201 362 5 | 319 363 6 | -------------------------------------------------------------------------------- /imgs/samples/img_1673_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/imgs/samples/img_1673_fake_B.png -------------------------------------------------------------------------------- /imgs/samples/img_1682_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/imgs/samples/img_1682_fake_B.png -------------------------------------------------------------------------------- /imgs/samples/img_1696_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/imgs/samples/img_1696_fake_B.png -------------------------------------------------------------------------------- /imgs/samples/img_1701_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/imgs/samples/img_1701_fake_B.png -------------------------------------------------------------------------------- /imgs/samples/img_1794_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/imgs/samples/img_1794_fake_B.png -------------------------------------------------------------------------------- /preprocess/example/img_1701.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/preprocess/example/img_1701.jpg -------------------------------------------------------------------------------- /samples/A/example/w_sexy_gr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/w_sexy_gr.png -------------------------------------------------------------------------------- /samples/A_mask/example/timg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/timg.png -------------------------------------------------------------------------------- /samples/A/example/WechatIMG256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/WechatIMG256.png -------------------------------------------------------------------------------- /samples/A/example/WechatIMG260.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/WechatIMG260.png -------------------------------------------------------------------------------- /samples/A/example/WechatIMG261.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/WechatIMG261.png -------------------------------------------------------------------------------- /samples/A/example/WechatIMG262.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/WechatIMG262.png -------------------------------------------------------------------------------- /samples/A/example/WechatIMG263.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/WechatIMG263.png -------------------------------------------------------------------------------- /samples/A/example/WechatIMG253_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/WechatIMG253_2.png -------------------------------------------------------------------------------- /samples/A_mask/example/w_sexy_gr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/w_sexy_gr.png -------------------------------------------------------------------------------- /preprocess/example/img_1701_aligned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/preprocess/example/img_1701_aligned.png -------------------------------------------------------------------------------- /samples/A/example/girl-2099354_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/girl-2099354_1920.png -------------------------------------------------------------------------------- /samples/A/example/girl-2099357_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/girl-2099357_1920.png -------------------------------------------------------------------------------- /samples/A/example/girl-2122909_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/girl-2122909_1920.png -------------------------------------------------------------------------------- /samples/A/example/girl-2122927_1280.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/girl-2122927_1280.png -------------------------------------------------------------------------------- /samples/A/example/girl-2128294_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/girl-2128294_1920.png -------------------------------------------------------------------------------- /samples/A/example/girl-2132171_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/girl-2132171_1920.png -------------------------------------------------------------------------------- /samples/A/example/girl-2143709_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/girl-2143709_1920.png -------------------------------------------------------------------------------- /samples/A/example/girl-2164409_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/girl-2164409_1920.png -------------------------------------------------------------------------------- /samples/A/example/girl-2177360_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/girl-2177360_1920.png -------------------------------------------------------------------------------- /samples/A/example/girl-2720476_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/girl-2720476_1920.png -------------------------------------------------------------------------------- /samples/A/example/girl-2999078_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/girl-2999078_1920.png -------------------------------------------------------------------------------- /samples/A/example/girl-4024238_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/girl-4024238_1920.png -------------------------------------------------------------------------------- /samples/A/example/girl-4024240_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/girl-4024240_1920.png -------------------------------------------------------------------------------- /samples/A/example/girl-4024244_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/girl-4024244_1920.png -------------------------------------------------------------------------------- /samples/A/example/male-467711_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/male-467711_1920.png -------------------------------------------------------------------------------- /samples/A/example/own-2553537_1280.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/own-2553537_1280.png -------------------------------------------------------------------------------- /samples/A/example/young-507297_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/young-507297_1920.png -------------------------------------------------------------------------------- /samples/A_landmark/example/portrait-smiling-woman-blue-shirt-450w-218101459.txt: -------------------------------------------------------------------------------- 1 | 195 245 2 | 315 247 3 | 258 331 4 | 202 362 5 | 309 362 6 | -------------------------------------------------------------------------------- /samples/A_mask/example/WechatIMG256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/WechatIMG256.png -------------------------------------------------------------------------------- /samples/A_mask/example/WechatIMG260.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/WechatIMG260.png -------------------------------------------------------------------------------- /samples/A_mask/example/WechatIMG261.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/WechatIMG261.png -------------------------------------------------------------------------------- /samples/A_mask/example/WechatIMG262.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/WechatIMG262.png -------------------------------------------------------------------------------- /samples/A_mask/example/WechatIMG263.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/WechatIMG263.png -------------------------------------------------------------------------------- /samples/A/example/brazil-1368806_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/brazil-1368806_1920.png -------------------------------------------------------------------------------- /samples/A/example/garden-2768329_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/garden-2768329_1920.png -------------------------------------------------------------------------------- /samples/A/example/model-2134460_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/model-2134460_1920.png -------------------------------------------------------------------------------- /samples/A/example/model-2911329_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/model-2911329_1920.png -------------------------------------------------------------------------------- /samples/A/example/model-2911332_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/model-2911332_1920.png -------------------------------------------------------------------------------- /samples/A/example/pinky-2727846_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/pinky-2727846_1920.png -------------------------------------------------------------------------------- /samples/A/example/pinky-2727874_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/pinky-2727874_1920.png -------------------------------------------------------------------------------- /samples/A_landmark/example/passport-picture-businesswoman-brown-hair-450w-250775908.txt: -------------------------------------------------------------------------------- 1 | 196 251 2 | 315 253 3 | 256 320 4 | 194 361 5 | 319 361 6 | -------------------------------------------------------------------------------- /samples/A_landmark/example/portrait-laughing-businesswoman-long-dark-450w-235195312.txt: -------------------------------------------------------------------------------- 1 | 198 251 2 | 314 251 3 | 256 319 4 | 196 364 5 | 316 363 6 | -------------------------------------------------------------------------------- /samples/A_mask/example/WechatIMG253_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/WechatIMG253_2.png -------------------------------------------------------------------------------- /preprocess/example/img_1701_facial5point.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/preprocess/example/img_1701_facial5point.mat -------------------------------------------------------------------------------- /samples/A/example/portrait-2164027_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/portrait-2164027_1920.png -------------------------------------------------------------------------------- /samples/A/example/portrait-2554431_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/portrait-2554431_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/girl-2099354_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/girl-2099354_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/girl-2099357_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/girl-2099357_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/girl-2122909_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/girl-2122909_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/girl-2122927_1280.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/girl-2122927_1280.png -------------------------------------------------------------------------------- /samples/A_mask/example/girl-2128294_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/girl-2128294_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/girl-2132171_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/girl-2132171_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/girl-2143709_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/girl-2143709_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/girl-2164409_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/girl-2164409_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/girl-2177360_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/girl-2177360_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/girl-2720476_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/girl-2720476_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/girl-2999078_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/girl-2999078_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/girl-4024238_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/girl-4024238_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/girl-4024240_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/girl-4024240_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/girl-4024244_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/girl-4024244_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/male-467711_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/male-467711_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/own-2553537_1280.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/own-2553537_1280.png -------------------------------------------------------------------------------- /samples/A_mask/example/young-507297_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/young-507297_1920.png -------------------------------------------------------------------------------- /preprocess/example/img_1701_aligned_bgmask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/preprocess/example/img_1701_aligned_bgmask.png -------------------------------------------------------------------------------- /samples/A/example/fconrad_Portrait_060414a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/fconrad_Portrait_060414a.png -------------------------------------------------------------------------------- /samples/A/example/mid-autumn-2752710_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/mid-autumn-2752710_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/brazil-1368806_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/brazil-1368806_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/garden-2768329_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/garden-2768329_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/model-2134460_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/model-2134460_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/model-2911329_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/model-2911329_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/model-2911332_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/model-2911332_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/pinky-2727846_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/pinky-2727846_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/pinky-2727874_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/pinky-2727874_1920.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | */*.pyc 2 | */**/*.pyc 3 | */**/**/*.pyc 4 | */**/**/**/*.pyc 5 | */**/**/**/**/*.pyc 6 | checkpoints/ 7 | data/ 8 | images/ 9 | results/ 10 | -------------------------------------------------------------------------------- /samples/A_mask/example/portrait-2164027_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/portrait-2164027_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/portrait-2554431_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/portrait-2554431_1920.png -------------------------------------------------------------------------------- /samples/A_mask/example/fconrad_Portrait_060414a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/fconrad_Portrait_060414a.png -------------------------------------------------------------------------------- /samples/A_mask/example/mid-autumn-2752710_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/mid-autumn-2752710_1920.png -------------------------------------------------------------------------------- /samples/A/example/portrait-smiling-woman-blue-shirt-450w-218101459.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/portrait-smiling-woman-blue-shirt-450w-218101459.png -------------------------------------------------------------------------------- /samples/A_mask/example/portrait-smiling-woman-blue-shirt-450w-218101459.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/portrait-smiling-woman-blue-shirt-450w-218101459.png -------------------------------------------------------------------------------- /samples/A/example/passport-picture-businesswoman-brown-hair-450w-250775908.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/passport-picture-businesswoman-brown-hair-450w-250775908.png -------------------------------------------------------------------------------- /samples/A/example/portrait-laughing-businesswoman-long-dark-450w-235195312.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A/example/portrait-laughing-businesswoman-long-dark-450w-235195312.png -------------------------------------------------------------------------------- /samples/A_mask/example/passport-picture-businesswoman-brown-hair-450w-250775908.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/passport-picture-businesswoman-brown-hair-450w-250775908.png -------------------------------------------------------------------------------- /samples/A_mask/example/portrait-laughing-businesswoman-long-dark-450w-235195312.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN-Jittor/HEAD/samples/A_mask/example/portrait-laughing-businesswoman-long-dark-450w-235195312.png -------------------------------------------------------------------------------- /preprocess/face_align_512.m: -------------------------------------------------------------------------------- 1 | function [trans_img,trans_facial5point]=face_align_512(impath,facial5point,savedir) 2 | % align the faces by similarity transformation. 3 | % using 5 facial landmarks: 2 eyes, nose, 2 mouth corners. 4 | % impath: path to image 5 | % facial5point: 5x2 size, 5 facial landmark positions, detected by MTCNN 6 | % savedir: savedir for cropped image and transformed facial landmarks 7 | 8 | %% alignment settings 9 | imgSize = [512,512]; 10 | coord5point = [180,230; 11 | 300,230; 12 | 240,301; 13 | 186,365.6; 14 | 294,365.6];%480x480 15 | coord5point = (coord5point-240)/560 * 512 + 256; 16 | 17 | %% face alignment 18 | 19 | % load and align, resize image to imgSize 20 | img = imread(impath); 21 | facial5point = double(facial5point); 22 | transf = cp2tform(facial5point, coord5point, 'similarity'); 23 | trans_img = imtransform(img, transf, 'XData', [1 imgSize(2)],... 24 | 'YData', [1 imgSize(1)],... 25 | 'Size', imgSize,... 26 | 'FillValues', [255;255;255]); 27 | trans_facial5point = round(tformfwd(transf,facial5point)); 28 | 29 | 30 | %% save results 31 | if ~exist(savedir,'dir') 32 | mkdir(savedir) 33 | end 34 | [~,name,~] = fileparts(impath); 35 | % save trans_img 36 | imwrite(trans_img, fullfile(savedir,[name,'_aligned.png'])); 37 | fprintf('write aligned image to %s\n',fullfile(savedir,[name,'_aligned.png'])); 38 | % save trans_facial5point 39 | write_5pt(fullfile(savedir, [name, '_aligned.txt']), trans_facial5point); 40 | fprintf('write transformed facial landmark to %s\n',fullfile(savedir,[name,'_aligned.txt'])); 41 | 42 | %% show results 43 | imshow(trans_img); hold on; 44 | plot(trans_facial5point(:,1),trans_facial5point(:,2),'b'); 45 | plot(trans_facial5point(:,1),trans_facial5point(:,2),'r+'); 46 | 47 | end 48 | 49 | function [] = write_5pt(fn, trans_pt) 50 | fid = fopen(fn, 'w'); 51 | for i = 1:5 52 | fprintf(fid, '%d %d\n', trans_pt(i,1), trans_pt(i,2));%will be read as np.int32 53 | end 54 | fclose(fid); 55 | end -------------------------------------------------------------------------------- /preprocess/combine_A_and_B.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser('create image pairs') 7 | parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges') 8 | parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg') 9 | parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB') 10 | parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000) 11 | parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true') 12 | args = parser.parse_args() 13 | 14 | for arg in vars(args): 15 | print('[%s] = ' % arg, getattr(args, arg)) 16 | 17 | splits = os.listdir(args.fold_A) 18 | 19 | for sp in splits: 20 | img_fold_A = os.path.join(args.fold_A, sp) 21 | img_fold_B = os.path.join(args.fold_B, sp) 22 | img_list = os.listdir(img_fold_A) 23 | if args.use_AB: 24 | img_list = [img_path for img_path in img_list if '_A.' in img_path] 25 | 26 | num_imgs = min(args.num_imgs, len(img_list)) 27 | print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list))) 28 | img_fold_AB = os.path.join(args.fold_AB, sp) 29 | if not os.path.isdir(img_fold_AB): 30 | os.makedirs(img_fold_AB) 31 | print('split = %s, number of images = %d' % (sp, num_imgs)) 32 | for n in range(num_imgs): 33 | name_A = img_list[n] 34 | path_A = os.path.join(img_fold_A, name_A) 35 | if args.use_AB: 36 | name_B = name_A.replace('_A.', '_B.') 37 | else: 38 | name_B = name_A 39 | path_B = os.path.join(img_fold_B, name_B) 40 | if os.path.isfile(path_A) and os.path.isfile(path_B): 41 | name_AB = name_A 42 | if args.use_AB: 43 | name_AB = name_AB.replace('_A.', '.') # remove _A 44 | path_AB = os.path.join(img_fold_AB, name_AB) 45 | im_A = cv2.imread(path_A, cv2.IMREAD_COLOR) 46 | im_B = cv2.imread(path_B, cv2.IMREAD_COLOR) 47 | im_AB = np.concatenate([im_A, im_B], 1) 48 | cv2.imwrite(path_AB, im_AB) 49 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import numpy as np 3 | 4 | EYE_H = 40 5 | EYE_W = 56 6 | NOSE_H = 48 7 | NOSE_W = 48 8 | MOUTH_H = 40 9 | MOUTH_W = 64 10 | 11 | def masked(A, mask): 12 | return (A/2+0.5)*mask*2-1 13 | 14 | def inverse_mask(mask): 15 | return jt.ones(mask.shape)-mask 16 | 17 | def addone_with_mask(A, mask): 18 | return ((A/2+0.5)*mask + (jt.ones(mask.shape)-mask))*2-1 19 | 20 | def partCombiner2_bg(center, eyel, eyer, nose, mouth, hair, bg, maskh, maskb, comb_op = 1, load_h = 512, load_w = 512): 21 | if comb_op == 0: 22 | # use max pooling, pad black for eyes etc 23 | padvalue = -1 24 | hair = masked(hair, maskh) 25 | bg = masked(bg, maskb) 26 | else: 27 | # use min pooling, pad white for eyes etc 28 | padvalue = 1 29 | hair = addone_with_mask(hair, maskh) 30 | bg = addone_with_mask(bg, maskb) 31 | ratio = load_h // 256 32 | rhs = np.array([EYE_H,EYE_H,NOSE_H,MOUTH_H]) * ratio 33 | rws = np.array([EYE_W,EYE_W,NOSE_W,MOUTH_W]) * ratio 34 | bs,nc,_,_ = eyel.shape 35 | eyel_p = jt.ones((bs,nc,load_h,load_w)) 36 | eyer_p = jt.ones((bs,nc,load_h,load_w)) 37 | nose_p = jt.ones((bs,nc,load_h,load_w)) 38 | mouth_p = jt.ones((bs,nc,load_h,load_w)) 39 | locals = [eyel, eyer, nose, mouth] 40 | locals_p = [eyel_p, eyer_p, nose_p, mouth_p] 41 | for i in range(bs): 42 | c = center[i].data#x,y 43 | for j in range(4): 44 | locals_p[j][i] = jt.nn.ConstantPad2d((int(c[j,0]-rws[j]/2), int(load_w-(c[j,0]+rws[j]/2)), int(c[j,1]-rhs[j]/2), int(load_h-(c[j,1]+rhs[j]/2))),padvalue)(locals[j][i]) 45 | if comb_op == 0: 46 | eyes = jt.maximum(locals_p[0], locals_p[1]) 47 | eye_nose = jt.maximum(eyes, locals_p[2]) 48 | eye_nose_mouth = jt.maximum(eye_nose, locals_p[3]) 49 | eye_nose_mouth_hair = jt.maximum(hair, eye_nose_mouth) 50 | result = jt.maximum(bg, eye_nose_mouth_hair) 51 | else: 52 | eyes = jt.minimum(locals_p[0], locals_p[1]) 53 | eye_nose = jt.minimum(eyes, locals_p[2]) 54 | eye_nose_mouth = jt.minimum(eye_nose, locals_p[3]) 55 | eye_nose_mouth_hair = jt.minimum(hair, eye_nose_mouth) 56 | result = jt.minimum(bg, eye_nose_mouth_hair) 57 | return result 58 | 59 | def getLocalParts(fakeAB, center, maskh, maskb, load_h = 512, load_w = 512): 60 | bs,nc,_,_ = fakeAB.shape 61 | ratio = load_h // 256 62 | rhs = np.array([EYE_H,EYE_H,NOSE_H,MOUTH_H]) * ratio 63 | rws = np.array([EYE_W,EYE_W,NOSE_W,MOUTH_W]) * ratio 64 | eyel = jt.ones((bs,nc,int(rhs[0]),int(rws[0]))) 65 | eyer = jt.ones((bs,nc,int(rhs[1]),int(rws[1]))) 66 | nose = jt.ones((bs,nc,int(rhs[2]),int(rws[2]))) 67 | mouth = jt.ones((bs,nc,int(rhs[3]),int(rws[3]))) 68 | locals = [eyel, eyer, nose, mouth] 69 | for i in range(bs): 70 | c = center[i].data 71 | for j in range(4): 72 | locals[j][i] = fakeAB[i, :, int(c[j,1]-rhs[j]//2):int(c[j,1]+rhs[j]//2), int(c[j,0]-rws[j]//2):int(c[j,0]+rws[j]//2)] 73 | hair = masked(fakeAB, maskh) 74 | bg = masked(fakeAB, maskb) 75 | locals += [hair, bg] 76 | return locals 77 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # APDrawingGAN Jittor Implementation 2 | 3 | We provide [Jittor](https://github.com/Jittor/jittor) implementations for our CVPR 2019 oral paper "APDrawingGAN: Generating Artistic Portrait Drawings from Face Photos with Hierarchical GANs". [[Paper]](http://openaccess.thecvf.com/content_CVPR_2019/html/Yi_APDrawingGAN_Generating_Artistic_Portrait_Drawings_From_Face_Photos_With_Hierarchical_CVPR_2019_paper.html) 4 | 5 | This project generates artistic portrait drawings from face photos using a GAN-based model. 6 | 7 | ## Prerequisites 8 | - Linux or macOS 9 | - Python 3 10 | - CPU or NVIDIA GPU + CUDA CuDNN 11 | 12 | ## Sample Results 13 | Up: input, Down: output 14 |

15 | 16 | 17 | 18 | 19 | 20 |

21 |

22 | 23 | 24 | 25 | 26 | 27 |

28 | 29 | ## Installation 30 | - To install the dependencies, run 31 | ```bash 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ## Apply pretrained model 36 | 37 | - 1. Download pre-trained models from [BaiduYun](https://pan.baidu.com/s/1ZTyV4gpMU45tp4d-bLwfmQ)(extract code: 9qhp) and rename the folder to `checkpoints`. 38 | 39 | - 2. Test for example photos: generate artistic portrait drawings for example photos in the folder `./samples/A/example` using models in `checkpoints/formal_author` 40 | ``` bash 41 | python test.py 42 | ``` 43 | Results are saved in `./results/portrait_drawing/formal_author_300/example` 44 | 45 | - 3. To test on your own photos: First run preprocess [here](preprocess/readme.md)). Then specify the folder that contains test photos using option `--input_folder`, specify the folder of landmarks using `--lm_folder` and the folder of masks using `--mask_folder`, and run the `test.py` again. 46 | 47 | ## Train 48 | 49 | - 1. Download the APDrawing dataset from [GoogleDrive](https://drive.google.com/file/d/1vm8uwNcy113f-TZmRtgh_H4djNEaLXvO/view?usp=sharing) and put the folder to `data/apdrawing`. 50 | 51 | - 2. Train our model (300 epochs) 52 | ``` bash 53 | python apdrawing_gan.py 54 | ``` 55 | Models are saved in folder `checkpoints/apdrawing` 56 | 57 | - 4. Test the trained model 58 | ``` bash 59 | python test.py --which_epoch 300 --model_name apdrawing 60 | ``` 61 | Results are saved in `./results/portrait_drawing/apdrawing_300/example` 62 | 63 | ## Citation 64 | If you use this code or APDrawing dataset for your research, please cite our paper. 65 | 66 | ``` 67 | @inproceedings{YiLLR19, 68 | title = {{APDrawingGAN}: Generating Artistic Portrait Drawings from Face Photos with Hierarchical GANs}, 69 | author = {Yi, Ran and Liu, Yong-Jin and Lai, Yu-Kun and Rosin, Paul L}, 70 | booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition (CVPR '19)}, 71 | pages = {10743--10752}, 72 | year = {2019} 73 | } 74 | ``` -------------------------------------------------------------------------------- /preprocess/readme.md: -------------------------------------------------------------------------------- 1 | ## Preprocessing steps 2 | 3 | Face photos (and paired drawings) need to be aligned and have background mask detected. Aligned images, facial lamdmark files (txt) and background masks are needed for training and testing. 4 | 5 | ### 1. Align, resize, crop images to 512x512 and prepare facial landmarks 6 | 7 | All training and testing images in our model are aligned using facial landmarks. And landmarks after alignment are needed in our code. 8 | 9 | - First, 5 facial landmark for a face photo need to be detected (we detect using [MTCNN](https://github.com/kpzhang93/MTCNN_face_detection_alignment)(MTCNNv1)). 10 | 11 | - Then, we provide a matlab function in `face_align_512.m` to align, resize and crop face photos (and corresponding drawings) to 512x512.Call this function in MATLAB to align the image to 512x512. 12 | For example, for `img_1701.jpg` in `example` dir, 5 detected facial landmark is saved in `example/img_1701_facial5point.mat`. Call following in MATLAB: 13 | ```bash 14 | load('example/img_1701_facial5point.mat'); 15 | [trans_img,trans_facial5point]=face_align_512('example/img_1701.jpg',facial5point,'example'); 16 | ``` 17 | 18 | This will align the image and output aligned image and transformed facial landmark (in txt format) in `example` folder. 19 | See `face_align_512.m` for more instructions. 20 | 21 | - The saved transformed facial landmark need to be copied to `--lm_folder` (default is `./samples/A_landmark/example`), and has the **same filename** with aligned face photos (e.g. `./samples/A/example/girl-2099354_1920.png` should have landmark file `./samples/A_landmark/example/girl-2099354_1920.txt`). 22 | 23 | ### 2. Prepare background masks 24 | 25 | Background masks are needed in our code. 26 | 27 | In our work, background mask is segmented by method in 28 | "Automatic Portrait Segmentation for Image Stylization" 29 | Xiaoyong Shen, Aaron Hertzmann, Jiaya Jia, Sylvain Paris, Brian Price, Eli Shechtman, Ian Sachs. Computer Graphics Forum, 35(2)(Proc. Eurographics), 2016. 30 | 31 | - We use code in http://xiaoyongshen.me/webpage_portrait/index.html to detect background masks for face photos. 32 | A sample background mask is shown in `example/img_1701_aligned_bgmask.png`. 33 | 34 | - The background masks need to be copied to `--mask_folder` (default is `./samples/A_mask/example`), and has the **same filename** with aligned face photos (e.g. `./samples/A/example/girl-2099354_1920.png` should have background mask `./samples/A_mask/example/girl-2099354_1920.png`) 35 | 36 | 37 | ### 3. (For training) Prepare more training data 38 | 39 | We provide a python script to generate training data in the form of pairs of images {A,B}, i.e. pairs {face photo, drawing}. This script will concatenate each pair of images horizontally into one single image. Then we can learn to translate A to B: 40 | 41 | Create folder `/path/to/data` with subfolders `A` and `B`. `A` and `B` should each have their own subfolders `train`, `test`, etc. In `/path/to/data/A/train`, put training face photos. In `/path/to/data/B/train`, put the corresponding artist drawings. Repeat same for `test`. 42 | 43 | Corresponding images in a pair {A,B} must both be images after aligning and of size 512x512, and have the same filename, e.g., `/path/to/data/A/train/1.png` is considered to correspond to `/path/to/data/B/train/1.png`. 44 | 45 | Once the data is formatted this way, call: 46 | ```bash 47 | python datasets/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data 48 | ``` 49 | 50 | This will combine each pair of images (A,B) into a single image file, ready for training. -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import math 5 | import datetime 6 | import time 7 | 8 | from models import * 9 | from datasets import * 10 | from utils import * 11 | 12 | jt.flags.use_cuda = 1 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--input_folder", type=str, default="./samples/A/example", help="the folder of input photos") 16 | parser.add_argument("--lm_folder", type=str, default="./samples/A_landmark/example", help="the folder of input landmarks") 17 | parser.add_argument("--mask_folder", type=str, default="./samples/A_mask/example", help="the folder of foreground landmarks") 18 | parser.add_argument("--model_name", type=str, default="formal_author", help="the load folder of model") 19 | parser.add_argument("--which_epoch", type=int, default=300, help="number of epoch to load") 20 | parser.add_argument("--dataset_name", type=str, default="portrait_drawing", help="name of the dataset") 21 | parser.add_argument("--batch_size", type=int, default=1, help="size of the batches") 22 | parser.add_argument("--img_height", type=int, default=512, help="size of image height") 23 | parser.add_argument("--img_width", type=int, default=512, help="size of image width") 24 | parser.add_argument("--in_channels", type=int, default=3, help="number of input channels") 25 | parser.add_argument("--out_channels", type=int, default=1, help="number of output channels") 26 | parser.add_argument("--n_residual_blocks", type=int, default=9, help="number of residual blocks in generator") 27 | parser.add_argument("--extra_channel", type=int, default=3, help="extra channel for style feature") 28 | opt = parser.parse_args() 29 | print(opt) 30 | 31 | # Create save directories 32 | data_subdir = opt.input_folder.split('/')[-1] 33 | save_folder = "results/%s/%s_%d/%s" % (opt.dataset_name, opt.model_name, opt.which_epoch, data_subdir) 34 | os.makedirs(save_folder, exist_ok=True) 35 | 36 | input_shape = (opt.in_channels, opt.img_height, opt.img_width) 37 | output_shape = (opt.out_channels, opt.img_height, opt.img_width) 38 | 39 | # Initialize generator 40 | G_global = GeneratorUNet(in_channels=opt.in_channels, out_channels=opt.out_channels) 41 | G_l_eyel = PartUnet(in_channels=opt.in_channels, out_channels=opt.out_channels) 42 | G_l_eyer = PartUnet(in_channels=opt.in_channels, out_channels=opt.out_channels) 43 | G_l_nose = PartUnet(in_channels=opt.in_channels, out_channels=opt.out_channels) 44 | G_l_mouth = PartUnet(in_channels=opt.in_channels, out_channels=opt.out_channels) 45 | G_l_hair = PartUnet2(in_channels=opt.in_channels, out_channels=opt.out_channels) 46 | G_l_bg = PartUnet2(in_channels=opt.in_channels, out_channels=opt.out_channels) 47 | G_combine = Combiner(in_channels=2*opt.out_channels, out_channels=opt.out_channels) 48 | 49 | # Load weight 50 | model_path = os.path.join("checkpoints", opt.model_name, "{}_net_gen.pth".format(opt.which_epoch)) 51 | if os.path.exists(model_path): 52 | state_dict = jt.safeunpickle(model_path) 53 | G_global.load_state_dict(state_dict['G']) 54 | G_l_eyel.load_state_dict(state_dict['GLEyel']) 55 | G_l_eyer.load_state_dict(state_dict['GLEyer']) 56 | G_l_nose.load_state_dict(state_dict['GLNose']) 57 | G_l_mouth.load_state_dict(state_dict['GLMouth']) 58 | G_l_hair.load_state_dict(state_dict['GLHair']) 59 | G_l_bg.load_state_dict(state_dict['GLBG']) 60 | G_combine.load_state_dict(state_dict['GCombine']) 61 | else: 62 | G_global.load(os.path.join("checkpoints", opt.model_name, "{}_net_G_global.pkl".format(opt.which_epoch))) 63 | G_l_eyel.load(os.path.join("checkpoints", opt.model_name, "{}_net_G_l_eyel.pkl".format(opt.which_epoch))) 64 | G_l_eyer.load(os.path.join("checkpoints", opt.model_name, "{}_net_G_l_eyer.pkl".format(opt.which_epoch))) 65 | G_l_nose.load(os.path.join("checkpoints", opt.model_name, "{}_net_G_l_nose.pkl".format(opt.which_epoch))) 66 | G_l_mouth.load(os.path.join("checkpoints", opt.model_name, "{}_net_G_l_mouth.pkl".format(opt.which_epoch))) 67 | G_l_hair.load(os.path.join("checkpoints", opt.model_name, "{}_net_G_l_hair.pkl".format(opt.which_epoch))) 68 | G_l_bg.load(os.path.join("checkpoints", opt.model_name, "{}_net_G_l_bg.pkl".format(opt.which_epoch))) 69 | G_combine.load(os.path.join("checkpoints", opt.model_name, "{}_net_G_combine.pkl".format(opt.which_epoch))) 70 | 71 | # Test data loader 72 | test_dataloader = TestDataset(opt.input_folder, opt.lm_folder, opt.mask_folder, mode="test", load_h=opt.img_height, load_w=opt.img_width).set_attrs(batch_size=1, shuffle=False, num_workers=1) 73 | import cv2 74 | def save_single_image(img, path): 75 | N,C,W,H = img.shape 76 | img = img[0] 77 | min_ = -1 78 | max_ = 1 79 | img=(img-min_)/(max_-min_)*255 80 | img=img.transpose((1,2,0)) 81 | if C==3: 82 | img = img[:,:,::-1] 83 | cv2.imwrite(path,img) 84 | 85 | # ---------- 86 | # Testing 87 | # ---------- 88 | 89 | prev_time = time.time() 90 | for i, batches in enumerate(test_dataloader): 91 | 92 | real_A = batches[0] 93 | real_A_eyel = batches[1] 94 | real_A_eyer = batches[2] 95 | real_A_nose = batches[3] 96 | real_A_mouth = batches[4] 97 | real_A_hair = batches[5] 98 | real_A_bg = batches[6] 99 | mask = batches[7] 100 | mask2 = batches[8] 101 | center = batches[9] 102 | 103 | maskh = mask*mask2 104 | maskb = inverse_mask(mask2) 105 | 106 | fake_B0 = G_global(real_A) 107 | # EYES, NOSE, MOUTH 108 | fake_B_eyel = G_l_eyel(real_A_eyel) 109 | fake_B_eyer = G_l_eyer(real_A_eyer) 110 | fake_B_nose = G_l_nose(real_A_nose) 111 | fake_B_mouth = G_l_mouth(real_A_mouth) 112 | # HAIR, BG AND PARTCOMBINE 113 | fake_B_hair = G_l_hair(real_A_hair) 114 | fake_B_bg = G_l_bg(real_A_bg) 115 | fake_B1 = partCombiner2_bg(center, fake_B_eyel, fake_B_eyer, fake_B_nose, fake_B_mouth, fake_B_hair, fake_B_bg, maskh, maskb, comb_op=1, load_h=opt.img_height, load_w=opt.img_width) 116 | # FUSION NET 117 | fake_B = G_combine(jt.contrib.concat((fake_B0, fake_B1), 1)) 118 | 119 | save_single_image(real_A.numpy(), "%s/%d_real.png" % (save_folder, i)) 120 | save_single_image(fake_B.numpy(), "%s/%d_fake.png" % (save_folder, i)) 121 | print("Test time: %.2f" % (time.time() - prev_time)) -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | 2 | import jittor as jt 3 | from jittor import init 4 | from jittor import nn 5 | 6 | import pdb 7 | 8 | def weights_init_normal(m): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | jt.init.gauss_(m.weight, 0.0, 0.02) 12 | if (hasattr(m, 'bias') and (m.bias is not None)): 13 | jt.init.constant_(m.bias, 0.0) 14 | elif classname.find("BatchNorm") != -1: 15 | jt.init.gauss_(m.weight, 1.0, 0.02) 16 | jt.init.constant_(m.bias, 0.0) 17 | 18 | class ResidualBlock(nn.Module): 19 | 20 | def __init__(self, in_features, dropout=0.0): 21 | super(ResidualBlock, self).__init__() 22 | model = [nn.ReflectionPad2d(1), nn.Conv(in_features, in_features, 3), nn.BatchNorm2d(in_features), nn.ReLU()] 23 | if dropout: 24 | model += [nn.Dropout(dropout)] 25 | model += [nn.ReflectionPad2d(1), nn.Conv(in_features, in_features, 3), nn.BatchNorm2d(in_features)] 26 | self.conv_block = nn.Sequential(*model) 27 | 28 | def execute(self, x): 29 | return (x + self.conv_block(x)) 30 | 31 | class UNetDown(nn.Module): 32 | 33 | def __init__(self, in_size, out_size, normalize=True, dropout=0.0): 34 | super(UNetDown, self).__init__() 35 | layers = [nn.Conv(in_size, out_size, 4, stride=2, padding=1, bias=False)] 36 | if normalize: 37 | layers.append(nn.BatchNorm2d(out_size)) 38 | layers.append(nn.LeakyReLU(scale=0.2)) 39 | if dropout: 40 | layers.append(nn.Dropout(dropout)) 41 | self.model = nn.Sequential(*layers) 42 | 43 | def execute(self, x): 44 | return self.model(x) 45 | 46 | class UNetUp(nn.Module): 47 | 48 | def __init__(self, in_size, out_size, dropout=0.0): 49 | super(UNetUp, self).__init__() 50 | layers = [nn.ConvTranspose(in_size, out_size, 4, stride=2, padding=1, bias=False), nn.BatchNorm2d(out_size), nn.ReLU()] 51 | if dropout: 52 | layers.append(nn.Dropout(dropout)) 53 | self.model = nn.Sequential(*layers) 54 | 55 | def execute(self, x, skip_input): 56 | x = self.model(x) 57 | x = jt.contrib.concat((x, skip_input), dim=1) 58 | return x 59 | 60 | class UnetBlock(nn.Module): 61 | 62 | def __init__(self, in_size, out_size, inner_nc, dropout=0.0, innermost=False, outermost=False, submodule=None): 63 | super(UnetBlock, self).__init__() 64 | self.outermost = outermost 65 | 66 | downconv = nn.Conv(in_size, inner_nc, 4, stride=2, padding=1, bias=False) 67 | downnorm = nn.BatchNorm2d(inner_nc) 68 | downrelu = nn.LeakyReLU(0.2) 69 | upnorm = nn.BatchNorm2d(out_size) 70 | uprelu = nn.ReLU() 71 | 72 | if outermost: 73 | upconv = nn.ConvTranspose(2*inner_nc, out_size, 4, stride=2, padding=1) 74 | down = [downconv] 75 | up = [uprelu, upconv, nn.Tanh()] 76 | model = down + [submodule] + up 77 | elif innermost: 78 | upconv = nn.ConvTranspose(inner_nc, out_size, 4, stride=2, padding=1, bias=False) 79 | down = [downrelu, downconv] 80 | up = [uprelu, upconv, upnorm] 81 | model = down + up 82 | else: 83 | upconv = nn.ConvTranspose(2*inner_nc, out_size, 4, stride=2, padding=1, bias=False) 84 | down = [downrelu, downconv, downnorm] 85 | up = [uprelu, upconv, upnorm] 86 | if dropout: 87 | model = down + [submodule] + up + [nn.Dropout(dropout)] 88 | else: 89 | model = down + [submodule] + up 90 | 91 | self.model = nn.Sequential(*model) 92 | 93 | for m in self.modules(): 94 | weights_init_normal(m) 95 | 96 | def execute(self, x): 97 | if self.outermost: 98 | return self.model(x) 99 | else: 100 | return jt.contrib.concat((x, self.model(x)), dim=1) 101 | 102 | 103 | class GeneratorUNet(nn.Module): 104 | 105 | def __init__(self, in_channels=3, out_channels=1, num_downs=8): 106 | super(GeneratorUNet, self).__init__() 107 | 108 | unet_block = UnetBlock(512, 512, inner_nc=512, submodule=None, innermost=True) # down8, up1 109 | for i in range(num_downs - 5): 110 | unet_block = UnetBlock(512, 512, inner_nc=512, submodule=unet_block, dropout=0.5) 111 | unet_block = UnetBlock(256, 256, inner_nc=512, submodule=unet_block) # down4, up5 112 | unet_block = UnetBlock(128, 128, inner_nc=256, submodule=unet_block) # down3, up6 113 | unet_block = UnetBlock(64, 64, inner_nc=128, submodule=unet_block) # down2, up7 114 | unet_block = UnetBlock(in_channels, out_channels, inner_nc=64, submodule=unet_block, outermost=True) # down1, final 115 | 116 | self.model = unet_block 117 | 118 | for m in self.modules(): 119 | weights_init_normal(m) 120 | 121 | def execute(self, x): 122 | return self.model(x) 123 | 124 | class PartUnet(nn.Module): 125 | def __init__(self, in_channels=3, out_channels=1): 126 | super(PartUnet, self).__init__() 127 | 128 | unet_block = UnetBlock(128, 128, inner_nc=256, submodule=None, innermost=True) 129 | unet_block = UnetBlock(64, 64, inner_nc=128, submodule=unet_block) 130 | unet_block = UnetBlock(in_channels, out_channels, inner_nc=64, submodule=unet_block, outermost=True) 131 | self.model = unet_block 132 | 133 | for m in self.modules(): 134 | weights_init_normal(m) 135 | 136 | def execute(self, x): 137 | return self.model(x) 138 | 139 | class PartUnet2(nn.Module): 140 | def __init__(self, in_channels=3, out_channels=1): 141 | super(PartUnet2, self).__init__() 142 | 143 | unet_block = UnetBlock(128, 128, inner_nc=128, submodule=None, innermost=True) 144 | unet_block = UnetBlock(128, 128, inner_nc=128, submodule=unet_block) 145 | unet_block = UnetBlock(64, 64, inner_nc=128, submodule=unet_block) 146 | unet_block = UnetBlock(in_channels, out_channels, inner_nc=64, submodule=unet_block, outermost=True) 147 | self.model = unet_block 148 | 149 | for m in self.modules(): 150 | weights_init_normal(m) 151 | 152 | def execute(self, x): 153 | return self.model(x) 154 | 155 | class Combiner(nn.Module): 156 | def __init__(self, in_channels=3, out_channels=1): 157 | super(Combiner, self).__init__() 158 | 159 | model = [nn.ReflectionPad2d(3), 160 | nn.Conv(in_channels, 64, 7, padding=0, bias=False), 161 | nn.BatchNorm2d(64), 162 | nn.ReLU()] 163 | 164 | for i in range(2): 165 | model += [ResidualBlock(64, dropout=0.5)] 166 | 167 | model += [nn.ReflectionPad2d(3), 168 | nn.Conv(64, out_channels, kernel_size=7, padding=0), 169 | nn.Tanh()] 170 | 171 | self.model = nn.Sequential(*model) 172 | 173 | def execute(self, x): 174 | return self.model(x) 175 | 176 | class Discriminator(nn.Module): 177 | 178 | def __init__(self, in_channels=3, out_channels=1): 179 | super(Discriminator, self).__init__() 180 | 181 | def discriminator_block(in_filters, out_filters, stride=2, normalization=True): 182 | 'Returns downsampling layers of each discriminator block' 183 | layers = [nn.Conv(in_filters, out_filters, 4, stride=stride, padding=1)] 184 | if normalization: 185 | layers.append(nn.BatchNorm2d(out_filters)) 186 | layers.append(nn.LeakyReLU(scale=0.2)) 187 | return layers 188 | self.model = nn.Sequential(*discriminator_block((in_channels+out_channels), 64, normalization=False), *discriminator_block(64, 128), *discriminator_block(128, 256), *discriminator_block(256, 512, stride=1), nn.Conv(512, 1, 4, stride=1, padding=1), nn.Sigmoid()) 189 | 190 | for m in self.modules(): 191 | weights_init_normal(m) 192 | 193 | def execute(self, img_A, img_B): 194 | img_input = jt.contrib.concat((img_A, img_B), dim=1) 195 | return self.model(img_input) 196 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | import os 4 | import numpy as np 5 | 6 | import jittor as jt 7 | from jittor.dataset.dataset import Dataset 8 | import jittor.transform as transform 9 | from PIL import Image 10 | import csv 11 | import random 12 | import cv2 13 | 14 | EYE_H = 40 15 | EYE_W = 56 16 | NOSE_H = 48 17 | NOSE_W = 48 18 | MOUTH_H = 40 19 | MOUTH_W = 64 20 | 21 | def getfeats(featpath): 22 | trans_points = np.empty([5,2],dtype=np.int64) 23 | with open(featpath, 'r') as csvfile: 24 | reader = csv.reader(csvfile, delimiter=' ') 25 | for ind,row in enumerate(reader): 26 | trans_points[ind,:] = row 27 | return trans_points 28 | 29 | def tocv2(ts): 30 | img = (ts.numpy()/2+0.5)*255 31 | img = img.astype('uint8') 32 | img = np.transpose(img,(1,2,0)) 33 | img = img[:,:,::-1]#rgb->bgr 34 | return img 35 | 36 | def dt(img): 37 | if(img.shape[2]==3): 38 | img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) 39 | #convert to BW 40 | ret1,thresh1 = cv2.threshold(img,127,255,cv2.THRESH_BINARY) 41 | ret2,thresh2 = cv2.threshold(img,127,255,cv2.THRESH_BINARY_INV) 42 | dt1 = cv2.distanceTransform(thresh1,cv2.DIST_L2,5) 43 | dt2 = cv2.distanceTransform(thresh2,cv2.DIST_L2,5) 44 | dt1 = dt1/dt1.max()#->[0,1] 45 | dt2 = dt2/dt2.max() 46 | return dt1, dt2 47 | 48 | 49 | def get_transform(params, gray = False, mask = False): 50 | transform_ = [] 51 | # resize 52 | transform_.append(transform.Resize((params['load_h'], params['load_w']), Image.BICUBIC)) 53 | # flip 54 | if params['flip']: 55 | transform_.append(transform.Lambda(lambda img: transform.hflip(img))) 56 | if gray: 57 | transform_.append(transform.Gray()) 58 | if mask: 59 | transform_.append(transform.ImageNormalize([0.,], [1.,])) 60 | else: 61 | if not gray: 62 | transform_.append(transform.ImageNormalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])) 63 | else: 64 | transform_.append(transform.ImageNormalize([0.5,], [0.5,])) 65 | return transform.Compose(transform_) 66 | 67 | class ImageDataset(Dataset): 68 | def __init__(self, root, mode="train", load_h=512, load_w=512): 69 | super().__init__() 70 | self.files = sorted(glob.glob(os.path.join(root, mode, "img") + "/*.*")) 71 | self.lmdir = os.path.join(root, mode, "landmark") 72 | self.maskdir = os.path.join(root, mode, "mask") 73 | self.set_attrs(total_len=len(self.files)) 74 | self.load_h = load_h 75 | self.load_w = load_w 76 | 77 | def __getitem__(self, index): 78 | 79 | AB_path = self.files[index % len(self.files)] 80 | img = Image.open(AB_path) 81 | w, h = img.size 82 | img_A = img.crop((0, 0, w / 2, h)) 83 | img_B = img.crop((w / 2, 0, w, h)) 84 | 85 | flip = random.random() > 0.5 86 | 87 | params = {'load_h': self.load_h, 'load_w': self.load_w, 'flip': flip} 88 | transform_A = get_transform(params) 89 | transform_B = get_transform(params, gray=True) 90 | transform_mask = get_transform(params, gray=True, mask=True) 91 | 92 | item_A = transform_A(img_A) 93 | item_A = jt.array(item_A) 94 | item_B = transform_B(img_B) 95 | item_B = jt.array(item_B) 96 | 97 | item_A_l = {} 98 | regions = ['eyel','eyer','nose','mouth'] 99 | basen = os.path.basename(AB_path)[:-4] 100 | lm_path = os.path.join(self.lmdir, basen+'.txt') 101 | feats = getfeats(lm_path) 102 | if flip: 103 | for i in range(5): 104 | feats[i,0] = self.load_w - feats[i,0] - 1 105 | tmp = [feats[0,0],feats[0,1]] 106 | feats[0,:] = [feats[1,0],feats[1,1]] 107 | feats[1,:] = tmp 108 | mouth_x = int((feats[3,0]+feats[4,0])/2.0) 109 | mouth_y = int((feats[3,1]+feats[4,1])/2.0) 110 | ratio = self.load_h // 256 111 | rhs = np.array([EYE_H,EYE_H,NOSE_H,MOUTH_H]) * ratio 112 | rws = np.array([EYE_W,EYE_W,NOSE_W,MOUTH_W]) * ratio 113 | center = np.array([[feats[0,0],feats[0,1]-4*ratio],[feats[1,0],feats[1,1]-4*ratio],[feats[2,0],feats[2,1]-rhs[2]//2+16*ratio],[mouth_x,mouth_y]]) 114 | 115 | for i in range(4): 116 | item_A_l[regions[i]+'_A'] = item_A[:,int(center[i,1]-rhs[i]/2):int(center[i,1]+rhs[i]/2),int(center[i,0]-rws[i]/2):int(center[i,0]+rws[i]/2)] 117 | 118 | mask = jt.ones([1,item_A.shape[1],item_A.shape[2]]) # mask out eyes, nose, mouth 119 | for i in range(4): 120 | mask[:,int(center[i,1]-rhs[i]/2):int(center[i,1]+rhs[i]/2),int(center[i,0]-rws[i]/2):int(center[i,0]+rws[i]/2)] = 0 121 | 122 | bgpath = os.path.join(self.maskdir, basen+'.png') 123 | im_bg = Image.open(bgpath) 124 | mask2 = transform_mask(im_bg) # mask out background 125 | mask2 = jt.array(mask2) 126 | mask2 = (mask2 >= 0.5).float() # foreground: 1, background: 0 127 | item_A_l['hair_A'] = (item_A/2+0.5) * mask.repeat(3,1,1) * mask2.repeat(3,1,1) * 2 - 1 128 | item_A_l['bg_A'] = (item_A/2+0.5) * (jt.ones(mask2.shape)-mask2).repeat(3,1,1) * 2 - 1 129 | 130 | img = tocv2(item_B) 131 | dt1, dt2 = dt(img) 132 | dt1 = jt.array(dt1) 133 | dt2 = jt.array(dt2) 134 | dt1 = dt1.unsqueeze(0) 135 | dt2 = dt2.unsqueeze(0) 136 | 137 | return item_A, item_A_l['eyel_A'], item_A_l['eyer_A'], item_A_l['nose_A'], item_A_l['mouth_A'], item_A_l['hair_A'], item_A_l['bg_A'], mask, mask2, center, item_B, dt1, dt2 138 | 139 | class TestDataset(Dataset): 140 | def __init__(self, root, lmdir, maskdir, mode="test", load_h=512, load_w=512): 141 | super().__init__() 142 | transform_ = [ 143 | transform.Resize((load_h, load_w), Image.BICUBIC), 144 | transform.ImageNormalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 145 | ] 146 | self.transform = transform.Compose(transform_) 147 | transform_mask_ = [ 148 | transform.Resize((load_h, load_w), Image.BICUBIC), 149 | transform.Gray(), 150 | ] 151 | self.transform_mask = transform.Compose(transform_mask_) 152 | 153 | self.files_A = sorted(glob.glob(root + "/*.*")) 154 | 155 | self.total_len = len(self.files_A) 156 | self.batch_size = None 157 | self.shuffle = False 158 | self.drop_last = False 159 | self.num_workers = None 160 | self.buffer_size = 512*1024*1024 161 | 162 | self.lmdir = lmdir 163 | self.maskdir = maskdir 164 | self.load_h = load_h 165 | 166 | def __getitem__(self, index): 167 | A_path = self.files_A[index % len(self.files_A)] 168 | image_A = Image.open(A_path) 169 | 170 | # Convert grayscale images to rgb 171 | if image_A.mode != "RGB": 172 | image_A = to_rgb(image_A) 173 | 174 | item_A = self.transform(image_A) 175 | item_A = jt.array(item_A) 176 | 177 | item_A_l = {} 178 | regions = ['eyel','eyer','nose','mouth'] 179 | basen = os.path.basename(A_path)[:-4] 180 | lm_path = os.path.join(self.lmdir, basen+'.txt') 181 | feats = getfeats(lm_path) 182 | mouth_x = int((feats[3,0]+feats[4,0])/2.0) 183 | mouth_y = int((feats[3,1]+feats[4,1])/2.0) 184 | ratio = self.load_h // 256 185 | rhs = np.array([EYE_H,EYE_H,NOSE_H,MOUTH_H]) * ratio 186 | rws = np.array([EYE_W,EYE_W,NOSE_W,MOUTH_W]) * ratio 187 | center = np.array([[feats[0,0],feats[0,1]-4*ratio],[feats[1,0],feats[1,1]-4*ratio],[feats[2,0],feats[2,1]-rhs[2]//2+16*ratio],[mouth_x,mouth_y]]) 188 | 189 | for i in range(4): 190 | item_A_l[regions[i]+'_A'] = item_A[:,int(center[i,1]-rhs[i]/2):int(center[i,1]+rhs[i]/2),int(center[i,0]-rws[i]/2):int(center[i,0]+rws[i]/2)] 191 | 192 | mask = jt.ones([1,item_A.shape[1],item_A.shape[2]]) # mask out eyes, nose, mouth 193 | for i in range(4): 194 | mask[:,int(center[i,1]-rhs[i]/2):int(center[i,1]+rhs[i]/2),int(center[i,0]-rws[i]/2):int(center[i,0]+rws[i]/2)] = 0 195 | 196 | bgpath = os.path.join(self.maskdir, basen+'.png') 197 | im_bg = Image.open(bgpath) 198 | mask2 = self.transform_mask(im_bg) # mask out background 199 | mask2 = jt.array(mask2) 200 | mask2 = (mask2 >= 0.5).float() # foreground: 1, background: 0 201 | item_A_l['hair_A'] = (item_A/2+0.5) * mask.repeat(3,1,1) * mask2.repeat(3,1,1) * 2 - 1 202 | item_A_l['bg_A'] = (item_A/2+0.5) * (jt.ones(mask2.shape)-mask2).repeat(3,1,1) * 2 - 1 203 | 204 | return item_A, item_A_l['eyel_A'], item_A_l['eyer_A'], item_A_l['nose_A'], item_A_l['mouth_A'], item_A_l['hair_A'], item_A_l['bg_A'], mask, mask2, center -------------------------------------------------------------------------------- /apdrawing_gan.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import init 3 | from jittor import nn 4 | import jittor.transform as transform 5 | import argparse 6 | import os 7 | import numpy as np 8 | import math 9 | import itertools 10 | import time 11 | import datetime 12 | import sys 13 | import cv2 14 | import time 15 | 16 | from models import * 17 | from datasets import * 18 | from utils import * 19 | 20 | import warnings 21 | warnings.filterwarnings("ignore") 22 | 23 | jt.flags.use_cuda = 1 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from") 27 | parser.add_argument("--n_epochs", type=int, default=300, help="number of epochs of training") 28 | parser.add_argument("--dataset_name", type=str, default="apdrawing", help="name of the dataset") 29 | parser.add_argument("--batch_size", type=int, default=1, help="size of the batches") 30 | parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") 31 | parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") 32 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") 33 | parser.add_argument('--load_pre_train', type=int, default=1, help='whether load pre-trained model') 34 | parser.add_argument('--load_pre_train_name', type=str, default="pre-training", help='the path to load pre-trained model') 35 | parser.add_argument('--load_auxiliary_name', type=str, default="auxiliary", help='the path to load auxiliary model') 36 | parser.add_argument('--nepoch', type=int, default=300, help='# of epoch at starting learning rate') 37 | parser.add_argument('--nepoch_decay', type=int, default=0, help='# of epoch to linearly decay learning rate to zero') 38 | parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") 39 | parser.add_argument("--img_height", type=int, default=512, help="size of image height") 40 | parser.add_argument("--img_width", type=int, default=512, help="size of image width") 41 | parser.add_argument("--in_channels", type=int, default=3, help="number of input channels") 42 | parser.add_argument("--out_channels", type=int, default=1, help="number of output channels") 43 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 44 | parser.add_argument('--lambda_local', type=float, default=25.0, help='weight for Local loss') 45 | parser.add_argument('--lambda_chamfer', type=float, default=0.1, help='weight for chamfer loss') 46 | parser.add_argument('--lambda_chamfer2', type=float, default=0.1, help='weight for chamfer loss2') 47 | parser.add_argument( 48 | "--sample_interval", type=int, default=400, help="interval between sampling of images from generators" 49 | ) 50 | parser.add_argument("--checkpoint_interval", type=int, default=10, help="interval between model checkpoints") 51 | parser.add_argument("--val_input_folder", type=str, default="./samples/A/example", help="the folder of input photos") 52 | parser.add_argument("--val_lm_folder", type=str, default="./samples/A_landmark/example", help="the folder of input landmarks") 53 | parser.add_argument("--val_mask_folder", type=str, default="./samples/A_mask/example", help="the folder of foreground landmarks") 54 | opt = parser.parse_args() 55 | print(opt) 56 | 57 | def save_image(img, path, nrow=10): 58 | N,C,W,H = img.shape 59 | if (N%nrow!=0): 60 | print("save_image error: N%nrow!=0") 61 | return 62 | img=img.transpose((1,0,2,3)) 63 | ncol=int(N/nrow) 64 | img2=img.reshape([img.shape[0],-1,H]) 65 | img=img2[:,:W*ncol,:] 66 | for i in range(1,int(img2.shape[1]/W/ncol)): 67 | img=np.concatenate([img,img2[:,W*ncol*i:W*ncol*(i+1),:]],axis=2) 68 | min_=img.min() 69 | max_=img.max() 70 | img=(img-min_)/(max_-min_)*255 71 | img=img.transpose((1,2,0)) 72 | if C==3: 73 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 74 | cv2.imwrite(path,img) 75 | 76 | def save_single_image(img, path): 77 | N,C,W,H = img.shape 78 | img = img[0] 79 | min_ = -1 80 | max_ = 1 81 | img=(img-min_)/(max_-min_)*255 82 | img=img.transpose((1,2,0)) 83 | if C==3: 84 | img = img[:,:,::-1] 85 | cv2.imwrite(path,img) 86 | 87 | os.makedirs("images/%s" % opt.dataset_name, exist_ok=True) 88 | os.makedirs("checkpoints/%s" % opt.dataset_name, exist_ok=True) 89 | 90 | # Loss functions 91 | criterion_GAN = nn.BCELoss() # no lsgan 92 | criterion_pixelwise = nn.L1Loss() 93 | 94 | # Calculate output of image discriminator (PatchGAN) 95 | patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4) 96 | 97 | # Initialize generator and discriminator 98 | G_global = GeneratorUNet(in_channels=opt.in_channels, out_channels=opt.out_channels) 99 | G_l_eyel = PartUnet(in_channels=opt.in_channels, out_channels=opt.out_channels) 100 | G_l_eyer = PartUnet(in_channels=opt.in_channels, out_channels=opt.out_channels) 101 | G_l_nose = PartUnet(in_channels=opt.in_channels, out_channels=opt.out_channels) 102 | G_l_mouth = PartUnet(in_channels=opt.in_channels, out_channels=opt.out_channels) 103 | G_l_hair = PartUnet2(in_channels=opt.in_channels, out_channels=opt.out_channels) 104 | G_l_bg = PartUnet2(in_channels=opt.in_channels, out_channels=opt.out_channels) 105 | G_combine = Combiner(in_channels=2*opt.out_channels, out_channels=opt.out_channels) 106 | G_nets = [G_global, G_l_eyel, G_l_eyer, G_l_nose, G_l_mouth, G_l_hair, G_l_bg, G_combine] 107 | D_global = Discriminator() 108 | D_l_eyel = Discriminator() 109 | D_l_eyer = Discriminator() 110 | D_l_nose = Discriminator() 111 | D_l_mouth = Discriminator() 112 | D_l_hair = Discriminator() 113 | D_l_bg = Discriminator() 114 | D_nets = [D_global, D_l_eyel, D_l_eyer, D_l_nose, D_l_mouth, D_l_hair, D_l_bg] 115 | 116 | if opt.load_pre_train != 0: 117 | # Load pretrained models using npr data 118 | gen_model_path = os.path.join("checkpoints", opt.load_pre_train_name, "latest_net_gen.pth") 119 | gen_state_dict = jt.safeunpickle(gen_model_path) 120 | G_global.load_state_dict(gen_state_dict['G']) 121 | G_l_eyel.load_state_dict(gen_state_dict['GLEyel']) 122 | G_l_eyer.load_state_dict(gen_state_dict['GLEyer']) 123 | G_l_nose.load_state_dict(gen_state_dict['GLNose']) 124 | G_l_mouth.load_state_dict(gen_state_dict['GLMouth']) 125 | G_l_hair.load_state_dict(gen_state_dict['GLHair']) 126 | G_l_bg.load_state_dict(gen_state_dict['GLBG']) 127 | G_combine.load_state_dict(gen_state_dict['GCombine']) 128 | dis_model_path = os.path.join("checkpoints", opt.load_pre_train_name, "latest_net_dis.pth") 129 | dis_state_dict = jt.safeunpickle(dis_model_path) 130 | D_global.load_state_dict(dis_state_dict['D']) 131 | D_l_eyel.load_state_dict(dis_state_dict['DLEyel']) 132 | D_l_eyer.load_state_dict(dis_state_dict['DLEyer']) 133 | D_l_nose.load_state_dict(dis_state_dict['DLNose']) 134 | D_l_mouth.load_state_dict(dis_state_dict['DLMouth']) 135 | D_l_hair.load_state_dict(dis_state_dict['DLHair']) 136 | D_l_bg.load_state_dict(dis_state_dict['DLBG']) 137 | 138 | DT1 = GeneratorUNet(in_channels=1, out_channels=1, num_downs=9) 139 | DT2 = GeneratorUNet(in_channels=1, out_channels=1, num_downs=9) 140 | Line1 = GeneratorUNet(in_channels=1, out_channels=1, num_downs=9) 141 | Line2 = GeneratorUNet(in_channels=1, out_channels=1, num_downs=9) 142 | DT1.load(os.path.join("checkpoints", opt.load_auxiliary_name, "latest_net_DT1.pth")) 143 | DT2.load(os.path.join("checkpoints", opt.load_auxiliary_name, "latest_net_DT2.pth")) 144 | Line1.load(os.path.join("checkpoints", opt.load_auxiliary_name, "latest_net_Line1.pth")) 145 | Line2.load(os.path.join("checkpoints", opt.load_auxiliary_name, "latest_net_Line2.pth")) 146 | 147 | # Optimizers 148 | G_nets_params = G_nets[0].parameters() 149 | for net in G_nets[1:]: 150 | G_nets_params += net.parameters() 151 | optimizer_G = jt.optim.Adam(G_nets_params, lr=opt.lr, betas=(opt.b1, opt.b2)) 152 | D_nets_params = D_nets[0].parameters() 153 | for net in D_nets[1:]: 154 | D_nets_params += net.parameters() 155 | optimizer_D = jt.optim.Adam(D_nets_params, lr=opt.lr, betas=(opt.b1, opt.b2)) 156 | 157 | # Configure dataloaders 158 | dataloader = ImageDataset("data/%s" % opt.dataset_name, load_h=opt.img_height, load_w=opt.img_width).set_attrs( 159 | batch_size=opt.batch_size, 160 | shuffle=True, 161 | num_workers=opt.n_cpu, 162 | ) 163 | val_dataloader = TestDataset(opt.val_input_folder, opt.val_lm_folder, opt.val_mask_folder, mode="val", load_h=opt.img_height, load_w=opt.img_width).set_attrs( 164 | batch_size=10, 165 | shuffle=True, 166 | num_workers=1, 167 | ) 168 | 169 | def sample_images(batches_done): 170 | """Saves a generated sample from the validation set""" 171 | batches = next(iter(val_dataloader)) 172 | 173 | real_A = batches[0] 174 | real_A_eyel = batches[1] 175 | real_A_eyer = batches[2] 176 | real_A_nose = batches[3] 177 | real_A_mouth = batches[4] 178 | real_A_hair = batches[5] 179 | real_A_bg = batches[6] 180 | mask = batches[7] 181 | mask2 = batches[8] 182 | center = batches[9] 183 | 184 | maskh = mask*mask2 185 | maskb = inverse_mask(mask2) 186 | 187 | fake_B0 = G_global(real_A) 188 | # EYES, NOSE, MOUTH 189 | fake_B_eyel = G_l_eyel(real_A_eyel) 190 | fake_B_eyer = G_l_eyer(real_A_eyer) 191 | fake_B_nose = G_l_nose(real_A_nose) 192 | fake_B_mouth = G_l_mouth(real_A_mouth) 193 | # HAIR, BG AND PARTCOMBINE 194 | fake_B_hair = G_l_hair(real_A_hair) 195 | fake_B_bg = G_l_bg(real_A_bg) 196 | fake_B1 = partCombiner2_bg(center, fake_B_eyel, fake_B_eyer, fake_B_nose, fake_B_mouth, fake_B_hair, fake_B_bg, maskh, maskb, comb_op=1, load_h=opt.img_height, load_w=opt.img_width) 197 | # FUSION NET 198 | fake_B = G_combine(jt.contrib.concat((fake_B0, fake_B1), 1)) 199 | 200 | img_sample = np.concatenate([real_A.data, fake_B.repeat(1,3,1,1).data], -2) 201 | save_image(img_sample, "images/%s/%s.jpg" % (opt.dataset_name, batches_done), nrow=5) 202 | 203 | warmup_times = -1 204 | run_times = 3000 205 | total_time = 0. 206 | cnt = 0 207 | 208 | # ---------- 209 | # Training 210 | # ---------- 211 | 212 | prev_time = time.time() 213 | 214 | for epoch in range(opt.epoch, opt.n_epochs): 215 | for i, batches in enumerate(dataloader): 216 | 217 | real_A = batches[0] 218 | real_A_eyel = batches[1] 219 | real_A_eyer = batches[2] 220 | real_A_nose = batches[3] 221 | real_A_mouth = batches[4] 222 | real_A_hair = batches[5] 223 | real_A_bg = batches[6] 224 | mask = batches[7] 225 | mask2 = batches[8] 226 | center = batches[9] 227 | real_B = batches[10] 228 | dt1gt = batches[11] 229 | dt2gt = batches[12] 230 | 231 | maskh = mask*mask2 232 | maskb = inverse_mask(mask2) 233 | 234 | # Adversarial ground truths 235 | valid = jt.ones([real_A.shape[0], 1]).stop_grad() 236 | fake = jt.zeros([real_A.shape[0], 1]).stop_grad() 237 | 238 | fake_B0 = G_global(real_A) 239 | # EYES, NOSE, MOUTH 240 | fake_B_eyel = G_l_eyel(real_A_eyel) 241 | fake_B_eyer = G_l_eyer(real_A_eyer) 242 | fake_B_nose = G_l_nose(real_A_nose) 243 | fake_B_mouth = G_l_mouth(real_A_mouth) 244 | # HAIR, BG AND PARTCOMBINE 245 | fake_B_hair = G_l_hair(real_A_hair) 246 | fake_B_bg = G_l_bg(real_A_bg) 247 | fake_B1 = partCombiner2_bg(center, fake_B_eyel, fake_B_eyer, fake_B_nose, fake_B_mouth, fake_B_hair, fake_B_bg, maskh, maskb, comb_op=1, load_h=opt.img_height, load_w=opt.img_width) 248 | # FUSION NET 249 | fake_B = G_combine(jt.contrib.concat((fake_B0, fake_B1), 1)) 250 | 251 | # ------------------ 252 | # Train Generators 253 | # ------------------ 254 | # GAN loss 255 | pred_fake = D_global(fake_B, real_A) 256 | loss_GAN = criterion_GAN(pred_fake, valid) 257 | fake_B_locals = getLocalParts(fake_B, center, maskh, maskb, load_h=opt.img_height, load_w=opt.img_width) 258 | real_A_locals = [real_A_eyel, real_A_eyer, real_A_nose, real_A_mouth, real_A_hair, real_A_bg] 259 | loss_GAN_local = 0 260 | for j in range(6): 261 | pred_fake_local = D_nets[j+1](fake_B_locals[j], real_A_locals[j]) 262 | loss_GAN_local += criterion_GAN(pred_fake_local, valid) 263 | 264 | # L1 loss 265 | loss_pixel = criterion_pixelwise(fake_B, real_B) * opt.lambda_L1 266 | 267 | # DT loss 268 | ## 1) d_CM(a_i,G(p_i)) 269 | fake_B_gray = fake_B 270 | real_B_gray = real_B 271 | 272 | dt1 = DT1(fake_B_gray) 273 | dt2 = DT2(fake_B_gray) 274 | dt1 = dt1/2.0+0.5#[-1,1]->[0,1] 275 | dt2 = dt2/2.0+0.5 276 | bs = real_B_gray.shape[0] 277 | real_B_gray_line1 = Line1(real_B_gray) 278 | real_B_gray_line2 = Line2(real_B_gray) 279 | loss_G_chamfer = (dt1[(real_B_gray<0)&(real_B_gray_line1<0)].sum() + dt2[(real_B_gray>=0)&(real_B_gray_line2>=0)].sum()) / bs * opt.lambda_chamfer 280 | 281 | ## 2) d_CM(G(p_i),a_i) 282 | fake_B_gray_line1 = Line1(fake_B_gray) 283 | fake_B_gray_line2 = Line2(fake_B_gray) 284 | loss_G_chamfer2 = (dt1gt[(fake_B_gray<0)&(fake_B_gray_line1<0)].sum() + dt2gt[(fake_B_gray>=0)&(fake_B_gray_line2>=0)].sum()) / bs * opt.lambda_chamfer2 285 | 286 | # Local loss 287 | real_B_locals = getLocalParts(real_B, center, maskh, maskb, load_h=opt.img_height, load_w=opt.img_width) 288 | loss_G_local = 0 289 | for j in range(6): 290 | loss_G_local += criterion_pixelwise(fake_B_locals[j], real_B_locals[j]) * opt.lambda_local 291 | 292 | # Total loss 293 | loss_G = loss_GAN + 0.25 * loss_GAN_local + loss_pixel + (loss_G_chamfer + loss_G_chamfer2) + loss_G_local 294 | optimizer_G.step(loss_G) 295 | 296 | # --------------------- 297 | # Train Discriminator 298 | # --------------------- 299 | # Real loss 300 | pred_real = D_global(real_B, real_A) 301 | loss_real = criterion_GAN(pred_real, valid) 302 | loss_real_local = 0 303 | for j in range(6): 304 | pred_real_local = D_nets[j+1](real_B_locals[j], real_A_locals[j]) 305 | loss_real_local += criterion_GAN(pred_real_local, valid) 306 | # Fake loss 307 | pred_fake = D_global(fake_B.detach(), real_A) 308 | loss_fake = criterion_GAN(pred_fake, fake) 309 | loss_fake_local = 0 310 | for j in range(6): 311 | pred_fake_local = D_nets[j+1](fake_B_locals[j].detach(), real_A_locals[j]) 312 | loss_fake_local += criterion_GAN(pred_fake_local, fake) 313 | # Total loss 314 | loss_D = 0.5 * (loss_real + loss_fake + loss_real_local + loss_fake_local) 315 | optimizer_D.step(loss_D) 316 | 317 | if warmup_times==-1: 318 | # -------------- 319 | # Log Progress 320 | # -------------- 321 | 322 | # Determine approximate time left 323 | batches_done = epoch * len(dataloader) + i 324 | batches_left = opt.n_epochs * len(dataloader) - batches_done 325 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) 326 | prev_time = time.time() 327 | 328 | # Print log 329 | jt.sync_all() 330 | if batches_done % 5 == 0: 331 | sys.stdout.write( 332 | "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, pixel: %f, dt: %f, local: %f] ETA: %s" 333 | % ( 334 | epoch, 335 | opt.n_epochs, 336 | i, 337 | len(dataloader), 338 | loss_D.numpy()[0], 339 | loss_G.numpy()[0], 340 | loss_GAN.numpy()[0], 341 | loss_pixel.numpy()[0], 342 | (loss_G_chamfer+loss_G_chamfer2).numpy()[0], 343 | loss_G_local.numpy()[0], 344 | time_left, 345 | ) 346 | ) 347 | 348 | # If at sample interval save image 349 | if batches_done % opt.sample_interval == 0: 350 | img_sample = np.concatenate([real_A.data, fake_B.repeat(1,3,1,1).data, real_B.repeat(1,3,1,1).data, fake_B0.repeat(1,3,1,1).data, fake_B1.repeat(1,3,1,1).data, (dt1*2-1).repeat(1,3,1,1).data, (dt2*2-1).repeat(1,3,1,1).data, (dt1gt*2-1).repeat(1,3,1,1).data, (dt2gt*2-1).repeat(1,3,1,1).data], -2) 351 | save_image(img_sample, "images/%s/train_%s.jpg" % (opt.dataset_name, batches_done), nrow=1) 352 | sample_images(batches_done) 353 | else: 354 | jt.sync_all() 355 | cnt += 1 356 | print(cnt) 357 | if cnt == warmup_times: 358 | jt.sync_all(True) 359 | sta = time.time() 360 | if cnt > warmup_times + run_times: 361 | jt.sync_all(True) 362 | total_time = time.time() - sta 363 | print(f"run {run_times} iters cost {total_time} seconds, and avg {total_time / run_times} one iter.") 364 | exit(0) 365 | 366 | if batches_done % opt.sample_interval == 0: 367 | img_sample = np.concatenate([real_A.data, fake_B.repeat(1,3,1,1).data, real_B.repeat(1,3,1,1).data, fake_B0.repeat(1,3,1,1).data, fake_B1.repeat(1,3,1,1).data, (dt1*2-1).repeat(1,3,1,1).data, (dt2*2-1).repeat(1,3,1,1).data, (dt1gt*2-1).repeat(1,3,1,1).data, (dt2gt*2-1).repeat(1,3,1,1).data], -2) 368 | save_image(img_sample, "images/%s/train_%s.jpg" % (opt.dataset_name, batches_done), nrow=1) 369 | sample_images(batches_done) 370 | if opt.checkpoint_interval != -1 and (epoch+1) % opt.checkpoint_interval == 0: 371 | # Save model checkpoints 372 | G_global.save("checkpoints/%s/%d_net_G_global.pkl" % (opt.dataset_name, epoch+1)) 373 | G_l_eyel.save("checkpoints/%s/%d_net_G_l_eyel.pkl" % (opt.dataset_name, epoch+1)) 374 | G_l_eyer.save("checkpoints/%s/%d_net_G_l_eyer.pkl" % (opt.dataset_name, epoch+1)) 375 | G_l_nose.save("checkpoints/%s/%d_net_G_l_nose.pkl" % (opt.dataset_name, epoch+1)) 376 | G_l_mouth.save("checkpoints/%s/%d_net_G_l_mouth.pkl" % (opt.dataset_name, epoch+1)) 377 | G_l_hair.save("checkpoints/%s/%d_net_G_l_hair.pkl" % (opt.dataset_name, epoch+1)) 378 | G_l_bg.save("checkpoints/%s/%d_net_G_l_bg.pkl" % (opt.dataset_name, epoch+1)) 379 | G_combine.save("checkpoints/%s/%d_net_G_combine.pkl" % (opt.dataset_name, epoch+1)) 380 | D_global.save("checkpoints/%s/%d_net_D_global.pkl" % (opt.dataset_name, epoch+1)) 381 | D_l_eyel.save("checkpoints/%s/%d_net_D_l_eyel.pkl" % (opt.dataset_name, epoch+1)) 382 | D_l_eyer.save("checkpoints/%s/%d_net_D_l_eyer.pkl" % (opt.dataset_name, epoch+1)) 383 | D_l_nose.save("checkpoints/%s/%d_net_D_l_nose.pkl" % (opt.dataset_name, epoch+1)) 384 | D_l_mouth.save("checkpoints/%s/%d_net_D_l_mouth.pkl" % (opt.dataset_name, epoch+1)) 385 | D_l_hair.save("checkpoints/%s/%d_net_D_l_hair.pkl" % (opt.dataset_name, epoch+1)) 386 | D_l_bg.save("checkpoints/%s/%d_net_D_l_bg.pkl" % (opt.dataset_name, epoch+1)) 387 | --------------------------------------------------------------------------------