├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md └── big_vision ├── __init__.py ├── configs ├── __init__.py ├── bit_i1k.py ├── bit_i21k.py ├── common.py ├── common_fewshot.py ├── load_and_eval.py ├── mlp_mixer_i1k.py ├── proj │ ├── cappa │ │ ├── README.md │ │ ├── cappa_architecture.png │ │ └── pretrain.py │ ├── clippo │ │ ├── README.md │ │ ├── clippo_colab.ipynb │ │ └── train_clippo.py │ ├── distill │ │ ├── README.md │ │ ├── bigsweep_flowers_pet.py │ │ ├── bigsweep_food_sun.py │ │ ├── bit_i1k.py │ │ └── common.py │ ├── flexivit │ │ ├── README.md │ │ ├── i1k_deit3_distill.py │ │ ├── i21k_distill.py │ │ ├── i21k_sup.py │ │ └── timing.py │ ├── givt │ │ ├── README.md │ │ ├── givt_coco_panoptic.py │ │ ├── givt_demo_colab.ipynb │ │ ├── givt_imagenet2012.py │ │ ├── givt_nyu_depth.py │ │ ├── givt_overview.png │ │ ├── vae_coco_panoptic.py │ │ └── vae_nyu_depth.py │ ├── gsam │ │ └── vit_i1k_gsam_no_aug.py │ ├── image_text │ │ ├── README.md │ │ ├── README_lit.md │ │ ├── README_siglip2.md │ │ ├── SigLIP2_demo.ipynb │ │ ├── SigLIP_demo.ipynb │ │ ├── common.py │ │ ├── lit.ipynb │ │ └── siglip_lit_coco.py │ ├── jet │ │ └── imagenet64.py │ ├── jetformer │ │ ├── README.md │ │ ├── jetformer_image_text.py │ │ ├── jetformer_imagenet2012.py │ │ └── jetformer_overview.png │ ├── paligemma │ │ ├── README.md │ │ ├── finetune_paligemma.ipynb │ │ ├── paligemma.png │ │ ├── paligemma2.png │ │ └── transfers │ │ │ ├── activitynet_cap.py │ │ │ ├── activitynet_qa.py │ │ │ ├── ai2d.py │ │ │ ├── aokvqa_da.py │ │ │ ├── aokvqa_mc.py │ │ │ ├── chartqa.py │ │ │ ├── coco35l.py │ │ │ ├── cococap.py │ │ │ ├── common.py │ │ │ ├── docvqa.py │ │ │ ├── forkme.py │ │ │ ├── gqa.py │ │ │ ├── infovqa.py │ │ │ ├── msrvtt_cap.py │ │ │ ├── msrvtt_qa.py │ │ │ ├── msvd_qa.py │ │ │ ├── nlvr2.py │ │ │ ├── ocrvqa.py │ │ │ ├── okvqa.py │ │ │ ├── pope.py │ │ │ ├── refcoco_seg.py │ │ │ ├── rsvqa_hr.py │ │ │ ├── rsvqa_lr.py │ │ │ ├── scicap.py │ │ │ ├── science_qa.py │ │ │ ├── screen2words.py │ │ │ ├── stvqa.py │ │ │ ├── tallyqa.py │ │ │ ├── textcaps.py │ │ │ ├── textvqa.py │ │ │ ├── vatex_cap.py │ │ │ ├── vertexai_l4.py │ │ │ ├── vizwizvqa.py │ │ │ ├── vqav2.py │ │ │ └── widgetcap.py │ ├── reward_tune │ │ └── detection_reward.py │ ├── scaling_laws │ │ └── train_vit_g.py │ └── uvim │ │ ├── README.md │ │ ├── train_coco_panoptic_pretrained.py │ │ ├── train_imagenet2012_colorization_pretrained.py │ │ ├── train_nyu_depth_pretrained.py │ │ ├── uvim_color_task.ipynb │ │ ├── uvim_depth_task.ipynb │ │ ├── uvim_panoptic_task.ipynb │ │ ├── vqvae_coco_panoptic.py │ │ ├── vqvae_imagenet2012_colorization.py │ │ └── vqvae_nyu_depth.py ├── transfer.py ├── vit_i1k.py ├── vit_i21k.py └── vit_s16_i1k.py ├── datasets ├── ai2d │ └── ai2d.py ├── aokvqa │ └── aokvqa.py ├── chartqa │ └── chartqa.py ├── coco35l │ └── coco35l.py ├── core.py ├── countbenchqa │ ├── countbenchqa.py │ └── data │ │ └── countbench_paired_questions.json ├── docvqa │ └── docvqa.py ├── gqa │ └── gqa.py ├── imagenet │ └── class_names.py ├── infovqa │ └── infovqa.py ├── jsonl.py ├── nocaps │ └── nocaps.py ├── okvqa │ └── okvqa.py ├── pope │ └── pope.py ├── refcoco │ └── refcoco.py ├── rsvqa_hr │ └── rsvqa_hr.py ├── rsvqa_lr │ └── rsvqa_lr.py ├── scicap │ └── scicap.py ├── science_qa │ └── science_qa.py ├── screen2words │ └── screen2words.py ├── sequence_packing.py ├── stvqa │ └── stvqa.py ├── tallyqa │ └── tallyqa.py ├── textcaps │ └── textcaps.py ├── textvqa │ └── textvqa.py ├── tfds.py ├── vizwizvqa │ └── vizwizvqa.py ├── vqa │ └── vqa.py ├── widgetcap │ └── widgetcap.py ├── xgqa │ └── xgqa.py └── xm3600 │ └── xm3600.py ├── evaluators ├── __init__.py ├── classification.py ├── common.py ├── fewshot_lsr.py ├── mean.py ├── proj │ ├── cappa │ │ ├── perplexity.py │ │ └── scoring_classifier.py │ ├── distill │ │ └── distance.py │ ├── givt │ │ ├── coco_panoptic.py │ │ ├── nyu_depth.py │ │ └── save_predictions.py │ ├── image_text │ │ ├── contrastive.py │ │ ├── discriminative_classifier.py │ │ ├── discriminative_classifier_test.py │ │ ├── image_text_retrieval.py │ │ ├── image_text_retrieval_test.py │ │ ├── prompt_engineering.py │ │ ├── prompt_engineering_constants.py │ │ ├── prompt_engineering_test.py │ │ ├── retrieval.py │ │ └── retrieval_test.py │ ├── paligemma │ │ ├── perplexity.py │ │ └── transfers │ │ │ ├── chartqa.py │ │ │ ├── coco_caption.py │ │ │ ├── pope.py │ │ │ ├── rsvqa.py │ │ │ ├── science_qa.py │ │ │ ├── segmentation.py │ │ │ ├── storepreds.py │ │ │ ├── tallyqa.py │ │ │ ├── vqa.py │ │ │ └── vqav2.py │ └── uvim │ │ ├── coco_panoptic.py │ │ ├── coltran_fid.py │ │ ├── coltran_fid_data │ │ ├── eval_file_names.txt │ │ └── reference_file_names.txt │ │ ├── common.py │ │ ├── compute_mean.py │ │ ├── nyu_depth.py │ │ ├── psnr.py │ │ └── save_predictions.py └── save.py ├── input_pipeline.py ├── models ├── __init__.py ├── bit.py ├── bit_paper.py ├── common.py ├── mlp_mixer.py ├── ppp │ ├── __init__.py │ └── gemma.py ├── proj │ ├── cappa │ │ └── cappa.py │ ├── clippo │ │ └── one_tower.py │ ├── flaxformer │ │ ├── bert.py │ │ ├── bert_test.py │ │ └── bert_test_util.py │ ├── flexi │ │ ├── vit.py │ │ └── vit_test.py │ ├── givt │ │ ├── adaptor.py │ │ ├── adaptor_test.py │ │ ├── cnn.py │ │ ├── decode.py │ │ ├── decode_test.py │ │ ├── givt.py │ │ ├── givt_test.py │ │ ├── parallel_decode.py │ │ ├── parallel_decode_test.py │ │ ├── vae.py │ │ └── vit.py │ ├── image_text │ │ ├── naflex_vit.py │ │ ├── text_transformer.py │ │ ├── two_towers.py │ │ └── utils.py │ ├── jet │ │ └── jet.py │ ├── jetformer │ │ ├── jetformer.py │ │ └── patch_pca.py │ ├── paligemma │ │ ├── gemma_bv.py │ │ └── paligemma.py │ └── uvim │ │ ├── decode.py │ │ ├── vit.py │ │ ├── vit_test.py │ │ ├── vtt.py │ │ └── vtt_test.py └── vit.py ├── optax.py ├── optax_test.py ├── pp ├── __init__.py ├── archive │ ├── __init__.py │ ├── autoaugment.py │ └── randaug.py ├── autoaugment.py ├── builder.py ├── builder_test.py ├── ops_general.py ├── ops_general_test.py ├── ops_image.py ├── ops_image_test.py ├── ops_text.py ├── ops_text_test.py ├── proj │ ├── clippo │ │ ├── download_unifont.sh │ │ └── pp_ops.py │ ├── flaxformer │ │ ├── bert_ops.py │ │ └── bert_ops_test.py │ ├── givt │ │ └── pp_ops.py │ ├── image_text │ │ ├── ops_naflex.py │ │ └── ops_naflex_test.py │ ├── paligemma │ │ ├── ops.py │ │ ├── robustness.py │ │ ├── sciqa_ops.py │ │ ├── segmentation.py │ │ ├── video.py │ │ └── widgetcap.py │ └── uvim │ │ ├── pp_ops.py │ │ └── pp_ops_test.py ├── registry.py ├── registry_test.py ├── tokenizer.py ├── utils.py └── utils_test.py ├── requirements.txt ├── run_tpu.sh ├── sharding.py ├── tools ├── download_tfds_datasets.py ├── eval_only.py └── lit_demo │ ├── README.md │ ├── build.js │ ├── package.json │ └── src │ ├── app.ts │ ├── components │ ├── image-carousel.scss │ ├── image-carousel.ts │ ├── image-prompts.scss │ ├── image-prompts.ts │ ├── lit-demo-app.scss │ ├── lit-demo-app.ts │ ├── loading-animation.scss │ ├── loading-animation.ts │ ├── message-list.scss │ ├── message-list.ts │ ├── model-controls.scss │ └── model-controls.ts │ ├── exports.ts │ ├── index.html │ ├── lit_demo │ ├── app.ts │ ├── compute.ts │ ├── constants.ts │ ├── data.ts │ └── url_utils.ts │ ├── playground.html │ ├── style.scss │ ├── style │ ├── colors.scss │ └── mixins.scss │ ├── tokenizers │ ├── common.ts │ ├── index.ts │ ├── sentencepiece_bpe.ts │ ├── sentencepiece_bpe_test.ts │ ├── sentencepiece_unigram.ts │ ├── sentencepiece_unigram_test.ts │ └── trie.ts │ └── tsconfig.json ├── train.py ├── trainers └── proj │ ├── cappa │ ├── generative.py │ └── predict_fns.py │ ├── distill │ └── distill.py │ ├── flexi │ ├── common.py │ ├── distill.py │ └── train.py │ ├── givt │ ├── generative.py │ ├── utils.py │ └── vae.py │ ├── gsam │ ├── gsam.py │ └── train.py │ ├── image_text │ ├── _deprecated_contrastive.py │ └── siglip.py │ ├── jet │ └── train.py │ ├── jetformer │ ├── predict_fns.py │ └── train.py │ ├── paligemma │ ├── predict_fns.py │ ├── run.py │ └── train.py │ └── uvim │ ├── coco_utils.py │ ├── colorization_task.py │ ├── depth_task.py │ ├── panoptic_task.py │ ├── train.py │ └── vqvae.py ├── utils.py └── utils_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | At this time we do not plan to accept non-trivial contributions. The main 4 | purpose of this codebase is to allow the community to reproduce results from our 5 | publications. 6 | 7 | You are however free to start a fork of the project for your purposes as 8 | permitted by the license. 9 | 10 | ## Contributor License Agreement 11 | 12 | Contributions to this project must be accompanied by a Contributor License 13 | Agreement (CLA). You (or your employer) retain the copyright to your 14 | contribution; this simply gives us permission to use and redistribute your 15 | contributions as part of the project. Head over to 16 | to see your current agreements on file or 17 | to sign a new one. 18 | 19 | You generally only need to submit a CLA once, so if you've already submitted one 20 | (even if it was for a different project), you probably don't need to do it 21 | again. 22 | 23 | ## Community Guidelines 24 | 25 | This project follows 26 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 27 | -------------------------------------------------------------------------------- /big_vision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/big_vision/0127fb6b337ee2a27bf4e54dea79cff176527356/big_vision/__init__.py -------------------------------------------------------------------------------- /big_vision/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/big_vision/0127fb6b337ee2a27bf4e54dea79cff176527356/big_vision/configs/__init__.py -------------------------------------------------------------------------------- /big_vision/configs/bit_i21k.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint: disable=line-too-long 16 | r"""A config for pre-training BiT on ImageNet-21k. 17 | 18 | This config relies on the Imagenet-21k tfds dataset, which is not yet 19 | available publicly in TFDS. We intend to add the dataset to public TFDS soon, 20 | and this config will then be runnable. 21 | """ 22 | 23 | from big_vision.configs.common_fewshot import get_fewshot_lsr 24 | import ml_collections as mlc 25 | 26 | 27 | def get_config(): 28 | """Config for training on imagenet-21k.""" 29 | config = mlc.ConfigDict() 30 | 31 | config.seed = 0 32 | config.total_epochs = 90 33 | config.num_classes = 21843 34 | config.init_head_bias = -10.0 35 | config.loss = 'sigmoid_xent' 36 | 37 | config.input = dict() 38 | config.input.data = dict( 39 | name='imagenet21k', 40 | split='full[51200:]', 41 | ) 42 | config.input.batch_size = 4096 43 | config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. 44 | 45 | pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")' 46 | pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}') 47 | pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"') 48 | config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common_i21k 49 | pp_eval = 'decode|resize_small(256)|central_crop(224)' 50 | 51 | config.log_training_steps = 50 52 | config.ckpt_steps = 1000 53 | 54 | # Model section 55 | config.model_name = 'bit_paper' 56 | config.model = dict(depth=50, width=1.0) 57 | 58 | # Optimizer section 59 | config.optax_name = 'big_vision.momentum_hp' 60 | config.grad_clip_norm = 1.0 61 | 62 | # linear scaling rule. Don't forget to sweep if sweeping batch_size. 63 | config.lr = (0.03 / 256) * config.input.batch_size 64 | config.wd = (3e-5 / 256) * config.input.batch_size 65 | config.schedule = dict(decay_type='cosine', warmup_steps=5000) 66 | 67 | # Evaluations on i21k itself. 68 | def eval_i21k(split): 69 | return dict( 70 | type='classification', 71 | data={**config.input.data, 'split': split}, 72 | pp_fn=pp_eval + pp_common_i21k, 73 | loss_name=config.loss, 74 | log_steps=1000, # Very fast O(seconds) so it's fine to run it often. 75 | ) 76 | config.evals = {} 77 | config.evals.test = eval_i21k('full[:25_600]') 78 | config.evals.val = eval_i21k('full[25_600:51_200]') 79 | config.evals.train = eval_i21k('full[51_200:76_800]') 80 | 81 | # Few-shot evaluators 82 | config.evals.fewshot = get_fewshot_lsr() 83 | config.evals.fewshot.log_steps = 25_000 84 | 85 | return config -------------------------------------------------------------------------------- /big_vision/configs/common_fewshot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Most common few-shot eval configuration.""" 16 | 17 | import ml_collections as mlc 18 | 19 | 20 | def get_fewshot_lsr(target_resolution=224, resize_resolution=256, 21 | runlocal=False, pp=None, **kw): 22 | """Returns a standard-ish fewshot eval configuration.""" 23 | kw.setdefault('representation_layer', 'pre_logits') 24 | kw.setdefault('shots', (1, 5, 10, 25)) 25 | kw.setdefault('l2_reg', 2.0 ** 10) 26 | kw.setdefault('num_seeds', 3) 27 | kw.setdefault('prefix', '') # No prefix as we already use a/ z/ and zz/ 28 | 29 | # Backward-compatible default: 30 | if not any(f'log_{x}' in kw for x in ['steps', 'percent', 'examples', 'epochs']): # pylint: disable=line-too-long 31 | kw['log_steps'] = 25_000 32 | 33 | config = mlc.ConfigDict(kw) 34 | config.type = 'fewshot_lsr' 35 | config.datasets = { 36 | 'caltech': ('caltech101', 'train', 'test'), # copybara:srtip 37 | 'cars': ('cars196:2.1.0', 'train', 'test'), 38 | 'cifar100': ('cifar100', 'train', 'test'), 39 | 'dtd': ('dtd', 'train', 'test'), 40 | # The first 65000 ImageNet samples have at least 30 shots per any class. 41 | # Commented out by default because needs manual download. 42 | # 'imagenet': ('imagenet2012', 'train[:65000]', 'validation'), 43 | 'pets': ('oxford_iiit_pet', 'train', 'test'), 44 | 'uc_merced': ('uc_merced', 'train[:1000]', 'train[1000:]'), 45 | } if not runlocal else { 46 | 'pets': ('oxford_iiit_pet', 'train', 'test'), 47 | } 48 | 49 | pp = pp or '|'.join([ 50 | 'decode', 51 | f'resize({resize_resolution})', 52 | f'central_crop({target_resolution})', 53 | 'value_range(-1,1)' 54 | ]) 55 | pp += '|keep("image", "label")' 56 | config.pp_train = pp 57 | config.pp_eval = pp 58 | config.display_first = [('imagenet', 10)] if not runlocal else [('pets', 10)] 59 | 60 | return config 61 | -------------------------------------------------------------------------------- /big_vision/configs/proj/cappa/README.md: -------------------------------------------------------------------------------- 1 | # Image Captioners Are Scalable Vision Learners Too 2 | 3 | *by Michael Tschannen, Manoj Kumar, Andreas Steiner, Xiaohua Zhai, Neil Houlsby, Lucas Beyer* [[arxiv]](https://arxiv.org/abs/2306.07915) 4 | 5 | ![CapPa Architecture](./cappa_architecture.png) 6 | 7 | This directory contains a config for training a CapPa model from scratch. 8 | Note that most models in the paper were trained on a proprietary dataset 9 | (WebLI), but similar results can be obtained by training on [LAION](https://laion.ai/). 10 | 11 | By default, this config trains on COCO captions as this data set is readily 12 | available in [TFDS](https://www.tensorflow.org/datasets) without manual steps. 13 | This is not meant to produce a meaningful model, but 14 | provides a way for the user to run the config out of the box. Please update the 15 | config with with a TFDS-wrapped variant of your favorite image/text data set to 16 | train capable models. 17 | 18 | After setting up `big_vision` as described in the [main README](https://github.com/google-research/big_vision#cloud-tpu-vm-setup), training can be launched as follows 19 | 20 | ``` 21 | python -m big_vision.trainers.proj.cappa.generative \ 22 | --config big_vision/configs/proj/cappa/pretrain.py \ 23 | --workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%m-%d_%H%M'` 24 | ``` 25 | 26 | To run the Cap baseline (autoregressive captioning without parallel prediction), 27 | set `config.model.masked_pred_prob = 0.0`. 28 | 29 | ### Citation 30 | ``` 31 | @inproceedings{tschannen2023image, 32 | title={Image Captioners Are Scalable Vision Learners Too}, 33 | author={Tschannen, Michael and Kumar, Manoj and Steiner, Andreas and Zhai, Xiaohua and Houlsby, Neil and Beyer, Lucas}, 34 | booktitle={Neural Information Processing Systems (NeurIPS)}, 35 | year={2023} 36 | } 37 | ``` 38 | -------------------------------------------------------------------------------- /big_vision/configs/proj/cappa/cappa_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/big_vision/0127fb6b337ee2a27bf4e54dea79cff176527356/big_vision/configs/proj/cappa/cappa_architecture.png -------------------------------------------------------------------------------- /big_vision/configs/proj/distill/README.md: -------------------------------------------------------------------------------- 1 | # Knowledge distillation: A good teacher is patient and consistent 2 | *by Lucas Beyer, Xiaohua Zhai, Amélie Royer, Larisa Markeeva, Rohan Anil, Alexander Kolesnikov* 3 | 4 | ## Introduction 5 | We publish all teacher models, and configurations for the main experiments of 6 | the paper, as well as training logs and student models. 7 | 8 | Please read the main [big_vision README](/README.md) to learn how to run 9 | configs, and remember that each config file contains an example invocation in 10 | the top-level comment. 11 | 12 | ## Results 13 | 14 | We provide the following [colab to read and plot the logfiles](https://colab.research.google.com/drive/1nMykzUzsfQ_uAxfj3k35DYsATnG_knPl?usp=sharing) 15 | of a few runs that we reproduced on Cloud. 16 | 17 | ### ImageNet-1k 18 | 19 | The file [bit_i1k.py](bit_i1k.py) is the configuration which reproduces our 20 | distillation runs on ImageNet-1k reported in Figures 1 and 5(left) and the first 21 | row of Table1. 22 | 23 | We release both student and teacher models: 24 | 25 | | Model | Download link | Resolution | ImageNet top-1 acc. (paper) | 26 | | :--- | :---: | :---: | :---: | 27 | | BiT-R50x1 | [link](https://storage.googleapis.com/bit_models/distill/R50x1_160.npz) | 160 | 80.5 | 28 | | BiT-R50x1 | [link](https://storage.googleapis.com/bit_models/distill/R50x1_224.npz) | 224 | 82.8 | 29 | | BiT-R152x2 | [link](https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz) | 224 | 83.0 | 30 | | BiT-R152x2 | [link](https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz) | 384 | 84.3 | 31 | 32 | ### Flowers/Pet/Food/Sun 33 | 34 | The files [bigsweep_flowers_pet.py](bigsweep_flowers_pet.py) and 35 | [bigsweep_food_sun.py](bigsweep_food_sun.py) can be used to reproduce the 36 | distillation runs on these datasets and shown in Figures 3,4,9-12, and Table4. 37 | 38 | While our open-source release does not currently support doing hyper-parameter 39 | sweeps, we still provide an example of the sweeps at the end of the configs 40 | for reference. 41 | 42 | ### Teacher models 43 | Links to all teacher models we used can be found in [common.py](common.py). 44 | -------------------------------------------------------------------------------- /big_vision/configs/proj/distill/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Most common teachers for distillation.""" 16 | 17 | # pylint: disable=line-too-long 18 | inits = { # pylint: disable=duplicate-key Internally, we override some paths for convenience. 19 | 'BiT-M R152x2 imagenet2012 ic224': 'gs://bit_models/distill/R152x2_T_224.npz', 20 | 'BiT-M R152x2 imagenet2012 rc384': 'gs://bit_models/distill/R152x2_T_384.npz', 21 | 'BiT-M R152x2 flowers rc128': 'gs://bit_models/distill/R152x2_T_flowers128.npz', 22 | 'BiT-M R152x2 pet rc128': 'gs://bit_models/distill/R152x2_T_pet128.npz', 23 | 'BiT-M R152x2 food rc128': 'gs://bit_models/distill/R152x2_T_food128.npz', 24 | 'BiT-M R152x2 sun rc128': 'gs://bit_models/distill/R152x2_T_sun128.npz', 25 | 26 | } 27 | # pylint: enable=line-too-long 28 | -------------------------------------------------------------------------------- /big_vision/configs/proj/flexivit/timing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint: disable=line-too-long,missing-function-docstring 16 | r"""A config to run timing for FlexiViT (only inference, no I/O etc.). 17 | 18 | big_vision.tools.eval_only \ 19 | --config big_vision/configs/proj/flexivit/timing.py \ 20 | --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ 21 | --config.total_epochs 90 22 | """ 23 | 24 | from ml_collections import ConfigDict 25 | 26 | 27 | def get_config(): 28 | c = ConfigDict() 29 | 30 | shape = (240, 240, 3) 31 | c.batch_size = 8 # swept 32 | c.init_shapes = [(1, *shape)] 33 | c.representation_layer = 'pre_logits' 34 | 35 | # Creating complete model using all params, the sweep will go over variants. 36 | c.model_name = 'xp.flexivit.vit' 37 | c.model = dict( 38 | variant='B', 39 | pool_type='tok', 40 | patch_size=(10, 10), # Like deit@384 41 | seqhw=(24, 24), 42 | ) 43 | c.num_classes = 0 44 | 45 | c.evals = {} 46 | c.evals.timing = dict( 47 | type='timing', 48 | input_shapes=[shape], 49 | timing=True, 50 | pred_kw=dict(outputs=('pre_logits',)), 51 | ) 52 | 53 | return c -------------------------------------------------------------------------------- /big_vision/configs/proj/givt/givt_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/big_vision/0127fb6b337ee2a27bf4e54dea79cff176527356/big_vision/configs/proj/givt/givt_overview.png -------------------------------------------------------------------------------- /big_vision/configs/proj/image_text/README.md: -------------------------------------------------------------------------------- 1 | # Image/text models 2 | 3 | This directory provides configs and Colabs for different projects on image/text multimodal learning. Please refer to the separate readmes for information on specific projects. 4 | 5 | **LiT: Zero-Shot Transfer with Locked-image text Tuning: [README_lit.md](README_lit.md)** 6 | 7 | **SigLIP 2: Multilingual Vision-Language Encoders with Improved Semantic Understanding, Localization, and Dense Features: [README_siglip2.md](README_siglip2.md)** -------------------------------------------------------------------------------- /big_vision/configs/proj/image_text/README_lit.md: -------------------------------------------------------------------------------- 1 | # LiT: Zero-Shot Transfer with Locked-image text Tuning 2 | 3 | *by Xiaohua Zhai, Xiao Wang, Basil Mustafa, Andreas Steiner, Daniel Keysers, Alexander Kolesnikov, Lucas Beyer* 4 | 5 | https://arxiv.org/abs/2111.07991 6 | 7 | ``` 8 | @article{zhai2022lit, 9 | title={LiT: Zero-Shot Transfer with Locked-image Text Tuning}, 10 | author={Zhai, Xiaohua and Wang, Xiao and Mustafa, Basil and Steiner, Andreas and Keysers, Daniel and Kolesnikov, Alexander and Beyer, Lucas}, 11 | journal={CVPR}, 12 | year={2022} 13 | } 14 | ``` 15 | 16 | Model card: 17 | https://github.com/google-research/vision_transformer/blob/main/model_cards/lit.md 18 | 19 | Colabs: 20 | 21 | - https://colab.research.google.com/github/google-research/vision_transformer/blob/main/lit.ipynb 22 | - https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/lit.ipynb 23 | 24 | ### Results 25 | 26 | | Model | Download link | ImageNet 0-shot | MS-COCO I→T | MS-COCO T→I | Config `arg` | 27 | | :--- | :---: | :---: | :---: | :---: | :--- | 28 | | mixed_L16L | [link](https://storage.googleapis.com/vit_models/lit/LiT-L16L.npz) | 75.7 | 48.5 | 31.2 | `txt=bert_large,img=L/16` | 29 | | mixed_B16B | [link](https://storage.googleapis.com/vit_models/lit/LiT-B16B.npz) | 72.1 | 49.4 | 31.1 | `txt=bert_base,img=B/16,img_head` | 30 | | mixed_B16B_2 | [link](https://storage.googleapis.com/vit_models/lit/LiT-B16B.npz) | 73.9 | 51.5 | 31.8 | `txt=bert_base,img=B/16` | 31 | | coco_B16B | [link](https://storage.googleapis.com/vit_models/lit/big_vision/coco_B16B/checkpoint.npz) | 20.7 | 47.2 | 32.1 | `txt=bert_base,img=B/16` | 32 | 33 | The first three rows are the best available models trained on open source data, 34 | originally published in the [`google-research/vision_transformer`] repository. 35 | These models were re-evaluated with this codebase using the following commands: 36 | 37 | ```bash 38 | big_vision.tools.eval_only --config big_vision/configs/proj/image_text/lit_coco.py:txt=bert_base,img=B/16,img_head,init=gs://vit_models/lit/LiT-B16B.npz 39 | 40 | big_vision.tools.eval_only --config big_vision/configs/proj/image_text/lit_coco.py:txt=bert_base,img=B/16_2,init=gs://vit_models/lit/LiT-B16B_2.npz 41 | 42 | big_vision.tools.eval_only --config big_vision/configs/proj/image_text/lit_coco.py:txt=bert_large,img=L/16,init=gs://vit_models/lit/LiT-L16L.npz 43 | ``` 44 | 45 | Unfortunately, the public multi-modal datasets [`CC12M`] and [`YFCC100M`] are 46 | not yet available in [`tfds`], so these models cannot be reproduced with the 47 | codebase. For this reason we provide the much weaker model `coco_B16B` in the 48 | third row, which was trained on the small `tfds` dataset [`coco_captions`], and 49 | can be used to verify correctness of the codebase 50 | ([workdir](https://console.cloud.google.com/storage/browser/vit_models/lit/big_vision/coco_B16B/)). 51 | 52 | [`google-research/vision_transformer`]: https://github.com/google-research/vision_transformer 53 | [`CC12M`]: https://arxiv.org/abs/2102.08981 54 | [`YFCC100M`]: https://arxiv.org/abs/1503.01817 55 | [`tfds`]: https://www.tensorflow.org/datasets/api_docs/python/tfds 56 | [`coco_captions`]: https://www.tensorflow.org/datasets/catalog/coco_captions 57 | 58 | 59 | ### Changelog 60 | 61 | - 2022-08-18: Added LiT-B16B_2 model that was trained for 60k steps 62 | (LiT_B16B: 30k) without linear head on the image side (LiT_B16B: 768) and has 63 | better performance. 64 | -------------------------------------------------------------------------------- /big_vision/configs/proj/jet/imagenet64.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pytype: disable=attribute-error,line-too-long 16 | r"""Jet config for imagenet64. 17 | 18 | Expected values in imagenet64 (200 epochs): 19 | - 32 couplings and block depth 2: 3.72 bpd 20 | - 64 couplings and block depth 5: 3.66 bpd 21 | """ 22 | 23 | import big_vision.configs.common as bvcc 24 | 25 | 26 | def get_config(arg=None): 27 | """Config for training a Flow model.""" 28 | config = bvcc.parse_arg(arg, mode='') 29 | 30 | config.seed = 0 31 | config.total_epochs = 200 32 | 33 | config.input = dict() 34 | config.input.data = dict( 35 | name='downsampled_imagenet/64x64', 36 | split='train', 37 | ) 38 | config.input.batch_size = 1024 39 | config.input.shuffle_buffer_size = 250_000 40 | 41 | config.input.pp = 'decode|resize(64)|value_range(-1, 1)|keep("image")' 42 | pp_eval = 'decode|resize(64)|value_range(-1, 1)|keep("image")' 43 | 44 | config.log_training_steps = 50 45 | config.ckpt_steps = 5000 46 | 47 | # Model section 48 | config.model_name = 'proj.jet.jet' 49 | config.model = dict( 50 | depth=32, block_depth=2, emb_dim=512, num_heads=8, 51 | kinds=('channels', 'channels', 'channels', 'channels', 'spatial'), 52 | channels_coupling_projs=('random',), 53 | spatial_coupling_projs=('checkerboard', 'checkerboard-inv', 54 | 'vstripes', 'vstripes-inv', 55 | 'hstripes', 'hstripes-inv')) 56 | 57 | # Optimizer section 58 | config.optax_name = 'scale_by_adam' 59 | config.optax = dict(mu_dtype='bfloat16', b2=0.95) 60 | config.grad_clip_norm = 1.0 61 | 62 | # rsqrt schedule. 63 | config.lr = 3e-4 64 | config.wd = 1e-5 65 | config.wd_mults = ( 66 | ('.*', 1.0), 67 | ) 68 | config.schedule = [ 69 | ('.*FREEZE_ME.*', None), # Permutation matrices should be always frozen. 70 | ('.*', dict(decay_type='cosine', warmup_percent=0.1)), 71 | ] 72 | 73 | # config.mesh = [('replica', 16), ('fsdp', -1)] 74 | # config.sharding_strategy = [('.*', 'fsdp(axis="fsdp")')] 75 | # config.sharding_rules = [('act_batch', ('replica', 'fsdp'))] 76 | 77 | # Eval section 78 | config.evals = {} 79 | 80 | config.evals.minitrain_bits = dict( 81 | type='mean', 82 | pred='loss', 83 | data=dict(name=config.input.data.name, split='train[:4096]'), 84 | pp_fn=pp_eval, 85 | log_percent=0.05, 86 | ) 87 | 88 | config.evals.val_bits = dict( 89 | type='mean', 90 | pred='loss', 91 | data=dict(name=config.input.data.name, split='validation'), 92 | pp_fn=pp_eval, 93 | log_percent=0.05, 94 | ) 95 | 96 | if config.mode == 'runlocal': 97 | del config.total_epochs 98 | config.total_steps = 200 99 | config.input.shuffle_buffer_size = 10 100 | config.input.batch_size = 32 101 | config.model.depth = 1 102 | config.model.block_depth = 1 103 | 104 | config.evals.val_bits.data.split = 'validation[:16]' 105 | config.evals.minitrain_bits.data.split = 'train[:16]' 106 | 107 | return config 108 | -------------------------------------------------------------------------------- /big_vision/configs/proj/jetformer/README.md: -------------------------------------------------------------------------------- 1 | # JetFormer: An Autoregressive Generative Model of Raw Images and Text 2 | 3 | *by Michael Tschannen\*, André Susano Pinto\*, Alexander Kolesnikov\** [[arxiv]](https://arxiv.org/abs//2411.19722) 4 | 5 | ![JetFormer overview](jetformer_overview.png) 6 | 7 | ### Summary 8 | 9 | Removing modeling constraints and unifying architectures across domains has 10 | been a key driver of the recent progress in training large multimodal models. 11 | However, most of these models still rely on many separately trained components 12 | such as modality-specific encoders and decoders. In this work, we further 13 | streamline joint generative modeling of images and text. We propose an 14 | autoregressive decoder-only transformer - JetFormer - which is trained to 15 | directly maximize the likelihood of raw data, without relying on any separately 16 | pretrained components, and can understand and generate both text and images. 17 | Specifically, we leverage a normalizing flow model to obtain a soft-token image 18 | representation that is jointly trained with an autoregressive multimodal 19 | transformer. The normalizing flow model serves as both an image encoder for 20 | perception tasks and an image decoder for image generation tasks during 21 | inference. JetFormer achieves text-to-image generation quality competitive with 22 | recent VQ-VAE- and VAE-based baselines. These baselines rely on pretrained 23 | image autoencoders, which are trained with a complex mixture of losses, 24 | including perceptual ones. At the same time, JetFormer demonstrates robust image 25 | understanding capabilities. To the best of our knowledge, JetFormer is the 26 | first model that is capable of generating high-fidelity images and producing 27 | strong log-likelihood bounds. 28 | 29 | ### Training models 30 | 31 | Please see the [main README](https://github.com/google-research/big_vision) for 32 | how to set up the codebase and training data sets in your preferred environment. 33 | Use the commands in the config headers to train models. -------------------------------------------------------------------------------- /big_vision/configs/proj/jetformer/jetformer_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/big_vision/0127fb6b337ee2a27bf4e54dea79cff176527356/big_vision/configs/proj/jetformer/jetformer_overview.png -------------------------------------------------------------------------------- /big_vision/configs/proj/paligemma/paligemma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/big_vision/0127fb6b337ee2a27bf4e54dea79cff176527356/big_vision/configs/proj/paligemma/paligemma.png -------------------------------------------------------------------------------- /big_vision/configs/proj/paligemma/paligemma2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/big_vision/0127fb6b337ee2a27bf4e54dea79cff176527356/big_vision/configs/proj/paligemma/paligemma2.png -------------------------------------------------------------------------------- /big_vision/configs/proj/paligemma/transfers/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Common things across all transfer configs.""" 16 | 17 | 18 | TOKENIZER = 'gemma(tokensets=("loc", "seg"))' 19 | 20 | 21 | def tok(**kw): 22 | """Creates the tokenization preprocessing string.""" 23 | # Single entry point so that it's consistent everywhere and easier to switch. 24 | kw.setdefault('model', TOKENIZER) 25 | kw = ', '.join(f'{k}={repr(v)}' for k, v in kw.items()) 26 | return f'tok({kw})' 27 | 28 | 29 | def combine_and_keep_train(text_len, before=(), sep='\n'): 30 | return '|'.join([ 31 | *before, 32 | tok(key='prefix', bos='yes'), 33 | tok(key='suffix', eos='yes'), 34 | tok(key='septok', text=sep), 35 | # If masks confuse you, see (internal link) 36 | 'masked_concat(["prefix", "septok", "suffix"], outkey="text", mask_ar=[0, 0, 1], mask_loss=[0, 0, 1])', # pylint: disable=line-too-long 37 | # For training, we +1 since the trainer removes EOS. 38 | f'tolen({text_len+1}, pad_value=0, key="text")', # Value doesn't matter. 39 | f'tolen({text_len+1}, pad_value=1, key="mask_ar")', 40 | f'tolen({text_len+1}, pad_value=0, key="mask_loss")', 41 | 'keep("image", "text", "mask_ar", "mask_loss")', 42 | ]) 43 | 44 | 45 | def combine_and_keep_eval(text_len, keep=tuple(), before=(), sep='\n'): 46 | return '|'.join([ 47 | *before, 48 | # Same as training, except that suffix is now the empty string. 49 | # Meaning, we create text as [prefix separator pad], 50 | # and the mask accordingly as [0 0 1] (with repeats of respective lengths) 51 | tok(key='prefix', bos='yes'), 52 | tok(key='septok', text=sep), 53 | # At eval time, there can be also a suffix key in the data. If so it is 54 | # tokenized without EOS and decoding will continue from it. 55 | 'setdefault("suffix", "")', 56 | tok(key='suffix', eos='no'), 57 | # If masks confuse you, see (internal link) 58 | 'masked_concat(["prefix", "septok", "suffix"], outkey="text", mask_ar=[0, 0, 1], mask_input=[1, 1, 1])', # pylint: disable=line-too-long 59 | f'tolen({text_len}, pad_value=0, key="text")', # value doesn't matter. 60 | f'tolen({text_len}, pad_value=1, key="mask_ar")', 61 | f'tolen({text_len}, pad_value=0, key="mask_input")', 62 | # And we need to keep everything that makes our evaluator happy. 63 | 'keep(' + ', '.join(f'"{x}"' for x in ( 64 | 'image', 'text', 'mask_ar', 'mask_input') + tuple(keep)) + ')', 65 | ]) 66 | -------------------------------------------------------------------------------- /big_vision/configs/proj/scaling_laws/train_vit_g.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint: disable=line-too-long 16 | r"""Pre-train ViT-g (1B params) on JFT-3B as in https://arxiv.org/abs/2106.04560 17 | 18 | To train ViT-G (2B params), simply update the following single line: 19 | `config.model.variant = 'G/14'` 20 | 21 | The code is released for reference purposes. 22 | One can test the code using public ImageNet-1k or ImageNet-21k dataset. 23 | 24 | big_vision.train \ 25 | --config big_vision/configs/proj/scaling_laws/train_vit_g.py \ 26 | --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` 27 | 28 | """ 29 | from big_vision.configs.common_fewshot import get_fewshot_lsr 30 | import ml_collections as mlc 31 | 32 | 33 | def get_config(): 34 | """Rocket config.""" 35 | config = mlc.ConfigDict() 36 | 37 | config.dataset = 'jft_3b' 38 | config.val_split = 'val' 39 | config.train_split = 'train' 40 | config.num_classes = 29_593 41 | config.init_head_bias = -10.0 42 | 43 | # Fits 32 images per TPUv3 core with ViT-g/14. 44 | config.batch_size = 4096*4 45 | 46 | pp_common = '|value_range(-1, 1)' 47 | pp_common += f'|onehot({config.num_classes})' 48 | pp_common += '|keep("image", "labels")' 49 | config.pp_train = 'inception_crop(224)|flip_lr' + pp_common 50 | config.pp_eval = 'resize_small(256)|central_crop(224)' + pp_common 51 | config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. 52 | 53 | config.log_training_steps = 50 54 | config.log_eval_steps = 1000 55 | # NOTE: eval is very fast O(seconds) so it's fine to run it often. 56 | 57 | config.ckpt_steps = 1000 58 | config.keep_ckpt_steps = 10_000 59 | 60 | config.prefetch_to_device = 1 61 | config.trial = 0 62 | 63 | # Model section 64 | config.model_name = 'vit' 65 | config.model = mlc.ConfigDict() 66 | config.model.variant = 'g/14' 67 | config.model.pool_type = 'map' 68 | 69 | # Optimizer section 70 | config.optax_name = 'big_vision.scale_by_adafactor' 71 | config.grad_clip_norm = 1.0 72 | config.lr = 8e-4 73 | config.wd = 0.03 * 8e-4 74 | config.wd_mults = [ 75 | ('.*head/kernel', 100.0), 76 | ('.*/kernel', 1.0), 77 | ] 78 | config.schedule = dict( 79 | decay_type='rsqrt', timescale=10_000, warmup_steps=10_000, 80 | cooldown_steps=50_000) 81 | config.total_steps = 1_000_000 82 | 83 | # Few-shot eval section 84 | config.evals = {} 85 | config.evals.fewshot = dict(log_steps=10_000, **get_fewshot_lsr()) 86 | 87 | return config 88 | -------------------------------------------------------------------------------- /big_vision/configs/vit_s16_i1k.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint: disable=line-too-long 16 | r"""Pre-training ViT-S/16 on ILSVRC-2012 following https://arxiv.org/abs/2205.01580. 17 | 18 | This should take 6-7h to finish 90ep on a TPU-v3-8 and reach 76.5%, 19 | see the tech report for more details. 20 | 21 | Command to run: 22 | 23 | big_vision.train \ 24 | --config big_vision/configs/vit_s16_i1k.py \ 25 | --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` 26 | 27 | To run for 300ep, add `--config.total_epochs 300` to the command. 28 | """ 29 | 30 | import ml_collections as mlc 31 | 32 | 33 | def get_config(): 34 | """Config for training.""" 35 | config = mlc.ConfigDict() 36 | 37 | config.seed = 0 38 | config.total_epochs = 90 39 | config.num_classes = 1000 40 | config.loss = 'softmax_xent' 41 | 42 | config.input = {} 43 | config.input.data = dict( 44 | name='imagenet2012', 45 | split='train[:99%]', 46 | ) 47 | config.input.batch_size = 1024 48 | config.input.cache_raw = True # Needs up to 120GB of RAM! 49 | config.input.shuffle_buffer_size = 250_000 50 | 51 | pp_common = ( 52 | '|value_range(-1, 1)' 53 | '|onehot(1000, key="{lbl}", key_result="labels")' 54 | '|keep("image", "labels")' 55 | ) 56 | config.input.pp = ( 57 | 'decode_jpeg_and_inception_crop(224)|flip_lr|randaug(2,10)' + 58 | pp_common.format(lbl='label') 59 | ) 60 | pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common 61 | 62 | # To continue using the near-defunct randaug op. 63 | config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug'] 64 | 65 | config.log_training_steps = 50 66 | config.ckpt_steps = 1000 67 | 68 | # Model section 69 | config.model_name = 'vit' 70 | config.model = dict( 71 | variant='S/16', 72 | rep_size=True, 73 | pool_type='gap', 74 | posemb='sincos2d', 75 | ) 76 | 77 | # Optimizer section 78 | config.grad_clip_norm = 1.0 79 | config.optax_name = 'scale_by_adam' 80 | config.optax = dict(mu_dtype='bfloat16') 81 | 82 | config.lr = 0.001 83 | config.wd = 0.0001 84 | config.schedule = dict(warmup_steps=10_000, decay_type='cosine') 85 | 86 | config.mixup = dict(p=0.2, fold_in=None) 87 | 88 | # Eval section 89 | def get_eval(split, dataset='imagenet2012'): 90 | return dict( 91 | type='classification', 92 | data=dict(name=dataset, split=split), 93 | pp_fn=pp_eval.format(lbl='label'), 94 | loss_name=config.loss, 95 | log_steps=2500, # Very fast O(seconds) so it's fine to run it often. 96 | ) 97 | config.evals = {} 98 | config.evals.train = get_eval('train[:2%]') 99 | config.evals.minival = get_eval('train[99%:]') 100 | config.evals.val = get_eval('validation') 101 | config.evals.v2 = get_eval('test', dataset='imagenet_v2') 102 | config.evals.real = get_eval('validation', dataset='imagenet2012_real') 103 | config.evals.real.pp_fn = pp_eval.format(lbl='real_label') 104 | 105 | return config 106 | -------------------------------------------------------------------------------- /big_vision/datasets/core.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Core data functions, dispatch calls to the requested dataset.""" 16 | import importlib 17 | 18 | 19 | # Note: intentionally not using ABC to avoid forcing implementation of every 20 | # method, since one can imagine train-only datasets for example. 21 | class DataSource: 22 | """The API that any data source should implement.""" 23 | 24 | def get_tfdata(self, ordered, *, process_split=True, allow_cache=True): 25 | """Creates this data object as a tf.data.Dataset. 26 | 27 | This will be called separately in each process, and it is up to the dataset 28 | implementation to shard it accordingly if desired! 29 | 30 | Args: 31 | ordered: if True, the dataset should use deterministic ordering, if False 32 | it may have undefined ordering. Think of True == val, False == train. 33 | process_split: if False then every process receives the entire dataset 34 | (e.g. for evaluators running in a single process). 35 | allow_cache: whether to allow caching the opened data or not. 36 | 37 | Returns: 38 | A tf.data.Dataset object. 39 | 40 | Raises: 41 | RuntimeError: if not implemented by the dataset, but called. 42 | """ 43 | raise RuntimeError("not implemented for {self.__class__.__name__}") 44 | 45 | @property 46 | def total_examples(self): 47 | """Returns number of examples in the dataset, regardless of sharding.""" 48 | raise RuntimeError("not implemented for {self.__class__.__name__}") 49 | 50 | def num_examples_per_process(self): 51 | """Returns a list of the numer of examples for each process. 52 | 53 | This is only needed for datasets that should go through make_for_inference. 54 | 55 | Returns: 56 | Returns a list of the numer of examples for each process. 57 | 58 | Ideally, this would always be `[total() / nprocess] * nprocess`, but in 59 | reality we can almost never perfectly shard a dataset across arbitrary 60 | number of processes. 61 | 62 | One alternative option that can work in some cases is to not even shard 63 | the dataset and thus return `[num_examples()] * nprocess. 64 | 65 | Raises: 66 | RuntimeError: if not implemented by the dataset, but called. 67 | """ 68 | raise RuntimeError("not implemented for {self.__class__.__name__}") 69 | 70 | 71 | def get(name, **kw): 72 | if name.startswith("bv:"): 73 | mod = importlib.import_module(f"big_vision.datasets.{name[3:]}") 74 | return mod.DataSource(**kw) 75 | else: 76 | mod = importlib.import_module("big_vision.datasets.tfds") 77 | return mod.DataSource(name, **kw) 78 | -------------------------------------------------------------------------------- /big_vision/datasets/sequence_packing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Packed Sequence Op.""" 16 | 17 | # Forked from 18 | # https://github.com/google/maxtext/blob/main/MaxText/sequence_packing.py. 19 | 20 | 21 | from typing import Dict, Optional, List, Union 22 | 23 | from flax import traverse_util 24 | import tensorflow as tf 25 | 26 | AUTOTUNE = tf.data.experimental.AUTOTUNE 27 | FLATTEN_SEPARATOR = "<|sep|>" 28 | 29 | 30 | def pack_dataset( 31 | dataset: tf.data.Dataset, 32 | batch_size: int | None, 33 | key2length: Union[int, Dict[str, int]], 34 | keys: Optional[List[str | tuple[str, ...]]] = None) -> tf.data.Dataset: 35 | """Creates a 'packed' version of a dataset on-the-fly. 36 | 37 | Wrap `tensorflow.grain` ops. 38 | 39 | This is meant to replace the irritation of having to create a separate 40 | "packed" version of a dataset to train efficiently on TPU. 41 | Each example in the output dataset represents several examples in the 42 | input dataset. 43 | 44 | For each key in the input dataset, two additional keys are created: 45 | _segment_ids: an int32 tensor identifying the parts 46 | representing the original example. 47 | _positions: an int32 tensor identifying the position within the original 48 | example. 49 | 50 | Example: 51 | Two input examples get combined to form an output example. 52 | The input examples are: 53 | {"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]} 54 | {"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]} 55 | The output example is: 56 | { 57 | "inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0] 58 | "inputs_seg": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0] 59 | "inputs_pos": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0] 60 | "targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0] 61 | "targets_seg": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0] 62 | "targets_pos": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0] 63 | } 64 | 0 represents padding in both the inputs and the outputs. 65 | Sequences in the incoming examples are truncated to length "length", and the 66 | sequences in the output examples all have fixed (padded) length "length". 67 | 68 | Args: 69 | dataset: A `tf.data.Dataset`. 70 | batch_size: Batch size of the packed dataset. 71 | key2length: An integer, or a dict from feature-key to integer. 72 | keys: A list of strings (e.g. ["inputs", "targets"]). 73 | 74 | Returns: 75 | A `tf.data.Dataset`. 76 | """ 77 | raise ValueError("Not implemented in OSS yet.") 78 | -------------------------------------------------------------------------------- /big_vision/datasets/tfds.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """TensorFlow Datasets as data source for big_vision.""" 16 | import functools 17 | 18 | import big_vision.datasets.core as ds_core 19 | import jax 20 | import numpy as np 21 | import overrides 22 | import tensorflow as tf 23 | import tensorflow_datasets as tfds 24 | 25 | 26 | class DataSource(ds_core.DataSource): 27 | """Use TFDS as a data source.""" 28 | 29 | def __init__(self, name, split, data_dir=None, skip_decode=("image",)): 30 | self.builder = _get_builder(name, data_dir) 31 | self.split = split 32 | # Each host is responsible for a fixed subset of data 33 | process_splits = tfds.even_splits(split, jax.process_count()) 34 | self.process_split = process_splits[jax.process_index()] 35 | self.skip_decode = skip_decode 36 | 37 | @overrides.overrides 38 | def get_tfdata( 39 | self, ordered=False, *, process_split=True, allow_cache=True, **kw): 40 | # The tf.data may use a lot of RAM, so we need to expose the option of not 41 | # keeping this in memory when we use lots of input pipelines, such as when 42 | # having many ephemeral evaluators. 43 | return (_cached_get_dataset if allow_cache else _get_dataset)( 44 | self.builder, self.skip_decode, 45 | split=self.process_split if process_split else self.split, 46 | shuffle_files=not ordered, 47 | **kw) 48 | 49 | @property 50 | @overrides.overrides 51 | def total_examples(self): 52 | return self.builder.info.splits[self.split].num_examples 53 | 54 | @overrides.overrides 55 | def num_examples_per_process(self): 56 | splits = tfds.even_splits(self.split, jax.process_count()) 57 | return [self.builder.info.splits[s].num_examples for s in splits] 58 | 59 | 60 | @functools.cache 61 | def _get_builder(dataset, data_dir): 62 | if dataset == "from_data_dir": 63 | return tfds.builder_from_directory(data_dir) 64 | else: 65 | return tfds.builder(dataset, data_dir=data_dir, try_gcs=True) 66 | 67 | 68 | # Cache as it may well take 1-2min on large datasets, and we may use the same 69 | # multiple times (eg various evaluators). 70 | def _get_dataset(builder, skip_decode, shuffle_files, split=None, **rckw): 71 | """Returns a tf.data to be used.""" 72 | ds = builder.as_dataset( 73 | split=split, shuffle_files=shuffle_files, 74 | read_config=tfds.ReadConfig( 75 | skip_prefetch=True, # We prefetch after pipeline. 76 | try_autocache=False, # We control this, esp. for few-shot. 77 | add_tfds_id=True, 78 | **rckw, 79 | ), 80 | decoders={ 81 | f: tfds.decode.SkipDecoding() 82 | for f in skip_decode if f in builder.info.features 83 | }) 84 | 85 | def _hash_tfds_id(example): 86 | id_ = tf.strings.to_hash_bucket_strong( 87 | example["tfds_id"], 88 | np.iinfo(np.uint32).max, # Max value 89 | [3714561454027272724, 8800639020734831960]) # Magic. 90 | example["_id"] = tf.bitcast(id_, tf.int32)[0] # good device dtype. 91 | return example 92 | 93 | return ds.map(_hash_tfds_id) 94 | _cached_get_dataset = functools.cache(_get_dataset) 95 | -------------------------------------------------------------------------------- /big_vision/evaluators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/big_vision/0127fb6b337ee2a27bf4e54dea79cff176527356/big_vision/evaluators/__init__.py -------------------------------------------------------------------------------- /big_vision/evaluators/classification.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluator for the classfication task.""" 16 | # pylint: disable=consider-using-from-import 17 | 18 | import functools 19 | 20 | from big_vision.evaluators import common 21 | import big_vision.utils as u 22 | import jax 23 | import jax.numpy as jnp 24 | 25 | 26 | # Temporary global flag to facilitate backwards compatability. Will be removed 27 | # by the end of year 2023. 28 | API = 'jit' 29 | 30 | 31 | # To avoid re-compiling the function for every new instance of the same 32 | # evaluator on a different dataset! 33 | @functools.cache 34 | def get_eval_fn(predict_fn, loss_name): 35 | """Produces eval function, also applies pmap.""" 36 | @jax.jit 37 | def _eval_fn(train_state, batch, labels, mask): 38 | logits, *_ = predict_fn(train_state, batch) 39 | 40 | # Ignore the entries with all zero labels for evaluation. 41 | mask *= labels.max(axis=1) 42 | 43 | loss = getattr(u, loss_name)( 44 | logits=logits, labels=labels, reduction=False) 45 | loss = jnp.sum(loss * mask) 46 | 47 | top1_idx = jnp.argmax(logits, axis=1) 48 | # Extracts the label at the highest logit index for each image. 49 | top1_correct = jnp.take_along_axis( 50 | labels, top1_idx[:, None], axis=1)[:, 0] 51 | ncorrect = jnp.sum(top1_correct * mask) 52 | nseen = jnp.sum(mask) 53 | return ncorrect, loss, nseen 54 | return _eval_fn 55 | 56 | 57 | class Evaluator: 58 | """Classification evaluator.""" 59 | 60 | def __init__(self, predict_fn, loss_name, label_key='labels', **kw): 61 | self.get_data_iter, self.steps = common.eval_input_pipeline(**kw) 62 | self.eval_fn = get_eval_fn(predict_fn, loss_name) 63 | self.label_key = label_key 64 | 65 | def run(self, train_state): 66 | """Computes all metrics.""" 67 | ncorrect, loss, nseen = 0, 0, 0 68 | for _, batch in zip(range(self.steps), self.get_data_iter()): 69 | labels, mask = batch.pop(self.label_key), batch.pop('_mask') 70 | batch_ncorrect, batch_losses, batch_nseen = jax.device_get( 71 | self.eval_fn(train_state, batch, labels, mask)) 72 | ncorrect += batch_ncorrect 73 | loss += batch_losses 74 | nseen += batch_nseen 75 | yield ('prec@1', ncorrect / nseen) 76 | yield ('loss', loss / nseen) 77 | -------------------------------------------------------------------------------- /big_vision/evaluators/mean.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluator for computing mean of per-example metrics. 16 | 17 | This evaluator can be used in two ways: 18 | 1. Create a new evaluator with reduced boilerplate by inheriting from it. 19 | 2. For quick prototyping, use this with predict_fns which return the metrics. 20 | """ 21 | from functools import partial 22 | from typing import Mapping 23 | 24 | from big_vision.evaluators import common 25 | 26 | import jax 27 | import jax.numpy as jnp 28 | import numpy as np 29 | 30 | 31 | # Temporary global flag to facilitate backwards compatability. Will be removed 32 | # by the end of year 2023. 33 | API = 'jit' 34 | 35 | 36 | # Note: global to avoid jax re-compiling across different evaluator instances. 37 | @partial(jax.jit, static_argnums=0) 38 | def _run_predict_fn(predict_fn, train_state, batch): 39 | """Sum per-example metrics weighted by `_mask`.""" 40 | metrics = predict_fn(train_state, batch) 41 | mask = batch['_mask'] 42 | # Sanity check output format of predict_fn. 43 | assert isinstance(metrics, Mapping), 'predict_fn must return a dict' 44 | for y in jax.tree.leaves(metrics): 45 | if y.shape != mask.shape: 46 | raise ValueError( 47 | f'Expected per-example metrics of shape {mask.shape} found ' 48 | f'{jax.tree.map(lambda x: x.shape, metrics)}.') 49 | metrics = {**metrics, '_mask': mask} 50 | return jax.tree.map(lambda x: jnp.sum(jnp.where(mask, x, 0)), metrics) 51 | 52 | 53 | class Evaluator: 54 | """Report the mean of per-example metrics computed by predict_fn. 55 | 56 | `predict_fn(params, batch)` must return a dict from metric name to 57 | per-example metrics of shape [batch_size]. 58 | """ 59 | 60 | def __init__(self, predict_fn, **kw): 61 | self.get_data_iter, self.steps = common.eval_input_pipeline(**kw) 62 | self.predict_fn = partial(_run_predict_fn, predict_fn) 63 | 64 | def run(self, train_state): 65 | """Computes all metrics.""" 66 | metrics = [] 67 | 68 | # Compute batch metrics without blocking. 69 | for _, batch in zip(range(self.steps), self.get_data_iter()): 70 | batch_metrics = self.predict_fn(train_state, batch) 71 | metrics.append(batch_metrics) 72 | 73 | # Transfer metrics (blocking). 74 | metrics = jax.device_get(metrics) 75 | 76 | # Accumulate metrics across batches. 77 | metrics_sum = jax.tree.map(lambda *x: np.sum(x), *metrics) 78 | mask_sum = metrics_sum.pop('_mask') 79 | for key, value_sum in metrics_sum.items(): 80 | yield (key, value_sum / mask_sum) 81 | -------------------------------------------------------------------------------- /big_vision/evaluators/proj/cappa/perplexity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluator for perplexity of a model.""" 16 | from big_vision.evaluators import mean 17 | import big_vision.utils as u 18 | import jax.numpy as jnp 19 | 20 | 21 | # Temporary global flag to facilitate backwards compatability. Will be removed 22 | # by the end of year 2023. 23 | API = 'jit' 24 | 25 | 26 | def perplexity(predict_fn, normalize_by_seqlen): 27 | """Returns a function that computes perplexity.""" 28 | 29 | def _perplexity_fn(train_state, batch, pad_token=0, **kw): 30 | logits, _ = predict_fn(train_state, batch, **kw) 31 | 32 | # Ignore perplexity on the padding label. 33 | weights = jnp.where(batch['labels'] != pad_token, 1, 0).astype(jnp.float32) 34 | if batch.get('label_masks') is not None: 35 | weights = weights * batch['label_masks'] 36 | 37 | losses = u.weighted_softmax_xent( 38 | logits=logits, labels=batch['labels'], 39 | weights=weights, label_smoothing=0.0, 40 | reduction=False, normalize=normalize_by_seqlen) 41 | 42 | return {'perplexity': losses} 43 | return _perplexity_fn 44 | 45 | 46 | class Evaluator(mean.Evaluator): 47 | """Perplexity evaluator.""" 48 | 49 | def __init__(self, predict_fn, *a, normalize_by_seqlen=False, **kw): 50 | super().__init__(perplexity(predict_fn, normalize_by_seqlen), *a, **kw) 51 | -------------------------------------------------------------------------------- /big_vision/evaluators/proj/cappa/scoring_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Scoring classifier. 16 | 17 | This one is based on a generative perspective for image classification. 18 | Here we input the image as well as all the tokenized labels to compute their 19 | perplexity and select the one with minimum loss as the prediction. 20 | """ 21 | import functools 22 | from big_vision.datasets.imagenet import class_names as imagenet_class_names 23 | from big_vision.evaluators import mean 24 | from big_vision.pp import builder as pp_builder 25 | import jax.numpy as jnp 26 | import numpy as np 27 | 28 | # Temporary global flag to facilitate backwards compatability. Will be removed 29 | # by the end of year 2023. 30 | API = "jit" 31 | 32 | 33 | CLASS_NAMES = { 34 | "imagenet2012": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES, 35 | } 36 | 37 | 38 | # As a separate function to cache result across instances. 39 | @functools.lru_cache(maxsize=None) 40 | def get_classes(dataset_name, pp_txt): 41 | """Load the class label strings and tokenize them using pp_txt.""" 42 | pp_fn = pp_builder.get_preprocess_fn(pp_txt, log_data=False) 43 | return np.array([pp_fn({"label": name})["labels"] 44 | for name in CLASS_NAMES[dataset_name]]) 45 | 46 | 47 | def scoring(predict_fn, tokenized_labels): 48 | 49 | def _scoring_fn(train_state, batch, *a, **kw): 50 | batch = {"_label_tokens": tokenized_labels, **batch} 51 | scores = predict_fn(train_state, batch, *a, **kw) 52 | predictions = jnp.argmax(scores, axis=-1) 53 | return {"prec@1": predictions == batch["label"]} 54 | 55 | return _scoring_fn 56 | 57 | 58 | class Evaluator(mean.Evaluator): 59 | """Evaluator for classification accuracy based on scoring all classes.""" 60 | 61 | def __init__(self, predict_fn, data, pp_fn, pp_txt, *a, **kw): 62 | cls_tokens = get_classes(data["name"], pp_txt) 63 | super().__init__(scoring(predict_fn, cls_tokens), data, pp_fn, *a, **kw) 64 | -------------------------------------------------------------------------------- /big_vision/evaluators/proj/image_text/image_text_retrieval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluates image-text retrieval results.""" 16 | from typing import List, Mapping 17 | 18 | import numpy as np 19 | 20 | RECALL_THRESHOLDS = (1, 5, 10) 21 | 22 | 23 | def text_to_image_retrieval_eval( 24 | dist_matrix: np.ndarray, 25 | text_image_correspondence: List[int]) -> Mapping[str, float]: 26 | """Runs the text-to-image retrieval eval from the distance matrix. 27 | 28 | Args: 29 | dist_matrix: Distance matrix between text and image embeddings (shape 30 | N_IMAGES x N_TEXTS). 31 | text_image_correspondence: Mapping between rows and columns of 32 | `dist_matrix`, that is, a list of N_TEXTS integers n_i that represent that 33 | the text embedding in column i corresponds to the image embedding in row 34 | n_i. Please note that many texts can be assigned to the same image. For 35 | instance, if we have 2 images and 4 texts (i.e. dist_matrix is 2x4), then 36 | `text_image_correspondence = [0, 0, 1, 1]` means that the two first texts 37 | correspond to the first image and the two last texts to the second image. 38 | 39 | Returns: 40 | A dictionary with the Recall@k scores for k in RECALL_THRESHOLDS. 41 | """ 42 | per_text_ranks = dist_matrix.argsort(axis=0) 43 | text_image_correspondence = np.array(text_image_correspondence) 44 | 45 | def recall_at(k): 46 | wins = per_text_ranks[:k, :] == text_image_correspondence[None] 47 | return wins.any(axis=0).mean() 48 | 49 | return { 50 | f'Recall@{k}': recall_at(k) 51 | for k in RECALL_THRESHOLDS 52 | } 53 | 54 | 55 | def image_to_text_retrieval_eval( 56 | dist_matrix: np.ndarray, 57 | text_image_correspondence: List[int]) -> Mapping[str, float]: 58 | """Runs the image-to-text retrieval eval from the distance matrix. 59 | 60 | Args: 61 | dist_matrix: Distance matrix between text and image embeddings (shape 62 | N_IMAGES x N_TEXTS). 63 | text_image_correspondence: Mapping between rows and columns of 64 | `dist_matrix`, that is, a list of N_TEXTS integers n_i that represent that 65 | the text embedding in column i corresponds to the image embedding in row 66 | n_i. Please note that many texts can be assigned to the same image. For 67 | instance, if we have 2 images and 4 texts (i.e. dist_matrix is 2x4), then 68 | `text_image_correspondence = [0, 0, 1, 1]` means that the two first texts 69 | correspond to the first image and the two last texts to the second image. 70 | 71 | Returns: 72 | A dictionary with the Recall@k scores for k in RECALL_THRESHOLDS. 73 | """ 74 | per_image_ranks = dist_matrix.argsort(axis=1) 75 | text_image_correspondence = np.array(text_image_correspondence) 76 | 77 | def recall_at(k): 78 | top_k_images = text_image_correspondence[per_image_ranks[:, :k]] 79 | wins = top_k_images == np.arange(len(per_image_ranks))[:, None] 80 | return wins.any(axis=1).mean() 81 | 82 | return { 83 | f'Recall@{k}': recall_at(k) 84 | for k in RECALL_THRESHOLDS 85 | } 86 | -------------------------------------------------------------------------------- /big_vision/evaluators/proj/image_text/image_text_retrieval_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for image_text_retrieval.""" 16 | from typing import Mapping 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from big_vision.evaluators.proj.image_text import image_text_retrieval 21 | import numpy as np 22 | 23 | 24 | class ImTextRetrievalTest(parameterized.TestCase): 25 | 26 | @parameterized.parameters( 27 | (np.array([[0.0, 0.0, 0.1, 0.5, 0.1, 0.2, 0.5, 0.1], 28 | [0.5, 0.4, 0.0, 0.0, 0.4, 0.2, 0.6, 0.4], 29 | [0.5, 0.4, 0.1, 0.5, 0.0, 0.0, 0.8, 0.3], 30 | [0.5, 0.4, 0.1, 0.5, 0.3, 0.2, 0.0, 0.0]]), { 31 | 'Recall@1': 1.0, 32 | 'Recall@5': 1.0, 33 | 'Recall@10': 1.0 34 | }), # 35 | (np.array([[0.8, 0.8, 0.1, 0.5, 0.1, 0.2, 0.5, 0.1], 36 | [0.5, 0.4, 0.0, 0.0, 0.4, 0.2, 0.6, 0.4], 37 | [0.5, 0.4, 0.1, 0.5, 0.0, 0.8, 0.8, 0.3], 38 | [0.5, 0.4, 0.1, 0.5, 0.4, 0.2, 0.3, 0.3]]), { 39 | 'Recall@1': 0.5, 40 | 'Recall@5': 0.75, 41 | 'Recall@10': 1.0 42 | })) 43 | def test_image_to_text_retrieval_eval(self, dist_matrix: np.ndarray, 44 | expected: Mapping[str, float]): 45 | """Checks `image_to_text_retrieval_eval`. 46 | 47 | Args: 48 | dist_matrix: Distance matrix between image (rows) and text (columns). 49 | expected: Expected eval results. 50 | """ 51 | self.assertEqual( 52 | image_text_retrieval.image_to_text_retrieval_eval( 53 | dist_matrix, [0, 0, 1, 1, 2, 2, 3, 3]), expected) 54 | 55 | @parameterized.parameters( 56 | (np.array([[0.0, 0.0, 0.1, 0.5, 0.1, 0.2, 0.5, 0.1], 57 | [0.5, 0.4, 0.0, 0.0, 0.4, 0.2, 0.6, 0.4], 58 | [0.5, 0.4, 0.1, 0.5, 0.0, 0.0, 0.8, 0.3], 59 | [0.5, 0.4, 0.1, 0.5, 0.3, 0.2, 0.0, 0.0]]), { 60 | 'Recall@1': 1.0, 61 | 'Recall@5': 1.0, 62 | 'Recall@10': 1.0 63 | }), # 64 | (np.array([[0.8, 0.8, 0.1, 0.5, 0.1, 0.2, 0.1, 0.1], 65 | [0.5, 0.4, 0.0, 0.0, 0.4, 0.2, 0.6, 0.4], 66 | [0.5, 0.4, 0.1, 0.5, 0.0, 0.8, 0.8, 0.3], 67 | [0.5, 0.4, 0.1, 0.5, 0.4, 0.2, 0.3, 0.3]]), { 68 | 'Recall@1': 0.375, 69 | 'Recall@5': 1.0, 70 | 'Recall@10': 1.0 71 | })) 72 | def test_image_text_retrieval(self, dist_matrix: np.ndarray, 73 | expected: Mapping[str, float]): 74 | """Checks `text_to_image_retrieval_eval`. 75 | 76 | Args: 77 | dist_matrix: Distance matrix between image (rows) and text (columns). 78 | expected: Expected eval results. 79 | """ 80 | self.assertEqual( 81 | image_text_retrieval.text_to_image_retrieval_eval( 82 | dist_matrix, [0, 0, 1, 1, 2, 2, 3, 3]), expected) 83 | 84 | 85 | if __name__ == '__main__': 86 | absltest.main() 87 | -------------------------------------------------------------------------------- /big_vision/evaluators/proj/image_text/prompt_engineering_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for prompt_engineering.""" 16 | 17 | from absl.testing import absltest 18 | from big_vision.evaluators.proj.image_text import prompt_engineering 19 | 20 | 21 | class PromptEngineeringTest(absltest.TestCase): 22 | 23 | def test_canonicalize_text(self): 24 | self.assertEqual(prompt_engineering.canonicalize_text("test_test"), "test test") 25 | self.assertEqual( 26 | prompt_engineering.canonicalize_text("test___test"), "test test") 27 | self.assertEqual(prompt_engineering.canonicalize_text("test"), "test") 28 | self.assertEqual(prompt_engineering.canonicalize_text("test."), "test") 29 | self.assertEqual(prompt_engineering.canonicalize_text(" test "), "test") 30 | self.assertEqual( 31 | prompt_engineering.canonicalize_text("test\ntest"), "test test") 32 | self.assertEqual( 33 | prompt_engineering.canonicalize_text("test test"), "test test") 34 | self.assertEqual(prompt_engineering.canonicalize_text("test {}"), "test") 35 | self.assertEqual( 36 | prompt_engineering.canonicalize_text( 37 | "test {}", keep_punctuation_exact_string="{}"), "test {}") 38 | self.assertEqual( 39 | prompt_engineering.canonicalize_text( 40 | " test {}...", keep_punctuation_exact_string="{}"), "test {}") 41 | self.assertEqual( 42 | prompt_engineering.canonicalize_text( 43 | "test {} {} {}", keep_punctuation_exact_string="{}"), 44 | "test {} {} {}") 45 | 46 | 47 | if __name__ == "__main__": 48 | absltest.main() 49 | -------------------------------------------------------------------------------- /big_vision/evaluators/proj/paligemma/perplexity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluator for perplexity of a model.""" 16 | import functools 17 | 18 | from big_vision.evaluators import mean 19 | import big_vision.utils as u 20 | import jax.numpy as jnp 21 | 22 | 23 | # Temporary global flag to facilitate backwards compatability. Will be removed 24 | # by the end of year 2023. 25 | API = 'jit' 26 | 27 | 28 | # Cache the function such that it won't always recompile (in mean evaluator). 29 | @functools.cache 30 | def perplexity( 31 | predict_fn, key='labels', shift_labels=True, pad_token=None): 32 | """Returns a function that computes perplexity.""" 33 | 34 | def _perplexity_fn(train_state, batch, **kw): 35 | logits, _ = predict_fn(train_state, batch, **kw) 36 | 37 | labels = batch[key] 38 | weights = batch.get('mask_loss', jnp.ones_like(labels)) 39 | 40 | if pad_token is not None: 41 | weights = weights * (labels != pad_token).astype(jnp.float32) 42 | 43 | if shift_labels: 44 | labels = labels[:, 1:] 45 | weights = weights[:, 1:] 46 | 47 | losses = u.weighted_softmax_xent( 48 | logits=logits, labels=labels, weights=weights, 49 | reduction=False, normalize=False) 50 | normalizer = jnp.clip(weights.sum(axis=1), 2e-38) 51 | 52 | return {'sum': losses, 'avg': losses / normalizer} 53 | return _perplexity_fn 54 | 55 | 56 | class Evaluator(mean.Evaluator): 57 | """Perplexity evaluator.""" 58 | 59 | def __init__(self, predict_fn, *a, 60 | key='labels', shift_labels=False, pad_token=None, **kw): 61 | kw.setdefault('prefetch', 0) # More memory-saving default. 62 | super().__init__( 63 | perplexity(predict_fn, key, shift_labels, pad_token), *a, **kw) 64 | -------------------------------------------------------------------------------- /big_vision/evaluators/proj/paligemma/transfers/storepreds.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluator to run inference and store results.""" 16 | import functools 17 | 18 | import big_vision.evaluators.common as c 19 | import big_vision.input_pipeline 20 | import big_vision.pp.builder 21 | import big_vision.pp.tokenizer 22 | import big_vision.utils as u 23 | 24 | import jax 25 | 26 | # Temporary global flag to facilitate backwards compatability. Will be removed 27 | # by the end of year 2023. 28 | API = "jit" 29 | 30 | 31 | class Evaluator: 32 | """Evaluator to run inference and store results.""" 33 | 34 | def __init__( 35 | self, predict_fn, tokenizer=None, 36 | preds_outfile="{workdir}/{name}_{split}_preds.json", 37 | annot_outfile="{workdir}/{name}_{split}_annotations.json", 38 | id_key="id", 39 | *, data, devices, **kw 40 | ): 41 | self.id_key = id_key 42 | self.get_data_iter, self.steps = c.eval_input_pipeline( 43 | keep_on_cpu={id_key}, data=data, devices=devices, **kw) 44 | 45 | self.preds_outfile = c.resolve_outfile( 46 | preds_outfile, name=data.get("name"), split=data.get("split", "")) 47 | self.annot_outfile = c.resolve_outfile( 48 | annot_outfile, name=data.get("name"), split=data.get("split", "")) 49 | 50 | self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) 51 | self.decode = functools.partial( 52 | predict_fn, devices=devices, eos_token=self.tok.eos_token) 53 | 54 | def run(self, train_state): 55 | """Run eval.""" 56 | res = [] 57 | 58 | for _, batch in zip(range(self.steps), self.get_data_iter()): 59 | # (batch, seqlen) array of decoded generated tokens. 60 | tokens = self.decode(train_state, batch) 61 | 62 | # (local_batch,) 63 | tokens = u.get_local_slice_from_fsarray(tokens) 64 | ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) 65 | 66 | image_ids = batch[self.id_key][ex_masks] 67 | pred_captions = self.tok.to_str(tokens[ex_masks]) 68 | 69 | for image_id, caption in zip(image_ids, pred_captions): 70 | res.append({self.id_key: str(image_id), "caption": caption}) 71 | 72 | res = c.multiprocess_write_json(self.preds_outfile, res) 73 | 74 | if jax.process_index(): # Host0 gets all preds and does eval. 75 | return 76 | 77 | yield "num_examples", len(res) 78 | -------------------------------------------------------------------------------- /big_vision/evaluators/proj/uvim/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Common utilities used in evaluators.""" 16 | import math 17 | import jax 18 | import tensorflow as tf 19 | import tensorflow_datasets as tfds 20 | 21 | 22 | def get_jax_process_dataset(dataset, split, global_batch_size, pp_fn, 23 | dataset_dir=None, cache=True, add_tfds_id=False): 24 | """Returns dataset to be processed by current jax host. 25 | 26 | The dataset is sharded and padded with zeros such that all processes 27 | have equal number of batches. The first 2 dimensions of the dataset 28 | elements are: [local_device_count, device_batch_size]. 29 | 30 | Args: 31 | dataset: dataset name. 32 | split: dataset split. 33 | global_batch_size: batch size to be process per iteration on the dataset. 34 | pp_fn: preprocessing function to apply per example. 35 | dataset_dir: path for tfds to find the prepared data. 36 | cache: whether to cache the dataset after batching. 37 | add_tfds_id: whether to add the unique `tfds_id` string to each example. 38 | """ 39 | assert global_batch_size % jax.device_count() == 0 40 | total_examples = tfds.load( 41 | dataset, split=split, data_dir=dataset_dir).cardinality() 42 | num_batches = math.ceil(total_examples / global_batch_size) 43 | 44 | process_split = tfds.even_splits( 45 | split, n=jax.process_count(), drop_remainder=False)[jax.process_index()] 46 | data = tfds.load( 47 | dataset, 48 | split=process_split, 49 | data_dir=dataset_dir, 50 | read_config=tfds.ReadConfig(add_tfds_id=add_tfds_id)).map(pp_fn) 51 | pad_data = tf.data.Dataset.from_tensors( 52 | jax.tree_map(lambda x: tf.zeros(x.shape, x.dtype), data.element_spec) 53 | ).repeat() 54 | 55 | data = data.concatenate(pad_data) 56 | data = data.batch(global_batch_size // jax.device_count()) 57 | data = data.batch(jax.local_device_count()) 58 | data = data.take(num_batches) 59 | if cache: 60 | # Eval datasets are often used many times and caching the dataset after 61 | # batching allows one to have the buffers ready to be used and not have 62 | # to wait for preprocessing to be done over and over. 63 | data = data.cache() 64 | return data 65 | -------------------------------------------------------------------------------- /big_vision/evaluators/proj/uvim/compute_mean.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluator for computing mean of per-example metrics.""" 16 | import functools 17 | from typing import Mapping 18 | 19 | from big_vision import input_pipeline 20 | from big_vision.datasets import core as ds_core 21 | from big_vision.pp import builder as pp_builder 22 | 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | 28 | # Note: global to avoid jax re-compiling across different evaluator instances. 29 | @functools.partial(jax.pmap, static_broadcasted_argnums=0, axis_name='batch') 30 | def _run_predict_fn(predict_fn, params, batch): 31 | """Sum per-example metrics weighted by `_mask`.""" 32 | mask = batch['_mask'] 33 | metrics = predict_fn(params, batch) 34 | # Sanity check output format of predict_fn. 35 | assert isinstance(metrics, Mapping), 'predict_fn must return a dict' 36 | for y in jax.tree_leaves(metrics): 37 | if y.shape != mask.shape: 38 | raise ValueError( 39 | f'Expected per-example metrics of shape {mask.shape} found ' 40 | f'{jax.tree_map(lambda x: x.shape, metrics)}.') 41 | metrics = {**metrics, '_mask': mask} 42 | metrics = jax.tree_map(lambda x: jnp.inner(x, mask), metrics) 43 | return jax.lax.psum(metrics, axis_name='batch') 44 | 45 | 46 | class Evaluator: 47 | """Report the mean of per-example metrics computed by predict_fn. 48 | 49 | `predict_fn(params, batch)` must return a dict from metric name to 50 | per-example metrics of shape [batch_size]. 51 | """ 52 | 53 | def __init__(self, predict_fn, data, pp_fn, batch_size, 54 | cache_final=True, cache_raw=False, prefetch=1): 55 | data = ds_core.get(**data) 56 | self.dataset, self.steps = input_pipeline.make_for_inference( 57 | data.get_tfdata(ordered=True), batch_size=batch_size, 58 | num_ex_per_process=data.num_examples_per_process(), 59 | preprocess_fn=pp_builder.get_preprocess_fn(pp_fn), 60 | cache_final=cache_final, cache_raw=cache_raw) 61 | self.data_iter = input_pipeline.start_input_pipeline(self.dataset, prefetch) 62 | self.predict_fn = predict_fn 63 | 64 | def run(self, params): 65 | """Computes all metrics.""" 66 | metrics = [] 67 | 68 | # Compute batch metrics without blocking. 69 | for _, batch in zip(range(self.steps), self.data_iter): 70 | batch_metrics = _run_predict_fn(self.predict_fn, params, batch) 71 | metrics.append(batch_metrics) 72 | 73 | # Transfer metrics from device 0 to host (blocking). 74 | metrics = jax.device_get(jax.tree_map(lambda x: x[0], metrics)) 75 | 76 | metrics_sum = jax.tree_map(lambda *x: np.sum(x), *metrics) 77 | mask_sum = metrics_sum.pop('_mask') 78 | for key, value_sum in metrics_sum.items(): 79 | yield (key, value_sum / mask_sum) 80 | -------------------------------------------------------------------------------- /big_vision/evaluators/proj/uvim/psnr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Compute PSNR, currently used for colorization and superresolution.""" 16 | 17 | import functools 18 | 19 | import big_vision.evaluators.proj.uvim.common as common 20 | import big_vision.pp.builder as pp_builder 21 | import jax 22 | import jax.numpy as jnp 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | 27 | class Evaluator: 28 | """PSNR evaluator. 29 | 30 | `predict_fn` accepts arbitrary dictionaries of parameters and data, where 31 | the data dictionary is produced by the `pp_fn` op. It is expected to output a 32 | single-key dict containing an RGB image with intensities in [-1,1]. 33 | """ 34 | 35 | def __init__(self, 36 | predict_fn, 37 | pp_fn, 38 | batch_size, 39 | dataset="imagenet2012", 40 | split="validation", 41 | predict_kwargs=None): 42 | 43 | def predict(params, batch): 44 | 45 | def _f(x): 46 | y = predict_fn(params, x, **(predict_kwargs or {})) 47 | # Assume image intensities are in [-1,1]. 48 | # Evaluator expects a dict with a single item. 49 | pred, = y.values() 50 | return _psnr(pred, x["labels"], 2.) 51 | return jax.lax.all_gather({ 52 | "mask": batch["mask"], 53 | "psnr": _f(batch["input"]), 54 | }, axis_name="data", axis=0) 55 | 56 | self.predict_fn = jax.pmap(predict, axis_name="data") 57 | 58 | # Prepare data for each process and pad with zeros so all processes have the 59 | # same number of batches. 60 | def preprocess(example): 61 | return { 62 | "mask": tf.constant(1), 63 | "input": pp_builder.get_preprocess_fn(pp_fn)(example), 64 | } 65 | 66 | self.data = common.get_jax_process_dataset( 67 | dataset, 68 | split, 69 | global_batch_size=batch_size, 70 | add_tfds_id=True, 71 | pp_fn=preprocess) 72 | 73 | def run(self, params): 74 | """Run eval.""" 75 | psnrs = [] 76 | 77 | for batch in self.data.as_numpy_iterator(): 78 | # Outputs is a dict with values shaped (gather/same, devices, batch, ...) 79 | out = self.predict_fn(params, batch) 80 | 81 | if jax.process_index(): # Host0 gets all preds and does eval. 82 | continue 83 | 84 | # First, we remove the "gather" dim and transfer the result to host, 85 | # leading to numpy arrays of (devices, device_batch, ...) 86 | out = jax.tree_map(lambda x: jax.device_get(x[0]), out) 87 | mask = out["mask"] 88 | batch_psnrs = out["psnr"][mask != 0] 89 | psnrs.extend(batch_psnrs) 90 | 91 | if jax.process_index(): # Host0 gets all preds and does eval. 92 | return 93 | 94 | yield "PSNR", np.mean(psnrs) 95 | 96 | 97 | @functools.partial(jax.vmap, in_axes=[0, 0, None]) 98 | def _psnr(img0, img1, dynamic_range): 99 | mse = jnp.mean(jnp.power(img0 - img1, 2)) 100 | return 20. * jnp.log10(dynamic_range) - 10. * jnp.log10(mse) 101 | -------------------------------------------------------------------------------- /big_vision/evaluators/proj/uvim/save_predictions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluator to save predictions.""" 16 | # pylint: disable=consider-using-from-import 17 | import os 18 | 19 | from absl import flags 20 | from absl import logging 21 | import big_vision.evaluators.proj.uvim.common as common 22 | import big_vision.pp.builder as pp_builder 23 | import big_vision.utils as u 24 | import jax 25 | import numpy as np 26 | import tensorflow as tf 27 | 28 | 29 | class Evaluator: 30 | """Save predictions in "{FLAGS.workdir}/{outfile}". 31 | 32 | Results can then be easily inspected in a notebook such as: 33 | 34 | ``` 35 | results = utils.load_checkpoint(None, "") 36 | inputs, outputs = (results["inputs"], results["outputs"]) 37 | ``` 38 | """ 39 | 40 | def __init__(self, predict_fn, pp_fn, dataset, split, batch_size, outfile, 41 | predict_kwargs=None, dataset_dir=None): 42 | # Prepare to run predict on all processes and gather predictions on all 43 | # devices. Note: if needed consider only gather across processes. 44 | def predict(params, batch): 45 | y = predict_fn(params, batch['inputs'], **(predict_kwargs or {})) 46 | res = {'inputs': batch['inputs'], 'outputs': y, 'mask': batch['mask']} 47 | return jax.lax.all_gather(res, axis_name='data', axis=0, tiled=True) 48 | 49 | self.predict_fn = jax.pmap(predict, axis_name='data') 50 | 51 | # Prepare data for each process and pad with zeros so all processes have the 52 | # same number of batches. 53 | def preprocess(example): 54 | return { 55 | 'mask': tf.constant(1), 56 | 'inputs': pp_builder.get_preprocess_fn(pp_fn)(example), 57 | } 58 | self.data = common.get_jax_process_dataset( 59 | dataset=dataset, split=split, 60 | dataset_dir=dataset_dir, 61 | global_batch_size=batch_size, 62 | pp_fn=preprocess) 63 | 64 | self.path = os.path.join(flags.FLAGS.workdir, outfile) 65 | 66 | def run(self, params): 67 | """Compute all predictions, gather in main host and save in outfile.""" 68 | count = 0 69 | outputs = [] 70 | for batch in self.data.as_numpy_iterator(): 71 | out = self.predict_fn(params, batch) 72 | if jax.process_index(): 73 | continue 74 | 75 | out = jax.device_get(jax.tree_map(lambda x: x[0], out)) 76 | out = jax.tree_map(lambda x: x[out['mask'] == 1], out) # pylint: disable=cell-var-from-loop 77 | count += out['mask'].shape[0] 78 | out.pop('mask') 79 | outputs.append(out) 80 | 81 | logging.log_every_n_seconds( 82 | logging.INFO, 'Save predictions: processed %i examples so far.', 30, 83 | count) 84 | 85 | if jax.process_index(): 86 | return 87 | 88 | logging.info('Save predictions: processed %d examples.', count) 89 | 90 | # Actually save in filesystem. 91 | outputs = jax.tree_map(lambda *x: np.concatenate(x, axis=0), *outputs) 92 | u.save_checkpoint(outputs, self.path, compressed=True) 93 | return 94 | 95 | yield None # pylint: disable=unreachable 96 | -------------------------------------------------------------------------------- /big_vision/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/big_vision/0127fb6b337ee2a27bf4e54dea79cff176527356/big_vision/models/__init__.py -------------------------------------------------------------------------------- /big_vision/models/ppp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/big_vision/0127fb6b337ee2a27bf4e54dea79cff176527356/big_vision/models/ppp/__init__.py -------------------------------------------------------------------------------- /big_vision/models/proj/clippo/one_tower.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Model definition to train a single ViT model with the contrastive trainer.""" 16 | 17 | import importlib 18 | from typing import Optional, Any 19 | 20 | from big_vision import utils 21 | import flax.linen as nn 22 | import jax.numpy as jnp 23 | 24 | ConfigDict = Any 25 | 26 | 27 | class Model(nn.Module): 28 | """Single ViT to encode regular images and text images.""" 29 | image: Optional[ConfigDict] = None 30 | image_model: str = "vit" 31 | out_dim: int = 768 32 | temperature_init: float = 10.0 33 | 34 | @nn.compact 35 | def __call__(self, image, text=None, **kw): 36 | """Returns (B, C) image and (B, C) text representations, and some extras.""" 37 | ztxt, zimg = None, None 38 | kw = kw or {} 39 | 40 | image_model = importlib.import_module( 41 | f"big_vision.models.{self.image_model}" 42 | ).Model(**{"num_classes": self.out_dim, **(self.image or {})}, name="img") # pylint: disable=not-a-mapping 43 | 44 | def _compute_embedding(input_image, prefix): 45 | zemb, out_emb = image_model(input_image, **kw) 46 | out = {f"{prefix}/{k}": v for k, v in out_emb.items()} 47 | 48 | # Normalize the embeddings. 49 | out[f"{prefix}/norm"] = jnp.linalg.norm(zemb, axis=1, keepdims=True) 50 | out[f"{prefix}/normalized"] = zemb = zemb / (out[f"{prefix}/norm"] + 1e-8) 51 | return zemb, out 52 | 53 | out = {} 54 | if image is not None: 55 | zimg, out_img = _compute_embedding(image, "img") 56 | out.update(out_img) 57 | 58 | if text is not None: 59 | ztxt, out_txt = _compute_embedding(text, "txt") 60 | out.update(out_txt) 61 | 62 | temp_init = jnp.log(self.temperature_init) 63 | t = self.param("t", 64 | lambda key, shape, dtype: temp_init*jnp.ones(shape, dtype), 65 | (1,), jnp.float32) 66 | out["t"] = jnp.exp(t) 67 | out["t/parameter"] = t 68 | 69 | return zimg, ztxt, out 70 | 71 | 72 | def load(init_params, init_files, model_cfg, img_load_kw={}): # pylint: disable=dangerous-default-value 73 | """Loads the ViT parameters - adapted from proj/image_text/two_towers.py.""" 74 | if isinstance(init_files, str): 75 | # A shortcut for a single file checkpoint of a two_towers model. 76 | init_files = {k: f"{init_files}:{k}" for k in ("img", "t")} 77 | else: 78 | init_files = {**init_files} # Shallow copy because we'll pop stuff off. 79 | 80 | restored_params = {**init_params} 81 | 82 | img_init = init_files.pop("image", init_files.pop("img", None)) 83 | if img_init: 84 | restored_params["img"] = importlib.import_module( 85 | f"big_vision.models.{model_cfg.image_model}" 86 | ).load(init_params["img"], img_init, model_cfg.image, **img_load_kw) 87 | 88 | t_init = init_files.pop("temperature", init_files.pop("t", None)) 89 | if t_init: 90 | restored_params["t"] = utils.load_params(None, t_init) 91 | 92 | assert not init_files, ( 93 | f"There's something unused left in `config.model_init`. You probably got " 94 | f"a typo. Here it is: {init_files}") 95 | 96 | return restored_params 97 | -------------------------------------------------------------------------------- /big_vision/models/proj/flaxformer/bert.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """BERT encoder, optionally loading pre-trained checkpoints.""" 16 | 17 | import dataclasses 18 | from typing import Optional 19 | 20 | from absl import logging 21 | from big_vision import utils 22 | from big_vision.models import common 23 | import flax 24 | import flax.linen as nn 25 | import jax.numpy as jnp 26 | from tensorflow.io import gfile 27 | 28 | from flaxformer.architectures.bert import bert 29 | from flaxformer.architectures.bert import bert_checkpoint_converter 30 | from flaxformer.architectures.bert import configs 31 | 32 | 33 | class Model(nn.Module): 34 | """BERT encoder with linear projection on last layer CLS token.""" 35 | 36 | config: str 37 | num_classes: Optional[int] = None 38 | head_zeroinit: bool = True 39 | 40 | @nn.compact 41 | def __call__(self, text, *, train=False): 42 | out = {} 43 | 44 | batch_size, max_len = text.shape 45 | bert_model = bert.BertEncoder(**dataclasses.asdict({ 46 | "base": configs.BertBaseConfig(), 47 | "large": configs.BertLargeConfig(), 48 | }[self.config])) 49 | x = out["transformed"] = bert_model( 50 | token_ids=text, 51 | position_ids=jnp.tile( 52 | jnp.arange(0, max_len, dtype=jnp.int32), [batch_size, 1]), 53 | segment_ids=jnp.zeros([batch_size, max_len], dtype=jnp.int32), 54 | input_mask=text.astype(jnp.bool_).astype(jnp.int32), 55 | enable_dropout=train, 56 | ) 57 | 58 | x = out["pre_logits"] = x[:, 0] # CLS token 59 | 60 | if self.num_classes: 61 | kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {} 62 | x = out["logits"] = nn.Dense(self.num_classes, name="head", **kw)(x) 63 | 64 | return x, out 65 | 66 | 67 | def load(params, path, model_cfg=None, dont_load=()): 68 | """Returns `params` with BERT weights replaced from checkpoint at `path`.""" 69 | del model_cfg 70 | 71 | checkpoint_path = f"{path}/bert_model.ckpt" 72 | if gfile.exists(f"{checkpoint_path}.index"): 73 | logging.info("Loading original BERT checkpoint from '%s'", checkpoint_path) 74 | params = flax.core.FrozenDict(params).unfreeze() # Recursive copy. 75 | max_len = ( 76 | params["BertEncoder_0"]["embedder"]["embedders_position_ids"] 77 | ["embedding"].shape[0]) 78 | bert_params, pooler_params = ( 79 | bert_checkpoint_converter.load_params_from_tf_checkpoint( 80 | checkpoint_path=f"{path}/bert_model.ckpt")) 81 | del pooler_params 82 | if isinstance(bert_params, flax.core.FrozenDict): 83 | bert_params = bert_params.unfreeze() 84 | bert_params["embedder"]["embedders_position_ids"]["embedding"] = ( 85 | bert_params["embedder"]["embedders_position_ids"]["embedding"][:max_len] 86 | ) 87 | return common.merge_params( 88 | {"BertEncoder_0": bert_params}, params, dont_load) 89 | 90 | logging.info( 91 | "Could not find original BERT checkpoint path '%s', " 92 | "loading big_vision checkpoint '%s'", checkpoint_path, path) 93 | restored_params = utils.load_params(path) 94 | return common.merge_params(restored_params, params, dont_load) 95 | -------------------------------------------------------------------------------- /big_vision/models/proj/flaxformer/bert_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for bert.""" 16 | 17 | import tempfile 18 | 19 | from big_vision import input_pipeline 20 | from big_vision.models.proj.flaxformer import bert 21 | from big_vision.models.proj.flaxformer import bert_test_util 22 | import big_vision.pp.builder as pp_builder 23 | import big_vision.pp.ops_general # pylint: disable=unused-import 24 | import big_vision.pp.proj.flaxformer.bert_ops # pylint: disable=unused-import 25 | import flax 26 | import jax 27 | import jax.numpy as jnp 28 | import tensorflow as tf 29 | 30 | 31 | # BERT vocabulary for testing. 32 | _BERT_VOCAB = [ 33 | "[PAD]", 34 | "[UNK]", 35 | "this", 36 | "is", 37 | "a", 38 | "test", 39 | "[CLS]", 40 | "[SEP]", 41 | ] 42 | _TOKEN_LEN = 16 43 | 44 | 45 | class BertTest(tf.test.TestCase): 46 | 47 | def test_load_apply(self): 48 | inkey = "text" 49 | vocab_path = f"{tempfile.mkdtemp()}/vocab.txt" 50 | with open(vocab_path, "w") as f: 51 | f.write("\n".join(_BERT_VOCAB)) 52 | ds2, _ = input_pipeline.make_for_inference( 53 | tf.data.Dataset.from_tensor_slices( 54 | {inkey: tf.ragged.constant([["this is a test"]])}), 55 | num_ex_per_process=[1], 56 | preprocess_fn=pp_builder.get_preprocess_fn( 57 | f"bert_tokenize(inkey='{inkey}', vocab_path='{vocab_path}', " 58 | f"max_len={_TOKEN_LEN})" 59 | "|keep('labels')"), 60 | batch_size=1, 61 | ) 62 | text = jnp.array(next(iter(ds2))["labels"]) 63 | model = bert.Model(config="base") 64 | variables = model.init(jax.random.PRNGKey(0), text) 65 | params = bert.load(flax.core.unfreeze(variables)["params"], 66 | bert_test_util.create_base_checkpoint()) 67 | x, out = model.apply({"params": params}, text) 68 | self.assertAllEqual(jax.tree_map(jnp.shape, x), (1, 768)) 69 | self.assertAllEqual( 70 | jax.tree_map(jnp.shape, out), { 71 | "transformed": (1, 16, 768), 72 | "pre_logits": (1, 768), 73 | }) 74 | 75 | 76 | if __name__ == "__main__": 77 | tf.test.main() 78 | -------------------------------------------------------------------------------- /big_vision/models/proj/givt/adaptor_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for the IRevNet adaptor.""" 16 | 17 | from big_vision.models.proj.givt import adaptor 18 | import jax 19 | from jax import random 20 | import jax.numpy as jnp 21 | 22 | from absl.testing import absltest 23 | 24 | 25 | class AdaptorTest(googletest.TestCase): 26 | 27 | def test_inversion(self): 28 | num_channels = 8 29 | input_shape = (1, 24, 24, num_channels) 30 | 31 | rng = random.PRNGKey(758493) 32 | _, inp_rng, init_rng, data_rng = jax.random.split(rng, 4) 33 | 34 | dummy_x = random.normal(inp_rng, shape=input_shape) 35 | real_x = jax.random.normal(data_rng, shape=input_shape) 36 | 37 | model = adaptor.IRevNet( 38 | num_blocks=4, 39 | num_channels=num_channels, 40 | dropout_rate=0.0, 41 | ) 42 | params = model.init(init_rng, dummy_x) 43 | 44 | real_y = model.apply(params, real_x, method=model.forward) 45 | real_x_ = model.apply(params, real_y, method=model.inverse) 46 | self.assertTrue(jnp.allclose(real_x, real_x_, atol=1e-5)) 47 | 48 | 49 | if __name__ == "__main__": 50 | googletest.main() 51 | -------------------------------------------------------------------------------- /big_vision/models/proj/givt/vae.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Abstract VAE model class. 16 | 17 | Gaussian encoder and decoder (the latter assumed to have constant variance). 18 | 19 | Inspiration drawn from https://github.com/pytorch/examples/tree/main/vae. 20 | """ 21 | 22 | import abc 23 | from typing import Optional, Mapping 24 | 25 | 26 | import flax.linen as nn 27 | import jax 28 | import jax.numpy as jnp 29 | 30 | 31 | class Model(nn.Module, metaclass=abc.ABCMeta): 32 | """Abstract VAE model class.""" 33 | 34 | codeword_dim: Optional[int] = None 35 | code_len: int = 256 36 | code_dropout: str = "none" 37 | 38 | @abc.abstractmethod 39 | def encode( 40 | self, 41 | x: jax.Array, 42 | *, 43 | train: bool = False, 44 | ) -> tuple[jax.Array, jax.Array]: 45 | ... 46 | 47 | def reparametrize( 48 | self, 49 | mu: jax.Array, 50 | logvar: jax.Array, 51 | rng: jax.Array | None = None, 52 | ) -> jax.Array: 53 | std = jnp.exp(0.5 * logvar) 54 | if rng is None: 55 | rng = self.make_rng("dropout") 56 | eps = jax.random.normal(rng, shape=std.shape, dtype=std.dtype) 57 | return mu + std * eps 58 | 59 | @abc.abstractmethod 60 | def decode( 61 | self, x: jax.Array, 62 | train: bool = False, 63 | ) -> jax.Array | Mapping[str, jax.Array]: 64 | ... 65 | 66 | def code_dropout_fn(self, z: jax.Array, *, train: bool = False) -> jax.Array: 67 | # "seq" drops out tokens later in the sequence with higher probablility than 68 | # tokens earlier in the sequence. 69 | assert self.code_dropout in ["none", "seq", "random"] 70 | if train and self.code_dropout != "none": 71 | importance = jnp.linspace(1.0, 0.0, self.code_len + 2)[1:-1] 72 | thr = jax.random.uniform(self.make_rng("dropout"), z.shape[:1]) 73 | mask = importance[None, :] > thr[:, None] 74 | if self.code_dropout == "random": 75 | mask = jax.random.permutation( 76 | self.make_rng("dropout"), mask, axis=-1, independent=True) 77 | z = z * mask[:, :, None] 78 | return z 79 | 80 | def __call__( 81 | self, 82 | x: jax.Array, 83 | *, 84 | train: bool = False, 85 | ) -> tuple[jax.Array | Mapping[str, jax.Array], Mapping[str, jax.Array]]: 86 | mu, logvar = self.encode(x, train=train) 87 | # Only reparametrize when training for simplicity. 88 | if train: 89 | z = self.reparametrize(mu, logvar) 90 | else: 91 | z = mu 92 | z = self.code_dropout_fn(z, train=train) 93 | x = self.decode(z, train=train) 94 | return x, {"mu": mu, "logvar": logvar, "z": z} 95 | -------------------------------------------------------------------------------- /big_vision/models/proj/image_text/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility functions.""" 16 | 17 | import jax 18 | from jax.experimental import shard_map 19 | from jax.interpreters import pxla 20 | 21 | 22 | P = jax.sharding.PartitionSpec 23 | 24 | 25 | def batch_shmap(fn, *args, **kwargs): 26 | """Shard map to map along the data dimension w/o triggering communication.""" 27 | 28 | mesh = pxla.thread_resources.env.physical_mesh 29 | if not mesh.empty: 30 | devices_flat = mesh.devices.flatten() 31 | mesh_flat = jax.sharding.Mesh(devices_flat, ("data",)) 32 | fn = shard_map.shard_map( 33 | fn, 34 | mesh=mesh_flat, 35 | in_specs=P("data"), out_specs=P("data"), check_rep=True) 36 | return fn(*args, **kwargs) 37 | 38 | 39 | def subsample_batch(x, subsample: int): 40 | """Shard map to subsample the data dimension w/o triggering communication.""" 41 | fn = lambda x: jax.tree.map(lambda xx: xx[::subsample], x) 42 | return batch_shmap(fn, x) if subsample > 1 else x 43 | -------------------------------------------------------------------------------- /big_vision/models/proj/uvim/vit_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for vit vqvae model.""" 16 | from absl.testing import absltest 17 | 18 | from big_vision.models.proj.uvim import vit 19 | import jax 20 | import jax.numpy as jnp 21 | import ml_collections 22 | 23 | 24 | class ViTVQVAEModelTest(absltest.TestCase): 25 | 26 | def test_model(self): 27 | model_config = ml_collections.ConfigDict({ 28 | "input_size": (32, 32), 29 | "code_len": 4, 30 | "width": 16, 31 | "mlp_dim": 64, 32 | "num_heads": 4, 33 | "enc_depth": 1, 34 | "dec_depth": 1, 35 | "with_encoder_ctx": True, 36 | "with_decoder_ctx": True, 37 | "statistics_axis_name": None, 38 | "inputs": { 39 | "in1": (10, 3), 40 | "in2": (25,), 41 | }, 42 | "outputs": { 43 | "out1": (5,), 44 | "out2": (20,), 45 | }, 46 | }) 47 | 48 | model = vit.Model(**model_config) 49 | batch_size = 4 50 | seq_len = (32 // 8) ** 2 51 | x = { 52 | "in1": jnp.zeros((batch_size, seq_len, 10, 3)), 53 | "in2": jnp.zeros((batch_size, seq_len, 25)), 54 | } 55 | ctx_image = jnp.zeros((batch_size,) + model_config.input_size + (3,)) 56 | init_rngs = { 57 | "params": jax.random.PRNGKey(0), 58 | "state": jax.random.PRNGKey(1), 59 | } 60 | params = model.init(init_rngs, x, ctx=ctx_image) 61 | self.assertEqual(params.keys(), set(["params", "state"])) 62 | 63 | apply_rngs = { 64 | "dropout": jax.random.PRNGKey(0), 65 | "vqvae": jax.random.PRNGKey(0), 66 | } 67 | (logits, _), params = model.apply( 68 | params, x, ctx=ctx_image, train=True, update_dict=True, 69 | rngs=apply_rngs, mutable=["state"]) 70 | self.assertEqual(logits.keys(), set(["out1", "out2"])) 71 | self.assertEqual(logits["out1"].shape, (batch_size, seq_len, 5)) 72 | self.assertEqual(logits["out2"].shape, (batch_size, seq_len, 20)) 73 | 74 | 75 | if __name__ == "__main__": 76 | absltest.main() 77 | -------------------------------------------------------------------------------- /big_vision/models/proj/uvim/vtt_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for vision-text-transformer.""" 16 | from absl.testing import absltest 17 | 18 | from big_vision.models.proj.uvim import vtt 19 | import jax 20 | import jax.numpy as jnp 21 | import ml_collections 22 | 23 | 24 | class VTTTest(absltest.TestCase): 25 | 26 | def test_vtt_with_1_step(self): 27 | model_config = ml_collections.ConfigDict(dict( 28 | input_size=(224, 224), 29 | patches={"size": (16, 16)}, 30 | num_heads=2, 31 | num_layers=2, 32 | mlp_dim=128, 33 | emb_dim=64, 34 | vocab_size=500)) 35 | batch_size, max_len = 8, 50 36 | image = jnp.ones((batch_size, 224, 224, 3)) 37 | text = jnp.ones((batch_size, max_len), dtype=jnp.int32) 38 | 39 | m = vtt.Model(**model_config) 40 | variables = m.init(jax.random.PRNGKey(42), image, text) 41 | self.assertCountEqual(variables.keys(), ["params"]) 42 | 43 | params = variables["params"] 44 | out = m.apply({"params": params}, image, text) 45 | expected_shape = (batch_size, max_len, model_config.vocab_size) 46 | self.assertEqual(out.shape, expected_shape) 47 | 48 | 49 | if __name__ == "__main__": 50 | absltest.main() 51 | -------------------------------------------------------------------------------- /big_vision/pp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/big_vision/0127fb6b337ee2a27bf4e54dea79cff176527356/big_vision/pp/__init__.py -------------------------------------------------------------------------------- /big_vision/pp/archive/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/big_vision/0127fb6b337ee2a27bf4e54dea79cff176527356/big_vision/pp/archive/__init__.py -------------------------------------------------------------------------------- /big_vision/pp/archive/randaug.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """RandAug depends on deprecated tfa.image package, now defunct.""" 16 | 17 | from big_vision.pp import registry 18 | from big_vision.pp import utils 19 | from big_vision.pp.archive import autoaugment 20 | 21 | 22 | @registry.Registry.register("preprocess_ops.randaug") 23 | @utils.InKeyOutKey() 24 | def get_randaug(num_layers: int = 2, magnitude: int = 10): 25 | """Creates a function that applies RandAugment. 26 | 27 | RandAugment is from the paper https://arxiv.org/abs/1909.13719, 28 | 29 | Args: 30 | num_layers: Integer, the number of augmentation transformations to apply 31 | sequentially to an image. Represented as (N) in the paper. Usually best 32 | values will be in the range [1, 3]. 33 | magnitude: Integer, shared magnitude across all augmentation operations. 34 | Represented as (M) in the paper. Usually best values are in the range [5, 35 | 30]. 36 | 37 | Returns: 38 | a function that applies RandAugment. 39 | """ 40 | 41 | def _randaug(image): 42 | return autoaugment.distort_image_with_randaugment( 43 | image, num_layers, magnitude 44 | ) 45 | 46 | return _randaug 47 | -------------------------------------------------------------------------------- /big_vision/pp/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Preprocessing builder.""" 16 | 17 | from absl import logging 18 | from big_vision.pp import registry 19 | import tensorflow as tf 20 | 21 | 22 | def get_preprocess_fn(pp_pipeline, log_data=True, log_steps=False): 23 | """Transform an input string into the preprocessing function. 24 | 25 | The minilanguage is as follows: 26 | 27 | fn1|fn2(arg, arg2,...)|... 28 | 29 | And describes the successive application of the various `fn`s to the input, 30 | where each function can optionally have one or more arguments, which are 31 | either positional or key/value, as dictated by the `fn`. 32 | 33 | The output preprocessing function expects a dictionary as input. This 34 | dictionary should have a key "image" that corresponds to a 3D tensor 35 | (height x width x channel). 36 | 37 | Args: 38 | pp_pipeline: A string describing the pre-processing pipeline. If empty or 39 | None, no preprocessing will be executed. 40 | log_data: Whether to log the data before and after preprocessing. Can also 41 | be a string to show in the log for debugging, for example dataset name. 42 | log_steps: Whether to log the steps of the preprocessing pipeline. 43 | 44 | Returns: 45 | preprocessing function. 46 | 47 | Raises: 48 | ValueError: if preprocessing function name is unknown 49 | """ 50 | 51 | names, ops, spec_strings = [], [], [] 52 | if pp_pipeline: 53 | for op_spec in pp_pipeline.split("|"): 54 | if not op_spec: continue # Skip empty section instead of error. 55 | try: 56 | ops.append(registry.Registry.lookup(f"preprocess_ops.{op_spec}")()) 57 | names.append(registry.parse_name(op_spec)[0]) 58 | spec_strings.append(op_spec) 59 | except SyntaxError as err: 60 | raise ValueError(f"Syntax error on: {op_spec}") from err 61 | 62 | def _preprocess_fn(data): 63 | """The preprocessing function that is returned.""" 64 | nonlocal log_data, log_steps 65 | 66 | # Apply all the individual steps in sequence. 67 | if log_data: 68 | logging.info("Data before pre-processing (%s):\n%s", log_data, data) 69 | for name, op, spec in zip(names, ops, spec_strings): 70 | if log_steps: 71 | logging.info("Pre-processing step (%s): %s\n%s", name, spec, data) 72 | with tf.name_scope(name): 73 | data = op(data) 74 | 75 | # Validate input 76 | if not isinstance(data, dict): 77 | raise ValueError("Argument `data` must be a dictionary, " 78 | "not %s" % str(type(data))) 79 | 80 | if log_data: 81 | logging.info("Data after pre-processing (%s):\n%s", log_data, data) 82 | log_data = False # For eager&pygrain: only log first one of each pipeline. 83 | return data 84 | 85 | return _preprocess_fn 86 | -------------------------------------------------------------------------------- /big_vision/pp/builder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for builder.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from big_vision.pp import builder 22 | from big_vision.pp import ops_general # pylint: disable=unused-import 23 | from big_vision.pp import ops_image # pylint: disable=unused-import 24 | import numpy as np 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | class BuilderTest(tf.test.TestCase): 29 | 30 | def testSingle(self): 31 | pp_fn = builder.get_preprocess_fn("resize(256)") 32 | x = np.random.randint(0, 256, [640, 480, 3]) 33 | image = pp_fn({"image": x})["image"] 34 | self.assertEqual(image.numpy().shape, (256, 256, 3)) 35 | 36 | def testEmpty(self): 37 | pp_fn = builder.get_preprocess_fn("||inception_crop|||resize(256)||") 38 | 39 | # Typical image input 40 | x = np.random.randint(0, 256, [640, 480, 3]) 41 | image = pp_fn({"image": x})["image"] 42 | self.assertEqual(image.numpy().shape, (256, 256, 3)) 43 | 44 | def testPreprocessingPipeline(self): 45 | pp_str = ("inception_crop|resize(256)|resize((256, 256))|" 46 | "central_crop((80, 120))|flip_lr|value_range(0,1)|" 47 | "value_range(-1,1)") 48 | pp_fn = builder.get_preprocess_fn(pp_str) 49 | 50 | # Typical image input 51 | x = np.random.randint(0, 256, [640, 480, 3]) 52 | image = pp_fn({"image": x})["image"] 53 | self.assertEqual(image.numpy().shape, (80, 120, 3)) 54 | self.assertLessEqual(np.max(image.numpy()), 1) 55 | self.assertGreaterEqual(np.min(image.numpy()), -1) 56 | 57 | def testNumArgsException(self): 58 | 59 | x = np.random.randint(0, 256, [640, 480, 3]) 60 | for pp_str in [ 61 | "inception_crop(1)", 62 | "resize()", 63 | "resize(1, 1, 1)" 64 | "flip_lr(1)", 65 | "central_crop()", 66 | ]: 67 | with self.assertRaises(BaseException): 68 | builder.get_preprocess_fn(pp_str)(x) 69 | 70 | 71 | if __name__ == "__main__": 72 | tf.test.main() 73 | -------------------------------------------------------------------------------- /big_vision/pp/ops_image_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ops_image.""" 16 | 17 | import copy 18 | import io 19 | 20 | import big_vision.pp.ops_image as pp 21 | import matplotlib.pyplot as plt 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | 26 | def get_image_data(): 27 | img = tf.random.uniform((640, 480, 3), 0, 255, tf.int32) # Can't ask uint8!? 28 | return {"image": tf.cast(img, tf.uint8)} 29 | 30 | 31 | class PreprocessOpsTest(tf.test.TestCase): 32 | 33 | def tfrun(self, ppfn, data): 34 | # Run once as standalone, as could happen eg in colab. 35 | yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()} 36 | 37 | # And then once again as part of tfdata pipeline. 38 | # You'd be surprised how much these two differ! 39 | tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data)) 40 | for npdata in tfdata.map(ppfn).as_numpy_iterator(): 41 | yield npdata 42 | 43 | def test_resize(self): 44 | for data in self.tfrun(pp.get_resize([120, 80]), get_image_data()): 45 | self.assertEqual(data["image"].shape, (120, 80, 3)) 46 | 47 | def test_resize_small(self): 48 | for data in self.tfrun(pp.get_resize_small(240), get_image_data()): 49 | self.assertEqual(data["image"].shape, (320, 240, 3)) 50 | 51 | def test_resize_long(self): 52 | for data in self.tfrun(pp.get_resize_long(320), get_image_data()): 53 | self.assertEqual(data["image"].shape, (320, 240, 3)) 54 | 55 | def test_inception_crop(self): 56 | for data in self.tfrun(pp.get_inception_crop(), get_image_data()): 57 | self.assertEqual(data["image"].shape[-1], 3) 58 | 59 | def test_decode_jpeg_and_inception_crop(self): 60 | f = io.BytesIO() 61 | plt.imsave(f, get_image_data()["image"].numpy(), format="jpg") 62 | data = {"image": tf.cast(f.getvalue(), tf.string)} 63 | for data in self.tfrun(pp.get_decode_jpeg_and_inception_crop(), data): 64 | self.assertEqual(data["image"].shape[-1], 3) 65 | 66 | def test_random_crop(self): 67 | for data in self.tfrun(pp.get_random_crop([120, 80]), get_image_data()): 68 | self.assertEqual(data["image"].shape, (120, 80, 3)) 69 | 70 | def test_central_crop(self): 71 | for data in self.tfrun(pp.get_central_crop([20, 80]), get_image_data()): 72 | self.assertEqual(data["image"].shape, (20, 80, 3)) 73 | 74 | def test_random_flip_lr(self): 75 | data_orig = get_image_data() 76 | for data in self.tfrun(pp.get_random_flip_lr(), data_orig): 77 | self.assertTrue( 78 | np.all(data_orig["image"].numpy() == data["image"]) or 79 | np.all(data_orig["image"].numpy() == data["image"][:, ::-1])) 80 | 81 | if __name__ == "__main__": 82 | tf.test.main() 83 | -------------------------------------------------------------------------------- /big_vision/pp/proj/clippo/download_unifont.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | # This is intended to be run from the big_vision repository root: 17 | # 18 | # bash big_vision/pp/proj/clippo/download_unifont.sh 19 | wget https://unifoundry.com/pub/unifont/unifont-9.0.06/font-builds/unifont-9.0.06.hex.gz https://unifoundry.com/pub/unifont/unifont-9.0.06/font-builds/unifont_upper-9.0.06.hex.gz 20 | gunzip unifont-9.0.06.hex.gz unifont_upper-9.0.06.hex.gz 21 | mv unifont-9.0.06.hex unifont_upper-9.0.06.hex big_vision/pp/proj/clippo/ -------------------------------------------------------------------------------- /big_vision/pp/proj/flaxformer/bert_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """BERT-related preprocessing ops (using WordPiece tokenizer).""" 16 | 17 | from big_vision.pp import utils 18 | from big_vision.pp.registry import Registry 19 | import tensorflow as tf 20 | import tensorflow_text 21 | 22 | 23 | # Internally using 24 | # BasicTokenizer 25 | # https://github.com/tensorflow/text/blob/df5250d6cf1069990df4bf55154867391ab5381a/tensorflow_text/python/ops/bert_tokenizer.py#L67 26 | # WordpieceTokenizer 27 | # https://github.com/tensorflow/text/blob/master/tensorflow_text/python/ops/wordpiece_tokenizer.py 28 | def _create_bert_tokenizer(vocab_path): 29 | """Returns cls_token id and tokenizer to use in a tf.Dataset.map function.""" 30 | # Create tokenizer inside a tf.init_scope so the vocab is only loaded from 31 | # disk once per dataset iterator (see: http://(internal link)). 32 | # TODO: Make a local copy of vocab if creating many iterators. 33 | with tf.init_scope(): 34 | tokenizer = tensorflow_text.BertTokenizer( 35 | vocab_path, 36 | token_out_type=tf.int32, 37 | lower_case=True, 38 | ) 39 | 40 | with tf.io.gfile.GFile(vocab_path) as f: 41 | vocab = f.read().split("\n") 42 | cls_token = vocab.index("[CLS]") 43 | 44 | return cls_token, tokenizer 45 | 46 | 47 | @Registry.register("preprocess_ops.bert_tokenize") 48 | @utils.InKeyOutKey(indefault=None, outdefault="labels") 49 | def get_pp_bert_tokenize(vocab_path, max_len, sample_if_multi=True): 50 | """Extracts tokens with tensorflow_text.BertTokenizer. 51 | 52 | Args: 53 | vocab_path: Path to a file containing the vocabulry for the WordPiece 54 | tokenizer. It's the "vocab.txt" file in the zip file downloaded from 55 | the original repo https://github.com/google-research/bert 56 | max_len: Number of tokens after tokenization. 57 | sample_if_multi: Whether the first text should be taken (if set to `False`), 58 | or whether a random text should be tokenized. 59 | 60 | Returns: 61 | A preprocessing Op. 62 | """ 63 | 64 | cls_token, tokenizer = _create_bert_tokenizer(vocab_path) 65 | 66 | def _pp_bert_tokenize(labels): 67 | 68 | labels = tf.reshape(labels, (-1,)) 69 | labels = tf.concat([labels, [""]], axis=0) 70 | if sample_if_multi: 71 | num_texts = tf.maximum(tf.shape(labels)[0] - 1, 1) # Don't sample "". 72 | txt = labels[tf.random.uniform([], 0, num_texts, dtype=tf.int32)] 73 | else: 74 | txt = labels[0] # Always works, since we append "" earlier on. 75 | 76 | token_ids = tokenizer.tokenize(txt[None]) 77 | padded_token_ids, mask = tensorflow_text.pad_model_inputs( 78 | token_ids, max_len - 1) 79 | del mask # Recovered from zero padding in model. 80 | count = tf.shape(padded_token_ids)[0] 81 | padded_token_ids = tf.concat( 82 | [tf.fill([count, 1], cls_token), padded_token_ids], axis=1) 83 | return padded_token_ids[0] 84 | 85 | return _pp_bert_tokenize 86 | -------------------------------------------------------------------------------- /big_vision/pp/proj/flaxformer/bert_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for bert_ops.""" 16 | 17 | import tempfile 18 | 19 | from big_vision import input_pipeline 20 | import big_vision.pp.builder as pp_builder 21 | import big_vision.pp.ops_general # pylint: disable=unused-import 22 | from big_vision.pp.proj.flaxformer import bert_ops # pylint: disable=unused-import 23 | import tensorflow as tf 24 | 25 | 26 | # BERT vocabulary for testing. 27 | _BERT_VOCAB = [ 28 | "[PAD]", 29 | "[UNK]", 30 | "more", 31 | "than", 32 | "one", 33 | "[CLS]", 34 | "[SEP]", 35 | ] 36 | 37 | 38 | def _create_ds(pp_str, tensor_slices, num_examples): 39 | return input_pipeline.make_for_inference( 40 | tf.data.Dataset.from_tensor_slices(tensor_slices), 41 | num_ex_per_process=[num_examples], 42 | preprocess_fn=pp_builder.get_preprocess_fn(pp_str), 43 | batch_size=num_examples, 44 | )[0] 45 | 46 | 47 | class BertOpsTest(tf.test.TestCase): 48 | 49 | def test_tokenize(self): 50 | inkey = "texts" 51 | vocab_path = f"{tempfile.mkdtemp()}/vocab.txt" 52 | with open(vocab_path, "w") as f: 53 | f.write("\n".join(_BERT_VOCAB)) 54 | pp_str = ( 55 | f"bert_tokenize(inkey='{inkey}', vocab_path='{vocab_path}', max_len=5)" 56 | f"|keep('labels')" 57 | ) 58 | tensor_slices = { 59 | inkey: tf.ragged.constant([["one more"], ["more than one"], [""]]) 60 | } 61 | ds = _create_ds(pp_str, tensor_slices, 3) 62 | self.assertAllEqual( 63 | next(iter(ds))["labels"], 64 | [[5, 4, 2, 0, 0], [5, 2, 3, 4, 0], [5, 0, 0, 0, 0]], 65 | ) 66 | 67 | 68 | if __name__ == "__main__": 69 | tf.test.main() 70 | -------------------------------------------------------------------------------- /big_vision/pp/proj/givt/pp_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """GIVT-specific preprocessing ops.""" 16 | 17 | from big_vision.pp import registry 18 | from big_vision.pp import utils 19 | import tensorflow as tf 20 | 21 | 22 | @registry.Registry.register("preprocess_ops.bin_nyu_depth") 23 | @utils.InKeyOutKey(indefault="labels", outdefault="labels") 24 | def get_bin_nyu_depth(min_depth=0.001, max_depth=10.0, num_bins=256): 25 | """Binning of NYU depth for UViM in preprocessing rather than model.""" 26 | 27 | def _bin_depth(labels): # pylint: disable=missing-docstring 28 | labels = (labels - min_depth) / (max_depth - min_depth) 29 | labels *= num_bins 30 | labels = tf.cast(tf.floor(labels), tf.int32) 31 | labels = tf.minimum(labels, num_bins - 1) 32 | labels = tf.maximum(labels, 0) 33 | return labels 34 | 35 | return _bin_depth 36 | 37 | -------------------------------------------------------------------------------- /big_vision/pp/proj/image_text/ops_naflex_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for NaFlex preprocessing ops.""" 16 | 17 | import copy 18 | 19 | from absl.testing import parameterized 20 | from big_vision.pp.proj.image_text import ops_naflex as pp 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | 25 | def get_image_data(h, w): 26 | img = tf.random.uniform((h, w, 3), 0, 255, tf.int32) # Can't ask uint8!? 27 | return {"image": tf.cast(img, tf.uint8)} 28 | 29 | 30 | class NaflexTest(tf.test.TestCase, parameterized.TestCase): 31 | 32 | def tfrun(self, ppfn, data): 33 | # Run once as standalone, as could happen eg in colab. 34 | yield tf.nest.map_structure(np.array, ppfn(copy.deepcopy(data))) 35 | 36 | # And then once again as part of tfdata pipeline. 37 | # You'd be surprised how much these two differ! 38 | tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data)) 39 | for npdata in tfdata.map(ppfn).as_numpy_iterator(): 40 | yield npdata 41 | 42 | @parameterized.parameters( 43 | (6, 8), 44 | (7, 9), 45 | (8, 10), 46 | ) 47 | def test_patchify_valid(self, h, w): 48 | """Tests the patchification op.""" 49 | op = pp.get_patchify((3, 4)) 50 | inputs = get_image_data(h, w) 51 | for data in self.tfrun(op, inputs): 52 | self.assertEqual(data["image"]["patches"].shape, (4, 3*4*3)) 53 | self.assertAllEqual( 54 | data["image"]["patches"][-1], 55 | np.array(inputs["image"])[3:6, 4:8, :].flatten()) 56 | self.assertAllEqual(data["image"]["yidx"], [0, 0, 1, 1]) 57 | self.assertAllEqual(data["image"]["xidx"], [0, 1, 0, 1]) 58 | 59 | @parameterized.named_parameters([ 60 | ("square_121_exact", (48, 48), 3, 121, (33, 33)), 61 | ("square_225_inexact", (112, 109), 7, 225, (105, 105)), 62 | ("square_64_exact", (176, 176), 11, 64, (88, 88)), 63 | ("rect_12_exact", (256, 64), 16, 12, (96, 32)), 64 | ("rect_15_exact_ps8", (256, 64), 8, 15, (56, 16)), 65 | ("rect_16_inexact", (63, 241), 16, 16, (32, 128)), 66 | ("rect_less_than_patch", (16, 512), 16, 16, (16, 256)), 67 | ]) 68 | def test_pp_resize_to_sequence( 69 | self, image_size, patch_size, seq_len, expected_image_size): 70 | """Tests the AR-preserving `resize_to_sequence` op.""" 71 | op = pp.get_resize_to_sequence(patch_size, seq_len) 72 | inputs = get_image_data(*image_size) 73 | for outputs in self.tfrun(op, inputs): 74 | self.assertAllEqual(outputs["image"].shape, expected_image_size + (3,)) 75 | 76 | if __name__ == "__main__": 77 | tf.test.main() 78 | -------------------------------------------------------------------------------- /big_vision/pp/proj/paligemma/robustness.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """pp ops.""" 16 | 17 | import math 18 | 19 | from big_vision.pp import utils 20 | from big_vision.pp.registry import Registry 21 | import tensorflow as tf 22 | 23 | 24 | @Registry.register("preprocess_ops.resize_r") 25 | @utils.InKeyOutKey() 26 | def get_resize_r(size): 27 | """Like standard `resize` but randomize some of its parameters.""" 28 | size = utils.maybe_repeat(size, 2) 29 | 30 | # Sadly TF won't let us pass symbolic arguments, so we need to pre-create all 31 | # variants of function calls we'd like to randomize over... 32 | resize_fns = [ 33 | lambda x, m=m, a=a: tf.image.resize(x, size, method=m, antialias=a) 34 | for m in ["bilinear", "bicubic", "lanczos3", "area", "mitchellcubic"] 35 | for a in [True, False] 36 | ] 37 | 38 | def _resize_r(image): 39 | """Resizes image to a given size.""" 40 | dtype = image.dtype 41 | tf_dtype = tf.type_spec_from_value(image).dtype 42 | ifn = tf.random.uniform((), 0, len(resize_fns), tf.int32) 43 | image = tf.switch_case(ifn, [lambda fn=fn: fn(image) for fn in resize_fns]) 44 | return tf.cast(tf.clip_by_value(image, tf_dtype.min, tf_dtype.max), dtype) 45 | 46 | return _resize_r 47 | 48 | 49 | @Registry.register("preprocess_ops.random_jpeg") 50 | @utils.InKeyOutKey() 51 | def get_random_jpeg(p): 52 | """With probability `p`, randomly encode-decode as jpeg.""" 53 | 54 | fns = [ 55 | lambda x: tf.image.adjust_jpeg_quality( 56 | x, dct_method="INTEGER_FAST", 57 | jpeg_quality=tf.random.uniform((), 75, 96, dtype=tf.int32), 58 | ), 59 | lambda x: tf.image.adjust_jpeg_quality( 60 | x, dct_method="INTEGER_ACCURATE", 61 | jpeg_quality=tf.random.uniform((), 75, 96, dtype=tf.int32), 62 | ), 63 | ] 64 | 65 | def _random_jpeg(image): 66 | """Resizes image to a given size.""" 67 | funcs = [lambda: image] + [lambda fn=fn: fn(image) for fn in fns] 68 | logits = [math.log(prob) for prob in [1 - p] + [p / len(fns)] * len(fns)] 69 | fn_idx = tf.random.categorical([logits], 1, dtype=tf.int32)[0, 0] 70 | return tf.switch_case(fn_idx, funcs) 71 | 72 | return _random_jpeg 73 | -------------------------------------------------------------------------------- /big_vision/pp/proj/paligemma/sciqa_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """pp ops.""" 16 | 17 | from big_vision.pp.registry import Registry 18 | import tensorflow as tf 19 | 20 | 21 | @Registry.register('preprocess_ops.sci_qa_choices_shuffle') 22 | def sci_qa_choices_shuffle( 23 | choice_str_inkey='choices', 24 | ans_inkey='answer', 25 | indexed_choices_outkey='indexed_choices', 26 | indexed_answer_outkey='indexed_answer', 27 | ): 28 | """Random shuffle the sci_qa's choice on the fly. 29 | 30 | Args: 31 | choice_str_inkey: the original choice list from 32 | sciqa,e.g['apple','banana',..] 33 | ans_inkey: the original answer from sciqa e.g. 1 34 | indexed_choices_outkey: shuffled choice (with index suffix concat to string) 35 | e.g."(A) banana, (B) apple" 36 | indexed_answer_outkey: shuffled answer with abc index, e,g 37 | 1(original)->2(shuffled)->'B' (alphabet index) 38 | 39 | Returns: 40 | """ 41 | def _template(data): 42 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 43 | abc_tensor = tf.constant([f'({a})' for a in alphabet]) 44 | abcans_tensor = tf.constant([f'{a}' for a in alphabet]) 45 | choices = data[choice_str_inkey] 46 | indices = tf.range(len(choices)) 47 | # Shuffle the indices 48 | shuffled_indices = tf.random.shuffle(indices) 49 | # Use the shuffled indices to shuffle the tensor 50 | shuffled_tensor = tf.gather(choices, shuffled_indices) 51 | 52 | abc_tensor = tf.gather(abc_tensor, indices) 53 | 54 | data[indexed_choices_outkey] = tf.strings.reduce_join( 55 | tf.strings.join([abc_tensor, shuffled_tensor], separator=' '), 56 | separator=', ', 57 | ) 58 | 59 | answer_tensor = data[ans_inkey] 60 | new_ans_indice = tf.where(tf.equal(shuffled_indices, answer_tensor)) 61 | new_ans_indice = tf.gather(abcans_tensor, new_ans_indice) 62 | data[indexed_answer_outkey] = tf.strings.reduce_join(new_ans_indice) 63 | return data 64 | 65 | return _template 66 | -------------------------------------------------------------------------------- /big_vision/pp/proj/paligemma/video.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Preprocessing for videos.""" 16 | 17 | from big_vision.pp import utils 18 | from big_vision.pp.registry import Registry 19 | 20 | import tensorflow as tf 21 | 22 | 23 | @Registry.register('preprocess_ops.video_decode') 24 | def video_decode(res): 25 | """Preprocessing.""" 26 | 27 | def _pp_per_image(img): 28 | # decode 29 | return tf.image.resize(tf.io.decode_jpeg(img), (res, res)) 30 | 31 | def _pp(data): 32 | images = data['episodic_images'] 33 | # resize 34 | images = tf.map_fn(_pp_per_image, images, fn_output_signature=tf.float32) 35 | # rescale 36 | images = 2 * (images / 255.) - 1.0 37 | data['image'] = images 38 | return data 39 | 40 | return _pp 41 | 42 | 43 | @Registry.register('preprocess_ops.video_ensure_shape') 44 | def video_ensure_shape(key, shape): 45 | """Preprocessing.""" 46 | def _video_ensure_shape(data): 47 | data[key] = tf.ensure_shape(data[key], shape) 48 | return data 49 | 50 | return _video_ensure_shape 51 | 52 | 53 | @Registry.register('preprocess_ops.video_replicate_img') 54 | def video_replicate_img(replicas, num_frames): 55 | """Ensure that for short videos, we have the correct number of frames. 56 | 57 | We replicate and select. 58 | 59 | Args: 60 | replicas: num_replicas before selection. Should be less than num_frames. 61 | num_frames: number of frames 62 | 63 | Returns: 64 | _replicate_img: preprocessing function 65 | """ 66 | 67 | def _replicate_img(data): 68 | # visual analogies + query 69 | image = data['image'] 70 | image = tf.tile(image, [replicas, 1, 1, 1]) 71 | data['image'] = image[:num_frames] 72 | return data 73 | 74 | return _replicate_img 75 | 76 | 77 | @Registry.register('preprocess_ops.video_choice') 78 | @utils.InKeyOutKey() 79 | def video_choice(empty_fallback=None): 80 | """Randomly takes one entry out of a tensor after flattening.""" 81 | 82 | def _choice(x): 83 | x = tf.reshape(x, (-1,)) # Ensure it's a 1D array 84 | 85 | # Append the fallback value so we gracefully handle empty cases. 86 | x0 = tf.zeros(1, x.dtype) if empty_fallback is None else [empty_fallback] 87 | x = tf.concat([x, x0], axis=0) 88 | 89 | num_choices = tf.maximum(tf.shape(x)[0] - 1, 1) # Don't sample x0. 90 | return x[tf.random.uniform([], 0, num_choices, dtype=tf.int32)] 91 | 92 | return _choice 93 | 94 | 95 | @Registry.register('preprocess_ops.stack_images') 96 | def stack_images(inkeys=(), outkey='image'): 97 | 98 | def _pp(data): 99 | images = tf.stack([data[inkey] for inkey in inkeys]) 100 | data[outkey] = images 101 | return data 102 | 103 | return _pp 104 | -------------------------------------------------------------------------------- /big_vision/pp/proj/paligemma/widgetcap.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Widgetcap pp ops.""" 16 | 17 | from big_vision.pp.registry import Registry 18 | import tensorflow as tf 19 | 20 | 21 | @Registry.register("preprocess_ops.draw_bbox") 22 | def get_draw_bbox(image_key="image", bbox_key="bbox"): 23 | """Draw a single bounding box.""" 24 | 25 | def _draw_bbox(data): 26 | """Draw a single bounding box.""" 27 | image = tf.cast(data[image_key], tf.float32) 28 | image = tf.image.draw_bounding_boxes( 29 | tf.expand_dims(image, 0), 30 | tf.reshape(data[bbox_key], [1, 1, 4]), 31 | tf.constant([255, 0, 0], dtype=tf.float32, shape=[1, 3]), 32 | ) 33 | data[image_key] = tf.squeeze(image) 34 | return data 35 | 36 | return _draw_bbox 37 | -------------------------------------------------------------------------------- /big_vision/pp/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """The tokenizer API for big_vision, and central registration place.""" 16 | import functools 17 | import importlib 18 | from typing import Protocol 19 | 20 | from absl import logging 21 | from big_vision.pp import registry 22 | import big_vision.utils as u 23 | import numpy as np 24 | 25 | 26 | class Tokenizer(Protocol): 27 | """Just to unify on the API as we now have mmany different ones.""" 28 | 29 | def to_int(self, text, *, bos=False, eos=False): 30 | """Tokenizes `text` into a list of integer tokens. 31 | 32 | Args: 33 | text: can be a single string, or a list of strings. 34 | bos: Whether a beginning-of-sentence token should be prepended. 35 | eos: Whether an end-of-sentence token should be appended. 36 | 37 | Returns: 38 | List or list-of-list of tokens. 39 | """ 40 | 41 | def to_int_tf_op(self, text, *, bos=False, eos=False): 42 | """Same as `to_int()`, but as TF ops to be used in pp.""" 43 | 44 | def to_str(self, tokens, *, stop_at_eos=True): 45 | """Inverse of `to_int()`. 46 | 47 | Args: 48 | tokens: list of tokens, or list of lists of tokens. 49 | stop_at_eos: remove everything that may come after the first EOS. 50 | 51 | Returns: 52 | A string (if `tokens` is a list of tokens), or a list of strings. 53 | Note that most tokenizers strip select few control tokens like 54 | eos/bos/pad/unk from the output string. 55 | """ 56 | 57 | def to_str_tf_op(self, tokens, *, stop_at_eos=True): 58 | """Same as `to_str()`, but as TF ops to be used in pp.""" 59 | 60 | @property 61 | def pad_token(self): 62 | """Token id of padding token.""" 63 | 64 | @property 65 | def eos_token(self): 66 | """Token id of end-of-sentence token.""" 67 | 68 | @property 69 | def bos_token(self): 70 | """Token id of beginning-of-sentence token.""" 71 | 72 | @property 73 | def vocab_size(self): 74 | """Returns the size of the vocabulary.""" 75 | 76 | 77 | @functools.cache 78 | def get_tokenizer(name): 79 | with u.chrono.log_timing(f"z/secs/tokenizer/{name}"): 80 | if not registry.Registry.knows(f"tokenizers.{name}"): 81 | raw_name, *_ = registry.parse_name(name) 82 | logging.info("Tokenizer %s not registered, " 83 | "trying import big_vision.pp.%s", name, raw_name) 84 | importlib.import_module(f"big_vision.pp.{raw_name}") 85 | 86 | return registry.Registry.lookup(f"tokenizers.{name}")() 87 | 88 | 89 | def get_extra_tokens(tokensets): 90 | extra_tokens = [] 91 | for tokenset in tokensets: 92 | extra_tokens.extend(registry.Registry.lookup(f"tokensets.{tokenset}")()) 93 | return list(np.unique(extra_tokens)) # Preserves order. Dups make no sense. 94 | 95 | 96 | @registry.Registry.register("tokensets.loc") 97 | def _get_loc1024(n=1024): 98 | return [f"" for i in range(n)] 99 | 100 | 101 | @registry.Registry.register("tokensets.seg") 102 | def _get_seg(n=128): 103 | return [f"" for i in range(n)] 104 | -------------------------------------------------------------------------------- /big_vision/pp/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Preprocessing utils.""" 16 | 17 | from collections import abc 18 | 19 | 20 | def maybe_repeat(arg, n_reps): 21 | if not isinstance(arg, abc.Sequence) or isinstance(arg, str): 22 | arg = (arg,) * n_reps 23 | return arg 24 | 25 | 26 | class InKeyOutKey(object): 27 | """Decorator for preprocessing ops, which adds `inkey` and `outkey` arguments. 28 | 29 | Note: Only supports single-input single-output ops. 30 | """ 31 | 32 | def __init__(self, indefault="image", outdefault="image", with_data=False): 33 | self.indefault = indefault 34 | self.outdefault = outdefault 35 | self.with_data = with_data 36 | 37 | def __call__(self, orig_get_pp_fn): 38 | 39 | def get_ikok_pp_fn(*args, key=None, 40 | inkey=self.indefault, outkey=self.outdefault, **kw): 41 | 42 | orig_pp_fn = orig_get_pp_fn(*args, **kw) 43 | def _ikok_pp_fn(data): 44 | # Optionally allow the function to get the full data dict as aux input. 45 | if self.with_data: 46 | data[key or outkey] = orig_pp_fn(data[key or inkey], data=data) 47 | else: 48 | data[key or outkey] = orig_pp_fn(data[key or inkey]) 49 | return data 50 | 51 | return _ikok_pp_fn 52 | 53 | return get_ikok_pp_fn 54 | -------------------------------------------------------------------------------- /big_vision/pp/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for preprocessing utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from big_vision.pp import utils 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | class UtilsTest(tf.test.TestCase): 26 | 27 | def test_maybe_repeat(self): 28 | self.assertEqual((1, 1, 1), utils.maybe_repeat(1, 3)) 29 | self.assertEqual((1, 2), utils.maybe_repeat((1, 2), 2)) 30 | self.assertEqual([1, 2], utils.maybe_repeat([1, 2], 2)) 31 | 32 | def test_inkeyoutkey(self): 33 | @utils.InKeyOutKey() 34 | def get_pp_fn(shift, scale=0): 35 | def _pp_fn(x): 36 | return scale * x + shift 37 | return _pp_fn 38 | 39 | data = {"k_in": 2, "other": 3} 40 | ppfn = get_pp_fn(1, 2, inkey="k_in", outkey="k_out") # pylint: disable=unexpected-keyword-arg 41 | self.assertEqual({"k_in": 2, "k_out": 5, "other": 3}, ppfn(data)) 42 | 43 | data = {"k": 6, "other": 3} 44 | ppfn = get_pp_fn(1, inkey="k", outkey="k") # pylint: disable=unexpected-keyword-arg 45 | self.assertEqual({"k": 1, "other": 3}, ppfn(data)) 46 | 47 | data = {"other": 6, "image": 3} 48 | ppfn = get_pp_fn(5, 2) 49 | self.assertEqual({"other": 6, "image": 11}, ppfn(data)) 50 | 51 | 52 | if __name__ == "__main__": 53 | tf.test.main() 54 | -------------------------------------------------------------------------------- /big_vision/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.26 2 | absl-py 3 | git+https://github.com/google/CommonLoopUtils 4 | distrax 5 | editdistance 6 | einops 7 | flax 8 | optax 9 | git+https://github.com/google/flaxformer 10 | git+https://github.com/akolesnikoff/panopticapi.git@mute 11 | overrides 12 | protobuf 13 | sentencepiece 14 | tensorflow-cpu 15 | tfds-nightly 16 | tensorflow-text 17 | tensorflow-gan 18 | psutil 19 | pycocoevalcap 20 | -------------------------------------------------------------------------------- /big_vision/run_tpu.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | 17 | if [ ! -d "bv_venv" ] 18 | then 19 | sudo apt-get update 20 | sudo apt install -y python3-venv 21 | python3 -m venv bv_venv 22 | . bv_venv/bin/activate 23 | 24 | pip install -U pip # Yes, really needed. 25 | # NOTE: doesn't work when in requirements.txt -> cyclic dep 26 | pip install "jax[tpu]>=0.4.25" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 27 | pip install -r big_vision/requirements.txt 28 | else 29 | . bv_venv/bin/activate 30 | fi 31 | 32 | if [ $# -ne 0 ] 33 | then 34 | env TFDS_DATA_DIR=$TFDS_DATA_DIR BV_JAX_INIT=1 python3 -m "$@" 35 | fi 36 | -------------------------------------------------------------------------------- /big_vision/tools/download_tfds_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Download and prepare TFDS datasets for the big_vision codebase. 16 | 17 | This python script covers cifar10, cifar100, oxford_iiit_pet 18 | and oxford_flowers10. 19 | 20 | If you want to integrate other public or custom datasets, please follow: 21 | https://www.tensorflow.org/datasets/catalog/overview 22 | """ 23 | 24 | from absl import app 25 | import tensorflow_datasets as tfds 26 | 27 | 28 | def main(argv): 29 | if len(argv) > 1 and "download_tfds_datasets.py" in argv[0]: 30 | datasets = argv[1:] 31 | else: 32 | datasets = [ 33 | "cifar10", 34 | "cifar100", 35 | "oxford_iiit_pet", 36 | "oxford_flowers102", 37 | "imagenet_v2", 38 | ] 39 | for d in datasets: 40 | tfds.load(name=d, download=True) 41 | 42 | 43 | if __name__ == "__main__": 44 | app.run(main) 45 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/README.md: -------------------------------------------------------------------------------- 1 | # LiT-Demo 2 | 3 | See https://blog.tensorflow.org/2022/08/jax-on-web-with-tensorflowjs.html 4 | 5 | Demo originally appeared on Twitter 6 | https://twitter.com/AndreasPSteiner/status/1514722383818543106 7 | 8 | App published at 9 | https://google-research.github.io/vision_transformer/lit 10 | 11 | ## Build 12 | 13 | Install packages (tested with node v16.17.0 and yarn 1.22.19) 14 | 15 | ```bash 16 | yarn 17 | ``` 18 | 19 | 20 | ## Run 21 | 22 | The web app will appear on http://localhost:8000 23 | 24 | ``` 25 | node build.js 26 | ``` 27 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/build.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | const sassPlugin = require('esbuild-sass-plugin').sassPlugin; 19 | 20 | require('esbuild').serve({ 21 | servedir: 'src', 22 | port: 8000, 23 | }, { 24 | entryPoints: ['src/app.ts'], 25 | bundle: true, 26 | outfile: 'src/index.js', 27 | plugins: [ 28 | sassPlugin({ 29 | filter: /style.scss$/, 30 | type: 'style' 31 | }), 32 | sassPlugin({ 33 | type: 'lit-css', 34 | }), 35 | ], 36 | sourcemap: true, 37 | }).then(() => { 38 | console.log('Serving on port 8000'); 39 | }).catch(() => process.exit(1)); 40 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "lit-demo", 3 | "version": "0.0.2", 4 | "description": "", 5 | "main": "src/app.ts", 6 | "license": "Apache-2.0", 7 | "private": true, 8 | "engines": { 9 | "node": ">=8.9.0" 10 | }, 11 | "scripts": { 12 | "serve": "node build.js", 13 | "test": "ts-node --skip-ignore --project tsconfig.test.json run_tests.ts" 14 | }, 15 | "devDependencies": { 16 | "@babel/core": "^7.7.5", 17 | "@babel/plugin-transform-runtime": "^7.7.6", 18 | "@babel/polyfill": "^7.10.4", 19 | "@babel/preset-env": "^7.7.6", 20 | "@tensorflow/tfjs-backend-cpu": "^3.15.0", 21 | "@tensorflow/tfjs-backend-webgl": "^3.15.0", 22 | "@tensorflow/tfjs-converter": "3.20.0", 23 | "@tensorflow/tfjs-core": "3.20.0", 24 | "babel-preset-env": "^1.7.0", 25 | "esbuild": "^0.15.5", 26 | "esbuild-sass-plugin": "^2.3.2", 27 | "jasmine": "^3.3.1", 28 | "lit": "^2.3.1", 29 | "naughty-words": "^1.2.0", 30 | "sass": "^1.50.0", 31 | "ts-node": "~5.0.0", 32 | "typescript": "4.1.3" 33 | }, 34 | "resolutions": { 35 | "is-svg": "4.3.1" 36 | }, 37 | "eslintConfig": { 38 | "extends": "google", 39 | "rules": { 40 | "require-jsdoc": 0, 41 | "valid-jsdoc": 0 42 | }, 43 | "env": { 44 | "es6": true 45 | }, 46 | "parserOptions": { 47 | "ecmaVersion": 8, 48 | "sourceType": "module" 49 | } 50 | }, 51 | "eslintIgnore": [ 52 | "dist/" 53 | ] 54 | } 55 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/app.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | import {LitDemoApp} from './components/lit-demo-app'; 19 | import './style.scss'; 20 | 21 | // tslint:disable-next-line:no-any 22 | (window as any).LitDemoApp = LitDemoApp; 23 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/components/image-carousel.scss: -------------------------------------------------------------------------------- 1 | @import '../style/mixins'; 2 | 3 | .selector { 4 | overflow: scroll; 5 | padding-bottom: 10px; // OS X scroll bar 6 | 7 | .inner { 8 | white-space: nowrap; 9 | 10 | .thumb { 11 | display: inline-block; 12 | 13 | img { 14 | cursor: pointer; 15 | 16 | width: 20vmin; 17 | height: 20vmin; 18 | max-width: 200px; 19 | max-height: 200px; 20 | 21 | @include phone-portrait { 22 | width: 33vmin; 23 | height: 33vmin; 24 | } 25 | 26 | margin: 10px; 27 | 28 | box-shadow: 0 0 10px #888; 29 | } 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/components/image-carousel.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview Carousel of images. 20 | */ 21 | 22 | import {html, LitElement} from 'lit'; 23 | 24 | import {app} from '../lit_demo/app'; 25 | import {getImageUrl} from '../lit_demo/constants'; 26 | import {ImageRow} from '../lit_demo/data'; 27 | 28 | import {customElement} from 'lit/decorators.js'; 29 | import styles from './image-carousel.scss'; 30 | 31 | /** 32 | * Shows multiple images in a horizontal carousel. 33 | * 34 | * Dispatches `'image-select'` event when an image is clicked/tapped. 35 | */ 36 | @customElement('image-carousel') 37 | export class ImageCarousel extends LitElement { 38 | static override styles = [styles]; 39 | 40 | onClick(id: string) { 41 | const event = 42 | new CustomEvent('image-select', {composed: true, detail: {id}}); 43 | this.dispatchEvent(event); 44 | } 45 | 46 | override render() { 47 | const images = app.imageData.rows.map( 48 | (row: ImageRow) => html` 49 |
50 | { 51 | this.onClick(row.id); 52 | }} data-id=${row.id} src="${getImageUrl(row.id)}"> 53 |
54 | `); 55 | return html` 56 |
57 |
58 | ${images} 59 |
60 |
61 |

Select an image 👆 to get started.

62 | `; 63 | } 64 | } 65 | 66 | declare global { 67 | interface HTMLElementTagNameMap { 68 | 'image-carousel': ImageCarousel; 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/components/image-prompts.scss: -------------------------------------------------------------------------------- 1 | @import '../style/mixins'; 2 | 3 | .image-prompt { 4 | display: flex; 5 | gap: 1.5em; 6 | align-items: flex-start; 7 | margin-top: 2rem; 8 | 9 | @include phone-portrait { 10 | align-items: center; 11 | flex-direction: column; 12 | gap: 0; 13 | margin-bottom: 5rem; 14 | } 15 | 16 | .left { 17 | display: flex; 18 | flex-direction: column; 19 | 20 | .wrapper { 21 | position: relative; 22 | 23 | .src { 24 | position: absolute; 25 | right: 2rem; 26 | bottom: 2rem; 27 | color: white; 28 | font-size: 1.5rem; 29 | text-shadow: 2px 2px black; 30 | text-decoration: none; 31 | } 32 | } 33 | 34 | .animation { 35 | position: relative; 36 | width: 224px; 37 | height: 15px; 38 | opacity: 0; 39 | 40 | .computing { 41 | text-align: center; 42 | } 43 | } 44 | } 45 | 46 | .right { 47 | display: flex; 48 | flex-grow: 1; 49 | flex-direction: column; 50 | gap: 0.5em; 51 | 52 | .top { 53 | text-align: right; 54 | height: 30px; 55 | } 56 | 57 | .buttons { 58 | display: flex; 59 | flex-wrap: wrap; 60 | justify-content: flex-end; 61 | gap: 1em; 62 | align-items: center; 63 | } 64 | 65 | .item { 66 | position: relative; 67 | display: flex; 68 | 69 | .pct { 70 | display: inline-block; 71 | margin-right: 1em; 72 | width: 3.5em; 73 | text-align: right; 74 | opacity: 0; 75 | transition: opacity 0.5s; 76 | } 77 | 78 | input { 79 | flex-grow: 1; 80 | max-width: 70vw; 81 | border-radius: 0; 82 | background: transparent; 83 | border: 0; 84 | border-bottom: 1px solid var(--text-fg); 85 | color: var(--text-fg); 86 | outline: none; 87 | 88 | &.toolong { 89 | border-bottom: 1px solid var(--text-red); 90 | color: var(--text-red); 91 | } 92 | } 93 | 94 | .bar { 95 | position: absolute; 96 | display: inline-block; 97 | top: 5%; 98 | left: 0; 99 | z-index: -1; 100 | background: var(--bar-col); 101 | height: 90%; 102 | width: 0; 103 | transition: width 0.5s; 104 | } 105 | } 106 | 107 | .bottom { 108 | display: flex; 109 | flex-wrap: wrap; 110 | justify-content: flex-end; 111 | gap: 1em; 112 | align-items: center; 113 | opacity: 0; 114 | 115 | .tweet { 116 | background: rgb(18, 150, 223); 117 | color: white; 118 | text-decoration: none; 119 | padding: 0px 15px; 120 | border-radius: 16px; 121 | } 122 | } 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/components/lit-demo-app.scss: -------------------------------------------------------------------------------- 1 | .loading-container { 2 | text-align: center; 3 | } 4 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/components/loading-animation.scss: -------------------------------------------------------------------------------- 1 | // CC0 from https://loading.io/css/ 2 | 3 | @import '../style/colors'; 4 | 5 | .lds-ellipsis { 6 | display: inline-block; 7 | position: relative; 8 | width: 80px; 9 | height: 80px; 10 | 11 | div { 12 | position: absolute; 13 | top: 33px; 14 | width: 13px; 15 | height: 13px; 16 | border-radius: 50%; 17 | background: var(--text-fg); 18 | animation-timing-function: cubic-bezier(0, 1, 1, 0); 19 | } 20 | 21 | div:nth-child(1) { 22 | left: 8px; 23 | animation: lds-ellipsis1 0.6s infinite; 24 | } 25 | 26 | div:nth-child(2) { 27 | left: 8px; 28 | animation: lds-ellipsis2 0.6s infinite; 29 | } 30 | 31 | div:nth-child(3) { 32 | left: 32px; 33 | animation: lds-ellipsis2 0.6s infinite; 34 | } 35 | 36 | div:nth-child(4) { 37 | left: 56px; 38 | animation: lds-ellipsis3 0.6s infinite; 39 | } 40 | } 41 | 42 | @keyframes lds-ellipsis1 { 43 | 0% { 44 | transform: scale(0); 45 | } 46 | 100% { 47 | transform: scale(1); 48 | } 49 | } 50 | @keyframes lds-ellipsis3 { 51 | 0% { 52 | transform: scale(1); 53 | } 54 | 100% { 55 | transform: scale(0); 56 | } 57 | } 58 | @keyframes lds-ellipsis2 { 59 | 0% { 60 | transform: translate(0, 0); 61 | } 62 | 100% { 63 | transform: translate(24px, 0); 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/components/loading-animation.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview Carousel of images. 20 | */ 21 | 22 | import {html, LitElement} from 'lit'; 23 | 24 | import {customElement} from 'lit/decorators.js'; 25 | import styles from './loading-animation.scss'; 26 | 27 | /** 28 | * Shows an animated loading animation. 29 | */ 30 | @customElement('loading-animation') 31 | export class LoadingAnimation extends LitElement { 32 | 33 | static override styles = [styles]; 34 | 35 | override render() { 36 | return html` 37 |
38 |
39 |
40 |
41 |
42 |
43 | `; 44 | } 45 | } 46 | 47 | declare global { 48 | interface HTMLElementTagNameMap { 49 | 'loading-animation': LoadingAnimation; 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/components/message-list.scss: -------------------------------------------------------------------------------- 1 | @import '../style/colors'; 2 | 3 | .message { 4 | padding: 0.1rem 0.5rem; 5 | margin-bottom: 1rem; 6 | } 7 | 8 | .warning { 9 | background: var(--warn-bg); 10 | color: var(--warn-fg); 11 | } 12 | 13 | .error { 14 | background: var(--error-bg); 15 | color: var(--error-fg); 16 | } 17 | 18 | .info { 19 | background: var(--note-bg); 20 | color: var(--note-fg); 21 | } 22 | 23 | .close { 24 | float: right; 25 | cursor: pointer; 26 | } 27 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/components/message-list.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview A list of dismissable info/warning/error messages. 20 | */ 21 | 22 | import {html, LitElement} from 'lit'; 23 | 24 | import {unsafeHTML} from 'lit/directives/unsafe-html.js'; 25 | 26 | import {customElement} from 'lit/decorators.js'; 27 | import styles from './message-list.scss'; 28 | 29 | enum MessageType { 30 | INFO = 'info', 31 | WARNING = 'warning', 32 | ERROR = 'error', 33 | } 34 | 35 | interface Message { 36 | message: string; 37 | type: MessageType; 38 | rawHtml: boolean; 39 | } 40 | 41 | 42 | /** 43 | * Shows info/warning/error messages that remain until closed by user. 44 | */ 45 | @customElement('message-list') 46 | export class MessageList extends LitElement { 47 | static override styles = [styles]; 48 | 49 | messages: Message[] = []; 50 | 51 | addMessage(message: Message) { 52 | this.messages.push(message); 53 | this.requestUpdate(); 54 | } 55 | 56 | info(message: string, {rawHtml = false}: {rawHtml?: boolean} = {}) { 57 | this.addMessage({message, type: MessageType.INFO, rawHtml}); 58 | } 59 | 60 | warning(message: string, {rawHtml = false}: {rawHtml?: boolean} = {}) { 61 | this.addMessage({message, type: MessageType.WARNING, rawHtml}); 62 | } 63 | 64 | error(message: string, {rawHtml = false}: {rawHtml?: boolean} = {}) { 65 | this.addMessage({message, type: MessageType.ERROR, rawHtml}); 66 | } 67 | 68 | removeMessage(event: Event, idx: number) { 69 | this.messages.splice(idx, 1); 70 | (event.target! as HTMLElement).closest('.message')!.remove(); 71 | } 72 | 73 | clear() { 74 | this.messages = []; 75 | while (this.firstChild) this.firstChild.remove(); 76 | } 77 | 78 | override render() { 79 | return this.messages.map( 80 | (message: Message, idx: number) => html` 81 |
82 | ${ 83 | message.rawHtml ? unsafeHTML(message.message) : 84 | message.message} 85 | { 86 | this.removeMessage(e, idx); 87 | }} class="close">✖ 88 |
89 | `); 90 | } 91 | } 92 | 93 | declare global { 94 | interface HTMLElementTagNameMap { 95 | 'message-list': MessageList; 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/components/model-controls.scss: -------------------------------------------------------------------------------- 1 | .controls { 2 | margin: 1em 0; 3 | display: flex; 4 | 5 | select { 6 | margin-left: 0.5em; 7 | } 8 | 9 | progress { 10 | margin: 0 1em; 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/components/model-controls.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview Controls to choose model. 20 | */ 21 | 22 | import {html, LitElement} from 'lit'; 23 | 24 | import {getModels} from '../lit_demo/constants'; 25 | import {app} from '../lit_demo/app'; 26 | 27 | import {customElement, property} from 'lit/decorators.js'; 28 | import styles from './model-controls.scss'; 29 | 30 | /** 31 | * Shows controls for model selection, progress bar, and status text. 32 | */ 33 | @customElement('model-controls') 34 | export class ModelControls extends LitElement { 35 | 36 | static override styles = [styles]; 37 | 38 | @property({attribute: false}) 39 | progress: number = 0; 40 | 41 | @property({attribute: false}) 42 | status: string = 'Initializing...'; 43 | 44 | constructor() { 45 | super(); 46 | app.models.addListener(this.onModelUpdate.bind(this)); 47 | app.models.load(getModels()[0]); 48 | } 49 | 50 | onModelUpdate(progress: number, message?: string) { 51 | this.progress = progress; 52 | if (message) this.status = message; 53 | } 54 | 55 | onModelChange(event: Event) { 56 | const target = event.target as HTMLSelectElement; 57 | const name = target.value; 58 | app.models.load(name).catch((error) => { 59 | this.status = `ERROR loading model "${name}": ${error}`; 60 | }); 61 | } 62 | 63 | async setModel(model: string) { 64 | if (getModels().indexOf(model) === -1) { 65 | throw new Error(`Model "${model}" not found!`); 66 | } 67 | await this.updateComplete; 68 | const dropdown = this.shadowRoot!.querySelector('#model_dropdown') as HTMLSelectElement; 69 | dropdown.value = model; 70 | dropdown.dispatchEvent(new Event('change')); 71 | } 72 | 73 | override render() { 74 | const options = getModels().map((model: string) => 75 | html``); 76 | return html` 77 |
78 | 79 | 82 | 83 |
${this.status}
84 |
85 | `; 86 | } 87 | } 88 | 89 | declare global { 90 | interface HTMLElementTagNameMap { 91 | 'model-controls': ModelControls; 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/exports.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview some useful exports to play around with the models & 20 | * tokenizers. 21 | * 22 | * Simple usage (see ./playground.html for more complete usage example): 23 | * 24 | * model = lit.Model('tiny'); 25 | * model.load(progress => console.log('loading...', progress)); 26 | * console.log(model.computeProbabilities(['a dog', 'a cat'], '0')); 27 | */ 28 | 29 | import {Model} from './lit_demo/compute'; 30 | import {getImageUrl, setBaseUrl} from './lit_demo/constants'; 31 | import {ImageData} from './lit_demo/data'; 32 | import * as tf from '@tensorflow/tfjs-core'; 33 | 34 | // tslint:disable-next-line:no-any Export symbols into global namespace. 35 | (window as any).lit = { Model, getImageUrl, ImageData, setBaseUrl }; 36 | // tslint:disable-next-line:no-any Export symbols into global namespace. 37 | // tslint:disable-next-line:ban-module-namespace-object-escape Export all of TF. 38 | (window as any).tf = tf; 39 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/index.html: -------------------------------------------------------------------------------- 1 | 17 | 18 | 19 | 20 | 21 | 22 | Lit Demo App 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 |

LiT: Zero-Shot Transfer with Locked-image Tuning

31 | 32 |

33 | This page is an interactive demo of the Google AI blog post 34 | LiT: adding language understanding to image models 36 | – please refer to that page for a detailed explanation of how a LiT model works. 37 | If you're interested in how this demo makes a JAX model run on device in your 38 | browser, check out our other blog post 39 | JAX on the Web with TensorFlow.js. 41 |

42 | 43 |

44 | Below you can choose an image from a selection and then write free-form 45 | text prompts that are matched to the image. Once you hit return on your 46 | keyboard or press the "compute" button, a text encoder implemented in 47 | TensorFlow.js 48 | will compute embeddings for the provided text on your local device, and the 49 | similarity of these text embeddings to the image embedding will be displayed. 50 |

51 | 52 |

53 | The prompts can be used to classify an image into multiple categories, listing 54 | each category individually with a prompt "an image of a X". But you can also 55 | probe the model interactively with more detailed prompts, comparing the 56 | different results when small details change in the text. 57 |

58 | 59 |

60 | Please use this demo responsibly. The models will always compare the image to 61 | the prompts you provide, and it is therefore trivial to construct situations 62 | where the model picks from a bunch of bad options. 63 |

64 | 65 |

66 | Note: 67 | The models available in this interactive demo are not those from the 68 | paper. 70 | We had to train much smaller text towers and tokenizers to avoid 71 | overloading your browser. Please see 72 | our GitHub repository 74 | for the models from the paper pre-trained on public datasets. 75 | Multilingual models coming soon. 76 |

77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/lit_demo/app.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview Global app state. 20 | */ 21 | 22 | import {ImageData} from './data'; 23 | import {Models} from './compute'; 24 | 25 | /** 26 | * Container class holding image data and models. 27 | * 28 | * The main application component would typically call `load()` and then show 29 | * the components depending on this class asynchronously. 30 | */ 31 | export class App { 32 | 33 | imageData = new ImageData(); 34 | models = new Models(); 35 | 36 | ready: boolean = false; 37 | 38 | async load() { 39 | await this.imageData.load(); 40 | this.ready = true; 41 | } 42 | } 43 | 44 | /** 45 | * Global app state. 46 | */ 47 | export const app = new App(); 48 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/lit_demo/constants.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview Project-wide constants. 20 | */ 21 | 22 | // Can be overwritten with setBaseUrl() below. 23 | // let baseUrl = 'https://google-research.github.io/vision_transformer/lit'; 24 | let baseUrl = 'https://figur.li/jax2tfjs'; 25 | // Can be overwritten with setModels() below. 26 | let models = ['tiny', 'small']; 27 | 28 | /** Allows to set abnew base URL. ase URL on which all other. */ 29 | export const setBaseUrl = (newBaseUrl: string) => { 30 | baseUrl = newBaseUrl; 31 | }; 32 | 33 | /** Retrieves URL for a model-specific file (vocabulary, embeddings, ...). */ 34 | export const getModelFileUrl = (name: string, relativePath: string) => ( 35 | `${baseUrl}/data/models/${name}/${relativePath}` 36 | ); 37 | 38 | /** Retrieves the URL for images information JSON file. */ 39 | export const getImagesInfoUrl = () => `${baseUrl}/data/images/info.json`; 40 | 41 | /** Retrieves the URL for an image. */ 42 | export const getImageUrl = (id: string) => `${baseUrl}/data/images/${id}.jpg`; 43 | 44 | /** Returns names of available models. */ 45 | export const getModels = () => models; 46 | 47 | /** Sets names of available models. */ 48 | export const setModels = (newModels: string[]) => { 49 | models = newModels; 50 | }; 51 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/lit_demo/data.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview Accessing additional data. 20 | */ 21 | 22 | import {getImagesInfoUrl} from './constants'; 23 | 24 | /** 25 | * Information about a single image. 26 | */ 27 | export interface ImageRow { 28 | /** Stable ID of the image. */ 29 | id: string; 30 | /** Set of example prompts for this image. */ 31 | prompts: string; 32 | /** License of the image. */ 33 | license: string; 34 | /** Where the image was originally downloaded from. */ 35 | source: string; 36 | /** Short description of image. */ 37 | description: string; 38 | } 39 | /** 40 | * Contains information about all images. 41 | */ 42 | export class ImageData { 43 | 44 | rows: ImageRow[] = []; 45 | /** Will be set to `true` when `load()` finishes. */ 46 | ready = false; 47 | 48 | /** 49 | * Gets an image by ID. Throws an error if image is not found, data is not 50 | * loaded, or ID is not unique. 51 | */ 52 | get(id: string): ImageRow { 53 | if (!this.ready) { 54 | throw new Error('ImageData not loaded!'); 55 | } 56 | const matching = this.rows.filter(row => row.id === id); 57 | if (matching.length !== 1) { 58 | throw new Error(`Got unexpected ${matching.length} matches for id="${id}"`); 59 | } 60 | return matching[0]; 61 | } 62 | 63 | /** 64 | * Loads image data asynchronously. 65 | */ 66 | async load() { 67 | this.rows = ( 68 | await fetch(getImagesInfoUrl()) 69 | .then(response => { 70 | console.log('response', response); 71 | return response.json(); 72 | }) 73 | ); 74 | this.ready = true; 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/lit_demo/url_utils.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview (De)serialize state from/to URL. 20 | */ 21 | 22 | // Should be updated whenever URLs are not compatible anymore 23 | // (e.g. adding new images) 24 | export const VERSION = 'v2'; 25 | // version history: 26 | // v1 used row number instead of image id 27 | 28 | const V1_IMAGE_IDS = [ 29 | '1', '48', '43', '22', '2', '3', '4', '5', '6', '7', '8', '9', 30 | '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', 31 | '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', 32 | '35', '36', '37', '38', '39', '40', '41', '42', '44', '45', '46', '47', 33 | '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60' 34 | ]; 35 | 36 | /** 37 | * State that can be stored in the URL. 38 | */ 39 | export interface State { 40 | /** Name of the model. */ 41 | modelName: string; 42 | /** ID Of the image. */ 43 | imageId: string; 44 | /** List of text prompts. */ 45 | prompts: string[]; 46 | } 47 | 48 | /** 49 | * Returns a URL for provided model/image/prompts. 50 | */ 51 | export const getUrl = 52 | (modelName: string, imageId: string, prompts: string[]): string => { 53 | let href = window.location.href; 54 | if (href.indexOf('#') !== -1) { 55 | href = href.substring(0, href.indexOf('#')); 56 | } 57 | const parts = [ 58 | VERSION, 59 | modelName, 60 | imageId, 61 | ...prompts, 62 | ]; 63 | return href + '#' + parts.map(encodeURIComponent).join('|'); 64 | }; 65 | 66 | /** 67 | * Parses an URL and returns a `State`, or undefined if no state is spefified. 68 | * 69 | * Raises an exception if there was a problem with the parsing of the URL. 70 | */ 71 | export const parseUrl = (): State|undefined => { 72 | const hash = window.location.hash.substring(1); 73 | if (!hash) return; 74 | const parts = hash.split(/\|/g); 75 | if (parts.length < 4) { 76 | throw new Error(`Invalid URL: "${hash}"`); 77 | } 78 | let [version, modelName, imageId, ...texts] = parts; 79 | if (version === VERSION) { 80 | } else if (version === 'v1') { 81 | const idx = Number(imageId); 82 | if (isNaN(idx)) throw new Error(`Expected idx="${idx}" to be numerical!`); 83 | imageId = V1_IMAGE_IDS[idx]; 84 | } else { 85 | throw new Error(`Incompatible version: ${version} (supported: ${VERSION})`); 86 | } 87 | return { 88 | modelName, 89 | imageId, 90 | prompts: texts.map(decodeURIComponent), 91 | }; 92 | }; 93 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/playground.html: -------------------------------------------------------------------------------- 1 | 17 | 18 | 19 | 20 | 21 | 22 |

23 | A simple demonstration how to use LiT models in a JS application using global exports. 24 | See source code of this file for API usage. 25 |

26 | 27 |

28 |     
29 | 
30 | 
31 | 
32 | 33 | 82 | 83 | 93 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/style.scss: -------------------------------------------------------------------------------- 1 | // General styles for the page. 2 | 3 | @import './style/colors'; 4 | @import './style/mixins'; 5 | 6 | html { 7 | font-size: 14px; 8 | line-height: 1.6em; 9 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, 10 | Ubuntu, Cantarell, 'Fira Sans', 'Droid Sans', 'Helvetica Neue', Arial, 11 | sans-serif; 12 | text-size-adjust: 100%; 13 | -ms-text-size-adjust: 100%; 14 | -webkit-text-size-adjust: 100%; 15 | 16 | @media (min-width: 1200px) { 17 | width: 1024px; 18 | margin: 0 auto; 19 | } 20 | @media (min-width: 768px) { 21 | font-size: 16px; 22 | } 23 | 24 | color: var(--text-fg); 25 | background: var(--text-bg); 26 | 27 | body { 28 | margin: 0; 29 | padding: 0rem 1rem 10rem; 30 | } 31 | } 32 | 33 | a, 34 | a:visited { 35 | color: var(--link-col); 36 | } 37 | 38 | h1 { 39 | font-weight: 700; 40 | font-size: 2rem; 41 | line-height: 1.3em; 42 | } 43 | 44 | p { 45 | font-size: 1.06rem; 46 | line-height: 1.3em; 47 | } 48 | 49 | input { 50 | font-size: 1rem; 51 | 52 | &::placeholder { 53 | color: var(--placeholder-col); 54 | } 55 | } 56 | 57 | .note { 58 | font-style: normal; 59 | border: none; 60 | border-radius: 2px; 61 | margin-left: auto; 62 | margin-right: auto; 63 | 64 | padding: 0.5rem 0.5rem 0.5rem 2rem; 65 | width: 90%; 66 | 67 | @include phone-portrait { 68 | width: 100%; 69 | padding: 0.5rem; 70 | box-sizing: border-box; 71 | } 72 | 73 | background-color: var(--note-bg); 74 | color: var(--note-fg); 75 | 76 | &.warning { 77 | background-color: var(--warn-bg); 78 | color: var(--warn-fg); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/style/colors.scss: -------------------------------------------------------------------------------- 1 | // Dark and light mode colors. 2 | 3 | :root { 4 | --text-bg: hsl(0, 0%, 97%); 5 | --gray-border: hsla(0, 0%, 0%, 0.1); 6 | --gray: rgba(0, 0, 0, 0.6); 7 | --border-radius: 5px; 8 | --orange: hsl(24, 100%, 50%); 9 | --distill-blue: hsl(200, 50%, 25%); 10 | --blue: #337699; 11 | --green: #3db867; 12 | --text-fg: rgb(15, 15, 15); 13 | --text-red: rgb(220, 0, 0); 14 | --bar-col: rgb(171, 199, 227); 15 | --link-col: rgb(0, 0, 238); 16 | --placeholder-col: rgb(166, 166, 166); 17 | --note-bg: #e1f5fe; 18 | --note-fg: #1a6ebb; 19 | --warn-bg: #ffe1aa; 20 | --warn-fg: #a16800; 21 | --error-bg: #850000; 22 | --error-fg: white; 23 | 24 | @media (prefers-color-scheme: dark) { 25 | --text-bg: rgb(56, 56, 56); 26 | --text-fg: rgb(213, 213, 213); 27 | --bar-col: rgb(20, 109, 163); 28 | --link-col: rgb(66, 165, 245); 29 | 30 | --note-fg: rgb(121 157 190); 31 | --note-bg: rgb(2 59 85); 32 | --warn-bg: #784e00; 33 | --warn-fg: #edbe68; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/style/mixins.scss: -------------------------------------------------------------------------------- 1 | // Useful mixins. 2 | 3 | // To wrap styles that should only trigger for phones in portrait mode. 4 | @mixin phone-portrait { 5 | @media only screen and (max-device-width: 800px) and (orientation: portrait) { 6 | @content; 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/tokenizers/common.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview Utility code shared between tokenizers. 20 | */ 21 | 22 | /** 23 | * A vocabulary consists of a list of tokens, and optional numerical value. 24 | * The numerical value is used by the unigram algorithnm to find the best 25 | * tokenizaion, and is ignored by the BPE algorithm. 26 | */ 27 | export type Vocabulary = Array<[string, number]>; 28 | 29 | /** 30 | * Converts a string to a sequence of tokens. 31 | */ 32 | export interface Tokenizer { 33 | encode(input: string): number[]; 34 | } 35 | 36 | /** 37 | * Factory for new `Tokenizer`. 38 | */ 39 | export interface TokenizerConstructor { 40 | new (vocabulary: Vocabulary): Tokenizer; 41 | } 42 | 43 | /** 44 | * Unicode-aware character iteration of strings. 45 | */ 46 | export const stringToChars = (input: string): string[] => { 47 | const symbols = []; 48 | for (const symbol of input) { 49 | symbols.push(symbol); 50 | } 51 | return symbols; 52 | }; 53 | 54 | /** 55 | * Special separator character used to delimit sub-word tokens. 56 | */ 57 | export const TOKEN_SEPARATOR = 58 | '\u2581'; // This is the unicode character 'lower one eighth block'. 59 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/tokenizers/index.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview Tokenizers and tokenizer mappings. 20 | */ 21 | 22 | import {Tokenizer, TokenizerConstructor, Vocabulary} from './common'; 23 | import * as sentencepieceBpe from './sentencepiece_bpe'; 24 | import * as sentencepieceUnigram from './sentencepiece_unigram'; 25 | 26 | export {Tokenizer, Vocabulary} from './common'; 27 | 28 | const TOKENIZERS = new Map([ 29 | ['BPE', sentencepieceBpe.Tokenizer], 30 | ['UNIGRAM', sentencepieceUnigram.Tokenizer], 31 | ]); 32 | 33 | /** 34 | * Returns a tokenizer of type `name` using `vocabulary`. 35 | */ 36 | export const getTokenizer = (name: string, vocabulary: Vocabulary): Tokenizer => { 37 | const ctor = TOKENIZERS.get(name); 38 | if (!ctor) throw new Error(`Unknown tokenizer: ${name}`); 39 | return new ctor(vocabulary); 40 | }; 41 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/tokenizers/sentencepiece_bpe.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | import {stringToChars, TOKEN_SEPARATOR, Vocabulary, Tokenizer as TokenizerInterface} from './common'; 19 | 20 | interface Candidate { 21 | piece: string; 22 | pos: number; 23 | score: number; 24 | } 25 | 26 | const scoreDesc = (a: Candidate, b: Candidate) => b.score - a.score; 27 | 28 | function processInput(str: string): string { 29 | const normalized = str.normalize('NFKC'); 30 | return normalized.length > 0 ? 31 | TOKEN_SEPARATOR + normalized.replace(/ /g, TOKEN_SEPARATOR) : 32 | normalized; 33 | } 34 | 35 | /** 36 | * Sentencepiece tokenizer implementing the BPE algorithm. 37 | */ 38 | export class Tokenizer implements TokenizerInterface { 39 | 40 | // piece -> [score, index] 41 | private readonly map: Map; 42 | 43 | constructor(vocabulary: Vocabulary) { 44 | this.map = new Map(); 45 | vocabulary.forEach(([piece, score], idx) => { 46 | if (this.map.has(piece)) { 47 | throw new Error(`Piece "${piece}" occurs multiple times in vocabulary`); 48 | } 49 | this.map.set(piece, [score, idx]); 50 | }); 51 | } 52 | 53 | encode(input: string): number[] { 54 | const processed: string = processInput(input); 55 | let pieces: string[] = stringToChars(processed); 56 | 57 | while (true) { 58 | const candidates: Candidate[] = []; 59 | for (let i = 0; i < pieces.length - 1; i++) { 60 | const fused = pieces[i] + pieces[i + 1]; 61 | const el = this.map.get(fused); 62 | if (el) { 63 | candidates.push({ piece: fused, pos: i, score: el[0] }); 64 | } 65 | } 66 | if (candidates.length === 0) { 67 | break; 68 | } 69 | candidates.sort(scoreDesc); 70 | const best = candidates[0]; 71 | pieces = [ 72 | ...pieces.slice(0, best.pos), 73 | best.piece, 74 | ...pieces.slice(best.pos + 2) 75 | ]; 76 | } 77 | 78 | return pieces.map(piece => this.map.get(piece)![1]); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/tokenizers/sentencepiece_bpe_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | import 'jasmine'; 19 | 20 | describe('sentencepiece bpe test', () => { 21 | it('computes a thing when asked', () => {}); 22 | }); 23 | 24 | import * as bpe from './sentencepiece_bpe'; 25 | import {TOKEN_SEPARATOR, Vocabulary} from './common'; 26 | 27 | const vocab: Vocabulary = [ 28 | [TOKEN_SEPARATOR, 0], // 0 29 | ['a', 0], // 1 30 | ['e', 0], // 2 31 | ['s', 0], // 3 32 | ['t', 0], // 4 33 | ['te', -1], // 5 34 | ['st', -2], // 6 35 | ['test', -3], // 7 36 | ['tes', -4], // 8 37 | ]; 38 | 39 | describe('BPE Tokenizer', () => { 40 | let tokenizer: bpe.Tokenizer; 41 | beforeAll(() => { 42 | tokenizer = new bpe.Tokenizer(vocab); 43 | }); 44 | 45 | it('should tokenize correctly', () => { 46 | expect(tokenizer.encode('a test')).toEqual([0, 1, 0, 7]); 47 | }); 48 | }); 49 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/tokenizers/sentencepiece_unigram_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | import {Tokenizer} from './sentencepiece_unigram'; 19 | 20 | const stubbedTokenizerVocab = [ 21 | ['�', 0], 22 | ['', 0], 23 | ['', 0], 24 | ['extra_token_id_1', 0], 25 | ['extra_token_id_2', 0], 26 | ['extra_token_id_3', 0], 27 | ['▁', -2], 28 | ['▁a', -1], 29 | ['▁ç', -2], 30 | ['a', -3], 31 | ['.', -1], 32 | ['▁I', -1], 33 | ['▁like', -1], 34 | ['▁it', -1], 35 | ['I', -2], 36 | ['like', -2], 37 | ['it', -2], 38 | ['l', -3], 39 | ['i', -3], 40 | ['k', -3], 41 | ['e', -3], 42 | ['i', -3], 43 | ['t', -3] 44 | ]; 45 | 46 | describe('Universal Sentence Encoder tokenizer', () => { 47 | let tokenizer: Tokenizer; 48 | beforeAll(() => { 49 | tokenizer = new Tokenizer(stubbedTokenizerVocab as Array<[string, number]>); 50 | }); 51 | 52 | it('basic usage', () => { 53 | expect(tokenizer.encode('Ilikeit.')).toEqual([11, 15, 16, 10]); 54 | }); 55 | 56 | it('handles whitespace', () => { 57 | expect(tokenizer.encode('I like it.')).toEqual([11, 12, 13, 10]); 58 | }); 59 | 60 | it('should normalize inputs', () => { 61 | expect(tokenizer.encode('ça')).toEqual(tokenizer.encode('c\u0327a')); 62 | }); 63 | 64 | it('should handle unknown inputs', () => { 65 | expect(() => tokenizer.encode('😹')).not.toThrow(); 66 | }); 67 | 68 | it('should treat consecutive unknown inputs as a single word', () => { 69 | expect(tokenizer.encode('a😹😹')).toEqual([7, 0]); 70 | }); 71 | }); 72 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/tokenizers/trie.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | // Copied from 19 | // https://github.com/tensorflow/tfjs-models/blob/master/universal-sentence-encoder/src/tokenizer/trie.ts 20 | 21 | import {stringToChars} from './common'; 22 | 23 | // [token, score, index] 24 | type OutputNode = [string[], number, number]; 25 | 26 | class TrieNode { 27 | parent: TrieNode|null; 28 | end: boolean; 29 | children: {[firstSymbol: string]: TrieNode}; 30 | word: OutputNode; 31 | 32 | constructor() { 33 | this.parent = null; 34 | this.children = {}; 35 | this.end = false; 36 | this.word = [[], 0, 0]; 37 | } 38 | } 39 | 40 | /** 41 | * Simple Trie datastructure. 42 | */ 43 | export class Trie { 44 | root: TrieNode; 45 | 46 | constructor() { 47 | this.root = new TrieNode(); 48 | } 49 | 50 | /** 51 | * Inserts a token into the trie. 52 | */ 53 | insert(word: string, score: number, index: number) { 54 | let node = this.root; 55 | 56 | const symbols = stringToChars(word); 57 | 58 | for (let i = 0; i < symbols.length; i++) { 59 | if (!node.children[symbols[i]]) { 60 | node.children[symbols[i]] = new TrieNode(); 61 | node.children[symbols[i]].parent = node; 62 | node.children[symbols[i]].word[0] = node.word[0].concat(symbols[i]); 63 | } 64 | 65 | node = node.children[symbols[i]]; 66 | if (i === symbols.length - 1) { 67 | node.end = true; 68 | node.word[1] = score; 69 | node.word[2] = index; 70 | } 71 | } 72 | } 73 | 74 | /** 75 | * Returns an array of all tokens starting with ss. 76 | * 77 | * @param ss The prefix to match on. 78 | */ 79 | commonPrefixSearch(ss: string[]): OutputNode[] { 80 | const output: OutputNode[] = []; 81 | let node = this.root.children[ss[0]]; 82 | 83 | for (let i = 0; i < ss.length && node; i++) { 84 | if (node.end) { 85 | output.push(node.word); 86 | } 87 | node = node.children[ss[i + 1]]; 88 | } 89 | 90 | if (!output.length) { 91 | output.push([[ss[0]], 0, 0]); 92 | } 93 | 94 | return output; 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /big_vision/tools/lit_demo/src/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "outDir": "dist", 4 | "target": "es6", 5 | "module": "commonjs", 6 | "lib": ["dom", "DOM.Iterable", "es2019", "es2020.string"], 7 | "types": ["node", "jasmine", "resize-observer-browser"], 8 | "moduleResolution": "node", 9 | "allowJs": false, 10 | "pretty": true, 11 | "resolveJsonModule": true, 12 | "sourceMap": false, 13 | "skipLibCheck": true, 14 | "removeComments": true, 15 | "esModuleInterop": true, 16 | "importsNotUsedAsValues": "preserve", 17 | "downlevelIteration": true, 18 | "skipDefaultLibCheck": true, 19 | "preserveConstEnums": false, 20 | "experimentalDecorators": true, 21 | "emitDecoratorMetadata": true, 22 | "noErrorTruncation": false, 23 | "noEmitOnError": false, 24 | "declaration": false, 25 | "stripInternal": true, 26 | "inlineSourceMap": true, 27 | "inlineSources": true, 28 | "importHelpers": true, 29 | "allowUnreachableCode": false, 30 | "noFallthroughCasesInSwitch": true, 31 | "noImplicitAny": true, 32 | "noImplicitReturns": false, 33 | "noImplicitThis": true, 34 | "strictBindCallApply": true, 35 | "strictFunctionTypes": true, 36 | "strictNullChecks": false, 37 | "strictPropertyInitialization": false 38 | }, 39 | "include": ["./client", "./examples"], 40 | "compileOnSave": false 41 | } 42 | -------------------------------------------------------------------------------- /big_vision/trainers/proj/flexi/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Few common utils used in both/all flexi-trainers.""" 16 | import functools 17 | import itertools 18 | import numpy as np 19 | 20 | 21 | def mkrng(xid, wid, step): 22 | # Need to cap at 0, for example localruns use -1. 23 | rng_key = (max(xid, 0), max(wid, 0), max(step, 0)) 24 | return np.random.default_rng(rng_key) 25 | 26 | 27 | def mkprob(x): 28 | if x is None: 29 | return x 30 | return np.array(x) / np.sum(x) 31 | 32 | 33 | def choice(values, ratios, rng=None): 34 | rng = rng or np.random.default_rng() 35 | return rng.choice(values, p=mkprob(ratios)) 36 | 37 | 38 | def mkpredictfns(predict_fn, config, template="predict_{x}"): 39 | # If we have two flexi args a=[1,2], b=[10,20], then we create a 40 | # predict_fn for all possible combinations, named "predict_a=1_b=10" etc. 41 | all_combinations = [dict(comb) for comb in itertools.product( 42 | *[[(arg, val) for val in config[arg].v] for arg in config] 43 | )] 44 | return { 45 | template.format(x="_".join(f"{k}={v}" for k, v in kw.items())): 46 | functools.partial(predict_fn, **kw) 47 | for kw in all_combinations} 48 | -------------------------------------------------------------------------------- /big_vision/trainers/proj/givt/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utils for GIVT stage I and II trainers.""" 16 | 17 | from typing import Any 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | 23 | def unbin_depth( 24 | depth: jax.Array, 25 | *, 26 | min_depth: float, 27 | max_depth: float, 28 | num_bins: int, 29 | ) -> jax.Array: 30 | """Transform a depth map with binned values into a float-valued depth map. 31 | 32 | Args: 33 | depth: Depth map whose binned values are encoded in one-hot fashion along 34 | the last dimension. 35 | min_depth: Minimum binned depth value. 36 | max_depth: Maximum value of binned depth. 37 | num_bins: Number of depth bins. 38 | 39 | Returns: 40 | Float-valued depth map. 41 | """ 42 | depth = jnp.argmax(depth, axis=-1) 43 | depth = depth.astype(jnp.float32) + 0.5 # Undoes floor in expectation. 44 | depth /= num_bins 45 | return depth * (max_depth - min_depth) + min_depth 46 | 47 | 48 | def get_local_rng( 49 | seed: int | jax.Array, 50 | batch: Any, 51 | ) -> jax.Array: 52 | """Generate a per-image seed based on the image id or the image values. 53 | 54 | Args: 55 | seed: Random seed from which per-image seeds should be derived. 56 | batch: Pytree containing a batch of images (key "image") and optionally 57 | image ids (key "image/id"). 58 | 59 | Returns: 60 | Array containing per-image ids. 61 | """ 62 | fake_id = None 63 | if "image" in batch: 64 | fake_id = (10**6 * jax.vmap(jnp.mean)(batch["image"])).astype(jnp.int32) 65 | return jax.lax.scan( 66 | lambda k, x: (jax.random.fold_in(k, x), None), 67 | jax.random.PRNGKey(seed), 68 | batch.get("image/id", fake_id), 69 | )[0] 70 | 71 | -------------------------------------------------------------------------------- /big_vision/trainers/proj/uvim/coco_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities to inspect coco data and predictions in notebooks.""" 16 | # pylint: disable=consider-using-from-import 17 | import functools 18 | import json 19 | 20 | import numpy as np 21 | from panopticapi import utils as pycoco_utils 22 | from skimage import segmentation 23 | 24 | import tensorflow.io.gfile as gfile 25 | 26 | 27 | import os 28 | ROOT = os.environ.get('COCO_DATA_DIR', '.') 29 | 30 | 31 | PANOPTIC_COCO_CATS_FILE = f'{ROOT}/panoptic_coco_categories.json' 32 | 33 | 34 | @functools.lru_cache(maxsize=None) 35 | def _coco_panoptic_categories(): 36 | with gfile.GFile(PANOPTIC_COCO_CATS_FILE, 'r') as f: 37 | categories_list = json.load(f) 38 | return tuple(categories_list) 39 | 40 | 41 | def rgb_panoptic_from_twochannels(twochannels, boundaries: bool = False): 42 | """Makes a RGB panoptic output and segments_info from a twochannels view.""" 43 | semantics = twochannels[..., 0] 44 | instances = twochannels[..., 1] 45 | max_instances = np.max(instances) + 1 46 | merged = semantics * max_instances + instances 47 | merged = np.where(semantics < 0, semantics, merged) 48 | 49 | categories_list = _coco_panoptic_categories() 50 | categories = {category['id']: category for category in categories_list} 51 | id_generator = pycoco_utils.IdGenerator(categories) 52 | segments_info = {} 53 | rgb = np.zeros((*instances.shape[:2], 3), dtype=np.uint8) 54 | 55 | for merged_id in np.unique(merged): 56 | if merged_id // max_instances > 0: 57 | category = categories_list[int(merged_id // max_instances) - 1] 58 | segment_id, color = id_generator.get_id_and_color(category['id']) 59 | else: 60 | category = {'id': -1, 'name': 'void', 'isthing': False} 61 | segment_id, color = -1, np.array([0, 0, 0]) 62 | segments_info[segment_id] = { 63 | 'id': segment_id, 64 | 'color': color, 65 | 'category_id': category['id'], 66 | 'name': category['name'], 67 | 'isthing': category['isthing'], 68 | } 69 | rgb[merged == merged_id] = color 70 | 71 | if boundaries: 72 | boundaries = segmentation.find_boundaries( 73 | pycoco_utils.rgb2id(rgb), mode='thick') 74 | rgb[boundaries] = 0 75 | return rgb, segments_info 76 | -------------------------------------------------------------------------------- /big_vision/trainers/proj/uvim/colorization_task.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Inputs, outputs and losses for colorization task.""" 16 | import einops 17 | import jax.numpy as jnp 18 | import numpy as np 19 | 20 | ONE_HOT_AXIS = -2 21 | 22 | 23 | def input_pp(batch, config): 24 | """Make inputs for colorization task.""" 25 | if "labels" not in batch: 26 | # During predict of phase2 there is no 'labels' field. 27 | x = None 28 | else: 29 | hp, wp = config.model.patch_size 30 | x = { 31 | "color": batch["labels"], 32 | } 33 | # Convert labels from (B, H, W) to (B, num_patches, C, patch_size) 34 | x["color"] = einops.rearrange( 35 | x["color"], "b (hn hp) (wn wp) c -> b (hn wn) c (hp wp)", hp=hp, wp=wp) 36 | ctx = batch.get("image_ctx", batch.get("image", None)) 37 | return {"ctx": ctx, "x": x} 38 | 39 | 40 | def loss_fn(logits, batch, config): 41 | """Compute loss for colorization task.""" 42 | labels = input_pp(batch, config)["x"] 43 | error = logits["color"] - labels["color"] 44 | loss = jnp.square(error) 45 | return loss, {"loss_color": loss} 46 | 47 | 48 | def predict_outputs(logits, config): 49 | """Make outputs for colorization task.""" 50 | # Map logits to (height, width, channels). 51 | hp, wp = config.model.patch_size 52 | hn, wn = np.array(config.model.input_size) // np.array((hp, wp)) 53 | assert ONE_HOT_AXIS == -2, "Rearrange below depends on this." 54 | output = einops.rearrange( 55 | logits["color"], 56 | "b (hn wn) c (hp wp) -> b (hn hp) (wn wp) c", 57 | hn=hn, 58 | wn=wn, 59 | hp=hp, 60 | wp=wp) 61 | output = jnp.clip(output, -1., 1.) 62 | return {"color": output} 63 | -------------------------------------------------------------------------------- /big_vision/trainers/proj/uvim/depth_task.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Inputs, outputs and losses for depth prediction task.""" 16 | import big_vision.utils as u 17 | import einops 18 | import jax 19 | import jax.numpy as jnp 20 | import numpy as np 21 | 22 | 23 | ONE_HOT_AXIS = -2 24 | 25 | 26 | def input_pp(batch, config): 27 | """Makes inputs for depth prediction task.""" 28 | if "labels" not in batch: 29 | x = None 30 | else: 31 | hp, wp = config.model.patch_size 32 | depth = batch["labels"][..., 0] 33 | 34 | # Discretize to [0, ..., bins - 1]. 35 | nbins = config.model.inputs.depth[ONE_HOT_AXIS] 36 | mind = config.min_depth 37 | maxd = config.max_depth 38 | depth = (depth - mind) / (maxd - mind) 39 | depth *= nbins 40 | depth = jnp.floor(depth).astype(jnp.int32) 41 | depth = jnp.minimum(depth, nbins - 1) 42 | depth = jnp.maximum(depth, 0) 43 | 44 | # Converts labels from (B, H, W, c) to (B, num_patches, c, patch_size). 45 | depth = jax.nn.one_hot( 46 | einops.rearrange( 47 | depth, "b (hn hp) (wn wp) -> b (hn wn) (hp wp)", hp=hp, wp=wp), 48 | num_classes=config.model.inputs.depth[ONE_HOT_AXIS], 49 | axis=ONE_HOT_AXIS) 50 | x = {"depth": depth} 51 | ctx = batch.get("image_ctx", batch.get("image", None)) 52 | return {"ctx": ctx, "x": x} 53 | 54 | 55 | def loss_fn(predictions, batch, config): 56 | """Computes loss for depth prediction task.""" 57 | labels = input_pp(batch, config)["x"] 58 | losses = {} 59 | loss = u.softmax_xent( 60 | logits=predictions["depth"], labels=labels["depth"], reduction=False, 61 | axis=ONE_HOT_AXIS) 62 | # Do not train on the closest class; usually regions of the image with 63 | # depth==0, which is the default for regions with no depth signal. 64 | # TODO: Encode depth==0 as class==-1. 65 | mask = jnp.argmax(labels["depth"], ONE_HOT_AXIS) != 0 66 | loss = loss * mask 67 | losses["loss_depth"] = loss 68 | return sum(losses.values()), losses 69 | 70 | 71 | def predict_outputs(predictions, config): 72 | """Makes outputs for depth predictin tasks.""" 73 | # Maps predictions to (height, width, channels). 74 | hp, wp = config.model.patch_size 75 | hn, wn = np.array(config.model.input_size) // np.array((hp, wp)) 76 | depth = einops.rearrange( 77 | predictions["depth"], 78 | "b (hn wn) c (hp wp) -> b (hn hp) (wn wp) c", 79 | hn=hn, wn=wn, hp=hp, wp=wp) 80 | 81 | depth = jnp.argmax(depth, axis=-1) # [B, H, W] 82 | 83 | # Revert discretization. 84 | nbins = config.model.inputs.depth[ONE_HOT_AXIS] 85 | mind = config.min_depth 86 | maxd = config.max_depth 87 | depth = depth.astype(jnp.float32) + 0.5 # Undoes floor in expectation. 88 | depth /= nbins 89 | depth = depth * (maxd - mind) + mind 90 | 91 | return {"depth": depth} 92 | -------------------------------------------------------------------------------- /big_vision/trainers/proj/uvim/panoptic_task.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Inputs, outputs and losses for panoptic task.""" 16 | import big_vision.utils as u 17 | import einops 18 | import jax 19 | import jax.numpy as jnp 20 | import numpy as np 21 | 22 | ONE_HOT_AXIS = -2 23 | 24 | 25 | def input_pp(batch, config): 26 | """Make inputs for panoptic segmentation task.""" 27 | if "labels" not in batch: 28 | # During predict of phase2 there is no 'labels' field. 29 | x = None 30 | else: 31 | hp, wp = config.model.patch_size 32 | x = { 33 | "semantics": batch["labels"][..., 0], 34 | "instances": batch["labels"][..., 1], 35 | } 36 | # Convert labels from (B, H, W) to (B, num_patches, num_classes, patch_size) 37 | for key in ["semantics", "instances"]: 38 | x[key] = jax.nn.one_hot( 39 | einops.rearrange( 40 | x[key], "b (hn hp) (wn wp) -> b (hn wn) (hp wp)", hp=hp, wp=wp), 41 | num_classes=config.model.inputs[key][ONE_HOT_AXIS], axis=ONE_HOT_AXIS) 42 | ctx = batch.get("image_ctx", batch.get("image", None)) 43 | return {"ctx": ctx, "x": x} 44 | 45 | 46 | def loss_fn(logits, batch, config): 47 | """Compute loss for panoptic task.""" 48 | labels = input_pp(batch, config)["x"] 49 | losses = {} 50 | for key in ["semantics", "instances"]: 51 | losses[f"loss_{key}"] = u.softmax_xent( 52 | logits=logits[key], labels=labels[key], reduction=False, 53 | axis=ONE_HOT_AXIS) 54 | return sum(losses.values()), losses 55 | 56 | 57 | def predict_outputs(logits, config, min_fraction=0.0): 58 | """Make outputs for panoptic segmentation task.""" 59 | # Map logits to (height, width, channels). 60 | hp, wp = config.model.patch_size 61 | hn, wn = np.array(config.model.input_size) // np.array((hp, wp)) 62 | outputs = {} 63 | for key in ["semantics", "instances"]: 64 | assert ONE_HOT_AXIS == -2, "Rearrange below depends on this." 65 | outputs[key] = einops.rearrange( 66 | logits[key], 67 | "b (hn wn) c (hp wp) -> b (hn hp) (wn wp) c", 68 | hn=hn, wn=wn, hp=hp, wp=wp) 69 | return panoptic_predictions_from_logits( 70 | **outputs, min_fraction=min_fraction) 71 | 72 | 73 | def panoptic_predictions_from_logits(semantics, instances, min_fraction=0.0): 74 | """Make panoptic prediction from logits.""" 75 | ins = jnp.argmax(instances, axis=-1) 76 | # Note: Make sure each instance has all pixels annotated with same label. 77 | # Otherwise they are further split into more instances and greatly affect 78 | # the number of unmatched predicted segments (FP) and RQ. 79 | masks = jax.nn.one_hot(ins, instances.shape[-1], dtype=jnp.int32) 80 | label = jnp.argmax(jnp.einsum("bhwk,bhwn->bnk", semantics, masks), axis=-1) 81 | sem = jnp.einsum("bhwn,bn->bhw", masks, label) 82 | out = jnp.stack([sem, ins], axis=-1) 83 | # Filter out small objects 84 | fraction = jnp.sum(masks, axis=(1, 2), keepdims=True)/np.prod(ins.shape[1:3]) 85 | mask_big = (fraction > min_fraction).astype("int32") 86 | mask_big_spatial = jnp.sum(masks * mask_big, axis=-1, keepdims=True) > 0 87 | return out * mask_big_spatial.astype("int32") 88 | --------------------------------------------------------------------------------