├── .gitignore ├── LICENSE ├── README.md ├── assets ├── examples │ ├── sample-0.png │ ├── sample-1.png │ ├── sample-15.png │ ├── sample-16.png │ ├── sample-17.png │ ├── sample-18.png │ ├── sample-4.png │ ├── sample-6.png │ └── sample-7.png ├── starvector-arch.png ├── starvector-teaser.png └── starvector-xyz.png ├── configs ├── accelerate │ ├── 1-gpu.yaml │ ├── 2-gpu.yaml │ ├── 4-gpu.yaml │ ├── 8-gpu.yaml │ ├── deepspeed-1-gpu.yaml │ ├── deepspeed-2-gpu.yaml │ ├── deepspeed-4-gpu.yaml │ ├── deepspeed-8-gpu.yaml │ ├── deespeed.json │ └── val-deepspeed-1-gpu.yaml ├── chat-template.jinja ├── generation │ ├── hf │ │ ├── starvector-1b │ │ │ └── im2svg.yaml │ │ └── starvector-8b │ │ │ └── im2svg.yaml │ └── vllm │ │ ├── starvector-1b │ │ └── im2svg.yaml │ │ └── starvector-8b │ │ └── im2svg.yaml ├── metrics │ ├── im2svg.yaml │ └── text2svg.yaml └── models │ ├── default.yaml │ ├── starvector-1b │ ├── im2svg-emoji.yaml │ ├── im2svg-fonts.yaml │ ├── im2svg-icons.yaml │ ├── im2svg-stack.yaml │ ├── text2svg-figr.yaml │ └── text2svg-stack.yaml │ └── starvector-8b │ ├── im2svg-emoji.yaml │ ├── im2svg-fonts-simple.yaml │ ├── im2svg-fonts.yaml │ ├── im2svg-icons.yaml │ ├── im2svg-stack.yaml │ ├── text2svg-figr.yaml │ └── text2svg-stack.yaml ├── docker ├── Dockerfile ├── Image2SVG Generation Example.ipynb └── README.md ├── pyproject.toml ├── scripts ├── quickstart-hf.py ├── quickstart-vllm.py ├── quickstart.py ├── train │ ├── train-starvector-1b-im2svg.sh │ ├── train-starvector-1b-text2svg.sh │ ├── train-starvector-8b-im2svg.sh │ └── train-starvector-8b-text2svg.sh └── validation │ ├── validate-starvector-1b-im2svg.sh │ └── validate-starvector-8b-im2svg.sh └── starvector ├── __init__.py ├── adapter.py ├── clip_model.py ├── data ├── augmentation.py ├── base.py ├── dataset.py ├── emojisvg.py ├── figrsvg.py ├── fontsvg.py ├── iconsvg.py ├── stacksvg.py └── util.py ├── image_encoder.py ├── metrics ├── base_metric.py ├── compute_LPIPS.py ├── compute_SSIM.py ├── compute_clip_score.py ├── compute_dino_score.py ├── compute_fid.py ├── compute_l2.py ├── count_token_length.py ├── inception.py ├── metrics.py └── util.py ├── model ├── adapters │ └── adapter.py ├── builder.py ├── gpt_bigcode │ ├── __init__.py │ ├── configuration_gpt_bigcode.py │ └── modeling_gpt_bigcode.py ├── image_encoder │ ├── clip_model.py │ └── image_encoder.py ├── llm │ ├── starcoder.py │ └── starcoder2.py ├── models │ ├── starvector_base.py │ ├── starvector_v1.py │ └── starvector_v2.py └── starvector_arch.py ├── serve ├── __init__.py ├── constants.py ├── controller.py ├── conversation.py ├── examples │ ├── sample-0.png │ ├── sample-1.png │ ├── sample-16.png │ ├── sample-17.png │ ├── sample-18.png │ ├── sample-4.png │ ├── sample-6.png │ └── sample-7.png ├── gradio_demo_with_updated_gradio.py ├── gradio_web_server.py ├── model_worker.py ├── register_worker.py ├── util.py └── vllm_api_gradio │ ├── controller.py │ ├── gradio_vllm.py │ ├── gradio_web_server.py │ ├── model_worker.py │ └── scroll.js ├── train ├── train.py ├── util.py └── zero_to_fp32.py ├── util.py └── validation ├── README.md ├── __init__.py ├── starvector_hf_validator.py ├── starvector_vllm_api_svg_validator.py ├── starvector_vllm_svg_validator.py ├── svg_validator_base.py └── validate.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Other 163 | *vscode* 164 | *egg* 165 | *nfs* 166 | *conv.json* 167 | *rebuttal* 168 | *.log* 169 | *remove_files* 170 | *wandb* 171 | *tmp* 172 | *vscode* 173 | *.csv 174 | *avoid_samples* 175 | *logs* 176 | *results* 177 | *.pickle 178 | *.pkl 179 | *internal* 180 | *test.png* 181 | assets/reward_assets -------------------------------------------------------------------------------- /assets/examples/sample-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/star-vector/250afbe55c4a5ca9cbae181eed8ab924b30b82bd/assets/examples/sample-0.png -------------------------------------------------------------------------------- /assets/examples/sample-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/star-vector/250afbe55c4a5ca9cbae181eed8ab924b30b82bd/assets/examples/sample-1.png -------------------------------------------------------------------------------- /assets/examples/sample-15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/star-vector/250afbe55c4a5ca9cbae181eed8ab924b30b82bd/assets/examples/sample-15.png -------------------------------------------------------------------------------- /assets/examples/sample-16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/star-vector/250afbe55c4a5ca9cbae181eed8ab924b30b82bd/assets/examples/sample-16.png -------------------------------------------------------------------------------- /assets/examples/sample-17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/star-vector/250afbe55c4a5ca9cbae181eed8ab924b30b82bd/assets/examples/sample-17.png -------------------------------------------------------------------------------- /assets/examples/sample-18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/star-vector/250afbe55c4a5ca9cbae181eed8ab924b30b82bd/assets/examples/sample-18.png -------------------------------------------------------------------------------- /assets/examples/sample-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/star-vector/250afbe55c4a5ca9cbae181eed8ab924b30b82bd/assets/examples/sample-4.png -------------------------------------------------------------------------------- /assets/examples/sample-6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/star-vector/250afbe55c4a5ca9cbae181eed8ab924b30b82bd/assets/examples/sample-6.png -------------------------------------------------------------------------------- /assets/examples/sample-7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/star-vector/250afbe55c4a5ca9cbae181eed8ab924b30b82bd/assets/examples/sample-7.png -------------------------------------------------------------------------------- /assets/starvector-arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/star-vector/250afbe55c4a5ca9cbae181eed8ab924b30b82bd/assets/starvector-arch.png -------------------------------------------------------------------------------- /assets/starvector-teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/star-vector/250afbe55c4a5ca9cbae181eed8ab924b30b82bd/assets/starvector-teaser.png -------------------------------------------------------------------------------- /assets/starvector-xyz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/star-vector/250afbe55c4a5ca9cbae181eed8ab924b30b82bd/assets/starvector-xyz.png -------------------------------------------------------------------------------- /configs/accelerate/1-gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: {} 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | dynamo_backend: 'NO' 6 | fsdp_config: {} 7 | gpu_ids: all 8 | machine_rank: 0 9 | main_training_function: main 10 | megatron_lm_config: {} 11 | mixed_precision: 'bf16' 12 | num_machines: 1 13 | num_processes: 1 14 | rdzv_backend: static 15 | same_network: true 16 | use_cpu: false -------------------------------------------------------------------------------- /configs/accelerate/2-gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: {} 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | dynamo_backend: 'NO' 6 | fsdp_config: {} 7 | gpu_ids: all 8 | machine_rank: 0 9 | main_training_function: main 10 | megatron_lm_config: {} 11 | mixed_precision: 'bf16' 12 | num_machines: 1 13 | num_processes: 2 14 | rdzv_backend: static 15 | same_network: true 16 | use_cpu: false -------------------------------------------------------------------------------- /configs/accelerate/4-gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: {} 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | dynamo_backend: 'NO' 6 | fsdp_config: {} 7 | gpu_ids: all 8 | machine_rank: 0 9 | main_training_function: main 10 | megatron_lm_config: {} 11 | mixed_precision: 'bf16' 12 | num_machines: 1 13 | num_processes: 4 14 | rdzv_backend: static 15 | same_network: true 16 | use_cpu: false -------------------------------------------------------------------------------- /configs/accelerate/8-gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: {} 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | dynamo_backend: 'NO' 6 | fsdp_config: {} 7 | gpu_ids: all 8 | machine_rank: 0 9 | main_training_function: main 10 | megatron_lm_config: {} 11 | mixed_precision: 'bf16' 12 | num_machines: 1 13 | num_processes: 8 14 | rdzv_backend: static 15 | same_network: true 16 | use_cpu: false -------------------------------------------------------------------------------- /configs/accelerate/deepspeed-1-gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_clipping: 1.0 4 | offload_optimizer_device: none 5 | offload_param_device: none 6 | zero3_init_flag: false 7 | zero_stage: 2 8 | distributed_type: DEEPSPEED 9 | downcast_bf16: 'no' 10 | fsdp_config: {} 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: 'bf16' 14 | num_machines: 1 15 | num_processes: 1 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false 22 | -------------------------------------------------------------------------------- /configs/accelerate/deepspeed-2-gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_clipping: 1.0 4 | offload_optimizer_device: none 5 | offload_param_device: none 6 | zero3_init_flag: false 7 | zero_stage: 2 8 | distributed_type: DEEPSPEED 9 | downcast_bf16: 'no' 10 | fsdp_config: {} 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: 'bf16' 14 | num_machines: 1 15 | num_processes: 2 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false 22 | -------------------------------------------------------------------------------- /configs/accelerate/deepspeed-4-gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_clipping: 1.0 4 | offload_optimizer_device: none 5 | offload_param_device: none 6 | zero3_init_flag: false 7 | zero_stage: 2 8 | distributed_type: DEEPSPEED 9 | downcast_bf16: 'no' 10 | fsdp_config: {} 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: 'bf16' 14 | num_machines: 1 15 | num_processes: 4 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false -------------------------------------------------------------------------------- /configs/accelerate/deepspeed-8-gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_clipping: 1.0 4 | offload_optimizer_device: none 5 | offload_param_device: none 6 | zero3_init_flag: false 7 | zero_stage: 2 8 | distributed_type: DEEPSPEED 9 | downcast_bf16: 'no' 10 | fsdp_config: {} 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: 'bf16' 14 | num_machines: 1 15 | num_processes: 8 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false 22 | -------------------------------------------------------------------------------- /configs/accelerate/deespeed.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": false 4 | }, 5 | "zero_optimization": { 6 | "stage": 3, 7 | "offload_optimizer": { 8 | "device": "cpu" 9 | }, 10 | "offload_param": { 11 | "device": "cpu" 12 | }, 13 | "overlap_comm": true, 14 | "contiguous_gradients": true, 15 | "reduce_bucket_size": "auto", 16 | "stage3_prefetch_bucket_size": "auto", 17 | "stage3_param_persistence_threshold": "auto", 18 | "sub_group_size": 1e9, 19 | "stage3_max_live_parameters": 1e9, 20 | "stage3_max_reuse_distance": 1e9, 21 | "stage3_gather_16bit_weights_on_model_save": true 22 | }, 23 | "gradient_accumulation_steps": 4, 24 | "gradient_clipping": "auto", 25 | "steps_per_print": 2000, 26 | "train_batch_size": "auto", 27 | "train_micro_batch_size_per_gpu": "auto", 28 | "wall_clock_breakdown": false 29 | } 30 | -------------------------------------------------------------------------------- /configs/accelerate/val-deepspeed-1-gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_clipping: 1.0 4 | offload_optimizer_device: none 5 | offload_param_device: none 6 | zero3_init_flag: false 7 | zero_stage: 1 8 | distributed_type: DEEPSPEED 9 | downcast_bf16: 'no' 10 | fsdp_config: {} 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: 'bf16' 14 | num_machines: 1 15 | num_processes: 1 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false 22 | -------------------------------------------------------------------------------- /configs/chat-template.jinja: -------------------------------------------------------------------------------- 1 | {% for message in messages %}{{ message.content }}{% endfor %} -------------------------------------------------------------------------------- /configs/generation/hf/starvector-1b/im2svg.yaml: -------------------------------------------------------------------------------- 1 | # General configuration 2 | run: 3 | project_name: "starvector-RL-eval" 4 | out_dir: "eval_results" 5 | device: cuda 6 | report_to: wandb 7 | run_id: test-run 8 | log_images: false 9 | 10 | # Model configuration 11 | model: 12 | name: "starvector/starvector-1b-im2svg" # Required: Model name for HF-based model 13 | from_checkpoint: false 14 | generation_engine: "hf" 15 | task: im2svg 16 | torch_dtype: float16 17 | # image_processor: clip # is this needed? 18 | 19 | # Dataset configuration 20 | dataset: 21 | dataset_name: starvector/svg-stack-RL # Required: Name of the dataset to evaluate on 22 | config_name: null # in bigodcs set Image2SVG 23 | split: test 24 | batch_size: 8 25 | num_workers: 4 26 | im_size: 224 27 | num_samples: -1 28 | 29 | # vllm https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html 30 | # hf https://huggingface.co/docs/transformers/main_classes/text_generation 31 | generation_params: 32 | # Text generation parameters 33 | max_length: 7800 34 | min_length: 10 35 | num_beams: 1 36 | temperature: 0.2 37 | generation_sweep: false # Controls multi-temperature sampling, rank based sampling 38 | # num_generations_different_temp: 1 39 | # min_temperature: 0.0 40 | # max_temperature: 0.5 41 | num_captions: 1 42 | repetition_penalty: 1.0 43 | length_penalty: 1.0 44 | presence_penalty: 0.0 # only used in vllm 45 | frequency_penalty: 0.0 46 | top_p: 0.95 47 | do_sample: true # turn this off for greedy decoding 48 | use_nucleus_sampling: true 49 | logit_bias: 5 # if this is not false, the model will be biased to the svg_end_token_id 50 | stream: false 51 | 52 | 53 | -------------------------------------------------------------------------------- /configs/generation/hf/starvector-8b/im2svg.yaml: -------------------------------------------------------------------------------- 1 | # General configuration 2 | run: 3 | project_name: "starvector-RL-eval" 4 | out_dir: "eval_results" 5 | device: cuda 6 | report_to: wandb 7 | run_id: test-run 8 | log_images: false 9 | 10 | # Model configuration 11 | model: 12 | name: "starvector/starvector-8b-im2svg" 13 | from_checkpoint: false 14 | generation_engine: "hf" 15 | task: im2svg 16 | torch_dtype: bfloat16 17 | # image_processor: siglip_384 18 | 19 | # Dataset configuration 20 | dataset: 21 | dataset_name: starvector/svg-stack-RL 22 | config_name: null 23 | split: test 24 | batch_size: 2 25 | num_workers: 4 26 | im_size: 384 27 | num_samples: -1 28 | 29 | # vllm https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html 30 | # hf https://huggingface.co/docs/transformers/main_classes/text_generation 31 | generation_params: 32 | max_length: 16000 33 | min_length: 10 34 | num_beams: 1 35 | temperature: 0.7 36 | generation_sweep: false # Controls multi-temperature sampling, rank based sampling 37 | # num_generations_different_temp: 1 38 | # min_temperature: 0.0 39 | # max_temperature: 0.5 40 | num_captions: 1 41 | repetition_penalty: 1.0 42 | length_penalty: 0.5 43 | presence_penalty: 0.0 # only used in vllm 44 | frequency_penalty: 0.0 45 | top_p: 0.95 46 | do_sample: true # turn this off for greedy decoding 47 | use_nucleus_sampling: true 48 | logit_bias: 5 # if this is not false, the model will be biased to the svg_end_token_id 49 | stream: false 50 | 51 | -------------------------------------------------------------------------------- /configs/generation/vllm/starvector-1b/im2svg.yaml: -------------------------------------------------------------------------------- 1 | # General configuration 2 | run: 3 | project_name: "starvector-RL-eval" 4 | out_dir: "eval_results" 5 | report_to: wandb 6 | run_id: test-eval3 7 | log_images: false 8 | 9 | # Model configuration 10 | model: 11 | name: "starvector/starvector-1b-im2svg" # Required: Model name for HF-based model 12 | from_checkpoint: false 13 | generation_engine: "vllm" 14 | task: im2svg 15 | torch_dtype: float16 16 | # image_processor: clip # is this needed? 17 | 18 | # Dataset configuration 19 | dataset: 20 | dataset_name: starvector/svg-stack-RL # Required: Name of the dataset to evaluate on 21 | config_name: null # in bigodcs set Image2SVG 22 | split: test 23 | batch_size: 8 24 | num_workers: 8 25 | im_size: 224 26 | num_samples: 500 27 | 28 | # vllm https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html 29 | # hf https://huggingface.co/docs/transformers/main_classes/text_generation 30 | generation_params: 31 | # Text generation parameters 32 | max_length: 7933 # 8192 - (visual tokens) 33 | num_generations: 1 34 | min_length: 10 35 | num_beams: 1 36 | temperature: 0.5 37 | generation_sweep: false # Controls multi-temperature sampling, rank based sampling 38 | # num_generations_different_temp: 5 39 | # min_temperature: 0.0 40 | # max_temperature: 0.5 41 | num_captions: 1 42 | frequency_penalty: 0.0 43 | presence_penalty: 0.0 44 | repetition_penalty: 1.0 45 | top_p: 0.9 46 | logit_bias: 5 # if this is not false, the model will be biased to the svg_end_token_id 47 | min_p: 0.0 48 | top_k: -1 49 | stream: false 50 | 51 | 52 | -------------------------------------------------------------------------------- /configs/generation/vllm/starvector-8b/im2svg.yaml: -------------------------------------------------------------------------------- 1 | # General configuration 2 | run: 3 | project_name: "starvector-RL-eval" 4 | out_dir: "eval_results" 5 | report_to: wandb 6 | run_id: test-eval3 7 | log_images: false 8 | 9 | # Model configuration 10 | model: 11 | name: "starvector/starvector-8b-im2svg" # Required: Model name for HF-based model 12 | from_checkpoint: false 13 | generation_engine: "vllm" 14 | task: im2svg 15 | torch_dtype: float16 16 | # image_processor: clip # is this needed? 17 | 18 | # Dataset configuration 19 | dataset: 20 | dataset_name: starvector/svg-stack # Required: Name of the dataset to evaluate on 21 | config_name: null # in bigodcs set Image2SVG 22 | split: test 23 | batch_size: 8 24 | num_workers: 8 25 | im_size: 224 26 | num_samples: 500 27 | 28 | # vllm https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html 29 | # hf https://huggingface.co/docs/transformers/main_classes/text_generation 30 | generation_params: 31 | # Text generation parameters 32 | max_length: 7933 # 8192 - (visual tokens) 33 | num_generations: 1 34 | min_length: 10 35 | num_beams: 1 36 | temperature: 0.5 37 | generation_sweep: false # Controls multi-temperature sampling, rank based sampling 38 | # num_generations_different_temp: 5 39 | # min_temperature: 0.0 40 | # max_temperature: 0.5 41 | num_captions: 1 42 | frequency_penalty: 0.0 43 | presence_penalty: 0.0 44 | repetition_penalty: 1.0 45 | top_p: 0.9 46 | logit_bias: 5 # if this is not false, the model will be biased to the svg_end_token_id 47 | min_p: 0.0 48 | top_k: -1 49 | stream: false 50 | 51 | 52 | -------------------------------------------------------------------------------- /configs/metrics/im2svg.yaml: -------------------------------------------------------------------------------- 1 | metrics: 2 | L2: true 3 | Masked-L2: false 4 | LPIPS: true 5 | SSIM: true 6 | FID: false 7 | FID_clip: false 8 | CLIPScore: false 9 | CountTokenLength: true 10 | ratio_post_processed: false 11 | ratio_non_compiling: false 12 | DinoScore: true 13 | -------------------------------------------------------------------------------- /configs/metrics/text2svg.yaml: -------------------------------------------------------------------------------- 1 | metrics: 2 | L2: false 3 | Masked-L2: false 4 | LPIPS: false 5 | SSIM: false 6 | FID: true 7 | FID_clip: true 8 | CLIPScore: true 9 | CountTokenLength: true 10 | ratio_post_processed: true 11 | ratio_non_compiling: true 12 | DinoScore: false 13 | -------------------------------------------------------------------------------- /configs/models/default.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | project: starvector-im2svg 3 | use_wandb: false 4 | entity: abc 5 | copy_code: false 6 | model: 7 | max_length: 8192 8 | model_name: null # in case of creating a new model, set this to None (null) 9 | starcoder_model_name: bigcode/starcoderbase-1b 10 | pretrained: true 11 | image_encoder_type: clip 12 | use_flash_attn: true 13 | adapter_norm: layer_norm 14 | init_type: glorot 15 | dropout: 0.1 16 | task: im2svg 17 | transformer_layer_cls: None # fsdp specific 18 | use_cache: false 19 | training: 20 | save_model_epochs: 1 21 | checkpointing_steps: 500 22 | checkpoints_total_limit: 3 23 | model_precision: bf16 24 | resume_from_checkpoint: false 25 | continue_training: false 26 | n_epochs: 4 27 | lr: 0.00001 28 | gradient_accumulation_steps: 4 29 | lr_scheduler: cosine 30 | lr_warmup_steps: 10 31 | adam_beta1: 0.95 32 | adam_beta2: 0.999 33 | adam_weight_decay: 1.0e-06 34 | adam_epsilon: 1e-08 35 | optimizer: adamw 36 | start_generation_at_step: 0 37 | train_image_encoder: true 38 | train_LLM: true 39 | use_gradient_checkpointing: false 40 | fsdp: 41 | enable: false 42 | data: 43 | num_workers: 4 44 | train: 45 | batch_size: 2 46 | target: starvector.data.stacksvg.SVGStackDataset 47 | params: 48 | split: train 49 | dataset_name: ServiceNow/svg-stack 50 | im_size: 224 51 | num_samples: -1 52 | transforms: false 53 | select_dataset_name: false 54 | test: 55 | batch_size: 2 56 | target: starvector.data.stacksvg.SVGStackDataset 57 | params: 58 | split: test 59 | dataset_name: ServiceNow/svg-stack 60 | im_size: 224 61 | num_samples: -1 62 | transforms: false 63 | select_dataset_name: false 64 | generation: 65 | max_length: 8192 66 | min_length: 10 67 | num_beams: 3 68 | num_captions: 1 69 | num_generations_different_temp: 1.5 70 | start_temperature: 0.5 71 | repetition_penalty: 1.0 72 | length_penalty: 1.0 73 | temperature: 1.0 74 | top_p: 0.9 75 | use_nucleus_sampling: true 76 | im_size: 224 77 | dpi: 2 78 | scale: 300 79 | num_samples_to_generate: -1 80 | log_wandb_images: true 81 | start_generation_at_step: -1 82 | metrics: 83 | L2: true 84 | Masked-L2: false 85 | LPIPS: true 86 | SSIM: true 87 | FID: false 88 | FID_clip: false 89 | CLIPScore: false 90 | CountTokenLength: true 91 | ratio_post_processed: false 92 | ratio_non_compiling: false 93 | DinoScore: true 94 | -------------------------------------------------------------------------------- /configs/models/starvector-1b/im2svg-emoji.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | project: starvector-1b-im2svg 3 | use_wandb: false 4 | entity: abc 5 | copy_code: false 6 | model: 7 | max_length: 8192 8 | model_name: null # in case of creating a new model, set this to None (null) 9 | starcoder_model_name: bigcode/starcoderbase-1b 10 | pretrained: true 11 | image_encoder_type: clip 12 | use_flash_attn: true 13 | adapter_norm: batch_norm 14 | init_type: glorot 15 | dropout: 0.1 16 | task: im2svg 17 | transformer_layer_cls: null # fsdp specific 18 | use_cache: false 19 | training: 20 | save_model_epochs: 1 21 | checkpointing_steps: 500 22 | checkpoints_total_limit: 3 23 | model_precision: bf16 24 | resume_from_checkpoint: false 25 | continue_training: false 26 | n_epochs: 10 27 | lr: 0.00001 28 | gradient_accumulation_steps: 4 29 | lr_scheduler: cosine 30 | lr_warmup_steps: 10 31 | adam_beta1: 0.95 32 | adam_beta2: 0.999 33 | adam_weight_decay: 1.0e-06 34 | adam_epsilon: 1e-08 35 | optimizer: adamw 36 | train_image_encoder: true 37 | train_LLM: true 38 | use_gradient_checkpointing: false 39 | fsdp: 40 | enable: false 41 | data: 42 | num_workers: 16 43 | train: 44 | batch_size: 2 45 | target: starvector.data.emojisvg.EmojiSVGDataset 46 | params: 47 | split: train 48 | dataset_name: starvector/svg-emoji 49 | im_size: 224 50 | num_samples: -1 51 | transforms: false 52 | select_dataset_name: false 53 | test: 54 | batch_size: 8 55 | target: starvector.data.emojisvg.EmojiSVGDataset 56 | params: 57 | split: test 58 | dataset_name: starvector/svg-emoji 59 | im_size: 224 60 | num_samples: -1 61 | transforms: false 62 | select_dataset_name: false 63 | generation: 64 | max_length: 8192 65 | min_length: 10 66 | num_beams: 3 67 | temperature: 1.0 68 | num_captions: 1 69 | repetition_penalty: 1.0 70 | length_penalty: 0.5 71 | top_p: 0.95 72 | use_nucleus_sampling: true 73 | im_size: 224 74 | dpi: 2 75 | scale: 300 76 | metrics: 77 | L2: true 78 | Masked-L2: false 79 | LPIPS: true 80 | SSIM: true 81 | FID: false 82 | FID_clip: false 83 | CLIPScore: false 84 | CountTokenLength: true 85 | ratio_post_processed: false 86 | ratio_non_compiling: false 87 | DinoScore: true 88 | -------------------------------------------------------------------------------- /configs/models/starvector-1b/im2svg-fonts.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | project: starvector-1b-im2svg 3 | use_wandb: false 4 | entity: abc 5 | copy_code: false 6 | model: 7 | max_length: 8192 8 | model_name: null # in case of creating a new model, set this to None (null) 9 | starcoder_model_name: bigcode/starcoderbase-1b 10 | pretrained: true 11 | image_encoder_type: clip 12 | use_flash_attn: true 13 | adapter_norm: batch_norm 14 | init_type: glorot 15 | dropout: 0.1 16 | task: im2svg 17 | transformer_layer_cls: null # fsdp specific 18 | use_cache: false 19 | training: 20 | save_model_epochs: 1 21 | checkpointing_steps: 500 22 | checkpoints_total_limit: 3 23 | model_precision: bf16 24 | resume_from_checkpoint: false 25 | continue_training: false 26 | n_epochs: 4 27 | lr: 0.00001 28 | gradient_accumulation_steps: 4 29 | lr_scheduler: cosine 30 | lr_warmup_steps: 10 31 | adam_beta1: 0.95 32 | adam_beta2: 0.999 33 | adam_weight_decay: 1.0e-06 34 | adam_epsilon: 1e-08 35 | optimizer: adamw 36 | train_image_encoder: true 37 | train_LLM: true 38 | use_gradient_checkpointing: false 39 | fsdp: 40 | enable: false 41 | data: 42 | num_workers: 16 43 | train: 44 | batch_size: 4 45 | target: starvector.data.fontsvg.FontSVGDataset 46 | params: 47 | split: train 48 | dataset_name: starvector/svg-fonts 49 | im_size: 224 50 | num_samples: -1 51 | transforms: false 52 | select_dataset_name: false 53 | test: 54 | batch_size: 8 55 | target: starvector.data.fontsvg.FontSVGDataset 56 | params: 57 | split: test 58 | dataset_name: starvector/svg-fonts 59 | im_size: 224 60 | num_samples: 1000 61 | transforms: false 62 | select_dataset_name: false 63 | generation: 64 | max_length: 8192 65 | min_length: 10 66 | num_beams: 3 67 | temperature: 1.0 68 | num_captions: 1 69 | repetition_penalty: 1.0 70 | length_penalty: 0.5 71 | top_p: 0.95 72 | use_nucleus_sampling: true 73 | im_size: 224 74 | dpi: 2 75 | scale: 300 76 | metrics: 77 | L2: true 78 | Masked-L2: false 79 | LPIPS: true 80 | SSIM: true 81 | FID: false 82 | FID_clip: false 83 | CLIPScore: false 84 | CountTokenLength: true 85 | ratio_post_processed: false 86 | ratio_non_compiling: false 87 | DinoScore: true 88 | -------------------------------------------------------------------------------- /configs/models/starvector-1b/im2svg-icons.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | project: starvector-1b-im2svg 3 | use_wandb: false 4 | entity: abc 5 | copy_code: false 6 | model: 7 | max_length: 8192 8 | model_name: null # in case of creating a new model, set this to None (null) 9 | starcoder_model_name: bigcode/starcoderbase-1b 10 | pretrained: true 11 | image_encoder_type: clip 12 | use_flash_attn: true 13 | adapter_norm: batch_norm 14 | init_type: glorot 15 | dropout: 0.1 16 | task: im2svg 17 | transformer_layer_cls: null # fsdp specific 18 | use_cache: false 19 | training: 20 | save_model_epochs: 1 21 | checkpointing_steps: 500 22 | checkpoints_total_limit: 3 23 | model_precision: bf16 24 | resume_from_checkpoint: false 25 | continue_training: false 26 | n_epochs: 10 27 | lr: 0.00001 28 | gradient_accumulation_steps: 4 29 | lr_scheduler: cosine 30 | lr_warmup_steps: 10 31 | adam_beta1: 0.95 32 | adam_beta2: 0.999 33 | adam_weight_decay: 1.0e-06 34 | adam_epsilon: 1e-08 35 | optimizer: adamw 36 | train_image_encoder: true 37 | train_LLM: true 38 | use_gradient_checkpointing: false 39 | fsdp: 40 | enable: false 41 | data: 42 | num_workers: 16 43 | train: 44 | batch_size: 4 45 | target: starvector.data.iconsvg.SVGIconsDataset 46 | params: 47 | split: train 48 | dataset_name: starvector/svg-icons 49 | im_size: 224 50 | num_samples: -1 51 | transforms: false 52 | select_dataset_name: false 53 | test: 54 | batch_size: 8 55 | target: starvector.data.iconsvg.SVGIconsDataset 56 | params: 57 | split: test 58 | dataset_name: starvector/svg-icons 59 | im_size: 224 60 | num_samples: 1000 61 | transforms: false 62 | select_dataset_name: false 63 | generation: 64 | max_length: 8192 65 | min_length: 10 66 | num_beams: 3 67 | temperature: 1.0 68 | num_captions: 1 69 | repetition_penalty: 1.0 70 | length_penalty: 0.5 71 | top_p: 0.95 72 | use_nucleus_sampling: true 73 | im_size: 224 74 | dpi: 2 75 | scale: 300 76 | metrics: 77 | L2: true 78 | Masked-L2: false 79 | LPIPS: true 80 | SSIM: true 81 | FID: false 82 | FID_clip: false 83 | CLIPScore: false 84 | CountTokenLength: true 85 | ratio_post_processed: false 86 | ratio_non_compiling: false 87 | DinoScore: true 88 | -------------------------------------------------------------------------------- /configs/models/starvector-1b/im2svg-stack.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | project: starvector-1b-im2svg 3 | use_wandb: false 4 | entity: abc 5 | copy_code: false 6 | model: 7 | max_length: 8192 8 | model_name: null # in case of creating a new model, set this to None (null) 9 | starcoder_model_name: bigcode/starcoderbase-1b 10 | pretrained: true 11 | image_encoder_type: clip 12 | use_flash_attn: true 13 | adapter_norm: batch_norm 14 | init_type: glorot 15 | dropout: 0.1 16 | task: im2svg 17 | transformer_layer_cls: null # fsdp specific 18 | use_cache: false 19 | training: 20 | save_model_epochs: 1 21 | checkpointing_steps: 500 22 | checkpoints_total_limit: 3 23 | model_precision: bf16 24 | resume_from_checkpoint: false 25 | continue_training: false 26 | n_epochs: 4 27 | lr: 0.00001 28 | gradient_accumulation_steps: 4 29 | lr_scheduler: cosine 30 | lr_warmup_steps: 10 31 | adam_beta1: 0.95 32 | adam_beta2: 0.999 33 | adam_weight_decay: 1.0e-06 34 | adam_epsilon: 1e-08 35 | optimizer: adamw 36 | train_image_encoder: true 37 | train_LLM: true 38 | use_gradient_checkpointing: false 39 | fsdp: 40 | enable: false 41 | data: 42 | num_workers: 4 43 | train: 44 | batch_size: 2 45 | target: starvector.data.stacksvg.SVGStackDataset 46 | params: 47 | split: train 48 | dataset_name: starvector/svg-stack 49 | im_size: 224 50 | num_samples: -1 51 | transforms: false 52 | select_dataset_name: false 53 | test: 54 | batch_size: 2 55 | target: starvector.data.stacksvg.SVGStackDataset 56 | params: 57 | split: test 58 | dataset_name: starvector/svg-stack 59 | im_size: 224 60 | num_samples: -1 61 | transforms: false 62 | select_dataset_name: false 63 | generation: 64 | max_length: 8192 65 | min_length: 10 66 | num_beams: 3 67 | temperature: 1.0 68 | num_captions: 1 69 | repetition_penalty: 1.0 70 | length_penalty: 0.5 71 | top_p: 0.95 72 | use_nucleus_sampling: true 73 | im_size: 224 74 | dpi: 2 75 | scale: 300 76 | metrics: 77 | L2: true 78 | Masked-L2: false 79 | LPIPS: true 80 | SSIM: true 81 | FID: false 82 | FID_clip: false 83 | CLIPScore: false 84 | CountTokenLength: true 85 | ratio_post_processed: false 86 | ratio_non_compiling: false 87 | DinoScore: true 88 | -------------------------------------------------------------------------------- /configs/models/starvector-1b/text2svg-figr.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | project: starvector-1b-text2svg 3 | use_wandb: false 4 | entity: joanrod 5 | copy_code: false 6 | model: 7 | max_length: 8192 8 | model_name: starvector/starvector-1b-im2svg 9 | starcoder_model_name: bigcode/starcoderbase-1b 10 | pretrained: true 11 | image_encoder_type: clip 12 | use_flash_attn: true 13 | adapter_norm: layer_norm 14 | init_type: normal 15 | dropout: 0.1 16 | task: text2svg 17 | transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific 18 | use_cache: false 19 | training: 20 | save_model_epochs: 1 21 | checkpointing_steps: 500 22 | checkpoints_total_limit: 5 23 | model_precision: bf16 24 | resume_from_checkpoint: false 25 | continue_training: false 26 | n_epochs: 2 27 | lr: 2e-5 28 | gradient_accumulation_steps: 8 29 | lr_scheduler: cosine 30 | lr_warmup_steps: 100 31 | adam_beta1: 0.95 32 | adam_beta2: 0.999 33 | adam_weight_decay: 1.0e-06 34 | adam_epsilon: 1e-08 35 | optimizer: adamw 36 | use_gradient_checkpointing: false 37 | train_image_encoder: false 38 | train_LLM: true 39 | fsdp: 40 | enable: false # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 41 | cpu_offload: false 42 | sharding_strategy: hsdp 43 | backward_prefetch: BACKWARD_PRE 44 | use_orig_params: true 45 | sync_module_states: true 46 | forward_prefetch: false 47 | cpu_ram_efficient_loading: true 48 | data: 49 | num_workers: 16 50 | train: 51 | batch_size: 4 52 | target: starvector.data.figrsvg.FigrSVGDataset 53 | params: 54 | split: train 55 | dataset_name: starvector/FIGR-SVG 56 | im_size: 224 57 | num_samples: -1 58 | transforms: false 59 | select_dataset_name: false 60 | test: 61 | batch_size: 8 62 | target: starvector.data.figrsvg.FigrSVGDataset 63 | params: 64 | split: test 65 | dataset_name: starvector/FIGR-SVG 66 | im_size: 224 67 | num_samples: -1 68 | transforms: false 69 | select_dataset_name: false 70 | generation: 71 | max_length: 10000 72 | min_length: 10 73 | num_beams: 3 74 | temperature: 1.0 75 | num_captions: 1 76 | repetition_penalty: 1.0 77 | length_penalty: 0.5 78 | top_p: 0.95 79 | use_nucleus_sampling: true 80 | im_size: 384 81 | dpi: 2 82 | scale: 300 83 | metrics: 84 | L2: false 85 | Masked-L2: false 86 | LPIPS: false 87 | SSIM: false 88 | FID: false 89 | FID_clip: false 90 | CLIPScore: true 91 | CountTokenLength: true 92 | ratio_post_processed: false 93 | ratio_non_compiling: false 94 | DinoScore: false -------------------------------------------------------------------------------- /configs/models/starvector-1b/text2svg-stack.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | project: starvector-1b-text2svg 3 | use_wandb: false 4 | entity: joanrod 5 | copy_code: false 6 | model: 7 | max_length: 8192 8 | model_name: starvector/starvector-1b-im2svg 9 | starcoder_model_name: bigcode/starcoderbase-1b 10 | pretrained: true 11 | image_encoder_type: clip 12 | use_flash_attn: true 13 | adapter_norm: layer_norm 14 | init_type: normal 15 | dropout: 0.1 16 | task: text2svg 17 | transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific 18 | use_cache: false 19 | training: 20 | save_model_epochs: 1 21 | checkpointing_steps: 500 22 | checkpoints_total_limit: 5 23 | model_precision: bf16 24 | resume_from_checkpoint: false 25 | continue_training: false 26 | n_epochs: 2 27 | lr: 2e-5 28 | gradient_accumulation_steps: 8 29 | lr_scheduler: cosine 30 | lr_warmup_steps: 100 31 | adam_beta1: 0.95 32 | adam_beta2: 0.999 33 | adam_weight_decay: 1.0e-06 34 | adam_epsilon: 1e-08 35 | optimizer: adamw 36 | use_gradient_checkpointing: false 37 | train_image_encoder: false 38 | train_LLM: true 39 | fsdp: 40 | enable: false # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 41 | cpu_offload: false 42 | sharding_strategy: hsdp 43 | backward_prefetch: BACKWARD_PRE 44 | use_orig_params: true 45 | sync_module_states: true 46 | forward_prefetch: false 47 | cpu_ram_efficient_loading: true 48 | data: 49 | num_workers: 16 50 | train: 51 | batch_size: 1 52 | target: starvector.data.stacksvg.SVGStackDataset 53 | params: 54 | split: train 55 | dataset_name: starvector/text2svg-stack 56 | im_size: 224 57 | num_samples: -1 58 | transforms: false 59 | select_dataset_name: false 60 | image_processor: siglip_384 61 | test: 62 | batch_size: 4 63 | target: starvector.data.stacksvg.SVGStackDataset 64 | params: 65 | split: test 66 | dataset_name: starvector/text2svg-stack 67 | im_size: 224 68 | num_samples: 64 69 | transforms: false 70 | select_dataset_name: false 71 | image_processor: siglip_384 72 | generation: 73 | max_length: 10000 74 | min_length: 10 75 | num_beams: 3 76 | temperature: 1.0 77 | num_captions: 1 78 | repetition_penalty: 1.0 79 | length_penalty: 0.5 80 | top_p: 0.95 81 | use_nucleus_sampling: true 82 | im_size: 384 83 | dpi: 2 84 | scale: 300 85 | metrics: 86 | L2: false 87 | Masked-L2: false 88 | LPIPS: false 89 | SSIM: false 90 | FID: false 91 | FID_clip: false 92 | CLIPScore: true 93 | CountTokenLength: true 94 | ratio_post_processed: false 95 | ratio_non_compiling: false 96 | DinoScore: false -------------------------------------------------------------------------------- /configs/models/starvector-8b/im2svg-emoji.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | project: starvector-8b-im2svg 3 | use_wandb: false 4 | entity: joanrod 5 | copy_code: false 6 | model: 7 | max_length: 16000 8 | model_name: starvector/starvector-8b-im2svg 9 | starcoder_model_name: bigcode/starcoder2-7b 10 | pretrained: true 11 | image_encoder_type: siglip_384 12 | use_flash_attn: true 13 | adapter_norm: layer_norm 14 | init_type: normal 15 | dropout: 0.1 16 | task: im2svg 17 | transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific 18 | use_cache: false 19 | training: 20 | save_model_epochs: 1 21 | checkpointing_steps: 500 22 | checkpoints_total_limit: 5 23 | model_precision: bf16 24 | resume_from_checkpoint: false 25 | continue_training: false 26 | n_epochs: 4 27 | lr: 0.00001 28 | gradient_accumulation_steps: 4 29 | lr_scheduler: cosine 30 | lr_warmup_steps: 10 31 | adam_beta1: 0.95 32 | adam_beta2: 0.999 33 | adam_weight_decay: 1.0e-06 34 | adam_epsilon: 1e-08 35 | optimizer: adamw 36 | use_gradient_checkpointing: true 37 | train_image_encoder: true 38 | train_LLM: true 39 | fsdp: 40 | enable: true # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 41 | cpu_offload: false 42 | sharding_strategy: hsdp 43 | backward_prefetch: BACKWARD_PRE 44 | use_orig_params: true 45 | sync_module_states: true 46 | forward_prefetch: false 47 | cpu_ram_efficient_loading: true 48 | data: 49 | num_workers: 16 50 | train: 51 | batch_size: 1 52 | target: starvector.data.emojisvg.EmojiSVGDataset 53 | params: 54 | split: train 55 | dataset_name: starvector/svg-emoji 56 | im_size: 384 57 | num_samples: -1 58 | transforms: false 59 | image_processor: siglip_384 60 | test: 61 | batch_size: 1 62 | target: starvector.data.emojisvg.EmojiSVGDataset 63 | params: 64 | split: test 65 | dataset_name: starvector/svg-emoji 66 | im_size: 384 67 | num_samples: 128 68 | transforms: false 69 | image_processor: siglip_384 70 | generation: 71 | max_length: 10000 72 | min_length: 10 73 | num_beams: 3 74 | temperature: 1.0 75 | num_captions: 1 76 | repetition_penalty: 1.0 77 | length_penalty: 0.5 78 | top_p: 0.95 79 | use_nucleus_sampling: true 80 | im_size: 384 81 | dpi: 2 82 | scale: 300 83 | metrics: 84 | L2: true 85 | Masked-L2: false 86 | LPIPS: true 87 | SSIM: true 88 | FID: false 89 | FID_clip: false 90 | CLIPScore: false 91 | CountTokenLength: true 92 | ratio_post_processed: false 93 | ratio_non_compiling: false 94 | DinoScore: true -------------------------------------------------------------------------------- /configs/models/starvector-8b/im2svg-fonts-simple.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | project: starvector-8b-im2svg 3 | use_wandb: false 4 | entity: joanrod 5 | copy_code: false 6 | model: 7 | max_length: 16000 8 | model_name: starvector/starvector-8b-im2svg 9 | starcoder_model_name: bigcode/starcoder2-7b 10 | pretrained: true 11 | image_encoder_type: siglip_384 12 | use_flash_attn: true 13 | adapter_norm: layer_norm 14 | init_type: normal 15 | dropout: 0.1 16 | task: im2svg 17 | transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific 18 | use_cache: false 19 | training: 20 | save_model_epochs: 1 21 | checkpointing_steps: 500 22 | checkpoints_total_limit: 5 23 | model_precision: bf16 24 | resume_from_checkpoint: false 25 | continue_training: false 26 | n_epochs: 4 27 | lr: 0.00001 28 | gradient_accumulation_steps: 4 29 | lr_scheduler: cosine 30 | lr_warmup_steps: 10 31 | adam_beta1: 0.95 32 | adam_beta2: 0.999 33 | adam_weight_decay: 1.0e-06 34 | adam_epsilon: 1e-08 35 | optimizer: adamw 36 | use_gradient_checkpointing: true 37 | train_image_encoder: true 38 | train_LLM: true 39 | fsdp: 40 | enable: true # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 41 | cpu_offload: false 42 | sharding_strategy: hsdp 43 | backward_prefetch: BACKWARD_PRE 44 | use_orig_params: true 45 | sync_module_states: true 46 | forward_prefetch: false 47 | cpu_ram_efficient_loading: true 48 | data: 49 | num_workers: 16 50 | train: 51 | batch_size: 1 52 | target: starvector.data.fontsvg.FontSVGDataset 53 | params: 54 | split: train 55 | dataset_name: starvector/svg-fonts-simple 56 | im_size: 384 57 | num_samples: -1 58 | transforms: false 59 | image_processor: siglip_384 60 | test: 61 | batch_size: 4 62 | target: starvector.data.fontsvg.FontSVGDataset 63 | params: 64 | split: test 65 | dataset_name: starvector/svg-fonts-simple 66 | im_size: 384 67 | num_samples: 128 68 | transforms: false 69 | image_processor: siglip_384 70 | generation: 71 | max_length: 10000 72 | min_length: 10 73 | num_beams: 3 74 | temperature: 1.0 75 | num_captions: 1 76 | repetition_penalty: 1.0 77 | length_penalty: 0.5 78 | top_p: 0.95 79 | use_nucleus_sampling: true 80 | im_size: 384 81 | dpi: 2 82 | scale: 300 83 | metrics: 84 | L2: true 85 | Masked-L2: false 86 | LPIPS: true 87 | SSIM: true 88 | FID: false 89 | FID_clip: false 90 | CLIPScore: false 91 | CountTokenLength: true 92 | ratio_post_processed: false 93 | ratio_non_compiling: false 94 | DinoScore: true -------------------------------------------------------------------------------- /configs/models/starvector-8b/im2svg-fonts.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | project: starvector-8b-im2svg 3 | use_wandb: false 4 | entity: joanrod 5 | copy_code: false 6 | model: 7 | max_length: 16000 8 | model_name: starvector/starvector-8b-im2svg 9 | starcoder_model_name: bigcode/starcoder2-7b 10 | pretrained: true 11 | image_encoder_type: siglip_384 12 | use_flash_attn: true 13 | adapter_norm: layer_norm 14 | init_type: normal 15 | dropout: 0.1 16 | task: im2svg 17 | transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific 18 | use_cache: false 19 | training: 20 | save_model_epochs: 1 21 | checkpointing_steps: 500 22 | checkpoints_total_limit: 5 23 | model_precision: bf16 24 | resume_from_checkpoint: false 25 | continue_training: false 26 | n_epochs: 4 27 | lr: 0.00001 28 | gradient_accumulation_steps: 4 29 | lr_scheduler: cosine 30 | lr_warmup_steps: 10 31 | adam_beta1: 0.95 32 | adam_beta2: 0.999 33 | adam_weight_decay: 1.0e-06 34 | adam_epsilon: 1e-08 35 | optimizer: adamw 36 | use_gradient_checkpointing: true 37 | train_image_encoder: true 38 | train_LLM: true 39 | fsdp: 40 | enable: true # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 41 | cpu_offload: false 42 | sharding_strategy: hsdp 43 | backward_prefetch: BACKWARD_PRE 44 | use_orig_params: true 45 | sync_module_states: true 46 | forward_prefetch: false 47 | cpu_ram_efficient_loading: true 48 | data: 49 | num_workers: 16 50 | train: 51 | batch_size: 1 52 | target: starvector.data.fontsvg.FontSVGDataset 53 | params: 54 | split: train 55 | dataset_name: starvector/svg-fonts 56 | im_size: 384 57 | num_samples: -1 58 | transforms: false 59 | image_processor: siglip_384 60 | test: 61 | batch_size: 4 62 | target: starvector.data.fontsvg.FontSVGDataset 63 | params: 64 | split: test 65 | dataset_name: starvector/svg-fonts 66 | im_size: 384 67 | num_samples: 128 68 | transforms: false 69 | image_processor: siglip_384 70 | generation: 71 | max_length: 10000 72 | min_length: 10 73 | num_beams: 3 74 | temperature: 1.0 75 | num_captions: 1 76 | repetition_penalty: 1.0 77 | length_penalty: 0.5 78 | top_p: 0.95 79 | use_nucleus_sampling: true 80 | im_size: 384 81 | dpi: 2 82 | scale: 300 83 | metrics: 84 | L2: true 85 | Masked-L2: false 86 | LPIPS: true 87 | SSIM: true 88 | FID: false 89 | FID_clip: false 90 | CLIPScore: false 91 | CountTokenLength: true 92 | ratio_post_processed: false 93 | ratio_non_compiling: false 94 | DinoScore: true -------------------------------------------------------------------------------- /configs/models/starvector-8b/im2svg-icons.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | project: starvector-8b-im2svg 3 | use_wandb: false 4 | entity: joanrod 5 | copy_code: false 6 | model: 7 | max_length: 16000 8 | model_name: starvector/starvector-8b-im2svg 9 | starcoder_model_name: bigcode/starcoder2-7b 10 | pretrained: true 11 | image_encoder_type: siglip_384 12 | use_flash_attn: true 13 | adapter_norm: layer_norm 14 | init_type: normal 15 | dropout: 0.1 16 | task: im2svg 17 | transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific 18 | use_cache: false 19 | training: 20 | save_model_epochs: 1 21 | checkpointing_steps: 500 22 | checkpoints_total_limit: 5 23 | model_precision: bf16 24 | resume_from_checkpoint: false 25 | continue_training: false 26 | n_epochs: 4 27 | lr: 0.00001 28 | gradient_accumulation_steps: 4 29 | lr_scheduler: cosine 30 | lr_warmup_steps: 10 31 | adam_beta1: 0.95 32 | adam_beta2: 0.999 33 | adam_weight_decay: 1.0e-06 34 | adam_epsilon: 1e-08 35 | optimizer: adamw 36 | use_gradient_checkpointing: true 37 | train_image_encoder: true 38 | train_LLM: true 39 | fsdp: 40 | enable: true # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 41 | cpu_offload: false 42 | sharding_strategy: hsdp 43 | backward_prefetch: BACKWARD_PRE 44 | use_orig_params: true 45 | sync_module_states: true 46 | forward_prefetch: false 47 | cpu_ram_efficient_loading: true 48 | data: 49 | num_workers: 16 50 | train: 51 | batch_size: 1 52 | target: starvector.data.iconsvg.SVGIconsDataset 53 | params: 54 | split: train 55 | dataset_name: starvector/svg-icons 56 | im_size: 384 57 | num_samples: -1 58 | transforms: false 59 | image_processor: siglip_384 60 | test: 61 | batch_size: 1 62 | target: starvector.data.iconsvg.SVGIconsDataset 63 | params: 64 | split: test 65 | dataset_name: starvector/svg-icons 66 | im_size: 384 67 | num_samples: 128 68 | transforms: false 69 | image_processor: siglip_384 70 | generation: 71 | max_length: 10000 72 | min_length: 10 73 | num_beams: 3 74 | temperature: 1.0 75 | num_captions: 1 76 | repetition_penalty: 1.0 77 | length_penalty: 0.5 78 | top_p: 0.95 79 | use_nucleus_sampling: true 80 | im_size: 384 81 | dpi: 2 82 | scale: 300 83 | metrics: 84 | L2: true 85 | Masked-L2: false 86 | LPIPS: true 87 | SSIM: true 88 | FID: false 89 | FID_clip: false 90 | CLIPScore: false 91 | CountTokenLength: true 92 | ratio_post_processed: false 93 | ratio_non_compiling: false 94 | DinoScore: true -------------------------------------------------------------------------------- /configs/models/starvector-8b/im2svg-stack.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | project: starvector-8b-im2svg 3 | use_wandb: false 4 | entity: joanrod 5 | copy_code: false 6 | model: 7 | max_length: 16000 8 | model_name: null 9 | starcoder_model_name: bigcode/starcoder2-7b 10 | pretrained: true 11 | image_encoder_type: siglip_384 12 | use_flash_attn: true 13 | adapter_norm: layer_norm 14 | init_type: normal 15 | dropout: 0.1 16 | task: im2svg 17 | transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific 18 | use_cache: false 19 | training: 20 | save_model_epochs: 1 21 | checkpointing_steps: 500 22 | checkpoints_total_limit: 5 23 | model_precision: bf16 24 | resume_from_checkpoint: false 25 | continue_training: false 26 | n_epochs: 4 27 | lr: 0.00001 28 | gradient_accumulation_steps: 4 29 | lr_scheduler: cosine 30 | lr_warmup_steps: 10 31 | adam_beta1: 0.95 32 | adam_beta2: 0.999 33 | adam_weight_decay: 1.0e-06 34 | adam_epsilon: 1e-08 35 | optimizer: adamw 36 | use_gradient_checkpointing: true 37 | train_image_encoder: true 38 | train_LLM: true 39 | fsdp: 40 | enable: true # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 41 | cpu_offload: false 42 | sharding_strategy: hsdp 43 | backward_prefetch: BACKWARD_PRE 44 | use_orig_params: true 45 | sync_module_states: true 46 | forward_prefetch: false 47 | cpu_ram_efficient_loading: true 48 | data: 49 | num_workers: 16 50 | train: 51 | batch_size: 1 52 | target: starvector.data.stacksvg.SVGStackDataset 53 | params: 54 | split: train 55 | dataset_name: starvector/svg-stack 56 | im_size: 384 57 | num_samples: -1 58 | transforms: false 59 | select_dataset_name: false 60 | image_processor: siglip_384 61 | test: 62 | batch_size: 2 63 | target: starvector.data.stacksvg.SVGStackDataset 64 | params: 65 | split: test 66 | dataset_name: starvector/svg-stack 67 | im_size: 384 68 | num_samples: 64 69 | transforms: false 70 | select_dataset_name: false 71 | image_processor: siglip_384 72 | generation: 73 | max_length: 10000 74 | min_length: 10 75 | num_beams: 3 76 | temperature: 1.0 77 | num_captions: 1 78 | repetition_penalty: 1.0 79 | length_penalty: 0.5 80 | top_p: 0.95 81 | use_nucleus_sampling: true 82 | im_size: 384 83 | dpi: 2 84 | scale: 300 85 | metrics: 86 | L2: true 87 | Masked-L2: false 88 | LPIPS: true 89 | SSIM: true 90 | FID: false 91 | FID_clip: false 92 | CLIPScore: false 93 | CountTokenLength: true 94 | ratio_post_processed: false 95 | ratio_non_compiling: false 96 | DinoScore: true -------------------------------------------------------------------------------- /configs/models/starvector-8b/text2svg-figr.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | project: starvector-8b-text2svg 3 | use_wandb: false 4 | entity: joanrod 5 | copy_code: false 6 | model: 7 | max_length: 16000 8 | model_name: starvector/starvector-8b-im2svg 9 | starcoder_model_name: bigcode/starcoder2-7b 10 | pretrained: true 11 | image_encoder_type: siglip_384 12 | use_flash_attn: true 13 | adapter_norm: layer_norm 14 | init_type: normal 15 | dropout: 0.1 16 | task: text2svg 17 | transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific 18 | use_cache: false 19 | training: 20 | save_model_epochs: 1 21 | checkpointing_steps: 500 22 | checkpoints_total_limit: 5 23 | model_precision: bf16 24 | resume_from_checkpoint: false 25 | continue_training: false 26 | n_epochs: 10 27 | lr: 0.00001 28 | gradient_accumulation_steps: 4 29 | lr_scheduler: cosine 30 | lr_warmup_steps: 10 31 | adam_beta1: 0.95 32 | adam_beta2: 0.999 33 | adam_weight_decay: 1.0e-06 34 | adam_epsilon: 1e-08 35 | optimizer: adamw 36 | use_gradient_checkpointing: true 37 | train_image_encoder: true 38 | train_LLM: true 39 | fsdp: 40 | enable: true # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 41 | cpu_offload: false 42 | sharding_strategy: hsdp 43 | backward_prefetch: BACKWARD_PRE 44 | use_orig_params: true 45 | sync_module_states: true 46 | forward_prefetch: false 47 | cpu_ram_efficient_loading: true 48 | data: 49 | num_workers: 16 50 | train: 51 | batch_size: 4 52 | target: starvector.data.stacksvg.SVGStackDataset 53 | params: 54 | split: train 55 | dataset_name: starvector/FIGR-SVG 56 | im_size: 384 57 | num_samples: -1 58 | transforms: false 59 | select_dataset_name: false 60 | image_processor: siglip_384 61 | test: 62 | batch_size: 4 63 | target: starvector.data.stacksvg.SVGStackDataset 64 | params: 65 | split: test 66 | dataset_name: starvector/FIGR-SVG 67 | im_size: 384 68 | num_samples: 64 69 | transforms: false 70 | select_dataset_name: false 71 | image_processor: siglip_384 72 | generation: 73 | max_length: 10000 74 | min_length: 10 75 | num_beams: 3 76 | temperature: 1.0 77 | num_captions: 1 78 | repetition_penalty: 1.0 79 | length_penalty: 0.5 80 | top_p: 0.95 81 | use_nucleus_sampling: true 82 | im_size: 384 83 | dpi: 2 84 | scale: 300 85 | metrics: 86 | L2: false 87 | Masked-L2: false 88 | LPIPS: false 89 | SSIM: false 90 | FID: false 91 | FID_clip: false 92 | CLIPScore: true 93 | CountTokenLength: true 94 | ratio_post_processed: false 95 | ratio_non_compiling: false 96 | DinoScore: false -------------------------------------------------------------------------------- /configs/models/starvector-8b/text2svg-stack.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | project: starvector-8b-text2svg 3 | use_wandb: false 4 | entity: joanrod 5 | copy_code: false 6 | model: 7 | max_length: 16000 8 | model_name: starvector/starvector-8b-im2svg 9 | starcoder_model_name: bigcode/starcoder2-7b 10 | pretrained: true 11 | image_encoder_type: siglip_384 12 | use_flash_attn: true 13 | adapter_norm: layer_norm 14 | init_type: normal 15 | dropout: 0.1 16 | task: text2svg 17 | transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific 18 | use_cache: false 19 | training: 20 | save_model_epochs: 1 21 | checkpointing_steps: 500 22 | checkpoints_total_limit: 5 23 | model_precision: bf16 24 | resume_from_checkpoint: false 25 | continue_training: false 26 | n_epochs: 4 27 | lr: 0.00001 28 | gradient_accumulation_steps: 4 29 | lr_scheduler: cosine 30 | lr_warmup_steps: 10 31 | adam_beta1: 0.95 32 | adam_beta2: 0.999 33 | adam_weight_decay: 1.0e-06 34 | adam_epsilon: 1e-08 35 | optimizer: adamw 36 | use_gradient_checkpointing: true 37 | train_image_encoder: true 38 | train_LLM: true 39 | fsdp: 40 | enable: true # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 41 | cpu_offload: false 42 | sharding_strategy: hsdp 43 | backward_prefetch: BACKWARD_PRE 44 | use_orig_params: true 45 | sync_module_states: true 46 | forward_prefetch: false 47 | cpu_ram_efficient_loading: true 48 | data: 49 | num_workers: 16 50 | train: 51 | batch_size: 4 52 | target: starvector.data.stacksvg.SVGStackDataset 53 | params: 54 | split: train 55 | dataset_name: starvector/text2svg-stack 56 | im_size: 384 57 | num_samples: -1 58 | transforms: false 59 | select_dataset_name: false 60 | image_processor: siglip_384 61 | test: 62 | batch_size: 4 63 | target: starvector.data.stacksvg.SVGStackDataset 64 | params: 65 | split: test 66 | dataset_name: starvector/text2svg-stack 67 | im_size: 384 68 | num_samples: 64 69 | transforms: false 70 | select_dataset_name: false 71 | image_processor: siglip_384 72 | generation: 73 | max_length: 10000 74 | min_length: 10 75 | num_beams: 3 76 | temperature: 1.0 77 | num_captions: 1 78 | repetition_penalty: 1.0 79 | length_penalty: 0.5 80 | top_p: 0.95 81 | use_nucleus_sampling: true 82 | im_size: 384 83 | dpi: 2 84 | scale: 300 85 | metrics: 86 | L2: false 87 | Masked-L2: false 88 | LPIPS: false 89 | SSIM: false 90 | FID: false 91 | FID_clip: false 92 | CLIPScore: true 93 | CountTokenLength: true 94 | ratio_post_processed: false 95 | ratio_non_compiling: false 96 | DinoScore: false -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel 2 | 3 | # Install necessary dependencies 4 | RUN apt-get update && apt-get install -y --no-install-recommends \ 5 | wget \ 6 | git \ 7 | vim \ 8 | build-essential \ 9 | libcairo2 \ 10 | cuda-compiler-12-4 \ 11 | libaio-dev \ 12 | && rm -rf /var/lib/apt/lists/* 13 | 14 | 15 | # Install Jupyter Notebook, clone and install star-vector and dependencies 16 | RUN pip install --upgrade pip && pip install jupyter deepspeed \ 17 | && git clone https://github.com/joanrod/star-vector.git /tmp/star-vector \ 18 | && pip install /tmp/star-vector \ 19 | && rm -rf /tmp/star-vector 20 | 21 | # Cleanup unneeded packages 22 | RUN apt-get purge -y --auto-remove \ 23 | git \ 24 | build-essential \ 25 | && apt-get clean \ 26 | && rm -rf /var/lib/apt/lists/* 27 | 28 | # Expose Jupyter Notebook port 29 | EXPOSE 8888 30 | 31 | WORKDIR /workspace 32 | 33 | # Launch Jupyter Notebook 34 | CMD ["jupyter", "notebook", "--ip=0.0.0.0", "--port=8888", "--no-browser", "--allow-root"] 35 | -------------------------------------------------------------------------------- /docker/Image2SVG Generation Example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "659c25fd-2a10-4010-ae4c-bf6c42c59a6d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from PIL import Image\n", 11 | "from starvector.model.starvector_arch import StarVectorForCausalLM\n", 12 | "from starvector.data.util import process_and_rasterize_svg\n", 13 | "import torch\n", 14 | "\n", 15 | "model_name = \"starvector/starvector-1b-im2svg\"\n", 16 | "#model_name = \"starvector/starvector-8b-im2svg\"\n", 17 | "\n", 18 | "starvector = StarVectorForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)\n", 19 | "\n", 20 | "starvector.cuda()\n", 21 | "starvector.eval()\n" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "0a3ae9d8-7e7c-4eba-b830-4f45d68f7b5f", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "image_pil = Image.open('../assets/examples/sample-18.png')\n", 32 | "\n", 33 | "image = starvector.process_images([image_pil])[0].cuda()\n", 34 | "batch = {\"image\": image}\n", 35 | "\n", 36 | "raw_svg = starvector.generate_im2svg(batch, max_length=5000)[0]\n", 37 | "svg, raster_image = process_and_rasterize_svg(raw_svg)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "id": "5a100757-c512-4e4b-a05d-f5a74fbf9133", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "from IPython.display import SVG, display\n", 48 | "\n", 49 | "display(image_pil)\n", 50 | "display(SVG(svg))\n", 51 | "display(raster_image)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "b13dd1af-d579-452a-96f8-f0fae5289d24", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "del starvector\n", 62 | "torch.cuda.empty_cache()" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "6cf1d4d9-2798-4b67-9942-7b3a5cf03122", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [] 72 | } 73 | ], 74 | "metadata": { 75 | "kernelspec": { 76 | "display_name": "Python 3 (ipykernel)", 77 | "language": "python", 78 | "name": "python3" 79 | }, 80 | "language_info": { 81 | "codemirror_mode": { 82 | "name": "ipython", 83 | "version": 3 84 | }, 85 | "file_extension": ".py", 86 | "mimetype": "text/x-python", 87 | "name": "python", 88 | "nbconvert_exporter": "python", 89 | "pygments_lexer": "ipython3", 90 | "version": "3.11.10" 91 | } 92 | }, 93 | "nbformat": 4, 94 | "nbformat_minor": 5 95 | } 96 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | ## 🐳 Docker Setup for StarVector 2 | 3 | To simplify the setup process and avoid dependency issues, you can use a Docker container to run **StarVector** in a self-contained environment. The Dockerfile and this guide are located in the `docker/` subfolder of the repository. 4 | 5 | --- 6 | 7 | ### 🛠️ Build the Docker Image 8 | 9 | 1. **Clone the Repository**: 10 | 11 | ```bash 12 | git clone https://github.com/joanrod/star-vector.git 13 | cd star-vector/docker 14 | ``` 15 | 16 | 2. **Build the Image**: 17 | 18 | ```bash 19 | docker build -t starvector:latest . 20 | ``` 21 | 22 | --- 23 | 24 | ### 🚀 Run the Docker Container 25 | 26 | ```bash 27 | docker run -it \ 28 | --gpus all \ 29 | -p 8888:8888 \ 30 | -v $(pwd)/..:/workspace \ 31 | -v ~/.cache/huggingface:/root/.cache/huggingface \ 32 | --env HUGGING_FACE_HUB_TOKEN= \ 33 | --name starvector \ 34 | starvector:latest 35 | ``` 36 | 37 | **Options explained:** 38 | 39 | - `--gpus all`: Enables GPU acceleration (requires [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)) 40 | - `-p 8888:8888`: Exposes Jupyter Notebook to your host machine 41 | - `-v $(pwd)/..:/workspace`: Mounts the project root directory into the container 42 | - `-v ~/.cache/huggingface:/root/.cache/huggingface`: Shares Hugging Face cache for offline usage and faster loading 43 | - `--env HUGGING_FACE_HUB_TOKEN=...`: Sets an environment variable with your token for accessing gated models 44 | - `--name starvector`: Assigns a name to the container 45 | 46 | --- 47 | 48 | ### 🔑 Hugging Face Token 49 | 50 | **StarVector models depend on** the Hugging Face model `bigcode/starcoderbase-1b`, which is a gated model. 51 | 52 | To access it, you need a **Hugging Face token** with the following permission: 53 | 54 | > ✅ Read access to contents of all public gated repos you can access 55 | 56 | You can generate it here: https://huggingface.co/settings/tokens 57 | 58 | --- 59 | 60 | ### 🌐 Access Jupyter Notebook 61 | 62 | After the container starts, you'll see a URL like: 63 | 64 | ``` 65 | http://127.0.0.1:8888/?token= 66 | ``` 67 | 68 | Copy and open it in your browser to start using **StarVector** in a notebook environment. 69 | 70 | --- -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "starvector" 7 | version = "1.0" 8 | description = "Generating Scalable Vector Graphics Code from Images and Text" 9 | readme = "README.md" 10 | requires-python = ">=3.11" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "torch==2.5.1", 17 | "torchvision==0.20.1", 18 | "transformers==4.49.0", 19 | "tokenizers==0.21.1", 20 | "sentencepiece==0.2.0", 21 | "accelerate", 22 | "pydantic==2.10", 23 | "markdown2[all]", 24 | "numpy<2.0.0", 25 | "scikit-learn==1.2.2", 26 | "gradio==3.36.1", 27 | "gradio_client==0.2.9", 28 | "requests", 29 | "httpx==0.24.0", 30 | "uvicorn", 31 | "fastapi", 32 | "svgpathtools==1.6.1", 33 | "seaborn==0.12.2", 34 | "taming-transformers", 35 | "lpips", 36 | "cairosvg", 37 | "beautifulsoup4", 38 | "webcolors", 39 | "tqdm", 40 | "omegaconf", 41 | "open-clip-torch", 42 | "noise", 43 | "datasets", 44 | "scikit-image", 45 | "fairscale", 46 | "lxml", 47 | "torch-fidelity", 48 | "clip-openai", 49 | "scipy==1.11.1", 50 | "sentence-transformers", 51 | "reportlab", 52 | "svglib", 53 | "Pillow", 54 | "protobuf", 55 | "openai", 56 | "flash_attn==2.7.3" 57 | 58 | ] 59 | 60 | [project.optional-dependencies] 61 | train = ["deepspeed", "ninja", "wandb"] 62 | 63 | [project.urls] 64 | "Homepage" = "https://starvector.github.io" 65 | "Bug Tracker" = "https://github.com/joanrod/starvector/issues" 66 | 67 | [tool.setuptools.packages.find] 68 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 69 | 70 | [tool.wheel] 71 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 72 | -------------------------------------------------------------------------------- /scripts/quickstart-hf.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor 3 | from starvector.data.util import process_and_rasterize_svg 4 | import torch 5 | 6 | # model_name = "starvector/starvector-1b-im2svg" 7 | model_name = "starvector/starvector-8b-im2svg" 8 | 9 | starvector = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True) 10 | processor = starvector.model.processor 11 | tokenizer = starvector.model.svg_transformer.tokenizer 12 | 13 | starvector.cuda() 14 | starvector.eval() 15 | 16 | image_pil = Image.open('assets/examples/sample-18.png') 17 | 18 | image = processor(image_pil, return_tensors="pt")['pixel_values'].cuda() 19 | if not image.shape[0] == 1: 20 | image = image.squeeze(0) 21 | batch = {"image": image} 22 | 23 | raw_svg = starvector.generate_im2svg(batch, max_length=100)[0] 24 | svg, raster_image = process_and_rasterize_svg(raw_svg) 25 | -------------------------------------------------------------------------------- /scripts/quickstart-vllm.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from vllm import LLM, SamplingParams 3 | 4 | model_name = "starvector/starvector-1b-im2svg" 5 | # model_name = "starvector/starvector-8b-im2svg" 6 | 7 | sampling_params = SamplingParams( 8 | temperature=0.8, 9 | top_p=0.95, 10 | max_tokens=7900, 11 | n=1, 12 | frequency_penalty=0.0, 13 | repetition_penalty=1.0, 14 | top_k=-1, 15 | min_p=0.0, 16 | ) 17 | llm = LLM(model=model_name, trust_remote_code=True, max_model_len=8192) 18 | 19 | prompt_start = "" 20 | images = [Image.open('assets/examples/sample-18.png')] 21 | model_inputs_vllm = [] 22 | for i in range(len(images)): 23 | model_inputs_vllm.append({ 24 | "prompt": prompt_start, 25 | "multi_modal_data": {"image": images[i]} 26 | }) 27 | 28 | outputs = llm.generate(model_inputs_vllm, 29 | sampling_params=sampling_params, 30 | use_tqdm=False) 31 | 32 | completions = [] 33 | for i in range(len(outputs)): 34 | for j in range(len(outputs[i].outputs)): 35 | completions.append(outputs[i].outputs[j].text) 36 | 37 | print(completions) 38 | -------------------------------------------------------------------------------- /scripts/quickstart.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from starvector.model.starvector_arch import StarVectorForCausalLM 3 | from starvector.data.util import process_and_rasterize_svg 4 | import torch 5 | 6 | model_name = "starvector/starvector-1b-im2svg" 7 | # model_name = "starvector/starvector-8b-im2svg" 8 | 9 | starvector = StarVectorForCausalLM.from_pretrained(model_name, torch_dtype="auto") # add , torch_dtype="bfloat16" 10 | 11 | starvector.cuda() 12 | starvector.eval() 13 | 14 | image_pil = Image.open("assets/examples/sample-18.png") 15 | image_pil = image_pil.convert('RGB') 16 | image = starvector.process_images([image_pil])[0].to(torch.float16).cuda() 17 | batch = {"image": image} 18 | 19 | raw_svg = starvector.generate_im2svg(batch, max_length=4000, temperature=1.5, length_penalty=-1, repetition_penalty=3.1)[0] 20 | svg, raster_image = process_and_rasterize_svg(raw_svg) 21 | -------------------------------------------------------------------------------- /scripts/train/train-starvector-1b-im2svg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export HF_HOME= 4 | export HF_TOKEN= 5 | export WANDB_API_KEY= 6 | export OUTPUT_DIR= 7 | 8 | accelerate launch --config_file configs/accelerate/deepspeed-8-gpu.yaml \ 9 | starvector/train/train.py \ 10 | config=configs/models/starvector-1b/im2svg-stack.yaml 11 | -------------------------------------------------------------------------------- /scripts/train/train-starvector-1b-text2svg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export HF_HOME= 4 | export HF_TOKEN= 5 | export WANDB_API_KEY= 6 | export OUTPUT_DIR= 7 | 8 | accelerate launch --config_file configs/accelerate/deepspeed-8-gpu.yaml \ 9 | starvector/train/train.py \ 10 | config=configs/models/starvector-1b/text2svg-stack.yaml 11 | -------------------------------------------------------------------------------- /scripts/train/train-starvector-8b-im2svg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export HF_HOME= 4 | export HF_TOKEN= 5 | export WANDB_API_KEY= 6 | export OUTPUT_DIR= 7 | 8 | torchrun \ 9 | --nproc-per-node=2 \ 10 | --nnodes=1 \ 11 | starvector/train/train.py \ 12 | config=configs/models/starvector-8b/im2svg-stack.yaml 13 | -------------------------------------------------------------------------------- /scripts/train/train-starvector-8b-text2svg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export HF_HOME= 4 | export HF_TOKEN= 5 | export WANDB_API_KEY= 6 | export OUTPUT_DIR= 7 | 8 | torchrun \ 9 | --nproc-per-node=2 \ 10 | --nnodes=1 \ 11 | starvector/train/train.py \ 12 | config=configs/models/starvector-8b/text2svg-stack.yaml 13 | -------------------------------------------------------------------------------- /scripts/validation/validate-starvector-1b-im2svg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python starvector/validation/run_validator.py \ 4 | config=configs/generation/starvector-1b/im2svg.yaml \ 5 | dataset.name svg-stack \ 6 | model.generation_engine=hf 7 | 8 | python starvector/validation/run_validator.py \ 9 | config=configs/generation/starvector-1b/im2svg.yaml \ 10 | dataset.name svg-emoji \ 11 | model.generation_engine=hf 12 | 13 | python starvector/validation/run_validator.py \ 14 | config=configs/generation/starvector-1b/im2svg.yaml \ 15 | dataset.name svg-fonts \ 16 | model.generation_engine=hf 17 | 18 | python starvector/validation/run_validator.py \ 19 | config=configs/generation/starvector-1b/im2svg.yaml \ 20 | dataset.name svg-diagrams \ 21 | model.generation_engine=hf 22 | 23 | python starvector/validation/run_validator.py \ 24 | config=configs/generation/starvector-1b/im2svg.yaml \ 25 | dataset.name svg-icons \ 26 | model.generation_engine=hf 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /scripts/validation/validate-starvector-8b-im2svg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python starvector/validation/run_validator.py \ 4 | config=configs/generation/starvector-8b/im2svg.yaml \ 5 | dataset.name svg-stack \ 6 | model.generation_engine=hf 7 | 8 | python starvector/validation/run_validator.py \ 9 | config=configs/generation/starvector-8b/im2svg.yaml \ 10 | dataset.name svg-emoji \ 11 | model.generation_engine=hf 12 | 13 | python starvector/validation/run_validator.py \ 14 | config=configs/generation/starvector-8b/im2svg.yaml \ 15 | dataset.name svg-fonts \ 16 | model.generation_engine=hf 17 | 18 | python starvector/validation/run_validator.py \ 19 | config=configs/generation/starvector-8b/im2svg.yaml \ 20 | dataset.name svg-diagrams \ 21 | model.generation_engine=hf 22 | 23 | python starvector/validation/run_validator.py \ 24 | config=configs/generation/starvector-8b/im2svg.yaml \ 25 | dataset.name svg-icons \ 26 | model.generation_engine=hf 27 | 28 | 29 | -------------------------------------------------------------------------------- /starvector/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/star-vector/250afbe55c4a5ca9cbae181eed8ab924b30b82bd/starvector/__init__.py -------------------------------------------------------------------------------- /starvector/adapter.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.init as init 3 | import torch 4 | 5 | class Swish(nn.Module): 6 | def __init__(self): 7 | super(Swish, self).__init__() 8 | 9 | def forward(self, x): 10 | return x * torch.sigmoid(x) 11 | 12 | class Adapter(nn.Module): 13 | def __init__(self, input_size, output_size, adapter_norm="layer_norm", init_type="glorot", query_length=32, dropout_prob=0.1): 14 | super().__init__() 15 | self.query_length = query_length 16 | self.dropout_prob = dropout_prob 17 | self.adapter_norm = adapter_norm 18 | 19 | self.dropout = nn.Dropout(p=self.dropout_prob) 20 | 21 | self.c_fc = nn.Linear(input_size, input_size*2) 22 | self.act = Swish() 23 | self.c_proj = nn.Linear(input_size*2, output_size) 24 | 25 | if adapter_norm == "layer_norm": 26 | self.norm = nn.LayerNorm([self.query_length, output_size]) 27 | elif adapter_norm == "batch_norm": 28 | self.norm = nn.BatchNorm1d(self.query_length) 29 | 30 | self.init_type = init_type.lower() 31 | self._initialize_weights() 32 | 33 | def forward(self, hidden_states): 34 | hidden_states = self.dropout(hidden_states) 35 | hidden_states = self.c_fc(hidden_states) 36 | hidden_states = self.act(hidden_states) 37 | hidden_states = self.c_proj(hidden_states) 38 | hidden_states = self.norm(hidden_states) 39 | return hidden_states 40 | 41 | def _initialize_weights(self): 42 | for m in self.modules(): 43 | if isinstance(m, nn.Linear): 44 | if self.init_type == "glorot": 45 | init.xavier_uniform_(m.weight) 46 | if m.bias is not None: 47 | init.constant_(m.bias, 0) 48 | elif self.init_type == "normal": 49 | init.normal_(m.weight, mean=0, std=0.01) 50 | if m.bias is not None: 51 | init.constant_(m.bias, 0) 52 | else: 53 | raise ValueError("Invalid initialization type specified.") 54 | -------------------------------------------------------------------------------- /starvector/clip_model.py: -------------------------------------------------------------------------------- 1 | # Adapted from LAVIS-Salesforce: LAVIS/lavis/models/clip_vit.py 2 | 3 | from collections import OrderedDict 4 | from itertools import repeat 5 | import collections.abc 6 | import math 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 11 | 12 | def convert_weights_to_precision(model: nn.Module, precision: torch.dtype): 13 | """Convert applicable model parameters to the specified precision""" 14 | 15 | def _convert_weights_to_precision(l): 16 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 17 | l.weight.data = l.weight.data.to(precision) 18 | if l.bias is not None: 19 | l.bias.data = l.bias.data.to(precision) 20 | 21 | elif isinstance(l, (nn.MultiheadAttention)): 22 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 23 | tensor = getattr(l, attr) 24 | if tensor is not None: 25 | tensor.data = tensor.data.to(precision) 26 | else: 27 | for _, p in l.named_parameters(): 28 | p.data = p.data.to(precision) 29 | 30 | model.apply(_convert_weights_to_precision) 31 | 32 | class Bottleneck(nn.Module): 33 | expansion = 4 34 | 35 | def __init__(self, inplanes, planes, stride=1): 36 | super().__init__() 37 | 38 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 39 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 40 | self.bn1 = nn.BatchNorm2d(planes) 41 | self.relu1 = nn.ReLU(inplace=True) 42 | 43 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.relu2 = nn.ReLU(inplace=True) 46 | 47 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 48 | 49 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 51 | self.relu3 = nn.ReLU(inplace=True) 52 | 53 | self.downsample = None 54 | self.stride = stride 55 | 56 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 57 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 58 | self.downsample = nn.Sequential(OrderedDict([ 59 | ("-1", nn.AvgPool2d(stride)), 60 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 61 | ("1", nn.BatchNorm2d(planes * self.expansion)) 62 | ])) 63 | 64 | def forward(self, x: torch.Tensor): 65 | identity = x 66 | 67 | out = self.relu1(self.bn1(self.conv1(x))) 68 | out = self.relu2(self.bn2(self.conv2(out))) 69 | out = self.avgpool(out) 70 | out = self.bn3(self.conv3(out)) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | out = self.relu3(out) 77 | return out 78 | 79 | 80 | class AttentionPool2d(nn.Module): 81 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 82 | super().__init__() 83 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 84 | self.k_proj = nn.Linear(embed_dim, embed_dim) 85 | self.q_proj = nn.Linear(embed_dim, embed_dim) 86 | self.v_proj = nn.Linear(embed_dim, embed_dim) 87 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 88 | self.num_heads = num_heads 89 | 90 | def forward(self, x): 91 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 92 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 93 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 94 | x, _ = F.multi_head_attention_forward( 95 | query=x, key=x, value=x, 96 | embed_dim_to_check=x.shape[-1], 97 | num_heads=self.num_heads, 98 | q_proj_weight=self.q_proj.weight, 99 | k_proj_weight=self.k_proj.weight, 100 | v_proj_weight=self.v_proj.weight, 101 | in_proj_weight=None, 102 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 103 | bias_k=None, 104 | bias_v=None, 105 | add_zero_attn=False, 106 | dropout_p=0, 107 | out_proj_weight=self.c_proj.weight, 108 | out_proj_bias=self.c_proj.bias, 109 | use_separate_proj_weight=True, 110 | training=self.training, 111 | need_weights=False 112 | ) 113 | 114 | return x[0] 115 | 116 | 117 | class LayerNorm(nn.LayerNorm): 118 | """Subclass torch's LayerNorm to handle fp16.""" 119 | 120 | def forward(self, x: torch.Tensor): 121 | orig_type = x.dtype 122 | layernorm_dtype = self.weight.dtype 123 | ret = super().forward(x.type(layernorm_dtype)) 124 | return ret.type(orig_type) 125 | 126 | class QuickGELU(nn.Module): 127 | def forward(self, x: torch.Tensor): 128 | return x * torch.sigmoid(1.702 * x) 129 | 130 | class ResidualAttentionBlock(nn.Module): 131 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): 132 | super().__init__() 133 | 134 | self.attn = nn.MultiheadAttention(d_model, n_head) 135 | self.ln_1 = LayerNorm(d_model) 136 | self.mlp = nn.Sequential(OrderedDict([ 137 | ("c_fc", nn.Linear(d_model, d_model * 4)), 138 | ("gelu", QuickGELU()), 139 | ("c_proj", nn.Linear(d_model * 4, d_model)) 140 | ])) 141 | self.ln_2 = LayerNorm(d_model) 142 | self.attn_mask = attn_mask 143 | 144 | if use_grad_checkpointing: 145 | self.attn = checkpoint_wrapper(self.attn) 146 | self.mlp = checkpoint_wrapper(self.mlp) 147 | 148 | def attention(self, x: torch.Tensor): 149 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 150 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 151 | 152 | def forward(self, x: torch.Tensor): 153 | x = x + self.attention(self.ln_1(x)) 154 | x = x + self.mlp(self.ln_2(x)) 155 | return x 156 | 157 | class Transformer(nn.Module): 158 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): 159 | super().__init__() 160 | self.width = width 161 | self.layers = layers 162 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)]) 163 | 164 | def forward(self, x: torch.Tensor): 165 | return self.resblocks(x) 166 | 167 | class VisionTransformer(nn.Module): 168 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool): 169 | super().__init__() 170 | self.input_resolution = input_resolution 171 | self.num_features = width 172 | self.num_heads = heads 173 | self.num_patches = (input_resolution // patch_size) ** 2 174 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 175 | scale = width ** -0.5 176 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 177 | self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width)) 178 | self.ln_pre = LayerNorm(width) 179 | self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing) 180 | 181 | def forward(self, x: torch.Tensor): 182 | x = self.conv1(x) # shape = [*, width, grid, grid] 183 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 184 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 185 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 186 | x = x + self.positional_embedding.to(x.dtype) 187 | x = self.ln_pre(x) 188 | x = x.permute(1, 0, 2) # NLD -> LND 189 | x = self.transformer(x) 190 | x = x.permute(1, 0, 2) # LND -> NLD 191 | return x 192 | -------------------------------------------------------------------------------- /starvector/data/augmentation.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from svgpathtools import ( 4 | Path, Arc, CubicBezier, QuadraticBezier, 5 | svgstr2paths) 6 | import os 7 | from noise import pnoise1 8 | import re 9 | import matplotlib.colors as mcolors 10 | from bs4 import BeautifulSoup 11 | from starvector.data.util import rasterize_svg 12 | 13 | class SVGTransforms: 14 | def __init__(self, transformations): 15 | self.transformations = transformations 16 | self.noise_std = self.transformations.get('noise_std', False) 17 | self.noise_type = self.transformations.get('noise_type', False) 18 | self.rotate = self.transformations.get('rotate', False) 19 | self.shift_re = self.transformations.get('shift_re', False) 20 | self.shift_im = self.transformations.get('shift_im', False) 21 | self.scale = self.transformations.get('scale', False) 22 | self.color_noise = self.transformations.get('color_noise', False) 23 | self.p = self.transformations.get('p', 0.5) 24 | self.color_change = self.transformations.get('color_change', False) 25 | self.colors = self.transformations.get('colors', ['#ff0000', '#0000ff', '#000000']) 26 | 27 | def sample_transformations(self): 28 | if self.rotate: 29 | a, b = self.rotate['from'], self.rotate['to'] 30 | rotation_angle = np.random.uniform(a, b) 31 | self.rotation_angle = rotation_angle 32 | 33 | if self.shift_re or self.shift_im: 34 | self.shift_real = np.random.uniform(self.shift_re['from'], self.shift_re['to']) 35 | self.shift_imag = np.random.uniform(self.shift_im['from'], self.shift_im['to']) 36 | 37 | if self.scale: 38 | self.scale = np.random.uniform(self.scale['from'], self.scale['to']) 39 | 40 | if self.color_noise: 41 | self.color_noise_std = np.random.uniform(self.color_noise['from'], self.color_noise['to']) 42 | 43 | 44 | def paths2str(self, groupped_paths, svg_opening_tag=''): 45 | 46 | keys_to_exclude = ['d', 'cx', 'cy', 'rx', 'ry'] 47 | all_groups_srt = '' 48 | for group, elements in groupped_paths.items(): 49 | group_attributes, paths_and_attributes = elements.get('attrs', {}), elements.get('paths', []) 50 | group_attr_str = ' '.join(f'{key}="{value}"' for key, value in group_attributes.items()) 51 | path_strings = [] 52 | path_str = '' 53 | for path, attributes in paths_and_attributes: 54 | path_attr_str = '' 55 | d_str = path.d() 56 | 57 | for key, value in attributes.items(): 58 | if key not in keys_to_exclude: 59 | path_attr_str += f' {key}="{value}"' 60 | 61 | path_strings.append(f'') 62 | path_str = "\n".join(path_strings) 63 | if 'no_group'in group: 64 | group_str = path_str 65 | else: 66 | group_str = f'\n{path_str}\n\n' 67 | all_groups_srt += group_str 68 | svg = f'{svg_opening_tag}\n{all_groups_srt}' 69 | return svg 70 | 71 | def add_noise(self, seg): 72 | noise_scale = np.random.uniform(self.noise_std['from'], self.noise_std['to']) 73 | if self.noise_type == 'gaussian': 74 | noise_sample = np.random.normal(loc=0.0, scale=noise_scale) + \ 75 | 1j * np.random.normal(loc=0.0, scale=noise_scale) 76 | elif self.noise_type == 'perlin': 77 | noise_sample = complex(pnoise1(np.random.random(), octaves=2), pnoise1(np.random.random(), octaves=2))*noise_scale 78 | 79 | if isinstance(seg, CubicBezier): 80 | seg.control1 = seg.control1 + noise_sample 81 | seg.control2 = seg.control2 + noise_sample 82 | elif isinstance(seg, QuadraticBezier): 83 | seg.control = seg.control + noise_sample 84 | elif isinstance(seg, Arc): 85 | seg.radius = seg.radius + noise_sample 86 | 87 | 88 | return seg 89 | 90 | def do_rotate(self, path, viewbox_width, viewbox_height): 91 | if self.rotate: 92 | new_path = path.rotated(self.rotation_angle, complex(viewbox_width/2, viewbox_height/2)) 93 | return new_path 94 | else: 95 | return path 96 | 97 | def do_shift(self, path): 98 | if self.shift_re or self.shift_im: 99 | return path.translated(complex(self.shift_real, self.shift_imag)) 100 | else: 101 | return path 102 | 103 | def do_scale(self, path): 104 | if self.scale: 105 | return path.scaled(self.scale) 106 | else: 107 | return path 108 | 109 | def add_color_noise(self, source_color): 110 | # Convert color to RGB 111 | if source_color.startswith("#"): 112 | base_color = mcolors.hex2color(source_color) 113 | else: 114 | base_color = mcolors.hex2color(mcolors.CSS4_COLORS.get(source_color, '#FFFFFF')) 115 | 116 | # Add noise to each RGB component 117 | noise = np.random.normal(0, self.color_noise_std, 3) 118 | noisy_color = np.clip(np.array(base_color) + noise, 0, 1) 119 | 120 | # Convert the RGB color back to hex 121 | hex_color = mcolors.rgb2hex(noisy_color) 122 | 123 | return hex_color 124 | 125 | def do_color_change(self, attr): 126 | if 'fill' in attr: 127 | if self.color_noise or self.color_change: 128 | fill_value = attr['fill'] 129 | if fill_value == 'none': 130 | new_fill_value = 'none' 131 | else: 132 | if self.color_noise: 133 | new_fill_value = self.add_color_noise(fill_value) 134 | elif self.color_change: 135 | new_fill_value = np.random.choice(self.colors) 136 | attr['fill'] = new_fill_value 137 | return attr 138 | 139 | def clean_attributes(self, attr): 140 | attr_out = {} 141 | if 'fill' in attr: 142 | attr_out = attr 143 | elif 'style' in attr: 144 | fill_values = re.findall('fill:[^;]+', attr['style']) 145 | if fill_values: 146 | fill_value = fill_values[0].replace('fill:', '').strip() 147 | attr_out['fill'] = fill_value 148 | else: 149 | attr_out = attr 150 | else: 151 | attr_out = attr 152 | 153 | return attr_out 154 | 155 | def get_viewbox_size(self, svg): 156 | # Try to extract viewBox attribute 157 | match = re.search(r'viewBox="([^"]+)"', svg) 158 | if match: 159 | viewbox = match.group(1) 160 | else: 161 | # If viewBox is not found, try to extract width and height attributes 162 | match = re.search(r'width="([^"]+)px" height="([^"]+)px"', svg) 163 | if match: 164 | width, height = match.groups() 165 | viewbox = f"0 0 {width} {height}" 166 | else: 167 | viewbox = "0 0 256 256" # Default if neither viewBox nor width/height are found 168 | 169 | viewbox = [float(x) for x in viewbox.split()] 170 | viewbox_width, viewbox_height = viewbox[2], viewbox[3] 171 | return viewbox_width, viewbox_height 172 | 173 | def augment(self, svg): 174 | if os.path.isfile(svg): 175 | # open svg file 176 | with open(svg, 'r') as f: 177 | svg = f.read() 178 | 179 | # Sample transformations for this sample 180 | self.sample_transformations() 181 | 182 | 183 | # Parse the SVG content 184 | soup = BeautifulSoup(svg, 'xml') 185 | 186 | # Get opening tag 187 | svg_opening_tag = re.findall(']+>', svg)[0] 188 | 189 | viewbox_width, viewbox_height = self.get_viewbox_size(svg) 190 | 191 | # Get all svg parents 192 | groups = soup.findAll() 193 | 194 | # Create the groups of paths based on their original tag 195 | grouped_paths = {} 196 | for i, g in enumerate(groups): 197 | if g.name == 'g': 198 | group_id = group_id = g.get('id') if g.get('id') else f'none_{i}' 199 | group_attrs = g.attrs 200 | 201 | elif g.name == 'svg' or g.name == 'metadata' or g.name == 'defs': 202 | continue 203 | 204 | else: 205 | group_id = f'no_group_{i}' 206 | group_attrs = {} 207 | 208 | group_svg_string = f'{svg_opening_tag}{str(g)}' 209 | try: 210 | paths, attributes = svgstr2paths(group_svg_string) 211 | except: 212 | return svg, rasterize_svg(svg) 213 | if not paths: 214 | continue 215 | 216 | paths_and_attributes = [] 217 | 218 | # Rotation, shift, scale, noise addition 219 | new_paths = [] 220 | new_attributes = [] 221 | for path, attribute in zip(paths, attributes): 222 | attr = self.clean_attributes(attribute) 223 | 224 | new_path = self.do_rotate(path, viewbox_width, viewbox_height) 225 | new_path = self.do_shift(new_path) 226 | new_path = self.do_scale(new_path) 227 | 228 | if self.noise_std: 229 | # Add noise to path to deform svg 230 | noisy_path = [] 231 | for seg in new_path: 232 | noisy_seg = self.add_noise(seg) 233 | noisy_path.append(noisy_seg) 234 | new_paths.append(Path(*noisy_path)) 235 | else: 236 | new_paths.append(new_path) 237 | 238 | # Color change 239 | attr = self.do_color_change(attr) 240 | paths_and_attributes.append((new_path, attr)) 241 | 242 | grouped_paths[group_id] = { 243 | 'paths': paths_and_attributes, 244 | 'attrs': group_attrs 245 | } 246 | 247 | svg = self.paths2str(grouped_paths, svg_opening_tag) 248 | image = rasterize_svg(svg) 249 | 250 | return svg, image 251 | -------------------------------------------------------------------------------- /starvector/data/base.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from starvector.data.util import ImageTrainProcessor, use_placeholder, rasterize_svg 3 | from starvector.util import instantiate_from_config 4 | import numpy as np 5 | from datasets import load_dataset 6 | 7 | class SVGDatasetBase(Dataset): 8 | def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): 9 | self.split = split 10 | self.im_size = im_size 11 | 12 | transforms = kwargs.get('transforms', False) 13 | if transforms: 14 | self.transforms = instantiate_from_config(transforms) 15 | self.p = self.transforms.p 16 | else: 17 | self.transforms = None 18 | self.p = 0.0 19 | 20 | normalization = kwargs.get('normalize', False) 21 | if normalization: 22 | mean = tuple(normalization.get('mean', None)) 23 | std = tuple(normalization.get('std', None)) 24 | else: 25 | mean = None 26 | std = None 27 | 28 | self.processor = ImageTrainProcessor(size=self.im_size, mean=mean, std=std) 29 | self.data = load_dataset(dataset_name, split=split) 30 | 31 | print(f"Loaded {len(self.data)} samples from {dataset_name} {split} split") 32 | 33 | def __len__(self): 34 | return len(self.data_json) 35 | 36 | def get_svg_and_image(self, svg_str, sample_id): 37 | do_augment = np.random.choice([True, False], p=[self.p, 1 - self.p]) 38 | svg, image = None, None 39 | 40 | # Try to augment the image if conditions are met 41 | if self.transforms is not None and do_augment: 42 | try: 43 | svg, image = self.transforms.augment(svg_str) 44 | except Exception as e: 45 | print(f"Error augmenting {sample_id} due to {str(e)}, trying to rasterize SVG") 46 | 47 | # If augmentation failed or wasn't attempted, try to rasterize the SVG 48 | if svg is None or image is None: 49 | try: 50 | svg, image = svg_str, rasterize_svg(svg_str, self.im_size) 51 | except Exception as e: 52 | print(f"Error rasterizing {sample_id} due to {str(e)}, using placeholder image") 53 | svg = use_placeholder() 54 | image = rasterize_svg(svg, self.im_size) 55 | 56 | # If the image is completely white, use a placeholder image 57 | if np.array(image).mean() == 255.0: 58 | print(f"Image is full white, using placeholder image for {sample_id}") 59 | svg = use_placeholder() 60 | image = rasterize_svg(svg) 61 | 62 | # Process the image 63 | if 'siglip' in self.image_processor: 64 | image = self.processor(image).pixel_values[0] 65 | else: 66 | image = self.processor(image) 67 | 68 | return svg, image 69 | 70 | def __getitem__(self, idx): 71 | raise NotImplementedError("This method should be implemented by subclasses") 72 | -------------------------------------------------------------------------------- /starvector/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from starvector.data.base import SVGDatasetBase 3 | from starvector.data.augmentation import SVGTransforms 4 | from starvector.data.util import ImageTrainProcessor 5 | from transformers import AutoProcessor 6 | 7 | class SVGDataset(SVGDatasetBase): 8 | def __init__(self, dataset_name, split, im_size, num_samples=None, **kwargs): 9 | super().__init__(dataset_name, split, im_size, num_samples, **kwargs) 10 | 11 | self.color_changer = SVGTransforms({'color_change' : True, 'colors' : ['#ff0000', '#0000ff', '#00ff00', '#ffff00', '#000000']}) 12 | select_dataset_name = kwargs.get('select_dataset_name', False) 13 | 14 | if select_dataset_name: 15 | self.data = self.data.filter(lambda example: example["model_name"]==select_dataset_name) 16 | 17 | self.num_samples = num_samples 18 | if self.num_samples != -1: 19 | self.data = self.data.select(range(self.num_samples)) 20 | 21 | self.image_processor = kwargs.get('image_processor', None) 22 | if 'siglip' in self.image_processor: 23 | model_name = {'siglip_512': 'google/siglip-base-patch16-512', 24 | 'siglip_384': 'google/siglip-large-patch16-384', 25 | 'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor] 26 | self.processor = AutoProcessor.from_pretrained(model_name).image_processor 27 | else: 28 | self.processor = ImageTrainProcessor(size=self.im_size) 29 | def __len__(self): 30 | return len(self.data) 31 | 32 | def __getitem__(self, idx): 33 | svg_str = self.data[idx]['Svg'] 34 | sample_id = self.data[idx]['Filename'] 35 | svg, image = self.get_svg_and_image(svg_str, sample_id) 36 | caption = self.data[idx].get('Caption', "") 37 | return { 38 | 'svg': svg, 39 | 'image': image, 40 | 'id': sample_id, 41 | 'caption': caption 42 | } -------------------------------------------------------------------------------- /starvector/data/emojisvg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from starvector.data.base import SVGDatasetBase 3 | 4 | 5 | class EmojiSVGDataset(SVGDatasetBase): 6 | def __init__(self, dataset_name, split, im_size, num_samples=None, **kwargs): 7 | super().__init__(dataset_name, split, im_size, **kwargs) 8 | 9 | self.num_samples = num_samples 10 | if self.num_samples != -1: 11 | self.data = self.data.select(range(self.num_samples)) 12 | 13 | def __len__(self): 14 | return len(self.data) 15 | 16 | def __getitem__(self, idx): 17 | 18 | svg_str = self.data[idx]['Svg'] 19 | sample_id = self.data[idx]['Filename'] 20 | svg, image = self.get_svg_and_image(svg_str, sample_id) 21 | caption = self.data[idx].get('Caption', "") 22 | return { 23 | 'svg': svg, 24 | 'image': image, 25 | 'id': sample_id, 26 | 'caption': caption 27 | } -------------------------------------------------------------------------------- /starvector/data/figrsvg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from starvector.data.base import SVGDatasetBase 3 | from transformers import AutoProcessor 4 | from starvector.data.util import ImageTrainProcessor 5 | 6 | class FigrSVGDataset(SVGDatasetBase): 7 | def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): 8 | super().__init__(dataset_name, split, im_size, **kwargs) 9 | 10 | self.num_samples = num_samples 11 | if self.num_samples != -1: 12 | self.data = self.data.select(range(self.num_samples)) 13 | 14 | def __len__(self): 15 | return len(self.data) 16 | 17 | def __getitem__(self, idx): 18 | svg_str = self.data[idx]['Svg'] 19 | sample_id = self.data[idx]['Id'] 20 | svg, image = self.get_svg_and_image(svg_str, sample_id) 21 | caption = self.data[idx].get('Caption', "") 22 | return { 23 | 'svg': svg, 24 | 'image': image, 25 | 'id': sample_id, 26 | 'caption': caption 27 | } 28 | -------------------------------------------------------------------------------- /starvector/data/fontsvg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from starvector.data.base import SVGDatasetBase 3 | from transformers import AutoProcessor 4 | from starvector.data.util import ImageTrainProcessor 5 | 6 | class FontSVGDataset(SVGDatasetBase): 7 | def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): 8 | super().__init__(dataset_name, split, im_size, **kwargs) 9 | 10 | self.num_samples = num_samples 11 | if self.num_samples != -1: 12 | self.data = self.data.select(range(self.num_samples)) 13 | 14 | def __len__(self): 15 | return len(self.data) 16 | 17 | def __getitem__(self, idx): 18 | 19 | svg_str = self.data[idx]['Svg'] 20 | sample_id = self.data[idx]['Filename'] 21 | svg, image = self.get_svg_and_image(svg_str, sample_id) 22 | caption = self.data[idx].get('Caption', "") 23 | return { 24 | 'svg': svg, 25 | 'image': image, 26 | 'id': sample_id, 27 | 'caption': caption 28 | } 29 | -------------------------------------------------------------------------------- /starvector/data/iconsvg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from starvector.data.base import SVGDatasetBase 3 | from starvector.data.util import ImageTrainProcessor 4 | from transformers import AutoProcessor 5 | 6 | class SVGIconsDataset(SVGDatasetBase): 7 | def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): 8 | super().__init__(dataset_name, split, im_size, **kwargs) 9 | 10 | self.num_samples = num_samples 11 | if self.num_samples != -1: 12 | self.data = self.data.select(range(self.num_samples)) 13 | 14 | self.image_processor = kwargs.get('image_processor', None) 15 | if 'siglip' in self.image_processor: 16 | model_name = {'siglip_512': 'google/siglip-base-patch16-512', 17 | 'siglip_384': 'google/siglip-large-patch16-384', 18 | 'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor] 19 | self.processor = AutoProcessor.from_pretrained(model_name).image_processor 20 | else: 21 | self.processor = ImageTrainProcessor(size=self.im_size) 22 | 23 | 24 | def __len__(self): 25 | return len(self.data) 26 | 27 | def __getitem__(self, idx): 28 | 29 | svg_str = self.data[idx]['Svg'] 30 | sample_id = self.data[idx]['Filename'] 31 | svg, image = self.get_svg_and_image(svg_str, sample_id) 32 | caption = self.data[idx].get('Caption', "") 33 | return { 34 | 'svg': svg, 35 | 'image': image, 36 | 'id': sample_id, 37 | 'caption': caption 38 | } 39 | -------------------------------------------------------------------------------- /starvector/data/stacksvg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from starvector.data.base import SVGDatasetBase 3 | from starvector.data.augmentation import SVGTransforms 4 | import random 5 | from transformers import AutoProcessor 6 | from starvector.data.util import ImageTrainProcessor 7 | 8 | text2svg_captions = [ 9 | "Draw an SVG of ", 10 | "Draw an SVG image of ", 11 | "Draw an SVG picture of ", 12 | "Generate an SVG of ", 13 | "Create an SVG of ", 14 | "Design an SVG of ", 15 | "Make an SVG of ", 16 | ] 17 | 18 | class SVGStackDataset(SVGDatasetBase): 19 | def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): 20 | super().__init__(dataset_name, split, im_size, num_samples, **kwargs) 21 | self.color_changer = SVGTransforms({'color_change' : True, 'colors' : ['#ff0000', '#0000ff', '#00ff00', '#ffff00', '#000000']}) 22 | 23 | # Text2SVG specific 24 | self.random_caption = kwargs.get('random_caption', True) 25 | select_dataset_name = kwargs.get('select_dataset_name', False) 26 | if select_dataset_name: 27 | self.data = self.data.filter(lambda example: example["model_name"]==select_dataset_name) 28 | 29 | self.num_samples = num_samples 30 | if self.num_samples != -1: 31 | self.data = self.data.select(range(self.num_samples)) 32 | 33 | self.image_processor = kwargs.get('image_processor', None) 34 | if self.image_processor and 'siglip' in self.image_processor: 35 | model_name = {'siglip_512': 'google/siglip-base-patch16-512', 36 | 'siglip_384': 'google/siglip-large-patch16-384', 37 | 'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor] 38 | self.processor = AutoProcessor.from_pretrained(model_name).image_processor 39 | else: 40 | self.processor = ImageTrainProcessor(size=self.im_size) 41 | 42 | 43 | def __len__(self): 44 | return len(self.data) 45 | 46 | def __getitem__(self, idx): 47 | svg_str = self.data[idx]['Svg'] 48 | sample_id = self.data[idx]['Filename'] 49 | svg, image = self.get_svg_and_image(svg_str, sample_id) 50 | 51 | # Randomly choose between 'caption_blip' and 'caption_llava' 52 | caption_column = random.choice(['caption_blip2', 'caption_llava']) 53 | caption = random.choice(text2svg_captions) + self.data[idx].get(caption_column, "") 54 | return { 55 | 'svg': svg, 56 | 'image': image, 57 | 'id': sample_id, 58 | 'caption': caption, 59 | } 60 | -------------------------------------------------------------------------------- /starvector/image_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import os 5 | from omegaconf import OmegaConf 6 | from starvector.model.image_encoder.clip_model import convert_weights_to_precision 7 | from starvector.data.util import ImageTrainProcessor 8 | 9 | class ImageEncoder(nn.Module): 10 | def __init__(self, config, **kwargs): 11 | super(ImageEncoder, self).__init__() 12 | 13 | image_size = config.image_size 14 | torch_dtype = kwargs.get('model_precision', config.torch_dtype) 15 | self.image_encoder_type = config.image_encoder_type 16 | if self.image_encoder_type == 'clip': 17 | self.visual_encoder, self.ln_vision = self.build_clip_encoder(image_size=image_size) 18 | convert_weights_to_precision(self, torch_dtype) 19 | self.processor = ImageTrainProcessor(size=config.image_size) 20 | 21 | elif self.image_encoder_type == 'vqgan': 22 | self.visual_encoder = self.build_vqgan_encoder() 23 | self.ln_vision = None 24 | self.processor = ImageTrainProcessor(size=config.image_size) 25 | 26 | elif self.image_encoder_type == 'convnext': 27 | self.visual_encoder = self.build_vqgan_encoder() 28 | self.ln_vision = None 29 | self.processor = ImageTrainProcessor(size=config.image_size) 30 | 31 | elif 'siglip' in self.image_encoder_type: 32 | if self.image_encoder_type == 'siglip_512': 33 | model_name = "google/siglip-base-patch16-512" 34 | elif self.image_encoder_type == 'siglip_384': 35 | model_name = "google/siglip-large-patch16-384" 36 | elif self.image_encoder_type == 'siglip_256': 37 | model_name = "google/siglip-base-patch16-256" 38 | 39 | from transformers import AutoProcessor, AutoModel 40 | 41 | self.visual_encoder = AutoModel.from_pretrained( 42 | model_name, torch_dtype = torch_dtype 43 | ).vision_model 44 | 45 | self.processor = AutoProcessor.from_pretrained( 46 | model_name, torch_dtype = torch_dtype 47 | ) 48 | 49 | def build_clip_encoder(self, image_size): 50 | from starvector.model.image_encoder.clip_model import VisionTransformer, LayerNorm 51 | visual_encoder = VisionTransformer( 52 | input_resolution=image_size, 53 | patch_size=14, 54 | width=1024, 55 | layers=23, 56 | heads=16, 57 | use_grad_checkpointing=False) 58 | 59 | ln_vision = LayerNorm(visual_encoder.num_features) 60 | return visual_encoder, ln_vision 61 | 62 | def build_vqgan_encoder(self): 63 | from taming.modules.diffusionmodules.model import Encoder 64 | VQGAN_CHECKPOINT = "/path/to/vqgan_checkpoint" # You can download the checkpoint from https://github.com/EleutherAI/vqgan-clip/blob/main/README.md 65 | vqgan_chkp_path = VQGAN_CHECKPOINT 66 | files_in_directory = os.listdir(vqgan_chkp_path + '/configs') 67 | vqgan_config_file = [file for file in files_in_directory if file.endswith('project.yaml')][0] 68 | vqgan_config = OmegaConf.load(os.path.join(vqgan_chkp_path, 'configs', vqgan_config_file)) 69 | visual_encoder = Encoder(**vqgan_config.model.params.ddconfig) 70 | 71 | # Load checkpoint weights 72 | checkpoint = torch.load(os.path.join(vqgan_chkp_path, 'checkpoints', 'last.ckpt'))['state_dict'] 73 | 74 | # Create a new state_dict with modified keys 75 | new_state_dict = {} 76 | for key, value in checkpoint.items(): 77 | if key.startswith('encoder.'): 78 | new_key = key[len('encoder.'):] 79 | new_state_dict[new_key] = value 80 | 81 | # Load weights 82 | visual_encoder.load_state_dict(new_state_dict) 83 | return visual_encoder 84 | 85 | def build_convnext_encoder(self): 86 | import open_clip 87 | model, _, _ = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k') 88 | return model.visual 89 | 90 | def forward(self, image): 91 | if self.image_encoder_type == 'clip': 92 | embeds = self.visual_encoder(image) 93 | out = self.ln_vision(embeds) 94 | elif self.image_encoder_type == 'open-clip': 95 | out = self.visual_encoder(image)[1] 96 | out = self.ln_vision(out) 97 | elif self.image_encoder_type == 'vqgan': 98 | out = self.visual_encoder(image) 99 | size = out.size() 100 | out = out.view(size[0], size[1], -1) 101 | out = out.permute(0, 2, 1) 102 | elif self.image_encoder_type == 'convnext': 103 | out = self.visual_encoder.trunk.forward_features(image) 104 | size = out.size() 105 | out = out.view(size[0], size[1], -1) 106 | out = out.permute(0, 2, 1) 107 | elif 'siglip' in self.image_encoder_type: 108 | out = self.visual_encoder(image)["last_hidden_state"] 109 | return out 110 | 111 | def process_images(self, images): 112 | if self.image_encoder_type == 'clip': 113 | res = [] 114 | for image in images: 115 | res.append(self.processor(image).unsqueeze(0)) # B, 3, H, W 116 | return res 117 | else: 118 | return self.processor(images=images, return_tensors="pt").pixel_values.unsqueeze(0) 119 | -------------------------------------------------------------------------------- /starvector/metrics/base_metric.py: -------------------------------------------------------------------------------- 1 | from starvector.metrics.util import AverageMeter 2 | from tqdm import tqdm 3 | import math 4 | 5 | class BaseMetric: 6 | def __init__(self): 7 | self.meter = AverageMeter() 8 | 9 | def reset(self): 10 | self.meter.reset() 11 | 12 | def calculate_score(self, batch, update=True): 13 | """ 14 | Batch: {"gt_im": [PIL Image], "gen_im": [Image]} 15 | """ 16 | values = [] 17 | batch_size = len(next(iter(batch.values()))) 18 | for index in tqdm(range(batch_size)): 19 | kwargs = {} 20 | for key in ["gt_im", "gen_im", "gt_svg", "gen_svg", "caption"]: 21 | if key in batch: 22 | kwargs[key] = batch[key][index] 23 | try: 24 | measure = self.metric(**kwargs) 25 | except Exception as e: 26 | print("Error calculating metric: {}".format(e)) 27 | continue 28 | if math.isnan(measure): 29 | continue 30 | values.append(measure) 31 | 32 | if not values: 33 | print("No valid values found for metric calculation.") 34 | return float("nan") 35 | 36 | score = sum(values) / len(values) 37 | if update: 38 | self.meter.update(score, len(values)) 39 | return self.meter.avg, values 40 | else: 41 | return score, values 42 | 43 | def metric(self, **kwargs): 44 | """ 45 | This method should be overridden by subclasses to provide the specific metric computation. 46 | """ 47 | raise NotImplementedError("The metric method must be implemented by subclasses.") 48 | 49 | def get_average_score(self): 50 | return self.meter.avg 51 | 52 | -------------------------------------------------------------------------------- /starvector/metrics/compute_LPIPS.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import ToTensor, Normalize 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from starvector.metrics.base_metric import BaseMetric 5 | import lpips 6 | from tqdm import tqdm 7 | 8 | 9 | class LPIPSDistanceCalculator(BaseMetric): 10 | def __init__(self, config=None, device='cuda'): 11 | super().__init__() 12 | self.class_name = self.__class__.__name__ 13 | self.config = config 14 | self.model = lpips.LPIPS(net='vgg').to(device) 15 | self.metric = self.LPIPS 16 | self.to_tensor = ToTensor() 17 | self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 18 | self.device = device 19 | 20 | def LPIPS(self, tensor_image1, tensor_image2): 21 | tensor_image1, tensor_image2 = tensor_image1.to(self.device), tensor_image2.to(self.device) 22 | return self.model(tensor_image1, tensor_image2) 23 | 24 | def to_tensor_transform(self, pil_img): 25 | return self.normalize(self.to_tensor(pil_img)) 26 | 27 | def collate_fn(self, batch): 28 | gt_imgs, gen_imgs = zip(*batch) 29 | tensor_gt_imgs = torch.stack([self.to_tensor_transform(img) for img in gt_imgs]) 30 | tensor_gen_imgs = torch.stack([self.to_tensor_transform(img) for img in gen_imgs]) 31 | return tensor_gt_imgs, tensor_gen_imgs 32 | 33 | def calculate_score(self, batch, batch_size=8, update=True): 34 | gt_images = batch['gt_im'] 35 | gen_images = batch['gen_im'] 36 | 37 | # Create DataLoader with custom collate function 38 | data_loader = DataLoader(list(zip(gt_images, gen_images)), batch_size=batch_size, collate_fn=self.collate_fn, shuffle=False) 39 | 40 | values = [] 41 | for tensor_gt_batch, tensor_gen_batch in tqdm(data_loader): 42 | # Compute LPIPS 43 | lpips_values = self.LPIPS(tensor_gt_batch, tensor_gen_batch) 44 | values.extend([lpips_values.squeeze().cpu().detach().tolist()] if lpips_values.numel() == 1 else lpips_values.squeeze().cpu().detach().tolist()) 45 | 46 | if not values: 47 | print("No valid values found for metric calculation.") 48 | return float("nan") 49 | 50 | avg_score = sum(values) / len(values) 51 | if update: 52 | self.meter.update(avg_score, len(values)) 53 | return self.meter.avg, values 54 | else: 55 | return avg_score, values 56 | -------------------------------------------------------------------------------- /starvector/metrics/compute_SSIM.py: -------------------------------------------------------------------------------- 1 | from starvector.metrics.base_metric import BaseMetric 2 | from skimage.metrics import structural_similarity as ssim 3 | import numpy as np 4 | 5 | class SSIMDistanceCalculator(BaseMetric): 6 | def __init__(self, config=None): 7 | super().__init__() 8 | self.class_name = self.__class__.__name__ 9 | self.config = config 10 | self.metric = self.compute_SSIM 11 | 12 | def compute_SSIM(self, **kwargs): 13 | image1 = kwargs.get('gt_im') 14 | image2 = kwargs.get('gen_im') 15 | win_size = kwargs.get('win_size', 11) # Increase win_size for more accuracy 16 | channel_axis = kwargs.get('channel_axis', -1) # Default channel_axis to -1 17 | sigma = kwargs.get('sigma', 1.5) # Add sigma parameter for Gaussian filter 18 | 19 | # Convert images to numpy arrays if they aren't already 20 | img1_np = np.array(image1) 21 | img2_np = np.array(image2) 22 | 23 | # Check if images are grayscale or RGB 24 | if len(img1_np.shape) == 3 and img1_np.shape[2] == 3: 25 | # Compute SSIM for RGB images 26 | score, _ = ssim(img1_np, img2_np, win_size=win_size, channel_axis=channel_axis, sigma=sigma, full=True) 27 | else: 28 | # Convert to grayscale if not already 29 | if len(img1_np.shape) == 3: 30 | img1_np = np.mean(img1_np, axis=2) 31 | img2_np = np.mean(img2_np, axis=2) 32 | 33 | score, _ = ssim(img1_np, img2_np, win_size=win_size, sigma=sigma, full=True) 34 | 35 | return score -------------------------------------------------------------------------------- /starvector/metrics/compute_clip_score.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import ToTensor 2 | import torch.nn.functional as F 3 | from starvector.metrics.base_metric import BaseMetric 4 | import torch 5 | from torchmetrics.multimodal.clip_score import CLIPScore 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | import torchvision.transforms as transforms 9 | from torchmetrics.functional.multimodal.clip_score import _clip_score_update 10 | 11 | class CLIPScoreCalculator(BaseMetric): 12 | def __init__(self): 13 | super().__init__() 14 | self.class_name = self.__class__.__name__ 15 | self.clip_score = CLIPScore(model_name_or_path="openai/clip-vit-base-patch32") 16 | self.clip_score.to('cuda') 17 | 18 | def CLIP_Score(self, images, captions): 19 | all_scores = _clip_score_update(images, captions, self.clip_score.model, self.clip_score.processor) 20 | return all_scores 21 | 22 | def collate_fn(self, batch): 23 | gen_imgs, captions = zip(*batch) 24 | tensor_gen_imgs = [transforms.ToTensor()(img) for img in gen_imgs] 25 | return tensor_gen_imgs, captions 26 | 27 | def calculate_score(self, batch, batch_size=512, update=True): 28 | gen_images = batch['gen_im'] 29 | captions = batch['caption'] 30 | 31 | # Create DataLoader with custom collate function 32 | data_loader = DataLoader(list(zip(gen_images, captions)), collate_fn=self.collate_fn, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) 33 | 34 | all_scores = [] 35 | for batch_eval in tqdm(data_loader): 36 | images, captions = batch_eval 37 | images = [img.to('cuda', non_blocking=True) * 255 for img in images] 38 | list_scores = self.CLIP_Score(images, captions)[0].detach().cpu().tolist() 39 | all_scores.extend(list_scores) 40 | 41 | if not all_scores: 42 | print("No valid scores found for metric calculation.") 43 | return float("nan"), [] 44 | 45 | avg_score = sum(all_scores) / len(all_scores) 46 | if update: 47 | self.meter.update(avg_score, len(all_scores)) 48 | return self.meter.avg, all_scores 49 | else: 50 | return avg_score, all_scores 51 | 52 | if __name__ == '__main__': 53 | import multiprocessing 54 | multiprocessing.set_start_method('spawn') 55 | # Rest of your code... -------------------------------------------------------------------------------- /starvector/metrics/compute_dino_score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from starvector.metrics.base_metric import BaseMetric 4 | from tqdm import tqdm 5 | from transformers import AutoModel, AutoImageProcessor 6 | from PIL import Image 7 | import torch.nn as nn 8 | 9 | class DINOScoreCalculator(BaseMetric): 10 | def __init__(self, config=None, device='cuda'): 11 | super().__init__() 12 | self.class_name = self.__class__.__name__ 13 | self.config = config 14 | self.model, self.processor = self.get_DINOv2_model("base") 15 | self.model = self.model.to(device) 16 | self.device = device 17 | 18 | self.metric = self.calculate_DINOv2_similarity_score 19 | 20 | def get_DINOv2_model(self, model_size): 21 | if model_size == "small": 22 | model_size = "facebook/dinov2-small" 23 | elif model_size == "base": 24 | model_size = "facebook/dinov2-base" 25 | elif model_size == "large": 26 | model_size = "facebook/dinov2-large" 27 | else: 28 | raise ValueError(f"model_size should be either 'small', 'base' or 'large', got {model_size}") 29 | return AutoModel.from_pretrained(model_size), AutoImageProcessor.from_pretrained(model_size) 30 | 31 | def process_input(self, image, processor): 32 | if isinstance(image, str): 33 | image = Image.open(image) 34 | if isinstance(image, Image.Image): 35 | with torch.no_grad(): 36 | inputs = processor(images=image, return_tensors="pt").to(self.device) 37 | outputs = self.model(**inputs) 38 | features = outputs.last_hidden_state.mean(dim=1) 39 | elif isinstance(image, torch.Tensor): 40 | features = image.unsqueeze(0) if image.dim() == 1 else image 41 | else: 42 | raise ValueError("Input must be a file path, PIL Image, or tensor of features") 43 | return features 44 | 45 | def calculate_DINOv2_similarity_score(self, **kwargs): 46 | image1 = kwargs.get('gt_im') 47 | image2 = kwargs.get('gen_im') 48 | features1 = self.process_input(image1, self.processor) 49 | features2 = self.process_input(image2, self.processor) 50 | 51 | cos = nn.CosineSimilarity(dim=1) 52 | sim = cos(features1, features2).item() 53 | sim = (sim + 1) / 2 54 | 55 | return sim 56 | -------------------------------------------------------------------------------- /starvector/metrics/compute_fid.py: -------------------------------------------------------------------------------- 1 | # Refer https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html 2 | # from torchmetrics.image.fid import FrechetInceptionDistance 3 | from PIL import Image 4 | from starvector.metrics.base_metric import BaseMetric 5 | import torch 6 | from torchvision import transforms 7 | import clip 8 | from torch.nn.functional import adaptive_avg_pool2d 9 | from starvector.metrics.inception import InceptionV3 10 | import numpy as np 11 | from tqdm import tqdm 12 | from scipy import linalg 13 | import torchvision.transforms as TF 14 | 15 | class FIDCalculator(BaseMetric): 16 | def __init__(self, model_name = 'InceptionV3',): 17 | self.class_name = self.__class__.__name__ 18 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | self.model_name = model_name 20 | if self.model_name == 'ViT-B/32': 21 | self.dims = 512 22 | model, preprocess = clip.load('ViT-B/32') 23 | 24 | elif self.model_name == 'InceptionV3': 25 | self.dims = 2048 26 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims] 27 | model = InceptionV3([block_idx]).to(self.device) 28 | preprocess = TF.Compose([TF.ToTensor()]) 29 | 30 | self.model = model.cuda() 31 | self.preprocess = preprocess 32 | 33 | def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): 34 | """Numpy implementation of the Frechet Distance. 35 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 36 | and X_2 ~ N(mu_2, C_2) is 37 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 38 | 39 | Stable version by Dougal J. Sutherland. 40 | 41 | Params: 42 | -- mu1 : Numpy array containing the activations of a layer of the 43 | inception net (like returned by the function 'get_predictions') 44 | for generated samples. 45 | -- mu2 : The sample mean over activations, precalculated on an 46 | representative data set. 47 | -- sigma1: The covariance matrix over activations for generated samples. 48 | -- sigma2: The covariance matrix over activations, precalculated on an 49 | representative data set. 50 | 51 | Returns: 52 | -- : The Frechet Distance. 53 | """ 54 | 55 | mu1 = np.atleast_1d(mu1) 56 | mu2 = np.atleast_1d(mu2) 57 | 58 | sigma1 = np.atleast_2d(sigma1) 59 | sigma2 = np.atleast_2d(sigma2) 60 | 61 | assert mu1.shape == mu2.shape, \ 62 | 'Training and test mean vectors have different lengths' 63 | assert sigma1.shape == sigma2.shape, \ 64 | 'Training and test covariances have different dimensions' 65 | 66 | diff = mu1 - mu2 67 | 68 | # Product might be almost singular 69 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 70 | if not np.isfinite(covmean).all(): 71 | msg = ('fid calculation produces singular product; ' 72 | 'adding %s to diagonal of cov estimates') % eps 73 | print(msg) 74 | offset = np.eye(sigma1.shape[0]) * eps 75 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 76 | 77 | # Numerical error might give slight imaginary component 78 | if np.iscomplexobj(covmean): 79 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 80 | m = np.max(np.abs(covmean.imag)) 81 | raise ValueError('Imaginary component {}'.format(m)) 82 | covmean = covmean.real 83 | 84 | tr_covmean = np.trace(covmean) 85 | 86 | return (diff.dot(diff) + np.trace(sigma1) 87 | + np.trace(sigma2) - 2 * tr_covmean) 88 | 89 | def get_activations(self, images): 90 | dataset = ImageDataset(images, self.preprocess) 91 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=50, shuffle=False, num_workers=4) 92 | pred_arr = np.empty((len(images), self.dims)) 93 | start_idx = 0 94 | for batch in tqdm(dataloader): 95 | batch = batch.to(self.device) 96 | 97 | with torch.no_grad(): 98 | if self.model_name == 'ViT-B/32': 99 | pred = self.model.encode_image(batch).cpu().numpy() 100 | elif self.model_name == 'InceptionV3': 101 | pred = self.model(batch)[0] 102 | 103 | # If model output is not scalar, apply global spatial average pooling. 104 | # This happens if you choose a dimensionality not equal 2048. 105 | if pred.size(2) != 1 or pred.size(3) != 1: 106 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 107 | 108 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 109 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 110 | start_idx = start_idx + pred.shape[0] 111 | 112 | return pred_arr 113 | 114 | def calculate_activation_statistics(self, images): 115 | act = self.get_activations(images) 116 | mu = np.mean(act, axis=0) 117 | sigma = np.cov(act, rowvar=False) 118 | return mu, sigma 119 | 120 | def pil_images_to_tensor(self, images_list): 121 | """Convert a list of PIL Images to a torch.Tensor.""" 122 | tensors_list = [self.preprocess(img) for img in images_list] 123 | return torch.stack(tensors_list).cuda() # BxCxHxW format 124 | 125 | def calculate_score(self, batch): 126 | m1, s1 = self.calculate_activation_statistics(batch['gt_im']) 127 | m2, s2 = self.calculate_activation_statistics(batch['gen_im']) 128 | fid_value = self.calculate_frechet_distance(m1, s1, m2, s2) 129 | return fid_value 130 | 131 | def reset(self): 132 | pass 133 | 134 | class ImageDataset(torch.utils.data.Dataset): 135 | def __init__(self, images, processor=None): 136 | self.images = images 137 | self.processor = processor 138 | 139 | def __len__(self): 140 | return len(self.images) 141 | 142 | def __getitem__(self, i): 143 | img = self.images[i] 144 | img = self.processor(img) 145 | return img -------------------------------------------------------------------------------- /starvector/metrics/compute_l2.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import ToTensor 2 | import torch.nn.functional as F 3 | from starvector.metrics.base_metric import BaseMetric 4 | import torch 5 | 6 | class L2DistanceCalculator(BaseMetric): 7 | def __init__(self, config=None, masked_l2=False): 8 | super().__init__() 9 | self.class_name = self.__class__.__name__ 10 | self.config = config 11 | self.metric = self.l2_distance 12 | self.masked_l2 = masked_l2 13 | 14 | def l2_distance(self, **kwargs): 15 | image1 = kwargs.get('gt_im') 16 | image2 = kwargs.get('gen_im') 17 | image1_tensor = ToTensor()(image1) 18 | image2_tensor = ToTensor()(image2) 19 | 20 | if self.masked_l2: 21 | # Create binary masks: 0 for white pixels, 1 for non-white pixels 22 | mask1 = (image1_tensor != 1).any(dim=0).float() 23 | mask2 = (image2_tensor != 1).any(dim=0).float() 24 | 25 | # Create a combined mask for overlapping non-white pixels 26 | combined_mask = mask1 * mask2 27 | 28 | # Apply the combined mask to both images 29 | image1_tensor = image1_tensor * combined_mask.unsqueeze(0) 30 | image2_tensor = image2_tensor * combined_mask.unsqueeze(0) 31 | 32 | # Compute mean squared error 33 | mse = F.mse_loss(image1_tensor, image2_tensor) 34 | return mse.item() 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /starvector/metrics/count_token_length.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from starvector.metrics.base_metric import BaseMetric 4 | from tqdm import tqdm 5 | from starvector.metrics.util import AverageMeter 6 | 7 | from transformers import AutoTokenizer 8 | 9 | class CountTokenLength(BaseMetric): 10 | def __init__(self, config=None, device='cuda'): 11 | super().__init__() 12 | self.tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-7b") 13 | self.metric = self.calculate_token_length 14 | self.meter_gt_tokens = AverageMeter() 15 | self.meter_gen_tokens = AverageMeter() 16 | self.meter_diff = AverageMeter() 17 | 18 | def calculate_token_length(self, **kwargs): 19 | svg = kwargs.get('gt_svg') 20 | tokens = self.tokenizer.encode(svg) 21 | gen_svg = kwargs.get('gen_svg') 22 | gen_tokens = self.tokenizer.encode(gen_svg) 23 | diff = len(gen_tokens) - len(tokens) 24 | return len(tokens), len(gen_tokens), diff 25 | 26 | def calculate_score(self, batch, update=None): 27 | gt_svgs = batch['gt_svg'] 28 | gen_svgs = batch['gen_svg'] 29 | values = [] 30 | for gt_svg, gen_svg in tqdm(zip(gt_svgs, gen_svgs), total=len(gt_svgs), desc="Processing SVGs"): 31 | gt_tokens, gen_tokens, diff = self.calculate_token_length(gt_svg=gt_svg, gen_svg=gen_svg) 32 | self.meter_gt_tokens.update(gt_tokens, 1) 33 | self.meter_gen_tokens.update(gen_tokens, 1) 34 | self.meter_diff.update(diff, 1) 35 | values.append({ 36 | 'gt_tokens': gt_tokens, 37 | 'gen_tokens': gen_tokens, 38 | 'diff': diff 39 | }) 40 | avg_score = { 41 | 'gt_tokens': self.meter_gt_tokens.avg, 42 | 'gen_tokens': self.meter_gen_tokens.avg, 43 | 'diff': self.meter_diff.avg 44 | } 45 | if not values: 46 | print("No valid values found for metric calculation.") 47 | return float("nan") 48 | 49 | return avg_score, values 50 | 51 | def reset(self): 52 | self.meter_gt_tokens.reset() 53 | self.meter_gen_tokens.reset() 54 | self.meter_diff.reset() 55 | -------------------------------------------------------------------------------- /starvector/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | from starvector.metrics.compute_l2 import L2DistanceCalculator 2 | from starvector.metrics.compute_LPIPS import LPIPSDistanceCalculator 3 | from starvector.metrics.compute_SSIM import SSIMDistanceCalculator 4 | from starvector.metrics.compute_fid import FIDCalculator 5 | from starvector.metrics.compute_clip_score import CLIPScoreCalculator 6 | from starvector.data.util import rasterize_svg 7 | from starvector.metrics.util import AverageMeter 8 | from starvector.metrics.compute_dino_score import DINOScoreCalculator 9 | from starvector.metrics.count_token_length import CountTokenLength 10 | import os 11 | from tqdm import tqdm 12 | 13 | class SVGMetrics: 14 | def __init__(self, config=None): 15 | self.class_name = self.__class__.__name__ 16 | 17 | default_config = { 18 | 'L2': True, 19 | 'Masked-L2': False, 20 | 'LPIPS': False, 21 | 'SSIM': False, 22 | 'FID': False, 23 | 'FID_clip': False, 24 | 'CLIPScore': False, 25 | 'CountTokenLength': False, 26 | 'ratio_post_processed': True, 27 | 'ratio_non_compiling': True, 28 | 'DinoScore': True, 29 | } 30 | self.config = config or default_config 31 | 32 | self.metrics = { 33 | 'L2': L2DistanceCalculator, 34 | 'Masked-L2': lambda: L2DistanceCalculator(masked_l2=True), 35 | 'LPIPS': LPIPSDistanceCalculator, 36 | 'SSIM': SSIMDistanceCalculator, 37 | 'FID': lambda: FIDCalculator(model_name='InceptionV3'), 38 | 'FID_clip': lambda: FIDCalculator(model_name='ViT-B/32'), 39 | 'CLIPScore': CLIPScoreCalculator, 40 | 'CountTokenLength': CountTokenLength, 41 | 'ratio_post_processed': AverageMeter, 42 | 'ratio_non_compiling': AverageMeter, 43 | 'DinoScore': DINOScoreCalculator, 44 | } 45 | 46 | self.active_metrics = {k: v() for k, v in self.metrics.items() if self.config.get(k)} 47 | 48 | def reset(self): 49 | for metric in self.active_metrics.values(): 50 | metric.reset() 51 | 52 | def batch_contains_raster(self, batch): 53 | return "gt_im" in batch and "gen_im" in batch 54 | 55 | def batch_contains_svg(self, batch): 56 | return "gt_svg" in batch and "gen_svg" in batch 57 | 58 | def calculate_metrics(self, batch, update=True): 59 | if not self.batch_contains_raster(batch): 60 | batch["gt_im"] = [rasterize_svg(svg) for svg in batch["gt_svg"]] 61 | batch["gen_im"] = [rasterize_svg(svg) for svg in batch["gen_svg"]] 62 | 63 | avg_results_dict = {} 64 | all_results_dict = {} 65 | 66 | def get_sample_id(json_item): 67 | return json_item.get('outpath_filename') or json_item.get('sample_id') 68 | 69 | # initialize all_results_dict 70 | for i, json_item in enumerate(batch['json']): 71 | sample_id = get_sample_id(json_item) 72 | if sample_id is None: 73 | raise ValueError(f"Could not find 'outpath_filename' or 'sample_id' in batch['json'][{i}]") 74 | all_results_dict[sample_id] = {} 75 | 76 | for metric_name, metric in self.active_metrics.items(): 77 | print(f"Calculating {metric_name}...") 78 | 79 | # Handle metrics that return both average and per-sample results 80 | if metric_name in ['L2', 'Masked-L2', 'SSIM', 'CLIPScore', 'LPIPS', 'CountTokenLength', 'DinoScore']: 81 | avg_result, list_result = metric.calculate_score(batch, update=update) 82 | avg_results_dict[metric_name] = avg_result 83 | 84 | # Store individual results 85 | for i, result in enumerate(list_result): 86 | sample_id = get_sample_id(batch['json'][i]) 87 | all_results_dict[sample_id][metric_name] = result 88 | 89 | # Handle FID metrics that only return average 90 | elif metric_name in ['FID', 'FID_clip']: 91 | avg_results_dict[metric_name] = metric.calculate_score(batch) 92 | 93 | # Handle other metrics (ratio metrics) 94 | else: 95 | self._handle_ratio_metric(metric_name, metric, batch, avg_results_dict, all_results_dict) 96 | 97 | metric.reset() 98 | print("Average results: \n", avg_results_dict) 99 | return avg_results_dict, all_results_dict 100 | 101 | def calculate_fid(self, batch): 102 | if not self.batch_contains_raster(batch): 103 | batch["gt_im"] = [rasterize_svg(svg) for svg in batch["gt_svg"]] 104 | batch["gen_im"] = [rasterize_svg(svg) for svg in batch["gen_svg"]] 105 | 106 | return self.active_metrics['FID'].calculate_score(batch).item() 107 | 108 | def get_average_metrics(self): 109 | metrics = {} 110 | for metric_name, metric in self.active_metrics.items(): 111 | if hasattr(metric, 'avg'): 112 | metrics[metric_name] = metric.avg 113 | elif hasattr(metric, 'get_average_score'): 114 | metrics[metric_name] = metric.get_average_score() 115 | return metrics 116 | 117 | def _handle_ratio_metric(self, metric_name, metric, batch, avg_results_dict, all_results_dict): 118 | """Helper method to handle ratio-based metrics.""" 119 | metric_key = metric_name.replace('avg_', '').replace('ratio_', '') 120 | 121 | for item in batch['json']: 122 | sample_id = get_sample_id(item) 123 | value = item[metric_key] 124 | all_results_dict[sample_id][metric_name] = value 125 | metric.update(value, 1) 126 | 127 | avg_results_dict[metric_name] = metric.avg -------------------------------------------------------------------------------- /starvector/metrics/util.py: -------------------------------------------------------------------------------- 1 | 2 | # -------------- Metrics -------------- 3 | class AverageMeter(object): 4 | """Computes and stores the average and current value""" 5 | 6 | def __init__(self): 7 | self.reset() 8 | 9 | def reset(self): 10 | self.val = 0 11 | self.avg = 0 12 | self.sum = 0 13 | self.count = 0 14 | 15 | def update(self, val, n=1): 16 | self.val = val 17 | self.sum += val * n 18 | self.count += n 19 | self.avg = self.sum / self.count 20 | -------------------------------------------------------------------------------- /starvector/model/adapters/adapter.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.init as init 3 | import torch 4 | 5 | class Swish(nn.Module): 6 | def __init__(self): 7 | super(Swish, self).__init__() 8 | 9 | def forward(self, x): 10 | return x * torch.sigmoid(x) 11 | 12 | class Adapter(nn.Module): 13 | def __init__(self, input_size, output_size, adapter_norm="layer_norm", init_type="glorot", query_length=32, dropout_prob=0.1): 14 | super().__init__() 15 | self.query_length = query_length 16 | self.dropout_prob = dropout_prob 17 | self.adapter_norm = adapter_norm 18 | 19 | self.dropout = nn.Dropout(p=self.dropout_prob) 20 | 21 | self.c_fc = nn.Linear(input_size, input_size*2) 22 | self.act = Swish() 23 | self.c_proj = nn.Linear(input_size*2, output_size) 24 | 25 | if adapter_norm == "layer_norm": 26 | self.norm = nn.LayerNorm([self.query_length, output_size]) 27 | elif adapter_norm == "batch_norm": 28 | self.norm = nn.BatchNorm1d(self.query_length) 29 | 30 | self.init_type = init_type.lower() 31 | self._initialize_weights() 32 | 33 | def forward(self, hidden_states): 34 | hidden_states = self.dropout(hidden_states) 35 | hidden_states = self.c_fc(hidden_states) 36 | hidden_states = self.act(hidden_states) 37 | hidden_states = self.c_proj(hidden_states) 38 | hidden_states = self.norm(hidden_states) 39 | return hidden_states 40 | 41 | def _initialize_weights(self): 42 | for m in self.modules(): 43 | if isinstance(m, nn.Linear): 44 | if self.init_type == "glorot": 45 | init.xavier_uniform_(m.weight) 46 | if m.bias is not None: 47 | init.constant_(m.bias, 0) 48 | elif self.init_type == "normal": 49 | init.normal_(m.weight, mean=0, std=0.01) 50 | if m.bias is not None: 51 | init.constant_(m.bias, 0) 52 | else: 53 | raise ValueError("Invalid initialization type specified.") 54 | -------------------------------------------------------------------------------- /starvector/model/builder.py: -------------------------------------------------------------------------------- 1 | 2 | from starvector.model.starvector_arch import StarVectorForCausalLM, StarVectorConfig 3 | from starvector.data.base import ImageTrainProcessor 4 | from starvector.util import dtype_mapping 5 | from transformers import AutoConfig 6 | 7 | def load_pretrained_model(model_path, device="cuda", **kwargs): 8 | model = StarVectorForCausalLM.from_pretrained(model_path, **kwargs).to(device) 9 | tokenizer = model.model.svg_transformer.tokenizer 10 | image_processor = ImageTrainProcessor() 11 | context_len = model.model.query_length + model.model.max_length 12 | return tokenizer, model, image_processor, context_len 13 | 14 | def model_builder(config): 15 | model_name = config.model.get("model_name", False) 16 | 17 | args = { 18 | "task": config.model.task, 19 | "train_image_encoder": config.training.train_image_encoder, 20 | "ignore_mismatched_sizes": True, 21 | "starcoder_model_name": config.model.starcoder_model_name, 22 | "train_LLM": config.training.train_LLM, 23 | "torch_dtype": dtype_mapping[config.training.model_precision], 24 | "transformer_layer_cls": config.model.get("transformer_layer_cls", False), 25 | "use_cache": config.model.use_cache, 26 | } 27 | if model_name: 28 | model = StarVectorForCausalLM.from_pretrained(model_name, **args) 29 | else: 30 | starcoder_model_config = AutoConfig.from_pretrained(config.model.starcoder_model_name) 31 | 32 | starvector_config = StarVectorConfig( 33 | max_length_train=config.model.max_length, 34 | image_encoder_type=config.model.image_encoder_type, 35 | use_flash_attn=config.model.use_flash_attn, 36 | adapter_norm=config.model.adapter_norm, 37 | starcoder_model_name=config.model.starcoder_model_name, 38 | torch_dtype=dtype_mapping[config.training.model_precision], 39 | num_attention_heads=starcoder_model_config.num_attention_heads, 40 | num_hidden_layers=starcoder_model_config.num_hidden_layers, 41 | vocab_size=starcoder_model_config.vocab_size, 42 | hidden_size=starcoder_model_config.hidden_size, 43 | num_kv_heads=getattr(starcoder_model_config, "num_key_value_heads", None), 44 | ) 45 | model = StarVectorForCausalLM(starvector_config, **args) 46 | 47 | return model 48 | 49 | -------------------------------------------------------------------------------- /starvector/model/gpt_bigcode/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from transformers.utils import ( 18 | OptionalDependencyNotAvailable, 19 | _LazyModule, 20 | is_torch_available, 21 | ) 22 | 23 | 24 | _import_structure = { 25 | "configuration_gpt_bigcode": ["GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTBigCodeConfig"], 26 | } 27 | 28 | try: 29 | if not is_torch_available(): 30 | raise OptionalDependencyNotAvailable() 31 | except OptionalDependencyNotAvailable: 32 | pass 33 | else: 34 | _import_structure["modeling_gpt_bigcode"] = [ 35 | "GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST", 36 | "GPTBigCodeForSequenceClassification", 37 | "GPTBigCodeForTokenClassification", 38 | "GPTBigCodeForCausalLM", 39 | "GPTBigCodeModel", 40 | "GPTBigCodePreTrainedModel", 41 | ] 42 | 43 | if TYPE_CHECKING: 44 | from .configuration_gpt_bigcode import GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTBigCodeConfig 45 | 46 | try: 47 | if not is_torch_available(): 48 | raise OptionalDependencyNotAvailable() 49 | except OptionalDependencyNotAvailable: 50 | pass 51 | else: 52 | from .modeling_gpt_bigcode import ( 53 | GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST, 54 | GPTBigCodeForCausalLM, 55 | GPTBigCodeForSequenceClassification, 56 | GPTBigCodeForTokenClassification, 57 | GPTBigCodeModel, 58 | GPTBigCodePreTrainedModel, 59 | ) 60 | 61 | 62 | else: 63 | import sys 64 | 65 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 66 | -------------------------------------------------------------------------------- /starvector/model/gpt_bigcode/configuration_gpt_bigcode.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The BigCode team and HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ GPTBigCode configuration""" 16 | from transformers.configuration_utils import PretrainedConfig 17 | from transformers.utils import logging 18 | 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | 24 | 25 | 26 | class GPTBigCodeConfig(PretrainedConfig): 27 | """ 28 | This is the configuration class to store the configuration of a [`GPTBigCodeModel`]. It is used to instantiate a 29 | GPTBigCode model according to the specified arguments, defining the model architecture. Instantiating a 30 | configuration with the defaults will yield a similar configuration to that of the GPTBigCode 31 | [gpt_bigcode](https://huggingface.co/gpt_bigcode) architecture. 32 | 33 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 34 | documentation from [`PretrainedConfig`] for more information. 35 | 36 | 37 | Args: 38 | vocab_size (`int`, *optional*, defaults to 50257): 39 | Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the 40 | `inputs_ids` passed when calling [`GPTBigCodeModel`]. 41 | n_positions (`int`, *optional*, defaults to 1024): 42 | The maximum sequence length that this model might ever be used with. Typically set this to something large 43 | just in case (e.g., 512 or 1024 or 2048). 44 | n_embd (`int`, *optional*, defaults to 768): 45 | Dimensionality of the embeddings and hidden states. 46 | n_layer (`int`, *optional*, defaults to 12): 47 | Number of hidden layers in the Transformer encoder. 48 | n_head (`int`, *optional*, defaults to 12): 49 | Number of attention heads for each attention layer in the Transformer encoder. 50 | n_inner (`int`, *optional*, defaults to None): 51 | Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd 52 | activation_function (`str`, *optional*, defaults to `"gelu_pytorch_tanh"`): 53 | Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new", 54 | "gelu_pytorch_tanh"]`. 55 | resid_pdrop (`float`, *optional*, defaults to 0.1): 56 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 57 | embd_pdrop (`float`, *optional*, defaults to 0.1): 58 | The dropout ratio for the embeddings. 59 | attn_pdrop (`float`, *optional*, defaults to 0.1): 60 | The dropout ratio for the attention. 61 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): 62 | The epsilon to use in the layer normalization layers. 63 | initializer_range (`float`, *optional*, defaults to 0.02): 64 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 65 | scale_attn_weights (`bool`, *optional*, defaults to `True`): 66 | Scale attention weights by dividing by sqrt(hidden_size).. 67 | use_cache (`bool`, *optional*, defaults to `True`): 68 | Whether or not the model should return the last key/values attentions (not used by all models). 69 | attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`): 70 | Whether to call the fused softmax in float32. 71 | scale_attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`): 72 | Whether to scale the attention softmax in float32. 73 | attention_type (`bool`, *optional*, defaults to `True`): 74 | Whether to use Multi-Query Attion (`True`) or Multi-Head Attention (`False`). 75 | Example: 76 | 77 | ```python 78 | >>> from transformers import GPTBigCodeConfig, GPTBigCodeModel 79 | 80 | >>> # Initializing a GPTBigCode configuration 81 | >>> configuration = GPTBigCodeConfig() 82 | 83 | >>> # Initializing a model (with random weights) from the configuration 84 | >>> model = GPTBigCodeModel(configuration) 85 | 86 | >>> # Accessing the model configuration 87 | >>> configuration = model.config 88 | ```""" 89 | 90 | model_type = "gpt_bigcode" 91 | keys_to_ignore_at_inference = ["past_key_values"] 92 | attribute_map = { 93 | "hidden_size": "n_embd", 94 | "max_position_embeddings": "n_positions", 95 | "num_attention_heads": "n_head", 96 | "num_hidden_layers": "n_layer", 97 | } 98 | 99 | def __init__( 100 | self, 101 | vocab_size=50257, 102 | n_positions=1024, 103 | n_embd=768, 104 | n_layer=12, 105 | n_head=12, 106 | n_inner=None, 107 | activation_function="gelu_pytorch_tanh", 108 | resid_pdrop=0.1, 109 | embd_pdrop=0.1, 110 | attn_pdrop=0.1, 111 | layer_norm_epsilon=1e-5, 112 | initializer_range=0.02, 113 | scale_attn_weights=True, 114 | use_cache=True, 115 | bos_token_id=50256, 116 | eos_token_id=50256, 117 | attention_softmax_in_fp32=True, 118 | scale_attention_softmax_in_fp32=True, 119 | multi_query=True, 120 | **kwargs, 121 | ): 122 | self.vocab_size = vocab_size 123 | self.n_positions = n_positions 124 | self.n_embd = n_embd 125 | self.n_layer = n_layer 126 | self.n_head = n_head 127 | self.n_inner = n_inner 128 | self.activation_function = activation_function 129 | self.resid_pdrop = resid_pdrop 130 | self.embd_pdrop = embd_pdrop 131 | self.attn_pdrop = attn_pdrop 132 | self.layer_norm_epsilon = layer_norm_epsilon 133 | self.initializer_range = initializer_range 134 | self.scale_attn_weights = scale_attn_weights 135 | self.use_cache = use_cache 136 | self.attention_softmax_in_fp32 = attention_softmax_in_fp32 137 | self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 138 | self.multi_query = multi_query 139 | 140 | self.bos_token_id = bos_token_id 141 | self.eos_token_id = eos_token_id 142 | 143 | super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 144 | -------------------------------------------------------------------------------- /starvector/model/image_encoder/clip_model.py: -------------------------------------------------------------------------------- 1 | # Adapted from LAVIS-Salesforce: LAVIS/lavis/models/clip_vit.py 2 | 3 | from collections import OrderedDict 4 | from itertools import repeat 5 | import collections.abc 6 | import math 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 11 | 12 | def convert_weights_to_precision(model: nn.Module, precision: torch.dtype): 13 | """Convert applicable model parameters to the specified precision""" 14 | 15 | def _convert_weights_to_precision(l): 16 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 17 | l.weight.data = l.weight.data.to(precision) 18 | if l.bias is not None: 19 | l.bias.data = l.bias.data.to(precision) 20 | 21 | elif isinstance(l, (nn.MultiheadAttention)): 22 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 23 | tensor = getattr(l, attr) 24 | if tensor is not None: 25 | tensor.data = tensor.data.to(precision) 26 | else: 27 | for _, p in l.named_parameters(): 28 | p.data = p.data.to(precision) 29 | 30 | model.apply(_convert_weights_to_precision) 31 | 32 | class Bottleneck(nn.Module): 33 | expansion = 4 34 | 35 | def __init__(self, inplanes, planes, stride=1): 36 | super().__init__() 37 | 38 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 39 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 40 | self.bn1 = nn.BatchNorm2d(planes) 41 | self.relu1 = nn.ReLU(inplace=True) 42 | 43 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.relu2 = nn.ReLU(inplace=True) 46 | 47 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 48 | 49 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 51 | self.relu3 = nn.ReLU(inplace=True) 52 | 53 | self.downsample = None 54 | self.stride = stride 55 | 56 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 57 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 58 | self.downsample = nn.Sequential(OrderedDict([ 59 | ("-1", nn.AvgPool2d(stride)), 60 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 61 | ("1", nn.BatchNorm2d(planes * self.expansion)) 62 | ])) 63 | 64 | def forward(self, x: torch.Tensor): 65 | identity = x 66 | 67 | out = self.relu1(self.bn1(self.conv1(x))) 68 | out = self.relu2(self.bn2(self.conv2(out))) 69 | out = self.avgpool(out) 70 | out = self.bn3(self.conv3(out)) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | out = self.relu3(out) 77 | return out 78 | 79 | 80 | class AttentionPool2d(nn.Module): 81 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 82 | super().__init__() 83 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 84 | self.k_proj = nn.Linear(embed_dim, embed_dim) 85 | self.q_proj = nn.Linear(embed_dim, embed_dim) 86 | self.v_proj = nn.Linear(embed_dim, embed_dim) 87 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 88 | self.num_heads = num_heads 89 | 90 | def forward(self, x): 91 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 92 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 93 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 94 | x, _ = F.multi_head_attention_forward( 95 | query=x, key=x, value=x, 96 | embed_dim_to_check=x.shape[-1], 97 | num_heads=self.num_heads, 98 | q_proj_weight=self.q_proj.weight, 99 | k_proj_weight=self.k_proj.weight, 100 | v_proj_weight=self.v_proj.weight, 101 | in_proj_weight=None, 102 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 103 | bias_k=None, 104 | bias_v=None, 105 | add_zero_attn=False, 106 | dropout_p=0, 107 | out_proj_weight=self.c_proj.weight, 108 | out_proj_bias=self.c_proj.bias, 109 | use_separate_proj_weight=True, 110 | training=self.training, 111 | need_weights=False 112 | ) 113 | 114 | return x[0] 115 | 116 | 117 | class LayerNorm(nn.LayerNorm): 118 | """Subclass torch's LayerNorm to handle fp16.""" 119 | 120 | def forward(self, x: torch.Tensor): 121 | orig_type = x.dtype 122 | layernorm_dtype = self.weight.dtype 123 | ret = super().forward(x.type(layernorm_dtype)) 124 | return ret.type(orig_type) 125 | 126 | class QuickGELU(nn.Module): 127 | def forward(self, x: torch.Tensor): 128 | return x * torch.sigmoid(1.702 * x) 129 | 130 | class ResidualAttentionBlock(nn.Module): 131 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): 132 | super().__init__() 133 | 134 | self.attn = nn.MultiheadAttention(d_model, n_head) 135 | self.ln_1 = LayerNorm(d_model) 136 | self.mlp = nn.Sequential(OrderedDict([ 137 | ("c_fc", nn.Linear(d_model, d_model * 4)), 138 | ("gelu", QuickGELU()), 139 | ("c_proj", nn.Linear(d_model * 4, d_model)) 140 | ])) 141 | self.ln_2 = LayerNorm(d_model) 142 | self.attn_mask = attn_mask 143 | 144 | if use_grad_checkpointing: 145 | self.attn = checkpoint_wrapper(self.attn) 146 | self.mlp = checkpoint_wrapper(self.mlp) 147 | 148 | def attention(self, x: torch.Tensor): 149 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 150 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 151 | 152 | def forward(self, x: torch.Tensor): 153 | x = x + self.attention(self.ln_1(x)) 154 | x = x + self.mlp(self.ln_2(x)) 155 | return x 156 | 157 | class Transformer(nn.Module): 158 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): 159 | super().__init__() 160 | self.width = width 161 | self.layers = layers 162 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)]) 163 | 164 | def forward(self, x: torch.Tensor): 165 | return self.resblocks(x) 166 | 167 | class VisionTransformer(nn.Module): 168 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool): 169 | super().__init__() 170 | self.input_resolution = input_resolution 171 | self.num_features = width 172 | self.num_heads = heads 173 | self.num_patches = (input_resolution // patch_size) ** 2 174 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 175 | scale = width ** -0.5 176 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 177 | self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width)) 178 | self.ln_pre = LayerNorm(width) 179 | self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing) 180 | 181 | def forward(self, x: torch.Tensor): 182 | x = self.conv1(x) # shape = [*, width, grid, grid] 183 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 184 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 185 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 186 | x = x + self.positional_embedding.to(x.dtype) 187 | x = self.ln_pre(x) 188 | x = x.permute(1, 0, 2) # NLD -> LND 189 | x = self.transformer(x) 190 | x = x.permute(1, 0, 2) # LND -> NLD 191 | return x 192 | -------------------------------------------------------------------------------- /starvector/model/image_encoder/image_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import os 5 | from omegaconf import OmegaConf 6 | from starvector.model.image_encoder.clip_model import convert_weights_to_precision 7 | from starvector.data.util import ImageTrainProcessor 8 | 9 | class ImageEncoder(nn.Module): 10 | def __init__(self, config, **kwargs): 11 | super(ImageEncoder, self).__init__() 12 | 13 | image_size = config.image_size 14 | torch_dtype = kwargs.get('model_precision', config.torch_dtype) 15 | # torch_dtype = torch.float32 16 | self.image_encoder_type = config.image_encoder_type 17 | if self.image_encoder_type == 'clip': 18 | self.visual_encoder, self.ln_vision = self.build_clip_encoder(image_size=image_size) 19 | convert_weights_to_precision(self, torch_dtype) 20 | self.processor = ImageTrainProcessor(size=config.image_size) 21 | 22 | elif self.image_encoder_type == 'vqgan': 23 | self.visual_encoder = self.build_vqgan_encoder() 24 | self.ln_vision = None 25 | self.processor = ImageTrainProcessor(size=config.image_size) 26 | 27 | elif self.image_encoder_type == 'convnext': 28 | self.visual_encoder = self.build_vqgan_encoder() 29 | self.ln_vision = None 30 | self.processor = ImageTrainProcessor(size=config.image_size) 31 | 32 | elif 'siglip' in self.image_encoder_type: 33 | if self.image_encoder_type == 'siglip_512': 34 | model_name = "google/siglip-base-patch16-512" 35 | elif self.image_encoder_type == 'siglip_384': 36 | model_name = "google/siglip-large-patch16-384" 37 | elif self.image_encoder_type == 'siglip_256': 38 | model_name = "google/siglip-base-patch16-256" 39 | 40 | from transformers import AutoProcessor, AutoModel 41 | 42 | self.visual_encoder = AutoModel.from_pretrained( 43 | model_name, torch_dtype = torch_dtype 44 | ).vision_model 45 | 46 | self.processor = AutoProcessor.from_pretrained( 47 | model_name, torch_dtype = torch_dtype 48 | ) 49 | 50 | def build_clip_encoder(self, image_size): 51 | from starvector.model.image_encoder.clip_model import VisionTransformer, LayerNorm 52 | visual_encoder = VisionTransformer( 53 | input_resolution=image_size, 54 | patch_size=14, 55 | width=1024, 56 | layers=23, 57 | heads=16, 58 | use_grad_checkpointing=False) 59 | 60 | ln_vision = LayerNorm(visual_encoder.num_features) 61 | return visual_encoder, ln_vision 62 | 63 | def build_vqgan_encoder(self): 64 | from taming.modules.diffusionmodules.model import Encoder 65 | VQGAN_CHECKPOINT = "/path/to/vqgan_checkpoint" # You can download the checkpoint from https://github.com/EleutherAI/vqgan-clip/blob/main/README.md 66 | vqgan_chkp_path = VQGAN_CHECKPOINT 67 | files_in_directory = os.listdir(vqgan_chkp_path + '/configs') 68 | vqgan_config_file = [file for file in files_in_directory if file.endswith('project.yaml')][0] 69 | vqgan_config = OmegaConf.load(os.path.join(vqgan_chkp_path, 'configs', vqgan_config_file)) 70 | visual_encoder = Encoder(**vqgan_config.model.params.ddconfig) 71 | 72 | # Load checkpoint weights 73 | checkpoint = torch.load(os.path.join(vqgan_chkp_path, 'checkpoints', 'last.ckpt'))['state_dict'] 74 | 75 | # Create a new state_dict with modified keys 76 | new_state_dict = {} 77 | for key, value in checkpoint.items(): 78 | if key.startswith('encoder.'): 79 | new_key = key[len('encoder.'):] 80 | new_state_dict[new_key] = value 81 | 82 | # Load weights 83 | visual_encoder.load_state_dict(new_state_dict) 84 | return visual_encoder 85 | 86 | def build_convnext_encoder(self): 87 | import open_clip 88 | model, _, _ = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k') 89 | return model.visual 90 | 91 | def forward(self, image): 92 | if self.image_encoder_type == 'clip': 93 | embeds = self.visual_encoder(image) 94 | out = self.ln_vision(embeds) 95 | elif self.image_encoder_type == 'open-clip': 96 | out = self.visual_encoder(image)[1] 97 | out = self.ln_vision(out) 98 | elif self.image_encoder_type == 'vqgan': 99 | out = self.visual_encoder(image) 100 | size = out.size() 101 | out = out.view(size[0], size[1], -1) 102 | out = out.permute(0, 2, 1) 103 | elif self.image_encoder_type == 'convnext': 104 | out = self.visual_encoder.trunk.forward_features(image) 105 | size = out.size() 106 | out = out.view(size[0], size[1], -1) 107 | out = out.permute(0, 2, 1) 108 | elif 'siglip' in self.image_encoder_type: 109 | out = self.visual_encoder(image)["last_hidden_state"] 110 | return out 111 | 112 | def process_images(self, images): 113 | if self.image_encoder_type == 'clip': 114 | res = [] 115 | for image in images: 116 | res.append(self.processor(image).unsqueeze(0)) # B, 3, H, W 117 | return res 118 | else: 119 | return self.processor(images=images, return_tensors="pt").pixel_values.unsqueeze(0) 120 | -------------------------------------------------------------------------------- /starvector/model/llm/starcoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import ( 3 | AutoConfig, 4 | AutoModelForCausalLM, 5 | AutoTokenizer, 6 | ) 7 | 8 | class StarCoderModel(nn.Module): 9 | def __init__(self, config, **kwargs): 10 | super(StarCoderModel, self).__init__() 11 | 12 | self.init_tokenizer(config.starcoder_model_name) 13 | 14 | self.max_length = config.max_length 15 | model_config = AutoConfig.from_pretrained(config.starcoder_model_name, trust_remote_code=True) 16 | kwargs = {} 17 | kwargs['trust_remote_code'] = True 18 | kwargs['torch_dtype'] = config.torch_dtype 19 | 20 | # Configure special tokens for generation 21 | model_config.eos_token_id = self.tokenizer.eos_token_id 22 | model_config.pad_token_id = self.tokenizer.pad_token_id 23 | model_config.bos_token_id = self.tokenizer.bos_token_id 24 | try: 25 | model_config.flash_attention = config.use_flash_attn 26 | model_config._attn_implementation = "flash_attention_2" 27 | except ImportError: 28 | config.use_flash_attn = False 29 | 30 | # model = GPTBigCodeForCausalLM(config=model_config) 31 | model = AutoModelForCausalLM.from_pretrained(config.starcoder_model_name, config=model_config, **kwargs) 32 | model.resize_token_embeddings(len(self.tokenizer)) 33 | self.transformer = model 34 | 35 | # Prompt the model after image 36 | self.prompt = ' BatchFeature: 53 | """ 54 | Process images and/or text inputs. 55 | 56 | Args: 57 | images: Optional image input(s) 58 | text: Optional text input(s) 59 | **kwargs: Additional arguments 60 | """ 61 | if images is None and text is None: 62 | raise ValueError("You have to specify at least one of `images` or `text`.") 63 | 64 | image_inputs = {} 65 | if images is not None: 66 | if isinstance(images, (list, tuple)): 67 | images_ = torch.stack([self.transform(img) for img in images]) 68 | else: 69 | images_ = self.transform(images) 70 | image_inputs = {"pixel_values": images_} 71 | 72 | text_inputs = {} 73 | if text is not None: 74 | text_inputs = self.tokenizer( 75 | text, truncation=True, 76 | add_special_tokens=True, 77 | padding='longest', 78 | max_length=max_length, 79 | return_tensors="pt" 80 | ) 81 | 82 | return BatchFeature(data={**text_inputs, **image_inputs}) 83 | 84 | def _pad_to_square(self, img): 85 | # Calculate padding to make the image square 86 | width, height = img.size 87 | max_dim = max(width, height) 88 | padding = [(max_dim - width) // 2, (max_dim - height) // 2] 89 | padding += [max_dim - width - padding[0], max_dim - height - padding[1]] 90 | return pad(img, padding, fill=255) # Assuming white padding 91 | 92 | 93 | AutoProcessor.register(SimpleStarVectorProcessor, SimpleStarVectorProcessor) 94 | 95 | 96 | class StarVectorConfig(PretrainedConfig): 97 | model_type = "starvector" 98 | 99 | def __init__( 100 | self, 101 | starcoder_model_name: str = "bigcode/starcoderbase-1b", 102 | image_encoder_type: str = "clip", 103 | adapter_norm: str = "layer_norm", 104 | image_size: int = 224, 105 | max_length: int = 8192, 106 | max_length_train: int = 8192, 107 | use_flash_attn: bool = True, 108 | use_cache: bool = True, 109 | num_attention_heads: int = 16, 110 | num_hidden_layers: int = 24, 111 | vocab_size: int = 49152, 112 | hidden_size: int = 2048, 113 | num_kv_heads: int = 4, 114 | torch_dtype: str = "bfloat16", 115 | **kwargs, 116 | ): 117 | kwargs["torch_dtype"] = torch_dtype 118 | self.starcoder_model_name = starcoder_model_name 119 | self.image_encoder_type = image_encoder_type 120 | self.adapter_norm = adapter_norm 121 | self.image_size = image_size 122 | self.max_length = max_length 123 | self.max_length_train = max_length_train 124 | self.use_flash_attn = use_flash_attn 125 | self.use_cache = use_cache 126 | self.num_attention_heads = num_attention_heads 127 | self.num_hidden_layers = num_hidden_layers 128 | self.vocab_size = vocab_size 129 | self.hidden_size = hidden_size 130 | self.num_kv_heads = num_kv_heads 131 | super().__init__(**kwargs) 132 | 133 | class StarVectorForCausalLM(PreTrainedModel): 134 | config_class = StarVectorConfig 135 | _no_split_modules = [] 136 | 137 | def __init__(self, config: StarVectorConfig, **kwargs): 138 | super().__init__(config) 139 | starcoder_model_name = config.starcoder_model_name 140 | if 'starcoder2' in starcoder_model_name: 141 | from starvector.model.models.starvector_v2 import StarVectorStarCoder2 142 | self.model = StarVectorStarCoder2(config=config, **kwargs) 143 | else: 144 | from starvector.model.models.starvector_v1 import StarVectorStarCoder 145 | self.model = StarVectorStarCoder(config=config, **kwargs) 146 | 147 | 148 | @property 149 | def supports_gradient_checkpointing(self): 150 | # If the underlying transformer (e.g., the one in StarCoderModel) 151 | # supports gradient checkpointing, delegate to it. 152 | if hasattr(self.model, 'svg_transformer'): 153 | return getattr(self.model.svg_transformer, 'supports_gradient_checkpointing', False) 154 | return False 155 | 156 | def gradient_checkpointing_enable(self): 157 | # Optionally, forward this call to the internal transformer. 158 | if hasattr(self.model, 'svg_transformer') and hasattr(self.model.svg_transformer, 'gradient_checkpointing_enable'): 159 | self.model.svg_transformer.gradient_checkpointing_enable() 160 | 161 | def forward(self, vision_embeds, input_ids, num_generations, attention_mask, num_logits_to_keep) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: 162 | completion_embeds = self.model._get_embeddings(input_ids) 163 | inputs_embeds = torch.cat([vision_embeds.repeat(num_generations, 1, 1), completion_embeds], dim=1) 164 | 165 | transformer_outputs = self.model.svg_transformer.transformer.transformer( 166 | inputs_embeds=inputs_embeds, 167 | attention_mask=attention_mask, 168 | ) 169 | hidden_states = transformer_outputs[0] 170 | 171 | if num_logits_to_keep > 0: 172 | lm_logits = self.model.svg_transformer.transformer.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 173 | else: 174 | lm_logits = self.model.svg_transformer.transformer.lm_head(hidden_states) 175 | 176 | loss = None 177 | return CausalLMOutputWithCrossAttentions( 178 | loss=loss, 179 | logits=lm_logits, 180 | past_key_values=transformer_outputs.past_key_values, 181 | hidden_states=transformer_outputs.hidden_states, 182 | attentions=transformer_outputs.attentions, 183 | cross_attentions=transformer_outputs.cross_attentions, 184 | ) 185 | 186 | def generate_im2svg(self, batch, **kwargs): 187 | return self.model.generate_im2svg(batch, **kwargs) 188 | 189 | def generate_im2text(self, batch, **kwargs): 190 | return self.model.generate_im2text(batch, **kwargs) 191 | 192 | def process_images(self, images): 193 | return self.model.image_encoder.process_images(images) 194 | 195 | -------------------------------------------------------------------------------- /starvector/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/star-vector/250afbe55c4a5ca9cbae181eed8ab924b30b82bd/starvector/serve/__init__.py -------------------------------------------------------------------------------- /starvector/serve/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | 15 | CLIP_QUERY_LENGTH = 257 16 | 17 | -------------------------------------------------------------------------------- /starvector/serve/conversation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import List 3 | from PIL import Image 4 | import concurrent.futures 5 | from bs4 import BeautifulSoup 6 | import cairosvg 7 | from io import BytesIO 8 | 9 | @dataclasses.dataclass 10 | class Conversation: 11 | """A class that keeps all conversation history.""" 12 | system: str 13 | image_prompt: str 14 | roles: List[str] 15 | messages: List[List[str]] 16 | offset: int 17 | version: str = "Unknown" 18 | stop_sampling: bool = False 19 | skip_next: bool = False 20 | display_images: bool = False 21 | task: str = "Im2SVG" 22 | 23 | def set_task(self, task): 24 | self.task = task 25 | 26 | def get_image_prompt(self): 27 | return self.image_prompt 28 | 29 | def get_images(self, return_pil=False): 30 | images = [] 31 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 32 | if i % 2 == 0: 33 | if type(msg) is tuple: 34 | import base64 35 | from io import BytesIO 36 | from PIL import Image 37 | image, image_process_mode = msg 38 | if image_process_mode == "Pad": 39 | def expand2square(pil_img, background_color=(255, 255, 255)): 40 | width, height = pil_img.size 41 | if width == height: 42 | return pil_img 43 | elif width > height: 44 | result = Image.new(pil_img.mode, (width, width), background_color) 45 | result.paste(pil_img, (0, (width - height) // 2)) 46 | return result 47 | else: 48 | result = Image.new(pil_img.mode, (height, height), background_color) 49 | result.paste(pil_img, ((height - width) // 2, 0)) 50 | return result 51 | image = expand2square(image) 52 | elif image_process_mode in ["Default", "Crop"]: 53 | pass 54 | elif image_process_mode == "Resize": 55 | image = image.resize((224, 224)) 56 | else: 57 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}") 58 | max_hw, min_hw = max(image.size), min(image.size) 59 | aspect_ratio = max_hw / min_hw 60 | max_len, min_len = 800, 400 61 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 62 | longest_edge = int(shortest_edge * aspect_ratio) 63 | W, H = image.size 64 | if longest_edge != max(image.size): 65 | if H > W: 66 | H, W = longest_edge, shortest_edge 67 | else: 68 | H, W = shortest_edge, longest_edge 69 | image = image.resize((W, H)) 70 | if return_pil: 71 | images.append(image) 72 | else: 73 | buffered = BytesIO() 74 | image.save(buffered, format="PNG") 75 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 76 | images.append(img_b64_str) 77 | return images 78 | 79 | def append_message(self, role, message): 80 | self.messages.append([role, message]) 81 | 82 | def download_files(self): 83 | svg_string = self.messages[-1][-1][:-1] 84 | image = self.render_svg(svg_string) 85 | svg_out = clean_svg(svg_string) 86 | 87 | return image, svg_out 88 | 89 | def rasterize_svg(self, svg_string, resolution=224, dpi = 128, scale=2): 90 | try: 91 | svg_raster_bytes = cairosvg.svg2png( 92 | bytestring=svg_string, 93 | background_color='white', 94 | output_width=resolution, 95 | output_height=resolution, 96 | dpi=dpi, 97 | scale=scale) 98 | svg_raster = Image.open(BytesIO(svg_raster_bytes)) 99 | except: 100 | try: 101 | svg = self.clean_svg(svg_string) 102 | svg_raster_bytes = cairosvg.svg2png( 103 | bytestring=svg, 104 | background_color='white', 105 | output_width=resolution, 106 | output_height=resolution, 107 | dpi=dpi, 108 | scale=scale) 109 | svg_raster = Image.open(BytesIO(svg_raster_bytes)) 110 | except: 111 | svg_raster = Image.new('RGB', (resolution, resolution), color = 'white') 112 | return svg_raster 113 | 114 | def clean_svg(self, svg_text, output_width=None, output_height=None): 115 | soup = BeautifulSoup(svg_text, 'xml') # Read as soup to parse as xml 116 | svg_bs4 = soup.prettify() # Prettify to get a string 117 | svg_cairo = cairosvg.svg2svg(svg_bs4, output_width=output_width, output_height=output_height).decode() 118 | svg_clean = "\n".join([line for line in svg_cairo.split("\n") if not line.strip().startswith(" W: 151 | H, W = longest_edge, shortest_edge 152 | else: 153 | H, W = shortest_edge, longest_edge 154 | image = image.resize((W, H)) 155 | buffered = BytesIO() 156 | image.save(buffered, format="JPEG") 157 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 158 | img_str = f'user upload image' 159 | msg = img_str 160 | ret.append([msg, None]) 161 | else: 162 | ret.append([msg, None]) 163 | else: 164 | ret[-1][-1] = msg 165 | return ret 166 | 167 | def copy(self): 168 | return Conversation( 169 | system=self.system, 170 | image_prompt=self.image_prompt, 171 | roles=self.roles, 172 | messages=[[x, y] for x, y in self.messages], 173 | offset=self.offset, 174 | version=self.version 175 | 176 | ) 177 | def dict(self): 178 | if len(self.get_images()) > 0: 179 | return { 180 | "system": self.system, 181 | "image_prompt": self.image_prompt, 182 | "roles": self.roles, 183 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], 184 | "offset": self.offset, 185 | } 186 | return { 187 | "system": self.system, 188 | "image_prompt": self.image_prompt, 189 | "roles": self.roles, 190 | "messages": self.messages, 191 | "offset": self.offset, 192 | } 193 | 194 | starvector_v1 = Conversation( 195 | system="StarVector", 196 | # prompt='', 197 | image_prompt='")[0] 62 | 63 | def get_dataloader(self): 64 | self.dataset = SVGValDataset(self.config.dataset.dataset_name, self.config.dataset.config_name, self.config.dataset.split, self.config.dataset.im_size, self.config.dataset.num_samples, self.processor) 65 | self.dataloader = DataLoader(self.dataset, batch_size=self.config.dataset.batch_size, shuffle=False, num_workers=self.config.dataset.num_workers) 66 | 67 | def release_memory(self): 68 | # Clear references to free GPU memory 69 | self.model.model.svg_transformer.tokenizer = None 70 | self.model.model.svg_transformer.model = None 71 | 72 | # Force CUDA garbage collection 73 | if torch.cuda.is_available(): 74 | torch.cuda.empty_cache() 75 | torch.cuda.ipc_collect() 76 | 77 | def generate_svg(self, batch, generate_config): 78 | if generate_config['temperature'] == 0: 79 | generate_config['temperature'] = 1.0 80 | generate_config['do_sample'] = False 81 | outputs = [] 82 | batch['image'] = batch['image'].to('cuda').to(self.torch_dtype) 83 | # for i, batch in enumerate(batch['svg']): 84 | if self.task == 'im2svg': 85 | outputs = self.model.model.generate_im2svg(batch = batch, **generate_config) 86 | elif self.task == 'text2svg': 87 | outputs = self.model.model.generate_text2svg(batch = batch, **generate_config) 88 | return outputs 89 | -------------------------------------------------------------------------------- /starvector/validation/starvector_vllm_api_svg_validator.py: -------------------------------------------------------------------------------- 1 | # vllm https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html 2 | # TODO: This is not maintained, need to update it to use the new VLLM API 3 | 4 | from .svg_validator_base import SVGValidator, register_validator 5 | from starvector.data.util import rasterize_svg, clean_svg, use_placeholder 6 | from starvector.data.util import encode_image_base64 7 | from svgpathtools import svgstr2paths 8 | import os 9 | import json 10 | from copy import deepcopy 11 | from openai import OpenAI 12 | 13 | @register_validator 14 | class StarVectorVLLMAPIValidator(SVGValidator): 15 | def __init__(self, config): 16 | 17 | super().__init__(config) 18 | # Initialize VLLM OpenAI client here 19 | self.client = OpenAI( 20 | api_key=config.run.api.key, 21 | base_url=f"{config.run.api.base_url}", 22 | ) 23 | if 'starvector-1b' in config.model.name: 24 | self.svg_end_token_id = 49154 # Adjust as needed 25 | elif 'starvector-8b' in config.model.name: 26 | self.svg_end_token_id = 49156 # Adjust as needed 27 | 28 | def generate_svg(self, batch, generate_config): 29 | outputs = [] 30 | for i, sample in enumerate(batch['svg']): 31 | if self.task == "im2svg": 32 | image = rasterize_svg(sample, 512) 33 | base64_image = encode_image_base64(image) 34 | content = [ 35 | { 36 | "type": "image_url", 37 | "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}, 38 | }, 39 | {"type": "text", "text": " 1, 57 | 'best_of': generate_config['num_beams'] 58 | }, 59 | stream=generate_config['stream'], 60 | logit_bias={self.svg_end_token_id: generate_config['logit_bias']} if generate_config['logit_bias'] else None, 61 | ) 62 | 63 | if generate_config['stream']: 64 | generated_text = self._handle_stream_response(response) 65 | else: 66 | generated_text = "