├── .gitignore ├── setup.py └── tux ├── __init__.py ├── checkpoint.py ├── config.py ├── distributed.py ├── jax_utils.py ├── loss.py ├── misc.py ├── optimizers.py ├── stats.py ├── utils.py └── wandb.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # Development utils 141 | logs/ 142 | backup/ 143 | temp/ 144 | local/ 145 | .vscode/ 146 | *.json 147 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='tux', 5 | version='0.0.3', 6 | license='MIT', 7 | description='Tools and Utils for JAX/Flax', 8 | url='https://github.com/forhaoliu/tux', 9 | packages=find_packages(include=['tux']), 10 | python_requires=">=3.7", 11 | install_requires=[ 12 | 'absl-py', 13 | 'ml-collections', 14 | 'wandb', 15 | 'gcsfs', 16 | 'cloudpickle', 17 | 'numpy', 18 | 'transformers', 19 | 'jax', 20 | 'flax', 21 | 'optax', 22 | ], 23 | classifiers=[ 24 | 'Development Status :: 3 - Alpha', 25 | 'Intended Audience :: Developers', 26 | 'License :: OSI Approved :: MIT License', 27 | 'Programming Language :: Python :: 3 :: Only', 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /tux/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoint import StreamingCheckpointer 2 | from .config import (config_dict, define_flags_with_default, 3 | flatten_config_dict, function_args_to_config, 4 | get_user_flags, print_flags, update_config_dict, 5 | user_flags_to_config_dict) 6 | from .distributed import (FlaxTemperatureLogitsWarper, JaxDistributedConfig, 7 | get_jax_mesh, get_names_from_parition_spec, 8 | make_shard_and_gather_fns, match_partition_rules, 9 | names_in_current_mesh, with_sharding_constraint) 10 | from .jax_utils import (JaxRNG, collect_metrics, flatten_tree, 11 | get_pytree_shape_info, init_rng, named_tree_map, 12 | next_rng, set_random_seed, tree_apply, 13 | wrap_function_with_rng) 14 | from .loss import cross_entropy_loss, cross_entropy_loss_and_accuracy, mse_loss 15 | from .misc import (float_tensor_to_dtype, float_to_dtype, 16 | get_float_dtype_by_name, get_gradient_checkpoint_policy) 17 | from .optimizers import (AdamWOptimizerFactory, get_weight_decay_mask, get_mask, optax_add_scheduled_weight_decay, 18 | OptimizerFactory, OptaxScheduledWeightDecayState, PalmOptimizerFactory) 19 | from .stats import average_metrics, global_norm 20 | from .utils import (Timer, array_to_text, load_pickle, open_file, save_pickle, 21 | text_to_array, check_exists) 22 | from .wandb import WandBLogger 23 | -------------------------------------------------------------------------------- /tux/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | from concurrent.futures import ThreadPoolExecutor 3 | 4 | import flax 5 | import jax 6 | import jax.numpy as jnp 7 | import msgpack 8 | import numpy as np 9 | from flax.serialization import (from_bytes, from_state_dict, to_bytes, 10 | to_state_dict) 11 | from flax.traverse_util import empty_node, flatten_dict, unflatten_dict 12 | from ml_collections import ConfigDict 13 | 14 | from .jax_utils import tree_apply 15 | from .misc import float_tensor_to_dtype 16 | from .utils import open_file, save_pickle 17 | 18 | 19 | class StreamingCheckpointer(object): 20 | """ Custom msgpack checkpointer that saves large train states by serializing 21 | and saving tensors one by one in a streaming fashion. Avoids running 22 | out of memory or local TPU disk with default flax checkpointer. 23 | """ 24 | 25 | @staticmethod 26 | def get_default_config(updates=None): 27 | config = ConfigDict() 28 | config.float_dtype = 'bf16' 29 | config.save_optimizer_state = False 30 | 31 | if updates is not None: 32 | config.update(ConfigDict(updates).copy_and_resolve_references()) 33 | return config 34 | 35 | def __init__(self, config, checkpoint_dir, enable=True): 36 | self.config = self.get_default_config(config) 37 | self.checkpoint_dir = checkpoint_dir 38 | self.enable = enable 39 | 40 | def save_checkpoint(self, train_state, filename, gather_fns=None): 41 | if self.enable: 42 | path = os.path.join(self.checkpoint_dir, filename) 43 | else: 44 | path = '/dev/null' 45 | self.save_train_state_to_file( 46 | train_state, path, gather_fns, self.config.float_dtype 47 | ) 48 | 49 | @staticmethod 50 | def save_train_state_to_file(train_state, path, gather_fns=None, float_dtype=None): 51 | train_state = to_state_dict(train_state) 52 | packer = msgpack.Packer() 53 | flattend_train_state = flatten_dict(train_state) 54 | 55 | if gather_fns is None: 56 | with open_file(path, "wb") as fout: 57 | for key, value in flattend_train_state.items(): 58 | value = float_tensor_to_dtype(value, float_dtype) 59 | fout.write(packer.pack((key, to_bytes(value)))) 60 | return 61 | 62 | gather_fns = flatten_dict(to_state_dict(gather_fns)) 63 | lowered_fns = dict() 64 | for key, value in flattend_train_state.items(): 65 | lowered_fns[key] = gather_fns[key].lower(value) 66 | 67 | compiled_fns = dict() 68 | with ThreadPoolExecutor() as executor: 69 | for key in flattend_train_state.keys(): 70 | compiled_fns[key] = executor.submit(lowered_fns[key].compile) 71 | compiled_fns = {k: f.result() for k, f in compiled_fns.items()} 72 | 73 | with open_file(path, "wb") as fout: 74 | for key, value in flattend_train_state.items(): 75 | value = jax.device_get(compiled_fns[key](value)) 76 | value = float_tensor_to_dtype(value, float_dtype) 77 | fout.write(packer.pack((key, to_bytes(value)))) 78 | 79 | def save_pickle(self, obj, filename): 80 | if self.enable: 81 | path = os.path.join(self.checkpoint_dir, filename) 82 | else: 83 | path = '/dev/null' 84 | save_pickle(obj, path) 85 | 86 | def save_all(self, train_state, gather_fns, metadata=None, dataset=None, milestone=False): 87 | step = int(jax.device_get(train_state.step)) 88 | if self.config.save_optimizer_state: 89 | checkpoint_state = train_state 90 | checkpoint_name = 'streaming_train_state' 91 | checkpoint_gather_fns = gather_fns 92 | else: 93 | checkpoint_state = train_state.params['params'] 94 | checkpoint_name = 'streaming_params' 95 | checkpoint_gather_fns = gather_fns.params['params'] 96 | 97 | if milestone: 98 | # Save a milestone checkpoint that will not be overwritten 99 | self.save_pickle(metadata, f'metadata_{step}.pkl') 100 | self.save_pickle(dataset, f'dataset_{step}.pkl') 101 | self.save_checkpoint( 102 | checkpoint_state, f'{checkpoint_name}_{step}', checkpoint_gather_fns 103 | ) 104 | # Additionally save a checkpoint that can be overwritten for automatic resuming 105 | self.save_pickle(metadata, 'metadata.pkl') 106 | self.save_pickle(dataset, 'dataset.pkl') 107 | self.save_checkpoint( 108 | checkpoint_state, f'{checkpoint_name}', checkpoint_gather_fns 109 | ) 110 | else: 111 | # Save a normal checkpoint that can be overwritten 112 | self.save_pickle(metadata, 'metadata.pkl') 113 | self.save_pickle(dataset, 'dataset.pkl') 114 | self.save_checkpoint( 115 | checkpoint_state, f'{checkpoint_name}', checkpoint_gather_fns 116 | ) 117 | 118 | @staticmethod 119 | def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None, max_buffer_size=0): 120 | if shard_fns is not None: 121 | shard_fns = flatten_dict( 122 | to_state_dict(shard_fns) 123 | ) 124 | if remove_dict_prefix is not None: 125 | remove_dict_prefix = tuple(remove_dict_prefix) 126 | flattend_train_state = {} 127 | with open_file(path) as fin: 128 | # 83886080 bytes = 80 MB, which is 16 blocks on GCS 129 | unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=max_buffer_size) 130 | for key, value in unpacker: 131 | key = tuple(key) 132 | if remove_dict_prefix is not None: 133 | if key[:len(remove_dict_prefix)] == remove_dict_prefix: 134 | key = key[len(remove_dict_prefix):] 135 | else: 136 | continue 137 | 138 | tensor = from_bytes(None, value) 139 | if shard_fns is not None: 140 | tensor = shard_fns[key](tensor) 141 | flattend_train_state[key] = tensor 142 | 143 | if target is not None: 144 | flattened_target = flatten_dict( 145 | to_state_dict(target), keep_empty_nodes=True 146 | ) 147 | for key, value in flattened_target.items(): 148 | if key not in flattend_train_state and value == empty_node: 149 | flattend_train_state[key] = value 150 | 151 | train_state = unflatten_dict(flattend_train_state) 152 | if target is None: 153 | return train_state 154 | 155 | return from_state_dict(target, train_state) 156 | 157 | @staticmethod 158 | def load_flax_checkpoint(path, target=None, shard_fns=None): 159 | """ Load a standard flax checkpoint that's not saved with the 160 | msgpack streaming format. 161 | """ 162 | with open_file(path, "rb") as fin: 163 | encoded_bytes = fin.read() 164 | 165 | state_dict = flax.serialization.msgpack_restore(encoded_bytes) 166 | if shard_fns is not None: 167 | shard_fns = to_state_dict(shard_fns) 168 | state_dict = tree_apply(shard_fns, state_dict) 169 | 170 | if target is None: 171 | return state_dict 172 | return from_state_dict(target, state_dict) 173 | 174 | @classmethod 175 | def load_trainstate_checkpoint(cls, load_from, trainstate_target=None, 176 | trainstate_shard_fns=None, 177 | disallow_trainstate=False, 178 | max_buffer_size=0): 179 | if trainstate_target is not None: 180 | params_target = trainstate_target.params['params'] 181 | else: 182 | params_target = None 183 | 184 | if trainstate_shard_fns is not None: 185 | params_shard_fns = trainstate_shard_fns.params['params'] 186 | else: 187 | params_shard_fns = None 188 | 189 | load_type, load_path = load_from.split('::', 1) 190 | if disallow_trainstate: 191 | assert load_type != 'trainstate', 'Loading full trainstate is not allowed!' 192 | train_state = None 193 | restored_params = None 194 | if load_type == 'trainstate': 195 | # Load the entire train state in the streaming format 196 | train_state = cls.load_checkpoint( 197 | path=load_path, 198 | target=trainstate_target, 199 | shard_fns=trainstate_shard_fns, 200 | max_buffer_size=max_buffer_size, 201 | ) 202 | elif load_type == 'trainstate_params': 203 | # Load the params part of the train state in the streaming format 204 | restored_params = cls.load_checkpoint( 205 | path=load_path, 206 | target=params_target, 207 | shard_fns=params_shard_fns, 208 | remove_dict_prefix=('params', 'params'), 209 | max_buffer_size=max_buffer_size, 210 | ) 211 | restored_params = flax.core.frozen_dict.freeze( 212 | {'params': restored_params} 213 | ) 214 | elif load_type == 'params': 215 | # Load the params in the streaming format 216 | restored_params = cls.load_checkpoint( 217 | path=load_path, 218 | target=params_target, 219 | shard_fns=params_shard_fns, 220 | max_buffer_size=max_buffer_size 221 | ) 222 | restored_params = flax.core.frozen_dict.freeze( 223 | {'params': restored_params} 224 | ) 225 | elif load_type == 'flax_params': 226 | # Load the params in the standard flax format (non-streaming) 227 | # This requires the entire params to fit in memory 228 | restored_params = cls.load_flax_checkpoint( 229 | path=load_path, 230 | target=params_target, 231 | shard_fns=params_shard_fns 232 | ) 233 | restored_params = flax.core.frozen_dict.freeze( 234 | {'params': restored_params} 235 | ) 236 | else: 237 | raise ValueError(f'Invalid load_from type: {load_type}') 238 | 239 | return train_state, restored_params 240 | -------------------------------------------------------------------------------- /tux/config.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import os 3 | import pprint 4 | import random 5 | import tempfile 6 | import time 7 | from copy import deepcopy 8 | 9 | import absl.flags 10 | from absl import logging 11 | from ml_collections import ConfigDict 12 | from ml_collections.config_dict.config_dict import \ 13 | placeholder as config_placeholder 14 | from ml_collections.config_flags import config_flags 15 | 16 | 17 | def config_dict(*args, **kwargs): 18 | return ConfigDict(dict(*args, **kwargs)) 19 | 20 | 21 | def define_flags_with_default(**kwargs): 22 | for key, val in kwargs.items(): 23 | if isinstance(val, tuple): 24 | val, help_str = val 25 | else: 26 | help_str = "" 27 | 28 | if isinstance(val, ConfigDict): 29 | config_flags.DEFINE_config_dict(key, val) 30 | elif isinstance(val, bool): 31 | # Note that True and False are instances of int. 32 | absl.flags.DEFINE_bool(key, val, help_str) 33 | elif isinstance(val, int): 34 | absl.flags.DEFINE_integer(key, val, help_str) 35 | elif isinstance(val, float): 36 | absl.flags.DEFINE_float(key, val, help_str) 37 | elif isinstance(val, str): 38 | absl.flags.DEFINE_string(key, val, help_str) 39 | else: 40 | raise ValueError("Incorrect value type") 41 | return absl.flags.FLAGS, kwargs 42 | 43 | 44 | def print_flags(flags, flags_def): 45 | flag_srings = [ 46 | "{}: {}".format(key, val) 47 | for key, val in get_user_flags(flags, flags_def).items() 48 | ] 49 | logging.info( 50 | "Hyperparameter configs: \n{}".format( 51 | pprint.pformat(flag_srings) 52 | ) 53 | ) 54 | 55 | 56 | def get_user_flags(flags, flags_def): 57 | output = {} 58 | for key in flags_def: 59 | val = getattr(flags, key) 60 | if isinstance(val, ConfigDict): 61 | output.update(flatten_config_dict(val, prefix=key)) 62 | else: 63 | output[key] = val 64 | 65 | return output 66 | 67 | 68 | def user_flags_to_config_dict(flags, flags_def): 69 | output = ConfigDict() 70 | for key in flags_def: 71 | output[key] = getattr(flags, key) 72 | 73 | return output 74 | 75 | 76 | def update_config_dict(config, updates=None): 77 | updated_config = deepcopy(config) 78 | if updates is not None: 79 | updated_config.update(ConfigDict(updates).copy_and_resolve_references()) 80 | return updated_config 81 | 82 | 83 | def flatten_config_dict(config, prefix=None): 84 | output = {} 85 | for key, val in config.items(): 86 | if isinstance(val, ConfigDict) or isinstance(val, dict): 87 | output.update(flatten_config_dict(val, prefix=key)) 88 | else: 89 | if prefix is not None: 90 | output["{}.{}".format(prefix, key)] = val 91 | else: 92 | output[key] = val 93 | return output 94 | 95 | 96 | def function_args_to_config(fn, none_arg_types=None, exclude_args=None, override_args=None): 97 | config = ConfigDict() 98 | arg_spec = inspect.getfullargspec(fn) 99 | n_args = len(arg_spec.defaults) 100 | arg_names = arg_spec.args[-n_args:] 101 | default_values = arg_spec.defaults 102 | for name, value in zip(arg_names, default_values): 103 | if exclude_args is not None and name in exclude_args: 104 | continue 105 | elif override_args is not None and name in override_args: 106 | config[name] = override_args[name] 107 | elif none_arg_types is not None and value is None and name in none_arg_types: 108 | config[name] = config_placeholder(none_arg_types[name]) 109 | else: 110 | config[name] = value 111 | 112 | return config 113 | -------------------------------------------------------------------------------- /tux/distributed.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import math 3 | import os 4 | import random 5 | import re 6 | from functools import partial 7 | from typing import Any, Mapping, NamedTuple, Text, Tuple, Union 8 | 9 | import flax 10 | import jax 11 | import jax.numpy as jnp 12 | import numpy as np 13 | from jax.experimental import mesh_utils 14 | from jax.experimental.pjit import pjit 15 | try: 16 | from jax.lax import with_sharding_constraint as _with_sharding_constraint 17 | except: 18 | from jax.experimental.pjit import \ 19 | with_sharding_constraint as _with_sharding_constraint 20 | from jax.interpreters import pxla 21 | from jax.sharding import Mesh 22 | from jax.sharding import PartitionSpec as PS 23 | from ml_collections import ConfigDict 24 | from ml_collections.config_dict.config_dict import placeholder 25 | from transformers import FlaxLogitsWarper 26 | 27 | from .jax_utils import named_tree_map 28 | 29 | 30 | class JaxDistributedConfig(object): 31 | """ Utility class for initializing JAX distributed. """ 32 | 33 | @staticmethod 34 | def get_default_config(updates=None): 35 | config = ConfigDict() 36 | config.initialize_jax_distributed = False 37 | config.coordinator_address = placeholder(str) 38 | config.num_processes = placeholder(int) 39 | config.process_id = placeholder(int) 40 | config.local_device_ids = placeholder(str) 41 | 42 | if updates is not None: 43 | config.update(ConfigDict(updates).copy_and_resolve_references()) 44 | return config 45 | 46 | @classmethod 47 | def initialize(cls, config): 48 | config = cls.get_default_config(config) 49 | if config.initialize_jax_distributed: 50 | if config.local_device_ids is not None: 51 | local_device_ids = [int(x) for x in config.local_device_ids.split(',')] 52 | else: 53 | local_device_ids = None 54 | 55 | jax.distributed.initialize( 56 | coordinator_address=config.coordinator_address, 57 | num_processes=config.num_processes, 58 | process_id=config.process_id, 59 | local_device_ids=local_device_ids, 60 | ) 61 | 62 | 63 | class FlaxTemperatureLogitsWarper(FlaxLogitsWarper): 64 | """ JIT traceable version of FlaxLogitsWarper that performs temperature scaling.""" 65 | def __init__(self, temperature): 66 | self.temperature = temperature 67 | 68 | def __call__(self, input_ids, scores, cur_len): 69 | return scores / jnp.clip(self.temperature, a_min=1e-8) 70 | 71 | 72 | def make_shard_and_gather_fns(partition_specs, dtype_specs=None): 73 | """ Create pytree of sharding and gathering functions from pytree of 74 | partition specs. 75 | """ 76 | float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64) 77 | 78 | def make_to_dtype_fn(dtype_spec): 79 | def to_dtype(tensor): 80 | if dtype_specs in float_dtypes and getattr(tensor, 'dtype', None) in float_dtypes: 81 | # Convert all float tensors to the same dtype 82 | return tensor.astype(dtype_specs) 83 | elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'): 84 | return tensor.astype(dtype_spec.dtype) 85 | return tensor 86 | return to_dtype 87 | 88 | def make_shard_fn(partition_spec, dtype_spec=None): 89 | jax_shard_function = pjit( 90 | make_to_dtype_fn(dtype_spec), 91 | in_shardings=None, 92 | out_shardings=partition_spec 93 | ) 94 | def shard_fn(tensor): 95 | return jax_shard_function(tensor).block_until_ready() 96 | return shard_fn 97 | 98 | def make_gather_fn(partition_spec, dtype_spec=None): 99 | jax_gather_fn = pjit( 100 | make_to_dtype_fn(dtype_spec), 101 | in_shardings=partition_spec, 102 | out_shardings=None 103 | ) 104 | return jax_gather_fn 105 | 106 | if dtype_specs is None or dtype_specs in float_dtypes: 107 | shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs) 108 | gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs) 109 | else: 110 | shard_fns = jax.tree_util.tree_map( 111 | make_shard_fn, partition_specs, dtype_specs 112 | ) 113 | gather_fns = jax.tree_util.tree_map( 114 | make_gather_fn, partition_specs, dtype_specs 115 | ) 116 | return shard_fns, gather_fns 117 | 118 | 119 | def get_jax_mesh(axis_dims, names): 120 | if axis_dims.startswith('!'): 121 | # Allow splitting a physical mesh axis if needed 122 | mesh_axis_splitting = True 123 | axis_dims = axis_dims[1:] 124 | else: 125 | mesh_axis_splitting = False 126 | 127 | if ':' in axis_dims: 128 | dims = [] 129 | dim_names = [] 130 | for axis in axis_dims.split(','): 131 | name, dim = axis.split(':') 132 | assert name in names 133 | dims.append(int(dim)) 134 | dim_names.append(name) 135 | assert(set(dim_names) == set(names)) 136 | else: 137 | dims = [int(x) for x in axis_dims.split(',')] 138 | dim_names = names 139 | assert len(dims) == len(names) 140 | mesh_shape = np.arange(jax.device_count()).reshape(dims).shape 141 | if mesh_axis_splitting: 142 | physical_mesh = np.array(jax.devices()).reshape(mesh_shape) 143 | else: 144 | physical_mesh = mesh_utils.create_device_mesh(mesh_shape) 145 | return Mesh(physical_mesh, dim_names) 146 | 147 | 148 | def names_in_current_mesh(*names): 149 | """ Check if current mesh axes contain these names. """ 150 | mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names 151 | return set(names) <= set(mesh_axis_names) 152 | 153 | 154 | def get_names_from_parition_spec(partition_specs): 155 | """ Return axis names from partition specs. """ 156 | names = set() 157 | if isinstance(partition_specs, dict): 158 | partition_specs = partition_specs.values() 159 | for item in partition_specs: 160 | if item is None: 161 | continue 162 | elif isinstance(item, str): 163 | names.add(item) 164 | else: 165 | names.update(get_names_from_parition_spec(item)) 166 | 167 | return list(names) 168 | 169 | 170 | def with_sharding_constraint(x, partition_specs): 171 | """ A smarter version of with_sharding_constraint that only applies the 172 | constraint if the current mesh contains the axes in the partition specs. 173 | """ 174 | axis_names = get_names_from_parition_spec(partition_specs) 175 | if names_in_current_mesh(*axis_names): 176 | x = _with_sharding_constraint(x, partition_specs) 177 | return x 178 | 179 | 180 | def match_partition_rules(rules, params): 181 | """ Returns a pytree of PartitionSpec according to rules. Supports handling 182 | Flax TrainState and Optax optimizer state. 183 | """ 184 | def get_partition_spec(name, leaf): 185 | if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1: 186 | """ Don't partition scalar values. """ 187 | return PS() 188 | for rule, ps in rules: 189 | if re.search(rule, name) is not None: 190 | return ps 191 | raise ValueError(f'Partition rule not found for param: {name}') 192 | return named_tree_map(get_partition_spec, params, sep='/') 193 | -------------------------------------------------------------------------------- /tux/jax_utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import random 3 | 4 | import flax 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | from flax.core import FrozenDict 9 | from flax.training.train_state import TrainState 10 | 11 | 12 | class JaxRNG(object): 13 | """ A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside 14 | pure function. 15 | """ 16 | 17 | @classmethod 18 | def from_seed(cls, seed): 19 | return cls(jax.random.PRNGKey(seed)) 20 | 21 | def __init__(self, rng): 22 | self.rng = rng 23 | 24 | def __call__(self, keys=None): 25 | if keys is None: 26 | self.rng, split_rng = jax.random.split(self.rng) 27 | return split_rng 28 | elif isinstance(keys, int): 29 | split_rngs = jax.random.split(self.rng, num=keys + 1) 30 | self.rng = split_rngs[0] 31 | return tuple(split_rngs[1:]) 32 | else: 33 | split_rngs = jax.random.split(self.rng, num=len(keys) + 1) 34 | self.rng = split_rngs[0] 35 | return {key: val for key, val in zip(keys, split_rngs[1:])} 36 | 37 | 38 | def wrap_function_with_rng(rng): 39 | """ To be used as decorator, automatically bookkeep a RNG for the wrapped function. """ 40 | def wrap_function(function): 41 | def wrapped(*args, **kwargs): 42 | nonlocal rng 43 | rng, split_rng = jax.random.split(rng) 44 | return function(split_rng, *args, **kwargs) 45 | return wrapped 46 | return wrap_function 47 | 48 | 49 | def init_rng(seed): 50 | global jax_utils_rng 51 | jax_utils_rng = JaxRNG.from_seed(seed) 52 | 53 | 54 | def next_rng(*args, **kwargs): 55 | global jax_utils_rng 56 | return jax_utils_rng(*args, **kwargs) 57 | 58 | 59 | def flatten_tree(xs, is_leaf=None, sep=None): 60 | """ A stronger version of flax.traverse_util.flatten_dict, supports 61 | dict, tuple, list and TrainState. Tuple and list indices will be 62 | converted to strings. 63 | """ 64 | tree_node_classes = (FrozenDict, dict, tuple, list, TrainState) 65 | if not isinstance(xs, tree_node_classes): 66 | ValueError('fUnsupported node type: {type(xs)}') 67 | 68 | def _is_leaf(prefix, fx): 69 | if is_leaf is not None: 70 | return is_leaf(prefix, xs) 71 | return False 72 | 73 | def _key(path): 74 | if sep is None: 75 | return path 76 | return sep.join(path) 77 | 78 | def _convert_to_dict(xs): 79 | if isinstance(xs, (FrozenDict, dict)): 80 | return xs 81 | elif isinstance(xs, (tuple, list)): 82 | return {f'{i}': v for i, v in enumerate(xs)} 83 | elif isinstance(xs, TrainState): 84 | output = {} 85 | for field in dataclasses.fields(xs): 86 | if 'pytree_node' not in field.metadata or field.metadata['pytree_node']: 87 | output[field.name] = getattr(xs, field.name) 88 | return output 89 | else: 90 | raise ValueError('fUnsupported node type: {type(xs)}') 91 | 92 | def _flatten(xs, prefix): 93 | if not isinstance(xs, tree_node_classes) or _is_leaf(prefix, xs): 94 | return {_key(prefix): xs} 95 | 96 | result = {} 97 | is_empty = True 98 | for (key, value) in _convert_to_dict(xs).items(): 99 | is_empty = False 100 | path = prefix + (key, ) 101 | result.update(_flatten(value, path)) 102 | return result 103 | 104 | return _flatten(xs, ()) 105 | 106 | 107 | def named_tree_map(f, tree, is_leaf=None, sep=None): 108 | """ An extended version of jax.tree_util.tree_map, where the mapped function 109 | f takes both the name (path) and the tree leaf as input. 110 | """ 111 | flattened_tree = flatten_tree(tree, is_leaf=is_leaf, sep=sep) 112 | id_to_name = {id(val): key for key, val in flattened_tree.items()} 113 | def map_fn(leaf): 114 | name = id_to_name[id(leaf)] 115 | return f(name, leaf) 116 | return jax.tree_util.tree_map(map_fn, tree) 117 | 118 | 119 | def get_pytree_shape_info(tree): 120 | flattend_tree = flatten_tree(tree, sep='/') 121 | shapes = [] 122 | for key in sorted(list(flattend_tree.keys())): 123 | val = flattend_tree[key] 124 | shapes.append(f'{key}: {val.dtype}, {val.shape}') 125 | return '\n'.join(shapes) 126 | 127 | 128 | def collect_metrics(metrics, names, prefix=None): 129 | collected = {} 130 | for name in names: 131 | if name in metrics: 132 | collected[name] = jnp.mean(metrics[name]) 133 | if prefix is not None: 134 | collected = { 135 | '{}/{}'.format(prefix, key): value for key, value in collected.items() 136 | } 137 | return collected 138 | 139 | 140 | def set_random_seed(seed): 141 | np.random.seed(seed) 142 | random.seed(seed) 143 | init_rng(seed) 144 | 145 | 146 | def tree_apply(fns, tree): 147 | """ Apply a pytree of functions to the pytree. """ 148 | return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree) 149 | -------------------------------------------------------------------------------- /tux/loss.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | def mse_loss(val, target, valid=None): 6 | if valid is None: 7 | valid = jnp.ones((*target.shape[:2], 1)) 8 | valid = valid.astype(jnp.float32) 9 | loss = jnp.mean( 10 | jnp.where( 11 | valid > 0.0, 12 | jnp.square(val - target), 13 | 0.0 14 | ) 15 | ) 16 | return loss 17 | 18 | 19 | def cross_entropy_loss(logits, labels, smoothing_factor=0.): 20 | num_classes = logits.shape[-1] 21 | if labels.dtype == jnp.int32 or labels.dtype == jnp.int64: 22 | labels = jax.nn.one_hot(labels, num_classes) 23 | if smoothing_factor > 0.: 24 | labels = labels * (1. - smoothing_factor) + smoothing_factor / num_classes 25 | logp = jax.nn.log_softmax(logits, axis=-1) 26 | return -jnp.mean(jnp.sum(logp * labels, axis=-1)) 27 | 28 | 29 | def cross_entropy_loss_and_accuracy(logits, tokens, valid=None, bias=None): 30 | if valid is None: 31 | valid = jnp.ones(tokens.shape[:2]) 32 | valid = valid.astype(jnp.float32) 33 | valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10) 34 | logits = logits.astype(jnp.float32) # for numerical stability 35 | token_log_prob = jnp.squeeze( 36 | jnp.take_along_axis( 37 | jax.nn.log_softmax(logits, axis=-1), 38 | jnp.expand_dims(tokens, -1), 39 | axis=-1, 40 | ), 41 | -1, 42 | ) 43 | token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0)) 44 | if bias is None: 45 | bias = 0.0 46 | loss = -jnp.mean((jnp.sum(token_log_prob, axis=-1) + bias) / valid_text_length) 47 | correct = jnp.where( 48 | valid > 0.0, 49 | jnp.argmax(logits, axis=-1) == tokens, 50 | jnp.array(False) 51 | ) 52 | accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length) 53 | return loss, accuracy 54 | -------------------------------------------------------------------------------- /tux/misc.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | 7 | def get_gradient_checkpoint_policy(name): 8 | return { 9 | 'everything_saveable': jax.checkpoint_policies.everything_saveable, 10 | 'nothing_saveable': jax.checkpoint_policies.nothing_saveable, 11 | 'dots_saveable': jax.checkpoint_policies.dots_saveable, 12 | 'dots_with_no_batch_dims_saveable': jax.checkpoint_policies.dots_with_no_batch_dims_saveable, 13 | }[name] 14 | 15 | 16 | def get_float_dtype_by_name(dtype): 17 | return { 18 | 'bf16': jnp.bfloat16, 19 | 'bfloat16': jnp.bfloat16, 20 | 'fp16': jnp.float16, 21 | 'float16': jnp.float16, 22 | 'fp32': jnp.float32, 23 | 'float32': jnp.float32, 24 | 'fp64': jnp.float64, 25 | 'float64': jnp.float64, 26 | }[dtype] 27 | 28 | 29 | def float_tensor_to_dtype(tensor, dtype): 30 | if dtype is None or dtype == '': 31 | return tensor 32 | if isinstance(dtype, str): 33 | dtype = get_float_dtype_by_name(dtype) 34 | float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64) 35 | if getattr(tensor, 'dtype', None) in float_dtypes: 36 | tensor = tensor.astype(dtype) 37 | return tensor 38 | 39 | 40 | def float_to_dtype(tree, dtype): 41 | return jax.tree_util.tree_map( 42 | partial(float_tensor_to_dtype, dtype=dtype), tree 43 | ) 44 | -------------------------------------------------------------------------------- /tux/optimizers.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import os 3 | import random 4 | import re 5 | import time 6 | from functools import partial 7 | from typing import Any, Mapping, NamedTuple, Text, Tuple, Union 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import numpy as np 12 | import optax 13 | from absl import logging 14 | from ml_collections import ConfigDict 15 | from ml_collections.config_dict import config_dict 16 | 17 | from .jax_utils import tree_apply, named_tree_map 18 | from .misc import float_tensor_to_dtype 19 | 20 | 21 | class OptimizerFactory(object): 22 | """ Configurable optax optimizer factory. """ 23 | 24 | def __init__(self): 25 | raise NotImplementedError 26 | 27 | @staticmethod 28 | def get_default_config(updates=None): 29 | config = ConfigDict() 30 | config.accumulate_gradient_steps = 1 31 | config.type = 'adamw' 32 | config.palm_optimizer = PalmOptimizerFactory.get_default_config() 33 | config.adamw_optimizer = AdamWOptimizerFactory.get_default_config() 34 | 35 | if updates is not None: 36 | config.update(ConfigDict(updates).copy_and_resolve_references()) 37 | return config 38 | 39 | @classmethod 40 | def get_optimizer(cls, config, weight_decay_mask=None, frozen_param_mask=None): 41 | prepend_to_chain = [] 42 | if frozen_param_mask is not None: 43 | prepend_to_chain.append(optax.masked(optax.set_to_zero(), frozen_param_mask)) 44 | 45 | config = cls.get_default_config(config) 46 | if config.type == 'palm': 47 | optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer( 48 | config.palm_optimizer, weight_decay_mask, prepend_to_chain 49 | ) 50 | elif config.type == 'adamw': 51 | optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer( 52 | config.adamw_optimizer, weight_decay_mask, prepend_to_chain 53 | ) 54 | else: 55 | raise ValueError(f'Unknown optimizer type: {config.type}') 56 | 57 | if config.accumulate_gradient_steps > 1: 58 | optimizer = optax.MultiSteps( 59 | optimizer, config.accumulate_gradient_steps 60 | ) 61 | 62 | return optimizer, optimizer_info 63 | 64 | 65 | class PalmOptimizerFactory(object): 66 | """ PaLM optimizer factory. This optimizer implements the optimizer 67 | described in the PaLM paper: https://arxiv.org/abs/2204.02311 68 | """ 69 | 70 | def __init__(self): 71 | raise NotImplementedError 72 | 73 | @staticmethod 74 | def get_default_config(updates=None): 75 | config = ConfigDict() 76 | config.lr = 0.01 77 | config.lr_warmup_steps = 10000 78 | config.b1 = 0.9 79 | config.b2 = 0.99 80 | config.clip_gradient = 1.0 81 | config.weight_decay = 1e-4 82 | config.bf16_momentum = False 83 | 84 | if updates is not None: 85 | config.update(ConfigDict(updates).copy_and_resolve_references()) 86 | return config 87 | 88 | @classmethod 89 | def get_optimizer(cls, config, weight_decay_mask=None): 90 | config = cls.get_default_config(config) 91 | 92 | def learning_rate_schedule(step): 93 | multiplier = config.lr / 0.01 94 | return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps)) 95 | 96 | def weight_decay_schedule(step): 97 | multiplier = config.weight_decay / 1e-4 98 | return -multiplier * jnp.square(learning_rate_schedule(step)) 99 | 100 | optimizer_info = dict( 101 | learning_rate_schedule=learning_rate_schedule, 102 | weight_decay_schedule=weight_decay_schedule, 103 | ) 104 | 105 | optimizer = optax.chain( 106 | optax.clip_by_global_norm(config.clip_gradient), 107 | optax.adafactor( 108 | learning_rate=learning_rate_schedule, 109 | multiply_by_parameter_scale=True, 110 | momentum=config.b1, 111 | decay_rate=config.b2, 112 | factored=False, 113 | clipping_threshold=None, 114 | dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32, 115 | ), 116 | optax_add_scheduled_weight_decay( 117 | weight_decay_schedule, weight_decay_mask 118 | ) 119 | ) 120 | return optimizer, optimizer_info 121 | 122 | 123 | class AdamWOptimizerFactory(object): 124 | """ AdamW optimizer with cosine schedule. """ 125 | 126 | def __init__(self): 127 | raise NotImplementedError 128 | 129 | @staticmethod 130 | def get_default_config(updates=None): 131 | config = ConfigDict() 132 | config.init_lr = 0.0 133 | config.end_lr = 0.001 134 | config.lr = 0.01 135 | config.lr_warmup_steps = 2000 136 | config.lr_decay_steps = 500000 137 | config.b1 = 0.9 138 | config.b2 = 0.95 139 | config.clip_gradient = 1.0 140 | config.weight_decay = 1e-4 141 | config.bf16_momentum = False 142 | config.multiply_by_parameter_scale = False 143 | 144 | if updates is not None: 145 | config.update(ConfigDict(updates).copy_and_resolve_references()) 146 | return config 147 | 148 | @classmethod 149 | def get_optimizer(cls, config, weight_decay_mask=None, prepend_to_chain=tuple()): 150 | config = cls.get_default_config(config) 151 | 152 | learning_rate_schedule = optax.warmup_cosine_decay_schedule( 153 | init_value=config.init_lr, 154 | peak_value=config.lr, 155 | warmup_steps=config.lr_warmup_steps, 156 | decay_steps=config.lr_decay_steps, 157 | end_value=config.end_lr, 158 | ) 159 | 160 | optimizer_info = dict( 161 | learning_rate_schedule=learning_rate_schedule, 162 | ) 163 | 164 | if config.multiply_by_parameter_scale: 165 | optimizer = optax.chain( 166 | *prepend_to_chain, 167 | optax.clip_by_global_norm(config.clip_gradient), 168 | optax.adafactor( 169 | learning_rate=learning_rate_schedule, 170 | multiply_by_parameter_scale=True, 171 | momentum=config.b1, 172 | decay_rate=config.b2, 173 | factored=False, 174 | clipping_threshold=None, 175 | dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32, 176 | ), 177 | optax_add_scheduled_weight_decay( 178 | lambda step: -learning_rate_schedule(step) * config.weight_decay, 179 | weight_decay_mask 180 | ) 181 | ) 182 | else: 183 | optimizer = optax.chain( 184 | *prepend_to_chain, 185 | optax.clip_by_global_norm(config.clip_gradient), 186 | optax.adamw( 187 | learning_rate=learning_rate_schedule, 188 | weight_decay=config.weight_decay, 189 | b1=config.b1, 190 | b2=config.b2, 191 | mask=weight_decay_mask, 192 | mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32, 193 | ), 194 | ) 195 | 196 | return optimizer, optimizer_info 197 | 198 | 199 | class OptaxScheduledWeightDecayState(NamedTuple): 200 | count: jnp.array 201 | 202 | 203 | def optax_add_scheduled_weight_decay(schedule_fn, mask=None): 204 | """ Apply weight decay with schedule. """ 205 | 206 | def init_fn(params): 207 | del params 208 | return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32)) 209 | 210 | def update_fn(updates, state, params): 211 | if params is None: 212 | raise ValueError('Params cannot be None for weight decay!') 213 | 214 | weight_decay = schedule_fn(state.count) 215 | updates = jax.tree_util.tree_map( 216 | lambda g, p: g + weight_decay * p, updates, params 217 | ) 218 | return updates, OptaxScheduledWeightDecayState( 219 | count=optax.safe_int32_increment(state.count) 220 | ) 221 | 222 | if mask is not None: 223 | return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask) 224 | return optax.GradientTransformation(init_fn, update_fn) 225 | 226 | 227 | def get_mask(exclusions, tf_map=None): 228 | """ Return a mask function that computes the pytree masks 229 | according to the given exclusion rules. 230 | """ 231 | if tf_map is None: 232 | tf_map = {True: True, False: False} 233 | else: 234 | assert len(tf_map) == 2 and True in tf_map and False in tf_map 235 | 236 | def to_keep(name, _): 237 | for rule in exclusions: 238 | if re.search(rule, name) is not None: 239 | return False 240 | return True 241 | 242 | def mask_fn(params): 243 | return named_tree_map(lambda *args: tf_map[to_keep(*args)], params, sep='/') 244 | 245 | return mask_fn 246 | 247 | 248 | # For backwards compatibility 249 | get_weight_decay_mask = get_mask -------------------------------------------------------------------------------- /tux/stats.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | def global_norm(tree): 6 | """ Return the global L2 norm of a pytree. """ 7 | squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree) 8 | flattened, _ = jax.flatten_util.ravel_pytree(squared) 9 | return jnp.sqrt(jnp.sum(flattened)) 10 | 11 | 12 | def average_metrics(metrics): 13 | return jax.tree_map( 14 | lambda *args: jnp.mean(jnp.stack(args)), 15 | *metrics 16 | ) 17 | -------------------------------------------------------------------------------- /tux/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import tempfile 5 | import time 6 | import uuid 7 | from copy import copy 8 | from io import BytesIO 9 | from socket import gethostname 10 | 11 | import cloudpickle as pickle 12 | import gcsfs 13 | import numpy as np 14 | 15 | 16 | class Timer(object): 17 | def __init__(self): 18 | self._time = None 19 | 20 | def __enter__(self): 21 | self._start_time = time.time() 22 | return self 23 | 24 | def __exit__(self, exc_type, exc_value, exc_tb): 25 | self._time = time.time() - self._start_time 26 | 27 | def __call__(self): 28 | return self._time 29 | 30 | 31 | def open_file(path, mode='rb', block_size=None, cache_type='readahead'): 32 | if path.startswith("gs://"): 33 | logging.getLogger("fsspec").setLevel(logging.WARNING) 34 | return gcsfs.GCSFileSystem().open(path, mode, block_size=block_size, cache_type=cache_type) 35 | else: 36 | return open(path, mode) 37 | 38 | 39 | def save_pickle(obj, path): 40 | with open_file(path, 'wb') as fout: 41 | pickle.dump(obj, fout) 42 | 43 | 44 | def load_pickle(path): 45 | with open_file(path, 'rb') as fin: 46 | data = pickle.load(fin) 47 | return data 48 | 49 | 50 | def text_to_array(text, encoding='utf-8'): 51 | return np.frombuffer(text.encode(encoding), dtype='uint8') 52 | 53 | 54 | def array_to_text(array, encoding='utf-8'): 55 | with BytesIO(array) as fin: 56 | text = fin.read().decode(encoding) 57 | return text 58 | 59 | 60 | def check_exists(path): 61 | if path.startswith("gs://"): 62 | return gcsfs.GCSFileSystem().exists(path) 63 | else: 64 | return os.path.exists(path) 65 | -------------------------------------------------------------------------------- /tux/wandb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import uuid 4 | from socket import gethostname 5 | 6 | import wandb 7 | from ml_collections import ConfigDict 8 | from ml_collections.config_dict.config_dict import placeholder 9 | 10 | from .config import flatten_config_dict, update_config_dict 11 | from .utils import save_pickle 12 | 13 | 14 | class WandBLogger(object): 15 | @staticmethod 16 | def get_default_config(updates=None): 17 | config = ConfigDict() 18 | config.project_id = "" 19 | config.project_entity = placeholder(str) 20 | config.experiment_id = placeholder(str) 21 | config.experiment_group = placeholder(str) 22 | config.append_uuid = True 23 | config.experiment_note = placeholder(str) 24 | 25 | config.output_dir = "/tmp/" 26 | config.wandb_dir = "" 27 | config.profile_dir = "" 28 | 29 | config.online = False 30 | 31 | config.reinit = False 32 | config.resume = "allow" 33 | 34 | return update_config_dict(config, updates) 35 | 36 | def __init__(self, config, variant, enable=True): 37 | self.enable = enable 38 | self.config = self.get_default_config(config) 39 | 40 | if self.config.experiment_id is None or self.config.experiment_id == "": 41 | self.config.experiment_id = uuid.uuid4().hex 42 | else: 43 | if self.config.append_uuid: 44 | self.config.experiment_id = ( 45 | str(self.config.experiment_id) + "_" + uuid.uuid4().hex 46 | ) 47 | else: 48 | self.config.experiment_id = str(self.config.experiment_id) 49 | 50 | if self.enable: 51 | if self.config.output_dir == "": 52 | self.config.output_dir = tempfile.mkdtemp() 53 | else: 54 | self.config.output_dir = os.path.join( 55 | self.config.output_dir, self.config.experiment_id 56 | ) 57 | if not self.config.output_dir.startswith("gs://"): 58 | os.makedirs(self.config.output_dir, exist_ok=True) 59 | 60 | if self.config.wandb_dir == "": 61 | if not self.config.output_dir.startswith("gs://"): 62 | # Use the same directory as output_dir if it is not a GCS path. 63 | self.config.wandb_dir = self.config.output_dir 64 | else: 65 | # Otherwise, use a temporary directory. 66 | self.config.wandb_dir = tempfile.mkdtemp() 67 | else: 68 | # Join the wandb_dir with the experiment_id. 69 | self.config.wandb_dir = os.path.join( 70 | self.config.wandb_dir, self.config.experiment_id 71 | ) 72 | os.makedirs(self.config.wandb_dir, exist_ok=True) 73 | 74 | if self.config.profile_dir == "": 75 | if not self.config.output_dir.startswith("gs://"): 76 | # Use the same directory as output_dir if it is not a GCS path. 77 | self.config.profile_dir = self.config.output_dir 78 | else: 79 | # Otherwise, use a temporary directory. 80 | self.config.profile_dir = tempfile.mkdtemp() 81 | else: 82 | # Join the profile_dir with the experiment_id. 83 | self.config.profile_dir = os.path.join( 84 | self.config.profile_dir, self.config.experiment_id 85 | ) 86 | os.makedirs(self.config.profile_dir, exist_ok=True) 87 | 88 | self._variant = flatten_config_dict(variant) 89 | 90 | if "hostname" not in self._variant: 91 | self._variant["hostname"] = gethostname() 92 | 93 | if self.enable: 94 | self.run = wandb.init( 95 | reinit=self.config.reinit, 96 | config=self._variant, 97 | project=self.config.project_id, 98 | dir=self.config.wandb_dir, 99 | id=self.config.experiment_id, 100 | group=self.config.experiment_group, 101 | resume="allow", 102 | notes=self.config.experiment_note, 103 | entity=self.config.project_entity, 104 | settings=wandb.Settings( 105 | start_method="thread", 106 | _disable_stats=True, 107 | ), 108 | mode="online" if self.config.online else "offline", 109 | ) 110 | else: 111 | self.run = None 112 | 113 | def log(self, *args, **kwargs): 114 | if self.enable: 115 | self.run.log(*args, **kwargs) 116 | 117 | def save_pickle(self, obj, filename): 118 | if self.enable: 119 | save_pickle(obj, os.path.join(self.config.output_dir, filename)) 120 | 121 | @property 122 | def experiment_id(self): 123 | return self.config.experiment_id 124 | 125 | @property 126 | def experiment_group(self): 127 | return self.config.experiment_group 128 | 129 | @property 130 | def variant(self): 131 | return self.config.variant 132 | 133 | @property 134 | def output_dir(self): 135 | return self.config.output_dir 136 | 137 | @property 138 | def wandb_dir(self): 139 | return self.config.wandb_dir 140 | 141 | @property 142 | def profile_dir(self): 143 | return self.config.profile_dir 144 | --------------------------------------------------------------------------------