├── .gitignore ├── LICENSE.txt ├── README.md ├── UserGuide.md ├── __init__.py ├── assets ├── images │ ├── girl.png │ ├── snake.png │ ├── test.jpg │ ├── test2.jpg │ └── test3.jpg ├── masks │ ├── test.png │ └── test2.png ├── materials │ ├── gr_infer_demo.jpg │ ├── gr_pre_demo.jpg │ ├── tasks.png │ └── teaser.jpg └── videos │ ├── test.mp4 │ └── test2.mp4 ├── benchmarks └── .gitkeep ├── models └── .gitkeep ├── pyproject.toml ├── requirements.txt ├── requirements ├── annotator.txt └── framework.txt ├── run_vace_ltx.sh ├── run_vace_pipeline.sh ├── run_vace_preproccess.sh ├── run_vace_wan.sh ├── tests └── test_annotators.py └── vace ├── __init__.py ├── annotators ├── __init__.py ├── canvas.py ├── common.py ├── composition.py ├── depth.py ├── depth_anything_v2 │ ├── __init__.py │ ├── dinov2.py │ ├── dpt.py │ ├── layers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── block.py │ │ ├── drop_path.py │ │ ├── layer_scale.py │ │ ├── mlp.py │ │ ├── patch_embed.py │ │ └── swiglu_ffn.py │ └── util │ │ ├── __init__.py │ │ ├── blocks.py │ │ └── transform.py ├── dwpose │ ├── __init__.py │ ├── onnxdet.py │ ├── onnxpose.py │ ├── util.py │ └── wholebody.py ├── face.py ├── flow.py ├── frameref.py ├── gdino.py ├── gray.py ├── inpainting.py ├── layout.py ├── mask.py ├── maskaug.py ├── midas │ ├── __init__.py │ ├── api.py │ ├── base_model.py │ ├── blocks.py │ ├── dpt_depth.py │ ├── midas_net.py │ ├── midas_net_custom.py │ ├── transforms.py │ ├── utils.py │ └── vit.py ├── outpainting.py ├── pose.py ├── prompt_extend.py ├── ram.py ├── salient.py ├── sam.py ├── sam2.py ├── scribble.py ├── subject.py └── utils.py ├── configs ├── __init__.py ├── common_preproccess.py ├── composition_preprocess.py ├── image_preproccess.py ├── prompt_preprocess.py └── video_preproccess.py ├── gradios ├── __init__.py ├── vace_ltx_demo.py ├── vace_preprocess_demo.py └── vace_wan_demo.py ├── models ├── __init__.py ├── ltx │ ├── __init__.py │ ├── ltx_vace.py │ ├── models │ │ ├── __init__.py │ │ └── transformers │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ └── transformer3d.py │ └── pipelines │ │ ├── __init__.py │ │ └── pipeline_ltx_video.py ├── utils │ ├── __init__.py │ └── preprocessor.py └── wan │ ├── __init__.py │ ├── configs │ ├── __init__.py │ ├── shared_config.py │ ├── wan_t2v_14B.py │ └── wan_t2v_1_3B.py │ ├── distributed │ ├── __init__.py │ └── xdit_context_parallel.py │ ├── modules │ ├── __init__.py │ └── model.py │ └── wan_vace.py ├── vace_ltx_inference.py ├── vace_pipeline.py ├── vace_preproccess.py └── vace_wan_inference.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pth 3 | *.pt 4 | *.pkl 5 | *.ckpt 6 | *.DS_Store 7 | *__pycache__* 8 | *.cache* 9 | *.bin 10 | *.idea 11 | *.csv 12 | cache 13 | build 14 | dist 15 | dev 16 | vace.egg-info 17 | .readthedocs.yml 18 | *resources 19 | *.ipynb_checkpoints* 20 | *.vscode -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/__init__.py -------------------------------------------------------------------------------- /assets/images/girl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/assets/images/girl.png -------------------------------------------------------------------------------- /assets/images/snake.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/assets/images/snake.png -------------------------------------------------------------------------------- /assets/images/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/assets/images/test.jpg -------------------------------------------------------------------------------- /assets/images/test2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/assets/images/test2.jpg -------------------------------------------------------------------------------- /assets/images/test3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/assets/images/test3.jpg -------------------------------------------------------------------------------- /assets/masks/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/assets/masks/test.png -------------------------------------------------------------------------------- /assets/masks/test2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/assets/masks/test2.png -------------------------------------------------------------------------------- /assets/materials/gr_infer_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/assets/materials/gr_infer_demo.jpg -------------------------------------------------------------------------------- /assets/materials/gr_pre_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/assets/materials/gr_pre_demo.jpg -------------------------------------------------------------------------------- /assets/materials/tasks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/assets/materials/tasks.png -------------------------------------------------------------------------------- /assets/materials/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/assets/materials/teaser.jpg -------------------------------------------------------------------------------- /assets/videos/test.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/assets/videos/test.mp4 -------------------------------------------------------------------------------- /assets/videos/test2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/assets/videos/test2.mp4 -------------------------------------------------------------------------------- /benchmarks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/benchmarks/.gitkeep -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/models/.gitkeep -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "vace" 7 | version = "1.1.0" 8 | description = "VACE: All-in-One Video Creation and Editing" 9 | authors = [ 10 | { name = "VACE Team", email = "wan.ai@alibabacloud.com" } 11 | ] 12 | requires-python = ">=3.10,<4.0" 13 | readme = "README.md" 14 | dependencies = [ 15 | "torch>=2.5.1", 16 | "torchvision>=0.20.1", 17 | "opencv-python>=4.9.0.80", 18 | "diffusers>=0.31.0", 19 | "transformers>=4.49.0", 20 | "tokenizers>=0.20.3", 21 | "accelerate>=1.1.1", 22 | "gradio>=5.0.0", 23 | "numpy>=1.23.5,<2", 24 | "tqdm", 25 | "imageio", 26 | "easydict", 27 | "ftfy", 28 | "dashscope", 29 | "imageio-ffmpeg", 30 | "flash_attn", 31 | "decord", 32 | "einops", 33 | "scikit-image", 34 | "scikit-learn", 35 | "pycocotools", 36 | "timm", 37 | "onnxruntime-gpu", 38 | "BeautifulSoup4" 39 | ] 40 | 41 | [project.optional-dependencies] 42 | ltx = [ 43 | "ltx-video@git+https://github.com/Lightricks/LTX-Video@ltx-video-0.9.1" 44 | ] 45 | wan = [ 46 | "wan@git+https://github.com/Wan-Video/Wan2.1" 47 | ] 48 | annotator = [ 49 | "insightface", 50 | "sam-2@git+https://github.com/facebookresearch/sam2.git", 51 | "segment-anything@git+https://github.com/facebookresearch/segment-anything.git", 52 | "groundingdino@git+https://github.com/IDEA-Research/GroundingDINO.git", 53 | "ram@git+https://github.com/xinyu1205/recognize-anything.git", 54 | "raft@git+https://github.com/martin-chobanyan-sdc/RAFT.git" 55 | ] 56 | 57 | [project.urls] 58 | homepage = "https://ali-vilab.github.io/VACE-Page/" 59 | documentation = "https://ali-vilab.github.io/VACE-Page/" 60 | repository = "https://github.com/ali-vilab/VACE" 61 | hfmodel = "https://huggingface.co/collections/ali-vilab/vace-67eca186ff3e3564726aff38" 62 | msmodel = "https://modelscope.cn/collections/VACE-8fa5fcfd386e43" 63 | paper = "https://arxiv.org/abs/2503.07598" 64 | 65 | [tool.setuptools] 66 | packages = { find = {} } 67 | 68 | [tool.black] 69 | line-length = 88 70 | 71 | [tool.isort] 72 | profile = "black" 73 | 74 | [tool.mypy] 75 | strict = true -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/framework.txt -------------------------------------------------------------------------------- /requirements/annotator.txt: -------------------------------------------------------------------------------- 1 | insightface 2 | git+https://github.com/facebookresearch/sam2.git 3 | git+https://github.com/facebookresearch/segment-anything.git 4 | git+https://github.com/IDEA-Research/GroundingDINO.git 5 | git+https://github.com/xinyu1205/recognize-anything.git 6 | git+https://github.com/martin-chobanyan-sdc/RAFT.git -------------------------------------------------------------------------------- /requirements/framework.txt: -------------------------------------------------------------------------------- 1 | torch>=2.5.1 2 | torchvision>=0.20.1 3 | opencv-python>=4.9.0.80 4 | diffusers>=0.31.0 5 | transformers>=4.49.0 6 | tokenizers>=0.20.3 7 | accelerate>=1.1.1 8 | gradio>=5.0.0 9 | numpy>=1.23.5,<2 10 | tqdm 11 | imageio 12 | easydict 13 | ftfy 14 | dashscope 15 | imageio-ffmpeg 16 | flash_attn 17 | decord 18 | einops 19 | scikit-image 20 | scikit-learn 21 | pycocotools 22 | timm 23 | onnxruntime-gpu 24 | BeautifulSoup4 25 | #ltx-video@git+https://github.com/Lightricks/LTX-Video@ltx-video-0.9.1 26 | #wan@git+https://github.com/Wan-Video/Wan2.1 -------------------------------------------------------------------------------- /run_vace_pipeline.sh: -------------------------------------------------------------------------------- 1 | #------------------------ Pipeline ------------------------# 2 | # extension firstframe 3 | python vace/vace_pipeline.py --base wan --task frameref --mode firstframe --image "benchmarks/VACE-Benchmark/assets/examples/firstframe/ori_image_1.png" --prompt "纪实摄影风格,前景是一位中国越野爱好者坐在越野车上,手持车载电台正在进行通联。他五官清晰,表情专注,眼神坚定地望向前方。越野车停在户外,车身略显脏污,显示出经历过的艰难路况。镜头从车外缓缓拉近,最后定格在人物的面部特写上,展现出他的坚定与热情。中景到近景,动态镜头运镜。" 4 | 5 | # repainting inpainting 6 | python vace/vace_pipeline.py --base wan --task inpainting --mode salientmasktrack --maskaug_mode original_expand --maskaug_ratio 0.5 --video "benchmarks/VACE-Benchmark/assets/examples/inpainting/ori_video.mp4" --prompt "一只巨大的金色凤凰从繁华的城市上空展翅飞过,羽毛如火焰般璀璨,闪烁着温暖的光辉,翅膀雄伟地展开。凤凰高昂着头,目光炯炯,轻轻扇动翅膀,散发出淡淡的光芒。下方是熙熙攘攘的市中心,人群惊叹,车水马龙,红蓝两色的霓虹灯在夜空下闪烁。镜头俯视城市街道,捕捉这一壮丽的景象,营造出既神秘又辉煌的氛围。" 7 | 8 | # repainting outpainting 9 | python vace/vace_pipeline.py --base wan --task outpainting --direction 'up,down,left,right' --expand_ratio 0.3 --video "benchmarks/VACE-Benchmark/assets/examples/outpainting/ori_video.mp4" --prompt "赛博朋克风格,无人机俯瞰视角下的现代西安城墙,镜头穿过永宁门时泛起金色涟漪,城墙砖块化作数据流重组为唐代长安城。周围的街道上流动的人群和飞驰的机械交通工具交织在一起,现代与古代的交融,城墙上的灯光闪烁,形成时空隧道的效果。全息投影技术展现历史变迁,粒子重组特效细腻逼真。大远景逐渐过渡到特写,聚焦于城门特效。" 10 | 11 | # control depth 12 | python vace/vace_pipeline.py --base wan --task depth --video "benchmarks/VACE-Benchmark/assets/examples/depth/ori_video.mp4" --prompt "一群年轻人在天空之城拍摄集体照。画面中,一对年轻情侣手牵手,轻声细语,相视而笑,周围是飞翔的彩色热气球和闪烁的星星,营造出浪漫的氛围。天空中,暖阳透过飘浮的云朵,洒下斑驳的光影。镜头以近景特写开始,随着情侣间的亲密互动,缓缓拉远。" 13 | 14 | # control flow 15 | python vace/vace_pipeline.py --base wan --task flow --video "benchmarks/VACE-Benchmark/assets/examples/flow/ori_video.mp4" --prompt "纪实摄影风格,一颗鲜红的小番茄缓缓落入盛着牛奶的玻璃杯中,溅起晶莹的水花。画面以慢镜头捕捉这一瞬间,水花在空中绽放,形成美丽的弧线。玻璃杯中的牛奶纯白,番茄的鲜红与之形成鲜明对比。背景简洁,突出主体。近景特写,垂直俯视视角,展现细节之美。" 16 | 17 | # control gray 18 | python vace/vace_pipeline.py --base wan --task gray --video "benchmarks/VACE-Benchmark/assets/examples/gray/ori_video.mp4" --prompt "镜头缓缓向右平移,身穿淡黄色坎肩长裙的长发女孩面对镜头露出灿烂的漏齿微笑。她的长发随风轻扬,眼神明亮而充满活力。背景是秋天红色和黄色的树叶,阳光透过树叶的缝隙洒下斑驳光影,营造出温馨自然的氛围。画面风格清新自然,仿佛夏日午后的一抹清凉。中景人像,强调自然光效和细腻的皮肤质感。" 19 | 20 | # control pose 21 | python vace/vace_pipeline.py --base wan --task pose --video "benchmarks/VACE-Benchmark/assets/examples/pose/ori_video.mp4" --prompt "在一个热带的庆祝派对上,一家人围坐在椰子树下的长桌旁。桌上摆满了异国风味的美食。长辈们愉悦地交谈,年轻人兴奋地举杯碰撞,孩子们在沙滩上欢乐奔跑。背景中是湛蓝的海洋和明亮的阳光,营造出轻松的气氛。镜头以动态中景捕捉每个开心的瞬间,温暖的阳光映照着他们幸福的面庞。" 22 | 23 | # control scribble 24 | python vace/vace_pipeline.py --base wan --task scribble --video "benchmarks/VACE-Benchmark/assets/examples/scribble/ori_video.mp4" --prompt "画面中荧光色彩的无人机从极低空高速掠过超现实主义风格的西安古城墙,尘埃反射着阳光。镜头快速切换至城墙上的砖石特写,阳光温暖地洒落,勾勒出每一块砖块的细腻纹理。整体画质清晰华丽,运镜流畅如水。" 25 | 26 | # control layout 27 | python vace/vace_pipeline.py --base wan --task layout_track --mode bboxtrack --bbox '54,200,614,448' --maskaug_mode bbox_expand --maskaug_ratio 0.2 --label 'bird' --video "benchmarks/VACE-Benchmark/assets/examples/layout/ori_video.mp4" --prompt "视频展示了一只成鸟在树枝上的巢中喂养它的幼鸟。成鸟在喂食的过程中,幼鸟张开嘴巴等待食物。随后,成鸟飞走,幼鸟继续等待。成鸟再次飞回,带回食物喂养幼鸟。整个视频的拍摄角度固定,聚焦于巢穴和鸟类的互动,背景是模糊的绿色植被,强调了鸟类的自然行为和生态环境。" 28 | -------------------------------------------------------------------------------- /run_vace_preproccess.sh: -------------------------------------------------------------------------------- 1 | #------------------------ Gadio ------------------------# 2 | python vace/gradios/vace_preproccess_demo.py 3 | 4 | #------------------------ Video ------------------------# 5 | python vace/vace_preproccess.py --task depth --video assets/videos/test.mp4 6 | python vace/vace_preproccess.py --task flow --video assets/videos/test.mp4 7 | python vace/vace_preproccess.py --task gray --video assets/videos/test.mp4 8 | python vace/vace_preproccess.py --task pose --video assets/videos/test.mp4 9 | python vace/vace_preproccess.py --task scribble --video assets/videos/test.mp4 10 | python vace/vace_preproccess.py --task frameref --mode firstframe --image assets/images/test.jpg 11 | python vace/vace_preproccess.py --task frameref --mode lastframe --expand_num 55 --image assets/images/test.jpg 12 | python vace/vace_preproccess.py --task frameref --mode firstlastframe --image assets/images/test.jpg,assets/images/test2.jpg 13 | python vace/vace_preproccess.py --task clipref --mode firstclip --expand_num 66 --video assets/videos/test.mp4 14 | python vace/vace_preproccess.py --task clipref --mode lastclip --expand_num 55 --video assets/videos/test.mp4 15 | python vace/vace_preproccess.py --task clipref --mode firstlastclip --video assets/videos/test.mp4,assets/videos/test2.mp4 16 | python vace/vace_preproccess.py --task inpainting --mode salient --video assets/videos/test.mp4 17 | python vace/vace_preproccess.py --task inpainting --mode mask --mask assets/masks/test.png --video assets/videos/test.mp4 18 | python vace/vace_preproccess.py --task inpainting --mode bbox --bbox 50,50,550,700 --video assets/videos/test.mp4 19 | python vace/vace_preproccess.py --task inpainting --mode salientmasktrack --video assets/videos/test.mp4 20 | python vace/vace_preproccess.py --task inpainting --mode salientbboxtrack --video assets/videos/test.mp4 21 | python vace/vace_preproccess.py --task inpainting --mode masktrack --mask assets/masks/test.png --video assets/videos/test.mp4 22 | python vace/vace_preproccess.py --task inpainting --mode bboxtrack --bbox 50,50,550,700 --video assets/videos/test.mp4 23 | python vace/vace_preproccess.py --task inpainting --mode label --label cat --video assets/videos/test.mp4 24 | python vace/vace_preproccess.py --task inpainting --mode caption --caption 'boxing glove' --video assets/videos/test.mp4 25 | python vace/vace_preproccess.py --task outpainting --video assets/videos/test.mp4 26 | python vace/vace_preproccess.py --task outpainting --direction 'up,down,left,right' --expand_ratio 0.5 --video assets/videos/test.mp4 27 | python vace/vace_preproccess.py --task layout_bbox --bbox '50,50,550,700 500,150,750,700' --label 'person' 28 | python vace/vace_preproccess.py --task layout_track --mode masktrack --mask assets/masks/test.png --label 'cat' --video assets/videos/test.mp4 29 | python vace/vace_preproccess.py --task layout_track --mode bboxtrack --bbox '50,50,550,700' --label 'cat' --video assets/videos/test.mp4 30 | python vace/vace_preproccess.py --task layout_track --mode label --label 'cat' --maskaug_mode hull_expand --maskaug_ratio 0.1 --video assets/videos/test.mp4 31 | python vace/vace_preproccess.py --task layout_track --mode caption --caption 'boxing glove' --maskaug_mode bbox --video assets/videos/test.mp4 --label 'glove' 32 | 33 | #------------------------ Image ------------------------# 34 | python vace/vace_preproccess.py --task image_face --image assets/images/test3.jpg 35 | python vace/vace_preproccess.py --task image_salient --image assets/images/test.jpg 36 | python vace/vace_preproccess.py --task image_inpainting --mode 'salientbboxtrack' --image assets/images/test2.jpg 37 | python vace/vace_preproccess.py --task image_inpainting --mode 'salientmasktrack' --maskaug_mode hull_expand --maskaug_ratio 0.3 --image assets/images/test2.jpg 38 | python vace/vace_preproccess.py --task image_reference --mode plain --image assets/images/test.jpg 39 | python vace/vace_preproccess.py --task image_reference --mode salient --image assets/images/test.jpg 40 | python vace/vace_preproccess.py --task image_reference --mode mask --mask assets/masks/test2.png --image assets/images/test.jpg 41 | python vace/vace_preproccess.py --task image_reference --mode bbox --bbox 0,264,338,636 --image assets/images/test.jpg 42 | python vace/vace_preproccess.py --task image_reference --mode salientmasktrack --image assets/images/test.jpg # easyway, recommend 43 | python vace/vace_preproccess.py --task image_reference --mode salientbboxtrack --bbox 0,264,338,636 --maskaug_mode original_expand --maskaug_ratio 0.2 --image assets/images/test.jpg 44 | python vace/vace_preproccess.py --task image_reference --mode masktrack --mask assets/masks/test2.png --image assets/images/test.jpg 45 | python vace/vace_preproccess.py --task image_reference --mode bboxtrack --bbox 0,264,338,636 --image assets/images/test.jpg 46 | python vace/vace_preproccess.py --task image_reference --mode label --label 'cat' --image assets/images/test.jpg 47 | python vace/vace_preproccess.py --task image_reference --mode caption --caption 'flower' --maskaug_mode bbox --maskaug_ratio 0.3 --image assets/images/test.jpg 48 | 49 | #------------------------ Composition ------------------------# 50 | python vace/vace_preproccess.py --task reference_anything --mode salientmasktrack --image assets/images/test.jpg 51 | python vace/vace_preproccess.py --task reference_anything --mode salientbboxtrack --image assets/images/test.jpg,assets/images/test2.jpg 52 | python vace/vace_preproccess.py --task animate_anything --mode salientbboxtrack --video assets/videos/test.mp4 --image assets/images/test.jpg 53 | python vace/vace_preproccess.py --task swap_anything --mode salientmasktrack --video assets/videos/test.mp4 --image assets/images/test.jpg 54 | python vace/vace_preproccess.py --task swap_anything --mode label,salientbboxtrack --label 'cat' --maskaug_mode bbox --maskaug_ratio 0.3 --video assets/videos/test.mp4 --image assets/images/test.jpg 55 | python vace/vace_preproccess.py --task swap_anything --mode label,plain --label 'cat' --maskaug_mode bbox --maskaug_ratio 0.3 --video assets/videos/test.mp4 --image assets/images/test.jpg 56 | python vace/vace_preproccess.py --task expand_anything --mode salientbboxtrack --direction 'left,right' --expand_ratio 0.5 --expand_num 80 --image assets/images/test.jpg,assets/images/test2.jpg 57 | python vace/vace_preproccess.py --task expand_anything --mode firstframe,plain --direction 'left,right' --expand_ratio 0.5 --expand_num 80 --image assets/images/test.jpg,assets/images/test2.jpg 58 | python vace/vace_preproccess.py --task move_anything --bbox '0,264,338,636 400,264,538,636' --expand_num 80 --label 'cat' --image assets/images/test.jpg 59 | -------------------------------------------------------------------------------- /run_vace_wan.sh: -------------------------------------------------------------------------------- 1 | #------------------------ Gadio ------------------------# 2 | python vace/gradios/vace_wan_demo.py 3 | 4 | #------------------------ CLI ------------------------# 5 | # txt2vid txt2vid 6 | python vace/vace_wan_inference.py --prompt "狂风巨浪的大海,镜头缓缓推进,一艘渺小的帆船在汹涌的波涛中挣扎漂荡。海面上白沫翻滚,帆船时隐时现,仿佛随时可能被巨浪吞噬。天空乌云密布,雷声轰鸣,海鸥在空中盘旋尖叫。帆船上的人们紧紧抓住缆绳,努力保持平衡。画面风格写实,充满紧张和动感。近景特写,强调风浪的冲击力和帆船的摇晃" 7 | 8 | # extension firstframe 9 | python vace/vace_wan_inference.py --src_video "benchmarks/VACE-Benchmark/assets/examples/firstframe/src_video.mp4" --src_mask "benchmarks/VACE-Benchmark/assets/examples/firstframe/src_mask.mp4" --prompt "纪实摄影风格,前景是一位中国越野爱好者坐在越野车上,手持车载电台正在进行通联。他五官清晰,表情专注,眼神坚定地望向前方。越野车停在户外,车身略显脏污,显示出经历过的艰难路况。镜头从车外缓缓拉近,最后定格在人物的面部特写上,展现出他的,动态镜头运镜。" 10 | 11 | # repainting inpainting 12 | python vace/vace_wan_inference.py --src_video "benchmarks/VACE-Benchmark/assets/examples/inpainting/src_video.mp4" --src_mask "benchmarks/VACE-Benchmark/assets/examples/inpainting/src_mask.mp4" --prompt "一只巨大的金色凤凰从繁华的城市上空展翅飞过,羽毛如火焰般璀璨,闪烁着温暖的光辉,翅膀雄伟地展开。凤凰高昂着头,目光炯炯,轻轻扇动翅膀,散发出淡淡的光芒。下方是熙熙攘攘的市中心,人群惊叹,车水马龙,红蓝两色的霓虹灯在夜空下闪烁。镜头俯视城市街道,捕捉这一壮丽的景象,营造出既神秘又辉煌的氛围。" 13 | 14 | # repainting outpainting 15 | python vace/vace_wan_inference.py --src_video "benchmarks/VACE-Benchmark/assets/examples/outpainting/src_video.mp4" --src_mask "benchmarks/VACE-Benchmark/assets/examples/outpainting/src_mask.mp4" --prompt "赛博朋克风格,无人机俯瞰视角下的现代西安城墙,镜头穿过永宁门时泛起金色涟漪,城墙砖块化作数据流重组为唐代长安城。周围的街道上流动的人群和飞驰的机械交通工具交织在一起,现代与古代的交融,城墙上的灯光闪烁,形成时空隧道的效果。全息投影技术展现历史变迁,粒子重组特效细腻逼真。大远景逐渐过渡到特写,聚焦于城门特效。" 16 | 17 | # control depth 18 | python vace/vace_wan_inference.py --src_video "benchmarks/VACE-Benchmark/assets/examples/depth/src_video.mp4" --prompt "一群年轻人在天空之城拍摄集体照。画面中,一对年轻情侣手牵手,轻声细语,相视而笑,周围是飞翔的彩色热气球和闪烁的星星,营造出浪漫的氛围。天空中,暖阳透过飘浮的云朵,洒下斑驳的光影。镜头以近景特写开始,随着情侣间的亲密互动,缓缓拉远。" 19 | 20 | # control flow 21 | python vace/vace_wan_inference.py --src_video "benchmarks/VACE-Benchmark/assets/examples/flow/src_video.mp4" --prompt "纪实摄影风格,一颗鲜红的小番茄缓缓落入盛着牛奶的玻璃杯中,溅起晶莹的水花。画面以慢镜头捕捉这一瞬间,水花在空中绽放,形成美丽的弧线。玻璃杯中的牛奶纯白,番茄的鲜红与之形成鲜明对比。背景简洁,突出主体。近景特写,垂直俯视视角,展现细节之美。" 22 | 23 | # control gray 24 | python vace/vace_wan_inference.py --src_video "benchmarks/VACE-Benchmark/assets/examples/gray/src_video.mp4" --prompt "镜头缓缓向右平移,身穿淡黄色坎肩长裙的长发女孩面对镜头露出灿烂的漏齿微笑。她的长发随风轻扬,眼神明亮而充满活力。背景是秋天红色和黄色的树叶,阳光透过树叶的缝隙洒下斑驳光影,营造出温馨自然的氛围。画面风格清新自然,仿佛夏日午后的一抹清凉。中景人像,强调自然光效和细腻的皮肤质感。" 25 | 26 | # control pose 27 | python vace/vace_wan_inference.py --src_video "benchmarks/VACE-Benchmark/assets/examples/pose/src_video.mp4" --prompt "在一个热带的庆祝派对上,一家人围坐在椰子树下的长桌旁。桌上摆满了异国风味的美食。长辈们愉悦地交谈,年轻人兴奋地举杯碰撞,孩子们在沙滩上欢乐奔跑。背景中是湛蓝的海洋和明亮的阳光,营造出轻松的气氛。镜头以动态中景捕捉每个开心的瞬间,温暖的阳光映照着他们幸福的面庞。" 28 | 29 | # control scribble 30 | python vace/vace_wan_inference.py --src_video "benchmarks/VACE-Benchmark/assets/examples/scribble/src_video.mp4" --prompt "画面中荧光色彩的无人机从极低空高速掠过超现实主义风格的西安古城墙,尘埃反射着阳光。镜头快速切换至城墙上的砖石特写,阳光温暖地洒落,勾勒出每一块砖块的细腻纹理。整体画质清晰华丽,运镜流畅如水。" 31 | 32 | # control layout 33 | python vace/vace_wan_inference.py --src_video "benchmarks/VACE-Benchmark/assets/examples/layout/src_video.mp4" --prompt "视频展示了一只成鸟在树枝上的巢中喂养它的幼鸟。成鸟在喂食的过程中,幼鸟张开嘴巴等待食物。随后,成鸟飞走,幼鸟继续等待。成鸟再次飞回,带回食物喂养幼鸟。整个视频的拍摄角度固定,聚焦于巢穴和鸟类的互动,背景是模糊的绿色植被,强调了鸟类的自然行为和生态环境。" 34 | 35 | # reference face 36 | python vace/vace_wan_inference.py --src_ref_images "benchmarks/VACE-Benchmark/assets/examples/face/src_ref_image_1.png" --prompt "视频展示了一位长着尖耳朵的老人,他有一头银白色的长发和小胡子,穿着一件色彩斑斓的长袍,内搭金色衬衫,散发出神秘与智慧的气息。背景为一个华丽宫殿的内部,金碧辉煌。灯光明亮,照亮他脸上的神采奕奕。摄像机旋转动态拍摄,捕捉老人轻松挥手的动作。" 37 | 38 | # reference object 39 | python vace/vace_wan_inference.py --src_ref_images "benchmarks/VACE-Benchmark/assets/examples/object/src_ref_image_1.png" --prompt "经典游戏角色马里奥在绿松石色水下世界中,四周环绕着珊瑚和各种各样的热带鱼。马里奥兴奋地向上跳起,摆出经典的欢快姿势,身穿鲜明的蓝色潜水服,红色的潜水面罩上印有“M”标志,脚上是一双潜水靴。背景中,水泡随波逐流,浮现出一个巨大而友好的海星。摄像机从水底向上快速移动,捕捉他跃出水面的瞬间,灯光明亮而流动。该场景融合了动画与幻想元素,令人惊叹。" 40 | 41 | # composition reference_anything 42 | python vace/vace_wan_inference.py --src_ref_images "benchmarks/VACE-Benchmark/assets/examples/reference_anything/src_ref_image_1.png,benchmarks/VACE-Benchmark/assets/examples/reference_anything/src_ref_image_2.png" --prompt "一名打扮成超人的男子自信地站着,面对镜头,肩头有一只充满活力的毛绒黄色鸭子。他留着整齐的短发和浅色胡须,鸭子有橙色的喙和脚,它的翅膀稍微展开,脚分开以保持稳定。他的表情严肃而坚定。他穿着标志性的蓝红超人服装,胸前有黄色“S”标志。斗篷在他身后飘逸。背景有行人。相机位于视线水平,捕捉角色的整个上半身。灯光均匀明亮。" 43 | 44 | # composition swap_anything 45 | python vace/vace_wan_inference.py --src_video "benchmarks/VACE-Benchmark/assets/examples/swap_anything/src_video.mp4" --src_mask "benchmarks/VACE-Benchmark/assets/examples/swap_anything/src_mask.mp4" --src_ref_images "benchmarks/VACE-Benchmark/assets/examples/swap_anything/src_ref_image_1.png" --prompt "视频展示了一个人在宽阔的草原上骑马。他有淡紫色长发,穿着传统服饰白上衣黑裤子,动画建模画风,看起来像是在进行某种户外活动或者是在进行某种表演。背景是壮观的山脉云的天空,给人一种宁静而广阔的感觉。整个视频的拍摄角度是固定的,重点展示了骑手和他的马。" 46 | 47 | # composition expand_anything 48 | python vace/vace_wan_inference.py --src_video "benchmarks/VACE-Benchmark/assets/examples/expand_anything/src_video.mp4" --src_mask "benchmarks/VACE-Benchmark/assets/examples/expand_anything/src_mask.mp4" --src_ref_images "benchmarks/VACE-Benchmark/assets/examples/expand_anything/src_ref_image_1.png" --prompt "古典油画风格,背景是一条河边,画面中央一位成熟优雅的女人,穿着长裙坐在椅子上。她双手从怀里取出打开的红色心形墨镜戴上。固定机位。" 49 | -------------------------------------------------------------------------------- /vace/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | from . import annotators 4 | from . import configs 5 | from . import models 6 | from . import gradios -------------------------------------------------------------------------------- /vace/annotators/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | from .depth import DepthAnnotator, DepthVideoAnnotator, DepthV2VideoAnnotator 4 | from .flow import FlowAnnotator, FlowVisAnnotator 5 | from .frameref import FrameRefExtractAnnotator, FrameRefExpandAnnotator 6 | from .gdino import GDINOAnnotator, GDINORAMAnnotator 7 | from .gray import GrayAnnotator, GrayVideoAnnotator 8 | from .inpainting import InpaintingAnnotator, InpaintingVideoAnnotator 9 | from .layout import LayoutBboxAnnotator, LayoutMaskAnnotator, LayoutTrackAnnotator 10 | from .maskaug import MaskAugAnnotator 11 | from .outpainting import OutpaintingAnnotator, OutpaintingInnerAnnotator, OutpaintingVideoAnnotator, OutpaintingInnerVideoAnnotator 12 | from .pose import PoseBodyFaceAnnotator, PoseBodyFaceVideoAnnotator, PoseAnnotator, PoseBodyVideoAnnotator, PoseBodyAnnotator 13 | from .ram import RAMAnnotator 14 | from .salient import SalientAnnotator, SalientVideoAnnotator 15 | from .sam import SAMImageAnnotator 16 | from .sam2 import SAM2ImageAnnotator, SAM2VideoAnnotator, SAM2SalientVideoAnnotator, SAM2GDINOVideoAnnotator 17 | from .scribble import ScribbleAnnotator, ScribbleVideoAnnotator 18 | from .face import FaceAnnotator 19 | from .subject import SubjectAnnotator 20 | from .common import PlainImageAnnotator, PlainMaskAnnotator, PlainMaskAugAnnotator, PlainMaskVideoAnnotator, PlainVideoAnnotator, PlainMaskAugVideoAnnotator, PlainMaskAugInvertAnnotator, PlainMaskAugInvertVideoAnnotator, ExpandMaskVideoAnnotator 21 | from .prompt_extend import PromptExtendAnnotator 22 | from .composition import CompositionAnnotator, ReferenceAnythingAnnotator, AnimateAnythingAnnotator, SwapAnythingAnnotator, ExpandAnythingAnnotator, MoveAnythingAnnotator 23 | from .mask import MaskDrawAnnotator 24 | from .canvas import RegionCanvasAnnotator -------------------------------------------------------------------------------- /vace/annotators/canvas.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | from .utils import convert_to_numpy 9 | 10 | 11 | class RegionCanvasAnnotator: 12 | def __init__(self, cfg, device=None): 13 | self.scale_range = cfg.get('SCALE_RANGE', [0.75, 1.0]) 14 | self.canvas_value = cfg.get('CANVAS_VALUE', 255) 15 | self.use_resize = cfg.get('USE_RESIZE', True) 16 | self.use_canvas = cfg.get('USE_CANVAS', True) 17 | self.use_aug = cfg.get('USE_AUG', False) 18 | if self.use_aug: 19 | from .maskaug import MaskAugAnnotator 20 | self.maskaug_anno = MaskAugAnnotator(cfg={}) 21 | 22 | def forward(self, image, mask, mask_cfg=None): 23 | 24 | image = convert_to_numpy(image) 25 | mask = convert_to_numpy(mask) 26 | image_h, image_w = image.shape[:2] 27 | 28 | if self.use_aug: 29 | mask = self.maskaug_anno.forward(mask, mask_cfg) 30 | 31 | # get region with white bg 32 | image[np.array(mask) == 0] = self.canvas_value 33 | x, y, w, h = cv2.boundingRect(mask) 34 | region_crop = image[y:y + h, x:x + w] 35 | 36 | if self.use_resize: 37 | # resize region 38 | scale_min, scale_max = self.scale_range 39 | scale_factor = random.uniform(scale_min, scale_max) 40 | new_w, new_h = int(image_w * scale_factor), int(image_h * scale_factor) 41 | obj_scale_factor = min(new_w/w, new_h/h) 42 | 43 | new_w = int(w * obj_scale_factor) 44 | new_h = int(h * obj_scale_factor) 45 | region_crop_resized = cv2.resize(region_crop, (new_w, new_h), interpolation=cv2.INTER_AREA) 46 | else: 47 | region_crop_resized = region_crop 48 | 49 | if self.use_canvas: 50 | # plot region into canvas 51 | new_canvas = np.ones_like(image) * self.canvas_value 52 | max_x = max(0, image_w - new_w) 53 | max_y = max(0, image_h - new_h) 54 | new_x = random.randint(0, max_x) 55 | new_y = random.randint(0, max_y) 56 | 57 | new_canvas[new_y:new_y + new_h, new_x:new_x + new_w] = region_crop_resized 58 | else: 59 | new_canvas = region_crop_resized 60 | return new_canvas -------------------------------------------------------------------------------- /vace/annotators/common.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | 4 | class PlainImageAnnotator: 5 | def __init__(self, cfg): 6 | pass 7 | def forward(self, image): 8 | return image 9 | 10 | class PlainVideoAnnotator: 11 | def __init__(self, cfg): 12 | pass 13 | def forward(self, frames): 14 | return frames 15 | 16 | class PlainMaskAnnotator: 17 | def __init__(self, cfg): 18 | pass 19 | def forward(self, mask): 20 | return mask 21 | 22 | class PlainMaskAugInvertAnnotator: 23 | def __init__(self, cfg): 24 | pass 25 | def forward(self, mask): 26 | return 255 - mask 27 | 28 | class PlainMaskAugAnnotator: 29 | def __init__(self, cfg): 30 | pass 31 | def forward(self, mask): 32 | return mask 33 | 34 | class PlainMaskVideoAnnotator: 35 | def __init__(self, cfg): 36 | pass 37 | def forward(self, mask): 38 | return mask 39 | 40 | class PlainMaskAugVideoAnnotator: 41 | def __init__(self, cfg): 42 | pass 43 | def forward(self, masks): 44 | return masks 45 | 46 | class PlainMaskAugInvertVideoAnnotator: 47 | def __init__(self, cfg): 48 | pass 49 | def forward(self, masks): 50 | return [255 - mask for mask in masks] 51 | 52 | class ExpandMaskVideoAnnotator: 53 | def __init__(self, cfg): 54 | pass 55 | def forward(self, mask, expand_num): 56 | return [mask] * expand_num 57 | 58 | class PlainPromptAnnotator: 59 | def __init__(self, cfg): 60 | pass 61 | def forward(self, prompt): 62 | return prompt -------------------------------------------------------------------------------- /vace/annotators/composition.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import numpy as np 4 | 5 | class CompositionAnnotator: 6 | def __init__(self, cfg): 7 | self.process_types = ["repaint", "extension", "control"] 8 | self.process_map = { 9 | "repaint": "repaint", 10 | "extension": "extension", 11 | "control": "control", 12 | "inpainting": "repaint", 13 | "outpainting": "repaint", 14 | "frameref": "extension", 15 | "clipref": "extension", 16 | "depth": "control", 17 | "flow": "control", 18 | "gray": "control", 19 | "pose": "control", 20 | "scribble": "control", 21 | "layout": "control" 22 | } 23 | 24 | def forward(self, process_type_1, process_type_2, frames_1, frames_2, masks_1, masks_2): 25 | total_frames = min(len(frames_1), len(frames_2), len(masks_1), len(masks_2)) 26 | combine_type = (self.process_map[process_type_1], self.process_map[process_type_2]) 27 | if combine_type in [("extension", "repaint"), ("extension", "control"), ("extension", "extension")]: 28 | output_video = [frames_2[i] * masks_1[i] + frames_1[i] * (1 - masks_1[i]) for i in range(total_frames)] 29 | output_mask = [masks_1[i] * masks_2[i] * 255 for i in range(total_frames)] 30 | elif combine_type in [("repaint", "extension"), ("control", "extension"), ("repaint", "repaint")]: 31 | output_video = [frames_1[i] * (1 - masks_2[i]) + frames_2[i] * masks_2[i] for i in range(total_frames)] 32 | output_mask = [(masks_1[i] * (1 - masks_2[i]) + masks_2[i] * masks_2[i]) * 255 for i in range(total_frames)] 33 | elif combine_type in [("repaint", "control"), ("control", "repaint")]: 34 | if combine_type in [("control", "repaint")]: 35 | frames_1, frames_2, masks_1, masks_2 = frames_2, frames_1, masks_2, masks_1 36 | output_video = [frames_1[i] * (1 - masks_1[i]) + frames_2[i] * masks_1[i] for i in range(total_frames)] 37 | output_mask = [masks_1[i] * 255 for i in range(total_frames)] 38 | elif combine_type in [("control", "control")]: # apply masks_2 39 | output_video = [frames_1[i] * (1 - masks_2[i]) + frames_2[i] * masks_2[i] for i in range(total_frames)] 40 | output_mask = [(masks_1[i] * (1 - masks_2[i]) + masks_2[i] * masks_2[i]) * 255 for i in range(total_frames)] 41 | else: 42 | raise Exception("Unknown combine type") 43 | return output_video, output_mask 44 | 45 | 46 | class ReferenceAnythingAnnotator: 47 | def __init__(self, cfg): 48 | from .subject import SubjectAnnotator 49 | self.sbjref_ins = SubjectAnnotator(cfg['SUBJECT'] if 'SUBJECT' in cfg else cfg) 50 | self.key_map = { 51 | "image": "images", 52 | "mask": "masks" 53 | } 54 | def forward(self, images, mode=None, return_mask=None, mask_cfg=None): 55 | ret_data = {} 56 | for image in images: 57 | ret_one_data = self.sbjref_ins.forward(image=image, mode=mode, return_mask=return_mask, mask_cfg=mask_cfg) 58 | if isinstance(ret_one_data, dict): 59 | for key, val in ret_one_data.items(): 60 | if key in self.key_map: 61 | new_key = self.key_map[key] 62 | else: 63 | continue 64 | if new_key in ret_data: 65 | ret_data[new_key].append(val) 66 | else: 67 | ret_data[new_key] = [val] 68 | else: 69 | if 'images' in ret_data: 70 | ret_data['images'].append(ret_data) 71 | else: 72 | ret_data['images'] = [ret_data] 73 | return ret_data 74 | 75 | 76 | class AnimateAnythingAnnotator: 77 | def __init__(self, cfg): 78 | from .pose import PoseBodyFaceVideoAnnotator 79 | self.pose_ins = PoseBodyFaceVideoAnnotator(cfg['POSE']) 80 | self.ref_ins = ReferenceAnythingAnnotator(cfg['REFERENCE']) 81 | 82 | def forward(self, frames=None, images=None, mode=None, return_mask=None, mask_cfg=None): 83 | ret_data = {} 84 | ret_pose_data = self.pose_ins.forward(frames=frames) 85 | ret_data.update({"frames": ret_pose_data}) 86 | 87 | ret_ref_data = self.ref_ins.forward(images=images, mode=mode, return_mask=return_mask, mask_cfg=mask_cfg) 88 | ret_data.update({"images": ret_ref_data['images']}) 89 | 90 | return ret_data 91 | 92 | 93 | class SwapAnythingAnnotator: 94 | def __init__(self, cfg): 95 | from .inpainting import InpaintingVideoAnnotator 96 | self.inp_ins = InpaintingVideoAnnotator(cfg['INPAINTING']) 97 | self.ref_ins = ReferenceAnythingAnnotator(cfg['REFERENCE']) 98 | 99 | def forward(self, video=None, frames=None, images=None, mode=None, mask=None, bbox=None, label=None, caption=None, return_mask=None, mask_cfg=None): 100 | ret_data = {} 101 | mode = mode.split(',') if ',' in mode else [mode, mode] 102 | 103 | ret_inp_data = self.inp_ins.forward(video=video, frames=frames, mode=mode[0], mask=mask, bbox=bbox, label=label, caption=caption, mask_cfg=mask_cfg) 104 | ret_data.update(ret_inp_data) 105 | 106 | ret_ref_data = self.ref_ins.forward(images=images, mode=mode[1], return_mask=return_mask, mask_cfg=mask_cfg) 107 | ret_data.update({"images": ret_ref_data['images']}) 108 | 109 | return ret_data 110 | 111 | 112 | class ExpandAnythingAnnotator: 113 | def __init__(self, cfg): 114 | from .outpainting import OutpaintingAnnotator 115 | from .frameref import FrameRefExpandAnnotator 116 | self.ref_ins = ReferenceAnythingAnnotator(cfg['REFERENCE']) 117 | self.frameref_ins = FrameRefExpandAnnotator(cfg['FRAMEREF']) 118 | self.outpainting_ins = OutpaintingAnnotator(cfg['OUTPAINTING']) 119 | 120 | def forward(self, images=None, mode=None, return_mask=None, mask_cfg=None, direction=None, expand_ratio=None, expand_num=None): 121 | ret_data = {} 122 | expand_image, reference_image= images[0], images[1:] 123 | mode = mode.split(',') if ',' in mode else ['firstframe', mode] 124 | 125 | outpainting_data = self.outpainting_ins.forward(expand_image,expand_ratio=expand_ratio, direction=direction) 126 | outpainting_image, outpainting_mask = outpainting_data['image'], outpainting_data['mask'] 127 | 128 | frameref_data = self.frameref_ins.forward(outpainting_image, mode=mode[0], expand_num=expand_num) 129 | frames, masks = frameref_data['frames'], frameref_data['masks'] 130 | masks[0] = outpainting_mask 131 | ret_data.update({"frames": frames, "masks": masks}) 132 | 133 | ret_ref_data = self.ref_ins.forward(images=reference_image, mode=mode[1], return_mask=return_mask, mask_cfg=mask_cfg) 134 | ret_data.update({"images": ret_ref_data['images']}) 135 | 136 | return ret_data 137 | 138 | 139 | class MoveAnythingAnnotator: 140 | def __init__(self, cfg): 141 | from .layout import LayoutBboxAnnotator 142 | self.layout_bbox_ins = LayoutBboxAnnotator(cfg['LAYOUTBBOX']) 143 | 144 | def forward(self, image=None, bbox=None, label=None, expand_num=None): 145 | frame_size = image.shape[:2] # [H, W] 146 | ret_layout_data = self.layout_bbox_ins.forward(bbox, frame_size=frame_size, num_frames=expand_num, label=label) 147 | 148 | out_frames = [image] + ret_layout_data 149 | out_mask = [np.zeros(frame_size, dtype=np.uint8)] + [np.ones(frame_size, dtype=np.uint8) * 255] * len(ret_layout_data) 150 | 151 | ret_data = { 152 | "frames": out_frames, 153 | "masks": out_mask 154 | } 155 | return ret_data -------------------------------------------------------------------------------- /vace/annotators/depth.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import numpy as np 4 | import torch 5 | from einops import rearrange 6 | 7 | from .utils import convert_to_numpy, resize_image, resize_image_ori 8 | 9 | class DepthAnnotator: 10 | def __init__(self, cfg, device=None): 11 | from .midas.api import MiDaSInference 12 | pretrained_model = cfg['PRETRAINED_MODEL'] 13 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device 14 | self.model = MiDaSInference(model_type='dpt_hybrid', model_path=pretrained_model).to(self.device) 15 | self.a = cfg.get('A', np.pi * 2.0) 16 | self.bg_th = cfg.get('BG_TH', 0.1) 17 | 18 | @torch.no_grad() 19 | @torch.inference_mode() 20 | @torch.autocast('cuda', enabled=False) 21 | def forward(self, image): 22 | image = convert_to_numpy(image) 23 | image_depth = image 24 | h, w, c = image.shape 25 | image_depth, k = resize_image(image_depth, 26 | 1024 if min(h, w) > 1024 else min(h, w)) 27 | image_depth = torch.from_numpy(image_depth).float().to(self.device) 28 | image_depth = image_depth / 127.5 - 1.0 29 | image_depth = rearrange(image_depth, 'h w c -> 1 c h w') 30 | depth = self.model(image_depth)[0] 31 | 32 | depth_pt = depth.clone() 33 | depth_pt -= torch.min(depth_pt) 34 | depth_pt /= torch.max(depth_pt) 35 | depth_pt = depth_pt.cpu().numpy() 36 | depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) 37 | depth_image = depth_image[..., None].repeat(3, 2) 38 | 39 | depth_image = resize_image_ori(h, w, depth_image, k) 40 | return depth_image 41 | 42 | 43 | class DepthVideoAnnotator(DepthAnnotator): 44 | def forward(self, frames): 45 | ret_frames = [] 46 | for frame in frames: 47 | anno_frame = super().forward(np.array(frame)) 48 | ret_frames.append(anno_frame) 49 | return ret_frames 50 | 51 | 52 | class DepthV2Annotator: 53 | def __init__(self, cfg, device=None): 54 | from .depth_anything_v2.dpt import DepthAnythingV2 55 | pretrained_model = cfg['PRETRAINED_MODEL'] 56 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device 57 | self.model = DepthAnythingV2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024]).to(self.device) 58 | self.model.load_state_dict( 59 | torch.load( 60 | pretrained_model, 61 | map_location=self.device 62 | ) 63 | ) 64 | self.model.eval() 65 | 66 | @torch.inference_mode() 67 | @torch.autocast('cuda', enabled=False) 68 | def forward(self, image): 69 | image = convert_to_numpy(image) 70 | depth = self.model.infer_image(image) 71 | 72 | depth_pt = depth.copy() 73 | depth_pt -= np.min(depth_pt) 74 | depth_pt /= np.max(depth_pt) 75 | depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) 76 | 77 | depth_image = depth_image[..., np.newaxis] 78 | depth_image = np.repeat(depth_image, 3, axis=2) 79 | return depth_image 80 | 81 | 82 | class DepthV2VideoAnnotator(DepthV2Annotator): 83 | def forward(self, frames): 84 | ret_frames = [] 85 | for frame in frames: 86 | anno_frame = super().forward(np.array(frame)) 87 | ret_frames.append(anno_frame) 88 | return ret_frames 89 | -------------------------------------------------------------------------------- /vace/annotators/depth_anything_v2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/vace/annotators/depth_anything_v2/__init__.py -------------------------------------------------------------------------------- /vace/annotators/depth_anything_v2/dpt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision.transforms import Compose 8 | 9 | from .dinov2 import DINOv2 10 | from .util.blocks import FeatureFusionBlock, _make_scratch 11 | from .util.transform import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | class DepthAnythingV2(nn.Module): 15 | def __init__( 16 | self, 17 | encoder='vitl', 18 | features=256, 19 | out_channels=[256, 512, 1024, 1024], 20 | use_bn=False, 21 | use_clstoken=False 22 | ): 23 | super(DepthAnythingV2, self).__init__() 24 | 25 | self.intermediate_layer_idx = { 26 | 'vits': [2, 5, 8, 11], 27 | 'vitb': [2, 5, 8, 11], 28 | 'vitl': [4, 11, 17, 23], 29 | 'vitg': [9, 19, 29, 39] 30 | } 31 | 32 | self.encoder = encoder 33 | self.pretrained = DINOv2(model_name=encoder) 34 | 35 | self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, 36 | use_clstoken=use_clstoken) 37 | 38 | def forward(self, x): 39 | patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14 40 | 41 | features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], 42 | return_class_token=True) 43 | 44 | depth = self.depth_head(features, patch_h, patch_w) 45 | depth = F.relu(depth) 46 | 47 | return depth.squeeze(1) 48 | 49 | @torch.no_grad() 50 | def infer_image(self, raw_image, input_size=518): 51 | image, (h, w) = self.image2tensor(raw_image, input_size) 52 | 53 | depth = self.forward(image) 54 | depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0] 55 | 56 | return depth.cpu().numpy() 57 | 58 | def image2tensor(self, raw_image, input_size=518): 59 | transform = Compose([ 60 | Resize( 61 | width=input_size, 62 | height=input_size, 63 | resize_target=False, 64 | keep_aspect_ratio=True, 65 | ensure_multiple_of=14, 66 | resize_method='lower_bound', 67 | image_interpolation_method=cv2.INTER_CUBIC, 68 | ), 69 | NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 70 | PrepareForNet(), 71 | ]) 72 | 73 | h, w = raw_image.shape[:2] 74 | 75 | image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0 76 | 77 | image = transform({'image': image})['image'] 78 | image = torch.from_numpy(image).unsqueeze(0) 79 | 80 | DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' 81 | image = image.to(DEVICE) 82 | 83 | return image, (h, w) 84 | 85 | 86 | class DPTHead(nn.Module): 87 | def __init__( 88 | self, 89 | in_channels, 90 | features=256, 91 | use_bn=False, 92 | out_channels=[256, 512, 1024, 1024], 93 | use_clstoken=False 94 | ): 95 | super(DPTHead, self).__init__() 96 | 97 | self.use_clstoken = use_clstoken 98 | 99 | self.projects = nn.ModuleList([ 100 | nn.Conv2d( 101 | in_channels=in_channels, 102 | out_channels=out_channel, 103 | kernel_size=1, 104 | stride=1, 105 | padding=0, 106 | ) for out_channel in out_channels 107 | ]) 108 | 109 | self.resize_layers = nn.ModuleList([ 110 | nn.ConvTranspose2d( 111 | in_channels=out_channels[0], 112 | out_channels=out_channels[0], 113 | kernel_size=4, 114 | stride=4, 115 | padding=0), 116 | nn.ConvTranspose2d( 117 | in_channels=out_channels[1], 118 | out_channels=out_channels[1], 119 | kernel_size=2, 120 | stride=2, 121 | padding=0), 122 | nn.Identity(), 123 | nn.Conv2d( 124 | in_channels=out_channels[3], 125 | out_channels=out_channels[3], 126 | kernel_size=3, 127 | stride=2, 128 | padding=1) 129 | ]) 130 | 131 | if use_clstoken: 132 | self.readout_projects = nn.ModuleList() 133 | for _ in range(len(self.projects)): 134 | self.readout_projects.append( 135 | nn.Sequential( 136 | nn.Linear(2 * in_channels, in_channels), 137 | nn.GELU())) 138 | 139 | self.scratch = _make_scratch( 140 | out_channels, 141 | features, 142 | groups=1, 143 | expand=False, 144 | ) 145 | 146 | self.scratch.stem_transpose = None 147 | 148 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 149 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 150 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 151 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 152 | 153 | head_features_1 = features 154 | head_features_2 = 32 155 | 156 | self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) 157 | self.scratch.output_conv2 = nn.Sequential( 158 | nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), 159 | nn.ReLU(True), 160 | nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), 161 | nn.ReLU(True), 162 | nn.Identity(), 163 | ) 164 | 165 | def forward(self, out_features, patch_h, patch_w): 166 | out = [] 167 | for i, x in enumerate(out_features): 168 | if self.use_clstoken: 169 | x, cls_token = x[0], x[1] 170 | readout = cls_token.unsqueeze(1).expand_as(x) 171 | x = self.readout_projects[i](torch.cat((x, readout), -1)) 172 | else: 173 | x = x[0] 174 | 175 | x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) 176 | 177 | x = self.projects[i](x) 178 | x = self.resize_layers[i](x) 179 | 180 | out.append(x) 181 | 182 | layer_1, layer_2, layer_3, layer_4 = out 183 | 184 | layer_1_rn = self.scratch.layer1_rn(layer_1) 185 | layer_2_rn = self.scratch.layer2_rn(layer_2) 186 | layer_3_rn = self.scratch.layer3_rn(layer_3) 187 | layer_4_rn = self.scratch.layer4_rn(layer_4) 188 | 189 | path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) 190 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) 191 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) 192 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 193 | 194 | out = self.scratch.output_conv1(path_1) 195 | out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) 196 | out = self.scratch.output_conv2(out) 197 | 198 | return out 199 | 200 | 201 | def _make_fusion_block(features, use_bn, size=None): 202 | return FeatureFusionBlock( 203 | features, 204 | nn.ReLU(False), 205 | deconv=False, 206 | bn=use_bn, 207 | expand=False, 208 | align_corners=True, 209 | size=size, 210 | ) 211 | -------------------------------------------------------------------------------- /vace/annotators/depth_anything_v2/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .mlp import Mlp 8 | from .patch_embed import PatchEmbed 9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 10 | from .block import NestedTensorBlock 11 | from .attention import MemEffAttention -------------------------------------------------------------------------------- /vace/annotators/depth_anything_v2/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | import logging 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | try: 19 | from xformers.ops import memory_efficient_attention, unbind, fmha 20 | 21 | XFORMERS_AVAILABLE = True 22 | except ImportError: 23 | logger.warning("xFormers not available") 24 | XFORMERS_AVAILABLE = False 25 | 26 | 27 | class Attention(nn.Module): 28 | def __init__( 29 | self, 30 | dim: int, 31 | num_heads: int = 8, 32 | qkv_bias: bool = False, 33 | proj_bias: bool = True, 34 | attn_drop: float = 0.0, 35 | proj_drop: float = 0.0, 36 | ) -> None: 37 | super().__init__() 38 | self.num_heads = num_heads 39 | head_dim = dim // num_heads 40 | self.scale = head_dim ** -0.5 41 | 42 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 43 | self.attn_drop = nn.Dropout(attn_drop) 44 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 45 | self.proj_drop = nn.Dropout(proj_drop) 46 | 47 | def forward(self, x: Tensor) -> Tensor: 48 | B, N, C = x.shape 49 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 50 | 51 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 52 | attn = q @ k.transpose(-2, -1) 53 | 54 | attn = attn.softmax(dim=-1) 55 | attn = self.attn_drop(attn) 56 | 57 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 58 | x = self.proj(x) 59 | x = self.proj_drop(x) 60 | return x 61 | 62 | 63 | class MemEffAttention(Attention): 64 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 65 | if not XFORMERS_AVAILABLE: 66 | assert attn_bias is None, "xFormers is required for nested tensors usage" 67 | return super().forward(x) 68 | 69 | B, N, C = x.shape 70 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 71 | 72 | q, k, v = unbind(qkv, 2) 73 | 74 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 75 | x = x.reshape([B, N, C]) 76 | 77 | x = self.proj(x) 78 | x = self.proj_drop(x) 79 | return x 80 | -------------------------------------------------------------------------------- /vace/annotators/depth_anything_v2/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 10 | 11 | from torch import nn 12 | 13 | 14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 15 | if drop_prob == 0.0 or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 20 | if keep_prob > 0.0: 21 | random_tensor.div_(keep_prob) 22 | output = x * random_tensor 23 | return output 24 | 25 | 26 | class DropPath(nn.Module): 27 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 28 | 29 | def __init__(self, drop_prob=None): 30 | super(DropPath, self).__init__() 31 | self.drop_prob = drop_prob 32 | 33 | def forward(self, x): 34 | return drop_path(x, self.drop_prob, self.training) 35 | -------------------------------------------------------------------------------- /vace/annotators/depth_anything_v2/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py 7 | 8 | 9 | from typing import Union 10 | 11 | import torch 12 | from torch import Tensor 13 | from torch import nn 14 | 15 | 16 | class LayerScale(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | init_values: Union[float, Tensor] = 1e-5, 21 | inplace: bool = False, 22 | ) -> None: 23 | super().__init__() 24 | self.inplace = inplace 25 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 29 | -------------------------------------------------------------------------------- /vace/annotators/depth_anything_v2/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 10 | 11 | from typing import Callable, Optional 12 | from torch import Tensor, nn 13 | 14 | 15 | class Mlp(nn.Module): 16 | def __init__( 17 | self, 18 | in_features: int, 19 | hidden_features: Optional[int] = None, 20 | out_features: Optional[int] = None, 21 | act_layer: Callable[..., nn.Module] = nn.GELU, 22 | drop: float = 0.0, 23 | bias: bool = True, 24 | ) -> None: 25 | super().__init__() 26 | out_features = out_features or in_features 27 | hidden_features = hidden_features or in_features 28 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 29 | self.act = act_layer() 30 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 31 | self.drop = nn.Dropout(drop) 32 | 33 | def forward(self, x: Tensor) -> Tensor: 34 | x = self.fc1(x) 35 | x = self.act(x) 36 | x = self.drop(x) 37 | x = self.fc2(x) 38 | x = self.drop(x) 39 | return x 40 | -------------------------------------------------------------------------------- /vace/annotators/depth_anything_v2/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # References: 9 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 10 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 11 | 12 | from typing import Callable, Optional, Tuple, Union 13 | 14 | from torch import Tensor 15 | import torch.nn as nn 16 | 17 | 18 | def make_2tuple(x): 19 | if isinstance(x, tuple): 20 | assert len(x) == 2 21 | return x 22 | 23 | assert isinstance(x, int) 24 | return (x, x) 25 | 26 | 27 | class PatchEmbed(nn.Module): 28 | """ 29 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 30 | 31 | Args: 32 | img_size: Image size. 33 | patch_size: Patch token size. 34 | in_chans: Number of input image channels. 35 | embed_dim: Number of linear projection output channels. 36 | norm_layer: Normalization layer. 37 | """ 38 | 39 | def __init__( 40 | self, 41 | img_size: Union[int, Tuple[int, int]] = 224, 42 | patch_size: Union[int, Tuple[int, int]] = 16, 43 | in_chans: int = 3, 44 | embed_dim: int = 768, 45 | norm_layer: Optional[Callable] = None, 46 | flatten_embedding: bool = True, 47 | ) -> None: 48 | super().__init__() 49 | 50 | image_HW = make_2tuple(img_size) 51 | patch_HW = make_2tuple(patch_size) 52 | patch_grid_size = ( 53 | image_HW[0] // patch_HW[0], 54 | image_HW[1] // patch_HW[1], 55 | ) 56 | 57 | self.img_size = image_HW 58 | self.patch_size = patch_HW 59 | self.patches_resolution = patch_grid_size 60 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 61 | 62 | self.in_chans = in_chans 63 | self.embed_dim = embed_dim 64 | 65 | self.flatten_embedding = flatten_embedding 66 | 67 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 68 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 69 | 70 | def forward(self, x: Tensor) -> Tensor: 71 | _, _, H, W = x.shape 72 | patch_H, patch_W = self.patch_size 73 | 74 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 75 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 76 | 77 | x = self.proj(x) # B C H W 78 | H, W = x.size(2), x.size(3) 79 | x = x.flatten(2).transpose(1, 2) # B HW C 80 | x = self.norm(x) 81 | if not self.flatten_embedding: 82 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 83 | return x 84 | 85 | def flops(self) -> float: 86 | Ho, Wo = self.patches_resolution 87 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 88 | if self.norm is not None: 89 | flops += Ho * Wo * self.embed_dim 90 | return flops 91 | -------------------------------------------------------------------------------- /vace/annotators/depth_anything_v2/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from typing import Callable, Optional 9 | 10 | from torch import Tensor, nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class SwiGLUFFN(nn.Module): 15 | def __init__( 16 | self, 17 | in_features: int, 18 | hidden_features: Optional[int] = None, 19 | out_features: Optional[int] = None, 20 | act_layer: Callable[..., nn.Module] = None, 21 | drop: float = 0.0, 22 | bias: bool = True, 23 | ) -> None: 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 29 | 30 | def forward(self, x: Tensor) -> Tensor: 31 | x12 = self.w12(x) 32 | x1, x2 = x12.chunk(2, dim=-1) 33 | hidden = F.silu(x1) * x2 34 | return self.w3(hidden) 35 | 36 | 37 | try: 38 | from xformers.ops import SwiGLU 39 | 40 | XFORMERS_AVAILABLE = True 41 | except ImportError: 42 | SwiGLU = SwiGLUFFN 43 | XFORMERS_AVAILABLE = False 44 | 45 | 46 | class SwiGLUFFNFused(SwiGLU): 47 | def __init__( 48 | self, 49 | in_features: int, 50 | hidden_features: Optional[int] = None, 51 | out_features: Optional[int] = None, 52 | act_layer: Callable[..., nn.Module] = None, 53 | drop: float = 0.0, 54 | bias: bool = True, 55 | ) -> None: 56 | out_features = out_features or in_features 57 | hidden_features = hidden_features or in_features 58 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 59 | super().__init__( 60 | in_features=in_features, 61 | hidden_features=hidden_features, 62 | out_features=out_features, 63 | bias=bias, 64 | ) 65 | -------------------------------------------------------------------------------- /vace/annotators/depth_anything_v2/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/VACE/0897c6d055d7d9ea9e191dce763006664d9780f8/vace/annotators/depth_anything_v2/util/__init__.py -------------------------------------------------------------------------------- /vace/annotators/depth_anything_v2/util/blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 5 | scratch = nn.Module() 6 | 7 | out_shape1 = out_shape 8 | out_shape2 = out_shape 9 | out_shape3 = out_shape 10 | if len(in_shape) >= 4: 11 | out_shape4 = out_shape 12 | 13 | if expand: 14 | out_shape1 = out_shape 15 | out_shape2 = out_shape * 2 16 | out_shape3 = out_shape * 4 17 | if len(in_shape) >= 4: 18 | out_shape4 = out_shape * 8 19 | 20 | scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, 21 | groups=groups) 22 | scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, 23 | groups=groups) 24 | scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, 25 | groups=groups) 26 | if len(in_shape) >= 4: 27 | scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, 28 | groups=groups) 29 | 30 | return scratch 31 | 32 | 33 | class ResidualConvUnit(nn.Module): 34 | """Residual convolution module. 35 | """ 36 | 37 | def __init__(self, features, activation, bn): 38 | """Init. 39 | 40 | Args: 41 | features (int): number of features 42 | """ 43 | super().__init__() 44 | 45 | self.bn = bn 46 | 47 | self.groups = 1 48 | 49 | self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) 50 | 51 | self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) 52 | 53 | if self.bn == True: 54 | self.bn1 = nn.BatchNorm2d(features) 55 | self.bn2 = nn.BatchNorm2d(features) 56 | 57 | self.activation = activation 58 | 59 | self.skip_add = nn.quantized.FloatFunctional() 60 | 61 | def forward(self, x): 62 | """Forward pass. 63 | 64 | Args: 65 | x (tensor): input 66 | 67 | Returns: 68 | tensor: output 69 | """ 70 | 71 | out = self.activation(x) 72 | out = self.conv1(out) 73 | if self.bn == True: 74 | out = self.bn1(out) 75 | 76 | out = self.activation(out) 77 | out = self.conv2(out) 78 | if self.bn == True: 79 | out = self.bn2(out) 80 | 81 | if self.groups > 1: 82 | out = self.conv_merge(out) 83 | 84 | return self.skip_add.add(out, x) 85 | 86 | 87 | class FeatureFusionBlock(nn.Module): 88 | """Feature fusion block. 89 | """ 90 | 91 | def __init__( 92 | self, 93 | features, 94 | activation, 95 | deconv=False, 96 | bn=False, 97 | expand=False, 98 | align_corners=True, 99 | size=None 100 | ): 101 | """Init. 102 | 103 | Args: 104 | features (int): number of features 105 | """ 106 | super(FeatureFusionBlock, self).__init__() 107 | 108 | self.deconv = deconv 109 | self.align_corners = align_corners 110 | 111 | self.groups = 1 112 | 113 | self.expand = expand 114 | out_features = features 115 | if self.expand == True: 116 | out_features = features // 2 117 | 118 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 119 | 120 | self.resConfUnit1 = ResidualConvUnit(features, activation, bn) 121 | self.resConfUnit2 = ResidualConvUnit(features, activation, bn) 122 | 123 | self.skip_add = nn.quantized.FloatFunctional() 124 | 125 | self.size = size 126 | 127 | def forward(self, *xs, size=None): 128 | """Forward pass. 129 | 130 | Returns: 131 | tensor: output 132 | """ 133 | output = xs[0] 134 | 135 | if len(xs) == 2: 136 | res = self.resConfUnit1(xs[1]) 137 | output = self.skip_add.add(output, res) 138 | 139 | output = self.resConfUnit2(output) 140 | 141 | if (size is None) and (self.size is None): 142 | modifier = {"scale_factor": 2} 143 | elif size is None: 144 | modifier = {"size": self.size} 145 | else: 146 | modifier = {"size": size} 147 | 148 | output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) 149 | output = self.out_conv(output) 150 | 151 | return output 152 | -------------------------------------------------------------------------------- /vace/annotators/depth_anything_v2/util/transform.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | class Resize(object): 6 | """Resize sample to given size (width, height). 7 | """ 8 | 9 | def __init__( 10 | self, 11 | width, 12 | height, 13 | resize_target=True, 14 | keep_aspect_ratio=False, 15 | ensure_multiple_of=1, 16 | resize_method="lower_bound", 17 | image_interpolation_method=cv2.INTER_AREA, 18 | ): 19 | """Init. 20 | 21 | Args: 22 | width (int): desired output width 23 | height (int): desired output height 24 | resize_target (bool, optional): 25 | True: Resize the full sample (image, mask, target). 26 | False: Resize image only. 27 | Defaults to True. 28 | keep_aspect_ratio (bool, optional): 29 | True: Keep the aspect ratio of the input sample. 30 | Output sample might not have the given width and height, and 31 | resize behaviour depends on the parameter 'resize_method'. 32 | Defaults to False. 33 | ensure_multiple_of (int, optional): 34 | Output width and height is constrained to be multiple of this parameter. 35 | Defaults to 1. 36 | resize_method (str, optional): 37 | "lower_bound": Output will be at least as large as the given size. 38 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 39 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 40 | Defaults to "lower_bound". 41 | """ 42 | self.__width = width 43 | self.__height = height 44 | 45 | self.__resize_target = resize_target 46 | self.__keep_aspect_ratio = keep_aspect_ratio 47 | self.__multiple_of = ensure_multiple_of 48 | self.__resize_method = resize_method 49 | self.__image_interpolation_method = image_interpolation_method 50 | 51 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 52 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 53 | 54 | if max_val is not None and y > max_val: 55 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 56 | 57 | if y < min_val: 58 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 59 | 60 | return y 61 | 62 | def get_size(self, width, height): 63 | # determine new height and width 64 | scale_height = self.__height / height 65 | scale_width = self.__width / width 66 | 67 | if self.__keep_aspect_ratio: 68 | if self.__resize_method == "lower_bound": 69 | # scale such that output size is lower bound 70 | if scale_width > scale_height: 71 | # fit width 72 | scale_height = scale_width 73 | else: 74 | # fit height 75 | scale_width = scale_height 76 | elif self.__resize_method == "upper_bound": 77 | # scale such that output size is upper bound 78 | if scale_width < scale_height: 79 | # fit width 80 | scale_height = scale_width 81 | else: 82 | # fit height 83 | scale_width = scale_height 84 | elif self.__resize_method == "minimal": 85 | # scale as least as possbile 86 | if abs(1 - scale_width) < abs(1 - scale_height): 87 | # fit width 88 | scale_height = scale_width 89 | else: 90 | # fit height 91 | scale_width = scale_height 92 | else: 93 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 94 | 95 | if self.__resize_method == "lower_bound": 96 | new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) 97 | new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) 98 | elif self.__resize_method == "upper_bound": 99 | new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) 100 | new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) 101 | elif self.__resize_method == "minimal": 102 | new_height = self.constrain_to_multiple_of(scale_height * height) 103 | new_width = self.constrain_to_multiple_of(scale_width * width) 104 | else: 105 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 106 | 107 | return (new_width, new_height) 108 | 109 | def __call__(self, sample): 110 | width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) 111 | 112 | # resize sample 113 | sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method) 114 | 115 | if self.__resize_target: 116 | if "depth" in sample: 117 | sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) 118 | 119 | if "mask" in sample: 120 | sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), 121 | interpolation=cv2.INTER_NEAREST) 122 | 123 | return sample 124 | 125 | 126 | class NormalizeImage(object): 127 | """Normlize image by given mean and std. 128 | """ 129 | 130 | def __init__(self, mean, std): 131 | self.__mean = mean 132 | self.__std = std 133 | 134 | def __call__(self, sample): 135 | sample["image"] = (sample["image"] - self.__mean) / self.__std 136 | 137 | return sample 138 | 139 | 140 | class PrepareForNet(object): 141 | """Prepare sample for usage as network input. 142 | """ 143 | 144 | def __init__(self): 145 | pass 146 | 147 | def __call__(self, sample): 148 | image = np.transpose(sample["image"], (2, 0, 1)) 149 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 150 | 151 | if "depth" in sample: 152 | depth = sample["depth"].astype(np.float32) 153 | sample["depth"] = np.ascontiguousarray(depth) 154 | 155 | if "mask" in sample: 156 | sample["mask"] = sample["mask"].astype(np.float32) 157 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 158 | 159 | return sample 160 | -------------------------------------------------------------------------------- /vace/annotators/dwpose/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | -------------------------------------------------------------------------------- /vace/annotators/dwpose/onnxdet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import cv2 4 | import numpy as np 5 | 6 | import onnxruntime 7 | 8 | def nms(boxes, scores, nms_thr): 9 | """Single class NMS implemented in Numpy.""" 10 | x1 = boxes[:, 0] 11 | y1 = boxes[:, 1] 12 | x2 = boxes[:, 2] 13 | y2 = boxes[:, 3] 14 | 15 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 16 | order = scores.argsort()[::-1] 17 | 18 | keep = [] 19 | while order.size > 0: 20 | i = order[0] 21 | keep.append(i) 22 | xx1 = np.maximum(x1[i], x1[order[1:]]) 23 | yy1 = np.maximum(y1[i], y1[order[1:]]) 24 | xx2 = np.minimum(x2[i], x2[order[1:]]) 25 | yy2 = np.minimum(y2[i], y2[order[1:]]) 26 | 27 | w = np.maximum(0.0, xx2 - xx1 + 1) 28 | h = np.maximum(0.0, yy2 - yy1 + 1) 29 | inter = w * h 30 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 31 | 32 | inds = np.where(ovr <= nms_thr)[0] 33 | order = order[inds + 1] 34 | 35 | return keep 36 | 37 | def multiclass_nms(boxes, scores, nms_thr, score_thr): 38 | """Multiclass NMS implemented in Numpy. Class-aware version.""" 39 | final_dets = [] 40 | num_classes = scores.shape[1] 41 | for cls_ind in range(num_classes): 42 | cls_scores = scores[:, cls_ind] 43 | valid_score_mask = cls_scores > score_thr 44 | if valid_score_mask.sum() == 0: 45 | continue 46 | else: 47 | valid_scores = cls_scores[valid_score_mask] 48 | valid_boxes = boxes[valid_score_mask] 49 | keep = nms(valid_boxes, valid_scores, nms_thr) 50 | if len(keep) > 0: 51 | cls_inds = np.ones((len(keep), 1)) * cls_ind 52 | dets = np.concatenate( 53 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 54 | ) 55 | final_dets.append(dets) 56 | if len(final_dets) == 0: 57 | return None 58 | return np.concatenate(final_dets, 0) 59 | 60 | def demo_postprocess(outputs, img_size, p6=False): 61 | grids = [] 62 | expanded_strides = [] 63 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] 64 | 65 | hsizes = [img_size[0] // stride for stride in strides] 66 | wsizes = [img_size[1] // stride for stride in strides] 67 | 68 | for hsize, wsize, stride in zip(hsizes, wsizes, strides): 69 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) 70 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2) 71 | grids.append(grid) 72 | shape = grid.shape[:2] 73 | expanded_strides.append(np.full((*shape, 1), stride)) 74 | 75 | grids = np.concatenate(grids, 1) 76 | expanded_strides = np.concatenate(expanded_strides, 1) 77 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides 78 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides 79 | 80 | return outputs 81 | 82 | def preprocess(img, input_size, swap=(2, 0, 1)): 83 | if len(img.shape) == 3: 84 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 85 | else: 86 | padded_img = np.ones(input_size, dtype=np.uint8) * 114 87 | 88 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) 89 | resized_img = cv2.resize( 90 | img, 91 | (int(img.shape[1] * r), int(img.shape[0] * r)), 92 | interpolation=cv2.INTER_LINEAR, 93 | ).astype(np.uint8) 94 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img 95 | 96 | padded_img = padded_img.transpose(swap) 97 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) 98 | return padded_img, r 99 | 100 | def inference_detector(session, oriImg): 101 | input_shape = (640,640) 102 | img, ratio = preprocess(oriImg, input_shape) 103 | 104 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} 105 | output = session.run(None, ort_inputs) 106 | predictions = demo_postprocess(output[0], input_shape)[0] 107 | 108 | boxes = predictions[:, :4] 109 | scores = predictions[:, 4:5] * predictions[:, 5:] 110 | 111 | boxes_xyxy = np.ones_like(boxes) 112 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. 113 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. 114 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. 115 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. 116 | boxes_xyxy /= ratio 117 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) 118 | if dets is not None: 119 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] 120 | isscore = final_scores>0.3 121 | iscat = final_cls_inds == 0 122 | isbbox = [ i and j for (i, j) in zip(isscore, iscat)] 123 | final_boxes = final_boxes[isbbox] 124 | else: 125 | final_boxes = np.array([]) 126 | 127 | return final_boxes 128 | -------------------------------------------------------------------------------- /vace/annotators/dwpose/wholebody.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import cv2 4 | import numpy as np 5 | import onnxruntime as ort 6 | from .onnxdet import inference_detector 7 | from .onnxpose import inference_pose 8 | 9 | def HWC3(x): 10 | assert x.dtype == np.uint8 11 | if x.ndim == 2: 12 | x = x[:, :, None] 13 | assert x.ndim == 3 14 | H, W, C = x.shape 15 | assert C == 1 or C == 3 or C == 4 16 | if C == 3: 17 | return x 18 | if C == 1: 19 | return np.concatenate([x, x, x], axis=2) 20 | if C == 4: 21 | color = x[:, :, 0:3].astype(np.float32) 22 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 23 | y = color * alpha + 255.0 * (1.0 - alpha) 24 | y = y.clip(0, 255).astype(np.uint8) 25 | return y 26 | 27 | 28 | def resize_image(input_image, resolution): 29 | H, W, C = input_image.shape 30 | H = float(H) 31 | W = float(W) 32 | k = float(resolution) / min(H, W) 33 | H *= k 34 | W *= k 35 | H = int(np.round(H / 64.0)) * 64 36 | W = int(np.round(W / 64.0)) * 64 37 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 38 | return img 39 | 40 | class Wholebody: 41 | def __init__(self, onnx_det, onnx_pose, device = 'cuda:0'): 42 | 43 | providers = ['CPUExecutionProvider' 44 | ] if device == 'cpu' else ['CUDAExecutionProvider'] 45 | # onnx_det = 'annotator/ckpts/yolox_l.onnx' 46 | # onnx_pose = 'annotator/ckpts/dw-ll_ucoco_384.onnx' 47 | 48 | self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) 49 | self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) 50 | 51 | def __call__(self, ori_img): 52 | det_result = inference_detector(self.session_det, ori_img) 53 | keypoints, scores = inference_pose(self.session_pose, det_result, ori_img) 54 | 55 | keypoints_info = np.concatenate( 56 | (keypoints, scores[..., None]), axis=-1) 57 | # compute neck joint 58 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1) 59 | # neck score when visualizing pred 60 | neck[:, 2:4] = np.logical_and( 61 | keypoints_info[:, 5, 2:4] > 0.3, 62 | keypoints_info[:, 6, 2:4] > 0.3).astype(int) 63 | new_keypoints_info = np.insert( 64 | keypoints_info, 17, neck, axis=1) 65 | mmpose_idx = [ 66 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 67 | ] 68 | openpose_idx = [ 69 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 70 | ] 71 | new_keypoints_info[:, openpose_idx] = \ 72 | new_keypoints_info[:, mmpose_idx] 73 | keypoints_info = new_keypoints_info 74 | 75 | keypoints, scores = keypoints_info[ 76 | ..., :2], keypoints_info[..., 2] 77 | 78 | return keypoints, scores, det_result 79 | 80 | 81 | -------------------------------------------------------------------------------- /vace/annotators/face.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from .utils import convert_to_numpy 8 | 9 | 10 | class FaceAnnotator: 11 | def __init__(self, cfg, device=None): 12 | from insightface.app import FaceAnalysis 13 | self.return_raw = cfg.get('RETURN_RAW', True) 14 | self.return_mask = cfg.get('RETURN_MASK', False) 15 | self.return_dict = cfg.get('RETURN_DICT', False) 16 | self.multi_face = cfg.get('MULTI_FACE', True) 17 | pretrained_model = cfg['PRETRAINED_MODEL'] 18 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device 19 | self.device_id = self.device.index if self.device.type == 'cuda' else None 20 | ctx_id = self.device_id if self.device_id is not None else 0 21 | self.model = FaceAnalysis(name=cfg.MODEL_NAME, root=pretrained_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) 22 | self.model.prepare(ctx_id=ctx_id, det_size=(640, 640)) 23 | 24 | def forward(self, image=None, return_mask=None, return_dict=None): 25 | return_mask = return_mask if return_mask is not None else self.return_mask 26 | return_dict = return_dict if return_dict is not None else self.return_dict 27 | image = convert_to_numpy(image) 28 | # [dict_keys(['bbox', 'kps', 'det_score', 'landmark_3d_68', 'pose', 'landmark_2d_106', 'gender', 'age', 'embedding'])] 29 | faces = self.model.get(image) 30 | if self.return_raw: 31 | return faces 32 | else: 33 | crop_face_list, mask_list = [], [] 34 | if len(faces) > 0: 35 | if not self.multi_face: 36 | faces = faces[:1] 37 | for face in faces: 38 | x_min, y_min, x_max, y_max = face['bbox'].tolist() 39 | crop_face = image[int(y_min): int(y_max) + 1, int(x_min): int(x_max) + 1] 40 | crop_face_list.append(crop_face) 41 | mask = np.zeros_like(image[:, :, 0]) 42 | mask[int(y_min): int(y_max) + 1, int(x_min): int(x_max) + 1] = 255 43 | mask_list.append(mask) 44 | if not self.multi_face: 45 | crop_face_list = crop_face_list[0] 46 | mask_list = mask_list[0] 47 | if return_mask: 48 | if return_dict: 49 | return {'image': crop_face_list, 'mask': mask_list} 50 | else: 51 | return crop_face_list, mask_list 52 | else: 53 | return crop_face_list 54 | else: 55 | return None 56 | -------------------------------------------------------------------------------- /vace/annotators/flow.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import torch 4 | import numpy as np 5 | import argparse 6 | 7 | from .utils import convert_to_numpy 8 | 9 | class FlowAnnotator: 10 | def __init__(self, cfg, device=None): 11 | try: 12 | from raft import RAFT 13 | from raft.utils.utils import InputPadder 14 | from raft.utils import flow_viz 15 | except: 16 | import warnings 17 | warnings.warn( 18 | "ignore raft import, please pip install raft package. you can refer to models/VACE-Annotators/flow/raft-1.0.0-py3-none-any.whl") 19 | 20 | params = { 21 | "small": False, 22 | "mixed_precision": False, 23 | "alternate_corr": False 24 | } 25 | params = argparse.Namespace(**params) 26 | pretrained_model = cfg['PRETRAINED_MODEL'] 27 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device 28 | self.model = RAFT(params) 29 | self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_model, map_location="cpu", weights_only=True).items()}) 30 | self.model = self.model.to(self.device).eval() 31 | self.InputPadder = InputPadder 32 | self.flow_viz = flow_viz 33 | 34 | def forward(self, frames): 35 | # frames / RGB 36 | frames = [torch.from_numpy(convert_to_numpy(frame).astype(np.uint8)).permute(2, 0, 1).float()[None].to(self.device) for frame in frames] 37 | flow_up_list, flow_up_vis_list = [], [] 38 | with torch.no_grad(): 39 | for i, (image1, image2) in enumerate(zip(frames[:-1], frames[1:])): 40 | padder = self.InputPadder(image1.shape) 41 | image1, image2 = padder.pad(image1, image2) 42 | flow_low, flow_up = self.model(image1, image2, iters=20, test_mode=True) 43 | flow_up = flow_up[0].permute(1, 2, 0).cpu().numpy() 44 | flow_up_vis = self.flow_viz.flow_to_image(flow_up) 45 | flow_up_list.append(flow_up) 46 | flow_up_vis_list.append(flow_up_vis) 47 | return flow_up_list, flow_up_vis_list # RGB 48 | 49 | 50 | class FlowVisAnnotator(FlowAnnotator): 51 | def forward(self, frames): 52 | flow_up_list, flow_up_vis_list = super().forward(frames) 53 | return flow_up_vis_list[:1] + flow_up_vis_list -------------------------------------------------------------------------------- /vace/annotators/frameref.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import random 4 | import numpy as np 5 | from .utils import align_frames 6 | 7 | 8 | class FrameRefExtractAnnotator: 9 | para_dict = {} 10 | 11 | def __init__(self, cfg, device=None): 12 | # first / last / firstlast / random 13 | self.ref_cfg = cfg.get('REF_CFG', [{"mode": "first", "proba": 0.1}, 14 | {"mode": "last", "proba": 0.1}, 15 | {"mode": "firstlast", "proba": 0.1}, 16 | {"mode": "random", "proba": 0.1}]) 17 | self.ref_num = cfg.get('REF_NUM', 1) 18 | self.ref_color = cfg.get('REF_COLOR', 127.5) 19 | self.return_dict = cfg.get('RETURN_DICT', True) 20 | self.return_mask = cfg.get('RETURN_MASK', True) 21 | 22 | 23 | def forward(self, frames, ref_cfg=None, ref_num=None, return_mask=None, return_dict=None): 24 | return_mask = return_mask if return_mask is not None else self.return_mask 25 | return_dict = return_dict if return_dict is not None else self.return_dict 26 | ref_cfg = ref_cfg if ref_cfg is not None else self.ref_cfg 27 | ref_cfg = [ref_cfg] if not isinstance(ref_cfg, list) else ref_cfg 28 | probas = [item['proba'] if 'proba' in item else 1.0 / len(ref_cfg) for item in ref_cfg] 29 | sel_ref_cfg = random.choices(ref_cfg, weights=probas, k=1)[0] 30 | mode = sel_ref_cfg['mode'] if 'mode' in sel_ref_cfg else 'original' 31 | ref_num = int(ref_num) if ref_num is not None else self.ref_num 32 | 33 | frame_num = len(frames) 34 | frame_num_range = list(range(frame_num)) 35 | if mode == "first": 36 | sel_idx = frame_num_range[:ref_num] 37 | elif mode == "last": 38 | sel_idx = frame_num_range[-ref_num:] 39 | elif mode == "firstlast": 40 | sel_idx = frame_num_range[:ref_num] + frame_num_range[-ref_num:] 41 | elif mode == "random": 42 | sel_idx = random.sample(frame_num_range, ref_num) 43 | else: 44 | raise NotImplementedError 45 | 46 | out_frames, out_masks = [], [] 47 | for i in range(frame_num): 48 | if i in sel_idx: 49 | out_frame = frames[i] 50 | out_mask = np.zeros_like(frames[i][:, :, 0]) 51 | else: 52 | out_frame = np.ones_like(frames[i]) * self.ref_color 53 | out_mask = np.ones_like(frames[i][:, :, 0]) * 255 54 | out_frames.append(out_frame) 55 | out_masks.append(out_mask) 56 | 57 | if return_dict: 58 | ret_data = {"frames": out_frames} 59 | if return_mask: 60 | ret_data['masks'] = out_masks 61 | return ret_data 62 | else: 63 | if return_mask: 64 | return out_frames, out_masks 65 | else: 66 | return out_frames 67 | 68 | 69 | 70 | class FrameRefExpandAnnotator: 71 | para_dict = {} 72 | 73 | def __init__(self, cfg, device=None): 74 | # first / last / firstlast 75 | self.ref_color = cfg.get('REF_COLOR', 127.5) 76 | self.return_mask = cfg.get('RETURN_MASK', True) 77 | self.return_dict = cfg.get('RETURN_DICT', True) 78 | self.mode = cfg.get('MODE', "firstframe") 79 | assert self.mode in ["firstframe", "lastframe", "firstlastframe", "firstclip", "lastclip", "firstlastclip", "all"] 80 | 81 | def forward(self, image=None, image_2=None, frames=None, frames_2=None, mode=None, expand_num=None, return_mask=None, return_dict=None): 82 | mode = mode if mode is not None else self.mode 83 | return_mask = return_mask if return_mask is not None else self.return_mask 84 | return_dict = return_dict if return_dict is not None else self.return_dict 85 | 86 | if 'frame' in mode: 87 | frames = [image] if image is not None and not isinstance(frames, list) else image 88 | frames_2 = [image_2] if image_2 is not None and not isinstance(image_2, list) else image_2 89 | 90 | expand_frames = [np.ones_like(frames[0]) * self.ref_color] * expand_num 91 | expand_masks = [np.ones_like(frames[0][:, :, 0]) * 255] * expand_num 92 | source_frames = frames 93 | source_masks = [np.zeros_like(frames[0][:, :, 0])] * len(frames) 94 | 95 | if mode in ["firstframe", "firstclip"]: 96 | out_frames = source_frames + expand_frames 97 | out_masks = source_masks + expand_masks 98 | elif mode in ["lastframe", "lastclip"]: 99 | out_frames = expand_frames + source_frames 100 | out_masks = expand_masks + source_masks 101 | elif mode in ["firstlastframe", "firstlastclip"]: 102 | source_frames_2 = [align_frames(source_frames[0], f2) for f2 in frames_2] 103 | source_masks_2 = [np.zeros_like(source_frames_2[0][:, :, 0])] * len(frames_2) 104 | out_frames = source_frames + expand_frames + source_frames_2 105 | out_masks = source_masks + expand_masks + source_masks_2 106 | else: 107 | raise NotImplementedError 108 | 109 | if return_dict: 110 | ret_data = {"frames": out_frames} 111 | if return_mask: 112 | ret_data['masks'] = out_masks 113 | return ret_data 114 | else: 115 | if return_mask: 116 | return out_frames, out_masks 117 | else: 118 | return out_frames 119 | -------------------------------------------------------------------------------- /vace/annotators/gdino.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | 4 | import cv2 5 | import torch 6 | import numpy as np 7 | import torchvision 8 | from .utils import convert_to_numpy 9 | 10 | 11 | class GDINOAnnotator: 12 | def __init__(self, cfg, device=None): 13 | try: 14 | from groundingdino.util.inference import Model, load_model, load_image, predict 15 | except: 16 | import warnings 17 | warnings.warn("please pip install groundingdino package, or you can refer to models/VACE-Annotators/gdino/groundingdino-0.1.0-cp310-cp310-linux_x86_64.whl") 18 | 19 | grounding_dino_config_path = cfg['CONFIG_PATH'] 20 | grounding_dino_checkpoint_path = cfg['PRETRAINED_MODEL'] 21 | grounding_dino_tokenizer_path = cfg['TOKENIZER_PATH'] # TODO 22 | self.box_threshold = cfg.get('BOX_THRESHOLD', 0.25) 23 | self.text_threshold = cfg.get('TEXT_THRESHOLD', 0.2) 24 | self.iou_threshold = cfg.get('IOU_THRESHOLD', 0.5) 25 | self.use_nms = cfg.get('USE_NMS', True) 26 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device 27 | self.model = Model(model_config_path=grounding_dino_config_path, 28 | model_checkpoint_path=grounding_dino_checkpoint_path, 29 | device=self.device) 30 | 31 | def forward(self, image, classes=None, caption=None): 32 | image_bgr = convert_to_numpy(image)[..., ::-1] # bgr 33 | 34 | if classes is not None: 35 | classes = [classes] if isinstance(classes, str) else classes 36 | detections = self.model.predict_with_classes( 37 | image=image_bgr, 38 | classes=classes, 39 | box_threshold=self.box_threshold, 40 | text_threshold=self.text_threshold 41 | ) 42 | elif caption is not None: 43 | detections, phrases = self.model.predict_with_caption( 44 | image=image_bgr, 45 | caption=caption, 46 | box_threshold=self.box_threshold, 47 | text_threshold=self.text_threshold 48 | ) 49 | else: 50 | raise NotImplementedError() 51 | 52 | if self.use_nms: 53 | nms_idx = torchvision.ops.nms( 54 | torch.from_numpy(detections.xyxy), 55 | torch.from_numpy(detections.confidence), 56 | self.iou_threshold 57 | ).numpy().tolist() 58 | detections.xyxy = detections.xyxy[nms_idx] 59 | detections.confidence = detections.confidence[nms_idx] 60 | detections.class_id = detections.class_id[nms_idx] if detections.class_id is not None else None 61 | 62 | boxes = detections.xyxy 63 | confidences = detections.confidence 64 | class_ids = detections.class_id 65 | class_names = [classes[_id] for _id in class_ids] if classes is not None else phrases 66 | 67 | ret_data = { 68 | "boxes": boxes.tolist() if boxes is not None else None, 69 | "confidences": confidences.tolist() if confidences is not None else None, 70 | "class_ids": class_ids.tolist() if class_ids is not None else None, 71 | "class_names": class_names if class_names is not None else None, 72 | } 73 | return ret_data 74 | 75 | 76 | class GDINORAMAnnotator: 77 | def __init__(self, cfg, device=None): 78 | from .ram import RAMAnnotator 79 | from .gdino import GDINOAnnotator 80 | self.ram_model = RAMAnnotator(cfg['RAM'], device=device) 81 | self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device) 82 | 83 | def forward(self, image): 84 | ram_res = self.ram_model.forward(image) 85 | classes = ram_res['tag_e'] if isinstance(ram_res, dict) else ram_res 86 | gdino_res = self.gdino_model.forward(image, classes=classes) 87 | return gdino_res 88 | 89 | -------------------------------------------------------------------------------- /vace/annotators/gray.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | 4 | import cv2 5 | import numpy as np 6 | from .utils import convert_to_numpy 7 | 8 | 9 | class GrayAnnotator: 10 | def __init__(self, cfg): 11 | pass 12 | def forward(self, image): 13 | image = convert_to_numpy(image) 14 | gray_map = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 15 | return gray_map[..., None].repeat(3, axis=2) 16 | 17 | 18 | class GrayVideoAnnotator(GrayAnnotator): 19 | def forward(self, frames): 20 | ret_frames = [] 21 | for frame in frames: 22 | anno_frame = super().forward(np.array(frame)) 23 | ret_frames.append(anno_frame) 24 | return ret_frames 25 | -------------------------------------------------------------------------------- /vace/annotators/layout.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from .utils import convert_to_numpy 8 | 9 | 10 | class LayoutBboxAnnotator: 11 | def __init__(self, cfg, device=None): 12 | self.bg_color = cfg.get('BG_COLOR', [255, 255, 255]) 13 | self.box_color = cfg.get('BOX_COLOR', [0, 0, 0]) 14 | self.frame_size = cfg.get('FRAME_SIZE', [720, 1280]) # [H, W] 15 | self.num_frames = cfg.get('NUM_FRAMES', 81) 16 | ram_tag_color_path = cfg.get('RAM_TAG_COLOR_PATH', None) 17 | self.color_dict = {'default': tuple(self.box_color)} 18 | if ram_tag_color_path is not None: 19 | lines = [id_name_color.strip().split('#;#') for id_name_color in open(ram_tag_color_path).readlines()] 20 | self.color_dict.update({id_name_color[1]: tuple(eval(id_name_color[2])) for id_name_color in lines}) 21 | 22 | def forward(self, bbox, frame_size=None, num_frames=None, label=None, color=None): 23 | frame_size = frame_size if frame_size is not None else self.frame_size 24 | num_frames = num_frames if num_frames is not None else self.num_frames 25 | assert len(bbox) == 2, 'bbox should be a list of two elements (start_bbox & end_bbox)' 26 | # frame_size = [H, W] 27 | # bbox = [x1, y1, x2, y2] 28 | label = label[0] if label is not None and isinstance(label, list) else label 29 | if label is not None and label in self.color_dict: 30 | box_color = self.color_dict[label] 31 | elif color is not None: 32 | box_color = color 33 | else: 34 | box_color = self.color_dict['default'] 35 | start_bbox, end_bbox = bbox 36 | start_bbox = [start_bbox[0], start_bbox[1], start_bbox[2] - start_bbox[0], start_bbox[3] - start_bbox[1]] 37 | start_bbox = np.array(start_bbox, dtype=np.float32) 38 | end_bbox = [end_bbox[0], end_bbox[1], end_bbox[2] - end_bbox[0], end_bbox[3] - end_bbox[1]] 39 | end_bbox = np.array(end_bbox, dtype=np.float32) 40 | bbox_increment = (end_bbox - start_bbox) / num_frames 41 | ret_frames = [] 42 | for frame_idx in range(num_frames): 43 | frame = np.zeros((frame_size[0], frame_size[1], 3), dtype=np.uint8) 44 | frame[:] = self.bg_color 45 | current_bbox = start_bbox + bbox_increment * frame_idx 46 | current_bbox = current_bbox.astype(int) 47 | x, y, w, h = current_bbox 48 | cv2.rectangle(frame, (x, y), (x + w, y + h), box_color, 2) 49 | ret_frames.append(frame[..., ::-1]) 50 | return ret_frames 51 | 52 | 53 | 54 | 55 | class LayoutMaskAnnotator: 56 | def __init__(self, cfg, device=None): 57 | self.use_aug = cfg.get('USE_AUG', False) 58 | self.bg_color = cfg.get('BG_COLOR', [255, 255, 255]) 59 | self.box_color = cfg.get('BOX_COLOR', [0, 0, 0]) 60 | ram_tag_color_path = cfg.get('RAM_TAG_COLOR_PATH', None) 61 | self.color_dict = {'default': tuple(self.box_color)} 62 | if ram_tag_color_path is not None: 63 | lines = [id_name_color.strip().split('#;#') for id_name_color in open(ram_tag_color_path).readlines()] 64 | self.color_dict.update({id_name_color[1]: tuple(eval(id_name_color[2])) for id_name_color in lines}) 65 | if self.use_aug: 66 | from .maskaug import MaskAugAnnotator 67 | self.maskaug_anno = MaskAugAnnotator(cfg={}) 68 | 69 | 70 | def find_contours(self, mask): 71 | contours, hier = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 72 | return contours 73 | 74 | def draw_contours(self, canvas, contour, color): 75 | canvas = np.ascontiguousarray(canvas, dtype=np.uint8) 76 | canvas = cv2.drawContours(canvas, contour, -1, color, thickness=3) 77 | return canvas 78 | 79 | def forward(self, mask=None, color=None, label=None, mask_cfg=None): 80 | if not isinstance(mask, list): 81 | is_batch = False 82 | mask = [mask] 83 | else: 84 | is_batch = True 85 | 86 | if label is not None and label in self.color_dict: 87 | color = self.color_dict[label] 88 | elif color is not None: 89 | color = color 90 | else: 91 | color = self.color_dict['default'] 92 | 93 | ret_data = [] 94 | for sub_mask in mask: 95 | sub_mask = convert_to_numpy(sub_mask) 96 | if self.use_aug: 97 | sub_mask = self.maskaug_anno.forward(sub_mask, mask_cfg) 98 | canvas = np.ones((sub_mask.shape[0], sub_mask.shape[1], 3)) * 255 99 | contour = self.find_contours(sub_mask) 100 | frame = self.draw_contours(canvas, contour, color) 101 | ret_data.append(frame) 102 | 103 | if is_batch: 104 | return ret_data 105 | else: 106 | return ret_data[0] 107 | 108 | 109 | 110 | 111 | class LayoutTrackAnnotator: 112 | def __init__(self, cfg, device=None): 113 | self.use_aug = cfg.get('USE_AUG', False) 114 | self.bg_color = cfg.get('BG_COLOR', [255, 255, 255]) 115 | self.box_color = cfg.get('BOX_COLOR', [0, 0, 0]) 116 | ram_tag_color_path = cfg.get('RAM_TAG_COLOR_PATH', None) 117 | self.color_dict = {'default': tuple(self.box_color)} 118 | if ram_tag_color_path is not None: 119 | lines = [id_name_color.strip().split('#;#') for id_name_color in open(ram_tag_color_path).readlines()] 120 | self.color_dict.update({id_name_color[1]: tuple(eval(id_name_color[2])) for id_name_color in lines}) 121 | if self.use_aug: 122 | from .maskaug import MaskAugAnnotator 123 | self.maskaug_anno = MaskAugAnnotator(cfg={}) 124 | from .inpainting import InpaintingVideoAnnotator 125 | self.inpainting_anno = InpaintingVideoAnnotator(cfg=cfg['INPAINTING']) 126 | 127 | def find_contours(self, mask): 128 | contours, hier = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 129 | return contours 130 | 131 | def draw_contours(self, canvas, contour, color): 132 | canvas = np.ascontiguousarray(canvas, dtype=np.uint8) 133 | canvas = cv2.drawContours(canvas, contour, -1, color, thickness=3) 134 | return canvas 135 | 136 | def forward(self, color=None, mask_cfg=None, frames=None, video=None, mask=None, bbox=None, label=None, caption=None, mode=None): 137 | inp_data = self.inpainting_anno.forward(frames, video, mask, bbox, label, caption, mode) 138 | inp_masks = inp_data['masks'] 139 | 140 | label = label[0] if label is not None and isinstance(label, list) else label 141 | if label is not None and label in self.color_dict: 142 | color = self.color_dict[label] 143 | elif color is not None: 144 | color = color 145 | else: 146 | color = self.color_dict['default'] 147 | 148 | num_frames = len(inp_masks) 149 | ret_data = [] 150 | for i in range(num_frames): 151 | sub_mask = inp_masks[i] 152 | if self.use_aug and mask_cfg is not None: 153 | sub_mask = self.maskaug_anno.forward(sub_mask, mask_cfg) 154 | canvas = np.ones((sub_mask.shape[0], sub_mask.shape[1], 3)) * 255 155 | contour = self.find_contours(sub_mask) 156 | frame = self.draw_contours(canvas, contour, color) 157 | ret_data.append(frame) 158 | 159 | return ret_data 160 | 161 | 162 | -------------------------------------------------------------------------------- /vace/annotators/mask.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | 4 | import numpy as np 5 | from scipy.spatial import ConvexHull 6 | from skimage.draw import polygon 7 | from scipy import ndimage 8 | 9 | from .utils import convert_to_numpy 10 | 11 | 12 | class MaskDrawAnnotator: 13 | def __init__(self, cfg, device=None): 14 | self.mode = cfg.get('MODE', 'maskpoint') 15 | self.return_dict = cfg.get('RETURN_DICT', True) 16 | assert self.mode in ['maskpoint', 'maskbbox', 'mask', 'bbox'] 17 | 18 | def forward(self, 19 | mask=None, 20 | image=None, 21 | bbox=None, 22 | mode=None, 23 | return_dict=None): 24 | mode = mode if mode is not None else self.mode 25 | return_dict = return_dict if return_dict is not None else self.return_dict 26 | 27 | mask = convert_to_numpy(mask) if mask is not None else None 28 | image = convert_to_numpy(image) if image is not None else None 29 | 30 | mask_shape = mask.shape 31 | if mode == 'maskpoint': 32 | scribble = mask.transpose(1, 0) 33 | labeled_array, num_features = ndimage.label(scribble >= 255) 34 | centers = ndimage.center_of_mass(scribble, labeled_array, 35 | range(1, num_features + 1)) 36 | centers = np.array(centers) 37 | out_mask = np.zeros(mask_shape, dtype=np.uint8) 38 | hull = ConvexHull(centers) 39 | hull_vertices = centers[hull.vertices] 40 | rr, cc = polygon(hull_vertices[:, 1], hull_vertices[:, 0], mask_shape) 41 | out_mask[rr, cc] = 255 42 | elif mode == 'maskbbox': 43 | scribble = mask.transpose(1, 0) 44 | labeled_array, num_features = ndimage.label(scribble >= 255) 45 | centers = ndimage.center_of_mass(scribble, labeled_array, 46 | range(1, num_features + 1)) 47 | centers = np.array(centers) 48 | # (x1, y1, x2, y2) 49 | x_min = centers[:, 0].min() 50 | x_max = centers[:, 0].max() 51 | y_min = centers[:, 1].min() 52 | y_max = centers[:, 1].max() 53 | out_mask = np.zeros(mask_shape, dtype=np.uint8) 54 | out_mask[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] = 255 55 | if image is not None: 56 | out_image = image[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] 57 | elif mode == 'bbox': 58 | if isinstance(bbox, list): 59 | bbox = np.array(bbox) 60 | x_min, y_min, x_max, y_max = bbox 61 | out_mask = np.zeros(mask_shape, dtype=np.uint8) 62 | out_mask[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] = 255 63 | if image is not None: 64 | out_image = image[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] 65 | elif mode == 'mask': 66 | out_mask = mask 67 | else: 68 | raise NotImplementedError 69 | 70 | if return_dict: 71 | if image is not None: 72 | return {"image": out_image, "mask": out_mask} 73 | else: 74 | return {"mask": out_mask} 75 | else: 76 | if image is not None: 77 | return out_image, out_mask 78 | else: 79 | return out_mask -------------------------------------------------------------------------------- /vace/annotators/maskaug.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | 4 | 5 | import random 6 | from functools import partial 7 | 8 | import cv2 9 | import numpy as np 10 | from PIL import Image, ImageDraw 11 | 12 | from .utils import convert_to_numpy 13 | 14 | 15 | 16 | class MaskAugAnnotator: 17 | def __init__(self, cfg, device=None): 18 | # original / original_expand / hull / hull_expand / bbox / bbox_expand 19 | self.mask_cfg = cfg.get('MASK_CFG', [{"mode": "original", "proba": 0.1}, 20 | {"mode": "original_expand", "proba": 0.1}, 21 | {"mode": "hull", "proba": 0.1}, 22 | {"mode": "hull_expand", "proba":0.1, "kwargs": {"expand_ratio": 0.2}}, 23 | {"mode": "bbox", "proba": 0.1}, 24 | {"mode": "bbox_expand", "proba": 0.1, "kwargs": {"min_expand_ratio": 0.2, "max_expand_ratio": 0.5}}]) 25 | 26 | def forward(self, mask, mask_cfg=None): 27 | mask_cfg = mask_cfg if mask_cfg is not None else self.mask_cfg 28 | if not isinstance(mask, list): 29 | is_batch = False 30 | masks = [mask] 31 | else: 32 | is_batch = True 33 | masks = mask 34 | 35 | mask_func = self.get_mask_func(mask_cfg) 36 | # print(mask_func) 37 | aug_masks = [] 38 | for submask in masks: 39 | mask = convert_to_numpy(submask) 40 | valid, large, h, w, bbox = self.get_mask_info(mask) 41 | # print(valid, large, h, w, bbox) 42 | if valid: 43 | mask = mask_func(mask, bbox, h, w) 44 | else: 45 | mask = mask.astype(np.uint8) 46 | aug_masks.append(mask) 47 | return aug_masks if is_batch else aug_masks[0] 48 | 49 | def get_mask_info(self, mask): 50 | h, w = mask.shape 51 | locs = mask.nonzero() 52 | valid = True 53 | if len(locs) < 1 or locs[0].shape[0] < 1 or locs[1].shape[0] < 1: 54 | valid = False 55 | return valid, False, h, w, [0, 0, 0, 0] 56 | 57 | left, right = np.min(locs[1]), np.max(locs[1]) 58 | top, bottom = np.min(locs[0]), np.max(locs[0]) 59 | bbox = [left, top, right, bottom] 60 | 61 | large = False 62 | if (right - left + 1) * (bottom - top + 1) > 0.9 * h * w: 63 | large = True 64 | return valid, large, h, w, bbox 65 | 66 | def get_expand_params(self, mask_kwargs): 67 | if 'expand_ratio' in mask_kwargs: 68 | expand_ratio = mask_kwargs['expand_ratio'] 69 | elif 'min_expand_ratio' in mask_kwargs and 'max_expand_ratio' in mask_kwargs: 70 | expand_ratio = random.uniform(mask_kwargs['min_expand_ratio'], mask_kwargs['max_expand_ratio']) 71 | else: 72 | expand_ratio = 0.3 73 | 74 | if 'expand_iters' in mask_kwargs: 75 | expand_iters = mask_kwargs['expand_iters'] 76 | else: 77 | expand_iters = random.randint(1, 10) 78 | 79 | if 'expand_lrtp' in mask_kwargs: 80 | expand_lrtp = mask_kwargs['expand_lrtp'] 81 | else: 82 | expand_lrtp = [random.random(), random.random(), random.random(), random.random()] 83 | 84 | return expand_ratio, expand_iters, expand_lrtp 85 | 86 | def get_mask_func(self, mask_cfg): 87 | if not isinstance(mask_cfg, list): 88 | mask_cfg = [mask_cfg] 89 | probas = [item['proba'] if 'proba' in item else 1.0 / len(mask_cfg) for item in mask_cfg] 90 | sel_mask_cfg = random.choices(mask_cfg, weights=probas, k=1)[0] 91 | mode = sel_mask_cfg['mode'] if 'mode' in sel_mask_cfg else 'original' 92 | mask_kwargs = sel_mask_cfg['kwargs'] if 'kwargs' in sel_mask_cfg else {} 93 | 94 | if mode == 'random': 95 | mode = random.choice(['original', 'original_expand', 'hull', 'hull_expand', 'bbox', 'bbox_expand']) 96 | if mode == 'original': 97 | mask_func = partial(self.generate_mask) 98 | elif mode == 'original_expand': 99 | expand_ratio, expand_iters, expand_lrtp = self.get_expand_params(mask_kwargs) 100 | mask_func = partial(self.generate_mask, expand_ratio=expand_ratio, expand_iters=expand_iters, expand_lrtp=expand_lrtp) 101 | elif mode == 'hull': 102 | clockwise = random.choice([True, False]) if 'clockwise' not in mask_kwargs else mask_kwargs['clockwise'] 103 | mask_func = partial(self.generate_hull_mask, clockwise=clockwise) 104 | elif mode == 'hull_expand': 105 | expand_ratio, expand_iters, expand_lrtp = self.get_expand_params(mask_kwargs) 106 | clockwise = random.choice([True, False]) if 'clockwise' not in mask_kwargs else mask_kwargs['clockwise'] 107 | mask_func = partial(self.generate_hull_mask, clockwise=clockwise, expand_ratio=expand_ratio, expand_iters=expand_iters, expand_lrtp=expand_lrtp) 108 | elif mode == 'bbox': 109 | mask_func = partial(self.generate_bbox_mask) 110 | elif mode == 'bbox_expand': 111 | expand_ratio, expand_iters, expand_lrtp = self.get_expand_params(mask_kwargs) 112 | mask_func = partial(self.generate_bbox_mask, expand_ratio=expand_ratio, expand_iters=expand_iters, expand_lrtp=expand_lrtp) 113 | else: 114 | raise NotImplementedError 115 | return mask_func 116 | 117 | 118 | def generate_mask(self, mask, bbox, h, w, expand_ratio=None, expand_iters=None, expand_lrtp=None): 119 | bin_mask = mask.astype(np.uint8) 120 | if expand_ratio: 121 | bin_mask = self.rand_expand_mask(bin_mask, bbox, h, w, expand_ratio, expand_iters, expand_lrtp) 122 | return bin_mask 123 | 124 | 125 | @staticmethod 126 | def rand_expand_mask(mask, bbox, h, w, expand_ratio=None, expand_iters=None, expand_lrtp=None): 127 | expand_ratio = 0.3 if expand_ratio is None else expand_ratio 128 | expand_iters = random.randint(1, 10) if expand_iters is None else expand_iters 129 | expand_lrtp = [random.random(), random.random(), random.random(), random.random()] if expand_lrtp is None else expand_lrtp 130 | # print('iters', expand_iters, 'expand_ratio', expand_ratio, 'expand_lrtp', expand_lrtp) 131 | # mask = np.squeeze(mask) 132 | left, top, right, bottom = bbox 133 | # mask expansion 134 | box_w = (right - left + 1) * expand_ratio 135 | box_h = (bottom - top + 1) * expand_ratio 136 | left_, right_ = int(expand_lrtp[0] * min(box_w, left / 2) / expand_iters), int( 137 | expand_lrtp[1] * min(box_w, (w - right) / 2) / expand_iters) 138 | top_, bottom_ = int(expand_lrtp[2] * min(box_h, top / 2) / expand_iters), int( 139 | expand_lrtp[3] * min(box_h, (h - bottom) / 2) / expand_iters) 140 | kernel_size = max(left_, right_, top_, bottom_) 141 | if kernel_size > 0: 142 | kernel = np.zeros((kernel_size * 2, kernel_size * 2), dtype=np.uint8) 143 | new_left, new_right = kernel_size - right_, kernel_size + left_ 144 | new_top, new_bottom = kernel_size - bottom_, kernel_size + top_ 145 | kernel[new_top:new_bottom + 1, new_left:new_right + 1] = 1 146 | mask = mask.astype(np.uint8) 147 | mask = cv2.dilate(mask, kernel, iterations=expand_iters).astype(np.uint8) 148 | # mask = new_mask - (mask / 2).astype(np.uint8) 149 | # mask = np.expand_dims(mask, axis=-1) 150 | return mask 151 | 152 | 153 | @staticmethod 154 | def _convexhull(image, clockwise): 155 | contours, hierarchy = cv2.findContours(image, 2, 1) 156 | cnt = np.concatenate(contours) # merge all regions 157 | hull = cv2.convexHull(cnt, clockwise=clockwise) 158 | hull = np.squeeze(hull, axis=1).astype(np.float32).tolist() 159 | hull = [tuple(x) for x in hull] 160 | return hull # b, 1, 2 161 | 162 | def generate_hull_mask(self, mask, bbox, h, w, clockwise=None, expand_ratio=None, expand_iters=None, expand_lrtp=None): 163 | clockwise = random.choice([True, False]) if clockwise is None else clockwise 164 | hull = self._convexhull(mask, clockwise) 165 | mask_img = Image.new('L', (w, h), 0) 166 | pt_list = hull 167 | mask_img_draw = ImageDraw.Draw(mask_img) 168 | mask_img_draw.polygon(pt_list, fill=255) 169 | bin_mask = np.array(mask_img).astype(np.uint8) 170 | if expand_ratio: 171 | bin_mask = self.rand_expand_mask(bin_mask, bbox, h, w, expand_ratio, expand_iters, expand_lrtp) 172 | return bin_mask 173 | 174 | 175 | def generate_bbox_mask(self, mask, bbox, h, w, expand_ratio=None, expand_iters=None, expand_lrtp=None): 176 | left, top, right, bottom = bbox 177 | bin_mask = np.zeros((h, w), dtype=np.uint8) 178 | bin_mask[top:bottom + 1, left:right + 1] = 255 179 | if expand_ratio: 180 | bin_mask = self.rand_expand_mask(bin_mask, bbox, h, w, expand_ratio, expand_iters, expand_lrtp) 181 | return bin_mask -------------------------------------------------------------------------------- /vace/annotators/midas/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | -------------------------------------------------------------------------------- /vace/annotators/midas/api.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | # based on https://github.com/isl-org/MiDaS 4 | 5 | import cv2 6 | import torch 7 | import torch.nn as nn 8 | from torchvision.transforms import Compose 9 | 10 | from .dpt_depth import DPTDepthModel 11 | from .midas_net import MidasNet 12 | from .midas_net_custom import MidasNet_small 13 | from .transforms import NormalizeImage, PrepareForNet, Resize 14 | 15 | # ISL_PATHS = { 16 | # "dpt_large": "dpt_large-midas-2f21e586.pt", 17 | # "dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt", 18 | # "midas_v21": "", 19 | # "midas_v21_small": "", 20 | # } 21 | 22 | # remote_model_path = 23 | # "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt" 24 | 25 | 26 | def disabled_train(self, mode=True): 27 | """Overwrite model.train with this function to make sure train/eval mode 28 | does not change anymore.""" 29 | return self 30 | 31 | 32 | def load_midas_transform(model_type): 33 | # https://github.com/isl-org/MiDaS/blob/master/run.py 34 | # load transform only 35 | if model_type == 'dpt_large': # DPT-Large 36 | net_w, net_h = 384, 384 37 | resize_mode = 'minimal' 38 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], 39 | std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == 'dpt_hybrid': # DPT-Hybrid 42 | net_w, net_h = 384, 384 43 | resize_mode = 'minimal' 44 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], 45 | std=[0.5, 0.5, 0.5]) 46 | 47 | elif model_type == 'midas_v21': 48 | net_w, net_h = 384, 384 49 | resize_mode = 'upper_bound' 50 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], 51 | std=[0.229, 0.224, 0.225]) 52 | 53 | elif model_type == 'midas_v21_small': 54 | net_w, net_h = 256, 256 55 | resize_mode = 'upper_bound' 56 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], 57 | std=[0.229, 0.224, 0.225]) 58 | 59 | else: 60 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 61 | 62 | transform = Compose([ 63 | Resize( 64 | net_w, 65 | net_h, 66 | resize_target=None, 67 | keep_aspect_ratio=True, 68 | ensure_multiple_of=32, 69 | resize_method=resize_mode, 70 | image_interpolation_method=cv2.INTER_CUBIC, 71 | ), 72 | normalization, 73 | PrepareForNet(), 74 | ]) 75 | 76 | return transform 77 | 78 | 79 | def load_model(model_type, model_path): 80 | # https://github.com/isl-org/MiDaS/blob/master/run.py 81 | # load network 82 | # model_path = ISL_PATHS[model_type] 83 | if model_type == 'dpt_large': # DPT-Large 84 | model = DPTDepthModel( 85 | path=model_path, 86 | backbone='vitl16_384', 87 | non_negative=True, 88 | ) 89 | net_w, net_h = 384, 384 90 | resize_mode = 'minimal' 91 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], 92 | std=[0.5, 0.5, 0.5]) 93 | 94 | elif model_type == 'dpt_hybrid': # DPT-Hybrid 95 | model = DPTDepthModel( 96 | path=model_path, 97 | backbone='vitb_rn50_384', 98 | non_negative=True, 99 | ) 100 | net_w, net_h = 384, 384 101 | resize_mode = 'minimal' 102 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], 103 | std=[0.5, 0.5, 0.5]) 104 | 105 | elif model_type == 'midas_v21': 106 | model = MidasNet(model_path, non_negative=True) 107 | net_w, net_h = 384, 384 108 | resize_mode = 'upper_bound' 109 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], 110 | std=[0.229, 0.224, 0.225]) 111 | 112 | elif model_type == 'midas_v21_small': 113 | model = MidasNet_small(model_path, 114 | features=64, 115 | backbone='efficientnet_lite3', 116 | exportable=True, 117 | non_negative=True, 118 | blocks={'expand': True}) 119 | net_w, net_h = 256, 256 120 | resize_mode = 'upper_bound' 121 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], 122 | std=[0.229, 0.224, 0.225]) 123 | 124 | else: 125 | print( 126 | f"model_type '{model_type}' not implemented, use: --model_type large" 127 | ) 128 | assert False 129 | 130 | transform = Compose([ 131 | Resize( 132 | net_w, 133 | net_h, 134 | resize_target=None, 135 | keep_aspect_ratio=True, 136 | ensure_multiple_of=32, 137 | resize_method=resize_mode, 138 | image_interpolation_method=cv2.INTER_CUBIC, 139 | ), 140 | normalization, 141 | PrepareForNet(), 142 | ]) 143 | 144 | return model.eval(), transform 145 | 146 | 147 | class MiDaSInference(nn.Module): 148 | MODEL_TYPES_TORCH_HUB = ['DPT_Large', 'DPT_Hybrid', 'MiDaS_small'] 149 | MODEL_TYPES_ISL = [ 150 | 'dpt_large', 151 | 'dpt_hybrid', 152 | 'midas_v21', 153 | 'midas_v21_small', 154 | ] 155 | 156 | def __init__(self, model_type, model_path): 157 | super().__init__() 158 | assert (model_type in self.MODEL_TYPES_ISL) 159 | model, _ = load_model(model_type, model_path) 160 | self.model = model 161 | self.model.train = disabled_train 162 | 163 | def forward(self, x): 164 | with torch.no_grad(): 165 | prediction = self.model(x) 166 | return prediction 167 | -------------------------------------------------------------------------------- /vace/annotators/midas/base_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import torch 4 | 5 | 6 | class BaseModel(torch.nn.Module): 7 | def load(self, path): 8 | """Load model from file. 9 | 10 | Args: 11 | path (str): file path 12 | """ 13 | parameters = torch.load(path, map_location=torch.device('cpu'), weights_only=True) 14 | 15 | if 'optimizer' in parameters: 16 | parameters = parameters['model'] 17 | 18 | self.load_state_dict(parameters) 19 | -------------------------------------------------------------------------------- /vace/annotators/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .base_model import BaseModel 7 | from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder 8 | from .vit import forward_vit 9 | 10 | 11 | def _make_fusion_block(features, use_bn): 12 | return FeatureFusionBlock_custom( 13 | features, 14 | nn.ReLU(False), 15 | deconv=False, 16 | bn=use_bn, 17 | expand=False, 18 | align_corners=True, 19 | ) 20 | 21 | 22 | class DPT(BaseModel): 23 | def __init__( 24 | self, 25 | head, 26 | features=256, 27 | backbone='vitb_rn50_384', 28 | readout='project', 29 | channels_last=False, 30 | use_bn=False, 31 | ): 32 | 33 | super(DPT, self).__init__() 34 | 35 | self.channels_last = channels_last 36 | 37 | hooks = { 38 | 'vitb_rn50_384': [0, 1, 8, 11], 39 | 'vitb16_384': [2, 5, 8, 11], 40 | 'vitl16_384': [5, 11, 17, 23], 41 | } 42 | 43 | # Instantiate backbone and reassemble blocks 44 | self.pretrained, self.scratch = _make_encoder( 45 | backbone, 46 | features, 47 | False, # Set to true of you want to train from scratch, uses ImageNet weights 48 | groups=1, 49 | expand=False, 50 | exportable=False, 51 | hooks=hooks[backbone], 52 | use_readout=readout, 53 | ) 54 | 55 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 56 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 57 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 58 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 59 | 60 | self.scratch.output_conv = head 61 | 62 | def forward(self, x): 63 | if self.channels_last is True: 64 | x.contiguous(memory_format=torch.channels_last) 65 | 66 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 67 | 68 | layer_1_rn = self.scratch.layer1_rn(layer_1) 69 | layer_2_rn = self.scratch.layer2_rn(layer_2) 70 | layer_3_rn = self.scratch.layer3_rn(layer_3) 71 | layer_4_rn = self.scratch.layer4_rn(layer_4) 72 | 73 | path_4 = self.scratch.refinenet4(layer_4_rn) 74 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 75 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 76 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 77 | 78 | out = self.scratch.output_conv(path_1) 79 | 80 | return out 81 | 82 | 83 | class DPTDepthModel(DPT): 84 | def __init__(self, path=None, non_negative=True, **kwargs): 85 | features = kwargs['features'] if 'features' in kwargs else 256 86 | 87 | head = nn.Sequential( 88 | nn.Conv2d(features, 89 | features // 2, 90 | kernel_size=3, 91 | stride=1, 92 | padding=1), 93 | Interpolate(scale_factor=2, mode='bilinear', align_corners=True), 94 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 95 | nn.ReLU(True), 96 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 97 | nn.ReLU(True) if non_negative else nn.Identity(), 98 | nn.Identity(), 99 | ) 100 | 101 | super().__init__(head, **kwargs) 102 | 103 | if path is not None: 104 | self.load(path) 105 | 106 | def forward(self, x): 107 | return super().forward(x).squeeze(dim=1) 108 | -------------------------------------------------------------------------------- /vace/annotators/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 4 | This file contains code that is adapted from 5 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .base_model import BaseModel 11 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 12 | 13 | 14 | class MidasNet(BaseModel): 15 | """Network for monocular depth estimation. 16 | """ 17 | def __init__(self, path=None, features=256, non_negative=True): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print('Loading weights: ', path) 26 | 27 | super(MidasNet, self).__init__() 28 | 29 | use_pretrained = False if path is None else True 30 | 31 | self.pretrained, self.scratch = _make_encoder( 32 | backbone='resnext101_wsl', 33 | features=features, 34 | use_pretrained=use_pretrained) 35 | 36 | self.scratch.refinenet4 = FeatureFusionBlock(features) 37 | self.scratch.refinenet3 = FeatureFusionBlock(features) 38 | self.scratch.refinenet2 = FeatureFusionBlock(features) 39 | self.scratch.refinenet1 = FeatureFusionBlock(features) 40 | 41 | self.scratch.output_conv = nn.Sequential( 42 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 43 | Interpolate(scale_factor=2, mode='bilinear'), 44 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 45 | nn.ReLU(True), 46 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 47 | nn.ReLU(True) if non_negative else nn.Identity(), 48 | ) 49 | 50 | if path: 51 | self.load(path) 52 | 53 | def forward(self, x): 54 | """Forward pass. 55 | 56 | Args: 57 | x (tensor): input data (image) 58 | 59 | Returns: 60 | tensor: depth 61 | """ 62 | 63 | layer_1 = self.pretrained.layer1(x) 64 | layer_2 = self.pretrained.layer2(layer_1) 65 | layer_3 = self.pretrained.layer3(layer_2) 66 | layer_4 = self.pretrained.layer4(layer_3) 67 | 68 | layer_1_rn = self.scratch.layer1_rn(layer_1) 69 | layer_2_rn = self.scratch.layer2_rn(layer_2) 70 | layer_3_rn = self.scratch.layer3_rn(layer_3) 71 | layer_4_rn = self.scratch.layer4_rn(layer_4) 72 | 73 | path_4 = self.scratch.refinenet4(layer_4_rn) 74 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 75 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 76 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 77 | 78 | out = self.scratch.output_conv(path_1) 79 | 80 | return torch.squeeze(out, dim=1) 81 | -------------------------------------------------------------------------------- /vace/annotators/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 4 | This file contains code that is adapted from 5 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .base_model import BaseModel 11 | from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder 12 | 13 | 14 | class MidasNet_small(BaseModel): 15 | """Network for monocular depth estimation. 16 | """ 17 | def __init__(self, 18 | path=None, 19 | features=64, 20 | backbone='efficientnet_lite3', 21 | non_negative=True, 22 | exportable=True, 23 | channels_last=False, 24 | align_corners=True, 25 | blocks={'expand': True}): 26 | """Init. 27 | 28 | Args: 29 | path (str, optional): Path to saved model. Defaults to None. 30 | features (int, optional): Number of features. Defaults to 256. 31 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 32 | """ 33 | print('Loading weights: ', path) 34 | 35 | super(MidasNet_small, self).__init__() 36 | 37 | use_pretrained = False if path else True 38 | 39 | self.channels_last = channels_last 40 | self.blocks = blocks 41 | self.backbone = backbone 42 | 43 | self.groups = 1 44 | 45 | features1 = features 46 | features2 = features 47 | features3 = features 48 | features4 = features 49 | self.expand = False 50 | if 'expand' in self.blocks and self.blocks['expand'] is True: 51 | self.expand = True 52 | features1 = features 53 | features2 = features * 2 54 | features3 = features * 4 55 | features4 = features * 8 56 | 57 | self.pretrained, self.scratch = _make_encoder(self.backbone, 58 | features, 59 | use_pretrained, 60 | groups=self.groups, 61 | expand=self.expand, 62 | exportable=exportable) 63 | 64 | self.scratch.activation = nn.ReLU(False) 65 | 66 | self.scratch.refinenet4 = FeatureFusionBlock_custom( 67 | features4, 68 | self.scratch.activation, 69 | deconv=False, 70 | bn=False, 71 | expand=self.expand, 72 | align_corners=align_corners) 73 | self.scratch.refinenet3 = FeatureFusionBlock_custom( 74 | features3, 75 | self.scratch.activation, 76 | deconv=False, 77 | bn=False, 78 | expand=self.expand, 79 | align_corners=align_corners) 80 | self.scratch.refinenet2 = FeatureFusionBlock_custom( 81 | features2, 82 | self.scratch.activation, 83 | deconv=False, 84 | bn=False, 85 | expand=self.expand, 86 | align_corners=align_corners) 87 | self.scratch.refinenet1 = FeatureFusionBlock_custom( 88 | features1, 89 | self.scratch.activation, 90 | deconv=False, 91 | bn=False, 92 | align_corners=align_corners) 93 | 94 | self.scratch.output_conv = nn.Sequential( 95 | nn.Conv2d(features, 96 | features // 2, 97 | kernel_size=3, 98 | stride=1, 99 | padding=1, 100 | groups=self.groups), 101 | Interpolate(scale_factor=2, mode='bilinear'), 102 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 103 | self.scratch.activation, 104 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 105 | nn.ReLU(True) if non_negative else nn.Identity(), 106 | nn.Identity(), 107 | ) 108 | 109 | if path: 110 | self.load(path) 111 | 112 | def forward(self, x): 113 | """Forward pass. 114 | 115 | Args: 116 | x (tensor): input data (image) 117 | 118 | Returns: 119 | tensor: depth 120 | """ 121 | if self.channels_last is True: 122 | print('self.channels_last = ', self.channels_last) 123 | x.contiguous(memory_format=torch.channels_last) 124 | 125 | layer_1 = self.pretrained.layer1(x) 126 | layer_2 = self.pretrained.layer2(layer_1) 127 | layer_3 = self.pretrained.layer3(layer_2) 128 | layer_4 = self.pretrained.layer4(layer_3) 129 | 130 | layer_1_rn = self.scratch.layer1_rn(layer_1) 131 | layer_2_rn = self.scratch.layer2_rn(layer_2) 132 | layer_3_rn = self.scratch.layer3_rn(layer_3) 133 | layer_4_rn = self.scratch.layer4_rn(layer_4) 134 | 135 | path_4 = self.scratch.refinenet4(layer_4_rn) 136 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 137 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 138 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 139 | 140 | out = self.scratch.output_conv(path_1) 141 | 142 | return torch.squeeze(out, dim=1) 143 | 144 | 145 | def fuse_model(m): 146 | prev_previous_type = nn.Identity() 147 | prev_previous_name = '' 148 | previous_type = nn.Identity() 149 | previous_name = '' 150 | for name, module in m.named_modules(): 151 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type( 152 | module) == nn.ReLU: 153 | # print("FUSED ", prev_previous_name, previous_name, name) 154 | torch.quantization.fuse_modules( 155 | m, [prev_previous_name, previous_name, name], inplace=True) 156 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 157 | # print("FUSED ", prev_previous_name, previous_name) 158 | torch.quantization.fuse_modules( 159 | m, [prev_previous_name, previous_name], inplace=True) 160 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 161 | # print("FUSED ", previous_name, name) 162 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 163 | 164 | prev_previous_type = previous_type 165 | prev_previous_name = previous_name 166 | previous_type = type(module) 167 | previous_name = name 168 | -------------------------------------------------------------------------------- /vace/annotators/midas/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | """Utils for monoDepth.""" 4 | import re 5 | import sys 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def read_pfm(path): 13 | """Read pfm file. 14 | 15 | Args: 16 | path (str): path to file 17 | 18 | Returns: 19 | tuple: (data, scale) 20 | """ 21 | with open(path, 'rb') as file: 22 | 23 | color = None 24 | width = None 25 | height = None 26 | scale = None 27 | endian = None 28 | 29 | header = file.readline().rstrip() 30 | if header.decode('ascii') == 'PF': 31 | color = True 32 | elif header.decode('ascii') == 'Pf': 33 | color = False 34 | else: 35 | raise Exception('Not a PFM file: ' + path) 36 | 37 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', 38 | file.readline().decode('ascii')) 39 | if dim_match: 40 | width, height = list(map(int, dim_match.groups())) 41 | else: 42 | raise Exception('Malformed PFM header.') 43 | 44 | scale = float(file.readline().decode('ascii').rstrip()) 45 | if scale < 0: 46 | # little-endian 47 | endian = '<' 48 | scale = -scale 49 | else: 50 | # big-endian 51 | endian = '>' 52 | 53 | data = np.fromfile(file, endian + 'f') 54 | shape = (height, width, 3) if color else (height, width) 55 | 56 | data = np.reshape(data, shape) 57 | data = np.flipud(data) 58 | 59 | return data, scale 60 | 61 | 62 | def write_pfm(path, image, scale=1): 63 | """Write pfm file. 64 | 65 | Args: 66 | path (str): pathto file 67 | image (array): data 68 | scale (int, optional): Scale. Defaults to 1. 69 | """ 70 | 71 | with open(path, 'wb') as file: 72 | color = None 73 | 74 | if image.dtype.name != 'float32': 75 | raise Exception('Image dtype must be float32.') 76 | 77 | image = np.flipud(image) 78 | 79 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 80 | color = True 81 | elif (len(image.shape) == 2 82 | or len(image.shape) == 3 and image.shape[2] == 1): # greyscale 83 | color = False 84 | else: 85 | raise Exception( 86 | 'Image must have H x W x 3, H x W x 1 or H x W dimensions.') 87 | 88 | file.write('PF\n' if color else 'Pf\n'.encode()) 89 | file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) 90 | 91 | endian = image.dtype.byteorder 92 | 93 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 94 | scale = -scale 95 | 96 | file.write('%f\n'.encode() % scale) 97 | 98 | image.tofile(file) 99 | 100 | 101 | def read_image(path): 102 | """Read image and output RGB image (0-1). 103 | 104 | Args: 105 | path (str): path to file 106 | 107 | Returns: 108 | array: RGB image (0-1) 109 | """ 110 | img = cv2.imread(path) 111 | 112 | if img.ndim == 2: 113 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 114 | 115 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 116 | 117 | return img 118 | 119 | 120 | def resize_image(img): 121 | """Resize image and make it fit for network. 122 | 123 | Args: 124 | img (array): image 125 | 126 | Returns: 127 | tensor: data ready for network 128 | """ 129 | height_orig = img.shape[0] 130 | width_orig = img.shape[1] 131 | 132 | if width_orig > height_orig: 133 | scale = width_orig / 384 134 | else: 135 | scale = height_orig / 384 136 | 137 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 138 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 139 | 140 | img_resized = cv2.resize(img, (width, height), 141 | interpolation=cv2.INTER_AREA) 142 | 143 | img_resized = (torch.from_numpy(np.transpose( 144 | img_resized, (2, 0, 1))).contiguous().float()) 145 | img_resized = img_resized.unsqueeze(0) 146 | 147 | return img_resized 148 | 149 | 150 | def resize_depth(depth, width, height): 151 | """Resize depth map and bring to CPU (numpy). 152 | 153 | Args: 154 | depth (tensor): depth 155 | width (int): image width 156 | height (int): image height 157 | 158 | Returns: 159 | array: processed depth 160 | """ 161 | depth = torch.squeeze(depth[0, :, :, :]).to('cpu') 162 | 163 | depth_resized = cv2.resize(depth.numpy(), (width, height), 164 | interpolation=cv2.INTER_CUBIC) 165 | 166 | return depth_resized 167 | 168 | 169 | def write_depth(path, depth, bits=1): 170 | """Write depth map to pfm and png file. 171 | 172 | Args: 173 | path (str): filepath without extension 174 | depth (array): depth 175 | """ 176 | write_pfm(path + '.pfm', depth.astype(np.float32)) 177 | 178 | depth_min = depth.min() 179 | depth_max = depth.max() 180 | 181 | max_val = (2**(8 * bits)) - 1 182 | 183 | if depth_max - depth_min > np.finfo('float').eps: 184 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 185 | else: 186 | out = np.zeros(depth.shape, dtype=depth.type) 187 | 188 | if bits == 1: 189 | cv2.imwrite(path + '.png', out.astype('uint8')) 190 | elif bits == 2: 191 | cv2.imwrite(path + '.png', out.astype('uint16')) 192 | 193 | return 194 | -------------------------------------------------------------------------------- /vace/annotators/pose.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | 4 | import os 5 | 6 | import cv2 7 | import torch 8 | import numpy as np 9 | from .dwpose import util 10 | from .dwpose.wholebody import Wholebody, HWC3, resize_image 11 | from .utils import convert_to_numpy 12 | 13 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 14 | 15 | 16 | 17 | def draw_pose(pose, H, W, use_hand=False, use_body=False, use_face=False): 18 | bodies = pose['bodies'] 19 | faces = pose['faces'] 20 | hands = pose['hands'] 21 | candidate = bodies['candidate'] 22 | subset = bodies['subset'] 23 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) 24 | 25 | if use_body: 26 | canvas = util.draw_bodypose(canvas, candidate, subset) 27 | if use_hand: 28 | canvas = util.draw_handpose(canvas, hands) 29 | if use_face: 30 | canvas = util.draw_facepose(canvas, faces) 31 | 32 | return canvas 33 | 34 | 35 | class PoseAnnotator: 36 | def __init__(self, cfg, device=None): 37 | onnx_det = cfg['DETECTION_MODEL'] 38 | onnx_pose = cfg['POSE_MODEL'] 39 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device 40 | self.pose_estimation = Wholebody(onnx_det, onnx_pose, device=self.device) 41 | self.resize_size = cfg.get("RESIZE_SIZE", 1024) 42 | self.use_body = cfg.get('USE_BODY', True) 43 | self.use_face = cfg.get('USE_FACE', True) 44 | self.use_hand = cfg.get('USE_HAND', True) 45 | 46 | @torch.no_grad() 47 | @torch.inference_mode 48 | def forward(self, image): 49 | image = convert_to_numpy(image) 50 | input_image = HWC3(image[..., ::-1]) 51 | return self.process(resize_image(input_image, self.resize_size), image.shape[:2]) 52 | 53 | def process(self, ori_img, ori_shape): 54 | ori_h, ori_w = ori_shape 55 | ori_img = ori_img.copy() 56 | H, W, C = ori_img.shape 57 | with torch.no_grad(): 58 | candidate, subset, det_result = self.pose_estimation(ori_img) 59 | nums, keys, locs = candidate.shape 60 | candidate[..., 0] /= float(W) 61 | candidate[..., 1] /= float(H) 62 | body = candidate[:, :18].copy() 63 | body = body.reshape(nums * 18, locs) 64 | score = subset[:, :18] 65 | for i in range(len(score)): 66 | for j in range(len(score[i])): 67 | if score[i][j] > 0.3: 68 | score[i][j] = int(18 * i + j) 69 | else: 70 | score[i][j] = -1 71 | 72 | un_visible = subset < 0.3 73 | candidate[un_visible] = -1 74 | 75 | foot = candidate[:, 18:24] 76 | 77 | faces = candidate[:, 24:92] 78 | 79 | hands = candidate[:, 92:113] 80 | hands = np.vstack([hands, candidate[:, 113:]]) 81 | 82 | bodies = dict(candidate=body, subset=score) 83 | pose = dict(bodies=bodies, hands=hands, faces=faces) 84 | 85 | ret_data = {} 86 | if self.use_body: 87 | detected_map_body = draw_pose(pose, H, W, use_body=True) 88 | detected_map_body = cv2.resize(detected_map_body[..., ::-1], (ori_w, ori_h), 89 | interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA) 90 | ret_data["detected_map_body"] = detected_map_body 91 | 92 | if self.use_face: 93 | detected_map_face = draw_pose(pose, H, W, use_face=True) 94 | detected_map_face = cv2.resize(detected_map_face[..., ::-1], (ori_w, ori_h), 95 | interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA) 96 | ret_data["detected_map_face"] = detected_map_face 97 | 98 | if self.use_body and self.use_face: 99 | detected_map_bodyface = draw_pose(pose, H, W, use_body=True, use_face=True) 100 | detected_map_bodyface = cv2.resize(detected_map_bodyface[..., ::-1], (ori_w, ori_h), 101 | interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA) 102 | ret_data["detected_map_bodyface"] = detected_map_bodyface 103 | 104 | if self.use_hand and self.use_body and self.use_face: 105 | detected_map_handbodyface = draw_pose(pose, H, W, use_hand=True, use_body=True, use_face=True) 106 | detected_map_handbodyface = cv2.resize(detected_map_handbodyface[..., ::-1], (ori_w, ori_h), 107 | interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA) 108 | ret_data["detected_map_handbodyface"] = detected_map_handbodyface 109 | 110 | # convert_size 111 | if det_result.shape[0] > 0: 112 | w_ratio, h_ratio = ori_w / W, ori_h / H 113 | det_result[..., ::2] *= h_ratio 114 | det_result[..., 1::2] *= w_ratio 115 | det_result = det_result.astype(np.int32) 116 | return ret_data, det_result 117 | 118 | 119 | class PoseBodyFaceAnnotator(PoseAnnotator): 120 | def __init__(self, cfg, device=None): 121 | super().__init__(cfg, device) 122 | self.use_body, self.use_face, self.use_hand = True, True, False 123 | @torch.no_grad() 124 | @torch.inference_mode 125 | def forward(self, image): 126 | ret_data, det_result = super().forward(image) 127 | return ret_data['detected_map_bodyface'] 128 | 129 | 130 | class PoseBodyFaceVideoAnnotator(PoseBodyFaceAnnotator): 131 | def forward(self, frames): 132 | ret_frames = [] 133 | for frame in frames: 134 | anno_frame = super().forward(np.array(frame)) 135 | ret_frames.append(anno_frame) 136 | return ret_frames 137 | 138 | class PoseBodyAnnotator(PoseAnnotator): 139 | def __init__(self, cfg, device=None): 140 | super().__init__(cfg, device) 141 | self.use_body, self.use_face, self.use_hand = True, False, False 142 | @torch.no_grad() 143 | @torch.inference_mode 144 | def forward(self, image): 145 | ret_data, det_result = super().forward(image) 146 | return ret_data['detected_map_body'] 147 | 148 | 149 | class PoseBodyVideoAnnotator(PoseBodyAnnotator): 150 | def forward(self, frames): 151 | ret_frames = [] 152 | for frame in frames: 153 | anno_frame = super().forward(np.array(frame)) 154 | ret_frames.append(anno_frame) 155 | return ret_frames -------------------------------------------------------------------------------- /vace/annotators/prompt_extend.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | 4 | import torch 5 | 6 | class PromptExtendAnnotator: 7 | def __init__(self, cfg, device=None): 8 | from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander 9 | self.mode = cfg.get('MODE', "local_qwen") 10 | self.model_name = cfg.get('MODEL_NAME', "Qwen2.5_3B") 11 | self.is_vl = cfg.get('IS_VL', False) 12 | self.system_prompt = cfg.get('SYSTEM_PROMPT', None) 13 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device 14 | self.device_id = self.device.index if self.device.type == 'cuda' else None 15 | rank = self.device_id if self.device_id is not None else 0 16 | if self.mode == "dashscope": 17 | self.prompt_expander = DashScopePromptExpander( 18 | model_name=self.model_name, is_vl=self.is_vl) 19 | elif self.mode == "local_qwen": 20 | self.prompt_expander = QwenPromptExpander( 21 | model_name=self.model_name, 22 | is_vl=self.is_vl, 23 | device=rank) 24 | else: 25 | raise NotImplementedError(f"Unsupport prompt_extend_method: {self.mode}") 26 | 27 | 28 | def forward(self, prompt, system_prompt=None, seed=-1): 29 | system_prompt = system_prompt if system_prompt is not None else self.system_prompt 30 | output = self.prompt_expander(prompt, system_prompt=system_prompt, seed=seed) 31 | if output.status == False: 32 | print(f"Extending prompt failed: {output.message}") 33 | output_prompt = prompt 34 | else: 35 | output_prompt = output.prompt 36 | return output_prompt 37 | -------------------------------------------------------------------------------- /vace/annotators/ram.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | 4 | import cv2 5 | import torch 6 | import numpy as np 7 | from torchvision.transforms import Normalize, Compose, Resize, ToTensor 8 | from .utils import convert_to_pil 9 | 10 | class RAMAnnotator: 11 | def __init__(self, cfg, device=None): 12 | try: 13 | from ram.models import ram_plus, ram, tag2text 14 | from ram import inference_ram 15 | except: 16 | import warnings 17 | warnings.warn("please pip install ram package, or you can refer to models/VACE-Annotators/ram/ram-0.0.1-py3-none-any.whl") 18 | 19 | delete_tag_index = [] 20 | image_size = cfg.get('IMAGE_SIZE', 384) 21 | ram_tokenizer_path = cfg['TOKENIZER_PATH'] 22 | ram_checkpoint_path = cfg['PRETRAINED_MODEL'] 23 | ram_type = cfg.get('RAM_TYPE', 'swin_l') 24 | self.return_lang = cfg.get('RETURN_LANG', ['en']) # ['en', 'zh'] 25 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device 26 | self.model = ram_plus(pretrained=ram_checkpoint_path, image_size=image_size, vit=ram_type, 27 | text_encoder_type=ram_tokenizer_path, delete_tag_index=delete_tag_index).eval().to(self.device) 28 | self.ram_transform = Compose([ 29 | Resize((image_size, image_size)), 30 | ToTensor(), 31 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 32 | ]) 33 | self.inference_ram = inference_ram 34 | 35 | def forward(self, image): 36 | image = convert_to_pil(image) 37 | image_ann_trans = self.ram_transform(image).unsqueeze(0).to(self.device) 38 | tags_e, tags_c = self.inference_ram(image_ann_trans, self.model) 39 | tags_e_list = [tag.strip() for tag in tags_e.strip().split("|")] 40 | tags_c_list = [tag.strip() for tag in tags_c.strip().split("|")] 41 | if len(self.return_lang) == 1 and 'en' in self.return_lang: 42 | return tags_e_list 43 | elif len(self.return_lang) == 1 and 'zh' in self.return_lang: 44 | return tags_c_list 45 | else: 46 | return { 47 | "tags_e": tags_e_list, 48 | "tags_c": tags_c_list 49 | } 50 | -------------------------------------------------------------------------------- /vace/annotators/sam.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import numpy as np 4 | import torch 5 | from scipy import ndimage 6 | 7 | from .utils import convert_to_numpy 8 | 9 | 10 | class SAMImageAnnotator: 11 | def __init__(self, cfg, device=None): 12 | try: 13 | from segment_anything import sam_model_registry, SamPredictor 14 | from segment_anything.utils.transforms import ResizeLongestSide 15 | except: 16 | import warnings 17 | warnings.warn("please pip install sam package, or you can refer to models/VACE-Annotators/sam/segment_anything-1.0-py3-none-any.whl") 18 | self.task_type = cfg.get('TASK_TYPE', 'input_box') 19 | self.return_mask = cfg.get('RETURN_MASK', False) 20 | self.transform = ResizeLongestSide(1024) 21 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device 22 | seg_model = sam_model_registry[cfg.get('MODEL_NAME', 'vit_b')](checkpoint=cfg['PRETRAINED_MODEL']).eval().to(self.device) 23 | self.predictor = SamPredictor(seg_model) 24 | 25 | def forward(self, 26 | image, 27 | input_box=None, 28 | mask=None, 29 | task_type=None, 30 | return_mask=None): 31 | task_type = task_type if task_type is not None else self.task_type 32 | return_mask = return_mask if return_mask is not None else self.return_mask 33 | mask = convert_to_numpy(mask) if mask is not None else None 34 | 35 | if task_type == 'mask_point': 36 | if len(mask.shape) == 3: 37 | scribble = mask.transpose(2, 1, 0)[0] 38 | else: 39 | scribble = mask.transpose(1, 0) # (H, W) -> (W, H) 40 | labeled_array, num_features = ndimage.label(scribble >= 255) 41 | centers = ndimage.center_of_mass(scribble, labeled_array, 42 | range(1, num_features + 1)) 43 | point_coords = np.array(centers) 44 | point_labels = np.array([1] * len(centers)) 45 | sample = { 46 | 'point_coords': point_coords, 47 | 'point_labels': point_labels 48 | } 49 | elif task_type == 'mask_box': 50 | if len(mask.shape) == 3: 51 | scribble = mask.transpose(2, 1, 0)[0] 52 | else: 53 | scribble = mask.transpose(1, 0) # (H, W) -> (W, H) 54 | labeled_array, num_features = ndimage.label(scribble >= 255) 55 | centers = ndimage.center_of_mass(scribble, labeled_array, 56 | range(1, num_features + 1)) 57 | centers = np.array(centers) 58 | # (x1, y1, x2, y2) 59 | x_min = centers[:, 0].min() 60 | x_max = centers[:, 0].max() 61 | y_min = centers[:, 1].min() 62 | y_max = centers[:, 1].max() 63 | bbox = np.array([x_min, y_min, x_max, y_max]) 64 | sample = {'box': bbox} 65 | elif task_type == 'input_box': 66 | if isinstance(input_box, list): 67 | input_box = np.array(input_box) 68 | sample = {'box': input_box} 69 | elif task_type == 'mask': 70 | sample = {'mask_input': mask[None, :, :]} 71 | else: 72 | raise NotImplementedError 73 | 74 | self.predictor.set_image(image) 75 | masks, scores, logits = self.predictor.predict( 76 | multimask_output=False, 77 | **sample 78 | ) 79 | sorted_ind = np.argsort(scores)[::-1] 80 | masks = masks[sorted_ind] 81 | scores = scores[sorted_ind] 82 | logits = logits[sorted_ind] 83 | 84 | if return_mask: 85 | return masks[0] 86 | else: 87 | ret_data = { 88 | "masks": masks, 89 | "scores": scores, 90 | "logits": logits 91 | } 92 | return ret_data -------------------------------------------------------------------------------- /vace/annotators/scribble.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from einops import rearrange 7 | 8 | from .utils import convert_to_torch 9 | 10 | norm_layer = nn.InstanceNorm2d 11 | 12 | 13 | class ResidualBlock(nn.Module): 14 | def __init__(self, in_features): 15 | super(ResidualBlock, self).__init__() 16 | 17 | conv_block = [ 18 | nn.ReflectionPad2d(1), 19 | nn.Conv2d(in_features, in_features, 3), 20 | norm_layer(in_features), 21 | nn.ReLU(inplace=True), 22 | nn.ReflectionPad2d(1), 23 | nn.Conv2d(in_features, in_features, 3), 24 | norm_layer(in_features) 25 | ] 26 | 27 | self.conv_block = nn.Sequential(*conv_block) 28 | 29 | def forward(self, x): 30 | return x + self.conv_block(x) 31 | 32 | 33 | class ContourInference(nn.Module): 34 | def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): 35 | super(ContourInference, self).__init__() 36 | 37 | # Initial convolution block 38 | model0 = [ 39 | nn.ReflectionPad2d(3), 40 | nn.Conv2d(input_nc, 64, 7), 41 | norm_layer(64), 42 | nn.ReLU(inplace=True) 43 | ] 44 | self.model0 = nn.Sequential(*model0) 45 | 46 | # Downsampling 47 | model1 = [] 48 | in_features = 64 49 | out_features = in_features * 2 50 | for _ in range(2): 51 | model1 += [ 52 | nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), 53 | norm_layer(out_features), 54 | nn.ReLU(inplace=True) 55 | ] 56 | in_features = out_features 57 | out_features = in_features * 2 58 | self.model1 = nn.Sequential(*model1) 59 | 60 | model2 = [] 61 | # Residual blocks 62 | for _ in range(n_residual_blocks): 63 | model2 += [ResidualBlock(in_features)] 64 | self.model2 = nn.Sequential(*model2) 65 | 66 | # Upsampling 67 | model3 = [] 68 | out_features = in_features // 2 69 | for _ in range(2): 70 | model3 += [ 71 | nn.ConvTranspose2d(in_features, 72 | out_features, 73 | 3, 74 | stride=2, 75 | padding=1, 76 | output_padding=1), 77 | norm_layer(out_features), 78 | nn.ReLU(inplace=True) 79 | ] 80 | in_features = out_features 81 | out_features = in_features // 2 82 | self.model3 = nn.Sequential(*model3) 83 | 84 | # Output layer 85 | model4 = [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)] 86 | if sigmoid: 87 | model4 += [nn.Sigmoid()] 88 | 89 | self.model4 = nn.Sequential(*model4) 90 | 91 | def forward(self, x, cond=None): 92 | out = self.model0(x) 93 | out = self.model1(out) 94 | out = self.model2(out) 95 | out = self.model3(out) 96 | out = self.model4(out) 97 | 98 | return out 99 | 100 | 101 | class ScribbleAnnotator: 102 | def __init__(self, cfg, device=None): 103 | input_nc = cfg.get('INPUT_NC', 3) 104 | output_nc = cfg.get('OUTPUT_NC', 1) 105 | n_residual_blocks = cfg.get('N_RESIDUAL_BLOCKS', 3) 106 | sigmoid = cfg.get('SIGMOID', True) 107 | pretrained_model = cfg['PRETRAINED_MODEL'] 108 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device 109 | self.model = ContourInference(input_nc, output_nc, n_residual_blocks, 110 | sigmoid) 111 | self.model.load_state_dict(torch.load(pretrained_model, weights_only=True)) 112 | self.model = self.model.eval().requires_grad_(False).to(self.device) 113 | 114 | @torch.no_grad() 115 | @torch.inference_mode() 116 | @torch.autocast('cuda', enabled=False) 117 | def forward(self, image): 118 | is_batch = False if len(image.shape) == 3 else True 119 | image = convert_to_torch(image) 120 | if len(image.shape) == 3: 121 | image = rearrange(image, 'h w c -> 1 c h w') 122 | image = image.float().div(255).to(self.device) 123 | contour_map = self.model(image) 124 | contour_map = (contour_map.squeeze(dim=1) * 255.0).clip( 125 | 0, 255).cpu().numpy().astype(np.uint8) 126 | contour_map = contour_map[..., None].repeat(3, -1) 127 | if not is_batch: 128 | contour_map = contour_map.squeeze() 129 | return contour_map 130 | 131 | 132 | class ScribbleVideoAnnotator(ScribbleAnnotator): 133 | def forward(self, frames): 134 | ret_frames = [] 135 | for frame in frames: 136 | anno_frame = super().forward(np.array(frame)) 137 | ret_frames.append(anno_frame) 138 | return ret_frames -------------------------------------------------------------------------------- /vace/annotators/subject.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from .utils import convert_to_numpy 8 | 9 | 10 | class SubjectAnnotator: 11 | def __init__(self, cfg, device=None): 12 | self.mode = cfg.get('MODE', "salientmasktrack") 13 | self.use_aug = cfg.get('USE_AUG', False) 14 | self.use_crop = cfg.get('USE_CROP', False) 15 | self.roi_only = cfg.get('ROI_ONLY', False) 16 | self.return_mask = cfg.get('RETURN_MASK', True) 17 | 18 | from .inpainting import InpaintingAnnotator 19 | self.inp_anno = InpaintingAnnotator(cfg['INPAINTING'], device=device) 20 | if self.use_aug: 21 | from .maskaug import MaskAugAnnotator 22 | self.maskaug_anno = MaskAugAnnotator(cfg={}) 23 | assert self.mode in ["plain", "salient", "mask", "bbox", "salientmasktrack", "salientbboxtrack", "masktrack", 24 | "bboxtrack", "label", "caption", "all"] 25 | 26 | def forward(self, image=None, mode=None, return_mask=None, mask_cfg=None, mask=None, bbox=None, label=None, caption=None): 27 | return_mask = return_mask if return_mask is not None else self.return_mask 28 | 29 | if mode == "plain": 30 | return {"image": image, "mask": None} if return_mask else image 31 | 32 | inp_res = self.inp_anno.forward(image, mask=mask, bbox=bbox, label=label, caption=caption, mode=mode, return_mask=True, return_source=True) 33 | src_image = inp_res['src_image'] 34 | mask = inp_res['mask'] 35 | 36 | if self.use_aug and mask_cfg is not None: 37 | mask = self.maskaug_anno.forward(mask, mask_cfg) 38 | 39 | _, binary_mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY) 40 | if (binary_mask is None or binary_mask.size == 0 or cv2.countNonZero(binary_mask) == 0): 41 | x, y, w, h = 0, 0, binary_mask.shape[1], binary_mask.shape[0] 42 | else: 43 | x, y, w, h = cv2.boundingRect(binary_mask) 44 | 45 | ret_mask = mask.copy() 46 | ret_image = src_image.copy() 47 | 48 | if self.roi_only: 49 | ret_image[mask == 0] = 255 50 | 51 | if self.use_crop: 52 | ret_image = ret_image[y:y + h, x:x + w] 53 | ret_mask = ret_mask[y:y + h, x:x + w] 54 | 55 | if return_mask: 56 | return {"image": ret_image, "mask": ret_mask} 57 | else: 58 | return ret_image 59 | 60 | 61 | -------------------------------------------------------------------------------- /vace/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | 4 | from .video_preproccess import video_depth_anno, video_depthv2_anno, video_flow_anno, video_gray_anno, video_pose_anno, video_pose_body_anno, video_scribble_anno 5 | from .video_preproccess import video_framerefext_anno, video_firstframeref_anno, video_lastframeref_anno, video_firstlastframeref_anno, video_firstclipref_anno, video_lastclipref_anno, video_firstlastclipref_anno, video_framerefexp_anno, video_cliprefexp_anno 6 | from .video_preproccess import video_inpainting_mask_anno, video_inpainting_bbox_anno, video_inpainting_masktrack_anno, video_inpainting_bboxtrack_anno, video_inpainting_label_anno, video_inpainting_caption_anno, video_inpainting_anno 7 | from .video_preproccess import video_outpainting_anno, video_outpainting_inner_anno 8 | from .video_preproccess import video_layout_bbox_anno, video_layout_track_anno 9 | from .image_preproccess import image_face_anno, image_salient_anno, image_subject_anno, image_face_mask_anno 10 | from .image_preproccess import image_inpainting_anno, image_outpainting_anno 11 | from .image_preproccess import image_depth_anno, image_gray_anno, image_pose_anno, image_scribble_anno 12 | from .common_preproccess import image_plain_anno, image_mask_plain_anno, image_maskaug_plain_anno, image_maskaug_invert_anno, image_maskaug_anno, video_mask_plain_anno, video_maskaug_plain_anno, video_plain_anno, video_maskaug_invert_anno, video_mask_expand_anno, prompt_plain_anno, video_maskaug_anno, video_maskaug_layout_anno, image_mask_draw_anno, image_maskaug_region_random_anno, image_maskaug_region_crop_anno 13 | from .prompt_preprocess import prompt_extend_ltx_en_anno, prompt_extend_wan_zh_anno, prompt_extend_wan_en_anno, prompt_extend_wan_zh_ds_anno, prompt_extend_wan_en_ds_anno, prompt_extend_ltx_en_ds_anno 14 | from .composition_preprocess import comp_anno, comp_refany_anno, comp_aniany_anno, comp_swapany_anno, comp_expany_anno, comp_moveany_anno 15 | 16 | VACE_IMAGE_PREPROCCESS_CONFIGS = { 17 | 'image_plain': image_plain_anno, 18 | 'image_face': image_face_anno, 19 | 'image_salient': image_salient_anno, 20 | 'image_inpainting': image_inpainting_anno, 21 | 'image_reference': image_subject_anno, 22 | 'image_outpainting': image_outpainting_anno, 23 | 'image_depth': image_depth_anno, 24 | 'image_gray': image_gray_anno, 25 | 'image_pose': image_pose_anno, 26 | 'image_scribble': image_scribble_anno 27 | } 28 | 29 | VACE_IMAGE_MASK_PREPROCCESS_CONFIGS = { 30 | 'image_mask_plain': image_mask_plain_anno, 31 | 'image_mask_seg': image_inpainting_anno, 32 | 'image_mask_draw': image_mask_draw_anno, 33 | 'image_mask_face': image_face_mask_anno 34 | } 35 | 36 | VACE_IMAGE_MASKAUG_PREPROCCESS_CONFIGS = { 37 | 'image_maskaug_plain': image_maskaug_plain_anno, 38 | 'image_maskaug_invert': image_maskaug_invert_anno, 39 | 'image_maskaug': image_maskaug_anno, 40 | 'image_maskaug_region_random': image_maskaug_region_random_anno, 41 | 'image_maskaug_region_crop': image_maskaug_region_crop_anno 42 | } 43 | 44 | 45 | VACE_VIDEO_PREPROCCESS_CONFIGS = { 46 | 'plain': video_plain_anno, 47 | 'depth': video_depth_anno, 48 | 'depthv2': video_depthv2_anno, 49 | 'flow': video_flow_anno, 50 | 'gray': video_gray_anno, 51 | 'pose': video_pose_anno, 52 | 'pose_body': video_pose_body_anno, 53 | 'scribble': video_scribble_anno, 54 | 'framerefext': video_framerefext_anno, 55 | 'frameref': video_framerefexp_anno, 56 | 'clipref': video_cliprefexp_anno, 57 | 'firstframe': video_firstframeref_anno, 58 | 'lastframe': video_lastframeref_anno, 59 | "firstlastframe": video_firstlastframeref_anno, 60 | 'firstclip': video_firstclipref_anno, 61 | 'lastclip': video_lastclipref_anno, 62 | 'firstlastclip': video_firstlastclipref_anno, 63 | 'inpainting': video_inpainting_anno, 64 | 'inpainting_mask': video_inpainting_mask_anno, 65 | 'inpainting_bbox': video_inpainting_bbox_anno, 66 | 'inpainting_masktrack': video_inpainting_masktrack_anno, 67 | 'inpainting_bboxtrack': video_inpainting_bboxtrack_anno, 68 | 'inpainting_label': video_inpainting_label_anno, 69 | 'inpainting_caption': video_inpainting_caption_anno, 70 | 'outpainting': video_outpainting_anno, 71 | 'outpainting_inner': video_outpainting_inner_anno, 72 | 'layout_bbox': video_layout_bbox_anno, 73 | 'layout_track': video_layout_track_anno, 74 | } 75 | 76 | VACE_VIDEO_MASK_PREPROCCESS_CONFIGS = { 77 | # 'mask_plain': video_mask_plain_anno, 78 | 'mask_expand': video_mask_expand_anno, 79 | 'mask_seg': video_inpainting_anno, 80 | } 81 | 82 | VACE_VIDEO_MASKAUG_PREPROCCESS_CONFIGS = { 83 | 'maskaug_plain': video_maskaug_plain_anno, 84 | 'maskaug_invert': video_maskaug_invert_anno, 85 | 'maskaug': video_maskaug_anno, 86 | 'maskaug_layout': video_maskaug_layout_anno 87 | } 88 | 89 | VACE_COMPOSITION_PREPROCCESS_CONFIGS = { 90 | 'composition': comp_anno, 91 | 'reference_anything': comp_refany_anno, 92 | 'animate_anything': comp_aniany_anno, 93 | 'swap_anything': comp_swapany_anno, 94 | 'expand_anything': comp_expany_anno, 95 | 'move_anything': comp_moveany_anno 96 | } 97 | 98 | 99 | VACE_PREPROCCESS_CONFIGS = {**VACE_IMAGE_PREPROCCESS_CONFIGS, **VACE_VIDEO_PREPROCCESS_CONFIGS, **VACE_COMPOSITION_PREPROCCESS_CONFIGS} 100 | 101 | VACE_PROMPT_CONFIGS = { 102 | 'plain': prompt_plain_anno, 103 | 'wan_zh': prompt_extend_wan_zh_anno, 104 | 'wan_en': prompt_extend_wan_en_anno, 105 | 'wan_zh_ds': prompt_extend_wan_zh_ds_anno, 106 | 'wan_en_ds': prompt_extend_wan_en_ds_anno, 107 | 'ltx_en': prompt_extend_ltx_en_anno, 108 | 'ltx_en_ds': prompt_extend_ltx_en_ds_anno 109 | } 110 | 111 | 112 | VACE_CONFIGS = { 113 | "prompt": VACE_PROMPT_CONFIGS, 114 | "image": VACE_IMAGE_PREPROCCESS_CONFIGS, 115 | "image_mask": VACE_IMAGE_MASK_PREPROCCESS_CONFIGS, 116 | "image_maskaug": VACE_IMAGE_MASKAUG_PREPROCCESS_CONFIGS, 117 | "video": VACE_VIDEO_PREPROCCESS_CONFIGS, 118 | "video_mask": VACE_VIDEO_MASK_PREPROCCESS_CONFIGS, 119 | "video_maskaug": VACE_VIDEO_MASKAUG_PREPROCCESS_CONFIGS, 120 | } -------------------------------------------------------------------------------- /vace/configs/common_preproccess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | 4 | from easydict import EasyDict 5 | 6 | ######################### Common ######################### 7 | #------------------------ image ------------------------# 8 | image_plain_anno = EasyDict() 9 | image_plain_anno.NAME = "PlainImageAnnotator" 10 | image_plain_anno.INPUTS = {"image": None} 11 | image_plain_anno.OUTPUTS = {"image": None} 12 | 13 | image_mask_plain_anno = EasyDict() 14 | image_mask_plain_anno.NAME = "PlainMaskAnnotator" 15 | image_mask_plain_anno.INPUTS = {"mask": None} 16 | image_mask_plain_anno.OUTPUTS = {"mask": None} 17 | 18 | image_maskaug_plain_anno = EasyDict() 19 | image_maskaug_plain_anno.NAME = "PlainMaskAugAnnotator" 20 | image_maskaug_plain_anno.INPUTS = {"mask": None} 21 | image_maskaug_plain_anno.OUTPUTS = {"mask": None} 22 | 23 | image_maskaug_invert_anno = EasyDict() 24 | image_maskaug_invert_anno.NAME = "PlainMaskAugInvertAnnotator" 25 | image_maskaug_invert_anno.INPUTS = {"mask": None} 26 | image_maskaug_invert_anno.OUTPUTS = {"mask": None} 27 | 28 | image_maskaug_anno = EasyDict() 29 | image_maskaug_anno.NAME = "MaskAugAnnotator" 30 | image_maskaug_anno.INPUTS = {"mask": None, 'mask_cfg': None} 31 | image_maskaug_anno.OUTPUTS = {"mask": None} 32 | 33 | image_mask_draw_anno = EasyDict() 34 | image_mask_draw_anno.NAME = "MaskDrawAnnotator" 35 | image_mask_draw_anno.INPUTS = {"mask": None, 'image': None, 'bbox': None, 'mode': None} 36 | image_mask_draw_anno.OUTPUTS = {"mask": None} 37 | 38 | image_maskaug_region_random_anno = EasyDict() 39 | image_maskaug_region_random_anno.NAME = "RegionCanvasAnnotator" 40 | image_maskaug_region_random_anno.SCALE_RANGE = [ 0.5, 1.0 ] 41 | image_maskaug_region_random_anno.USE_AUG = True 42 | image_maskaug_region_random_anno.INPUTS = {"mask": None, 'image': None, 'bbox': None, 'mode': None} 43 | image_maskaug_region_random_anno.OUTPUTS = {"mask": None} 44 | 45 | image_maskaug_region_crop_anno = EasyDict() 46 | image_maskaug_region_crop_anno.NAME = "RegionCanvasAnnotator" 47 | image_maskaug_region_crop_anno.SCALE_RANGE = [ 0.5, 1.0 ] 48 | image_maskaug_region_crop_anno.USE_AUG = True 49 | image_maskaug_region_crop_anno.USE_RESIZE = False 50 | image_maskaug_region_crop_anno.USE_CANVAS = False 51 | image_maskaug_region_crop_anno.INPUTS = {"mask": None, 'image': None, 'bbox': None, 'mode': None} 52 | image_maskaug_region_crop_anno.OUTPUTS = {"mask": None} 53 | 54 | 55 | #------------------------ video ------------------------# 56 | video_plain_anno = EasyDict() 57 | video_plain_anno.NAME = "PlainVideoAnnotator" 58 | video_plain_anno.INPUTS = {"frames": None} 59 | video_plain_anno.OUTPUTS = {"frames": None} 60 | 61 | video_mask_plain_anno = EasyDict() 62 | video_mask_plain_anno.NAME = "PlainMaskVideoAnnotator" 63 | video_mask_plain_anno.INPUTS = {"masks": None} 64 | video_mask_plain_anno.OUTPUTS = {"masks": None} 65 | 66 | video_maskaug_plain_anno = EasyDict() 67 | video_maskaug_plain_anno.NAME = "PlainMaskAugVideoAnnotator" 68 | video_maskaug_plain_anno.INPUTS = {"masks": None} 69 | video_maskaug_plain_anno.OUTPUTS = {"masks": None} 70 | 71 | video_maskaug_invert_anno = EasyDict() 72 | video_maskaug_invert_anno.NAME = "PlainMaskAugInvertVideoAnnotator" 73 | video_maskaug_invert_anno.INPUTS = {"masks": None} 74 | video_maskaug_invert_anno.OUTPUTS = {"masks": None} 75 | 76 | video_mask_expand_anno = EasyDict() 77 | video_mask_expand_anno.NAME = "ExpandMaskVideoAnnotator" 78 | video_mask_expand_anno.INPUTS = {"masks": None} 79 | video_mask_expand_anno.OUTPUTS = {"masks": None} 80 | 81 | video_maskaug_anno = EasyDict() 82 | video_maskaug_anno.NAME = "MaskAugAnnotator" 83 | video_maskaug_anno.INPUTS = {"mask": None, 'mask_cfg': None} 84 | video_maskaug_anno.OUTPUTS = {"mask": None} 85 | 86 | video_maskaug_layout_anno = EasyDict() 87 | video_maskaug_layout_anno.NAME = "LayoutMaskAnnotator" 88 | video_maskaug_layout_anno.RAM_TAG_COLOR_PATH = "models/VACE-Annotators/layout/ram_tag_color_list.txt" 89 | video_maskaug_layout_anno.USE_AUG = True 90 | video_maskaug_layout_anno.INPUTS = {"mask": None, 'mask_cfg': None} 91 | video_maskaug_layout_anno.OUTPUTS = {"mask": None} 92 | 93 | 94 | #------------------------ prompt ------------------------# 95 | prompt_plain_anno = EasyDict() 96 | prompt_plain_anno.NAME = "PlainPromptAnnotator" 97 | prompt_plain_anno.INPUTS = {"prompt": None} 98 | prompt_plain_anno.OUTPUTS = {"prompt": None} 99 | -------------------------------------------------------------------------------- /vace/configs/composition_preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | 4 | from easydict import EasyDict 5 | 6 | #------------------------ CompositionBase ------------------------# 7 | comp_anno = EasyDict() 8 | comp_anno.NAME = "CompositionAnnotator" 9 | comp_anno.INPUTS = {"process_type_1": None, "process_type_2": None, "frames_1": None, "frames_2": None, "masks_1": None, "masks_2": None} 10 | comp_anno.OUTPUTS = {"frames": None, "masks": None} 11 | 12 | #------------------------ ReferenceAnything ------------------------# 13 | comp_refany_anno = EasyDict() 14 | comp_refany_anno.NAME = "ReferenceAnythingAnnotator" 15 | comp_refany_anno.SUBJECT = {"MODE": "all", "USE_AUG": True, "USE_CROP": True, "ROI_ONLY": True, 16 | "INPAINTING": {"MODE": "all", 17 | "SALIENT": {"PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt"}, 18 | "GDINO": {"TOKENIZER_PATH": "models/VACE-Annotators/gdino/bert-base-uncased", 19 | "CONFIG_PATH": "models/VACE-Annotators/gdino/GroundingDINO_SwinT_OGC_mod.py", 20 | "PRETRAINED_MODEL": "models/VACE-Annotators/gdino/groundingdino_swint_ogc.pth"}, 21 | "SAM2": {"CONFIG_PATH": 'models/VACE-Annotators/sam2/configs/sam2.1/sam2.1_hiera_l.yaml', 22 | "PRETRAINED_MODEL": 'models/VACE-Annotators/sam2/sam2.1_hiera_large.pt'}}} 23 | comp_refany_anno.INPUTS = {"images": None, "mode": None, "mask_cfg": None} 24 | comp_refany_anno.OUTPUTS = {"images": None} 25 | 26 | 27 | #------------------------ AnimateAnything ------------------------# 28 | comp_aniany_anno = EasyDict() 29 | comp_aniany_anno.NAME = "AnimateAnythingAnnotator" 30 | comp_aniany_anno.POSE = {"DETECTION_MODEL": "models/VACE-Annotators/pose/yolox_l.onnx", 31 | "POSE_MODEL": "models/VACE-Annotators/pose/dw-ll_ucoco_384.onnx"} 32 | comp_aniany_anno.REFERENCE = {"MODE": "all", "USE_AUG": True, "USE_CROP": True, "ROI_ONLY": True, 33 | "INPAINTING": {"MODE": "all", 34 | "SALIENT": {"PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt"}, 35 | "GDINO": {"TOKENIZER_PATH": "models/VACE-Annotators/gdino/bert-base-uncased", 36 | "CONFIG_PATH": "models/VACE-Annotators/gdino/GroundingDINO_SwinT_OGC_mod.py", 37 | "PRETRAINED_MODEL": "models/VACE-Annotators/gdino/groundingdino_swint_ogc.pth"}, 38 | "SAM2": {"CONFIG_PATH": 'models/VACE-Annotators/sam2/configs/sam2.1/sam2.1_hiera_l.yaml', 39 | "PRETRAINED_MODEL": 'models/VACE-Annotators/sam2/sam2.1_hiera_large.pt'}}} 40 | comp_aniany_anno.INPUTS = {"frames": None, "images": None, "mode": None, "mask_cfg": None} 41 | comp_aniany_anno.OUTPUTS = {"frames": None, "images": None} 42 | 43 | 44 | #------------------------ SwapAnything ------------------------# 45 | comp_swapany_anno = EasyDict() 46 | comp_swapany_anno.NAME = "SwapAnythingAnnotator" 47 | comp_swapany_anno.REFERENCE = {"MODE": "all", "USE_AUG": True, "USE_CROP": True, "ROI_ONLY": True, 48 | "INPAINTING": {"MODE": "all", 49 | "SALIENT": {"PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt"}, 50 | "GDINO": {"TOKENIZER_PATH": "models/VACE-Annotators/gdino/bert-base-uncased", 51 | "CONFIG_PATH": "models/VACE-Annotators/gdino/GroundingDINO_SwinT_OGC_mod.py", 52 | "PRETRAINED_MODEL": "models/VACE-Annotators/gdino/groundingdino_swint_ogc.pth"}, 53 | "SAM2": {"CONFIG_PATH": 'models/VACE-Annotators/sam2/configs/sam2.1/sam2.1_hiera_l.yaml', 54 | "PRETRAINED_MODEL": 'models/VACE-Annotators/sam2/sam2.1_hiera_large.pt'}}} 55 | comp_swapany_anno.INPAINTING = {"MODE": "all", 56 | "SALIENT": {"PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt"}, 57 | "GDINO": {"TOKENIZER_PATH": "models/VACE-Annotators/gdino/bert-base-uncased", 58 | "CONFIG_PATH": "models/VACE-Annotators/gdino/GroundingDINO_SwinT_OGC_mod.py", 59 | "PRETRAINED_MODEL": "models/VACE-Annotators/gdino/groundingdino_swint_ogc.pth"}, 60 | "SAM2": {"CONFIG_PATH": 'models/VACE-Annotators/sam2/configs/sam2.1/sam2.1_hiera_l.yaml', 61 | "PRETRAINED_MODEL": 'models/VACE-Annotators/sam2/sam2.1_hiera_large.pt'}} 62 | comp_swapany_anno.INPUTS = {"frames": None, "video": None, "images": None, "mask": None, "bbox": None, "label": None, "caption": None, "mode": None, "mask_cfg": None} 63 | comp_swapany_anno.OUTPUTS = {"frames": None, "images": None, "masks": None} 64 | 65 | 66 | 67 | #------------------------ ExpandAnything ------------------------# 68 | comp_expany_anno = EasyDict() 69 | comp_expany_anno.NAME = "ExpandAnythingAnnotator" 70 | comp_expany_anno.REFERENCE = {"MODE": "all", "USE_AUG": True, "USE_CROP": True, "ROI_ONLY": True, 71 | "INPAINTING": {"MODE": "all", 72 | "SALIENT": {"PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt"}, 73 | "GDINO": {"TOKENIZER_PATH": "models/VACE-Annotators/gdino/bert-base-uncased", 74 | "CONFIG_PATH": "models/VACE-Annotators/gdino/GroundingDINO_SwinT_OGC_mod.py", 75 | "PRETRAINED_MODEL": "models/VACE-Annotators/gdino/groundingdino_swint_ogc.pth"}, 76 | "SAM2": {"CONFIG_PATH": 'models/VACE-Annotators/sam2/configs/sam2.1/sam2.1_hiera_l.yaml', 77 | "PRETRAINED_MODEL": 'models/VACE-Annotators/sam2/sam2.1_hiera_large.pt'}}} 78 | comp_expany_anno.OUTPAINTING = {"RETURN_MASK": True, "KEEP_PADDING_RATIO": 1, "MASK_COLOR": "gray"} 79 | comp_expany_anno.FRAMEREF = {} 80 | comp_expany_anno.INPUTS = {"images": None, "mode": None, "mask_cfg": None, "direction": None, "expand_ratio": None, "expand_num": None} 81 | comp_expany_anno.OUTPUTS = {"frames": None, "images": None, "masks": None} 82 | 83 | 84 | #------------------------ MoveAnything ------------------------# 85 | comp_moveany_anno = EasyDict() 86 | comp_moveany_anno.NAME = "MoveAnythingAnnotator" 87 | comp_moveany_anno.LAYOUTBBOX = {"RAM_TAG_COLOR_PATH": "models/VACE-Annotators/layout/ram_tag_color_list.txt"} 88 | comp_moveany_anno.INPUTS = {"image": None, "bbox": None, "label": None, "expand_num": None} 89 | comp_moveany_anno.OUTPUTS = {"frames": None, "masks": None} 90 | -------------------------------------------------------------------------------- /vace/configs/image_preproccess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | 4 | from easydict import EasyDict 5 | 6 | ######################### Control ######################### 7 | #------------------------ Depth ------------------------# 8 | image_depth_anno = EasyDict() 9 | image_depth_anno.NAME = "DepthAnnotator" 10 | image_depth_anno.PRETRAINED_MODEL = "models/VACE-Annotators/depth/dpt_hybrid-midas-501f0c75.pt" 11 | image_depth_anno.INPUTS = {"image": None} 12 | image_depth_anno.OUTPUTS = {"image": None} 13 | 14 | #------------------------ Depth ------------------------# 15 | image_depthv2_anno = EasyDict() 16 | image_depthv2_anno.NAME = "DepthV2Annotator" 17 | image_depthv2_anno.PRETRAINED_MODEL = "models/VACE-Annotators/depth/depth_anything_v2_vitl.pth" 18 | image_depthv2_anno.INPUTS = {"image": None} 19 | image_depthv2_anno.OUTPUTS = {"image": None} 20 | 21 | #------------------------ Gray ------------------------# 22 | image_gray_anno = EasyDict() 23 | image_gray_anno.NAME = "GrayAnnotator" 24 | image_gray_anno.INPUTS = {"image": None} 25 | image_gray_anno.OUTPUTS = {"image": None} 26 | 27 | #------------------------ Pose ------------------------# 28 | image_pose_anno = EasyDict() 29 | image_pose_anno.NAME = "PoseBodyFaceAnnotator" 30 | image_pose_anno.DETECTION_MODEL = "models/VACE-Annotators/pose/yolox_l.onnx" 31 | image_pose_anno.POSE_MODEL = "models/VACE-Annotators/pose/dw-ll_ucoco_384.onnx" 32 | image_pose_anno.INPUTS = {"image": None} 33 | image_pose_anno.OUTPUTS = {"image": None} 34 | 35 | #------------------------ Scribble ------------------------# 36 | image_scribble_anno = EasyDict() 37 | image_scribble_anno.NAME = "ScribbleAnnotator" 38 | image_scribble_anno.PRETRAINED_MODEL = "models/VACE-Annotators/scribble/anime_style/netG_A_latest.pth" 39 | image_scribble_anno.INPUTS = {"image": None} 40 | image_scribble_anno.OUTPUTS = {"image": None} 41 | 42 | #------------------------ Outpainting ------------------------# 43 | image_outpainting_anno = EasyDict() 44 | image_outpainting_anno.NAME = "OutpaintingAnnotator" 45 | image_outpainting_anno.RETURN_MASK = True 46 | image_outpainting_anno.KEEP_PADDING_RATIO = 1 47 | image_outpainting_anno.MASK_COLOR = 'gray' 48 | image_outpainting_anno.INPUTS = {"image": None, "direction": ['left', 'right'], 'expand_ratio': 0.25} 49 | image_outpainting_anno.OUTPUTS = {"image": None, "mask": None} 50 | 51 | 52 | 53 | 54 | ######################### R2V - Subject ######################### 55 | #------------------------ Face ------------------------# 56 | image_face_anno = EasyDict() 57 | image_face_anno.NAME = "FaceAnnotator" 58 | image_face_anno.MODEL_NAME = "antelopev2" 59 | image_face_anno.PRETRAINED_MODEL = "models/VACE-Annotators/face/" 60 | image_face_anno.RETURN_RAW = False 61 | image_face_anno.MULTI_FACE = False 62 | image_face_anno.INPUTS = {"image": None} 63 | image_face_anno.OUTPUTS = {"image": None} 64 | 65 | #------------------------ FaceMask ------------------------# 66 | image_face_mask_anno = EasyDict() 67 | image_face_mask_anno.NAME = "FaceAnnotator" 68 | image_face_mask_anno.MODEL_NAME = "antelopev2" 69 | image_face_mask_anno.PRETRAINED_MODEL = "models/VACE-Annotators/face/" 70 | image_face_mask_anno.MULTI_FACE = False 71 | image_face_mask_anno.RETURN_RAW = False 72 | image_face_mask_anno.RETURN_DICT = True 73 | image_face_mask_anno.RETURN_MASK = True 74 | image_face_mask_anno.INPUTS = {"image": None} 75 | image_face_mask_anno.OUTPUTS = {"image": None, "mask": None} 76 | 77 | #------------------------ Salient ------------------------# 78 | image_salient_anno = EasyDict() 79 | image_salient_anno.NAME = "SalientAnnotator" 80 | image_salient_anno.NORM_SIZE = [320, 320] 81 | image_salient_anno.RETURN_IMAGE = True 82 | image_salient_anno.USE_CROP = True 83 | image_salient_anno.PRETRAINED_MODEL = "models/VACE-Annotators/salient/u2net.pt" 84 | image_salient_anno.INPUTS = {"image": None} 85 | image_salient_anno.OUTPUTS = {"image": None} 86 | 87 | #------------------------ Inpainting ------------------------# 88 | image_inpainting_anno = EasyDict() 89 | image_inpainting_anno.NAME = "InpaintingAnnotator" 90 | image_inpainting_anno.MODE = "all" 91 | image_inpainting_anno.USE_AUG = True 92 | image_inpainting_anno.SALIENT = {"PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt"} 93 | image_inpainting_anno.GDINO = {"TOKENIZER_PATH": "models/VACE-Annotators/gdino/bert-base-uncased", 94 | "CONFIG_PATH": "models/VACE-Annotators/gdino/GroundingDINO_SwinT_OGC_mod.py", 95 | "PRETRAINED_MODEL": "models/VACE-Annotators/gdino/groundingdino_swint_ogc.pth"} 96 | image_inpainting_anno.SAM2 = {"CONFIG_PATH": 'models/VACE-Annotators/sam2/configs/sam2.1/sam2.1_hiera_l.yaml', 97 | "PRETRAINED_MODEL": 'models/VACE-Annotators/sam2/sam2.1_hiera_large.pt'} 98 | # image_inpainting_anno.INPUTS = {"image": None, "mode": "salient"} 99 | # image_inpainting_anno.INPUTS = {"image": None, "mask": None, "mode": "mask"} 100 | # image_inpainting_anno.INPUTS = {"image": None, "bbox": None, "mode": "bbox"} 101 | image_inpainting_anno.INPUTS = {"image": None, "mode": "salientmasktrack", "mask_cfg": None} 102 | # image_inpainting_anno.INPUTS = {"image": None, "mode": "salientbboxtrack"} 103 | # image_inpainting_anno.INPUTS = {"image": None, "mask": None, "mode": "masktrack"} 104 | # image_inpainting_anno.INPUTS = {"image": None, "bbox": None, "mode": "bboxtrack"} 105 | # image_inpainting_anno.INPUTS = {"image": None, "label": None, "mode": "label"} 106 | # image_inpainting_anno.INPUTS = {"image": None, "caption": None, "mode": "caption"} 107 | image_inpainting_anno.OUTPUTS = {"image": None, "mask": None} 108 | 109 | 110 | #------------------------ Subject ------------------------# 111 | image_subject_anno = EasyDict() 112 | image_subject_anno.NAME = "SubjectAnnotator" 113 | image_subject_anno.MODE = "all" 114 | image_subject_anno.USE_AUG = True 115 | image_subject_anno.USE_CROP = True 116 | image_subject_anno.ROI_ONLY = True 117 | image_subject_anno.INPAINTING = {"MODE": "all", 118 | "SALIENT": {"PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt"}, 119 | "GDINO": {"TOKENIZER_PATH": "models/VACE-Annotators/gdino/bert-base-uncased", 120 | "CONFIG_PATH": "models/VACE-Annotators/gdino/GroundingDINO_SwinT_OGC_mod.py", 121 | "PRETRAINED_MODEL": "models/VACE-Annotators/gdino/groundingdino_swint_ogc.pth"}, 122 | "SAM2": {"CONFIG_PATH": 'models/VACE-Annotators/sam2/configs/sam2.1/sam2.1_hiera_l.yaml', 123 | "PRETRAINED_MODEL": 'models/VACE-Annotators/sam2/sam2.1_hiera_large.pt'}} 124 | # image_subject_anno.INPUTS = {"image": None, "mode": "salient"} 125 | # image_subject_anno.INPUTS = {"image": None, "mask": None, "mode": "mask"} 126 | # image_subject_anno.INPUTS = {"image": None, "bbox": None, "mode": "bbox"} 127 | # image_subject_anno.INPUTS = {"image": None, "mode": "salientmasktrack"} 128 | # image_subject_anno.INPUTS = {"image": None, "mode": "salientbboxtrack"} 129 | # image_subject_anno.INPUTS = {"image": None, "mask": None, "mode": "masktrack"} 130 | # image_subject_anno.INPUTS = {"image": None, "bbox": None, "mode": "bboxtrack"} 131 | # image_subject_anno.INPUTS = {"image": None, "label": None, "mode": "label"} 132 | # image_subject_anno.INPUTS = {"image": None, "caption": None, "mode": "caption"} 133 | image_subject_anno.INPUTS = {"image": None, "mode": None, "mask": None, "bbox": None, "label": None, "caption": None, "mask_cfg": None} 134 | image_subject_anno.OUTPUTS = {"image": None, "mask": None} 135 | -------------------------------------------------------------------------------- /vace/gradios/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. -------------------------------------------------------------------------------- /vace/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | from . import utils 4 | 5 | try: 6 | from . import ltx 7 | except ImportError as e: 8 | print("Warning: failed to importing 'ltx'. Please install its dependencies with:") 9 | print("pip install ltx-video@git+https://github.com/Lightricks/LTX-Video@ltx-video-0.9.1 sentencepiece --no-deps") 10 | 11 | try: 12 | from . import wan 13 | except ImportError as e: 14 | print("Warning: failed to importing 'wan'. Please install its dependencies with:") 15 | print("pip install wan@git+https://github.com/Wan-Video/Wan2.1") 16 | -------------------------------------------------------------------------------- /vace/models/ltx/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | from . import models 4 | from . import pipelines -------------------------------------------------------------------------------- /vace/models/ltx/ltx_vace.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | from pathlib import Path 4 | 5 | import torch 6 | from transformers import T5EncoderModel, T5Tokenizer 7 | 8 | from ltx_video.models.autoencoders.causal_video_autoencoder import ( 9 | CausalVideoAutoencoder, 10 | ) 11 | from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier 12 | from ltx_video.schedulers.rf import RectifiedFlowScheduler 13 | from ltx_video.utils.conditioning_method import ConditioningMethod 14 | from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy 15 | 16 | from .models.transformers.transformer3d import VaceTransformer3DModel 17 | from .pipelines.pipeline_ltx_video import VaceLTXVideoPipeline 18 | from ..utils.preprocessor import VaceImageProcessor, VaceVideoProcessor 19 | 20 | 21 | 22 | class LTXVace(): 23 | def __init__(self, ckpt_path, text_encoder_path, precision='bfloat16', stg_skip_layers="19", stg_mode="stg_a", offload_to_cpu=False): 24 | self.precision = precision 25 | self.offload_to_cpu = offload_to_cpu 26 | ckpt_path = Path(ckpt_path) 27 | vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) 28 | transformer = VaceTransformer3DModel.from_pretrained(ckpt_path) 29 | scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path) 30 | 31 | text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder") 32 | patchifier = SymmetricPatchifier(patch_size=1) 33 | tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer") 34 | 35 | if torch.cuda.is_available(): 36 | transformer = transformer.cuda() 37 | vae = vae.cuda() 38 | text_encoder = text_encoder.cuda() 39 | 40 | vae = vae.to(torch.bfloat16) 41 | if precision == "bfloat16" and transformer.dtype != torch.bfloat16: 42 | transformer = transformer.to(torch.bfloat16) 43 | text_encoder = text_encoder.to(torch.bfloat16) 44 | 45 | # Set spatiotemporal guidance 46 | self.skip_block_list = [int(x.strip()) for x in stg_skip_layers.split(",")] 47 | self.skip_layer_strategy = ( 48 | SkipLayerStrategy.Attention 49 | if stg_mode.lower() == "stg_a" 50 | else SkipLayerStrategy.Residual 51 | ) 52 | 53 | # Use submodels for the pipeline 54 | submodel_dict = { 55 | "transformer": transformer, 56 | "patchifier": patchifier, 57 | "text_encoder": text_encoder, 58 | "tokenizer": tokenizer, 59 | "scheduler": scheduler, 60 | "vae": vae, 61 | } 62 | 63 | self.pipeline = VaceLTXVideoPipeline(**submodel_dict) 64 | if torch.cuda.is_available(): 65 | self.pipeline = self.pipeline.to("cuda") 66 | 67 | self.img_proc = VaceImageProcessor(downsample=[8,32,32], seq_len=384) 68 | 69 | self.vid_proc = VaceVideoProcessor(downsample=[8,32,32], 70 | min_area=512*768, 71 | max_area=512*768, 72 | min_fps=25, 73 | max_fps=25, 74 | seq_len=4992, 75 | zero_start=True, 76 | keep_last=True) 77 | 78 | 79 | def generate(self, src_video=None, src_mask=None, src_ref_images=[], prompt="", negative_prompt="", seed=42, 80 | num_inference_steps=40, num_images_per_prompt=1, context_scale=1.0, guidance_scale=3, stg_scale=1, stg_rescale=0.7, 81 | frame_rate=25, image_cond_noise_scale=0.15, decode_timestep=0.05, decode_noise_scale=0.025, 82 | output_height=512, output_width=768, num_frames=97): 83 | # src_video: [c, t, h, w] / norm [-1, 1] 84 | # src_mask : [c, t, h, w] / norm [0, 1] 85 | # src_ref_images : [[c, h, w], [c, h, w], ...] / norm [-1, 1] 86 | # image_size: (H, W) 87 | if (src_video is not None and src_video != "") and (src_mask is not None and src_mask != ""): 88 | src_video, src_mask, frame_ids, image_size, frame_rate = self.vid_proc.load_video_batch(src_video, src_mask) 89 | if torch.all(src_mask > 0): 90 | src_mask = torch.ones_like(src_video[:1, :, :, :]) 91 | else: 92 | # bool_mask = src_mask > 0 93 | # bool_mask = bool_mask.expand_as(src_video) 94 | # src_video[bool_mask] = 0 95 | src_mask = src_mask[:1, :, :, :] 96 | src_mask = torch.clamp((src_mask + 1) / 2, min=0, max=1) 97 | elif (src_video is not None and src_video != "") and (src_mask is None or src_mask == ""): 98 | src_video, frame_ids, image_size, frame_rate = self.vid_proc.load_video_batch(src_video) 99 | src_mask = torch.ones_like(src_video[:1, :, :, :]) 100 | else: 101 | output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames) 102 | frame_ids = list(range(num_frames)) 103 | image_size = (output_height, output_width) 104 | src_video = torch.zeros((3, num_frames, output_height, output_width)) 105 | src_mask = torch.ones((1, num_frames, output_height, output_width)) 106 | 107 | src_ref_images_prelist = src_ref_images 108 | src_ref_images = [] 109 | for ref_image in src_ref_images_prelist: 110 | if ref_image != "" and ref_image is not None: 111 | src_ref_images.append(self.img_proc.load_image(ref_image)[0]) 112 | 113 | 114 | # Prepare input for the pipeline 115 | num_frames = len(frame_ids) 116 | sample = { 117 | "src_video": [src_video], 118 | "src_mask": [src_mask], 119 | "src_ref_images": [src_ref_images], 120 | "prompt": [prompt], 121 | "prompt_attention_mask": None, 122 | "negative_prompt": [negative_prompt], 123 | "negative_prompt_attention_mask": None, 124 | } 125 | 126 | generator = torch.Generator( 127 | device="cuda" if torch.cuda.is_available() else "cpu" 128 | ).manual_seed(seed) 129 | 130 | output = self.pipeline( 131 | num_inference_steps=num_inference_steps, 132 | num_images_per_prompt=num_images_per_prompt, 133 | context_scale=context_scale, 134 | guidance_scale=guidance_scale, 135 | skip_layer_strategy=self.skip_layer_strategy, 136 | skip_block_list=self.skip_block_list, 137 | stg_scale=stg_scale, 138 | do_rescaling=stg_rescale != 1, 139 | rescaling_scale=stg_rescale, 140 | generator=generator, 141 | output_type="pt", 142 | callback_on_step_end=None, 143 | height=image_size[0], 144 | width=image_size[1], 145 | num_frames=num_frames, 146 | frame_rate=frame_rate, 147 | **sample, 148 | is_video=True, 149 | vae_per_channel_normalize=True, 150 | conditioning_method=ConditioningMethod.UNCONDITIONAL, 151 | image_cond_noise_scale=image_cond_noise_scale, 152 | decode_timestep=decode_timestep, 153 | decode_noise_scale=decode_noise_scale, 154 | mixed_precision=(self.precision in "mixed_precision"), 155 | offload_to_cpu=self.offload_to_cpu, 156 | ) 157 | gen_video = output.images[0] 158 | gen_video = gen_video.to(torch.float32) if gen_video.dtype == torch.bfloat16 else gen_video 159 | info = output.info 160 | 161 | ret_data = { 162 | "out_video": gen_video, 163 | "src_video": src_video, 164 | "src_mask": src_mask, 165 | "src_ref_images": src_ref_images, 166 | "info": info 167 | } 168 | return ret_data -------------------------------------------------------------------------------- /vace/models/ltx/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | from . import transformers -------------------------------------------------------------------------------- /vace/models/ltx/models/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | from .attention import BasicTransformerMainBlock, BasicTransformerBypassBlock 4 | from .transformer3d import VaceTransformer3DModel -------------------------------------------------------------------------------- /vace/models/ltx/models/transformers/attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import torch 4 | from torch import nn 5 | 6 | from diffusers.utils.torch_utils import maybe_allow_in_graph 7 | 8 | from ltx_video.models.transformers.attention import BasicTransformerBlock 9 | 10 | 11 | @maybe_allow_in_graph 12 | class BasicTransformerMainBlock(BasicTransformerBlock): 13 | def __init__(self, *args, **kwargs): 14 | self.block_id = kwargs.pop('block_id') 15 | super().__init__(*args, **kwargs) 16 | 17 | def forward(self, *args, **kwargs) -> torch.FloatTensor: 18 | context_hints = kwargs.pop('context_hints') 19 | context_scale = kwargs.pop('context_scale') 20 | hidden_states = super().forward(*args, **kwargs) 21 | if self.block_id < len(context_hints) and context_hints[self.block_id] is not None: 22 | hidden_states = hidden_states + context_hints[self.block_id] * context_scale 23 | return hidden_states 24 | 25 | 26 | @maybe_allow_in_graph 27 | class BasicTransformerBypassBlock(BasicTransformerBlock): 28 | def __init__(self, *args, **kwargs): 29 | self.dim = args[0] 30 | self.block_id = kwargs.pop('block_id') 31 | super().__init__(*args, **kwargs) 32 | if self.block_id == 0: 33 | self.before_proj = nn.Linear(self.dim, self.dim) 34 | nn.init.zeros_(self.before_proj.weight) 35 | nn.init.zeros_(self.before_proj.bias) 36 | self.after_proj = nn.Linear(self.dim, self.dim) 37 | nn.init.zeros_(self.after_proj.weight) 38 | nn.init.zeros_(self.after_proj.bias) 39 | 40 | def forward(self, *args, **kwargs): 41 | hidden_states = kwargs.pop('hidden_states') 42 | context_hidden_states = kwargs.pop('context_hidden_states') 43 | if self.block_id == 0: 44 | context_hidden_states = self.before_proj(context_hidden_states) + hidden_states 45 | 46 | kwargs['hidden_states'] = context_hidden_states 47 | bypass_context_hidden_states = super().forward(*args, **kwargs) 48 | main_context_hidden_states = self.after_proj(bypass_context_hidden_states) 49 | return (main_context_hidden_states, bypass_context_hidden_states) 50 | -------------------------------------------------------------------------------- /vace/models/ltx/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | from .pipeline_ltx_video import VaceLTXVideoPipeline -------------------------------------------------------------------------------- /vace/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | from .preprocessor import VaceVideoProcessor -------------------------------------------------------------------------------- /vace/models/wan/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | from . import modules 4 | from .wan_vace import WanVace 5 | -------------------------------------------------------------------------------- /vace/models/wan/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import os 4 | 5 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 6 | 7 | from .wan_t2v_1_3B import t2v_1_3B 8 | from .wan_t2v_14B import t2v_14B 9 | 10 | WAN_CONFIGS = { 11 | 'vace-1.3B': t2v_1_3B, 12 | 'vace-14B': t2v_14B, 13 | } 14 | 15 | SIZE_CONFIGS = { 16 | '720*1280': (720, 1280), 17 | '1280*720': (1280, 720), 18 | '480*832': (480, 832), 19 | '832*480': (832, 480), 20 | '1024*1024': (1024, 1024), 21 | '720p': (1280, 720), 22 | '480p': (480, 832) 23 | } 24 | 25 | MAX_AREA_CONFIGS = { 26 | '720*1280': 720 * 1280, 27 | '1280*720': 1280 * 720, 28 | '480*832': 480 * 832, 29 | '832*480': 832 * 480, 30 | '720p': 1280 * 720, 31 | '480p': 480 * 832 32 | } 33 | 34 | SUPPORTED_SIZES = { 35 | 'vace-1.3B': ('480*832', '832*480', '480p'), 36 | 'vace-14B': ('720*1280', '1280*720', '480*832', '832*480', '480p', '720p') 37 | } 38 | -------------------------------------------------------------------------------- /vace/models/wan/configs/shared_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import torch 4 | from easydict import EasyDict 5 | 6 | #------------------------ Wan shared config ------------------------# 7 | wan_shared_cfg = EasyDict() 8 | 9 | # t5 10 | wan_shared_cfg.t5_model = 'umt5_xxl' 11 | wan_shared_cfg.t5_dtype = torch.bfloat16 12 | wan_shared_cfg.text_len = 512 13 | 14 | # transformer 15 | wan_shared_cfg.param_dtype = torch.bfloat16 16 | 17 | # inference 18 | wan_shared_cfg.num_train_timesteps = 1000 19 | wan_shared_cfg.sample_fps = 16 20 | wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' 21 | -------------------------------------------------------------------------------- /vace/models/wan/configs/wan_t2v_14B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | from easydict import EasyDict 3 | 4 | from .shared_config import wan_shared_cfg 5 | 6 | #------------------------ Wan T2V 14B ------------------------# 7 | 8 | t2v_14B = EasyDict(__name__='Config: Wan T2V 14B') 9 | t2v_14B.update(wan_shared_cfg) 10 | 11 | # t5 12 | t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 13 | t2v_14B.t5_tokenizer = 'google/umt5-xxl' 14 | 15 | # vae 16 | t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' 17 | t2v_14B.vae_stride = (4, 8, 8) 18 | 19 | # transformer 20 | t2v_14B.patch_size = (1, 2, 2) 21 | t2v_14B.dim = 5120 22 | t2v_14B.ffn_dim = 13824 23 | t2v_14B.freq_dim = 256 24 | t2v_14B.num_heads = 40 25 | t2v_14B.num_layers = 40 26 | t2v_14B.window_size = (-1, -1) 27 | t2v_14B.qk_norm = True 28 | t2v_14B.cross_attn_norm = True 29 | t2v_14B.eps = 1e-6 30 | -------------------------------------------------------------------------------- /vace/models/wan/configs/wan_t2v_1_3B.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | from easydict import EasyDict 4 | 5 | from .shared_config import wan_shared_cfg 6 | 7 | #------------------------ Wan T2V 1.3B ------------------------# 8 | 9 | t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B') 10 | t2v_1_3B.update(wan_shared_cfg) 11 | 12 | # t5 13 | t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 14 | t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' 15 | 16 | # vae 17 | t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' 18 | t2v_1_3B.vae_stride = (4, 8, 8) 19 | 20 | # transformer 21 | t2v_1_3B.patch_size = (1, 2, 2) 22 | t2v_1_3B.dim = 1536 23 | t2v_1_3B.ffn_dim = 8960 24 | t2v_1_3B.freq_dim = 256 25 | t2v_1_3B.num_heads = 12 26 | t2v_1_3B.num_layers = 30 27 | t2v_1_3B.window_size = (-1, -1) 28 | t2v_1_3B.qk_norm = True 29 | t2v_1_3B.cross_attn_norm = True 30 | t2v_1_3B.eps = 1e-6 31 | -------------------------------------------------------------------------------- /vace/models/wan/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | from .xdit_context_parallel import pad_freqs, rope_apply, usp_dit_forward_vace, usp_dit_forward, usp_attn_forward -------------------------------------------------------------------------------- /vace/models/wan/distributed/xdit_context_parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import torch 4 | import torch.cuda.amp as amp 5 | from xfuser.core.distributed import (get_sequence_parallel_rank, 6 | get_sequence_parallel_world_size, 7 | get_sp_group) 8 | from xfuser.core.long_ctx_attention import xFuserLongContextAttention 9 | 10 | from ..modules.model import sinusoidal_embedding_1d 11 | 12 | 13 | def pad_freqs(original_tensor, target_len): 14 | seq_len, s1, s2 = original_tensor.shape 15 | pad_size = target_len - seq_len 16 | padding_tensor = torch.ones( 17 | pad_size, 18 | s1, 19 | s2, 20 | dtype=original_tensor.dtype, 21 | device=original_tensor.device) 22 | padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) 23 | return padded_tensor 24 | 25 | 26 | @amp.autocast(enabled=False) 27 | def rope_apply(x, grid_sizes, freqs): 28 | """ 29 | x: [B, L, N, C]. 30 | grid_sizes: [B, 3]. 31 | freqs: [M, C // 2]. 32 | """ 33 | s, n, c = x.size(1), x.size(2), x.size(3) // 2 34 | # split freqs 35 | freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) 36 | 37 | # loop over samples 38 | output = [] 39 | for i, (f, h, w) in enumerate(grid_sizes.tolist()): 40 | seq_len = f * h * w 41 | 42 | # precompute multipliers 43 | x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape( 44 | s, n, -1, 2)) 45 | freqs_i = torch.cat([ 46 | freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), 47 | freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), 48 | freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) 49 | ], 50 | dim=-1).reshape(seq_len, 1, -1) 51 | 52 | # apply rotary embedding 53 | sp_size = get_sequence_parallel_world_size() 54 | sp_rank = get_sequence_parallel_rank() 55 | freqs_i = pad_freqs(freqs_i, s * sp_size) 56 | s_per_rank = s 57 | freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * 58 | s_per_rank), :, :] 59 | x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) 60 | x_i = torch.cat([x_i, x[i, s:]]) 61 | 62 | # append to collection 63 | output.append(x_i) 64 | return torch.stack(output).float() 65 | 66 | 67 | def usp_dit_forward_vace( 68 | self, 69 | x, 70 | vace_context, 71 | seq_len, 72 | kwargs 73 | ): 74 | # embeddings 75 | c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] 76 | c = [u.flatten(2).transpose(1, 2) for u in c] 77 | c = torch.cat([ 78 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], 79 | dim=1) for u in c 80 | ]) 81 | 82 | # arguments 83 | new_kwargs = dict(x=x) 84 | new_kwargs.update(kwargs) 85 | 86 | # Context Parallel 87 | c = torch.chunk( 88 | c, get_sequence_parallel_world_size(), 89 | dim=1)[get_sequence_parallel_rank()] 90 | 91 | for block in self.vace_blocks: 92 | c = block(c, **new_kwargs) 93 | hints = torch.unbind(c)[:-1] 94 | return hints 95 | 96 | 97 | def usp_dit_forward( 98 | self, 99 | x, 100 | t, 101 | vace_context, 102 | context, 103 | seq_len, 104 | vace_context_scale=1.0, 105 | clip_fea=None, 106 | y=None, 107 | ): 108 | """ 109 | x: A list of videos each with shape [C, T, H, W]. 110 | t: [B]. 111 | context: A list of text embeddings each with shape [L, C]. 112 | """ 113 | # params 114 | device = self.patch_embedding.weight.device 115 | if self.freqs.device != device: 116 | self.freqs = self.freqs.to(device) 117 | 118 | # if y is not None: 119 | # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] 120 | 121 | # embeddings 122 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x] 123 | grid_sizes = torch.stack( 124 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) 125 | x = [u.flatten(2).transpose(1, 2) for u in x] 126 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) 127 | assert seq_lens.max() <= seq_len 128 | x = torch.cat([ 129 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) 130 | for u in x 131 | ]) 132 | 133 | # time embeddings 134 | with amp.autocast(dtype=torch.float32): 135 | e = self.time_embedding( 136 | sinusoidal_embedding_1d(self.freq_dim, t).float()) 137 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) 138 | assert e.dtype == torch.float32 and e0.dtype == torch.float32 139 | 140 | # context 141 | context_lens = None 142 | context = self.text_embedding( 143 | torch.stack([ 144 | torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) 145 | for u in context 146 | ])) 147 | 148 | # if clip_fea is not None: 149 | # context_clip = self.img_emb(clip_fea) # bs x 257 x dim 150 | # context = torch.concat([context_clip, context], dim=1) 151 | 152 | # arguments 153 | kwargs = dict( 154 | e=e0, 155 | seq_lens=seq_lens, 156 | grid_sizes=grid_sizes, 157 | freqs=self.freqs, 158 | context=context, 159 | context_lens=context_lens) 160 | 161 | # Context Parallel 162 | x = torch.chunk( 163 | x, get_sequence_parallel_world_size(), 164 | dim=1)[get_sequence_parallel_rank()] 165 | 166 | hints = self.forward_vace(x, vace_context, seq_len, kwargs) 167 | kwargs['hints'] = hints 168 | kwargs['context_scale'] = vace_context_scale 169 | 170 | for block in self.blocks: 171 | x = block(x, **kwargs) 172 | 173 | # head 174 | x = self.head(x, e) 175 | 176 | # Context Parallel 177 | x = get_sp_group().all_gather(x, dim=1) 178 | 179 | # unpatchify 180 | x = self.unpatchify(x, grid_sizes) 181 | return [u.float() for u in x] 182 | 183 | 184 | def usp_attn_forward(self, 185 | x, 186 | seq_lens, 187 | grid_sizes, 188 | freqs, 189 | dtype=torch.bfloat16): 190 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim 191 | half_dtypes = (torch.float16, torch.bfloat16) 192 | 193 | def half(x): 194 | return x if x.dtype in half_dtypes else x.to(dtype) 195 | 196 | # query, key, value function 197 | def qkv_fn(x): 198 | q = self.norm_q(self.q(x)).view(b, s, n, d) 199 | k = self.norm_k(self.k(x)).view(b, s, n, d) 200 | v = self.v(x).view(b, s, n, d) 201 | return q, k, v 202 | 203 | q, k, v = qkv_fn(x) 204 | q = rope_apply(q, grid_sizes, freqs) 205 | k = rope_apply(k, grid_sizes, freqs) 206 | 207 | # TODO: We should use unpaded q,k,v for attention. 208 | # k_lens = seq_lens // get_sequence_parallel_world_size() 209 | # if k_lens is not None: 210 | # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0) 211 | # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) 212 | # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) 213 | 214 | x = xFuserLongContextAttention()( 215 | None, 216 | query=half(q), 217 | key=half(k), 218 | value=half(v), 219 | window_size=self.window_size) 220 | 221 | # TODO: padding after attention. 222 | # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1) 223 | 224 | # output 225 | x = x.flatten(2) 226 | x = self.o(x) 227 | return x 228 | -------------------------------------------------------------------------------- /vace/models/wan/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | from .model import VaceWanAttentionBlock, BaseWanAttentionBlock, VaceWanModel -------------------------------------------------------------------------------- /vace/models/wan/modules/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | import torch 4 | import torch.cuda.amp as amp 5 | import torch.nn as nn 6 | from diffusers.configuration_utils import register_to_config 7 | from wan.modules.model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d 8 | 9 | 10 | class VaceWanAttentionBlock(WanAttentionBlock): 11 | def __init__( 12 | self, 13 | cross_attn_type, 14 | dim, 15 | ffn_dim, 16 | num_heads, 17 | window_size=(-1, -1), 18 | qk_norm=True, 19 | cross_attn_norm=False, 20 | eps=1e-6, 21 | block_id=0 22 | ): 23 | super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) 24 | self.block_id = block_id 25 | if block_id == 0: 26 | self.before_proj = nn.Linear(self.dim, self.dim) 27 | nn.init.zeros_(self.before_proj.weight) 28 | nn.init.zeros_(self.before_proj.bias) 29 | self.after_proj = nn.Linear(self.dim, self.dim) 30 | nn.init.zeros_(self.after_proj.weight) 31 | nn.init.zeros_(self.after_proj.bias) 32 | 33 | def forward(self, c, x, **kwargs): 34 | if self.block_id == 0: 35 | c = self.before_proj(c) + x 36 | all_c = [] 37 | else: 38 | all_c = list(torch.unbind(c)) 39 | c = all_c.pop(-1) 40 | c = super().forward(c, **kwargs) 41 | c_skip = self.after_proj(c) 42 | all_c += [c_skip, c] 43 | c = torch.stack(all_c) 44 | return c 45 | 46 | 47 | class BaseWanAttentionBlock(WanAttentionBlock): 48 | def __init__( 49 | self, 50 | cross_attn_type, 51 | dim, 52 | ffn_dim, 53 | num_heads, 54 | window_size=(-1, -1), 55 | qk_norm=True, 56 | cross_attn_norm=False, 57 | eps=1e-6, 58 | block_id=None 59 | ): 60 | super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) 61 | self.block_id = block_id 62 | 63 | def forward(self, x, hints, context_scale=1.0, **kwargs): 64 | x = super().forward(x, **kwargs) 65 | if self.block_id is not None: 66 | x = x + hints[self.block_id] * context_scale 67 | return x 68 | 69 | 70 | class VaceWanModel(WanModel): 71 | @register_to_config 72 | def __init__(self, 73 | vace_layers=None, 74 | vace_in_dim=None, 75 | model_type='t2v', 76 | patch_size=(1, 2, 2), 77 | text_len=512, 78 | in_dim=16, 79 | dim=2048, 80 | ffn_dim=8192, 81 | freq_dim=256, 82 | text_dim=4096, 83 | out_dim=16, 84 | num_heads=16, 85 | num_layers=32, 86 | window_size=(-1, -1), 87 | qk_norm=True, 88 | cross_attn_norm=True, 89 | eps=1e-6): 90 | model_type = "t2v" # TODO: Hard code for both preview and official versions. 91 | super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim, 92 | num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps) 93 | 94 | self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers 95 | self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim 96 | 97 | assert 0 in self.vace_layers 98 | self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} 99 | 100 | # blocks 101 | self.blocks = nn.ModuleList([ 102 | BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, 103 | self.cross_attn_norm, self.eps, 104 | block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None) 105 | for i in range(self.num_layers) 106 | ]) 107 | 108 | # vace blocks 109 | self.vace_blocks = nn.ModuleList([ 110 | VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, 111 | self.cross_attn_norm, self.eps, block_id=i) 112 | for i in self.vace_layers 113 | ]) 114 | 115 | # vace patch embeddings 116 | self.vace_patch_embedding = nn.Conv3d( 117 | self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size 118 | ) 119 | 120 | def forward_vace( 121 | self, 122 | x, 123 | vace_context, 124 | seq_len, 125 | kwargs 126 | ): 127 | # embeddings 128 | c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] 129 | c = [u.flatten(2).transpose(1, 2) for u in c] 130 | c = torch.cat([ 131 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], 132 | dim=1) for u in c 133 | ]) 134 | 135 | # arguments 136 | new_kwargs = dict(x=x) 137 | new_kwargs.update(kwargs) 138 | 139 | for block in self.vace_blocks: 140 | c = block(c, **new_kwargs) 141 | hints = torch.unbind(c)[:-1] 142 | return hints 143 | 144 | def forward( 145 | self, 146 | x, 147 | t, 148 | vace_context, 149 | context, 150 | seq_len, 151 | vace_context_scale=1.0, 152 | clip_fea=None, 153 | y=None, 154 | ): 155 | r""" 156 | Forward pass through the diffusion model 157 | 158 | Args: 159 | x (List[Tensor]): 160 | List of input video tensors, each with shape [C_in, F, H, W] 161 | t (Tensor): 162 | Diffusion timesteps tensor of shape [B] 163 | context (List[Tensor]): 164 | List of text embeddings each with shape [L, C] 165 | seq_len (`int`): 166 | Maximum sequence length for positional encoding 167 | clip_fea (Tensor, *optional*): 168 | CLIP image features for image-to-video mode 169 | y (List[Tensor], *optional*): 170 | Conditional video inputs for image-to-video mode, same shape as x 171 | 172 | Returns: 173 | List[Tensor]: 174 | List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] 175 | """ 176 | # if self.model_type == 'i2v': 177 | # assert clip_fea is not None and y is not None 178 | # params 179 | device = self.patch_embedding.weight.device 180 | if self.freqs.device != device: 181 | self.freqs = self.freqs.to(device) 182 | 183 | # if y is not None: 184 | # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] 185 | 186 | # embeddings 187 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x] 188 | grid_sizes = torch.stack( 189 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) 190 | x = [u.flatten(2).transpose(1, 2) for u in x] 191 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) 192 | assert seq_lens.max() <= seq_len 193 | x = torch.cat([ 194 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], 195 | dim=1) for u in x 196 | ]) 197 | 198 | # time embeddings 199 | with amp.autocast(dtype=torch.float32): 200 | e = self.time_embedding( 201 | sinusoidal_embedding_1d(self.freq_dim, t).float()) 202 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) 203 | assert e.dtype == torch.float32 and e0.dtype == torch.float32 204 | 205 | # context 206 | context_lens = None 207 | context = self.text_embedding( 208 | torch.stack([ 209 | torch.cat( 210 | [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) 211 | for u in context 212 | ])) 213 | 214 | # if clip_fea is not None: 215 | # context_clip = self.img_emb(clip_fea) # bs x 257 x dim 216 | # context = torch.concat([context_clip, context], dim=1) 217 | 218 | # arguments 219 | kwargs = dict( 220 | e=e0, 221 | seq_lens=seq_lens, 222 | grid_sizes=grid_sizes, 223 | freqs=self.freqs, 224 | context=context, 225 | context_lens=context_lens) 226 | 227 | hints = self.forward_vace(x, vace_context, seq_len, kwargs) 228 | kwargs['hints'] = hints 229 | kwargs['context_scale'] = vace_context_scale 230 | 231 | for block in self.blocks: 232 | x = block(x, **kwargs) 233 | 234 | # head 235 | x = self.head(x, e) 236 | 237 | # unpatchify 238 | x = self.unpatchify(x, grid_sizes) 239 | return [u.float() for u in x] -------------------------------------------------------------------------------- /vace/vace_pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import importlib 4 | from typing import Dict, Any 5 | 6 | def load_parser(module_name: str) -> argparse.ArgumentParser: 7 | module = importlib.import_module(module_name) 8 | if not hasattr(module, "get_parser"): 9 | raise ValueError(f"{module_name} undefined get_parser()") 10 | return module.get_parser() 11 | 12 | def filter_args(args: Dict[str, Any], parser: argparse.ArgumentParser) -> Dict[str, Any]: 13 | known_args = set() 14 | for action in parser._actions: 15 | if action.dest and action.dest != "help": 16 | known_args.add(action.dest) 17 | return {k: v for k, v in args.items() if k in known_args} 18 | 19 | def main(): 20 | 21 | main_parser = argparse.ArgumentParser() 22 | main_parser.add_argument("--base", type=str, default='ltx', choices=['ltx', 'wan']) 23 | pipeline_args, _ = main_parser.parse_known_args() 24 | 25 | if pipeline_args.base in ["ltx"]: 26 | preproccess_name, inference_name = "vace_preproccess", "vace_ltx_inference" 27 | else: 28 | preproccess_name, inference_name = "vace_preproccess", "vace_wan_inference" 29 | 30 | preprocess_parser = load_parser(preproccess_name) 31 | inference_parser = load_parser(inference_name) 32 | 33 | for parser in [preprocess_parser, inference_parser]: 34 | for action in parser._actions: 35 | if action.dest != "help": 36 | main_parser._add_action(action) 37 | 38 | cli_args = main_parser.parse_args() 39 | args_dict = vars(cli_args) 40 | 41 | # run preprocess 42 | preprocess_args = filter_args(args_dict, preprocess_parser) 43 | preprocesser = importlib.import_module(preproccess_name) 44 | preprocess_output = preprocesser.main(preprocess_args) 45 | print("preprocess_output:", preprocess_output) 46 | 47 | del preprocesser 48 | torch.cuda.empty_cache() 49 | 50 | # run inference 51 | inference_args = filter_args(args_dict, inference_parser) 52 | inference_args.update(preprocess_output) 53 | preprocess_output = importlib.import_module(inference_name).main(inference_args) 54 | print("preprocess_output:", preprocess_output) 55 | 56 | 57 | if __name__ == "__main__": 58 | main() --------------------------------------------------------------------------------