├── dataset_utils
├── ref-dot.png
└── dataset_util.py
├── docs
├── figures
│ ├── dynamic1.gif
│ ├── dynamic2.gif
│ ├── framework.png
│ ├── Krita
│ │ ├── krita1.png
│ │ ├── krita2.png
│ │ └── krita3.png
│ ├── inbetweens
│ │ ├── 0.png
│ │ ├── 5.png
│ │ ├── 24.png
│ │ ├── 27.png
│ │ ├── 3-1.png
│ │ ├── 3-2.png
│ │ ├── 3-3.png
│ │ ├── 4-1.png
│ │ ├── 4-2.png
│ │ ├── 21-1.png
│ │ ├── 21-2.png
│ │ ├── 25-1.png
│ │ ├── 25-2.png
│ │ ├── 0-dynamic.gif
│ │ ├── 0-reference.png
│ │ ├── 21-dynamic.gif
│ │ ├── 24-dynamic.gif
│ │ ├── 25-dynamic.gif
│ │ ├── 27-dynamic.gif
│ │ ├── 3-dynamic.gif
│ │ ├── 3-reference.png
│ │ ├── 4-dynamic2.gif
│ │ ├── 4-reference.png
│ │ ├── 5-dynamic.gif
│ │ ├── 5-reference.png
│ │ ├── 21-reference.png
│ │ ├── 24-reference.png
│ │ ├── 25-reference.png
│ │ ├── 27-reference.png
│ │ └── rough
│ │ │ ├── 21-target1.png
│ │ │ ├── 21-target2.png
│ │ │ ├── 24-target.png
│ │ │ ├── 25-target1.png
│ │ │ ├── 25-target2.png
│ │ │ ├── 27-target.png
│ │ │ ├── 21-reference.png
│ │ │ ├── 24-reference.png
│ │ │ ├── 25-reference.png
│ │ │ └── 27-reference.png
│ ├── teaser-sub1.png
│ ├── teaser-sub2.png
│ └── CACANi
│ │ ├── canani1.png
│ │ ├── canani2.png
│ │ ├── canani3.png
│ │ ├── canani4.png
│ │ ├── canani5.png
│ │ ├── canani6.png
│ │ └── canani7.png
└── assets
│ ├── font.css
│ └── style.css
├── sample_inputs
├── clean
│ ├── raster
│ │ ├── 0-0.png
│ │ ├── 0-1.png
│ │ ├── 2-0.png
│ │ ├── 2-1.png
│ │ ├── 2-2.png
│ │ ├── 3-0.png
│ │ ├── 3-1.png
│ │ ├── 3-2.png
│ │ ├── 3-3.png
│ │ ├── 4-0.png
│ │ ├── 4-1.png
│ │ ├── 4-2.png
│ │ ├── 5-0.png
│ │ ├── 5-1.png
│ │ ├── 7-0.png
│ │ ├── 7-1.png
│ │ ├── 9-0.png
│ │ ├── 9-1.png
│ │ ├── 9-2.png
│ │ ├── 11-0.png
│ │ └── 11-1.png
│ └── svg
│ │ ├── 2-0.svg
│ │ ├── 0-0.svg
│ │ ├── 3-0.svg
│ │ ├── 4-0.svg
│ │ ├── 5-0.svg
│ │ └── 11-0.svg
└── rough
│ ├── raster
│ ├── 20-0.png
│ ├── 20-1.png
│ ├── 21-0.png
│ ├── 21-1.png
│ ├── 21-2.png
│ ├── 21-3.png
│ ├── 22-0.png
│ ├── 22-1.png
│ ├── 23-0.png
│ ├── 23-1.png
│ ├── 23-2.png
│ ├── 24-0.png
│ ├── 24-1.png
│ ├── 25-0.png
│ ├── 25-1.png
│ ├── 25-2.png
│ ├── 26-0.png
│ ├── 26-1.png
│ ├── 27-0.png
│ ├── 27-1.png
│ ├── 27-2.png
│ ├── 28-0.png
│ ├── 28-1.png
│ ├── 30-0.png
│ ├── 30-1.png
│ ├── 34-0.png
│ └── 34-1.png
│ └── svg
│ ├── 34-0.svg
│ ├── 30-0.svg
│ ├── 25-0.svg
│ ├── 27-0.svg
│ ├── 22-0.svg
│ └── 24-0.svg
├── .gitignore
├── tools
├── image_squaring.py
├── make_inbetweening.py
├── vis_difference.py
├── npz_to_svg.py
└── svg_to_npz.py
├── tutorials
├── Krita_vector_generation.md
└── CACANi_inbetweening_generation.md
├── sketch_tracing_inference.py
├── README.md
├── vgg_utils
└── VGG16.py
├── LICENSE
└── rnn2.py
/dataset_utils/ref-dot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/dataset_utils/ref-dot.png
--------------------------------------------------------------------------------
/docs/figures/dynamic1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/dynamic1.gif
--------------------------------------------------------------------------------
/docs/figures/dynamic2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/dynamic2.gif
--------------------------------------------------------------------------------
/docs/figures/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/framework.png
--------------------------------------------------------------------------------
/docs/figures/Krita/krita1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/Krita/krita1.png
--------------------------------------------------------------------------------
/docs/figures/Krita/krita2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/Krita/krita2.png
--------------------------------------------------------------------------------
/docs/figures/Krita/krita3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/Krita/krita3.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/0.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/5.png
--------------------------------------------------------------------------------
/docs/figures/teaser-sub1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/teaser-sub1.png
--------------------------------------------------------------------------------
/docs/figures/teaser-sub2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/teaser-sub2.png
--------------------------------------------------------------------------------
/docs/figures/CACANi/canani1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/CACANi/canani1.png
--------------------------------------------------------------------------------
/docs/figures/CACANi/canani2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/CACANi/canani2.png
--------------------------------------------------------------------------------
/docs/figures/CACANi/canani3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/CACANi/canani3.png
--------------------------------------------------------------------------------
/docs/figures/CACANi/canani4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/CACANi/canani4.png
--------------------------------------------------------------------------------
/docs/figures/CACANi/canani5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/CACANi/canani5.png
--------------------------------------------------------------------------------
/docs/figures/CACANi/canani6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/CACANi/canani6.png
--------------------------------------------------------------------------------
/docs/figures/CACANi/canani7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/CACANi/canani7.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/24.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/24.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/27.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/27.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/3-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/3-1.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/3-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/3-2.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/3-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/3-3.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/4-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/4-1.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/4-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/4-2.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/21-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/21-1.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/21-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/21-2.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/25-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/25-1.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/25-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/25-2.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/0-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/0-0.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/0-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/0-1.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/2-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/2-0.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/2-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/2-1.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/2-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/2-2.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/3-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/3-0.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/3-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/3-1.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/3-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/3-2.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/3-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/3-3.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/4-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/4-0.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/4-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/4-1.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/4-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/4-2.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/5-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/5-0.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/5-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/5-1.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/7-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/7-0.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/7-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/7-1.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/9-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/9-0.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/9-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/9-1.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/9-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/9-2.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/11-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/11-0.png
--------------------------------------------------------------------------------
/sample_inputs/clean/raster/11-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/clean/raster/11-1.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/20-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/20-0.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/20-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/20-1.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/21-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/21-0.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/21-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/21-1.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/21-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/21-2.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/21-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/21-3.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/22-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/22-0.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/22-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/22-1.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/23-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/23-0.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/23-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/23-1.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/23-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/23-2.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/24-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/24-0.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/24-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/24-1.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/25-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/25-0.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/25-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/25-1.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/25-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/25-2.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/26-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/26-0.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/26-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/26-1.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/27-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/27-0.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/27-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/27-1.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/27-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/27-2.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/28-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/28-0.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/28-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/28-1.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/30-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/30-0.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/30-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/30-1.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/34-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/34-0.png
--------------------------------------------------------------------------------
/sample_inputs/rough/raster/34-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/sample_inputs/rough/raster/34-1.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/0-dynamic.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/0-dynamic.gif
--------------------------------------------------------------------------------
/docs/figures/inbetweens/0-reference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/0-reference.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/21-dynamic.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/21-dynamic.gif
--------------------------------------------------------------------------------
/docs/figures/inbetweens/24-dynamic.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/24-dynamic.gif
--------------------------------------------------------------------------------
/docs/figures/inbetweens/25-dynamic.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/25-dynamic.gif
--------------------------------------------------------------------------------
/docs/figures/inbetweens/27-dynamic.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/27-dynamic.gif
--------------------------------------------------------------------------------
/docs/figures/inbetweens/3-dynamic.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/3-dynamic.gif
--------------------------------------------------------------------------------
/docs/figures/inbetweens/3-reference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/3-reference.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/4-dynamic2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/4-dynamic2.gif
--------------------------------------------------------------------------------
/docs/figures/inbetweens/4-reference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/4-reference.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/5-dynamic.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/5-dynamic.gif
--------------------------------------------------------------------------------
/docs/figures/inbetweens/5-reference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/5-reference.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/21-reference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/21-reference.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/24-reference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/24-reference.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/25-reference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/25-reference.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/27-reference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/27-reference.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/rough/21-target1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/21-target1.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/rough/21-target2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/21-target2.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/rough/24-target.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/24-target.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/rough/25-target1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/25-target1.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/rough/25-target2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/25-target2.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/rough/27-target.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/27-target.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/rough/21-reference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/21-reference.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/rough/24-reference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/24-reference.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/rough/25-reference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/25-reference.png
--------------------------------------------------------------------------------
/docs/figures/inbetweens/rough/27-reference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarkMoHR/JoSTC/HEAD/docs/figures/inbetweens/rough/27-reference.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | .idea/
3 | data/
4 | datas/
5 | dataset/
6 | datasets/
7 | model/
8 | models/
9 | testData/
10 | output/
11 | outputs/
12 |
13 | *.csv
14 |
15 | # temporary files
16 | *.txt~
17 | *.pyc
18 | .DS_Store
19 | .gitignore~
20 |
21 | *.h5
--------------------------------------------------------------------------------
/tools/image_squaring.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | import argparse
4 |
5 |
6 | def image_squaring(img_path):
7 | img = Image.open(img_path).convert('RGB')
8 | height, width = img.height, img.width
9 |
10 | max_dim = max(height, width)
11 |
12 | pad_top = (max_dim - height) // 2
13 | pad_down = max_dim - height - pad_top
14 | pad_left = (max_dim - width) // 2
15 | pad_right = max_dim - width - pad_left
16 |
17 | img = np.array(img, dtype=np.uint8)
18 | img_p = np.pad(img, ((pad_top, pad_down), (pad_left, pad_right), (0, 0)), 'constant', constant_values=255)
19 | img_p = Image.fromarray(img_p, 'RGB')
20 | img_p.save(img_path[:-4] + '-pad.png', 'PNG')
21 |
22 |
23 | if __name__ == '__main__':
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--file', '-f', type=str, default='', help="define an image")
26 | args = parser.parse_args()
27 |
28 | image_squaring(args.file)
29 |
--------------------------------------------------------------------------------
/tutorials/Krita_vector_generation.md:
--------------------------------------------------------------------------------
1 | # Make Vector Reference Frame with Krita
2 |
3 | 1. Install [Krita](https://krita.org/en/)
4 | 2. New File - Create (arbitrary configurations)
5 | 3. In Layers panel, remove the default Paint Layer. Then, add a File Layer. Select your reference image, and choose the 'Scale to Image Size' option. You will see image now. Remember to lock the layer in the panel.
6 | 4. Add a Vector Layer and move it above the File Layer. Then, choose the 'Bezier Curve Tool' on the left. Select any color you like and set the Size of the curve.
7 |
8 |
9 |
10 | 5. Trace the drawing. Set all the intermediate points to 'Corner point' for better control. Then, drag them to align the curves with the drawing.
11 |
12 |
13 |
14 | 6. After tracing the entire drawing, save it to an svg file: Layer - Import/Export - Save Vector Layer as SVG. Done!
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/tutorials/CACANi_inbetweening_generation.md:
--------------------------------------------------------------------------------
1 | # Make Inbetweening and Animation with CACANi
2 |
3 | 1. Install [CACANi](https://cacani.sg/).
4 | 2. Add keyframes according to the number of SVGs. There are 3 keyframes now:
5 |
6 |
7 |
8 | 3. Import SVG files in `separate/` folder one by one: File - Import - CACS/SVG File
9 |
10 |
11 |
12 | 4. Note that the imported keyframe is in a new cell. You should copy all the vector strokes to the default cell (Cel 1 - Layer 1):
13 |
14 |
15 |
16 | 5. Delete the cell above.
17 |
18 |
19 |
20 | 6. Repeat step-3 to step-5 to import the other SVGs to the corresponding keyframes. Then, add intermediate frames:
21 |
22 |
23 |
24 | 7. Select the empty intermediate frames between two keyframes, and generate the inbetweens. Repeat the process for the other empty intermediate frames.
25 |
26 |
27 |
28 | 8. Finally, broadcast the animation. You can choose different animation modes and set different frame rates.
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/docs/assets/font.css:
--------------------------------------------------------------------------------
1 | /* Homepage Font */
2 |
3 | /* latin-ext */
4 | @font-face {
5 | font-family: 'Lato';
6 | font-style: normal;
7 | font-weight: 400;
8 | src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjxAwXjeu.woff2) format('woff2');
9 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF;
10 | }
11 |
12 | /* latin */
13 | @font-face {
14 | font-family: 'Lato';
15 | font-style: normal;
16 | font-weight: 400;
17 | src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjx4wXg.woff2) format('woff2');
18 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
19 | }
20 |
21 | /* latin-ext */
22 | @font-face {
23 | font-family: 'Lato';
24 | font-style: normal;
25 | font-weight: 700;
26 | src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwaPGR_p.woff2) format('woff2');
27 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF;
28 | }
29 |
30 | /* latin */
31 | @font-face {
32 | font-family: 'Lato';
33 | font-style: normal;
34 | font-weight: 700;
35 | src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwiPGQ.woff2) format('woff2');
36 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
37 | }
38 |
--------------------------------------------------------------------------------
/sketch_tracing_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import six
3 | import argparse
4 |
5 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
6 | os.environ["KMP_WARNINGS"] = "0"
7 |
8 | import model_full as sketch_tracing_model
9 | from utils import load_dataset
10 |
11 |
12 | def main(dataset_base, data_type, img_seq, mode='inference'):
13 | model_params = sketch_tracing_model.get_default_hparams()
14 | model_params.add_hparam('dataset_base', dataset_base)
15 | model_params.add_hparam('data_type', data_type)
16 | model_params.add_hparam('img_seq', img_seq)
17 |
18 | print('Hyperparams:')
19 | for key, val in six.iteritems(model_params.values()):
20 | print('%s = %s' % (key, str(val)))
21 | print('-' * 100)
22 |
23 | val_set = load_dataset(model_params)
24 | model = sketch_tracing_model.FullModel(model_params, val_set)
25 |
26 | if mode == 'inference':
27 | inference_root = model_params.inference_root
28 | sub_inference_root = os.path.join(inference_root, model_params.data_type)
29 | os.makedirs(sub_inference_root, exist_ok=True)
30 | model.inference(sub_inference_root, model_params.img_seq)
31 | # elif mode == 'test':
32 | # model.evaluate(load_trained_weights=True)
33 | else:
34 | raise Exception('Unknown mode:', mode)
35 |
36 |
37 | if __name__ == '__main__':
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument('--dataset_base', '-db', type=str, default='sample_inputs', help="define the data base")
40 | parser.add_argument('--data_type', '-dt', type=str, default='rough', choices=['clean', 'rough'], help="define the data type")
41 | parser.add_argument('--img_seq', '-is', type=str, default=['23-0.png', '23-1.png', '23-2.png'], nargs='+', help="define the image sequence")
42 |
43 | args = parser.parse_args()
44 |
45 | main(args.dataset_base, args.data_type, args.img_seq)
46 |
--------------------------------------------------------------------------------
/tools/make_inbetweening.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 |
5 |
6 | def gen_intensity_list(max_inten, min_inten, num):
7 | interval = (max_inten - min_inten) // (num - 1)
8 | intensity_list = [min_inten + i * interval for i in range(num)]
9 | intensity_list = intensity_list[::-1]
10 | return intensity_list
11 |
12 |
13 | def make_inbetweening_img(data_base, image_sequence):
14 | max_intensity = 200
15 | min_intensity = 0
16 | black_threshold = 128
17 |
18 | img_num = len(image_sequence)
19 |
20 | intensity_list = gen_intensity_list(max_intensity, min_intensity, img_num)
21 | print('intensity_list', intensity_list)
22 |
23 | img_inbetween = None
24 |
25 | for i, img_name in enumerate(image_sequence):
26 | img_path = os.path.join(data_base, img_name)
27 | img = Image.open(img_path).convert('RGB')
28 | img = np.array(img, dtype=np.uint8)[:, :, 0]
29 |
30 | intensity = intensity_list[i]
31 |
32 | if img_inbetween is None:
33 | img_inbetween = np.ones_like(img) * 255
34 |
35 | img_inbetween[img <= black_threshold] = intensity
36 |
37 | img_inbetween = Image.fromarray(img_inbetween, 'L')
38 | save_path = os.path.join(data_base, 'inbetweening.png')
39 | img_inbetween.save(save_path)
40 |
41 |
42 | def make_inbetweening_gif(data_base, image_sequence):
43 | all_files = image_sequence + image_sequence[::-1][1:]
44 |
45 | gif_frames = []
46 | for img_name in all_files:
47 | img_i = Image.open(os.path.join(data_base, img_name))
48 | gif_frames.append(img_i)
49 |
50 | print('gif_frames', len(gif_frames))
51 | save_path = os.path.join(data_base, 'inbetweening.gif')
52 | first_frame = gif_frames[0]
53 | first_frame.save(save_path, save_all=True, append_images=gif_frames, loop=0, duration=0.01)
54 |
55 |
56 | if __name__ == '__main__':
57 | data_base = '../outputs/inference/clean/inbetweening/9'
58 | image_sequence = ['10.png', '18.png', '25.png', '32.png', '40.png']
59 |
60 | make_inbetweening_img(data_base, image_sequence)
61 | make_inbetweening_gif(data_base, image_sequence)
62 |
--------------------------------------------------------------------------------
/tools/vis_difference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 | import argparse
5 |
6 |
7 | def vis_difference(database_input, data_type, reference_image, target_image):
8 | data_base = os.path.join(database_input, data_type)
9 | raster_base = os.path.join(data_base, 'raster')
10 | raster_diff_base = os.path.join(data_base, 'raster_diff')
11 | os.makedirs(raster_diff_base, exist_ok=True)
12 |
13 | ref_img_path = os.path.join(raster_base, reference_image)
14 | tar_img_path = os.path.join(raster_base, target_image)
15 |
16 | ref_img = Image.open(ref_img_path).convert('L')
17 | ref_img = np.array(ref_img, dtype=np.float32)
18 | tar_img = Image.open(tar_img_path).convert('RGB')
19 | tar_img = np.array(tar_img, dtype=np.float32)
20 |
21 | target_mask = (tar_img < 200).any(-1)
22 | ref_img_bg = np.expand_dims(ref_img, axis=-1).astype(np.float32)
23 | ref_img_bg = np.concatenate(
24 | [np.ones_like(ref_img_bg) * 255,
25 | ref_img_bg,
26 | np.ones_like(ref_img_bg) * 255], axis=-1)
27 | ref_img_bg = 255 - (255 - ref_img_bg) * 0.5
28 | ref_img_bg[target_mask] = tar_img[target_mask]
29 | ref_img_bg = ref_img_bg.astype(np.uint8)
30 |
31 | save_path = os.path.join(raster_diff_base, reference_image[:-4] + '_and_' + target_image[:-4] + '.png')
32 | ref_img_bg_png = Image.fromarray(ref_img_bg, 'RGB')
33 | ref_img_bg_png.save(save_path, 'PNG')
34 |
35 |
36 | if __name__ == '__main__':
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument('--database_input', '-dbi', type=str, default='../sample_inputs',
39 | help="define the input data base")
40 | parser.add_argument('--data_type', '-dt', type=str, default='rough', choices=['clean', 'rough'],
41 | help="define the data type")
42 | parser.add_argument('--reference_image', '-ri', type=str, default='23-0.png',
43 | help="define the reference image")
44 | parser.add_argument('--target_image', '-ti', type=str, default='23-1.png',
45 | help="define the target image")
46 |
47 | args = parser.parse_args()
48 |
49 | vis_difference(args.database_input, args.data_type, args.reference_image, args.target_image)
50 |
--------------------------------------------------------------------------------
/docs/assets/style.css:
--------------------------------------------------------------------------------
1 | /* Body */
2 | body {
3 | background: #e3e5e8;
4 | color: #ffffff;
5 | font-family: 'Lato', Verdana, Helvetica, sans-serif;
6 | font-weight: 300;
7 | font-size: 14pt;
8 | }
9 |
10 | /* Hyperlinks */
11 | a {text-decoration: none;}
12 | a:link {color: #1772d0;}
13 | a:visited {color: #1772d0;}
14 | a:active {color: red;}
15 | a:hover {color: #f09228;}
16 |
17 | /* Pre-formatted Text */
18 | pre {
19 | margin: 5pt 0;
20 | border: 0;
21 | font-size: 12pt;
22 | background: #fcfcfc;
23 | }
24 |
25 | /* Project Page Style */
26 | /* Section */
27 | .section {
28 | width: 968pt;
29 | min-height: 100pt;
30 | margin: 15pt auto;
31 | padding: 20pt 30pt;
32 | border: 1pt hidden #000;
33 | text-align: justify;
34 | color: #000000;
35 | background: #ffffff;
36 | }
37 |
38 | /* Header (Title and Logo) */
39 | .section .header {
40 | min-height: 80pt;
41 | margin-top: 30pt;
42 | }
43 | .section .header .logo {
44 | width: 80pt;
45 | margin-left: 10pt;
46 | float: left;
47 | }
48 | .section .header .logo img {
49 | width: 80pt;
50 | object-fit: cover;
51 | }
52 | .section .header .title {
53 | margin: 0 120pt;
54 | text-align: center;
55 | font-size: 22pt;
56 | }
57 |
58 | /* Author */
59 | .section .author {
60 | margin: 5pt 0;
61 | text-align: center;
62 | font-size: 16pt;
63 | }
64 |
65 | /* Institution */
66 | .section .institution {
67 | margin: 5pt 0;
68 | text-align: center;
69 | font-size: 16pt;
70 | }
71 |
72 | /* Hyperlink (such as Paper and Code) */
73 | .section .link {
74 | margin: 5pt 0;
75 | text-align: center;
76 | font-size: 16pt;
77 | }
78 |
79 | /* Teaser */
80 | .section .teaser {
81 | margin: 20pt 0;
82 | text-align: left;
83 | }
84 | .section .teaser img {
85 | width: 95%;
86 | }
87 |
88 | /* Section Title */
89 | .section .title {
90 | text-align: center;
91 | font-size: 22pt;
92 | margin: 5pt 0 15pt 0; /* top right bottom left */
93 | }
94 |
95 | /* Section Body */
96 | .section .body {
97 | margin-bottom: 15pt;
98 | text-align: justify;
99 | font-size: 14pt;
100 | }
101 |
102 | /* BibTeX */
103 | .section .bibtex {
104 | margin: 5pt 0;
105 | text-align: left;
106 | font-size: 22pt;
107 | }
108 |
109 | /* Related Work */
110 | .section .ref {
111 | margin: 20pt 0 10pt 0; /* top right bottom left */
112 | text-align: left;
113 | font-size: 18pt;
114 | font-weight: bold;
115 | }
116 |
117 | /* Citation */
118 | .section .citation {
119 | min-height: 60pt;
120 | margin: 10pt 0;
121 | }
122 | .section .citation .image {
123 | width: 120pt;
124 | float: left;
125 | }
126 | .section .citation .image img {
127 | max-height: 60pt;
128 | width: 120pt;
129 | object-fit: cover;
130 | }
131 | .section .citation .comment{
132 | margin-left: 0pt;
133 | text-align: left;
134 | font-size: 14pt;
135 | }
136 |
--------------------------------------------------------------------------------
/dataset_utils/dataset_util.py:
--------------------------------------------------------------------------------
1 | import pydiffvg
2 | import torch
3 |
4 | from dataset_utils.common import generate_colors2, generate_colors_ordered
5 |
6 |
7 | def draw_segment_jointly(strokes_points_list, canvas_size, stroke_thickness, render, background, is_black=True):
8 | shapes = []
9 | shape_groups = []
10 |
11 | stroke_num = len(strokes_points_list)
12 | colors = generate_colors2(stroke_num) # list of (3), in [0., 1.]
13 |
14 | for i in range(stroke_num):
15 | points_single_stroke = strokes_points_list[i] # (N, 2)
16 | points_single_stroke = torch.tensor(points_single_stroke)
17 |
18 | segment_num = (len(points_single_stroke) - 1) // 3
19 | if is_black:
20 | stroke_color = [0.0, 0.0, 0.0, 1.0]
21 | else:
22 | stroke_color = [colors[i][0], colors[i][1], colors[i][2], 1.0]
23 |
24 | num_control_points = torch.zeros(segment_num, dtype=torch.int32) + 2
25 |
26 | path = pydiffvg.Path(num_control_points=num_control_points,
27 | points=points_single_stroke,
28 | is_closed=False,
29 | stroke_width=torch.tensor(stroke_thickness))
30 | shapes.append(path)
31 |
32 | path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(shapes) - 1]),
33 | fill_color=None,
34 | stroke_color=torch.tensor(stroke_color))
35 | shape_groups.append(path_group)
36 |
37 | scene_args_t = pydiffvg.RenderFunction.serialize_scene(
38 | canvas_size, canvas_size, shapes, shape_groups)
39 | img = render(canvas_size, # width
40 | canvas_size, # height
41 | 2, # num_samples_x
42 | 2, # num_samples_y
43 | 1, # seed
44 | background, # background_image
45 | *scene_args_t)
46 | return img
47 |
48 |
49 | def draw_segment_separately(strokes_points_list, canvas_size, stroke_thickness, max_seq_number,
50 | render, background, is_black=False, draw_order=False):
51 | shapes = []
52 | shape_groups = []
53 |
54 | stroke_num = len(strokes_points_list)
55 |
56 | if draw_order:
57 | colors = generate_colors_ordered(max_seq_number) # list of (3), in [0., 1.]
58 | else:
59 | colors = generate_colors2(max_seq_number) # list of (3), in [0., 1.]
60 | global_segment_idx = 0
61 |
62 | for i in range(stroke_num):
63 | points_single_stroke = strokes_points_list[i] # (N, 2)
64 | points_single_stroke = torch.tensor(points_single_stroke)
65 |
66 | segment_num = (len(points_single_stroke) - 1) // 3
67 |
68 | for j in range(segment_num):
69 | start_idx = j * 3
70 | points_single_segment = points_single_stroke[start_idx: start_idx + 4]
71 |
72 | if is_black:
73 | stroke_color = [0.0, 0.0, 0.0, 1.0]
74 | else:
75 | stroke_color = [colors[global_segment_idx][0], colors[global_segment_idx][1], colors[global_segment_idx][2], 1.0]
76 |
77 | num_control_points = torch.tensor([2])
78 |
79 | path = pydiffvg.Path(num_control_points=num_control_points,
80 | points=points_single_segment,
81 | is_closed=False,
82 | stroke_width=torch.tensor(stroke_thickness))
83 | shapes.append(path)
84 |
85 | path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(shapes) - 1]),
86 | fill_color=None,
87 | stroke_color=torch.tensor(stroke_color))
88 | shape_groups.append(path_group)
89 |
90 | global_segment_idx += 1
91 |
92 | scene_args_t = pydiffvg.RenderFunction.serialize_scene(
93 | canvas_size, canvas_size, shapes, shape_groups)
94 | img = render(canvas_size, # width
95 | canvas_size, # height
96 | 2, # num_samples_x
97 | 2, # num_samples_y
98 | 1, # seed
99 | background, # background_image
100 | *scene_args_t)
101 | return img
102 |
--------------------------------------------------------------------------------
/sample_inputs/clean/svg/2-0.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
--------------------------------------------------------------------------------
/sample_inputs/clean/svg/0-0.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
--------------------------------------------------------------------------------
/sample_inputs/rough/svg/34-0.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
--------------------------------------------------------------------------------
/sample_inputs/rough/svg/30-0.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Joint Stroke Tracing and Correspondence for 2D Animation - TOG & SIGGRAPH 2024
2 |
3 | [[Paper]](https://www.sysu-imsl.com/files/TOG2024/SketchTracing_TOG2024_personal.pdf) | [[Paper (ACM)]](https://dl.acm.org/doi/10.1145/3649890) | [[Project Page]](https://markmohr.github.io/JoSTC/)
4 |
5 | This code is used for producing stroke tracing and correspondence results, which can be imported into an inbetweening product named [CACANi](https://cacani.sg) for making 2D animations.
6 |
7 | 
8 |
9 | 
10 |
11 | ## Outline
12 | - [Dependencies](#dependencies)
13 | - [Quick Start](#quick-start)
14 | - [Vector Stroke Correspondence Dataset](#vector-stroke-correspondence-dataset)
15 | - [Citation](#citation)
16 |
17 | ## Dependencies
18 | - [cudatoolkit](https://www.anaconda.com/download/) == 11.0.3
19 | - [cudnn](https://www.anaconda.com/download/) == 8.4.1.50
20 | - [pytorch](https://pytorch.org/) == 1.9.0
21 | - [torchvision](https://pytorch.org/vision/0.9/) == 0.9.0
22 | - [diffvg](https://github.com/BachiLi/diffvg)
23 | - [Krita](https://krita.org/en/): for making reference vector frame
24 | - [CACANi](https://cacani.sg/): for making inbetweening and 2D animation
25 |
26 | ## Quick Start
27 |
28 | ### Model Preparation
29 |
30 | Download the models [here](https://drive.google.com/drive/folders/15oAP7YbNKx4Cx1AmzC16wuoyQoAai2YV?usp=sharing), and place them in this file structure:
31 | ```
32 | models/
33 | quickdraw-perceptual.pth
34 | point_matching_model/
35 | sketch_correspondence_50000.pkl
36 | transform_module/
37 | sketch_transform_30000.pkl
38 | stroke_tracing_model/
39 | sketch_tracing_30000.pkl
40 | ```
41 |
42 | ### Create Reference Vector Frames with Krita
43 |
44 | Our method takes as inputs consecutive raster keyframes and a single vector drawing from the starting keyframe, and then generates vector images for the remaining keyframes with one-to-one stroke correspondence. So we have to create the vector image for the reference frame here.
45 |
46 | **Note**: We provide several examples for testing in directory `sample_inputs/`. If you use them, you can skip step-1 and step-2 below and execute step-3 directly.
47 |
48 | 1. Our method takes squared images as input, so please preprocess the images first using [tools/image_squaring.py](https://github.com/MarkMoHR/JoSTC/blob/main/tools/image_squaring.py):
49 | ```
50 | python3 tools/image_squaring.py --file path/to/the/image.png
51 | ```
52 |
53 | 2. Follow tutorial [here](https://github.com/MarkMoHR/JoSTC/blob/main/tutorials/Krita_vector_generation.md) to make vector frames as a reference in svg format with [Krita](https://krita.org/en/).
54 | 3. Place the svg files in `sample_inputs/*/svg/`. Then, convert them into npz format using [tools/svg_to_npz.py](https://github.com/MarkMoHR/JoSTC/blob/main/tools/svg_to_npz.py):
55 | ```
56 | cd tools/
57 | python3 svg_to_npz.py --database ../sample_inputs/rough/ --reference 23-0.png
58 | ```
59 |
60 | ### Execute the Main Code
61 |
62 | Perform joint stroke tracing and correspondence using [sketch_tracing_inference.py](https://github.com/MarkMoHR/JoSTC/blob/main/sketch_tracing_inference.py). We provide several examples for testing in directory `sample_inputs/`.
63 | ```
64 | python3 sketch_tracing_inference.py --dataset_base sample_inputs --data_type rough --img_seq 23-0.png 23-1.png 23-2.png
65 | ```
66 | - `--data_type`: specify the image type with `clean` or `rough`.
67 | - `--img_seq`: specify the animation frames here. The first one should be the reference frame.
68 | - The results are placed in `outputs/inference/*`. Inside this folder:
69 | - `raster/` stores rendered line drawings of the target vector frames.
70 | - `rgb/` stores visualization (with target images underneath) of the vector stroke correspondence to the reference frame.
71 | - `rgb-wo-bg/` stores visualization without target images underneath.
72 | - `parameter/` stores vector stroke parameters.
73 |
74 | ### Create Inbetweening and Animation with CACANi
75 |
76 | 1. Convert the output npz file (vector stroke parameters) into svg format using [tools/npz_to_svg.py](https://github.com/MarkMoHR/JoSTC/blob/main/tools/npz_to_svg.py):
77 | ```
78 | cd tools/
79 | python3 npz_to_svg.py --database_input ../sample_inputs/ --database_output ../outputs/inference/ --data_type rough --file_names 23-1.png 23-2.png
80 | ```
81 | - The results are placed in `outputs/inference/*/svg`. There are two kinds of results:
82 | - `chain/`: the svg files store stroke chains, defining each path as a chain.
83 | - `separate/`: the svg files store separated strokes, defining each path as a single stroke. **Note that the automatic inbetweening in [CACANi](https://cacani.sg/) relies on this format.**
84 |
85 | 2. Follow tutorial [here](https://github.com/MarkMoHR/JoSTC/blob/main/tutorials/CACANi_inbetweening_generation.md) to generate inbetweening and 2D animation with [CACANi](https://cacani.sg/).
86 |
87 | ### More Tools
88 |
89 | - [tools/vis_difference.py](https://github.com/MarkMoHR/JoSTC/blob/main/tools/vis_difference.py): visualize difference between the reference image and the target one. The results are placed in `sample_inputs/*/raster_diff/`
90 | ```
91 | cd tools/
92 | python3 vis_difference.py --database_input ../sample_inputs --data_type rough --reference_image 23-0.png --target_image 23-1.png
93 | ```
94 |
95 | - [tools/make_inbetweening.py](https://github.com/MarkMoHR/JoSTC/blob/main/tools/make_inbetweening.py): visualize the inbetweening in a single image or a gif file.
96 |
97 |
98 |
99 | ## Vector Stroke Correspondence Dataset
100 |
101 | We collect a dataset for training with 10k+ pairs of raster frames and their vector drawings with stroke correspondence. Please download it [here](https://drive.google.com/drive/folders/15oAP7YbNKx4Cx1AmzC16wuoyQoAai2YV?usp=sharing). We provide a reference code [dataset_utils/tuberlin_dataset_util.py](https://github.com/MarkMoHR/JoSTC/blob/main/dataset_utils/tuberlin_dataset_util.py) showing how to use the data.
102 |
103 |
104 |
105 | ## Citation
106 |
107 | If you use the code and models please cite:
108 |
109 | ```
110 | @article{mo2024joint,
111 | title={Joint Stroke Tracing and Correspondence for 2D Animation},
112 | author={Mo, Haoran and Gao, Chengying and Wang, Ruomei},
113 | journal={ACM Transactions on Graphics},
114 | volume={43},
115 | number={3},
116 | pages={1--17},
117 | year={2024},
118 | publisher={ACM New York, NY}
119 | }
120 | ```
121 |
122 |
123 |
--------------------------------------------------------------------------------
/sample_inputs/clean/svg/3-0.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
--------------------------------------------------------------------------------
/tools/npz_to_svg.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 | import random
5 | import argparse
6 |
7 | from xml.dom import minidom
8 |
9 |
10 | def write_svg(points_list, height, width, svg_save_path):
11 | impl_save = minidom.getDOMImplementation()
12 |
13 | doc_save = impl_save.createDocument(None, None, None)
14 |
15 | rootElement_save = doc_save.createElement('svg')
16 | rootElement_save.setAttribute('xmlns', 'http://www.w3.org/2000/svg')
17 |
18 | rootElement_save.setAttribute('height', str(height) + 'pt')
19 | rootElement_save.setAttribute('width', str(width) + 'pt')
20 |
21 | view_box = '0 0 ' + str(width) + ' ' + str(height)
22 | rootElement_save.setAttribute('viewBox', view_box)
23 |
24 | globl_path_i = 0
25 | for stroke_i, stroke_points in enumerate(points_list):
26 | # stroke_points: (N_point, 2), in image size
27 | segment_num = (stroke_points.shape[0] - 1) // 3
28 |
29 | for segment_i in range(segment_num):
30 | start_idx = segment_i * 3
31 | start_point = stroke_points[start_idx]
32 | ctrl_point1 = stroke_points[start_idx + 1]
33 | ctrl_point2 = stroke_points[start_idx + 2]
34 | end_point = stroke_points[start_idx + 3]
35 |
36 | command_str = 'M ' + str(start_point[0]) + ', ' + str(start_point[1]) + ' '
37 | command_str += 'C ' + str(ctrl_point1[0]) + ', ' + str(ctrl_point1[1]) + ' ' \
38 | + str(ctrl_point2[0]) + ', ' + str(ctrl_point2[1]) + ' ' \
39 | + str(end_point[0]) + ', ' + str(end_point[1]) + ' '
40 |
41 | childElement_save = doc_save.createElement('path')
42 | childElement_save.setAttribute('id', 'curve_' + str(globl_path_i))
43 | childElement_save.setAttribute('stroke', '#000000')
44 | childElement_save.setAttribute('stroke-linejoin', 'round')
45 | childElement_save.setAttribute('stroke-linecap', 'square')
46 | childElement_save.setAttribute('fill', 'none')
47 |
48 | childElement_save.setAttribute('d', command_str)
49 | childElement_save.setAttribute('stroke-width', str(2.375))
50 | rootElement_save.appendChild(childElement_save)
51 |
52 | globl_path_i += 1
53 |
54 | doc_save.appendChild(rootElement_save)
55 |
56 | f = open(svg_save_path, 'w')
57 | doc_save.writexml(f, addindent=' ', newl='\n')
58 | f.close()
59 |
60 |
61 | def write_svg_chain(points_list, height, width, svg_save_path):
62 | impl_save = minidom.getDOMImplementation()
63 |
64 | doc_save = impl_save.createDocument(None, None, None)
65 |
66 | rootElement_save = doc_save.createElement('svg')
67 | rootElement_save.setAttribute('xmlns', 'http://www.w3.org/2000/svg')
68 |
69 | rootElement_save.setAttribute('height', str(height) + 'pt')
70 | rootElement_save.setAttribute('width', str(width) + 'pt')
71 |
72 | view_box = '0 0 ' + str(width) + ' ' + str(height)
73 | rootElement_save.setAttribute('viewBox', view_box)
74 |
75 | for stroke_i, stroke_points in enumerate(points_list):
76 | # stroke_points: (N_point, 2), in image size
77 | segment_num = (stroke_points.shape[0] - 1) // 3
78 |
79 | command_str = 'M ' + str(stroke_points[0][0]) + ', ' + str(stroke_points[0][1]) + ' '
80 |
81 | for segment_i in range(segment_num):
82 | start_idx = segment_i * 3
83 | ctrl_point1 = stroke_points[start_idx + 1]
84 | ctrl_point2 = stroke_points[start_idx + 2]
85 | end_point = stroke_points[start_idx + 3]
86 |
87 | command_str += 'C ' + str(ctrl_point1[0]) + ', ' + str(ctrl_point1[1]) + ' ' \
88 | + str(ctrl_point2[0]) + ', ' + str(ctrl_point2[1]) + ' ' \
89 | + str(end_point[0]) + ', ' + str(end_point[1]) + ' '
90 |
91 | childElement_save = doc_save.createElement('path')
92 | childElement_save.setAttribute('id', 'curve_' + str(stroke_i))
93 | childElement_save.setAttribute('stroke', '#000000')
94 | childElement_save.setAttribute('stroke-linejoin', 'round')
95 | childElement_save.setAttribute('stroke-linecap', 'square')
96 | childElement_save.setAttribute('fill', 'none')
97 |
98 | childElement_save.setAttribute('d', command_str)
99 | childElement_save.setAttribute('stroke-width', str(2.375))
100 | rootElement_save.appendChild(childElement_save)
101 |
102 | doc_save.appendChild(rootElement_save)
103 |
104 | f = open(svg_save_path, 'w')
105 | doc_save.writexml(f, addindent=' ', newl='\n')
106 | f.close()
107 |
108 |
109 | def npz_to_svg(file_names, database_input, database_output, data_type):
110 | svg_chain_save_base = os.path.join(database_output, data_type, 'svg', 'chain')
111 | svg_separate_save_base = os.path.join(database_output, data_type, 'svg', 'separate')
112 | os.makedirs(svg_chain_save_base, exist_ok=True)
113 | os.makedirs(svg_separate_save_base, exist_ok=True)
114 |
115 | for file_name in file_names:
116 | file_id = file_name[:file_name.find('.')]
117 |
118 | output_image_path = os.path.join(database_output, data_type, 'raster', file_name)
119 | output_image = Image.open(output_image_path)
120 | width, height = output_image.width, output_image.height
121 |
122 | # output
123 | npz_path = os.path.join(database_output, data_type, 'parameter', file_id + '.npz')
124 | npz = np.load(npz_path, encoding='latin1', allow_pickle=True)
125 | strokes_data_output = npz['strokes_data'] # list of (N_point, 2), in image size
126 |
127 | svg_chain_output_save_path = os.path.join(svg_chain_save_base, file_id + '.svg')
128 | write_svg_chain(strokes_data_output, height, width, svg_chain_output_save_path)
129 | svg_separate_output_save_path = os.path.join(svg_separate_save_base, file_id + '.svg')
130 | write_svg(strokes_data_output, height, width, svg_separate_output_save_path)
131 |
132 | # rewrite input
133 | file_id = file_name[:file_name.find('-')]
134 |
135 | npz_path = os.path.join(database_input, data_type, 'parameter', file_id + '-0.npz')
136 | npz = np.load(npz_path, encoding='latin1', allow_pickle=True)
137 | strokes_data_input = npz['strokes_data'] # list of (N_point, 2), in image size
138 |
139 | svg_chain_input_save_path = os.path.join(svg_chain_save_base, file_id + '-0.svg')
140 | write_svg_chain(strokes_data_input, height, width, svg_chain_input_save_path)
141 | svg_chain_input_save_path = os.path.join(svg_separate_save_base, file_id + '-0.svg')
142 | write_svg(strokes_data_input, height, width, svg_chain_input_save_path)
143 |
144 |
145 | if __name__ == '__main__':
146 | parser = argparse.ArgumentParser()
147 | parser.add_argument('--database_input', '-dbi', type=str, default='../sample_inputs', help="define the input data base")
148 | parser.add_argument('--database_output', '-dbo', type=str, default='../outputs/inference', help="define the output data base")
149 | parser.add_argument('--data_type', '-dt', type=str, default='rough', choices=['clean', 'rough'], help="define the data type")
150 | parser.add_argument('--file_names', '-fn', type=str, default=['23-1.png', '23-2.png'], nargs='+', help="define the file names")
151 |
152 | args = parser.parse_args()
153 |
154 | npz_to_svg(args.file_names, args.database_input, args.database_output, args.data_type)
155 |
--------------------------------------------------------------------------------
/vgg_utils/VGG16.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import numpy as np
6 |
7 |
8 | class VGG_Slim(nn.Module):
9 | def __init__(self):
10 | super(VGG_Slim, self).__init__()
11 |
12 | self.conv11 = nn.Conv2d(1, 64, 3, 1, 1)
13 | self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
14 |
15 | self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
16 |
17 | self.conv21 =nn.Conv2d(64, 128, 3, 1, 1)
18 | self.conv22 =nn.Conv2d(128, 128, 3, 1, 1)
19 |
20 | self.conv31 =nn.Conv2d(128, 256, 3, 1, 1)
21 | self.conv32 =nn.Conv2d(256, 256, 3, 1, 1)
22 | self.conv33 =nn.Conv2d(256, 256, 3, 1, 1)
23 |
24 | self.conv41 =nn.Conv2d(256, 512, 3, 1, 1)
25 | self.conv42 =nn.Conv2d(512, 512, 3, 1, 1)
26 | self.conv43 =nn.Conv2d(512, 512, 3, 1, 1)
27 |
28 | self.conv51 =nn.Conv2d(512, 512, 3, 1, 1)
29 | self.conv52 =nn.Conv2d(512, 512, 3, 1, 1)
30 | self.conv53 =nn.Conv2d(512, 512, 3, 1, 1)
31 |
32 | def forward(self, input_imgs):
33 | return_map = {}
34 |
35 | x = input_imgs
36 | if x.dim() == 3:
37 | x = x.unsqueeze(dim=1) # NCHW
38 |
39 | x = F.relu(self.conv11(x))
40 | return_map['ReLU1_1'] = x
41 | x = F.relu(self.conv12(x))
42 | return_map['ReLU1_2'] = x
43 | x = self.max_pool(x)
44 |
45 | x = F.relu(self.conv21(x))
46 | return_map['ReLU2_1'] = x
47 | x = F.relu(self.conv22(x))
48 | return_map['ReLU2_2'] = x
49 | x = self.max_pool(x)
50 |
51 | x = F.relu(self.conv31(x))
52 | return_map['ReLU3_1'] = x
53 | x = F.relu(self.conv32(x))
54 | return_map['ReLU3_2'] = x
55 | x = F.relu(self.conv33(x))
56 | return_map['ReLU3_3'] = x
57 | x = self.max_pool(x)
58 |
59 | x = F.relu(self.conv41(x))
60 | return_map['ReLU4_1'] = x
61 | x = F.relu(self.conv42(x))
62 | return_map['ReLU4_2'] = x
63 | x = F.relu(self.conv43(x))
64 | return_map['ReLU4_3'] = x
65 | x = self.max_pool(x)
66 |
67 | x = F.relu(self.conv51(x))
68 | return_map['ReLU5_1'] = x
69 | x = F.relu(self.conv52(x))
70 | return_map['ReLU5_2'] = x
71 | x = F.relu(self.conv53(x))
72 | return_map['ReLU5_3'] = x
73 |
74 | return return_map
75 |
76 |
77 | if __name__ == '__main__':
78 | import matplotlib.pyplot as plt
79 | import os
80 | from PIL import Image
81 |
82 | os.environ['CUDA_VISIBLE_DEVICES'] = '2'
83 |
84 | use_cuda = torch.cuda.is_available()
85 |
86 | def load_image(img_path):
87 | img = Image.open(img_path).convert('RGB')
88 | img = np.array(img, dtype=np.float32)[:, :, 0] / 255.0
89 | return img
90 |
91 | raster_size = 128
92 | vis_layer1 = 'ReLU3_3'
93 | vis_layer2 = 'ReLU5_1'
94 |
95 | vgg_slim_model = VGG_Slim()
96 |
97 | model_path = '../models/quickdraw-perceptual.pth'
98 | pretrained_dict = torch.load(model_path)
99 | print('pretrained_dict')
100 | print(pretrained_dict.keys())
101 |
102 | model_dict = vgg_slim_model.state_dict()
103 | print('model_dict')
104 | print(model_dict.keys())
105 |
106 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
107 | model_dict.update(pretrained_dict)
108 | vgg_slim_model.load_state_dict(model_dict)
109 |
110 | vgg_slim_model.eval()
111 |
112 | param_list = vgg_slim_model.named_parameters()
113 | print('-' * 100)
114 | count_t_vars = 0
115 | for name, param in param_list:
116 | num_param = np.prod(list(param.size()))
117 | count_t_vars += num_param
118 | print('%s | shape: %s | num_param: %i' % (name, str(param.size()), num_param))
119 | print('Total trainable variables %i.' % count_t_vars)
120 | print('-' * 100)
121 |
122 | image_types = ['GT', 'Connected', 'Disconnected', 'Parallel', 'Part']
123 | image_info_map = {}
124 |
125 | all_files = os.listdir(os.path.join('../vgg_utils/testData', 'GT'))
126 | all_files.sort()
127 | img_ids = []
128 | for file_name in all_files:
129 | img_idx = file_name[:file_name.find('_')]
130 | img_ids.append(img_idx)
131 | print(img_ids)
132 |
133 | if use_cuda:
134 | vgg_slim_model = vgg_slim_model.cuda()
135 |
136 | for img_id in img_ids:
137 | for image_type in image_types:
138 | finalstr = '_gt.png' if image_type == 'GT' else '_pred.png'
139 | img_path = os.path.join('../vgg_utils/testData', image_type, str(img_id) + finalstr)
140 | print(img_path)
141 | image = load_image(img_path)
142 | image_tensor = torch.tensor(np.expand_dims(image, axis=0)).float()
143 |
144 | if use_cuda:
145 | image_tensor = image_tensor.cuda()
146 |
147 | feature_maps = vgg_slim_model(image_tensor)
148 | feature_maps_11 = feature_maps['ReLU3_3'].cpu().data.numpy()
149 | feature_maps_33 = feature_maps['ReLU5_1'].cpu().data.numpy()
150 |
151 | print('ReLU3_3', feature_maps_11.shape)
152 | print('ReLU5_1', feature_maps_33.shape)
153 |
154 | feature_maps_val11 = np.transpose(feature_maps_11[0], (1, 2, 0))
155 | feature_maps_val33 = np.transpose(feature_maps_33[0], (1, 2, 0))
156 |
157 | if image_type != 'GT':
158 | feature_maps_gt11 = image_info_map['GT'][4]
159 | feat_diff_all11 = np.mean(np.abs(feature_maps_gt11 - feature_maps_val11), axis=-1)
160 | perc_layer_loss11 = np.mean(np.abs(feature_maps_gt11 - feature_maps_val11))
161 |
162 | feature_maps_gt33 = image_info_map['GT'][1]
163 | feat_diff_all33 = np.mean(np.abs(feature_maps_gt33 - feature_maps_val33), axis=-1)
164 | perc_layer_loss33 = np.mean(np.abs(feature_maps_gt33 - feature_maps_val33))
165 | # print('perc_layer_loss', image_type, perc_layer_loss)
166 | else:
167 | feat_diff_all11 = np.zeros_like(feature_maps_val11[:, :, 0])
168 | perc_layer_loss11 = 0.0
169 |
170 | feat_diff_all33 = np.zeros_like(feature_maps_val33[:, :, 0])
171 | perc_layer_loss33 = 0.0
172 |
173 | image_info_map[image_type] = [image,
174 | feature_maps_val33, feat_diff_all33, perc_layer_loss33,
175 | feature_maps_val11, feat_diff_all11, perc_layer_loss11]
176 |
177 | rows = 3
178 | cols = len(image_types)
179 | plt.figure(figsize=(4 * cols, 4 * rows))
180 |
181 | for image_type_i, image_type in enumerate(image_types):
182 | input_image = image_info_map[image_type][0]
183 | feat_diff33 = image_info_map[image_type][2]
184 | perc_loss33 = image_info_map[image_type][3]
185 | feat_diff11 = image_info_map[image_type][5]
186 | perc_loss11 = image_info_map[image_type][6]
187 |
188 | plt.subplot(rows, cols, image_type_i + 1)
189 | plt.title(image_type, fontsize=12)
190 | if image_type_i != 0:
191 | plt.axis('off')
192 | plt.imshow(input_image)
193 |
194 | plt.subplot(rows, cols, image_type_i + 1 + len(image_types))
195 | plt.title(str(perc_loss11), fontsize=12)
196 | if image_type_i != 0:
197 | plt.axis('off')
198 | plt.imshow(feat_diff11)
199 |
200 | plt.subplot(rows, cols, image_type_i + 1 + len(image_types) + len(image_types))
201 | plt.title(str(perc_loss33), fontsize=12)
202 | if image_type_i != 0:
203 | plt.axis('off')
204 | plt.imshow(feat_diff33)
205 |
206 | plt.show()
207 |
--------------------------------------------------------------------------------
/tools/svg_to_npz.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['CUDA_VISIBLE_DEVICES'] = '7'
3 |
4 | import numpy as np
5 | from PIL import Image
6 | import argparse
7 | import sys
8 |
9 | sys.path.append("..")
10 |
11 | import pydiffvg
12 | import torch
13 |
14 | import xml.etree.ElementTree as ET
15 | from xml.dom import minidom
16 | from svg.path import parse_path, path
17 |
18 | from dataset_utils.dataset_util import draw_segment_jointly, draw_segment_separately
19 |
20 |
21 | invalid_svg_shapes = ['rect', 'circle', 'ellipse', 'line', 'polyline', 'polygon']
22 |
23 |
24 | def parse_single_path(path_str):
25 | ps = parse_path(path_str)
26 | # print(len(ps))
27 |
28 | stroke_points_list = []
29 | control_points_list = []
30 | for item_i, path_item in enumerate(ps):
31 | path_type = type(path_item)
32 |
33 | if path_type == path.Move:
34 | # assert item_i == 0
35 | if item_i != 0:
36 | assert (len(control_points_list) - 1) % 3 == 0
37 | assert len(control_points_list) > 1
38 | stroke_points_list.append(control_points_list)
39 | control_points_list = []
40 |
41 | start = path_item.start
42 | start_x, start_y = start.real, start.imag
43 | control_points_list.append((start_x, start_y))
44 | elif path_type == path.CubicBezier:
45 | start, control1, control2, end = path_item.start, path_item.control1, path_item.control2, path_item.end
46 | start_x, start_y = start.real, start.imag
47 | control1_x, control1_y = control1.real, control1.imag
48 | control2_x, control2_y = control2.real, control2.imag
49 | end_x, end_y = end.real, end.imag
50 | control_points_list.append((control1_x, control1_y))
51 | control_points_list.append((control2_x, control2_y))
52 | control_points_list.append((end_x, end_y))
53 | elif path_type == path.Arc:
54 | raise Exception('Arc is here')
55 | elif path_type == path.Line:
56 | # assert len(control_points_list) == 1
57 | # start, end = path_item.start, path_item.end
58 | # start_x, start_y = start.real, start.imag
59 | # end_x, end_y = end.real, end.imag
60 | #
61 | # control1_x = 2.0 / 3.0 * start_x + 1.0 / 3.0 * end_x
62 | # control1_y = 2.0 / 3.0 * start_y + 1.0 / 3.0 * end_y
63 | # control2_x = 1.0 / 3.0 * start_x + 2.0 / 3.0 * end_x
64 | # control2_y = 1.0 / 3.0 * start_y + 2.0 / 3.0 * end_y
65 | #
66 | # control1 = (control1_x, control1_y)
67 | # control2 = (control2_x, control2_y)
68 | # control1_dist = sample_random_position(control1, 4.0, 1.0)
69 | # control2_dist = sample_random_position(control2, 4.0, 1.0)
70 | #
71 | # control_points_list.append(control1_dist)
72 | # control_points_list.append(control2_dist)
73 | # control_points_list.append((end_x, end_y))
74 | raise Exception('Line is here')
75 | elif path_type == path.Close:
76 | assert item_i == len(ps) - 1
77 | else:
78 | raise Exception('Unknown path_type', path_type)
79 |
80 | assert (len(control_points_list) - 1) % 3 == 0
81 | assert len(control_points_list) > 1
82 | stroke_points_list.append(control_points_list)
83 |
84 | return stroke_points_list
85 |
86 |
87 | def matrix_transform(points, matrix_params):
88 | # points: (N, 2), (x, y)
89 | # matrix_params: (6)
90 | new_points = []
91 | a, b, c, d, e, f = matrix_params
92 | matrix = np.array([[a, c, e],
93 | [b, d, f],
94 | [0, 0, 1]], dtype=np.float32)
95 | for point in points:
96 | point_vec = [point[0], point[1], 1]
97 | new_point = np.matmul(matrix, point_vec)[:2]
98 | new_points.append(new_point)
99 | # print(point, new_point)
100 | new_points = np.stack(new_points).astype(np.float32)
101 | return new_points
102 |
103 |
104 | def parse_svg(svg_file):
105 | tree = ET.parse(svg_file)
106 | root = tree.getroot()
107 |
108 | width = root.get('width')
109 | height = root.get('height')
110 | width = int(width[:-2])
111 | height = int(height[:-2])
112 |
113 | view_box = root.get('viewBox')
114 | view_x, view_y, view_width, view_height = view_box.split(' ')
115 | view_x, view_y, view_width, view_height = int(view_x), int(view_y), int(view_width), int(view_height)
116 | assert view_x == 0 and view_y == 0
117 | assert width == view_width and height == view_height
118 |
119 | strokes_list = []
120 |
121 | for elem in root.iter():
122 | try:
123 | _, tag_suffix = elem.tag.split('}')
124 | except ValueError:
125 | continue
126 |
127 | assert tag_suffix not in invalid_svg_shapes
128 |
129 | if tag_suffix == 'path':
130 | path_d = elem.attrib['d']
131 | control_points_single_stroke_list = parse_single_path(path_d) # (N_point, 2)
132 | for control_points_single_stroke_ in control_points_single_stroke_list:
133 | control_points_single_stroke = np.array(control_points_single_stroke_, dtype=np.float32)
134 | # print('control_points_single_stroke', control_points_single_stroke.shape)
135 |
136 | if 'transform' in elem.attrib.keys():
137 | transformation = elem.attrib['transform']
138 | if 'translate' in transformation:
139 | translate_xy = transformation[transformation.find('(')+1:transformation.find(')')]
140 | assert ', ' in translate_xy
141 | translate_x = float(translate_xy[:translate_xy.find(', ')])
142 | translate_y = float(translate_xy[translate_xy.find(', ')+2:])
143 |
144 | control_points_single_stroke[:, 0] += translate_x
145 | control_points_single_stroke[:, 1] += translate_y
146 | elif 'matrix' in transformation:
147 | matrix_params = transformation[transformation.find('(')+1:transformation.find(')')]
148 | matrix_params = matrix_params.split(' ')
149 | assert len(matrix_params) == 6
150 | matrix_params = [float(item) for item in matrix_params]
151 | control_points_single_stroke = matrix_transform(control_points_single_stroke, matrix_params)
152 | else:
153 | raise Exception('Error transformation')
154 |
155 | strokes_list.append(control_points_single_stroke)
156 |
157 | assert len(strokes_list) > 0
158 | return (view_width, view_height), strokes_list
159 |
160 |
161 | def svg_to_npz(database, ref_image):
162 | svg_base = os.path.join(database, 'svg')
163 | img_base = os.path.join(database, 'raster')
164 |
165 | save_base_parameter = os.path.join(database, 'parameter')
166 | save_base_color_separate = os.path.join(database, 'vector_vis')
167 | os.makedirs(save_base_parameter, exist_ok=True)
168 | os.makedirs(save_base_color_separate, exist_ok=True)
169 |
170 | ref_id = ref_image[:ref_image.find('.')]
171 | svg_file_path = os.path.join(svg_base, ref_id + '.svg')
172 | img_file_path = os.path.join(img_base, ref_image)
173 |
174 | img = Image.open(img_file_path).convert('RGB')
175 | img_width, img_height = img.width, img.height
176 | assert img_width == img_height
177 |
178 | view_sizes, strokes_list = parse_svg(svg_file_path)
179 | # strokes_list: list of (N_point, 2), in view size
180 | assert view_sizes[0] == view_sizes[1]
181 | view_size = view_sizes[0]
182 |
183 | norm_strokes_list = [] # list of (N_point, 2), in image size
184 | max_seq_num = 0
185 | for single_stroke in strokes_list:
186 | single_stroke_norm = single_stroke / float(view_size) * float(img_width)
187 | norm_strokes_list.append(single_stroke_norm)
188 | segment_num = (len(single_stroke_norm) - 1) // 3
189 | # print('segment_num', segment_num)
190 | max_seq_num += segment_num
191 | print('max_seq_num', max_seq_num)
192 |
193 | save_npz_path = os.path.join(save_base_parameter, ref_id + '.npz')
194 | np.savez(save_npz_path, strokes_data=norm_strokes_list, canvas_size=img_width)
195 |
196 | # visualization
197 | pydiffvg.set_use_gpu(torch.cuda.is_available())
198 |
199 | background_ = torch.ones(img_width, img_width, 4)
200 | render_ = pydiffvg.RenderFunction.apply
201 |
202 | stroke_thickness = 1.2
203 |
204 | # black_stroke_img = draw_segment_separately(norm_strokes_list, img_width, stroke_thickness, max_seq_num,
205 | # render_, background_, is_black=True)
206 | # black_stroke_img_save_path = os.path.join(img_base, ref_id + '-rendered.png')
207 | # pydiffvg.imwrite(black_stroke_img.cpu(), black_stroke_img_save_path, gamma=1.0)
208 |
209 | color_stroke_img = draw_segment_separately(norm_strokes_list, img_width, stroke_thickness, max_seq_num,
210 | render_, background_, is_black=False)
211 | color_stroke_img_save_path = os.path.join(save_base_color_separate, ref_id + '.png')
212 | pydiffvg.imwrite(color_stroke_img.cpu(), color_stroke_img_save_path, gamma=1.0)
213 |
214 |
215 | if __name__ == '__main__':
216 | parser = argparse.ArgumentParser()
217 | parser.add_argument('--database', '-db', type=str, default='../sample_inputs/rough/', help="define the data base")
218 | parser.add_argument('--reference', '-ref', type=str, default='20-0.png', help="define the reference image")
219 | args = parser.parse_args()
220 |
221 | svg_to_npz(args.database, args.reference)
222 |
--------------------------------------------------------------------------------
/sample_inputs/clean/svg/4-0.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/sample_inputs/clean/svg/5-0.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
--------------------------------------------------------------------------------
/sample_inputs/clean/svg/11-0.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
--------------------------------------------------------------------------------
/sample_inputs/rough/svg/25-0.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
--------------------------------------------------------------------------------
/rnn2.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class HyperNorm(nn.Module):
8 | def __init__(self, input_size: int, num_units: int, hyper_embedding_size: int, use_bias: bool = True):
9 | super().__init__()
10 | self.num_units = num_units
11 | self.embedding_size = hyper_embedding_size
12 | self.use_bias = use_bias
13 |
14 | self.z_w = nn.Linear(input_size, self.embedding_size, bias=True)
15 | self.alpha = nn.Linear(self.embedding_size, self.num_units, bias=False)
16 |
17 | if self.use_bias:
18 | self.z_b = nn.Linear(input_size, self.embedding_size, bias=False)
19 | self.beta = nn.Linear(self.embedding_size, self.num_units, bias=False)
20 |
21 | def __call__(self, input: torch.Tensor, hyper_output: torch.Tensor):
22 | zw = self.z_w(hyper_output)
23 | alpha = self.alpha(zw)
24 | result = torch.mul(alpha, input)
25 |
26 | if self.use_bias:
27 | zb = self.z_b(hyper_output)
28 | beta = self.beta(zb)
29 | result = torch.add(result, beta)
30 |
31 | return result
32 |
33 |
34 | class LSTMCell(nn.Module):
35 | """
36 | ## Long Short-Term Memory Cell
37 | LSTM Cell computes $c$, and $h$. $c$ is like the long-term memory,
38 | and $h$ is like the short term memory.
39 | We use the input $x$ and $h$ to update the long term memory.
40 | In the update, some features of $c$ are cleared with a forget gate $f$,
41 | and some features $i$ are added through a gate $g$.
42 | The new short term memory is the $\tanh$ of the long-term memory
43 | multiplied by the output gate $o$.
44 | Note that the cell doesn't look at long term memory $c$ when doing the update. It only modifies it.
45 | Also $c$ never goes through a linear transformation.
46 | This is what solves vanishing and exploding gradients.
47 | Here's the update rule.
48 | \begin{align}
49 | c_t &= \sigma(f_t) \odot c_{t-1} + \sigma(i_t) \odot \tanh(g_t) \\
50 | h_t &= \sigma(o_t) \odot \tanh(c_t)
51 | \end{align}
52 | $\odot$ stands for element-wise multiplication.
53 | Intermediate values and gates are computed as linear transformations of the hidden
54 | state and input.
55 | \begin{align}
56 | i_t &= lin_x^i(x_t) + lin_h^i(h_{t-1}) \\
57 | f_t &= lin_x^f(x_t) + lin_h^f(h_{t-1}) \\
58 | g_t &= lin_x^g(x_t) + lin_h^g(h_{t-1}) \\
59 | o_t &= lin_x^o(x_t) + lin_h^o(h_{t-1})
60 | \end{align}
61 | """
62 |
63 | def __init__(self, input_size: int, hidden_size: int, layer_norm: bool = False,
64 | forget_bias: float = 1.0):
65 | super().__init__()
66 |
67 | # These are the linear layer to transform the `input` and `hidden` vectors.
68 | # One of them doesn't need a bias since we add the transformations.
69 | self.forget_bias = forget_bias
70 |
71 | self.hidden_lin = nn.Linear(hidden_size, 4 * hidden_size, bias=False)
72 | self.input_lin = nn.Linear(input_size, 4 * hidden_size, bias=False)
73 |
74 | if layer_norm:
75 | self.layer_norm_all = nn.LayerNorm(4 * hidden_size)
76 | self.layer_norm_c = nn.LayerNorm(hidden_size)
77 | else:
78 | self.layer_norm_all = nn.Identity()
79 | self.layer_norm_c = nn.Identity()
80 |
81 | def __call__(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):
82 | ifgo = self.input_lin(x) + self.hidden_lin(h)
83 | ifgo = self.layer_norm_all(ifgo)
84 |
85 | ifgo = ifgo.chunk(4, dim=-1)
86 | i, j, f, o = ifgo
87 |
88 | g = torch.tanh(j)
89 |
90 | c_next = c * torch.sigmoid(f + self.forget_bias) + torch.sigmoid(i) * g
91 | h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
92 |
93 | return h_next, c_next
94 |
95 |
96 | class LSTM(nn.Module):
97 | """
98 | ## Multilayer LSTM
99 | """
100 |
101 | def __init__(self, input_size: int, hidden_size: int, n_layers: int):
102 | """
103 | Create a network of `n_layers` of LSTM.
104 | """
105 |
106 | super().__init__()
107 | self.n_layers = n_layers
108 | self.hidden_size = hidden_size
109 | # Create cells for each layer. Note that only the first layer gets the input directly.
110 | # Rest of the layers get the input from the layer below
111 | self.cells = nn.ModuleList([LSTMCell(input_size, hidden_size)] +
112 | [LSTMCell(hidden_size, hidden_size) for _ in range(n_layers - 1)])
113 |
114 | def __call__(self, x: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
115 | """
116 | `x` has shape `[n_steps, batch_size, input_size]` and
117 | `state` is a tuple of $h$ and $c$, each with a shape of `[batch_size, hidden_size]`.
118 | """
119 | n_steps, batch_size = x.shape[:2]
120 |
121 | # Initialize the state if `None`
122 | if state is None:
123 | h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
124 | c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
125 | else:
126 | (h, c) = state
127 | # Reverse stack the tensors to get the states of each layer
128 | # ?? You can just work with the tensor itself but this is easier to debug
129 | h, c = list(torch.unbind(h)), list(torch.unbind(c))
130 |
131 | # Array to collect the outputs of the final layer at each time step.
132 | out = []
133 | for t in range(n_steps):
134 | # Input to the first layer is the input itself
135 | inp = x[t]
136 | # Loop through the layers
137 | for layer in range(self.n_layers):
138 | # Get the state of the layer
139 | h[layer], c[layer] = self.cells[layer](inp, h[layer], c[layer])
140 | # Input to the next layer is the state of this layer
141 | inp = h[layer]
142 | # Collect the output $h$ of the final layer
143 | out.append(h[-1])
144 |
145 | # Stack the outputs and states
146 | out = torch.stack(out)
147 | h = torch.stack(h)
148 | c = torch.stack(c)
149 |
150 | return out, (h, c)
151 |
152 |
153 | class HyperLSTMCell(nn.Module):
154 | """
155 | ## HyperLSTM Cell
156 | For HyperLSTM the smaller network and the larger network both have the LSTM structure.
157 | This is defined in Appendix A.2.2 in the paper.
158 | """
159 |
160 | def __init__(self, input_size: int, num_units: int, hyper_num_units: int = 256, hyper_embedding_size: int = 32,
161 | forget_bias: float = 1.0):
162 | """
163 | `input_size` is the size of the input $x_t$,
164 | `num_units` is the size of the LSTM, and
165 | `hyper_num_units` is the size of the smaller LSTM that alters the weights of the larger outer LSTM.
166 | `hyper_embedding_size` is the size of the feature vectors used to alter the LSTM weights.
167 | We use the output of the smaller LSTM to compute $z_h^{i,f,g,o}$, $z_x^{i,f,g,o}$ and
168 | $z_b^{i,f,g,o}$ using linear transformations.
169 | We calculate $d_h^{i,f,g,o}(z_h^{i,f,g,o})$, $d_x^{i,f,g,o}(z_x^{i,f,g,o})$, and
170 | $d_b^{i,f,g,o}(z_b^{i,f,g,o})$ from these, using linear transformations again.
171 | These are then used to scale the rows of weight and bias tensors of the main LSTM.
172 | ?? Since the computation of $z$ and $d$ are two sequential linear transformations
173 | these can be combined into a single linear transformation.
174 | However we've implemented this separately so that it matches with the description
175 | in the paper.
176 | """
177 | super().__init__()
178 | self.forget_bias = forget_bias
179 |
180 | self.hyper = LSTMCell(num_units + input_size, hyper_num_units, layer_norm=True)
181 |
182 | self.w_x = nn.Linear(input_size, 4 * num_units, bias=False)
183 | self.w_h = nn.Linear(num_units, 4 * num_units, bias=False)
184 |
185 | self.hyper_ix = HyperNorm(num_units, num_units, hyper_embedding_size, use_bias=False)
186 | self.hyper_jx = HyperNorm(num_units, num_units, hyper_embedding_size, use_bias=False)
187 | self.hyper_fx = HyperNorm(num_units, num_units, hyper_embedding_size, use_bias=False)
188 | self.hyper_ox = HyperNorm(num_units, num_units, hyper_embedding_size, use_bias=False)
189 |
190 | self.hyper_ih = HyperNorm(num_units, num_units, hyper_embedding_size, use_bias=True)
191 | self.hyper_jh = HyperNorm(num_units, num_units, hyper_embedding_size, use_bias=True)
192 | self.hyper_fh = HyperNorm(num_units, num_units, hyper_embedding_size, use_bias=True)
193 | self.hyper_oh = HyperNorm(num_units, num_units, hyper_embedding_size, use_bias=True)
194 |
195 | # zero initialization
196 | self.bias = nn.Parameter(torch.zeros(4 * num_units))
197 |
198 | self.layer_norm_all = nn.LayerNorm(num_units * 4)
199 | self.layer_norm_c = nn.LayerNorm(num_units)
200 |
201 | def __call__(self, x: torch.Tensor,
202 | h: torch.Tensor, c: torch.Tensor,
203 | h_hat: torch.Tensor, c_hat: torch.Tensor):
204 | hyper_input = torch.cat([x, h], dim=-1)
205 | h_hat, c_hat = self.hyper(hyper_input, h_hat, c_hat)
206 | hyper_output = h_hat
207 |
208 | xh = self.w_x(x)
209 | hh = self.w_h(h)
210 |
211 | ix, jx, fx, ox = torch.chunk(xh, 4, 1)
212 | ix = self.hyper_ix(ix, hyper_output)
213 | jx = self.hyper_jx(jx, hyper_output)
214 | fx = self.hyper_fx(fx, hyper_output)
215 | ox = self.hyper_ox(ox, hyper_output)
216 |
217 | ih, jh, fh, oh = torch.chunk(hh, 4, 1)
218 | ih = self.hyper_ih(ih, hyper_output)
219 | jh = self.hyper_jh(jh, hyper_output)
220 | fh = self.hyper_fh(fh, hyper_output)
221 | oh = self.hyper_oh(oh, hyper_output)
222 |
223 | ib, jb, fb, ob = torch.chunk(self.bias, 4, 0)
224 |
225 | i = ix + ih + ib
226 | j = jx + jh + jb
227 | f = fx + fh + fb
228 | o = ox + oh + ob
229 |
230 | concat = torch.cat([i, j, f, o], 1)
231 | concat = self.layer_norm_all(concat)
232 | i, j, f, o = torch.chunk(concat, 4, 1)
233 |
234 | g = torch.tanh(j)
235 |
236 | c_next = c * torch.sigmoid(f + self.forget_bias) + torch.sigmoid(i) * g
237 | h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
238 |
239 | return h_next, c_next, h_hat, c_hat
240 |
241 |
242 | class HyperLSTM(nn.Module):
243 | """
244 | # HyperLSTM module
245 | """
246 | def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int):
247 | """
248 | Create a network of `n_layers` of HyperLSTM.
249 | """
250 |
251 | super().__init__()
252 |
253 | # Store sizes to initialize state
254 | self.n_layers = n_layers
255 | self.hidden_size = hidden_size
256 | self.hyper_size = hyper_size
257 |
258 | # Create cells for each layer. Note that only the first layer gets the input directly.
259 | # Rest of the layers get the input from the layer below
260 | self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, hyper_size, n_z)] +
261 | [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in
262 | range(n_layers - 1)])
263 |
264 | def __call__(self, x: torch.Tensor,
265 | state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):
266 | """
267 | * `x` has shape `[n_steps, batch_size, input_size]` and
268 | * `state` is a tuple of $h, c, \hat{h}, \hat{c}$.
269 | $h, c$ have shape `[batch_size, hidden_size]` and
270 | $\hat{h}, \hat{c}$ have shape `[batch_size, hyper_size]`.
271 | """
272 | n_steps, batch_size = x.shape[:2]
273 |
274 | # Initialize the state with zeros if `None`
275 | if state is None:
276 | h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
277 | c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
278 | h_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
279 | c_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
280 | #
281 | else:
282 | (h, c, h_hat, c_hat) = state
283 | # Reverse stack the tensors to get the states of each layer
284 | #
285 | # ?? You can just work with the tensor itself but this is easier to debug
286 | h, c = list(torch.unbind(h)), list(torch.unbind(c))
287 | h_hat, c_hat = list(torch.unbind(h_hat)), list(torch.unbind(c_hat))
288 |
289 | # Collect the outputs of the final layer at each step
290 | out = []
291 | for t in range(n_steps):
292 | # Input to the first layer is the input itself
293 | inp = x[t]
294 | # Loop through the layers
295 | for layer in range(self.n_layers):
296 | # Get the state of the layer
297 | h[layer], c[layer], h_hat[layer], c_hat[layer] = \
298 | self.cells[layer](inp, h[layer], c[layer], h_hat[layer], c_hat[layer])
299 | # Input to the next layer is the state of this layer
300 | inp = h[layer]
301 | # Collect the output $h$ of the final layer
302 | out.append(h[-1])
303 |
304 | # Stack the outputs and states
305 | out = torch.stack(out)
306 | h = torch.stack(h)
307 | c = torch.stack(c)
308 | h_hat = torch.stack(h_hat)
309 | c_hat = torch.stack(c_hat)
310 |
311 | #
312 | return out, (h, c, h_hat, c_hat)
313 |
--------------------------------------------------------------------------------
/sample_inputs/rough/svg/27-0.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
--------------------------------------------------------------------------------
/sample_inputs/rough/svg/22-0.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
--------------------------------------------------------------------------------
/sample_inputs/rough/svg/24-0.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
--------------------------------------------------------------------------------