12 |
13 | Additionally, small batch sizes are much more robust to hyperparameter mispecification, meaning that when the tuning budget is limited, small batch sizes perform better in expecation.
14 |
15 |
16 |
17 | We hope that our results can be useful for memory-constrained practitioners, since small batch sizes allow the use of simple optimizers. For example, instead of using LoRA for fine-tuning, it might be preferable to do full fine-tuning with a small batch size and a memory-efficient optimizer like Adafactor, matching the performance of Adam while maintaining a similar memory footprint to LoRA.
18 |
19 |
20 |
21 | ## Code structure
22 |
23 | We implemented all of our experiments in JAX from scratch, using a mix of data, tensor, and sequence parallelism. We used two independent codebases for [pretraining](pretraining) and [fine-tuning](finetuning). Please refer to either codebase for more details on running experiments.
24 |
25 | All of our visualizations were done using Jupyter Notebooks found in the [utils](utils) directory.
26 |
27 | ## Citation
28 |
29 | ```bibtex
30 | @misc{smallbatch,
31 | title={Small Batch Size Training for Language Models: When Vanilla SGD Works, and Why Gradient Accumulation Is Wasteful},
32 | author={Martin Marek and Sanae Lotfi and Aditya Somasundaram and Andrew Gordon Wilson and Micah Goldblum},
33 | year={2025},
34 | eprint={2507.07101},
35 | archivePrefix={arXiv},
36 | primaryClass={cs.LG}
37 | }
38 | ```
39 |
--------------------------------------------------------------------------------
/pretraining/download_fineweb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import fire
3 | import numpy as np
4 | from pathlib import Path
5 | from tqdm.auto import tqdm
6 | from huggingface_hub import hf_hub_download
7 | from huggingface_hub.utils import disable_progress_bars; disable_progress_bars()
8 | from typing import Optional, Literal
9 |
10 |
11 | def load_data_shard(file):
12 | # https://github.com/KellerJordan/modded-nanogpt/blob/a202a3a0ca99d69bb7f847e5337c7c6e0890fd92/train_gpt.py#L411
13 | header = np.fromfile(file, dtype=np.int32, count=256) # header is 256 int32
14 | assert header[0] == 20240520, "magic number mismatch in the data .bin file"
15 | assert header[1] == 1, "unsupported version"
16 | num_tokens = int(header[2]) # number of tokens (claimed)
17 | with Path(file).open("rb", buffering=0) as f:
18 | tokens = np.empty(num_tokens, dtype=np.uint16) # avoid pin_memory copy by @YouJiacheng
19 | f.seek(256 * 4)
20 | nbytes = f.readinto(tokens) # avoid bytes->array copy by @YouJiacheng
21 | assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
22 | return tokens
23 |
24 |
25 | def download_dataset(
26 | dataset: Literal['fineweb', 'finewebedu'] = 'fineweb',
27 | num_chunks: Optional[int] = None,
28 | ):
29 | """download dataset, save it as a np.memmap binary file"""
30 |
31 | # get num. chunks
32 | # by default, download all chunnks (10B tokens)
33 | # each chunk is 100M tokens
34 | if num_chunks is None:
35 | if dataset == 'fineweb': num_chunks = 103
36 | if dataset == 'finewebedu': num_chunks = 99
37 |
38 | # load chunks into memory
39 | print('downloading...')
40 | shards = []
41 | for i in tqdm(range(1, num_chunks+1)):
42 | shard_path = hf_hub_download(repo_id=f'kjj0/{dataset}10B-gpt2', filename=f'{dataset}_train_{i:06}.bin', repo_type="dataset")
43 | shards += [load_data_shard(shard_path)]
44 |
45 | # save to disk
46 | print('saving...')
47 | out_dir = os.path.expanduser('~/datasets')
48 | out_path = f'{out_dir}/{dataset}_gpt2.bin'
49 | os.makedirs(out_dir, exist_ok=True)
50 | n_tokens = sum(map(len, shards))
51 | out = np.memmap(out_path, dtype=np.uint16, mode='w+', shape=[n_tokens])
52 | i = 0
53 | for shard in tqdm(shards):
54 | out[i:i+len(shard)] = shard
55 | i += len(shard)
56 | out.flush()
57 |
58 |
59 | if __name__ == '__main__':
60 | fire.Fire(download_dataset)
61 |
--------------------------------------------------------------------------------
/pretraining/utils.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from flax import nnx
4 | from collections.abc import Mapping
5 |
6 |
7 | def flatten_dict(d, prefix=None, sep='.'):
8 | if isinstance(d, Mapping):
9 | out = {}
10 | for k, v in d.items():
11 | nested_prefix = k if prefix is None else f'{prefix}{sep}{k}'
12 | out |= flatten_dict(v, nested_prefix, sep)
13 | return out
14 | else:
15 | return {prefix: d}
16 |
17 |
18 | def get_num_model_params(model: nnx.Module):
19 | graphdef, params = nnx.split(model, nnx.Param)
20 | n_params = jax.tree.reduce(lambda x, y: x + jnp.size(y), params, 0)
21 | return n_params
22 |
23 |
24 | def halflife_to_decay(t_token, n_batch=1):
25 | """
26 | notation:
27 | - t_token: halflife measured in number of tokens
28 | - t_steps: halflife measured in number of steps
29 | - n_batch: number of tokens per batch
30 | - d: decay coefficient
31 | """
32 | t_steps = t_token / n_batch # halflife (measured in number of steps)
33 | d = (1/2)**(1/t_steps)
34 | return d
35 |
36 |
37 | def decay_to_halflife(d, n_batch=1):
38 | """
39 | notation:
40 | - t_token: halflife measured in number of tokens
41 | - t_steps: halflife measured in number of steps
42 | - n_batch: number of tokens per batch
43 | - d: decay coefficient
44 | """
45 | # note: d**t_steps = 1/2
46 | t_steps = jnp.log(1/2) / jnp.log(d)
47 | t_token = t_steps * n_batch
48 | return t_token
49 |
50 |
51 | @jax.jit
52 | def to_bf16_stochastic(key, source):
53 | """
54 | performs (float32 -> bfloat16) stochastic rounding
55 | based on https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905
56 | """
57 | # ensure the source array is float32, the bitwise logic depends on it
58 | source = source.astype(jnp.float32)
59 |
60 | # reinterpert float32 source as uint32 to allow bitwise operations
61 | source_uint32 = jax.lax.bitcast_convert_type(source, jnp.uint32)
62 |
63 | # randomly flip lower 16 bits of the float32 source
64 | # these are the bits that get truncated when converting to bf16
65 | random_int = jax.random.randint(
66 | key,
67 | shape=source.shape,
68 | minval=0,
69 | maxval=(1 << 16),
70 | dtype=jnp.uint32
71 | )
72 | result_uint32 = source_uint32 + random_int
73 |
74 | # mask off lower 16 bits, keep top 16 bits (corresponding to bf16 format)
75 | mask = jnp.uint32(0xFFFF0000)
76 | result_uint32 = jax.lax.bitwise_and(result_uint32, mask)
77 |
78 | # cast result to bf16
79 | result_fp32 = jax.lax.bitcast_convert_type(result_uint32, jnp.float32)
80 | result_bf16 = result_fp32.astype(jnp.bfloat16)
81 |
82 | return result_bf16
83 |
--------------------------------------------------------------------------------
/finetuning/optimizer.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import optax
4 | from optax import tree_utils as otu
5 | from flax import nnx
6 | import factorized, utils
7 | from typing import Optional
8 |
9 |
10 | class ModelAndOptimizer(nnx.Optimizer):
11 | """
12 | Extends nnx.ModelAndOptimizer (v0.12.0) with stochastic rounding.
13 | """
14 | def __init__(self, model, tx, wrt=nnx.Param, stochastic_round=False):
15 | super().__init__(model, tx, wrt=wrt)
16 | self.model = model
17 | self.stochastic_round = stochastic_round # <- CHANGED: added stochastic_round support
18 |
19 | def update(self, key, grads, **kwargs):
20 | param_arrays = nnx.to_arrays(nnx.pure(nnx.state(self.model, self.wrt)))
21 | grad_arrays = nnx.to_arrays(nnx.pure(nnx.state(grads)))
22 | opt_state_arrays = nnx.to_arrays(nnx.pure(self.opt_state))
23 | kwargs_arrays = nnx.to_arrays(nnx.pure(kwargs))
24 |
25 | updates, new_opt_state = self.tx.update(grad_arrays, opt_state_arrays, param_arrays, **kwargs_arrays)
26 | new_params = apply_updates(key, param_arrays, updates, self.stochastic_round) # <- CHANGED: added stochastic_round support
27 |
28 | nnx.update(self.model, new_params)
29 | nnx.update(self.opt_state, nnx.state(new_opt_state))
30 | self.step[...] += 1
31 |
32 |
33 | def apply_updates(
34 | key: jax.Array,
35 | params: optax.Params,
36 | updates: optax.Updates,
37 | stochastic_round = False
38 | ) -> optax.Params:
39 | """Extends optax.apply_updates with stochastic rounding."""
40 | keys = otu.tree_split_key_like(key, params)
41 | def leaf_update(p, u, key):
42 | if p is None: return None
43 | param_dtype = jnp.asarray(p).dtype
44 | if stochastic_round:
45 | p = p.astype(jnp.float32) + u
46 | p = utils.to_bf16_stochastic(key, p)
47 | else:
48 | p += u
49 | return p.astype(param_dtype)
50 | return jax.tree.map(leaf_update, params, updates, keys, is_leaf=lambda x: x is None)
51 |
52 |
53 | def adafactor(
54 | learning_rate: optax.ScalarOrSchedule,
55 | decay_rate: float = 0.8,
56 | clipping_threshold: Optional[float] = 1.0,
57 | min_dim_size_to_factor: int = 128,
58 | ) -> optax.GradientTransformation:
59 | """
60 | Adafactor reimplemented to use float32 state, regardless of param dtype.
61 | https://github.com/google-deepmind/optax/blob/8973bb3c77b07850737246815f1c028b53fffbe0/optax/_src/alias.py#L225#L327
62 | """
63 | return optax.chain(
64 | factorized.scale_by_factored_rms(decay_rate=decay_rate, min_dim_size_to_factor=min_dim_size_to_factor),
65 | optax.clip_by_block_rms(clipping_threshold) if clipping_threshold is not None else optax.identity(),
66 | optax.scale_by_learning_rate(learning_rate),
67 | optax.scale_by_param_block_rms(),
68 | )
69 |
--------------------------------------------------------------------------------
/finetuning/sampler.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import flax
4 | from flax import nnx
5 | from functools import partial
6 |
7 |
8 | @flax.struct.dataclass
9 | class SamplingState:
10 | key: jax.Array
11 | step: jnp.int32
12 | tokens: jnp.ndarray # [B, T]
13 | kv_cache: dict
14 | done: jnp.ndarray # [B]
15 |
16 |
17 | def _sample_top_p(key, probs, p=0.95):
18 | """Sample a token using top-p sampling.
19 | https://github.com/google/flax/blob/cca78723892c539b42c261d2625168d39b61c495/examples/gemma/sampler.py#L38"""
20 | probs_sorted, indices = jax.lax.top_k(probs, k=probs.shape[-1])
21 | cumsum_probs = jnp.cumsum(probs_sorted, axis=-1)
22 | mask = cumsum_probs - probs_sorted > p
23 | probs_sorted = jnp.where(mask, 0.0, probs_sorted)
24 | probs_sorted /= jnp.sum(probs_sorted, axis=-1, keepdims=True)
25 | next_token = jax.random.categorical(key, logits=jnp.log(probs_sorted))
26 | next_token = jnp.take_along_axis(indices, next_token[..., None], axis=-1)
27 | next_token = jnp.squeeze(next_token, axis=-1)
28 | return next_token
29 |
30 |
31 | def _sample_step(state, model_graphdef, model_state, pad_id, eos_id, temperature=1):
32 | model = nnx.merge(model_graphdef, model_state)
33 |
34 | # sample next token
35 | key, key_sampling = jax.random.split(state.key)
36 | input_token = state.tokens[:, state.step, None] # [B, 1]
37 | logits, kv_cache = model(input_token, state.kv_cache) # [B, 1, V]
38 | if temperature == 0:
39 | sampled_token = logits[:, 0, :].argmax(1) # [B]
40 | else:
41 | probs = jax.nn.softmax(logits[:, 0, :] / temperature, axis=-1) # [B, V]
42 | sampled_token = _sample_top_p(key_sampling, probs)
43 |
44 | # update buffer
45 | next_token = state.tokens[:, state.step+1]
46 | update_token = jnp.where((~state.done) & (next_token==pad_id), sampled_token, next_token)
47 | tokens = state.tokens.at[:, state.step+1].set(update_token)
48 |
49 | # check if sampling is done
50 | done = state.done | ((next_token==pad_id) & (sampled_token==eos_id))
51 |
52 | return SamplingState(key, state.step+1, tokens, kv_cache, done)
53 |
54 |
55 | @partial(jax.jit, static_argnames=('model_graphdef', 'temperature'))
56 | def sample(key, model_graphdef, model_state, tokens, temperature=1, pad_id=0, eos_id=1):
57 | model = nnx.merge(model_graphdef, model_state)
58 | B, T = tokens.shape
59 |
60 | # initialize state
61 | state = SamplingState(
62 | key=key,
63 | step=0,
64 | tokens=tokens,
65 | kv_cache=model.init_kv_cache(B, T),
66 | done=jnp.zeros([B], dtype=jnp.bool_),
67 | )
68 |
69 | # sample next token inside a while loop
70 | step_fn = lambda state: _sample_step(state, *nnx.split(model), pad_id, eos_id, temperature)
71 | cond_fn = lambda state: (state.step < T) & jnp.any(~state.done)
72 | state = jax.lax.while_loop(cond_fn, step_fn, state)
73 |
74 | return state.tokens
75 |
--------------------------------------------------------------------------------
/pretraining/train.py:
--------------------------------------------------------------------------------
1 | import math
2 | import jax
3 | import jax.numpy as jnp
4 | import optax
5 | import wandb
6 | from functools import partial
7 | from flax import nnx
8 | from optax import tree_utils as otu
9 | from tqdm.auto import tqdm
10 | from omegaconf.dictconfig import DictConfig
11 | import data, utils
12 | import model as model_lib
13 | import optimizer as optimizer_lib
14 |
15 |
16 | @partial(jax.jit, static_argnames=('model_graphdef', 'pad'))
17 | def loss_fn(model_state, model_graphdef, x, pad=False): # [B, T]
18 | model = nnx.merge(model_graphdef, model_state)
19 | y = jnp.roll(x, -1, axis=1)
20 | loss_mask = data.pad_mask(x) if pad else jnp.ones(x.shape, dtype=bool)
21 | loss_mask = loss_mask.at[:, -1].set(False)
22 | logits = model(x) # [B, T, V]
23 | losses = optax.softmax_cross_entropy_with_integer_labels(logits.astype(jnp.float32), y) # [B, T]
24 | return (losses * loss_mask).sum() / loss_mask.sum()
25 |
26 |
27 | @partial(jax.jit, static_argnames=('opt_graphdef', 'model_graphdef'), donate_argnames=('opt_state'))
28 | def train_step(key, opt_state, opt_graphdef, model_graphdef, batch):
29 | key, key_opt = jax.random.split(key)
30 |
31 | # compute grads from a single micro-batch
32 | if batch.ndim == 2:
33 | loss, grads = jax.value_and_grad(loss_fn)(opt_state.model, model_graphdef, batch)
34 |
35 | # compute grads from multiple micro-batches (using gradient accumulation)
36 | if batch.ndim == 3:
37 | loss = 0
38 | grads = otu.tree_zeros_like(opt_state.model, dtype=jnp.float32)
39 | def step_fn(i , args):
40 | loss, grads = args
41 | batch_loss, batch_grads = jax.value_and_grad(loss_fn)(opt_state.model, model_graphdef, batch[i])
42 | loss = (i*loss + batch_loss) / (i+1)
43 | grads = jax.tree.map(lambda m, g: (i*m + g) / (i+1), grads, batch_grads)
44 | return loss, grads
45 | loss, grads = jax.lax.fori_loop(0, len(batch), step_fn, (loss, grads))
46 |
47 | # optimizer step
48 | optimizer = nnx.merge(opt_graphdef, opt_state)
49 | optimizer.update(key_opt, grads)
50 | opt_state = nnx.state(optimizer)
51 | return key, opt_state, loss
52 |
53 |
54 | def eval_step(model_state, model_graphdef, dataset, pad=False):
55 | loss = 0
56 | for batch in dataset:
57 | loss += loss_fn(model_state, model_graphdef, batch, pad)
58 | return loss / len(dataset)
59 |
60 |
61 | def train_and_evaluate(c: DictConfig):
62 |
63 | # get model and dataset rng seed
64 | key = jax.random.key(c.seed)
65 | key, key_model, key_dataset = jax.random.split(key, 3)
66 |
67 | # sharding
68 | num_fsdp_devices = jax.device_count() // c.num_tp_devices
69 | mesh = jax.make_mesh((num_fsdp_devices, c.num_tp_devices), ('data', 'model'))
70 | jax.set_mesh(mesh)
71 | print('sharding mesh:', ', '.join(f'{k}={v}' for k, v in mesh.shape.items()))
72 |
73 | # model
74 | print('initializing model...')
75 | c.model.V = int(math.ceil(c.model.V / jax.device_count()) * jax.device_count()) # round V up to enable sharding
76 | model = model_lib.create_sharded_model(c.model, key_model)
77 | model_graphdef = nnx.graphdef(model)
78 |
79 | # get num. model parameters
80 | n_params = {
81 | 'n_param_nonembed': 12 * c.model.L * c.model.D**2,
82 | 'n_param_embed': c.model.D * c.model.V,
83 | 'n_param_actual': utils.get_num_model_params(model),
84 | }
85 | for k, v in n_params.items():
86 | print(f'{k}={v:_}')
87 |
88 | # dataset
89 | if (c.num_tokens_train is None) and (c.tokens_params_ratio is not None):
90 | c.num_tokens_train = c.tokens_params_ratio * (n_params['n_param_nonembed'] + n_params['n_param_embed'])
91 | ds_train, ds_valid = data.load_ds(key_dataset, mesh, c.ds_path, c.model.T, c.opt.microbatch_size, c.num_tokens_valid, c.num_tokens_train)
92 | if (c.num_tokens_train is None): c.num_tokens_train = ds_train.size
93 |
94 | # optimizer
95 | num_opt_steps = len(ds_train) // c.opt.grad_acc_steps
96 | tokens_per_opt_step = c.opt.batch_size * c.model.T
97 | tx = optimizer_lib.get_optimizer(c.opt, num_opt_steps, tokens_per_opt_step)
98 | optimizer = optimizer_lib.ModelAndOptimizer(model, tx, stochastic_round=c.opt.stochastic_round)
99 | opt_graphdef, opt_state = nnx.split(optimizer)
100 |
101 | # start wandb
102 | if jax.process_index() == 0:
103 | wandb.init(project=c.wandb_project, config=utils.flatten_dict(c), mode=c.wandb_mode, name=c.run_name)
104 | wandb.summary.update(n_params)
105 |
106 | # training loop
107 | train_loss_sum, train_loss_num = jnp.zeros([]), 0
108 | pbar = range(num_opt_steps)
109 | if jax.process_index() == 0: pbar = tqdm(pbar)
110 | for step in pbar:
111 |
112 | # get batch
113 | if c.opt.grad_acc_steps == 1:
114 | batch = ds_train[step] # [batch_size, T]
115 | if c.opt.grad_acc_steps > 1:
116 | batch = ds_train[step*c.opt.grad_acc_steps:(step+1)*c.opt.grad_acc_steps] # [grad_acc_steps, micro_batch_size, T]
117 |
118 | # training step
119 | key, opt_state, batch_loss = train_step(key, opt_state, opt_graphdef, model_graphdef, batch)
120 |
121 | # logging
122 | train_loss_sum += batch_loss
123 | train_loss_num += 1
124 | if train_loss_num * tokens_per_opt_step >= c.log_every_tokens:
125 | metrics = {}
126 | metrics['train_loss'] = train_loss_sum / train_loss_num
127 | metrics['train_tokens_seen'] = (step+1) * tokens_per_opt_step
128 | if jax.process_index() == 0:
129 | wandb.log(metrics, step)
130 | pbar.set_postfix_str(f'loss={metrics["train_loss"]:.2f}')
131 | train_loss_sum, train_loss_num = jnp.zeros([]), 0
132 |
133 | # eval at end of training
134 | eval_loss = eval_step(opt_state.model, model_graphdef, ds_valid, c.pad_eval)
135 | if jax.process_index() == 0:
136 | wandb.log({'eval_loss': eval_loss}, step)
137 | wandb.finish()
138 |
--------------------------------------------------------------------------------
/finetuning/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import jax
3 | import jax.numpy as jnp
4 | from flax import nnx
5 | import numpy as np
6 | import datasets
7 | from jax.sharding import NamedSharding, PartitionSpec as P
8 | from math_verify import parse, verify
9 | from sampler import sample
10 | from tqdm.auto import tqdm
11 |
12 |
13 | def load_datasets(vocab, seq_len=1024):
14 | pad_id = vocab.pad_id()
15 | bos_id = vocab.bos_id()
16 | eos_id = vocab.eos_id()
17 |
18 | # load MATH dataset
19 | print('loading datasets...')
20 | ds_name = 'EleutherAI/hendrycks_math'
21 | configs = datasets.get_dataset_config_names(ds_name) # ['algebra', 'counting_and_probability', 'geometry', 'intermediate_algebra', 'number_theory', 'prealgebra', 'precalculus']
22 | ds_train = datasets.concatenate_datasets([datasets.load_dataset(ds_name, config, split='train') for config in configs]) # ['problem', 'solution']
23 | ds_valid = datasets.concatenate_datasets([datasets.load_dataset(ds_name, config, split='test') for config in configs]) # ['problem', 'solution']
24 |
25 | # tokenize trainind dataset
26 | print('tokenizing training dataset...')
27 | train_tokens = np.full([len(ds_train), seq_len], pad_id, dtype=np.int32)
28 | train_pos = np.zeros([len(ds_train), seq_len], dtype=np.int32)
29 | train_loss_mask = np.zeros([len(ds_train), seq_len], dtype=np.bool_)
30 | train_attn_mask = np.zeros([len(ds_train), seq_len, seq_len], dtype=np.bool_)
31 | seq_idx = 0
32 | tok_idx = 0
33 | skipped = 0
34 | for example in ds_train:
35 |
36 | # tokenize example
37 | prompt = f'Problem: {example["problem"]}\nSolution: '
38 | solution = f'{example["solution"]}'
39 | prompt_tokenized, solution_tokenized = vocab.EncodeAsIds([prompt, solution])
40 | example_tokenized = [bos_id] + prompt_tokenized + solution_tokenized + [eos_id]
41 |
42 | # if example is too long, skip it
43 | if len(example_tokenized) > seq_len:
44 | skipped += 1
45 | continue
46 |
47 | # if example doesn't fit in current sequence, start next sequence
48 | if tok_idx + len(example_tokenized) > seq_len:
49 | seq_idx += 1
50 | tok_idx = 0
51 |
52 | # store tokens
53 | train_tokens[seq_idx, tok_idx:tok_idx+len(example_tokenized)] = example_tokenized
54 | train_pos[seq_idx, tok_idx:tok_idx+len(example_tokenized)] = np.arange(len(example_tokenized))
55 | train_loss_mask[seq_idx, tok_idx+len(prompt_tokenized):tok_idx+len(example_tokenized)-1] = True
56 | train_attn_mask[seq_idx, tok_idx:tok_idx+len(example_tokenized), tok_idx:tok_idx+len(example_tokenized)] = True
57 | tok_idx += len(example_tokenized)
58 | train_attn_mask = np.tril(train_attn_mask)
59 | train_tokens = train_tokens[:seq_idx+1]
60 | train_pos = train_pos[:seq_idx+1]
61 | train_attn_mask = train_attn_mask[:seq_idx+1]
62 | train_loss_mask = train_loss_mask[:seq_idx+1]
63 | print(f'skipped train. seq.: {skipped / len(ds_train):.1%}')
64 |
65 | # tokenize eval dataset
66 | print('tokenizing eval dataset...')
67 | skipped = 0
68 | prompts_eval = []
69 | problems_eval = []
70 | solutions_eval = []
71 | tokens_eval = np.full([len(ds_valid), seq_len], pad_id, dtype=np.int32)
72 | for i, example in enumerate(ds_valid):
73 | problems_eval += [example['problem']]
74 | solutions_eval += [example['solution']]
75 | prompt = f'Problem: {example["problem"]}\nSolution: '
76 | prompt_tokenized = [bos_id] + vocab.EncodeAsIds(prompt)
77 | if len(prompt_tokenized) < seq_len:
78 | tokens_eval[i, :len(prompt_tokenized)] = prompt_tokenized
79 | else:
80 | skipped += 1
81 | problems_eval = np.array(problems_eval)
82 | solutions_eval = np.array(solutions_eval)
83 | print(f'skipped valid. seq.: {skipped / len(ds_valid):.1%}')
84 |
85 | return train_tokens, train_pos, train_attn_mask, train_loss_mask, tokens_eval, problems_eval, solutions_eval
86 |
87 |
88 | def benchmark_model(key, model, tokens, problems_eval, solutions_eval, vocab, batch_size, n_eval_samples=None, temperature=1, print_output=True):
89 | pad_id = vocab.pad_id()
90 | eos_id = vocab.eos_id()
91 | key_decoding, key_questions = jax.random.split(key)
92 | mesh = model.in_embed.embedding.value.sharding.mesh
93 | if n_eval_samples is None: n_eval_samples = len(tokens)
94 | n_batches = n_eval_samples // batch_size
95 | sample_idxs = jax.random.choice(key_questions, len(tokens), shape=[n_batches, batch_size], replace=False)
96 | lengths_list = []
97 | correct_list = []
98 | finished_list = []
99 | pbar = tqdm(sample_idxs, desc='Sampling') if (jax.process_index() == 0) else sample_idxs
100 | for batch_idx in pbar:
101 | # sample tokens
102 | input_tokens_batch = jax.device_put(tokens[batch_idx], NamedSharding(mesh, P('data', None)))
103 | output_tokens_batch = sample(key_decoding, *nnx.split(model), input_tokens_batch, temperature)
104 |
105 | # extract output sequences
106 | completions_tokens = []
107 | for in_seq, out_seq in zip(input_tokens_batch, output_tokens_batch):
108 | out_seq = out_seq[jnp.argmax(in_seq==pad_id):]
109 | if jnp.any(out_seq==pad_id): out_seq = out_seq[:jnp.argmax(out_seq==pad_id)]
110 | completions_tokens += [out_seq.tolist()]
111 |
112 | # eval completions
113 | completions_text = vocab.DecodeIds(completions_tokens)
114 | for sample_idx, completion_tokens, completion_text in zip(batch_idx, completions_tokens, completions_text):
115 | if sample_idx < len(problems_eval):
116 | problem = problems_eval[sample_idx]
117 | gold = solutions_eval[sample_idx]
118 | parsed = parse(completion_text)
119 | finished = eos_id in completion_tokens
120 | correct = verify(parse(gold), parsed)
121 | lengths_list += [len(completion_tokens)]
122 | finished_list += [finished]
123 | correct_list += [correct]
124 | if print_output:
125 | print('------------')
126 | print(f'PROMPT:\n{problem}\nCOMPLETION:\n{completion_text}\nPARSED: {parsed}\nGOLD: {gold}\nCORRECT: {correct}')
127 |
128 | return dict(length=np.mean(lengths_list), finished=np.mean(finished_list), accuracy=np.mean(correct_list))
129 |
--------------------------------------------------------------------------------
/pretraining/model.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import jax
3 | import jax.numpy as jnp
4 | from functools import partial
5 | from flax import nnx
6 | from jax.sharding import PartitionSpec as P
7 | from jax.experimental.shard_map import shard_map
8 | from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel, splash_attention_mask
9 | from omegaconf.dictconfig import DictConfig
10 | from rope import apply_rope
11 |
12 |
13 | class TransformerDecoder(nnx.Module):
14 | def __init__(self, c: DictConfig, rngs: nnx.Rngs):
15 | embed_in_init = sharded_init('embedding_in')
16 | embed_out_init = sharded_init('embedding_out')
17 | self.token_embed_in = nnx.Embed(num_embeddings=c.V, features=c.D, embedding_init=embed_in_init, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs)
18 | self.token_embed_out = nnx.Embed(num_embeddings=c.V, features=c.D, embedding_init=embed_out_init, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs)
19 | self.blocks = nnx.List(TransformerBlock(c, rngs) for _ in range(c.L))
20 | self.out_ln = nnx.RMSNorm(c.D, use_scale=False, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs)
21 | self.remat = c.remat
22 |
23 | def __call__(self, x): # [B, S]
24 |
25 | # token embedding
26 | h = self.token_embed_in(x) # [B, T, D]
27 |
28 | # transformer blocks
29 | for block in self.blocks:
30 | h = jax.remat(block)(h) if self.remat else block(h)
31 |
32 | # project back to vocabulary
33 | h = self.out_ln(h)
34 | logits = self.token_embed_out.attend(h) # [B, T, V]
35 | return logits
36 |
37 |
38 | class TransformerBlock(nnx.Module):
39 | def __init__(self, c: DictConfig, rngs: nnx.Rngs):
40 | self.ln1 = nnx.RMSNorm(c.D, use_scale=False, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs)
41 | self.ln2 = nnx.RMSNorm(c.D, use_scale=False, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs)
42 | self.attn = MultiHeadAttention(c, rngs)
43 | self.mlp = MLP(c, rngs)
44 |
45 | def __call__(self, x): # [B, T, D]
46 | x = x + self.attn(self.ln1(x)) # attention block
47 | return x + self.mlp(self.ln2(x)) # MLP block
48 |
49 |
50 | class MultiHeadAttention(nnx.Module):
51 | """Causal attention layer."""
52 | def __init__(self, c: DictConfig, rngs: nnx.Rngs):
53 | qkv_proj_init = sharded_init('attn_qkv_proj')
54 | out_proj_init = sharded_init('attn_out_proj')
55 | self.qkv_proj = nnx.Einsum('BTd,SNdH->SBTNH', (3, c.N, c.D, c.H), kernel_init=qkv_proj_init, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs)
56 | self.out_proj = nnx.Einsum('BTnh,nhD->BTD', (c.N, c.H, c.D), kernel_init=out_proj_init, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs)
57 | self.query_norm = nnx.RMSNorm(c.H, use_scale=False, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs)
58 | self.key_norm = nnx.RMSNorm(c.H, use_scale=False, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs)
59 | if c.use_flash_attn and jax.devices()[0].platform == 'tpu' and (c.H % 128 != 0):
60 | warnings.warn('cannot use flash attention because `model.H` is not a multiple of 128.')
61 | c.use_flash_attn &= jax.devices()[0].platform == 'tpu'
62 | c.use_flash_attn &= (c.H % 128 == 0)
63 | self.attention = partial(tpu_causal_flash_attention) if c.use_flash_attn else partial(jax.nn.dot_product_attention, is_causal=True)
64 |
65 | def __call__(self, x): # [B, T, D]
66 | B, T, D = x.shape
67 |
68 | # input projection
69 | q, k, v = self.qkv_proj(x) # [B, T, N, H]
70 |
71 | # qk-norm
72 | q = self.query_norm(q)
73 | k = self.key_norm(k)
74 |
75 | # position embedding
76 | position = jnp.arange(T)
77 | q = apply_rope(q, position[None])
78 | k = apply_rope(k, position[None])
79 |
80 | # attention
81 | out = self.attention(q, k, v) # [B, T, N, H]
82 |
83 | # output projection followed by contraction back to original dims
84 | out = self.out_proj(out) # [B, T, D]
85 | return out
86 |
87 |
88 | def tpu_causal_flash_attention(q, k, v):
89 | """
90 | TPU Flash Attention.
91 | https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py
92 | https://github.com/AI-Hypercomputer/maxtext/blob/9ea52118535e970096c164460dbbfa478d157066/MaxText/layers/attentions.py#L562
93 | """
94 | B, T, N, H = q.shape
95 | assert H >= 128, 'TPU flash attention reqruies head dim. to be a multiple of 128'
96 |
97 | # scale query
98 | q /= jnp.sqrt(H)
99 |
100 | # kernel block sizes
101 | # https://github.com/AI-Hypercomputer/maxtext/blob/afcdf47f8b7c1e1864fa81012a873590c5408122/MaxText/configs/base.yml#L644
102 | block_sizes = splash_attention_kernel.BlockSizes(
103 | block_q=512,
104 | block_kv=512,
105 | block_kv_compute=128,
106 | block_q_dkv=512,
107 | block_kv_dkv=512,
108 | block_kv_dkv_compute=128,
109 | block_q_dq=512,
110 | block_kv_dq=512,
111 | )
112 |
113 | mesh = jax.sharding.get_abstract_mesh()
114 | sharding = P('data', None, 'model', None)
115 | @partial(shard_map, mesh=mesh, in_specs=(sharding, sharding, sharding), out_specs=sharding, check_rep=False)
116 | def attention(q, k, v):
117 | _, _, n, _ = q.shape
118 | causal_mask = splash_attention_mask.CausalMask(shape=(T, T))
119 | multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(causal_mask,) * n)
120 | splash_kernel = splash_attention_kernel.make_splash_mha(mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes)
121 | out = jax.vmap(splash_kernel)(
122 | q.swapaxes(1, 2),
123 | k.swapaxes(1, 2),
124 | v.swapaxes(1, 2)
125 | ).swapaxes(1, 2) # [B, T, N, H]
126 | return out
127 |
128 | return attention(q, k, v)
129 |
130 |
131 | class MLP(nnx.Module):
132 | """Multilayer perceptron."""
133 | def __init__(self, c: DictConfig, rngs: nnx.Rngs):
134 | fc1_init = sharded_init('mlp_fc1')
135 | fc2_init = sharded_init('mlp_fc2')
136 | self.fc1 = nnx.Linear(in_features=c.D, out_features=c.F, kernel_init=fc1_init, use_bias=False, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs)
137 | self.fc2 = nnx.Linear(in_features=c.F, out_features=c.D, kernel_init=fc2_init, use_bias=False, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs)
138 |
139 | def __call__(self, x): # [B, T, D]
140 | h = jax.nn.gelu(self.fc1(x)) # [B, T, F]
141 | return self.fc2(h) # [B, T, D]
142 |
143 |
144 | def sharded_init(layer_type: str):
145 | """Initialize weights with optional sharding."""
146 | kernel_init = jax.nn.initializers.xavier_uniform()
147 | embed_init = jax.nn.initializers.variance_scaling(1.0, 'fan_in', 'normal', out_axis=0)
148 | match layer_type:
149 | case 'embedding_in': # [V, D]
150 | return nnx.with_partitioning(embed_init, ('data', 'model'))
151 | case 'embedding_out': # [V, D]
152 | return nnx.with_partitioning(embed_init, ('model', 'data'))
153 | case 'attn_qkv_proj': # [3, N, D, H]
154 | return nnx.with_partitioning(kernel_init, (None, 'model', 'data', None))
155 | case 'attn_out_proj': # [N, H, D]
156 | return nnx.with_partitioning(kernel_init, ('model', None, 'data'))
157 | case 'mlp_fc1': # [D, F]
158 | return nnx.with_partitioning(kernel_init, ('data', 'model'))
159 | case 'mlp_fc2': # [F, D]
160 | return nnx.with_partitioning(kernel_init, ('model', 'data'))
161 | case _:
162 | raise ValueError(f'unrecognized layer type: {layer_type}')
163 |
164 |
165 | def create_sharded_model(c: DictConfig, key):
166 | """
167 | initialize sharded model without putting it on a single device
168 | https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html
169 | """
170 | seed = int(jax.random.randint(key, [1], 0, 1_000_000)[0])
171 |
172 | @nnx.jit
173 | def initialize_sharded_model():
174 | rngs = nnx.Rngs(seed)
175 | model = TransformerDecoder(c, rngs=rngs) # unsharded at this moment
176 | state = nnx.state(model) # the model's state, a pure pytree
177 | pspecs = nnx.get_partition_spec(state) # get annotations from state
178 | sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
179 | nnx.update(model, sharded_state) # the model is sharded now
180 | return model
181 |
182 | model = initialize_sharded_model()
183 |
184 | return model
--------------------------------------------------------------------------------
/finetuning/factorized.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """
16 | Factorized optimizers.
17 | Taken from https://github.com/google-deepmind/optax/blob/main/optax/_src/factorized.py
18 | Only 2 LOC were modified to hard-code scale_by_factored_rms state dtype to float32.
19 | """
20 |
21 | from collections.abc import Callable
22 | import dataclasses
23 | from typing import NamedTuple, Optional
24 |
25 | import chex
26 | import jax
27 | import jax.numpy as jnp
28 | import numpy as np
29 | from optax._src import base
30 | from optax._src import numerics
31 |
32 |
33 | def _decay_rate_pow(i: int, exponent: float = 0.8) -> chex.Array:
34 | """Second-order moment decay schedule."""
35 | t = jnp.array(i + 1, jnp.float32)
36 | return 1.0 - t ** (-exponent)
37 |
38 |
39 | def _factored_dims(
40 | shape: base.Shape, factored: bool, min_dim_size_to_factor: int
41 | ) -> Optional[tuple[int, int]]:
42 | """Whether to use a factored second moment estimator.
43 |
44 | This function returns a tuple with the two largest axes to reduce over.
45 | If no two dimensions have size >= min_dim_size_to_factor, return None.
46 |
47 | Args:
48 | shape: an input shape
49 | factored: whether to use factored second-moment estimator for 2d vars.
50 | min_dim_size_to_factor: only factor accumulator if two array dimensions have
51 | at least this size.
52 |
53 | Returns:
54 | None or a tuple of ints
55 | """
56 | if not factored or len(shape) < 2:
57 | return None
58 | sorted_dims = np.argsort(shape)
59 | if shape[sorted_dims[-2]] < min_dim_size_to_factor:
60 | return None
61 | return int(sorted_dims[-2]), int(sorted_dims[-1])
62 |
63 |
64 | @dataclasses.dataclass
65 | class _UpdateResult:
66 | """Opaque container that is not traversed by jax.tree.map."""
67 |
68 | update: chex.Array # the update to apply to params
69 | v_row: chex.Array # used for factored params.
70 | v_col: chex.Array # used for factored params.
71 | v: chex.Array # used for params where factoring is skipped.
72 |
73 |
74 | class FactoredState(NamedTuple):
75 | """Overall state of the gradient transformation."""
76 |
77 | count: chex.Array # number of update steps.
78 | v_row: chex.ArrayTree # Tree of factored params.
79 | v_col: chex.ArrayTree # Tree of factored params.
80 | v: chex.ArrayTree # Tree for params where factoring is skipped.
81 |
82 |
83 | def scale_by_factored_rms(
84 | factored: bool = True,
85 | decay_rate: float = 0.8,
86 | step_offset: int = 0,
87 | min_dim_size_to_factor: int = 128,
88 | epsilon: float = 1e-30,
89 | decay_rate_fn: Callable[[int, float], chex.Array] = _decay_rate_pow,
90 | ):
91 | """Scaling by a factored estimate of the gradient rms (as in Adafactor).
92 |
93 | This is a so-called "1+epsilon" scaling algorithms, that is extremely memory
94 | efficient compared to RMSProp/Adam, and has had wide success when applied to
95 | large-scale training of attention-based models.
96 |
97 | Args:
98 | factored: boolean: whether to use factored second-moment estimates..
99 | decay_rate: float: controls second-moment exponential decay schedule.
100 | step_offset: for finetuning, one may set this to the starting step-number of
101 | the fine tuning phase.
102 | min_dim_size_to_factor: only factor accumulator if two array dimensions are
103 | at least this size.
104 | epsilon: Regularization constant for squared gradient.
105 | decay_rate_fn: A function that accepts the current step, the decay rate
106 | parameter and controls the schedule for the second momentum. Defaults to
107 | the original adafactor's power decay schedule. One potential shortcoming
108 | of the original schedule is the fact that second momentum converges to 1,
109 | which effectively freezes the second momentum. To prevent this the user
110 | can opt for a custom schedule that sets an upper bound for the second
111 | momentum, like in Zhai et al., 2021.
112 |
113 | Returns:
114 | The corresponding :class:`optax.GradientTransformation`.
115 |
116 | References:
117 | Shazeer et al, `Adafactor: Adaptive Learning Rates with Sublinear Memory
118 | Cost `offline` in this directory. Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
Run data is saved locally in /content/batch-size/finetuning/wandb/offline-run-20250827_192852-b7npqj4d"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"metadata":{"tags":null},"name":"stdout","output_type":"stream","text":["loading model...\n"]},{"metadata":{"tags":null},"name":"stderr","output_type":"stream","text":["WARNING:absl:Provided metadata contains unknown key custom. Adding it to custom_metadata.\n"]},{"metadata":{"tags":null},"name":"stdout","output_type":"stream","text":["loading data...\n","loading datasets...\n"]},{"metadata":{"tags":null},"name":"stderr","output_type":"stream","text":["/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n","The secret `HF_TOKEN` does not exist in your Colab secrets.\n","To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n","You will be able to reuse this secret in all of your notebooks.\n","Please note that authentication is recommended but still optional to access public models or datasets.\n"," warnings.warn(\n"]},{"metadata":{"tags":null},"name":"stdout","output_type":"stream","text":["tokenizing training dataset...\n","skipped train. seq.: 2.1%\n","tokenizing eval dataset...\n","skipped valid. seq.: 0.0%\n","n_model_params=11_765_788_416\n","n_opt_params=13_234_532\n"]},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"c2602df649d04f589f5accc560dd2897","version_major":2,"version_minor":0},"text/plain":["Training: 0%| | 0/2535 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"output_type":"display_data","data":{"text/plain":["Sampling: 0%| | 0/1 [00:00, ?it/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"58ee97f06e084bef8688544b671fcdb1"}},"metadata":{}},{"output_type":"stream","name":"stdout","text":["------------\n","PROMPT:\n","Which of the following polynomials has the greatest real root?\n","\n","(A) $x^{19}+2018x^{11}+1$\n","(B) $x^{17}+2018x^{11}+1$\n","(C) $x^{19}+2018x^{13}+1$\n","(D) $x^{17}+2018x^{13}+1$\n","(E) $2019x+2018$\n","\n","Enter the letter of the polynomial with the greatest real root.\n","COMPLETION:\n","We evaluate at $x=2018$. Option (A) will result in \\begin{align*}\n","x^{19}+2018x^{11}+1 &= 2018^{19} + 2018^{12} +1 \\\\\n","&>2018^{17}+2018^{12} + 1 \\\\\n","&= 2018^{17} + 2018^{13} + 1.\n","\\end{align*}Also, \\begin{align*}\n","2018^{17} + 2018^{13} + 1 &= 2018^{13}(2018^{4} + 1) + 1 \\\\\n","&= 2018^{13}(1655664 + 1) + 1 \\\\\n","&= 2018^{13}\\cdot 1655665 + 1 \\\\\n","&< 2018^{13}\\cdot 1655665 + 2018 \\\\\n","&= 2019x+2018.\n","\\end{align*}Thus, the polynomial with the greatest real root is $\\boxed{\\text{A}}$.\n","PARSED: [a, '\\\\text{A}']\n","GOLD: By Descartes' Rule of Signs, none of the polynomials has a positive root, and each one has exactly one negative root. Furthermore, each polynomial is positive at $x = 0$ and negative at $x = -1,$ so each real root lies between $-1$ and 0. Also, each polynomial is increasing on the interval $(-1,0).$\n","\n","Let $r_A$ and $r_B$ be the roots of the polynomials in options A and B, respectively, so\n","\\[r_A^{19} + 2018r_A^{11} + 1 = r_B^{17} + 2018r_B^{11} + 1 = 0,\\]so $r_A^{19} = r_B^{17}.$ Since $r_A \\in (-1,0),$ $r_B^{17} = r_A^{19} > r_A^{17},$ so $r_B > r_A.$\n","\n","Similarly, let $r_C$ and $r_D$ be the roots of the polynomials in options C and D, respectively, so\n","\\[r_C^{19} + 2018r_C^{13} + 1 = r_D^{17} + 2018r_D^{13} + 1 = 0,\\]so $r_C^{19} = r_D^{17}.$ Since $r_C \\in (-1,0),$ $r_D^{17} = r_C^{19} > r_C^{17},$ so $r_D > r_C.$\n","\n","Since\n","\\[r_B^{17} + 2018r_B^{11} + 1 = r_D^{17} + 2018r_D^{13} + 1 = 0,\\]we have that $r_B^{11} = r_D^{13}.$ Since $r_D \\in (-1,0),$ $r_B^{11} = r_D^{13} > r_D^{11},$ so $r_B > r_D.$\n","\n","Therefore, the largest root must be either $r_B$ or the root of $2019x + 2018 = 0,$ which is $-\\frac{2018}{2019}.$\n","\n","Let $f(x) = x^{17} + 2018x^{11} + 1,$ so $f(r_B) = 0.$ Note that\n","\\[f \\left( -\\frac{2}{3} \\right) = -\\frac{2^{17}}{3^{17}} - 2018 \\cdot \\frac{2^{11}}{3^{11}} + 1.\\]We claim that $2018 \\cdot 2^{11} > 3^{11}.$ Since $2^2 > 3,$ $2^{22} > 3^{11}.$ Then\n","\\[2018 \\cdot 2^{11} = 1009 \\cdot 2^{22} > 3^{11}.\\]From $2018 \\cdot 2^{11} > 3^{11},$ $2018 \\cdot \\frac{2^{11}}{3^{11}} > 1,$ so\n","\\[f \\left( -\\frac{2}{3} \\right) = -\\frac{2^{17}}{3^{17}} - 2018 \\cdot \\frac{2^{11}}{3^{11}} + 1 < 0.\\]Since $f(x)$ is an increasing function, we can conclude that $r_B > -\\frac{2}{3} > -\\frac{2018}{2019}.$ Therefore, the answer is $\\boxed{\\text{(B)}}.$\n","CORRECT: False\n","------------\n","PROMPT:\n","The positive difference between the squares of two consecutive odd integers is 128. What is the product of the two integers?\n","COMPLETION:\n","Let $n$ be the smaller of the two integers. Since the two integers are consecutive, the larger integer is $n+2$. So, the equation to solve is $n^2-(n+2)^2=128$. Expanding, we have \\begin{align*}\n","n^2-(n+2)^2&=128\\quad\\Rightarrow\\\\\n","n^2-(n^2+4n+4)&=128\\quad\\Rightarrow\\\\\n","-4n&=132\\quad\\Rightarrow\\\\\n","n&=-33.\n","\\end{align*} Since $n$ is a positive integer, we must reject this solution. Let $n$ be the larger of the two integers. So, the equation to solve is now $(n+2)^2-n^2=128$. Expanding, we have \\begin{align*}\n","(n+2)^2-n^2&=128\\quad\\Rightarrow\\\\\n","n^2+4n+4-n^2&=128\\quad\\Rightarrow\\\\\n","4n&=124\\quad\\Rightarrow\\\\\n","n&=31.\n","\\end{align*} Thus, the two integers are 31 and 33, and the product is $\\boxed{1023}$.\n","PARSED: [1023, '1023']\n","GOLD: Let $a, a+2$ be the two integers. We are given that $(a+2)^2-a^2 = 128$. Using the difference of squares factorization, the equation becomes $(a+2+a)(a+2-a) = 128$. Simplifying and solving, we get: \\begin{align*}\n","(2a+2)(2) &= 128\\\\\n","\\Rightarrow 2a+2 &= 64\\\\\n","\\Rightarrow 2a &= 62\\\\\n","\\Rightarrow a &= 31.\\\\\n","\\end{align*} Therefore the desired product is $a(a+2) = 31\\cdot33 = \\boxed{1023}$.\n","CORRECT: True\n","------------\n","PROMPT:\n","The area of right triangle $ABC$ is 4, and the hypotenuse $\\overline{AB}$ is 12. Compute $\\sin 2A.$\n","COMPLETION:\n","[asy]\n","draw((0,0)--(11,0));\n","draw((0,0)--(12,0)--(0,10.4),linewidth(0.6));\n","draw((0,0)--(11,0)--(11,2)--(0,10.4));\n","draw((0,0)--(1.3,0));\n","draw((1.3,0)--(1.3,0.4));\n","draw((1.3,0.4)--(0.9,0.4)--(0.9,0));\n","label(\"$B$\",(0,0),S);\n","label(\"$C$\",(12,0),S);\n","label(\"$A$\",(0,10.4),NW);\n","label(\"$D$\",(11,2),E);\n","[/asy]\n","\n","Since the area of $\\triangle ABC$ is 4 and its base $AB$ has length 12, the height from $C$ to $AB$ has length $2$. Then, $\\overline{AC}$ has length $\\sqrt{116}$.\n","\n","Notice that $\\sin 2A = 2 \\sin A \\cos A.$ Since $\\overline{AC}$ and $\\overline{BC}$ are the legs and $\\overline{AB}$ is the hypotenuse of a right triangle, $\\overline{AC}$ and $\\overline{BC}$ are adjacent and opposite sides to $A$ and $B$, respectively. Thus, $2 \\sin A \\cos A = 2 \\cdot \\frac{2}{\\sqrt{116}} \\cdot \\frac{\\sqrt{116}}{12} = \\boxed{\\frac{1}{3}}.$\n","PARSED: [1/3, '\\\\frac{1}{3}']\n","GOLD: We have that $\\frac{1}{2} ab = 4,$ so $ab = 8.$\n","\n","[asy]\n","unitsize (1 cm);\n","\n","pair A, B, C;\n","\n","C = (0,0);\n","B = (3,0);\n","A = (0,2);\n","\n","draw(A--B--C--cycle);\n","draw(rightanglemark(A,C,B,6));\n","\n","label(\"$A$\", A, N);\n","label(\"$B$\", B, E);\n","label(\"$C$\", C, SW);\n","label(\"$a$\", (B + C)/2, S, red);\n","label(\"$b$\", (A + C)/2, W, red);\n","label(\"$12$\", (A + B)/2, NE, red);\n","[/asy]\n","\n","Then\n","\\[\\sin 2A = 2 \\sin A \\cos A = 2 \\cdot \\frac{a}{12} \\cdot \\frac{b}{12} = \\frac{ab}{72} = \\frac{8}{72} = \\boxed{\\frac{1}{9}}.\\]\n","CORRECT: False\n","------------\n","PROMPT:\n","Let $z_1 = 18 + 83i$, $z_2 = 18 + 39i,$ and $z_3 = 78 + 99i,$ where $i^2 = -1$. Let $z$ be the unique complex number with the properties that $\\frac{z_3 - z_1}{z_2 - z_1} \\cdot \\frac{z - z_2}{z - z_3}$ is a real number and the imaginary part of $z$ is the greatest possible. Find the real part of $z$.\n","COMPLETION:\n","From the given equation, $$\n","\\left|\\frac{z_3 - z_1}{z_2 - z_1}\\cdot \\frac{z - z_2}{z - z_3}\\right| = \\left|\\frac{z_3 - z_1}{z_2 - z_1}\\right| \\cdot \\left|\\frac{z - z_2}{z - z_3}\\right| = 1.\n","$$ This is true if and only if the complex numbers corresponding to the factors are each other's negatives. In other words, \\begin{align*}\n","\\frac{z_3 - z_1}{z_2 - z_1} &= -\\frac{z - z_2}{z - z_3} \\\\\n","\\frac{78 + 99i - 18 - 83i}{18 + 39i - 18 - 83i} &= -\\frac{z - 18 - 39i}{z - 78 - 99i} \\\\\n","\\frac{60 + 16i}{48 - 44i} &= \\frac{z - 18 - 39i}{z - 78 - 99i} \\\\\n","\\frac{5 + i}{4 - 4i} &= \\frac{z - 18 - 39i}{z - 78 - 99i} \\\\\n","\\frac{5 - i}{4 + 4i} &= \\frac{z - 18 - 39i}{z - 78 - 99i} \\\\\n","\\frac{5 - i}{4} &= \\frac{z - 18 - 39i}{z - 78 - 99i} \\\\\n","5z - 390 - 495i = 4z - 72 - 156i \\\\\n","z = 318 + 339i\n","\\end{align*} To find the complex number $z$ for which the imaginary part is greatest, we must maximize the imaginary part of $\\frac{z_3 - z_1}{z_2 - z_1}$. Since this term is a complex conjugate of $\\frac{z - z_2}{z - z_3}$, the imaginary part of this term must be maximized as well. Therefore, $z$ must lie on the line passing through $z_2$ and $z_3$ with a slope of $(99 - 39)/(78 - 18) = 2$. The imaginary part of $z_3 - z_1$ is $99 - 83 = 16$. Thus, this imaginary part is maximized when $z = \\boxed{318}$.\n","PARSED: [318, '318']\n","GOLD: Let $z = a + bi,$ where $a$ and $b$ are real numbers. Then\n","\\begin{align*}\n","\\frac{z_3 - z_1}{z_2 - z_1} \\cdot \\frac{z - z_2}{z - z_3} &= \\frac{60 + 16i}{-44i} \\cdot \\frac{(a - 18) + (b - 39)i}{(a - 78) + (b - 99)i} \\\\\n","&= \\frac{-4 + 15i}{11} \\cdot \\frac{[(a - 18) + (b - 39)i][(a - 78) - (b - 99)i]}{(a - 78)^2 + (b - 99)^2}.\n","\\end{align*}This expression is real if and only if the imaginary part is 0. In other words,\n","\\[(-4 + 15i)[(a - 18) + (b - 39)i][(a - 78) - (b - 99)i]\\]has imaginary part 0. In turn this is equivalent to\n","\\[(-4)(-(a - 18)(b - 99) + (a - 78)(b - 39)) + 15((a - 18)(a - 78) + (b - 39)(b - 99)) = 0.\\]This simplifies to $a^2 - 112a + b^2 - 122b + 4929 = 0.$ Completing the square, we get\n","\\[(a - 56)^2 + (b - 61)^2 = 1928,\\]so\n","\\[(a - 56)^2 = 1928 - (b - 61)^2.\\]When $b$ is maximized, the right-hand side is 0, and $a = \\boxed{56}.$\n","CORRECT: False\n","------------\n","PROMPT:\n","A TV screen measures 24 by 16 inches. If each dimension is increased by 20$\\%$, by what percent is the area increased?\n","COMPLETION:\n","The area of the screen increases by 40$\\%$.\n","\n","The screen has area $(.24\\times24\\times.8)\\times(.16\\times24\\times.8)=1.44\\times24\\times1.6$ square inches, an increase of 44 percent. The area of the screen is $(.24\\times24)\\times(.16\\times24)$ square inches. (Note that this calculation can be skipped and the percentage change can still be determined, because $1+44\\%=.5+22=.8\\times1.2$.)\n","PARSED: [Eq(0.5 + 22, 0.8*1.2) & Eq(44*(1/100) + 1, 0.5 + 22), '1+44\\\\%=.5+22=.8\\\\times1.2']\n","GOLD: Since we're dealing with percentages, the actual dimensions do not matter. Let $l$ and $w$ represent the dimensions of the TV screen. The current area is $lw$. If we increase $l$ by $20\\%$, we end up with $l\\left(1+\\frac{20}{100}\\right)=l\\left(\\frac{10}{10}+\\frac{2}{10}\\right)=\\frac{12}{10}l$. Increasing the width results in $\\frac{12}{10}w$. The new area is $\\frac{12}{10}l\\times\\frac{12}{10}w=\\frac{144}{100}lw=lw\\left(1+\\frac{44}{100}\\right)$. The area increases by $\\boxed{44\\%}$.\n","CORRECT: False\n","------------\n","PROMPT:\n","How many ways are there to arrange 6 people around a circular table with 7 seats? (Two seatings are considered the same if one is a rotation of the other.)\n","COMPLETION:\n","If we can position 5 people around the table, then there are 7 different places to position the 6th person, so there are 7 different seatings that give the same relative arrangement. Therefore, the number of arrangements is equal to the number of seatings with 5 people divided by 7.\n","\n","It is clear that with 5 people, there are $5!$ ways to seat the people. However, each arrangement is counted $5!$ times, once for each rotation of the seats. So we divide by 5! to eliminate the over-counting, giving us $\\frac{5!}{7} = \\boxed{60}$ arrangements.\n","PARSED: [60, '60']\n","GOLD: First choose which seat is empty. It doesn't matter which seat we choose because we can rotate everyone at the table to move the empty seat anywhere we want. After the empty seat has been chosen, there are $6!=\\boxed{720}$ ways to arrange the remaining people.\n","CORRECT: False\n","------------\n","PROMPT:\n","Line $j$ is perpendicular to $\\frac{y}{3}+\\frac{2x}{5}=2$. What is the slope of line $j$?\n","COMPLETION:\n","$\\frac{y}{3}+\\frac{2x}{5}=2 \\Longrightarrow \\frac{5y+6x}{15}=2 \\Longrightarrow 5y+6x=30 \\Longrightarrow y=-\\frac{6}{5}x+6$. The slope of this line is $-\\frac{6}{5}$, so the slope of the perpendicular line is $\\boxed{\\frac{5}{6}}$.\n","PARSED: [5/6, '\\\\frac{5}{6}']\n","GOLD: First, we find the slope of $\\frac{y}{3}+\\frac{2x}{5}=2$. We change the form to slope-intercept form. Multiplying both sides by 3, we get $y+\\frac{6x}{5}=6$. Moving the $x$ to the right, we get $y=-\\frac{6x}{5}+6$. The slopes of two perpendicular lines are negative reciprocals. Therefore, the slope of line $j$ is the opposite reciprocal of $-\\frac{6}{5}$ which is $\\boxed{\\frac56}$.\n","CORRECT: True\n","------------\n","PROMPT:\n","Let $p(x)$ be a cubic polynomial such that $p(2) = 0,$ $p(-1) = 0,$ $p(4) = 6,$ and $p(5) = 8.$ Find $p(7).$\n","COMPLETION:\n","Let $p(x) = a(x - 2)(x + 1)(x - 4)$ for some constant $a.$ Then\n","\\begin{align*}\n","p(4) &= a(4 - 2)(4 + 1)(4 - 4) = 0, \\\\\n","p(5) &= a(5 - 2)(5 + 1)(5 - 4) = 12a, \\\\\n","p(7) &= a(7 - 2)(7 + 1)(7 - 4) = 60a = 5p(5).\n","\\end{align*}Since $p(5) = 8,$ $p(7) = \\boxed{40}.$\n","PARSED: [40, '40']\n","GOLD: Since $p(2) = p(-1) = 0,$ $p(x)$ is of the form\n","\\[p(x) = (ax + b)(x - 2)(x + 1)\\]for some constants $a$ and $b.$ Setting $x = 4$ and $x = 5,$ we get\n","\\begin{align*}\n","(4a + b)(2)(5) &= p(4) = 6, \\\\\n","(5a + b)(3)(6) &= p(5) = 8,\n","\\end{align*}so\n","\\begin{align*}\n","4a + b &= \\frac{3}{5}, \\\\\n","5a + b &= \\frac{4}{9}.\n","\\end{align*}Solving, we find $a = -\\frac{7}{45}$ and $b = \\frac{11}{9}.$ Hence,\n","\\[p(x) = \\left( -\\frac{7}{45} x + \\frac{11}{9} \\right) (x - 2)(x + 1) = -\\frac{(7x - 55)(x - 2)(x + 1)}{45}.\\]Therefore,\n","\\[p(7) = -\\frac{(49 - 55)(5)(8)}{45} = \\boxed{\\frac{16}{3}}.\\]\n","CORRECT: False\n"]},{"output_type":"display_data","data":{"text/plain":[""],"text/html":[]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":[""],"text/html":["
Run history:
accuracy ▁ finished ▁ length ▁ train_loss ▇▅▆▅█▅▆▄▅▆▇▄▃▆▃▃▅▂▂▁▄▃▅▃▁
Run summary:
accuracy 0.25 finished 1 length 258.75 train_loss 0.54794
"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":[""],"text/html":["You can sync this run to the cloud by running:
wandb sync /content/batch-size/finetuning/wandb/offline-run-20250827_192852-b7npqj4d"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":[""],"text/html":["Find logs at: ./wandb/offline-run-20250827_192852-b7npqj4d/logs"]},"metadata":{}}]},{"cell_type":"code","source":[],"metadata":{"id":"iaZETfC4Txwn"},"execution_count":null,"outputs":[]}]}
--------------------------------------------------------------------------------