├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── architectures ├── convnet_blocks.py ├── resnet.py ├── tape.py ├── transformers.py └── transunet.py ├── colabs ├── pix2seq_finetuning_object_detection.ipynb ├── pix2seq_inference_multitask.ipynb └── pix2seq_inference_object_detection.ipynb ├── configs ├── config_base.py ├── config_det_finetune.py ├── config_diffusion_base.py ├── config_diffusion_cifar10.py ├── config_diffusion_cifar10d.py ├── config_diffusion_imagenet64.py ├── config_diffusion_panoptic_base.py ├── config_diffusion_panoptic_image.py ├── config_diffusion_panoptic_video.py ├── config_multi_task.py ├── dataset_configs.py └── transform_configs.py ├── data ├── cityscapes.py ├── coco.py ├── data_utils.py ├── dataset.py ├── datasets.py ├── decode_utils.py ├── obj365.py ├── recognition.py ├── scripts │ ├── create_coco_tfrecord.py │ ├── create_davis_tfrecord.py │ ├── create_kittistep_tfrecord.py │ ├── merge_coco_json_tfrecord.py │ └── tfrecord_lib.py ├── text.py ├── tokenizer.py ├── transforms.py └── video.py ├── metrics ├── coco_metrics.py ├── fid.py ├── fvd.py ├── metric_registry.py ├── metric_utils.py ├── segmentation_and_tracking_quality.py ├── text_metrics.py ├── vos_metrics.py └── vps_metrics.py ├── models ├── ar_model.py ├── diffusion_utils.py ├── image_ar_model.py ├── image_diffusion_model.py ├── image_discrete_diffusion_model.py ├── model.py ├── model_utils.py ├── panoptic_diffusion.py └── video_diffusion_model.py ├── pix2seq.gif ├── pix2seq.png ├── registry.py ├── requirements.txt ├── run.py ├── tasks ├── captioning.py ├── image_generation.py ├── instance_segmentation.py ├── keypoint_detection.py ├── object_detection.py ├── panoptic_segmentation.py ├── recognition.py ├── task.py ├── task_utils.py ├── video_generation.py ├── video_panoptic_segmentation.py └── visualization │ ├── shape_utils.py │ ├── standard_fields.py │ ├── static_shape.py │ └── vis_utils.py ├── utils.py └── vocab.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | SimCLR needs to maintain permanent compatibility with the pre-trained model 4 | files, so we do not plan to make any major changes to this library (other than 5 | what was promised in the README). However, we can accept small patches related 6 | to re-factoring and documentation. To submit contributes, there are just a few 7 | small guidelines you need to follow. 8 | 9 | ## Contributor License Agreement 10 | 11 | Contributions to this project must be accompanied by a Contributor License 12 | Agreement. You (or your employer) retain the copyright to your contribution; 13 | this simply gives us permission to use and redistribute your contributions as 14 | part of the project. Head over to to see 15 | your current agreements on file or to sign a new one. 16 | 17 | You generally only need to submit a CLA once, so if you've already submitted one 18 | (even if it was for a different project), you probably don't need to do it 19 | again. 20 | 21 | ## Code reviews 22 | 23 | All submissions, including submissions by project members, require review. We 24 | use GitHub pull requests for this purpose. Consult 25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 26 | information on using pull requests. 27 | 28 | ## Community Guidelines 29 | 30 | This project follows 31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 32 | -------------------------------------------------------------------------------- /configs/config_base.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Common / shared settings among multiple configs.""" 17 | 18 | import ml_collections 19 | 20 | 21 | def D(**kwargs): 22 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 23 | 24 | 25 | architecture_config_map = { 26 | 'vit-b': D( 27 | resnet_variant='c1', 28 | num_encoder_layers=12, 29 | dim_att=768, 30 | dim_mlp=3072, 31 | num_heads=12, 32 | num_decoder_layers=6, 33 | dim_att_dec=512, 34 | dim_mlp_dec=2048, 35 | num_heads_dec=16, 36 | ), 37 | 'vit-l': D( 38 | resnet_variant='c1', 39 | num_encoder_layers=24, 40 | dim_att=1024, 41 | dim_mlp=4096, 42 | num_heads=16, 43 | num_decoder_layers=8, 44 | dim_att_dec=512, 45 | dim_mlp_dec=2048, 46 | num_heads_dec=16, 47 | ), 48 | 'resnet': D( 49 | resnet_variant='standard', 50 | resnet_depth=50, 51 | resnet_sk_ratio=0., 52 | resnet_width_multiplier=1, 53 | num_encoder_layers=6, 54 | dim_att=256, 55 | dim_mlp=1024, 56 | num_heads=8, 57 | num_decoder_layers=6, 58 | dim_att_dec=256, 59 | dim_mlp_dec=1024, 60 | num_heads_dec=8, 61 | ), 62 | 'resnet-c': D( 63 | resnet_variant='c4', 64 | resnet_depth=50, 65 | resnet_sk_ratio=0., 66 | resnet_width_multiplier=1, 67 | num_encoder_layers=12, 68 | dim_att=512, 69 | dim_mlp=2048, 70 | num_heads=16, 71 | num_decoder_layers=8, 72 | dim_att_dec=512, 73 | dim_mlp_dec=2048, 74 | num_heads_dec=16, 75 | ), 76 | } 77 | -------------------------------------------------------------------------------- /configs/config_det_finetune.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Config file for object detection fine-tuning and evaluation.""" 17 | 18 | import copy 19 | 20 | from configs import dataset_configs 21 | from configs import transform_configs 22 | from configs.config_base import architecture_config_map 23 | from configs.config_base import D 24 | 25 | # pylint: disable=invalid-name,line-too-long,missing-docstring 26 | 27 | 28 | def get_config(config_str=None): 29 | """config_str is either empty or contains task,architecture variants.""" 30 | 31 | task_variant = 'object_detection@coco/2017_object_detection' 32 | encoder_variant = 'vit-b' # Set model architecture. 33 | image_size = (640, 640) # Set image size. 34 | 35 | 36 | tasks_and_datasets = [] 37 | for task_and_ds in task_variant.split('+'): 38 | tasks_and_datasets.append(task_and_ds.split('@')) 39 | 40 | max_instances_per_image = 100 41 | max_instances_per_image_test = 100 42 | 43 | task_config_map = { 44 | 'object_detection': D( 45 | name='object_detection', 46 | vocab_id=10, 47 | image_size=image_size, 48 | quantization_bins=1000, 49 | max_instances_per_image=max_instances_per_image, 50 | max_instances_per_image_test=max_instances_per_image_test, 51 | train_transforms=transform_configs.get_object_detection_train_transforms( 52 | image_size, max_instances_per_image), 53 | eval_transforms=transform_configs.get_object_detection_eval_transforms( 54 | image_size, max_instances_per_image_test), 55 | # Train on both ground-truth and (augmented) noisy objects. 56 | noise_bbox_weight=1.0, 57 | eos_token_weight=0.1, 58 | # Train on just ground-truth objects (with an ending token). 59 | # noise_bbox_weight=0.0, 60 | # eos_token_weight=0.1, 61 | class_label_corruption='rand_n_fake_cls', 62 | top_k=0, 63 | top_p=0.4, 64 | temperature=1.0, 65 | weight=1.0, 66 | metric=D(name='coco_object_detection',), 67 | ), 68 | } 69 | 70 | task_d_list = [] 71 | dataset_list = [] 72 | for tv, ds_name in tasks_and_datasets: 73 | task_d_list.append(task_config_map[tv]) 74 | dataset_config = copy.deepcopy(dataset_configs.dataset_configs[ds_name]) 75 | dataset_list.append(dataset_config) 76 | 77 | config = D( 78 | dataset=dataset_list[0], 79 | datasets=dataset_list, 80 | 81 | task=task_d_list[0], 82 | tasks=task_d_list, 83 | 84 | model=D( 85 | name='encoder_ar_decoder', 86 | image_size=image_size, 87 | max_seq_len=512, 88 | vocab_size=3000, # Note: should be large enough for 100 + num_classes + quantization_bins + (optional) text 89 | coord_vocab_shift=1000, # Note: make sure num_class <= coord_vocab_shift - 100 90 | text_vocab_shift=3000, # Note: make sure coord_vocab_shift + quantization_bins <= text_vocab_shift 91 | use_cls_token=False, 92 | shared_decoder_embedding=True, 93 | decoder_output_bias=True, 94 | patch_size=16, 95 | drop_path=0.1, 96 | drop_units=0.1, 97 | drop_att=0.0, 98 | dec_proj_mode='mlp', 99 | pos_encoding='sin_cos', 100 | pos_encoding_dec='learned', 101 | pretrained_ckpt=get_obj365_pretrained_checkpoint(encoder_variant), 102 | ), 103 | 104 | optimization=D( 105 | optimizer='adamw', 106 | learning_rate=3e-5, 107 | end_lr_factor=0.01, 108 | warmup_epochs=2, 109 | warmup_steps=0, # set to >0 to override warmup_epochs. 110 | weight_decay=0.05, 111 | global_clipnorm=-1, 112 | beta1=0.9, 113 | beta2=0.95, 114 | eps=1e-8, 115 | learning_rate_schedule='linear', 116 | learning_rate_scaling='none', 117 | ), 118 | 119 | train=D( 120 | batch_size=32, 121 | epochs=40, 122 | steps=0, # set to >0 to override epochs. 123 | checkpoint_epochs=1, 124 | checkpoint_steps=0, # set to >0 to override checkpoint_epochs. 125 | keep_checkpoint_max=5, 126 | loss_type='xent', 127 | ), 128 | 129 | eval=D( 130 | tag='eval', 131 | checkpoint_dir='', # checkpoint_dir will be model_dir if not set. 132 | # checkpoint_dir=get_coco_finetuned_checkpoint(encoder_variant, image_size[0]), 133 | batch_size=8, # needs to be divisible by total eval examples. 134 | steps=0, # 0 means eval over full validation set. 135 | ), 136 | ) 137 | 138 | # Update model with architecture variant. 139 | for key, value in architecture_config_map[encoder_variant].items(): 140 | config.model[key] = value 141 | 142 | return config 143 | 144 | CKPT_PREFIX = 'gs://pix2seq' 145 | 146 | 147 | def get_obj365_pretrained_checkpoint(encoder_variant): 148 | if encoder_variant == 'resnet': 149 | return f'{CKPT_PREFIX}/obj365_pretrain/resnet_640x640_b256_s400k' 150 | elif encoder_variant == 'resnet-c': 151 | return f'{CKPT_PREFIX}/obj365_pretrain/resnetc_640x640_b256_s400k' 152 | elif encoder_variant == 'vit-b': 153 | return f'{CKPT_PREFIX}/obj365_pretrain/vit_b_640x640_b256_s400k' 154 | elif encoder_variant == 'vit-l': 155 | return f'{CKPT_PREFIX}/obj365_pretrain/vit_l_640x640_b256_s400k' 156 | else: 157 | raise ValueError('Unknown encoder_variant {}'.format(encoder_variant)) 158 | 159 | 160 | def get_coco_finetuned_checkpoint(encoder_variant, image_size): 161 | if encoder_variant == 'resnet': 162 | return f'{CKPT_PREFIX}/coco_det_finetune/resnet_{image_size}x{image_size}' 163 | elif encoder_variant == 'resnet-c': 164 | return f'{CKPT_PREFIX}/coco_det_finetune/resnetc_{image_size}x{image_size}' 165 | elif encoder_variant == 'vit-b': 166 | return f'{CKPT_PREFIX}/coco_det_finetune/vit_b_{image_size}x{image_size}' 167 | elif encoder_variant == 'vit-l': 168 | return f'{CKPT_PREFIX}/coco_det_finetune/vit_l_{image_size}x{image_size}' 169 | else: 170 | raise ValueError('Unknown encoder_variant {}'.format(encoder_variant)) 171 | -------------------------------------------------------------------------------- /configs/config_diffusion_base.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """A config.""" 17 | 18 | # pylint: disable=invalid-name,line-too-long 19 | 20 | from configs.config_base import architecture_config_map 21 | from configs.config_base import D 22 | 23 | 24 | def get_config(config_str=None): 25 | """config_str can be none or something that specifies meta-hyperparam.""" 26 | if config_str: 27 | data_name, arch_variant = config_str.split(',') 28 | else: 29 | data_name = 'cifar10' 30 | arch_variant = 'transunet' 31 | arch_variant = 'tape' 32 | 33 | architecture_config_map.update({ 34 | 'transunet': D( 35 | arch_name='transunet', 36 | resnet_variant='c1', 37 | dim=128, 38 | in_strides=1, 39 | in_kernel_size=3, 40 | out_kernel_size=3, 41 | kernel_sizes='3', 42 | n_mlp_blocks=0, 43 | udrop=0.0, 44 | n_res_blocks='3,3,3', 45 | ch_multipliers='1,2,2', 46 | mhsa_resolutions='16,8', 47 | per_head_dim=64, 48 | transformer_dim=128, 49 | transformer_strides=2, 50 | transformer_blocks=0, 51 | conditioning=False, 52 | u_pos_encoding='sin_cos', 53 | outp_softmax_groups=0, 54 | ), 55 | 'tape': D( 56 | arch_name='tape', 57 | num_layers='4,4,4,4', 58 | latent_slots=32, 59 | latent_dim=256, 60 | latent_mlp_ratio=4, 61 | latent_num_heads=4, 62 | tape_dim=128, 63 | tape_mlp_ratio=2, 64 | rw_num_heads=1, 65 | conv_kernel_size=0, 66 | conv_drop_units=0., 67 | drop_path=0., 68 | drop_units=0., 69 | drop_att=0., 70 | drop_sc=0., 71 | time_on_latent=False, 72 | cond_on_latent=False, 73 | cond_tape_writable=False, 74 | latent_pos_encoding='learned', 75 | tape_pos_encoding='learned', 76 | xattn_enc_ln=False, 77 | cond_dim=0, 78 | cond_proj=True, 79 | cond_decoupled_read=False, 80 | ), 81 | }) 82 | 83 | dataset_config_map = { 84 | 'mnist': D( 85 | name='object_recognition', 86 | tfds_name='mnist', 87 | train_split='train', 88 | eval_split='test', 89 | num_classes=10, 90 | image_size=28, 91 | batch_duplicates=1, 92 | cache_dataset=True, 93 | cropping='none', 94 | flipping='none', 95 | ), 96 | 'cifar10': D( 97 | name='object_recognition', 98 | tfds_name='cifar10', 99 | train_split='train', 100 | eval_split='test', 101 | num_classes=10, 102 | image_size=32, 103 | batch_duplicates=1, 104 | cache_dataset=True, 105 | cropping='none', 106 | flipping='left_right', 107 | ), 108 | 'imagenet2012': D( 109 | name='object_recognition', 110 | tfds_name='imagenet2012', 111 | train_split='train', 112 | eval_split='validation', 113 | num_classes=1000, 114 | image_size=64, 115 | batch_duplicates=1, 116 | cache_dataset=True, 117 | cropping='center', 118 | flipping='left_right', 119 | ), 120 | 'ucf101': D( 121 | name='tfds_video', 122 | tfds_name='ucf101', 123 | train_split=['train', 'test'], 124 | eval_split=['train', 'test'], 125 | num_classes=101, 126 | image_size=64, 127 | batch_duplicates=1, 128 | cache_dataset=False, 129 | cropping='center', 130 | flipping='none', 131 | seq_len=16, 132 | ), 133 | 'kinetics600': D( 134 | name='kinetics600', 135 | tfds_name='kinetics600', # for eval config 136 | train_split='train', 137 | eval_split='test', 138 | num_classes=600, 139 | image_size=64, 140 | batch_duplicates=1, 141 | cache_dataset=False, 142 | cropping='none', 143 | flipping='none', 144 | seq_len=16, 145 | ), 146 | } 147 | 148 | task = D( 149 | name='image_generation', 150 | weight=1., 151 | ) 152 | task_d_list = [task] 153 | dataset_list = [dataset_config_map[data_name]] 154 | 155 | config = D( 156 | dataset=dataset_list[0], 157 | datasets=dataset_list, 158 | 159 | task=task_d_list[0], 160 | tasks=task_d_list, 161 | 162 | model=D( 163 | name='image_diffusion_model', 164 | train_schedule='cosine', 165 | infer_schedule='cosine', 166 | pred_type='eps', 167 | loss_type='eps', 168 | infer_iterations=100, 169 | td=0., 170 | x0_clip='auto', 171 | b_scale=1.0, 172 | normalize_noisy_input=False, 173 | time_scaling=1000, 174 | pretrained_ckpt='', 175 | sampler_name='ddpm', 176 | conditional='class', 177 | self_cond='latent', 178 | b_type='uint8', 179 | flip_rate=0., 180 | self_cond_rate=0.5, 181 | self_cond_by_masking=False, 182 | cond_dropout=0., 183 | guidance=0., 184 | 185 | # architecture extra 186 | use_cls_token=False, 187 | pos_encoding='sin_cos', 188 | patch_size=8, 189 | drop_path=0., 190 | drop_units=0., 191 | drop_att=0., 192 | ), 193 | 194 | optimization=D( 195 | optimizer='lamb', 196 | learning_rate=1e-4, 197 | warmup_epochs=0, 198 | warmup_steps=5000, 199 | tail_steps=0, 200 | weight_decay=0., 201 | global_clipnorm=1.0, 202 | momentum=0.9, 203 | beta1=0.9, 204 | beta2=0.999, 205 | eps=1e-8, 206 | learning_rate_schedule='none', 207 | learning_rate_scaling='none', 208 | end_lr_factor=0.0, 209 | ema_decay=0.9999, 210 | ema_name_exact_match=True, 211 | exclude_from_weight_decay='bias,beta,gamma', 212 | ), 213 | 214 | train=D( 215 | batch_size=512, 216 | epochs=100, 217 | steps=0, 218 | checkpoint_epochs=40, 219 | checkpoint_steps=0, 220 | keep_checkpoint_max=20, 221 | label_smoothing=0., 222 | ), 223 | 224 | eval=D( 225 | tag='eval', 226 | checkpoint_dir='', 227 | batch_size=64, 228 | steps=100, # this is an approximation. 229 | ), 230 | ) 231 | 232 | # Update model with architecture variant. 233 | for key, value in architecture_config_map[arch_variant].items(): 234 | config.model[key] = value 235 | 236 | return config 237 | -------------------------------------------------------------------------------- /configs/config_diffusion_cifar10.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """A config.""" 17 | 18 | # pylint: disable=invalid-name,line-too-long 19 | 20 | from configs.google.users.iamtingchen import config_diffusion_base as config_base 21 | 22 | DATA_NAME = 'cifar10' 23 | ARCH_VARIANT = 'tape' 24 | # ARCH_VARIANT = 'transunet' 25 | 26 | 27 | def get_config(config_str=None): 28 | """Returns config.""" 29 | del config_str 30 | config = config_base.get_config(f'{DATA_NAME},{ARCH_VARIANT}') 31 | config.model.name = 'image_diffusion_model' 32 | config.model.b_scale = 1.0 33 | config.model.pred_type = 'eps' 34 | config.model.conditional = 'class' 35 | config.optimization.ema_decay = 0.9999 36 | config.eval.batch_size = 80 37 | config.eval.steps = 625 38 | return config 39 | 40 | 41 | def get_sweep(h): 42 | """Get the hyperparamater sweep.""" 43 | if ARCH_VARIANT == 'transunet': 44 | return h.chainit([ 45 | h.product([ 46 | h.sweep('config.train.steps', [1_000_000]), 47 | h.sweep('config.train.checkpoint_steps', [10000]), 48 | h.sweep('config.train.keep_checkpoint_max', [100]), 49 | h.sweep('config.train.batch_size', [128*1]), 50 | h.sweep('config.optimization.learning_rate', [1e-4]), 51 | h.sweep('config.optimization.warmup_steps', [10000]), 52 | h.sweep('config.model.self_cond', ['none', 'x', 'eps']), 53 | h.sweep('config.model.udrop', [0.3]), 54 | h.sweep('config.model.dim', [256]), 55 | h.sweep('config.model.n_res_blocks', ['3,3,3']), 56 | h.sweep('config.model.ch_multipliers', ['1,1,1']), 57 | ]), 58 | ]) 59 | elif ARCH_VARIANT == 'tape': 60 | return h.chainit([ 61 | h.product([ 62 | h.sweep('config.train.steps', [150_000]), 63 | h.sweep('config.train.checkpoint_steps', [10000]), 64 | h.sweep('config.train.keep_checkpoint_max', [10]), 65 | h.sweep('config.train.batch_size', [256]), 66 | h.sweep('config.optimization.optimizer', ['lamb']), 67 | h.sweep('config.optimization.learning_rate', [3e-3]), 68 | h.sweep('config.optimization.learning_rate_schedule', ['cosine@0.8']), 69 | h.sweep('config.optimization.end_lr_factor', [0.0]), 70 | h.sweep('config.optimization.warmup_steps', [10000]), 71 | h.sweep('config.optimization.beta2', [0.999]), 72 | h.sweep('config.optimization.weight_decay', [1e-2]), 73 | h.sweep('config.model.train_schedule', ['sigmoid@-3,3,0.9', 74 | 'simple_linear']), 75 | h.sweep('config.model.self_cond', ['latent']), 76 | h.sweep('config.model.self_cond_by_masking', [True]), 77 | h.sweep('config.model.self_cond_rate', [0.9]), 78 | 79 | h.sweep('config.model.patch_size', [2]), 80 | h.sweep('config.model.time_on_latent', [True]), 81 | h.sweep('config.model.cond_on_latent', [True]), 82 | h.sweep('config.model.cond_tape_writable', [False]), 83 | h.sweep('config.model.latent_pos_encoding', ['learned']), 84 | h.sweep('config.model.tape_pos_encoding', ['learned']), 85 | h.sweep('config.model.num_layers', ['2,2,2']), # '4,4', 86 | h.sweep('config.model.latent_slots', [128]), 87 | h.sweep('config.model.latent_dim', [512]), 88 | h.sweep('config.model.latent_num_heads', [16]), 89 | h.sweep('config.model.latent_mlp_ratio', [4]), 90 | h.sweep('config.model.tape_dim', [256]), 91 | h.sweep('config.model.tape_mlp_ratio', [2]), 92 | h.sweep('config.model.rw_num_heads', [8]), 93 | h.sweep('config.model.drop_units', [0.1]), 94 | h.sweep('config.model.drop_path', [0.1]), 95 | ]), 96 | ]) 97 | 98 | 99 | def get_eval_args_and_tags(config, args, unused_config_flag): 100 | """Return eval args and tags.""" 101 | args_and_tags = [] 102 | for eval_split in [config.dataset.train_split]: 103 | for sampler_name in ['ddpm']: 104 | for infer_schedule in ['cosine']: 105 | for infer_iterations in [400, 1000]: 106 | # if sampler_name == 'ddim' and infer_iterations > 250: continue 107 | # if sampler_name == 'ddpm' and infer_iterations <= 250: continue 108 | eval_args = args.copy() 109 | sampler_name_s = sampler_name.replace('@', '') 110 | infer_schedule_s = infer_schedule.replace('@', '').replace(',', 'c') 111 | eval_tag = f'ev_{eval_split}_{sampler_name_s}_{infer_schedule_s}_i{infer_iterations}' 112 | eval_args.update({ 113 | 'config.eval.tag': eval_tag, 114 | 'config.dataset.eval_split': eval_split, 115 | 'config.model.sampler_name': sampler_name, 116 | 'config.model.infer_schedule': infer_schedule, 117 | 'config.model.infer_iterations': infer_iterations, 118 | }) 119 | args_and_tags.append((eval_args, eval_tag, None)) 120 | return args_and_tags 121 | -------------------------------------------------------------------------------- /configs/config_diffusion_cifar10d.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """A config.""" 17 | 18 | # pylint: disable=invalid-name,line-too-long 19 | 20 | from configs.google.users.iamtingchen import config_diffusion_base as config_base 21 | 22 | DATA_NAME = 'cifar10' 23 | ARCH_VARIANT = 'transunet' 24 | 25 | 26 | def get_config(config_str=None): 27 | """Returns config.""" 28 | del config_str 29 | config = config_base.get_config(f'{DATA_NAME},{ARCH_VARIANT}') 30 | config.model.name = 'image_discrete_diffusion_model' 31 | config.model.b_scale = 1.0 32 | config.model.pred_type = 'x' 33 | config.model.conditional = 'class' 34 | config.optimization.ema_decay = 0.9999 35 | config.eval.batch_size = 80 36 | config.eval.steps = 625 37 | return config 38 | 39 | 40 | def get_sweep(h): 41 | """Get the hyperparamater sweep.""" 42 | if ARCH_VARIANT == 'transunet': 43 | return h.chainit([ 44 | h.product([ 45 | h.sweep('config.train.steps', [1_500_000]), 46 | h.sweep('config.train.checkpoint_steps', [10000]), 47 | h.sweep('config.train.keep_checkpoint_max', [100]), 48 | h.sweep('config.train.batch_size', [128*1]), 49 | h.sweep('config.optimization.learning_rate', [1e-4]), 50 | h.sweep('config.optimization.warmup_steps', [10000]), 51 | h.sweep('config.optimization.optimizer', ['adamw']), 52 | h.sweep('config.optimization.exclude_from_weight_decay', ['']), 53 | h.sweep('config.model.self_cond', ['x']), 54 | h.sweep('config.model.self_cond_by_masking', [False]), 55 | h.sweep('config.model.self_cond_rate', [0.5]), 56 | h.sweep('config.model.b_scale', [0.5]), 57 | h.sweep('config.model.udrop', [0.]), 58 | h.sweep('config.model.dim', [256]), 59 | h.sweep('config.model.n_res_blocks', ['3,3,3']), 60 | h.sweep('config.model.ch_multipliers', ['1,1,1']), 61 | h.sweep('config.model.total_time_steps', [1000]), 62 | h.sweep('config.model.pred_type', ['x_softmax_xent']), 63 | h.zipit([ 64 | h.sweep('config.model.b_type', ['uint8', 'uint8_s']), 65 | h.sweep('config.model.outp_softmax_groups', [0, 3]), 66 | ]), 67 | ]), 68 | ]) 69 | 70 | 71 | def get_eval_args_and_tags(config, args, unused_config_flag): 72 | """Return eval args and tags.""" 73 | args_and_tags = [] 74 | for eval_split in [config.dataset.train_split]: 75 | for sampler_name in ['ddpm']: 76 | for infer_schedule in ['cosine']: 77 | for infer_iterations in [100, 250, 400]: 78 | # if sampler_name == 'ddim' and infer_iterations > 250: continue 79 | # if sampler_name == 'ddpm' and infer_iterations <= 250: continue 80 | eval_args = args.copy() 81 | sampler_name_s = sampler_name.replace('@', '') 82 | infer_schedule_s = infer_schedule.replace('@', '').replace(',', 'c') 83 | eval_tag = f'ev_{eval_split}_{sampler_name_s}_{infer_schedule_s}_i{infer_iterations}' 84 | eval_args.update({ 85 | 'config.eval.tag': eval_tag, 86 | 'config.dataset.eval_split': eval_split, 87 | 'config.model.sampler_name': sampler_name, 88 | 'config.model.infer_schedule': infer_schedule, 89 | 'config.model.infer_iterations': infer_iterations, 90 | }) 91 | args_and_tags.append((eval_args, eval_tag, None)) 92 | return args_and_tags 93 | -------------------------------------------------------------------------------- /configs/config_diffusion_imagenet64.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """A config.""" 17 | 18 | # pylint: disable=invalid-name,line-too-long 19 | 20 | from configs.google.users.iamtingchen import config_diffusion_base as config_base 21 | 22 | DATA_NAME = 'imagenet2012' 23 | ARCH_VARIANT = 'tape' 24 | # ARCH_VARIANT = 'transunet' 25 | IMAGE_SIZE = 64 * 1 26 | 27 | 28 | def get_config(config_str=None): 29 | """Returns config.""" 30 | del config_str 31 | config = config_base.get_config(f'{DATA_NAME},{ARCH_VARIANT}') 32 | config.dataset.image_size = IMAGE_SIZE 33 | config.model.name = 'image_diffusion_model' 34 | config.model.b_scale = 1.0 35 | config.model.pred_type = 'eps' 36 | config.model.conditional = 'class' 37 | config.model.time_on_latent = True 38 | config.model.cond_on_latent = True 39 | config.model.cond_tape_writable = False 40 | config.optimization.ema_decay = 0.9999 41 | config.eval.batch_size = 80 42 | config.eval.steps = 625 43 | return config 44 | 45 | 46 | def get_sweep(h): 47 | """Get the hyperparamater sweep.""" 48 | if ARCH_VARIANT == 'transunet': 49 | return h.chainit([ 50 | h.product([ 51 | h.sweep('config.train.steps', [250_000]), 52 | h.sweep('config.train.checkpoint_steps', [5000]), 53 | h.sweep('config.train.keep_checkpoint_max', [100]), 54 | h.sweep('config.train.batch_size', [1024*1]), 55 | h.sweep('config.optimization.optimizer', ['adamw']), 56 | h.sweep('config.optimization.exclude_from_weight_decay', ['']), 57 | h.sweep('config.optimization.learning_rate', [2e-4]), 58 | h.sweep('config.optimization.warmup_steps', [10000]), 59 | h.sweep('config.optimization.weight_decay', [0.]), 60 | h.sweep('config.model.self_cond', ['none', 'eps']), 61 | h.sweep('config.model.total_time_steps', [1, 1000]), 62 | h.sweep('config.model.udrop', [0.]), 63 | # h.sweep('config.model.dim', [256]), 64 | # h.sweep('config.model.n_res_blocks', ['3,3,3,3']), 65 | # h.sweep('config.model.ch_multipliers', ['1,2,2,3']), 66 | h.sweep('config.model.dim', [192]), 67 | h.sweep('config.model.n_res_blocks', ['3,3,3,3']), 68 | h.sweep('config.model.ch_multipliers', ['1,2,3,4']), 69 | h.sweep('config.model.self_cond_by_masking', [False]), 70 | h.sweep('config.model.self_cond_rate', [0.5]), 71 | ]), 72 | ]) 73 | else: 74 | return h.chainit([ 75 | h.product([ 76 | h.sweep('config.train.steps', [150_000]), 77 | h.sweep('config.train.checkpoint_steps', [10_000]), 78 | h.sweep('config.train.keep_checkpoint_max', [20]), 79 | h.sweep('config.train.batch_size', [1024]), 80 | h.sweep('config.optimization.optimizer', ['lamb']), 81 | h.sweep('config.optimization.learning_rate', [2e-3]), 82 | h.sweep('config.optimization.learning_rate_schedule', ['cosine@0.7']), 83 | h.sweep('config.optimization.end_lr_factor', [0.]), 84 | h.sweep('config.optimization.warmup_steps', [10000]), 85 | h.sweep('config.optimization.weight_decay', [1e-2]), 86 | h.sweep('config.optimization.beta2', [0.999]), 87 | h.sweep('config.model.train_schedule', ['simple_linear']), 88 | h.sweep('config.model.pred_type', ['eps']), 89 | h.sweep('config.model.self_cond', ['latent']), 90 | h.sweep('config.model.self_cond_by_masking', [True]), 91 | h.sweep('config.model.self_cond_rate', [0.9]), 92 | h.sweep('config.model.total_time_steps', [1000]), 93 | 94 | h.sweep('config.model.patch_size', [4*2]), # 4 95 | h.sweep('config.model.latent_pos_encoding', ['learned']), 96 | h.sweep('config.model.tape_pos_encoding', ['learned']), 97 | h.sweep('config.model.num_layers', ['4,4,4,4']), # '6,6,6,6' 98 | h.sweep('config.model.latent_slots', [128]), 99 | h.sweep('config.model.latent_dim', [768]), 100 | h.sweep('config.model.latent_mlp_ratio', [4]), 101 | h.sweep('config.model.latent_num_heads', [16]), 102 | h.sweep('config.model.tape_dim', [512]), 103 | h.sweep('config.model.tape_mlp_ratio', [4]), 104 | h.sweep('config.model.rw_num_heads', [16]), 105 | h.sweep('config.model.drop_units', [0.]), 106 | h.sweep('config.model.drop_path', [0.]), 107 | ]), 108 | ]) 109 | 110 | 111 | def get_eval_args_and_tags(config, args, unused_config_flag): 112 | """Return eval args and tags.""" 113 | args_and_tags = [] 114 | for eval_split in [config.dataset.train_split]: 115 | for sampler_name in ['ddpm']: 116 | for infer_schedule in ['cosine']: 117 | for infer_iterations in [100, 250, 1000]: 118 | eval_args = args.copy() 119 | sampler_name_s = sampler_name.replace('@', '') 120 | infer_schedule_s = infer_schedule.replace('@', '').replace(',', 'c') 121 | eval_tag = f'ev_{eval_split}_{sampler_name_s}_{infer_schedule_s}_i{infer_iterations}' 122 | eval_args.update({ 123 | 'config.eval.tag': eval_tag, 124 | 'config.dataset.eval_split': eval_split, 125 | 'config.model.sampler_name': sampler_name, 126 | 'config.model.infer_schedule': infer_schedule, 127 | 'config.model.infer_iterations': infer_iterations, 128 | }) 129 | args_and_tags.append((eval_args, eval_tag, None)) 130 | return args_and_tags 131 | -------------------------------------------------------------------------------- /configs/config_diffusion_panoptic_base.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """A config.""" 17 | 18 | import copy 19 | # pylint: disable=invalid-name,line-too-long,missing-docstring 20 | from configs import dataset_configs 21 | from configs import transform_configs 22 | from configs.config_base import architecture_config_map 23 | from configs.config_base import D 24 | 25 | 26 | def get_config(config_str=None): 27 | """config_str is either empty or contains task,architecture variants.""" 28 | 29 | if config_str: 30 | config_lists = config_str.split(',') 31 | task_variant, encoder_variant, decoder_variant = config_lists[:3] 32 | image_dim, mdim = config_lists[3], config_lists[4] 33 | # The `if` conditions here are for backwards compatibility with existing 34 | # scripts that pass in a single dim for image size and mask size. 35 | # We should eventually remove this and require callers to pass in the full 36 | # res. 37 | if 'x' in image_dim: 38 | image_size = [int(d) for d in image_dim.split('x')] 39 | else: 40 | image_size = (int(image_dim), int(image_dim)) 41 | if 'x' in mdim: 42 | msize = [int(d) for d in mdim.split('x')] 43 | else: 44 | msize = (int(mdim), int(mdim)) 45 | else: 46 | task_variant = 'panoptic_segmentation@coco/2017_panoptic_segmentation' 47 | encoder_variant = 'resnet-c' 48 | decoder_variant = 'transunet' 49 | image_size = (256, 256) 50 | msize = (64, 64) 51 | 52 | tasks_and_datasets = [] 53 | for task_and_ds in task_variant.split('+'): 54 | tasks_and_datasets.append(task_and_ds.split('@')) 55 | 56 | decoder_config_map = { 57 | 'transformer': D( 58 | arch_name='transformer', 59 | num_layers=2, 60 | dim=512, 61 | dim_mlp=2048, 62 | num_heads=16, 63 | pos_encoding='learned', 64 | drop_path=0.0, 65 | drop_units=0.1, 66 | drop_att=0.0, 67 | ), 68 | 'transunet': D( 69 | arch_name='transunet', 70 | dim=32, 71 | in_strides=1, 72 | in_kernel_size=1, 73 | out_kernel_size=1, 74 | udrop=0.1, 75 | n_mlp_blocks=0, 76 | n_res_blocks='1,1,1,1', 77 | kernel_sizes='3', 78 | ch_multipliers='1,2,3,4', 79 | transformer_dim=512, 80 | transformer_strides=1, 81 | transformer_blocks=1, 82 | mhsa_resolutions='16,8', 83 | per_head_dim=64, 84 | u_pos_encoding='sin_cos', 85 | outp_softmax_groups=16, 86 | ), 87 | 'tape': D( 88 | arch_name='tape', 89 | num_layers='1,1', 90 | latent_slots=128, 91 | latent_dim=256*2, 92 | latent_mlp_ratio=4, 93 | latent_num_heads=16, 94 | tape_dim=256*2, 95 | tape_mlp_ratio=4, 96 | rw_num_heads=16, 97 | conv_kernel_size=0, 98 | conv_drop_units=0., 99 | drop_path=0., 100 | drop_units=0., 101 | drop_att=0., 102 | patch_size=8, 103 | outp_softmax_groups=2, 104 | pos_encoding='sin_cos', 105 | patch_scales='1', 106 | patch_scales_w='1', 107 | latent_pos_encoding='learned', 108 | tape_pos_encoding='learned', 109 | ), 110 | } 111 | 112 | task_config_map = { 113 | 'panoptic_segmentation': D( 114 | name='panoptic_segmentation', 115 | vocab_id=16, 116 | image_size=image_size, 117 | n_bits_label=16, 118 | max_instances_per_image=101, 119 | object_order='random', 120 | color_jitter_strength=0., 121 | jitter_scale_min=0.3, 122 | jitter_scale_max=1.0, 123 | min_pixels=40, 124 | weight=1.0, 125 | metric=D( 126 | name='coco_panoptic_segmentation', 127 | results_dir='', 128 | ), 129 | ), 130 | 'video_panoptic_segmentation': D( 131 | name='video_panoptic_segmentation', 132 | vocab_id=18, 133 | image_size=image_size, 134 | n_bits_label=16, 135 | max_instances_per_image=256, # including id 0. 136 | object_order='shuffle', 137 | color_jitter_strength=0., 138 | jitter_scale_min=1.0, 139 | jitter_scale_max=1.0, 140 | weight=1.0, 141 | proceeding_frames='-2,-1', 142 | eval_single_frames=False, 143 | eval_use_gt_cond_frames=False, 144 | frames_dropout=0.1, 145 | max_num_frames=100, 146 | min_pixels=40, 147 | metric=D( 148 | name='segmentation_and_tracking_quality', 149 | results_dir='' 150 | ), 151 | ), 152 | } 153 | 154 | task_d_list = [] 155 | dataset_list = [] 156 | for task_name, ds_name in tasks_and_datasets: 157 | task_d_list.append(task_config_map[task_name]) 158 | dataset_config = copy.deepcopy(dataset_configs.dataset_configs[ds_name]) 159 | dataset_list.append(dataset_config) 160 | 161 | config = D( 162 | dataset=dataset_list[0], 163 | datasets=dataset_list, 164 | 165 | task=task_d_list[0], 166 | tasks=task_d_list, 167 | 168 | encoder=D(), 169 | 170 | decoder=D(), 171 | 172 | model=D( 173 | name='panoptic_diffusion', 174 | train_schedule='cosine', 175 | infer_schedule='cosine', 176 | train_noise='normal', 177 | infer_noise='normal', 178 | noise_std=1.0, 179 | noise_truncation=False, 180 | pred_type='x_softmax_xent', 181 | iter_start=0, 182 | step_bias=0., 183 | td=0., 184 | td_schedule='constant', 185 | x0_clip='auto', 186 | b_scale=0.1, 187 | normalize_noisy_input=False, 188 | total_time_steps=1000., 189 | pretrained_ckpt='', 190 | self_cond='none', 191 | conditional='cat', 192 | iterations=100, 193 | iterations_2=100, # only used in video inference where less iterations can be used for the 2nd frame onwards. 194 | sampler='ddim', 195 | l_tile_factors=1, 196 | msize=msize, 197 | mask_weight_p=0., 198 | self_cond_rate=0.5, 199 | self_cond_by_masking=False, 200 | 201 | # extra architecture 202 | image_size=image_size, 203 | decoder_variant=decoder_variant, 204 | use_cls_token=False, 205 | patch_size=16, 206 | drop_path=0.1, 207 | drop_units=0.1, 208 | drop_att=0.0, 209 | pos_encoding='sin_cos', 210 | dec_proj_mode='mlp', 211 | enc_drop=0., 212 | enc_fuse='pyramid_merge', 213 | enc_fuse_upsample='nearest', 214 | enc_fuse_dim=256, 215 | frozen_backbone=False, 216 | ), 217 | 218 | optimization=D( 219 | optimizer='adamw', 220 | learning_rate=1e-3, 221 | end_lr_factor=0.01, 222 | warmup_epochs=10, 223 | warmup_steps=0, # set to >0 to override warmup_epochs. 224 | weight_decay=0.05, 225 | global_clipnorm=-1., 226 | beta1=0.9, 227 | beta2=0.95, 228 | eps=1e-8, 229 | ema_decay=0.9999, 230 | ema_name_exact_match=True, 231 | learning_rate_schedule='linear', 232 | learning_rate_scaling='none', 233 | ), 234 | 235 | train=D( 236 | batch_size=32, 237 | epochs=40, 238 | steps=0, # set to >0 to override epochs. 239 | checkpoint_epochs=1, 240 | checkpoint_steps=0, # set to >0 to override checkpoint_epochs. 241 | keep_checkpoint_max=10, 242 | loss_type='xent', 243 | ), 244 | 245 | eval=D( 246 | tag='eval', 247 | checkpoint_dir='', # checkpoint_dir will be model_dir if not set. 248 | batch_size=4, # needs to be divisible by total eval examples. 249 | steps=0, # 0 means eval over full validation set. 250 | ), 251 | ) 252 | 253 | # Update model with architecture variant. 254 | for key, value in architecture_config_map[encoder_variant].items(): 255 | config.model[key] = value 256 | 257 | # Update decoder architecture variant. 258 | for key, value in decoder_config_map[decoder_variant].items(): 259 | config.decoder[key] = value 260 | 261 | # on-the-fly gt gathering for metric computation 262 | # config.dataset.coco_annotations_dir_for_metrics = '' 263 | 264 | config.task.train_transforms = transform_configs.get_panoptic_segmentation_train_transforms( 265 | image_size, msize, 1.0, 1.0, 0.) 266 | config.task.eval_transforms = transform_configs.get_panoptic_segmentation_eval_transforms( 267 | image_size) 268 | 269 | return config 270 | -------------------------------------------------------------------------------- /configs/config_diffusion_panoptic_image.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Config for conditional mask modeling.""" 17 | 18 | # pylint: disable=invalid-name,line-too-long 19 | 20 | from configs import config_diffusion_panoptic_base as config_base 21 | from configs import transform_configs 22 | 23 | IMAGE_SIZE = '1024x1024' 24 | MASK_SIZE = '512x512' 25 | MODE = 'train_high_res' # one of ['train_high_res', 'train_low_res'] 26 | 27 | 28 | def get_config(config_str=None): 29 | """Returns config.""" 30 | if config_str: 31 | task_variant = config_str 32 | else: 33 | task_variant = 'panoptic_segmentation@coco/2017_panoptic_segmentation' 34 | encoder_variant = 'resnet-c' 35 | decoder_variant = 'transunet' 36 | config = config_base.get_config( 37 | f'{task_variant},{encoder_variant},{decoder_variant},{IMAGE_SIZE},{MASK_SIZE}') 38 | image_size = [int(x) for x in IMAGE_SIZE.split('x')] 39 | mask_size = [int(x) for x in MASK_SIZE.split('x')] 40 | config.task.train_transforms = transform_configs.get_panoptic_segmentation_train_transforms( 41 | image_size, mask_size, 1.0, 1.0, 0.) 42 | config.task.eval_transforms = transform_configs.get_panoptic_segmentation_eval_transforms( 43 | image_size) 44 | config.model.name = 'panoptic_diffusion' 45 | config.model.train_schedule = 'cosine' 46 | config.model.l_tile_factors = 1 47 | config.model.frozen_backbone = False 48 | config.model.enc_drop = 0. 49 | config.model.enc_fuse = 'pyramid_merge' 50 | config.model.enc_fuse_upsample = 'nearest' 51 | config.model.enc_fuse_dim = 256 52 | config.model.total_time_steps = 1.0 # for legacy compability. 53 | config.decoder.mhsa_resolutions = '0' 54 | config.decoder.n_mlp_blocks = 0 55 | config.decoder.outp_softmax_groups = 0 56 | config.decoder.in_kernel_size = 1 57 | config.decoder.out_kernel_size = 1 58 | config.optimization.learning_rate_schedule = 'linear' 59 | config.optimization.end_lr_factor = 0.02 60 | config.optimization.weight_decay = 0.05 61 | config.optimization.beta2 = 0.999 62 | config.optimization.warmup_epochs = 5 63 | config.optimization.global_clipnorm = 1. 64 | return config 65 | 66 | 67 | def get_sweep(h): 68 | """Get the hyperparamater sweep.""" 69 | if MODE == 'train_low_res': 70 | return h.chainit([ 71 | h.product([ 72 | h.sweep('config.train.epochs', [800]), 73 | h.sweep('config.train.batch_size', [512]), 74 | h.sweep('config.train.checkpoint_epochs', [10]), 75 | h.sweep('config.train.keep_checkpoint_max', [10]), 76 | h.sweep('config.optimization.learning_rate', [1e-4]), 77 | h.sweep('config.optimization.end_lr_factor', [0.1]), 78 | h.sweep('config.optimization.warmup_epochs', [5]), 79 | h.sweep('config.optimization.ema_decay', [0.999]), 80 | h.sweep('config.model.b_scale', [0.1]), 81 | h.sweep('config.model.pred_type', ['x_softmax_xent']), 82 | h.sweep('config.model.self_cond', ['none']), 83 | h.sweep('config.model.conditional', ['cat+attn']), 84 | h.sweep('config.model.mask_weight_p', [0.2]), 85 | h.sweep('config.task.train_transforms[1].min_scale', [1.0]), # jitter_scale 86 | h.sweep('config.task.train_transforms[1].max_scale', [3.0]), 87 | h.sweep('config.task.train_transforms[4].color_jitter_strength', [1.0]), 88 | h.sweep('config.decoder.dim', [128]), 89 | h.sweep('config.decoder.udrop', [0.]), 90 | h.sweep('config.decoder.n_res_blocks', ['1,1,1,1']), 91 | h.sweep('config.decoder.ch_multipliers', ['1,1,2,2']), 92 | h.sweep('config.decoder.transformer_strides', [1]), 93 | h.sweep('config.decoder.transformer_dim', [512]), 94 | h.sweep('config.decoder.transformer_blocks', [6]), 95 | h.sweep('config.decoder.outp_softmax_groups', [2]), 96 | ]), 97 | ]) 98 | elif MODE == 'train_high_res': 99 | return h.chainit([ 100 | h.product([ 101 | h.sweep('config.train.epochs', [15]), 102 | h.sweep('config.train.batch_size', [16]), 103 | h.sweep('config.train.checkpoint_epochs', [1]), 104 | h.sweep('config.optimization.learning_rate', [1e-5]), 105 | h.sweep('config.optimization.end_lr_factor', [0.1]), 106 | h.sweep('config.optimization.warmup_epochs', [0]), 107 | h.sweep('config.optimization.ema_decay', [0.999]), 108 | h.sweep('config.model.b_scale', [0.1]), 109 | h.sweep('config.model.pred_type', ['x_softmax_xent']), 110 | h.sweep('config.model.self_cond', ['none']), 111 | h.sweep('config.model.conditional', ['cat+attn']), 112 | h.sweep('config.model.mask_weight_p', [0.2]), 113 | h.sweep('config.task.train_transforms[1].min_scale', [1.0]), 114 | h.sweep('config.task.train_transforms[1].max_scale', [1.0]), 115 | h.sweep('config.decoder.dim', [128]), 116 | h.sweep('config.decoder.udrop', [0.]), 117 | h.sweep('config.decoder.n_res_blocks', ['1,1,1,1']), 118 | h.sweep('config.decoder.ch_multipliers', ['1,1,2,2']), 119 | h.sweep('config.decoder.transformer_strides', [1]), 120 | h.sweep('config.decoder.transformer_dim', [512]), 121 | h.sweep('config.decoder.transformer_blocks', [6]), 122 | h.sweep('config.decoder.outp_softmax_groups', [2]), 123 | ]), 124 | ]) 125 | 126 | 127 | def get_eval_args_and_tags(config, args, unused_config_flag): 128 | """Return eval args and tags.""" 129 | args_and_tags = [] 130 | for eval_split in [config.dataset.eval_split]: 131 | for sampler in ['ddim']: 132 | for iterations in [20]: 133 | for td in [1.0, 2.0]: 134 | for min_pixels in [40]: 135 | eval_args = args.copy() 136 | eval_tag = f'ev_{eval_split}_{sampler}_i{iterations}_td{td}_p{min_pixels}' 137 | results_dir = eval_args['model_dir'] + '/' + eval_tag # pylint: disable=unused-variable 138 | eval_args.update({ 139 | 'config.eval.tag': eval_tag, 140 | 'config.eval.batch_size': 8, 141 | 'config.eval.steps': 0, 142 | 'config.model.sampler': sampler, 143 | 'config.model.iterations': iterations, 144 | 'config.model.td': td, 145 | 'config.task.min_pixels': min_pixels, 146 | # 'config.task.metric.results_dir': results_dir, 147 | }) 148 | if eval_split == 'train': 149 | eval_args.update({ 150 | 'config.dataset.eval_split': 'train', 151 | 'config.eval.steps': 100, 152 | }) 153 | args_and_tags.append((eval_args, eval_tag, None)) 154 | return args_and_tags 155 | 156 | 157 | -------------------------------------------------------------------------------- /configs/config_diffusion_panoptic_video.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Video panoptic segmentation config.""" 17 | 18 | # pylint: disable=invalid-name,line-too-long 19 | 20 | from configs import config_diffusion_panoptic_base as config_base 21 | from configs import transform_configs 22 | 23 | 24 | def get_config(config_str=None): 25 | """Returns config.""" 26 | if config_str: 27 | task_variant = config_str 28 | else: 29 | task_variant = 'video_panoptic_segmentation@kittistep_vps' 30 | 31 | encoder_variant = 'resnet-c' 32 | decoder_variant = 'transunet' 33 | 34 | if 'kittistep' in task_variant: 35 | imgsize = '384x1248' 36 | msize = '192x624' 37 | elif 'davis' in task_variant: 38 | imgsize = '512x1024' 39 | msize = '256x512' 40 | else: 41 | imgsize = '256x256' 42 | msize = '128x128' 43 | 44 | config = config_base.get_config( 45 | f'{task_variant},{encoder_variant},{decoder_variant},{imgsize},{msize}') 46 | 47 | image_size = [int(x) for x in imgsize.split('x')] 48 | mask_size = [int(x) for x in msize.split('x')] 49 | config.task.train_transforms = transform_configs.get_video_panoptic_segmentation_train_transforms( 50 | image_size, mask_size, 1.0, 1.0, 0.) 51 | config.task.eval_transforms = transform_configs.get_video_panoptic_segmentation_eval_transforms( 52 | image_size, mask_size, 100) 53 | 54 | config.model.name = 'panoptic_diffusion' 55 | config.model.train_schedule = 'cosine' 56 | config.model.l_tile_factors = 1 57 | config.model.frozen_backbone = False 58 | config.model.enc_drop = 0. 59 | config.model.enc_fuse = 'pyramid_merge' 60 | config.model.enc_fuse_upsample = 'nearest' 61 | config.model.enc_fuse_dim = 256 62 | config.model.b_scale = 0.1 63 | config.model.pred_type = 'x_softmax_xent' 64 | config.model.self_cond = 'none' 65 | config.model.conditional = 'cat+attn' 66 | config.model.mask_weight_p = 0.2 67 | 68 | config.decoder.mhsa_resolutions = '0' 69 | config.decoder.n_mlp_blocks = 0 70 | config.decoder.in_kernel_size = 1 71 | config.decoder.out_kernel_size = 1 72 | config.decoder.output_residual = False 73 | config.decoder.input_scaling = False 74 | config.decoder.dim = 128 75 | config.decoder.udrop = 0. 76 | config.decoder.n_res_blocks = '1,1,1,1' 77 | config.decoder.ch_multipliers = '1,1,2,2' 78 | config.decoder.transformer_strides = 1 79 | config.decoder.transformer_dim = 512 80 | config.decoder.transformer_blocks = 6 81 | config.decoder.outp_softmax_groups = 2 82 | 83 | config.optimization.learning_rate_schedule = 'linear' 84 | config.optimization.end_lr_factor = 0.02 85 | config.optimization.weight_decay = 0.05 86 | config.optimization.beta2 = 0.999 87 | config.optimization.warmup_epochs = 0 88 | config.optimization.global_clipnorm = 1. 89 | 90 | config.task.proceeding_frames = '-2,-1' 91 | config.task.eval_single_frames = False 92 | config.task.eval_use_gt_cond_frames = False 93 | 94 | if 'davis' in task_variant: 95 | config.task.max_instances_per_image = 16 96 | config.task.max_num_frames = 105 97 | config.task.eval_transforms[4].max_num_frames = 105 98 | config.task.metric.name = 'davis_video_object_segmentation' 99 | 100 | config.eval.batch_size = 2 101 | return config 102 | 103 | 104 | def get_sweep(h): 105 | """Get the hyperparamater sweep.""" 106 | 107 | return h.chainit([ 108 | h.product([ 109 | h.sweep('config.train.steps', [50000]), 110 | h.sweep('config.train.batch_size', [32]), 111 | h.sweep('config.train.checkpoint_steps', [1000]), 112 | h.sweep('config.optimization.learning_rate', [1e-5, 3e-5]), 113 | h.sweep('config.optimization.end_lr_factor', [1.]), 114 | h.sweep('config.optimization.warmup_epochs', [0]), 115 | h.sweep('config.optimization.ema_decay', [0.99]), 116 | h.zipit([ 117 | h.sweep('config.task.train_transforms[0].min_scale', [1.0]), 118 | h.sweep('config.task.train_transforms[0].max_scale', [1.0]), 119 | ]), 120 | h.sweep('config.task.train_transforms[3].color_jitter_strength', [0.]), 121 | h.sweep('config.task.object_order', ['shuffle']), 122 | h.sweep('config.task.frames_dropout', [0.2]), 123 | ]), 124 | ]) 125 | 126 | 127 | -------------------------------------------------------------------------------- /configs/dataset_configs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Dataset configs.""" 17 | import os 18 | from configs.config_base import D 19 | 20 | 21 | _shared_dataset_config = D( 22 | batch_duplicates=1, 23 | cache_dataset=True, 24 | ) 25 | 26 | # Generate tfrecords for the dataset using data/scripts/create_coco_tfrecord.py 27 | # and add paths here. 28 | COCO_TRAIN_TFRECORD_PATTERN = 'gs://pix2seq/multi_task/data/coco/tfrecord/train*' 29 | COCO_VAL_TFRECORD_PATTERN = 'gs://pix2seq/multi_task/data/coco/tfrecord/val*' 30 | 31 | # Download from gs://pix2seq/multi_task/data/coco/json 32 | COCO_ANNOTATIONS_DIR = '/tmp/coco_annotations' 33 | 34 | _shared_coco_dataset_config = D( 35 | train_file_pattern=COCO_TRAIN_TFRECORD_PATTERN, 36 | val_file_pattern=COCO_VAL_TFRECORD_PATTERN, 37 | train_num_examples=118287, 38 | eval_num_examples=5000, 39 | train_split='train', 40 | eval_split='validation', 41 | # Directory of annotations used by the metrics. 42 | # Also need to set train_filename_for_metrics and val_filename_for_metrics. 43 | # If unset, groundtruth annotations should be specified via 44 | # record_groundtruth. 45 | coco_annotations_dir_for_metrics=COCO_ANNOTATIONS_DIR, 46 | label_shift=0, 47 | **_shared_dataset_config 48 | ) 49 | 50 | dataset_configs = { 51 | 'coco/2017_object_detection': 52 | D( 53 | name='coco/2017_object_detection', 54 | train_filename_for_metrics='instances_train2017.json', 55 | val_filename_for_metrics='instances_val2017.json', 56 | category_names_path=os.path.join( 57 | _shared_coco_dataset_config['coco_annotations_dir_for_metrics'], 58 | 'instances_val2017.json'), 59 | **_shared_coco_dataset_config 60 | ), 61 | 'coco/2017_instance_segmentation': 62 | D( 63 | name='coco/2017_instance_segmentation', 64 | train_filename_for_metrics='instances_train2017.json', 65 | val_filename_for_metrics='instances_val2017.json', 66 | category_names_path=os.path.join( 67 | _shared_coco_dataset_config['coco_annotations_dir_for_metrics'], 68 | 'instances_val2017.json'), 69 | **_shared_coco_dataset_config 70 | ), 71 | 'coco/2017_keypoint_detection': 72 | D( 73 | name='coco/2017_keypoint_detection', 74 | train_filename_for_metrics='person_keypoints_train2017.json', 75 | val_filename_for_metrics='person_keypoints_val2017.json', 76 | category_names_path=os.path.join( 77 | _shared_coco_dataset_config['coco_annotations_dir_for_metrics'], 78 | 'person_keypoints_val2017.json'), 79 | **_shared_coco_dataset_config 80 | ), 81 | 'coco/2017_captioning': 82 | D(name='coco/2017_captioning', 83 | train_filename_for_metrics='captions_train2017_eval_compatible.json', 84 | val_filename_for_metrics='captions_val2017_eval_compatible.json', 85 | **_shared_coco_dataset_config), 86 | } 87 | -------------------------------------------------------------------------------- /data/cityscapes.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Cityscapes dataset.""" 17 | from data import dataset as dataset_lib 18 | import tensorflow as tf 19 | 20 | 21 | @dataset_lib.DatasetRegistry.register('cityscapes_panoptic') 22 | class CityscapesPanopticDataset(dataset_lib.TFRecordDataset): 23 | """Cityscapes panoptic dataset.""" 24 | 25 | def get_feature_map(self, training): 26 | """Returns feature map for parsing the TFExample.""" 27 | del training 28 | return { 29 | 'image/encoded': 30 | tf.io.FixedLenFeature([], tf.string), 31 | 'image/segmentation/class/encoded': 32 | tf.io.FixedLenFeature([], tf.string), 33 | 'image/height': 34 | tf.io.FixedLenFeature([], tf.int64), 35 | 'image/width': 36 | tf.io.FixedLenFeature([], tf.int64), 37 | 'image/filename': 38 | tf.io.FixedLenFeature([], tf.string), 39 | } 40 | 41 | def extract(self, example, training): 42 | """Extracts needed features & annotations into a flat dictionary. 43 | 44 | Note: 45 | - label starts at 1 instead of 0, as 0 is reserved for special use 46 | (such as padding). 47 | - coordinates (e.g. bbox) are (normalized to be) in [0, 1]. 48 | 49 | Args: 50 | example: `dict` of raw features. 51 | training: `bool` of training vs eval mode. 52 | 53 | Returns: 54 | example: `dict` of relevant features and labels. 55 | """ 56 | # Decode image and label. 57 | image = tf.io.decode_image(example['image/encoded'], channels=3) 58 | image.set_shape([1024, 2048, 3]) 59 | label = example['image/segmentation/class/encoded'] 60 | label = tf.io.decode_raw( 61 | example['image/segmentation/class/encoded'], out_type=tf.int32) 62 | label_shape = tf.stack([1024, 2048]) 63 | label = tf.reshape(label, label_shape) 64 | 65 | # Map instance ids to range(1, num_instance + 1) 66 | unique_instance_ids, _ = tf.unique(tf.reshape(label, [-1])) 67 | num_instances = tf.size(unique_instance_ids) 68 | new_instance_ids = tf.random.shuffle(tf.range(1, num_instances + 1)) 69 | def map_ids(x, src_ids, tgt_ids): 70 | """Convert object ids into semantic classes.""" 71 | x = tf.equal(x[:, :, tf.newaxis], src_ids[tf.newaxis, tf.newaxis, :]) 72 | x = tf.reduce_sum(tf.cast(x, tgt_ids.dtype) * 73 | tgt_ids[tf.newaxis, tf.newaxis, :], -1) 74 | return x 75 | identity = map_ids(label, unique_instance_ids, new_instance_ids) 76 | 77 | # label = class * max_instances_per_class + per_class_instance_id 78 | semantic = label // self.config.max_instances_per_class 79 | 80 | ignore_mask = tf.logical_not(tf.logical_and( 81 | tf.greater_equal(semantic, 0), 82 | tf.less(semantic, self.config.num_classes - 83 | 1))) # num_classes includes padding class. 84 | # 0 is reserved for background and labels which are to be ignored. 85 | semantic = tf.where(ignore_mask, tf.zeros_like(semantic), semantic + 1) 86 | identity = tf.where(ignore_mask, tf.zeros_like(identity), identity) 87 | 88 | return { 89 | 'image': 90 | tf.image.convert_image_dtype(image, tf.float32), 91 | # TODO(srbs): Find another hashing strategy that does not have 92 | # collisions possibly by leveraging the structure of the filename which 93 | # is _123456_123456. 94 | # Coco metrics would fail if there are duplicate image ids in preds or 95 | # gt. 96 | 'image/id': 97 | tf.strings.to_hash_bucket( 98 | example['image/filename'], num_buckets=1000000000), 99 | 'label_map': tf.stack([semantic, identity], -1) 100 | } 101 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Dataset base class.""" 17 | 18 | import abc 19 | import functools 20 | import operator 21 | from typing import Callable 22 | import ml_collections 23 | 24 | import registry 25 | import tensorflow as tf 26 | import tensorflow_datasets as tfds 27 | 28 | 29 | DatasetRegistry = registry.Registry() 30 | 31 | 32 | def mix_datasets(input_fns, weights): 33 | """Mix multiple datasets according to weights. 34 | 35 | Args: 36 | input_fns: a list of input_fn's. Each input_fn takes in an input_context and 37 | produces a tf.data.Dataset instance. 38 | weights: a list of floats where weights[i] represents the probability to 39 | sample from input_fns[i]. 40 | 41 | Returns: 42 | a tf.data.Dataset instance. 43 | """ 44 | def input_fn(input_context): 45 | dses = [] 46 | for ifn in input_fns: 47 | dses.append(ifn(input_context)) 48 | mixed_ds = tf.data.Dataset.sample_from_datasets(dses, weights) 49 | return mixed_ds 50 | return tf.distribute.get_strategy().distribute_datasets_from_function( 51 | input_fn) 52 | 53 | 54 | class Dataset(abc.ABC): 55 | """A dataset that handles creating a tf.data.Dataset.""" 56 | 57 | def __init__(self, config: ml_collections.ConfigDict): 58 | """Constructs the dataset.""" 59 | self.config = config.dataset 60 | self.task_config = config.task 61 | 62 | @abc.abstractmethod 63 | def extract(self, example, training): 64 | """Extracts needed features & annotations into a flat dictionary. 65 | 66 | Note: be consisous about 0 in label, which should probably reserved for 67 | special use (such as padding). 68 | 69 | Args: 70 | example: `dict` of raw features. 71 | training: `bool` of training vs eval mode. 72 | 73 | Returns: 74 | example: `dict` of relevant features and labels 75 | """ 76 | 77 | @abc.abstractmethod 78 | def load_dataset(self, input_context, training): 79 | """Load tf.data.Dataset from sources such as TFDS or TFRecord files.""" 80 | 81 | def parse_example(self, example, training): 82 | del training 83 | return example 84 | 85 | def filter_example(self, unused_example, unused_training): 86 | return True 87 | 88 | def pipeline(self, 89 | process_single_example: Callable[[tf.data.Dataset, int, bool], 90 | tf.data.Dataset], 91 | global_batch_size: int, training: bool): 92 | """Data pipeline from name to preprocessed examples. 93 | 94 | Args: 95 | process_single_example: a function that takes single example dataset and 96 | returns processed example dataset. 97 | global_batch_size: global batch size. 98 | training: training vs eval mode. 99 | 100 | Returns: 101 | An input_fn which generates a tf.data.Dataset instance. 102 | """ 103 | config = self.config 104 | def input_fn(input_context): 105 | dataset = self.load_dataset(input_context, training) 106 | if config.cache_dataset: 107 | dataset = dataset.cache() 108 | 109 | if input_context: 110 | batch_size = input_context.get_per_replica_batch_size(global_batch_size) 111 | # Sharding is not neccesary for TFDS given read_config above. 112 | # dataset = dataset.shard(input_context.num_input_pipelines, 113 | # input_context.input_pipeline_id) 114 | else: 115 | batch_size = global_batch_size 116 | 117 | if training: 118 | options = tf.data.Options() 119 | options.deterministic = False 120 | options.experimental_slack = True 121 | dataset = dataset.with_options(options) 122 | buffer_size = config.get('buffer_size', 0) 123 | if buffer_size <= 0: 124 | buffer_size = 10 * batch_size 125 | dataset = dataset.shuffle(buffer_size) 126 | dataset = dataset.repeat() 127 | 128 | dataset = dataset.map( 129 | lambda x: self.parse_example(x, training), 130 | num_parallel_calls=tf.data.experimental.AUTOTUNE 131 | ).filter( 132 | lambda x: self.filter_example(x, training) 133 | ).map( 134 | lambda x: self.extract(x, training), 135 | num_parallel_calls=tf.data.experimental.AUTOTUNE 136 | ) 137 | if process_single_example: 138 | dataset = process_single_example( 139 | dataset, config.batch_duplicates, training) 140 | 141 | # TODO(b/181662974): Revert this and support non-even batch sizes. 142 | # dataset = dataset.batch(batch_size, drop_remainder=training) 143 | dataset = dataset.padded_batch(batch_size, drop_remainder=True) 144 | if config.batch_duplicates > 1 and training: 145 | dataset = dataset.map(self._flatten_dims, 146 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 147 | dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) 148 | return dataset 149 | 150 | return input_fn 151 | 152 | def _flatten_dims(self, example): 153 | """Flatten first 2 dims when batch is independently duplicated.""" 154 | 155 | def flatten_first_2_dims(t): 156 | """Merge first 2 dims.""" 157 | shape_list = t.shape.as_list() 158 | new_bsz = functools.reduce(operator.mul, shape_list[:2]) 159 | out_shape = [new_bsz] + shape_list[2:] 160 | return tf.reshape(t, out_shape) 161 | 162 | return tf.nest.map_structure(flatten_first_2_dims, example) 163 | 164 | @property 165 | @abc.abstractmethod 166 | def num_train_examples(self): 167 | """Number of training examples.""" 168 | 169 | @property 170 | @abc.abstractmethod 171 | def num_eval_examples(self): 172 | """Number of eval examples.""" 173 | 174 | 175 | class TFDSDataset(Dataset): 176 | """A dataset created from a TFDS dataset. 177 | 178 | Each example is a dictionary, but the fields may be different for each 179 | dataset. 180 | 181 | Each task would have a list of required fields (e.g. bounding boxes for 182 | object detection). When a dataset is used for a specific task, it should 183 | contain all the fields required by that task. 184 | """ 185 | 186 | def __init__(self, config: ml_collections.ConfigDict): 187 | """Constructs the dataset.""" 188 | super().__init__(config) 189 | self.builder = tfds.builder(self.config.tfds_name, 190 | data_dir=self.config.get('data_dir', None)) 191 | self.builder.download_and_prepare() 192 | self.allowed_tasks = [] 193 | 194 | def load_dataset(self, input_context, training): 195 | """Load tf.data.Dataset from TFDS.""" 196 | split = self.config.train_split if training else self.config.eval_split 197 | # For TFDS, pass input_context using read_config to make TFDS read 198 | # different parts of the dataset on different workers. 199 | read_config = tfds.ReadConfig(input_context=input_context) 200 | if isinstance(split, list): 201 | dataset = self.builder.as_dataset( 202 | split=split[0], shuffle_files=training, read_config=read_config) 203 | for i in range(1, len(split)): 204 | dataset.concatenate(self.builder.as_dataset( 205 | split=split[i], shuffle_files=training, read_config=read_config)) 206 | else: 207 | dataset = self.builder.as_dataset( 208 | split=split, shuffle_files=training, read_config=read_config) 209 | return dataset 210 | 211 | @property 212 | def num_train_examples(self): 213 | return self.builder.info.splits[self.config.train_split].num_examples 214 | 215 | @property 216 | def num_eval_examples(self): 217 | return self.builder.info.splits[ 218 | self.config.eval_split].num_examples if not self.task_config.get( 219 | 'unbatch', False) else None 220 | 221 | 222 | class TFRecordDataset(Dataset): 223 | """A dataset created from tfrecord files.""" 224 | 225 | def __init__(self, config: ml_collections.ConfigDict): 226 | """Constructs the dataset.""" 227 | super().__init__(config) 228 | self.dataset_cls = tf.data.TFRecordDataset 229 | 230 | def load_dataset(self, input_context, training): 231 | """Load tf.data.Dataset from TFRecord files.""" 232 | if training or self.config.eval_split == 'train': 233 | file_pattern = self.config.train_file_pattern 234 | else: 235 | file_pattern = self.config.val_file_pattern 236 | dataset = tf.data.Dataset.list_files(file_pattern, shuffle=training) 237 | dataset = dataset.interleave( 238 | self.dataset_cls, cycle_length=32, deterministic=not training, 239 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 240 | return dataset 241 | 242 | @abc.abstractmethod 243 | def get_feature_map(self, training): 244 | """Returns feature map(s) for parsing the TFExample. 245 | 246 | Returns a single feature map (a dict) to parse a TFEXample. 247 | Returns a tuple of (context feature map, sequence feature map) to parse a 248 | TFSequenceExample. Context features are non-sequence features, i.e. 249 | independent of time/frame. Sequence features have time/frame dimension. 250 | 251 | Args: 252 | training: `bool` of training vs eval mode. 253 | """ 254 | 255 | def parse_example(self, example, training): 256 | """Parse the serialized example into a dictionary of tensors. 257 | 258 | Args: 259 | example: the serialized tf.train.Example or tf.train.SequenceExample. 260 | training: `bool` of training vs eval mode. 261 | 262 | Returns: 263 | a dictionary of feature name to tensors. 264 | """ 265 | feature_map = self.get_feature_map(training) 266 | if isinstance(feature_map, dict): 267 | example = tf.io.parse_single_example(example, feature_map) 268 | else: 269 | context_features, sequence_features = feature_map 270 | example, sequence = tf.io.parse_single_sequence_example( 271 | example, context_features, sequence_features) 272 | example.update(sequence) 273 | 274 | for k in example: 275 | if isinstance(example[k], tf.SparseTensor): 276 | if example[k].dtype == tf.string: 277 | example[k] = tf.sparse.to_dense(example[k], default_value='') 278 | else: 279 | example[k] = tf.sparse.to_dense(example[k], default_value=0) 280 | return example 281 | 282 | @property 283 | def num_train_examples(self): 284 | return self.config.train_num_examples 285 | 286 | @property 287 | def num_eval_examples(self): 288 | return self.config.eval_num_examples if not self.task_config.get( 289 | 'unbatch', False) else None 290 | 291 | 292 | -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """All registered datasets.""" 17 | 18 | from data import coco # pylint: disable=unused-import 19 | -------------------------------------------------------------------------------- /data/decode_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Utility functions for decoding example into features and labels.""" 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def get_feature_map_for_image(): 22 | return { 23 | 'image/encoded': tf.io.FixedLenFeature((), tf.string), 24 | 'image/source_id': tf.io.FixedLenFeature((), tf.string, ''), 25 | 'image/height': tf.io.FixedLenFeature((), tf.int64, -1), 26 | 'image/width': tf.io.FixedLenFeature((), tf.int64, -1), 27 | 'image/filename': tf.io.FixedLenFeature((), tf.string, ''), 28 | } 29 | 30 | 31 | def get_feature_map_for_object_detection(): 32 | return { 33 | 'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32), 34 | 'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32), 35 | 'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32), 36 | 'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32), 37 | 'image/object/class/label': tf.io.VarLenFeature(tf.int64), 38 | 'image/object/area': tf.io.VarLenFeature(tf.float32), 39 | 'image/object/is_crowd': tf.io.VarLenFeature(tf.int64), 40 | 'image/object/score': tf.io.VarLenFeature(tf.float32), 41 | } 42 | 43 | 44 | def get_feature_map_for_instance_segmentation(): 45 | return { 46 | 'image/object/segmentation': 47 | tf.io.RaggedFeature( 48 | value_key='image/object/segmentation_v', 49 | dtype=tf.float32, 50 | partitions=[ 51 | tf.io.RaggedFeature.RowSplits('image/object/segmentation_sep') # pytype: disable=attribute-error 52 | ]), 53 | } 54 | 55 | 56 | def get_feature_map_for_keypoint_detection(): 57 | return { 58 | 'image/object/keypoints': 59 | tf.io.RaggedFeature( 60 | value_key='image/object/keypoints_v', 61 | dtype=tf.float32, 62 | partitions=[ 63 | tf.io.RaggedFeature.RowSplits('image/object/keypoints_sep') # pytype: disable=attribute-error 64 | ]), 65 | 'image/object/num_keypoints': 66 | tf.io.VarLenFeature(tf.int64), 67 | } 68 | 69 | 70 | def get_feature_map_for_captioning(): 71 | return { 72 | 'image/caption': tf.io.VarLenFeature(tf.string), 73 | } 74 | 75 | 76 | def decode_image(example): 77 | """Decodes the image and set its static shape.""" 78 | image = tf.io.decode_image(example['image/encoded'], channels=3) 79 | image.set_shape([None, None, 3]) 80 | image = tf.image.convert_image_dtype(image, tf.float32) 81 | return image 82 | 83 | 84 | def decode_boxes(example): 85 | """Concat box coordinates in the format of [ymin, xmin, ymax, xmax].""" 86 | xmin = example['image/object/bbox/xmin'] 87 | xmax = example['image/object/bbox/xmax'] 88 | ymin = example['image/object/bbox/ymin'] 89 | ymax = example['image/object/bbox/ymax'] 90 | return tf.stack([ymin, xmin, ymax, xmax], axis=-1) 91 | 92 | 93 | def decode_areas(example): 94 | xmin = example['image/object/bbox/xmin'] 95 | xmax = example['image/object/bbox/xmax'] 96 | ymin = example['image/object/bbox/ymin'] 97 | ymax = example['image/object/bbox/ymax'] 98 | height = tf.cast(example['image/height'], dtype=tf.float32) 99 | width = tf.cast(example['image/width'], dtype=tf.float32) 100 | return tf.cond( 101 | tf.greater(tf.shape(example['image/object/area'])[0], 0), 102 | lambda: example['image/object/area'], 103 | lambda: (xmax - xmin) * (ymax - ymin) * height * width) 104 | 105 | 106 | def decode_is_crowd(example): 107 | return tf.cond( 108 | tf.greater(tf.shape(example['image/object/is_crowd'])[0], 0), 109 | lambda: tf.cast(example['image/object/is_crowd'], dtype=tf.bool), 110 | lambda: tf.zeros_like(example['image/object/class/label'], dtype=tf.bool) 111 | ) 112 | 113 | 114 | def decode_scores(example): 115 | return tf.cond( 116 | tf.greater(tf.shape(example['image/object/score'])[0], 0), 117 | lambda: example['image/object/score'], 118 | lambda: tf.ones_like(example['image/object/class/label'], # pylint: disable=g-long-lambda 119 | dtype=tf.float32) 120 | ) 121 | -------------------------------------------------------------------------------- /data/obj365.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Object365 dataset class.""" 17 | 18 | import ml_collections 19 | from data import dataset as dataset_lib 20 | from data import decode_utils 21 | import tensorflow as tf 22 | 23 | 24 | @dataset_lib.DatasetRegistry.register('obj365') 25 | class Obj365Dataset(dataset_lib.TFRecordDataset): 26 | """Dataset for Obj365 tasks.""" 27 | 28 | def __init__(self, config: ml_collections.ConfigDict): 29 | """Constructs the dataset. 30 | 31 | Args: 32 | config: the model config. 33 | """ 34 | super().__init__(config) 35 | 36 | if 'label_shift' in config.dataset: 37 | self.label_shift = config.dataset.label_shift 38 | else: 39 | self.label_shift = 0 40 | 41 | def _get_source_id(self, example): 42 | def _generate_source_id(): 43 | return tf.strings.as_string( 44 | tf.strings.to_hash_bucket_fast(example['image/encoded'], 2**63 - 1)) 45 | 46 | if self.config.get('regenerate_source_id', False): 47 | source_id = _generate_source_id() 48 | else: 49 | source_id = tf.cond( 50 | tf.greater(tf.strings.length(example['image/source_id']), 51 | 0), lambda: example['image/source_id'], 52 | _generate_source_id) 53 | return source_id 54 | 55 | def _decode_masks(self, example): 56 | """Decode a set of PNG masks to the tf.float32 tensors.""" 57 | def _decode_png_mask(png_bytes): 58 | mask = tf.squeeze( 59 | tf.io.decode_png(png_bytes, channels=1, dtype=tf.uint8), axis=-1) 60 | mask = tf.cast(mask, dtype=tf.float32) 61 | mask.set_shape([None, None]) 62 | return mask 63 | 64 | height = example['image/height'] 65 | width = example['image/width'] 66 | masks = example['image/object/mask'] 67 | return tf.cond( 68 | tf.greater(tf.size(masks), 0), 69 | lambda: tf.map_fn(_decode_png_mask, masks, dtype=tf.float32), 70 | lambda: tf.zeros([0, height, width], dtype=tf.float32)) 71 | 72 | def extract(self, example, training): 73 | """Extracts needed features & annotations into a flat dictionary. 74 | 75 | Note: 76 | - label starts at 1 instead of 0, as 0 is reserved for special use 77 | (such as padding). 78 | - coordinates (e.g. bbox) are (normalized to be) in [0, 1]. 79 | 80 | Args: 81 | example: `dict` of raw features. 82 | training: `bool` of training vs eval mode. 83 | 84 | Returns: 85 | example: `dict` of relevant features and labels. 86 | """ 87 | bbox = decode_utils.decode_boxes(example) 88 | example = { 89 | 'image': decode_utils.decode_image(example), 90 | 'image/id': self._get_source_id(example), 91 | 'bbox': bbox, 92 | 'is_crowd': decode_utils.decode_is_crowd(example), 93 | 'label': example['image/object/class/label'] + self.label_shift, 94 | 'area': decode_utils.decode_areas(example), 95 | } 96 | return example 97 | 98 | @property 99 | def num_train_examples(self): 100 | return { 101 | 'obj365': 1662289, 102 | 'obj365v1': 608606, 103 | 'obj365v2': 1662289 104 | }[self.config.dataset_name] 105 | 106 | @property 107 | def num_eval_examples(self): 108 | return { 109 | 'obj365': 80000, 110 | 'obj365v1': 30000, 111 | 'obj365v2': 80000 112 | }[self.config.dataset_name] 113 | 114 | @property 115 | def num_classes(self): 116 | return 365 117 | 118 | def get_feature_map(self, training): 119 | """Returns feature map for parsing the TFExample.""" 120 | del training 121 | image_feature_map = decode_utils.get_feature_map_for_image() 122 | detection_feature_map = decode_utils.get_feature_map_for_object_detection() 123 | feature_map = {**image_feature_map, **detection_feature_map} 124 | if self.config.get('include_mask', False): 125 | feature_map.update({'image/object/mask': tf.io.VarLenFeature(tf.string)}) 126 | return feature_map 127 | -------------------------------------------------------------------------------- /data/recognition.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Image classification datasets.""" 17 | 18 | from data import dataset as dataset_lib 19 | import tensorflow as tf 20 | 21 | 22 | @dataset_lib.DatasetRegistry.register('object_recognition') 23 | class ImageDataset(dataset_lib.TFDSDataset): 24 | """Dataset for image classification datasets.""" 25 | 26 | def extract(self, example, training): 27 | """Extracts needed features & annotations into a flat dictionary. 28 | 29 | Args: 30 | example: `dict` of raw features. 31 | training: `bool` of training vs eval mode. 32 | 33 | Returns: 34 | example: `dict` of relevant features and labels. 35 | """ 36 | image = example['image'] 37 | if image.shape.rank == 2 or image.shape[-1] == 1: 38 | image = tf.image.grayscale_to_rgb(image) 39 | if 'label' in example: 40 | label = example['label'] 41 | else: 42 | label = tf.zeros([], dtype=tf.int32) 43 | return {'image': image, 44 | 'label': label} 45 | 46 | @property 47 | def num_classes(self): 48 | return self.builder.info.features['label'].num_classes 49 | 50 | 51 | -------------------------------------------------------------------------------- /data/scripts/create_davis_tfrecord.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Generate tfrecord for DAVIS2017.""" 17 | 18 | import collections 19 | import os 20 | 21 | from absl import app 22 | from absl import flags 23 | from absl import logging 24 | import numpy as np 25 | from PIL import Image 26 | import tensorflow as tf 27 | 28 | flags.DEFINE_string('split', 'train', 'train or val') 29 | flags.DEFINE_integer('shards', 25, '') 30 | flags.DEFINE_integer('num_frames', 3, '') 31 | flags.DEFINE_string('data_dir', '', '') 32 | flags.DEFINE_string('output_dir', '', '') 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | 37 | # There is one video ('tennis' in train) that labels one object as 255. 38 | def handle_mislabel_255(arr): 39 | maxm = np.max(arr[arr != 255]) 40 | arr[arr == 255] = maxm + 1 41 | return arr 42 | 43 | 44 | def generate_tf_sequence_example( 45 | image_list, segmentation_list, filename, video_name): 46 | """Create tf sequence example.""" 47 | num_frames = len(image_list) 48 | assert len(image_list) == len(segmentation_list) 49 | height, width, c = image_list[0].shape 50 | assert c == 3 51 | frame_id = filename.split('.')[0] 52 | 53 | example_proto = tf.train.SequenceExample( 54 | context=tf.train.Features( 55 | feature={ 56 | 'image/format': 57 | tf.train.Feature( 58 | bytes_list=tf.train.BytesList( 59 | value=[bytes('png', 'utf-8')])), 60 | 'image/channels': 61 | tf.train.Feature(int64_list=tf.train.Int64List(value=[c])), 62 | 'image/height': 63 | tf.train.Feature( 64 | int64_list=tf.train.Int64List(value=[height])), 65 | 'image/width': 66 | tf.train.Feature( 67 | int64_list=tf.train.Int64List(value=[width])), 68 | 'video/name': 69 | tf.train.Feature( 70 | bytes_list=tf.train.BytesList( 71 | value=[bytes(video_name, 'utf-8')])), 72 | 'video/frame_id': 73 | tf.train.Feature( 74 | bytes_list=tf.train.BytesList( 75 | value=[bytes(frame_id, 'utf-8')])), 76 | 'video/num_frames': 77 | tf.train.Feature( 78 | int64_list=tf.train.Int64List(value=[num_frames])), 79 | }), 80 | feature_lists=tf.train.FeatureLists( 81 | feature_list={ 82 | 'video/frames': 83 | tf.train.FeatureList(feature=[ 84 | tf.train.Feature( 85 | bytes_list=tf.train.BytesList( 86 | value=[tf.io.encode_png(image).numpy()])) 87 | for image in image_list 88 | ]), 89 | 'video/segmentations': 90 | tf.train.FeatureList(feature=[ 91 | tf.train.Feature( 92 | bytes_list=tf.train.BytesList(value=[ 93 | tf.io.encode_png(tf.expand_dims(seg, -1)).numpy() 94 | ])) for seg in segmentation_list 95 | ]), 96 | })) 97 | 98 | return example_proto.SerializeToString() 99 | 100 | 101 | def main(unused_argv): 102 | split = FLAGS.split 103 | data_dir = FLAGS.data_dir 104 | images_dir = os.path.join(data_dir, 'JPEGImages/480p') 105 | annotation_dir = os.path.join(data_dir, 'Annotations_unsupervised/480p/') 106 | video_names = [ 107 | s.strip() for s in tf.io.gfile.GFile( 108 | os.path.join(data_dir, f'ImageSets/2017/{split}.txt')).readlines()] 109 | 110 | num_frames = FLAGS.num_frames 111 | shards = FLAGS.shards 112 | output_dir = FLAGS.output_dir 113 | writers = [ 114 | tf.io.TFRecordWriter(os.path.join( 115 | output_dir, 116 | f'{split}_{num_frames}-{i:05d}-of-{shards:05d}.tfrecord')) 117 | for i in range(shards) 118 | ] 119 | 120 | k = 0 121 | for i, video_name in enumerate(video_names): 122 | image_filenames = tf.io.gfile.listdir( 123 | os.path.join(images_dir, video_name)) 124 | ann_filenames = tf.io.gfile.listdir( 125 | os.path.join(annotation_dir, video_name)) 126 | 127 | all_images = collections.deque( 128 | maxlen=num_frames if num_frames > 0 else None) 129 | all_segs = collections.deque( 130 | maxlen=num_frames if num_frames > 0 else None) 131 | for j, (image_f, ann_f) in enumerate(zip(image_filenames, ann_filenames)): 132 | logging.info('%s, %s', video_name, image_f) 133 | 134 | # load the image. 135 | image = np.asarray( 136 | Image.open( 137 | tf.io.gfile.GFile( 138 | os.path.join(images_dir, video_name, image_f), 'rb'))) 139 | all_images.append(image) 140 | 141 | # load the segmentations. 142 | data = np.array( 143 | Image.open( 144 | tf.io.gfile.GFile( 145 | os.path.join(annotation_dir, video_name, ann_f), 'rb'))) 146 | data = handle_mislabel_255(data) 147 | all_segs.append(data) 148 | 149 | if j >= num_frames - 1 and num_frames > 0: 150 | serialized_example = generate_tf_sequence_example( 151 | all_images, all_segs, image_f, video_name) 152 | writers[k % shards].write(serialized_example) 153 | k += 1 154 | 155 | # Write all frames out in the same example. 156 | if num_frames <= 0: 157 | serialized_example = generate_tf_sequence_example( 158 | all_images, all_segs, image_filenames[-1], video_name) 159 | writers[k % shards].write(serialized_example) 160 | k += 1 161 | 162 | for writer in writers: 163 | writer.close() 164 | 165 | 166 | if __name__ == '__main__': 167 | app.run(main) 168 | -------------------------------------------------------------------------------- /data/scripts/create_kittistep_tfrecord.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Generate tfrecord for kitti-step.""" 17 | 18 | import collections 19 | import os 20 | from typing import Optional 21 | 22 | from absl import app 23 | from absl import flags 24 | from absl import logging 25 | import numpy as np 26 | from PIL import Image 27 | import tensorflow as tf 28 | 29 | flags.DEFINE_string('split', 'train', '') 30 | flags.DEFINE_integer('shards', 25, '') 31 | flags.DEFINE_integer('num_frames', 3, '') 32 | flags.DEFINE_string('raw_image_dir', '', '') 33 | flags.DEFINE_string('raw_ann_dir', '', '') 34 | flags.DEFINE_string('output_dir', '', '') 35 | 36 | FLAGS = flags.FLAGS 37 | 38 | _INSTANCE_LABEL_DIVISOR = 1000 39 | _ENCODED_INSTANCE_LABEL_DIVISOR = 256 40 | 41 | 42 | def _decode_panoptic_map(panoptic_map_path: str) -> Optional[str]: 43 | """Decodes the panoptic map from encoded image file. 44 | 45 | 46 | Args: 47 | panoptic_map_path: Path to the panoptic map image file. 48 | 49 | Returns: 50 | Panoptic map as an encoded int32 numpy array bytes or None if not existing. 51 | """ 52 | if not tf.io.gfile.exists(panoptic_map_path): 53 | return None 54 | with tf.io.gfile.GFile(panoptic_map_path, 'rb') as f: 55 | panoptic_map = np.array(Image.open(f)).astype(np.int32) 56 | semantic_map = panoptic_map[:, :, 0] 57 | instance_map = ( 58 | panoptic_map[:, :, 1] * _ENCODED_INSTANCE_LABEL_DIVISOR + 59 | panoptic_map[:, :, 2]) 60 | panoptic_map = semantic_map * _INSTANCE_LABEL_DIVISOR + instance_map 61 | return panoptic_map.tobytes() 62 | 63 | 64 | def generate_tf_sequence_example( 65 | image_list, panoptic_map_list, filename, video_name): 66 | """Create tf sequence example.""" 67 | assert len(image_list) == len(panoptic_map_list) 68 | height, width, c = image_list[0].shape 69 | assert c == 3 70 | frame_id = filename.split('.')[0] 71 | 72 | example_proto = tf.train.SequenceExample( 73 | context=tf.train.Features( 74 | feature={ 75 | 'image/filename': 76 | tf.train.Feature( 77 | bytes_list=tf.train.BytesList( 78 | value=[bytes(filename, 'utf-8')])), 79 | 'image/format': 80 | tf.train.Feature( 81 | bytes_list=tf.train.BytesList( 82 | value=[bytes('png', 'utf-8')])), 83 | 'image/channels': 84 | tf.train.Feature(int64_list=tf.train.Int64List(value=[c])), 85 | 'image/height': 86 | tf.train.Feature( 87 | int64_list=tf.train.Int64List(value=[height])), 88 | 'image/width': 89 | tf.train.Feature( 90 | int64_list=tf.train.Int64List(value=[width])), 91 | 'image/segmentation/class/format': 92 | tf.train.Feature( 93 | bytes_list=tf.train.BytesList( 94 | value=[bytes('raw', 'utf-8')])), 95 | 'video/sequence_id': 96 | tf.train.Feature( 97 | bytes_list=tf.train.BytesList( 98 | value=[bytes(video_name, 'utf-8')])), 99 | 'video/frame_id': 100 | tf.train.Feature( 101 | bytes_list=tf.train.BytesList( 102 | value=[bytes(frame_id, 'utf-8')])), 103 | }), 104 | feature_lists=tf.train.FeatureLists( 105 | feature_list={ 106 | 'image/encoded_list': 107 | tf.train.FeatureList(feature=[ 108 | tf.train.Feature( 109 | bytes_list=tf.train.BytesList( 110 | value=[tf.io.encode_png(image).numpy()])) 111 | for image in image_list 112 | ]), 113 | 'image/segmentation/class/encoded_list': 114 | tf.train.FeatureList(feature=[ 115 | tf.train.Feature( 116 | bytes_list=tf.train.BytesList(value=[seg])) 117 | for seg in panoptic_map_list 118 | ]), 119 | })) 120 | 121 | return example_proto.SerializeToString() 122 | 123 | 124 | def main(unused_argv): 125 | split = FLAGS.split 126 | raw_image_dir = FLAGS.raw_image_dir 127 | raw_ann_dir = FLAGS.raw_ann_dir 128 | num_frames = FLAGS.num_frames 129 | video_names = tf.io.gfile.listdir(os.path.join(raw_image_dir, split)) 130 | if num_frames <= 0: 131 | assert FLAGS.shards <= 0 132 | shards = FLAGS.shards if FLAGS.shards > 0 else len(video_names) 133 | tf.io.gfile.makedirs(FLAGS.output_dir) 134 | writers = [ 135 | tf.io.TFRecordWriter(os.path.join( 136 | FLAGS.output_dir, 137 | f'{split}_{num_frames}-{i:05d}-of-{shards:05d}.tfrecord')) 138 | for i in range(shards) 139 | ] 140 | 141 | k = 0 142 | for i, video_name in enumerate(video_names): 143 | frame_filenames = tf.io.gfile.listdir( 144 | os.path.join(raw_image_dir, split, video_name)) 145 | 146 | # If larger than 3 frames, we only save one example per video, so only 147 | # getting the first n frames of that video. 148 | if num_frames > 3: 149 | frame_filenames = frame_filenames[:num_frames] 150 | 151 | all_images = collections.deque( 152 | maxlen=num_frames if num_frames > 0 else None) 153 | all_panoptic_maps = collections.deque( 154 | maxlen=num_frames if num_frames > 0 else None) 155 | for j, fn in enumerate(frame_filenames): 156 | logging.info('%s, %s', video_name, fn) 157 | 158 | # load the image. 159 | image = np.asarray( 160 | Image.open( 161 | tf.io.gfile.GFile( 162 | os.path.join(raw_image_dir, split, video_name, fn), 'rb'))) 163 | all_images.append(image) 164 | 165 | # load and decode the panoptic map. 166 | panoptic_map = _decode_panoptic_map( 167 | os.path.join(raw_ann_dir, split, video_name, fn)) 168 | all_panoptic_maps.append(panoptic_map) 169 | 170 | if j >= num_frames - 1 and num_frames > 0: 171 | serialized_example = generate_tf_sequence_example( 172 | all_images, all_panoptic_maps, fn, video_name) 173 | writers[k % shards].write(serialized_example) 174 | k += 1 175 | 176 | # Write all frames out in the same example. 177 | if num_frames <= 0: 178 | serialized_example = generate_tf_sequence_example( 179 | all_images, all_panoptic_maps, frame_filenames[-1], video_name) 180 | writers[k % shards].write(serialized_example) 181 | k += 1 182 | 183 | for writer in writers: 184 | writer.close() 185 | 186 | 187 | if __name__ == '__main__': 188 | app.run(main) 189 | -------------------------------------------------------------------------------- /data/scripts/merge_coco_json_tfrecord.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Merge COCO annotation json file with tfrecord files. 17 | 18 | This is used to merge the generated object detection annotations with the 19 | groundtruth tfrecord files, to generate new tfrecord files that can be used 20 | for downstream task inference such as instance segmentation and keypoint 21 | detection. 22 | 23 | When merging, the image features (image/encoded, image/source_id, etc) are kept 24 | the same. Object features related to detection (bbox, is_crowd, label) are 25 | populated from the json annotation file. Downstream task features ( 26 | segmentation, keypoint) are padded. 27 | """ 28 | 29 | import collections 30 | import json 31 | import os 32 | 33 | from absl import app 34 | from absl import flags 35 | from absl import logging 36 | from data.scripts import tfrecord_lib 37 | import tensorflow as tf 38 | 39 | flags.DEFINE_string('tfrecord_path', '', 'Tfrecord file pattern.') 40 | flags.DEFINE_string('annotation_path', '', 'JSON annotation file path.') 41 | flags.DEFINE_string('output_dir', None, 'Output directory') 42 | 43 | FLAGS = flags.FLAGS 44 | 45 | 46 | COPY_FEATURE_LIST = ['image/encoded', 'image/source_id', 'image/format', 47 | 'image/filename', 'image/height', 'image/width', 48 | 'image/key/sha256', 'image/caption'] 49 | 50 | 51 | def load_instance_annotations(annotation_path): 52 | """Load instance annotations. 53 | 54 | Args: 55 | annotation_path: str. Path to the annotation file. 56 | 57 | Returns: 58 | category_id_to_name_map: dict of category ids to category names. 59 | img_to_ann: a dict of image_id to annotation. 60 | """ 61 | with tf.io.gfile.GFile(annotation_path, 'r') as f: 62 | annotations = json.load(f) 63 | 64 | img_to_ann = collections.defaultdict(list) 65 | for ann in annotations['annotations']: 66 | image_id = ann['image_id'] 67 | img_to_ann[image_id].append(ann) 68 | 69 | category_id_to_name_map = dict( 70 | (element['id'], element['name']) for element in annotations['categories']) 71 | 72 | return category_id_to_name_map, img_to_ann 73 | 74 | 75 | def coco_annotations_to_lists(obj_annotations, id_to_name_map): 76 | """Converts COCO annotations to feature lists. 77 | 78 | Args: 79 | obj_annotations: a list of object annotations. 80 | id_to_name_map: category id to category name map. 81 | 82 | Returns: 83 | a dict of list features. 84 | """ 85 | 86 | data = dict((k, list()) for k in [ 87 | 'xmin', 'xmax', 'ymin', 'ymax', 'is_crowd', 'category_id', 88 | 'category_names', 'area', 'score']) 89 | 90 | for ann in obj_annotations: 91 | (x, y, width, height) = tuple(ann['bbox']) 92 | if width > 0. and height > 0.: # Only keep valid boxes. 93 | data['xmin'].append(float(x)) 94 | data['xmax'].append(float(x + width)) 95 | data['ymin'].append(float(y)) 96 | data['ymax'].append(float(y + height)) 97 | data['is_crowd'].append(ann['iscrowd']) 98 | category_id = int(ann['category_id']) 99 | data['category_id'].append(category_id) 100 | data['category_names'].append(id_to_name_map[category_id].encode('utf8')) 101 | data['area'].append(float(height * width)) 102 | data['score'].append(ann['score']) 103 | 104 | return data 105 | 106 | 107 | def obj_annotations_to_feature_dict(obj_annotations, id_to_name_map): 108 | """Convert COCO annotations to an encoded feature dict. 109 | 110 | Args: 111 | obj_annotations: a list of object annotations. 112 | id_to_name_map: category id to category name map. 113 | 114 | Returns: 115 | a dict of tf features, and the number of instances. 116 | """ 117 | 118 | data = coco_annotations_to_lists(obj_annotations, id_to_name_map) 119 | feature_dict = { 120 | 'image/object/bbox/xmin': 121 | tfrecord_lib.convert_to_feature( 122 | data['xmin'], value_type='float_list'), 123 | 'image/object/bbox/xmax': 124 | tfrecord_lib.convert_to_feature( 125 | data['xmax'], value_type='float_list'), 126 | 'image/object/bbox/ymin': 127 | tfrecord_lib.convert_to_feature( 128 | data['ymin'], value_type='float_list'), 129 | 'image/object/bbox/ymax': 130 | tfrecord_lib.convert_to_feature( 131 | data['ymax'], value_type='float_list'), 132 | 'image/object/class/text': 133 | tfrecord_lib.convert_to_feature( 134 | data['category_names'], value_type='bytes_list'), 135 | 'image/object/class/label': 136 | tfrecord_lib.convert_to_feature( 137 | data['category_id'], value_type='int64_list'), 138 | 'image/object/is_crowd': 139 | tfrecord_lib.convert_to_feature( 140 | data['is_crowd'], value_type='int64_list'), 141 | 'image/object/area': 142 | tfrecord_lib.convert_to_feature( 143 | data['area'], value_type='float_list'), 144 | 'image/object/score': 145 | tfrecord_lib.convert_to_feature( 146 | data['score'], value_type='float_list'), 147 | } 148 | return feature_dict, len(data['xmin']) 149 | 150 | 151 | def update_tfrecord_file(tfrecord_path, image_to_anns, category_id_to_name_map, 152 | output_path): 153 | """Merge one tfrecord file with annotations. 154 | 155 | Args: 156 | tfrecord_path: string, the input tfrecord path. 157 | image_to_anns: a dict of image_id to annotation. 158 | category_id_to_name_map: dict of category ids to category names. 159 | output_path: string, the output tfrecord file path. 160 | """ 161 | dataset = tf.data.TFRecordDataset(tfrecord_path) 162 | 163 | with tf.io.TFRecordWriter(output_path) as writer: 164 | for serialized_ex in dataset: 165 | ex = tf.train.Example() 166 | ex.ParseFromString(serialized_ex.numpy()) 167 | 168 | image_id = int(ex.features.feature['image/source_id'].bytes_list.value[0]) 169 | anns = image_to_anns[image_id] 170 | 171 | # Copy the following features from current tf example. 172 | feature_dict = {} 173 | for f in COPY_FEATURE_LIST: 174 | feature_dict[f] = ex.features.feature[f] 175 | 176 | # Populate the object detection features from json annotations. 177 | det_features, num_bbox = obj_annotations_to_feature_dict( 178 | anns, category_id_to_name_map) 179 | feature_dict.update(det_features) 180 | 181 | # Pad the segmentation and keypoint features. 182 | feature_dict.update({ 183 | 'image/object/segmentation_v': 184 | tfrecord_lib.convert_to_feature([], value_type='float_list'), 185 | 'image/object/segmentation_sep': 186 | tfrecord_lib.convert_to_feature( 187 | [0] * (num_bbox + 1), value_type='int64_list'), 188 | 'image/object/keypoints_v': 189 | tfrecord_lib.convert_to_feature([], value_type='float_list'), 190 | 'image/object/keypoints_sep': 191 | tfrecord_lib.convert_to_feature( 192 | [0] * (num_bbox + 1), value_type='int64_list'), 193 | 'image/object/num_keypoints': 194 | tfrecord_lib.convert_to_feature( 195 | [0] * num_bbox, value_type='int64_list'), 196 | }) 197 | 198 | new_ex = tf.train.Example( 199 | features=tf.train.Features(feature=feature_dict)).SerializeToString() 200 | writer.write(new_ex) 201 | 202 | 203 | def main(unused_argv): 204 | category_id_to_name_map, image_to_anns = load_instance_annotations( 205 | FLAGS.annotation_path) 206 | 207 | tfrecord_lib.check_and_make_dir(FLAGS.output_dir) 208 | 209 | tfrecord_paths = tf.io.gfile.glob(FLAGS.tfrecord_path) 210 | for tfrecord_path in tfrecord_paths: 211 | output_path = os.path.join(FLAGS.output_dir, 212 | os.path.basename(tfrecord_path)) 213 | update_tfrecord_file(tfrecord_path, image_to_anns, category_id_to_name_map, 214 | output_path) 215 | logging.info('Finished writing file %s', output_path) 216 | 217 | 218 | if __name__ == '__main__': 219 | app.run(main) 220 | -------------------------------------------------------------------------------- /data/scripts/tfrecord_lib.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Helper functions for creating TFRecord datasets.""" 17 | 18 | import hashlib 19 | import io 20 | import itertools 21 | # copybara:insert import multiprocessing 22 | 23 | from absl import logging 24 | import numpy as np 25 | from PIL import Image 26 | import tensorflow as tf 27 | 28 | 29 | 30 | LOG_EVERY = 100 31 | 32 | 33 | def convert_to_feature(value, value_type=None): 34 | """Converts the given python object to a tf.train.Feature. 35 | 36 | Args: 37 | value: int, float, bytes or a list of them. 38 | value_type: optional, if specified, forces the feature to be of the given 39 | type. Otherwise, type is inferred automatically. Can be one of 40 | ['bytes', 'int64', 'float', 'bytes_list', 'int64_list', 'float_list'] 41 | 42 | Returns: 43 | feature: A tf.train.Feature object. 44 | """ 45 | 46 | if value_type is None: 47 | 48 | element = value[0] if isinstance(value, list) else value 49 | 50 | if isinstance(element, bytes): 51 | value_type = 'bytes' 52 | 53 | elif isinstance(element, (int, np.integer)): 54 | value_type = 'int64' 55 | 56 | elif isinstance(element, (float, np.floating)): 57 | value_type = 'float' 58 | 59 | else: 60 | raise ValueError('Cannot convert type {} to feature'. 61 | format(type(element))) 62 | 63 | if isinstance(value, list): 64 | value_type = value_type + '_list' 65 | 66 | if value_type == 'int64': 67 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 68 | 69 | elif value_type == 'int64_list': 70 | value = np.asarray(value).astype(np.int64).reshape(-1) 71 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 72 | 73 | elif value_type == 'float': 74 | return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 75 | 76 | elif value_type == 'float_list': 77 | value = np.asarray(value).astype(np.float32).reshape(-1) 78 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 79 | 80 | elif value_type == 'bytes': 81 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 82 | 83 | elif value_type == 'bytes_list': 84 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 85 | 86 | else: 87 | raise ValueError('Unknown value_type parameter - {}'.format(value_type)) 88 | 89 | 90 | def image_info_to_feature_dict(height, width, filename, image_id, 91 | encoded_str, encoded_format): 92 | """Convert image information to a dict of features.""" 93 | 94 | key = hashlib.sha256(encoded_str).hexdigest() 95 | 96 | return { 97 | 'image/height': convert_to_feature(height), 98 | 'image/width': convert_to_feature(width), 99 | 'image/filename': convert_to_feature(filename.encode('utf8')), 100 | 'image/source_id': convert_to_feature(str(image_id).encode('utf8')), 101 | 'image/key/sha256': convert_to_feature(key.encode('utf8')), 102 | 'image/encoded': convert_to_feature(encoded_str), 103 | 'image/format': convert_to_feature(encoded_format.encode('utf8')), 104 | } 105 | 106 | 107 | def read_image(image_path): 108 | pil_image = Image.open(image_path) 109 | return np.asarray(pil_image) 110 | 111 | 112 | def encode_mask_as_png(mask): 113 | pil_image = Image.fromarray(mask) 114 | output_io = io.BytesIO() 115 | pil_image.save(output_io, format='PNG') 116 | return output_io.getvalue() 117 | 118 | 119 | def write_tf_record_dataset(output_path, annotation_iterator, 120 | process_func, num_shards, 121 | multiple_processes=None, unpack_arguments=True): 122 | """Iterates over annotations, processes them and writes into TFRecords. 123 | 124 | Args: 125 | output_path: The prefix path to create TF record files. 126 | annotation_iterator: An iterator of tuples containing details about the 127 | dataset. 128 | process_func: A function which takes the elements from the tuples of 129 | annotation_iterator as arguments and returns a tuple of (tf.train.Example, 130 | int). The integer indicates the number of annotations that were skipped. 131 | num_shards: int, the number of shards to write for the dataset. 132 | multiple_processes: integer, the number of multiple parallel processes to 133 | use. If None, uses multi-processing with number of processes equal to 134 | `os.cpu_count()`, which is Python's default behavior. If set to 0, 135 | multi-processing is disabled. 136 | Whether or not to use multiple processes to write TF Records. 137 | unpack_arguments: 138 | Whether to unpack the tuples from annotation_iterator as individual 139 | arguments to the process func or to pass the returned value as it is. 140 | 141 | Returns: 142 | num_skipped: The total number of skipped annotations. 143 | """ 144 | 145 | writers = [ 146 | tf.io.TFRecordWriter( 147 | output_path + '-%05d-of-%05d.tfrecord' % (i, num_shards)) 148 | for i in range(num_shards) 149 | ] 150 | 151 | total_num_annotations_skipped = 0 152 | 153 | if multiple_processes is None or multiple_processes > 0: 154 | 155 | pool = multiprocessing.Pool(processes=multiple_processes) 156 | if unpack_arguments: 157 | tf_example_iterator = pool.starmap(process_func, annotation_iterator) 158 | else: 159 | tf_example_iterator = pool.imap(process_func, annotation_iterator) 160 | else: 161 | if unpack_arguments: 162 | tf_example_iterator = itertools.starmap(process_func, annotation_iterator) 163 | else: 164 | tf_example_iterator = map(process_func, annotation_iterator) 165 | 166 | for idx, (tf_example, num_annotations_skipped) in enumerate( 167 | tf_example_iterator): 168 | if idx % LOG_EVERY == 0: 169 | logging.info('On image %d', idx) 170 | 171 | total_num_annotations_skipped += num_annotations_skipped 172 | writers[idx % num_shards].write(tf_example.SerializeToString()) 173 | 174 | if multiple_processes is None or multiple_processes > 0: 175 | pool.close() 176 | pool.join() 177 | 178 | for writer in writers: 179 | writer.close() 180 | 181 | logging.info('Finished writing, skipped %d annotations.', 182 | total_num_annotations_skipped) 183 | return total_num_annotations_skipped 184 | 185 | 186 | def check_and_make_dir(directory): 187 | """Creates the directory if it doesn't exist.""" 188 | if not tf.io.gfile.isdir(directory): 189 | tf.io.gfile.makedirs(directory) 190 | -------------------------------------------------------------------------------- /data/text.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Text datasets.""" 17 | 18 | from data import dataset as dataset_lib 19 | 20 | 21 | @dataset_lib.DatasetRegistry.register('text') 22 | class TextDataset(dataset_lib.TFDSDataset): 23 | """Dataset.""" 24 | 25 | def extract(self, example, training): 26 | """Extracts needed features & annotations into a flat dictionary. 27 | 28 | Args: 29 | example: `dict` of raw features. 30 | training: `bool` of training vs eval mode. 31 | 32 | Returns: 33 | a sequence 34 | """ 35 | if self.config.tfds_name.startswith('wikipedia'): 36 | text = example['title'] + '\n\n' + example['text'] 37 | elif self.config.tfds_name.startswith('wmt'): 38 | src, dst = self.config.tfds_name.split('/')[1].split('-') 39 | if training: 40 | text = '[src] ' + example[src] + ' [dst] ' + example[dst] 41 | else: 42 | text = '[src] ' + example[src] + ' [dst] ' 43 | else: 44 | text = example['text'] 45 | return {'text': text} 46 | -------------------------------------------------------------------------------- /data/tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tokenizer library.""" 17 | 18 | import abc 19 | 20 | import tensorflow as tf 21 | import tensorflow_text as tf_text 22 | 23 | 24 | class Tokenizer(abc.ABC): 25 | """Tokenizer base class.""" 26 | 27 | def __init__(self): 28 | pass 29 | 30 | @property 31 | @abc.abstractmethod 32 | def vocab_size(self): 33 | """Vocab size.""" 34 | 35 | @abc.abstractmethod 36 | def string_to_ids(self, string): 37 | """Tokenize a single string.""" 38 | 39 | @abc.abstractmethod 40 | def strings_to_ids(self, strings): 41 | """Tokenize a batch of strings.""" 42 | 43 | @abc.abstractmethod 44 | def ids_to_strings(self, ids, ids_len): 45 | """Detokenize a batch of ids.""" 46 | 47 | 48 | class SPTokenizer(Tokenizer): 49 | """Sentence Piece Tokenizer.""" 50 | 51 | def __init__(self, model_path, add_bos=False, add_eos=False): 52 | super(SPTokenizer, self).__init__() 53 | self.model_path = model_path 54 | with tf.io.gfile.GFile(model_path, "rb") as f: 55 | model = f.read() 56 | self.tokenizer = tf_text.SentencepieceTokenizer(model, 57 | out_type=tf.string, 58 | add_bos=add_bos, 59 | add_eos=add_eos) 60 | 61 | @property 62 | def vocab_size(self): 63 | return int(self.tokenizer.vocab_size().numpy()) 64 | 65 | def string_to_ids(self, string): 66 | tokens = self.tokenizer.tokenize(string) 67 | pieces = self.tokenizer.string_to_id(tokens) 68 | return tf.cast(pieces, tf.int64) 69 | 70 | def strings_to_ids(self, strings): 71 | return self.string_to_ids(strings) 72 | 73 | def ids_to_strings(self, ids, ids_len): 74 | return self.tokenizer.detokenize(ids) 75 | -------------------------------------------------------------------------------- /metrics/fid.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Evaluation metrics for FID score.""" 17 | 18 | import dataclasses 19 | 20 | from absl import logging 21 | import numpy as np 22 | import tensorflow as tf 23 | import tensorflow_gan as tfgan 24 | 25 | 26 | def get_stats_for_fid(act): 27 | """Get mean and std statistics from activations for FID computation.""" 28 | if act.ndim != 2: 29 | raise ValueError("Expected input to have 2 axes") 30 | act = np.asarray(act, dtype=np.float64) 31 | mu = np.mean(act, axis=0) 32 | sigma = np.cov(act, rowvar=False) 33 | return mu, sigma 34 | 35 | 36 | def _symmetric_matrix_square_root(mat, eps=1e-10): 37 | """Compute square root of a symmetric matrix.""" 38 | u, s, vt = np.linalg.svd(mat, hermitian=True) 39 | si = np.where(s < eps, s, np.sqrt(s)) 40 | return u.dot(np.diag(si)).dot(vt) 41 | 42 | 43 | def _trace_sqrt_product(sigma, sigma_v): 44 | """Find the trace of the positive sqrt of product of covariance matrices.""" 45 | sqrt_sigma = _symmetric_matrix_square_root(sigma) 46 | sqrt_a_sigmav_a = sqrt_sigma.dot(sigma_v).dot(sqrt_sigma) 47 | return _symmetric_matrix_square_root(sqrt_a_sigmav_a).trace() 48 | 49 | 50 | def get_fid_score(mu1, sigma1, mu2, sigma2): 51 | """FID score.""" 52 | if mu1.shape != mu2.shape: 53 | raise ValueError("means should have the same shape") 54 | dim, = mu1.shape 55 | if not sigma1.shape == sigma2.shape == (dim, dim): 56 | raise ValueError("covariance matrices should be the same shape (d, d)") 57 | mu1 = np.asarray(mu1, dtype=np.float64) 58 | mu2 = np.asarray(mu2, dtype=np.float64) 59 | sigma1 = np.asarray(sigma1, dtype=np.float64) 60 | sigma2 = np.asarray(sigma2, dtype=np.float64) 61 | return (np.square(mu1 - mu2).sum() + sigma1.trace() + sigma2.trace() - 62 | 2 * _trace_sqrt_product(sigma1, sigma2)) 63 | 64 | 65 | @dataclasses.dataclass 66 | class TFGANMetricEvaluator: 67 | """A wrappner class for tensorflow-gan evaluation.""" 68 | dataset_name: str = "cifar10" 69 | image_size: int = -1 70 | inceptionv3_input_size: int = 299 71 | activations_key: str = "pool_3" 72 | resize_method: str = "bilinear" 73 | antialias: bool = False 74 | 75 | def __post_init__(self): 76 | self.all_logits_real = [] 77 | self.all_pool3_real = [] 78 | self.all_logits_gen = [] 79 | self.all_pool3_gen = [] 80 | self.dataset_stats_mean, self.dataset_stats_cov = self.load_fid_stats() 81 | 82 | def load_fid_stats(self, stats_path=None): 83 | """Load the pre-computed dataset statistics.""" 84 | logging.info("loading FID stats for datasets %s", self.dataset_name) 85 | # TODO(iamtingchen): provide stat path via config dict. 86 | if self.dataset_name == "cifar10": 87 | filename = "{}/cifar10_stats_real.npy".format(stats_path) 88 | elif self.dataset_name == "downsampled_imagenet/64x64": 89 | filename = "{}/imagenet64_stats_real.npz".format(stats_path) 90 | with tf.io.gfile.GFile(filename, "rb") as fin: 91 | stats_real = np.load(fin) 92 | return stats_real["mu"], stats_real["sigma"] 93 | elif self.dataset_name == "imagenet2012": 94 | assert self.image_size in [ 95 | 32, 64, 128, 256, 384, 512, 768, 1024, 1536, 2048] 96 | filename = "{}/imagenet_man_{}_stats_real.npz".format( 97 | stats_path, self.image_size) 98 | with tf.io.gfile.GFile(filename, "rb") as fin: 99 | stats_real = np.load(fin) 100 | return stats_real["mu"], stats_real["cov"] 101 | elif self.dataset_name == "coco": 102 | filename = "{}/coco_stats_real.npz".format(stats_path) 103 | with tf.io.gfile.GFile(filename, "rb") as fin: 104 | stats_real = np.load(fin) 105 | return stats_real["mu"], stats_real["cov"] 106 | else: 107 | logging.warn("Dataset %s stats not found!", self.dataset_name) 108 | return None, None 109 | 110 | with tf.io.gfile.GFile(filename, "rb") as fin: 111 | stats_real = np.load(fin) 112 | logging.info("FID stats loading done! Number of examples %d", 113 | stats_real.shape[0]) 114 | return get_stats_for_fid(stats_real) 115 | 116 | def preprocess_inputs(self, inputs, is_n1p1=False): 117 | """Resize images and shift/clip pixels to [-1, 1].""" 118 | if isinstance(inputs, list): 119 | all_inputs = tf.concat(inputs, 0) 120 | all_inputs = self.preprocess_inputs(all_inputs, is_n1p1=is_n1p1) 121 | return tf.split(all_inputs, len(inputs)) 122 | if is_n1p1: 123 | inputs = tf.clip_by_value(inputs, -1.0, 1.0) 124 | inputs = (inputs + 1.0) / 2.0 125 | 126 | inputs = tf.image.resize( 127 | inputs, [self.inceptionv3_input_size, self.inceptionv3_input_size], 128 | self.resize_method, 129 | antialias=self.antialias) 130 | inputs = tf.clip_by_value(inputs, 0.0, 1.0) 131 | # transform inputs to [-1, 1] 132 | inputs = inputs * 2 - 1.0 133 | return inputs 134 | 135 | def get_inception_stats(self, inputs): 136 | if isinstance(inputs, list): 137 | return [self.get_inception_stats(x) for x in inputs] 138 | stats = tfgan.eval.run_inception(inputs) 139 | return stats["logits"], stats["pool_3"] 140 | 141 | def update_stats(self, logits_real, pool3_real, logits_gen, pool3_gen): 142 | self.all_logits_real.append(logits_real) 143 | self.all_pool3_real.append(pool3_real) 144 | self.all_logits_gen.append(logits_gen) 145 | self.all_pool3_gen.append(pool3_gen) 146 | 147 | def reset(self): 148 | self.all_logits_real.clear() 149 | self.all_pool3_real.clear() 150 | self.all_logits_gen.clear() 151 | self.all_pool3_gen.clear() 152 | return 153 | 154 | def compute_fid_score(self): 155 | """Return a dict of metrics.""" 156 | metrics = {} 157 | logging.info("Computing Inception score.") 158 | all_logits_gen = np.concatenate(self.all_logits_gen, axis=0) 159 | logging.info("IS number of gen samples: %d, number of classes: %d", 160 | all_logits_gen.shape[0], all_logits_gen.shape[1]) 161 | is_score = tfgan.eval.classifier_score_from_logits(all_logits_gen) 162 | metrics.update({"inception_score": is_score}) 163 | logging.info("Computing FID score.") 164 | all_stats_real = np.concatenate(self.all_pool3_real, axis=0) 165 | all_stats_gen = np.concatenate(self.all_pool3_gen, axis=0) 166 | logging.info("FID number of real samples: %d", all_stats_real.shape[0]) 167 | logging.info("FID number of generated samples: %d", all_stats_gen.shape[0]) 168 | gen_mean, gen_cov = get_stats_for_fid(all_stats_gen) 169 | ref_mean, ref_cov = get_stats_for_fid(all_stats_real) 170 | metrics.update({ 171 | "fid_batch": get_fid_score(gen_mean, gen_cov, ref_mean, ref_cov), 172 | }) 173 | if self.dataset_stats_mean is not None: 174 | metrics.update({ 175 | "fid_full": 176 | get_fid_score(gen_mean, gen_cov, self.dataset_stats_mean, 177 | self.dataset_stats_cov), 178 | "fid_batch_vs_full": 179 | get_fid_score(ref_mean, ref_cov, self.dataset_stats_mean, 180 | self.dataset_stats_cov), 181 | }) 182 | return metrics 183 | -------------------------------------------------------------------------------- /metrics/fvd.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Evaluation metrics for FVD/IS scores with I3D / C3D nets.""" 17 | 18 | import dataclasses 19 | from functools import partial 20 | 21 | from absl import logging 22 | from jax.experimental import jax2tf 23 | import numpy as np 24 | from metrics.fid import get_fid_score 25 | from metrics.fid import get_stats_for_fid 26 | from metrics.fid import TFGANMetricEvaluator 27 | 28 | import tensorflow as tf 29 | import tensorflow_gan as tfgan 30 | from universal_diffusion.metrics import c3d 31 | from universal_diffusion.metrics import i3d 32 | 33 | 34 | @dataclasses.dataclass 35 | class FVDMetricEvaluator(TFGANMetricEvaluator): 36 | """A wrapper class for tensorflow-gan evaluation extended for FVD.""" 37 | dataset_name: str 38 | image_size: int = -1 39 | activations_key: str = 'pool_3' 40 | 41 | def __post_init__(self): 42 | self.all_logits_real = [] 43 | self.all_pool3_real = [] 44 | self.all_logits_gen = [] 45 | self.all_pool3_gen = [] 46 | self.dataset_stats_mean, self.dataset_stats_cov = self.load_fid_stats() 47 | 48 | if self.dataset_name == 'ucf101': 49 | self.model = jax2tf.convert(partial(c3d.run_model, c3d.load_params())) 50 | elif self.dataset_name == 'kinetics600': 51 | self.model = jax2tf.convert(partial(i3d.run_model, i3d.load_params())) 52 | else: 53 | assert False, 'dataset not supported for FVD' 54 | 55 | def load_fid_stats(self, stat_path=None): 56 | """Load the pre-computed dataset statistics.""" 57 | logging.info('loading FID stats for datasets %s', self.dataset_name) 58 | if self.dataset_name in {'ucf101', 'kinetics600'} and False: 59 | assert self.image_size in [64, 128] 60 | filename = '{}/{}_{}_stats_real.npz'.format( 61 | stat_path, self.dataset_name, self.image_size) 62 | with tf.io.gfile.GFile(filename, 'rb') as fin: 63 | stats_real = np.load(fin) 64 | logging.info('FID stats loading done! Number of examples %d', 65 | stats_real['mu'].shape[0]) 66 | return stats_real['mu'], stats_real['cov'] 67 | else: 68 | logging.warn('Dataset %s stats not found!', self.dataset_name) 69 | return None, None 70 | 71 | def preprocess_inputs(self, inputs, is_n1p1=False): 72 | if isinstance(inputs, list): 73 | all_inputs = tf.concat(inputs, 0) 74 | all_inputs = self.preprocess_inputs(all_inputs, is_n1p1=is_n1p1) 75 | return tf.split(all_inputs, len(inputs)) 76 | if is_n1p1: 77 | inputs = tf.clip_by_value(inputs, -1.0, 1.0) 78 | inputs = (inputs + 1.0) / 2.0 79 | inputs = inputs * 255.0 80 | return inputs 81 | 82 | def get_inception_stats(self, inputs): 83 | if isinstance(inputs, list): 84 | return [self.get_inception_stats(x) for x in inputs] 85 | stats = self.model(inputs) 86 | if 'features' in stats: 87 | return stats['logits'], stats['features'] 88 | else: 89 | return stats['logits_mean'], stats['pool'] 90 | 91 | def compute_fid_score(self): 92 | """Return a dict of metrics.""" 93 | metrics = {} 94 | logging.info('Computing Inception score.') 95 | all_logits_gen = np.concatenate(self.all_logits_gen, axis=0) 96 | all_logits_real = np.concatenate(self.all_logits_real, axis=0) 97 | logging.info('IS number of gen samples: %d, number of classes: %d', 98 | all_logits_gen.shape[0], all_logits_gen.shape[1]) 99 | is_score = tfgan.eval.classifier_score_from_logits(all_logits_gen) 100 | metrics.update({'inception_score': is_score}) 101 | 102 | logging.info('Computing FVD score.') 103 | all_stats_real = np.concatenate(self.all_pool3_real, axis=0) 104 | all_stats_gen = np.concatenate(self.all_pool3_gen, axis=0) 105 | logging.info('FVD number of real samples: %d', all_stats_real.shape[0]) 106 | logging.info('FVD number of generated samples: %d', all_stats_gen.shape[0]) 107 | gen_mean, gen_cov = get_stats_for_fid(all_stats_gen) 108 | ref_mean, ref_cov = get_stats_for_fid(all_stats_real) 109 | 110 | gen_logits_mean, gen_logits_cov = get_stats_for_fid(all_logits_gen) 111 | ref_logits_mean, ref_logits_cov = get_stats_for_fid(all_logits_real) 112 | 113 | metrics.update({ 114 | 'fvd_pool_batch': get_fid_score(gen_mean, gen_cov, ref_mean, ref_cov), 115 | 'fvd_batch': get_fid_score(gen_logits_mean, gen_logits_cov, 116 | ref_logits_mean, ref_logits_cov), 117 | }) 118 | if self.dataset_stats_mean is not None: 119 | metrics.update({ 120 | 'fvd_pool_full': 121 | get_fid_score(gen_mean, gen_cov, self.dataset_stats_mean, 122 | self.dataset_stats_cov), 123 | 'fvd_pool_batch_vs_full': 124 | get_fid_score(ref_mean, ref_cov, self.dataset_stats_mean, 125 | self.dataset_stats_cov), 126 | }) 127 | return metrics 128 | -------------------------------------------------------------------------------- /metrics/metric_registry.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | import registry 17 | 18 | MetricRegistry = registry.Registry() -------------------------------------------------------------------------------- /metrics/metric_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | def yxyx_to_xywh(box): 17 | ymin, xmin, ymax, xmax = box 18 | w = xmax - xmin 19 | h = ymax - ymin 20 | return [xmin, ymin, w, h] 21 | -------------------------------------------------------------------------------- /metrics/text_metrics.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Metrics for text tasks.""" 17 | 18 | from metrics import metric_registry 19 | import sacrebleu 20 | 21 | 22 | @metric_registry.MetricRegistry.register('text_sacrebleu') 23 | class BleuMetric(): 24 | """BLEU metric for text.""" 25 | 26 | def __init__(self, config): 27 | self._config = config 28 | self.reset_states() 29 | 30 | def reset_states(self): 31 | self._metric_values = None 32 | self._targets = [] 33 | self._predictions = [] 34 | 35 | def record_prediction(self, predictions, targets): 36 | """Records predictions. 37 | 38 | If multiple references are present, then each example need to have the same 39 | number of references. 40 | 41 | Args: 42 | predictions: list of strings. Has len batch_size. 43 | targets: list of strings, or list of list of strings if multiple 44 | references are present. Has len batch_size. In the format of 45 | [ex1_ref, ex2_ref, ...] or 46 | [[ex1_ref1, ex2_ref1, ...], [ex1_ref2, ex2_ref2, ...], ...]. 47 | """ 48 | self._predictions.extend(predictions) 49 | 50 | # Turn targets into lists. 51 | if not isinstance(targets[0], list): 52 | targets = [targets] 53 | if self._targets: 54 | assert len(self._targets) == len(targets) 55 | for i in range(len(targets)): 56 | self._targets[i].extend(targets[i]) 57 | else: 58 | self._targets = targets 59 | 60 | def _evaluate(self): 61 | """Evaluates with predictions for all examples. 62 | 63 | Call this function from `self.result`. 64 | 65 | Returns: 66 | dict from metric name to float value. 67 | """ 68 | tokenizer = self._config.get('tokenizer', 'intl') 69 | bleu_score = sacrebleu.corpus_bleu(self._predictions, self._targets, 70 | smooth_method='exp', 71 | smooth_value=0.0, 72 | force=False, 73 | lowercase=False, 74 | tokenize=tokenizer, 75 | use_effective_order=False) 76 | return {'bleu': bleu_score.score} 77 | 78 | def result(self): 79 | """Return the metric values (and compute it if needed).""" 80 | if self._metric_values is None: 81 | self._metric_values = self._evaluate() 82 | return self._metric_values 83 | -------------------------------------------------------------------------------- /metrics/vos_metrics.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Video object segmentation metrics.""" 17 | 18 | import io 19 | import os 20 | import tempfile 21 | 22 | from absl import logging 23 | from davis2017.evaluation import DAVISEvaluation 24 | import numpy as np 25 | import pandas as pd 26 | import PIL 27 | import utils 28 | from metrics import metric_registry 29 | import tensorflow as tf 30 | 31 | _PALETTE = [ 32 | 0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, 33 | 128, 128, 128, 128, 64, 0, 0, 191, 0, 0, 64, 128, 0, 191, 128, 0, 64, 0, 34 | 128, 191, 0, 128, 64, 128, 128, 191, 128, 128 35 | ] 36 | 37 | 38 | @metric_registry.MetricRegistry.register('davis_video_object_segmentation') 39 | class DavisVideoObjectSegmentationMetric(): 40 | """Video object segmentation metric for DAVIS.""" 41 | 42 | def __init__(self, config): 43 | self.config = config 44 | self.metric_names = [ 45 | 'J&F-Mean', 'J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall', 46 | 'F-Decay' 47 | ] 48 | self.results_dir = config.task.metric.get('results_dir') 49 | self.dataset_eval = DAVISEvaluation( 50 | davis_root=config.dataset.annotations_dir, 51 | task=config.dataset.vos_task, 52 | gt_set=('train' if config.dataset.eval_split == 'train' else 'val'), 53 | use_tfds=False) 54 | self.reset_states() 55 | 56 | def reset_states(self): 57 | self.metric_values = None 58 | self._local_pred_dir_obj = tempfile.TemporaryDirectory() 59 | 60 | def record_prediction(self, predictions, video_name, frame_ids, step): 61 | """Records predictions. 62 | 63 | Args: 64 | predictions: uint8 of shape (num_frames, h, w, channel), where channel 65 | could be 1 or >1, but only the last channel is used as instance id. 66 | video_name: str. Video name. 67 | frame_ids: list of int, or 1-d np.array. Frame ids of predictions. 68 | step: int. The checkpoint step, used to name sub-directories. 69 | """ 70 | predictions = predictions[..., -1] # Last channel is instance id. 71 | subdir = os.path.join(self._local_pred_dir_obj.name, str(step), video_name) 72 | if not tf.io.gfile.exists(subdir): 73 | tf.io.gfile.makedirs(subdir) 74 | 75 | for frame_id in frame_ids: 76 | filename = f'{frame_id:05}.png' 77 | filepath = os.path.join(subdir, filename) 78 | pred_image = PIL.Image.fromarray(predictions[frame_id], mode='L') 79 | pred_image.putpalette(_PALETTE) 80 | with io.BytesIO() as out: 81 | pred_image.save(out, format='PNG') 82 | with tf.io.gfile.GFile(filepath, 'wb') as f: 83 | f.write(out.getvalue()) 84 | logging.info('Done writing out pngs for %s', video_name) 85 | 86 | if self.results_dir is not None: 87 | # Copy images to results dir. 88 | results_dir = os.path.join(self.results_dir, str(step), video_name) 89 | utils.copy_dir(subdir, results_dir) 90 | 91 | def _evaluate(self, step): 92 | """Evaluates with predictions for all images. 93 | 94 | Call this function from `self.result`. 95 | 96 | Args: 97 | step: int. The checkpoint step being evaluated. 98 | 99 | Returns: 100 | dict from metric name to float value. 101 | """ 102 | result_path = os.path.join(self._local_pred_dir_obj.name, str(step)) 103 | metrics_res = self.dataset_eval.evaluate(result_path) 104 | J, F = metrics_res['J'], metrics_res['F'] # pylint: disable=invalid-name 105 | g_measures = [ 106 | 'J&F-Mean', 'J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall', 107 | 'F-Decay' 108 | ] 109 | g_res = np.array([ 110 | (np.mean(J['M']) + np.mean(F['M'])) / 2., 111 | np.mean(J['M']), 112 | np.mean(J['R']), 113 | np.mean(J['D']), 114 | np.mean(F['M']), 115 | np.mean(F['R']), 116 | np.mean(F['D']) 117 | ]) 118 | 119 | if self.results_dir is not None: 120 | # Write metrics to result_dir. 121 | result_dir = os.path.join(self.results_dir, str(step)) 122 | csv_name_global_path = os.path.join(result_dir, 'global_results.csv') 123 | csv_name_per_sequence_path = os.path.join(result_dir, 124 | 'per_sequence_results.csv') 125 | 126 | # Global results. 127 | g_res_ = np.reshape(g_res, [1, len(g_res)]) 128 | table_g = pd.DataFrame(data=g_res_, columns=g_measures) 129 | with tf.io.gfile.GFile(csv_name_global_path, 'w') as f: 130 | table_g.to_csv(f, index=False, float_format='%.3f') 131 | logging.info('Global results saved in %s', csv_name_global_path) 132 | 133 | # Per sequence results. 134 | assert isinstance(J['M_per_object'], dict) 135 | seq_names = list(J['M_per_object'].keys()) 136 | seq_measures = ['Sequence', 'J-Mean', 'F-Mean'] 137 | j_per_object = [J['M_per_object'][x] for x in seq_names] 138 | f_per_object = [F['M_per_object'][x] for x in seq_names] 139 | table_seq = pd.DataFrame( 140 | data=list(zip(seq_names, j_per_object, f_per_object)), 141 | columns=seq_measures) 142 | with tf.io.gfile.GFile(csv_name_per_sequence_path, 'w') as f: 143 | table_seq.to_csv(f, index=False, float_format='%.3f') 144 | logging.info('Per-sequence results saved in %s', 145 | csv_name_per_sequence_path) 146 | 147 | return {name: v for name, v in zip(g_measures, g_res)} 148 | 149 | def result(self, step): 150 | """Return the metric values (and compute it if needed).""" 151 | if self.metric_values is None: 152 | self.metric_values = self._evaluate(step) 153 | return self.metric_values 154 | -------------------------------------------------------------------------------- /models/image_ar_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """The image autoregressive decoder model.""" 17 | 18 | import einops 19 | import ml_collections 20 | 21 | import utils 22 | from architectures.transformers import AutoregressiveDecoder 23 | from architectures.transformers import FITAR 24 | from models import model as model_lib 25 | from models import model_utils 26 | import tensorflow as tf 27 | 28 | 29 | @model_lib.ModelRegistry.register('image_ar_decoder') 30 | class Model(tf.keras.models.Model): 31 | """Inputs images and returns activations.""" 32 | 33 | def __init__(self, config: ml_collections.ConfigDict, **kwargs): 34 | # vocab_size and max_seq_len don't include start token, which is only used 35 | # inside this class. 36 | super().__init__(**kwargs) 37 | image_size = config.dataset.image_size 38 | self.loss_type = config.train.loss_type 39 | config = config.model 40 | self.config = config 41 | if config.arch_name == 'base': 42 | mlp_ratio_dec = config.dim_mlp_dec // config.dim_att_dec 43 | self.decoder = AutoregressiveDecoder( 44 | config.vocab_size, config.max_seq_len, config.num_decoder_layers, 45 | config.dim_att_dec, mlp_ratio_dec, config.num_heads_dec, 46 | config.drop_path, config.drop_units, config.drop_att, 47 | config.pos_encoding_dec, config.shared_decoder_embedding, 48 | config.decoder_output_bias, cross_attention=False, name='ar_decoder') 49 | else: 50 | self.decoder = FITAR( 51 | layers=config.layers, 52 | x_size=image_size**2*3, 53 | num_groups=(image_size//config.patch_size)**2, 54 | latents_per_group=config.latents_per_group, 55 | x_dim=config.dim_att, 56 | latent_dim=config.dim_latent, 57 | x_num_heads=config.num_heads, 58 | latent_num_heads=config.num_heads, 59 | mlp_ratio=config.dim_mlp//config.dim_att, 60 | vocab_size=config.vocab_size, 61 | shared_embedding=config.shared_decoder_embedding, 62 | output_bias=config.decoder_output_bias, 63 | drop_path=config.drop_path, 64 | drop_units=config.drop_units, 65 | drop_att=config.drop_att, 66 | x_pos_encoding=config.pos_encoding, 67 | latent_pos_encoding=config.latent_pos_encoding) 68 | 69 | def call(self, images, labels=None, training=True): 70 | """Model function call for *training*.""" 71 | with tf.name_scope(''): # for other functions to have the same name scope. 72 | config = self.config 73 | input_seq, target_seq = image2seqs( 74 | images, config.arch_name, config.patch_size, config.patch_ordering) 75 | logits = self.decoder(input_seq, None, training=training) 76 | losses = model_utils.get_loss(logits, target_seq, self.loss_type) 77 | loss = tf.reduce_mean(losses) / tf.math.log(2.0) 78 | return loss, logits, target_seq 79 | 80 | def sample(self, **kwargs): 81 | """Sampling.""" 82 | # TODO(iamtingchen): add sampling. 83 | loss, _, _ = self.call(kwargs['images'], kwargs['labels'], training=False) 84 | return kwargs['images'], loss 85 | 86 | 87 | @model_lib.TrainerRegistry.register('image_ar_decoder') 88 | class ARTrainer(model_lib.Trainer): 89 | """A trainer for AR model.""" 90 | 91 | def __init__(self, config: ml_collections.ConfigDict, **kwargs): 92 | """Init and setup basic training elements under strategy scope. 93 | 94 | Note: the trainer needs to be created under `strategy.scope()`. 95 | 96 | Args: 97 | config: object for holding hyperparameters and other configurations. 98 | **kwargs: other neccesary configurations to pass for training setup. 99 | """ 100 | super().__init__(config, **kwargs) 101 | self._metrics.update({ 102 | 'loss': tf.keras.metrics.Mean('loss'), 103 | 'accuracy': tf.keras.metrics.SparseCategoricalAccuracy( 104 | 'accuracy'), 105 | }) 106 | 107 | def compute_loss(self, preprocess_outputs): 108 | """Compute loss based on model outputs and targets.""" 109 | images, labels = preprocess_outputs 110 | loss, logits, target_seq = self.model(images, labels, training=True) 111 | 112 | # update metrics 113 | self._metrics['loss'].update_state(loss) 114 | self._metrics['accuracy'].update_state(target_seq, logits) 115 | 116 | return loss 117 | 118 | 119 | def image2seqs(images, arch_name, patch_size, patch_ordering='snake'): 120 | """Turn images into input and target sequences.""" 121 | if arch_name == 'base': 122 | images = einops.rearrange(images, 'b h w c -> b (h w c)') 123 | target_seq = tf.cast(images * 255., tf.int32) # (bsz, seqlen) 124 | input_seq = tf.concat( 125 | [tf.zeros_like(target_seq[:, :1]), target_seq[:, :-1]], 1) 126 | else: 127 | images = utils.extract_patches( 128 | images, [patch_size, patch_size], patch_ordering=patch_ordering) 129 | target_seq = tf.cast(images * 255., tf.int32) # (bsz, groups, seqlen) 130 | flat_seq = einops.rearrange(target_seq, 'b n m -> b (n m)') 131 | input_seq = tf.concat( 132 | [tf.zeros_like(flat_seq[:, :1]), flat_seq[:, :-1]], 1) 133 | input_seq = tf.reshape(input_seq, tf.shape(target_seq)) 134 | return input_seq, target_seq 135 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Abstract model file.""" 17 | 18 | import abc 19 | from absl import logging 20 | import ml_collections 21 | import registry 22 | import utils 23 | from models import model_utils 24 | import tensorflow as tf 25 | 26 | ModelRegistry = registry.Registry() 27 | TrainerRegistry = registry.Registry() 28 | 29 | 30 | class Trainer(abc.ABC): 31 | """A base trainer.""" 32 | 33 | def __init__(self, config: ml_collections.ConfigDict, **kwargs): 34 | """Init and setup basic training elements under strategy scope. 35 | 36 | Note: the trainer needs to be created under `strategy.scope()`. 37 | 38 | Args: 39 | config: object for holding hyperparameters and other configurations. 40 | **kwargs: other neccesary configurations to pass for training setup. 41 | """ 42 | self._config = config 43 | 44 | # Setup learning rate and optimizer. 45 | num_train_examples = kwargs['num_train_examples'] 46 | train_steps = kwargs['train_steps'] 47 | c_opt = config.optimization 48 | batch_size = config.train.batch_size 49 | tail_steps = c_opt.get('tail_steps', 0) 50 | end_lr_factor = c_opt.get('end_lr_factor', 0.) 51 | warmup_steps = c_opt.warmup_steps or int( 52 | round(c_opt.warmup_epochs * num_train_examples // batch_size)) 53 | self._learning_rate = learning_rate = model_utils.WarmUpAndDecay( 54 | c_opt.learning_rate, c_opt.learning_rate_scaling, batch_size, 55 | c_opt.learning_rate_schedule, warmup_steps, train_steps, 56 | tail_steps=tail_steps, end_lr_factor=end_lr_factor) 57 | self._optimizer = optimizer = model_utils.build_optimizer( 58 | config.optimization, learning_rate) 59 | 60 | # Setup model and checkpoints. 61 | self._model = model = ModelRegistry.lookup(config.model.name)(config) 62 | model_dir = kwargs['model_dir'] 63 | latest_ckpt, ckpt, self._verify_restored = utils.restore_from_checkpoint( 64 | model_dir, False, 65 | model=model, global_step=optimizer.iterations, optimizer=optimizer) 66 | self._verify_restored_p = None 67 | if not latest_ckpt: 68 | if config.model.pretrained_ckpt: 69 | _, _, self._verify_restored_p = utils.restore_from_checkpoint( 70 | config.model.pretrained_ckpt, True, model=model) 71 | self._checkpoint_manager = tf.train.CheckpointManager( 72 | ckpt, model_dir, config.train.keep_checkpoint_max) 73 | 74 | # Setup metrics. 75 | self._metrics = { 76 | 'total_num_params': tf.keras.metrics.Mean('total_num_params'), 77 | 'grad_global_norm': tf.keras.metrics.Mean('grad_global_norm'), 78 | 'weight_linf_norm': tf.keras.metrics.Mean('weight_linf_norm'), 79 | 'loss': tf.keras.metrics.Mean('loss'), 80 | } 81 | self._metrics.update({ 82 | f'loss_{t.name}': tf.keras.metrics.Mean(f'loss_{t.name}') 83 | for t in config.tasks}) 84 | self._print_params = True 85 | 86 | def train_step(self, examples, tasks, strategy): 87 | """Defines a single training step for model update given examples and tasks. 88 | 89 | Args: 90 | examples: a list of data examples to be fed into the paired task class for 91 | preprocessing. 92 | tasks: a list of tasks that provide preprocessing and postprocessing for 93 | specific task. 94 | strategy: tensorflow strategy such as `TPUStrategy` or `MirroredStrategy`. 95 | """ 96 | logging.info('train_step begins...') 97 | preprocessed_outputs = [ 98 | t.preprocess_batched(e, training=True) for e, t in zip(examples, tasks)] 99 | 100 | task_loss_metrics = {} 101 | loss = 0 102 | grads = [] 103 | for i, (o, task) in enumerate(zip(preprocessed_outputs, tasks)): 104 | with tf.GradientTape() as tape: 105 | loss_t = self.compute_loss(o) 106 | task_loss_metrics[f'loss_{task.config.task.name}'] = loss_t 107 | loss += loss_t * task.config.task.weight 108 | trainable_variables = self._model.trainable_variables 109 | grads_t = tape.gradient( # div by num_replicas_in_sync for mean grad. 110 | loss_t * task.config.task.weight / strategy.num_replicas_in_sync, 111 | trainable_variables) 112 | grads = grads_t if i == 0 else [ 113 | g + gt for g, gt in zip(grads, grads_t)] 114 | self._optimizer.apply_gradients(zip(grads, trainable_variables)) 115 | 116 | # Update metrics. 117 | self._metrics['loss'].update_state(loss) 118 | for k, v in task_loss_metrics.items(): 119 | self._metrics[k].update_state(v) 120 | wmx = [tf.reduce_max(tf.math.abs(m)) for m in trainable_variables] 121 | self._metrics['weight_linf_norm'].update_state(tf.reduce_max(wmx)) 122 | multiplier = strategy.num_replicas_in_sync 123 | self._metrics['grad_global_norm'].update_state(tf.linalg.global_norm( 124 | [tf.math.scalar_mul(multiplier, g) for g in grads if g is not None])) 125 | self._metrics['total_num_params'].update_state( 126 | utils.count_params(self._model, verbose=self._print_params)) 127 | self._print_params = False 128 | logging.info('train_step ends...') 129 | 130 | @abc.abstractmethod 131 | def compute_loss(self, preprocessed_outputs): 132 | """Compute loss based on model outputs and targets.""" 133 | 134 | def check_checkpoint_restored(self): 135 | """Check if the checkpoints are correctely restored.""" 136 | (verify_restored,), (verify_restored_p,) = ( 137 | utils.check_checkpoint_restored( 138 | [self._verify_restored], [self._verify_restored_p])) 139 | self._verify_restored = verify_restored 140 | self._verify_restored_p = verify_restored_p 141 | 142 | def reset(self): 143 | """Reseting the metrics and/or other state accumulators.""" 144 | for k, _ in self._metrics.items(): 145 | self._metrics[k].reset_states() 146 | 147 | @property 148 | def model(self): 149 | """Returns model instance.""" 150 | return self._model 151 | 152 | @property 153 | def optimizer(self): 154 | """Returns optimizer instance.""" 155 | return self._optimizer 156 | 157 | @property 158 | def learning_rate(self): 159 | """Returns learning rate scheduling instance.""" 160 | return self._learning_rate 161 | 162 | @property 163 | def metrics(self): 164 | """Returns metrics instance.""" 165 | return self._metrics 166 | 167 | @property 168 | def config(self): 169 | """Returns config instance.""" 170 | return self._config 171 | 172 | @property 173 | def checkpoint_manager(self): 174 | """Returns checkpoint_manager instance.""" 175 | return self._checkpoint_manager 176 | -------------------------------------------------------------------------------- /models/video_diffusion_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """The model.""" 17 | 18 | import functools 19 | import ml_collections 20 | 21 | from architectures.tape import VideoTapeDenoiser 22 | from models import diffusion_utils 23 | from models import image_diffusion_model 24 | from models import model as model_lib 25 | from models.diffusion_utils import Scheduler 26 | import tensorflow as tf 27 | 28 | 29 | @model_lib.ModelRegistry.register('video_diffusion_model') 30 | class Model(image_diffusion_model.Model): 31 | """A model with tape for video prediction.""" 32 | 33 | def __init__(self, config: ml_collections.ConfigDict, **kwargs): 34 | super(image_diffusion_model.Model, self).__init__(**kwargs) 35 | image_size = config.dataset.image_size 36 | self.image_size = image_size 37 | self.num_classes = config.dataset.num_classes 38 | self.seq_len = config.dataset.seq_len 39 | config = config.model 40 | self.config = config 41 | self.scheduler = Scheduler(config.train_schedule) 42 | if config.x0_clip == 'auto': 43 | self.x0_clip = '{},{}'.format(-config.b_scale, config.b_scale) 44 | else: 45 | self.x0_clip = config.x0_clip 46 | self.x_channels = 3 47 | 48 | model_fn = functools.partial( 49 | VideoTapeDenoiser, 50 | num_layers=config.num_layers, 51 | latent_slots=config.latent_slots, 52 | latent_dim=config.latent_dim, 53 | latent_mlp_ratio=config.latent_mlp_ratio, 54 | latent_num_heads=config.latent_num_heads, 55 | tape_dim=config.tape_dim, 56 | tape_mlp_ratio=config.tape_mlp_ratio, 57 | rw_num_heads=config.rw_num_heads, 58 | image_height=image_size, 59 | image_width=image_size, 60 | image_channels=self.x_channels, 61 | patch_size=config.patch_size, 62 | seq_len=self.seq_len, 63 | seq_stride=config.seq_stride, 64 | seq_cond=self.seq_len - self.sample_shape[0], 65 | latent_pos_encoding=config.latent_pos_encoding, 66 | tape_pos_encoding=config.tape_pos_encoding, 67 | drop_path=config.drop_path, 68 | drop_units=config.drop_units, 69 | drop_att=config.drop_att, 70 | time_scaling=config.time_scaling, 71 | self_cond=config.self_cond, 72 | time_on_latent=config.time_on_latent, 73 | cond_on_latent_n=1 if config.cond_on_latent else 0, 74 | cond_tape_writable=config.cond_tape_writable, 75 | xattn_enc_ln=config.xattn_enc_ln, 76 | name=config.arch_name) 77 | self.denoiser = model_fn(name='denoiser') 78 | self.denoiser_ema = model_fn(name='denoiser', trainable=False) 79 | self.hidden_shapes = self.denoiser.hidden_shapes 80 | 81 | @property 82 | def sample_shape(self): 83 | if self.seq_len > 1: 84 | seq_cond = self.config.get('conditional', 'seq@0').split('@') 85 | seq_cond = int(seq_cond[-1]) if len(seq_cond) > 1 else 0 86 | return [self.seq_len-seq_cond, self.image_size, self.image_size, 87 | self.x_channels] 88 | else: 89 | return [self.image_size, self.image_size, self.x_channels] 90 | 91 | # override for conditional data 92 | def get_cond_denoise(self, labels, cond=None): 93 | config = self.config 94 | def cond_denoise(x, gamma, training): 95 | gamma = tf.reshape(gamma, [-1]) 96 | if config.conditional == 'class': 97 | gamma = tf.concat([gamma[..., tf.newaxis], labels], -1) 98 | elif config.conditional != 'none' and 'seq@' not in config.conditional: 99 | raise ValueError(f'Unknown conditional {config.conditional}') 100 | return self.denoise(x, gamma, cond, training) 101 | return cond_denoise 102 | 103 | # override to pass x_cond 104 | def sample(self, num_samples=100, iterations=100, method='ddim', **kwargs): 105 | config = self.config 106 | samples_shape = [num_samples, *self.sample_shape] 107 | if config.conditional == 'class': 108 | labels = tf.random.uniform( 109 | [num_samples], 0, self.num_classes, dtype=tf.int32) 110 | labels = tf.one_hot(labels, self.num_classes) 111 | else: 112 | labels = None 113 | x_cond = kwargs['images'][:, :-self.sample_shape[0]] 114 | x_cond = (x_cond * 2. - 1.) * config.b_scale # convert 0,1 -> -s,s 115 | samples = self.scheduler.generate( 116 | self.get_cond_denoise(labels, cond=x_cond), 117 | iterations, 118 | samples_shape, 119 | hidden_shapes=self.hidden_shapes, 120 | pred_type=config.pred_type, 121 | schedule=config.infer_schedule, 122 | td=config.td, 123 | noise_std=config.noise_std, 124 | x0_clip=self.x0_clip, 125 | self_cond=config.self_cond, 126 | sampler_name=config.sampler_name) 127 | if x_cond.shape[1] > 0: 128 | samples = tf.concat([x_cond, samples], axis=1) 129 | samples = (samples / config.b_scale / 2. + 0.5) # convert -s,s -> 0,1 130 | return samples 131 | 132 | # override to allow for more cond data 133 | def noise_denoise(self, images, labels, time_step=None, training=True): 134 | config = self.config 135 | images = (images * 2. - 1.) * config.b_scale # convert 0,1 -> -s,s 136 | seq_len = self.sample_shape[0] 137 | cond_images, images = images[:, :-seq_len], images[:, -seq_len:] 138 | images_noised, noise, _, gamma = self.scheduler.add_noise( 139 | images, time_step=time_step) 140 | images_noised_ori = images_noised 141 | if config.self_cond != 'none': 142 | sc_rate = config.get('self_cond_rate', 0.5) 143 | self_cond_by_masking = config.get('self_cond_by_masking', False) 144 | if self_cond_by_masking: 145 | sc_drop_rate = 1. - sc_rate 146 | num_sc_examples = tf.shape(images)[0] 147 | else: 148 | sc_drop_rate = 0. 149 | num_sc_examples = tf.cast( 150 | tf.cast(tf.shape(images)[0], tf.float32) * sc_rate, tf.int32) 151 | cond_denoise = self.get_cond_denoise( 152 | labels[:num_sc_examples], 153 | cond_images[:num_sc_examples] if cond_images is not None else None) 154 | if self.hidden_shapes is None: # data self-cond, return is a tensor. 155 | denoise_inputs = diffusion_utils.add_self_cond_estimate( 156 | images_noised, gamma, cond_denoise, config.pred_type, 157 | config.self_cond, self.x0_clip, num_sc_examples, 158 | drop_rate=sc_drop_rate, training=training) 159 | else: # latent self-cond, return is a tuple. 160 | denoise_inputs = diffusion_utils.add_self_cond_hidden( 161 | images_noised, gamma, cond_denoise, num_sc_examples, 162 | self.hidden_shapes, drop_rate=sc_drop_rate, training=training) 163 | else: 164 | denoise_inputs = images_noised 165 | cond_denoise = self.get_cond_denoise(labels, cond_images) 166 | denoise_out = cond_denoise(denoise_inputs, gamma, training) 167 | if isinstance(denoise_out, tuple): denoise_out = denoise_out[0] 168 | 169 | return images, noise, images_noised_ori, denoise_out 170 | 171 | 172 | @model_lib.TrainerRegistry.register('video_diffusion_model') 173 | class Trainer(image_diffusion_model.Trainer): 174 | """A trainer.""" 175 | 176 | def train_step(self, examples, tasks, strategy): 177 | super().train_step(examples, tasks, strategy) 178 | -------------------------------------------------------------------------------- /pix2seq.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/pix2seq/eabd4e98f65f7627a0e727d11a4b9cbba916283d/pix2seq.gif -------------------------------------------------------------------------------- /pix2seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/pix2seq/eabd4e98f65f7627a0e727d11a4b9cbba916283d/pix2seq.png -------------------------------------------------------------------------------- /registry.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """General Registry.""" 17 | 18 | from typing import Any 19 | from typing import Callable 20 | 21 | 22 | class Registry(object): 23 | """Registry.""" 24 | 25 | def __init__(self): 26 | self._registry = {} 27 | 28 | def register(self, key: str) -> Callable[[Any], None]: 29 | """Returns callable to register value for key.""" 30 | def r(item): 31 | if key in self._registry: 32 | raise ValueError("%s already registered!" % key) 33 | self._registry[key] = item 34 | return item 35 | return r 36 | 37 | def lookup(self, key: str) -> Any: 38 | """Looks up value for key.""" 39 | if key not in self._registry: 40 | valid_keys = "\n".join(self._registry.keys()) 41 | raise ValueError( 42 | "%s not registered!\n\n" 43 | "Valid keys:%s\n\n" % 44 | (key, valid_keys)) 45 | return self._registry[key] 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow 2 | ml-collections 3 | matplotlib 4 | tensorflow-datasets 5 | tensorflow-addons 6 | tensorflow-text 7 | pycocotools 8 | scikit-image -------------------------------------------------------------------------------- /tasks/captioning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Instance segmentation task via COCO metric evaluation.""" 17 | 18 | import ml_collections 19 | 20 | import utils 21 | import vocab 22 | from data import tokenizer as tokenizer_lib 23 | from metrics import metric_registry 24 | from tasks import task as task_lib 25 | from tasks import task_utils 26 | import tensorflow as tf 27 | 28 | 29 | @task_lib.TaskRegistry.register('captioning') 30 | class TaskCaptioning(task_lib.Task): 31 | """Image captioning with coco evaluation.""" 32 | 33 | def __init__(self, 34 | config: ml_collections.ConfigDict): 35 | super().__init__(config) 36 | metric_config = config.task.get('metric') 37 | if metric_config and metric_config.get('name'): 38 | self._coco_metrics = metric_registry.MetricRegistry.lookup( 39 | metric_config.name)(config) 40 | else: 41 | self._coco_metrics = None 42 | self._tokenizer = tokenizer_lib.SPTokenizer( 43 | config.tokenizer.sentencepiece_model, 44 | add_bos=config.tokenizer.add_bos, 45 | add_eos=config.tokenizer.add_eos) 46 | 47 | def preprocess_single(self, dataset, batch_duplicates, training): 48 | """Task-specific preprocessing of individual example in the dataset. 49 | 50 | Args: 51 | dataset: A tf.data.Dataset. 52 | batch_duplicates: `int`, enlarge a batch by augmenting it multiple times 53 | (as specified) and concating the augmented examples. 54 | training: bool. 55 | 56 | Returns: 57 | A dataset. 58 | """ 59 | if batch_duplicates > 1: 60 | raise NotImplementedError('Not supporting batch_duplicate=%d > 1 for ' 61 | 'caption as of now.' % batch_duplicates) 62 | 63 | def _preprocess_single_example(example): 64 | config = self.config.task 65 | mconfig = self.config.model 66 | if training: 67 | captions = [] 68 | for i in range(config.captions_per_image): 69 | caption = (self._tokenizer.string_to_ids(example['captions'][i]) + 70 | mconfig.text_vocab_shift) 71 | captions.append(utils.pad_to_max_len(caption, config.max_seq_len, -1)) 72 | captions = tf.stack(captions) 73 | 74 | for t in self.train_transforms: 75 | example = t.process_example(example) 76 | example['captions'] = captions 77 | else: 78 | for t in self.eval_transforms: 79 | example = t.process_example(example) 80 | 81 | # Use the first caption. This won't be used in eval. 82 | caption = (self._tokenizer.string_to_ids(example['captions'][0]) + 83 | mconfig.text_vocab_shift) 84 | caption = utils.pad_to_max_len(caption, config.max_seq_len, -1) 85 | example['captions'] = caption 86 | return example 87 | 88 | dataset = dataset.map(_preprocess_single_example, 89 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 90 | return dataset 91 | 92 | def preprocess_batched(self, batched_examples, training): 93 | """Batched preprocessing & sequence construction. 94 | 95 | Args: 96 | batched_examples: a tupple of features (`dict`) and labels (`dict`), 97 | containing images and labels. 98 | training: `bool` indicating training or inference mode. 99 | 100 | Returns: 101 | images: `float` of shape (bsz, h, w, c) 102 | input_seq: `int` of shape (bsz, seqlen), or (bsz, instacnes, seqlen) 103 | for multiple instances as in keypoint for example. 104 | target_seq: `int` of shape (bsz, seqlen'), or (bsz, instacnes, seqlen') 105 | for multiple instances as in keypoint for example. 106 | """ 107 | config = self.config.task 108 | mconfig = self.config.model 109 | 110 | if training: 111 | response_seq = batched_examples['captions'] # (bsz, num_cap, max_seq_len) 112 | prompt_seq = task_utils.build_prompt_seq_from_task_id( 113 | self.task_vocab_id, response_seq) # (bsz, 1) 114 | label_seq = tf.concat([prompt_seq, response_seq], -1) 115 | token_weights = tf.where( 116 | response_seq == 1 + mconfig.text_vocab_shift, # eos token 117 | config.eos_token_weight, 1.0) 118 | input_seq, target_seq = label_seq[..., :-1], label_seq[..., 1:] 119 | 120 | if config.input_seq_drop_rate > 0: 121 | input_seq = tf.where( 122 | tf.random.uniform(tf.shape(input_seq)) > config.input_seq_drop_rate, 123 | input_seq, vocab.FAKE_TEXT_TOKEN) 124 | 125 | return batched_examples['image'], input_seq, target_seq, token_weights 126 | else: 127 | return (batched_examples['image'], batched_examples['captions'], 128 | batched_examples) 129 | 130 | def infer(self, model, preprocessed_outputs): 131 | """Perform inference given the model and preprocessed outputs.""" 132 | config = self.config.task 133 | image, _, examples = preprocessed_outputs # response_seq unused by default 134 | bsz = tf.shape(image)[0] 135 | prompt_seq = task_utils.build_prompt_seq_from_task_id( 136 | self.task_vocab_id, prompt_shape=(bsz, 1)) 137 | pred_seq, logits, _ = model.infer( 138 | image, prompt_seq, encoded=None, 139 | temperature=config.temperature, top_k=config.top_k, top_p=config.top_p) 140 | # if True: # Sanity check by using gt response_seq as pred_seq. 141 | # pred_seq = response_seq 142 | # logits = tf.one_hot(pred_seq, self.vocab_size) 143 | return examples, pred_seq, logits 144 | 145 | def postprocess_tpu(self, batched_examples, pred_seq, logits, training=False): # pytype: disable=signature-mismatch # overriding-parameter-count-checks 146 | """Organizing results after fitting the batched examples in graph. 147 | 148 | Such as updating metrics, putting together results for computing metrics in 149 | CPU/numpy mode. 150 | 151 | Args: 152 | batched_examples: a tupple of features (`dict`) and labels (`dict`), 153 | containing images and labels. 154 | pred_seq: `int` sequence of shape (bsz * instances, seqlen'). 155 | logits: `float` sequence of shape (bsz * instances, seqlen', vocab_size). 156 | training: `bool` indicating training or inference mode. 157 | 158 | Returns: 159 | results for passing to `postprocess_cpu` which runs in CPU mode. 160 | """ 161 | return (batched_examples['image'], batched_examples['image/id'], 162 | batched_examples['captions'], pred_seq) 163 | 164 | def postprocess_cpu(self, outputs, train_step, # pytype: disable=signature-mismatch # overriding-parameter-count-checks 165 | eval_step=None, training=False, summary_tag='eval', 166 | ret_results=False): 167 | """CPU post-processing of outputs. 168 | 169 | Such as computing the metrics, log image summary. 170 | 171 | Note: current implementation only support eval mode where gt are given in 172 | metrics as they are not given here in outputs. 173 | 174 | Args: 175 | outputs: a tuple of tensor passed from `postprocess_tpu`. 176 | train_step: `int` scalar indicating training step of current model or 177 | the checkpoint. 178 | eval_step: `int` scalar indicating eval step for the given checkpoint. 179 | training: `bool` indicating training or inference mode. 180 | summary_tag: `string` of name scope for result summary. 181 | ret_results: whether to return visualization images/captions. 182 | 183 | Returns: 184 | A dict of visualization images/caption if ret_results, else None. 185 | """ 186 | config = self.config.task 187 | mconfig = self.config.model 188 | del summary_tag 189 | if not training: 190 | images, image_ids, gt_seq, pred_seq = outputs 191 | batch_size = tf.shape(image_ids)[0] 192 | pred_seq = tf.where(pred_seq == 0, 0, 193 | pred_seq - mconfig.text_vocab_shift) 194 | original_pred_seq = pred_seq.numpy() 195 | pred_seq = tf.minimum( 196 | tf.maximum(pred_seq, 0), self._tokenizer.vocab_size - 1) 197 | pred_seq = tf.cast(pred_seq, tf.int32) 198 | clipped_pred_seq = pred_seq.numpy() 199 | output_text = self._tokenizer.ids_to_strings( 200 | pred_seq, 201 | tf.ones([batch_size], tf.int32) * config.max_seq_len).numpy() 202 | output_text = [o.decode('utf-8') for o in output_text] 203 | if self._coco_metrics: 204 | self._coco_metrics.record_prediction(image_ids.numpy(), output_text, 205 | original_pred_seq, 206 | clipped_pred_seq) 207 | 208 | if ret_results: 209 | gt_seq = tf.where(gt_seq == 0, 0, 210 | gt_seq - mconfig.text_vocab_shift) 211 | gt_text = self._tokenizer.ids_to_strings( 212 | tf.cast(gt_seq, tf.int32), 213 | tf.ones([batch_size], tf.int32) * config.max_seq_len).numpy() 214 | gt_text = [s.decode('utf-8') for s in gt_text] 215 | return {'gt_images': images, 216 | 'gt_captions': gt_text, 217 | 'pred_captions': output_text} 218 | 219 | def compute_scalar_metrics(self, step): 220 | """Returns a dict containing scalar metrics to log.""" 221 | if self._coco_metrics: 222 | return self._coco_metrics.result(step) 223 | else: 224 | return {} 225 | 226 | def reset_metrics(self): 227 | """Reset states of metrics accumulators.""" 228 | if self._coco_metrics: 229 | self._coco_metrics.reset_states() 230 | -------------------------------------------------------------------------------- /tasks/image_generation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Task for image generation.""" 17 | 18 | import os 19 | 20 | from absl import logging 21 | import ml_collections 22 | import utils 23 | from data import data_utils 24 | from metrics.fid import TFGANMetricEvaluator 25 | from tasks import task as task_lib 26 | import tensorflow as tf 27 | 28 | 29 | @task_lib.TaskRegistry.register('image_generation') 30 | class TaskImageGeneration(task_lib.Task): # pytype: disable=base-class-error 31 | """TaskImageGeneration.""" 32 | 33 | def __init__(self, config: ml_collections.ConfigDict): 34 | super().__init__(config) 35 | self._metrics = {} 36 | self._metrics['eval_loss'] = tf.keras.metrics.Mean('eval_loss') 37 | self._tfgan_evaluator = TFGANMetricEvaluator( 38 | dataset_name=config.dataset.tfds_name, 39 | image_size=config.dataset.image_size) 40 | 41 | self._write_images_to_file = config.eval.get('write_images_to_file', False) 42 | if self._write_images_to_file: 43 | self._tfrecord_dir = os.path.join(config.model_dir, 'images', 44 | config.eval.tag) 45 | if not tf.io.gfile.exists(self._tfrecord_dir): 46 | tf.io.gfile.makedirs(self._tfrecord_dir) 47 | self._tfrecord_writer = None 48 | 49 | def preprocess_single(self, dataset, batch_duplicates, training): 50 | """Task-specific preprocessing of individual example in the dataset. 51 | 52 | Args: 53 | dataset: A tf.data.Dataset. 54 | batch_duplicates: `int`, enlarge a batch by augmenting it multiple times 55 | (as specified) and concating the augmented examples. 56 | training: bool. 57 | 58 | Returns: 59 | A dataset. 60 | """ 61 | image_size = self.config.dataset.image_size 62 | 63 | def _preprocess_single_example(examples): 64 | examples_list = [] 65 | for _ in range(batch_duplicates if training else 1): 66 | image_ = preprocess_image( 67 | examples['image'], 68 | height=image_size, 69 | width=image_size, 70 | cropping=self.config.dataset.cropping, 71 | flipping=self.config.dataset.flipping, 72 | training=training) 73 | if examples['label'].shape.ndims == 0: 74 | label_ = tf.one_hot(examples['label'], 75 | self.config.dataset.num_classes) 76 | else: 77 | label_ = examples['label'] 78 | examples_list.append({'image': image_, 'label': label_}) 79 | examples = utils.merge_list_of_dict(examples_list) 80 | return examples 81 | 82 | dataset = dataset.map(_preprocess_single_example, 83 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 84 | return dataset 85 | 86 | def preprocess_batched(self, examples, training): 87 | """Task-specific preprocessing of batched examples on accelerators (TPUs). 88 | 89 | Args: 90 | examples: `dict` of images and labels. 91 | training: bool. 92 | 93 | Returns: 94 | images: `float` of shape (bsz, h, w, c) 95 | labels: `int` of shape (bsz) 96 | """ 97 | if training: 98 | return examples['image'], examples['label'] 99 | else: 100 | return examples['image'], examples['label'], examples 101 | 102 | def infer(self, model, preprocessed_outputs): 103 | """Perform inference given the model and preprocessed outputs.""" 104 | images, labels, examples = preprocessed_outputs 105 | outputs = model.sample( 106 | num_samples=tf.shape(images)[0], 107 | iterations=self.config.model.infer_iterations, 108 | method=self.config.model.sampler_name, 109 | images=images, 110 | labels=labels) 111 | if isinstance(outputs, tuple) or isinstance(outputs, list): 112 | samples, eval_loss = outputs 113 | self._metrics['eval_loss'].update_state(eval_loss) 114 | else: 115 | samples = outputs 116 | return examples, samples 117 | 118 | def postprocess_tpu(self, 119 | examples, 120 | samples, 121 | training=False): 122 | """Organizing results after fitting the batched examples in graph. 123 | 124 | Such as updating metrics, putting together results for computing metrics in 125 | CPU/numpy mode. 126 | 127 | Note: current implementation only support eval mode where gt are given in 128 | metrics as they are not constructed here from input_seq/target_seq. 129 | 130 | Args: 131 | examples: `dict` of images and labels. 132 | samples: `float` predicted image tensor of (bsz, h, w, c). 133 | training: `bool` indicating training or inference mode. 134 | 135 | Returns: 136 | results for passing to `postprocess_cpu` which runs in CPU mode. 137 | """ 138 | logging.info('Start postprocess_tpu.') 139 | # Get ground-truth images. 140 | if 'original_image' in examples: 141 | images = examples['original_image'] 142 | elif examples['image'].shape[-1] == 3: 143 | images = examples['image'] 144 | else: # ground-truth images not available in the dataset. 145 | images = samples 146 | 147 | # FID 148 | data_real, data_gen = self._tfgan_evaluator.preprocess_inputs( 149 | [images, samples], is_n1p1=False) 150 | (logits_real, pool3_real), (logits_gen, pool3_gen) = ( 151 | self._tfgan_evaluator.get_inception_stats([data_real, data_gen])) 152 | 153 | logging.info('postprocess_tpu done.') 154 | return (images, samples, logits_real, pool3_real, logits_gen, pool3_gen) 155 | 156 | def postprocess_cpu(self, 157 | outputs, 158 | train_step, 159 | eval_step=None, 160 | training=False, 161 | summary_tag='eval', 162 | ret_results=False): 163 | """CPU post-processing of outputs. 164 | 165 | Args: 166 | outputs: a tuple of tensor passed from `postprocess_tpu`. 167 | train_step: `int` scalar indicating training step of current model or the 168 | checkpoint. 169 | eval_step: `int` scalar indicating eval step for the given checkpoint. 170 | training: `bool` indicating training or inference mode. 171 | summary_tag: `string` of name scope for result summary. 172 | ret_results: whether to return visualization images. 173 | 174 | Returns: 175 | A dict of visualization images if ret_results, else None. 176 | """ 177 | logging.info('Start postprocess_cpu') 178 | images, samples, logits_real, pool3_real, logits_gen, pool3_gen = outputs 179 | 180 | # FID update. 181 | self._tfgan_evaluator.update_stats( 182 | logits_real, pool3_real, logits_gen, pool3_gen) 183 | 184 | # Image summary. 185 | bsz, h, w, c = utils.shape_as_list(samples) 186 | a = tf.cast(tf.math.sqrt(tf.cast(bsz, tf.float32)), tf.int32) 187 | b = a 188 | vis_samples = samples[:a * a, ...] 189 | vis_samples = tf.reshape(vis_samples, [a, b, h, w, c]) 190 | vis_samples = tf.transpose(vis_samples, [0, 2, 1, 3, 4]) 191 | images_sum = tf.reshape(vis_samples, [1, a * h, b * w, c]) 192 | if eval_step < 2: 193 | tf.summary.image( 194 | f'{summary_tag}/samples_{eval_step}', images_sum, step=train_step) 195 | 196 | # Write gt and samples to tfrecord. 197 | if self._write_images_to_file: 198 | if self._tfrecord_writer is None: 199 | self._tfrecord_writer = tf.io.TFRecordWriter( 200 | os.path.join(self._tfrecord_dir, f'step-{train_step}.tfrecord')) 201 | for i in range(bsz): 202 | ex_str = create_tf_example(images[i], samples[i]) 203 | self._tfrecord_writer.write(ex_str) 204 | 205 | logging.info('postprocess_cpu done.') 206 | if ret_results: 207 | return {'gt': images, 'pred': samples} 208 | 209 | def compute_scalar_metrics(self, step): 210 | """Returns a dict containing scalar metrics to log.""" 211 | result = {} 212 | for metric in self._metrics.values(): 213 | result[metric.name] = metric.result().numpy() 214 | 215 | # FID 216 | result.update(self._tfgan_evaluator.compute_fid_score()) 217 | return result 218 | 219 | def reset_metrics(self): 220 | """Reset states of metrics accumulators.""" 221 | for metric in self._metrics.values(): 222 | metric.reset_states() 223 | 224 | self._tfgan_evaluator.reset() 225 | if self._write_images_to_file: 226 | self._tfrecord_writer.close() 227 | self._tfrecord_writer = None 228 | 229 | 230 | 231 | 232 | def create_tf_example(ref, hyp): 233 | """Creates a serialized tf example that stores the images. 234 | 235 | Args: 236 | ref: Tensor of [h, w, c], the ground-truth image. 237 | hyp: Tensor of [h, w, c], the generated image. 238 | 239 | Returns: 240 | the serialized tf example. 241 | """ 242 | ref_bytes = tf.io.encode_png( 243 | tf.image.convert_image_dtype(ref, tf.uint8)).numpy() 244 | hyp_bytes = tf.io.encode_png( 245 | tf.image.convert_image_dtype(hyp, tf.uint8)).numpy() 246 | feature = { 247 | 'hyp': tf.train.Feature(bytes_list=tf.train.BytesList(value=[hyp_bytes])), 248 | 'ref': tf.train.Feature(bytes_list=tf.train.BytesList(value=[ref_bytes])), 249 | } 250 | return tf.train.Example( 251 | features=tf.train.Features(feature=feature)).SerializeToString() 252 | 253 | 254 | def preprocess_image(image, 255 | height, 256 | width, 257 | cropping='none', 258 | flipping='none', 259 | training=False): 260 | """Preprocesses the given image. 261 | 262 | Args: 263 | image: `Tensor` representing an image of arbitrary size. 264 | height: Height of output image. 265 | width: Width of output image. 266 | cropping: which cropping to apply to the image. 267 | flipping: which flipping to apply to the image. 268 | training: `bool` for whether the preprocessing is for training. 269 | 270 | Returns: 271 | A preprocessed image `Tensor` of range [0, 1]. 272 | """ 273 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 274 | if cropping == 'center': 275 | image = data_utils.largest_center_square_crop(image) 276 | image = tf.image.resize( 277 | image, 278 | size=(height, width), 279 | method='bicubic', 280 | preserve_aspect_ratio=False, 281 | antialias=True) 282 | elif cropping != 'none': 283 | raise ValueError(f'Unknown cropping method {cropping}') 284 | if training: 285 | if flipping == 'left_right': 286 | image = tf.image.random_flip_left_right(image) 287 | elif flipping != 'none': 288 | raise ValueError(f'Unknown flipping method {flipping}') 289 | image = tf.ensure_shape(image, [height, width, 3]) # Let arch knows shape. 290 | return image 291 | -------------------------------------------------------------------------------- /tasks/recognition.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Object recognition / image classification task.""" 17 | 18 | import ml_collections 19 | import utils 20 | from tasks import task as task_lib 21 | from simclr.tf2 import data_util as simclr_data_util 22 | import tensorflow as tf 23 | 24 | 25 | @task_lib.TaskRegistry.register('object_recognition') 26 | class TaskObjectRecognition(task_lib.Task): 27 | """Object recognition.""" 28 | 29 | def __init__(self, config: ml_collections.ConfigDict): 30 | super().__init__(config) 31 | self._metrics = { 32 | 'accuracy': tf.keras.metrics.SparseCategoricalAccuracy('accuracy') 33 | } 34 | if 'linear_eval_all_layers' in config.model and ( 35 | config.model.linear_eval_all_layers): 36 | for i in range(config.model.num_encoder_layers+2): 37 | self._metrics['accuracy_%d'%i] = ( 38 | tf.keras.metrics.SparseCategoricalAccuracy('accuracy_%d'%i)) 39 | 40 | def preprocess_single(self, dataset, batch_duplicates, training): 41 | """Task-specific preprocessing of individual example in the dataset. 42 | 43 | Args: 44 | dataset: A tf.data.Dataset. 45 | batch_duplicates: `int`, enlarge a batch by augmenting it multiple times 46 | (as specified) and concating the augmented examples. 47 | training: bool. 48 | 49 | Returns: 50 | A dataset. 51 | """ 52 | config = self.config.task 53 | image_size = self.config.dataset.image_size 54 | 55 | def _preprocess_single_example(examples): 56 | test_crop = False if image_size <= 32 else True 57 | 58 | examples_list = [] 59 | for _ in range(batch_duplicates if training else 1): 60 | image_ = preprocess_image( 61 | examples['image'], 62 | height=image_size, 63 | width=image_size, 64 | training=training, 65 | color_jitter_strength=config.color_jitter_strength, 66 | test_crop=test_crop, 67 | train_crop=config.train_crop) 68 | if config.get('set_pixel_range_minus_one_to_one'): 69 | image_ = image_ * 2 - 1 70 | if examples['label'].shape.ndims == 0: 71 | label_ = tf.one_hot(examples['label'], 72 | self.config.dataset.num_classes) 73 | else: 74 | label_ = examples['label'] 75 | examples_list.append({'image': image_, 'label': label_}) 76 | examples = utils.merge_list_of_dict(examples_list) 77 | return examples 78 | 79 | dataset = dataset.map(_preprocess_single_example, 80 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 81 | return dataset 82 | 83 | def preprocess_batched(self, examples, training): 84 | """Task-specific preprocessing of batched examples on accelerators (TPUs). 85 | 86 | Args: 87 | examples: `dict` of images and labels. 88 | training: bool. 89 | 90 | Returns: 91 | images: `float` of shape (bsz, h, w, c) 92 | labels: `int` of shape (bsz) 93 | """ 94 | if training: 95 | return examples['image'], examples['label'] 96 | else: 97 | return examples['image'], examples['label'], examples 98 | 99 | def infer(self, model, preprocessed_outputs): 100 | """Perform inference given the model and preprocessed outputs.""" 101 | image, _, examples = preprocessed_outputs # response_seq unused by default 102 | outputs = model(image, training=False) 103 | logits = outputs[0] 104 | if hasattr(model, 'encode_decode'): 105 | outputs = model.encode_decode(image, training=False) 106 | images_recon = outputs[0] 107 | elif hasattr(model, 'sample'): 108 | outputs = model.sample(num_samples=tf.shape(image)[0]) 109 | images_recon = outputs 110 | else: 111 | images_recon = image 112 | return examples, logits, images_recon 113 | 114 | def postprocess_tpu(self, examples, logits, images_recon, # pytype: disable=signature-mismatch # overriding-parameter-count-checks 115 | training=False): 116 | """Organizing results after fitting the batched examples in graph. 117 | 118 | Such as updating metrics, putting together results for computing metrics in 119 | CPU/numpy mode. 120 | 121 | Note: current implementation only support eval mode where gt are given in 122 | metrics as they are not constructed here from input_seq/target_seq. 123 | 124 | Args: 125 | examples: `dict` containing images and labels. 126 | logits: `float` sequence of shape (bsz, seqlen', vocab_size). 127 | images_recon: `float` predicted image tensor of (bsz, h, w, c). 128 | training: `bool` indicating training or inference mode. 129 | 130 | Returns: 131 | results for passing to `postprocess_cpu` which runs in CPU mode. 132 | """ 133 | images = examples['image'] 134 | labels = tf.argmax(examples['label'], -1) 135 | if isinstance(logits, list): 136 | self._metrics['accuracy'].update_state(labels, logits[-1]) 137 | if len(logits) > 1: 138 | assert self.config.model.linear_eval_all_layers 139 | for i, logits_ in enumerate(logits): 140 | self._metrics['accuracy_%d'%i].update_state(labels, logits_) 141 | else: 142 | self._metrics['accuracy'].update_state(labels, logits) 143 | return (images, images_recon) 144 | 145 | def postprocess_cpu(self, outputs, train_step, # pytype: disable=signature-mismatch # overriding-parameter-count-checks 146 | eval_step=None, training=False, summary_tag='eval', 147 | ret_results=False): 148 | """CPU post-processing of outputs. 149 | 150 | Args: 151 | outputs: a tuple of tensor passed from `postprocess_tpu`. 152 | train_step: `int` scalar indicating training step of current model or 153 | the checkpoint. 154 | eval_step: `int` scalar indicating eval step for the given checkpoint. 155 | training: `bool` indicating training or inference mode. 156 | summary_tag: `string` of name scope for result summary. 157 | ret_results: whether to return visualization images. 158 | 159 | Returns: 160 | A dict of visualization images if ret_results, else None. 161 | """ 162 | images, images_recon = outputs 163 | if self.config.task.get('image_gen_sum'): 164 | bsz, h, w, c = utils.shape_as_list(images_recon) 165 | a = tf.cast(tf.math.sqrt(tf.cast(bsz, tf.float32)), tf.int32) 166 | b = bsz // a 167 | images_recon = tf.reshape(images_recon, [a, b, h, w, c]) 168 | images_recon = tf.transpose(images_recon, [0, 2, 1, 3, 4]) 169 | images_sum = tf.reshape(images_recon, [1, a * h, b * w, c]) 170 | else: 171 | images_sum = tf.concat([images, images_recon], 2) 172 | if self.config.task.get('set_pixel_range_minus_one_to_one'): 173 | norm = lambda x: (x-tf.reduce_min(x))/(tf.reduce_max(x)-tf.reduce_min(x)) 174 | images_sum = norm(images_sum) 175 | if eval_step <= 5: 176 | tf.summary.image(summary_tag + '/gt_pred', images_sum, step=train_step) 177 | if ret_results: 178 | return {'gt': images, 'pred': images_recon} 179 | 180 | def compute_scalar_metrics(self, step): 181 | """Returns a dict containing scalar metrics to log.""" 182 | result = {} 183 | for metric in self._metrics.values(): 184 | result[metric.name] = metric.result().numpy() 185 | return result 186 | 187 | def reset_metrics(self): 188 | """Reset states of metrics accumulators.""" 189 | for metric in self._metrics.values(): 190 | metric.reset_states() 191 | 192 | 193 | def preprocess_image(image, height, width, training=False, 194 | color_jitter_strength=0., test_crop=True, train_crop=True): 195 | """Preprocesses the given image. 196 | 197 | Args: 198 | image: `Tensor` representing an image of arbitrary size. 199 | height: Height of output image. 200 | width: Width of output image. 201 | training: `bool` for whether the preprocessing is for training. 202 | color_jitter_strength: `float` between 0 and 1 indicating the color 203 | distortion strength, disable color distortion if not bigger than 0. 204 | test_crop: whether or not to extract a central crop of the images 205 | (as for standard ImageNet evaluation) during the evaluation. 206 | train_crop: whether or not to apply random crop during training. 207 | 208 | Returns: 209 | A preprocessed image `Tensor` of range [0, 1]. 210 | """ 211 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 212 | if image.shape[-1] == 1: 213 | image = tf.image.grayscale_to_rgb(image) 214 | if training: 215 | return simclr_data_util.preprocess_for_train( 216 | image, height, width, color_jitter_strength, train_crop) 217 | else: 218 | return simclr_data_util.preprocess_for_eval(image, height, width, test_crop) 219 | -------------------------------------------------------------------------------- /tasks/task.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Task base class.""" 17 | 18 | import abc 19 | import copy 20 | from absl import logging 21 | import ml_collections 22 | import registry 23 | import utils 24 | from data import transforms 25 | import tensorflow as tf 26 | 27 | TaskRegistry = registry.Registry() 28 | 29 | 30 | class Task(abc.ABC): 31 | """Task class. 32 | 33 | Providing: 34 | - Preprocessing functions for a specific task that turns raw features into 35 | inputs in a common interface. 36 | - Post-processing functions for a specific task that decode the model's 37 | outputs in a common interface. 38 | - Evaluation for a specific task. 39 | - Important task properties, such as vocab size, max seq len. 40 | """ 41 | 42 | def __init__(self, 43 | config: ml_collections.ConfigDict): 44 | self.config = config 45 | 46 | train_transforms = config.task.get('train_transforms', []) 47 | eval_transforms = config.task.get('eval_transforms', []) 48 | self.train_transforms = [ 49 | transforms.TransformRegistry.lookup(t.name)(t) 50 | for t in train_transforms] 51 | self.eval_transforms = [ 52 | transforms.TransformRegistry.lookup(t.name)(t) for t in eval_transforms] 53 | 54 | @property 55 | def task_vocab_id(self): 56 | return self.config.task.vocab_id 57 | 58 | @abc.abstractmethod 59 | def preprocess_single(self, dataset: tf.data.Dataset, batch_duplicates: int, 60 | training: bool): 61 | """Task-specific preprocessing of individual example in the dataset. 62 | 63 | Args: 64 | dataset: A tf.data.Dataset. 65 | batch_duplicates: `int`, enlarge a batch by augmenting it multiple times 66 | (as specified) and concating the augmented examples. 67 | training: bool. 68 | 69 | Returns: 70 | A dataset. 71 | """ 72 | 73 | def preprocess_single_example(self, example, training, batch_duplicates=1): 74 | """Preprocessing of a single example. 75 | 76 | This should be called in preprocess_single. 77 | 78 | Args: 79 | example: A dict of name to Tensor. 80 | training: bool. 81 | batch_duplicates: `int`, enlarge a batch by augmenting it multiple times 82 | (as specified) and concating the augmented examples. 83 | 84 | Returns: 85 | A dict of name to Tensor. 86 | """ 87 | if training: 88 | example_list = [] 89 | for _ in range(batch_duplicates): 90 | example_ = copy.copy(example) 91 | for t in self.train_transforms: 92 | example_ = t.process_example(example_) 93 | example_list.append(example_) 94 | example = utils.merge_list_of_dict(example_list) 95 | else: 96 | for t in self.eval_transforms: 97 | example = t.process_example(example) 98 | return example 99 | 100 | @abc.abstractmethod 101 | def preprocess_batched(self, batched_examples, training): 102 | """Task-specific preprocessing of batched examples on accelerators (TPUs). 103 | 104 | Args: 105 | batched_examples: preprocessed and batched examples. 106 | training: bool. 107 | 108 | Returns batched inputs in a comon interface for modeling. 109 | """ 110 | 111 | @abc.abstractmethod 112 | def postprocess_tpu(self): 113 | """Task-specific post processing on accelerators (TPUs). 114 | 115 | This is intended for inference / evaluation time only. 116 | 117 | Returns a list of tensors for `postprocess_cpu` to further process. 118 | """ 119 | 120 | @abc.abstractmethod 121 | def postprocess_cpu(self): 122 | """Task-specific post processing on CPUs. 123 | 124 | This is intended for inference / evaluation time only. 125 | 126 | It receives outputs from `postprocess_tpu`, further processes, and update 127 | internal states (e.g. _metrics). 128 | """ 129 | 130 | def _log_metrics(self, metrics_dict, step): 131 | for key, value in metrics_dict.items(): 132 | logging.info('Step: [%d] %s = %f', step, key, value) 133 | tf.summary.scalar(key, value, step) 134 | 135 | def evaluate(self, summary_writer, step, eval_tag): 136 | """Evaluate results on accumulated outputs (after multiple infer steps). 137 | 138 | Args: 139 | summary_writer: the summary writer. 140 | step: current step. 141 | eval_tag: `string` name scope for eval result summary. 142 | 143 | Returns: 144 | result as a `dict`. 145 | """ 146 | metrics = self.compute_scalar_metrics(step) 147 | with summary_writer.as_default(): 148 | with tf.name_scope(eval_tag): 149 | self._log_metrics(metrics, step) 150 | summary_writer.flush() 151 | self.reset_metrics() 152 | return metrics 153 | 154 | @abc.abstractmethod 155 | def compute_scalar_metrics(self, step): 156 | """Returns a dict containing scalar metrics to log.""" 157 | 158 | @abc.abstractmethod 159 | def reset_metrics(self): 160 | """Reset states of metrics accumulators.""" 161 | -------------------------------------------------------------------------------- /tasks/video_generation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Task for video generation.""" 17 | from absl import logging 18 | import ml_collections 19 | import utils 20 | from data import data_utils 21 | from metrics import fvd 22 | from tasks import task as task_lib 23 | import tensorflow as tf 24 | 25 | 26 | @task_lib.TaskRegistry.register('video_generation') 27 | class TaskVideoGeneration(task_lib.Task): # pytype: disable=base-class-error 28 | """Task for video generation.""" 29 | 30 | def __init__(self, config: ml_collections.ConfigDict): 31 | super().__init__(config) 32 | self._metrics = {} 33 | self._tfgan_evaluator = fvd.FVDMetricEvaluator( 34 | dataset_name=config.dataset.tfds_name, 35 | image_size=config.dataset.image_size) 36 | 37 | def preprocess_single(self, dataset, batch_duplicates, training): 38 | """Task-specific preprocessing of individual example in the dataset. 39 | 40 | Args: 41 | dataset: A tf.data.Dataset. 42 | batch_duplicates: `int`, enlarge a batch by augmenting it multiple times 43 | (as specified) and concating the augmented examples. 44 | training: bool. 45 | 46 | Returns: 47 | A dataset. 48 | """ 49 | image_size = self.config.dataset.image_size 50 | 51 | def _preprocess_single_example(features, labels): 52 | features_list, labels_list = [], [] 53 | for _ in range(batch_duplicates if training else 1): 54 | video = preprocess_video( 55 | features['video'], 56 | height=image_size, 57 | width=image_size, 58 | seq_len=self.config.dataset.get('seq_len'), 59 | cropping=self.config.dataset.cropping, 60 | flipping=self.config.dataset.flipping, 61 | training=training) 62 | label = tf.one_hot(labels['label'], self.config.dataset.num_classes) 63 | features_list.append({'video': video}) 64 | labels_list.append({'label': label}) 65 | features = utils.merge_list_of_dict(features_list) 66 | labels = utils.merge_list_of_dict(labels_list) 67 | return features, labels 68 | 69 | dataset = dataset.map(_preprocess_single_example, 70 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 71 | return dataset 72 | 73 | def preprocess_batched(self, batched_examples, training): 74 | """Task-specific preprocessing of batched examples on accelerators (TPUs). 75 | 76 | Args: 77 | batched_examples: tuples of feature and label tensors that are 78 | preprocessed, batched, and stored with `dict`. 79 | training: bool. 80 | 81 | Returns: 82 | videos: `float` of shape (bsz, t, h, w, c) 83 | labels: `int` of shape (bsz) 84 | """ 85 | features, labels = batched_examples 86 | if training: 87 | return features['video'], labels['label'] 88 | else: 89 | return features['video'], labels['label'], batched_examples 90 | 91 | def infer(self, model, preprocessed_outputs, **kwargs): 92 | """Perform inference given the model and preprocessed outputs.""" 93 | videos, labels, examples = preprocessed_outputs 94 | samples = model.sample( 95 | num_samples=tf.shape(videos)[0], 96 | iterations=self.config.model.infer_iterations, 97 | method=self.config.model.sampler_name, 98 | images=videos, 99 | labels=labels, 100 | **kwargs) 101 | return examples, samples 102 | 103 | def postprocess_tpu(self, 104 | batched_examples, 105 | samples, 106 | training=False): 107 | """Organizing results after fitting the batched examples in graph. 108 | 109 | Such as updating metrics, putting together results for computing metrics in 110 | CPU/numpy mode. 111 | 112 | Note: current implementation only support eval mode where gt are given in 113 | metrics as they are not constructed here from input_seq/target_seq. 114 | 115 | Args: 116 | batched_examples: a tupple of features (`dict`) and labels (`dict`), 117 | containing videos and labels. 118 | samples: `float` predicted video tensor of (bsz, t, h, w, c). 119 | training: `bool` indicating training or inference mode. 120 | 121 | Returns: 122 | results for passing to `postprocess_cpu` which runs in CPU mode. 123 | """ 124 | logging.info('Start postprocess_tpu.') 125 | features, labels = batched_examples 126 | videos = features['video'] 127 | labels = tf.argmax(labels['label'], -1) 128 | 129 | # FID 130 | data_real, data_gen = self._tfgan_evaluator.preprocess_inputs( 131 | [videos, samples], is_n1p1=False) 132 | (logits_real, pool3_real), (logits_gen, pool3_gen) = ( 133 | self._tfgan_evaluator.get_inception_stats([data_real, data_gen])) 134 | 135 | logging.info('postprocess_tpu done.') 136 | return (videos, samples, logits_real, pool3_real, logits_gen, pool3_gen) 137 | 138 | def postprocess_cpu(self, 139 | outputs, 140 | train_step, 141 | eval_step=None, 142 | training=False, 143 | summary_tag='eval', 144 | ret_results=False): 145 | """CPU post-processing of outputs. 146 | 147 | Args: 148 | outputs: a tuple of tensor passed from `postprocess_tpu`. 149 | train_step: `int` scalar indicating training step of current model or the 150 | checkpoint. 151 | eval_step: `int` scalar indicating eval step for the given checkpoint. 152 | training: `bool` indicating training or inference mode. 153 | summary_tag: `string` of name scope for result summary. 154 | ret_results: whether to return visualization videos. 155 | 156 | Returns: 157 | A dict of visualization videos if ret_results, else None. 158 | """ 159 | logging.info('Start postprocess_cpu') 160 | videos, samples, logits_real, pool3_real, logits_gen, pool3_gen = outputs 161 | 162 | # FID update. 163 | self._tfgan_evaluator.update_stats( 164 | logits_real, pool3_real, logits_gen, pool3_gen) 165 | 166 | # videos summary. 167 | bsz, t, h, w, c = utils.shape_as_list(samples) 168 | a = tf.cast(tf.math.sqrt(tf.cast(bsz, tf.float32)), tf.int32) 169 | b = a 170 | vis_samples = samples[:a * a, ...] 171 | vis_samples = tf.reshape(vis_samples, [a, b, t, h, w, c]) 172 | vis_samples = tf.transpose(vis_samples, [0, 3, 1, 2, 4, 5]) 173 | videos_sum = tf.reshape(vis_samples, [1, a * h, b * t * w, c]) 174 | if eval_step < 2: 175 | tf.summary.image( 176 | f'{summary_tag}/samples_{eval_step}', videos_sum, step=train_step) 177 | 178 | logging.info('postprocess_cpu done.') 179 | if ret_results: 180 | return {'gt': videos, 'pred': samples} 181 | 182 | def compute_scalar_metrics(self, step): 183 | """Returns a dict containing scalar metrics to log.""" 184 | result = {} 185 | for metric in self._metrics.values(): 186 | result[metric.name] = metric.result().numpy() 187 | 188 | # FID 189 | result.update(self._tfgan_evaluator.compute_fid_score()) 190 | return result 191 | 192 | def reset_metrics(self): 193 | """Reset states of metrics accumulators.""" 194 | for metric in self._metrics.values(): 195 | metric.reset_states() 196 | 197 | self._tfgan_evaluator.reset() 198 | 199 | 200 | def preprocess_video(video, 201 | height, 202 | width, 203 | seq_len=None, 204 | cropping='none', 205 | flipping='none', 206 | training=False): 207 | """Preprocesses the given video. 208 | 209 | Args: 210 | video: `Tensor` representing an video of arbitrary size (B x H x W x C). 211 | height: Height of output video. 212 | width: Width of output video. 213 | seq_len: Length of sequence to crop. 214 | cropping: which cropping to apply to the video. 215 | flipping: which flipping to apply to the video. 216 | training: `bool` for whether the preprocessing is for training. 217 | 218 | Returns: 219 | A preprocessed video `Tensor` of range [0, 1]. 220 | """ 221 | seq_len = seq_len if seq_len is not None else video.shape[0] 222 | video = tf.image.convert_image_dtype(video, dtype=tf.float32) 223 | if cropping == 'center': 224 | video = data_utils.largest_center_square(video) 225 | video = tf.image.resize( 226 | video, 227 | size=(height, width), 228 | method='area', 229 | preserve_aspect_ratio=False, 230 | antialias=True) 231 | elif cropping == 'random': 232 | video = data_utils.crop_video( 233 | frames=video, 234 | height=height, 235 | width=width, 236 | seq_len=seq_len, 237 | random=True) 238 | elif cropping == 'random_resize': 239 | video = tf.image.resize( 240 | video, 241 | size=(int(height*1.25), int(width*1.25)), # TODO ajabri: make general 242 | method='area', 243 | preserve_aspect_ratio=False, 244 | antialias=True) 245 | video = data_utils.crop_video( 246 | frames=video, 247 | height=height, 248 | width=width, 249 | seq_len=seq_len, 250 | random=True) 251 | elif cropping != 'none': 252 | raise ValueError(f'Unknown cropping method {cropping}') 253 | if training: 254 | if flipping == 'left_right': 255 | video = tf.image.random_flip_left_right(video) 256 | elif flipping != 'none': 257 | raise ValueError(f'Unknown flipping method {flipping}') 258 | video = tf.reshape(video, [seq_len, height, width, 3]) # let arch know shape 259 | 260 | return video 261 | -------------------------------------------------------------------------------- /tasks/visualization/static_shape.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # Source: https://github.com/google/automl/tree/master/efficientdet/visualize/static_shape.py 17 | # Copyright 2020 Google Research. All Rights Reserved. 18 | # 19 | # Licensed under the Apache License, Version 2.0 (the "License"); 20 | # you may not use this file except in compliance with the License. 21 | # You may obtain a copy of the License at 22 | # 23 | # http://www.apache.org/licenses/LICENSE-2.0 24 | # 25 | # Unless required by applicable law or agreed to in writing, software 26 | # distributed under the License is distributed on an "AS IS" BASIS, 27 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 28 | # See the License for the specific language governing permissions and 29 | # limitations under the License. 30 | # ============================================================================== 31 | """Helper functions to access TensorShape values. 32 | 33 | The rank 4 tensor_shape must be of the form [batch_size, height, width, depth]. 34 | """ 35 | 36 | 37 | def get_dim_as_int(dim): 38 | """Utility to get v1 or v2 TensorShape dim as an int. 39 | 40 | Args: 41 | dim: The TensorShape dimension to get as an int 42 | 43 | Returns: 44 | None or an int. 45 | """ 46 | try: 47 | return dim.value 48 | except AttributeError: 49 | return dim 50 | 51 | 52 | def get_batch_size(tensor_shape): 53 | """Returns batch size from the tensor shape. 54 | 55 | Args: 56 | tensor_shape: A rank 4 TensorShape. 57 | 58 | Returns: 59 | An integer representing the batch size of the tensor. 60 | """ 61 | tensor_shape.assert_has_rank(rank=4) 62 | return get_dim_as_int(tensor_shape[0]) 63 | 64 | 65 | def get_height(tensor_shape): 66 | """Returns height from the tensor shape. 67 | 68 | Args: 69 | tensor_shape: A rank 4 TensorShape. 70 | 71 | Returns: 72 | An integer representing the height of the tensor. 73 | """ 74 | tensor_shape.assert_has_rank(rank=4) 75 | return get_dim_as_int(tensor_shape[1]) 76 | 77 | 78 | def get_width(tensor_shape): 79 | """Returns width from the tensor shape. 80 | 81 | Args: 82 | tensor_shape: A rank 4 TensorShape. 83 | 84 | Returns: 85 | An integer representing the width of the tensor. 86 | """ 87 | tensor_shape.assert_has_rank(rank=4) 88 | return get_dim_as_int(tensor_shape[2]) 89 | 90 | 91 | def get_depth(tensor_shape): 92 | """Returns depth from the tensor shape. 93 | 94 | Args: 95 | tensor_shape: A rank 4 TensorShape. 96 | 97 | Returns: 98 | An integer representing the depth of the tensor. 99 | """ 100 | tensor_shape.assert_has_rank(rank=4) 101 | return get_dim_as_int(tensor_shape[3]) 102 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pix2Seq Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Vocab.""" 17 | 18 | # A shared vocab among tasks and its structure - 19 | # Special tokens: [0, 99). 20 | # Class tokens: [100, coord_vocab_shift). Total coord_vocab_shift - 100 classes. 21 | # Coordinate tokens: [coord_vocab_shift, text_vocab_shift). 22 | # Text tokens: [text_vocab_shift, ...]. 23 | 24 | PADDING_TOKEN = 0 25 | 26 | # 10-29 reserved for task id. 27 | 28 | FAKE_CLASS_TOKEN = 30 29 | FAKE_TEXT_TOKEN = 30 # Same token to represent fake class and fake text. 30 | SEPARATOR_TOKEN = 40 31 | INVISIBLE_TOKEN = 41 32 | 33 | BASE_VOCAB_SHIFT = 100 34 | 35 | # Floats used to represent padding and separator in the flat list of polygon 36 | # coords, and invisibility in the key points. 37 | PADDING_FLOAT = -1. 38 | SEPARATOR_FLOAT = -2. 39 | INVISIBLE_FLOAT = -3. 40 | FLOATS = [PADDING_FLOAT, SEPARATOR_FLOAT, INVISIBLE_FLOAT] 41 | TOKENS = [PADDING_TOKEN, SEPARATOR_TOKEN, INVISIBLE_TOKEN] 42 | FLOAT_TO_TOKEN = dict(zip(FLOATS, TOKENS)) 43 | TOKEN_TO_FLOAT = dict(zip(TOKENS, FLOATS)) 44 | --------------------------------------------------------------------------------