├── .DS_Store ├── DATASET.md ├── EasyLM ├── __init__.py ├── bpt.py ├── checkpoint.py ├── data.py ├── jax_utils.py ├── models │ ├── __init__.py │ └── llama │ │ ├── convert_easylm_to_hf.py │ │ ├── llama_model.py │ │ ├── llama_serve.py │ │ └── llama_train.py ├── optimizers.py ├── scripts │ ├── __init__.py │ ├── convert_checkpoint.py │ ├── diff_checkpoint.py │ ├── lm_eval_harness.py │ └── lm_eval_json.py └── serving.py ├── LICENSE ├── README.md ├── evaluation ├── .DS_Store ├── EVAL.md ├── LICENSE ├── environment.yml └── vqlm_demo │ ├── .DS_Store │ ├── __init__.py │ ├── batch_generation.py │ ├── eval_perplexity.py │ ├── eval_video_perplexity.py │ ├── eval_videos.py │ ├── generate_videos.py │ ├── inference.py │ ├── torch_vqvae_model.py │ ├── utils.py │ ├── vqvae │ ├── .DS_Store │ ├── __init__.py │ ├── logging.py │ └── modeling_utils.py │ └── vqvae_muse.py ├── images └── visual_sentences.jpg ├── scripts ├── gpu_environment.yml ├── tpu_commands.sh └── tpu_vm_setup.sh └── tokenize_examples ├── detokenization_muse.py ├── map_color.py ├── tokenize_catogory_images_muse.py ├── tokenize_co3d_muse.py ├── tokenize_colorization_dataset_muse.py ├── tokenize_inpainting_dataset_muse.py ├── tokenize_multi_datasets_muse.py ├── tokenize_multi_seq_images_muse.py ├── tokenize_paired_dataset_muse.py ├── tokenize_seq_images_muse.py └── tokenize_video_muse.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/.DS_Store -------------------------------------------------------------------------------- /DATASET.md: -------------------------------------------------------------------------------- 1 | 2 | # Dataset Preparation 3 | 4 | This section describes how to prepare your dataset by tokenizing visual data into visual sentences and then mixing and shuffling the datasets. 5 | 6 | ## Download & Prepare Dataset 7 | We have individual scripts in `./tokenize_examples` to handle different kinds of visual sentences mentioned in the paper. Simple descriptions and instructions are listed as follows. 8 | 9 | ### Pair Datasets 10 | For pair datasets, where the visual sentence is constructed as `[image, label, image, label, image, label...]`, we use `tokenize_examples/tokenize_paired_dataset_muse.py` to generate the visual tokens. For example, for depth maps, surface normals, edges, and segmentation, we first use [prismer](https://github.com/NVlabs/prismer) to generate pseudo labels, then use the script to generate visual sentences. Note that for segmentation, we use an additional color mapping after obtaining the pseudo labels from prismer, which can be done by running `tokenize_examples/map_color.py`. 11 | 12 | ### Video Datasets 13 | For video datasets, the visual sentences are constructed as `[frame1, frame2, frame3, ... framex]`. One can use `tokenize_examples/tokenize_video_muse.py` to generate the visual sentences. The hyperparameter `stride` can be used to control the sampling rate of the extraction of frames from a video. 14 | 15 | ### Colorization Datasets 16 | For colorization datasets, the visual sentences are constructed as `[gray_image, colored_image, gray_image, colored_image, ...]`. One can use `tokenize_examples/tokenize_colorization_dataset_muse.py` to generate the visual sentences. The user only needs to prepare the colored images, and the script will generate the gray counterparts. 17 | 18 | ### Inpainting Datasets 19 | For inpainting datasets, the visual sentences are constructed as `[masked_image, image, masked_image, image, ...]`. One can use `tokenize_examples/tokenize_inpainting_dataset_muse.py` to generate the visual sentences. The user can control the masked ratio by changing `FLAGS.hole_mask_ratio`. 20 | 21 | ### Multi-Datasets 22 | For multi-datasets, the visual sentences are constructed as `[dataset1_image, dataset1_image, dataset2_image, dataset2_image, ...]`. One can use `tokenize_examples/tokenize_multi_datasets_muse.py` to generate the visual sentences. 23 | 24 | ### Category Datasets 25 | For category datasets, the visual sentences are constructed as `[cate1_image, cate1_image, cate2_image, cate2_image,...]`. One can use `tokenize_examples/tokenize_category_images_muse.py` to generate the visual sentences. Note that the user can use `images_per_shot` to customize the number of images for each category and `n_shots` to control the number of categories in the visual sentences. 26 | 27 | In general, visual sentences with different logic can be achieved by combining the basic logic provided above. After you generate your visual sentences, you can always use `tokenize_examples/detokenization_muse.py` to perform a sanity check for the recovery of the visual sentences. 28 | 29 | 30 | ## Visual Sentences 31 | 32 | - Following the concept of Visual Sentences, each image is tokenized separately, and the tokens are concatenated to form a visual sentence. 33 | 34 | ## Generating JSONL Files 35 | 36 | - For each dataset, generate the `dataset*.jsonl` files. 37 | 38 | ### Example Code 39 | 40 | - You can find example code for tokenizing images in the `./tokenize_examples` directory. 41 | 42 | ## Mixing and Shuffling Datasets 43 | 44 | - After generating the JSONL files, the datasets need to be mixed and shuffled. Follow these steps: 45 | 46 | 1. Set the temporary directory and memory allocation: 47 | ```shell 48 | export TMPDIR='/global/scratch/users/yutong/data/temp' 49 | export MEMORY='20' 50 | ``` 51 | 52 | 2. Navigate to your data directory: 53 | ```shell 54 | cd /path/to/your/data/ 55 | ``` 56 | 57 | 3. Mix and shuffle the datasets: 58 | ```shell 59 | cat tokenized_tasks/*.jsonl | terashuf > mix_and_shuffled/dataset.jsonl 60 | ``` 61 | 62 | - This will create a mixed and shuffled dataset file named `dataset.jsonl` in the `mix_and_shuffled` directory. -------------------------------------------------------------------------------- /EasyLM/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/EasyLM/__init__.py -------------------------------------------------------------------------------- /EasyLM/bpt.py: -------------------------------------------------------------------------------- 1 | """An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370 2 | 3 | Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682 4 | """ 5 | 6 | import functools 7 | from typing import NamedTuple 8 | 9 | import flax.linen as nn 10 | import jax 11 | import jax.lax as lax 12 | import jax.numpy as jnp 13 | from einops import rearrange 14 | 15 | ''' 16 | Computing ffn blockwise without materializing the large hidden tensor, training 4x longer sequences than the memory-efficient transformer. 17 | Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023 18 | ''' 19 | def blockwise_ffn(remat_ffn, inputs, chunk_size, deterministic): 20 | # remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable() 21 | # inputs: (batch, seq_len, dim) 22 | # chunk_size: the chunk size to split the sequence 23 | inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size) 24 | def scan_ffn(remat_ffn, carry, hidden_states): 25 | outputs = remat_ffn(hidden_states, deterministic=deterministic) 26 | return carry, outputs 27 | scan_axis = inputs.ndim - 2 28 | _, res = nn.scan( 29 | scan_ffn, 30 | variable_broadcast="params", 31 | split_rngs={"params": False, "dropout": True}, 32 | in_axes=scan_axis, 33 | out_axes=scan_axis, 34 | )(remat_ffn, None, inputs) 35 | res = rearrange(res, 'b c n d -> b (c n) d') 36 | return res 37 | 38 | 39 | ''' 40 | Compute attention blockwise without materializing the full attention matrix, initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021; 41 | flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA efficient implementation; 42 | blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x longer sequences than memory-efficient/flash-attention and fusion of attention and FFN. 43 | ''' 44 | def blockwise_attn(query, key, value, bias, deterministic, 45 | dropout_rng, attn_pdrop, causal, query_chunk_size, 46 | key_chunk_size, dtype, policy, precision, float32_logits, 47 | prevent_cse): 48 | # query, key, value: (batch, seq_len, num_heads, dim_per_head) 49 | # bias: (batch, seq_len) can be used to mask out attention (e.g. padding) 50 | # causal: whether to use causal mask 51 | # policy: one of jax.checkpoint_policies 52 | query = query / jnp.sqrt(query.shape[-1]).astype(dtype) 53 | if float32_logits: 54 | query = query.astype(jnp.float32) 55 | key = key.astype(jnp.float32) 56 | 57 | batch, q_len, num_heads, dim_per_head = query.shape 58 | batch, kv_len, num_heads, dim_per_head = key.shape 59 | batch, kv_len, num_heads, dim_per_head = value.shape 60 | 61 | num_q = q_len // query_chunk_size 62 | num_kv = kv_len // key_chunk_size 63 | query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head)) 64 | key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head)) 65 | value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head)) 66 | 67 | query = jnp.moveaxis(query, 1, 0) 68 | key = jnp.moveaxis(key, 1, 0) 69 | value = jnp.moveaxis(value, 1, 0) 70 | 71 | if bias is not None: 72 | for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)): 73 | assert bias_dim == 1 or bias_dim == broadcast_dim 74 | if not deterministic and attn_pdrop > 0.0: 75 | attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng) 76 | attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len)) 77 | else: 78 | attn_dropout = None 79 | 80 | _chunk_bias_fn = functools.partial( 81 | _chunk_attention_bias, 82 | query_chunk_size, key_chunk_size, bias, deterministic, 83 | attn_dropout, attn_pdrop, causal, dtype) 84 | 85 | def scan_attention(args): 86 | query_chunk, query_chunk_idx = args 87 | 88 | @functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy) 89 | def scan_kv_block(carry, args): 90 | key_chunk, value_chunk, key_chunk_idx = args 91 | (numerator, denominator, prev_max_score) = carry 92 | attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision) 93 | bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx) 94 | bias_chunk = jnp.moveaxis(bias_chunk, 1, 2) 95 | attn_weights = attn_weights + bias_chunk 96 | 97 | max_score = jnp.max(attn_weights, axis=-1, keepdims=True) 98 | max_score = jnp.maximum(prev_max_score, max_score) 99 | max_score = jax.lax.stop_gradient(max_score) 100 | exp_weights = jnp.exp(attn_weights - max_score) 101 | exp_values = jnp.einsum( 102 | 'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision 103 | ) 104 | correction = jnp.exp(prev_max_score - max_score) 105 | numerator = numerator * correction + exp_values 106 | denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True) 107 | return Carry(numerator, denominator, max_score), None 108 | 109 | def skip_upper_half(carry, args): 110 | key_chunk, value_chunk, key_chunk_idx = args 111 | skip_block = jnp.array(False) 112 | if causal: 113 | skip_block = query_chunk_idx < key_chunk_idx 114 | return jax.lax.cond( 115 | skip_block, 116 | lambda carry, args: (carry, None), 117 | scan_kv_block, 118 | carry, 119 | args, 120 | ) 121 | 122 | init_carry = Carry( 123 | jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype), 124 | jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype), 125 | (-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype), 126 | ) 127 | (numerator, denominator, max_score), _ = lax.scan( 128 | skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv)) 129 | ) 130 | outputs = (numerator / denominator).astype(dtype) 131 | return outputs 132 | 133 | _, res = lax.scan( 134 | lambda _, x: ((), scan_attention(x)), 135 | (), xs=(query, jnp.arange(0, num_q)) 136 | ) 137 | res = rearrange(res, 'n b c h d -> b (n c) h d') 138 | return res 139 | 140 | 141 | class Carry(NamedTuple): 142 | numerator: jax.Array 143 | denominator: jax.Array 144 | max_so_far: jax.Array 145 | 146 | 147 | def _chunk_attention_bias(query_chunk_size, key_chunk_size, 148 | bias, deterministic, attn_dropout, attn_pdrop, causal, 149 | dtype, query_chunk_idx, key_chunk_idx): 150 | query_offset = query_chunk_idx * query_chunk_size 151 | key_offset = key_chunk_idx * key_chunk_size 152 | chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype) 153 | if bias is not None: 154 | chunk_bias = lax.dynamic_slice( 155 | bias, 156 | start_indices=(0, 0, query_offset, key_offset), 157 | slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)), 158 | ) 159 | 160 | if causal: 161 | query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0) 162 | key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1) 163 | offset = query_offset - key_offset 164 | query_idx += offset 165 | causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min 166 | chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape) 167 | 168 | if not deterministic and attn_pdrop > 0.0: 169 | attn_dropout_slice = lax.dynamic_slice( 170 | attn_dropout, 171 | start_indices=(0, 0, query_offset, key_offset), 172 | slice_sizes=( 173 | *attn_dropout.shape[:2], 174 | min(attn_dropout.shape[-2], query_chunk_size), 175 | min(attn_dropout.shape[-1], key_chunk_size), 176 | ), 177 | ) 178 | chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min 179 | return chunk_bias.astype(dtype) 180 | 181 | 182 | if __name__ == '__main__': 183 | # test 184 | def reference_attn(query, key, value, causal, dtype): 185 | query = query / jnp.sqrt(query.shape[-1]).astype(dtype) 186 | logits = jnp.einsum("bqhc,bkhc->bhqk", query, key) 187 | if causal: 188 | mask_value = jnp.finfo(logits.dtype).min 189 | _, q_seq_len, _, _ = query.shape 190 | _, kv_seq_len, _, _ = key.shape 191 | mask_shape = (q_seq_len, kv_seq_len) 192 | row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) 193 | col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) 194 | causal_mask = (row_ids < col_ids)[None, None, :, :] 195 | logits = logits + jnp.where(causal_mask, mask_value, 0.0) 196 | weights = jax.nn.softmax(logits, axis=-1) 197 | out = jnp.einsum("bhqk,bkhc->bqhc", weights, value) 198 | return out 199 | 200 | # random inputs 201 | shape = (1, 32, 8, 64) 202 | query = jax.random.normal(jax.random.PRNGKey(0), shape) 203 | key = jax.random.normal(jax.random.PRNGKey(1), shape) 204 | value = jax.random.normal(jax.random.PRNGKey(2), shape) 205 | 206 | causal = True 207 | chunk_size = 4 208 | policy = jax.checkpoint_policies.nothing_saveable() 209 | 210 | blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False) 211 | reference = reference_attn(query, key, value, causal, 'float32') 212 | 213 | assert jnp.allclose(reference, blockwise, atol=1e-6) 214 | -------------------------------------------------------------------------------- /EasyLM/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from ml_collections import ConfigDict 4 | import mlxu 5 | import jax 6 | import jax.numpy as jnp 7 | import flax 8 | from flax.serialization import ( 9 | from_bytes, to_bytes, to_state_dict, from_state_dict 10 | ) 11 | from flax.traverse_util import flatten_dict, unflatten_dict, empty_node 12 | import msgpack 13 | 14 | from EasyLM.jax_utils import tree_apply, float_tensor_to_dtype 15 | 16 | 17 | class StreamingCheckpointer(object): 18 | """ Custom msgpack checkpointer that saves large train states by serializing 19 | and saving tensors one by one in a streaming fashion. Avoids running 20 | out of memory or local TPU disk with default flax checkpointer. 21 | """ 22 | 23 | @staticmethod 24 | def get_default_config(updates=None): 25 | config = ConfigDict() 26 | config.float_dtype = 'bf16' 27 | config.save_optimizer_state = False 28 | 29 | if updates is not None: 30 | config.update(ConfigDict(updates).copy_and_resolve_references()) 31 | return config 32 | 33 | def __init__(self, config, checkpoint_dir, enable=True): 34 | self.config = self.get_default_config(config) 35 | self.checkpoint_dir = checkpoint_dir 36 | self.enable = enable 37 | 38 | def save_checkpoint(self, train_state, filename, gather_fns=None): 39 | if self.enable: 40 | path = os.path.join(self.checkpoint_dir, filename) 41 | else: 42 | path = '/dev/null' 43 | self.save_train_state_to_file( 44 | train_state, path, gather_fns, self.config.float_dtype 45 | ) 46 | 47 | @staticmethod 48 | def save_train_state_to_file(train_state, path, gather_fns=None, float_dtype=None): 49 | train_state = to_state_dict(train_state) 50 | packer = msgpack.Packer() 51 | flattend_train_state = flatten_dict(train_state) 52 | if gather_fns is not None: 53 | gather_fns = flatten_dict(to_state_dict(gather_fns)) 54 | 55 | with mlxu.open_file(path, "wb") as fout: 56 | for key, value in flattend_train_state.items(): 57 | if gather_fns is not None: 58 | value = gather_fns[key](value) 59 | value = float_tensor_to_dtype(value, float_dtype) 60 | fout.write(packer.pack((key, to_bytes(value)))) 61 | 62 | def save_pickle(self, obj, filename): 63 | if self.enable: 64 | path = os.path.join(self.checkpoint_dir, filename) 65 | else: 66 | path = '/dev/null' 67 | mlxu.save_pickle(obj, path) 68 | 69 | def save_all(self, train_state, gather_fns, metadata=None, dataset=None, milestone=False): 70 | step = int(jax.device_get(train_state.step)) 71 | if self.config.save_optimizer_state: 72 | checkpoint_state = train_state 73 | checkpoint_name = 'streaming_train_state' 74 | checkpoint_gather_fns = gather_fns 75 | else: 76 | checkpoint_state = train_state.params['params'] 77 | checkpoint_name = 'streaming_params' 78 | checkpoint_gather_fns = gather_fns.params['params'] 79 | 80 | if milestone: 81 | # Save a milestone checkpoint that will not be overwritten 82 | self.save_pickle(metadata, f'metadata_{step}.pkl') 83 | self.save_pickle(dataset, f'dataset_{step}.pkl') 84 | self.save_checkpoint( 85 | checkpoint_state, f'{checkpoint_name}_{step}', checkpoint_gather_fns 86 | ) 87 | else: 88 | # Save a normal checkpoint that can be overwritten 89 | self.save_pickle(metadata, 'metadata.pkl') 90 | self.save_pickle(dataset, 'dataset.pkl') 91 | self.save_checkpoint( 92 | checkpoint_state, f'{checkpoint_name}', checkpoint_gather_fns 93 | ) 94 | 95 | @staticmethod 96 | def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None): 97 | if shard_fns is not None: 98 | shard_fns = flatten_dict( 99 | to_state_dict(shard_fns) 100 | ) 101 | if remove_dict_prefix is not None: 102 | remove_dict_prefix = tuple(remove_dict_prefix) 103 | flattend_train_state = {} 104 | with mlxu.open_file(path) as fin: 105 | # 83886080 bytes = 80 MB, which is 16 blocks on GCS 106 | unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0) 107 | for key, value in unpacker: 108 | key = tuple(key) 109 | if remove_dict_prefix is not None: 110 | if key[:len(remove_dict_prefix)] == remove_dict_prefix: 111 | key = key[len(remove_dict_prefix):] 112 | else: 113 | continue 114 | 115 | tensor = from_bytes(None, value) 116 | if shard_fns is not None: 117 | tensor = shard_fns[key](tensor) 118 | flattend_train_state[key] = tensor 119 | 120 | if target is not None: 121 | flattened_target = flatten_dict( 122 | to_state_dict(target), keep_empty_nodes=True 123 | ) 124 | for key, value in flattened_target.items(): 125 | if key not in flattend_train_state and value == empty_node: 126 | flattend_train_state[key] = value 127 | 128 | train_state = unflatten_dict(flattend_train_state) 129 | if target is None: 130 | return train_state 131 | 132 | return from_state_dict(target, train_state) 133 | 134 | @staticmethod 135 | def load_flax_checkpoint(path, target=None, shard_fns=None): 136 | """ Load a standard flax checkpoint that's not saved with the 137 | msgpack streaming format. 138 | """ 139 | with mlxu.open_file(path, "rb") as fin: 140 | encoded_bytes = fin.read() 141 | 142 | state_dict = flax.serialization.msgpack_restore(encoded_bytes) 143 | if shard_fns is not None: 144 | shard_fns = to_state_dict(shard_fns) 145 | state_dict = tree_apply(shard_fns, state_dict) 146 | 147 | if target is None: 148 | return state_dict 149 | return from_state_dict(target, state_dict) 150 | 151 | @classmethod 152 | def load_trainstate_checkpoint(cls, load_from, trainstate_target=None, 153 | trainstate_shard_fns=None, 154 | disallow_trainstate=False): 155 | if trainstate_target is not None: 156 | params_target = trainstate_target.params['params'] 157 | else: 158 | params_target = None 159 | 160 | if trainstate_shard_fns is not None: 161 | params_shard_fns = trainstate_shard_fns.params['params'] 162 | else: 163 | params_shard_fns = None 164 | 165 | load_type, load_path = load_from.split('::', 1) 166 | if disallow_trainstate: 167 | assert load_type != 'trainstate', 'Loading full trainstate is not allowed!' 168 | train_state = None 169 | restored_params = None 170 | if load_type == 'trainstate': 171 | # Load the entire train state in the streaming format 172 | train_state = cls.load_checkpoint( 173 | path=load_path, 174 | target=trainstate_target, 175 | shard_fns=trainstate_shard_fns, 176 | ) 177 | elif load_type == 'trainstate_params': 178 | # Load the params part of the train state in the streaming format 179 | restored_params = cls.load_checkpoint( 180 | path=load_path, 181 | target=params_target, 182 | shard_fns=params_shard_fns, 183 | remove_dict_prefix=('params', 'params'), 184 | ) 185 | restored_params = flax.core.frozen_dict.freeze( 186 | {'params': restored_params} 187 | ) 188 | elif load_type == 'params': 189 | # Load the params in the streaming format 190 | restored_params = cls.load_checkpoint( 191 | path=load_path, 192 | target=params_target, 193 | shard_fns=params_shard_fns, 194 | ) 195 | restored_params = flax.core.frozen_dict.freeze( 196 | {'params': restored_params} 197 | ) 198 | elif load_type == 'flax_params': 199 | # Load the params in the standard flax format (non-streaming) 200 | # This requires the entire params to fit in memory 201 | restored_params = cls.load_flax_checkpoint( 202 | path=load_path, 203 | target=params_target, 204 | shard_fns=params_shard_fns 205 | ) 206 | restored_params = flax.core.frozen_dict.freeze( 207 | {'params': restored_params} 208 | ) 209 | else: 210 | raise ValueError(f'Invalid load_from type: {load_type}') 211 | 212 | return train_state, restored_params 213 | -------------------------------------------------------------------------------- /EasyLM/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/EasyLM/models/__init__.py -------------------------------------------------------------------------------- /EasyLM/models/llama/convert_easylm_to_hf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. 2 | # Copyright 2023 Xinyang Geng 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This script converts LLaMA model checkpoint trained by EsayLM to the 17 | # HuggingFace transformers LLaMA PyTorch format, which can then be loaded 18 | # by HuggingFace transformers. 19 | 20 | import gc 21 | import json 22 | import math 23 | import os 24 | import shutil 25 | 26 | import numpy as np 27 | import mlxu 28 | import jax 29 | import jax.numpy as jnp 30 | import flax 31 | from flax.traverse_util import flatten_dict 32 | import torch 33 | from transformers import LlamaConfig, LlamaForCausalLM 34 | 35 | from EasyLM.checkpoint import StreamingCheckpointer 36 | from EasyLM.jax_utils import float_tensor_to_dtype 37 | from EasyLM.models.llama.llama_model import LLAMA_STANDARD_CONFIGS 38 | 39 | 40 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( 41 | load_checkpoint='', 42 | model_size='vqlm_1b', 43 | output_dir='', 44 | ) 45 | 46 | 47 | def match_keywords(string, positives, negatives): 48 | for positive in positives: 49 | if positive not in string: 50 | return False 51 | for negative in negatives: 52 | if negative in string: 53 | return False 54 | return True 55 | 56 | 57 | def load_and_convert_checkpoint(path): 58 | _, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path) 59 | flax_params = flatten_dict(flax_params['params'], sep='.') 60 | torch_params = {} 61 | for key, tensor in flax_params.items(): 62 | if match_keywords(key, ["kernel"], ["norm", 'ln_f']): 63 | tensor = tensor.T 64 | torch_params[key] = torch.tensor( 65 | float_tensor_to_dtype(tensor, 'fp32'), dtype=torch.float16 66 | ) 67 | return torch_params 68 | 69 | 70 | def read_json(path): 71 | with open(path, "r") as f: 72 | return json.load(f) 73 | 74 | 75 | def write_json(text, path): 76 | with open(path, "w") as f: 77 | json.dump(text, f) 78 | 79 | 80 | def write_model(loaded, model_path, model_size): 81 | os.makedirs(model_path, exist_ok=True) 82 | tmp_model_path = os.path.join(model_path, "tmp") 83 | os.makedirs(tmp_model_path, exist_ok=True) 84 | 85 | params = LLAMA_STANDARD_CONFIGS[model_size] 86 | 87 | n_layers = params["num_hidden_layers"] 88 | n_heads = params["num_attention_heads"] 89 | dim = params["hidden_size"] 90 | dims_per_head = dim // n_heads 91 | base = 10000.0 92 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) 93 | 94 | # permute for sliced rotary 95 | def permute(w): 96 | return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim) 97 | 98 | 99 | param_count = 0 100 | index_dict = {"weight_map": {}} 101 | for layer_i in range(n_layers): 102 | filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" 103 | state_dict = { 104 | f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( 105 | loaded[f"transformer.h.{layer_i}.attention.wq.kernel"] 106 | ), 107 | f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( 108 | loaded[f"transformer.h.{layer_i}.attention.wk.kernel"] 109 | ), 110 | f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wv.kernel"], 111 | f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wo.kernel"], 112 | 113 | f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w1.kernel"], 114 | f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w2.kernel"], 115 | f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w3.kernel"], 116 | 117 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.h.{layer_i}.attention_norm.kernel"], 118 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"transformer.h.{layer_i}.ffn_norm.kernel"], 119 | 120 | } 121 | 122 | state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq 123 | for k, v in state_dict.items(): 124 | index_dict["weight_map"][k] = filename 125 | param_count += v.numel() 126 | torch.save(state_dict, os.path.join(tmp_model_path, filename)) 127 | 128 | filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" 129 | # Unsharded 130 | state_dict = { 131 | "model.embed_tokens.weight": loaded["transformer.wte.embedding"], 132 | "model.norm.weight": loaded["transformer.ln_f.kernel"], 133 | "lm_head.weight": loaded["lm_head.kernel"], 134 | } 135 | 136 | for k, v in state_dict.items(): 137 | index_dict["weight_map"][k] = filename 138 | param_count += v.numel() 139 | torch.save(state_dict, os.path.join(tmp_model_path, filename)) 140 | 141 | # Write configs 142 | index_dict["metadata"] = {"total_size": param_count * 2} 143 | write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) 144 | 145 | config = LlamaConfig( 146 | vocab_size=params["vocab_size"], 147 | max_position_embeddings=params["max_sequence_length"], 148 | hidden_size=dim, 149 | intermediate_size=params["intermediate_size"], 150 | num_attention_heads=params["num_attention_heads"], 151 | num_hidden_layers=params["num_hidden_layers"], 152 | rms_norm_eps=params["rms_norm_eps"], 153 | ) 154 | config.save_pretrained(tmp_model_path) 155 | 156 | # Make space so we can load the model properly now. 157 | del state_dict 158 | del loaded 159 | gc.collect() 160 | 161 | print("Loading the checkpoint in a Llama model.") 162 | model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16) 163 | # Avoid saving this as part of the config. 164 | del model.config._name_or_path 165 | 166 | print("Saving in the Transformers format.") 167 | model.save_pretrained(model_path) 168 | shutil.rmtree(tmp_model_path) 169 | 170 | 171 | def main(argv): 172 | assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != "" 173 | assert FLAGS.model_size in LLAMA_STANDARD_CONFIGS 174 | 175 | write_model( 176 | load_and_convert_checkpoint(FLAGS.load_checkpoint), 177 | model_path=FLAGS.output_dir, 178 | model_size=FLAGS.model_size, 179 | ) 180 | 181 | 182 | if __name__ == "__main__": 183 | mlxu.run(main) -------------------------------------------------------------------------------- /EasyLM/models/llama/llama_train.py: -------------------------------------------------------------------------------- 1 | import pprint 2 | from functools import partial 3 | 4 | from tqdm import tqdm, trange 5 | import numpy as np 6 | import mlxu 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | from jax.experimental.pjit import pjit 11 | from jax.sharding import PartitionSpec as PS 12 | from flax.training.train_state import TrainState 13 | 14 | from EasyLM.data import DatasetFactory 15 | from EasyLM.checkpoint import StreamingCheckpointer 16 | from EasyLM.optimizers import OptimizerFactory 17 | from EasyLM.jax_utils import ( 18 | JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, 19 | cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name, 20 | set_random_seed, average_metrics, get_weight_decay_mask, 21 | make_shard_and_gather_fns, with_sharding_constraint, 22 | ) 23 | from EasyLM.models.llama.llama_model import ( 24 | LLaMAConfig, FlaxLLaMAForCausalLMModule 25 | ) 26 | 27 | 28 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( 29 | seed=42, 30 | mesh_dim='1,-1,1', 31 | dtype='fp32', 32 | total_steps=10000, 33 | load_llama_config='', 34 | update_llama_config='', 35 | load_checkpoint='', 36 | load_dataset_state='', 37 | load_metadata='', 38 | log_freq=50, 39 | save_model_freq=0, 40 | save_milestone_freq=0, 41 | eval_steps=0, 42 | tokenizer=LLaMAConfig.get_tokenizer_config(), 43 | train_dataset=DatasetFactory.get_default_config(), 44 | eval_dataset=DatasetFactory.get_default_config(), 45 | optimizer=OptimizerFactory.get_default_config(), 46 | checkpointer=StreamingCheckpointer.get_default_config(), 47 | llama=LLaMAConfig.get_default_config(), 48 | logger=mlxu.WandBLogger.get_default_config(), 49 | log_all_worker=False, 50 | jax_distributed=JaxDistributedConfig.get_default_config(), 51 | ) 52 | 53 | 54 | def main(argv): 55 | JaxDistributedConfig.initialize(FLAGS.jax_distributed) 56 | variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF) 57 | flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF) 58 | logger = mlxu.WandBLogger( 59 | config=FLAGS.logger, 60 | variant=variant, 61 | enable=FLAGS.log_all_worker or (jax.process_index() == 0), 62 | ) 63 | set_random_seed(FLAGS.seed) 64 | 65 | tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer) 66 | dataset = DatasetFactory.load_dataset( 67 | FLAGS.train_dataset, tokenizer, device_count=jax.device_count() 68 | ) 69 | if FLAGS.load_dataset_state != '': 70 | dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state)) 71 | 72 | if FLAGS.eval_steps > 0: 73 | eval_dataset = DatasetFactory.load_dataset( 74 | FLAGS.eval_dataset, dataset.tokenizer 75 | ) 76 | eval_iterator = iter(eval_dataset) 77 | 78 | seq_length = dataset.seq_length 79 | 80 | if FLAGS.load_llama_config != '': 81 | llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config) 82 | else: 83 | llama_config = LLaMAConfig(**FLAGS.llama) 84 | 85 | if FLAGS.update_llama_config != '': 86 | llama_config.update(dict(eval(FLAGS.update_llama_config))) 87 | 88 | # llama_config.update(dict( 89 | # bos_token_id=dataset.tokenizer.bos_token_id, 90 | # eos_token_id=dataset.tokenizer.eos_token_id, 91 | # )) 92 | # if llama_config.vocab_size < dataset.vocab_size: 93 | # llama_config.update(dict(vocab_size=dataset.vocab_size)) 94 | 95 | model = FlaxLLaMAForCausalLMModule( 96 | llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype) 97 | ) 98 | 99 | optimizer, optimizer_info = OptimizerFactory.get_optimizer( 100 | FLAGS.optimizer, 101 | get_weight_decay_mask(LLaMAConfig.get_weight_decay_exclusions()) 102 | ) 103 | 104 | def create_trainstate_from_params(params): 105 | return TrainState.create(params=params, tx=optimizer, apply_fn=None) 106 | 107 | def init_fn(rng): 108 | rng_generator = JaxRNG(rng) 109 | params = model.init( 110 | input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32), 111 | position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32), 112 | attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32), 113 | rngs=rng_generator(llama_config.rng_keys()), 114 | ) 115 | return TrainState.create(params=params, tx=optimizer, apply_fn=None) 116 | 117 | def train_step(train_state, rng, batch): 118 | rng_generator = JaxRNG(rng) 119 | batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'))) 120 | def loss_and_accuracy(params): 121 | logits = model.apply( 122 | params, batch['input_tokens'], deterministic=False, 123 | rngs=rng_generator(llama_config.rng_keys()), 124 | ).logits 125 | return cross_entropy_loss_and_accuracy( 126 | logits, batch['target_tokens'], batch['loss_masks'] 127 | ) 128 | grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True) 129 | (loss, accuracy), grads = grad_fn(train_state.params) 130 | train_state = train_state.apply_gradients(grads=grads) 131 | metrics = dict( 132 | loss=loss, 133 | accuracy=accuracy, 134 | learning_rate=optimizer_info['learning_rate_schedule'](train_state.step), 135 | gradient_norm=global_norm(grads), 136 | param_norm=global_norm(train_state.params), 137 | ) 138 | return train_state, rng_generator(), metrics 139 | 140 | def eval_step(train_state, rng, batch): 141 | rng_generator = JaxRNG(rng) 142 | batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'))) 143 | logits = model.apply( 144 | train_state.params, batch['input_tokens'], deterministic=True, 145 | rngs=rng_generator(llama_config.rng_keys()), 146 | ).logits 147 | loss, accuracy = cross_entropy_loss_and_accuracy( 148 | logits, batch['target_tokens'], batch['loss_masks'] 149 | ) 150 | metrics = dict( 151 | eval_loss=loss, 152 | eval_accuracy=accuracy, 153 | ) 154 | return rng_generator(), metrics 155 | 156 | train_state_shapes = jax.eval_shape(init_fn, next_rng()) 157 | train_state_partition = match_partition_rules( 158 | LLaMAConfig.get_partition_rules(), train_state_shapes 159 | ) 160 | 161 | shard_fns, gather_fns = make_shard_and_gather_fns( 162 | train_state_partition, train_state_shapes 163 | ) 164 | checkpointer = StreamingCheckpointer( 165 | FLAGS.checkpointer, logger.output_dir, 166 | enable=jax.process_index() == 0, 167 | ) 168 | 169 | sharded_init_fn = pjit( 170 | init_fn, 171 | in_shardings=PS(), 172 | out_shardings=train_state_partition 173 | ) 174 | 175 | sharded_create_trainstate_from_params = pjit( 176 | create_trainstate_from_params, 177 | in_shardings=(train_state_partition.params, ), 178 | out_shardings=train_state_partition, 179 | donate_argnums=(0, ), 180 | ) 181 | 182 | sharded_train_step = pjit( 183 | train_step, 184 | in_shardings=(train_state_partition, PS(), PS()), 185 | out_shardings=(train_state_partition, PS(), PS()), 186 | donate_argnums=(0, 1), 187 | ) 188 | 189 | sharded_eval_step = pjit( 190 | eval_step, 191 | in_shardings=(train_state_partition, PS(), PS()), 192 | out_shardings=(PS(), PS()), 193 | donate_argnums=(1,), 194 | ) 195 | 196 | def save_checkpoint(train_state, milestone=False): 197 | step = int(jax.device_get(train_state.step)) 198 | metadata = dict( 199 | step=step, 200 | variant=variant, 201 | flags=flags_config_dict, 202 | llama_config=llama_config.to_dict(), 203 | ) 204 | checkpointer.save_all( 205 | train_state=train_state, 206 | gather_fns=gather_fns, 207 | metadata=metadata, 208 | dataset=dataset.get_state_dict(), 209 | milestone=milestone, 210 | ) 211 | 212 | mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim) 213 | with mesh: 214 | train_state, restored_params = None, None 215 | if FLAGS.load_checkpoint != '': 216 | train_state, restored_params = checkpointer.load_trainstate_checkpoint( 217 | FLAGS.load_checkpoint, train_state_shapes, shard_fns 218 | ) 219 | 220 | if train_state is None and restored_params is None: 221 | # Initialize from scratch 222 | train_state = sharded_init_fn(next_rng()) 223 | elif train_state is None and restored_params is not None: 224 | # Restore from params but initialize train_state 225 | train_state = sharded_create_trainstate_from_params(restored_params) 226 | del restored_params 227 | # Hack to get the correct step from the metadata. Otherwise, the `train_state.step` is initialized at 0 228 | if FLAGS.load_metadata != '': 229 | loaded_step = mlxu.load_pickle(FLAGS.load_metadata)["step"] + 1 230 | train_state = train_state.replace(step=loaded_step) 231 | print(f"Resuming at step {loaded_step}") 232 | 233 | start_step = int(jax.device_get(train_state.step)) 234 | 235 | # if FLAGS.save_model_freq > 0: 236 | # save_checkpoint(train_state) 237 | 238 | sharded_rng = next_rng() 239 | 240 | step_counter = trange(start_step, FLAGS.total_steps, ncols=0) 241 | 242 | for step, (batch, dataset_metrics) in zip(step_counter, dataset): 243 | train_state, sharded_rng, metrics = sharded_train_step( 244 | train_state, sharded_rng, batch 245 | ) 246 | 247 | if step % FLAGS.log_freq == 0: 248 | if FLAGS.eval_steps > 0: 249 | eval_metric_list = [] 250 | for _ in range(FLAGS.eval_steps): 251 | eval_batch, _ = next(eval_iterator) 252 | sharded_rng, eval_metrics = sharded_eval_step( 253 | train_state, sharded_rng, eval_batch 254 | ) 255 | eval_metric_list.append(eval_metrics) 256 | metrics.update(average_metrics(eval_metric_list)) 257 | 258 | log_metrics = {"step": step} 259 | log_metrics.update(metrics) 260 | log_metrics.update(dataset_metrics) 261 | log_metrics = jax.device_get(log_metrics) 262 | logger.log(log_metrics) 263 | if jax.process_index() == 0: 264 | tqdm.write("\n" + pprint.pformat(log_metrics) + "\n") 265 | 266 | if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0: 267 | save_checkpoint(train_state, milestone=True) 268 | elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0: 269 | save_checkpoint(train_state) 270 | 271 | if FLAGS.save_model_freq > 0: 272 | save_checkpoint(train_state) 273 | 274 | 275 | if __name__ == "__main__": 276 | mlxu.run(main) 277 | -------------------------------------------------------------------------------- /EasyLM/optimizers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import Any, Mapping, Text, Tuple, Union, NamedTuple 4 | from functools import partial 5 | import re 6 | import dataclasses 7 | import random 8 | 9 | from ml_collections.config_dict import config_dict 10 | from ml_collections import ConfigDict 11 | import jax 12 | import jax.numpy as jnp 13 | import numpy as np 14 | from absl import logging 15 | import optax 16 | 17 | from EasyLM.jax_utils import float_to_dtype 18 | 19 | 20 | class OptimizerFactory(object): 21 | """ Configurable optax optimizer factory. """ 22 | 23 | def __init__(self): 24 | raise NotImplementedError 25 | 26 | @staticmethod 27 | def get_default_config(updates=None): 28 | config = ConfigDict() 29 | config.accumulate_gradient_steps = 1 30 | config.type = 'adamw' 31 | config.palm_optimizer = PalmOptimizerFactory.get_default_config() 32 | config.adamw_optimizer = AdamWOptimizerFactory.get_default_config() 33 | 34 | if updates is not None: 35 | config.update(ConfigDict(updates).copy_and_resolve_references()) 36 | return config 37 | 38 | @classmethod 39 | def get_optimizer(cls, config, weight_decay_mask=None): 40 | config = cls.get_default_config(config) 41 | if config.type == 'palm': 42 | optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer( 43 | config.palm_optimizer, weight_decay_mask 44 | ) 45 | elif config.type == 'adamw': 46 | optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer( 47 | config.adamw_optimizer, weight_decay_mask 48 | ) 49 | else: 50 | raise ValueError(f'Unknown optimizer type: {config.type}') 51 | 52 | if config.accumulate_gradient_steps > 1: 53 | optimizer = optax.MultiSteps( 54 | optimizer, config.accumulate_gradient_steps 55 | ) 56 | 57 | return optimizer, optimizer_info 58 | 59 | 60 | class PalmOptimizerFactory(object): 61 | """ PaLM optimizer factory. This optimizer implements the optimizer 62 | described in the PaLM paper: https://arxiv.org/abs/2204.02311 63 | """ 64 | 65 | def __init__(self): 66 | raise NotImplementedError 67 | 68 | @staticmethod 69 | def get_default_config(updates=None): 70 | config = ConfigDict() 71 | config.lr = 0.01 72 | config.lr_warmup_steps = 10000 73 | config.b1 = 0.9 74 | config.b2 = 0.99 75 | config.clip_gradient = 1.0 76 | config.weight_decay = 1e-4 77 | config.bf16_momentum = False 78 | 79 | if updates is not None: 80 | config.update(ConfigDict(updates).copy_and_resolve_references()) 81 | return config 82 | 83 | @classmethod 84 | def get_optimizer(cls, config, weight_decay_mask=None): 85 | config = cls.get_default_config(config) 86 | 87 | def learning_rate_schedule(step): 88 | multiplier = config.lr / 0.01 89 | return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps)) 90 | 91 | def weight_decay_schedule(step): 92 | multiplier = config.weight_decay / 1e-4 93 | return -multiplier * jnp.square(learning_rate_schedule(step)) 94 | 95 | optimizer_info = dict( 96 | learning_rate_schedule=learning_rate_schedule, 97 | weight_decay_schedule=weight_decay_schedule, 98 | ) 99 | 100 | optimizer = optax.chain( 101 | optax.clip_by_global_norm(config.clip_gradient), 102 | optax.adafactor( 103 | learning_rate=learning_rate_schedule, 104 | multiply_by_parameter_scale=True, 105 | momentum=config.b1, 106 | decay_rate=config.b2, 107 | factored=False, 108 | clipping_threshold=None, 109 | dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32, 110 | ), 111 | optax_add_scheduled_weight_decay( 112 | weight_decay_schedule, weight_decay_mask 113 | ) 114 | ) 115 | return optimizer, optimizer_info 116 | 117 | 118 | class AdamWOptimizerFactory(object): 119 | """ AdamW optimizer with cosine schedule. """ 120 | 121 | def __init__(self): 122 | raise NotImplementedError 123 | 124 | @staticmethod 125 | def get_default_config(updates=None): 126 | config = ConfigDict() 127 | config.init_lr = 0.0 128 | config.end_lr = 0.001 129 | config.lr = 0.01 130 | config.lr_warmup_steps = 2000 131 | config.lr_decay_steps = 500000 132 | config.b1 = 0.9 133 | config.b2 = 0.95 134 | config.clip_gradient = 1.0 135 | config.weight_decay = 1e-4 136 | config.bf16_momentum = False 137 | config.multiply_by_parameter_scale = False 138 | 139 | if updates is not None: 140 | config.update(ConfigDict(updates).copy_and_resolve_references()) 141 | return config 142 | 143 | @classmethod 144 | def get_optimizer(cls, config, weight_decay_mask=None): 145 | config = cls.get_default_config(config) 146 | 147 | learning_rate_schedule = optax.warmup_cosine_decay_schedule( 148 | init_value=config.init_lr, 149 | peak_value=config.lr, 150 | warmup_steps=config.lr_warmup_steps, 151 | decay_steps=config.lr_decay_steps, 152 | end_value=config.end_lr, 153 | ) 154 | 155 | optimizer_info = dict( 156 | learning_rate_schedule=learning_rate_schedule, 157 | ) 158 | 159 | if config.multiply_by_parameter_scale: 160 | optimizer = optax.chain( 161 | optax.clip_by_global_norm(config.clip_gradient), 162 | optax.adafactor( 163 | learning_rate=learning_rate_schedule, 164 | multiply_by_parameter_scale=True, 165 | momentum=config.b1, 166 | decay_rate=config.b2, 167 | factored=False, 168 | clipping_threshold=None, 169 | dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32, 170 | ), 171 | optax_add_scheduled_weight_decay( 172 | lambda step: -learning_rate_schedule(step) * config.weight_decay, 173 | weight_decay_mask 174 | ) 175 | ) 176 | else: 177 | optimizer = optax.chain( 178 | optax.clip_by_global_norm(config.clip_gradient), 179 | optax.adamw( 180 | learning_rate=learning_rate_schedule, 181 | weight_decay=config.weight_decay, 182 | b1=config.b1, 183 | b2=config.b2, 184 | mask=weight_decay_mask, 185 | mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32, 186 | ), 187 | ) 188 | 189 | return optimizer, optimizer_info 190 | 191 | 192 | class OptaxScheduledWeightDecayState(NamedTuple): 193 | count: jax.Array 194 | 195 | 196 | def optax_add_scheduled_weight_decay(schedule_fn, mask=None): 197 | """ Apply weight decay with schedule. """ 198 | 199 | def init_fn(params): 200 | del params 201 | return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32)) 202 | 203 | def update_fn(updates, state, params): 204 | if params is None: 205 | raise ValueError('Params cannot be None for weight decay!') 206 | 207 | weight_decay = schedule_fn(state.count) 208 | updates = jax.tree_util.tree_map( 209 | lambda g, p: g + weight_decay * p, updates, params 210 | ) 211 | return updates, OptaxScheduledWeightDecayState( 212 | count=optax.safe_int32_increment(state.count) 213 | ) 214 | 215 | if mask is not None: 216 | return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask) 217 | return optax.GradientTransformation(init_fn, update_fn) 218 | -------------------------------------------------------------------------------- /EasyLM/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/EasyLM/scripts/__init__.py -------------------------------------------------------------------------------- /EasyLM/scripts/convert_checkpoint.py: -------------------------------------------------------------------------------- 1 | # This script converts model checkpoint trained by EsayLM to a standard 2 | # mspack checkpoint that can be loaded by huggingface transformers or 3 | # flax.serialization.msgpack_restore. Such conversion allows models to be 4 | # used by other frameworks that integrate with huggingface transformers. 5 | 6 | import pprint 7 | from functools import partial 8 | import os 9 | import numpy as np 10 | import mlxu 11 | import jax.numpy as jnp 12 | import flax.serialization 13 | from EasyLM.checkpoint import StreamingCheckpointer 14 | from EasyLM.jax_utils import float_to_dtype 15 | 16 | 17 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( 18 | load_checkpoint='', 19 | output_file='', 20 | streaming=False, 21 | float_dtype='bf16', 22 | ) 23 | 24 | 25 | def main(argv): 26 | assert FLAGS.load_checkpoint != '' and FLAGS.output_file != '', 'input and output must be specified' 27 | params = StreamingCheckpointer.load_trainstate_checkpoint( 28 | FLAGS.load_checkpoint, disallow_trainstate=True 29 | )[1]['params'] 30 | 31 | if FLAGS.streaming: 32 | StreamingCheckpointer.save_train_state_to_file( 33 | params, FLAGS.output_file, float_dtype=FLAGS.float_dtype 34 | ) 35 | else: 36 | params = float_to_dtype(params, FLAGS.float_dtype) 37 | with mlxu.open_file(FLAGS.output, 'wb') as fout: 38 | fout.write(flax.serialization.msgpack_serialize(params, in_place=True)) 39 | 40 | 41 | if __name__ == "__main__": 42 | mlxu.run(main) 43 | -------------------------------------------------------------------------------- /EasyLM/scripts/diff_checkpoint.py: -------------------------------------------------------------------------------- 1 | # This script converts model checkpoint trained by EsayLM to a standard 2 | # mspack checkpoint that can be loaded by huggingface transformers or 3 | # flax.serialization.msgpack_restore. Such conversion allows models to be 4 | # used by other frameworks that integrate with huggingface transformers. 5 | 6 | import pprint 7 | from functools import partial 8 | import os 9 | import numpy as np 10 | import jax 11 | import jax.numpy as jnp 12 | import flax.serialization 13 | import mlxu 14 | from EasyLM.checkpoint import StreamingCheckpointer 15 | from EasyLM.jax_utils import float_to_dtype 16 | 17 | 18 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( 19 | recover_diff=False, 20 | load_base_checkpoint='', 21 | load_target_checkpoint='', 22 | output_file='', 23 | streaming=True, 24 | float_dtype='bf16', 25 | ) 26 | 27 | 28 | def main(argv): 29 | assert FLAGS.load_base_checkpoint != '' and FLAGS.load_target_checkpoint != '' 30 | assert FLAGS.output_file != '' 31 | base_params = StreamingCheckpointer.load_trainstate_checkpoint( 32 | FLAGS.load_base_checkpoint, disallow_trainstate=True 33 | )[1]['params'] 34 | 35 | target_params = StreamingCheckpointer.load_trainstate_checkpoint( 36 | FLAGS.load_target_checkpoint, disallow_trainstate=True 37 | )[1]['params'] 38 | 39 | if FLAGS.recover_diff: 40 | params = jax.tree_util.tree_map( 41 | lambda b, t: b + t, base_params, target_params 42 | ) 43 | else: 44 | params = jax.tree_util.tree_map( 45 | lambda b, t: t - b, base_params, target_params 46 | ) 47 | 48 | if FLAGS.streaming: 49 | StreamingCheckpointer.save_train_state_to_file( 50 | params, FLAGS.output_file, float_dtype=FLAGS.float_dtype 51 | ) 52 | else: 53 | params = float_to_dtype(params, FLAGS.float_dtype) 54 | with mlxu.open_file(FLAGS.output, 'wb') as fout: 55 | fout.write(flax.serialization.msgpack_serialize(params, in_place=True)) 56 | 57 | 58 | if __name__ == "__main__": 59 | mlxu.run(main) 60 | -------------------------------------------------------------------------------- /EasyLM/scripts/lm_eval_harness.py: -------------------------------------------------------------------------------- 1 | # This script runs lm_eval_harness evaluations against a served language model. 2 | # Typically, you need to run a language model server first, e.g.: 3 | # python -m EasyLM.models.gptj.gptj_serve ... 4 | 5 | import dataclasses 6 | import pprint 7 | from functools import partial 8 | import os 9 | from tqdm import tqdm, trange 10 | import numpy as np 11 | import mlxu 12 | 13 | from flax.traverse_util import flatten_dict 14 | from lm_eval import evaluator, tasks 15 | from lm_eval.base import LM 16 | 17 | from EasyLM.serving import LMClient 18 | 19 | 20 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( 21 | tasks='wsc,piqa,winogrande,openbookqa,logiqa', 22 | shots=0, 23 | limit=0, 24 | write_out=False, 25 | lm_client=LMClient.get_default_config(), 26 | logger=mlxu.WandBLogger.get_default_config(), 27 | ) 28 | 29 | 30 | class LMEvalHarnessInterface(LM): 31 | 32 | def __init__(self, lm_client): 33 | self.lm_client = lm_client 34 | 35 | def greedy_until(self, inputs): 36 | prefix, until = zip(*inputs) 37 | return self.lm_client.greedy_until(prefix, until) 38 | 39 | def loglikelihood_rolling(self, inputs): 40 | loglikelihood, is_greedy = self.lm_client.loglikelihood_rolling(inputs) 41 | return list(zip(loglikelihood, is_greedy)) 42 | 43 | def loglikelihood(self, inputs): 44 | prefix, text = zip(*inputs) 45 | loglikelihood, is_greedy = self.lm_client.loglikelihood(prefix, text) 46 | return list(zip(loglikelihood, is_greedy)) 47 | 48 | 49 | def main(argv): 50 | logger = mlxu.WandBLogger( 51 | config=FLAGS.logger, variant=mlxu.get_user_flags(FLAGS, FLAGS_DEF) 52 | ) 53 | model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client)) 54 | task_list = FLAGS.tasks.split(',') 55 | results = evaluator.evaluate( 56 | model, tasks.get_task_dict(task_list), False, FLAGS.shots, 57 | limit=None if FLAGS.limit <= 0 else FLAGS.limit, 58 | write_out=FLAGS.write_out, 59 | ) 60 | logger.log(flatten_dict(results['results'], sep='/')) 61 | pprint.pprint(results) 62 | 63 | 64 | if __name__ == "__main__": 65 | mlxu.run(main) 66 | -------------------------------------------------------------------------------- /EasyLM/scripts/lm_eval_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import mlxu 3 | from EasyLM.serving import LMClient 4 | 5 | 6 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( 7 | input_file='', 8 | output_file='', 9 | prefix_field='prefix', 10 | text_field='text', 11 | until_field='until', 12 | eval_type='loglikelihood', 13 | lm_client=LMClient.get_default_config(), 14 | ) 15 | 16 | 17 | def main(argv): 18 | lm_client = LMClient(FLAGS.lm_client) 19 | with mlxu.open_file(FLAGS.input_file, 'r') as fin: 20 | input_data = json.load(fin) 21 | 22 | if FLAGS.eval_type == 'loglikelihood': 23 | prefix = input_data[FLAGS.prefix_field] 24 | text = input_data[FLAGS.text_field] 25 | loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text) 26 | output_data = { 27 | 'loglikelihood': loglikelihoods, 28 | 'is_greedy': is_greedys, 29 | } 30 | elif FLAGS.eval_type == 'loglikelihood_rolling': 31 | text = input_data[FLAGS.text_field] 32 | loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text) 33 | output_data = { 34 | 'loglikelihood': loglikelihoods, 35 | 'is_greedy': is_greedys, 36 | } 37 | elif FLAGS.eval_type == 'greedy_until': 38 | prefix = input_data[FLAGS.prefix_field] 39 | until = input_data[FLAGS.until_field] 40 | output_data = {'output_text': lm_client.greedy_until(prefix, until)} 41 | elif FLAGS.eval_type == 'generate': 42 | prefix = input_data[FLAGS.prefix_field] 43 | output_data = {'output_text': lm_client.generate(prefix)} 44 | else: 45 | raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}') 46 | 47 | with mlxu.open_file(FLAGS.output_file, 'w') as fout: 48 | json.dump(output_data, fout) 49 | 50 | 51 | if __name__ == "__main__": 52 | mlxu.run(main) 53 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # LVM: Sequential Modeling Enables Scalable Learning for Large Vision Models 3 | 4 | [LVM](https://arxiv.org/abs/2312.00785) is a vision pretraining model that converts various kinds of visual data into visual sentences and performs next-token prediction autoregressively. It is compatible with both GPU and TPU. 5 | 6 | LVM is built on top of [OpenLLaMA](https://github.com/openlm-research/open_llama) (an autoregressive model) and [OpenMuse](https://github.com/huggingface/open-muse) (a VQGAN that converts images into visual tokens). 7 | 8 | This was trained in collaboration with HuggingFace. Thanks [Victor Sanh](https://github.com/VictorSanh) for the support in this project. 9 | 10 | ## Abstract: 11 | 12 | We introduce a novel sequential modeling approach which enables learning a Large Vision Model (LVM) without making use of any linguistic data. 13 | To do this, we define a common format, ``visual sentences", in which we can represent raw images and videos as well as annotated data sources such as semantic segmentations and depth reconstructions without needing any meta-knowledge beyond the pixels. Once this wide variety of visual data (comprising 420 billion tokens) is represented as sequences, the model can be trained to minimize a cross-entropy loss for next token prediction. By training across various scales of model architecture and data diversity, we provide empirical evidence that our models scale effectively. Many different vision tasks can be solved by designing suitable visual prompts at test time. 14 | 15 | ## Visual Sentence 16 | 17 |
18 | 19 |
20 | 21 | 22 | ## Key Differences from the Original Paper Version 23 | 1. We are currently releasing the 7B model (previously 3B). Additional model size variants will be available later. 24 | 2. Deep filtering (including quality filters, deduplication, and known CSAM content removal) has been applied to the LAION dataset, reducing the dataset size from 1.5B to 1.2B images. 25 | 26 | 3. The tokenizer has been improved for better performance. 27 | 28 | ## License 29 | LVM is licensed under the Apache 2.0 License. 30 | 31 | ## Installation 32 | ```shell 33 | git clone https://github.com/ytongbai/LVM 34 | cd LVM 35 | export PYTHONPATH="\${PWD}:\$PYTHONPATH" 36 | ``` 37 | 38 | ## Environment Setup 39 | ```shell 40 | conda env create -f scripts/gpu_environment.yml 41 | conda activate LVM 42 | ``` 43 | 44 | ## Dataset Preparation 45 | Please refer to \`DATASET.md\` for detailed instructions on preparing the dataset. 46 | 47 | After preparing the dataset, you will get a pretokenized file \`dataset.jsonl\`. 48 | 49 | ## Training Script 50 | 51 | We provide an example training script for 7B model, for more details about the distributed training setting, please refer to [EasyLM](https://github.com/young-geng/EasyLM). 52 | 53 | For other model size, we provide the model definition from 100M, 300M, 600M, 1B, 3B, 7B, 13B, 20B to 30B in './EasyLM/models/llama/llama_model.py' . 54 | 55 | 56 | ```shell 57 | python -u -m EasyLM.models.llama.llama_train \ 58 | --jax_distributed.initialize_jax_distributed=True \ 59 | --jax_distributed.coordinator_address='$MASTER_ADDR:$MASTER_PORT' \ 60 | --jax_distributed.local_device_ids='0,1,2,3,4,5,6,7' \ 61 | --mesh_dim='$SLURM_NNODES,-1,1' \ 62 | --dtype='bf16' \ 63 | --total_steps=400000 \ # change according to the number of data 64 | --log_freq=10 \ 65 | --save_model_freq=1000 \ 66 | --save_milestone_freq=2000 \ 67 | --load_llama_config='vqlm_7b' \ 68 | --optimizer.type='adamw' \ 69 | --optimizer.adamw_optimizer.weight_decay=0.1 \ 70 | --optimizer.adamw_optimizer.lr=1.5e-4 \ 71 | --optimizer.adamw_optimizer.end_lr=3e-5 \ 72 | --optimizer.adamw_optimizer.lr_warmup_steps=8000 \ 73 | --optimizer.adamw_optimizer.lr_decay_steps=288000 \ 74 | --optimizer.accumulate_gradient_steps=4 \ 75 | --train_dataset.type='json' \ 76 | --train_dataset.text_processor.fields=',{tokens},' \ 77 | --train_dataset.json_dataset.path='/path/to/dataset.jsonl' \ 78 | --train_dataset.json_dataset.seq_length=4096 \ 79 | --train_dataset.json_dataset.batch_size=32 \ 80 | --train_dataset.json_dataset.tokenizer_processes=16 \ 81 | --checkpointer.save_optimizer_state=True \ 82 | --logger.online=True \ 83 | --logger.output_dir='/path/to/checkpoint/$RUN_NAME' \ 84 | --logger.wandb_dir='/path/to/wandb' \ 85 | --logger.notes='' \ 86 | --logger.experiment_id=$EXPERIMENT_ID 87 | ``` 88 | 89 | ## Convert to Huggingface checkpoint 90 | 91 | ```shell 92 | python -m EasyLM.models.llama.convert_easylm_to_hf --load_checkpoint='trainstate_params::/path/to/checkpoint/streaming_train_state' --model_size='vqlm_7b' --output_dir='/path/to/output/checkpoint/' 93 | ``` 94 | 95 | ## Demo & Inference 96 | 97 | Download the [few-shot examples dataset](https://livejohnshopkins-my.sharepoint.com/:f:/g/personal/ybai20_jh_edu/Ei0xiLdFFqJPnwAlFWar29EBUAvB0O3CVaJykZl-f11KDQ?e=Bx9SXZ). 98 | 99 | There are mainly two visual prompting: sequential prompting and analogy prompting. 100 | 101 | ### Analogy Prompting: 102 | Describe the task with few-shot examples, which is pairs of (x, y) inputs where x is the input image and y the "annotated" image. And add one query image in the end. We provide more few-shot examples at [this link](https://livejohnshopkins-my.sharepoint.com/:f:/g/personal/ybai20_jh_edu/Ei0xiLdFFqJPnwAlFWar29EBUAvB0O3CVaJykZl-f11KDQ?e=Bx9SXZ), and you can simply change the query image in the end for testing. 103 | 104 | ### Sequential Prompting: 105 | Input a sequence of continuous frames and let the model generate the next one. 106 | 107 | 108 | Check out our demo and additionaly inference code on HuggingFace Spaces: [LVM Demo](https://huggingface.co/spaces/Emma02/LVM) 109 | 110 | 111 | 112 | ## Evaluation 113 | 114 | Check evaluation/EVAL.md 115 | 116 | ## Models 117 | - [LVM Checkpoints](https://huggingface.co/Emma02/LVM_ckpts) 118 | - [VQ-VAE Checkpoints](https://huggingface.co/Emma02/vqvae_ckpts) 119 | 120 | 121 | ## Finetuning 122 | 123 | LVM is a pretraining model, without instruction tuning or other kinds of post-training. If you want a specific task, we recommend organizing the data into visual sentence format, then finetune with a smaller learning rate using the training script we provide. 124 | 125 | ## Citation 126 | If you found LVM useful in your research or applications, please cite our work using the following BibTeX: 127 | 128 | ```bibtex 129 | @article{bai2023sequential, 130 | title={Sequential modeling enables scalable learning for large vision models}, 131 | author={Bai, Yutong and Geng, Xinyang and Mangalam, Karttikeya and Bar, Amir and Yuille, Alan and Darrell, Trevor and Malik, Jitendra and Efros, Alexei A}, 132 | journal={arXiv preprint arXiv:2312.00785}, 133 | year={2023} 134 | } 135 | 136 | ``` 137 | -------------------------------------------------------------------------------- /evaluation/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/evaluation/.DS_Store -------------------------------------------------------------------------------- /evaluation/EVAL.md: -------------------------------------------------------------------------------- 1 | # Large Vision Model Evaluation 2 | This is a evaluation demo for the Large Vision Model paper. 3 | 4 | 5 | ### Setting Up with Conda Environment 6 | You can also set up the demo using a conda environment. First, you will need to 7 | create a conda environment and install the dependencies: 8 | ```bash 9 | conda env create -f environment.yml 10 | conda activate vqlm_demo 11 | export PYTHONPATH=/path/to/this/repo:$PYTHONPATH 12 | 13 | ``` 14 | 15 | Then you'll need to download the [VQ tokenizer checkpoint file](https://huggingface.co/Emma02/vqvae_ckpts) and put it into ./vqvae_ckpts/ 16 | 17 | 18 | ## Running the Perplexity Evaluation 19 | This repo also contains the perplexity evaluation script. You can run the following 20 | command to evaluate the perplexity of the model: 21 | 22 | ```bash 23 | python -m vqlm_demo.eval_perplexity \ 24 | --input_file=path/to/input_jsonl_file \ 25 | --input_base_dir=base/path/to/add/to/the/input \ 26 | --checkpoint=path/to/checkpoint \ 27 | --batch_size=4 28 | ``` 29 | 30 | This script accept a jsonl file as input. Each line of the jsonl file 31 | representing a dictionary. Each line represents one example in the evaluation 32 | set. The dictionary should have two key: 33 | * input: a list of paths to the input images as **context to the model**. This list should include the few shot examples. 34 | * target: a list of paths to the **target images** to evaluate perplexity on. 35 | 36 | Here's an example of the json format: 37 | ```javascript 38 | {'input': ['path/to/input1.jpg', 'path/to/input2.jpg', 'path/to/input3.jpg'], 39 | 'target': ['path/to/target1.jpg', 'path/to/target2.jpg', 'path/to/target3.jpg']} 40 | ``` 41 | 42 | When evaluating 43 | 44 | Ths script should run the model and compute the average perplexity on the 45 | evaluation set. 46 | 47 | 48 | ## Running the batch generation evaluation 49 | This repo also contains the script to batch generate images from the model. You 50 | can run the following command to generate images from the model: 51 | 52 | ```bash 53 | python -m vqlm_demo.batch_generation \ 54 | --checkpoint=path/to/checkpoint \ 55 | --input_file=path/to/input_jsonl_file \ 56 | --input_base_dir=base/path/to/add/to/input/path/in/jsonl \ 57 | --output_base_dir=base/path/to/add/to/output/path/in/jsonl \ 58 | --n_new_frames=1 \ 59 | --n_candidates=4 \ 60 | --resize_output='original' 61 | ``` 62 | 63 | This script accept a jsonl file as input. Each line of the jsonl file 64 | representing a dictionary. Each line represents one example in the evaluation 65 | set. The dictionary should have two key: 66 | * input: a list of paths to the input images as **context to the model**. This list should include the few shot examples. 67 | * output: a string representing the output path of generated image. 68 | 69 | Here's an example of the json format: 70 | ```javascript 71 | {'input': ['path/to/input1.jpg', 'path/to/input2.jpg', 'path/to/input3.jpg'], 72 | 'output': 'path/to/output.jpg'} 73 | ``` -------------------------------------------------------------------------------- /evaluation/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /evaluation/environment.yml: -------------------------------------------------------------------------------- 1 | name: vqlm_demo 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.10 6 | - pip 7 | - numpy 8 | - scipy 9 | - matplotlib 10 | - seaborn 11 | - jupyter 12 | - tqdm 13 | - pillow 14 | - pip: 15 | - --extra-index-url https://download.pytorch.org/whl/cu118 16 | - transformers==4.34.1 17 | - torch==2.0.1 18 | - einops 19 | - absl-py 20 | - ml_collections 21 | - requests 22 | - mlxu==0.1.11 23 | - pydantic 24 | - fastapi 25 | - uvicorn 26 | - gradio 27 | - fastapi 28 | - uvicorn 29 | - opencv-python-headless 30 | - scikit-video 31 | - scikit-image 32 | - natsort 33 | -------------------------------------------------------------------------------- /evaluation/vqlm_demo/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/evaluation/vqlm_demo/.DS_Store -------------------------------------------------------------------------------- /evaluation/vqlm_demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/evaluation/vqlm_demo/__init__.py -------------------------------------------------------------------------------- /evaluation/vqlm_demo/batch_generation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Batch generation for sequnce of images. This script accept a jsonl file 3 | as input. Each line of the jsonl file representing a dictionary. Each line 4 | represents one example in the evaluation set. The dictionary should have two key: 5 | 6 | input: a list of paths to the input images as context to the model. 7 | output: a string representing the path to the output of generation to be saved. 8 | 9 | Ths script runs the mode to generate the output images, and concatenate the 10 | input and output images together and save them to the output path. 11 | """ 12 | 13 | import os 14 | import json 15 | from PIL import Image 16 | import numpy as np 17 | import mlxu 18 | from tqdm import tqdm, trange 19 | from multiprocessing import Pool 20 | import einops 21 | import torch 22 | 23 | from .inference import MultiProcessInferenceModel 24 | from .utils import read_image_to_tensor, MultiProcessImageSaver 25 | 26 | 27 | FLAGS, _ = mlxu.define_flags_with_default( 28 | input_file='', 29 | checkpoint='', 30 | input_base_dir='', 31 | output_base_dir='', 32 | evaluate_mse=False, 33 | json_input_key='input', 34 | json_output_key='output', 35 | json_target_key='target', 36 | n_new_frames=1, 37 | n_candidates=2, 38 | context_frames=16, 39 | temperature=1.0, 40 | top_p=1.0, 41 | n_workers=8, 42 | dtype='float16', 43 | torch_devices='', 44 | batch_size_factor=4, 45 | max_examples=0, 46 | resize_output='', 47 | include_input=False, 48 | ) 49 | 50 | # create this according to the json file. 51 | class MultiFrameDataset(torch.utils.data.Dataset): 52 | def __init__(self, input_files, output_files, target_files=None): 53 | assert len(input_files) 54 | self.input_files = input_files 55 | self.output_files = output_files 56 | self.target_files = target_files 57 | 58 | def __len__(self): 59 | return len(self.input_files) 60 | 61 | def __getitem__(self, idx): 62 | original_size = Image.open(self.input_files[idx][-1]).size 63 | input_images = np.stack( 64 | [read_image_to_tensor(f) for f in self.input_files[idx]], 65 | axis=0 66 | ) 67 | 68 | if self.target_files is not None: 69 | target_images = np.stack( 70 | [read_image_to_tensor(f) for f in self.target_files[idx]], 71 | axis=0 72 | ) 73 | else: 74 | target_images = None 75 | return input_images, target_images, self.output_files[idx], np.array(original_size) 76 | 77 | 78 | def main(_): 79 | assert FLAGS.checkpoint != '' 80 | 81 | print(f'Loading checkpoint from {FLAGS.checkpoint}') 82 | print(f'Evaluating input file from {FLAGS.input_file}') 83 | 84 | # build a model. 85 | 86 | model = MultiProcessInferenceModel( 87 | checkpoint=FLAGS.checkpoint, 88 | torch_devices=FLAGS.torch_devices, 89 | dtype=FLAGS.dtype, 90 | context_frames=FLAGS.context_frames, 91 | use_lock=True, 92 | ) 93 | 94 | # input_files: the json file that needs to be generated by the other file. 95 | input_files = [] 96 | output_files = [] 97 | 98 | if FLAGS.evaluate_mse: 99 | target_files = [] 100 | else: 101 | target_files = None 102 | 103 | with mlxu.open_file(FLAGS.input_file, 'r') as f: 104 | for line in f: 105 | record = json.loads(line) 106 | input_files.append(record[FLAGS.json_input_key]) 107 | output_files.append(record[FLAGS.json_output_key]) 108 | if FLAGS.evaluate_mse: 109 | target_files.append(record[FLAGS.json_target_key]) 110 | 111 | 112 | if FLAGS.max_examples > 0: 113 | input_files = input_files[:FLAGS.max_examples] 114 | output_files = output_files[:FLAGS.max_examples] 115 | if FLAGS.evaluate_mse: 116 | target_files = target_files[:FLAGS.max_examples] 117 | 118 | if FLAGS.input_base_dir != '': 119 | input_files = [ 120 | [os.path.join(FLAGS.input_base_dir, x) for x in y] 121 | for y in input_files 122 | ] 123 | if FLAGS.evaluate_mse: 124 | target_files = [ 125 | [os.path.join(FLAGS.input_base_dir, x) for x in y] 126 | for y in target_files 127 | ] 128 | 129 | if FLAGS.output_base_dir != '': 130 | os.makedirs(FLAGS.output_base_dir, exist_ok=True) 131 | output_files = [ 132 | os.path.join(FLAGS.output_base_dir, x) 133 | for x in output_files 134 | ] 135 | 136 | dataset = MultiFrameDataset(input_files, output_files, target_files) 137 | 138 | data_loader = torch.utils.data.DataLoader( 139 | dataset, 140 | batch_size=FLAGS.batch_size_factor * model.n_processes, 141 | shuffle=False, 142 | num_workers=FLAGS.n_workers, 143 | ) 144 | 145 | image_saver = MultiProcessImageSaver(FLAGS.n_workers) 146 | 147 | mses = [] 148 | 149 | for batch_images, batch_targets, batch_output_files, batch_sizes in tqdm(data_loader, ncols=0): 150 | 151 | # batch_images is input. 152 | batch_images = batch_images.numpy() 153 | 154 | # 155 | context_length = batch_images.shape[1] 156 | 157 | 158 | generated_images = model( 159 | batch_images, 160 | FLAGS.n_new_frames, 161 | FLAGS.n_candidates, 162 | temperature=FLAGS.temperature, 163 | top_p=FLAGS.top_p 164 | ) 165 | 166 | 167 | repeated_batch = einops.repeat( 168 | batch_images, 169 | 'b s h w c -> b n s h w c', 170 | n=FLAGS.n_candidates, 171 | ) 172 | generated_images = np.array(generated_images) 173 | 174 | if FLAGS.evaluate_mse: 175 | batch_targets = einops.repeat( 176 | batch_targets.numpy(), 177 | 'b s h w c -> b n s h w c', # batch, candidate, s 178 | n=FLAGS.n_candidates, 179 | ) 180 | channels = batch_targets.shape[-1] 181 | # calculate mse loss. 182 | mse = np.mean((generated_images - batch_targets) ** 2, axis=(1, 2, 3, 4, 5)) 183 | 184 | mses.append(mse * channels) 185 | 186 | 187 | if FLAGS.include_input: 188 | combined = einops.rearrange( 189 | np.concatenate([repeated_batch, generated_images], axis=2), 190 | 'b n s h w c -> b (n h) (s w) c' 191 | ) 192 | else: 193 | combined = einops.rearrange( 194 | generated_images, 195 | 'b n s h w c -> b (n h) (s w) c' 196 | ) 197 | combined = (combined * 255).astype(np.uint8) 198 | 199 | n_frames = FLAGS.n_new_frames 200 | if FLAGS.include_input: 201 | n_frames += context_length 202 | 203 | if FLAGS.resize_output == '': 204 | resizes = None 205 | 206 | elif FLAGS.resize_output == 'original': 207 | resizes = batch_sizes.numpy() 208 | resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]]) 209 | else: 210 | resize = tuple(int(x) for x in FLAGS.resize_output.split(',')) 211 | resizes = np.array([resize] * len(batch_sizes)) 212 | resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]]) 213 | 214 | image_saver(combined, batch_output_files, resizes) 215 | 216 | if FLAGS.evaluate_mse: 217 | mses = np.concatenate(mses, axis=0) 218 | print(f'MSE: {np.mean(mses)}') 219 | 220 | image_saver.close() 221 | 222 | if __name__ == "__main__": 223 | mlxu.run(main) -------------------------------------------------------------------------------- /evaluation/vqlm_demo/eval_perplexity.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluating the perplexity on few shot tasks. This script accept a jsonl file 3 | as input. Each line of the jsonl file representing a dictionary. Each line 4 | represents one example in the evaluation set. The dictionary should have two key: 5 | 6 | input: a list of paths to the input images as context to the model. This 7 | list should include the few shot examples. 8 | target: a list of paths to the target images to evaluate perplexity 9 | 10 | Ths script should run the model and compute the average perplexity on the 11 | evaluation set. 12 | """ 13 | 14 | import os 15 | import json 16 | from PIL import Image 17 | import numpy as np 18 | import mlxu 19 | from tqdm import tqdm, trange 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | import einops 24 | 25 | from .inference import MultiProcessInferenceModel 26 | 27 | 28 | FLAGS, _ = mlxu.define_flags_with_default( 29 | input_file='', 30 | checkpoint='', 31 | input_base_dir='', 32 | batch_size=2, 33 | json_input_key='input', 34 | json_target_key='target', 35 | dtype='float16', 36 | torch_devices='', 37 | n_workers=4, 38 | max_examples=0, 39 | ) 40 | 41 | 42 | def read_image_to_tensor(path): 43 | pil_im = Image.open(path).convert('RGB') 44 | input_img = pil_im.resize((256, 256)) 45 | input_img = np.array(input_img) / 255.0 46 | input_img = input_img.astype(np.float32) 47 | return input_img 48 | 49 | 50 | class MultiFrameDataset(torch.utils.data.Dataset): 51 | def __init__(self, input_files, target_files): 52 | assert len(input_files) == len(target_files) 53 | self.input_files = input_files 54 | self.target_files = target_files 55 | 56 | def __len__(self): 57 | return len(self.input_files) 58 | 59 | def __getitem__(self, idx): 60 | input_list = np.stack( 61 | [read_image_to_tensor(f) for f in self.input_files[idx]], 62 | axis=0 63 | ) 64 | target_list = np.stack( 65 | [read_image_to_tensor(f) for f in self.target_files[idx]], 66 | axis=0 67 | ) 68 | return input_list, target_list 69 | 70 | 71 | def main(_): 72 | assert FLAGS.checkpoint != '' 73 | 74 | print(f'Loading checkpoint from {FLAGS.checkpoint}') 75 | print(f'Evaluating input file from {FLAGS.input_file}') 76 | 77 | model = MultiProcessInferenceModel( 78 | checkpoint=FLAGS.checkpoint, 79 | torch_devices=FLAGS.torch_devices, 80 | dtype=FLAGS.dtype, 81 | use_lock=True, 82 | perplexity_batch_size=FLAGS.batch_size, 83 | ) 84 | 85 | input_files = [] 86 | target_files = [] 87 | 88 | with mlxu.open_file(FLAGS.input_file, 'r') as f: 89 | for line in f: 90 | record = json.loads(line) 91 | input_files.append(record[FLAGS.json_input_key]) 92 | target_files.append(record[FLAGS.json_target_key]) 93 | 94 | if FLAGS.input_base_dir != '': 95 | input_files = [ 96 | [os.path.join(FLAGS.input_base_dir, x) for x in y] 97 | for y in input_files 98 | ] 99 | target_files = [ 100 | [os.path.join(FLAGS.input_base_dir, x) for x in y] 101 | for y in target_files 102 | ] 103 | 104 | if FLAGS.max_examples > 0: 105 | input_files = input_files[:FLAGS.max_examples] 106 | target_files = target_files[:FLAGS.max_examples] 107 | 108 | dataset = MultiFrameDataset(input_files, target_files) 109 | data_loader = torch.utils.data.DataLoader( 110 | dataset, 111 | batch_size=FLAGS.batch_size * model.n_processes, 112 | shuffle=False, 113 | num_workers=FLAGS.n_workers 114 | ) 115 | 116 | perplexities = [] 117 | 118 | for input_images, target_images in tqdm(data_loader, ncols=0): 119 | perplexity = model.compute_perplexity(input_images, target_images) 120 | perplexities.append(perplexity) 121 | 122 | perplexities = np.concatenate(perplexities, axis=0) 123 | print(f'Perplexity: {np.mean(perplexities)}') 124 | 125 | 126 | if __name__ == "__main__": 127 | mlxu.run(main) -------------------------------------------------------------------------------- /evaluation/vqlm_demo/eval_video_perplexity.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import glob 4 | from functools import partial 5 | from tqdm import tqdm, trange 6 | from multiprocessing import Pool 7 | from PIL import Image 8 | import cv2 9 | import mlxu 10 | from natsort import natsorted 11 | import numpy as np 12 | import einops 13 | import torch 14 | 15 | from vqlm_demo.inference import MultiProcessInferenceModel 16 | from vqlm_demo.utils import ( 17 | is_video, random_square_crop, 18 | read_frames_from_dir, read_frames_from_video 19 | ) 20 | 21 | 22 | FLAGS, _ = mlxu.define_flags_with_default( 23 | checkpoint='', 24 | input_files='', 25 | frame_input=False, 26 | read_file_list='', 27 | center_crop=1.0, 28 | n_context_frames=15, 29 | n_target_frames=1, 30 | n_workers=8, 31 | stride=8, 32 | batch_size=2, 33 | torch_devices='', 34 | shuffle=False, 35 | random_start=True, 36 | max_examples=0, 37 | ) 38 | 39 | 40 | class VideoDataset(torch.utils.data.Dataset): 41 | 42 | def __init__(self, videos, frame_input=False, n_context_frames=15, 43 | n_target_frames=1, stride=1): 44 | self.videos = videos 45 | self.frame_input = frame_input 46 | self.n_context_frames = n_context_frames 47 | self.n_target_frames = n_target_frames 48 | self.stride = stride 49 | 50 | def __getitem__(self, index): 51 | if self.frame_input: 52 | frames = read_frames_from_dir( 53 | self.videos[index], 54 | self.n_context_frames + self.n_target_frames, 55 | self.stride, 56 | center_crop=FLAGS.center_crop, 57 | random_start=FLAGS.random_start, 58 | ) 59 | else: 60 | frames = read_frames_from_video( 61 | self.videos[index], 62 | self.n_context_frames + self.n_target_frames, 63 | self.stride, 64 | center_crop=FLAGS.center_crop, 65 | random_start=FLAGS.random_start, 66 | ) 67 | if frames is None: 68 | return self[np.random.randint(0, len(self))] 69 | return frames[:self.n_context_frames], frames[self.n_context_frames:] 70 | 71 | def __len__(self): 72 | return len(self.videos) 73 | 74 | 75 | 76 | def main(_): 77 | assert FLAGS.checkpoint != '' 78 | assert FLAGS.read_file_list != '' or FLAGS.input_files != '' 79 | 80 | model = MultiProcessInferenceModel( 81 | checkpoint=FLAGS.checkpoint, 82 | torch_devices=FLAGS.torch_devices, 83 | perplexity_batch_size=FLAGS.batch_size, 84 | ) 85 | 86 | if FLAGS.read_file_list != '': 87 | with open(FLAGS.read_file_list, 'r') as f: 88 | videos = [x.strip() for x in f.readlines()] 89 | else: 90 | videos = glob.glob(FLAGS.input_files) 91 | 92 | if FLAGS.frame_input: 93 | videos = [x for x in videos if os.path.isdir(x)] 94 | else: 95 | videos = [x for x in videos if is_video(x)] 96 | 97 | if FLAGS.shuffle: 98 | np.random.shuffle(videos) 99 | 100 | if FLAGS.max_examples > 0: 101 | videos = videos[:FLAGS.max_examples] 102 | 103 | dataset = VideoDataset( 104 | videos, 105 | frame_input=FLAGS.frame_input, 106 | n_context_frames=FLAGS.n_context_frames, 107 | n_target_frames=FLAGS.n_target_frames, 108 | stride=FLAGS.stride 109 | ) 110 | dataloader = torch.utils.data.DataLoader( 111 | dataset, 112 | batch_size=FLAGS.batch_size * model.n_processes * 4, 113 | shuffle=False, 114 | num_workers=FLAGS.n_workers, 115 | prefetch_factor=4, 116 | drop_last=True, 117 | ) 118 | 119 | perplexities = [] 120 | 121 | for batch_context_frames, batch_taret_frames in tqdm(dataloader, ncols=0): 122 | batch_context_frames = batch_context_frames.numpy() 123 | batch_taret_frames = batch_taret_frames.numpy() 124 | perplexity = model.compute_perplexity( 125 | batch_context_frames, batch_taret_frames 126 | ) 127 | perplexities.append(perplexity) 128 | 129 | perplexities = np.concatenate(perplexities, axis=0) 130 | print(f'Perplexity: {np.mean(perplexities)}') 131 | 132 | 133 | if __name__ == '__main__': 134 | mlxu.run(main) -------------------------------------------------------------------------------- /evaluation/vqlm_demo/eval_videos.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from functools import partial 4 | from tqdm import tqdm, trange 5 | from multiprocessing import Pool 6 | from PIL import Image 7 | import cv2 8 | import mlxu 9 | from natsort import natsorted 10 | import numpy as np 11 | import einops 12 | import torch 13 | 14 | from vqlm_demo.inference import MultiProcessInferenceModel 15 | from vqlm_demo.utils import ( 16 | is_video, random_square_crop, 17 | read_frames_from_dir, read_frames_from_video 18 | ) 19 | 20 | 21 | FLAGS, _ = mlxu.define_flags_with_default( 22 | checkpoint='', 23 | input_files='', 24 | frame_input=False, 25 | read_file_list='', 26 | output_dir='', 27 | center_crop=1.0, 28 | n_context_frames=12, 29 | n_new_frames=4, 30 | n_candidates=8, 31 | temperature=1.0, 32 | top_p=1.0, 33 | n_workers=8, 34 | stride=8, 35 | batch_size=32, 36 | torch_devices='', 37 | shuffle=False, 38 | max_examples=0, 39 | ) 40 | 41 | 42 | def save_image(args): 43 | image, filename = args 44 | base = FLAGS.input_files.split('*')[0] 45 | filename = filename[len(base):].replace('/', '_') + '.png' 46 | Image.fromarray(image).save(os.path.join(FLAGS.output_dir, filename)) 47 | 48 | 49 | class VideoDataset(torch.utils.data.Dataset): 50 | 51 | def __init__(self, videos, frame_input=False, n_frames=8, stride=1, new_frame=1): 52 | self.videos = videos 53 | self.frame_input = frame_input 54 | self.n_frames = n_frames 55 | self.stride = stride 56 | self.new_frames = new_frames 57 | 58 | def __getitem__(self, index): 59 | if self.frame_input: 60 | frames = read_frames_from_dir( 61 | self.videos[index], self.n_frames, self.stride, 62 | center_crop=FLAGS.center_crop, 63 | ) 64 | 65 | else: 66 | # 's h w c' 67 | frames = read_frames_from_video( 68 | self.videos[index], self.n_frames, self.stride, 69 | center_crop=FLAGS.center_crop, 70 | ) 71 | target_frames = frames[n_frames-new_frame:n_frames, :, :, :] 72 | 73 | if frames is None: 74 | return self[np.random.randint(0, len(self))] 75 | 76 | 77 | return frames, target_frames, self.videos[index] 78 | 79 | def __len__(self): 80 | return len(self.videos) 81 | 82 | 83 | 84 | def main(_): 85 | assert FLAGS.checkpoint != '' and FLAGS.output_dir != '' 86 | assert FLAGS.read_file_list != '' or FLAGS.input_files != '' 87 | os.makedirs(FLAGS.output_dir, exist_ok=True) 88 | 89 | if FLAGS.read_file_list != '': 90 | with open(FLAGS.read_file_list, 'r') as f: 91 | videos = [x.strip() for x in f.readlines()] 92 | else: 93 | videos = glob.glob(FLAGS.input_files) 94 | 95 | if FLAGS.frame_input: 96 | videos = [x for x in videos if os.path.isdir(x)] 97 | else: 98 | videos = [x for x in videos if is_video(x)] 99 | 100 | if FLAGS.shuffle: 101 | np.random.shuffle(videos) 102 | 103 | if FLAGS.max_examples > 0: 104 | videos = videos[:FLAGS.max_examples] 105 | 106 | dataset = VideoDataset( 107 | videos, 108 | frame_input=FLAGS.frame_input, 109 | n_frames=FLAGS.n_context_frames, 110 | stride=FLAGS.stride 111 | ) 112 | dataloader = torch.utils.data.DataLoader( 113 | dataset, 114 | batch_size=FLAGS.batch_size, 115 | shuffle=False, 116 | num_workers=FLAGS.n_workers, 117 | prefetch_factor=4, 118 | drop_last=True, 119 | ) 120 | 121 | if FLAGS.torch_devices == '': 122 | torch_devices = None 123 | else: 124 | torch_devices = [f'cuda:{x}' for x in FLAGS.torch_devices.split(',')] 125 | 126 | model = MultiProcessInferenceModel( 127 | checkpoint=FLAGS.checkpoint, torch_devices=torch_devices, 128 | ) 129 | 130 | save_img_pool = Pool(FLAGS.n_workers) 131 | 132 | 133 | fids 134 | 135 | for batch, batch_targets, filenames in tqdm(dataloader, ncols=0): 136 | 137 | batch = batch.numpy() # 'b s h w c ' 138 | 139 | 140 | 141 | generated = model( 142 | batch, 143 | n_new_frames=FLAGS.n_new_frames, 144 | n_candidates=FLAGS.n_candidates, 145 | temperature=FLAGS.temperature, 146 | top_p=FLAGS.top_p, 147 | ) 148 | 149 | 150 | generated = np.array(generated) 151 | 152 | batch_targets = einops.repeat( 153 | batch_targets.numpy(), 154 | 'b s h w c -> b n s h w c', # batch, candidate, sequence, h, w, c. 155 | n=FLAGS.n_candidates, 156 | ) 157 | 158 | 159 | if __name__ == '__main__': 160 | mlxu.run(main) -------------------------------------------------------------------------------- /evaluation/vqlm_demo/generate_videos.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import glob 4 | from functools import partial 5 | from tqdm import tqdm, trange 6 | from multiprocessing import Pool 7 | from PIL import Image 8 | import cv2 9 | import mlxu 10 | from natsort import natsorted 11 | import numpy as np 12 | import einops 13 | import torch 14 | 15 | from vqlm_demo.inference import MultiProcessInferenceModel 16 | from vqlm_demo.utils import ( 17 | is_video, random_square_crop, 18 | read_frames_from_dir, read_frames_from_video 19 | ) 20 | 21 | 22 | FLAGS, _ = mlxu.define_flags_with_default( 23 | checkpoint='', 24 | input_files='', 25 | frame_input=False, 26 | read_file_list='', 27 | output_dir='', 28 | center_crop=1.0, 29 | n_context_frames=12, 30 | n_new_frames=4, 31 | n_candidates=8, 32 | temperature=1.0, 33 | top_p=1.0, 34 | n_workers=8, 35 | stride=8, 36 | batch_size=32, 37 | torch_devices='', 38 | shuffle=False, 39 | max_examples=0, 40 | ) 41 | 42 | 43 | def save_image(args): 44 | image, filename = args 45 | base = FLAGS.input_files.split('*')[0] 46 | filename = filename[len(base):].replace('/', '_') + '.png' 47 | Image.fromarray(image).save(os.path.join(FLAGS.output_dir, filename)) 48 | 49 | 50 | class VideoDataset(torch.utils.data.Dataset): 51 | 52 | def __init__(self, videos, frame_input=False, n_frames=8, stride=1): 53 | self.videos = videos 54 | self.frame_input = frame_input 55 | self.n_frames = n_frames 56 | self.stride = stride 57 | 58 | def __getitem__(self, index): 59 | if self.frame_input: 60 | frames = read_frames_from_dir( 61 | self.videos[index], self.n_frames, self.stride, 62 | center_crop=FLAGS.center_crop, 63 | ) 64 | else: 65 | frames = read_frames_from_video( 66 | self.videos[index], self.n_frames, self.stride, 67 | center_crop=FLAGS.center_crop, 68 | ) 69 | if frames is None: 70 | return self[np.random.randint(0, len(self))] 71 | return frames, self.videos[index] 72 | 73 | def __len__(self): 74 | return len(self.videos) 75 | 76 | 77 | 78 | def main(_): 79 | assert FLAGS.checkpoint != '' and FLAGS.output_dir != '' 80 | assert FLAGS.read_file_list != '' or FLAGS.input_files != '' 81 | os.makedirs(FLAGS.output_dir, exist_ok=True) 82 | 83 | if FLAGS.read_file_list != '': 84 | with open(FLAGS.read_file_list, 'r') as f: 85 | videos = [x.strip() for x in f.readlines()] 86 | else: 87 | videos = glob.glob(FLAGS.input_files) 88 | 89 | if FLAGS.frame_input: 90 | videos = [x for x in videos if os.path.isdir(x)] 91 | else: 92 | videos = [x for x in videos if is_video(x)] 93 | 94 | if FLAGS.shuffle: 95 | np.random.shuffle(videos) 96 | 97 | if FLAGS.max_examples > 0: 98 | videos = videos[:FLAGS.max_examples] 99 | 100 | dataset = VideoDataset( 101 | videos, 102 | frame_input=FLAGS.frame_input, 103 | n_frames=FLAGS.n_context_frames, 104 | stride=FLAGS.stride 105 | ) 106 | dataloader = torch.utils.data.DataLoader( 107 | dataset, 108 | batch_size=FLAGS.batch_size, 109 | shuffle=False, 110 | num_workers=FLAGS.n_workers, 111 | prefetch_factor=4, 112 | drop_last=True, 113 | ) 114 | 115 | if FLAGS.torch_devices == '': 116 | torch_devices = None 117 | else: 118 | torch_devices = [f'cuda:{x}' for x in FLAGS.torch_devices.split(',')] 119 | 120 | model = MultiProcessInferenceModel( 121 | checkpoint=FLAGS.checkpoint, torch_devices=torch_devices, 122 | ) 123 | 124 | save_img_pool = Pool(FLAGS.n_workers) 125 | 126 | 127 | 128 | for batch, filenames in tqdm(dataloader, ncols=0): 129 | 130 | 131 | 132 | batch = batch.numpy() 133 | 134 | 135 | 136 | generated = model( 137 | batch, 138 | n_new_frames=FLAGS.n_new_frames, 139 | n_candidates=FLAGS.n_candidates, 140 | temperature=FLAGS.temperature, 141 | top_p=FLAGS.top_p, 142 | ) 143 | 144 | 145 | generated = np.array(generated) 146 | 147 | 148 | 149 | 150 | output_batch = einops.repeat( 151 | batch, 152 | 'b s h w c -> b n s h w c', 153 | n=FLAGS.n_candidates, 154 | ) 155 | 156 | 157 | combined = einops.rearrange( 158 | np.concatenate([output_batch, generated], axis=2), 159 | 'b n s h w c -> b (n h) (s w) c' 160 | ) 161 | 162 | 163 | combined = (np.clip(combined, 0, 1) * 255).astype(np.uint8) 164 | save_img_pool.imap(save_image, zip(combined, filenames)) 165 | 166 | 167 | if __name__ == '__main__': 168 | mlxu.run(main) -------------------------------------------------------------------------------- /evaluation/vqlm_demo/inference.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from contextlib import nullcontext 3 | import time 4 | import os 5 | from functools import partial 6 | from copy import deepcopy 7 | from multiprocessing import Pool 8 | from threading import Lock 9 | from PIL import Image 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | import einops 14 | from transformers import LlamaForCausalLM 15 | 16 | from .vqvae_muse import VQGANModel, get_tokenizer_muse 17 | from .torch_vqvae_model import get_tokenizer 18 | 19 | 20 | def get_torch_float_dtype(dtype): 21 | if dtype in (torch.float16, torch.bfloat16, torch.float32): 22 | return dtype 23 | return { 24 | 'float16': torch.float16, 25 | 'fp16': torch.float16, 26 | 'f16': torch.float16, 27 | 'bfloat16': torch.bfloat16, 28 | 'bf16': torch.bfloat16, 29 | 'float32': torch.float32, 30 | 'fp32': torch.float32, 31 | 'f32': torch.float32, 32 | }[dtype] 33 | 34 | 35 | def get_pid(): 36 | time.sleep(1) 37 | return os.getpid() 38 | 39 | 40 | class InferenceModel(ABC): 41 | 42 | @abstractmethod 43 | def __call__(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0): 44 | raise NotImplementedError() 45 | 46 | 47 | class LocalInferenceModel(InferenceModel): 48 | 49 | def __init__(self, checkpoint, dtype='float16', torch_device='cuda', 50 | context_frames=16, use_lock=False): 51 | self.checkpoint = checkpoint 52 | self.dtype = dtype 53 | self.torch_device = torch_device 54 | self.context_frames = context_frames 55 | 56 | # old version of the tokenizer 57 | # self.tokenizer = get_tokenizer() 58 | # self.tokenizer.to(self.torch_device) 59 | 60 | # new tokenizer 61 | self.tokenizer = get_tokenizer_muse() 62 | self.tokenizer.to(self.torch_device) 63 | 64 | self.model = LlamaForCausalLM.from_pretrained( 65 | self.checkpoint, torch_dtype=get_torch_float_dtype(self.dtype) 66 | ).to(self.torch_device) 67 | 68 | if use_lock: 69 | self.lock = Lock() 70 | else: 71 | self.lock = nullcontext() 72 | 73 | @torch.no_grad() 74 | def compute_perplexity(self, input_images, target_images): 75 | input_images = np.array(input_images) 76 | target_images = np.array(target_images) 77 | assert len(input_images.shape) == 5 and len(target_images.shape) == 5 # [B, S, H, W, C] 78 | assert input_images.shape[0] == target_images.shape[0] 79 | batch_size = input_images.shape[0] 80 | with self.lock: 81 | input_images = torch.tensor( 82 | einops.rearrange(input_images, 'b s h w c -> b s c h w') 83 | ).to(self.torch_device) 84 | target_images = torch.tensor( 85 | einops.rearrange(target_images, 'b s h w c -> b s c h w') 86 | ).to(self.torch_device) 87 | input_ids = self.tokenizer.tokenize(input_images).view(batch_size, -1) 88 | target_ids = self.tokenizer.tokenize(target_images).view(batch_size, -1) 89 | all_ids = torch.cat([input_ids, target_ids], dim=1) 90 | logits = self.model(all_ids).logits 91 | log_probs = F.log_softmax(logits, dim=-1) 92 | target_ids_onehot = F.one_hot(target_ids, num_classes=logits.shape[-1]) 93 | target_log_probs = log_probs[:, input_ids.shape[1] - 1 : -1] 94 | perplexity = torch.exp( 95 | -torch.mean( 96 | torch.sum(target_log_probs * target_ids_onehot, dim=-1), 97 | dim=-1 98 | ) 99 | ) 100 | return perplexity.detach().cpu().numpy() 101 | 102 | @torch.no_grad() 103 | def generate_once(self, input_images, n_new_frames, temperature=1.0, top_p=1.0): 104 | assert type(input_images) == np.ndarray 105 | with self.lock: 106 | input_images = np.array(input_images, dtype=np.float32) 107 | input_images = torch.tensor( 108 | einops.rearrange(input_images, 'b h w c -> b c h w') 109 | ).to(self.torch_device) 110 | 111 | print('here:', type(input_images)) 112 | 113 | # old tokenizer 114 | # input_ids = self.tokenizer.tokenize(input_images).view(1, -1) 115 | 116 | # new tokenizer 117 | _, input_ids = self.tokenizer.encode(input_images) 118 | input_ids = input_ids.view(1, -1) 119 | 120 | 121 | input_ids = input_ids[:, -(self.context_frames - 1) * 256:] 122 | 123 | new_tokens = [] 124 | current_context_frames = input_ids.shape[1] // 256 125 | fisrt_generation_left = self.context_frames - current_context_frames 126 | first_new_frames = min(fisrt_generation_left, n_new_frames) 127 | input_ids = self.model.generate( 128 | input_ids=input_ids, 129 | attention_mask=torch.ones_like(input_ids), 130 | pad_token_id=8192, 131 | max_new_tokens=256 * first_new_frames, 132 | do_sample=True, 133 | top_p=top_p, 134 | temperature=temperature, 135 | suppress_tokens=list(range(8192, self.model.vocab_size)), 136 | ) 137 | new_tokens.append(input_ids[:, -256 * first_new_frames:]) 138 | input_ids = input_ids[:, -(self.context_frames - 1) * 256:] 139 | 140 | for _ in range(max(0, n_new_frames - first_new_frames)): 141 | input_ids = self.model.generate( 142 | input_ids=input_ids, 143 | attention_mask=torch.ones_like(input_ids), 144 | pad_token_id=8192, 145 | max_new_tokens=256, 146 | do_sample=True, 147 | top_p=top_p, 148 | temperature=temperature, 149 | suppress_tokens=list(range(8192, self.model.vocab_size)), 150 | ) 151 | new_tokens.append(input_ids[:, -256:]) 152 | input_ids = input_ids[:, -(self.context_frames - 1) * 256:] 153 | 154 | new_tokens = torch.cat(new_tokens, dim=1).view(-1, 256) 155 | new_images = einops.rearrange( 156 | torch.clamp(self.tokenizer.decode_code(new_tokens), 0.0, 1.0), 157 | 'b c h w -> b h w c' 158 | ).detach().cpu().numpy() 159 | return new_images 160 | 161 | def __call__(self, input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0): 162 | output = [] 163 | for seq in input_images: 164 | output.append( 165 | [self.generate_once(seq, n_new_frames, temperature, top_p) 166 | for _ in range(n_candidates)] 167 | ) 168 | return output 169 | 170 | 171 | class MultiProcessInferenceModel(InferenceModel): 172 | 173 | def __init__(self, checkpoint, torch_devices=None, dtype='float16', 174 | context_frames=16, use_lock=False, perplexity_batch_size=2): 175 | if torch_devices is None or torch_devices == '': 176 | torch_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())] 177 | 178 | self.torch_devices = torch_devices 179 | self.n_processes = len(torch_devices) 180 | print(f'Using {self.n_processes} processes for inference') 181 | self.worker_pool = Pool(self.n_processes) 182 | self.worker_pids = self.worker_pool.starmap(get_pid, [tuple() for _ in range(self.n_processes)]) 183 | self.device_map = { 184 | pid: torch_device 185 | for pid, torch_device in zip(self.worker_pids, self.torch_devices) 186 | } 187 | self.worker_pool.starmap( 188 | self.initialize_worker, 189 | [(self.device_map, checkpoint, dtype, context_frames) for _ in range(self.n_processes)] 190 | ) 191 | self.perplexity_batch_size = perplexity_batch_size 192 | if use_lock: 193 | self.lock = Lock() 194 | else: 195 | self.lock = nullcontext() 196 | 197 | @staticmethod 198 | def initialize_worker(device_map, checkpoint, dtype, context_frames): 199 | global _current_process_backend 200 | torch_device = device_map[os.getpid()] 201 | _current_process_backend = LocalInferenceModel( 202 | checkpoint, dtype, torch_device, context_frames 203 | ) 204 | 205 | @staticmethod 206 | def generate_once(input_images, n_new_frames, temperature=1.0, top_p=1.0): 207 | return _current_process_backend.generate_once(input_images, n_new_frames, temperature, top_p) 208 | 209 | @staticmethod 210 | def compute_perplexity_once(input_images, target_images): 211 | return _current_process_backend.compute_perplexity(input_images, target_images) 212 | 213 | def compute_perplexity(self, input_images, target_images): 214 | with self.lock: 215 | map_args = [] 216 | for i in range(0, len(input_images), self.perplexity_batch_size): 217 | map_args.append(( 218 | input_images[i : i + self.perplexity_batch_size], 219 | target_images[i : i + self.perplexity_batch_size] 220 | )) 221 | outputs = self.worker_pool.starmap(self.compute_perplexity_once, map_args) 222 | return np.concatenate(outputs, axis=0) 223 | 224 | def __call__(self, input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0): 225 | with self.lock: 226 | map_args = [] 227 | for seq in input_images: 228 | for _ in range(n_candidates): 229 | map_args.append((seq, n_new_frames, temperature, top_p)) 230 | 231 | outputs = self.worker_pool.starmap(self.generate_once, map_args) 232 | reshaped_output = [] 233 | index = 0 234 | for _ in range(len(input_images)): 235 | candidates = [] 236 | for _ in range(n_candidates): 237 | candidates.append(outputs[index]) 238 | index += 1 239 | reshaped_output.append(candidates) 240 | return reshaped_output 241 | 242 | -------------------------------------------------------------------------------- /evaluation/vqlm_demo/torch_vqvae_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import einops 6 | from einops.layers.torch import Rearrange 7 | 8 | 9 | def normalize(in_channels): 10 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 11 | 12 | def swish(x): 13 | return x*torch.sigmoid(x) 14 | 15 | class ResBlock(nn.Module): 16 | def __init__(self, in_channels, out_channels=None, activation_fn="relu"): 17 | super(ResBlock, self).__init__() 18 | self.in_channels = in_channels 19 | self.out_channels = in_channels if out_channels is None else out_channels 20 | self.norm1 = normalize(in_channels) 21 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.norm2 = normalize(out_channels) 23 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) 24 | if self.in_channels != self.out_channels: 25 | self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 26 | self.activation_fn = activation_fn 27 | if activation_fn=="relu": 28 | self.actn = nn.ReLU() 29 | 30 | 31 | def forward(self, x_in): 32 | x = x_in 33 | x = self.norm1(x) 34 | if self.activation_fn=="relu": 35 | x = self.actn(x) 36 | elif self.activation_fn=="swish": 37 | x = swish(x) 38 | x = self.conv1(x) 39 | x = self.norm2(x) 40 | if self.activation_fn=="relu": 41 | x = self.actn(x) 42 | elif self.activation_fn=="swish": 43 | x = swish(x) 44 | x = self.conv2(x) 45 | if self.in_channels != self.out_channels: 46 | x_in = self.conv_out(x_in) 47 | 48 | return x + x_in 49 | 50 | class Encoder(nn.Module): 51 | def __init__(self, ): 52 | super().__init__() 53 | 54 | self.filters = 128 55 | self.num_res_blocks = 2 56 | self.ch_mult = [1,1,2,2,4] 57 | self.in_ch_mult = (1,)+tuple(self.ch_mult) 58 | self.embedding_dim = 32 59 | self.conv_downsample = False 60 | 61 | self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, bias=False) 62 | blocks = [] 63 | for i in range(len(self.ch_mult)): 64 | block_in_ch = self.filters * self.in_ch_mult[i] 65 | block_out_ch = self.filters * self.ch_mult[i] 66 | for _ in range(self.num_res_blocks): 67 | blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) 68 | block_in_ch = block_out_ch 69 | for _ in range(self.num_res_blocks): 70 | blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) 71 | self.norm1 = normalize(block_in_ch) 72 | self.conv2 = nn.Conv2d(block_in_ch, self.embedding_dim, kernel_size=1, stride=1, padding=0) 73 | self.blocks = nn.ModuleList(blocks) 74 | 75 | def forward(self, x): 76 | x = self.conv1(x) 77 | for i in range(len(self.ch_mult)): 78 | for j in range(self.num_res_blocks): 79 | x = self.blocks[i*2+j](x) 80 | 81 | if i < len(self.ch_mult) -1: 82 | x = torch.nn.functional.avg_pool2d(x, (2,2),(2,2)) 83 | 84 | x = self.blocks[-2](x) 85 | x = self.blocks[-1](x) 86 | 87 | x = self.norm1(x) 88 | x = swish(x) 89 | x = self.conv2(x) 90 | return x 91 | 92 | class VectorQuantizer(nn.Module): 93 | def __init__(self, codebook_size=8192, emb_dim=32, beta=None): 94 | super(VectorQuantizer, self).__init__() 95 | self.codebook_size = codebook_size # number of embeddings 96 | self.emb_dim = emb_dim # dimension of embedding 97 | self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) 98 | self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) 99 | self.beta=0.0 100 | self.z_dim = emb_dim 101 | 102 | def forward(self, z): 103 | # preprocess 104 | 105 | b, c, h, w = z.size() 106 | flatten = z.permute(0, 2, 3, 1).reshape(-1, c) 107 | codebook = self.embedding.weight 108 | with torch.no_grad(): 109 | tokens = torch.cdist(flatten, codebook).argmin(dim=1) 110 | quantized = F.embedding(tokens, 111 | codebook).view(b, h, w, c).permute(0, 3, 1, 2) 112 | 113 | # compute loss 114 | codebook_loss = F.mse_loss(quantized, z.detach()) 115 | commitment_loss = F.mse_loss(quantized.detach(), z) 116 | loss = codebook_loss + self.beta * commitment_loss 117 | 118 | # perplexity 119 | counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype) 120 | # dist.all_reduce(counts) 121 | p = counts / counts.sum() 122 | perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10))) 123 | 124 | # postprocess 125 | tokens = tokens.view(b, h, w) 126 | quantized = z + (quantized - z).detach() 127 | 128 | # quantized_2 = self.get_codebook_feat(tokens, (b, h, w, c)) 129 | 130 | return quantized, tokens, loss, perplexity 131 | 132 | 133 | def get_codebook_feat(self, indices, shape=None): 134 | # input indices: batch*token_num -> (batch*token_num)*1 135 | # shape: batch, height, width, channel 136 | indices = indices.view(-1,1) 137 | min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) 138 | min_encodings.scatter_(1, indices, 1) 139 | # get quantized latent vectors 140 | z_q = torch.matmul(min_encodings.float(), self.embedding.weight) 141 | 142 | if shape is not None: # reshape back to match original input shape 143 | z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() 144 | 145 | return z_q 146 | 147 | 148 | class Decoder(nn.Module): 149 | def __init__(self,): 150 | super().__init__() 151 | self.filters = 128 152 | self.num_res_blocks = 2 153 | self.ch_mult = [1,1,2,2,4] 154 | self.in_ch_mult = (1,)+tuple(self.ch_mult) 155 | self.embedding_dim =32 156 | self.out_channels = 3 157 | self.in_channels = self.embedding_dim 158 | self.conv_downsample = False 159 | 160 | self.conv1 = nn.Conv2d(32, 512, kernel_size=3, stride=1, padding=1) 161 | blocks = [] 162 | block_in_ch = self.filters * self.ch_mult[-1] 163 | block_out_ch = self.filters * self.ch_mult[-1] 164 | #blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) 165 | for _ in range(self.num_res_blocks): 166 | blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) 167 | upsample_conv_layers = [] 168 | for i in reversed(range(len(self.ch_mult))): 169 | block_out_ch = self.filters * self.ch_mult[i] 170 | for _ in range(self.num_res_blocks): 171 | blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) 172 | block_in_ch = block_out_ch 173 | if i > 0: 174 | upsample_conv_layers.append(nn.Conv2d(block_in_ch, block_out_ch*4, kernel_size=3, stride=1, padding=1)) 175 | 176 | self.upsample = Rearrange("b h w (h2 w2 c) -> b (h h2) (w w2) c", h2=2, w2=2) 177 | self.norm1 = normalize(block_in_ch) 178 | # self.act_fn 179 | self.conv6 = nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1) 180 | self.blocks = nn.ModuleList(blocks) 181 | self.up_convs = nn.ModuleList(upsample_conv_layers) 182 | 183 | def forward(self, x): 184 | x = self.conv1(x) 185 | x = self.blocks[0](x) 186 | x = self.blocks[1](x) 187 | for i in range(len(self.ch_mult)): 188 | for j in range(self.num_res_blocks): 189 | x = self.blocks[2+i*2+j](x) 190 | if i < len(self.ch_mult)-1: 191 | x = self.up_convs[i](x) 192 | #print("pre: x.size()",x.size()) 193 | x = x.permute(0,2,3,1) 194 | x = self.upsample(x) 195 | x = x.permute(0,3,1,2) 196 | #print("post: x.size()", x.size()) 197 | x = self.norm1(x) 198 | x = swish(x) 199 | x = self.conv6(x) 200 | return x 201 | 202 | 203 | class VQVAE(nn.Module): 204 | def __init__(self, ): 205 | super().__init__() 206 | self.encoder = Encoder() 207 | self.quantizer = VectorQuantizer() 208 | self.decoder = Decoder() 209 | 210 | def forward(self, x): 211 | x = self.encoder(x) 212 | quant,tokens, loss, perplexity = self.quantizer(x) 213 | x = self.decoder(quant) 214 | return x 215 | 216 | def tokenize(self, x): 217 | batch_shape = x.shape[:-3] 218 | x = x.reshape(-1, *x.shape[-3:]) 219 | x = self.encoder(x) 220 | quant,tokens, loss, perplexity = self.quantizer(x) 221 | return tokens.reshape(*batch_shape, *tokens.shape[1:]) 222 | 223 | def decode(self, tokens): 224 | tokens = einops.rearrange(tokens, 'b ... -> b (...)') 225 | b = tokens.shape[0] 226 | if tokens.shape[-1] == 256: 227 | hw = 16 228 | elif tokens.shape[-1] == 224: 229 | hw = 14 230 | else: 231 | raise ValueError("Invalid tokens shape") 232 | quant = self.quantizer.get_codebook_feat(tokens, (b, hw, hw, 32)) 233 | x = self.decoder(quant) 234 | return x 235 | 236 | 237 | class VAEDecoder(nn.Module): 238 | def __init__(self, ): 239 | super().__init__() 240 | self.quantizer = VectorQuantizer() 241 | self.decoder = Decoder() 242 | 243 | def forward(self, x): 244 | quant = self.quantizer.get_codebook_feat(x,(1,14,14,32)) 245 | x = self.decoder(quant) 246 | return x 247 | 248 | 249 | def get_tokenizer(): 250 | checkpoint_path = os.path.join( 251 | os.path.dirname(os.path.realpath(__file__)), "xh_ckpt.pth" 252 | ) 253 | torch_state_dict = torch.load(checkpoint_path) 254 | net = VQVAE() 255 | net.load_state_dict(torch_state_dict) 256 | return net 257 | 258 | -------------------------------------------------------------------------------- /evaluation/vqlm_demo/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from multiprocessing import Pool 3 | import numpy as np 4 | import random 5 | from PIL import Image 6 | import re 7 | import cv2 8 | import glob 9 | from natsort import natsorted 10 | 11 | 12 | class MultiProcessImageSaver(object): 13 | 14 | def __init__(self, n_workers=1): 15 | self.pool = Pool(n_workers) 16 | 17 | def __call__(self, images, output_files, resizes=None): 18 | if resizes is None: 19 | resizes = [None for _ in range(len(images))] 20 | return self.pool.imap( 21 | self.save_image, 22 | zip(images, output_files, resizes), 23 | ) 24 | 25 | def close(self): 26 | self.pool.close() 27 | self.pool.join() 28 | 29 | @staticmethod 30 | def save_image(args): 31 | image, filename, resize = args 32 | image = Image.fromarray(image) 33 | if resize is not None: 34 | image = image.resize(tuple(resize)) 35 | image.save(filename) 36 | 37 | 38 | def list_dir_with_full_path(path): 39 | return [os.path.join(path, f) for f in os.listdir(path)] 40 | 41 | 42 | def find_all_files_in_dir(path): 43 | files = [] 44 | for root, _, files in os.walk(path): 45 | for file in files: 46 | files.append(os.path.join(root, file)) 47 | return files 48 | 49 | 50 | def is_image(path): 51 | return ( 52 | path.endswith('.jpg') 53 | or path.endswith('.png') 54 | or path.endswith('.jpeg') 55 | or path.endswith('.JPG') 56 | or path.endswith('.PNG') 57 | or path.endswith('.JPEG') 58 | ) 59 | 60 | 61 | def is_video(path): 62 | return ( 63 | path.endswith('.mp4') 64 | or path.endswith('.avi') 65 | or path.endswith('.MP4') 66 | or path.endswith('.AVI') 67 | or path.endswith('.webm') 68 | or path.endswith('.WEBM') 69 | or path.endswith('.mkv') 70 | or path.endswith('.MVK') 71 | ) 72 | 73 | 74 | def random_square_crop(img, random_generator=None): 75 | # If no random generator is provided, use numpy's default 76 | if random_generator is None: 77 | random_generator = np.random.default_rng() 78 | 79 | # Get the width and height of the image 80 | width, height = img.size 81 | 82 | # Determine the shorter side 83 | min_size = min(width, height) 84 | 85 | # Randomly determine the starting x and y coordinates for the crop 86 | if width > height: 87 | left = random_generator.integers(0, width - min_size) 88 | upper = 0 89 | else: 90 | left = 0 91 | upper = random_generator.integers(0, height - min_size) 92 | 93 | # Calculate the ending x and y coordinates for the crop 94 | right = left + min_size 95 | lower = upper + min_size 96 | 97 | # Crop the image 98 | return img.crop((left, upper, right, lower)) 99 | 100 | 101 | def read_image_to_tensor(path, center_crop=1.0): 102 | pil_im = Image.open(path).convert('RGB') 103 | if center_crop < 1.0: 104 | width, height = pil_im.size 105 | pil_im = pil_im.crop(( 106 | int((1 - center_crop) * height / 2), int((1 + center_crop) * height / 2), 107 | int((1 - center_crop) * width / 2), int((1 + center_crop) * width / 2), 108 | )) 109 | input_img = pil_im.resize((256, 256)) 110 | input_img = np.array(input_img) / 255.0 111 | input_img = input_img.astype(np.float32) 112 | return input_img 113 | 114 | 115 | def match_mulitple_path(root_dir, regex): 116 | videos = [] 117 | for root, _, files in os.walk(root_dir): 118 | for file in files: 119 | videos.append(os.path.join(root, file)) 120 | 121 | videos = [v for v in videos if not v.split('/')[-1].startswith('.')] 122 | 123 | grouped_path = {} 124 | for r in regex: 125 | r = re.compile(r) 126 | for v in videos: 127 | matched = r.findall(v) 128 | if len(matched) > 0: 129 | groups = matched[0] 130 | if groups not in grouped_path: 131 | grouped_path[groups] = [] 132 | grouped_path[groups].append(v) 133 | 134 | grouped_path = { 135 | k: tuple(v) for k, v in grouped_path.items() 136 | if len(v) == len(regex) 137 | } 138 | return list(grouped_path.values()) 139 | 140 | 141 | def randomly_subsample_frame_indices(length, n_frames, max_stride=30, random_start=True): 142 | assert length >= n_frames 143 | max_stride = min( 144 | (length - 1) // (n_frames - 1), 145 | max_stride 146 | ) 147 | stride = np.random.randint(1, max_stride + 1) 148 | if random_start: 149 | start = np.random.randint(0, length - (n_frames - 1) * stride) 150 | else: 151 | start = 0 152 | return np.arange(n_frames) * stride + start 153 | 154 | 155 | def read_frames_from_dir(dir_path, n_frames, stride, random_start=True, center_crop=1.0): 156 | files = [os.path.join(dir_path, x) for x in os.listdir(dir_path)] 157 | files = natsorted([x for x in files if is_image(x)]) 158 | 159 | total_frames = len(files) 160 | 161 | if total_frames < n_frames: 162 | return None 163 | 164 | max_stride = (total_frames - 1) // (n_frames - 1) 165 | stride = min(max_stride, stride) 166 | 167 | if random_start: 168 | start = np.random.randint(0, total_frames - (n_frames - 1) * stride) 169 | else: 170 | start = 0 171 | frame_indices = np.arange(n_frames) * stride + start 172 | 173 | frames = [] 174 | for frame_index in sorted(frame_indices): 175 | # Check if the frame_index is valid 176 | frames.append(read_image_to_tensor(files[frame_index], center_crop=center_crop)) 177 | if len(frames) < n_frames: 178 | return None 179 | frames = np.stack(frames, axis=0) 180 | return frames 181 | 182 | 183 | def read_frames_from_video(video_path, n_frames, stride, random_start=True, center_crop=1.0): 184 | 185 | frames = [] 186 | cap = cv2.VideoCapture(video_path) 187 | 188 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 189 | 190 | if total_frames < n_frames: 191 | cap.release() 192 | return None 193 | 194 | max_stride = (total_frames - 1) // (n_frames - 1) 195 | stride = min(max_stride, stride) 196 | 197 | if random_start: 198 | start = np.random.randint(0, total_frames - (n_frames - 1) * stride) 199 | else: 200 | start = 0 201 | frame_indices = np.arange(n_frames) * stride + start 202 | 203 | for frame_index in sorted(frame_indices): 204 | # Check if the frame_index is valid 205 | if 0 <= frame_index < total_frames: 206 | cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) 207 | ret, frame = cap.read() 208 | if ret: 209 | if center_crop < 1.0: 210 | height, width, _ = frame.shape 211 | frame = frame[ 212 | int((1 - center_crop) * height / 2):int((1 + center_crop) * height / 2), 213 | int((1 - center_crop) * width / 2):int((1 + center_crop) * width / 2), 214 | : 215 | ] 216 | frame = cv2.resize(frame, (256, 256)) 217 | 218 | frames.append(frame) 219 | 220 | else: 221 | print(f"Frame index {frame_index} is out of bounds. Skipping...") 222 | 223 | cap.release() 224 | if len(frames) < n_frames: 225 | return None 226 | frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 227 | 228 | # From BGR to RGB 229 | return np.stack( 230 | [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1 231 | ) 232 | 233 | 234 | def read_all_frames_from_video(video_path, center_crop=1.0): 235 | 236 | frames = [] 237 | cap = cv2.VideoCapture(video_path) 238 | 239 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 240 | 241 | 242 | for frame_index in range(total_frames): 243 | # Check if the frame_index is valid 244 | if 0 <= frame_index < total_frames: 245 | cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) 246 | ret, frame = cap.read() 247 | if ret: 248 | if center_crop < 1.0: 249 | height, width, _ = frame.shape 250 | frame = frame[ 251 | int((1 - center_crop) * height / 2):int((1 + center_crop) * height / 2), 252 | int((1 - center_crop) * width / 2):int((1 + center_crop) * width / 2), 253 | : 254 | ] 255 | frames.append(cv2.resize(frame, (256, 256))) 256 | else: 257 | print(f"Frame index {frame_index} is out of bounds. Skipping...") 258 | 259 | cap.release() 260 | if len(frames) == 0: 261 | return None 262 | frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 263 | # From BGR to RGB 264 | return np.stack( 265 | [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1 266 | ) 267 | 268 | 269 | def read_max_span_frames_from_video(video_path, n_frames): 270 | frames = [] 271 | cap = cv2.VideoCapture(video_path) 272 | 273 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 274 | if total_frames < n_frames: 275 | cap.release() 276 | return None 277 | stride = (total_frames - 1) // (n_frames - 1) 278 | frame_indices = np.arange(n_frames) * stride 279 | 280 | frames = [] 281 | for frame_index in frame_indices: 282 | cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) 283 | ret, frame = cap.read() 284 | if ret: 285 | frames.append(cv2.resize(frame, (256, 256))) 286 | 287 | cap.release() 288 | if len(frames) < n_frames: 289 | return None 290 | 291 | frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 292 | # From BGR to RGB 293 | return np.stack( 294 | [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1 295 | ) 296 | 297 | -------------------------------------------------------------------------------- /evaluation/vqlm_demo/vqvae/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/evaluation/vqlm_demo/vqvae/.DS_Store -------------------------------------------------------------------------------- /evaluation/vqlm_demo/vqvae/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | __version__ = "0.0.1" 17 | 18 | # from .modeling_ema import EMAModel 19 | # from .modeling_maskgit_vqgan import MaskGitVQGAN 20 | # from .modeling_movq import MOVQ 21 | # from .modeling_paella_vq import PaellaVQModel 22 | # from .modeling_utils import VQGANModel 23 | # from .modeling_transformer import MaskGitTransformer, MaskGiTUViT 24 | # from .pipeline_muse import PipelineMuse, PipelineMuseInpainting 25 | # from .sampling import get_mask_chedule 26 | -------------------------------------------------------------------------------- /evaluation/vqlm_demo/vqvae/logging.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Optuna, Hugging Face 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Logging utilities.""" 16 | 17 | import logging 18 | import os 19 | import sys 20 | import threading 21 | from logging import CRITICAL # NOQA 22 | from logging import DEBUG # NOQA 23 | from logging import ERROR # NOQA 24 | from logging import FATAL # NOQA 25 | from logging import INFO # NOQA 26 | from logging import NOTSET # NOQA 27 | from logging import WARN # NOQA 28 | from logging import WARNING # NOQA 29 | from typing import Optional 30 | 31 | from tqdm import auto as tqdm_lib 32 | 33 | _lock = threading.Lock() 34 | _default_handler: Optional[logging.Handler] = None 35 | 36 | log_levels = { 37 | "debug": logging.DEBUG, 38 | "info": logging.INFO, 39 | "warning": logging.WARNING, 40 | "error": logging.ERROR, 41 | "critical": logging.CRITICAL, 42 | } 43 | 44 | _default_log_level = logging.WARNING 45 | 46 | _tqdm_active = True 47 | 48 | 49 | def _get_default_logging_level(): 50 | """ 51 | If muse_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is 52 | not - fall back to `_default_log_level` 53 | """ 54 | env_level_str = os.getenv("muse_VERBOSITY", None) 55 | if env_level_str: 56 | if env_level_str in log_levels: 57 | return log_levels[env_level_str] 58 | else: 59 | logging.getLogger().warning( 60 | f"Unknown option muse_VERBOSITY={env_level_str}, has to be one of: { ', '.join(log_levels.keys()) }" 61 | ) 62 | return _default_log_level 63 | 64 | 65 | def _get_library_name() -> str: 66 | return __name__.split(".")[0] 67 | 68 | 69 | def _get_library_root_logger() -> logging.Logger: 70 | return logging.getLogger(_get_library_name()) 71 | 72 | 73 | def _configure_library_root_logger() -> None: 74 | global _default_handler 75 | 76 | with _lock: 77 | if _default_handler: 78 | # This library has already configured the library root logger. 79 | return 80 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream. 81 | _default_handler.flush = sys.stderr.flush 82 | 83 | # Apply our default configuration to the library root logger. 84 | library_root_logger = _get_library_root_logger() 85 | library_root_logger.addHandler(_default_handler) 86 | library_root_logger.setLevel(_get_default_logging_level()) 87 | library_root_logger.propagate = False 88 | 89 | 90 | def _reset_library_root_logger() -> None: 91 | global _default_handler 92 | 93 | with _lock: 94 | if not _default_handler: 95 | return 96 | 97 | library_root_logger = _get_library_root_logger() 98 | library_root_logger.removeHandler(_default_handler) 99 | library_root_logger.setLevel(logging.NOTSET) 100 | _default_handler = None 101 | 102 | 103 | def get_log_levels_dict(): 104 | return log_levels 105 | 106 | 107 | def get_logger(name: Optional[str] = None) -> logging.Logger: 108 | """ 109 | Return a logger with the specified name. 110 | 111 | This function is not supposed to be directly accessed unless you are writing a custom muse module. 112 | """ 113 | 114 | if name is None: 115 | name = _get_library_name() 116 | 117 | _configure_library_root_logger() 118 | return logging.getLogger(name) 119 | 120 | 121 | def get_verbosity() -> int: 122 | """ 123 | Return the current level for the 🤗 muse' root logger as an int. 124 | 125 | Returns: 126 | `int`: The logging level. 127 | 128 | 129 | 130 | 🤗 muse has following logging levels: 131 | 132 | - 50: `muse.logging.CRITICAL` or `muse.logging.FATAL` 133 | - 40: `muse.logging.ERROR` 134 | - 30: `muse.logging.WARNING` or `muse.logging.WARN` 135 | - 20: `muse.logging.INFO` 136 | - 10: `muse.logging.DEBUG` 137 | 138 | """ 139 | 140 | _configure_library_root_logger() 141 | return _get_library_root_logger().getEffectiveLevel() 142 | 143 | 144 | def set_verbosity(verbosity: int) -> None: 145 | """ 146 | Set the verbosity level for the 🤗 muse' root logger. 147 | 148 | Args: 149 | verbosity (`int`): 150 | Logging level, e.g., one of: 151 | 152 | - `muse.logging.CRITICAL` or `muse.logging.FATAL` 153 | - `muse.logging.ERROR` 154 | - `muse.logging.WARNING` or `muse.logging.WARN` 155 | - `muse.logging.INFO` 156 | - `muse.logging.DEBUG` 157 | """ 158 | 159 | _configure_library_root_logger() 160 | _get_library_root_logger().setLevel(verbosity) 161 | 162 | 163 | def set_verbosity_info(): 164 | """Set the verbosity to the `INFO` level.""" 165 | return set_verbosity(INFO) 166 | 167 | 168 | def set_verbosity_warning(): 169 | """Set the verbosity to the `WARNING` level.""" 170 | return set_verbosity(WARNING) 171 | 172 | 173 | def set_verbosity_debug(): 174 | """Set the verbosity to the `DEBUG` level.""" 175 | return set_verbosity(DEBUG) 176 | 177 | 178 | def set_verbosity_error(): 179 | """Set the verbosity to the `ERROR` level.""" 180 | return set_verbosity(ERROR) 181 | 182 | 183 | def disable_default_handler() -> None: 184 | """Disable the default handler of the HuggingFace muse' root logger.""" 185 | 186 | _configure_library_root_logger() 187 | 188 | assert _default_handler is not None 189 | _get_library_root_logger().removeHandler(_default_handler) 190 | 191 | 192 | def enable_default_handler() -> None: 193 | """Enable the default handler of the HuggingFace muse' root logger.""" 194 | 195 | _configure_library_root_logger() 196 | 197 | assert _default_handler is not None 198 | _get_library_root_logger().addHandler(_default_handler) 199 | 200 | 201 | def add_handler(handler: logging.Handler) -> None: 202 | """adds a handler to the HuggingFace muse' root logger.""" 203 | 204 | _configure_library_root_logger() 205 | 206 | assert handler is not None 207 | _get_library_root_logger().addHandler(handler) 208 | 209 | 210 | def remove_handler(handler: logging.Handler) -> None: 211 | """removes given handler from the HuggingFace muse' root logger.""" 212 | 213 | _configure_library_root_logger() 214 | 215 | assert handler is not None and handler not in _get_library_root_logger().handlers 216 | _get_library_root_logger().removeHandler(handler) 217 | 218 | 219 | def disable_propagation() -> None: 220 | """ 221 | Disable propagation of the library log outputs. Note that log propagation is disabled by default. 222 | """ 223 | 224 | _configure_library_root_logger() 225 | _get_library_root_logger().propagate = False 226 | 227 | 228 | def enable_propagation() -> None: 229 | """ 230 | Enable propagation of the library log outputs. Please disable the HuggingFace muse' default handler to prevent 231 | double logging if the root logger has been configured. 232 | """ 233 | 234 | _configure_library_root_logger() 235 | _get_library_root_logger().propagate = True 236 | 237 | 238 | def enable_explicit_format() -> None: 239 | """ 240 | Enable explicit formatting for every HuggingFace muse' logger. The explicit formatter is as follows: 241 | ``` 242 | [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE 243 | ``` 244 | All handlers currently bound to the root logger are affected by this method. 245 | """ 246 | handlers = _get_library_root_logger().handlers 247 | 248 | for handler in handlers: 249 | formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") 250 | handler.setFormatter(formatter) 251 | 252 | 253 | def reset_format() -> None: 254 | """ 255 | Resets the formatting for HuggingFace muse' loggers. 256 | 257 | All handlers currently bound to the root logger are affected by this method. 258 | """ 259 | handlers = _get_library_root_logger().handlers 260 | 261 | for handler in handlers: 262 | handler.setFormatter(None) 263 | 264 | 265 | def warning_advice(self, *args, **kwargs): 266 | """ 267 | This method is identical to `logger.warning()`, but if env var muse_NO_ADVISORY_WARNINGS=1 is set, this 268 | warning will not be printed 269 | """ 270 | no_advisory_warnings = os.getenv("muse_NO_ADVISORY_WARNINGS", False) 271 | if no_advisory_warnings: 272 | return 273 | self.warning(*args, **kwargs) 274 | 275 | 276 | logging.Logger.warning_advice = warning_advice 277 | 278 | 279 | class EmptyTqdm: 280 | """Dummy tqdm which doesn't do anything.""" 281 | 282 | def __init__(self, *args, **kwargs): # pylint: disable=unused-argument 283 | self._iterator = args[0] if args else None 284 | 285 | def __iter__(self): 286 | return iter(self._iterator) 287 | 288 | def __getattr__(self, _): 289 | """Return empty function.""" 290 | 291 | def empty_fn(*args, **kwargs): # pylint: disable=unused-argument 292 | return 293 | 294 | return empty_fn 295 | 296 | def __enter__(self): 297 | return self 298 | 299 | def __exit__(self, type_, value, traceback): 300 | return 301 | 302 | 303 | class _tqdm_cls: 304 | def __call__(self, *args, **kwargs): 305 | if _tqdm_active: 306 | return tqdm_lib.tqdm(*args, **kwargs) 307 | else: 308 | return EmptyTqdm(*args, **kwargs) 309 | 310 | def set_lock(self, *args, **kwargs): 311 | self._lock = None 312 | if _tqdm_active: 313 | return tqdm_lib.tqdm.set_lock(*args, **kwargs) 314 | 315 | def get_lock(self): 316 | if _tqdm_active: 317 | return tqdm_lib.tqdm.get_lock() 318 | 319 | 320 | tqdm = _tqdm_cls() 321 | 322 | 323 | def is_progress_bar_enabled() -> bool: 324 | """Return a boolean indicating whether tqdm progress bars are enabled.""" 325 | global _tqdm_active 326 | return bool(_tqdm_active) 327 | 328 | 329 | def enable_progress_bar(): 330 | """Enable tqdm progress bar.""" 331 | global _tqdm_active 332 | _tqdm_active = True 333 | 334 | 335 | def disable_progress_bar(): 336 | """Disable tqdm progress bar.""" 337 | global _tqdm_active 338 | _tqdm_active = False 339 | -------------------------------------------------------------------------------- /images/visual_sentences.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/images/visual_sentences.jpg -------------------------------------------------------------------------------- /scripts/gpu_environment.yml: -------------------------------------------------------------------------------- 1 | name: LVM 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.10 6 | - pip 7 | - numpy 8 | - scipy 9 | - matplotlib 10 | - seaborn 11 | - jupyter 12 | - tqdm 13 | - sentencepiece 14 | - pip: 15 | - -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 16 | - --extra-index-url https://download.pytorch.org/whl/cu118 17 | - jax[cuda11_pip]==0.4.14 18 | - flax==0.7.0 19 | - optax==0.1.7 20 | - distrax==0.1.4 21 | - chex==0.1.82 22 | - transformers==4.31.0 23 | - torch==2.0.1 24 | - huggingface_hub==0.16.4 25 | - datasets==2.14.2 26 | - einops 27 | - tensorflow==2.11.1 28 | - dill 29 | - absl-py 30 | - wandb 31 | - ml_collections 32 | - gcsfs 33 | - requests 34 | - jupyter_http_over_ws 35 | - lm-eval 36 | - mlxu==0.1.11 37 | - pydantic 38 | - fastapi 39 | - uvicorn 40 | - gradio 41 | - s3fs 42 | -------------------------------------------------------------------------------- /scripts/tpu_commands.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | function _tpu_ips { 4 | tpu_zone=$1 5 | tpu_project=$2 6 | tpu_name=$3 7 | gcloud alpha compute tpus tpu-vm describe $tpu_name --zone $tpu_zone --project $tpu_project | grep -oP 'externalIp: \K(.+)$' 8 | 9 | } 10 | 11 | function _tpu_create { 12 | tpu_zone=$1 13 | tpu_project=$2 14 | tpu_gen=$3 15 | tpu_cores=$4 16 | tpu_name=$5 17 | if [ "$tpu_gen" = "v3" ]; then 18 | software_version='tpu-vm-base' 19 | else 20 | software_version='tpu-vm-v4-base' 21 | fi 22 | 23 | if [[ $tpu_cores =~ ^[0-9]+$ ]]; then 24 | gcloud alpha compute tpus tpu-vm create \ 25 | $tpu_name \ 26 | --accelerator-type="$tpu_gen-$tpu_cores" \ 27 | --version $software_version \ 28 | --zone $tpu_zone \ 29 | --project $tpu_project 30 | else 31 | gcloud alpha compute tpus tpu-vm create \ 32 | $tpu_name \ 33 | --type="$tpu_gen" \ 34 | --topology="$tpu_cores" \ 35 | --version $software_version \ 36 | --zone $tpu_zone \ 37 | --project $tpu_project 38 | fi 39 | } 40 | 41 | function _tpu_retry_create { 42 | while true; do 43 | _tpu_create "$@" 44 | sleep 120s 45 | done 46 | } 47 | 48 | function _tpu_cp_ssh_key { 49 | tpu_zone=$1 50 | tpu_project=$2 51 | tpu_name=$3 52 | 53 | gcloud alpha compute tpus tpu-vm scp \ 54 | $HOME/.ssh/authorized_keys \ 55 | $tpu_name:/home/$USER/.ssh/ \ 56 | --worker=all \ 57 | --project $tpu_project \ 58 | --zone $tpu_zone 59 | } 60 | 61 | function _tpu_setup { 62 | tpu_zone=$1 63 | tpu_project=$2 64 | tpu_name=$3 65 | 66 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 67 | for host in $tpu_ips; do 68 | scp $PROJECT_HOME/$PROJECT_NAME/scripts/tpu_vm_setup.sh $host:~/ 69 | ssh $host '~/tpu_vm_setup.sh' & 70 | done 71 | wait &> /dev/null 72 | 73 | for host in $tpu_ips; do 74 | scp $PROJECT_HOME/$PROJECT_NAME/scripts/tpu_vm_setup.sh $host:~/ 75 | ssh $host '~/tpu_vm_setup.sh' & 76 | done 77 | wait &> /dev/null 78 | } 79 | 80 | function _tpu_check { 81 | tpu_zone=$1 82 | tpu_project=$2 83 | tpu_name=$3 84 | 85 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 86 | for host in $tpu_ips; do 87 | echo "============== Checking host: $host ==============" 88 | ssh $host 'tmux capture-pane -pt launch -S -2000' 89 | echo 90 | echo 91 | done 92 | } 93 | 94 | function _tpu_copy { 95 | tpu_zone=$1 96 | tpu_project=$2 97 | tpu_name=$3 98 | 99 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 100 | for host in $tpu_ips; do 101 | rsync -avPI --exclude=logs --exclude=__pycache__ --exclude=.git $PROJECT_HOME/$PROJECT_NAME $host:~/ & 102 | done 103 | wait &> /dev/null 104 | sleep 1s 105 | 106 | for host in $tpu_ips; do 107 | rsync -avPI --exclude=logs --exclude=__pycache__ --exclude=.git $PROJECT_HOME/$PROJECT_NAME $host:~/ & 108 | done 109 | wait &> /dev/null 110 | sleep 1s 111 | } 112 | 113 | function _tpu_stop { 114 | tpu_zone=$1 115 | tpu_project=$2 116 | tpu_name=$3 117 | 118 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 119 | for host in $tpu_ips; do 120 | ssh $host 'tmux kill-session -t launch ; pkill -9 python' & 121 | done 122 | wait &> /dev/null 123 | } 124 | 125 | function _tpu_launch { 126 | tpu_zone=$1 127 | tpu_project=$2 128 | tpu_name=$3 129 | command=$4 130 | 131 | if [ -z "$command" ]; then 132 | echo "Invalid syntax!" 133 | return 1 134 | fi 135 | 136 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 137 | for host in $tpu_ips; do 138 | ssh $host "tmux new -d -s launch ~/$PROJECT_NAME/launcher/$command" & 139 | done 140 | wait &> /dev/null 141 | } 142 | 143 | function _tpu_maintain { 144 | tpu_zone=$1 145 | tpu_project=$2 146 | tpu_name=$3 147 | 148 | gcloud alpha compute tpus tpu-vm simulate-maintenance-event $tpu_name \ 149 | --project $tpu_project \ 150 | --zone=$tpu_zone \ 151 | --workers=all 152 | } 153 | 154 | function _tpu_ssh { 155 | tpu_zone=$1 156 | tpu_project=$2 157 | tpu_name=$3 158 | command="$4" 159 | 160 | if [ -z "$command" ]; then 161 | echo "Invalid syntax!" 162 | return 1 163 | fi 164 | 165 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 166 | for host in $tpu_ips; do 167 | ssh $host "$command" & 168 | done 169 | wait &> /dev/null 170 | } 171 | 172 | function _tpu_reboot { 173 | tpu_zone=$1 174 | tpu_project=$2 175 | tpu_name=$3 176 | 177 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 178 | for host in $tpu_ips; do 179 | ssh $host 'sudo reboot' & 180 | done 181 | wait &> /dev/null 182 | } 183 | 184 | 185 | function tpu { 186 | trap "trap - SIGINT SIGTERM; return 1;" SIGINT SIGTERM 187 | 188 | 189 | # =============== TPU Project Specific Definitions =============== 190 | export PROJECT_HOME=' $HOME/tpu_requirements.txt <<- EndOfFile 17 | -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 18 | jax[tpu]==0.4.13 19 | tensorflow==2.11.0 20 | flax==0.7.0 21 | optax==0.1.7 22 | distrax==0.1.3 23 | chex==0.1.7 24 | einops 25 | --extra-index-url https://download.pytorch.org/whl/cpu 26 | torch==2.0.1 27 | transformers==4.31.0 28 | datasets==2.14.2 29 | huggingface_hub==0.16.4 30 | tqdm 31 | h5py 32 | ml_collections 33 | wandb==0.13.5 34 | gcsfs==2022.11.0 35 | requests 36 | typing-extensions 37 | lm-eval==0.3.0 38 | mlxu==0.1.11 39 | sentencepiece 40 | pydantic 41 | fastapi 42 | uvicorn 43 | gradio 44 | EndOfFile 45 | 46 | pip install --upgrade -r $HOME/tpu_requirements.txt 47 | 48 | 49 | # vim configurations 50 | cat > $HOME/.vimrc <<- EndOfFile 51 | set tabstop=4 52 | set shiftwidth=4 53 | set softtabstop=4 54 | set expandtab 55 | set backspace=indent,eol,start 56 | syntax on 57 | EndOfFile 58 | 59 | # tmux configurations 60 | cat > $HOME/.tmux.conf <<- EndOfFile 61 | bind r source-file ~/.tmux.conf 62 | 63 | set -g prefix C-a 64 | 65 | set -g set-titles on 66 | set -g set-titles-string '#(whoami)::#h::#(curl ipecho.net/plain;echo)' 67 | 68 | set -g default-terminal "screen-256color" 69 | 70 | # Status bar customization 71 | #set -g status-utf8 on 72 | set -g status-bg white 73 | set -g status-fg black 74 | set -g status-interval 5 75 | set -g status-left-length 90 76 | set -g status-right-length 60 77 | 78 | set -g status-justify left 79 | 80 | unbind-key C-o 81 | bind -n C-o prev 82 | unbind-key C-p 83 | bind -n C-p next 84 | unbind-key C-w 85 | bind -n C-w new-window 86 | 87 | unbind-key C-j 88 | bind -n C-j select-pane -D 89 | unbind-key C-k 90 | bind -n C-k select-pane -U 91 | unbind-key C-h 92 | bind -n C-h select-pane -L 93 | unbind-key C-l 94 | bind -n C-l select-pane -R 95 | 96 | unbind-key C-e 97 | bind -n C-e split-window -h 98 | unbind-key C-q 99 | bind -n C-q split-window -v 100 | unbind '"' 101 | unbind % 102 | 103 | unbind-key u 104 | bind-key u split-window -h 105 | unbind-key i 106 | bind-key i split-window -v 107 | EndOfFile 108 | 109 | 110 | # htop Configurations 111 | mkdir -p $HOME/.config/htop 112 | cat > $HOME/.config/htop/htoprc <<- EndOfFile 113 | # Beware! This file is rewritten by htop when settings are changed in the interface. 114 | # The parser is also very primitive, and not human-friendly. 115 | fields=0 48 17 18 38 39 40 2 46 47 49 1 116 | sort_key=46 117 | sort_direction=1 118 | hide_threads=0 119 | hide_kernel_threads=1 120 | hide_userland_threads=1 121 | shadow_other_users=0 122 | show_thread_names=0 123 | show_program_path=1 124 | highlight_base_name=0 125 | highlight_megabytes=1 126 | highlight_threads=1 127 | tree_view=0 128 | header_margin=1 129 | detailed_cpu_time=0 130 | cpu_count_from_zero=0 131 | update_process_names=0 132 | account_guest_in_cpu_meter=0 133 | color_scheme=0 134 | delay=15 135 | left_meters=CPU Memory Swap 136 | left_meter_modes=1 1 1 137 | right_meters=Tasks LoadAverage Uptime 138 | right_meter_modes=2 2 2 139 | EndOfFile 140 | -------------------------------------------------------------------------------- /tokenize_examples/detokenization_muse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import einops 4 | import base64 5 | import shutil 6 | from torchvision import transforms 7 | from torch.utils.data import Dataset, DataLoader 8 | from PIL import Image 9 | import numpy as np 10 | from torchvision.utils import save_image 11 | import json 12 | import wandb 13 | from tqdm import tqdm 14 | import torch 15 | from torchvision import transforms 16 | from PIL import Image 17 | from muse import VQGANModel 18 | import torchvision.utils as vutils 19 | import matplotlib.pyplot as plt 20 | import argparse 21 | 22 | 23 | def save_temp_images(tensors, temp_dir='./temp_images'): 24 | """ 25 | Save a batch of images stored in a PyTorch tensor to a temporary directory. 26 | 27 | Args: 28 | - tensors (torch.Tensor): A batch of images in the format (batch size, channels, height, width). 29 | - temp_dir (str): Path to the temporary directory where images will be saved. 30 | """ 31 | # Create a temporary directory if it doesn't exist 32 | if not os.path.exists(temp_dir): 33 | os.makedirs(temp_dir) 34 | 35 | # grid_img = vutils.make_grid(tensors, n=16, padding=0) 36 | # vutils.save_image(grid_img, 'vqlm/muse/vis_reconstruction_tokens/concatenated_image.png') 37 | 38 | # Save each image in the batch to the temporary directory 39 | for i, tensor in enumerate(tensors): 40 | # Construct the file path 41 | file_path = os.path.join(temp_dir, f'image_{i}.png') 42 | # Save the image 43 | save_image(tensor, file_path) 44 | 45 | 46 | def delete_temp_images(temp_dir='./temp_images'): 47 | """ 48 | Delete the temporary directory and its contents. 49 | 50 | Args: 51 | - temp_dir (str): Path to the temporary directory to be deleted. 52 | """ 53 | if os.path.exists(temp_dir): 54 | shutil.rmtree(temp_dir) 55 | 56 | 57 | class CustomImageDataset(Dataset): 58 | def __init__(self, image_paths, transform=None): 59 | self.image_paths = image_paths 60 | self.transform = transform 61 | 62 | def __len__(self): 63 | return len(self.image_paths) 64 | 65 | def __getitem__(self, idx): 66 | image_path = self.image_paths[idx] 67 | image = Image.open(image_path).convert('RGB') 68 | 69 | if self.transform: 70 | image = self.transform(image) 71 | 72 | # Assuming you want to split or transform the image in a way that fits the '2 * h * w' description 73 | # This can be an operation like doubling the image, or applying some transformation. 74 | # For simplicity, we'll just duplicate the image in the dataset. 75 | # Adjust the transformation as needed. 76 | 77 | return (image,) 78 | 79 | 80 | import random 81 | 82 | def select_random_elements(your_list, num_elements=64, seed_value=42): 83 | """ 84 | Randomly selects `num_elements` from `your_list` using `seed_value` for reproducibility. 85 | 86 | Parameters: 87 | - your_list: The list to select elements from. 88 | - num_elements: The number of elements to select. Default is 64. 89 | - seed_value: The seed value for the random number generator to ensure reproducibility. Default is 42. 90 | 91 | Returns: 92 | - A list containing the randomly selected elements. 93 | """ 94 | random.seed(seed_value) # Set the seed for reproducibility 95 | selected_elements = random.sample(your_list, num_elements) # Randomly select elements 96 | return selected_elements 97 | 98 | def save_tokens(trecons_name, idx, re_constructed): 99 | tensor_img = re_constructed.cpu().numpy() 100 | images = (tensor_img * 255).transpose(0, 2, 3, 1).astype(np.uint8) 101 | 102 | # Concatenate images horizontally 103 | concatenated_image = np.hstack(images) 104 | 105 | # Convert the NumPy array to a PIL Image and save it 106 | os.makedirs(trecons_name, exist_ok=True) 107 | img = Image.fromarray(concatenated_image) 108 | img.save(os.path.join(trecons_name, '{}.png'.format(idx))) 109 | 110 | 111 | if __name__ == '__main__': 112 | # Create the parser 113 | parser = argparse.ArgumentParser(description='Process some integers.') 114 | 115 | # Add arguments 116 | parser.add_argument('--dataset', default='i1k_edge', type=str, help='An input name') 117 | 118 | # Parse the arguments 119 | args = parser.parse_args() 120 | 121 | # Load the pre-trained vq model from the hub 122 | net = VQGANModel.from_pretrained("vqlm/muse/ckpts/laion").cuda() 123 | net.eval() 124 | 125 | idx = 1 126 | dataset = args.dataset 127 | with open('./lvm/tokenized_muse/{}.jsonl'.format(dataset), 'r') as file: 128 | for line in file: 129 | try: 130 | json_obj = json.loads(line.strip())['tokens'] 131 | # Process the JSON object 132 | except: 133 | continue 134 | 135 | decoded_bytes = base64.b64decode(json_obj ) 136 | array_dtype = np.int32 137 | array_shape = (-1, 256) 138 | tokens = np.frombuffer(decoded_bytes, dtype=array_dtype).reshape(array_shape) 139 | tokens = torch.tensor(tokens).cuda() 140 | 141 | 142 | with torch.no_grad(): 143 | 144 | # detokenized 145 | re_constructed = net.decode_code(tokens) 146 | 147 | plt.figure(figsize=(12, 12)) 148 | for i in range(re_constructed.shape[0]): 149 | recon_img = torch.clamp(re_constructed[i], 150 | 0.0, 1.0 151 | ) 152 | plt.subplot(4, 4, i + 1) 153 | plt.imshow((((recon_img).permute(1, 2, 0).detach().cpu().numpy() * 255)).astype(np.int32)) 154 | plt.grid(False) 155 | plt.axis('off') 156 | save_root = './lvm/other_folder/vis_reconstruction_tokens_check_final/{}'.format(dataset) 157 | os.makedirs(save_root, exist_ok=True) 158 | plt.savefig('{}/{}.png'.format(save_root, idx)) 159 | 160 | idx += 1 161 | if idx >= 16: 162 | break 163 | 164 | -------------------------------------------------------------------------------- /tokenize_examples/map_color.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import os 4 | from tqdm import tqdm 5 | import glob 6 | 7 | def generate_random_palette(): 8 | # Generate a random palette with RGB values 9 | return np.array([np.random.choice(range(256), size=3) for _ in range(256)], dtype=np.uint8) 10 | 11 | # Function to apply the color mapping to a grayscale image 12 | def apply_color_palette(image_path, output_path, palette): 13 | # Load the image and convert it to a NumPy array 14 | grayscale_image = Image.open(image_path).convert('L') 15 | grayscale_array = np.array(grayscale_image, dtype=np.int32) 16 | 17 | # Apply the palette using advanced indexing 18 | color_mapped_array = palette[grayscale_array] 19 | 20 | # Convert back to an image and save or display 21 | color_mapped_image = Image.fromarray(color_mapped_array, 'RGB') 22 | color_mapped_image.save(output_path) 23 | 24 | 25 | if __name__ == '__main__': 26 | # Example usage 27 | # Generate a random color palette 28 | # palette = generate_random_palette() 29 | # np.save('./color_map.npy', palette) 30 | 31 | palette = np.load('./vqlm/muse/organize_dataset/i1k/color_map.npy') 32 | 33 | s = 1 34 | 35 | root = '/scratch/partial_datasets/lvm/dataset/prismer_i1k/new/seg_coco/train' 36 | 37 | images = glob.glob(root + '/*/*.png') 38 | 39 | for img in tqdm(images): 40 | cate_path = img.split('/')[-2] 41 | save_root = os.path.join('/scratch/partial_datasets/lvm/dataset/prismer_i1k/new_mapped_color/seg_coco_colored/train', cate_path) 42 | os.makedirs(save_root, exist_ok=True) 43 | save_path = os.path.join(save_root, os.path.basename(img)) 44 | apply_color_palette(img, save_path, palette=palette) 45 | -------------------------------------------------------------------------------- /tokenize_examples/tokenize_catogory_images_muse.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | from tempfile import NamedTemporaryFile 4 | import random 5 | import json 6 | from base64 import b64encode 7 | from tqdm import tqdm, trange 8 | 9 | import numpy as np 10 | import mlxu 11 | 12 | import torch 13 | 14 | import jax 15 | import jax.numpy as jnp 16 | import flax 17 | import einops 18 | 19 | from PIL import Image 20 | from utils import read_image_to_tensor 21 | 22 | FLAGS, _ = mlxu.define_flags_with_default( 23 | input_image_dir='/datasets/ilsvrc_2024-01-04_1601/train', 24 | output_file='/home/yutongbai/vqlm/muse/running_script/tokenized_muse/i1kcate_4.jsonl', 25 | images_per_shot=4, 26 | n_shots=4, 27 | n_epochs=1, 28 | n_workers=8, 29 | dtype='fp32', 30 | batch_size=1 31 | ) 32 | 33 | # Define the desired fixed token length 34 | fixed_token_length = 4096 35 | 36 | 37 | class SubfolderImageDataset(torch.utils.data.Dataset): 38 | 39 | def __init__(self, subfolders, images_per_shot): 40 | self.subfolders = subfolders 41 | self.images_per_shot = images_per_shot 42 | 43 | def __getitem__(self, index): 44 | subfolder = self.subfolders[index] 45 | image_files = [os.path.join(subfolder, f) for f in os.listdir(subfolder) if 46 | f.endswith(('.png', '.jpg', '.JPEG'))] 47 | selected_images = np.random.choice(image_files, self.images_per_shot, replace=False) 48 | return [read_image_to_tensor(image_file) for image_file in selected_images] 49 | 50 | def __len__(self): 51 | return len(self.subfolders) 52 | 53 | 54 | def read_image_to_tensor(image_path): 55 | with Image.open(image_path) as img: 56 | img = img.convert('RGB') 57 | img = img.resize((256, 256)) 58 | img = np.array(img) / 255.0 59 | return torch.tensor(img).float() 60 | 61 | 62 | def custom_collate(batch): 63 | return [item for sublist in batch for item in sublist] 64 | 65 | 66 | def main(argv): 67 | assert FLAGS.input_image_dir != '' 68 | assert FLAGS.output_file != '' 69 | 70 | subfolders = [os.path.join(FLAGS.input_image_dir, d) for d in os.listdir(FLAGS.input_image_dir) if 71 | os.path.isdir(os.path.join(FLAGS.input_image_dir, d))] 72 | 73 | dataset = SubfolderImageDataset(subfolders, FLAGS.images_per_shot) 74 | dataloader = torch.utils.data.DataLoader( 75 | dataset, 76 | batch_size=FLAGS.batch_size, 77 | shuffle=True, 78 | num_workers=FLAGS.n_workers, 79 | drop_last=True, 80 | collate_fn=custom_collate 81 | ) 82 | 83 | total_images = len(dataset) * FLAGS.images_per_shot 84 | 85 | checkpoint_path = os.path.join( 86 | os.path.dirname(os.path.realpath(__file__)), 87 | '/home/vqlm/eval_visualization/torch_ckpts/jax_xh_ckpt.pkl' 88 | ) 89 | tokenize_fn = get_tokenizer_fn(checkpoint_path, FLAGS.dtype) 90 | 91 | n_devices = jax.device_count() 92 | 93 | with NamedTemporaryFile() as ntf: 94 | # Adjust the shape of all_tokens to match the fixed token length 95 | all_tokens = np.memmap(ntf, dtype='i4', mode='w+', shape=(total_images, fixed_token_length)) 96 | 97 | index = 0 98 | for batch in tqdm(dataloader, ncols=0): 99 | batch_images = np.concatenate(batch, axis=0) 100 | 101 | # Reshape batch_images to (n_devices, batch_size/n_devices, height, width, channels) 102 | batch_images = batch_images.reshape(n_devices, -1, 256, 256, 3) 103 | 104 | tokens = tokenize_fn(batch_images).flatten() 105 | 106 | # Ensure tokens have a fixed length (fixed_token_length) 107 | tokens = tokens[:fixed_token_length] 108 | 109 | # Assign tokens to the all_tokens array 110 | all_tokens[index:index + len(tokens)] = tokens 111 | index += len(tokens) 112 | 113 | with open(FLAGS.output_file, 'w') as fout: 114 | for _ in trange(FLAGS.n_epochs, ncols=0): 115 | indices = np.random.permutation(total_images).reshape(-1, FLAGS.n_shots * FLAGS.images_per_shot) 116 | for i in trange(indices.shape[0], ncols=0): 117 | tokens = all_tokens[indices[i], :].reshape(-1) 118 | data = {'tokens': b64encode(tokens.tobytes()).decode('utf-8')} 119 | fout.write(json.dumps(data) + '\n') 120 | 121 | 122 | if __name__ == '__main__': 123 | mlxu.run(main) 124 | -------------------------------------------------------------------------------- /tokenize_examples/tokenize_co3d_muse.py: -------------------------------------------------------------------------------- 1 | """ Tokenize multiple related sequences of data """ 2 | 3 | 4 | 5 | import os 6 | from copy import deepcopy 7 | from functools import partial 8 | from tempfile import NamedTemporaryFile 9 | import random 10 | import json 11 | from tqdm import tqdm, trange 12 | from muse import VQGANModel 13 | import numpy as np 14 | import mlxu 15 | 16 | import torch 17 | 18 | 19 | import einops 20 | 21 | from PIL import Image 22 | 23 | from utils import match_mulitple_path 24 | 25 | from utils import ( 26 | randomly_subsample_frame_indices, list_dir_with_full_path, 27 | is_image, b64encode_tokens, read_image_to_tensor 28 | ) 29 | 30 | 31 | FLAGS, _ = mlxu.define_flags_with_default( 32 | input_dir='', 33 | output_file='', 34 | batch_size=4, 35 | n_frames=4, 36 | max_stride=4, 37 | n_shots=4, 38 | n_epochs=1, 39 | n_workers=16, 40 | dtype='fp32', 41 | ) 42 | 43 | 44 | 45 | class Co3DDataset(torch.utils.data.Dataset): 46 | 47 | def __init__(self, image_dirs, n_frames): 48 | self.image_dirs = image_dirs 49 | self.n_frames = n_frames 50 | 51 | def __getitem__(self, index): 52 | tasks = self.image_dirs[index] 53 | frames = [] 54 | length = len(list_dir_with_full_path(tasks[0])) 55 | indices = randomly_subsample_frame_indices( 56 | length, self.n_frames, FLAGS.max_stride 57 | ) 58 | for task in tasks: 59 | task_frames = [] 60 | files = sorted(list_dir_with_full_path(task)) 61 | for i in indices: 62 | if is_image(files[i]): 63 | task_frames.append(read_image_to_tensor(files[i])) 64 | frames.append(np.stack(task_frames, axis=0)) 65 | 66 | return np.stack(frames, axis=0) 67 | 68 | def __len__(self): 69 | return len(self.image_dirs) 70 | 71 | 72 | def main(argv): 73 | assert FLAGS.input_dir != '' 74 | assert FLAGS.output_file != '' 75 | 76 | # Load the pre-trained vq model from the hub 77 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 78 | 79 | net = VQGANModel.from_pretrained('/home/vqlm/muse/ckpts/laion').to(device) 80 | net.eval() 81 | 82 | dirs = [] 83 | for d in list_dir_with_full_path(FLAGS.input_dir): 84 | if not os.path.isdir(d): 85 | continue 86 | for d2 in list_dir_with_full_path(d): 87 | if not os.path.isdir(d2): 88 | continue 89 | dirs.append(d2) 90 | 91 | image_dirs = [] 92 | for d in dirs: 93 | image_dirs.append(( 94 | os.path.join(d, 'images'), 95 | os.path.join(d, 'masks'), 96 | os.path.join(d, 'depth_masks') 97 | )) 98 | 99 | with open(FLAGS.output_file, 'w') as fout: 100 | with torch.no_grad(): 101 | for _ in trange(FLAGS.n_epochs, ncols=0): 102 | print(image_dirs[0]) 103 | dataset = Co3DDataset(image_dirs, FLAGS.n_frames) 104 | dataloader = torch.utils.data.DataLoader( 105 | dataset, 106 | batch_size=FLAGS.batch_size * FLAGS.n_shots, 107 | shuffle=False, 108 | num_workers=FLAGS.n_workers, 109 | drop_last=True 110 | ) 111 | 112 | for batch in tqdm(dataloader, ncols=0): 113 | batch_shape = batch.shape[:-3] 114 | batch = batch.reshape(-1, *batch.shape[-3:]) 115 | batch = batch.permute(0,3,1,2) 116 | batch = batch.to(device) 117 | 118 | _, tokens = net.encode(batch) 119 | tokens = tokens.reshape(*batch_shape, tokens.shape[-1]) 120 | # batch x task x frame x token 121 | tokens = einops.rearrange( 122 | tokens.cpu().numpy().astype(np.int32), '(b s) t f d -> b s t f d', 123 | s=FLAGS.n_shots 124 | ) 125 | 126 | image_mask_tokens = np.concatenate( 127 | (tokens[:, :, 0, :, :], tokens[:, :, 1, :, :]), axis=-2 128 | ) 129 | image_depth_tokens = np.concatenate( 130 | (tokens[:, :, 0, :, :], tokens[:, :, 2, :, :]), axis=-2 131 | ) 132 | for i in range(tokens.shape[0]): 133 | data = {'tokens': b64encode_tokens(image_mask_tokens[i])} 134 | fout.write(json.dumps(data) + '\n') 135 | data = {'tokens': b64encode_tokens(image_depth_tokens[i])} 136 | fout.write(json.dumps(data) + '\n') 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | if __name__ == '__main__': 145 | mlxu.run(main) -------------------------------------------------------------------------------- /tokenize_examples/tokenize_colorization_dataset_muse.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from functools import partial 4 | from tempfile import NamedTemporaryFile 5 | import random 6 | import json 7 | from base64 import b64encode 8 | from tqdm import tqdm, trange 9 | import glob 10 | import numpy as np 11 | import mlxu 12 | 13 | import torch 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | import flax 18 | import einops 19 | 20 | from PIL import Image, UnidentifiedImageError 21 | 22 | from PIL import Image 23 | from muse import VQGANModel 24 | from utils import read_image_to_tensor, is_image, list_dir_with_full_path 25 | 26 | 27 | FLAGS, _ = mlxu.define_flags_with_default( 28 | input_image_dir='/datasets/ilsvrc_2024-01-04_1601/train', 29 | output_file='./lvm/tokenized_muse/i1k_colorization.jsonl', 30 | batch_size=1, 31 | n_shots=8, 32 | n_epochs=2, 33 | n_workers=8, 34 | patch_size=32, 35 | hole_mask_ratio=0.3, 36 | dtype='fp32', 37 | layer = 1 38 | ) 39 | 40 | 41 | class PairedImageDataset(torch.utils.data.Dataset): 42 | 43 | def __init__(self, images): 44 | self.images = images 45 | 46 | def __getitem__(self, index): 47 | try: 48 | original_image = read_image_to_tensor(self.images[index]) 49 | except UnidentifiedImageError: 50 | original_image = np.zeros((256, 256, 3), dtype=np.float32) 51 | 52 | # make gray images 53 | grayscale = np.dot(original_image[..., :3], [0.2989, 0.5870, 0.1140]) 54 | gray_image = np.stack((grayscale,) * 3, axis=-1) 55 | # Stack the grayscale image across the third dimension 56 | gray_image = np.array(gray_image, dtype=np.float32) 57 | return gray_image, original_image 58 | 59 | def __len__(self): 60 | return len(self.images) 61 | 62 | 63 | def main(argv): 64 | assert FLAGS.input_image_dir != '' 65 | assert FLAGS.output_file != '' 66 | 67 | # Load the pre-trained vq model from the hub 68 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 69 | 70 | net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device) 71 | net.eval() 72 | 73 | 74 | input_images = glob.glob('{}{}/*.png'.format(FLAGS.input_image_dir, '/*'*FLAGS.layer)) 75 | input_images += glob.glob('{}{}/*.jpg'.format(FLAGS.input_image_dir, '/*' * FLAGS.layer)) 76 | input_images += glob.glob('{}{}/*.jpeg'.format(FLAGS.input_image_dir, '/*' * FLAGS.layer)) 77 | input_images += glob.glob('{}{}/*.JPEG'.format(FLAGS.input_image_dir, '/*' * FLAGS.layer)) 78 | 79 | dataset = PairedImageDataset(input_images) 80 | dataloader = torch.utils.data.DataLoader( 81 | dataset, 82 | batch_size=FLAGS.batch_size * FLAGS.n_shots, 83 | shuffle=False, 84 | num_workers=FLAGS.n_workers, 85 | drop_last=True 86 | ) 87 | 88 | total_images = len(input_images) - len(input_images) % (FLAGS.batch_size * FLAGS.n_shots) 89 | print(total_images) 90 | with NamedTemporaryFile() as ntf: 91 | all_tokens = np.memmap(ntf, dtype='i4', mode='w+', shape=(total_images, 512)) 92 | all_tokens[:] = 0 93 | 94 | index = 0 95 | # debug_count = 0 96 | for input_image_batch, output_image_batch in tqdm(dataloader, ncols=0): 97 | # if debug_count < 5243: 98 | # debug_count += 1 99 | # continue 100 | 101 | _, input_token_batch = net.encode(input_image_batch.permute(0, 3, 1, 2).to(device)) 102 | _, output_token_batch = net.encode(output_image_batch.permute(0, 3, 1, 2).to(device)) 103 | 104 | all_tokens[index:index + input_image_batch.shape[0]] = np.concatenate( 105 | [input_token_batch.cpu().numpy().astype(np.int32), output_token_batch.cpu().numpy().astype(np.int32)], 106 | axis=1 107 | ) 108 | index += input_image_batch.shape[0] 109 | 110 | with open(FLAGS.output_file, 'w') as fout: 111 | for _ in trange(FLAGS.n_epochs, ncols=0): 112 | indices = np.random.permutation(total_images).reshape(-1, FLAGS.n_shots) 113 | for i in trange(indices.shape[0], ncols=0): 114 | tokens = all_tokens[indices[i], :].reshape(-1) 115 | data = {'tokens': b64encode(tokens.tobytes()).decode('utf-8'),} 116 | fout.write(json.dumps(data) + '\n') 117 | 118 | 119 | if __name__ == '__main__': 120 | mlxu.run(main) -------------------------------------------------------------------------------- /tokenize_examples/tokenize_inpainting_dataset_muse.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from functools import partial 4 | from tempfile import NamedTemporaryFile 5 | import random 6 | import json 7 | from base64 import b64encode 8 | from tqdm import tqdm, trange 9 | import glob 10 | import numpy as np 11 | import mlxu 12 | 13 | import torch 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | import flax 18 | import einops 19 | 20 | from PIL import Image, UnidentifiedImageError 21 | 22 | from PIL import Image 23 | from muse import VQGANModel 24 | from utils import read_image_to_tensor, is_image, list_dir_with_full_path 25 | 26 | 27 | FLAGS, _ = mlxu.define_flags_with_default( 28 | input_image_dir='/datasets/imagenet_22k_2024-01-04_1601', 29 | output_file='./lvm/tokenized_muse/i22k_inpainting.jsonl', 30 | batch_size=1, 31 | n_shots=8, 32 | n_epochs=5, 33 | n_workers=8, 34 | patch_size=32, 35 | hole_mask_ratio=0.3, 36 | dtype='fp32', 37 | layer = 1 38 | ) 39 | 40 | 41 | class PairedImageDataset(torch.utils.data.Dataset): 42 | 43 | def __init__(self, images): 44 | self.images = images 45 | 46 | def __getitem__(self, index): 47 | try: 48 | original_image = read_image_to_tensor(self.images[index]) 49 | except UnidentifiedImageError: 50 | original_image = np.zeros((256, 256, 3), dtype=np.float32) 51 | 52 | if np.random.random() < FLAGS.hole_mask_ratio: 53 | h, w, _ = original_image.shape 54 | masked_image = original_image.copy() 55 | min_dim = min(h, w) 56 | size = np.random.randint(int(0.1 * min_dim), int(0.5 * min_dim)) 57 | top_left_y = np.random.randint(0, h - size + 1) 58 | top_left_x = np.random.randint(0, w - size + 1) 59 | masked_image[top_left_y:top_left_y+size, top_left_x:top_left_x+size, :] = 0 60 | else: 61 | patches = einops.rearrange( 62 | original_image, 63 | '(h p1) (w p2) c -> (h w) (p1 p2 c)', 64 | p1=FLAGS.patch_size, p2=FLAGS.patch_size 65 | ) 66 | mask_ratio = np.random.random() 67 | mask = np.random.random((patches.shape[0], 1)) < mask_ratio 68 | masked_patches = patches * mask 69 | hw = 256 // FLAGS.patch_size 70 | masked_image = einops.rearrange( 71 | masked_patches, 72 | '(h w) (p1 p2 c) -> (h p1) (w p2) c', 73 | p1=FLAGS.patch_size, p2=FLAGS.patch_size, 74 | h=hw, w=hw 75 | ) 76 | return masked_image, original_image 77 | 78 | def __len__(self): 79 | return len(self.images) 80 | 81 | 82 | def main(argv): 83 | assert FLAGS.input_image_dir != '' 84 | assert FLAGS.output_file != '' 85 | 86 | # Load the pre-trained vq model from the hub 87 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 88 | 89 | net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device) 90 | net.eval() 91 | 92 | 93 | # input_images = glob.glob('{}{}/*.png'.format(FLAGS.input_image_dir, '/*'*FLAGS.layer)) 94 | # input_images += glob.glob('{}{}/*.jpg'.format(FLAGS.input_image_dir, '/*' * FLAGS.layer)) 95 | # input_images += glob.glob('{}{}/*.jpeg'.format(FLAGS.input_image_dir, '/*' * FLAGS.layer)) 96 | input_images = glob.glob('{}{}/*.JPEG'.format(FLAGS.input_image_dir, '/*' * FLAGS.layer)) 97 | 98 | 99 | # input_images = input_images[:100] 100 | dataset = PairedImageDataset(input_images) 101 | dataloader = torch.utils.data.DataLoader( 102 | dataset, 103 | batch_size=FLAGS.batch_size * FLAGS.n_shots, 104 | shuffle=False, 105 | num_workers=FLAGS.n_workers, 106 | drop_last=True 107 | ) 108 | 109 | total_images = len(input_images) - len(input_images) % (FLAGS.batch_size * FLAGS.n_shots) 110 | print(total_images) 111 | with NamedTemporaryFile() as ntf: 112 | all_tokens = np.memmap(ntf, dtype='i4', mode='w+', shape=(total_images, 512)) 113 | all_tokens[:] = 0 114 | 115 | index = 0 116 | for input_image_batch, output_image_batch in tqdm(dataloader, ncols=0): 117 | _, input_token_batch = net.encode(input_image_batch.permute(0, 3, 1, 2).to(device)) 118 | _, output_token_batch = net.encode(output_image_batch.permute(0, 3, 1, 2).to(device)) 119 | 120 | all_tokens[index:index + input_image_batch.shape[0]] = np.concatenate( 121 | [input_token_batch.cpu().numpy().astype(np.int32), output_token_batch.cpu().numpy().astype(np.int32)], 122 | axis=1 123 | ) 124 | index += input_image_batch.shape[0] 125 | 126 | with open(FLAGS.output_file, 'w') as fout: 127 | for _ in trange(FLAGS.n_epochs, ncols=0): 128 | indices = np.random.permutation(total_images).reshape(-1, FLAGS.n_shots) 129 | for i in trange(indices.shape[0], ncols=0): 130 | tokens = all_tokens[indices[i], :].reshape(-1) 131 | data = {'tokens': b64encode(tokens.tobytes()).decode('utf-8'),} 132 | fout.write(json.dumps(data) + '\n') 133 | 134 | 135 | if __name__ == '__main__': 136 | mlxu.run(main) -------------------------------------------------------------------------------- /tokenize_examples/tokenize_multi_datasets_muse.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | from functools import partial 4 | from tempfile import NamedTemporaryFile 5 | import random 6 | import json 7 | from base64 import b64encode 8 | from tqdm import tqdm, trange 9 | import glob 10 | import numpy as np 11 | import mlxu 12 | 13 | import torch 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | import flax 18 | import einops 19 | 20 | from PIL import Image, UnidentifiedImageError 21 | from muse import VQGANModel 22 | from utils import match_mulitple_path, read_image_to_tensor 23 | 24 | 25 | # FLAGS, _ = mlxu.define_flags_with_default( 26 | # input_dir='./lvm/dataset/prismer_i1k/new', 27 | # input_regex='/datasets/ilsvrc_2024-01-04_1601/train::./lvm/dataset/prismer_i1k/new/depth/train::./lvm/dataset/prismer_i1k/new/edge/train::./lvm/dataset/prismer_i1k/new/normal/train::./lvm/dataset/prismer_i1k/new_mapped_color/seg_coco_colored/train', 28 | # output_file='./lvm/other_folder/old/i1k_cot_uni-mapped_occupy_gpu.jsonl', 29 | # shuffle_tasks=True, 30 | # crop=False, 31 | # batch_size=16, 32 | # max_examples=0, 33 | # n_shots=3, 34 | # n_epochs=5, 35 | # n_workers=8, 36 | # dtype='fp32', 37 | # layer=2, 38 | # ) 39 | 40 | FLAGS, _ = mlxu.define_flags_with_default( 41 | input_dir='./lvm/dataset/prismer_i1k/new', 42 | input_regex='/datasets/coco2017_2024-01-04_1601/train2017::/shared/yutongbai/labels/normal/helpers/images::/shared/yutongbai/labels/edge/helpers/images::/shared/yutongbai/labels/depth/helpers/images_coco::./lvm/dataset/prismer_coco/seg_coco_colored_mapped_color', 43 | output_file='./lvm/tokenized_muse/coco_mixed_uni-mapped.jsonl', 44 | shuffle_tasks=True, 45 | crop=False, 46 | batch_size=8, 47 | max_examples=0, 48 | n_shots=3, 49 | n_epochs=5, 50 | n_workers=8, 51 | dtype='fp32', 52 | layer=1, 53 | ) 54 | 55 | # FLAGS, _ = mlxu.define_flags_with_default( 56 | # input_dir='./lvm/dataset/prismer_i1k/new', 57 | # input_regex='./lvm/dataset/kitti-cot_crop/image::./lvm/dataset/kitti-cot_crop/depth::./lvm/dataset/kitti-cot_crop/next_frame::./lvm/dataset/kitti-cot_crop/scene_flow::./lvm/dataset/kitti-cot_crop/sementic_seg::./lvm/dataset/kitti-cot_crop/sementic_seg_rbg::./lvm/dataset/kitti-cot_crop/stereo', 58 | # output_file='./lvm/tokenized_muse/kitti_new.jsonl', 59 | # shuffle_tasks=True, 60 | # crop=False, 61 | # batch_size=16, 62 | # max_examples=0, 63 | # n_shots=2, 64 | # n_epochs=5, 65 | # n_workers=8, 66 | # dtype='fp32', 67 | # layer=1, 68 | # ) 69 | 70 | 71 | class MultipleImageDataset(torch.utils.data.Dataset): 72 | 73 | def __init__(self, input_images): 74 | self.input_images = input_images 75 | 76 | def __getitem__(self, index): 77 | try: 78 | if FLAGS.crop: 79 | crop_rng = np.random.default_rng(np.random.randint(0, 2 ** 32)) 80 | else: 81 | crop_rng = None 82 | return tuple( 83 | read_image_to_tensor(x, crop=FLAGS.crop, crop_rng=deepcopy(crop_rng)) 84 | for x in self.input_images[index] 85 | ) 86 | except UnidentifiedImageError as e: 87 | print(f'Error: {e} for {self.input_images[index]}') 88 | return self[np.random.randint(0, len(self))] 89 | 90 | def __len__(self): 91 | return len(self.input_images) 92 | 93 | def match_mulitple_path_v2(root, regex): 94 | groups = {} 95 | for modal in regex: 96 | images = glob.glob(modal + '{}.png'.format(FLAGS.layer*'/*')) 97 | images += glob.glob(modal + '{}.jpg'.format(FLAGS.layer*'/*')) 98 | images += glob.glob(modal + '{}.JPEG'.format(FLAGS.layer*'/*')) 99 | images += glob.glob(modal + '{}.jpeg'.format(FLAGS.layer*'/*')) 100 | for img in images: 101 | img_idx = img.split('/')[-1].split('.')[0] 102 | if not img_idx in groups: 103 | groups[img_idx] = [] 104 | groups[img_idx].append(img) 105 | 106 | groups = [groups[idx] for idx in groups if len(groups[idx]) == 5] 107 | 108 | return groups 109 | 110 | 111 | def main(argv): 112 | assert FLAGS.input_dir != '' 113 | assert FLAGS.input_regex != '' 114 | assert FLAGS.output_file != '' 115 | 116 | # Load the pre-trained vq model from the hub 117 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 118 | 119 | net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device) 120 | net.eval() 121 | 122 | regex = FLAGS.input_regex.split('::') 123 | input_images = match_mulitple_path_v2(FLAGS.input_dir, regex) 124 | 125 | print(f'Found {len(input_images)} images') 126 | assert len(input_images) > 0, 'No images found' 127 | 128 | if FLAGS.max_examples > 0: 129 | input_images = input_images[:FLAGS.max_examples] 130 | 131 | random.shuffle(input_images) 132 | 133 | dataset = MultipleImageDataset(input_images) 134 | dataloader = torch.utils.data.DataLoader( 135 | dataset, 136 | batch_size=FLAGS.batch_size * FLAGS.n_shots, 137 | shuffle=False, 138 | num_workers=FLAGS.n_workers, 139 | drop_last=True 140 | ) 141 | 142 | total_images = len(input_images) - len(input_images) % (FLAGS.batch_size * FLAGS.n_shots) 143 | 144 | with NamedTemporaryFile() as ntf: 145 | all_tokens = np.memmap(ntf, dtype='i4', mode='w+', shape=(total_images, 256 * len(input_images[0]))) 146 | all_tokens[:] = 0 147 | 148 | index = 0 149 | for batch in tqdm(dataloader, ncols=0): 150 | k = 0 151 | for image in batch: 152 | batch_size = image.shape[0] 153 | image = einops.rearrange( 154 | image.numpy(), 'b h w c -> b c h w' 155 | ) 156 | image = torch.tensor(image).to(device) 157 | _, tokens = net.encode(image) 158 | tokens = einops.rearrange( 159 | tokens.cpu().numpy().astype(np.int32), '(b t) d -> b (t d)', b=batch_size 160 | ) 161 | all_tokens[index:index + image.shape[0], k:k + 256] = tokens 162 | k += 256 163 | index += batch[0].shape[0] 164 | 165 | with open(FLAGS.output_file, 'w') as fout: 166 | for _ in trange(FLAGS.n_epochs, ncols=0): 167 | indices = np.random.permutation(total_images).reshape(-1, FLAGS.n_shots) 168 | for i in trange(indices.shape[0], ncols=0): 169 | tokens = deepcopy(all_tokens[indices[i], :]) 170 | tokens = einops.rearrange(tokens, 'b (s t) -> b s t', t=256) 171 | if FLAGS.shuffle_tasks: 172 | permutations = np.random.permutation(tokens.shape[1]) 173 | tokens = tokens[:, permutations, :] 174 | tokens = tokens.reshape(-1) 175 | data = {'tokens': b64encode(tokens.tobytes()).decode('utf-8'),} 176 | fout.write(json.dumps(data) + '\n') 177 | 178 | 179 | if __name__ == '__main__': 180 | mlxu.run(main) -------------------------------------------------------------------------------- /tokenize_examples/tokenize_multi_seq_images_muse.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import glob 4 | from functools import partial 5 | from tempfile import NamedTemporaryFile 6 | import random 7 | import json 8 | from base64 import b64encode 9 | from tqdm import tqdm, trange 10 | 11 | import numpy as np 12 | np.float = np.float64 13 | np.int = np.int_ 14 | import mlxu 15 | from natsort import natsorted 16 | 17 | import torch 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | import flax 22 | import einops 23 | 24 | from PIL import Image 25 | from muse import VQGANModel 26 | from utils import ( 27 | list_dir_with_full_path, is_image, read_image_to_tensor, 28 | randomly_subsample_frame_indices 29 | ) 30 | 31 | 32 | FLAGS, _ = mlxu.define_flags_with_default( 33 | input_dirs='', 34 | output_file='', 35 | batch_size=1, 36 | n_frames=16, 37 | n_shots=2, 38 | n_epochs=1, 39 | n_workers=8, 40 | max_stride=4, 41 | dtype='fp32', 42 | ) 43 | 44 | 45 | class MultiVideoDataset(torch.utils.data.Dataset): 46 | 47 | def __init__(self, videos, n_frames=8): 48 | self.videos = videos 49 | self.n_tasks = len(videos[0]) 50 | self.n_frames = n_frames 51 | 52 | def __getitem__(self, index): 53 | n_frames = len([x for x in list_dir_with_full_path(self.videos[index][0]) if is_image(x)]) 54 | for i in range(self.n_tasks): 55 | if len( 56 | [x for x in list_dir_with_full_path(self.videos[index][i]) 57 | if is_image(x)] 58 | ) != n_frames: 59 | print('Inconsistent number of frames') 60 | return self[np.random.randint(0, len(self))] 61 | if n_frames < self.n_frames: 62 | print(n_frames) 63 | return self[np.random.randint(0, len(self))] 64 | 65 | indices = randomly_subsample_frame_indices( 66 | n_frames, self.n_frames, FLAGS.max_stride, 67 | random_start=True 68 | ) 69 | 70 | all_frames = [] 71 | for task_idx in range(self.n_tasks): 72 | frames = [] 73 | all_files = [x for x in list_dir_with_full_path(self.videos[index][task_idx]) if is_image(x)] 74 | all_files = natsorted(all_files) 75 | for idx in indices: 76 | frames.append(read_image_to_tensor(all_files[idx])) 77 | 78 | all_frames.append(np.stack(frames, axis=0)) 79 | 80 | return np.stack(all_frames, axis=0) 81 | 82 | def __len__(self): 83 | return len(self.videos) 84 | 85 | 86 | def main(argv): 87 | assert FLAGS.input_dirs != '' 88 | assert FLAGS.output_file != '' 89 | 90 | # Load the pre-trained vq model from the hub 91 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 92 | 93 | net = VQGANModel.from_pretrained('/home/vqlm/muse/ckpts/laion').to(device) 94 | net.eval() 95 | 96 | video_dirs = [sorted(glob.glob(x)) for x in FLAGS.input_dirs.split('::')] 97 | n_tasks = len(video_dirs) 98 | 99 | groups = {} 100 | 101 | for videos in [sorted(glob.glob(x)) for x in FLAGS.input_dirs.split('::')]: 102 | for video in videos: 103 | name = video.split('/')[-1] 104 | # name = video[:-len(video.split('/')[-1].split('_')[-1]) - 1] 105 | if name not in groups: 106 | groups[name] = [] 107 | groups[name].append(video) 108 | 109 | video_dirs = [] 110 | for name, videos in groups.items(): 111 | if len(videos) == n_tasks: 112 | video_dirs.append(videos) 113 | 114 | 115 | with open(FLAGS.output_file, 'w') as fout: 116 | dataset = MultiVideoDataset(video_dirs, n_frames=FLAGS.n_frames) 117 | 118 | dataloader = torch.utils.data.DataLoader( 119 | dataset, 120 | batch_size=FLAGS.batch_size * FLAGS.n_shots, 121 | shuffle=False, 122 | num_workers=FLAGS.n_workers, 123 | prefetch_factor=4, 124 | drop_last=True, 125 | ) 126 | with torch.no_grad(): 127 | for _ in trange(FLAGS.n_epochs, ncols=0): 128 | 129 | all_tokens = np.zeros( 130 | dtype='i4', 131 | shape=(len(dataloader) * FLAGS.batch_size * FLAGS.n_shots, n_tasks, FLAGS.n_frames, 256) 132 | ) 133 | index = 0 134 | 135 | for batch in tqdm(dataloader, ncols=0): 136 | batch_size = batch.shape[0] 137 | batch = einops.rearrange( 138 | batch.numpy(), 'b t f h w c -> (b t f) c h w' 139 | ) 140 | batch = torch.tensor(batch).to(device) 141 | _, tokens = net.encode(batch) 142 | tokens = einops.rearrange( 143 | tokens.cpu().numpy().astype(np.int32), '(b t f) d -> b t f d', b=batch_size, t=n_tasks, f=FLAGS.n_frames 144 | ) 145 | all_tokens[index:index + batch_size, ...] = tokens 146 | index += batch_size 147 | 148 | 149 | random_indices = np.random.permutation(all_tokens.shape[0]) 150 | all_tokens = all_tokens[random_indices, ...] 151 | 152 | 153 | tokens_sep = einops.rearrange( 154 | all_tokens, '(b x) t s d -> b (x t s d)', 155 | x=FLAGS.n_shots 156 | ) 157 | tokens_interleave = einops.rearrange( 158 | all_tokens, '(b x) t s d -> b (x s t d)', 159 | x=FLAGS.n_shots 160 | ) 161 | 162 | for i in range(tokens_sep.shape[0]): 163 | data = {'tokens': b64encode(tokens_sep[i].tobytes()).decode('utf-8')} 164 | fout.write(json.dumps(data) + '\n') 165 | 166 | data = {'tokens': b64encode(tokens_interleave[i].tobytes()).decode('utf-8')} 167 | fout.write(json.dumps(data) + '\n') 168 | 169 | 170 | 171 | 172 | if __name__ == '__main__': 173 | mlxu.run(main) 174 | -------------------------------------------------------------------------------- /tokenize_examples/tokenize_paired_dataset_muse.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from functools import partial 4 | from tempfile import NamedTemporaryFile 5 | import random 6 | import json 7 | from base64 import b64encode 8 | from tqdm import tqdm, trange 9 | 10 | import numpy as np 11 | import mlxu 12 | 13 | import torch 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | import flax 18 | import einops 19 | 20 | from PIL import Image 21 | from muse import VQGANModel 22 | from utils import read_image_to_tensor 23 | 24 | 25 | FLAGS, _ = mlxu.define_flags_with_default( 26 | input_image_dir='./kitti-cot_crop/sementic_seg', 27 | output_image_dir='./kitti-cot_crop/sementic_seg', 28 | output_file='./test_kitti_semantic.jsonl', 29 | input_filter_key='', 30 | output_filter_key='', 31 | input_suffix='', 32 | output_suffix='', 33 | crop=False, 34 | batch_size=1, 35 | n_shots=8, 36 | n_epochs=5, 37 | n_workers=8, 38 | dtype='fp32', 39 | ) 40 | 41 | 42 | class PairedImageDataset(torch.utils.data.Dataset): 43 | 44 | def __init__(self, input_images, output_images): 45 | self.input_images = input_images 46 | self.output_images = output_images 47 | 48 | def __getitem__(self, index): 49 | try: 50 | return ( 51 | read_image_to_tensor(self.input_images[index], crop=FLAGS.crop), 52 | read_image_to_tensor(self.output_images[index], crop=FLAGS.crop) 53 | ) 54 | except Exception as e: 55 | print(f'Error: {e} for {self.input_images[index]}') 56 | return self[np.random.randint(0, len(self))] 57 | 58 | def __len__(self): 59 | return len(self.input_images) 60 | 61 | 62 | def main(argv): 63 | assert FLAGS.input_image_dir != '' 64 | assert FLAGS.output_image_dir != '' 65 | assert FLAGS.output_file != '' 66 | 67 | # Load the pre-trained vq model from the hub 68 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 69 | 70 | net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device) 71 | net.eval() 72 | 73 | 74 | input_images = os.listdir(FLAGS.input_image_dir) 75 | input_images = [i for i in input_images if i.endswith('.png') or i.endswith('.jpg') or i.endswith('.jpeg')] 76 | input_images = [i for i in input_images if FLAGS.input_filter_key in i] 77 | input_images = sorted(input_images) 78 | output_images = os.listdir(FLAGS.output_image_dir) 79 | output_images = [i for i in output_images if i.endswith('.png') or i.endswith('.jpg') or i.endswith('.jpeg')] 80 | output_images = [i for i in output_images if FLAGS.output_filter_key in i] 81 | output_images = sorted(output_images) 82 | 83 | assert len(input_images) == len(output_images) 84 | 85 | 86 | input_images = [ 87 | os.path.join(FLAGS.input_image_dir, s) 88 | for s in input_images 89 | ] 90 | output_images = [ 91 | os.path.join(FLAGS.output_image_dir, s) 92 | for s in output_images 93 | ] 94 | 95 | dataset = PairedImageDataset(input_images, output_images) 96 | dataloader = torch.utils.data.DataLoader( 97 | dataset, 98 | batch_size=FLAGS.batch_size * FLAGS.n_shots, 99 | shuffle=False, 100 | num_workers=FLAGS.n_workers, 101 | drop_last=True 102 | ) 103 | 104 | total_images = len(input_images) - len(input_images) % (FLAGS.batch_size * FLAGS.n_shots) 105 | 106 | with torch.no_grad(): 107 | with NamedTemporaryFile() as ntf: 108 | all_tokens = np.memmap(ntf, dtype='i4', mode='w+', shape=(total_images, 512)) 109 | all_tokens[:] = 0 110 | 111 | index = 0 112 | for input_image_batch, output_image_batch in tqdm(dataloader, ncols=0): 113 | _, input_token_batch = net.encode(input_image_batch.permute(0,3,1,2).to(device)) 114 | _, output_token_batch = net.encode(output_image_batch.permute(0, 3, 1, 2).to(device)) 115 | 116 | 117 | all_tokens[index:index + input_image_batch.shape[0]] = np.concatenate( 118 | [input_token_batch.cpu().numpy().astype(np.int32), output_token_batch.cpu().numpy().astype(np.int32)], 119 | axis=1 120 | ) 121 | index += input_image_batch.shape[0] 122 | 123 | with open(FLAGS.output_file, 'w') as fout: 124 | for _ in trange(FLAGS.n_epochs, ncols=0): 125 | indices = np.random.permutation(total_images).reshape(-1, FLAGS.n_shots) 126 | for i in trange(indices.shape[0], ncols=0): 127 | tokens = all_tokens[indices[i], :].reshape(-1) 128 | data = {'tokens': b64encode(tokens.tobytes()).decode('utf-8'),} 129 | fout.write(json.dumps(data) + '\n') 130 | 131 | 132 | if __name__ == '__main__': 133 | mlxu.run(main) -------------------------------------------------------------------------------- /tokenize_examples/tokenize_seq_images_muse.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import glob 4 | from functools import partial 5 | from tempfile import NamedTemporaryFile 6 | import random 7 | import json 8 | from base64 import b64encode 9 | from tqdm import tqdm, trange 10 | from muse import VQGANModel 11 | 12 | import numpy as np 13 | np.float = np.float64 14 | np.int = np.int_ 15 | import mlxu 16 | 17 | import torch 18 | 19 | 20 | import einops 21 | 22 | from PIL import Image 23 | 24 | from utils import ( 25 | list_dir_with_full_path, is_image, read_image_to_tensor, 26 | randomly_subsample_frame_indices 27 | ) 28 | 29 | 30 | 31 | 32 | FLAGS, _ = mlxu.define_flags_with_default( 33 | input_dirs='', 34 | output_file='', 35 | batch_size=1, 36 | n_frames=16, 37 | n_epochs=1, 38 | n_workers=8, 39 | max_stride=4, 40 | dtype='fp32', 41 | ) 42 | 43 | 44 | class VideoDataset(torch.utils.data.Dataset): 45 | 46 | def __init__(self, videos, n_frames=8): 47 | self.videos = videos 48 | self.n_frames = n_frames 49 | 50 | def __getitem__(self, index): 51 | frames = [] 52 | for file in sorted(list_dir_with_full_path(self.videos[index])): 53 | if is_image(file): 54 | frames.append(read_image_to_tensor(file)) 55 | if len(frames) < self.n_frames: 56 | return self[np.random.randint(0, len(self))] 57 | indices = randomly_subsample_frame_indices( 58 | len(frames), self.n_frames, FLAGS.max_stride, 59 | random_start=True 60 | ) 61 | frames = np.stack([frames[i] for i in indices], axis=0) 62 | return frames 63 | 64 | def __len__(self): 65 | return len(self.videos) 66 | 67 | 68 | def main(argv): 69 | assert FLAGS.input_dirs != '' 70 | assert FLAGS.output_file != '' 71 | 72 | # Load the pre-trained vq model from the hub 73 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 74 | 75 | net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device) 76 | net.eval() 77 | 78 | # videos = list_dir_with_full_path(FLAGS.input_dir) 79 | videos = glob.glob(FLAGS.input_dirs) 80 | 81 | with torch.no_grad(): 82 | with open(FLAGS.output_file, 'w') as fout: 83 | dataset = VideoDataset(videos, n_frames=FLAGS.n_frames) 84 | dataloader = torch.utils.data.DataLoader( 85 | dataset, 86 | batch_size=FLAGS.batch_size, 87 | shuffle=False, 88 | num_workers=FLAGS.n_workers, 89 | prefetch_factor=4, 90 | drop_last=True, 91 | ) 92 | for _ in range(FLAGS.n_epochs): 93 | for batch in tqdm(dataloader, ncols=0): 94 | batch_size = batch.shape[0] 95 | batch = einops.rearrange( 96 | batch.numpy(), 'b t h w c -> (b t) c h w' 97 | ) 98 | batch = torch.tensor(batch).to(device) 99 | _, tokens = net.encode(batch) 100 | tokens = einops.rearrange( 101 | tokens.cpu().numpy().astype(np.int32), '(b t) d -> b (t d)', b=batch_size 102 | ) 103 | for i in range(batch_size): 104 | data = {'tokens': b64encode(tokens[i].tobytes()).decode('utf-8')} 105 | fout.write(json.dumps(data) + '\n') 106 | 107 | 108 | 109 | if __name__ == '__main__': 110 | mlxu.run(main) 111 | -------------------------------------------------------------------------------- /tokenize_examples/tokenize_video_muse.py: -------------------------------------------------------------------------------- 1 | from base64 import b64encode 2 | from tqdm import tqdm, trange 3 | import numpy as np 4 | np.float = np.float64 5 | np.int = np.int_ 6 | from utils import read_frames_from_video, is_video 7 | 8 | import einops 9 | from torch.utils.data import Dataset, DataLoader 10 | from PIL import Image 11 | import numpy as np 12 | from tqdm import tqdm 13 | import torch 14 | from muse import VQGANModel 15 | from base64 import b64encode 16 | import json 17 | import os 18 | import mlxu 19 | 20 | 21 | 22 | FLAGS, _ = mlxu.define_flags_with_default( 23 | input_dir='DAVIS/JPEGImages/480p', 24 | output_file='vqlm/muse/running_script/tokenized_muse/davis.jsonl', 25 | batch_size=32, 26 | n_frames=16, 27 | n_workers=32, 28 | strides='8', 29 | n_epochs=1, 30 | dtype='fp32', 31 | ) 32 | 33 | 34 | class VideoDataset(torch.utils.data.Dataset): 35 | 36 | def __init__(self, videos, n_frames=8, stride=1): 37 | self.videos = videos 38 | self.n_frames = n_frames 39 | self.stride = stride 40 | 41 | def __getitem__(self, index): 42 | frames = read_frames_from_video(self.videos[index], self.n_frames, self.stride) 43 | if frames is None: 44 | return self[np.random.randint(0, len(self))] 45 | return frames 46 | 47 | def __len__(self): 48 | return len(self.videos) 49 | 50 | 51 | def main(argv): 52 | assert FLAGS.input_dir != '' 53 | assert FLAGS.output_file != '' 54 | 55 | # Load the pre-trained vq model from the hub 56 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 57 | 58 | net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device) 59 | net.eval() 60 | 61 | videos = [] 62 | for root, _, files in os.walk(FLAGS.input_dir): 63 | for file in files: 64 | if is_video(file): 65 | videos.append(os.path.join(root, file)) 66 | 67 | with open(FLAGS.output_file, 'w') as fout: 68 | with torch.no_grad(): 69 | for epoch in trange(FLAGS.n_epochs, ncols=0): 70 | for stride in tqdm(FLAGS.strides.split(','), ncols=0): 71 | stride = int(stride) 72 | dataset = VideoDataset(videos, n_frames=FLAGS.n_frames, stride=stride) 73 | dataloader = torch.utils.data.DataLoader( 74 | dataset, 75 | batch_size=FLAGS.batch_size, 76 | shuffle=False, 77 | num_workers=FLAGS.n_workers, 78 | prefetch_factor=4, 79 | drop_last=True, 80 | ) 81 | for batch in tqdm(dataloader, ncols=0): 82 | batch_size = batch.shape[0] 83 | batch = einops.rearrange( 84 | batch.numpy(), 'b t h w c -> (b t) c h w' 85 | ) 86 | batch = torch.tensor(batch).to(device) 87 | _, tokens = net.encode(batch) 88 | tokens = einops.rearrange( 89 | tokens.cpu().numpy().astype(np.int32), '(b t) d -> b (t d)', b=batch_size 90 | ) 91 | for i in range(batch_size): 92 | data = {'tokens': b64encode(tokens[i].tobytes()).decode('utf-8'),} 93 | fout.write(json.dumps(data) + '\n') 94 | 95 | 96 | 97 | if __name__ == '__main__': 98 | mlxu.run(main) 99 | --------------------------------------------------------------------------------