├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── configs └── pixelflow_xl_c2i.yaml ├── imagenet_en_cn.py ├── pixelflow ├── data_in1k.py ├── model.py ├── pipeline_pixelflow.py ├── scheduling_pixelflow.py ├── solver_ode_wrapper.py └── utils │ ├── config.py │ ├── logger.py │ └── misc.py ├── requirements.txt ├── sample_ddp.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | /dataset/ 2 | # output dir 3 | /output 4 | /ckpt_* 5 | /exp0*/ 6 | 7 | *.diff 8 | 9 | # compilation and distribution 10 | __pycache__ 11 | _ext 12 | *.pyc 13 | *.pyd 14 | *.so 15 | build/ 16 | dist/ 17 | wheels/ 18 | *.egg-info/ 19 | 20 | # pytorch/python/numpy formats 21 | *.pth 22 | *.pkl 23 | *.pt 24 | 25 | 26 | # Editor temporaries 27 | *.mp4 28 | *.swn 29 | *.swo 30 | *.swp 31 | *~ 32 | 33 | # editor settings 34 | .idea 35 | .vscode 36 | 37 | # macOS system files 38 | .DS_Store 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Shoufa Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

PixelFlow: Pixel-Space Generative Models with Flow

4 | 5 | [![arXiv](https://img.shields.io/badge/arXiv%20paper-2504.07963-b31b1b.svg)](https://arxiv.org/abs/2504.07963)  6 | [![demo](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Online_Demo-blue)](https://huggingface.co/spaces/ShoufaChen/PixelFlow)  7 | 8 | 9 | ![pixelflow](https://github.com/user-attachments/assets/7e2e4db9-4b41-46ca-8d43-92f2b642a676) 10 | 11 |
12 | 13 | 14 | 15 | 16 | > [**PixelFlow: Pixel-Space Generative Models with Flow**](https://arxiv.org/abs/2504.07963)
17 | > [Shoufa Chen](https://www.shoufachen.com), [Chongjian Ge](https://chongjiange.github.io/), [Shilong Zhang](https://jshilong.github.io/), [Peize Sun](https://peizesun.github.io/), [Ping Luo](http://luoping.me/) 18 | >
The University of Hong Kong, Adobe
19 | 20 | ## Introduction 21 | We present PixelFlow, a family of image generation models that operate directly in the raw pixel space, in contrast to the predominant latent-space models. This approach simplifies the image generation process by eliminating the need for a pre-trained Variational Autoencoder (VAE) and enabling the whole model end-to-end trainable. Through efficient cascade flow modeling, PixelFlow achieves affordable computation cost in pixel space. It achieves an FID of 1.98 on 256x256 ImageNet class-conditional image generation benchmark. The qualitative text-to-image results demonstrate that PixelFlow excels in image quality, artistry, and semantic control. We hope this new paradigm will inspire and open up new opportunities for next-generation visual generation models. 22 | 23 | 24 | ## Model Zoo 25 | 26 | | Model | Task | Params | FID | Checkpoint | 27 | |:---------:|:--------------:|:------:|:----:|:----------:| 28 | | PixelFlow | class-to-image | 677M | 1.98 | [🤗](https://huggingface.co/ShoufaChen/PixelFlow-Class2Image) | 29 | | PixelFlow | text-to-image | 882M | N/A | [🤗](https://huggingface.co/ShoufaChen/PixelFlow-Text2Image) | 30 | 31 | 32 | ## Setup 33 | 34 | ### 1. Create Environment 35 | ```bash 36 | conda create -n pixelflow python=3.12 37 | conda activate pixelflow 38 | ``` 39 | ### 2. Install Dependencies: 40 | * [PyTorch 2.6.0](https://pytorch.org/) — install it according to your system configuration (CUDA version, etc.). 41 | * [flash-attention v2.7.4.post1](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.7.4.post1): optional, required only for training. 42 | * Other packages: `pip3 install -r requirements.txt` 43 | 44 | 45 | ## Demo [![demo](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Online_Demo-blue)](https://huggingface.co/spaces/ShoufaChen/PixelFlow) 46 | 47 | 48 | We provide an online [Gradio demo](https://huggingface.co/spaces/ShoufaChen/PixelFlow) for class-to-image generation. 49 | 50 | You can also easily deploy both class-to-image and text-to-image demos locally by: 51 | 52 | ```bash 53 | python app.py --checkpoint /path/to/checkpoint --class_cond # for class-to-image 54 | ``` 55 | or 56 | ```bash 57 | python app.py --checkpoint /path/to/checkpoint # for text-to-image 58 | ``` 59 | 60 | 61 | ## Training 62 | 63 | ### 1. ImageNet Preparation 64 | 65 | - Download the ImageNet dataset from [http://www.image-net.org/](http://www.image-net.org/). 66 | - Use the [extract_ILSVRC.sh]([extract_ILSVRC.sh](https://github.com/pytorch/examples/blob/main/imagenet/extract_ILSVRC.sh)) to extract and organize the training and validation images into labeled subfolders. 67 | 68 | ### 2. Training Command 69 | 70 | ```bash 71 | torchrun --nnodes=1 --nproc_per_node=8 train.py configs/pixelflow_xl_c2i.yaml 72 | ``` 73 | 74 | ## Evaluation (FID, Inception Score, etc.) 75 | 76 | We provide a [sample_ddp.py](sample_ddp.py) script, adapted from [DiT](https://github.com/facebookresearch/DiT), for generating sample images and saving them both as a folder and as a .npz file. The .npz file is compatible with ADM's TensorFlow evaluation suite, allowing direct computation of FID, Inception Score, and other metrics. 77 | 78 | 79 | ```bash 80 | torchrun --nnodes=1 --nproc_per_node=8 sample_ddp.py --pretrained /path/to/checkpoint 81 | ``` 82 | 83 | 84 | ## BibTeX 85 | ```bibtex 86 | @article{chen2025pixelflow, 87 | title={PixelFlow: Pixel-Space Generative Models with Flow}, 88 | author={Chen, Shoufa and Ge, Chongjian and Zhang, Shilong and Sun, Peize and Luo, Ping}, 89 | journal={arXiv preprint arXiv:2504.07963}, 90 | year={2025} 91 | } 92 | ``` 93 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from PIL import Image 4 | import gradio as gr 5 | from imagenet_en_cn import IMAGENET_1K_CLASSES 6 | from omegaconf import OmegaConf 7 | 8 | import torch 9 | from transformers import T5EncoderModel, AutoTokenizer 10 | 11 | from pixelflow.scheduling_pixelflow import PixelFlowScheduler 12 | from pixelflow.pipeline_pixelflow import PixelFlowPipeline 13 | from pixelflow.utils import config as config_utils 14 | from pixelflow.utils.misc import seed_everything 15 | 16 | 17 | parser = argparse.ArgumentParser(description='Gradio Demo', add_help=False) 18 | parser.add_argument('--checkpoint', type=str, help='checkpoint folder path') 19 | parser.add_argument('--class_cond', action='store_true', help='use class conditional generation') 20 | args = parser.parse_args() 21 | 22 | local_rank = 0 23 | device = torch.device(f"cuda:{local_rank}") 24 | torch.cuda.set_device(device) 25 | 26 | output_dir = args.checkpoint 27 | if args.class_cond: 28 | config = OmegaConf.load(f"{output_dir}/config.yaml") 29 | model = config_utils.instantiate_from_config(config.model).to(device) 30 | print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 31 | ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True) 32 | text_encoder = None 33 | tokenizer = None 34 | resolution = 256 35 | NUM_EXAMPLES = 4 36 | else: 37 | config = OmegaConf.load(f"{output_dir}/config.yaml") 38 | model = config_utils.instantiate_from_config(config.model).to(device) 39 | print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 40 | ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True) 41 | text_encoder = T5EncoderModel.from_pretrained("google/flan-t5-xl").to(device) 42 | tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl") 43 | resolution = 1024 44 | NUM_EXAMPLES = 1 45 | model.load_state_dict(ckpt, strict=True) 46 | model.eval() 47 | 48 | scheduler = PixelFlowScheduler(config.scheduler.num_train_timesteps, num_stages=config.scheduler.num_stages, gamma=-1/3) 49 | 50 | pipeline = PixelFlowPipeline( 51 | scheduler, 52 | model, 53 | text_encoder=text_encoder, 54 | tokenizer=tokenizer, 55 | max_token_length=512, 56 | ) 57 | 58 | def infer(use_ode_dopri5, noise_shift, cfg_scale, class_label, seed, *num_steps_per_stage): 59 | seed_everything(seed) 60 | with torch.autocast("cuda", dtype=torch.bfloat16), torch.no_grad(): 61 | samples = pipeline( 62 | prompt=[class_label] * NUM_EXAMPLES, 63 | height=resolution, 64 | width=resolution, 65 | num_inference_steps=list(num_steps_per_stage), 66 | guidance_scale=cfg_scale, # The guidance for the first frame, set it to 7 for 384p variant 67 | device=device, 68 | shift=noise_shift, 69 | use_ode_dopri5=use_ode_dopri5, 70 | ) 71 | samples = (samples * 255).round().astype("uint8") 72 | samples = [Image.fromarray(sample) for sample in samples] 73 | return samples 74 | 75 | 76 | with gr.Blocks() as demo: 77 | gr.Markdown("

PixelFlow: Pixel-Space Generative Models with Flow

") 78 | 79 | with gr.Tabs(): 80 | with gr.TabItem('Generate'): 81 | with gr.Row(): 82 | with gr.Column(): 83 | with gr.Row(): 84 | if args.class_cond: 85 | user_input = gr.Dropdown( 86 | list(IMAGENET_1K_CLASSES.values()), 87 | value='daisy [雏菊]', 88 | type="index", label='ImageNet-1K Class' 89 | ) 90 | else: 91 | # text input 92 | user_input = gr.Textbox(label='Enter your prompt', show_label=False, max_lines=1, placeholder="Enter your prompt",) 93 | ode_dopri5 = gr.Checkbox(label="Dopri5 ODE", info="Use Dopri5 ODE solver") 94 | noise_shift = gr.Slider(minimum=1.0, maximum=100.0, step=1, value=1.0, label='Noise Shift') 95 | cfg_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=4.0, label='Classifier-free Guidance Scale') 96 | num_steps_per_stage = [] 97 | for stage_idx in range(config.scheduler.num_stages): 98 | num_steps = gr.Slider(minimum=1, maximum=100, step=1, value=10, label=f'Num Inference Steps (Stage {stage_idx})') 99 | num_steps_per_stage.append(num_steps) 100 | seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed') 101 | button = gr.Button("Generate", variant="primary") 102 | with gr.Column(): 103 | output = gr.Gallery(label='Generated Images', height=700) 104 | button.click(infer, inputs=[ode_dopri5, noise_shift, cfg_scale, user_input, seed, *num_steps_per_stage], outputs=[output]) 105 | demo.queue() 106 | demo.launch(share=False, debug=True) 107 | -------------------------------------------------------------------------------- /configs/pixelflow_xl_c2i.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: pixelflow.model.PixelFlowModel 3 | params: 4 | num_attention_heads: 16 5 | attention_head_dim: 72 6 | in_channels: 3 7 | out_channels: 3 8 | depth: 28 9 | num_classes: 1000 10 | patch_size: 4 11 | attention_bias: true 12 | 13 | scheduler: 14 | num_train_timesteps: 1000 15 | num_stages: 4 16 | pyramid_shift: false 17 | 18 | train: 19 | lr: 1e-4 20 | weight_decay: 0.0 21 | epochs: 10 22 | 23 | data: 24 | root: /public/datasets/ILSVRC2012/train 25 | center_crop: false 26 | resolution: 256 27 | expand_ratio: 1.125 28 | num_workers: 4 29 | batch_size: 4 30 | 31 | seed: 42 32 | -------------------------------------------------------------------------------- /imagenet_en_cn.py: -------------------------------------------------------------------------------- 1 | IMAGENET_1K_CLASSES = { 2 | 0: 'tench, Tinca tinca [丁鲷]', 3 | 1: 'goldfish, Carassius auratus [金鱼]', 4 | 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias [大白鲨]', 5 | 3: 'tiger shark, Galeocerdo cuvieri [虎鲨]', 6 | 4: 'hammerhead, hammerhead shark [锤头鲨]', 7 | 5: 'electric ray, crampfish, numbfish, torpedo [电鳐]', 8 | 6: 'stingray [黄貂鱼]', 9 | 7: 'cock [公鸡]', 10 | 8: 'hen [母鸡]', 11 | 9: 'ostrich, Struthio camelus [鸵鸟]', 12 | 10: 'brambling, Fringilla montifringilla [燕雀]', 13 | 11: 'goldfinch, Carduelis carduelis [金翅雀]', 14 | 12: 'house finch, linnet, Carpodacus mexicanus [家朱雀]', 15 | 13: 'junco, snowbird [灯芯草雀]', 16 | 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea [靛蓝雀,靛蓝鸟]', 17 | 15: 'robin, American robin, Turdus migratorius [蓝鹀]', 18 | 16: 'bulbul [夜莺]', 19 | 17: 'jay [松鸦]', 20 | 18: 'magpie [喜鹊]', 21 | 19: 'chickadee [山雀]', 22 | 20: 'water ouzel, dipper [河鸟]', 23 | 21: 'kite [鸢(猛禽)]', 24 | 22: 'bald eagle, American eagle, Haliaeetus leucocephalus [秃头鹰]', 25 | 23: 'vulture [秃鹫]', 26 | 24: 'great grey owl, great gray owl, Strix nebulosa [大灰猫头鹰]', 27 | 25: 'European fire salamander, Salamandra salamandra [欧洲火蝾螈]', 28 | 26: 'common newt, Triturus vulgaris [普通蝾螈]', 29 | 27: 'eft [水蜥]', 30 | 28: 'spotted salamander, Ambystoma maculatum [斑点蝾螈]', 31 | 29: 'axolotl, mud puppy, Ambystoma mexicanum [蝾螈,泥狗]', 32 | 30: 'bullfrog, Rana catesbeiana [牛蛙]', 33 | 31: 'tree frog, tree-frog [树蛙]', 34 | 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui [尾蛙,铃蟾蜍,肋蟾蜍,尾蟾蜍]', 35 | 33: 'loggerhead, loggerhead turtle, Caretta caretta [红海龟]', 36 | 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea [皮革龟]', 37 | 35: 'mud turtle [泥龟]', 38 | 36: 'terrapin [淡水龟]', 39 | 37: 'box turtle, box tortoise [箱龟]', 40 | 38: 'banded gecko [带状壁虎]', 41 | 39: 'common iguana, iguana, Iguana iguana [普通鬣蜥]', 42 | 40: 'American chameleon, anole, Anolis carolinensis [美国变色龙]', 43 | 41: 'whiptail, whiptail lizard [鞭尾蜥蜴]', 44 | 42: 'agama [飞龙科蜥蜴]', 45 | 43: 'frilled lizard, Chlamydosaurus kingi [褶边蜥蜴]', 46 | 44: 'alligator lizard [鳄鱼蜥蜴]', 47 | 45: 'Gila monster, Heloderma suspectum [毒蜥]', 48 | 46: 'green lizard, Lacerta viridis [绿蜥蜴]', 49 | 47: 'African chameleon, Chamaeleo chamaeleon [非洲变色龙]', 50 | 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis [科莫多蜥蜴]', 51 | 49: 'African crocodile, Nile crocodile, Crocodylus niloticus [非洲鳄,尼罗河鳄鱼]', 52 | 50: 'American alligator, Alligator mississipiensis [美国鳄鱼,鳄鱼]', 53 | 51: 'triceratops [三角龙]', 54 | 52: 'thunder snake, worm snake, Carphophis amoenus [雷蛇,蠕虫蛇]', 55 | 53: 'ringneck snake, ring-necked snake, ring snake [环蛇,环颈蛇]', 56 | 54: 'hognose snake, puff adder, sand viper [希腊蛇]', 57 | 55: 'green snake, grass snake [绿蛇,草蛇]', 58 | 56: 'king snake, kingsnake [国王蛇]', 59 | 57: 'garter snake, grass snake [袜带蛇,草蛇]', 60 | 58: 'water snake [水蛇]', 61 | 59: 'vine snake [藤蛇]', 62 | 60: 'night snake, Hypsiglena torquata [夜蛇]', 63 | 61: 'boa constrictor, Constrictor constrictor [大蟒蛇]', 64 | 62: 'rock python, rock snake, Python sebae [岩石蟒蛇,岩蛇,蟒蛇]', 65 | 63: 'Indian cobra, Naja naja [印度眼镜蛇]', 66 | 64: 'green mamba [绿曼巴]', 67 | 65: 'sea snake [海蛇]', 68 | 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus [角腹蛇]', 69 | 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus [菱纹响尾蛇]', 70 | 68: 'sidewinder, horned rattlesnake, Crotalus cerastes [角响尾蛇]', 71 | 69: 'trilobite [三叶虫]', 72 | 70: 'harvestman, daddy longlegs, Phalangium opilio [盲蜘蛛]', 73 | 71: 'scorpion [蝎子]', 74 | 72: 'black and gold garden spider, Argiope aurantia [黑金花园蜘蛛]', 75 | 73: 'barn spider, Araneus cavaticus [谷仓蜘蛛]', 76 | 74: 'garden spider, Aranea diademata [花园蜘蛛]', 77 | 75: 'black widow, Latrodectus mactans [黑寡妇蜘蛛]', 78 | 76: 'tarantula [狼蛛]', 79 | 77: 'wolf spider, hunting spider [狼蜘蛛,狩猎蜘蛛]', 80 | 78: 'tick [壁虱]', 81 | 79: 'centipede [蜈蚣]', 82 | 80: 'black grouse [黑松鸡]', 83 | 81: 'ptarmigan [松鸡,雷鸟]', 84 | 82: 'ruffed grouse, partridge, Bonasa umbellus [披肩鸡,披肩榛鸡]', 85 | 83: 'prairie chicken, prairie grouse, prairie fowl [草原鸡,草原松鸡]', 86 | 84: 'peacock [孔雀]', 87 | 85: 'quail [鹌鹑]', 88 | 86: 'partridge [鹧鸪]', 89 | 87: 'African grey, African gray, Psittacus erithacus [非洲灰鹦鹉]', 90 | 88: 'macaw [金刚鹦鹉]', 91 | 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita [硫冠鹦鹉]', 92 | 90: 'lorikeet [短尾鹦鹉]', 93 | 91: 'coucal [褐翅鸦鹃]', 94 | 92: 'bee eater [蜜蜂]', 95 | 93: 'hornbill [犀鸟]', 96 | 94: 'hummingbird [蜂鸟]', 97 | 95: 'jacamar [鹟䴕]', 98 | 96: 'toucan [犀鸟]', 99 | 97: 'drake [野鸭]', 100 | 98: 'red-breasted merganser, Mergus serrator [红胸秋沙鸭]', 101 | 99: 'goose [鹅]', 102 | 100: 'black swan, Cygnus atratus [黑天鹅]', 103 | 101: 'tusker [大象]', 104 | 102: 'echidna, spiny anteater, anteater [针鼹鼠]', 105 | 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus [鸭嘴兽]', 106 | 104: 'wallaby, brush kangaroo [沙袋鼠]', 107 | 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus [考拉,考拉熊]', 108 | 106: 'wombat [袋熊]', 109 | 107: 'jellyfish [水母]', 110 | 108: 'sea anemone, anemone [海葵]', 111 | 109: 'brain coral [脑珊瑚]', 112 | 110: 'flatworm, platyhelminth [扁形虫扁虫]', 113 | 111: 'nematode, nematode worm, roundworm [线虫,蛔虫]', 114 | 112: 'conch [海螺]', 115 | 113: 'snail [蜗牛]', 116 | 114: 'slug [鼻涕虫]', 117 | 115: 'sea slug, nudibranch [海参]', 118 | 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore [石鳖]', 119 | 117: 'chambered nautilus, pearly nautilus, nautilus [鹦鹉螺]', 120 | 118: 'Dungeness crab, Cancer magister [珍宝蟹]', 121 | 119: 'rock crab, Cancer irroratus [石蟹]', 122 | 120: 'fiddler crab [招潮蟹]', 123 | 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica [帝王蟹,阿拉斯加蟹,阿拉斯加帝王蟹]', 124 | 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus [美国龙虾,缅因州龙虾]', 125 | 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish [大螯虾]', 126 | 124: 'crayfish, crawfish, crawdad, crawdaddy [小龙虾]', 127 | 125: 'hermit crab [寄居蟹]', 128 | 126: 'isopod [等足目动物(明虾和螃蟹近亲)]', 129 | 127: 'white stork, Ciconia ciconia [白鹳]', 130 | 128: 'black stork, Ciconia nigra [黑鹳]', 131 | 129: 'spoonbill [鹭]', 132 | 130: 'flamingo [火烈鸟]', 133 | 131: 'little blue heron, Egretta caerulea [小蓝鹭]', 134 | 132: 'American egret, great white heron, Egretta albus [美国鹭,大白鹭]', 135 | 133: 'bittern [麻鸦]', 136 | 134: 'crane [鹤]', 137 | 135: 'limpkin, Aramus pictus [秧鹤]', 138 | 136: 'European gallinule, Porphyrio porphyrio [欧洲水鸡,紫水鸡]', 139 | 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana [沼泽泥母鸡,水母鸡]', 140 | 138: 'bustard [鸨]', 141 | 139: 'ruddy turnstone, Arenaria interpres [红翻石鹬]', 142 | 140: 'red-backed sandpiper, dunlin, Erolia alpina [红背鹬,黑腹滨鹬]', 143 | 141: 'redshank, Tringa totanus [红脚鹬]', 144 | 142: 'dowitcher [半蹼鹬]', 145 | 143: 'oystercatcher, oyster catcher [蛎鹬]', 146 | 144: 'pelican [鹈鹕]', 147 | 145: 'king penguin, Aptenodytes patagonica [国王企鹅]', 148 | 146: 'albatross, mollymawk [信天翁,大海鸟]', 149 | 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus [灰鲸]', 150 | 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca [杀人鲸,逆戟鲸,虎鲸]', 151 | 149: 'dugong, Dugong dugon [海牛]', 152 | 150: 'sea lion [海狮]', 153 | 151: 'Chihuahua [奇瓦瓦]', 154 | 152: 'Japanese spaniel [日本猎犬]', 155 | 153: 'Maltese dog, Maltese terrier, Maltese [马尔济斯犬]', 156 | 154: 'Pekinese, Pekingese, Peke [狮子狗]', 157 | 155: 'Shih-Tzu [西施犬]', 158 | 156: 'Blenheim spaniel [布莱尼姆猎犬]', 159 | 157: 'papillon [巴比狗]', 160 | 158: 'toy terrier [玩具犬]', 161 | 159: 'Rhodesian ridgeback [罗得西亚长背猎狗]', 162 | 160: 'Afghan hound, Afghan [阿富汗猎犬]', 163 | 161: 'basset, basset hound [猎犬]', 164 | 162: 'beagle [比格犬,猎兔犬]', 165 | 163: 'bloodhound, sleuthhound [侦探犬]', 166 | 164: 'bluetick [蓝色快狗]', 167 | 165: 'black-and-tan coonhound [黑褐猎浣熊犬]', 168 | 166: 'Walker hound, Walker foxhound [沃克猎犬]', 169 | 167: 'English foxhound [英国猎狐犬]', 170 | 168: 'redbone [美洲赤狗]', 171 | 169: 'borzoi, Russian wolfhound [俄罗斯猎狼犬]', 172 | 170: 'Irish wolfhound [爱尔兰猎狼犬]', 173 | 171: 'Italian greyhound [意大利灰狗]', 174 | 172: 'whippet [惠比特犬]', 175 | 173: 'Ibizan hound, Ibizan Podenco [依比沙猎犬]', 176 | 174: 'Norwegian elkhound, elkhound [挪威猎犬]', 177 | 175: 'otterhound, otter hound [奥达猎犬,水獭猎犬]', 178 | 176: 'Saluki, gazelle hound [沙克犬,瞪羚猎犬]', 179 | 177: 'Scottish deerhound, deerhound [苏格兰猎鹿犬,猎鹿犬]', 180 | 178: 'Weimaraner [威玛猎犬]', 181 | 179: 'Staffordshire bullterrier, Staffordshire bull terrier [斯塔福德郡牛头梗,斯塔福德郡斗牛梗]', 182 | 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier [美国斯塔福德郡梗,美国比特斗牛梗,斗牛梗]', 183 | 181: 'Bedlington terrier [贝德灵顿梗]', 184 | 182: 'Border terrier [边境梗]', 185 | 183: 'Kerry blue terrier [凯丽蓝梗]', 186 | 184: 'Irish terrier [爱尔兰梗]', 187 | 185: 'Norfolk terrier [诺福克梗]', 188 | 186: 'Norwich terrier [诺维奇梗]', 189 | 187: 'Yorkshire terrier [约克郡梗]', 190 | 188: 'wire-haired fox terrier [刚毛猎狐梗]', 191 | 189: 'Lakeland terrier [莱克兰梗]', 192 | 190: 'Sealyham terrier, Sealyham [锡利哈姆梗]', 193 | 191: 'Airedale, Airedale terrier [艾尔谷犬]', 194 | 192: 'cairn, cairn terrier [凯恩梗]', 195 | 193: 'Australian terrier [澳大利亚梗]', 196 | 194: 'Dandie Dinmont, Dandie Dinmont terrier [丹迪丁蒙梗]', 197 | 195: 'Boston bull, Boston terrier [波士顿梗]', 198 | 196: 'miniature schnauzer [迷你雪纳瑞犬]', 199 | 197: 'giant schnauzer [巨型雪纳瑞犬]', 200 | 198: 'standard schnauzer [标准雪纳瑞犬]', 201 | 199: 'Scotch terrier, Scottish terrier, Scottie [苏格兰梗]', 202 | 200: 'Tibetan terrier, chrysanthemum dog [西藏梗,菊花狗]', 203 | 201: 'silky terrier, Sydney silky [丝毛梗]', 204 | 202: 'soft-coated wheaten terrier [软毛麦色梗]', 205 | 203: 'West Highland white terrier [西高地白梗]', 206 | 204: 'Lhasa, Lhasa apso [拉萨阿普索犬]', 207 | 205: 'flat-coated retriever [平毛寻回犬]', 208 | 206: 'curly-coated retriever [卷毛寻回犬]', 209 | 207: 'golden retriever [金毛猎犬]', 210 | 208: 'Labrador retriever [拉布拉多猎犬]', 211 | 209: 'Chesapeake Bay retriever [乞沙比克猎犬]', 212 | 210: 'German short-haired pointer [德国短毛猎犬]', 213 | 211: 'vizsla, Hungarian pointer [维兹拉犬]', 214 | 212: 'English setter [英国谍犬]', 215 | 213: 'Irish setter, red setter [爱尔兰雪达犬,红色猎犬]', 216 | 214: 'Gordon setter [戈登雪达犬]', 217 | 215: 'Brittany spaniel [布列塔尼犬猎犬]', 218 | 216: 'clumber, clumber spaniel [黄毛,黄毛猎犬]', 219 | 217: 'English springer, English springer spaniel [英国史宾格犬]', 220 | 218: 'Welsh springer spaniel [威尔士史宾格犬]', 221 | 219: 'cocker spaniel, English cocker spaniel, cocker [可卡犬,英国可卡犬]', 222 | 220: 'Sussex spaniel [萨塞克斯猎犬]', 223 | 221: 'Irish water spaniel [爱尔兰水猎犬]', 224 | 222: 'kuvasz [哥威斯犬]', 225 | 223: 'schipperke [舒柏奇犬]', 226 | 224: 'groenendael [比利时牧羊犬]', 227 | 225: 'malinois [马里努阿犬]', 228 | 226: 'briard [伯瑞犬]', 229 | 227: 'kelpie [凯尔皮犬]', 230 | 228: 'komondor [匈牙利牧羊犬]', 231 | 229: 'Old English sheepdog, bobtail [老英国牧羊犬]', 232 | 230: 'Shetland sheepdog, Shetland sheep dog, Shetland [喜乐蒂牧羊犬]', 233 | 231: 'collie [牧羊犬]', 234 | 232: 'Border collie [边境牧羊犬]', 235 | 233: 'Bouvier des Flandres, Bouviers des Flandres [法兰德斯牧牛狗]', 236 | 234: 'Rottweiler [罗特韦尔犬]', 237 | 235: 'German shepherd, German shepherd dog, German police dog, alsatian [德国牧羊犬,德国警犬,阿尔萨斯]', 238 | 236: 'Doberman, Doberman pinscher [多伯曼犬,杜宾犬]', 239 | 237: 'miniature pinscher [迷你杜宾犬]', 240 | 238: 'Greater Swiss Mountain dog [大瑞士山地犬]', 241 | 239: 'Bernese mountain dog [伯恩山犬]', 242 | 240: 'Appenzeller [Appenzeller狗]', 243 | 241: 'EntleBucher [EntleBucher狗]', 244 | 242: 'boxer [拳师狗]', 245 | 243: 'bull mastiff [斗牛獒]', 246 | 244: 'Tibetan mastiff [藏獒]', 247 | 245: 'French bulldog [法国斗牛犬]', 248 | 246: 'Great Dane [大丹犬]', 249 | 247: 'Saint Bernard, St Bernard [圣伯纳德狗]', 250 | 248: 'Eskimo dog, husky [爱斯基摩犬,哈士奇]', 251 | 249: 'malamute, malemute, Alaskan malamute [雪橇犬,阿拉斯加爱斯基摩狗]', 252 | 250: 'Siberian husky [哈士奇]', 253 | 251: 'dalmatian, coach dog, carriage dog [达尔马提亚,教练车狗]', 254 | 252: 'affenpinscher, monkey pinscher, monkey dog [狮毛狗]', 255 | 253: 'basenji [巴辛吉狗]', 256 | 254: 'pug, pug-dog [哈巴狗,狮子狗]', 257 | 255: 'Leonberg [莱昂贝格狗]', 258 | 256: 'Newfoundland, Newfoundland dog [纽芬兰岛狗]', 259 | 257: 'Great Pyrenees [大白熊犬]', 260 | 258: 'Samoyed, Samoyede [萨摩耶犬]', 261 | 259: 'Pomeranian [博美犬]', 262 | 260: 'chow, chow chow [松狮,松狮]', 263 | 261: 'keeshond [荷兰卷尾狮毛狗]', 264 | 262: 'Brabancon griffon [布鲁塞尔格林芬犬]', 265 | 263: 'Pembroke, Pembroke Welsh corgi [彭布洛克威尔士科基犬]', 266 | 264: 'Cardigan, Cardigan Welsh corgi [威尔士柯基犬]', 267 | 265: 'toy poodle [玩具贵宾犬]', 268 | 266: 'miniature poodle [迷你贵宾犬]', 269 | 267: 'standard poodle [标准贵宾犬]', 270 | 268: 'Mexican hairless [墨西哥无毛犬]', 271 | 269: 'timber wolf, grey wolf, gray wolf, Canis lupus [灰狼]', 272 | 270: 'white wolf, Arctic wolf, Canis lupus tundrarum [白狼,北极狼]', 273 | 271: 'red wolf, maned wolf, Canis rufus, Canis niger [红太狼,鬃狼,犬犬鲁弗斯]', 274 | 272: 'coyote, prairie wolf, brush wolf, Canis latrans [狼,草原狼,刷狼,郊狼]', 275 | 273: 'dingo, warrigal, warragal, Canis dingo [澳洲野狗,澳大利亚野犬]', 276 | 274: 'dhole, Cuon alpinus [豺]', 277 | 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus [非洲猎犬,土狼犬]', 278 | 276: 'hyena, hyaena [鬣狗]', 279 | 277: 'red fox, Vulpes vulpes [红狐狸]', 280 | 278: 'kit fox, Vulpes macrotis [沙狐]', 281 | 279: 'Arctic fox, white fox, Alopex lagopus [北极狐狸,白狐狸]', 282 | 280: 'grey fox, gray fox, Urocyon cinereoargenteus [灰狐狸]', 283 | 281: 'tabby, tabby cat [虎斑猫]', 284 | 282: 'tiger cat [山猫,虎猫]', 285 | 283: 'Persian cat [波斯猫]', 286 | 284: 'Siamese cat, Siamese [暹罗暹罗猫,]', 287 | 285: 'Egyptian cat [埃及猫]', 288 | 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor [美洲狮,美洲豹]', 289 | 287: 'lynx, catamount [猞猁,山猫]', 290 | 288: 'leopard, Panthera pardus [豹子]', 291 | 289: 'snow leopard, ounce, Panthera uncia [雪豹]', 292 | 290: 'jaguar, panther, Panthera onca, Felis onca [美洲虎]', 293 | 291: 'lion, king of beasts, Panthera leo [狮子]', 294 | 292: 'tiger, Panthera tigris [老虎]', 295 | 293: 'cheetah, chetah, Acinonyx jubatus [猎豹]', 296 | 294: 'brown bear, bruin, Ursus arctos [棕熊]', 297 | 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus [美洲黑熊]', 298 | 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus [冰熊,北极熊]', 299 | 297: 'sloth bear, Melursus ursinus, Ursus ursinus [懒熊]', 300 | 298: 'mongoose [猫鼬]', 301 | 299: 'meerkat, mierkat [猫鼬,海猫]', 302 | 300: 'tiger beetle [虎甲虫]', 303 | 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle [瓢虫]', 304 | 302: 'ground beetle, carabid beetle [土鳖虫]', 305 | 303: 'long-horned beetle, longicorn, longicorn beetle [天牛]', 306 | 304: 'leaf beetle, chrysomelid [龟甲虫]', 307 | 305: 'dung beetle [粪甲虫]', 308 | 306: 'rhinoceros beetle [犀牛甲虫]', 309 | 307: 'weevil [象甲]', 310 | 308: 'fly [苍蝇]', 311 | 309: 'bee [蜜蜂]', 312 | 310: 'ant, emmet, pismire [蚂蚁]', 313 | 311: 'grasshopper, hopper [蚱蜢]', 314 | 312: 'cricket [蟋蟀]', 315 | 313: 'walking stick, walkingstick, stick insect [竹节虫]', 316 | 314: 'cockroach, roach [蟑螂]', 317 | 315: 'mantis, mantid [螳螂]', 318 | 316: 'cicada, cicala [蝉]', 319 | 317: 'leafhopper [叶蝉]', 320 | 318: 'lacewing, lacewing fly [草蜻蛉]', 321 | 319: 'dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk [蜻蜓]', 322 | 320: 'damselfly [豆娘,蜻蛉]', 323 | 321: 'admiral [优红蛱蝶]', 324 | 322: 'ringlet, ringlet butterfly [小环蝴蝶]', 325 | 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus [君主蝴蝶,大斑蝶]', 326 | 324: 'cabbage butterfly [菜粉蝶]', 327 | 325: 'sulphur butterfly, sulfur butterfly [白蝴蝶]', 328 | 326: 'lycaenid, lycaenid butterfly [灰蝶]', 329 | 327: 'starfish, sea star [海星]', 330 | 328: 'sea urchin [海胆]', 331 | 329: 'sea cucumber, holothurian [海参,海黄瓜]', 332 | 330: 'wood rabbit, cottontail, cottontail rabbit [野兔]', 333 | 331: 'hare [兔]', 334 | 332: 'Angora, Angora rabbit [安哥拉兔]', 335 | 333: 'hamster [仓鼠]', 336 | 334: 'porcupine, hedgehog [刺猬,豪猪,]', 337 | 335: 'fox squirrel, eastern fox squirrel, Sciurus niger [黑松鼠]', 338 | 336: 'marmot [土拨鼠]', 339 | 337: 'beaver [海狸]', 340 | 338: 'guinea pig, Cavia cobaya [豚鼠,豚鼠]', 341 | 339: 'sorrel [栗色马]', 342 | 340: 'zebra [斑马]', 343 | 341: 'hog, pig, grunter, squealer, Sus scrofa [猪]', 344 | 342: 'wild boar, boar, Sus scrofa [野猪]', 345 | 343: 'warthog [疣猪]', 346 | 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius [河马]', 347 | 345: 'ox [牛]', 348 | 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis [水牛,亚洲水牛]', 349 | 347: 'bison [野牛]', 350 | 348: 'ram, tup [公羊]', 351 | 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis [大角羊,洛矶山大角羊]', 352 | 350: 'ibex, Capra ibex [山羊]', 353 | 351: 'hartebeest [狷羚]', 354 | 352: 'impala, Aepyceros melampus [黑斑羚]', 355 | 353: 'gazelle [瞪羚]', 356 | 354: 'Arabian camel, dromedary, Camelus dromedarius [阿拉伯单峰骆驼,骆驼]', 357 | 355: 'llama [羊驼]', 358 | 356: 'weasel [黄鼠狼]', 359 | 357: 'mink [水貂]', 360 | 358: 'polecat, fitch, foulmart, foumart, Mustela putorius [臭猫]', 361 | 359: 'black-footed ferret, ferret, Mustela nigripes [黑足鼬]', 362 | 360: 'otter [水獭]', 363 | 361: 'skunk, polecat, wood pussy [臭鼬,木猫]', 364 | 362: 'badger [獾]', 365 | 363: 'armadillo [犰狳]', 366 | 364: 'three-toed sloth, ai, Bradypus tridactylus [树懒]', 367 | 365: 'orangutan, orang, orangutang, Pongo pygmaeus [猩猩,婆罗洲猩猩]', 368 | 366: 'gorilla, Gorilla gorilla [大猩猩]', 369 | 367: 'chimpanzee, chimp, Pan troglodytes [黑猩猩]', 370 | 368: 'gibbon, Hylobates lar [长臂猿]', 371 | 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus [合趾猿长臂猿,合趾猿]', 372 | 370: 'guenon, guenon monkey [长尾猴]', 373 | 371: 'patas, hussar monkey, Erythrocebus patas [赤猴]', 374 | 372: 'baboon [狒狒]', 375 | 373: 'macaque [恒河猴,猕猴]', 376 | 374: 'langur [白头叶猴]', 377 | 375: 'colobus, colobus monkey [疣猴]', 378 | 376: 'proboscis monkey, Nasalis larvatus [长鼻猴]', 379 | 377: 'marmoset [狨(美洲产小型长尾猴)]', 380 | 378: 'capuchin, ringtail, Cebus capucinus [卷尾猴]', 381 | 379: 'howler monkey, howler [吼猴]', 382 | 380: 'titi, titi monkey [伶猴]', 383 | 381: 'spider monkey, Ateles geoffroyi [蜘蛛猴]', 384 | 382: 'squirrel monkey, Saimiri sciureus [松鼠猴]', 385 | 383: 'Madagascar cat, ring-tailed lemur, Lemur catta [马达加斯加环尾狐猴,鼠狐猴]', 386 | 384: 'indri, indris, Indri indri, Indri brevicaudatus [大狐猴,马达加斯加大狐猴]', 387 | 385: 'Indian elephant, Elephas maximus [印度大象,亚洲象]', 388 | 386: 'African elephant, Loxodonta africana [非洲象,非洲象]', 389 | 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens [小熊猫]', 390 | 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca [大熊猫]', 391 | 389: 'barracouta, snoek [杖鱼]', 392 | 390: 'eel [鳗鱼]', 393 | 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch [银鲑,银鲑鱼]', 394 | 392: 'rock beauty, Holocanthus tricolor [三色刺蝶鱼]', 395 | 393: 'anemone fish [海葵鱼]', 396 | 394: 'sturgeon [鲟鱼]', 397 | 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus [雀鳝]', 398 | 396: 'lionfish [狮子鱼]', 399 | 397: 'puffer, pufferfish, blowfish, globefish [河豚]', 400 | 398: 'abacus [算盘]', 401 | 399: 'abaya [长袍]', 402 | 400: 'academic gown, academic robe, judge robe [学位袍]', 403 | 401: 'accordion, piano accordion, squeeze box [手风琴]', 404 | 402: 'acoustic guitar [原声吉他]', 405 | 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier [航空母舰]', 406 | 404: 'airliner [客机]', 407 | 405: 'airship, dirigible [飞艇]', 408 | 406: 'altar [祭坛]', 409 | 407: 'ambulance [救护车]', 410 | 408: 'amphibian, amphibious vehicle [水陆两用车]', 411 | 409: 'analog clock [模拟时钟]', 412 | 410: 'apiary, bee house [蜂房]', 413 | 411: 'apron [围裙]', 414 | 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin [垃圾桶]', 415 | 413: 'assault rifle, assault gun [攻击步枪,枪]', 416 | 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack [背包]', 417 | 415: 'bakery, bakeshop, bakehouse [面包店,面包铺,]', 418 | 416: 'balance beam, beam [平衡木]', 419 | 417: 'balloon [热气球]', 420 | 418: 'ballpoint, ballpoint pen, ballpen, Biro [圆珠笔]', 421 | 419: 'Band Aid [创可贴]', 422 | 420: 'banjo [班卓琴]', 423 | 421: 'bannister, banister, balustrade, balusters, handrail [栏杆,楼梯扶手]', 424 | 422: 'barbell [杠铃]', 425 | 423: 'barber chair [理发师的椅子]', 426 | 424: 'barbershop [理发店]', 427 | 425: 'barn [牲口棚]', 428 | 426: 'barometer [晴雨表]', 429 | 427: 'barrel, cask [圆筒]', 430 | 428: 'barrow, garden cart, lawn cart, wheelbarrow [园地小车,手推车]', 431 | 429: 'baseball [棒球]', 432 | 430: 'basketball [篮球]', 433 | 431: 'bassinet [婴儿床]', 434 | 432: 'bassoon [巴松管,低音管]', 435 | 433: 'bathing cap, swimming cap [游泳帽]', 436 | 434: 'bath towel [沐浴毛巾]', 437 | 435: 'bathtub, bathing tub, bath, tub [浴缸,澡盆]', 438 | 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon [沙滩车,旅行车]', 439 | 437: 'beacon, lighthouse, beacon light, pharos [灯塔]', 440 | 438: 'beaker [高脚杯]', 441 | 439: 'bearskin, busby, shako [熊皮高帽]', 442 | 440: 'beer bottle [啤酒瓶]', 443 | 441: 'beer glass [啤酒杯]', 444 | 442: 'bell cote, bell cot [钟塔]', 445 | 443: 'bib [(小儿用的)围嘴]', 446 | 444: 'bicycle-built-for-two, tandem bicycle, tandem [串联自行车,]', 447 | 445: 'bikini, two-piece [比基尼]', 448 | 446: 'binder, ring-binder [装订册]', 449 | 447: 'binoculars, field glasses, opera glasses [双筒望远镜]', 450 | 448: 'birdhouse [鸟舍]', 451 | 449: 'boathouse [船库]', 452 | 450: 'bobsled, bobsleigh, bob [雪橇]', 453 | 451: 'bolo tie, bolo, bola tie, bola [饰扣式领带]', 454 | 452: 'bonnet, poke bonnet [阔边女帽]', 455 | 453: 'bookcase [书橱]', 456 | 454: 'bookshop, bookstore, bookstall [书店,书摊]', 457 | 455: 'bottlecap [瓶盖]', 458 | 456: 'bow [弓箭]', 459 | 457: 'bow tie, bow-tie, bowtie [蝴蝶结领结]', 460 | 458: 'brass, memorial tablet, plaque [铜制牌位]', 461 | 459: 'brassiere, bra, bandeau [奶罩]', 462 | 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty [防波堤,海堤]', 463 | 461: 'breastplate, aegis, egis [铠甲]', 464 | 462: 'broom [扫帚]', 465 | 463: 'bucket, pail [桶]', 466 | 464: 'buckle [扣环]', 467 | 465: 'bulletproof vest [防弹背心]', 468 | 466: 'bullet train, bullet [动车,子弹头列车]', 469 | 467: 'butcher shop, meat market [肉铺,肉菜市场]', 470 | 468: 'cab, hack, taxi, taxicab [出租车]', 471 | 469: 'caldron, cauldron [大锅]', 472 | 470: 'candle, taper, wax light [蜡烛]', 473 | 471: 'cannon [大炮]', 474 | 472: 'canoe [独木舟]', 475 | 473: 'can opener, tin opener [开瓶器,开罐器]', 476 | 474: 'cardigan [开衫]', 477 | 475: 'car mirror [车镜]', 478 | 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig [旋转木马]', 479 | 477: 'carpenters kit, tool kit [木匠的工具包,工具包]', 480 | 478: 'carton [纸箱]', 481 | 479: 'car wheel [车轮]', 482 | 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM [取款机,自动取款机]', 483 | 481: 'cassette [盒式录音带]', 484 | 482: 'cassette player [卡带播放器]', 485 | 483: 'castle [城堡]', 486 | 484: 'catamaran [双体船]', 487 | 485: 'CD player [CD播放器]', 488 | 486: 'cello, violoncello [大提琴]', 489 | 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone [移动电话,手机]', 490 | 488: 'chain [铁链]', 491 | 489: 'chainlink fence [围栏]', 492 | 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour [链甲]', 493 | 491: 'chain saw, chainsaw [电锯,油锯]', 494 | 492: 'chest [箱子]', 495 | 493: 'chiffonier, commode [衣柜,洗脸台]', 496 | 494: 'chime, bell, gong [编钟,钟,锣]', 497 | 495: 'china cabinet, china closet [中国橱柜]', 498 | 496: 'Christmas stocking [圣诞袜]', 499 | 497: 'church, church building [教堂,教堂建筑]', 500 | 498: 'cinema, movie theater, movie theatre, movie house, picture palace [电影院,剧场]', 501 | 499: 'cleaver, meat cleaver, chopper [切肉刀,菜刀]', 502 | 500: 'cliff dwelling [悬崖屋]', 503 | 501: 'cloak [斗篷]', 504 | 502: 'clog, geta, patten, sabot [木屐,木鞋]', 505 | 503: 'cocktail shaker [鸡尾酒调酒器]', 506 | 504: 'coffee mug [咖啡杯]', 507 | 505: 'coffeepot [咖啡壶]', 508 | 506: 'coil, spiral, volute, whorl, helix [螺旋结构(楼梯)]', 509 | 507: 'combination lock [组合锁]', 510 | 508: 'computer keyboard, keypad [电脑键盘,键盘]', 511 | 509: 'confectionery, confectionary, candy store [糖果,糖果店]', 512 | 510: 'container ship, containership, container vessel [集装箱船]', 513 | 511: 'convertible [敞篷车]', 514 | 512: 'corkscrew, bottle screw [开瓶器,瓶螺杆]', 515 | 513: 'cornet, horn, trumpet, trump [短号,喇叭]', 516 | 514: 'cowboy boot [牛仔靴]', 517 | 515: 'cowboy hat, ten-gallon hat [牛仔帽]', 518 | 516: 'cradle [摇篮]', 519 | 517: 'crane [起重机]', 520 | 518: 'crash helmet [头盔]', 521 | 519: 'crate [板条箱]', 522 | 520: 'crib, cot [小儿床]', 523 | 521: 'Crock Pot [砂锅]', 524 | 522: 'croquet ball [槌球]', 525 | 523: 'crutch [拐杖]', 526 | 524: 'cuirass [胸甲]', 527 | 525: 'dam, dike, dyke [大坝,堤防]', 528 | 526: 'desk [书桌]', 529 | 527: 'desktop computer [台式电脑]', 530 | 528: 'dial telephone, dial phone [有线电话]', 531 | 529: 'diaper, nappy, napkin [尿布湿]', 532 | 530: 'digital clock [数字时钟]', 533 | 531: 'digital watch [数字手表]', 534 | 532: 'dining table, board [餐桌板]', 535 | 533: 'dishrag, dishcloth [抹布]', 536 | 534: 'dishwasher, dish washer, dishwashing machine [洗碗机,洗碟机]', 537 | 535: 'disk brake, disc brake [盘式制动器]', 538 | 536: 'dock, dockage, docking facility [码头,船坞,码头设施]', 539 | 537: 'dogsled, dog sled, dog sleigh [狗拉雪橇]', 540 | 538: 'dome [圆顶]', 541 | 539: 'doormat, welcome mat [门垫,垫子]', 542 | 540: 'drilling platform, offshore rig [钻井平台,海上钻井]', 543 | 541: 'drum, membranophone, tympan [鼓,乐器,鼓膜]', 544 | 542: 'drumstick [鼓槌]', 545 | 543: 'dumbbell [哑铃]', 546 | 544: 'Dutch oven [荷兰烤箱]', 547 | 545: 'electric fan, blower [电风扇,鼓风机]', 548 | 546: 'electric guitar [电吉他]', 549 | 547: 'electric locomotive [电力机车]', 550 | 548: 'entertainment center [电视,电视柜]', 551 | 549: 'envelope [信封]', 552 | 550: 'espresso maker [浓缩咖啡机]', 553 | 551: 'face powder [扑面粉]', 554 | 552: 'feather boa, boa [女用长围巾]', 555 | 553: 'file, file cabinet, filing cabinet [文件,文件柜,档案柜]', 556 | 554: 'fireboat [消防船]', 557 | 555: 'fire engine, fire truck [消防车]', 558 | 556: 'fire screen, fireguard [火炉栏]', 559 | 557: 'flagpole, flagstaff [旗杆]', 560 | 558: 'flute, transverse flute [长笛]', 561 | 559: 'folding chair [折叠椅]', 562 | 560: 'football helmet [橄榄球头盔]', 563 | 561: 'forklift [叉车]', 564 | 562: 'fountain [喷泉]', 565 | 563: 'fountain pen [钢笔]', 566 | 564: 'four-poster [有四根帷柱的床]', 567 | 565: 'freight car [运货车厢]', 568 | 566: 'French horn, horn [圆号,喇叭]', 569 | 567: 'frying pan, frypan, skillet [煎锅]', 570 | 568: 'fur coat [裘皮大衣]', 571 | 569: 'garbage truck, dustcart [垃圾车]', 572 | 570: 'gasmask, respirator, gas helmet [防毒面具,呼吸器]', 573 | 571: 'gas pump, gasoline pump, petrol pump, island dispenser [汽油泵]', 574 | 572: 'goblet [高脚杯]', 575 | 573: 'go-kart [卡丁车]', 576 | 574: 'golf ball [高尔夫球]', 577 | 575: 'golfcart, golf cart [高尔夫球车]', 578 | 576: 'gondola [狭长小船]', 579 | 577: 'gong, tam-tam [锣]', 580 | 578: 'gown [礼服]', 581 | 579: 'grand piano, grand [钢琴]', 582 | 580: 'greenhouse, nursery, glasshouse [温室,苗圃]', 583 | 581: 'grille, radiator grille [散热器格栅]', 584 | 582: 'grocery store, grocery, food market, market [杂货店,食品市场]', 585 | 583: 'guillotine [断头台]', 586 | 584: 'hair slide [小发夹]', 587 | 585: 'hair spray [头发喷雾]', 588 | 586: 'half track [半履带装甲车]', 589 | 587: 'hammer [锤子]', 590 | 588: 'hamper [大篮子]', 591 | 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier [手摇鼓风机,吹风机]', 592 | 590: 'hand-held computer, hand-held microcomputer [手提电脑]', 593 | 591: 'handkerchief, hankie, hanky, hankey [手帕]', 594 | 592: 'hard disc, hard disk, fixed disk [硬盘]', 595 | 593: 'harmonica, mouth organ, harp, mouth harp [口琴,口风琴]', 596 | 594: 'harp [竖琴]', 597 | 595: 'harvester, reaper [收割机]', 598 | 596: 'hatchet [斧头]', 599 | 597: 'holster [手枪皮套]', 600 | 598: 'home theater, home theatre [家庭影院]', 601 | 599: 'honeycomb [蜂窝]', 602 | 600: 'hook, claw [钩爪]', 603 | 601: 'hoopskirt, crinoline [衬裙]', 604 | 602: 'horizontal bar, high bar [单杠]', 605 | 603: 'horse cart, horse-cart [马车]', 606 | 604: 'hourglass [沙漏]', 607 | 605: 'iPod [手机,iPad]', 608 | 606: 'iron, smoothing iron [熨斗]', 609 | 607: 'jack-o-lantern [南瓜灯笼]', 610 | 608: 'jean, blue jean, denim [牛仔裤,蓝色牛仔裤]', 611 | 609: 'jeep, landrover [吉普车]', 612 | 610: 'jersey, T-shirt, tee shirt [运动衫,T恤]', 613 | 611: 'jigsaw puzzle [拼图]', 614 | 612: 'jinrikisha, ricksha, rickshaw [人力车]', 615 | 613: 'joystick [操纵杆]', 616 | 614: 'kimono [和服]', 617 | 615: 'knee pad [护膝]', 618 | 616: 'knot [蝴蝶结]', 619 | 617: 'lab coat, laboratory coat [大褂,实验室外套]', 620 | 618: 'ladle [长柄勺]', 621 | 619: 'lampshade, lamp shade [灯罩]', 622 | 620: 'laptop, laptop computer [笔记本电脑]', 623 | 621: 'lawn mower, mower [割草机]', 624 | 622: 'lens cap, lens cover [镜头盖]', 625 | 623: 'letter opener, paper knife, paperknife [开信刀,裁纸刀]', 626 | 624: 'library [图书馆]', 627 | 625: 'lifeboat [救生艇]', 628 | 626: 'lighter, light, igniter, ignitor [点火器,打火机]', 629 | 627: 'limousine, limo [豪华轿车]', 630 | 628: 'liner, ocean liner [远洋班轮]', 631 | 629: 'lipstick, lip rouge [唇膏,口红]', 632 | 630: 'Loafer [平底便鞋]', 633 | 631: 'lotion [洗剂]', 634 | 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system [扬声器]', 635 | 633: 'loupe, jewelers loupe [放大镜]', 636 | 634: 'lumbermill, sawmill [锯木厂]', 637 | 635: 'magnetic compass [磁罗盘]', 638 | 636: 'mailbag, postbag [邮袋]', 639 | 637: 'mailbox, letter box [信箱]', 640 | 638: 'maillot [女游泳衣]', 641 | 639: 'maillot, tank suit [有肩带浴衣]', 642 | 640: 'manhole cover [窨井盖]', 643 | 641: 'maraca [沙球(一种打击乐器)]', 644 | 642: 'marimba, xylophone [马林巴木琴]', 645 | 643: 'mask [面膜]', 646 | 644: 'matchstick [火柴]', 647 | 645: 'maypole [花柱]', 648 | 646: 'maze, labyrinth [迷宫]', 649 | 647: 'measuring cup [量杯]', 650 | 648: 'medicine chest, medicine cabinet [药箱]', 651 | 649: 'megalith, megalithic structure [巨石,巨石结构]', 652 | 650: 'microphone, mike [麦克风]', 653 | 651: 'microwave, microwave oven [微波炉]', 654 | 652: 'military uniform [军装]', 655 | 653: 'milk can [奶桶]', 656 | 654: 'minibus [迷你巴士]', 657 | 655: 'miniskirt, mini [迷你裙]', 658 | 656: 'minivan [面包车]', 659 | 657: 'missile [导弹]', 660 | 658: 'mitten [连指手套]', 661 | 659: 'mixing bowl [搅拌钵]', 662 | 660: 'mobile home, manufactured home [活动房屋(由汽车拖拉的)]', 663 | 661: 'Model T [T型发动机小汽车]', 664 | 662: 'modem [调制解调器]', 665 | 663: 'monastery [修道院]', 666 | 664: 'monitor [显示器]', 667 | 665: 'moped [电瓶车]', 668 | 666: 'mortar [砂浆]', 669 | 667: 'mortarboard [学士]', 670 | 668: 'mosque [清真寺]', 671 | 669: 'mosquito net [蚊帐]', 672 | 670: 'motor scooter, scooter [摩托车]', 673 | 671: 'mountain bike, all-terrain bike, off-roader [山地自行车]', 674 | 672: 'mountain tent [登山帐]', 675 | 673: 'mouse, computer mouse [鼠标,电脑鼠标]', 676 | 674: 'mousetrap [捕鼠器]', 677 | 675: 'moving van [搬家车]', 678 | 676: 'muzzle [口套]', 679 | 677: 'nail [钉子]', 680 | 678: 'neck brace [颈托]', 681 | 679: 'necklace [项链]', 682 | 680: 'nipple [乳头(瓶)]', 683 | 681: 'notebook, notebook computer [笔记本,笔记本电脑]', 684 | 682: 'obelisk [方尖碑]', 685 | 683: 'oboe, hautboy, hautbois [双簧管]', 686 | 684: 'ocarina, sweet potato [陶笛,卵形笛]', 687 | 685: 'odometer, hodometer, mileometer, milometer [里程表]', 688 | 686: 'oil filter [滤油器]', 689 | 687: 'organ, pipe organ [风琴,管风琴]', 690 | 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO [示波器]', 691 | 689: 'overskirt [罩裙]', 692 | 690: 'oxcart [牛车]', 693 | 691: 'oxygen mask [氧气面罩]', 694 | 692: 'packet [包装]', 695 | 693: 'paddle, boat paddle [船桨]', 696 | 694: 'paddlewheel, paddle wheel [明轮,桨轮]', 697 | 695: 'padlock [挂锁,扣锁]', 698 | 696: 'paintbrush [画笔]', 699 | 697: 'pajama, pyjama, pjs, jammies [睡衣]', 700 | 698: 'palace [宫殿]', 701 | 699: 'panpipe, pandean pipe, syrinx [排箫,鸣管]', 702 | 700: 'paper towel [纸巾]', 703 | 701: 'parachute, chute [降落伞]', 704 | 702: 'parallel bars, bars [双杠]', 705 | 703: 'park bench [公园长椅]', 706 | 704: 'parking meter [停车收费表,停车计时器]', 707 | 705: 'passenger car, coach, carriage [客车,教练车]', 708 | 706: 'patio, terrace [露台,阳台]', 709 | 707: 'pay-phone, pay-station [付费电话]', 710 | 708: 'pedestal, plinth, footstall [基座,基脚]', 711 | 709: 'pencil box, pencil case [铅笔盒]', 712 | 710: 'pencil sharpener [卷笔刀]', 713 | 711: 'perfume, essence [香水(瓶)]', 714 | 712: 'Petri dish [培养皿]', 715 | 713: 'photocopier [复印机]', 716 | 714: 'pick, plectrum, plectron [拨弦片,拨子]', 717 | 715: 'pickelhaube [尖顶头盔]', 718 | 716: 'picket fence, paling [栅栏,栅栏]', 719 | 717: 'pickup, pickup truck [皮卡,皮卡车]', 720 | 718: 'pier [桥墩]', 721 | 719: 'piggy bank, penny bank [存钱罐]', 722 | 720: 'pill bottle [药瓶]', 723 | 721: 'pillow [枕头]', 724 | 722: 'ping-pong ball [乒乓球]', 725 | 723: 'pinwheel [风车]', 726 | 724: 'pirate, pirate ship [海盗船]', 727 | 725: 'pitcher, ewer [水罐]', 728 | 726: 'plane, carpenters plane, woodworking plane [木工刨]', 729 | 727: 'planetarium [天文馆]', 730 | 728: 'plastic bag [塑料袋]', 731 | 729: 'plate rack [板架]', 732 | 730: 'plow, plough [犁型铲雪机]', 733 | 731: 'plunger, plumbers helper [手压皮碗泵]', 734 | 732: 'Polaroid camera, Polaroid Land camera [宝丽来相机]', 735 | 733: 'pole [电线杆]', 736 | 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria [警车,巡逻车]', 737 | 735: 'poncho [雨披]', 738 | 736: 'pool table, billiard table, snooker table [台球桌]', 739 | 737: 'pop bottle, soda bottle [充气饮料瓶]', 740 | 738: 'pot, flowerpot [花盆]', 741 | 739: 'potters wheel [陶工旋盘]', 742 | 740: 'power drill [电钻]', 743 | 741: 'prayer rug, prayer mat [祈祷垫,地毯]', 744 | 742: 'printer [打印机]', 745 | 743: 'prison, prison house [监狱]', 746 | 744: 'projectile, missile [炮弹,导弹]', 747 | 745: 'projector [投影仪]', 748 | 746: 'puck, hockey puck [冰球]', 749 | 747: 'punching bag, punch bag, punching ball, punchball [沙包,吊球]', 750 | 748: 'purse [钱包]', 751 | 749: 'quill, quill pen [羽管笔]', 752 | 750: 'quilt, comforter, comfort, puff [被子]', 753 | 751: 'racer, race car, racing car [赛车]', 754 | 752: 'racket, racquet [球拍]', 755 | 753: 'radiator [散热器]', 756 | 754: 'radio, wireless [收音机]', 757 | 755: 'radio telescope, radio reflector [射电望远镜,无线电反射器]', 758 | 756: 'rain barrel [雨桶]', 759 | 757: 'recreational vehicle, RV, R.V. [休闲车,房车]', 760 | 758: 'reel [卷轴,卷筒]', 761 | 759: 'reflex camera [反射式照相机]', 762 | 760: 'refrigerator, icebox [冰箱,冰柜]', 763 | 761: 'remote control, remote [遥控器]', 764 | 762: 'restaurant, eating house, eating place, eatery [餐厅,饮食店,食堂]', 765 | 763: 'revolver, six-gun, six-shooter [左轮手枪]', 766 | 764: 'rifle [步枪]', 767 | 765: 'rocking chair, rocker [摇椅]', 768 | 766: 'rotisserie [电转烤肉架]', 769 | 767: 'rubber eraser, rubber, pencil eraser [橡皮]', 770 | 768: 'rugby ball [橄榄球]', 771 | 769: 'rule, ruler [直尺]', 772 | 770: 'running shoe [跑步鞋]', 773 | 771: 'safe [保险柜]', 774 | 772: 'safety pin [安全别针]', 775 | 773: 'saltshaker, salt shaker [盐瓶(调味用)]', 776 | 774: 'sandal [凉鞋]', 777 | 775: 'sarong [纱笼,围裙]', 778 | 776: 'sax, saxophone [萨克斯管]', 779 | 777: 'scabbard [剑鞘]', 780 | 778: 'scale, weighing machine [秤,称重机]', 781 | 779: 'school bus [校车]', 782 | 780: 'schooner [帆船]', 783 | 781: 'scoreboard [记分牌]', 784 | 782: 'screen, CRT screen [屏幕]', 785 | 783: 'screw [螺丝]', 786 | 784: 'screwdriver [螺丝刀]', 787 | 785: 'seat belt, seatbelt [安全带]', 788 | 786: 'sewing machine [缝纫机]', 789 | 787: 'shield, buckler [盾牌,盾牌]', 790 | 788: 'shoe shop, shoe-shop, shoe store [皮鞋店,鞋店]', 791 | 789: 'shoji [障子]', 792 | 790: 'shopping basket [购物篮]', 793 | 791: 'shopping cart [购物车]', 794 | 792: 'shovel [铁锹]', 795 | 793: 'shower cap [浴帽]', 796 | 794: 'shower curtain [浴帘]', 797 | 795: 'ski [滑雪板]', 798 | 796: 'ski mask [滑雪面罩]', 799 | 797: 'sleeping bag [睡袋]', 800 | 798: 'slide rule, slipstick [滑尺]', 801 | 799: 'sliding door [滑动门]', 802 | 800: 'slot, one-armed bandit [角子老虎机]', 803 | 801: 'snorkel [潜水通气管]', 804 | 802: 'snowmobile [雪橇]', 805 | 803: 'snowplow, snowplough [扫雪机,扫雪机]', 806 | 804: 'soap dispenser [皂液器]', 807 | 805: 'soccer ball [足球]', 808 | 806: 'sock [袜子]', 809 | 807: 'solar dish, solar collector, solar furnace [碟式太阳能,太阳能集热器,太阳能炉]', 810 | 808: 'sombrero [宽边帽]', 811 | 809: 'soup bowl [汤碗]', 812 | 810: 'space bar [空格键]', 813 | 811: 'space heater [空间加热器]', 814 | 812: 'space shuttle [航天飞机]', 815 | 813: 'spatula [铲(搅拌或涂敷用的)]', 816 | 814: 'speedboat [快艇]', 817 | 815: 'spider web, spiders web [蜘蛛网]', 818 | 816: 'spindle [纺锤,纱锭]', 819 | 817: 'sports car, sport car [跑车]', 820 | 818: 'spotlight, spot [聚光灯]', 821 | 819: 'stage [舞台]', 822 | 820: 'steam locomotive [蒸汽机车]', 823 | 821: 'steel arch bridge [钢拱桥]', 824 | 822: 'steel drum [钢滚筒]', 825 | 823: 'stethoscope [听诊器]', 826 | 824: 'stole [女用披肩]', 827 | 825: 'stone wall [石头墙]', 828 | 826: 'stopwatch, stop watch [秒表]', 829 | 827: 'stove [火炉]', 830 | 828: 'strainer [过滤器]', 831 | 829: 'streetcar, tram, tramcar, trolley, trolley car [有轨电车,电车]', 832 | 830: 'stretcher [担架]', 833 | 831: 'studio couch, day bed [沙发床]', 834 | 832: 'stupa, tope [佛塔]', 835 | 833: 'submarine, pigboat, sub, U-boat [潜艇,潜水艇]', 836 | 834: 'suit, suit of clothes [套装,衣服]', 837 | 835: 'sundial [日晷]', 838 | 836: 'sunglass [太阳镜]', 839 | 837: 'sunglasses, dark glasses, shades [太阳镜,墨镜]', 840 | 838: 'sunscreen, sunblock, sun blocker [防晒霜,防晒剂]', 841 | 839: 'suspension bridge [悬索桥]', 842 | 840: 'swab, swob, mop [拖把]', 843 | 841: 'sweatshirt [运动衫]', 844 | 842: 'swimming trunks, bathing trunks [游泳裤]', 845 | 843: 'swing [秋千]', 846 | 844: 'switch, electric switch, electrical switch [开关,电器开关]', 847 | 845: 'syringe [注射器]', 848 | 846: 'table lamp [台灯]', 849 | 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle [坦克,装甲战车,装甲战斗车辆]', 850 | 848: 'tape player [磁带播放器]', 851 | 849: 'teapot [茶壶]', 852 | 850: 'teddy, teddy bear [泰迪,泰迪熊]', 853 | 851: 'television, television system [电视]', 854 | 852: 'tennis ball [网球]', 855 | 853: 'thatch, thatched roof [茅草,茅草屋顶]', 856 | 854: 'theater curtain, theatre curtain [幕布,剧院的帷幕]', 857 | 855: 'thimble [顶针]', 858 | 856: 'thresher, thrasher, threshing machine [脱粒机]', 859 | 857: 'throne [宝座]', 860 | 858: 'tile roof [瓦屋顶]', 861 | 859: 'toaster [烤面包机]', 862 | 860: 'tobacco shop, tobacconist shop, tobacconist [烟草店,烟草]', 863 | 861: 'toilet seat [马桶]', 864 | 862: 'torch [火炬]', 865 | 863: 'totem pole [图腾柱]', 866 | 864: 'tow truck, tow car, wrecker [拖车,牵引车,清障车]', 867 | 865: 'toyshop [玩具店]', 868 | 866: 'tractor [拖拉机]', 869 | 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi [拖车,铰接式卡车]', 870 | 868: 'tray [托盘]', 871 | 869: 'trench coat [风衣]', 872 | 870: 'tricycle, trike, velocipede [三轮车]', 873 | 871: 'trimaran [三体船]', 874 | 872: 'tripod [三脚架]', 875 | 873: 'triumphal arch [凯旋门]', 876 | 874: 'trolleybus, trolley coach, trackless trolley [无轨电车]', 877 | 875: 'trombone [长号]', 878 | 876: 'tub, vat [浴盆,浴缸]', 879 | 877: 'turnstile [旋转式栅门]', 880 | 878: 'typewriter keyboard [打字机键盘]', 881 | 879: 'umbrella [伞]', 882 | 880: 'unicycle, monocycle [独轮车]', 883 | 881: 'upright, upright piano [直立式钢琴]', 884 | 882: 'vacuum, vacuum cleaner [真空吸尘器]', 885 | 883: 'vase [花瓶]', 886 | 884: 'vault [拱顶]', 887 | 885: 'velvet [天鹅绒]', 888 | 886: 'vending machine [自动售货机]', 889 | 887: 'vestment [祭服]', 890 | 888: 'viaduct [高架桥]', 891 | 889: 'violin, fiddle [小提琴,小提琴]', 892 | 890: 'volleyball [排球]', 893 | 891: 'waffle iron [松饼机]', 894 | 892: 'wall clock [挂钟]', 895 | 893: 'wallet, billfold, notecase, pocketbook [钱包,皮夹]', 896 | 894: 'wardrobe, closet, press [衣柜,壁橱]', 897 | 895: 'warplane, military plane [军用飞机]', 898 | 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin [洗脸盆,洗手盆]', 899 | 897: 'washer, automatic washer, washing machine [洗衣机,自动洗衣机]', 900 | 898: 'water bottle [水瓶]', 901 | 899: 'water jug [水壶]', 902 | 900: 'water tower [水塔]', 903 | 901: 'whiskey jug [威士忌壶]', 904 | 902: 'whistle [哨子]', 905 | 903: 'wig [假发]', 906 | 904: 'window screen [纱窗]', 907 | 905: 'window shade [百叶窗]', 908 | 906: 'Windsor tie [温莎领带]', 909 | 907: 'wine bottle [葡萄酒瓶]', 910 | 908: 'wing [飞机翅膀,飞机]', 911 | 909: 'wok [炒菜锅]', 912 | 910: 'wooden spoon [木制的勺子]', 913 | 911: 'wool, woolen, woollen [毛织品,羊绒]', 914 | 912: 'worm fence, snake fence, snake-rail fence, Virginia fence [栅栏,围栏]', 915 | 913: 'wreck [沉船]', 916 | 914: 'yawl [双桅船]', 917 | 915: 'yurt [蒙古包]', 918 | 916: 'web site, website, internet site, site [网站,互联网网站]', 919 | 917: 'comic book [漫画]', 920 | 918: 'crossword puzzle, crossword [纵横字谜]', 921 | 919: 'street sign [路标]', 922 | 920: 'traffic light, traffic signal, stoplight [交通信号灯]', 923 | 921: 'book jacket, dust cover, dust jacket, dust wrapper [防尘罩,书皮]', 924 | 922: 'menu [菜单]', 925 | 923: 'plate [盘子]', 926 | 924: 'guacamole [鳄梨酱]', 927 | 925: 'consomme [清汤]', 928 | 926: 'hot pot, hotpot [罐焖土豆烧肉]', 929 | 927: 'trifle [蛋糕]', 930 | 928: 'ice cream, icecream [冰淇淋]', 931 | 929: 'ice lolly, lolly, lollipop, popsicle [雪糕,冰棍,冰棒]', 932 | 930: 'French loaf [法式面包]', 933 | 931: 'bagel, beigel [百吉饼]', 934 | 932: 'pretzel [椒盐脆饼]', 935 | 933: 'cheeseburger [芝士汉堡]', 936 | 934: 'hotdog, hot dog, red hot [热狗]', 937 | 935: 'mashed potato [土豆泥]', 938 | 936: 'head cabbage [结球甘蓝]', 939 | 937: 'broccoli [西兰花]', 940 | 938: 'cauliflower [菜花]', 941 | 939: 'zucchini, courgette [绿皮密生西葫芦]', 942 | 940: 'spaghetti squash [西葫芦]', 943 | 941: 'acorn squash [小青南瓜]', 944 | 942: 'butternut squash [南瓜]', 945 | 943: 'cucumber, cuke [黄瓜]', 946 | 944: 'artichoke, globe artichoke [朝鲜蓟]', 947 | 945: 'bell pepper [甜椒]', 948 | 946: 'cardoon [刺棘蓟]', 949 | 947: 'mushroom [蘑菇]', 950 | 948: 'Granny Smith [绿苹果]', 951 | 949: 'strawberry [草莓]', 952 | 950: 'orange [橘子]', 953 | 951: 'lemon [柠檬]', 954 | 952: 'fig [无花果]', 955 | 953: 'pineapple, ananas [菠萝]', 956 | 954: 'banana [香蕉]', 957 | 955: 'jackfruit, jak, jack [菠萝蜜]', 958 | 956: 'custard apple [蛋奶冻苹果]', 959 | 957: 'pomegranate [石榴]', 960 | 958: 'hay [干草]', 961 | 959: 'carbonara [烤面条加干酪沙司]', 962 | 960: 'chocolate sauce, chocolate syrup [巧克力酱,巧克力糖浆]', 963 | 961: 'dough [面团]', 964 | 962: 'meat loaf, meatloaf [瑞士肉包,肉饼]', 965 | 963: 'pizza, pizza pie [披萨,披萨饼]', 966 | 964: 'potpie [馅饼]', 967 | 965: 'burrito [卷饼]', 968 | 966: 'red wine [红葡萄酒]', 969 | 967: 'espresso [意大利浓咖啡]', 970 | 968: 'cup [杯子]', 971 | 969: 'eggnog [蛋酒]', 972 | 970: 'alp [高山]', 973 | 971: 'bubble [泡泡]', 974 | 972: 'cliff, drop, drop-off [悬崖]', 975 | 973: 'coral reef [珊瑚礁]', 976 | 974: 'geyser [间歇泉]', 977 | 975: 'lakeside, lakeshore [湖边,湖岸]', 978 | 976: 'promontory, headland, head, foreland [海角]', 979 | 977: 'sandbar, sand bar [沙洲,沙坝]', 980 | 978: 'seashore, coast, seacoast, sea-coast [海滨,海岸]', 981 | 979: 'valley, vale [峡谷]', 982 | 980: 'volcano [火山]', 983 | 981: 'ballplayer, baseball player [棒球,棒球运动员]', 984 | 982: 'groom, bridegroom [新郎]', 985 | 983: 'scuba diver [潜水员]', 986 | 984: 'rapeseed [油菜]', 987 | 985: 'daisy [雏菊]', 988 | 986: 'yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum [杓兰]', 989 | 987: 'corn [玉米]', 990 | 988: 'acorn [橡子]', 991 | 989: 'hip, rose hip, rosehip [玫瑰果]', 992 | 990: 'buckeye, horse chestnut, conker [七叶树果实]', 993 | 991: 'coral fungus [珊瑚菌]', 994 | 992: 'agaric [木耳]', 995 | 993: 'gyromitra [鹿花菌]', 996 | 994: 'stinkhorn, carrion fungus [鬼笔菌]', 997 | 995: 'earthstar [地星(菌类)]', 998 | 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa [多叶奇果菌]', 999 | 997: 'bolete [牛肝菌]', 1000 | 998: 'ear, spike, capitulum [玉米穗]', 1001 | 999: 'toilet tissue, toilet paper, bathroom tissue [卫生纸]', 1002 | } 1003 | -------------------------------------------------------------------------------- /pixelflow/data_in1k.py: -------------------------------------------------------------------------------- 1 | # ImageNet-1K Dataset and DataLoader 2 | 3 | from einops import rearrange 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data.distributed import DistributedSampler 7 | from torchvision.datasets import ImageFolder 8 | from torchvision import transforms 9 | from PIL import Image 10 | import math 11 | from functools import partial 12 | import numpy as np 13 | import random 14 | 15 | from diffusers.models.embeddings import get_2d_rotary_pos_embed 16 | 17 | # https://github.com/facebookresearch/DiT/blob/main/train.py#L85 18 | def center_crop_arr(pil_image, image_size): 19 | while min(*pil_image.size) >= 2 * image_size: 20 | pil_image = pil_image.resize( 21 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 22 | ) 23 | 24 | scale = image_size / min(*pil_image.size) 25 | pil_image = pil_image.resize( 26 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 27 | ) 28 | 29 | arr = np.array(pil_image) 30 | crop_y = (arr.shape[0] - image_size) // 2 31 | crop_x = (arr.shape[1] - image_size) // 2 32 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 33 | 34 | 35 | def collate_fn(examples, config, noise_scheduler_copy): 36 | patch_size = config.model.params.patch_size 37 | pixel_values = torch.stack([eg[0] for eg in examples]) 38 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 39 | input_ids = [eg[1] for eg in examples] 40 | 41 | batch_size = len(examples) 42 | stage_indices = list(range(config.scheduler.num_stages)) * (batch_size // config.scheduler.num_stages + 1) 43 | stage_indices = stage_indices[:batch_size] 44 | 45 | random.shuffle(stage_indices) 46 | stage_indices = torch.tensor(stage_indices, dtype=torch.int32) 47 | orig_height, orig_width = pixel_values.shape[-2:] 48 | timesteps = torch.randint(0, config.scheduler.num_train_timesteps, (batch_size,)) 49 | 50 | sample_list, input_ids_list, pos_embed_list, seq_len_list, target_list, timestep_list = [], [], [], [], [], [] 51 | for stage_idx in range(config.scheduler.num_stages): 52 | corrected_stage_idx = config.scheduler.num_stages - stage_idx - 1 53 | stage_select_indices = timesteps[stage_indices == corrected_stage_idx] 54 | Timesteps = noise_scheduler_copy.Timesteps_per_stage[corrected_stage_idx][stage_select_indices].float() 55 | batch_size_select = Timesteps.shape[0] 56 | pixel_values_select = pixel_values[stage_indices == corrected_stage_idx] 57 | input_ids_select = [input_ids[i] for i in range(batch_size) if stage_indices[i] == corrected_stage_idx] 58 | 59 | end_height, end_width = orig_height // (2 ** stage_idx), orig_width // (2 ** stage_idx) 60 | 61 | ################ build model input ################ 62 | start_t, end_t = noise_scheduler_copy.start_t[corrected_stage_idx], noise_scheduler_copy.end_t[corrected_stage_idx] 63 | 64 | pixel_values_end = pixel_values_select 65 | pixel_values_start = pixel_values_select 66 | if stage_idx > 0: 67 | # pixel_values_end 68 | for downsample_idx in range(1, stage_idx + 1): 69 | pixel_values_end = F.interpolate(pixel_values_end, (orig_height // (2 ** downsample_idx), orig_width // (2 ** downsample_idx)), mode="bilinear") 70 | 71 | # pixel_values_start 72 | for downsample_idx in range(1, stage_idx + 2): 73 | pixel_values_start = F.interpolate(pixel_values_start, (orig_height // (2 ** downsample_idx), orig_width // (2 ** downsample_idx)), mode="bilinear") 74 | # upsample pixel_values_start 75 | pixel_values_start = F.interpolate(pixel_values_start, (end_height, end_width), mode="nearest") 76 | 77 | noise = torch.randn_like(pixel_values_end) 78 | pixel_values_end = end_t * pixel_values_end + (1.0 - end_t) * noise 79 | pixel_values_start = start_t * pixel_values_start + (1.0 - start_t) * noise 80 | target = pixel_values_end - pixel_values_start 81 | 82 | t_select = noise_scheduler_copy.t_window_per_stage[corrected_stage_idx][stage_select_indices].flatten() 83 | while len(t_select.shape) < pixel_values_start.ndim: 84 | t_select = t_select.unsqueeze(-1) 85 | xt = t_select.float() * pixel_values_end + (1.0 - t_select.float()) * pixel_values_start 86 | 87 | target = rearrange(target, 'b c (h ph) (w pw) -> (b h w) (c ph pw)', ph=patch_size, pw=patch_size) 88 | xt = rearrange(xt, 'b c (h ph) (w pw) -> (b h w) (c ph pw)', ph=patch_size, pw=patch_size) 89 | 90 | pos_embed = get_2d_rotary_pos_embed( 91 | embed_dim=config.model.params.attention_head_dim, 92 | crops_coords=((0, 0), (end_height // patch_size, end_width // patch_size)), 93 | grid_size=(end_height // patch_size, end_width // patch_size), 94 | ) 95 | seq_len = (end_height // patch_size) * (end_width // patch_size) 96 | assert end_height == end_width, f"only support square image, got {seq_len}; TODO: latent_size_list" 97 | sample_list.append(xt) 98 | target_list.append(target) 99 | pos_embed_list.extend([pos_embed] * batch_size_select) 100 | seq_len_list.extend([seq_len] * batch_size_select) 101 | timestep_list.append(Timesteps) 102 | input_ids_list.extend(input_ids_select) 103 | 104 | pixel_values = torch.cat(sample_list, dim=0).to(memory_format=torch.contiguous_format) 105 | target_values = torch.cat(target_list, dim=0).to(memory_format=torch.contiguous_format) 106 | pos_embed = torch.cat([torch.stack(one_pos_emb, -1) for one_pos_emb in pos_embed_list], dim=0).float() 107 | cumsum_q_len = torch.cumsum(torch.tensor([0] + seq_len_list), 0).to(torch.int32) 108 | latent_size_list = torch.tensor([int(math.sqrt(seq_len)) for seq_len in seq_len_list], dtype=torch.int32) 109 | 110 | return { 111 | "pixel_values": pixel_values, 112 | "input_ids": input_ids_list, 113 | "pos_embed": pos_embed, 114 | "cumsum_q_len": cumsum_q_len, 115 | "batch_latent_size": latent_size_list, 116 | "seqlen_list_q": seq_len_list, 117 | "cumsum_kv_len": None, 118 | "batch_kv_len": None, 119 | "timesteps": torch.cat(timestep_list, dim=0), 120 | "target_values": target_values, 121 | } 122 | 123 | 124 | def build_imagenet_loader(config, noise_scheduler_copy): 125 | if config.data.center_crop: 126 | transform = transforms.Compose([ 127 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, config.data.resolution)), 128 | transforms.RandomHorizontalFlip(), 129 | transforms.ToTensor(), 130 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 131 | ]) 132 | else: 133 | transform = transforms.Compose([ 134 | transforms.Resize(round(config.data.resolution * config.data.expand_ratio), interpolation=transforms.InterpolationMode.LANCZOS), 135 | transforms.RandomCrop(config.data.resolution), 136 | transforms.RandomHorizontalFlip(), 137 | transforms.ToTensor(), 138 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 139 | ]) 140 | dataset = ImageFolder(config.data.root, transform=transform) 141 | sampler = DistributedSampler( 142 | dataset, 143 | num_replicas=torch.distributed.get_world_size(), 144 | rank=torch.distributed.get_rank(), 145 | shuffle=True, 146 | seed=config.seed, 147 | ) 148 | 149 | loader = torch.utils.data.DataLoader( 150 | dataset, 151 | batch_size=config.data.batch_size, 152 | collate_fn=partial(collate_fn, config=config, noise_scheduler_copy=noise_scheduler_copy), 153 | shuffle=False, 154 | sampler=sampler, 155 | num_workers=config.data.num_workers, 156 | drop_last=True, 157 | ) 158 | return loader, sampler 159 | -------------------------------------------------------------------------------- /pixelflow/model.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from diffusers.models.embeddings import Timesteps, TimestepEmbedding, LabelEmbedding 7 | import warnings 8 | 9 | try: 10 | from flash_attn import flash_attn_varlen_func 11 | except ImportError: 12 | warnings.warn("`flash-attn` is not installed. Training mode may not work properly.", UserWarning) 13 | flash_attn_varlen_func = None 14 | 15 | 16 | def apply_rotary_emb( 17 | x: torch.Tensor, 18 | freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], 19 | ) -> Tuple[torch.Tensor, torch.Tensor]: 20 | cos, sin = freqs_cis.unbind(-1) 21 | cos = cos[None, None] 22 | sin = sin[None, None] 23 | cos, sin = cos.to(x.device), sin.to(x.device) 24 | 25 | x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] 26 | x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) 27 | out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) 28 | 29 | return out 30 | 31 | 32 | class PatchEmbed(nn.Module): 33 | def __init__(self, patch_size, in_channels, embed_dim, bias=True): 34 | super().__init__() 35 | self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias) 36 | 37 | def forward_unfold(self, x): 38 | out_unfold = x.matmul(self.proj.weight.view(self.proj.weight.size(0), -1).t()) 39 | if self.proj.bias is not None: 40 | out_unfold += self.proj.bias.to(out_unfold.dtype) 41 | return out_unfold 42 | 43 | # force fp32 for strict numerical reproducibility (debug only) 44 | # @torch.autocast('cuda', enabled=False) 45 | def forward(self, x): 46 | if self.training: 47 | return self.forward_unfold(x) 48 | out = self.proj(x) 49 | out = out.flatten(2).transpose(1, 2) # BCHW -> BNC 50 | 51 | return out 52 | 53 | class AdaLayerNorm(nn.Module): 54 | def __init__(self, embedding_dim): 55 | super().__init__() 56 | self.embedding_dim = embedding_dim 57 | self.silu = nn.SiLU() 58 | self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) 59 | self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) 60 | 61 | def forward(self, x, timestep, seqlen_list=None): 62 | input_dtype = x.dtype 63 | emb = self.linear(self.silu(timestep)) 64 | 65 | if seqlen_list is not None: 66 | # equivalent to `torch.repeat_interleave` but faster 67 | emb = torch.cat([one_emb[None].expand(repeat_time, -1) for one_emb, repeat_time in zip(emb, seqlen_list)]) 68 | else: 69 | emb = emb.unsqueeze(1) 70 | 71 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.float().chunk(6, dim=-1) 72 | x = self.norm(x).float() * (1 + scale_msa) + shift_msa 73 | return x.to(input_dtype), gate_msa, shift_mlp, scale_mlp, gate_mlp 74 | 75 | 76 | class FeedForward(nn.Module): 77 | def __init__(self, dim, dim_out=None, mult=4, inner_dim=None, bias=True): 78 | super().__init__() 79 | inner_dim = int(dim * mult) if inner_dim is None else inner_dim 80 | dim_out = dim_out if dim_out is not None else dim 81 | self.fc1 = nn.Linear(dim, inner_dim, bias=bias) 82 | self.fc2 = nn.Linear(inner_dim, dim_out, bias=bias) 83 | 84 | def forward(self, hidden_states): 85 | hidden_states = self.fc1(hidden_states) 86 | hidden_states = F.gelu(hidden_states, approximate="tanh") 87 | hidden_states = self.fc2(hidden_states) 88 | return hidden_states 89 | 90 | 91 | class RMSNorm(nn.Module): 92 | def __init__(self, dim: int, eps=1e-6): 93 | super().__init__() 94 | self.weight = nn.Parameter(torch.ones(dim)) 95 | self.eps = eps 96 | 97 | def forward(self, x): 98 | output = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) 99 | return (self.weight * output).to(x.dtype) 100 | 101 | 102 | class Attention(nn.Module): 103 | def __init__(self, q_dim, kv_dim=None, heads=8, head_dim=64, dropout=0.0, bias=False): 104 | super().__init__() 105 | self.q_dim = q_dim 106 | self.kv_dim = kv_dim if kv_dim is not None else q_dim 107 | self.inner_dim = head_dim * heads 108 | self.dropout = dropout 109 | self.head_dim = head_dim 110 | self.num_heads = heads 111 | 112 | self.q_proj = nn.Linear(self.q_dim, self.inner_dim, bias=bias) 113 | self.k_proj = nn.Linear(self.kv_dim, self.inner_dim, bias=bias) 114 | self.v_proj = nn.Linear(self.kv_dim, self.inner_dim, bias=bias) 115 | 116 | self.o_proj = nn.Linear(self.inner_dim, self.q_dim, bias=bias) 117 | 118 | self.q_norm = RMSNorm(self.inner_dim) 119 | self.k_norm = RMSNorm(self.inner_dim) 120 | 121 | def prepare_attention_mask( 122 | # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L694 123 | self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 124 | ): 125 | head_size = self.num_heads 126 | if attention_mask is None: 127 | return attention_mask 128 | 129 | current_length: int = attention_mask.shape[-1] 130 | if current_length != target_length: 131 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 132 | 133 | if out_dim == 3: 134 | if attention_mask.shape[0] < batch_size * head_size: 135 | attention_mask = attention_mask.repeat_interleave(head_size, dim=0) 136 | elif out_dim == 4: 137 | attention_mask = attention_mask.unsqueeze(1) 138 | attention_mask = attention_mask.repeat_interleave(head_size, dim=1) 139 | 140 | return attention_mask 141 | 142 | def forward( 143 | self, 144 | inputs_q, 145 | inputs_kv, 146 | attention_mask=None, 147 | cross_attention=False, 148 | rope_pos_embed=None, 149 | cu_seqlens_q=None, 150 | cu_seqlens_k=None, 151 | max_seqlen_q=None, 152 | max_seqlen_k=None, 153 | ): 154 | 155 | inputs_kv = inputs_q if inputs_kv is None else inputs_kv 156 | 157 | query_states = self.q_proj(inputs_q) 158 | key_states = self.k_proj(inputs_kv) 159 | value_states = self.v_proj(inputs_kv) 160 | 161 | query_states = self.q_norm(query_states) 162 | key_states = self.k_norm(key_states) 163 | 164 | if max_seqlen_q is None: 165 | assert not self.training, "PixelFlow needs sequence packing for training" 166 | 167 | bsz, q_len, _ = inputs_q.shape 168 | _, kv_len, _ = inputs_kv.shape 169 | 170 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 171 | key_states = key_states.view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) 172 | value_states = value_states.view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) 173 | 174 | query_states = apply_rotary_emb(query_states, rope_pos_embed) 175 | if not cross_attention: 176 | key_states = apply_rotary_emb(key_states, rope_pos_embed) 177 | 178 | if attention_mask is not None: 179 | attention_mask = self.prepare_attention_mask(attention_mask, kv_len, bsz) 180 | # scaled_dot_product_attention expects attention_mask shape to be 181 | # (batch, heads, source_length, target_length) 182 | attention_mask = attention_mask.view(bsz, self.num_heads, -1, attention_mask.shape[-1]) 183 | 184 | # with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.MATH]): # strict numerical reproducibility (debug only) 185 | attn_output = F.scaled_dot_product_attention( 186 | query_states, 187 | key_states, 188 | value_states, 189 | attn_mask=attention_mask, 190 | dropout_p=self.dropout if self.training else 0.0, 191 | is_causal=False, 192 | ) 193 | 194 | attn_output = attn_output.transpose(1, 2).contiguous() 195 | attn_output = attn_output.view(bsz, q_len, self.inner_dim) 196 | attn_output = self.o_proj(attn_output) 197 | return attn_output 198 | 199 | else: 200 | # sequence packing mode 201 | query_states = query_states.view(-1, self.num_heads, self.head_dim) 202 | key_states = key_states.view(-1, self.num_heads, self.head_dim) 203 | value_states = value_states.view(-1, self.num_heads, self.head_dim) 204 | 205 | query_states = apply_rotary_emb(query_states.permute(1, 0, 2)[None], rope_pos_embed)[0].permute(1, 0, 2) 206 | if not cross_attention: 207 | key_states = apply_rotary_emb(key_states.permute(1, 0, 2)[None], rope_pos_embed)[0].permute(1, 0, 2) 208 | 209 | attn_output = flash_attn_varlen_func( 210 | query_states, 211 | key_states, 212 | value_states, 213 | cu_seqlens_q=cu_seqlens_q, 214 | cu_seqlens_k=cu_seqlens_k, 215 | max_seqlen_q=max_seqlen_q, 216 | max_seqlen_k=max_seqlen_k, 217 | ) 218 | 219 | attn_output = attn_output.view(-1, self.num_heads * self.head_dim) 220 | attn_output = self.o_proj(attn_output) 221 | return attn_output 222 | 223 | 224 | class TransformerBlock(nn.Module): 225 | def __init__(self, dim, num_attention_heads, attention_head_dim, dropout=0.0, 226 | cross_attention_dim=None, attention_bias=False, 227 | ): 228 | super().__init__() 229 | self.norm1 = AdaLayerNorm(dim) 230 | 231 | # Self Attention 232 | self.attn1 = Attention(q_dim=dim, kv_dim=None, heads=num_attention_heads, head_dim=attention_head_dim, dropout=dropout, bias=attention_bias) 233 | 234 | if cross_attention_dim is not None: 235 | # Cross Attention 236 | self.norm2 = RMSNorm(dim, eps=1e-6) 237 | self.attn2 = Attention(q_dim=dim, kv_dim=cross_attention_dim, heads=num_attention_heads, head_dim=attention_head_dim, dropout=dropout, bias=attention_bias) 238 | else: 239 | self.attn2 = None 240 | 241 | self.norm3 = RMSNorm(dim, eps=1e-6) 242 | self.mlp = FeedForward(dim) 243 | 244 | def forward( 245 | self, 246 | hidden_states, 247 | encoder_hidden_states=None, 248 | encoder_attention_mask=None, 249 | timestep=None, 250 | rope_pos_embed=None, 251 | cu_seqlens_q=None, 252 | cu_seqlens_k=None, 253 | seqlen_list_q=None, 254 | seqlen_list_k=None, 255 | ): 256 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, timestep, seqlen_list_q) 257 | 258 | attn_output = self.attn1( 259 | inputs_q=norm_hidden_states, 260 | inputs_kv=None, 261 | attention_mask=None, 262 | cross_attention=False, 263 | rope_pos_embed=rope_pos_embed, 264 | cu_seqlens_q=cu_seqlens_q, 265 | cu_seqlens_k=cu_seqlens_q, 266 | max_seqlen_q=max(seqlen_list_q) if seqlen_list_q is not None else None, 267 | max_seqlen_k=max(seqlen_list_q) if seqlen_list_q is not None else None, 268 | ) 269 | 270 | attn_output = (gate_msa * attn_output.float()).to(attn_output.dtype) 271 | hidden_states = attn_output + hidden_states 272 | 273 | if self.attn2 is not None: 274 | norm_hidden_states = self.norm2(hidden_states) 275 | attn_output = self.attn2( 276 | inputs_q=norm_hidden_states, 277 | inputs_kv=encoder_hidden_states, 278 | attention_mask=encoder_attention_mask, 279 | cross_attention=True, 280 | rope_pos_embed=rope_pos_embed, 281 | cu_seqlens_q=cu_seqlens_q, 282 | cu_seqlens_k=cu_seqlens_k, 283 | max_seqlen_q=max(seqlen_list_q) if seqlen_list_q is not None else None, 284 | max_seqlen_k=max(seqlen_list_k) if seqlen_list_k is not None else None, 285 | ) 286 | hidden_states = hidden_states + attn_output 287 | 288 | norm_hidden_states = self.norm3(hidden_states) 289 | norm_hidden_states = (norm_hidden_states.float() * (1 + scale_mlp) + shift_mlp).to(norm_hidden_states.dtype) 290 | ff_output = self.mlp(norm_hidden_states) 291 | ff_output = (gate_mlp * ff_output.float()).to(ff_output.dtype) 292 | hidden_states = ff_output + hidden_states 293 | 294 | return hidden_states 295 | 296 | 297 | class PixelFlowModel(torch.nn.Module): 298 | def __init__(self, in_channels, out_channels, num_attention_heads, attention_head_dim, 299 | depth, patch_size, dropout=0.0, cross_attention_dim=None, attention_bias=True, num_classes=0, 300 | ): 301 | super().__init__() 302 | self.patch_size = patch_size 303 | self.attention_head_dim = attention_head_dim 304 | self.num_classes = num_classes 305 | self.out_channels = out_channels 306 | 307 | embed_dim = num_attention_heads * attention_head_dim 308 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim) 309 | 310 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) 311 | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embed_dim) 312 | 313 | # [stage] embedding 314 | self.latent_size_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embed_dim) 315 | if self.num_classes > 0: 316 | # class conditional 317 | self.class_embedder = LabelEmbedding(num_classes, embed_dim, dropout_prob=0.1) 318 | 319 | self.transformer_blocks = nn.ModuleList([ 320 | TransformerBlock(embed_dim, num_attention_heads, attention_head_dim, dropout, cross_attention_dim, attention_bias) for _ in range(depth) 321 | ]) 322 | 323 | self.norm_out = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) 324 | self.proj_out_1 = nn.Linear(embed_dim, 2 * embed_dim) 325 | self.proj_out_2 = nn.Linear(embed_dim, patch_size * patch_size * out_channels) 326 | 327 | self.initialize_from_scratch() 328 | 329 | def initialize_from_scratch(self): 330 | print("Starting Initialization...") 331 | def _basic_init(module): 332 | if isinstance(module, nn.Linear): 333 | torch.nn.init.xavier_uniform_(module.weight) 334 | if module.bias is not None: 335 | nn.init.constant_(module.bias, 0) 336 | self.apply(_basic_init) 337 | 338 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 339 | w = self.patch_embed.proj.weight.data 340 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 341 | nn.init.constant_(self.patch_embed.proj.bias, 0) 342 | 343 | nn.init.normal_(self.timestep_embedder.linear_1.weight, std=0.02) 344 | nn.init.normal_(self.timestep_embedder.linear_2.weight, std=0.02) 345 | 346 | nn.init.normal_(self.latent_size_embedder.linear_1.weight, std=0.02) 347 | nn.init.normal_(self.latent_size_embedder.linear_2.weight, std=0.02) 348 | 349 | if self.num_classes > 0: 350 | nn.init.normal_(self.class_embedder.embedding_table.weight, std=0.02) 351 | 352 | for block in self.transformer_blocks: 353 | nn.init.constant_(block.norm1.linear.weight, 0) 354 | nn.init.constant_(block.norm1.linear.bias, 0) 355 | 356 | nn.init.constant_(self.proj_out_1.weight, 0) 357 | nn.init.constant_(self.proj_out_1.bias, 0) 358 | nn.init.constant_(self.proj_out_2.weight, 0) 359 | nn.init.constant_(self.proj_out_2.bias, 0) 360 | 361 | def forward( 362 | self, 363 | hidden_states, 364 | encoder_hidden_states=None, 365 | class_labels=None, 366 | timestep=None, 367 | latent_size=None, 368 | encoder_attention_mask=None, 369 | pos_embed=None, 370 | cu_seqlens_q=None, 371 | cu_seqlens_k=None, 372 | seqlen_list_q=None, 373 | seqlen_list_k=None, 374 | ): 375 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 376 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 377 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 378 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 379 | 380 | orig_height, orig_width = hidden_states.shape[-2], hidden_states.shape[-1] 381 | hidden_states = hidden_states.to(torch.float32) 382 | hidden_states = self.patch_embed(hidden_states) 383 | 384 | # timestep, class_embed, latent_size_embed 385 | timesteps_proj = self.time_proj(timestep) 386 | conditioning = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) 387 | 388 | if self.num_classes > 0: 389 | class_embed = self.class_embedder(class_labels) 390 | conditioning += class_embed 391 | 392 | latent_size_proj = self.time_proj(latent_size) 393 | latent_size_embed = self.latent_size_embedder(latent_size_proj.to(dtype=hidden_states.dtype)) 394 | conditioning += latent_size_embed 395 | 396 | for block in self.transformer_blocks: 397 | hidden_states = block( 398 | hidden_states, 399 | encoder_hidden_states=encoder_hidden_states, 400 | encoder_attention_mask=encoder_attention_mask, 401 | timestep=conditioning, 402 | rope_pos_embed=pos_embed, 403 | cu_seqlens_q=cu_seqlens_q, 404 | cu_seqlens_k=cu_seqlens_k, 405 | seqlen_list_q=seqlen_list_q, 406 | seqlen_list_k=seqlen_list_k, 407 | ) 408 | 409 | shift, scale = self.proj_out_1(F.silu(conditioning)).float().chunk(2, dim=1) 410 | if seqlen_list_q is None: 411 | shift = shift.unsqueeze(1) 412 | scale = scale.unsqueeze(1) 413 | else: 414 | shift = torch.cat([shift_i[None].expand(ri, -1) for shift_i, ri in zip(shift, seqlen_list_q)]) 415 | scale = torch.cat([scale_i[None].expand(ri, -1) for scale_i, ri in zip(scale, seqlen_list_q)]) 416 | 417 | hidden_states = (self.norm_out(hidden_states).float() * (1 + scale) + shift).to(hidden_states.dtype) 418 | hidden_states = self.proj_out_2(hidden_states) 419 | if self.training: 420 | hidden_states = hidden_states.reshape(hidden_states.shape[0], self.patch_size, self.patch_size, self.out_channels) 421 | hidden_states = hidden_states.permute(0, 3, 1, 2).flatten(1) 422 | return hidden_states 423 | 424 | height, width = orig_height // self.patch_size, orig_width // self.patch_size 425 | hidden_states = hidden_states.reshape( 426 | shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) 427 | ) 428 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 429 | output = hidden_states.reshape( 430 | shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) 431 | ) 432 | 433 | return output 434 | 435 | def c2i_forward_cfg_torchdiffq(self, hidden_states, timestep, class_labels, latent_size, pos_embed, cfg_scale): 436 | # used for evaluation with ODE ('dopri5') solver from torchdiffeq 437 | half = hidden_states[: len(hidden_states)//2] 438 | combined = torch.cat([half, half], dim=0) 439 | out = self.forward( 440 | hidden_states=combined, 441 | timestep=timestep, 442 | class_labels=class_labels, 443 | latent_size=latent_size, 444 | pos_embed=pos_embed, 445 | ) 446 | uncond_out, cond_out = torch.split(out, len(out)//2, dim=0) 447 | half_output = uncond_out + cfg_scale * (cond_out - uncond_out) 448 | output = torch.cat([half_output, half_output], dim=0) 449 | return output 450 | -------------------------------------------------------------------------------- /pixelflow/pipeline_pixelflow.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | import math 3 | from typing import List, Optional, Union 4 | import time 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from diffusers.utils.torch_utils import randn_tensor 9 | from diffusers.models.embeddings import get_2d_rotary_pos_embed 10 | 11 | 12 | class PixelFlowPipeline: 13 | def __init__( 14 | self, 15 | scheduler, 16 | transformer, 17 | text_encoder=None, 18 | tokenizer=None, 19 | max_token_length=512, 20 | ): 21 | super().__init__() 22 | self.class_cond = text_encoder is None or tokenizer is None 23 | self.scheduler = scheduler 24 | self.transformer = transformer 25 | self.patch_size = transformer.patch_size 26 | self.head_dim = transformer.attention_head_dim 27 | self.num_stages = scheduler.num_stages 28 | 29 | self.text_encoder = text_encoder 30 | self.tokenizer = tokenizer 31 | self.max_token_length = max_token_length 32 | 33 | @torch.autocast("cuda", enabled=False) 34 | def encode_prompt( 35 | self, 36 | prompt: Union[str, List[str]], 37 | device: Optional[torch.device] = None, 38 | num_images_per_prompt: int = 1, 39 | do_classifier_free_guidance: bool = True, 40 | negative_prompt: Union[str, List[str]] = "", 41 | prompt_embeds: Optional[torch.FloatTensor] = None, 42 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 43 | prompt_attention_mask: Optional[torch.FloatTensor] = None, 44 | negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, 45 | use_attention_mask: bool = False, 46 | max_length: int = 512, 47 | ): 48 | # Determine the batch size and normalize prompt input to a list 49 | if prompt is not None: 50 | if isinstance(prompt, str): 51 | prompt = [prompt] 52 | batch_size = len(prompt) 53 | else: 54 | batch_size = prompt_embeds.shape[0] 55 | 56 | # Process prompt embeddings if not provided 57 | if prompt_embeds is None: 58 | text_inputs = self.tokenizer( 59 | prompt, 60 | padding="max_length", 61 | max_length=max_length, 62 | truncation=True, 63 | add_special_tokens=True, 64 | return_tensors="pt", 65 | ) 66 | text_input_ids = text_inputs.input_ids.to(device) 67 | prompt_attention_mask = text_inputs.attention_mask.to(device) 68 | prompt_embeds = self.text_encoder( 69 | text_input_ids, 70 | attention_mask=prompt_attention_mask if use_attention_mask else None 71 | )[0] 72 | 73 | # Determine dtype from available encoder 74 | if self.text_encoder is not None: 75 | dtype = self.text_encoder.dtype 76 | elif self.transformer is not None: 77 | dtype = self.transformer.dtype 78 | else: 79 | dtype = None 80 | 81 | # Move prompt embeddings to desired dtype and device 82 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 83 | 84 | bs_embed, seq_len, _ = prompt_embeds.shape 85 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 86 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 87 | prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1) 88 | 89 | # Handle classifier-free guidance for negative prompts 90 | if do_classifier_free_guidance and negative_prompt_embeds is None: 91 | # Normalize negative prompt to list and validate length 92 | if isinstance(negative_prompt, str): 93 | uncond_tokens = [negative_prompt] * batch_size 94 | elif isinstance(negative_prompt, list): 95 | if len(negative_prompt) != batch_size: 96 | raise ValueError(f"The negative prompt list must have the same length as the prompt list, but got {len(negative_prompt)} and {batch_size}") 97 | uncond_tokens = negative_prompt 98 | else: 99 | raise ValueError(f"Negative prompt must be a string or a list of strings, but got {type(negative_prompt)}") 100 | 101 | # Tokenize and encode negative prompts 102 | uncond_inputs = self.tokenizer( 103 | uncond_tokens, 104 | padding="max_length", 105 | max_length=prompt_embeds.shape[1], 106 | truncation=True, 107 | return_attention_mask=True, 108 | add_special_tokens=True, 109 | return_tensors="pt", 110 | ) 111 | negative_input_ids = uncond_inputs.input_ids.to(device) 112 | negative_prompt_attention_mask = uncond_inputs.attention_mask.to(device) 113 | negative_prompt_embeds = self.text_encoder( 114 | negative_input_ids, 115 | attention_mask=negative_prompt_attention_mask if use_attention_mask else None 116 | )[0] 117 | 118 | if do_classifier_free_guidance: 119 | # Duplicate negative prompt embeddings and attention mask for each generation 120 | seq_len_neg = negative_prompt_embeds.shape[1] 121 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) 122 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) 123 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len_neg, -1) 124 | negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1) 125 | else: 126 | negative_prompt_embeds = None 127 | negative_prompt_attention_mask = None 128 | 129 | # Concatenate negative and positive embeddings and their masks 130 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 131 | prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) 132 | 133 | return prompt_embeds, prompt_attention_mask 134 | 135 | def sample_block_noise(self, bs, ch, height, width, eps=1e-6): 136 | gamma = self.scheduler.gamma 137 | dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 - gamma) + torch.ones(4, 4) * gamma + eps * torch.eye(4)) 138 | block_number = bs * ch * (height // 2) * (width // 2) 139 | noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4] 140 | noise = rearrange(noise, '(b c h w) (p q) -> b c (h p) (w q)',b=bs,c=ch,h=height//2,w=width//2,p=2,q=2) 141 | return noise 142 | 143 | @torch.no_grad() 144 | def __call__( 145 | self, 146 | prompt, 147 | height, 148 | width, 149 | num_inference_steps=30, 150 | guidance_scale=4.0, 151 | num_images_per_prompt=1, 152 | device=None, 153 | shift=1.0, 154 | use_ode_dopri5=False, 155 | ): 156 | if isinstance(num_inference_steps, int): 157 | num_inference_steps = [num_inference_steps] * self.num_stages 158 | 159 | if use_ode_dopri5: 160 | assert self.class_cond, "ODE (dopri5) sampling is only supported for class-conditional models now" 161 | from pixelflow.solver_ode_wrapper import ODE 162 | sample_fn = ODE(t0=0, t1=1, sampler_type="dopri5", num_steps=num_inference_steps[0], atol=1e-06, rtol=0.001).sample 163 | else: 164 | # default Euler 165 | sample_fn = None 166 | 167 | self._guidance_scale = guidance_scale 168 | batch_size = len(prompt) 169 | if self.class_cond: 170 | prompt_embeds = torch.tensor(prompt, dtype=torch.int32).to(device) 171 | negative_prompt_embeds = 1000 * torch.ones_like(prompt_embeds) 172 | if self.do_classifier_free_guidance: 173 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 174 | else: 175 | prompt_embeds, prompt_attention_mask = self.encode_prompt( 176 | prompt, 177 | device, 178 | num_images_per_prompt, 179 | guidance_scale > 1, 180 | "", 181 | prompt_embeds=None, 182 | negative_prompt_embeds=None, 183 | use_attention_mask=True, 184 | max_length=self.max_token_length, 185 | ) 186 | 187 | init_factor = 2 ** (self.num_stages - 1) 188 | height, width = height // init_factor, width // init_factor 189 | shape = (batch_size * num_images_per_prompt, 3, height, width) 190 | latents = randn_tensor(shape, device=device, dtype=torch.float32) 191 | 192 | for stage_idx in range(self.num_stages): 193 | stage_start = time.time() 194 | # Set the number of inference steps for the current stage 195 | self.scheduler.set_timesteps(num_inference_steps[stage_idx], stage_idx, device=device, shift=shift) 196 | Timesteps = self.scheduler.Timesteps 197 | 198 | if stage_idx > 0: 199 | height, width = height * 2, width * 2 200 | latents = F.interpolate(latents, size=(height, width), mode='nearest') 201 | original_start_t = self.scheduler.original_start_t[stage_idx] 202 | gamma = self.scheduler.gamma 203 | alpha = 1 / (math.sqrt(1 - (1 / gamma)) * (1 - original_start_t) + original_start_t) 204 | beta = alpha * (1 - original_start_t) / math.sqrt(- gamma) 205 | 206 | # bs, ch, height, width = latents.shape 207 | noise = self.sample_block_noise(*latents.shape) 208 | noise = noise.to(device=device, dtype=latents.dtype) 209 | latents = alpha * latents + beta * noise 210 | 211 | size_tensor = torch.tensor([latents.shape[-1] // self.patch_size], dtype=torch.int32, device=device) 212 | pos_embed = get_2d_rotary_pos_embed( 213 | embed_dim=self.head_dim, 214 | crops_coords=((0, 0), (latents.shape[-1] // self.patch_size, latents.shape[-1] // self.patch_size)), 215 | grid_size=(latents.shape[-1] // self.patch_size, latents.shape[-1] // self.patch_size), 216 | ) 217 | rope_pos = torch.stack(pos_embed, -1) 218 | 219 | if sample_fn is not None: 220 | # dopri5 221 | model_kwargs = dict(class_labels=prompt_embeds, cfg_scale=self.guidance_scale(None, stage_idx), latent_size=size_tensor, pos_embed=rope_pos) 222 | if stage_idx == 0: 223 | latents = torch.cat([latents] * 2) 224 | stage_T_start = self.scheduler.Timesteps_per_stage[stage_idx][0].item() 225 | stage_T_end = self.scheduler.Timesteps_per_stage[stage_idx][-1].item() 226 | latents = sample_fn(latents, self.transformer.c2i_forward_cfg_torchdiffq, stage_T_start, stage_T_end, **model_kwargs)[-1] 227 | if stage_idx == self.num_stages - 1: 228 | latents = latents[:latents.shape[0] // 2] 229 | else: 230 | # euler 231 | for T in Timesteps: 232 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 233 | timestep = T.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) 234 | if self.class_cond: 235 | noise_pred = self.transformer(latent_model_input, timestep=timestep, class_labels=prompt_embeds, latent_size=size_tensor, pos_embed=rope_pos) 236 | else: 237 | encoder_hidden_states = prompt_embeds 238 | encoder_attention_mask = prompt_attention_mask 239 | 240 | noise_pred = self.transformer( 241 | latent_model_input, 242 | encoder_hidden_states=encoder_hidden_states, 243 | encoder_attention_mask=encoder_attention_mask, 244 | timestep=timestep, 245 | latent_size=size_tensor, 246 | pos_embed=rope_pos, 247 | ) 248 | 249 | if self.do_classifier_free_guidance: 250 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 251 | noise_pred = noise_pred_uncond + self.guidance_scale(T, stage_idx) * (noise_pred_text - noise_pred_uncond) 252 | 253 | latents = self.scheduler.step(model_output=noise_pred, sample=latents) 254 | stage_end = time.time() 255 | 256 | samples = (latents / 2 + 0.5).clamp(0, 1) 257 | samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() 258 | return samples 259 | 260 | @property 261 | def device(self): 262 | return next(self.transformer.parameters()).device 263 | 264 | @property 265 | def dtype(self): 266 | return next(self.transformer.parameters()).dtype 267 | 268 | def guidance_scale(self, step=None, stage_idx=None): 269 | if not self.class_cond: 270 | return self._guidance_scale 271 | scale_dict = {0: 0, 1: 1/6, 2: 2/3, 3: 1} 272 | return (self._guidance_scale - 1) * scale_dict[stage_idx] + 1 273 | 274 | @property 275 | def do_classifier_free_guidance(self): 276 | return self._guidance_scale > 0 277 | -------------------------------------------------------------------------------- /pixelflow/scheduling_pixelflow.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def cal_rectify_ratio(start_t, gamma): 7 | return 1 / (math.sqrt(1 - (1 / gamma)) * (1 - start_t) + start_t) 8 | 9 | 10 | class PixelFlowScheduler: 11 | def __init__(self, num_train_timesteps, num_stages, gamma=-1 / 3): 12 | assert num_stages > 0, f"num_stages must be positive, got {num_stages}" 13 | self.num_stages = num_stages 14 | self.gamma = gamma 15 | 16 | self.Timesteps = torch.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=torch.float32) 17 | 18 | self.t = self.Timesteps / num_train_timesteps # normalized time in [0, 1] 19 | 20 | self.stage_range = [x / num_stages for x in range(num_stages + 1)] 21 | 22 | self.original_start_t = dict() 23 | self.start_t, self.end_t = dict(), dict() 24 | self.t_window_per_stage = dict() 25 | self.Timesteps_per_stage = dict() 26 | stage_distance = list() 27 | 28 | # stage_idx = 0: min t, min resolution, most noisy 29 | # stage_idx = num_stages - 1 : max t, max resolution, most clear 30 | for stage_idx in range(num_stages): 31 | start_idx = max(int(num_train_timesteps * self.stage_range[stage_idx]), 0) 32 | end_idx = min(int(num_train_timesteps * self.stage_range[stage_idx + 1]), num_train_timesteps) 33 | 34 | start_t = self.t[start_idx].item() 35 | end_t = self.t[end_idx].item() if end_idx < num_train_timesteps else 1.0 36 | 37 | self.original_start_t[stage_idx] = start_t 38 | 39 | if stage_idx > 0: 40 | start_t *= cal_rectify_ratio(start_t, gamma) 41 | 42 | self.start_t[stage_idx] = start_t 43 | self.end_t[stage_idx] = end_t 44 | stage_distance.append(end_t - start_t) 45 | 46 | total_stage_distance = sum(stage_distance) 47 | t_within_stage = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float64)[:-1] 48 | 49 | for stage_idx in range(num_stages): 50 | start_ratio = 0.0 if stage_idx == 0 else sum(stage_distance[:stage_idx]) / total_stage_distance 51 | end_ratio = 1.0 if stage_idx == num_stages - 1 else sum(stage_distance[:stage_idx + 1]) / total_stage_distance 52 | 53 | Timestep_start = self.Timesteps[int(num_train_timesteps * start_ratio)] 54 | Timestep_end = self.Timesteps[min(int(num_train_timesteps * end_ratio), num_train_timesteps - 1)] 55 | 56 | self.t_window_per_stage[stage_idx] = t_within_stage 57 | 58 | if stage_idx == num_stages - 1: 59 | self.Timesteps_per_stage[stage_idx] = torch.linspace(Timestep_start.item(), Timestep_end.item(), num_train_timesteps, dtype=torch.float64) 60 | else: 61 | self.Timesteps_per_stage[stage_idx] = torch.linspace(Timestep_start.item(), Timestep_end.item(), num_train_timesteps + 1, dtype=torch.float64)[:-1] 62 | 63 | @staticmethod 64 | def time_linear_to_Timesteps(t, t_start, t_end, T_start, T_end): 65 | """ 66 | linearly map t to T: T = k * t + b 67 | """ 68 | k = (T_end - T_start) / (t_end - t_start) 69 | b = T_start - t_start * k 70 | return k * t + b 71 | 72 | def set_timesteps(self, num_inference_steps, stage_index, device=None, shift=1.0): 73 | self.num_inference_steps = num_inference_steps 74 | 75 | stage_T_start = self.Timesteps_per_stage[stage_index][0].item() 76 | stage_T_end = self.Timesteps_per_stage[stage_index][-1].item() 77 | 78 | t_start = self.t_window_per_stage[stage_index][0].item() 79 | t_end = self.t_window_per_stage[stage_index][-1].item() 80 | 81 | t = np.linspace(t_start, t_end, num_inference_steps, dtype=np.float64) 82 | t = t / (shift + (1 - shift) * t) 83 | 84 | Timesteps = self.time_linear_to_Timesteps(t, t_start, t_end, stage_T_start, stage_T_end) 85 | self.Timesteps = torch.from_numpy(Timesteps).to(device=device) 86 | 87 | self.t = torch.from_numpy(np.append(t, 1.0)).to(device=device, dtype=torch.float64) 88 | self._step_index = None 89 | 90 | def step(self, model_output, sample): 91 | if self.step_index is None: 92 | self._step_index = 0 93 | 94 | sample = sample.to(torch.float32) 95 | t = self.t[self.step_index].float() 96 | t_next = self.t[self.step_index + 1].float() 97 | 98 | prev_sample = sample + (t_next - t) * model_output 99 | self._step_index += 1 100 | 101 | return prev_sample.to(model_output.dtype) 102 | 103 | @property 104 | def step_index(self): 105 | """Current step index for the scheduler.""" 106 | return self._step_index 107 | -------------------------------------------------------------------------------- /pixelflow/solver_ode_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchdiffeq import odeint 3 | 4 | 5 | # https://github.com/willisma/SiT/blob/main/transport/integrators.py#L77 6 | class ODE: 7 | """ODE solver class""" 8 | def __init__( 9 | self, 10 | *, 11 | t0, 12 | t1, 13 | sampler_type, 14 | num_steps, 15 | atol, 16 | rtol, 17 | ): 18 | assert t0 < t1, "ODE sampler has to be in forward time" 19 | 20 | self.t = torch.linspace(t0, t1, num_steps) 21 | self.atol = atol 22 | self.rtol = rtol 23 | self.sampler_type = sampler_type 24 | 25 | def time_linear_to_Timesteps(self, t, t_start, t_end, T_start, T_end): 26 | # T = k * t + b 27 | k = (T_end - T_start) / (t_end - t_start) 28 | b = T_start - t_start * k 29 | return k * t + b 30 | 31 | def sample(self, x, model, T_start, T_end, **model_kwargs): 32 | device = x[0].device if isinstance(x, tuple) else x.device 33 | def _fn(t, x): 34 | t = torch.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else torch.ones(x.size(0)).to(device) * t 35 | model_output = model(x, self.time_linear_to_Timesteps(t, 0, 1, T_start, T_end), **model_kwargs) 36 | assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape" 37 | return model_output 38 | 39 | t = self.t.to(device) 40 | atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol] 41 | rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol] 42 | samples = odeint( 43 | _fn, 44 | x, 45 | t, 46 | method=self.sampler_type, 47 | atol=atol, 48 | rtol=rtol 49 | ) 50 | return samples 51 | -------------------------------------------------------------------------------- /pixelflow/utils/config.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def get_obj_from_str(string, reload=False): 5 | module, cls = string.rsplit(".", 1) 6 | if reload: 7 | module_imp = importlib.import_module(module) 8 | importlib.reload(module_imp) 9 | return getattr(importlib.import_module(module, package=None), cls) 10 | 11 | 12 | def instantiate_from_config(config): 13 | if not "target" in config: 14 | raise KeyError("Expected key `target` to instantiate.") 15 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 16 | 17 | 18 | def instantiate_optimizer_from_config(config, params): 19 | if not "target" in config: 20 | raise KeyError("Expected key `target` to instantiate.") 21 | return get_obj_from_str(config["target"])(params, **config.get("params", dict())) 22 | 23 | 24 | def instantiate_dataset_from_config(config, transform): 25 | if not "target" in config: 26 | raise KeyError("Expected key `target` to instantiate.") 27 | return get_obj_from_str(config["target"])(transform=transform, **config.get("params", dict())) 28 | -------------------------------------------------------------------------------- /pixelflow/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | 5 | class PathSimplifierFormatter(logging.Formatter): 6 | def format(self, record): 7 | record.short_path = os.path.relpath(record.pathname) 8 | return super().format(record) 9 | 10 | 11 | def setup_logger(log_directory, experiment_name, process_rank, source_module=__name__): 12 | handlers = [logging.StreamHandler()] 13 | 14 | if process_rank == 0: 15 | log_file_path = os.path.join(log_directory, f"{experiment_name}.log") 16 | handlers.append(logging.FileHandler(log_file_path)) 17 | 18 | log_formatter = PathSimplifierFormatter( 19 | fmt='[%(asctime)s %(short_path)s:%(lineno)d] %(message)s', 20 | datefmt='%Y-%m-%d %H:%M:%S' 21 | ) 22 | 23 | for handler in handlers: 24 | handler.setFormatter(log_formatter) 25 | 26 | logging.basicConfig(level=logging.INFO, handlers=handlers) 27 | return logging.getLogger(source_module) 28 | -------------------------------------------------------------------------------- /pixelflow/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | 6 | def seed_everything(seed=0, deterministic_ops=True, allow_tf32=False): 7 | """ 8 | Sets the seed for reproducibility across various libraries and frameworks, and configures PyTorch backend settings. 9 | 10 | Args: 11 | seed (int): The seed value for random number generation. Default is 0. 12 | deterministic_ops (bool): Whether to enable deterministic operations in PyTorch. 13 | Enabling this can make results reproducible at the cost of potential performance degradation. Default is True. 14 | allow_tf32 (bool): Whether to allow TensorFloat-32 (TF32) precision in PyTorch operations. TF32 can improve performance but may affect reproducibility. Default is False. 15 | 16 | Effects: 17 | - Seeds Python's random module, NumPy, and PyTorch (CPU and GPU). 18 | - Sets the environment variable `PYTHONHASHSEED` to the specified seed. 19 | - Configures PyTorch to use deterministic algorithms if `deterministic_ops` is True. 20 | - Configures TensorFloat-32 precision based on `allow_tf32`. 21 | - Issues warnings if configurations may impact reproducibility. 22 | 23 | Notes: 24 | - Setting `torch.backends.cudnn.deterministic` to False allows nondeterministic operations, which may introduce variability. 25 | - Allowing TF32 (`allow_tf32=True`) may lead to non-reproducible results, especially in matrix operations. 26 | """ 27 | # Seed standard random number generators 28 | random.seed(seed) 29 | np.random.seed(seed) 30 | os.environ['PYTHONHASHSEED'] = str(seed) 31 | 32 | # Seed PyTorch random number generators 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) 35 | 36 | # Configure deterministic operations 37 | if deterministic_ops: 38 | torch.backends.cudnn.deterministic = True 39 | torch.use_deterministic_algorithms(True) 40 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 41 | else: 42 | torch.backends.cudnn.deterministic = False 43 | print("WARNING: torch.backends.cudnn.deterministic is set to False, reproducibility is not guaranteed.") 44 | 45 | # Configure TensorFloat-32 precision 46 | if allow_tf32: 47 | print("WARNING: TensorFloat-32 (TF32) is enabled; reproducibility is not guaranteed.") 48 | 49 | torch.backends.cudnn.allow_tf32 = allow_tf32 # Default True in PyTorch 2.6.0 50 | torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Default False in PyTorch 2.6.0 51 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | pyarrow 3 | omegaconf 4 | diffusers==0.32.2 5 | transformers==4.48.0 6 | torchdiffeq==0.2.4 7 | -------------------------------------------------------------------------------- /sample_ddp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import os 5 | from omegaconf import OmegaConf 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import math 9 | 10 | import torch 11 | torch.backends.cuda.matmul.allow_tf32 = True 12 | torch.backends.cudnn.allow_tf32 = True 13 | import torch.distributed as dist 14 | 15 | from pixelflow.scheduling_pixelflow import PixelFlowScheduler 16 | from pixelflow.pipeline_pixelflow import PixelFlowPipeline 17 | from pixelflow.utils import config as config_utils 18 | 19 | 20 | def get_args_parser(): 21 | parser = argparse.ArgumentParser(description='sample 50k images for FID evaluation', add_help=False) 22 | parser.add_argument('--pretrained', type=str, required=True, help='Pretrained model path') 23 | 24 | parser.add_argument("--sample-dir", type=str, default="evaluate_256pix_folder") 25 | parser.add_argument("--cfg", type=float, default=2.4) 26 | parser.add_argument("--num-steps-per-stage", type=int, default=30) 27 | parser.add_argument("--use-ode-dopri5", action="store_true") 28 | parser.add_argument("--local-batch-size", type=int, default=16) 29 | parser.add_argument("--num-fid-samples", type=int, default=50000) 30 | parser.add_argument("--num-classes", type=int, default=1000) 31 | parser.add_argument("--global-seed", type=int, default=10) 32 | return parser 33 | 34 | 35 | def create_npz_from_sample_folder(sample_dir, num=50_000): 36 | """ 37 | Builds a single .npz file from a folder of .png samples. 38 | """ 39 | samples = [] 40 | for i in tqdm(range(num), desc="Building .npz file from samples"): 41 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") 42 | sample_np = np.asarray(sample_pil).astype(np.uint8) 43 | samples.append(sample_np) 44 | samples = np.stack(samples) 45 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) 46 | npz_path = f"{sample_dir}.npz" 47 | np.savez(npz_path, arr_0=samples) 48 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") 49 | return npz_path 50 | 51 | 52 | def main(args): 53 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" 54 | torch.set_grad_enabled(False) 55 | 56 | # Setup DDP 57 | rank = int(os.environ["RANK"]) 58 | local_rank = int(os.environ["LOCAL_RANK"]) 59 | world_size = int(os.environ["WORLD_SIZE"]) 60 | 61 | dist.init_process_group("nccl", rank=rank, world_size=world_size, timeout=datetime.timedelta(hours=1)) 62 | device = torch.device(f"cuda:{local_rank}") 63 | torch.cuda.set_device(device) 64 | 65 | seed = args.global_seed * dist.get_world_size() + rank 66 | torch.manual_seed(seed) 67 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 68 | print(args) 69 | 70 | # create and load model 71 | config = OmegaConf.load(f"{args.pretrained}/config.yaml") 72 | model = config_utils.instantiate_from_config(config.model).to(device) 73 | ckpt = torch.load(f"{args.pretrained}/model.pt", map_location="cpu", weights_only=True) 74 | model.load_state_dict(ckpt, strict=True) 75 | model.eval() 76 | 77 | resolution = config.data.resolution 78 | 79 | scheduler = PixelFlowScheduler( 80 | config.scheduler.num_train_timesteps, num_stages=config.scheduler.num_stages, gamma=-1/3 81 | ) 82 | num_steps_per_stage = [args.num_steps_per_stage for _ in range(config.scheduler.num_stages)] 83 | pipeline = PixelFlowPipeline(scheduler, model) 84 | 85 | # Create folder to save samples: 86 | sample_folder_dir = args.sample_dir 87 | if rank == 0: 88 | os.makedirs(sample_folder_dir, exist_ok=True) 89 | print(f"Saving .png samples at {sample_folder_dir}") 90 | dist.barrier() 91 | 92 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: 93 | local_batch_size = args.local_batch_size 94 | global_batch_size = local_batch_size * dist.get_world_size() 95 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples: 96 | total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size) 97 | if rank == 0: 98 | print(f"Total number of images that will be sampled: {total_samples}") 99 | assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" 100 | 101 | # Number of images per class is equal (and pad 0) 102 | class_list_global = torch.zeros((total_samples,), device=device) 103 | num_classes = args.num_classes 104 | class_list = torch.arange(0, num_classes).repeat(args.num_fid_samples // num_classes) 105 | class_list_global[:class_list.shape[0]] = class_list 106 | 107 | local_samples = int(total_samples // dist.get_world_size()) 108 | assert local_samples % local_batch_size == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" 109 | iterations = int(local_samples // local_batch_size) 110 | pbar = range(iterations) 111 | pbar = tqdm(pbar) if rank == 0 else pbar 112 | total = 0 113 | for _ in pbar: 114 | cur_index = torch.arange(local_batch_size) * dist.get_world_size() + rank + total 115 | cur_class_list = class_list_global[cur_index] 116 | 117 | with torch.autocast("cuda", dtype=torch.bfloat16), torch.no_grad(): 118 | samples = pipeline( 119 | prompt=cur_class_list, 120 | num_inference_steps=list(num_steps_per_stage), 121 | height=resolution, 122 | width=resolution, 123 | guidance_scale=args.cfg, 124 | device=device, 125 | use_ode_dopri5=args.use_ode_dopri5, 126 | ) 127 | samples = (samples * 255).round().astype("uint8") 128 | image_list = [Image.fromarray(sample) for sample in samples] 129 | 130 | for img_ind, image in enumerate(image_list): 131 | index = img_ind * dist.get_world_size() + rank + total 132 | if index < args.num_fid_samples: 133 | image.save(f"{sample_folder_dir}/{index:06d}.png") 134 | 135 | total += global_batch_size 136 | 137 | # Make sure all processes have finished saving their samples before attempting to convert to .npz 138 | dist.barrier() 139 | if rank == 0: 140 | create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples) 141 | print("Done.") 142 | dist.barrier() 143 | dist.destroy_process_group() 144 | 145 | if __name__ == "__main__": 146 | parser = get_args_parser() 147 | args = parser.parse_args() 148 | main(args) 149 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import argparse 4 | import copy 5 | from collections import OrderedDict 6 | from datetime import datetime 7 | from omegaconf import OmegaConf 8 | import torch 9 | import torch.distributed as dist 10 | 11 | from pixelflow.scheduling_pixelflow import PixelFlowScheduler 12 | from pixelflow.utils.logger import setup_logger 13 | from pixelflow.utils import config as config_utils 14 | from pixelflow.utils.misc import seed_everything 15 | from pixelflow.data_in1k import build_imagenet_loader 16 | 17 | 18 | def get_args_parser(): 19 | parser = argparse.ArgumentParser(description='Action Tokenizer', add_help=False) 20 | parser.add_argument('config', type=str, help='config') 21 | parser.add_argument('--output-dir', default='./exp0001', help='Output directory') 22 | parser.add_argument('--logging-steps', type=int, default=10, help='Logging steps') 23 | parser.add_argument('--checkpoint-steps', type=int, default=1000, help='Checkpoint steps') 24 | parser.add_argument('--pretrained-model', type=str, default=None, help='Pretrained model') 25 | parser.add_argument('--report-to', type=str, default=None, help='Report to, eg. wandb') 26 | 27 | return parser 28 | 29 | 30 | @torch.no_grad() 31 | def update_ema(ema_model, model, decay=0.9999): 32 | """ 33 | Step the EMA model towards the current model. 34 | """ 35 | ema_params = OrderedDict(ema_model.named_parameters()) 36 | model_params = OrderedDict(model.named_parameters()) 37 | 38 | for name, param in model_params.items(): 39 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 40 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 41 | 42 | 43 | def main(args): 44 | rank = int(os.environ["RANK"]) 45 | local_rank = int(os.environ["LOCAL_RANK"]) 46 | world_size = int(os.environ["WORLD_SIZE"]) 47 | 48 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 49 | device = torch.device(f"cuda:{local_rank}") 50 | torch.cuda.set_device(device) 51 | 52 | exp_name = "{}".format(datetime.now().strftime('%Y%m%d-%H%M%S')) 53 | 54 | # we print logs from all ranks to console, but only save to file from rank 0 55 | logger = setup_logger(args.output_dir, exp_name, rank, __name__) 56 | 57 | config = OmegaConf.load(args.config) 58 | 59 | rank_seed = config.seed * world_size + rank 60 | seed_everything(rank_seed, deterministic_ops=False, allow_tf32=False) 61 | logger.info(f"Rank: {rank}, Local Rank: {local_rank}, World Size: {world_size} Seed: {rank_seed}") 62 | 63 | # save args and config to output_dir 64 | with open(os.path.join(args.output_dir, "args.txt"), "w") as f: 65 | f.write(str(args)) 66 | with open(os.path.join(args.output_dir, "config.yaml"), "w") as f: 67 | f.write(OmegaConf.to_yaml(config)) 68 | 69 | logger.info(f"Config: {config}") 70 | model = config_utils.instantiate_from_config(config.model).to(device) 71 | ema = copy.deepcopy(model).to(device) 72 | for param in ema.parameters(): 73 | param.requires_grad = False 74 | 75 | logger.info(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 76 | 77 | noise_scheduler = PixelFlowScheduler( 78 | num_train_timesteps=config.scheduler.num_train_timesteps, 79 | num_stages=config.scheduler.num_stages, 80 | ) 81 | noise_scheduler_copy = copy.deepcopy(noise_scheduler) 82 | 83 | if args.pretrained_model is not None: 84 | ckpt = torch.load(args.pretrained_model, map_location="cpu", weights_only=True) 85 | model.load_state_dict(ckpt, strict=True) 86 | 87 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) 88 | 89 | optimizer = torch.optim.AdamW(model.parameters(), lr=config.train.lr, weight_decay=config.train.weight_decay) 90 | data_loader, sampler = build_imagenet_loader(config, noise_scheduler_copy) 91 | 92 | logger.info("***** Running training *****") 93 | global_step = 0 94 | first_epoch = 0 95 | 96 | # Prepare models for training: 97 | update_ema(ema, model.module, decay=0) 98 | model.train() 99 | ema.eval() 100 | 101 | for epoch in range(first_epoch, config.train.epochs): 102 | sampler.set_epoch(epoch) 103 | for step, batch in enumerate(data_loader): 104 | optimizer.zero_grad() 105 | target = batch["target_values"].to(device) 106 | with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16): 107 | model_output = model( 108 | hidden_states=batch["pixel_values"].to(device), 109 | encoder_hidden_states=None, 110 | class_labels=torch.tensor(batch["input_ids"], device=device), 111 | timestep=batch["timesteps"].to(device), 112 | latent_size=batch["batch_latent_size"].to(device), 113 | pos_embed=batch["pos_embed"].to(device), 114 | cu_seqlens_q=batch["cumsum_q_len"].to(device), 115 | cu_seqlens_k=None, 116 | seqlen_list_q=batch["seqlen_list_q"], 117 | seqlen_list_k=None, 118 | ) 119 | 120 | loss = (model_output.float() - target.float()) ** 2 121 | loss_split = torch.split(loss, batch["seqlen_list_q"], dim=0) 122 | loss_items = torch.stack([x.mean() for x in loss_split]) 123 | if "padding_size" in batch and batch["padding_size"] is not None and batch["padding_size"] > 0: 124 | loss_items = loss_items[:-1] 125 | 126 | loss = loss_items.mean() 127 | loss.backward() 128 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 129 | optimizer.step() 130 | update_ema(ema, model.module) 131 | 132 | if global_step % args.logging_steps == 0: 133 | logger.info(f"Epoch: {epoch}, Step: {step}, Loss: {loss.item()}, Grad Norm: {grad_norm.item()}") 134 | 135 | if global_step % args.checkpoint_steps == 0 and global_step > 0: 136 | if rank == 0: 137 | torch.save( 138 | { 139 | "model": model.module.state_dict(), 140 | "ema": ema.state_dict(), 141 | "opt": optimizer.state_dict(), 142 | "args": args 143 | }, 144 | os.path.join(args.output_dir, f"model_{global_step}.pt")) 145 | logger.info(f"Model saved at step {global_step}") 146 | dist.barrier() 147 | 148 | global_step += 1 149 | 150 | 151 | if __name__ == '__main__': 152 | parser = get_args_parser() 153 | args = parser.parse_args() 154 | os.makedirs(args.output_dir, exist_ok=True) 155 | main(args) 156 | --------------------------------------------------------------------------------