├── dataset_utils ├── ref-dot.png └── dataset_util.py ├── docs ├── figures │ ├── dynamic1.gif │ ├── dynamic2.gif │ ├── framework.png │ ├── Krita │ │ ├── krita1.png │ │ ├── krita2.png │ │ └── krita3.png │ ├── inbetweens │ │ ├── 0.png │ │ ├── 5.png │ │ ├── 24.png │ │ ├── 27.png │ │ ├── 3-1.png │ │ ├── 3-2.png │ │ ├── 3-3.png │ │ ├── 4-1.png │ │ ├── 4-2.png │ │ ├── 21-1.png │ │ ├── 21-2.png │ │ ├── 25-1.png │ │ ├── 25-2.png │ │ ├── 0-dynamic.gif │ │ ├── 0-reference.png │ │ ├── 21-dynamic.gif │ │ ├── 24-dynamic.gif │ │ ├── 25-dynamic.gif │ │ ├── 27-dynamic.gif │ │ ├── 3-dynamic.gif │ │ ├── 3-reference.png │ │ ├── 4-dynamic2.gif │ │ ├── 4-reference.png │ │ ├── 5-dynamic.gif │ │ ├── 5-reference.png │ │ ├── 21-reference.png │ │ ├── 24-reference.png │ │ ├── 25-reference.png │ │ ├── 27-reference.png │ │ └── rough │ │ │ ├── 21-target1.png │ │ │ ├── 21-target2.png │ │ │ ├── 24-target.png │ │ │ ├── 25-target1.png │ │ │ ├── 25-target2.png │ │ │ ├── 27-target.png │ │ │ ├── 21-reference.png │ │ │ ├── 24-reference.png │ │ │ ├── 25-reference.png │ │ │ └── 27-reference.png │ ├── teaser-sub1.png │ ├── teaser-sub2.png │ └── CACANi │ │ ├── canani1.png │ │ ├── canani2.png │ │ ├── canani3.png │ │ ├── canani4.png │ │ ├── canani5.png │ │ ├── canani6.png │ │ └── canani7.png └── assets │ ├── font.css │ └── style.css ├── sample_inputs ├── clean │ ├── raster │ │ ├── 0-0.png │ │ ├── 0-1.png │ │ ├── 2-0.png │ │ ├── 2-1.png │ │ ├── 2-2.png │ │ ├── 3-0.png │ │ ├── 3-1.png │ │ ├── 3-2.png │ │ ├── 3-3.png │ │ ├── 4-0.png │ │ ├── 4-1.png │ │ ├── 4-2.png │ │ ├── 5-0.png │ │ ├── 5-1.png │ │ ├── 7-0.png │ │ ├── 7-1.png │ │ ├── 9-0.png │ │ ├── 9-1.png │ │ ├── 9-2.png │ │ ├── 11-0.png │ │ └── 11-1.png │ └── svg │ │ ├── 2-0.svg │ │ ├── 0-0.svg │ │ ├── 3-0.svg │ │ ├── 4-0.svg │ │ ├── 5-0.svg │ │ └── 11-0.svg └── rough │ ├── raster │ ├── 20-0.png │ ├── 20-1.png │ ├── 21-0.png │ ├── 21-1.png │ ├── 21-2.png │ ├── 21-3.png │ ├── 22-0.png │ ├── 22-1.png │ ├── 23-0.png │ ├── 23-1.png │ ├── 23-2.png │ ├── 24-0.png │ ├── 24-1.png │ ├── 25-0.png │ ├── 25-1.png │ ├── 25-2.png │ ├── 26-0.png │ ├── 26-1.png │ ├── 27-0.png │ ├── 27-1.png │ ├── 27-2.png │ ├── 28-0.png │ ├── 28-1.png │ ├── 30-0.png │ ├── 30-1.png │ ├── 34-0.png │ └── 34-1.png │ └── svg │ ├── 34-0.svg │ ├── 30-0.svg │ ├── 25-0.svg │ ├── 27-0.svg │ ├── 22-0.svg │ └── 24-0.svg ├── .gitignore ├── tools ├── image_squaring.py ├── make_inbetweening.py ├── vis_difference.py ├── npz_to_svg.py └── svg_to_npz.py ├── tutorials ├── Krita_vector_generation.md └── CACANi_inbetweening_generation.md ├── sketch_tracing_inference.py ├── README.md ├── vgg_utils └── VGG16.py ├── LICENSE └── rnn2.py /dataset_utils/ref-dot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/dataset_utils/ref-dot.png -------------------------------------------------------------------------------- /docs/figures/dynamic1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/dynamic1.gif -------------------------------------------------------------------------------- /docs/figures/dynamic2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/dynamic2.gif -------------------------------------------------------------------------------- /docs/figures/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/framework.png -------------------------------------------------------------------------------- /docs/figures/Krita/krita1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/Krita/krita1.png -------------------------------------------------------------------------------- /docs/figures/Krita/krita2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/Krita/krita2.png -------------------------------------------------------------------------------- /docs/figures/Krita/krita3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/Krita/krita3.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/0.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/5.png -------------------------------------------------------------------------------- /docs/figures/teaser-sub1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/teaser-sub1.png -------------------------------------------------------------------------------- /docs/figures/teaser-sub2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/teaser-sub2.png -------------------------------------------------------------------------------- /docs/figures/CACANi/canani1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/CACANi/canani1.png -------------------------------------------------------------------------------- /docs/figures/CACANi/canani2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/CACANi/canani2.png -------------------------------------------------------------------------------- /docs/figures/CACANi/canani3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/CACANi/canani3.png -------------------------------------------------------------------------------- /docs/figures/CACANi/canani4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/CACANi/canani4.png -------------------------------------------------------------------------------- /docs/figures/CACANi/canani5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/CACANi/canani5.png -------------------------------------------------------------------------------- /docs/figures/CACANi/canani6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/CACANi/canani6.png -------------------------------------------------------------------------------- /docs/figures/CACANi/canani7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/CACANi/canani7.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/24.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/27.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/3-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/3-1.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/3-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/3-2.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/3-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/3-3.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/4-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/4-1.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/4-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/4-2.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/21-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/21-1.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/21-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/21-2.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/25-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/25-1.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/25-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/25-2.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/0-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/0-0.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/0-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/0-1.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/2-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/2-0.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/2-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/2-1.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/2-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/2-2.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/3-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/3-0.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/3-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/3-1.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/3-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/3-2.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/3-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/3-3.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/4-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/4-0.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/4-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/4-1.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/4-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/4-2.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/5-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/5-0.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/5-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/5-1.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/7-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/7-0.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/7-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/7-1.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/9-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/9-0.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/9-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/9-1.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/9-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/9-2.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/11-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/11-0.png -------------------------------------------------------------------------------- /sample_inputs/clean/raster/11-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/11-1.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/20-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/20-0.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/20-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/20-1.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/21-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/21-0.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/21-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/21-1.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/21-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/21-2.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/21-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/21-3.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/22-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/22-0.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/22-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/22-1.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/23-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/23-0.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/23-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/23-1.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/23-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/23-2.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/24-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/24-0.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/24-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/24-1.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/25-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/25-0.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/25-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/25-1.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/25-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/25-2.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/26-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/26-0.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/26-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/26-1.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/27-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/27-0.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/27-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/27-1.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/27-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/27-2.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/28-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/28-0.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/28-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/28-1.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/30-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/30-0.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/30-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/30-1.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/34-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/34-0.png -------------------------------------------------------------------------------- /sample_inputs/rough/raster/34-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/34-1.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/0-dynamic.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/0-dynamic.gif -------------------------------------------------------------------------------- /docs/figures/inbetweens/0-reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/0-reference.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/21-dynamic.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/21-dynamic.gif -------------------------------------------------------------------------------- /docs/figures/inbetweens/24-dynamic.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/24-dynamic.gif -------------------------------------------------------------------------------- /docs/figures/inbetweens/25-dynamic.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/25-dynamic.gif -------------------------------------------------------------------------------- /docs/figures/inbetweens/27-dynamic.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/27-dynamic.gif -------------------------------------------------------------------------------- /docs/figures/inbetweens/3-dynamic.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/3-dynamic.gif -------------------------------------------------------------------------------- /docs/figures/inbetweens/3-reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/3-reference.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/4-dynamic2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/4-dynamic2.gif -------------------------------------------------------------------------------- /docs/figures/inbetweens/4-reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/4-reference.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/5-dynamic.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/5-dynamic.gif -------------------------------------------------------------------------------- /docs/figures/inbetweens/5-reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/5-reference.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/21-reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/21-reference.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/24-reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/24-reference.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/25-reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/25-reference.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/27-reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/27-reference.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/rough/21-target1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/21-target1.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/rough/21-target2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/21-target2.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/rough/24-target.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/24-target.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/rough/25-target1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/25-target1.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/rough/25-target2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/25-target2.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/rough/27-target.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/27-target.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/rough/21-reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/21-reference.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/rough/24-reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/24-reference.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/rough/25-reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/25-reference.png -------------------------------------------------------------------------------- /docs/figures/inbetweens/rough/27-reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/27-reference.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .idea/ 3 | data/ 4 | datas/ 5 | dataset/ 6 | datasets/ 7 | model/ 8 | models/ 9 | testData/ 10 | output/ 11 | outputs/ 12 | 13 | *.csv 14 | 15 | # temporary files 16 | *.txt~ 17 | *.pyc 18 | .DS_Store 19 | .gitignore~ 20 | 21 | *.h5 -------------------------------------------------------------------------------- /tools/image_squaring.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import argparse 4 | 5 | 6 | def image_squaring(img_path): 7 | img = Image.open(img_path).convert('RGB') 8 | height, width = img.height, img.width 9 | 10 | max_dim = max(height, width) 11 | 12 | pad_top = (max_dim - height) // 2 13 | pad_down = max_dim - height - pad_top 14 | pad_left = (max_dim - width) // 2 15 | pad_right = max_dim - width - pad_left 16 | 17 | img = np.array(img, dtype=np.uint8) 18 | img_p = np.pad(img, ((pad_top, pad_down), (pad_left, pad_right), (0, 0)), 'constant', constant_values=255) 19 | img_p = Image.fromarray(img_p, 'RGB') 20 | img_p.save(img_path[:-4] + '-pad.png', 'PNG') 21 | 22 | 23 | if __name__ == '__main__': 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--file', '-f', type=str, default='', help="define an image") 26 | args = parser.parse_args() 27 | 28 | image_squaring(args.file) 29 | -------------------------------------------------------------------------------- /tutorials/Krita_vector_generation.md: -------------------------------------------------------------------------------- 1 | # Make Vector Reference Frame with Krita 2 | 3 | 1. Install [Krita](https://krita.org/en/) 4 | 2. New File - Create (arbitrary configurations) 5 | 3. In Layers panel, remove the default Paint Layer. Then, add a File Layer. Select your reference image, and choose the 'Scale to Image Size' option. You will see image now. Remember to lock the layer in the panel. 6 | 4. Add a Vector Layer and move it above the File Layer. Then, choose the 'Bezier Curve Tool' on the left. Select any color you like and set the Size of the curve. 7 | 8 | 9 | 10 | 5. Trace the drawing. Set all the intermediate points to 'Corner point' for better control. Then, drag them to align the curves with the drawing. 11 | 12 | 13 | 14 | 6. After tracing the entire drawing, save it to an svg file: Layer - Import/Export - Save Vector Layer as SVG. Done! 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /tutorials/CACANi_inbetweening_generation.md: -------------------------------------------------------------------------------- 1 | # Make Inbetweening and Animation with CACANi 2 | 3 | 1. Install [CACANi](https://cacani.sg/). 4 | 2. Add keyframes according to the number of SVGs. There are 3 keyframes now: 5 | 6 | 7 | 8 | 3. Import SVG files in `separate/` folder one by one: File - Import - CACS/SVG File 9 | 10 | 11 | 12 | 4. Note that the imported keyframe is in a new cell. You should copy all the vector strokes to the default cell (Cel 1 - Layer 1): 13 | 14 | 15 | 16 | 5. Delete the cell above. 17 | 18 | 19 | 20 | 6. Repeat step-3 to step-5 to import the other SVGs to the corresponding keyframes. Then, add intermediate frames: 21 | 22 | 23 | 24 | 7. Select the empty intermediate frames between two keyframes, and generate the inbetweens. Repeat the process for the other empty intermediate frames. 25 | 26 | 27 | 28 | 8. Finally, broadcast the animation. You can choose different animation modes and set different frame rates. 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /docs/assets/font.css: -------------------------------------------------------------------------------- 1 | /* Homepage Font */ 2 | 3 | /* latin-ext */ 4 | @font-face { 5 | font-family: 'Lato'; 6 | font-style: normal; 7 | font-weight: 400; 8 | src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjxAwXjeu.woff2) format('woff2'); 9 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF; 10 | } 11 | 12 | /* latin */ 13 | @font-face { 14 | font-family: 'Lato'; 15 | font-style: normal; 16 | font-weight: 400; 17 | src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjx4wXg.woff2) format('woff2'); 18 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; 19 | } 20 | 21 | /* latin-ext */ 22 | @font-face { 23 | font-family: 'Lato'; 24 | font-style: normal; 25 | font-weight: 700; 26 | src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwaPGR_p.woff2) format('woff2'); 27 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF; 28 | } 29 | 30 | /* latin */ 31 | @font-face { 32 | font-family: 'Lato'; 33 | font-style: normal; 34 | font-weight: 700; 35 | src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwiPGQ.woff2) format('woff2'); 36 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; 37 | } 38 | -------------------------------------------------------------------------------- /sketch_tracing_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import six 3 | import argparse 4 | 5 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 6 | os.environ["KMP_WARNINGS"] = "0" 7 | 8 | import model_full as sketch_tracing_model 9 | from utils import load_dataset 10 | 11 | 12 | def main(dataset_base, data_type, img_seq, mode='inference'): 13 | model_params = sketch_tracing_model.get_default_hparams() 14 | model_params.add_hparam('dataset_base', dataset_base) 15 | model_params.add_hparam('data_type', data_type) 16 | model_params.add_hparam('img_seq', img_seq) 17 | 18 | print('Hyperparams:') 19 | for key, val in six.iteritems(model_params.values()): 20 | print('%s = %s' % (key, str(val))) 21 | print('-' * 100) 22 | 23 | val_set = load_dataset(model_params) 24 | model = sketch_tracing_model.FullModel(model_params, val_set) 25 | 26 | if mode == 'inference': 27 | inference_root = model_params.inference_root 28 | sub_inference_root = os.path.join(inference_root, model_params.data_type) 29 | os.makedirs(sub_inference_root, exist_ok=True) 30 | model.inference(sub_inference_root, model_params.img_seq) 31 | # elif mode == 'test': 32 | # model.evaluate(load_trained_weights=True) 33 | else: 34 | raise Exception('Unknown mode:', mode) 35 | 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--dataset_base', '-db', type=str, default='sample_inputs', help="define the data base") 40 | parser.add_argument('--data_type', '-dt', type=str, default='rough', choices=['clean', 'rough'], help="define the data type") 41 | parser.add_argument('--img_seq', '-is', type=str, default=['23-0.png', '23-1.png', '23-2.png'], nargs='+', help="define the image sequence") 42 | 43 | args = parser.parse_args() 44 | 45 | main(args.dataset_base, args.data_type, args.img_seq) 46 | -------------------------------------------------------------------------------- /tools/make_inbetweening.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | 6 | def gen_intensity_list(max_inten, min_inten, num): 7 | interval = (max_inten - min_inten) // (num - 1) 8 | intensity_list = [min_inten + i * interval for i in range(num)] 9 | intensity_list = intensity_list[::-1] 10 | return intensity_list 11 | 12 | 13 | def make_inbetweening_img(data_base, image_sequence): 14 | max_intensity = 200 15 | min_intensity = 0 16 | black_threshold = 128 17 | 18 | img_num = len(image_sequence) 19 | 20 | intensity_list = gen_intensity_list(max_intensity, min_intensity, img_num) 21 | print('intensity_list', intensity_list) 22 | 23 | img_inbetween = None 24 | 25 | for i, img_name in enumerate(image_sequence): 26 | img_path = os.path.join(data_base, img_name) 27 | img = Image.open(img_path).convert('RGB') 28 | img = np.array(img, dtype=np.uint8)[:, :, 0] 29 | 30 | intensity = intensity_list[i] 31 | 32 | if img_inbetween is None: 33 | img_inbetween = np.ones_like(img) * 255 34 | 35 | img_inbetween[img <= black_threshold] = intensity 36 | 37 | img_inbetween = Image.fromarray(img_inbetween, 'L') 38 | save_path = os.path.join(data_base, 'inbetweening.png') 39 | img_inbetween.save(save_path) 40 | 41 | 42 | def make_inbetweening_gif(data_base, image_sequence): 43 | all_files = image_sequence + image_sequence[::-1][1:] 44 | 45 | gif_frames = [] 46 | for img_name in all_files: 47 | img_i = Image.open(os.path.join(data_base, img_name)) 48 | gif_frames.append(img_i) 49 | 50 | print('gif_frames', len(gif_frames)) 51 | save_path = os.path.join(data_base, 'inbetweening.gif') 52 | first_frame = gif_frames[0] 53 | first_frame.save(save_path, save_all=True, append_images=gif_frames, loop=0, duration=0.01) 54 | 55 | 56 | if __name__ == '__main__': 57 | data_base = '../outputs/inference/clean/inbetweening/9' 58 | image_sequence = ['10.png', '18.png', '25.png', '32.png', '40.png'] 59 | 60 | make_inbetweening_img(data_base, image_sequence) 61 | make_inbetweening_gif(data_base, image_sequence) 62 | -------------------------------------------------------------------------------- /tools/vis_difference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import argparse 5 | 6 | 7 | def vis_difference(database_input, data_type, reference_image, target_image): 8 | data_base = os.path.join(database_input, data_type) 9 | raster_base = os.path.join(data_base, 'raster') 10 | raster_diff_base = os.path.join(data_base, 'raster_diff') 11 | os.makedirs(raster_diff_base, exist_ok=True) 12 | 13 | ref_img_path = os.path.join(raster_base, reference_image) 14 | tar_img_path = os.path.join(raster_base, target_image) 15 | 16 | ref_img = Image.open(ref_img_path).convert('L') 17 | ref_img = np.array(ref_img, dtype=np.float32) 18 | tar_img = Image.open(tar_img_path).convert('RGB') 19 | tar_img = np.array(tar_img, dtype=np.float32) 20 | 21 | target_mask = (tar_img < 200).any(-1) 22 | ref_img_bg = np.expand_dims(ref_img, axis=-1).astype(np.float32) 23 | ref_img_bg = np.concatenate( 24 | [np.ones_like(ref_img_bg) * 255, 25 | ref_img_bg, 26 | np.ones_like(ref_img_bg) * 255], axis=-1) 27 | ref_img_bg = 255 - (255 - ref_img_bg) * 0.5 28 | ref_img_bg[target_mask] = tar_img[target_mask] 29 | ref_img_bg = ref_img_bg.astype(np.uint8) 30 | 31 | save_path = os.path.join(raster_diff_base, reference_image[:-4] + '_and_' + target_image[:-4] + '.png') 32 | ref_img_bg_png = Image.fromarray(ref_img_bg, 'RGB') 33 | ref_img_bg_png.save(save_path, 'PNG') 34 | 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--database_input', '-dbi', type=str, default='../sample_inputs', 39 | help="define the input data base") 40 | parser.add_argument('--data_type', '-dt', type=str, default='rough', choices=['clean', 'rough'], 41 | help="define the data type") 42 | parser.add_argument('--reference_image', '-ri', type=str, default='23-0.png', 43 | help="define the reference image") 44 | parser.add_argument('--target_image', '-ti', type=str, default='23-1.png', 45 | help="define the target image") 46 | 47 | args = parser.parse_args() 48 | 49 | vis_difference(args.database_input, args.data_type, args.reference_image, args.target_image) 50 | -------------------------------------------------------------------------------- /docs/assets/style.css: -------------------------------------------------------------------------------- 1 | /* Body */ 2 | body { 3 | background: #e3e5e8; 4 | color: #ffffff; 5 | font-family: 'Lato', Verdana, Helvetica, sans-serif; 6 | font-weight: 300; 7 | font-size: 14pt; 8 | } 9 | 10 | /* Hyperlinks */ 11 | a {text-decoration: none;} 12 | a:link {color: #1772d0;} 13 | a:visited {color: #1772d0;} 14 | a:active {color: red;} 15 | a:hover {color: #f09228;} 16 | 17 | /* Pre-formatted Text */ 18 | pre { 19 | margin: 5pt 0; 20 | border: 0; 21 | font-size: 12pt; 22 | background: #fcfcfc; 23 | } 24 | 25 | /* Project Page Style */ 26 | /* Section */ 27 | .section { 28 | width: 968pt; 29 | min-height: 100pt; 30 | margin: 15pt auto; 31 | padding: 20pt 30pt; 32 | border: 1pt hidden #000; 33 | text-align: justify; 34 | color: #000000; 35 | background: #ffffff; 36 | } 37 | 38 | /* Header (Title and Logo) */ 39 | .section .header { 40 | min-height: 80pt; 41 | margin-top: 30pt; 42 | } 43 | .section .header .logo { 44 | width: 80pt; 45 | margin-left: 10pt; 46 | float: left; 47 | } 48 | .section .header .logo img { 49 | width: 80pt; 50 | object-fit: cover; 51 | } 52 | .section .header .title { 53 | margin: 0 120pt; 54 | text-align: center; 55 | font-size: 22pt; 56 | } 57 | 58 | /* Author */ 59 | .section .author { 60 | margin: 5pt 0; 61 | text-align: center; 62 | font-size: 16pt; 63 | } 64 | 65 | /* Institution */ 66 | .section .institution { 67 | margin: 5pt 0; 68 | text-align: center; 69 | font-size: 16pt; 70 | } 71 | 72 | /* Hyperlink (such as Paper and Code) */ 73 | .section .link { 74 | margin: 5pt 0; 75 | text-align: center; 76 | font-size: 16pt; 77 | } 78 | 79 | /* Teaser */ 80 | .section .teaser { 81 | margin: 20pt 0; 82 | text-align: left; 83 | } 84 | .section .teaser img { 85 | width: 95%; 86 | } 87 | 88 | /* Section Title */ 89 | .section .title { 90 | text-align: center; 91 | font-size: 22pt; 92 | margin: 5pt 0 15pt 0; /* top right bottom left */ 93 | } 94 | 95 | /* Section Body */ 96 | .section .body { 97 | margin-bottom: 15pt; 98 | text-align: justify; 99 | font-size: 14pt; 100 | } 101 | 102 | /* BibTeX */ 103 | .section .bibtex { 104 | margin: 5pt 0; 105 | text-align: left; 106 | font-size: 22pt; 107 | } 108 | 109 | /* Related Work */ 110 | .section .ref { 111 | margin: 20pt 0 10pt 0; /* top right bottom left */ 112 | text-align: left; 113 | font-size: 18pt; 114 | font-weight: bold; 115 | } 116 | 117 | /* Citation */ 118 | .section .citation { 119 | min-height: 60pt; 120 | margin: 10pt 0; 121 | } 122 | .section .citation .image { 123 | width: 120pt; 124 | float: left; 125 | } 126 | .section .citation .image img { 127 | max-height: 60pt; 128 | width: 120pt; 129 | object-fit: cover; 130 | } 131 | .section .citation .comment{ 132 | margin-left: 0pt; 133 | text-align: left; 134 | font-size: 14pt; 135 | } 136 | -------------------------------------------------------------------------------- /dataset_utils/dataset_util.py: -------------------------------------------------------------------------------- 1 | import pydiffvg 2 | import torch 3 | 4 | from dataset_utils.common import generate_colors2, generate_colors_ordered 5 | 6 | 7 | def draw_segment_jointly(strokes_points_list, canvas_size, stroke_thickness, render, background, is_black=True): 8 | shapes = [] 9 | shape_groups = [] 10 | 11 | stroke_num = len(strokes_points_list) 12 | colors = generate_colors2(stroke_num) # list of (3), in [0., 1.] 13 | 14 | for i in range(stroke_num): 15 | points_single_stroke = strokes_points_list[i] # (N, 2) 16 | points_single_stroke = torch.tensor(points_single_stroke) 17 | 18 | segment_num = (len(points_single_stroke) - 1) // 3 19 | if is_black: 20 | stroke_color = [0.0, 0.0, 0.0, 1.0] 21 | else: 22 | stroke_color = [colors[i][0], colors[i][1], colors[i][2], 1.0] 23 | 24 | num_control_points = torch.zeros(segment_num, dtype=torch.int32) + 2 25 | 26 | path = pydiffvg.Path(num_control_points=num_control_points, 27 | points=points_single_stroke, 28 | is_closed=False, 29 | stroke_width=torch.tensor(stroke_thickness)) 30 | shapes.append(path) 31 | 32 | path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(shapes) - 1]), 33 | fill_color=None, 34 | stroke_color=torch.tensor(stroke_color)) 35 | shape_groups.append(path_group) 36 | 37 | scene_args_t = pydiffvg.RenderFunction.serialize_scene( 38 | canvas_size, canvas_size, shapes, shape_groups) 39 | img = render(canvas_size, # width 40 | canvas_size, # height 41 | 2, # num_samples_x 42 | 2, # num_samples_y 43 | 1, # seed 44 | background, # background_image 45 | *scene_args_t) 46 | return img 47 | 48 | 49 | def draw_segment_separately(strokes_points_list, canvas_size, stroke_thickness, max_seq_number, 50 | render, background, is_black=False, draw_order=False): 51 | shapes = [] 52 | shape_groups = [] 53 | 54 | stroke_num = len(strokes_points_list) 55 | 56 | if draw_order: 57 | colors = generate_colors_ordered(max_seq_number) # list of (3), in [0., 1.] 58 | else: 59 | colors = generate_colors2(max_seq_number) # list of (3), in [0., 1.] 60 | global_segment_idx = 0 61 | 62 | for i in range(stroke_num): 63 | points_single_stroke = strokes_points_list[i] # (N, 2) 64 | points_single_stroke = torch.tensor(points_single_stroke) 65 | 66 | segment_num = (len(points_single_stroke) - 1) // 3 67 | 68 | for j in range(segment_num): 69 | start_idx = j * 3 70 | points_single_segment = points_single_stroke[start_idx: start_idx + 4] 71 | 72 | if is_black: 73 | stroke_color = [0.0, 0.0, 0.0, 1.0] 74 | else: 75 | stroke_color = [colors[global_segment_idx][0], colors[global_segment_idx][1], colors[global_segment_idx][2], 1.0] 76 | 77 | num_control_points = torch.tensor([2]) 78 | 79 | path = pydiffvg.Path(num_control_points=num_control_points, 80 | points=points_single_segment, 81 | is_closed=False, 82 | stroke_width=torch.tensor(stroke_thickness)) 83 | shapes.append(path) 84 | 85 | path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(shapes) - 1]), 86 | fill_color=None, 87 | stroke_color=torch.tensor(stroke_color)) 88 | shape_groups.append(path_group) 89 | 90 | global_segment_idx += 1 91 | 92 | scene_args_t = pydiffvg.RenderFunction.serialize_scene( 93 | canvas_size, canvas_size, shapes, shape_groups) 94 | img = render(canvas_size, # width 95 | canvas_size, # height 96 | 2, # num_samples_x 97 | 2, # num_samples_y 98 | 1, # seed 99 | background, # background_image 100 | *scene_args_t) 101 | return img 102 | -------------------------------------------------------------------------------- /sample_inputs/clean/svg/2-0.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /sample_inputs/clean/svg/0-0.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /sample_inputs/rough/svg/34-0.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /sample_inputs/rough/svg/30-0.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Joint Stroke Tracing and Correspondence for 2D Animation - TOG & SIGGRAPH 2024 2 | 3 | [[Paper]](https://www.sysu-imsl.com/files/TOG2024/SketchTracing_TOG2024_personal.pdf) | [[Paper (ACM)]](https://dl.acm.org/doi/10.1145/3649890) | [[Project Page]](https://markmohr.github.io/JoSTC/) 4 | 5 | This code is used for producing stroke tracing and correspondence results, which can be imported into an inbetweening product named [CACANi](https://cacani.sg) for making 2D animations. 6 | 7 | 8 | 9 | 10 | 11 | ## Outline 12 | - [Dependencies](#dependencies) 13 | - [Quick Start](#quick-start) 14 | - [Vector Stroke Correspondence Dataset](#vector-stroke-correspondence-dataset) 15 | - [Citation](#citation) 16 | 17 | ## Dependencies 18 | - [cudatoolkit](https://www.anaconda.com/download/) == 11.0.3 19 | - [cudnn](https://www.anaconda.com/download/) == 8.4.1.50 20 | - [pytorch](https://pytorch.org/) == 1.9.0 21 | - [torchvision](https://pytorch.org/vision/0.9/) == 0.9.0 22 | - [diffvg](https://github.com/BachiLi/diffvg) 23 | - [Krita](https://krita.org/en/): for making reference vector frame 24 | - [CACANi](https://cacani.sg/): for making inbetweening and 2D animation 25 | 26 | ## Quick Start 27 | 28 | ### Model Preparation 29 | 30 | Download the models [here](https://drive.google.com/drive/folders/15oAP7YbNKx4Cx1AmzC16wuoyQoAai2YV?usp=sharing), and place them in this file structure: 31 | ``` 32 | models/ 33 | quickdraw-perceptual.pth 34 | point_matching_model/ 35 | sketch_correspondence_50000.pkl 36 | transform_module/ 37 | sketch_transform_30000.pkl 38 | stroke_tracing_model/ 39 | sketch_tracing_30000.pkl 40 | ``` 41 | 42 | ### Create Reference Vector Frames with Krita 43 | 44 | Our method takes as inputs consecutive raster keyframes and a single vector drawing from the starting keyframe, and then generates vector images for the remaining keyframes with one-to-one stroke correspondence. So we have to create the vector image for the reference frame here. 45 | 46 | **Note**: We provide several examples for testing in directory `sample_inputs/`. If you use them, you can skip step-1 and step-2 below and execute step-3 directly. 47 | 48 | 1. Our method takes squared images as input, so please preprocess the images first using [tools/image_squaring.py](https://github.com/MarkMoHR/JoSTC/blob/main/tools/image_squaring.py): 49 | ``` 50 | python3 tools/image_squaring.py --file path/to/the/image.png 51 | ``` 52 | 53 | 2. Follow tutorial [here](https://github.com/MarkMoHR/JoSTC/blob/main/tutorials/Krita_vector_generation.md) to make vector frames as a reference in svg format with [Krita](https://krita.org/en/). 54 | 3. Place the svg files in `sample_inputs/*/svg/`. Then, convert them into npz format using [tools/svg_to_npz.py](https://github.com/MarkMoHR/JoSTC/blob/main/tools/svg_to_npz.py): 55 | ``` 56 | cd tools/ 57 | python3 svg_to_npz.py --database ../sample_inputs/rough/ --reference 23-0.png 58 | ``` 59 | 60 | ### Execute the Main Code 61 | 62 | Perform joint stroke tracing and correspondence using [sketch_tracing_inference.py](https://github.com/MarkMoHR/JoSTC/blob/main/sketch_tracing_inference.py). We provide several examples for testing in directory `sample_inputs/`. 63 | ``` 64 | python3 sketch_tracing_inference.py --dataset_base sample_inputs --data_type rough --img_seq 23-0.png 23-1.png 23-2.png 65 | ``` 66 | - `--data_type`: specify the image type with `clean` or `rough`. 67 | - `--img_seq`: specify the animation frames here. The first one should be the reference frame. 68 | - The results are placed in `outputs/inference/*`. Inside this folder: 69 | - `raster/` stores rendered line drawings of the target vector frames. 70 | - `rgb/` stores visualization (with target images underneath) of the vector stroke correspondence to the reference frame. 71 | - `rgb-wo-bg/` stores visualization without target images underneath. 72 | - `parameter/` stores vector stroke parameters. 73 | 74 | ### Create Inbetweening and Animation with CACANi 75 | 76 | 1. Convert the output npz file (vector stroke parameters) into svg format using [tools/npz_to_svg.py](https://github.com/MarkMoHR/JoSTC/blob/main/tools/npz_to_svg.py): 77 | ``` 78 | cd tools/ 79 | python3 npz_to_svg.py --database_input ../sample_inputs/ --database_output ../outputs/inference/ --data_type rough --file_names 23-1.png 23-2.png 80 | ``` 81 | - The results are placed in `outputs/inference/*/svg`. There are two kinds of results: 82 | - `chain/`: the svg files store stroke chains, defining each path as a chain. 83 | - `separate/`: the svg files store separated strokes, defining each path as a single stroke. **Note that the automatic inbetweening in [CACANi](https://cacani.sg/) relies on this format.** 84 | 85 | 2. Follow tutorial [here](https://github.com/MarkMoHR/JoSTC/blob/main/tutorials/CACANi_inbetweening_generation.md) to generate inbetweening and 2D animation with [CACANi](https://cacani.sg/). 86 | 87 | ### More Tools 88 | 89 | - [tools/vis_difference.py](https://github.com/MarkMoHR/JoSTC/blob/main/tools/vis_difference.py): visualize difference between the reference image and the target one. The results are placed in `sample_inputs/*/raster_diff/` 90 | ``` 91 | cd tools/ 92 | python3 vis_difference.py --database_input ../sample_inputs --data_type rough --reference_image 23-0.png --target_image 23-1.png 93 | ``` 94 | 95 | - [tools/make_inbetweening.py](https://github.com/MarkMoHR/JoSTC/blob/main/tools/make_inbetweening.py): visualize the inbetweening in a single image or a gif file. 96 | 97 |
98 | 99 | ## Vector Stroke Correspondence Dataset 100 | 101 | We collect a dataset for training with 10k+ pairs of raster frames and their vector drawings with stroke correspondence. Please download it [here](https://drive.google.com/drive/folders/15oAP7YbNKx4Cx1AmzC16wuoyQoAai2YV?usp=sharing). We provide a reference code [dataset_utils/tuberlin_dataset_util.py](https://github.com/MarkMoHR/JoSTC/blob/main/dataset_utils/tuberlin_dataset_util.py) showing how to use the data. 102 | 103 |
104 | 105 | ## Citation 106 | 107 | If you use the code and models please cite: 108 | 109 | ``` 110 | @article{mo2024joint, 111 | title={Joint Stroke Tracing and Correspondence for 2D Animation}, 112 | author={Mo, Haoran and Gao, Chengying and Wang, Ruomei}, 113 | journal={ACM Transactions on Graphics}, 114 | volume={43}, 115 | number={3}, 116 | pages={1--17}, 117 | year={2024}, 118 | publisher={ACM New York, NY} 119 | } 120 | ``` 121 | 122 | 123 | -------------------------------------------------------------------------------- /sample_inputs/clean/svg/3-0.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /tools/npz_to_svg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import random 5 | import argparse 6 | 7 | from xml.dom import minidom 8 | 9 | 10 | def write_svg(points_list, height, width, svg_save_path): 11 | impl_save = minidom.getDOMImplementation() 12 | 13 | doc_save = impl_save.createDocument(None, None, None) 14 | 15 | rootElement_save = doc_save.createElement('svg') 16 | rootElement_save.setAttribute('xmlns', 'http://www.w3.org/2000/svg') 17 | 18 | rootElement_save.setAttribute('height', str(height) + 'pt') 19 | rootElement_save.setAttribute('width', str(width) + 'pt') 20 | 21 | view_box = '0 0 ' + str(width) + ' ' + str(height) 22 | rootElement_save.setAttribute('viewBox', view_box) 23 | 24 | globl_path_i = 0 25 | for stroke_i, stroke_points in enumerate(points_list): 26 | # stroke_points: (N_point, 2), in image size 27 | segment_num = (stroke_points.shape[0] - 1) // 3 28 | 29 | for segment_i in range(segment_num): 30 | start_idx = segment_i * 3 31 | start_point = stroke_points[start_idx] 32 | ctrl_point1 = stroke_points[start_idx + 1] 33 | ctrl_point2 = stroke_points[start_idx + 2] 34 | end_point = stroke_points[start_idx + 3] 35 | 36 | command_str = 'M ' + str(start_point[0]) + ', ' + str(start_point[1]) + ' ' 37 | command_str += 'C ' + str(ctrl_point1[0]) + ', ' + str(ctrl_point1[1]) + ' ' \ 38 | + str(ctrl_point2[0]) + ', ' + str(ctrl_point2[1]) + ' ' \ 39 | + str(end_point[0]) + ', ' + str(end_point[1]) + ' ' 40 | 41 | childElement_save = doc_save.createElement('path') 42 | childElement_save.setAttribute('id', 'curve_' + str(globl_path_i)) 43 | childElement_save.setAttribute('stroke', '#000000') 44 | childElement_save.setAttribute('stroke-linejoin', 'round') 45 | childElement_save.setAttribute('stroke-linecap', 'square') 46 | childElement_save.setAttribute('fill', 'none') 47 | 48 | childElement_save.setAttribute('d', command_str) 49 | childElement_save.setAttribute('stroke-width', str(2.375)) 50 | rootElement_save.appendChild(childElement_save) 51 | 52 | globl_path_i += 1 53 | 54 | doc_save.appendChild(rootElement_save) 55 | 56 | f = open(svg_save_path, 'w') 57 | doc_save.writexml(f, addindent=' ', newl='\n') 58 | f.close() 59 | 60 | 61 | def write_svg_chain(points_list, height, width, svg_save_path): 62 | impl_save = minidom.getDOMImplementation() 63 | 64 | doc_save = impl_save.createDocument(None, None, None) 65 | 66 | rootElement_save = doc_save.createElement('svg') 67 | rootElement_save.setAttribute('xmlns', 'http://www.w3.org/2000/svg') 68 | 69 | rootElement_save.setAttribute('height', str(height) + 'pt') 70 | rootElement_save.setAttribute('width', str(width) + 'pt') 71 | 72 | view_box = '0 0 ' + str(width) + ' ' + str(height) 73 | rootElement_save.setAttribute('viewBox', view_box) 74 | 75 | for stroke_i, stroke_points in enumerate(points_list): 76 | # stroke_points: (N_point, 2), in image size 77 | segment_num = (stroke_points.shape[0] - 1) // 3 78 | 79 | command_str = 'M ' + str(stroke_points[0][0]) + ', ' + str(stroke_points[0][1]) + ' ' 80 | 81 | for segment_i in range(segment_num): 82 | start_idx = segment_i * 3 83 | ctrl_point1 = stroke_points[start_idx + 1] 84 | ctrl_point2 = stroke_points[start_idx + 2] 85 | end_point = stroke_points[start_idx + 3] 86 | 87 | command_str += 'C ' + str(ctrl_point1[0]) + ', ' + str(ctrl_point1[1]) + ' ' \ 88 | + str(ctrl_point2[0]) + ', ' + str(ctrl_point2[1]) + ' ' \ 89 | + str(end_point[0]) + ', ' + str(end_point[1]) + ' ' 90 | 91 | childElement_save = doc_save.createElement('path') 92 | childElement_save.setAttribute('id', 'curve_' + str(stroke_i)) 93 | childElement_save.setAttribute('stroke', '#000000') 94 | childElement_save.setAttribute('stroke-linejoin', 'round') 95 | childElement_save.setAttribute('stroke-linecap', 'square') 96 | childElement_save.setAttribute('fill', 'none') 97 | 98 | childElement_save.setAttribute('d', command_str) 99 | childElement_save.setAttribute('stroke-width', str(2.375)) 100 | rootElement_save.appendChild(childElement_save) 101 | 102 | doc_save.appendChild(rootElement_save) 103 | 104 | f = open(svg_save_path, 'w') 105 | doc_save.writexml(f, addindent=' ', newl='\n') 106 | f.close() 107 | 108 | 109 | def npz_to_svg(file_names, database_input, database_output, data_type): 110 | svg_chain_save_base = os.path.join(database_output, data_type, 'svg', 'chain') 111 | svg_separate_save_base = os.path.join(database_output, data_type, 'svg', 'separate') 112 | os.makedirs(svg_chain_save_base, exist_ok=True) 113 | os.makedirs(svg_separate_save_base, exist_ok=True) 114 | 115 | for file_name in file_names: 116 | file_id = file_name[:file_name.find('.')] 117 | 118 | output_image_path = os.path.join(database_output, data_type, 'raster', file_name) 119 | output_image = Image.open(output_image_path) 120 | width, height = output_image.width, output_image.height 121 | 122 | # output 123 | npz_path = os.path.join(database_output, data_type, 'parameter', file_id + '.npz') 124 | npz = np.load(npz_path, encoding='latin1', allow_pickle=True) 125 | strokes_data_output = npz['strokes_data'] # list of (N_point, 2), in image size 126 | 127 | svg_chain_output_save_path = os.path.join(svg_chain_save_base, file_id + '.svg') 128 | write_svg_chain(strokes_data_output, height, width, svg_chain_output_save_path) 129 | svg_separate_output_save_path = os.path.join(svg_separate_save_base, file_id + '.svg') 130 | write_svg(strokes_data_output, height, width, svg_separate_output_save_path) 131 | 132 | # rewrite input 133 | file_id = file_name[:file_name.find('-')] 134 | 135 | npz_path = os.path.join(database_input, data_type, 'parameter', file_id + '-0.npz') 136 | npz = np.load(npz_path, encoding='latin1', allow_pickle=True) 137 | strokes_data_input = npz['strokes_data'] # list of (N_point, 2), in image size 138 | 139 | svg_chain_input_save_path = os.path.join(svg_chain_save_base, file_id + '-0.svg') 140 | write_svg_chain(strokes_data_input, height, width, svg_chain_input_save_path) 141 | svg_chain_input_save_path = os.path.join(svg_separate_save_base, file_id + '-0.svg') 142 | write_svg(strokes_data_input, height, width, svg_chain_input_save_path) 143 | 144 | 145 | if __name__ == '__main__': 146 | parser = argparse.ArgumentParser() 147 | parser.add_argument('--database_input', '-dbi', type=str, default='../sample_inputs', help="define the input data base") 148 | parser.add_argument('--database_output', '-dbo', type=str, default='../outputs/inference', help="define the output data base") 149 | parser.add_argument('--data_type', '-dt', type=str, default='rough', choices=['clean', 'rough'], help="define the data type") 150 | parser.add_argument('--file_names', '-fn', type=str, default=['23-1.png', '23-2.png'], nargs='+', help="define the file names") 151 | 152 | args = parser.parse_args() 153 | 154 | npz_to_svg(args.file_names, args.database_input, args.database_output, args.data_type) 155 | -------------------------------------------------------------------------------- /vgg_utils/VGG16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | 8 | class VGG_Slim(nn.Module): 9 | def __init__(self): 10 | super(VGG_Slim, self).__init__() 11 | 12 | self.conv11 = nn.Conv2d(1, 64, 3, 1, 1) 13 | self.conv12 = nn.Conv2d(64, 64, 3, 1, 1) 14 | 15 | self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) 16 | 17 | self.conv21 =nn.Conv2d(64, 128, 3, 1, 1) 18 | self.conv22 =nn.Conv2d(128, 128, 3, 1, 1) 19 | 20 | self.conv31 =nn.Conv2d(128, 256, 3, 1, 1) 21 | self.conv32 =nn.Conv2d(256, 256, 3, 1, 1) 22 | self.conv33 =nn.Conv2d(256, 256, 3, 1, 1) 23 | 24 | self.conv41 =nn.Conv2d(256, 512, 3, 1, 1) 25 | self.conv42 =nn.Conv2d(512, 512, 3, 1, 1) 26 | self.conv43 =nn.Conv2d(512, 512, 3, 1, 1) 27 | 28 | self.conv51 =nn.Conv2d(512, 512, 3, 1, 1) 29 | self.conv52 =nn.Conv2d(512, 512, 3, 1, 1) 30 | self.conv53 =nn.Conv2d(512, 512, 3, 1, 1) 31 | 32 | def forward(self, input_imgs): 33 | return_map = {} 34 | 35 | x = input_imgs 36 | if x.dim() == 3: 37 | x = x.unsqueeze(dim=1) # NCHW 38 | 39 | x = F.relu(self.conv11(x)) 40 | return_map['ReLU1_1'] = x 41 | x = F.relu(self.conv12(x)) 42 | return_map['ReLU1_2'] = x 43 | x = self.max_pool(x) 44 | 45 | x = F.relu(self.conv21(x)) 46 | return_map['ReLU2_1'] = x 47 | x = F.relu(self.conv22(x)) 48 | return_map['ReLU2_2'] = x 49 | x = self.max_pool(x) 50 | 51 | x = F.relu(self.conv31(x)) 52 | return_map['ReLU3_1'] = x 53 | x = F.relu(self.conv32(x)) 54 | return_map['ReLU3_2'] = x 55 | x = F.relu(self.conv33(x)) 56 | return_map['ReLU3_3'] = x 57 | x = self.max_pool(x) 58 | 59 | x = F.relu(self.conv41(x)) 60 | return_map['ReLU4_1'] = x 61 | x = F.relu(self.conv42(x)) 62 | return_map['ReLU4_2'] = x 63 | x = F.relu(self.conv43(x)) 64 | return_map['ReLU4_3'] = x 65 | x = self.max_pool(x) 66 | 67 | x = F.relu(self.conv51(x)) 68 | return_map['ReLU5_1'] = x 69 | x = F.relu(self.conv52(x)) 70 | return_map['ReLU5_2'] = x 71 | x = F.relu(self.conv53(x)) 72 | return_map['ReLU5_3'] = x 73 | 74 | return return_map 75 | 76 | 77 | if __name__ == '__main__': 78 | import matplotlib.pyplot as plt 79 | import os 80 | from PIL import Image 81 | 82 | os.environ['CUDA_VISIBLE_DEVICES'] = '2' 83 | 84 | use_cuda = torch.cuda.is_available() 85 | 86 | def load_image(img_path): 87 | img = Image.open(img_path).convert('RGB') 88 | img = np.array(img, dtype=np.float32)[:, :, 0] / 255.0 89 | return img 90 | 91 | raster_size = 128 92 | vis_layer1 = 'ReLU3_3' 93 | vis_layer2 = 'ReLU5_1' 94 | 95 | vgg_slim_model = VGG_Slim() 96 | 97 | model_path = '../models/quickdraw-perceptual.pth' 98 | pretrained_dict = torch.load(model_path) 99 | print('pretrained_dict') 100 | print(pretrained_dict.keys()) 101 | 102 | model_dict = vgg_slim_model.state_dict() 103 | print('model_dict') 104 | print(model_dict.keys()) 105 | 106 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 107 | model_dict.update(pretrained_dict) 108 | vgg_slim_model.load_state_dict(model_dict) 109 | 110 | vgg_slim_model.eval() 111 | 112 | param_list = vgg_slim_model.named_parameters() 113 | print('-' * 100) 114 | count_t_vars = 0 115 | for name, param in param_list: 116 | num_param = np.prod(list(param.size())) 117 | count_t_vars += num_param 118 | print('%s | shape: %s | num_param: %i' % (name, str(param.size()), num_param)) 119 | print('Total trainable variables %i.' % count_t_vars) 120 | print('-' * 100) 121 | 122 | image_types = ['GT', 'Connected', 'Disconnected', 'Parallel', 'Part'] 123 | image_info_map = {} 124 | 125 | all_files = os.listdir(os.path.join('../vgg_utils/testData', 'GT')) 126 | all_files.sort() 127 | img_ids = [] 128 | for file_name in all_files: 129 | img_idx = file_name[:file_name.find('_')] 130 | img_ids.append(img_idx) 131 | print(img_ids) 132 | 133 | if use_cuda: 134 | vgg_slim_model = vgg_slim_model.cuda() 135 | 136 | for img_id in img_ids: 137 | for image_type in image_types: 138 | finalstr = '_gt.png' if image_type == 'GT' else '_pred.png' 139 | img_path = os.path.join('../vgg_utils/testData', image_type, str(img_id) + finalstr) 140 | print(img_path) 141 | image = load_image(img_path) 142 | image_tensor = torch.tensor(np.expand_dims(image, axis=0)).float() 143 | 144 | if use_cuda: 145 | image_tensor = image_tensor.cuda() 146 | 147 | feature_maps = vgg_slim_model(image_tensor) 148 | feature_maps_11 = feature_maps['ReLU3_3'].cpu().data.numpy() 149 | feature_maps_33 = feature_maps['ReLU5_1'].cpu().data.numpy() 150 | 151 | print('ReLU3_3', feature_maps_11.shape) 152 | print('ReLU5_1', feature_maps_33.shape) 153 | 154 | feature_maps_val11 = np.transpose(feature_maps_11[0], (1, 2, 0)) 155 | feature_maps_val33 = np.transpose(feature_maps_33[0], (1, 2, 0)) 156 | 157 | if image_type != 'GT': 158 | feature_maps_gt11 = image_info_map['GT'][4] 159 | feat_diff_all11 = np.mean(np.abs(feature_maps_gt11 - feature_maps_val11), axis=-1) 160 | perc_layer_loss11 = np.mean(np.abs(feature_maps_gt11 - feature_maps_val11)) 161 | 162 | feature_maps_gt33 = image_info_map['GT'][1] 163 | feat_diff_all33 = np.mean(np.abs(feature_maps_gt33 - feature_maps_val33), axis=-1) 164 | perc_layer_loss33 = np.mean(np.abs(feature_maps_gt33 - feature_maps_val33)) 165 | # print('perc_layer_loss', image_type, perc_layer_loss) 166 | else: 167 | feat_diff_all11 = np.zeros_like(feature_maps_val11[:, :, 0]) 168 | perc_layer_loss11 = 0.0 169 | 170 | feat_diff_all33 = np.zeros_like(feature_maps_val33[:, :, 0]) 171 | perc_layer_loss33 = 0.0 172 | 173 | image_info_map[image_type] = [image, 174 | feature_maps_val33, feat_diff_all33, perc_layer_loss33, 175 | feature_maps_val11, feat_diff_all11, perc_layer_loss11] 176 | 177 | rows = 3 178 | cols = len(image_types) 179 | plt.figure(figsize=(4 * cols, 4 * rows)) 180 | 181 | for image_type_i, image_type in enumerate(image_types): 182 | input_image = image_info_map[image_type][0] 183 | feat_diff33 = image_info_map[image_type][2] 184 | perc_loss33 = image_info_map[image_type][3] 185 | feat_diff11 = image_info_map[image_type][5] 186 | perc_loss11 = image_info_map[image_type][6] 187 | 188 | plt.subplot(rows, cols, image_type_i + 1) 189 | plt.title(image_type, fontsize=12) 190 | if image_type_i != 0: 191 | plt.axis('off') 192 | plt.imshow(input_image) 193 | 194 | plt.subplot(rows, cols, image_type_i + 1 + len(image_types)) 195 | plt.title(str(perc_loss11), fontsize=12) 196 | if image_type_i != 0: 197 | plt.axis('off') 198 | plt.imshow(feat_diff11) 199 | 200 | plt.subplot(rows, cols, image_type_i + 1 + len(image_types) + len(image_types)) 201 | plt.title(str(perc_loss33), fontsize=12) 202 | if image_type_i != 0: 203 | plt.axis('off') 204 | plt.imshow(feat_diff33) 205 | 206 | plt.show() 207 | -------------------------------------------------------------------------------- /tools/svg_to_npz.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['CUDA_VISIBLE_DEVICES'] = '7' 3 | 4 | import numpy as np 5 | from PIL import Image 6 | import argparse 7 | import sys 8 | 9 | sys.path.append("..") 10 | 11 | import pydiffvg 12 | import torch 13 | 14 | import xml.etree.ElementTree as ET 15 | from xml.dom import minidom 16 | from svg.path import parse_path, path 17 | 18 | from dataset_utils.dataset_util import draw_segment_jointly, draw_segment_separately 19 | 20 | 21 | invalid_svg_shapes = ['rect', 'circle', 'ellipse', 'line', 'polyline', 'polygon'] 22 | 23 | 24 | def parse_single_path(path_str): 25 | ps = parse_path(path_str) 26 | # print(len(ps)) 27 | 28 | stroke_points_list = [] 29 | control_points_list = [] 30 | for item_i, path_item in enumerate(ps): 31 | path_type = type(path_item) 32 | 33 | if path_type == path.Move: 34 | # assert item_i == 0 35 | if item_i != 0: 36 | assert (len(control_points_list) - 1) % 3 == 0 37 | assert len(control_points_list) > 1 38 | stroke_points_list.append(control_points_list) 39 | control_points_list = [] 40 | 41 | start = path_item.start 42 | start_x, start_y = start.real, start.imag 43 | control_points_list.append((start_x, start_y)) 44 | elif path_type == path.CubicBezier: 45 | start, control1, control2, end = path_item.start, path_item.control1, path_item.control2, path_item.end 46 | start_x, start_y = start.real, start.imag 47 | control1_x, control1_y = control1.real, control1.imag 48 | control2_x, control2_y = control2.real, control2.imag 49 | end_x, end_y = end.real, end.imag 50 | control_points_list.append((control1_x, control1_y)) 51 | control_points_list.append((control2_x, control2_y)) 52 | control_points_list.append((end_x, end_y)) 53 | elif path_type == path.Arc: 54 | raise Exception('Arc is here') 55 | elif path_type == path.Line: 56 | # assert len(control_points_list) == 1 57 | # start, end = path_item.start, path_item.end 58 | # start_x, start_y = start.real, start.imag 59 | # end_x, end_y = end.real, end.imag 60 | # 61 | # control1_x = 2.0 / 3.0 * start_x + 1.0 / 3.0 * end_x 62 | # control1_y = 2.0 / 3.0 * start_y + 1.0 / 3.0 * end_y 63 | # control2_x = 1.0 / 3.0 * start_x + 2.0 / 3.0 * end_x 64 | # control2_y = 1.0 / 3.0 * start_y + 2.0 / 3.0 * end_y 65 | # 66 | # control1 = (control1_x, control1_y) 67 | # control2 = (control2_x, control2_y) 68 | # control1_dist = sample_random_position(control1, 4.0, 1.0) 69 | # control2_dist = sample_random_position(control2, 4.0, 1.0) 70 | # 71 | # control_points_list.append(control1_dist) 72 | # control_points_list.append(control2_dist) 73 | # control_points_list.append((end_x, end_y)) 74 | raise Exception('Line is here') 75 | elif path_type == path.Close: 76 | assert item_i == len(ps) - 1 77 | else: 78 | raise Exception('Unknown path_type', path_type) 79 | 80 | assert (len(control_points_list) - 1) % 3 == 0 81 | assert len(control_points_list) > 1 82 | stroke_points_list.append(control_points_list) 83 | 84 | return stroke_points_list 85 | 86 | 87 | def matrix_transform(points, matrix_params): 88 | # points: (N, 2), (x, y) 89 | # matrix_params: (6) 90 | new_points = [] 91 | a, b, c, d, e, f = matrix_params 92 | matrix = np.array([[a, c, e], 93 | [b, d, f], 94 | [0, 0, 1]], dtype=np.float32) 95 | for point in points: 96 | point_vec = [point[0], point[1], 1] 97 | new_point = np.matmul(matrix, point_vec)[:2] 98 | new_points.append(new_point) 99 | # print(point, new_point) 100 | new_points = np.stack(new_points).astype(np.float32) 101 | return new_points 102 | 103 | 104 | def parse_svg(svg_file): 105 | tree = ET.parse(svg_file) 106 | root = tree.getroot() 107 | 108 | width = root.get('width') 109 | height = root.get('height') 110 | width = int(width[:-2]) 111 | height = int(height[:-2]) 112 | 113 | view_box = root.get('viewBox') 114 | view_x, view_y, view_width, view_height = view_box.split(' ') 115 | view_x, view_y, view_width, view_height = int(view_x), int(view_y), int(view_width), int(view_height) 116 | assert view_x == 0 and view_y == 0 117 | assert width == view_width and height == view_height 118 | 119 | strokes_list = [] 120 | 121 | for elem in root.iter(): 122 | try: 123 | _, tag_suffix = elem.tag.split('}') 124 | except ValueError: 125 | continue 126 | 127 | assert tag_suffix not in invalid_svg_shapes 128 | 129 | if tag_suffix == 'path': 130 | path_d = elem.attrib['d'] 131 | control_points_single_stroke_list = parse_single_path(path_d) # (N_point, 2) 132 | for control_points_single_stroke_ in control_points_single_stroke_list: 133 | control_points_single_stroke = np.array(control_points_single_stroke_, dtype=np.float32) 134 | # print('control_points_single_stroke', control_points_single_stroke.shape) 135 | 136 | if 'transform' in elem.attrib.keys(): 137 | transformation = elem.attrib['transform'] 138 | if 'translate' in transformation: 139 | translate_xy = transformation[transformation.find('(')+1:transformation.find(')')] 140 | assert ', ' in translate_xy 141 | translate_x = float(translate_xy[:translate_xy.find(', ')]) 142 | translate_y = float(translate_xy[translate_xy.find(', ')+2:]) 143 | 144 | control_points_single_stroke[:, 0] += translate_x 145 | control_points_single_stroke[:, 1] += translate_y 146 | elif 'matrix' in transformation: 147 | matrix_params = transformation[transformation.find('(')+1:transformation.find(')')] 148 | matrix_params = matrix_params.split(' ') 149 | assert len(matrix_params) == 6 150 | matrix_params = [float(item) for item in matrix_params] 151 | control_points_single_stroke = matrix_transform(control_points_single_stroke, matrix_params) 152 | else: 153 | raise Exception('Error transformation') 154 | 155 | strokes_list.append(control_points_single_stroke) 156 | 157 | assert len(strokes_list) > 0 158 | return (view_width, view_height), strokes_list 159 | 160 | 161 | def svg_to_npz(database, ref_image): 162 | svg_base = os.path.join(database, 'svg') 163 | img_base = os.path.join(database, 'raster') 164 | 165 | save_base_parameter = os.path.join(database, 'parameter') 166 | save_base_color_separate = os.path.join(database, 'vector_vis') 167 | os.makedirs(save_base_parameter, exist_ok=True) 168 | os.makedirs(save_base_color_separate, exist_ok=True) 169 | 170 | ref_id = ref_image[:ref_image.find('.')] 171 | svg_file_path = os.path.join(svg_base, ref_id + '.svg') 172 | img_file_path = os.path.join(img_base, ref_image) 173 | 174 | img = Image.open(img_file_path).convert('RGB') 175 | img_width, img_height = img.width, img.height 176 | assert img_width == img_height 177 | 178 | view_sizes, strokes_list = parse_svg(svg_file_path) 179 | # strokes_list: list of (N_point, 2), in view size 180 | assert view_sizes[0] == view_sizes[1] 181 | view_size = view_sizes[0] 182 | 183 | norm_strokes_list = [] # list of (N_point, 2), in image size 184 | max_seq_num = 0 185 | for single_stroke in strokes_list: 186 | single_stroke_norm = single_stroke / float(view_size) * float(img_width) 187 | norm_strokes_list.append(single_stroke_norm) 188 | segment_num = (len(single_stroke_norm) - 1) // 3 189 | # print('segment_num', segment_num) 190 | max_seq_num += segment_num 191 | print('max_seq_num', max_seq_num) 192 | 193 | save_npz_path = os.path.join(save_base_parameter, ref_id + '.npz') 194 | np.savez(save_npz_path, strokes_data=norm_strokes_list, canvas_size=img_width) 195 | 196 | # visualization 197 | pydiffvg.set_use_gpu(torch.cuda.is_available()) 198 | 199 | background_ = torch.ones(img_width, img_width, 4) 200 | render_ = pydiffvg.RenderFunction.apply 201 | 202 | stroke_thickness = 1.2 203 | 204 | # black_stroke_img = draw_segment_separately(norm_strokes_list, img_width, stroke_thickness, max_seq_num, 205 | # render_, background_, is_black=True) 206 | # black_stroke_img_save_path = os.path.join(img_base, ref_id + '-rendered.png') 207 | # pydiffvg.imwrite(black_stroke_img.cpu(), black_stroke_img_save_path, gamma=1.0) 208 | 209 | color_stroke_img = draw_segment_separately(norm_strokes_list, img_width, stroke_thickness, max_seq_num, 210 | render_, background_, is_black=False) 211 | color_stroke_img_save_path = os.path.join(save_base_color_separate, ref_id + '.png') 212 | pydiffvg.imwrite(color_stroke_img.cpu(), color_stroke_img_save_path, gamma=1.0) 213 | 214 | 215 | if __name__ == '__main__': 216 | parser = argparse.ArgumentParser() 217 | parser.add_argument('--database', '-db', type=str, default='../sample_inputs/rough/', help="define the data base") 218 | parser.add_argument('--reference', '-ref', type=str, default='20-0.png', help="define the reference image") 219 | args = parser.parse_args() 220 | 221 | svg_to_npz(args.database, args.reference) 222 | -------------------------------------------------------------------------------- /sample_inputs/clean/svg/4-0.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /sample_inputs/clean/svg/5-0.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /sample_inputs/clean/svg/11-0.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /sample_inputs/rough/svg/25-0.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /rnn2.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class HyperNorm(nn.Module): 8 | def __init__(self, input_size: int, num_units: int, hyper_embedding_size: int, use_bias: bool = True): 9 | super().__init__() 10 | self.num_units = num_units 11 | self.embedding_size = hyper_embedding_size 12 | self.use_bias = use_bias 13 | 14 | self.z_w = nn.Linear(input_size, self.embedding_size, bias=True) 15 | self.alpha = nn.Linear(self.embedding_size, self.num_units, bias=False) 16 | 17 | if self.use_bias: 18 | self.z_b = nn.Linear(input_size, self.embedding_size, bias=False) 19 | self.beta = nn.Linear(self.embedding_size, self.num_units, bias=False) 20 | 21 | def __call__(self, input: torch.Tensor, hyper_output: torch.Tensor): 22 | zw = self.z_w(hyper_output) 23 | alpha = self.alpha(zw) 24 | result = torch.mul(alpha, input) 25 | 26 | if self.use_bias: 27 | zb = self.z_b(hyper_output) 28 | beta = self.beta(zb) 29 | result = torch.add(result, beta) 30 | 31 | return result 32 | 33 | 34 | class LSTMCell(nn.Module): 35 | """ 36 | ## Long Short-Term Memory Cell 37 | LSTM Cell computes $c$, and $h$. $c$ is like the long-term memory, 38 | and $h$ is like the short term memory. 39 | We use the input $x$ and $h$ to update the long term memory. 40 | In the update, some features of $c$ are cleared with a forget gate $f$, 41 | and some features $i$ are added through a gate $g$. 42 | The new short term memory is the $\tanh$ of the long-term memory 43 | multiplied by the output gate $o$. 44 | Note that the cell doesn't look at long term memory $c$ when doing the update. It only modifies it. 45 | Also $c$ never goes through a linear transformation. 46 | This is what solves vanishing and exploding gradients. 47 | Here's the update rule. 48 | \begin{align} 49 | c_t &= \sigma(f_t) \odot c_{t-1} + \sigma(i_t) \odot \tanh(g_t) \\ 50 | h_t &= \sigma(o_t) \odot \tanh(c_t) 51 | \end{align} 52 | $\odot$ stands for element-wise multiplication. 53 | Intermediate values and gates are computed as linear transformations of the hidden 54 | state and input. 55 | \begin{align} 56 | i_t &= lin_x^i(x_t) + lin_h^i(h_{t-1}) \\ 57 | f_t &= lin_x^f(x_t) + lin_h^f(h_{t-1}) \\ 58 | g_t &= lin_x^g(x_t) + lin_h^g(h_{t-1}) \\ 59 | o_t &= lin_x^o(x_t) + lin_h^o(h_{t-1}) 60 | \end{align} 61 | """ 62 | 63 | def __init__(self, input_size: int, hidden_size: int, layer_norm: bool = False, 64 | forget_bias: float = 1.0): 65 | super().__init__() 66 | 67 | # These are the linear layer to transform the `input` and `hidden` vectors. 68 | # One of them doesn't need a bias since we add the transformations. 69 | self.forget_bias = forget_bias 70 | 71 | self.hidden_lin = nn.Linear(hidden_size, 4 * hidden_size, bias=False) 72 | self.input_lin = nn.Linear(input_size, 4 * hidden_size, bias=False) 73 | 74 | if layer_norm: 75 | self.layer_norm_all = nn.LayerNorm(4 * hidden_size) 76 | self.layer_norm_c = nn.LayerNorm(hidden_size) 77 | else: 78 | self.layer_norm_all = nn.Identity() 79 | self.layer_norm_c = nn.Identity() 80 | 81 | def __call__(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor): 82 | ifgo = self.input_lin(x) + self.hidden_lin(h) 83 | ifgo = self.layer_norm_all(ifgo) 84 | 85 | ifgo = ifgo.chunk(4, dim=-1) 86 | i, j, f, o = ifgo 87 | 88 | g = torch.tanh(j) 89 | 90 | c_next = c * torch.sigmoid(f + self.forget_bias) + torch.sigmoid(i) * g 91 | h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next)) 92 | 93 | return h_next, c_next 94 | 95 | 96 | class LSTM(nn.Module): 97 | """ 98 | ## Multilayer LSTM 99 | """ 100 | 101 | def __init__(self, input_size: int, hidden_size: int, n_layers: int): 102 | """ 103 | Create a network of `n_layers` of LSTM. 104 | """ 105 | 106 | super().__init__() 107 | self.n_layers = n_layers 108 | self.hidden_size = hidden_size 109 | # Create cells for each layer. Note that only the first layer gets the input directly. 110 | # Rest of the layers get the input from the layer below 111 | self.cells = nn.ModuleList([LSTMCell(input_size, hidden_size)] + 112 | [LSTMCell(hidden_size, hidden_size) for _ in range(n_layers - 1)]) 113 | 114 | def __call__(self, x: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None): 115 | """ 116 | `x` has shape `[n_steps, batch_size, input_size]` and 117 | `state` is a tuple of $h$ and $c$, each with a shape of `[batch_size, hidden_size]`. 118 | """ 119 | n_steps, batch_size = x.shape[:2] 120 | 121 | # Initialize the state if `None` 122 | if state is None: 123 | h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)] 124 | c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)] 125 | else: 126 | (h, c) = state 127 | # Reverse stack the tensors to get the states of each layer
128 | # ?? You can just work with the tensor itself but this is easier to debug 129 | h, c = list(torch.unbind(h)), list(torch.unbind(c)) 130 | 131 | # Array to collect the outputs of the final layer at each time step. 132 | out = [] 133 | for t in range(n_steps): 134 | # Input to the first layer is the input itself 135 | inp = x[t] 136 | # Loop through the layers 137 | for layer in range(self.n_layers): 138 | # Get the state of the layer 139 | h[layer], c[layer] = self.cells[layer](inp, h[layer], c[layer]) 140 | # Input to the next layer is the state of this layer 141 | inp = h[layer] 142 | # Collect the output $h$ of the final layer 143 | out.append(h[-1]) 144 | 145 | # Stack the outputs and states 146 | out = torch.stack(out) 147 | h = torch.stack(h) 148 | c = torch.stack(c) 149 | 150 | return out, (h, c) 151 | 152 | 153 | class HyperLSTMCell(nn.Module): 154 | """ 155 | ## HyperLSTM Cell 156 | For HyperLSTM the smaller network and the larger network both have the LSTM structure. 157 | This is defined in Appendix A.2.2 in the paper. 158 | """ 159 | 160 | def __init__(self, input_size: int, num_units: int, hyper_num_units: int = 256, hyper_embedding_size: int = 32, 161 | forget_bias: float = 1.0): 162 | """ 163 | `input_size` is the size of the input $x_t$, 164 | `num_units` is the size of the LSTM, and 165 | `hyper_num_units` is the size of the smaller LSTM that alters the weights of the larger outer LSTM. 166 | `hyper_embedding_size` is the size of the feature vectors used to alter the LSTM weights. 167 | We use the output of the smaller LSTM to compute $z_h^{i,f,g,o}$, $z_x^{i,f,g,o}$ and 168 | $z_b^{i,f,g,o}$ using linear transformations. 169 | We calculate $d_h^{i,f,g,o}(z_h^{i,f,g,o})$, $d_x^{i,f,g,o}(z_x^{i,f,g,o})$, and 170 | $d_b^{i,f,g,o}(z_b^{i,f,g,o})$ from these, using linear transformations again. 171 | These are then used to scale the rows of weight and bias tensors of the main LSTM. 172 | ?? Since the computation of $z$ and $d$ are two sequential linear transformations 173 | these can be combined into a single linear transformation. 174 | However we've implemented this separately so that it matches with the description 175 | in the paper. 176 | """ 177 | super().__init__() 178 | self.forget_bias = forget_bias 179 | 180 | self.hyper = LSTMCell(num_units + input_size, hyper_num_units, layer_norm=True) 181 | 182 | self.w_x = nn.Linear(input_size, 4 * num_units, bias=False) 183 | self.w_h = nn.Linear(num_units, 4 * num_units, bias=False) 184 | 185 | self.hyper_ix = HyperNorm(num_units, num_units, hyper_embedding_size, use_bias=False) 186 | self.hyper_jx = HyperNorm(num_units, num_units, hyper_embedding_size, use_bias=False) 187 | self.hyper_fx = HyperNorm(num_units, num_units, hyper_embedding_size, use_bias=False) 188 | self.hyper_ox = HyperNorm(num_units, num_units, hyper_embedding_size, use_bias=False) 189 | 190 | self.hyper_ih = HyperNorm(num_units, num_units, hyper_embedding_size, use_bias=True) 191 | self.hyper_jh = HyperNorm(num_units, num_units, hyper_embedding_size, use_bias=True) 192 | self.hyper_fh = HyperNorm(num_units, num_units, hyper_embedding_size, use_bias=True) 193 | self.hyper_oh = HyperNorm(num_units, num_units, hyper_embedding_size, use_bias=True) 194 | 195 | # zero initialization 196 | self.bias = nn.Parameter(torch.zeros(4 * num_units)) 197 | 198 | self.layer_norm_all = nn.LayerNorm(num_units * 4) 199 | self.layer_norm_c = nn.LayerNorm(num_units) 200 | 201 | def __call__(self, x: torch.Tensor, 202 | h: torch.Tensor, c: torch.Tensor, 203 | h_hat: torch.Tensor, c_hat: torch.Tensor): 204 | hyper_input = torch.cat([x, h], dim=-1) 205 | h_hat, c_hat = self.hyper(hyper_input, h_hat, c_hat) 206 | hyper_output = h_hat 207 | 208 | xh = self.w_x(x) 209 | hh = self.w_h(h) 210 | 211 | ix, jx, fx, ox = torch.chunk(xh, 4, 1) 212 | ix = self.hyper_ix(ix, hyper_output) 213 | jx = self.hyper_jx(jx, hyper_output) 214 | fx = self.hyper_fx(fx, hyper_output) 215 | ox = self.hyper_ox(ox, hyper_output) 216 | 217 | ih, jh, fh, oh = torch.chunk(hh, 4, 1) 218 | ih = self.hyper_ih(ih, hyper_output) 219 | jh = self.hyper_jh(jh, hyper_output) 220 | fh = self.hyper_fh(fh, hyper_output) 221 | oh = self.hyper_oh(oh, hyper_output) 222 | 223 | ib, jb, fb, ob = torch.chunk(self.bias, 4, 0) 224 | 225 | i = ix + ih + ib 226 | j = jx + jh + jb 227 | f = fx + fh + fb 228 | o = ox + oh + ob 229 | 230 | concat = torch.cat([i, j, f, o], 1) 231 | concat = self.layer_norm_all(concat) 232 | i, j, f, o = torch.chunk(concat, 4, 1) 233 | 234 | g = torch.tanh(j) 235 | 236 | c_next = c * torch.sigmoid(f + self.forget_bias) + torch.sigmoid(i) * g 237 | h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next)) 238 | 239 | return h_next, c_next, h_hat, c_hat 240 | 241 | 242 | class HyperLSTM(nn.Module): 243 | """ 244 | # HyperLSTM module 245 | """ 246 | def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int): 247 | """ 248 | Create a network of `n_layers` of HyperLSTM. 249 | """ 250 | 251 | super().__init__() 252 | 253 | # Store sizes to initialize state 254 | self.n_layers = n_layers 255 | self.hidden_size = hidden_size 256 | self.hyper_size = hyper_size 257 | 258 | # Create cells for each layer. Note that only the first layer gets the input directly. 259 | # Rest of the layers get the input from the layer below 260 | self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, hyper_size, n_z)] + 261 | [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in 262 | range(n_layers - 1)]) 263 | 264 | def __call__(self, x: torch.Tensor, 265 | state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None): 266 | """ 267 | * `x` has shape `[n_steps, batch_size, input_size]` and 268 | * `state` is a tuple of $h, c, \hat{h}, \hat{c}$. 269 | $h, c$ have shape `[batch_size, hidden_size]` and 270 | $\hat{h}, \hat{c}$ have shape `[batch_size, hyper_size]`. 271 | """ 272 | n_steps, batch_size = x.shape[:2] 273 | 274 | # Initialize the state with zeros if `None` 275 | if state is None: 276 | h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)] 277 | c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)] 278 | h_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)] 279 | c_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)] 280 | # 281 | else: 282 | (h, c, h_hat, c_hat) = state 283 | # Reverse stack the tensors to get the states of each layer 284 | # 285 | # ?? You can just work with the tensor itself but this is easier to debug 286 | h, c = list(torch.unbind(h)), list(torch.unbind(c)) 287 | h_hat, c_hat = list(torch.unbind(h_hat)), list(torch.unbind(c_hat)) 288 | 289 | # Collect the outputs of the final layer at each step 290 | out = [] 291 | for t in range(n_steps): 292 | # Input to the first layer is the input itself 293 | inp = x[t] 294 | # Loop through the layers 295 | for layer in range(self.n_layers): 296 | # Get the state of the layer 297 | h[layer], c[layer], h_hat[layer], c_hat[layer] = \ 298 | self.cells[layer](inp, h[layer], c[layer], h_hat[layer], c_hat[layer]) 299 | # Input to the next layer is the state of this layer 300 | inp = h[layer] 301 | # Collect the output $h$ of the final layer 302 | out.append(h[-1]) 303 | 304 | # Stack the outputs and states 305 | out = torch.stack(out) 306 | h = torch.stack(h) 307 | c = torch.stack(c) 308 | h_hat = torch.stack(h_hat) 309 | c_hat = torch.stack(c_hat) 310 | 311 | # 312 | return out, (h, c, h_hat, c_hat) 313 | -------------------------------------------------------------------------------- /sample_inputs/rough/svg/27-0.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /sample_inputs/rough/svg/22-0.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /sample_inputs/rough/svg/24-0.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | 13 | 14 | --------------------------------------------------------------------------------