├── .gitignore
├── LICENSE
├── README.md
├── app.py
├── configs
└── pixelflow_xl_c2i.yaml
├── imagenet_en_cn.py
├── pixelflow
├── data_in1k.py
├── model.py
├── pipeline_pixelflow.py
├── scheduling_pixelflow.py
├── solver_ode_wrapper.py
└── utils
│ ├── config.py
│ ├── logger.py
│ └── misc.py
├── requirements.txt
├── sample_ddp.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /dataset/
2 | # output dir
3 | /output
4 | /ckpt_*
5 | /exp0*/
6 |
7 | *.diff
8 |
9 | # compilation and distribution
10 | __pycache__
11 | _ext
12 | *.pyc
13 | *.pyd
14 | *.so
15 | build/
16 | dist/
17 | wheels/
18 | *.egg-info/
19 |
20 | # pytorch/python/numpy formats
21 | *.pth
22 | *.pkl
23 | *.pt
24 |
25 |
26 | # Editor temporaries
27 | *.mp4
28 | *.swn
29 | *.swo
30 | *.swp
31 | *~
32 |
33 | # editor settings
34 | .idea
35 | .vscode
36 |
37 | # macOS system files
38 | .DS_Store
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Shoufa Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
PixelFlow: Pixel-Space Generative Models with Flow
4 |
5 | [](https://arxiv.org/abs/2504.07963)
6 | [](https://huggingface.co/spaces/ShoufaChen/PixelFlow)
7 |
8 |
9 | 
10 |
11 |
12 |
13 |
14 |
15 |
16 | > [**PixelFlow: Pixel-Space Generative Models with Flow**](https://arxiv.org/abs/2504.07963)
17 | > [Shoufa Chen](https://www.shoufachen.com), [Chongjian Ge](https://chongjiange.github.io/), [Shilong Zhang](https://jshilong.github.io/), [Peize Sun](https://peizesun.github.io/), [Ping Luo](http://luoping.me/)
18 | >
The University of Hong Kong, Adobe
19 |
20 | ## Introduction
21 | We present PixelFlow, a family of image generation models that operate directly in the raw pixel space, in contrast to the predominant latent-space models. This approach simplifies the image generation process by eliminating the need for a pre-trained Variational Autoencoder (VAE) and enabling the whole model end-to-end trainable. Through efficient cascade flow modeling, PixelFlow achieves affordable computation cost in pixel space. It achieves an FID of 1.98 on 256x256 ImageNet class-conditional image generation benchmark. The qualitative text-to-image results demonstrate that PixelFlow excels in image quality, artistry, and semantic control. We hope this new paradigm will inspire and open up new opportunities for next-generation visual generation models.
22 |
23 |
24 | ## Model Zoo
25 |
26 | | Model | Task | Params | FID | Checkpoint |
27 | |:---------:|:--------------:|:------:|:----:|:----------:|
28 | | PixelFlow | class-to-image | 677M | 1.98 | [🤗](https://huggingface.co/ShoufaChen/PixelFlow-Class2Image) |
29 | | PixelFlow | text-to-image | 882M | N/A | [🤗](https://huggingface.co/ShoufaChen/PixelFlow-Text2Image) |
30 |
31 |
32 | ## Setup
33 |
34 | ### 1. Create Environment
35 | ```bash
36 | conda create -n pixelflow python=3.12
37 | conda activate pixelflow
38 | ```
39 | ### 2. Install Dependencies:
40 | * [PyTorch 2.6.0](https://pytorch.org/) — install it according to your system configuration (CUDA version, etc.).
41 | * [flash-attention v2.7.4.post1](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.7.4.post1): optional, required only for training.
42 | * Other packages: `pip3 install -r requirements.txt`
43 |
44 |
45 | ## Demo [](https://huggingface.co/spaces/ShoufaChen/PixelFlow)
46 |
47 |
48 | We provide an online [Gradio demo](https://huggingface.co/spaces/ShoufaChen/PixelFlow) for class-to-image generation.
49 |
50 | You can also easily deploy both class-to-image and text-to-image demos locally by:
51 |
52 | ```bash
53 | python app.py --checkpoint /path/to/checkpoint --class_cond # for class-to-image
54 | ```
55 | or
56 | ```bash
57 | python app.py --checkpoint /path/to/checkpoint # for text-to-image
58 | ```
59 |
60 |
61 | ## Training
62 |
63 | ### 1. ImageNet Preparation
64 |
65 | - Download the ImageNet dataset from [http://www.image-net.org/](http://www.image-net.org/).
66 | - Use the [extract_ILSVRC.sh]([extract_ILSVRC.sh](https://github.com/pytorch/examples/blob/main/imagenet/extract_ILSVRC.sh)) to extract and organize the training and validation images into labeled subfolders.
67 |
68 | ### 2. Training Command
69 |
70 | ```bash
71 | torchrun --nnodes=1 --nproc_per_node=8 train.py configs/pixelflow_xl_c2i.yaml
72 | ```
73 |
74 | ## Evaluation (FID, Inception Score, etc.)
75 |
76 | We provide a [sample_ddp.py](sample_ddp.py) script, adapted from [DiT](https://github.com/facebookresearch/DiT), for generating sample images and saving them both as a folder and as a .npz file. The .npz file is compatible with ADM's TensorFlow evaluation suite, allowing direct computation of FID, Inception Score, and other metrics.
77 |
78 |
79 | ```bash
80 | torchrun --nnodes=1 --nproc_per_node=8 sample_ddp.py --pretrained /path/to/checkpoint
81 | ```
82 |
83 |
84 | ## BibTeX
85 | ```bibtex
86 | @article{chen2025pixelflow,
87 | title={PixelFlow: Pixel-Space Generative Models with Flow},
88 | author={Chen, Shoufa and Ge, Chongjian and Zhang, Shilong and Sun, Peize and Luo, Ping},
89 | journal={arXiv preprint arXiv:2504.07963},
90 | year={2025}
91 | }
92 | ```
93 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from PIL import Image
4 | import gradio as gr
5 | from imagenet_en_cn import IMAGENET_1K_CLASSES
6 | from omegaconf import OmegaConf
7 |
8 | import torch
9 | from transformers import T5EncoderModel, AutoTokenizer
10 |
11 | from pixelflow.scheduling_pixelflow import PixelFlowScheduler
12 | from pixelflow.pipeline_pixelflow import PixelFlowPipeline
13 | from pixelflow.utils import config as config_utils
14 | from pixelflow.utils.misc import seed_everything
15 |
16 |
17 | parser = argparse.ArgumentParser(description='Gradio Demo', add_help=False)
18 | parser.add_argument('--checkpoint', type=str, help='checkpoint folder path')
19 | parser.add_argument('--class_cond', action='store_true', help='use class conditional generation')
20 | args = parser.parse_args()
21 |
22 | local_rank = 0
23 | device = torch.device(f"cuda:{local_rank}")
24 | torch.cuda.set_device(device)
25 |
26 | output_dir = args.checkpoint
27 | if args.class_cond:
28 | config = OmegaConf.load(f"{output_dir}/config.yaml")
29 | model = config_utils.instantiate_from_config(config.model).to(device)
30 | print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
31 | ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True)
32 | text_encoder = None
33 | tokenizer = None
34 | resolution = 256
35 | NUM_EXAMPLES = 4
36 | else:
37 | config = OmegaConf.load(f"{output_dir}/config.yaml")
38 | model = config_utils.instantiate_from_config(config.model).to(device)
39 | print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
40 | ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True)
41 | text_encoder = T5EncoderModel.from_pretrained("google/flan-t5-xl").to(device)
42 | tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
43 | resolution = 1024
44 | NUM_EXAMPLES = 1
45 | model.load_state_dict(ckpt, strict=True)
46 | model.eval()
47 |
48 | scheduler = PixelFlowScheduler(config.scheduler.num_train_timesteps, num_stages=config.scheduler.num_stages, gamma=-1/3)
49 |
50 | pipeline = PixelFlowPipeline(
51 | scheduler,
52 | model,
53 | text_encoder=text_encoder,
54 | tokenizer=tokenizer,
55 | max_token_length=512,
56 | )
57 |
58 | def infer(use_ode_dopri5, noise_shift, cfg_scale, class_label, seed, *num_steps_per_stage):
59 | seed_everything(seed)
60 | with torch.autocast("cuda", dtype=torch.bfloat16), torch.no_grad():
61 | samples = pipeline(
62 | prompt=[class_label] * NUM_EXAMPLES,
63 | height=resolution,
64 | width=resolution,
65 | num_inference_steps=list(num_steps_per_stage),
66 | guidance_scale=cfg_scale, # The guidance for the first frame, set it to 7 for 384p variant
67 | device=device,
68 | shift=noise_shift,
69 | use_ode_dopri5=use_ode_dopri5,
70 | )
71 | samples = (samples * 255).round().astype("uint8")
72 | samples = [Image.fromarray(sample) for sample in samples]
73 | return samples
74 |
75 |
76 | with gr.Blocks() as demo:
77 | gr.Markdown("PixelFlow: Pixel-Space Generative Models with Flow
")
78 |
79 | with gr.Tabs():
80 | with gr.TabItem('Generate'):
81 | with gr.Row():
82 | with gr.Column():
83 | with gr.Row():
84 | if args.class_cond:
85 | user_input = gr.Dropdown(
86 | list(IMAGENET_1K_CLASSES.values()),
87 | value='daisy [雏菊]',
88 | type="index", label='ImageNet-1K Class'
89 | )
90 | else:
91 | # text input
92 | user_input = gr.Textbox(label='Enter your prompt', show_label=False, max_lines=1, placeholder="Enter your prompt",)
93 | ode_dopri5 = gr.Checkbox(label="Dopri5 ODE", info="Use Dopri5 ODE solver")
94 | noise_shift = gr.Slider(minimum=1.0, maximum=100.0, step=1, value=1.0, label='Noise Shift')
95 | cfg_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=4.0, label='Classifier-free Guidance Scale')
96 | num_steps_per_stage = []
97 | for stage_idx in range(config.scheduler.num_stages):
98 | num_steps = gr.Slider(minimum=1, maximum=100, step=1, value=10, label=f'Num Inference Steps (Stage {stage_idx})')
99 | num_steps_per_stage.append(num_steps)
100 | seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed')
101 | button = gr.Button("Generate", variant="primary")
102 | with gr.Column():
103 | output = gr.Gallery(label='Generated Images', height=700)
104 | button.click(infer, inputs=[ode_dopri5, noise_shift, cfg_scale, user_input, seed, *num_steps_per_stage], outputs=[output])
105 | demo.queue()
106 | demo.launch(share=False, debug=True)
107 |
--------------------------------------------------------------------------------
/configs/pixelflow_xl_c2i.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: pixelflow.model.PixelFlowModel
3 | params:
4 | num_attention_heads: 16
5 | attention_head_dim: 72
6 | in_channels: 3
7 | out_channels: 3
8 | depth: 28
9 | num_classes: 1000
10 | patch_size: 4
11 | attention_bias: true
12 |
13 | scheduler:
14 | num_train_timesteps: 1000
15 | num_stages: 4
16 | pyramid_shift: false
17 |
18 | train:
19 | lr: 1e-4
20 | weight_decay: 0.0
21 | epochs: 10
22 |
23 | data:
24 | root: /public/datasets/ILSVRC2012/train
25 | center_crop: false
26 | resolution: 256
27 | expand_ratio: 1.125
28 | num_workers: 4
29 | batch_size: 4
30 |
31 | seed: 42
32 |
--------------------------------------------------------------------------------
/imagenet_en_cn.py:
--------------------------------------------------------------------------------
1 | IMAGENET_1K_CLASSES = {
2 | 0: 'tench, Tinca tinca [丁鲷]',
3 | 1: 'goldfish, Carassius auratus [金鱼]',
4 | 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias [大白鲨]',
5 | 3: 'tiger shark, Galeocerdo cuvieri [虎鲨]',
6 | 4: 'hammerhead, hammerhead shark [锤头鲨]',
7 | 5: 'electric ray, crampfish, numbfish, torpedo [电鳐]',
8 | 6: 'stingray [黄貂鱼]',
9 | 7: 'cock [公鸡]',
10 | 8: 'hen [母鸡]',
11 | 9: 'ostrich, Struthio camelus [鸵鸟]',
12 | 10: 'brambling, Fringilla montifringilla [燕雀]',
13 | 11: 'goldfinch, Carduelis carduelis [金翅雀]',
14 | 12: 'house finch, linnet, Carpodacus mexicanus [家朱雀]',
15 | 13: 'junco, snowbird [灯芯草雀]',
16 | 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea [靛蓝雀,靛蓝鸟]',
17 | 15: 'robin, American robin, Turdus migratorius [蓝鹀]',
18 | 16: 'bulbul [夜莺]',
19 | 17: 'jay [松鸦]',
20 | 18: 'magpie [喜鹊]',
21 | 19: 'chickadee [山雀]',
22 | 20: 'water ouzel, dipper [河鸟]',
23 | 21: 'kite [鸢(猛禽)]',
24 | 22: 'bald eagle, American eagle, Haliaeetus leucocephalus [秃头鹰]',
25 | 23: 'vulture [秃鹫]',
26 | 24: 'great grey owl, great gray owl, Strix nebulosa [大灰猫头鹰]',
27 | 25: 'European fire salamander, Salamandra salamandra [欧洲火蝾螈]',
28 | 26: 'common newt, Triturus vulgaris [普通蝾螈]',
29 | 27: 'eft [水蜥]',
30 | 28: 'spotted salamander, Ambystoma maculatum [斑点蝾螈]',
31 | 29: 'axolotl, mud puppy, Ambystoma mexicanum [蝾螈,泥狗]',
32 | 30: 'bullfrog, Rana catesbeiana [牛蛙]',
33 | 31: 'tree frog, tree-frog [树蛙]',
34 | 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui [尾蛙,铃蟾蜍,肋蟾蜍,尾蟾蜍]',
35 | 33: 'loggerhead, loggerhead turtle, Caretta caretta [红海龟]',
36 | 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea [皮革龟]',
37 | 35: 'mud turtle [泥龟]',
38 | 36: 'terrapin [淡水龟]',
39 | 37: 'box turtle, box tortoise [箱龟]',
40 | 38: 'banded gecko [带状壁虎]',
41 | 39: 'common iguana, iguana, Iguana iguana [普通鬣蜥]',
42 | 40: 'American chameleon, anole, Anolis carolinensis [美国变色龙]',
43 | 41: 'whiptail, whiptail lizard [鞭尾蜥蜴]',
44 | 42: 'agama [飞龙科蜥蜴]',
45 | 43: 'frilled lizard, Chlamydosaurus kingi [褶边蜥蜴]',
46 | 44: 'alligator lizard [鳄鱼蜥蜴]',
47 | 45: 'Gila monster, Heloderma suspectum [毒蜥]',
48 | 46: 'green lizard, Lacerta viridis [绿蜥蜴]',
49 | 47: 'African chameleon, Chamaeleo chamaeleon [非洲变色龙]',
50 | 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis [科莫多蜥蜴]',
51 | 49: 'African crocodile, Nile crocodile, Crocodylus niloticus [非洲鳄,尼罗河鳄鱼]',
52 | 50: 'American alligator, Alligator mississipiensis [美国鳄鱼,鳄鱼]',
53 | 51: 'triceratops [三角龙]',
54 | 52: 'thunder snake, worm snake, Carphophis amoenus [雷蛇,蠕虫蛇]',
55 | 53: 'ringneck snake, ring-necked snake, ring snake [环蛇,环颈蛇]',
56 | 54: 'hognose snake, puff adder, sand viper [希腊蛇]',
57 | 55: 'green snake, grass snake [绿蛇,草蛇]',
58 | 56: 'king snake, kingsnake [国王蛇]',
59 | 57: 'garter snake, grass snake [袜带蛇,草蛇]',
60 | 58: 'water snake [水蛇]',
61 | 59: 'vine snake [藤蛇]',
62 | 60: 'night snake, Hypsiglena torquata [夜蛇]',
63 | 61: 'boa constrictor, Constrictor constrictor [大蟒蛇]',
64 | 62: 'rock python, rock snake, Python sebae [岩石蟒蛇,岩蛇,蟒蛇]',
65 | 63: 'Indian cobra, Naja naja [印度眼镜蛇]',
66 | 64: 'green mamba [绿曼巴]',
67 | 65: 'sea snake [海蛇]',
68 | 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus [角腹蛇]',
69 | 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus [菱纹响尾蛇]',
70 | 68: 'sidewinder, horned rattlesnake, Crotalus cerastes [角响尾蛇]',
71 | 69: 'trilobite [三叶虫]',
72 | 70: 'harvestman, daddy longlegs, Phalangium opilio [盲蜘蛛]',
73 | 71: 'scorpion [蝎子]',
74 | 72: 'black and gold garden spider, Argiope aurantia [黑金花园蜘蛛]',
75 | 73: 'barn spider, Araneus cavaticus [谷仓蜘蛛]',
76 | 74: 'garden spider, Aranea diademata [花园蜘蛛]',
77 | 75: 'black widow, Latrodectus mactans [黑寡妇蜘蛛]',
78 | 76: 'tarantula [狼蛛]',
79 | 77: 'wolf spider, hunting spider [狼蜘蛛,狩猎蜘蛛]',
80 | 78: 'tick [壁虱]',
81 | 79: 'centipede [蜈蚣]',
82 | 80: 'black grouse [黑松鸡]',
83 | 81: 'ptarmigan [松鸡,雷鸟]',
84 | 82: 'ruffed grouse, partridge, Bonasa umbellus [披肩鸡,披肩榛鸡]',
85 | 83: 'prairie chicken, prairie grouse, prairie fowl [草原鸡,草原松鸡]',
86 | 84: 'peacock [孔雀]',
87 | 85: 'quail [鹌鹑]',
88 | 86: 'partridge [鹧鸪]',
89 | 87: 'African grey, African gray, Psittacus erithacus [非洲灰鹦鹉]',
90 | 88: 'macaw [金刚鹦鹉]',
91 | 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita [硫冠鹦鹉]',
92 | 90: 'lorikeet [短尾鹦鹉]',
93 | 91: 'coucal [褐翅鸦鹃]',
94 | 92: 'bee eater [蜜蜂]',
95 | 93: 'hornbill [犀鸟]',
96 | 94: 'hummingbird [蜂鸟]',
97 | 95: 'jacamar [鹟䴕]',
98 | 96: 'toucan [犀鸟]',
99 | 97: 'drake [野鸭]',
100 | 98: 'red-breasted merganser, Mergus serrator [红胸秋沙鸭]',
101 | 99: 'goose [鹅]',
102 | 100: 'black swan, Cygnus atratus [黑天鹅]',
103 | 101: 'tusker [大象]',
104 | 102: 'echidna, spiny anteater, anteater [针鼹鼠]',
105 | 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus [鸭嘴兽]',
106 | 104: 'wallaby, brush kangaroo [沙袋鼠]',
107 | 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus [考拉,考拉熊]',
108 | 106: 'wombat [袋熊]',
109 | 107: 'jellyfish [水母]',
110 | 108: 'sea anemone, anemone [海葵]',
111 | 109: 'brain coral [脑珊瑚]',
112 | 110: 'flatworm, platyhelminth [扁形虫扁虫]',
113 | 111: 'nematode, nematode worm, roundworm [线虫,蛔虫]',
114 | 112: 'conch [海螺]',
115 | 113: 'snail [蜗牛]',
116 | 114: 'slug [鼻涕虫]',
117 | 115: 'sea slug, nudibranch [海参]',
118 | 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore [石鳖]',
119 | 117: 'chambered nautilus, pearly nautilus, nautilus [鹦鹉螺]',
120 | 118: 'Dungeness crab, Cancer magister [珍宝蟹]',
121 | 119: 'rock crab, Cancer irroratus [石蟹]',
122 | 120: 'fiddler crab [招潮蟹]',
123 | 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica [帝王蟹,阿拉斯加蟹,阿拉斯加帝王蟹]',
124 | 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus [美国龙虾,缅因州龙虾]',
125 | 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish [大螯虾]',
126 | 124: 'crayfish, crawfish, crawdad, crawdaddy [小龙虾]',
127 | 125: 'hermit crab [寄居蟹]',
128 | 126: 'isopod [等足目动物(明虾和螃蟹近亲)]',
129 | 127: 'white stork, Ciconia ciconia [白鹳]',
130 | 128: 'black stork, Ciconia nigra [黑鹳]',
131 | 129: 'spoonbill [鹭]',
132 | 130: 'flamingo [火烈鸟]',
133 | 131: 'little blue heron, Egretta caerulea [小蓝鹭]',
134 | 132: 'American egret, great white heron, Egretta albus [美国鹭,大白鹭]',
135 | 133: 'bittern [麻鸦]',
136 | 134: 'crane [鹤]',
137 | 135: 'limpkin, Aramus pictus [秧鹤]',
138 | 136: 'European gallinule, Porphyrio porphyrio [欧洲水鸡,紫水鸡]',
139 | 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana [沼泽泥母鸡,水母鸡]',
140 | 138: 'bustard [鸨]',
141 | 139: 'ruddy turnstone, Arenaria interpres [红翻石鹬]',
142 | 140: 'red-backed sandpiper, dunlin, Erolia alpina [红背鹬,黑腹滨鹬]',
143 | 141: 'redshank, Tringa totanus [红脚鹬]',
144 | 142: 'dowitcher [半蹼鹬]',
145 | 143: 'oystercatcher, oyster catcher [蛎鹬]',
146 | 144: 'pelican [鹈鹕]',
147 | 145: 'king penguin, Aptenodytes patagonica [国王企鹅]',
148 | 146: 'albatross, mollymawk [信天翁,大海鸟]',
149 | 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus [灰鲸]',
150 | 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca [杀人鲸,逆戟鲸,虎鲸]',
151 | 149: 'dugong, Dugong dugon [海牛]',
152 | 150: 'sea lion [海狮]',
153 | 151: 'Chihuahua [奇瓦瓦]',
154 | 152: 'Japanese spaniel [日本猎犬]',
155 | 153: 'Maltese dog, Maltese terrier, Maltese [马尔济斯犬]',
156 | 154: 'Pekinese, Pekingese, Peke [狮子狗]',
157 | 155: 'Shih-Tzu [西施犬]',
158 | 156: 'Blenheim spaniel [布莱尼姆猎犬]',
159 | 157: 'papillon [巴比狗]',
160 | 158: 'toy terrier [玩具犬]',
161 | 159: 'Rhodesian ridgeback [罗得西亚长背猎狗]',
162 | 160: 'Afghan hound, Afghan [阿富汗猎犬]',
163 | 161: 'basset, basset hound [猎犬]',
164 | 162: 'beagle [比格犬,猎兔犬]',
165 | 163: 'bloodhound, sleuthhound [侦探犬]',
166 | 164: 'bluetick [蓝色快狗]',
167 | 165: 'black-and-tan coonhound [黑褐猎浣熊犬]',
168 | 166: 'Walker hound, Walker foxhound [沃克猎犬]',
169 | 167: 'English foxhound [英国猎狐犬]',
170 | 168: 'redbone [美洲赤狗]',
171 | 169: 'borzoi, Russian wolfhound [俄罗斯猎狼犬]',
172 | 170: 'Irish wolfhound [爱尔兰猎狼犬]',
173 | 171: 'Italian greyhound [意大利灰狗]',
174 | 172: 'whippet [惠比特犬]',
175 | 173: 'Ibizan hound, Ibizan Podenco [依比沙猎犬]',
176 | 174: 'Norwegian elkhound, elkhound [挪威猎犬]',
177 | 175: 'otterhound, otter hound [奥达猎犬,水獭猎犬]',
178 | 176: 'Saluki, gazelle hound [沙克犬,瞪羚猎犬]',
179 | 177: 'Scottish deerhound, deerhound [苏格兰猎鹿犬,猎鹿犬]',
180 | 178: 'Weimaraner [威玛猎犬]',
181 | 179: 'Staffordshire bullterrier, Staffordshire bull terrier [斯塔福德郡牛头梗,斯塔福德郡斗牛梗]',
182 | 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier [美国斯塔福德郡梗,美国比特斗牛梗,斗牛梗]',
183 | 181: 'Bedlington terrier [贝德灵顿梗]',
184 | 182: 'Border terrier [边境梗]',
185 | 183: 'Kerry blue terrier [凯丽蓝梗]',
186 | 184: 'Irish terrier [爱尔兰梗]',
187 | 185: 'Norfolk terrier [诺福克梗]',
188 | 186: 'Norwich terrier [诺维奇梗]',
189 | 187: 'Yorkshire terrier [约克郡梗]',
190 | 188: 'wire-haired fox terrier [刚毛猎狐梗]',
191 | 189: 'Lakeland terrier [莱克兰梗]',
192 | 190: 'Sealyham terrier, Sealyham [锡利哈姆梗]',
193 | 191: 'Airedale, Airedale terrier [艾尔谷犬]',
194 | 192: 'cairn, cairn terrier [凯恩梗]',
195 | 193: 'Australian terrier [澳大利亚梗]',
196 | 194: 'Dandie Dinmont, Dandie Dinmont terrier [丹迪丁蒙梗]',
197 | 195: 'Boston bull, Boston terrier [波士顿梗]',
198 | 196: 'miniature schnauzer [迷你雪纳瑞犬]',
199 | 197: 'giant schnauzer [巨型雪纳瑞犬]',
200 | 198: 'standard schnauzer [标准雪纳瑞犬]',
201 | 199: 'Scotch terrier, Scottish terrier, Scottie [苏格兰梗]',
202 | 200: 'Tibetan terrier, chrysanthemum dog [西藏梗,菊花狗]',
203 | 201: 'silky terrier, Sydney silky [丝毛梗]',
204 | 202: 'soft-coated wheaten terrier [软毛麦色梗]',
205 | 203: 'West Highland white terrier [西高地白梗]',
206 | 204: 'Lhasa, Lhasa apso [拉萨阿普索犬]',
207 | 205: 'flat-coated retriever [平毛寻回犬]',
208 | 206: 'curly-coated retriever [卷毛寻回犬]',
209 | 207: 'golden retriever [金毛猎犬]',
210 | 208: 'Labrador retriever [拉布拉多猎犬]',
211 | 209: 'Chesapeake Bay retriever [乞沙比克猎犬]',
212 | 210: 'German short-haired pointer [德国短毛猎犬]',
213 | 211: 'vizsla, Hungarian pointer [维兹拉犬]',
214 | 212: 'English setter [英国谍犬]',
215 | 213: 'Irish setter, red setter [爱尔兰雪达犬,红色猎犬]',
216 | 214: 'Gordon setter [戈登雪达犬]',
217 | 215: 'Brittany spaniel [布列塔尼犬猎犬]',
218 | 216: 'clumber, clumber spaniel [黄毛,黄毛猎犬]',
219 | 217: 'English springer, English springer spaniel [英国史宾格犬]',
220 | 218: 'Welsh springer spaniel [威尔士史宾格犬]',
221 | 219: 'cocker spaniel, English cocker spaniel, cocker [可卡犬,英国可卡犬]',
222 | 220: 'Sussex spaniel [萨塞克斯猎犬]',
223 | 221: 'Irish water spaniel [爱尔兰水猎犬]',
224 | 222: 'kuvasz [哥威斯犬]',
225 | 223: 'schipperke [舒柏奇犬]',
226 | 224: 'groenendael [比利时牧羊犬]',
227 | 225: 'malinois [马里努阿犬]',
228 | 226: 'briard [伯瑞犬]',
229 | 227: 'kelpie [凯尔皮犬]',
230 | 228: 'komondor [匈牙利牧羊犬]',
231 | 229: 'Old English sheepdog, bobtail [老英国牧羊犬]',
232 | 230: 'Shetland sheepdog, Shetland sheep dog, Shetland [喜乐蒂牧羊犬]',
233 | 231: 'collie [牧羊犬]',
234 | 232: 'Border collie [边境牧羊犬]',
235 | 233: 'Bouvier des Flandres, Bouviers des Flandres [法兰德斯牧牛狗]',
236 | 234: 'Rottweiler [罗特韦尔犬]',
237 | 235: 'German shepherd, German shepherd dog, German police dog, alsatian [德国牧羊犬,德国警犬,阿尔萨斯]',
238 | 236: 'Doberman, Doberman pinscher [多伯曼犬,杜宾犬]',
239 | 237: 'miniature pinscher [迷你杜宾犬]',
240 | 238: 'Greater Swiss Mountain dog [大瑞士山地犬]',
241 | 239: 'Bernese mountain dog [伯恩山犬]',
242 | 240: 'Appenzeller [Appenzeller狗]',
243 | 241: 'EntleBucher [EntleBucher狗]',
244 | 242: 'boxer [拳师狗]',
245 | 243: 'bull mastiff [斗牛獒]',
246 | 244: 'Tibetan mastiff [藏獒]',
247 | 245: 'French bulldog [法国斗牛犬]',
248 | 246: 'Great Dane [大丹犬]',
249 | 247: 'Saint Bernard, St Bernard [圣伯纳德狗]',
250 | 248: 'Eskimo dog, husky [爱斯基摩犬,哈士奇]',
251 | 249: 'malamute, malemute, Alaskan malamute [雪橇犬,阿拉斯加爱斯基摩狗]',
252 | 250: 'Siberian husky [哈士奇]',
253 | 251: 'dalmatian, coach dog, carriage dog [达尔马提亚,教练车狗]',
254 | 252: 'affenpinscher, monkey pinscher, monkey dog [狮毛狗]',
255 | 253: 'basenji [巴辛吉狗]',
256 | 254: 'pug, pug-dog [哈巴狗,狮子狗]',
257 | 255: 'Leonberg [莱昂贝格狗]',
258 | 256: 'Newfoundland, Newfoundland dog [纽芬兰岛狗]',
259 | 257: 'Great Pyrenees [大白熊犬]',
260 | 258: 'Samoyed, Samoyede [萨摩耶犬]',
261 | 259: 'Pomeranian [博美犬]',
262 | 260: 'chow, chow chow [松狮,松狮]',
263 | 261: 'keeshond [荷兰卷尾狮毛狗]',
264 | 262: 'Brabancon griffon [布鲁塞尔格林芬犬]',
265 | 263: 'Pembroke, Pembroke Welsh corgi [彭布洛克威尔士科基犬]',
266 | 264: 'Cardigan, Cardigan Welsh corgi [威尔士柯基犬]',
267 | 265: 'toy poodle [玩具贵宾犬]',
268 | 266: 'miniature poodle [迷你贵宾犬]',
269 | 267: 'standard poodle [标准贵宾犬]',
270 | 268: 'Mexican hairless [墨西哥无毛犬]',
271 | 269: 'timber wolf, grey wolf, gray wolf, Canis lupus [灰狼]',
272 | 270: 'white wolf, Arctic wolf, Canis lupus tundrarum [白狼,北极狼]',
273 | 271: 'red wolf, maned wolf, Canis rufus, Canis niger [红太狼,鬃狼,犬犬鲁弗斯]',
274 | 272: 'coyote, prairie wolf, brush wolf, Canis latrans [狼,草原狼,刷狼,郊狼]',
275 | 273: 'dingo, warrigal, warragal, Canis dingo [澳洲野狗,澳大利亚野犬]',
276 | 274: 'dhole, Cuon alpinus [豺]',
277 | 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus [非洲猎犬,土狼犬]',
278 | 276: 'hyena, hyaena [鬣狗]',
279 | 277: 'red fox, Vulpes vulpes [红狐狸]',
280 | 278: 'kit fox, Vulpes macrotis [沙狐]',
281 | 279: 'Arctic fox, white fox, Alopex lagopus [北极狐狸,白狐狸]',
282 | 280: 'grey fox, gray fox, Urocyon cinereoargenteus [灰狐狸]',
283 | 281: 'tabby, tabby cat [虎斑猫]',
284 | 282: 'tiger cat [山猫,虎猫]',
285 | 283: 'Persian cat [波斯猫]',
286 | 284: 'Siamese cat, Siamese [暹罗暹罗猫,]',
287 | 285: 'Egyptian cat [埃及猫]',
288 | 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor [美洲狮,美洲豹]',
289 | 287: 'lynx, catamount [猞猁,山猫]',
290 | 288: 'leopard, Panthera pardus [豹子]',
291 | 289: 'snow leopard, ounce, Panthera uncia [雪豹]',
292 | 290: 'jaguar, panther, Panthera onca, Felis onca [美洲虎]',
293 | 291: 'lion, king of beasts, Panthera leo [狮子]',
294 | 292: 'tiger, Panthera tigris [老虎]',
295 | 293: 'cheetah, chetah, Acinonyx jubatus [猎豹]',
296 | 294: 'brown bear, bruin, Ursus arctos [棕熊]',
297 | 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus [美洲黑熊]',
298 | 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus [冰熊,北极熊]',
299 | 297: 'sloth bear, Melursus ursinus, Ursus ursinus [懒熊]',
300 | 298: 'mongoose [猫鼬]',
301 | 299: 'meerkat, mierkat [猫鼬,海猫]',
302 | 300: 'tiger beetle [虎甲虫]',
303 | 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle [瓢虫]',
304 | 302: 'ground beetle, carabid beetle [土鳖虫]',
305 | 303: 'long-horned beetle, longicorn, longicorn beetle [天牛]',
306 | 304: 'leaf beetle, chrysomelid [龟甲虫]',
307 | 305: 'dung beetle [粪甲虫]',
308 | 306: 'rhinoceros beetle [犀牛甲虫]',
309 | 307: 'weevil [象甲]',
310 | 308: 'fly [苍蝇]',
311 | 309: 'bee [蜜蜂]',
312 | 310: 'ant, emmet, pismire [蚂蚁]',
313 | 311: 'grasshopper, hopper [蚱蜢]',
314 | 312: 'cricket [蟋蟀]',
315 | 313: 'walking stick, walkingstick, stick insect [竹节虫]',
316 | 314: 'cockroach, roach [蟑螂]',
317 | 315: 'mantis, mantid [螳螂]',
318 | 316: 'cicada, cicala [蝉]',
319 | 317: 'leafhopper [叶蝉]',
320 | 318: 'lacewing, lacewing fly [草蜻蛉]',
321 | 319: 'dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk [蜻蜓]',
322 | 320: 'damselfly [豆娘,蜻蛉]',
323 | 321: 'admiral [优红蛱蝶]',
324 | 322: 'ringlet, ringlet butterfly [小环蝴蝶]',
325 | 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus [君主蝴蝶,大斑蝶]',
326 | 324: 'cabbage butterfly [菜粉蝶]',
327 | 325: 'sulphur butterfly, sulfur butterfly [白蝴蝶]',
328 | 326: 'lycaenid, lycaenid butterfly [灰蝶]',
329 | 327: 'starfish, sea star [海星]',
330 | 328: 'sea urchin [海胆]',
331 | 329: 'sea cucumber, holothurian [海参,海黄瓜]',
332 | 330: 'wood rabbit, cottontail, cottontail rabbit [野兔]',
333 | 331: 'hare [兔]',
334 | 332: 'Angora, Angora rabbit [安哥拉兔]',
335 | 333: 'hamster [仓鼠]',
336 | 334: 'porcupine, hedgehog [刺猬,豪猪,]',
337 | 335: 'fox squirrel, eastern fox squirrel, Sciurus niger [黑松鼠]',
338 | 336: 'marmot [土拨鼠]',
339 | 337: 'beaver [海狸]',
340 | 338: 'guinea pig, Cavia cobaya [豚鼠,豚鼠]',
341 | 339: 'sorrel [栗色马]',
342 | 340: 'zebra [斑马]',
343 | 341: 'hog, pig, grunter, squealer, Sus scrofa [猪]',
344 | 342: 'wild boar, boar, Sus scrofa [野猪]',
345 | 343: 'warthog [疣猪]',
346 | 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius [河马]',
347 | 345: 'ox [牛]',
348 | 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis [水牛,亚洲水牛]',
349 | 347: 'bison [野牛]',
350 | 348: 'ram, tup [公羊]',
351 | 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis [大角羊,洛矶山大角羊]',
352 | 350: 'ibex, Capra ibex [山羊]',
353 | 351: 'hartebeest [狷羚]',
354 | 352: 'impala, Aepyceros melampus [黑斑羚]',
355 | 353: 'gazelle [瞪羚]',
356 | 354: 'Arabian camel, dromedary, Camelus dromedarius [阿拉伯单峰骆驼,骆驼]',
357 | 355: 'llama [羊驼]',
358 | 356: 'weasel [黄鼠狼]',
359 | 357: 'mink [水貂]',
360 | 358: 'polecat, fitch, foulmart, foumart, Mustela putorius [臭猫]',
361 | 359: 'black-footed ferret, ferret, Mustela nigripes [黑足鼬]',
362 | 360: 'otter [水獭]',
363 | 361: 'skunk, polecat, wood pussy [臭鼬,木猫]',
364 | 362: 'badger [獾]',
365 | 363: 'armadillo [犰狳]',
366 | 364: 'three-toed sloth, ai, Bradypus tridactylus [树懒]',
367 | 365: 'orangutan, orang, orangutang, Pongo pygmaeus [猩猩,婆罗洲猩猩]',
368 | 366: 'gorilla, Gorilla gorilla [大猩猩]',
369 | 367: 'chimpanzee, chimp, Pan troglodytes [黑猩猩]',
370 | 368: 'gibbon, Hylobates lar [长臂猿]',
371 | 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus [合趾猿长臂猿,合趾猿]',
372 | 370: 'guenon, guenon monkey [长尾猴]',
373 | 371: 'patas, hussar monkey, Erythrocebus patas [赤猴]',
374 | 372: 'baboon [狒狒]',
375 | 373: 'macaque [恒河猴,猕猴]',
376 | 374: 'langur [白头叶猴]',
377 | 375: 'colobus, colobus monkey [疣猴]',
378 | 376: 'proboscis monkey, Nasalis larvatus [长鼻猴]',
379 | 377: 'marmoset [狨(美洲产小型长尾猴)]',
380 | 378: 'capuchin, ringtail, Cebus capucinus [卷尾猴]',
381 | 379: 'howler monkey, howler [吼猴]',
382 | 380: 'titi, titi monkey [伶猴]',
383 | 381: 'spider monkey, Ateles geoffroyi [蜘蛛猴]',
384 | 382: 'squirrel monkey, Saimiri sciureus [松鼠猴]',
385 | 383: 'Madagascar cat, ring-tailed lemur, Lemur catta [马达加斯加环尾狐猴,鼠狐猴]',
386 | 384: 'indri, indris, Indri indri, Indri brevicaudatus [大狐猴,马达加斯加大狐猴]',
387 | 385: 'Indian elephant, Elephas maximus [印度大象,亚洲象]',
388 | 386: 'African elephant, Loxodonta africana [非洲象,非洲象]',
389 | 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens [小熊猫]',
390 | 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca [大熊猫]',
391 | 389: 'barracouta, snoek [杖鱼]',
392 | 390: 'eel [鳗鱼]',
393 | 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch [银鲑,银鲑鱼]',
394 | 392: 'rock beauty, Holocanthus tricolor [三色刺蝶鱼]',
395 | 393: 'anemone fish [海葵鱼]',
396 | 394: 'sturgeon [鲟鱼]',
397 | 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus [雀鳝]',
398 | 396: 'lionfish [狮子鱼]',
399 | 397: 'puffer, pufferfish, blowfish, globefish [河豚]',
400 | 398: 'abacus [算盘]',
401 | 399: 'abaya [长袍]',
402 | 400: 'academic gown, academic robe, judge robe [学位袍]',
403 | 401: 'accordion, piano accordion, squeeze box [手风琴]',
404 | 402: 'acoustic guitar [原声吉他]',
405 | 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier [航空母舰]',
406 | 404: 'airliner [客机]',
407 | 405: 'airship, dirigible [飞艇]',
408 | 406: 'altar [祭坛]',
409 | 407: 'ambulance [救护车]',
410 | 408: 'amphibian, amphibious vehicle [水陆两用车]',
411 | 409: 'analog clock [模拟时钟]',
412 | 410: 'apiary, bee house [蜂房]',
413 | 411: 'apron [围裙]',
414 | 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin [垃圾桶]',
415 | 413: 'assault rifle, assault gun [攻击步枪,枪]',
416 | 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack [背包]',
417 | 415: 'bakery, bakeshop, bakehouse [面包店,面包铺,]',
418 | 416: 'balance beam, beam [平衡木]',
419 | 417: 'balloon [热气球]',
420 | 418: 'ballpoint, ballpoint pen, ballpen, Biro [圆珠笔]',
421 | 419: 'Band Aid [创可贴]',
422 | 420: 'banjo [班卓琴]',
423 | 421: 'bannister, banister, balustrade, balusters, handrail [栏杆,楼梯扶手]',
424 | 422: 'barbell [杠铃]',
425 | 423: 'barber chair [理发师的椅子]',
426 | 424: 'barbershop [理发店]',
427 | 425: 'barn [牲口棚]',
428 | 426: 'barometer [晴雨表]',
429 | 427: 'barrel, cask [圆筒]',
430 | 428: 'barrow, garden cart, lawn cart, wheelbarrow [园地小车,手推车]',
431 | 429: 'baseball [棒球]',
432 | 430: 'basketball [篮球]',
433 | 431: 'bassinet [婴儿床]',
434 | 432: 'bassoon [巴松管,低音管]',
435 | 433: 'bathing cap, swimming cap [游泳帽]',
436 | 434: 'bath towel [沐浴毛巾]',
437 | 435: 'bathtub, bathing tub, bath, tub [浴缸,澡盆]',
438 | 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon [沙滩车,旅行车]',
439 | 437: 'beacon, lighthouse, beacon light, pharos [灯塔]',
440 | 438: 'beaker [高脚杯]',
441 | 439: 'bearskin, busby, shako [熊皮高帽]',
442 | 440: 'beer bottle [啤酒瓶]',
443 | 441: 'beer glass [啤酒杯]',
444 | 442: 'bell cote, bell cot [钟塔]',
445 | 443: 'bib [(小儿用的)围嘴]',
446 | 444: 'bicycle-built-for-two, tandem bicycle, tandem [串联自行车,]',
447 | 445: 'bikini, two-piece [比基尼]',
448 | 446: 'binder, ring-binder [装订册]',
449 | 447: 'binoculars, field glasses, opera glasses [双筒望远镜]',
450 | 448: 'birdhouse [鸟舍]',
451 | 449: 'boathouse [船库]',
452 | 450: 'bobsled, bobsleigh, bob [雪橇]',
453 | 451: 'bolo tie, bolo, bola tie, bola [饰扣式领带]',
454 | 452: 'bonnet, poke bonnet [阔边女帽]',
455 | 453: 'bookcase [书橱]',
456 | 454: 'bookshop, bookstore, bookstall [书店,书摊]',
457 | 455: 'bottlecap [瓶盖]',
458 | 456: 'bow [弓箭]',
459 | 457: 'bow tie, bow-tie, bowtie [蝴蝶结领结]',
460 | 458: 'brass, memorial tablet, plaque [铜制牌位]',
461 | 459: 'brassiere, bra, bandeau [奶罩]',
462 | 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty [防波堤,海堤]',
463 | 461: 'breastplate, aegis, egis [铠甲]',
464 | 462: 'broom [扫帚]',
465 | 463: 'bucket, pail [桶]',
466 | 464: 'buckle [扣环]',
467 | 465: 'bulletproof vest [防弹背心]',
468 | 466: 'bullet train, bullet [动车,子弹头列车]',
469 | 467: 'butcher shop, meat market [肉铺,肉菜市场]',
470 | 468: 'cab, hack, taxi, taxicab [出租车]',
471 | 469: 'caldron, cauldron [大锅]',
472 | 470: 'candle, taper, wax light [蜡烛]',
473 | 471: 'cannon [大炮]',
474 | 472: 'canoe [独木舟]',
475 | 473: 'can opener, tin opener [开瓶器,开罐器]',
476 | 474: 'cardigan [开衫]',
477 | 475: 'car mirror [车镜]',
478 | 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig [旋转木马]',
479 | 477: 'carpenters kit, tool kit [木匠的工具包,工具包]',
480 | 478: 'carton [纸箱]',
481 | 479: 'car wheel [车轮]',
482 | 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM [取款机,自动取款机]',
483 | 481: 'cassette [盒式录音带]',
484 | 482: 'cassette player [卡带播放器]',
485 | 483: 'castle [城堡]',
486 | 484: 'catamaran [双体船]',
487 | 485: 'CD player [CD播放器]',
488 | 486: 'cello, violoncello [大提琴]',
489 | 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone [移动电话,手机]',
490 | 488: 'chain [铁链]',
491 | 489: 'chainlink fence [围栏]',
492 | 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour [链甲]',
493 | 491: 'chain saw, chainsaw [电锯,油锯]',
494 | 492: 'chest [箱子]',
495 | 493: 'chiffonier, commode [衣柜,洗脸台]',
496 | 494: 'chime, bell, gong [编钟,钟,锣]',
497 | 495: 'china cabinet, china closet [中国橱柜]',
498 | 496: 'Christmas stocking [圣诞袜]',
499 | 497: 'church, church building [教堂,教堂建筑]',
500 | 498: 'cinema, movie theater, movie theatre, movie house, picture palace [电影院,剧场]',
501 | 499: 'cleaver, meat cleaver, chopper [切肉刀,菜刀]',
502 | 500: 'cliff dwelling [悬崖屋]',
503 | 501: 'cloak [斗篷]',
504 | 502: 'clog, geta, patten, sabot [木屐,木鞋]',
505 | 503: 'cocktail shaker [鸡尾酒调酒器]',
506 | 504: 'coffee mug [咖啡杯]',
507 | 505: 'coffeepot [咖啡壶]',
508 | 506: 'coil, spiral, volute, whorl, helix [螺旋结构(楼梯)]',
509 | 507: 'combination lock [组合锁]',
510 | 508: 'computer keyboard, keypad [电脑键盘,键盘]',
511 | 509: 'confectionery, confectionary, candy store [糖果,糖果店]',
512 | 510: 'container ship, containership, container vessel [集装箱船]',
513 | 511: 'convertible [敞篷车]',
514 | 512: 'corkscrew, bottle screw [开瓶器,瓶螺杆]',
515 | 513: 'cornet, horn, trumpet, trump [短号,喇叭]',
516 | 514: 'cowboy boot [牛仔靴]',
517 | 515: 'cowboy hat, ten-gallon hat [牛仔帽]',
518 | 516: 'cradle [摇篮]',
519 | 517: 'crane [起重机]',
520 | 518: 'crash helmet [头盔]',
521 | 519: 'crate [板条箱]',
522 | 520: 'crib, cot [小儿床]',
523 | 521: 'Crock Pot [砂锅]',
524 | 522: 'croquet ball [槌球]',
525 | 523: 'crutch [拐杖]',
526 | 524: 'cuirass [胸甲]',
527 | 525: 'dam, dike, dyke [大坝,堤防]',
528 | 526: 'desk [书桌]',
529 | 527: 'desktop computer [台式电脑]',
530 | 528: 'dial telephone, dial phone [有线电话]',
531 | 529: 'diaper, nappy, napkin [尿布湿]',
532 | 530: 'digital clock [数字时钟]',
533 | 531: 'digital watch [数字手表]',
534 | 532: 'dining table, board [餐桌板]',
535 | 533: 'dishrag, dishcloth [抹布]',
536 | 534: 'dishwasher, dish washer, dishwashing machine [洗碗机,洗碟机]',
537 | 535: 'disk brake, disc brake [盘式制动器]',
538 | 536: 'dock, dockage, docking facility [码头,船坞,码头设施]',
539 | 537: 'dogsled, dog sled, dog sleigh [狗拉雪橇]',
540 | 538: 'dome [圆顶]',
541 | 539: 'doormat, welcome mat [门垫,垫子]',
542 | 540: 'drilling platform, offshore rig [钻井平台,海上钻井]',
543 | 541: 'drum, membranophone, tympan [鼓,乐器,鼓膜]',
544 | 542: 'drumstick [鼓槌]',
545 | 543: 'dumbbell [哑铃]',
546 | 544: 'Dutch oven [荷兰烤箱]',
547 | 545: 'electric fan, blower [电风扇,鼓风机]',
548 | 546: 'electric guitar [电吉他]',
549 | 547: 'electric locomotive [电力机车]',
550 | 548: 'entertainment center [电视,电视柜]',
551 | 549: 'envelope [信封]',
552 | 550: 'espresso maker [浓缩咖啡机]',
553 | 551: 'face powder [扑面粉]',
554 | 552: 'feather boa, boa [女用长围巾]',
555 | 553: 'file, file cabinet, filing cabinet [文件,文件柜,档案柜]',
556 | 554: 'fireboat [消防船]',
557 | 555: 'fire engine, fire truck [消防车]',
558 | 556: 'fire screen, fireguard [火炉栏]',
559 | 557: 'flagpole, flagstaff [旗杆]',
560 | 558: 'flute, transverse flute [长笛]',
561 | 559: 'folding chair [折叠椅]',
562 | 560: 'football helmet [橄榄球头盔]',
563 | 561: 'forklift [叉车]',
564 | 562: 'fountain [喷泉]',
565 | 563: 'fountain pen [钢笔]',
566 | 564: 'four-poster [有四根帷柱的床]',
567 | 565: 'freight car [运货车厢]',
568 | 566: 'French horn, horn [圆号,喇叭]',
569 | 567: 'frying pan, frypan, skillet [煎锅]',
570 | 568: 'fur coat [裘皮大衣]',
571 | 569: 'garbage truck, dustcart [垃圾车]',
572 | 570: 'gasmask, respirator, gas helmet [防毒面具,呼吸器]',
573 | 571: 'gas pump, gasoline pump, petrol pump, island dispenser [汽油泵]',
574 | 572: 'goblet [高脚杯]',
575 | 573: 'go-kart [卡丁车]',
576 | 574: 'golf ball [高尔夫球]',
577 | 575: 'golfcart, golf cart [高尔夫球车]',
578 | 576: 'gondola [狭长小船]',
579 | 577: 'gong, tam-tam [锣]',
580 | 578: 'gown [礼服]',
581 | 579: 'grand piano, grand [钢琴]',
582 | 580: 'greenhouse, nursery, glasshouse [温室,苗圃]',
583 | 581: 'grille, radiator grille [散热器格栅]',
584 | 582: 'grocery store, grocery, food market, market [杂货店,食品市场]',
585 | 583: 'guillotine [断头台]',
586 | 584: 'hair slide [小发夹]',
587 | 585: 'hair spray [头发喷雾]',
588 | 586: 'half track [半履带装甲车]',
589 | 587: 'hammer [锤子]',
590 | 588: 'hamper [大篮子]',
591 | 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier [手摇鼓风机,吹风机]',
592 | 590: 'hand-held computer, hand-held microcomputer [手提电脑]',
593 | 591: 'handkerchief, hankie, hanky, hankey [手帕]',
594 | 592: 'hard disc, hard disk, fixed disk [硬盘]',
595 | 593: 'harmonica, mouth organ, harp, mouth harp [口琴,口风琴]',
596 | 594: 'harp [竖琴]',
597 | 595: 'harvester, reaper [收割机]',
598 | 596: 'hatchet [斧头]',
599 | 597: 'holster [手枪皮套]',
600 | 598: 'home theater, home theatre [家庭影院]',
601 | 599: 'honeycomb [蜂窝]',
602 | 600: 'hook, claw [钩爪]',
603 | 601: 'hoopskirt, crinoline [衬裙]',
604 | 602: 'horizontal bar, high bar [单杠]',
605 | 603: 'horse cart, horse-cart [马车]',
606 | 604: 'hourglass [沙漏]',
607 | 605: 'iPod [手机,iPad]',
608 | 606: 'iron, smoothing iron [熨斗]',
609 | 607: 'jack-o-lantern [南瓜灯笼]',
610 | 608: 'jean, blue jean, denim [牛仔裤,蓝色牛仔裤]',
611 | 609: 'jeep, landrover [吉普车]',
612 | 610: 'jersey, T-shirt, tee shirt [运动衫,T恤]',
613 | 611: 'jigsaw puzzle [拼图]',
614 | 612: 'jinrikisha, ricksha, rickshaw [人力车]',
615 | 613: 'joystick [操纵杆]',
616 | 614: 'kimono [和服]',
617 | 615: 'knee pad [护膝]',
618 | 616: 'knot [蝴蝶结]',
619 | 617: 'lab coat, laboratory coat [大褂,实验室外套]',
620 | 618: 'ladle [长柄勺]',
621 | 619: 'lampshade, lamp shade [灯罩]',
622 | 620: 'laptop, laptop computer [笔记本电脑]',
623 | 621: 'lawn mower, mower [割草机]',
624 | 622: 'lens cap, lens cover [镜头盖]',
625 | 623: 'letter opener, paper knife, paperknife [开信刀,裁纸刀]',
626 | 624: 'library [图书馆]',
627 | 625: 'lifeboat [救生艇]',
628 | 626: 'lighter, light, igniter, ignitor [点火器,打火机]',
629 | 627: 'limousine, limo [豪华轿车]',
630 | 628: 'liner, ocean liner [远洋班轮]',
631 | 629: 'lipstick, lip rouge [唇膏,口红]',
632 | 630: 'Loafer [平底便鞋]',
633 | 631: 'lotion [洗剂]',
634 | 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system [扬声器]',
635 | 633: 'loupe, jewelers loupe [放大镜]',
636 | 634: 'lumbermill, sawmill [锯木厂]',
637 | 635: 'magnetic compass [磁罗盘]',
638 | 636: 'mailbag, postbag [邮袋]',
639 | 637: 'mailbox, letter box [信箱]',
640 | 638: 'maillot [女游泳衣]',
641 | 639: 'maillot, tank suit [有肩带浴衣]',
642 | 640: 'manhole cover [窨井盖]',
643 | 641: 'maraca [沙球(一种打击乐器)]',
644 | 642: 'marimba, xylophone [马林巴木琴]',
645 | 643: 'mask [面膜]',
646 | 644: 'matchstick [火柴]',
647 | 645: 'maypole [花柱]',
648 | 646: 'maze, labyrinth [迷宫]',
649 | 647: 'measuring cup [量杯]',
650 | 648: 'medicine chest, medicine cabinet [药箱]',
651 | 649: 'megalith, megalithic structure [巨石,巨石结构]',
652 | 650: 'microphone, mike [麦克风]',
653 | 651: 'microwave, microwave oven [微波炉]',
654 | 652: 'military uniform [军装]',
655 | 653: 'milk can [奶桶]',
656 | 654: 'minibus [迷你巴士]',
657 | 655: 'miniskirt, mini [迷你裙]',
658 | 656: 'minivan [面包车]',
659 | 657: 'missile [导弹]',
660 | 658: 'mitten [连指手套]',
661 | 659: 'mixing bowl [搅拌钵]',
662 | 660: 'mobile home, manufactured home [活动房屋(由汽车拖拉的)]',
663 | 661: 'Model T [T型发动机小汽车]',
664 | 662: 'modem [调制解调器]',
665 | 663: 'monastery [修道院]',
666 | 664: 'monitor [显示器]',
667 | 665: 'moped [电瓶车]',
668 | 666: 'mortar [砂浆]',
669 | 667: 'mortarboard [学士]',
670 | 668: 'mosque [清真寺]',
671 | 669: 'mosquito net [蚊帐]',
672 | 670: 'motor scooter, scooter [摩托车]',
673 | 671: 'mountain bike, all-terrain bike, off-roader [山地自行车]',
674 | 672: 'mountain tent [登山帐]',
675 | 673: 'mouse, computer mouse [鼠标,电脑鼠标]',
676 | 674: 'mousetrap [捕鼠器]',
677 | 675: 'moving van [搬家车]',
678 | 676: 'muzzle [口套]',
679 | 677: 'nail [钉子]',
680 | 678: 'neck brace [颈托]',
681 | 679: 'necklace [项链]',
682 | 680: 'nipple [乳头(瓶)]',
683 | 681: 'notebook, notebook computer [笔记本,笔记本电脑]',
684 | 682: 'obelisk [方尖碑]',
685 | 683: 'oboe, hautboy, hautbois [双簧管]',
686 | 684: 'ocarina, sweet potato [陶笛,卵形笛]',
687 | 685: 'odometer, hodometer, mileometer, milometer [里程表]',
688 | 686: 'oil filter [滤油器]',
689 | 687: 'organ, pipe organ [风琴,管风琴]',
690 | 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO [示波器]',
691 | 689: 'overskirt [罩裙]',
692 | 690: 'oxcart [牛车]',
693 | 691: 'oxygen mask [氧气面罩]',
694 | 692: 'packet [包装]',
695 | 693: 'paddle, boat paddle [船桨]',
696 | 694: 'paddlewheel, paddle wheel [明轮,桨轮]',
697 | 695: 'padlock [挂锁,扣锁]',
698 | 696: 'paintbrush [画笔]',
699 | 697: 'pajama, pyjama, pjs, jammies [睡衣]',
700 | 698: 'palace [宫殿]',
701 | 699: 'panpipe, pandean pipe, syrinx [排箫,鸣管]',
702 | 700: 'paper towel [纸巾]',
703 | 701: 'parachute, chute [降落伞]',
704 | 702: 'parallel bars, bars [双杠]',
705 | 703: 'park bench [公园长椅]',
706 | 704: 'parking meter [停车收费表,停车计时器]',
707 | 705: 'passenger car, coach, carriage [客车,教练车]',
708 | 706: 'patio, terrace [露台,阳台]',
709 | 707: 'pay-phone, pay-station [付费电话]',
710 | 708: 'pedestal, plinth, footstall [基座,基脚]',
711 | 709: 'pencil box, pencil case [铅笔盒]',
712 | 710: 'pencil sharpener [卷笔刀]',
713 | 711: 'perfume, essence [香水(瓶)]',
714 | 712: 'Petri dish [培养皿]',
715 | 713: 'photocopier [复印机]',
716 | 714: 'pick, plectrum, plectron [拨弦片,拨子]',
717 | 715: 'pickelhaube [尖顶头盔]',
718 | 716: 'picket fence, paling [栅栏,栅栏]',
719 | 717: 'pickup, pickup truck [皮卡,皮卡车]',
720 | 718: 'pier [桥墩]',
721 | 719: 'piggy bank, penny bank [存钱罐]',
722 | 720: 'pill bottle [药瓶]',
723 | 721: 'pillow [枕头]',
724 | 722: 'ping-pong ball [乒乓球]',
725 | 723: 'pinwheel [风车]',
726 | 724: 'pirate, pirate ship [海盗船]',
727 | 725: 'pitcher, ewer [水罐]',
728 | 726: 'plane, carpenters plane, woodworking plane [木工刨]',
729 | 727: 'planetarium [天文馆]',
730 | 728: 'plastic bag [塑料袋]',
731 | 729: 'plate rack [板架]',
732 | 730: 'plow, plough [犁型铲雪机]',
733 | 731: 'plunger, plumbers helper [手压皮碗泵]',
734 | 732: 'Polaroid camera, Polaroid Land camera [宝丽来相机]',
735 | 733: 'pole [电线杆]',
736 | 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria [警车,巡逻车]',
737 | 735: 'poncho [雨披]',
738 | 736: 'pool table, billiard table, snooker table [台球桌]',
739 | 737: 'pop bottle, soda bottle [充气饮料瓶]',
740 | 738: 'pot, flowerpot [花盆]',
741 | 739: 'potters wheel [陶工旋盘]',
742 | 740: 'power drill [电钻]',
743 | 741: 'prayer rug, prayer mat [祈祷垫,地毯]',
744 | 742: 'printer [打印机]',
745 | 743: 'prison, prison house [监狱]',
746 | 744: 'projectile, missile [炮弹,导弹]',
747 | 745: 'projector [投影仪]',
748 | 746: 'puck, hockey puck [冰球]',
749 | 747: 'punching bag, punch bag, punching ball, punchball [沙包,吊球]',
750 | 748: 'purse [钱包]',
751 | 749: 'quill, quill pen [羽管笔]',
752 | 750: 'quilt, comforter, comfort, puff [被子]',
753 | 751: 'racer, race car, racing car [赛车]',
754 | 752: 'racket, racquet [球拍]',
755 | 753: 'radiator [散热器]',
756 | 754: 'radio, wireless [收音机]',
757 | 755: 'radio telescope, radio reflector [射电望远镜,无线电反射器]',
758 | 756: 'rain barrel [雨桶]',
759 | 757: 'recreational vehicle, RV, R.V. [休闲车,房车]',
760 | 758: 'reel [卷轴,卷筒]',
761 | 759: 'reflex camera [反射式照相机]',
762 | 760: 'refrigerator, icebox [冰箱,冰柜]',
763 | 761: 'remote control, remote [遥控器]',
764 | 762: 'restaurant, eating house, eating place, eatery [餐厅,饮食店,食堂]',
765 | 763: 'revolver, six-gun, six-shooter [左轮手枪]',
766 | 764: 'rifle [步枪]',
767 | 765: 'rocking chair, rocker [摇椅]',
768 | 766: 'rotisserie [电转烤肉架]',
769 | 767: 'rubber eraser, rubber, pencil eraser [橡皮]',
770 | 768: 'rugby ball [橄榄球]',
771 | 769: 'rule, ruler [直尺]',
772 | 770: 'running shoe [跑步鞋]',
773 | 771: 'safe [保险柜]',
774 | 772: 'safety pin [安全别针]',
775 | 773: 'saltshaker, salt shaker [盐瓶(调味用)]',
776 | 774: 'sandal [凉鞋]',
777 | 775: 'sarong [纱笼,围裙]',
778 | 776: 'sax, saxophone [萨克斯管]',
779 | 777: 'scabbard [剑鞘]',
780 | 778: 'scale, weighing machine [秤,称重机]',
781 | 779: 'school bus [校车]',
782 | 780: 'schooner [帆船]',
783 | 781: 'scoreboard [记分牌]',
784 | 782: 'screen, CRT screen [屏幕]',
785 | 783: 'screw [螺丝]',
786 | 784: 'screwdriver [螺丝刀]',
787 | 785: 'seat belt, seatbelt [安全带]',
788 | 786: 'sewing machine [缝纫机]',
789 | 787: 'shield, buckler [盾牌,盾牌]',
790 | 788: 'shoe shop, shoe-shop, shoe store [皮鞋店,鞋店]',
791 | 789: 'shoji [障子]',
792 | 790: 'shopping basket [购物篮]',
793 | 791: 'shopping cart [购物车]',
794 | 792: 'shovel [铁锹]',
795 | 793: 'shower cap [浴帽]',
796 | 794: 'shower curtain [浴帘]',
797 | 795: 'ski [滑雪板]',
798 | 796: 'ski mask [滑雪面罩]',
799 | 797: 'sleeping bag [睡袋]',
800 | 798: 'slide rule, slipstick [滑尺]',
801 | 799: 'sliding door [滑动门]',
802 | 800: 'slot, one-armed bandit [角子老虎机]',
803 | 801: 'snorkel [潜水通气管]',
804 | 802: 'snowmobile [雪橇]',
805 | 803: 'snowplow, snowplough [扫雪机,扫雪机]',
806 | 804: 'soap dispenser [皂液器]',
807 | 805: 'soccer ball [足球]',
808 | 806: 'sock [袜子]',
809 | 807: 'solar dish, solar collector, solar furnace [碟式太阳能,太阳能集热器,太阳能炉]',
810 | 808: 'sombrero [宽边帽]',
811 | 809: 'soup bowl [汤碗]',
812 | 810: 'space bar [空格键]',
813 | 811: 'space heater [空间加热器]',
814 | 812: 'space shuttle [航天飞机]',
815 | 813: 'spatula [铲(搅拌或涂敷用的)]',
816 | 814: 'speedboat [快艇]',
817 | 815: 'spider web, spiders web [蜘蛛网]',
818 | 816: 'spindle [纺锤,纱锭]',
819 | 817: 'sports car, sport car [跑车]',
820 | 818: 'spotlight, spot [聚光灯]',
821 | 819: 'stage [舞台]',
822 | 820: 'steam locomotive [蒸汽机车]',
823 | 821: 'steel arch bridge [钢拱桥]',
824 | 822: 'steel drum [钢滚筒]',
825 | 823: 'stethoscope [听诊器]',
826 | 824: 'stole [女用披肩]',
827 | 825: 'stone wall [石头墙]',
828 | 826: 'stopwatch, stop watch [秒表]',
829 | 827: 'stove [火炉]',
830 | 828: 'strainer [过滤器]',
831 | 829: 'streetcar, tram, tramcar, trolley, trolley car [有轨电车,电车]',
832 | 830: 'stretcher [担架]',
833 | 831: 'studio couch, day bed [沙发床]',
834 | 832: 'stupa, tope [佛塔]',
835 | 833: 'submarine, pigboat, sub, U-boat [潜艇,潜水艇]',
836 | 834: 'suit, suit of clothes [套装,衣服]',
837 | 835: 'sundial [日晷]',
838 | 836: 'sunglass [太阳镜]',
839 | 837: 'sunglasses, dark glasses, shades [太阳镜,墨镜]',
840 | 838: 'sunscreen, sunblock, sun blocker [防晒霜,防晒剂]',
841 | 839: 'suspension bridge [悬索桥]',
842 | 840: 'swab, swob, mop [拖把]',
843 | 841: 'sweatshirt [运动衫]',
844 | 842: 'swimming trunks, bathing trunks [游泳裤]',
845 | 843: 'swing [秋千]',
846 | 844: 'switch, electric switch, electrical switch [开关,电器开关]',
847 | 845: 'syringe [注射器]',
848 | 846: 'table lamp [台灯]',
849 | 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle [坦克,装甲战车,装甲战斗车辆]',
850 | 848: 'tape player [磁带播放器]',
851 | 849: 'teapot [茶壶]',
852 | 850: 'teddy, teddy bear [泰迪,泰迪熊]',
853 | 851: 'television, television system [电视]',
854 | 852: 'tennis ball [网球]',
855 | 853: 'thatch, thatched roof [茅草,茅草屋顶]',
856 | 854: 'theater curtain, theatre curtain [幕布,剧院的帷幕]',
857 | 855: 'thimble [顶针]',
858 | 856: 'thresher, thrasher, threshing machine [脱粒机]',
859 | 857: 'throne [宝座]',
860 | 858: 'tile roof [瓦屋顶]',
861 | 859: 'toaster [烤面包机]',
862 | 860: 'tobacco shop, tobacconist shop, tobacconist [烟草店,烟草]',
863 | 861: 'toilet seat [马桶]',
864 | 862: 'torch [火炬]',
865 | 863: 'totem pole [图腾柱]',
866 | 864: 'tow truck, tow car, wrecker [拖车,牵引车,清障车]',
867 | 865: 'toyshop [玩具店]',
868 | 866: 'tractor [拖拉机]',
869 | 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi [拖车,铰接式卡车]',
870 | 868: 'tray [托盘]',
871 | 869: 'trench coat [风衣]',
872 | 870: 'tricycle, trike, velocipede [三轮车]',
873 | 871: 'trimaran [三体船]',
874 | 872: 'tripod [三脚架]',
875 | 873: 'triumphal arch [凯旋门]',
876 | 874: 'trolleybus, trolley coach, trackless trolley [无轨电车]',
877 | 875: 'trombone [长号]',
878 | 876: 'tub, vat [浴盆,浴缸]',
879 | 877: 'turnstile [旋转式栅门]',
880 | 878: 'typewriter keyboard [打字机键盘]',
881 | 879: 'umbrella [伞]',
882 | 880: 'unicycle, monocycle [独轮车]',
883 | 881: 'upright, upright piano [直立式钢琴]',
884 | 882: 'vacuum, vacuum cleaner [真空吸尘器]',
885 | 883: 'vase [花瓶]',
886 | 884: 'vault [拱顶]',
887 | 885: 'velvet [天鹅绒]',
888 | 886: 'vending machine [自动售货机]',
889 | 887: 'vestment [祭服]',
890 | 888: 'viaduct [高架桥]',
891 | 889: 'violin, fiddle [小提琴,小提琴]',
892 | 890: 'volleyball [排球]',
893 | 891: 'waffle iron [松饼机]',
894 | 892: 'wall clock [挂钟]',
895 | 893: 'wallet, billfold, notecase, pocketbook [钱包,皮夹]',
896 | 894: 'wardrobe, closet, press [衣柜,壁橱]',
897 | 895: 'warplane, military plane [军用飞机]',
898 | 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin [洗脸盆,洗手盆]',
899 | 897: 'washer, automatic washer, washing machine [洗衣机,自动洗衣机]',
900 | 898: 'water bottle [水瓶]',
901 | 899: 'water jug [水壶]',
902 | 900: 'water tower [水塔]',
903 | 901: 'whiskey jug [威士忌壶]',
904 | 902: 'whistle [哨子]',
905 | 903: 'wig [假发]',
906 | 904: 'window screen [纱窗]',
907 | 905: 'window shade [百叶窗]',
908 | 906: 'Windsor tie [温莎领带]',
909 | 907: 'wine bottle [葡萄酒瓶]',
910 | 908: 'wing [飞机翅膀,飞机]',
911 | 909: 'wok [炒菜锅]',
912 | 910: 'wooden spoon [木制的勺子]',
913 | 911: 'wool, woolen, woollen [毛织品,羊绒]',
914 | 912: 'worm fence, snake fence, snake-rail fence, Virginia fence [栅栏,围栏]',
915 | 913: 'wreck [沉船]',
916 | 914: 'yawl [双桅船]',
917 | 915: 'yurt [蒙古包]',
918 | 916: 'web site, website, internet site, site [网站,互联网网站]',
919 | 917: 'comic book [漫画]',
920 | 918: 'crossword puzzle, crossword [纵横字谜]',
921 | 919: 'street sign [路标]',
922 | 920: 'traffic light, traffic signal, stoplight [交通信号灯]',
923 | 921: 'book jacket, dust cover, dust jacket, dust wrapper [防尘罩,书皮]',
924 | 922: 'menu [菜单]',
925 | 923: 'plate [盘子]',
926 | 924: 'guacamole [鳄梨酱]',
927 | 925: 'consomme [清汤]',
928 | 926: 'hot pot, hotpot [罐焖土豆烧肉]',
929 | 927: 'trifle [蛋糕]',
930 | 928: 'ice cream, icecream [冰淇淋]',
931 | 929: 'ice lolly, lolly, lollipop, popsicle [雪糕,冰棍,冰棒]',
932 | 930: 'French loaf [法式面包]',
933 | 931: 'bagel, beigel [百吉饼]',
934 | 932: 'pretzel [椒盐脆饼]',
935 | 933: 'cheeseburger [芝士汉堡]',
936 | 934: 'hotdog, hot dog, red hot [热狗]',
937 | 935: 'mashed potato [土豆泥]',
938 | 936: 'head cabbage [结球甘蓝]',
939 | 937: 'broccoli [西兰花]',
940 | 938: 'cauliflower [菜花]',
941 | 939: 'zucchini, courgette [绿皮密生西葫芦]',
942 | 940: 'spaghetti squash [西葫芦]',
943 | 941: 'acorn squash [小青南瓜]',
944 | 942: 'butternut squash [南瓜]',
945 | 943: 'cucumber, cuke [黄瓜]',
946 | 944: 'artichoke, globe artichoke [朝鲜蓟]',
947 | 945: 'bell pepper [甜椒]',
948 | 946: 'cardoon [刺棘蓟]',
949 | 947: 'mushroom [蘑菇]',
950 | 948: 'Granny Smith [绿苹果]',
951 | 949: 'strawberry [草莓]',
952 | 950: 'orange [橘子]',
953 | 951: 'lemon [柠檬]',
954 | 952: 'fig [无花果]',
955 | 953: 'pineapple, ananas [菠萝]',
956 | 954: 'banana [香蕉]',
957 | 955: 'jackfruit, jak, jack [菠萝蜜]',
958 | 956: 'custard apple [蛋奶冻苹果]',
959 | 957: 'pomegranate [石榴]',
960 | 958: 'hay [干草]',
961 | 959: 'carbonara [烤面条加干酪沙司]',
962 | 960: 'chocolate sauce, chocolate syrup [巧克力酱,巧克力糖浆]',
963 | 961: 'dough [面团]',
964 | 962: 'meat loaf, meatloaf [瑞士肉包,肉饼]',
965 | 963: 'pizza, pizza pie [披萨,披萨饼]',
966 | 964: 'potpie [馅饼]',
967 | 965: 'burrito [卷饼]',
968 | 966: 'red wine [红葡萄酒]',
969 | 967: 'espresso [意大利浓咖啡]',
970 | 968: 'cup [杯子]',
971 | 969: 'eggnog [蛋酒]',
972 | 970: 'alp [高山]',
973 | 971: 'bubble [泡泡]',
974 | 972: 'cliff, drop, drop-off [悬崖]',
975 | 973: 'coral reef [珊瑚礁]',
976 | 974: 'geyser [间歇泉]',
977 | 975: 'lakeside, lakeshore [湖边,湖岸]',
978 | 976: 'promontory, headland, head, foreland [海角]',
979 | 977: 'sandbar, sand bar [沙洲,沙坝]',
980 | 978: 'seashore, coast, seacoast, sea-coast [海滨,海岸]',
981 | 979: 'valley, vale [峡谷]',
982 | 980: 'volcano [火山]',
983 | 981: 'ballplayer, baseball player [棒球,棒球运动员]',
984 | 982: 'groom, bridegroom [新郎]',
985 | 983: 'scuba diver [潜水员]',
986 | 984: 'rapeseed [油菜]',
987 | 985: 'daisy [雏菊]',
988 | 986: 'yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum [杓兰]',
989 | 987: 'corn [玉米]',
990 | 988: 'acorn [橡子]',
991 | 989: 'hip, rose hip, rosehip [玫瑰果]',
992 | 990: 'buckeye, horse chestnut, conker [七叶树果实]',
993 | 991: 'coral fungus [珊瑚菌]',
994 | 992: 'agaric [木耳]',
995 | 993: 'gyromitra [鹿花菌]',
996 | 994: 'stinkhorn, carrion fungus [鬼笔菌]',
997 | 995: 'earthstar [地星(菌类)]',
998 | 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa [多叶奇果菌]',
999 | 997: 'bolete [牛肝菌]',
1000 | 998: 'ear, spike, capitulum [玉米穗]',
1001 | 999: 'toilet tissue, toilet paper, bathroom tissue [卫生纸]',
1002 | }
1003 |
--------------------------------------------------------------------------------
/pixelflow/data_in1k.py:
--------------------------------------------------------------------------------
1 | # ImageNet-1K Dataset and DataLoader
2 |
3 | from einops import rearrange
4 | import torch
5 | import torch.nn.functional as F
6 | from torch.utils.data.distributed import DistributedSampler
7 | from torchvision.datasets import ImageFolder
8 | from torchvision import transforms
9 | from PIL import Image
10 | import math
11 | from functools import partial
12 | import numpy as np
13 | import random
14 |
15 | from diffusers.models.embeddings import get_2d_rotary_pos_embed
16 |
17 | # https://github.com/facebookresearch/DiT/blob/main/train.py#L85
18 | def center_crop_arr(pil_image, image_size):
19 | while min(*pil_image.size) >= 2 * image_size:
20 | pil_image = pil_image.resize(
21 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
22 | )
23 |
24 | scale = image_size / min(*pil_image.size)
25 | pil_image = pil_image.resize(
26 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
27 | )
28 |
29 | arr = np.array(pil_image)
30 | crop_y = (arr.shape[0] - image_size) // 2
31 | crop_x = (arr.shape[1] - image_size) // 2
32 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
33 |
34 |
35 | def collate_fn(examples, config, noise_scheduler_copy):
36 | patch_size = config.model.params.patch_size
37 | pixel_values = torch.stack([eg[0] for eg in examples])
38 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
39 | input_ids = [eg[1] for eg in examples]
40 |
41 | batch_size = len(examples)
42 | stage_indices = list(range(config.scheduler.num_stages)) * (batch_size // config.scheduler.num_stages + 1)
43 | stage_indices = stage_indices[:batch_size]
44 |
45 | random.shuffle(stage_indices)
46 | stage_indices = torch.tensor(stage_indices, dtype=torch.int32)
47 | orig_height, orig_width = pixel_values.shape[-2:]
48 | timesteps = torch.randint(0, config.scheduler.num_train_timesteps, (batch_size,))
49 |
50 | sample_list, input_ids_list, pos_embed_list, seq_len_list, target_list, timestep_list = [], [], [], [], [], []
51 | for stage_idx in range(config.scheduler.num_stages):
52 | corrected_stage_idx = config.scheduler.num_stages - stage_idx - 1
53 | stage_select_indices = timesteps[stage_indices == corrected_stage_idx]
54 | Timesteps = noise_scheduler_copy.Timesteps_per_stage[corrected_stage_idx][stage_select_indices].float()
55 | batch_size_select = Timesteps.shape[0]
56 | pixel_values_select = pixel_values[stage_indices == corrected_stage_idx]
57 | input_ids_select = [input_ids[i] for i in range(batch_size) if stage_indices[i] == corrected_stage_idx]
58 |
59 | end_height, end_width = orig_height // (2 ** stage_idx), orig_width // (2 ** stage_idx)
60 |
61 | ################ build model input ################
62 | start_t, end_t = noise_scheduler_copy.start_t[corrected_stage_idx], noise_scheduler_copy.end_t[corrected_stage_idx]
63 |
64 | pixel_values_end = pixel_values_select
65 | pixel_values_start = pixel_values_select
66 | if stage_idx > 0:
67 | # pixel_values_end
68 | for downsample_idx in range(1, stage_idx + 1):
69 | pixel_values_end = F.interpolate(pixel_values_end, (orig_height // (2 ** downsample_idx), orig_width // (2 ** downsample_idx)), mode="bilinear")
70 |
71 | # pixel_values_start
72 | for downsample_idx in range(1, stage_idx + 2):
73 | pixel_values_start = F.interpolate(pixel_values_start, (orig_height // (2 ** downsample_idx), orig_width // (2 ** downsample_idx)), mode="bilinear")
74 | # upsample pixel_values_start
75 | pixel_values_start = F.interpolate(pixel_values_start, (end_height, end_width), mode="nearest")
76 |
77 | noise = torch.randn_like(pixel_values_end)
78 | pixel_values_end = end_t * pixel_values_end + (1.0 - end_t) * noise
79 | pixel_values_start = start_t * pixel_values_start + (1.0 - start_t) * noise
80 | target = pixel_values_end - pixel_values_start
81 |
82 | t_select = noise_scheduler_copy.t_window_per_stage[corrected_stage_idx][stage_select_indices].flatten()
83 | while len(t_select.shape) < pixel_values_start.ndim:
84 | t_select = t_select.unsqueeze(-1)
85 | xt = t_select.float() * pixel_values_end + (1.0 - t_select.float()) * pixel_values_start
86 |
87 | target = rearrange(target, 'b c (h ph) (w pw) -> (b h w) (c ph pw)', ph=patch_size, pw=patch_size)
88 | xt = rearrange(xt, 'b c (h ph) (w pw) -> (b h w) (c ph pw)', ph=patch_size, pw=patch_size)
89 |
90 | pos_embed = get_2d_rotary_pos_embed(
91 | embed_dim=config.model.params.attention_head_dim,
92 | crops_coords=((0, 0), (end_height // patch_size, end_width // patch_size)),
93 | grid_size=(end_height // patch_size, end_width // patch_size),
94 | )
95 | seq_len = (end_height // patch_size) * (end_width // patch_size)
96 | assert end_height == end_width, f"only support square image, got {seq_len}; TODO: latent_size_list"
97 | sample_list.append(xt)
98 | target_list.append(target)
99 | pos_embed_list.extend([pos_embed] * batch_size_select)
100 | seq_len_list.extend([seq_len] * batch_size_select)
101 | timestep_list.append(Timesteps)
102 | input_ids_list.extend(input_ids_select)
103 |
104 | pixel_values = torch.cat(sample_list, dim=0).to(memory_format=torch.contiguous_format)
105 | target_values = torch.cat(target_list, dim=0).to(memory_format=torch.contiguous_format)
106 | pos_embed = torch.cat([torch.stack(one_pos_emb, -1) for one_pos_emb in pos_embed_list], dim=0).float()
107 | cumsum_q_len = torch.cumsum(torch.tensor([0] + seq_len_list), 0).to(torch.int32)
108 | latent_size_list = torch.tensor([int(math.sqrt(seq_len)) for seq_len in seq_len_list], dtype=torch.int32)
109 |
110 | return {
111 | "pixel_values": pixel_values,
112 | "input_ids": input_ids_list,
113 | "pos_embed": pos_embed,
114 | "cumsum_q_len": cumsum_q_len,
115 | "batch_latent_size": latent_size_list,
116 | "seqlen_list_q": seq_len_list,
117 | "cumsum_kv_len": None,
118 | "batch_kv_len": None,
119 | "timesteps": torch.cat(timestep_list, dim=0),
120 | "target_values": target_values,
121 | }
122 |
123 |
124 | def build_imagenet_loader(config, noise_scheduler_copy):
125 | if config.data.center_crop:
126 | transform = transforms.Compose([
127 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, config.data.resolution)),
128 | transforms.RandomHorizontalFlip(),
129 | transforms.ToTensor(),
130 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
131 | ])
132 | else:
133 | transform = transforms.Compose([
134 | transforms.Resize(round(config.data.resolution * config.data.expand_ratio), interpolation=transforms.InterpolationMode.LANCZOS),
135 | transforms.RandomCrop(config.data.resolution),
136 | transforms.RandomHorizontalFlip(),
137 | transforms.ToTensor(),
138 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
139 | ])
140 | dataset = ImageFolder(config.data.root, transform=transform)
141 | sampler = DistributedSampler(
142 | dataset,
143 | num_replicas=torch.distributed.get_world_size(),
144 | rank=torch.distributed.get_rank(),
145 | shuffle=True,
146 | seed=config.seed,
147 | )
148 |
149 | loader = torch.utils.data.DataLoader(
150 | dataset,
151 | batch_size=config.data.batch_size,
152 | collate_fn=partial(collate_fn, config=config, noise_scheduler_copy=noise_scheduler_copy),
153 | shuffle=False,
154 | sampler=sampler,
155 | num_workers=config.data.num_workers,
156 | drop_last=True,
157 | )
158 | return loader, sampler
159 |
--------------------------------------------------------------------------------
/pixelflow/model.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from diffusers.models.embeddings import Timesteps, TimestepEmbedding, LabelEmbedding
7 | import warnings
8 |
9 | try:
10 | from flash_attn import flash_attn_varlen_func
11 | except ImportError:
12 | warnings.warn("`flash-attn` is not installed. Training mode may not work properly.", UserWarning)
13 | flash_attn_varlen_func = None
14 |
15 |
16 | def apply_rotary_emb(
17 | x: torch.Tensor,
18 | freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
19 | ) -> Tuple[torch.Tensor, torch.Tensor]:
20 | cos, sin = freqs_cis.unbind(-1)
21 | cos = cos[None, None]
22 | sin = sin[None, None]
23 | cos, sin = cos.to(x.device), sin.to(x.device)
24 |
25 | x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
26 | x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
27 | out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
28 |
29 | return out
30 |
31 |
32 | class PatchEmbed(nn.Module):
33 | def __init__(self, patch_size, in_channels, embed_dim, bias=True):
34 | super().__init__()
35 | self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias)
36 |
37 | def forward_unfold(self, x):
38 | out_unfold = x.matmul(self.proj.weight.view(self.proj.weight.size(0), -1).t())
39 | if self.proj.bias is not None:
40 | out_unfold += self.proj.bias.to(out_unfold.dtype)
41 | return out_unfold
42 |
43 | # force fp32 for strict numerical reproducibility (debug only)
44 | # @torch.autocast('cuda', enabled=False)
45 | def forward(self, x):
46 | if self.training:
47 | return self.forward_unfold(x)
48 | out = self.proj(x)
49 | out = out.flatten(2).transpose(1, 2) # BCHW -> BNC
50 |
51 | return out
52 |
53 | class AdaLayerNorm(nn.Module):
54 | def __init__(self, embedding_dim):
55 | super().__init__()
56 | self.embedding_dim = embedding_dim
57 | self.silu = nn.SiLU()
58 | self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
59 | self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
60 |
61 | def forward(self, x, timestep, seqlen_list=None):
62 | input_dtype = x.dtype
63 | emb = self.linear(self.silu(timestep))
64 |
65 | if seqlen_list is not None:
66 | # equivalent to `torch.repeat_interleave` but faster
67 | emb = torch.cat([one_emb[None].expand(repeat_time, -1) for one_emb, repeat_time in zip(emb, seqlen_list)])
68 | else:
69 | emb = emb.unsqueeze(1)
70 |
71 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.float().chunk(6, dim=-1)
72 | x = self.norm(x).float() * (1 + scale_msa) + shift_msa
73 | return x.to(input_dtype), gate_msa, shift_mlp, scale_mlp, gate_mlp
74 |
75 |
76 | class FeedForward(nn.Module):
77 | def __init__(self, dim, dim_out=None, mult=4, inner_dim=None, bias=True):
78 | super().__init__()
79 | inner_dim = int(dim * mult) if inner_dim is None else inner_dim
80 | dim_out = dim_out if dim_out is not None else dim
81 | self.fc1 = nn.Linear(dim, inner_dim, bias=bias)
82 | self.fc2 = nn.Linear(inner_dim, dim_out, bias=bias)
83 |
84 | def forward(self, hidden_states):
85 | hidden_states = self.fc1(hidden_states)
86 | hidden_states = F.gelu(hidden_states, approximate="tanh")
87 | hidden_states = self.fc2(hidden_states)
88 | return hidden_states
89 |
90 |
91 | class RMSNorm(nn.Module):
92 | def __init__(self, dim: int, eps=1e-6):
93 | super().__init__()
94 | self.weight = nn.Parameter(torch.ones(dim))
95 | self.eps = eps
96 |
97 | def forward(self, x):
98 | output = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
99 | return (self.weight * output).to(x.dtype)
100 |
101 |
102 | class Attention(nn.Module):
103 | def __init__(self, q_dim, kv_dim=None, heads=8, head_dim=64, dropout=0.0, bias=False):
104 | super().__init__()
105 | self.q_dim = q_dim
106 | self.kv_dim = kv_dim if kv_dim is not None else q_dim
107 | self.inner_dim = head_dim * heads
108 | self.dropout = dropout
109 | self.head_dim = head_dim
110 | self.num_heads = heads
111 |
112 | self.q_proj = nn.Linear(self.q_dim, self.inner_dim, bias=bias)
113 | self.k_proj = nn.Linear(self.kv_dim, self.inner_dim, bias=bias)
114 | self.v_proj = nn.Linear(self.kv_dim, self.inner_dim, bias=bias)
115 |
116 | self.o_proj = nn.Linear(self.inner_dim, self.q_dim, bias=bias)
117 |
118 | self.q_norm = RMSNorm(self.inner_dim)
119 | self.k_norm = RMSNorm(self.inner_dim)
120 |
121 | def prepare_attention_mask(
122 | # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L694
123 | self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
124 | ):
125 | head_size = self.num_heads
126 | if attention_mask is None:
127 | return attention_mask
128 |
129 | current_length: int = attention_mask.shape[-1]
130 | if current_length != target_length:
131 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
132 |
133 | if out_dim == 3:
134 | if attention_mask.shape[0] < batch_size * head_size:
135 | attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
136 | elif out_dim == 4:
137 | attention_mask = attention_mask.unsqueeze(1)
138 | attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
139 |
140 | return attention_mask
141 |
142 | def forward(
143 | self,
144 | inputs_q,
145 | inputs_kv,
146 | attention_mask=None,
147 | cross_attention=False,
148 | rope_pos_embed=None,
149 | cu_seqlens_q=None,
150 | cu_seqlens_k=None,
151 | max_seqlen_q=None,
152 | max_seqlen_k=None,
153 | ):
154 |
155 | inputs_kv = inputs_q if inputs_kv is None else inputs_kv
156 |
157 | query_states = self.q_proj(inputs_q)
158 | key_states = self.k_proj(inputs_kv)
159 | value_states = self.v_proj(inputs_kv)
160 |
161 | query_states = self.q_norm(query_states)
162 | key_states = self.k_norm(key_states)
163 |
164 | if max_seqlen_q is None:
165 | assert not self.training, "PixelFlow needs sequence packing for training"
166 |
167 | bsz, q_len, _ = inputs_q.shape
168 | _, kv_len, _ = inputs_kv.shape
169 |
170 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
171 | key_states = key_states.view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
172 | value_states = value_states.view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
173 |
174 | query_states = apply_rotary_emb(query_states, rope_pos_embed)
175 | if not cross_attention:
176 | key_states = apply_rotary_emb(key_states, rope_pos_embed)
177 |
178 | if attention_mask is not None:
179 | attention_mask = self.prepare_attention_mask(attention_mask, kv_len, bsz)
180 | # scaled_dot_product_attention expects attention_mask shape to be
181 | # (batch, heads, source_length, target_length)
182 | attention_mask = attention_mask.view(bsz, self.num_heads, -1, attention_mask.shape[-1])
183 |
184 | # with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.MATH]): # strict numerical reproducibility (debug only)
185 | attn_output = F.scaled_dot_product_attention(
186 | query_states,
187 | key_states,
188 | value_states,
189 | attn_mask=attention_mask,
190 | dropout_p=self.dropout if self.training else 0.0,
191 | is_causal=False,
192 | )
193 |
194 | attn_output = attn_output.transpose(1, 2).contiguous()
195 | attn_output = attn_output.view(bsz, q_len, self.inner_dim)
196 | attn_output = self.o_proj(attn_output)
197 | return attn_output
198 |
199 | else:
200 | # sequence packing mode
201 | query_states = query_states.view(-1, self.num_heads, self.head_dim)
202 | key_states = key_states.view(-1, self.num_heads, self.head_dim)
203 | value_states = value_states.view(-1, self.num_heads, self.head_dim)
204 |
205 | query_states = apply_rotary_emb(query_states.permute(1, 0, 2)[None], rope_pos_embed)[0].permute(1, 0, 2)
206 | if not cross_attention:
207 | key_states = apply_rotary_emb(key_states.permute(1, 0, 2)[None], rope_pos_embed)[0].permute(1, 0, 2)
208 |
209 | attn_output = flash_attn_varlen_func(
210 | query_states,
211 | key_states,
212 | value_states,
213 | cu_seqlens_q=cu_seqlens_q,
214 | cu_seqlens_k=cu_seqlens_k,
215 | max_seqlen_q=max_seqlen_q,
216 | max_seqlen_k=max_seqlen_k,
217 | )
218 |
219 | attn_output = attn_output.view(-1, self.num_heads * self.head_dim)
220 | attn_output = self.o_proj(attn_output)
221 | return attn_output
222 |
223 |
224 | class TransformerBlock(nn.Module):
225 | def __init__(self, dim, num_attention_heads, attention_head_dim, dropout=0.0,
226 | cross_attention_dim=None, attention_bias=False,
227 | ):
228 | super().__init__()
229 | self.norm1 = AdaLayerNorm(dim)
230 |
231 | # Self Attention
232 | self.attn1 = Attention(q_dim=dim, kv_dim=None, heads=num_attention_heads, head_dim=attention_head_dim, dropout=dropout, bias=attention_bias)
233 |
234 | if cross_attention_dim is not None:
235 | # Cross Attention
236 | self.norm2 = RMSNorm(dim, eps=1e-6)
237 | self.attn2 = Attention(q_dim=dim, kv_dim=cross_attention_dim, heads=num_attention_heads, head_dim=attention_head_dim, dropout=dropout, bias=attention_bias)
238 | else:
239 | self.attn2 = None
240 |
241 | self.norm3 = RMSNorm(dim, eps=1e-6)
242 | self.mlp = FeedForward(dim)
243 |
244 | def forward(
245 | self,
246 | hidden_states,
247 | encoder_hidden_states=None,
248 | encoder_attention_mask=None,
249 | timestep=None,
250 | rope_pos_embed=None,
251 | cu_seqlens_q=None,
252 | cu_seqlens_k=None,
253 | seqlen_list_q=None,
254 | seqlen_list_k=None,
255 | ):
256 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, timestep, seqlen_list_q)
257 |
258 | attn_output = self.attn1(
259 | inputs_q=norm_hidden_states,
260 | inputs_kv=None,
261 | attention_mask=None,
262 | cross_attention=False,
263 | rope_pos_embed=rope_pos_embed,
264 | cu_seqlens_q=cu_seqlens_q,
265 | cu_seqlens_k=cu_seqlens_q,
266 | max_seqlen_q=max(seqlen_list_q) if seqlen_list_q is not None else None,
267 | max_seqlen_k=max(seqlen_list_q) if seqlen_list_q is not None else None,
268 | )
269 |
270 | attn_output = (gate_msa * attn_output.float()).to(attn_output.dtype)
271 | hidden_states = attn_output + hidden_states
272 |
273 | if self.attn2 is not None:
274 | norm_hidden_states = self.norm2(hidden_states)
275 | attn_output = self.attn2(
276 | inputs_q=norm_hidden_states,
277 | inputs_kv=encoder_hidden_states,
278 | attention_mask=encoder_attention_mask,
279 | cross_attention=True,
280 | rope_pos_embed=rope_pos_embed,
281 | cu_seqlens_q=cu_seqlens_q,
282 | cu_seqlens_k=cu_seqlens_k,
283 | max_seqlen_q=max(seqlen_list_q) if seqlen_list_q is not None else None,
284 | max_seqlen_k=max(seqlen_list_k) if seqlen_list_k is not None else None,
285 | )
286 | hidden_states = hidden_states + attn_output
287 |
288 | norm_hidden_states = self.norm3(hidden_states)
289 | norm_hidden_states = (norm_hidden_states.float() * (1 + scale_mlp) + shift_mlp).to(norm_hidden_states.dtype)
290 | ff_output = self.mlp(norm_hidden_states)
291 | ff_output = (gate_mlp * ff_output.float()).to(ff_output.dtype)
292 | hidden_states = ff_output + hidden_states
293 |
294 | return hidden_states
295 |
296 |
297 | class PixelFlowModel(torch.nn.Module):
298 | def __init__(self, in_channels, out_channels, num_attention_heads, attention_head_dim,
299 | depth, patch_size, dropout=0.0, cross_attention_dim=None, attention_bias=True, num_classes=0,
300 | ):
301 | super().__init__()
302 | self.patch_size = patch_size
303 | self.attention_head_dim = attention_head_dim
304 | self.num_classes = num_classes
305 | self.out_channels = out_channels
306 |
307 | embed_dim = num_attention_heads * attention_head_dim
308 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
309 |
310 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
311 | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embed_dim)
312 |
313 | # [stage] embedding
314 | self.latent_size_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embed_dim)
315 | if self.num_classes > 0:
316 | # class conditional
317 | self.class_embedder = LabelEmbedding(num_classes, embed_dim, dropout_prob=0.1)
318 |
319 | self.transformer_blocks = nn.ModuleList([
320 | TransformerBlock(embed_dim, num_attention_heads, attention_head_dim, dropout, cross_attention_dim, attention_bias) for _ in range(depth)
321 | ])
322 |
323 | self.norm_out = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
324 | self.proj_out_1 = nn.Linear(embed_dim, 2 * embed_dim)
325 | self.proj_out_2 = nn.Linear(embed_dim, patch_size * patch_size * out_channels)
326 |
327 | self.initialize_from_scratch()
328 |
329 | def initialize_from_scratch(self):
330 | print("Starting Initialization...")
331 | def _basic_init(module):
332 | if isinstance(module, nn.Linear):
333 | torch.nn.init.xavier_uniform_(module.weight)
334 | if module.bias is not None:
335 | nn.init.constant_(module.bias, 0)
336 | self.apply(_basic_init)
337 |
338 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
339 | w = self.patch_embed.proj.weight.data
340 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
341 | nn.init.constant_(self.patch_embed.proj.bias, 0)
342 |
343 | nn.init.normal_(self.timestep_embedder.linear_1.weight, std=0.02)
344 | nn.init.normal_(self.timestep_embedder.linear_2.weight, std=0.02)
345 |
346 | nn.init.normal_(self.latent_size_embedder.linear_1.weight, std=0.02)
347 | nn.init.normal_(self.latent_size_embedder.linear_2.weight, std=0.02)
348 |
349 | if self.num_classes > 0:
350 | nn.init.normal_(self.class_embedder.embedding_table.weight, std=0.02)
351 |
352 | for block in self.transformer_blocks:
353 | nn.init.constant_(block.norm1.linear.weight, 0)
354 | nn.init.constant_(block.norm1.linear.bias, 0)
355 |
356 | nn.init.constant_(self.proj_out_1.weight, 0)
357 | nn.init.constant_(self.proj_out_1.bias, 0)
358 | nn.init.constant_(self.proj_out_2.weight, 0)
359 | nn.init.constant_(self.proj_out_2.bias, 0)
360 |
361 | def forward(
362 | self,
363 | hidden_states,
364 | encoder_hidden_states=None,
365 | class_labels=None,
366 | timestep=None,
367 | latent_size=None,
368 | encoder_attention_mask=None,
369 | pos_embed=None,
370 | cu_seqlens_q=None,
371 | cu_seqlens_k=None,
372 | seqlen_list_q=None,
373 | seqlen_list_k=None,
374 | ):
375 | # convert encoder_attention_mask to a bias the same way we do for attention_mask
376 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
377 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
378 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
379 |
380 | orig_height, orig_width = hidden_states.shape[-2], hidden_states.shape[-1]
381 | hidden_states = hidden_states.to(torch.float32)
382 | hidden_states = self.patch_embed(hidden_states)
383 |
384 | # timestep, class_embed, latent_size_embed
385 | timesteps_proj = self.time_proj(timestep)
386 | conditioning = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
387 |
388 | if self.num_classes > 0:
389 | class_embed = self.class_embedder(class_labels)
390 | conditioning += class_embed
391 |
392 | latent_size_proj = self.time_proj(latent_size)
393 | latent_size_embed = self.latent_size_embedder(latent_size_proj.to(dtype=hidden_states.dtype))
394 | conditioning += latent_size_embed
395 |
396 | for block in self.transformer_blocks:
397 | hidden_states = block(
398 | hidden_states,
399 | encoder_hidden_states=encoder_hidden_states,
400 | encoder_attention_mask=encoder_attention_mask,
401 | timestep=conditioning,
402 | rope_pos_embed=pos_embed,
403 | cu_seqlens_q=cu_seqlens_q,
404 | cu_seqlens_k=cu_seqlens_k,
405 | seqlen_list_q=seqlen_list_q,
406 | seqlen_list_k=seqlen_list_k,
407 | )
408 |
409 | shift, scale = self.proj_out_1(F.silu(conditioning)).float().chunk(2, dim=1)
410 | if seqlen_list_q is None:
411 | shift = shift.unsqueeze(1)
412 | scale = scale.unsqueeze(1)
413 | else:
414 | shift = torch.cat([shift_i[None].expand(ri, -1) for shift_i, ri in zip(shift, seqlen_list_q)])
415 | scale = torch.cat([scale_i[None].expand(ri, -1) for scale_i, ri in zip(scale, seqlen_list_q)])
416 |
417 | hidden_states = (self.norm_out(hidden_states).float() * (1 + scale) + shift).to(hidden_states.dtype)
418 | hidden_states = self.proj_out_2(hidden_states)
419 | if self.training:
420 | hidden_states = hidden_states.reshape(hidden_states.shape[0], self.patch_size, self.patch_size, self.out_channels)
421 | hidden_states = hidden_states.permute(0, 3, 1, 2).flatten(1)
422 | return hidden_states
423 |
424 | height, width = orig_height // self.patch_size, orig_width // self.patch_size
425 | hidden_states = hidden_states.reshape(
426 | shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
427 | )
428 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
429 | output = hidden_states.reshape(
430 | shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
431 | )
432 |
433 | return output
434 |
435 | def c2i_forward_cfg_torchdiffq(self, hidden_states, timestep, class_labels, latent_size, pos_embed, cfg_scale):
436 | # used for evaluation with ODE ('dopri5') solver from torchdiffeq
437 | half = hidden_states[: len(hidden_states)//2]
438 | combined = torch.cat([half, half], dim=0)
439 | out = self.forward(
440 | hidden_states=combined,
441 | timestep=timestep,
442 | class_labels=class_labels,
443 | latent_size=latent_size,
444 | pos_embed=pos_embed,
445 | )
446 | uncond_out, cond_out = torch.split(out, len(out)//2, dim=0)
447 | half_output = uncond_out + cfg_scale * (cond_out - uncond_out)
448 | output = torch.cat([half_output, half_output], dim=0)
449 | return output
450 |
--------------------------------------------------------------------------------
/pixelflow/pipeline_pixelflow.py:
--------------------------------------------------------------------------------
1 | from einops import rearrange
2 | import math
3 | from typing import List, Optional, Union
4 | import time
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | from diffusers.utils.torch_utils import randn_tensor
9 | from diffusers.models.embeddings import get_2d_rotary_pos_embed
10 |
11 |
12 | class PixelFlowPipeline:
13 | def __init__(
14 | self,
15 | scheduler,
16 | transformer,
17 | text_encoder=None,
18 | tokenizer=None,
19 | max_token_length=512,
20 | ):
21 | super().__init__()
22 | self.class_cond = text_encoder is None or tokenizer is None
23 | self.scheduler = scheduler
24 | self.transformer = transformer
25 | self.patch_size = transformer.patch_size
26 | self.head_dim = transformer.attention_head_dim
27 | self.num_stages = scheduler.num_stages
28 |
29 | self.text_encoder = text_encoder
30 | self.tokenizer = tokenizer
31 | self.max_token_length = max_token_length
32 |
33 | @torch.autocast("cuda", enabled=False)
34 | def encode_prompt(
35 | self,
36 | prompt: Union[str, List[str]],
37 | device: Optional[torch.device] = None,
38 | num_images_per_prompt: int = 1,
39 | do_classifier_free_guidance: bool = True,
40 | negative_prompt: Union[str, List[str]] = "",
41 | prompt_embeds: Optional[torch.FloatTensor] = None,
42 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
43 | prompt_attention_mask: Optional[torch.FloatTensor] = None,
44 | negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
45 | use_attention_mask: bool = False,
46 | max_length: int = 512,
47 | ):
48 | # Determine the batch size and normalize prompt input to a list
49 | if prompt is not None:
50 | if isinstance(prompt, str):
51 | prompt = [prompt]
52 | batch_size = len(prompt)
53 | else:
54 | batch_size = prompt_embeds.shape[0]
55 |
56 | # Process prompt embeddings if not provided
57 | if prompt_embeds is None:
58 | text_inputs = self.tokenizer(
59 | prompt,
60 | padding="max_length",
61 | max_length=max_length,
62 | truncation=True,
63 | add_special_tokens=True,
64 | return_tensors="pt",
65 | )
66 | text_input_ids = text_inputs.input_ids.to(device)
67 | prompt_attention_mask = text_inputs.attention_mask.to(device)
68 | prompt_embeds = self.text_encoder(
69 | text_input_ids,
70 | attention_mask=prompt_attention_mask if use_attention_mask else None
71 | )[0]
72 |
73 | # Determine dtype from available encoder
74 | if self.text_encoder is not None:
75 | dtype = self.text_encoder.dtype
76 | elif self.transformer is not None:
77 | dtype = self.transformer.dtype
78 | else:
79 | dtype = None
80 |
81 | # Move prompt embeddings to desired dtype and device
82 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
83 |
84 | bs_embed, seq_len, _ = prompt_embeds.shape
85 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
86 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
87 | prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1)
88 |
89 | # Handle classifier-free guidance for negative prompts
90 | if do_classifier_free_guidance and negative_prompt_embeds is None:
91 | # Normalize negative prompt to list and validate length
92 | if isinstance(negative_prompt, str):
93 | uncond_tokens = [negative_prompt] * batch_size
94 | elif isinstance(negative_prompt, list):
95 | if len(negative_prompt) != batch_size:
96 | raise ValueError(f"The negative prompt list must have the same length as the prompt list, but got {len(negative_prompt)} and {batch_size}")
97 | uncond_tokens = negative_prompt
98 | else:
99 | raise ValueError(f"Negative prompt must be a string or a list of strings, but got {type(negative_prompt)}")
100 |
101 | # Tokenize and encode negative prompts
102 | uncond_inputs = self.tokenizer(
103 | uncond_tokens,
104 | padding="max_length",
105 | max_length=prompt_embeds.shape[1],
106 | truncation=True,
107 | return_attention_mask=True,
108 | add_special_tokens=True,
109 | return_tensors="pt",
110 | )
111 | negative_input_ids = uncond_inputs.input_ids.to(device)
112 | negative_prompt_attention_mask = uncond_inputs.attention_mask.to(device)
113 | negative_prompt_embeds = self.text_encoder(
114 | negative_input_ids,
115 | attention_mask=negative_prompt_attention_mask if use_attention_mask else None
116 | )[0]
117 |
118 | if do_classifier_free_guidance:
119 | # Duplicate negative prompt embeddings and attention mask for each generation
120 | seq_len_neg = negative_prompt_embeds.shape[1]
121 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
122 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
123 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len_neg, -1)
124 | negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1)
125 | else:
126 | negative_prompt_embeds = None
127 | negative_prompt_attention_mask = None
128 |
129 | # Concatenate negative and positive embeddings and their masks
130 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
131 | prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
132 |
133 | return prompt_embeds, prompt_attention_mask
134 |
135 | def sample_block_noise(self, bs, ch, height, width, eps=1e-6):
136 | gamma = self.scheduler.gamma
137 | dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 - gamma) + torch.ones(4, 4) * gamma + eps * torch.eye(4))
138 | block_number = bs * ch * (height // 2) * (width // 2)
139 | noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4]
140 | noise = rearrange(noise, '(b c h w) (p q) -> b c (h p) (w q)',b=bs,c=ch,h=height//2,w=width//2,p=2,q=2)
141 | return noise
142 |
143 | @torch.no_grad()
144 | def __call__(
145 | self,
146 | prompt,
147 | height,
148 | width,
149 | num_inference_steps=30,
150 | guidance_scale=4.0,
151 | num_images_per_prompt=1,
152 | device=None,
153 | shift=1.0,
154 | use_ode_dopri5=False,
155 | ):
156 | if isinstance(num_inference_steps, int):
157 | num_inference_steps = [num_inference_steps] * self.num_stages
158 |
159 | if use_ode_dopri5:
160 | assert self.class_cond, "ODE (dopri5) sampling is only supported for class-conditional models now"
161 | from pixelflow.solver_ode_wrapper import ODE
162 | sample_fn = ODE(t0=0, t1=1, sampler_type="dopri5", num_steps=num_inference_steps[0], atol=1e-06, rtol=0.001).sample
163 | else:
164 | # default Euler
165 | sample_fn = None
166 |
167 | self._guidance_scale = guidance_scale
168 | batch_size = len(prompt)
169 | if self.class_cond:
170 | prompt_embeds = torch.tensor(prompt, dtype=torch.int32).to(device)
171 | negative_prompt_embeds = 1000 * torch.ones_like(prompt_embeds)
172 | if self.do_classifier_free_guidance:
173 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
174 | else:
175 | prompt_embeds, prompt_attention_mask = self.encode_prompt(
176 | prompt,
177 | device,
178 | num_images_per_prompt,
179 | guidance_scale > 1,
180 | "",
181 | prompt_embeds=None,
182 | negative_prompt_embeds=None,
183 | use_attention_mask=True,
184 | max_length=self.max_token_length,
185 | )
186 |
187 | init_factor = 2 ** (self.num_stages - 1)
188 | height, width = height // init_factor, width // init_factor
189 | shape = (batch_size * num_images_per_prompt, 3, height, width)
190 | latents = randn_tensor(shape, device=device, dtype=torch.float32)
191 |
192 | for stage_idx in range(self.num_stages):
193 | stage_start = time.time()
194 | # Set the number of inference steps for the current stage
195 | self.scheduler.set_timesteps(num_inference_steps[stage_idx], stage_idx, device=device, shift=shift)
196 | Timesteps = self.scheduler.Timesteps
197 |
198 | if stage_idx > 0:
199 | height, width = height * 2, width * 2
200 | latents = F.interpolate(latents, size=(height, width), mode='nearest')
201 | original_start_t = self.scheduler.original_start_t[stage_idx]
202 | gamma = self.scheduler.gamma
203 | alpha = 1 / (math.sqrt(1 - (1 / gamma)) * (1 - original_start_t) + original_start_t)
204 | beta = alpha * (1 - original_start_t) / math.sqrt(- gamma)
205 |
206 | # bs, ch, height, width = latents.shape
207 | noise = self.sample_block_noise(*latents.shape)
208 | noise = noise.to(device=device, dtype=latents.dtype)
209 | latents = alpha * latents + beta * noise
210 |
211 | size_tensor = torch.tensor([latents.shape[-1] // self.patch_size], dtype=torch.int32, device=device)
212 | pos_embed = get_2d_rotary_pos_embed(
213 | embed_dim=self.head_dim,
214 | crops_coords=((0, 0), (latents.shape[-1] // self.patch_size, latents.shape[-1] // self.patch_size)),
215 | grid_size=(latents.shape[-1] // self.patch_size, latents.shape[-1] // self.patch_size),
216 | )
217 | rope_pos = torch.stack(pos_embed, -1)
218 |
219 | if sample_fn is not None:
220 | # dopri5
221 | model_kwargs = dict(class_labels=prompt_embeds, cfg_scale=self.guidance_scale(None, stage_idx), latent_size=size_tensor, pos_embed=rope_pos)
222 | if stage_idx == 0:
223 | latents = torch.cat([latents] * 2)
224 | stage_T_start = self.scheduler.Timesteps_per_stage[stage_idx][0].item()
225 | stage_T_end = self.scheduler.Timesteps_per_stage[stage_idx][-1].item()
226 | latents = sample_fn(latents, self.transformer.c2i_forward_cfg_torchdiffq, stage_T_start, stage_T_end, **model_kwargs)[-1]
227 | if stage_idx == self.num_stages - 1:
228 | latents = latents[:latents.shape[0] // 2]
229 | else:
230 | # euler
231 | for T in Timesteps:
232 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
233 | timestep = T.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
234 | if self.class_cond:
235 | noise_pred = self.transformer(latent_model_input, timestep=timestep, class_labels=prompt_embeds, latent_size=size_tensor, pos_embed=rope_pos)
236 | else:
237 | encoder_hidden_states = prompt_embeds
238 | encoder_attention_mask = prompt_attention_mask
239 |
240 | noise_pred = self.transformer(
241 | latent_model_input,
242 | encoder_hidden_states=encoder_hidden_states,
243 | encoder_attention_mask=encoder_attention_mask,
244 | timestep=timestep,
245 | latent_size=size_tensor,
246 | pos_embed=rope_pos,
247 | )
248 |
249 | if self.do_classifier_free_guidance:
250 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
251 | noise_pred = noise_pred_uncond + self.guidance_scale(T, stage_idx) * (noise_pred_text - noise_pred_uncond)
252 |
253 | latents = self.scheduler.step(model_output=noise_pred, sample=latents)
254 | stage_end = time.time()
255 |
256 | samples = (latents / 2 + 0.5).clamp(0, 1)
257 | samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
258 | return samples
259 |
260 | @property
261 | def device(self):
262 | return next(self.transformer.parameters()).device
263 |
264 | @property
265 | def dtype(self):
266 | return next(self.transformer.parameters()).dtype
267 |
268 | def guidance_scale(self, step=None, stage_idx=None):
269 | if not self.class_cond:
270 | return self._guidance_scale
271 | scale_dict = {0: 0, 1: 1/6, 2: 2/3, 3: 1}
272 | return (self._guidance_scale - 1) * scale_dict[stage_idx] + 1
273 |
274 | @property
275 | def do_classifier_free_guidance(self):
276 | return self._guidance_scale > 0
277 |
--------------------------------------------------------------------------------
/pixelflow/scheduling_pixelflow.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def cal_rectify_ratio(start_t, gamma):
7 | return 1 / (math.sqrt(1 - (1 / gamma)) * (1 - start_t) + start_t)
8 |
9 |
10 | class PixelFlowScheduler:
11 | def __init__(self, num_train_timesteps, num_stages, gamma=-1 / 3):
12 | assert num_stages > 0, f"num_stages must be positive, got {num_stages}"
13 | self.num_stages = num_stages
14 | self.gamma = gamma
15 |
16 | self.Timesteps = torch.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=torch.float32)
17 |
18 | self.t = self.Timesteps / num_train_timesteps # normalized time in [0, 1]
19 |
20 | self.stage_range = [x / num_stages for x in range(num_stages + 1)]
21 |
22 | self.original_start_t = dict()
23 | self.start_t, self.end_t = dict(), dict()
24 | self.t_window_per_stage = dict()
25 | self.Timesteps_per_stage = dict()
26 | stage_distance = list()
27 |
28 | # stage_idx = 0: min t, min resolution, most noisy
29 | # stage_idx = num_stages - 1 : max t, max resolution, most clear
30 | for stage_idx in range(num_stages):
31 | start_idx = max(int(num_train_timesteps * self.stage_range[stage_idx]), 0)
32 | end_idx = min(int(num_train_timesteps * self.stage_range[stage_idx + 1]), num_train_timesteps)
33 |
34 | start_t = self.t[start_idx].item()
35 | end_t = self.t[end_idx].item() if end_idx < num_train_timesteps else 1.0
36 |
37 | self.original_start_t[stage_idx] = start_t
38 |
39 | if stage_idx > 0:
40 | start_t *= cal_rectify_ratio(start_t, gamma)
41 |
42 | self.start_t[stage_idx] = start_t
43 | self.end_t[stage_idx] = end_t
44 | stage_distance.append(end_t - start_t)
45 |
46 | total_stage_distance = sum(stage_distance)
47 | t_within_stage = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float64)[:-1]
48 |
49 | for stage_idx in range(num_stages):
50 | start_ratio = 0.0 if stage_idx == 0 else sum(stage_distance[:stage_idx]) / total_stage_distance
51 | end_ratio = 1.0 if stage_idx == num_stages - 1 else sum(stage_distance[:stage_idx + 1]) / total_stage_distance
52 |
53 | Timestep_start = self.Timesteps[int(num_train_timesteps * start_ratio)]
54 | Timestep_end = self.Timesteps[min(int(num_train_timesteps * end_ratio), num_train_timesteps - 1)]
55 |
56 | self.t_window_per_stage[stage_idx] = t_within_stage
57 |
58 | if stage_idx == num_stages - 1:
59 | self.Timesteps_per_stage[stage_idx] = torch.linspace(Timestep_start.item(), Timestep_end.item(), num_train_timesteps, dtype=torch.float64)
60 | else:
61 | self.Timesteps_per_stage[stage_idx] = torch.linspace(Timestep_start.item(), Timestep_end.item(), num_train_timesteps + 1, dtype=torch.float64)[:-1]
62 |
63 | @staticmethod
64 | def time_linear_to_Timesteps(t, t_start, t_end, T_start, T_end):
65 | """
66 | linearly map t to T: T = k * t + b
67 | """
68 | k = (T_end - T_start) / (t_end - t_start)
69 | b = T_start - t_start * k
70 | return k * t + b
71 |
72 | def set_timesteps(self, num_inference_steps, stage_index, device=None, shift=1.0):
73 | self.num_inference_steps = num_inference_steps
74 |
75 | stage_T_start = self.Timesteps_per_stage[stage_index][0].item()
76 | stage_T_end = self.Timesteps_per_stage[stage_index][-1].item()
77 |
78 | t_start = self.t_window_per_stage[stage_index][0].item()
79 | t_end = self.t_window_per_stage[stage_index][-1].item()
80 |
81 | t = np.linspace(t_start, t_end, num_inference_steps, dtype=np.float64)
82 | t = t / (shift + (1 - shift) * t)
83 |
84 | Timesteps = self.time_linear_to_Timesteps(t, t_start, t_end, stage_T_start, stage_T_end)
85 | self.Timesteps = torch.from_numpy(Timesteps).to(device=device)
86 |
87 | self.t = torch.from_numpy(np.append(t, 1.0)).to(device=device, dtype=torch.float64)
88 | self._step_index = None
89 |
90 | def step(self, model_output, sample):
91 | if self.step_index is None:
92 | self._step_index = 0
93 |
94 | sample = sample.to(torch.float32)
95 | t = self.t[self.step_index].float()
96 | t_next = self.t[self.step_index + 1].float()
97 |
98 | prev_sample = sample + (t_next - t) * model_output
99 | self._step_index += 1
100 |
101 | return prev_sample.to(model_output.dtype)
102 |
103 | @property
104 | def step_index(self):
105 | """Current step index for the scheduler."""
106 | return self._step_index
107 |
--------------------------------------------------------------------------------
/pixelflow/solver_ode_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchdiffeq import odeint
3 |
4 |
5 | # https://github.com/willisma/SiT/blob/main/transport/integrators.py#L77
6 | class ODE:
7 | """ODE solver class"""
8 | def __init__(
9 | self,
10 | *,
11 | t0,
12 | t1,
13 | sampler_type,
14 | num_steps,
15 | atol,
16 | rtol,
17 | ):
18 | assert t0 < t1, "ODE sampler has to be in forward time"
19 |
20 | self.t = torch.linspace(t0, t1, num_steps)
21 | self.atol = atol
22 | self.rtol = rtol
23 | self.sampler_type = sampler_type
24 |
25 | def time_linear_to_Timesteps(self, t, t_start, t_end, T_start, T_end):
26 | # T = k * t + b
27 | k = (T_end - T_start) / (t_end - t_start)
28 | b = T_start - t_start * k
29 | return k * t + b
30 |
31 | def sample(self, x, model, T_start, T_end, **model_kwargs):
32 | device = x[0].device if isinstance(x, tuple) else x.device
33 | def _fn(t, x):
34 | t = torch.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else torch.ones(x.size(0)).to(device) * t
35 | model_output = model(x, self.time_linear_to_Timesteps(t, 0, 1, T_start, T_end), **model_kwargs)
36 | assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
37 | return model_output
38 |
39 | t = self.t.to(device)
40 | atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
41 | rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
42 | samples = odeint(
43 | _fn,
44 | x,
45 | t,
46 | method=self.sampler_type,
47 | atol=atol,
48 | rtol=rtol
49 | )
50 | return samples
51 |
--------------------------------------------------------------------------------
/pixelflow/utils/config.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 |
4 | def get_obj_from_str(string, reload=False):
5 | module, cls = string.rsplit(".", 1)
6 | if reload:
7 | module_imp = importlib.import_module(module)
8 | importlib.reload(module_imp)
9 | return getattr(importlib.import_module(module, package=None), cls)
10 |
11 |
12 | def instantiate_from_config(config):
13 | if not "target" in config:
14 | raise KeyError("Expected key `target` to instantiate.")
15 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
16 |
17 |
18 | def instantiate_optimizer_from_config(config, params):
19 | if not "target" in config:
20 | raise KeyError("Expected key `target` to instantiate.")
21 | return get_obj_from_str(config["target"])(params, **config.get("params", dict()))
22 |
23 |
24 | def instantiate_dataset_from_config(config, transform):
25 | if not "target" in config:
26 | raise KeyError("Expected key `target` to instantiate.")
27 | return get_obj_from_str(config["target"])(transform=transform, **config.get("params", dict()))
28 |
--------------------------------------------------------------------------------
/pixelflow/utils/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 |
5 | class PathSimplifierFormatter(logging.Formatter):
6 | def format(self, record):
7 | record.short_path = os.path.relpath(record.pathname)
8 | return super().format(record)
9 |
10 |
11 | def setup_logger(log_directory, experiment_name, process_rank, source_module=__name__):
12 | handlers = [logging.StreamHandler()]
13 |
14 | if process_rank == 0:
15 | log_file_path = os.path.join(log_directory, f"{experiment_name}.log")
16 | handlers.append(logging.FileHandler(log_file_path))
17 |
18 | log_formatter = PathSimplifierFormatter(
19 | fmt='[%(asctime)s %(short_path)s:%(lineno)d] %(message)s',
20 | datefmt='%Y-%m-%d %H:%M:%S'
21 | )
22 |
23 | for handler in handlers:
24 | handler.setFormatter(log_formatter)
25 |
26 | logging.basicConfig(level=logging.INFO, handlers=handlers)
27 | return logging.getLogger(source_module)
28 |
--------------------------------------------------------------------------------
/pixelflow/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import torch
5 |
6 | def seed_everything(seed=0, deterministic_ops=True, allow_tf32=False):
7 | """
8 | Sets the seed for reproducibility across various libraries and frameworks, and configures PyTorch backend settings.
9 |
10 | Args:
11 | seed (int): The seed value for random number generation. Default is 0.
12 | deterministic_ops (bool): Whether to enable deterministic operations in PyTorch.
13 | Enabling this can make results reproducible at the cost of potential performance degradation. Default is True.
14 | allow_tf32 (bool): Whether to allow TensorFloat-32 (TF32) precision in PyTorch operations. TF32 can improve performance but may affect reproducibility. Default is False.
15 |
16 | Effects:
17 | - Seeds Python's random module, NumPy, and PyTorch (CPU and GPU).
18 | - Sets the environment variable `PYTHONHASHSEED` to the specified seed.
19 | - Configures PyTorch to use deterministic algorithms if `deterministic_ops` is True.
20 | - Configures TensorFloat-32 precision based on `allow_tf32`.
21 | - Issues warnings if configurations may impact reproducibility.
22 |
23 | Notes:
24 | - Setting `torch.backends.cudnn.deterministic` to False allows nondeterministic operations, which may introduce variability.
25 | - Allowing TF32 (`allow_tf32=True`) may lead to non-reproducible results, especially in matrix operations.
26 | """
27 | # Seed standard random number generators
28 | random.seed(seed)
29 | np.random.seed(seed)
30 | os.environ['PYTHONHASHSEED'] = str(seed)
31 |
32 | # Seed PyTorch random number generators
33 | torch.manual_seed(seed)
34 | torch.cuda.manual_seed_all(seed)
35 |
36 | # Configure deterministic operations
37 | if deterministic_ops:
38 | torch.backends.cudnn.deterministic = True
39 | torch.use_deterministic_algorithms(True)
40 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
41 | else:
42 | torch.backends.cudnn.deterministic = False
43 | print("WARNING: torch.backends.cudnn.deterministic is set to False, reproducibility is not guaranteed.")
44 |
45 | # Configure TensorFloat-32 precision
46 | if allow_tf32:
47 | print("WARNING: TensorFloat-32 (TF32) is enabled; reproducibility is not guaranteed.")
48 |
49 | torch.backends.cudnn.allow_tf32 = allow_tf32 # Default True in PyTorch 2.6.0
50 | torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Default False in PyTorch 2.6.0
51 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas
2 | pyarrow
3 | omegaconf
4 | diffusers==0.32.2
5 | transformers==4.48.0
6 | torchdiffeq==0.2.4
7 |
--------------------------------------------------------------------------------
/sample_ddp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import numpy as np
4 | import os
5 | from omegaconf import OmegaConf
6 | from tqdm import tqdm
7 | from PIL import Image
8 | import math
9 |
10 | import torch
11 | torch.backends.cuda.matmul.allow_tf32 = True
12 | torch.backends.cudnn.allow_tf32 = True
13 | import torch.distributed as dist
14 |
15 | from pixelflow.scheduling_pixelflow import PixelFlowScheduler
16 | from pixelflow.pipeline_pixelflow import PixelFlowPipeline
17 | from pixelflow.utils import config as config_utils
18 |
19 |
20 | def get_args_parser():
21 | parser = argparse.ArgumentParser(description='sample 50k images for FID evaluation', add_help=False)
22 | parser.add_argument('--pretrained', type=str, required=True, help='Pretrained model path')
23 |
24 | parser.add_argument("--sample-dir", type=str, default="evaluate_256pix_folder")
25 | parser.add_argument("--cfg", type=float, default=2.4)
26 | parser.add_argument("--num-steps-per-stage", type=int, default=30)
27 | parser.add_argument("--use-ode-dopri5", action="store_true")
28 | parser.add_argument("--local-batch-size", type=int, default=16)
29 | parser.add_argument("--num-fid-samples", type=int, default=50000)
30 | parser.add_argument("--num-classes", type=int, default=1000)
31 | parser.add_argument("--global-seed", type=int, default=10)
32 | return parser
33 |
34 |
35 | def create_npz_from_sample_folder(sample_dir, num=50_000):
36 | """
37 | Builds a single .npz file from a folder of .png samples.
38 | """
39 | samples = []
40 | for i in tqdm(range(num), desc="Building .npz file from samples"):
41 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
42 | sample_np = np.asarray(sample_pil).astype(np.uint8)
43 | samples.append(sample_np)
44 | samples = np.stack(samples)
45 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
46 | npz_path = f"{sample_dir}.npz"
47 | np.savez(npz_path, arr_0=samples)
48 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
49 | return npz_path
50 |
51 |
52 | def main(args):
53 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
54 | torch.set_grad_enabled(False)
55 |
56 | # Setup DDP
57 | rank = int(os.environ["RANK"])
58 | local_rank = int(os.environ["LOCAL_RANK"])
59 | world_size = int(os.environ["WORLD_SIZE"])
60 |
61 | dist.init_process_group("nccl", rank=rank, world_size=world_size, timeout=datetime.timedelta(hours=1))
62 | device = torch.device(f"cuda:{local_rank}")
63 | torch.cuda.set_device(device)
64 |
65 | seed = args.global_seed * dist.get_world_size() + rank
66 | torch.manual_seed(seed)
67 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
68 | print(args)
69 |
70 | # create and load model
71 | config = OmegaConf.load(f"{args.pretrained}/config.yaml")
72 | model = config_utils.instantiate_from_config(config.model).to(device)
73 | ckpt = torch.load(f"{args.pretrained}/model.pt", map_location="cpu", weights_only=True)
74 | model.load_state_dict(ckpt, strict=True)
75 | model.eval()
76 |
77 | resolution = config.data.resolution
78 |
79 | scheduler = PixelFlowScheduler(
80 | config.scheduler.num_train_timesteps, num_stages=config.scheduler.num_stages, gamma=-1/3
81 | )
82 | num_steps_per_stage = [args.num_steps_per_stage for _ in range(config.scheduler.num_stages)]
83 | pipeline = PixelFlowPipeline(scheduler, model)
84 |
85 | # Create folder to save samples:
86 | sample_folder_dir = args.sample_dir
87 | if rank == 0:
88 | os.makedirs(sample_folder_dir, exist_ok=True)
89 | print(f"Saving .png samples at {sample_folder_dir}")
90 | dist.barrier()
91 |
92 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
93 | local_batch_size = args.local_batch_size
94 | global_batch_size = local_batch_size * dist.get_world_size()
95 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
96 | total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
97 | if rank == 0:
98 | print(f"Total number of images that will be sampled: {total_samples}")
99 | assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
100 |
101 | # Number of images per class is equal (and pad 0)
102 | class_list_global = torch.zeros((total_samples,), device=device)
103 | num_classes = args.num_classes
104 | class_list = torch.arange(0, num_classes).repeat(args.num_fid_samples // num_classes)
105 | class_list_global[:class_list.shape[0]] = class_list
106 |
107 | local_samples = int(total_samples // dist.get_world_size())
108 | assert local_samples % local_batch_size == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
109 | iterations = int(local_samples // local_batch_size)
110 | pbar = range(iterations)
111 | pbar = tqdm(pbar) if rank == 0 else pbar
112 | total = 0
113 | for _ in pbar:
114 | cur_index = torch.arange(local_batch_size) * dist.get_world_size() + rank + total
115 | cur_class_list = class_list_global[cur_index]
116 |
117 | with torch.autocast("cuda", dtype=torch.bfloat16), torch.no_grad():
118 | samples = pipeline(
119 | prompt=cur_class_list,
120 | num_inference_steps=list(num_steps_per_stage),
121 | height=resolution,
122 | width=resolution,
123 | guidance_scale=args.cfg,
124 | device=device,
125 | use_ode_dopri5=args.use_ode_dopri5,
126 | )
127 | samples = (samples * 255).round().astype("uint8")
128 | image_list = [Image.fromarray(sample) for sample in samples]
129 |
130 | for img_ind, image in enumerate(image_list):
131 | index = img_ind * dist.get_world_size() + rank + total
132 | if index < args.num_fid_samples:
133 | image.save(f"{sample_folder_dir}/{index:06d}.png")
134 |
135 | total += global_batch_size
136 |
137 | # Make sure all processes have finished saving their samples before attempting to convert to .npz
138 | dist.barrier()
139 | if rank == 0:
140 | create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
141 | print("Done.")
142 | dist.barrier()
143 | dist.destroy_process_group()
144 |
145 | if __name__ == "__main__":
146 | parser = get_args_parser()
147 | args = parser.parse_args()
148 | main(args)
149 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import argparse
4 | import copy
5 | from collections import OrderedDict
6 | from datetime import datetime
7 | from omegaconf import OmegaConf
8 | import torch
9 | import torch.distributed as dist
10 |
11 | from pixelflow.scheduling_pixelflow import PixelFlowScheduler
12 | from pixelflow.utils.logger import setup_logger
13 | from pixelflow.utils import config as config_utils
14 | from pixelflow.utils.misc import seed_everything
15 | from pixelflow.data_in1k import build_imagenet_loader
16 |
17 |
18 | def get_args_parser():
19 | parser = argparse.ArgumentParser(description='Action Tokenizer', add_help=False)
20 | parser.add_argument('config', type=str, help='config')
21 | parser.add_argument('--output-dir', default='./exp0001', help='Output directory')
22 | parser.add_argument('--logging-steps', type=int, default=10, help='Logging steps')
23 | parser.add_argument('--checkpoint-steps', type=int, default=1000, help='Checkpoint steps')
24 | parser.add_argument('--pretrained-model', type=str, default=None, help='Pretrained model')
25 | parser.add_argument('--report-to', type=str, default=None, help='Report to, eg. wandb')
26 |
27 | return parser
28 |
29 |
30 | @torch.no_grad()
31 | def update_ema(ema_model, model, decay=0.9999):
32 | """
33 | Step the EMA model towards the current model.
34 | """
35 | ema_params = OrderedDict(ema_model.named_parameters())
36 | model_params = OrderedDict(model.named_parameters())
37 |
38 | for name, param in model_params.items():
39 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
40 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
41 |
42 |
43 | def main(args):
44 | rank = int(os.environ["RANK"])
45 | local_rank = int(os.environ["LOCAL_RANK"])
46 | world_size = int(os.environ["WORLD_SIZE"])
47 |
48 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
49 | device = torch.device(f"cuda:{local_rank}")
50 | torch.cuda.set_device(device)
51 |
52 | exp_name = "{}".format(datetime.now().strftime('%Y%m%d-%H%M%S'))
53 |
54 | # we print logs from all ranks to console, but only save to file from rank 0
55 | logger = setup_logger(args.output_dir, exp_name, rank, __name__)
56 |
57 | config = OmegaConf.load(args.config)
58 |
59 | rank_seed = config.seed * world_size + rank
60 | seed_everything(rank_seed, deterministic_ops=False, allow_tf32=False)
61 | logger.info(f"Rank: {rank}, Local Rank: {local_rank}, World Size: {world_size} Seed: {rank_seed}")
62 |
63 | # save args and config to output_dir
64 | with open(os.path.join(args.output_dir, "args.txt"), "w") as f:
65 | f.write(str(args))
66 | with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
67 | f.write(OmegaConf.to_yaml(config))
68 |
69 | logger.info(f"Config: {config}")
70 | model = config_utils.instantiate_from_config(config.model).to(device)
71 | ema = copy.deepcopy(model).to(device)
72 | for param in ema.parameters():
73 | param.requires_grad = False
74 |
75 | logger.info(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
76 |
77 | noise_scheduler = PixelFlowScheduler(
78 | num_train_timesteps=config.scheduler.num_train_timesteps,
79 | num_stages=config.scheduler.num_stages,
80 | )
81 | noise_scheduler_copy = copy.deepcopy(noise_scheduler)
82 |
83 | if args.pretrained_model is not None:
84 | ckpt = torch.load(args.pretrained_model, map_location="cpu", weights_only=True)
85 | model.load_state_dict(ckpt, strict=True)
86 |
87 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
88 |
89 | optimizer = torch.optim.AdamW(model.parameters(), lr=config.train.lr, weight_decay=config.train.weight_decay)
90 | data_loader, sampler = build_imagenet_loader(config, noise_scheduler_copy)
91 |
92 | logger.info("***** Running training *****")
93 | global_step = 0
94 | first_epoch = 0
95 |
96 | # Prepare models for training:
97 | update_ema(ema, model.module, decay=0)
98 | model.train()
99 | ema.eval()
100 |
101 | for epoch in range(first_epoch, config.train.epochs):
102 | sampler.set_epoch(epoch)
103 | for step, batch in enumerate(data_loader):
104 | optimizer.zero_grad()
105 | target = batch["target_values"].to(device)
106 | with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16):
107 | model_output = model(
108 | hidden_states=batch["pixel_values"].to(device),
109 | encoder_hidden_states=None,
110 | class_labels=torch.tensor(batch["input_ids"], device=device),
111 | timestep=batch["timesteps"].to(device),
112 | latent_size=batch["batch_latent_size"].to(device),
113 | pos_embed=batch["pos_embed"].to(device),
114 | cu_seqlens_q=batch["cumsum_q_len"].to(device),
115 | cu_seqlens_k=None,
116 | seqlen_list_q=batch["seqlen_list_q"],
117 | seqlen_list_k=None,
118 | )
119 |
120 | loss = (model_output.float() - target.float()) ** 2
121 | loss_split = torch.split(loss, batch["seqlen_list_q"], dim=0)
122 | loss_items = torch.stack([x.mean() for x in loss_split])
123 | if "padding_size" in batch and batch["padding_size"] is not None and batch["padding_size"] > 0:
124 | loss_items = loss_items[:-1]
125 |
126 | loss = loss_items.mean()
127 | loss.backward()
128 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
129 | optimizer.step()
130 | update_ema(ema, model.module)
131 |
132 | if global_step % args.logging_steps == 0:
133 | logger.info(f"Epoch: {epoch}, Step: {step}, Loss: {loss.item()}, Grad Norm: {grad_norm.item()}")
134 |
135 | if global_step % args.checkpoint_steps == 0 and global_step > 0:
136 | if rank == 0:
137 | torch.save(
138 | {
139 | "model": model.module.state_dict(),
140 | "ema": ema.state_dict(),
141 | "opt": optimizer.state_dict(),
142 | "args": args
143 | },
144 | os.path.join(args.output_dir, f"model_{global_step}.pt"))
145 | logger.info(f"Model saved at step {global_step}")
146 | dist.barrier()
147 |
148 | global_step += 1
149 |
150 |
151 | if __name__ == '__main__':
152 | parser = get_args_parser()
153 | args = parser.parse_args()
154 | os.makedirs(args.output_dir, exist_ok=True)
155 | main(args)
156 |
--------------------------------------------------------------------------------