├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── architectures
├── convnet_blocks.py
├── resnet.py
├── tape.py
├── transformers.py
└── transunet.py
├── colabs
├── pix2seq_finetuning_object_detection.ipynb
├── pix2seq_inference_multitask.ipynb
└── pix2seq_inference_object_detection.ipynb
├── configs
├── config_base.py
├── config_det_finetune.py
├── config_diffusion_base.py
├── config_diffusion_cifar10.py
├── config_diffusion_cifar10d.py
├── config_diffusion_imagenet64.py
├── config_diffusion_panoptic_base.py
├── config_diffusion_panoptic_image.py
├── config_diffusion_panoptic_video.py
├── config_multi_task.py
├── dataset_configs.py
└── transform_configs.py
├── data
├── cityscapes.py
├── coco.py
├── data_utils.py
├── dataset.py
├── datasets.py
├── decode_utils.py
├── obj365.py
├── recognition.py
├── scripts
│ ├── create_coco_tfrecord.py
│ ├── create_davis_tfrecord.py
│ ├── create_kittistep_tfrecord.py
│ ├── merge_coco_json_tfrecord.py
│ └── tfrecord_lib.py
├── text.py
├── tokenizer.py
├── transforms.py
└── video.py
├── metrics
├── coco_metrics.py
├── fid.py
├── fvd.py
├── metric_registry.py
├── metric_utils.py
├── segmentation_and_tracking_quality.py
├── text_metrics.py
├── vos_metrics.py
└── vps_metrics.py
├── models
├── ar_model.py
├── diffusion_utils.py
├── image_ar_model.py
├── image_diffusion_model.py
├── image_discrete_diffusion_model.py
├── model.py
├── model_utils.py
├── panoptic_diffusion.py
└── video_diffusion_model.py
├── pix2seq.gif
├── pix2seq.png
├── registry.py
├── requirements.txt
├── run.py
├── tasks
├── captioning.py
├── image_generation.py
├── instance_segmentation.py
├── keypoint_detection.py
├── object_detection.py
├── panoptic_segmentation.py
├── recognition.py
├── task.py
├── task_utils.py
├── video_generation.py
├── video_panoptic_segmentation.py
└── visualization
│ ├── shape_utils.py
│ ├── standard_fields.py
│ ├── static_shape.py
│ └── vis_utils.py
├── utils.py
└── vocab.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | SimCLR needs to maintain permanent compatibility with the pre-trained model
4 | files, so we do not plan to make any major changes to this library (other than
5 | what was promised in the README). However, we can accept small patches related
6 | to re-factoring and documentation. To submit contributes, there are just a few
7 | small guidelines you need to follow.
8 |
9 | ## Contributor License Agreement
10 |
11 | Contributions to this project must be accompanied by a Contributor License
12 | Agreement. You (or your employer) retain the copyright to your contribution;
13 | this simply gives us permission to use and redistribute your contributions as
14 | part of the project. Head over to to see
15 | your current agreements on file or to sign a new one.
16 |
17 | You generally only need to submit a CLA once, so if you've already submitted one
18 | (even if it was for a different project), you probably don't need to do it
19 | again.
20 |
21 | ## Code reviews
22 |
23 | All submissions, including submissions by project members, require review. We
24 | use GitHub pull requests for this purpose. Consult
25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
26 | information on using pull requests.
27 |
28 | ## Community Guidelines
29 |
30 | This project follows
31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
32 |
--------------------------------------------------------------------------------
/configs/config_base.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Common / shared settings among multiple configs."""
17 |
18 | import ml_collections
19 |
20 |
21 | def D(**kwargs):
22 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
23 |
24 |
25 | architecture_config_map = {
26 | 'vit-b': D(
27 | resnet_variant='c1',
28 | num_encoder_layers=12,
29 | dim_att=768,
30 | dim_mlp=3072,
31 | num_heads=12,
32 | num_decoder_layers=6,
33 | dim_att_dec=512,
34 | dim_mlp_dec=2048,
35 | num_heads_dec=16,
36 | ),
37 | 'vit-l': D(
38 | resnet_variant='c1',
39 | num_encoder_layers=24,
40 | dim_att=1024,
41 | dim_mlp=4096,
42 | num_heads=16,
43 | num_decoder_layers=8,
44 | dim_att_dec=512,
45 | dim_mlp_dec=2048,
46 | num_heads_dec=16,
47 | ),
48 | 'resnet': D(
49 | resnet_variant='standard',
50 | resnet_depth=50,
51 | resnet_sk_ratio=0.,
52 | resnet_width_multiplier=1,
53 | num_encoder_layers=6,
54 | dim_att=256,
55 | dim_mlp=1024,
56 | num_heads=8,
57 | num_decoder_layers=6,
58 | dim_att_dec=256,
59 | dim_mlp_dec=1024,
60 | num_heads_dec=8,
61 | ),
62 | 'resnet-c': D(
63 | resnet_variant='c4',
64 | resnet_depth=50,
65 | resnet_sk_ratio=0.,
66 | resnet_width_multiplier=1,
67 | num_encoder_layers=12,
68 | dim_att=512,
69 | dim_mlp=2048,
70 | num_heads=16,
71 | num_decoder_layers=8,
72 | dim_att_dec=512,
73 | dim_mlp_dec=2048,
74 | num_heads_dec=16,
75 | ),
76 | }
77 |
--------------------------------------------------------------------------------
/configs/config_det_finetune.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Config file for object detection fine-tuning and evaluation."""
17 |
18 | import copy
19 |
20 | from configs import dataset_configs
21 | from configs import transform_configs
22 | from configs.config_base import architecture_config_map
23 | from configs.config_base import D
24 |
25 | # pylint: disable=invalid-name,line-too-long,missing-docstring
26 |
27 |
28 | def get_config(config_str=None):
29 | """config_str is either empty or contains task,architecture variants."""
30 |
31 | task_variant = 'object_detection@coco/2017_object_detection'
32 | encoder_variant = 'vit-b' # Set model architecture.
33 | image_size = (640, 640) # Set image size.
34 |
35 |
36 | tasks_and_datasets = []
37 | for task_and_ds in task_variant.split('+'):
38 | tasks_and_datasets.append(task_and_ds.split('@'))
39 |
40 | max_instances_per_image = 100
41 | max_instances_per_image_test = 100
42 |
43 | task_config_map = {
44 | 'object_detection': D(
45 | name='object_detection',
46 | vocab_id=10,
47 | image_size=image_size,
48 | quantization_bins=1000,
49 | max_instances_per_image=max_instances_per_image,
50 | max_instances_per_image_test=max_instances_per_image_test,
51 | train_transforms=transform_configs.get_object_detection_train_transforms(
52 | image_size, max_instances_per_image),
53 | eval_transforms=transform_configs.get_object_detection_eval_transforms(
54 | image_size, max_instances_per_image_test),
55 | # Train on both ground-truth and (augmented) noisy objects.
56 | noise_bbox_weight=1.0,
57 | eos_token_weight=0.1,
58 | # Train on just ground-truth objects (with an ending token).
59 | # noise_bbox_weight=0.0,
60 | # eos_token_weight=0.1,
61 | class_label_corruption='rand_n_fake_cls',
62 | top_k=0,
63 | top_p=0.4,
64 | temperature=1.0,
65 | weight=1.0,
66 | metric=D(name='coco_object_detection',),
67 | ),
68 | }
69 |
70 | task_d_list = []
71 | dataset_list = []
72 | for tv, ds_name in tasks_and_datasets:
73 | task_d_list.append(task_config_map[tv])
74 | dataset_config = copy.deepcopy(dataset_configs.dataset_configs[ds_name])
75 | dataset_list.append(dataset_config)
76 |
77 | config = D(
78 | dataset=dataset_list[0],
79 | datasets=dataset_list,
80 |
81 | task=task_d_list[0],
82 | tasks=task_d_list,
83 |
84 | model=D(
85 | name='encoder_ar_decoder',
86 | image_size=image_size,
87 | max_seq_len=512,
88 | vocab_size=3000, # Note: should be large enough for 100 + num_classes + quantization_bins + (optional) text
89 | coord_vocab_shift=1000, # Note: make sure num_class <= coord_vocab_shift - 100
90 | text_vocab_shift=3000, # Note: make sure coord_vocab_shift + quantization_bins <= text_vocab_shift
91 | use_cls_token=False,
92 | shared_decoder_embedding=True,
93 | decoder_output_bias=True,
94 | patch_size=16,
95 | drop_path=0.1,
96 | drop_units=0.1,
97 | drop_att=0.0,
98 | dec_proj_mode='mlp',
99 | pos_encoding='sin_cos',
100 | pos_encoding_dec='learned',
101 | pretrained_ckpt=get_obj365_pretrained_checkpoint(encoder_variant),
102 | ),
103 |
104 | optimization=D(
105 | optimizer='adamw',
106 | learning_rate=3e-5,
107 | end_lr_factor=0.01,
108 | warmup_epochs=2,
109 | warmup_steps=0, # set to >0 to override warmup_epochs.
110 | weight_decay=0.05,
111 | global_clipnorm=-1,
112 | beta1=0.9,
113 | beta2=0.95,
114 | eps=1e-8,
115 | learning_rate_schedule='linear',
116 | learning_rate_scaling='none',
117 | ),
118 |
119 | train=D(
120 | batch_size=32,
121 | epochs=40,
122 | steps=0, # set to >0 to override epochs.
123 | checkpoint_epochs=1,
124 | checkpoint_steps=0, # set to >0 to override checkpoint_epochs.
125 | keep_checkpoint_max=5,
126 | loss_type='xent',
127 | ),
128 |
129 | eval=D(
130 | tag='eval',
131 | checkpoint_dir='', # checkpoint_dir will be model_dir if not set.
132 | # checkpoint_dir=get_coco_finetuned_checkpoint(encoder_variant, image_size[0]),
133 | batch_size=8, # needs to be divisible by total eval examples.
134 | steps=0, # 0 means eval over full validation set.
135 | ),
136 | )
137 |
138 | # Update model with architecture variant.
139 | for key, value in architecture_config_map[encoder_variant].items():
140 | config.model[key] = value
141 |
142 | return config
143 |
144 | CKPT_PREFIX = 'gs://pix2seq'
145 |
146 |
147 | def get_obj365_pretrained_checkpoint(encoder_variant):
148 | if encoder_variant == 'resnet':
149 | return f'{CKPT_PREFIX}/obj365_pretrain/resnet_640x640_b256_s400k'
150 | elif encoder_variant == 'resnet-c':
151 | return f'{CKPT_PREFIX}/obj365_pretrain/resnetc_640x640_b256_s400k'
152 | elif encoder_variant == 'vit-b':
153 | return f'{CKPT_PREFIX}/obj365_pretrain/vit_b_640x640_b256_s400k'
154 | elif encoder_variant == 'vit-l':
155 | return f'{CKPT_PREFIX}/obj365_pretrain/vit_l_640x640_b256_s400k'
156 | else:
157 | raise ValueError('Unknown encoder_variant {}'.format(encoder_variant))
158 |
159 |
160 | def get_coco_finetuned_checkpoint(encoder_variant, image_size):
161 | if encoder_variant == 'resnet':
162 | return f'{CKPT_PREFIX}/coco_det_finetune/resnet_{image_size}x{image_size}'
163 | elif encoder_variant == 'resnet-c':
164 | return f'{CKPT_PREFIX}/coco_det_finetune/resnetc_{image_size}x{image_size}'
165 | elif encoder_variant == 'vit-b':
166 | return f'{CKPT_PREFIX}/coco_det_finetune/vit_b_{image_size}x{image_size}'
167 | elif encoder_variant == 'vit-l':
168 | return f'{CKPT_PREFIX}/coco_det_finetune/vit_l_{image_size}x{image_size}'
169 | else:
170 | raise ValueError('Unknown encoder_variant {}'.format(encoder_variant))
171 |
--------------------------------------------------------------------------------
/configs/config_diffusion_base.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """A config."""
17 |
18 | # pylint: disable=invalid-name,line-too-long
19 |
20 | from configs.config_base import architecture_config_map
21 | from configs.config_base import D
22 |
23 |
24 | def get_config(config_str=None):
25 | """config_str can be none or something that specifies meta-hyperparam."""
26 | if config_str:
27 | data_name, arch_variant = config_str.split(',')
28 | else:
29 | data_name = 'cifar10'
30 | arch_variant = 'transunet'
31 | arch_variant = 'tape'
32 |
33 | architecture_config_map.update({
34 | 'transunet': D(
35 | arch_name='transunet',
36 | resnet_variant='c1',
37 | dim=128,
38 | in_strides=1,
39 | in_kernel_size=3,
40 | out_kernel_size=3,
41 | kernel_sizes='3',
42 | n_mlp_blocks=0,
43 | udrop=0.0,
44 | n_res_blocks='3,3,3',
45 | ch_multipliers='1,2,2',
46 | mhsa_resolutions='16,8',
47 | per_head_dim=64,
48 | transformer_dim=128,
49 | transformer_strides=2,
50 | transformer_blocks=0,
51 | conditioning=False,
52 | u_pos_encoding='sin_cos',
53 | outp_softmax_groups=0,
54 | ),
55 | 'tape': D(
56 | arch_name='tape',
57 | num_layers='4,4,4,4',
58 | latent_slots=32,
59 | latent_dim=256,
60 | latent_mlp_ratio=4,
61 | latent_num_heads=4,
62 | tape_dim=128,
63 | tape_mlp_ratio=2,
64 | rw_num_heads=1,
65 | conv_kernel_size=0,
66 | conv_drop_units=0.,
67 | drop_path=0.,
68 | drop_units=0.,
69 | drop_att=0.,
70 | drop_sc=0.,
71 | time_on_latent=False,
72 | cond_on_latent=False,
73 | cond_tape_writable=False,
74 | latent_pos_encoding='learned',
75 | tape_pos_encoding='learned',
76 | xattn_enc_ln=False,
77 | cond_dim=0,
78 | cond_proj=True,
79 | cond_decoupled_read=False,
80 | ),
81 | })
82 |
83 | dataset_config_map = {
84 | 'mnist': D(
85 | name='object_recognition',
86 | tfds_name='mnist',
87 | train_split='train',
88 | eval_split='test',
89 | num_classes=10,
90 | image_size=28,
91 | batch_duplicates=1,
92 | cache_dataset=True,
93 | cropping='none',
94 | flipping='none',
95 | ),
96 | 'cifar10': D(
97 | name='object_recognition',
98 | tfds_name='cifar10',
99 | train_split='train',
100 | eval_split='test',
101 | num_classes=10,
102 | image_size=32,
103 | batch_duplicates=1,
104 | cache_dataset=True,
105 | cropping='none',
106 | flipping='left_right',
107 | ),
108 | 'imagenet2012': D(
109 | name='object_recognition',
110 | tfds_name='imagenet2012',
111 | train_split='train',
112 | eval_split='validation',
113 | num_classes=1000,
114 | image_size=64,
115 | batch_duplicates=1,
116 | cache_dataset=True,
117 | cropping='center',
118 | flipping='left_right',
119 | ),
120 | 'ucf101': D(
121 | name='tfds_video',
122 | tfds_name='ucf101',
123 | train_split=['train', 'test'],
124 | eval_split=['train', 'test'],
125 | num_classes=101,
126 | image_size=64,
127 | batch_duplicates=1,
128 | cache_dataset=False,
129 | cropping='center',
130 | flipping='none',
131 | seq_len=16,
132 | ),
133 | 'kinetics600': D(
134 | name='kinetics600',
135 | tfds_name='kinetics600', # for eval config
136 | train_split='train',
137 | eval_split='test',
138 | num_classes=600,
139 | image_size=64,
140 | batch_duplicates=1,
141 | cache_dataset=False,
142 | cropping='none',
143 | flipping='none',
144 | seq_len=16,
145 | ),
146 | }
147 |
148 | task = D(
149 | name='image_generation',
150 | weight=1.,
151 | )
152 | task_d_list = [task]
153 | dataset_list = [dataset_config_map[data_name]]
154 |
155 | config = D(
156 | dataset=dataset_list[0],
157 | datasets=dataset_list,
158 |
159 | task=task_d_list[0],
160 | tasks=task_d_list,
161 |
162 | model=D(
163 | name='image_diffusion_model',
164 | train_schedule='cosine',
165 | infer_schedule='cosine',
166 | pred_type='eps',
167 | loss_type='eps',
168 | infer_iterations=100,
169 | td=0.,
170 | x0_clip='auto',
171 | b_scale=1.0,
172 | normalize_noisy_input=False,
173 | time_scaling=1000,
174 | pretrained_ckpt='',
175 | sampler_name='ddpm',
176 | conditional='class',
177 | self_cond='latent',
178 | b_type='uint8',
179 | flip_rate=0.,
180 | self_cond_rate=0.5,
181 | self_cond_by_masking=False,
182 | cond_dropout=0.,
183 | guidance=0.,
184 |
185 | # architecture extra
186 | use_cls_token=False,
187 | pos_encoding='sin_cos',
188 | patch_size=8,
189 | drop_path=0.,
190 | drop_units=0.,
191 | drop_att=0.,
192 | ),
193 |
194 | optimization=D(
195 | optimizer='lamb',
196 | learning_rate=1e-4,
197 | warmup_epochs=0,
198 | warmup_steps=5000,
199 | tail_steps=0,
200 | weight_decay=0.,
201 | global_clipnorm=1.0,
202 | momentum=0.9,
203 | beta1=0.9,
204 | beta2=0.999,
205 | eps=1e-8,
206 | learning_rate_schedule='none',
207 | learning_rate_scaling='none',
208 | end_lr_factor=0.0,
209 | ema_decay=0.9999,
210 | ema_name_exact_match=True,
211 | exclude_from_weight_decay='bias,beta,gamma',
212 | ),
213 |
214 | train=D(
215 | batch_size=512,
216 | epochs=100,
217 | steps=0,
218 | checkpoint_epochs=40,
219 | checkpoint_steps=0,
220 | keep_checkpoint_max=20,
221 | label_smoothing=0.,
222 | ),
223 |
224 | eval=D(
225 | tag='eval',
226 | checkpoint_dir='',
227 | batch_size=64,
228 | steps=100, # this is an approximation.
229 | ),
230 | )
231 |
232 | # Update model with architecture variant.
233 | for key, value in architecture_config_map[arch_variant].items():
234 | config.model[key] = value
235 |
236 | return config
237 |
--------------------------------------------------------------------------------
/configs/config_diffusion_cifar10.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """A config."""
17 |
18 | # pylint: disable=invalid-name,line-too-long
19 |
20 | from configs.google.users.iamtingchen import config_diffusion_base as config_base
21 |
22 | DATA_NAME = 'cifar10'
23 | ARCH_VARIANT = 'tape'
24 | # ARCH_VARIANT = 'transunet'
25 |
26 |
27 | def get_config(config_str=None):
28 | """Returns config."""
29 | del config_str
30 | config = config_base.get_config(f'{DATA_NAME},{ARCH_VARIANT}')
31 | config.model.name = 'image_diffusion_model'
32 | config.model.b_scale = 1.0
33 | config.model.pred_type = 'eps'
34 | config.model.conditional = 'class'
35 | config.optimization.ema_decay = 0.9999
36 | config.eval.batch_size = 80
37 | config.eval.steps = 625
38 | return config
39 |
40 |
41 | def get_sweep(h):
42 | """Get the hyperparamater sweep."""
43 | if ARCH_VARIANT == 'transunet':
44 | return h.chainit([
45 | h.product([
46 | h.sweep('config.train.steps', [1_000_000]),
47 | h.sweep('config.train.checkpoint_steps', [10000]),
48 | h.sweep('config.train.keep_checkpoint_max', [100]),
49 | h.sweep('config.train.batch_size', [128*1]),
50 | h.sweep('config.optimization.learning_rate', [1e-4]),
51 | h.sweep('config.optimization.warmup_steps', [10000]),
52 | h.sweep('config.model.self_cond', ['none', 'x', 'eps']),
53 | h.sweep('config.model.udrop', [0.3]),
54 | h.sweep('config.model.dim', [256]),
55 | h.sweep('config.model.n_res_blocks', ['3,3,3']),
56 | h.sweep('config.model.ch_multipliers', ['1,1,1']),
57 | ]),
58 | ])
59 | elif ARCH_VARIANT == 'tape':
60 | return h.chainit([
61 | h.product([
62 | h.sweep('config.train.steps', [150_000]),
63 | h.sweep('config.train.checkpoint_steps', [10000]),
64 | h.sweep('config.train.keep_checkpoint_max', [10]),
65 | h.sweep('config.train.batch_size', [256]),
66 | h.sweep('config.optimization.optimizer', ['lamb']),
67 | h.sweep('config.optimization.learning_rate', [3e-3]),
68 | h.sweep('config.optimization.learning_rate_schedule', ['cosine@0.8']),
69 | h.sweep('config.optimization.end_lr_factor', [0.0]),
70 | h.sweep('config.optimization.warmup_steps', [10000]),
71 | h.sweep('config.optimization.beta2', [0.999]),
72 | h.sweep('config.optimization.weight_decay', [1e-2]),
73 | h.sweep('config.model.train_schedule', ['sigmoid@-3,3,0.9',
74 | 'simple_linear']),
75 | h.sweep('config.model.self_cond', ['latent']),
76 | h.sweep('config.model.self_cond_by_masking', [True]),
77 | h.sweep('config.model.self_cond_rate', [0.9]),
78 |
79 | h.sweep('config.model.patch_size', [2]),
80 | h.sweep('config.model.time_on_latent', [True]),
81 | h.sweep('config.model.cond_on_latent', [True]),
82 | h.sweep('config.model.cond_tape_writable', [False]),
83 | h.sweep('config.model.latent_pos_encoding', ['learned']),
84 | h.sweep('config.model.tape_pos_encoding', ['learned']),
85 | h.sweep('config.model.num_layers', ['2,2,2']), # '4,4',
86 | h.sweep('config.model.latent_slots', [128]),
87 | h.sweep('config.model.latent_dim', [512]),
88 | h.sweep('config.model.latent_num_heads', [16]),
89 | h.sweep('config.model.latent_mlp_ratio', [4]),
90 | h.sweep('config.model.tape_dim', [256]),
91 | h.sweep('config.model.tape_mlp_ratio', [2]),
92 | h.sweep('config.model.rw_num_heads', [8]),
93 | h.sweep('config.model.drop_units', [0.1]),
94 | h.sweep('config.model.drop_path', [0.1]),
95 | ]),
96 | ])
97 |
98 |
99 | def get_eval_args_and_tags(config, args, unused_config_flag):
100 | """Return eval args and tags."""
101 | args_and_tags = []
102 | for eval_split in [config.dataset.train_split]:
103 | for sampler_name in ['ddpm']:
104 | for infer_schedule in ['cosine']:
105 | for infer_iterations in [400, 1000]:
106 | # if sampler_name == 'ddim' and infer_iterations > 250: continue
107 | # if sampler_name == 'ddpm' and infer_iterations <= 250: continue
108 | eval_args = args.copy()
109 | sampler_name_s = sampler_name.replace('@', '')
110 | infer_schedule_s = infer_schedule.replace('@', '').replace(',', 'c')
111 | eval_tag = f'ev_{eval_split}_{sampler_name_s}_{infer_schedule_s}_i{infer_iterations}'
112 | eval_args.update({
113 | 'config.eval.tag': eval_tag,
114 | 'config.dataset.eval_split': eval_split,
115 | 'config.model.sampler_name': sampler_name,
116 | 'config.model.infer_schedule': infer_schedule,
117 | 'config.model.infer_iterations': infer_iterations,
118 | })
119 | args_and_tags.append((eval_args, eval_tag, None))
120 | return args_and_tags
121 |
--------------------------------------------------------------------------------
/configs/config_diffusion_cifar10d.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """A config."""
17 |
18 | # pylint: disable=invalid-name,line-too-long
19 |
20 | from configs.google.users.iamtingchen import config_diffusion_base as config_base
21 |
22 | DATA_NAME = 'cifar10'
23 | ARCH_VARIANT = 'transunet'
24 |
25 |
26 | def get_config(config_str=None):
27 | """Returns config."""
28 | del config_str
29 | config = config_base.get_config(f'{DATA_NAME},{ARCH_VARIANT}')
30 | config.model.name = 'image_discrete_diffusion_model'
31 | config.model.b_scale = 1.0
32 | config.model.pred_type = 'x'
33 | config.model.conditional = 'class'
34 | config.optimization.ema_decay = 0.9999
35 | config.eval.batch_size = 80
36 | config.eval.steps = 625
37 | return config
38 |
39 |
40 | def get_sweep(h):
41 | """Get the hyperparamater sweep."""
42 | if ARCH_VARIANT == 'transunet':
43 | return h.chainit([
44 | h.product([
45 | h.sweep('config.train.steps', [1_500_000]),
46 | h.sweep('config.train.checkpoint_steps', [10000]),
47 | h.sweep('config.train.keep_checkpoint_max', [100]),
48 | h.sweep('config.train.batch_size', [128*1]),
49 | h.sweep('config.optimization.learning_rate', [1e-4]),
50 | h.sweep('config.optimization.warmup_steps', [10000]),
51 | h.sweep('config.optimization.optimizer', ['adamw']),
52 | h.sweep('config.optimization.exclude_from_weight_decay', ['']),
53 | h.sweep('config.model.self_cond', ['x']),
54 | h.sweep('config.model.self_cond_by_masking', [False]),
55 | h.sweep('config.model.self_cond_rate', [0.5]),
56 | h.sweep('config.model.b_scale', [0.5]),
57 | h.sweep('config.model.udrop', [0.]),
58 | h.sweep('config.model.dim', [256]),
59 | h.sweep('config.model.n_res_blocks', ['3,3,3']),
60 | h.sweep('config.model.ch_multipliers', ['1,1,1']),
61 | h.sweep('config.model.total_time_steps', [1000]),
62 | h.sweep('config.model.pred_type', ['x_softmax_xent']),
63 | h.zipit([
64 | h.sweep('config.model.b_type', ['uint8', 'uint8_s']),
65 | h.sweep('config.model.outp_softmax_groups', [0, 3]),
66 | ]),
67 | ]),
68 | ])
69 |
70 |
71 | def get_eval_args_and_tags(config, args, unused_config_flag):
72 | """Return eval args and tags."""
73 | args_and_tags = []
74 | for eval_split in [config.dataset.train_split]:
75 | for sampler_name in ['ddpm']:
76 | for infer_schedule in ['cosine']:
77 | for infer_iterations in [100, 250, 400]:
78 | # if sampler_name == 'ddim' and infer_iterations > 250: continue
79 | # if sampler_name == 'ddpm' and infer_iterations <= 250: continue
80 | eval_args = args.copy()
81 | sampler_name_s = sampler_name.replace('@', '')
82 | infer_schedule_s = infer_schedule.replace('@', '').replace(',', 'c')
83 | eval_tag = f'ev_{eval_split}_{sampler_name_s}_{infer_schedule_s}_i{infer_iterations}'
84 | eval_args.update({
85 | 'config.eval.tag': eval_tag,
86 | 'config.dataset.eval_split': eval_split,
87 | 'config.model.sampler_name': sampler_name,
88 | 'config.model.infer_schedule': infer_schedule,
89 | 'config.model.infer_iterations': infer_iterations,
90 | })
91 | args_and_tags.append((eval_args, eval_tag, None))
92 | return args_and_tags
93 |
--------------------------------------------------------------------------------
/configs/config_diffusion_imagenet64.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """A config."""
17 |
18 | # pylint: disable=invalid-name,line-too-long
19 |
20 | from configs.google.users.iamtingchen import config_diffusion_base as config_base
21 |
22 | DATA_NAME = 'imagenet2012'
23 | ARCH_VARIANT = 'tape'
24 | # ARCH_VARIANT = 'transunet'
25 | IMAGE_SIZE = 64 * 1
26 |
27 |
28 | def get_config(config_str=None):
29 | """Returns config."""
30 | del config_str
31 | config = config_base.get_config(f'{DATA_NAME},{ARCH_VARIANT}')
32 | config.dataset.image_size = IMAGE_SIZE
33 | config.model.name = 'image_diffusion_model'
34 | config.model.b_scale = 1.0
35 | config.model.pred_type = 'eps'
36 | config.model.conditional = 'class'
37 | config.model.time_on_latent = True
38 | config.model.cond_on_latent = True
39 | config.model.cond_tape_writable = False
40 | config.optimization.ema_decay = 0.9999
41 | config.eval.batch_size = 80
42 | config.eval.steps = 625
43 | return config
44 |
45 |
46 | def get_sweep(h):
47 | """Get the hyperparamater sweep."""
48 | if ARCH_VARIANT == 'transunet':
49 | return h.chainit([
50 | h.product([
51 | h.sweep('config.train.steps', [250_000]),
52 | h.sweep('config.train.checkpoint_steps', [5000]),
53 | h.sweep('config.train.keep_checkpoint_max', [100]),
54 | h.sweep('config.train.batch_size', [1024*1]),
55 | h.sweep('config.optimization.optimizer', ['adamw']),
56 | h.sweep('config.optimization.exclude_from_weight_decay', ['']),
57 | h.sweep('config.optimization.learning_rate', [2e-4]),
58 | h.sweep('config.optimization.warmup_steps', [10000]),
59 | h.sweep('config.optimization.weight_decay', [0.]),
60 | h.sweep('config.model.self_cond', ['none', 'eps']),
61 | h.sweep('config.model.total_time_steps', [1, 1000]),
62 | h.sweep('config.model.udrop', [0.]),
63 | # h.sweep('config.model.dim', [256]),
64 | # h.sweep('config.model.n_res_blocks', ['3,3,3,3']),
65 | # h.sweep('config.model.ch_multipliers', ['1,2,2,3']),
66 | h.sweep('config.model.dim', [192]),
67 | h.sweep('config.model.n_res_blocks', ['3,3,3,3']),
68 | h.sweep('config.model.ch_multipliers', ['1,2,3,4']),
69 | h.sweep('config.model.self_cond_by_masking', [False]),
70 | h.sweep('config.model.self_cond_rate', [0.5]),
71 | ]),
72 | ])
73 | else:
74 | return h.chainit([
75 | h.product([
76 | h.sweep('config.train.steps', [150_000]),
77 | h.sweep('config.train.checkpoint_steps', [10_000]),
78 | h.sweep('config.train.keep_checkpoint_max', [20]),
79 | h.sweep('config.train.batch_size', [1024]),
80 | h.sweep('config.optimization.optimizer', ['lamb']),
81 | h.sweep('config.optimization.learning_rate', [2e-3]),
82 | h.sweep('config.optimization.learning_rate_schedule', ['cosine@0.7']),
83 | h.sweep('config.optimization.end_lr_factor', [0.]),
84 | h.sweep('config.optimization.warmup_steps', [10000]),
85 | h.sweep('config.optimization.weight_decay', [1e-2]),
86 | h.sweep('config.optimization.beta2', [0.999]),
87 | h.sweep('config.model.train_schedule', ['simple_linear']),
88 | h.sweep('config.model.pred_type', ['eps']),
89 | h.sweep('config.model.self_cond', ['latent']),
90 | h.sweep('config.model.self_cond_by_masking', [True]),
91 | h.sweep('config.model.self_cond_rate', [0.9]),
92 | h.sweep('config.model.total_time_steps', [1000]),
93 |
94 | h.sweep('config.model.patch_size', [4*2]), # 4
95 | h.sweep('config.model.latent_pos_encoding', ['learned']),
96 | h.sweep('config.model.tape_pos_encoding', ['learned']),
97 | h.sweep('config.model.num_layers', ['4,4,4,4']), # '6,6,6,6'
98 | h.sweep('config.model.latent_slots', [128]),
99 | h.sweep('config.model.latent_dim', [768]),
100 | h.sweep('config.model.latent_mlp_ratio', [4]),
101 | h.sweep('config.model.latent_num_heads', [16]),
102 | h.sweep('config.model.tape_dim', [512]),
103 | h.sweep('config.model.tape_mlp_ratio', [4]),
104 | h.sweep('config.model.rw_num_heads', [16]),
105 | h.sweep('config.model.drop_units', [0.]),
106 | h.sweep('config.model.drop_path', [0.]),
107 | ]),
108 | ])
109 |
110 |
111 | def get_eval_args_and_tags(config, args, unused_config_flag):
112 | """Return eval args and tags."""
113 | args_and_tags = []
114 | for eval_split in [config.dataset.train_split]:
115 | for sampler_name in ['ddpm']:
116 | for infer_schedule in ['cosine']:
117 | for infer_iterations in [100, 250, 1000]:
118 | eval_args = args.copy()
119 | sampler_name_s = sampler_name.replace('@', '')
120 | infer_schedule_s = infer_schedule.replace('@', '').replace(',', 'c')
121 | eval_tag = f'ev_{eval_split}_{sampler_name_s}_{infer_schedule_s}_i{infer_iterations}'
122 | eval_args.update({
123 | 'config.eval.tag': eval_tag,
124 | 'config.dataset.eval_split': eval_split,
125 | 'config.model.sampler_name': sampler_name,
126 | 'config.model.infer_schedule': infer_schedule,
127 | 'config.model.infer_iterations': infer_iterations,
128 | })
129 | args_and_tags.append((eval_args, eval_tag, None))
130 | return args_and_tags
131 |
--------------------------------------------------------------------------------
/configs/config_diffusion_panoptic_base.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """A config."""
17 |
18 | import copy
19 | # pylint: disable=invalid-name,line-too-long,missing-docstring
20 | from configs import dataset_configs
21 | from configs import transform_configs
22 | from configs.config_base import architecture_config_map
23 | from configs.config_base import D
24 |
25 |
26 | def get_config(config_str=None):
27 | """config_str is either empty or contains task,architecture variants."""
28 |
29 | if config_str:
30 | config_lists = config_str.split(',')
31 | task_variant, encoder_variant, decoder_variant = config_lists[:3]
32 | image_dim, mdim = config_lists[3], config_lists[4]
33 | # The `if` conditions here are for backwards compatibility with existing
34 | # scripts that pass in a single dim for image size and mask size.
35 | # We should eventually remove this and require callers to pass in the full
36 | # res.
37 | if 'x' in image_dim:
38 | image_size = [int(d) for d in image_dim.split('x')]
39 | else:
40 | image_size = (int(image_dim), int(image_dim))
41 | if 'x' in mdim:
42 | msize = [int(d) for d in mdim.split('x')]
43 | else:
44 | msize = (int(mdim), int(mdim))
45 | else:
46 | task_variant = 'panoptic_segmentation@coco/2017_panoptic_segmentation'
47 | encoder_variant = 'resnet-c'
48 | decoder_variant = 'transunet'
49 | image_size = (256, 256)
50 | msize = (64, 64)
51 |
52 | tasks_and_datasets = []
53 | for task_and_ds in task_variant.split('+'):
54 | tasks_and_datasets.append(task_and_ds.split('@'))
55 |
56 | decoder_config_map = {
57 | 'transformer': D(
58 | arch_name='transformer',
59 | num_layers=2,
60 | dim=512,
61 | dim_mlp=2048,
62 | num_heads=16,
63 | pos_encoding='learned',
64 | drop_path=0.0,
65 | drop_units=0.1,
66 | drop_att=0.0,
67 | ),
68 | 'transunet': D(
69 | arch_name='transunet',
70 | dim=32,
71 | in_strides=1,
72 | in_kernel_size=1,
73 | out_kernel_size=1,
74 | udrop=0.1,
75 | n_mlp_blocks=0,
76 | n_res_blocks='1,1,1,1',
77 | kernel_sizes='3',
78 | ch_multipliers='1,2,3,4',
79 | transformer_dim=512,
80 | transformer_strides=1,
81 | transformer_blocks=1,
82 | mhsa_resolutions='16,8',
83 | per_head_dim=64,
84 | u_pos_encoding='sin_cos',
85 | outp_softmax_groups=16,
86 | ),
87 | 'tape': D(
88 | arch_name='tape',
89 | num_layers='1,1',
90 | latent_slots=128,
91 | latent_dim=256*2,
92 | latent_mlp_ratio=4,
93 | latent_num_heads=16,
94 | tape_dim=256*2,
95 | tape_mlp_ratio=4,
96 | rw_num_heads=16,
97 | conv_kernel_size=0,
98 | conv_drop_units=0.,
99 | drop_path=0.,
100 | drop_units=0.,
101 | drop_att=0.,
102 | patch_size=8,
103 | outp_softmax_groups=2,
104 | pos_encoding='sin_cos',
105 | patch_scales='1',
106 | patch_scales_w='1',
107 | latent_pos_encoding='learned',
108 | tape_pos_encoding='learned',
109 | ),
110 | }
111 |
112 | task_config_map = {
113 | 'panoptic_segmentation': D(
114 | name='panoptic_segmentation',
115 | vocab_id=16,
116 | image_size=image_size,
117 | n_bits_label=16,
118 | max_instances_per_image=101,
119 | object_order='random',
120 | color_jitter_strength=0.,
121 | jitter_scale_min=0.3,
122 | jitter_scale_max=1.0,
123 | min_pixels=40,
124 | weight=1.0,
125 | metric=D(
126 | name='coco_panoptic_segmentation',
127 | results_dir='',
128 | ),
129 | ),
130 | 'video_panoptic_segmentation': D(
131 | name='video_panoptic_segmentation',
132 | vocab_id=18,
133 | image_size=image_size,
134 | n_bits_label=16,
135 | max_instances_per_image=256, # including id 0.
136 | object_order='shuffle',
137 | color_jitter_strength=0.,
138 | jitter_scale_min=1.0,
139 | jitter_scale_max=1.0,
140 | weight=1.0,
141 | proceeding_frames='-2,-1',
142 | eval_single_frames=False,
143 | eval_use_gt_cond_frames=False,
144 | frames_dropout=0.1,
145 | max_num_frames=100,
146 | min_pixels=40,
147 | metric=D(
148 | name='segmentation_and_tracking_quality',
149 | results_dir=''
150 | ),
151 | ),
152 | }
153 |
154 | task_d_list = []
155 | dataset_list = []
156 | for task_name, ds_name in tasks_and_datasets:
157 | task_d_list.append(task_config_map[task_name])
158 | dataset_config = copy.deepcopy(dataset_configs.dataset_configs[ds_name])
159 | dataset_list.append(dataset_config)
160 |
161 | config = D(
162 | dataset=dataset_list[0],
163 | datasets=dataset_list,
164 |
165 | task=task_d_list[0],
166 | tasks=task_d_list,
167 |
168 | encoder=D(),
169 |
170 | decoder=D(),
171 |
172 | model=D(
173 | name='panoptic_diffusion',
174 | train_schedule='cosine',
175 | infer_schedule='cosine',
176 | train_noise='normal',
177 | infer_noise='normal',
178 | noise_std=1.0,
179 | noise_truncation=False,
180 | pred_type='x_softmax_xent',
181 | iter_start=0,
182 | step_bias=0.,
183 | td=0.,
184 | td_schedule='constant',
185 | x0_clip='auto',
186 | b_scale=0.1,
187 | normalize_noisy_input=False,
188 | total_time_steps=1000.,
189 | pretrained_ckpt='',
190 | self_cond='none',
191 | conditional='cat',
192 | iterations=100,
193 | iterations_2=100, # only used in video inference where less iterations can be used for the 2nd frame onwards.
194 | sampler='ddim',
195 | l_tile_factors=1,
196 | msize=msize,
197 | mask_weight_p=0.,
198 | self_cond_rate=0.5,
199 | self_cond_by_masking=False,
200 |
201 | # extra architecture
202 | image_size=image_size,
203 | decoder_variant=decoder_variant,
204 | use_cls_token=False,
205 | patch_size=16,
206 | drop_path=0.1,
207 | drop_units=0.1,
208 | drop_att=0.0,
209 | pos_encoding='sin_cos',
210 | dec_proj_mode='mlp',
211 | enc_drop=0.,
212 | enc_fuse='pyramid_merge',
213 | enc_fuse_upsample='nearest',
214 | enc_fuse_dim=256,
215 | frozen_backbone=False,
216 | ),
217 |
218 | optimization=D(
219 | optimizer='adamw',
220 | learning_rate=1e-3,
221 | end_lr_factor=0.01,
222 | warmup_epochs=10,
223 | warmup_steps=0, # set to >0 to override warmup_epochs.
224 | weight_decay=0.05,
225 | global_clipnorm=-1.,
226 | beta1=0.9,
227 | beta2=0.95,
228 | eps=1e-8,
229 | ema_decay=0.9999,
230 | ema_name_exact_match=True,
231 | learning_rate_schedule='linear',
232 | learning_rate_scaling='none',
233 | ),
234 |
235 | train=D(
236 | batch_size=32,
237 | epochs=40,
238 | steps=0, # set to >0 to override epochs.
239 | checkpoint_epochs=1,
240 | checkpoint_steps=0, # set to >0 to override checkpoint_epochs.
241 | keep_checkpoint_max=10,
242 | loss_type='xent',
243 | ),
244 |
245 | eval=D(
246 | tag='eval',
247 | checkpoint_dir='', # checkpoint_dir will be model_dir if not set.
248 | batch_size=4, # needs to be divisible by total eval examples.
249 | steps=0, # 0 means eval over full validation set.
250 | ),
251 | )
252 |
253 | # Update model with architecture variant.
254 | for key, value in architecture_config_map[encoder_variant].items():
255 | config.model[key] = value
256 |
257 | # Update decoder architecture variant.
258 | for key, value in decoder_config_map[decoder_variant].items():
259 | config.decoder[key] = value
260 |
261 | # on-the-fly gt gathering for metric computation
262 | # config.dataset.coco_annotations_dir_for_metrics = ''
263 |
264 | config.task.train_transforms = transform_configs.get_panoptic_segmentation_train_transforms(
265 | image_size, msize, 1.0, 1.0, 0.)
266 | config.task.eval_transforms = transform_configs.get_panoptic_segmentation_eval_transforms(
267 | image_size)
268 |
269 | return config
270 |
--------------------------------------------------------------------------------
/configs/config_diffusion_panoptic_image.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Config for conditional mask modeling."""
17 |
18 | # pylint: disable=invalid-name,line-too-long
19 |
20 | from configs import config_diffusion_panoptic_base as config_base
21 | from configs import transform_configs
22 |
23 | IMAGE_SIZE = '1024x1024'
24 | MASK_SIZE = '512x512'
25 | MODE = 'train_high_res' # one of ['train_high_res', 'train_low_res']
26 |
27 |
28 | def get_config(config_str=None):
29 | """Returns config."""
30 | if config_str:
31 | task_variant = config_str
32 | else:
33 | task_variant = 'panoptic_segmentation@coco/2017_panoptic_segmentation'
34 | encoder_variant = 'resnet-c'
35 | decoder_variant = 'transunet'
36 | config = config_base.get_config(
37 | f'{task_variant},{encoder_variant},{decoder_variant},{IMAGE_SIZE},{MASK_SIZE}')
38 | image_size = [int(x) for x in IMAGE_SIZE.split('x')]
39 | mask_size = [int(x) for x in MASK_SIZE.split('x')]
40 | config.task.train_transforms = transform_configs.get_panoptic_segmentation_train_transforms(
41 | image_size, mask_size, 1.0, 1.0, 0.)
42 | config.task.eval_transforms = transform_configs.get_panoptic_segmentation_eval_transforms(
43 | image_size)
44 | config.model.name = 'panoptic_diffusion'
45 | config.model.train_schedule = 'cosine'
46 | config.model.l_tile_factors = 1
47 | config.model.frozen_backbone = False
48 | config.model.enc_drop = 0.
49 | config.model.enc_fuse = 'pyramid_merge'
50 | config.model.enc_fuse_upsample = 'nearest'
51 | config.model.enc_fuse_dim = 256
52 | config.model.total_time_steps = 1.0 # for legacy compability.
53 | config.decoder.mhsa_resolutions = '0'
54 | config.decoder.n_mlp_blocks = 0
55 | config.decoder.outp_softmax_groups = 0
56 | config.decoder.in_kernel_size = 1
57 | config.decoder.out_kernel_size = 1
58 | config.optimization.learning_rate_schedule = 'linear'
59 | config.optimization.end_lr_factor = 0.02
60 | config.optimization.weight_decay = 0.05
61 | config.optimization.beta2 = 0.999
62 | config.optimization.warmup_epochs = 5
63 | config.optimization.global_clipnorm = 1.
64 | return config
65 |
66 |
67 | def get_sweep(h):
68 | """Get the hyperparamater sweep."""
69 | if MODE == 'train_low_res':
70 | return h.chainit([
71 | h.product([
72 | h.sweep('config.train.epochs', [800]),
73 | h.sweep('config.train.batch_size', [512]),
74 | h.sweep('config.train.checkpoint_epochs', [10]),
75 | h.sweep('config.train.keep_checkpoint_max', [10]),
76 | h.sweep('config.optimization.learning_rate', [1e-4]),
77 | h.sweep('config.optimization.end_lr_factor', [0.1]),
78 | h.sweep('config.optimization.warmup_epochs', [5]),
79 | h.sweep('config.optimization.ema_decay', [0.999]),
80 | h.sweep('config.model.b_scale', [0.1]),
81 | h.sweep('config.model.pred_type', ['x_softmax_xent']),
82 | h.sweep('config.model.self_cond', ['none']),
83 | h.sweep('config.model.conditional', ['cat+attn']),
84 | h.sweep('config.model.mask_weight_p', [0.2]),
85 | h.sweep('config.task.train_transforms[1].min_scale', [1.0]), # jitter_scale
86 | h.sweep('config.task.train_transforms[1].max_scale', [3.0]),
87 | h.sweep('config.task.train_transforms[4].color_jitter_strength', [1.0]),
88 | h.sweep('config.decoder.dim', [128]),
89 | h.sweep('config.decoder.udrop', [0.]),
90 | h.sweep('config.decoder.n_res_blocks', ['1,1,1,1']),
91 | h.sweep('config.decoder.ch_multipliers', ['1,1,2,2']),
92 | h.sweep('config.decoder.transformer_strides', [1]),
93 | h.sweep('config.decoder.transformer_dim', [512]),
94 | h.sweep('config.decoder.transformer_blocks', [6]),
95 | h.sweep('config.decoder.outp_softmax_groups', [2]),
96 | ]),
97 | ])
98 | elif MODE == 'train_high_res':
99 | return h.chainit([
100 | h.product([
101 | h.sweep('config.train.epochs', [15]),
102 | h.sweep('config.train.batch_size', [16]),
103 | h.sweep('config.train.checkpoint_epochs', [1]),
104 | h.sweep('config.optimization.learning_rate', [1e-5]),
105 | h.sweep('config.optimization.end_lr_factor', [0.1]),
106 | h.sweep('config.optimization.warmup_epochs', [0]),
107 | h.sweep('config.optimization.ema_decay', [0.999]),
108 | h.sweep('config.model.b_scale', [0.1]),
109 | h.sweep('config.model.pred_type', ['x_softmax_xent']),
110 | h.sweep('config.model.self_cond', ['none']),
111 | h.sweep('config.model.conditional', ['cat+attn']),
112 | h.sweep('config.model.mask_weight_p', [0.2]),
113 | h.sweep('config.task.train_transforms[1].min_scale', [1.0]),
114 | h.sweep('config.task.train_transforms[1].max_scale', [1.0]),
115 | h.sweep('config.decoder.dim', [128]),
116 | h.sweep('config.decoder.udrop', [0.]),
117 | h.sweep('config.decoder.n_res_blocks', ['1,1,1,1']),
118 | h.sweep('config.decoder.ch_multipliers', ['1,1,2,2']),
119 | h.sweep('config.decoder.transformer_strides', [1]),
120 | h.sweep('config.decoder.transformer_dim', [512]),
121 | h.sweep('config.decoder.transformer_blocks', [6]),
122 | h.sweep('config.decoder.outp_softmax_groups', [2]),
123 | ]),
124 | ])
125 |
126 |
127 | def get_eval_args_and_tags(config, args, unused_config_flag):
128 | """Return eval args and tags."""
129 | args_and_tags = []
130 | for eval_split in [config.dataset.eval_split]:
131 | for sampler in ['ddim']:
132 | for iterations in [20]:
133 | for td in [1.0, 2.0]:
134 | for min_pixels in [40]:
135 | eval_args = args.copy()
136 | eval_tag = f'ev_{eval_split}_{sampler}_i{iterations}_td{td}_p{min_pixels}'
137 | results_dir = eval_args['model_dir'] + '/' + eval_tag # pylint: disable=unused-variable
138 | eval_args.update({
139 | 'config.eval.tag': eval_tag,
140 | 'config.eval.batch_size': 8,
141 | 'config.eval.steps': 0,
142 | 'config.model.sampler': sampler,
143 | 'config.model.iterations': iterations,
144 | 'config.model.td': td,
145 | 'config.task.min_pixels': min_pixels,
146 | # 'config.task.metric.results_dir': results_dir,
147 | })
148 | if eval_split == 'train':
149 | eval_args.update({
150 | 'config.dataset.eval_split': 'train',
151 | 'config.eval.steps': 100,
152 | })
153 | args_and_tags.append((eval_args, eval_tag, None))
154 | return args_and_tags
155 |
156 |
157 |
--------------------------------------------------------------------------------
/configs/config_diffusion_panoptic_video.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Video panoptic segmentation config."""
17 |
18 | # pylint: disable=invalid-name,line-too-long
19 |
20 | from configs import config_diffusion_panoptic_base as config_base
21 | from configs import transform_configs
22 |
23 |
24 | def get_config(config_str=None):
25 | """Returns config."""
26 | if config_str:
27 | task_variant = config_str
28 | else:
29 | task_variant = 'video_panoptic_segmentation@kittistep_vps'
30 |
31 | encoder_variant = 'resnet-c'
32 | decoder_variant = 'transunet'
33 |
34 | if 'kittistep' in task_variant:
35 | imgsize = '384x1248'
36 | msize = '192x624'
37 | elif 'davis' in task_variant:
38 | imgsize = '512x1024'
39 | msize = '256x512'
40 | else:
41 | imgsize = '256x256'
42 | msize = '128x128'
43 |
44 | config = config_base.get_config(
45 | f'{task_variant},{encoder_variant},{decoder_variant},{imgsize},{msize}')
46 |
47 | image_size = [int(x) for x in imgsize.split('x')]
48 | mask_size = [int(x) for x in msize.split('x')]
49 | config.task.train_transforms = transform_configs.get_video_panoptic_segmentation_train_transforms(
50 | image_size, mask_size, 1.0, 1.0, 0.)
51 | config.task.eval_transforms = transform_configs.get_video_panoptic_segmentation_eval_transforms(
52 | image_size, mask_size, 100)
53 |
54 | config.model.name = 'panoptic_diffusion'
55 | config.model.train_schedule = 'cosine'
56 | config.model.l_tile_factors = 1
57 | config.model.frozen_backbone = False
58 | config.model.enc_drop = 0.
59 | config.model.enc_fuse = 'pyramid_merge'
60 | config.model.enc_fuse_upsample = 'nearest'
61 | config.model.enc_fuse_dim = 256
62 | config.model.b_scale = 0.1
63 | config.model.pred_type = 'x_softmax_xent'
64 | config.model.self_cond = 'none'
65 | config.model.conditional = 'cat+attn'
66 | config.model.mask_weight_p = 0.2
67 |
68 | config.decoder.mhsa_resolutions = '0'
69 | config.decoder.n_mlp_blocks = 0
70 | config.decoder.in_kernel_size = 1
71 | config.decoder.out_kernel_size = 1
72 | config.decoder.output_residual = False
73 | config.decoder.input_scaling = False
74 | config.decoder.dim = 128
75 | config.decoder.udrop = 0.
76 | config.decoder.n_res_blocks = '1,1,1,1'
77 | config.decoder.ch_multipliers = '1,1,2,2'
78 | config.decoder.transformer_strides = 1
79 | config.decoder.transformer_dim = 512
80 | config.decoder.transformer_blocks = 6
81 | config.decoder.outp_softmax_groups = 2
82 |
83 | config.optimization.learning_rate_schedule = 'linear'
84 | config.optimization.end_lr_factor = 0.02
85 | config.optimization.weight_decay = 0.05
86 | config.optimization.beta2 = 0.999
87 | config.optimization.warmup_epochs = 0
88 | config.optimization.global_clipnorm = 1.
89 |
90 | config.task.proceeding_frames = '-2,-1'
91 | config.task.eval_single_frames = False
92 | config.task.eval_use_gt_cond_frames = False
93 |
94 | if 'davis' in task_variant:
95 | config.task.max_instances_per_image = 16
96 | config.task.max_num_frames = 105
97 | config.task.eval_transforms[4].max_num_frames = 105
98 | config.task.metric.name = 'davis_video_object_segmentation'
99 |
100 | config.eval.batch_size = 2
101 | return config
102 |
103 |
104 | def get_sweep(h):
105 | """Get the hyperparamater sweep."""
106 |
107 | return h.chainit([
108 | h.product([
109 | h.sweep('config.train.steps', [50000]),
110 | h.sweep('config.train.batch_size', [32]),
111 | h.sweep('config.train.checkpoint_steps', [1000]),
112 | h.sweep('config.optimization.learning_rate', [1e-5, 3e-5]),
113 | h.sweep('config.optimization.end_lr_factor', [1.]),
114 | h.sweep('config.optimization.warmup_epochs', [0]),
115 | h.sweep('config.optimization.ema_decay', [0.99]),
116 | h.zipit([
117 | h.sweep('config.task.train_transforms[0].min_scale', [1.0]),
118 | h.sweep('config.task.train_transforms[0].max_scale', [1.0]),
119 | ]),
120 | h.sweep('config.task.train_transforms[3].color_jitter_strength', [0.]),
121 | h.sweep('config.task.object_order', ['shuffle']),
122 | h.sweep('config.task.frames_dropout', [0.2]),
123 | ]),
124 | ])
125 |
126 |
127 |
--------------------------------------------------------------------------------
/configs/dataset_configs.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Dataset configs."""
17 | import os
18 | from configs.config_base import D
19 |
20 |
21 | _shared_dataset_config = D(
22 | batch_duplicates=1,
23 | cache_dataset=True,
24 | )
25 |
26 | # Generate tfrecords for the dataset using data/scripts/create_coco_tfrecord.py
27 | # and add paths here.
28 | COCO_TRAIN_TFRECORD_PATTERN = 'gs://pix2seq/multi_task/data/coco/tfrecord/train*'
29 | COCO_VAL_TFRECORD_PATTERN = 'gs://pix2seq/multi_task/data/coco/tfrecord/val*'
30 |
31 | # Download from gs://pix2seq/multi_task/data/coco/json
32 | COCO_ANNOTATIONS_DIR = '/tmp/coco_annotations'
33 |
34 | _shared_coco_dataset_config = D(
35 | train_file_pattern=COCO_TRAIN_TFRECORD_PATTERN,
36 | val_file_pattern=COCO_VAL_TFRECORD_PATTERN,
37 | train_num_examples=118287,
38 | eval_num_examples=5000,
39 | train_split='train',
40 | eval_split='validation',
41 | # Directory of annotations used by the metrics.
42 | # Also need to set train_filename_for_metrics and val_filename_for_metrics.
43 | # If unset, groundtruth annotations should be specified via
44 | # record_groundtruth.
45 | coco_annotations_dir_for_metrics=COCO_ANNOTATIONS_DIR,
46 | label_shift=0,
47 | **_shared_dataset_config
48 | )
49 |
50 | dataset_configs = {
51 | 'coco/2017_object_detection':
52 | D(
53 | name='coco/2017_object_detection',
54 | train_filename_for_metrics='instances_train2017.json',
55 | val_filename_for_metrics='instances_val2017.json',
56 | category_names_path=os.path.join(
57 | _shared_coco_dataset_config['coco_annotations_dir_for_metrics'],
58 | 'instances_val2017.json'),
59 | **_shared_coco_dataset_config
60 | ),
61 | 'coco/2017_instance_segmentation':
62 | D(
63 | name='coco/2017_instance_segmentation',
64 | train_filename_for_metrics='instances_train2017.json',
65 | val_filename_for_metrics='instances_val2017.json',
66 | category_names_path=os.path.join(
67 | _shared_coco_dataset_config['coco_annotations_dir_for_metrics'],
68 | 'instances_val2017.json'),
69 | **_shared_coco_dataset_config
70 | ),
71 | 'coco/2017_keypoint_detection':
72 | D(
73 | name='coco/2017_keypoint_detection',
74 | train_filename_for_metrics='person_keypoints_train2017.json',
75 | val_filename_for_metrics='person_keypoints_val2017.json',
76 | category_names_path=os.path.join(
77 | _shared_coco_dataset_config['coco_annotations_dir_for_metrics'],
78 | 'person_keypoints_val2017.json'),
79 | **_shared_coco_dataset_config
80 | ),
81 | 'coco/2017_captioning':
82 | D(name='coco/2017_captioning',
83 | train_filename_for_metrics='captions_train2017_eval_compatible.json',
84 | val_filename_for_metrics='captions_val2017_eval_compatible.json',
85 | **_shared_coco_dataset_config),
86 | }
87 |
--------------------------------------------------------------------------------
/data/cityscapes.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Cityscapes dataset."""
17 | from data import dataset as dataset_lib
18 | import tensorflow as tf
19 |
20 |
21 | @dataset_lib.DatasetRegistry.register('cityscapes_panoptic')
22 | class CityscapesPanopticDataset(dataset_lib.TFRecordDataset):
23 | """Cityscapes panoptic dataset."""
24 |
25 | def get_feature_map(self, training):
26 | """Returns feature map for parsing the TFExample."""
27 | del training
28 | return {
29 | 'image/encoded':
30 | tf.io.FixedLenFeature([], tf.string),
31 | 'image/segmentation/class/encoded':
32 | tf.io.FixedLenFeature([], tf.string),
33 | 'image/height':
34 | tf.io.FixedLenFeature([], tf.int64),
35 | 'image/width':
36 | tf.io.FixedLenFeature([], tf.int64),
37 | 'image/filename':
38 | tf.io.FixedLenFeature([], tf.string),
39 | }
40 |
41 | def extract(self, example, training):
42 | """Extracts needed features & annotations into a flat dictionary.
43 |
44 | Note:
45 | - label starts at 1 instead of 0, as 0 is reserved for special use
46 | (such as padding).
47 | - coordinates (e.g. bbox) are (normalized to be) in [0, 1].
48 |
49 | Args:
50 | example: `dict` of raw features.
51 | training: `bool` of training vs eval mode.
52 |
53 | Returns:
54 | example: `dict` of relevant features and labels.
55 | """
56 | # Decode image and label.
57 | image = tf.io.decode_image(example['image/encoded'], channels=3)
58 | image.set_shape([1024, 2048, 3])
59 | label = example['image/segmentation/class/encoded']
60 | label = tf.io.decode_raw(
61 | example['image/segmentation/class/encoded'], out_type=tf.int32)
62 | label_shape = tf.stack([1024, 2048])
63 | label = tf.reshape(label, label_shape)
64 |
65 | # Map instance ids to range(1, num_instance + 1)
66 | unique_instance_ids, _ = tf.unique(tf.reshape(label, [-1]))
67 | num_instances = tf.size(unique_instance_ids)
68 | new_instance_ids = tf.random.shuffle(tf.range(1, num_instances + 1))
69 | def map_ids(x, src_ids, tgt_ids):
70 | """Convert object ids into semantic classes."""
71 | x = tf.equal(x[:, :, tf.newaxis], src_ids[tf.newaxis, tf.newaxis, :])
72 | x = tf.reduce_sum(tf.cast(x, tgt_ids.dtype) *
73 | tgt_ids[tf.newaxis, tf.newaxis, :], -1)
74 | return x
75 | identity = map_ids(label, unique_instance_ids, new_instance_ids)
76 |
77 | # label = class * max_instances_per_class + per_class_instance_id
78 | semantic = label // self.config.max_instances_per_class
79 |
80 | ignore_mask = tf.logical_not(tf.logical_and(
81 | tf.greater_equal(semantic, 0),
82 | tf.less(semantic, self.config.num_classes -
83 | 1))) # num_classes includes padding class.
84 | # 0 is reserved for background and labels which are to be ignored.
85 | semantic = tf.where(ignore_mask, tf.zeros_like(semantic), semantic + 1)
86 | identity = tf.where(ignore_mask, tf.zeros_like(identity), identity)
87 |
88 | return {
89 | 'image':
90 | tf.image.convert_image_dtype(image, tf.float32),
91 | # TODO(srbs): Find another hashing strategy that does not have
92 | # collisions possibly by leveraging the structure of the filename which
93 | # is _123456_123456.
94 | # Coco metrics would fail if there are duplicate image ids in preds or
95 | # gt.
96 | 'image/id':
97 | tf.strings.to_hash_bucket(
98 | example['image/filename'], num_buckets=1000000000),
99 | 'label_map': tf.stack([semantic, identity], -1)
100 | }
101 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Dataset base class."""
17 |
18 | import abc
19 | import functools
20 | import operator
21 | from typing import Callable
22 | import ml_collections
23 |
24 | import registry
25 | import tensorflow as tf
26 | import tensorflow_datasets as tfds
27 |
28 |
29 | DatasetRegistry = registry.Registry()
30 |
31 |
32 | def mix_datasets(input_fns, weights):
33 | """Mix multiple datasets according to weights.
34 |
35 | Args:
36 | input_fns: a list of input_fn's. Each input_fn takes in an input_context and
37 | produces a tf.data.Dataset instance.
38 | weights: a list of floats where weights[i] represents the probability to
39 | sample from input_fns[i].
40 |
41 | Returns:
42 | a tf.data.Dataset instance.
43 | """
44 | def input_fn(input_context):
45 | dses = []
46 | for ifn in input_fns:
47 | dses.append(ifn(input_context))
48 | mixed_ds = tf.data.Dataset.sample_from_datasets(dses, weights)
49 | return mixed_ds
50 | return tf.distribute.get_strategy().distribute_datasets_from_function(
51 | input_fn)
52 |
53 |
54 | class Dataset(abc.ABC):
55 | """A dataset that handles creating a tf.data.Dataset."""
56 |
57 | def __init__(self, config: ml_collections.ConfigDict):
58 | """Constructs the dataset."""
59 | self.config = config.dataset
60 | self.task_config = config.task
61 |
62 | @abc.abstractmethod
63 | def extract(self, example, training):
64 | """Extracts needed features & annotations into a flat dictionary.
65 |
66 | Note: be consisous about 0 in label, which should probably reserved for
67 | special use (such as padding).
68 |
69 | Args:
70 | example: `dict` of raw features.
71 | training: `bool` of training vs eval mode.
72 |
73 | Returns:
74 | example: `dict` of relevant features and labels
75 | """
76 |
77 | @abc.abstractmethod
78 | def load_dataset(self, input_context, training):
79 | """Load tf.data.Dataset from sources such as TFDS or TFRecord files."""
80 |
81 | def parse_example(self, example, training):
82 | del training
83 | return example
84 |
85 | def filter_example(self, unused_example, unused_training):
86 | return True
87 |
88 | def pipeline(self,
89 | process_single_example: Callable[[tf.data.Dataset, int, bool],
90 | tf.data.Dataset],
91 | global_batch_size: int, training: bool):
92 | """Data pipeline from name to preprocessed examples.
93 |
94 | Args:
95 | process_single_example: a function that takes single example dataset and
96 | returns processed example dataset.
97 | global_batch_size: global batch size.
98 | training: training vs eval mode.
99 |
100 | Returns:
101 | An input_fn which generates a tf.data.Dataset instance.
102 | """
103 | config = self.config
104 | def input_fn(input_context):
105 | dataset = self.load_dataset(input_context, training)
106 | if config.cache_dataset:
107 | dataset = dataset.cache()
108 |
109 | if input_context:
110 | batch_size = input_context.get_per_replica_batch_size(global_batch_size)
111 | # Sharding is not neccesary for TFDS given read_config above.
112 | # dataset = dataset.shard(input_context.num_input_pipelines,
113 | # input_context.input_pipeline_id)
114 | else:
115 | batch_size = global_batch_size
116 |
117 | if training:
118 | options = tf.data.Options()
119 | options.deterministic = False
120 | options.experimental_slack = True
121 | dataset = dataset.with_options(options)
122 | buffer_size = config.get('buffer_size', 0)
123 | if buffer_size <= 0:
124 | buffer_size = 10 * batch_size
125 | dataset = dataset.shuffle(buffer_size)
126 | dataset = dataset.repeat()
127 |
128 | dataset = dataset.map(
129 | lambda x: self.parse_example(x, training),
130 | num_parallel_calls=tf.data.experimental.AUTOTUNE
131 | ).filter(
132 | lambda x: self.filter_example(x, training)
133 | ).map(
134 | lambda x: self.extract(x, training),
135 | num_parallel_calls=tf.data.experimental.AUTOTUNE
136 | )
137 | if process_single_example:
138 | dataset = process_single_example(
139 | dataset, config.batch_duplicates, training)
140 |
141 | # TODO(b/181662974): Revert this and support non-even batch sizes.
142 | # dataset = dataset.batch(batch_size, drop_remainder=training)
143 | dataset = dataset.padded_batch(batch_size, drop_remainder=True)
144 | if config.batch_duplicates > 1 and training:
145 | dataset = dataset.map(self._flatten_dims,
146 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
147 | dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
148 | return dataset
149 |
150 | return input_fn
151 |
152 | def _flatten_dims(self, example):
153 | """Flatten first 2 dims when batch is independently duplicated."""
154 |
155 | def flatten_first_2_dims(t):
156 | """Merge first 2 dims."""
157 | shape_list = t.shape.as_list()
158 | new_bsz = functools.reduce(operator.mul, shape_list[:2])
159 | out_shape = [new_bsz] + shape_list[2:]
160 | return tf.reshape(t, out_shape)
161 |
162 | return tf.nest.map_structure(flatten_first_2_dims, example)
163 |
164 | @property
165 | @abc.abstractmethod
166 | def num_train_examples(self):
167 | """Number of training examples."""
168 |
169 | @property
170 | @abc.abstractmethod
171 | def num_eval_examples(self):
172 | """Number of eval examples."""
173 |
174 |
175 | class TFDSDataset(Dataset):
176 | """A dataset created from a TFDS dataset.
177 |
178 | Each example is a dictionary, but the fields may be different for each
179 | dataset.
180 |
181 | Each task would have a list of required fields (e.g. bounding boxes for
182 | object detection). When a dataset is used for a specific task, it should
183 | contain all the fields required by that task.
184 | """
185 |
186 | def __init__(self, config: ml_collections.ConfigDict):
187 | """Constructs the dataset."""
188 | super().__init__(config)
189 | self.builder = tfds.builder(self.config.tfds_name,
190 | data_dir=self.config.get('data_dir', None))
191 | self.builder.download_and_prepare()
192 | self.allowed_tasks = []
193 |
194 | def load_dataset(self, input_context, training):
195 | """Load tf.data.Dataset from TFDS."""
196 | split = self.config.train_split if training else self.config.eval_split
197 | # For TFDS, pass input_context using read_config to make TFDS read
198 | # different parts of the dataset on different workers.
199 | read_config = tfds.ReadConfig(input_context=input_context)
200 | if isinstance(split, list):
201 | dataset = self.builder.as_dataset(
202 | split=split[0], shuffle_files=training, read_config=read_config)
203 | for i in range(1, len(split)):
204 | dataset.concatenate(self.builder.as_dataset(
205 | split=split[i], shuffle_files=training, read_config=read_config))
206 | else:
207 | dataset = self.builder.as_dataset(
208 | split=split, shuffle_files=training, read_config=read_config)
209 | return dataset
210 |
211 | @property
212 | def num_train_examples(self):
213 | return self.builder.info.splits[self.config.train_split].num_examples
214 |
215 | @property
216 | def num_eval_examples(self):
217 | return self.builder.info.splits[
218 | self.config.eval_split].num_examples if not self.task_config.get(
219 | 'unbatch', False) else None
220 |
221 |
222 | class TFRecordDataset(Dataset):
223 | """A dataset created from tfrecord files."""
224 |
225 | def __init__(self, config: ml_collections.ConfigDict):
226 | """Constructs the dataset."""
227 | super().__init__(config)
228 | self.dataset_cls = tf.data.TFRecordDataset
229 |
230 | def load_dataset(self, input_context, training):
231 | """Load tf.data.Dataset from TFRecord files."""
232 | if training or self.config.eval_split == 'train':
233 | file_pattern = self.config.train_file_pattern
234 | else:
235 | file_pattern = self.config.val_file_pattern
236 | dataset = tf.data.Dataset.list_files(file_pattern, shuffle=training)
237 | dataset = dataset.interleave(
238 | self.dataset_cls, cycle_length=32, deterministic=not training,
239 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
240 | return dataset
241 |
242 | @abc.abstractmethod
243 | def get_feature_map(self, training):
244 | """Returns feature map(s) for parsing the TFExample.
245 |
246 | Returns a single feature map (a dict) to parse a TFEXample.
247 | Returns a tuple of (context feature map, sequence feature map) to parse a
248 | TFSequenceExample. Context features are non-sequence features, i.e.
249 | independent of time/frame. Sequence features have time/frame dimension.
250 |
251 | Args:
252 | training: `bool` of training vs eval mode.
253 | """
254 |
255 | def parse_example(self, example, training):
256 | """Parse the serialized example into a dictionary of tensors.
257 |
258 | Args:
259 | example: the serialized tf.train.Example or tf.train.SequenceExample.
260 | training: `bool` of training vs eval mode.
261 |
262 | Returns:
263 | a dictionary of feature name to tensors.
264 | """
265 | feature_map = self.get_feature_map(training)
266 | if isinstance(feature_map, dict):
267 | example = tf.io.parse_single_example(example, feature_map)
268 | else:
269 | context_features, sequence_features = feature_map
270 | example, sequence = tf.io.parse_single_sequence_example(
271 | example, context_features, sequence_features)
272 | example.update(sequence)
273 |
274 | for k in example:
275 | if isinstance(example[k], tf.SparseTensor):
276 | if example[k].dtype == tf.string:
277 | example[k] = tf.sparse.to_dense(example[k], default_value='')
278 | else:
279 | example[k] = tf.sparse.to_dense(example[k], default_value=0)
280 | return example
281 |
282 | @property
283 | def num_train_examples(self):
284 | return self.config.train_num_examples
285 |
286 | @property
287 | def num_eval_examples(self):
288 | return self.config.eval_num_examples if not self.task_config.get(
289 | 'unbatch', False) else None
290 |
291 |
292 |
--------------------------------------------------------------------------------
/data/datasets.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """All registered datasets."""
17 |
18 | from data import coco # pylint: disable=unused-import
19 |
--------------------------------------------------------------------------------
/data/decode_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Utility functions for decoding example into features and labels."""
17 |
18 | import tensorflow as tf
19 |
20 |
21 | def get_feature_map_for_image():
22 | return {
23 | 'image/encoded': tf.io.FixedLenFeature((), tf.string),
24 | 'image/source_id': tf.io.FixedLenFeature((), tf.string, ''),
25 | 'image/height': tf.io.FixedLenFeature((), tf.int64, -1),
26 | 'image/width': tf.io.FixedLenFeature((), tf.int64, -1),
27 | 'image/filename': tf.io.FixedLenFeature((), tf.string, ''),
28 | }
29 |
30 |
31 | def get_feature_map_for_object_detection():
32 | return {
33 | 'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
34 | 'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
35 | 'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
36 | 'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
37 | 'image/object/class/label': tf.io.VarLenFeature(tf.int64),
38 | 'image/object/area': tf.io.VarLenFeature(tf.float32),
39 | 'image/object/is_crowd': tf.io.VarLenFeature(tf.int64),
40 | 'image/object/score': tf.io.VarLenFeature(tf.float32),
41 | }
42 |
43 |
44 | def get_feature_map_for_instance_segmentation():
45 | return {
46 | 'image/object/segmentation':
47 | tf.io.RaggedFeature(
48 | value_key='image/object/segmentation_v',
49 | dtype=tf.float32,
50 | partitions=[
51 | tf.io.RaggedFeature.RowSplits('image/object/segmentation_sep') # pytype: disable=attribute-error
52 | ]),
53 | }
54 |
55 |
56 | def get_feature_map_for_keypoint_detection():
57 | return {
58 | 'image/object/keypoints':
59 | tf.io.RaggedFeature(
60 | value_key='image/object/keypoints_v',
61 | dtype=tf.float32,
62 | partitions=[
63 | tf.io.RaggedFeature.RowSplits('image/object/keypoints_sep') # pytype: disable=attribute-error
64 | ]),
65 | 'image/object/num_keypoints':
66 | tf.io.VarLenFeature(tf.int64),
67 | }
68 |
69 |
70 | def get_feature_map_for_captioning():
71 | return {
72 | 'image/caption': tf.io.VarLenFeature(tf.string),
73 | }
74 |
75 |
76 | def decode_image(example):
77 | """Decodes the image and set its static shape."""
78 | image = tf.io.decode_image(example['image/encoded'], channels=3)
79 | image.set_shape([None, None, 3])
80 | image = tf.image.convert_image_dtype(image, tf.float32)
81 | return image
82 |
83 |
84 | def decode_boxes(example):
85 | """Concat box coordinates in the format of [ymin, xmin, ymax, xmax]."""
86 | xmin = example['image/object/bbox/xmin']
87 | xmax = example['image/object/bbox/xmax']
88 | ymin = example['image/object/bbox/ymin']
89 | ymax = example['image/object/bbox/ymax']
90 | return tf.stack([ymin, xmin, ymax, xmax], axis=-1)
91 |
92 |
93 | def decode_areas(example):
94 | xmin = example['image/object/bbox/xmin']
95 | xmax = example['image/object/bbox/xmax']
96 | ymin = example['image/object/bbox/ymin']
97 | ymax = example['image/object/bbox/ymax']
98 | height = tf.cast(example['image/height'], dtype=tf.float32)
99 | width = tf.cast(example['image/width'], dtype=tf.float32)
100 | return tf.cond(
101 | tf.greater(tf.shape(example['image/object/area'])[0], 0),
102 | lambda: example['image/object/area'],
103 | lambda: (xmax - xmin) * (ymax - ymin) * height * width)
104 |
105 |
106 | def decode_is_crowd(example):
107 | return tf.cond(
108 | tf.greater(tf.shape(example['image/object/is_crowd'])[0], 0),
109 | lambda: tf.cast(example['image/object/is_crowd'], dtype=tf.bool),
110 | lambda: tf.zeros_like(example['image/object/class/label'], dtype=tf.bool)
111 | )
112 |
113 |
114 | def decode_scores(example):
115 | return tf.cond(
116 | tf.greater(tf.shape(example['image/object/score'])[0], 0),
117 | lambda: example['image/object/score'],
118 | lambda: tf.ones_like(example['image/object/class/label'], # pylint: disable=g-long-lambda
119 | dtype=tf.float32)
120 | )
121 |
--------------------------------------------------------------------------------
/data/obj365.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Object365 dataset class."""
17 |
18 | import ml_collections
19 | from data import dataset as dataset_lib
20 | from data import decode_utils
21 | import tensorflow as tf
22 |
23 |
24 | @dataset_lib.DatasetRegistry.register('obj365')
25 | class Obj365Dataset(dataset_lib.TFRecordDataset):
26 | """Dataset for Obj365 tasks."""
27 |
28 | def __init__(self, config: ml_collections.ConfigDict):
29 | """Constructs the dataset.
30 |
31 | Args:
32 | config: the model config.
33 | """
34 | super().__init__(config)
35 |
36 | if 'label_shift' in config.dataset:
37 | self.label_shift = config.dataset.label_shift
38 | else:
39 | self.label_shift = 0
40 |
41 | def _get_source_id(self, example):
42 | def _generate_source_id():
43 | return tf.strings.as_string(
44 | tf.strings.to_hash_bucket_fast(example['image/encoded'], 2**63 - 1))
45 |
46 | if self.config.get('regenerate_source_id', False):
47 | source_id = _generate_source_id()
48 | else:
49 | source_id = tf.cond(
50 | tf.greater(tf.strings.length(example['image/source_id']),
51 | 0), lambda: example['image/source_id'],
52 | _generate_source_id)
53 | return source_id
54 |
55 | def _decode_masks(self, example):
56 | """Decode a set of PNG masks to the tf.float32 tensors."""
57 | def _decode_png_mask(png_bytes):
58 | mask = tf.squeeze(
59 | tf.io.decode_png(png_bytes, channels=1, dtype=tf.uint8), axis=-1)
60 | mask = tf.cast(mask, dtype=tf.float32)
61 | mask.set_shape([None, None])
62 | return mask
63 |
64 | height = example['image/height']
65 | width = example['image/width']
66 | masks = example['image/object/mask']
67 | return tf.cond(
68 | tf.greater(tf.size(masks), 0),
69 | lambda: tf.map_fn(_decode_png_mask, masks, dtype=tf.float32),
70 | lambda: tf.zeros([0, height, width], dtype=tf.float32))
71 |
72 | def extract(self, example, training):
73 | """Extracts needed features & annotations into a flat dictionary.
74 |
75 | Note:
76 | - label starts at 1 instead of 0, as 0 is reserved for special use
77 | (such as padding).
78 | - coordinates (e.g. bbox) are (normalized to be) in [0, 1].
79 |
80 | Args:
81 | example: `dict` of raw features.
82 | training: `bool` of training vs eval mode.
83 |
84 | Returns:
85 | example: `dict` of relevant features and labels.
86 | """
87 | bbox = decode_utils.decode_boxes(example)
88 | example = {
89 | 'image': decode_utils.decode_image(example),
90 | 'image/id': self._get_source_id(example),
91 | 'bbox': bbox,
92 | 'is_crowd': decode_utils.decode_is_crowd(example),
93 | 'label': example['image/object/class/label'] + self.label_shift,
94 | 'area': decode_utils.decode_areas(example),
95 | }
96 | return example
97 |
98 | @property
99 | def num_train_examples(self):
100 | return {
101 | 'obj365': 1662289,
102 | 'obj365v1': 608606,
103 | 'obj365v2': 1662289
104 | }[self.config.dataset_name]
105 |
106 | @property
107 | def num_eval_examples(self):
108 | return {
109 | 'obj365': 80000,
110 | 'obj365v1': 30000,
111 | 'obj365v2': 80000
112 | }[self.config.dataset_name]
113 |
114 | @property
115 | def num_classes(self):
116 | return 365
117 |
118 | def get_feature_map(self, training):
119 | """Returns feature map for parsing the TFExample."""
120 | del training
121 | image_feature_map = decode_utils.get_feature_map_for_image()
122 | detection_feature_map = decode_utils.get_feature_map_for_object_detection()
123 | feature_map = {**image_feature_map, **detection_feature_map}
124 | if self.config.get('include_mask', False):
125 | feature_map.update({'image/object/mask': tf.io.VarLenFeature(tf.string)})
126 | return feature_map
127 |
--------------------------------------------------------------------------------
/data/recognition.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Image classification datasets."""
17 |
18 | from data import dataset as dataset_lib
19 | import tensorflow as tf
20 |
21 |
22 | @dataset_lib.DatasetRegistry.register('object_recognition')
23 | class ImageDataset(dataset_lib.TFDSDataset):
24 | """Dataset for image classification datasets."""
25 |
26 | def extract(self, example, training):
27 | """Extracts needed features & annotations into a flat dictionary.
28 |
29 | Args:
30 | example: `dict` of raw features.
31 | training: `bool` of training vs eval mode.
32 |
33 | Returns:
34 | example: `dict` of relevant features and labels.
35 | """
36 | image = example['image']
37 | if image.shape.rank == 2 or image.shape[-1] == 1:
38 | image = tf.image.grayscale_to_rgb(image)
39 | if 'label' in example:
40 | label = example['label']
41 | else:
42 | label = tf.zeros([], dtype=tf.int32)
43 | return {'image': image,
44 | 'label': label}
45 |
46 | @property
47 | def num_classes(self):
48 | return self.builder.info.features['label'].num_classes
49 |
50 |
51 |
--------------------------------------------------------------------------------
/data/scripts/create_davis_tfrecord.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Generate tfrecord for DAVIS2017."""
17 |
18 | import collections
19 | import os
20 |
21 | from absl import app
22 | from absl import flags
23 | from absl import logging
24 | import numpy as np
25 | from PIL import Image
26 | import tensorflow as tf
27 |
28 | flags.DEFINE_string('split', 'train', 'train or val')
29 | flags.DEFINE_integer('shards', 25, '')
30 | flags.DEFINE_integer('num_frames', 3, '')
31 | flags.DEFINE_string('data_dir', '', '')
32 | flags.DEFINE_string('output_dir', '', '')
33 |
34 | FLAGS = flags.FLAGS
35 |
36 |
37 | # There is one video ('tennis' in train) that labels one object as 255.
38 | def handle_mislabel_255(arr):
39 | maxm = np.max(arr[arr != 255])
40 | arr[arr == 255] = maxm + 1
41 | return arr
42 |
43 |
44 | def generate_tf_sequence_example(
45 | image_list, segmentation_list, filename, video_name):
46 | """Create tf sequence example."""
47 | num_frames = len(image_list)
48 | assert len(image_list) == len(segmentation_list)
49 | height, width, c = image_list[0].shape
50 | assert c == 3
51 | frame_id = filename.split('.')[0]
52 |
53 | example_proto = tf.train.SequenceExample(
54 | context=tf.train.Features(
55 | feature={
56 | 'image/format':
57 | tf.train.Feature(
58 | bytes_list=tf.train.BytesList(
59 | value=[bytes('png', 'utf-8')])),
60 | 'image/channels':
61 | tf.train.Feature(int64_list=tf.train.Int64List(value=[c])),
62 | 'image/height':
63 | tf.train.Feature(
64 | int64_list=tf.train.Int64List(value=[height])),
65 | 'image/width':
66 | tf.train.Feature(
67 | int64_list=tf.train.Int64List(value=[width])),
68 | 'video/name':
69 | tf.train.Feature(
70 | bytes_list=tf.train.BytesList(
71 | value=[bytes(video_name, 'utf-8')])),
72 | 'video/frame_id':
73 | tf.train.Feature(
74 | bytes_list=tf.train.BytesList(
75 | value=[bytes(frame_id, 'utf-8')])),
76 | 'video/num_frames':
77 | tf.train.Feature(
78 | int64_list=tf.train.Int64List(value=[num_frames])),
79 | }),
80 | feature_lists=tf.train.FeatureLists(
81 | feature_list={
82 | 'video/frames':
83 | tf.train.FeatureList(feature=[
84 | tf.train.Feature(
85 | bytes_list=tf.train.BytesList(
86 | value=[tf.io.encode_png(image).numpy()]))
87 | for image in image_list
88 | ]),
89 | 'video/segmentations':
90 | tf.train.FeatureList(feature=[
91 | tf.train.Feature(
92 | bytes_list=tf.train.BytesList(value=[
93 | tf.io.encode_png(tf.expand_dims(seg, -1)).numpy()
94 | ])) for seg in segmentation_list
95 | ]),
96 | }))
97 |
98 | return example_proto.SerializeToString()
99 |
100 |
101 | def main(unused_argv):
102 | split = FLAGS.split
103 | data_dir = FLAGS.data_dir
104 | images_dir = os.path.join(data_dir, 'JPEGImages/480p')
105 | annotation_dir = os.path.join(data_dir, 'Annotations_unsupervised/480p/')
106 | video_names = [
107 | s.strip() for s in tf.io.gfile.GFile(
108 | os.path.join(data_dir, f'ImageSets/2017/{split}.txt')).readlines()]
109 |
110 | num_frames = FLAGS.num_frames
111 | shards = FLAGS.shards
112 | output_dir = FLAGS.output_dir
113 | writers = [
114 | tf.io.TFRecordWriter(os.path.join(
115 | output_dir,
116 | f'{split}_{num_frames}-{i:05d}-of-{shards:05d}.tfrecord'))
117 | for i in range(shards)
118 | ]
119 |
120 | k = 0
121 | for i, video_name in enumerate(video_names):
122 | image_filenames = tf.io.gfile.listdir(
123 | os.path.join(images_dir, video_name))
124 | ann_filenames = tf.io.gfile.listdir(
125 | os.path.join(annotation_dir, video_name))
126 |
127 | all_images = collections.deque(
128 | maxlen=num_frames if num_frames > 0 else None)
129 | all_segs = collections.deque(
130 | maxlen=num_frames if num_frames > 0 else None)
131 | for j, (image_f, ann_f) in enumerate(zip(image_filenames, ann_filenames)):
132 | logging.info('%s, %s', video_name, image_f)
133 |
134 | # load the image.
135 | image = np.asarray(
136 | Image.open(
137 | tf.io.gfile.GFile(
138 | os.path.join(images_dir, video_name, image_f), 'rb')))
139 | all_images.append(image)
140 |
141 | # load the segmentations.
142 | data = np.array(
143 | Image.open(
144 | tf.io.gfile.GFile(
145 | os.path.join(annotation_dir, video_name, ann_f), 'rb')))
146 | data = handle_mislabel_255(data)
147 | all_segs.append(data)
148 |
149 | if j >= num_frames - 1 and num_frames > 0:
150 | serialized_example = generate_tf_sequence_example(
151 | all_images, all_segs, image_f, video_name)
152 | writers[k % shards].write(serialized_example)
153 | k += 1
154 |
155 | # Write all frames out in the same example.
156 | if num_frames <= 0:
157 | serialized_example = generate_tf_sequence_example(
158 | all_images, all_segs, image_filenames[-1], video_name)
159 | writers[k % shards].write(serialized_example)
160 | k += 1
161 |
162 | for writer in writers:
163 | writer.close()
164 |
165 |
166 | if __name__ == '__main__':
167 | app.run(main)
168 |
--------------------------------------------------------------------------------
/data/scripts/create_kittistep_tfrecord.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Generate tfrecord for kitti-step."""
17 |
18 | import collections
19 | import os
20 | from typing import Optional
21 |
22 | from absl import app
23 | from absl import flags
24 | from absl import logging
25 | import numpy as np
26 | from PIL import Image
27 | import tensorflow as tf
28 |
29 | flags.DEFINE_string('split', 'train', '')
30 | flags.DEFINE_integer('shards', 25, '')
31 | flags.DEFINE_integer('num_frames', 3, '')
32 | flags.DEFINE_string('raw_image_dir', '', '')
33 | flags.DEFINE_string('raw_ann_dir', '', '')
34 | flags.DEFINE_string('output_dir', '', '')
35 |
36 | FLAGS = flags.FLAGS
37 |
38 | _INSTANCE_LABEL_DIVISOR = 1000
39 | _ENCODED_INSTANCE_LABEL_DIVISOR = 256
40 |
41 |
42 | def _decode_panoptic_map(panoptic_map_path: str) -> Optional[str]:
43 | """Decodes the panoptic map from encoded image file.
44 |
45 |
46 | Args:
47 | panoptic_map_path: Path to the panoptic map image file.
48 |
49 | Returns:
50 | Panoptic map as an encoded int32 numpy array bytes or None if not existing.
51 | """
52 | if not tf.io.gfile.exists(panoptic_map_path):
53 | return None
54 | with tf.io.gfile.GFile(panoptic_map_path, 'rb') as f:
55 | panoptic_map = np.array(Image.open(f)).astype(np.int32)
56 | semantic_map = panoptic_map[:, :, 0]
57 | instance_map = (
58 | panoptic_map[:, :, 1] * _ENCODED_INSTANCE_LABEL_DIVISOR +
59 | panoptic_map[:, :, 2])
60 | panoptic_map = semantic_map * _INSTANCE_LABEL_DIVISOR + instance_map
61 | return panoptic_map.tobytes()
62 |
63 |
64 | def generate_tf_sequence_example(
65 | image_list, panoptic_map_list, filename, video_name):
66 | """Create tf sequence example."""
67 | assert len(image_list) == len(panoptic_map_list)
68 | height, width, c = image_list[0].shape
69 | assert c == 3
70 | frame_id = filename.split('.')[0]
71 |
72 | example_proto = tf.train.SequenceExample(
73 | context=tf.train.Features(
74 | feature={
75 | 'image/filename':
76 | tf.train.Feature(
77 | bytes_list=tf.train.BytesList(
78 | value=[bytes(filename, 'utf-8')])),
79 | 'image/format':
80 | tf.train.Feature(
81 | bytes_list=tf.train.BytesList(
82 | value=[bytes('png', 'utf-8')])),
83 | 'image/channels':
84 | tf.train.Feature(int64_list=tf.train.Int64List(value=[c])),
85 | 'image/height':
86 | tf.train.Feature(
87 | int64_list=tf.train.Int64List(value=[height])),
88 | 'image/width':
89 | tf.train.Feature(
90 | int64_list=tf.train.Int64List(value=[width])),
91 | 'image/segmentation/class/format':
92 | tf.train.Feature(
93 | bytes_list=tf.train.BytesList(
94 | value=[bytes('raw', 'utf-8')])),
95 | 'video/sequence_id':
96 | tf.train.Feature(
97 | bytes_list=tf.train.BytesList(
98 | value=[bytes(video_name, 'utf-8')])),
99 | 'video/frame_id':
100 | tf.train.Feature(
101 | bytes_list=tf.train.BytesList(
102 | value=[bytes(frame_id, 'utf-8')])),
103 | }),
104 | feature_lists=tf.train.FeatureLists(
105 | feature_list={
106 | 'image/encoded_list':
107 | tf.train.FeatureList(feature=[
108 | tf.train.Feature(
109 | bytes_list=tf.train.BytesList(
110 | value=[tf.io.encode_png(image).numpy()]))
111 | for image in image_list
112 | ]),
113 | 'image/segmentation/class/encoded_list':
114 | tf.train.FeatureList(feature=[
115 | tf.train.Feature(
116 | bytes_list=tf.train.BytesList(value=[seg]))
117 | for seg in panoptic_map_list
118 | ]),
119 | }))
120 |
121 | return example_proto.SerializeToString()
122 |
123 |
124 | def main(unused_argv):
125 | split = FLAGS.split
126 | raw_image_dir = FLAGS.raw_image_dir
127 | raw_ann_dir = FLAGS.raw_ann_dir
128 | num_frames = FLAGS.num_frames
129 | video_names = tf.io.gfile.listdir(os.path.join(raw_image_dir, split))
130 | if num_frames <= 0:
131 | assert FLAGS.shards <= 0
132 | shards = FLAGS.shards if FLAGS.shards > 0 else len(video_names)
133 | tf.io.gfile.makedirs(FLAGS.output_dir)
134 | writers = [
135 | tf.io.TFRecordWriter(os.path.join(
136 | FLAGS.output_dir,
137 | f'{split}_{num_frames}-{i:05d}-of-{shards:05d}.tfrecord'))
138 | for i in range(shards)
139 | ]
140 |
141 | k = 0
142 | for i, video_name in enumerate(video_names):
143 | frame_filenames = tf.io.gfile.listdir(
144 | os.path.join(raw_image_dir, split, video_name))
145 |
146 | # If larger than 3 frames, we only save one example per video, so only
147 | # getting the first n frames of that video.
148 | if num_frames > 3:
149 | frame_filenames = frame_filenames[:num_frames]
150 |
151 | all_images = collections.deque(
152 | maxlen=num_frames if num_frames > 0 else None)
153 | all_panoptic_maps = collections.deque(
154 | maxlen=num_frames if num_frames > 0 else None)
155 | for j, fn in enumerate(frame_filenames):
156 | logging.info('%s, %s', video_name, fn)
157 |
158 | # load the image.
159 | image = np.asarray(
160 | Image.open(
161 | tf.io.gfile.GFile(
162 | os.path.join(raw_image_dir, split, video_name, fn), 'rb')))
163 | all_images.append(image)
164 |
165 | # load and decode the panoptic map.
166 | panoptic_map = _decode_panoptic_map(
167 | os.path.join(raw_ann_dir, split, video_name, fn))
168 | all_panoptic_maps.append(panoptic_map)
169 |
170 | if j >= num_frames - 1 and num_frames > 0:
171 | serialized_example = generate_tf_sequence_example(
172 | all_images, all_panoptic_maps, fn, video_name)
173 | writers[k % shards].write(serialized_example)
174 | k += 1
175 |
176 | # Write all frames out in the same example.
177 | if num_frames <= 0:
178 | serialized_example = generate_tf_sequence_example(
179 | all_images, all_panoptic_maps, frame_filenames[-1], video_name)
180 | writers[k % shards].write(serialized_example)
181 | k += 1
182 |
183 | for writer in writers:
184 | writer.close()
185 |
186 |
187 | if __name__ == '__main__':
188 | app.run(main)
189 |
--------------------------------------------------------------------------------
/data/scripts/merge_coco_json_tfrecord.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Merge COCO annotation json file with tfrecord files.
17 |
18 | This is used to merge the generated object detection annotations with the
19 | groundtruth tfrecord files, to generate new tfrecord files that can be used
20 | for downstream task inference such as instance segmentation and keypoint
21 | detection.
22 |
23 | When merging, the image features (image/encoded, image/source_id, etc) are kept
24 | the same. Object features related to detection (bbox, is_crowd, label) are
25 | populated from the json annotation file. Downstream task features (
26 | segmentation, keypoint) are padded.
27 | """
28 |
29 | import collections
30 | import json
31 | import os
32 |
33 | from absl import app
34 | from absl import flags
35 | from absl import logging
36 | from data.scripts import tfrecord_lib
37 | import tensorflow as tf
38 |
39 | flags.DEFINE_string('tfrecord_path', '', 'Tfrecord file pattern.')
40 | flags.DEFINE_string('annotation_path', '', 'JSON annotation file path.')
41 | flags.DEFINE_string('output_dir', None, 'Output directory')
42 |
43 | FLAGS = flags.FLAGS
44 |
45 |
46 | COPY_FEATURE_LIST = ['image/encoded', 'image/source_id', 'image/format',
47 | 'image/filename', 'image/height', 'image/width',
48 | 'image/key/sha256', 'image/caption']
49 |
50 |
51 | def load_instance_annotations(annotation_path):
52 | """Load instance annotations.
53 |
54 | Args:
55 | annotation_path: str. Path to the annotation file.
56 |
57 | Returns:
58 | category_id_to_name_map: dict of category ids to category names.
59 | img_to_ann: a dict of image_id to annotation.
60 | """
61 | with tf.io.gfile.GFile(annotation_path, 'r') as f:
62 | annotations = json.load(f)
63 |
64 | img_to_ann = collections.defaultdict(list)
65 | for ann in annotations['annotations']:
66 | image_id = ann['image_id']
67 | img_to_ann[image_id].append(ann)
68 |
69 | category_id_to_name_map = dict(
70 | (element['id'], element['name']) for element in annotations['categories'])
71 |
72 | return category_id_to_name_map, img_to_ann
73 |
74 |
75 | def coco_annotations_to_lists(obj_annotations, id_to_name_map):
76 | """Converts COCO annotations to feature lists.
77 |
78 | Args:
79 | obj_annotations: a list of object annotations.
80 | id_to_name_map: category id to category name map.
81 |
82 | Returns:
83 | a dict of list features.
84 | """
85 |
86 | data = dict((k, list()) for k in [
87 | 'xmin', 'xmax', 'ymin', 'ymax', 'is_crowd', 'category_id',
88 | 'category_names', 'area', 'score'])
89 |
90 | for ann in obj_annotations:
91 | (x, y, width, height) = tuple(ann['bbox'])
92 | if width > 0. and height > 0.: # Only keep valid boxes.
93 | data['xmin'].append(float(x))
94 | data['xmax'].append(float(x + width))
95 | data['ymin'].append(float(y))
96 | data['ymax'].append(float(y + height))
97 | data['is_crowd'].append(ann['iscrowd'])
98 | category_id = int(ann['category_id'])
99 | data['category_id'].append(category_id)
100 | data['category_names'].append(id_to_name_map[category_id].encode('utf8'))
101 | data['area'].append(float(height * width))
102 | data['score'].append(ann['score'])
103 |
104 | return data
105 |
106 |
107 | def obj_annotations_to_feature_dict(obj_annotations, id_to_name_map):
108 | """Convert COCO annotations to an encoded feature dict.
109 |
110 | Args:
111 | obj_annotations: a list of object annotations.
112 | id_to_name_map: category id to category name map.
113 |
114 | Returns:
115 | a dict of tf features, and the number of instances.
116 | """
117 |
118 | data = coco_annotations_to_lists(obj_annotations, id_to_name_map)
119 | feature_dict = {
120 | 'image/object/bbox/xmin':
121 | tfrecord_lib.convert_to_feature(
122 | data['xmin'], value_type='float_list'),
123 | 'image/object/bbox/xmax':
124 | tfrecord_lib.convert_to_feature(
125 | data['xmax'], value_type='float_list'),
126 | 'image/object/bbox/ymin':
127 | tfrecord_lib.convert_to_feature(
128 | data['ymin'], value_type='float_list'),
129 | 'image/object/bbox/ymax':
130 | tfrecord_lib.convert_to_feature(
131 | data['ymax'], value_type='float_list'),
132 | 'image/object/class/text':
133 | tfrecord_lib.convert_to_feature(
134 | data['category_names'], value_type='bytes_list'),
135 | 'image/object/class/label':
136 | tfrecord_lib.convert_to_feature(
137 | data['category_id'], value_type='int64_list'),
138 | 'image/object/is_crowd':
139 | tfrecord_lib.convert_to_feature(
140 | data['is_crowd'], value_type='int64_list'),
141 | 'image/object/area':
142 | tfrecord_lib.convert_to_feature(
143 | data['area'], value_type='float_list'),
144 | 'image/object/score':
145 | tfrecord_lib.convert_to_feature(
146 | data['score'], value_type='float_list'),
147 | }
148 | return feature_dict, len(data['xmin'])
149 |
150 |
151 | def update_tfrecord_file(tfrecord_path, image_to_anns, category_id_to_name_map,
152 | output_path):
153 | """Merge one tfrecord file with annotations.
154 |
155 | Args:
156 | tfrecord_path: string, the input tfrecord path.
157 | image_to_anns: a dict of image_id to annotation.
158 | category_id_to_name_map: dict of category ids to category names.
159 | output_path: string, the output tfrecord file path.
160 | """
161 | dataset = tf.data.TFRecordDataset(tfrecord_path)
162 |
163 | with tf.io.TFRecordWriter(output_path) as writer:
164 | for serialized_ex in dataset:
165 | ex = tf.train.Example()
166 | ex.ParseFromString(serialized_ex.numpy())
167 |
168 | image_id = int(ex.features.feature['image/source_id'].bytes_list.value[0])
169 | anns = image_to_anns[image_id]
170 |
171 | # Copy the following features from current tf example.
172 | feature_dict = {}
173 | for f in COPY_FEATURE_LIST:
174 | feature_dict[f] = ex.features.feature[f]
175 |
176 | # Populate the object detection features from json annotations.
177 | det_features, num_bbox = obj_annotations_to_feature_dict(
178 | anns, category_id_to_name_map)
179 | feature_dict.update(det_features)
180 |
181 | # Pad the segmentation and keypoint features.
182 | feature_dict.update({
183 | 'image/object/segmentation_v':
184 | tfrecord_lib.convert_to_feature([], value_type='float_list'),
185 | 'image/object/segmentation_sep':
186 | tfrecord_lib.convert_to_feature(
187 | [0] * (num_bbox + 1), value_type='int64_list'),
188 | 'image/object/keypoints_v':
189 | tfrecord_lib.convert_to_feature([], value_type='float_list'),
190 | 'image/object/keypoints_sep':
191 | tfrecord_lib.convert_to_feature(
192 | [0] * (num_bbox + 1), value_type='int64_list'),
193 | 'image/object/num_keypoints':
194 | tfrecord_lib.convert_to_feature(
195 | [0] * num_bbox, value_type='int64_list'),
196 | })
197 |
198 | new_ex = tf.train.Example(
199 | features=tf.train.Features(feature=feature_dict)).SerializeToString()
200 | writer.write(new_ex)
201 |
202 |
203 | def main(unused_argv):
204 | category_id_to_name_map, image_to_anns = load_instance_annotations(
205 | FLAGS.annotation_path)
206 |
207 | tfrecord_lib.check_and_make_dir(FLAGS.output_dir)
208 |
209 | tfrecord_paths = tf.io.gfile.glob(FLAGS.tfrecord_path)
210 | for tfrecord_path in tfrecord_paths:
211 | output_path = os.path.join(FLAGS.output_dir,
212 | os.path.basename(tfrecord_path))
213 | update_tfrecord_file(tfrecord_path, image_to_anns, category_id_to_name_map,
214 | output_path)
215 | logging.info('Finished writing file %s', output_path)
216 |
217 |
218 | if __name__ == '__main__':
219 | app.run(main)
220 |
--------------------------------------------------------------------------------
/data/scripts/tfrecord_lib.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Helper functions for creating TFRecord datasets."""
17 |
18 | import hashlib
19 | import io
20 | import itertools
21 | # copybara:insert import multiprocessing
22 |
23 | from absl import logging
24 | import numpy as np
25 | from PIL import Image
26 | import tensorflow as tf
27 |
28 |
29 |
30 | LOG_EVERY = 100
31 |
32 |
33 | def convert_to_feature(value, value_type=None):
34 | """Converts the given python object to a tf.train.Feature.
35 |
36 | Args:
37 | value: int, float, bytes or a list of them.
38 | value_type: optional, if specified, forces the feature to be of the given
39 | type. Otherwise, type is inferred automatically. Can be one of
40 | ['bytes', 'int64', 'float', 'bytes_list', 'int64_list', 'float_list']
41 |
42 | Returns:
43 | feature: A tf.train.Feature object.
44 | """
45 |
46 | if value_type is None:
47 |
48 | element = value[0] if isinstance(value, list) else value
49 |
50 | if isinstance(element, bytes):
51 | value_type = 'bytes'
52 |
53 | elif isinstance(element, (int, np.integer)):
54 | value_type = 'int64'
55 |
56 | elif isinstance(element, (float, np.floating)):
57 | value_type = 'float'
58 |
59 | else:
60 | raise ValueError('Cannot convert type {} to feature'.
61 | format(type(element)))
62 |
63 | if isinstance(value, list):
64 | value_type = value_type + '_list'
65 |
66 | if value_type == 'int64':
67 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
68 |
69 | elif value_type == 'int64_list':
70 | value = np.asarray(value).astype(np.int64).reshape(-1)
71 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
72 |
73 | elif value_type == 'float':
74 | return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
75 |
76 | elif value_type == 'float_list':
77 | value = np.asarray(value).astype(np.float32).reshape(-1)
78 | return tf.train.Feature(float_list=tf.train.FloatList(value=value))
79 |
80 | elif value_type == 'bytes':
81 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
82 |
83 | elif value_type == 'bytes_list':
84 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
85 |
86 | else:
87 | raise ValueError('Unknown value_type parameter - {}'.format(value_type))
88 |
89 |
90 | def image_info_to_feature_dict(height, width, filename, image_id,
91 | encoded_str, encoded_format):
92 | """Convert image information to a dict of features."""
93 |
94 | key = hashlib.sha256(encoded_str).hexdigest()
95 |
96 | return {
97 | 'image/height': convert_to_feature(height),
98 | 'image/width': convert_to_feature(width),
99 | 'image/filename': convert_to_feature(filename.encode('utf8')),
100 | 'image/source_id': convert_to_feature(str(image_id).encode('utf8')),
101 | 'image/key/sha256': convert_to_feature(key.encode('utf8')),
102 | 'image/encoded': convert_to_feature(encoded_str),
103 | 'image/format': convert_to_feature(encoded_format.encode('utf8')),
104 | }
105 |
106 |
107 | def read_image(image_path):
108 | pil_image = Image.open(image_path)
109 | return np.asarray(pil_image)
110 |
111 |
112 | def encode_mask_as_png(mask):
113 | pil_image = Image.fromarray(mask)
114 | output_io = io.BytesIO()
115 | pil_image.save(output_io, format='PNG')
116 | return output_io.getvalue()
117 |
118 |
119 | def write_tf_record_dataset(output_path, annotation_iterator,
120 | process_func, num_shards,
121 | multiple_processes=None, unpack_arguments=True):
122 | """Iterates over annotations, processes them and writes into TFRecords.
123 |
124 | Args:
125 | output_path: The prefix path to create TF record files.
126 | annotation_iterator: An iterator of tuples containing details about the
127 | dataset.
128 | process_func: A function which takes the elements from the tuples of
129 | annotation_iterator as arguments and returns a tuple of (tf.train.Example,
130 | int). The integer indicates the number of annotations that were skipped.
131 | num_shards: int, the number of shards to write for the dataset.
132 | multiple_processes: integer, the number of multiple parallel processes to
133 | use. If None, uses multi-processing with number of processes equal to
134 | `os.cpu_count()`, which is Python's default behavior. If set to 0,
135 | multi-processing is disabled.
136 | Whether or not to use multiple processes to write TF Records.
137 | unpack_arguments:
138 | Whether to unpack the tuples from annotation_iterator as individual
139 | arguments to the process func or to pass the returned value as it is.
140 |
141 | Returns:
142 | num_skipped: The total number of skipped annotations.
143 | """
144 |
145 | writers = [
146 | tf.io.TFRecordWriter(
147 | output_path + '-%05d-of-%05d.tfrecord' % (i, num_shards))
148 | for i in range(num_shards)
149 | ]
150 |
151 | total_num_annotations_skipped = 0
152 |
153 | if multiple_processes is None or multiple_processes > 0:
154 |
155 | pool = multiprocessing.Pool(processes=multiple_processes)
156 | if unpack_arguments:
157 | tf_example_iterator = pool.starmap(process_func, annotation_iterator)
158 | else:
159 | tf_example_iterator = pool.imap(process_func, annotation_iterator)
160 | else:
161 | if unpack_arguments:
162 | tf_example_iterator = itertools.starmap(process_func, annotation_iterator)
163 | else:
164 | tf_example_iterator = map(process_func, annotation_iterator)
165 |
166 | for idx, (tf_example, num_annotations_skipped) in enumerate(
167 | tf_example_iterator):
168 | if idx % LOG_EVERY == 0:
169 | logging.info('On image %d', idx)
170 |
171 | total_num_annotations_skipped += num_annotations_skipped
172 | writers[idx % num_shards].write(tf_example.SerializeToString())
173 |
174 | if multiple_processes is None or multiple_processes > 0:
175 | pool.close()
176 | pool.join()
177 |
178 | for writer in writers:
179 | writer.close()
180 |
181 | logging.info('Finished writing, skipped %d annotations.',
182 | total_num_annotations_skipped)
183 | return total_num_annotations_skipped
184 |
185 |
186 | def check_and_make_dir(directory):
187 | """Creates the directory if it doesn't exist."""
188 | if not tf.io.gfile.isdir(directory):
189 | tf.io.gfile.makedirs(directory)
190 |
--------------------------------------------------------------------------------
/data/text.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Text datasets."""
17 |
18 | from data import dataset as dataset_lib
19 |
20 |
21 | @dataset_lib.DatasetRegistry.register('text')
22 | class TextDataset(dataset_lib.TFDSDataset):
23 | """Dataset."""
24 |
25 | def extract(self, example, training):
26 | """Extracts needed features & annotations into a flat dictionary.
27 |
28 | Args:
29 | example: `dict` of raw features.
30 | training: `bool` of training vs eval mode.
31 |
32 | Returns:
33 | a sequence
34 | """
35 | if self.config.tfds_name.startswith('wikipedia'):
36 | text = example['title'] + '\n\n' + example['text']
37 | elif self.config.tfds_name.startswith('wmt'):
38 | src, dst = self.config.tfds_name.split('/')[1].split('-')
39 | if training:
40 | text = '[src] ' + example[src] + ' [dst] ' + example[dst]
41 | else:
42 | text = '[src] ' + example[src] + ' [dst] '
43 | else:
44 | text = example['text']
45 | return {'text': text}
46 |
--------------------------------------------------------------------------------
/data/tokenizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Tokenizer library."""
17 |
18 | import abc
19 |
20 | import tensorflow as tf
21 | import tensorflow_text as tf_text
22 |
23 |
24 | class Tokenizer(abc.ABC):
25 | """Tokenizer base class."""
26 |
27 | def __init__(self):
28 | pass
29 |
30 | @property
31 | @abc.abstractmethod
32 | def vocab_size(self):
33 | """Vocab size."""
34 |
35 | @abc.abstractmethod
36 | def string_to_ids(self, string):
37 | """Tokenize a single string."""
38 |
39 | @abc.abstractmethod
40 | def strings_to_ids(self, strings):
41 | """Tokenize a batch of strings."""
42 |
43 | @abc.abstractmethod
44 | def ids_to_strings(self, ids, ids_len):
45 | """Detokenize a batch of ids."""
46 |
47 |
48 | class SPTokenizer(Tokenizer):
49 | """Sentence Piece Tokenizer."""
50 |
51 | def __init__(self, model_path, add_bos=False, add_eos=False):
52 | super(SPTokenizer, self).__init__()
53 | self.model_path = model_path
54 | with tf.io.gfile.GFile(model_path, "rb") as f:
55 | model = f.read()
56 | self.tokenizer = tf_text.SentencepieceTokenizer(model,
57 | out_type=tf.string,
58 | add_bos=add_bos,
59 | add_eos=add_eos)
60 |
61 | @property
62 | def vocab_size(self):
63 | return int(self.tokenizer.vocab_size().numpy())
64 |
65 | def string_to_ids(self, string):
66 | tokens = self.tokenizer.tokenize(string)
67 | pieces = self.tokenizer.string_to_id(tokens)
68 | return tf.cast(pieces, tf.int64)
69 |
70 | def strings_to_ids(self, strings):
71 | return self.string_to_ids(strings)
72 |
73 | def ids_to_strings(self, ids, ids_len):
74 | return self.tokenizer.detokenize(ids)
75 |
--------------------------------------------------------------------------------
/metrics/fid.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Evaluation metrics for FID score."""
17 |
18 | import dataclasses
19 |
20 | from absl import logging
21 | import numpy as np
22 | import tensorflow as tf
23 | import tensorflow_gan as tfgan
24 |
25 |
26 | def get_stats_for_fid(act):
27 | """Get mean and std statistics from activations for FID computation."""
28 | if act.ndim != 2:
29 | raise ValueError("Expected input to have 2 axes")
30 | act = np.asarray(act, dtype=np.float64)
31 | mu = np.mean(act, axis=0)
32 | sigma = np.cov(act, rowvar=False)
33 | return mu, sigma
34 |
35 |
36 | def _symmetric_matrix_square_root(mat, eps=1e-10):
37 | """Compute square root of a symmetric matrix."""
38 | u, s, vt = np.linalg.svd(mat, hermitian=True)
39 | si = np.where(s < eps, s, np.sqrt(s))
40 | return u.dot(np.diag(si)).dot(vt)
41 |
42 |
43 | def _trace_sqrt_product(sigma, sigma_v):
44 | """Find the trace of the positive sqrt of product of covariance matrices."""
45 | sqrt_sigma = _symmetric_matrix_square_root(sigma)
46 | sqrt_a_sigmav_a = sqrt_sigma.dot(sigma_v).dot(sqrt_sigma)
47 | return _symmetric_matrix_square_root(sqrt_a_sigmav_a).trace()
48 |
49 |
50 | def get_fid_score(mu1, sigma1, mu2, sigma2):
51 | """FID score."""
52 | if mu1.shape != mu2.shape:
53 | raise ValueError("means should have the same shape")
54 | dim, = mu1.shape
55 | if not sigma1.shape == sigma2.shape == (dim, dim):
56 | raise ValueError("covariance matrices should be the same shape (d, d)")
57 | mu1 = np.asarray(mu1, dtype=np.float64)
58 | mu2 = np.asarray(mu2, dtype=np.float64)
59 | sigma1 = np.asarray(sigma1, dtype=np.float64)
60 | sigma2 = np.asarray(sigma2, dtype=np.float64)
61 | return (np.square(mu1 - mu2).sum() + sigma1.trace() + sigma2.trace() -
62 | 2 * _trace_sqrt_product(sigma1, sigma2))
63 |
64 |
65 | @dataclasses.dataclass
66 | class TFGANMetricEvaluator:
67 | """A wrappner class for tensorflow-gan evaluation."""
68 | dataset_name: str = "cifar10"
69 | image_size: int = -1
70 | inceptionv3_input_size: int = 299
71 | activations_key: str = "pool_3"
72 | resize_method: str = "bilinear"
73 | antialias: bool = False
74 |
75 | def __post_init__(self):
76 | self.all_logits_real = []
77 | self.all_pool3_real = []
78 | self.all_logits_gen = []
79 | self.all_pool3_gen = []
80 | self.dataset_stats_mean, self.dataset_stats_cov = self.load_fid_stats()
81 |
82 | def load_fid_stats(self, stats_path=None):
83 | """Load the pre-computed dataset statistics."""
84 | logging.info("loading FID stats for datasets %s", self.dataset_name)
85 | # TODO(iamtingchen): provide stat path via config dict.
86 | if self.dataset_name == "cifar10":
87 | filename = "{}/cifar10_stats_real.npy".format(stats_path)
88 | elif self.dataset_name == "downsampled_imagenet/64x64":
89 | filename = "{}/imagenet64_stats_real.npz".format(stats_path)
90 | with tf.io.gfile.GFile(filename, "rb") as fin:
91 | stats_real = np.load(fin)
92 | return stats_real["mu"], stats_real["sigma"]
93 | elif self.dataset_name == "imagenet2012":
94 | assert self.image_size in [
95 | 32, 64, 128, 256, 384, 512, 768, 1024, 1536, 2048]
96 | filename = "{}/imagenet_man_{}_stats_real.npz".format(
97 | stats_path, self.image_size)
98 | with tf.io.gfile.GFile(filename, "rb") as fin:
99 | stats_real = np.load(fin)
100 | return stats_real["mu"], stats_real["cov"]
101 | elif self.dataset_name == "coco":
102 | filename = "{}/coco_stats_real.npz".format(stats_path)
103 | with tf.io.gfile.GFile(filename, "rb") as fin:
104 | stats_real = np.load(fin)
105 | return stats_real["mu"], stats_real["cov"]
106 | else:
107 | logging.warn("Dataset %s stats not found!", self.dataset_name)
108 | return None, None
109 |
110 | with tf.io.gfile.GFile(filename, "rb") as fin:
111 | stats_real = np.load(fin)
112 | logging.info("FID stats loading done! Number of examples %d",
113 | stats_real.shape[0])
114 | return get_stats_for_fid(stats_real)
115 |
116 | def preprocess_inputs(self, inputs, is_n1p1=False):
117 | """Resize images and shift/clip pixels to [-1, 1]."""
118 | if isinstance(inputs, list):
119 | all_inputs = tf.concat(inputs, 0)
120 | all_inputs = self.preprocess_inputs(all_inputs, is_n1p1=is_n1p1)
121 | return tf.split(all_inputs, len(inputs))
122 | if is_n1p1:
123 | inputs = tf.clip_by_value(inputs, -1.0, 1.0)
124 | inputs = (inputs + 1.0) / 2.0
125 |
126 | inputs = tf.image.resize(
127 | inputs, [self.inceptionv3_input_size, self.inceptionv3_input_size],
128 | self.resize_method,
129 | antialias=self.antialias)
130 | inputs = tf.clip_by_value(inputs, 0.0, 1.0)
131 | # transform inputs to [-1, 1]
132 | inputs = inputs * 2 - 1.0
133 | return inputs
134 |
135 | def get_inception_stats(self, inputs):
136 | if isinstance(inputs, list):
137 | return [self.get_inception_stats(x) for x in inputs]
138 | stats = tfgan.eval.run_inception(inputs)
139 | return stats["logits"], stats["pool_3"]
140 |
141 | def update_stats(self, logits_real, pool3_real, logits_gen, pool3_gen):
142 | self.all_logits_real.append(logits_real)
143 | self.all_pool3_real.append(pool3_real)
144 | self.all_logits_gen.append(logits_gen)
145 | self.all_pool3_gen.append(pool3_gen)
146 |
147 | def reset(self):
148 | self.all_logits_real.clear()
149 | self.all_pool3_real.clear()
150 | self.all_logits_gen.clear()
151 | self.all_pool3_gen.clear()
152 | return
153 |
154 | def compute_fid_score(self):
155 | """Return a dict of metrics."""
156 | metrics = {}
157 | logging.info("Computing Inception score.")
158 | all_logits_gen = np.concatenate(self.all_logits_gen, axis=0)
159 | logging.info("IS number of gen samples: %d, number of classes: %d",
160 | all_logits_gen.shape[0], all_logits_gen.shape[1])
161 | is_score = tfgan.eval.classifier_score_from_logits(all_logits_gen)
162 | metrics.update({"inception_score": is_score})
163 | logging.info("Computing FID score.")
164 | all_stats_real = np.concatenate(self.all_pool3_real, axis=0)
165 | all_stats_gen = np.concatenate(self.all_pool3_gen, axis=0)
166 | logging.info("FID number of real samples: %d", all_stats_real.shape[0])
167 | logging.info("FID number of generated samples: %d", all_stats_gen.shape[0])
168 | gen_mean, gen_cov = get_stats_for_fid(all_stats_gen)
169 | ref_mean, ref_cov = get_stats_for_fid(all_stats_real)
170 | metrics.update({
171 | "fid_batch": get_fid_score(gen_mean, gen_cov, ref_mean, ref_cov),
172 | })
173 | if self.dataset_stats_mean is not None:
174 | metrics.update({
175 | "fid_full":
176 | get_fid_score(gen_mean, gen_cov, self.dataset_stats_mean,
177 | self.dataset_stats_cov),
178 | "fid_batch_vs_full":
179 | get_fid_score(ref_mean, ref_cov, self.dataset_stats_mean,
180 | self.dataset_stats_cov),
181 | })
182 | return metrics
183 |
--------------------------------------------------------------------------------
/metrics/fvd.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Evaluation metrics for FVD/IS scores with I3D / C3D nets."""
17 |
18 | import dataclasses
19 | from functools import partial
20 |
21 | from absl import logging
22 | from jax.experimental import jax2tf
23 | import numpy as np
24 | from metrics.fid import get_fid_score
25 | from metrics.fid import get_stats_for_fid
26 | from metrics.fid import TFGANMetricEvaluator
27 |
28 | import tensorflow as tf
29 | import tensorflow_gan as tfgan
30 | from universal_diffusion.metrics import c3d
31 | from universal_diffusion.metrics import i3d
32 |
33 |
34 | @dataclasses.dataclass
35 | class FVDMetricEvaluator(TFGANMetricEvaluator):
36 | """A wrapper class for tensorflow-gan evaluation extended for FVD."""
37 | dataset_name: str
38 | image_size: int = -1
39 | activations_key: str = 'pool_3'
40 |
41 | def __post_init__(self):
42 | self.all_logits_real = []
43 | self.all_pool3_real = []
44 | self.all_logits_gen = []
45 | self.all_pool3_gen = []
46 | self.dataset_stats_mean, self.dataset_stats_cov = self.load_fid_stats()
47 |
48 | if self.dataset_name == 'ucf101':
49 | self.model = jax2tf.convert(partial(c3d.run_model, c3d.load_params()))
50 | elif self.dataset_name == 'kinetics600':
51 | self.model = jax2tf.convert(partial(i3d.run_model, i3d.load_params()))
52 | else:
53 | assert False, 'dataset not supported for FVD'
54 |
55 | def load_fid_stats(self, stat_path=None):
56 | """Load the pre-computed dataset statistics."""
57 | logging.info('loading FID stats for datasets %s', self.dataset_name)
58 | if self.dataset_name in {'ucf101', 'kinetics600'} and False:
59 | assert self.image_size in [64, 128]
60 | filename = '{}/{}_{}_stats_real.npz'.format(
61 | stat_path, self.dataset_name, self.image_size)
62 | with tf.io.gfile.GFile(filename, 'rb') as fin:
63 | stats_real = np.load(fin)
64 | logging.info('FID stats loading done! Number of examples %d',
65 | stats_real['mu'].shape[0])
66 | return stats_real['mu'], stats_real['cov']
67 | else:
68 | logging.warn('Dataset %s stats not found!', self.dataset_name)
69 | return None, None
70 |
71 | def preprocess_inputs(self, inputs, is_n1p1=False):
72 | if isinstance(inputs, list):
73 | all_inputs = tf.concat(inputs, 0)
74 | all_inputs = self.preprocess_inputs(all_inputs, is_n1p1=is_n1p1)
75 | return tf.split(all_inputs, len(inputs))
76 | if is_n1p1:
77 | inputs = tf.clip_by_value(inputs, -1.0, 1.0)
78 | inputs = (inputs + 1.0) / 2.0
79 | inputs = inputs * 255.0
80 | return inputs
81 |
82 | def get_inception_stats(self, inputs):
83 | if isinstance(inputs, list):
84 | return [self.get_inception_stats(x) for x in inputs]
85 | stats = self.model(inputs)
86 | if 'features' in stats:
87 | return stats['logits'], stats['features']
88 | else:
89 | return stats['logits_mean'], stats['pool']
90 |
91 | def compute_fid_score(self):
92 | """Return a dict of metrics."""
93 | metrics = {}
94 | logging.info('Computing Inception score.')
95 | all_logits_gen = np.concatenate(self.all_logits_gen, axis=0)
96 | all_logits_real = np.concatenate(self.all_logits_real, axis=0)
97 | logging.info('IS number of gen samples: %d, number of classes: %d',
98 | all_logits_gen.shape[0], all_logits_gen.shape[1])
99 | is_score = tfgan.eval.classifier_score_from_logits(all_logits_gen)
100 | metrics.update({'inception_score': is_score})
101 |
102 | logging.info('Computing FVD score.')
103 | all_stats_real = np.concatenate(self.all_pool3_real, axis=0)
104 | all_stats_gen = np.concatenate(self.all_pool3_gen, axis=0)
105 | logging.info('FVD number of real samples: %d', all_stats_real.shape[0])
106 | logging.info('FVD number of generated samples: %d', all_stats_gen.shape[0])
107 | gen_mean, gen_cov = get_stats_for_fid(all_stats_gen)
108 | ref_mean, ref_cov = get_stats_for_fid(all_stats_real)
109 |
110 | gen_logits_mean, gen_logits_cov = get_stats_for_fid(all_logits_gen)
111 | ref_logits_mean, ref_logits_cov = get_stats_for_fid(all_logits_real)
112 |
113 | metrics.update({
114 | 'fvd_pool_batch': get_fid_score(gen_mean, gen_cov, ref_mean, ref_cov),
115 | 'fvd_batch': get_fid_score(gen_logits_mean, gen_logits_cov,
116 | ref_logits_mean, ref_logits_cov),
117 | })
118 | if self.dataset_stats_mean is not None:
119 | metrics.update({
120 | 'fvd_pool_full':
121 | get_fid_score(gen_mean, gen_cov, self.dataset_stats_mean,
122 | self.dataset_stats_cov),
123 | 'fvd_pool_batch_vs_full':
124 | get_fid_score(ref_mean, ref_cov, self.dataset_stats_mean,
125 | self.dataset_stats_cov),
126 | })
127 | return metrics
128 |
--------------------------------------------------------------------------------
/metrics/metric_registry.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | import registry
17 |
18 | MetricRegistry = registry.Registry()
--------------------------------------------------------------------------------
/metrics/metric_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | def yxyx_to_xywh(box):
17 | ymin, xmin, ymax, xmax = box
18 | w = xmax - xmin
19 | h = ymax - ymin
20 | return [xmin, ymin, w, h]
21 |
--------------------------------------------------------------------------------
/metrics/text_metrics.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Metrics for text tasks."""
17 |
18 | from metrics import metric_registry
19 | import sacrebleu
20 |
21 |
22 | @metric_registry.MetricRegistry.register('text_sacrebleu')
23 | class BleuMetric():
24 | """BLEU metric for text."""
25 |
26 | def __init__(self, config):
27 | self._config = config
28 | self.reset_states()
29 |
30 | def reset_states(self):
31 | self._metric_values = None
32 | self._targets = []
33 | self._predictions = []
34 |
35 | def record_prediction(self, predictions, targets):
36 | """Records predictions.
37 |
38 | If multiple references are present, then each example need to have the same
39 | number of references.
40 |
41 | Args:
42 | predictions: list of strings. Has len batch_size.
43 | targets: list of strings, or list of list of strings if multiple
44 | references are present. Has len batch_size. In the format of
45 | [ex1_ref, ex2_ref, ...] or
46 | [[ex1_ref1, ex2_ref1, ...], [ex1_ref2, ex2_ref2, ...], ...].
47 | """
48 | self._predictions.extend(predictions)
49 |
50 | # Turn targets into lists.
51 | if not isinstance(targets[0], list):
52 | targets = [targets]
53 | if self._targets:
54 | assert len(self._targets) == len(targets)
55 | for i in range(len(targets)):
56 | self._targets[i].extend(targets[i])
57 | else:
58 | self._targets = targets
59 |
60 | def _evaluate(self):
61 | """Evaluates with predictions for all examples.
62 |
63 | Call this function from `self.result`.
64 |
65 | Returns:
66 | dict from metric name to float value.
67 | """
68 | tokenizer = self._config.get('tokenizer', 'intl')
69 | bleu_score = sacrebleu.corpus_bleu(self._predictions, self._targets,
70 | smooth_method='exp',
71 | smooth_value=0.0,
72 | force=False,
73 | lowercase=False,
74 | tokenize=tokenizer,
75 | use_effective_order=False)
76 | return {'bleu': bleu_score.score}
77 |
78 | def result(self):
79 | """Return the metric values (and compute it if needed)."""
80 | if self._metric_values is None:
81 | self._metric_values = self._evaluate()
82 | return self._metric_values
83 |
--------------------------------------------------------------------------------
/metrics/vos_metrics.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Video object segmentation metrics."""
17 |
18 | import io
19 | import os
20 | import tempfile
21 |
22 | from absl import logging
23 | from davis2017.evaluation import DAVISEvaluation
24 | import numpy as np
25 | import pandas as pd
26 | import PIL
27 | import utils
28 | from metrics import metric_registry
29 | import tensorflow as tf
30 |
31 | _PALETTE = [
32 | 0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128,
33 | 128, 128, 128, 128, 64, 0, 0, 191, 0, 0, 64, 128, 0, 191, 128, 0, 64, 0,
34 | 128, 191, 0, 128, 64, 128, 128, 191, 128, 128
35 | ]
36 |
37 |
38 | @metric_registry.MetricRegistry.register('davis_video_object_segmentation')
39 | class DavisVideoObjectSegmentationMetric():
40 | """Video object segmentation metric for DAVIS."""
41 |
42 | def __init__(self, config):
43 | self.config = config
44 | self.metric_names = [
45 | 'J&F-Mean', 'J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall',
46 | 'F-Decay'
47 | ]
48 | self.results_dir = config.task.metric.get('results_dir')
49 | self.dataset_eval = DAVISEvaluation(
50 | davis_root=config.dataset.annotations_dir,
51 | task=config.dataset.vos_task,
52 | gt_set=('train' if config.dataset.eval_split == 'train' else 'val'),
53 | use_tfds=False)
54 | self.reset_states()
55 |
56 | def reset_states(self):
57 | self.metric_values = None
58 | self._local_pred_dir_obj = tempfile.TemporaryDirectory()
59 |
60 | def record_prediction(self, predictions, video_name, frame_ids, step):
61 | """Records predictions.
62 |
63 | Args:
64 | predictions: uint8 of shape (num_frames, h, w, channel), where channel
65 | could be 1 or >1, but only the last channel is used as instance id.
66 | video_name: str. Video name.
67 | frame_ids: list of int, or 1-d np.array. Frame ids of predictions.
68 | step: int. The checkpoint step, used to name sub-directories.
69 | """
70 | predictions = predictions[..., -1] # Last channel is instance id.
71 | subdir = os.path.join(self._local_pred_dir_obj.name, str(step), video_name)
72 | if not tf.io.gfile.exists(subdir):
73 | tf.io.gfile.makedirs(subdir)
74 |
75 | for frame_id in frame_ids:
76 | filename = f'{frame_id:05}.png'
77 | filepath = os.path.join(subdir, filename)
78 | pred_image = PIL.Image.fromarray(predictions[frame_id], mode='L')
79 | pred_image.putpalette(_PALETTE)
80 | with io.BytesIO() as out:
81 | pred_image.save(out, format='PNG')
82 | with tf.io.gfile.GFile(filepath, 'wb') as f:
83 | f.write(out.getvalue())
84 | logging.info('Done writing out pngs for %s', video_name)
85 |
86 | if self.results_dir is not None:
87 | # Copy images to results dir.
88 | results_dir = os.path.join(self.results_dir, str(step), video_name)
89 | utils.copy_dir(subdir, results_dir)
90 |
91 | def _evaluate(self, step):
92 | """Evaluates with predictions for all images.
93 |
94 | Call this function from `self.result`.
95 |
96 | Args:
97 | step: int. The checkpoint step being evaluated.
98 |
99 | Returns:
100 | dict from metric name to float value.
101 | """
102 | result_path = os.path.join(self._local_pred_dir_obj.name, str(step))
103 | metrics_res = self.dataset_eval.evaluate(result_path)
104 | J, F = metrics_res['J'], metrics_res['F'] # pylint: disable=invalid-name
105 | g_measures = [
106 | 'J&F-Mean', 'J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall',
107 | 'F-Decay'
108 | ]
109 | g_res = np.array([
110 | (np.mean(J['M']) + np.mean(F['M'])) / 2.,
111 | np.mean(J['M']),
112 | np.mean(J['R']),
113 | np.mean(J['D']),
114 | np.mean(F['M']),
115 | np.mean(F['R']),
116 | np.mean(F['D'])
117 | ])
118 |
119 | if self.results_dir is not None:
120 | # Write metrics to result_dir.
121 | result_dir = os.path.join(self.results_dir, str(step))
122 | csv_name_global_path = os.path.join(result_dir, 'global_results.csv')
123 | csv_name_per_sequence_path = os.path.join(result_dir,
124 | 'per_sequence_results.csv')
125 |
126 | # Global results.
127 | g_res_ = np.reshape(g_res, [1, len(g_res)])
128 | table_g = pd.DataFrame(data=g_res_, columns=g_measures)
129 | with tf.io.gfile.GFile(csv_name_global_path, 'w') as f:
130 | table_g.to_csv(f, index=False, float_format='%.3f')
131 | logging.info('Global results saved in %s', csv_name_global_path)
132 |
133 | # Per sequence results.
134 | assert isinstance(J['M_per_object'], dict)
135 | seq_names = list(J['M_per_object'].keys())
136 | seq_measures = ['Sequence', 'J-Mean', 'F-Mean']
137 | j_per_object = [J['M_per_object'][x] for x in seq_names]
138 | f_per_object = [F['M_per_object'][x] for x in seq_names]
139 | table_seq = pd.DataFrame(
140 | data=list(zip(seq_names, j_per_object, f_per_object)),
141 | columns=seq_measures)
142 | with tf.io.gfile.GFile(csv_name_per_sequence_path, 'w') as f:
143 | table_seq.to_csv(f, index=False, float_format='%.3f')
144 | logging.info('Per-sequence results saved in %s',
145 | csv_name_per_sequence_path)
146 |
147 | return {name: v for name, v in zip(g_measures, g_res)}
148 |
149 | def result(self, step):
150 | """Return the metric values (and compute it if needed)."""
151 | if self.metric_values is None:
152 | self.metric_values = self._evaluate(step)
153 | return self.metric_values
154 |
--------------------------------------------------------------------------------
/models/image_ar_model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The image autoregressive decoder model."""
17 |
18 | import einops
19 | import ml_collections
20 |
21 | import utils
22 | from architectures.transformers import AutoregressiveDecoder
23 | from architectures.transformers import FITAR
24 | from models import model as model_lib
25 | from models import model_utils
26 | import tensorflow as tf
27 |
28 |
29 | @model_lib.ModelRegistry.register('image_ar_decoder')
30 | class Model(tf.keras.models.Model):
31 | """Inputs images and returns activations."""
32 |
33 | def __init__(self, config: ml_collections.ConfigDict, **kwargs):
34 | # vocab_size and max_seq_len don't include start token, which is only used
35 | # inside this class.
36 | super().__init__(**kwargs)
37 | image_size = config.dataset.image_size
38 | self.loss_type = config.train.loss_type
39 | config = config.model
40 | self.config = config
41 | if config.arch_name == 'base':
42 | mlp_ratio_dec = config.dim_mlp_dec // config.dim_att_dec
43 | self.decoder = AutoregressiveDecoder(
44 | config.vocab_size, config.max_seq_len, config.num_decoder_layers,
45 | config.dim_att_dec, mlp_ratio_dec, config.num_heads_dec,
46 | config.drop_path, config.drop_units, config.drop_att,
47 | config.pos_encoding_dec, config.shared_decoder_embedding,
48 | config.decoder_output_bias, cross_attention=False, name='ar_decoder')
49 | else:
50 | self.decoder = FITAR(
51 | layers=config.layers,
52 | x_size=image_size**2*3,
53 | num_groups=(image_size//config.patch_size)**2,
54 | latents_per_group=config.latents_per_group,
55 | x_dim=config.dim_att,
56 | latent_dim=config.dim_latent,
57 | x_num_heads=config.num_heads,
58 | latent_num_heads=config.num_heads,
59 | mlp_ratio=config.dim_mlp//config.dim_att,
60 | vocab_size=config.vocab_size,
61 | shared_embedding=config.shared_decoder_embedding,
62 | output_bias=config.decoder_output_bias,
63 | drop_path=config.drop_path,
64 | drop_units=config.drop_units,
65 | drop_att=config.drop_att,
66 | x_pos_encoding=config.pos_encoding,
67 | latent_pos_encoding=config.latent_pos_encoding)
68 |
69 | def call(self, images, labels=None, training=True):
70 | """Model function call for *training*."""
71 | with tf.name_scope(''): # for other functions to have the same name scope.
72 | config = self.config
73 | input_seq, target_seq = image2seqs(
74 | images, config.arch_name, config.patch_size, config.patch_ordering)
75 | logits = self.decoder(input_seq, None, training=training)
76 | losses = model_utils.get_loss(logits, target_seq, self.loss_type)
77 | loss = tf.reduce_mean(losses) / tf.math.log(2.0)
78 | return loss, logits, target_seq
79 |
80 | def sample(self, **kwargs):
81 | """Sampling."""
82 | # TODO(iamtingchen): add sampling.
83 | loss, _, _ = self.call(kwargs['images'], kwargs['labels'], training=False)
84 | return kwargs['images'], loss
85 |
86 |
87 | @model_lib.TrainerRegistry.register('image_ar_decoder')
88 | class ARTrainer(model_lib.Trainer):
89 | """A trainer for AR model."""
90 |
91 | def __init__(self, config: ml_collections.ConfigDict, **kwargs):
92 | """Init and setup basic training elements under strategy scope.
93 |
94 | Note: the trainer needs to be created under `strategy.scope()`.
95 |
96 | Args:
97 | config: object for holding hyperparameters and other configurations.
98 | **kwargs: other neccesary configurations to pass for training setup.
99 | """
100 | super().__init__(config, **kwargs)
101 | self._metrics.update({
102 | 'loss': tf.keras.metrics.Mean('loss'),
103 | 'accuracy': tf.keras.metrics.SparseCategoricalAccuracy(
104 | 'accuracy'),
105 | })
106 |
107 | def compute_loss(self, preprocess_outputs):
108 | """Compute loss based on model outputs and targets."""
109 | images, labels = preprocess_outputs
110 | loss, logits, target_seq = self.model(images, labels, training=True)
111 |
112 | # update metrics
113 | self._metrics['loss'].update_state(loss)
114 | self._metrics['accuracy'].update_state(target_seq, logits)
115 |
116 | return loss
117 |
118 |
119 | def image2seqs(images, arch_name, patch_size, patch_ordering='snake'):
120 | """Turn images into input and target sequences."""
121 | if arch_name == 'base':
122 | images = einops.rearrange(images, 'b h w c -> b (h w c)')
123 | target_seq = tf.cast(images * 255., tf.int32) # (bsz, seqlen)
124 | input_seq = tf.concat(
125 | [tf.zeros_like(target_seq[:, :1]), target_seq[:, :-1]], 1)
126 | else:
127 | images = utils.extract_patches(
128 | images, [patch_size, patch_size], patch_ordering=patch_ordering)
129 | target_seq = tf.cast(images * 255., tf.int32) # (bsz, groups, seqlen)
130 | flat_seq = einops.rearrange(target_seq, 'b n m -> b (n m)')
131 | input_seq = tf.concat(
132 | [tf.zeros_like(flat_seq[:, :1]), flat_seq[:, :-1]], 1)
133 | input_seq = tf.reshape(input_seq, tf.shape(target_seq))
134 | return input_seq, target_seq
135 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Abstract model file."""
17 |
18 | import abc
19 | from absl import logging
20 | import ml_collections
21 | import registry
22 | import utils
23 | from models import model_utils
24 | import tensorflow as tf
25 |
26 | ModelRegistry = registry.Registry()
27 | TrainerRegistry = registry.Registry()
28 |
29 |
30 | class Trainer(abc.ABC):
31 | """A base trainer."""
32 |
33 | def __init__(self, config: ml_collections.ConfigDict, **kwargs):
34 | """Init and setup basic training elements under strategy scope.
35 |
36 | Note: the trainer needs to be created under `strategy.scope()`.
37 |
38 | Args:
39 | config: object for holding hyperparameters and other configurations.
40 | **kwargs: other neccesary configurations to pass for training setup.
41 | """
42 | self._config = config
43 |
44 | # Setup learning rate and optimizer.
45 | num_train_examples = kwargs['num_train_examples']
46 | train_steps = kwargs['train_steps']
47 | c_opt = config.optimization
48 | batch_size = config.train.batch_size
49 | tail_steps = c_opt.get('tail_steps', 0)
50 | end_lr_factor = c_opt.get('end_lr_factor', 0.)
51 | warmup_steps = c_opt.warmup_steps or int(
52 | round(c_opt.warmup_epochs * num_train_examples // batch_size))
53 | self._learning_rate = learning_rate = model_utils.WarmUpAndDecay(
54 | c_opt.learning_rate, c_opt.learning_rate_scaling, batch_size,
55 | c_opt.learning_rate_schedule, warmup_steps, train_steps,
56 | tail_steps=tail_steps, end_lr_factor=end_lr_factor)
57 | self._optimizer = optimizer = model_utils.build_optimizer(
58 | config.optimization, learning_rate)
59 |
60 | # Setup model and checkpoints.
61 | self._model = model = ModelRegistry.lookup(config.model.name)(config)
62 | model_dir = kwargs['model_dir']
63 | latest_ckpt, ckpt, self._verify_restored = utils.restore_from_checkpoint(
64 | model_dir, False,
65 | model=model, global_step=optimizer.iterations, optimizer=optimizer)
66 | self._verify_restored_p = None
67 | if not latest_ckpt:
68 | if config.model.pretrained_ckpt:
69 | _, _, self._verify_restored_p = utils.restore_from_checkpoint(
70 | config.model.pretrained_ckpt, True, model=model)
71 | self._checkpoint_manager = tf.train.CheckpointManager(
72 | ckpt, model_dir, config.train.keep_checkpoint_max)
73 |
74 | # Setup metrics.
75 | self._metrics = {
76 | 'total_num_params': tf.keras.metrics.Mean('total_num_params'),
77 | 'grad_global_norm': tf.keras.metrics.Mean('grad_global_norm'),
78 | 'weight_linf_norm': tf.keras.metrics.Mean('weight_linf_norm'),
79 | 'loss': tf.keras.metrics.Mean('loss'),
80 | }
81 | self._metrics.update({
82 | f'loss_{t.name}': tf.keras.metrics.Mean(f'loss_{t.name}')
83 | for t in config.tasks})
84 | self._print_params = True
85 |
86 | def train_step(self, examples, tasks, strategy):
87 | """Defines a single training step for model update given examples and tasks.
88 |
89 | Args:
90 | examples: a list of data examples to be fed into the paired task class for
91 | preprocessing.
92 | tasks: a list of tasks that provide preprocessing and postprocessing for
93 | specific task.
94 | strategy: tensorflow strategy such as `TPUStrategy` or `MirroredStrategy`.
95 | """
96 | logging.info('train_step begins...')
97 | preprocessed_outputs = [
98 | t.preprocess_batched(e, training=True) for e, t in zip(examples, tasks)]
99 |
100 | task_loss_metrics = {}
101 | loss = 0
102 | grads = []
103 | for i, (o, task) in enumerate(zip(preprocessed_outputs, tasks)):
104 | with tf.GradientTape() as tape:
105 | loss_t = self.compute_loss(o)
106 | task_loss_metrics[f'loss_{task.config.task.name}'] = loss_t
107 | loss += loss_t * task.config.task.weight
108 | trainable_variables = self._model.trainable_variables
109 | grads_t = tape.gradient( # div by num_replicas_in_sync for mean grad.
110 | loss_t * task.config.task.weight / strategy.num_replicas_in_sync,
111 | trainable_variables)
112 | grads = grads_t if i == 0 else [
113 | g + gt for g, gt in zip(grads, grads_t)]
114 | self._optimizer.apply_gradients(zip(grads, trainable_variables))
115 |
116 | # Update metrics.
117 | self._metrics['loss'].update_state(loss)
118 | for k, v in task_loss_metrics.items():
119 | self._metrics[k].update_state(v)
120 | wmx = [tf.reduce_max(tf.math.abs(m)) for m in trainable_variables]
121 | self._metrics['weight_linf_norm'].update_state(tf.reduce_max(wmx))
122 | multiplier = strategy.num_replicas_in_sync
123 | self._metrics['grad_global_norm'].update_state(tf.linalg.global_norm(
124 | [tf.math.scalar_mul(multiplier, g) for g in grads if g is not None]))
125 | self._metrics['total_num_params'].update_state(
126 | utils.count_params(self._model, verbose=self._print_params))
127 | self._print_params = False
128 | logging.info('train_step ends...')
129 |
130 | @abc.abstractmethod
131 | def compute_loss(self, preprocessed_outputs):
132 | """Compute loss based on model outputs and targets."""
133 |
134 | def check_checkpoint_restored(self):
135 | """Check if the checkpoints are correctely restored."""
136 | (verify_restored,), (verify_restored_p,) = (
137 | utils.check_checkpoint_restored(
138 | [self._verify_restored], [self._verify_restored_p]))
139 | self._verify_restored = verify_restored
140 | self._verify_restored_p = verify_restored_p
141 |
142 | def reset(self):
143 | """Reseting the metrics and/or other state accumulators."""
144 | for k, _ in self._metrics.items():
145 | self._metrics[k].reset_states()
146 |
147 | @property
148 | def model(self):
149 | """Returns model instance."""
150 | return self._model
151 |
152 | @property
153 | def optimizer(self):
154 | """Returns optimizer instance."""
155 | return self._optimizer
156 |
157 | @property
158 | def learning_rate(self):
159 | """Returns learning rate scheduling instance."""
160 | return self._learning_rate
161 |
162 | @property
163 | def metrics(self):
164 | """Returns metrics instance."""
165 | return self._metrics
166 |
167 | @property
168 | def config(self):
169 | """Returns config instance."""
170 | return self._config
171 |
172 | @property
173 | def checkpoint_manager(self):
174 | """Returns checkpoint_manager instance."""
175 | return self._checkpoint_manager
176 |
--------------------------------------------------------------------------------
/models/video_diffusion_model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The model."""
17 |
18 | import functools
19 | import ml_collections
20 |
21 | from architectures.tape import VideoTapeDenoiser
22 | from models import diffusion_utils
23 | from models import image_diffusion_model
24 | from models import model as model_lib
25 | from models.diffusion_utils import Scheduler
26 | import tensorflow as tf
27 |
28 |
29 | @model_lib.ModelRegistry.register('video_diffusion_model')
30 | class Model(image_diffusion_model.Model):
31 | """A model with tape for video prediction."""
32 |
33 | def __init__(self, config: ml_collections.ConfigDict, **kwargs):
34 | super(image_diffusion_model.Model, self).__init__(**kwargs)
35 | image_size = config.dataset.image_size
36 | self.image_size = image_size
37 | self.num_classes = config.dataset.num_classes
38 | self.seq_len = config.dataset.seq_len
39 | config = config.model
40 | self.config = config
41 | self.scheduler = Scheduler(config.train_schedule)
42 | if config.x0_clip == 'auto':
43 | self.x0_clip = '{},{}'.format(-config.b_scale, config.b_scale)
44 | else:
45 | self.x0_clip = config.x0_clip
46 | self.x_channels = 3
47 |
48 | model_fn = functools.partial(
49 | VideoTapeDenoiser,
50 | num_layers=config.num_layers,
51 | latent_slots=config.latent_slots,
52 | latent_dim=config.latent_dim,
53 | latent_mlp_ratio=config.latent_mlp_ratio,
54 | latent_num_heads=config.latent_num_heads,
55 | tape_dim=config.tape_dim,
56 | tape_mlp_ratio=config.tape_mlp_ratio,
57 | rw_num_heads=config.rw_num_heads,
58 | image_height=image_size,
59 | image_width=image_size,
60 | image_channels=self.x_channels,
61 | patch_size=config.patch_size,
62 | seq_len=self.seq_len,
63 | seq_stride=config.seq_stride,
64 | seq_cond=self.seq_len - self.sample_shape[0],
65 | latent_pos_encoding=config.latent_pos_encoding,
66 | tape_pos_encoding=config.tape_pos_encoding,
67 | drop_path=config.drop_path,
68 | drop_units=config.drop_units,
69 | drop_att=config.drop_att,
70 | time_scaling=config.time_scaling,
71 | self_cond=config.self_cond,
72 | time_on_latent=config.time_on_latent,
73 | cond_on_latent_n=1 if config.cond_on_latent else 0,
74 | cond_tape_writable=config.cond_tape_writable,
75 | xattn_enc_ln=config.xattn_enc_ln,
76 | name=config.arch_name)
77 | self.denoiser = model_fn(name='denoiser')
78 | self.denoiser_ema = model_fn(name='denoiser', trainable=False)
79 | self.hidden_shapes = self.denoiser.hidden_shapes
80 |
81 | @property
82 | def sample_shape(self):
83 | if self.seq_len > 1:
84 | seq_cond = self.config.get('conditional', 'seq@0').split('@')
85 | seq_cond = int(seq_cond[-1]) if len(seq_cond) > 1 else 0
86 | return [self.seq_len-seq_cond, self.image_size, self.image_size,
87 | self.x_channels]
88 | else:
89 | return [self.image_size, self.image_size, self.x_channels]
90 |
91 | # override for conditional data
92 | def get_cond_denoise(self, labels, cond=None):
93 | config = self.config
94 | def cond_denoise(x, gamma, training):
95 | gamma = tf.reshape(gamma, [-1])
96 | if config.conditional == 'class':
97 | gamma = tf.concat([gamma[..., tf.newaxis], labels], -1)
98 | elif config.conditional != 'none' and 'seq@' not in config.conditional:
99 | raise ValueError(f'Unknown conditional {config.conditional}')
100 | return self.denoise(x, gamma, cond, training)
101 | return cond_denoise
102 |
103 | # override to pass x_cond
104 | def sample(self, num_samples=100, iterations=100, method='ddim', **kwargs):
105 | config = self.config
106 | samples_shape = [num_samples, *self.sample_shape]
107 | if config.conditional == 'class':
108 | labels = tf.random.uniform(
109 | [num_samples], 0, self.num_classes, dtype=tf.int32)
110 | labels = tf.one_hot(labels, self.num_classes)
111 | else:
112 | labels = None
113 | x_cond = kwargs['images'][:, :-self.sample_shape[0]]
114 | x_cond = (x_cond * 2. - 1.) * config.b_scale # convert 0,1 -> -s,s
115 | samples = self.scheduler.generate(
116 | self.get_cond_denoise(labels, cond=x_cond),
117 | iterations,
118 | samples_shape,
119 | hidden_shapes=self.hidden_shapes,
120 | pred_type=config.pred_type,
121 | schedule=config.infer_schedule,
122 | td=config.td,
123 | noise_std=config.noise_std,
124 | x0_clip=self.x0_clip,
125 | self_cond=config.self_cond,
126 | sampler_name=config.sampler_name)
127 | if x_cond.shape[1] > 0:
128 | samples = tf.concat([x_cond, samples], axis=1)
129 | samples = (samples / config.b_scale / 2. + 0.5) # convert -s,s -> 0,1
130 | return samples
131 |
132 | # override to allow for more cond data
133 | def noise_denoise(self, images, labels, time_step=None, training=True):
134 | config = self.config
135 | images = (images * 2. - 1.) * config.b_scale # convert 0,1 -> -s,s
136 | seq_len = self.sample_shape[0]
137 | cond_images, images = images[:, :-seq_len], images[:, -seq_len:]
138 | images_noised, noise, _, gamma = self.scheduler.add_noise(
139 | images, time_step=time_step)
140 | images_noised_ori = images_noised
141 | if config.self_cond != 'none':
142 | sc_rate = config.get('self_cond_rate', 0.5)
143 | self_cond_by_masking = config.get('self_cond_by_masking', False)
144 | if self_cond_by_masking:
145 | sc_drop_rate = 1. - sc_rate
146 | num_sc_examples = tf.shape(images)[0]
147 | else:
148 | sc_drop_rate = 0.
149 | num_sc_examples = tf.cast(
150 | tf.cast(tf.shape(images)[0], tf.float32) * sc_rate, tf.int32)
151 | cond_denoise = self.get_cond_denoise(
152 | labels[:num_sc_examples],
153 | cond_images[:num_sc_examples] if cond_images is not None else None)
154 | if self.hidden_shapes is None: # data self-cond, return is a tensor.
155 | denoise_inputs = diffusion_utils.add_self_cond_estimate(
156 | images_noised, gamma, cond_denoise, config.pred_type,
157 | config.self_cond, self.x0_clip, num_sc_examples,
158 | drop_rate=sc_drop_rate, training=training)
159 | else: # latent self-cond, return is a tuple.
160 | denoise_inputs = diffusion_utils.add_self_cond_hidden(
161 | images_noised, gamma, cond_denoise, num_sc_examples,
162 | self.hidden_shapes, drop_rate=sc_drop_rate, training=training)
163 | else:
164 | denoise_inputs = images_noised
165 | cond_denoise = self.get_cond_denoise(labels, cond_images)
166 | denoise_out = cond_denoise(denoise_inputs, gamma, training)
167 | if isinstance(denoise_out, tuple): denoise_out = denoise_out[0]
168 |
169 | return images, noise, images_noised_ori, denoise_out
170 |
171 |
172 | @model_lib.TrainerRegistry.register('video_diffusion_model')
173 | class Trainer(image_diffusion_model.Trainer):
174 | """A trainer."""
175 |
176 | def train_step(self, examples, tasks, strategy):
177 | super().train_step(examples, tasks, strategy)
178 |
--------------------------------------------------------------------------------
/pix2seq.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/pix2seq/eabd4e98f65f7627a0e727d11a4b9cbba916283d/pix2seq.gif
--------------------------------------------------------------------------------
/pix2seq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/pix2seq/eabd4e98f65f7627a0e727d11a4b9cbba916283d/pix2seq.png
--------------------------------------------------------------------------------
/registry.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """General Registry."""
17 |
18 | from typing import Any
19 | from typing import Callable
20 |
21 |
22 | class Registry(object):
23 | """Registry."""
24 |
25 | def __init__(self):
26 | self._registry = {}
27 |
28 | def register(self, key: str) -> Callable[[Any], None]:
29 | """Returns callable to register value for key."""
30 | def r(item):
31 | if key in self._registry:
32 | raise ValueError("%s already registered!" % key)
33 | self._registry[key] = item
34 | return item
35 | return r
36 |
37 | def lookup(self, key: str) -> Any:
38 | """Looks up value for key."""
39 | if key not in self._registry:
40 | valid_keys = "\n".join(self._registry.keys())
41 | raise ValueError(
42 | "%s not registered!\n\n"
43 | "Valid keys:%s\n\n" %
44 | (key, valid_keys))
45 | return self._registry[key]
46 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow
2 | ml-collections
3 | matplotlib
4 | tensorflow-datasets
5 | tensorflow-addons
6 | tensorflow-text
7 | pycocotools
8 | scikit-image
--------------------------------------------------------------------------------
/tasks/captioning.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Instance segmentation task via COCO metric evaluation."""
17 |
18 | import ml_collections
19 |
20 | import utils
21 | import vocab
22 | from data import tokenizer as tokenizer_lib
23 | from metrics import metric_registry
24 | from tasks import task as task_lib
25 | from tasks import task_utils
26 | import tensorflow as tf
27 |
28 |
29 | @task_lib.TaskRegistry.register('captioning')
30 | class TaskCaptioning(task_lib.Task):
31 | """Image captioning with coco evaluation."""
32 |
33 | def __init__(self,
34 | config: ml_collections.ConfigDict):
35 | super().__init__(config)
36 | metric_config = config.task.get('metric')
37 | if metric_config and metric_config.get('name'):
38 | self._coco_metrics = metric_registry.MetricRegistry.lookup(
39 | metric_config.name)(config)
40 | else:
41 | self._coco_metrics = None
42 | self._tokenizer = tokenizer_lib.SPTokenizer(
43 | config.tokenizer.sentencepiece_model,
44 | add_bos=config.tokenizer.add_bos,
45 | add_eos=config.tokenizer.add_eos)
46 |
47 | def preprocess_single(self, dataset, batch_duplicates, training):
48 | """Task-specific preprocessing of individual example in the dataset.
49 |
50 | Args:
51 | dataset: A tf.data.Dataset.
52 | batch_duplicates: `int`, enlarge a batch by augmenting it multiple times
53 | (as specified) and concating the augmented examples.
54 | training: bool.
55 |
56 | Returns:
57 | A dataset.
58 | """
59 | if batch_duplicates > 1:
60 | raise NotImplementedError('Not supporting batch_duplicate=%d > 1 for '
61 | 'caption as of now.' % batch_duplicates)
62 |
63 | def _preprocess_single_example(example):
64 | config = self.config.task
65 | mconfig = self.config.model
66 | if training:
67 | captions = []
68 | for i in range(config.captions_per_image):
69 | caption = (self._tokenizer.string_to_ids(example['captions'][i]) +
70 | mconfig.text_vocab_shift)
71 | captions.append(utils.pad_to_max_len(caption, config.max_seq_len, -1))
72 | captions = tf.stack(captions)
73 |
74 | for t in self.train_transforms:
75 | example = t.process_example(example)
76 | example['captions'] = captions
77 | else:
78 | for t in self.eval_transforms:
79 | example = t.process_example(example)
80 |
81 | # Use the first caption. This won't be used in eval.
82 | caption = (self._tokenizer.string_to_ids(example['captions'][0]) +
83 | mconfig.text_vocab_shift)
84 | caption = utils.pad_to_max_len(caption, config.max_seq_len, -1)
85 | example['captions'] = caption
86 | return example
87 |
88 | dataset = dataset.map(_preprocess_single_example,
89 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
90 | return dataset
91 |
92 | def preprocess_batched(self, batched_examples, training):
93 | """Batched preprocessing & sequence construction.
94 |
95 | Args:
96 | batched_examples: a tupple of features (`dict`) and labels (`dict`),
97 | containing images and labels.
98 | training: `bool` indicating training or inference mode.
99 |
100 | Returns:
101 | images: `float` of shape (bsz, h, w, c)
102 | input_seq: `int` of shape (bsz, seqlen), or (bsz, instacnes, seqlen)
103 | for multiple instances as in keypoint for example.
104 | target_seq: `int` of shape (bsz, seqlen'), or (bsz, instacnes, seqlen')
105 | for multiple instances as in keypoint for example.
106 | """
107 | config = self.config.task
108 | mconfig = self.config.model
109 |
110 | if training:
111 | response_seq = batched_examples['captions'] # (bsz, num_cap, max_seq_len)
112 | prompt_seq = task_utils.build_prompt_seq_from_task_id(
113 | self.task_vocab_id, response_seq) # (bsz, 1)
114 | label_seq = tf.concat([prompt_seq, response_seq], -1)
115 | token_weights = tf.where(
116 | response_seq == 1 + mconfig.text_vocab_shift, # eos token
117 | config.eos_token_weight, 1.0)
118 | input_seq, target_seq = label_seq[..., :-1], label_seq[..., 1:]
119 |
120 | if config.input_seq_drop_rate > 0:
121 | input_seq = tf.where(
122 | tf.random.uniform(tf.shape(input_seq)) > config.input_seq_drop_rate,
123 | input_seq, vocab.FAKE_TEXT_TOKEN)
124 |
125 | return batched_examples['image'], input_seq, target_seq, token_weights
126 | else:
127 | return (batched_examples['image'], batched_examples['captions'],
128 | batched_examples)
129 |
130 | def infer(self, model, preprocessed_outputs):
131 | """Perform inference given the model and preprocessed outputs."""
132 | config = self.config.task
133 | image, _, examples = preprocessed_outputs # response_seq unused by default
134 | bsz = tf.shape(image)[0]
135 | prompt_seq = task_utils.build_prompt_seq_from_task_id(
136 | self.task_vocab_id, prompt_shape=(bsz, 1))
137 | pred_seq, logits, _ = model.infer(
138 | image, prompt_seq, encoded=None,
139 | temperature=config.temperature, top_k=config.top_k, top_p=config.top_p)
140 | # if True: # Sanity check by using gt response_seq as pred_seq.
141 | # pred_seq = response_seq
142 | # logits = tf.one_hot(pred_seq, self.vocab_size)
143 | return examples, pred_seq, logits
144 |
145 | def postprocess_tpu(self, batched_examples, pred_seq, logits, training=False): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
146 | """Organizing results after fitting the batched examples in graph.
147 |
148 | Such as updating metrics, putting together results for computing metrics in
149 | CPU/numpy mode.
150 |
151 | Args:
152 | batched_examples: a tupple of features (`dict`) and labels (`dict`),
153 | containing images and labels.
154 | pred_seq: `int` sequence of shape (bsz * instances, seqlen').
155 | logits: `float` sequence of shape (bsz * instances, seqlen', vocab_size).
156 | training: `bool` indicating training or inference mode.
157 |
158 | Returns:
159 | results for passing to `postprocess_cpu` which runs in CPU mode.
160 | """
161 | return (batched_examples['image'], batched_examples['image/id'],
162 | batched_examples['captions'], pred_seq)
163 |
164 | def postprocess_cpu(self, outputs, train_step, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
165 | eval_step=None, training=False, summary_tag='eval',
166 | ret_results=False):
167 | """CPU post-processing of outputs.
168 |
169 | Such as computing the metrics, log image summary.
170 |
171 | Note: current implementation only support eval mode where gt are given in
172 | metrics as they are not given here in outputs.
173 |
174 | Args:
175 | outputs: a tuple of tensor passed from `postprocess_tpu`.
176 | train_step: `int` scalar indicating training step of current model or
177 | the checkpoint.
178 | eval_step: `int` scalar indicating eval step for the given checkpoint.
179 | training: `bool` indicating training or inference mode.
180 | summary_tag: `string` of name scope for result summary.
181 | ret_results: whether to return visualization images/captions.
182 |
183 | Returns:
184 | A dict of visualization images/caption if ret_results, else None.
185 | """
186 | config = self.config.task
187 | mconfig = self.config.model
188 | del summary_tag
189 | if not training:
190 | images, image_ids, gt_seq, pred_seq = outputs
191 | batch_size = tf.shape(image_ids)[0]
192 | pred_seq = tf.where(pred_seq == 0, 0,
193 | pred_seq - mconfig.text_vocab_shift)
194 | original_pred_seq = pred_seq.numpy()
195 | pred_seq = tf.minimum(
196 | tf.maximum(pred_seq, 0), self._tokenizer.vocab_size - 1)
197 | pred_seq = tf.cast(pred_seq, tf.int32)
198 | clipped_pred_seq = pred_seq.numpy()
199 | output_text = self._tokenizer.ids_to_strings(
200 | pred_seq,
201 | tf.ones([batch_size], tf.int32) * config.max_seq_len).numpy()
202 | output_text = [o.decode('utf-8') for o in output_text]
203 | if self._coco_metrics:
204 | self._coco_metrics.record_prediction(image_ids.numpy(), output_text,
205 | original_pred_seq,
206 | clipped_pred_seq)
207 |
208 | if ret_results:
209 | gt_seq = tf.where(gt_seq == 0, 0,
210 | gt_seq - mconfig.text_vocab_shift)
211 | gt_text = self._tokenizer.ids_to_strings(
212 | tf.cast(gt_seq, tf.int32),
213 | tf.ones([batch_size], tf.int32) * config.max_seq_len).numpy()
214 | gt_text = [s.decode('utf-8') for s in gt_text]
215 | return {'gt_images': images,
216 | 'gt_captions': gt_text,
217 | 'pred_captions': output_text}
218 |
219 | def compute_scalar_metrics(self, step):
220 | """Returns a dict containing scalar metrics to log."""
221 | if self._coco_metrics:
222 | return self._coco_metrics.result(step)
223 | else:
224 | return {}
225 |
226 | def reset_metrics(self):
227 | """Reset states of metrics accumulators."""
228 | if self._coco_metrics:
229 | self._coco_metrics.reset_states()
230 |
--------------------------------------------------------------------------------
/tasks/image_generation.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Task for image generation."""
17 |
18 | import os
19 |
20 | from absl import logging
21 | import ml_collections
22 | import utils
23 | from data import data_utils
24 | from metrics.fid import TFGANMetricEvaluator
25 | from tasks import task as task_lib
26 | import tensorflow as tf
27 |
28 |
29 | @task_lib.TaskRegistry.register('image_generation')
30 | class TaskImageGeneration(task_lib.Task): # pytype: disable=base-class-error
31 | """TaskImageGeneration."""
32 |
33 | def __init__(self, config: ml_collections.ConfigDict):
34 | super().__init__(config)
35 | self._metrics = {}
36 | self._metrics['eval_loss'] = tf.keras.metrics.Mean('eval_loss')
37 | self._tfgan_evaluator = TFGANMetricEvaluator(
38 | dataset_name=config.dataset.tfds_name,
39 | image_size=config.dataset.image_size)
40 |
41 | self._write_images_to_file = config.eval.get('write_images_to_file', False)
42 | if self._write_images_to_file:
43 | self._tfrecord_dir = os.path.join(config.model_dir, 'images',
44 | config.eval.tag)
45 | if not tf.io.gfile.exists(self._tfrecord_dir):
46 | tf.io.gfile.makedirs(self._tfrecord_dir)
47 | self._tfrecord_writer = None
48 |
49 | def preprocess_single(self, dataset, batch_duplicates, training):
50 | """Task-specific preprocessing of individual example in the dataset.
51 |
52 | Args:
53 | dataset: A tf.data.Dataset.
54 | batch_duplicates: `int`, enlarge a batch by augmenting it multiple times
55 | (as specified) and concating the augmented examples.
56 | training: bool.
57 |
58 | Returns:
59 | A dataset.
60 | """
61 | image_size = self.config.dataset.image_size
62 |
63 | def _preprocess_single_example(examples):
64 | examples_list = []
65 | for _ in range(batch_duplicates if training else 1):
66 | image_ = preprocess_image(
67 | examples['image'],
68 | height=image_size,
69 | width=image_size,
70 | cropping=self.config.dataset.cropping,
71 | flipping=self.config.dataset.flipping,
72 | training=training)
73 | if examples['label'].shape.ndims == 0:
74 | label_ = tf.one_hot(examples['label'],
75 | self.config.dataset.num_classes)
76 | else:
77 | label_ = examples['label']
78 | examples_list.append({'image': image_, 'label': label_})
79 | examples = utils.merge_list_of_dict(examples_list)
80 | return examples
81 |
82 | dataset = dataset.map(_preprocess_single_example,
83 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
84 | return dataset
85 |
86 | def preprocess_batched(self, examples, training):
87 | """Task-specific preprocessing of batched examples on accelerators (TPUs).
88 |
89 | Args:
90 | examples: `dict` of images and labels.
91 | training: bool.
92 |
93 | Returns:
94 | images: `float` of shape (bsz, h, w, c)
95 | labels: `int` of shape (bsz)
96 | """
97 | if training:
98 | return examples['image'], examples['label']
99 | else:
100 | return examples['image'], examples['label'], examples
101 |
102 | def infer(self, model, preprocessed_outputs):
103 | """Perform inference given the model and preprocessed outputs."""
104 | images, labels, examples = preprocessed_outputs
105 | outputs = model.sample(
106 | num_samples=tf.shape(images)[0],
107 | iterations=self.config.model.infer_iterations,
108 | method=self.config.model.sampler_name,
109 | images=images,
110 | labels=labels)
111 | if isinstance(outputs, tuple) or isinstance(outputs, list):
112 | samples, eval_loss = outputs
113 | self._metrics['eval_loss'].update_state(eval_loss)
114 | else:
115 | samples = outputs
116 | return examples, samples
117 |
118 | def postprocess_tpu(self,
119 | examples,
120 | samples,
121 | training=False):
122 | """Organizing results after fitting the batched examples in graph.
123 |
124 | Such as updating metrics, putting together results for computing metrics in
125 | CPU/numpy mode.
126 |
127 | Note: current implementation only support eval mode where gt are given in
128 | metrics as they are not constructed here from input_seq/target_seq.
129 |
130 | Args:
131 | examples: `dict` of images and labels.
132 | samples: `float` predicted image tensor of (bsz, h, w, c).
133 | training: `bool` indicating training or inference mode.
134 |
135 | Returns:
136 | results for passing to `postprocess_cpu` which runs in CPU mode.
137 | """
138 | logging.info('Start postprocess_tpu.')
139 | # Get ground-truth images.
140 | if 'original_image' in examples:
141 | images = examples['original_image']
142 | elif examples['image'].shape[-1] == 3:
143 | images = examples['image']
144 | else: # ground-truth images not available in the dataset.
145 | images = samples
146 |
147 | # FID
148 | data_real, data_gen = self._tfgan_evaluator.preprocess_inputs(
149 | [images, samples], is_n1p1=False)
150 | (logits_real, pool3_real), (logits_gen, pool3_gen) = (
151 | self._tfgan_evaluator.get_inception_stats([data_real, data_gen]))
152 |
153 | logging.info('postprocess_tpu done.')
154 | return (images, samples, logits_real, pool3_real, logits_gen, pool3_gen)
155 |
156 | def postprocess_cpu(self,
157 | outputs,
158 | train_step,
159 | eval_step=None,
160 | training=False,
161 | summary_tag='eval',
162 | ret_results=False):
163 | """CPU post-processing of outputs.
164 |
165 | Args:
166 | outputs: a tuple of tensor passed from `postprocess_tpu`.
167 | train_step: `int` scalar indicating training step of current model or the
168 | checkpoint.
169 | eval_step: `int` scalar indicating eval step for the given checkpoint.
170 | training: `bool` indicating training or inference mode.
171 | summary_tag: `string` of name scope for result summary.
172 | ret_results: whether to return visualization images.
173 |
174 | Returns:
175 | A dict of visualization images if ret_results, else None.
176 | """
177 | logging.info('Start postprocess_cpu')
178 | images, samples, logits_real, pool3_real, logits_gen, pool3_gen = outputs
179 |
180 | # FID update.
181 | self._tfgan_evaluator.update_stats(
182 | logits_real, pool3_real, logits_gen, pool3_gen)
183 |
184 | # Image summary.
185 | bsz, h, w, c = utils.shape_as_list(samples)
186 | a = tf.cast(tf.math.sqrt(tf.cast(bsz, tf.float32)), tf.int32)
187 | b = a
188 | vis_samples = samples[:a * a, ...]
189 | vis_samples = tf.reshape(vis_samples, [a, b, h, w, c])
190 | vis_samples = tf.transpose(vis_samples, [0, 2, 1, 3, 4])
191 | images_sum = tf.reshape(vis_samples, [1, a * h, b * w, c])
192 | if eval_step < 2:
193 | tf.summary.image(
194 | f'{summary_tag}/samples_{eval_step}', images_sum, step=train_step)
195 |
196 | # Write gt and samples to tfrecord.
197 | if self._write_images_to_file:
198 | if self._tfrecord_writer is None:
199 | self._tfrecord_writer = tf.io.TFRecordWriter(
200 | os.path.join(self._tfrecord_dir, f'step-{train_step}.tfrecord'))
201 | for i in range(bsz):
202 | ex_str = create_tf_example(images[i], samples[i])
203 | self._tfrecord_writer.write(ex_str)
204 |
205 | logging.info('postprocess_cpu done.')
206 | if ret_results:
207 | return {'gt': images, 'pred': samples}
208 |
209 | def compute_scalar_metrics(self, step):
210 | """Returns a dict containing scalar metrics to log."""
211 | result = {}
212 | for metric in self._metrics.values():
213 | result[metric.name] = metric.result().numpy()
214 |
215 | # FID
216 | result.update(self._tfgan_evaluator.compute_fid_score())
217 | return result
218 |
219 | def reset_metrics(self):
220 | """Reset states of metrics accumulators."""
221 | for metric in self._metrics.values():
222 | metric.reset_states()
223 |
224 | self._tfgan_evaluator.reset()
225 | if self._write_images_to_file:
226 | self._tfrecord_writer.close()
227 | self._tfrecord_writer = None
228 |
229 |
230 |
231 |
232 | def create_tf_example(ref, hyp):
233 | """Creates a serialized tf example that stores the images.
234 |
235 | Args:
236 | ref: Tensor of [h, w, c], the ground-truth image.
237 | hyp: Tensor of [h, w, c], the generated image.
238 |
239 | Returns:
240 | the serialized tf example.
241 | """
242 | ref_bytes = tf.io.encode_png(
243 | tf.image.convert_image_dtype(ref, tf.uint8)).numpy()
244 | hyp_bytes = tf.io.encode_png(
245 | tf.image.convert_image_dtype(hyp, tf.uint8)).numpy()
246 | feature = {
247 | 'hyp': tf.train.Feature(bytes_list=tf.train.BytesList(value=[hyp_bytes])),
248 | 'ref': tf.train.Feature(bytes_list=tf.train.BytesList(value=[ref_bytes])),
249 | }
250 | return tf.train.Example(
251 | features=tf.train.Features(feature=feature)).SerializeToString()
252 |
253 |
254 | def preprocess_image(image,
255 | height,
256 | width,
257 | cropping='none',
258 | flipping='none',
259 | training=False):
260 | """Preprocesses the given image.
261 |
262 | Args:
263 | image: `Tensor` representing an image of arbitrary size.
264 | height: Height of output image.
265 | width: Width of output image.
266 | cropping: which cropping to apply to the image.
267 | flipping: which flipping to apply to the image.
268 | training: `bool` for whether the preprocessing is for training.
269 |
270 | Returns:
271 | A preprocessed image `Tensor` of range [0, 1].
272 | """
273 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
274 | if cropping == 'center':
275 | image = data_utils.largest_center_square_crop(image)
276 | image = tf.image.resize(
277 | image,
278 | size=(height, width),
279 | method='bicubic',
280 | preserve_aspect_ratio=False,
281 | antialias=True)
282 | elif cropping != 'none':
283 | raise ValueError(f'Unknown cropping method {cropping}')
284 | if training:
285 | if flipping == 'left_right':
286 | image = tf.image.random_flip_left_right(image)
287 | elif flipping != 'none':
288 | raise ValueError(f'Unknown flipping method {flipping}')
289 | image = tf.ensure_shape(image, [height, width, 3]) # Let arch knows shape.
290 | return image
291 |
--------------------------------------------------------------------------------
/tasks/recognition.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Object recognition / image classification task."""
17 |
18 | import ml_collections
19 | import utils
20 | from tasks import task as task_lib
21 | from simclr.tf2 import data_util as simclr_data_util
22 | import tensorflow as tf
23 |
24 |
25 | @task_lib.TaskRegistry.register('object_recognition')
26 | class TaskObjectRecognition(task_lib.Task):
27 | """Object recognition."""
28 |
29 | def __init__(self, config: ml_collections.ConfigDict):
30 | super().__init__(config)
31 | self._metrics = {
32 | 'accuracy': tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
33 | }
34 | if 'linear_eval_all_layers' in config.model and (
35 | config.model.linear_eval_all_layers):
36 | for i in range(config.model.num_encoder_layers+2):
37 | self._metrics['accuracy_%d'%i] = (
38 | tf.keras.metrics.SparseCategoricalAccuracy('accuracy_%d'%i))
39 |
40 | def preprocess_single(self, dataset, batch_duplicates, training):
41 | """Task-specific preprocessing of individual example in the dataset.
42 |
43 | Args:
44 | dataset: A tf.data.Dataset.
45 | batch_duplicates: `int`, enlarge a batch by augmenting it multiple times
46 | (as specified) and concating the augmented examples.
47 | training: bool.
48 |
49 | Returns:
50 | A dataset.
51 | """
52 | config = self.config.task
53 | image_size = self.config.dataset.image_size
54 |
55 | def _preprocess_single_example(examples):
56 | test_crop = False if image_size <= 32 else True
57 |
58 | examples_list = []
59 | for _ in range(batch_duplicates if training else 1):
60 | image_ = preprocess_image(
61 | examples['image'],
62 | height=image_size,
63 | width=image_size,
64 | training=training,
65 | color_jitter_strength=config.color_jitter_strength,
66 | test_crop=test_crop,
67 | train_crop=config.train_crop)
68 | if config.get('set_pixel_range_minus_one_to_one'):
69 | image_ = image_ * 2 - 1
70 | if examples['label'].shape.ndims == 0:
71 | label_ = tf.one_hot(examples['label'],
72 | self.config.dataset.num_classes)
73 | else:
74 | label_ = examples['label']
75 | examples_list.append({'image': image_, 'label': label_})
76 | examples = utils.merge_list_of_dict(examples_list)
77 | return examples
78 |
79 | dataset = dataset.map(_preprocess_single_example,
80 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
81 | return dataset
82 |
83 | def preprocess_batched(self, examples, training):
84 | """Task-specific preprocessing of batched examples on accelerators (TPUs).
85 |
86 | Args:
87 | examples: `dict` of images and labels.
88 | training: bool.
89 |
90 | Returns:
91 | images: `float` of shape (bsz, h, w, c)
92 | labels: `int` of shape (bsz)
93 | """
94 | if training:
95 | return examples['image'], examples['label']
96 | else:
97 | return examples['image'], examples['label'], examples
98 |
99 | def infer(self, model, preprocessed_outputs):
100 | """Perform inference given the model and preprocessed outputs."""
101 | image, _, examples = preprocessed_outputs # response_seq unused by default
102 | outputs = model(image, training=False)
103 | logits = outputs[0]
104 | if hasattr(model, 'encode_decode'):
105 | outputs = model.encode_decode(image, training=False)
106 | images_recon = outputs[0]
107 | elif hasattr(model, 'sample'):
108 | outputs = model.sample(num_samples=tf.shape(image)[0])
109 | images_recon = outputs
110 | else:
111 | images_recon = image
112 | return examples, logits, images_recon
113 |
114 | def postprocess_tpu(self, examples, logits, images_recon, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
115 | training=False):
116 | """Organizing results after fitting the batched examples in graph.
117 |
118 | Such as updating metrics, putting together results for computing metrics in
119 | CPU/numpy mode.
120 |
121 | Note: current implementation only support eval mode where gt are given in
122 | metrics as they are not constructed here from input_seq/target_seq.
123 |
124 | Args:
125 | examples: `dict` containing images and labels.
126 | logits: `float` sequence of shape (bsz, seqlen', vocab_size).
127 | images_recon: `float` predicted image tensor of (bsz, h, w, c).
128 | training: `bool` indicating training or inference mode.
129 |
130 | Returns:
131 | results for passing to `postprocess_cpu` which runs in CPU mode.
132 | """
133 | images = examples['image']
134 | labels = tf.argmax(examples['label'], -1)
135 | if isinstance(logits, list):
136 | self._metrics['accuracy'].update_state(labels, logits[-1])
137 | if len(logits) > 1:
138 | assert self.config.model.linear_eval_all_layers
139 | for i, logits_ in enumerate(logits):
140 | self._metrics['accuracy_%d'%i].update_state(labels, logits_)
141 | else:
142 | self._metrics['accuracy'].update_state(labels, logits)
143 | return (images, images_recon)
144 |
145 | def postprocess_cpu(self, outputs, train_step, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
146 | eval_step=None, training=False, summary_tag='eval',
147 | ret_results=False):
148 | """CPU post-processing of outputs.
149 |
150 | Args:
151 | outputs: a tuple of tensor passed from `postprocess_tpu`.
152 | train_step: `int` scalar indicating training step of current model or
153 | the checkpoint.
154 | eval_step: `int` scalar indicating eval step for the given checkpoint.
155 | training: `bool` indicating training or inference mode.
156 | summary_tag: `string` of name scope for result summary.
157 | ret_results: whether to return visualization images.
158 |
159 | Returns:
160 | A dict of visualization images if ret_results, else None.
161 | """
162 | images, images_recon = outputs
163 | if self.config.task.get('image_gen_sum'):
164 | bsz, h, w, c = utils.shape_as_list(images_recon)
165 | a = tf.cast(tf.math.sqrt(tf.cast(bsz, tf.float32)), tf.int32)
166 | b = bsz // a
167 | images_recon = tf.reshape(images_recon, [a, b, h, w, c])
168 | images_recon = tf.transpose(images_recon, [0, 2, 1, 3, 4])
169 | images_sum = tf.reshape(images_recon, [1, a * h, b * w, c])
170 | else:
171 | images_sum = tf.concat([images, images_recon], 2)
172 | if self.config.task.get('set_pixel_range_minus_one_to_one'):
173 | norm = lambda x: (x-tf.reduce_min(x))/(tf.reduce_max(x)-tf.reduce_min(x))
174 | images_sum = norm(images_sum)
175 | if eval_step <= 5:
176 | tf.summary.image(summary_tag + '/gt_pred', images_sum, step=train_step)
177 | if ret_results:
178 | return {'gt': images, 'pred': images_recon}
179 |
180 | def compute_scalar_metrics(self, step):
181 | """Returns a dict containing scalar metrics to log."""
182 | result = {}
183 | for metric in self._metrics.values():
184 | result[metric.name] = metric.result().numpy()
185 | return result
186 |
187 | def reset_metrics(self):
188 | """Reset states of metrics accumulators."""
189 | for metric in self._metrics.values():
190 | metric.reset_states()
191 |
192 |
193 | def preprocess_image(image, height, width, training=False,
194 | color_jitter_strength=0., test_crop=True, train_crop=True):
195 | """Preprocesses the given image.
196 |
197 | Args:
198 | image: `Tensor` representing an image of arbitrary size.
199 | height: Height of output image.
200 | width: Width of output image.
201 | training: `bool` for whether the preprocessing is for training.
202 | color_jitter_strength: `float` between 0 and 1 indicating the color
203 | distortion strength, disable color distortion if not bigger than 0.
204 | test_crop: whether or not to extract a central crop of the images
205 | (as for standard ImageNet evaluation) during the evaluation.
206 | train_crop: whether or not to apply random crop during training.
207 |
208 | Returns:
209 | A preprocessed image `Tensor` of range [0, 1].
210 | """
211 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
212 | if image.shape[-1] == 1:
213 | image = tf.image.grayscale_to_rgb(image)
214 | if training:
215 | return simclr_data_util.preprocess_for_train(
216 | image, height, width, color_jitter_strength, train_crop)
217 | else:
218 | return simclr_data_util.preprocess_for_eval(image, height, width, test_crop)
219 |
--------------------------------------------------------------------------------
/tasks/task.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Task base class."""
17 |
18 | import abc
19 | import copy
20 | from absl import logging
21 | import ml_collections
22 | import registry
23 | import utils
24 | from data import transforms
25 | import tensorflow as tf
26 |
27 | TaskRegistry = registry.Registry()
28 |
29 |
30 | class Task(abc.ABC):
31 | """Task class.
32 |
33 | Providing:
34 | - Preprocessing functions for a specific task that turns raw features into
35 | inputs in a common interface.
36 | - Post-processing functions for a specific task that decode the model's
37 | outputs in a common interface.
38 | - Evaluation for a specific task.
39 | - Important task properties, such as vocab size, max seq len.
40 | """
41 |
42 | def __init__(self,
43 | config: ml_collections.ConfigDict):
44 | self.config = config
45 |
46 | train_transforms = config.task.get('train_transforms', [])
47 | eval_transforms = config.task.get('eval_transforms', [])
48 | self.train_transforms = [
49 | transforms.TransformRegistry.lookup(t.name)(t)
50 | for t in train_transforms]
51 | self.eval_transforms = [
52 | transforms.TransformRegistry.lookup(t.name)(t) for t in eval_transforms]
53 |
54 | @property
55 | def task_vocab_id(self):
56 | return self.config.task.vocab_id
57 |
58 | @abc.abstractmethod
59 | def preprocess_single(self, dataset: tf.data.Dataset, batch_duplicates: int,
60 | training: bool):
61 | """Task-specific preprocessing of individual example in the dataset.
62 |
63 | Args:
64 | dataset: A tf.data.Dataset.
65 | batch_duplicates: `int`, enlarge a batch by augmenting it multiple times
66 | (as specified) and concating the augmented examples.
67 | training: bool.
68 |
69 | Returns:
70 | A dataset.
71 | """
72 |
73 | def preprocess_single_example(self, example, training, batch_duplicates=1):
74 | """Preprocessing of a single example.
75 |
76 | This should be called in preprocess_single.
77 |
78 | Args:
79 | example: A dict of name to Tensor.
80 | training: bool.
81 | batch_duplicates: `int`, enlarge a batch by augmenting it multiple times
82 | (as specified) and concating the augmented examples.
83 |
84 | Returns:
85 | A dict of name to Tensor.
86 | """
87 | if training:
88 | example_list = []
89 | for _ in range(batch_duplicates):
90 | example_ = copy.copy(example)
91 | for t in self.train_transforms:
92 | example_ = t.process_example(example_)
93 | example_list.append(example_)
94 | example = utils.merge_list_of_dict(example_list)
95 | else:
96 | for t in self.eval_transforms:
97 | example = t.process_example(example)
98 | return example
99 |
100 | @abc.abstractmethod
101 | def preprocess_batched(self, batched_examples, training):
102 | """Task-specific preprocessing of batched examples on accelerators (TPUs).
103 |
104 | Args:
105 | batched_examples: preprocessed and batched examples.
106 | training: bool.
107 |
108 | Returns batched inputs in a comon interface for modeling.
109 | """
110 |
111 | @abc.abstractmethod
112 | def postprocess_tpu(self):
113 | """Task-specific post processing on accelerators (TPUs).
114 |
115 | This is intended for inference / evaluation time only.
116 |
117 | Returns a list of tensors for `postprocess_cpu` to further process.
118 | """
119 |
120 | @abc.abstractmethod
121 | def postprocess_cpu(self):
122 | """Task-specific post processing on CPUs.
123 |
124 | This is intended for inference / evaluation time only.
125 |
126 | It receives outputs from `postprocess_tpu`, further processes, and update
127 | internal states (e.g. _metrics).
128 | """
129 |
130 | def _log_metrics(self, metrics_dict, step):
131 | for key, value in metrics_dict.items():
132 | logging.info('Step: [%d] %s = %f', step, key, value)
133 | tf.summary.scalar(key, value, step)
134 |
135 | def evaluate(self, summary_writer, step, eval_tag):
136 | """Evaluate results on accumulated outputs (after multiple infer steps).
137 |
138 | Args:
139 | summary_writer: the summary writer.
140 | step: current step.
141 | eval_tag: `string` name scope for eval result summary.
142 |
143 | Returns:
144 | result as a `dict`.
145 | """
146 | metrics = self.compute_scalar_metrics(step)
147 | with summary_writer.as_default():
148 | with tf.name_scope(eval_tag):
149 | self._log_metrics(metrics, step)
150 | summary_writer.flush()
151 | self.reset_metrics()
152 | return metrics
153 |
154 | @abc.abstractmethod
155 | def compute_scalar_metrics(self, step):
156 | """Returns a dict containing scalar metrics to log."""
157 |
158 | @abc.abstractmethod
159 | def reset_metrics(self):
160 | """Reset states of metrics accumulators."""
161 |
--------------------------------------------------------------------------------
/tasks/video_generation.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Task for video generation."""
17 | from absl import logging
18 | import ml_collections
19 | import utils
20 | from data import data_utils
21 | from metrics import fvd
22 | from tasks import task as task_lib
23 | import tensorflow as tf
24 |
25 |
26 | @task_lib.TaskRegistry.register('video_generation')
27 | class TaskVideoGeneration(task_lib.Task): # pytype: disable=base-class-error
28 | """Task for video generation."""
29 |
30 | def __init__(self, config: ml_collections.ConfigDict):
31 | super().__init__(config)
32 | self._metrics = {}
33 | self._tfgan_evaluator = fvd.FVDMetricEvaluator(
34 | dataset_name=config.dataset.tfds_name,
35 | image_size=config.dataset.image_size)
36 |
37 | def preprocess_single(self, dataset, batch_duplicates, training):
38 | """Task-specific preprocessing of individual example in the dataset.
39 |
40 | Args:
41 | dataset: A tf.data.Dataset.
42 | batch_duplicates: `int`, enlarge a batch by augmenting it multiple times
43 | (as specified) and concating the augmented examples.
44 | training: bool.
45 |
46 | Returns:
47 | A dataset.
48 | """
49 | image_size = self.config.dataset.image_size
50 |
51 | def _preprocess_single_example(features, labels):
52 | features_list, labels_list = [], []
53 | for _ in range(batch_duplicates if training else 1):
54 | video = preprocess_video(
55 | features['video'],
56 | height=image_size,
57 | width=image_size,
58 | seq_len=self.config.dataset.get('seq_len'),
59 | cropping=self.config.dataset.cropping,
60 | flipping=self.config.dataset.flipping,
61 | training=training)
62 | label = tf.one_hot(labels['label'], self.config.dataset.num_classes)
63 | features_list.append({'video': video})
64 | labels_list.append({'label': label})
65 | features = utils.merge_list_of_dict(features_list)
66 | labels = utils.merge_list_of_dict(labels_list)
67 | return features, labels
68 |
69 | dataset = dataset.map(_preprocess_single_example,
70 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
71 | return dataset
72 |
73 | def preprocess_batched(self, batched_examples, training):
74 | """Task-specific preprocessing of batched examples on accelerators (TPUs).
75 |
76 | Args:
77 | batched_examples: tuples of feature and label tensors that are
78 | preprocessed, batched, and stored with `dict`.
79 | training: bool.
80 |
81 | Returns:
82 | videos: `float` of shape (bsz, t, h, w, c)
83 | labels: `int` of shape (bsz)
84 | """
85 | features, labels = batched_examples
86 | if training:
87 | return features['video'], labels['label']
88 | else:
89 | return features['video'], labels['label'], batched_examples
90 |
91 | def infer(self, model, preprocessed_outputs, **kwargs):
92 | """Perform inference given the model and preprocessed outputs."""
93 | videos, labels, examples = preprocessed_outputs
94 | samples = model.sample(
95 | num_samples=tf.shape(videos)[0],
96 | iterations=self.config.model.infer_iterations,
97 | method=self.config.model.sampler_name,
98 | images=videos,
99 | labels=labels,
100 | **kwargs)
101 | return examples, samples
102 |
103 | def postprocess_tpu(self,
104 | batched_examples,
105 | samples,
106 | training=False):
107 | """Organizing results after fitting the batched examples in graph.
108 |
109 | Such as updating metrics, putting together results for computing metrics in
110 | CPU/numpy mode.
111 |
112 | Note: current implementation only support eval mode where gt are given in
113 | metrics as they are not constructed here from input_seq/target_seq.
114 |
115 | Args:
116 | batched_examples: a tupple of features (`dict`) and labels (`dict`),
117 | containing videos and labels.
118 | samples: `float` predicted video tensor of (bsz, t, h, w, c).
119 | training: `bool` indicating training or inference mode.
120 |
121 | Returns:
122 | results for passing to `postprocess_cpu` which runs in CPU mode.
123 | """
124 | logging.info('Start postprocess_tpu.')
125 | features, labels = batched_examples
126 | videos = features['video']
127 | labels = tf.argmax(labels['label'], -1)
128 |
129 | # FID
130 | data_real, data_gen = self._tfgan_evaluator.preprocess_inputs(
131 | [videos, samples], is_n1p1=False)
132 | (logits_real, pool3_real), (logits_gen, pool3_gen) = (
133 | self._tfgan_evaluator.get_inception_stats([data_real, data_gen]))
134 |
135 | logging.info('postprocess_tpu done.')
136 | return (videos, samples, logits_real, pool3_real, logits_gen, pool3_gen)
137 |
138 | def postprocess_cpu(self,
139 | outputs,
140 | train_step,
141 | eval_step=None,
142 | training=False,
143 | summary_tag='eval',
144 | ret_results=False):
145 | """CPU post-processing of outputs.
146 |
147 | Args:
148 | outputs: a tuple of tensor passed from `postprocess_tpu`.
149 | train_step: `int` scalar indicating training step of current model or the
150 | checkpoint.
151 | eval_step: `int` scalar indicating eval step for the given checkpoint.
152 | training: `bool` indicating training or inference mode.
153 | summary_tag: `string` of name scope for result summary.
154 | ret_results: whether to return visualization videos.
155 |
156 | Returns:
157 | A dict of visualization videos if ret_results, else None.
158 | """
159 | logging.info('Start postprocess_cpu')
160 | videos, samples, logits_real, pool3_real, logits_gen, pool3_gen = outputs
161 |
162 | # FID update.
163 | self._tfgan_evaluator.update_stats(
164 | logits_real, pool3_real, logits_gen, pool3_gen)
165 |
166 | # videos summary.
167 | bsz, t, h, w, c = utils.shape_as_list(samples)
168 | a = tf.cast(tf.math.sqrt(tf.cast(bsz, tf.float32)), tf.int32)
169 | b = a
170 | vis_samples = samples[:a * a, ...]
171 | vis_samples = tf.reshape(vis_samples, [a, b, t, h, w, c])
172 | vis_samples = tf.transpose(vis_samples, [0, 3, 1, 2, 4, 5])
173 | videos_sum = tf.reshape(vis_samples, [1, a * h, b * t * w, c])
174 | if eval_step < 2:
175 | tf.summary.image(
176 | f'{summary_tag}/samples_{eval_step}', videos_sum, step=train_step)
177 |
178 | logging.info('postprocess_cpu done.')
179 | if ret_results:
180 | return {'gt': videos, 'pred': samples}
181 |
182 | def compute_scalar_metrics(self, step):
183 | """Returns a dict containing scalar metrics to log."""
184 | result = {}
185 | for metric in self._metrics.values():
186 | result[metric.name] = metric.result().numpy()
187 |
188 | # FID
189 | result.update(self._tfgan_evaluator.compute_fid_score())
190 | return result
191 |
192 | def reset_metrics(self):
193 | """Reset states of metrics accumulators."""
194 | for metric in self._metrics.values():
195 | metric.reset_states()
196 |
197 | self._tfgan_evaluator.reset()
198 |
199 |
200 | def preprocess_video(video,
201 | height,
202 | width,
203 | seq_len=None,
204 | cropping='none',
205 | flipping='none',
206 | training=False):
207 | """Preprocesses the given video.
208 |
209 | Args:
210 | video: `Tensor` representing an video of arbitrary size (B x H x W x C).
211 | height: Height of output video.
212 | width: Width of output video.
213 | seq_len: Length of sequence to crop.
214 | cropping: which cropping to apply to the video.
215 | flipping: which flipping to apply to the video.
216 | training: `bool` for whether the preprocessing is for training.
217 |
218 | Returns:
219 | A preprocessed video `Tensor` of range [0, 1].
220 | """
221 | seq_len = seq_len if seq_len is not None else video.shape[0]
222 | video = tf.image.convert_image_dtype(video, dtype=tf.float32)
223 | if cropping == 'center':
224 | video = data_utils.largest_center_square(video)
225 | video = tf.image.resize(
226 | video,
227 | size=(height, width),
228 | method='area',
229 | preserve_aspect_ratio=False,
230 | antialias=True)
231 | elif cropping == 'random':
232 | video = data_utils.crop_video(
233 | frames=video,
234 | height=height,
235 | width=width,
236 | seq_len=seq_len,
237 | random=True)
238 | elif cropping == 'random_resize':
239 | video = tf.image.resize(
240 | video,
241 | size=(int(height*1.25), int(width*1.25)), # TODO ajabri: make general
242 | method='area',
243 | preserve_aspect_ratio=False,
244 | antialias=True)
245 | video = data_utils.crop_video(
246 | frames=video,
247 | height=height,
248 | width=width,
249 | seq_len=seq_len,
250 | random=True)
251 | elif cropping != 'none':
252 | raise ValueError(f'Unknown cropping method {cropping}')
253 | if training:
254 | if flipping == 'left_right':
255 | video = tf.image.random_flip_left_right(video)
256 | elif flipping != 'none':
257 | raise ValueError(f'Unknown flipping method {flipping}')
258 | video = tf.reshape(video, [seq_len, height, width, 3]) # let arch know shape
259 |
260 | return video
261 |
--------------------------------------------------------------------------------
/tasks/visualization/static_shape.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | # Source: https://github.com/google/automl/tree/master/efficientdet/visualize/static_shape.py
17 | # Copyright 2020 Google Research. All Rights Reserved.
18 | #
19 | # Licensed under the Apache License, Version 2.0 (the "License");
20 | # you may not use this file except in compliance with the License.
21 | # You may obtain a copy of the License at
22 | #
23 | # http://www.apache.org/licenses/LICENSE-2.0
24 | #
25 | # Unless required by applicable law or agreed to in writing, software
26 | # distributed under the License is distributed on an "AS IS" BASIS,
27 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28 | # See the License for the specific language governing permissions and
29 | # limitations under the License.
30 | # ==============================================================================
31 | """Helper functions to access TensorShape values.
32 |
33 | The rank 4 tensor_shape must be of the form [batch_size, height, width, depth].
34 | """
35 |
36 |
37 | def get_dim_as_int(dim):
38 | """Utility to get v1 or v2 TensorShape dim as an int.
39 |
40 | Args:
41 | dim: The TensorShape dimension to get as an int
42 |
43 | Returns:
44 | None or an int.
45 | """
46 | try:
47 | return dim.value
48 | except AttributeError:
49 | return dim
50 |
51 |
52 | def get_batch_size(tensor_shape):
53 | """Returns batch size from the tensor shape.
54 |
55 | Args:
56 | tensor_shape: A rank 4 TensorShape.
57 |
58 | Returns:
59 | An integer representing the batch size of the tensor.
60 | """
61 | tensor_shape.assert_has_rank(rank=4)
62 | return get_dim_as_int(tensor_shape[0])
63 |
64 |
65 | def get_height(tensor_shape):
66 | """Returns height from the tensor shape.
67 |
68 | Args:
69 | tensor_shape: A rank 4 TensorShape.
70 |
71 | Returns:
72 | An integer representing the height of the tensor.
73 | """
74 | tensor_shape.assert_has_rank(rank=4)
75 | return get_dim_as_int(tensor_shape[1])
76 |
77 |
78 | def get_width(tensor_shape):
79 | """Returns width from the tensor shape.
80 |
81 | Args:
82 | tensor_shape: A rank 4 TensorShape.
83 |
84 | Returns:
85 | An integer representing the width of the tensor.
86 | """
87 | tensor_shape.assert_has_rank(rank=4)
88 | return get_dim_as_int(tensor_shape[2])
89 |
90 |
91 | def get_depth(tensor_shape):
92 | """Returns depth from the tensor shape.
93 |
94 | Args:
95 | tensor_shape: A rank 4 TensorShape.
96 |
97 | Returns:
98 | An integer representing the depth of the tensor.
99 | """
100 | tensor_shape.assert_has_rank(rank=4)
101 | return get_dim_as_int(tensor_shape[3])
102 |
--------------------------------------------------------------------------------
/vocab.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pix2Seq Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Vocab."""
17 |
18 | # A shared vocab among tasks and its structure -
19 | # Special tokens: [0, 99).
20 | # Class tokens: [100, coord_vocab_shift). Total coord_vocab_shift - 100 classes.
21 | # Coordinate tokens: [coord_vocab_shift, text_vocab_shift).
22 | # Text tokens: [text_vocab_shift, ...].
23 |
24 | PADDING_TOKEN = 0
25 |
26 | # 10-29 reserved for task id.
27 |
28 | FAKE_CLASS_TOKEN = 30
29 | FAKE_TEXT_TOKEN = 30 # Same token to represent fake class and fake text.
30 | SEPARATOR_TOKEN = 40
31 | INVISIBLE_TOKEN = 41
32 |
33 | BASE_VOCAB_SHIFT = 100
34 |
35 | # Floats used to represent padding and separator in the flat list of polygon
36 | # coords, and invisibility in the key points.
37 | PADDING_FLOAT = -1.
38 | SEPARATOR_FLOAT = -2.
39 | INVISIBLE_FLOAT = -3.
40 | FLOATS = [PADDING_FLOAT, SEPARATOR_FLOAT, INVISIBLE_FLOAT]
41 | TOKENS = [PADDING_TOKEN, SEPARATOR_TOKEN, INVISIBLE_TOKEN]
42 | FLOAT_TO_TOKEN = dict(zip(FLOATS, TOKENS))
43 | TOKEN_TO_FLOAT = dict(zip(TOKENS, FLOATS))
44 |
--------------------------------------------------------------------------------