├── LICENSE ├── README.md ├── TextHarmony ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ └── __init__.cpython-39.pyc ├── configs │ └── release │ │ ├── 896-moe-eval.yaml │ │ ├── 896-moe-inference.yaml │ │ ├── deepspeed_zero1.json │ │ ├── deepspeed_zero2.json │ │ ├── edit_annt.json │ │ ├── edit_gt.json │ │ └── example_inference.yaml ├── custom_datasets │ ├── SegToImageDataset.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── SegToImageDataset.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── ade20k.cpython-39.pyc │ │ ├── all_mixed_dataset.cpython-39.pyc │ │ ├── caption_datasets.cpython-39.pyc │ │ ├── clip_itp.cpython-39.pyc │ │ ├── collator.cpython-39.pyc │ │ ├── collator_sft.cpython-39.pyc │ │ ├── flintstones.cpython-39.pyc │ │ ├── grounding_datasets.cpython-39.pyc │ │ ├── image2paragraph.cpython-39.pyc │ │ ├── laion_dataset.cpython-39.pyc │ │ ├── laion_wds.cpython-39.pyc │ │ ├── lncoco.cpython-39.pyc │ │ ├── loader.cpython-39.pyc │ │ ├── mix_dataset.cpython-39.pyc │ │ ├── mmc4_wds.cpython-39.pyc │ │ ├── mscoco.cpython-39.pyc │ │ ├── mscoco_karpathy.cpython-39.pyc │ │ ├── pororo.cpython-39.pyc │ │ ├── segment_dataset.cpython-39.pyc │ │ ├── sft_datasets.cpython-39.pyc │ │ ├── utils.cpython-39.pyc │ │ ├── visdial_dense.cpython-39.pyc │ │ ├── vist.cpython-39.pyc │ │ ├── vqa_datasets.cpython-39.pyc │ │ └── wds_utils.cpython-39.pyc │ ├── ade20k.py │ ├── ade20k_preparation.py │ ├── all_mixed_dataset.py │ ├── caption_datasets.py │ ├── clip_itp.py │ ├── collator.py │ ├── collator_sft.py │ ├── flintstones.py │ ├── image2paragraph.py │ ├── laion_dataset.py │ ├── laion_wds.py │ ├── lncoco.py │ ├── loader.py │ ├── mix_dataset.py │ ├── mmc4_wds.py │ ├── mscoco.py │ ├── mscoco_karpathy.py │ ├── pororo.py │ ├── segment_dataset.py │ ├── sft_datasets.py │ ├── utils.py │ ├── visdial_dense.py │ ├── vist.py │ ├── vqa_datasets.py │ └── wds_utils.py ├── engine │ ├── __pycache__ │ │ └── lmm_trainer.cpython-39.pyc │ └── lmm_trainer.py ├── models │ ├── TextHarmony.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── mm_interleaved.cpython-310.pyc │ │ ├── mm_interleaved.cpython-38.pyc │ │ └── mm_interleaved.cpython-39.pyc │ ├── decoders │ │ ├── __pycache__ │ │ │ ├── decoder_image.cpython-38.pyc │ │ │ ├── decoder_image.cpython-39.pyc │ │ │ ├── decoder_text.cpython-38.pyc │ │ │ ├── decoder_text.cpython-39.pyc │ │ │ ├── modeling_llama_mmfs.cpython-38.pyc │ │ │ ├── modeling_llama_mmfs.cpython-39.pyc │ │ │ ├── perceiver.cpython-310.pyc │ │ │ ├── perceiver.cpython-38.pyc │ │ │ ├── perceiver.cpython-39.pyc │ │ │ ├── sd.cpython-38.pyc │ │ │ ├── sd.cpython-39.pyc │ │ │ └── sd_mmfs.cpython-39.pyc │ │ ├── decoder_image.py │ │ ├── decoder_text.py │ │ ├── modeling_llama_mmfs.py │ │ ├── perceiver.py │ │ ├── sd.py │ │ └── sd_mmfs.py │ ├── encoders │ │ ├── __pycache__ │ │ │ ├── visual_tokenizer.cpython-310.pyc │ │ │ ├── visual_tokenizer.cpython-38.pyc │ │ │ └── visual_tokenizer.cpython-39.pyc │ │ ├── visual_tokenizer.py │ │ └── vit_adapter │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── adapter_modules.cpython-310.pyc │ │ │ ├── adapter_modules.cpython-38.pyc │ │ │ ├── adapter_modules.cpython-39.pyc │ │ │ ├── clip_vit_hf.cpython-310.pyc │ │ │ ├── clip_vit_hf.cpython-38.pyc │ │ │ ├── clip_vit_hf.cpython-39.pyc │ │ │ ├── vit_adapter_hf.cpython-310.pyc │ │ │ ├── vit_adapter_hf.cpython-38.pyc │ │ │ ├── vit_adapter_hf.cpython-39.pyc │ │ │ ├── xattn.cpython-310.pyc │ │ │ ├── xattn.cpython-38.pyc │ │ │ └── xattn.cpython-39.pyc │ │ │ ├── adapter_modules.py │ │ │ ├── clip_vit_hf.py │ │ │ ├── ops │ │ │ ├── functions │ │ │ │ ├── __init__.py │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ │ ├── __init__.cpython-39.pyc │ │ │ │ │ ├── ms_deform_attn_func.cpython-310.pyc │ │ │ │ │ ├── ms_deform_attn_func.cpython-38.pyc │ │ │ │ │ └── ms_deform_attn_func.cpython-39.pyc │ │ │ │ └── ms_deform_attn_func.py │ │ │ └── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── __init__.cpython-39.pyc │ │ │ │ ├── ms_deform_attn.cpython-310.pyc │ │ │ │ ├── ms_deform_attn.cpython-38.pyc │ │ │ │ └── ms_deform_attn.cpython-39.pyc │ │ │ │ └── ms_deform_attn.py │ │ │ ├── vit_adapter_hf.py │ │ │ └── xattn.py │ ├── moe │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── layer.cpython-39.pyc │ │ │ └── moe_lora.cpython-39.pyc │ │ ├── layer.py │ │ └── moe_lora.py │ └── utils │ │ ├── __pycache__ │ │ ├── causal_lm_cascade.cpython-39.pyc │ │ ├── pos_embed.cpython-310.pyc │ │ ├── pos_embed.cpython-38.pyc │ │ └── pos_embed.cpython-39.pyc │ │ ├── causal_lm_cascade.py │ │ ├── monkey_patch │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── beam_search_monkey_patch.cpython-39.pyc │ │ │ ├── blip2_qknorm_monkey_patch.cpython-39.pyc │ │ │ ├── llama_flash_attn_train_monkey_patch.cpython-39.pyc │ │ │ ├── sd_pipeline_monkey_patch.cpython-39.pyc │ │ │ └── sd_unet_forward_monkey_patch.cpython-39.pyc │ │ ├── beam_search_monkey_patch.py │ │ ├── blip2_qknorm_monkey_patch.py │ │ ├── llama_flash_attn_train_monkey_patch.py │ │ ├── sd_pipeline_monkey_patch.py │ │ └── sd_unet_forward_monkey_patch.py │ │ ├── ops │ │ ├── MultiScaleDeformableAttention.egg-info │ │ │ ├── PKG-INFO │ │ │ ├── SOURCES.txt │ │ │ ├── dependency_links.txt │ │ │ └── top_level.txt │ │ ├── build │ │ │ ├── lib.linux-x86_64-cpython-39 │ │ │ │ ├── MultiScaleDeformableAttention.cpython-39-x86_64-linux-gnu.so │ │ │ │ ├── functions │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── ms_deform_attn_func.py │ │ │ │ └── modules │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── mmfs.py │ │ │ └── temp.linux-x86_64-cpython-39 │ │ │ │ ├── .ninja_deps │ │ │ │ ├── .ninja_log │ │ │ │ ├── build.ninja │ │ │ │ └── mnt │ │ │ │ └── bn │ │ │ │ └── zz-nas │ │ │ │ └── MM-Interleaved │ │ │ │ └── mm_interleaved │ │ │ │ └── models │ │ │ │ └── utils │ │ │ │ └── ops │ │ │ │ └── src │ │ │ │ ├── cpu │ │ │ │ └── ms_deform_attn_cpu.o │ │ │ │ ├── cuda │ │ │ │ └── ms_deform_attn_cuda.o │ │ │ │ └── vision.o │ │ ├── dist │ │ │ └── MultiScaleDeformableAttention-1.0-py3.9-linux-x86_64.egg │ │ ├── forward_backward_error.py │ │ ├── functions │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-39.pyc │ │ │ │ └── ms_deform_attn_func.cpython-39.pyc │ │ │ └── ms_deform_attn_func.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-39.pyc │ │ │ │ └── mmfs.cpython-39.pyc │ │ │ └── mmfs.py │ │ ├── setup.py │ │ ├── src │ │ │ ├── cpu │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ ├── cuda │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ ├── ms_deform_attn.h │ │ │ └── vision.cpp │ │ └── tests │ │ │ ├── __init__.py │ │ │ ├── compare_with_data.py │ │ │ ├── create_data.py │ │ │ ├── forward_backward_error.py │ │ │ ├── skip_forward_error.py │ │ │ └── speed_test.py │ │ └── pos_embed.py ├── scripts │ └── download_hf_models.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── caption_collect.cpython-39.pyc │ ├── clip_sim_score.cpython-39.pyc │ ├── coco_cap_score.cpython-39.pyc │ ├── fid_score.cpython-39.pyc │ ├── grounding_score.cpython-39.pyc │ ├── inception.cpython-39.pyc │ ├── misc.cpython-39.pyc │ ├── parse_args.cpython-39.pyc │ ├── segm_eval.cpython-39.pyc │ ├── visdial_metrics.cpython-39.pyc │ ├── vqa_collect.cpython-39.pyc │ └── vqa_score.cpython-39.pyc │ ├── caption_collect.py │ ├── clip_sim_score.py │ ├── coco_cap_score.py │ ├── fid_score.py │ ├── grounding_score.py │ ├── inception.py │ ├── misc.py │ ├── parse_args.py │ ├── segm_eval.py │ ├── visdial_metrics.py │ ├── vizwiz_metrics_src │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── vqa.cpython-39.pyc │ │ └── vqaEval.cpython-39.pyc │ ├── vqa.py │ └── vqaEval.py │ ├── vqa_collect.py │ ├── vqa_score.py │ └── vqav2_metrics_src │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── vqa.cpython-39.pyc │ └── vqaEval.cpython-39.pyc │ ├── vqa.py │ └── vqaEval.py ├── docs └── examples │ ├── all_zero.jpg │ ├── book.jpg │ └── example.json ├── evaluate.py ├── evaluate_utils.py ├── image_eval ├── __pycache__ │ ├── anytext_singleGPU.cpython-310.pyc │ ├── anytext_singleGPU.cpython-38.pyc │ ├── dataset_util.cpython-310.pyc │ ├── dataset_util.cpython-38.pyc │ ├── t3_dataset.cpython-310.pyc │ └── t3_dataset.cpython-38.pyc ├── anytext_multiGPUs.py ├── anytext_singleGPU.py ├── cldm │ ├── __pycache__ │ │ ├── cldm.cpython-310.pyc │ │ ├── ddim_hacked.cpython-310.pyc │ │ ├── ddim_hacked.cpython-38.pyc │ │ ├── embedding_manager.cpython-310.pyc │ │ ├── model.cpython-310.pyc │ │ ├── model.cpython-38.pyc │ │ ├── recognizer.cpython-310.pyc │ │ └── recognizer.cpython-38.pyc │ ├── cldm.py │ ├── ddim_hacked.py │ ├── embedding_manager.py │ ├── hack.py │ ├── logger.py │ ├── model.py │ └── recognizer.py ├── controlnet_multiGPUs.py ├── controlnet_singleGPU.py ├── dataset_util.py ├── eval_dgocr.py ├── eval_fid.sh ├── eval_ocr.sh ├── gen_glyph.sh ├── gen_imgs_anytext.sh ├── gen_imgs_controlnet_canny.sh ├── gen_imgs_glyphcontrol.sh ├── gen_imgs_textdiffuser.sh ├── glyphcontrol_multiGPUs.py ├── glyphcontrol_singleGPU.py ├── ldm │ ├── __pycache__ │ │ ├── util.cpython-310.pyc │ │ └── util.cpython-38.pyc │ ├── data │ │ ├── __init__.py │ │ └── util.py │ ├── models │ │ ├── __pycache__ │ │ │ └── autoencoder.cpython-310.pyc │ │ ├── autoencoder.py │ │ └── diffusion │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── ddim.cpython-310.pyc │ │ │ └── ddpm.cpython-310.pyc │ │ │ ├── ddim.py │ │ │ ├── ddpm.py │ │ │ ├── dpm_solver │ │ │ ├── __init__.py │ │ │ ├── dpm_solver.py │ │ │ └── sampler.py │ │ │ ├── plms.py │ │ │ └── sampling_util.py │ ├── modules │ │ ├── __pycache__ │ │ │ ├── attention.cpython-310.pyc │ │ │ └── ema.cpython-310.pyc │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── model.cpython-310.pyc │ │ │ │ ├── openaimodel.cpython-310.pyc │ │ │ │ ├── util.cpython-310.pyc │ │ │ │ └── util.cpython-38.pyc │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ ├── upscaling.py │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ └── distributions.cpython-310.pyc │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ └── modules.cpython-310.pyc │ │ │ └── modules.py │ │ ├── image_degradation │ │ │ ├── __init__.py │ │ │ ├── bsrgan.py │ │ │ ├── bsrgan_light.py │ │ │ ├── utils │ │ │ │ └── test.png │ │ │ └── utils_image.py │ │ └── midas │ │ │ ├── __init__.py │ │ │ ├── api.py │ │ │ ├── midas │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ ├── blocks.py │ │ │ ├── dpt_depth.py │ │ │ ├── midas_net.py │ │ │ ├── midas_net_custom.py │ │ │ ├── transforms.py │ │ │ └── vit.py │ │ │ └── utils.py │ └── util.py ├── ocr_recog │ ├── RNN.py │ ├── RecCTCHead.py │ ├── RecModel.py │ ├── RecMv1_enhance.py │ ├── RecSVTR.py │ ├── __pycache__ │ │ ├── RNN.cpython-310.pyc │ │ ├── RNN.cpython-38.pyc │ │ ├── RecCTCHead.cpython-310.pyc │ │ ├── RecCTCHead.cpython-38.pyc │ │ ├── RecModel.cpython-310.pyc │ │ ├── RecModel.cpython-38.pyc │ │ ├── RecMv1_enhance.cpython-310.pyc │ │ ├── RecMv1_enhance.cpython-38.pyc │ │ ├── RecSVTR.cpython-310.pyc │ │ ├── RecSVTR.cpython-38.pyc │ │ ├── common.cpython-310.pyc │ │ └── common.cpython-38.pyc │ ├── common.py │ ├── en_dict.txt │ └── ppocr_keys_v1.txt ├── ocr_weights │ ├── en_dict.txt │ ├── ppocr_keys_v1.txt │ ├── ppv3_rec.pth │ └── ppv3_rec_en.pth ├── render_glyph_imgs.py ├── t3_dataset.py ├── textdiffuser_multiGPUs.py └── textdiffuser_singleGPU.py ├── inference.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Harmonizing Visual Text Comprehension and Generation 2 | 3 | ## Environment 4 | 5 | **step 1**: set up the environment 6 | 7 | ``` 8 | git clone https://github.com/bytedance/TextHarmony 9 | cd TextHarmony 10 | pip install -r requirements.txt 11 | # install `MultiScaleDeformableAttention` module 12 | cd TextHarmony/models/utils/ops 13 | python setup.py install 14 | ``` 15 | some of the packages like mmcv and flash_attn in requirements.txt may need to be installed manually. 16 | 17 | **step 2**: download pretraining weights 18 | 19 | ``` 20 | cd TextHarmony 21 | python TextHarmony/scripts/download_hf_models.py 22 | ``` 23 | 24 | **step 3**: download the model weight of [TextHarmony](https://huggingface.co/jingqun/textharmony) 25 | 26 | ``` 27 | # concatenate the model files 28 | cat pytorch_model.binaa pytorch_model.binab pytorch_model.binac > pytorch_model.bin 29 | ``` 30 | 31 | ## Inference 32 | 33 | **step1**: modify 'load_from', 'llm_model_path', 'encoder_model_path' and 'pretrained_model_name_or_path' in example_inference.yaml 34 | 35 | **step 2**: run the following command: 36 | 37 | ``` 38 | torchrun --nproc_per_node 1 --nnodes 1 --master_port 2333 inference.py --config_file=TextHarmony/TextHarmony/configs/release/example_inference.yaml 39 | ``` 40 | 41 | ## Evaluation 42 | 43 | ### image comprehension 44 | 45 | **step1**: modify 'data_root' and 'data_path' in 896-moe-eval.yaml. The structure of 'data_path' should be as follows: 46 | 47 | ``` 48 | [ 49 | { 50 | "image": image_path, 51 | "question": question, 52 | "answer": answer 53 | }, 54 | ] 55 | ``` 56 | 57 | **step 2**: run the following command 58 | 59 | ``` 60 | torchrun --nproc_per_node 1 --nnodes 1 --master_port 2333 evaluate.py --config_file=TextHarmony/TextHarmony/configs/release/896-moe-eval.yaml 61 | ``` 62 | 63 | ### image generation 64 | 65 | **step 1**: download [AnyText-Benchmark](https://github.com/tyxsspa/AnyText?tab=readme-ov-file) 66 | 67 | **step 2**: generate the target images 68 | 69 | ``` 70 | torchrun --nproc_per_node 1 --nnodes 1 --master_port 2333 inference.py --config_file=TextHarmony/TextHarmony/configs/release/896-moe-inference.yaml 71 | ``` 72 | 73 | **step 3**: calculate the results 74 | 75 | ``` 76 | python TextHarmony/image_eval/eval_dgocr.py 77 | ``` 78 | 79 | ## Training 80 | 81 | * **TODO** 82 | 83 | ## Acknowledgment 84 | 85 | We thank the great work of [MM-Interleaved](https://github.com/OpenGVLab/MM-Interleaved), [TextDiffuser](https://github.com/microsoft/unilm/tree/master/textdiffuser-2), [AnyText](https://github.com/tyxsspa/AnyText) and [LoRAMoE](https://github.com/Ablustrund/LoRAMoE) 86 | 87 | ## Citation 88 | 89 | ``` 90 | @article{zhao2024harmonizing, 91 | title={Harmonizing Visual Text Comprehension and Generation}, 92 | author={Zhao, Zhen and Tang, Jingqun and Wu, Binghong and Lin, Chunhui and Wei, Shu and Liu, Hao and Tan, Xin and Zhang, Zhizhong and Huang, Can and Xie, Yuan}, 93 | journal={arXiv preprint arXiv:2407.16364}, 94 | year={2024} 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /TextHarmony/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/__init__.py -------------------------------------------------------------------------------- /TextHarmony/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /TextHarmony/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/configs/release/896-moe-eval.yaml: -------------------------------------------------------------------------------- 1 | 2 | load_from: /ATA_checkpoints/896-moe-3/training_end/pytorch_model.bin # checkpoint 3 | 4 | data_root: /mllm-doc/InternLM-XComposer # image folder 5 | 6 | data_path: /mllm-doc/InternLM-XComposer/data/infographicVQA/infovqa_test.jsonl # annotation file 7 | 8 | image_upscale: 1 9 | 10 | data_seed: &data_seed 0 11 | seed: 32 12 | use_lora: True 13 | 14 | # MODEL 15 | 16 | model: 17 | llm_model_path: &tokenizer_path ./assets/lmsys/vicuna-13b-v1.3 18 | num_img_token: &img_len 512 19 | 20 | visual_tokenizer_config: 21 | encoder_model_path: ./assets/openai/clip-vit-large-patch14 22 | image_size: 896 23 | perceiver_config: 24 | num_queries: 512 25 | hidden_size: 768 26 | encoder_hidden_size: 1024 27 | cross_attention_frequency: 2 28 | num_hidden_layers: 12 29 | num_attention_heads: 12 30 | qk_normalization: True 31 | image_decoder_config: 32 | pretrained_model_name_or_path: ./assets/stabilityai/stable-diffusion-2-base 33 | sd_base_seed: 42 34 | perceiver_config: 35 | num_queries: 77 36 | hidden_size: 1024 37 | encoder_hidden_size: 5120 38 | cross_attention_frequency: 1 39 | num_hidden_layers: 1 40 | num_attention_heads: 16 41 | hidden_dropout_prob: 0. 42 | attention_probs_dropout_prob: 0. 43 | moe_config: 44 | moe_finetuning: True 45 | vit_lora: True 46 | llm_lora: True 47 | peft_type: moe_lora 48 | lora_r: 32 49 | lora_alpha: 32 50 | lora_dropout: 0.1 51 | lora_target_modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj'] 52 | moe_lora_num_experts: 3 53 | moe_gate_mode: top2_gate 54 | task_num: 3 # (image generation; text generation; shared) 55 | 56 | deepspeed: './mm_interleaved/configs/release/deepspeed_zero2.json' 57 | 58 | # INFERENCE 59 | 60 | inference: 61 | tokenizer_path: *tokenizer_path 62 | num_img_token: *img_len 63 | generate_mode: generate_texts 64 | force_gen_image_next: False 65 | force_replace_gen_text: False 66 | auto_end: False 67 | num_iter: 2 68 | 69 | transform: 70 | aug_type: numpy 71 | resolution: 896 72 | 73 | generation_kwargs: 74 | max_length: 30 75 | min_length: 1 76 | num_beams: 5 77 | use_nucleus_sampling: True 78 | repetition_penalty: 1.3 79 | guidance_scale: 7.5 80 | num_inference_steps: 30 81 | num_validation_images: 1 82 | 83 | -------------------------------------------------------------------------------- /TextHarmony/configs/release/896-moe-inference.yaml: -------------------------------------------------------------------------------- 1 | 2 | load_from: /ATA_checkpoints/896-moe-3/training_end/pytorch_model.bin 3 | 4 | annt_path: TextHarmony/TextHarmony/configs/release/edit_annt.json 5 | output_dir: TextHarmony/save_image 6 | 7 | image_upscale: 1 8 | 9 | data_seed: &data_seed 0 10 | seed: 32 11 | use_lora: True 12 | 13 | # MODEL 14 | 15 | model: 16 | llm_model_path: &tokenizer_path ./assets/lmsys/vicuna-13b-v1.3 17 | num_img_token: &img_len 512 18 | 19 | visual_tokenizer_config: 20 | encoder_model_path: ./assets/openai/clip-vit-large-patch14 21 | image_size: 896 22 | perceiver_config: 23 | num_queries: 512 24 | hidden_size: 768 25 | encoder_hidden_size: 1024 26 | cross_attention_frequency: 2 27 | num_hidden_layers: 12 28 | num_attention_heads: 12 29 | qk_normalization: True 30 | image_decoder_config: 31 | pretrained_model_name_or_path: ./assets/stabilityai/stable-diffusion-2-base 32 | sd_base_seed: 42 33 | perceiver_config: 34 | num_queries: 77 35 | hidden_size: 1024 36 | encoder_hidden_size: 5120 37 | cross_attention_frequency: 1 38 | num_hidden_layers: 1 39 | num_attention_heads: 16 40 | hidden_dropout_prob: 0. 41 | attention_probs_dropout_prob: 0. 42 | moe_config: 43 | moe_finetuning: True 44 | vit_lora: True 45 | llm_lora: True 46 | peft_type: moe_lora 47 | lora_r: 32 48 | lora_alpha: 32 49 | lora_dropout: 0.1 50 | lora_target_modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj'] 51 | moe_lora_num_experts: 3 52 | moe_gate_mode: top2_gate 53 | task_num: 3 # (image generation; text generation; shared) 54 | 55 | # INFERENCE 56 | 57 | inference: 58 | tokenizer_path: *tokenizer_path 59 | num_img_token: *img_len 60 | generate_mode: generate_texts 61 | force_gen_image_next: False 62 | force_replace_gen_text: False 63 | auto_end: False 64 | num_iter: 2 65 | 66 | transform: 67 | aug_type: numpy 68 | resolution: 896 69 | 70 | generation_kwargs: 71 | max_length: 128 72 | min_length: 1 73 | num_beams: 5 74 | use_nucleus_sampling: True 75 | repetition_penalty: 1.3 76 | guidance_scale: 7.5 77 | num_inference_steps: 30 78 | num_validation_images: 1 79 | 80 | -------------------------------------------------------------------------------- /TextHarmony/configs/release/deepspeed_zero1.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 1, 4 | "allgather_partitions": true, 5 | "reduce_scatter": false, 6 | "allgather_bucket_size": 1e9, 7 | "reduce_bucket_size": 1e9, 8 | "overlap_comm": true, 9 | "contiguous_gradients": true, 10 | "ignore_unused_parameters": true 11 | }, 12 | "fp16": { 13 | "enabled": "auto", 14 | "auto_cast": true, 15 | "loss_scale": 0.0, 16 | "initial_scale_power": 16, 17 | "loss_scale_window": 250, 18 | "min_loss_scale": 1 19 | }, 20 | "bf16": { 21 | "enabled": "auto" 22 | }, 23 | "train_batch_size": "auto", 24 | "train_micro_batch_size_per_gpu": "auto", 25 | "wall_clock_breakdown": false, 26 | "gradient_clipping": "auto", 27 | "prescale_gradients": true 28 | } -------------------------------------------------------------------------------- /TextHarmony/configs/release/deepspeed_zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 2, 16 | "offload_optimizer": { 17 | "device": "cpu", 18 | "pin_memory": true 19 | }, 20 | "allgather_partitions": true, 21 | "allgather_bucket_size": 2e8, 22 | "overlap_comm": true, 23 | "reduce_scatter": true, 24 | "reduce_bucket_size": 2e8, 25 | "contiguous_gradients": true 26 | }, 27 | 28 | "gradient_accumulation_steps": "auto", 29 | "gradient_clipping": "auto", 30 | "steps_per_print": 100, 31 | "train_batch_size": "auto", 32 | "train_micro_batch_size_per_gpu": "auto", 33 | "wall_clock_breakdown": false 34 | } 35 | -------------------------------------------------------------------------------- /TextHarmony/configs/release/example_inference.yaml: -------------------------------------------------------------------------------- 1 | 2 | load_from: /ATA_checkpoints/896-moe-3/training_end/pytorch_model.bin 3 | 4 | annt_path: TextHarmony/docs/examples/example.json 5 | output_dir: TextHarmony/save_image 6 | 7 | image_upscale: 1 8 | 9 | data_seed: &data_seed 0 10 | seed: 32 11 | use_lora: True 12 | 13 | # MODEL 14 | 15 | model: 16 | llm_model_path: &tokenizer_path ./assets/lmsys/vicuna-13b-v1.3 17 | num_img_token: &img_len 512 18 | 19 | visual_tokenizer_config: 20 | encoder_model_path: ./assets/openai/clip-vit-large-patch14 21 | image_size: 896 22 | perceiver_config: 23 | num_queries: 512 24 | hidden_size: 768 25 | encoder_hidden_size: 1024 26 | cross_attention_frequency: 2 27 | num_hidden_layers: 12 28 | num_attention_heads: 12 29 | qk_normalization: True 30 | image_decoder_config: 31 | pretrained_model_name_or_path: ./assets/stabilityai/stable-diffusion-2-base 32 | sd_base_seed: 42 33 | perceiver_config: 34 | num_queries: 77 35 | hidden_size: 1024 36 | encoder_hidden_size: 5120 37 | cross_attention_frequency: 1 38 | num_hidden_layers: 1 39 | num_attention_heads: 16 40 | hidden_dropout_prob: 0. 41 | attention_probs_dropout_prob: 0. 42 | moe_config: 43 | moe_finetuning: True 44 | vit_lora: True 45 | llm_lora: True 46 | peft_type: moe_lora 47 | lora_r: 32 48 | lora_alpha: 32 49 | lora_dropout: 0.1 50 | lora_target_modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj'] 51 | moe_lora_num_experts: 3 52 | moe_gate_mode: top2_gate 53 | task_num: 3 # (image generation; text generation; shared) 54 | 55 | # INFERENCE 56 | 57 | inference: 58 | tokenizer_path: *tokenizer_path 59 | num_img_token: *img_len 60 | generate_mode: generate_texts 61 | force_gen_image_next: False 62 | force_replace_gen_text: False 63 | auto_end: False 64 | num_iter: 2 65 | 66 | transform: 67 | aug_type: numpy 68 | resolution: 896 69 | 70 | generation_kwargs: 71 | max_length: 128 72 | min_length: 1 73 | num_beams: 5 74 | use_nucleus_sampling: True 75 | repetition_penalty: 1.3 76 | guidance_scale: 7.5 77 | num_inference_steps: 30 78 | num_validation_images: 1 79 | 80 | -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import build_dataset 2 | from .wds_utils import init_tokenizer -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/SegToImageDataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/SegToImageDataset.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/ade20k.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/ade20k.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/all_mixed_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/all_mixed_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/caption_datasets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/caption_datasets.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/clip_itp.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/clip_itp.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/collator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/collator.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/collator_sft.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/collator_sft.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/flintstones.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/flintstones.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/grounding_datasets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/grounding_datasets.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/image2paragraph.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/image2paragraph.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/laion_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/laion_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/laion_wds.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/laion_wds.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/lncoco.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/lncoco.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/loader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/loader.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/mix_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/mix_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/mmc4_wds.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/mmc4_wds.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/mscoco.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/mscoco.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/mscoco_karpathy.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/mscoco_karpathy.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/pororo.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/pororo.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/segment_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/segment_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/sft_datasets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/sft_datasets.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/visdial_dense.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/visdial_dense.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/vist.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/vist.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/vqa_datasets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/vqa_datasets.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/__pycache__/wds_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/custom_datasets/__pycache__/wds_utils.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/caption_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | from .loader import BaseDataset 6 | 7 | 8 | class NoCapsDataset(BaseDataset): 9 | def __init__( 10 | self, 11 | data_root, 12 | annt_file, 13 | transform, 14 | image_only=False, 15 | total_length=None, 16 | collate_mode='generate_texts', 17 | add_eos=None, 18 | ) -> None: 19 | super().__init__() 20 | self.collate_mode = collate_mode 21 | self.transform = transform 22 | self.data_root = data_root 23 | self.image_only = image_only 24 | self.annts = self.load_annotations(annt_file) 25 | self.annt_file = annt_file 26 | if self.image_only: 27 | self.dedeup_image() 28 | if total_length is not None: 29 | self.annts = self.annts[:total_length] 30 | self.add_eos = add_eos 31 | print(f"length of the dataset is {len(self.annts)}") 32 | 33 | def load_annotations(self, annt_file): 34 | meta_info = json.load(open(annt_file, "r")) 35 | images = meta_info['images'] 36 | annotations = meta_info['annotations'] 37 | 38 | image_info = {} 39 | for image in images: 40 | image_info[image['id']] = image 41 | 42 | processed_annotations = [] 43 | for ann in annotations: 44 | image_id = ann['image_id'] 45 | file_name = image_info[image_id]['file_name'] 46 | caption = ann['caption'] 47 | 48 | processed_annotations.append({ 49 | 'image': file_name, 50 | 'caption': caption, 51 | 'image_id': image_id, 52 | }) 53 | 54 | return processed_annotations 55 | 56 | def dedeup_image(self): 57 | annts = {} 58 | for annt in self.annts: 59 | image_idx = annt["image_id"] 60 | if image_idx in annts: 61 | continue 62 | annts[image_idx] = annt 63 | self.annts = list(annts.values()) 64 | 65 | def __repr__(self) -> str: 66 | return "Nocaps Dataset" 67 | 68 | def __len__(self): 69 | return len(self.annts) 70 | 71 | def __getitem__(self, index): 72 | item = self.annts[index] 73 | caption = item["caption"] 74 | if isinstance(caption, list): # TODO, random choose one caption from the image 75 | caption = random.choice(caption) 76 | caption = caption.lower() 77 | if self.add_eos is not None: 78 | caption = caption + self.add_eos 79 | image_idx_int = item["image_id"] 80 | image_path = os.path.join(self.data_root, item["image"]) 81 | 82 | try: 83 | image = self.loader(image_path).convert("RGB") 84 | 85 | image = self.transform(image) 86 | except: 87 | print(image_path) 88 | index = random.randint(0, len(self) - 1) 89 | return self.__getitem__(index) 90 | 91 | return image, caption, image_idx_int 92 | 93 | 94 | class Flickr30KDataset(NoCapsDataset): 95 | def __repr__(self) -> str: 96 | return "Flickr30K Dataset" 97 | -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/clip_itp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import CLIPProcessor 3 | 4 | from .loader import BaseDataset 5 | 6 | 7 | class CLIPImageTextPairDataset(BaseDataset): 8 | def __init__( 9 | self, 10 | image_root, 11 | caption_list, 12 | model_name="openai/clip-vit-large-patch14", 13 | ) -> None: 14 | super().__init__() 15 | 16 | self.model_name = model_name 17 | self.image_root = image_root 18 | self.caption_list = caption_list 19 | 20 | self.clip_processor = CLIPProcessor.from_pretrained(model_name) 21 | 22 | print(f"length of the dataset is {len(self.caption_list)}") 23 | 24 | def __repr__(self) -> str: 25 | return ( 26 | f"CLIPImageTextPair Dataset total_length={len(self)}\n" 27 | f"image_root={self.image_root}\nprocessor={self.clip_processor}" 28 | ) 29 | 30 | def __len__(self): 31 | return len(self.caption_list) 32 | 33 | def __getitem__(self, index): 34 | caption = self.caption_list[str(index)]["caption"] 35 | image_path = os.path.join(self.image_root, f"{index:05d}.png") 36 | 37 | image = self.loader(image_path).convert("RGB") 38 | data = self.clip_processor( 39 | images=image, 40 | text=caption, 41 | return_tensors="pt", 42 | padding="max_length", 43 | max_length=77, 44 | ) 45 | 46 | return data.pixel_values[0], data.input_ids[0], index 47 | 48 | 49 | class CLIPImagePairDataset(BaseDataset): 50 | def __init__( 51 | self, 52 | image_pair_list, 53 | model_name="openai/clip-vit-large-patch14", 54 | ) -> None: 55 | 56 | super().__init__() 57 | 58 | self.model_name = model_name 59 | self.image_pair_list = image_pair_list 60 | 61 | self.clip_processor = CLIPProcessor.from_pretrained(model_name) 62 | 63 | print(f"length of the dataset is {len(self.image_pair_list)}") 64 | 65 | def __repr__(self) -> str: 66 | return ( 67 | f"CLIPImagePairDataset total_length={len(self)}\n" 68 | f"processor={self.clip_processor}" 69 | ) 70 | 71 | def __len__(self): 72 | return len(self.image_pair_list) 73 | 74 | def __getitem__(self, index): 75 | image_path = self.image_pair_list[index]["image_path"] 76 | image = self.loader(image_path).convert("RGB") 77 | 78 | image = self.clip_processor( 79 | images=image, 80 | text=None, 81 | return_tensors="pt", 82 | ).pixel_values[0] 83 | 84 | image_path_gt = self.image_pair_list[index]["image_gt_path"] 85 | image_gt = self.loader(image_path_gt).convert("RGB") 86 | 87 | image_gt = self.clip_processor( 88 | images=image_gt, 89 | text=None, 90 | return_tensors="pt", 91 | ).pixel_values[0] 92 | 93 | return image, image_gt, index 94 | -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/image2paragraph.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | from .loader import BaseDataset 6 | 7 | 8 | class Image2ParagraphDataset(BaseDataset): 9 | def __init__( 10 | self, 11 | data_root, 12 | annt_root, 13 | transform, 14 | image_only=False, 15 | total_length=None, 16 | collate_mode="generate_texts", 17 | phase="train", 18 | add_eos=None, 19 | ) -> None: 20 | super().__init__() 21 | self.collate_mode = collate_mode 22 | self.transform = transform 23 | self.data_root = data_root 24 | self.annt_root = annt_root 25 | self.phase = phase 26 | self.image_only = image_only 27 | 28 | annt_file = os.path.join(annt_root, "annotations", f"paragraphs_coco.json") 29 | with open(annt_file, "r") as rf: 30 | data = json.load(rf) 31 | annts = {d["image_id"]: d for d in data["annotations"]} 32 | 33 | split_file = os.path.join(annt_root, "annotations", f"{phase}_split.json") 34 | with open(split_file, "r") as rf: 35 | split_idxs = set(json.load(rf)) 36 | annts = [v for k, v in annts.items() if k in split_idxs] 37 | 38 | self.annts = annts 39 | self.annt_file = annt_file 40 | if total_length is not None: 41 | self.annts = self.annts[:total_length] 42 | self.add_eos = add_eos 43 | print(f"length of the dataset is {len(self.annts)}") 44 | 45 | def __repr__(self) -> str: 46 | return ( 47 | f"Image2Paragraph Dataset phase={self.phase}\n" 48 | f"annotation_root={self.annt_root} data_root={self.data_root}\n" 49 | f"transform={self.transform}" 50 | ) 51 | 52 | def __len__(self): 53 | return len(self.annts) 54 | 55 | def __getitem__(self, index): 56 | item = self.annts[index] 57 | caption = item["caption"] 58 | # caption = caption.lower() 59 | if self.add_eos is not None: 60 | caption = caption + self.add_eos 61 | 62 | image_idx_int = item["image_id"] 63 | image_subpaths = item["url"].split("/")[-2:] 64 | image_path = os.path.join(self.data_root, *image_subpaths) 65 | 66 | try: 67 | image = self.loader(image_path).convert("RGB") 68 | 69 | image = self.transform(image) 70 | except: 71 | print(image_path) 72 | index = random.randint(0, len(self) - 1) 73 | return self.__getitem__(index) 74 | 75 | return image, caption, image_idx_int 76 | -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/lncoco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import numpy as np 5 | from collections import Counter 6 | 7 | from .loader import BaseDataset 8 | 9 | 10 | class LNCOCODataset(BaseDataset): 11 | def __init__( 12 | self, 13 | data_root, 14 | annt_root, 15 | transform, 16 | image_only=False, 17 | total_length=None, 18 | collate_mode="generate_images", 19 | phase="val", 20 | add_eos=None, 21 | ) -> None: 22 | super().__init__() 23 | assert phase == "val" and collate_mode in ["generate_images"] 24 | self.collate_mode = collate_mode 25 | self.transform = transform 26 | self.data_root = data_root 27 | self.annt_root = annt_root 28 | self.phase = phase 29 | self.image_only = image_only 30 | 31 | annt_file = os.path.join(annt_root, "coco_val_captions.jsonl") 32 | with open(annt_file, "r") as rf: 33 | data = rf.readlines() 34 | self.annts = [json.loads(s) for s in data] 35 | self.annt_file = annt_file 36 | if self.image_only: 37 | self.dedeup_image() 38 | if total_length is not None: 39 | if total_length <= len(self.annts): 40 | self.annts = self.annts[:total_length] 41 | else: 42 | # over sampling 43 | cnter_image = Counter([a["image_id"] for a in self.annts]) 44 | annts_weight = [1./cnter_image[a["image_id"]] for a in self.annts] 45 | annts_weight = [w / sum(annts_weight) for w in annts_weight] 46 | annts_n = np.random.choice(self.annts, total_length - len(self.annts), p=annts_weight) 47 | self.annts += list(annts_n) 48 | self.add_eos = add_eos 49 | print(f"length of the dataset is {len(self.annts)}") 50 | 51 | def dedeup_image(self): 52 | annts = {} 53 | for annt in self.annts: 54 | image_idx = annt["image_id"] 55 | if image_idx in annts: 56 | continue 57 | annts[image_idx] = annt 58 | self.annts = list(annts.values()) 59 | 60 | def image_id_to_path(self, image_id): 61 | # coco-2017 62 | return os.path.join(self.data_root, "val2017", f"{image_id:012d}.jpg") 63 | 64 | def __repr__(self) -> str: 65 | return ( 66 | f"LNCOCO Dataset phase={self.phase}\n" 67 | f"annotation_root={self.annt_root} data_root={self.data_root}\n" 68 | f"transform={self.transform}" 69 | ) 70 | 71 | def __len__(self): 72 | return len(self.annts) 73 | 74 | def __getitem__(self, index): 75 | item = self.annts[index] 76 | caption = item["caption"] 77 | # caption = caption.lower() 78 | if self.add_eos is not None: 79 | caption = caption + self.add_eos 80 | 81 | image_idx_int = int(item["image_id"]) 82 | image_path = os.path.join(self.data_root, "val2017", f"{image_idx_int:012d}.jpg") 83 | 84 | try: 85 | image = self.loader(image_path).convert("RGB") 86 | 87 | image = self.transform(image) 88 | except: 89 | print(image_path) 90 | index = random.randint(0, len(self) - 1) 91 | return self.__getitem__(index) 92 | 93 | return image, caption, image_idx_int 94 | -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/loader.py: -------------------------------------------------------------------------------- 1 | import io 2 | from PIL import Image 3 | import cv2 4 | import numpy as np 5 | from torch.utils.data import Dataset, IterableDataset 6 | 7 | import logging 8 | import os 9 | 10 | LOG_LOADER = os.environ.get("LOG_LOADER", False) 11 | 12 | 13 | def pil_loader(img_str): 14 | buff = io.BytesIO(img_str) 15 | return Image.open(buff) 16 | 17 | 18 | def cv2_loader(img_bytes): 19 | # assert(img_bytes is not None) 20 | img_mem_view = memoryview(img_bytes) 21 | img_array = np.frombuffer(img_mem_view, np.uint8) 22 | imgcv2 = cv2.imdecode(img_array, cv2.IMREAD_COLOR) 23 | imgcv2 = cv2.cvtColor(imgcv2, cv2.COLOR_BGR2RGB) 24 | return Image.fromarray(imgcv2) 25 | 26 | 27 | class LocalClient(): 28 | def __init__(self, **kwargs) -> None: 29 | pass 30 | 31 | def get(self, url): 32 | with open(url, "rb") as rf: 33 | data = rf.read() 34 | return data 35 | 36 | 37 | class BaseLoader(object): 38 | def __init__(self): 39 | self.client = LocalClient() 40 | 41 | def __call__(self, fn): 42 | try: 43 | if self.client is not None: 44 | img_value_str = self.client.get(fn) 45 | img = pil_loader(img_value_str) 46 | else: 47 | img = Image.open(fn) 48 | except: 49 | try: 50 | img = cv2_loader(img_value_str) 51 | except Exception as exn: 52 | exn.args = exn.args + (fn,) 53 | if LOG_LOADER: 54 | logging.warning(f"Handling BaseLoader image reading error ({repr(exn)}). Ignoring.") 55 | # print('Read image failed ({})'.format(fn)) 56 | return None 57 | else: 58 | return img 59 | else: 60 | return img 61 | 62 | 63 | class BaseDataset(Dataset): 64 | def __init__(self) -> None: 65 | super().__init__() 66 | self.loader = BaseLoader() 67 | self.client = self.loader.client 68 | 69 | def __getitem__(self, index): 70 | raise NotImplementedError 71 | 72 | 73 | class IterableBaseDataset(IterableDataset): 74 | def __init__(self) -> None: 75 | super().__init__() 76 | self.loader = BaseLoader() 77 | self.client = self.loader.client 78 | 79 | def __iter__(self): 80 | raise NotImplementedError 81 | 82 | -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/mscoco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import numpy as np 5 | 6 | from .loader import BaseDataset 7 | 8 | 9 | class CocoCaptionDataset(BaseDataset): 10 | def __init__( 11 | self, 12 | data_root, 13 | annt_root, 14 | transform, 15 | image_only=False, 16 | total_length=None, 17 | collate_mode="generate_images", 18 | shuffle=False, 19 | rerank_by_clip=False, 20 | phase="train", 21 | year="2014", 22 | ) -> None: 23 | super().__init__() 24 | self.collate_mode = collate_mode 25 | self.transform = transform 26 | self.data_root = data_root 27 | self.annt_root = annt_root 28 | self.phase = phase 29 | self.year = year 30 | self.image_only = image_only 31 | self.rerank_by_clip = rerank_by_clip 32 | 33 | annt_file = os.path.join( 34 | annt_root, "annotations", f"captions_{phase}{year}.json" 35 | ) 36 | self.annt_file = annt_file 37 | self.annts = json.load(open(annt_file, "r"))["annotations"] 38 | if self.image_only: 39 | self.dedeup_image() 40 | if shuffle: 41 | np.random.shuffle(self.annts) 42 | if total_length is not None: 43 | self.annts = self.annts[:total_length] 44 | print(f"length of the dataset is {len(self.annts)}") 45 | 46 | def dedeup_image(self): 47 | annts = {} 48 | for annt in self.annts: 49 | image_idx = str(annt["image_id"]).zfill(12) 50 | if image_idx in annts: 51 | continue 52 | annts[image_idx] = annt 53 | self.annts = list(annts.values()) 54 | 55 | def image_id_to_path(self, image_id): 56 | # coco-2014 57 | image_idx = str(image_id).zfill(12) 58 | image_name = f"COCO_{self.phase}{self.year}_{image_idx}.jpg" 59 | image_path = os.path.join( 60 | self.data_root, f"{self.phase}{self.year}", image_name 61 | ) 62 | return image_path 63 | 64 | def __repr__(self) -> str: 65 | return ( 66 | f"MSCOCO-Caption Dataset year={self.year} phase={self.phase}\n" 67 | f"annotation_root={self.annt_root} data_root={self.data_root}\n" 68 | f"transform={self.transform}" 69 | ) 70 | 71 | def __len__(self): 72 | return len(self.annts) 73 | 74 | def __getitem__(self, index): 75 | item = self.annts[index] 76 | caption = item["caption"].lower() 77 | 78 | image_idx = str(item["image_id"]).zfill(12) 79 | image_name = f"COCO_{self.phase}{self.year}_{image_idx}.jpg" 80 | image_path = os.path.join( 81 | self.data_root, f"{self.phase}{self.year}", image_name 82 | ) 83 | try: 84 | image = self.loader(image_path).convert("RGB") 85 | 86 | image = self.transform(image) 87 | except: 88 | print(image_path) 89 | index = random.randint(0, len(self) - 1) 90 | return self.__getitem__(index) 91 | 92 | return image, caption, item["image_id"] 93 | -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/mscoco_karpathy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | from .loader import BaseDataset 6 | 7 | 8 | class CocoCaptionKarpathyDataset(BaseDataset): 9 | def __init__( 10 | self, 11 | data_root, 12 | annt_root, 13 | transform, 14 | image_only=False, 15 | total_length=None, 16 | collate_mode="generate_texts", 17 | phase="train", 18 | year="2014", 19 | add_eos=None, 20 | use_1st_sentence_only=True, 21 | rerank_by_clip=False, 22 | ) -> None: 23 | super().__init__() 24 | self.collate_mode = collate_mode 25 | self.transform = transform 26 | self.data_root = data_root 27 | self.annt_root = annt_root 28 | self.phase = phase 29 | self.year = year 30 | self.image_only = image_only 31 | annt_file = os.path.join( 32 | annt_root, "annotations", f"coco_karpathy_{phase}.json" 33 | ) 34 | self.annts = json.load(open(annt_file, "r")) 35 | self.annt_file = annt_file 36 | if self.image_only: 37 | self.dedeup_image() 38 | if total_length is not None: 39 | self.annts = self.annts[:total_length] 40 | self.add_eos = add_eos 41 | self.use_1st_sentence_only = use_1st_sentence_only 42 | self.rerank_by_clip = rerank_by_clip 43 | print(f"length of the dataset is {len(self.annts)}") 44 | 45 | def dedeup_image(self): 46 | annts = {} 47 | for annt in self.annts: 48 | image_idx = annt["image"].split("_")[-1][ 49 | :-4 50 | ] # 'val2014/COCO_val2014_000000391895.jpg' 51 | if image_idx in annts: 52 | continue 53 | annts[image_idx] = annt 54 | self.annts = list(annts.values()) 55 | 56 | def image_id_to_path(self, image_id): 57 | phase = "val" if self.phase == "test" else self.phase 58 | # coco-2014 59 | image_idx = str(image_id).zfill(12) 60 | image_name = f"COCO_{phase}{self.year}_{image_idx}.jpg" 61 | image_path = os.path.join( 62 | self.data_root, f"{phase}{self.year}", image_name 63 | ) 64 | return image_path 65 | 66 | def __repr__(self) -> str: 67 | return ( 68 | f"MSCOCO-Caption Karpathy Dataset year={self.year} phase={self.phase}\n" 69 | f"annotation_root={self.annt_root} data_root={self.data_root}\n" 70 | f"transform={self.transform}" 71 | ) 72 | 73 | def __len__(self): 74 | return len(self.annts) 75 | 76 | def __getitem__(self, index): 77 | item = self.annts[index] 78 | caption = item["caption"] 79 | if isinstance(caption, list): 80 | caption = random.choice(caption) 81 | caption = caption.lower() 82 | if self.add_eos is not None: 83 | caption = caption + self.add_eos 84 | image_idx_int = int(item["image"].split("_")[-1][:-4]) 85 | image_name = item["image"] 86 | image_path = os.path.join(self.data_root, f"{image_name}") 87 | 88 | try: 89 | image = self.loader(image_path).convert("RGB") 90 | 91 | image = self.transform(image) 92 | except: 93 | print(image_path) 94 | index = random.randint(0, len(self) - 1) 95 | return self.__getitem__(index) 96 | 97 | return image, caption, image_idx_int 98 | -------------------------------------------------------------------------------- /TextHarmony/custom_datasets/sft_datasets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | import torch 6 | from torch.utils.data import ConcatDataset, WeightedRandomSampler 7 | 8 | from .loader import BaseDataset 9 | 10 | 11 | class LLaVADataset(BaseDataset): 12 | def __init__( 13 | self, 14 | annt_root=[], 15 | data_root=[], 16 | transform=None, 17 | ): 18 | super().__init__() 19 | self.ann_path = [annt_root] if isinstance(annt_root, str) else annt_root 20 | self.data_root = [data_root] if isinstance(data_root, str) else data_root 21 | self.transform = transform 22 | 23 | self.ann = [] 24 | print("Formatting inputs...Skip in lazy mode") 25 | for index, p in enumerate(self.ann_path): 26 | if p.endswith('json'): 27 | with open(p, 'r') as file: 28 | data = json.load(file) 29 | for item in data: 30 | try: 31 | item['image'] = os.path.join(self.data_root[index], item['image']) 32 | self.ann.append(item) 33 | except: 34 | pass 35 | elif p.endswith('.jsonl'): 36 | for line in open(p, 'r'): 37 | data = json.loads(line) 38 | try: 39 | data['image'] = os.path.join(self.data_root[index], data['image']) 40 | self.ann.append(data) 41 | except: 42 | pass 43 | 44 | # split multi-round dialogues to single-round dialogue 45 | max_conv_num = 2 # 1 round 46 | print(f"data length before split: {len(self.ann)}") 47 | new_ann = [] 48 | for item in self.ann: 49 | conversations = item["conversations"] 50 | conversations = [conversations[i:i + max_conv_num] for i in range(0, len(conversations), max_conv_num)] 51 | for conv in conversations: 52 | new_item = item.copy() 53 | if "" not in conv[0]['value']: 54 | conv[0]['value'] = "\n" + conv[0]['value'] 55 | new_item["conversations"] = conv 56 | new_ann.append(new_item) 57 | self.ann = new_ann 58 | print(f"data length after split: {len(self.ann)}") 59 | 60 | def __getitem__(self, index): 61 | while True: 62 | try: 63 | data = self.ann[index] 64 | 65 | assert len(data['conversations']) == 2 66 | 67 | query = data['conversations'][0]['value'].replace('\n', '') 68 | query = query.replace('\n', '') 69 | query = query.replace('', '') 70 | 71 | image_id = data['id'] 72 | image = self.loader(data['image']).convert('RGB') 73 | label = data['conversations'][1]['value'] 74 | break 75 | except Exception as e: 76 | print(e) 77 | print('Error loading data:', data['image']) 78 | index = random.randint(0, len(self.ann) - 1) 79 | 80 | return self.transform(image), query, label, image_id 81 | 82 | def __len__(self): 83 | return len(self.ann) 84 | 85 | 86 | class WeightedConcatDataset(ConcatDataset): 87 | def __init__(self, datasets, weights): 88 | super().__init__(datasets) 89 | self.weights = torch.DoubleTensor(weights) 90 | self.total_size = sum(len(d) for d in datasets) 91 | self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True) 92 | 93 | def __iter__(self): 94 | return iter(self.sampler) 95 | 96 | def __len__(self): 97 | return self.total_size 98 | -------------------------------------------------------------------------------- /TextHarmony/engine/__pycache__/lmm_trainer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/engine/__pycache__/lmm_trainer.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .TextHarmony import TextHarmony 2 | -------------------------------------------------------------------------------- /TextHarmony/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /TextHarmony/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/__pycache__/mm_interleaved.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/__pycache__/mm_interleaved.cpython-310.pyc -------------------------------------------------------------------------------- /TextHarmony/models/__pycache__/mm_interleaved.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/__pycache__/mm_interleaved.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/__pycache__/mm_interleaved.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/__pycache__/mm_interleaved.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/decoders/__pycache__/decoder_image.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/decoders/__pycache__/decoder_image.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/decoders/__pycache__/decoder_image.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/decoders/__pycache__/decoder_image.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/decoders/__pycache__/decoder_text.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/decoders/__pycache__/decoder_text.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/decoders/__pycache__/decoder_text.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/decoders/__pycache__/decoder_text.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/decoders/__pycache__/modeling_llama_mmfs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/decoders/__pycache__/modeling_llama_mmfs.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/decoders/__pycache__/modeling_llama_mmfs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/decoders/__pycache__/modeling_llama_mmfs.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/decoders/__pycache__/perceiver.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/decoders/__pycache__/perceiver.cpython-310.pyc -------------------------------------------------------------------------------- /TextHarmony/models/decoders/__pycache__/perceiver.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/decoders/__pycache__/perceiver.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/decoders/__pycache__/perceiver.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/decoders/__pycache__/perceiver.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/decoders/__pycache__/sd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/decoders/__pycache__/sd.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/decoders/__pycache__/sd.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/decoders/__pycache__/sd.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/decoders/__pycache__/sd_mmfs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/decoders/__pycache__/sd_mmfs.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/decoders/perceiver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import Blip2QFormerModel, Blip2QFormerConfig 5 | 6 | 7 | class PerceiverResampler(nn.Module): 8 | def __init__( 9 | self, 10 | num_queries=32, 11 | hidden_size=768, 12 | qk_normalization=False, 13 | gradient_checkpointing=True, 14 | **kwargs 15 | ) -> None: 16 | super().__init__() 17 | 18 | config = Blip2QFormerConfig(hidden_size=hidden_size, **kwargs) 19 | config.qk_normalization = qk_normalization 20 | self.blip2qformer = Blip2QFormerModel(config) 21 | 22 | self.queries = nn.Parameter(torch.zeros(1, num_queries, hidden_size)) 23 | self.queries.data.normal_(0, config.initializer_range) 24 | if gradient_checkpointing: 25 | self.blip2qformer.gradient_checkpointing_enable() 26 | 27 | def forward(self, **kwargs): 28 | query_embeds = kwargs.pop("query_embeds", self.queries) 29 | 30 | return self.blip2qformer(query_embeds=query_embeds, **kwargs) 31 | -------------------------------------------------------------------------------- /TextHarmony/models/encoders/__pycache__/visual_tokenizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/__pycache__/visual_tokenizer.cpython-310.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/__pycache__/visual_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/__pycache__/visual_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/__pycache__/visual_tokenizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/__pycache__/visual_tokenizer.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip_vit_hf import CLIPVisionTransformer, CLIPVisionModel 2 | from .vit_adapter_hf import CLIPVisionTransformerAdapter, CLIPVisionAdapterModel 3 | from .vit_adapter_hf import clip_vit_adapter_hf 4 | 5 | __all__ = ["CLIPVisionTransformer", "CLIPVisionModel", 'clip_vit_adapter_hf', 6 | "CLIPVisionTransformerAdapter", "CLIPVisionAdapterModel"] 7 | -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/__pycache__/adapter_modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/__pycache__/adapter_modules.cpython-310.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/__pycache__/adapter_modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/__pycache__/adapter_modules.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/__pycache__/adapter_modules.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/__pycache__/adapter_modules.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/__pycache__/clip_vit_hf.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/__pycache__/clip_vit_hf.cpython-310.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/__pycache__/clip_vit_hf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/__pycache__/clip_vit_hf.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/__pycache__/clip_vit_hf.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/__pycache__/clip_vit_hf.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/__pycache__/vit_adapter_hf.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/__pycache__/vit_adapter_hf.cpython-310.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/__pycache__/vit_adapter_hf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/__pycache__/vit_adapter_hf.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/__pycache__/vit_adapter_hf.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/__pycache__/vit_adapter_hf.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/__pycache__/xattn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/__pycache__/xattn.cpython-310.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/__pycache__/xattn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/__pycache__/xattn.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/__pycache__/xattn.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/__pycache__/xattn.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 10 | -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/ops/functions/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/ops/functions/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/ops/functions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/ops/functions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/ops/functions/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/ops/functions/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/ops/functions/__pycache__/ms_deform_attn_func.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/ops/functions/__pycache__/ms_deform_attn_func.cpython-310.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/ops/functions/__pycache__/ms_deform_attn_func.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/ops/functions/__pycache__/ms_deform_attn_func.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/ops/functions/__pycache__/ms_deform_attn_func.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/ops/functions/__pycache__/ms_deform_attn_func.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | from torch.cuda.amp import custom_bwd, custom_fwd 18 | 19 | try: 20 | import MultiScaleDeformableAttention as MSDA 21 | except: 22 | print("MultiScaleDeformableAttention is not installed") 23 | 24 | 25 | class MSDeformAttnFunction(Function): 26 | @staticmethod 27 | @custom_fwd 28 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, 29 | im2col_step): 30 | ctx.im2col_step = im2col_step 31 | output = MSDA.ms_deform_attn_forward( 32 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, 33 | ctx.im2col_step) 34 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, 35 | attention_weights) 36 | return output 37 | 38 | @staticmethod 39 | @once_differentiable 40 | @custom_bwd 41 | def backward(ctx, grad_output): 42 | grad_output = grad_output.contiguous() 43 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 44 | grad_value, grad_sampling_loc, grad_attn_weight = \ 45 | MSDA.ms_deform_attn_backward( 46 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, 47 | grad_output, ctx.im2col_step) 48 | 49 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 50 | 51 | 52 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 53 | # for debug and test only, 54 | # need to use cuda version instead 55 | N_, S_, M_, D_ = value.shape 56 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 57 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 58 | sampling_grids = 2 * sampling_locations - 1 59 | sampling_value_list = [] 60 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 61 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 62 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) 63 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 64 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 65 | # N_*M_, D_, Lq_, P_ 66 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 67 | mode='bilinear', padding_mode='zeros', align_corners=False) 68 | sampling_value_list.append(sampling_value_l_) 69 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 70 | attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_) 71 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_) 72 | return output.transpose(1, 2).contiguous() 73 | -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import * 10 | -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/ops/modules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/ops/modules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/ops/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/ops/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/ops/modules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/ops/modules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/ops/modules/__pycache__/ms_deform_attn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/ops/modules/__pycache__/ms_deform_attn.cpython-310.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/ops/modules/__pycache__/ms_deform_attn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/ops/modules/__pycache__/ms_deform_attn.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/encoders/vit_adapter/ops/modules/__pycache__/ms_deform_attn.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/encoders/vit_adapter/ops/modules/__pycache__/ms_deform_attn.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/moe/__init__.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import peft 3 | from peft import PEFT_TYPE_TO_CONFIG_MAPPING 4 | from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING 5 | 6 | 7 | # register MoE LoRA 8 | class PeftType(str, enum.Enum): 9 | PROMPT_TUNING = "PROMPT_TUNING" 10 | P_TUNING = "P_TUNING" 11 | PREFIX_TUNING = "PREFIX_TUNING" 12 | LORA = "LORA" 13 | ADALORA = "ADALORA" 14 | ADAPTION_PROMPT = "ADAPTION_PROMPT" 15 | IA3 = "IA3" 16 | MOE_LORA = 'MOE_LORA' 17 | 18 | peft.PeftType = PeftType 19 | 20 | from .moe_lora import MoeLoraConfig, MoeLoraModel 21 | PEFT_TYPE_TO_CONFIG_MAPPING[peft.PeftType.MOE_LORA] = MoeLoraConfig 22 | PEFT_TYPE_TO_MODEL_MAPPING[peft.PeftType.MOE_LORA] = MoeLoraModel 23 | 24 | 25 | __all__ = [ 26 | 'MoeLoraConfig', 27 | ] 28 | -------------------------------------------------------------------------------- /TextHarmony/models/moe/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/moe/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/moe/__pycache__/layer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/moe/__pycache__/layer.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/moe/__pycache__/moe_lora.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/moe/__pycache__/moe_lora.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/utils/__pycache__/causal_lm_cascade.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/__pycache__/causal_lm_cascade.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/utils/__pycache__/pos_embed.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/__pycache__/pos_embed.cpython-310.pyc -------------------------------------------------------------------------------- /TextHarmony/models/utils/__pycache__/pos_embed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/__pycache__/pos_embed.cpython-38.pyc -------------------------------------------------------------------------------- /TextHarmony/models/utils/__pycache__/pos_embed.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/__pycache__/pos_embed.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/utils/monkey_patch/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama_flash_attn_train_monkey_patch import replace_llama_attn_with_flash_attn 2 | from .blip2_qknorm_monkey_patch import replace_blip2_attn_with_qknorm_attn 3 | from .beam_search_monkey_patch import replace_beam_search 4 | from .sd_pipeline_monkey_patch import replace_stable_diffusion_pipeline_call 5 | from .sd_unet_forward_monkey_patch import replace_stable_diffusion_unet_forward 6 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/monkey_patch/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/monkey_patch/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/utils/monkey_patch/__pycache__/beam_search_monkey_patch.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/monkey_patch/__pycache__/beam_search_monkey_patch.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/utils/monkey_patch/__pycache__/blip2_qknorm_monkey_patch.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/monkey_patch/__pycache__/blip2_qknorm_monkey_patch.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/utils/monkey_patch/__pycache__/llama_flash_attn_train_monkey_patch.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/monkey_patch/__pycache__/llama_flash_attn_train_monkey_patch.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/utils/monkey_patch/__pycache__/sd_pipeline_monkey_patch.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/monkey_patch/__pycache__/sd_pipeline_monkey_patch.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/utils/monkey_patch/__pycache__/sd_unet_forward_monkey_patch.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/monkey_patch/__pycache__/sd_unet_forward_monkey_patch.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/MultiScaleDeformableAttention.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: MultiScaleDeformableAttention 3 | Version: 1.0 4 | Summary: PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention 5 | Home-page: https://github.com/fundamentalvision/Deformable-DETR 6 | Author: Weijie Su 7 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/MultiScaleDeformableAttention.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | MM-Interleaved/mm_interleaved/models/utils/ops/src/vision.cpp 3 | MM-Interleaved/mm_interleaved/models/utils/ops/src/cpu/ms_deform_attn_cpu.cpp 4 | MM-Interleaved/mm_interleaved/models/utils/ops/src/cuda/ms_deform_attn_cuda.cu 5 | MultiScaleDeformableAttention.egg-info/PKG-INFO 6 | MultiScaleDeformableAttention.egg-info/SOURCES.txt 7 | MultiScaleDeformableAttention.egg-info/dependency_links.txt 8 | MultiScaleDeformableAttention.egg-info/top_level.txt 9 | functions/__init__.py 10 | functions/ms_deform_attn_func.py 11 | modules/__init__.py 12 | modules/mmfs.py -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/MultiScaleDeformableAttention.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/MultiScaleDeformableAttention.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | MultiScaleDeformableAttention 2 | functions 3 | modules 4 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/build/lib.linux-x86_64-cpython-39/MultiScaleDeformableAttention.cpython-39-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/ops/build/lib.linux-x86_64-cpython-39/MultiScaleDeformableAttention.cpython-39-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/build/lib.linux-x86_64-cpython-39/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 10 | 11 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/build/lib.linux-x86_64-cpython-39/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | from torch.cuda.amp import custom_bwd, custom_fwd 18 | try: 19 | import MultiScaleDeformableAttention as MSDA 20 | except: 21 | print("MultiScaleDeformableAttention is not installed") 22 | 23 | 24 | class MSDeformAttnFunction(Function): 25 | @staticmethod 26 | @custom_fwd 27 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 28 | ctx.im2col_step = im2col_step 29 | output = MSDA.ms_deform_attn_forward( 30 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 31 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 32 | return output 33 | 34 | @staticmethod 35 | @once_differentiable 36 | @custom_bwd 37 | def backward(ctx, grad_output): 38 | grad_output = grad_output.contiguous() 39 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 40 | grad_value, grad_sampling_loc, grad_attn_weight = \ 41 | MSDA.ms_deform_attn_backward( 42 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 43 | 44 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 45 | 46 | 47 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 48 | # for debug and test only, 49 | # need to use cuda version instead 50 | N_, S_, M_, D_ = value.shape 51 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 52 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 53 | sampling_grids = 2 * sampling_locations - 1 54 | sampling_value_list = [] 55 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 56 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 57 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 58 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 59 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 60 | # N_*M_, D_, Lq_, P_ 61 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 62 | mode='bilinear', padding_mode='zeros', align_corners=False) 63 | sampling_value_list.append(sampling_value_l_) 64 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 65 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 66 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 67 | return output.transpose(1, 2).contiguous() 68 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/build/lib.linux-x86_64-cpython-39/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Multi-Image Multi-Scale Feature Synchronizer 3 | # Modifed from Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | from .mmfs import MMFS 11 | 12 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/build/temp.linux-x86_64-cpython-39/.ninja_deps: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/ops/build/temp.linux-x86_64-cpython-39/.ninja_deps -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/build/temp.linux-x86_64-cpython-39/.ninja_log: -------------------------------------------------------------------------------- 1 | # ninja log v5 2 | 2 21356 1707319042849568000 MM-Interleaved/mm_interleaved/models/utils/ops/build/temp.linux-x86_64-cpython-39MM-Interleaved/mm_interleaved/models/utils/ops/src/cpu/ms_deform_attn_cpu.o 5b89eef9bbefdddf 3 | 3 35429 1707319056914530000 MM-Interleaved/mm_interleaved/models/utils/ops/build/temp.linux-x86_64-cpython-39MM-Interleaved/mm_interleaved/models/utils/ops/src/vision.o c18d90d942f432cf 4 | 2 87589 1707319109082569000 MM-Interleaved/mm_interleaved/models/utils/ops/build/temp.linux-x86_64-cpython-39MM-Interleaved/mm_interleaved/models/utils/ops/src/cuda/ms_deform_attn_cuda.o 4900055df687b016 5 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/build/temp.linux-x86_64-cpython-39/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /usr/local/cuda/bin/nvcc 4 | 5 | cflags = -pthread -B anaconda3/envs/mmInter/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem anaconda3/envs/mmInter/include -Ianaconda3/envs/mmInter/include -fPIC -O2 -isystem anaconda3/envs/mmInter/include -fPIC -DWITH_CUDA -IMM-Interleaved/mm_interleaved/models/utils/ops/src -Ianaconda3/envs/mmInter/lib/python3.9/site-packages/torch/include -Ianaconda3/envs/mmInter/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -Ianaconda3/envs/mmInter/lib/python3.9/site-packages/torch/include/TH -Ianaconda3/envs/mmInter/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -Ianaconda3/envs/mmInter/include/python3.9 -c 6 | post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=MultiScaleDeformableAttention -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++17 7 | cuda_cflags = -DWITH_CUDA -IMM-Interleaved/mm_interleaved/models/utils/ops/src -Ianaconda3/envs/mmInter/lib/python3.9/site-packages/torch/include -Ianaconda3/envs/mmInter/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -Ianaconda3/envs/mmInter/lib/python3.9/site-packages/torch/include/TH -Ianaconda3/envs/mmInter/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -Ianaconda3/envs/mmInter/include/python3.9 -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=MultiScaleDeformableAttention -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 -std=c++17 9 | cuda_dlink_post_cflags = 10 | ldflags = 11 | 12 | rule compile 13 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 14 | depfile = $out.d 15 | deps = gcc 16 | 17 | rule cuda_compile 18 | depfile = $out.d 19 | deps = gcc 20 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 21 | 22 | 23 | 24 | 25 | 26 | build MM-Interleaved/mm_interleaved/models/utils/ops/build/temp.linux-x86_64-cpython-39MM-Interleaved/mm_interleaved/models/utils/ops/src/cpu/ms_deform_attn_cpu.o: compile MM-Interleaved/mm_interleaved/models/utils/ops/src/cpu/ms_deform_attn_cpu.cpp 27 | build MM-Interleaved/mm_interleaved/models/utils/ops/build/temp.linux-x86_64-cpython-39MM-Interleaved/mm_interleaved/models/utils/ops/src/cuda/ms_deform_attn_cuda.o: cuda_compile MM-Interleaved/mm_interleaved/models/utils/ops/src/cuda/ms_deform_attn_cuda.cu 28 | build MM-Interleaved/mm_interleaved/models/utils/ops/build/temp.linux-x86_64-cpython-39MM-Interleaved/mm_interleaved/models/utils/ops/src/vision.o: compile MM-Interleaved/mm_interleaved/models/utils/ops/src/vision.cpp 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/build/temp.linux-x86_64-cpython-39/mnt/bn/zz-nas/MM-Interleaved/mm_interleaved/models/utils/ops/src/cpu/ms_deform_attn_cpu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/ops/build/temp.linux-x86_64-cpython-39/mnt/bn/zz-nas/MM-Interleaved/mm_interleaved/models/utils/ops/src/cpu/ms_deform_attn_cpu.o -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/build/temp.linux-x86_64-cpython-39/mnt/bn/zz-nas/MM-Interleaved/mm_interleaved/models/utils/ops/src/cuda/ms_deform_attn_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/ops/build/temp.linux-x86_64-cpython-39/mnt/bn/zz-nas/MM-Interleaved/mm_interleaved/models/utils/ops/src/cuda/ms_deform_attn_cuda.o -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/build/temp.linux-x86_64-cpython-39/mnt/bn/zz-nas/MM-Interleaved/mm_interleaved/models/utils/ops/src/vision.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/ops/build/temp.linux-x86_64-cpython-39/mnt/bn/zz-nas/MM-Interleaved/mm_interleaved/models/utils/ops/src/vision.o -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/dist/MultiScaleDeformableAttention-1.0-py3.9-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/ops/dist/MultiScaleDeformableAttention-1.0-py3.9-linux-x86_64.egg -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 10 | 11 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/functions/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/ops/functions/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/functions/__pycache__/ms_deform_attn_func.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/ops/functions/__pycache__/ms_deform_attn_func.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | from torch.cuda.amp import custom_bwd, custom_fwd 18 | try: 19 | import MultiScaleDeformableAttention as MSDA 20 | except: 21 | print("MultiScaleDeformableAttention is not installed") 22 | 23 | 24 | class MSDeformAttnFunction(Function): 25 | @staticmethod 26 | @custom_fwd 27 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 28 | ctx.im2col_step = im2col_step 29 | output = MSDA.ms_deform_attn_forward( 30 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 31 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 32 | return output 33 | 34 | @staticmethod 35 | @once_differentiable 36 | @custom_bwd 37 | def backward(ctx, grad_output): 38 | grad_output = grad_output.contiguous() 39 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 40 | grad_value, grad_sampling_loc, grad_attn_weight = \ 41 | MSDA.ms_deform_attn_backward( 42 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 43 | 44 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 45 | 46 | 47 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 48 | # for debug and test only, 49 | # need to use cuda version instead 50 | N_, S_, M_, D_ = value.shape 51 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 52 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 53 | sampling_grids = 2 * sampling_locations - 1 54 | sampling_value_list = [] 55 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 56 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 57 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 58 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 59 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 60 | # N_*M_, D_, Lq_, P_ 61 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 62 | mode='bilinear', padding_mode='zeros', align_corners=False) 63 | sampling_value_list.append(sampling_value_l_) 64 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 65 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 66 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 67 | return output.transpose(1, 2).contiguous() 68 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Multi-Image Multi-Scale Feature Synchronizer 3 | # Modifed from Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | from .mmfs import MMFS 11 | 12 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/modules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/ops/modules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/modules/__pycache__/mmfs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/ops/modules/__pycache__/mmfs.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | import os 10 | import glob 11 | 12 | import torch 13 | 14 | from torch.utils.cpp_extension import CUDA_HOME 15 | from torch.utils.cpp_extension import CppExtension 16 | from torch.utils.cpp_extension import CUDAExtension 17 | 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | requirements = ["torch", "torchvision"] 22 | 23 | def get_extensions(): 24 | this_dir = os.path.dirname(os.path.abspath(__file__)) 25 | extensions_dir = os.path.join(this_dir, "src") 26 | 27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 30 | 31 | sources = main_file + source_cpu 32 | extension = CppExtension 33 | extra_compile_args = {"cxx": []} 34 | define_macros = [] 35 | 36 | if torch.cuda.is_available() and CUDA_HOME is not None: 37 | extension = CUDAExtension 38 | sources += source_cuda 39 | define_macros += [("WITH_CUDA", None)] 40 | extra_compile_args["nvcc"] = [ 41 | # "-DCUDA_HAS_FP16=1", 42 | # "-D__CUDA_NO_HALF_OPERATORS__", 43 | # "-D__CUDA_NO_HALF_CONVERSIONS__", 44 | # "-D__CUDA_NO_HALF2_OPERATORS__", 45 | ] 46 | else: 47 | raise NotImplementedError('Cuda is not availabel') 48 | 49 | sources = [os.path.join(extensions_dir, s) for s in sources] 50 | include_dirs = [extensions_dir] 51 | ext_modules = [ 52 | extension( 53 | "MultiScaleDeformableAttention", 54 | sources, 55 | include_dirs=include_dirs, 56 | define_macros=define_macros, 57 | extra_compile_args=extra_compile_args, 58 | ) 59 | ] 60 | return ext_modules 61 | 62 | setup( 63 | name="MultiScaleDeformableAttention", 64 | version="1.0", 65 | author="Weijie Su", 66 | url="https://github.com/fundamentalvision/Deformable-DETR", 67 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 68 | packages=find_packages(exclude=("configs", "tests",)), 69 | ext_modules=get_extensions(), 70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 71 | ) 72 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/models/utils/ops/tests/__init__.py -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/tests/create_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from collections import OrderedDict 6 | import pickle as pkl 7 | import copy 8 | 9 | from functions.ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | def generate_inputs(dtype, bs=1, n_levels=2, shapes=((6,4), (3,2)), n_query=2, n_points=2, n_heads=2, head_dim=4): 12 | assert len(shapes) == n_levels 13 | shapes = torch.as_tensor(list(shapes), dtype=torch.long) 14 | assert shapes.shape[1] == 2 15 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 16 | spatial_size = sum([(H*W).item() for H, W in shapes]) 17 | 18 | value = torch.rand(bs, spatial_size, n_heads, head_dim) 19 | sampling_locations = torch.rand(bs, n_query, n_heads, n_levels, n_points, 2) 20 | attention_weights = torch.rand(bs, n_query, n_heads, n_levels, n_points).cuda() + 1e-5 21 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 22 | im2col_step = 2 23 | return ( 24 | value.cuda().to(dtype).requires_grad_(True), 25 | shapes.cuda(), 26 | level_start_index.cuda(), 27 | sampling_locations.cuda().to(dtype).requires_grad_(True), 28 | attention_weights.cuda().to(dtype).requires_grad_(True), 29 | im2col_step 30 | ) 31 | 32 | def test_op(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step): 33 | return MSDeformAttnFunction.apply( 34 | value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step) 35 | 36 | def to_dtype(inputs, dtype): 37 | new_inputs = list(copy.deepcopy(inputs)) 38 | for i, x in enumerate(new_inputs): 39 | if isinstance(x, torch.Tensor) and torch.is_floating_point(x): 40 | new_inputs[i] = x.to(dtype).detach() 41 | new_inputs[i].requires_grad = x.requires_grad 42 | return tuple(new_inputs) 43 | 44 | 45 | if __name__ == '__main__': 46 | out_file = 'data/fp64_data.pkl' 47 | data_list = [] 48 | for i in range(20): 49 | inputs = generate_inputs(torch.float16, 64) 50 | inputs = to_dtype(inputs, torch.float64) 51 | inputs[0].requires_grad = True 52 | inputs[3].requires_grad = True 53 | inputs[4].requires_grad = True 54 | 55 | outs = test_op(*inputs) 56 | outs.sum().backward() 57 | 58 | grads = [x.grad if hasattr(x, 'grad') else None for x in inputs] 59 | 60 | data_list.append(OrderedDict({ 61 | 'inputs': inputs, 62 | 'grads': grads, 63 | 'outs': outs 64 | })) 65 | 66 | torch.save(data_list, out_file) 67 | 68 | 69 | -------------------------------------------------------------------------------- /TextHarmony/models/utils/ops/tests/speed_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | import time 4 | import torch 5 | 6 | import copy 7 | 8 | from functions.ms_deform_attn_func import MSDeformAttnFunction 9 | 10 | from tests.create_data import generate_inputs 11 | from easydict import EasyDict as edict 12 | 13 | torch.manual_seed(0) 14 | 15 | def test_op(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step): 16 | return MSDeformAttnFunction.apply( 17 | value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step 18 | ) 19 | 20 | 21 | def to_dtype(inputs, dtype): 22 | new_inputs = list(copy.deepcopy(inputs)) 23 | for i, x in enumerate(new_inputs): 24 | if isinstance(x, torch.Tensor) and torch.is_floating_point(x): 25 | new_inputs[i] = x.to(dtype).detach() 26 | new_inputs[i].requires_grad = x.requires_grad 27 | return tuple(new_inputs) 28 | 29 | 30 | def run(module, args, name='Unknown'): 31 | inputs = generate_inputs(args.dtype, **args.data_args) 32 | 33 | # cudnn warmup 34 | for _ in range(50): 35 | if args.backward: 36 | module(*inputs).sum().backward() 37 | else: 38 | module(*inputs) 39 | 40 | torch.cuda.synchronize() 41 | t0 = time.time() 42 | 43 | for _ in range(args.num_iter): 44 | if args.backward: 45 | module(*inputs).sum().backward() 46 | else: 47 | module(*inputs) 48 | 49 | torch.cuda.synchronize() 50 | t1 = time.time() 51 | 52 | avg_time = (t1 - t0) * 1000 / args.num_iter 53 | print( 54 | f'>>> {name} finished {args.num_iter} running, avg_time: {avg_time:.6f} ms') 55 | return avg_time 56 | 57 | def info_memory(msg=None): 58 | if msg: 59 | print(msg) 60 | print(f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \ 61 | Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \ 62 | CA {round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024),2)} GB \ 63 | Max_CA {round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024))} GB ") 64 | 65 | 66 | 67 | if __name__ == '__main__': 68 | 69 | data_args = edict() 70 | data_args.bs = 32 71 | data_args.n_levels = 2 72 | data_args.shapes=[(16,16), (8,8)] 73 | data_args.n_query = 128 74 | data_args.n_points = 64 75 | data_args.n_heads = 8 76 | data_args.head_dim = 128 77 | 78 | args = edict() 79 | args.num_iter = 200 80 | args.backward = True 81 | args.dtype = torch.float16 82 | args.data_args = data_args 83 | 84 | run(test_op, args, name='fp16') 85 | info_memory() 86 | args.dtype = torch.float32 87 | run(test_op, args, name='fp32') 88 | info_memory() -------------------------------------------------------------------------------- /TextHarmony/scripts/download_hf_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from transformers import CLIPModel, CLIPProcessor 5 | from transformers import LlamaTokenizer, LlamaForCausalLM 6 | from diffusers import StableDiffusionPipeline 7 | 8 | version = 'lmsys/vicuna-13b-v1.3' 9 | path = os.path.join('./assets', version) 10 | os.makedirs(path, exist_ok=True) 11 | llm_tokenizer:LlamaTokenizer = LlamaTokenizer.from_pretrained(version) 12 | llm_tokenizer.save_pretrained(path) 13 | llm_model = LlamaForCausalLM.from_pretrained(version, force_download=True, resume_download=False) 14 | llm_model.save_pretrained(path) 15 | 16 | version = "openai/clip-vit-large-patch14" 17 | clip_model = CLIPModel.from_pretrained(version) 18 | clip_processor = CLIPProcessor.from_pretrained(version) 19 | path = os.path.join('./assets', version) 20 | os.makedirs(path, exist_ok=True) 21 | clip_model.save_pretrained(path) 22 | clip_processor.save_pretrained(path) 23 | 24 | version = 'stabilityai/stable-diffusion-2-base' 25 | path = os.path.join('./assets', version) 26 | os.makedirs(path, exist_ok=True) 27 | pipe = StableDiffusionPipeline.from_pretrained(version, torch_dtype=torch.float32) 28 | pipe.save_pretrained(path) -------------------------------------------------------------------------------- /TextHarmony/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .parse_args import ArgumentParser, TrainingArguments 2 | from .misc import init_distributed_mode, load_model_weights 3 | from .caption_collect import collect_caption_result 4 | from .vqa_collect import collect_vqa_result 5 | -------------------------------------------------------------------------------- /TextHarmony/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/__pycache__/caption_collect.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/__pycache__/caption_collect.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/__pycache__/clip_sim_score.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/__pycache__/clip_sim_score.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/__pycache__/coco_cap_score.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/__pycache__/coco_cap_score.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/__pycache__/fid_score.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/__pycache__/fid_score.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/__pycache__/grounding_score.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/__pycache__/grounding_score.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/__pycache__/inception.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/__pycache__/inception.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/__pycache__/misc.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/__pycache__/misc.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/__pycache__/parse_args.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/__pycache__/parse_args.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/__pycache__/segm_eval.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/__pycache__/segm_eval.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/__pycache__/visdial_metrics.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/__pycache__/visdial_metrics.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/__pycache__/vqa_collect.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/__pycache__/vqa_collect.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/__pycache__/vqa_score.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/__pycache__/vqa_score.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/caption_collect.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from .misc import barrier, get_rank, get_world_size 5 | 6 | 7 | def collect_caption_result(result, result_dir, filename, remove_duplicate=''): 8 | result_file = os.path.join(result_dir, '%s_rank%d.json' % (filename, get_rank())) 9 | final_result_file = os.path.join(result_dir, '%s.json' % filename) 10 | 11 | json.dump(result, open(result_file, 'w')) 12 | 13 | barrier() 14 | 15 | if get_rank() == 0: 16 | # combine results from all processes 17 | result = [] 18 | 19 | for rank in range(get_world_size()): 20 | result_file = os.path.join(result_dir, '%s_rank%d.json' % (filename, rank)) 21 | res = json.load(open(result_file, 'r')) 22 | result += res 23 | os.remove(result_file) 24 | 25 | if remove_duplicate: 26 | result_new = [] 27 | id_list = set() 28 | for res in result: 29 | if res[remove_duplicate] not in id_list: 30 | id_list.add(res[remove_duplicate]) 31 | result_new.append(res) 32 | result = result_new 33 | 34 | json.dump(result, open(final_result_file, 'w')) 35 | print('result file saved to %s' % final_result_file) 36 | 37 | return final_result_file -------------------------------------------------------------------------------- /TextHarmony/utils/coco_cap_score.py: -------------------------------------------------------------------------------- 1 | from pycocotools.coco import COCO 2 | from pycocoevalcap.eval import COCOEvalCap 3 | import os 4 | import json 5 | 6 | 7 | def coco_caption_eval( 8 | annotation_file, 9 | results_file, 10 | phase="test", 11 | use_1st_sentence_only=False, 12 | ): 13 | 14 | # we use the test dataset as the evaluation 15 | annotation_file = annotation_file.replace( 16 | f"coco_karpathy_{phase}.json", f"coco_karpathy_{phase}_gt.json" 17 | ) 18 | # create coco object and coco_result object 19 | coco = COCO(annotation_file) 20 | 21 | with open(results_file) as f: 22 | anns = json.load(f) 23 | if use_1st_sentence_only: 24 | for ann in anns: 25 | ann["caption"] = ann["caption"].split(".")[0] 26 | coco_result = coco.loadRes(anns) 27 | 28 | # create coco_eval object by taking coco and coco_result 29 | coco_eval = COCOEvalCap(coco, coco_result) 30 | 31 | # evaluate on a subset of images by setting 32 | # coco_eval.params['image_id'] = coco_result.getImgIds() 33 | # please remove this line when evaluating the full validation set 34 | coco_eval.params["image_id"] = coco_result.getImgIds() 35 | 36 | try: 37 | # evaluate results 38 | # SPICE will take a few minutes the first time, but speeds up due to caching 39 | coco_eval.evaluate() 40 | except Exception as exp: 41 | print(exp) 42 | return {} 43 | 44 | # print output evaluation scores 45 | return coco_eval.eval 46 | -------------------------------------------------------------------------------- /TextHarmony/utils/grounding_score.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import torch 4 | from torchvision.ops.boxes import box_area 5 | 6 | def box_iou(boxes1, boxes2): 7 | area1 = box_area(boxes1) 8 | area2 = box_area(boxes2) 9 | 10 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 11 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 12 | 13 | wh = (rb - lt).clamp(min=0) # [N,M,2] 14 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 15 | 16 | union = area1[:, None] + area2 - inter 17 | 18 | iou = inter / union 19 | return iou, union 20 | 21 | def parse_box(box_str): 22 | PATTERN = re.compile(r'\((.*?)\)\((.*?)\)') 23 | predict_bbox = re.findall(PATTERN, box_str) 24 | 25 | try: 26 | if ',' not in predict_bbox[0][0] or ',' not in predict_bbox[0][1]: 27 | predict_bbox = (0., 0., 0., 0.) 28 | else: 29 | x1, y1 = [ 30 | float(tmp) for tmp in predict_bbox[0][0].split(',') 31 | ] 32 | x2, y2 = [ 33 | float(tmp) for tmp in predict_bbox[0][1].split(',') 34 | ] 35 | predict_bbox = (x1, y1, x2, y2) 36 | except: 37 | predict_bbox = (0., 0., 0., 0.) 38 | 39 | return predict_bbox 40 | 41 | def grounding_eval(results_file): 42 | results = json.load(open(results_file)) 43 | 44 | total_cnt = 0 45 | correct = 0 46 | for item in results: 47 | gt_box = item['gt_box'] 48 | pred_box = item['pred_box'] 49 | h = item['height'] 50 | w = item['width'] 51 | 52 | pred_box = parse_box(pred_box) 53 | pred_box = torch.tensor(pred_box, dtype=torch.float32).view(-1, 4) / 999 54 | pred_box[:, 0::2] *= w 55 | pred_box[:, 1::2] *= h 56 | 57 | gt_box = torch.tensor(gt_box, dtype=torch.float32).view(-1, 4) / 999 58 | gt_box[:, 0::2] *= w 59 | gt_box[:, 1::2] *= h 60 | 61 | iou, _ = box_iou(pred_box, gt_box) 62 | iou = iou.item() 63 | total_cnt += 1 64 | if iou >= 0.5: 65 | correct += 1 66 | 67 | return {'accuracy': correct / total_cnt} 68 | -------------------------------------------------------------------------------- /TextHarmony/utils/parse_args.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, List, Optional, Tuple, Union 3 | from dataclasses import dataclass, field, fields 4 | 5 | from mmcv import Config 6 | import transformers 7 | from transformers.hf_argparser import HfArgumentParser, DataClass 8 | 9 | from .misc import is_main_process 10 | 11 | 12 | @dataclass 13 | class TrainingArguments(transformers.TrainingArguments): 14 | config_file: Optional[str] = field(default="./configs/debug.yaml") 15 | resume: Optional[bool] = field(default=True) 16 | 17 | output_dir: Optional[str] = field(default="./OUTPUT/debug") 18 | remove_unused_columns: Optional[bool] = field( 19 | default=False, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."} 20 | ) 21 | 22 | lr_for_random_params: Optional[float] = field(default=1e-3) 23 | random_params: Optional[str] = field(default=None) 24 | lr_for_random_params_list: Optional[List[str]]= field(default_factory=lambda: None) 25 | wd_for_random_params_list: Optional[List[str]]= field(default_factory=lambda: None) 26 | random_params_list: Optional[List[str]] = field(default_factory=lambda: None) 27 | 28 | generate_mode: Optional[str] = field(default="generate_texts") 29 | use_1st_sentence_only: Optional[bool] = field(default=False) 30 | 31 | 32 | class ArgumentParser(HfArgumentParser): 33 | def parse_args_with_config_file_into_dataclasses( 34 | self, 35 | args=None, 36 | return_remaining_strings=False, 37 | ) -> Tuple[DataClass, ...]: 38 | """ 39 | 1. parse system arguments 40 | 2. load yaml config file 41 | 3. merge arguments from 2. into 1., 42 | note that if there exists same arguments in both 2. and 1., 43 | then the arguments in 1. will be overwritten by that in 2. 44 | 4. split into different dataclasses 45 | """ 46 | namespace, remaining_args = self.parse_known_args(args=args) 47 | config_file = getattr(namespace, "config_file", "./configs/debug.yaml") 48 | config_args = Config.fromfile(config_file) 49 | namespace.__dict__.update(config_args) 50 | if is_main_process(): 51 | Config.dump(Config(namespace.__dict__), file=os.path.join(namespace.output_dir, "config.yaml")) 52 | 53 | outputs = [] 54 | for dtype in self.dataclass_types: 55 | keys = {f.name for f in fields(dtype) if f.init} 56 | inputs = {k: v for k, v in vars(namespace).items() if k in keys} 57 | for k in keys: 58 | delattr(namespace, k) 59 | obj = dtype(**inputs) 60 | outputs.append(obj) 61 | if len(namespace.__dict__) > 0: 62 | # additional namespace. 63 | outputs.append(namespace) 64 | if return_remaining_strings: 65 | return (*outputs, remaining_args) 66 | else: 67 | if remaining_args: 68 | raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}") 69 | 70 | return (*outputs,) 71 | -------------------------------------------------------------------------------- /TextHarmony/utils/segm_eval.py: -------------------------------------------------------------------------------- 1 | from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation 2 | from PIL import Image 3 | import numpy as np 4 | 5 | processor = None # OneFormerProcessor.from_pretrained("./assets/shi-labs/oneformer_ade20k_dinat_large") 6 | model = None # OneFormerForUniversalSegmentation.from_pretrained("./assets/shi-labs/oneformer_ade20k_dinat_large") 7 | 8 | 9 | def calculate_segm(image, gt_img): 10 | global processor 11 | global model 12 | if processor is None: 13 | processor = OneFormerProcessor.from_pretrained("./assets/shi-labs/oneformer_ade20k_dinat_large") 14 | if model is None: 15 | model = OneFormerForUniversalSegmentation.from_pretrained("./assets/shi-labs/oneformer_ade20k_dinat_large") 16 | 17 | semantic_inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt") 18 | semantic_outputs = model(**semantic_inputs) 19 | # pass through image_processor for postprocessing 20 | predicted_semantic_map = processor.post_process_semantic_segmentation(semantic_outputs, target_sizes=[gt_img.size[::-1]])[0] 21 | 22 | return predicted_semantic_map 23 | 24 | def intersectionAndUnion(imPred, imLab, numClass): 25 | imPred = np.asarray(imPred).copy() 26 | imLab = np.asarray(imLab).copy() 27 | 28 | # imPred += 1 29 | # imLab += 1 30 | # Remove classes from unlabeled pixels in gt image. 31 | # We should not penalize detections in unlabeled portions of the image. 32 | imPred = imPred * (imLab > 0) 33 | 34 | # Compute area intersection: 35 | intersection = imPred * (imPred == imLab) 36 | (area_intersection, _) = np.histogram( 37 | intersection, bins=numClass, range=(1, numClass)) 38 | 39 | # Compute area union: 40 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 41 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 42 | area_union = area_pred + area_lab - area_intersection 43 | 44 | return (area_intersection, area_union) 45 | 46 | 47 | def calculate_miou_given_paths(paths, num_classes=150): 48 | 49 | all_intersection = None 50 | all_union = None 51 | 52 | for path1, path2 in zip(*paths): 53 | seg_label = np.array(Image.open(path1)) 54 | pred = np.array(Image.open(path2)) + 1 55 | 56 | intersection, union = intersectionAndUnion(pred, seg_label, num_classes) 57 | all_intersection = intersection if all_intersection is None else all_intersection + intersection 58 | all_union = union if all_union is None else all_union + union 59 | 60 | iou = all_intersection / (all_union + 1e-10) 61 | 62 | miou = iou.mean() 63 | 64 | return miou 65 | 66 | -------------------------------------------------------------------------------- /TextHarmony/utils/vizwiz_metrics_src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/vizwiz_metrics_src/__init__.py -------------------------------------------------------------------------------- /TextHarmony/utils/vizwiz_metrics_src/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/vizwiz_metrics_src/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/vizwiz_metrics_src/__pycache__/vqa.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/vizwiz_metrics_src/__pycache__/vqa.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/vizwiz_metrics_src/__pycache__/vqaEval.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/vizwiz_metrics_src/__pycache__/vqaEval.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/vizwiz_metrics_src/vqa.py: -------------------------------------------------------------------------------- 1 | __author__ = 'QingLi' 2 | __version__ = '1.0' 3 | 4 | # Interface for accessing the VQA dataset. 5 | 6 | # This code is based on the code written by Qing Li for VizWiz Python API available at the following link: 7 | # (https://github.com/xxx) 8 | 9 | # The following functions are defined: 10 | # VQA - VQA class that loads VQA annotation file and prepares data structures. 11 | # getQuesIds - Get question ids that satisfy given filter conditions. 12 | # getImgIds - Get image ids that satisfy given filter conditions. 13 | # loadQA - Load questions and answers with the specified question ids. 14 | # showQA - Display the specified questions and answers. 15 | # loadRes - Load result file and create result object. 16 | 17 | # Help on each function can be accessed by: "help(COCO.function)" 18 | 19 | import json 20 | import datetime 21 | import copy 22 | 23 | 24 | class VQA: 25 | def __init__(self, annotation_file=None, annotation=None): 26 | """ 27 | Constructor of VQA helper class for reading and visualizing questions and answers. 28 | :param annotation_file (str): location of VQA annotation file 29 | :return: 30 | """ 31 | # load dataset 32 | self.dataset = {} 33 | self.imgToQA = {} 34 | if annotation is not None or annotation_file is not None: 35 | print('loading dataset into memory...') 36 | time_t = datetime.datetime.utcnow() 37 | dataset = json.load(open(annotation_file, 'r')) if annotation is None else annotation 38 | print(datetime.datetime.utcnow() - time_t) 39 | self.dataset = dataset 40 | self.imgToQA = {x['image']: x for x in dataset} 41 | 42 | def getImgs(self): 43 | return list(self.imgToQA.keys()) 44 | 45 | def getAnns(self, imgs=[], ansTypes=[]): 46 | """ 47 | Get annotations that satisfy given filter conditions. default skips that filter 48 | :param imgs (str array): get annotations for given image names 49 | ansTypes (str array) : get annotations for given answer types 50 | :return: annotations (dict array) : dict array of annotations 51 | """ 52 | anns = self.dataset 53 | 54 | imgs = imgs if type(imgs) == list else [imgs] 55 | if len(imgs) != 0: 56 | anns = [self.imgToQA[img] for img in imgs] 57 | 58 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 59 | if len(ansTypes) != 0: 60 | anns = [ann for ann in anns if ann['answer_type'] in ansTypes] 61 | return anns 62 | 63 | def showQA(self, anns): 64 | """ 65 | Display the specified annotations. 66 | :param anns (array of object): annotations to display 67 | :return: None 68 | """ 69 | if len(anns) == 0: 70 | return 0 71 | for ann in anns: 72 | print("Question: %s" % ann['question']) 73 | print("Answer: ") 74 | print('\n'.join([x['answer'] for x in ann['answers']])) 75 | -------------------------------------------------------------------------------- /TextHarmony/utils/vqa_collect.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from .misc import barrier, get_rank, get_world_size 5 | 6 | 7 | def collect_vqa_result(result, result_dir, filename, is_vizwiz=False): 8 | result_file = os.path.join(result_dir, '%s_rank%d.json' % (filename, get_rank())) 9 | final_result_file = os.path.join(result_dir, '%s.json' % filename) 10 | 11 | for item in result: 12 | image_id = item.pop("image_id") 13 | answer = item.pop("caption") 14 | 15 | if is_vizwiz: 16 | item['image'] = f'VizWiz_val_{image_id:08d}.jpg' 17 | else: 18 | item['question_id'] = image_id 19 | item['answer'] = answer 20 | 21 | json.dump(result, open(result_file, 'w')) 22 | 23 | barrier() 24 | if get_rank() == 0: 25 | # combine results from all processes 26 | result = [] 27 | 28 | for rank in range(get_world_size()): 29 | result_file = os.path.join(result_dir, '%s_rank%d.json' % (filename, rank)) 30 | res = json.load(open(result_file, 'r')) 31 | result += res 32 | os.remove(result_file) 33 | 34 | json.dump(result, open(final_result_file, 'w')) 35 | print('result file saved to %s' % final_result_file) 36 | 37 | return final_result_file -------------------------------------------------------------------------------- /TextHarmony/utils/vqa_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from .vqav2_metrics_src.vqa import VQA as VQAV2_VQA 5 | from .vqav2_metrics_src.vqaEval import VQAEval as VQAV2_VQAEval 6 | from .vizwiz_metrics_src.vqa import VQA as Vizwiz_VQA 7 | from .vizwiz_metrics_src.vqaEval import VQAEval as Vizwiz_VQAEval 8 | 9 | def extract_answer(response): 10 | response = response.replace('\"', '') 11 | # response = response.strip().split('.')[0].split(',')[0].split('!')[0].lower() 12 | response = response.strip().split('\n')[0].split('.')[0].split(',')[0].split('!')[0].lower() 13 | 14 | if 'is ' in response: 15 | response = response.split('is ')[1] 16 | if 'are ' in response: 17 | response = response.split('are ')[1] 18 | if 'a ' in response: 19 | response = response.split('a ')[1] 20 | if 'an ' in response: 21 | response = response.split('an ')[1] 22 | if 'the ' in response: 23 | response = response.split('the ')[1] 24 | if ' of' in response: 25 | response = response.split(' of')[0] 26 | 27 | if ' or ' in response: 28 | response = response.split(' or ')[0] 29 | if ' and ' in response: 30 | response = response.split(' and ')[0] 31 | 32 | return response.strip() 33 | 34 | def vqa_eval( 35 | question_file, 36 | annotation_file, 37 | results_file, 38 | use_extract_answer=True, 39 | ): 40 | 41 | use_extract_answer=False # !!! 42 | 43 | answers = json.load(open(results_file)) 44 | for item in answers: 45 | answer = item['answer'] 46 | 47 | if use_extract_answer: 48 | answer = extract_answer(answer) 49 | 50 | item['answer'] = answer 51 | 52 | if use_extract_answer: 53 | with open(results_file.replace('.json', '_processed.json'), 'w') as file: 54 | json.dump(answers, file) 55 | 56 | annotation_file = annotation_file 57 | question_file = question_file 58 | vqa = VQAV2_VQA(annotation_file, question_file) 59 | vqaRes = vqa.loadRes(answers, question_file) 60 | vqaEval = VQAV2_VQAEval(vqa, vqaRes, n=2) # n is precision of accuracy (number of places after decimal), default is 2 61 | vqaEval.evaluate() 62 | 63 | return {'overall_accuracy': vqaEval.accuracy['overall']} 64 | 65 | def vizwiz_vqa_eval( 66 | annotation_file, 67 | results_file, 68 | use_extract_answer=True, 69 | ): 70 | answers = json.load(open(results_file)) 71 | for item in answers: 72 | answer = item['answer'] 73 | 74 | if use_extract_answer: 75 | answer = extract_answer(answer) 76 | 77 | item['answer'] = answer 78 | 79 | if use_extract_answer: 80 | with open(results_file.replace('.json', '_processed.json'), 'w') as file: 81 | json.dump(answers, file) 82 | 83 | vqa = Vizwiz_VQA(annotation_file) 84 | vqaRes = Vizwiz_VQA(annotation=answers) 85 | vqaEval = Vizwiz_VQAEval(vqa, vqaRes, n=2) # n is precision of accuracy (number of places after decimal), default is 2 86 | vqaEval.evaluate() 87 | 88 | res = {'overall_accuracy': vqaEval.accuracy['overall']} 89 | res.update(vqaEval.caption_metric.items()) 90 | return res 91 | -------------------------------------------------------------------------------- /TextHarmony/utils/vqav2_metrics_src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/vqav2_metrics_src/__init__.py -------------------------------------------------------------------------------- /TextHarmony/utils/vqav2_metrics_src/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/vqav2_metrics_src/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/vqav2_metrics_src/__pycache__/vqa.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/vqav2_metrics_src/__pycache__/vqa.cpython-39.pyc -------------------------------------------------------------------------------- /TextHarmony/utils/vqav2_metrics_src/__pycache__/vqaEval.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/TextHarmony/utils/vqav2_metrics_src/__pycache__/vqaEval.cpython-39.pyc -------------------------------------------------------------------------------- /docs/examples/all_zero.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/docs/examples/all_zero.jpg -------------------------------------------------------------------------------- /docs/examples/book.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/docs/examples/book.jpg -------------------------------------------------------------------------------- /docs/examples/example.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "sentences": [ 4 | "Generate an image according to the caption. A cake of \"Good Time\". Outdoor, bright, sunny, happiness.", 5 | "" 6 | ], 7 | "images": [ 8 | "TextHarmony/docs/examples/all_zero.jpg", 9 | "TextHarmony/docs/examples/all_zero.jpg" 10 | ], 11 | "sentence_ixs": [0, 1], 12 | "image_first": [true, true], 13 | "generate_mode": "generate_images", 14 | "num_iter": 1 15 | }, 16 | 17 | { 18 | "sentences": [ 19 | "Generate an image according to the caption. A T-shirt of \"Keep Focused\". Outdoor, bright, sunny, happiness, 24k.", 20 | "" 21 | ], 22 | "images": [ 23 | "TextHarmony/docs/examples/all_zero.jpg", 24 | "TextHarmony/docs/examples/all_zero.jpg" 25 | ], 26 | "sentence_ixs": [0, 1], 27 | "image_first": [true, true], 28 | "generate_mode": "generate_images", 29 | "num_iter": 1 30 | }, 31 | 32 | { 33 | "sentences": [ 34 | "Generate an image according to the caption. Photo of A book cover of \"Summer Love\". Outdoor, bright, sunny, happiness, 24k.", 35 | "" 36 | ], 37 | "images": [ 38 | "TextHarmony/docs/examples/all_zero.jpg", 39 | "TextHarmony/docs/examples/all_zero.jpg" 40 | ], 41 | "sentence_ixs": [0, 1], 42 | "image_first": [true, true], 43 | "generate_mode": "generate_images", 44 | "num_iter": 1 45 | }, 46 | 47 | { 48 | "sentences": [ 49 | "Who wrote this book?" 50 | ], 51 | "images": [ 52 | "TextHarmony/docs/examples/book.jpg" 53 | ], 54 | "sentence_ixs": [0], 55 | "image_first": [true], 56 | "generate_mode": "generate_texts", 57 | "num_iter": 1 58 | } 59 | ] -------------------------------------------------------------------------------- /image_eval/__pycache__/anytext_singleGPU.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/__pycache__/anytext_singleGPU.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/__pycache__/anytext_singleGPU.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/__pycache__/anytext_singleGPU.cpython-38.pyc -------------------------------------------------------------------------------- /image_eval/__pycache__/dataset_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/__pycache__/dataset_util.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/__pycache__/dataset_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/__pycache__/dataset_util.cpython-38.pyc -------------------------------------------------------------------------------- /image_eval/__pycache__/t3_dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/__pycache__/t3_dataset.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/__pycache__/t3_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/__pycache__/t3_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /image_eval/anytext_multiGPUs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import copy 4 | import argparse 5 | import pathlib 6 | import json 7 | 8 | 9 | def load(file_path: str): 10 | file_path = pathlib.Path(file_path) 11 | func_dict = {'.json': load_json} 12 | assert file_path.suffix in func_dict 13 | return func_dict[file_path.suffix](file_path) 14 | 15 | 16 | def load_json(file_path: str): 17 | with open(file_path, 'r', encoding='utf8') as f: 18 | content = json.load(f) 19 | return content 20 | 21 | 22 | def save(data, file_path): 23 | file_path = pathlib.Path(file_path) 24 | func_dict = {'.json': save_json} 25 | assert file_path.suffix in func_dict 26 | return func_dict[file_path.suffix](data, file_path) 27 | 28 | 29 | def save_json(data, file_path): 30 | with open(file_path, 'w', encoding='utf-8') as json_file: 31 | json.dump(data, json_file, ensure_ascii=False, indent=4) 32 | 33 | 34 | def parse_args(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument( 37 | "--model_path", 38 | type=str, 39 | default='models/anytext_v1.1.ckpt', 40 | help='path of model' 41 | ) 42 | parser.add_argument( 43 | "--gpus", 44 | type=str, 45 | default='0,1,2,3,4,5,6,7', 46 | help='gpus for inference' 47 | ) 48 | parser.add_argument( 49 | "--output_dir", 50 | type=str, 51 | default='./anytext_v1.1_laion_generated/', 52 | help="output path" 53 | ) 54 | parser.add_argument( 55 | "--json_path", 56 | type=str, 57 | default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json', 58 | help="json path for evaluation dataset" 59 | ) 60 | args = parser.parse_args() 61 | return args 62 | 63 | 64 | if __name__ == "__main__": 65 | args = parse_args() 66 | ckpt_path = args.model_path 67 | gpus = args.gpus 68 | output_dir = args.output_dir 69 | json_path = args.json_path 70 | 71 | USING_DLC = False 72 | if USING_DLC: 73 | json_path = json_path.replace('/data/vdb', '/mnt/data', 1) 74 | output_dir = output_dir.replace('/data/vdb', '/mnt/data', 1) 75 | 76 | exec_path = './eval/anytext_singleGPU.py' 77 | continue_gen = True # if True, not clear output_dir, and generate rest images. 78 | tmp_dir = './tmp_dir' 79 | if os.path.exists(tmp_dir): 80 | shutil.rmtree(tmp_dir) 81 | os.makedirs(tmp_dir) 82 | 83 | if not continue_gen: 84 | if os.path.exists(output_dir): 85 | shutil.rmtree(output_dir) 86 | os.makedirs(output_dir) 87 | else: 88 | if not os.path.exists(output_dir): 89 | os.makedirs(output_dir) 90 | 91 | os.system('sleep 1') 92 | 93 | gpu_ids = [int(i) for i in gpus.split(',')] 94 | nproc = len(gpu_ids) 95 | all_lines = load(json_path) 96 | split_file = [] 97 | length = len(all_lines['data_list']) // nproc 98 | cmds = [] 99 | for i in range(nproc): 100 | start, end = i*length, (i+1)*length 101 | if i == nproc - 1: 102 | end = len(all_lines['data_list']) 103 | temp_lines = copy.deepcopy(all_lines) 104 | temp_lines['data_list'] = temp_lines['data_list'][start:end] 105 | tmp_file = os.path.join(tmp_dir, f'tmp_list_{i}.json') 106 | save(temp_lines, tmp_file) 107 | os.system('sleep 1') 108 | cmds += [f'export CUDA_VISIBLE_DEVICES={gpu_ids[i]} && python {exec_path} --input_json {tmp_file} --output_dir {output_dir} --ckpt_path {ckpt_path} && echo proc-{i} done!'] 109 | cmds = ' & '.join(cmds) 110 | os.system(cmds) 111 | print('Done.') 112 | os.system('sleep 2') 113 | shutil.rmtree(tmp_dir) 114 | 115 | ''' 116 | command to kill the task after running: 117 | $ps -ef | grep singleGPU | awk '{ print $2 }' | xargs kill -9 && ps -ef | grep multiproce | awk '{ print $2 }' | xargs kill -9 118 | ''' 119 | -------------------------------------------------------------------------------- /image_eval/cldm/__pycache__/cldm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/cldm/__pycache__/cldm.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/cldm/__pycache__/ddim_hacked.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/cldm/__pycache__/ddim_hacked.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/cldm/__pycache__/ddim_hacked.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/cldm/__pycache__/ddim_hacked.cpython-38.pyc -------------------------------------------------------------------------------- /image_eval/cldm/__pycache__/embedding_manager.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/cldm/__pycache__/embedding_manager.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/cldm/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/cldm/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/cldm/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/cldm/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /image_eval/cldm/__pycache__/recognizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/cldm/__pycache__/recognizer.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/cldm/__pycache__/recognizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/cldm/__pycache__/recognizer.cpython-38.pyc -------------------------------------------------------------------------------- /image_eval/cldm/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | from PIL import Image 7 | from pytorch_lightning.callbacks import Callback 8 | from pytorch_lightning.utilities.distributed import rank_zero_only 9 | 10 | 11 | class ImageLogger(Callback): 12 | def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True, 13 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, 14 | log_images_kwargs=None): 15 | super().__init__() 16 | self.rescale = rescale 17 | self.batch_freq = batch_frequency 18 | self.max_images = max_images 19 | if not increase_log_steps: 20 | self.log_steps = [self.batch_freq] 21 | self.clamp = clamp 22 | self.disabled = disabled 23 | self.log_on_batch_idx = log_on_batch_idx 24 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 25 | self.log_first_step = log_first_step 26 | 27 | @rank_zero_only 28 | def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): 29 | root = os.path.join(save_dir, "image_log", split) 30 | for k in images: 31 | grid = torchvision.utils.make_grid(images[k], nrow=4) 32 | if self.rescale: 33 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 34 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 35 | grid = grid.numpy() 36 | grid = (grid * 255).astype(np.uint8) 37 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) 38 | path = os.path.join(root, filename) 39 | os.makedirs(os.path.split(path)[0], exist_ok=True) 40 | Image.fromarray(grid).save(path) 41 | 42 | def log_img(self, pl_module, batch, batch_idx, split="train"): 43 | check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step 44 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 45 | hasattr(pl_module, "log_images") and 46 | callable(pl_module.log_images) and 47 | self.max_images > 0): 48 | logger = type(pl_module.logger) 49 | 50 | is_train = pl_module.training 51 | if is_train: 52 | pl_module.eval() 53 | 54 | with torch.no_grad(): 55 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) 56 | 57 | for k in images: 58 | N = min(images[k].shape[0], self.max_images) 59 | images[k] = images[k][:N] 60 | if isinstance(images[k], torch.Tensor): 61 | images[k] = images[k].detach().cpu() 62 | if self.clamp: 63 | images[k] = torch.clamp(images[k], -1., 1.) 64 | 65 | self.log_local(pl_module.logger.save_dir, split, images, 66 | pl_module.global_step, pl_module.current_epoch, batch_idx) 67 | 68 | if is_train: 69 | pl_module.train() 70 | 71 | def check_frequency(self, check_idx): 72 | return check_idx % self.batch_freq == 0 73 | 74 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 75 | if not self.disabled: 76 | self.log_img(pl_module, batch, batch_idx, split="train") 77 | -------------------------------------------------------------------------------- /image_eval/cldm/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from omegaconf import OmegaConf 5 | import sys 6 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 7 | from ldm.util import instantiate_from_config 8 | 9 | 10 | def get_state_dict(d): 11 | return d.get('state_dict', d) 12 | 13 | 14 | def load_state_dict(ckpt_path, location='cpu'): 15 | _, extension = os.path.splitext(ckpt_path) 16 | if extension.lower() == ".safetensors": 17 | import safetensors.torch 18 | state_dict = safetensors.torch.load_file(ckpt_path, device=location) 19 | else: 20 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) 21 | state_dict = get_state_dict(state_dict) 22 | print(f'Loaded state_dict from [{ckpt_path}]') 23 | return state_dict 24 | 25 | 26 | def create_model(config_path, cond_stage_path=None, use_fp16=False): 27 | config = OmegaConf.load(config_path) 28 | if cond_stage_path: 29 | config.model.params.cond_stage_config.params.version = cond_stage_path # use pre-downloaded ckpts, in case blocked 30 | if use_fp16: 31 | config.model.params.use_fp16 = True 32 | config.model.params.control_stage_config.params.use_fp16 = True 33 | config.model.params.unet_config.params.use_fp16 = True 34 | model = instantiate_from_config(config.model).cpu() 35 | print(f'Loaded model config from [{config_path}]') 36 | return model 37 | -------------------------------------------------------------------------------- /image_eval/controlnet_multiGPUs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import copy 4 | import argparse 5 | import pathlib 6 | import json 7 | 8 | 9 | def load(file_path: str): 10 | file_path = pathlib.Path(file_path) 11 | func_dict = {'.json': load_json} 12 | assert file_path.suffix in func_dict 13 | return func_dict[file_path.suffix](file_path) 14 | 15 | 16 | def load_json(file_path: str): 17 | with open(file_path, 'r', encoding='utf8') as f: 18 | content = json.load(f) 19 | return content 20 | 21 | 22 | def save(data, file_path): 23 | file_path = pathlib.Path(file_path) 24 | func_dict = {'.json': save_json} 25 | assert file_path.suffix in func_dict 26 | return func_dict[file_path.suffix](data, file_path) 27 | 28 | 29 | def save_json(data, file_path): 30 | with open(file_path, 'w', encoding='utf-8') as json_file: 31 | json.dump(data, json_file, ensure_ascii=False, indent=4) 32 | 33 | 34 | def parse_args(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument( 37 | "--model_path", 38 | type=str, 39 | default='/home/yuxiang.tyx/projects/AnyText/models/control_sd15_canny.pth', 40 | help='path of model' 41 | ) 42 | parser.add_argument( 43 | "--gpus", 44 | type=str, 45 | default='0,1,2,3,4,5,6,7', 46 | help='gpus for inference' 47 | ) 48 | parser.add_argument( 49 | "--output_dir", 50 | type=str, 51 | default='./controlnet_laion_generated/', 52 | help="output path" 53 | ) 54 | parser.add_argument( 55 | "--glyph_dir", 56 | type=str, 57 | default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/glyph_laion', 58 | help="path of glyph images from anytext evaluation dataset" 59 | ) 60 | parser.add_argument( 61 | "--json_path", 62 | type=str, 63 | default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json', 64 | help="json path for evaluation dataset" 65 | ) 66 | args = parser.parse_args() 67 | return args 68 | 69 | 70 | if __name__ == "__main__": 71 | args = parse_args() 72 | output_dir = args.output_dir 73 | 74 | tmp_dir = './tmp_dir' 75 | exec_path = './controlnet_singleGPU.py' 76 | continue_gen = True # if True, not clear output_dir, and generate rest images. 77 | 78 | if os.path.exists(tmp_dir): 79 | shutil.rmtree(tmp_dir) 80 | os.makedirs(tmp_dir) 81 | 82 | if not continue_gen: 83 | if os.path.exists(output_dir): 84 | shutil.rmtree(output_dir) 85 | os.makedirs(output_dir) 86 | else: 87 | if not os.path.exists(output_dir): 88 | os.makedirs(output_dir) 89 | 90 | os.system('sleep 1') 91 | gpu_ids = [int(i) for i in args.gpus.split(',')] 92 | nproc = len(gpu_ids) 93 | all_lines = load(args.json_path) 94 | split_file = [] 95 | length = len(all_lines['data_list']) // nproc 96 | cmds = [] 97 | for i in range(nproc): 98 | start, end = i*length, (i+1)*length 99 | if i == nproc - 1: 100 | end = len(all_lines['data_list']) 101 | temp_lines = copy.deepcopy(all_lines) 102 | temp_lines['data_list'] = temp_lines['data_list'][start:end] 103 | tmp_file = os.path.join(tmp_dir, f'tmp_list_{i}.json') 104 | save(temp_lines, tmp_file) 105 | os.system('sleep 1') 106 | cmds += [f'export CUDA_VISIBLE_DEVICES={gpu_ids[i]} && python {exec_path} --json_path {tmp_file} --output_dir {output_dir} --model_path {args.model_path} --glyph_dir {args.glyph_dir} && echo proc-{i} done!'] 107 | cmds = ' & '.join(cmds) 108 | os.system(cmds) 109 | print('Done.') 110 | os.system('sleep 2') 111 | shutil.rmtree(tmp_dir) 112 | 113 | 114 | ''' 115 | command to kill the task after running: 116 | $ps -ef | grep singleGPU | awk '{ print $2 }' | xargs kill -9 && ps -ef | grep multiproce | awk '{ print $2 }' | xargs kill -9 117 | ''' 118 | -------------------------------------------------------------------------------- /image_eval/dataset_util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pathlib 3 | 4 | __all__ = ['load', 'save', 'show_bbox_on_image'] 5 | 6 | 7 | def load(file_path: str): 8 | file_path = pathlib.Path(file_path) 9 | func_dict = {'.txt': load_txt, '.json': load_json, '.list': load_txt} 10 | assert file_path.suffix in func_dict 11 | return func_dict[file_path.suffix](file_path) 12 | 13 | 14 | def load_txt(file_path: str): 15 | with open(file_path, 'r', encoding='utf8') as f: 16 | content = [x.strip().strip('\ufeff').strip('\xef\xbb\xbf') for x in f.readlines()] 17 | return content 18 | 19 | 20 | def load_json(file_path: str): 21 | with open(file_path, 'r', encoding='utf8') as f: 22 | content = json.load(f) 23 | return content 24 | 25 | 26 | def save(data, file_path): 27 | file_path = pathlib.Path(file_path) 28 | func_dict = {'.txt': save_txt, '.json': save_json} 29 | assert file_path.suffix in func_dict 30 | return func_dict[file_path.suffix](data, file_path) 31 | 32 | 33 | def save_txt(data, file_path): 34 | if not isinstance(data, list): 35 | data = [data] 36 | with open(file_path, mode='w', encoding='utf8') as f: 37 | f.write('\n'.join(data)) 38 | 39 | 40 | def save_json(data, file_path): 41 | with open(file_path, 'w', encoding='utf-8') as json_file: 42 | json.dump(data, json_file, ensure_ascii=False, indent=4) 43 | 44 | 45 | def show_bbox_on_image(image, polygons=None, txt=None, color=None, font_path='./font/Arial_Unicode.ttf'): 46 | from PIL import ImageDraw, ImageFont 47 | image = image.convert('RGB') 48 | draw = ImageDraw.Draw(image) 49 | if len(txt) == 0: 50 | txt = None 51 | if color is None: 52 | color = (255, 0, 0) 53 | if txt is not None: 54 | font = ImageFont.truetype(font_path, 20) 55 | for i, box in enumerate(polygons): 56 | box = box[0] 57 | if txt is not None: 58 | draw.text((int(box[0][0]) + 20, int(box[0][1]) - 20), str(txt[i]), fill='red', font=font) 59 | for j in range(len(box) - 1): 60 | draw.line((box[j][0], box[j][1], box[j + 1][0], box[j + 1][1]), fill=color, width=2) 61 | draw.line((box[-1][0], box[-1][1], box[0][0], box[0][1]), fill=color, width=2) 62 | return image 63 | 64 | 65 | def show_glyphs(glyphs, name): 66 | import numpy as np 67 | import cv2 68 | size = 64 69 | gap = 5 70 | n_char = 20 71 | canvas = np.ones((size, size*n_char + gap*(n_char-1), 1))*0.5 72 | x = 0 73 | for i in range(glyphs.shape[-1]): 74 | canvas[:, x:x + size, :] = glyphs[..., i:i+1] 75 | x += size+gap 76 | cv2.imwrite(name, canvas*255) 77 | -------------------------------------------------------------------------------- /image_eval/eval_fid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -m pytorch_fid \ 3 | /ATA_Benchmark/image_benchmark/AnyText_Benchmark/AnyText-benchmark/benchmark/laion_word/generate_imgs/gen_img_202404261313 \ 4 | /ATA_Benchmark/image_benchmark/AnyText_Benchmark/AnyText-benchmark/benchmark/laion_word/generate_imgs/gen_img_202404261313 -------------------------------------------------------------------------------- /image_eval/eval_ocr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | python eval/eval_dgocr.py \ 4 | --img_dir /data/vdc/yuxiang.tyx/AIGC/anytext_eval_imgs/controlnet_wukong_generated \ 5 | --input_json /data/vdb/yuxiang.tyx/AIGC/data/wukong_word/test1k.json -------------------------------------------------------------------------------- /image_eval/gen_glyph.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python eval/render_glyph_imgs.py \ 3 | --json_path /data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json \ 4 | --output_dir /data/vdb/yuxiang.tyx/AIGC/data/laion_word/glyph_laion 5 | -------------------------------------------------------------------------------- /image_eval/gen_imgs_anytext.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python eval/anytext_multiGPUs.py \ 3 | --model_path models/anytext_v1.1.ckpt \ 4 | --json_path /data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json \ 5 | --output_dir ./anytext_laion_generated \ 6 | --gpus 0,1,2,3,4,5,6,7 7 | -------------------------------------------------------------------------------- /image_eval/gen_imgs_controlnet_canny.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python controlnet_multiGPUs.py \ 3 | --model_path /home/yuxiang.tyx/projects/AnyText/models/control_sd15_canny.pth \ 4 | --json_path /data/vdb/yuxiang.tyx/AIGC/data/wukong_word/test1k.json \ 5 | --glyph_dir /data/vdb/yuxiang.tyx/AIGC/data/wukong_word/glyph_wukong \ 6 | --output_dir ./controlnet_wukong_generated \ 7 | --gpus 0,1,2,3,4,5,6,7 8 | -------------------------------------------------------------------------------- /image_eval/gen_imgs_glyphcontrol.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python glyphcontrol_multiGPUs.py \ 3 | --model_path checkpoints/laion10M_epoch_6_model_ema_only.ckpt \ 4 | --json_path /data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json \ 5 | --glyph_dir /data/vdb/yuxiang.tyx/AIGC/data/laion_word/glyph_laion \ 6 | --output_dir ./glyphcontrol_laion_generated \ 7 | --gpus 0,1,2,3,4,5,6,7 8 | -------------------------------------------------------------------------------- /image_eval/gen_imgs_textdiffuser.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python textdiffuser_multiGPUs.py \ 3 | --model_path textdiffuser-ckpt/diffusion_backbone \ 4 | --json_path /data/vdb/yuxiang.tyx/AIGC/data/wukong_word/test1k.json \ 5 | --glyph_dir /data/vdb/yuxiang.tyx/AIGC/data/wukong_word/glyph_wukong \ 6 | --output_dir ./textdiffuser_wukong_generated \ 7 | --gpus 0,1,2,3,4,5,6,7 8 | -------------------------------------------------------------------------------- /image_eval/ldm/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ldm/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /image_eval/ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/data/__init__.py -------------------------------------------------------------------------------- /image_eval/ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ldm.modules.midas.api import load_midas_transform 4 | 5 | 6 | class AddMiDaS(object): 7 | def __init__(self, model_type): 8 | super().__init__() 9 | self.transform = load_midas_transform(model_type) 10 | 11 | def pt2np(self, x): 12 | x = ((x + 1.0) * .5).detach().cpu().numpy() 13 | return x 14 | 15 | def np2pt(self, x): 16 | x = torch.from_numpy(x) * 2 - 1. 17 | return x 18 | 19 | def __call__(self, sample): 20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 21 | x = self.pt2np(sample['jpg']) 22 | x = self.transform({"image": x})["image"] 23 | sample['midas_in'] = x 24 | return sample -------------------------------------------------------------------------------- /image_eval/ldm/models/__pycache__/autoencoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/models/__pycache__/autoencoder.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /image_eval/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ldm/models/diffusion/__pycache__/ddpm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/models/diffusion/__pycache__/ddpm.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /image_eval/ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | 7 | MODEL_TYPES = { 8 | "eps": "noise", 9 | "v": "v" 10 | } 11 | 12 | 13 | class DPMSolverSampler(object): 14 | def __init__(self, model, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 54 | if cbs != batch_size: 55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 56 | else: 57 | if conditioning.shape[0] != batch_size: 58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 59 | 60 | # sampling 61 | C, H, W = shape 62 | size = (batch_size, C, H, W) 63 | 64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 65 | 66 | device = self.model.betas.device 67 | if x_T is None: 68 | img = torch.randn(size, device=device) 69 | else: 70 | img = x_T 71 | 72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 73 | 74 | model_fn = model_wrapper( 75 | lambda x, t, c: self.model.apply_model(x, t, c), 76 | ns, 77 | model_type=MODEL_TYPES[self.model.parameterization], 78 | guidance_type="classifier-free", 79 | condition=conditioning, 80 | unconditional_condition=unconditional_conditioning, 81 | guidance_scale=unconditional_guidance_scale, 82 | ) 83 | 84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 86 | 87 | return x.to(device), None -------------------------------------------------------------------------------- /image_eval/ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /image_eval/ldm/modules/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ldm/modules/__pycache__/ema.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/__pycache__/ema.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /image_eval/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /image_eval/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /image_eval/ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /image_eval/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /image_eval/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /image_eval/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /image_eval/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /image_eval/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /image_eval/ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /image_eval/ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /image_eval/ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /image_eval/ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /image_eval/ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /image_eval/ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /image_eval/ocr_recog/RecCTCHead.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class CTCHead(nn.Module): 5 | def __init__(self, 6 | in_channels, 7 | out_channels=6625, 8 | fc_decay=0.0004, 9 | mid_channels=None, 10 | return_feats=False, 11 | **kwargs): 12 | super(CTCHead, self).__init__() 13 | if mid_channels is None: 14 | self.fc = nn.Linear( 15 | in_channels, 16 | out_channels, 17 | bias=True,) 18 | else: 19 | self.fc1 = nn.Linear( 20 | in_channels, 21 | mid_channels, 22 | bias=True, 23 | ) 24 | self.fc2 = nn.Linear( 25 | mid_channels, 26 | out_channels, 27 | bias=True, 28 | ) 29 | 30 | self.out_channels = out_channels 31 | self.mid_channels = mid_channels 32 | self.return_feats = return_feats 33 | 34 | def forward(self, x, labels=None): 35 | if self.mid_channels is None: 36 | predicts = self.fc(x) 37 | else: 38 | x = self.fc1(x) 39 | predicts = self.fc2(x) 40 | 41 | if self.return_feats: 42 | result = dict() 43 | result['ctc'] = predicts 44 | result['ctc_neck'] = x 45 | else: 46 | result = predicts 47 | 48 | return result 49 | -------------------------------------------------------------------------------- /image_eval/ocr_recog/RecModel.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .RNN import SequenceEncoder, Im2Seq, Im2Im 3 | from .RecMv1_enhance import MobileNetV1Enhance 4 | 5 | from .RecCTCHead import CTCHead 6 | 7 | backbone_dict = {"MobileNetV1Enhance":MobileNetV1Enhance} 8 | neck_dict = {'SequenceEncoder': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im} 9 | head_dict = {'CTCHead':CTCHead} 10 | 11 | 12 | class RecModel(nn.Module): 13 | def __init__(self, config): 14 | super().__init__() 15 | assert 'in_channels' in config, 'in_channels must in model config' 16 | backbone_type = config.backbone.pop('type') 17 | assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}' 18 | self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone) 19 | 20 | neck_type = config.neck.pop('type') 21 | assert neck_type in neck_dict, f'neck.type must in {neck_dict}' 22 | self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck) 23 | 24 | head_type = config.head.pop('type') 25 | assert head_type in head_dict, f'head.type must in {head_dict}' 26 | self.head = head_dict[head_type](self.neck.out_channels, **config.head) 27 | 28 | self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}' 29 | 30 | def load_3rd_state_dict(self, _3rd_name, _state): 31 | self.backbone.load_3rd_state_dict(_3rd_name, _state) 32 | self.neck.load_3rd_state_dict(_3rd_name, _state) 33 | self.head.load_3rd_state_dict(_3rd_name, _state) 34 | 35 | def forward(self, x): 36 | x = self.backbone(x) 37 | x = self.neck(x) 38 | x = self.head(x) 39 | return x 40 | 41 | def encode(self, x): 42 | x = self.backbone(x) 43 | x = self.neck(x) 44 | x = self.head.ctc_encoder(x) 45 | return x 46 | -------------------------------------------------------------------------------- /image_eval/ocr_recog/__pycache__/RNN.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ocr_recog/__pycache__/RNN.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ocr_recog/__pycache__/RNN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ocr_recog/__pycache__/RNN.cpython-38.pyc -------------------------------------------------------------------------------- /image_eval/ocr_recog/__pycache__/RecCTCHead.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ocr_recog/__pycache__/RecCTCHead.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ocr_recog/__pycache__/RecCTCHead.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ocr_recog/__pycache__/RecCTCHead.cpython-38.pyc -------------------------------------------------------------------------------- /image_eval/ocr_recog/__pycache__/RecModel.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ocr_recog/__pycache__/RecModel.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ocr_recog/__pycache__/RecModel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ocr_recog/__pycache__/RecModel.cpython-38.pyc -------------------------------------------------------------------------------- /image_eval/ocr_recog/__pycache__/RecMv1_enhance.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ocr_recog/__pycache__/RecMv1_enhance.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ocr_recog/__pycache__/RecMv1_enhance.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ocr_recog/__pycache__/RecMv1_enhance.cpython-38.pyc -------------------------------------------------------------------------------- /image_eval/ocr_recog/__pycache__/RecSVTR.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ocr_recog/__pycache__/RecSVTR.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ocr_recog/__pycache__/RecSVTR.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ocr_recog/__pycache__/RecSVTR.cpython-38.pyc -------------------------------------------------------------------------------- /image_eval/ocr_recog/__pycache__/common.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ocr_recog/__pycache__/common.cpython-310.pyc -------------------------------------------------------------------------------- /image_eval/ocr_recog/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ocr_recog/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /image_eval/ocr_recog/common.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class Hswish(nn.Module): 9 | def __init__(self, inplace=True): 10 | super(Hswish, self).__init__() 11 | self.inplace = inplace 12 | 13 | def forward(self, x): 14 | return x * F.relu6(x + 3., inplace=self.inplace) / 6. 15 | 16 | # out = max(0, min(1, slop*x+offset)) 17 | # paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None) 18 | class Hsigmoid(nn.Module): 19 | def __init__(self, inplace=True): 20 | super(Hsigmoid, self).__init__() 21 | self.inplace = inplace 22 | 23 | def forward(self, x): 24 | # torch: F.relu6(x + 3., inplace=self.inplace) / 6. 25 | # paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6. 26 | return F.relu6(1.2 * x + 3., inplace=self.inplace) / 6. 27 | 28 | class GELU(nn.Module): 29 | def __init__(self, inplace=True): 30 | super(GELU, self).__init__() 31 | self.inplace = inplace 32 | 33 | def forward(self, x): 34 | return torch.nn.functional.gelu(x) 35 | 36 | 37 | class Swish(nn.Module): 38 | def __init__(self, inplace=True): 39 | super(Swish, self).__init__() 40 | self.inplace = inplace 41 | 42 | def forward(self, x): 43 | if self.inplace: 44 | x.mul_(torch.sigmoid(x)) 45 | return x 46 | else: 47 | return x*torch.sigmoid(x) 48 | 49 | 50 | class Activation(nn.Module): 51 | def __init__(self, act_type, inplace=True): 52 | super(Activation, self).__init__() 53 | act_type = act_type.lower() 54 | if act_type == 'relu': 55 | self.act = nn.ReLU(inplace=inplace) 56 | elif act_type == 'relu6': 57 | self.act = nn.ReLU6(inplace=inplace) 58 | elif act_type == 'sigmoid': 59 | raise NotImplementedError 60 | elif act_type == 'hard_sigmoid': 61 | self.act = Hsigmoid(inplace) 62 | elif act_type == 'hard_swish': 63 | self.act = Hswish(inplace=inplace) 64 | elif act_type == 'leakyrelu': 65 | self.act = nn.LeakyReLU(inplace=inplace) 66 | elif act_type == 'gelu': 67 | self.act = GELU(inplace=inplace) 68 | elif act_type == 'swish': 69 | self.act = Swish(inplace=inplace) 70 | else: 71 | raise NotImplementedError 72 | 73 | def forward(self, inputs): 74 | return self.act(inputs) -------------------------------------------------------------------------------- /image_eval/ocr_recog/en_dict.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | 4 6 | 5 7 | 6 8 | 7 9 | 8 10 | 9 11 | : 12 | ; 13 | < 14 | = 15 | > 16 | ? 17 | @ 18 | A 19 | B 20 | C 21 | D 22 | E 23 | F 24 | G 25 | H 26 | I 27 | J 28 | K 29 | L 30 | M 31 | N 32 | O 33 | P 34 | Q 35 | R 36 | S 37 | T 38 | U 39 | V 40 | W 41 | X 42 | Y 43 | Z 44 | [ 45 | \ 46 | ] 47 | ^ 48 | _ 49 | ` 50 | a 51 | b 52 | c 53 | d 54 | e 55 | f 56 | g 57 | h 58 | i 59 | j 60 | k 61 | l 62 | m 63 | n 64 | o 65 | p 66 | q 67 | r 68 | s 69 | t 70 | u 71 | v 72 | w 73 | x 74 | y 75 | z 76 | { 77 | | 78 | } 79 | ~ 80 | ! 81 | " 82 | # 83 | $ 84 | % 85 | & 86 | ' 87 | ( 88 | ) 89 | * 90 | + 91 | , 92 | - 93 | . 94 | / 95 | 96 | -------------------------------------------------------------------------------- /image_eval/ocr_weights/en_dict.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | 4 6 | 5 7 | 6 8 | 7 9 | 8 10 | 9 11 | : 12 | ; 13 | < 14 | = 15 | > 16 | ? 17 | @ 18 | A 19 | B 20 | C 21 | D 22 | E 23 | F 24 | G 25 | H 26 | I 27 | J 28 | K 29 | L 30 | M 31 | N 32 | O 33 | P 34 | Q 35 | R 36 | S 37 | T 38 | U 39 | V 40 | W 41 | X 42 | Y 43 | Z 44 | [ 45 | \ 46 | ] 47 | ^ 48 | _ 49 | ` 50 | a 51 | b 52 | c 53 | d 54 | e 55 | f 56 | g 57 | h 58 | i 59 | j 60 | k 61 | l 62 | m 63 | n 64 | o 65 | p 66 | q 67 | r 68 | s 69 | t 70 | u 71 | v 72 | w 73 | x 74 | y 75 | z 76 | { 77 | | 78 | } 79 | ~ 80 | ! 81 | " 82 | # 83 | $ 84 | % 85 | & 86 | ' 87 | ( 88 | ) 89 | * 90 | + 91 | , 92 | - 93 | . 94 | / 95 | 96 | -------------------------------------------------------------------------------- /image_eval/ocr_weights/ppv3_rec.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ocr_weights/ppv3_rec.pth -------------------------------------------------------------------------------- /image_eval/ocr_weights/ppv3_rec_en.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/TextHarmony/f98ac34b206fcd9698e1d6fe2d471dfaa7c96582/image_eval/ocr_weights/ppv3_rec_en.pth -------------------------------------------------------------------------------- /image_eval/render_glyph_imgs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 4 | from tqdm import tqdm 5 | import shutil 6 | import numpy as np 7 | import cv2 8 | from PIL import Image, ImageFont 9 | from torch.utils.data import DataLoader 10 | from dataset_util import show_bbox_on_image 11 | import argparse 12 | from t3_dataset import T3DataSet 13 | max_lines = 20 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "--json_path", 20 | type=str, 21 | default='/data/vdb/yuxiang.tyx/AIGC/data/wukong_word/test1k.json', 22 | help="json path for evaluation dataset", 23 | ) 24 | parser.add_argument( 25 | "--output_dir", 26 | type=str, 27 | default='/data/vdb/yuxiang.tyx/AIGC/data/wukong_word/glyph_wukong', 28 | help="output path, clear the folder if exist", 29 | ) 30 | parser.add_argument( 31 | "--img_count", 32 | type=int, 33 | default=1000, 34 | help="image count", 35 | ) 36 | args = parser.parse_args() 37 | return args 38 | 39 | 40 | if __name__ == '__main__': 41 | args = parse_args() 42 | if os.path.exists(args.output_dir): 43 | shutil.rmtree(args.output_dir) 44 | os.makedirs(args.output_dir) 45 | dataset = T3DataSet(args.json_path, for_show=True, max_lines=max_lines, glyph_scale=2, mask_img_prob=1.0, caption_pos_prob=0.0) 46 | train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 47 | pbar = tqdm(total=args.img_count) 48 | for i, data in enumerate(train_loader): 49 | if i == args.img_count: 50 | break 51 | all_glyphs = [] 52 | for k, glyphs in enumerate(data['glyphs']): 53 | all_glyphs += [glyphs[0].numpy().astype(np.int32)*255] 54 | glyph_img = cv2.resize(255.0-np.sum(all_glyphs, axis=0), (512, 512)) 55 | cv2.imwrite(os.path.join(args.output_dir, data['img_name'][0]), glyph_img) 56 | pbar.update(1) 57 | pbar.close() 58 | -------------------------------------------------------------------------------- /image_eval/textdiffuser_multiGPUs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import copy 4 | import argparse 5 | import pathlib 6 | import json 7 | 8 | 9 | def load(file_path: str): 10 | file_path = pathlib.Path(file_path) 11 | func_dict = {'.json': load_json} 12 | assert file_path.suffix in func_dict 13 | return func_dict[file_path.suffix](file_path) 14 | 15 | 16 | def load_json(file_path: str): 17 | with open(file_path, 'r', encoding='utf8') as f: 18 | content = json.load(f) 19 | return content 20 | 21 | 22 | def save(data, file_path): 23 | file_path = pathlib.Path(file_path) 24 | func_dict = {'.json': save_json} 25 | assert file_path.suffix in func_dict 26 | return func_dict[file_path.suffix](data, file_path) 27 | 28 | 29 | def save_json(data, file_path): 30 | with open(file_path, 'w', encoding='utf-8') as json_file: 31 | json.dump(data, json_file, ensure_ascii=False, indent=4) 32 | 33 | 34 | def parse_args(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument( 37 | "--model_path", 38 | type=str, 39 | default='textdiffuser-ckpt/diffusion_backbone', 40 | help='path to model' 41 | ) 42 | parser.add_argument( 43 | "--gpus", 44 | type=str, 45 | default='0,1,2,3,4,5,6,7', 46 | help='gpus for inference' 47 | ) 48 | parser.add_argument( 49 | "--output_dir", 50 | type=str, 51 | default='./textdiffuser_laion_generated/', 52 | help="output path" 53 | ) 54 | parser.add_argument( 55 | "--glyph_dir", 56 | type=str, 57 | default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/glyph_laion', 58 | help="path of glyph images from anytext evaluation dataset" 59 | ) 60 | parser.add_argument( 61 | "--json_path", 62 | type=str, 63 | default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json', 64 | help="json path for evaluation dataset" 65 | ) 66 | args = parser.parse_args() 67 | return args 68 | 69 | 70 | if __name__ == "__main__": 71 | args = parse_args() 72 | output_dir = args.output_dir 73 | 74 | tmp_dir = './tmp_dir' 75 | exec_path = './textdiffuser_singleGPU.py' 76 | continue_gen = True # if True, not clear output_dir, and generate rest images. 77 | 78 | if os.path.exists(tmp_dir): 79 | shutil.rmtree(tmp_dir) 80 | os.makedirs(tmp_dir) 81 | 82 | if not continue_gen: 83 | if os.path.exists(output_dir): 84 | shutil.rmtree(output_dir) 85 | os.makedirs(output_dir) 86 | else: 87 | if not os.path.exists(output_dir): 88 | os.makedirs(output_dir) 89 | 90 | os.system('sleep 1') 91 | gpu_ids = [int(i) for i in args.gpus.split(',')] 92 | nproc = len(gpu_ids) 93 | all_lines = load(args.json_path) 94 | split_file = [] 95 | length = len(all_lines['data_list']) // nproc 96 | cmds = [] 97 | for i in range(nproc): 98 | start, end = i*length, (i+1)*length 99 | if i == nproc - 1: 100 | end = len(all_lines['data_list']) 101 | temp_lines = copy.deepcopy(all_lines) 102 | temp_lines['data_list'] = temp_lines['data_list'][start:end] 103 | tmp_file = os.path.join(tmp_dir, f'tmp_list_{i}.json') 104 | save(temp_lines, tmp_file) 105 | os.system('sleep 1') 106 | cmds += [f'export CUDA_VISIBLE_DEVICES={gpu_ids[i]} && python {exec_path} --json_path {tmp_file} --output_dir {output_dir} --model_path {args.model_path} --glyph_dir {args.glyph_dir} && echo proc-{i} done!'] 107 | cmds = ' & '.join(cmds) 108 | os.system(cmds) 109 | print('Done.') 110 | os.system('sleep 2') 111 | shutil.rmtree(tmp_dir) 112 | 113 | 114 | ''' 115 | command to kill the task after running: 116 | $ps -ef | grep singleGPU | awk '{ print $2 }' | xargs kill -9 && ps -ef | grep multiproce | awk '{ print $2 }' | xargs kill -9 117 | ''' 118 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | accelerate==0.21.0 3 | addict==2.4.0 4 | adjustText==1.1.1 5 | aiohttp==3.9.3 6 | aiosignal==1.3.1 7 | albumentations==1.3.1 8 | antlr4-python3-runtime==4.9.3 9 | async-timeout==4.0.3 10 | attr==0.3.2 11 | attrs==23.2.0 12 | braceexpand==0.1.7 13 | cachetools==5.3.2 14 | certifi==2024.2.2 15 | charset-normalizer==3.3.2 16 | click==8.1.7 17 | clip==0.2.0 18 | cmake==3.28.1 19 | contourpy==1.2.0 20 | cycler==0.12.1 21 | datasets==2.14.0 22 | deepspeed==0.10.0 23 | diffusers==0.20.0 24 | dill==0.3.7 25 | einops==0.6.1 26 | fairscale==0.4.13 27 | filelock==3.13.4 28 | flash-attn==2.0.4 29 | fonttools==4.48.1 30 | frozenlist==1.4.1 31 | fsspec==2024.2.0 32 | ftfy==6.2.0 33 | google-auth==2.27.0 34 | google-auth-oauthlib==1.2.0 35 | grpcio==1.60.1 36 | hjson==3.1.0 37 | huggingface-hub==0.20.3 38 | idna==3.6 39 | imageio==2.33.1 40 | importlib-metadata==7.0.1 41 | importlib-resources==6.1.1 42 | Jinja2==3.1.3 43 | joblib==1.3.2 44 | kiwisolver==1.4.5 45 | lazy_loader==0.3 46 | lit==17.0.6 47 | Markdown==3.5.2 48 | MarkupSafe==2.1.5 49 | matplotlib==3.8.4 50 | mmcv-full==1.7.0 51 | mpmath==1.3.0 52 | multidict==6.0.5 53 | multiprocess==0.70.15 54 | MultiScaleDeformableAttention==1.0 55 | mypy-extensions==1.0.0 56 | networkx==3.2.1 57 | ninja==1.11.1 58 | nltk==3.8.1 59 | numpy==1.26.4 60 | nvidia-cublas-cu11==11.10.3.66 61 | nvidia-cuda-cupti-cu11==11.7.101 62 | nvidia-cuda-nvrtc-cu11==11.7.99 63 | nvidia-cuda-runtime-cu11==11.7.99 64 | nvidia-cudnn-cu11==8.5.0.96 65 | nvidia-cufft-cu11==10.9.0.58 66 | nvidia-curand-cu11==10.2.10.91 67 | nvidia-cusolver-cu11==11.4.0.1 68 | nvidia-cusparse-cu11==11.7.4.91 69 | nvidia-nccl-cu11==2.14.3 70 | nvidia-nvtx-cu11==11.7.91 71 | oauthlib==3.2.2 72 | omegaconf==2.3.0 73 | openai-clip==1.0.1 74 | opencv-python==4.9.0.80 75 | opencv-python-headless==4.9.0.80 76 | packaging==23.2 77 | palettable==3.3.3 78 | pandas==2.2.0 79 | peft==0.3.0 80 | Pillow==10.0.0 81 | platformdirs==4.2.0 82 | protobuf==4.23.4 83 | psutil==5.9.8 84 | py-cpuinfo==9.0.0 85 | pyarrow==12.0.1 86 | pyasn1==0.5.1 87 | pyasn1-modules==0.3.0 88 | pycocoevalcap==1.2 89 | pycocotools==2.0.6 90 | pydantic==1.10.14 91 | pyparsing==3.1.1 92 | pyre-extensions==0.0.29 93 | python-dateutil==2.8.2 94 | pytorch-fid==0.3.0 95 | pytz==2024.1 96 | PyYAML==6.0.1 97 | qudida==0.0.4 98 | regex==2023.12.25 99 | requests==2.31.0 100 | requests-oauthlib==1.3.1 101 | rsa==4.9 102 | safetensors==0.4.2 103 | scikit-image==0.22.0 104 | scikit-learn==1.3.1 105 | scipy==1.11.1 106 | sentencepiece==0.1.99 107 | sentry-sdk==1.43.0 108 | six==1.16.0 109 | sympy==1.12 110 | tensorboard==2.15.1 111 | tensorboard-data-server==0.7.2 112 | threadpoolctl==3.2.0 113 | tifffile==2024.1.30 114 | timm==0.9.2 115 | tokenizers==0.13.3 116 | torch==2.0.1+cu117 117 | torchaudio==2.0.2 118 | torchvision==0.15.2 119 | tqdm==4.66.1 120 | transformers==4.31.0 121 | triton==2.0.0 122 | typing-inspect==0.9.0 123 | typing_extensions==4.9.0 124 | tzdata==2023.4 125 | urllib3==2.2.0 126 | wcwidth==0.2.13 127 | webdataset==0.2.48 128 | Werkzeug==3.0.1 129 | xformers==0.0.20 130 | xxhash==3.4.1 131 | yapf==0.40.2 132 | yarl==1.9.4 133 | zipp==3.17.0 134 | --------------------------------------------------------------------------------