├── .DS_Store
├── DATASET.md
├── EasyLM
├── __init__.py
├── bpt.py
├── checkpoint.py
├── data.py
├── jax_utils.py
├── models
│ ├── __init__.py
│ └── llama
│ │ ├── convert_easylm_to_hf.py
│ │ ├── llama_model.py
│ │ ├── llama_serve.py
│ │ └── llama_train.py
├── optimizers.py
├── scripts
│ ├── __init__.py
│ ├── convert_checkpoint.py
│ ├── diff_checkpoint.py
│ ├── lm_eval_harness.py
│ └── lm_eval_json.py
└── serving.py
├── LICENSE
├── README.md
├── evaluation
├── .DS_Store
├── EVAL.md
├── LICENSE
├── environment.yml
└── vqlm_demo
│ ├── .DS_Store
│ ├── __init__.py
│ ├── batch_generation.py
│ ├── eval_perplexity.py
│ ├── eval_video_perplexity.py
│ ├── eval_videos.py
│ ├── generate_videos.py
│ ├── inference.py
│ ├── torch_vqvae_model.py
│ ├── utils.py
│ ├── vqvae
│ ├── .DS_Store
│ ├── __init__.py
│ ├── logging.py
│ └── modeling_utils.py
│ └── vqvae_muse.py
├── images
└── visual_sentences.jpg
├── scripts
├── gpu_environment.yml
├── tpu_commands.sh
└── tpu_vm_setup.sh
└── tokenize_examples
├── detokenization_muse.py
├── map_color.py
├── tokenize_catogory_images_muse.py
├── tokenize_co3d_muse.py
├── tokenize_colorization_dataset_muse.py
├── tokenize_inpainting_dataset_muse.py
├── tokenize_multi_datasets_muse.py
├── tokenize_multi_seq_images_muse.py
├── tokenize_paired_dataset_muse.py
├── tokenize_seq_images_muse.py
└── tokenize_video_muse.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/.DS_Store
--------------------------------------------------------------------------------
/DATASET.md:
--------------------------------------------------------------------------------
1 |
2 | # Dataset Preparation
3 |
4 | This section describes how to prepare your dataset by tokenizing visual data into visual sentences and then mixing and shuffling the datasets.
5 |
6 | ## Download & Prepare Dataset
7 | We have individual scripts in `./tokenize_examples` to handle different kinds of visual sentences mentioned in the paper. Simple descriptions and instructions are listed as follows.
8 |
9 | ### Pair Datasets
10 | For pair datasets, where the visual sentence is constructed as `[image, label, image, label, image, label...]`, we use `tokenize_examples/tokenize_paired_dataset_muse.py` to generate the visual tokens. For example, for depth maps, surface normals, edges, and segmentation, we first use [prismer](https://github.com/NVlabs/prismer) to generate pseudo labels, then use the script to generate visual sentences. Note that for segmentation, we use an additional color mapping after obtaining the pseudo labels from prismer, which can be done by running `tokenize_examples/map_color.py`.
11 |
12 | ### Video Datasets
13 | For video datasets, the visual sentences are constructed as `[frame1, frame2, frame3, ... framex]`. One can use `tokenize_examples/tokenize_video_muse.py` to generate the visual sentences. The hyperparameter `stride` can be used to control the sampling rate of the extraction of frames from a video.
14 |
15 | ### Colorization Datasets
16 | For colorization datasets, the visual sentences are constructed as `[gray_image, colored_image, gray_image, colored_image, ...]`. One can use `tokenize_examples/tokenize_colorization_dataset_muse.py` to generate the visual sentences. The user only needs to prepare the colored images, and the script will generate the gray counterparts.
17 |
18 | ### Inpainting Datasets
19 | For inpainting datasets, the visual sentences are constructed as `[masked_image, image, masked_image, image, ...]`. One can use `tokenize_examples/tokenize_inpainting_dataset_muse.py` to generate the visual sentences. The user can control the masked ratio by changing `FLAGS.hole_mask_ratio`.
20 |
21 | ### Multi-Datasets
22 | For multi-datasets, the visual sentences are constructed as `[dataset1_image, dataset1_image, dataset2_image, dataset2_image, ...]`. One can use `tokenize_examples/tokenize_multi_datasets_muse.py` to generate the visual sentences.
23 |
24 | ### Category Datasets
25 | For category datasets, the visual sentences are constructed as `[cate1_image, cate1_image, cate2_image, cate2_image,...]`. One can use `tokenize_examples/tokenize_category_images_muse.py` to generate the visual sentences. Note that the user can use `images_per_shot` to customize the number of images for each category and `n_shots` to control the number of categories in the visual sentences.
26 |
27 | In general, visual sentences with different logic can be achieved by combining the basic logic provided above. After you generate your visual sentences, you can always use `tokenize_examples/detokenization_muse.py` to perform a sanity check for the recovery of the visual sentences.
28 |
29 |
30 | ## Visual Sentences
31 |
32 | - Following the concept of Visual Sentences, each image is tokenized separately, and the tokens are concatenated to form a visual sentence.
33 |
34 | ## Generating JSONL Files
35 |
36 | - For each dataset, generate the `dataset*.jsonl` files.
37 |
38 | ### Example Code
39 |
40 | - You can find example code for tokenizing images in the `./tokenize_examples` directory.
41 |
42 | ## Mixing and Shuffling Datasets
43 |
44 | - After generating the JSONL files, the datasets need to be mixed and shuffled. Follow these steps:
45 |
46 | 1. Set the temporary directory and memory allocation:
47 | ```shell
48 | export TMPDIR='/global/scratch/users/yutong/data/temp'
49 | export MEMORY='20'
50 | ```
51 |
52 | 2. Navigate to your data directory:
53 | ```shell
54 | cd /path/to/your/data/
55 | ```
56 |
57 | 3. Mix and shuffle the datasets:
58 | ```shell
59 | cat tokenized_tasks/*.jsonl | terashuf > mix_and_shuffled/dataset.jsonl
60 | ```
61 |
62 | - This will create a mixed and shuffled dataset file named `dataset.jsonl` in the `mix_and_shuffled` directory.
--------------------------------------------------------------------------------
/EasyLM/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/EasyLM/__init__.py
--------------------------------------------------------------------------------
/EasyLM/bpt.py:
--------------------------------------------------------------------------------
1 | """An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370
2 |
3 | Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682
4 | """
5 |
6 | import functools
7 | from typing import NamedTuple
8 |
9 | import flax.linen as nn
10 | import jax
11 | import jax.lax as lax
12 | import jax.numpy as jnp
13 | from einops import rearrange
14 |
15 | '''
16 | Computing ffn blockwise without materializing the large hidden tensor, training 4x longer sequences than the memory-efficient transformer.
17 | Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023
18 | '''
19 | def blockwise_ffn(remat_ffn, inputs, chunk_size, deterministic):
20 | # remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable()
21 | # inputs: (batch, seq_len, dim)
22 | # chunk_size: the chunk size to split the sequence
23 | inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
24 | def scan_ffn(remat_ffn, carry, hidden_states):
25 | outputs = remat_ffn(hidden_states, deterministic=deterministic)
26 | return carry, outputs
27 | scan_axis = inputs.ndim - 2
28 | _, res = nn.scan(
29 | scan_ffn,
30 | variable_broadcast="params",
31 | split_rngs={"params": False, "dropout": True},
32 | in_axes=scan_axis,
33 | out_axes=scan_axis,
34 | )(remat_ffn, None, inputs)
35 | res = rearrange(res, 'b c n d -> b (c n) d')
36 | return res
37 |
38 |
39 | '''
40 | Compute attention blockwise without materializing the full attention matrix, initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021;
41 | flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA efficient implementation;
42 | blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x longer sequences than memory-efficient/flash-attention and fusion of attention and FFN.
43 | '''
44 | def blockwise_attn(query, key, value, bias, deterministic,
45 | dropout_rng, attn_pdrop, causal, query_chunk_size,
46 | key_chunk_size, dtype, policy, precision, float32_logits,
47 | prevent_cse):
48 | # query, key, value: (batch, seq_len, num_heads, dim_per_head)
49 | # bias: (batch, seq_len) can be used to mask out attention (e.g. padding)
50 | # causal: whether to use causal mask
51 | # policy: one of jax.checkpoint_policies
52 | query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
53 | if float32_logits:
54 | query = query.astype(jnp.float32)
55 | key = key.astype(jnp.float32)
56 |
57 | batch, q_len, num_heads, dim_per_head = query.shape
58 | batch, kv_len, num_heads, dim_per_head = key.shape
59 | batch, kv_len, num_heads, dim_per_head = value.shape
60 |
61 | num_q = q_len // query_chunk_size
62 | num_kv = kv_len // key_chunk_size
63 | query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
64 | key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
65 | value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
66 |
67 | query = jnp.moveaxis(query, 1, 0)
68 | key = jnp.moveaxis(key, 1, 0)
69 | value = jnp.moveaxis(value, 1, 0)
70 |
71 | if bias is not None:
72 | for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
73 | assert bias_dim == 1 or bias_dim == broadcast_dim
74 | if not deterministic and attn_pdrop > 0.0:
75 | attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
76 | attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
77 | else:
78 | attn_dropout = None
79 |
80 | _chunk_bias_fn = functools.partial(
81 | _chunk_attention_bias,
82 | query_chunk_size, key_chunk_size, bias, deterministic,
83 | attn_dropout, attn_pdrop, causal, dtype)
84 |
85 | def scan_attention(args):
86 | query_chunk, query_chunk_idx = args
87 |
88 | @functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
89 | def scan_kv_block(carry, args):
90 | key_chunk, value_chunk, key_chunk_idx = args
91 | (numerator, denominator, prev_max_score) = carry
92 | attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision)
93 | bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx)
94 | bias_chunk = jnp.moveaxis(bias_chunk, 1, 2)
95 | attn_weights = attn_weights + bias_chunk
96 |
97 | max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
98 | max_score = jnp.maximum(prev_max_score, max_score)
99 | max_score = jax.lax.stop_gradient(max_score)
100 | exp_weights = jnp.exp(attn_weights - max_score)
101 | exp_values = jnp.einsum(
102 | 'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision
103 | )
104 | correction = jnp.exp(prev_max_score - max_score)
105 | numerator = numerator * correction + exp_values
106 | denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True)
107 | return Carry(numerator, denominator, max_score), None
108 |
109 | def skip_upper_half(carry, args):
110 | key_chunk, value_chunk, key_chunk_idx = args
111 | skip_block = jnp.array(False)
112 | if causal:
113 | skip_block = query_chunk_idx < key_chunk_idx
114 | return jax.lax.cond(
115 | skip_block,
116 | lambda carry, args: (carry, None),
117 | scan_kv_block,
118 | carry,
119 | args,
120 | )
121 |
122 | init_carry = Carry(
123 | jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
124 | jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
125 | (-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype),
126 | )
127 | (numerator, denominator, max_score), _ = lax.scan(
128 | skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv))
129 | )
130 | outputs = (numerator / denominator).astype(dtype)
131 | return outputs
132 |
133 | _, res = lax.scan(
134 | lambda _, x: ((), scan_attention(x)),
135 | (), xs=(query, jnp.arange(0, num_q))
136 | )
137 | res = rearrange(res, 'n b c h d -> b (n c) h d')
138 | return res
139 |
140 |
141 | class Carry(NamedTuple):
142 | numerator: jax.Array
143 | denominator: jax.Array
144 | max_so_far: jax.Array
145 |
146 |
147 | def _chunk_attention_bias(query_chunk_size, key_chunk_size,
148 | bias, deterministic, attn_dropout, attn_pdrop, causal,
149 | dtype, query_chunk_idx, key_chunk_idx):
150 | query_offset = query_chunk_idx * query_chunk_size
151 | key_offset = key_chunk_idx * key_chunk_size
152 | chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
153 | if bias is not None:
154 | chunk_bias = lax.dynamic_slice(
155 | bias,
156 | start_indices=(0, 0, query_offset, key_offset),
157 | slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
158 | )
159 |
160 | if causal:
161 | query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
162 | key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
163 | offset = query_offset - key_offset
164 | query_idx += offset
165 | causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
166 | chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)
167 |
168 | if not deterministic and attn_pdrop > 0.0:
169 | attn_dropout_slice = lax.dynamic_slice(
170 | attn_dropout,
171 | start_indices=(0, 0, query_offset, key_offset),
172 | slice_sizes=(
173 | *attn_dropout.shape[:2],
174 | min(attn_dropout.shape[-2], query_chunk_size),
175 | min(attn_dropout.shape[-1], key_chunk_size),
176 | ),
177 | )
178 | chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
179 | return chunk_bias.astype(dtype)
180 |
181 |
182 | if __name__ == '__main__':
183 | # test
184 | def reference_attn(query, key, value, causal, dtype):
185 | query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
186 | logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
187 | if causal:
188 | mask_value = jnp.finfo(logits.dtype).min
189 | _, q_seq_len, _, _ = query.shape
190 | _, kv_seq_len, _, _ = key.shape
191 | mask_shape = (q_seq_len, kv_seq_len)
192 | row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
193 | col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
194 | causal_mask = (row_ids < col_ids)[None, None, :, :]
195 | logits = logits + jnp.where(causal_mask, mask_value, 0.0)
196 | weights = jax.nn.softmax(logits, axis=-1)
197 | out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
198 | return out
199 |
200 | # random inputs
201 | shape = (1, 32, 8, 64)
202 | query = jax.random.normal(jax.random.PRNGKey(0), shape)
203 | key = jax.random.normal(jax.random.PRNGKey(1), shape)
204 | value = jax.random.normal(jax.random.PRNGKey(2), shape)
205 |
206 | causal = True
207 | chunk_size = 4
208 | policy = jax.checkpoint_policies.nothing_saveable()
209 |
210 | blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False)
211 | reference = reference_attn(query, key, value, causal, 'float32')
212 |
213 | assert jnp.allclose(reference, blockwise, atol=1e-6)
214 |
--------------------------------------------------------------------------------
/EasyLM/checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from ml_collections import ConfigDict
4 | import mlxu
5 | import jax
6 | import jax.numpy as jnp
7 | import flax
8 | from flax.serialization import (
9 | from_bytes, to_bytes, to_state_dict, from_state_dict
10 | )
11 | from flax.traverse_util import flatten_dict, unflatten_dict, empty_node
12 | import msgpack
13 |
14 | from EasyLM.jax_utils import tree_apply, float_tensor_to_dtype
15 |
16 |
17 | class StreamingCheckpointer(object):
18 | """ Custom msgpack checkpointer that saves large train states by serializing
19 | and saving tensors one by one in a streaming fashion. Avoids running
20 | out of memory or local TPU disk with default flax checkpointer.
21 | """
22 |
23 | @staticmethod
24 | def get_default_config(updates=None):
25 | config = ConfigDict()
26 | config.float_dtype = 'bf16'
27 | config.save_optimizer_state = False
28 |
29 | if updates is not None:
30 | config.update(ConfigDict(updates).copy_and_resolve_references())
31 | return config
32 |
33 | def __init__(self, config, checkpoint_dir, enable=True):
34 | self.config = self.get_default_config(config)
35 | self.checkpoint_dir = checkpoint_dir
36 | self.enable = enable
37 |
38 | def save_checkpoint(self, train_state, filename, gather_fns=None):
39 | if self.enable:
40 | path = os.path.join(self.checkpoint_dir, filename)
41 | else:
42 | path = '/dev/null'
43 | self.save_train_state_to_file(
44 | train_state, path, gather_fns, self.config.float_dtype
45 | )
46 |
47 | @staticmethod
48 | def save_train_state_to_file(train_state, path, gather_fns=None, float_dtype=None):
49 | train_state = to_state_dict(train_state)
50 | packer = msgpack.Packer()
51 | flattend_train_state = flatten_dict(train_state)
52 | if gather_fns is not None:
53 | gather_fns = flatten_dict(to_state_dict(gather_fns))
54 |
55 | with mlxu.open_file(path, "wb") as fout:
56 | for key, value in flattend_train_state.items():
57 | if gather_fns is not None:
58 | value = gather_fns[key](value)
59 | value = float_tensor_to_dtype(value, float_dtype)
60 | fout.write(packer.pack((key, to_bytes(value))))
61 |
62 | def save_pickle(self, obj, filename):
63 | if self.enable:
64 | path = os.path.join(self.checkpoint_dir, filename)
65 | else:
66 | path = '/dev/null'
67 | mlxu.save_pickle(obj, path)
68 |
69 | def save_all(self, train_state, gather_fns, metadata=None, dataset=None, milestone=False):
70 | step = int(jax.device_get(train_state.step))
71 | if self.config.save_optimizer_state:
72 | checkpoint_state = train_state
73 | checkpoint_name = 'streaming_train_state'
74 | checkpoint_gather_fns = gather_fns
75 | else:
76 | checkpoint_state = train_state.params['params']
77 | checkpoint_name = 'streaming_params'
78 | checkpoint_gather_fns = gather_fns.params['params']
79 |
80 | if milestone:
81 | # Save a milestone checkpoint that will not be overwritten
82 | self.save_pickle(metadata, f'metadata_{step}.pkl')
83 | self.save_pickle(dataset, f'dataset_{step}.pkl')
84 | self.save_checkpoint(
85 | checkpoint_state, f'{checkpoint_name}_{step}', checkpoint_gather_fns
86 | )
87 | else:
88 | # Save a normal checkpoint that can be overwritten
89 | self.save_pickle(metadata, 'metadata.pkl')
90 | self.save_pickle(dataset, 'dataset.pkl')
91 | self.save_checkpoint(
92 | checkpoint_state, f'{checkpoint_name}', checkpoint_gather_fns
93 | )
94 |
95 | @staticmethod
96 | def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None):
97 | if shard_fns is not None:
98 | shard_fns = flatten_dict(
99 | to_state_dict(shard_fns)
100 | )
101 | if remove_dict_prefix is not None:
102 | remove_dict_prefix = tuple(remove_dict_prefix)
103 | flattend_train_state = {}
104 | with mlxu.open_file(path) as fin:
105 | # 83886080 bytes = 80 MB, which is 16 blocks on GCS
106 | unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0)
107 | for key, value in unpacker:
108 | key = tuple(key)
109 | if remove_dict_prefix is not None:
110 | if key[:len(remove_dict_prefix)] == remove_dict_prefix:
111 | key = key[len(remove_dict_prefix):]
112 | else:
113 | continue
114 |
115 | tensor = from_bytes(None, value)
116 | if shard_fns is not None:
117 | tensor = shard_fns[key](tensor)
118 | flattend_train_state[key] = tensor
119 |
120 | if target is not None:
121 | flattened_target = flatten_dict(
122 | to_state_dict(target), keep_empty_nodes=True
123 | )
124 | for key, value in flattened_target.items():
125 | if key not in flattend_train_state and value == empty_node:
126 | flattend_train_state[key] = value
127 |
128 | train_state = unflatten_dict(flattend_train_state)
129 | if target is None:
130 | return train_state
131 |
132 | return from_state_dict(target, train_state)
133 |
134 | @staticmethod
135 | def load_flax_checkpoint(path, target=None, shard_fns=None):
136 | """ Load a standard flax checkpoint that's not saved with the
137 | msgpack streaming format.
138 | """
139 | with mlxu.open_file(path, "rb") as fin:
140 | encoded_bytes = fin.read()
141 |
142 | state_dict = flax.serialization.msgpack_restore(encoded_bytes)
143 | if shard_fns is not None:
144 | shard_fns = to_state_dict(shard_fns)
145 | state_dict = tree_apply(shard_fns, state_dict)
146 |
147 | if target is None:
148 | return state_dict
149 | return from_state_dict(target, state_dict)
150 |
151 | @classmethod
152 | def load_trainstate_checkpoint(cls, load_from, trainstate_target=None,
153 | trainstate_shard_fns=None,
154 | disallow_trainstate=False):
155 | if trainstate_target is not None:
156 | params_target = trainstate_target.params['params']
157 | else:
158 | params_target = None
159 |
160 | if trainstate_shard_fns is not None:
161 | params_shard_fns = trainstate_shard_fns.params['params']
162 | else:
163 | params_shard_fns = None
164 |
165 | load_type, load_path = load_from.split('::', 1)
166 | if disallow_trainstate:
167 | assert load_type != 'trainstate', 'Loading full trainstate is not allowed!'
168 | train_state = None
169 | restored_params = None
170 | if load_type == 'trainstate':
171 | # Load the entire train state in the streaming format
172 | train_state = cls.load_checkpoint(
173 | path=load_path,
174 | target=trainstate_target,
175 | shard_fns=trainstate_shard_fns,
176 | )
177 | elif load_type == 'trainstate_params':
178 | # Load the params part of the train state in the streaming format
179 | restored_params = cls.load_checkpoint(
180 | path=load_path,
181 | target=params_target,
182 | shard_fns=params_shard_fns,
183 | remove_dict_prefix=('params', 'params'),
184 | )
185 | restored_params = flax.core.frozen_dict.freeze(
186 | {'params': restored_params}
187 | )
188 | elif load_type == 'params':
189 | # Load the params in the streaming format
190 | restored_params = cls.load_checkpoint(
191 | path=load_path,
192 | target=params_target,
193 | shard_fns=params_shard_fns,
194 | )
195 | restored_params = flax.core.frozen_dict.freeze(
196 | {'params': restored_params}
197 | )
198 | elif load_type == 'flax_params':
199 | # Load the params in the standard flax format (non-streaming)
200 | # This requires the entire params to fit in memory
201 | restored_params = cls.load_flax_checkpoint(
202 | path=load_path,
203 | target=params_target,
204 | shard_fns=params_shard_fns
205 | )
206 | restored_params = flax.core.frozen_dict.freeze(
207 | {'params': restored_params}
208 | )
209 | else:
210 | raise ValueError(f'Invalid load_from type: {load_type}')
211 |
212 | return train_state, restored_params
213 |
--------------------------------------------------------------------------------
/EasyLM/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/EasyLM/models/__init__.py
--------------------------------------------------------------------------------
/EasyLM/models/llama/convert_easylm_to_hf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2 | # Copyright 2023 Xinyang Geng
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # This script converts LLaMA model checkpoint trained by EsayLM to the
17 | # HuggingFace transformers LLaMA PyTorch format, which can then be loaded
18 | # by HuggingFace transformers.
19 |
20 | import gc
21 | import json
22 | import math
23 | import os
24 | import shutil
25 |
26 | import numpy as np
27 | import mlxu
28 | import jax
29 | import jax.numpy as jnp
30 | import flax
31 | from flax.traverse_util import flatten_dict
32 | import torch
33 | from transformers import LlamaConfig, LlamaForCausalLM
34 |
35 | from EasyLM.checkpoint import StreamingCheckpointer
36 | from EasyLM.jax_utils import float_tensor_to_dtype
37 | from EasyLM.models.llama.llama_model import LLAMA_STANDARD_CONFIGS
38 |
39 |
40 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
41 | load_checkpoint='',
42 | model_size='vqlm_1b',
43 | output_dir='',
44 | )
45 |
46 |
47 | def match_keywords(string, positives, negatives):
48 | for positive in positives:
49 | if positive not in string:
50 | return False
51 | for negative in negatives:
52 | if negative in string:
53 | return False
54 | return True
55 |
56 |
57 | def load_and_convert_checkpoint(path):
58 | _, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
59 | flax_params = flatten_dict(flax_params['params'], sep='.')
60 | torch_params = {}
61 | for key, tensor in flax_params.items():
62 | if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
63 | tensor = tensor.T
64 | torch_params[key] = torch.tensor(
65 | float_tensor_to_dtype(tensor, 'fp32'), dtype=torch.float16
66 | )
67 | return torch_params
68 |
69 |
70 | def read_json(path):
71 | with open(path, "r") as f:
72 | return json.load(f)
73 |
74 |
75 | def write_json(text, path):
76 | with open(path, "w") as f:
77 | json.dump(text, f)
78 |
79 |
80 | def write_model(loaded, model_path, model_size):
81 | os.makedirs(model_path, exist_ok=True)
82 | tmp_model_path = os.path.join(model_path, "tmp")
83 | os.makedirs(tmp_model_path, exist_ok=True)
84 |
85 | params = LLAMA_STANDARD_CONFIGS[model_size]
86 |
87 | n_layers = params["num_hidden_layers"]
88 | n_heads = params["num_attention_heads"]
89 | dim = params["hidden_size"]
90 | dims_per_head = dim // n_heads
91 | base = 10000.0
92 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
93 |
94 | # permute for sliced rotary
95 | def permute(w):
96 | return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
97 |
98 |
99 | param_count = 0
100 | index_dict = {"weight_map": {}}
101 | for layer_i in range(n_layers):
102 | filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
103 | state_dict = {
104 | f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
105 | loaded[f"transformer.h.{layer_i}.attention.wq.kernel"]
106 | ),
107 | f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
108 | loaded[f"transformer.h.{layer_i}.attention.wk.kernel"]
109 | ),
110 | f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wv.kernel"],
111 | f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wo.kernel"],
112 |
113 | f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w1.kernel"],
114 | f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w2.kernel"],
115 | f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w3.kernel"],
116 |
117 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.h.{layer_i}.attention_norm.kernel"],
118 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"transformer.h.{layer_i}.ffn_norm.kernel"],
119 |
120 | }
121 |
122 | state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
123 | for k, v in state_dict.items():
124 | index_dict["weight_map"][k] = filename
125 | param_count += v.numel()
126 | torch.save(state_dict, os.path.join(tmp_model_path, filename))
127 |
128 | filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
129 | # Unsharded
130 | state_dict = {
131 | "model.embed_tokens.weight": loaded["transformer.wte.embedding"],
132 | "model.norm.weight": loaded["transformer.ln_f.kernel"],
133 | "lm_head.weight": loaded["lm_head.kernel"],
134 | }
135 |
136 | for k, v in state_dict.items():
137 | index_dict["weight_map"][k] = filename
138 | param_count += v.numel()
139 | torch.save(state_dict, os.path.join(tmp_model_path, filename))
140 |
141 | # Write configs
142 | index_dict["metadata"] = {"total_size": param_count * 2}
143 | write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
144 |
145 | config = LlamaConfig(
146 | vocab_size=params["vocab_size"],
147 | max_position_embeddings=params["max_sequence_length"],
148 | hidden_size=dim,
149 | intermediate_size=params["intermediate_size"],
150 | num_attention_heads=params["num_attention_heads"],
151 | num_hidden_layers=params["num_hidden_layers"],
152 | rms_norm_eps=params["rms_norm_eps"],
153 | )
154 | config.save_pretrained(tmp_model_path)
155 |
156 | # Make space so we can load the model properly now.
157 | del state_dict
158 | del loaded
159 | gc.collect()
160 |
161 | print("Loading the checkpoint in a Llama model.")
162 | model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16)
163 | # Avoid saving this as part of the config.
164 | del model.config._name_or_path
165 |
166 | print("Saving in the Transformers format.")
167 | model.save_pretrained(model_path)
168 | shutil.rmtree(tmp_model_path)
169 |
170 |
171 | def main(argv):
172 | assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != ""
173 | assert FLAGS.model_size in LLAMA_STANDARD_CONFIGS
174 |
175 | write_model(
176 | load_and_convert_checkpoint(FLAGS.load_checkpoint),
177 | model_path=FLAGS.output_dir,
178 | model_size=FLAGS.model_size,
179 | )
180 |
181 |
182 | if __name__ == "__main__":
183 | mlxu.run(main)
--------------------------------------------------------------------------------
/EasyLM/models/llama/llama_train.py:
--------------------------------------------------------------------------------
1 | import pprint
2 | from functools import partial
3 |
4 | from tqdm import tqdm, trange
5 | import numpy as np
6 | import mlxu
7 |
8 | import jax
9 | import jax.numpy as jnp
10 | from jax.experimental.pjit import pjit
11 | from jax.sharding import PartitionSpec as PS
12 | from flax.training.train_state import TrainState
13 |
14 | from EasyLM.data import DatasetFactory
15 | from EasyLM.checkpoint import StreamingCheckpointer
16 | from EasyLM.optimizers import OptimizerFactory
17 | from EasyLM.jax_utils import (
18 | JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
19 | cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
20 | set_random_seed, average_metrics, get_weight_decay_mask,
21 | make_shard_and_gather_fns, with_sharding_constraint,
22 | )
23 | from EasyLM.models.llama.llama_model import (
24 | LLaMAConfig, FlaxLLaMAForCausalLMModule
25 | )
26 |
27 |
28 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
29 | seed=42,
30 | mesh_dim='1,-1,1',
31 | dtype='fp32',
32 | total_steps=10000,
33 | load_llama_config='',
34 | update_llama_config='',
35 | load_checkpoint='',
36 | load_dataset_state='',
37 | load_metadata='',
38 | log_freq=50,
39 | save_model_freq=0,
40 | save_milestone_freq=0,
41 | eval_steps=0,
42 | tokenizer=LLaMAConfig.get_tokenizer_config(),
43 | train_dataset=DatasetFactory.get_default_config(),
44 | eval_dataset=DatasetFactory.get_default_config(),
45 | optimizer=OptimizerFactory.get_default_config(),
46 | checkpointer=StreamingCheckpointer.get_default_config(),
47 | llama=LLaMAConfig.get_default_config(),
48 | logger=mlxu.WandBLogger.get_default_config(),
49 | log_all_worker=False,
50 | jax_distributed=JaxDistributedConfig.get_default_config(),
51 | )
52 |
53 |
54 | def main(argv):
55 | JaxDistributedConfig.initialize(FLAGS.jax_distributed)
56 | variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
57 | flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
58 | logger = mlxu.WandBLogger(
59 | config=FLAGS.logger,
60 | variant=variant,
61 | enable=FLAGS.log_all_worker or (jax.process_index() == 0),
62 | )
63 | set_random_seed(FLAGS.seed)
64 |
65 | tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
66 | dataset = DatasetFactory.load_dataset(
67 | FLAGS.train_dataset, tokenizer, device_count=jax.device_count()
68 | )
69 | if FLAGS.load_dataset_state != '':
70 | dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
71 |
72 | if FLAGS.eval_steps > 0:
73 | eval_dataset = DatasetFactory.load_dataset(
74 | FLAGS.eval_dataset, dataset.tokenizer
75 | )
76 | eval_iterator = iter(eval_dataset)
77 |
78 | seq_length = dataset.seq_length
79 |
80 | if FLAGS.load_llama_config != '':
81 | llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
82 | else:
83 | llama_config = LLaMAConfig(**FLAGS.llama)
84 |
85 | if FLAGS.update_llama_config != '':
86 | llama_config.update(dict(eval(FLAGS.update_llama_config)))
87 |
88 | # llama_config.update(dict(
89 | # bos_token_id=dataset.tokenizer.bos_token_id,
90 | # eos_token_id=dataset.tokenizer.eos_token_id,
91 | # ))
92 | # if llama_config.vocab_size < dataset.vocab_size:
93 | # llama_config.update(dict(vocab_size=dataset.vocab_size))
94 |
95 | model = FlaxLLaMAForCausalLMModule(
96 | llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
97 | )
98 |
99 | optimizer, optimizer_info = OptimizerFactory.get_optimizer(
100 | FLAGS.optimizer,
101 | get_weight_decay_mask(LLaMAConfig.get_weight_decay_exclusions())
102 | )
103 |
104 | def create_trainstate_from_params(params):
105 | return TrainState.create(params=params, tx=optimizer, apply_fn=None)
106 |
107 | def init_fn(rng):
108 | rng_generator = JaxRNG(rng)
109 | params = model.init(
110 | input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
111 | position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
112 | attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
113 | rngs=rng_generator(llama_config.rng_keys()),
114 | )
115 | return TrainState.create(params=params, tx=optimizer, apply_fn=None)
116 |
117 | def train_step(train_state, rng, batch):
118 | rng_generator = JaxRNG(rng)
119 | batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
120 | def loss_and_accuracy(params):
121 | logits = model.apply(
122 | params, batch['input_tokens'], deterministic=False,
123 | rngs=rng_generator(llama_config.rng_keys()),
124 | ).logits
125 | return cross_entropy_loss_and_accuracy(
126 | logits, batch['target_tokens'], batch['loss_masks']
127 | )
128 | grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
129 | (loss, accuracy), grads = grad_fn(train_state.params)
130 | train_state = train_state.apply_gradients(grads=grads)
131 | metrics = dict(
132 | loss=loss,
133 | accuracy=accuracy,
134 | learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
135 | gradient_norm=global_norm(grads),
136 | param_norm=global_norm(train_state.params),
137 | )
138 | return train_state, rng_generator(), metrics
139 |
140 | def eval_step(train_state, rng, batch):
141 | rng_generator = JaxRNG(rng)
142 | batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
143 | logits = model.apply(
144 | train_state.params, batch['input_tokens'], deterministic=True,
145 | rngs=rng_generator(llama_config.rng_keys()),
146 | ).logits
147 | loss, accuracy = cross_entropy_loss_and_accuracy(
148 | logits, batch['target_tokens'], batch['loss_masks']
149 | )
150 | metrics = dict(
151 | eval_loss=loss,
152 | eval_accuracy=accuracy,
153 | )
154 | return rng_generator(), metrics
155 |
156 | train_state_shapes = jax.eval_shape(init_fn, next_rng())
157 | train_state_partition = match_partition_rules(
158 | LLaMAConfig.get_partition_rules(), train_state_shapes
159 | )
160 |
161 | shard_fns, gather_fns = make_shard_and_gather_fns(
162 | train_state_partition, train_state_shapes
163 | )
164 | checkpointer = StreamingCheckpointer(
165 | FLAGS.checkpointer, logger.output_dir,
166 | enable=jax.process_index() == 0,
167 | )
168 |
169 | sharded_init_fn = pjit(
170 | init_fn,
171 | in_shardings=PS(),
172 | out_shardings=train_state_partition
173 | )
174 |
175 | sharded_create_trainstate_from_params = pjit(
176 | create_trainstate_from_params,
177 | in_shardings=(train_state_partition.params, ),
178 | out_shardings=train_state_partition,
179 | donate_argnums=(0, ),
180 | )
181 |
182 | sharded_train_step = pjit(
183 | train_step,
184 | in_shardings=(train_state_partition, PS(), PS()),
185 | out_shardings=(train_state_partition, PS(), PS()),
186 | donate_argnums=(0, 1),
187 | )
188 |
189 | sharded_eval_step = pjit(
190 | eval_step,
191 | in_shardings=(train_state_partition, PS(), PS()),
192 | out_shardings=(PS(), PS()),
193 | donate_argnums=(1,),
194 | )
195 |
196 | def save_checkpoint(train_state, milestone=False):
197 | step = int(jax.device_get(train_state.step))
198 | metadata = dict(
199 | step=step,
200 | variant=variant,
201 | flags=flags_config_dict,
202 | llama_config=llama_config.to_dict(),
203 | )
204 | checkpointer.save_all(
205 | train_state=train_state,
206 | gather_fns=gather_fns,
207 | metadata=metadata,
208 | dataset=dataset.get_state_dict(),
209 | milestone=milestone,
210 | )
211 |
212 | mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
213 | with mesh:
214 | train_state, restored_params = None, None
215 | if FLAGS.load_checkpoint != '':
216 | train_state, restored_params = checkpointer.load_trainstate_checkpoint(
217 | FLAGS.load_checkpoint, train_state_shapes, shard_fns
218 | )
219 |
220 | if train_state is None and restored_params is None:
221 | # Initialize from scratch
222 | train_state = sharded_init_fn(next_rng())
223 | elif train_state is None and restored_params is not None:
224 | # Restore from params but initialize train_state
225 | train_state = sharded_create_trainstate_from_params(restored_params)
226 | del restored_params
227 | # Hack to get the correct step from the metadata. Otherwise, the `train_state.step` is initialized at 0
228 | if FLAGS.load_metadata != '':
229 | loaded_step = mlxu.load_pickle(FLAGS.load_metadata)["step"] + 1
230 | train_state = train_state.replace(step=loaded_step)
231 | print(f"Resuming at step {loaded_step}")
232 |
233 | start_step = int(jax.device_get(train_state.step))
234 |
235 | # if FLAGS.save_model_freq > 0:
236 | # save_checkpoint(train_state)
237 |
238 | sharded_rng = next_rng()
239 |
240 | step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
241 |
242 | for step, (batch, dataset_metrics) in zip(step_counter, dataset):
243 | train_state, sharded_rng, metrics = sharded_train_step(
244 | train_state, sharded_rng, batch
245 | )
246 |
247 | if step % FLAGS.log_freq == 0:
248 | if FLAGS.eval_steps > 0:
249 | eval_metric_list = []
250 | for _ in range(FLAGS.eval_steps):
251 | eval_batch, _ = next(eval_iterator)
252 | sharded_rng, eval_metrics = sharded_eval_step(
253 | train_state, sharded_rng, eval_batch
254 | )
255 | eval_metric_list.append(eval_metrics)
256 | metrics.update(average_metrics(eval_metric_list))
257 |
258 | log_metrics = {"step": step}
259 | log_metrics.update(metrics)
260 | log_metrics.update(dataset_metrics)
261 | log_metrics = jax.device_get(log_metrics)
262 | logger.log(log_metrics)
263 | if jax.process_index() == 0:
264 | tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
265 |
266 | if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
267 | save_checkpoint(train_state, milestone=True)
268 | elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
269 | save_checkpoint(train_state)
270 |
271 | if FLAGS.save_model_freq > 0:
272 | save_checkpoint(train_state)
273 |
274 |
275 | if __name__ == "__main__":
276 | mlxu.run(main)
277 |
--------------------------------------------------------------------------------
/EasyLM/optimizers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
4 | from functools import partial
5 | import re
6 | import dataclasses
7 | import random
8 |
9 | from ml_collections.config_dict import config_dict
10 | from ml_collections import ConfigDict
11 | import jax
12 | import jax.numpy as jnp
13 | import numpy as np
14 | from absl import logging
15 | import optax
16 |
17 | from EasyLM.jax_utils import float_to_dtype
18 |
19 |
20 | class OptimizerFactory(object):
21 | """ Configurable optax optimizer factory. """
22 |
23 | def __init__(self):
24 | raise NotImplementedError
25 |
26 | @staticmethod
27 | def get_default_config(updates=None):
28 | config = ConfigDict()
29 | config.accumulate_gradient_steps = 1
30 | config.type = 'adamw'
31 | config.palm_optimizer = PalmOptimizerFactory.get_default_config()
32 | config.adamw_optimizer = AdamWOptimizerFactory.get_default_config()
33 |
34 | if updates is not None:
35 | config.update(ConfigDict(updates).copy_and_resolve_references())
36 | return config
37 |
38 | @classmethod
39 | def get_optimizer(cls, config, weight_decay_mask=None):
40 | config = cls.get_default_config(config)
41 | if config.type == 'palm':
42 | optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer(
43 | config.palm_optimizer, weight_decay_mask
44 | )
45 | elif config.type == 'adamw':
46 | optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer(
47 | config.adamw_optimizer, weight_decay_mask
48 | )
49 | else:
50 | raise ValueError(f'Unknown optimizer type: {config.type}')
51 |
52 | if config.accumulate_gradient_steps > 1:
53 | optimizer = optax.MultiSteps(
54 | optimizer, config.accumulate_gradient_steps
55 | )
56 |
57 | return optimizer, optimizer_info
58 |
59 |
60 | class PalmOptimizerFactory(object):
61 | """ PaLM optimizer factory. This optimizer implements the optimizer
62 | described in the PaLM paper: https://arxiv.org/abs/2204.02311
63 | """
64 |
65 | def __init__(self):
66 | raise NotImplementedError
67 |
68 | @staticmethod
69 | def get_default_config(updates=None):
70 | config = ConfigDict()
71 | config.lr = 0.01
72 | config.lr_warmup_steps = 10000
73 | config.b1 = 0.9
74 | config.b2 = 0.99
75 | config.clip_gradient = 1.0
76 | config.weight_decay = 1e-4
77 | config.bf16_momentum = False
78 |
79 | if updates is not None:
80 | config.update(ConfigDict(updates).copy_and_resolve_references())
81 | return config
82 |
83 | @classmethod
84 | def get_optimizer(cls, config, weight_decay_mask=None):
85 | config = cls.get_default_config(config)
86 |
87 | def learning_rate_schedule(step):
88 | multiplier = config.lr / 0.01
89 | return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps))
90 |
91 | def weight_decay_schedule(step):
92 | multiplier = config.weight_decay / 1e-4
93 | return -multiplier * jnp.square(learning_rate_schedule(step))
94 |
95 | optimizer_info = dict(
96 | learning_rate_schedule=learning_rate_schedule,
97 | weight_decay_schedule=weight_decay_schedule,
98 | )
99 |
100 | optimizer = optax.chain(
101 | optax.clip_by_global_norm(config.clip_gradient),
102 | optax.adafactor(
103 | learning_rate=learning_rate_schedule,
104 | multiply_by_parameter_scale=True,
105 | momentum=config.b1,
106 | decay_rate=config.b2,
107 | factored=False,
108 | clipping_threshold=None,
109 | dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
110 | ),
111 | optax_add_scheduled_weight_decay(
112 | weight_decay_schedule, weight_decay_mask
113 | )
114 | )
115 | return optimizer, optimizer_info
116 |
117 |
118 | class AdamWOptimizerFactory(object):
119 | """ AdamW optimizer with cosine schedule. """
120 |
121 | def __init__(self):
122 | raise NotImplementedError
123 |
124 | @staticmethod
125 | def get_default_config(updates=None):
126 | config = ConfigDict()
127 | config.init_lr = 0.0
128 | config.end_lr = 0.001
129 | config.lr = 0.01
130 | config.lr_warmup_steps = 2000
131 | config.lr_decay_steps = 500000
132 | config.b1 = 0.9
133 | config.b2 = 0.95
134 | config.clip_gradient = 1.0
135 | config.weight_decay = 1e-4
136 | config.bf16_momentum = False
137 | config.multiply_by_parameter_scale = False
138 |
139 | if updates is not None:
140 | config.update(ConfigDict(updates).copy_and_resolve_references())
141 | return config
142 |
143 | @classmethod
144 | def get_optimizer(cls, config, weight_decay_mask=None):
145 | config = cls.get_default_config(config)
146 |
147 | learning_rate_schedule = optax.warmup_cosine_decay_schedule(
148 | init_value=config.init_lr,
149 | peak_value=config.lr,
150 | warmup_steps=config.lr_warmup_steps,
151 | decay_steps=config.lr_decay_steps,
152 | end_value=config.end_lr,
153 | )
154 |
155 | optimizer_info = dict(
156 | learning_rate_schedule=learning_rate_schedule,
157 | )
158 |
159 | if config.multiply_by_parameter_scale:
160 | optimizer = optax.chain(
161 | optax.clip_by_global_norm(config.clip_gradient),
162 | optax.adafactor(
163 | learning_rate=learning_rate_schedule,
164 | multiply_by_parameter_scale=True,
165 | momentum=config.b1,
166 | decay_rate=config.b2,
167 | factored=False,
168 | clipping_threshold=None,
169 | dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
170 | ),
171 | optax_add_scheduled_weight_decay(
172 | lambda step: -learning_rate_schedule(step) * config.weight_decay,
173 | weight_decay_mask
174 | )
175 | )
176 | else:
177 | optimizer = optax.chain(
178 | optax.clip_by_global_norm(config.clip_gradient),
179 | optax.adamw(
180 | learning_rate=learning_rate_schedule,
181 | weight_decay=config.weight_decay,
182 | b1=config.b1,
183 | b2=config.b2,
184 | mask=weight_decay_mask,
185 | mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
186 | ),
187 | )
188 |
189 | return optimizer, optimizer_info
190 |
191 |
192 | class OptaxScheduledWeightDecayState(NamedTuple):
193 | count: jax.Array
194 |
195 |
196 | def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
197 | """ Apply weight decay with schedule. """
198 |
199 | def init_fn(params):
200 | del params
201 | return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32))
202 |
203 | def update_fn(updates, state, params):
204 | if params is None:
205 | raise ValueError('Params cannot be None for weight decay!')
206 |
207 | weight_decay = schedule_fn(state.count)
208 | updates = jax.tree_util.tree_map(
209 | lambda g, p: g + weight_decay * p, updates, params
210 | )
211 | return updates, OptaxScheduledWeightDecayState(
212 | count=optax.safe_int32_increment(state.count)
213 | )
214 |
215 | if mask is not None:
216 | return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask)
217 | return optax.GradientTransformation(init_fn, update_fn)
218 |
--------------------------------------------------------------------------------
/EasyLM/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/EasyLM/scripts/__init__.py
--------------------------------------------------------------------------------
/EasyLM/scripts/convert_checkpoint.py:
--------------------------------------------------------------------------------
1 | # This script converts model checkpoint trained by EsayLM to a standard
2 | # mspack checkpoint that can be loaded by huggingface transformers or
3 | # flax.serialization.msgpack_restore. Such conversion allows models to be
4 | # used by other frameworks that integrate with huggingface transformers.
5 |
6 | import pprint
7 | from functools import partial
8 | import os
9 | import numpy as np
10 | import mlxu
11 | import jax.numpy as jnp
12 | import flax.serialization
13 | from EasyLM.checkpoint import StreamingCheckpointer
14 | from EasyLM.jax_utils import float_to_dtype
15 |
16 |
17 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
18 | load_checkpoint='',
19 | output_file='',
20 | streaming=False,
21 | float_dtype='bf16',
22 | )
23 |
24 |
25 | def main(argv):
26 | assert FLAGS.load_checkpoint != '' and FLAGS.output_file != '', 'input and output must be specified'
27 | params = StreamingCheckpointer.load_trainstate_checkpoint(
28 | FLAGS.load_checkpoint, disallow_trainstate=True
29 | )[1]['params']
30 |
31 | if FLAGS.streaming:
32 | StreamingCheckpointer.save_train_state_to_file(
33 | params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
34 | )
35 | else:
36 | params = float_to_dtype(params, FLAGS.float_dtype)
37 | with mlxu.open_file(FLAGS.output, 'wb') as fout:
38 | fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
39 |
40 |
41 | if __name__ == "__main__":
42 | mlxu.run(main)
43 |
--------------------------------------------------------------------------------
/EasyLM/scripts/diff_checkpoint.py:
--------------------------------------------------------------------------------
1 | # This script converts model checkpoint trained by EsayLM to a standard
2 | # mspack checkpoint that can be loaded by huggingface transformers or
3 | # flax.serialization.msgpack_restore. Such conversion allows models to be
4 | # used by other frameworks that integrate with huggingface transformers.
5 |
6 | import pprint
7 | from functools import partial
8 | import os
9 | import numpy as np
10 | import jax
11 | import jax.numpy as jnp
12 | import flax.serialization
13 | import mlxu
14 | from EasyLM.checkpoint import StreamingCheckpointer
15 | from EasyLM.jax_utils import float_to_dtype
16 |
17 |
18 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
19 | recover_diff=False,
20 | load_base_checkpoint='',
21 | load_target_checkpoint='',
22 | output_file='',
23 | streaming=True,
24 | float_dtype='bf16',
25 | )
26 |
27 |
28 | def main(argv):
29 | assert FLAGS.load_base_checkpoint != '' and FLAGS.load_target_checkpoint != ''
30 | assert FLAGS.output_file != ''
31 | base_params = StreamingCheckpointer.load_trainstate_checkpoint(
32 | FLAGS.load_base_checkpoint, disallow_trainstate=True
33 | )[1]['params']
34 |
35 | target_params = StreamingCheckpointer.load_trainstate_checkpoint(
36 | FLAGS.load_target_checkpoint, disallow_trainstate=True
37 | )[1]['params']
38 |
39 | if FLAGS.recover_diff:
40 | params = jax.tree_util.tree_map(
41 | lambda b, t: b + t, base_params, target_params
42 | )
43 | else:
44 | params = jax.tree_util.tree_map(
45 | lambda b, t: t - b, base_params, target_params
46 | )
47 |
48 | if FLAGS.streaming:
49 | StreamingCheckpointer.save_train_state_to_file(
50 | params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
51 | )
52 | else:
53 | params = float_to_dtype(params, FLAGS.float_dtype)
54 | with mlxu.open_file(FLAGS.output, 'wb') as fout:
55 | fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
56 |
57 |
58 | if __name__ == "__main__":
59 | mlxu.run(main)
60 |
--------------------------------------------------------------------------------
/EasyLM/scripts/lm_eval_harness.py:
--------------------------------------------------------------------------------
1 | # This script runs lm_eval_harness evaluations against a served language model.
2 | # Typically, you need to run a language model server first, e.g.:
3 | # python -m EasyLM.models.gptj.gptj_serve ...
4 |
5 | import dataclasses
6 | import pprint
7 | from functools import partial
8 | import os
9 | from tqdm import tqdm, trange
10 | import numpy as np
11 | import mlxu
12 |
13 | from flax.traverse_util import flatten_dict
14 | from lm_eval import evaluator, tasks
15 | from lm_eval.base import LM
16 |
17 | from EasyLM.serving import LMClient
18 |
19 |
20 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
21 | tasks='wsc,piqa,winogrande,openbookqa,logiqa',
22 | shots=0,
23 | limit=0,
24 | write_out=False,
25 | lm_client=LMClient.get_default_config(),
26 | logger=mlxu.WandBLogger.get_default_config(),
27 | )
28 |
29 |
30 | class LMEvalHarnessInterface(LM):
31 |
32 | def __init__(self, lm_client):
33 | self.lm_client = lm_client
34 |
35 | def greedy_until(self, inputs):
36 | prefix, until = zip(*inputs)
37 | return self.lm_client.greedy_until(prefix, until)
38 |
39 | def loglikelihood_rolling(self, inputs):
40 | loglikelihood, is_greedy = self.lm_client.loglikelihood_rolling(inputs)
41 | return list(zip(loglikelihood, is_greedy))
42 |
43 | def loglikelihood(self, inputs):
44 | prefix, text = zip(*inputs)
45 | loglikelihood, is_greedy = self.lm_client.loglikelihood(prefix, text)
46 | return list(zip(loglikelihood, is_greedy))
47 |
48 |
49 | def main(argv):
50 | logger = mlxu.WandBLogger(
51 | config=FLAGS.logger, variant=mlxu.get_user_flags(FLAGS, FLAGS_DEF)
52 | )
53 | model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client))
54 | task_list = FLAGS.tasks.split(',')
55 | results = evaluator.evaluate(
56 | model, tasks.get_task_dict(task_list), False, FLAGS.shots,
57 | limit=None if FLAGS.limit <= 0 else FLAGS.limit,
58 | write_out=FLAGS.write_out,
59 | )
60 | logger.log(flatten_dict(results['results'], sep='/'))
61 | pprint.pprint(results)
62 |
63 |
64 | if __name__ == "__main__":
65 | mlxu.run(main)
66 |
--------------------------------------------------------------------------------
/EasyLM/scripts/lm_eval_json.py:
--------------------------------------------------------------------------------
1 | import json
2 | import mlxu
3 | from EasyLM.serving import LMClient
4 |
5 |
6 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
7 | input_file='',
8 | output_file='',
9 | prefix_field='prefix',
10 | text_field='text',
11 | until_field='until',
12 | eval_type='loglikelihood',
13 | lm_client=LMClient.get_default_config(),
14 | )
15 |
16 |
17 | def main(argv):
18 | lm_client = LMClient(FLAGS.lm_client)
19 | with mlxu.open_file(FLAGS.input_file, 'r') as fin:
20 | input_data = json.load(fin)
21 |
22 | if FLAGS.eval_type == 'loglikelihood':
23 | prefix = input_data[FLAGS.prefix_field]
24 | text = input_data[FLAGS.text_field]
25 | loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text)
26 | output_data = {
27 | 'loglikelihood': loglikelihoods,
28 | 'is_greedy': is_greedys,
29 | }
30 | elif FLAGS.eval_type == 'loglikelihood_rolling':
31 | text = input_data[FLAGS.text_field]
32 | loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text)
33 | output_data = {
34 | 'loglikelihood': loglikelihoods,
35 | 'is_greedy': is_greedys,
36 | }
37 | elif FLAGS.eval_type == 'greedy_until':
38 | prefix = input_data[FLAGS.prefix_field]
39 | until = input_data[FLAGS.until_field]
40 | output_data = {'output_text': lm_client.greedy_until(prefix, until)}
41 | elif FLAGS.eval_type == 'generate':
42 | prefix = input_data[FLAGS.prefix_field]
43 | output_data = {'output_text': lm_client.generate(prefix)}
44 | else:
45 | raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}')
46 |
47 | with mlxu.open_file(FLAGS.output_file, 'w') as fout:
48 | json.dump(output_data, fout)
49 |
50 |
51 | if __name__ == "__main__":
52 | mlxu.run(main)
53 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # LVM: Sequential Modeling Enables Scalable Learning for Large Vision Models
3 |
4 | [LVM](https://arxiv.org/abs/2312.00785) is a vision pretraining model that converts various kinds of visual data into visual sentences and performs next-token prediction autoregressively. It is compatible with both GPU and TPU.
5 |
6 | LVM is built on top of [OpenLLaMA](https://github.com/openlm-research/open_llama) (an autoregressive model) and [OpenMuse](https://github.com/huggingface/open-muse) (a VQGAN that converts images into visual tokens).
7 |
8 | This was trained in collaboration with HuggingFace. Thanks [Victor Sanh](https://github.com/VictorSanh) for the support in this project.
9 |
10 | ## Abstract:
11 |
12 | We introduce a novel sequential modeling approach which enables learning a Large Vision Model (LVM) without making use of any linguistic data.
13 | To do this, we define a common format, ``visual sentences", in which we can represent raw images and videos as well as annotated data sources such as semantic segmentations and depth reconstructions without needing any meta-knowledge beyond the pixels. Once this wide variety of visual data (comprising 420 billion tokens) is represented as sequences, the model can be trained to minimize a cross-entropy loss for next token prediction. By training across various scales of model architecture and data diversity, we provide empirical evidence that our models scale effectively. Many different vision tasks can be solved by designing suitable visual prompts at test time.
14 |
15 | ## Visual Sentence
16 |
17 |
18 |

19 |
20 |
21 |
22 | ## Key Differences from the Original Paper Version
23 | 1. We are currently releasing the 7B model (previously 3B). Additional model size variants will be available later.
24 | 2. Deep filtering (including quality filters, deduplication, and known CSAM content removal) has been applied to the LAION dataset, reducing the dataset size from 1.5B to 1.2B images.
25 |
26 | 3. The tokenizer has been improved for better performance.
27 |
28 | ## License
29 | LVM is licensed under the Apache 2.0 License.
30 |
31 | ## Installation
32 | ```shell
33 | git clone https://github.com/ytongbai/LVM
34 | cd LVM
35 | export PYTHONPATH="\${PWD}:\$PYTHONPATH"
36 | ```
37 |
38 | ## Environment Setup
39 | ```shell
40 | conda env create -f scripts/gpu_environment.yml
41 | conda activate LVM
42 | ```
43 |
44 | ## Dataset Preparation
45 | Please refer to \`DATASET.md\` for detailed instructions on preparing the dataset.
46 |
47 | After preparing the dataset, you will get a pretokenized file \`dataset.jsonl\`.
48 |
49 | ## Training Script
50 |
51 | We provide an example training script for 7B model, for more details about the distributed training setting, please refer to [EasyLM](https://github.com/young-geng/EasyLM).
52 |
53 | For other model size, we provide the model definition from 100M, 300M, 600M, 1B, 3B, 7B, 13B, 20B to 30B in './EasyLM/models/llama/llama_model.py' .
54 |
55 |
56 | ```shell
57 | python -u -m EasyLM.models.llama.llama_train \
58 | --jax_distributed.initialize_jax_distributed=True \
59 | --jax_distributed.coordinator_address='$MASTER_ADDR:$MASTER_PORT' \
60 | --jax_distributed.local_device_ids='0,1,2,3,4,5,6,7' \
61 | --mesh_dim='$SLURM_NNODES,-1,1' \
62 | --dtype='bf16' \
63 | --total_steps=400000 \ # change according to the number of data
64 | --log_freq=10 \
65 | --save_model_freq=1000 \
66 | --save_milestone_freq=2000 \
67 | --load_llama_config='vqlm_7b' \
68 | --optimizer.type='adamw' \
69 | --optimizer.adamw_optimizer.weight_decay=0.1 \
70 | --optimizer.adamw_optimizer.lr=1.5e-4 \
71 | --optimizer.adamw_optimizer.end_lr=3e-5 \
72 | --optimizer.adamw_optimizer.lr_warmup_steps=8000 \
73 | --optimizer.adamw_optimizer.lr_decay_steps=288000 \
74 | --optimizer.accumulate_gradient_steps=4 \
75 | --train_dataset.type='json' \
76 | --train_dataset.text_processor.fields=',{tokens},' \
77 | --train_dataset.json_dataset.path='/path/to/dataset.jsonl' \
78 | --train_dataset.json_dataset.seq_length=4096 \
79 | --train_dataset.json_dataset.batch_size=32 \
80 | --train_dataset.json_dataset.tokenizer_processes=16 \
81 | --checkpointer.save_optimizer_state=True \
82 | --logger.online=True \
83 | --logger.output_dir='/path/to/checkpoint/$RUN_NAME' \
84 | --logger.wandb_dir='/path/to/wandb' \
85 | --logger.notes='' \
86 | --logger.experiment_id=$EXPERIMENT_ID
87 | ```
88 |
89 | ## Convert to Huggingface checkpoint
90 |
91 | ```shell
92 | python -m EasyLM.models.llama.convert_easylm_to_hf --load_checkpoint='trainstate_params::/path/to/checkpoint/streaming_train_state' --model_size='vqlm_7b' --output_dir='/path/to/output/checkpoint/'
93 | ```
94 |
95 | ## Demo & Inference
96 |
97 | Download the [few-shot examples dataset](https://livejohnshopkins-my.sharepoint.com/:f:/g/personal/ybai20_jh_edu/Ei0xiLdFFqJPnwAlFWar29EBUAvB0O3CVaJykZl-f11KDQ?e=Bx9SXZ).
98 |
99 | There are mainly two visual prompting: sequential prompting and analogy prompting.
100 |
101 | ### Analogy Prompting:
102 | Describe the task with few-shot examples, which is pairs of (x, y) inputs where x is the input image and y the "annotated" image. And add one query image in the end. We provide more few-shot examples at [this link](https://livejohnshopkins-my.sharepoint.com/:f:/g/personal/ybai20_jh_edu/Ei0xiLdFFqJPnwAlFWar29EBUAvB0O3CVaJykZl-f11KDQ?e=Bx9SXZ), and you can simply change the query image in the end for testing.
103 |
104 | ### Sequential Prompting:
105 | Input a sequence of continuous frames and let the model generate the next one.
106 |
107 |
108 | Check out our demo and additionaly inference code on HuggingFace Spaces: [LVM Demo](https://huggingface.co/spaces/Emma02/LVM)
109 |
110 |
111 |
112 | ## Evaluation
113 |
114 | Check evaluation/EVAL.md
115 |
116 | ## Models
117 | - [LVM Checkpoints](https://huggingface.co/Emma02/LVM_ckpts)
118 | - [VQ-VAE Checkpoints](https://huggingface.co/Emma02/vqvae_ckpts)
119 |
120 |
121 | ## Finetuning
122 |
123 | LVM is a pretraining model, without instruction tuning or other kinds of post-training. If you want a specific task, we recommend organizing the data into visual sentence format, then finetune with a smaller learning rate using the training script we provide.
124 |
125 | ## Citation
126 | If you found LVM useful in your research or applications, please cite our work using the following BibTeX:
127 |
128 | ```bibtex
129 | @article{bai2023sequential,
130 | title={Sequential modeling enables scalable learning for large vision models},
131 | author={Bai, Yutong and Geng, Xinyang and Mangalam, Karttikeya and Bar, Amir and Yuille, Alan and Darrell, Trevor and Malik, Jitendra and Efros, Alexei A},
132 | journal={arXiv preprint arXiv:2312.00785},
133 | year={2023}
134 | }
135 |
136 | ```
137 |
--------------------------------------------------------------------------------
/evaluation/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/evaluation/.DS_Store
--------------------------------------------------------------------------------
/evaluation/EVAL.md:
--------------------------------------------------------------------------------
1 | # Large Vision Model Evaluation
2 | This is a evaluation demo for the Large Vision Model paper.
3 |
4 |
5 | ### Setting Up with Conda Environment
6 | You can also set up the demo using a conda environment. First, you will need to
7 | create a conda environment and install the dependencies:
8 | ```bash
9 | conda env create -f environment.yml
10 | conda activate vqlm_demo
11 | export PYTHONPATH=/path/to/this/repo:$PYTHONPATH
12 |
13 | ```
14 |
15 | Then you'll need to download the [VQ tokenizer checkpoint file](https://huggingface.co/Emma02/vqvae_ckpts) and put it into ./vqvae_ckpts/
16 |
17 |
18 | ## Running the Perplexity Evaluation
19 | This repo also contains the perplexity evaluation script. You can run the following
20 | command to evaluate the perplexity of the model:
21 |
22 | ```bash
23 | python -m vqlm_demo.eval_perplexity \
24 | --input_file=path/to/input_jsonl_file \
25 | --input_base_dir=base/path/to/add/to/the/input \
26 | --checkpoint=path/to/checkpoint \
27 | --batch_size=4
28 | ```
29 |
30 | This script accept a jsonl file as input. Each line of the jsonl file
31 | representing a dictionary. Each line represents one example in the evaluation
32 | set. The dictionary should have two key:
33 | * input: a list of paths to the input images as **context to the model**. This list should include the few shot examples.
34 | * target: a list of paths to the **target images** to evaluate perplexity on.
35 |
36 | Here's an example of the json format:
37 | ```javascript
38 | {'input': ['path/to/input1.jpg', 'path/to/input2.jpg', 'path/to/input3.jpg'],
39 | 'target': ['path/to/target1.jpg', 'path/to/target2.jpg', 'path/to/target3.jpg']}
40 | ```
41 |
42 | When evaluating
43 |
44 | Ths script should run the model and compute the average perplexity on the
45 | evaluation set.
46 |
47 |
48 | ## Running the batch generation evaluation
49 | This repo also contains the script to batch generate images from the model. You
50 | can run the following command to generate images from the model:
51 |
52 | ```bash
53 | python -m vqlm_demo.batch_generation \
54 | --checkpoint=path/to/checkpoint \
55 | --input_file=path/to/input_jsonl_file \
56 | --input_base_dir=base/path/to/add/to/input/path/in/jsonl \
57 | --output_base_dir=base/path/to/add/to/output/path/in/jsonl \
58 | --n_new_frames=1 \
59 | --n_candidates=4 \
60 | --resize_output='original'
61 | ```
62 |
63 | This script accept a jsonl file as input. Each line of the jsonl file
64 | representing a dictionary. Each line represents one example in the evaluation
65 | set. The dictionary should have two key:
66 | * input: a list of paths to the input images as **context to the model**. This list should include the few shot examples.
67 | * output: a string representing the output path of generated image.
68 |
69 | Here's an example of the json format:
70 | ```javascript
71 | {'input': ['path/to/input1.jpg', 'path/to/input2.jpg', 'path/to/input3.jpg'],
72 | 'output': 'path/to/output.jpg'}
73 | ```
--------------------------------------------------------------------------------
/evaluation/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/evaluation/environment.yml:
--------------------------------------------------------------------------------
1 | name: vqlm_demo
2 | channels:
3 | - conda-forge
4 | dependencies:
5 | - python=3.10
6 | - pip
7 | - numpy
8 | - scipy
9 | - matplotlib
10 | - seaborn
11 | - jupyter
12 | - tqdm
13 | - pillow
14 | - pip:
15 | - --extra-index-url https://download.pytorch.org/whl/cu118
16 | - transformers==4.34.1
17 | - torch==2.0.1
18 | - einops
19 | - absl-py
20 | - ml_collections
21 | - requests
22 | - mlxu==0.1.11
23 | - pydantic
24 | - fastapi
25 | - uvicorn
26 | - gradio
27 | - fastapi
28 | - uvicorn
29 | - opencv-python-headless
30 | - scikit-video
31 | - scikit-image
32 | - natsort
33 |
--------------------------------------------------------------------------------
/evaluation/vqlm_demo/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/evaluation/vqlm_demo/.DS_Store
--------------------------------------------------------------------------------
/evaluation/vqlm_demo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/evaluation/vqlm_demo/__init__.py
--------------------------------------------------------------------------------
/evaluation/vqlm_demo/batch_generation.py:
--------------------------------------------------------------------------------
1 | """
2 | Batch generation for sequnce of images. This script accept a jsonl file
3 | as input. Each line of the jsonl file representing a dictionary. Each line
4 | represents one example in the evaluation set. The dictionary should have two key:
5 |
6 | input: a list of paths to the input images as context to the model.
7 | output: a string representing the path to the output of generation to be saved.
8 |
9 | Ths script runs the mode to generate the output images, and concatenate the
10 | input and output images together and save them to the output path.
11 | """
12 |
13 | import os
14 | import json
15 | from PIL import Image
16 | import numpy as np
17 | import mlxu
18 | from tqdm import tqdm, trange
19 | from multiprocessing import Pool
20 | import einops
21 | import torch
22 |
23 | from .inference import MultiProcessInferenceModel
24 | from .utils import read_image_to_tensor, MultiProcessImageSaver
25 |
26 |
27 | FLAGS, _ = mlxu.define_flags_with_default(
28 | input_file='',
29 | checkpoint='',
30 | input_base_dir='',
31 | output_base_dir='',
32 | evaluate_mse=False,
33 | json_input_key='input',
34 | json_output_key='output',
35 | json_target_key='target',
36 | n_new_frames=1,
37 | n_candidates=2,
38 | context_frames=16,
39 | temperature=1.0,
40 | top_p=1.0,
41 | n_workers=8,
42 | dtype='float16',
43 | torch_devices='',
44 | batch_size_factor=4,
45 | max_examples=0,
46 | resize_output='',
47 | include_input=False,
48 | )
49 |
50 | # create this according to the json file.
51 | class MultiFrameDataset(torch.utils.data.Dataset):
52 | def __init__(self, input_files, output_files, target_files=None):
53 | assert len(input_files)
54 | self.input_files = input_files
55 | self.output_files = output_files
56 | self.target_files = target_files
57 |
58 | def __len__(self):
59 | return len(self.input_files)
60 |
61 | def __getitem__(self, idx):
62 | original_size = Image.open(self.input_files[idx][-1]).size
63 | input_images = np.stack(
64 | [read_image_to_tensor(f) for f in self.input_files[idx]],
65 | axis=0
66 | )
67 |
68 | if self.target_files is not None:
69 | target_images = np.stack(
70 | [read_image_to_tensor(f) for f in self.target_files[idx]],
71 | axis=0
72 | )
73 | else:
74 | target_images = None
75 | return input_images, target_images, self.output_files[idx], np.array(original_size)
76 |
77 |
78 | def main(_):
79 | assert FLAGS.checkpoint != ''
80 |
81 | print(f'Loading checkpoint from {FLAGS.checkpoint}')
82 | print(f'Evaluating input file from {FLAGS.input_file}')
83 |
84 | # build a model.
85 |
86 | model = MultiProcessInferenceModel(
87 | checkpoint=FLAGS.checkpoint,
88 | torch_devices=FLAGS.torch_devices,
89 | dtype=FLAGS.dtype,
90 | context_frames=FLAGS.context_frames,
91 | use_lock=True,
92 | )
93 |
94 | # input_files: the json file that needs to be generated by the other file.
95 | input_files = []
96 | output_files = []
97 |
98 | if FLAGS.evaluate_mse:
99 | target_files = []
100 | else:
101 | target_files = None
102 |
103 | with mlxu.open_file(FLAGS.input_file, 'r') as f:
104 | for line in f:
105 | record = json.loads(line)
106 | input_files.append(record[FLAGS.json_input_key])
107 | output_files.append(record[FLAGS.json_output_key])
108 | if FLAGS.evaluate_mse:
109 | target_files.append(record[FLAGS.json_target_key])
110 |
111 |
112 | if FLAGS.max_examples > 0:
113 | input_files = input_files[:FLAGS.max_examples]
114 | output_files = output_files[:FLAGS.max_examples]
115 | if FLAGS.evaluate_mse:
116 | target_files = target_files[:FLAGS.max_examples]
117 |
118 | if FLAGS.input_base_dir != '':
119 | input_files = [
120 | [os.path.join(FLAGS.input_base_dir, x) for x in y]
121 | for y in input_files
122 | ]
123 | if FLAGS.evaluate_mse:
124 | target_files = [
125 | [os.path.join(FLAGS.input_base_dir, x) for x in y]
126 | for y in target_files
127 | ]
128 |
129 | if FLAGS.output_base_dir != '':
130 | os.makedirs(FLAGS.output_base_dir, exist_ok=True)
131 | output_files = [
132 | os.path.join(FLAGS.output_base_dir, x)
133 | for x in output_files
134 | ]
135 |
136 | dataset = MultiFrameDataset(input_files, output_files, target_files)
137 |
138 | data_loader = torch.utils.data.DataLoader(
139 | dataset,
140 | batch_size=FLAGS.batch_size_factor * model.n_processes,
141 | shuffle=False,
142 | num_workers=FLAGS.n_workers,
143 | )
144 |
145 | image_saver = MultiProcessImageSaver(FLAGS.n_workers)
146 |
147 | mses = []
148 |
149 | for batch_images, batch_targets, batch_output_files, batch_sizes in tqdm(data_loader, ncols=0):
150 |
151 | # batch_images is input.
152 | batch_images = batch_images.numpy()
153 |
154 | #
155 | context_length = batch_images.shape[1]
156 |
157 |
158 | generated_images = model(
159 | batch_images,
160 | FLAGS.n_new_frames,
161 | FLAGS.n_candidates,
162 | temperature=FLAGS.temperature,
163 | top_p=FLAGS.top_p
164 | )
165 |
166 |
167 | repeated_batch = einops.repeat(
168 | batch_images,
169 | 'b s h w c -> b n s h w c',
170 | n=FLAGS.n_candidates,
171 | )
172 | generated_images = np.array(generated_images)
173 |
174 | if FLAGS.evaluate_mse:
175 | batch_targets = einops.repeat(
176 | batch_targets.numpy(),
177 | 'b s h w c -> b n s h w c', # batch, candidate, s
178 | n=FLAGS.n_candidates,
179 | )
180 | channels = batch_targets.shape[-1]
181 | # calculate mse loss.
182 | mse = np.mean((generated_images - batch_targets) ** 2, axis=(1, 2, 3, 4, 5))
183 |
184 | mses.append(mse * channels)
185 |
186 |
187 | if FLAGS.include_input:
188 | combined = einops.rearrange(
189 | np.concatenate([repeated_batch, generated_images], axis=2),
190 | 'b n s h w c -> b (n h) (s w) c'
191 | )
192 | else:
193 | combined = einops.rearrange(
194 | generated_images,
195 | 'b n s h w c -> b (n h) (s w) c'
196 | )
197 | combined = (combined * 255).astype(np.uint8)
198 |
199 | n_frames = FLAGS.n_new_frames
200 | if FLAGS.include_input:
201 | n_frames += context_length
202 |
203 | if FLAGS.resize_output == '':
204 | resizes = None
205 |
206 | elif FLAGS.resize_output == 'original':
207 | resizes = batch_sizes.numpy()
208 | resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]])
209 | else:
210 | resize = tuple(int(x) for x in FLAGS.resize_output.split(','))
211 | resizes = np.array([resize] * len(batch_sizes))
212 | resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]])
213 |
214 | image_saver(combined, batch_output_files, resizes)
215 |
216 | if FLAGS.evaluate_mse:
217 | mses = np.concatenate(mses, axis=0)
218 | print(f'MSE: {np.mean(mses)}')
219 |
220 | image_saver.close()
221 |
222 | if __name__ == "__main__":
223 | mlxu.run(main)
--------------------------------------------------------------------------------
/evaluation/vqlm_demo/eval_perplexity.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluating the perplexity on few shot tasks. This script accept a jsonl file
3 | as input. Each line of the jsonl file representing a dictionary. Each line
4 | represents one example in the evaluation set. The dictionary should have two key:
5 |
6 | input: a list of paths to the input images as context to the model. This
7 | list should include the few shot examples.
8 | target: a list of paths to the target images to evaluate perplexity
9 |
10 | Ths script should run the model and compute the average perplexity on the
11 | evaluation set.
12 | """
13 |
14 | import os
15 | import json
16 | from PIL import Image
17 | import numpy as np
18 | import mlxu
19 | from tqdm import tqdm, trange
20 | import torch
21 | import torch.nn as nn
22 | import torch.nn.functional as F
23 | import einops
24 |
25 | from .inference import MultiProcessInferenceModel
26 |
27 |
28 | FLAGS, _ = mlxu.define_flags_with_default(
29 | input_file='',
30 | checkpoint='',
31 | input_base_dir='',
32 | batch_size=2,
33 | json_input_key='input',
34 | json_target_key='target',
35 | dtype='float16',
36 | torch_devices='',
37 | n_workers=4,
38 | max_examples=0,
39 | )
40 |
41 |
42 | def read_image_to_tensor(path):
43 | pil_im = Image.open(path).convert('RGB')
44 | input_img = pil_im.resize((256, 256))
45 | input_img = np.array(input_img) / 255.0
46 | input_img = input_img.astype(np.float32)
47 | return input_img
48 |
49 |
50 | class MultiFrameDataset(torch.utils.data.Dataset):
51 | def __init__(self, input_files, target_files):
52 | assert len(input_files) == len(target_files)
53 | self.input_files = input_files
54 | self.target_files = target_files
55 |
56 | def __len__(self):
57 | return len(self.input_files)
58 |
59 | def __getitem__(self, idx):
60 | input_list = np.stack(
61 | [read_image_to_tensor(f) for f in self.input_files[idx]],
62 | axis=0
63 | )
64 | target_list = np.stack(
65 | [read_image_to_tensor(f) for f in self.target_files[idx]],
66 | axis=0
67 | )
68 | return input_list, target_list
69 |
70 |
71 | def main(_):
72 | assert FLAGS.checkpoint != ''
73 |
74 | print(f'Loading checkpoint from {FLAGS.checkpoint}')
75 | print(f'Evaluating input file from {FLAGS.input_file}')
76 |
77 | model = MultiProcessInferenceModel(
78 | checkpoint=FLAGS.checkpoint,
79 | torch_devices=FLAGS.torch_devices,
80 | dtype=FLAGS.dtype,
81 | use_lock=True,
82 | perplexity_batch_size=FLAGS.batch_size,
83 | )
84 |
85 | input_files = []
86 | target_files = []
87 |
88 | with mlxu.open_file(FLAGS.input_file, 'r') as f:
89 | for line in f:
90 | record = json.loads(line)
91 | input_files.append(record[FLAGS.json_input_key])
92 | target_files.append(record[FLAGS.json_target_key])
93 |
94 | if FLAGS.input_base_dir != '':
95 | input_files = [
96 | [os.path.join(FLAGS.input_base_dir, x) for x in y]
97 | for y in input_files
98 | ]
99 | target_files = [
100 | [os.path.join(FLAGS.input_base_dir, x) for x in y]
101 | for y in target_files
102 | ]
103 |
104 | if FLAGS.max_examples > 0:
105 | input_files = input_files[:FLAGS.max_examples]
106 | target_files = target_files[:FLAGS.max_examples]
107 |
108 | dataset = MultiFrameDataset(input_files, target_files)
109 | data_loader = torch.utils.data.DataLoader(
110 | dataset,
111 | batch_size=FLAGS.batch_size * model.n_processes,
112 | shuffle=False,
113 | num_workers=FLAGS.n_workers
114 | )
115 |
116 | perplexities = []
117 |
118 | for input_images, target_images in tqdm(data_loader, ncols=0):
119 | perplexity = model.compute_perplexity(input_images, target_images)
120 | perplexities.append(perplexity)
121 |
122 | perplexities = np.concatenate(perplexities, axis=0)
123 | print(f'Perplexity: {np.mean(perplexities)}')
124 |
125 |
126 | if __name__ == "__main__":
127 | mlxu.run(main)
--------------------------------------------------------------------------------
/evaluation/vqlm_demo/eval_video_perplexity.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import glob
4 | from functools import partial
5 | from tqdm import tqdm, trange
6 | from multiprocessing import Pool
7 | from PIL import Image
8 | import cv2
9 | import mlxu
10 | from natsort import natsorted
11 | import numpy as np
12 | import einops
13 | import torch
14 |
15 | from vqlm_demo.inference import MultiProcessInferenceModel
16 | from vqlm_demo.utils import (
17 | is_video, random_square_crop,
18 | read_frames_from_dir, read_frames_from_video
19 | )
20 |
21 |
22 | FLAGS, _ = mlxu.define_flags_with_default(
23 | checkpoint='',
24 | input_files='',
25 | frame_input=False,
26 | read_file_list='',
27 | center_crop=1.0,
28 | n_context_frames=15,
29 | n_target_frames=1,
30 | n_workers=8,
31 | stride=8,
32 | batch_size=2,
33 | torch_devices='',
34 | shuffle=False,
35 | random_start=True,
36 | max_examples=0,
37 | )
38 |
39 |
40 | class VideoDataset(torch.utils.data.Dataset):
41 |
42 | def __init__(self, videos, frame_input=False, n_context_frames=15,
43 | n_target_frames=1, stride=1):
44 | self.videos = videos
45 | self.frame_input = frame_input
46 | self.n_context_frames = n_context_frames
47 | self.n_target_frames = n_target_frames
48 | self.stride = stride
49 |
50 | def __getitem__(self, index):
51 | if self.frame_input:
52 | frames = read_frames_from_dir(
53 | self.videos[index],
54 | self.n_context_frames + self.n_target_frames,
55 | self.stride,
56 | center_crop=FLAGS.center_crop,
57 | random_start=FLAGS.random_start,
58 | )
59 | else:
60 | frames = read_frames_from_video(
61 | self.videos[index],
62 | self.n_context_frames + self.n_target_frames,
63 | self.stride,
64 | center_crop=FLAGS.center_crop,
65 | random_start=FLAGS.random_start,
66 | )
67 | if frames is None:
68 | return self[np.random.randint(0, len(self))]
69 | return frames[:self.n_context_frames], frames[self.n_context_frames:]
70 |
71 | def __len__(self):
72 | return len(self.videos)
73 |
74 |
75 |
76 | def main(_):
77 | assert FLAGS.checkpoint != ''
78 | assert FLAGS.read_file_list != '' or FLAGS.input_files != ''
79 |
80 | model = MultiProcessInferenceModel(
81 | checkpoint=FLAGS.checkpoint,
82 | torch_devices=FLAGS.torch_devices,
83 | perplexity_batch_size=FLAGS.batch_size,
84 | )
85 |
86 | if FLAGS.read_file_list != '':
87 | with open(FLAGS.read_file_list, 'r') as f:
88 | videos = [x.strip() for x in f.readlines()]
89 | else:
90 | videos = glob.glob(FLAGS.input_files)
91 |
92 | if FLAGS.frame_input:
93 | videos = [x for x in videos if os.path.isdir(x)]
94 | else:
95 | videos = [x for x in videos if is_video(x)]
96 |
97 | if FLAGS.shuffle:
98 | np.random.shuffle(videos)
99 |
100 | if FLAGS.max_examples > 0:
101 | videos = videos[:FLAGS.max_examples]
102 |
103 | dataset = VideoDataset(
104 | videos,
105 | frame_input=FLAGS.frame_input,
106 | n_context_frames=FLAGS.n_context_frames,
107 | n_target_frames=FLAGS.n_target_frames,
108 | stride=FLAGS.stride
109 | )
110 | dataloader = torch.utils.data.DataLoader(
111 | dataset,
112 | batch_size=FLAGS.batch_size * model.n_processes * 4,
113 | shuffle=False,
114 | num_workers=FLAGS.n_workers,
115 | prefetch_factor=4,
116 | drop_last=True,
117 | )
118 |
119 | perplexities = []
120 |
121 | for batch_context_frames, batch_taret_frames in tqdm(dataloader, ncols=0):
122 | batch_context_frames = batch_context_frames.numpy()
123 | batch_taret_frames = batch_taret_frames.numpy()
124 | perplexity = model.compute_perplexity(
125 | batch_context_frames, batch_taret_frames
126 | )
127 | perplexities.append(perplexity)
128 |
129 | perplexities = np.concatenate(perplexities, axis=0)
130 | print(f'Perplexity: {np.mean(perplexities)}')
131 |
132 |
133 | if __name__ == '__main__':
134 | mlxu.run(main)
--------------------------------------------------------------------------------
/evaluation/vqlm_demo/eval_videos.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | from functools import partial
4 | from tqdm import tqdm, trange
5 | from multiprocessing import Pool
6 | from PIL import Image
7 | import cv2
8 | import mlxu
9 | from natsort import natsorted
10 | import numpy as np
11 | import einops
12 | import torch
13 |
14 | from vqlm_demo.inference import MultiProcessInferenceModel
15 | from vqlm_demo.utils import (
16 | is_video, random_square_crop,
17 | read_frames_from_dir, read_frames_from_video
18 | )
19 |
20 |
21 | FLAGS, _ = mlxu.define_flags_with_default(
22 | checkpoint='',
23 | input_files='',
24 | frame_input=False,
25 | read_file_list='',
26 | output_dir='',
27 | center_crop=1.0,
28 | n_context_frames=12,
29 | n_new_frames=4,
30 | n_candidates=8,
31 | temperature=1.0,
32 | top_p=1.0,
33 | n_workers=8,
34 | stride=8,
35 | batch_size=32,
36 | torch_devices='',
37 | shuffle=False,
38 | max_examples=0,
39 | )
40 |
41 |
42 | def save_image(args):
43 | image, filename = args
44 | base = FLAGS.input_files.split('*')[0]
45 | filename = filename[len(base):].replace('/', '_') + '.png'
46 | Image.fromarray(image).save(os.path.join(FLAGS.output_dir, filename))
47 |
48 |
49 | class VideoDataset(torch.utils.data.Dataset):
50 |
51 | def __init__(self, videos, frame_input=False, n_frames=8, stride=1, new_frame=1):
52 | self.videos = videos
53 | self.frame_input = frame_input
54 | self.n_frames = n_frames
55 | self.stride = stride
56 | self.new_frames = new_frames
57 |
58 | def __getitem__(self, index):
59 | if self.frame_input:
60 | frames = read_frames_from_dir(
61 | self.videos[index], self.n_frames, self.stride,
62 | center_crop=FLAGS.center_crop,
63 | )
64 |
65 | else:
66 | # 's h w c'
67 | frames = read_frames_from_video(
68 | self.videos[index], self.n_frames, self.stride,
69 | center_crop=FLAGS.center_crop,
70 | )
71 | target_frames = frames[n_frames-new_frame:n_frames, :, :, :]
72 |
73 | if frames is None:
74 | return self[np.random.randint(0, len(self))]
75 |
76 |
77 | return frames, target_frames, self.videos[index]
78 |
79 | def __len__(self):
80 | return len(self.videos)
81 |
82 |
83 |
84 | def main(_):
85 | assert FLAGS.checkpoint != '' and FLAGS.output_dir != ''
86 | assert FLAGS.read_file_list != '' or FLAGS.input_files != ''
87 | os.makedirs(FLAGS.output_dir, exist_ok=True)
88 |
89 | if FLAGS.read_file_list != '':
90 | with open(FLAGS.read_file_list, 'r') as f:
91 | videos = [x.strip() for x in f.readlines()]
92 | else:
93 | videos = glob.glob(FLAGS.input_files)
94 |
95 | if FLAGS.frame_input:
96 | videos = [x for x in videos if os.path.isdir(x)]
97 | else:
98 | videos = [x for x in videos if is_video(x)]
99 |
100 | if FLAGS.shuffle:
101 | np.random.shuffle(videos)
102 |
103 | if FLAGS.max_examples > 0:
104 | videos = videos[:FLAGS.max_examples]
105 |
106 | dataset = VideoDataset(
107 | videos,
108 | frame_input=FLAGS.frame_input,
109 | n_frames=FLAGS.n_context_frames,
110 | stride=FLAGS.stride
111 | )
112 | dataloader = torch.utils.data.DataLoader(
113 | dataset,
114 | batch_size=FLAGS.batch_size,
115 | shuffle=False,
116 | num_workers=FLAGS.n_workers,
117 | prefetch_factor=4,
118 | drop_last=True,
119 | )
120 |
121 | if FLAGS.torch_devices == '':
122 | torch_devices = None
123 | else:
124 | torch_devices = [f'cuda:{x}' for x in FLAGS.torch_devices.split(',')]
125 |
126 | model = MultiProcessInferenceModel(
127 | checkpoint=FLAGS.checkpoint, torch_devices=torch_devices,
128 | )
129 |
130 | save_img_pool = Pool(FLAGS.n_workers)
131 |
132 |
133 | fids
134 |
135 | for batch, batch_targets, filenames in tqdm(dataloader, ncols=0):
136 |
137 | batch = batch.numpy() # 'b s h w c '
138 |
139 |
140 |
141 | generated = model(
142 | batch,
143 | n_new_frames=FLAGS.n_new_frames,
144 | n_candidates=FLAGS.n_candidates,
145 | temperature=FLAGS.temperature,
146 | top_p=FLAGS.top_p,
147 | )
148 |
149 |
150 | generated = np.array(generated)
151 |
152 | batch_targets = einops.repeat(
153 | batch_targets.numpy(),
154 | 'b s h w c -> b n s h w c', # batch, candidate, sequence, h, w, c.
155 | n=FLAGS.n_candidates,
156 | )
157 |
158 |
159 | if __name__ == '__main__':
160 | mlxu.run(main)
--------------------------------------------------------------------------------
/evaluation/vqlm_demo/generate_videos.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import glob
4 | from functools import partial
5 | from tqdm import tqdm, trange
6 | from multiprocessing import Pool
7 | from PIL import Image
8 | import cv2
9 | import mlxu
10 | from natsort import natsorted
11 | import numpy as np
12 | import einops
13 | import torch
14 |
15 | from vqlm_demo.inference import MultiProcessInferenceModel
16 | from vqlm_demo.utils import (
17 | is_video, random_square_crop,
18 | read_frames_from_dir, read_frames_from_video
19 | )
20 |
21 |
22 | FLAGS, _ = mlxu.define_flags_with_default(
23 | checkpoint='',
24 | input_files='',
25 | frame_input=False,
26 | read_file_list='',
27 | output_dir='',
28 | center_crop=1.0,
29 | n_context_frames=12,
30 | n_new_frames=4,
31 | n_candidates=8,
32 | temperature=1.0,
33 | top_p=1.0,
34 | n_workers=8,
35 | stride=8,
36 | batch_size=32,
37 | torch_devices='',
38 | shuffle=False,
39 | max_examples=0,
40 | )
41 |
42 |
43 | def save_image(args):
44 | image, filename = args
45 | base = FLAGS.input_files.split('*')[0]
46 | filename = filename[len(base):].replace('/', '_') + '.png'
47 | Image.fromarray(image).save(os.path.join(FLAGS.output_dir, filename))
48 |
49 |
50 | class VideoDataset(torch.utils.data.Dataset):
51 |
52 | def __init__(self, videos, frame_input=False, n_frames=8, stride=1):
53 | self.videos = videos
54 | self.frame_input = frame_input
55 | self.n_frames = n_frames
56 | self.stride = stride
57 |
58 | def __getitem__(self, index):
59 | if self.frame_input:
60 | frames = read_frames_from_dir(
61 | self.videos[index], self.n_frames, self.stride,
62 | center_crop=FLAGS.center_crop,
63 | )
64 | else:
65 | frames = read_frames_from_video(
66 | self.videos[index], self.n_frames, self.stride,
67 | center_crop=FLAGS.center_crop,
68 | )
69 | if frames is None:
70 | return self[np.random.randint(0, len(self))]
71 | return frames, self.videos[index]
72 |
73 | def __len__(self):
74 | return len(self.videos)
75 |
76 |
77 |
78 | def main(_):
79 | assert FLAGS.checkpoint != '' and FLAGS.output_dir != ''
80 | assert FLAGS.read_file_list != '' or FLAGS.input_files != ''
81 | os.makedirs(FLAGS.output_dir, exist_ok=True)
82 |
83 | if FLAGS.read_file_list != '':
84 | with open(FLAGS.read_file_list, 'r') as f:
85 | videos = [x.strip() for x in f.readlines()]
86 | else:
87 | videos = glob.glob(FLAGS.input_files)
88 |
89 | if FLAGS.frame_input:
90 | videos = [x for x in videos if os.path.isdir(x)]
91 | else:
92 | videos = [x for x in videos if is_video(x)]
93 |
94 | if FLAGS.shuffle:
95 | np.random.shuffle(videos)
96 |
97 | if FLAGS.max_examples > 0:
98 | videos = videos[:FLAGS.max_examples]
99 |
100 | dataset = VideoDataset(
101 | videos,
102 | frame_input=FLAGS.frame_input,
103 | n_frames=FLAGS.n_context_frames,
104 | stride=FLAGS.stride
105 | )
106 | dataloader = torch.utils.data.DataLoader(
107 | dataset,
108 | batch_size=FLAGS.batch_size,
109 | shuffle=False,
110 | num_workers=FLAGS.n_workers,
111 | prefetch_factor=4,
112 | drop_last=True,
113 | )
114 |
115 | if FLAGS.torch_devices == '':
116 | torch_devices = None
117 | else:
118 | torch_devices = [f'cuda:{x}' for x in FLAGS.torch_devices.split(',')]
119 |
120 | model = MultiProcessInferenceModel(
121 | checkpoint=FLAGS.checkpoint, torch_devices=torch_devices,
122 | )
123 |
124 | save_img_pool = Pool(FLAGS.n_workers)
125 |
126 |
127 |
128 | for batch, filenames in tqdm(dataloader, ncols=0):
129 |
130 |
131 |
132 | batch = batch.numpy()
133 |
134 |
135 |
136 | generated = model(
137 | batch,
138 | n_new_frames=FLAGS.n_new_frames,
139 | n_candidates=FLAGS.n_candidates,
140 | temperature=FLAGS.temperature,
141 | top_p=FLAGS.top_p,
142 | )
143 |
144 |
145 | generated = np.array(generated)
146 |
147 |
148 |
149 |
150 | output_batch = einops.repeat(
151 | batch,
152 | 'b s h w c -> b n s h w c',
153 | n=FLAGS.n_candidates,
154 | )
155 |
156 |
157 | combined = einops.rearrange(
158 | np.concatenate([output_batch, generated], axis=2),
159 | 'b n s h w c -> b (n h) (s w) c'
160 | )
161 |
162 |
163 | combined = (np.clip(combined, 0, 1) * 255).astype(np.uint8)
164 | save_img_pool.imap(save_image, zip(combined, filenames))
165 |
166 |
167 | if __name__ == '__main__':
168 | mlxu.run(main)
--------------------------------------------------------------------------------
/evaluation/vqlm_demo/inference.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from contextlib import nullcontext
3 | import time
4 | import os
5 | from functools import partial
6 | from copy import deepcopy
7 | from multiprocessing import Pool
8 | from threading import Lock
9 | from PIL import Image
10 | import numpy as np
11 | import torch
12 | import torch.nn.functional as F
13 | import einops
14 | from transformers import LlamaForCausalLM
15 |
16 | from .vqvae_muse import VQGANModel, get_tokenizer_muse
17 | from .torch_vqvae_model import get_tokenizer
18 |
19 |
20 | def get_torch_float_dtype(dtype):
21 | if dtype in (torch.float16, torch.bfloat16, torch.float32):
22 | return dtype
23 | return {
24 | 'float16': torch.float16,
25 | 'fp16': torch.float16,
26 | 'f16': torch.float16,
27 | 'bfloat16': torch.bfloat16,
28 | 'bf16': torch.bfloat16,
29 | 'float32': torch.float32,
30 | 'fp32': torch.float32,
31 | 'f32': torch.float32,
32 | }[dtype]
33 |
34 |
35 | def get_pid():
36 | time.sleep(1)
37 | return os.getpid()
38 |
39 |
40 | class InferenceModel(ABC):
41 |
42 | @abstractmethod
43 | def __call__(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0):
44 | raise NotImplementedError()
45 |
46 |
47 | class LocalInferenceModel(InferenceModel):
48 |
49 | def __init__(self, checkpoint, dtype='float16', torch_device='cuda',
50 | context_frames=16, use_lock=False):
51 | self.checkpoint = checkpoint
52 | self.dtype = dtype
53 | self.torch_device = torch_device
54 | self.context_frames = context_frames
55 |
56 | # old version of the tokenizer
57 | # self.tokenizer = get_tokenizer()
58 | # self.tokenizer.to(self.torch_device)
59 |
60 | # new tokenizer
61 | self.tokenizer = get_tokenizer_muse()
62 | self.tokenizer.to(self.torch_device)
63 |
64 | self.model = LlamaForCausalLM.from_pretrained(
65 | self.checkpoint, torch_dtype=get_torch_float_dtype(self.dtype)
66 | ).to(self.torch_device)
67 |
68 | if use_lock:
69 | self.lock = Lock()
70 | else:
71 | self.lock = nullcontext()
72 |
73 | @torch.no_grad()
74 | def compute_perplexity(self, input_images, target_images):
75 | input_images = np.array(input_images)
76 | target_images = np.array(target_images)
77 | assert len(input_images.shape) == 5 and len(target_images.shape) == 5 # [B, S, H, W, C]
78 | assert input_images.shape[0] == target_images.shape[0]
79 | batch_size = input_images.shape[0]
80 | with self.lock:
81 | input_images = torch.tensor(
82 | einops.rearrange(input_images, 'b s h w c -> b s c h w')
83 | ).to(self.torch_device)
84 | target_images = torch.tensor(
85 | einops.rearrange(target_images, 'b s h w c -> b s c h w')
86 | ).to(self.torch_device)
87 | input_ids = self.tokenizer.tokenize(input_images).view(batch_size, -1)
88 | target_ids = self.tokenizer.tokenize(target_images).view(batch_size, -1)
89 | all_ids = torch.cat([input_ids, target_ids], dim=1)
90 | logits = self.model(all_ids).logits
91 | log_probs = F.log_softmax(logits, dim=-1)
92 | target_ids_onehot = F.one_hot(target_ids, num_classes=logits.shape[-1])
93 | target_log_probs = log_probs[:, input_ids.shape[1] - 1 : -1]
94 | perplexity = torch.exp(
95 | -torch.mean(
96 | torch.sum(target_log_probs * target_ids_onehot, dim=-1),
97 | dim=-1
98 | )
99 | )
100 | return perplexity.detach().cpu().numpy()
101 |
102 | @torch.no_grad()
103 | def generate_once(self, input_images, n_new_frames, temperature=1.0, top_p=1.0):
104 | assert type(input_images) == np.ndarray
105 | with self.lock:
106 | input_images = np.array(input_images, dtype=np.float32)
107 | input_images = torch.tensor(
108 | einops.rearrange(input_images, 'b h w c -> b c h w')
109 | ).to(self.torch_device)
110 |
111 | print('here:', type(input_images))
112 |
113 | # old tokenizer
114 | # input_ids = self.tokenizer.tokenize(input_images).view(1, -1)
115 |
116 | # new tokenizer
117 | _, input_ids = self.tokenizer.encode(input_images)
118 | input_ids = input_ids.view(1, -1)
119 |
120 |
121 | input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
122 |
123 | new_tokens = []
124 | current_context_frames = input_ids.shape[1] // 256
125 | fisrt_generation_left = self.context_frames - current_context_frames
126 | first_new_frames = min(fisrt_generation_left, n_new_frames)
127 | input_ids = self.model.generate(
128 | input_ids=input_ids,
129 | attention_mask=torch.ones_like(input_ids),
130 | pad_token_id=8192,
131 | max_new_tokens=256 * first_new_frames,
132 | do_sample=True,
133 | top_p=top_p,
134 | temperature=temperature,
135 | suppress_tokens=list(range(8192, self.model.vocab_size)),
136 | )
137 | new_tokens.append(input_ids[:, -256 * first_new_frames:])
138 | input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
139 |
140 | for _ in range(max(0, n_new_frames - first_new_frames)):
141 | input_ids = self.model.generate(
142 | input_ids=input_ids,
143 | attention_mask=torch.ones_like(input_ids),
144 | pad_token_id=8192,
145 | max_new_tokens=256,
146 | do_sample=True,
147 | top_p=top_p,
148 | temperature=temperature,
149 | suppress_tokens=list(range(8192, self.model.vocab_size)),
150 | )
151 | new_tokens.append(input_ids[:, -256:])
152 | input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
153 |
154 | new_tokens = torch.cat(new_tokens, dim=1).view(-1, 256)
155 | new_images = einops.rearrange(
156 | torch.clamp(self.tokenizer.decode_code(new_tokens), 0.0, 1.0),
157 | 'b c h w -> b h w c'
158 | ).detach().cpu().numpy()
159 | return new_images
160 |
161 | def __call__(self, input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0):
162 | output = []
163 | for seq in input_images:
164 | output.append(
165 | [self.generate_once(seq, n_new_frames, temperature, top_p)
166 | for _ in range(n_candidates)]
167 | )
168 | return output
169 |
170 |
171 | class MultiProcessInferenceModel(InferenceModel):
172 |
173 | def __init__(self, checkpoint, torch_devices=None, dtype='float16',
174 | context_frames=16, use_lock=False, perplexity_batch_size=2):
175 | if torch_devices is None or torch_devices == '':
176 | torch_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
177 |
178 | self.torch_devices = torch_devices
179 | self.n_processes = len(torch_devices)
180 | print(f'Using {self.n_processes} processes for inference')
181 | self.worker_pool = Pool(self.n_processes)
182 | self.worker_pids = self.worker_pool.starmap(get_pid, [tuple() for _ in range(self.n_processes)])
183 | self.device_map = {
184 | pid: torch_device
185 | for pid, torch_device in zip(self.worker_pids, self.torch_devices)
186 | }
187 | self.worker_pool.starmap(
188 | self.initialize_worker,
189 | [(self.device_map, checkpoint, dtype, context_frames) for _ in range(self.n_processes)]
190 | )
191 | self.perplexity_batch_size = perplexity_batch_size
192 | if use_lock:
193 | self.lock = Lock()
194 | else:
195 | self.lock = nullcontext()
196 |
197 | @staticmethod
198 | def initialize_worker(device_map, checkpoint, dtype, context_frames):
199 | global _current_process_backend
200 | torch_device = device_map[os.getpid()]
201 | _current_process_backend = LocalInferenceModel(
202 | checkpoint, dtype, torch_device, context_frames
203 | )
204 |
205 | @staticmethod
206 | def generate_once(input_images, n_new_frames, temperature=1.0, top_p=1.0):
207 | return _current_process_backend.generate_once(input_images, n_new_frames, temperature, top_p)
208 |
209 | @staticmethod
210 | def compute_perplexity_once(input_images, target_images):
211 | return _current_process_backend.compute_perplexity(input_images, target_images)
212 |
213 | def compute_perplexity(self, input_images, target_images):
214 | with self.lock:
215 | map_args = []
216 | for i in range(0, len(input_images), self.perplexity_batch_size):
217 | map_args.append((
218 | input_images[i : i + self.perplexity_batch_size],
219 | target_images[i : i + self.perplexity_batch_size]
220 | ))
221 | outputs = self.worker_pool.starmap(self.compute_perplexity_once, map_args)
222 | return np.concatenate(outputs, axis=0)
223 |
224 | def __call__(self, input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0):
225 | with self.lock:
226 | map_args = []
227 | for seq in input_images:
228 | for _ in range(n_candidates):
229 | map_args.append((seq, n_new_frames, temperature, top_p))
230 |
231 | outputs = self.worker_pool.starmap(self.generate_once, map_args)
232 | reshaped_output = []
233 | index = 0
234 | for _ in range(len(input_images)):
235 | candidates = []
236 | for _ in range(n_candidates):
237 | candidates.append(outputs[index])
238 | index += 1
239 | reshaped_output.append(candidates)
240 | return reshaped_output
241 |
242 |
--------------------------------------------------------------------------------
/evaluation/vqlm_demo/torch_vqvae_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import einops
6 | from einops.layers.torch import Rearrange
7 |
8 |
9 | def normalize(in_channels):
10 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
11 |
12 | def swish(x):
13 | return x*torch.sigmoid(x)
14 |
15 | class ResBlock(nn.Module):
16 | def __init__(self, in_channels, out_channels=None, activation_fn="relu"):
17 | super(ResBlock, self).__init__()
18 | self.in_channels = in_channels
19 | self.out_channels = in_channels if out_channels is None else out_channels
20 | self.norm1 = normalize(in_channels)
21 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
22 | self.norm2 = normalize(out_channels)
23 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
24 | if self.in_channels != self.out_channels:
25 | self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
26 | self.activation_fn = activation_fn
27 | if activation_fn=="relu":
28 | self.actn = nn.ReLU()
29 |
30 |
31 | def forward(self, x_in):
32 | x = x_in
33 | x = self.norm1(x)
34 | if self.activation_fn=="relu":
35 | x = self.actn(x)
36 | elif self.activation_fn=="swish":
37 | x = swish(x)
38 | x = self.conv1(x)
39 | x = self.norm2(x)
40 | if self.activation_fn=="relu":
41 | x = self.actn(x)
42 | elif self.activation_fn=="swish":
43 | x = swish(x)
44 | x = self.conv2(x)
45 | if self.in_channels != self.out_channels:
46 | x_in = self.conv_out(x_in)
47 |
48 | return x + x_in
49 |
50 | class Encoder(nn.Module):
51 | def __init__(self, ):
52 | super().__init__()
53 |
54 | self.filters = 128
55 | self.num_res_blocks = 2
56 | self.ch_mult = [1,1,2,2,4]
57 | self.in_ch_mult = (1,)+tuple(self.ch_mult)
58 | self.embedding_dim = 32
59 | self.conv_downsample = False
60 |
61 | self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, bias=False)
62 | blocks = []
63 | for i in range(len(self.ch_mult)):
64 | block_in_ch = self.filters * self.in_ch_mult[i]
65 | block_out_ch = self.filters * self.ch_mult[i]
66 | for _ in range(self.num_res_blocks):
67 | blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish"))
68 | block_in_ch = block_out_ch
69 | for _ in range(self.num_res_blocks):
70 | blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish"))
71 | self.norm1 = normalize(block_in_ch)
72 | self.conv2 = nn.Conv2d(block_in_ch, self.embedding_dim, kernel_size=1, stride=1, padding=0)
73 | self.blocks = nn.ModuleList(blocks)
74 |
75 | def forward(self, x):
76 | x = self.conv1(x)
77 | for i in range(len(self.ch_mult)):
78 | for j in range(self.num_res_blocks):
79 | x = self.blocks[i*2+j](x)
80 |
81 | if i < len(self.ch_mult) -1:
82 | x = torch.nn.functional.avg_pool2d(x, (2,2),(2,2))
83 |
84 | x = self.blocks[-2](x)
85 | x = self.blocks[-1](x)
86 |
87 | x = self.norm1(x)
88 | x = swish(x)
89 | x = self.conv2(x)
90 | return x
91 |
92 | class VectorQuantizer(nn.Module):
93 | def __init__(self, codebook_size=8192, emb_dim=32, beta=None):
94 | super(VectorQuantizer, self).__init__()
95 | self.codebook_size = codebook_size # number of embeddings
96 | self.emb_dim = emb_dim # dimension of embedding
97 | self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
98 | self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
99 | self.beta=0.0
100 | self.z_dim = emb_dim
101 |
102 | def forward(self, z):
103 | # preprocess
104 |
105 | b, c, h, w = z.size()
106 | flatten = z.permute(0, 2, 3, 1).reshape(-1, c)
107 | codebook = self.embedding.weight
108 | with torch.no_grad():
109 | tokens = torch.cdist(flatten, codebook).argmin(dim=1)
110 | quantized = F.embedding(tokens,
111 | codebook).view(b, h, w, c).permute(0, 3, 1, 2)
112 |
113 | # compute loss
114 | codebook_loss = F.mse_loss(quantized, z.detach())
115 | commitment_loss = F.mse_loss(quantized.detach(), z)
116 | loss = codebook_loss + self.beta * commitment_loss
117 |
118 | # perplexity
119 | counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype)
120 | # dist.all_reduce(counts)
121 | p = counts / counts.sum()
122 | perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10)))
123 |
124 | # postprocess
125 | tokens = tokens.view(b, h, w)
126 | quantized = z + (quantized - z).detach()
127 |
128 | # quantized_2 = self.get_codebook_feat(tokens, (b, h, w, c))
129 |
130 | return quantized, tokens, loss, perplexity
131 |
132 |
133 | def get_codebook_feat(self, indices, shape=None):
134 | # input indices: batch*token_num -> (batch*token_num)*1
135 | # shape: batch, height, width, channel
136 | indices = indices.view(-1,1)
137 | min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
138 | min_encodings.scatter_(1, indices, 1)
139 | # get quantized latent vectors
140 | z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
141 |
142 | if shape is not None: # reshape back to match original input shape
143 | z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
144 |
145 | return z_q
146 |
147 |
148 | class Decoder(nn.Module):
149 | def __init__(self,):
150 | super().__init__()
151 | self.filters = 128
152 | self.num_res_blocks = 2
153 | self.ch_mult = [1,1,2,2,4]
154 | self.in_ch_mult = (1,)+tuple(self.ch_mult)
155 | self.embedding_dim =32
156 | self.out_channels = 3
157 | self.in_channels = self.embedding_dim
158 | self.conv_downsample = False
159 |
160 | self.conv1 = nn.Conv2d(32, 512, kernel_size=3, stride=1, padding=1)
161 | blocks = []
162 | block_in_ch = self.filters * self.ch_mult[-1]
163 | block_out_ch = self.filters * self.ch_mult[-1]
164 | #blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
165 | for _ in range(self.num_res_blocks):
166 | blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish"))
167 | upsample_conv_layers = []
168 | for i in reversed(range(len(self.ch_mult))):
169 | block_out_ch = self.filters * self.ch_mult[i]
170 | for _ in range(self.num_res_blocks):
171 | blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish"))
172 | block_in_ch = block_out_ch
173 | if i > 0:
174 | upsample_conv_layers.append(nn.Conv2d(block_in_ch, block_out_ch*4, kernel_size=3, stride=1, padding=1))
175 |
176 | self.upsample = Rearrange("b h w (h2 w2 c) -> b (h h2) (w w2) c", h2=2, w2=2)
177 | self.norm1 = normalize(block_in_ch)
178 | # self.act_fn
179 | self.conv6 = nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)
180 | self.blocks = nn.ModuleList(blocks)
181 | self.up_convs = nn.ModuleList(upsample_conv_layers)
182 |
183 | def forward(self, x):
184 | x = self.conv1(x)
185 | x = self.blocks[0](x)
186 | x = self.blocks[1](x)
187 | for i in range(len(self.ch_mult)):
188 | for j in range(self.num_res_blocks):
189 | x = self.blocks[2+i*2+j](x)
190 | if i < len(self.ch_mult)-1:
191 | x = self.up_convs[i](x)
192 | #print("pre: x.size()",x.size())
193 | x = x.permute(0,2,3,1)
194 | x = self.upsample(x)
195 | x = x.permute(0,3,1,2)
196 | #print("post: x.size()", x.size())
197 | x = self.norm1(x)
198 | x = swish(x)
199 | x = self.conv6(x)
200 | return x
201 |
202 |
203 | class VQVAE(nn.Module):
204 | def __init__(self, ):
205 | super().__init__()
206 | self.encoder = Encoder()
207 | self.quantizer = VectorQuantizer()
208 | self.decoder = Decoder()
209 |
210 | def forward(self, x):
211 | x = self.encoder(x)
212 | quant,tokens, loss, perplexity = self.quantizer(x)
213 | x = self.decoder(quant)
214 | return x
215 |
216 | def tokenize(self, x):
217 | batch_shape = x.shape[:-3]
218 | x = x.reshape(-1, *x.shape[-3:])
219 | x = self.encoder(x)
220 | quant,tokens, loss, perplexity = self.quantizer(x)
221 | return tokens.reshape(*batch_shape, *tokens.shape[1:])
222 |
223 | def decode(self, tokens):
224 | tokens = einops.rearrange(tokens, 'b ... -> b (...)')
225 | b = tokens.shape[0]
226 | if tokens.shape[-1] == 256:
227 | hw = 16
228 | elif tokens.shape[-1] == 224:
229 | hw = 14
230 | else:
231 | raise ValueError("Invalid tokens shape")
232 | quant = self.quantizer.get_codebook_feat(tokens, (b, hw, hw, 32))
233 | x = self.decoder(quant)
234 | return x
235 |
236 |
237 | class VAEDecoder(nn.Module):
238 | def __init__(self, ):
239 | super().__init__()
240 | self.quantizer = VectorQuantizer()
241 | self.decoder = Decoder()
242 |
243 | def forward(self, x):
244 | quant = self.quantizer.get_codebook_feat(x,(1,14,14,32))
245 | x = self.decoder(quant)
246 | return x
247 |
248 |
249 | def get_tokenizer():
250 | checkpoint_path = os.path.join(
251 | os.path.dirname(os.path.realpath(__file__)), "xh_ckpt.pth"
252 | )
253 | torch_state_dict = torch.load(checkpoint_path)
254 | net = VQVAE()
255 | net.load_state_dict(torch_state_dict)
256 | return net
257 |
258 |
--------------------------------------------------------------------------------
/evaluation/vqlm_demo/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from multiprocessing import Pool
3 | import numpy as np
4 | import random
5 | from PIL import Image
6 | import re
7 | import cv2
8 | import glob
9 | from natsort import natsorted
10 |
11 |
12 | class MultiProcessImageSaver(object):
13 |
14 | def __init__(self, n_workers=1):
15 | self.pool = Pool(n_workers)
16 |
17 | def __call__(self, images, output_files, resizes=None):
18 | if resizes is None:
19 | resizes = [None for _ in range(len(images))]
20 | return self.pool.imap(
21 | self.save_image,
22 | zip(images, output_files, resizes),
23 | )
24 |
25 | def close(self):
26 | self.pool.close()
27 | self.pool.join()
28 |
29 | @staticmethod
30 | def save_image(args):
31 | image, filename, resize = args
32 | image = Image.fromarray(image)
33 | if resize is not None:
34 | image = image.resize(tuple(resize))
35 | image.save(filename)
36 |
37 |
38 | def list_dir_with_full_path(path):
39 | return [os.path.join(path, f) for f in os.listdir(path)]
40 |
41 |
42 | def find_all_files_in_dir(path):
43 | files = []
44 | for root, _, files in os.walk(path):
45 | for file in files:
46 | files.append(os.path.join(root, file))
47 | return files
48 |
49 |
50 | def is_image(path):
51 | return (
52 | path.endswith('.jpg')
53 | or path.endswith('.png')
54 | or path.endswith('.jpeg')
55 | or path.endswith('.JPG')
56 | or path.endswith('.PNG')
57 | or path.endswith('.JPEG')
58 | )
59 |
60 |
61 | def is_video(path):
62 | return (
63 | path.endswith('.mp4')
64 | or path.endswith('.avi')
65 | or path.endswith('.MP4')
66 | or path.endswith('.AVI')
67 | or path.endswith('.webm')
68 | or path.endswith('.WEBM')
69 | or path.endswith('.mkv')
70 | or path.endswith('.MVK')
71 | )
72 |
73 |
74 | def random_square_crop(img, random_generator=None):
75 | # If no random generator is provided, use numpy's default
76 | if random_generator is None:
77 | random_generator = np.random.default_rng()
78 |
79 | # Get the width and height of the image
80 | width, height = img.size
81 |
82 | # Determine the shorter side
83 | min_size = min(width, height)
84 |
85 | # Randomly determine the starting x and y coordinates for the crop
86 | if width > height:
87 | left = random_generator.integers(0, width - min_size)
88 | upper = 0
89 | else:
90 | left = 0
91 | upper = random_generator.integers(0, height - min_size)
92 |
93 | # Calculate the ending x and y coordinates for the crop
94 | right = left + min_size
95 | lower = upper + min_size
96 |
97 | # Crop the image
98 | return img.crop((left, upper, right, lower))
99 |
100 |
101 | def read_image_to_tensor(path, center_crop=1.0):
102 | pil_im = Image.open(path).convert('RGB')
103 | if center_crop < 1.0:
104 | width, height = pil_im.size
105 | pil_im = pil_im.crop((
106 | int((1 - center_crop) * height / 2), int((1 + center_crop) * height / 2),
107 | int((1 - center_crop) * width / 2), int((1 + center_crop) * width / 2),
108 | ))
109 | input_img = pil_im.resize((256, 256))
110 | input_img = np.array(input_img) / 255.0
111 | input_img = input_img.astype(np.float32)
112 | return input_img
113 |
114 |
115 | def match_mulitple_path(root_dir, regex):
116 | videos = []
117 | for root, _, files in os.walk(root_dir):
118 | for file in files:
119 | videos.append(os.path.join(root, file))
120 |
121 | videos = [v for v in videos if not v.split('/')[-1].startswith('.')]
122 |
123 | grouped_path = {}
124 | for r in regex:
125 | r = re.compile(r)
126 | for v in videos:
127 | matched = r.findall(v)
128 | if len(matched) > 0:
129 | groups = matched[0]
130 | if groups not in grouped_path:
131 | grouped_path[groups] = []
132 | grouped_path[groups].append(v)
133 |
134 | grouped_path = {
135 | k: tuple(v) for k, v in grouped_path.items()
136 | if len(v) == len(regex)
137 | }
138 | return list(grouped_path.values())
139 |
140 |
141 | def randomly_subsample_frame_indices(length, n_frames, max_stride=30, random_start=True):
142 | assert length >= n_frames
143 | max_stride = min(
144 | (length - 1) // (n_frames - 1),
145 | max_stride
146 | )
147 | stride = np.random.randint(1, max_stride + 1)
148 | if random_start:
149 | start = np.random.randint(0, length - (n_frames - 1) * stride)
150 | else:
151 | start = 0
152 | return np.arange(n_frames) * stride + start
153 |
154 |
155 | def read_frames_from_dir(dir_path, n_frames, stride, random_start=True, center_crop=1.0):
156 | files = [os.path.join(dir_path, x) for x in os.listdir(dir_path)]
157 | files = natsorted([x for x in files if is_image(x)])
158 |
159 | total_frames = len(files)
160 |
161 | if total_frames < n_frames:
162 | return None
163 |
164 | max_stride = (total_frames - 1) // (n_frames - 1)
165 | stride = min(max_stride, stride)
166 |
167 | if random_start:
168 | start = np.random.randint(0, total_frames - (n_frames - 1) * stride)
169 | else:
170 | start = 0
171 | frame_indices = np.arange(n_frames) * stride + start
172 |
173 | frames = []
174 | for frame_index in sorted(frame_indices):
175 | # Check if the frame_index is valid
176 | frames.append(read_image_to_tensor(files[frame_index], center_crop=center_crop))
177 | if len(frames) < n_frames:
178 | return None
179 | frames = np.stack(frames, axis=0)
180 | return frames
181 |
182 |
183 | def read_frames_from_video(video_path, n_frames, stride, random_start=True, center_crop=1.0):
184 |
185 | frames = []
186 | cap = cv2.VideoCapture(video_path)
187 |
188 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
189 |
190 | if total_frames < n_frames:
191 | cap.release()
192 | return None
193 |
194 | max_stride = (total_frames - 1) // (n_frames - 1)
195 | stride = min(max_stride, stride)
196 |
197 | if random_start:
198 | start = np.random.randint(0, total_frames - (n_frames - 1) * stride)
199 | else:
200 | start = 0
201 | frame_indices = np.arange(n_frames) * stride + start
202 |
203 | for frame_index in sorted(frame_indices):
204 | # Check if the frame_index is valid
205 | if 0 <= frame_index < total_frames:
206 | cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
207 | ret, frame = cap.read()
208 | if ret:
209 | if center_crop < 1.0:
210 | height, width, _ = frame.shape
211 | frame = frame[
212 | int((1 - center_crop) * height / 2):int((1 + center_crop) * height / 2),
213 | int((1 - center_crop) * width / 2):int((1 + center_crop) * width / 2),
214 | :
215 | ]
216 | frame = cv2.resize(frame, (256, 256))
217 |
218 | frames.append(frame)
219 |
220 | else:
221 | print(f"Frame index {frame_index} is out of bounds. Skipping...")
222 |
223 | cap.release()
224 | if len(frames) < n_frames:
225 | return None
226 | frames = np.stack(frames, axis=0).astype(np.float32) / 255.0
227 |
228 | # From BGR to RGB
229 | return np.stack(
230 | [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1
231 | )
232 |
233 |
234 | def read_all_frames_from_video(video_path, center_crop=1.0):
235 |
236 | frames = []
237 | cap = cv2.VideoCapture(video_path)
238 |
239 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
240 |
241 |
242 | for frame_index in range(total_frames):
243 | # Check if the frame_index is valid
244 | if 0 <= frame_index < total_frames:
245 | cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
246 | ret, frame = cap.read()
247 | if ret:
248 | if center_crop < 1.0:
249 | height, width, _ = frame.shape
250 | frame = frame[
251 | int((1 - center_crop) * height / 2):int((1 + center_crop) * height / 2),
252 | int((1 - center_crop) * width / 2):int((1 + center_crop) * width / 2),
253 | :
254 | ]
255 | frames.append(cv2.resize(frame, (256, 256)))
256 | else:
257 | print(f"Frame index {frame_index} is out of bounds. Skipping...")
258 |
259 | cap.release()
260 | if len(frames) == 0:
261 | return None
262 | frames = np.stack(frames, axis=0).astype(np.float32) / 255.0
263 | # From BGR to RGB
264 | return np.stack(
265 | [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1
266 | )
267 |
268 |
269 | def read_max_span_frames_from_video(video_path, n_frames):
270 | frames = []
271 | cap = cv2.VideoCapture(video_path)
272 |
273 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
274 | if total_frames < n_frames:
275 | cap.release()
276 | return None
277 | stride = (total_frames - 1) // (n_frames - 1)
278 | frame_indices = np.arange(n_frames) * stride
279 |
280 | frames = []
281 | for frame_index in frame_indices:
282 | cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
283 | ret, frame = cap.read()
284 | if ret:
285 | frames.append(cv2.resize(frame, (256, 256)))
286 |
287 | cap.release()
288 | if len(frames) < n_frames:
289 | return None
290 |
291 | frames = np.stack(frames, axis=0).astype(np.float32) / 255.0
292 | # From BGR to RGB
293 | return np.stack(
294 | [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1
295 | )
296 |
297 |
--------------------------------------------------------------------------------
/evaluation/vqlm_demo/vqvae/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/evaluation/vqlm_demo/vqvae/.DS_Store
--------------------------------------------------------------------------------
/evaluation/vqlm_demo/vqvae/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | __version__ = "0.0.1"
17 |
18 | # from .modeling_ema import EMAModel
19 | # from .modeling_maskgit_vqgan import MaskGitVQGAN
20 | # from .modeling_movq import MOVQ
21 | # from .modeling_paella_vq import PaellaVQModel
22 | # from .modeling_utils import VQGANModel
23 | # from .modeling_transformer import MaskGitTransformer, MaskGiTUViT
24 | # from .pipeline_muse import PipelineMuse, PipelineMuseInpainting
25 | # from .sampling import get_mask_chedule
26 |
--------------------------------------------------------------------------------
/evaluation/vqlm_demo/vqvae/logging.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 Optuna, Hugging Face
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Logging utilities."""
16 |
17 | import logging
18 | import os
19 | import sys
20 | import threading
21 | from logging import CRITICAL # NOQA
22 | from logging import DEBUG # NOQA
23 | from logging import ERROR # NOQA
24 | from logging import FATAL # NOQA
25 | from logging import INFO # NOQA
26 | from logging import NOTSET # NOQA
27 | from logging import WARN # NOQA
28 | from logging import WARNING # NOQA
29 | from typing import Optional
30 |
31 | from tqdm import auto as tqdm_lib
32 |
33 | _lock = threading.Lock()
34 | _default_handler: Optional[logging.Handler] = None
35 |
36 | log_levels = {
37 | "debug": logging.DEBUG,
38 | "info": logging.INFO,
39 | "warning": logging.WARNING,
40 | "error": logging.ERROR,
41 | "critical": logging.CRITICAL,
42 | }
43 |
44 | _default_log_level = logging.WARNING
45 |
46 | _tqdm_active = True
47 |
48 |
49 | def _get_default_logging_level():
50 | """
51 | If muse_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
52 | not - fall back to `_default_log_level`
53 | """
54 | env_level_str = os.getenv("muse_VERBOSITY", None)
55 | if env_level_str:
56 | if env_level_str in log_levels:
57 | return log_levels[env_level_str]
58 | else:
59 | logging.getLogger().warning(
60 | f"Unknown option muse_VERBOSITY={env_level_str}, has to be one of: { ', '.join(log_levels.keys()) }"
61 | )
62 | return _default_log_level
63 |
64 |
65 | def _get_library_name() -> str:
66 | return __name__.split(".")[0]
67 |
68 |
69 | def _get_library_root_logger() -> logging.Logger:
70 | return logging.getLogger(_get_library_name())
71 |
72 |
73 | def _configure_library_root_logger() -> None:
74 | global _default_handler
75 |
76 | with _lock:
77 | if _default_handler:
78 | # This library has already configured the library root logger.
79 | return
80 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
81 | _default_handler.flush = sys.stderr.flush
82 |
83 | # Apply our default configuration to the library root logger.
84 | library_root_logger = _get_library_root_logger()
85 | library_root_logger.addHandler(_default_handler)
86 | library_root_logger.setLevel(_get_default_logging_level())
87 | library_root_logger.propagate = False
88 |
89 |
90 | def _reset_library_root_logger() -> None:
91 | global _default_handler
92 |
93 | with _lock:
94 | if not _default_handler:
95 | return
96 |
97 | library_root_logger = _get_library_root_logger()
98 | library_root_logger.removeHandler(_default_handler)
99 | library_root_logger.setLevel(logging.NOTSET)
100 | _default_handler = None
101 |
102 |
103 | def get_log_levels_dict():
104 | return log_levels
105 |
106 |
107 | def get_logger(name: Optional[str] = None) -> logging.Logger:
108 | """
109 | Return a logger with the specified name.
110 |
111 | This function is not supposed to be directly accessed unless you are writing a custom muse module.
112 | """
113 |
114 | if name is None:
115 | name = _get_library_name()
116 |
117 | _configure_library_root_logger()
118 | return logging.getLogger(name)
119 |
120 |
121 | def get_verbosity() -> int:
122 | """
123 | Return the current level for the 🤗 muse' root logger as an int.
124 |
125 | Returns:
126 | `int`: The logging level.
127 |
128 |
129 |
130 | 🤗 muse has following logging levels:
131 |
132 | - 50: `muse.logging.CRITICAL` or `muse.logging.FATAL`
133 | - 40: `muse.logging.ERROR`
134 | - 30: `muse.logging.WARNING` or `muse.logging.WARN`
135 | - 20: `muse.logging.INFO`
136 | - 10: `muse.logging.DEBUG`
137 |
138 | """
139 |
140 | _configure_library_root_logger()
141 | return _get_library_root_logger().getEffectiveLevel()
142 |
143 |
144 | def set_verbosity(verbosity: int) -> None:
145 | """
146 | Set the verbosity level for the 🤗 muse' root logger.
147 |
148 | Args:
149 | verbosity (`int`):
150 | Logging level, e.g., one of:
151 |
152 | - `muse.logging.CRITICAL` or `muse.logging.FATAL`
153 | - `muse.logging.ERROR`
154 | - `muse.logging.WARNING` or `muse.logging.WARN`
155 | - `muse.logging.INFO`
156 | - `muse.logging.DEBUG`
157 | """
158 |
159 | _configure_library_root_logger()
160 | _get_library_root_logger().setLevel(verbosity)
161 |
162 |
163 | def set_verbosity_info():
164 | """Set the verbosity to the `INFO` level."""
165 | return set_verbosity(INFO)
166 |
167 |
168 | def set_verbosity_warning():
169 | """Set the verbosity to the `WARNING` level."""
170 | return set_verbosity(WARNING)
171 |
172 |
173 | def set_verbosity_debug():
174 | """Set the verbosity to the `DEBUG` level."""
175 | return set_verbosity(DEBUG)
176 |
177 |
178 | def set_verbosity_error():
179 | """Set the verbosity to the `ERROR` level."""
180 | return set_verbosity(ERROR)
181 |
182 |
183 | def disable_default_handler() -> None:
184 | """Disable the default handler of the HuggingFace muse' root logger."""
185 |
186 | _configure_library_root_logger()
187 |
188 | assert _default_handler is not None
189 | _get_library_root_logger().removeHandler(_default_handler)
190 |
191 |
192 | def enable_default_handler() -> None:
193 | """Enable the default handler of the HuggingFace muse' root logger."""
194 |
195 | _configure_library_root_logger()
196 |
197 | assert _default_handler is not None
198 | _get_library_root_logger().addHandler(_default_handler)
199 |
200 |
201 | def add_handler(handler: logging.Handler) -> None:
202 | """adds a handler to the HuggingFace muse' root logger."""
203 |
204 | _configure_library_root_logger()
205 |
206 | assert handler is not None
207 | _get_library_root_logger().addHandler(handler)
208 |
209 |
210 | def remove_handler(handler: logging.Handler) -> None:
211 | """removes given handler from the HuggingFace muse' root logger."""
212 |
213 | _configure_library_root_logger()
214 |
215 | assert handler is not None and handler not in _get_library_root_logger().handlers
216 | _get_library_root_logger().removeHandler(handler)
217 |
218 |
219 | def disable_propagation() -> None:
220 | """
221 | Disable propagation of the library log outputs. Note that log propagation is disabled by default.
222 | """
223 |
224 | _configure_library_root_logger()
225 | _get_library_root_logger().propagate = False
226 |
227 |
228 | def enable_propagation() -> None:
229 | """
230 | Enable propagation of the library log outputs. Please disable the HuggingFace muse' default handler to prevent
231 | double logging if the root logger has been configured.
232 | """
233 |
234 | _configure_library_root_logger()
235 | _get_library_root_logger().propagate = True
236 |
237 |
238 | def enable_explicit_format() -> None:
239 | """
240 | Enable explicit formatting for every HuggingFace muse' logger. The explicit formatter is as follows:
241 | ```
242 | [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
243 | ```
244 | All handlers currently bound to the root logger are affected by this method.
245 | """
246 | handlers = _get_library_root_logger().handlers
247 |
248 | for handler in handlers:
249 | formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
250 | handler.setFormatter(formatter)
251 |
252 |
253 | def reset_format() -> None:
254 | """
255 | Resets the formatting for HuggingFace muse' loggers.
256 |
257 | All handlers currently bound to the root logger are affected by this method.
258 | """
259 | handlers = _get_library_root_logger().handlers
260 |
261 | for handler in handlers:
262 | handler.setFormatter(None)
263 |
264 |
265 | def warning_advice(self, *args, **kwargs):
266 | """
267 | This method is identical to `logger.warning()`, but if env var muse_NO_ADVISORY_WARNINGS=1 is set, this
268 | warning will not be printed
269 | """
270 | no_advisory_warnings = os.getenv("muse_NO_ADVISORY_WARNINGS", False)
271 | if no_advisory_warnings:
272 | return
273 | self.warning(*args, **kwargs)
274 |
275 |
276 | logging.Logger.warning_advice = warning_advice
277 |
278 |
279 | class EmptyTqdm:
280 | """Dummy tqdm which doesn't do anything."""
281 |
282 | def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
283 | self._iterator = args[0] if args else None
284 |
285 | def __iter__(self):
286 | return iter(self._iterator)
287 |
288 | def __getattr__(self, _):
289 | """Return empty function."""
290 |
291 | def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
292 | return
293 |
294 | return empty_fn
295 |
296 | def __enter__(self):
297 | return self
298 |
299 | def __exit__(self, type_, value, traceback):
300 | return
301 |
302 |
303 | class _tqdm_cls:
304 | def __call__(self, *args, **kwargs):
305 | if _tqdm_active:
306 | return tqdm_lib.tqdm(*args, **kwargs)
307 | else:
308 | return EmptyTqdm(*args, **kwargs)
309 |
310 | def set_lock(self, *args, **kwargs):
311 | self._lock = None
312 | if _tqdm_active:
313 | return tqdm_lib.tqdm.set_lock(*args, **kwargs)
314 |
315 | def get_lock(self):
316 | if _tqdm_active:
317 | return tqdm_lib.tqdm.get_lock()
318 |
319 |
320 | tqdm = _tqdm_cls()
321 |
322 |
323 | def is_progress_bar_enabled() -> bool:
324 | """Return a boolean indicating whether tqdm progress bars are enabled."""
325 | global _tqdm_active
326 | return bool(_tqdm_active)
327 |
328 |
329 | def enable_progress_bar():
330 | """Enable tqdm progress bar."""
331 | global _tqdm_active
332 | _tqdm_active = True
333 |
334 |
335 | def disable_progress_bar():
336 | """Disable tqdm progress bar."""
337 | global _tqdm_active
338 | _tqdm_active = False
339 |
--------------------------------------------------------------------------------
/images/visual_sentences.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ytongbai/LVM/b6de939ef0eb1ee6593445a7f5268145f338749b/images/visual_sentences.jpg
--------------------------------------------------------------------------------
/scripts/gpu_environment.yml:
--------------------------------------------------------------------------------
1 | name: LVM
2 | channels:
3 | - conda-forge
4 | dependencies:
5 | - python=3.10
6 | - pip
7 | - numpy
8 | - scipy
9 | - matplotlib
10 | - seaborn
11 | - jupyter
12 | - tqdm
13 | - sentencepiece
14 | - pip:
15 | - -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
16 | - --extra-index-url https://download.pytorch.org/whl/cu118
17 | - jax[cuda11_pip]==0.4.14
18 | - flax==0.7.0
19 | - optax==0.1.7
20 | - distrax==0.1.4
21 | - chex==0.1.82
22 | - transformers==4.31.0
23 | - torch==2.0.1
24 | - huggingface_hub==0.16.4
25 | - datasets==2.14.2
26 | - einops
27 | - tensorflow==2.11.1
28 | - dill
29 | - absl-py
30 | - wandb
31 | - ml_collections
32 | - gcsfs
33 | - requests
34 | - jupyter_http_over_ws
35 | - lm-eval
36 | - mlxu==0.1.11
37 | - pydantic
38 | - fastapi
39 | - uvicorn
40 | - gradio
41 | - s3fs
42 |
--------------------------------------------------------------------------------
/scripts/tpu_commands.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | function _tpu_ips {
4 | tpu_zone=$1
5 | tpu_project=$2
6 | tpu_name=$3
7 | gcloud alpha compute tpus tpu-vm describe $tpu_name --zone $tpu_zone --project $tpu_project | grep -oP 'externalIp: \K(.+)$'
8 |
9 | }
10 |
11 | function _tpu_create {
12 | tpu_zone=$1
13 | tpu_project=$2
14 | tpu_gen=$3
15 | tpu_cores=$4
16 | tpu_name=$5
17 | if [ "$tpu_gen" = "v3" ]; then
18 | software_version='tpu-vm-base'
19 | else
20 | software_version='tpu-vm-v4-base'
21 | fi
22 |
23 | if [[ $tpu_cores =~ ^[0-9]+$ ]]; then
24 | gcloud alpha compute tpus tpu-vm create \
25 | $tpu_name \
26 | --accelerator-type="$tpu_gen-$tpu_cores" \
27 | --version $software_version \
28 | --zone $tpu_zone \
29 | --project $tpu_project
30 | else
31 | gcloud alpha compute tpus tpu-vm create \
32 | $tpu_name \
33 | --type="$tpu_gen" \
34 | --topology="$tpu_cores" \
35 | --version $software_version \
36 | --zone $tpu_zone \
37 | --project $tpu_project
38 | fi
39 | }
40 |
41 | function _tpu_retry_create {
42 | while true; do
43 | _tpu_create "$@"
44 | sleep 120s
45 | done
46 | }
47 |
48 | function _tpu_cp_ssh_key {
49 | tpu_zone=$1
50 | tpu_project=$2
51 | tpu_name=$3
52 |
53 | gcloud alpha compute tpus tpu-vm scp \
54 | $HOME/.ssh/authorized_keys \
55 | $tpu_name:/home/$USER/.ssh/ \
56 | --worker=all \
57 | --project $tpu_project \
58 | --zone $tpu_zone
59 | }
60 |
61 | function _tpu_setup {
62 | tpu_zone=$1
63 | tpu_project=$2
64 | tpu_name=$3
65 |
66 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name)
67 | for host in $tpu_ips; do
68 | scp $PROJECT_HOME/$PROJECT_NAME/scripts/tpu_vm_setup.sh $host:~/
69 | ssh $host '~/tpu_vm_setup.sh' &
70 | done
71 | wait &> /dev/null
72 |
73 | for host in $tpu_ips; do
74 | scp $PROJECT_HOME/$PROJECT_NAME/scripts/tpu_vm_setup.sh $host:~/
75 | ssh $host '~/tpu_vm_setup.sh' &
76 | done
77 | wait &> /dev/null
78 | }
79 |
80 | function _tpu_check {
81 | tpu_zone=$1
82 | tpu_project=$2
83 | tpu_name=$3
84 |
85 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name)
86 | for host in $tpu_ips; do
87 | echo "============== Checking host: $host =============="
88 | ssh $host 'tmux capture-pane -pt launch -S -2000'
89 | echo
90 | echo
91 | done
92 | }
93 |
94 | function _tpu_copy {
95 | tpu_zone=$1
96 | tpu_project=$2
97 | tpu_name=$3
98 |
99 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name)
100 | for host in $tpu_ips; do
101 | rsync -avPI --exclude=logs --exclude=__pycache__ --exclude=.git $PROJECT_HOME/$PROJECT_NAME $host:~/ &
102 | done
103 | wait &> /dev/null
104 | sleep 1s
105 |
106 | for host in $tpu_ips; do
107 | rsync -avPI --exclude=logs --exclude=__pycache__ --exclude=.git $PROJECT_HOME/$PROJECT_NAME $host:~/ &
108 | done
109 | wait &> /dev/null
110 | sleep 1s
111 | }
112 |
113 | function _tpu_stop {
114 | tpu_zone=$1
115 | tpu_project=$2
116 | tpu_name=$3
117 |
118 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name)
119 | for host in $tpu_ips; do
120 | ssh $host 'tmux kill-session -t launch ; pkill -9 python' &
121 | done
122 | wait &> /dev/null
123 | }
124 |
125 | function _tpu_launch {
126 | tpu_zone=$1
127 | tpu_project=$2
128 | tpu_name=$3
129 | command=$4
130 |
131 | if [ -z "$command" ]; then
132 | echo "Invalid syntax!"
133 | return 1
134 | fi
135 |
136 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name)
137 | for host in $tpu_ips; do
138 | ssh $host "tmux new -d -s launch ~/$PROJECT_NAME/launcher/$command" &
139 | done
140 | wait &> /dev/null
141 | }
142 |
143 | function _tpu_maintain {
144 | tpu_zone=$1
145 | tpu_project=$2
146 | tpu_name=$3
147 |
148 | gcloud alpha compute tpus tpu-vm simulate-maintenance-event $tpu_name \
149 | --project $tpu_project \
150 | --zone=$tpu_zone \
151 | --workers=all
152 | }
153 |
154 | function _tpu_ssh {
155 | tpu_zone=$1
156 | tpu_project=$2
157 | tpu_name=$3
158 | command="$4"
159 |
160 | if [ -z "$command" ]; then
161 | echo "Invalid syntax!"
162 | return 1
163 | fi
164 |
165 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name)
166 | for host in $tpu_ips; do
167 | ssh $host "$command" &
168 | done
169 | wait &> /dev/null
170 | }
171 |
172 | function _tpu_reboot {
173 | tpu_zone=$1
174 | tpu_project=$2
175 | tpu_name=$3
176 |
177 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name)
178 | for host in $tpu_ips; do
179 | ssh $host 'sudo reboot' &
180 | done
181 | wait &> /dev/null
182 | }
183 |
184 |
185 | function tpu {
186 | trap "trap - SIGINT SIGTERM; return 1;" SIGINT SIGTERM
187 |
188 |
189 | # =============== TPU Project Specific Definitions ===============
190 | export PROJECT_HOME='" ]; then
194 | tpu_project=''
195 | tpu_zone='us-east1-d'
196 | tpu_gen='v3'
197 | else
198 | echo "Invalid syntax!"
199 | trap - SIGINT SIGTERM
200 | return 1
201 | fi
202 | # =============== End of TPU Project Specific Definitions ===============
203 |
204 |
205 | if [ "$2" = "list" ]; then
206 | gcloud alpha compute tpus tpu-vm list --zone $tpu_zone --project $tpu_project
207 | elif [ "$2" = "describe" ]; then
208 | gcloud alpha compute tpus tpu-vm describe $3 --zone $tpu_zone --project $tpu_project
209 | elif [ "$2" = "ips" ]; then
210 | _tpu_ips $tpu_zone $tpu_project $3
211 | elif [ "$2" = "delete" ]; then
212 | gcloud alpha compute tpus tpu-vm delete $3 --zone $tpu_zone --project $tpu_project --quiet
213 | elif [ "$2" = "delete_queued" ]; then
214 | gcloud alpha compute tpus queued-resources delete $3 --project $tpu_project --zone $tpu_zone
215 | elif [ "$2" = "create" ]; then
216 | _tpu_create $tpu_zone $tpu_project $tpu_gen $3 $4
217 | elif [ "$2" = "cp_ssh_key" ]; then
218 | _tpu_cp_ssh_key $tpu_zone $tpu_project $3
219 | elif [ "$2" = "retry_create" ]; then
220 | _tpu_retry_create $tpu_zone $tpu_project $tpu_gen $3 $4
221 | elif [ "$2" = "cs" ]; then
222 | _tpu_create $tpu_zone $tpu_project $tpu_gen $3 $4
223 | sleep 90s
224 | _tpu_setup $tpu_zone $tpu_project $4
225 | elif [ "$2" = "check" ]; then
226 | _tpu_check $tpu_zone $tpu_project $3
227 | elif [ "$2" = "setup" ]; then
228 | _tpu_setup $tpu_zone $tpu_project $3
229 | elif [ "$2" = "copy" ]; then
230 | _tpu_copy $tpu_zone $tpu_project $3
231 | elif [ "$2" = "stop" ]; then
232 | _tpu_stop $tpu_zone $tpu_project $3
233 | elif [ "$2" = "launch" ]; then
234 | _tpu_launch $tpu_zone $tpu_project $3 $4
235 | elif [ "$2" = "cl" ]; then
236 | _tpu_copy $tpu_zone $tpu_project $3
237 | _tpu_launch $tpu_zone $tpu_project $3 $4
238 | elif [ "$2" = "maintain" ]; then
239 | _tpu_maintain $tpu_zone $tpu_project $3
240 | elif [ "$2" = "ssh" ]; then
241 | _tpu_ssh $tpu_zone $tpu_project $3 "$4"
242 | elif [ "$2" = "reboot" ]; then
243 | _tpu_reboot $tpu_zone $tpu_project $3
244 | elif [ "$2" = "df" ]; then
245 | _tpu_ssh $tpu_zone $tpu_project $3 'df -h | grep root'
246 | else
247 | echo "Invalid syntax!"
248 | trap - SIGINT SIGTERM
249 | return 1
250 | fi
251 | trap - SIGINT SIGTERM
252 | }
253 |
254 |
255 | export -f tpu _tpu_ips _tpu_create _tpu_setup _tpu_check _tpu_copy _tpu_stop _tpu_launch _tpu_maintain _tpu_ssh _tpu_reboot
--------------------------------------------------------------------------------
/scripts/tpu_vm_setup.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | sudo apt-get update && sudo apt-get install -y \
4 | build-essential \
5 | python-is-python3 \
6 | tmux \
7 | htop \
8 | git \
9 | nodejs \
10 | bmon \
11 | p7zip-full \
12 | nfs-common
13 |
14 |
15 | # Python dependencies
16 | cat > $HOME/tpu_requirements.txt <<- EndOfFile
17 | -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
18 | jax[tpu]==0.4.13
19 | tensorflow==2.11.0
20 | flax==0.7.0
21 | optax==0.1.7
22 | distrax==0.1.3
23 | chex==0.1.7
24 | einops
25 | --extra-index-url https://download.pytorch.org/whl/cpu
26 | torch==2.0.1
27 | transformers==4.31.0
28 | datasets==2.14.2
29 | huggingface_hub==0.16.4
30 | tqdm
31 | h5py
32 | ml_collections
33 | wandb==0.13.5
34 | gcsfs==2022.11.0
35 | requests
36 | typing-extensions
37 | lm-eval==0.3.0
38 | mlxu==0.1.11
39 | sentencepiece
40 | pydantic
41 | fastapi
42 | uvicorn
43 | gradio
44 | EndOfFile
45 |
46 | pip install --upgrade -r $HOME/tpu_requirements.txt
47 |
48 |
49 | # vim configurations
50 | cat > $HOME/.vimrc <<- EndOfFile
51 | set tabstop=4
52 | set shiftwidth=4
53 | set softtabstop=4
54 | set expandtab
55 | set backspace=indent,eol,start
56 | syntax on
57 | EndOfFile
58 |
59 | # tmux configurations
60 | cat > $HOME/.tmux.conf <<- EndOfFile
61 | bind r source-file ~/.tmux.conf
62 |
63 | set -g prefix C-a
64 |
65 | set -g set-titles on
66 | set -g set-titles-string '#(whoami)::#h::#(curl ipecho.net/plain;echo)'
67 |
68 | set -g default-terminal "screen-256color"
69 |
70 | # Status bar customization
71 | #set -g status-utf8 on
72 | set -g status-bg white
73 | set -g status-fg black
74 | set -g status-interval 5
75 | set -g status-left-length 90
76 | set -g status-right-length 60
77 |
78 | set -g status-justify left
79 |
80 | unbind-key C-o
81 | bind -n C-o prev
82 | unbind-key C-p
83 | bind -n C-p next
84 | unbind-key C-w
85 | bind -n C-w new-window
86 |
87 | unbind-key C-j
88 | bind -n C-j select-pane -D
89 | unbind-key C-k
90 | bind -n C-k select-pane -U
91 | unbind-key C-h
92 | bind -n C-h select-pane -L
93 | unbind-key C-l
94 | bind -n C-l select-pane -R
95 |
96 | unbind-key C-e
97 | bind -n C-e split-window -h
98 | unbind-key C-q
99 | bind -n C-q split-window -v
100 | unbind '"'
101 | unbind %
102 |
103 | unbind-key u
104 | bind-key u split-window -h
105 | unbind-key i
106 | bind-key i split-window -v
107 | EndOfFile
108 |
109 |
110 | # htop Configurations
111 | mkdir -p $HOME/.config/htop
112 | cat > $HOME/.config/htop/htoprc <<- EndOfFile
113 | # Beware! This file is rewritten by htop when settings are changed in the interface.
114 | # The parser is also very primitive, and not human-friendly.
115 | fields=0 48 17 18 38 39 40 2 46 47 49 1
116 | sort_key=46
117 | sort_direction=1
118 | hide_threads=0
119 | hide_kernel_threads=1
120 | hide_userland_threads=1
121 | shadow_other_users=0
122 | show_thread_names=0
123 | show_program_path=1
124 | highlight_base_name=0
125 | highlight_megabytes=1
126 | highlight_threads=1
127 | tree_view=0
128 | header_margin=1
129 | detailed_cpu_time=0
130 | cpu_count_from_zero=0
131 | update_process_names=0
132 | account_guest_in_cpu_meter=0
133 | color_scheme=0
134 | delay=15
135 | left_meters=CPU Memory Swap
136 | left_meter_modes=1 1 1
137 | right_meters=Tasks LoadAverage Uptime
138 | right_meter_modes=2 2 2
139 | EndOfFile
140 |
--------------------------------------------------------------------------------
/tokenize_examples/detokenization_muse.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import einops
4 | import base64
5 | import shutil
6 | from torchvision import transforms
7 | from torch.utils.data import Dataset, DataLoader
8 | from PIL import Image
9 | import numpy as np
10 | from torchvision.utils import save_image
11 | import json
12 | import wandb
13 | from tqdm import tqdm
14 | import torch
15 | from torchvision import transforms
16 | from PIL import Image
17 | from muse import VQGANModel
18 | import torchvision.utils as vutils
19 | import matplotlib.pyplot as plt
20 | import argparse
21 |
22 |
23 | def save_temp_images(tensors, temp_dir='./temp_images'):
24 | """
25 | Save a batch of images stored in a PyTorch tensor to a temporary directory.
26 |
27 | Args:
28 | - tensors (torch.Tensor): A batch of images in the format (batch size, channels, height, width).
29 | - temp_dir (str): Path to the temporary directory where images will be saved.
30 | """
31 | # Create a temporary directory if it doesn't exist
32 | if not os.path.exists(temp_dir):
33 | os.makedirs(temp_dir)
34 |
35 | # grid_img = vutils.make_grid(tensors, n=16, padding=0)
36 | # vutils.save_image(grid_img, 'vqlm/muse/vis_reconstruction_tokens/concatenated_image.png')
37 |
38 | # Save each image in the batch to the temporary directory
39 | for i, tensor in enumerate(tensors):
40 | # Construct the file path
41 | file_path = os.path.join(temp_dir, f'image_{i}.png')
42 | # Save the image
43 | save_image(tensor, file_path)
44 |
45 |
46 | def delete_temp_images(temp_dir='./temp_images'):
47 | """
48 | Delete the temporary directory and its contents.
49 |
50 | Args:
51 | - temp_dir (str): Path to the temporary directory to be deleted.
52 | """
53 | if os.path.exists(temp_dir):
54 | shutil.rmtree(temp_dir)
55 |
56 |
57 | class CustomImageDataset(Dataset):
58 | def __init__(self, image_paths, transform=None):
59 | self.image_paths = image_paths
60 | self.transform = transform
61 |
62 | def __len__(self):
63 | return len(self.image_paths)
64 |
65 | def __getitem__(self, idx):
66 | image_path = self.image_paths[idx]
67 | image = Image.open(image_path).convert('RGB')
68 |
69 | if self.transform:
70 | image = self.transform(image)
71 |
72 | # Assuming you want to split or transform the image in a way that fits the '2 * h * w' description
73 | # This can be an operation like doubling the image, or applying some transformation.
74 | # For simplicity, we'll just duplicate the image in the dataset.
75 | # Adjust the transformation as needed.
76 |
77 | return (image,)
78 |
79 |
80 | import random
81 |
82 | def select_random_elements(your_list, num_elements=64, seed_value=42):
83 | """
84 | Randomly selects `num_elements` from `your_list` using `seed_value` for reproducibility.
85 |
86 | Parameters:
87 | - your_list: The list to select elements from.
88 | - num_elements: The number of elements to select. Default is 64.
89 | - seed_value: The seed value for the random number generator to ensure reproducibility. Default is 42.
90 |
91 | Returns:
92 | - A list containing the randomly selected elements.
93 | """
94 | random.seed(seed_value) # Set the seed for reproducibility
95 | selected_elements = random.sample(your_list, num_elements) # Randomly select elements
96 | return selected_elements
97 |
98 | def save_tokens(trecons_name, idx, re_constructed):
99 | tensor_img = re_constructed.cpu().numpy()
100 | images = (tensor_img * 255).transpose(0, 2, 3, 1).astype(np.uint8)
101 |
102 | # Concatenate images horizontally
103 | concatenated_image = np.hstack(images)
104 |
105 | # Convert the NumPy array to a PIL Image and save it
106 | os.makedirs(trecons_name, exist_ok=True)
107 | img = Image.fromarray(concatenated_image)
108 | img.save(os.path.join(trecons_name, '{}.png'.format(idx)))
109 |
110 |
111 | if __name__ == '__main__':
112 | # Create the parser
113 | parser = argparse.ArgumentParser(description='Process some integers.')
114 |
115 | # Add arguments
116 | parser.add_argument('--dataset', default='i1k_edge', type=str, help='An input name')
117 |
118 | # Parse the arguments
119 | args = parser.parse_args()
120 |
121 | # Load the pre-trained vq model from the hub
122 | net = VQGANModel.from_pretrained("vqlm/muse/ckpts/laion").cuda()
123 | net.eval()
124 |
125 | idx = 1
126 | dataset = args.dataset
127 | with open('./lvm/tokenized_muse/{}.jsonl'.format(dataset), 'r') as file:
128 | for line in file:
129 | try:
130 | json_obj = json.loads(line.strip())['tokens']
131 | # Process the JSON object
132 | except:
133 | continue
134 |
135 | decoded_bytes = base64.b64decode(json_obj )
136 | array_dtype = np.int32
137 | array_shape = (-1, 256)
138 | tokens = np.frombuffer(decoded_bytes, dtype=array_dtype).reshape(array_shape)
139 | tokens = torch.tensor(tokens).cuda()
140 |
141 |
142 | with torch.no_grad():
143 |
144 | # detokenized
145 | re_constructed = net.decode_code(tokens)
146 |
147 | plt.figure(figsize=(12, 12))
148 | for i in range(re_constructed.shape[0]):
149 | recon_img = torch.clamp(re_constructed[i],
150 | 0.0, 1.0
151 | )
152 | plt.subplot(4, 4, i + 1)
153 | plt.imshow((((recon_img).permute(1, 2, 0).detach().cpu().numpy() * 255)).astype(np.int32))
154 | plt.grid(False)
155 | plt.axis('off')
156 | save_root = './lvm/other_folder/vis_reconstruction_tokens_check_final/{}'.format(dataset)
157 | os.makedirs(save_root, exist_ok=True)
158 | plt.savefig('{}/{}.png'.format(save_root, idx))
159 |
160 | idx += 1
161 | if idx >= 16:
162 | break
163 |
164 |
--------------------------------------------------------------------------------
/tokenize_examples/map_color.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import os
4 | from tqdm import tqdm
5 | import glob
6 |
7 | def generate_random_palette():
8 | # Generate a random palette with RGB values
9 | return np.array([np.random.choice(range(256), size=3) for _ in range(256)], dtype=np.uint8)
10 |
11 | # Function to apply the color mapping to a grayscale image
12 | def apply_color_palette(image_path, output_path, palette):
13 | # Load the image and convert it to a NumPy array
14 | grayscale_image = Image.open(image_path).convert('L')
15 | grayscale_array = np.array(grayscale_image, dtype=np.int32)
16 |
17 | # Apply the palette using advanced indexing
18 | color_mapped_array = palette[grayscale_array]
19 |
20 | # Convert back to an image and save or display
21 | color_mapped_image = Image.fromarray(color_mapped_array, 'RGB')
22 | color_mapped_image.save(output_path)
23 |
24 |
25 | if __name__ == '__main__':
26 | # Example usage
27 | # Generate a random color palette
28 | # palette = generate_random_palette()
29 | # np.save('./color_map.npy', palette)
30 |
31 | palette = np.load('./vqlm/muse/organize_dataset/i1k/color_map.npy')
32 |
33 | s = 1
34 |
35 | root = '/scratch/partial_datasets/lvm/dataset/prismer_i1k/new/seg_coco/train'
36 |
37 | images = glob.glob(root + '/*/*.png')
38 |
39 | for img in tqdm(images):
40 | cate_path = img.split('/')[-2]
41 | save_root = os.path.join('/scratch/partial_datasets/lvm/dataset/prismer_i1k/new_mapped_color/seg_coco_colored/train', cate_path)
42 | os.makedirs(save_root, exist_ok=True)
43 | save_path = os.path.join(save_root, os.path.basename(img))
44 | apply_color_palette(img, save_path, palette=palette)
45 |
--------------------------------------------------------------------------------
/tokenize_examples/tokenize_catogory_images_muse.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import partial
3 | from tempfile import NamedTemporaryFile
4 | import random
5 | import json
6 | from base64 import b64encode
7 | from tqdm import tqdm, trange
8 |
9 | import numpy as np
10 | import mlxu
11 |
12 | import torch
13 |
14 | import jax
15 | import jax.numpy as jnp
16 | import flax
17 | import einops
18 |
19 | from PIL import Image
20 | from utils import read_image_to_tensor
21 |
22 | FLAGS, _ = mlxu.define_flags_with_default(
23 | input_image_dir='/datasets/ilsvrc_2024-01-04_1601/train',
24 | output_file='/home/yutongbai/vqlm/muse/running_script/tokenized_muse/i1kcate_4.jsonl',
25 | images_per_shot=4,
26 | n_shots=4,
27 | n_epochs=1,
28 | n_workers=8,
29 | dtype='fp32',
30 | batch_size=1
31 | )
32 |
33 | # Define the desired fixed token length
34 | fixed_token_length = 4096
35 |
36 |
37 | class SubfolderImageDataset(torch.utils.data.Dataset):
38 |
39 | def __init__(self, subfolders, images_per_shot):
40 | self.subfolders = subfolders
41 | self.images_per_shot = images_per_shot
42 |
43 | def __getitem__(self, index):
44 | subfolder = self.subfolders[index]
45 | image_files = [os.path.join(subfolder, f) for f in os.listdir(subfolder) if
46 | f.endswith(('.png', '.jpg', '.JPEG'))]
47 | selected_images = np.random.choice(image_files, self.images_per_shot, replace=False)
48 | return [read_image_to_tensor(image_file) for image_file in selected_images]
49 |
50 | def __len__(self):
51 | return len(self.subfolders)
52 |
53 |
54 | def read_image_to_tensor(image_path):
55 | with Image.open(image_path) as img:
56 | img = img.convert('RGB')
57 | img = img.resize((256, 256))
58 | img = np.array(img) / 255.0
59 | return torch.tensor(img).float()
60 |
61 |
62 | def custom_collate(batch):
63 | return [item for sublist in batch for item in sublist]
64 |
65 |
66 | def main(argv):
67 | assert FLAGS.input_image_dir != ''
68 | assert FLAGS.output_file != ''
69 |
70 | subfolders = [os.path.join(FLAGS.input_image_dir, d) for d in os.listdir(FLAGS.input_image_dir) if
71 | os.path.isdir(os.path.join(FLAGS.input_image_dir, d))]
72 |
73 | dataset = SubfolderImageDataset(subfolders, FLAGS.images_per_shot)
74 | dataloader = torch.utils.data.DataLoader(
75 | dataset,
76 | batch_size=FLAGS.batch_size,
77 | shuffle=True,
78 | num_workers=FLAGS.n_workers,
79 | drop_last=True,
80 | collate_fn=custom_collate
81 | )
82 |
83 | total_images = len(dataset) * FLAGS.images_per_shot
84 |
85 | checkpoint_path = os.path.join(
86 | os.path.dirname(os.path.realpath(__file__)),
87 | '/home/vqlm/eval_visualization/torch_ckpts/jax_xh_ckpt.pkl'
88 | )
89 | tokenize_fn = get_tokenizer_fn(checkpoint_path, FLAGS.dtype)
90 |
91 | n_devices = jax.device_count()
92 |
93 | with NamedTemporaryFile() as ntf:
94 | # Adjust the shape of all_tokens to match the fixed token length
95 | all_tokens = np.memmap(ntf, dtype='i4', mode='w+', shape=(total_images, fixed_token_length))
96 |
97 | index = 0
98 | for batch in tqdm(dataloader, ncols=0):
99 | batch_images = np.concatenate(batch, axis=0)
100 |
101 | # Reshape batch_images to (n_devices, batch_size/n_devices, height, width, channels)
102 | batch_images = batch_images.reshape(n_devices, -1, 256, 256, 3)
103 |
104 | tokens = tokenize_fn(batch_images).flatten()
105 |
106 | # Ensure tokens have a fixed length (fixed_token_length)
107 | tokens = tokens[:fixed_token_length]
108 |
109 | # Assign tokens to the all_tokens array
110 | all_tokens[index:index + len(tokens)] = tokens
111 | index += len(tokens)
112 |
113 | with open(FLAGS.output_file, 'w') as fout:
114 | for _ in trange(FLAGS.n_epochs, ncols=0):
115 | indices = np.random.permutation(total_images).reshape(-1, FLAGS.n_shots * FLAGS.images_per_shot)
116 | for i in trange(indices.shape[0], ncols=0):
117 | tokens = all_tokens[indices[i], :].reshape(-1)
118 | data = {'tokens': b64encode(tokens.tobytes()).decode('utf-8')}
119 | fout.write(json.dumps(data) + '\n')
120 |
121 |
122 | if __name__ == '__main__':
123 | mlxu.run(main)
124 |
--------------------------------------------------------------------------------
/tokenize_examples/tokenize_co3d_muse.py:
--------------------------------------------------------------------------------
1 | """ Tokenize multiple related sequences of data """
2 |
3 |
4 |
5 | import os
6 | from copy import deepcopy
7 | from functools import partial
8 | from tempfile import NamedTemporaryFile
9 | import random
10 | import json
11 | from tqdm import tqdm, trange
12 | from muse import VQGANModel
13 | import numpy as np
14 | import mlxu
15 |
16 | import torch
17 |
18 |
19 | import einops
20 |
21 | from PIL import Image
22 |
23 | from utils import match_mulitple_path
24 |
25 | from utils import (
26 | randomly_subsample_frame_indices, list_dir_with_full_path,
27 | is_image, b64encode_tokens, read_image_to_tensor
28 | )
29 |
30 |
31 | FLAGS, _ = mlxu.define_flags_with_default(
32 | input_dir='',
33 | output_file='',
34 | batch_size=4,
35 | n_frames=4,
36 | max_stride=4,
37 | n_shots=4,
38 | n_epochs=1,
39 | n_workers=16,
40 | dtype='fp32',
41 | )
42 |
43 |
44 |
45 | class Co3DDataset(torch.utils.data.Dataset):
46 |
47 | def __init__(self, image_dirs, n_frames):
48 | self.image_dirs = image_dirs
49 | self.n_frames = n_frames
50 |
51 | def __getitem__(self, index):
52 | tasks = self.image_dirs[index]
53 | frames = []
54 | length = len(list_dir_with_full_path(tasks[0]))
55 | indices = randomly_subsample_frame_indices(
56 | length, self.n_frames, FLAGS.max_stride
57 | )
58 | for task in tasks:
59 | task_frames = []
60 | files = sorted(list_dir_with_full_path(task))
61 | for i in indices:
62 | if is_image(files[i]):
63 | task_frames.append(read_image_to_tensor(files[i]))
64 | frames.append(np.stack(task_frames, axis=0))
65 |
66 | return np.stack(frames, axis=0)
67 |
68 | def __len__(self):
69 | return len(self.image_dirs)
70 |
71 |
72 | def main(argv):
73 | assert FLAGS.input_dir != ''
74 | assert FLAGS.output_file != ''
75 |
76 | # Load the pre-trained vq model from the hub
77 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
78 |
79 | net = VQGANModel.from_pretrained('/home/vqlm/muse/ckpts/laion').to(device)
80 | net.eval()
81 |
82 | dirs = []
83 | for d in list_dir_with_full_path(FLAGS.input_dir):
84 | if not os.path.isdir(d):
85 | continue
86 | for d2 in list_dir_with_full_path(d):
87 | if not os.path.isdir(d2):
88 | continue
89 | dirs.append(d2)
90 |
91 | image_dirs = []
92 | for d in dirs:
93 | image_dirs.append((
94 | os.path.join(d, 'images'),
95 | os.path.join(d, 'masks'),
96 | os.path.join(d, 'depth_masks')
97 | ))
98 |
99 | with open(FLAGS.output_file, 'w') as fout:
100 | with torch.no_grad():
101 | for _ in trange(FLAGS.n_epochs, ncols=0):
102 | print(image_dirs[0])
103 | dataset = Co3DDataset(image_dirs, FLAGS.n_frames)
104 | dataloader = torch.utils.data.DataLoader(
105 | dataset,
106 | batch_size=FLAGS.batch_size * FLAGS.n_shots,
107 | shuffle=False,
108 | num_workers=FLAGS.n_workers,
109 | drop_last=True
110 | )
111 |
112 | for batch in tqdm(dataloader, ncols=0):
113 | batch_shape = batch.shape[:-3]
114 | batch = batch.reshape(-1, *batch.shape[-3:])
115 | batch = batch.permute(0,3,1,2)
116 | batch = batch.to(device)
117 |
118 | _, tokens = net.encode(batch)
119 | tokens = tokens.reshape(*batch_shape, tokens.shape[-1])
120 | # batch x task x frame x token
121 | tokens = einops.rearrange(
122 | tokens.cpu().numpy().astype(np.int32), '(b s) t f d -> b s t f d',
123 | s=FLAGS.n_shots
124 | )
125 |
126 | image_mask_tokens = np.concatenate(
127 | (tokens[:, :, 0, :, :], tokens[:, :, 1, :, :]), axis=-2
128 | )
129 | image_depth_tokens = np.concatenate(
130 | (tokens[:, :, 0, :, :], tokens[:, :, 2, :, :]), axis=-2
131 | )
132 | for i in range(tokens.shape[0]):
133 | data = {'tokens': b64encode_tokens(image_mask_tokens[i])}
134 | fout.write(json.dumps(data) + '\n')
135 | data = {'tokens': b64encode_tokens(image_depth_tokens[i])}
136 | fout.write(json.dumps(data) + '\n')
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 | if __name__ == '__main__':
145 | mlxu.run(main)
--------------------------------------------------------------------------------
/tokenize_examples/tokenize_colorization_dataset_muse.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from functools import partial
4 | from tempfile import NamedTemporaryFile
5 | import random
6 | import json
7 | from base64 import b64encode
8 | from tqdm import tqdm, trange
9 | import glob
10 | import numpy as np
11 | import mlxu
12 |
13 | import torch
14 |
15 | import jax
16 | import jax.numpy as jnp
17 | import flax
18 | import einops
19 |
20 | from PIL import Image, UnidentifiedImageError
21 |
22 | from PIL import Image
23 | from muse import VQGANModel
24 | from utils import read_image_to_tensor, is_image, list_dir_with_full_path
25 |
26 |
27 | FLAGS, _ = mlxu.define_flags_with_default(
28 | input_image_dir='/datasets/ilsvrc_2024-01-04_1601/train',
29 | output_file='./lvm/tokenized_muse/i1k_colorization.jsonl',
30 | batch_size=1,
31 | n_shots=8,
32 | n_epochs=2,
33 | n_workers=8,
34 | patch_size=32,
35 | hole_mask_ratio=0.3,
36 | dtype='fp32',
37 | layer = 1
38 | )
39 |
40 |
41 | class PairedImageDataset(torch.utils.data.Dataset):
42 |
43 | def __init__(self, images):
44 | self.images = images
45 |
46 | def __getitem__(self, index):
47 | try:
48 | original_image = read_image_to_tensor(self.images[index])
49 | except UnidentifiedImageError:
50 | original_image = np.zeros((256, 256, 3), dtype=np.float32)
51 |
52 | # make gray images
53 | grayscale = np.dot(original_image[..., :3], [0.2989, 0.5870, 0.1140])
54 | gray_image = np.stack((grayscale,) * 3, axis=-1)
55 | # Stack the grayscale image across the third dimension
56 | gray_image = np.array(gray_image, dtype=np.float32)
57 | return gray_image, original_image
58 |
59 | def __len__(self):
60 | return len(self.images)
61 |
62 |
63 | def main(argv):
64 | assert FLAGS.input_image_dir != ''
65 | assert FLAGS.output_file != ''
66 |
67 | # Load the pre-trained vq model from the hub
68 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
69 |
70 | net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device)
71 | net.eval()
72 |
73 |
74 | input_images = glob.glob('{}{}/*.png'.format(FLAGS.input_image_dir, '/*'*FLAGS.layer))
75 | input_images += glob.glob('{}{}/*.jpg'.format(FLAGS.input_image_dir, '/*' * FLAGS.layer))
76 | input_images += glob.glob('{}{}/*.jpeg'.format(FLAGS.input_image_dir, '/*' * FLAGS.layer))
77 | input_images += glob.glob('{}{}/*.JPEG'.format(FLAGS.input_image_dir, '/*' * FLAGS.layer))
78 |
79 | dataset = PairedImageDataset(input_images)
80 | dataloader = torch.utils.data.DataLoader(
81 | dataset,
82 | batch_size=FLAGS.batch_size * FLAGS.n_shots,
83 | shuffle=False,
84 | num_workers=FLAGS.n_workers,
85 | drop_last=True
86 | )
87 |
88 | total_images = len(input_images) - len(input_images) % (FLAGS.batch_size * FLAGS.n_shots)
89 | print(total_images)
90 | with NamedTemporaryFile() as ntf:
91 | all_tokens = np.memmap(ntf, dtype='i4', mode='w+', shape=(total_images, 512))
92 | all_tokens[:] = 0
93 |
94 | index = 0
95 | # debug_count = 0
96 | for input_image_batch, output_image_batch in tqdm(dataloader, ncols=0):
97 | # if debug_count < 5243:
98 | # debug_count += 1
99 | # continue
100 |
101 | _, input_token_batch = net.encode(input_image_batch.permute(0, 3, 1, 2).to(device))
102 | _, output_token_batch = net.encode(output_image_batch.permute(0, 3, 1, 2).to(device))
103 |
104 | all_tokens[index:index + input_image_batch.shape[0]] = np.concatenate(
105 | [input_token_batch.cpu().numpy().astype(np.int32), output_token_batch.cpu().numpy().astype(np.int32)],
106 | axis=1
107 | )
108 | index += input_image_batch.shape[0]
109 |
110 | with open(FLAGS.output_file, 'w') as fout:
111 | for _ in trange(FLAGS.n_epochs, ncols=0):
112 | indices = np.random.permutation(total_images).reshape(-1, FLAGS.n_shots)
113 | for i in trange(indices.shape[0], ncols=0):
114 | tokens = all_tokens[indices[i], :].reshape(-1)
115 | data = {'tokens': b64encode(tokens.tobytes()).decode('utf-8'),}
116 | fout.write(json.dumps(data) + '\n')
117 |
118 |
119 | if __name__ == '__main__':
120 | mlxu.run(main)
--------------------------------------------------------------------------------
/tokenize_examples/tokenize_inpainting_dataset_muse.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from functools import partial
4 | from tempfile import NamedTemporaryFile
5 | import random
6 | import json
7 | from base64 import b64encode
8 | from tqdm import tqdm, trange
9 | import glob
10 | import numpy as np
11 | import mlxu
12 |
13 | import torch
14 |
15 | import jax
16 | import jax.numpy as jnp
17 | import flax
18 | import einops
19 |
20 | from PIL import Image, UnidentifiedImageError
21 |
22 | from PIL import Image
23 | from muse import VQGANModel
24 | from utils import read_image_to_tensor, is_image, list_dir_with_full_path
25 |
26 |
27 | FLAGS, _ = mlxu.define_flags_with_default(
28 | input_image_dir='/datasets/imagenet_22k_2024-01-04_1601',
29 | output_file='./lvm/tokenized_muse/i22k_inpainting.jsonl',
30 | batch_size=1,
31 | n_shots=8,
32 | n_epochs=5,
33 | n_workers=8,
34 | patch_size=32,
35 | hole_mask_ratio=0.3,
36 | dtype='fp32',
37 | layer = 1
38 | )
39 |
40 |
41 | class PairedImageDataset(torch.utils.data.Dataset):
42 |
43 | def __init__(self, images):
44 | self.images = images
45 |
46 | def __getitem__(self, index):
47 | try:
48 | original_image = read_image_to_tensor(self.images[index])
49 | except UnidentifiedImageError:
50 | original_image = np.zeros((256, 256, 3), dtype=np.float32)
51 |
52 | if np.random.random() < FLAGS.hole_mask_ratio:
53 | h, w, _ = original_image.shape
54 | masked_image = original_image.copy()
55 | min_dim = min(h, w)
56 | size = np.random.randint(int(0.1 * min_dim), int(0.5 * min_dim))
57 | top_left_y = np.random.randint(0, h - size + 1)
58 | top_left_x = np.random.randint(0, w - size + 1)
59 | masked_image[top_left_y:top_left_y+size, top_left_x:top_left_x+size, :] = 0
60 | else:
61 | patches = einops.rearrange(
62 | original_image,
63 | '(h p1) (w p2) c -> (h w) (p1 p2 c)',
64 | p1=FLAGS.patch_size, p2=FLAGS.patch_size
65 | )
66 | mask_ratio = np.random.random()
67 | mask = np.random.random((patches.shape[0], 1)) < mask_ratio
68 | masked_patches = patches * mask
69 | hw = 256 // FLAGS.patch_size
70 | masked_image = einops.rearrange(
71 | masked_patches,
72 | '(h w) (p1 p2 c) -> (h p1) (w p2) c',
73 | p1=FLAGS.patch_size, p2=FLAGS.patch_size,
74 | h=hw, w=hw
75 | )
76 | return masked_image, original_image
77 |
78 | def __len__(self):
79 | return len(self.images)
80 |
81 |
82 | def main(argv):
83 | assert FLAGS.input_image_dir != ''
84 | assert FLAGS.output_file != ''
85 |
86 | # Load the pre-trained vq model from the hub
87 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
88 |
89 | net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device)
90 | net.eval()
91 |
92 |
93 | # input_images = glob.glob('{}{}/*.png'.format(FLAGS.input_image_dir, '/*'*FLAGS.layer))
94 | # input_images += glob.glob('{}{}/*.jpg'.format(FLAGS.input_image_dir, '/*' * FLAGS.layer))
95 | # input_images += glob.glob('{}{}/*.jpeg'.format(FLAGS.input_image_dir, '/*' * FLAGS.layer))
96 | input_images = glob.glob('{}{}/*.JPEG'.format(FLAGS.input_image_dir, '/*' * FLAGS.layer))
97 |
98 |
99 | # input_images = input_images[:100]
100 | dataset = PairedImageDataset(input_images)
101 | dataloader = torch.utils.data.DataLoader(
102 | dataset,
103 | batch_size=FLAGS.batch_size * FLAGS.n_shots,
104 | shuffle=False,
105 | num_workers=FLAGS.n_workers,
106 | drop_last=True
107 | )
108 |
109 | total_images = len(input_images) - len(input_images) % (FLAGS.batch_size * FLAGS.n_shots)
110 | print(total_images)
111 | with NamedTemporaryFile() as ntf:
112 | all_tokens = np.memmap(ntf, dtype='i4', mode='w+', shape=(total_images, 512))
113 | all_tokens[:] = 0
114 |
115 | index = 0
116 | for input_image_batch, output_image_batch in tqdm(dataloader, ncols=0):
117 | _, input_token_batch = net.encode(input_image_batch.permute(0, 3, 1, 2).to(device))
118 | _, output_token_batch = net.encode(output_image_batch.permute(0, 3, 1, 2).to(device))
119 |
120 | all_tokens[index:index + input_image_batch.shape[0]] = np.concatenate(
121 | [input_token_batch.cpu().numpy().astype(np.int32), output_token_batch.cpu().numpy().astype(np.int32)],
122 | axis=1
123 | )
124 | index += input_image_batch.shape[0]
125 |
126 | with open(FLAGS.output_file, 'w') as fout:
127 | for _ in trange(FLAGS.n_epochs, ncols=0):
128 | indices = np.random.permutation(total_images).reshape(-1, FLAGS.n_shots)
129 | for i in trange(indices.shape[0], ncols=0):
130 | tokens = all_tokens[indices[i], :].reshape(-1)
131 | data = {'tokens': b64encode(tokens.tobytes()).decode('utf-8'),}
132 | fout.write(json.dumps(data) + '\n')
133 |
134 |
135 | if __name__ == '__main__':
136 | mlxu.run(main)
--------------------------------------------------------------------------------
/tokenize_examples/tokenize_multi_datasets_muse.py:
--------------------------------------------------------------------------------
1 | import os
2 | from copy import deepcopy
3 | from functools import partial
4 | from tempfile import NamedTemporaryFile
5 | import random
6 | import json
7 | from base64 import b64encode
8 | from tqdm import tqdm, trange
9 | import glob
10 | import numpy as np
11 | import mlxu
12 |
13 | import torch
14 |
15 | import jax
16 | import jax.numpy as jnp
17 | import flax
18 | import einops
19 |
20 | from PIL import Image, UnidentifiedImageError
21 | from muse import VQGANModel
22 | from utils import match_mulitple_path, read_image_to_tensor
23 |
24 |
25 | # FLAGS, _ = mlxu.define_flags_with_default(
26 | # input_dir='./lvm/dataset/prismer_i1k/new',
27 | # input_regex='/datasets/ilsvrc_2024-01-04_1601/train::./lvm/dataset/prismer_i1k/new/depth/train::./lvm/dataset/prismer_i1k/new/edge/train::./lvm/dataset/prismer_i1k/new/normal/train::./lvm/dataset/prismer_i1k/new_mapped_color/seg_coco_colored/train',
28 | # output_file='./lvm/other_folder/old/i1k_cot_uni-mapped_occupy_gpu.jsonl',
29 | # shuffle_tasks=True,
30 | # crop=False,
31 | # batch_size=16,
32 | # max_examples=0,
33 | # n_shots=3,
34 | # n_epochs=5,
35 | # n_workers=8,
36 | # dtype='fp32',
37 | # layer=2,
38 | # )
39 |
40 | FLAGS, _ = mlxu.define_flags_with_default(
41 | input_dir='./lvm/dataset/prismer_i1k/new',
42 | input_regex='/datasets/coco2017_2024-01-04_1601/train2017::/shared/yutongbai/labels/normal/helpers/images::/shared/yutongbai/labels/edge/helpers/images::/shared/yutongbai/labels/depth/helpers/images_coco::./lvm/dataset/prismer_coco/seg_coco_colored_mapped_color',
43 | output_file='./lvm/tokenized_muse/coco_mixed_uni-mapped.jsonl',
44 | shuffle_tasks=True,
45 | crop=False,
46 | batch_size=8,
47 | max_examples=0,
48 | n_shots=3,
49 | n_epochs=5,
50 | n_workers=8,
51 | dtype='fp32',
52 | layer=1,
53 | )
54 |
55 | # FLAGS, _ = mlxu.define_flags_with_default(
56 | # input_dir='./lvm/dataset/prismer_i1k/new',
57 | # input_regex='./lvm/dataset/kitti-cot_crop/image::./lvm/dataset/kitti-cot_crop/depth::./lvm/dataset/kitti-cot_crop/next_frame::./lvm/dataset/kitti-cot_crop/scene_flow::./lvm/dataset/kitti-cot_crop/sementic_seg::./lvm/dataset/kitti-cot_crop/sementic_seg_rbg::./lvm/dataset/kitti-cot_crop/stereo',
58 | # output_file='./lvm/tokenized_muse/kitti_new.jsonl',
59 | # shuffle_tasks=True,
60 | # crop=False,
61 | # batch_size=16,
62 | # max_examples=0,
63 | # n_shots=2,
64 | # n_epochs=5,
65 | # n_workers=8,
66 | # dtype='fp32',
67 | # layer=1,
68 | # )
69 |
70 |
71 | class MultipleImageDataset(torch.utils.data.Dataset):
72 |
73 | def __init__(self, input_images):
74 | self.input_images = input_images
75 |
76 | def __getitem__(self, index):
77 | try:
78 | if FLAGS.crop:
79 | crop_rng = np.random.default_rng(np.random.randint(0, 2 ** 32))
80 | else:
81 | crop_rng = None
82 | return tuple(
83 | read_image_to_tensor(x, crop=FLAGS.crop, crop_rng=deepcopy(crop_rng))
84 | for x in self.input_images[index]
85 | )
86 | except UnidentifiedImageError as e:
87 | print(f'Error: {e} for {self.input_images[index]}')
88 | return self[np.random.randint(0, len(self))]
89 |
90 | def __len__(self):
91 | return len(self.input_images)
92 |
93 | def match_mulitple_path_v2(root, regex):
94 | groups = {}
95 | for modal in regex:
96 | images = glob.glob(modal + '{}.png'.format(FLAGS.layer*'/*'))
97 | images += glob.glob(modal + '{}.jpg'.format(FLAGS.layer*'/*'))
98 | images += glob.glob(modal + '{}.JPEG'.format(FLAGS.layer*'/*'))
99 | images += glob.glob(modal + '{}.jpeg'.format(FLAGS.layer*'/*'))
100 | for img in images:
101 | img_idx = img.split('/')[-1].split('.')[0]
102 | if not img_idx in groups:
103 | groups[img_idx] = []
104 | groups[img_idx].append(img)
105 |
106 | groups = [groups[idx] for idx in groups if len(groups[idx]) == 5]
107 |
108 | return groups
109 |
110 |
111 | def main(argv):
112 | assert FLAGS.input_dir != ''
113 | assert FLAGS.input_regex != ''
114 | assert FLAGS.output_file != ''
115 |
116 | # Load the pre-trained vq model from the hub
117 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
118 |
119 | net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device)
120 | net.eval()
121 |
122 | regex = FLAGS.input_regex.split('::')
123 | input_images = match_mulitple_path_v2(FLAGS.input_dir, regex)
124 |
125 | print(f'Found {len(input_images)} images')
126 | assert len(input_images) > 0, 'No images found'
127 |
128 | if FLAGS.max_examples > 0:
129 | input_images = input_images[:FLAGS.max_examples]
130 |
131 | random.shuffle(input_images)
132 |
133 | dataset = MultipleImageDataset(input_images)
134 | dataloader = torch.utils.data.DataLoader(
135 | dataset,
136 | batch_size=FLAGS.batch_size * FLAGS.n_shots,
137 | shuffle=False,
138 | num_workers=FLAGS.n_workers,
139 | drop_last=True
140 | )
141 |
142 | total_images = len(input_images) - len(input_images) % (FLAGS.batch_size * FLAGS.n_shots)
143 |
144 | with NamedTemporaryFile() as ntf:
145 | all_tokens = np.memmap(ntf, dtype='i4', mode='w+', shape=(total_images, 256 * len(input_images[0])))
146 | all_tokens[:] = 0
147 |
148 | index = 0
149 | for batch in tqdm(dataloader, ncols=0):
150 | k = 0
151 | for image in batch:
152 | batch_size = image.shape[0]
153 | image = einops.rearrange(
154 | image.numpy(), 'b h w c -> b c h w'
155 | )
156 | image = torch.tensor(image).to(device)
157 | _, tokens = net.encode(image)
158 | tokens = einops.rearrange(
159 | tokens.cpu().numpy().astype(np.int32), '(b t) d -> b (t d)', b=batch_size
160 | )
161 | all_tokens[index:index + image.shape[0], k:k + 256] = tokens
162 | k += 256
163 | index += batch[0].shape[0]
164 |
165 | with open(FLAGS.output_file, 'w') as fout:
166 | for _ in trange(FLAGS.n_epochs, ncols=0):
167 | indices = np.random.permutation(total_images).reshape(-1, FLAGS.n_shots)
168 | for i in trange(indices.shape[0], ncols=0):
169 | tokens = deepcopy(all_tokens[indices[i], :])
170 | tokens = einops.rearrange(tokens, 'b (s t) -> b s t', t=256)
171 | if FLAGS.shuffle_tasks:
172 | permutations = np.random.permutation(tokens.shape[1])
173 | tokens = tokens[:, permutations, :]
174 | tokens = tokens.reshape(-1)
175 | data = {'tokens': b64encode(tokens.tobytes()).decode('utf-8'),}
176 | fout.write(json.dumps(data) + '\n')
177 |
178 |
179 | if __name__ == '__main__':
180 | mlxu.run(main)
--------------------------------------------------------------------------------
/tokenize_examples/tokenize_multi_seq_images_muse.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import glob
4 | from functools import partial
5 | from tempfile import NamedTemporaryFile
6 | import random
7 | import json
8 | from base64 import b64encode
9 | from tqdm import tqdm, trange
10 |
11 | import numpy as np
12 | np.float = np.float64
13 | np.int = np.int_
14 | import mlxu
15 | from natsort import natsorted
16 |
17 | import torch
18 |
19 | import jax
20 | import jax.numpy as jnp
21 | import flax
22 | import einops
23 |
24 | from PIL import Image
25 | from muse import VQGANModel
26 | from utils import (
27 | list_dir_with_full_path, is_image, read_image_to_tensor,
28 | randomly_subsample_frame_indices
29 | )
30 |
31 |
32 | FLAGS, _ = mlxu.define_flags_with_default(
33 | input_dirs='',
34 | output_file='',
35 | batch_size=1,
36 | n_frames=16,
37 | n_shots=2,
38 | n_epochs=1,
39 | n_workers=8,
40 | max_stride=4,
41 | dtype='fp32',
42 | )
43 |
44 |
45 | class MultiVideoDataset(torch.utils.data.Dataset):
46 |
47 | def __init__(self, videos, n_frames=8):
48 | self.videos = videos
49 | self.n_tasks = len(videos[0])
50 | self.n_frames = n_frames
51 |
52 | def __getitem__(self, index):
53 | n_frames = len([x for x in list_dir_with_full_path(self.videos[index][0]) if is_image(x)])
54 | for i in range(self.n_tasks):
55 | if len(
56 | [x for x in list_dir_with_full_path(self.videos[index][i])
57 | if is_image(x)]
58 | ) != n_frames:
59 | print('Inconsistent number of frames')
60 | return self[np.random.randint(0, len(self))]
61 | if n_frames < self.n_frames:
62 | print(n_frames)
63 | return self[np.random.randint(0, len(self))]
64 |
65 | indices = randomly_subsample_frame_indices(
66 | n_frames, self.n_frames, FLAGS.max_stride,
67 | random_start=True
68 | )
69 |
70 | all_frames = []
71 | for task_idx in range(self.n_tasks):
72 | frames = []
73 | all_files = [x for x in list_dir_with_full_path(self.videos[index][task_idx]) if is_image(x)]
74 | all_files = natsorted(all_files)
75 | for idx in indices:
76 | frames.append(read_image_to_tensor(all_files[idx]))
77 |
78 | all_frames.append(np.stack(frames, axis=0))
79 |
80 | return np.stack(all_frames, axis=0)
81 |
82 | def __len__(self):
83 | return len(self.videos)
84 |
85 |
86 | def main(argv):
87 | assert FLAGS.input_dirs != ''
88 | assert FLAGS.output_file != ''
89 |
90 | # Load the pre-trained vq model from the hub
91 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92 |
93 | net = VQGANModel.from_pretrained('/home/vqlm/muse/ckpts/laion').to(device)
94 | net.eval()
95 |
96 | video_dirs = [sorted(glob.glob(x)) for x in FLAGS.input_dirs.split('::')]
97 | n_tasks = len(video_dirs)
98 |
99 | groups = {}
100 |
101 | for videos in [sorted(glob.glob(x)) for x in FLAGS.input_dirs.split('::')]:
102 | for video in videos:
103 | name = video.split('/')[-1]
104 | # name = video[:-len(video.split('/')[-1].split('_')[-1]) - 1]
105 | if name not in groups:
106 | groups[name] = []
107 | groups[name].append(video)
108 |
109 | video_dirs = []
110 | for name, videos in groups.items():
111 | if len(videos) == n_tasks:
112 | video_dirs.append(videos)
113 |
114 |
115 | with open(FLAGS.output_file, 'w') as fout:
116 | dataset = MultiVideoDataset(video_dirs, n_frames=FLAGS.n_frames)
117 |
118 | dataloader = torch.utils.data.DataLoader(
119 | dataset,
120 | batch_size=FLAGS.batch_size * FLAGS.n_shots,
121 | shuffle=False,
122 | num_workers=FLAGS.n_workers,
123 | prefetch_factor=4,
124 | drop_last=True,
125 | )
126 | with torch.no_grad():
127 | for _ in trange(FLAGS.n_epochs, ncols=0):
128 |
129 | all_tokens = np.zeros(
130 | dtype='i4',
131 | shape=(len(dataloader) * FLAGS.batch_size * FLAGS.n_shots, n_tasks, FLAGS.n_frames, 256)
132 | )
133 | index = 0
134 |
135 | for batch in tqdm(dataloader, ncols=0):
136 | batch_size = batch.shape[0]
137 | batch = einops.rearrange(
138 | batch.numpy(), 'b t f h w c -> (b t f) c h w'
139 | )
140 | batch = torch.tensor(batch).to(device)
141 | _, tokens = net.encode(batch)
142 | tokens = einops.rearrange(
143 | tokens.cpu().numpy().astype(np.int32), '(b t f) d -> b t f d', b=batch_size, t=n_tasks, f=FLAGS.n_frames
144 | )
145 | all_tokens[index:index + batch_size, ...] = tokens
146 | index += batch_size
147 |
148 |
149 | random_indices = np.random.permutation(all_tokens.shape[0])
150 | all_tokens = all_tokens[random_indices, ...]
151 |
152 |
153 | tokens_sep = einops.rearrange(
154 | all_tokens, '(b x) t s d -> b (x t s d)',
155 | x=FLAGS.n_shots
156 | )
157 | tokens_interleave = einops.rearrange(
158 | all_tokens, '(b x) t s d -> b (x s t d)',
159 | x=FLAGS.n_shots
160 | )
161 |
162 | for i in range(tokens_sep.shape[0]):
163 | data = {'tokens': b64encode(tokens_sep[i].tobytes()).decode('utf-8')}
164 | fout.write(json.dumps(data) + '\n')
165 |
166 | data = {'tokens': b64encode(tokens_interleave[i].tobytes()).decode('utf-8')}
167 | fout.write(json.dumps(data) + '\n')
168 |
169 |
170 |
171 |
172 | if __name__ == '__main__':
173 | mlxu.run(main)
174 |
--------------------------------------------------------------------------------
/tokenize_examples/tokenize_paired_dataset_muse.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from functools import partial
4 | from tempfile import NamedTemporaryFile
5 | import random
6 | import json
7 | from base64 import b64encode
8 | from tqdm import tqdm, trange
9 |
10 | import numpy as np
11 | import mlxu
12 |
13 | import torch
14 |
15 | import jax
16 | import jax.numpy as jnp
17 | import flax
18 | import einops
19 |
20 | from PIL import Image
21 | from muse import VQGANModel
22 | from utils import read_image_to_tensor
23 |
24 |
25 | FLAGS, _ = mlxu.define_flags_with_default(
26 | input_image_dir='./kitti-cot_crop/sementic_seg',
27 | output_image_dir='./kitti-cot_crop/sementic_seg',
28 | output_file='./test_kitti_semantic.jsonl',
29 | input_filter_key='',
30 | output_filter_key='',
31 | input_suffix='',
32 | output_suffix='',
33 | crop=False,
34 | batch_size=1,
35 | n_shots=8,
36 | n_epochs=5,
37 | n_workers=8,
38 | dtype='fp32',
39 | )
40 |
41 |
42 | class PairedImageDataset(torch.utils.data.Dataset):
43 |
44 | def __init__(self, input_images, output_images):
45 | self.input_images = input_images
46 | self.output_images = output_images
47 |
48 | def __getitem__(self, index):
49 | try:
50 | return (
51 | read_image_to_tensor(self.input_images[index], crop=FLAGS.crop),
52 | read_image_to_tensor(self.output_images[index], crop=FLAGS.crop)
53 | )
54 | except Exception as e:
55 | print(f'Error: {e} for {self.input_images[index]}')
56 | return self[np.random.randint(0, len(self))]
57 |
58 | def __len__(self):
59 | return len(self.input_images)
60 |
61 |
62 | def main(argv):
63 | assert FLAGS.input_image_dir != ''
64 | assert FLAGS.output_image_dir != ''
65 | assert FLAGS.output_file != ''
66 |
67 | # Load the pre-trained vq model from the hub
68 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
69 |
70 | net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device)
71 | net.eval()
72 |
73 |
74 | input_images = os.listdir(FLAGS.input_image_dir)
75 | input_images = [i for i in input_images if i.endswith('.png') or i.endswith('.jpg') or i.endswith('.jpeg')]
76 | input_images = [i for i in input_images if FLAGS.input_filter_key in i]
77 | input_images = sorted(input_images)
78 | output_images = os.listdir(FLAGS.output_image_dir)
79 | output_images = [i for i in output_images if i.endswith('.png') or i.endswith('.jpg') or i.endswith('.jpeg')]
80 | output_images = [i for i in output_images if FLAGS.output_filter_key in i]
81 | output_images = sorted(output_images)
82 |
83 | assert len(input_images) == len(output_images)
84 |
85 |
86 | input_images = [
87 | os.path.join(FLAGS.input_image_dir, s)
88 | for s in input_images
89 | ]
90 | output_images = [
91 | os.path.join(FLAGS.output_image_dir, s)
92 | for s in output_images
93 | ]
94 |
95 | dataset = PairedImageDataset(input_images, output_images)
96 | dataloader = torch.utils.data.DataLoader(
97 | dataset,
98 | batch_size=FLAGS.batch_size * FLAGS.n_shots,
99 | shuffle=False,
100 | num_workers=FLAGS.n_workers,
101 | drop_last=True
102 | )
103 |
104 | total_images = len(input_images) - len(input_images) % (FLAGS.batch_size * FLAGS.n_shots)
105 |
106 | with torch.no_grad():
107 | with NamedTemporaryFile() as ntf:
108 | all_tokens = np.memmap(ntf, dtype='i4', mode='w+', shape=(total_images, 512))
109 | all_tokens[:] = 0
110 |
111 | index = 0
112 | for input_image_batch, output_image_batch in tqdm(dataloader, ncols=0):
113 | _, input_token_batch = net.encode(input_image_batch.permute(0,3,1,2).to(device))
114 | _, output_token_batch = net.encode(output_image_batch.permute(0, 3, 1, 2).to(device))
115 |
116 |
117 | all_tokens[index:index + input_image_batch.shape[0]] = np.concatenate(
118 | [input_token_batch.cpu().numpy().astype(np.int32), output_token_batch.cpu().numpy().astype(np.int32)],
119 | axis=1
120 | )
121 | index += input_image_batch.shape[0]
122 |
123 | with open(FLAGS.output_file, 'w') as fout:
124 | for _ in trange(FLAGS.n_epochs, ncols=0):
125 | indices = np.random.permutation(total_images).reshape(-1, FLAGS.n_shots)
126 | for i in trange(indices.shape[0], ncols=0):
127 | tokens = all_tokens[indices[i], :].reshape(-1)
128 | data = {'tokens': b64encode(tokens.tobytes()).decode('utf-8'),}
129 | fout.write(json.dumps(data) + '\n')
130 |
131 |
132 | if __name__ == '__main__':
133 | mlxu.run(main)
--------------------------------------------------------------------------------
/tokenize_examples/tokenize_seq_images_muse.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import glob
4 | from functools import partial
5 | from tempfile import NamedTemporaryFile
6 | import random
7 | import json
8 | from base64 import b64encode
9 | from tqdm import tqdm, trange
10 | from muse import VQGANModel
11 |
12 | import numpy as np
13 | np.float = np.float64
14 | np.int = np.int_
15 | import mlxu
16 |
17 | import torch
18 |
19 |
20 | import einops
21 |
22 | from PIL import Image
23 |
24 | from utils import (
25 | list_dir_with_full_path, is_image, read_image_to_tensor,
26 | randomly_subsample_frame_indices
27 | )
28 |
29 |
30 |
31 |
32 | FLAGS, _ = mlxu.define_flags_with_default(
33 | input_dirs='',
34 | output_file='',
35 | batch_size=1,
36 | n_frames=16,
37 | n_epochs=1,
38 | n_workers=8,
39 | max_stride=4,
40 | dtype='fp32',
41 | )
42 |
43 |
44 | class VideoDataset(torch.utils.data.Dataset):
45 |
46 | def __init__(self, videos, n_frames=8):
47 | self.videos = videos
48 | self.n_frames = n_frames
49 |
50 | def __getitem__(self, index):
51 | frames = []
52 | for file in sorted(list_dir_with_full_path(self.videos[index])):
53 | if is_image(file):
54 | frames.append(read_image_to_tensor(file))
55 | if len(frames) < self.n_frames:
56 | return self[np.random.randint(0, len(self))]
57 | indices = randomly_subsample_frame_indices(
58 | len(frames), self.n_frames, FLAGS.max_stride,
59 | random_start=True
60 | )
61 | frames = np.stack([frames[i] for i in indices], axis=0)
62 | return frames
63 |
64 | def __len__(self):
65 | return len(self.videos)
66 |
67 |
68 | def main(argv):
69 | assert FLAGS.input_dirs != ''
70 | assert FLAGS.output_file != ''
71 |
72 | # Load the pre-trained vq model from the hub
73 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
74 |
75 | net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device)
76 | net.eval()
77 |
78 | # videos = list_dir_with_full_path(FLAGS.input_dir)
79 | videos = glob.glob(FLAGS.input_dirs)
80 |
81 | with torch.no_grad():
82 | with open(FLAGS.output_file, 'w') as fout:
83 | dataset = VideoDataset(videos, n_frames=FLAGS.n_frames)
84 | dataloader = torch.utils.data.DataLoader(
85 | dataset,
86 | batch_size=FLAGS.batch_size,
87 | shuffle=False,
88 | num_workers=FLAGS.n_workers,
89 | prefetch_factor=4,
90 | drop_last=True,
91 | )
92 | for _ in range(FLAGS.n_epochs):
93 | for batch in tqdm(dataloader, ncols=0):
94 | batch_size = batch.shape[0]
95 | batch = einops.rearrange(
96 | batch.numpy(), 'b t h w c -> (b t) c h w'
97 | )
98 | batch = torch.tensor(batch).to(device)
99 | _, tokens = net.encode(batch)
100 | tokens = einops.rearrange(
101 | tokens.cpu().numpy().astype(np.int32), '(b t) d -> b (t d)', b=batch_size
102 | )
103 | for i in range(batch_size):
104 | data = {'tokens': b64encode(tokens[i].tobytes()).decode('utf-8')}
105 | fout.write(json.dumps(data) + '\n')
106 |
107 |
108 |
109 | if __name__ == '__main__':
110 | mlxu.run(main)
111 |
--------------------------------------------------------------------------------
/tokenize_examples/tokenize_video_muse.py:
--------------------------------------------------------------------------------
1 | from base64 import b64encode
2 | from tqdm import tqdm, trange
3 | import numpy as np
4 | np.float = np.float64
5 | np.int = np.int_
6 | from utils import read_frames_from_video, is_video
7 |
8 | import einops
9 | from torch.utils.data import Dataset, DataLoader
10 | from PIL import Image
11 | import numpy as np
12 | from tqdm import tqdm
13 | import torch
14 | from muse import VQGANModel
15 | from base64 import b64encode
16 | import json
17 | import os
18 | import mlxu
19 |
20 |
21 |
22 | FLAGS, _ = mlxu.define_flags_with_default(
23 | input_dir='DAVIS/JPEGImages/480p',
24 | output_file='vqlm/muse/running_script/tokenized_muse/davis.jsonl',
25 | batch_size=32,
26 | n_frames=16,
27 | n_workers=32,
28 | strides='8',
29 | n_epochs=1,
30 | dtype='fp32',
31 | )
32 |
33 |
34 | class VideoDataset(torch.utils.data.Dataset):
35 |
36 | def __init__(self, videos, n_frames=8, stride=1):
37 | self.videos = videos
38 | self.n_frames = n_frames
39 | self.stride = stride
40 |
41 | def __getitem__(self, index):
42 | frames = read_frames_from_video(self.videos[index], self.n_frames, self.stride)
43 | if frames is None:
44 | return self[np.random.randint(0, len(self))]
45 | return frames
46 |
47 | def __len__(self):
48 | return len(self.videos)
49 |
50 |
51 | def main(argv):
52 | assert FLAGS.input_dir != ''
53 | assert FLAGS.output_file != ''
54 |
55 | # Load the pre-trained vq model from the hub
56 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57 |
58 | net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device)
59 | net.eval()
60 |
61 | videos = []
62 | for root, _, files in os.walk(FLAGS.input_dir):
63 | for file in files:
64 | if is_video(file):
65 | videos.append(os.path.join(root, file))
66 |
67 | with open(FLAGS.output_file, 'w') as fout:
68 | with torch.no_grad():
69 | for epoch in trange(FLAGS.n_epochs, ncols=0):
70 | for stride in tqdm(FLAGS.strides.split(','), ncols=0):
71 | stride = int(stride)
72 | dataset = VideoDataset(videos, n_frames=FLAGS.n_frames, stride=stride)
73 | dataloader = torch.utils.data.DataLoader(
74 | dataset,
75 | batch_size=FLAGS.batch_size,
76 | shuffle=False,
77 | num_workers=FLAGS.n_workers,
78 | prefetch_factor=4,
79 | drop_last=True,
80 | )
81 | for batch in tqdm(dataloader, ncols=0):
82 | batch_size = batch.shape[0]
83 | batch = einops.rearrange(
84 | batch.numpy(), 'b t h w c -> (b t) c h w'
85 | )
86 | batch = torch.tensor(batch).to(device)
87 | _, tokens = net.encode(batch)
88 | tokens = einops.rearrange(
89 | tokens.cpu().numpy().astype(np.int32), '(b t) d -> b (t d)', b=batch_size
90 | )
91 | for i in range(batch_size):
92 | data = {'tokens': b64encode(tokens[i].tobytes()).decode('utf-8'),}
93 | fout.write(json.dumps(data) + '\n')
94 |
95 |
96 |
97 | if __name__ == '__main__':
98 | mlxu.run(main)
99 |
--------------------------------------------------------------------------------