├── CONTRIBUTING_ROADMAP.md ├── LICENSE ├── README.md ├── accelerate_configs ├── 1_gpu.yaml ├── 8_gpus_deepspeed_zero2.yaml └── multi_nodes │ ├── 8_gpus_node_0.yaml │ ├── 8_gpus_node_1.yaml │ ├── 8_gpus_node_2.yaml │ ├── 8_gpus_node_3.yaml │ ├── 8_gpus_node_4.yaml │ └── 8_gpus_node_5.yaml ├── configs ├── showo_demo.yaml ├── showo_demo_512x512.yaml ├── showo_demo_w_clip_vit.yaml ├── showo_demo_w_clip_vit_512x512.yaml ├── showo_instruction_tuning_1.yaml ├── showo_instruction_tuning_1_512x512.yaml ├── showo_instruction_tuning_1_w_clip_vit.yaml ├── showo_instruction_tuning_1_w_clip_vit_512x512.yaml ├── showo_instruction_tuning_2.yaml ├── showo_instruction_tuning_2_512x512.yaml ├── showo_instruction_tuning_2_w_clip_vit.yaml ├── showo_instruction_tuning_2_w_clip_vit_512x512.yaml ├── showo_pretraining_stage1.yaml ├── showo_pretraining_stage2.yaml └── showo_pretraining_stage3.yaml ├── docs ├── characteristic_comparison.png ├── github_extrapolation.png ├── github_inpainting.png ├── github_mmu.png ├── github_t2i.png ├── show-o-512x512-mmu.png ├── show-o-512x512-t2i.png ├── show-o-ablation.png ├── show-o-geneval.png ├── show-o-want-u.png ├── showo.png ├── showo_title.png └── wechat_qa_3.jpg ├── inference_mmu.ipynb ├── inference_mmu.py ├── inference_t2i.py ├── inpainting_validation ├── .DS_Store ├── alpine_lake.jpg ├── bedroom.jpg ├── bedroom_mask.webp ├── bench.jpg ├── bench_mask.webp ├── bus.jpg ├── bus_mask.webp ├── lake_mountain.jpg ├── maya.png ├── river.png ├── train.jpg ├── train_mask.webp ├── truebsee.jpg ├── truebsee_mask.webp ├── wukong1.jpg └── wukong2.jpg ├── llava ├── __init__.py ├── llava │ ├── __init__.py │ ├── constants.py │ ├── conversation.py │ ├── mm_utils.py │ └── utils.py ├── llava_data_vq_unified.py ├── llava_instruct_data.py └── llava_pretrain_data.py ├── mmu_validation ├── dog.png └── sofa_under_water.jpg ├── models ├── __init__.py ├── clip_encoder.py ├── common_modules.py ├── logging.py ├── lr_schedulers.py ├── misc.py ├── modeling_magvitv2.py ├── modeling_showo.py ├── modeling_utils.py ├── phi.py ├── sampling.py └── training_utils.py ├── parquet ├── __init__.py └── refinedweb_dataset.py ├── requirements.txt ├── training ├── __init__.py ├── data.py ├── imagenet_dataset.py ├── imagenet_label_mapping ├── omni_attention.py ├── optimizer.py ├── prompting_utils.py ├── questions.json ├── train.py ├── train_w_clip_vit.py └── utils.py └── validation_prompts ├── imagenet_prompts.txt ├── showoprompts.txt └── text2image_prompts.txt /CONTRIBUTING_ROADMAP.md: -------------------------------------------------------------------------------- 1 | # Contributing to the Show-o Project 2 | 3 | The Show-o project is open-sourced to the community to push the boundary of unified multimodal models. We invite you to join this exciting journey and contribute to the Show-o project! 4 | 5 | ## Submitting a Pull Request (PR) 6 | 7 | As a contributor, before submitting your request, kindly follow these guidelines: 8 | 9 | 1. Start by checking the [Show-o GitHub](https://github.com/showlab/Show-o/pulls) to see if there are any open or closed pull requests related to your intended submission. Avoid duplicating existing work. 10 | 11 | 2. [Fork](https://github.com/showlab/Show-o/fork) the [Show-o](https://github.com/showlab/Show-o) repository and download your forked repository to your local machine. 12 | 13 | ```bash 14 | git clone [your-forked-repository-url] 15 | ``` 16 | 17 | 3. Add the original repository as a remote to sync with the latest updates: 18 | 19 | ```bash 20 | git remote add upstream https://github.com/showlab/Show-o 21 | ``` 22 | 23 | 4. Sync the code from the main repository to your local machine, and then push it back to your forked remote repository. 24 | 25 | ``` 26 | # Pull the latest code from the upstream branch 27 | git fetch upstream 28 | 29 | # Switch to the main branch 30 | git checkout main 31 | 32 | # Merge the updates from the upstream branch into main, synchronizing the local main branch with the upstream 33 | git merge upstream/main 34 | 35 | # Additionally, sync the local main branch to the remote branch of your forked repository 36 | git push origin main 37 | ``` 38 | 39 | 40 | > Note: Sync the code from the main repository before each submission. 41 | 42 | 5. Create a branch in your forked repository for your changes, ensuring the branch name is meaningful. 43 | 44 | ```bash 45 | git checkout -b my-docs-branch main 46 | ``` 47 | 48 | 6. While making modifications and committing changes, adhere to our [Commit Message Format](#Commit-Message-Format). 49 | 50 | ```bash 51 | git commit -m "[docs]: xxxx" 52 | ``` 53 | 54 | 7. Push your changes to your GitHub repository. 55 | 56 | ```bash 57 | git push origin my-docs-branch 58 | ``` 59 | 60 | 8. Submit a pull request to `Show-o:main` on the GitHub repository page. 61 | 62 | ## Commit Message Format 63 | 64 | Commit messages must include both `` and `` sections. 65 | 66 | ```bash 67 | []: 68 | │ │ 69 | │ └─⫸ Briefly describe your changes, without ending with a period. 70 | │ 71 | └─⫸ Commit Type: |docs|feat|fix|refactor| 72 | ``` 73 | 74 | ### Type 75 | 76 | * **docs**: Modify or add documents. 77 | * **feat**: Introduce a new feature. 78 | * **fix**: Fix a bug. 79 | * **refactor**: Restructure code, excluding new features or bug fixes. 80 | 81 | ### Summary 82 | 83 | Describe modifications in English, without ending with a period. 84 | 85 | > e.g., git commit -m "[docs]: add a contributing.md file" 86 | 87 | ## Roadmap 88 | - 🛠️ Mixed-modal generation. (In progress by [@hrodruck](https://github.com/hrodruck)) 89 | - 🛠️ Support more modalities. (In progress by by [@LJungang](https://github.com/LJungang)) 90 | - 🛠️ Efficient training/inference. (In progress by [@KevinZeng08](https://github.com/KevinZeng08)) 91 | - 📣 Support training on more datasets. (Help wanted!) 92 | - 📣 Visual tokenizer training. (Help wanted!) 93 | 94 | ### Acknowledgement 95 | This guideline is modified from [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan/tree/main) and [minisora](https://github.com/mini-sora/minisora). Thanks for their awesome templates. 96 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /accelerate_configs/1_gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: 'NO' 3 | downcast_bf16: 'no' 4 | gpu_ids: '0' 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: bf16 8 | num_machines: 1 9 | num_processes: 1 10 | rdzv_backend: static 11 | same_network: true 12 | tpu_env: [] 13 | tpu_use_cluster: false 14 | tpu_use_sudo: false 15 | use_cpu: false -------------------------------------------------------------------------------- /accelerate_configs/8_gpus_deepspeed_zero2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 1 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'no' 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: bf16 14 | num_machines: 1 15 | num_processes: 8 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false -------------------------------------------------------------------------------- /accelerate_configs/multi_nodes/8_gpus_node_0.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: cpu 7 | offload_param_device: cpu 8 | zero3_init_flag: true 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_process_ip: change to your main process ip 14 | main_process_port: 9999 15 | main_training_function: main 16 | mixed_precision: bf16 17 | num_machines: 6 18 | num_processes: 48 19 | rdzv_backend: static 20 | same_network: true 21 | tpu_env: [] 22 | tpu_use_cluster: false 23 | tpu_use_sudo: false 24 | use_cpu: false -------------------------------------------------------------------------------- /accelerate_configs/multi_nodes/8_gpus_node_1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: cpu 7 | offload_param_device: cpu 8 | zero3_init_flag: true 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 1 13 | main_process_ip: change to your main process ip 14 | main_process_port: 9999 15 | main_training_function: main 16 | mixed_precision: bf16 17 | num_machines: 6 18 | num_processes: 48 19 | rdzv_backend: static 20 | same_network: true 21 | tpu_env: [] 22 | tpu_use_cluster: false 23 | tpu_use_sudo: false 24 | use_cpu: false -------------------------------------------------------------------------------- /accelerate_configs/multi_nodes/8_gpus_node_2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: cpu 7 | offload_param_device: cpu 8 | zero3_init_flag: true 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 2 13 | main_process_ip: change to your main process ip 14 | main_process_port: 9999 15 | main_training_function: main 16 | mixed_precision: bf16 17 | num_machines: 6 18 | num_processes: 48 19 | rdzv_backend: static 20 | same_network: true 21 | tpu_env: [] 22 | tpu_use_cluster: false 23 | tpu_use_sudo: false 24 | use_cpu: false -------------------------------------------------------------------------------- /accelerate_configs/multi_nodes/8_gpus_node_3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: cpu 7 | offload_param_device: cpu 8 | zero3_init_flag: true 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 3 13 | main_process_ip: change to your main process ip 14 | main_process_port: 9999 15 | main_training_function: main 16 | mixed_precision: bf16 17 | num_machines: 6 18 | num_processes: 48 19 | rdzv_backend: static 20 | same_network: true 21 | tpu_env: [] 22 | tpu_use_cluster: false 23 | tpu_use_sudo: false 24 | use_cpu: false -------------------------------------------------------------------------------- /accelerate_configs/multi_nodes/8_gpus_node_4.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: cpu 7 | offload_param_device: cpu 8 | zero3_init_flag: true 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 4 13 | main_process_ip: change to your main process ip 14 | main_process_port: 9999 15 | main_training_function: main 16 | mixed_precision: bf16 17 | num_machines: 6 18 | num_processes: 48 19 | rdzv_backend: static 20 | same_network: true 21 | tpu_env: [] 22 | tpu_use_cluster: false 23 | tpu_use_sudo: false 24 | use_cpu: false -------------------------------------------------------------------------------- /accelerate_configs/multi_nodes/8_gpus_node_5.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: cpu 7 | offload_param_device: cpu 8 | zero3_init_flag: true 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 5 13 | main_process_ip: change to your main process ip 14 | main_process_port: 9999 15 | main_training_function: main 16 | mixed_precision: bf16 17 | num_machines: 6 18 | num_processes: 48 19 | rdzv_backend: static 20 | same_network: true 21 | tpu_env: [] 22 | tpu_use_cluster: false 23 | tpu_use_sudo: false 24 | use_cpu: false -------------------------------------------------------------------------------- /configs/showo_demo.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: null 3 | # run_id: askkz9i2 4 | resume: 'auto' 5 | 6 | experiment: 7 | project: "demo" 8 | name: "show-o-demo" 9 | output_dir: "show-o-demo" 10 | 11 | model: 12 | vq_model: 13 | type: "magvitv2" 14 | vq_model_name: "showlab/magvitv2" 15 | 16 | showo: 17 | pretrained_model_path: "showlab/show-o" 18 | w_clip_vit: False 19 | vocab_size: 58498 20 | llm_vocab_size: 50295 21 | llm_model_path: 'microsoft/phi-1_5' 22 | codebook_size: 8192 23 | num_vq_tokens: 256 24 | num_new_special_tokens: 10 # <|soi|> <|eoi|> <|sov|> <|eov|> <|t2i|> <|mmu|> <|t2v|> <|v2v|> <|lvg|> <|pad|> 25 | 26 | gradient_checkpointing: True 27 | 28 | dataset: 29 | gen_type: "t2i" 30 | und_type: "captioning" 31 | params: 32 | batch_size: ${training.batch_size} 33 | shuffle_buffer_size: 1000 34 | num_workers: 32 35 | resolution: 256 36 | pin_memory: True 37 | persistent_workers: True 38 | 39 | preprocessing: 40 | max_seq_length: 128 41 | resolution: 256 42 | center_crop: False 43 | random_flip: False 44 | 45 | training: 46 | gradient_accumulation_steps: 1 47 | cond_dropout_prob: 0.1 48 | batch_size: 20 49 | -------------------------------------------------------------------------------- /configs/showo_demo_512x512.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: null 3 | # run_id: askkz9i2 4 | resume: 'auto' 5 | 6 | experiment: 7 | project: "demo" 8 | name: "show-o-demo" 9 | output_dir: "show-o-demo" 10 | 11 | model: 12 | vq_model: 13 | type: "magvitv2" 14 | vq_model_name: "showlab/magvitv2" 15 | 16 | showo: 17 | pretrained_model_path: "showlab/show-o-512x512" 18 | w_clip_vit: False 19 | vocab_size: 58498 20 | llm_vocab_size: 50295 21 | llm_model_path: 'microsoft/phi-1_5' 22 | codebook_size: 8192 23 | num_vq_tokens: 1024 24 | num_new_special_tokens: 10 # <|soi|> <|eoi|> <|sov|> <|eov|> <|t2i|> <|mmu|> <|t2v|> <|v2v|> <|lvg|> <|pad|> 25 | 26 | gradient_checkpointing: True 27 | 28 | dataset: 29 | gen_type: "t2i" 30 | und_type: "captioning" 31 | params: 32 | batch_size: ${training.batch_size} 33 | shuffle_buffer_size: 1000 34 | num_workers: 32 35 | resolution: 512 36 | pin_memory: True 37 | persistent_workers: True 38 | 39 | preprocessing: 40 | max_seq_length: 128 41 | resolution: 512 42 | center_crop: False 43 | random_flip: False 44 | 45 | training: 46 | gradient_accumulation_steps: 1 47 | cond_dropout_prob: 0.1 48 | batch_size: 20 49 | -------------------------------------------------------------------------------- /configs/showo_demo_w_clip_vit.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: null 3 | # run_id: askkz9i2 4 | resume: 'auto' 5 | 6 | experiment: 7 | project: "demo" 8 | name: "show-o-demo" 9 | output_dir: "show-o-demo" 10 | 11 | model: 12 | vq_model: 13 | type: "magvitv2" 14 | vq_model_name: "showlab/magvitv2" 15 | 16 | showo: 17 | pretrained_model_path: "showlab/show-o-w-clip-vit" 18 | w_clip_vit: True 19 | vocab_size: 58498 20 | llm_vocab_size: 50295 21 | llm_model_path: 'microsoft/phi-1_5' 22 | codebook_size: 8192 23 | num_vq_tokens: 256 24 | num_new_special_tokens: 10 # <|soi|> <|eoi|> <|sov|> <|eov|> <|t2i|> <|mmu|> <|t2v|> <|v2v|> <|lvg|> <|pad|> 25 | 26 | gradient_checkpointing: True 27 | 28 | dataset: 29 | gen_type: "t2i" 30 | und_type: "captioning" 31 | params: 32 | batch_size: ${training.batch_size} 33 | shuffle_buffer_size: 1000 34 | num_workers: 32 35 | resolution: 256 36 | pin_memory: True 37 | persistent_workers: True 38 | 39 | preprocessing: 40 | max_seq_length: 128 41 | resolution: 256 42 | center_crop: False 43 | random_flip: False 44 | 45 | training: 46 | gradient_accumulation_steps: 1 47 | cond_dropout_prob: 0.1 48 | batch_size: 20 49 | -------------------------------------------------------------------------------- /configs/showo_demo_w_clip_vit_512x512.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: null 3 | # run_id: askkz9i2 4 | resume: 'auto' 5 | 6 | experiment: 7 | project: "demo" 8 | name: "show-o-demo" 9 | output_dir: "show-o-demo" 10 | 11 | model: 12 | vq_model: 13 | type: "magvitv2" 14 | vq_model_name: "showlab/magvitv2" 15 | 16 | showo: 17 | pretrained_model_path: "showlab/show-o-w-clip-vit-512x512" 18 | w_clip_vit: True 19 | vocab_size: 58498 20 | llm_vocab_size: 50295 21 | llm_model_path: 'microsoft/phi-1_5' 22 | codebook_size: 8192 23 | num_vq_tokens: 1024 24 | num_new_special_tokens: 10 # <|soi|> <|eoi|> <|sov|> <|eov|> <|t2i|> <|mmu|> <|t2v|> <|v2v|> <|lvg|> <|pad|> 25 | 26 | gradient_checkpointing: True 27 | 28 | dataset: 29 | gen_type: "t2i" 30 | und_type: "captioning" 31 | params: 32 | batch_size: ${training.batch_size} 33 | shuffle_buffer_size: 1000 34 | num_workers: 32 35 | resolution: 512 36 | pin_memory: True 37 | persistent_workers: True 38 | 39 | preprocessing: 40 | max_seq_length: 128 41 | resolution: 512 42 | center_crop: False 43 | random_flip: False 44 | 45 | training: 46 | gradient_accumulation_steps: 1 47 | cond_dropout_prob: 0.1 48 | batch_size: 20 49 | -------------------------------------------------------------------------------- /configs/showo_instruction_tuning_1.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: null 3 | # run_id: askkz9i2 4 | resume: 'auto' 5 | 6 | experiment: 7 | project: "tuning" 8 | name: "show-o-tuning-stage1" 9 | output_dir: "show-o-tuning-stage1" 10 | max_train_examples_t2i: 20000000 11 | max_train_examples_mmu: 40000000 12 | save_every: 10000 13 | eval_every: 2500 14 | generate_every: 1000 15 | log_every: 50 16 | log_grad_norm_every: 500 17 | resume_from_checkpoint: 'latest' 18 | 19 | model: 20 | vq_model: 21 | type: "magvitv2" 22 | vq_model_name: "showlab/magvitv2" 23 | 24 | showo: 25 | load_from_showo: False 26 | pretrained_model_path: "showlab/show-o" 27 | w_clip_vit: False 28 | vocab_size: 58498 29 | llm_vocab_size: 50295 30 | llm_model_path: 'microsoft/phi-1_5' 31 | codebook_size: 8192 32 | num_vq_tokens: 256 33 | num_new_special_tokens: 10 # <|soi|> <|eoi|> <|sov|> <|eov|> <|t2i|> <|mmu|> <|t2v|> <|v2v|> <|lvg|> <|pad|> 34 | 35 | gradient_checkpointing: True 36 | 37 | dataset: 38 | gen_type: "t2i" 39 | und_type: "llava_pretrain" 40 | combined_loader_mode: "min_size" 41 | add_system_prompt: False 42 | params: 43 | train_t2i_shards_path_or_url: [ "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar", 44 | "/mnt/bn/vgfm/JourneyDB/JourneyDB/data/train/imgs/{000..199}.tgz" ] 45 | train_mmu_shards_path_or_url: [ "/mnt/bn/vgfm2/test_mlx/xavier/data/SA1B2/sa_{000000..000999}.tar", 46 | "/mnt/bn/vgfm/cc12m/images/{00000..01242}.tar", 47 | "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar" ] 48 | train_lm_shards_path_or_url: "/mnt/bn/vgfm2/test_mlx/xavier/data/falcon-refinedweb/data/*.parquet" 49 | add_caption_prompt: True 50 | external_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/data/SAM-LLaVA-Captions10M" 51 | external_journeydb_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/code/3062/open_muse/train_journeydb_anno.json" 52 | external_laion12m_caption_path: "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-captions" 53 | external_cc12m_caption_path: '/mnt/bn/vgfm/cc12m/captions/' 54 | validation_prompts_file: "validation_prompts/text2image_prompts.txt" 55 | shuffle_buffer_size: 1000 56 | num_workers: 32 57 | resolution: 256 58 | pin_memory: True 59 | persistent_workers: True 60 | 61 | preprocessing: 62 | max_seq_length: 381 # for text tokens 63 | resolution: 256 64 | center_crop: False 65 | random_flip: False 66 | 67 | optimizer: 68 | name: adamw 69 | params: # default adamw params 70 | learning_rate: 0.00002 71 | scale_lr: False # scale learning rate by total batch size 72 | beta1: 0.9 73 | beta2: 0.999 74 | weight_decay: 0.01 75 | epsilon: 1e-8 76 | 77 | lr_scheduler: 78 | scheduler: "cosine" 79 | params: 80 | learning_rate: ${optimizer.params.learning_rate} 81 | warmup_steps: 1000 82 | 83 | training: 84 | gradient_accumulation_steps: 1 85 | noise_type: "mask" 86 | batch_size_t2i: 4 87 | batch_size_lm: 2 88 | batch_size_mmu: 7 89 | mixed_precision: "bf16" 90 | enable_tf32: True 91 | seed: 10086 92 | max_train_steps: 10000 93 | overfit_one_batch: False 94 | cond_dropout_prob: 0.1 95 | min_masking_rate: 0.0 96 | label_smoothing: 0.0 97 | max_grad_norm: null 98 | guidance_scale: 0.0 99 | generation_timesteps: 12 100 | t2i_coeff: 1.0 101 | lm_coeff: 0.1 102 | mmu_coeff: 1.0 103 | -------------------------------------------------------------------------------- /configs/showo_instruction_tuning_1_512x512.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: null 3 | # run_id: askkz9i2 4 | resume: 'auto' 5 | 6 | experiment: 7 | project: "tuning" 8 | name: "show-o-tuning-stage1-512x512" 9 | output_dir: "show-o-tuning-stage1-512x512" 10 | max_train_examples_t2i: 20000000 11 | max_train_examples_mmu: 40000000 12 | save_every: 10000 13 | eval_every: 2500 14 | generate_every: 1000 15 | log_every: 50 16 | log_grad_norm_every: 500 17 | resume_from_checkpoint: 'latest' 18 | 19 | model: 20 | vq_model: 21 | type: "magvitv2" 22 | vq_model_name: "showlab/magvitv2" 23 | 24 | showo: 25 | load_from_showo: False 26 | pretrained_model_path: "showlab/show-o" 27 | w_clip_vit: False 28 | vocab_size: 58498 29 | llm_vocab_size: 50295 30 | llm_model_path: 'microsoft/phi-1_5' 31 | codebook_size: 8192 32 | num_vq_tokens: 1024 33 | num_new_special_tokens: 10 # <|soi|> <|eoi|> <|sov|> <|eov|> <|t2i|> <|mmu|> <|t2v|> <|v2v|> <|lvg|> <|pad|> 34 | 35 | gradient_checkpointing: True 36 | 37 | dataset: 38 | gen_type: "t2i" 39 | und_type: "llava_pretrain" 40 | combined_loader_mode: "min_size" 41 | add_system_prompt: False 42 | params: 43 | train_t2i_shards_path_or_url: [ 44 | "/mnt/bn/vgfm2/test_mlx/xavier/data/XData/XData-{000000..000009}.tar", 45 | "/mnt/bn/vgfm2/test_mlx/xavier/data/XData-2/XData-2-{000000..000009}.tar", 46 | "/mnt/bn/vgfm2/test_mlx/xavier/data/XData-3/XData-3-{000000..000024}.tar", 47 | "/mnt/bn/vgfm2/test_mlx/xavier/data/XData-4/XData-4-{000000..000009}.tar", 48 | # "/mnt/bn/vgfm/JourneyDB/JourneyDB/data/train/imgs/{000..199}.tgz", 49 | ] 50 | train_mmu_shards_path_or_url: [ "/mnt/bn/vgfm2/test_mlx/xavier/data/SA1B2/sa_{000000..000999}.tar", 51 | "/mnt/bn/vgfm/cc12m/images/{00000..01242}.tar", 52 | "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar" ] 53 | train_lm_shards_path_or_url: "/mnt/bn/vgfm2/test_mlx/xavier/data/falcon-refinedweb/data/*.parquet" 54 | add_caption_prompt: True 55 | external_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/data/SAM-LLaVA-Captions10M" 56 | external_journeydb_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/code/3062/open_muse/train_journeydb_anno.json" 57 | external_laion12m_caption_path: "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-captions" 58 | external_cc12m_caption_path: '/mnt/bn/vgfm/cc12m/captions/' 59 | validation_prompts_file: "validation_prompts/text2image_prompts.txt" 60 | shuffle_buffer_size: 1000 61 | num_workers: 32 62 | resolution: 512 63 | pin_memory: True 64 | persistent_workers: True 65 | 66 | preprocessing: 67 | max_seq_length: 381 # for text tokens 68 | # max_seq_length: 512 # for text tokens 69 | resolution: 512 70 | center_crop: False 71 | random_flip: False 72 | 73 | optimizer: 74 | name: adamw 75 | params: # default adamw params 76 | learning_rate: 0.00002 77 | scale_lr: False # scale learning rate by total batch size 78 | beta1: 0.9 79 | beta2: 0.999 80 | weight_decay: 0.01 81 | epsilon: 1e-8 82 | 83 | lr_scheduler: 84 | scheduler: "cosine" 85 | params: 86 | learning_rate: ${optimizer.params.learning_rate} 87 | warmup_steps: 1000 88 | 89 | training: 90 | gradient_accumulation_steps: 1 91 | noise_type: "mask" 92 | batch_size_t2i: 5 93 | batch_size_lm: 2 94 | batch_size_mmu: 5 95 | # batch_size_t2i: 4 96 | # batch_size_lm: 1 97 | # batch_size_mmu: 4 98 | mixed_precision: "bf16" 99 | enable_tf32: True 100 | seed: 10086 101 | max_train_steps: 14000 102 | overfit_one_batch: False 103 | cond_dropout_prob: 0.1 104 | min_masking_rate: 0.0 105 | label_smoothing: 0.0 106 | max_grad_norm: null 107 | guidance_scale: 0.0 108 | generation_timesteps: 12 109 | t2i_coeff: 1.0 110 | lm_coeff: 0.1 111 | mmu_coeff: 1.0 112 | -------------------------------------------------------------------------------- /configs/showo_instruction_tuning_1_w_clip_vit.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: null 3 | # run_id: askkz9i2 4 | resume: 'auto' 5 | 6 | experiment: 7 | project: "tuning" 8 | name: "show-o-tuning-stage1-w-clip-vit" 9 | output_dir: "show-o-tuning-stage1-w-clip-vit" 10 | max_train_examples_t2i: 20000000 11 | max_train_examples_mmu: 40000000 12 | save_every: 10000 13 | eval_every: 2500 14 | generate_every: 1000 15 | log_every: 50 16 | log_grad_norm_every: 500 17 | resume_from_checkpoint: 'latest' 18 | 19 | model: 20 | vq_model: 21 | type: "magvitv2" 22 | vq_model_name: "showlab/magvitv2" 23 | 24 | showo: 25 | load_from_showo: False 26 | pretrained_model_path: "showlab/show-o-w-clip-vit" 27 | w_clip_vit: True 28 | vocab_size: 58498 29 | llm_vocab_size: 50295 30 | llm_model_path: 'microsoft/phi-1_5' 31 | codebook_size: 8192 32 | num_vq_tokens: 256 33 | num_new_special_tokens: 10 # <|soi|> <|eoi|> <|sov|> <|eov|> <|t2i|> <|mmu|> <|t2v|> <|v2v|> <|lvg|> <|pad|> 34 | 35 | gradient_checkpointing: True 36 | 37 | dataset: 38 | gen_type: "t2i" 39 | und_type: "llava_pretrain" 40 | combined_loader_mode: "min_size" 41 | params: 42 | train_t2i_shards_path_or_url: [ "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar", 43 | "/mnt/bn/vgfm/JourneyDB/JourneyDB/data/train/imgs/{000..199}.tgz" ] 44 | train_mmu_shards_path_or_url: [ "/mnt/bn/vgfm2/test_mlx/xavier/data/SA1B2/sa_{000000..000999}.tar", 45 | "/mnt/bn/vgfm/cc12m/images/{00000..01242}.tar", 46 | "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar" ] 47 | train_lm_shards_path_or_url: "/mnt/bn/vgfm2/test_mlx/xavier/data/falcon-refinedweb/data/*.parquet" 48 | add_caption_prompt: True 49 | external_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/data/SAM-LLaVA-Captions10M" 50 | external_journeydb_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/code/3062/open_muse/train_journeydb_anno.json" 51 | external_laion12m_caption_path: "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-captions" 52 | external_cc12m_caption_path: '/mnt/bn/vgfm/cc12m/captions/' 53 | validation_prompts_file: "validation_prompts/text2image_prompts.txt" 54 | shuffle_buffer_size: 1000 55 | num_workers: 32 56 | resolution: 256 57 | pin_memory: True 58 | persistent_workers: True 59 | 60 | preprocessing: 61 | max_seq_length: 512 # for text tokens 62 | resolution: 256 63 | center_crop: False 64 | random_flip: False 65 | 66 | optimizer: 67 | name: adamw 68 | params: # default adamw params 69 | learning_rate: 0.002 70 | scale_lr: False # scale learning rate by total batch size 71 | beta1: 0.9 72 | beta2: 0.999 73 | weight_decay: 0.01 74 | epsilon: 1e-8 75 | 76 | lr_scheduler: 77 | scheduler: "cosine" 78 | params: 79 | learning_rate: ${optimizer.params.learning_rate} 80 | warmup_steps: 1000 81 | 82 | training: 83 | gradient_accumulation_steps: 1 84 | noise_type: "mask" 85 | batch_size_t2i: 2 86 | batch_size_lm: 2 87 | batch_size_mmu: 10 88 | mixed_precision: "bf16" 89 | enable_tf32: True 90 | seed: 10086 91 | max_train_steps: 10000 92 | overfit_one_batch: False 93 | cond_dropout_prob: 0.1 94 | min_masking_rate: 0.0 95 | label_smoothing: 0.0 96 | max_grad_norm: null 97 | guidance_scale: 0.0 98 | generation_timesteps: 12 99 | t2i_coeff: 1.0 100 | lm_coeff: 0.1 101 | mmu_coeff: 1.0 102 | -------------------------------------------------------------------------------- /configs/showo_instruction_tuning_1_w_clip_vit_512x512.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: null 3 | # run_id: askkz9i2 4 | resume: 'auto' 5 | 6 | experiment: 7 | project: "tuning" 8 | name: "show-o-tuning-stage1-w-clip-vit-512x512" 9 | output_dir: "show-o-tuning-stage1-w-clip-vit-512x512" 10 | max_train_examples_t2i: 20000000 11 | max_train_examples_mmu: 40000000 12 | save_every: 10000 13 | eval_every: 2500 14 | generate_every: 1000 15 | log_every: 50 16 | log_grad_norm_every: 500 17 | resume_from_checkpoint: 'latest' 18 | 19 | model: 20 | vq_model: 21 | type: "magvitv2" 22 | vq_model_name: "showlab/magvitv2" 23 | 24 | showo: 25 | load_from_showo: False 26 | pretrained_model_path: "showlab/show-o-w-clip-vit" 27 | w_clip_vit: True 28 | vocab_size: 58498 29 | llm_vocab_size: 50295 30 | llm_model_path: 'microsoft/phi-1_5' 31 | codebook_size: 8192 32 | num_vq_tokens: 1024 33 | num_new_special_tokens: 10 # <|soi|> <|eoi|> <|sov|> <|eov|> <|t2i|> <|mmu|> <|t2v|> <|v2v|> <|lvg|> <|pad|> 34 | 35 | gradient_checkpointing: True 36 | 37 | dataset: 38 | gen_type: "t2i" 39 | und_type: "llava_pretrain" 40 | combined_loader_mode: "min_size" 41 | params: 42 | train_t2i_shards_path_or_url: [ 43 | "/mnt/bn/vgfm2/test_mlx/xavier/data/XData/XData-{000000..000009}.tar", 44 | "/mnt/bn/vgfm2/test_mlx/xavier/data/XData-2/XData-2-{000000..000009}.tar", 45 | "/mnt/bn/vgfm2/test_mlx/xavier/data/XData-3/XData-3-{000000..000024}.tar", 46 | "/mnt/bn/vgfm2/test_mlx/xavier/data/XData-4/XData-4-{000000..000009}.tar", 47 | # "/mnt/bn/vgfm/JourneyDB/JourneyDB/data/train/imgs/{000..199}.tgz", 48 | ] 49 | train_mmu_shards_path_or_url: [ "/mnt/bn/vgfm2/test_mlx/xavier/data/SA1B2/sa_{000000..000999}.tar", 50 | "/mnt/bn/vgfm/cc12m/images/{00000..01242}.tar", 51 | "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar" ] 52 | train_lm_shards_path_or_url: "/mnt/bn/vgfm2/test_mlx/xavier/data/falcon-refinedweb/data/*.parquet" 53 | add_caption_prompt: True 54 | external_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/data/SAM-LLaVA-Captions10M" 55 | external_journeydb_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/code/3062/open_muse/train_journeydb_anno.json" 56 | external_laion12m_caption_path: "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-captions" 57 | external_cc12m_caption_path: '/mnt/bn/vgfm/cc12m/captions/' 58 | validation_prompts_file: "validation_prompts/text2image_prompts.txt" 59 | shuffle_buffer_size: 1000 60 | num_workers: 32 61 | resolution: 512 62 | pin_memory: True 63 | persistent_workers: True 64 | 65 | preprocessing: 66 | max_seq_length: 512 # for text tokens 67 | resolution: 512 68 | center_crop: False 69 | random_flip: False 70 | 71 | optimizer: 72 | name: adamw 73 | params: # default adamw params 74 | learning_rate: 0.002 75 | scale_lr: False # scale learning rate by total batch size 76 | beta1: 0.9 77 | beta2: 0.999 78 | weight_decay: 0.01 79 | epsilon: 1e-8 80 | 81 | lr_scheduler: 82 | scheduler: "cosine" 83 | params: 84 | learning_rate: ${optimizer.params.learning_rate} 85 | warmup_steps: 1000 86 | 87 | training: 88 | gradient_accumulation_steps: 1 89 | noise_type: "mask" 90 | batch_size_t2i: 2 91 | batch_size_lm: 2 92 | batch_size_mmu: 10 93 | mixed_precision: "bf16" 94 | enable_tf32: True 95 | seed: 10086 96 | max_train_steps: 7000 97 | overfit_one_batch: False 98 | cond_dropout_prob: 0.1 99 | min_masking_rate: 0.0 100 | label_smoothing: 0.0 101 | max_grad_norm: null 102 | guidance_scale: 0.0 103 | generation_timesteps: 12 104 | t2i_coeff: 1.0 105 | lm_coeff: 0.1 106 | mmu_coeff: 1.0 107 | -------------------------------------------------------------------------------- /configs/showo_instruction_tuning_2.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: null 3 | # run_id: askkz9i2 4 | resume: 'auto' 5 | 6 | experiment: 7 | project: "tuning" 8 | name: "show-o-tuning-stage2" 9 | output_dir: "show-o-tuning-stage2" 10 | max_train_examples_t2i: 20000000 11 | max_train_examples_mmu: 40000000 12 | save_every: 10000 13 | eval_every: 2500 14 | generate_every: 1000 15 | log_every: 50 16 | log_grad_norm_every: 500 17 | resume_from_checkpoint: 'latest' 18 | 19 | model: 20 | vq_model: 21 | type: "magvitv2" 22 | vq_model_name: "showlab/magvitv2" 23 | 24 | showo: 25 | load_from_showo: False 26 | pretrained_model_path: "showlab/show-o" 27 | w_clip_vit: False 28 | vocab_size: 58498 29 | llm_vocab_size: 50295 30 | llm_model_path: 'microsoft/phi-1_5' 31 | codebook_size: 8192 32 | num_vq_tokens: 256 33 | num_new_special_tokens: 10 # <|soi|> <|eoi|> <|sov|> <|eov|> <|t2i|> <|mmu|> <|t2v|> <|v2v|> <|lvg|> <|pad|> 34 | 35 | gradient_checkpointing: True 36 | 37 | dataset: 38 | gen_type: "t2i" 39 | und_type: "llava_tuning" 40 | combined_loader_mode: "min_size" 41 | add_system_prompt: False 42 | params: 43 | train_t2i_shards_path_or_url: [ "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar", 44 | "/mnt/bn/vgfm/JourneyDB/JourneyDB/data/train/imgs/{000..199}.tgz" ] 45 | train_mmu_shards_path_or_url: [ "/mnt/bn/vgfm2/test_mlx/xavier/data/SA1B2/sa_{000000..000999}.tar", 46 | "/mnt/bn/vgfm/cc12m/images/{00000..01242}.tar", 47 | "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar" ] 48 | train_lm_shards_path_or_url: "/mnt/bn/vgfm2/test_mlx/xavier/data/falcon-refinedweb/data/*.parquet" 49 | add_caption_prompt: True 50 | external_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/data/SAM-LLaVA-Captions10M" 51 | external_journeydb_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/code/3062/open_muse/train_journeydb_anno.json" 52 | external_laion12m_caption_path: "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-captions" 53 | external_cc12m_caption_path: '/mnt/bn/vgfm/cc12m/captions/' 54 | validation_prompts_file: "validation_prompts/text2image_prompts.txt" 55 | shuffle_buffer_size: 1000 56 | num_workers: 32 57 | resolution: 256 58 | pin_memory: True 59 | persistent_workers: True 60 | 61 | preprocessing: 62 | max_seq_length: 381 # for text tokens 63 | resolution: 256 64 | center_crop: False 65 | random_flip: False 66 | 67 | optimizer: 68 | name: adamw 69 | params: # default adamw params 70 | learning_rate: 5e-05 71 | scale_lr: False # scale learning rate by total batch size 72 | beta1: 0.9 73 | beta2: 0.999 74 | weight_decay: 0.01 75 | epsilon: 1e-8 76 | 77 | lr_scheduler: 78 | scheduler: "cosine" 79 | params: 80 | learning_rate: ${optimizer.params.learning_rate} 81 | warmup_steps: 1000 82 | 83 | training: 84 | gradient_accumulation_steps: 1 85 | noise_type: "mask" 86 | batch_size_t2i: 4 87 | batch_size_lm: 2 88 | batch_size_mmu: 6 89 | mixed_precision: "bf16" 90 | enable_tf32: True 91 | seed: 10086 92 | max_train_steps: 14000 93 | overfit_one_batch: False 94 | cond_dropout_prob: 0.1 95 | min_masking_rate: 0.0 96 | label_smoothing: 0.0 97 | max_grad_norm: null 98 | guidance_scale: 0.0 99 | generation_timesteps: 12 100 | t2i_coeff: 1.0 101 | lm_coeff: 0.1 102 | mmu_coeff: 1.0 103 | -------------------------------------------------------------------------------- /configs/showo_instruction_tuning_2_512x512.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: null 3 | # run_id: askkz9i2 4 | resume: 'auto' 5 | 6 | experiment: 7 | project: "tuning" 8 | name: "show-o-tuning-stage2-512x512" 9 | output_dir: "show-o-tuning-stage2-512x512" 10 | max_train_examples_t2i: 20000000 11 | max_train_examples_mmu: 40000000 12 | save_every: 10000 13 | eval_every: 2500 14 | generate_every: 1000 15 | log_every: 50 16 | log_grad_norm_every: 500 17 | resume_from_checkpoint: 'latest' 18 | 19 | model: 20 | vq_model: 21 | type: "magvitv2" 22 | vq_model_name: "showlab/magvitv2" 23 | 24 | showo: 25 | load_from_showo: False 26 | pretrained_model_path: "showlab/show-o" 27 | w_clip_vit: False 28 | vocab_size: 58498 29 | llm_vocab_size: 50295 30 | llm_model_path: 'microsoft/phi-1_5' 31 | codebook_size: 8192 32 | num_vq_tokens: 1024 33 | num_new_special_tokens: 10 # <|soi|> <|eoi|> <|sov|> <|eov|> <|t2i|> <|mmu|> <|t2v|> <|v2v|> <|lvg|> <|pad|> 34 | 35 | gradient_checkpointing: True 36 | 37 | dataset: 38 | gen_type: "t2i" 39 | und_type: "llava_tuning" 40 | combined_loader_mode: "min_size" 41 | add_system_prompt: False 42 | params: 43 | train_t2i_shards_path_or_url: [ 44 | "/mnt/bn/vgfm2/test_mlx/xavier/data/XData/XData-{000000..000009}.tar", 45 | "/mnt/bn/vgfm2/test_mlx/xavier/data/XData-2/XData-2-{000000..000009}.tar", 46 | "/mnt/bn/vgfm2/test_mlx/xavier/data/XData-3/XData-3-{000000..000024}.tar", 47 | "/mnt/bn/vgfm2/test_mlx/xavier/data/XData-4/XData-4-{000000..000009}.tar", 48 | # "/mnt/bn/vgfm/JourneyDB/JourneyDB/data/train/imgs/{000..199}.tgz", 49 | ] 50 | train_mmu_shards_path_or_url: [ "/mnt/bn/vgfm2/test_mlx/xavier/data/SA1B2/sa_{000000..000999}.tar", 51 | "/mnt/bn/vgfm/cc12m/images/{00000..01242}.tar", 52 | "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar" ] 53 | train_lm_shards_path_or_url: "/mnt/bn/vgfm2/test_mlx/xavier/data/falcon-refinedweb/data/*.parquet" 54 | add_caption_prompt: True 55 | external_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/data/SAM-LLaVA-Captions10M" 56 | external_journeydb_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/code/3062/open_muse/train_journeydb_anno.json" 57 | external_laion12m_caption_path: "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-captions" 58 | external_cc12m_caption_path: '/mnt/bn/vgfm/cc12m/captions/' 59 | validation_prompts_file: "validation_prompts/text2image_prompts.txt" 60 | shuffle_buffer_size: 1000 61 | num_workers: 32 62 | resolution: 512 63 | pin_memory: True 64 | persistent_workers: True 65 | 66 | preprocessing: 67 | max_seq_length: 381 # for text tokens 68 | # max_seq_length: 512 # for text tokens 69 | resolution: 512 70 | center_crop: False 71 | random_flip: False 72 | 73 | optimizer: 74 | name: adamw 75 | params: # default adamw params 76 | learning_rate: 5e-05 77 | scale_lr: False # scale learning rate by total batch size 78 | beta1: 0.9 79 | beta2: 0.999 80 | weight_decay: 0.01 81 | epsilon: 1e-8 82 | 83 | lr_scheduler: 84 | scheduler: "cosine" 85 | params: 86 | learning_rate: ${optimizer.params.learning_rate} 87 | warmup_steps: 1000 88 | 89 | training: 90 | gradient_accumulation_steps: 1 91 | noise_type: "mask" 92 | batch_size_t2i: 5 93 | batch_size_lm: 2 94 | batch_size_mmu: 5 95 | # batch_size_t2i: 4 96 | # batch_size_lm: 1 97 | # batch_size_mmu: 4 98 | mixed_precision: "bf16" 99 | enable_tf32: True 100 | seed: 10086 101 | max_train_steps: 16000 102 | overfit_one_batch: False 103 | cond_dropout_prob: 0.1 104 | min_masking_rate: 0.0 105 | label_smoothing: 0.0 106 | max_grad_norm: null 107 | guidance_scale: 0.0 108 | generation_timesteps: 12 109 | t2i_coeff: 1.0 110 | lm_coeff: 0.1 111 | mmu_coeff: 1.0 112 | -------------------------------------------------------------------------------- /configs/showo_instruction_tuning_2_w_clip_vit.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: null 3 | # run_id: askkz9i2 4 | resume: 'auto' 5 | 6 | experiment: 7 | project: "tuning" 8 | name: "show-o-tuning-stage2-w-clip-vit" 9 | output_dir: "show-o-tuning-stage2-w-clip-vit" 10 | max_train_examples_t2i: 20000000 11 | max_train_examples_mmu: 40000000 12 | save_every: 10000 13 | eval_every: 2500 14 | generate_every: 1000 15 | log_every: 50 16 | log_grad_norm_every: 500 17 | resume_from_checkpoint: 'latest' 18 | 19 | model: 20 | vq_model: 21 | type: "magvitv2" 22 | vq_model_name: "showlab/magvitv2" 23 | 24 | showo: 25 | load_from_showo: False 26 | pretrained_model_path: "showlab/show-o-w-clip-vit" 27 | w_clip_vit: True 28 | vocab_size: 58498 29 | llm_vocab_size: 50295 30 | llm_model_path: 'microsoft/phi-1_5' 31 | codebook_size: 8192 32 | num_vq_tokens: 256 33 | num_new_special_tokens: 10 # <|soi|> <|eoi|> <|sov|> <|eov|> <|t2i|> <|mmu|> <|t2v|> <|v2v|> <|lvg|> <|pad|> 34 | 35 | gradient_checkpointing: True 36 | 37 | dataset: 38 | gen_type: "t2i" 39 | und_type: "llava_tuning" 40 | combined_loader_mode: "min_size" 41 | params: 42 | train_t2i_shards_path_or_url: [ "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar", 43 | "/mnt/bn/vgfm/JourneyDB/JourneyDB/data/train/imgs/{000..199}.tgz" ] 44 | train_mmu_shards_path_or_url: [ "/mnt/bn/vgfm2/test_mlx/xavier/data/SA1B2/sa_{000000..000999}.tar", 45 | "/mnt/bn/vgfm/cc12m/images/{00000..01242}.tar", 46 | "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar" ] 47 | train_lm_shards_path_or_url: "/mnt/bn/vgfm2/test_mlx/xavier/data/falcon-refinedweb/data/*.parquet" 48 | add_caption_prompt: True 49 | external_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/data/SAM-LLaVA-Captions10M" 50 | external_journeydb_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/code/3062/open_muse/train_journeydb_anno.json" 51 | external_laion12m_caption_path: "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-captions" 52 | external_cc12m_caption_path: '/mnt/bn/vgfm/cc12m/captions/' 53 | validation_prompts_file: "validation_prompts/text2image_prompts.txt" 54 | shuffle_buffer_size: 1000 55 | num_workers: 32 56 | resolution: 256 57 | pin_memory: True 58 | persistent_workers: True 59 | 60 | preprocessing: 61 | max_seq_length: 576 # for text tokens 62 | resolution: 256 63 | center_crop: False 64 | random_flip: False 65 | 66 | optimizer: 67 | name: adamw 68 | params: # default adamw params 69 | learning_rate: 0.0001 70 | scale_lr: False # scale learning rate by total batch size 71 | beta1: 0.9 72 | beta2: 0.999 73 | weight_decay: 0.01 74 | epsilon: 1e-8 75 | 76 | lr_scheduler: 77 | scheduler: "cosine" 78 | params: 79 | learning_rate: ${optimizer.params.learning_rate} 80 | warmup_steps: 1000 81 | 82 | training: 83 | gradient_accumulation_steps: 1 84 | noise_type: "mask" 85 | batch_size_t2i: 3 86 | batch_size_lm: 1 87 | batch_size_mmu: 4 88 | mixed_precision: "bf16" 89 | enable_tf32: True 90 | seed: 10086 91 | max_train_steps: 19600 92 | overfit_one_batch: False 93 | cond_dropout_prob: 0.1 94 | min_masking_rate: 0.0 95 | label_smoothing: 0.0 96 | max_grad_norm: null 97 | guidance_scale: 0.0 98 | generation_timesteps: 12 99 | t2i_coeff: 1.0 100 | lm_coeff: 0.1 101 | mmu_coeff: 1.0 102 | -------------------------------------------------------------------------------- /configs/showo_instruction_tuning_2_w_clip_vit_512x512.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: null 3 | # run_id: askkz9i2 4 | resume: 'auto' 5 | 6 | experiment: 7 | project: "tuning" 8 | name: "show-o-tuning-stage2-w-clip-vit-512x512" 9 | output_dir: "show-o-tuning-stage2-w-clip-vit-512x512" 10 | max_train_examples_t2i: 20000000 11 | max_train_examples_mmu: 40000000 12 | save_every: 10000 13 | eval_every: 2500 14 | generate_every: 1000 15 | log_every: 50 16 | log_grad_norm_every: 500 17 | resume_from_checkpoint: 'latest' 18 | 19 | model: 20 | vq_model: 21 | type: "magvitv2" 22 | vq_model_name: "showlab/magvitv2" 23 | 24 | showo: 25 | load_from_showo: False 26 | pretrained_model_path: "showlab/show-o-w-clip-vit" 27 | w_clip_vit: True 28 | vocab_size: 58498 29 | llm_vocab_size: 50295 30 | llm_model_path: 'microsoft/phi-1_5' 31 | codebook_size: 8192 32 | num_vq_tokens: 1024 33 | num_new_special_tokens: 10 # <|soi|> <|eoi|> <|sov|> <|eov|> <|t2i|> <|mmu|> <|t2v|> <|v2v|> <|lvg|> <|pad|> 34 | 35 | gradient_checkpointing: True 36 | 37 | dataset: 38 | gen_type: "t2i" 39 | und_type: "llava_tuning" 40 | combined_loader_mode: "min_size" 41 | params: 42 | train_t2i_shards_path_or_url: [ 43 | "/mnt/bn/vgfm2/test_mlx/xavier/data/XData/XData-{000000..000009}.tar", 44 | "/mnt/bn/vgfm2/test_mlx/xavier/data/XData-2/XData-2-{000000..000009}.tar", 45 | "/mnt/bn/vgfm2/test_mlx/xavier/data/XData-3/XData-3-{000000..000024}.tar", 46 | "/mnt/bn/vgfm2/test_mlx/xavier/data/XData-4/XData-4-{000000..000009}.tar", 47 | # "/mnt/bn/vgfm/JourneyDB/JourneyDB/data/train/imgs/{000..199}.tgz", 48 | ] 49 | train_mmu_shards_path_or_url: [ "/mnt/bn/vgfm2/test_mlx/xavier/data/SA1B2/sa_{000000..000999}.tar", 50 | "/mnt/bn/vgfm/cc12m/images/{00000..01242}.tar", 51 | "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar" ] 52 | train_lm_shards_path_or_url: "/mnt/bn/vgfm2/test_mlx/xavier/data/falcon-refinedweb/data/*.parquet" 53 | add_caption_prompt: True 54 | external_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/data/SAM-LLaVA-Captions10M" 55 | external_journeydb_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/code/3062/open_muse/train_journeydb_anno.json" 56 | external_laion12m_caption_path: "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-captions" 57 | external_cc12m_caption_path: '/mnt/bn/vgfm/cc12m/captions/' 58 | validation_prompts_file: "validation_prompts/text2image_prompts.txt" 59 | shuffle_buffer_size: 1000 60 | num_workers: 32 61 | resolution: 512 62 | pin_memory: True 63 | persistent_workers: True 64 | 65 | preprocessing: 66 | max_seq_length: 576 # for text tokens 67 | resolution: 512 68 | center_crop: False 69 | random_flip: False 70 | 71 | optimizer: 72 | name: adamw 73 | params: # default adamw params 74 | learning_rate: 0.0001 75 | scale_lr: False # scale learning rate by total batch size 76 | beta1: 0.9 77 | beta2: 0.999 78 | weight_decay: 0.01 79 | epsilon: 1e-8 80 | 81 | lr_scheduler: 82 | scheduler: "cosine" 83 | params: 84 | learning_rate: ${optimizer.params.learning_rate} 85 | warmup_steps: 1000 86 | 87 | training: 88 | gradient_accumulation_steps: 1 89 | noise_type: "mask" 90 | batch_size_t2i: 3 91 | batch_size_lm: 1 92 | batch_size_mmu: 4 93 | mixed_precision: "bf16" 94 | enable_tf32: True 95 | seed: 10086 96 | max_train_steps: 19600 97 | overfit_one_batch: False 98 | cond_dropout_prob: 0.1 99 | min_masking_rate: 0.0 100 | label_smoothing: 0.0 101 | max_grad_norm: null 102 | guidance_scale: 0.0 103 | generation_timesteps: 12 104 | t2i_coeff: 1.0 105 | lm_coeff: 0.1 106 | mmu_coeff: 1.0 107 | -------------------------------------------------------------------------------- /configs/showo_pretraining_stage1.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: null 3 | # run_id: askkz9i2 4 | resume: 'auto' 5 | 6 | experiment: 7 | project: "training" 8 | name: "show-o-training-stage1" 9 | output_dir: "show-o-training-stage1" 10 | max_train_examples_t2i: 40000000 11 | max_train_examples_mmu: 40000000 12 | save_every: 10000 13 | eval_every: 2500 14 | generate_every: 1000 15 | log_every: 50 16 | log_grad_norm_every: 500 17 | resume_from_checkpoint: 'latest' 18 | 19 | model: 20 | vq_model: 21 | type: "magvitv2" 22 | vq_model_name: "showlab/magvitv2" 23 | 24 | showo: 25 | load_from_showo: False 26 | pretrained_model_path: "showlab/show-o" 27 | w_clip_vit: False 28 | vocab_size: 58498 29 | llm_vocab_size: 50295 30 | llm_model_path: 'microsoft/phi-1_5' 31 | codebook_size: 8192 32 | num_vq_tokens: 256 33 | num_new_special_tokens: 10 # <|soi|> <|eoi|> <|sov|> <|eov|> <|t2i|> <|mmu|> <|t2v|> <|v2v|> <|lvg|> <|pad|> 34 | 35 | gradient_checkpointing: True 36 | 37 | dataset: 38 | gen_type: "imagenet1k" 39 | und_type: "captioning" 40 | combined_loader_mode: "max_size_cycle" 41 | params: 42 | train_t2i_shards_path_or_url: "/mnt/bn/vgfm2/test_dit/imagenet/ILSVRC/Data/CLS-LOC/train" 43 | train_mmu_shards_path_or_url: [ "/mnt/bn/vgfm2/test_mlx/xavier/data/SA1B2/sa_{000000..000999}.tar", 44 | "/mnt/bn/vgfm/cc12m/images/{00000..01242}.tar", 45 | "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar" ] 46 | train_lm_shards_path_or_url: "/mnt/bn/vgfm2/test_mlx/xavier/data/falcon-refinedweb/data/*.parquet" 47 | add_caption_prompt: True 48 | external_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/data/SAM-LLaVA-Captions10M" 49 | external_journeydb_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/code/3062/open_muse/train_journeydb_anno.json" 50 | external_laion12m_caption_path: "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-captions" 51 | external_cc12m_caption_path: '/mnt/bn/vgfm/cc12m/captions/' 52 | validation_prompts_file: "validation_prompts/imagenet_prompts.txt" 53 | shuffle_buffer_size: 1000 54 | num_workers: 32 55 | resolution: 256 56 | pin_memory: True 57 | persistent_workers: True 58 | 59 | preprocessing: 60 | max_seq_length: 128 # for text tokens 61 | resolution: 256 62 | center_crop: False 63 | random_flip: False 64 | 65 | optimizer: 66 | name: adamw 67 | params: # default adamw params 68 | learning_rate: 1e-4 69 | scale_lr: False # scale learning rate by total batch size 70 | beta1: 0.9 71 | beta2: 0.999 72 | weight_decay: 0.01 73 | epsilon: 1e-8 74 | 75 | lr_scheduler: 76 | scheduler: "cosine" 77 | params: 78 | learning_rate: ${optimizer.params.learning_rate} 79 | warmup_steps: 5000 80 | 81 | training: 82 | gradient_accumulation_steps: 1 83 | noise_type: "mask" 84 | batch_size_t2i: 15 85 | batch_size_lm: 4 86 | batch_size_mmu: 10 87 | mixed_precision: "bf16" 88 | enable_tf32: True 89 | seed: 10086 90 | max_train_steps: 500000 91 | overfit_one_batch: False 92 | cond_dropout_prob: 0.1 93 | min_masking_rate: 0.0 94 | label_smoothing: 0.0 95 | max_grad_norm: null 96 | guidance_scale: 1.5 97 | generation_timesteps: 12 98 | t2i_coeff: 1.0 99 | lm_coeff: 0.1 100 | mmu_coeff: 1.0 101 | -------------------------------------------------------------------------------- /configs/showo_pretraining_stage2.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: null 3 | # run_id: askkz9i2 4 | resume: 'auto' 5 | 6 | experiment: 7 | project: "training" 8 | name: "show-o-training-stage2" 9 | output_dir: "show-o-training-stage2" 10 | max_train_examples_t2i: 40000000 11 | max_train_examples_mmu: 40000000 12 | save_every: 10000 13 | eval_every: 2500 14 | generate_every: 1000 15 | log_every: 50 16 | log_grad_norm_every: 500 17 | resume_from_checkpoint: 'latest' 18 | 19 | model: 20 | vq_model: 21 | type: "magvitv2" 22 | vq_model_name: "showlab/magvitv2" 23 | 24 | showo: 25 | load_from_showo: False 26 | pretrained_model_path: "showlab/show-o" 27 | w_clip_vit: False 28 | vocab_size: 58498 29 | llm_vocab_size: 50295 30 | llm_model_path: 'microsoft/phi-1_5' 31 | codebook_size: 8192 32 | num_vq_tokens: 256 33 | num_new_special_tokens: 10 # <|soi|> <|eoi|> <|sov|> <|eov|> <|t2i|> <|mmu|> <|t2v|> <|v2v|> <|lvg|> <|pad|> 34 | 35 | gradient_checkpointing: True 36 | 37 | dataset: 38 | gen_type: "t2i" 39 | und_type: "captioning" 40 | combined_loader_mode: "max_size_cycle" 41 | params: 42 | train_t2i_shards_path_or_url: [ "/mnt/bn/vgfm2/test_mlx/xavier/data/SA1B2/sa_{000000..000999}.tar", 43 | "/mnt/bn/vgfm/cc12m/images/{00000..01242}.tar", 44 | "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar"] 45 | train_mmu_shards_path_or_url: [ "/mnt/bn/vgfm2/test_mlx/xavier/data/SA1B2/sa_{000000..000999}.tar", 46 | "/mnt/bn/vgfm/cc12m/images/{00000..01242}.tar", 47 | "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar" ] 48 | train_lm_shards_path_or_url: "/mnt/bn/vgfm2/test_mlx/xavier/data/falcon-refinedweb/data/*.parquet" 49 | add_caption_prompt: True 50 | external_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/data/SAM-LLaVA-Captions10M" 51 | external_journeydb_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/code/3062/open_muse/train_journeydb_anno.json" 52 | external_laion12m_caption_path: "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-captions" 53 | external_cc12m_caption_path: '/mnt/bn/vgfm/cc12m/captions/' 54 | validation_prompts_file: "validation_prompts/text2image_prompts.txt" 55 | shuffle_buffer_size: 1000 56 | num_workers: 32 57 | resolution: 256 58 | pin_memory: True 59 | persistent_workers: True 60 | 61 | preprocessing: 62 | max_seq_length: 128 # for text tokens 63 | resolution: 256 64 | center_crop: False 65 | random_flip: False 66 | 67 | optimizer: 68 | name: adamw 69 | params: # default adamw params 70 | learning_rate: 1e-4 71 | scale_lr: False # scale learning rate by total batch size 72 | beta1: 0.9 73 | beta2: 0.999 74 | weight_decay: 0.01 75 | epsilon: 1e-8 76 | 77 | lr_scheduler: 78 | scheduler: "cosine" 79 | params: 80 | learning_rate: ${optimizer.params.learning_rate} 81 | warmup_steps: 5000 82 | 83 | training: 84 | gradient_accumulation_steps: 1 85 | noise_type: "mask" 86 | batch_size_t2i: 10 87 | batch_size_lm: 4 88 | batch_size_mmu: 10 89 | mixed_precision: "bf16" 90 | enable_tf32: True 91 | seed: 10086 92 | max_train_steps: 1000000 93 | overfit_one_batch: False 94 | cond_dropout_prob: 0.1 95 | min_masking_rate: 0.0 96 | label_smoothing: 0.0 97 | max_grad_norm: null 98 | guidance_scale: 0.0 99 | generation_timesteps: 12 100 | t2i_coeff: 1.0 101 | lm_coeff: 0.1 102 | mmu_coeff: 1.0 103 | -------------------------------------------------------------------------------- /configs/showo_pretraining_stage3.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: null 3 | # run_id: askkz9i2 4 | resume: 'auto' 5 | 6 | experiment: 7 | project: "training" 8 | name: "show-o-training-stage3" 9 | output_dir: "show-o-training-stage3" 10 | max_train_examples_t2i: 20000000 11 | max_train_examples_mmu: 40000000 12 | save_every: 10000 13 | eval_every: 2500 14 | generate_every: 1000 15 | log_every: 50 16 | log_grad_norm_every: 500 17 | resume_from_checkpoint: 'latest' 18 | 19 | model: 20 | vq_model: 21 | type: "magvitv2" 22 | vq_model_name: "showlab/magvitv2" 23 | 24 | showo: 25 | load_from_showo: False 26 | pretrained_model_path: "showlab/show-o" 27 | w_clip_vit: False 28 | vocab_size: 58498 29 | llm_vocab_size: 50295 30 | llm_model_path: 'microsoft/phi-1_5' 31 | codebook_size: 8192 32 | num_vq_tokens: 256 33 | num_new_special_tokens: 10 # <|soi|> <|eoi|> <|sov|> <|eov|> <|t2i|> <|mmu|> <|t2v|> <|v2v|> <|lvg|> <|pad|> 34 | 35 | gradient_checkpointing: True 36 | 37 | dataset: 38 | gen_type: "t2i" 39 | und_type: "captioning" 40 | combined_loader_mode: "max_size_cycle" 41 | params: 42 | train_t2i_shards_path_or_url: ["/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar", 43 | "/mnt/bn/vgfm/JourneyDB/JourneyDB/data/train/imgs/{000..199}.tgz" ] 44 | train_mmu_shards_path_or_url: [ "/mnt/bn/vgfm2/test_mlx/xavier/data/SA1B2/sa_{000000..000999}.tar", 45 | "/mnt/bn/vgfm/cc12m/images/{00000..01242}.tar", 46 | "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-images/{00000..01209}.tar" ] 47 | train_lm_shards_path_or_url: "/mnt/bn/vgfm2/test_mlx/xavier/data/falcon-refinedweb/data/*.parquet" 48 | add_caption_prompt: True 49 | external_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/data/SAM-LLaVA-Captions10M" 50 | external_journeydb_caption_path: "/mnt/bn/vgfm2/test_mlx/xavier/code/3062/open_muse/train_journeydb_anno.json" 51 | external_laion12m_caption_path: "/mnt/bn/vgfm/laion5b/laion-aesthetics-12m-captions" 52 | external_cc12m_caption_path: '/mnt/bn/vgfm/cc12m/captions/' 53 | validation_prompts_file: "validation_prompts/text2image_prompts.txt" 54 | shuffle_buffer_size: 1000 55 | num_workers: 32 56 | resolution: 256 57 | pin_memory: True 58 | persistent_workers: True 59 | 60 | preprocessing: 61 | max_seq_length: 128 # for text tokens 62 | resolution: 256 63 | center_crop: False 64 | random_flip: False 65 | 66 | optimizer: 67 | name: adamw 68 | params: # default adamw params 69 | learning_rate: 2e-5 70 | scale_lr: False # scale learning rate by total batch size 71 | beta1: 0.9 72 | beta2: 0.999 73 | weight_decay: 0.01 74 | epsilon: 1e-8 75 | 76 | lr_scheduler: 77 | scheduler: "cosine" 78 | params: 79 | learning_rate: ${optimizer.params.learning_rate} 80 | warmup_steps: 5000 81 | 82 | training: 83 | gradient_accumulation_steps: 1 84 | noise_type: "mask" 85 | batch_size_t2i: 10 86 | batch_size_lm: 4 87 | batch_size_mmu: 10 88 | mixed_precision: "bf16" 89 | enable_tf32: True 90 | seed: 10086 91 | max_train_steps: 50000 # to be determined according to the scale of high-quality dataset 92 | overfit_one_batch: False 93 | cond_dropout_prob: 0.1 94 | min_masking_rate: 0.0 95 | label_smoothing: 0.0 96 | max_grad_norm: null 97 | guidance_scale: 0.0 98 | generation_timesteps: 12 99 | t2i_coeff: 1.0 100 | lm_coeff: 0.1 101 | mmu_coeff: 1.0 102 | -------------------------------------------------------------------------------- /docs/characteristic_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/docs/characteristic_comparison.png -------------------------------------------------------------------------------- /docs/github_extrapolation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/docs/github_extrapolation.png -------------------------------------------------------------------------------- /docs/github_inpainting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/docs/github_inpainting.png -------------------------------------------------------------------------------- /docs/github_mmu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/docs/github_mmu.png -------------------------------------------------------------------------------- /docs/github_t2i.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/docs/github_t2i.png -------------------------------------------------------------------------------- /docs/show-o-512x512-mmu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/docs/show-o-512x512-mmu.png -------------------------------------------------------------------------------- /docs/show-o-512x512-t2i.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/docs/show-o-512x512-t2i.png -------------------------------------------------------------------------------- /docs/show-o-ablation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/docs/show-o-ablation.png -------------------------------------------------------------------------------- /docs/show-o-geneval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/docs/show-o-geneval.png -------------------------------------------------------------------------------- /docs/show-o-want-u.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/docs/show-o-want-u.png -------------------------------------------------------------------------------- /docs/showo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/docs/showo.png -------------------------------------------------------------------------------- /docs/showo_title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/docs/showo_title.png -------------------------------------------------------------------------------- /docs/wechat_qa_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/docs/wechat_qa_3.jpg -------------------------------------------------------------------------------- /inference_mmu.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "source": [ 8 | "import torch\n", 9 | "from models import Showo, MAGVITv2\n", 10 | "from training.prompting_utils import UniversalPrompting, create_attention_mask_for_mmu, create_attention_mask_for_mmu_vit\n", 11 | "from training.utils import get_config, flatten_omega_conf, image_transform\n", 12 | "from transformers import AutoTokenizer\n", 13 | "from models.clip_encoder import CLIPVisionTower\n", 14 | "from transformers import CLIPImageProcessor\n", 15 | "import training.conversation as conversation_lib\n", 16 | "\n", 17 | "conversation_lib.default_conversation = conversation_lib.conv_templates[\"phi1.5\"]\n" 18 | ], 19 | "outputs": [] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 3, 24 | "metadata": {}, 25 | "source": [ 26 | "# config load - 'showo_demo_w_clip_vit.yaml'\n", 27 | "from omegaconf import DictConfig, ListConfig, OmegaConf\n", 28 | "config = OmegaConf.load('configs/showo_demo_w_clip_vit.yaml')" 29 | ], 30 | "outputs": [] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 4, 35 | "metadata": {}, 36 | "source": [ 37 | "# device setup\n", 38 | "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n", 39 | "# device = \"cpu\"" 40 | ], 41 | "outputs": [] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 5, 46 | "metadata": {}, 47 | "source": [ 48 | "\n", 49 | "# show o tokenizer setup and adding special tokens to universal prompting\n", 50 | "# llm model : 'microsoft/phi-1_5'\n", 51 | "tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side =\"left\")\n", 52 | "uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,\n", 53 | " special_tokens=(\"<|soi|>\", \"<|eoi|>\", \"<|sov|>\", \"<|eov|>\", \"<|t2i|>\", \"<|mmu|>\", \"<|t2v|>\", \"<|v2v|>\", \"<|lvg|>\"),\n", 54 | " ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)\n" 55 | ], 56 | "outputs": [] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 6, 61 | "metadata": {}, 62 | "source": [ 63 | "# setting up the visual question answering model: magvit-v2\n", 64 | "vq_model = MAGVITv2\n", 65 | "vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device)\n", 66 | "vq_model.requires_grad_(False)\n", 67 | "vq_model.eval()" 68 | ], 69 | "outputs": [] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 7, 74 | "metadata": {}, 75 | "source": [ 76 | "# setting up vision tower: clip-vit\n", 77 | "vision_tower_name =\"openai/clip-vit-large-patch14-336\"\n", 78 | "vision_tower = CLIPVisionTower(vision_tower_name).to(device)\n", 79 | "clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)\n" 80 | ], 81 | "outputs": [] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 8, 86 | "metadata": {}, 87 | "source": [ 88 | "# setting up the showo model \n", 89 | "model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)\n", 90 | "model.eval()" 91 | ], 92 | "outputs": [] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 9, 97 | "metadata": {}, 98 | "source": [ 99 | "# setting up the parameters\n", 100 | "temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions\n", 101 | "top_k = 1 # retain only the top_k most likely tokens, clamp others to have 0 probability\n", 102 | "SYSTEM_PROMPT = \"A chat between a curious user and an artificial intelligence assistant. \" \\\n", 103 | " \"The assistant gives helpful, detailed, and polite answers to the user's questions.\"\n", 104 | "SYSTEM_PROMPT_LEN = 28\n" 105 | ], 106 | "outputs": [] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [ 112 | "## Inference " 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 1, 118 | "metadata": {}, 119 | "source": [ 120 | "import os\n", 121 | "import requests\n", 122 | "from IPython.display import Image\n", 123 | "from urllib.parse import urlparse\n", 124 | "\n", 125 | "def load_image(path_or_url, save_dir=\"downloaded_images\"):\n", 126 | " \"\"\"Load image from local path or URL.\"\"\"\n", 127 | " if os.path.exists(path_or_url):\n", 128 | " return Image(filename=path_or_url)\n", 129 | "\n", 130 | " os.makedirs(save_dir, exist_ok=True)\n", 131 | " filename = os.path.join(save_dir, os.path.basename(urlparse(path_or_url).path))\n", 132 | " \n", 133 | " with requests.get(path_or_url, stream=True) as r:\n", 134 | " if r.status_code == 200:\n", 135 | " with open(filename, \"wb\") as f:\n", 136 | " for chunk in r.iter_content(1024):\n", 137 | " f.write(chunk)\n", 138 | " return Image(filename=filename)\n", 139 | " \n", 140 | " print(\"Failed to load image.\")\n", 141 | " return None\n", 142 | "\n", 143 | "# Example usage\n", 144 | "image_path_or_url = \"/home/grads/h/hasnat.md.abdullah/Show-o/mmu_validation/sofa_under_water.jpg\" # Or a URL\n", 145 | "load_image(image_path_or_url)" 146 | ], 147 | "outputs": [] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 14, 152 | "metadata": {}, 153 | "source": [ 154 | "# inference\n", 155 | "from PIL import Image\n", 156 | "## arguments\n", 157 | "input_image_path =\"./mmu_validation/sofa_under_water.jpg\"\n", 158 | "questions ='Please describe this image in detail. *** Do you think the image is unusual or not?'\n", 159 | "\n", 160 | "## processing\n", 161 | "questions = questions.split('***')\n", 162 | "image_ori = Image.open(input_image_path).convert(\"RGB\")\n", 163 | "# tranforming the image to the required resolution:256x256\n", 164 | "image = image_transform(image_ori, resolution = config.dataset.params.resolution).to(device)\n", 165 | "image = image.unsqueeze(0)\n", 166 | "print(f\"image shape: {image.shape}\") # torch.Size([1, 3, 256, 256])\n", 167 | "pixel_values = clip_image_processor.preprocess(image_ori,return_tensors=\"pt\")['pixel_values'][0]\n", 168 | "print(f\"pixel values shape: {pixel_values.shape}\")\n", 169 | "image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer)\n", 170 | "print(f\"image tokens shape: {image_tokens.shape}\") # torch.Size([1, 256])\n", 171 | "batch_size = 1\n", 172 | "\n", 173 | "## inference\n", 174 | "for question in questions: \n", 175 | " conv = conversation_lib.default_conversation.copy()\n", 176 | " print(f\"conversation: {conv}\")\n", 177 | " conv.append_message(conv.roles[0], question)\n", 178 | " conv.append_message(conv.roles[1], None)\n", 179 | " prompt_question = conv.get_prompt()\n", 180 | " # print(prompt_question)\n", 181 | " question_input = []\n", 182 | " question_input.append(prompt_question.strip())\n", 183 | " print(f\"system prompt: {SYSTEM_PROMPT}\")\n", 184 | " input_ids_system = [uni_prompting.text_tokenizer(SYSTEM_PROMPT, return_tensors=\"pt\", padding=\"longest\").input_ids for _ in range(batch_size)]\n", 185 | " print(f\"system prompt input ids: {input_ids_system}\")\n", 186 | " input_ids_system = torch.stack(input_ids_system, dim=0)\n", 187 | " assert input_ids_system.shape[-1] == 28\n", 188 | " print(f\"after torch stacking: {input_ids_system}\")\n", 189 | " input_ids_system = input_ids_system.clone().detach().to(device)\n", 190 | " # inputs_ids_system = input_ids_system.to(device)\n", 191 | "# inputs_ids_system = torch.tensor(input_ids_system).to(device).squeeze(0)\n", 192 | " \n", 193 | " print(f\"after moving to device: {input_ids_system}\")\n", 194 | " input_ids_system = input_ids_system[0]\n", 195 | " print(f\"after indexing 0: {input_ids_system}\")\n", 196 | " \n", 197 | " \n", 198 | " print(f\"question input: {question_input}\")\n", 199 | " input_ids = [uni_prompting.text_tokenizer(prompt, return_tensors=\"pt\", padding=\"longest\").input_ids for prompt in question_input]\n", 200 | " print(f\"after tokenizing the question: {input_ids}\")\n", 201 | " input_ids = torch.stack(input_ids)\n", 202 | " print(f\"after torch stacking: {input_ids}\")\n", 203 | " input_ids = torch.nn.utils.rnn.pad_sequence(\n", 204 | " input_ids, batch_first=True, padding_value=uni_prompting.text_tokenizer.pad_token_id\n", 205 | " )\n", 206 | " print(f\"after padding: {input_ids}\")\n", 207 | " # input_ids = torch.tensor(input_ids).to(device).squeeze(0)\n", 208 | " input_ids = input_ids.clone().detach().to(device).squeeze(0)\n", 209 | " print(f\"after moving to device: {input_ids}\")\n", 210 | " input_ids_llava = torch.cat([\n", 211 | " (torch.ones(input_ids.shape[0], 1) *uni_prompting.sptids_dict['<|mmu|>']).to(device),\n", 212 | " input_ids_system,\n", 213 | " (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device),\n", 214 | " # place your img embedding here\n", 215 | " (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device),\n", 216 | " input_ids,\n", 217 | " ], dim=1).long()\n", 218 | " print(input_ids_llava)\n", 219 | " \n", 220 | " images_embeddings = vision_tower(pixel_values[None])\n", 221 | " print(f\"images embeddings shape: {images_embeddings.shape}\")# torch.Size([1, 576, 1024])\n", 222 | " images_embeddings = model.mm_projector(images_embeddings)\n", 223 | " print(f\"images embeddings shape after projection: {images_embeddings.shape}\") \n", 224 | "\n", 225 | " text_embeddings = model.showo.model.embed_tokens(input_ids_llava)\n", 226 | "\n", 227 | " #full input seq\n", 228 | " part1 = text_embeddings[:, :2+SYSTEM_PROMPT_LEN,:]\n", 229 | " part2 = text_embeddings[:, 2+SYSTEM_PROMPT_LEN:,:]\n", 230 | " input_embeddings = torch.cat((part1,images_embeddings,part2),dim=1)\n", 231 | "\n", 232 | " attention_mask_llava = create_attention_mask_for_mmu_vit(input_embeddings,system_prompt_len=SYSTEM_PROMPT_LEN)\n", 233 | "\n", 234 | " cont_toks_list = model.mmu_generate(\n", 235 | " input_embeddings = input_embeddings,\n", 236 | " attention_mask = attention_mask_llava[0].unsqueeze(0),\n", 237 | " max_new_tokens = 100,\n", 238 | " top_k = top_k,\n", 239 | " eot_token = uni_prompting.sptids_dict['<|eov|>']\n", 240 | " )\n", 241 | " \n", 242 | " cont_toks_list = torch.stack(cont_toks_list).squeeze()[None]\n", 243 | " text = uni_prompting.text_tokenizer.batch_decode(cont_toks_list,skip_special_tokens=True)\n", 244 | " print(f\"User: {question}, \\nAnswer: {text[0]}\")\n", 245 | "\n", 246 | "\n" 247 | ], 248 | "outputs": [] 249 | } 250 | ], 251 | "metadata": { 252 | "kernelspec": { 253 | "display_name": "Python 3", 254 | "language": "python", 255 | "name": "python3" 256 | }, 257 | "language_info": { 258 | "codemirror_mode": { 259 | "name": "ipython", 260 | "version": 3 261 | }, 262 | "file_extension": ".py", 263 | "mimetype": "text/x-python", 264 | "name": "python", 265 | "nbconvert_exporter": "python", 266 | "pygments_lexer": "ipython3", 267 | "version": "3.9.19" 268 | } 269 | }, 270 | "nbformat": 4, 271 | "nbformat_minor": 2 272 | } 273 | -------------------------------------------------------------------------------- /inference_mmu.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 NUS Show Lab. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 18 | from PIL import Image 19 | from tqdm import tqdm 20 | import numpy as np 21 | import torch 22 | import wandb 23 | from models import Showo, MAGVITv2, CLIPVisionTower 24 | from training.prompting_utils import UniversalPrompting, create_attention_mask_for_mmu, create_attention_mask_for_mmu_vit 25 | from training.utils import get_config, flatten_omega_conf, image_transform 26 | from transformers import AutoTokenizer 27 | from transformers import CLIPImageProcessor 28 | 29 | from llava.llava import conversation as conversation_lib 30 | 31 | conversation_lib.default_conversation = conversation_lib.conv_templates["phi1.5"] 32 | SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. " \ 33 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 34 | SYSTEM_PROMPT_LEN = 28 35 | 36 | def get_vq_model_class(model_type): 37 | if model_type == "magvitv2": 38 | return MAGVITv2 39 | else: 40 | raise ValueError(f"model_type {model_type} not supported.") 41 | 42 | if __name__ == '__main__': 43 | 44 | config = get_config() 45 | 46 | resume_wandb_run = config.wandb.resume 47 | run_id = config.wandb.get("run_id", None) 48 | if run_id is None: 49 | resume_wandb_run = False 50 | run_id = wandb.util.generate_id() 51 | config.wandb.run_id = run_id 52 | 53 | wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} 54 | 55 | wandb.init( 56 | project="demo", 57 | name=config.experiment.name + '_mmu', 58 | config=wandb_config, 59 | ) 60 | 61 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 62 | tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left") 63 | 64 | uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, 65 | special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"), 66 | ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob) 67 | 68 | vq_model = get_vq_model_class(config.model.vq_model.type) 69 | vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device) 70 | vq_model.requires_grad_(False) 71 | vq_model.eval() 72 | 73 | vision_tower_name = "openai/clip-vit-large-patch14-336" 74 | vision_tower = CLIPVisionTower(vision_tower_name).to(device) 75 | clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name) 76 | 77 | model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device) 78 | model.eval() 79 | 80 | temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions 81 | top_k = 1 # retain only the top_k most likely tokens, clamp others to have 0 probability 82 | 83 | file_list = os.listdir(config.mmu_image_root) 84 | responses = ['' for i in range(len(file_list))] 85 | images = [] 86 | config.question = config.question.split(' *** ') 87 | for i, file_name in enumerate(tqdm(file_list)): 88 | image_path = os.path.join(config.mmu_image_root, file_name) 89 | image_ori = Image.open(image_path).convert("RGB") 90 | image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device) 91 | image = image.unsqueeze(0) 92 | images.append(image) 93 | 94 | pixel_values = clip_image_processor.preprocess(image_ori, return_tensors="pt")["pixel_values"][0] 95 | 96 | image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer) 97 | batch_size = 1 98 | 99 | for question in config.question: 100 | if config.model.showo.w_clip_vit: 101 | conv = conversation_lib.default_conversation.copy() 102 | conv.append_message(conv.roles[0], question) 103 | conv.append_message(conv.roles[1], None) 104 | prompt_question = conv.get_prompt() 105 | question_input = [] 106 | question_input.append(prompt_question.strip()) 107 | 108 | input_ids_system = [uni_prompting.text_tokenizer(SYSTEM_PROMPT, return_tensors="pt", padding="longest").input_ids 109 | for _ in range(batch_size)] 110 | input_ids_system = torch.stack(input_ids_system, dim=0) 111 | assert input_ids_system.shape[-1] == 28 112 | input_ids_system = input_ids_system.to(device) 113 | input_ids_system = input_ids_system[0] 114 | 115 | input_ids = [uni_prompting.text_tokenizer(prompt, return_tensors="pt", padding="longest").input_ids 116 | for prompt in question_input] 117 | 118 | input_ids = torch.stack(input_ids) 119 | input_ids = torch.nn.utils.rnn.pad_sequence( 120 | input_ids, batch_first=True, padding_value=uni_prompting.text_tokenizer.pad_token_id 121 | ) 122 | input_ids = torch.tensor(input_ids).to(device).squeeze(0) 123 | # import pdb; pdb.set_trace() 124 | input_ids_llava = torch.cat([ 125 | (torch.ones(input_ids.shape[0], 1) *uni_prompting.sptids_dict['<|mmu|>']).to(device), 126 | input_ids_system, 127 | (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device), 128 | # place your img embedding here 129 | (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device), 130 | input_ids, 131 | ], dim=1).long() 132 | 133 | images_embeddings = vision_tower(pixel_values[None]) 134 | images_embeddings = model.mm_projector(images_embeddings) 135 | 136 | text_embeddings = model.showo.model.embed_tokens(input_ids_llava) 137 | 138 | # Full input seq 139 | part1 = text_embeddings[:, :2 + SYSTEM_PROMPT_LEN, :] 140 | part2 = text_embeddings[:, 2 + SYSTEM_PROMPT_LEN:, :] 141 | input_embeddings = torch.cat((part1, images_embeddings, part2), dim=1) 142 | 143 | attention_mask_llava = create_attention_mask_for_mmu_vit(input_embeddings, 144 | system_prompt_len=SYSTEM_PROMPT_LEN) 145 | 146 | cont_toks_list = model.mmu_generate(input_embeddings=input_embeddings, 147 | attention_mask=attention_mask_llava[0].unsqueeze(0), 148 | max_new_tokens=config.max_new_tokens, 149 | top_k=top_k, 150 | eot_token=tokenizer.eos_token_id 151 | ) 152 | else: 153 | input_ids = uni_prompting.text_tokenizer(['USER: \n' + question + ' ASSISTANT:'])[ 154 | 'input_ids'] 155 | input_ids = torch.tensor(input_ids).to(device) 156 | 157 | input_ids = torch.cat([ 158 | (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device), 159 | (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device), 160 | image_tokens, 161 | (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device), 162 | (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device), 163 | input_ids 164 | ], dim=1).long() 165 | 166 | attention_mask = create_attention_mask_for_mmu(input_ids.to(device), 167 | eoi_id=int(uni_prompting.sptids_dict['<|eoi|>'])) 168 | 169 | cont_toks_list = model.mmu_generate(input_ids, attention_mask=attention_mask, 170 | max_new_tokens=config.max_new_tokens, top_k=top_k, 171 | eot_token=uni_prompting.sptids_dict['<|eot|>']) 172 | 173 | cont_toks_list = torch.stack(cont_toks_list).squeeze()[None] 174 | 175 | text = uni_prompting.text_tokenizer.batch_decode(cont_toks_list, skip_special_tokens=True) 176 | print(text) 177 | responses[i] += f'User: ' + question + f'\n Answer : ' + text[0] + '\n' 178 | 179 | images = torch.cat(images, dim=0) 180 | images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) 181 | images *= 255.0 182 | images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) 183 | pil_images = [Image.fromarray(image) for image in images] 184 | 185 | wandb_images = [wandb.Image(image, caption=responses[i]) for i, image in enumerate(pil_images)] 186 | wandb.log({"multimodal understanding": wandb_images}, step=0) 187 | 188 | -------------------------------------------------------------------------------- /inpainting_validation/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/.DS_Store -------------------------------------------------------------------------------- /inpainting_validation/alpine_lake.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/alpine_lake.jpg -------------------------------------------------------------------------------- /inpainting_validation/bedroom.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/bedroom.jpg -------------------------------------------------------------------------------- /inpainting_validation/bedroom_mask.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/bedroom_mask.webp -------------------------------------------------------------------------------- /inpainting_validation/bench.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/bench.jpg -------------------------------------------------------------------------------- /inpainting_validation/bench_mask.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/bench_mask.webp -------------------------------------------------------------------------------- /inpainting_validation/bus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/bus.jpg -------------------------------------------------------------------------------- /inpainting_validation/bus_mask.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/bus_mask.webp -------------------------------------------------------------------------------- /inpainting_validation/lake_mountain.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/lake_mountain.jpg -------------------------------------------------------------------------------- /inpainting_validation/maya.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/maya.png -------------------------------------------------------------------------------- /inpainting_validation/river.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/river.png -------------------------------------------------------------------------------- /inpainting_validation/train.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/train.jpg -------------------------------------------------------------------------------- /inpainting_validation/train_mask.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/train_mask.webp -------------------------------------------------------------------------------- /inpainting_validation/truebsee.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/truebsee.jpg -------------------------------------------------------------------------------- /inpainting_validation/truebsee_mask.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/truebsee_mask.webp -------------------------------------------------------------------------------- /inpainting_validation/wukong1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/wukong1.jpg -------------------------------------------------------------------------------- /inpainting_validation/wukong2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/inpainting_validation/wukong2.jpg -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/llava/__init__.py -------------------------------------------------------------------------------- /llava/llava/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/llava/llava/__init__.py -------------------------------------------------------------------------------- /llava/llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /llava/llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | import torch 5 | import math 6 | import ast 7 | 8 | from transformers import StoppingCriteria 9 | from .constants import IMAGE_TOKEN_INDEX 10 | 11 | 12 | def select_best_resolution(original_size, possible_resolutions): 13 | """ 14 | Selects the best resolution from a list of possible resolutions based on the original size. 15 | 16 | Args: 17 | original_size (tuple): The original size of the image in the format (width, height). 18 | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. 19 | 20 | Returns: 21 | tuple: The best fit resolution in the format (width, height). 22 | """ 23 | original_width, original_height = original_size 24 | best_fit = None 25 | max_effective_resolution = 0 26 | min_wasted_resolution = float('inf') 27 | 28 | for width, height in possible_resolutions: 29 | scale = min(width / original_width, height / original_height) 30 | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) 31 | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) 32 | wasted_resolution = (width * height) - effective_resolution 33 | 34 | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): 35 | max_effective_resolution = effective_resolution 36 | min_wasted_resolution = wasted_resolution 37 | best_fit = (width, height) 38 | 39 | return best_fit 40 | 41 | 42 | def resize_and_pad_image(image, target_resolution): 43 | """ 44 | Resize and pad an image to a target resolution while maintaining aspect ratio. 45 | 46 | Args: 47 | image (PIL.Image.Image): The input image. 48 | target_resolution (tuple): The target resolution (width, height) of the image. 49 | 50 | Returns: 51 | PIL.Image.Image: The resized and padded image. 52 | """ 53 | original_width, original_height = image.size 54 | target_width, target_height = target_resolution 55 | 56 | scale_w = target_width / original_width 57 | scale_h = target_height / original_height 58 | 59 | if scale_w < scale_h: 60 | new_width = target_width 61 | new_height = min(math.ceil(original_height * scale_w), target_height) 62 | else: 63 | new_height = target_height 64 | new_width = min(math.ceil(original_width * scale_h), target_width) 65 | 66 | # Resize the image 67 | resized_image = image.resize((new_width, new_height)) 68 | 69 | new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) 70 | paste_x = (target_width - new_width) // 2 71 | paste_y = (target_height - new_height) // 2 72 | new_image.paste(resized_image, (paste_x, paste_y)) 73 | 74 | return new_image 75 | 76 | 77 | def divide_to_patches(image, patch_size): 78 | """ 79 | Divides an image into patches of a specified size. 80 | 81 | Args: 82 | image (PIL.Image.Image): The input image. 83 | patch_size (int): The size of each patch. 84 | 85 | Returns: 86 | list: A list of PIL.Image.Image objects representing the patches. 87 | """ 88 | patches = [] 89 | width, height = image.size 90 | for i in range(0, height, patch_size): 91 | for j in range(0, width, patch_size): 92 | box = (j, i, j + patch_size, i + patch_size) 93 | patch = image.crop(box) 94 | patches.append(patch) 95 | 96 | return patches 97 | 98 | 99 | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): 100 | """ 101 | Calculate the shape of the image patch grid after the preprocessing for images of any resolution. 102 | 103 | Args: 104 | image_size (tuple): The size of the input image in the format (width, height). 105 | grid_pinpoints (str): A string representation of a list of possible resolutions. 106 | patch_size (int): The size of each image patch. 107 | 108 | Returns: 109 | tuple: The shape of the image patch grid in the format (width, height). 110 | """ 111 | if type(grid_pinpoints) is list: 112 | possible_resolutions = grid_pinpoints 113 | else: 114 | possible_resolutions = ast.literal_eval(grid_pinpoints) 115 | width, height = select_best_resolution(image_size, possible_resolutions) 116 | return width // patch_size, height // patch_size 117 | 118 | 119 | def process_anyres_image(image, processor, grid_pinpoints): 120 | """ 121 | Process an image with variable resolutions. 122 | 123 | Args: 124 | image (PIL.Image.Image): The input image to be processed. 125 | processor: The image processor object. 126 | grid_pinpoints (str): A string representation of a list of possible resolutions. 127 | 128 | Returns: 129 | torch.Tensor: A tensor containing the processed image patches. 130 | """ 131 | if type(grid_pinpoints) is list: 132 | possible_resolutions = grid_pinpoints 133 | else: 134 | possible_resolutions = ast.literal_eval(grid_pinpoints) 135 | best_resolution = select_best_resolution(image.size, possible_resolutions) 136 | image_padded = resize_and_pad_image(image, best_resolution) 137 | 138 | patches = divide_to_patches(image_padded, processor.crop_size['height']) 139 | 140 | image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) 141 | 142 | image_patches = [image_original_resize] + patches 143 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] 144 | for image_patch in image_patches] 145 | return torch.stack(image_patches, dim=0) 146 | 147 | 148 | def load_image_from_base64(image): 149 | return Image.open(BytesIO(base64.b64decode(image))) 150 | 151 | 152 | def expand2square(pil_img, background_color): 153 | width, height = pil_img.size 154 | if width == height: 155 | return pil_img 156 | elif width > height: 157 | result = Image.new(pil_img.mode, (width, width), background_color) 158 | result.paste(pil_img, (0, (width - height) // 2)) 159 | return result 160 | else: 161 | result = Image.new(pil_img.mode, (height, height), background_color) 162 | result.paste(pil_img, ((height - width) // 2, 0)) 163 | return result 164 | 165 | 166 | def process_images(images, image_processor, model_cfg): 167 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 168 | new_images = [] 169 | if image_aspect_ratio == 'pad': 170 | for image in images: 171 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 172 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 173 | new_images.append(image) 174 | elif image_aspect_ratio == "anyres": 175 | for image in images: 176 | image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) 177 | new_images.append(image) 178 | else: 179 | return image_processor(images, return_tensors='pt')['pixel_values'] 180 | if all(x.shape == new_images[0].shape for x in new_images): 181 | new_images = torch.stack(new_images, dim=0) 182 | return new_images 183 | 184 | 185 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 186 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 187 | 188 | def insert_separator(X, sep): 189 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 190 | 191 | input_ids = [] 192 | offset = 0 193 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 194 | offset = 1 195 | input_ids.append(prompt_chunks[0][0]) 196 | 197 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 198 | input_ids.extend(x[offset:]) 199 | 200 | if return_tensors is not None: 201 | if return_tensors == 'pt': 202 | return torch.tensor(input_ids, dtype=torch.long) 203 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 204 | return input_ids 205 | 206 | 207 | def get_model_name_from_path(model_path): 208 | model_path = model_path.strip("/") 209 | model_paths = model_path.split("/") 210 | if model_paths[-1].startswith('checkpoint-'): 211 | return model_paths[-2] + "_" + model_paths[-1] 212 | else: 213 | return model_paths[-1] 214 | 215 | 216 | class KeywordsStoppingCriteria(StoppingCriteria): 217 | def __init__(self, keywords, tokenizer, input_ids): 218 | self.keywords = keywords 219 | self.keyword_ids = [] 220 | self.max_keyword_len = 0 221 | for keyword in keywords: 222 | cur_keyword_ids = tokenizer(keyword).input_ids 223 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 224 | cur_keyword_ids = cur_keyword_ids[1:] 225 | if len(cur_keyword_ids) > self.max_keyword_len: 226 | self.max_keyword_len = len(cur_keyword_ids) 227 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 228 | self.tokenizer = tokenizer 229 | self.start_len = input_ids.shape[1] 230 | 231 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 232 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 233 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 234 | for keyword_id in self.keyword_ids: 235 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] 236 | if torch.equal(truncated_output_ids, keyword_id): 237 | return True 238 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 239 | for keyword in self.keywords: 240 | if keyword in outputs: 241 | return True 242 | return False 243 | 244 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 245 | outputs = [] 246 | for i in range(output_ids.shape[0]): 247 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 248 | return all(outputs) 249 | -------------------------------------------------------------------------------- /llava/llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from .constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /llava/llava_data_vq_unified.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | from functools import partial 5 | 6 | import torch 7 | from PIL import ImageFile 8 | 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | from PIL import Image 11 | from torch.utils.data import Dataset 12 | from torch.utils.data.distributed import DistributedSampler 13 | from training.utils import image_transform 14 | from llava.llava import conversation as conversation_lib 15 | 16 | DEFAULT_IMAGE_TOKEN = "" 17 | IGNORE_INDEX = -100 18 | conversation_lib.default_conversation = conversation_lib.conv_templates["phi1.5"] 19 | SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. " \ 20 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 21 | 22 | def preprocess_multimodal(sources): 23 | for source in sources: 24 | for sentence in source: 25 | if DEFAULT_IMAGE_TOKEN in sentence['value']: 26 | sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() 27 | sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] 28 | sentence['value'] = sentence['value'].strip() 29 | 30 | # Customized operation, get rid of special token. Edited by Zechen 31 | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "") 32 | sentence['value'] = sentence['value'].strip() 33 | 34 | return sources 35 | 36 | 37 | def preprocess_v0( 38 | sources, 39 | tokenizer, 40 | ): 41 | # Let's assume has_image is false, since we will process the image token separately 42 | has_image = False 43 | 44 | # Adapted from llava-phi/mipha/train/train.py 45 | conv = conversation_lib.default_conversation.copy() 46 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 47 | 48 | # Apply prompt templates 49 | conversations = [] 50 | for i, source in enumerate(sources): 51 | if roles[source[0]["from"]] != conv.roles[0]: 52 | # Skip the first one if it is not from human 53 | source = source[1:] 54 | 55 | conv.messages = [] 56 | for j, sentence in enumerate(source): 57 | role = roles[sentence["from"]] 58 | assert role == conv.roles[j % 2] 59 | conv.append_message(role, sentence["value"]) 60 | conversation_str = str(conv.get_prompt()).strip() 61 | conversations.append(conversation_str) 62 | 63 | input_ids = tokenizer( 64 | conversations, 65 | return_tensors="pt", 66 | padding="longest", 67 | max_length=tokenizer.model_max_length, 68 | truncation=True, 69 | ).input_ids 70 | 71 | targets = input_ids.clone() 72 | 73 | assert conv.sep_style == conversation_lib.SeparatorStyle.TWO 74 | 75 | # Mask targets 76 | sep = conv.sep + conv.roles[1] + ": " # ' ASSISTANT: ' 77 | for conversation, target in zip(conversations, targets): # loop for instances in a batch 78 | # total_len = int(target.ne(tokenizer.pad_token_id).sum()) + conversation.count(conv.sep2) # in phi-2, pad_token_id == eos_token_id 79 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 80 | 81 | rounds = conversation.split(conv.sep2) # handle multi-round conversation regarding one image 82 | cur_len = 0 # no bos token in phi, so set the initial len to 0 83 | if cur_len > 0: 84 | target[:cur_len] = IGNORE_INDEX 85 | for i, rou in enumerate(rounds): 86 | if rou == "": 87 | break 88 | 89 | parts = rou.split(sep) 90 | if len(parts) != 2: 91 | break 92 | parts[0] += sep 93 | 94 | round_len = len(tokenizer(rou).input_ids) + 1 # +1 for <|endoftext|> 95 | instruction_len = len(tokenizer(parts[0]).input_ids) - 1 96 | 97 | target[cur_len: cur_len + instruction_len] = IGNORE_INDEX 98 | 99 | cur_len += round_len 100 | target[cur_len:] = IGNORE_INDEX 101 | 102 | if cur_len < tokenizer.model_max_length: 103 | if cur_len != total_len: 104 | target[:] = IGNORE_INDEX 105 | print(conversation) 106 | print( 107 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 108 | f" (ignored)" 109 | ) 110 | 111 | input_ids_system = tokenizer( 112 | [SYSTEM_PROMPT for _ in range(len(conversations))], 113 | return_tensors="pt", 114 | padding="longest", 115 | max_length=tokenizer.model_max_length, 116 | truncation=True, 117 | ).input_ids 118 | 119 | return dict( 120 | input_ids=input_ids, 121 | labels=targets, 122 | input_ids_system=input_ids_system 123 | ) 124 | 125 | 126 | class LLaVADataset(Dataset): 127 | 128 | def __init__(self, 129 | tokenizer, 130 | phase, 131 | ): 132 | super(LLaVADataset, self).__init__() 133 | 134 | self.tokenizer = tokenizer 135 | 136 | if phase == "pretrain": 137 | data_file_path = "/mnt/bn/vgfm2/test_dit/blip_laion_cc_sbu_558k.json" 138 | self.image_root = "/mnt/bn/vgfm2/test_dit/pretraining_data" 139 | else: 140 | data_file_path = "/mnt/bn/vgfm2/test_dit/llava_v1_5_mix665k.json" 141 | self.image_root = "/mnt/bn/vgfm2/test_dit/tuning_data" 142 | 143 | with open(data_file_path, 'r') as f: 144 | data = json.load(f) 145 | self.list_data_dict = [] 146 | for item in data: 147 | if 'image' in item.keys(): 148 | self.list_data_dict.append(item) 149 | 150 | print("Formatting llava instruction data") 151 | 152 | def __len__(self): 153 | return len(self.list_data_dict) 154 | 155 | def __getitem__(self, i): 156 | sources = self.list_data_dict[i] 157 | if isinstance(i, int): 158 | sources = [sources] 159 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME 160 | 161 | assert 'image' in sources[0] 162 | image_file = self.list_data_dict[i]['image'] 163 | image_folder = self.image_root 164 | try: 165 | image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') 166 | image = image_transform(image) 167 | except: 168 | print("Read image error. Use dummy data.") 169 | crop_size = 256 170 | image = torch.zeros(3, crop_size, crop_size) 171 | 172 | sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources])) 173 | 174 | data_dict = preprocess_v0(sources, self.tokenizer) 175 | 176 | if isinstance(i, int): 177 | data_dict = dict(input_ids=data_dict["input_ids"][0], 178 | labels=data_dict["labels"][0], 179 | input_ids_system=data_dict["input_ids_system"][0]) 180 | 181 | # image exist in the data 182 | if 'image' in self.list_data_dict[i]: 183 | data_dict['image'] = image 184 | else: 185 | # image does not exist in the data, but the model is multimodal 186 | crop_size = 256 187 | data_dict['image'] = torch.zeros(3, crop_size, crop_size) 188 | 189 | return data_dict 190 | 191 | 192 | def collate_fn( 193 | instances, 194 | tokenizer=None, 195 | max_length=77, 196 | ): 197 | input_ids, labels, input_ids_system = tuple([instance[key] for instance in instances] 198 | for key in ("input_ids", "labels", "input_ids_system")) 199 | input_ids = torch.nn.utils.rnn.pad_sequence( 200 | input_ids, 201 | batch_first=True, 202 | padding_value=tokenizer.pad_token_id) 203 | labels = torch.nn.utils.rnn.pad_sequence(labels, 204 | batch_first=True, 205 | padding_value=IGNORE_INDEX) 206 | input_ids_system = torch.stack(input_ids_system, dim=0) 207 | 208 | offset = max_length - input_ids.shape[-1] - input_ids_system.shape[-1] 209 | 210 | if input_ids.shape[-1] < max_length - input_ids_system.shape[-1]: 211 | pad_tube = torch.ones(size=(input_ids.shape[0], offset), dtype=input_ids.dtype) * tokenizer.pad_token_id 212 | input_ids = torch.cat([input_ids, pad_tube], dim=1) 213 | 214 | pad_tube = torch.ones(size=(labels.shape[0], offset), dtype=labels.dtype) * IGNORE_INDEX 215 | labels = torch.cat([labels, pad_tube], dim=1) 216 | 217 | min_max_len = min( 218 | max_length - input_ids_system.shape[-1], 219 | tokenizer.model_max_length - input_ids_system.shape[-1], 220 | ) 221 | 222 | input_ids = input_ids[:, :min_max_len] 223 | labels = labels[:, :min_max_len] 224 | batch = dict( 225 | input_ids=input_ids, 226 | labels=labels, 227 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 228 | input_ids_system=input_ids_system, 229 | ) 230 | 231 | if 'image' in instances[0]: 232 | images = [instance['image'] for instance in instances] 233 | if all(x is not None and x.shape == images[0].shape for x in images): 234 | batch['images'] = torch.stack(images) 235 | else: 236 | batch['images'] = images 237 | 238 | return batch 239 | 240 | 241 | def get_instruct_data_loader( 242 | tokenizer, 243 | batch_size, 244 | num_workers, 245 | world_size, 246 | local_rank, 247 | max_length, 248 | phase, 249 | ): 250 | train_dataset = LLaVADataset( 251 | tokenizer, 252 | phase, 253 | ) 254 | datasampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=local_rank) 255 | dataloader = torch.utils.data.DataLoader( 256 | train_dataset, 257 | batch_size=batch_size, 258 | num_workers=num_workers, 259 | pin_memory=True, 260 | collate_fn=partial( 261 | collate_fn, 262 | tokenizer=tokenizer, 263 | max_length=max_length, 264 | ), 265 | sampler=datasampler 266 | ) 267 | 268 | return dataloader 269 | 270 | 271 | if __name__ == '__main__': 272 | import transformers 273 | pretrained_model_path = '/mnt/bn/vgfm2/test_mlx/xavier/pretrained_weights/phi-1_5' 274 | tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_path, 275 | padding_side="left") 276 | special_tokens = ("soi", "eoi", "sovi", "eovi", "t2i", "mmu", "t2v", "v2v", "lvg") 277 | tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 278 | tokenizer.add_tokens(list(special_tokens)) 279 | 280 | dataset = LLaVADataset( 281 | tokenizer, 282 | "tuning" 283 | ) 284 | 285 | item = dataset.__getitem__(0) 286 | import pdb 287 | pdb.set_trace() 288 | 289 | -------------------------------------------------------------------------------- /llava/llava_instruct_data.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | from functools import partial 5 | 6 | import torch 7 | from PIL import ImageFile 8 | from transformers import CLIPImageProcessor 9 | 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | from PIL import Image 12 | from torch.utils.data import Dataset 13 | from torch.utils.data.distributed import DistributedSampler 14 | 15 | from llava.llava import conversation as conversation_lib 16 | 17 | DEFAULT_IMAGE_TOKEN = "" 18 | IGNORE_INDEX = -100 19 | conversation_lib.default_conversation = conversation_lib.conv_templates["phi1.5"] 20 | SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. " \ 21 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 22 | 23 | def preprocess_multimodal(sources): 24 | for source in sources: 25 | for sentence in source: 26 | if DEFAULT_IMAGE_TOKEN in sentence['value']: 27 | sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() 28 | sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] 29 | sentence['value'] = sentence['value'].strip() 30 | 31 | # Customized operation, get rid of special token. Edited by Zechen 32 | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "") 33 | sentence['value'] = sentence['value'].strip() 34 | 35 | return sources 36 | 37 | 38 | def preprocess_v0( 39 | sources, 40 | tokenizer, 41 | ): 42 | # Let's assume has_image is false, since we will process the image token separately 43 | has_image = False 44 | 45 | # Adapted from llava-phi/mipha/train/train.py 46 | conv = conversation_lib.default_conversation.copy() 47 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 48 | 49 | # Apply prompt templates 50 | conversations = [] 51 | for i, source in enumerate(sources): 52 | if roles[source[0]["from"]] != conv.roles[0]: 53 | # Skip the first one if it is not from human 54 | source = source[1:] 55 | 56 | conv.messages = [] 57 | for j, sentence in enumerate(source): 58 | role = roles[sentence["from"]] 59 | assert role == conv.roles[j % 2] 60 | conv.append_message(role, sentence["value"]) 61 | conversation_str = str(conv.get_prompt()).strip() 62 | conversations.append(conversation_str) 63 | 64 | input_ids = tokenizer( 65 | conversations, 66 | return_tensors="pt", 67 | padding="longest", 68 | max_length=tokenizer.model_max_length, 69 | truncation=True, 70 | ).input_ids 71 | 72 | targets = input_ids.clone() 73 | 74 | assert conv.sep_style == conversation_lib.SeparatorStyle.TWO 75 | 76 | # Mask targets 77 | sep = conv.sep + conv.roles[1] + ": " # ' ASSISTANT: ' 78 | for conversation, target in zip(conversations, targets): # loop for instances in a batch 79 | # total_len = int(target.ne(tokenizer.pad_token_id).sum()) + conversation.count(conv.sep2) # in phi-2, pad_token_id == eos_token_id 80 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 81 | 82 | rounds = conversation.split(conv.sep2) # handle multi-round conversation regarding one image 83 | cur_len = 0 # no bos token in phi, so set the initial len to 0 84 | if cur_len > 0: 85 | target[:cur_len] = IGNORE_INDEX 86 | for i, rou in enumerate(rounds): 87 | if rou == "": 88 | break 89 | 90 | parts = rou.split(sep) 91 | if len(parts) != 2: 92 | break 93 | parts[0] += sep 94 | 95 | # if has_image: 96 | # round_len = len(tokenizer_image_token(rou, tokenizer)) + 1 # +1 for <|endoftext|> 97 | # instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 # -1 for 98 | # else: 99 | round_len = len(tokenizer(rou).input_ids) + 1 # +1 for <|endoftext|> 100 | instruction_len = len(tokenizer(parts[0]).input_ids) - 1 101 | 102 | target[cur_len: cur_len + instruction_len] = IGNORE_INDEX 103 | 104 | cur_len += round_len 105 | target[cur_len:] = IGNORE_INDEX 106 | 107 | if cur_len < tokenizer.model_max_length: 108 | if cur_len != total_len: 109 | target[:] = IGNORE_INDEX 110 | print(conversation) 111 | print( 112 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 113 | f" (ignored)" 114 | ) 115 | 116 | input_ids_system = tokenizer( 117 | [SYSTEM_PROMPT for _ in range(len(conversations))], 118 | return_tensors="pt", 119 | padding="longest", 120 | max_length=tokenizer.model_max_length, 121 | truncation=True, 122 | ).input_ids 123 | 124 | return dict( 125 | input_ids=input_ids, 126 | labels=targets, 127 | input_ids_system=input_ids_system 128 | ) 129 | 130 | 131 | class LLaVAInstructDataset(Dataset): 132 | 133 | def __init__(self, tokenizer): 134 | super(LLaVAInstructDataset, self).__init__() 135 | 136 | self.tokenizer = tokenizer 137 | 138 | data_file_path = "/mnt/bn/vgfm2/test_dit/llava_v1_5_mix665k.json" 139 | self.image_root = "/mnt/bn/vgfm2/test_dit/tuning_data" 140 | 141 | with open(data_file_path, 'r') as f: 142 | data = json.load(f) 143 | self.list_data_dict = [] 144 | for item in data: 145 | if 'image' in item.keys(): 146 | self.list_data_dict.append(item) 147 | 148 | self.processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336") 149 | 150 | print("Formatting llava instruction data") 151 | 152 | def __len__(self): 153 | return len(self.list_data_dict) 154 | 155 | def __getitem__(self, i): 156 | sources = self.list_data_dict[i] 157 | if isinstance(i, int): 158 | sources = [sources] 159 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME 160 | 161 | assert 'image' in sources[0] 162 | image_file = self.list_data_dict[i]['image'] 163 | image_folder = self.image_root 164 | try: 165 | image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') 166 | image = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 167 | except: 168 | print("Read image error. Use dummy data.") 169 | crop_size = 336 170 | image = torch.zeros(3, crop_size, crop_size) 171 | 172 | sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources])) 173 | 174 | data_dict = preprocess_v0(sources, self.tokenizer) 175 | 176 | if isinstance(i, int): 177 | data_dict = dict(input_ids=data_dict["input_ids"][0], 178 | labels=data_dict["labels"][0], 179 | input_ids_system=data_dict["input_ids_system"][0]) 180 | 181 | # image exist in the data 182 | if 'image' in self.list_data_dict[i]: 183 | data_dict['image'] = image 184 | else: 185 | # image does not exist in the data, but the model is multimodal 186 | crop_size = 336 187 | data_dict['image'] = torch.zeros(3, crop_size, crop_size) 188 | 189 | return data_dict 190 | 191 | 192 | def collate_fn( 193 | instances, 194 | tokenizer=None, 195 | max_length=77, 196 | ): 197 | input_ids, labels, input_ids_system = tuple([instance[key] for instance in instances] 198 | for key in ("input_ids", "labels", "input_ids_system")) 199 | input_ids = torch.nn.utils.rnn.pad_sequence( 200 | input_ids, 201 | batch_first=True, 202 | padding_value=tokenizer.pad_token_id) 203 | labels = torch.nn.utils.rnn.pad_sequence(labels, 204 | batch_first=True, 205 | padding_value=IGNORE_INDEX) 206 | input_ids_system = torch.stack(input_ids_system, dim=0) 207 | 208 | offset = max_length - input_ids.shape[-1] - input_ids_system.shape[-1] 209 | 210 | if input_ids.shape[-1] < max_length - input_ids_system.shape[-1]: 211 | pad_tube = torch.ones(size=(input_ids.shape[0], offset), dtype=input_ids.dtype) * tokenizer.pad_token_id 212 | input_ids = torch.cat([input_ids, pad_tube], dim=1) 213 | 214 | pad_tube = torch.ones(size=(labels.shape[0], offset), dtype=labels.dtype) * IGNORE_INDEX 215 | labels = torch.cat([labels, pad_tube], dim=1) 216 | 217 | min_max_len = min( 218 | max_length - input_ids_system.shape[-1], 219 | tokenizer.model_max_length - input_ids_system.shape[-1], 220 | ) 221 | 222 | input_ids = input_ids[:, :min_max_len] 223 | labels = labels[:, :min_max_len] 224 | batch = dict( 225 | input_ids=input_ids, 226 | labels=labels, 227 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 228 | input_ids_system=input_ids_system, 229 | ) 230 | 231 | if 'image' in instances[0]: 232 | images = [instance['image'] for instance in instances] 233 | if all(x is not None and x.shape == images[0].shape for x in images): 234 | batch['images'] = torch.stack(images) 235 | else: 236 | batch['images'] = images 237 | 238 | return batch 239 | 240 | 241 | def get_instruct_data_loader( 242 | tokenizer, 243 | batch_size, 244 | num_workers, 245 | world_size, 246 | local_rank, 247 | max_length, 248 | ): 249 | train_dataset = LLaVAInstructDataset(tokenizer) 250 | datasampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=local_rank) 251 | dataloader = torch.utils.data.DataLoader( 252 | train_dataset, 253 | batch_size=batch_size, 254 | num_workers=num_workers, 255 | pin_memory=True, 256 | collate_fn=partial( 257 | collate_fn, 258 | tokenizer=tokenizer, 259 | max_length=max_length, 260 | ), 261 | sampler=datasampler 262 | ) 263 | 264 | return dataloader 265 | 266 | 267 | if __name__ == '__main__': 268 | import transformers 269 | pretrained_model_path = '/mnt/bn/vgfm2/test_mlx/xavier/pretrained_weights/phi-1_5' 270 | tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_path, 271 | padding_side="left") 272 | special_tokens = ("soi", "eoi", "sovi", "eovi", "t2i", "mmu", "t2v", "v2v", "lvg") 273 | tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 274 | tokenizer.add_tokens(list(special_tokens)) 275 | 276 | dataset = LLaVAInstructDataset(tokenizer) 277 | 278 | dataset.__getitem__(0) 279 | 280 | -------------------------------------------------------------------------------- /llava/llava_pretrain_data.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | from functools import partial 5 | 6 | import torch 7 | from PIL import Image 8 | from llava.llava import conversation as conversation_lib 9 | from torch.utils.data import Dataset 10 | from torch.utils.data.distributed import DistributedSampler 11 | from transformers import CLIPImageProcessor 12 | 13 | DEFAULT_IMAGE_TOKEN = "" 14 | IGNORE_INDEX = -100 15 | conversation_lib.default_conversation = conversation_lib.conv_templates["plain"] 16 | 17 | def preprocess_multimodal(sources): 18 | for source in sources: 19 | for sentence in source: 20 | if DEFAULT_IMAGE_TOKEN in sentence['value']: 21 | sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() 22 | sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] 23 | sentence['value'] = sentence['value'].strip() 24 | 25 | # Customized operation, get rid of special token. Edited by Zechen 26 | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "") 27 | sentence['value'] = sentence['value'].strip() 28 | 29 | return sources 30 | 31 | 32 | def preprocess_plain(sources, tokenizer): 33 | # add end signal and concatenate together 34 | conversations = [] 35 | for source in sources: 36 | assert len(source) == 2 37 | # assert DEFAULT_IMAGE_TOKEN in source[0]['value'] 38 | # source[0]['value'] = DEFAULT_IMAGE_TOKEN 39 | source[0]['value'] = "" 40 | conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep 41 | conversations.append(conversation) 42 | 43 | # tokenize conversations 44 | # input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] 45 | input_ids = [tokenizer(prompt)["input_ids"] + [tokenizer.eos_token_id] for prompt in conversations] 46 | targets = copy.deepcopy(input_ids) 47 | 48 | for target, source in zip(targets, sources): 49 | # tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) 50 | tokenized_len = len(tokenizer(source[0]['value'])["input_ids"]) 51 | if tokenized_len > 0: 52 | target[:tokenized_len] = IGNORE_INDEX 53 | 54 | return dict(input_ids=torch.tensor(input_ids), labels=torch.tensor(targets)) 55 | 56 | 57 | class LLaVAPretrainCaptioningDataset(Dataset): 58 | 59 | def __init__(self, tokenizer): 60 | super(LLaVAPretrainCaptioningDataset, self).__init__() 61 | 62 | self.tokenizer = tokenizer 63 | 64 | data_file_path = "/mnt/bn/vgfm2/test_dit/blip_laion_cc_sbu_558k.json" 65 | self.image_root = "/mnt/bn/vgfm2/test_dit/pretraining_data" 66 | 67 | with open(data_file_path, 'r') as f: 68 | data = json.load(f) 69 | self.list_data_dict = [] 70 | for item in data: 71 | if 'image' in item.keys(): 72 | self.list_data_dict.append(item) 73 | 74 | self.processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336") 75 | 76 | print("Formatting llava captioning data") 77 | 78 | def __len__(self): 79 | return len(self.list_data_dict) 80 | 81 | def __getitem__(self, i): 82 | sources = self.list_data_dict[i] 83 | if isinstance(i, int): 84 | sources = [sources] 85 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME 86 | 87 | assert 'image' in sources[0] 88 | image_file = self.list_data_dict[i]['image'] 89 | image_folder = self.image_root 90 | image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') 91 | image = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 92 | 93 | sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources])) 94 | 95 | data_dict = preprocess_plain(sources, self.tokenizer) 96 | 97 | if isinstance(i, int): 98 | data_dict = dict(input_ids=data_dict["input_ids"][0], 99 | labels=data_dict["labels"][0]) 100 | 101 | # image exist in the data 102 | if 'image' in self.list_data_dict[i]: 103 | data_dict['image'] = image 104 | else: 105 | # image does not exist in the data, but the model is multimodal 106 | crop_size = 256 107 | data_dict['image'] = torch.zeros(3, crop_size, crop_size) 108 | 109 | return data_dict 110 | 111 | 112 | def collate_fn( 113 | instances, 114 | tokenizer=None, 115 | max_length=77, 116 | ): 117 | input_ids, labels = tuple([instance[key] for instance in instances] 118 | for key in ("input_ids", "labels")) 119 | input_ids = torch.nn.utils.rnn.pad_sequence( 120 | input_ids, 121 | batch_first=True, 122 | padding_value=tokenizer.pad_token_id) 123 | labels = torch.nn.utils.rnn.pad_sequence(labels, 124 | batch_first=True, 125 | padding_value=IGNORE_INDEX) 126 | 127 | if input_ids.shape[-1] < max_length: 128 | offset = max_length - input_ids.shape[-1] 129 | pad_tube = torch.ones(size=(input_ids.shape[0], offset), dtype=input_ids.dtype) * tokenizer.pad_token_id 130 | input_ids = torch.cat([input_ids, pad_tube], dim=1) 131 | 132 | offset = max_length - labels.shape[-1] 133 | pad_tube = torch.ones(size=(labels.shape[0], offset), dtype=labels.dtype) * IGNORE_INDEX 134 | labels = torch.cat([labels, pad_tube], dim=1) 135 | 136 | min_max_len = min(max_length, tokenizer.model_max_length) 137 | 138 | input_ids = input_ids[:, :min_max_len] 139 | labels = labels[:, :min_max_len] 140 | batch = dict( 141 | input_ids=input_ids, 142 | labels=labels, 143 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 144 | ) 145 | 146 | if 'image' in instances[0]: 147 | images = [instance['image'] for instance in instances] 148 | if all(x is not None and x.shape == images[0].shape for x in images): 149 | batch['images'] = torch.stack(images) 150 | else: 151 | batch['images'] = images 152 | 153 | return batch 154 | 155 | 156 | def get_plain_data_loader( 157 | tokenizer, 158 | batch_size, 159 | num_workers, 160 | world_size, 161 | local_rank, 162 | max_length, 163 | ): 164 | train_dataset = LLaVAPretrainCaptioningDataset(tokenizer) 165 | datasampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=local_rank) 166 | dataloader = torch.utils.data.DataLoader( 167 | train_dataset, 168 | batch_size=batch_size, 169 | num_workers=num_workers, 170 | pin_memory=True, 171 | collate_fn=partial( 172 | collate_fn, 173 | tokenizer=tokenizer, 174 | max_length=max_length, 175 | ), 176 | sampler=datasampler 177 | ) 178 | 179 | return dataloader 180 | 181 | 182 | if __name__ == '__main__': 183 | import transformers 184 | pretrained_model_path = '/mnt/bn/vgfm2/test_mlx/xavier/pretrained_weights/phi-1_5' 185 | tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_path, 186 | padding_side="left") 187 | special_tokens = ("soi", "eoi", "sovi", "eovi", "t2i", "mmu", "t2v", "v2v", "lvg") 188 | tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 189 | tokenizer.add_tokens(list(special_tokens)) 190 | 191 | dataset = LLaVAPretrainCaptioningDataset(tokenizer) 192 | 193 | dataset.__getitem__(0) 194 | 195 | -------------------------------------------------------------------------------- /mmu_validation/dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/mmu_validation/dog.png -------------------------------------------------------------------------------- /mmu_validation/sofa_under_water.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/mmu_validation/sofa_under_water.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_showo import Showo 2 | from .modeling_magvitv2 import VQGANEncoder, VQGANDecoder, LFQuantizer, MAGVITv2 3 | from .sampling import * 4 | from .clip_encoder import CLIPVisionTower -------------------------------------------------------------------------------- /models/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | class CLIPVisionTower(nn.Module): 7 | def __init__(self, vision_tower): 8 | super().__init__() 9 | 10 | self.is_loaded = False 11 | 12 | self.vision_tower_name = vision_tower 13 | self.select_layer = -2 14 | self.select_feature = "patch" 15 | self.load_model() 16 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 17 | 18 | def load_model(self, device_map=None): 19 | if self.is_loaded: 20 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 21 | return 22 | 23 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 24 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 25 | self.vision_tower.requires_grad_(False) 26 | 27 | self.is_loaded = True 28 | 29 | def feature_select(self, image_forward_outs): 30 | image_features = image_forward_outs.hidden_states[self.select_layer] 31 | if self.select_feature == 'patch': 32 | image_features = image_features[:, 1:] 33 | elif self.select_feature == 'cls_patch': 34 | image_features = image_features 35 | else: 36 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 37 | return image_features 38 | 39 | @torch.no_grad() 40 | def forward(self, images): 41 | if type(images) is list: 42 | image_features = [] 43 | for image in images: 44 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 45 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 46 | image_features.append(image_feature) 47 | else: 48 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 49 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 50 | 51 | return image_features 52 | 53 | @property 54 | def dummy_feature(self): 55 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 56 | 57 | @property 58 | def dtype(self): 59 | return self.vision_tower.dtype 60 | 61 | @property 62 | def device(self): 63 | return self.vision_tower.device 64 | 65 | @property 66 | def config(self): 67 | if self.is_loaded: 68 | return self.vision_tower.config 69 | else: 70 | return self.cfg_only 71 | 72 | @property 73 | def hidden_size(self): 74 | return self.config.hidden_size 75 | 76 | @property 77 | def num_patches_per_side(self): 78 | return self.config.image_size // self.config.patch_size 79 | 80 | @property 81 | def num_patches(self): 82 | return (self.config.image_size // self.config.patch_size) ** 2 83 | 84 | 85 | class CLIPVisionTowerS2(CLIPVisionTower): 86 | def __init__(self, vision_tower, args, delay_load=False): 87 | super().__init__(vision_tower, args, delay_load) 88 | 89 | self.s2_scales = getattr(args, 's2_scales', '336,672,1008') 90 | self.s2_scales = list(map(int, self.s2_scales.split(','))) 91 | self.s2_scales.sort() 92 | self.s2_split_size = self.s2_scales[0] 93 | self.s2_image_size = self.s2_scales[-1] 94 | 95 | try: 96 | from s2wrapper import forward as multiscale_forward 97 | except ImportError: 98 | raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git') 99 | self.multiscale_forward = multiscale_forward 100 | 101 | # change resize/crop size in preprocessing to the largest image size in s2_scale 102 | if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False): 103 | self.image_processor.size['shortest_edge'] = self.s2_image_size 104 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size 105 | 106 | def load_model(self, device_map=None): 107 | if self.is_loaded: 108 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 109 | return 110 | 111 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 112 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 113 | self.vision_tower.requires_grad_(False) 114 | 115 | self.image_processor.size['shortest_edge'] = self.s2_image_size 116 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size 117 | 118 | self.is_loaded = True 119 | 120 | @torch.no_grad() 121 | def forward_feature(self, images): 122 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 123 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 124 | return image_features 125 | 126 | @torch.no_grad() 127 | def forward(self, images): 128 | if type(images) is list: 129 | image_features = [] 130 | for image in images: 131 | image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size) 132 | image_features.append(image_feature) 133 | else: 134 | image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size) 135 | 136 | return image_features 137 | 138 | @property 139 | def hidden_size(self): 140 | return self.config.hidden_size * len(self.s2_scales) 141 | -------------------------------------------------------------------------------- /models/common_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/CompVis/taming-transformers/blob/master/taming/modules/diffusionmodules/model.py#L34 3 | """ 4 | 5 | import math 6 | from typing import Tuple, Union 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from einops import rearrange, repeat 13 | from einops.layers.torch import Rearrange 14 | 15 | 16 | def nonlinearity(x): 17 | # swish 18 | return x * torch.sigmoid(x) 19 | 20 | 21 | def Normalize(in_channels): 22 | return torch.nn.GroupNorm( 23 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 24 | ) 25 | 26 | 27 | class Upsample(nn.Module): 28 | def __init__(self, in_channels, with_conv): 29 | super().__init__() 30 | self.with_conv = with_conv 31 | if self.with_conv: 32 | self.conv = torch.nn.Conv2d( 33 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 34 | ) 35 | 36 | def forward(self, x): 37 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 38 | if self.with_conv: 39 | x = self.conv(x) 40 | return x 41 | 42 | 43 | class DepthToSpaceUpsample(nn.Module): 44 | def __init__( 45 | self, 46 | in_channels, 47 | ): 48 | super().__init__() 49 | conv = nn.Conv2d(in_channels, in_channels * 4, 1) 50 | 51 | self.net = nn.Sequential( 52 | conv, 53 | nn.SiLU(), 54 | Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2), 55 | ) 56 | 57 | self.init_conv_(conv) 58 | 59 | def init_conv_(self, conv): 60 | o, i, h, w = conv.weight.shape 61 | conv_weight = torch.empty(o // 4, i, h, w) 62 | nn.init.kaiming_uniform_(conv_weight) 63 | conv_weight = repeat(conv_weight, "o ... -> (o 4) ...") 64 | 65 | conv.weight.data.copy_(conv_weight) 66 | nn.init.zeros_(conv.bias.data) 67 | 68 | def forward(self, x): 69 | out = self.net(x) 70 | return out 71 | 72 | 73 | class Downsample(nn.Module): 74 | def __init__(self, in_channels, with_conv): 75 | super().__init__() 76 | self.with_conv = with_conv 77 | if self.with_conv: 78 | # no asymmetric padding in torch conv, must do it ourselves 79 | self.conv = torch.nn.Conv2d( 80 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 81 | ) 82 | 83 | def forward(self, x): 84 | if self.with_conv: 85 | pad = (0, 1, 0, 1) 86 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 87 | x = self.conv(x) 88 | else: 89 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 90 | return x 91 | 92 | 93 | def unpack_time(t, batch): 94 | _, c, w, h = t.size() 95 | out = torch.reshape(t, [batch, -1, c, w, h]) 96 | out = rearrange(out, "b t c h w -> b c t h w") 97 | return out 98 | 99 | 100 | def pack_time(t): 101 | out = rearrange(t, "b c t h w -> b t c h w") 102 | _, _, c, w, h = out.size() 103 | return torch.reshape(out, [-1, c, w, h]) 104 | 105 | 106 | class TimeDownsample2x(nn.Module): 107 | def __init__( 108 | self, 109 | dim, 110 | dim_out=None, 111 | kernel_size=3, 112 | ): 113 | super().__init__() 114 | if dim_out is None: 115 | dim_out = dim 116 | self.time_causal_padding = (kernel_size - 1, 0) 117 | self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride=2) 118 | 119 | def forward(self, x): 120 | x = rearrange(x, "b c t h w -> b h w c t") 121 | b, h, w, c, t = x.size() 122 | x = torch.reshape(x, [-1, c, t]) 123 | 124 | x = F.pad(x, self.time_causal_padding) 125 | out = self.conv(x) 126 | 127 | out = torch.reshape(out, [b, h, w, c, t]) 128 | out = rearrange(out, "b h w c t -> b c t h w") 129 | out = rearrange(out, "b h w c t -> b c t h w") 130 | return out 131 | 132 | 133 | class TimeUpsample2x(nn.Module): 134 | def __init__(self, dim, dim_out=None): 135 | super().__init__() 136 | if dim_out is None: 137 | dim_out = dim 138 | conv = nn.Conv1d(dim, dim_out * 2, 1) 139 | 140 | self.net = nn.Sequential( 141 | nn.SiLU(), conv, Rearrange("b (c p) t -> b c (t p)", p=2) 142 | ) 143 | 144 | self.init_conv_(conv) 145 | 146 | def init_conv_(self, conv): 147 | o, i, t = conv.weight.shape 148 | conv_weight = torch.empty(o // 2, i, t) 149 | nn.init.kaiming_uniform_(conv_weight) 150 | conv_weight = repeat(conv_weight, "o ... -> (o 2) ...") 151 | 152 | conv.weight.data.copy_(conv_weight) 153 | nn.init.zeros_(conv.bias.data) 154 | 155 | def forward(self, x): 156 | x = rearrange(x, "b c t h w -> b h w c t") 157 | b, h, w, c, t = x.size() 158 | x = torch.reshape(x, [-1, c, t]) 159 | 160 | out = self.net(x) 161 | out = out[:, :, 1:].contiguous() 162 | 163 | out = torch.reshape(out, [b, h, w, c, t]) 164 | out = rearrange(out, "b h w c t -> b c t h w") 165 | return out 166 | 167 | 168 | class AttnBlock(nn.Module): 169 | def __init__(self, in_channels): 170 | super().__init__() 171 | self.in_channels = in_channels 172 | 173 | self.norm = Normalize(in_channels) 174 | self.q = torch.nn.Conv2d( 175 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 176 | ) 177 | self.k = torch.nn.Conv2d( 178 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 179 | ) 180 | self.v = torch.nn.Conv2d( 181 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 182 | ) 183 | self.proj_out = torch.nn.Conv2d( 184 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 185 | ) 186 | 187 | def forward(self, x): 188 | h_ = x 189 | h_ = self.norm(h_) 190 | q = self.q(h_) 191 | k = self.k(h_) 192 | v = self.v(h_) 193 | 194 | # compute attention 195 | b, c, h, w = q.shape 196 | q = q.reshape(b, c, h * w) 197 | q = q.permute(0, 2, 1) # b,hw,c 198 | k = k.reshape(b, c, h * w) # b,c,hw 199 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 200 | w_ = w_ * (int(c) ** (-0.5)) 201 | w_ = torch.nn.functional.softmax(w_, dim=2) 202 | 203 | # attend to values 204 | v = v.reshape(b, c, h * w) 205 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 206 | h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 207 | h_ = h_.reshape(b, c, h, w) 208 | 209 | h_ = self.proj_out(h_) 210 | 211 | return x + h_ 212 | 213 | 214 | class TimeAttention(AttnBlock): 215 | def forward(self, x, *args, **kwargs): 216 | x = rearrange(x, "b c t h w -> b h w t c") 217 | b, h, w, t, c = x.size() 218 | x = torch.reshape(x, (-1, t, c)) 219 | 220 | x = super().forward(x, *args, **kwargs) 221 | 222 | x = torch.reshape(x, [b, h, w, t, c]) 223 | return rearrange(x, "b h w t c -> b c t h w") 224 | 225 | 226 | class Residual(nn.Module): 227 | def __init__(self, fn: nn.Module): 228 | super().__init__() 229 | self.fn = fn 230 | 231 | def forward(self, x, **kwargs): 232 | return self.fn(x, **kwargs) + x 233 | 234 | 235 | def cast_tuple(t, length=1): 236 | return t if isinstance(t, tuple) else ((t,) * length) 237 | 238 | 239 | class CausalConv3d(nn.Module): 240 | def __init__( 241 | self, 242 | chan_in, 243 | chan_out, 244 | kernel_size: Union[int, Tuple[int, int, int]], 245 | pad_mode="constant", 246 | **kwargs 247 | ): 248 | super().__init__() 249 | kernel_size = cast_tuple(kernel_size, 3) 250 | 251 | time_kernel_size, height_kernel_size, width_kernel_size = kernel_size 252 | 253 | dilation = kwargs.pop("dilation", 1) 254 | stride = kwargs.pop("stride", 1) 255 | 256 | self.pad_mode = pad_mode 257 | time_pad = dilation * (time_kernel_size - 1) + (1 - stride) 258 | height_pad = height_kernel_size // 2 259 | width_pad = width_kernel_size // 2 260 | 261 | self.time_pad = time_pad 262 | self.time_causal_padding = ( 263 | width_pad, 264 | width_pad, 265 | height_pad, 266 | height_pad, 267 | time_pad, 268 | 0, 269 | ) 270 | 271 | stride = (stride, 1, 1) 272 | dilation = (dilation, 1, 1) 273 | self.conv = nn.Conv3d( 274 | chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs 275 | ) 276 | 277 | def forward(self, x): 278 | pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant" 279 | 280 | x = F.pad(x, self.time_causal_padding, mode=pad_mode) 281 | return self.conv(x) 282 | 283 | 284 | def ResnetBlockCausal3D( 285 | dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: str = "constant" 286 | ): 287 | net = nn.Sequential( 288 | Normalize(dim), 289 | nn.SiLU(), 290 | CausalConv3d(dim, dim, kernel_size, pad_mode), 291 | Normalize(dim), 292 | nn.SiLU(), 293 | CausalConv3d(dim, dim, kernel_size, pad_mode), 294 | ) 295 | return Residual(net) 296 | 297 | 298 | class ResnetBlock(nn.Module): 299 | def __init__( 300 | self, 301 | *, 302 | in_channels, 303 | out_channels=None, 304 | conv_shortcut=False, 305 | dropout, 306 | temb_channels=512 307 | ): 308 | super().__init__() 309 | self.in_channels = in_channels 310 | out_channels = in_channels if out_channels is None else out_channels 311 | self.out_channels = out_channels 312 | self.use_conv_shortcut = conv_shortcut 313 | 314 | self.norm1 = Normalize(in_channels) 315 | self.conv1 = torch.nn.Conv2d( 316 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 317 | ) 318 | if temb_channels > 0: 319 | self.temb_proj = torch.nn.Linear(temb_channels, out_channels) 320 | else: 321 | self.temb_proj = None 322 | self.norm2 = Normalize(out_channels) 323 | self.dropout = torch.nn.Dropout(dropout) 324 | self.conv2 = torch.nn.Conv2d( 325 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 326 | ) 327 | if self.in_channels != self.out_channels: 328 | if self.use_conv_shortcut: 329 | self.conv_shortcut = torch.nn.Conv2d( 330 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 331 | ) 332 | else: 333 | self.nin_shortcut = torch.nn.Conv2d( 334 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 335 | ) 336 | 337 | def forward(self, x, temb): 338 | h = x 339 | h = self.norm1(h) 340 | h = nonlinearity(h) 341 | h = self.conv1(h) 342 | 343 | if temb is not None: 344 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 345 | 346 | h = self.norm2(h) 347 | h = nonlinearity(h) 348 | h = self.dropout(h) 349 | h = self.conv2(h) 350 | 351 | if self.in_channels != self.out_channels: 352 | if self.use_conv_shortcut: 353 | x = self.conv_shortcut(x) 354 | else: 355 | x = self.nin_shortcut(x) 356 | 357 | return x + h 358 | -------------------------------------------------------------------------------- /models/logging.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Optuna, Hugging Face 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Logging utilities.""" 16 | 17 | import logging 18 | import os 19 | import sys 20 | import threading 21 | from logging import CRITICAL # NOQA 22 | from logging import DEBUG # NOQA 23 | from logging import ERROR # NOQA 24 | from logging import FATAL # NOQA 25 | from logging import INFO # NOQA 26 | from logging import NOTSET # NOQA 27 | from logging import WARN # NOQA 28 | from logging import WARNING # NOQA 29 | from typing import Optional 30 | 31 | from tqdm import auto as tqdm_lib 32 | 33 | _lock = threading.Lock() 34 | _default_handler: Optional[logging.Handler] = None 35 | 36 | log_levels = { 37 | "debug": logging.DEBUG, 38 | "info": logging.INFO, 39 | "warning": logging.WARNING, 40 | "error": logging.ERROR, 41 | "critical": logging.CRITICAL, 42 | } 43 | 44 | _default_log_level = logging.WARNING 45 | 46 | _tqdm_active = True 47 | 48 | 49 | def _get_default_logging_level(): 50 | """ 51 | If muse_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is 52 | not - fall back to `_default_log_level` 53 | """ 54 | env_level_str = os.getenv("muse_VERBOSITY", None) 55 | if env_level_str: 56 | if env_level_str in log_levels: 57 | return log_levels[env_level_str] 58 | else: 59 | logging.getLogger().warning( 60 | f"Unknown option muse_VERBOSITY={env_level_str}, has to be one of: { ', '.join(log_levels.keys()) }" 61 | ) 62 | return _default_log_level 63 | 64 | 65 | def _get_library_name() -> str: 66 | return __name__.split(".")[0] 67 | 68 | 69 | def _get_library_root_logger() -> logging.Logger: 70 | return logging.getLogger(_get_library_name()) 71 | 72 | 73 | def _configure_library_root_logger() -> None: 74 | global _default_handler 75 | 76 | with _lock: 77 | if _default_handler: 78 | # This library has already configured the library root logger. 79 | return 80 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream. 81 | _default_handler.flush = sys.stderr.flush 82 | 83 | # Apply our default configuration to the library root logger. 84 | library_root_logger = _get_library_root_logger() 85 | library_root_logger.addHandler(_default_handler) 86 | library_root_logger.setLevel(_get_default_logging_level()) 87 | library_root_logger.propagate = False 88 | 89 | 90 | def _reset_library_root_logger() -> None: 91 | global _default_handler 92 | 93 | with _lock: 94 | if not _default_handler: 95 | return 96 | 97 | library_root_logger = _get_library_root_logger() 98 | library_root_logger.removeHandler(_default_handler) 99 | library_root_logger.setLevel(logging.NOTSET) 100 | _default_handler = None 101 | 102 | 103 | def get_log_levels_dict(): 104 | return log_levels 105 | 106 | 107 | def get_logger(name: Optional[str] = None) -> logging.Logger: 108 | """ 109 | Return a logger with the specified name. 110 | 111 | This function is not supposed to be directly accessed unless you are writing a custom muse module. 112 | """ 113 | 114 | if name is None: 115 | name = _get_library_name() 116 | 117 | _configure_library_root_logger() 118 | return logging.getLogger(name) 119 | 120 | 121 | def get_verbosity() -> int: 122 | """ 123 | Return the current level for the 🤗 muse' root logger as an int. 124 | 125 | Returns: 126 | `int`: The logging level. 127 | 128 | 129 | 130 | 🤗 muse has following logging levels: 131 | 132 | - 50: `muse.logging.CRITICAL` or `muse.logging.FATAL` 133 | - 40: `muse.logging.ERROR` 134 | - 30: `muse.logging.WARNING` or `muse.logging.WARN` 135 | - 20: `muse.logging.INFO` 136 | - 10: `muse.logging.DEBUG` 137 | 138 | """ 139 | 140 | _configure_library_root_logger() 141 | return _get_library_root_logger().getEffectiveLevel() 142 | 143 | 144 | def set_verbosity(verbosity: int) -> None: 145 | """ 146 | Set the verbosity level for the 🤗 muse' root logger. 147 | 148 | Args: 149 | verbosity (`int`): 150 | Logging level, e.g., one of: 151 | 152 | - `muse.logging.CRITICAL` or `muse.logging.FATAL` 153 | - `muse.logging.ERROR` 154 | - `muse.logging.WARNING` or `muse.logging.WARN` 155 | - `muse.logging.INFO` 156 | - `muse.logging.DEBUG` 157 | """ 158 | 159 | _configure_library_root_logger() 160 | _get_library_root_logger().setLevel(verbosity) 161 | 162 | 163 | def set_verbosity_info(): 164 | """Set the verbosity to the `INFO` level.""" 165 | return set_verbosity(INFO) 166 | 167 | 168 | def set_verbosity_warning(): 169 | """Set the verbosity to the `WARNING` level.""" 170 | return set_verbosity(WARNING) 171 | 172 | 173 | def set_verbosity_debug(): 174 | """Set the verbosity to the `DEBUG` level.""" 175 | return set_verbosity(DEBUG) 176 | 177 | 178 | def set_verbosity_error(): 179 | """Set the verbosity to the `ERROR` level.""" 180 | return set_verbosity(ERROR) 181 | 182 | 183 | def disable_default_handler() -> None: 184 | """Disable the default handler of the HuggingFace muse' root logger.""" 185 | 186 | _configure_library_root_logger() 187 | 188 | assert _default_handler is not None 189 | _get_library_root_logger().removeHandler(_default_handler) 190 | 191 | 192 | def enable_default_handler() -> None: 193 | """Enable the default handler of the HuggingFace muse' root logger.""" 194 | 195 | _configure_library_root_logger() 196 | 197 | assert _default_handler is not None 198 | _get_library_root_logger().addHandler(_default_handler) 199 | 200 | 201 | def add_handler(handler: logging.Handler) -> None: 202 | """adds a handler to the HuggingFace muse' root logger.""" 203 | 204 | _configure_library_root_logger() 205 | 206 | assert handler is not None 207 | _get_library_root_logger().addHandler(handler) 208 | 209 | 210 | def remove_handler(handler: logging.Handler) -> None: 211 | """removes given handler from the HuggingFace muse' root logger.""" 212 | 213 | _configure_library_root_logger() 214 | 215 | assert handler is not None and handler not in _get_library_root_logger().handlers 216 | _get_library_root_logger().removeHandler(handler) 217 | 218 | 219 | def disable_propagation() -> None: 220 | """ 221 | Disable propagation of the library log outputs. Note that log propagation is disabled by default. 222 | """ 223 | 224 | _configure_library_root_logger() 225 | _get_library_root_logger().propagate = False 226 | 227 | 228 | def enable_propagation() -> None: 229 | """ 230 | Enable propagation of the library log outputs. Please disable the HuggingFace muse' default handler to prevent 231 | double logging if the root logger has been configured. 232 | """ 233 | 234 | _configure_library_root_logger() 235 | _get_library_root_logger().propagate = True 236 | 237 | 238 | def enable_explicit_format() -> None: 239 | """ 240 | Enable explicit formatting for every HuggingFace muse' logger. The explicit formatter is as follows: 241 | ``` 242 | [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE 243 | ``` 244 | All handlers currently bound to the root logger are affected by this method. 245 | """ 246 | handlers = _get_library_root_logger().handlers 247 | 248 | for handler in handlers: 249 | formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") 250 | handler.setFormatter(formatter) 251 | 252 | 253 | def reset_format() -> None: 254 | """ 255 | Resets the formatting for HuggingFace muse' loggers. 256 | 257 | All handlers currently bound to the root logger are affected by this method. 258 | """ 259 | handlers = _get_library_root_logger().handlers 260 | 261 | for handler in handlers: 262 | handler.setFormatter(None) 263 | 264 | 265 | def warning_advice(self, *args, **kwargs): 266 | """ 267 | This method is identical to `logger.warning()`, but if env var muse_NO_ADVISORY_WARNINGS=1 is set, this 268 | warning will not be printed 269 | """ 270 | no_advisory_warnings = os.getenv("muse_NO_ADVISORY_WARNINGS", False) 271 | if no_advisory_warnings: 272 | return 273 | self.warning(*args, **kwargs) 274 | 275 | 276 | logging.Logger.warning_advice = warning_advice 277 | 278 | 279 | class EmptyTqdm: 280 | """Dummy tqdm which doesn't do anything.""" 281 | 282 | def __init__(self, *args, **kwargs): # pylint: disable=unused-argument 283 | self._iterator = args[0] if args else None 284 | 285 | def __iter__(self): 286 | return iter(self._iterator) 287 | 288 | def __getattr__(self, _): 289 | """Return empty function.""" 290 | 291 | def empty_fn(*args, **kwargs): # pylint: disable=unused-argument 292 | return 293 | 294 | return empty_fn 295 | 296 | def __enter__(self): 297 | return self 298 | 299 | def __exit__(self, type_, value, traceback): 300 | return 301 | 302 | 303 | class _tqdm_cls: 304 | def __call__(self, *args, **kwargs): 305 | if _tqdm_active: 306 | return tqdm_lib.tqdm(*args, **kwargs) 307 | else: 308 | return EmptyTqdm(*args, **kwargs) 309 | 310 | def set_lock(self, *args, **kwargs): 311 | self._lock = None 312 | if _tqdm_active: 313 | return tqdm_lib.tqdm.set_lock(*args, **kwargs) 314 | 315 | def get_lock(self): 316 | if _tqdm_active: 317 | return tqdm_lib.tqdm.get_lock() 318 | 319 | 320 | tqdm = _tqdm_cls() 321 | 322 | 323 | def is_progress_bar_enabled() -> bool: 324 | """Return a boolean indicating whether tqdm progress bars are enabled.""" 325 | global _tqdm_active 326 | return bool(_tqdm_active) 327 | 328 | 329 | def enable_progress_bar(): 330 | """Enable tqdm progress bar.""" 331 | global _tqdm_active 332 | _tqdm_active = True 333 | 334 | 335 | def disable_progress_bar(): 336 | """Disable tqdm progress bar.""" 337 | global _tqdm_active 338 | _tqdm_active = False 339 | -------------------------------------------------------------------------------- /models/misc.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import torch 3 | from typing import ( 4 | Any, 5 | Callable, 6 | Dict, 7 | Iterable, 8 | List, 9 | NamedTuple, 10 | NewType, 11 | Optional, 12 | Sized, 13 | Tuple, 14 | Type, 15 | TypeVar, 16 | Union, 17 | ) 18 | try: 19 | from typing import Literal 20 | except ImportError: 21 | from typing_extensions import Literal 22 | 23 | # Tensor dtype 24 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md 25 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt 26 | 27 | # Config type 28 | from omegaconf import DictConfig 29 | 30 | # PyTorch Tensor type 31 | from torch import Tensor 32 | 33 | # Runtime type checking decorator 34 | from typeguard import typechecked as typechecker 35 | 36 | 37 | def broadcast(tensor, src=0): 38 | if not _distributed_available(): 39 | return tensor 40 | else: 41 | torch.distributed.broadcast(tensor, src=src) 42 | return tensor 43 | 44 | def _distributed_available(): 45 | return torch.distributed.is_available() and torch.distributed.is_initialized() 46 | 47 | def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: 48 | # added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword 49 | if '--local-rank' in cfg: 50 | del cfg['--local-rank'] 51 | # added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword 52 | scfg = OmegaConf.structured(fields(**cfg)) 53 | return scfg -------------------------------------------------------------------------------- /models/sampling.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/lucidrains/muse-maskgit-pytorch 2 | 3 | import math 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | def log(t, eps=1e-20): 11 | return torch.log(t.clamp(min=eps)) 12 | 13 | 14 | def gumbel_noise(t, generator=None): 15 | noise = torch.zeros_like(t).uniform_(0, 1, generator=generator) 16 | return -log(-log(noise)) 17 | 18 | 19 | def gumbel_sample(t, temperature=1.0, dim=-1, generator=None): 20 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t, generator=generator)).argmax(dim=dim) 21 | 22 | 23 | def top_k(logits, thres=0.9): 24 | k = math.ceil((1 - thres) * logits.shape[-1]) 25 | val, ind = logits.topk(k, dim=-1) 26 | probs = torch.full_like(logits, float("-inf")) 27 | probs.scatter_(2, ind, val) 28 | return probs 29 | 30 | 31 | def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None): 32 | confidence = log(probs) + temperature * gumbel_noise(probs, generator=generator) 33 | sorted_confidence = torch.sort(confidence, dim=-1).values 34 | cut_off = torch.gather(sorted_confidence, 1, mask_len.long()) 35 | masking = confidence < cut_off 36 | return masking 37 | 38 | 39 | def cosine_schedule(t): 40 | return torch.cos(t * math.pi * 0.5) 41 | 42 | 43 | def linear_schedule(t): 44 | mask_ratio = 1 - t 45 | mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0) 46 | return mask_ratio 47 | 48 | 49 | def pow(t, method): 50 | exponent = float(method.replace("pow", "")) 51 | mask_ratio = 1.0 - t**exponent 52 | mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0) 53 | return mask_ratio 54 | 55 | 56 | def sigmoid_schedule(t, start=-3, end=3, tau=1.0, clip_min=1e-6): 57 | for item in [t, start, end, tau]: 58 | item = torch.tensor(item) if not torch.is_tensor(item) else item 59 | 60 | # A gamma function based on sigmoid function. 61 | v_start = torch.sigmoid(torch.tensor(start / tau)) 62 | v_end = torch.sigmoid(torch.tensor(end / tau)) 63 | output = torch.sigmoid((t * (end - start) + start) / tau) 64 | output = (v_end - output) / (v_end - v_start) 65 | return torch.clip(output, clip_min, 1.0) 66 | 67 | 68 | def get_mask_chedule(method, **schedule_kwargs): 69 | if method == "cosine": 70 | return cosine_schedule 71 | elif method == "linear": 72 | return linear_schedule 73 | elif "pow" in method: 74 | return partial(pow, method=method) 75 | elif method == "sigmoid": 76 | return partial(sigmoid_schedule, **schedule_kwargs) 77 | else: 78 | raise ValueError("Unknown schedule method: {}".format(method)) 79 | 80 | def top_k_top_p_filtering( 81 | logits: torch.Tensor, 82 | top_k: int = 0, 83 | top_p: float = 1.0, 84 | filter_value: float = -float("Inf"), 85 | min_tokens_to_keep: int = 1, 86 | ) -> torch.Tensor: 87 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 88 | Args: 89 | logits: logits distribution shape (batch size, vocabulary size) 90 | if top_k > 0: keep only top k tokens with highest probability (top-k filtering). 91 | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 92 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 93 | Make sure we keep at least min_tokens_to_keep per batch example in the output 94 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 95 | """ 96 | if top_k > 0: 97 | top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check 98 | # Remove all tokens with a probability less than the last token of the top-k 99 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 100 | logits[indices_to_remove] = filter_value 101 | 102 | if top_p < 1.0: 103 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 104 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 105 | 106 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 107 | sorted_indices_to_remove = cumulative_probs > top_p 108 | if min_tokens_to_keep > 1: 109 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 110 | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 111 | # Shift the indices to the right to keep also the first token above the threshold 112 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 113 | sorted_indices_to_remove[..., 0] = 0 114 | 115 | # scatter sorted tensors to original indexing 116 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 117 | logits[indices_to_remove] = filter_value 118 | return logits 119 | -------------------------------------------------------------------------------- /parquet/__init__.py: -------------------------------------------------------------------------------- 1 | from .refinedweb_dataset import RefinedWebDataset 2 | -------------------------------------------------------------------------------- /parquet/refinedweb_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 NUS Show Lab. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import collections 17 | import random 18 | 19 | import torch 20 | from parquet.parquet_dataset import CruiseParquetDataset 21 | 22 | 23 | class RefinedWebDataset(CruiseParquetDataset): 24 | def __init__(self, 25 | data_path, 26 | rank: int = 0, 27 | world_size: int = 1, 28 | shuffle=True, 29 | repeat=True, 30 | buffer_size=1000, 31 | max_length=8000, 32 | num_workers=1, 33 | **kwargs 34 | ): 35 | super().__init__(data_path, rank, world_size, shuffle, repeat, verbose=False, buffer_size=buffer_size, meta_data_path=None, state_path=None, num_workers=num_workers) 36 | self.max_length = max_length 37 | 38 | def __iter__(self): 39 | for example in self.generate(): 40 | try: 41 | data, current_worker_hash, data_idx, seed = example 42 | text = data['content'].replace('\n', '') 43 | if len(text) > self.max_length: 44 | start_index = random.randint(0, len(text) - self.max_length - 1) 45 | selected_text = text[start_index:start_index + self.max_length] 46 | else: 47 | selected_text = text 48 | ret = {'input_ids': selected_text} 49 | yield ret 50 | 51 | except Exception as e: 52 | # print('internal dataset iter error', e) 53 | continue 54 | 55 | def collate_fn(self, batch): 56 | batched = collections.defaultdict(list) 57 | for data in batch: 58 | for k, v in data.items(): 59 | batched[k].append(v) 60 | for k, v in batched.items(): 61 | if k not in ('key', 'input_ids', 'similarity'): 62 | batched[k] = torch.stack(v, dim=0) 63 | 64 | return batched 65 | 66 | if __name__ == '__main__': 67 | 68 | dataset = RefinedWebDataset('/mnt/bn/vgfm2/test_mlx/xavier/data/falcon-refinedweb/data/*.parquet', num_workers=10) 69 | from torch.utils.data import DataLoader 70 | train_dataloader = DataLoader(dataset, batch_size=10, 71 | sampler=None, collate_fn=dataset.collate_fn, 72 | num_workers=10) 73 | # num_workers=0) 74 | for i, batch in enumerate(train_dataloader): 75 | print(len(batch['input_ids'][0])) 76 | import ipdb; ipdb.set_trace() 77 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.21.0 2 | aiohttp==3.9.5 3 | aiosignal==1.3.1 4 | albumentations==0.3.2 5 | annotated-types==0.7.0 6 | antlr4-python3-runtime==4.9.3 7 | anykeystore==0.2 8 | asn1crypto==1.5.1 9 | asttokens==2.4.1 10 | async-timeout==4.0.3 11 | attrs==21.2.0 12 | bidict==0.23.1 13 | blessed==1.20.0 14 | boto3==1.34.113 15 | botocore==1.34.113 16 | braceexpand==0.1.7 17 | cachetools==5.3.3 18 | certifi==2024.2.2 19 | cffi==1.16.0 20 | chardet==5.2.0 21 | charset-normalizer==3.3.2 22 | click==8.1.7 23 | clip==0.2.0 24 | clip-openai==1.0.post20230121 25 | cmake==3.29.3 26 | cramjam==2.8.3 27 | crcmod==1.7 28 | cryptacular==1.6.2 29 | cryptography==39.0.2 30 | cycler==0.12.1 31 | datasets 32 | diffusers==0.30.1 33 | decorator==5.1.1 34 | decord==0.6.0 35 | deepspeed==0.14.2 36 | defusedxml==0.7.1 37 | Deprecated==1.2.14 38 | descartes==1.1.0 39 | dill==0.3.8 40 | distlib==0.3.8 41 | distro-info==1.0 42 | dnspython==2.6.1 43 | docker-pycreds==0.4.0 44 | docstring_parser==0.16 45 | ecdsa==0.19.0 46 | einops==0.6.0 47 | exceptiongroup==1.2.1 48 | executing==2.0.1 49 | fairscale==0.4.13 50 | fastparquet==2024.5.0 51 | ffmpegcv==0.3.13 52 | filelock==3.14.0 53 | fire==0.6.0 54 | fonttools==4.51.0 55 | frozenlist==1.4.1 56 | fsspec==2023.6.0 57 | ftfy==6.2.0 58 | gitdb==4.0.11 59 | GitPython==3.1.43 60 | gpustat==1.1.1 61 | greenlet==3.0.3 62 | grpcio==1.64.0 63 | h11==0.14.0 64 | hjson==3.1.0 65 | huggingface-hub==0.23.2 66 | hupper==1.12.1 67 | idna==3.7 68 | imageio==2.34.1 69 | imgaug==0.2.6 70 | iniconfig==2.0.0 71 | ipaddress==1.0.23 72 | ipdb==0.13.13 73 | ipython==8.18.1 74 | jaxtyping==0.2.28 75 | jedi==0.19.1 76 | Jinja2==3.1.4 77 | jmespath==1.0.1 78 | joblib==1.4.2 79 | jsonargparse==4.14.1 80 | jsonlines==4.0.0 81 | kiwisolver==1.4.5 82 | kornia==0.7.2 83 | kornia_rs==0.1.3 84 | lazy_loader==0.4 85 | lightning==2.2.3 86 | lightning-utilities==0.11.2 87 | lit==18.1.6 88 | MarkupSafe==2.1.5 89 | matplotlib==3.5.3 90 | matplotlib-inline==0.1.7 91 | miscreant==0.3.0 92 | mpmath==1.3.0 93 | msgpack==1.0.8 94 | multidict==6.0.5 95 | multiprocess==0.70.16 96 | natsort==8.4.0 97 | networkx==3.2.1 98 | ninja==1.11.1.1 99 | numpy==1.24.4 100 | nuscenes-devkit==1.1.11 101 | oauthlib==3.2.2 102 | omegaconf==2.3.0 103 | open-clip-torch==2.24.0 104 | openai-clip 105 | opencv-python==4.9.0.80 106 | opencv-python-headless==3.4.18.65 107 | packaging==22.0 108 | pandas==1.5.3 109 | parquet==1.3.1 110 | parso==0.8.4 111 | PasteDeploy==3.1.0 112 | pathlib2==2.3.7.post1 113 | pathtools==0.1.2 114 | pbkdf2==1.3 115 | pexpect==4.9.0 116 | pillow==10.3.0 117 | plaster==1.1.2 118 | plaster-pastedeploy==1.0.1 119 | platformdirs==4.2.2 120 | plotly==5.22.0 121 | pluggy==1.5.0 122 | ply==3.11 123 | promise==2.3 124 | prompt-toolkit==3.0.43 125 | protobuf==3.20.3 126 | psutil==5.9.8 127 | ptyprocess==0.7.0 128 | pure-eval==0.2.2 129 | py==1.11.0 130 | py-cpuinfo==9.0.0 131 | py-spy==0.3.14 132 | pyarrow==11.0.0 133 | pyarrow-hotfix==0.6 134 | pyasn1==0.6.0 135 | pycocotools==2.0.7 136 | pycparser==2.22 137 | pycryptodomex==3.20.0 138 | pycurl==7.43.0.6 139 | pydantic==1.10.15 140 | pydantic_core==2.18.3 141 | Pygments==2.18.0 142 | PyJWT==2.8.0 143 | pynvml==11.5.0 144 | pyope==0.2.2 145 | pyOpenSSL==23.2.0 146 | pyparsing==3.1.2 147 | pyquaternion==0.9.9 148 | pyramid==2.0.2 149 | pyramid-mailer==0.15.1 150 | pytest==6.2.5 151 | python-consul==1.1.0 152 | python-dateutil==2.9.0.post0 153 | python-engineio==4.9.1 154 | python-etcd==0.4.5 155 | python-jose==3.3.0 156 | python-socketio==5.11.2 157 | python3-openid==3.2.0 158 | pytorch-extension==0.2 159 | pytorch-lightning==2.2.3 160 | pytz==2024.1 161 | PyYAML==6.0.1 162 | regex==2024.5.15 163 | repoze.sendmail==4.4.1 164 | requests==2.31.0 165 | requests-oauthlib==2.0.0 166 | rsa==4.9 167 | s3transfer==0.10.1 168 | safetensors==0.4.3 169 | schedule==1.2.2 170 | scikit-image==0.22.0 171 | scikit-learn==1.5.0 172 | scipy==1.13.1 173 | sentencepiece==0.2.0 174 | sentry-sdk==2.3.1 175 | setproctitle==1.3.3 176 | Shapely==1.8.5.post1 177 | shortuuid==1.0.13 178 | simple-websocket==1.0.0 179 | six==1.16.0 180 | smmap==5.0.1 181 | SQLAlchemy==2.0.30 182 | stack-data==0.6.3 183 | sympy==1.12 184 | taming-transformers-rom1504==0.0.6 185 | tenacity==8.3.0 186 | tensorboardX==2.6.2.2 187 | termcolor==2.4.0 188 | threadpoolctl==3.5.0 189 | thriftpy2==0.5.0 190 | tifffile==2024.5.22 191 | timm==1.0.3 192 | tokenizers==0.19.1 193 | toml==0.10.2 194 | tomli==2.0.1 195 | torch==2.2.1 196 | torch-fidelity==0.3.0 197 | torchmetrics==1.4.0.post0 198 | torchvision==0.17.1 199 | tox==3.28.0 200 | tqdm==4.66.4 201 | traitlets==5.14.3 202 | transaction==4.0 203 | transformers==4.41.1 204 | translationstring==1.4 205 | triton==2.2.0 206 | typeguard==2.13.3 207 | typing_extensions==4.12.0 208 | tzdata==2024.1 209 | urllib3==1.26.18 210 | velruse==1.1.1 211 | venusian==3.1.0 212 | virtualenv==20.26.2 213 | wandb==0.17.0 214 | watchdog==4.0.1 215 | wcwidth==0.2.13 216 | webdataset==0.2.86 217 | WebOb==1.8.7 218 | websocket-client==1.8.0 219 | wrapt==1.16.0 220 | wsproto==1.2.0 221 | WTForms==3.1.2 222 | wtforms-recaptcha==0.3.2 223 | xformers==0.0.25 224 | xxhash==3.4.1 225 | yarl==1.9.4 226 | zope.deprecation==5.0 227 | zope.interface==6.4.post2 228 | zope.sqlalchemy==3.1 229 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Show-o/c890d34ddc2d8994bad408c254f7dc4689b93287/training/__init__.py -------------------------------------------------------------------------------- /training/imagenet_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 NUS Show Lab. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import collections 17 | from typing import Any, Callable, Optional 18 | 19 | import torch 20 | from torchvision.datasets.folder import DatasetFolder, default_loader 21 | from training.utils import image_transform 22 | 23 | 24 | class ImageNetDataset(DatasetFolder): 25 | def __init__( 26 | self, 27 | root: str, 28 | loader: Callable[[str], Any] = default_loader, 29 | is_valid_file: Optional[Callable[[str], bool]] = None, 30 | image_size=256, 31 | ): 32 | IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") 33 | 34 | self.transform = image_transform 35 | self.image_size = image_size 36 | 37 | super().__init__( 38 | root, 39 | loader, 40 | IMG_EXTENSIONS if is_valid_file is None else None, 41 | transform=self.transform, 42 | target_transform=None, 43 | is_valid_file=is_valid_file, 44 | ) 45 | 46 | with open('./training/imagenet_label_mapping', 'r') as f: 47 | self.labels = {} 48 | for l in f: 49 | num, description = l.split(":") 50 | self.labels[int(num)] = description.strip() 51 | 52 | print("ImageNet dataset loaded.") 53 | 54 | def __getitem__(self, idx): 55 | 56 | try: 57 | path, target = self.samples[idx] 58 | image = self.loader(path) 59 | image = self.transform(image, resolution=self.image_size) 60 | input_ids = "{}".format(self.labels[target]) 61 | class_ids = torch.tensor(target) 62 | 63 | return {'images': image, 'input_ids': input_ids, 'class_ids': class_ids} 64 | 65 | except Exception as e: 66 | print(e) 67 | return self.__getitem__(idx+1) 68 | 69 | def collate_fn(self, batch): 70 | batched = collections.defaultdict(list) 71 | for data in batch: 72 | for k, v in data.items(): 73 | batched[k].append(v) 74 | for k, v in batched.items(): 75 | if k not in ('input_ids'): 76 | batched[k] = torch.stack(v, dim=0) 77 | 78 | return batched 79 | 80 | 81 | if __name__ == '__main__': 82 | pass 83 | -------------------------------------------------------------------------------- /training/omni_attention.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 NUS Show Lab. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | torch.set_default_device('cuda') 18 | from torch.nn.attention.flex_attention import create_block_mask, flex_attention 19 | flex_attention = torch.compile(flex_attention, dynamic=False) 20 | 21 | # Class for Omni-Attention Mechanism based on FlexAttention (torch >= 2.5) 22 | class OmniAttentionMechanism(torch.nn.Module): 23 | # def __init__(self, batch_size_t2i, batch_size_lm, batch_size_mmu, S, image_begin_ends=[(128 + 1, 128 + 1 + 258)], device='cuda'): 24 | def __init__(self, batch_size_t2i, batch_size_lm, batch_size_mmu, S, t2i_image_begin_end=[(128, 1152)], mmu_end=1027, right_padding=[(1024, 1280)], device='cuda'): 25 | # def __init__(self, batch_size_t2i, batch_size_lm, batch_size_mmu, S, t2i_image_begin_end=[(15, 20)], mmu_end=1027, right_padding=[(1024, 1280)], device='cuda'): 26 | super().__init__() 27 | 28 | self.batch_size_t2i = batch_size_t2i 29 | self.batch_size_lm = batch_size_lm 30 | self.batch_size_mmu = batch_size_mmu 31 | 32 | self.t2i_image_begin_end = t2i_image_begin_end 33 | self.t2i_full_begin = torch.arange(S, device=device) 34 | self.t2i_full_end = torch.arange(S, device=device) 35 | for image_begin, image_end in t2i_image_begin_end: 36 | self.t2i_full_begin[image_begin:image_end] = image_begin 37 | self.t2i_full_end[image_begin:image_end] = image_end 38 | 39 | self.mmu_end = mmu_end 40 | 41 | # if we add padding on the right most of sequence 42 | # self.right_pad_begins = torch.arange(S, device=device) 43 | # self.right_pad_ends = torch.arange(S, device=device) 44 | # for image_begin, image_end in right_padding: 45 | # self.right_pad_begins[image_begin:image_end] = image_begin 46 | # self.right_pad_ends[image_begin:image_end] = image_end 47 | 48 | def causal_mask(self, b, h, q_idx, kv_idx): 49 | # right_pad_mask = ~((kv_idx < self.right_pad_ends[q_idx]) & (kv_idx >= self.right_pad_begins[q_idx])) 50 | return (q_idx >= kv_idx) #& right_pad_mask 51 | 52 | def t2i_mask(self, b, h, q_idx, kv_idx): 53 | """ 54 | (batch_size, seq_len) 55 | t2i sequence = [ 56 | [pad][pad][t2i][sot][text][text][eot][soi][image][image][eoi] 57 | [pad][t2i][sot][text][text][text][eot][soi][image][image][eoi] 58 | ] 59 | left padding for the text 60 | #right padding for the requirement of flexattention (len is the multiple of 128) 61 | """ 62 | # causal mask that excludes padding regions 63 | # eye_mask = (q_idx == kv_idx) to avoid the NaN issue 64 | causal_mask = ~((kv_idx < self.pad_ends[b, kv_idx])) & ((q_idx >= kv_idx)) | (q_idx == kv_idx) 65 | full_mask = (kv_idx < self.t2i_full_end[q_idx]) & (kv_idx >= self.t2i_full_begin[q_idx]) 66 | # remove right padding attention (becuase we add some padding at the end of the sqeuence to meet the len of flexattention) 67 | # right_pad_mask = ~((kv_idx < self.right_pad_ends[q_idx]) & (kv_idx >= self.right_pad_begins[q_idx])) 68 | 69 | return (causal_mask | full_mask) #& right_pad_mask 70 | 71 | # TODO: check the boundary. 72 | def mmu_mask(self, b, h, q_idx, kv_idx): 73 | # right_pad_mask = ~((kv_idx < self.right_pad_ends[q_idx]) & (kv_idx >= self.right_pad_begins[q_idx])) 74 | return (q_idx >= kv_idx) | (kv_idx < self.mmu_end) #& right_pad_mask 75 | 76 | # TODO: check the boundary. 77 | def mmu_vit_mask(self, b, h, q_idx, kv_idx, system_prompt_len=28, num_clip_vit_feat=576): 78 | index = 1 + system_prompt_len + 1 + num_clip_vit_feat 79 | return (q_idx >= kv_idx) | ((kv_idx >= (1 + system_prompt_len + 1)) & (kv_idx < index)) 80 | 81 | def mixed_mask(self, b, h, q_idx, kv_idx, num_clip_vit_feat=576): 82 | # causal mask that excludes padding regions 83 | # to avoid the NaN issue 84 | # eye_mask = (q_idx == kv_idx) 85 | causal_mask = ~(kv_idx < self.pad_ends[b, kv_idx]) & (q_idx >= kv_idx) | (q_idx == kv_idx) 86 | full_mask = (kv_idx < self.t2i_full_end[q_idx]) & (kv_idx >= self.t2i_full_begin[q_idx]) 87 | # right_pad_mask = ~((kv_idx < self.right_pad_ends[q_idx]) & (kv_idx >= self.right_pad_begins[q_idx])) 88 | t2i_mask = (causal_mask | full_mask) #& right_pad_mask 89 | 90 | lm_mask = (q_idx >= kv_idx) #& right_pad_mask 91 | # mmu_mask = (q_idx >= kv_idx) | (kv_idx <= num_clip_vit_feat + 3) #& right_pad_mask 92 | mmu_mask = (q_idx >= kv_idx) | (kv_idx < self.mmu_end) 93 | 94 | return (((b < self.batch_size_t2i) & t2i_mask) 95 | ^ ((b >= self.batch_size_t2i) & (b < (self.batch_size_t2i + self.batch_size_lm)) & lm_mask) 96 | ^ ((b >= (self.batch_size_t2i + self.batch_size_lm)) & mmu_mask)) 97 | 98 | def create_block_mask(self, sequence, pad_begin_ends=[(0, 80), (0, 100), (0, 110), (0, 0)], type="t2i"): 99 | # def create_block_mask(self, sequence, pad_begin_ends=[(0, 10), (0, 5), (0, 2), (0, 0)], type="t2i"): 100 | B, S = sequence.shape 101 | self.pad_begins = torch.arange(S, device='cuda').repeat(B, 1) 102 | self.pad_ends = torch.arange(S, device='cuda').repeat(B, 1) 103 | 104 | cnt = 0 105 | for pb, pe in pad_begin_ends: 106 | self.pad_begins[cnt, pb:pe] = pb 107 | self.pad_ends[cnt, pb:pe] = pe 108 | cnt += 1 109 | 110 | if type == "t2i": 111 | block_mask = create_block_mask(self.t2i_mask, B=B, H=None, Q_LEN=S, KV_LEN=S, _compile=True) 112 | elif type == "mmu": 113 | block_mask = create_block_mask(self.mmu_mask, B=B, H=None, Q_LEN=S, KV_LEN=S, _compile=True) 114 | elif type == "mmu_vit": 115 | block_mask = create_block_mask(self.mmu_vit_mask, B=B, H=None, Q_LEN=S, KV_LEN=S, _compile=True) 116 | elif type == "causal": 117 | block_mask = create_block_mask(self.causal_mask, B=B, H=None, Q_LEN=S, KV_LEN=S, _compile=True) 118 | elif type == "mixed-t2i-lm-mmu": 119 | block_mask = create_block_mask(self.mixed_mask, B=B, H=None, Q_LEN=S, KV_LEN=S, _compile=True) 120 | else: 121 | raise ValueError("Unknown type") 122 | 123 | return block_mask 124 | 125 | def test(self): 126 | attn_mask = torch.zeros(4, 1, 21, 21) 127 | for b in range(4): 128 | for h in range(1): 129 | for q_idx in range(21): 130 | for kv_idx in range(21): 131 | attn_mask[b, h, q_idx, kv_idx] = self.t2i_mask(b, h, q_idx, kv_idx) 132 | # import ipdb 133 | # ipdb.set_trace() 134 | # print() 135 | return attn_mask 136 | 137 | def create_attention_mask_for_mmu_vit( 138 | sequence, 139 | return_inverse_mask=False, 140 | system_prompt_len=0 141 | ): 142 | N, L = sequence.shape 143 | causal_mask = torch.tril(torch.ones((N, 1, L, L), dtype=torch.bool)).to(sequence.device) 144 | index = 1 + system_prompt_len + 1 + 576 145 | 146 | causal_mask[:, :, :, :index] = 1 147 | 148 | causal_mask[0:4, :, :, :] = 1 149 | 150 | if return_inverse_mask: 151 | inverted_mask = 1.0 - causal_mask.type(torch.int64) 152 | inverted_mask = inverted_mask.masked_fill( 153 | inverted_mask.to(torch.bool), torch.iinfo(torch.int64).min 154 | ) 155 | return inverted_mask.to(dtype=torch.bool) 156 | else: 157 | return causal_mask 158 | 159 | if __name__ == '__main__': 160 | 161 | from triton.testing import do_bench 162 | 163 | B = 12 164 | S = 1152 # must be the multiple of 128 165 | H = 8 166 | D = 64 167 | q, k, v = [torch.randn(B, H, S, D, dtype=torch.float16) for _ in range(3)] 168 | 169 | OAM = OmniAttentionMechanism(4, 4, 4, S) 170 | 171 | sequence = torch.randn((B, S), device='cuda') 172 | block_mask = OAM.create_block_mask(sequence, type='t2i') 173 | print(block_mask) 174 | 175 | flex_attn = lambda: flex_attention(q, k, v, block_mask=block_mask) 176 | print("t2i flexattention: ", do_bench(flex_attn)) 177 | 178 | mask = OAM.test() 179 | import ipdb 180 | 181 | ipdb.set_trace() 182 | 183 | sequence = torch.randn((B, S), device='cuda') 184 | block_mask = OAM.create_block_mask(sequence, type='causal') 185 | print(block_mask) 186 | 187 | flex_attn = lambda: flex_attention(q, k, v, block_mask=block_mask) 188 | print("lm flexattention: ", do_bench(flex_attn)) 189 | 190 | sequence = torch.randn((B, S), device='cuda') 191 | import time 192 | s = time.time() 193 | block_mask = OAM.create_block_mask(sequence, type='mmu') 194 | print(block_mask) 195 | print(time.time() - s, 'create mmu mask') 196 | 197 | flex_attn = lambda: flex_attention(q, k, v, block_mask=block_mask) 198 | print("mmu flexattention: ", do_bench(flex_attn)) 199 | 200 | sequence = torch.randn((B, S), device='cuda') 201 | block_mask = OAM.create_block_mask(sequence, type='mmu_vit') 202 | print(block_mask.shape) 203 | 204 | flex_attn = lambda: flex_attention(q, k, v, block_mask=block_mask) 205 | print("mmu vit flexattention: ", do_bench(flex_attn)) 206 | 207 | sequence = torch.randn((B, S), device='cuda') 208 | import time 209 | s = time.time() 210 | block_mask = OAM.create_block_mask(sequence, type='mixed-t2i-lm-mmu') 211 | print(block_mask.shape) 212 | print(time.time()-s, 'create mixed mask') 213 | 214 | flex_attn = lambda: flex_attention(q, k, v, block_mask=block_mask) 215 | print("mixed-t2i-lm-mmu flexattention: ", do_bench(flex_attn)) 216 | 217 | 218 | 219 | import torch.nn.functional as F 220 | from torch.backends.cuda import sdp_kernel, SDPBackend 221 | 222 | # Helpful arg mapper 223 | backend_map = { 224 | SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False}, 225 | SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False}, 226 | SDPBackend.EFFICIENT_ATTENTION: { 227 | "enable_math": False, "enable_flash": False, "enable_mem_efficient": True} 228 | } 229 | 230 | with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): 231 | # with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): 232 | # with sdp_kernel(**backend_map[SDPBackend.MATH]): 233 | q, k, v = [torch.randn(B, H, S, D, dtype=torch.float16) for _ in range(3)] 234 | sequence = torch.randn(B, S) 235 | s = time.time() 236 | mask = create_attention_mask_for_mmu_vit(sequence) 237 | print(time.time() - s, 'create mmu vit mask') 238 | xformer_attn = lambda: F.scaled_dot_product_attention(q, k, v, attn_mask=mask) 239 | print("xformer: ", do_bench(xformer_attn)) 240 | 241 | 242 | 243 | 244 | 245 | -------------------------------------------------------------------------------- /training/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """PyTorch implementation of the Lion optimizer.""" 16 | import torch 17 | from torch.optim.optimizer import Optimizer 18 | 19 | 20 | class Lion(Optimizer): 21 | r"""Implements Lion algorithm.""" 22 | 23 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0, **kwargs): 24 | """Initialize the hyperparameters. 25 | Args: 26 | params (iterable): iterable of parameters to optimize or dicts defining 27 | parameter groups 28 | lr (float, optional): learning rate (default: 1e-4) 29 | betas (Tuple[float, float], optional): coefficients used for computing 30 | running averages of gradient and its square (default: (0.9, 0.99)) 31 | weight_decay (float, optional): weight decay coefficient (default: 0) 32 | """ 33 | 34 | if not 0.0 <= lr: 35 | raise ValueError("Invalid learning rate: {}".format(lr)) 36 | if not 0.0 <= betas[0] < 1.0: 37 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 38 | if not 0.0 <= betas[1] < 1.0: 39 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 40 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) 41 | super().__init__(params, defaults) 42 | 43 | @torch.no_grad() 44 | def step(self, closure=None): 45 | """Performs a single optimization step. 46 | Args: 47 | closure (callable, optional): A closure that reevaluates the model 48 | and returns the loss. 49 | Returns: 50 | the loss. 51 | """ 52 | loss = None 53 | if closure is not None: 54 | with torch.enable_grad(): 55 | loss = closure() 56 | 57 | for group in self.param_groups: 58 | for p in group["params"]: 59 | if p.grad is None: 60 | continue 61 | 62 | # Perform stepweight decay 63 | p.data.mul_(1 - group["lr"] * group["weight_decay"]) 64 | 65 | grad = p.grad 66 | state = self.state[p] 67 | # State initialization 68 | if len(state) == 0: 69 | # Exponential moving average of gradient values 70 | state["exp_avg"] = torch.zeros_like(p) 71 | 72 | exp_avg = state["exp_avg"] 73 | beta1, beta2 = group["betas"] 74 | 75 | # Weight update 76 | update = exp_avg * beta1 + grad * (1 - beta1) 77 | p.add_(torch.sign(update), alpha=-group["lr"]) 78 | # Decay the momentum running average coefficient 79 | exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) 80 | 81 | return loss 82 | -------------------------------------------------------------------------------- /training/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | import torch.nn.functional as F 5 | from omegaconf import DictConfig, ListConfig, OmegaConf 6 | from typing import Any, List, Tuple, Union 7 | 8 | 9 | ################################################## 10 | # config utils 11 | ################################################## 12 | def get_config(): 13 | cli_conf = OmegaConf.from_cli() 14 | yaml_conf = OmegaConf.load(cli_conf.config) 15 | conf = OmegaConf.merge(yaml_conf, cli_conf) 16 | 17 | return conf 18 | 19 | 20 | def flatten_omega_conf(cfg: Any, resolve: bool = False) -> List[Tuple[str, Any]]: 21 | ret = [] 22 | 23 | def handle_dict(key: Any, value: Any, resolve: bool) -> List[Tuple[str, Any]]: 24 | return [(f"{key}.{k1}", v1) for k1, v1 in flatten_omega_conf(value, resolve=resolve)] 25 | 26 | def handle_list(key: Any, value: Any, resolve: bool) -> List[Tuple[str, Any]]: 27 | return [(f"{key}.{idx}", v1) for idx, v1 in flatten_omega_conf(value, resolve=resolve)] 28 | 29 | if isinstance(cfg, DictConfig): 30 | for k, v in cfg.items_ex(resolve=resolve): 31 | if isinstance(v, DictConfig): 32 | ret.extend(handle_dict(k, v, resolve=resolve)) 33 | elif isinstance(v, ListConfig): 34 | ret.extend(handle_list(k, v, resolve=resolve)) 35 | else: 36 | ret.append((str(k), v)) 37 | elif isinstance(cfg, ListConfig): 38 | for idx, v in enumerate(cfg._iter_ex(resolve=resolve)): 39 | if isinstance(v, DictConfig): 40 | ret.extend(handle_dict(idx, v, resolve=resolve)) 41 | elif isinstance(v, ListConfig): 42 | ret.extend(handle_list(idx, v, resolve=resolve)) 43 | else: 44 | ret.append((str(idx), v)) 45 | else: 46 | assert False 47 | 48 | return ret 49 | 50 | 51 | ################################################## 52 | # training utils 53 | ################################################## 54 | def soft_target_cross_entropy(logits, targets, soft_targets): 55 | # ignore the first token from logits and targets (class id token) 56 | logits = logits[:, 1:] 57 | targets = targets[:, 1:] 58 | 59 | logits = logits[..., : soft_targets.shape[-1]] 60 | 61 | log_probs = F.log_softmax(logits, dim=-1) 62 | padding_mask = targets.eq(-100) 63 | 64 | loss = torch.sum(-soft_targets * log_probs, dim=-1) 65 | loss.masked_fill_(padding_mask, 0.0) 66 | 67 | # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded): 68 | num_active_elements = padding_mask.numel() - padding_mask.long().sum() 69 | loss = loss.sum() / num_active_elements 70 | return loss 71 | 72 | 73 | def get_loss_weight(t, mask, min_val=0.3): 74 | return 1 - (1 - mask) * ((1 - t) * (1 - min_val))[:, None] 75 | 76 | 77 | def mask_or_random_replace_tokens(image_tokens, mask_id, config, mask_schedule, is_train=True): 78 | batch_size, seq_len = image_tokens.shape 79 | 80 | if not is_train and config.training.get("eval_mask_ratios", None): 81 | mask_prob = random.choices(config.training.eval_mask_ratios, k=batch_size) 82 | mask_prob = torch.tensor(mask_prob, device=image_tokens.device) 83 | else: 84 | # Sample a random timestep for each image 85 | timesteps = torch.rand(batch_size, device=image_tokens.device) 86 | # Sample a random mask probability for each image using timestep and cosine schedule 87 | mask_prob = mask_schedule(timesteps) 88 | mask_prob = mask_prob.clip(config.training.min_masking_rate) 89 | 90 | # creat a random mask for each image 91 | num_token_masked = (seq_len * mask_prob).round().clamp(min=1) 92 | 93 | mask_contiguous_region_prob = config.training.get("mask_contiguous_region_prob", None) 94 | 95 | if mask_contiguous_region_prob is None: 96 | mask_contiguous_region = False 97 | else: 98 | mask_contiguous_region = random.random() < mask_contiguous_region_prob 99 | 100 | if not mask_contiguous_region: 101 | batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1) 102 | mask = batch_randperm < num_token_masked.unsqueeze(-1) 103 | else: 104 | resolution = int(seq_len ** 0.5) 105 | mask = torch.zeros((batch_size, resolution, resolution), device=image_tokens.device) 106 | 107 | # TODO - would be nice to vectorize 108 | for batch_idx, num_token_masked_ in enumerate(num_token_masked): 109 | num_token_masked_ = int(num_token_masked_.item()) 110 | 111 | # NOTE: a bit handwavy with the bounds but gets a rectangle of ~num_token_masked_ 112 | num_token_masked_height = random.randint( 113 | math.ceil(num_token_masked_ / resolution), min(resolution, num_token_masked_) 114 | ) 115 | num_token_masked_height = min(num_token_masked_height, resolution) 116 | 117 | num_token_masked_width = math.ceil(num_token_masked_ / num_token_masked_height) 118 | num_token_masked_width = min(num_token_masked_width, resolution) 119 | 120 | start_idx_height = random.randint(0, resolution - num_token_masked_height) 121 | start_idx_width = random.randint(0, resolution - num_token_masked_width) 122 | 123 | mask[ 124 | batch_idx, 125 | start_idx_height: start_idx_height + num_token_masked_height, 126 | start_idx_width: start_idx_width + num_token_masked_width, 127 | ] = 1 128 | 129 | mask = mask.reshape(batch_size, seq_len) 130 | mask = mask.to(torch.bool) 131 | 132 | # mask images and create input and labels 133 | if config.training.get("noise_type", "mask"): 134 | input_ids = torch.where(mask, mask_id, image_tokens) 135 | elif config.training.get("noise_type", "random_replace"): 136 | # sample random tokens from the vocabulary 137 | random_tokens = torch.randint_like( 138 | image_tokens, low=0, high=config.model.codebook_size, device=image_tokens.device 139 | ) 140 | input_ids = torch.where(mask, random_tokens, image_tokens) 141 | else: 142 | raise ValueError(f"noise_type {config.training.noise_type} not supported") 143 | 144 | if ( 145 | config.training.get("predict_all_tokens", False) 146 | or config.training.get("noise_type", "mask") == "random_replace" 147 | ): 148 | labels = image_tokens 149 | loss_weight = get_loss_weight(mask_prob, mask.long()) 150 | else: 151 | labels = torch.where(mask, image_tokens, -100) 152 | loss_weight = None 153 | 154 | return input_ids, labels, loss_weight, mask_prob 155 | 156 | 157 | ################################################## 158 | # misc 159 | ################################################## 160 | class AverageMeter(object): 161 | """Computes and stores the average and current value""" 162 | 163 | def __init__(self): 164 | self.reset() 165 | 166 | def reset(self): 167 | self.val = 0 168 | self.avg = 0 169 | self.sum = 0 170 | self.count = 0 171 | 172 | def update(self, val, n=1): 173 | self.val = val 174 | self.sum += val * n 175 | self.count += n 176 | self.avg = self.sum / self.count 177 | 178 | from torchvision import transforms 179 | def image_transform(image, resolution=256, normalize=True): 180 | image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image) 181 | image = transforms.CenterCrop((resolution, resolution))(image) 182 | image = transforms.ToTensor()(image) 183 | if normalize: 184 | image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image) 185 | return image -------------------------------------------------------------------------------- /validation_prompts/imagenet_prompts.txt: -------------------------------------------------------------------------------- 1 | golden retriever 2 | tiger 3 | wall clock 4 | bicycle-built-for-two 5 | coffee mug 6 | laptop 7 | banana 8 | broccoli 9 | pizza 10 | garbage truck -------------------------------------------------------------------------------- /validation_prompts/showoprompts.txt: -------------------------------------------------------------------------------- 1 | A 3D render of a futuristic car made of glass, driving through a city of mirrors. 2 | A photo-realistic image of a garden with pink and blue flowers. There are pink poppies in the foreground, with their petals gently curved. The background features purple cosmos flowers. The flowers have water droplets on their petals, which glisten in the natural light. The green leaves are lush and healthy. The background is blurred, with a few trees and buildings visible. The overall image has a high resolution and is hyper-realistic, as if taken by a skilled photographer. 3 | an egg and a bird made of wheat bread. 4 | An armchair in the shape of an avocado 5 | The image features a stylized stained glass illustration of a hummingbird with vibrant colors, set against a backdrop of swirling patterns and a large sun. The composition includes floral elements and intricate details, creating a vivid and dynamic scene that emphasizes the beauty of the bird. The colors range from greens to reds, enhancing the lively and artistic aesthetic of the piece. 6 | A 3D render of a surreal explosion scene on the shore of a beautiful white sand beach with crystal clear water. The explosion has a spatter of oil paint with pastel colors and a thick consistency. The explosion is in a quiet and serene environment. A beautiful Japanese woman with a dress compacted to the sea is seen. There are butterfly petals and flowers with an ethereal glow and bioluminescence. There are pink and blue roses, and the overall image has a surreal and dreamlike quality. 7 | A 3D render of a cute, round rice ball character with big, sparkling eyes that convey curiosity and joy. Its body is a soft, fluffy white with a slight sheen, resembling freshly cooked rice. Mochi has small, rosy cheeks that give it a warm, friendly expression. A tiny smile brightens its face, and it often sports a colorful ribbon tied around its "waist," adding a playful touch. Mochi's arms and feet are cartoonishly short, allowing it to bounce adorably around its surroundings. 8 | A hyper-realistic close-up photograph of a woman's face, focusing on the left side. The image is highly detailed and realistic, showing voluminous glossy lips slightly parted, a well-defined nose, and open eyes with long eyelashes that cast shadows on the skin. The eye color is crystal clear almond green. The skin texture is crisp, with incredible detail of natural, lush skin and pores and freckles, with subtle highlights and shadows that give a realistic, close-up appearance. 9 | A vibrant cartoon of a chameleon blending into a tie-dye pattern. 10 | A colorful cartoon of a tiger camouflaged in an abstract art painting, its stripes merging with the wild brushstrokes. 11 | A 3D render of a cute, round rice ball character named Mochi, with big, sparkling eyes that convey curiosity and joy. Its body is a soft, fluffy white with a slight sheen, resembling freshly cooked rice. Mochi has small, rosy cheeks that give it a warm, friendly expression. A tiny smile brightens its face, and it often sports a colorful ribbon tied around its "waist," adding a playful touch. Mochi's arms and feet are cartoonishly short, allowing it to bounce adorably around its surroundings. This time, Mochi is placed against a background that is a vibrant explosion of colors, with bright hues of fuchsia, turquoise, lemon yellow, and emerald green creating a canvas of vibrant contrasts and playful energy. The clashing colors make Mochi's soft white body and rosy cheeks stand out even more, inviting viewers into a world of cheerful exuberance and visual delight. 12 | The word 'mardefly' on a coffee mug. 13 | A group of seven people standing on a snow-covered slope, allwearing skis and posing for a picture. -------------------------------------------------------------------------------- /validation_prompts/text2image_prompts.txt: -------------------------------------------------------------------------------- 1 | a family of four is captured in a moment of joy 2 | a serene indoor setting. Dominating the foreground is a black diffuser from the brand "Aromatherapy Associates", as indicated by the white text on its silver lid. 3 | a scene from a laboratory setting. 4 | a t-shirt of an avocado and a llama 5 | there is a woman who is the main subject 6 | a man is standing on a stage, holding a microphone in his hand. 7 | a captivating scene of two fishing boats docked at a rocky shore 8 | a close-up of a woman's face, captured in what appears to be a mugshot setting --------------------------------------------------------------------------------