10 |
11 |
12 | int main(int argc, char **argv)
13 | {
14 | const char *filename;
15 | int fd;
16 | int rc;
17 |
18 | if (argc != 2) {
19 | fprintf(stderr, "Usage: usbreset device-filename\n");
20 | return 1;
21 | }
22 | filename = argv[1];
23 |
24 | fd = open(filename, O_WRONLY);
25 | if (fd < 0) {
26 | perror("Error opening output file");
27 | return 1;
28 | }
29 |
30 | printf("Resetting USB device %s\n", filename);
31 | rc = ioctl(fd, USBDEVFS_RESET, 0);
32 | if (rc < 0) {
33 | perror("Error in ioctl");
34 | return 1;
35 | }
36 | printf("Reset successful\n");
37 |
38 | close(fd);
39 | return 0;
40 | }
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/requirements.txt:
--------------------------------------------------------------------------------
1 | pillow
2 | protobuf
3 | funcsigs
4 | future
5 | imageio
6 | imageio-ffmpeg
7 | librosa
8 | matplotlib
9 | moviepy
10 | opencv-python
11 | numpy>=1.23.4
12 | pyquaternion
13 | scikit-learn
14 | scikit-image
15 | scipy
16 | six
17 | requests
18 | nvidia_smi
19 | rospkg
20 | modern_robotics==1.1.1
21 | gym
22 | tqdm
23 | transformations
24 | ipdb
25 | joblib
26 | pickle5
27 | h5py
28 | funcsigs
29 | git+https://github.com/rail-berkeley/oculus_reader.git
30 | adafruit-circuitpython-bno055
31 | sounddevice
32 | numpy_ringbuffer
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/scripts/go_to_neutral_pose.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | if __name__ == '__main__':
4 | from widowx_envs.widowx_env import StateReachingWidowX
5 | env = StateReachingWidowX()
6 | env.move_to_neutral()
7 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/scripts/go_to_sleep_pose.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | if __name__ == '__main__':
4 | from widowx_envs.widowx_env import StateReachingWidowX
5 | env = StateReachingWidowX()
6 | env.move_to_neutral()
7 | env._controller.bot.arm.go_to_sleep_pose()
8 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/scripts/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | bash $(dirname "$0")/setup.sh || exit 1
4 |
5 | python_node_string='python_node:=false'
6 | camera_string='realsense:=false'
7 |
8 | source /opt/ros/noetic/setup.bash
9 | source ~/interbotix_ws/devel/setup.bash
10 | source ~/myenv/bin/activate
11 |
12 | # using 'exec' here is very important because roslaunch needs to do some cleanup after it exits
13 | # so when the container is killed the SIGTERM needs to be passed to roslaunch
14 | exec roslaunch widowx_controller widowx_rs.launch \
15 | ${video_stream_provider_string} camera_connector_chart:=/tmp/camera_connector_chart \
16 | serial_no_camera1:=${REALSENSE_SERIAL} \
17 | python_node:=false realsense:=true
18 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/scripts/setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 | if [[ -z "${ROBONETV2_ARM}" ]]; then
4 | echo 'Env variable "ROBONETV2_ARM" is not set. Please define it based on https://github.com/Interbotix/interbotix_ros_manipulators/tree/main/interbotix_ros_xsarms'
5 | echo 'For instance in case of WidowX 250 Robot Arm 6DOF use:'
6 | echo 'echo "export ROBONETV2_ARM=wx250s" >> ~/.bashrc && source ~/.bashrc'
7 | exit 1
8 | fi
9 |
10 | cd
11 | if [ ! -f ".built" ]; then
12 | cd ~/interbotix_ws && catkin_make && touch ~/.built
13 | fi
14 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | setuptools.setup(
4 | name='widowx_envs',
5 | version='0.0.1',
6 | packages=setuptools.find_packages(),
7 | license='MIT License',
8 | long_description=open('README.md').read(),
9 | entry_points={
10 | 'console_scripts': [
11 | 'widowx_env_service = widowx_envs.widowx_env_service:main',
12 | ],
13 | },
14 | )
15 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 2.8.3)
2 | project(widowx_controller)
3 |
4 | ## Compile as C++11, supported in ROS Kinetic and newer
5 | # add_compile_options(-std=c++11)
6 |
7 | catkin_python_setup()
8 |
9 | ## Find catkin macros and libraries
10 | ## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz)
11 | ## is used, also find other catkin packages
12 | find_package(catkin REQUIRED COMPONENTS
13 | rospy
14 | std_msgs
15 | message_generation
16 | )
17 |
18 | ## Generate services in the 'srv' folder
19 | add_service_files(
20 | FILES
21 | GotoNeutral.srv
22 | OpenGripper.srv
23 | MoveToEEP.srv
24 | MoveToState.srv
25 | GetGripperDesiredState.srv
26 | GetCartesianPose.srv
27 | GetState.srv
28 | GetVRButtons.srv
29 | EnableController.srv
30 | DisableController.srv
31 | SetGripperPosition.srv
32 | )
33 |
34 | ## Generate added messages and services with any dependencies listed here
35 | generate_messages(
36 | DEPENDENCIES
37 | std_msgs
38 | )
39 |
40 | catkin_package()
41 |
42 | include_directories(
43 | ${catkin_INCLUDE_DIRS}
44 | )
45 |
46 | catkin_install_python(PROGRAMS # add executable python files here
47 | DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
48 | )
49 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/launch/launch.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/launch/widowx_rs.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | widowx_controller
4 | 0.0.0
5 | Package for controlling WidowX robots
6 |
7 |
8 |
9 |
10 | jonathan
11 |
12 |
13 |
14 |
15 |
16 | TODO
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 | message_generation
39 |
40 |
41 |
42 |
43 |
44 | message_runtime
45 |
46 |
47 |
48 |
49 | catkin
50 | rospy
51 | std_msgs
52 | rospy
53 | std_msgs
54 | rospy
55 | std_msgs
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from catkin_pkg.python_setup import generate_distutils_setup
3 |
4 | # fetch values from package.xml
5 | setup_args = generate_distutils_setup(
6 | packages=['widowx_controller'],
7 | package_dir={'': 'src'})
8 |
9 | setup(**setup_args)
10 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/src/widowx_controller/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/bridge_with_digit/widowx_envs/widowx_controller/src/widowx_controller/__init__.py
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/srv/DisableController.srv:
--------------------------------------------------------------------------------
1 | ---
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/srv/EnableController.srv:
--------------------------------------------------------------------------------
1 | ---
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/srv/GetCartesianPose.srv:
--------------------------------------------------------------------------------
1 | ---
2 | float32[] eep
3 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/srv/GetGripperDesiredState.srv:
--------------------------------------------------------------------------------
1 | ---
2 | float32 des_pos
3 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/srv/GetState.srv:
--------------------------------------------------------------------------------
1 | ---
2 | float32[] joint_angles
3 | float32[] joint_velocities
4 | float32[] cartesian_pose
5 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/srv/GetVRButtons.srv:
--------------------------------------------------------------------------------
1 | ---
2 | int32 handle
3 | int32 a
4 | int32 b
5 | int32 rj
6 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/srv/GotoNeutral.srv:
--------------------------------------------------------------------------------
1 | float32 duration
2 | ---
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/srv/MoveToEEP.srv:
--------------------------------------------------------------------------------
1 | float32[] des_eep
2 | float32 duration
3 | ---
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/srv/MoveToState.srv:
--------------------------------------------------------------------------------
1 | float32[] target_xyz
2 | float32 target_zangle
3 | float32 duration
4 | ---
5 | int32 success
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/srv/OpenGripper.srv:
--------------------------------------------------------------------------------
1 | ---
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_controller/srv/SetGripperPosition.srv:
--------------------------------------------------------------------------------
1 | float32 des_pos
2 | ---
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_envs/__init__.py:
--------------------------------------------------------------------------------
1 | name = "widowx_envs"
2 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_envs/base/robot_configs.json:
--------------------------------------------------------------------------------
1 | {
2 | "wx250s": {
3 | "bound": [[0.19, -0.08, 0.05, -1.57, 0], [0.31, 0.08, 0.055, 1.57, 0]]
4 | },
5 | "wx250": {
6 | "bound": [[0.19, -0.08, 0.05, -1.57, 0], [0.31, 0.08, 0.055, 1.57, 0]]
7 | },
8 | "wx200": {
9 | "bound": [[0.19, -0.08, 0.05, -1.57, 0], [0.31, 0.08, 0.055, 1.57, 0]]
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_envs/policies/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/bridge_with_digit/widowx_envs/widowx_envs/policies/__init__.py
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_envs/policies/policy.py:
--------------------------------------------------------------------------------
1 | """ This file defines the base class for the policy. """
2 | import abc, six
3 | import pickle as pkl
4 | import numpy as np
5 | import pdb
6 |
7 | from widowx_envs.utils.utils import AttrDict, Configurable
8 |
9 | class Policy(Configurable):
10 | """Abstract class for policy."""
11 | def _default_hparams(self):
12 | dict = AttrDict(
13 | ngpu=1,
14 | gpu_id=0,
15 | )
16 | default_dict = super()._default_hparams()
17 | default_dict.update(dict)
18 | return default_dict
19 |
20 | def act(self, *args):
21 | """
22 | Args:
23 | Request necessary arguments in definition
24 | (see Agent code)
25 | Returns:
26 | A dict of outputs D
27 | -One key in D, 'actions' should have the action for this time-step
28 | """
29 | raise NotImplementedError("Must be implemented in subclass.")
30 |
31 | def reset(self):
32 | pass
33 |
34 | def set_log_dir(self, dir):
35 | self.traj_log_dir = dir
36 |
37 |
38 | class DummyPolicy(Policy):
39 | def __init__(self, ag_params, policyparams):
40 | """ Computes actions from states/observations. """
41 | pass
42 |
43 | def act(self, *args):
44 | return {'actions': None}
45 |
46 | def reset(self):
47 | return None
48 |
49 |
50 | class ReplayActions(Policy):
51 | def __init__(self, ag_params, policyparams):
52 | """ Computes actions from states/observations. """
53 | self._hp = self._default_hparams()
54 | self._override_defaults(policyparams)
55 | self.policy_out = pkl.load(open(self._hp.load_file + '/policy_out.pkl', 'rb'))
56 | self.env = ag_params.env
57 |
58 | def _default_hparams(self):
59 | dict = AttrDict(
60 | load_file="",
61 | type=None,
62 | )
63 | default_dict = super(Policy, self)._default_hparams()
64 | default_dict.update(dict)
65 | return default_dict
66 |
67 | def act(self, t):
68 | return self.policy_out[t]
69 |
70 | def reset(self):
71 | return None
72 |
73 |
74 | class NullPolicy(Policy):
75 | """
76 | Returns 0 for all timesteps
77 | """
78 | def __init__(self, ag_params, policyparams):
79 | self._adim = ag_params['adim']
80 | self._hp = self._default_hparams()
81 | self._override_defaults(policyparams)
82 |
83 | # def _default_hparams(self):
84 | # default_dict = {
85 | # 'wait_for_user': False
86 | # }
87 | # parent_params = super(NullPolicy, self)._default_hparams()
88 | # for k in default_dict.keys():
89 | # parent_params.add_hparam(k, default_dict[k])
90 | # return parent_params
91 |
92 | def act(self):
93 | return {'actions': np.zeros(self._adim)}
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_envs/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_envs/utils/exceptions.py:
--------------------------------------------------------------------------------
1 | class Bad_Traj_Exception(Exception):
2 | def __init__(self):
3 | pass
4 |
5 |
6 | class Image_Exception(Exception):
7 | def __init__(self):
8 | pass
9 |
10 |
11 | class Environment_Exception(Exception):
12 | def __init__(self):
13 | pass
14 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_envs/utils/grasp_utils.py:
--------------------------------------------------------------------------------
1 | from sklearn.linear_model import LinearRegression
2 | import numpy as np
3 | import time
4 | from sklearn.preprocessing import PolynomialFeatures
5 |
6 |
7 | def compute_robot_transformation_matrix(a, b):
8 | lr = LinearRegression(fit_intercept=False).fit(a, b)
9 | return lr.coef_.T
10 |
11 |
12 | def convert_obs_to_image(obs, transpose=False):
13 | print("taking picture...")
14 | image = np.uint8(np.reshape(obs['image'] * 255, (3, 64, 64)))
15 | if transpose: image = np.transpose(image, (1, 2, 0))
16 | # print("image.shape", image.shape)
17 | return image
18 |
19 |
20 | def rgb_to_robot_coords(rgb_coords, transmatrix):
21 | # add vector of 1s as feature to the pc_coords.
22 | assert len(rgb_coords.shape) <= 2
23 | if len(rgb_coords.shape) == 1:
24 | rgb_coords = np.array(rgb_coords[None])
25 | poly = PolynomialFeatures(2)
26 | rgb_coords = poly.fit_transform(rgb_coords)
27 |
28 | if transmatrix is not None:
29 | robot_coords = rgb_coords @ transmatrix
30 | return np.squeeze(robot_coords)
31 |
32 | def get_image_obs(env, image_xyz=None, skip_move_to_neutral=False):
33 | joint_angles = env._controller.get_joint_angles()
34 | if image_xyz is None:
35 | if not skip_move_to_neutral:
36 | env.move_to_neutral(0.5)
37 | # else:
38 | # env.reset()
39 | else:
40 | env.move_to_state(image_xyz, target_zangle=0, duration=0.5)
41 | time.sleep(0.2) # wait for camera to catch up
42 | obs = env.current_obs()
43 | env._controller.set_joint_angles(joint_angles, 0.5)
44 | return obs
45 |
46 |
47 | def get_image(env, transpose=True, image_xyz=None, skip_move_to_neutral=False):
48 | obs = get_image_obs(env, image_xyz, skip_move_to_neutral)
49 | return convert_obs_to_image(obs, transpose=transpose)
50 |
51 |
52 | def execute_reach(env, reach_policy, reachpoint, noise=0.0):
53 | reach_policy.reset(reach_point=reachpoint)
54 | for i in range(6):
55 | action, _ = reach_policy.get_action()
56 |
57 | # noise
58 | noise_dims = 2
59 | noise_stds = [noise] * noise_dims + [0] * (len(action) - noise_dims)
60 | action = np.random.normal(loc=action, scale=noise_stds)
61 | action = np.clip(action, -1.0, 1.0)
62 | # import ipdb; ipdb.set_trace()
63 | obs, _, _, _ = env.step(action)
64 | return obs
65 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_envs/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import moviepy.editor as mpy
3 | import os
4 | import numpy as np
5 | from PIL import Image, ImageDraw
6 | from PIL import Image
7 |
8 | def resize_store(t, target_array, input_array):
9 | target_img_height, target_img_width = target_array.shape[2:4]
10 |
11 | if (target_img_height, target_img_width) == input_array.shape[1:3]:
12 | for i in range(input_array.shape[0]):
13 | target_array[t, i] = input_array[i]
14 | else:
15 | for i in range(input_array.shape[0]):
16 | target_array[t, i] = cv2.resize(input_array[i], (target_img_width, target_img_height),
17 | interpolation=cv2.INTER_AREA)
18 |
19 |
20 | def npy_to_gif(im_list, filename, fps=4):
21 | save_dir = '/'.join(str.split(filename, '/')[:-1])
22 |
23 | if not os.path.exists(save_dir):
24 | print('creating directory: ', save_dir)
25 | os.makedirs(save_dir)
26 |
27 | clip = mpy.ImageSequenceClip(im_list, fps=fps)
28 | clip.write_gif(filename + '.gif')
29 |
30 |
31 | def npy_to_mp4(im_list, filename, fps=4):
32 | save_dir = '/'.join(str.split(filename, '/')[:-1])
33 |
34 | if not os.path.exists(save_dir):
35 | print('creating directory: ', save_dir)
36 | os.mkdir(save_dir)
37 |
38 | clip = mpy.ImageSequenceClip(im_list, fps=fps)
39 | clip.write_videofile(filename + '.mp4')
40 |
41 | def draw_text_image(text, background_color=(255,255,255), image_size=(30, 64), dtype=np.float32):
42 |
43 | text_image = Image.new('RGB', image_size[::-1], background_color)
44 | draw = ImageDraw.Draw(text_image)
45 | if text:
46 | draw.text((4, 0), text, fill=(0, 0, 0))
47 | if dtype == np.float32:
48 | return np.array(text_image).astype(np.float32)/255.
49 | else:
50 | return np.array(text_image)
51 |
52 |
53 | def draw_text_onimage(text, image, color=(255, 0, 0)):
54 | if image.dtype == np.float32:
55 | image = (image*255.).astype(np.uint8)
56 | assert image.dtype == np.uint8
57 | text_image = Image.fromarray(image)
58 | draw = ImageDraw.Draw(text_image)
59 | draw.text((4, 0), text, fill=color)
60 | return np.array(text_image).astype(np.float32)/255.
61 |
--------------------------------------------------------------------------------
/bridge_with_digit/widowx_envs/widowx_envs/utils/sync.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Value, Lock
2 |
3 |
4 | class SyncCounter:
5 | def __init__(self, base_value=0):
6 | self._lock = Lock()
7 | self._value = Value('i', base_value)
8 |
9 | @property
10 | def ret_increment(self):
11 | with self._lock:
12 | ret_val = self._value.value
13 | self._value.value += 1
14 | return ret_val
15 |
16 | @property
17 | def value(self):
18 | with self._lock:
19 | ret_val = self._value.value
20 | return ret_val
21 |
22 |
23 | class ManagedSyncCounter(SyncCounter):
24 | def __init__(self, manager, base_value=0):
25 | self._lock, self._value = manager.Lock(), manager.Value('i', base_value)
26 |
--------------------------------------------------------------------------------
/media/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/media/teaser.jpg
--------------------------------------------------------------------------------
/octo_digit/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | exclude = .git
3 | max-line-length = 88
4 | select = E,F,W,C
5 | ignore=W503,
6 | E203,
7 | E731,
8 | E722,
9 | F841,
10 | E402,
11 | E741,
12 | E501,
13 | C406,
14 |
--------------------------------------------------------------------------------
/octo_digit/.github/workflows/pre-commit.yaml:
--------------------------------------------------------------------------------
1 | name: pre-commit
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches: [main]
7 |
8 | jobs:
9 | pre-commit:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v3
13 | - uses: actions/setup-python@v3
14 | - uses: pre-commit/action@v3.0.0
15 |
--------------------------------------------------------------------------------
/octo_digit/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3.10
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v2.3.0
6 | hooks:
7 | - id: check-yaml
8 | - id: check-ast
9 | - id: check-added-large-files
10 | exclude: ^examples/
11 | - id: check-case-conflict
12 | - id: check-merge-conflict
13 | - id: end-of-file-fixer
14 | - id: trailing-whitespace
15 | - id: detect-private-key
16 | - id: debug-statements
17 | exclude: ^experiments/
18 | - repo: https://github.com/psf/black
19 | rev: 22.10.0
20 | hooks:
21 | - id: black
22 | exclude: ^experiments/
23 | - repo: https://github.com/PyCQA/flake8
24 | rev: 6.1.0
25 | hooks:
26 | - id: flake8
27 | exclude: ^experiments/
28 | - repo: https://github.com/pycqa/isort
29 | rev: 5.12.0
30 | hooks:
31 | - id: isort
32 | exclude: ^experiments/
33 | args: ["--profile", "black", "--src", "octo", "--src", "experiments"]
34 | - repo: https://github.com/srstevenson/nb-clean
35 | rev: 3.1.0
36 | hooks:
37 | - id: nb-clean
38 | args:
39 | - --remove-empty-cells
40 | - --preserve-cell-outputs
41 |
--------------------------------------------------------------------------------
/octo_digit/.python-version:
--------------------------------------------------------------------------------
1 | 3.10
2 |
--------------------------------------------------------------------------------
/octo_digit/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Robotic AI & Learning Lab Berkeley
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/octo_digit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/__init__.py
--------------------------------------------------------------------------------
/octo_digit/docs/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/docs/assets/teaser.jpg
--------------------------------------------------------------------------------
/octo_digit/eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/eval/__init__.py
--------------------------------------------------------------------------------
/octo_digit/eval/decode_config.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | gen_modes = [("visual",), ("tactile",), ("visual", "tactile")]
4 | # gen_modes = [('audio',)]
5 | csv_modes = [",".join(modality_tuple) for modality_tuple in gen_modes]
6 | gen_mode_lang_names = [
7 | "all_lang_4",
8 | ]
9 | modality_obs_keys = {
10 | "visual": ["image_primary", "image_wrist"],
11 | "tactile": ["image_digit_right", "image_digit_left"],
12 | "audio": [
13 | "mel_spectro",
14 | ],
15 | }
16 |
17 | modality_specific_keys = []
18 | for v in modality_obs_keys.values():
19 | modality_specific_keys.extend(v)
20 | modality_specific_keys = set(modality_specific_keys)
21 |
22 | includes = ["pad_mask_dict", "task_completed", "timestep", "timestep_pad_mask"]
23 |
24 | WINDOW_SIZE = 2
25 | pad_mask = np.array([[True for _ in range(WINDOW_SIZE)]])[0]
26 |
--------------------------------------------------------------------------------
/octo_digit/eval/envs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/eval/envs/__init__.py
--------------------------------------------------------------------------------
/octo_digit/eval/eval_requirements.txt:
--------------------------------------------------------------------------------
1 | funcsigs
2 | opencv-python
3 | pyquaternion
4 | librosa
5 | edgeml @ git+https://github.com/youliangtan/edgeml.git
6 |
--------------------------------------------------------------------------------
/octo_digit/eval/recursive_dict_print.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | MAX_KEY_LEN = 20
4 | INDENT_SIZE = MAX_KEY_LEN + 4
5 | INDENT = "".join([" " for _ in range(INDENT_SIZE)])
6 |
7 |
8 | def recursive_dict_print(dictionary: dict, prefix=""):
9 | for key, val in dictionary.items():
10 | key = key[:MAX_KEY_LEN]
11 | if isinstance(val, dict):
12 | print(f"{prefix}{key}")
13 | new_prefix = prefix + INDENT
14 | recursive_dict_print(val, new_prefix)
15 | else:
16 | indent = "".join([" " for _ in range(INDENT_SIZE - len(key))])
17 | try:
18 | print(f"{prefix}{key}:{indent}{val.shape} {val.dtype}")
19 | except AttributeError:
20 | print(f"{prefix}{key}:{indent} {type(val)}")
21 |
--------------------------------------------------------------------------------
/octo_digit/eval/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(name="eval", packages=["envs"])
4 |
--------------------------------------------------------------------------------
/octo_digit/mem_lims.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | window_sizes=(8 10 12 13 14 15 16)
3 | batch_sizes=( 32 64 128 200 256 )
4 | for i in "${batch_sizes[@]}"
5 | do
6 | for j in "${window_sizes[@]}"
7 | do
8 | python scripts/finetune_josh.py --config=scripts/configs/josh_finetune_config.py:"None" --name=mem_test --o_window_size="${j}" --o_batch_size="${i}" --o_steps=2 --debug=True --mode="${1}" --log_file="${1}_log.txt"
9 | done
10 | done
11 |
--------------------------------------------------------------------------------
/octo_digit/octo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/octo/__init__.py
--------------------------------------------------------------------------------
/octo_digit/octo/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/octo/data/__init__.py
--------------------------------------------------------------------------------
/octo_digit/octo/data/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/octo/data/utils/__init__.py
--------------------------------------------------------------------------------
/octo_digit/octo/data/utils/goal_relabeling.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required.
3 | Each function should add entries to the "task" dict.
4 | """
5 |
6 | from typing import Optional
7 |
8 | import tensorflow as tf
9 |
10 | from octo.data.utils.data_utils import tree_merge
11 |
12 |
13 | def uniform(traj: dict, max_goal_distance: Optional[int] = None) -> dict:
14 | """
15 | Relabels with a true uniform distribution over future states.
16 | Optionally caps goal distance.
17 | """
18 | traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0]
19 |
20 | # select a random future index for each transition i in the range [i, traj_len)
21 | rand = tf.random.uniform([traj_len])
22 | low = tf.cast(tf.range(traj_len), tf.float32)
23 | if max_goal_distance is not None:
24 | high = tf.cast(
25 | tf.minimum(tf.range(traj_len) + max_goal_distance, traj_len), tf.float32
26 | )
27 | else:
28 | high = tf.cast(traj_len, tf.float32)
29 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32)
30 |
31 | # sometimes there are floating-point errors that cause an out-of-bounds
32 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1)
33 |
34 | # adds keys to "task" mirroring "observation" keys (must do a tree merge to combine "pad_mask_dict" from
35 | # "observation" and "task" properly)
36 | goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"])
37 | traj["task"] = tree_merge(traj["task"], goal)
38 |
39 | return traj
40 |
--------------------------------------------------------------------------------
/octo_digit/octo/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/octo/model/__init__.py
--------------------------------------------------------------------------------
/octo_digit/octo/model/components/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/octo/model/components/__init__.py
--------------------------------------------------------------------------------
/octo_digit/octo/model/components/base.py:
--------------------------------------------------------------------------------
1 | import flax
2 | import jax
3 | import jax.numpy as jnp
4 |
5 | from octo.utils.typing import Sequence
6 |
7 |
8 | @flax.struct.dataclass
9 | class TokenGroup:
10 | """A group of tokens that have semantic meaning together (e.g. the tokens for a single observation)
11 |
12 | Attributes:
13 | tokens: jax.Array of shape (..., n_tokens, token_dim)
14 | mask: jax.Array of shape (..., n_tokens) indicating which tokens are valid (1) vs padding (0)
15 | """
16 |
17 | tokens: jax.typing.ArrayLike
18 | mask: jax.typing.ArrayLike
19 |
20 | @classmethod
21 | def create(
22 | cls, tokens: jax.typing.ArrayLike, mask: jax.typing.ArrayLike = None, **kwargs
23 | ):
24 | if mask is None:
25 | mask = jnp.ones(tokens.shape[:-1])
26 | assert mask.ndim == tokens.ndim - 1
27 | return cls(tokens, mask, **kwargs)
28 |
29 | @classmethod
30 | def concatenate(cls, group_list: Sequence["TokenGroup"], axis=-2):
31 | data = jnp.concatenate([t.tokens for t in group_list], axis=axis)
32 | mask = jnp.concatenate([t.mask for t in group_list], axis=axis + 1)
33 | return cls(data, mask)
34 |
--------------------------------------------------------------------------------
/octo_digit/octo/model/components/film_conditioning_layer.py:
--------------------------------------------------------------------------------
1 | # adapted from https://github.com/google-research/robotics_transformer/blob/master/film_efficientnet/film_conditioning_layer.py
2 |
3 | import flax.linen as nn
4 | import jax.numpy as jnp
5 |
6 |
7 | class FilmConditioning(nn.Module):
8 | @nn.compact
9 | def __call__(self, conv_filters: jnp.ndarray, conditioning: jnp.ndarray):
10 | """Applies FiLM conditioning to a convolutional feature map.
11 |
12 | Args:
13 | conv_filters: A tensor of shape [batch_size, height, width, channels].
14 | conditioning: A tensor of shape [batch_size, conditioning_size].
15 |
16 | Returns:
17 | A tensor of shape [batch_size, height, width, channels].
18 | """
19 | projected_cond_add = nn.Dense(
20 | features=conv_filters.shape[-1],
21 | kernel_init=nn.initializers.zeros,
22 | bias_init=nn.initializers.zeros,
23 | )(conditioning)
24 | projected_cond_mult = nn.Dense(
25 | features=conv_filters.shape[-1],
26 | kernel_init=nn.initializers.zeros,
27 | bias_init=nn.initializers.zeros,
28 | )(conditioning)
29 |
30 | projected_cond_add = projected_cond_add[:, None, None, :]
31 | projected_cond_mult = projected_cond_mult[:, None, None, :]
32 |
33 | return conv_filters * (1 + projected_cond_add) + projected_cond_mult
34 |
--------------------------------------------------------------------------------
/octo_digit/octo/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(name="octo", packages=["data", "model", "utils"])
--------------------------------------------------------------------------------
/octo_digit/octo/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/octo_digit/octo/utils/__init__.py
--------------------------------------------------------------------------------
/octo_digit/octo/utils/fuse_constants.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import jax
3 |
4 | modality_combination_order = ['simple', '', 'visual', 'tactile', 'audio', 'visual,tactile', 'visual,audio', 'tactile,audio', 'visual,tactile,audio']
5 | _modality_combinations = [('visual',), ('tactile',), ('audio',), ('visual', 'tactile'), ('visual', 'audio'), ('tactile', 'audio'), ('visual', 'tactile', 'audio')]
6 | modality_combinations = [','.join(combination) for combination in _modality_combinations]
7 | fuse_loss_modal_indices = {i: combination for i, combination in enumerate(modality_combination_order) if combination != ''}
8 | contrastive_indices = fuse_loss_modal_indices
9 | generative_indices = {k: v for k, v in contrastive_indices.items() if v != 'simple'}
10 | name_to_index_generative = {v: k for k, v in generative_indices.items()}
11 |
12 | modality_to_observation_keys = {
13 | 'visual': ['image_primary', 'image_wrist'],
14 | 'tactile': ['image_digit_right', 'image_digit_right_background', 'image_digit_left', 'image_digit_left_background'],
15 | 'audio': ['mic', 'mel_spectro']
16 | }
17 | modality_specific_keys = []
18 | for v in modality_to_observation_keys.values():
19 | modality_specific_keys.extend(v)
20 | modality_specific_keys = set(modality_specific_keys)
21 | nonspecific_keys= ['task_completed', 'timestep', 'modality_idx']
22 |
23 | modality_to_observation_keys['simple'] = list(modality_specific_keys)
24 |
25 |
26 | def create_fuse_modal_masks(example_obs):
27 | modal_masks = {}
28 | pad_mask_dict = example_obs['pad_mask_dict']
29 | for i, combination in fuse_loss_modal_indices.items():
30 | combination_mask = {}
31 | for modality in combination.split(','):
32 | for obs_key in modality_to_observation_keys[modality]:
33 | if obs_key in pad_mask_dict:
34 | combination_mask[obs_key] = jnp.ones_like(pad_mask_dict[obs_key])
35 | for obs_key in nonspecific_keys:
36 | if obs_key in pad_mask_dict:
37 | combination_mask[obs_key] = jnp.ones_like(pad_mask_dict[obs_key])
38 | for obs_key in pad_mask_dict:
39 | if obs_key not in combination_mask:
40 | combination_mask[obs_key] = jnp.zeros_like(pad_mask_dict[obs_key])
41 | modal_masks[i] = combination_mask
42 | assert modal_masks[i].keys() == pad_mask_dict.keys()
43 | return modal_masks
44 |
45 |
46 | def create_batch(batch, observation_masks, fuse_modal_masks, modality_combination_index: int):
47 | if observation_masks is None:
48 | batch['observation']['pad_mask_dict'] = fuse_modal_masks[modality_combination_index]
49 | else:
50 | batch['observation']['pad_mask_dict'] = jax.tree_map(
51 | lambda true_mask, fuse_mask: jnp.logical_and(true_mask, fuse_mask),
52 | observation_masks,
53 | fuse_modal_masks[modality_combination_index],
54 | )
55 | return batch
--------------------------------------------------------------------------------
/octo_digit/octo/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | MAX_KEY_LEN = 15
2 | INDENT_SIZE = MAX_KEY_LEN + 4
3 | INDENT = ''.join([' ' for _ in range(INDENT_SIZE)])
4 | HEADING_SEPARATOR = "############################################"
5 |
6 | def print_separator(log_func=print):
7 | log_func(HEADING_SEPARATOR)
8 |
9 |
10 | def pretty_print_dict(dictionary, prefix="", log_func=print, pad_with_newlines=True):
11 | lines_to_output = []
12 | def _pretty_print(dictionary, prefix=""):
13 | for key, val in dictionary.items():
14 | key = key[:MAX_KEY_LEN]
15 | if isinstance(val, dict):
16 | lines_to_output.append(f'{prefix}{key}')
17 | _pretty_print(val, prefix + INDENT)
18 | else:
19 | indent = ' ' * (INDENT_SIZE - len(key))
20 | lines_to_output.append(f'{prefix}{key}:{indent}{val}')
21 | _pretty_print(dictionary, prefix)
22 | if pad_with_newlines:
23 | lines_to_output = [''] + lines_to_output + ['']
24 | log_func('\n'.join(lines_to_output))
25 |
26 |
27 | def append_identity_to_metrics(metrics: dict, identity_suffix: str) -> dict:
28 | processed_metrics = {}
29 | for key, val in metrics.items():
30 | processed_metrics[f'{key}_{identity_suffix}'] = val
31 | return processed_metrics
--------------------------------------------------------------------------------
/octo_digit/octo/utils/typing.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Mapping, Sequence, Union
2 |
3 | import jax
4 |
5 | PRNGKey = jax.Array
6 | PyTree = Union[jax.typing.ArrayLike, Mapping[str, "PyTree"]]
7 | Config = Union[Any, Mapping[str, "Config"]]
8 | Params = Mapping[str, PyTree]
9 | Perturbations = Mapping[str, PyTree]
10 | JaxArray = jax.typing.ArrayLike
11 | Data = Mapping[str, PyTree]
12 | Shape = Sequence[int]
13 | Dtype = jax.typing.DTypeLike
14 |
--------------------------------------------------------------------------------
/octo_digit/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "octo_digit"
3 | version = "0.0.1"
4 | description = ""
5 | readme = "README.md"
6 | requires-python = "==3.10.*"
7 | dependencies = [
8 | "gym>=0.26",
9 | "numpy==1.24.3",
10 | "ml_dtypes==0.2.0",
11 | "chex==0.1.85",
12 | "optax==0.1.5",
13 | "tensorflow_probability==0.23.0",
14 | "tensorflow==2.15.0",
15 | "jax==0.4.20",
16 | "distrax==0.1.5",
17 | "flax==0.7.5",
18 | "ml_collections>=0.1.0",
19 | "tqdm>=4.60.0",
20 | "absl-py>=0.12.0",
21 | "wandb>=0.12.14",
22 | "einops>=0.6.1",
23 | "imageio>=2.31.1",
24 | "moviepy>=1.0.3",
25 | "pre-commit==3.3.3",
26 | "transformers>=4.34.1",
27 | "tensorflow_hub>=0.14.0",
28 | "tensorflow_text>=2.13.0",
29 | "tensorflow_datasets>=4.9.0",
30 | "tensorflow_graphics==2021.12.3",
31 | "dlimp@git+https://github.com/kvablack/dlimp@5edaa4691567873d495633f2708982b42edf1972",
32 | "plotly>=5.16.1",
33 | "matplotlib",
34 | "scipy==1.12.0",
35 | "funcsigs",
36 | "opencv-python",
37 | "pyquaternion",
38 | "librosa",
39 | "edgeml @ git+https://github.com/youliangtan/edgeml.git",
40 | "octo",
41 | "eval"
42 | ]
43 |
44 | [project.optional-dependencies]
45 | tpu = [
46 | "jax[tpu]==0.4.20",
47 | "libtpu-nightly"
48 | ]
49 | gpu = [
50 | "jax[cuda11_pip]==0.4.20",
51 | ]
52 |
53 | [tool.uv]
54 | find-links = [
55 | "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html",
56 | "https://storage.googleapis.com/jax-releases/libtpu_releases.html",
57 | ]
58 | override-dependencies = ["scipy==1.12.0"]
59 | prerelease = "allow"
60 | conflicts = [
61 | [
62 | { extra = "tpu" },
63 | { extra = "gpu" },
64 | ],
65 | ]
66 |
67 |
68 | [tool.uv.sources]
69 | octo = { path = "./octo", editable = true }
70 | eval = { path = "./eval", editable = true }
71 |
72 | [tool.black]
73 | # https://github.ciom/psf/black
74 | line-length = 88
75 | target-version = ["py310"]
76 | exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|build|dist)"
77 |
78 | [tool.isort]
79 | profile = "black"
80 | line_length = 88
81 | force_sort_within_sections = "True"
82 | order_by_type = "False"
83 |
--------------------------------------------------------------------------------
/octo_digit/requirements.txt:
--------------------------------------------------------------------------------
1 | gym >= 0.26
2 | numpy == 1.24.3
3 | ml_dtypes == 0.2.0
4 | chex == 0.1.85
5 | optax == 0.1.5
6 | tensorflow_probability == 0.23.0
7 | tensorflow == 2.15.0
8 | jax == 0.4.20
9 | distrax == 0.1.5
10 | flax == 0.7.5
11 | ml_collections >= 0.1.0
12 | tqdm >= 4.60.0
13 | absl-py >= 0.12.0
14 | scipy >= 1.6.0
15 | wandb >= 0.12.14
16 | einops >= 0.6.1
17 | imageio >= 2.31.1
18 | moviepy >= 1.0.3
19 | pre-commit == 3.3.3
20 | transformers >= 4.34.1
21 | tensorflow_hub >= 0.14.0
22 | tensorflow_text >= 2.13.0
23 | tensorflow_datasets == 4.9.2
24 | tensorflow_graphics == 2021.12.3
25 | dlimp @ git+https://github.com/kvablack/dlimp@5edaa4691567873d495633f2708982b42edf1972
26 | plotly >= 5.16.1
27 | matplotlib
28 | scipy==1.12.0
29 |
--------------------------------------------------------------------------------
/octo_digit/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(name="octo_digit", packages=["octo", "eval"])
4 |
--------------------------------------------------------------------------------
/palivla_digit/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .venv
3 | notebooks
4 | **/*.mp4
5 | *.ipynb
6 | wandb/
7 | *.egg-info
8 |
9 | /trained_tokenizers
10 | /checkpoints
11 | models
12 |
--------------------------------------------------------------------------------
/palivla_digit/.python-version:
--------------------------------------------------------------------------------
1 | 3.11
--------------------------------------------------------------------------------
/palivla_digit/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | At this time we do not plan to accept non-trivial contributions. The main
4 | purpose of this codebase is to allow the community to reproduce results from our
5 | publications.
6 |
7 | You are however free to start a fork of the project for your purposes as
8 | permitted by the license.
9 |
10 | ## Contributor License Agreement
11 |
12 | Contributions to this project must be accompanied by a Contributor License
13 | Agreement (CLA). You (or your employer) retain the copyright to your
14 | contribution; this simply gives us permission to use and redistribute your
15 | contributions as part of the project. Head over to
16 | to see your current agreements on file or
17 | to sign a new one.
18 |
19 | You generally only need to submit a CLA once, so if you've already submitted one
20 | (even if it was for a different project), you probably don't need to do it
21 | again.
22 |
23 | ## Community Guidelines
24 |
25 | This project follows
26 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
27 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/__init__.py
--------------------------------------------------------------------------------
/palivla_digit/big_vision/configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/configs/__init__.py
--------------------------------------------------------------------------------
/palivla_digit/big_vision/configs/common_fewshot.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Most common few-shot eval configuration."""
16 |
17 | import ml_collections as mlc
18 |
19 |
20 | def get_fewshot_lsr(target_resolution=224, resize_resolution=256,
21 | runlocal=False, **kw):
22 | """Returns a standard-ish fewshot eval configuration."""
23 | kw.setdefault('representation_layer', 'pre_logits')
24 | kw.setdefault('shots', (1, 5, 10, 25))
25 | kw.setdefault('l2_reg', 2.0 ** 10)
26 | kw.setdefault('num_seeds', 3)
27 | kw.setdefault('prefix', '') # No prefix as we already use a/ z/ and zz/
28 |
29 | # Backward-compatible default:
30 | if not any(f'log_{x}' in kw for x in ['steps', 'percent', 'examples', 'epochs']): # pylint: disable=line-too-long
31 | kw['log_steps'] = 25_000
32 |
33 | config = mlc.ConfigDict(kw)
34 | config.type = 'fewshot_lsr'
35 | config.datasets = {
36 | 'caltech': ('caltech101', 'train', 'test'), # copybara:srtip
37 | 'cars': ('cars196:2.1.0', 'train', 'test'),
38 | 'cifar100': ('cifar100', 'train', 'test'),
39 | 'dtd': ('dtd', 'train', 'test'),
40 | # The first 65000 ImageNet samples have at least 30 shots per any class.
41 | # Commented out by default because needs manual download.
42 | # 'imagenet': ('imagenet2012', 'train[:65000]', 'validation'),
43 | 'pets': ('oxford_iiit_pet', 'train', 'test'),
44 | 'uc_merced': ('uc_merced', 'train[:1000]', 'train[1000:]'),
45 | } if not runlocal else {
46 | 'pets': ('oxford_iiit_pet', 'train', 'test'),
47 | }
48 | config.pp_train = (f'decode|resize({resize_resolution})|'
49 | f'central_crop({target_resolution})|'
50 | f'value_range(-1,1)|keep("image", "label")')
51 | config.pp_eval = (f'decode|resize({resize_resolution})|'
52 | f'central_crop({target_resolution})|'
53 | f'value_range(-1,1)|keep("image", "label")')
54 | config.display_first = [('imagenet', 10)] if not runlocal else [('pets', 10)]
55 |
56 | return config
57 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/configs/proj/cappa/README.md:
--------------------------------------------------------------------------------
1 | # Image Captioners Are Scalable Vision Learners Too
2 |
3 | *by Michael Tschannen, Manoj Kumar, Andreas Steiner, Xiaohua Zhai, Neil Houlsby, Lucas Beyer* [[arxiv]](https://arxiv.org/abs/2306.07915)
4 |
5 | 
6 |
7 | This directory contains a config for training a CapPa model from scratch.
8 | Note that most models in the paper were trained on a proprietary dataset
9 | (WebLI), but similar results can be obtained by training on [LAION](https://laion.ai/).
10 |
11 | By default, this config trains on COCO captions as this data set is readily
12 | available in [TFDS](https://www.tensorflow.org/datasets) without manual steps.
13 | This is not meant to produce a meaningful model, but
14 | provides a way for the user to run the config out of the box. Please update the
15 | config with with a TFDS-wrapped variant of your favorite image/text data set to
16 | train capable models.
17 |
18 | After setting up `big_vision` as described in the [main README](https://github.com/google-research/big_vision#cloud-tpu-vm-setup), training can be launched as follows
19 |
20 | ```
21 | python -m big_vision.trainers.proj.cappa.generative \
22 | --config big_vision/configs/proj/cappa/pretrain.py \
23 | --workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%m-%d_%H%M'`
24 | ```
25 |
26 | To run the Cap baseline (autoregressive captioning without parallel prediction),
27 | set `config.model.masked_pred_prob = 0.0`.
28 |
29 | ### Citation
30 | ```
31 | @inproceedings{tschannen2023image,
32 | title={Image Captioners Are Scalable Vision Learners Too},
33 | author={Tschannen, Michael and Kumar, Manoj and Steiner, Andreas and Zhai, Xiaohua and Houlsby, Neil and Beyer, Lucas},
34 | booktitle={Neural Information Processing Systems (NeurIPS)},
35 | year={2023}
36 | }
37 | ```
38 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/configs/proj/cappa/cappa_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/configs/proj/cappa/cappa_architecture.png
--------------------------------------------------------------------------------
/palivla_digit/big_vision/configs/proj/distill/README.md:
--------------------------------------------------------------------------------
1 | # Knowledge distillation: A good teacher is patient and consistent
2 | *by Lucas Beyer, Xiaohua Zhai, Amélie Royer, Larisa Markeeva, Rohan Anil, Alexander Kolesnikov*
3 |
4 | ## Introduction
5 | We publish all teacher models, and configurations for the main experiments of
6 | the paper, as well as training logs and student models.
7 |
8 | Please read the main [big_vision README](/README.md) to learn how to run
9 | configs, and remember that each config file contains an example invocation in
10 | the top-level comment.
11 |
12 | ## Results
13 |
14 | We provide the following [colab to read and plot the logfiles](https://colab.research.google.com/drive/1nMykzUzsfQ_uAxfj3k35DYsATnG_knPl?usp=sharing)
15 | of a few runs that we reproduced on Cloud.
16 |
17 | ### ImageNet-1k
18 |
19 | The file [bit_i1k.py](bit_i1k.py) is the configuration which reproduces our
20 | distillation runs on ImageNet-1k reported in Figures 1 and 5(left) and the first
21 | row of Table1.
22 |
23 | We release both student and teacher models:
24 |
25 | | Model | Download link | Resolution | ImageNet top-1 acc. (paper) |
26 | | :--- | :---: | :---: | :---: |
27 | | BiT-R50x1 | [link](https://storage.googleapis.com/bit_models/distill/R50x1_160.npz) | 160 | 80.5 |
28 | | BiT-R50x1 | [link](https://storage.googleapis.com/bit_models/distill/R50x1_224.npz) | 224 | 82.8 |
29 | | BiT-R152x2 | [link](https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz) | 224 | 83.0 |
30 | | BiT-R152x2 | [link](https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz) | 384 | 84.3 |
31 |
32 | ### Flowers/Pet/Food/Sun
33 |
34 | The files [bigsweep_flowers_pet.py](bigsweep_flowers_pet.py) and
35 | [bigsweep_food_sun.py](bigsweep_food_sun.py) can be used to reproduce the
36 | distillation runs on these datasets and shown in Figures 3,4,9-12, and Table4.
37 |
38 | While our open-source release does not currently support doing hyper-parameter
39 | sweeps, we still provide an example of the sweeps at the end of the configs
40 | for reference.
41 |
42 | ### Teacher models
43 | Links to all teacher models we used can be found in [common.py](common.py).
44 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/configs/proj/distill/common.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Most common teachers for distillation."""
16 |
17 | # pylint: disable=line-too-long
18 | inits = { # pylint: disable=duplicate-key Internally, we override some paths for convenience.
19 | 'BiT-M R152x2 imagenet2012 ic224': 'gs://bit_models/distill/R152x2_T_224.npz',
20 | 'BiT-M R152x2 imagenet2012 rc384': 'gs://bit_models/distill/R152x2_T_384.npz',
21 | 'BiT-M R152x2 flowers rc128': 'gs://bit_models/distill/R152x2_T_flowers128.npz',
22 | 'BiT-M R152x2 pet rc128': 'gs://bit_models/distill/R152x2_T_pet128.npz',
23 | 'BiT-M R152x2 food rc128': 'gs://bit_models/distill/R152x2_T_food128.npz',
24 | 'BiT-M R152x2 sun rc128': 'gs://bit_models/distill/R152x2_T_sun128.npz',
25 |
26 | }
27 | # pylint: enable=line-too-long
28 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/configs/proj/flexivit/timing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint: disable=line-too-long,missing-function-docstring
16 | r"""A config to run timing for FlexiViT (only inference, no I/O etc.).
17 |
18 | big_vision.tools.eval_only \
19 | --config big_vision/configs/proj/flexivit/timing.py \
20 | --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
21 | --config.total_epochs 90
22 | """
23 |
24 | from ml_collections import ConfigDict
25 |
26 |
27 | def get_config():
28 | c = ConfigDict()
29 |
30 | shape = (240, 240, 3)
31 | c.batch_size = 8 # swept
32 | c.init_shapes = [(1, *shape)]
33 | c.representation_layer = 'pre_logits'
34 |
35 | # Creating complete model using all params, the sweep will go over variants.
36 | c.model_name = 'xp.flexivit.vit'
37 | c.model = dict(
38 | variant='B',
39 | pool_type='tok',
40 | patch_size=(10, 10), # Like deit@384
41 | seqhw=(24, 24),
42 | )
43 | c.num_classes = 0
44 |
45 | c.evals = {}
46 | c.evals.timing = dict(
47 | type='timing',
48 | input_shapes=[shape],
49 | timing=True,
50 | pred_kw=dict(outputs=('pre_logits',)),
51 | )
52 |
53 | return c
--------------------------------------------------------------------------------
/palivla_digit/big_vision/configs/proj/givt/givt_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/configs/proj/givt/givt_overview.png
--------------------------------------------------------------------------------
/palivla_digit/big_vision/configs/proj/paligemma/paligemma.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/configs/proj/paligemma/paligemma.png
--------------------------------------------------------------------------------
/palivla_digit/big_vision/configs/proj/paligemma/transfers/common.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Common things across all transfer configs."""
16 |
17 |
18 | TOKENIZER = 'gemma(tokensets=("loc", "seg"))'
19 |
20 |
21 | def tok(**kw):
22 | """Creates the tokenization preprocessing string."""
23 | # Single entry point so that it's consistent everywhere and easier to switch.
24 | kw.setdefault('model', TOKENIZER)
25 | kw = ', '.join(f'{k}={repr(v)}' for k, v in kw.items())
26 | return f'tok({kw})'
27 |
28 |
29 | def combine_and_keep_train(text_len, before=(), sep='\n'):
30 | return '|'.join([
31 | *before,
32 | tok(key='prefix', bos='yes'),
33 | tok(key='suffix', eos='yes'),
34 | tok(key='septok', text=sep),
35 | # If masks confuse you, see (internal link)
36 | 'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_loss=[0, 0, 1])', # pylint: disable=line-too-long
37 | # For training, we +1 since the trainer removes EOS.
38 | f'tolen({text_len+1}, pad_value=0, key="text")', # Value doesn't matter.
39 | f'tolen({text_len+1}, pad_value=1, key="mask_ar")',
40 | f'tolen({text_len+1}, pad_value=0, key="mask_loss")',
41 | 'keep("image", "text", "mask_ar", "mask_loss")',
42 | ])
43 |
44 |
45 | def combine_and_keep_eval(text_len, keep=tuple(), before=(), sep='\n'):
46 | return '|'.join([
47 | *before,
48 | # Same as training, except that suffix is now the empty string.
49 | # Meaning, we create text as [prefix separator pad],
50 | # and the mask accordingly as [0 0 1] (with repeats of respective lengths)
51 | tok(key='prefix', bos='yes'),
52 | tok(key='septok', text=sep),
53 | # At eval time, there can be also a suffix key in the data. If so it is
54 | # tokenized without EOS and decoding will continue from it.
55 | 'setdefault("suffix", "")',
56 | tok(key='suffix', eos='no'),
57 | # If masks confuse you, see (internal link)
58 | 'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_input=[1, 1, 1])', # pylint: disable=line-too-long
59 | f'tolen({text_len}, pad_value=0, key="text")', # value doesn't matter.
60 | f'tolen({text_len}, pad_value=1, key="mask_ar")',
61 | f'tolen({text_len}, pad_value=0, key="mask_input")',
62 | # And we need to keep everything that makes our evaluator happy.
63 | 'keep(' + ', '.join(f'"{x}"' for x in (
64 | 'image', 'text', 'mask_ar', 'mask_input') + tuple(keep)) + ')',
65 | ])
66 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/evaluators/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/evaluators/__init__.py
--------------------------------------------------------------------------------
/palivla_digit/big_vision/evaluators/classification.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Evaluator for the classfication task."""
16 | # pylint: disable=consider-using-from-import
17 |
18 | import functools
19 |
20 | from big_vision.evaluators import common
21 | import big_vision.utils as u
22 | import jax
23 | import jax.numpy as jnp
24 |
25 |
26 | # Temporary global flag to facilitate backwards compatability. Will be removed
27 | # by the end of year 2023.
28 | API = 'jit'
29 |
30 |
31 | # To avoid re-compiling the function for every new instance of the same
32 | # evaluator on a different dataset!
33 | @functools.cache
34 | def get_eval_fn(predict_fn, loss_name):
35 | """Produces eval function, also applies pmap."""
36 | @jax.jit
37 | def _eval_fn(train_state, batch, labels, mask):
38 | logits, *_ = predict_fn(train_state, batch)
39 |
40 | # Ignore the entries with all zero labels for evaluation.
41 | mask *= labels.max(axis=1)
42 |
43 | loss = getattr(u, loss_name)(
44 | logits=logits, labels=labels, reduction=False)
45 | loss = jnp.sum(loss * mask)
46 |
47 | top1_idx = jnp.argmax(logits, axis=1)
48 | # Extracts the label at the highest logit index for each image.
49 | top1_correct = jnp.take_along_axis(
50 | labels, top1_idx[:, None], axis=1)[:, 0]
51 | ncorrect = jnp.sum(top1_correct * mask)
52 | nseen = jnp.sum(mask)
53 | return ncorrect, loss, nseen
54 | return _eval_fn
55 |
56 |
57 | class Evaluator:
58 | """Classification evaluator."""
59 |
60 | def __init__(self, predict_fn, loss_name, label_key='labels', **kw):
61 | self.get_data_iter, self.steps = common.eval_input_pipeline(**kw)
62 | self.eval_fn = get_eval_fn(predict_fn, loss_name)
63 | self.label_key = label_key
64 |
65 | def run(self, train_state):
66 | """Computes all metrics."""
67 | ncorrect, loss, nseen = 0, 0, 0
68 | for _, batch in zip(range(self.steps), self.get_data_iter()):
69 | labels, mask = batch.pop(self.label_key), batch.pop('_mask')
70 | batch_ncorrect, batch_losses, batch_nseen = jax.device_get(
71 | self.eval_fn(train_state, batch, labels, mask))
72 | ncorrect += batch_ncorrect
73 | loss += batch_losses
74 | nseen += batch_nseen
75 | yield ('prec@1', ncorrect / nseen)
76 | yield ('loss', loss / nseen)
77 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/evaluators/proj/cappa/perplexity.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Evaluator for perplexity of a model."""
16 | from big_vision.evaluators import mean
17 | import big_vision.utils as u
18 | import jax.numpy as jnp
19 |
20 |
21 | # Temporary global flag to facilitate backwards compatability. Will be removed
22 | # by the end of year 2023.
23 | API = 'jit'
24 |
25 |
26 | def perplexity(predict_fn, normalize_by_seqlen):
27 | """Returns a function that computes perplexity."""
28 |
29 | def _perplexity_fn(train_state, batch, pad_token=0, **kw):
30 | logits, _ = predict_fn(train_state, batch, **kw)
31 |
32 | # Ignore perplexity on the padding label.
33 | weights = jnp.where(batch['labels'] != pad_token, 1, 0).astype(jnp.float32)
34 | if batch.get('label_masks') is not None:
35 | weights = weights * batch['label_masks']
36 |
37 | losses = u.weighted_softmax_xent(
38 | logits=logits, labels=batch['labels'],
39 | weights=weights, label_smoothing=0.0,
40 | reduction=False, normalize=normalize_by_seqlen)
41 |
42 | return {'perplexity': losses}
43 | return _perplexity_fn
44 |
45 |
46 | class Evaluator(mean.Evaluator):
47 | """Perplexity evaluator."""
48 |
49 | def __init__(self, predict_fn, *a, normalize_by_seqlen=False, **kw):
50 | super().__init__(perplexity(predict_fn, normalize_by_seqlen), *a, **kw)
51 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/evaluators/proj/cappa/scoring_classifier.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Scoring classifier.
16 |
17 | This one is based on a generative perspective for image classification.
18 | Here we input the image as well as all the tokenized labels to compute their
19 | perplexity and select the one with minimum loss as the prediction.
20 | """
21 | import functools
22 | from big_vision.datasets.imagenet import class_names as imagenet_class_names
23 | from big_vision.evaluators import mean
24 | from big_vision.pp import builder as pp_builder
25 | import jax.numpy as jnp
26 | import numpy as np
27 |
28 | # Temporary global flag to facilitate backwards compatability. Will be removed
29 | # by the end of year 2023.
30 | API = "jit"
31 |
32 |
33 | CLASS_NAMES = {
34 | "imagenet2012": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES,
35 | }
36 |
37 |
38 | # As a separate function to cache result across instances.
39 | @functools.lru_cache(maxsize=None)
40 | def get_classes(dataset_name, pp_txt):
41 | """Load the class label strings and tokenize them using pp_txt."""
42 | pp_fn = pp_builder.get_preprocess_fn(pp_txt, log_data=False)
43 | return np.array([pp_fn({"label": name})["labels"]
44 | for name in CLASS_NAMES[dataset_name]])
45 |
46 |
47 | def scoring(predict_fn, tokenized_labels):
48 |
49 | def _scoring_fn(train_state, batch, *a, **kw):
50 | batch = {"_label_tokens": tokenized_labels, **batch}
51 | scores = predict_fn(train_state, batch, *a, **kw)
52 | predictions = jnp.argmax(scores, axis=-1)
53 | return {"prec@1": predictions == batch["label"]}
54 |
55 | return _scoring_fn
56 |
57 |
58 | class Evaluator(mean.Evaluator):
59 | """Evaluator for classification accuracy based on scoring all classes."""
60 |
61 | def __init__(self, predict_fn, data, pp_fn, pp_txt, *a, **kw):
62 | cls_tokens = get_classes(data["name"], pp_txt)
63 | super().__init__(scoring(predict_fn, cls_tokens), data, pp_fn, *a, **kw)
64 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/evaluators/proj/image_text/prompt_engineering_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for prompt_engineering."""
16 |
17 | from absl.testing import absltest
18 | from big_vision.evaluators.proj.image_text import prompt_engineering
19 |
20 |
21 | class PromptEngineeringTest(absltest.TestCase):
22 |
23 | def test_canonicalize_text(self):
24 | self.assertEqual(prompt_engineering.canonicalize_text("test_test"), "test test")
25 | self.assertEqual(
26 | prompt_engineering.canonicalize_text("test___test"), "test test")
27 | self.assertEqual(prompt_engineering.canonicalize_text("test"), "test")
28 | self.assertEqual(prompt_engineering.canonicalize_text("test."), "test")
29 | self.assertEqual(prompt_engineering.canonicalize_text(" test "), "test")
30 | self.assertEqual(
31 | prompt_engineering.canonicalize_text("test\ntest"), "test test")
32 | self.assertEqual(
33 | prompt_engineering.canonicalize_text("test test"), "test test")
34 | self.assertEqual(prompt_engineering.canonicalize_text("test {}"), "test")
35 | self.assertEqual(
36 | prompt_engineering.canonicalize_text(
37 | "test {}", keep_punctuation_exact_string="{}"), "test {}")
38 | self.assertEqual(
39 | prompt_engineering.canonicalize_text(
40 | " test {}...", keep_punctuation_exact_string="{}"), "test {}")
41 | self.assertEqual(
42 | prompt_engineering.canonicalize_text(
43 | "test {} {} {}", keep_punctuation_exact_string="{}"),
44 | "test {} {} {}")
45 |
46 |
47 | if __name__ == "__main__":
48 | absltest.main()
49 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/evaluators/proj/paligemma/perplexity.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Evaluator for perplexity of a model."""
16 | import functools
17 |
18 | from big_vision.evaluators import mean
19 | import big_vision.utils as u
20 | import jax.numpy as jnp
21 |
22 |
23 | # Temporary global flag to facilitate backwards compatability. Will be removed
24 | # by the end of year 2023.
25 | API = 'jit'
26 |
27 |
28 | # Cache the function such that it won't always recompile (in mean evaluator).
29 | @functools.cache
30 | def perplexity(
31 | predict_fn, key='labels', shift_labels=True):
32 | """Returns a function that computes perplexity."""
33 |
34 | def _perplexity_fn(train_state, batch, **kw):
35 | logits, _ = predict_fn(train_state, batch, **kw)
36 |
37 | labels = batch[key]
38 | weights = batch.get('mask_loss', jnp.ones_like(labels))
39 |
40 | if shift_labels:
41 | labels = labels[:, 1:]
42 | weights = weights[:, 1:]
43 |
44 | losses = u.weighted_softmax_xent(
45 | logits=logits, labels=labels, weights=weights,
46 | reduction=False, normalize=False)
47 | normalizer = jnp.clip(weights.sum(axis=1), 2e-38)
48 |
49 | return {'sum': losses, 'avg': losses / normalizer}
50 | return _perplexity_fn
51 |
52 |
53 | class Evaluator(mean.Evaluator):
54 | """Perplexity evaluator."""
55 |
56 | def __init__(self, predict_fn, *a, key='labels', shift_labels=False, **kw):
57 | kw.setdefault('prefetch', 0) # More memory-saving default.
58 | super().__init__(perplexity(predict_fn, key, shift_labels), *a, **kw)
59 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/evaluators/proj/paligemma/transfers/storepreds.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Evaluator to run inference and store results."""
16 | import functools
17 |
18 | import big_vision.evaluators.common as c
19 | import big_vision.input_pipeline
20 | import big_vision.pp.builder
21 | import big_vision.pp.tokenizer
22 | import big_vision.utils as u
23 |
24 | import jax
25 |
26 | # Temporary global flag to facilitate backwards compatability. Will be removed
27 | # by the end of year 2023.
28 | API = "jit"
29 |
30 |
31 | class Evaluator:
32 | """Evaluator to run inference and store results."""
33 |
34 | def __init__(
35 | self, predict_fn, tokenizer=None,
36 | preds_outfile="{workdir}/{name}_{split}_preds.json",
37 | annot_outfile="{workdir}/{name}_{split}_annotations.json",
38 | id_key="id",
39 | *, data, devices, **kw
40 | ):
41 | self.id_key = id_key
42 | self.get_data_iter, self.steps = c.eval_input_pipeline(
43 | keep_on_cpu={id_key}, data=data, devices=devices, **kw)
44 |
45 | self.preds_outfile = c.resolve_outfile(
46 | preds_outfile, name=data.get("name"), split=data.get("split", ""))
47 | self.annot_outfile = c.resolve_outfile(
48 | annot_outfile, name=data.get("name"), split=data.get("split", ""))
49 |
50 | self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer)
51 | self.decode = functools.partial(
52 | predict_fn, devices=devices, eos_token=self.tok.eos_token)
53 |
54 | def run(self, train_state):
55 | """Run eval."""
56 | res = []
57 |
58 | for _, batch in zip(range(self.steps), self.get_data_iter()):
59 | # (batch, seqlen) array of decoded generated tokens.
60 | tokens = self.decode(train_state, batch)
61 |
62 | # (local_batch,)
63 | tokens = u.get_local_slice_from_fsarray(tokens)
64 | ex_masks = u.get_local_slice_from_fsarray(batch["_mask"])
65 |
66 | image_ids = batch[self.id_key][ex_masks]
67 | pred_captions = self.tok.to_str(tokens[ex_masks])
68 |
69 | for image_id, caption in zip(image_ids, pred_captions):
70 | res.append({self.id_key: str(image_id), "caption": caption})
71 |
72 | res = c.multiprocess_write_json(self.preds_outfile, res)
73 |
74 | if jax.process_index(): # Host0 gets all preds and does eval.
75 | return
76 |
77 | yield "num_examples", len(res)
78 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/evaluators/proj/uvim/common.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Common utilities used in evaluators."""
16 | import math
17 | import jax
18 | import tensorflow as tf
19 | import tensorflow_datasets as tfds
20 |
21 |
22 | def get_jax_process_dataset(dataset, split, global_batch_size, pp_fn,
23 | dataset_dir=None, cache=True, add_tfds_id=False):
24 | """Returns dataset to be processed by current jax host.
25 |
26 | The dataset is sharded and padded with zeros such that all processes
27 | have equal number of batches. The first 2 dimensions of the dataset
28 | elements are: [local_device_count, device_batch_size].
29 |
30 | Args:
31 | dataset: dataset name.
32 | split: dataset split.
33 | global_batch_size: batch size to be process per iteration on the dataset.
34 | pp_fn: preprocessing function to apply per example.
35 | dataset_dir: path for tfds to find the prepared data.
36 | cache: whether to cache the dataset after batching.
37 | add_tfds_id: whether to add the unique `tfds_id` string to each example.
38 | """
39 | assert global_batch_size % jax.device_count() == 0
40 | total_examples = tfds.load(
41 | dataset, split=split, data_dir=dataset_dir).cardinality()
42 | num_batches = math.ceil(total_examples / global_batch_size)
43 |
44 | process_split = tfds.even_splits(
45 | split, n=jax.process_count(), drop_remainder=False)[jax.process_index()]
46 | data = tfds.load(
47 | dataset,
48 | split=process_split,
49 | data_dir=dataset_dir,
50 | read_config=tfds.ReadConfig(add_tfds_id=add_tfds_id)).map(pp_fn)
51 | pad_data = tf.data.Dataset.from_tensors(
52 | jax.tree_map(lambda x: tf.zeros(x.shape, x.dtype), data.element_spec)
53 | ).repeat()
54 |
55 | data = data.concatenate(pad_data)
56 | data = data.batch(global_batch_size // jax.device_count())
57 | data = data.batch(jax.local_device_count())
58 | data = data.take(num_batches)
59 | if cache:
60 | # Eval datasets are often used many times and caching the dataset after
61 | # batching allows one to have the buffers ready to be used and not have
62 | # to wait for preprocessing to be done over and over.
63 | data = data.cache()
64 | return data
65 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/models/__init__.py
--------------------------------------------------------------------------------
/palivla_digit/big_vision/models/ppp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/models/ppp/__init__.py
--------------------------------------------------------------------------------
/palivla_digit/big_vision/models/proj/flaxformer/bert_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for bert."""
16 |
17 | import tempfile
18 |
19 | from big_vision import input_pipeline
20 | from big_vision.models.proj.flaxformer import bert, bert_test_util
21 | import big_vision.pp.builder as pp_builder
22 | import big_vision.pp.ops_general # pylint: disable=unused-import
23 | import big_vision.pp.proj.flaxformer.bert_ops # pylint: disable=unused-import
24 | import flax
25 | import jax
26 | import jax.numpy as jnp
27 | import tensorflow as tf
28 |
29 | # BERT vocabulary for testing.
30 | _BERT_VOCAB = [
31 | "[PAD]",
32 | "[UNK]",
33 | "this",
34 | "is",
35 | "a",
36 | "test",
37 | "[CLS]",
38 | "[SEP]",
39 | ]
40 | _TOKEN_LEN = 16
41 |
42 |
43 | class BertTest(tf.test.TestCase):
44 | def test_load_apply(self):
45 | inkey = "text"
46 | vocab_path = f"{tempfile.mkdtemp()}/vocab.txt"
47 | with open(vocab_path, "w") as f:
48 | f.write("\n".join(_BERT_VOCAB))
49 | ds2, _ = input_pipeline.make_for_inference(
50 | tf.data.Dataset.from_tensor_slices(
51 | {inkey: tf.ragged.constant([["this is a test"]])}
52 | ),
53 | num_ex_per_process=[1],
54 | preprocess_fn=pp_builder.get_preprocess_fn(
55 | f"bert_tokenize(inkey='{inkey}', vocab_path='{vocab_path}', "
56 | f"max_len={_TOKEN_LEN})"
57 | "|keep('labels')"
58 | ),
59 | batch_size=1,
60 | )
61 | text = jnp.array(next(iter(ds2))["labels"])
62 | model = bert.Model(config="base")
63 | variables = model.init(jax.random.PRNGKey(0), text)
64 | params = bert.load(
65 | flax.core.unfreeze(variables)["params"],
66 | bert_test_util.create_base_checkpoint(),
67 | )
68 | x, out = model.apply({"params": params}, text)
69 | self.assertAllEqual(jax.tree_map(jnp.shape, x), (1, 768))
70 | self.assertAllEqual(
71 | jax.tree_map(jnp.shape, out),
72 | {
73 | "transformed": (1, 16, 768),
74 | "pre_logits": (1, 768),
75 | },
76 | )
77 |
78 |
79 | if __name__ == "__main__":
80 | tf.test.main()
81 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/models/proj/givt/adaptor_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for the IRevNet adaptor."""
16 |
17 | from absl.testing import absltest
18 | from big_vision.models.proj.givt import adaptor
19 | import jax
20 | from jax import random
21 | import jax.numpy as jnp
22 |
23 |
24 | class AdaptorTest(googletest.TestCase):
25 | def test_inversion(self):
26 | num_channels = 8
27 | input_shape = (1, 24, 24, num_channels)
28 |
29 | rng = random.PRNGKey(758493)
30 | _, inp_rng, init_rng, data_rng = jax.random.split(rng, 4)
31 |
32 | dummy_x = random.normal(inp_rng, shape=input_shape)
33 | real_x = jax.random.normal(data_rng, shape=input_shape)
34 |
35 | model = adaptor.IRevNet(
36 | num_blocks=4,
37 | num_channels=num_channels,
38 | dropout_rate=0.0,
39 | )
40 | params = model.init(init_rng, dummy_x)
41 |
42 | real_y = model.apply(params, real_x, method=model.forward)
43 | real_x_ = model.apply(params, real_y, method=model.inverse)
44 | self.assertTrue(jnp.allclose(real_x, real_x_, atol=1e-5))
45 |
46 |
47 | if __name__ == "__main__":
48 | googletest.main()
49 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/models/proj/uvim/vit_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for vit vqvae model."""
16 | from absl.testing import absltest
17 | from big_vision.models.proj.uvim import vit
18 | import jax
19 | import jax.numpy as jnp
20 | import ml_collections
21 |
22 |
23 | class ViTVQVAEModelTest(absltest.TestCase):
24 | def test_model(self):
25 | model_config = ml_collections.ConfigDict(
26 | {
27 | "input_size": (32, 32),
28 | "code_len": 4,
29 | "width": 16,
30 | "mlp_dim": 64,
31 | "num_heads": 4,
32 | "enc_depth": 1,
33 | "dec_depth": 1,
34 | "with_encoder_ctx": True,
35 | "with_decoder_ctx": True,
36 | "statistics_axis_name": None,
37 | "inputs": {
38 | "in1": (10, 3),
39 | "in2": (25,),
40 | },
41 | "outputs": {
42 | "out1": (5,),
43 | "out2": (20,),
44 | },
45 | }
46 | )
47 |
48 | model = vit.Model(**model_config)
49 | batch_size = 4
50 | seq_len = (32 // 8) ** 2
51 | x = {
52 | "in1": jnp.zeros((batch_size, seq_len, 10, 3)),
53 | "in2": jnp.zeros((batch_size, seq_len, 25)),
54 | }
55 | ctx_image = jnp.zeros((batch_size,) + model_config.input_size + (3,))
56 | init_rngs = {
57 | "params": jax.random.PRNGKey(0),
58 | "state": jax.random.PRNGKey(1),
59 | }
60 | params = model.init(init_rngs, x, ctx=ctx_image)
61 | self.assertEqual(params.keys(), set(["params", "state"]))
62 |
63 | apply_rngs = {
64 | "dropout": jax.random.PRNGKey(0),
65 | "vqvae": jax.random.PRNGKey(0),
66 | }
67 | (logits, _), params = model.apply(
68 | params,
69 | x,
70 | ctx=ctx_image,
71 | train=True,
72 | update_dict=True,
73 | rngs=apply_rngs,
74 | mutable=["state"],
75 | )
76 | self.assertEqual(logits.keys(), set(["out1", "out2"]))
77 | self.assertEqual(logits["out1"].shape, (batch_size, seq_len, 5))
78 | self.assertEqual(logits["out2"].shape, (batch_size, seq_len, 20))
79 |
80 |
81 | if __name__ == "__main__":
82 | absltest.main()
83 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/models/proj/uvim/vtt_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for vision-text-transformer."""
16 | from absl.testing import absltest
17 | from big_vision.models.proj.uvim import vtt
18 | import jax
19 | import jax.numpy as jnp
20 | import ml_collections
21 |
22 |
23 | class VTTTest(absltest.TestCase):
24 | def test_vtt_with_1_step(self):
25 | model_config = ml_collections.ConfigDict(
26 | dict(
27 | input_size=(224, 224),
28 | patches={"size": (16, 16)},
29 | num_heads=2,
30 | num_layers=2,
31 | mlp_dim=128,
32 | emb_dim=64,
33 | vocab_size=500,
34 | )
35 | )
36 | batch_size, max_len = 8, 50
37 | image = jnp.ones((batch_size, 224, 224, 3))
38 | text = jnp.ones((batch_size, max_len), dtype=jnp.int32)
39 |
40 | m = vtt.Model(**model_config)
41 | variables = m.init(jax.random.PRNGKey(42), image, text)
42 | self.assertCountEqual(variables.keys(), ["params"])
43 |
44 | params = variables["params"]
45 | out = m.apply({"params": params}, image, text)
46 | expected_shape = (batch_size, max_len, model_config.vocab_size)
47 | self.assertEqual(out.shape, expected_shape)
48 |
49 |
50 | if __name__ == "__main__":
51 | absltest.main()
52 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/pp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/pp/__init__.py
--------------------------------------------------------------------------------
/palivla_digit/big_vision/pp/archive/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/big_vision/pp/archive/__init__.py
--------------------------------------------------------------------------------
/palivla_digit/big_vision/pp/archive/randaug.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """RandAug depends on deprecated tfa.image package, now defunct."""
16 |
17 | from big_vision.pp import registry
18 | from big_vision.pp import utils
19 | from big_vision.pp.archive import autoaugment
20 |
21 |
22 | @registry.Registry.register("preprocess_ops.randaug")
23 | @utils.InKeyOutKey()
24 | def get_randaug(num_layers: int = 2, magnitude: int = 10):
25 | """Creates a function that applies RandAugment.
26 |
27 | RandAugment is from the paper https://arxiv.org/abs/1909.13719,
28 |
29 | Args:
30 | num_layers: Integer, the number of augmentation transformations to apply
31 | sequentially to an image. Represented as (N) in the paper. Usually best
32 | values will be in the range [1, 3].
33 | magnitude: Integer, shared magnitude across all augmentation operations.
34 | Represented as (M) in the paper. Usually best values are in the range [5,
35 | 30].
36 |
37 | Returns:
38 | a function that applies RandAugment.
39 | """
40 |
41 | def _randaug(image):
42 | return autoaugment.distort_image_with_randaugment(
43 | image, num_layers, magnitude
44 | )
45 |
46 | return _randaug
47 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/pp/builder_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for builder."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from big_vision.pp import builder
22 | from big_vision.pp import ops_general # pylint: disable=unused-import
23 | from big_vision.pp import ops_image # pylint: disable=unused-import
24 | import numpy as np
25 | import tensorflow.compat.v1 as tf
26 |
27 |
28 | class BuilderTest(tf.test.TestCase):
29 |
30 | def testSingle(self):
31 | pp_fn = builder.get_preprocess_fn("resize(256)")
32 | x = np.random.randint(0, 256, [640, 480, 3])
33 | image = pp_fn({"image": x})["image"]
34 | self.assertEqual(image.numpy().shape, (256, 256, 3))
35 |
36 | def testEmpty(self):
37 | pp_fn = builder.get_preprocess_fn("||inception_crop|||resize(256)||")
38 |
39 | # Typical image input
40 | x = np.random.randint(0, 256, [640, 480, 3])
41 | image = pp_fn({"image": x})["image"]
42 | self.assertEqual(image.numpy().shape, (256, 256, 3))
43 |
44 | def testPreprocessingPipeline(self):
45 | pp_str = ("inception_crop|resize(256)|resize((256, 256))|"
46 | "central_crop((80, 120))|flip_lr|value_range(0,1)|"
47 | "value_range(-1,1)")
48 | pp_fn = builder.get_preprocess_fn(pp_str)
49 |
50 | # Typical image input
51 | x = np.random.randint(0, 256, [640, 480, 3])
52 | image = pp_fn({"image": x})["image"]
53 | self.assertEqual(image.numpy().shape, (80, 120, 3))
54 | self.assertLessEqual(np.max(image.numpy()), 1)
55 | self.assertGreaterEqual(np.min(image.numpy()), -1)
56 |
57 | def testNumArgsException(self):
58 |
59 | x = np.random.randint(0, 256, [640, 480, 3])
60 | for pp_str in [
61 | "inception_crop(1)",
62 | "resize()",
63 | "resize(1, 1, 1)"
64 | "flip_lr(1)",
65 | "central_crop()",
66 | ]:
67 | with self.assertRaises(BaseException):
68 | builder.get_preprocess_fn(pp_str)(x)
69 |
70 |
71 | if __name__ == "__main__":
72 | tf.test.main()
73 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/pp/proj/clippo/download_unifont.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | #!/bin/bash
16 | # This is intended to be run from the big_vision repository root:
17 | #
18 | # bash big_vision/pp/proj/clippo/download_unifont.sh
19 | wget https://unifoundry.com/pub/unifont/unifont-9.0.06/font-builds/unifont-9.0.06.hex.gz https://unifoundry.com/pub/unifont/unifont-9.0.06/font-builds/unifont_upper-9.0.06.hex.gz
20 | gunzip unifont-9.0.06.hex.gz unifont_upper-9.0.06.hex.gz
21 | mv unifont-9.0.06.hex unifont_upper-9.0.06.hex big_vision/pp/proj/clippo/
--------------------------------------------------------------------------------
/palivla_digit/big_vision/pp/proj/flaxformer/bert_ops_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for bert_ops."""
16 |
17 | import tempfile
18 |
19 | from big_vision import input_pipeline
20 | import big_vision.pp.builder as pp_builder
21 | import big_vision.pp.ops_general # pylint: disable=unused-import
22 | from big_vision.pp.proj.flaxformer import bert_ops # pylint: disable=unused-import
23 | import tensorflow as tf
24 |
25 |
26 | # BERT vocabulary for testing.
27 | _BERT_VOCAB = [
28 | "[PAD]",
29 | "[UNK]",
30 | "more",
31 | "than",
32 | "one",
33 | "[CLS]",
34 | "[SEP]",
35 | ]
36 |
37 |
38 | def _create_ds(pp_str, tensor_slices, num_examples):
39 | return input_pipeline.make_for_inference(
40 | tf.data.Dataset.from_tensor_slices(tensor_slices),
41 | num_ex_per_process=[num_examples],
42 | preprocess_fn=pp_builder.get_preprocess_fn(pp_str),
43 | batch_size=num_examples,
44 | )[0]
45 |
46 |
47 | class BertOpsTest(tf.test.TestCase):
48 |
49 | def test_tokenize(self):
50 | inkey = "texts"
51 | vocab_path = f"{tempfile.mkdtemp()}/vocab.txt"
52 | with open(vocab_path, "w") as f:
53 | f.write("\n".join(_BERT_VOCAB))
54 | pp_str = (
55 | f"bert_tokenize(inkey='{inkey}', vocab_path='{vocab_path}', max_len=5)"
56 | f"|keep('labels')"
57 | )
58 | tensor_slices = {
59 | inkey: tf.ragged.constant([["one more"], ["more than one"], [""]])
60 | }
61 | ds = _create_ds(pp_str, tensor_slices, 3)
62 | self.assertAllEqual(
63 | next(iter(ds))["labels"],
64 | [[5, 4, 2, 0, 0], [5, 2, 3, 4, 0], [5, 0, 0, 0, 0]],
65 | )
66 |
67 |
68 | if __name__ == "__main__":
69 | tf.test.main()
70 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/pp/proj/givt/pp_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """GIVT-specific preprocessing ops."""
16 |
17 | from big_vision.pp import registry
18 | from big_vision.pp import utils
19 | import tensorflow as tf
20 |
21 |
22 | @registry.Registry.register("preprocess_ops.bin_nyu_depth")
23 | @utils.InKeyOutKey(indefault="labels", outdefault="labels")
24 | def get_bin_nyu_depth(min_depth=0.001, max_depth=10.0, num_bins=256):
25 | """Binning of NYU depth for UViM in preprocessing rather than model."""
26 |
27 | def _bin_depth(labels): # pylint: disable=missing-docstring
28 | labels = (labels - min_depth) / (max_depth - min_depth)
29 | labels *= num_bins
30 | labels = tf.cast(tf.floor(labels), tf.int32)
31 | labels = tf.minimum(labels, num_bins - 1)
32 | labels = tf.maximum(labels, 0)
33 | return labels
34 |
35 | return _bin_depth
36 |
37 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/pp/proj/paligemma/robustness.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """pp ops."""
16 |
17 | import math
18 |
19 | from big_vision.pp import utils
20 | from big_vision.pp.registry import Registry
21 | import tensorflow as tf
22 |
23 |
24 | @Registry.register("preprocess_ops.resize_r")
25 | @utils.InKeyOutKey()
26 | def get_resize_r(size):
27 | """Like standard `resize` but randomize some of its parameters."""
28 | size = utils.maybe_repeat(size, 2)
29 |
30 | # Sadly TF won't let us pass symbolic arguments, so we need to pre-create all
31 | # variants of function calls we'd like to randomize over...
32 | resize_fns = [
33 | lambda x, m=m, a=a: tf.image.resize(x, size, method=m, antialias=a)
34 | for m in ["bilinear", "bicubic", "lanczos3", "area", "mitchellcubic"]
35 | for a in [True, False]
36 | ]
37 |
38 | def _resize_r(image):
39 | """Resizes image to a given size."""
40 | dtype = image.dtype
41 | tf_dtype = tf.type_spec_from_value(image).dtype
42 | ifn = tf.random.uniform((), 0, len(resize_fns), tf.int32)
43 | image = tf.switch_case(ifn, [lambda fn=fn: fn(image) for fn in resize_fns])
44 | return tf.cast(tf.clip_by_value(image, tf_dtype.min, tf_dtype.max), dtype)
45 |
46 | return _resize_r
47 |
48 |
49 | @Registry.register("preprocess_ops.random_jpeg")
50 | @utils.InKeyOutKey()
51 | def get_random_jpeg(p):
52 | """With probability `p`, randomly encode-decode as jpeg."""
53 |
54 | fns = [
55 | lambda x: tf.image.adjust_jpeg_quality(
56 | x, dct_method="INTEGER_FAST",
57 | jpeg_quality=tf.random.uniform((), 75, 96, dtype=tf.int32),
58 | ),
59 | lambda x: tf.image.adjust_jpeg_quality(
60 | x, dct_method="INTEGER_ACCURATE",
61 | jpeg_quality=tf.random.uniform((), 75, 96, dtype=tf.int32),
62 | ),
63 | ]
64 |
65 | def _random_jpeg(image):
66 | """Resizes image to a given size."""
67 | funcs = [lambda: image] + [lambda fn=fn: fn(image) for fn in fns]
68 | logits = [math.log(prob) for prob in [1 - p] + [p / len(fns)] * len(fns)]
69 | fn_idx = tf.random.categorical([logits], 1, dtype=tf.int32)[0, 0]
70 | return tf.switch_case(fn_idx, funcs)
71 |
72 | return _random_jpeg
73 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/pp/proj/paligemma/sciqa_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """pp ops."""
16 |
17 | from big_vision.pp.registry import Registry
18 | import tensorflow as tf
19 |
20 |
21 | @Registry.register('preprocess_ops.sci_qa_choices_shuffle')
22 | def sci_qa_choices_shuffle(
23 | choice_str_inkey='choices',
24 | ans_inkey='answer',
25 | indexed_choices_outkey='indexed_choices',
26 | indexed_answer_outkey='indexed_answer',
27 | ):
28 | """Random shuffle the sci_qa's choice on the fly.
29 |
30 | Args:
31 | choice_str_inkey: the original choice list from
32 | sciqa,e.g['apple','banana',..]
33 | ans_inkey: the original answer from sciqa e.g. 1
34 | indexed_choices_outkey: shuffled choice (with index suffix concat to string)
35 | e.g."(A) banana, (B) apple"
36 | indexed_answer_outkey: shuffled answer with abc index, e,g
37 | 1(original)->2(shuffled)->'B' (alphabet index)
38 |
39 | Returns:
40 | """
41 | def _template(data):
42 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
43 | abc_tensor = tf.constant([f'({a})' for a in alphabet])
44 | abcans_tensor = tf.constant([f'{a}' for a in alphabet])
45 | choices = data[choice_str_inkey]
46 | indices = tf.range(len(choices))
47 | # Shuffle the indices
48 | shuffled_indices = tf.random.shuffle(indices)
49 | # Use the shuffled indices to shuffle the tensor
50 | shuffled_tensor = tf.gather(choices, shuffled_indices)
51 |
52 | abc_tensor = tf.gather(abc_tensor, indices)
53 |
54 | data[indexed_choices_outkey] = tf.strings.reduce_join(
55 | tf.strings.join([abc_tensor, shuffled_tensor], separator=' '),
56 | separator=', ',
57 | )
58 |
59 | answer_tensor = data[ans_inkey]
60 | new_ans_indice = tf.where(tf.equal(shuffled_indices, answer_tensor))
61 | new_ans_indice = tf.gather(abcans_tensor, new_ans_indice)
62 | data[indexed_answer_outkey] = tf.strings.reduce_join(new_ans_indice)
63 | return data
64 |
65 | return _template
66 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/pp/proj/paligemma/widgetcap.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Widgetcap pp ops."""
16 |
17 | from big_vision.pp.registry import Registry
18 | import tensorflow as tf
19 |
20 |
21 | @Registry.register("preprocess_ops.draw_bbox")
22 | def get_draw_bbox(image_key="image", bbox_key="bbox"):
23 | """Draw a single bounding box."""
24 |
25 | def _draw_bbox(data):
26 | """Draw a single bounding box."""
27 | image = tf.cast(data[image_key], tf.float32)
28 | image = tf.image.draw_bounding_boxes(
29 | tf.expand_dims(image, 0),
30 | tf.reshape(data[bbox_key], [1, 1, 4]),
31 | tf.constant([255, 0, 0], dtype=tf.float32, shape=[1, 3]),
32 | )
33 | data[image_key] = tf.squeeze(image)
34 | return data
35 |
36 | return _draw_bbox
37 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/pp/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Preprocessing utils."""
16 |
17 | from collections import abc
18 |
19 |
20 | def maybe_repeat(arg, n_reps):
21 | if not isinstance(arg, abc.Sequence) or isinstance(arg, str):
22 | arg = (arg,) * n_reps
23 | return arg
24 |
25 |
26 | class InKeyOutKey(object):
27 | """Decorator for preprocessing ops, which adds `inkey` and `outkey` arguments.
28 |
29 | Note: Only supports single-input single-output ops.
30 | """
31 |
32 | def __init__(self, indefault="image", outdefault="image", with_data=False):
33 | self.indefault = indefault
34 | self.outdefault = outdefault
35 | self.with_data = with_data
36 |
37 | def __call__(self, orig_get_pp_fn):
38 |
39 | def get_ikok_pp_fn(*args, key=None,
40 | inkey=self.indefault, outkey=self.outdefault, **kw):
41 |
42 | orig_pp_fn = orig_get_pp_fn(*args, **kw)
43 | def _ikok_pp_fn(data):
44 | # Optionally allow the function to get the full data dict as aux input.
45 | if self.with_data:
46 | data[key or outkey] = orig_pp_fn(data[key or inkey], data=data)
47 | else:
48 | data[key or outkey] = orig_pp_fn(data[key or inkey])
49 | return data
50 |
51 | return _ikok_pp_fn
52 |
53 | return get_ikok_pp_fn
54 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/pp/utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for preprocessing utils."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from big_vision.pp import utils
22 | import tensorflow.compat.v1 as tf
23 |
24 |
25 | class UtilsTest(tf.test.TestCase):
26 |
27 | def test_maybe_repeat(self):
28 | self.assertEqual((1, 1, 1), utils.maybe_repeat(1, 3))
29 | self.assertEqual((1, 2), utils.maybe_repeat((1, 2), 2))
30 | self.assertEqual([1, 2], utils.maybe_repeat([1, 2], 2))
31 |
32 | def test_inkeyoutkey(self):
33 | @utils.InKeyOutKey()
34 | def get_pp_fn(shift, scale=0):
35 | def _pp_fn(x):
36 | return scale * x + shift
37 | return _pp_fn
38 |
39 | data = {"k_in": 2, "other": 3}
40 | ppfn = get_pp_fn(1, 2, inkey="k_in", outkey="k_out") # pylint: disable=unexpected-keyword-arg
41 | self.assertEqual({"k_in": 2, "k_out": 5, "other": 3}, ppfn(data))
42 |
43 | data = {"k": 6, "other": 3}
44 | ppfn = get_pp_fn(1, inkey="k", outkey="k") # pylint: disable=unexpected-keyword-arg
45 | self.assertEqual({"k": 1, "other": 3}, ppfn(data))
46 |
47 | data = {"other": 6, "image": 3}
48 | ppfn = get_pp_fn(5, 2)
49 | self.assertEqual({"other": 6, "image": 11}, ppfn(data))
50 |
51 |
52 | if __name__ == "__main__":
53 | tf.test.main()
54 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.26
2 | absl-py
3 | git+https://github.com/google/CommonLoopUtils
4 | distrax
5 | editdistance
6 | einops
7 | flax
8 | optax
9 | git+https://github.com/google/flaxformer
10 | git+https://github.com/akolesnikoff/panopticapi.git@mute
11 | overrides
12 | protobuf
13 | sentencepiece
14 | tensorflow-cpu
15 | tfds-nightly
16 | tensorflow-text
17 | tensorflow-gan
18 | pycocoevalcap
19 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/run_tpu.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | #!/bin/bash
16 |
17 | if [ ! -d "bv_venv" ]
18 | then
19 | sudo apt-get update
20 | sudo apt install -y python3-venv
21 | python3 -m venv bv_venv
22 | . bv_venv/bin/activate
23 |
24 | pip install -U pip # Yes, really needed.
25 | # NOTE: doesn't work when in requirements.txt -> cyclic dep
26 | pip install "jax[tpu]>=0.4.25" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
27 | pip install -r big_vision/requirements.txt
28 | else
29 | . bv_venv/bin/activate
30 | fi
31 |
32 | if [ $# -ne 0 ]
33 | then
34 | env TFDS_DATA_DIR=$TFDS_DATA_DIR BV_JAX_INIT=1 python3 -m "$@"
35 | fi
36 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/download_tfds_datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Download and prepare TFDS datasets for the big_vision codebase.
16 |
17 | This python script covers cifar10, cifar100, oxford_iiit_pet
18 | and oxford_flowers10.
19 |
20 | If you want to integrate other public or custom datasets, please follow:
21 | https://www.tensorflow.org/datasets/catalog/overview
22 | """
23 |
24 | from absl import app
25 | import tensorflow_datasets as tfds
26 |
27 |
28 | def main(argv):
29 | if len(argv) > 1 and "download_tfds_datasets.py" in argv[0]:
30 | datasets = argv[1:]
31 | else:
32 | datasets = [
33 | "cifar10",
34 | "cifar100",
35 | "oxford_iiit_pet",
36 | "oxford_flowers102",
37 | "imagenet_v2",
38 | ]
39 | for d in datasets:
40 | tfds.load(name=d, download=True)
41 |
42 |
43 | if __name__ == "__main__":
44 | app.run(main)
45 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/README.md:
--------------------------------------------------------------------------------
1 | # LiT-Demo
2 |
3 | See https://blog.tensorflow.org/2022/08/jax-on-web-with-tensorflowjs.html
4 |
5 | Demo originally appeared on Twitter
6 | https://twitter.com/AndreasPSteiner/status/1514722383818543106
7 |
8 | App published at
9 | https://google-research.github.io/vision_transformer/lit
10 |
11 | ## Build
12 |
13 | Install packages (tested with node v16.17.0 and yarn 1.22.19)
14 |
15 | ```bash
16 | yarn
17 | ```
18 |
19 |
20 | ## Run
21 |
22 | The web app will appear on http://localhost:8000
23 |
24 | ```
25 | node build.js
26 | ```
27 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/build.js:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | const sassPlugin = require('esbuild-sass-plugin').sassPlugin;
19 |
20 | require('esbuild').serve({
21 | servedir: 'src',
22 | port: 8000,
23 | }, {
24 | entryPoints: ['src/app.ts'],
25 | bundle: true,
26 | outfile: 'src/index.js',
27 | plugins: [
28 | sassPlugin({
29 | filter: /style.scss$/,
30 | type: 'style'
31 | }),
32 | sassPlugin({
33 | type: 'lit-css',
34 | }),
35 | ],
36 | sourcemap: true,
37 | }).then(() => {
38 | console.log('Serving on port 8000');
39 | }).catch(() => process.exit(1));
40 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "lit-demo",
3 | "version": "0.0.2",
4 | "description": "",
5 | "main": "src/app.ts",
6 | "license": "Apache-2.0",
7 | "private": true,
8 | "engines": {
9 | "node": ">=8.9.0"
10 | },
11 | "scripts": {
12 | "serve": "node build.js",
13 | "test": "ts-node --skip-ignore --project tsconfig.test.json run_tests.ts"
14 | },
15 | "devDependencies": {
16 | "@babel/core": "^7.7.5",
17 | "@babel/plugin-transform-runtime": "^7.7.6",
18 | "@babel/polyfill": "^7.10.4",
19 | "@babel/preset-env": "^7.7.6",
20 | "@tensorflow/tfjs-backend-cpu": "^3.15.0",
21 | "@tensorflow/tfjs-backend-webgl": "^3.15.0",
22 | "@tensorflow/tfjs-converter": "3.20.0",
23 | "@tensorflow/tfjs-core": "3.20.0",
24 | "babel-preset-env": "^1.7.0",
25 | "esbuild": "^0.15.5",
26 | "esbuild-sass-plugin": "^2.3.2",
27 | "jasmine": "^3.3.1",
28 | "lit": "^2.3.1",
29 | "naughty-words": "^1.2.0",
30 | "sass": "^1.50.0",
31 | "ts-node": "~5.0.0",
32 | "typescript": "4.1.3"
33 | },
34 | "resolutions": {
35 | "is-svg": "4.3.1"
36 | },
37 | "eslintConfig": {
38 | "extends": "google",
39 | "rules": {
40 | "require-jsdoc": 0,
41 | "valid-jsdoc": 0
42 | },
43 | "env": {
44 | "es6": true
45 | },
46 | "parserOptions": {
47 | "ecmaVersion": 8,
48 | "sourceType": "module"
49 | }
50 | },
51 | "eslintIgnore": [
52 | "dist/"
53 | ]
54 | }
55 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/app.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | import {LitDemoApp} from './components/lit-demo-app';
19 | import './style.scss';
20 |
21 | // tslint:disable-next-line:no-any
22 | (window as any).LitDemoApp = LitDemoApp;
23 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/components/image-carousel.scss:
--------------------------------------------------------------------------------
1 | @import '../style/mixins';
2 |
3 | .selector {
4 | overflow: scroll;
5 | padding-bottom: 10px; // OS X scroll bar
6 |
7 | .inner {
8 | white-space: nowrap;
9 |
10 | .thumb {
11 | display: inline-block;
12 |
13 | img {
14 | cursor: pointer;
15 |
16 | width: 20vmin;
17 | height: 20vmin;
18 | max-width: 200px;
19 | max-height: 200px;
20 |
21 | @include phone-portrait {
22 | width: 33vmin;
23 | height: 33vmin;
24 | }
25 |
26 | margin: 10px;
27 |
28 | box-shadow: 0 0 10px #888;
29 | }
30 | }
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/components/image-carousel.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview Carousel of images.
20 | */
21 |
22 | import {html, LitElement} from 'lit';
23 |
24 | import {app} from '../lit_demo/app';
25 | import {getImageUrl} from '../lit_demo/constants';
26 | import {ImageRow} from '../lit_demo/data';
27 |
28 | import {customElement} from 'lit/decorators.js';
29 | import styles from './image-carousel.scss';
30 |
31 | /**
32 | * Shows multiple images in a horizontal carousel.
33 | *
34 | * Dispatches `'image-select'` event when an image is clicked/tapped.
35 | */
36 | @customElement('image-carousel')
37 | export class ImageCarousel extends LitElement {
38 | static override styles = [styles];
39 |
40 | onClick(id: string) {
41 | const event =
42 | new CustomEvent('image-select', {composed: true, detail: {id}});
43 | this.dispatchEvent(event);
44 | }
45 |
46 | override render() {
47 | const images = app.imageData.rows.map(
48 | (row: ImageRow) => html`
49 |
50 |
![]()
{
51 | this.onClick(row.id);
52 | }} data-id=${row.id} src="${getImageUrl(row.id)}">
53 |
54 | `);
55 | return html`
56 |
57 |
58 | ${images}
59 |
60 |
61 | Select an image 👆 to get started.
62 | `;
63 | }
64 | }
65 |
66 | declare global {
67 | interface HTMLElementTagNameMap {
68 | 'image-carousel': ImageCarousel;
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/components/image-prompts.scss:
--------------------------------------------------------------------------------
1 | @import '../style/mixins';
2 |
3 | .image-prompt {
4 | display: flex;
5 | gap: 1.5em;
6 | align-items: flex-start;
7 | margin-top: 2rem;
8 |
9 | @include phone-portrait {
10 | align-items: center;
11 | flex-direction: column;
12 | gap: 0;
13 | margin-bottom: 5rem;
14 | }
15 |
16 | .left {
17 | display: flex;
18 | flex-direction: column;
19 |
20 | .wrapper {
21 | position: relative;
22 |
23 | .src {
24 | position: absolute;
25 | right: 2rem;
26 | bottom: 2rem;
27 | color: white;
28 | font-size: 1.5rem;
29 | text-shadow: 2px 2px black;
30 | text-decoration: none;
31 | }
32 | }
33 |
34 | .animation {
35 | position: relative;
36 | width: 224px;
37 | height: 15px;
38 | opacity: 0;
39 |
40 | .computing {
41 | text-align: center;
42 | }
43 | }
44 | }
45 |
46 | .right {
47 | display: flex;
48 | flex-grow: 1;
49 | flex-direction: column;
50 | gap: 0.5em;
51 |
52 | .top {
53 | text-align: right;
54 | height: 30px;
55 | }
56 |
57 | .buttons {
58 | display: flex;
59 | flex-wrap: wrap;
60 | justify-content: flex-end;
61 | gap: 1em;
62 | align-items: center;
63 | }
64 |
65 | .item {
66 | position: relative;
67 | display: flex;
68 |
69 | .pct {
70 | display: inline-block;
71 | margin-right: 1em;
72 | width: 3.5em;
73 | text-align: right;
74 | opacity: 0;
75 | transition: opacity 0.5s;
76 | }
77 |
78 | input {
79 | flex-grow: 1;
80 | max-width: 70vw;
81 | border-radius: 0;
82 | background: transparent;
83 | border: 0;
84 | border-bottom: 1px solid var(--text-fg);
85 | color: var(--text-fg);
86 | outline: none;
87 |
88 | &.toolong {
89 | border-bottom: 1px solid var(--text-red);
90 | color: var(--text-red);
91 | }
92 | }
93 |
94 | .bar {
95 | position: absolute;
96 | display: inline-block;
97 | top: 5%;
98 | left: 0;
99 | z-index: -1;
100 | background: var(--bar-col);
101 | height: 90%;
102 | width: 0;
103 | transition: width 0.5s;
104 | }
105 | }
106 |
107 | .bottom {
108 | display: flex;
109 | flex-wrap: wrap;
110 | justify-content: flex-end;
111 | gap: 1em;
112 | align-items: center;
113 | opacity: 0;
114 |
115 | .tweet {
116 | background: rgb(18, 150, 223);
117 | color: white;
118 | text-decoration: none;
119 | padding: 0px 15px;
120 | border-radius: 16px;
121 | }
122 | }
123 | }
124 | }
125 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/components/lit-demo-app.scss:
--------------------------------------------------------------------------------
1 | .loading-container {
2 | text-align: center;
3 | }
4 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/components/loading-animation.scss:
--------------------------------------------------------------------------------
1 | // CC0 from https://loading.io/css/
2 |
3 | @import '../style/colors';
4 |
5 | .lds-ellipsis {
6 | display: inline-block;
7 | position: relative;
8 | width: 80px;
9 | height: 80px;
10 |
11 | div {
12 | position: absolute;
13 | top: 33px;
14 | width: 13px;
15 | height: 13px;
16 | border-radius: 50%;
17 | background: var(--text-fg);
18 | animation-timing-function: cubic-bezier(0, 1, 1, 0);
19 | }
20 |
21 | div:nth-child(1) {
22 | left: 8px;
23 | animation: lds-ellipsis1 0.6s infinite;
24 | }
25 |
26 | div:nth-child(2) {
27 | left: 8px;
28 | animation: lds-ellipsis2 0.6s infinite;
29 | }
30 |
31 | div:nth-child(3) {
32 | left: 32px;
33 | animation: lds-ellipsis2 0.6s infinite;
34 | }
35 |
36 | div:nth-child(4) {
37 | left: 56px;
38 | animation: lds-ellipsis3 0.6s infinite;
39 | }
40 | }
41 |
42 | @keyframes lds-ellipsis1 {
43 | 0% {
44 | transform: scale(0);
45 | }
46 | 100% {
47 | transform: scale(1);
48 | }
49 | }
50 | @keyframes lds-ellipsis3 {
51 | 0% {
52 | transform: scale(1);
53 | }
54 | 100% {
55 | transform: scale(0);
56 | }
57 | }
58 | @keyframes lds-ellipsis2 {
59 | 0% {
60 | transform: translate(0, 0);
61 | }
62 | 100% {
63 | transform: translate(24px, 0);
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/components/loading-animation.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview Carousel of images.
20 | */
21 |
22 | import {html, LitElement} from 'lit';
23 |
24 | import {customElement} from 'lit/decorators.js';
25 | import styles from './loading-animation.scss';
26 |
27 | /**
28 | * Shows an animated loading animation.
29 | */
30 | @customElement('loading-animation')
31 | export class LoadingAnimation extends LitElement {
32 |
33 | static override styles = [styles];
34 |
35 | override render() {
36 | return html`
37 |
43 | `;
44 | }
45 | }
46 |
47 | declare global {
48 | interface HTMLElementTagNameMap {
49 | 'loading-animation': LoadingAnimation;
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/components/message-list.scss:
--------------------------------------------------------------------------------
1 | @import '../style/colors';
2 |
3 | .message {
4 | padding: 0.1rem 0.5rem;
5 | margin-bottom: 1rem;
6 | }
7 |
8 | .warning {
9 | background: var(--warn-bg);
10 | color: var(--warn-fg);
11 | }
12 |
13 | .error {
14 | background: var(--error-bg);
15 | color: var(--error-fg);
16 | }
17 |
18 | .info {
19 | background: var(--note-bg);
20 | color: var(--note-fg);
21 | }
22 |
23 | .close {
24 | float: right;
25 | cursor: pointer;
26 | }
27 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/components/message-list.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview A list of dismissable info/warning/error messages.
20 | */
21 |
22 | import {html, LitElement} from 'lit';
23 |
24 | import {unsafeHTML} from 'lit/directives/unsafe-html.js';
25 |
26 | import {customElement} from 'lit/decorators.js';
27 | import styles from './message-list.scss';
28 |
29 | enum MessageType {
30 | INFO = 'info',
31 | WARNING = 'warning',
32 | ERROR = 'error',
33 | }
34 |
35 | interface Message {
36 | message: string;
37 | type: MessageType;
38 | rawHtml: boolean;
39 | }
40 |
41 |
42 | /**
43 | * Shows info/warning/error messages that remain until closed by user.
44 | */
45 | @customElement('message-list')
46 | export class MessageList extends LitElement {
47 | static override styles = [styles];
48 |
49 | messages: Message[] = [];
50 |
51 | addMessage(message: Message) {
52 | this.messages.push(message);
53 | this.requestUpdate();
54 | }
55 |
56 | info(message: string, {rawHtml = false}: {rawHtml?: boolean} = {}) {
57 | this.addMessage({message, type: MessageType.INFO, rawHtml});
58 | }
59 |
60 | warning(message: string, {rawHtml = false}: {rawHtml?: boolean} = {}) {
61 | this.addMessage({message, type: MessageType.WARNING, rawHtml});
62 | }
63 |
64 | error(message: string, {rawHtml = false}: {rawHtml?: boolean} = {}) {
65 | this.addMessage({message, type: MessageType.ERROR, rawHtml});
66 | }
67 |
68 | removeMessage(event: Event, idx: number) {
69 | this.messages.splice(idx, 1);
70 | (event.target! as HTMLElement).closest('.message')!.remove();
71 | }
72 |
73 | clear() {
74 | this.messages = [];
75 | while (this.firstChild) this.firstChild.remove();
76 | }
77 |
78 | override render() {
79 | return this.messages.map(
80 | (message: Message, idx: number) => html`
81 |
82 | ${
83 | message.rawHtml ? unsafeHTML(message.message) :
84 | message.message}
85 | {
86 | this.removeMessage(e, idx);
87 | }} class="close">✖
88 |
89 | `);
90 | }
91 | }
92 |
93 | declare global {
94 | interface HTMLElementTagNameMap {
95 | 'message-list': MessageList;
96 | }
97 | }
98 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/components/model-controls.scss:
--------------------------------------------------------------------------------
1 | .controls {
2 | margin: 1em 0;
3 | display: flex;
4 |
5 | select {
6 | margin-left: 0.5em;
7 | }
8 |
9 | progress {
10 | margin: 0 1em;
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/components/model-controls.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview Controls to choose model.
20 | */
21 |
22 | import {html, LitElement} from 'lit';
23 |
24 | import {getModels} from '../lit_demo/constants';
25 | import {app} from '../lit_demo/app';
26 |
27 | import {customElement, property} from 'lit/decorators.js';
28 | import styles from './model-controls.scss';
29 |
30 | /**
31 | * Shows controls for model selection, progress bar, and status text.
32 | */
33 | @customElement('model-controls')
34 | export class ModelControls extends LitElement {
35 |
36 | static override styles = [styles];
37 |
38 | @property({attribute: false})
39 | progress: number = 0;
40 |
41 | @property({attribute: false})
42 | status: string = 'Initializing...';
43 |
44 | constructor() {
45 | super();
46 | app.models.addListener(this.onModelUpdate.bind(this));
47 | app.models.load(getModels()[0]);
48 | }
49 |
50 | onModelUpdate(progress: number, message?: string) {
51 | this.progress = progress;
52 | if (message) this.status = message;
53 | }
54 |
55 | onModelChange(event: Event) {
56 | const target = event.target as HTMLSelectElement;
57 | const name = target.value;
58 | app.models.load(name).catch((error) => {
59 | this.status = `ERROR loading model "${name}": ${error}`;
60 | });
61 | }
62 |
63 | async setModel(model: string) {
64 | if (getModels().indexOf(model) === -1) {
65 | throw new Error(`Model "${model}" not found!`);
66 | }
67 | await this.updateComplete;
68 | const dropdown = this.shadowRoot!.querySelector('#model_dropdown') as HTMLSelectElement;
69 | dropdown.value = model;
70 | dropdown.dispatchEvent(new Event('change'));
71 | }
72 |
73 | override render() {
74 | const options = getModels().map((model: string) =>
75 | html``);
76 | return html`
77 |
78 |
79 |
82 |
83 |
${this.status}
84 |
85 | `;
86 | }
87 | }
88 |
89 | declare global {
90 | interface HTMLElementTagNameMap {
91 | 'model-controls': ModelControls;
92 | }
93 | }
94 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/exports.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview some useful exports to play around with the models &
20 | * tokenizers.
21 | *
22 | * Simple usage (see ./playground.html for more complete usage example):
23 | *
24 | * model = lit.Model('tiny');
25 | * model.load(progress => console.log('loading...', progress));
26 | * console.log(model.computeProbabilities(['a dog', 'a cat'], '0'));
27 | */
28 |
29 | import {Model} from './lit_demo/compute';
30 | import {getImageUrl, setBaseUrl} from './lit_demo/constants';
31 | import {ImageData} from './lit_demo/data';
32 | import * as tf from '@tensorflow/tfjs-core';
33 |
34 | // tslint:disable-next-line:no-any Export symbols into global namespace.
35 | (window as any).lit = { Model, getImageUrl, ImageData, setBaseUrl };
36 | // tslint:disable-next-line:no-any Export symbols into global namespace.
37 | // tslint:disable-next-line:ban-module-namespace-object-escape Export all of TF.
38 | (window as any).tf = tf;
39 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/lit_demo/app.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview Global app state.
20 | */
21 |
22 | import {ImageData} from './data';
23 | import {Models} from './compute';
24 |
25 | /**
26 | * Container class holding image data and models.
27 | *
28 | * The main application component would typically call `load()` and then show
29 | * the components depending on this class asynchronously.
30 | */
31 | export class App {
32 |
33 | imageData = new ImageData();
34 | models = new Models();
35 |
36 | ready: boolean = false;
37 |
38 | async load() {
39 | await this.imageData.load();
40 | this.ready = true;
41 | }
42 | }
43 |
44 | /**
45 | * Global app state.
46 | */
47 | export const app = new App();
48 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/lit_demo/constants.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview Project-wide constants.
20 | */
21 |
22 | // Can be overwritten with setBaseUrl() below.
23 | // let baseUrl = 'https://google-research.github.io/vision_transformer/lit';
24 | let baseUrl = 'https://figur.li/jax2tfjs';
25 | // Can be overwritten with setModels() below.
26 | let models = ['tiny', 'small'];
27 |
28 | /** Allows to set abnew base URL. ase URL on which all other. */
29 | export const setBaseUrl = (newBaseUrl: string) => {
30 | baseUrl = newBaseUrl;
31 | };
32 |
33 | /** Retrieves URL for a model-specific file (vocabulary, embeddings, ...). */
34 | export const getModelFileUrl = (name: string, relativePath: string) => (
35 | `${baseUrl}/data/models/${name}/${relativePath}`
36 | );
37 |
38 | /** Retrieves the URL for images information JSON file. */
39 | export const getImagesInfoUrl = () => `${baseUrl}/data/images/info.json`;
40 |
41 | /** Retrieves the URL for an image. */
42 | export const getImageUrl = (id: string) => `${baseUrl}/data/images/${id}.jpg`;
43 |
44 | /** Returns names of available models. */
45 | export const getModels = () => models;
46 |
47 | /** Sets names of available models. */
48 | export const setModels = (newModels: string[]) => {
49 | models = newModels;
50 | };
51 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/lit_demo/data.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview Accessing additional data.
20 | */
21 |
22 | import {getImagesInfoUrl} from './constants';
23 |
24 | /**
25 | * Information about a single image.
26 | */
27 | export interface ImageRow {
28 | /** Stable ID of the image. */
29 | id: string;
30 | /** Set of example prompts for this image. */
31 | prompts: string;
32 | /** License of the image. */
33 | license: string;
34 | /** Where the image was originally downloaded from. */
35 | source: string;
36 | /** Short description of image. */
37 | description: string;
38 | }
39 | /**
40 | * Contains information about all images.
41 | */
42 | export class ImageData {
43 |
44 | rows: ImageRow[] = [];
45 | /** Will be set to `true` when `load()` finishes. */
46 | ready = false;
47 |
48 | /**
49 | * Gets an image by ID. Throws an error if image is not found, data is not
50 | * loaded, or ID is not unique.
51 | */
52 | get(id: string): ImageRow {
53 | if (!this.ready) {
54 | throw new Error('ImageData not loaded!');
55 | }
56 | const matching = this.rows.filter(row => row.id === id);
57 | if (matching.length !== 1) {
58 | throw new Error(`Got unexpected ${matching.length} matches for id="${id}"`);
59 | }
60 | return matching[0];
61 | }
62 |
63 | /**
64 | * Loads image data asynchronously.
65 | */
66 | async load() {
67 | this.rows = (
68 | await fetch(getImagesInfoUrl())
69 | .then(response => {
70 | console.log('response', response);
71 | return response.json();
72 | })
73 | );
74 | this.ready = true;
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/style.scss:
--------------------------------------------------------------------------------
1 | // General styles for the page.
2 |
3 | @import './style/colors';
4 | @import './style/mixins';
5 |
6 | html {
7 | font-size: 14px;
8 | line-height: 1.6em;
9 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen,
10 | Ubuntu, Cantarell, 'Fira Sans', 'Droid Sans', 'Helvetica Neue', Arial,
11 | sans-serif;
12 | text-size-adjust: 100%;
13 | -ms-text-size-adjust: 100%;
14 | -webkit-text-size-adjust: 100%;
15 |
16 | @media (min-width: 1200px) {
17 | width: 1024px;
18 | margin: 0 auto;
19 | }
20 | @media (min-width: 768px) {
21 | font-size: 16px;
22 | }
23 |
24 | color: var(--text-fg);
25 | background: var(--text-bg);
26 |
27 | body {
28 | margin: 0;
29 | padding: 0rem 1rem 10rem;
30 | }
31 | }
32 |
33 | a,
34 | a:visited {
35 | color: var(--link-col);
36 | }
37 |
38 | h1 {
39 | font-weight: 700;
40 | font-size: 2rem;
41 | line-height: 1.3em;
42 | }
43 |
44 | p {
45 | font-size: 1.06rem;
46 | line-height: 1.3em;
47 | }
48 |
49 | input {
50 | font-size: 1rem;
51 |
52 | &::placeholder {
53 | color: var(--placeholder-col);
54 | }
55 | }
56 |
57 | .note {
58 | font-style: normal;
59 | border: none;
60 | border-radius: 2px;
61 | margin-left: auto;
62 | margin-right: auto;
63 |
64 | padding: 0.5rem 0.5rem 0.5rem 2rem;
65 | width: 90%;
66 |
67 | @include phone-portrait {
68 | width: 100%;
69 | padding: 0.5rem;
70 | box-sizing: border-box;
71 | }
72 |
73 | background-color: var(--note-bg);
74 | color: var(--note-fg);
75 |
76 | &.warning {
77 | background-color: var(--warn-bg);
78 | color: var(--warn-fg);
79 | }
80 | }
81 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/style/colors.scss:
--------------------------------------------------------------------------------
1 | // Dark and light mode colors.
2 |
3 | :root {
4 | --text-bg: hsl(0, 0%, 97%);
5 | --gray-border: hsla(0, 0%, 0%, 0.1);
6 | --gray: rgba(0, 0, 0, 0.6);
7 | --border-radius: 5px;
8 | --orange: hsl(24, 100%, 50%);
9 | --distill-blue: hsl(200, 50%, 25%);
10 | --blue: #337699;
11 | --green: #3db867;
12 | --text-fg: rgb(15, 15, 15);
13 | --text-red: rgb(220, 0, 0);
14 | --bar-col: rgb(171, 199, 227);
15 | --link-col: rgb(0, 0, 238);
16 | --placeholder-col: rgb(166, 166, 166);
17 | --note-bg: #e1f5fe;
18 | --note-fg: #1a6ebb;
19 | --warn-bg: #ffe1aa;
20 | --warn-fg: #a16800;
21 | --error-bg: #850000;
22 | --error-fg: white;
23 |
24 | @media (prefers-color-scheme: dark) {
25 | --text-bg: rgb(56, 56, 56);
26 | --text-fg: rgb(213, 213, 213);
27 | --bar-col: rgb(20, 109, 163);
28 | --link-col: rgb(66, 165, 245);
29 |
30 | --note-fg: rgb(121 157 190);
31 | --note-bg: rgb(2 59 85);
32 | --warn-bg: #784e00;
33 | --warn-fg: #edbe68;
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/style/mixins.scss:
--------------------------------------------------------------------------------
1 | // Useful mixins.
2 |
3 | // To wrap styles that should only trigger for phones in portrait mode.
4 | @mixin phone-portrait {
5 | @media only screen and (max-device-width: 800px) and (orientation: portrait) {
6 | @content;
7 | }
8 | }
9 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/tokenizers/common.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview Utility code shared between tokenizers.
20 | */
21 |
22 | /**
23 | * A vocabulary consists of a list of tokens, and optional numerical value.
24 | * The numerical value is used by the unigram algorithnm to find the best
25 | * tokenizaion, and is ignored by the BPE algorithm.
26 | */
27 | export type Vocabulary = Array<[string, number]>;
28 |
29 | /**
30 | * Converts a string to a sequence of tokens.
31 | */
32 | export interface Tokenizer {
33 | encode(input: string): number[];
34 | }
35 |
36 | /**
37 | * Factory for new `Tokenizer`.
38 | */
39 | export interface TokenizerConstructor {
40 | new (vocabulary: Vocabulary): Tokenizer;
41 | }
42 |
43 | /**
44 | * Unicode-aware character iteration of strings.
45 | */
46 | export const stringToChars = (input: string): string[] => {
47 | const symbols = [];
48 | for (const symbol of input) {
49 | symbols.push(symbol);
50 | }
51 | return symbols;
52 | };
53 |
54 | /**
55 | * Special separator character used to delimit sub-word tokens.
56 | */
57 | export const TOKEN_SEPARATOR =
58 | '\u2581'; // This is the unicode character 'lower one eighth block'.
59 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/tokenizers/index.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview Tokenizers and tokenizer mappings.
20 | */
21 |
22 | import {Tokenizer, TokenizerConstructor, Vocabulary} from './common';
23 | import * as sentencepieceBpe from './sentencepiece_bpe';
24 | import * as sentencepieceUnigram from './sentencepiece_unigram';
25 |
26 | export {Tokenizer, Vocabulary} from './common';
27 |
28 | const TOKENIZERS = new Map([
29 | ['BPE', sentencepieceBpe.Tokenizer],
30 | ['UNIGRAM', sentencepieceUnigram.Tokenizer],
31 | ]);
32 |
33 | /**
34 | * Returns a tokenizer of type `name` using `vocabulary`.
35 | */
36 | export const getTokenizer = (name: string, vocabulary: Vocabulary): Tokenizer => {
37 | const ctor = TOKENIZERS.get(name);
38 | if (!ctor) throw new Error(`Unknown tokenizer: ${name}`);
39 | return new ctor(vocabulary);
40 | };
41 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_bpe.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | import {stringToChars, TOKEN_SEPARATOR, Vocabulary, Tokenizer as TokenizerInterface} from './common';
19 |
20 | interface Candidate {
21 | piece: string;
22 | pos: number;
23 | score: number;
24 | }
25 |
26 | const scoreDesc = (a: Candidate, b: Candidate) => b.score - a.score;
27 |
28 | function processInput(str: string): string {
29 | const normalized = str.normalize('NFKC');
30 | return normalized.length > 0 ?
31 | TOKEN_SEPARATOR + normalized.replace(/ /g, TOKEN_SEPARATOR) :
32 | normalized;
33 | }
34 |
35 | /**
36 | * Sentencepiece tokenizer implementing the BPE algorithm.
37 | */
38 | export class Tokenizer implements TokenizerInterface {
39 |
40 | // piece -> [score, index]
41 | private readonly map: Map;
42 |
43 | constructor(vocabulary: Vocabulary) {
44 | this.map = new Map();
45 | vocabulary.forEach(([piece, score], idx) => {
46 | if (this.map.has(piece)) {
47 | throw new Error(`Piece "${piece}" occurs multiple times in vocabulary`);
48 | }
49 | this.map.set(piece, [score, idx]);
50 | });
51 | }
52 |
53 | encode(input: string): number[] {
54 | const processed: string = processInput(input);
55 | let pieces: string[] = stringToChars(processed);
56 |
57 | while (true) {
58 | const candidates: Candidate[] = [];
59 | for (let i = 0; i < pieces.length - 1; i++) {
60 | const fused = pieces[i] + pieces[i + 1];
61 | const el = this.map.get(fused);
62 | if (el) {
63 | candidates.push({ piece: fused, pos: i, score: el[0] });
64 | }
65 | }
66 | if (candidates.length === 0) {
67 | break;
68 | }
69 | candidates.sort(scoreDesc);
70 | const best = candidates[0];
71 | pieces = [
72 | ...pieces.slice(0, best.pos),
73 | best.piece,
74 | ...pieces.slice(best.pos + 2)
75 | ];
76 | }
77 |
78 | return pieces.map(piece => this.map.get(piece)![1]);
79 | }
80 | }
81 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_bpe_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | import 'jasmine';
19 |
20 | describe('sentencepiece bpe test', () => {
21 | it('computes a thing when asked', () => {});
22 | });
23 |
24 | import * as bpe from './sentencepiece_bpe';
25 | import {TOKEN_SEPARATOR, Vocabulary} from './common';
26 |
27 | const vocab: Vocabulary = [
28 | [TOKEN_SEPARATOR, 0], // 0
29 | ['a', 0], // 1
30 | ['e', 0], // 2
31 | ['s', 0], // 3
32 | ['t', 0], // 4
33 | ['te', -1], // 5
34 | ['st', -2], // 6
35 | ['test', -3], // 7
36 | ['tes', -4], // 8
37 | ];
38 |
39 | describe('BPE Tokenizer', () => {
40 | let tokenizer: bpe.Tokenizer;
41 | beforeAll(() => {
42 | tokenizer = new bpe.Tokenizer(vocab);
43 | });
44 |
45 | it('should tokenize correctly', () => {
46 | expect(tokenizer.encode('a test')).toEqual([0, 1, 0, 7]);
47 | });
48 | });
49 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_unigram_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | import {Tokenizer} from './sentencepiece_unigram';
19 |
20 | const stubbedTokenizerVocab = [
21 | ['�', 0],
22 | ['', 0],
23 | ['', 0],
24 | ['extra_token_id_1', 0],
25 | ['extra_token_id_2', 0],
26 | ['extra_token_id_3', 0],
27 | ['▁', -2],
28 | ['▁a', -1],
29 | ['▁ç', -2],
30 | ['a', -3],
31 | ['.', -1],
32 | ['▁I', -1],
33 | ['▁like', -1],
34 | ['▁it', -1],
35 | ['I', -2],
36 | ['like', -2],
37 | ['it', -2],
38 | ['l', -3],
39 | ['i', -3],
40 | ['k', -3],
41 | ['e', -3],
42 | ['i', -3],
43 | ['t', -3]
44 | ];
45 |
46 | describe('Universal Sentence Encoder tokenizer', () => {
47 | let tokenizer: Tokenizer;
48 | beforeAll(() => {
49 | tokenizer = new Tokenizer(stubbedTokenizerVocab as Array<[string, number]>);
50 | });
51 |
52 | it('basic usage', () => {
53 | expect(tokenizer.encode('Ilikeit.')).toEqual([11, 15, 16, 10]);
54 | });
55 |
56 | it('handles whitespace', () => {
57 | expect(tokenizer.encode('I like it.')).toEqual([11, 12, 13, 10]);
58 | });
59 |
60 | it('should normalize inputs', () => {
61 | expect(tokenizer.encode('ça')).toEqual(tokenizer.encode('c\u0327a'));
62 | });
63 |
64 | it('should handle unknown inputs', () => {
65 | expect(() => tokenizer.encode('😹')).not.toThrow();
66 | });
67 |
68 | it('should treat consecutive unknown inputs as a single word', () => {
69 | expect(tokenizer.encode('a😹😹')).toEqual([7, 0]);
70 | });
71 | });
72 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/tokenizers/trie.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | // Copied from
19 | // https://github.com/tensorflow/tfjs-models/blob/master/universal-sentence-encoder/src/tokenizer/trie.ts
20 |
21 | import {stringToChars} from './common';
22 |
23 | // [token, score, index]
24 | type OutputNode = [string[], number, number];
25 |
26 | class TrieNode {
27 | parent: TrieNode|null;
28 | end: boolean;
29 | children: {[firstSymbol: string]: TrieNode};
30 | word: OutputNode;
31 |
32 | constructor() {
33 | this.parent = null;
34 | this.children = {};
35 | this.end = false;
36 | this.word = [[], 0, 0];
37 | }
38 | }
39 |
40 | /**
41 | * Simple Trie datastructure.
42 | */
43 | export class Trie {
44 | root: TrieNode;
45 |
46 | constructor() {
47 | this.root = new TrieNode();
48 | }
49 |
50 | /**
51 | * Inserts a token into the trie.
52 | */
53 | insert(word: string, score: number, index: number) {
54 | let node = this.root;
55 |
56 | const symbols = stringToChars(word);
57 |
58 | for (let i = 0; i < symbols.length; i++) {
59 | if (!node.children[symbols[i]]) {
60 | node.children[symbols[i]] = new TrieNode();
61 | node.children[symbols[i]].parent = node;
62 | node.children[symbols[i]].word[0] = node.word[0].concat(symbols[i]);
63 | }
64 |
65 | node = node.children[symbols[i]];
66 | if (i === symbols.length - 1) {
67 | node.end = true;
68 | node.word[1] = score;
69 | node.word[2] = index;
70 | }
71 | }
72 | }
73 |
74 | /**
75 | * Returns an array of all tokens starting with ss.
76 | *
77 | * @param ss The prefix to match on.
78 | */
79 | commonPrefixSearch(ss: string[]): OutputNode[] {
80 | const output: OutputNode[] = [];
81 | let node = this.root.children[ss[0]];
82 |
83 | for (let i = 0; i < ss.length && node; i++) {
84 | if (node.end) {
85 | output.push(node.word);
86 | }
87 | node = node.children[ss[i + 1]];
88 | }
89 |
90 | if (!output.length) {
91 | output.push([[ss[0]], 0, 0]);
92 | }
93 |
94 | return output;
95 | }
96 | }
97 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/tools/lit_demo/src/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "outDir": "dist",
4 | "target": "es6",
5 | "module": "commonjs",
6 | "lib": ["dom", "DOM.Iterable", "es2019", "es2020.string"],
7 | "types": ["node", "jasmine", "resize-observer-browser"],
8 | "moduleResolution": "node",
9 | "allowJs": false,
10 | "pretty": true,
11 | "resolveJsonModule": true,
12 | "sourceMap": false,
13 | "skipLibCheck": true,
14 | "removeComments": true,
15 | "esModuleInterop": true,
16 | "importsNotUsedAsValues": "preserve",
17 | "downlevelIteration": true,
18 | "skipDefaultLibCheck": true,
19 | "preserveConstEnums": false,
20 | "experimentalDecorators": true,
21 | "emitDecoratorMetadata": true,
22 | "noErrorTruncation": false,
23 | "noEmitOnError": false,
24 | "declaration": false,
25 | "stripInternal": true,
26 | "inlineSourceMap": true,
27 | "inlineSources": true,
28 | "importHelpers": true,
29 | "allowUnreachableCode": false,
30 | "noFallthroughCasesInSwitch": true,
31 | "noImplicitAny": true,
32 | "noImplicitReturns": false,
33 | "noImplicitThis": true,
34 | "strictBindCallApply": true,
35 | "strictFunctionTypes": true,
36 | "strictNullChecks": false,
37 | "strictPropertyInitialization": false
38 | },
39 | "include": ["./client", "./examples"],
40 | "compileOnSave": false
41 | }
42 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/trainers/proj/flexi/common.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Few common utils used in both/all flexi-trainers."""
16 | import functools
17 | import itertools
18 | import numpy as np
19 |
20 |
21 | def mkrng(xid, wid, step):
22 | # Need to cap at 0, for example localruns use -1.
23 | rng_key = (max(xid, 0), max(wid, 0), max(step, 0))
24 | return np.random.default_rng(rng_key)
25 |
26 |
27 | def mkprob(x):
28 | if x is None:
29 | return x
30 | return np.array(x) / np.sum(x)
31 |
32 |
33 | def choice(values, ratios, rng=None):
34 | rng = rng or np.random.default_rng()
35 | return rng.choice(values, p=mkprob(ratios))
36 |
37 |
38 | def mkpredictfns(predict_fn, config, template="predict_{x}"):
39 | # If we have two flexi args a=[1,2], b=[10,20], then we create a
40 | # predict_fn for all possible combinations, named "predict_a=1_b=10" etc.
41 | all_combinations = [dict(comb) for comb in itertools.product(
42 | *[[(arg, val) for val in config[arg].v] for arg in config]
43 | )]
44 | return {
45 | template.format(x="_".join(f"{k}={v}" for k, v in kw.items())):
46 | functools.partial(predict_fn, **kw)
47 | for kw in all_combinations}
48 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/trainers/proj/givt/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utils for GIVT stage I and II trainers."""
16 |
17 | from typing import Any
18 |
19 | import jax
20 | import jax.numpy as jnp
21 |
22 |
23 | def unbin_depth(
24 | depth: jax.Array,
25 | *,
26 | min_depth: float,
27 | max_depth: float,
28 | num_bins: int,
29 | ) -> jax.Array:
30 | """Transform a depth map with binned values into a float-valued depth map.
31 |
32 | Args:
33 | depth: Depth map whose binned values are encoded in one-hot fashion along
34 | the last dimension.
35 | min_depth: Minimum binned depth value.
36 | max_depth: Maximum value of binned depth.
37 | num_bins: Number of depth bins.
38 |
39 | Returns:
40 | Float-valued depth map.
41 | """
42 | depth = jnp.argmax(depth, axis=-1)
43 | depth = depth.astype(jnp.float32) + 0.5 # Undoes floor in expectation.
44 | depth /= num_bins
45 | return depth * (max_depth - min_depth) + min_depth
46 |
47 |
48 | def get_local_rng(
49 | seed: int | jax.Array,
50 | batch: Any,
51 | ) -> jax.Array:
52 | """Generate a per-image seed based on the image id or the image values.
53 |
54 | Args:
55 | seed: Random seed from which per-image seeds should be derived.
56 | batch: Pytree containing a batch of images (key "image") and optionally
57 | image ids (key "image/id").
58 |
59 | Returns:
60 | Array containing per-image ids.
61 | """
62 | fake_id = None
63 | if "image" in batch:
64 | fake_id = (10**6 * jax.vmap(jnp.mean)(batch["image"])).astype(jnp.int32)
65 | return jax.lax.scan(
66 | lambda k, x: (jax.random.fold_in(k, x), None),
67 | jax.random.PRNGKey(seed),
68 | batch.get("image/id", fake_id),
69 | )[0]
70 |
71 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/trainers/proj/uvim/coco_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities to inspect coco data and predictions in notebooks."""
16 | # pylint: disable=consider-using-from-import
17 | import functools
18 | import json
19 |
20 | import numpy as np
21 | from panopticapi import utils as pycoco_utils
22 | from skimage import segmentation
23 |
24 | import tensorflow.io.gfile as gfile
25 |
26 |
27 | import os
28 | ROOT = os.environ.get('COCO_DATA_DIR', '.')
29 |
30 |
31 | PANOPTIC_COCO_CATS_FILE = f'{ROOT}/panoptic_coco_categories.json'
32 |
33 |
34 | @functools.lru_cache(maxsize=None)
35 | def _coco_panoptic_categories():
36 | with gfile.GFile(PANOPTIC_COCO_CATS_FILE, 'r') as f:
37 | categories_list = json.load(f)
38 | return tuple(categories_list)
39 |
40 |
41 | def rgb_panoptic_from_twochannels(twochannels, boundaries: bool = False):
42 | """Makes a RGB panoptic output and segments_info from a twochannels view."""
43 | semantics = twochannels[..., 0]
44 | instances = twochannels[..., 1]
45 | max_instances = np.max(instances) + 1
46 | merged = semantics * max_instances + instances
47 | merged = np.where(semantics < 0, semantics, merged)
48 |
49 | categories_list = _coco_panoptic_categories()
50 | categories = {category['id']: category for category in categories_list}
51 | id_generator = pycoco_utils.IdGenerator(categories)
52 | segments_info = {}
53 | rgb = np.zeros((*instances.shape[:2], 3), dtype=np.uint8)
54 |
55 | for merged_id in np.unique(merged):
56 | if merged_id // max_instances > 0:
57 | category = categories_list[int(merged_id // max_instances) - 1]
58 | segment_id, color = id_generator.get_id_and_color(category['id'])
59 | else:
60 | category = {'id': -1, 'name': 'void', 'isthing': False}
61 | segment_id, color = -1, np.array([0, 0, 0])
62 | segments_info[segment_id] = {
63 | 'id': segment_id,
64 | 'color': color,
65 | 'category_id': category['id'],
66 | 'name': category['name'],
67 | 'isthing': category['isthing'],
68 | }
69 | rgb[merged == merged_id] = color
70 |
71 | if boundaries:
72 | boundaries = segmentation.find_boundaries(
73 | pycoco_utils.rgb2id(rgb), mode='thick')
74 | rgb[boundaries] = 0
75 | return rgb, segments_info
76 |
--------------------------------------------------------------------------------
/palivla_digit/big_vision/trainers/proj/uvim/colorization_task.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Inputs, outputs and losses for colorization task."""
16 | import einops
17 | import jax.numpy as jnp
18 | import numpy as np
19 |
20 | ONE_HOT_AXIS = -2
21 |
22 |
23 | def input_pp(batch, config):
24 | """Make inputs for colorization task."""
25 | if "labels" not in batch:
26 | # During predict of phase2 there is no 'labels' field.
27 | x = None
28 | else:
29 | hp, wp = config.model.patch_size
30 | x = {
31 | "color": batch["labels"],
32 | }
33 | # Convert labels from (B, H, W) to (B, num_patches, C, patch_size)
34 | x["color"] = einops.rearrange(
35 | x["color"], "b (hn hp) (wn wp) c -> b (hn wn) c (hp wp)", hp=hp, wp=wp)
36 | ctx = batch.get("image_ctx", batch.get("image", None))
37 | return {"ctx": ctx, "x": x}
38 |
39 |
40 | def loss_fn(logits, batch, config):
41 | """Compute loss for colorization task."""
42 | labels = input_pp(batch, config)["x"]
43 | error = logits["color"] - labels["color"]
44 | loss = jnp.square(error)
45 | return loss, {"loss_color": loss}
46 |
47 |
48 | def predict_outputs(logits, config):
49 | """Make outputs for colorization task."""
50 | # Map logits to (height, width, channels).
51 | hp, wp = config.model.patch_size
52 | hn, wn = np.array(config.model.input_size) // np.array((hp, wp))
53 | assert ONE_HOT_AXIS == -2, "Rearrange below depends on this."
54 | output = einops.rearrange(
55 | logits["color"],
56 | "b (hn wn) c (hp wp) -> b (hn hp) (wn wp) c",
57 | hn=hn,
58 | wn=wn,
59 | hp=hp,
60 | wp=wp)
61 | output = jnp.clip(output, -1., 1.)
62 | return {"color": output}
63 |
--------------------------------------------------------------------------------
/palivla_digit/palivla/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuse-model/FuSe/db701c0df2445c4ef559c40fb6aa94919168a198/palivla_digit/palivla/__init__.py
--------------------------------------------------------------------------------
/palivla_digit/palivla/modality_embedder.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import flax.linen as nn
3 |
4 | # quick hack to make embedding w/ data type param json-serializable
5 | class ModalityEmbedder(nn.Module):
6 | num_embeddings: int
7 | embedding_dim: int
8 | dtype_str: str = 'float32'
9 |
10 | @nn.compact
11 | def __call__(self, x):
12 | return nn.Embed(
13 | num_embeddings=self.num_embeddings,
14 | features=self.embedding_dim,
15 | dtype=getattr(jnp, self.dtype_str),
16 | )(x)
--------------------------------------------------------------------------------
/palivla_digit/palivla/types.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Sequence, Mapping, Union
2 |
3 | import jax
4 | from flax.typing import Collection, VariableDict
5 | from flax.struct import dataclass
6 | import chex
7 |
8 | Array = chex.Array
9 | ArrayTree = Union[chex.Array, Mapping[str, "ArrayTree"], Sequence["ArrayTree"]]
10 | Params = Collection
11 | Variables = VariableDict
12 | Updates = ArrayTree
13 | Data = ArrayTree
14 | Info = Dict[str, Any]
15 |
16 |
17 | @dataclass
18 | class TrainingBatch:
19 | sensors: Dict[str, jax.Array]
20 | sensors_mask: jax.Array
21 | actions_mask: jax.Array
22 | actions: jax.Array
23 | tokens: jax.Array
24 | tokens_ar: jax.Array
25 | tokens_loss: jax.Array
26 | tokens_mask: jax.Array
27 | language_validity: jax.Array | None = None
28 | tokens_ar_fuse: jax.Array | None = None
29 | tokens_loss_fuse: jax.Array | None = None
30 | gen_start: jax.Array | None = None
31 | modality_idx: jax.Array | None = None
32 | modal_mask: jax.Array | None = None
33 |
34 |
35 | @dataclass
36 | class RolloutBatch:
37 | sensor_data: Dict[str, jax.Array]
38 | sensor_masks: Dict[str, jax.Array]
39 | prompt: jax.Array
40 | prompt_mask: jax.Array
41 | prompt_ar: jax.Array
42 |
--------------------------------------------------------------------------------
/palivla_digit/palivla/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import jax
3 | from jax.experimental import multihost_utils
4 | import numpy as np
5 | import tensorflow as tf
6 | import flax
7 | from palivla.types import Params
8 |
9 | def freeze_structure(structure):
10 | return jax.tree_util.tree_map(
11 | lambda x: tuple(freeze_structure(y) for y in x) if isinstance(x, list) else x,
12 | structure,
13 | is_leaf=lambda x: isinstance(x, list),
14 | )
15 |
16 | def key_string(path, separator="/") -> str:
17 | def _component_to_string(component) -> str:
18 | if isinstance(component, jax.tree_util.SequenceKey):
19 | return str(component.idx)
20 | elif isinstance(component, jax.tree_util.DictKey):
21 | return str(component.key)
22 | elif isinstance(component, jax.tree_util.GetAttrKey):
23 | return str(component.name)
24 | elif isinstance(component, jax.tree_util.FlattenedIndexKey):
25 | return str(component.key)
26 | else:
27 | return str(component)
28 | return separator.join(_component_to_string(component) for component in path)
29 |
30 |
31 | def host_broadcast_str(x: str | None) -> str:
32 | """
33 | Broadcast_one_to_all, but with a string.
34 |
35 | Works by padding the string to the length of the longest string and then
36 | broadcasting the result, then stripping the padding.
37 |
38 | Note: this will remove the padding from the end of the string.
39 | """
40 | if x is None:
41 | x = ""
42 |
43 | max_len = multihost_utils.broadcast_one_to_all(len(x))
44 | padded = x.ljust(max_len)
45 |
46 | encoded = np.array([ord(c) for c in padded], dtype=np.uint8)[:max_len]
47 | encoded = multihost_utils.broadcast_one_to_all(encoded)
48 | decoded = "".join([chr(u) for u in encoded])
49 |
50 | return decoded.rstrip()
51 |
52 |
53 | def load_tvl_weights(pretrained_path: str) -> dict[tuple, np.ndarray]:
54 | with tf.io.gfile.GFile(pretrained_path, 'rb') as f:
55 | ckpt_dict = np.load(f, allow_pickle=False)
56 | keys, values = zip(*list(ckpt_dict.items()))
57 | return {tuple(k.split('|')): v for k, v in zip(keys, values)}
58 |
59 |
60 | def merge_params(init_params: Params, pretrained_params: Params) -> Params:
61 | def _merge(possible_param1, possible_param2):
62 | if possible_param2 is not None:
63 | return possible_param2
64 | return possible_param1
65 | flat_init_params = flax.traverse_util.flatten_dict(init_params)
66 | flat_pretrained_params = flax.traverse_util.flatten_dict(pretrained_params)
67 | params = {k: _merge(v_init, flat_pretrained_params.get(k, None)) for k, v_init in flat_init_params.items()}
68 | return flax.traverse_util.unflatten_dict(params)
--------------------------------------------------------------------------------
/palivla_digit/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "big-vision"
3 | version = "0.1.0"
4 | description = "Add your description here"
5 | readme = "README.md"
6 | requires-python = ">=3.11"
7 | dependencies = [
8 | "big-vision",
9 | "tqdm>=4.60.0",
10 | "absl-py>=0.12.0",
11 | "imageio>=2.31.1",
12 | "moviepy>=1.0.3",
13 | "chex>=0.1.86",
14 | "distrax>=0.1.5",
15 | "dlimp@git+https://github.com/kvablack/dlimp.git",
16 | "einops>=0.8.0",
17 | "flax>=0.9.0",
18 | "ipykernel",
19 | "jax>=0.4.34",
20 | "matplotlib>=3.9.2",
21 | "ml-collections>=0.1.1",
22 | "numpy<2.0.0",
23 | "optax>=0.2.3",
24 | "orbax-checkpoint>=0.7.0",
25 | "overrides>=7.7.0",
26 | "pip",
27 | "scalax>=0.2.4",
28 | "scikit-learn>=1.5.2",
29 | "scipy>=1.14.1",
30 | "tensorflow-probability>=0.24.0",
31 | "plotly>=5.16.1",
32 | "tfds-nightly>=4.9.0",
33 | "tf-nightly>=2.15.0",
34 | "tensorflow-text-nightly>=2.15.0",
35 | "tensorflow_hub>=0.14.0",
36 | "tensorflow_graphics",
37 | "wandb>=0.18.3",
38 | "protobuf>=3.20",
39 | "huggingface-hub>=0.27.0",
40 | "transformers>=4.47.1",
41 | "prettytable>=3.12.0",
42 | "funcsigs",
43 | "opencv-python",
44 | "pyquaternion",
45 | "librosa",
46 | "edgeml @ git+https://github.com/youliangtan/edgeml.git",
47 | "gym>=0.26",
48 | "jax-smi",
49 | "octo",
50 | "eval",
51 | ]
52 |
53 | [project.optional-dependencies]
54 | tpu = [
55 | "jax[tpu]>=0.4.34",
56 | "libtpu-nightly",
57 | ]
58 | gpu = [
59 | "jax[cuda12]==0.4.34"
60 | ]
61 |
62 | [tool.uv]
63 | find-links = ["https://storage.googleapis.com/jax-releases/libtpu_releases.html", "https://pypi.org/simple/tf-nightly/"]
64 |
65 | prerelease = "allow"
66 | conflicts = [
67 | [
68 | { extra = "tpu" },
69 | { extra = "gpu" },
70 | ],
71 | ]
72 | override-dependencies = [
73 | # Always use tf-nightly and tfds-nightly instead of tensorflow and tensorflow_datasets
74 | "tensorflow ; sys_platform == 'never'",
75 | "tensorflow_datasets ; sys_platform == 'never'",
76 | "scipy>=1.14.1",
77 | "jax>=0.4.34",
78 | ]
79 |
80 | [build-system]
81 | requires = ["hatchling"]
82 | build-backend = "hatchling.build"
83 |
84 | [dependency-groups]
85 | dev = [
86 | "ipywidgets>=8.1.5",
87 | "isort>=6.0.0b2",
88 | "ruff>=0.8.4",
89 | ]
90 |
91 | [tool.hatch.metadata]
92 | allow-direct-references = true
93 |
94 | [tool.uv.sources]
95 | big-vision = { workspace = true }
96 | octo = { path = "../octo_digit/octo", editable = true }
97 | eval = { path = "../octo_digit/eval", editable = true }
--------------------------------------------------------------------------------
/palivla_digit/run.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Check if a TPU VM name is provided
4 | if [ $# -eq 0 ]; then
5 | echo "Usage: $0 "
6 | exit 1
7 | fi
8 |
9 | TPU_VM_NAME=$1
10 | PROJECT="rail-tpus"
11 |
12 | # Cache file for TPU name/zone mapping
13 | CACHE_FILE="$HOME/.cache/tpus"
14 | mkdir -p "$(dirname "$CACHE_FILE")"
15 |
16 | # Check if the TPU info is already cached
17 | if [ -f "$CACHE_FILE" ]; then
18 | CACHED_INFO=$(grep "^$TPU_VM_NAME:" "$CACHE_FILE")
19 | if [ -n "$CACHED_INFO" ]; then
20 | CACHED_ZONE=$(echo "$CACHED_INFO" | cut -d':' -f2)
21 | NUM_WORKERS=$(echo "$CACHED_INFO" | cut -d':' -f3)
22 | fi
23 | fi
24 |
25 | if [ -n "$CACHED_ZONE" ]; then
26 | ZONE=$CACHED_ZONE
27 | else
28 | # Get the TPU information
29 | for MAYBE_ZONE in us-central1-a us-central2-b europe-west4-b; do
30 | TPU_INFO=$(gcloud compute tpus tpu-vm describe $TPU_VM_NAME --project=$PROJECT --zone=$MAYBE_ZONE --format=json 2>/dev/null)
31 | if [ $? -eq 0 ]; then
32 | # Cache the successful name/zone mapping and number of workers
33 | ZONE=$MAYBE_ZONE
34 | NUM_WORKERS=$(echo "$TPU_INFO" | jq '.networkEndpoints | length')
35 | echo "$TPU_VM_NAME:$ZONE:$NUM_WORKERS" >> "$CACHE_FILE"
36 | break
37 | fi
38 | done
39 | fi
40 |
41 | # Set the source and destination directories based on the zone
42 | if [[ $ZONE == "europe-west4-"* ]]; then
43 | DEST_DIR="$TPU_VM_NAME:/nfs/nfs3/users/kstachowicz/big_vision_multimodal"
44 | elif [[ $ZONE == "us-central2-"* ]]; then
45 | DEST_DIR="data-machine:/nfs/nfs2/users/kstachowicz/big_vision_multimodal"
46 | else
47 | echo "Unsupported zone: $ZONE"
48 | exit 1
49 | fi
50 |
51 | echo "TPU_VM_NAME: $TPU_VM_NAME"
52 | echo "ZONE: $ZONE"
53 | echo "DEST_DIR: $DEST_DIR"
54 | echo "Number of workers: $NUM_WORKERS"
55 |
56 | # Copy the source directory to the TPU VM
57 | rsync -avzL --exclude .git --exclude-from=.gitignore . $DEST_DIR
58 |
59 | # Launch the pod configuration
60 | POD_NAME=$TPU_VM_NAME tpc launch pod_config.py
61 |
62 | # Connect to the pod
63 | bash ssh_pod.sh $TPU_VM_NAME
64 |
65 |
--------------------------------------------------------------------------------
/palivla_digit/setup.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | uv sync --extra gpu
3 | uv pip install -e ../octo_digit --no-deps
4 | uv pip install -e ../bridge_with_digit/widowx_envs
--------------------------------------------------------------------------------
/palivla_digit/setup_pod.sh:
--------------------------------------------------------------------------------
1 | sudo apt -y update && sudo apt -y install nfs-common
2 | sudo mkdir -p -m 777 /nfs/nfs3
3 | sudo mount -o rw,intr 10.105.46.66:/nfs3 /nfs/nfs3
4 |
5 | sudo usermod -u 3210 kstachowicz
6 | sudo groupmod -g 3210 kstachowicz
7 | sudo chown -R kstachowicz:kstachowicz /home/kstachowicz
8 |
9 | sudo -i -u kstachowicz bash << EOF
10 |
11 | git config --global --add safe.directory '*'
12 |
13 | rm -f .bashrc
14 | ln -s /nfs/nfs3/users/kstachowicz/.bashrc .bashrc
15 | ln -s /nfs/nfs3/users/kstachowicz/.netrc .netrc
16 | EOF
17 |
18 | /nfs/nfs3/users/kstachowicz/miniforge3/bin/conda init bash
--------------------------------------------------------------------------------
/palivla_digit/ssh_pod.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ -z "$1" ]; then
4 | echo "Usage: ssh_pod.sh "
5 | exit 1
6 | fi
7 |
8 | TPU_VM_NAME=$1
9 |
10 | # Cache file for TPU name/zone mapping
11 | CACHE_FILE="$HOME/.cache/tpus"
12 |
13 | # Check if the TPU info is already cached
14 | if [ -f "$CACHE_FILE" ]; then
15 | CACHED_INFO=$(grep "^$TPU_VM_NAME:" "$CACHE_FILE")
16 | if [ -n "$CACHED_INFO" ]; then
17 | ZONE=$(echo "$CACHED_INFO" | cut -d':' -f2)
18 | N_WORKERS=$(echo "$CACHED_INFO" | cut -d':' -f3)
19 | fi
20 | fi
21 |
22 | # Use default values if not found in cache
23 | ZONE=${ZONE:-europe-west4-b}
24 | N_WORKERS=${N_WORKERS:-16}
25 |
26 | echo "Connecting to $TPU_VM_NAME with $N_WORKERS workers in zone $ZONE..."
27 |
28 | tmux kill-session -t tpc_${TPU_VM_NAME} || true
29 | tmux new -d -s tpc_${TPU_VM_NAME}
30 | for i in $(seq 0 $(($N_WORKERS - 1))); do
31 | TMUX_HEIGHT=$(tmux display-message -p '#{window_height}')
32 | TMUX_WIDTH=$(tmux display-message -p '#{window_width}')
33 |
34 | tmux new-window -t tpc_${TPU_VM_NAME}:$i -k
35 | INNER_TMUX_COMMAND="tmux a -t tpc_${TPU_VM_NAME}"
36 | tmux send-keys -t tpc_${TPU_VM_NAME} "gcloud compute tpus tpu-vm ssh --zone $ZONE $TPU_VM_NAME --worker=$i -- -t $INNER_TMUX_COMMAND" Enter
37 | done
38 | tmux a -t tpc_${TPU_VM_NAME} || tmux switch -t tpc_${TPU_VM_NAME}
39 |
--------------------------------------------------------------------------------