├── LICENSE ├── README.md ├── bridge_with_digit ├── .gitignore ├── README.md ├── docker-compose.yml ├── generate_usb_config.sh ├── host_install.sh ├── host_install_scripts │ ├── docker_compose_install.sh │ ├── docker_install.sh │ ├── nvidia_docker_install.sh │ ├── ros_install.sh │ └── widowx_install.sh ├── requirements.txt ├── test_utils │ ├── count_progress.py │ ├── plot-mic.py │ ├── read-ip.py │ ├── test-cam.py │ └── test-server.py └── widowx_envs │ ├── Dockerfile │ ├── LICENSE │ ├── README.md │ ├── curr_ip.py │ ├── experiments │ └── bridge_data_v2 │ │ ├── collection_metadata.json │ │ └── conf.py │ ├── multicam_server │ ├── CMakeLists.txt │ ├── launch │ │ ├── cameras.launch │ │ ├── imu_streamer.launch │ │ ├── mic_streamer.launch │ │ ├── overhead_streamer.launch │ │ ├── streamer.launch │ │ └── video_streamer.launch │ ├── package.xml │ ├── setup.py │ └── src │ │ ├── __init__.py │ │ ├── imu_socket_server.py │ │ ├── imu_streamer.py │ │ ├── mic_streamer.py │ │ ├── multicam_server │ │ ├── __init__.py │ │ ├── angle_streamer.py │ │ ├── camera_recorder.py │ │ ├── sensor_recorder.py │ │ └── topic_utils.py │ │ ├── overhead_streamer.py │ │ ├── socket_config.py │ │ ├── start_streamers.py │ │ ├── streamer.py │ │ ├── usb_connector_chart_example.yml │ │ ├── usbreset.c │ │ └── video_socket_streamer.py │ ├── requirements.txt │ ├── scripts │ ├── go_to_neutral_pose.py │ ├── go_to_sleep_pose.py │ ├── run.sh │ └── setup.sh │ ├── setup.py │ ├── widowx_controller │ ├── CMakeLists.txt │ ├── launch │ │ ├── launch.launch │ │ └── widowx_rs.launch │ ├── package.xml │ ├── setup.py │ ├── src │ │ └── widowx_controller │ │ │ ├── __init__.py │ │ │ ├── controller_base.py │ │ │ ├── custom_gripper_controller.py │ │ │ ├── velocity_controller.py │ │ │ ├── vr_controller_client.py │ │ │ ├── vr_controller_server.py │ │ │ └── widowx_controller.py │ └── srv │ │ ├── DisableController.srv │ │ ├── EnableController.srv │ │ ├── GetCartesianPose.srv │ │ ├── GetGripperDesiredState.srv │ │ ├── GetState.srv │ │ ├── GetVRButtons.srv │ │ ├── GotoNeutral.srv │ │ ├── MoveToEEP.srv │ │ ├── MoveToState.srv │ │ ├── OpenGripper.srv │ │ └── SetGripperPosition.srv │ └── widowx_envs │ ├── __init__.py │ ├── base │ ├── base_env.py │ ├── robot_base_env.py │ └── robot_configs.json │ ├── control_loops.py │ ├── policies │ ├── __init__.py │ ├── policy.py │ └── vr_teleop_policy.py │ ├── run_data_collection.py │ ├── teleop.py │ ├── trajectory_collector.py │ ├── utils │ ├── __init__.py │ ├── exceptions.py │ ├── grasp_utils.py │ ├── image_utils.py │ ├── metadata_helper.py │ ├── raw_saver.py │ ├── sync.py │ ├── transformation_utils.py │ └── utils.py │ ├── widowx_env.py │ └── widowx_env_service.py ├── media └── teaser.jpg ├── octo_digit ├── .flake8 ├── .github │ └── workflows │ │ └── pre-commit.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .python-version ├── LICENSE ├── README.md ├── __init__.py ├── docs │ └── assets │ │ └── teaser.jpg ├── eval │ ├── __init__.py │ ├── compositional_task.py │ ├── decode_config.py │ ├── decode_language.py │ ├── envs │ │ ├── __init__.py │ │ ├── gym_wrappers.py │ │ └── widowx_env.py │ ├── eval_config.py │ ├── eval_requirements.txt │ ├── fuse_compositional_task.py │ ├── fuse_eval.py │ ├── read_eval_logs.py │ ├── read_logs_to_csv.py │ ├── recursive_dict_print.py │ ├── setup.py │ └── utils.py ├── mem_lims.sh ├── octo │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── dataset.py │ │ ├── obs_transforms.py │ │ ├── oxe │ │ │ ├── __init__.py │ │ │ ├── oxe_dataset_configs.py │ │ │ ├── oxe_dataset_mixes.py │ │ │ └── oxe_standardization_transforms.py │ │ ├── traj_transforms.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── data_utils.py │ │ │ ├── goal_relabeling.py │ │ │ ├── task_augmentation.py │ │ │ └── text_processing.py │ ├── model │ │ ├── __init__.py │ │ ├── bcz_model.py │ │ ├── bcz_module.py │ │ ├── components │ │ │ ├── __init__.py │ │ │ ├── action_heads.py │ │ │ ├── base.py │ │ │ ├── bcz_action_heads.py │ │ │ ├── bcz_language_reconstruction_heads.py │ │ │ ├── block_transformer.py │ │ │ ├── diffusion.py │ │ │ ├── film_conditioning_layer.py │ │ │ ├── language_reconstruction_heads.py │ │ │ ├── mae_encoder.py │ │ │ ├── t3_vit.py │ │ │ ├── tokenizers.py │ │ │ ├── transformer.py │ │ │ ├── tvl_vit.py │ │ │ ├── unet.py │ │ │ ├── value_heads.py │ │ │ └── vit_encoders.py │ │ ├── octo_model.py │ │ ├── octo_module.py │ │ └── resnet_model.py │ ├── setup.py │ └── utils │ │ ├── __init__.py │ │ ├── fuse_constants.py │ │ ├── fuse_utils.py │ │ ├── gradcam.py │ │ ├── gym_wrappers.py │ │ ├── jax_utils.py │ │ ├── logging_utils.py │ │ ├── spec.py │ │ ├── train_callbacks.py │ │ ├── train_utils.py │ │ ├── typing.py │ │ └── visualization_lib.py ├── pyproject.toml ├── requirements.txt ├── scripts │ ├── configs │ │ ├── config.py │ │ ├── finetune_config.py │ │ ├── fuse_config.py │ │ └── octo_pretrain_config.py │ ├── finetune.py │ ├── finetune_fuse.py │ ├── finetune_fuse_pods.py │ ├── finetune_resnet.py │ ├── save_viz.py │ ├── test_mem.py │ ├── train.py │ └── viz.ipynb └── setup.py └── palivla_digit ├── .gitignore ├── .python-version ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── big_vision ├── __init__.py ├── configs │ ├── __init__.py │ ├── bit_i1k.py │ ├── bit_i21k.py │ ├── common.py │ ├── common_fewshot.py │ ├── load_and_eval.py │ ├── mlp_mixer_i1k.py │ ├── proj │ │ ├── cappa │ │ │ ├── README.md │ │ │ ├── cappa_architecture.png │ │ │ └── pretrain.py │ │ ├── clippo │ │ │ ├── README.md │ │ │ └── train_clippo.py │ │ ├── distill │ │ │ ├── README.md │ │ │ ├── bigsweep_flowers_pet.py │ │ │ ├── bigsweep_food_sun.py │ │ │ ├── bit_i1k.py │ │ │ └── common.py │ │ ├── flexivit │ │ │ ├── README.md │ │ │ ├── i1k_deit3_distill.py │ │ │ ├── i21k_distill.py │ │ │ ├── i21k_sup.py │ │ │ └── timing.py │ │ ├── givt │ │ │ ├── README.md │ │ │ ├── givt_coco_panoptic.py │ │ │ ├── givt_imagenet2012.py │ │ │ ├── givt_nyu_depth.py │ │ │ ├── givt_overview.png │ │ │ ├── vae_coco_panoptic.py │ │ │ └── vae_nyu_depth.py │ │ ├── gsam │ │ │ └── vit_i1k_gsam_no_aug.py │ │ ├── image_text │ │ │ ├── README.md │ │ │ ├── common.py │ │ │ └── siglip_lit_coco.py │ │ ├── paligemma │ │ │ ├── README.md │ │ │ ├── paligemma.png │ │ │ └── transfers │ │ │ │ ├── activitynet_cap.py │ │ │ │ ├── activitynet_qa.py │ │ │ │ ├── ai2d.py │ │ │ │ ├── aokvqa_da.py │ │ │ │ ├── aokvqa_mc.py │ │ │ │ ├── chartqa.py │ │ │ │ ├── coco35l.py │ │ │ │ ├── cococap.py │ │ │ │ ├── common.py │ │ │ │ ├── docvqa.py │ │ │ │ ├── forkme.py │ │ │ │ ├── gqa.py │ │ │ │ ├── infovqa.py │ │ │ │ ├── msrvtt_cap.py │ │ │ │ ├── msrvtt_qa.py │ │ │ │ ├── msvd_qa.py │ │ │ │ ├── nlvr2.py │ │ │ │ ├── ocrvqa.py │ │ │ │ ├── okvqa.py │ │ │ │ ├── pope.py │ │ │ │ ├── refcoco_seg.py │ │ │ │ ├── rsvqa_hr.py │ │ │ │ ├── rsvqa_lr.py │ │ │ │ ├── scicap.py │ │ │ │ ├── science_qa.py │ │ │ │ ├── screen2words.py │ │ │ │ ├── stvqa.py │ │ │ │ ├── tallyqa.py │ │ │ │ ├── textcaps.py │ │ │ │ ├── textvqa.py │ │ │ │ ├── vatex_cap.py │ │ │ │ ├── vertexai_l4.py │ │ │ │ ├── vizwizvqa.py │ │ │ │ ├── vqav2.py │ │ │ │ └── widgetcap.py │ │ ├── reward_tune │ │ │ └── detection_reward.py │ │ ├── scaling_laws │ │ │ └── train_vit_g.py │ │ └── uvim │ │ │ ├── README.md │ │ │ ├── train_coco_panoptic_pretrained.py │ │ │ ├── train_imagenet2012_colorization_pretrained.py │ │ │ ├── train_nyu_depth_pretrained.py │ │ │ ├── vqvae_coco_panoptic.py │ │ │ ├── vqvae_imagenet2012_colorization.py │ │ │ └── vqvae_nyu_depth.py │ ├── transfer.py │ ├── vit_i1k.py │ ├── vit_i21k.py │ └── vit_s16_i1k.py ├── datasets │ ├── ai2d │ │ └── ai2d.py │ ├── aokvqa │ │ └── aokvqa.py │ ├── chartqa │ │ └── chartqa.py │ ├── coco35l │ │ └── coco35l.py │ ├── core.py │ ├── countbenchqa │ │ ├── countbenchqa.py │ │ └── data │ │ │ └── countbench_paired_questions.json │ ├── docvqa │ │ └── docvqa.py │ ├── gqa │ │ └── gqa.py │ ├── imagenet │ │ └── class_names.py │ ├── infovqa │ │ └── infovqa.py │ ├── jsonl.py │ ├── nocaps │ │ └── nocaps.py │ ├── okvqa │ │ └── okvqa.py │ ├── pope │ │ └── pope.py │ ├── refcoco │ │ └── refcoco.py │ ├── rsvqa_hr │ │ └── rsvqa_hr.py │ ├── rsvqa_lr │ │ └── rsvqa_lr.py │ ├── scicap │ │ └── scicap.py │ ├── science_qa │ │ └── science_qa.py │ ├── screen2words │ │ └── screen2words.py │ ├── sequence_packing.py │ ├── stvqa │ │ └── stvqa.py │ ├── tallyqa │ │ └── tallyqa.py │ ├── textcaps │ │ └── textcaps.py │ ├── textvqa │ │ └── textvqa.py │ ├── tfds.py │ ├── vizwizvqa │ │ └── vizwizvqa.py │ ├── vqa │ │ └── vqa.py │ ├── widgetcap │ │ └── widgetcap.py │ ├── xgqa │ │ └── xgqa.py │ └── xm3600 │ │ └── xm3600.py ├── evaluators │ ├── __init__.py │ ├── classification.py │ ├── common.py │ ├── fewshot_lsr.py │ ├── mean.py │ ├── proj │ │ ├── cappa │ │ │ ├── perplexity.py │ │ │ └── scoring_classifier.py │ │ ├── distill │ │ │ └── distance.py │ │ ├── givt │ │ │ ├── coco_panoptic.py │ │ │ ├── nyu_depth.py │ │ │ └── save_predictions.py │ │ ├── image_text │ │ │ ├── contrastive.py │ │ │ ├── discriminative_classifier.py │ │ │ ├── discriminative_classifier_test.py │ │ │ ├── image_text_retrieval.py │ │ │ ├── image_text_retrieval_test.py │ │ │ ├── prompt_engineering.py │ │ │ ├── prompt_engineering_constants.py │ │ │ ├── prompt_engineering_test.py │ │ │ ├── retrieval.py │ │ │ └── retrieval_test.py │ │ ├── paligemma │ │ │ ├── perplexity.py │ │ │ └── transfers │ │ │ │ ├── chartqa.py │ │ │ │ ├── coco_caption.py │ │ │ │ ├── pope.py │ │ │ │ ├── rsvqa.py │ │ │ │ ├── science_qa.py │ │ │ │ ├── segmentation.py │ │ │ │ ├── storepreds.py │ │ │ │ ├── tallyqa.py │ │ │ │ ├── vqa.py │ │ │ │ └── vqav2.py │ │ └── uvim │ │ │ ├── coco_panoptic.py │ │ │ ├── coltran_fid.py │ │ │ ├── coltran_fid_data │ │ │ ├── eval_file_names.txt │ │ │ └── reference_file_names.txt │ │ │ ├── common.py │ │ │ ├── compute_mean.py │ │ │ ├── nyu_depth.py │ │ │ ├── psnr.py │ │ │ └── save_predictions.py │ └── save.py ├── input_pipeline.py ├── models │ ├── __init__.py │ ├── bit.py │ ├── bit_paper.py │ ├── common.py │ ├── mlp_mixer.py │ ├── ppp │ │ ├── __init__.py │ │ └── gemma.py │ ├── proj │ │ ├── cappa │ │ │ └── cappa.py │ │ ├── clippo │ │ │ └── one_tower.py │ │ ├── flaxformer │ │ │ ├── bert.py │ │ │ ├── bert_test.py │ │ │ └── bert_test_util.py │ │ ├── flexi │ │ │ ├── vit.py │ │ │ └── vit_test.py │ │ ├── givt │ │ │ ├── adaptor.py │ │ │ ├── adaptor_test.py │ │ │ ├── cnn.py │ │ │ ├── decode.py │ │ │ ├── decode_test.py │ │ │ ├── givt.py │ │ │ ├── givt_test.py │ │ │ ├── parallel_decode.py │ │ │ ├── parallel_decode_test.py │ │ │ ├── vae.py │ │ │ └── vit.py │ │ ├── image_text │ │ │ ├── text_transformer.py │ │ │ └── two_towers.py │ │ ├── paligemma │ │ │ ├── gemma_bv.py │ │ │ └── paligemma.py │ │ └── uvim │ │ │ ├── decode.py │ │ │ ├── vit.py │ │ │ ├── vit_test.py │ │ │ ├── vtt.py │ │ │ └── vtt_test.py │ └── vit.py ├── optax.py ├── optax_test.py ├── pp │ ├── __init__.py │ ├── archive │ │ ├── __init__.py │ │ ├── autoaugment.py │ │ └── randaug.py │ ├── autoaugment.py │ ├── builder.py │ ├── builder_test.py │ ├── ops_general.py │ ├── ops_general_test.py │ ├── ops_image.py │ ├── ops_image_test.py │ ├── ops_text.py │ ├── ops_text_test.py │ ├── proj │ │ ├── clippo │ │ │ ├── download_unifont.sh │ │ │ └── pp_ops.py │ │ ├── flaxformer │ │ │ ├── bert_ops.py │ │ │ └── bert_ops_test.py │ │ ├── givt │ │ │ └── pp_ops.py │ │ ├── paligemma │ │ │ ├── ops.py │ │ │ ├── robustness.py │ │ │ ├── sciqa_ops.py │ │ │ ├── segmentation.py │ │ │ ├── video.py │ │ │ └── widgetcap.py │ │ └── uvim │ │ │ ├── pp_ops.py │ │ │ └── pp_ops_test.py │ ├── registry.py │ ├── registry_test.py │ ├── tokenizer.py │ ├── utils.py │ └── utils_test.py ├── requirements.txt ├── run_tpu.sh ├── sharding.py ├── tools │ ├── download_tfds_datasets.py │ ├── eval_only.py │ └── lit_demo │ │ ├── README.md │ │ ├── build.js │ │ ├── package.json │ │ └── src │ │ ├── app.ts │ │ ├── components │ │ ├── image-carousel.scss │ │ ├── image-carousel.ts │ │ ├── image-prompts.scss │ │ ├── image-prompts.ts │ │ ├── lit-demo-app.scss │ │ ├── lit-demo-app.ts │ │ ├── loading-animation.scss │ │ ├── loading-animation.ts │ │ ├── message-list.scss │ │ ├── message-list.ts │ │ ├── model-controls.scss │ │ └── model-controls.ts │ │ ├── exports.ts │ │ ├── index.html │ │ ├── lit_demo │ │ ├── app.ts │ │ ├── compute.ts │ │ ├── constants.ts │ │ ├── data.ts │ │ └── url_utils.ts │ │ ├── playground.html │ │ ├── style.scss │ │ ├── style │ │ ├── colors.scss │ │ └── mixins.scss │ │ ├── tokenizers │ │ ├── common.ts │ │ ├── index.ts │ │ ├── sentencepiece_bpe.ts │ │ ├── sentencepiece_bpe_test.ts │ │ ├── sentencepiece_unigram.ts │ │ ├── sentencepiece_unigram_test.ts │ │ └── trie.ts │ │ └── tsconfig.json ├── train.py ├── trainers │ └── proj │ │ ├── cappa │ │ ├── generative.py │ │ └── predict_fns.py │ │ ├── distill │ │ └── distill.py │ │ ├── flexi │ │ ├── common.py │ │ ├── distill.py │ │ └── train.py │ │ ├── givt │ │ ├── generative.py │ │ ├── utils.py │ │ └── vae.py │ │ ├── gsam │ │ ├── gsam.py │ │ └── train.py │ │ ├── image_text │ │ ├── _deprecated_contrastive.py │ │ └── siglip.py │ │ ├── paligemma │ │ ├── predict_fns.py │ │ ├── run.py │ │ └── train.py │ │ └── uvim │ │ ├── coco_utils.py │ │ ├── colorization_task.py │ │ ├── depth_task.py │ │ ├── panoptic_task.py │ │ ├── train.py │ │ └── vqvae.py ├── utils.py └── utils_test.py ├── eval_palivla.py ├── palivla ├── __init__.py ├── configs │ ├── aloha_config.py │ ├── aloha_tokenizer_config.py │ ├── bridge_config.py │ ├── debug_config.py │ ├── fuse_config.py │ └── oxe_config.py ├── dataset.py ├── eval_step.py ├── fuse_eval.py ├── inference.py ├── learned_tokenizer.py ├── load_model.py ├── modality_embedder.py ├── model.py ├── predict_fns.py ├── sentencepiece_model_pb2.py ├── spec.py ├── tactile_encoder_pooled.py ├── tokenizer.py ├── train_fuse.py ├── train_state.py ├── train_step.py ├── train_tokenizer.py ├── types.py └── utils.py ├── pod_config_fuse.py ├── pyproject.toml ├── run.bash ├── setup.bash ├── setup_pod.sh └── ssh_pod.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Robotic AI & Learning Lab Berkeley 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /bridge_with_digit/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.9" 2 | services: 3 | robonet-digit: 4 | image: robonet-base-digit 5 | container_name: robonet-digit 6 | build: 7 | context: widowx_envs 8 | entrypoint: ./widowx_envs/scripts/run.sh 9 | environment: 10 | - DISPLAY=:0 11 | user: robonet:1002 12 | stdin_open: true 13 | tty: true 14 | device_cgroup_rules: 15 | - 'c *:* rmw' 16 | network_mode: host 17 | volumes: 18 | - ${USB_CONNECTOR_CHART:-/dev/null}:/tmp/camera_connector_chart 19 | # if you want to be able to edit widowx_envs code without rebuilding the image, this 20 | # overwrites the widowx_envs directory that was copied in at build time 21 | - ./experiments:/home/robonet/experiments 22 | - ./data:/home/robonet/trainingdata # TODO (YL): better way to mount a local directory 23 | # $HOME/widowx_data:/home/robonet/trainingdata 24 | - ./widowx_envs:/home/robonet/widowx_envs 25 | - /dev:/dev # for host tty access 26 | 27 | bridge_data_v2: 28 | image: robonet-digit-bridge-data-v2 29 | container_name: robonet_digit_bridge_data_v2 30 | user: robonet:1002 31 | build: 32 | context: code/bridge_data_v2 33 | network_mode: host 34 | environment: 35 | - WANDB_API_KEY=putwandbkeyhere 36 | volumes: 37 | - ./widowx_envs:/home/robonet/widowx_envs 38 | - ./code/bridge_data_v2:/home/robonet/code/bridge_data_v2 39 | - ./experiments:/home/robonet/experiments 40 | - /media/harddrive/:/home/robonet/trainingdata 41 | -------------------------------------------------------------------------------- /bridge_with_digit/generate_usb_config.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Initialize the YAML variables to empty strings 4 | blue="" 5 | yellow="" 6 | wrist="" 7 | d435="" 8 | 9 | # Run v4l2-ctl to fetch devices and parse them line by line 10 | while IFS= read -r line; do 11 | # Check for the device identifiers and store them accordingly 12 | if [[ $line == *"Piwebcam: UVC Camera"* ]]; then 13 | wrist=$(echo "$line" | awk -F '(' '{print $2}' | awk -F ')' '{print $1}') 14 | elif [[ $line == *"HD Pro Webcam C920"* ]] && [ -z "$blue" ]; then 15 | blue=$(echo "$line" | awk -F '(' '{print $2}' | awk -F ')' '{print $1}') 16 | elif [[ $line == *"HD Pro Webcam C920"* ]]; then 17 | yellow=$(echo "$line" | awk -F '(' '{print $2}' | awk -F ')' '{print $1}') 18 | elif [[ $line == *"Intel(R) RealSense(TM) Depth Ca"* ]]; then 19 | d435=$(echo "$line" | awk -F 'Ca ' '{print $2}' | awk -F ')' '{print $1}') 20 | fi 21 | done < <(v4l2-ctl --list-devices) 22 | 23 | # Print the generated YAML format 24 | cat << EOF > usb_connector_chart.yml 25 | blue: '$blue' 26 | yellow: '$yellow' 27 | wrist: '$wrist' 28 | D435: '$d435' 29 | EOF 30 | -------------------------------------------------------------------------------- /bridge_with_digit/host_install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | read -r -p "Install ROS? [Y/n] " response 5 | case "$response" in 6 | [nN][oO]|[nN]) 7 | install_ros=false 8 | ;; 9 | *) 10 | install_ros=true 11 | ;; 12 | esac 13 | 14 | read -r -p "Install WidowX drivers? [Y/n] " response 15 | case "$response" in 16 | [nN][oO]|[nN]) 17 | install_widowx=false 18 | ;; 19 | *) 20 | install_widowx=true 21 | ;; 22 | esac 23 | 24 | read -r -p "Install Docker? [Y/n] " response 25 | case "$response" in 26 | [nN][oO]|[nN]) 27 | install_docker=false 28 | ;; 29 | *) 30 | install_docker=true 31 | read -r -p "Do you want to use Docker as a non-root user (without sudo)? [Y/n] " response 32 | case "$response" in 33 | [nN][oO]|[nN]) 34 | docker_without_sudo=false 35 | ;; 36 | *) 37 | docker_without_sudo=true 38 | echo "Remember to log out and back in for this to take effect after the script completes!" 39 | if [[ $EUID -eq 0 ]]; then 40 | echo "If you want to set up Docker as a non-root user run this script $0 as non-root user." 41 | exit 1 42 | fi 43 | ;; 44 | esac 45 | ;; 46 | esac 47 | 48 | 49 | read -r -p "Install nvidia-docker? [Y/n] " response 50 | case "$response" in 51 | [nN][oO]|[nN]) 52 | install_nv_docker=false 53 | ;; 54 | *) 55 | install_nv_docker=true 56 | echo "If you want to use Nvidia GPU in docker, please make sure that nvidia drivers are already installed. https://github.com/NVIDIA/nvidia-docker/wiki/Frequently-Asked-Questions#how-do-i-install-the-nvidia-driver" 57 | read -r -p "Do you want to continue? [Y/n] " response 58 | case "$response" in 59 | [nN][oO]|[nN]) 60 | exit 0 61 | esac 62 | ;; 63 | esac 64 | 65 | read -r -p "Install docker-compose? [Y/n] " response 66 | case "$response" in 67 | [nN][oO]|[nN]) 68 | install_docker_compose=false 69 | ;; 70 | *) 71 | install_docker_compose=true 72 | ;; 73 | esac 74 | 75 | full_path=$(realpath $0) 76 | dir_path=$(dirname $full_path)/host_install_scripts 77 | 78 | if [ "$install_ros" = "true" ]; then 79 | $dir_path/ros_install.sh 80 | fi 81 | if [ "$install_widowx" = "true" ]; then 82 | $dir_path/widowx_install.sh 83 | fi 84 | if [ "$install_docker" = "true" ]; then 85 | $dir_path/docker_install.sh $docker_without_sudo 86 | fi 87 | if [ "$install_nv_docker" = "true" ]; then 88 | $dir_path/nvidia_docker_install.sh 89 | fi 90 | if [ "$install_docker_compose" = "true" ]; then 91 | $dir_path/docker_compose_install.sh 92 | fi 93 | 94 | echo "All done!" 95 | -------------------------------------------------------------------------------- /bridge_with_digit/host_install_scripts/docker_compose_install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sudo curl -L "https://github.com/docker/compose/releases/download/1.29.0/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose 4 | sudo chmod +x /usr/local/bin/docker-compose -------------------------------------------------------------------------------- /bridge_with_digit/host_install_scripts/docker_install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | docker_without_sudo=$1 4 | curl -fsSL https://get.docker.com -o /tmp/get-docker.sh || (echo "get-docker.sh download failed" && exit 1) 5 | sudo sh /tmp/get-docker.sh || (echo "Docker install failed" && exit 1) 6 | sudo systemctl --now enable docker || (echo "Docker startup failed" && exit 1) 7 | 8 | if [ "$docker_without_sudo" = "true" ]; then 9 | if [[ $EUID -eq 0 ]]; then 10 | echo "If you want to set up Docker as a non-root user run this script $0 as non-root user." 11 | exit 1 12 | fi 13 | sudo usermod -aG docker $USER 14 | fi -------------------------------------------------------------------------------- /bridge_with_digit/host_install_scripts/nvidia_docker_install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) \ 4 | && curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - \ 5 | && curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list 6 | sudo apt-get update 7 | sudo apt-get install -y nvidia-docker2 8 | sudo systemctl restart docker 9 | sudo docker run --rm --gpus all nvidia/cuda:11.0-base nvidia-smi > /dev/null || (echo "nvidia-docker test failed" && exit 1) 10 | -------------------------------------------------------------------------------- /bridge_with_digit/host_install_scripts/widowx_install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget -q https://raw.githubusercontent.com/Interbotix/interbotix_ros_core/main/interbotix_ros_xseries/interbotix_xs_sdk/99-interbotix-udev.rules -O /tmp/99-interbotix-udev.rules || (echo "interbotix-udev.rules download failed" && exit 1) 4 | sudo cp /tmp/99-interbotix-udev.rules /etc/udev/rules.d/ 5 | sudo udevadm control --reload-rules && sudo udevadm trigger 6 | -------------------------------------------------------------------------------- /bridge_with_digit/requirements.txt: -------------------------------------------------------------------------------- 1 | gym >= 0.26 2 | numpy==1.24.3 3 | jax==0.4.13 4 | distrax==0.1.2 5 | flax==0.7.0 6 | ml_collections >= 0.1.0 7 | tqdm >= 4.60.0 8 | chex==0.1.6 9 | optax==0.1.5 10 | absl-py >= 0.12.0 11 | scipy >= 1.6.0 12 | wandb >= 0.12.14 13 | tensorflow==2.13.0 14 | tensorflow_probability==0.21 15 | tensorflow_hub 16 | tensorflow_text 17 | einops >= 0.6.1 18 | imageio >= 2.31.1 19 | moviepy >= 1.0.3 20 | orbax 21 | matplotlib 22 | pyquaternion 23 | opencv-python 24 | opencv-contrib-python 25 | funcsigs 26 | adafruit-circuitpython-bno055 27 | sounddevice -------------------------------------------------------------------------------- /bridge_with_digit/test_utils/count_progress.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | 5 | data_dir = '/home/paulzhou/bridge_with_digit/v3_dnw' 6 | 7 | 8 | num_collected_trajectories = 0 9 | total_len = 0 10 | for dated_dir in os.listdir(data_dir): 11 | if dated_dir[:4] != "2024" or 'raw' not in os.listdir(os.path.join(data_dir, dated_dir)): 12 | continue 13 | dated_dir = os.path.join(data_dir, dated_dir, 'raw', 'traj_group0') 14 | for trajdir in os.listdir(dated_dir): 15 | if trajdir[:4] != 'traj': 16 | continue 17 | assert os.path.exists(os.path.join(dated_dir, trajdir, 'images0', 'im_0.jpg')) 18 | num_collected_trajectories += 1 19 | 20 | total_len += len(os.listdir(os.path.join(dated_dir, trajdir, 'images0'))) 21 | 22 | 23 | 24 | print(f'Collected {num_collected_trajectories} so far!') 25 | print(f'Which corresponds to {total_len * 0.2} seconds') 26 | print(f'Or {total_len * 0.2 / 60} minutes') 27 | print(f'Or {total_len * 0.2 / 3600.0} hours') -------------------------------------------------------------------------------- /bridge_with_digit/test_utils/plot-mic.py: -------------------------------------------------------------------------------- 1 | # import sounddevice as sd 2 | # import numpy as np 3 | # import matplotlib.pyplot as plt 4 | 5 | # # Settings 6 | # duration = 3 # Duration of the recording in seconds 7 | # fs = 44100 # Sampling frequency 8 | 9 | # # Record audio 10 | # print("Recording...") 11 | # audio_data = sd.rec(int(duration * fs), samplerate=fs, channels=1, dtype='float32') 12 | # sd.wait() # Wait for recording to finish 13 | 14 | # # Plot the waveform of the recorded audio 15 | # plt.figure(figsize=(10, 6)) 16 | # time_axis = np.linspace(0, duration, len(audio_data)) 17 | # plt.plot(time_axis, audio_data) 18 | # plt.xlabel("Time (seconds)") 19 | # plt.ylabel("Amplitude") 20 | # plt.title("Recorded Audio Waveform") 21 | # plt.show() 22 | 23 | 24 | 25 | import sounddevice as sd 26 | import numpy as np 27 | import matplotlib.pyplot as plt 28 | import pickle 29 | 30 | 31 | with open() 32 | 33 | # Settings 34 | duration = 3 # Duration of the recording in seconds 35 | fs = 44100 # Sampling frequency 36 | 37 | # Record audio 38 | # print("Recording...") 39 | path = '/home/paulzhou/bridge_with_digit/data/audio11.pkl' 40 | # path = '/home/paulzhou/bridge_with_digit/widowx_envs/audio_data.pkl' 41 | with open(path, 'rb') as file: 42 | audio_data = pickle.load(file) 43 | 44 | # Perform FFT (Fast Fourier Transform) to analyze frequency content 45 | 46 | n = len(audio_data) 47 | frequencies = np.fft.rfftfreq(n, d=1/fs) # Real FFT frequencies 48 | magnitude_spectrum = np.abs(np.fft.rfft(audio_data, axis=0)) 49 | 50 | # with open('./audio_data.pkl', 'wb') as file: 51 | # pickle.dump(audio_data, file) 52 | 53 | # exit(0) 54 | 55 | # Plot the waveform and frequency spectrum of the recorded audio 56 | plt.figure(figsize=(12, 6)) 57 | 58 | # Plot waveform 59 | plt.subplot(2, 1, 1) 60 | time_axis = np.linspace(0, duration, n) 61 | plt.plot(time_axis, audio_data) 62 | plt.xlabel("Time (seconds)") 63 | plt.ylabel("Amplitude") 64 | plt.title("Recorded Audio Waveform") 65 | 66 | # Plot frequency spectrum 67 | plt.subplot(2, 1, 2) 68 | plt.plot(frequencies, magnitude_spectrum) 69 | plt.xlabel("Frequency (Hz)") 70 | plt.ylabel("Magnitude") 71 | plt.title("Frequency Spectrum") 72 | plt.tight_layout() 73 | plt.show() 74 | 75 | 76 | import time 77 | 78 | plt.savefig(f'./mic_test_{time.time()}.png') -------------------------------------------------------------------------------- /bridge_with_digit/test_utils/read-ip.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | 4 | 5 | 6 | if __name__ == '__main__': 7 | if len(sys.argv) == 1: 8 | dev_name = 'enp114s0' 9 | else: 10 | dev_name = sys.argv[1] 11 | ifconfig_out = str(subprocess.check_output(['ifconfig'])) 12 | dev_only = ifconfig_out[ifconfig_out.index(dev_name):] 13 | inet_str = 'inet ' 14 | dev_only = dev_only[dev_only.index(inet_str) + len(inet_str):] 15 | dev_only = dev_only[:dev_only.index(' ')] 16 | 17 | full_str = f"CURR_IP = '{dev_only}'" 18 | write_path = '/home/paulzhou/bridge_with_digit/widowx_envs/curr_ip.py' 19 | with open(write_path, 'w') as file: 20 | file.write(full_str) 21 | -------------------------------------------------------------------------------- /bridge_with_digit/test_utils/test-cam.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | import time 4 | import sys 5 | if len(sys.argv) < 2: 6 | channel = 2 7 | else: 8 | channel = int(sys.argv[1]) 9 | 10 | 11 | cam_file = f"/dev/video{channel}" 12 | 13 | DURATION = 5 14 | NUM_SAVE = 5 15 | from PIL import Image 16 | 17 | cam = cv2.VideoCapture(cam_file) 18 | print(cam) 19 | import numpy as np 20 | 21 | for _ in range(100): 22 | cam.read() 23 | 24 | ret, frame = cam.read() 25 | print(ret) 26 | cam.release() 27 | print(ret) 28 | save_img = Image.fromarray(frame[..., ::-1]) 29 | save_img.save(f'./test.jpeg') 30 | exit(0) 31 | 32 | # cam.set() 33 | 34 | 35 | # imgs = [] 36 | # start_time = time.time() 37 | # prev_time = start_time 38 | # while True: 39 | # curr_time = time.time() 40 | # if curr_time - prev_time < 1.0/30.0: 41 | # continue 42 | # # if curr_time - prev_time < 1.0: 43 | # # continue 44 | # prev_time = curr_time 45 | # if curr_time - start_time > DURATION: 46 | # break 47 | # ret, frame = cam.read() 48 | # if not ret: 49 | # print(len(imgs), "failed") 50 | # break 51 | # imgs.append(frame) 52 | 53 | # save_freq = len(imgs) // NUM_SAVE 54 | # print(len(imgs)) 55 | # for i, frame in enumerate(imgs): 56 | # if i % save_freq: 57 | # continue 58 | # save_img = Image.fromarray(np.uint8(frame)) 59 | # save_img.save(f'./{i}.jpeg') 60 | 61 | 62 | 63 | # cHEIGHT)) 64 | # print(cam.get(cv2.CAP_PROP_FPS)) 65 | # ret, frame = cam.read() 66 | # if ret: 67 | # cv2.imwrite("test_img.jpg", frame) 68 | -------------------------------------------------------------------------------- /bridge_with_digit/test_utils/test-server.py: -------------------------------------------------------------------------------- 1 | import socket 2 | HOST = "128.32.175.252" 3 | PORT = 54321 4 | socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 5 | socket.bind((HOST, PORT)) 6 | socket.listen() 7 | conn, addr = socket.accept() 8 | print(addr) 9 | 10 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Robotics and AI Lab @ BAIR 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/README.md: -------------------------------------------------------------------------------- 1 | # widowx_envs 2 | 3 | ## Directory Structure 4 | 5 | - experiments: Contains the configuration files for the experiments 6 | - multicam_server: rospkg which handles multi camera streaming setup 7 | - scripts: high-level scritps/entrypoints to run widowx_envs 8 | - widowx_controller: Low-level scripts that handles communication with robot and VR controller 9 | - widowx_envs: Contains the gym environment for the robot arm 10 | 11 | ## Supported robots 12 | (https://github.com/Interbotix/interbotix_ros_manipulators/tree/main/interbotix_ros_xsarms): 13 | - PincherX 100 Robot Arm 14 | - PincherX 150 Robot Arm 15 | - ReactorX 150 Robot Arm 16 | - ReactorX 200 Robot Arm 17 | - WidowX 200 Robot Arm 18 | - WidowX 250 Robot Arm 19 | - WidowX 250 Robot Arm 6DOF 20 | - ViperX 250 Robot Arm 21 | - ViperX 300 Robot Arm 22 | - ViperX 300 Robot Arm 6DOF 23 | - PincherX 100 Mobile Robot Arm 24 | - WidowX 200 Mobile Robot Arm 25 | - WidowX 250 Mobile Robot Arm 6DOF 26 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/curr_ip.py: -------------------------------------------------------------------------------- 1 | CURR_IP = '128.32.175.252' -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/experiments/bridge_data_v2/collection_metadata.json: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/experiments/bridge_data_v2/conf.py: -------------------------------------------------------------------------------- 1 | """ Hyperparameters for Large Scale Data Collection """ 2 | import os.path 3 | 4 | BASE_DIR = '/'.join(str.split(__file__, '/')[:-1]) 5 | current_dir = os.path.dirname(os.path.realpath(__file__)) 6 | 7 | from multicam_server.topic_utils import IMTopic 8 | from widowx_envs.widowx_env import VR_WidowX 9 | from widowx_envs.control_loops import TimedLoop 10 | from widowx_envs.policies.vr_teleop_policy import VRTeleopPolicy 11 | 12 | env_params = { 13 | 'camera_topics': [#IMTopic('/D435/color/image_raw'), 14 | #IMTopic('/yellow/image_raw'), 15 | IMTopic('/blue/image_raw'), 16 | IMTopic('/wrist/image_raw', is_python_node=True) 17 | ], 18 | 'depth_camera_topics': [], # [IMTopic('/D435/depth/image_rect_raw', dtype='16UC1')], 19 | 'digit_topics': [ 20 | IMTopic('/digit_left/image_raw', width=320, height=240, is_python_node=True), 21 | IMTopic('/digit_right/image_raw', width=320, height=240, is_python_node=True), 22 | ], 23 | 'imu_topic': '/imu/imu_raw', 24 | 'mic_topic': '/mic/mic_raw', 25 | 26 | 'gripper_attached': 'custom', 27 | 'skip_move_to_neutral': True, 28 | 'move_to_rand_start_freq': -1, 29 | 'fix_zangle': 0.1, 30 | 'move_duration': 0.2, 31 | 'adaptive_wait': True, 32 | 'action_clipping': None, 33 | 'record_img': True, 34 | 'stream_angles': True, 35 | } 36 | 37 | agent = { 38 | 'type': TimedLoop, 39 | 'env': (VR_WidowX, env_params), 40 | 'recreate_env': (False, 1), 41 | 'T': 500, 42 | 'image_height': 480, 43 | 'image_width': 640, 44 | 'make_final_gif': False, 45 | 'video_format': 'mp4', 46 | } 47 | 48 | policy = { 49 | 'type': VRTeleopPolicy, 50 | } 51 | 52 | config = { 53 | 'current_dir' : current_dir, 54 | 'collection_metadata' : current_dir + '/collection_metadata.json', 55 | 'start_index':0, 56 | 'end_index': 500, 57 | 'agent': agent, 58 | 'policy': policy, 59 | 'save_format': ['raw'], 60 | 'make_diagnostics': False, 61 | 'record_floor_height': False 62 | } 63 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/multicam_server/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(multicam_server) 3 | 4 | catkin_python_setup() 5 | 6 | ## Find catkin macros and libraries 7 | ## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) 8 | ## is used, also find other catkin packages 9 | find_package(catkin REQUIRED COMPONENTS 10 | rospy 11 | std_msgs 12 | message_generation 13 | ) 14 | 15 | catkin_package() 16 | 17 | ## Specify additional locations of header files 18 | ## Your package locations should be listed before other locations 19 | include_directories( 20 | # include 21 | ${catkin_INCLUDE_DIRS} 22 | ) 23 | 24 | catkin_install_python(PROGRAMS # add executable python scripts here 25 | DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} 26 | ) 27 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/multicam_server/launch/cameras.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/multicam_server/launch/imu_streamer.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/multicam_server/launch/mic_streamer.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/multicam_server/launch/overhead_streamer.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/multicam_server/launch/video_streamer.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/multicam_server/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | multicam_server 4 | 0.0.0 5 | The multicam_server package 6 | 7 | 8 | 9 | 10 | jonathan 11 | 12 | 13 | 14 | 15 | 16 | TODO 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | message_generation 41 | 42 | 43 | 44 | 45 | 46 | message_runtime 47 | video_stream_opencv 48 | 49 | 50 | 51 | 52 | catkin 53 | rospy 54 | std_msgs 55 | rospy 56 | std_msgs 57 | rospy 58 | std_msgs 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/multicam_server/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from catkin_pkg.python_setup import generate_distutils_setup 3 | 4 | # fetch values from package.xml 5 | setup_args = generate_distutils_setup( 6 | packages=['multicam_server'], 7 | package_dir={'': 'src'}) 8 | 9 | setup(**setup_args) 10 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/multicam_server/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/bridge_with_digit/widowx_envs/multicam_server/src/__init__.py -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/multicam_server/src/multicam_server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/bridge_with_digit/widowx_envs/multicam_server/src/multicam_server/__init__.py -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/multicam_server/src/multicam_server/angle_streamer.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # best working version 3 | import cv2 4 | import rospy 5 | import threading 6 | import os 7 | from std_msgs.msg import Float32MultiArray 8 | from start_streamers import get_param 9 | import adafruit_bno055 10 | import serial 11 | import time 12 | import numpy as np 13 | from numpy_ringbuffer import RingBuffer 14 | import rospy 15 | from functools import partial 16 | 17 | 18 | ############################################################################## 19 | 20 | 21 | class AngleStreamer: 22 | def __init__(self, env, publish_freq=40., read_freq=40.): 23 | self.env = env 24 | self._publish_freq = publish_freq 25 | self._read_freq = read_freq 26 | self._obs_dim = 3 27 | 28 | self._most_recent_reading = None 29 | self._lock = threading.Lock() 30 | self._angle_lock = threading.Lock() 31 | self._angle_sema = threading.Semaphore(0) 32 | self.start_capture() 33 | self.publisher = rospy.Publisher('/joints/rpy', Float32MultiArray, queue_size=1) 34 | self.start_publishing() 35 | 36 | def start_capture(self): 37 | self._capture_thread = threading.Thread(target=self.capture) 38 | self._capture_thread.start() 39 | 40 | def capture(self): 41 | rate = rospy.Rate(self._read_freq) 42 | while not rospy.is_shutdown(): 43 | full_state = self.env.get_full_state() 44 | rpy = full_state[3:6] 45 | with self._lock: 46 | self._most_recent_reading = np.concatenate((rpy, np.array([rospy.get_time()]))) 47 | rate.sleep() 48 | 49 | def publish_reading(self, data): 50 | msg = Float32MultiArray(data=data) 51 | self.publisher.publish(msg) 52 | 53 | 54 | 55 | def start_publishing(self): 56 | self._publishing_thread = threading.Thread(target=self.publishing) 57 | self._publishing_thread.start() 58 | 59 | def publishing(self): 60 | rate = rospy.Rate(self._publish_freq) 61 | while not rospy.is_shutdown(): 62 | reading = None 63 | with self._lock: 64 | reading = self._most_recent_reading 65 | if reading is not None: 66 | self.publish_reading(reading) 67 | rate.sleep() 68 | 69 | 70 | 71 | 72 | def main(): 73 | # rospy.init_node('Angle_streamer') 74 | streamer = AngleStreamer() 75 | # rospy.spin() 76 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/multicam_server/src/multicam_server/topic_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import cv2 4 | from dataclasses import dataclass, asdict 5 | from typing import Any, Type 6 | 7 | 8 | @dataclass 9 | class IMTopic: 10 | """ 11 | Configuration for an image topic. 12 | """ 13 | name: str 14 | width: int = 640 15 | height: int = 480 16 | top: int = 0 17 | bot: int = 0 18 | right: int = 0 19 | left: int = 0 20 | dtype: str = "bgr8" 21 | flip: bool = False 22 | info_name: str = None 23 | is_python_node: bool = False 24 | 25 | def process_image(self, img): 26 | # Check for overcrop conditions 27 | assert self.bot + self.top <= img.shape[0], "Overcrop! bot + top crop >= image height!" 28 | assert self.right + self.left <= img.shape[1], "Overcrop! right + left crop >= image width!" 29 | 30 | # If bot or right is negative, set to value that crops the entire image 31 | bot = self.bot if self.bot > 0 else -(img.shape[0] + 10) 32 | right = self.right if self.right > 0 else -(img.shape[1] + 10) 33 | 34 | # Crop image 35 | img = img[self.top:-bot, self.left:-right] 36 | 37 | # Flip image if necessary 38 | if self.flip: 39 | img = img[::-1, ::-1] 40 | 41 | # Resize image if necessary 42 | if (self.height, self.width) != img.shape[:2]: 43 | return cv2.resize(img, (self.width, self.height), interpolation=cv2.INTER_AREA) 44 | return img 45 | 46 | @classmethod 47 | def from_dict(cls: Type[Any], data: dict) -> Any: 48 | return cls(**data) 49 | 50 | def to_dict(self): 51 | return asdict(self) 52 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/multicam_server/src/socket_config.py: -------------------------------------------------------------------------------- 1 | from curr_ip import CURR_IP 2 | HOST = CURR_IP 3 | 4 | CHECK_PORT = 48005 # for streaming the overhead camera 5 | IMU_PORT = 54398 # for streaming the IMU measurements 6 | VIDEO_PORT = 50387 # for streaming the wrist camera 7 | 8 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/multicam_server/src/usb_connector_chart_example.yml: -------------------------------------------------------------------------------- 1 | #rename this file to "usb_connector_chart_user.yml" 2 | cam0: 'usb-0000:00:14.0-10.2' 3 | cam1: 'usb-0000:00:14.0-10.4' 4 | cam2: 'usb-0000:00:14.0-9.3' 5 | cam3: 'usb-0000:08:00.0-2' 6 | cam4: "usb-0000:00:14.0-9.2" -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/multicam_server/src/usbreset.c: -------------------------------------------------------------------------------- 1 | /* usbreset -- send a USB port reset to a USB device */ 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | 12 | int main(int argc, char **argv) 13 | { 14 | const char *filename; 15 | int fd; 16 | int rc; 17 | 18 | if (argc != 2) { 19 | fprintf(stderr, "Usage: usbreset device-filename\n"); 20 | return 1; 21 | } 22 | filename = argv[1]; 23 | 24 | fd = open(filename, O_WRONLY); 25 | if (fd < 0) { 26 | perror("Error opening output file"); 27 | return 1; 28 | } 29 | 30 | printf("Resetting USB device %s\n", filename); 31 | rc = ioctl(fd, USBDEVFS_RESET, 0); 32 | if (rc < 0) { 33 | perror("Error in ioctl"); 34 | return 1; 35 | } 36 | printf("Reset successful\n"); 37 | 38 | close(fd); 39 | return 0; 40 | } -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/requirements.txt: -------------------------------------------------------------------------------- 1 | pillow 2 | protobuf 3 | funcsigs 4 | future 5 | imageio 6 | imageio-ffmpeg 7 | librosa 8 | matplotlib 9 | moviepy 10 | opencv-python 11 | numpy>=1.23.4 12 | pyquaternion 13 | scikit-learn 14 | scikit-image 15 | scipy 16 | six 17 | requests 18 | nvidia_smi 19 | rospkg 20 | modern_robotics==1.1.1 21 | gym 22 | tqdm 23 | transformations 24 | ipdb 25 | joblib 26 | pickle5 27 | h5py 28 | funcsigs 29 | git+https://github.com/rail-berkeley/oculus_reader.git 30 | adafruit-circuitpython-bno055 31 | sounddevice 32 | numpy_ringbuffer -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/scripts/go_to_neutral_pose.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | if __name__ == '__main__': 4 | from widowx_envs.widowx_env import StateReachingWidowX 5 | env = StateReachingWidowX() 6 | env.move_to_neutral() 7 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/scripts/go_to_sleep_pose.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | if __name__ == '__main__': 4 | from widowx_envs.widowx_env import StateReachingWidowX 5 | env = StateReachingWidowX() 6 | env.move_to_neutral() 7 | env._controller.bot.arm.go_to_sleep_pose() 8 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | bash $(dirname "$0")/setup.sh || exit 1 4 | 5 | python_node_string='python_node:=false' 6 | camera_string='realsense:=false' 7 | 8 | source /opt/ros/noetic/setup.bash 9 | source ~/interbotix_ws/devel/setup.bash 10 | source ~/myenv/bin/activate 11 | 12 | # using 'exec' here is very important because roslaunch needs to do some cleanup after it exits 13 | # so when the container is killed the SIGTERM needs to be passed to roslaunch 14 | exec roslaunch widowx_controller widowx_rs.launch \ 15 | ${video_stream_provider_string} camera_connector_chart:=/tmp/camera_connector_chart \ 16 | serial_no_camera1:=${REALSENSE_SERIAL} \ 17 | python_node:=false realsense:=true 18 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/scripts/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | if [[ -z "${ROBONETV2_ARM}" ]]; then 4 | echo 'Env variable "ROBONETV2_ARM" is not set. Please define it based on https://github.com/Interbotix/interbotix_ros_manipulators/tree/main/interbotix_ros_xsarms' 5 | echo 'For instance in case of WidowX 250 Robot Arm 6DOF use:' 6 | echo 'echo "export ROBONETV2_ARM=wx250s" >> ~/.bashrc && source ~/.bashrc' 7 | exit 1 8 | fi 9 | 10 | cd 11 | if [ ! -f ".built" ]; then 12 | cd ~/interbotix_ws && catkin_make && touch ~/.built 13 | fi 14 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name='widowx_envs', 5 | version='0.0.1', 6 | packages=setuptools.find_packages(), 7 | license='MIT License', 8 | long_description=open('README.md').read(), 9 | entry_points={ 10 | 'console_scripts': [ 11 | 'widowx_env_service = widowx_envs.widowx_env_service:main', 12 | ], 13 | }, 14 | ) 15 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(widowx_controller) 3 | 4 | ## Compile as C++11, supported in ROS Kinetic and newer 5 | # add_compile_options(-std=c++11) 6 | 7 | catkin_python_setup() 8 | 9 | ## Find catkin macros and libraries 10 | ## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) 11 | ## is used, also find other catkin packages 12 | find_package(catkin REQUIRED COMPONENTS 13 | rospy 14 | std_msgs 15 | message_generation 16 | ) 17 | 18 | ## Generate services in the 'srv' folder 19 | add_service_files( 20 | FILES 21 | GotoNeutral.srv 22 | OpenGripper.srv 23 | MoveToEEP.srv 24 | MoveToState.srv 25 | GetGripperDesiredState.srv 26 | GetCartesianPose.srv 27 | GetState.srv 28 | GetVRButtons.srv 29 | EnableController.srv 30 | DisableController.srv 31 | SetGripperPosition.srv 32 | ) 33 | 34 | ## Generate added messages and services with any dependencies listed here 35 | generate_messages( 36 | DEPENDENCIES 37 | std_msgs 38 | ) 39 | 40 | catkin_package() 41 | 42 | include_directories( 43 | ${catkin_INCLUDE_DIRS} 44 | ) 45 | 46 | catkin_install_python(PROGRAMS # add executable python files here 47 | DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} 48 | ) 49 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/launch/launch.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/launch/widowx_rs.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | widowx_controller 4 | 0.0.0 5 | Package for controlling WidowX robots 6 | 7 | 8 | 9 | 10 | jonathan 11 | 12 | 13 | 14 | 15 | 16 | TODO 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | message_generation 39 | 40 | 41 | 42 | 43 | 44 | message_runtime 45 | 46 | 47 | 48 | 49 | catkin 50 | rospy 51 | std_msgs 52 | rospy 53 | std_msgs 54 | rospy 55 | std_msgs 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from catkin_pkg.python_setup import generate_distutils_setup 3 | 4 | # fetch values from package.xml 5 | setup_args = generate_distutils_setup( 6 | packages=['widowx_controller'], 7 | package_dir={'': 'src'}) 8 | 9 | setup(**setup_args) 10 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/src/widowx_controller/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/bridge_with_digit/widowx_envs/widowx_controller/src/widowx_controller/__init__.py -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/srv/DisableController.srv: -------------------------------------------------------------------------------- 1 | --- -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/srv/EnableController.srv: -------------------------------------------------------------------------------- 1 | --- -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/srv/GetCartesianPose.srv: -------------------------------------------------------------------------------- 1 | --- 2 | float32[] eep 3 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/srv/GetGripperDesiredState.srv: -------------------------------------------------------------------------------- 1 | --- 2 | float32 des_pos 3 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/srv/GetState.srv: -------------------------------------------------------------------------------- 1 | --- 2 | float32[] joint_angles 3 | float32[] joint_velocities 4 | float32[] cartesian_pose 5 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/srv/GetVRButtons.srv: -------------------------------------------------------------------------------- 1 | --- 2 | int32 handle 3 | int32 a 4 | int32 b 5 | int32 rj 6 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/srv/GotoNeutral.srv: -------------------------------------------------------------------------------- 1 | float32 duration 2 | --- -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/srv/MoveToEEP.srv: -------------------------------------------------------------------------------- 1 | float32[] des_eep 2 | float32 duration 3 | --- -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/srv/MoveToState.srv: -------------------------------------------------------------------------------- 1 | float32[] target_xyz 2 | float32 target_zangle 3 | float32 duration 4 | --- 5 | int32 success -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/srv/OpenGripper.srv: -------------------------------------------------------------------------------- 1 | --- -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_controller/srv/SetGripperPosition.srv: -------------------------------------------------------------------------------- 1 | float32 des_pos 2 | --- -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_envs/__init__.py: -------------------------------------------------------------------------------- 1 | name = "widowx_envs" 2 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_envs/base/robot_configs.json: -------------------------------------------------------------------------------- 1 | { 2 | "wx250s": { 3 | "bound": [[0.19, -0.08, 0.05, -1.57, 0], [0.31, 0.08, 0.055, 1.57, 0]] 4 | }, 5 | "wx250": { 6 | "bound": [[0.19, -0.08, 0.05, -1.57, 0], [0.31, 0.08, 0.055, 1.57, 0]] 7 | }, 8 | "wx200": { 9 | "bound": [[0.19, -0.08, 0.05, -1.57, 0], [0.31, 0.08, 0.055, 1.57, 0]] 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_envs/policies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/bridge_with_digit/widowx_envs/widowx_envs/policies/__init__.py -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_envs/policies/policy.py: -------------------------------------------------------------------------------- 1 | """ This file defines the base class for the policy. """ 2 | import abc, six 3 | import pickle as pkl 4 | import numpy as np 5 | import pdb 6 | 7 | from widowx_envs.utils.utils import AttrDict, Configurable 8 | 9 | class Policy(Configurable): 10 | """Abstract class for policy.""" 11 | def _default_hparams(self): 12 | dict = AttrDict( 13 | ngpu=1, 14 | gpu_id=0, 15 | ) 16 | default_dict = super()._default_hparams() 17 | default_dict.update(dict) 18 | return default_dict 19 | 20 | def act(self, *args): 21 | """ 22 | Args: 23 | Request necessary arguments in definition 24 | (see Agent code) 25 | Returns: 26 | A dict of outputs D 27 | -One key in D, 'actions' should have the action for this time-step 28 | """ 29 | raise NotImplementedError("Must be implemented in subclass.") 30 | 31 | def reset(self): 32 | pass 33 | 34 | def set_log_dir(self, dir): 35 | self.traj_log_dir = dir 36 | 37 | 38 | class DummyPolicy(Policy): 39 | def __init__(self, ag_params, policyparams): 40 | """ Computes actions from states/observations. """ 41 | pass 42 | 43 | def act(self, *args): 44 | return {'actions': None} 45 | 46 | def reset(self): 47 | return None 48 | 49 | 50 | class ReplayActions(Policy): 51 | def __init__(self, ag_params, policyparams): 52 | """ Computes actions from states/observations. """ 53 | self._hp = self._default_hparams() 54 | self._override_defaults(policyparams) 55 | self.policy_out = pkl.load(open(self._hp.load_file + '/policy_out.pkl', 'rb')) 56 | self.env = ag_params.env 57 | 58 | def _default_hparams(self): 59 | dict = AttrDict( 60 | load_file="", 61 | type=None, 62 | ) 63 | default_dict = super(Policy, self)._default_hparams() 64 | default_dict.update(dict) 65 | return default_dict 66 | 67 | def act(self, t): 68 | return self.policy_out[t] 69 | 70 | def reset(self): 71 | return None 72 | 73 | 74 | class NullPolicy(Policy): 75 | """ 76 | Returns 0 for all timesteps 77 | """ 78 | def __init__(self, ag_params, policyparams): 79 | self._adim = ag_params['adim'] 80 | self._hp = self._default_hparams() 81 | self._override_defaults(policyparams) 82 | 83 | # def _default_hparams(self): 84 | # default_dict = { 85 | # 'wait_for_user': False 86 | # } 87 | # parent_params = super(NullPolicy, self)._default_hparams() 88 | # for k in default_dict.keys(): 89 | # parent_params.add_hparam(k, default_dict[k]) 90 | # return parent_params 91 | 92 | def act(self): 93 | return {'actions': np.zeros(self._adim)} -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_envs/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_envs/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | class Bad_Traj_Exception(Exception): 2 | def __init__(self): 3 | pass 4 | 5 | 6 | class Image_Exception(Exception): 7 | def __init__(self): 8 | pass 9 | 10 | 11 | class Environment_Exception(Exception): 12 | def __init__(self): 13 | pass 14 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_envs/utils/grasp_utils.py: -------------------------------------------------------------------------------- 1 | from sklearn.linear_model import LinearRegression 2 | import numpy as np 3 | import time 4 | from sklearn.preprocessing import PolynomialFeatures 5 | 6 | 7 | def compute_robot_transformation_matrix(a, b): 8 | lr = LinearRegression(fit_intercept=False).fit(a, b) 9 | return lr.coef_.T 10 | 11 | 12 | def convert_obs_to_image(obs, transpose=False): 13 | print("taking picture...") 14 | image = np.uint8(np.reshape(obs['image'] * 255, (3, 64, 64))) 15 | if transpose: image = np.transpose(image, (1, 2, 0)) 16 | # print("image.shape", image.shape) 17 | return image 18 | 19 | 20 | def rgb_to_robot_coords(rgb_coords, transmatrix): 21 | # add vector of 1s as feature to the pc_coords. 22 | assert len(rgb_coords.shape) <= 2 23 | if len(rgb_coords.shape) == 1: 24 | rgb_coords = np.array(rgb_coords[None]) 25 | poly = PolynomialFeatures(2) 26 | rgb_coords = poly.fit_transform(rgb_coords) 27 | 28 | if transmatrix is not None: 29 | robot_coords = rgb_coords @ transmatrix 30 | return np.squeeze(robot_coords) 31 | 32 | def get_image_obs(env, image_xyz=None, skip_move_to_neutral=False): 33 | joint_angles = env._controller.get_joint_angles() 34 | if image_xyz is None: 35 | if not skip_move_to_neutral: 36 | env.move_to_neutral(0.5) 37 | # else: 38 | # env.reset() 39 | else: 40 | env.move_to_state(image_xyz, target_zangle=0, duration=0.5) 41 | time.sleep(0.2) # wait for camera to catch up 42 | obs = env.current_obs() 43 | env._controller.set_joint_angles(joint_angles, 0.5) 44 | return obs 45 | 46 | 47 | def get_image(env, transpose=True, image_xyz=None, skip_move_to_neutral=False): 48 | obs = get_image_obs(env, image_xyz, skip_move_to_neutral) 49 | return convert_obs_to_image(obs, transpose=transpose) 50 | 51 | 52 | def execute_reach(env, reach_policy, reachpoint, noise=0.0): 53 | reach_policy.reset(reach_point=reachpoint) 54 | for i in range(6): 55 | action, _ = reach_policy.get_action() 56 | 57 | # noise 58 | noise_dims = 2 59 | noise_stds = [noise] * noise_dims + [0] * (len(action) - noise_dims) 60 | action = np.random.normal(loc=action, scale=noise_stds) 61 | action = np.clip(action, -1.0, 1.0) 62 | # import ipdb; ipdb.set_trace() 63 | obs, _, _, _ = env.step(action) 64 | return obs 65 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_envs/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import moviepy.editor as mpy 3 | import os 4 | import numpy as np 5 | from PIL import Image, ImageDraw 6 | from PIL import Image 7 | 8 | def resize_store(t, target_array, input_array): 9 | target_img_height, target_img_width = target_array.shape[2:4] 10 | 11 | if (target_img_height, target_img_width) == input_array.shape[1:3]: 12 | for i in range(input_array.shape[0]): 13 | target_array[t, i] = input_array[i] 14 | else: 15 | for i in range(input_array.shape[0]): 16 | target_array[t, i] = cv2.resize(input_array[i], (target_img_width, target_img_height), 17 | interpolation=cv2.INTER_AREA) 18 | 19 | 20 | def npy_to_gif(im_list, filename, fps=4): 21 | save_dir = '/'.join(str.split(filename, '/')[:-1]) 22 | 23 | if not os.path.exists(save_dir): 24 | print('creating directory: ', save_dir) 25 | os.makedirs(save_dir) 26 | 27 | clip = mpy.ImageSequenceClip(im_list, fps=fps) 28 | clip.write_gif(filename + '.gif') 29 | 30 | 31 | def npy_to_mp4(im_list, filename, fps=4): 32 | save_dir = '/'.join(str.split(filename, '/')[:-1]) 33 | 34 | if not os.path.exists(save_dir): 35 | print('creating directory: ', save_dir) 36 | os.mkdir(save_dir) 37 | 38 | clip = mpy.ImageSequenceClip(im_list, fps=fps) 39 | clip.write_videofile(filename + '.mp4') 40 | 41 | def draw_text_image(text, background_color=(255,255,255), image_size=(30, 64), dtype=np.float32): 42 | 43 | text_image = Image.new('RGB', image_size[::-1], background_color) 44 | draw = ImageDraw.Draw(text_image) 45 | if text: 46 | draw.text((4, 0), text, fill=(0, 0, 0)) 47 | if dtype == np.float32: 48 | return np.array(text_image).astype(np.float32)/255. 49 | else: 50 | return np.array(text_image) 51 | 52 | 53 | def draw_text_onimage(text, image, color=(255, 0, 0)): 54 | if image.dtype == np.float32: 55 | image = (image*255.).astype(np.uint8) 56 | assert image.dtype == np.uint8 57 | text_image = Image.fromarray(image) 58 | draw = ImageDraw.Draw(text_image) 59 | draw.text((4, 0), text, fill=color) 60 | return np.array(text_image).astype(np.float32)/255. 61 | -------------------------------------------------------------------------------- /bridge_with_digit/widowx_envs/widowx_envs/utils/sync.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Value, Lock 2 | 3 | 4 | class SyncCounter: 5 | def __init__(self, base_value=0): 6 | self._lock = Lock() 7 | self._value = Value('i', base_value) 8 | 9 | @property 10 | def ret_increment(self): 11 | with self._lock: 12 | ret_val = self._value.value 13 | self._value.value += 1 14 | return ret_val 15 | 16 | @property 17 | def value(self): 18 | with self._lock: 19 | ret_val = self._value.value 20 | return ret_val 21 | 22 | 23 | class ManagedSyncCounter(SyncCounter): 24 | def __init__(self, manager, base_value=0): 25 | self._lock, self._value = manager.Lock(), manager.Value('i', base_value) 26 | -------------------------------------------------------------------------------- /media/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/media/teaser.jpg -------------------------------------------------------------------------------- /octo_digit/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .git 3 | max-line-length = 88 4 | select = E,F,W,C 5 | ignore=W503, 6 | E203, 7 | E731, 8 | E722, 9 | F841, 10 | E402, 11 | E741, 12 | E501, 13 | C406, 14 | -------------------------------------------------------------------------------- /octo_digit/.github/workflows/pre-commit.yaml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | - uses: pre-commit/action@v3.0.0 15 | -------------------------------------------------------------------------------- /octo_digit/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.10 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v2.3.0 6 | hooks: 7 | - id: check-yaml 8 | - id: check-ast 9 | - id: check-added-large-files 10 | exclude: ^examples/ 11 | - id: check-case-conflict 12 | - id: check-merge-conflict 13 | - id: end-of-file-fixer 14 | - id: trailing-whitespace 15 | - id: detect-private-key 16 | - id: debug-statements 17 | exclude: ^experiments/ 18 | - repo: https://github.com/psf/black 19 | rev: 22.10.0 20 | hooks: 21 | - id: black 22 | exclude: ^experiments/ 23 | - repo: https://github.com/PyCQA/flake8 24 | rev: 6.1.0 25 | hooks: 26 | - id: flake8 27 | exclude: ^experiments/ 28 | - repo: https://github.com/pycqa/isort 29 | rev: 5.12.0 30 | hooks: 31 | - id: isort 32 | exclude: ^experiments/ 33 | args: ["--profile", "black", "--src", "octo", "--src", "experiments"] 34 | - repo: https://github.com/srstevenson/nb-clean 35 | rev: 3.1.0 36 | hooks: 37 | - id: nb-clean 38 | args: 39 | - --remove-empty-cells 40 | - --preserve-cell-outputs 41 | -------------------------------------------------------------------------------- /octo_digit/.python-version: -------------------------------------------------------------------------------- 1 | 3.10 2 | -------------------------------------------------------------------------------- /octo_digit/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Robotic AI & Learning Lab Berkeley 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /octo_digit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/__init__.py -------------------------------------------------------------------------------- /octo_digit/docs/assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/docs/assets/teaser.jpg -------------------------------------------------------------------------------- /octo_digit/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/eval/__init__.py -------------------------------------------------------------------------------- /octo_digit/eval/decode_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | gen_modes = [("visual",), ("tactile",), ("visual", "tactile")] 4 | # gen_modes = [('audio',)] 5 | csv_modes = [",".join(modality_tuple) for modality_tuple in gen_modes] 6 | gen_mode_lang_names = [ 7 | "all_lang_4", 8 | ] 9 | modality_obs_keys = { 10 | "visual": ["image_primary", "image_wrist"], 11 | "tactile": ["image_digit_right", "image_digit_left"], 12 | "audio": [ 13 | "mel_spectro", 14 | ], 15 | } 16 | 17 | modality_specific_keys = [] 18 | for v in modality_obs_keys.values(): 19 | modality_specific_keys.extend(v) 20 | modality_specific_keys = set(modality_specific_keys) 21 | 22 | includes = ["pad_mask_dict", "task_completed", "timestep", "timestep_pad_mask"] 23 | 24 | WINDOW_SIZE = 2 25 | pad_mask = np.array([[True for _ in range(WINDOW_SIZE)]])[0] 26 | -------------------------------------------------------------------------------- /octo_digit/eval/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/eval/envs/__init__.py -------------------------------------------------------------------------------- /octo_digit/eval/eval_requirements.txt: -------------------------------------------------------------------------------- 1 | funcsigs 2 | opencv-python 3 | pyquaternion 4 | librosa 5 | edgeml @ git+https://github.com/youliangtan/edgeml.git 6 | -------------------------------------------------------------------------------- /octo_digit/eval/recursive_dict_print.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | MAX_KEY_LEN = 20 4 | INDENT_SIZE = MAX_KEY_LEN + 4 5 | INDENT = "".join([" " for _ in range(INDENT_SIZE)]) 6 | 7 | 8 | def recursive_dict_print(dictionary: dict, prefix=""): 9 | for key, val in dictionary.items(): 10 | key = key[:MAX_KEY_LEN] 11 | if isinstance(val, dict): 12 | print(f"{prefix}{key}") 13 | new_prefix = prefix + INDENT 14 | recursive_dict_print(val, new_prefix) 15 | else: 16 | indent = "".join([" " for _ in range(INDENT_SIZE - len(key))]) 17 | try: 18 | print(f"{prefix}{key}:{indent}{val.shape} {val.dtype}") 19 | except AttributeError: 20 | print(f"{prefix}{key}:{indent} {type(val)}") 21 | -------------------------------------------------------------------------------- /octo_digit/eval/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name="eval", packages=["envs"]) 4 | -------------------------------------------------------------------------------- /octo_digit/mem_lims.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | window_sizes=(8 10 12 13 14 15 16) 3 | batch_sizes=( 32 64 128 200 256 ) 4 | for i in "${batch_sizes[@]}" 5 | do 6 | for j in "${window_sizes[@]}" 7 | do 8 | python scripts/finetune_josh.py --config=scripts/configs/josh_finetune_config.py:"None" --name=mem_test --o_window_size="${j}" --o_batch_size="${i}" --o_steps=2 --debug=True --mode="${1}" --log_file="${1}_log.txt" 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /octo_digit/octo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/octo/__init__.py -------------------------------------------------------------------------------- /octo_digit/octo/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/octo/data/__init__.py -------------------------------------------------------------------------------- /octo_digit/octo/data/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/octo/data/utils/__init__.py -------------------------------------------------------------------------------- /octo_digit/octo/data/utils/goal_relabeling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required. 3 | Each function should add entries to the "task" dict. 4 | """ 5 | 6 | from typing import Optional 7 | 8 | import tensorflow as tf 9 | 10 | from octo.data.utils.data_utils import tree_merge 11 | 12 | 13 | def uniform(traj: dict, max_goal_distance: Optional[int] = None) -> dict: 14 | """ 15 | Relabels with a true uniform distribution over future states. 16 | Optionally caps goal distance. 17 | """ 18 | traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0] 19 | 20 | # select a random future index for each transition i in the range [i, traj_len) 21 | rand = tf.random.uniform([traj_len]) 22 | low = tf.cast(tf.range(traj_len), tf.float32) 23 | if max_goal_distance is not None: 24 | high = tf.cast( 25 | tf.minimum(tf.range(traj_len) + max_goal_distance, traj_len), tf.float32 26 | ) 27 | else: 28 | high = tf.cast(traj_len, tf.float32) 29 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) 30 | 31 | # sometimes there are floating-point errors that cause an out-of-bounds 32 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1) 33 | 34 | # adds keys to "task" mirroring "observation" keys (must do a tree merge to combine "pad_mask_dict" from 35 | # "observation" and "task" properly) 36 | goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"]) 37 | traj["task"] = tree_merge(traj["task"], goal) 38 | 39 | return traj 40 | -------------------------------------------------------------------------------- /octo_digit/octo/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/octo/model/__init__.py -------------------------------------------------------------------------------- /octo_digit/octo/model/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/octo/model/components/__init__.py -------------------------------------------------------------------------------- /octo_digit/octo/model/components/base.py: -------------------------------------------------------------------------------- 1 | import flax 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | from octo.utils.typing import Sequence 6 | 7 | 8 | @flax.struct.dataclass 9 | class TokenGroup: 10 | """A group of tokens that have semantic meaning together (e.g. the tokens for a single observation) 11 | 12 | Attributes: 13 | tokens: jax.Array of shape (..., n_tokens, token_dim) 14 | mask: jax.Array of shape (..., n_tokens) indicating which tokens are valid (1) vs padding (0) 15 | """ 16 | 17 | tokens: jax.typing.ArrayLike 18 | mask: jax.typing.ArrayLike 19 | 20 | @classmethod 21 | def create( 22 | cls, tokens: jax.typing.ArrayLike, mask: jax.typing.ArrayLike = None, **kwargs 23 | ): 24 | if mask is None: 25 | mask = jnp.ones(tokens.shape[:-1]) 26 | assert mask.ndim == tokens.ndim - 1 27 | return cls(tokens, mask, **kwargs) 28 | 29 | @classmethod 30 | def concatenate(cls, group_list: Sequence["TokenGroup"], axis=-2): 31 | data = jnp.concatenate([t.tokens for t in group_list], axis=axis) 32 | mask = jnp.concatenate([t.mask for t in group_list], axis=axis + 1) 33 | return cls(data, mask) 34 | -------------------------------------------------------------------------------- /octo_digit/octo/model/components/film_conditioning_layer.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/google-research/robotics_transformer/blob/master/film_efficientnet/film_conditioning_layer.py 2 | 3 | import flax.linen as nn 4 | import jax.numpy as jnp 5 | 6 | 7 | class FilmConditioning(nn.Module): 8 | @nn.compact 9 | def __call__(self, conv_filters: jnp.ndarray, conditioning: jnp.ndarray): 10 | """Applies FiLM conditioning to a convolutional feature map. 11 | 12 | Args: 13 | conv_filters: A tensor of shape [batch_size, height, width, channels]. 14 | conditioning: A tensor of shape [batch_size, conditioning_size]. 15 | 16 | Returns: 17 | A tensor of shape [batch_size, height, width, channels]. 18 | """ 19 | projected_cond_add = nn.Dense( 20 | features=conv_filters.shape[-1], 21 | kernel_init=nn.initializers.zeros, 22 | bias_init=nn.initializers.zeros, 23 | )(conditioning) 24 | projected_cond_mult = nn.Dense( 25 | features=conv_filters.shape[-1], 26 | kernel_init=nn.initializers.zeros, 27 | bias_init=nn.initializers.zeros, 28 | )(conditioning) 29 | 30 | projected_cond_add = projected_cond_add[:, None, None, :] 31 | projected_cond_mult = projected_cond_mult[:, None, None, :] 32 | 33 | return conv_filters * (1 + projected_cond_add) + projected_cond_mult 34 | -------------------------------------------------------------------------------- /octo_digit/octo/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name="octo", packages=["data", "model", "utils"]) -------------------------------------------------------------------------------- /octo_digit/octo/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/octo/utils/__init__.py -------------------------------------------------------------------------------- /octo_digit/octo/utils/fuse_constants.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax 3 | 4 | modality_combination_order = ['simple', '', 'visual', 'tactile', 'audio', 'visual,tactile', 'visual,audio', 'tactile,audio', 'visual,tactile,audio'] 5 | _modality_combinations = [('visual',), ('tactile',), ('audio',), ('visual', 'tactile'), ('visual', 'audio'), ('tactile', 'audio'), ('visual', 'tactile', 'audio')] 6 | modality_combinations = [','.join(combination) for combination in _modality_combinations] 7 | fuse_loss_modal_indices = {i: combination for i, combination in enumerate(modality_combination_order) if combination != ''} 8 | contrastive_indices = fuse_loss_modal_indices 9 | generative_indices = {k: v for k, v in contrastive_indices.items() if v != 'simple'} 10 | name_to_index_generative = {v: k for k, v in generative_indices.items()} 11 | 12 | modality_to_observation_keys = { 13 | 'visual': ['image_primary', 'image_wrist'], 14 | 'tactile': ['image_digit_right', 'image_digit_right_background', 'image_digit_left', 'image_digit_left_background'], 15 | 'audio': ['mic', 'mel_spectro'] 16 | } 17 | modality_specific_keys = [] 18 | for v in modality_to_observation_keys.values(): 19 | modality_specific_keys.extend(v) 20 | modality_specific_keys = set(modality_specific_keys) 21 | nonspecific_keys= ['task_completed', 'timestep', 'modality_idx'] 22 | 23 | modality_to_observation_keys['simple'] = list(modality_specific_keys) 24 | 25 | 26 | def create_fuse_modal_masks(example_obs): 27 | modal_masks = {} 28 | pad_mask_dict = example_obs['pad_mask_dict'] 29 | for i, combination in fuse_loss_modal_indices.items(): 30 | combination_mask = {} 31 | for modality in combination.split(','): 32 | for obs_key in modality_to_observation_keys[modality]: 33 | if obs_key in pad_mask_dict: 34 | combination_mask[obs_key] = jnp.ones_like(pad_mask_dict[obs_key]) 35 | for obs_key in nonspecific_keys: 36 | if obs_key in pad_mask_dict: 37 | combination_mask[obs_key] = jnp.ones_like(pad_mask_dict[obs_key]) 38 | for obs_key in pad_mask_dict: 39 | if obs_key not in combination_mask: 40 | combination_mask[obs_key] = jnp.zeros_like(pad_mask_dict[obs_key]) 41 | modal_masks[i] = combination_mask 42 | assert modal_masks[i].keys() == pad_mask_dict.keys() 43 | return modal_masks 44 | 45 | 46 | def create_batch(batch, observation_masks, fuse_modal_masks, modality_combination_index: int): 47 | if observation_masks is None: 48 | batch['observation']['pad_mask_dict'] = fuse_modal_masks[modality_combination_index] 49 | else: 50 | batch['observation']['pad_mask_dict'] = jax.tree_map( 51 | lambda true_mask, fuse_mask: jnp.logical_and(true_mask, fuse_mask), 52 | observation_masks, 53 | fuse_modal_masks[modality_combination_index], 54 | ) 55 | return batch -------------------------------------------------------------------------------- /octo_digit/octo/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | MAX_KEY_LEN = 15 2 | INDENT_SIZE = MAX_KEY_LEN + 4 3 | INDENT = ''.join([' ' for _ in range(INDENT_SIZE)]) 4 | HEADING_SEPARATOR = "############################################" 5 | 6 | def print_separator(log_func=print): 7 | log_func(HEADING_SEPARATOR) 8 | 9 | 10 | def pretty_print_dict(dictionary, prefix="", log_func=print, pad_with_newlines=True): 11 | lines_to_output = [] 12 | def _pretty_print(dictionary, prefix=""): 13 | for key, val in dictionary.items(): 14 | key = key[:MAX_KEY_LEN] 15 | if isinstance(val, dict): 16 | lines_to_output.append(f'{prefix}{key}') 17 | _pretty_print(val, prefix + INDENT) 18 | else: 19 | indent = ' ' * (INDENT_SIZE - len(key)) 20 | lines_to_output.append(f'{prefix}{key}:{indent}{val}') 21 | _pretty_print(dictionary, prefix) 22 | if pad_with_newlines: 23 | lines_to_output = [''] + lines_to_output + [''] 24 | log_func('\n'.join(lines_to_output)) 25 | 26 | 27 | def append_identity_to_metrics(metrics: dict, identity_suffix: str) -> dict: 28 | processed_metrics = {} 29 | for key, val in metrics.items(): 30 | processed_metrics[f'{key}_{identity_suffix}'] = val 31 | return processed_metrics -------------------------------------------------------------------------------- /octo_digit/octo/utils/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Sequence, Union 2 | 3 | import jax 4 | 5 | PRNGKey = jax.Array 6 | PyTree = Union[jax.typing.ArrayLike, Mapping[str, "PyTree"]] 7 | Config = Union[Any, Mapping[str, "Config"]] 8 | Params = Mapping[str, PyTree] 9 | Perturbations = Mapping[str, PyTree] 10 | JaxArray = jax.typing.ArrayLike 11 | Data = Mapping[str, PyTree] 12 | Shape = Sequence[int] 13 | Dtype = jax.typing.DTypeLike 14 | -------------------------------------------------------------------------------- /octo_digit/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "octo_digit" 3 | version = "0.0.1" 4 | description = "" 5 | readme = "README.md" 6 | requires-python = "==3.10.*" 7 | dependencies = [ 8 | "gym>=0.26", 9 | "numpy==1.24.3", 10 | "ml_dtypes==0.2.0", 11 | "chex==0.1.85", 12 | "optax==0.1.5", 13 | "tensorflow_probability==0.23.0", 14 | "tensorflow==2.15.0", 15 | "jax==0.4.20", 16 | "distrax==0.1.5", 17 | "flax==0.7.5", 18 | "ml_collections>=0.1.0", 19 | "tqdm>=4.60.0", 20 | "absl-py>=0.12.0", 21 | "wandb>=0.12.14", 22 | "einops>=0.6.1", 23 | "imageio>=2.31.1", 24 | "moviepy>=1.0.3", 25 | "pre-commit==3.3.3", 26 | "transformers>=4.34.1", 27 | "tensorflow_hub>=0.14.0", 28 | "tensorflow_text>=2.13.0", 29 | "tensorflow_datasets>=4.9.0", 30 | "tensorflow_graphics==2021.12.3", 31 | "dlimp@git+https://github.com/kvablack/dlimp@5edaa4691567873d495633f2708982b42edf1972", 32 | "plotly>=5.16.1", 33 | "matplotlib", 34 | "scipy==1.12.0", 35 | "funcsigs", 36 | "opencv-python", 37 | "pyquaternion", 38 | "librosa", 39 | "edgeml @ git+https://github.com/youliangtan/edgeml.git", 40 | "octo", 41 | "eval" 42 | ] 43 | 44 | [project.optional-dependencies] 45 | tpu = [ 46 | "jax[tpu]==0.4.20", 47 | "libtpu-nightly" 48 | ] 49 | gpu = [ 50 | "jax[cuda11_pip]==0.4.20", 51 | ] 52 | 53 | [tool.uv] 54 | find-links = [ 55 | "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html", 56 | "https://storage.googleapis.com/jax-releases/libtpu_releases.html", 57 | ] 58 | override-dependencies = ["scipy==1.12.0"] 59 | prerelease = "allow" 60 | conflicts = [ 61 | [ 62 | { extra = "tpu" }, 63 | { extra = "gpu" }, 64 | ], 65 | ] 66 | 67 | 68 | [tool.uv.sources] 69 | octo = { path = "./octo", editable = true } 70 | eval = { path = "./eval", editable = true } 71 | 72 | [tool.black] 73 | # https://github.ciom/psf/black 74 | line-length = 88 75 | target-version = ["py310"] 76 | exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|build|dist)" 77 | 78 | [tool.isort] 79 | profile = "black" 80 | line_length = 88 81 | force_sort_within_sections = "True" 82 | order_by_type = "False" 83 | -------------------------------------------------------------------------------- /octo_digit/requirements.txt: -------------------------------------------------------------------------------- 1 | gym >= 0.26 2 | numpy == 1.24.3 3 | ml_dtypes == 0.2.0 4 | chex == 0.1.85 5 | optax == 0.1.5 6 | tensorflow_probability == 0.23.0 7 | tensorflow == 2.15.0 8 | jax == 0.4.20 9 | distrax == 0.1.5 10 | flax == 0.7.5 11 | ml_collections >= 0.1.0 12 | tqdm >= 4.60.0 13 | absl-py >= 0.12.0 14 | scipy >= 1.6.0 15 | wandb >= 0.12.14 16 | einops >= 0.6.1 17 | imageio >= 2.31.1 18 | moviepy >= 1.0.3 19 | pre-commit == 3.3.3 20 | transformers >= 4.34.1 21 | tensorflow_hub >= 0.14.0 22 | tensorflow_text >= 2.13.0 23 | tensorflow_datasets == 4.9.2 24 | tensorflow_graphics == 2021.12.3 25 | dlimp @ git+https://github.com/kvablack/dlimp@5edaa4691567873d495633f2708982b42edf1972 26 | plotly >= 5.16.1 27 | matplotlib 28 | scipy==1.12.0 29 | -------------------------------------------------------------------------------- /octo_digit/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name="octo_digit", packages=["octo", "eval"]) 4 | -------------------------------------------------------------------------------- /palivla_digit/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .venv 3 | notebooks 4 | **/*.mp4 5 | *.ipynb 6 | wandb/ 7 | *.egg-info 8 | 9 | /trained_tokenizers 10 | /checkpoints 11 | models 12 | -------------------------------------------------------------------------------- /palivla_digit/.python-version: -------------------------------------------------------------------------------- 1 | 3.11 -------------------------------------------------------------------------------- /palivla_digit/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | At this time we do not plan to accept non-trivial contributions. The main 4 | purpose of this codebase is to allow the community to reproduce results from our 5 | publications. 6 | 7 | You are however free to start a fork of the project for your purposes as 8 | permitted by the license. 9 | 10 | ## Contributor License Agreement 11 | 12 | Contributions to this project must be accompanied by a Contributor License 13 | Agreement (CLA). You (or your employer) retain the copyright to your 14 | contribution; this simply gives us permission to use and redistribute your 15 | contributions as part of the project. Head over to 16 | to see your current agreements on file or 17 | to sign a new one. 18 | 19 | You generally only need to submit a CLA once, so if you've already submitted one 20 | (even if it was for a different project), you probably don't need to do it 21 | again. 22 | 23 | ## Community Guidelines 24 | 25 | This project follows 26 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 27 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/__init__.py -------------------------------------------------------------------------------- /palivla_digit/big_vision/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/configs/__init__.py -------------------------------------------------------------------------------- /palivla_digit/big_vision/configs/common_fewshot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Most common few-shot eval configuration.""" 16 | 17 | import ml_collections as mlc 18 | 19 | 20 | def get_fewshot_lsr(target_resolution=224, resize_resolution=256, 21 | runlocal=False, **kw): 22 | """Returns a standard-ish fewshot eval configuration.""" 23 | kw.setdefault('representation_layer', 'pre_logits') 24 | kw.setdefault('shots', (1, 5, 10, 25)) 25 | kw.setdefault('l2_reg', 2.0 ** 10) 26 | kw.setdefault('num_seeds', 3) 27 | kw.setdefault('prefix', '') # No prefix as we already use a/ z/ and zz/ 28 | 29 | # Backward-compatible default: 30 | if not any(f'log_{x}' in kw for x in ['steps', 'percent', 'examples', 'epochs']): # pylint: disable=line-too-long 31 | kw['log_steps'] = 25_000 32 | 33 | config = mlc.ConfigDict(kw) 34 | config.type = 'fewshot_lsr' 35 | config.datasets = { 36 | 'caltech': ('caltech101', 'train', 'test'), # copybara:srtip 37 | 'cars': ('cars196:2.1.0', 'train', 'test'), 38 | 'cifar100': ('cifar100', 'train', 'test'), 39 | 'dtd': ('dtd', 'train', 'test'), 40 | # The first 65000 ImageNet samples have at least 30 shots per any class. 41 | # Commented out by default because needs manual download. 42 | # 'imagenet': ('imagenet2012', 'train[:65000]', 'validation'), 43 | 'pets': ('oxford_iiit_pet', 'train', 'test'), 44 | 'uc_merced': ('uc_merced', 'train[:1000]', 'train[1000:]'), 45 | } if not runlocal else { 46 | 'pets': ('oxford_iiit_pet', 'train', 'test'), 47 | } 48 | config.pp_train = (f'decode|resize({resize_resolution})|' 49 | f'central_crop({target_resolution})|' 50 | f'value_range(-1,1)|keep("image", "label")') 51 | config.pp_eval = (f'decode|resize({resize_resolution})|' 52 | f'central_crop({target_resolution})|' 53 | f'value_range(-1,1)|keep("image", "label")') 54 | config.display_first = [('imagenet', 10)] if not runlocal else [('pets', 10)] 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/configs/proj/cappa/README.md: -------------------------------------------------------------------------------- 1 | # Image Captioners Are Scalable Vision Learners Too 2 | 3 | *by Michael Tschannen, Manoj Kumar, Andreas Steiner, Xiaohua Zhai, Neil Houlsby, Lucas Beyer* [[arxiv]](https://arxiv.org/abs/2306.07915) 4 | 5 | ![CapPa Architecture](./cappa_architecture.png) 6 | 7 | This directory contains a config for training a CapPa model from scratch. 8 | Note that most models in the paper were trained on a proprietary dataset 9 | (WebLI), but similar results can be obtained by training on [LAION](https://laion.ai/). 10 | 11 | By default, this config trains on COCO captions as this data set is readily 12 | available in [TFDS](https://www.tensorflow.org/datasets) without manual steps. 13 | This is not meant to produce a meaningful model, but 14 | provides a way for the user to run the config out of the box. Please update the 15 | config with with a TFDS-wrapped variant of your favorite image/text data set to 16 | train capable models. 17 | 18 | After setting up `big_vision` as described in the [main README](https://github.com/google-research/big_vision#cloud-tpu-vm-setup), training can be launched as follows 19 | 20 | ``` 21 | python -m big_vision.trainers.proj.cappa.generative \ 22 | --config big_vision/configs/proj/cappa/pretrain.py \ 23 | --workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%m-%d_%H%M'` 24 | ``` 25 | 26 | To run the Cap baseline (autoregressive captioning without parallel prediction), 27 | set `config.model.masked_pred_prob = 0.0`. 28 | 29 | ### Citation 30 | ``` 31 | @inproceedings{tschannen2023image, 32 | title={Image Captioners Are Scalable Vision Learners Too}, 33 | author={Tschannen, Michael and Kumar, Manoj and Steiner, Andreas and Zhai, Xiaohua and Houlsby, Neil and Beyer, Lucas}, 34 | booktitle={Neural Information Processing Systems (NeurIPS)}, 35 | year={2023} 36 | } 37 | ``` 38 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/configs/proj/cappa/cappa_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/configs/proj/cappa/cappa_architecture.png -------------------------------------------------------------------------------- /palivla_digit/big_vision/configs/proj/distill/README.md: -------------------------------------------------------------------------------- 1 | # Knowledge distillation: A good teacher is patient and consistent 2 | *by Lucas Beyer, Xiaohua Zhai, Amélie Royer, Larisa Markeeva, Rohan Anil, Alexander Kolesnikov* 3 | 4 | ## Introduction 5 | We publish all teacher models, and configurations for the main experiments of 6 | the paper, as well as training logs and student models. 7 | 8 | Please read the main [big_vision README](/README.md) to learn how to run 9 | configs, and remember that each config file contains an example invocation in 10 | the top-level comment. 11 | 12 | ## Results 13 | 14 | We provide the following [colab to read and plot the logfiles](https://colab.research.google.com/drive/1nMykzUzsfQ_uAxfj3k35DYsATnG_knPl?usp=sharing) 15 | of a few runs that we reproduced on Cloud. 16 | 17 | ### ImageNet-1k 18 | 19 | The file [bit_i1k.py](bit_i1k.py) is the configuration which reproduces our 20 | distillation runs on ImageNet-1k reported in Figures 1 and 5(left) and the first 21 | row of Table1. 22 | 23 | We release both student and teacher models: 24 | 25 | | Model | Download link | Resolution | ImageNet top-1 acc. (paper) | 26 | | :--- | :---: | :---: | :---: | 27 | | BiT-R50x1 | [link](https://storage.googleapis.com/bit_models/distill/R50x1_160.npz) | 160 | 80.5 | 28 | | BiT-R50x1 | [link](https://storage.googleapis.com/bit_models/distill/R50x1_224.npz) | 224 | 82.8 | 29 | | BiT-R152x2 | [link](https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz) | 224 | 83.0 | 30 | | BiT-R152x2 | [link](https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz) | 384 | 84.3 | 31 | 32 | ### Flowers/Pet/Food/Sun 33 | 34 | The files [bigsweep_flowers_pet.py](bigsweep_flowers_pet.py) and 35 | [bigsweep_food_sun.py](bigsweep_food_sun.py) can be used to reproduce the 36 | distillation runs on these datasets and shown in Figures 3,4,9-12, and Table4. 37 | 38 | While our open-source release does not currently support doing hyper-parameter 39 | sweeps, we still provide an example of the sweeps at the end of the configs 40 | for reference. 41 | 42 | ### Teacher models 43 | Links to all teacher models we used can be found in [common.py](common.py). 44 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/configs/proj/distill/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Most common teachers for distillation.""" 16 | 17 | # pylint: disable=line-too-long 18 | inits = { # pylint: disable=duplicate-key Internally, we override some paths for convenience. 19 | 'BiT-M R152x2 imagenet2012 ic224': 'gs://bit_models/distill/R152x2_T_224.npz', 20 | 'BiT-M R152x2 imagenet2012 rc384': 'gs://bit_models/distill/R152x2_T_384.npz', 21 | 'BiT-M R152x2 flowers rc128': 'gs://bit_models/distill/R152x2_T_flowers128.npz', 22 | 'BiT-M R152x2 pet rc128': 'gs://bit_models/distill/R152x2_T_pet128.npz', 23 | 'BiT-M R152x2 food rc128': 'gs://bit_models/distill/R152x2_T_food128.npz', 24 | 'BiT-M R152x2 sun rc128': 'gs://bit_models/distill/R152x2_T_sun128.npz', 25 | 26 | } 27 | # pylint: enable=line-too-long 28 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/configs/proj/flexivit/timing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint: disable=line-too-long,missing-function-docstring 16 | r"""A config to run timing for FlexiViT (only inference, no I/O etc.). 17 | 18 | big_vision.tools.eval_only \ 19 | --config big_vision/configs/proj/flexivit/timing.py \ 20 | --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ 21 | --config.total_epochs 90 22 | """ 23 | 24 | from ml_collections import ConfigDict 25 | 26 | 27 | def get_config(): 28 | c = ConfigDict() 29 | 30 | shape = (240, 240, 3) 31 | c.batch_size = 8 # swept 32 | c.init_shapes = [(1, *shape)] 33 | c.representation_layer = 'pre_logits' 34 | 35 | # Creating complete model using all params, the sweep will go over variants. 36 | c.model_name = 'xp.flexivit.vit' 37 | c.model = dict( 38 | variant='B', 39 | pool_type='tok', 40 | patch_size=(10, 10), # Like deit@384 41 | seqhw=(24, 24), 42 | ) 43 | c.num_classes = 0 44 | 45 | c.evals = {} 46 | c.evals.timing = dict( 47 | type='timing', 48 | input_shapes=[shape], 49 | timing=True, 50 | pred_kw=dict(outputs=('pre_logits',)), 51 | ) 52 | 53 | return c -------------------------------------------------------------------------------- /palivla_digit/big_vision/configs/proj/givt/givt_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/configs/proj/givt/givt_overview.png -------------------------------------------------------------------------------- /palivla_digit/big_vision/configs/proj/paligemma/paligemma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/configs/proj/paligemma/paligemma.png -------------------------------------------------------------------------------- /palivla_digit/big_vision/configs/proj/paligemma/transfers/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Common things across all transfer configs.""" 16 | 17 | 18 | TOKENIZER = 'gemma(tokensets=("loc", "seg"))' 19 | 20 | 21 | def tok(**kw): 22 | """Creates the tokenization preprocessing string.""" 23 | # Single entry point so that it's consistent everywhere and easier to switch. 24 | kw.setdefault('model', TOKENIZER) 25 | kw = ', '.join(f'{k}={repr(v)}' for k, v in kw.items()) 26 | return f'tok({kw})' 27 | 28 | 29 | def combine_and_keep_train(text_len, before=(), sep='\n'): 30 | return '|'.join([ 31 | *before, 32 | tok(key='prefix', bos='yes'), 33 | tok(key='suffix', eos='yes'), 34 | tok(key='septok', text=sep), 35 | # If masks confuse you, see (internal link) 36 | 'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_loss=[0, 0, 1])', # pylint: disable=line-too-long 37 | # For training, we +1 since the trainer removes EOS. 38 | f'tolen({text_len+1}, pad_value=0, key="text")', # Value doesn't matter. 39 | f'tolen({text_len+1}, pad_value=1, key="mask_ar")', 40 | f'tolen({text_len+1}, pad_value=0, key="mask_loss")', 41 | 'keep("image", "text", "mask_ar", "mask_loss")', 42 | ]) 43 | 44 | 45 | def combine_and_keep_eval(text_len, keep=tuple(), before=(), sep='\n'): 46 | return '|'.join([ 47 | *before, 48 | # Same as training, except that suffix is now the empty string. 49 | # Meaning, we create text as [prefix separator pad], 50 | # and the mask accordingly as [0 0 1] (with repeats of respective lengths) 51 | tok(key='prefix', bos='yes'), 52 | tok(key='septok', text=sep), 53 | # At eval time, there can be also a suffix key in the data. If so it is 54 | # tokenized without EOS and decoding will continue from it. 55 | 'setdefault("suffix", "")', 56 | tok(key='suffix', eos='no'), 57 | # If masks confuse you, see (internal link) 58 | 'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_input=[1, 1, 1])', # pylint: disable=line-too-long 59 | f'tolen({text_len}, pad_value=0, key="text")', # value doesn't matter. 60 | f'tolen({text_len}, pad_value=1, key="mask_ar")', 61 | f'tolen({text_len}, pad_value=0, key="mask_input")', 62 | # And we need to keep everything that makes our evaluator happy. 63 | 'keep(' + ', '.join(f'"{x}"' for x in ( 64 | 'image', 'text', 'mask_ar', 'mask_input') + tuple(keep)) + ')', 65 | ]) 66 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/evaluators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/evaluators/__init__.py -------------------------------------------------------------------------------- /palivla_digit/big_vision/evaluators/classification.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluator for the classfication task.""" 16 | # pylint: disable=consider-using-from-import 17 | 18 | import functools 19 | 20 | from big_vision.evaluators import common 21 | import big_vision.utils as u 22 | import jax 23 | import jax.numpy as jnp 24 | 25 | 26 | # Temporary global flag to facilitate backwards compatability. Will be removed 27 | # by the end of year 2023. 28 | API = 'jit' 29 | 30 | 31 | # To avoid re-compiling the function for every new instance of the same 32 | # evaluator on a different dataset! 33 | @functools.cache 34 | def get_eval_fn(predict_fn, loss_name): 35 | """Produces eval function, also applies pmap.""" 36 | @jax.jit 37 | def _eval_fn(train_state, batch, labels, mask): 38 | logits, *_ = predict_fn(train_state, batch) 39 | 40 | # Ignore the entries with all zero labels for evaluation. 41 | mask *= labels.max(axis=1) 42 | 43 | loss = getattr(u, loss_name)( 44 | logits=logits, labels=labels, reduction=False) 45 | loss = jnp.sum(loss * mask) 46 | 47 | top1_idx = jnp.argmax(logits, axis=1) 48 | # Extracts the label at the highest logit index for each image. 49 | top1_correct = jnp.take_along_axis( 50 | labels, top1_idx[:, None], axis=1)[:, 0] 51 | ncorrect = jnp.sum(top1_correct * mask) 52 | nseen = jnp.sum(mask) 53 | return ncorrect, loss, nseen 54 | return _eval_fn 55 | 56 | 57 | class Evaluator: 58 | """Classification evaluator.""" 59 | 60 | def __init__(self, predict_fn, loss_name, label_key='labels', **kw): 61 | self.get_data_iter, self.steps = common.eval_input_pipeline(**kw) 62 | self.eval_fn = get_eval_fn(predict_fn, loss_name) 63 | self.label_key = label_key 64 | 65 | def run(self, train_state): 66 | """Computes all metrics.""" 67 | ncorrect, loss, nseen = 0, 0, 0 68 | for _, batch in zip(range(self.steps), self.get_data_iter()): 69 | labels, mask = batch.pop(self.label_key), batch.pop('_mask') 70 | batch_ncorrect, batch_losses, batch_nseen = jax.device_get( 71 | self.eval_fn(train_state, batch, labels, mask)) 72 | ncorrect += batch_ncorrect 73 | loss += batch_losses 74 | nseen += batch_nseen 75 | yield ('prec@1', ncorrect / nseen) 76 | yield ('loss', loss / nseen) 77 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/evaluators/proj/cappa/perplexity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluator for perplexity of a model.""" 16 | from big_vision.evaluators import mean 17 | import big_vision.utils as u 18 | import jax.numpy as jnp 19 | 20 | 21 | # Temporary global flag to facilitate backwards compatability. Will be removed 22 | # by the end of year 2023. 23 | API = 'jit' 24 | 25 | 26 | def perplexity(predict_fn, normalize_by_seqlen): 27 | """Returns a function that computes perplexity.""" 28 | 29 | def _perplexity_fn(train_state, batch, pad_token=0, **kw): 30 | logits, _ = predict_fn(train_state, batch, **kw) 31 | 32 | # Ignore perplexity on the padding label. 33 | weights = jnp.where(batch['labels'] != pad_token, 1, 0).astype(jnp.float32) 34 | if batch.get('label_masks') is not None: 35 | weights = weights * batch['label_masks'] 36 | 37 | losses = u.weighted_softmax_xent( 38 | logits=logits, labels=batch['labels'], 39 | weights=weights, label_smoothing=0.0, 40 | reduction=False, normalize=normalize_by_seqlen) 41 | 42 | return {'perplexity': losses} 43 | return _perplexity_fn 44 | 45 | 46 | class Evaluator(mean.Evaluator): 47 | """Perplexity evaluator.""" 48 | 49 | def __init__(self, predict_fn, *a, normalize_by_seqlen=False, **kw): 50 | super().__init__(perplexity(predict_fn, normalize_by_seqlen), *a, **kw) 51 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/evaluators/proj/cappa/scoring_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Scoring classifier. 16 | 17 | This one is based on a generative perspective for image classification. 18 | Here we input the image as well as all the tokenized labels to compute their 19 | perplexity and select the one with minimum loss as the prediction. 20 | """ 21 | import functools 22 | from big_vision.datasets.imagenet import class_names as imagenet_class_names 23 | from big_vision.evaluators import mean 24 | from big_vision.pp import builder as pp_builder 25 | import jax.numpy as jnp 26 | import numpy as np 27 | 28 | # Temporary global flag to facilitate backwards compatability. Will be removed 29 | # by the end of year 2023. 30 | API = "jit" 31 | 32 | 33 | CLASS_NAMES = { 34 | "imagenet2012": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES, 35 | } 36 | 37 | 38 | # As a separate function to cache result across instances. 39 | @functools.lru_cache(maxsize=None) 40 | def get_classes(dataset_name, pp_txt): 41 | """Load the class label strings and tokenize them using pp_txt.""" 42 | pp_fn = pp_builder.get_preprocess_fn(pp_txt, log_data=False) 43 | return np.array([pp_fn({"label": name})["labels"] 44 | for name in CLASS_NAMES[dataset_name]]) 45 | 46 | 47 | def scoring(predict_fn, tokenized_labels): 48 | 49 | def _scoring_fn(train_state, batch, *a, **kw): 50 | batch = {"_label_tokens": tokenized_labels, **batch} 51 | scores = predict_fn(train_state, batch, *a, **kw) 52 | predictions = jnp.argmax(scores, axis=-1) 53 | return {"prec@1": predictions == batch["label"]} 54 | 55 | return _scoring_fn 56 | 57 | 58 | class Evaluator(mean.Evaluator): 59 | """Evaluator for classification accuracy based on scoring all classes.""" 60 | 61 | def __init__(self, predict_fn, data, pp_fn, pp_txt, *a, **kw): 62 | cls_tokens = get_classes(data["name"], pp_txt) 63 | super().__init__(scoring(predict_fn, cls_tokens), data, pp_fn, *a, **kw) 64 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/evaluators/proj/image_text/prompt_engineering_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for prompt_engineering.""" 16 | 17 | from absl.testing import absltest 18 | from big_vision.evaluators.proj.image_text import prompt_engineering 19 | 20 | 21 | class PromptEngineeringTest(absltest.TestCase): 22 | 23 | def test_canonicalize_text(self): 24 | self.assertEqual(prompt_engineering.canonicalize_text("test_test"), "test test") 25 | self.assertEqual( 26 | prompt_engineering.canonicalize_text("test___test"), "test test") 27 | self.assertEqual(prompt_engineering.canonicalize_text("test"), "test") 28 | self.assertEqual(prompt_engineering.canonicalize_text("test."), "test") 29 | self.assertEqual(prompt_engineering.canonicalize_text(" test "), "test") 30 | self.assertEqual( 31 | prompt_engineering.canonicalize_text("test\ntest"), "test test") 32 | self.assertEqual( 33 | prompt_engineering.canonicalize_text("test test"), "test test") 34 | self.assertEqual(prompt_engineering.canonicalize_text("test {}"), "test") 35 | self.assertEqual( 36 | prompt_engineering.canonicalize_text( 37 | "test {}", keep_punctuation_exact_string="{}"), "test {}") 38 | self.assertEqual( 39 | prompt_engineering.canonicalize_text( 40 | " test {}...", keep_punctuation_exact_string="{}"), "test {}") 41 | self.assertEqual( 42 | prompt_engineering.canonicalize_text( 43 | "test {} {} {}", keep_punctuation_exact_string="{}"), 44 | "test {} {} {}") 45 | 46 | 47 | if __name__ == "__main__": 48 | absltest.main() 49 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/evaluators/proj/paligemma/perplexity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluator for perplexity of a model.""" 16 | import functools 17 | 18 | from big_vision.evaluators import mean 19 | import big_vision.utils as u 20 | import jax.numpy as jnp 21 | 22 | 23 | # Temporary global flag to facilitate backwards compatability. Will be removed 24 | # by the end of year 2023. 25 | API = 'jit' 26 | 27 | 28 | # Cache the function such that it won't always recompile (in mean evaluator). 29 | @functools.cache 30 | def perplexity( 31 | predict_fn, key='labels', shift_labels=True): 32 | """Returns a function that computes perplexity.""" 33 | 34 | def _perplexity_fn(train_state, batch, **kw): 35 | logits, _ = predict_fn(train_state, batch, **kw) 36 | 37 | labels = batch[key] 38 | weights = batch.get('mask_loss', jnp.ones_like(labels)) 39 | 40 | if shift_labels: 41 | labels = labels[:, 1:] 42 | weights = weights[:, 1:] 43 | 44 | losses = u.weighted_softmax_xent( 45 | logits=logits, labels=labels, weights=weights, 46 | reduction=False, normalize=False) 47 | normalizer = jnp.clip(weights.sum(axis=1), 2e-38) 48 | 49 | return {'sum': losses, 'avg': losses / normalizer} 50 | return _perplexity_fn 51 | 52 | 53 | class Evaluator(mean.Evaluator): 54 | """Perplexity evaluator.""" 55 | 56 | def __init__(self, predict_fn, *a, key='labels', shift_labels=False, **kw): 57 | kw.setdefault('prefetch', 0) # More memory-saving default. 58 | super().__init__(perplexity(predict_fn, key, shift_labels), *a, **kw) 59 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/evaluators/proj/paligemma/transfers/storepreds.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluator to run inference and store results.""" 16 | import functools 17 | 18 | import big_vision.evaluators.common as c 19 | import big_vision.input_pipeline 20 | import big_vision.pp.builder 21 | import big_vision.pp.tokenizer 22 | import big_vision.utils as u 23 | 24 | import jax 25 | 26 | # Temporary global flag to facilitate backwards compatability. Will be removed 27 | # by the end of year 2023. 28 | API = "jit" 29 | 30 | 31 | class Evaluator: 32 | """Evaluator to run inference and store results.""" 33 | 34 | def __init__( 35 | self, predict_fn, tokenizer=None, 36 | preds_outfile="{workdir}/{name}_{split}_preds.json", 37 | annot_outfile="{workdir}/{name}_{split}_annotations.json", 38 | id_key="id", 39 | *, data, devices, **kw 40 | ): 41 | self.id_key = id_key 42 | self.get_data_iter, self.steps = c.eval_input_pipeline( 43 | keep_on_cpu={id_key}, data=data, devices=devices, **kw) 44 | 45 | self.preds_outfile = c.resolve_outfile( 46 | preds_outfile, name=data.get("name"), split=data.get("split", "")) 47 | self.annot_outfile = c.resolve_outfile( 48 | annot_outfile, name=data.get("name"), split=data.get("split", "")) 49 | 50 | self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) 51 | self.decode = functools.partial( 52 | predict_fn, devices=devices, eos_token=self.tok.eos_token) 53 | 54 | def run(self, train_state): 55 | """Run eval.""" 56 | res = [] 57 | 58 | for _, batch in zip(range(self.steps), self.get_data_iter()): 59 | # (batch, seqlen) array of decoded generated tokens. 60 | tokens = self.decode(train_state, batch) 61 | 62 | # (local_batch,) 63 | tokens = u.get_local_slice_from_fsarray(tokens) 64 | ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) 65 | 66 | image_ids = batch[self.id_key][ex_masks] 67 | pred_captions = self.tok.to_str(tokens[ex_masks]) 68 | 69 | for image_id, caption in zip(image_ids, pred_captions): 70 | res.append({self.id_key: str(image_id), "caption": caption}) 71 | 72 | res = c.multiprocess_write_json(self.preds_outfile, res) 73 | 74 | if jax.process_index(): # Host0 gets all preds and does eval. 75 | return 76 | 77 | yield "num_examples", len(res) 78 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/evaluators/proj/uvim/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Common utilities used in evaluators.""" 16 | import math 17 | import jax 18 | import tensorflow as tf 19 | import tensorflow_datasets as tfds 20 | 21 | 22 | def get_jax_process_dataset(dataset, split, global_batch_size, pp_fn, 23 | dataset_dir=None, cache=True, add_tfds_id=False): 24 | """Returns dataset to be processed by current jax host. 25 | 26 | The dataset is sharded and padded with zeros such that all processes 27 | have equal number of batches. The first 2 dimensions of the dataset 28 | elements are: [local_device_count, device_batch_size]. 29 | 30 | Args: 31 | dataset: dataset name. 32 | split: dataset split. 33 | global_batch_size: batch size to be process per iteration on the dataset. 34 | pp_fn: preprocessing function to apply per example. 35 | dataset_dir: path for tfds to find the prepared data. 36 | cache: whether to cache the dataset after batching. 37 | add_tfds_id: whether to add the unique `tfds_id` string to each example. 38 | """ 39 | assert global_batch_size % jax.device_count() == 0 40 | total_examples = tfds.load( 41 | dataset, split=split, data_dir=dataset_dir).cardinality() 42 | num_batches = math.ceil(total_examples / global_batch_size) 43 | 44 | process_split = tfds.even_splits( 45 | split, n=jax.process_count(), drop_remainder=False)[jax.process_index()] 46 | data = tfds.load( 47 | dataset, 48 | split=process_split, 49 | data_dir=dataset_dir, 50 | read_config=tfds.ReadConfig(add_tfds_id=add_tfds_id)).map(pp_fn) 51 | pad_data = tf.data.Dataset.from_tensors( 52 | jax.tree_map(lambda x: tf.zeros(x.shape, x.dtype), data.element_spec) 53 | ).repeat() 54 | 55 | data = data.concatenate(pad_data) 56 | data = data.batch(global_batch_size // jax.device_count()) 57 | data = data.batch(jax.local_device_count()) 58 | data = data.take(num_batches) 59 | if cache: 60 | # Eval datasets are often used many times and caching the dataset after 61 | # batching allows one to have the buffers ready to be used and not have 62 | # to wait for preprocessing to be done over and over. 63 | data = data.cache() 64 | return data 65 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/models/__init__.py -------------------------------------------------------------------------------- /palivla_digit/big_vision/models/ppp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/models/ppp/__init__.py -------------------------------------------------------------------------------- /palivla_digit/big_vision/models/proj/flaxformer/bert_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for bert.""" 16 | 17 | import tempfile 18 | 19 | from big_vision import input_pipeline 20 | from big_vision.models.proj.flaxformer import bert, bert_test_util 21 | import big_vision.pp.builder as pp_builder 22 | import big_vision.pp.ops_general # pylint: disable=unused-import 23 | import big_vision.pp.proj.flaxformer.bert_ops # pylint: disable=unused-import 24 | import flax 25 | import jax 26 | import jax.numpy as jnp 27 | import tensorflow as tf 28 | 29 | # BERT vocabulary for testing. 30 | _BERT_VOCAB = [ 31 | "[PAD]", 32 | "[UNK]", 33 | "this", 34 | "is", 35 | "a", 36 | "test", 37 | "[CLS]", 38 | "[SEP]", 39 | ] 40 | _TOKEN_LEN = 16 41 | 42 | 43 | class BertTest(tf.test.TestCase): 44 | def test_load_apply(self): 45 | inkey = "text" 46 | vocab_path = f"{tempfile.mkdtemp()}/vocab.txt" 47 | with open(vocab_path, "w") as f: 48 | f.write("\n".join(_BERT_VOCAB)) 49 | ds2, _ = input_pipeline.make_for_inference( 50 | tf.data.Dataset.from_tensor_slices( 51 | {inkey: tf.ragged.constant([["this is a test"]])} 52 | ), 53 | num_ex_per_process=[1], 54 | preprocess_fn=pp_builder.get_preprocess_fn( 55 | f"bert_tokenize(inkey='{inkey}', vocab_path='{vocab_path}', " 56 | f"max_len={_TOKEN_LEN})" 57 | "|keep('labels')" 58 | ), 59 | batch_size=1, 60 | ) 61 | text = jnp.array(next(iter(ds2))["labels"]) 62 | model = bert.Model(config="base") 63 | variables = model.init(jax.random.PRNGKey(0), text) 64 | params = bert.load( 65 | flax.core.unfreeze(variables)["params"], 66 | bert_test_util.create_base_checkpoint(), 67 | ) 68 | x, out = model.apply({"params": params}, text) 69 | self.assertAllEqual(jax.tree_map(jnp.shape, x), (1, 768)) 70 | self.assertAllEqual( 71 | jax.tree_map(jnp.shape, out), 72 | { 73 | "transformed": (1, 16, 768), 74 | "pre_logits": (1, 768), 75 | }, 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | tf.test.main() 81 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/models/proj/givt/adaptor_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for the IRevNet adaptor.""" 16 | 17 | from absl.testing import absltest 18 | from big_vision.models.proj.givt import adaptor 19 | import jax 20 | from jax import random 21 | import jax.numpy as jnp 22 | 23 | 24 | class AdaptorTest(googletest.TestCase): 25 | def test_inversion(self): 26 | num_channels = 8 27 | input_shape = (1, 24, 24, num_channels) 28 | 29 | rng = random.PRNGKey(758493) 30 | _, inp_rng, init_rng, data_rng = jax.random.split(rng, 4) 31 | 32 | dummy_x = random.normal(inp_rng, shape=input_shape) 33 | real_x = jax.random.normal(data_rng, shape=input_shape) 34 | 35 | model = adaptor.IRevNet( 36 | num_blocks=4, 37 | num_channels=num_channels, 38 | dropout_rate=0.0, 39 | ) 40 | params = model.init(init_rng, dummy_x) 41 | 42 | real_y = model.apply(params, real_x, method=model.forward) 43 | real_x_ = model.apply(params, real_y, method=model.inverse) 44 | self.assertTrue(jnp.allclose(real_x, real_x_, atol=1e-5)) 45 | 46 | 47 | if __name__ == "__main__": 48 | googletest.main() 49 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/models/proj/uvim/vit_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for vit vqvae model.""" 16 | from absl.testing import absltest 17 | from big_vision.models.proj.uvim import vit 18 | import jax 19 | import jax.numpy as jnp 20 | import ml_collections 21 | 22 | 23 | class ViTVQVAEModelTest(absltest.TestCase): 24 | def test_model(self): 25 | model_config = ml_collections.ConfigDict( 26 | { 27 | "input_size": (32, 32), 28 | "code_len": 4, 29 | "width": 16, 30 | "mlp_dim": 64, 31 | "num_heads": 4, 32 | "enc_depth": 1, 33 | "dec_depth": 1, 34 | "with_encoder_ctx": True, 35 | "with_decoder_ctx": True, 36 | "statistics_axis_name": None, 37 | "inputs": { 38 | "in1": (10, 3), 39 | "in2": (25,), 40 | }, 41 | "outputs": { 42 | "out1": (5,), 43 | "out2": (20,), 44 | }, 45 | } 46 | ) 47 | 48 | model = vit.Model(**model_config) 49 | batch_size = 4 50 | seq_len = (32 // 8) ** 2 51 | x = { 52 | "in1": jnp.zeros((batch_size, seq_len, 10, 3)), 53 | "in2": jnp.zeros((batch_size, seq_len, 25)), 54 | } 55 | ctx_image = jnp.zeros((batch_size,) + model_config.input_size + (3,)) 56 | init_rngs = { 57 | "params": jax.random.PRNGKey(0), 58 | "state": jax.random.PRNGKey(1), 59 | } 60 | params = model.init(init_rngs, x, ctx=ctx_image) 61 | self.assertEqual(params.keys(), set(["params", "state"])) 62 | 63 | apply_rngs = { 64 | "dropout": jax.random.PRNGKey(0), 65 | "vqvae": jax.random.PRNGKey(0), 66 | } 67 | (logits, _), params = model.apply( 68 | params, 69 | x, 70 | ctx=ctx_image, 71 | train=True, 72 | update_dict=True, 73 | rngs=apply_rngs, 74 | mutable=["state"], 75 | ) 76 | self.assertEqual(logits.keys(), set(["out1", "out2"])) 77 | self.assertEqual(logits["out1"].shape, (batch_size, seq_len, 5)) 78 | self.assertEqual(logits["out2"].shape, (batch_size, seq_len, 20)) 79 | 80 | 81 | if __name__ == "__main__": 82 | absltest.main() 83 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/models/proj/uvim/vtt_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for vision-text-transformer.""" 16 | from absl.testing import absltest 17 | from big_vision.models.proj.uvim import vtt 18 | import jax 19 | import jax.numpy as jnp 20 | import ml_collections 21 | 22 | 23 | class VTTTest(absltest.TestCase): 24 | def test_vtt_with_1_step(self): 25 | model_config = ml_collections.ConfigDict( 26 | dict( 27 | input_size=(224, 224), 28 | patches={"size": (16, 16)}, 29 | num_heads=2, 30 | num_layers=2, 31 | mlp_dim=128, 32 | emb_dim=64, 33 | vocab_size=500, 34 | ) 35 | ) 36 | batch_size, max_len = 8, 50 37 | image = jnp.ones((batch_size, 224, 224, 3)) 38 | text = jnp.ones((batch_size, max_len), dtype=jnp.int32) 39 | 40 | m = vtt.Model(**model_config) 41 | variables = m.init(jax.random.PRNGKey(42), image, text) 42 | self.assertCountEqual(variables.keys(), ["params"]) 43 | 44 | params = variables["params"] 45 | out = m.apply({"params": params}, image, text) 46 | expected_shape = (batch_size, max_len, model_config.vocab_size) 47 | self.assertEqual(out.shape, expected_shape) 48 | 49 | 50 | if __name__ == "__main__": 51 | absltest.main() 52 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/pp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/pp/__init__.py -------------------------------------------------------------------------------- /palivla_digit/big_vision/pp/archive/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/pp/archive/__init__.py -------------------------------------------------------------------------------- /palivla_digit/big_vision/pp/archive/randaug.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """RandAug depends on deprecated tfa.image package, now defunct.""" 16 | 17 | from big_vision.pp import registry 18 | from big_vision.pp import utils 19 | from big_vision.pp.archive import autoaugment 20 | 21 | 22 | @registry.Registry.register("preprocess_ops.randaug") 23 | @utils.InKeyOutKey() 24 | def get_randaug(num_layers: int = 2, magnitude: int = 10): 25 | """Creates a function that applies RandAugment. 26 | 27 | RandAugment is from the paper https://arxiv.org/abs/1909.13719, 28 | 29 | Args: 30 | num_layers: Integer, the number of augmentation transformations to apply 31 | sequentially to an image. Represented as (N) in the paper. Usually best 32 | values will be in the range [1, 3]. 33 | magnitude: Integer, shared magnitude across all augmentation operations. 34 | Represented as (M) in the paper. Usually best values are in the range [5, 35 | 30]. 36 | 37 | Returns: 38 | a function that applies RandAugment. 39 | """ 40 | 41 | def _randaug(image): 42 | return autoaugment.distort_image_with_randaugment( 43 | image, num_layers, magnitude 44 | ) 45 | 46 | return _randaug 47 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/pp/builder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for builder.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from big_vision.pp import builder 22 | from big_vision.pp import ops_general # pylint: disable=unused-import 23 | from big_vision.pp import ops_image # pylint: disable=unused-import 24 | import numpy as np 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | class BuilderTest(tf.test.TestCase): 29 | 30 | def testSingle(self): 31 | pp_fn = builder.get_preprocess_fn("resize(256)") 32 | x = np.random.randint(0, 256, [640, 480, 3]) 33 | image = pp_fn({"image": x})["image"] 34 | self.assertEqual(image.numpy().shape, (256, 256, 3)) 35 | 36 | def testEmpty(self): 37 | pp_fn = builder.get_preprocess_fn("||inception_crop|||resize(256)||") 38 | 39 | # Typical image input 40 | x = np.random.randint(0, 256, [640, 480, 3]) 41 | image = pp_fn({"image": x})["image"] 42 | self.assertEqual(image.numpy().shape, (256, 256, 3)) 43 | 44 | def testPreprocessingPipeline(self): 45 | pp_str = ("inception_crop|resize(256)|resize((256, 256))|" 46 | "central_crop((80, 120))|flip_lr|value_range(0,1)|" 47 | "value_range(-1,1)") 48 | pp_fn = builder.get_preprocess_fn(pp_str) 49 | 50 | # Typical image input 51 | x = np.random.randint(0, 256, [640, 480, 3]) 52 | image = pp_fn({"image": x})["image"] 53 | self.assertEqual(image.numpy().shape, (80, 120, 3)) 54 | self.assertLessEqual(np.max(image.numpy()), 1) 55 | self.assertGreaterEqual(np.min(image.numpy()), -1) 56 | 57 | def testNumArgsException(self): 58 | 59 | x = np.random.randint(0, 256, [640, 480, 3]) 60 | for pp_str in [ 61 | "inception_crop(1)", 62 | "resize()", 63 | "resize(1, 1, 1)" 64 | "flip_lr(1)", 65 | "central_crop()", 66 | ]: 67 | with self.assertRaises(BaseException): 68 | builder.get_preprocess_fn(pp_str)(x) 69 | 70 | 71 | if __name__ == "__main__": 72 | tf.test.main() 73 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/pp/proj/clippo/download_unifont.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | # This is intended to be run from the big_vision repository root: 17 | # 18 | # bash big_vision/pp/proj/clippo/download_unifont.sh 19 | wget https://unifoundry.com/pub/unifont/unifont-9.0.06/font-builds/unifont-9.0.06.hex.gz https://unifoundry.com/pub/unifont/unifont-9.0.06/font-builds/unifont_upper-9.0.06.hex.gz 20 | gunzip unifont-9.0.06.hex.gz unifont_upper-9.0.06.hex.gz 21 | mv unifont-9.0.06.hex unifont_upper-9.0.06.hex big_vision/pp/proj/clippo/ -------------------------------------------------------------------------------- /palivla_digit/big_vision/pp/proj/flaxformer/bert_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for bert_ops.""" 16 | 17 | import tempfile 18 | 19 | from big_vision import input_pipeline 20 | import big_vision.pp.builder as pp_builder 21 | import big_vision.pp.ops_general # pylint: disable=unused-import 22 | from big_vision.pp.proj.flaxformer import bert_ops # pylint: disable=unused-import 23 | import tensorflow as tf 24 | 25 | 26 | # BERT vocabulary for testing. 27 | _BERT_VOCAB = [ 28 | "[PAD]", 29 | "[UNK]", 30 | "more", 31 | "than", 32 | "one", 33 | "[CLS]", 34 | "[SEP]", 35 | ] 36 | 37 | 38 | def _create_ds(pp_str, tensor_slices, num_examples): 39 | return input_pipeline.make_for_inference( 40 | tf.data.Dataset.from_tensor_slices(tensor_slices), 41 | num_ex_per_process=[num_examples], 42 | preprocess_fn=pp_builder.get_preprocess_fn(pp_str), 43 | batch_size=num_examples, 44 | )[0] 45 | 46 | 47 | class BertOpsTest(tf.test.TestCase): 48 | 49 | def test_tokenize(self): 50 | inkey = "texts" 51 | vocab_path = f"{tempfile.mkdtemp()}/vocab.txt" 52 | with open(vocab_path, "w") as f: 53 | f.write("\n".join(_BERT_VOCAB)) 54 | pp_str = ( 55 | f"bert_tokenize(inkey='{inkey}', vocab_path='{vocab_path}', max_len=5)" 56 | f"|keep('labels')" 57 | ) 58 | tensor_slices = { 59 | inkey: tf.ragged.constant([["one more"], ["more than one"], [""]]) 60 | } 61 | ds = _create_ds(pp_str, tensor_slices, 3) 62 | self.assertAllEqual( 63 | next(iter(ds))["labels"], 64 | [[5, 4, 2, 0, 0], [5, 2, 3, 4, 0], [5, 0, 0, 0, 0]], 65 | ) 66 | 67 | 68 | if __name__ == "__main__": 69 | tf.test.main() 70 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/pp/proj/givt/pp_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """GIVT-specific preprocessing ops.""" 16 | 17 | from big_vision.pp import registry 18 | from big_vision.pp import utils 19 | import tensorflow as tf 20 | 21 | 22 | @registry.Registry.register("preprocess_ops.bin_nyu_depth") 23 | @utils.InKeyOutKey(indefault="labels", outdefault="labels") 24 | def get_bin_nyu_depth(min_depth=0.001, max_depth=10.0, num_bins=256): 25 | """Binning of NYU depth for UViM in preprocessing rather than model.""" 26 | 27 | def _bin_depth(labels): # pylint: disable=missing-docstring 28 | labels = (labels - min_depth) / (max_depth - min_depth) 29 | labels *= num_bins 30 | labels = tf.cast(tf.floor(labels), tf.int32) 31 | labels = tf.minimum(labels, num_bins - 1) 32 | labels = tf.maximum(labels, 0) 33 | return labels 34 | 35 | return _bin_depth 36 | 37 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/pp/proj/paligemma/robustness.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """pp ops.""" 16 | 17 | import math 18 | 19 | from big_vision.pp import utils 20 | from big_vision.pp.registry import Registry 21 | import tensorflow as tf 22 | 23 | 24 | @Registry.register("preprocess_ops.resize_r") 25 | @utils.InKeyOutKey() 26 | def get_resize_r(size): 27 | """Like standard `resize` but randomize some of its parameters.""" 28 | size = utils.maybe_repeat(size, 2) 29 | 30 | # Sadly TF won't let us pass symbolic arguments, so we need to pre-create all 31 | # variants of function calls we'd like to randomize over... 32 | resize_fns = [ 33 | lambda x, m=m, a=a: tf.image.resize(x, size, method=m, antialias=a) 34 | for m in ["bilinear", "bicubic", "lanczos3", "area", "mitchellcubic"] 35 | for a in [True, False] 36 | ] 37 | 38 | def _resize_r(image): 39 | """Resizes image to a given size.""" 40 | dtype = image.dtype 41 | tf_dtype = tf.type_spec_from_value(image).dtype 42 | ifn = tf.random.uniform((), 0, len(resize_fns), tf.int32) 43 | image = tf.switch_case(ifn, [lambda fn=fn: fn(image) for fn in resize_fns]) 44 | return tf.cast(tf.clip_by_value(image, tf_dtype.min, tf_dtype.max), dtype) 45 | 46 | return _resize_r 47 | 48 | 49 | @Registry.register("preprocess_ops.random_jpeg") 50 | @utils.InKeyOutKey() 51 | def get_random_jpeg(p): 52 | """With probability `p`, randomly encode-decode as jpeg.""" 53 | 54 | fns = [ 55 | lambda x: tf.image.adjust_jpeg_quality( 56 | x, dct_method="INTEGER_FAST", 57 | jpeg_quality=tf.random.uniform((), 75, 96, dtype=tf.int32), 58 | ), 59 | lambda x: tf.image.adjust_jpeg_quality( 60 | x, dct_method="INTEGER_ACCURATE", 61 | jpeg_quality=tf.random.uniform((), 75, 96, dtype=tf.int32), 62 | ), 63 | ] 64 | 65 | def _random_jpeg(image): 66 | """Resizes image to a given size.""" 67 | funcs = [lambda: image] + [lambda fn=fn: fn(image) for fn in fns] 68 | logits = [math.log(prob) for prob in [1 - p] + [p / len(fns)] * len(fns)] 69 | fn_idx = tf.random.categorical([logits], 1, dtype=tf.int32)[0, 0] 70 | return tf.switch_case(fn_idx, funcs) 71 | 72 | return _random_jpeg 73 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/pp/proj/paligemma/sciqa_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """pp ops.""" 16 | 17 | from big_vision.pp.registry import Registry 18 | import tensorflow as tf 19 | 20 | 21 | @Registry.register('preprocess_ops.sci_qa_choices_shuffle') 22 | def sci_qa_choices_shuffle( 23 | choice_str_inkey='choices', 24 | ans_inkey='answer', 25 | indexed_choices_outkey='indexed_choices', 26 | indexed_answer_outkey='indexed_answer', 27 | ): 28 | """Random shuffle the sci_qa's choice on the fly. 29 | 30 | Args: 31 | choice_str_inkey: the original choice list from 32 | sciqa,e.g['apple','banana',..] 33 | ans_inkey: the original answer from sciqa e.g. 1 34 | indexed_choices_outkey: shuffled choice (with index suffix concat to string) 35 | e.g."(A) banana, (B) apple" 36 | indexed_answer_outkey: shuffled answer with abc index, e,g 37 | 1(original)->2(shuffled)->'B' (alphabet index) 38 | 39 | Returns: 40 | """ 41 | def _template(data): 42 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 43 | abc_tensor = tf.constant([f'({a})' for a in alphabet]) 44 | abcans_tensor = tf.constant([f'{a}' for a in alphabet]) 45 | choices = data[choice_str_inkey] 46 | indices = tf.range(len(choices)) 47 | # Shuffle the indices 48 | shuffled_indices = tf.random.shuffle(indices) 49 | # Use the shuffled indices to shuffle the tensor 50 | shuffled_tensor = tf.gather(choices, shuffled_indices) 51 | 52 | abc_tensor = tf.gather(abc_tensor, indices) 53 | 54 | data[indexed_choices_outkey] = tf.strings.reduce_join( 55 | tf.strings.join([abc_tensor, shuffled_tensor], separator=' '), 56 | separator=', ', 57 | ) 58 | 59 | answer_tensor = data[ans_inkey] 60 | new_ans_indice = tf.where(tf.equal(shuffled_indices, answer_tensor)) 61 | new_ans_indice = tf.gather(abcans_tensor, new_ans_indice) 62 | data[indexed_answer_outkey] = tf.strings.reduce_join(new_ans_indice) 63 | return data 64 | 65 | return _template 66 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/pp/proj/paligemma/widgetcap.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Widgetcap pp ops.""" 16 | 17 | from big_vision.pp.registry import Registry 18 | import tensorflow as tf 19 | 20 | 21 | @Registry.register("preprocess_ops.draw_bbox") 22 | def get_draw_bbox(image_key="image", bbox_key="bbox"): 23 | """Draw a single bounding box.""" 24 | 25 | def _draw_bbox(data): 26 | """Draw a single bounding box.""" 27 | image = tf.cast(data[image_key], tf.float32) 28 | image = tf.image.draw_bounding_boxes( 29 | tf.expand_dims(image, 0), 30 | tf.reshape(data[bbox_key], [1, 1, 4]), 31 | tf.constant([255, 0, 0], dtype=tf.float32, shape=[1, 3]), 32 | ) 33 | data[image_key] = tf.squeeze(image) 34 | return data 35 | 36 | return _draw_bbox 37 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/pp/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Preprocessing utils.""" 16 | 17 | from collections import abc 18 | 19 | 20 | def maybe_repeat(arg, n_reps): 21 | if not isinstance(arg, abc.Sequence) or isinstance(arg, str): 22 | arg = (arg,) * n_reps 23 | return arg 24 | 25 | 26 | class InKeyOutKey(object): 27 | """Decorator for preprocessing ops, which adds `inkey` and `outkey` arguments. 28 | 29 | Note: Only supports single-input single-output ops. 30 | """ 31 | 32 | def __init__(self, indefault="image", outdefault="image", with_data=False): 33 | self.indefault = indefault 34 | self.outdefault = outdefault 35 | self.with_data = with_data 36 | 37 | def __call__(self, orig_get_pp_fn): 38 | 39 | def get_ikok_pp_fn(*args, key=None, 40 | inkey=self.indefault, outkey=self.outdefault, **kw): 41 | 42 | orig_pp_fn = orig_get_pp_fn(*args, **kw) 43 | def _ikok_pp_fn(data): 44 | # Optionally allow the function to get the full data dict as aux input. 45 | if self.with_data: 46 | data[key or outkey] = orig_pp_fn(data[key or inkey], data=data) 47 | else: 48 | data[key or outkey] = orig_pp_fn(data[key or inkey]) 49 | return data 50 | 51 | return _ikok_pp_fn 52 | 53 | return get_ikok_pp_fn 54 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/pp/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for preprocessing utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from big_vision.pp import utils 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | class UtilsTest(tf.test.TestCase): 26 | 27 | def test_maybe_repeat(self): 28 | self.assertEqual((1, 1, 1), utils.maybe_repeat(1, 3)) 29 | self.assertEqual((1, 2), utils.maybe_repeat((1, 2), 2)) 30 | self.assertEqual([1, 2], utils.maybe_repeat([1, 2], 2)) 31 | 32 | def test_inkeyoutkey(self): 33 | @utils.InKeyOutKey() 34 | def get_pp_fn(shift, scale=0): 35 | def _pp_fn(x): 36 | return scale * x + shift 37 | return _pp_fn 38 | 39 | data = {"k_in": 2, "other": 3} 40 | ppfn = get_pp_fn(1, 2, inkey="k_in", outkey="k_out") # pylint: disable=unexpected-keyword-arg 41 | self.assertEqual({"k_in": 2, "k_out": 5, "other": 3}, ppfn(data)) 42 | 43 | data = {"k": 6, "other": 3} 44 | ppfn = get_pp_fn(1, inkey="k", outkey="k") # pylint: disable=unexpected-keyword-arg 45 | self.assertEqual({"k": 1, "other": 3}, ppfn(data)) 46 | 47 | data = {"other": 6, "image": 3} 48 | ppfn = get_pp_fn(5, 2) 49 | self.assertEqual({"other": 6, "image": 11}, ppfn(data)) 50 | 51 | 52 | if __name__ == "__main__": 53 | tf.test.main() 54 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.26 2 | absl-py 3 | git+https://github.com/google/CommonLoopUtils 4 | distrax 5 | editdistance 6 | einops 7 | flax 8 | optax 9 | git+https://github.com/google/flaxformer 10 | git+https://github.com/akolesnikoff/panopticapi.git@mute 11 | overrides 12 | protobuf 13 | sentencepiece 14 | tensorflow-cpu 15 | tfds-nightly 16 | tensorflow-text 17 | tensorflow-gan 18 | pycocoevalcap 19 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/run_tpu.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | 17 | if [ ! -d "bv_venv" ] 18 | then 19 | sudo apt-get update 20 | sudo apt install -y python3-venv 21 | python3 -m venv bv_venv 22 | . bv_venv/bin/activate 23 | 24 | pip install -U pip # Yes, really needed. 25 | # NOTE: doesn't work when in requirements.txt -> cyclic dep 26 | pip install "jax[tpu]>=0.4.25" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 27 | pip install -r big_vision/requirements.txt 28 | else 29 | . bv_venv/bin/activate 30 | fi 31 | 32 | if [ $# -ne 0 ] 33 | then 34 | env TFDS_DATA_DIR=$TFDS_DATA_DIR BV_JAX_INIT=1 python3 -m "$@" 35 | fi 36 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/download_tfds_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Download and prepare TFDS datasets for the big_vision codebase. 16 | 17 | This python script covers cifar10, cifar100, oxford_iiit_pet 18 | and oxford_flowers10. 19 | 20 | If you want to integrate other public or custom datasets, please follow: 21 | https://www.tensorflow.org/datasets/catalog/overview 22 | """ 23 | 24 | from absl import app 25 | import tensorflow_datasets as tfds 26 | 27 | 28 | def main(argv): 29 | if len(argv) > 1 and "download_tfds_datasets.py" in argv[0]: 30 | datasets = argv[1:] 31 | else: 32 | datasets = [ 33 | "cifar10", 34 | "cifar100", 35 | "oxford_iiit_pet", 36 | "oxford_flowers102", 37 | "imagenet_v2", 38 | ] 39 | for d in datasets: 40 | tfds.load(name=d, download=True) 41 | 42 | 43 | if __name__ == "__main__": 44 | app.run(main) 45 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/README.md: -------------------------------------------------------------------------------- 1 | # LiT-Demo 2 | 3 | See https://blog.tensorflow.org/2022/08/jax-on-web-with-tensorflowjs.html 4 | 5 | Demo originally appeared on Twitter 6 | https://twitter.com/AndreasPSteiner/status/1514722383818543106 7 | 8 | App published at 9 | https://google-research.github.io/vision_transformer/lit 10 | 11 | ## Build 12 | 13 | Install packages (tested with node v16.17.0 and yarn 1.22.19) 14 | 15 | ```bash 16 | yarn 17 | ``` 18 | 19 | 20 | ## Run 21 | 22 | The web app will appear on http://localhost:8000 23 | 24 | ``` 25 | node build.js 26 | ``` 27 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/build.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | const sassPlugin = require('esbuild-sass-plugin').sassPlugin; 19 | 20 | require('esbuild').serve({ 21 | servedir: 'src', 22 | port: 8000, 23 | }, { 24 | entryPoints: ['src/app.ts'], 25 | bundle: true, 26 | outfile: 'src/index.js', 27 | plugins: [ 28 | sassPlugin({ 29 | filter: /style.scss$/, 30 | type: 'style' 31 | }), 32 | sassPlugin({ 33 | type: 'lit-css', 34 | }), 35 | ], 36 | sourcemap: true, 37 | }).then(() => { 38 | console.log('Serving on port 8000'); 39 | }).catch(() => process.exit(1)); 40 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "lit-demo", 3 | "version": "0.0.2", 4 | "description": "", 5 | "main": "src/app.ts", 6 | "license": "Apache-2.0", 7 | "private": true, 8 | "engines": { 9 | "node": ">=8.9.0" 10 | }, 11 | "scripts": { 12 | "serve": "node build.js", 13 | "test": "ts-node --skip-ignore --project tsconfig.test.json run_tests.ts" 14 | }, 15 | "devDependencies": { 16 | "@babel/core": "^7.7.5", 17 | "@babel/plugin-transform-runtime": "^7.7.6", 18 | "@babel/polyfill": "^7.10.4", 19 | "@babel/preset-env": "^7.7.6", 20 | "@tensorflow/tfjs-backend-cpu": "^3.15.0", 21 | "@tensorflow/tfjs-backend-webgl": "^3.15.0", 22 | "@tensorflow/tfjs-converter": "3.20.0", 23 | "@tensorflow/tfjs-core": "3.20.0", 24 | "babel-preset-env": "^1.7.0", 25 | "esbuild": "^0.15.5", 26 | "esbuild-sass-plugin": "^2.3.2", 27 | "jasmine": "^3.3.1", 28 | "lit": "^2.3.1", 29 | "naughty-words": "^1.2.0", 30 | "sass": "^1.50.0", 31 | "ts-node": "~5.0.0", 32 | "typescript": "4.1.3" 33 | }, 34 | "resolutions": { 35 | "is-svg": "4.3.1" 36 | }, 37 | "eslintConfig": { 38 | "extends": "google", 39 | "rules": { 40 | "require-jsdoc": 0, 41 | "valid-jsdoc": 0 42 | }, 43 | "env": { 44 | "es6": true 45 | }, 46 | "parserOptions": { 47 | "ecmaVersion": 8, 48 | "sourceType": "module" 49 | } 50 | }, 51 | "eslintIgnore": [ 52 | "dist/" 53 | ] 54 | } 55 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/app.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | import {LitDemoApp} from './components/lit-demo-app'; 19 | import './style.scss'; 20 | 21 | // tslint:disable-next-line:no-any 22 | (window as any).LitDemoApp = LitDemoApp; 23 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/components/image-carousel.scss: -------------------------------------------------------------------------------- 1 | @import '../style/mixins'; 2 | 3 | .selector { 4 | overflow: scroll; 5 | padding-bottom: 10px; // OS X scroll bar 6 | 7 | .inner { 8 | white-space: nowrap; 9 | 10 | .thumb { 11 | display: inline-block; 12 | 13 | img { 14 | cursor: pointer; 15 | 16 | width: 20vmin; 17 | height: 20vmin; 18 | max-width: 200px; 19 | max-height: 200px; 20 | 21 | @include phone-portrait { 22 | width: 33vmin; 23 | height: 33vmin; 24 | } 25 | 26 | margin: 10px; 27 | 28 | box-shadow: 0 0 10px #888; 29 | } 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/components/image-carousel.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview Carousel of images. 20 | */ 21 | 22 | import {html, LitElement} from 'lit'; 23 | 24 | import {app} from '../lit_demo/app'; 25 | import {getImageUrl} from '../lit_demo/constants'; 26 | import {ImageRow} from '../lit_demo/data'; 27 | 28 | import {customElement} from 'lit/decorators.js'; 29 | import styles from './image-carousel.scss'; 30 | 31 | /** 32 | * Shows multiple images in a horizontal carousel. 33 | * 34 | * Dispatches `'image-select'` event when an image is clicked/tapped. 35 | */ 36 | @customElement('image-carousel') 37 | export class ImageCarousel extends LitElement { 38 | static override styles = [styles]; 39 | 40 | onClick(id: string) { 41 | const event = 42 | new CustomEvent('image-select', {composed: true, detail: {id}}); 43 | this.dispatchEvent(event); 44 | } 45 | 46 | override render() { 47 | const images = app.imageData.rows.map( 48 | (row: ImageRow) => html` 49 |
50 | { 51 | this.onClick(row.id); 52 | }} data-id=${row.id} src="${getImageUrl(row.id)}"> 53 |
54 | `); 55 | return html` 56 |
57 |
58 | ${images} 59 |
60 |
61 |

Select an image 👆 to get started.

62 | `; 63 | } 64 | } 65 | 66 | declare global { 67 | interface HTMLElementTagNameMap { 68 | 'image-carousel': ImageCarousel; 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/components/image-prompts.scss: -------------------------------------------------------------------------------- 1 | @import '../style/mixins'; 2 | 3 | .image-prompt { 4 | display: flex; 5 | gap: 1.5em; 6 | align-items: flex-start; 7 | margin-top: 2rem; 8 | 9 | @include phone-portrait { 10 | align-items: center; 11 | flex-direction: column; 12 | gap: 0; 13 | margin-bottom: 5rem; 14 | } 15 | 16 | .left { 17 | display: flex; 18 | flex-direction: column; 19 | 20 | .wrapper { 21 | position: relative; 22 | 23 | .src { 24 | position: absolute; 25 | right: 2rem; 26 | bottom: 2rem; 27 | color: white; 28 | font-size: 1.5rem; 29 | text-shadow: 2px 2px black; 30 | text-decoration: none; 31 | } 32 | } 33 | 34 | .animation { 35 | position: relative; 36 | width: 224px; 37 | height: 15px; 38 | opacity: 0; 39 | 40 | .computing { 41 | text-align: center; 42 | } 43 | } 44 | } 45 | 46 | .right { 47 | display: flex; 48 | flex-grow: 1; 49 | flex-direction: column; 50 | gap: 0.5em; 51 | 52 | .top { 53 | text-align: right; 54 | height: 30px; 55 | } 56 | 57 | .buttons { 58 | display: flex; 59 | flex-wrap: wrap; 60 | justify-content: flex-end; 61 | gap: 1em; 62 | align-items: center; 63 | } 64 | 65 | .item { 66 | position: relative; 67 | display: flex; 68 | 69 | .pct { 70 | display: inline-block; 71 | margin-right: 1em; 72 | width: 3.5em; 73 | text-align: right; 74 | opacity: 0; 75 | transition: opacity 0.5s; 76 | } 77 | 78 | input { 79 | flex-grow: 1; 80 | max-width: 70vw; 81 | border-radius: 0; 82 | background: transparent; 83 | border: 0; 84 | border-bottom: 1px solid var(--text-fg); 85 | color: var(--text-fg); 86 | outline: none; 87 | 88 | &.toolong { 89 | border-bottom: 1px solid var(--text-red); 90 | color: var(--text-red); 91 | } 92 | } 93 | 94 | .bar { 95 | position: absolute; 96 | display: inline-block; 97 | top: 5%; 98 | left: 0; 99 | z-index: -1; 100 | background: var(--bar-col); 101 | height: 90%; 102 | width: 0; 103 | transition: width 0.5s; 104 | } 105 | } 106 | 107 | .bottom { 108 | display: flex; 109 | flex-wrap: wrap; 110 | justify-content: flex-end; 111 | gap: 1em; 112 | align-items: center; 113 | opacity: 0; 114 | 115 | .tweet { 116 | background: rgb(18, 150, 223); 117 | color: white; 118 | text-decoration: none; 119 | padding: 0px 15px; 120 | border-radius: 16px; 121 | } 122 | } 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/components/lit-demo-app.scss: -------------------------------------------------------------------------------- 1 | .loading-container { 2 | text-align: center; 3 | } 4 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/components/loading-animation.scss: -------------------------------------------------------------------------------- 1 | // CC0 from https://loading.io/css/ 2 | 3 | @import '../style/colors'; 4 | 5 | .lds-ellipsis { 6 | display: inline-block; 7 | position: relative; 8 | width: 80px; 9 | height: 80px; 10 | 11 | div { 12 | position: absolute; 13 | top: 33px; 14 | width: 13px; 15 | height: 13px; 16 | border-radius: 50%; 17 | background: var(--text-fg); 18 | animation-timing-function: cubic-bezier(0, 1, 1, 0); 19 | } 20 | 21 | div:nth-child(1) { 22 | left: 8px; 23 | animation: lds-ellipsis1 0.6s infinite; 24 | } 25 | 26 | div:nth-child(2) { 27 | left: 8px; 28 | animation: lds-ellipsis2 0.6s infinite; 29 | } 30 | 31 | div:nth-child(3) { 32 | left: 32px; 33 | animation: lds-ellipsis2 0.6s infinite; 34 | } 35 | 36 | div:nth-child(4) { 37 | left: 56px; 38 | animation: lds-ellipsis3 0.6s infinite; 39 | } 40 | } 41 | 42 | @keyframes lds-ellipsis1 { 43 | 0% { 44 | transform: scale(0); 45 | } 46 | 100% { 47 | transform: scale(1); 48 | } 49 | } 50 | @keyframes lds-ellipsis3 { 51 | 0% { 52 | transform: scale(1); 53 | } 54 | 100% { 55 | transform: scale(0); 56 | } 57 | } 58 | @keyframes lds-ellipsis2 { 59 | 0% { 60 | transform: translate(0, 0); 61 | } 62 | 100% { 63 | transform: translate(24px, 0); 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/components/loading-animation.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview Carousel of images. 20 | */ 21 | 22 | import {html, LitElement} from 'lit'; 23 | 24 | import {customElement} from 'lit/decorators.js'; 25 | import styles from './loading-animation.scss'; 26 | 27 | /** 28 | * Shows an animated loading animation. 29 | */ 30 | @customElement('loading-animation') 31 | export class LoadingAnimation extends LitElement { 32 | 33 | static override styles = [styles]; 34 | 35 | override render() { 36 | return html` 37 |
38 |
39 |
40 |
41 |
42 |
43 | `; 44 | } 45 | } 46 | 47 | declare global { 48 | interface HTMLElementTagNameMap { 49 | 'loading-animation': LoadingAnimation; 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/components/message-list.scss: -------------------------------------------------------------------------------- 1 | @import '../style/colors'; 2 | 3 | .message { 4 | padding: 0.1rem 0.5rem; 5 | margin-bottom: 1rem; 6 | } 7 | 8 | .warning { 9 | background: var(--warn-bg); 10 | color: var(--warn-fg); 11 | } 12 | 13 | .error { 14 | background: var(--error-bg); 15 | color: var(--error-fg); 16 | } 17 | 18 | .info { 19 | background: var(--note-bg); 20 | color: var(--note-fg); 21 | } 22 | 23 | .close { 24 | float: right; 25 | cursor: pointer; 26 | } 27 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/components/message-list.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview A list of dismissable info/warning/error messages. 20 | */ 21 | 22 | import {html, LitElement} from 'lit'; 23 | 24 | import {unsafeHTML} from 'lit/directives/unsafe-html.js'; 25 | 26 | import {customElement} from 'lit/decorators.js'; 27 | import styles from './message-list.scss'; 28 | 29 | enum MessageType { 30 | INFO = 'info', 31 | WARNING = 'warning', 32 | ERROR = 'error', 33 | } 34 | 35 | interface Message { 36 | message: string; 37 | type: MessageType; 38 | rawHtml: boolean; 39 | } 40 | 41 | 42 | /** 43 | * Shows info/warning/error messages that remain until closed by user. 44 | */ 45 | @customElement('message-list') 46 | export class MessageList extends LitElement { 47 | static override styles = [styles]; 48 | 49 | messages: Message[] = []; 50 | 51 | addMessage(message: Message) { 52 | this.messages.push(message); 53 | this.requestUpdate(); 54 | } 55 | 56 | info(message: string, {rawHtml = false}: {rawHtml?: boolean} = {}) { 57 | this.addMessage({message, type: MessageType.INFO, rawHtml}); 58 | } 59 | 60 | warning(message: string, {rawHtml = false}: {rawHtml?: boolean} = {}) { 61 | this.addMessage({message, type: MessageType.WARNING, rawHtml}); 62 | } 63 | 64 | error(message: string, {rawHtml = false}: {rawHtml?: boolean} = {}) { 65 | this.addMessage({message, type: MessageType.ERROR, rawHtml}); 66 | } 67 | 68 | removeMessage(event: Event, idx: number) { 69 | this.messages.splice(idx, 1); 70 | (event.target! as HTMLElement).closest('.message')!.remove(); 71 | } 72 | 73 | clear() { 74 | this.messages = []; 75 | while (this.firstChild) this.firstChild.remove(); 76 | } 77 | 78 | override render() { 79 | return this.messages.map( 80 | (message: Message, idx: number) => html` 81 |
82 | ${ 83 | message.rawHtml ? unsafeHTML(message.message) : 84 | message.message} 85 | { 86 | this.removeMessage(e, idx); 87 | }} class="close">✖ 88 |
89 | `); 90 | } 91 | } 92 | 93 | declare global { 94 | interface HTMLElementTagNameMap { 95 | 'message-list': MessageList; 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/components/model-controls.scss: -------------------------------------------------------------------------------- 1 | .controls { 2 | margin: 1em 0; 3 | display: flex; 4 | 5 | select { 6 | margin-left: 0.5em; 7 | } 8 | 9 | progress { 10 | margin: 0 1em; 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/components/model-controls.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview Controls to choose model. 20 | */ 21 | 22 | import {html, LitElement} from 'lit'; 23 | 24 | import {getModels} from '../lit_demo/constants'; 25 | import {app} from '../lit_demo/app'; 26 | 27 | import {customElement, property} from 'lit/decorators.js'; 28 | import styles from './model-controls.scss'; 29 | 30 | /** 31 | * Shows controls for model selection, progress bar, and status text. 32 | */ 33 | @customElement('model-controls') 34 | export class ModelControls extends LitElement { 35 | 36 | static override styles = [styles]; 37 | 38 | @property({attribute: false}) 39 | progress: number = 0; 40 | 41 | @property({attribute: false}) 42 | status: string = 'Initializing...'; 43 | 44 | constructor() { 45 | super(); 46 | app.models.addListener(this.onModelUpdate.bind(this)); 47 | app.models.load(getModels()[0]); 48 | } 49 | 50 | onModelUpdate(progress: number, message?: string) { 51 | this.progress = progress; 52 | if (message) this.status = message; 53 | } 54 | 55 | onModelChange(event: Event) { 56 | const target = event.target as HTMLSelectElement; 57 | const name = target.value; 58 | app.models.load(name).catch((error) => { 59 | this.status = `ERROR loading model "${name}": ${error}`; 60 | }); 61 | } 62 | 63 | async setModel(model: string) { 64 | if (getModels().indexOf(model) === -1) { 65 | throw new Error(`Model "${model}" not found!`); 66 | } 67 | await this.updateComplete; 68 | const dropdown = this.shadowRoot!.querySelector('#model_dropdown') as HTMLSelectElement; 69 | dropdown.value = model; 70 | dropdown.dispatchEvent(new Event('change')); 71 | } 72 | 73 | override render() { 74 | const options = getModels().map((model: string) => 75 | html``); 76 | return html` 77 |
78 | 79 | 82 | 83 |
${this.status}
84 |
85 | `; 86 | } 87 | } 88 | 89 | declare global { 90 | interface HTMLElementTagNameMap { 91 | 'model-controls': ModelControls; 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/exports.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview some useful exports to play around with the models & 20 | * tokenizers. 21 | * 22 | * Simple usage (see ./playground.html for more complete usage example): 23 | * 24 | * model = lit.Model('tiny'); 25 | * model.load(progress => console.log('loading...', progress)); 26 | * console.log(model.computeProbabilities(['a dog', 'a cat'], '0')); 27 | */ 28 | 29 | import {Model} from './lit_demo/compute'; 30 | import {getImageUrl, setBaseUrl} from './lit_demo/constants'; 31 | import {ImageData} from './lit_demo/data'; 32 | import * as tf from '@tensorflow/tfjs-core'; 33 | 34 | // tslint:disable-next-line:no-any Export symbols into global namespace. 35 | (window as any).lit = { Model, getImageUrl, ImageData, setBaseUrl }; 36 | // tslint:disable-next-line:no-any Export symbols into global namespace. 37 | // tslint:disable-next-line:ban-module-namespace-object-escape Export all of TF. 38 | (window as any).tf = tf; 39 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/lit_demo/app.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview Global app state. 20 | */ 21 | 22 | import {ImageData} from './data'; 23 | import {Models} from './compute'; 24 | 25 | /** 26 | * Container class holding image data and models. 27 | * 28 | * The main application component would typically call `load()` and then show 29 | * the components depending on this class asynchronously. 30 | */ 31 | export class App { 32 | 33 | imageData = new ImageData(); 34 | models = new Models(); 35 | 36 | ready: boolean = false; 37 | 38 | async load() { 39 | await this.imageData.load(); 40 | this.ready = true; 41 | } 42 | } 43 | 44 | /** 45 | * Global app state. 46 | */ 47 | export const app = new App(); 48 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/lit_demo/constants.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview Project-wide constants. 20 | */ 21 | 22 | // Can be overwritten with setBaseUrl() below. 23 | // let baseUrl = 'https://google-research.github.io/vision_transformer/lit'; 24 | let baseUrl = 'https://figur.li/jax2tfjs'; 25 | // Can be overwritten with setModels() below. 26 | let models = ['tiny', 'small']; 27 | 28 | /** Allows to set abnew base URL. ase URL on which all other. */ 29 | export const setBaseUrl = (newBaseUrl: string) => { 30 | baseUrl = newBaseUrl; 31 | }; 32 | 33 | /** Retrieves URL for a model-specific file (vocabulary, embeddings, ...). */ 34 | export const getModelFileUrl = (name: string, relativePath: string) => ( 35 | `${baseUrl}/data/models/${name}/${relativePath}` 36 | ); 37 | 38 | /** Retrieves the URL for images information JSON file. */ 39 | export const getImagesInfoUrl = () => `${baseUrl}/data/images/info.json`; 40 | 41 | /** Retrieves the URL for an image. */ 42 | export const getImageUrl = (id: string) => `${baseUrl}/data/images/${id}.jpg`; 43 | 44 | /** Returns names of available models. */ 45 | export const getModels = () => models; 46 | 47 | /** Sets names of available models. */ 48 | export const setModels = (newModels: string[]) => { 49 | models = newModels; 50 | }; 51 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/lit_demo/data.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview Accessing additional data. 20 | */ 21 | 22 | import {getImagesInfoUrl} from './constants'; 23 | 24 | /** 25 | * Information about a single image. 26 | */ 27 | export interface ImageRow { 28 | /** Stable ID of the image. */ 29 | id: string; 30 | /** Set of example prompts for this image. */ 31 | prompts: string; 32 | /** License of the image. */ 33 | license: string; 34 | /** Where the image was originally downloaded from. */ 35 | source: string; 36 | /** Short description of image. */ 37 | description: string; 38 | } 39 | /** 40 | * Contains information about all images. 41 | */ 42 | export class ImageData { 43 | 44 | rows: ImageRow[] = []; 45 | /** Will be set to `true` when `load()` finishes. */ 46 | ready = false; 47 | 48 | /** 49 | * Gets an image by ID. Throws an error if image is not found, data is not 50 | * loaded, or ID is not unique. 51 | */ 52 | get(id: string): ImageRow { 53 | if (!this.ready) { 54 | throw new Error('ImageData not loaded!'); 55 | } 56 | const matching = this.rows.filter(row => row.id === id); 57 | if (matching.length !== 1) { 58 | throw new Error(`Got unexpected ${matching.length} matches for id="${id}"`); 59 | } 60 | return matching[0]; 61 | } 62 | 63 | /** 64 | * Loads image data asynchronously. 65 | */ 66 | async load() { 67 | this.rows = ( 68 | await fetch(getImagesInfoUrl()) 69 | .then(response => { 70 | console.log('response', response); 71 | return response.json(); 72 | }) 73 | ); 74 | this.ready = true; 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/style.scss: -------------------------------------------------------------------------------- 1 | // General styles for the page. 2 | 3 | @import './style/colors'; 4 | @import './style/mixins'; 5 | 6 | html { 7 | font-size: 14px; 8 | line-height: 1.6em; 9 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, 10 | Ubuntu, Cantarell, 'Fira Sans', 'Droid Sans', 'Helvetica Neue', Arial, 11 | sans-serif; 12 | text-size-adjust: 100%; 13 | -ms-text-size-adjust: 100%; 14 | -webkit-text-size-adjust: 100%; 15 | 16 | @media (min-width: 1200px) { 17 | width: 1024px; 18 | margin: 0 auto; 19 | } 20 | @media (min-width: 768px) { 21 | font-size: 16px; 22 | } 23 | 24 | color: var(--text-fg); 25 | background: var(--text-bg); 26 | 27 | body { 28 | margin: 0; 29 | padding: 0rem 1rem 10rem; 30 | } 31 | } 32 | 33 | a, 34 | a:visited { 35 | color: var(--link-col); 36 | } 37 | 38 | h1 { 39 | font-weight: 700; 40 | font-size: 2rem; 41 | line-height: 1.3em; 42 | } 43 | 44 | p { 45 | font-size: 1.06rem; 46 | line-height: 1.3em; 47 | } 48 | 49 | input { 50 | font-size: 1rem; 51 | 52 | &::placeholder { 53 | color: var(--placeholder-col); 54 | } 55 | } 56 | 57 | .note { 58 | font-style: normal; 59 | border: none; 60 | border-radius: 2px; 61 | margin-left: auto; 62 | margin-right: auto; 63 | 64 | padding: 0.5rem 0.5rem 0.5rem 2rem; 65 | width: 90%; 66 | 67 | @include phone-portrait { 68 | width: 100%; 69 | padding: 0.5rem; 70 | box-sizing: border-box; 71 | } 72 | 73 | background-color: var(--note-bg); 74 | color: var(--note-fg); 75 | 76 | &.warning { 77 | background-color: var(--warn-bg); 78 | color: var(--warn-fg); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/style/colors.scss: -------------------------------------------------------------------------------- 1 | // Dark and light mode colors. 2 | 3 | :root { 4 | --text-bg: hsl(0, 0%, 97%); 5 | --gray-border: hsla(0, 0%, 0%, 0.1); 6 | --gray: rgba(0, 0, 0, 0.6); 7 | --border-radius: 5px; 8 | --orange: hsl(24, 100%, 50%); 9 | --distill-blue: hsl(200, 50%, 25%); 10 | --blue: #337699; 11 | --green: #3db867; 12 | --text-fg: rgb(15, 15, 15); 13 | --text-red: rgb(220, 0, 0); 14 | --bar-col: rgb(171, 199, 227); 15 | --link-col: rgb(0, 0, 238); 16 | --placeholder-col: rgb(166, 166, 166); 17 | --note-bg: #e1f5fe; 18 | --note-fg: #1a6ebb; 19 | --warn-bg: #ffe1aa; 20 | --warn-fg: #a16800; 21 | --error-bg: #850000; 22 | --error-fg: white; 23 | 24 | @media (prefers-color-scheme: dark) { 25 | --text-bg: rgb(56, 56, 56); 26 | --text-fg: rgb(213, 213, 213); 27 | --bar-col: rgb(20, 109, 163); 28 | --link-col: rgb(66, 165, 245); 29 | 30 | --note-fg: rgb(121 157 190); 31 | --note-bg: rgb(2 59 85); 32 | --warn-bg: #784e00; 33 | --warn-fg: #edbe68; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/style/mixins.scss: -------------------------------------------------------------------------------- 1 | // Useful mixins. 2 | 3 | // To wrap styles that should only trigger for phones in portrait mode. 4 | @mixin phone-portrait { 5 | @media only screen and (max-device-width: 800px) and (orientation: portrait) { 6 | @content; 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/tokenizers/common.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview Utility code shared between tokenizers. 20 | */ 21 | 22 | /** 23 | * A vocabulary consists of a list of tokens, and optional numerical value. 24 | * The numerical value is used by the unigram algorithnm to find the best 25 | * tokenizaion, and is ignored by the BPE algorithm. 26 | */ 27 | export type Vocabulary = Array<[string, number]>; 28 | 29 | /** 30 | * Converts a string to a sequence of tokens. 31 | */ 32 | export interface Tokenizer { 33 | encode(input: string): number[]; 34 | } 35 | 36 | /** 37 | * Factory for new `Tokenizer`. 38 | */ 39 | export interface TokenizerConstructor { 40 | new (vocabulary: Vocabulary): Tokenizer; 41 | } 42 | 43 | /** 44 | * Unicode-aware character iteration of strings. 45 | */ 46 | export const stringToChars = (input: string): string[] => { 47 | const symbols = []; 48 | for (const symbol of input) { 49 | symbols.push(symbol); 50 | } 51 | return symbols; 52 | }; 53 | 54 | /** 55 | * Special separator character used to delimit sub-word tokens. 56 | */ 57 | export const TOKEN_SEPARATOR = 58 | '\u2581'; // This is the unicode character 'lower one eighth block'. 59 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/tokenizers/index.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * @fileoverview Tokenizers and tokenizer mappings. 20 | */ 21 | 22 | import {Tokenizer, TokenizerConstructor, Vocabulary} from './common'; 23 | import * as sentencepieceBpe from './sentencepiece_bpe'; 24 | import * as sentencepieceUnigram from './sentencepiece_unigram'; 25 | 26 | export {Tokenizer, Vocabulary} from './common'; 27 | 28 | const TOKENIZERS = new Map([ 29 | ['BPE', sentencepieceBpe.Tokenizer], 30 | ['UNIGRAM', sentencepieceUnigram.Tokenizer], 31 | ]); 32 | 33 | /** 34 | * Returns a tokenizer of type `name` using `vocabulary`. 35 | */ 36 | export const getTokenizer = (name: string, vocabulary: Vocabulary): Tokenizer => { 37 | const ctor = TOKENIZERS.get(name); 38 | if (!ctor) throw new Error(`Unknown tokenizer: ${name}`); 39 | return new ctor(vocabulary); 40 | }; 41 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_bpe.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | import {stringToChars, TOKEN_SEPARATOR, Vocabulary, Tokenizer as TokenizerInterface} from './common'; 19 | 20 | interface Candidate { 21 | piece: string; 22 | pos: number; 23 | score: number; 24 | } 25 | 26 | const scoreDesc = (a: Candidate, b: Candidate) => b.score - a.score; 27 | 28 | function processInput(str: string): string { 29 | const normalized = str.normalize('NFKC'); 30 | return normalized.length > 0 ? 31 | TOKEN_SEPARATOR + normalized.replace(/ /g, TOKEN_SEPARATOR) : 32 | normalized; 33 | } 34 | 35 | /** 36 | * Sentencepiece tokenizer implementing the BPE algorithm. 37 | */ 38 | export class Tokenizer implements TokenizerInterface { 39 | 40 | // piece -> [score, index] 41 | private readonly map: Map; 42 | 43 | constructor(vocabulary: Vocabulary) { 44 | this.map = new Map(); 45 | vocabulary.forEach(([piece, score], idx) => { 46 | if (this.map.has(piece)) { 47 | throw new Error(`Piece "${piece}" occurs multiple times in vocabulary`); 48 | } 49 | this.map.set(piece, [score, idx]); 50 | }); 51 | } 52 | 53 | encode(input: string): number[] { 54 | const processed: string = processInput(input); 55 | let pieces: string[] = stringToChars(processed); 56 | 57 | while (true) { 58 | const candidates: Candidate[] = []; 59 | for (let i = 0; i < pieces.length - 1; i++) { 60 | const fused = pieces[i] + pieces[i + 1]; 61 | const el = this.map.get(fused); 62 | if (el) { 63 | candidates.push({ piece: fused, pos: i, score: el[0] }); 64 | } 65 | } 66 | if (candidates.length === 0) { 67 | break; 68 | } 69 | candidates.sort(scoreDesc); 70 | const best = candidates[0]; 71 | pieces = [ 72 | ...pieces.slice(0, best.pos), 73 | best.piece, 74 | ...pieces.slice(best.pos + 2) 75 | ]; 76 | } 77 | 78 | return pieces.map(piece => this.map.get(piece)![1]); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_bpe_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | import 'jasmine'; 19 | 20 | describe('sentencepiece bpe test', () => { 21 | it('computes a thing when asked', () => {}); 22 | }); 23 | 24 | import * as bpe from './sentencepiece_bpe'; 25 | import {TOKEN_SEPARATOR, Vocabulary} from './common'; 26 | 27 | const vocab: Vocabulary = [ 28 | [TOKEN_SEPARATOR, 0], // 0 29 | ['a', 0], // 1 30 | ['e', 0], // 2 31 | ['s', 0], // 3 32 | ['t', 0], // 4 33 | ['te', -1], // 5 34 | ['st', -2], // 6 35 | ['test', -3], // 7 36 | ['tes', -4], // 8 37 | ]; 38 | 39 | describe('BPE Tokenizer', () => { 40 | let tokenizer: bpe.Tokenizer; 41 | beforeAll(() => { 42 | tokenizer = new bpe.Tokenizer(vocab); 43 | }); 44 | 45 | it('should tokenize correctly', () => { 46 | expect(tokenizer.encode('a test')).toEqual([0, 1, 0, 7]); 47 | }); 48 | }); 49 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_unigram_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | import {Tokenizer} from './sentencepiece_unigram'; 19 | 20 | const stubbedTokenizerVocab = [ 21 | ['�', 0], 22 | ['', 0], 23 | ['', 0], 24 | ['extra_token_id_1', 0], 25 | ['extra_token_id_2', 0], 26 | ['extra_token_id_3', 0], 27 | ['▁', -2], 28 | ['▁a', -1], 29 | ['▁ç', -2], 30 | ['a', -3], 31 | ['.', -1], 32 | ['▁I', -1], 33 | ['▁like', -1], 34 | ['▁it', -1], 35 | ['I', -2], 36 | ['like', -2], 37 | ['it', -2], 38 | ['l', -3], 39 | ['i', -3], 40 | ['k', -3], 41 | ['e', -3], 42 | ['i', -3], 43 | ['t', -3] 44 | ]; 45 | 46 | describe('Universal Sentence Encoder tokenizer', () => { 47 | let tokenizer: Tokenizer; 48 | beforeAll(() => { 49 | tokenizer = new Tokenizer(stubbedTokenizerVocab as Array<[string, number]>); 50 | }); 51 | 52 | it('basic usage', () => { 53 | expect(tokenizer.encode('Ilikeit.')).toEqual([11, 15, 16, 10]); 54 | }); 55 | 56 | it('handles whitespace', () => { 57 | expect(tokenizer.encode('I like it.')).toEqual([11, 12, 13, 10]); 58 | }); 59 | 60 | it('should normalize inputs', () => { 61 | expect(tokenizer.encode('ça')).toEqual(tokenizer.encode('c\u0327a')); 62 | }); 63 | 64 | it('should handle unknown inputs', () => { 65 | expect(() => tokenizer.encode('😹')).not.toThrow(); 66 | }); 67 | 68 | it('should treat consecutive unknown inputs as a single word', () => { 69 | expect(tokenizer.encode('a😹😹')).toEqual([7, 0]); 70 | }); 71 | }); 72 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/tokenizers/trie.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright Big Vision Authors 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | // Copied from 19 | // https://github.com/tensorflow/tfjs-models/blob/master/universal-sentence-encoder/src/tokenizer/trie.ts 20 | 21 | import {stringToChars} from './common'; 22 | 23 | // [token, score, index] 24 | type OutputNode = [string[], number, number]; 25 | 26 | class TrieNode { 27 | parent: TrieNode|null; 28 | end: boolean; 29 | children: {[firstSymbol: string]: TrieNode}; 30 | word: OutputNode; 31 | 32 | constructor() { 33 | this.parent = null; 34 | this.children = {}; 35 | this.end = false; 36 | this.word = [[], 0, 0]; 37 | } 38 | } 39 | 40 | /** 41 | * Simple Trie datastructure. 42 | */ 43 | export class Trie { 44 | root: TrieNode; 45 | 46 | constructor() { 47 | this.root = new TrieNode(); 48 | } 49 | 50 | /** 51 | * Inserts a token into the trie. 52 | */ 53 | insert(word: string, score: number, index: number) { 54 | let node = this.root; 55 | 56 | const symbols = stringToChars(word); 57 | 58 | for (let i = 0; i < symbols.length; i++) { 59 | if (!node.children[symbols[i]]) { 60 | node.children[symbols[i]] = new TrieNode(); 61 | node.children[symbols[i]].parent = node; 62 | node.children[symbols[i]].word[0] = node.word[0].concat(symbols[i]); 63 | } 64 | 65 | node = node.children[symbols[i]]; 66 | if (i === symbols.length - 1) { 67 | node.end = true; 68 | node.word[1] = score; 69 | node.word[2] = index; 70 | } 71 | } 72 | } 73 | 74 | /** 75 | * Returns an array of all tokens starting with ss. 76 | * 77 | * @param ss The prefix to match on. 78 | */ 79 | commonPrefixSearch(ss: string[]): OutputNode[] { 80 | const output: OutputNode[] = []; 81 | let node = this.root.children[ss[0]]; 82 | 83 | for (let i = 0; i < ss.length && node; i++) { 84 | if (node.end) { 85 | output.push(node.word); 86 | } 87 | node = node.children[ss[i + 1]]; 88 | } 89 | 90 | if (!output.length) { 91 | output.push([[ss[0]], 0, 0]); 92 | } 93 | 94 | return output; 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/tools/lit_demo/src/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "outDir": "dist", 4 | "target": "es6", 5 | "module": "commonjs", 6 | "lib": ["dom", "DOM.Iterable", "es2019", "es2020.string"], 7 | "types": ["node", "jasmine", "resize-observer-browser"], 8 | "moduleResolution": "node", 9 | "allowJs": false, 10 | "pretty": true, 11 | "resolveJsonModule": true, 12 | "sourceMap": false, 13 | "skipLibCheck": true, 14 | "removeComments": true, 15 | "esModuleInterop": true, 16 | "importsNotUsedAsValues": "preserve", 17 | "downlevelIteration": true, 18 | "skipDefaultLibCheck": true, 19 | "preserveConstEnums": false, 20 | "experimentalDecorators": true, 21 | "emitDecoratorMetadata": true, 22 | "noErrorTruncation": false, 23 | "noEmitOnError": false, 24 | "declaration": false, 25 | "stripInternal": true, 26 | "inlineSourceMap": true, 27 | "inlineSources": true, 28 | "importHelpers": true, 29 | "allowUnreachableCode": false, 30 | "noFallthroughCasesInSwitch": true, 31 | "noImplicitAny": true, 32 | "noImplicitReturns": false, 33 | "noImplicitThis": true, 34 | "strictBindCallApply": true, 35 | "strictFunctionTypes": true, 36 | "strictNullChecks": false, 37 | "strictPropertyInitialization": false 38 | }, 39 | "include": ["./client", "./examples"], 40 | "compileOnSave": false 41 | } 42 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/trainers/proj/flexi/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Few common utils used in both/all flexi-trainers.""" 16 | import functools 17 | import itertools 18 | import numpy as np 19 | 20 | 21 | def mkrng(xid, wid, step): 22 | # Need to cap at 0, for example localruns use -1. 23 | rng_key = (max(xid, 0), max(wid, 0), max(step, 0)) 24 | return np.random.default_rng(rng_key) 25 | 26 | 27 | def mkprob(x): 28 | if x is None: 29 | return x 30 | return np.array(x) / np.sum(x) 31 | 32 | 33 | def choice(values, ratios, rng=None): 34 | rng = rng or np.random.default_rng() 35 | return rng.choice(values, p=mkprob(ratios)) 36 | 37 | 38 | def mkpredictfns(predict_fn, config, template="predict_{x}"): 39 | # If we have two flexi args a=[1,2], b=[10,20], then we create a 40 | # predict_fn for all possible combinations, named "predict_a=1_b=10" etc. 41 | all_combinations = [dict(comb) for comb in itertools.product( 42 | *[[(arg, val) for val in config[arg].v] for arg in config] 43 | )] 44 | return { 45 | template.format(x="_".join(f"{k}={v}" for k, v in kw.items())): 46 | functools.partial(predict_fn, **kw) 47 | for kw in all_combinations} 48 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/trainers/proj/givt/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utils for GIVT stage I and II trainers.""" 16 | 17 | from typing import Any 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | 23 | def unbin_depth( 24 | depth: jax.Array, 25 | *, 26 | min_depth: float, 27 | max_depth: float, 28 | num_bins: int, 29 | ) -> jax.Array: 30 | """Transform a depth map with binned values into a float-valued depth map. 31 | 32 | Args: 33 | depth: Depth map whose binned values are encoded in one-hot fashion along 34 | the last dimension. 35 | min_depth: Minimum binned depth value. 36 | max_depth: Maximum value of binned depth. 37 | num_bins: Number of depth bins. 38 | 39 | Returns: 40 | Float-valued depth map. 41 | """ 42 | depth = jnp.argmax(depth, axis=-1) 43 | depth = depth.astype(jnp.float32) + 0.5 # Undoes floor in expectation. 44 | depth /= num_bins 45 | return depth * (max_depth - min_depth) + min_depth 46 | 47 | 48 | def get_local_rng( 49 | seed: int | jax.Array, 50 | batch: Any, 51 | ) -> jax.Array: 52 | """Generate a per-image seed based on the image id or the image values. 53 | 54 | Args: 55 | seed: Random seed from which per-image seeds should be derived. 56 | batch: Pytree containing a batch of images (key "image") and optionally 57 | image ids (key "image/id"). 58 | 59 | Returns: 60 | Array containing per-image ids. 61 | """ 62 | fake_id = None 63 | if "image" in batch: 64 | fake_id = (10**6 * jax.vmap(jnp.mean)(batch["image"])).astype(jnp.int32) 65 | return jax.lax.scan( 66 | lambda k, x: (jax.random.fold_in(k, x), None), 67 | jax.random.PRNGKey(seed), 68 | batch.get("image/id", fake_id), 69 | )[0] 70 | 71 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/trainers/proj/uvim/coco_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities to inspect coco data and predictions in notebooks.""" 16 | # pylint: disable=consider-using-from-import 17 | import functools 18 | import json 19 | 20 | import numpy as np 21 | from panopticapi import utils as pycoco_utils 22 | from skimage import segmentation 23 | 24 | import tensorflow.io.gfile as gfile 25 | 26 | 27 | import os 28 | ROOT = os.environ.get('COCO_DATA_DIR', '.') 29 | 30 | 31 | PANOPTIC_COCO_CATS_FILE = f'{ROOT}/panoptic_coco_categories.json' 32 | 33 | 34 | @functools.lru_cache(maxsize=None) 35 | def _coco_panoptic_categories(): 36 | with gfile.GFile(PANOPTIC_COCO_CATS_FILE, 'r') as f: 37 | categories_list = json.load(f) 38 | return tuple(categories_list) 39 | 40 | 41 | def rgb_panoptic_from_twochannels(twochannels, boundaries: bool = False): 42 | """Makes a RGB panoptic output and segments_info from a twochannels view.""" 43 | semantics = twochannels[..., 0] 44 | instances = twochannels[..., 1] 45 | max_instances = np.max(instances) + 1 46 | merged = semantics * max_instances + instances 47 | merged = np.where(semantics < 0, semantics, merged) 48 | 49 | categories_list = _coco_panoptic_categories() 50 | categories = {category['id']: category for category in categories_list} 51 | id_generator = pycoco_utils.IdGenerator(categories) 52 | segments_info = {} 53 | rgb = np.zeros((*instances.shape[:2], 3), dtype=np.uint8) 54 | 55 | for merged_id in np.unique(merged): 56 | if merged_id // max_instances > 0: 57 | category = categories_list[int(merged_id // max_instances) - 1] 58 | segment_id, color = id_generator.get_id_and_color(category['id']) 59 | else: 60 | category = {'id': -1, 'name': 'void', 'isthing': False} 61 | segment_id, color = -1, np.array([0, 0, 0]) 62 | segments_info[segment_id] = { 63 | 'id': segment_id, 64 | 'color': color, 65 | 'category_id': category['id'], 66 | 'name': category['name'], 67 | 'isthing': category['isthing'], 68 | } 69 | rgb[merged == merged_id] = color 70 | 71 | if boundaries: 72 | boundaries = segmentation.find_boundaries( 73 | pycoco_utils.rgb2id(rgb), mode='thick') 74 | rgb[boundaries] = 0 75 | return rgb, segments_info 76 | -------------------------------------------------------------------------------- /palivla_digit/big_vision/trainers/proj/uvim/colorization_task.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Inputs, outputs and losses for colorization task.""" 16 | import einops 17 | import jax.numpy as jnp 18 | import numpy as np 19 | 20 | ONE_HOT_AXIS = -2 21 | 22 | 23 | def input_pp(batch, config): 24 | """Make inputs for colorization task.""" 25 | if "labels" not in batch: 26 | # During predict of phase2 there is no 'labels' field. 27 | x = None 28 | else: 29 | hp, wp = config.model.patch_size 30 | x = { 31 | "color": batch["labels"], 32 | } 33 | # Convert labels from (B, H, W) to (B, num_patches, C, patch_size) 34 | x["color"] = einops.rearrange( 35 | x["color"], "b (hn hp) (wn wp) c -> b (hn wn) c (hp wp)", hp=hp, wp=wp) 36 | ctx = batch.get("image_ctx", batch.get("image", None)) 37 | return {"ctx": ctx, "x": x} 38 | 39 | 40 | def loss_fn(logits, batch, config): 41 | """Compute loss for colorization task.""" 42 | labels = input_pp(batch, config)["x"] 43 | error = logits["color"] - labels["color"] 44 | loss = jnp.square(error) 45 | return loss, {"loss_color": loss} 46 | 47 | 48 | def predict_outputs(logits, config): 49 | """Make outputs for colorization task.""" 50 | # Map logits to (height, width, channels). 51 | hp, wp = config.model.patch_size 52 | hn, wn = np.array(config.model.input_size) // np.array((hp, wp)) 53 | assert ONE_HOT_AXIS == -2, "Rearrange below depends on this." 54 | output = einops.rearrange( 55 | logits["color"], 56 | "b (hn wn) c (hp wp) -> b (hn hp) (wn wp) c", 57 | hn=hn, 58 | wn=wn, 59 | hp=hp, 60 | wp=wp) 61 | output = jnp.clip(output, -1., 1.) 62 | return {"color": output} 63 | -------------------------------------------------------------------------------- /palivla_digit/palivla/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/palivla/__init__.py -------------------------------------------------------------------------------- /palivla_digit/palivla/modality_embedder.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import flax.linen as nn 3 | 4 | # quick hack to make embedding w/ data type param json-serializable 5 | class ModalityEmbedder(nn.Module): 6 | num_embeddings: int 7 | embedding_dim: int 8 | dtype_str: str = 'float32' 9 | 10 | @nn.compact 11 | def __call__(self, x): 12 | return nn.Embed( 13 | num_embeddings=self.num_embeddings, 14 | features=self.embedding_dim, 15 | dtype=getattr(jnp, self.dtype_str), 16 | )(x) -------------------------------------------------------------------------------- /palivla_digit/palivla/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Sequence, Mapping, Union 2 | 3 | import jax 4 | from flax.typing import Collection, VariableDict 5 | from flax.struct import dataclass 6 | import chex 7 | 8 | Array = chex.Array 9 | ArrayTree = Union[chex.Array, Mapping[str, "ArrayTree"], Sequence["ArrayTree"]] 10 | Params = Collection 11 | Variables = VariableDict 12 | Updates = ArrayTree 13 | Data = ArrayTree 14 | Info = Dict[str, Any] 15 | 16 | 17 | @dataclass 18 | class TrainingBatch: 19 | sensors: Dict[str, jax.Array] 20 | sensors_mask: jax.Array 21 | actions_mask: jax.Array 22 | actions: jax.Array 23 | tokens: jax.Array 24 | tokens_ar: jax.Array 25 | tokens_loss: jax.Array 26 | tokens_mask: jax.Array 27 | language_validity: jax.Array | None = None 28 | tokens_ar_fuse: jax.Array | None = None 29 | tokens_loss_fuse: jax.Array | None = None 30 | gen_start: jax.Array | None = None 31 | modality_idx: jax.Array | None = None 32 | modal_mask: jax.Array | None = None 33 | 34 | 35 | @dataclass 36 | class RolloutBatch: 37 | sensor_data: Dict[str, jax.Array] 38 | sensor_masks: Dict[str, jax.Array] 39 | prompt: jax.Array 40 | prompt_mask: jax.Array 41 | prompt_ar: jax.Array 42 | -------------------------------------------------------------------------------- /palivla_digit/palivla/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import jax 3 | from jax.experimental import multihost_utils 4 | import numpy as np 5 | import tensorflow as tf 6 | import flax 7 | from palivla.types import Params 8 | 9 | def freeze_structure(structure): 10 | return jax.tree_util.tree_map( 11 | lambda x: tuple(freeze_structure(y) for y in x) if isinstance(x, list) else x, 12 | structure, 13 | is_leaf=lambda x: isinstance(x, list), 14 | ) 15 | 16 | def key_string(path, separator="/") -> str: 17 | def _component_to_string(component) -> str: 18 | if isinstance(component, jax.tree_util.SequenceKey): 19 | return str(component.idx) 20 | elif isinstance(component, jax.tree_util.DictKey): 21 | return str(component.key) 22 | elif isinstance(component, jax.tree_util.GetAttrKey): 23 | return str(component.name) 24 | elif isinstance(component, jax.tree_util.FlattenedIndexKey): 25 | return str(component.key) 26 | else: 27 | return str(component) 28 | return separator.join(_component_to_string(component) for component in path) 29 | 30 | 31 | def host_broadcast_str(x: str | None) -> str: 32 | """ 33 | Broadcast_one_to_all, but with a string. 34 | 35 | Works by padding the string to the length of the longest string and then 36 | broadcasting the result, then stripping the padding. 37 | 38 | Note: this will remove the padding from the end of the string. 39 | """ 40 | if x is None: 41 | x = "" 42 | 43 | max_len = multihost_utils.broadcast_one_to_all(len(x)) 44 | padded = x.ljust(max_len) 45 | 46 | encoded = np.array([ord(c) for c in padded], dtype=np.uint8)[:max_len] 47 | encoded = multihost_utils.broadcast_one_to_all(encoded) 48 | decoded = "".join([chr(u) for u in encoded]) 49 | 50 | return decoded.rstrip() 51 | 52 | 53 | def load_tvl_weights(pretrained_path: str) -> dict[tuple, np.ndarray]: 54 | with tf.io.gfile.GFile(pretrained_path, 'rb') as f: 55 | ckpt_dict = np.load(f, allow_pickle=False) 56 | keys, values = zip(*list(ckpt_dict.items())) 57 | return {tuple(k.split('|')): v for k, v in zip(keys, values)} 58 | 59 | 60 | def merge_params(init_params: Params, pretrained_params: Params) -> Params: 61 | def _merge(possible_param1, possible_param2): 62 | if possible_param2 is not None: 63 | return possible_param2 64 | return possible_param1 65 | flat_init_params = flax.traverse_util.flatten_dict(init_params) 66 | flat_pretrained_params = flax.traverse_util.flatten_dict(pretrained_params) 67 | params = {k: _merge(v_init, flat_pretrained_params.get(k, None)) for k, v_init in flat_init_params.items()} 68 | return flax.traverse_util.unflatten_dict(params) -------------------------------------------------------------------------------- /palivla_digit/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "big-vision" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.11" 7 | dependencies = [ 8 | "big-vision", 9 | "tqdm>=4.60.0", 10 | "absl-py>=0.12.0", 11 | "imageio>=2.31.1", 12 | "moviepy>=1.0.3", 13 | "chex>=0.1.86", 14 | "distrax>=0.1.5", 15 | "dlimp@git+https://github.com/kvablack/dlimp.git", 16 | "einops>=0.8.0", 17 | "flax>=0.9.0", 18 | "ipykernel", 19 | "jax>=0.4.34", 20 | "matplotlib>=3.9.2", 21 | "ml-collections>=0.1.1", 22 | "numpy<2.0.0", 23 | "optax>=0.2.3", 24 | "orbax-checkpoint>=0.7.0", 25 | "overrides>=7.7.0", 26 | "pip", 27 | "scalax>=0.2.4", 28 | "scikit-learn>=1.5.2", 29 | "scipy>=1.14.1", 30 | "tensorflow-probability>=0.24.0", 31 | "plotly>=5.16.1", 32 | "tfds-nightly>=4.9.0", 33 | "tf-nightly>=2.15.0", 34 | "tensorflow-text-nightly>=2.15.0", 35 | "tensorflow_hub>=0.14.0", 36 | "tensorflow_graphics", 37 | "wandb>=0.18.3", 38 | "protobuf>=3.20", 39 | "huggingface-hub>=0.27.0", 40 | "transformers>=4.47.1", 41 | "prettytable>=3.12.0", 42 | "funcsigs", 43 | "opencv-python", 44 | "pyquaternion", 45 | "librosa", 46 | "edgeml @ git+https://github.com/youliangtan/edgeml.git", 47 | "gym>=0.26", 48 | "jax-smi", 49 | "octo", 50 | "eval", 51 | ] 52 | 53 | [project.optional-dependencies] 54 | tpu = [ 55 | "jax[tpu]>=0.4.34", 56 | "libtpu-nightly", 57 | ] 58 | gpu = [ 59 | "jax[cuda12]==0.4.34" 60 | ] 61 | 62 | [tool.uv] 63 | find-links = ["https://storage.googleapis.com/jax-releases/libtpu_releases.html", "https://pypi.org/simple/tf-nightly/"] 64 | 65 | prerelease = "allow" 66 | conflicts = [ 67 | [ 68 | { extra = "tpu" }, 69 | { extra = "gpu" }, 70 | ], 71 | ] 72 | override-dependencies = [ 73 | # Always use tf-nightly and tfds-nightly instead of tensorflow and tensorflow_datasets 74 | "tensorflow ; sys_platform == 'never'", 75 | "tensorflow_datasets ; sys_platform == 'never'", 76 | "scipy>=1.14.1", 77 | "jax>=0.4.34", 78 | ] 79 | 80 | [build-system] 81 | requires = ["hatchling"] 82 | build-backend = "hatchling.build" 83 | 84 | [dependency-groups] 85 | dev = [ 86 | "ipywidgets>=8.1.5", 87 | "isort>=6.0.0b2", 88 | "ruff>=0.8.4", 89 | ] 90 | 91 | [tool.hatch.metadata] 92 | allow-direct-references = true 93 | 94 | [tool.uv.sources] 95 | big-vision = { workspace = true } 96 | octo = { path = "../octo_digit/octo", editable = true } 97 | eval = { path = "../octo_digit/eval", editable = true } -------------------------------------------------------------------------------- /palivla_digit/run.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if a TPU VM name is provided 4 | if [ $# -eq 0 ]; then 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | TPU_VM_NAME=$1 10 | PROJECT="rail-tpus" 11 | 12 | # Cache file for TPU name/zone mapping 13 | CACHE_FILE="$HOME/.cache/tpus" 14 | mkdir -p "$(dirname "$CACHE_FILE")" 15 | 16 | # Check if the TPU info is already cached 17 | if [ -f "$CACHE_FILE" ]; then 18 | CACHED_INFO=$(grep "^$TPU_VM_NAME:" "$CACHE_FILE") 19 | if [ -n "$CACHED_INFO" ]; then 20 | CACHED_ZONE=$(echo "$CACHED_INFO" | cut -d':' -f2) 21 | NUM_WORKERS=$(echo "$CACHED_INFO" | cut -d':' -f3) 22 | fi 23 | fi 24 | 25 | if [ -n "$CACHED_ZONE" ]; then 26 | ZONE=$CACHED_ZONE 27 | else 28 | # Get the TPU information 29 | for MAYBE_ZONE in us-central1-a us-central2-b europe-west4-b; do 30 | TPU_INFO=$(gcloud compute tpus tpu-vm describe $TPU_VM_NAME --project=$PROJECT --zone=$MAYBE_ZONE --format=json 2>/dev/null) 31 | if [ $? -eq 0 ]; then 32 | # Cache the successful name/zone mapping and number of workers 33 | ZONE=$MAYBE_ZONE 34 | NUM_WORKERS=$(echo "$TPU_INFO" | jq '.networkEndpoints | length') 35 | echo "$TPU_VM_NAME:$ZONE:$NUM_WORKERS" >> "$CACHE_FILE" 36 | break 37 | fi 38 | done 39 | fi 40 | 41 | # Set the source and destination directories based on the zone 42 | if [[ $ZONE == "europe-west4-"* ]]; then 43 | DEST_DIR="$TPU_VM_NAME:/nfs/nfs3/users/kstachowicz/big_vision_multimodal" 44 | elif [[ $ZONE == "us-central2-"* ]]; then 45 | DEST_DIR="data-machine:/nfs/nfs2/users/kstachowicz/big_vision_multimodal" 46 | else 47 | echo "Unsupported zone: $ZONE" 48 | exit 1 49 | fi 50 | 51 | echo "TPU_VM_NAME: $TPU_VM_NAME" 52 | echo "ZONE: $ZONE" 53 | echo "DEST_DIR: $DEST_DIR" 54 | echo "Number of workers: $NUM_WORKERS" 55 | 56 | # Copy the source directory to the TPU VM 57 | rsync -avzL --exclude .git --exclude-from=.gitignore . $DEST_DIR 58 | 59 | # Launch the pod configuration 60 | POD_NAME=$TPU_VM_NAME tpc launch pod_config.py 61 | 62 | # Connect to the pod 63 | bash ssh_pod.sh $TPU_VM_NAME 64 | 65 | -------------------------------------------------------------------------------- /palivla_digit/setup.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | uv sync --extra gpu 3 | uv pip install -e ../octo_digit --no-deps 4 | uv pip install -e ../bridge_with_digit/widowx_envs -------------------------------------------------------------------------------- /palivla_digit/setup_pod.sh: -------------------------------------------------------------------------------- 1 | sudo apt -y update && sudo apt -y install nfs-common 2 | sudo mkdir -p -m 777 /nfs/nfs3 3 | sudo mount -o rw,intr 10.105.46.66:/nfs3 /nfs/nfs3 4 | 5 | sudo usermod -u 3210 kstachowicz 6 | sudo groupmod -g 3210 kstachowicz 7 | sudo chown -R kstachowicz:kstachowicz /home/kstachowicz 8 | 9 | sudo -i -u kstachowicz bash << EOF 10 | 11 | git config --global --add safe.directory '*' 12 | 13 | rm -f .bashrc 14 | ln -s /nfs/nfs3/users/kstachowicz/.bashrc .bashrc 15 | ln -s /nfs/nfs3/users/kstachowicz/.netrc .netrc 16 | EOF 17 | 18 | /nfs/nfs3/users/kstachowicz/miniforge3/bin/conda init bash -------------------------------------------------------------------------------- /palivla_digit/ssh_pod.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ]; then 4 | echo "Usage: ssh_pod.sh " 5 | exit 1 6 | fi 7 | 8 | TPU_VM_NAME=$1 9 | 10 | # Cache file for TPU name/zone mapping 11 | CACHE_FILE="$HOME/.cache/tpus" 12 | 13 | # Check if the TPU info is already cached 14 | if [ -f "$CACHE_FILE" ]; then 15 | CACHED_INFO=$(grep "^$TPU_VM_NAME:" "$CACHE_FILE") 16 | if [ -n "$CACHED_INFO" ]; then 17 | ZONE=$(echo "$CACHED_INFO" | cut -d':' -f2) 18 | N_WORKERS=$(echo "$CACHED_INFO" | cut -d':' -f3) 19 | fi 20 | fi 21 | 22 | # Use default values if not found in cache 23 | ZONE=${ZONE:-europe-west4-b} 24 | N_WORKERS=${N_WORKERS:-16} 25 | 26 | echo "Connecting to $TPU_VM_NAME with $N_WORKERS workers in zone $ZONE..." 27 | 28 | tmux kill-session -t tpc_${TPU_VM_NAME} || true 29 | tmux new -d -s tpc_${TPU_VM_NAME} 30 | for i in $(seq 0 $(($N_WORKERS - 1))); do 31 | TMUX_HEIGHT=$(tmux display-message -p '#{window_height}') 32 | TMUX_WIDTH=$(tmux display-message -p '#{window_width}') 33 | 34 | tmux new-window -t tpc_${TPU_VM_NAME}:$i -k 35 | INNER_TMUX_COMMAND="tmux a -t tpc_${TPU_VM_NAME}" 36 | tmux send-keys -t tpc_${TPU_VM_NAME} "gcloud compute tpus tpu-vm ssh --zone $ZONE $TPU_VM_NAME --worker=$i -- -t $INNER_TMUX_COMMAND" Enter 37 | done 38 | tmux a -t tpc_${TPU_VM_NAME} || tmux switch -t tpc_${TPU_VM_NAME} 39 | --------------------------------------------------------------------------------