├── .gitignore ├── .idea ├── .gitignore ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── other.xml ├── rwkv-decon.iml └── vcs.xml ├── README.md ├── archive ├── gpt_configs.py ├── gpt_decon.py ├── notebook.py └── pico_utils.py ├── copy_init ├── copy_rwkv.py ├── test_weights.py └── weights.py ├── custom_dataset.py ├── custom_dataset_str.py ├── gpt_recon ├── LICENSE ├── attn.html ├── bpe_encoder.py ├── chart4.svg ├── clean_frame.py ├── clean_frame_utils.py ├── gpt.py ├── lines.svg ├── load_model.py ├── run_gpt.py ├── train_gpt.py ├── train_notebook.py ├── tune_gpt.py ├── view_attn.py ├── view_notebook.py └── view_vec.py ├── labels.py ├── lra_arena ├── lra_utils.py ├── rwkv_lra.py └── use_lra_demo.py ├── mypy.ini ├── nlp_utils.py ├── pico_lru ├── pico_lru.py ├── pico_lru_parallel.py └── pico_rnn.py ├── pico_rwkv ├── associative_scan_toolbox.py ├── close_to_original │ ├── pico_rwkv.py │ └── run_rwkv.py ├── jax_load.py ├── pico_rwkv_parallel.py ├── pico_rwkv_parallel_alternatives.py ├── pico_rwkv_rnn.py ├── pico_rwkv_weights.py ├── pth_to_safet.py ├── run_rwkv_parallel.py ├── rwkv_torch.py ├── rwkv_weight_profile.py └── train_rwkv.py ├── pico_s5 └── pico_s5.py ├── picojax ├── jax_utils.py ├── random_utils.py └── train_utils.py ├── python_utils.py └── saves ├── view_vec.npy ├── view_vec2_dict ├── view_vec2_dict_jit ├── view_vec2_dict_new ├── view_vec2_step1.npy └── view_vecs_dict /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # wanb/jax 132 | jax-trace/* 133 | wandb/* 134 | wandb 135 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 93 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | 8 | 10 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | -------------------------------------------------------------------------------- /.idea/rwkv-decon.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 16 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # rwkv-decon 2 | Trying to deconstruct RWKV in understandable terms 3 | -------------------------------------------------------------------------------- /archive/gpt_configs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from functools import cache 5 | from typing import NamedTuple, Optional, cast, TypedDict, List, Literal, Union 6 | 7 | from chex import assert_shape 8 | from einops import rearrange 9 | from jax import numpy as jnp 10 | from jax.nn import softmax 11 | from jax.experimental.maps import xmap 12 | from safetensors.flax import save_file 13 | 14 | from clean_frame import Linear, for_all_T, gelu, LN 15 | from clean_frame_utils import check_config, WeightConfigDict, PartsDict, WeightsTree, config_weights_check, Arr, \ 16 | WeightConfig 17 | 18 | 19 | class GptMha: 20 | class Config(NamedTuple): 21 | n_channels: Optional[int] = None 22 | n_heads: Optional[int] = None 23 | n_seq: Optional[Union[int, Literal['dynamic']]] = None 24 | inf_mask: float = -1e10 25 | linear: Linear.Config = Linear.Config() 26 | QKV_linear: Linear.Config = Linear.Config() 27 | 28 | name: str = "mha" 29 | 30 | @property 31 | @cache 32 | def T(self) -> Optional[int]: 33 | if self.n_seq == 'dynamic': 34 | return None 35 | else: 36 | return self.n_seq 37 | 38 | def fill(self) -> GptMha.Config: 39 | assert self.n_channels is not None, 'n_channels must be set' 40 | new = self._replace(linear=self.linear._replace(n_in=self.n_channels, n_out=self.n_channels), 41 | QKV_linear=self.QKV_linear._replace(n_in=self.n_channels, n_out=3 * self.n_channels)) 42 | check_config(new) 43 | return new 44 | 45 | def make(self) -> GptMha: 46 | return GptMha(self.fill()) 47 | 48 | @property 49 | def dim_heads(self) -> int: 50 | assert self.n_channels is not None 51 | assert self.n_heads is not None 52 | return self.n_channels // self.n_heads 53 | 54 | @property 55 | def weights(self) -> WeightConfigDict: 56 | return {} 57 | 58 | @property 59 | def parts(self) -> PartsDict: 60 | filled = self.fill() 61 | return dict( 62 | linear=filled.linear, 63 | QKV_linear=filled.QKV_linear 64 | ) 65 | 66 | def weights_check(self, w: WeightsTree) -> GptMha.Weights: 67 | return cast(GptMha.Weights, config_weights_check(self, w)) 68 | 69 | class Weights(TypedDict): 70 | QKV_linear: Linear.Weights 71 | linear: Linear.Weights 72 | 73 | def __init__(self, config: Config): 74 | assert config.n_channels is not None 75 | assert config.n_heads is not None 76 | assert config.dim_heads is not None 77 | self.n_channels = config.n_channels 78 | self.n_heads = config.n_heads 79 | self.T = config.T 80 | self.dim_heads = config.dim_heads 81 | self.config = config 82 | 83 | self.linear = config.linear.make() 84 | self.QKV_linear = config.QKV_linear.make() 85 | self.scale = math.sqrt(self.dim_heads) 86 | self.linearf = for_all_T(self.linear.f) 87 | self.QKV_linearf = for_all_T(self.QKV_linear.f) 88 | assert self.n_channels % self.n_heads == 0, 'n_channels must be divisible by n_heads' 89 | 90 | def get_mask(self, t: int) -> Arr: 91 | return (1 - jnp.tri(t)) * self.config.inf_mask 92 | 93 | 94 | class GptFfn: 95 | class Config(NamedTuple): 96 | n_channels: Optional[int] = None 97 | n_seq: Optional[Union[int, Literal['dynamic']]] = None 98 | linear1: Linear.Config = Linear.Config() 99 | linear2: Linear.Config = Linear.Config() 100 | name: str = "ffn" 101 | 102 | @property 103 | @cache 104 | def T(self) -> Optional[int]: 105 | if self.n_seq == 'dynamic': 106 | return None 107 | else: 108 | return self.n_seq 109 | 110 | def fill(self) -> GptFfn.Config: 111 | assert self.n_channels is not None 112 | new = self._replace( 113 | linear1=self.linear1._replace(n_in=self.n_channels, n_out=self.n_channels * 4), 114 | linear2=self.linear2._replace(n_in=self.n_channels * 4, n_out=self.n_channels)) 115 | check_config(new) 116 | return new 117 | 118 | def make(self) -> GptFfn: 119 | return GptFfn(self.fill()) 120 | 121 | @property 122 | def weights(self) -> WeightConfigDict: 123 | return {} 124 | 125 | @property 126 | def parts(self) -> PartsDict: 127 | filled = self.fill() 128 | return dict( 129 | linear1=filled.linear1, 130 | linear2=filled.linear2 131 | ) 132 | 133 | def weights_check(self, w: WeightsTree) -> GptFfn.Weights: 134 | return cast(GptFfn.Weights, config_weights_check(self, w)) 135 | 136 | class Weights(TypedDict): 137 | linear1: Linear.Weights 138 | linear2: Linear.Weights 139 | 140 | def __init__(self, config: Config): 141 | self.n_channels = config.n_channels 142 | self.linear1 = config.linear1.make() 143 | self.linear2 = config.linear2.make() 144 | 145 | 146 | class GptBlock: 147 | class Config(NamedTuple): 148 | eps: Optional[float] = None 149 | n_channels: Optional[int] = None 150 | n_heads: Optional[int] = None 151 | n_seq: Optional[Union[int, Literal['dynamic']]] = None 152 | mha: GptMha.Config = GptMha.Config() 153 | ffn: GptFfn.Config = GptFfn.Config() 154 | ln1: LN.Config = LN.Config() 155 | ln2: LN.Config = LN.Config() 156 | name: str = "gpt_block" 157 | 158 | @property 159 | def T(self) -> Optional[int]: 160 | if self.n_seq == 'dynamic': 161 | return None 162 | else: 163 | return self.n_seq 164 | 165 | @property 166 | def x_shape(self) -> tuple[Optional[int], ...]: 167 | assert self.n_channels is not None 168 | return self.T, self.n_channels 169 | 170 | def fill(self) -> GptBlock.Config: 171 | new = self._replace( 172 | mha=self.mha._replace(n_channels=self.n_channels, n_seq=self.n_seq, n_heads=self.n_heads).fill(), 173 | ffn=self.ffn._replace(n_channels=self.n_channels, n_seq=self.n_seq).fill(), 174 | ln1=self.ln1._replace(eps=self.eps, norm_dims=(0,), x_shape=self.x_shape), 175 | ln2=self.ln2._replace(eps=self.eps, norm_dims=(0,), x_shape=self.x_shape)) 176 | check_config(new) 177 | return new 178 | 179 | def make(self) -> GptBlock: 180 | return GptBlock(self.fill()) 181 | 182 | @property 183 | def weights(self) -> WeightConfigDict: 184 | return {} 185 | 186 | @property 187 | def parts(self) -> PartsDict: 188 | filled = self.fill() 189 | return dict( 190 | mha=filled.mha, 191 | ffn=filled.ffn, 192 | ln1=filled.ln1, 193 | ln2=filled.ln2, 194 | ) 195 | 196 | def weights_check(self, w: WeightsTree) -> GptBlock.Weights: 197 | return cast(GptBlock.Weights, config_weights_check(self, w)) 198 | 199 | class Weights(TypedDict): 200 | mha: GptMha.Weights 201 | ffn: GptFfn.Weights 202 | ln1: LN.Weights 203 | ln2: LN.Weights 204 | 205 | def __init__(self, config: Config): 206 | self.T = config.T 207 | self.n_channels = config.n_channels 208 | self.mha = config.mha.make() 209 | self.ffn = config.ffn.make() 210 | self.ln1 = config.ln1.make() 211 | self.ln2 = config.ln2.make() 212 | 213 | self.ffnf = for_all_T(self.ffn.f) 214 | self.ln1f = self.ln1.f 215 | self.ln2f = self.ln2.f 216 | 217 | 218 | 219 | class GptDecoder: 220 | class Config(NamedTuple): 221 | eps: Optional[float] = None 222 | n_channels: Optional[int] = None 223 | n_heads: Optional[int] = None 224 | n_seq: Optional[Union[int, Literal['dynamic']]] = None 225 | n_blocks: Optional[int] = None 226 | blocks: GptBlock.Config = GptBlock.Config() 227 | 228 | name: str = 'gpt_decoder' 229 | 230 | @property 231 | @cache 232 | def T(self) -> Optional[int]: 233 | if self.n_seq == 'dynamic': 234 | return None 235 | else: 236 | return self.n_seq 237 | 238 | def fill(self) -> GptDecoder.Config: 239 | new = self._replace(blocks=self.blocks._replace(eps=self.eps, n_channels=self.n_channels, 240 | n_heads=self.n_heads, n_seq=self.n_seq).fill()) 241 | check_config(new) 242 | return new 243 | 244 | def make(self) -> GptDecoder: 245 | return GptDecoder(self.fill()) 246 | 247 | @property 248 | def weights(self) -> WeightConfigDict: 249 | return {} 250 | 251 | @property 252 | def parts(self) -> PartsDict: 253 | filled = self.fill() 254 | assert filled.blocks is not None 255 | assert filled.n_blocks is not None 256 | return dict( 257 | blocks=[filled.blocks] * filled.n_blocks 258 | ) 259 | 260 | def weights_check(self, w: WeightsTree) -> GptDecoder.Weights: 261 | return cast(GptDecoder.Weights, config_weights_check(self, w)) 262 | 263 | class Weights(TypedDict): 264 | blocks: List[GptBlock.Weights] 265 | 266 | def __init__(self, config: Config): 267 | assert config.n_blocks is not None 268 | self.T = config.T 269 | self.n_channels = config.n_channels 270 | self.blocks = [config.blocks.make() for _ in range(config.n_blocks)] 271 | 272 | 273 | 274 | class Gpt: 275 | class Config(NamedTuple): 276 | eps: Optional[float] = None 277 | n_channels: Optional[int] = None 278 | n_heads: Optional[int] = None 279 | n_seq: Optional[Union[int, Literal['dynamic']]] = None 280 | n_blocks: Optional[int] = None 281 | n_tokens: Optional[int] = None 282 | max_seq_len: Optional[int] = None 283 | 284 | te_name: str = 'te' 285 | te_init: Literal['normal'] = 'normal' 286 | te_scale: float = 0.02 287 | 288 | decoder: GptDecoder.Config = GptDecoder.Config() 289 | ln: LN.Config = LN.Config() 290 | 291 | pe_name: str = 'pe' 292 | 293 | name: str = 'gpt' 294 | 295 | @property 296 | def T(self) -> Optional[int]: 297 | if self.n_seq == 'dynamic': 298 | return None 299 | else: 300 | return self.n_seq 301 | 302 | def fill(self) -> Gpt.Config: 303 | assert self.n_channels is not None, 'n_channels must be specified' 304 | new = self._replace(decoder=self.decoder._replace(eps=self.eps, n_channels=self.n_channels, 305 | n_heads=self.n_heads, n_seq=self.n_seq, 306 | n_blocks=self.n_blocks).fill(), 307 | ln=self.ln._replace(eps=self.eps, norm_dims=(0,), x_shape=(self.T, self.n_channels))) 308 | 309 | check_config(new) 310 | return new 311 | 312 | def make(self) -> Gpt: 313 | return Gpt(self.fill()) 314 | 315 | @property 316 | def weights(self) -> WeightConfigDict: 317 | filled = self.fill() 318 | assert filled.max_seq_len is not None 319 | assert filled.n_tokens is not None 320 | assert filled.n_channels is not None 321 | return dict( 322 | token_embedding=WeightConfig(name=filled.te_name, 323 | init=filled.te_init, 324 | shape=(filled.n_tokens, filled.n_channels), 325 | scale=filled.te_scale), 326 | 327 | positional_encoding=WeightConfig(name=filled.pe_name, 328 | shape=(filled.max_seq_len, filled.n_channels)), 329 | ) 330 | 331 | @property 332 | def parts(self) -> PartsDict: 333 | filled = self.fill() 334 | assert filled.decoder is not None 335 | assert filled.ln is not None 336 | assert filled.te_name is not None 337 | return dict( 338 | decoder=filled.decoder, 339 | ln=filled.ln, 340 | ) 341 | 342 | def weights_check(self, w: WeightsTree) -> Gpt.Weights: 343 | return cast(Gpt.Weights, config_weights_check(self, w)) 344 | 345 | def __init__(self, config: Config): 346 | assert config.n_blocks is not None 347 | assert config.n_tokens is not None 348 | assert config.n_channels is not None 349 | self.T = config.T 350 | self.n_channels = config.n_channels 351 | self.n_tokens = config.n_tokens 352 | self.eps = config.eps 353 | self.decoder = config.decoder.make() 354 | self.ln = config.ln.make() 355 | 356 | -------------------------------------------------------------------------------- /archive/gpt_decon.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from __future__ import annotations 3 | 4 | import math 5 | from typing import Callable 6 | 7 | import jax 8 | from jax.experimental.maps import xmap 9 | from optax import softmax_cross_entropy_with_integer_labels 10 | from simple_pytree import Pytree, static_field 11 | from jax.numpy import mean, var, sqrt, tanh, pi 12 | from jax.nn import softmax 13 | from jax.lax import rsqrt 14 | import jax.numpy as jnp 15 | 16 | Arr = jax.Array 17 | 18 | 19 | def gelu(x: Arr) -> Arr: 20 | return 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x ** 3))) 21 | 22 | 23 | class LN(Pytree): 24 | eps: float = static_field() 25 | 26 | def __init__(self, eps: float, w: Arr, b: Arr): 27 | self.eps = eps 28 | self.w = w 29 | self.b = b 30 | 31 | # x: (n_channels,) 32 | def f(self, x: Arr) -> Arr: 33 | o = x - mean(x) 34 | i = self.w * rsqrt(var(x) + self.eps) 35 | return o * i + self.b 36 | 37 | 38 | class Linear(Pytree): 39 | def __init__(self, w: Arr, b: Arr): 40 | self.w = w 41 | self.b = b 42 | 43 | # x: (n_channels,) 44 | def f(self, x: Arr) -> Arr: 45 | return self.w @ x + self.b 46 | 47 | 48 | class GptMha(Pytree): 49 | n_heads: int = static_field() 50 | scale: float = static_field() 51 | dim_head: int = static_field() 52 | T: int = static_field() 53 | mask: Arr = static_field() 54 | linearf: Callable[[Arr], Arr] = static_field() 55 | 56 | # q, k, v: (dim_head, T) 57 | def causal_dot_attention(self, q: Arr, k: Arr, v: Arr) -> Arr: 58 | return softmax((q.T @ k) / self.scale + self.mask) @ v 59 | 60 | # QKV: (3, n_heads, dim_head, n_channels) 61 | def __init__(self, n_channels: int, n_heads: int, T: int, QKV: Arr, linear: Linear): 62 | self.scale = math.sqrt(n_channels) 63 | self.n_heads = n_heads 64 | self.QKV = QKV 65 | self.linear = linear 66 | self.mask = jnp.tril(jnp.ones((T, T))) 67 | self.linearf = xmap(linear.f, [None, 'T'], 'T') 68 | 69 | # x: (n_channels, T) 70 | def f(self, x: Arr) -> Arr: 71 | q, k, v = self.QKV @ x 72 | xmap_in = [['n_head', ...]] * 3 73 | attended = xmap(self.causal_dot_attention, xmap_in, 'n_head')(q, k, v) 74 | # attended: (n_heads, dim_head, T) 75 | 76 | return self.linearf(jnp.concatenate(attended)) 77 | 78 | 79 | class GptFfn(Pytree): 80 | def __init__(self, linear1: Linear, linear2: Linear): 81 | self.linear1 = linear1 82 | self.linear2 = linear2 83 | 84 | # x: (n_channels,) 85 | def f(self, x: Arr) -> Arr: 86 | return self.linear2.f(gelu(self.linear1.f(x.T))).T 87 | 88 | 89 | class GptBlock(Pytree): 90 | ln1f: Callable[[Arr], Arr] = static_field() 91 | ln2f: Callable[[Arr], Arr] = static_field() 92 | ffnf: Callable[[Arr], Arr] = static_field() 93 | 94 | def __init__(self, mha: GptMha, ffn: GptFfn, ln1: LN, ln2: LN): 95 | self.mha = mha 96 | self.ffn = ffn 97 | self.ln1 = ln1 98 | self.ln2 = ln2 99 | self.ln1f = xmap(ln1.f, [None, 'T'], 'T') 100 | self.ln2f = xmap(ln2.f, [None, 'T'], 'T') 101 | self.ffnf = xmap(ffn.f, [None, 'T'], 'T') 102 | 103 | # x: (n_channels, T) 104 | def f(self, x: Arr) -> Arr: 105 | x += self.mha.f(self.ln1f(x)) 106 | x += self.ffnf(self.ln2f(x)) 107 | return x 108 | 109 | 110 | class GptDecoder(Pytree): 111 | def __init__(self, blocks: list[GptBlock]): 112 | self.blocks = blocks 113 | 114 | # x: (n_channels, T) 115 | def f(self, x: Arr) -> Arr: 116 | for block in self.blocks: 117 | x = block.f(x) 118 | return x 119 | 120 | 121 | class Gpt(Pytree): 122 | T: int = static_field() 123 | 124 | def __init__(self, T: int, decoder: GptDecoder, token_embed: Arr, position_embed: Arr, ln: LN): 125 | self.decoder = decoder 126 | self.ln = ln 127 | self.te = token_embed 128 | self.pe = position_embed 129 | self.T = T 130 | 131 | # x: (n_channels, T) 132 | def f(self, x: Arr) -> Arr: 133 | x = self.te[x] + self.pe[self.T] 134 | return self.ln.f(self.decoder.f(x)) 135 | 136 | 137 | def gpt_loss(gpt: Gpt, inputs: list[int], labels: list[int]) -> Arr: 138 | logits = gpt.f(jnp.array(inputs)) 139 | return softmax_cross_entropy_with_integer_labels(logits, labels) 140 | 141 | 142 | -------------------------------------------------------------------------------- /archive/notebook.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | """ 6 | mk o i g u K R V mv mv 7 | fk fi fK fR fv 8 | ln0b ln0w ln1b ln1w ln2b ln2w 9 | """ 10 | 11 | sigma = jax.nn.sigmoid 12 | maxi = jax.numpy.maximum 13 | e = jax.numpy.exp 14 | relu = jax.nn.relu 15 | 16 | 17 | def mix(_in, _state, _mix): 18 | return _in * _mix + _state * (1 - _mix) 19 | 20 | 21 | """ 22 | x: 1024 23 | s: 1024 24 | 25 | """ 26 | 27 | 28 | def ln(x, w, b): 29 | mean = jnp.mean(x) 30 | v = jnp.var(x) 31 | o = x - mean 32 | i = w * jax.lax.rsqrt(v + 1e-5) 33 | return o * i + b 34 | 35 | 36 | def mix_t(x, s, K, V, R, mk, o, mr, g, u, mv, aa, bb, pp): 37 | xk, xv, xr = mix(x, s, mk), mix(x, s, mv), mix(x, s, mr) 38 | s.update_(x) 39 | 40 | k = xk @ K 41 | v = xv @ V 42 | 43 | ww = u + k 44 | p = maxi(pp, ww) 45 | wkv = (e(pp - p) * aa + e(ww - p) * v) / (e(pp - p) * bb + e(ww - p)) 46 | 47 | p = maxi(pp + g, k) 48 | aa.update_(e(ww - p) * aa + e(k - p) * v) 49 | bb.update_(e(ww - p) * bb + e(k - p)) 50 | pp.update_(p) 51 | 52 | return (sigma(xr @ R) * wkv) @ o 53 | 54 | 55 | def mix_h(x, ss, K, R, mk, mr, V): 56 | xk, xr = mix(x, ss, mk), mix(x, ss, mr) 57 | ss.update_(x) 58 | return sigma(xr @ R) * (relu(xk @ K) ** 2 @ V) 59 | 60 | 61 | def cell(x, s, K, V, R, mk, o, mr, g, u, mv, aa, bb, pp, ln1b, ln1w, ln2b, ln2w, fK, fR, mfk, mfr, fV): 62 | xx = ln(x, ln1w, ln1b) 63 | x += mix_t(xx, s, K, V, R, mk, o, mr, g, u, mv, aa, bb, pp) 64 | x = ln(x, ln2w, ln2b) 65 | xx = mix_h(x, s, fK, fR, mfk, mfr, fV) 66 | return x + xx 67 | -------------------------------------------------------------------------------- /archive/pico_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | 5 | import numpy as np 6 | import requests 7 | import tensorflow as tf 8 | from tqdm import tqdm 9 | 10 | from bpe_encoder import get_encoder 11 | 12 | 13 | def download_gpt2_files(model_size, model_dir): 14 | assert model_size in ["124M", "355M", "774M", "1558M"] 15 | for filename in [ 16 | "checkpoint", 17 | "encoder.json", 18 | "hparams.json", 19 | "model.ckpt.data-00000-of-00001", 20 | "model.ckpt.index", 21 | "model.ckpt.meta", 22 | "vocab.bpe", 23 | ]: 24 | url = "https://openaipublic.blob.core.windows.net/gpt-2/models" 25 | r = requests.get(f"{url}/{model_size}/{filename}", stream=True) 26 | r.raise_for_status() 27 | 28 | with open(os.path.join(model_dir, filename), "wb") as f: 29 | file_size = int(r.headers["content-length"]) 30 | chunk_size = 1000 31 | with tqdm( 32 | ncols=100, 33 | desc="Fetching " + filename, 34 | total=file_size, 35 | unit_scale=True, 36 | unit="b", 37 | ) as pbar: 38 | # 1k for chunk_size, since Ethernet packet size is around 1500 bytes 39 | for chunk in r.iter_content(chunk_size=chunk_size): 40 | f.write(chunk) 41 | pbar.update(chunk_size) 42 | 43 | 44 | def load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams): 45 | def set_in_nested_dict(d, keys, val): 46 | if not keys: 47 | return val 48 | if keys[0] not in d: 49 | d[keys[0]] = {} 50 | d[keys[0]] = set_in_nested_dict(d[keys[0]], keys[1:], val) 51 | return d 52 | 53 | params = {"blocks": [{} for _ in range(hparams["n_layer"])]} 54 | for name, _ in tf.train.list_variables(tf_ckpt_path): 55 | array = np.squeeze(tf.train.load_variable(tf_ckpt_path, name)) 56 | name = name[len("model/") :] 57 | if name.startswith("h"): 58 | m = re.match(r"h([0-9]+)/(.*)", name) 59 | n = int(m[1]) 60 | sub_name = m[2] 61 | set_in_nested_dict(params["blocks"][n], sub_name.split("/"), array) 62 | else: 63 | set_in_nested_dict(params, name.split("/"), array) 64 | 65 | return params 66 | 67 | 68 | def load_encoder_hparams_and_params(model_size, models_dir): 69 | assert model_size in ["124M", "355M", "774M", "1558M"] 70 | 71 | model_dir = os.path.join(models_dir, model_size) 72 | tf_ckpt_path = tf.train.latest_checkpoint(model_dir) 73 | if not tf_ckpt_path: # download files if necessary 74 | os.makedirs(model_dir, exist_ok=True) 75 | download_gpt2_files(model_size, model_dir) 76 | tf_ckpt_path = tf.train.latest_checkpoint(model_dir) 77 | 78 | encoder = get_encoder(model_size, models_dir) 79 | hparams = json.load(open(os.path.join(model_dir, "hparams.json"))) 80 | params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams) 81 | 82 | return encoder, hparams, params -------------------------------------------------------------------------------- /copy_init/copy_rwkv.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from __future__ import annotations 3 | 4 | from pathlib import Path 5 | 6 | from copy_init.weights import get_normal_weights_config_, init 7 | from picojax.random_utils import infinite_safe_keys 8 | 9 | path = Path("/Data/lm_models/rwkv") 10 | model_name = 'RWKV-4-Pile-430M-20220808-8066' 11 | 12 | weight_infos = get_normal_weights_config_(path, model_name) 13 | 14 | keygen = infinite_safe_keys(0) 15 | key = next(keygen) 16 | 17 | w = init(weight_infos, rng_key=key) 18 | 19 | print(w) -------------------------------------------------------------------------------- /copy_init/test_weights.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from __future__ import annotations 3 | 4 | from pathlib import Path 5 | 6 | import jax.numpy as np 7 | from jax.tree_util import tree_flatten 8 | 9 | from copy_init.weights import get_normal_weights_config_, init, save_pytree_, load_pytree_ 10 | from pico_rwkv.pico_rwkv_weights import parse_rwkv_weight 11 | from picojax.jax_utils import WeightsTree 12 | from picojax.random_utils import infinite_safe_keys 13 | 14 | model_path = Path("/Data/lm_models/rwkv") 15 | # model_name = 'RWKV-4-Pile-430M-20220808-8066' 16 | model_name = 'RWKV-4-Pile-169M-20220807-8023' 17 | 18 | weight_infos = get_normal_weights_config_(model_path, model_name) 19 | keygen = infinite_safe_keys(0) 20 | key = next(keygen) 21 | init_weights_raw = init(weight_infos, rng_key=key) 22 | init_weights_: WeightsTree = parse_rwkv_weight(init_weights_raw.keys(), init_weights_raw.__getitem__, trim=True) 23 | _, tree_struct = tree_flatten(init_weights_) 24 | f = save_pytree_(init_weights_, ".", model_name) 25 | print(f) 26 | ww = load_pytree_(tree_struct, ".", model_name) 27 | 28 | 29 | def nested_dict_equal(a, b): 30 | """ 31 | Check if two nested dictionaries of NumPy arrays and lists are equal. 32 | """ 33 | if isinstance(a, dict) and isinstance(b, dict): 34 | if a.keys() != b.keys(): 35 | return False 36 | for key in a.keys(): 37 | if not nested_dict_equal(a[key], b[key]): 38 | print("kd", key) 39 | print(a[key].shape, b[key].shape) 40 | return False 41 | return True 42 | elif isinstance(a, list) and isinstance(b, list): 43 | if len(a) != len(b): 44 | return False 45 | for i in range(len(a)): 46 | if not nested_dict_equal(a[i], b[i]): 47 | return False 48 | return True 49 | elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): 50 | return np.array_equal(a, b) 51 | else: 52 | return a == b 53 | 54 | 55 | print(nested_dict_equal(init_weights_, ww)) 56 | -------------------------------------------------------------------------------- /copy_init/weights.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from __future__ import annotations 3 | 4 | import json 5 | import os 6 | from pathlib import Path 7 | from typing import Protocol, Literal, Optional, cast 8 | 9 | import jax.numpy as np 10 | from jax import random 11 | from jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef 12 | from pydantic import ValidationError 13 | from pydantic.dataclasses import dataclass 14 | from pydantic.json import pydantic_encoder 15 | from safetensors import safe_open 16 | from safetensors.flax import save_file 17 | 18 | from labels import Labels, L 19 | from picojax.jax_utils import Arr, WeightsTree 20 | from picojax.random_utils import SafeKey 21 | 22 | WeightConfigType = Literal['normal', 'zero'] 23 | 24 | 25 | class WeightConfig(Protocol): 26 | name: str 27 | tags: Labels 28 | shape: tuple[int, ...] 29 | 30 | def init(self, rng_key: SafeKey) -> Arr: 31 | ... 32 | 33 | @classmethod 34 | def from_arr(cls, name: str, tags: Labels, arr: Arr) -> WeightConfig: 35 | ... 36 | 37 | 38 | @dataclass 39 | class NormalWeight: 40 | name: str 41 | tags: Labels 42 | shape: tuple[int, ...] 43 | mean: float = 0 44 | scale: float = 1 45 | 46 | def init(self, rng_key: SafeKey) -> Arr: 47 | return random.normal(rng_key.get(), self.shape) * self.scale + self.mean 48 | 49 | @classmethod 50 | def from_arr(cls, name: str, tags: Labels, arr: Arr) -> NormalWeight: 51 | return cls(name=name, tags=tags, shape=arr.shape, mean=arr.mean().item(), scale=arr.std().item()) 52 | 53 | 54 | @dataclass 55 | class ZeroWeight: 56 | name: str 57 | tags: Labels 58 | shape: tuple[int, ...] 59 | 60 | def init(self, rng_key: SafeKey) -> Arr: 61 | return np.zeros(self.shape) 62 | 63 | @classmethod 64 | def from_arr(cls, name: str, tags: Labels, arr: Arr) -> ZeroWeight: 65 | return cls(name=name, tags=tags, shape=arr.shape) 66 | 67 | 68 | 69 | def get_weight_config(weight_config_type: WeightConfigType, key: str, arr: Arr, tags: Labels) -> WeightConfig: 70 | if weight_config_type == 'normal': 71 | return NormalWeight.from_arr(name=key, tags=tags, arr=arr) 72 | elif weight_config_type == 'zero': 73 | return ZeroWeight.from_arr(name=key, tags=tags, arr=arr) 74 | else: 75 | raise ValueError(f"weight_config_type must be 'normal' or 'zero', not {weight_config_type}") 76 | 77 | 78 | def get_normal_weights_config_(path: Path, model_name: str, 79 | non_normal_weight_tags: Optional[dict[str, WeightConfigType]] = None) -> dict[ 80 | str, WeightConfig]: 81 | try: 82 | return load_weight_configs_(path, model_name) 83 | except FileNotFoundError: 84 | weight_infos: dict[str, WeightConfig] = {} 85 | with safe_open(path / f"{model_name}.safetensors", framework="flax", device="cpu") as f: 86 | for key in f.keys(): 87 | t = f.get_tensor(key) 88 | # weight_infos[key] = NormalWeight.from_tensor(name=key, tags=L(*key.split('.')), arr=t) 89 | tags = L(*key.split('.')) 90 | weight_config_type = 'normal' 91 | if non_normal_weight_tags is not None: 92 | for tag, _type in non_normal_weight_tags.items(): 93 | if (tag in tags.tags) and (_type in ['normal', 'zero']): 94 | weight_config_type = _type 95 | weight_config_type = cast(WeightConfigType, weight_config_type) 96 | weight_infos[key] = get_weight_config(weight_config_type, key, t, tags) 97 | 98 | with open(path / f"{model_name}.json", 'w') as f: 99 | json.dump(weight_infos, f, indent=2, default=pydantic_encoder) 100 | return weight_infos 101 | 102 | 103 | def load_weight_configs_(path: Path, model_name: str) -> dict[str, WeightConfig]: 104 | with open(path / f"{model_name}.json") as f: 105 | weight_info_dict = json.load(f) 106 | weight_infos: dict[str, WeightConfig] = {} 107 | for key, value in weight_info_dict.items(): 108 | try: 109 | weight_infos[key] = NormalWeight(**value) 110 | 111 | except ValidationError as e: 112 | print(e) 113 | return weight_infos 114 | 115 | 116 | def init(w_infos: dict[str, WeightConfig], rng_key: SafeKey) -> dict[str, Arr]: 117 | rng_keys = rng_key.split(len(w_infos)) 118 | return {key: w_info.init(rng_key) 119 | for key, w_info, rng_key in zip(w_infos.keys(), w_infos.values(), rng_keys)} 120 | 121 | 122 | def get_weights_mask(whitelist: list[Labels], w_infos: dict[str, WeightConfig]) -> dict[str, bool]: 123 | def is_in_whitelist(w_info: WeightConfig) -> bool: 124 | for tag in whitelist: 125 | if w_info.tags.covers(tag): 126 | return True 127 | return False 128 | 129 | return {key: is_in_whitelist(w_info) for key, w_info in w_infos.items()} 130 | 131 | 132 | def save_pytree_(w: WeightsTree, checkpoint_path: str, model_name: str) -> str: 133 | tensors, _ = tree_flatten(w) 134 | file_path = os.path.join(checkpoint_path, f"{model_name}.safetensors") 135 | n_tensors = len(tensors) 136 | save_file({str(i).zfill(len(str(n_tensors))): t for i, t in enumerate(tensors)}, file_path) 137 | return file_path 138 | 139 | 140 | def load_pytree_(t: PyTreeDef, checkpoint_path: str, model_name: str) -> WeightsTree: 141 | file_path = os.path.join(checkpoint_path, f"{model_name}.safetensors") 142 | with safe_open(file_path, framework="flax") as f: 143 | v = iter(f.get_tensor(k) for k in f.keys()) 144 | return tree_unflatten(t, v) 145 | -------------------------------------------------------------------------------- /custom_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import NamedTuple, Callable 6 | 7 | import jax.numpy as xp 8 | 9 | # dataset = "play" 10 | base_path = Path("/Data/nlp/") 11 | 12 | def load(dataset: str = "english"): 13 | path = base_path / dataset 14 | books = [f for f in os.listdir(path) if f.endswith('.txt')] 15 | 16 | def read_book_file(filename: str): 17 | print(f"reading {filename}...") 18 | try: 19 | with open(path / filename, 'r', encoding='utf-8') as f: 20 | return f.read() 21 | except UnicodeDecodeError: 22 | try: 23 | with open(path / filename, 'r', encoding='gb2312') as f: 24 | return f.read() 25 | except UnicodeDecodeError: 26 | try: 27 | with open(path / filename, 'r', encoding='gbk') as f: 28 | return f.read() 29 | except UnicodeDecodeError: 30 | try: 31 | with open(path / filename, 'r', encoding='big5') as f: 32 | return f.read() 33 | except UnicodeDecodeError: 34 | try: 35 | with open(path / filename, 'r', encoding='utf-16') as f: 36 | return f.read() 37 | except UnicodeDecodeError: 38 | try: 39 | with open(path / filename, 'r', encoding='gb18030') as f: 40 | return f.read() 41 | except UnicodeDecodeError: 42 | raise Exception(f"Failed to read {filename} with many encodings") 43 | 44 | text = "\n\n".join(f"{book_name}\n\n {read_book_file(book_name)}" for book_name in books) 45 | chars = [ch for ch, c in Counter(text).most_common()] 46 | vocab_size = len(chars) 47 | stoi = {c: i for i, c in enumerate(chars)} 48 | itos = {i: c for i, c in enumerate(chars)} 49 | 50 | def encode(_text: str): 51 | return [stoi[c] for c in _text] 52 | 53 | def decode(_encoded: list): 54 | return "".join(itos[i] for i in _encoded) 55 | 56 | return text, encode, decode, vocab_size 57 | 58 | 59 | def load_jax_cached(dataset: str = "english"): 60 | text, encode, decode, vocab_size = load(dataset) 61 | cache_path = base_path / dataset / 'encoded_jax.npy' 62 | try: 63 | with open(cache_path, 'rb') as f: 64 | encoded_jax = xp.load(f) 65 | except FileNotFoundError: 66 | encoded = encode(text) 67 | encoded_jax = xp.array(encoded, dtype=xp.int16) 68 | print(encoded_jax.shape, encoded_jax.dtype) 69 | with open(cache_path, 'wb') as fw: 70 | xp.save(fw, encoded_jax) 71 | return encoded_jax, encode, decode, vocab_size 72 | 73 | 74 | @dataclass 75 | class Tokens: 76 | ids: list[int] 77 | 78 | 79 | class Tokenizer(NamedTuple): 80 | vocab_size: int 81 | encode_: Callable[[str], Tokens] 82 | decode_: Callable[[list[int]], str] 83 | 84 | def encode(self, text: str) -> Tokens: 85 | return self.encode_(text) 86 | 87 | def decode(self, tokens: list[int]) -> str: 88 | return self.decode_(tokens) 89 | 90 | def get_vocab_size(self) -> int: 91 | return self.vocab_size 92 | 93 | 94 | def load_jax_cached_tokenizer(base_path: Path, dataset: str = "english") -> tuple[xp.ndarray, Tokenizer]: 95 | text, encode, decode, vocab_size = load(dataset) 96 | cache_path = base_path / dataset / 'encoded_jax.npy' 97 | try: 98 | with open(cache_path, 'rb') as f: 99 | encoded_jax = xp.load(f) 100 | except FileNotFoundError: 101 | encoded = encode(text) 102 | encoded_jax = xp.array(encoded, dtype=xp.int16) 103 | print(encoded_jax.shape, encoded_jax.dtype) 104 | with open(cache_path, 'wb') as fw: 105 | xp.save(fw, encoded_jax) 106 | return encoded_jax, Tokenizer(vocab_size, lambda x: Tokens(encode(x)), decode) 107 | -------------------------------------------------------------------------------- /custom_dataset_str.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | 5 | def load(base_path: Path, dataset: str = "english"): 6 | path = base_path / dataset 7 | books = [f for f in os.listdir(path) if f.endswith('.txt')] 8 | 9 | def read_book_file(filename: str): 10 | print(f"reading {filename}...") 11 | encodings = ['utf-8', 'gb18030', 'gbk', 'gb2312', 'big5', 'utf-16'] 12 | for encoding in encodings: 13 | try: 14 | with open(path / filename, 'r', encoding=encoding) as f: 15 | return f.read() 16 | except UnicodeDecodeError: 17 | print(f"Failed to read {filename} with encoding: {encoding}") 18 | continue 19 | raise Exception(f"Failed to read {filename} with all available encodings") 20 | 21 | text = "\n\n".join(f"{book_name}\n\n {read_book_file(book_name)}" for book_name in books) 22 | return text 23 | -------------------------------------------------------------------------------- /gpt_recon/bpe_encoder.py: -------------------------------------------------------------------------------- 1 | """Byte pair encoding utilities. 2 | 3 | Copied from: https://github.com/openai/gpt-2/blob/master/src/encoder.py. 4 | """ 5 | import json 6 | import os 7 | from functools import lru_cache 8 | 9 | import regex as re 10 | 11 | 12 | @lru_cache() 13 | def bytes_to_unicode(): 14 | """ 15 | Returns list of utf-8 byte and a corresponding list of unicode strings. 16 | The reversible bpe codes work on unicode strings. 17 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 18 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 19 | This is a signficant percentage of your normal, say, 32K bpe vocab. 20 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 21 | And avoids mapping to whitespace/control characters the bpe code barfs on. 22 | """ 23 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 24 | cs = bs[:] 25 | n = 0 26 | for b in range(2**8): 27 | if b not in bs: 28 | bs.append(b) 29 | cs.append(2**8 + n) 30 | n += 1 31 | cs = [chr(n) for n in cs] 32 | return dict(zip(bs, cs)) 33 | 34 | 35 | def get_pairs(word): 36 | """Return set of symbol pairs in a word. 37 | Word is represented as tuple of symbols (symbols being variable-length strings). 38 | """ 39 | pairs = set() 40 | prev_char = word[0] 41 | for char in word[1:]: 42 | pairs.add((prev_char, char)) 43 | prev_char = char 44 | return pairs 45 | 46 | 47 | class Encoder: 48 | def __init__(self, encoder, bpe_merges, errors="replace"): 49 | self.encoder = encoder 50 | self.decoder = {v: k for k, v in self.encoder.items()} 51 | self.errors = errors # how to handle errors in decoding 52 | self.byte_encoder = bytes_to_unicode() 53 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 54 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 55 | self.cache = {} 56 | 57 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 58 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 59 | 60 | def bpe(self, token): 61 | if token in self.cache: 62 | return self.cache[token] 63 | word = tuple(token) 64 | pairs = get_pairs(word) 65 | 66 | if not pairs: 67 | return token 68 | 69 | while True: 70 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 71 | if bigram not in self.bpe_ranks: 72 | break 73 | first, second = bigram 74 | new_word = [] 75 | i = 0 76 | while i < len(word): 77 | try: 78 | j = word.index(first, i) 79 | new_word.extend(word[i:j]) 80 | i = j 81 | except: 82 | new_word.extend(word[i:]) 83 | break 84 | 85 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 86 | new_word.append(first + second) 87 | i += 2 88 | else: 89 | new_word.append(word[i]) 90 | i += 1 91 | new_word = tuple(new_word) 92 | word = new_word 93 | if len(word) == 1: 94 | break 95 | else: 96 | pairs = get_pairs(word) 97 | word = " ".join(word) 98 | self.cache[token] = word 99 | return word 100 | 101 | def encode(self, text): 102 | bpe_tokens = [] 103 | for token in re.findall(self.pat, text): 104 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 105 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) 106 | return bpe_tokens 107 | 108 | def decode(self, tokens): 109 | text = "".join([self.decoder[token] for token in tokens]) 110 | text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) 111 | return text 112 | 113 | 114 | def get_encoder(model_name, models_dir, encoder_file="encoder.json", vocab_file="vocab.bpe"): 115 | with open(os.path.join(models_dir, model_name, encoder_file), "r") as f: 116 | encoder = json.load(f) 117 | with open(os.path.join(models_dir, model_name, vocab_file), "r", encoding="utf-8") as f: 118 | bpe_data = f.read() 119 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] 120 | return Encoder(encoder=encoder, bpe_merges=bpe_merges) -------------------------------------------------------------------------------- /gpt_recon/clean_frame.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from __future__ import annotations 3 | from typing import NamedTuple, TypeVar, Callable, TypedDict, Literal, cast, \ 4 | Optional 5 | 6 | from chex import assert_shape 7 | from jax import vmap 8 | from jax.experimental.maps import xmap 9 | from jax.lax import rsqrt 10 | from jax.numpy import mean, var, sqrt, tanh, pi 11 | 12 | from picojax.jax_utils import jit_f, WeightsTree 13 | from gpt_recon.clean_frame_utils import Arr, PartsDict, WeightConfig, WeightConfigDict, check_config, config_weights_check 14 | 15 | C = TypeVar('C') 16 | W = TypeVar('W') 17 | 18 | 19 | # TODO: unify init_param 20 | # TODO: load real GPT weights 21 | 22 | 23 | def no_w(d: C) -> tuple[list, C]: 24 | return [...], d 25 | 26 | 27 | def batch_ops_x(f: Callable[[W, Arr], Arr], label: str, add_behind: bool, skip_w: bool) -> Callable[[W, Arr], Arr]: 28 | if add_behind: 29 | extension = [..., label] 30 | else: 31 | extension = [label, ...] 32 | if skip_w: 33 | return xmap(f, no_w(extension), extension) 34 | else: 35 | return xmap(f, extension, extension) 36 | 37 | 38 | def batch_ops(f: Callable[[W, Arr], Arr], label: str, add_behind: bool, skip_w: bool) -> Callable[[W, Arr], Arr]: 39 | if add_behind: 40 | extension = -1 41 | else: 42 | extension = 0 43 | if skip_w: 44 | return vmap(f, (None, extension), extension) 45 | else: 46 | return vmap(f, extension, extension) 47 | 48 | 49 | def for_all_T(f: Callable[[W, Arr], Arr]) -> Callable[[W, Arr], Arr]: 50 | return batch_ops(f, 'T', False, True) 51 | 52 | 53 | def batch_fy(f: Callable[[W, Arr], Arr]) -> Callable[[W, Arr], Arr]: 54 | return batch_ops(f, 'batch', False, True) 55 | 56 | 57 | def gelu(x: Arr) -> Arr: 58 | return 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x ** 3))) 59 | 60 | 61 | class Linear: 62 | class Weights(TypedDict): 63 | w: Arr 64 | b: Arr 65 | 66 | @jit_f 67 | def f(self, w: Weights, x: Arr) -> Arr: 68 | assert_shape(x, (self.n_in,)) 69 | result = w['w'].T @ x + w['b'] 70 | assert_shape(result, (self.n_out,)) 71 | return result 72 | 73 | class Config(NamedTuple): 74 | n_in: Optional[int] = None 75 | n_out: Optional[int] = None 76 | w_save_name: str = "w" 77 | b_save_name: str = "b" 78 | w_init: Literal["kaiming"] = "kaiming" 79 | b_init: Literal[0] = 0 80 | save_name: str = "linear" 81 | 82 | def make(self) -> Linear: 83 | check_config(self) 84 | return Linear(self) 85 | 86 | @property 87 | def weights(self) -> WeightConfigDict: 88 | assert self.n_in is not None 89 | assert self.n_out is not None 90 | return dict( 91 | w=WeightConfig( 92 | save_name=self.w_save_name, 93 | shape=(self.n_in, self.n_out), 94 | init=self.w_init 95 | ), 96 | b=WeightConfig( 97 | save_name=self.b_save_name, 98 | shape=(self.n_out,), 99 | init=self.b_init) 100 | ) 101 | 102 | @property 103 | def parts(self) -> PartsDict: 104 | return {} 105 | 106 | def fill(self) -> Linear.Config: 107 | return self 108 | 109 | def weights_check(self, w: WeightsTree) -> Linear.Weights: 110 | return cast(Linear.Weights, config_weights_check(self, w)) 111 | 112 | def __init__(self, config: Config): 113 | self.config = config 114 | assert config.n_in is not None 115 | assert config.n_out is not None 116 | self.n_in = config.n_in 117 | self.n_out = config.n_out 118 | 119 | 120 | class LN: 121 | class Weights(TypedDict): 122 | w: Arr 123 | b: Arr 124 | 125 | # x: self.shape 126 | @jit_f 127 | def f(self, w: LN.Weights, x: Arr) -> Arr: 128 | o = x - mean(x, axis=self.config.non_norm_dims, keepdims=True) 129 | i = o * rsqrt(var(x, axis=self.config.non_norm_dims, keepdims=True) + self.eps) 130 | return w['w'] * i + w['b'] 131 | 132 | class Config(NamedTuple): 133 | eps: Optional[float] = None 134 | # all the other dimensions are normalized 135 | norm_dims: Optional[tuple[int, ...]] = None 136 | x_shape: Optional[tuple[Optional[int], ...]] = None 137 | w_save_name: str = "w" 138 | b_save_name: str = "b" 139 | w_init: Literal[0] = 0 140 | b_init: Literal[0] = 0 141 | save_name: str = "ln" 142 | norm_dim_name: str = "norm_dim" 143 | 144 | def make(self) -> LN: 145 | check_config(self) 146 | return LN(self) 147 | 148 | @property 149 | def non_norm_dims(self) -> tuple[int, ...]: 150 | assert self.norm_dims is not None 151 | assert self.x_shape is not None 152 | return tuple(i for i in range(len(self.x_shape)) if i not in self.norm_dims) 153 | 154 | @property 155 | def weights(self) -> WeightConfigDict: 156 | assert self.norm_dims is not None, 'norm_dims must be specified' 157 | assert self.x_shape is not None, 'x_shape must be specified' 158 | assert self.eps is not None, 'eps must be specified' 159 | non_norm_shape = tuple(self.x_shape[i] for i in range(len(self.x_shape)) if i not in self.norm_dims) 160 | non_norm_shape = cast(tuple[int, ...], non_norm_shape) 161 | return dict( 162 | w=WeightConfig( 163 | save_name=self.w_save_name, 164 | shape=non_norm_shape, 165 | init=self.w_init 166 | ), 167 | b=WeightConfig( 168 | save_name=self.b_save_name, 169 | shape=non_norm_shape, 170 | init=self.b_init 171 | ) 172 | ) 173 | 174 | @property 175 | def parts(self) -> PartsDict: 176 | return {} 177 | 178 | def fill(self) -> LN.Config: 179 | return self 180 | 181 | def weights_check(self, w: WeightsTree) -> LN.Weights: 182 | return cast(LN.Weights, config_weights_check(self, w)) 183 | 184 | def __init__(self, config: Config): 185 | assert config.eps is not None 186 | self.config = config 187 | self.eps = config.eps 188 | -------------------------------------------------------------------------------- /gpt_recon/clean_frame_utils.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from __future__ import annotations 3 | 4 | from abc import abstractmethod 5 | from typing import NamedTuple, TypeVar, List, Union, Protocol, Mapping, cast, Callable 6 | 7 | import jax.numpy as jnp 8 | from typing_extensions import runtime 9 | 10 | from picojax.jax_utils import Arr, WeightsTree 11 | from picojax.random_utils import kaiming_init, embedding_init, normal_init, dropout_gen, SafeKey, ArrayGen 12 | 13 | 14 | class WeightConfig(NamedTuple): 15 | # from in to out 16 | save_name: str 17 | shape: tuple[int, ...] 18 | init: Union[ArrayGen, int, float] = "kaiming" 19 | scale: float = 1 20 | 21 | def make(self, rng_key: SafeKey) -> Arr: 22 | if isinstance(self.init, int) or isinstance(self.init, float): 23 | return jnp.full(self.shape, float(self.init)) 24 | elif self.init == 'kaiming': 25 | return kaiming_init(rng_key, self.scale, self.shape) 26 | elif self.init == 'embedding': 27 | return embedding_init(rng_key, self.scale, self.shape) 28 | elif self.init == 'normal': 29 | return normal_init(rng_key, self.scale, self.shape) 30 | elif self.init == 'dropout': 31 | return dropout_gen(rng_key, self.scale, self.shape) 32 | else: 33 | raise NotImplementedError("unsupported init type") 34 | 35 | 36 | W_co = TypeVar('W_co', covariant=True) 37 | 38 | 39 | @runtime 40 | class ModuleConfig(Protocol[W_co]): 41 | @property 42 | @abstractmethod 43 | def save_name(self) -> str: 44 | ... 45 | 46 | @property 47 | @abstractmethod 48 | def weights(self) -> WeightConfigDict: 49 | ... 50 | 51 | @property 52 | @abstractmethod 53 | def parts(self) -> PartsDict: 54 | ... 55 | 56 | def fill(self) -> ModuleConfig: 57 | ... 58 | 59 | def make(self) -> Module: 60 | ... 61 | 62 | def weights_check(self, weights: WeightsTree) -> W_co: 63 | ... 64 | 65 | 66 | W = TypeVar('W', contravariant=True) 67 | 68 | 69 | @runtime 70 | class Module(Protocol[W]): 71 | 72 | @abstractmethod 73 | def f(self, w: W, x: Arr) -> Arr: 74 | ... 75 | 76 | 77 | T = TypeVar('T') 78 | 79 | 80 | def check_config(config: NamedTuple) -> None: 81 | for k, v in config._asdict().items(): 82 | if v is None: 83 | raise ValueError(f"Missing config '{k}' in {config.__class__}") 84 | 85 | 86 | WeightConfigTree = Mapping[str, Union[WeightConfig, "WeightConfigTree", list["WeightConfigTree"]]] 87 | 88 | WeightConfigDict = dict[str, WeightConfig] 89 | PartsDict = dict[str, Union[ModuleConfig, List[ModuleConfig]]] 90 | 91 | 92 | def config_weights_check(config: ModuleConfig, weights: WeightsTree) -> WeightsTree: 93 | try: 94 | checked_w: WeightsTree = {} 95 | assert isinstance(weights, dict), f"weights for {config.save_name} module is not a dict: {type(weights)}" 96 | for name, part_config in config.parts.items(): 97 | if name not in weights: 98 | raise ValueError(f"Missing weight {name}") 99 | w = weights[name] 100 | if isinstance(part_config, ModuleConfig): 101 | if not isinstance(w, dict): 102 | raise ValueError(f"weights for {config.save_name} module is not a dict: {type(w)}") 103 | checked_w[name] = config_weights_check(part_config, w) 104 | else: 105 | err_msg = f"Config {config.save_name} should be a ModuleConfig or a non-empty list of ModuleConfigs, got {part_config}" 106 | assert isinstance(part_config, list), err_msg 107 | assert len(part_config) > 0, err_msg 108 | if len(part_config) != len(w): 109 | raise ValueError( 110 | f"Wrong number of weights in list {name}: weight {len(w)} != config {len(part_config)}") 111 | assert isinstance(part_config[0], ModuleConfig), err_msg 112 | if not isinstance(w, list): 113 | raise ValueError(f"weights for {name} module is not a list: {type(w)}") 114 | else: 115 | checked_w[name] = [config_weights_check(part, cast(WeightsTree, w[i])) 116 | for i, part in enumerate(part_config)] 117 | checked_w.update(weights_check(config.weights, weights)) 118 | return checked_w 119 | except ValueError as e: 120 | raise ValueError(f"Config Weights check failed for {config.save_name}") from e 121 | 122 | 123 | def weights_check(weights_config: WeightConfigDict, weights: WeightsTree) -> WeightsTree: 124 | try: 125 | w: WeightsTree = {} 126 | for name, config in weights_config.items(): 127 | assert isinstance(config, WeightConfig) 128 | w[name] = weight_check(config, name, weights) 129 | return w 130 | except ValueError as e: 131 | raise ValueError(f"Weights check failed for {weights_config}") from e 132 | 133 | 134 | def squeeze_side(s: tuple[int, ...]) -> tuple[int, ...]: 135 | assert len(s) > 0, "Empty shape cannot be squeezed" 136 | if len(s) == 1: 137 | return s 138 | if s[0] == 1: 139 | return squeeze_side(s[1:]) 140 | if s[-1] == 1: 141 | return squeeze_side(s[:-1]) 142 | return s 143 | 144 | 145 | def weight_check(config: WeightConfig, name: str, weights: WeightsTree) -> Arr: 146 | if name not in weights: 147 | raise ValueError(f"Missing weight {name}") 148 | w = weights[name] 149 | if not isinstance(w, Arr): 150 | raise ValueError(f"weight {name} is not an array: {type(w)}") 151 | if squeeze_side(w.shape) != squeeze_side(config.shape): 152 | raise ValueError(f"Shape for weight {name} does not match: {w.shape} != {config.shape}") 153 | else: 154 | return w.reshape(config.shape) 155 | 156 | 157 | def load_config(config: ModuleConfig, weights_getter: Callable[[str], Arr], prefix: str = "") -> WeightsTree: 158 | if len(config.save_name) == 0: 159 | name_prefix = prefix 160 | else: 161 | name_prefix = f"{prefix}{config.save_name}." 162 | weights: WeightsTree = {} 163 | for name, part_config in config.parts.items(): 164 | if isinstance(part_config, ModuleConfig): 165 | weights[name] = load_config(part_config, weights_getter, name_prefix) 166 | else: 167 | err_msg = f"Config {config.save_name} should be a ModuleConfig or a non-empty list of ModuleConfigs" 168 | assert isinstance(part_config, list), err_msg 169 | assert len(part_config) > 0, err_msg 170 | assert isinstance(part_config[0], ModuleConfig), err_msg 171 | weights[name] = [load_config(part, weights_getter, f"{name_prefix}{i}.") 172 | for i, part in enumerate(part_config)] 173 | weights.update(load_weights(config.weights, weights_getter, name_prefix)) 174 | return weights 175 | 176 | 177 | def load_weights(weight_configs: WeightConfigDict, weights_getter: Callable[[str], Arr], 178 | prefix: str = "") -> WeightsTree: 179 | weights: WeightsTree = {} 180 | for name, weight_config in weight_configs.items(): 181 | assert isinstance(weight_config, 182 | WeightConfig), f"weight_config {name} is not a WeightConfig, but {type(weight_config)}" 183 | weights[name] = weights_getter(f"{prefix}{weight_config.save_name}") 184 | return weights 185 | 186 | 187 | def init_weights(weights_config: WeightConfigDict, rng_key: SafeKey) -> WeightsTree: 188 | w: WeightsTree = {} 189 | all_keys = rng_key.split(len(weights_config)) 190 | for key, (name, weight_config) in zip(all_keys, weights_config.items()): 191 | assert isinstance(weight_config, WeightConfig), f"Config {name} should be a WeightConfig, got {weight_config}" 192 | w[name] = weight_config.make(key) 193 | return w 194 | 195 | 196 | def init_weight_module(module: ModuleConfig, rng_key: SafeKey) -> WeightsTree: 197 | try: 198 | parts_key, weights_key = rng_key.split(2) 199 | w = init_weight_parts(module.parts, parts_key) 200 | w.update(init_weights(module.weights, weights_key)) 201 | except Exception as e: 202 | raise Exception(f"Failed to initialize module {module.save_name}.") from e 203 | return w 204 | 205 | 206 | def init_weight_parts(parts: PartsDict, rng_key: SafeKey) -> WeightsTree: 207 | w: WeightsTree = {} 208 | all_keys = rng_key.split(len(parts)) 209 | for key, (name, part_config) in zip(all_keys, parts.items()): 210 | if isinstance(part_config, ModuleConfig): 211 | w[name] = init_weight_module(part_config, key) 212 | else: 213 | err_msg = f"Config {name} should be a ModuleConfig or a non-empty list of ModuleConfigs, got {part_config}" 214 | assert isinstance(part_config, list), err_msg 215 | assert len(part_config) > 0, err_msg 216 | assert isinstance(part_config[0], ModuleConfig), err_msg 217 | subkeys = key.split(len(part_config)) 218 | w[name] = [init_weight_module(part, subkey) 219 | for i, (subkey, part) in enumerate(zip(subkeys, part_config))] 220 | return w 221 | -------------------------------------------------------------------------------- /gpt_recon/gpt.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from functools import cache 5 | from typing import NamedTuple, Optional, cast, TypedDict, List, Literal, Union 6 | 7 | from chex import assert_shape 8 | from einops import rearrange 9 | from jax import numpy as jnp, vmap 10 | from jax.nn import softmax 11 | 12 | from picojax.jax_utils import jit_f, Arr, WeightsTree 13 | from gpt_recon.clean_frame import Linear, for_all_T, gelu, LN 14 | from gpt_recon.clean_frame_utils import check_config, WeightConfigDict, PartsDict, config_weights_check, \ 15 | WeightConfig 16 | 17 | 18 | class GptMha: 19 | class Weights(TypedDict): 20 | QKV_linear: Linear.Weights 21 | linear: Linear.Weights 22 | 23 | @jit_f 24 | def causal_dot_attention(self, q: Arr, k: Arr, v: Arr) -> Arr: 25 | assert_shape([q, k, v], (self.T, self.dim_heads)) 26 | mask = self.get_mask(q.shape[0]) 27 | result = softmax((q @ k.T) / self.scale + mask) @ v 28 | assert_shape(result, (self.T, self.dim_heads)) 29 | return result 30 | 31 | @jit_f 32 | def f(self, w: GptMha.Weights, x: Arr) -> Arr: 33 | assert_shape(x, (self.T, self.n_channels)) 34 | 35 | q, k, v = rearrange(self.QKV_linearf(w['QKV_linear'], x), 36 | 'T (qkv n_heads dim_heads) -> qkv n_heads T dim_heads', 37 | qkv=3, 38 | n_heads=self.n_heads, 39 | dim_heads=self.dim_heads) 40 | # extension_shape = ['n_head', ...] 41 | # attended = xmap(self.causal_dot_attention, [extension_shape] * 3, extension_shape)(q, k, v) 42 | attended = vmap(self.causal_dot_attention, (0, 0, 0), 0)(q, k, v) 43 | assert_shape(attended, (self.n_heads, self.T, self.dim_heads)) 44 | concatenated = jnp.concatenate(attended, -1) 45 | assert_shape(concatenated, (self.T, self.n_channels)) 46 | 47 | result = self.linearf(w['linear'], concatenated) 48 | 49 | assert_shape(result, (self.T, self.n_channels)) 50 | return result 51 | 52 | def f_debug(self, w: GptMha.Weights, x: Arr) -> dict[str, Arr]: 53 | assert_shape(x, (self.T, self.n_channels)) 54 | 55 | qs, ks, vs = rearrange(self.QKV_linearf(w['QKV_linear'], x), 56 | 'T (qkv n_heads dim_heads) -> qkv n_heads T dim_heads', 57 | qkv=3, 58 | n_heads=self.n_heads, 59 | dim_heads=self.dim_heads) 60 | mask = self.get_mask(qs.shape[1]) 61 | attended_list = [] 62 | attn_maps = [] 63 | attn_maps_raw = [] 64 | for q, k, v in zip(qs, ks, vs): 65 | attn_map_raw = q @ k.T 66 | attn_maps_raw.append(attn_map_raw) 67 | attn_map = softmax(attn_map_raw / self.scale + mask) 68 | attn_maps.append(attn_map) 69 | attended_list.append(attn_map @ v) 70 | attended = jnp.stack(attended_list) 71 | assert_shape(attended, (self.n_heads, self.T, self.dim_heads)) 72 | concatenated = jnp.concatenate(attended, -1) 73 | assert_shape(concatenated, (self.T, self.n_channels)) 74 | 75 | result = self.linearf(w['linear'], concatenated) 76 | 77 | assert_shape(result, (self.T, self.n_channels)) 78 | return dict( 79 | attn=jnp.stack(attn_maps), 80 | attn_raw=jnp.stack(attn_maps_raw), 81 | x_after_mha=result) 82 | 83 | class Config(NamedTuple): 84 | n_channels: Optional[int] = None 85 | n_heads: Optional[int] = None 86 | n_seq: Optional[Union[int, Literal['dynamic']]] = None 87 | inf_mask: float = -1e10 88 | linear: Linear.Config = Linear.Config() 89 | QKV_linear: Linear.Config = Linear.Config() 90 | 91 | save_name: str = "mha" 92 | 93 | @property 94 | @cache 95 | def T(self) -> Optional[int]: 96 | if self.n_seq == 'dynamic': 97 | return None 98 | else: 99 | return self.n_seq 100 | 101 | def fill(self) -> GptMha.Config: 102 | assert self.n_channels is not None, 'n_channels must be set' 103 | new = self._replace(linear=self.linear._replace(n_in=self.n_channels, n_out=self.n_channels), 104 | QKV_linear=self.QKV_linear._replace(n_in=self.n_channels, n_out=3 * self.n_channels)) 105 | check_config(new) 106 | return new 107 | 108 | def make(self) -> GptMha: 109 | return GptMha(self.fill()) 110 | 111 | @property 112 | def dim_heads(self) -> int: 113 | assert self.n_channels is not None 114 | assert self.n_heads is not None 115 | return self.n_channels // self.n_heads 116 | 117 | @property 118 | def weights(self) -> WeightConfigDict: 119 | return {} 120 | 121 | @property 122 | def parts(self) -> PartsDict: 123 | filled = self.fill() 124 | return dict( 125 | linear=filled.linear, 126 | QKV_linear=filled.QKV_linear 127 | ) 128 | 129 | def weights_check(self, w: WeightsTree) -> GptMha.Weights: 130 | return cast(GptMha.Weights, config_weights_check(self, w)) 131 | 132 | def __init__(self, config: Config): 133 | assert config.n_channels is not None 134 | assert config.n_heads is not None 135 | assert config.dim_heads is not None 136 | self.n_channels = config.n_channels 137 | self.n_heads = config.n_heads 138 | self.T = config.T 139 | self.dim_heads = config.dim_heads 140 | self.config = config 141 | 142 | self.linear = config.linear.make() 143 | self.QKV_linear = config.QKV_linear.make() 144 | self.scale = math.sqrt(self.dim_heads) 145 | self.linearf = for_all_T(self.linear.f) 146 | self.QKV_linearf = for_all_T(self.QKV_linear.f) 147 | assert self.n_channels % self.n_heads == 0, 'n_channels must be divisible by n_heads' 148 | 149 | def get_mask(self, t: int) -> Arr: 150 | return (1 - jnp.tri(t)) * self.config.inf_mask 151 | 152 | 153 | class GptFfn: 154 | class Weights(TypedDict): 155 | linear1: Linear.Weights 156 | linear2: Linear.Weights 157 | 158 | @jit_f 159 | def f(self, w: GptFfn.Weights, x: Arr) -> Arr: 160 | assert_shape(x, (self.n_channels,)) 161 | result = self.linear2.f(w['linear2'], gelu(self.linear1.f(w['linear1'], x))) 162 | assert_shape(result, (self.n_channels,)) 163 | return result 164 | 165 | class Config(NamedTuple): 166 | n_channels: Optional[int] = None 167 | n_seq: Optional[Union[int, Literal['dynamic']]] = None 168 | linear1: Linear.Config = Linear.Config() 169 | linear2: Linear.Config = Linear.Config() 170 | save_name: str = "ffn" 171 | 172 | @property 173 | @cache 174 | def T(self) -> Optional[int]: 175 | if self.n_seq == 'dynamic': 176 | return None 177 | else: 178 | return self.n_seq 179 | 180 | def fill(self) -> GptFfn.Config: 181 | assert self.n_channels is not None 182 | new = self._replace( 183 | linear1=self.linear1._replace(n_in=self.n_channels, n_out=self.n_channels * 4), 184 | linear2=self.linear2._replace(n_in=self.n_channels * 4, n_out=self.n_channels)) 185 | check_config(new) 186 | return new 187 | 188 | def make(self) -> GptFfn: 189 | return GptFfn(self.fill()) 190 | 191 | @property 192 | def weights(self) -> WeightConfigDict: 193 | return {} 194 | 195 | @property 196 | def parts(self) -> PartsDict: 197 | filled = self.fill() 198 | return dict( 199 | linear1=filled.linear1, 200 | linear2=filled.linear2 201 | ) 202 | 203 | def weights_check(self, w: WeightsTree) -> GptFfn.Weights: 204 | return cast(GptFfn.Weights, config_weights_check(self, w)) 205 | 206 | def __init__(self, config: Config): 207 | self.n_channels = config.n_channels 208 | self.linear1 = config.linear1.make() 209 | self.linear2 = config.linear2.make() 210 | 211 | 212 | class GptBlock: 213 | class Weights(TypedDict): 214 | mha: GptMha.Weights 215 | ffn: GptFfn.Weights 216 | ln1: LN.Weights 217 | ln2: LN.Weights 218 | 219 | @jit_f 220 | def f(self, w: GptBlock.Weights, x: Arr) -> Arr: 221 | assert_shape(x, (self.T, self.n_channels)) 222 | x += self.mha.f(w['mha'], self.ln1f(w['ln1'], x)) 223 | x += self.ffnf(w['ffn'], self.ln2f(w['ln2'], x)) 224 | assert_shape(x, (self.T, self.n_channels)) 225 | return x 226 | 227 | def f_debug(self, w: GptBlock.Weights, x: Arr) -> dict[str, Arr]: 228 | return_dict = {} 229 | x0 = x 230 | return_dict['x0'] = x 231 | x = self.ln1f(w['ln1'], x) 232 | return_dict['x_before_mha'] = x 233 | attn_result = self.mha.f_debug(w['mha'], x) 234 | return_dict.update(attn_result) 235 | x = attn_result['x_after_mha'] 236 | x = x0 + x 237 | x0 = x 238 | x = self.ln2f(w['ln2'], x) 239 | return_dict['x_before_ffn'] = x 240 | x = self.ffnf(w['ffn'], x) 241 | return_dict['x_after_ffn'] = x 242 | x = x0 + x 243 | return_dict['x'] = x 244 | return return_dict 245 | 246 | class Config(NamedTuple): 247 | eps: Optional[float] = None 248 | n_channels: Optional[int] = None 249 | n_heads: Optional[int] = None 250 | n_seq: Optional[Union[int, Literal['dynamic']]] = None 251 | mha: GptMha.Config = GptMha.Config() 252 | ffn: GptFfn.Config = GptFfn.Config() 253 | ln1: LN.Config = LN.Config() 254 | ln2: LN.Config = LN.Config() 255 | save_name: str = "gpt_block" 256 | 257 | @property 258 | def T(self) -> Optional[int]: 259 | if self.n_seq == 'dynamic': 260 | return None 261 | else: 262 | return self.n_seq 263 | 264 | @property 265 | def x_shape(self) -> tuple[Optional[int], ...]: 266 | assert self.n_channels is not None 267 | return self.T, self.n_channels 268 | 269 | def fill(self) -> GptBlock.Config: 270 | new = self._replace( 271 | mha=self.mha._replace(n_channels=self.n_channels, n_seq=self.n_seq, n_heads=self.n_heads).fill(), 272 | ffn=self.ffn._replace(n_channels=self.n_channels, n_seq=self.n_seq).fill(), 273 | ln1=self.ln1._replace(eps=self.eps, norm_dims=(0,), x_shape=self.x_shape), 274 | ln2=self.ln2._replace(eps=self.eps, norm_dims=(0,), x_shape=self.x_shape)) 275 | check_config(new) 276 | return new 277 | 278 | def make(self) -> GptBlock: 279 | return GptBlock(self.fill()) 280 | 281 | @property 282 | def weights(self) -> WeightConfigDict: 283 | return {} 284 | 285 | @property 286 | def parts(self) -> PartsDict: 287 | filled = self.fill() 288 | return dict( 289 | mha=filled.mha, 290 | ffn=filled.ffn, 291 | ln1=filled.ln1, 292 | ln2=filled.ln2, 293 | ) 294 | 295 | def weights_check(self, w: WeightsTree) -> GptBlock.Weights: 296 | return cast(GptBlock.Weights, config_weights_check(self, w)) 297 | 298 | def __init__(self, config: Config): 299 | self.T = config.T 300 | self.n_channels = config.n_channels 301 | self.mha = config.mha.make() 302 | self.ffn = config.ffn.make() 303 | self.ln1 = config.ln1.make() 304 | self.ln2 = config.ln2.make() 305 | 306 | self.ffnf = for_all_T(self.ffn.f) 307 | self.ln1f = self.ln1.f 308 | self.ln2f = self.ln2.f 309 | 310 | 311 | class GptDecoder: 312 | class Weights(TypedDict): 313 | blocks: List[GptBlock.Weights] 314 | 315 | @jit_f 316 | def f(self, w: GptDecoder.Weights, x: Arr) -> Arr: 317 | assert_shape(x, (self.T, self.n_channels)) 318 | for blk, blk_w in zip(self.blocks, w['blocks']): 319 | x = blk.f(blk_w, x) 320 | assert_shape(x, (self.T, self.n_channels)) 321 | return x 322 | 323 | def f_debug(self, w: GptDecoder.Weights, x: Arr) -> tuple[Arr, dict[str, Arr]]: 324 | assert_shape(x, (self.T, self.n_channels)) 325 | view_vec = [] 326 | for blk, blk_w in zip(self.blocks, w['blocks']): 327 | return_dict = blk.f_debug(blk_w, x) 328 | view_vec.append(return_dict) 329 | x = return_dict['x'] 330 | assert_shape(x, (self.T, self.n_channels)) 331 | vecs = {key: jnp.stack([dict_item[key] for dict_item in view_vec]) for key in view_vec[0].keys()} 332 | return x, vecs 333 | 334 | class Config(NamedTuple): 335 | eps: Optional[float] = None 336 | n_channels: Optional[int] = None 337 | n_heads: Optional[int] = None 338 | n_seq: Optional[Union[int, Literal['dynamic']]] = None 339 | n_blocks: Optional[int] = None 340 | blocks: GptBlock.Config = GptBlock.Config() 341 | 342 | save_name: str = 'gpt_decoder' 343 | 344 | @property 345 | @cache 346 | def T(self) -> Optional[int]: 347 | if self.n_seq == 'dynamic': 348 | return None 349 | else: 350 | return self.n_seq 351 | 352 | def fill(self) -> GptDecoder.Config: 353 | new = self._replace(blocks=self.blocks._replace(eps=self.eps, n_channels=self.n_channels, 354 | n_heads=self.n_heads, n_seq=self.n_seq).fill()) 355 | check_config(new) 356 | return new 357 | 358 | def make(self) -> GptDecoder: 359 | return GptDecoder(self.fill()) 360 | 361 | @property 362 | def weights(self) -> WeightConfigDict: 363 | return {} 364 | 365 | @property 366 | def parts(self) -> PartsDict: 367 | filled = self.fill() 368 | assert filled.blocks is not None 369 | assert filled.n_blocks is not None 370 | return dict( 371 | blocks=[filled.blocks] * filled.n_blocks 372 | ) 373 | 374 | def weights_check(self, w: WeightsTree) -> GptDecoder.Weights: 375 | return cast(GptDecoder.Weights, config_weights_check(self, w)) 376 | 377 | def __init__(self, config: Config): 378 | assert config.n_blocks is not None 379 | self.T = config.T 380 | self.n_channels = config.n_channels 381 | self.blocks = [config.blocks.make() for _ in range(config.n_blocks)] 382 | 383 | 384 | class Gpt: 385 | class Weights(TypedDict): 386 | token_embedding: Arr 387 | positional_encoding: Arr 388 | decoder: GptDecoder.Weights 389 | ln: LN.Weights 390 | 391 | @jit_f 392 | def f(self, w: Gpt.Weights, x: Arr) -> Arr: 393 | assert_shape(x, (self.T,)) 394 | # remove x.shape[0] for faster but static seq len 395 | if self.T is None: 396 | x = w['token_embedding'][x, :] + w['positional_encoding'][:x.shape[0], :] 397 | else: 398 | x = w['token_embedding'][x, :] + w['positional_encoding'][:self.T, :] 399 | assert_shape(x, (self.T, self.n_channels)) 400 | result = self.ln.f(w['ln'], self.decoder.f(w['decoder'], x)) 401 | assert_shape(result, (self.T, self.n_channels)) 402 | return result @ w['token_embedding'].T 403 | 404 | def f_debug(self, w: Gpt.Weights, x: Arr) -> tuple[Arr, dict[str, Arr]]: 405 | assert_shape(x, (self.T,)) 406 | x = w['token_embedding'][x, :] + w['positional_encoding'][:x.shape[0], :] 407 | assert_shape(x, (self.T, self.n_channels)) 408 | decoder_out, save_vecs = self.decoder.f_debug(w['decoder'], x) 409 | result = self.ln.f(w['ln'], decoder_out) 410 | assert_shape(result, (self.T, self.n_channels)) 411 | return result @ w['token_embedding'].T, save_vecs 412 | 413 | class Config(NamedTuple): 414 | eps: Optional[float] = None 415 | n_channels: Optional[int] = None 416 | n_heads: Optional[int] = None 417 | n_seq: Optional[Union[int, Literal['dynamic']]] = None 418 | n_blocks: Optional[int] = None 419 | n_tokens: Optional[int] = None 420 | max_seq_len: Optional[int] = None 421 | 422 | token_embedding_save_name: str = 'te' 423 | token_embedding_init: Literal['normal'] = 'normal' 424 | token_embedding_scale: float = 0.02 425 | 426 | decoder: GptDecoder.Config = GptDecoder.Config() 427 | ln: LN.Config = LN.Config() 428 | 429 | positional_embedding_save_name: str = 'pe' 430 | 431 | save_name: str = 'gpt' 432 | 433 | @property 434 | def T(self) -> Optional[int]: 435 | if self.n_seq == 'dynamic': 436 | return None 437 | else: 438 | return self.n_seq 439 | 440 | def fill(self) -> Gpt.Config: 441 | assert self.n_channels is not None, 'n_channels must be specified' 442 | new = self._replace(decoder=self.decoder._replace(eps=self.eps, n_channels=self.n_channels, 443 | n_heads=self.n_heads, n_seq=self.n_seq, 444 | n_blocks=self.n_blocks).fill(), 445 | ln=self.ln._replace(eps=self.eps, norm_dims=(0,), x_shape=(self.T, self.n_channels))) 446 | 447 | check_config(new) 448 | return new 449 | 450 | def make(self) -> Gpt: 451 | return Gpt(self.fill()) 452 | 453 | @property 454 | def weights(self) -> WeightConfigDict: 455 | filled = self.fill() 456 | assert filled.max_seq_len is not None 457 | assert filled.n_tokens is not None 458 | assert filled.n_channels is not None 459 | return dict( 460 | token_embedding=WeightConfig(save_name=filled.token_embedding_save_name, 461 | init=filled.token_embedding_init, 462 | shape=(filled.n_tokens, filled.n_channels), 463 | scale=filled.token_embedding_scale), 464 | 465 | positional_encoding=WeightConfig(save_name=filled.positional_embedding_save_name, 466 | shape=(filled.max_seq_len, filled.n_channels)), 467 | ) 468 | 469 | @property 470 | def parts(self) -> PartsDict: 471 | filled = self.fill() 472 | assert filled.decoder is not None 473 | assert filled.ln is not None 474 | assert filled.token_embedding_save_name is not None 475 | return dict( 476 | decoder=filled.decoder, 477 | ln=filled.ln, 478 | ) 479 | 480 | def weights_check(self, w: WeightsTree) -> Gpt.Weights: 481 | return cast(Gpt.Weights, config_weights_check(self, w)) 482 | 483 | def __init__(self, config: Config): 484 | assert config.n_blocks is not None 485 | assert config.n_tokens is not None 486 | assert config.n_channels is not None 487 | self.config = config 488 | self.T = config.T 489 | self.n_channels = config.n_channels 490 | self.n_tokens = config.n_tokens 491 | self.eps = config.eps 492 | self.decoder = config.decoder.make() 493 | self.ln = config.ln.make() 494 | 495 | 496 | 497 | 498 | def get_positional_encoding(max_len: int, d_model: int): 499 | pe = jnp.zeros((max_len, d_model)) 500 | position = jnp.expand_dims(jnp.arange(0, max_len), 1) 501 | div_term = jnp.exp( 502 | jnp.arange(0, d_model, 2) * -(jnp.log(10000.0) / d_model) 503 | ) 504 | pe = pe.at[:, 0::2].set(jnp.sin(position * div_term)) 505 | pe = pe.at[:, 1::2].set(jnp.cos(position * div_term)) 506 | return pe 507 | 508 | 509 | -------------------------------------------------------------------------------- /gpt_recon/load_model.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from safetensors import safe_open 3 | from pathlib import Path 4 | 5 | # %% 6 | path = Path("/Data/lm_models/gpt2") 7 | with safe_open(path / "model.safetensors", framework="flax", device="cpu") as f: 8 | for key in f.keys(): 9 | t = f.get_tensor(key) 10 | print(key, t.shape) 11 | 12 | print(f.get_tensor("h.3.attn.bias")) 13 | print(type(t)) 14 | -------------------------------------------------------------------------------- /gpt_recon/run_gpt.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from __future__ import annotations 3 | 4 | from collections import defaultdict 5 | from pathlib import Path 6 | 7 | import jax.numpy as jnp 8 | from jax import jit 9 | from safetensors import safe_open 10 | from safetensors.flax import save_file 11 | 12 | from bpe_encoder import get_encoder 13 | from gpt_recon.clean_frame import LN, Linear 14 | from gpt_recon.clean_frame_utils import Arr, load_config 15 | from gpt_recon.gpt import GptMha, GptFfn, GptBlock, GptDecoder, Gpt 16 | 17 | gpt_config = Gpt.Config(eps=1e-5, 18 | n_channels=768, 19 | n_heads=12, 20 | n_seq='dynamic', 21 | max_seq_len=1024, 22 | n_blocks=12, 23 | n_tokens=50257, 24 | token_embedding_save_name='wte.weight', 25 | positional_embedding_save_name='wpe.weight', 26 | ln=LN.Config(w_save_name='weight', b_save_name='bias', save_name='ln_f'), 27 | save_name="", 28 | decoder=GptDecoder.Config( 29 | save_name='h', 30 | blocks=GptBlock.Config( 31 | save_name="", 32 | mha=GptMha.Config( 33 | save_name='attn', 34 | QKV_linear=Linear.Config(save_name='c_attn', w_save_name='weight', b_save_name='bias'), 35 | linear=Linear.Config(save_name="c_proj", w_save_name='weight', b_save_name='bias'), 36 | ), 37 | ln1=LN.Config(w_save_name='weight', b_save_name='bias', save_name='ln_1'), 38 | ffn=GptFfn.Config( 39 | save_name='mlp', 40 | linear1=Linear.Config(w_save_name='weight', b_save_name='bias', save_name='c_fc'), 41 | linear2=Linear.Config(w_save_name='weight', b_save_name='bias', save_name='c_proj'), 42 | ), 43 | ln2=LN.Config(w_save_name='weight', b_save_name='bias', save_name='ln_2') 44 | ) 45 | )).fill() 46 | 47 | print(gpt_config) 48 | print(gpt_config.weights) 49 | print(gpt_config.parts) 50 | 51 | z = defaultdict(int) 52 | ww = load_config(gpt_config, lambda i: z[i]) 53 | print(z.keys()) 54 | 55 | path = Path("/Data/lm_models/gpt2") 56 | with safe_open(path / "model.safetensors", framework="flax", device="cpu") as f: 57 | weights_tree_ = load_config(gpt_config, f.get_tensor) 58 | print(weights_tree_.keys()) 59 | 60 | gpt_ = gpt_config.make() 61 | checked_weights = gpt_config.weights_check(weights_tree_) 62 | 63 | 64 | # r = gpt_.f(checked_weights, jnp.ones((1024,), dtype=jnp.int32)) 65 | # print(r) 66 | 67 | 68 | def run(inputs): 69 | return gpt_.f(checked_weights, jnp.array(inputs)) 70 | 71 | 72 | def debug(inputs) -> dict[str, Arr]: 73 | logits, to_save = jit(gpt_.f_debug)(checked_weights, jnp.array(inputs)) 74 | out = encoder.decode([int(jnp.argmax(logits[-1]))]) 75 | print(out) 76 | save_dir = "../saves" 77 | return save_file(to_save, f'{save_dir}/view_vec2_dict_jit') 78 | 79 | 80 | encoder = get_encoder("gpt2", "/Data/lm_models/", "vocab.json", "merges.txt") 81 | 82 | # prompt = "Alan Turing theorized that computers would one day become" 83 | prompt = "Time flies like an arrow, fruit flies like a banana. Time flies like an" 84 | # prompt = "Time flies like an arrow; fruit flies like a banana." 85 | # prompt = "The number of pixels used to render an image is set by the Axes size and the dpi of the figure. This can lead to aliasing artifacts when the image is resampled because the displayed image size will usually not match the size of X (see Image antialiasing)." 86 | input_ids = encoder.encode(prompt) 87 | print([encoder.decoder[t] for t in input_ids]) 88 | 89 | debug(input_ids) 90 | 91 | # output_ids = generate(run, input_ids, 8) 92 | # output_text = encoder.decode(output_ids) 93 | # print(output_text) 94 | 95 | # TODO: 96 | # [done] implement weight check to parse a weight tree to proper weights 97 | # make shape a dict 98 | # [done] move out init_params 99 | # input/output shapes in config 100 | # [done] fix layernorm issue: clever loading 101 | # [done] make dyn a config instead of separate function 102 | # [done] compare with pico to fix stuff 103 | # change name to save_name 104 | # [half] investigate ffn residual (make good visualizations) 105 | # [done] complete the debug_chain 106 | # make view panel 107 | # [half] better folder structure 108 | -------------------------------------------------------------------------------- /gpt_recon/train_gpt.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from functools import partial 5 | from typing import Callable 6 | 7 | import jax.numpy as jnp 8 | import optax 9 | import wandb 10 | from optax import softmax_cross_entropy_with_integer_labels 11 | 12 | import gpt_recon.gpt as gpt 13 | import nlp_utils 14 | from gpt_recon.clean_frame import batch_fy, Linear 15 | from gpt_recon.clean_frame_utils import Arr, init_weight_module 16 | from custom_dataset import load_jax_cached 17 | from gpt_recon.gpt import Gpt 18 | from picojax.random_utils import infinite_safe_keys 19 | from picojax.train_utils import LMBatchConfig, TrainConfig, TrainState, BatchType 20 | 21 | # need to install the updated version for optax: 22 | # pip install git+https://github.com/deepmind/optax.git 23 | 24 | os.environ['JAX_LOG_COMPILES'] = '1' 25 | 26 | 27 | def gpt_loss_batch(forward: Callable[[Gpt.Weights, Arr], Arr], w: Gpt.Weights, batch: BatchType) -> Arr: 28 | inputs, labels = batch 29 | logits = batch_fy(forward)(w, jnp.array(inputs)) 30 | return softmax_cross_entropy_with_integer_labels(logits, jnp.array(labels)).mean() 31 | 32 | 33 | dataset = "play" 34 | encoded_jax, encode, decode, vocab_size_ = load_jax_cached(dataset=dataset) 35 | n = int(len(encoded_jax) * 0.9) 36 | train_data = encoded_jax[:n] 37 | valid_data = encoded_jax[n:] 38 | 39 | key_gen = infinite_safe_keys(0) 40 | 41 | adam_params = { 42 | 'learning_rate': 1e-4, 43 | 'beta1': 0.9, 44 | 'beta2': 0.999, 45 | 'eps': 1e-8, 46 | } 47 | lion_params = { 48 | 'learning_rate': 1e-4, 49 | 'beta1': 0.95, 50 | 'beta2': 0.98, 51 | 'weight_decay': 0.01 52 | } 53 | train_params = { 54 | 'eval_iters': 200, 55 | 'eval_interval': 2000, 56 | 'max_iters': 1000000, 57 | # 'adam': adam_params, 58 | 'lion': lion_params, 59 | # 'adamw': adam_params, 60 | 'optimizer': 'lion', 61 | } 62 | 63 | experimental_params = { 64 | 'eps': 1e-5, 65 | 'n_tokens': vocab_size_, 66 | 'n_channels': 768, 67 | 'n_heads': 12, 68 | 'n_blocks': 12, 69 | 70 | 'batch_size': 8, 71 | 'block_size': 128, 72 | 'train': train_params 73 | } 74 | 75 | max_iters = experimental_params['train']['max_iters'] 76 | eval_interval = experimental_params['train']['eval_interval'] 77 | eval_iters = experimental_params['train']['eval_iters'] 78 | batch_config_ = LMBatchConfig(block_size=experimental_params['block_size'], 79 | batch_size=experimental_params['batch_size']) 80 | 81 | # dataset = "english" 82 | 83 | gpt_config_ = Gpt.Config(eps=experimental_params['eps'], 84 | n_channels=experimental_params['n_channels'], 85 | n_heads=experimental_params['n_heads'], 86 | # n_seq='dynamic', 87 | n_seq=batch_config_.block_size, 88 | max_seq_len=batch_config_.block_size, 89 | n_blocks=experimental_params['n_blocks'], 90 | n_tokens=vocab_size_, 91 | decoder=gpt.GptDecoder.Config( 92 | blocks=gpt.GptBlock.Config( 93 | mha=gpt.GptMha.Config( 94 | linear=Linear.Config( 95 | ), 96 | QKV_linear=Linear.Config( 97 | ) 98 | ), 99 | ffn=gpt.GptFfn.Config( 100 | ) 101 | ) 102 | )).fill() 103 | 104 | raw_weights = init_weight_module(gpt_config_, next(key_gen)) 105 | raw_weights['positional_encoding'] = gpt.get_positional_encoding(experimental_params['block_size'], 106 | experimental_params['n_channels']) 107 | init_weights_ = gpt_config_.weights_check(raw_weights) 108 | 109 | if experimental_params['train']['optimizer'] == 'adam': 110 | adam_config = experimental_params['train']['adam'] 111 | optimizer_ = optax.adam(learning_rate=adam_config['learning_rate'], 112 | b1=adam_config['beta1'], 113 | b2=adam_config['beta2'], 114 | eps=adam_config['eps']) 115 | elif experimental_params['train']['optimizer'] == 'lion': 116 | lion_config = experimental_params['train']['lion'] 117 | optimizer_ = optax.lion(learning_rate=lion_config['learning_rate'], 118 | b1=lion_config['beta1'], 119 | b2=lion_config['beta2'], 120 | weight_decay=lion_config['weight_decay']) 121 | elif experimental_params['train']['optimizer'] == 'adamw': 122 | adamw_config = experimental_params['train']['adamw'] 123 | optimizer_ = optax.adamw(learning_rate=adamw_config['learning_rate'], 124 | b1=adamw_config['beta1'], 125 | b2=adamw_config['beta2'], 126 | eps=adamw_config['eps']) 127 | else: 128 | raise ValueError(f"optimizer {experimental_params['train']['optimizer']} not supported") 129 | 130 | model = gpt_config_.make() 131 | loss_f = partial(gpt_loss_batch, model.f) 132 | # noinspection PyArgumentList 133 | # cuz it's a NamedTuple 134 | train_config_ = TrainConfig(loss_fn=loss_f, 135 | optimiser=optimizer_) 136 | # noinspection PyArgumentList 137 | # cuz it's a NamedTuple 138 | train_state_: TrainState[Gpt.Weights] = TrainState(weights=init_weights_, 139 | opt_state=optimizer_.init(init_weights_)) 140 | 141 | wandb.init( 142 | project="inside-transformer", 143 | config=experimental_params, 144 | ) 145 | keys_ = next(key_gen).split(max_iters) 146 | for step in range(max_iters): 147 | batch_ = batch_config_.sample(train_data, keys_[step]) 148 | if step % eval_interval == 0: 149 | loss = train_config_.loss_fn(train_state_.weights, batch_) 150 | print(f"===step {step} is an eval step===") 151 | print(f"before step {step}, batch loss {loss}") 152 | 153 | train_state_ = train_config_.train1(train_state_, batch_) 154 | if step % eval_interval == 0: 155 | loss = train_config_.loss_fn(train_state_.weights, batch_) 156 | print(f"after step {step}, batch loss {loss}") 157 | results = train_config_.estimate_loss(eval_iters, key_gen, train_state_, 158 | {'train': partial(batch_config_.sample, train_data), 159 | 'val': partial(batch_config_.sample, valid_data)}) 160 | # generate_f = jax.jit(partial(dynamic_model_f, train_state_.weights)) 161 | # generated = gpt.generate(generate_f, [0], n_tokens_to_generate=10, 162 | # max_len=batch_config_.block_size) 163 | generate_f = partial(model.f, train_state_.weights) 164 | # TODO fix generation/ add temperature 165 | generated = nlp_utils.generate_static(generate_f, [0], 166 | n_tokens_to_generate=batch_config_.block_size - 1, 167 | max_len=batch_config_.block_size) 168 | wandb.log({"train_loss": results['train'], 169 | "validation_loss": results['val'], 170 | "batch_loss": loss, 171 | "n_tokens_trained": step * batch_config_.batch_size * batch_config_.block_size, 172 | 'generated': wandb.Html(f"{decode(generated)}")}) 173 | print(decode(generated), flush=True) 174 | 175 | wandb.finish() 176 | # TODO: add trainable weights 177 | # TODO: add weight decay 178 | -------------------------------------------------------------------------------- /gpt_recon/train_notebook.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from functools import partial 5 | from pprint import pprint 6 | from typing import Callable 7 | 8 | import jax.numpy as jnp 9 | import optax 10 | from jax import random 11 | from optax import softmax_cross_entropy_with_integer_labels 12 | 13 | import gpt_recon.gpt as gpt 14 | import nlp_utils 15 | from gpt_recon.clean_frame import batch_fy 16 | from gpt_recon.clean_frame_utils import Arr, config_weights_check, init_weight_module, ModuleConfig 17 | from custom_dataset import load_jax_cached 18 | from gpt_recon.gpt import Gpt, GptMha 19 | from picojax.random_utils import SafeKey, infinite_safe_keys 20 | from picojax.train_utils import LMBatchConfig, TrainConfig, TrainState, BatchType 21 | 22 | os.environ['JAX_LOG_COMPILES'] = '1' 23 | 24 | 25 | def go(c: ModuleConfig, x: Arr, key: SafeKey) -> Arr: 26 | weights = init_weight_module(c, key) 27 | return c.make().f(weights, x) 28 | 29 | 30 | gpt_mha_config_ = GptMha.Config(n_channels=9, 31 | n_heads=3, 32 | n_seq='dynamic').fill() 33 | pprint(gpt_mha_config_.parts) 34 | 35 | w = init_weight_module(gpt_mha_config_, SafeKey(random.PRNGKey(0))) 36 | # print(w) 37 | checked = config_weights_check(gpt_mha_config_, w) 38 | # print(checked) 39 | print(go(gpt_mha_config_, jnp.ones((5, 9)), SafeKey(random.PRNGKey(0))).shape) 40 | 41 | 42 | # @partial(jax.jit, static_argnums=(0,)) 43 | def gpt_loss_batch(forward: Callable[[Gpt.Weights, Arr], Arr], w: Gpt.Weights, batch: BatchType) -> Arr: 44 | inputs, labels = batch 45 | logits = batch_fy(forward)(w, jnp.array(inputs)) 46 | return softmax_cross_entropy_with_integer_labels(logits, jnp.array(labels)).mean() 47 | 48 | 49 | # %% 50 | 51 | 52 | key_gen = infinite_safe_keys(0) 53 | max_iters = 100000 54 | eval_interval = 1000 55 | learning_rate_ = 1e-4 56 | eval_iters = 100 57 | batch_config_ = LMBatchConfig(block_size=128, batch_size=4) 58 | 59 | # dataset = "english" 60 | dataset = "play" 61 | encoded_jax, encode, decode, vocab_size_ = load_jax_cached(dataset=dataset) 62 | n = int(len(encoded_jax) * 0.9) 63 | train_data = encoded_jax[:n] 64 | valid_data = encoded_jax[n:] 65 | 66 | gpt_config_ = Gpt.Config(eps=1e-5, 67 | n_channels=768, 68 | n_heads=12, 69 | # n_seq='dynamic', 70 | n_seq=batch_config_.block_size, 71 | max_seq_len=batch_config_.block_size, 72 | n_blocks=12, 73 | n_tokens=vocab_size_).fill() 74 | gpt_dynamic_config_ = gpt_config_._replace(n_seq='dynamic') 75 | dynamic_model = gpt_dynamic_config_.make() 76 | # result = go(gpt_config_, jnp.ones((1024,), dtype=jnp.int32), next(key_gen)) 77 | # print(result.shape) 78 | 79 | init_weights_ = gpt_config_.weights_check(init_weight_module(gpt_config_, next(key_gen))) 80 | optimizer_ = optax.adam(learning_rate=learning_rate_) 81 | 82 | # noinspection PyArgumentList 83 | # cuz it's a NamedTuple 84 | train_config_ = TrainConfig(model=gpt_config_.make(), 85 | loss_fn_in=gpt_loss_batch, 86 | optimiser=optimizer_) 87 | # noinspection PyArgumentList 88 | # cuz it's a NamedTuple 89 | train_state_: TrainState[Gpt.Weights] = TrainState(weights=init_weights_, 90 | opt_state=optimizer_.init(init_weights_)) 91 | 92 | # %% 93 | # precompile for better profiling 94 | # batch_ = batch_config_.sample(train_data, next(key_gen)) 95 | # train_state_ = train_config_.train1(train_state_, batch_) 96 | # batch_ = batch_config_.sample(train_data, next(key_gen)) 97 | # train_state_ = train_config_.train1(train_state_, batch_) 98 | # batch_ = batch_config_.sample(train_data, next(key_gen)) 99 | # train_state_ = train_config_.train1(train_state_, batch_) 100 | # # print([x.device() for x in tree_flatten(train_state_.weights)[0]]) 101 | # train_config_.estimate_loss(eval_iters, 102 | # key_gen, 103 | # train_state_, 104 | # batch_config_, 105 | # {'train': train_data, 'val': valid_data}) 106 | # train_config_.estimate_loss(eval_iters, 107 | # key_gen, 108 | # train_state_, 109 | # batch_config_, 110 | # {'train': train_data, 'val': valid_data}) 111 | 112 | # %% 113 | 114 | keys_ = next(key_gen).split(max_iters) 115 | #with jax.profiler.trace("./jax-trace", create_perfetto_link=True): 116 | for step in range(max_iters): 117 | batch_ = batch_config_.sample(train_data, keys_[step]) 118 | if step % eval_interval == 0: 119 | loss = train_config_.loss_fn(train_state_.weights, batch_) 120 | print(f"===step {step} is an eval step===") 121 | print(f"before step {step}, batch loss {loss}") 122 | train_config_.estimate_loss(eval_iters, 123 | key_gen, 124 | train_state_, 125 | batch_config_, 126 | {'train': train_data, 'val': valid_data}) 127 | 128 | # print(make_jaxpr(partial(jax_calc_updates, train_state_.optimiser, train_state_.loss_fn))(train_state_.weights, batch_, train_state_.opt_state)) 129 | train_state_ = train_config_.train1(train_state_, batch_) 130 | if step % eval_interval == 0: 131 | loss = train_config_.loss_fn(train_state_.weights, batch_) 132 | print(f"after step {step}, batch loss {loss}") 133 | train_config_.estimate_loss(eval_iters, key_gen, train_state_, batch_config_, 134 | {'train': train_data, 'val': valid_data}) 135 | # generate_f = jax.jit(partial(dynamic_model.f, train_state_.weights)) 136 | # generated = gpt.generate(generate_f, [0], n_tokens_to_generate=10, 137 | # max_len=batch_config_.block_size) 138 | generate_f = partial(train_config_.model.f, train_state_.weights) 139 | # TODO fix generation/ add temperature 140 | generated = nlp_utils.generate_static(generate_f, [0], 141 | n_tokens_to_generate=batch_config_.block_size - 1, 142 | max_len=batch_config_.block_size) 143 | print(decode(generated), flush=True) 144 | 145 | -------------------------------------------------------------------------------- /gpt_recon/tune_gpt.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from functools import partial 5 | from pathlib import Path 6 | from typing import Callable 7 | 8 | import jax.numpy as jnp 9 | import optax 10 | # import wandb 11 | from optax import softmax_cross_entropy_with_integer_labels 12 | from safetensors import safe_open 13 | 14 | import gpt_recon.gpt as gpt 15 | import nlp_utils 16 | from gpt_recon.clean_frame import batch_fy, LN, Linear 17 | from gpt_recon.clean_frame_utils import Arr, load_config 18 | from custom_dataset import load_jax_cached 19 | from gpt_recon.gpt import Gpt, GptMha, GptFfn, GptBlock, GptDecoder 20 | from picojax.random_utils import infinite_safe_keys 21 | from picojax.train_utils import LMBatchConfig, TrainConfig, TrainState, BatchType 22 | 23 | # need to install the updated version for optax: 24 | # pip install git+https://github.com/deepmind/optax.git 25 | 26 | os.environ['JAX_LOG_COMPILES'] = '1' 27 | 28 | 29 | def gpt_loss_batch(forward: Callable[[Gpt.Weights, Arr], Arr], w: Gpt.Weights, batch: BatchType) -> Arr: 30 | inputs, labels = batch 31 | logits = batch_fy(forward)(w, jnp.array(inputs)) 32 | return softmax_cross_entropy_with_integer_labels(logits, jnp.array(labels)).mean() 33 | 34 | 35 | dataset = "play" 36 | encoded_jax, encode, decode, vocab_size_ = load_jax_cached(dataset=dataset) 37 | n = int(len(encoded_jax) * 0.9) 38 | train_data = encoded_jax[:n] 39 | valid_data = encoded_jax[n:] 40 | 41 | key_gen = infinite_safe_keys(0) 42 | 43 | adam_params = { 44 | 'learning_rate': 1e-4, 45 | 'beta1': 0.9, 46 | 'beta2': 0.999, 47 | 'eps': 1e-8, 48 | } 49 | lion_params = { 50 | 'learning_rate': 1e-4, 51 | 'beta1': 0.9, 52 | 'beta2': 0.99, 53 | } 54 | train_params = { 55 | 'eval_iters': 200, 56 | 'eval_interval': 2000, 57 | 'max_iters': 1000000, 58 | # 'adam': adam_params, 59 | # 'lion': lion_params, 60 | 'optimizer': 'adamw', 61 | 'adamw': adam_params, 62 | } 63 | 64 | experimental_params = { 65 | 'eps': 1e-5, 66 | 'n_tokens': vocab_size_, 67 | 'n_channels': 768, 68 | 'n_heads': 12, 69 | 'n_blocks': 12, 70 | 71 | 'batch_size': 8, 72 | 'block_size': 128, 73 | 'train': train_params 74 | } 75 | 76 | # wandb.init( 77 | # project="inside-transformer", 78 | # config=experimental_params, 79 | # ) 80 | 81 | max_iters = experimental_params['train']['max_iters'] 82 | eval_interval = experimental_params['train']['eval_interval'] 83 | eval_iters = experimental_params['train']['eval_iters'] 84 | batch_config_ = LMBatchConfig(block_size=experimental_params['block_size'], 85 | batch_size=experimental_params['batch_size']) 86 | 87 | gpt_config_ = Gpt.Config(eps=experimental_params['eps'], 88 | n_channels=experimental_params['n_channels'], 89 | n_heads=experimental_params['n_heads'], 90 | # n_seq='dynamic', 91 | n_seq=batch_config_.block_size, 92 | max_seq_len=batch_config_.block_size, 93 | n_blocks=experimental_params['n_blocks'], 94 | n_tokens=vocab_size_, 95 | token_embedding_save_name='wte.weight', 96 | positional_embedding_save_name='wpe.weight', 97 | ln=LN.Config(w_save_name='weight', b_save_name='bias', save_name='ln_f'), 98 | save_name="", 99 | decoder=GptDecoder.Config( 100 | save_name='h', 101 | blocks=GptBlock.Config( 102 | save_name="", 103 | mha=GptMha.Config( 104 | save_name='attn', 105 | QKV_linear=Linear.Config(save_name='c_attn', w_save_name='weight', 106 | b_save_name='bias'), 107 | linear=Linear.Config(save_name="c_proj", w_save_name='weight', b_save_name='bias'), 108 | ), 109 | ln1=LN.Config(w_save_name='weight', b_save_name='bias', save_name='ln_1'), 110 | ffn=GptFfn.Config( 111 | save_name='mlp', 112 | linear1=Linear.Config(w_save_name='weight', b_save_name='bias', save_name='c_fc'), 113 | linear2=Linear.Config(w_save_name='weight', b_save_name='bias', 114 | save_name='c_proj'), 115 | ), 116 | ln2=LN.Config(w_save_name='weight', b_save_name='bias', save_name='ln_2') 117 | ) 118 | )).fill() 119 | 120 | # dataset = "english" 121 | 122 | 123 | path = Path("/Data/lm_models/gpt2") 124 | with safe_open(path / "model.safetensors", framework="flax", device="cpu") as f: 125 | weights_tree_ = load_config(gpt_config_, f.get_tensor) 126 | print(weights_tree_.keys()) 127 | 128 | gpt_ = gpt_config_.make() 129 | init_weights_ = gpt_config_.weights_check(weights_tree_) 130 | 131 | if experimental_params['train']['optimizer'] == 'adam': 132 | adam_config = experimental_params['train']['adam'] 133 | optimizer_ = optax.adam(learning_rate=adam_config['learning_rate'], 134 | b1=adam_config['beta1'], 135 | b2=adam_config['beta2'], 136 | eps=adam_config['eps']) 137 | elif experimental_params['train']['optimizer'] == 'lion': 138 | lion_config = experimental_params['train']['lion'] 139 | optimizer_ = optax.lion(learning_rate=lion_config['learning_rate'], 140 | b1=lion_config['beta1'], 141 | b2=lion_config['beta2']) 142 | elif experimental_params['train']['optimizer'] == 'adamw': 143 | adamw_config = experimental_params['train']['adamw'] 144 | optimizer_ = optax.adamw(learning_rate=adamw_config['learning_rate'], 145 | b1=adamw_config['beta1'], 146 | b2=adamw_config['beta2'], 147 | eps=adamw_config['eps']) 148 | else: 149 | raise ValueError(f"optimizer {experimental_params['train']['optimizer']} not supported") 150 | 151 | # noinspection PyArgumentList 152 | # cuz it's a NamedTuple 153 | train_config_ = TrainConfig(model=gpt_config_.make(), 154 | loss_fn_in=gpt_loss_batch, 155 | optimiser=optimizer_) 156 | # noinspection PyArgumentList 157 | # cuz it's a NamedTuple 158 | train_state_: TrainState[Gpt.Weights] = TrainState(weights=init_weights_, 159 | opt_state=optimizer_.init(init_weights_)) 160 | 161 | keys_ = next(key_gen).split(max_iters) 162 | for step in range(max_iters): 163 | batch_ = batch_config_.sample(train_data, keys_[step]) 164 | if step % eval_interval == 0: 165 | loss = train_config_.loss_fn(train_state_.weights, batch_) 166 | print(f"===step {step} is an eval step===") 167 | print(f"before step {step}, batch loss {loss}") 168 | 169 | train_state_ = train_config_.train1(train_state_, batch_) 170 | if step % eval_interval == 0: 171 | loss = train_config_.loss_fn(train_state_.weights, batch_) 172 | print(f"after step {step}, batch loss {loss}") 173 | results = train_config_.estimate_loss(eval_iters, key_gen, train_state_, batch_config_, 174 | {'train': train_data, 'val': valid_data}) 175 | # generate_f = jax.jit(partial(dynamic_model.f, train_state_.weights)) 176 | # generated = gpt.generate(generate_f, [0], n_tokens_to_generate=10, 177 | # max_len=batch_config_.block_size) 178 | generate_f = partial(train_config_.model.f, train_state_.weights) 179 | # TODO fix generation/ add temperature 180 | generated = nlp_utils.generate_static(generate_f, [0], 181 | n_tokens_to_generate=batch_config_.block_size - 1, 182 | max_len=batch_config_.block_size) 183 | # wandb.log({"train_loss": results['train'], 184 | # "validation_loss": results['val'], 185 | # "batch_loss": loss, 186 | # "n_tokens_trained": step * batch_config_.batch_size * batch_config_.block_size, 187 | # 'generated': wandb.Html(f"{decode(generated)}")}) 188 | print(decode(generated), flush=True) 189 | 190 | # wandb.finish() 191 | -------------------------------------------------------------------------------- /gpt_recon/view_attn.py: -------------------------------------------------------------------------------- 1 | import altair as alt 2 | import numpy as np 3 | import pandas as pd 4 | import safetensors 5 | from numpy.typing import NDArray 6 | 7 | 8 | def heatmap(mat: NDArray) -> alt.Chart: 9 | x_len, y_len = mat.shape 10 | x, y = np.meshgrid(np.arange(x_len), np.arange(y_len)) 11 | source = pd.DataFrame({ 12 | 'x': x.flatten(), 13 | 'y': y.flatten(), 14 | 'z': mat.flatten() 15 | }) 16 | 17 | return alt.Chart(source).mark_rect().encode( 18 | x='x:O', 19 | y='y:O', 20 | color='z:Q' 21 | ) 22 | 23 | 24 | view_vec_dict = safetensors.safe_open('saves/view_vec2_dict', 'flax') 25 | attn_result = view_vec_dict.get_tensor('attn_raw') 26 | # head, layer, token, shape 27 | chart = alt.vconcat().resolve_scale( 28 | color='independent' 29 | ) 30 | for i in range(12): 31 | row = alt.hconcat().resolve_scale(color='independent') 32 | for j in range(12): 33 | row |= heatmap(attn_result[j, i, :, :]) 34 | chart &= row 35 | chart.save('attn.html') 36 | -------------------------------------------------------------------------------- /gpt_recon/view_notebook.py: -------------------------------------------------------------------------------- 1 | import altair as alt 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import pandas as pd 5 | import safetensors.flax 6 | from numpy.typing import NDArray 7 | from sklearn.cluster import AgglomerativeClustering 8 | from sklearn.manifold import TSNE 9 | 10 | alt.renderers.enable('svg') 11 | 12 | 13 | def heatmap_at(mat: NDArray, display: str): 14 | x_len, y_len = mat.shape 15 | x, y = np.meshgrid(np.arange(x_len), np.arange(y_len)) 16 | source = pd.DataFrame({ 17 | 'x': x.flatten(), 18 | 'y': y.flatten(), 19 | 'z': mat.flatten() 20 | }) 21 | 22 | alt.Chart(source).mark_rect().encode( 23 | x='x:O', 24 | y='y:O', 25 | color='z:Q' 26 | ).save(f'{display}.html') 27 | 28 | 29 | def heatmap(x): 30 | plt.figure() 31 | plt.imshow(x) 32 | plt.show() 33 | 34 | 35 | def plot1(x): 36 | print(x.shape) 37 | plt.figure() 38 | plt.plot(x) 39 | plt.show() 40 | 41 | 42 | # %% 43 | # view_vec = np.load('view_vec.npy') 44 | view_vec_dict = safetensors.safe_open('view_vec2_dict', 'flax') 45 | view_vec1 = view_vec_dict.get_tensor('x_before_ffn') 46 | view_vec2 = view_vec_dict.get_tensor('x_after_ffn') 47 | 48 | plot1(view_vec1[:, 0, 0]) 49 | plot1(view_vec2[:, 0, 0]) 50 | 51 | # %% 52 | bf_layer = view_vec_dict.get_tensor('x_before_mha') 53 | aft_layer = view_vec_dict.get_tensor('x') 54 | prompt = "Time flies like an arrow ; fruit flies like a banana . Time files like an" 55 | token_list = prompt.split(" ") 56 | token_list = [str(i).zfill(2) + "_" + token for i, token in enumerate(token_list)] 57 | # view_vec = np.stack([bf_layer, aft_layer]).reshape((-1, len(token_list), 768)) 58 | view_vec = aft_layer 59 | 60 | # layers, tokens, channels 61 | plot1(view_vec[:, 0, 0]) 62 | plot1(view_vec[0, :, 0]) 63 | plot1(view_vec[0, 0, :]) 64 | 65 | # %% 66 | # heatmap_at(np.median(view_vec, axis=-1)) 67 | n_layers, n_tokens, n_channels = view_vec.shape 68 | corr_layers = np.array([np.corrcoef(view_vec[:, i, :]) for i in range(n_tokens)]).mean(axis=0) 69 | heatmap_at(corr_layers, "chart1") 70 | corr_tokens = np.array([np.corrcoef(view_vec[i, :, :]) for i in range(n_layers)]).mean(axis=0) 71 | heatmap_at(corr_tokens, "chart2") 72 | 73 | # %% 74 | plot1(np.array([np.corrcoef(view_vec[:, i, :]) for i in range(n_tokens)]).mean(axis=(1, 2))) 75 | plot1(np.array([np.corrcoef(view_vec[i, :, :]) for i in range(n_layers)]).mean(axis=(1, 2))) 76 | plot1(np.array([np.corrcoef(view_vec[:, :, i]) for i in range(n_channels)]).mean(axis=(1, 2))) 77 | plot1(np.array([np.corrcoef(view_vec[:, :, i].T) for i in range(n_channels)]).mean(axis=(1, 2))) 78 | 79 | # %% 80 | all_vecs = view_vec.reshape(-1, n_channels) 81 | # clustering = AgglomerativeClustering(n_clusters=3).fit(all_vecs) 82 | clustering = AgglomerativeClustering(n_clusters=10, linkage="average", metric="cosine").fit(all_vecs) 83 | c = clustering.labels_.reshape(n_layers, n_tokens) 84 | heatmap_at(c, "chart3") 85 | 86 | # %% 87 | # TSNE with first layer 88 | view_vec_dict = safetensors.safe_open('view_vec2_dict', 'flax') 89 | bf_layer = view_vec_dict.get_tensor('x_before_mha') 90 | aft_layer = view_vec_dict.get_tensor('x') 91 | prompt = "Time flies like an arrow ; fruit flies like a banana . Time files like an" 92 | token_list = prompt.split(" ") 93 | token_list = [str(i).zfill(2) + "_" + token for i, token in enumerate(token_list)] 94 | 95 | view_vec = np.stack([bf_layer, aft_layer]).reshape((-1, len(token_list), 768)) 96 | 97 | n_layers, n_tokens, n_channels = view_vec.shape 98 | assert n_tokens == len(token_list) 99 | all_vecs = view_vec.reshape(-1, n_channels) 100 | # all_vecs = view_vec[:-1, :, :].reshape(-1, n_channels) 101 | # n_layers = n_layers - 1 102 | new = TSNE(n_components=2, learning_rate='auto', 103 | init='random', perplexity=3, metric='cosine').fit_transform(all_vecs) 104 | new = np.concatenate([new, np.array([np.repeat(i, n_tokens) for i in range(n_layers)]).reshape(-1, 1), 105 | np.array([np.arange(n_tokens) for i in range(n_layers)]).reshape(-1, 1)], axis=1) 106 | 107 | source = pd.DataFrame(new, columns=['x', 'y', 'layer', 'token']) 108 | source['token'] = source['token'].apply(lambda x: token_list[int(x)]) 109 | source['layer'] = source['layer'].apply(lambda x: int(x)) 110 | alt.Chart(source).mark_circle().encode( 111 | x='x', 112 | y='y', 113 | size=alt.Size('layer', legend=alt.Legend(type="symbol", symbolLimit=0), type='ordinal'), 114 | # rainbow 115 | color=alt.Color('token', scale=alt.Scale(scheme='rainbow'), legend=alt.Legend(type="symbol", symbolLimit=0)), 116 | ).properties(width=1000, height=1000).save('chart4.svg') 117 | 118 | # TSNE without first layer 119 | -------------------------------------------------------------------------------- /gpt_recon/view_vec.py: -------------------------------------------------------------------------------- 1 | import altair as alt 2 | import numpy as np 3 | import pandas as pd 4 | import safetensors.flax 5 | from sklearn.manifold import TSNE 6 | from pacmap import PaCMAP 7 | from umap import UMAP 8 | from trimap import TRIMAP 9 | 10 | alt.renderers.enable('svg') 11 | 12 | 13 | view_vec_dict = safetensors.safe_open('saves/view_vec2_dict_jit', 'flax') 14 | bf_layer = view_vec_dict.get_tensor('x0') 15 | # mid_layer = view_vec_dict.get_tensor('x_after_mha') 16 | mid_layer = view_vec_dict.get_tensor('x_before_ffn') 17 | aft_layer = view_vec_dict.get_tensor('x') 18 | assert (bf_layer[1:, :, :] == aft_layer[:-1, :, :]).all() 19 | 20 | # %% 21 | prompt = "Time flies like an arrow ; fruit flies like a banana . Time files like an" 22 | token_list = prompt.split(" ") 23 | token_list = [str(i).zfill(2) + "_" + token for i, token in enumerate(token_list)] 24 | 25 | # view_vec = np.stack([bf_layer, aft_layer]).reshape((-1, len(token_list), 768), order='F') 26 | # view_vec = bf_layer 27 | view_vec = np.concatenate((bf_layer[0, :, :][np.newaxis, :, :], aft_layer), axis=0) 28 | # view_vec = np.stack((bf_layer[0, :, :], aft_layer[-1, :, :]), axis=0) 29 | 30 | # view_vec = np.stack([mid_layer, aft_layer]).reshape((-1, len(token_list), 768), order='F') 31 | # view_vec = np.concatenate((bf_layer[0, :, :][np.newaxis, :, :], view_vec), axis=0) 32 | 33 | n_layers, n_tokens, n_channels = view_vec.shape 34 | assert n_tokens == len(token_list) 35 | 36 | #%% 37 | 38 | 39 | all_vecs = view_vec.reshape(-1, n_channels) 40 | # all_vecs = view_vec[:-1, :, :].reshape(-1, n_channels) 41 | # n_layers = n_layers - 1 42 | # embedding = PaCMAP(distance='angular', n_components=2, n_neighbors=None, MN_ratio=0.5, FP_ratio=2.0) 43 | embedding = UMAP(n_neighbors=10, metric='cosine') 44 | # embedding = TRIMAP(distance='cosine') 45 | # embedding = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=3, metric='cosine') 46 | new = embedding.fit_transform(all_vecs) 47 | new = np.concatenate([new, np.array([np.repeat(i, n_tokens) for i in range(n_layers)]).reshape(-1, 1), 48 | np.array([np.arange(n_tokens) for i in range(n_layers)]).reshape(-1, 1)], axis=1) 49 | 50 | source = pd.DataFrame(new, columns=['x', 'y', 'layer', 'token']) 51 | source['token'] = source['token'].apply(lambda x: token_list[int(x)]) 52 | source['layer'] = source['layer'].apply(lambda x: int(x)) 53 | 54 | lines_all = [] 55 | for i in range(n_tokens): 56 | t = source[source['token'] == token_list[i]] 57 | lines = alt.Chart(t).mark_line().encode( 58 | x='x', 59 | y='y', 60 | order='layer', 61 | color=alt.value("#AAAAAA") 62 | ) 63 | lines_all.append(lines) 64 | 65 | dots = alt.Chart(source).mark_circle().encode( 66 | x='x', 67 | y='y', 68 | size=alt.Size('layer', legend=alt.Legend(type="symbol", symbolLimit=0), type='ordinal'), 69 | # rainbow 70 | color=alt.Color('token', scale=alt.Scale(scheme='rainbow'), legend=alt.Legend(type="symbol", symbolLimit=0)), 71 | ).properties(width=1000, height=1000) 72 | 73 | final = sum(lines_all, alt.LayerChart()) + dots 74 | final.save('lines.svg') 75 | 76 | # TSNE without first layer 77 | -------------------------------------------------------------------------------- /labels.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from typing import Protocol, runtime_checkable, Union 5 | 6 | 7 | @dataclass(frozen=True) 8 | class Labels: 9 | tags: frozenset[str] 10 | 11 | def fmt(self) -> str: 12 | return ', '.join(self.tags) 13 | 14 | def check(self) -> None: 15 | if not isinstance(self.tags, frozenset): 16 | raise ValueError(f"tags must be a frozenset, not {type(self.tags)}") 17 | 18 | def covers(self, other: Labels) -> bool: 19 | return self.tags >= other.tags 20 | 21 | def __contains__(self, item: str) -> bool: 22 | return item in self.tags 23 | 24 | def __len__(self) -> int: 25 | return len(self.tags) 26 | 27 | def __add__(self, other: Union[str, Labels]) -> Labels: 28 | if isinstance(other, str): 29 | return Labels(tags=self.tags | {other}) 30 | else: 31 | return Labels(tags=self.tags | other.tags) 32 | 33 | def __radd__(self, other): 34 | return self + other 35 | 36 | @classmethod 37 | def empty(cls): 38 | return cls(frozenset()) 39 | 40 | @classmethod 41 | def from_strs(cls, *tags: str) -> Labels: 42 | return cls(frozenset(tags)) 43 | 44 | 45 | 46 | L = Labels.from_strs 47 | 48 | 49 | @runtime_checkable 50 | class WithLabels(Protocol): 51 | labels: Labels 52 | 53 | def clear_labels(self) -> WithLabels: 54 | ... 55 | -------------------------------------------------------------------------------- /lra_arena/lra_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import NamedTuple, Sequence, Callable, Generator, Iterator 3 | 4 | import jax.numpy as np 5 | from jax import random 6 | from s5.dataloading import Datasets, DataLoader 7 | 8 | from picojax.random_utils import SafeKey 9 | from picojax.train_utils import MixedLenBatchType 10 | 11 | 12 | class LRABatchConfig(NamedTuple): 13 | block_size: int 14 | batch_size: int 15 | s5_dataloaders: DataLoader 16 | train_size: int 17 | n_classes_in: int 18 | n_classes_out: int 19 | 20 | @classmethod 21 | def from_s5(cls, batch_size: int, cache_path: Path, dataset_name: str, seed: int = 0): 22 | create_dataset_fn = Datasets[dataset_name] 23 | trainloader, valloader, testloader, aux_dataloaders, n_classes, seq_len, in_dim, train_size = create_dataset_fn( 24 | cache_path, seed=seed, bsz=batch_size) 25 | return cls(block_size=seq_len, batch_size=batch_size, 26 | s5_dataloaders={'train': trainloader, 'val': valloader, 'test': testloader}, train_size=train_size, 27 | n_classes_in=in_dim, n_classes_out=n_classes) 28 | 29 | @property 30 | def samplers(self) -> dict[str, Callable[[SafeKey], MixedLenBatchType]]: 31 | def get_sampler(loader: DataLoader) -> Callable[[SafeKey], MixedLenBatchType]: 32 | loader_sampler = iter(loader.batch_sampler) 33 | def sampler(key: SafeKey): 34 | x, y, l = next(loader_sampler) 35 | return np.array(x), np.array(y), np.array(l['lengths']) 36 | 37 | return sampler 38 | 39 | return {k: get_sampler(v) for k, v in self.s5_dataloaders.items()} 40 | 41 | @property 42 | def dataloaders(self) -> dict[str, Iterator[MixedLenBatchType]]: 43 | def get_dataloader(loader: DataLoader) -> Iterator[MixedLenBatchType]: 44 | def data_generator() -> Iterator[MixedLenBatchType]: 45 | loader_iter = iter(loader) 46 | while True: 47 | x, y, l = next(loader_iter) 48 | yield np.array(x), np.array(y), np.array(l['lengths']) 49 | 50 | return data_generator() 51 | return {k: get_dataloader(v) for k, v in self.s5_dataloaders.items()} 52 | -------------------------------------------------------------------------------- /lra_arena/rwkv_lra.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from functools import partial 5 | from pathlib import Path 6 | from typing import Optional 7 | 8 | import jax.numpy as np 9 | import optax 10 | import wandb 11 | from jax.tree_util import tree_flatten 12 | 13 | from copy_init.weights import get_normal_weights_config_, init, save_pytree_ 14 | from labels import Labels 15 | from lra_arena.lra_utils import LRABatchConfig 16 | from pico_rwkv.pico_rwkv_parallel import rwkv_net_parallel 17 | from pico_rwkv.pico_rwkv_rnn import rwkv_net_rnn 18 | from pico_rwkv.pico_rwkv_weights import parse_rwkv_weight 19 | from picojax.jax_utils import WeightsTree, Arr 20 | from picojax.random_utils import infinite_safe_keys 21 | from picojax.train_utils import TrainConfig, TrainState, get_classification_loss 22 | from python_utils import num_short_form 23 | 24 | os.environ['JAX_LOG_COMPILES'] = '1' 25 | model_path = Path("/Data/lm_models/rwkv") 26 | # model_name = 'RWKV-4-Pile-430M-20220808-8066' 27 | model_name = 'RWKV-4-Pile-169M-20220807-8023' 28 | 29 | weight_infos = get_normal_weights_config_(model_path, model_name) 30 | 31 | keygen = infinite_safe_keys(0) 32 | key = next(keygen) 33 | 34 | init_weights_raw = init(weight_infos, rng_key=key) 35 | all_weight_names = list(init_weights_raw.keys()) 36 | 37 | # randomly initialize weights 38 | init_weights_: WeightsTree = parse_rwkv_weight(init_weights_raw.keys(), init_weights_raw.__getitem__, trim=True) 39 | ## load weights instead of randomly initialize 40 | # with safe_open(model_path / f"{model_name}.safetensors", framework="flax", device="cpu") as f: 41 | # init_weights_ = parse_rwkv_weight(f.keys(), f.get_tensor) 42 | 43 | n_channels = init_weights_['emb']['weight'].shape[1] # type: ignore 44 | _, tree_struct = tree_flatten(init_weights_) 45 | 46 | train_tags: Optional[list[Labels]] = None 47 | weight_mask = None 48 | 49 | key_gen = infinite_safe_keys(0) 50 | 51 | adam_params = { 52 | 'learning_rate': 1e-4, 53 | 'beta1': 0.9, 54 | 'beta2': 0.999, 55 | 'eps': 1e-8, 56 | } 57 | lion_params = { 58 | 'learning_rate': 1e-4, 59 | 'beta1': 0.95, 60 | 'beta2': 0.98, 61 | 'weight_decay': 0.01 62 | } 63 | train_params = { 64 | 'eval_iters': 200, 65 | 'eval_interval': 2000, 66 | 'save_interval': 10000, 67 | 'max_iters': 1000000, 68 | # 'adam': adam_params, 69 | 'lion': lion_params, 70 | # 'adamw': adam_params, 71 | 'optimizer': 'lion', 72 | } 73 | batch_size = 3 74 | 75 | batch_config_ = LRABatchConfig.from_s5( 76 | batch_size=batch_size, 77 | cache_path=Path("/Data/datasets/pile"), 78 | dataset_name='listops-classification' 79 | ) 80 | 81 | # reduce embedding to vocab size 82 | init_weights_['emb']['weight'] = init_weights_['emb']['weight'][:batch_config_.n_classes_in, :] # type: ignore 83 | init_weights_['head']['weight'] = init_weights_['head']['weight'][:batch_config_.n_classes_out, :] # type: ignore 84 | 85 | experimental_params: dict = { 86 | 'eps': 1e-5, 87 | 'n_tokens': batch_config_.n_classes_in, 88 | 'n_channels': n_channels, 89 | 'n_blocks': len(init_weights_['blocks']), 90 | 91 | 'train_tags': [l.fmt() for l in train_tags] if train_tags is not None else None, 92 | 93 | 'batch_size': batch_config_.batch_size, 94 | 'block_size': batch_config_.block_size, 95 | 'train': train_params, 96 | 'model': "rwkv" 97 | } 98 | 99 | max_iters = experimental_params['train']['max_iters'] 100 | eval_interval = experimental_params['train']['eval_interval'] 101 | save_interval = experimental_params['train']['save_interval'] 102 | eval_iters = experimental_params['train']['eval_iters'] 103 | 104 | if experimental_params['train']['optimizer'] == 'adam': 105 | adam_config = experimental_params['train']['adam'] 106 | optimizer_ = optax.adam(learning_rate=adam_config['learning_rate'], 107 | b1=adam_config['beta1'], 108 | b2=adam_config['beta2'], 109 | eps=adam_config['eps']) 110 | elif experimental_params['train']['optimizer'] == 'lion': 111 | lion_config = experimental_params['train']['lion'] 112 | optimizer_ = optax.lion(learning_rate=lion_config['learning_rate'], 113 | b1=lion_config['beta1'], 114 | b2=lion_config['beta2'], 115 | weight_decay=lion_config['weight_decay']) 116 | elif experimental_params['train']['optimizer'] == 'adamw': 117 | adamw_config = experimental_params['train']['adamw'] 118 | optimizer_ = optax.adamw(learning_rate=adamw_config['learning_rate'], 119 | b1=adamw_config['beta1'], 120 | b2=adamw_config['beta2'], 121 | eps=adamw_config['eps']) 122 | else: 123 | raise ValueError(f"optimizer {experimental_params['train']['optimizer']} not supported") 124 | 125 | 126 | def rwkv_f(w: WeightsTree, token_array: Arr) -> Arr: 127 | return rwkv_net_parallel(len(token_array), token_array, **w) 128 | 129 | 130 | def rwkv_rnn(w: WeightsTree, token_array: Arr, state: Arr) -> tuple[Arr, Arr]: 131 | return rwkv_net_rnn(token_array, state, **w) 132 | 133 | 134 | # noinspection PyArgumentList 135 | # cuz it's a NamedTuple 136 | train_config_ = TrainConfig(loss_fn=partial(get_classification_loss, rwkv_f), 137 | optimiser=optimizer_) 138 | # noinspection PyArgumentList 139 | # cuz it's a NamedTuple 140 | train_state_: TrainState = TrainState(weights=init_weights_, 141 | train_mask=weight_mask, 142 | opt_state=optimizer_.init(init_weights_)) 143 | 144 | rnn_init_state = np.zeros((experimental_params['n_blocks'], 5, experimental_params['n_channels'])) 145 | for i in range(experimental_params['n_blocks']): 146 | # to jax state[5 * i + 4] = -1e30 147 | rnn_init_state = rnn_init_state.at[i, 4].set(-1e30) 148 | 149 | run = wandb.init( 150 | project="inside-transformer", 151 | config=experimental_params, 152 | ) 153 | assert isinstance(run, wandb.sdk.wandb_run.Run) 154 | 155 | keys_ = next(key_gen).split(max_iters) 156 | train_batch_iter = iter(batch_config_.dataloaders['train']) 157 | for step in range(max_iters): 158 | batch_ = next(train_batch_iter) 159 | 160 | if step % eval_interval == 0: 161 | loss = train_config_.loss_fn(train_state_.weights, batch_) 162 | print(f"\n===[ step {step} is an eval step ]==========") 163 | print(f"before step {step}, batch loss {loss}") 164 | 165 | train_state_ = train_config_.train1(train_state_, batch_) 166 | if step % eval_interval == 0: 167 | loss = train_config_.loss_fn(train_state_.weights, batch_) 168 | print(f"after step {step}, batch loss {loss}") 169 | results = train_config_.estimate_loss(eval_iters, key_gen, train_state_, batch_config_.samplers) 170 | 171 | wandb.log({"train_loss": results['train'], 172 | "validation_loss": results['val'], 173 | "batch_loss": loss, 174 | "n_tokens_trained": step * batch_config_.batch_size * batch_config_.block_size}) 175 | if step % save_interval == 0: # and step > 0: 176 | n_tokens_trained = step * batch_config_.batch_size * batch_config_.block_size 177 | n_tokens_trained_str = num_short_form(n_tokens_trained) 178 | wandb.save(save_pytree_(train_state_.weights, run.dir, f"{model_name}_{n_tokens_trained_str}"), run.dir) 179 | 180 | wandb.finish() 181 | 182 | # [Done] add trainable weights (via gradient mask) 183 | # [TODO] add trainable weights via weight name mask 184 | # [Done] add checkpointing 185 | # TODO: add weight decay 186 | -------------------------------------------------------------------------------- /lra_arena/use_lra_demo.py: -------------------------------------------------------------------------------- 1 | # to download data: https://github.com/google-research/long-range-arena 2 | # https://storage.googleapis.com/long-range-arena/lra_release.gz unzip 3 | from pathlib import Path 4 | 5 | # load data in /Data/LRA/lra_release/lra_release/ 6 | # clone https://github.com/lindermanlab/S5/tree/main repo and `pip install -e .` in the repo 7 | from s5.dataloading import Datasets 8 | print(Datasets.keys()) 9 | create_dataset_fn = Datasets['listops-classification'] 10 | batch_size = 8 11 | seed = 0 12 | data_path = Path("/Data") 13 | trainloader, valloader, testloader, aux_dataloaders, n_classes, seq_len, in_dim, train_size = create_dataset_fn(data_path, seed=0, bsz=batch_size) 14 | print(aux_dataloaders) 15 | print(in_dim) 16 | print(seq_len) 17 | print(n_classes) 18 | print(next(iter(trainloader))[0].shape) 19 | print(next(iter(trainloader))[1].shape) 20 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | plugins = pydantic.mypy -------------------------------------------------------------------------------- /nlp_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Callable, Iterator, Protocol 4 | 5 | from jax import numpy as jnp, random, numpy as np 6 | from jax._src.lax.control_flow import scan 7 | from jax._src.nn.functions import softmax 8 | from tqdm import tqdm 9 | 10 | from picojax.jax_utils import Arr 11 | from picojax.random_utils import SafeKey 12 | 13 | 14 | class Tokens(Protocol): 15 | ids: list[int] 16 | 17 | 18 | class Tokenizer(Protocol): 19 | def encode(self, text: str) -> Tokens: 20 | ... 21 | 22 | def decode(self, ids: list[int]) -> str: 23 | ... 24 | 25 | def get_vocab_size(self) -> int: 26 | ... 27 | 28 | 29 | def generate(get_logits: Callable[[Arr], Arr], inputs: list[int], n_tokens_to_generate: int, max_len: int): 30 | input_window = inputs 31 | for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop 32 | logits = get_logits(jnp.array(input_window)) 33 | next_id = jnp.argmax(logits[-1]) # greedy sampling 34 | inputs.append(int(next_id)) # append prediction to input 35 | input_window = inputs[-max_len:] # update input window 36 | 37 | return inputs[len(inputs) - n_tokens_to_generate:] # only return generated ids 38 | 39 | 40 | def generate_static(get_logits: Callable[[Arr], Arr], inputs: list[int], n_tokens_to_generate: int, max_len: int): 41 | for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop 42 | if len(inputs) >= max_len: 43 | input_window = inputs[-max_len:] # update input window 44 | else: 45 | input_window = inputs + [0] * (max_len - len(inputs)) 46 | output_index = len(inputs) - 1 47 | logits = get_logits(jnp.array(input_window)) 48 | next_id = jnp.argmax(logits[output_index]) # greedy sampling 49 | inputs.append(int(next_id)) # append prediction to input 50 | 51 | return inputs[len(inputs) - n_tokens_to_generate:] # only return generated ids 52 | 53 | 54 | # not working yet from https://github.com/cgarciae/nanoGPT-jax/blob/master/model.py 55 | def generate_static_inplace(get_logits: Callable[[Arr], Arr], 56 | key: SafeKey, 57 | inputs: list[int], 58 | n_tokens_to_generate: int, 59 | max_len: int, 60 | temperature=1.0, 61 | top_k=None): 62 | input_len = len(inputs) 63 | input_tokens = jnp.array(inputs) 64 | padding = jnp.zeros(n_tokens_to_generate, dtype=jnp.int32) 65 | tokens = jnp.concatenate([input_tokens, padding], axis=-1) 66 | indexes = jnp.arange(input_len, input_len + n_tokens_to_generate) 67 | 68 | # tokens index -> tokens None 69 | def scan_f(tokens, i): 70 | # l: x y 71 | # t: a b - - 72 | # i: 0 1 2 3 73 | step_key = random.fold_in(key.get(), i) 74 | # if the sequence context is growing too long we must crop it at block_size 75 | # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 76 | # forward the model to get the logits for the index in the sequence 77 | logits = get_logits(tokens) 78 | # pluck the logits at the final step and scale by desired temperature 79 | logits = logits[i - 1] / temperature 80 | # optionally crop the logits to only the top k options 81 | # sample from the distribution 82 | if top_k is not None: 83 | top_logits, top_tokens = top_k(logits, min(top_k, logits.shape[-1])) 84 | token_idx = random.categorical(step_key, top_logits, axis=-1) 85 | next_token = jnp.take_along_axis(top_tokens, token_idx[:, None], axis=-1).squeeze(-1) 86 | else: 87 | next_token = random.categorical(step_key, logits, axis=-1) 88 | # logits = jnp.where(logits < v[:, -1:], float('-inf'), logits) 89 | # append sampled index to the running sequence and continue 90 | tokens = tokens.at[i].set(next_token) 91 | 92 | return tokens, None 93 | 94 | tokens, _ = scan(scan_f, tokens, indexes) 95 | 96 | return tokens.tolist() 97 | 98 | 99 | def rnn_generate(get_logits_rnn: Callable[[Arr, Arr], tuple[Arr, Arr]], 100 | context: str, 101 | init_state: Arr, 102 | tokenizer: Tokenizer, 103 | key_gen: Iterator[SafeKey], 104 | argmax: bool = False, 105 | length_per_trial: int = 100, n_trials: int = 1, temperature: float = 1.0, top_p: float = 0.85) -> str: 106 | init_state = init_state.copy() 107 | for token in tokenizer.encode(context).ids: 108 | init_out, init_state = get_logits_rnn(token, init_state) 109 | # print(init_state[1, :, 1]) 110 | 111 | def sample_logits(logits, key, temperature=1.0, top_p=0.8): 112 | probs = softmax(logits, axis=-1) 113 | sorted_probs = np.sort(probs)[::-1] 114 | cumulative_probs = np.cumsum(sorted_probs) 115 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) 116 | probs = np.where(probs < cutoff, 0, probs) 117 | if temperature != 1.0: 118 | probs = np.power(probs, 1.0 / temperature) 119 | probs = probs / np.sum(probs) 120 | 121 | out = random.choice(key.get(), a=len(probs), p=probs) 122 | return out 123 | 124 | out_str = "" 125 | for t in range(n_trials): 126 | to_print = f'--[ Trial {t} ]-----------------\n{context}' 127 | print(to_print, end="") 128 | out_str += to_print 129 | all_tokens = [] 130 | out_last = 0 131 | out, state_ = init_out, init_state.copy() 132 | for i in range(length_per_trial): 133 | if argmax: 134 | token = np.argmax(out) 135 | else: 136 | token = sample_logits(out, next(key_gen), temperature, top_p) 137 | all_tokens.append(token.item()) 138 | tmp = tokenizer.decode(all_tokens[out_last:]) 139 | if '\ufffd' not in tmp: # only print when we have a valid utf-8 string 140 | print(tmp, end="", flush=True) 141 | out_str += tmp 142 | out_last = i + 1 143 | out, state_ = get_logits_rnn(token, state_) 144 | print(flush=True) 145 | return out_str 146 | -------------------------------------------------------------------------------- /pico_lru/pico_lru.py: -------------------------------------------------------------------------------- 1 | from jax import numpy as np 2 | from jax import vmap 3 | from jax import lax 4 | 5 | 6 | def forward(input_sequence, nu_log, theta_log, B_re, B_im, C_re, C_im, D, gamma_log): 7 | """Forward pass of the LRU layer. Output y and input_sequence are of shape (L, H).""" 8 | 9 | # Materializing the diagonal of Lambda and projections 10 | Lambda = np.exp(-np.exp(nu_log) + 1j * np.exp(theta_log)) 11 | B_norm = (B_re + 1j * B_im) * np.expand_dims(np.exp(gamma_log), axis=-1) 12 | C = C_re + 1j * C_im 13 | 14 | # Running the LRU + output projection 15 | # For details on parallel scan, check discussion in Smith et al (2022). 16 | Lambda_elements = np.repeat(Lambda[None, ...], input_sequence.shape[0], axis=0) 17 | Bu_elements = vmap(lambda u: B_norm @ u)(input_sequence) 18 | elements = (Lambda_elements, Bu_elements) 19 | _, inner_states = lax.associative_scan(binary_operator_diag, elements) # all x_k 20 | y = vmap(lambda x, u: (C @ x).real + D * u)(inner_states, input_sequence) 21 | 22 | return y 23 | -------------------------------------------------------------------------------- /pico_lru/pico_lru_parallel.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cwhy/rwkv-decon/edb98117140f6a8e99a8fe250a1882a9e07a5650/pico_lru/pico_lru_parallel.py -------------------------------------------------------------------------------- /pico_lru/pico_rnn.py: -------------------------------------------------------------------------------- 1 | from jax import numpy as np 2 | 3 | 4 | def rnn_step(x, u, a, b, c, d): 5 | xp = a * x + b * u 6 | y = c * xp + d * u 7 | return xp, y 8 | 9 | 10 | def init_params_(state_size: int, input_size: int, output_size: int): 11 | return { 12 | 'a': np.zeros((state_size, state_size)), 13 | 'b': np.zeros((state_size, input_size)), 14 | 'c': np.zeros((output_size, state_size)), 15 | 'd': np.zeros((output_size, input_size)), 16 | } 17 | -------------------------------------------------------------------------------- /pico_rwkv/associative_scan_toolbox.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import reduce 3 | from typing import NamedTuple, Callable, Sequence, Protocol 4 | 5 | import sympy 6 | from sympy import symbols, init_printing, Eq, sin, cos, Idx, IndexedBase, pprint, Function, Piecewise, Tuple, Mul, Add, \ 7 | lambdify, simplify, expand, Sum, Indexed, Expr, exp 8 | 9 | a, b, c, d, l = symbols('a, b, c, d, λ') 10 | 11 | t = symbols('t', integer=True) 12 | i = Idx('i', t) 13 | j = Idx('j', t) 14 | k = Idx('k', t) 15 | 16 | Z = IndexedBase('z', shape=(t,)) 17 | Y = IndexedBase('y', shape=(t,)) 18 | L = IndexedBase('λ', shape=(t,)) 19 | 20 | init_printing(use_unicode=True) 21 | 22 | gen_next = Eq(Y[t + 1], Y[t] * l + Z[t]) 23 | pprint(gen_next, use_unicode=True) 24 | 25 | 26 | def is_associative(op, a, b, c) -> bool: 27 | return op(op(a, b), c).equals(op(a, op(b, c))) 28 | 29 | 30 | def is_distributive(mul_like, add_like, a, b, c) -> bool: 31 | return mul_like(a, add_like(b, c)).equals(add_like(mul_like(a, b), mul_like(a, c))) 32 | 33 | class IndexedBaseG(Protocol): 34 | def __getitem__(self, item) -> Expr: 35 | ... 36 | 37 | 38 | class RecPair(NamedTuple): 39 | weight: Expr 40 | bias: Expr 41 | 42 | def equals(self, other: 'RecPair') -> bool: 43 | return self.weight.equals(other.weight) and self.bias.equals(other.bias) 44 | 45 | 46 | class IndexedPair(NamedTuple): 47 | weight: IndexedBaseG 48 | bias: IndexedBaseG 49 | 50 | def __getitem__(self, item) -> RecPair: 51 | return RecPair(self.weight[item], self.bias[item]) 52 | 53 | 54 | class RecFormula(NamedTuple): 55 | x: IndexedBase 56 | params: IndexedPair 57 | op_mul: Function = Mul 58 | op_add: Function = Add 59 | 60 | def test(self): 61 | assert is_associative(self.op_mul, a, b, c) 62 | assert is_associative(self.op_add, a, b, c) 63 | assert is_distributive(self.op_mul, self.op_add, a, b, c) 64 | 65 | @property 66 | def rec_base(self) -> Indexed: 67 | return self.params[0].bias 68 | 69 | def rec_step(self, n: int): 70 | base_ = self.rec_base 71 | for i in range(1, n): 72 | w, b = self.params[i] 73 | base_ = self.op_add(self.op_mul(base_, w), b) 74 | return base_ 75 | 76 | def pscan_i(self, i: int) -> RecPair: 77 | return self.params[i] 78 | 79 | @property 80 | def pscan_base(self) -> RecPair: 81 | return self.pscan_i(0) 82 | 83 | def pscan_step(self, left: RecPair, right: RecPair) -> RecPair: 84 | self.test() 85 | return RecPair(simplify(self.op_mul(left.weight, right.weight)), 86 | simplify(expand((self.op_add(self.op_mul(left.bias, right.weight), 87 | right.bias))))) 88 | 89 | 90 | @dataclass 91 | class IndexedExpr: 92 | base_vars: list[IndexedBaseG] 93 | expr: Callable[[Sequence[Expr]], Expr] 94 | 95 | def __getitem__(self, item) -> Expr: 96 | return self.expr([v[item] for v in self.base_vars]) 97 | 98 | 99 | @dataclass 100 | class FakeIndex: 101 | const: Expr 102 | 103 | def __getitem__(self, item) -> Indexed: 104 | return self.const 105 | 106 | # print(build_rec_formula(Y, t, Z, L)) 107 | f = RecFormula(Y, IndexedPair(FakeIndex(l), Z)) 108 | # pprint(f.rec_equation, use_unicode=True) 109 | pprint(f.pscan_i(i), use_unicode=True) 110 | pprint(f.pscan_base, use_unicode=True) 111 | ne = f.pscan_step(f.pscan_base, f.pscan_i(1)) 112 | pprint(ne, use_unicode=True) 113 | pprint(f.pscan_step(ne, f.pscan_i(2)), use_unicode=True) 114 | assert is_associative(Mul, a, b, c) 115 | assert is_associative(f.pscan_step, f.pscan_i(i), f.pscan_i(j), f.pscan_i(k)) 116 | 117 | base_ = f.pscan_base 118 | for i in range(1, 10): 119 | base_ = f.pscan_step(base_, f.pscan_i(i)) 120 | pprint(base_, use_unicode=True) 121 | rec10 = f.rec_step(10) 122 | pprint(rec10, use_unicode=True) 123 | assert base_[1].equals(f.rec_step(10)) 124 | print(simplify(expand(rec10))) 125 | print(simplify(expand(base_[1]))) 126 | 127 | 128 | #%% 129 | K = IndexedBase('k', shape=(t,)) 130 | V = IndexedBase('v', shape=(t,)) 131 | w = symbols('w') 132 | 133 | lc = exp(w) 134 | lci = FakeIndex(lc) 135 | expKV = IndexedExpr([K, V], lambda kv: exp(kv[0]) * kv[1]) 136 | f = RecFormula(Y, IndexedPair(lci, expKV), op_mul=Mul, op_add=Add) 137 | rec10 = f.rec_step(3) 138 | print(rec10) 139 | base_ = f.pscan_base 140 | for i in range(1, 3): 141 | base_ = f.pscan_step(base_, f.pscan_i(i)) 142 | print(base_[1]) 143 | print(base_[0]) 144 | assert base_[1].equals(f.rec_step(3)) 145 | print(simplify(expand(f.rec_step(5)))) 146 | 147 | i = Idx('i', t) 148 | dummy_l = RecPair(lc, expKV[i]) 149 | dummy_r = RecPair(lc, expKV[i + 1]) 150 | ne = f.pscan_step(dummy_l, dummy_r) 151 | print(ne[0]) 152 | print(ne[1]) 153 | 154 | # This can be alternative associative scan function of rwkv 155 | # version 2 156 | def pscan_step_alt(left: RecPair, right: RecPair) -> RecPair: 157 | return RecPair(simplify(left.weight + right.weight), 158 | simplify(expand((left.bias * exp(right.weight) + right.bias)))) 159 | 160 | #i = Idx('i', t) 161 | res = reduce(pscan_step_alt, [RecPair(w, expKV[j]) for j in range(1, 5)]) 162 | print(res) 163 | res_rev = reduce(lambda a, b: pscan_step_alt(b, a), [RecPair(w, expKV[j]) for j in reversed(range(1, 5))]) 164 | print(res_rev) 165 | print(simplify(expand(f.rec_step(5)))) 166 | 167 | -------------------------------------------------------------------------------- /pico_rwkv/close_to_original/pico_rwkv.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax.numpy as np 4 | from jax import jit 5 | from jax.lax import rsqrt 6 | from jax.nn import sigmoid, relu 7 | 8 | 9 | def layer_norm(x, weight, bias, eps: float = 1e-5): 10 | mean = np.mean(x, axis=-1, keepdims=True) 11 | variance = np.var(x, axis=-1, keepdims=True) 12 | return weight * (x - mean) * rsqrt(variance + eps) + bias 13 | 14 | 15 | def time_mix(x, old_state, mix): 16 | return x * mix + old_state * (1 - mix) 17 | 18 | 19 | def exp_mix(max_coef, k, v_s, v, base): 20 | """ 21 | 1. If `max_coef >= k`, the mixed value and the updated base value are computed as follows: 22 | mix_v = v_s + e^(k - max_coef) * v 23 | base_new = base + e^(k - max_coef) 24 | 25 | 2. If `max_coef < k`, the mixed value and the updated base value are computed as follows: 26 | mix_v = e^(max_coef - k) * v_s + v 27 | base_new = e^(max_coef - k) * base + 1 28 | """ 29 | new_max_coef = np.maximum(max_coef, k) 30 | e1 = np.exp(max_coef - new_max_coef) 31 | e2 = np.exp(k - new_max_coef) 32 | mix_v = e1 * v_s + e2 * v 33 | base_new = e1 * base + e2 34 | return mix_v, base_new, new_max_coef 35 | 36 | 37 | @partial(jit, static_argnames='i') 38 | def rwkv(r, ow, k, v, state, i: int, time_first, time_decay, debug=False): 39 | """ 40 | the original form of the equation is: 41 | $\omega = \frac{a_1 + e^{(t_1 + k) - p} \cdot v}{b_1 + e^{(t_1 + k) - p}}$ 42 | for numerical stability, we rewrite it as: 43 | $\omega = \frac{e^{p - \max(p, t_1 + k)} \cdot a_1 + e^{(t_1 + k) - \max(p, t_1 + k)} \cdot v}{e^{p - \max(p, t_1 + k)} \cdot b_1 + e^{(t_1 + k) - \max(p, t_1 + k)}}$ 44 | where 45 | $a_1$ is wkv_top (state[i, 2]) 46 | $b_1$ is wkv_bot (state[i, 3]) 47 | $p$ is max_coef (state[i, 4]) 48 | $\omega$ is wkv 49 | $t_1$ is time_first 50 | $t_2$ is time_decay 51 | 52 | :param r: [1024, ] 53 | :param ow: 54 | :param k: [1024, ] 55 | :param v: [1024, ] 56 | :param state: [50, 1024] 57 | :param i: int 58 | :param time_first: [1024, ] 59 | :param time_decay: 60 | :return: 61 | """ 62 | v_state, base_state, max_coef = state[i, 2], state[i, 3], state[i, 4] 63 | 64 | wkv_top, wkv_bot, _ = exp_mix(max_coef, k + time_first, v_state, v, base_state) 65 | wkv = wkv_top / wkv_bot 66 | wkv_top_new, wkv_bot_new, max_coef = exp_mix(max_coef + time_decay, k, v_state, v, base_state) 67 | 68 | 69 | state = state.at[i, 2].set(wkv_top_new) 70 | state = state.at[i, 3].set(wkv_bot_new) 71 | state = state.at[i, 4].set(max_coef) 72 | return ow @ (r * wkv), state 73 | 74 | 75 | @partial(jit, static_argnames='i') 76 | def token_mixing(x, state, i: int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow, debug=False): 77 | xk = time_mix(x, state[i, 1], time_mix_k) 78 | xv = time_mix(x, state[i, 1], time_mix_v) 79 | xr = time_mix(x, state[i, 1], time_mix_r) 80 | 81 | state = state.at[i, 1].set(x) 82 | 83 | r = sigmoid(rw @ xr) 84 | k = kw @ xk 85 | v = vw @ xv 86 | return rwkv(r, ow, k, v, state, i, time_first, time_decay, debug=debug) 87 | 88 | 89 | @partial(jit, static_argnames='i') 90 | def channel_mixing(x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw): 91 | xk = time_mix(x, state[i, 0], time_mix_k) 92 | xr = time_mix(x, state[i, 0], time_mix_r) 93 | state = state.at[i, 0].set(x) 94 | r = sigmoid(rw @ xr) 95 | k = np.square(relu(kw @ xk)) # square relu, primer paper 96 | return r * (vw @ k), state 97 | 98 | 99 | @partial(jit, static_argnames='i') 100 | def block(x, state, i: int, att, ffn, ln1, ln2, **kwargs): 101 | xn = layer_norm(x, **ln1) 102 | xp, state = token_mixing(xn, state, i, 103 | att['time_mix_k'], 104 | att['time_mix_v'], 105 | att['time_mix_r'], 106 | att['time_first'], 107 | att['time_decay'], 108 | att['key']['weight'], 109 | att['value']['weight'], 110 | att['receptance']['weight'], 111 | att['output']['weight'], debug=kwargs['debug']) 112 | 113 | x += xp 114 | xn = layer_norm(x, **ln2) 115 | xp, state = channel_mixing(xn, state, i, ffn['time_mix_k'], ffn['time_mix_r'], 116 | ffn['key']['weight'], ffn['value']['weight'], 117 | ffn['receptance']['weight']) 118 | x += xp 119 | return x, state 120 | 121 | 122 | def rwkv_net(token, state, ln_out, blocks, head, emb): 123 | #print(token) 124 | x = emb['weight'][token] 125 | w_ln0 = blocks[0]['ln0'] 126 | x = layer_norm(x, **w_ln0) 127 | for i in range(len(blocks)): 128 | block_w = blocks[i] 129 | x, state = block(x, state, i, debug=token == 49, **block_w) 130 | xn = layer_norm(x, **ln_out) 131 | x = head['weight'] @ xn 132 | return x, state 133 | 134 | 135 | @jit 136 | def rwkv_net_w(token, state, w): 137 | return rwkv_net(token, state, w['ln_out'], w['blocks'], w['head'], w['emb']) 138 | 139 | -------------------------------------------------------------------------------- /pico_rwkv/close_to_original/run_rwkv.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import jax.numpy as np 4 | from safetensors import safe_open 5 | from tokenizers import Tokenizer 6 | 7 | from nlp_utils import rnn_generate 8 | from pico_rwkv.pico_rwkv_weights import parse_rwkv_weight 9 | from picojax.random_utils import infinite_safe_keys 10 | from pico_rwkv.close_to_original.pico_rwkv import rwkv_net_w 11 | 12 | path = Path("/Data/lm_models/rwkv") 13 | # model_name = 'RWKV-4-Pile-430M-20220808-8066' 14 | model_name = 'RWKV-4-Pile-169M-20220807-8023' 15 | # jax.config.update('jax_platform_name', 'cpu') 16 | 17 | 18 | with safe_open(path / f"{model_name}.safetensors", framework="flax", device="cpu") as f: 19 | w = parse_rwkv_weight(f.keys(), f.get_tensor) 20 | 21 | tokenizer = Tokenizer.from_file(str(path / "20B_tokenizer.json")) 22 | 23 | 24 | n_channels = w['emb']['weight'].shape[1] 25 | n_layers = len(w['blocks']) 26 | 27 | context = ("\nPumas are large, cat-like animals found in America. When reports came into London Zoo that " 28 | "a wild puma had been spotted forty-five miles south of London, they were not taken seriously." 29 | " However, as the evidence began to accumulate, experts from the Zoo felt obliged to investigate," 30 | " for the descriptions given by people who claimed to have seen the puma were extraordinarily similar." 31 | "\nThe hunt for the puma began in a small village where a woman picking blackberries saw 'a large cat'" 32 | " only five yards away from her. It") 33 | 34 | # context = "\nPumas are large, " 35 | 36 | 37 | def sample_rwkv_rnn(token_arr, state): 38 | return rwkv_net_w(token_arr, state, w) 39 | 40 | 41 | state = np.zeros((n_layers, 5, n_channels)) 42 | for i in range(n_layers): 43 | # to jax state[5 * i + 4] = -1e30 44 | state = state.at[i, 4].set(-1e30) 45 | 46 | 47 | keygen = infinite_safe_keys(0) 48 | rnn_generate(sample_rwkv_rnn, context, state, tokenizer, keygen, 49 | n_trials=1, 50 | argmax=False, 51 | temperature=1.0, 52 | top_p=0.85, 53 | length_per_trial=100) 54 | 55 | -------------------------------------------------------------------------------- /pico_rwkv/jax_load.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from safetensors import safe_open 4 | 5 | path = Path("/Data/lm_models/rwkv") 6 | model_name = 'RWKV-4-Pile-430M-20220808-8066' 7 | with safe_open(path / f"{model_name}.safetensors", framework="torch", device="cpu") as f: 8 | for key in f.keys(): 9 | t = f.get_tensor(key) 10 | print(key, t.shape) 11 | 12 | -------------------------------------------------------------------------------- /pico_rwkv/pico_rwkv_parallel.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax.numpy as np 4 | from jax import jit, vmap, lax 5 | 6 | from pico_rwkv.pico_rwkv_rnn import channel_mixing, token_mixing, layer_norm 7 | 8 | 9 | def exp_mix(v1, v2, p1, p2): 10 | """ 11 | Given 12 | x1 --> (v1, p1) --> v1 exp(p1), 13 | x2 --> (v2, p2) --> v2 exp(p2), 14 | calculate x1 + x2 and normalize it such that the 15 | largest exp exponent is 0, 16 | x1 + x2 --> (v1 exp(p1-p) + v2 exp(p2-p), p) 17 | where p = max(p1, p2) 18 | """ 19 | p = np.maximum(p1, p2) 20 | return (v1 * np.exp(p1 - p) + v2 * np.exp(p2 - p)), p 21 | 22 | 23 | def token_mixing_parallel(x, x_prev1, time_mix_k, time_mix_v, time_mix_r, kw, vw, rw, **kwargs): 24 | x_prev = np.vstack((x_prev1, x[:-1, :])) 25 | token_mixing_p = vmap(token_mixing, in_axes=(0, 0, None, None, None, None, None, None), out_axes=(0, 0, 0)) 26 | return token_mixing_p(x, x_prev, time_mix_k, time_mix_v, time_mix_r, kw, vw, rw) 27 | 28 | 29 | def channel_mixing_parallel(x, x_prev1, time_mix_k, time_mix_r, kw, vw, rw, **kwargs): 30 | x_prev = np.vstack((x_prev1, x[:-1, :])) 31 | channel_mixing_p = vmap(channel_mixing, in_axes=(0, 0, None, None, None, None, None), out_axes=0) 32 | return channel_mixing_p(x, x_prev, time_mix_k, time_mix_r, kw, vw, rw) 33 | 34 | 35 | def lru_parallel_scannable_normalized(left, right): 36 | (l_exp_kv, l_w, p_w), (r_exp_kv, r_w, p_r) = left, right 37 | p = np.maximum(p_w + r_w, p_r) 38 | return l_exp_kv * np.exp(r_w + p_w - p) + r_exp_kv * np.exp(p_r - p), l_w + r_w, p 39 | 40 | 41 | def rwkv_parallel_scan_stable(r, k, v, ow, time_first, time_decay): 42 | w, u = time_decay, time_first 43 | W = np.repeat(w[np.newaxis, :], v.shape[0], axis=0) 44 | ones = np.ones_like(k) 45 | 46 | a_state, _, p_state = lax.associative_scan(lru_parallel_scannable_normalized, (v, W, k)) 47 | b_state, _, _ = lax.associative_scan(lru_parallel_scannable_normalized, (ones, W, k)) 48 | 49 | c, _ = exp_mix(a_state, v, p_state, u + w + k) 50 | d, _ = exp_mix(b_state, ones, p_state, u + w + k) 51 | 52 | wkv = c / d 53 | return (r * wkv) @ ow.T 54 | 55 | 56 | @partial(jit, static_argnums=(0,)) 57 | def rwkv_net_parallel(seq_len: int, tokens, ln_out, blocks, head, emb): 58 | """ 59 | :param seq_len: int 60 | :param tokens: int32[seq_len] 61 | :param ln_out: 62 | :param blocks: 63 | :param head: 64 | :param emb: {'weight': float32[n_vocab, n_channels]} 65 | :return: 66 | """ 67 | assert seq_len >= 2 68 | zeros_padding = np.zeros_like(emb['weight'][0, :]) 69 | 70 | x = emb['weight'][tokens, :] 71 | w_ln0 = blocks[0]['ln0'] 72 | x = layer_norm(x, **w_ln0) 73 | for i in range(len(blocks)): 74 | block_w = blocks[i] 75 | 76 | xn_token = layer_norm(x, **blocks[i]['ln1']) 77 | r, k, v = token_mixing_parallel(xn_token, zeros_padding, **block_w['att']) 78 | 79 | xp = rwkv_parallel_scan_stable(r, k, v, block_w['att']['output']['weight'], 80 | block_w['att']['time_first'], block_w['att']['time_decay']) 81 | x += xp 82 | xn_channel = layer_norm(x, **blocks[i]['ln2']) 83 | xp = channel_mixing_parallel(xn_channel, zeros_padding, **block_w['ffn']) 84 | 85 | x += xp 86 | xn = layer_norm(x, **ln_out) 87 | # parallel version of `logits = head['weight'] @ xn`, t is time, c is channel, v is vocab 88 | logits = np.einsum('tc,vc->tv', xn, head['weight']) 89 | return logits 90 | 91 | # [Done] load params 92 | # [Done] jit 93 | # [TODO] associative scan 94 | # - trying to formulate according to jianlin's blog (is cumsum in rwkv viable)? 95 | # - not trying to use states, only use it for training and fixed-context inference 96 | # - [Done] make it work 97 | # - [Done] check numerical issues with help of yilun 98 | # [Done] make a scan version and compare with 99 | # - [Done] scan batch of tokens 100 | # - [TODO] maintain state using normalized version of scan 101 | # - [TODO] merge upper and lower scan 102 | # [Abandoned] pico-plus 103 | # [Done] training 104 | -------------------------------------------------------------------------------- /pico_rwkv/pico_rwkv_parallel_alternatives.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from jax import lax 3 | 4 | from pico_rwkv.pico_rwkv_parallel import token_mixing_parallel, channel_mixing_parallel 5 | from pico_rwkv.pico_rwkv_rnn import layer_norm, rwkv, rwkv_state_flow 6 | 7 | 8 | def rwkv_scan(state, rs, ks, vs, ow, time_first, time_decay): 9 | w, u = time_decay, time_first 10 | 11 | def lru_scannable(carry, new): 12 | r, k, v = new 13 | new = rwkv(carry, r, k, v, ow, time_first) 14 | carry_new = rwkv_state_flow(w, carry, k, v) 15 | return np.stack(carry_new), new 16 | 17 | return lax.scan(lru_scannable, state, (rs, ks, vs)) 18 | 19 | 20 | def lru_parallel_scannable(left, right): 21 | (l_exp_kv, l_w), (r_exp_kv, r_w) = left, right 22 | return l_exp_kv * np.exp(r_w) + r_exp_kv, l_w + r_w 23 | 24 | 25 | def rwkv_parallel_scan(seq_len: int, r, k, v, ow, time_first, time_decay): 26 | w, u = time_decay, time_first 27 | W = np.repeat(w[np.newaxis, :], seq_len, axis=0) 28 | 29 | exp_k = np.exp(k) 30 | v_state, _ = lax.associative_scan(lru_parallel_scannable, (exp_k * v, W)) 31 | base_state, _ = lax.associative_scan(lru_parallel_scannable, (exp_k, W)) 32 | curr_k = np.exp(u) * exp_k 33 | 34 | def shift1pad0(x): 35 | return np.pad(x, ((1, 0), (0, 0)), mode='constant', constant_values=0)[:-1, :] 36 | 37 | v_state = shift1pad0(v_state) + curr_k * v 38 | base_state = shift1pad0(base_state) + curr_k 39 | 40 | wkv = v_state / base_state 41 | return (r * wkv) @ ow.T 42 | 43 | 44 | def rwkv_parallel_scan_alt(seq_len: int, r, k, v, ow, time_first, time_decay): 45 | w, u = time_decay, time_first 46 | W = np.repeat(w[np.newaxis, :], seq_len, axis=0) 47 | 48 | exp_k = np.exp(k) 49 | v_state, _ = lax.associative_scan(lru_parallel_scannable, (exp_k * v, W)) 50 | base_state, _ = lax.associative_scan(lru_parallel_scannable, (exp_k, W)) 51 | curr_diff = exp_k * (np.exp(u + w) - 1) 52 | 53 | v_state += curr_diff * v 54 | base_state += curr_diff 55 | base_state += 10e-6 56 | 57 | wkv = v_state / base_state 58 | return (r * wkv) @ ow.T 59 | 60 | 61 | def rwkv_net_scan(seq_len: int, tokens, states, ln_out, blocks, head, emb): 62 | assert seq_len >= 2 63 | prev_xn_token = states[:, 1, :] 64 | prev_xn_channel = states[:, 0, :] 65 | prev_wkv_states = states[:, 2:, :] 66 | 67 | x = emb['weight'][tokens, :] 68 | w_ln0 = blocks[0]['ln0'] 69 | x = layer_norm(x, **w_ln0) 70 | new_states = np.empty_like(states) 71 | for i in range(len(blocks)): 72 | block_w = blocks[i] 73 | 74 | xn_token = layer_norm(x, **blocks[i]['ln1']) 75 | r, k, v = token_mixing_parallel(xn_token, prev_xn_token[i], **block_w['att']) 76 | 77 | state_wkv, xp = rwkv_scan(prev_wkv_states[i], r, k, v, block_w['att']['output']['weight'], 78 | block_w['att']['time_first'], block_w['att']['time_decay']) 79 | x += xp 80 | xn_channel = layer_norm(x, **blocks[i]['ln2']) 81 | xp = channel_mixing_parallel(xn_channel, prev_xn_channel[i], **block_w['ffn']) 82 | 83 | x += xp 84 | new_states = new_states.at[i, :].set(np.stack([xn_channel[-1], xn_token[-1], *(s for s in state_wkv)])) 85 | xn = layer_norm(x, **ln_out) 86 | # parallel version of `logits = head['weight'] @ xn`, t is time, c is channel, v is vocab 87 | logits = np.einsum('tc,vc->tv', xn, head['weight']) 88 | return logits, new_states 89 | -------------------------------------------------------------------------------- /pico_rwkv/pico_rwkv_rnn.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from jax import lax, jit 3 | from jax.lax import rsqrt 4 | from jax.nn import sigmoid, relu 5 | 6 | 7 | def layer_norm(x, weight, bias, eps: float = 1e-5): 8 | mean = np.mean(x, axis=-1, keepdims=True) 9 | variance = np.var(x, axis=-1, keepdims=True) 10 | return weight * (x - mean) * rsqrt(variance + eps) + bias 11 | 12 | 13 | def time_mix(x, x_prev, mix): 14 | return x * mix + x_prev * (1 - mix) 15 | 16 | 17 | def exp_mix_both(max_coef, k, v_s, v, base): 18 | """ 19 | 1. If `max_coef >= k`, the mixed value and the updated base value are computed as follows: 20 | mix_v = v_s + e^(k - max_coef) * v 21 | base_new = base + e^(k - max_coef) 22 | 23 | 2. If `max_coef < k`, the mixed value and the updated base value are computed as follows: 24 | mix_v = e^(max_coef - k) * v_s + v 25 | base_new = e^(max_coef - k) * base + 1 26 | """ 27 | new_max_coef = np.maximum(max_coef, k) 28 | e1 = np.exp(max_coef - new_max_coef) 29 | e2 = np.exp(k - new_max_coef) 30 | mix_v = e1 * v_s + e2 * v 31 | base_new = e1 * base + e2 32 | return mix_v, base_new, new_max_coef 33 | 34 | 35 | def rwkv(state_wkv, r, k, v, ow, time_first, debug=False): 36 | """ 37 | state_wkv: (v_state, wkv_bot, max_coef) / state[i, 2:].T 38 | """ 39 | v_state, base_state, max_coef = state_wkv 40 | 41 | v_state, base_state, _ = exp_mix_both(max_coef, k + time_first, v_state, v, base_state) 42 | wkv = v_state / base_state 43 | return ow @ (r * wkv) 44 | 45 | 46 | def rwkv_state_flow(time_decay, state_wkv, k, v, **kwargs): 47 | v_state, base_state, max_coef = state_wkv 48 | a, b, c = exp_mix_both(max_coef + time_decay, k, v_state, v, base_state) 49 | return a, b, c 50 | 51 | 52 | def token_mixing(x, x_prev, time_mix_k, time_mix_v, time_mix_r, kw, vw, rw, **kwargs): 53 | xk = time_mix(x, x_prev, time_mix_k) 54 | xv = time_mix(x, x_prev, time_mix_v) 55 | xr = time_mix(x, x_prev, time_mix_r) 56 | 57 | r = sigmoid(rw @ xr) 58 | k = kw @ xk 59 | v = vw @ xv 60 | return r, k, v 61 | 62 | 63 | def channel_mixing(x, x_prev, time_mix_k, time_mix_r, kw, vw, rw, **kwargs): 64 | xk = time_mix(x, x_prev, time_mix_k) 65 | xr = time_mix(x, x_prev, time_mix_r) 66 | r = sigmoid(rw @ xr) 67 | k = np.square(relu(kw @ xk)) # square relu, primer paper 68 | return r * (vw @ k) 69 | 70 | 71 | @jit 72 | def rwkv_net_rnn(token, state, ln_out, blocks, head, emb): 73 | x = emb['weight'][token] 74 | w_ln0 = blocks[0]['ln0'] 75 | x = layer_norm(x, **w_ln0) 76 | new_states = np.empty_like(state) 77 | for i in range(len(blocks)): 78 | block_w = blocks[i] 79 | 80 | xn_token = layer_norm(x, **blocks[i]['ln1']) 81 | r, k, v = token_mixing(xn_token, state[i, 1], **block_w['att']) 82 | 83 | state_wkv = state[i, 2:] 84 | # print(state_wkv.shape) 85 | 86 | x += rwkv(state_wkv, r, k, v, block_w['att']['output']['weight'], block_w['att']['time_first'], debug=False) 87 | xn_channel = layer_norm(x, **blocks[i]['ln2']) 88 | xp = channel_mixing(xn_channel, state[i, 0], **block_w['ffn']) 89 | 90 | x += xp 91 | new_state = np.vstack([xn_channel, 92 | xn_token, 93 | *rwkv_state_flow(block_w['att']['time_decay'], 94 | state_wkv, k, v, debug=False)]) 95 | new_states = new_states.at[i].set(new_state) 96 | xn = layer_norm(x, **ln_out) 97 | logits = head['weight'] @ xn 98 | return logits, new_states 99 | -------------------------------------------------------------------------------- /pico_rwkv/pico_rwkv_weights.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, TypeVar, Iterable, Optional 2 | 3 | from copy_init.weights import WeightConfig, get_weights_mask 4 | from labels import Labels 5 | from picojax.jax_utils import Arr, WeightsTree 6 | from jax import numpy as np 7 | 8 | 9 | def parse_rwkv_weight(keys: Iterable[str], get_tensor: Callable[[str], Arr], trim: bool=False) -> WeightsTree: 10 | w = {} 11 | for k in keys: 12 | parts = k.split('.') 13 | last = parts.pop() 14 | current_ = w 15 | for p in parts: 16 | if p.isdigit(): 17 | p = int(p) 18 | if p not in current_: 19 | current_[p] = {} 20 | current_ = current_[p] 21 | current_[last] = get_tensor(k) 22 | 23 | for i in w['blocks'].keys(): 24 | att = w['blocks'][i]['att'] 25 | ffn = w['blocks'][i]['ffn'] 26 | 27 | for m in att, ffn: 28 | for k in ('key', 'value', 'receptance'): 29 | if k in m: 30 | m[k[0] + 'w'] = m[k]['weight'] 31 | if trim: 32 | del m[k]['weight'] 33 | return w 34 | 35 | 36 | def get_masks_to_train(train_tags: Optional[list[Labels]], info: dict[str, WeightConfig], trim:bool=False) -> WeightsTree: 37 | to_train = get_weights_mask(train_tags, info) 38 | mask_raw = {} 39 | for k, train in to_train.items(): 40 | if train: 41 | mask_raw[k] = np.ones(info[k].shape, dtype=bool) 42 | else: 43 | mask_raw[k] = np.zeros(info[k].shape, dtype=bool) 44 | 45 | masks = parse_rwkv_weight(mask_raw.keys(), lambda k: mask_raw[k], trim) 46 | return masks 47 | -------------------------------------------------------------------------------- /pico_rwkv/pth_to_safet.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import jax 5 | from safetensors.flax import save_file 6 | 7 | path = Path("/Data/lm_models/rwkv") 8 | # model_name = 'RWKV-4-Pile-430M-20220808-8066' 9 | model_name = 'RWKV-4-Pile-169M-20220807-8023' 10 | 11 | w_raw = torch.load(path/f'{model_name}.pth', map_location='cpu') 12 | 13 | for k in w_raw.keys(): 14 | if '.time_' in k: 15 | w_raw[k] = w_raw[k].squeeze() 16 | if '.time_decay' in k: 17 | w_raw[k] = -torch.exp(w_raw[k]) # the real time decay is like e^{-e^x} 18 | else: 19 | w_raw[k] = w_raw[k] # convert to f32 type 20 | w_raw[k] = jax.numpy.array(w_raw[k].float()) 21 | 22 | save_file(w_raw, path/f'{model_name}.safetensors') 23 | -------------------------------------------------------------------------------- /pico_rwkv/run_rwkv_parallel.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import jax.numpy as np 4 | from safetensors import safe_open 5 | from tokenizers import Tokenizer 6 | 7 | from nlp_utils import rnn_generate 8 | from pico_rwkv.pico_rwkv_parallel import rwkv_net_parallel 9 | from pico_rwkv.pico_rwkv_parallel_alternatives import rwkv_net_scan 10 | from pico_rwkv.pico_rwkv_rnn import rwkv_net_rnn 11 | from pico_rwkv.pico_rwkv_weights import parse_rwkv_weight 12 | from picojax.jax_utils import Arr 13 | from picojax.random_utils import infinite_safe_keys 14 | 15 | path = Path("/Data/lm_models/rwkv") 16 | model_name = 'RWKV-4-Pile-430M-20220808-8066' 17 | # jax.config.update('jax_platform_name', 'cpu') 18 | 19 | with safe_open(path / f"{model_name}.safetensors", framework="flax", device="cpu") as f: 20 | w = parse_rwkv_weight(f.keys(), f.get_tensor) 21 | 22 | tokenizer = Tokenizer.from_file(str(path / "20B_tokenizer.json")) 23 | 24 | n_channels = w['emb']['weight'].shape[1] 25 | n_layers = len(w['blocks']) 26 | 27 | key_gen = infinite_safe_keys(0) 28 | 29 | # context = """Where must the puma have come from? 30 | # 31 | # Pumas are large, cat-like animals which are found in America. When reports came into London Zoo that a wild puma had been spotted forty-five miles south of London, they were not taken seriously. However, as the evidence began to accumulate, experts from the Zoo felt obliged to investigate, for the descriptions given by people who claimed to have seen the puma were extraordinarily similar. 32 | # The hunt for the puma began in a small village where a woman picking blackberries saw 'a large cat' only five yards away from her. It immediately ran away when she saw it, and experts confirmed that a puma will not attack a human being unless it is cornered. The search proved difficult, for the puma was often observed at one place in the morning and at another place twenty miles away in the evening. Wherever it went, it left behind it a trail of dead deer and small animals like rabbits. Paw prints were seen in a number of places and puma fur was found clinging to bushes. Several people complained of "cat-like noises' at night and a businessman on a fishing trip saw the puma up a tree. The experts were now fully convinced that the animal was a puma, but where had it come from? As no pumas had been reported missing from any zoo in the country, this one must have been in the possession of a private collector and somehow managed to escape. The hunt went on for several weeks, but the puma was not caught. It is disturbing to think that a dangerous wild animal is still at large in the quiet countryside. 33 | # 34 | # In not more than 80 words describe how experts came to the conclusion that the animal seen by many people really was a puma. Do not include anything that is not in the passage. 35 | # Answer these questions in note form to get your points: 36 | # 1 What sort of reports were received by London Zoo? 37 | # 2 Were the reports similar in nature or not? 38 | # 3 Who saw it first? 39 | # 4 Did it stay in one place,or did it move from place to place? 40 | # 5 What did it leave behind it? 41 | # 6 Were paw prints and puma fur found as well or not? 42 | # 7 What was heard at night? 43 | # 8 Was the animal seen up a tree or not? 44 | # 9 Were experts now sure that the animal really was a puma or not? 45 | # 46 | # """ 47 | 48 | 49 | context = ("\nPumas are large, cat-like animals found in America. When reports came into London Zoo that " 50 | "a wild puma had been spotted forty-five miles south of London, they were not taken seriously." 51 | " However, as the evidence began to accumulate, experts from the Zoo felt obliged to investigate," 52 | " for the descriptions given by people who claimed to have seen the puma were extraordinarily similar." 53 | "\nThe hunt for the puma began in a small village where a woman picking blackberries saw 'a large cat'" 54 | " only five yards away from her. It") 55 | # context = "The quick brown fox jumps over the lazy" 56 | # context = "Once upon a " 57 | 58 | def sample_rwkv_rnn(token_arr: Arr, state: Arr) -> tuple[Arr, Arr]: 59 | return rwkv_net_rnn(token_arr, state, **w) 60 | 61 | mode = "parallel" 62 | # mode = "rnn" 63 | # mode = "parallel" 64 | if mode == "parallel": 65 | token_array = np.array(tokenizer.encode(context).ids) 66 | # init_out = rwkv_net_parallel(len(token_array), token_array, **w) 67 | # init_out = init_out[-1, :] 68 | print("token length: ", len(token_array)) 69 | print(context, end="") 70 | 71 | for i in range(100): 72 | init_out = rwkv_net_parallel(len(token_array), token_array, **w) 73 | out = np.argmax(init_out[-1, :], axis=-1) 74 | print(tokenizer.decode([out]), end="", flush=True) 75 | token_array = np.append(token_array[1:], out) 76 | raise Exception("Done") 77 | 78 | 79 | elif mode == "rnn": 80 | state = np.zeros((n_layers, 5, n_channels)) 81 | for i in range(n_layers): 82 | state = state.at[i, -1].set(-1e30) 83 | key_gen = infinite_safe_keys(0) 84 | _ = rnn_generate(sample_rwkv_rnn, context, state, tokenizer, key_gen, 85 | argmax=True, 86 | length_per_trial=100, n_trials=3, temperature=1.0, top_p=0.85) 87 | 88 | elif mode == "scan": 89 | max_len_ = 10 90 | state = np.zeros((n_layers, 5, n_channels)) 91 | for i in range(n_layers): 92 | state = state.at[i, -1].set(-1e30) 93 | 94 | 95 | def process_tokens(token_objs, max_len): 96 | # l_token_array = np.array_split(full_token_array, len(full_token_array) // max_len) 97 | # if len(l_token_array[-1]) < max_len: 98 | # return l_token_array[:-1], l_token_array[-1] 99 | # else: 100 | # return l_token_array, [] 101 | curr_len = len(token_objs) 102 | batch = [] 103 | while curr_len > max_len: 104 | batch.append(np.array(token_objs[:max_len])) 105 | token_objs = token_objs[max_len:] 106 | curr_len = len(token_objs) 107 | return batch, np.array(token_objs) 108 | 109 | 110 | token_array_batch, left_over = process_tokens(tokenizer.encode(context).ids, max_len_) 111 | print(token_array_batch) 112 | print(left_over) 113 | if len(token_array_batch) > 0: 114 | for token_array in token_array_batch: 115 | init_out, state = rwkv_net_scan(max_len_, token_array, state, **w) 116 | if len(left_over) > 0: 117 | for token in left_over: 118 | init_out, state = rwkv_net_rnn(token, state, **w) 119 | else: 120 | print("scan_out", tokenizer.decode(np.argmax(init_out))) 121 | init_out = init_out[-1, :] 122 | key_gen = infinite_safe_keys(0) 123 | _ = rnn_generate(sample_rwkv_rnn, context, state, tokenizer, key_gen, 124 | argmax=True, 125 | length_per_trial=100, n_trials=3, temperature=1.0, top_p=0.85) 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /pico_rwkv/rwkv_torch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 4 | import types, torch 5 | from torch.nn import functional as F 6 | from tokenizers import Tokenizer 7 | 8 | tokenizer = Tokenizer.from_file("/Data/lm_models/rwkv/20B_tokenizer.json") 9 | 10 | args = types.SimpleNamespace() 11 | args.MODEL_NAME = '/Data/lm_models/rwkv/RWKV-4-Pile-430M-20220808-8066' 12 | args.n_layer = 24 13 | args.n_embd = 1024 14 | 15 | context = "\nPumas are large, cat-like animals" 16 | NUM_TRIALS = 3 17 | LENGTH_PER_TRIAL = 100 18 | TEMPERATURE = 1.0 19 | TOP_P = 0.85 20 | 21 | 22 | ######################################################################################################## 23 | 24 | class RWKV_RNN(torch.jit.ScriptModule): 25 | def __init__(self, args): 26 | super().__init__() 27 | self.args = args 28 | self.eval() # set torch to inference mode 29 | 30 | w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') 31 | for k in w.keys(): 32 | if '.time_' in k: w[k] = w[k].squeeze() 33 | if '.time_decay' in k: 34 | w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x} 35 | else: 36 | w[k] = w[k].float() # convert to f32 type 37 | 38 | self.w = types.SimpleNamespace() # set self.w from w 39 | self.w.blocks = {} 40 | for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first 41 | parts = k.split('.') 42 | last = parts.pop() 43 | here = self.w 44 | for p in parts: 45 | if p.isdigit(): 46 | p = int(p) 47 | if p not in here: here[p] = types.SimpleNamespace() 48 | here = here[p] 49 | else: 50 | if not hasattr(here, p): setattr(here, p, types.SimpleNamespace()) 51 | here = getattr(here, p) 52 | setattr(here, last, w[k]) 53 | 54 | def layer_norm(self, x, w): 55 | return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias) 56 | 57 | @torch.jit.script_method 58 | def channel_mixing(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw): 59 | xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) 60 | xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) 61 | state[5 * i + 0] = x 62 | r = torch.sigmoid(rw @ xr) 63 | k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper 64 | return r * (vw @ k) 65 | 66 | @torch.jit.script_method 67 | def time_mixing(self, x, state, i: int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): 68 | if i == 10: 69 | print(state[5 * i + 1]) 70 | raise Exception("stop") 71 | xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) 72 | xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) 73 | xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) 74 | state[5 * i + 1] = x 75 | if i == 9: 76 | print(state[5 * i + 1]) 77 | r = torch.sigmoid(rw @ xr) 78 | k = kw @ xk 79 | v = vw @ xv 80 | 81 | aa = state[5 * i + 2] 82 | bb = state[5 * i + 3] 83 | pp = state[5 * i + 4] 84 | ww = time_first + k 85 | qq = torch.maximum(pp, ww) 86 | e1 = torch.exp(pp - qq) 87 | e2 = torch.exp(ww - qq) 88 | a = e1 * aa + e2 * v 89 | b = e1 * bb + e2 90 | wkv = a / b 91 | ww = pp + time_decay 92 | qq = torch.maximum(ww, k) 93 | e1 = torch.exp(ww - qq) 94 | e2 = torch.exp(k - qq) 95 | state[5 * i + 2] = e1 * aa + e2 * v 96 | state[5 * i + 3] = e1 * bb + e2 97 | state[5 * i + 4] = qq 98 | return ow @ (r * wkv) 99 | 100 | def forward(self, token, state): 101 | with torch.no_grad(): 102 | if state == None: 103 | state = torch.zeros(self.args.n_layer * 5, self.args.n_embd) 104 | for i in range(self.args.n_layer): state[5 * i + 4] = -1e30 # -infinity 105 | 106 | x = self.w.emb.weight[token] 107 | x = self.layer_norm(x, self.w.blocks[0].ln0) 108 | for i in range(self.args.n_layer): 109 | att = self.w.blocks[i].att 110 | xp = self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, 111 | att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay, 112 | att.key.weight, att.value.weight, att.receptance.weight, att.output.weight) 113 | x += xp 114 | ffn = self.w.blocks[i].ffn 115 | xp = self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, 116 | ffn.time_mix_k, ffn.time_mix_r, 117 | ffn.key.weight, ffn.value.weight, ffn.receptance.weight) 118 | x += xp 119 | 120 | x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out) 121 | return x.float(), state 122 | 123 | 124 | ########################################################################################################## 125 | 126 | def sample_logits(out, temperature=1.0, top_p=0.8): 127 | probs = F.softmax(out, dim=-1).numpy() 128 | sorted_probs = np.sort(probs)[::-1] 129 | cumulative_probs = np.cumsum(sorted_probs) 130 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) 131 | probs[probs < cutoff] = 0 132 | if temperature != 1.0: 133 | probs = probs.pow(1.0 / temperature) 134 | probs = probs / np.sum(probs) 135 | out = np.random.choice(a=len(probs), p=probs) 136 | return out 137 | 138 | 139 | ######################################################################################################## 140 | 141 | print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...') 142 | model = RWKV_RNN(args) 143 | 144 | print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)') 145 | init_state = None 146 | for token in tokenizer.encode(context).ids: 147 | init_out, init_state = model.forward(token, init_state) 148 | 149 | for TRIAL in range(NUM_TRIALS): 150 | print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="") 151 | all_tokens = [] 152 | out_last = 0 153 | out, state = init_out.clone(), init_state.clone() 154 | for i in range(LENGTH_PER_TRIAL): 155 | token = np.argmax(out) 156 | # token = sample_logits(out, TEMPERATURE, TOP_P) 157 | all_tokens += [token] 158 | tmp = tokenizer.decode(all_tokens[out_last:]) 159 | if '\ufffd' not in tmp: # only print when we have a valid utf-8 string 160 | print(tmp, end="", flush=True) 161 | out_last = i + 1 162 | out, state = model.forward(token, state) 163 | print('\n') -------------------------------------------------------------------------------- /pico_rwkv/rwkv_weight_profile.py: -------------------------------------------------------------------------------- 1 | from copy_init.weights import WeightConfig, WeightConfigType, NormalWeight, ZeroWeight 2 | from labels import L 3 | 4 | 5 | def make_weight_config(name: str, shape: tuple[int, ...], wtype: WeightConfigType) -> WeightConfig: 6 | tags = L(*name.split('.')) 7 | if wtype == 'normal': 8 | return NormalWeight(name=name, shape=shape, tags=tags) 9 | elif wtype == 'zero': 10 | return ZeroWeight(name=name, shape=shape, tags=tags) 11 | else: 12 | raise ValueError(f"weight_config_type must be 'normal' or 'zero', not {wtype}") 13 | 14 | 15 | def make_rwkv_weight_configs(n_blocks: int, n_embd: int, ffn_hidden_multiplier: int) -> dict[str, WeightConfig]: 16 | properties = [ 17 | ("blocks.att.output.weight", (n_blocks, n_embd, n_embd), 'zero'), 18 | ("blocks.att.value.weight", (n_blocks, n_embd, n_embd), 'zero'), 19 | ("blocks.att.key.weight", (n_blocks, n_embd, n_embd), 'zero'), 20 | ("blocks.att.receptance.weight", (n_blocks, n_embd, n_embd), 'zero'), 21 | ("blocks.att.time_decay", (n_blocks, n_embd), 'zero'), 22 | ("blocks.att.time_first", (n_blocks, n_embd), 'zero'), 23 | ("blocks.att.time_mix_k", (n_blocks, n_embd), 'zero'), 24 | ("blocks.att.time_mix_r", (n_blocks, n_embd), 'zero'), 25 | ("blocks.att.time_mix_v", (n_blocks, n_embd), 'zero'), 26 | ("blocks.ffn.key.weight", (n_blocks, ffn_hidden_multiplier * n_embd, n_embd), 'zero'), 27 | ("blocks.ffn.value.weight", (n_blocks, n_embd, n_embd * ffn_hidden_multiplier), 'zero'), 28 | ("blocks.ffn.receptance.weight", (n_blocks, n_embd, n_embd), 'zero'), 29 | ("blocks.ffn.time_mix_k", (n_blocks, n_embd), 'zero'), 30 | ("blocks.ffn.time_mix_r", (n_blocks, n_embd), 'zero'), 31 | ("blocks.ffn.time_mix_v", (n_blocks, n_embd), 'zero'), 32 | ("blocks.ln0.weight", (n_blocks, n_embd), 'zero'), 33 | ("blocks.ln0.bias", (n_blocks, n_embd), 'zero'), 34 | 35 | 36 | ] 37 | -------------------------------------------------------------------------------- /pico_rwkv/train_rwkv.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from functools import partial 5 | from pathlib import Path 6 | from typing import Optional 7 | 8 | import jax.numpy as np 9 | import optax 10 | import wandb 11 | from jax.tree_util import tree_flatten 12 | from tokenizers import Tokenizer 13 | 14 | import custom_dataset 15 | import custom_dataset_str 16 | import nlp_utils 17 | from copy_init.weights import get_normal_weights_config_, init, save_pytree_ 18 | from labels import Labels 19 | from pico_rwkv.pico_rwkv_parallel import rwkv_net_parallel 20 | from pico_rwkv.pico_rwkv_rnn import rwkv_net_rnn 21 | from pico_rwkv.pico_rwkv_weights import parse_rwkv_weight 22 | from picojax.jax_utils import WeightsTree, Arr 23 | from picojax.random_utils import infinite_safe_keys 24 | from picojax.train_utils import LMBatchConfig, TrainConfig, TrainState, get_lm_loss 25 | from python_utils import num_short_form 26 | 27 | os.environ['JAX_LOG_COMPILES'] = '1' 28 | model_path = Path("/Data/lm_models/rwkv") 29 | # model_name = 'RWKV-4-Pile-430M-20220808-8066' 30 | model_name = 'RWKV-4-Pile-169M-20220807-8023' 31 | 32 | weight_infos = get_normal_weights_config_(model_path, model_name) 33 | 34 | keygen = infinite_safe_keys(0) 35 | key = next(keygen) 36 | 37 | init_weights_raw = init(weight_infos, rng_key=key) 38 | all_weight_names = list(init_weights_raw.keys()) 39 | 40 | # randomly initialize weights 41 | init_weights_: WeightsTree = parse_rwkv_weight(init_weights_raw.keys(), init_weights_raw.__getitem__, trim=True) 42 | ## load weights instead of randomly initialize 43 | # with safe_open(model_path / f"{model_name}.safetensors", framework="flax", device="cpu") as f: 44 | # init_weights_ = parse_rwkv_weight(f.keys(), f.get_tensor) 45 | 46 | n_channels = init_weights_['emb']['weight'].shape[1] # type: ignore 47 | _, tree_struct = tree_flatten(init_weights_) 48 | 49 | train_tags: Optional[list[Labels]] = None 50 | weight_mask = None 51 | 52 | # train_tags = [Labels.from_strs('ffn', 'key'), Labels.from_strs('ffn', 'value')] 53 | # weight_mask = get_masks_to_train(train_tags, weight_infos, trim=True) 54 | 55 | 56 | # %% 57 | 58 | data_path = Path("/Data/nlp/") 59 | 60 | str_sampling = False # turn this on if the data is too large to fit in memory, may be a bit slower 61 | custom_vocab = True # turn this off if you want to use the default tokenizer, str_sampling == True does not support custom_token 62 | 63 | dataset = "play" 64 | # dataset = "poem" 65 | # dataset = "english" 66 | # dataset = "russell" 67 | # dataset = "duilian" 68 | # dataset = "MQnovel" 69 | 70 | if str_sampling: 71 | input_data = custom_dataset_str.load(data_path, dataset) 72 | tokenizer = Tokenizer.from_file(str(model_path / "20B_tokenizer.json")) 73 | else: 74 | if custom_vocab: 75 | input_data, tokenizer = custom_dataset.load_jax_cached_tokenizer(data_path, dataset) 76 | else: 77 | input_data_str = custom_dataset_str.load(data_path, dataset) 78 | tokenizer = Tokenizer.from_file(str(model_path / "20B_tokenizer.json")) 79 | input_data = np.array(tokenizer.encode(input_data_str).ids) 80 | if custom_vocab: 81 | # reduce embedding to vocab size 82 | init_weights_['emb']['weight'] = init_weights_['emb']['weight'][:tokenizer.get_vocab_size(), :] # type: ignore 83 | init_weights_['head']['weight'] = init_weights_['head']['weight'][:tokenizer.get_vocab_size(), :] # type: ignore 84 | 85 | n = int(len(input_data) * 0.9) 86 | train_data = input_data[:n] 87 | valid_data = input_data[n:] 88 | 89 | key_gen = infinite_safe_keys(0) 90 | 91 | adam_params = { 92 | 'learning_rate': 1e-4, 93 | 'beta1': 0.9, 94 | 'beta2': 0.999, 95 | 'eps': 1e-8, 96 | } 97 | lion_params = { 98 | 'learning_rate': 1e-4, 99 | 'beta1': 0.95, 100 | 'beta2': 0.98, 101 | 'weight_decay': 0.01 102 | } 103 | train_params = { 104 | 'eval_iters': 200, 105 | 'eval_interval': 2000, 106 | 'save_interval': 10000, 107 | 'max_iters': 1000000, 108 | # 'adam': adam_params, 109 | 'lion': lion_params, 110 | # 'adamw': adam_params, 111 | 'optimizer': 'lion', 112 | } 113 | 114 | experimental_params: dict = { 115 | 'eps': 1e-5, 116 | 'n_tokens': tokenizer.get_vocab_size(), 117 | 'n_channels': n_channels, 118 | 'n_blocks': len(init_weights_['blocks']), 119 | 120 | 'train_tags': [l.fmt() for l in train_tags] if train_tags is not None else None, 121 | 122 | 'batch_size': 4, 123 | 'block_size': 128, 124 | 'train': train_params, 125 | 'model': "rwkv" 126 | } 127 | 128 | max_iters = experimental_params['train']['max_iters'] 129 | eval_interval = experimental_params['train']['eval_interval'] 130 | save_interval = experimental_params['train']['save_interval'] 131 | eval_iters = experimental_params['train']['eval_iters'] 132 | batch_config_ = LMBatchConfig(block_size=experimental_params['block_size'], 133 | batch_size=experimental_params['batch_size'], 134 | tokenizer=tokenizer, 135 | str_sampling=str_sampling) 136 | 137 | if experimental_params['train']['optimizer'] == 'adam': 138 | adam_config = experimental_params['train']['adam'] 139 | optimizer_ = optax.adam(learning_rate=adam_config['learning_rate'], 140 | b1=adam_config['beta1'], 141 | b2=adam_config['beta2'], 142 | eps=adam_config['eps']) 143 | elif experimental_params['train']['optimizer'] == 'lion': 144 | lion_config = experimental_params['train']['lion'] 145 | optimizer_ = optax.lion(learning_rate=lion_config['learning_rate'], 146 | b1=lion_config['beta1'], 147 | b2=lion_config['beta2'], 148 | weight_decay=lion_config['weight_decay']) 149 | elif experimental_params['train']['optimizer'] == 'adamw': 150 | adamw_config = experimental_params['train']['adamw'] 151 | optimizer_ = optax.adamw(learning_rate=adamw_config['learning_rate'], 152 | b1=adamw_config['beta1'], 153 | b2=adamw_config['beta2'], 154 | eps=adamw_config['eps']) 155 | else: 156 | raise ValueError(f"optimizer {experimental_params['train']['optimizer']} not supported") 157 | 158 | 159 | def rwkv_f(w: WeightsTree, token_array: Arr) -> Arr: 160 | return rwkv_net_parallel(len(token_array), token_array, **w) 161 | 162 | 163 | def rwkv_rnn(w: WeightsTree, token_array: Arr, state: Arr) -> tuple[Arr, Arr]: 164 | return rwkv_net_rnn(token_array, state, **w) 165 | 166 | 167 | # noinspection PyArgumentList 168 | # cuz it's a NamedTuple 169 | train_config_ = TrainConfig(loss_fn=partial(get_lm_loss, rwkv_f), 170 | optimiser=optimizer_) 171 | # noinspection PyArgumentList 172 | # cuz it's a NamedTuple 173 | train_state_: TrainState = TrainState(weights=init_weights_, 174 | train_mask=weight_mask, 175 | opt_state=optimizer_.init(init_weights_)) 176 | 177 | rnn_init_state = np.zeros((experimental_params['n_blocks'], 5, experimental_params['n_channels'])) 178 | for i in range(experimental_params['n_blocks']): 179 | # to jax state[5 * i + 4] = -1e30 180 | rnn_init_state = rnn_init_state.at[i, 4].set(-1e30) 181 | 182 | run = wandb.init( 183 | project="inside-transformer", 184 | config=experimental_params, 185 | ) 186 | assert isinstance(run, wandb.sdk.wandb_run.Run) 187 | 188 | keys_ = next(key_gen).split(max_iters) 189 | for step in range(max_iters): 190 | batch_ = batch_config_.sample(train_data, keys_[step]) 191 | 192 | if step % eval_interval == 0: 193 | loss = train_config_.loss_fn(train_state_.weights, batch_) 194 | print(f"\n===[ step {step} is an eval step ]==========") 195 | print(f"before step {step}, batch loss {loss}") 196 | 197 | train_state_ = train_config_.train1(train_state_, batch_) 198 | if step % eval_interval == 0: 199 | loss = train_config_.loss_fn(train_state_.weights, batch_) 200 | print(f"after step {step}, batch loss {loss}") 201 | results = train_config_.estimate_loss(eval_iters, key_gen, train_state_, 202 | {'train': partial(batch_config_.sample, train_data), 203 | 'val': partial(batch_config_.sample, valid_data)}) 204 | generate_f = partial(rwkv_rnn, train_state_.weights) 205 | generated = nlp_utils.rnn_generate(generate_f, 206 | "\n", 207 | tokenizer=batch_config_.tokenizer, 208 | argmax=True, 209 | key_gen=key_gen, 210 | init_state=rnn_init_state, 211 | length_per_trial=batch_config_.block_size - 1) 212 | 213 | wandb.log({"train_loss": results['train'], 214 | "validation_loss": results['val'], 215 | "batch_loss": loss, 216 | "n_tokens_trained": step * batch_config_.batch_size * batch_config_.block_size, 217 | 'generated': wandb.Html(f"{generated}")}) 218 | if step % save_interval == 0: # and step > 0: 219 | n_tokens_trained = step * batch_config_.batch_size * batch_config_.block_size 220 | n_tokens_trained_str = num_short_form(n_tokens_trained) 221 | wandb.save(save_pytree_(train_state_.weights, run.dir, f"{model_name}_{n_tokens_trained_str}"), run.dir) 222 | 223 | wandb.finish() 224 | 225 | # [Done] add trainable weights (via gradient mask) 226 | # [TODO] add trainable weights via weight name mask 227 | # [Done] add checkpointing 228 | # TODO: add weight decay 229 | -------------------------------------------------------------------------------- /pico_s5/pico_s5.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cwhy/rwkv-decon/edb98117140f6a8e99a8fe250a1882a9e07a5650/pico_s5/pico_s5.py -------------------------------------------------------------------------------- /picojax/jax_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import jax 4 | from typing import Callable, Union, TypeVar 5 | 6 | Arr = jax.Array 7 | 8 | 9 | def jit_f(f: Callable) -> Callable: 10 | return jax.jit(f, static_argnums=(0,), inline=True) 11 | 12 | 13 | WeightsTree = dict[str, 'WeightsType'] 14 | WeightsType = Union[Arr, WeightsTree, list[Union[WeightsTree, Arr]]] 15 | 16 | -------------------------------------------------------------------------------- /picojax/random_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Tuple, Literal, Iterator 4 | 5 | import jax 6 | from jax import Array 7 | from jax import numpy as xp, random 8 | from jax.random import PRNGKeyArray 9 | 10 | ArrayGen = Literal['kaiming', 'dropout', 'embedding', 'normal'] 11 | 12 | 13 | class SafeKey: 14 | """Safety wrapper for PRNG keys.""" 15 | 16 | def __init__(self, key: PRNGKeyArray): 17 | self._key = key 18 | self._used = False 19 | 20 | def _assert_not_used(self) -> None: 21 | if self._used: 22 | raise RuntimeError('Random key has been used previously.') 23 | 24 | def get(self) -> PRNGKeyArray: 25 | self._assert_not_used() 26 | self._used = True 27 | return self._key 28 | 29 | def split(self, num_keys=2) -> Tuple['SafeKey', ...]: 30 | self._assert_not_used() 31 | self._used = True 32 | new_keys = jax.random.split(self._key, num_keys) 33 | return jax.tree_map(SafeKey, tuple(new_keys)) 34 | 35 | def duplicate(self, num_keys=2) -> Tuple['SafeKey', ...]: 36 | self._assert_not_used() 37 | self._used = True 38 | return tuple(SafeKey(self._key) for _ in range(num_keys)) 39 | 40 | 41 | def infinite_safe_keys(seed: int) -> Iterator[SafeKey]: 42 | init_key = jax.random.PRNGKey(seed) 43 | while True: 44 | init_key, key = jax.random.split(init_key) 45 | yield SafeKey(key) 46 | 47 | 48 | def dropout_gen(rng_key: SafeKey, keep_rate: float, shape: Tuple[int, ...]): 49 | return random.bernoulli(rng_key.get(), keep_rate, shape) 50 | 51 | 52 | def kaiming_init(rng_key: SafeKey, sd: float, shape: Tuple[int, ...]) -> Array: 53 | """ 54 | Generate randomly initialized weight matrix with Kaiming initalization: 55 | Normally distributed scaled by sqrt(2/fan_in) 56 | 57 | Arguments: 58 | :param rng_key: random number generator key from jax 59 | :param sd: standard deviation for initialization 60 | :param shape: = (n_in, ..., n_out) 61 | where 62 | n_in is number of inputs to the layer 63 | n_out is number of outputs from the layer 64 | 65 | Returns: 66 | weight matrix of shape [n_in, n_out] 67 | """ 68 | n_in = shape[0] 69 | return xp.sqrt(2 / n_in) * normal_init(rng_key, sd, shape) 70 | 71 | 72 | def embedding_init(rng_key: SafeKey, scale: float, shape: Tuple[int, ...]) -> Array: 73 | """ 74 | Arguments: 75 | :param rng_key: random number generator key from jax 76 | :param scale: standard deviation for initialization 77 | :param shape: = (dict_size, ..., dim_model) 78 | where 79 | 80 | Returns: 81 | weight matrix of shape (dict_size, ..., dim_model) 82 | """ 83 | dim_model = shape[-1] 84 | return random.normal(rng_key.get(), shape) * xp.sqrt(dim_model) * scale 85 | 86 | 87 | def normal_init(rng_key: SafeKey, sd: float, shape: Tuple[int, ...]) -> Array: 88 | return random.normal(rng_key.get(), shape) * sd 89 | -------------------------------------------------------------------------------- /picojax/train_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import partial 4 | from typing import Callable, TypeVar, Optional, Union, Protocol, Collection, Sequence 5 | from typing import Iterable, Generic, Iterator 6 | 7 | import jax 8 | import optax 9 | from jax import random, numpy as jnp, vmap, numpy as np 10 | from optax import GradientTransformation, OptState, softmax_cross_entropy_with_integer_labels 11 | from typing_extensions import NamedTuple 12 | 13 | from nlp_utils import Tokenizer 14 | from .jax_utils import Arr, WeightsTree 15 | from .random_utils import SafeKey 16 | 17 | Weights = TypeVar('Weights') 18 | Batch = TypeVar('Batch') 19 | 20 | 21 | @partial(jax.jit, static_argnums=(0, 1), inline=True) 22 | def jax_calc_updates( 23 | optimizer: optax.GradientTransformation, 24 | loss_fn: Callable[[Batch], Arr], 25 | weights: Weights, 26 | train_mask: Optional[WeightsTree], 27 | batch: Batch, 28 | opt_state: optax.OptState 29 | ) -> tuple[Weights, optax.OptState]: 30 | grads = jax.grad(loss_fn)(weights, batch) 31 | if train_mask is not None: 32 | grads = jax.tree_map(lambda x, y: x * y, grads, train_mask) 33 | updates, opt_state = optimizer.update(grads, opt_state, weights) 34 | return optax.apply_updates(weights, updates), opt_state 35 | 36 | 37 | class LMBatchConfig(NamedTuple): 38 | block_size: int 39 | batch_size: int 40 | tokenizer: Tokenizer 41 | str_sampling: bool = False 42 | 43 | def sample_jax(self, data: Arr, rng_key: SafeKey) -> tuple[Iterable[int], Iterable[int]]: 44 | ix = random.randint(key=rng_key.get(), minval=0, maxval=len(data) - self.block_size, shape=(self.batch_size,)) 45 | inputs_ = data[(ix[:, jnp.newaxis] + jnp.arange(self.block_size)[jnp.newaxis, :])] 46 | targets_ = data[(ix[:, jnp.newaxis] + jnp.arange(1, self.block_size + 1)[jnp.newaxis, :])] 47 | return inputs_, targets_ 48 | 49 | def sample_from_str(self, data: str, rng_key: SafeKey) -> tuple[Iterable[int], Iterable[int]]: 50 | ix = random.randint(key=rng_key.get(), minval=0, maxval=len(data) - self.block_size, shape=(self.batch_size,)) 51 | arr = np.stack([np.array(self.tokenizer.encode(data[i:i + self.block_size]).ids).astype(np.int32) for i in ix]) 52 | inputs_ = arr[:, :-1] 53 | targets_ = arr[:, 1:] 54 | return inputs_, targets_ 55 | 56 | def sample(self, data: Union[str, Arr], rng_key: SafeKey) -> tuple[Iterable[int], Iterable[int]]: 57 | if self.str_sampling: 58 | assert isinstance(data, str), 'data must be a string when str_sampling is True' 59 | return self.sample_from_str(data, rng_key) 60 | else: 61 | assert not isinstance(data, str), 'data must be an array when str_sampling is False' 62 | return self.sample_jax(data, rng_key) 63 | 64 | 65 | W = TypeVar('W') 66 | 67 | 68 | class TrainConfig(NamedTuple, Generic[W]): 69 | loss_fn: Callable[[W, BatchType], Arr] 70 | optimiser: GradientTransformation 71 | 72 | def estimate_loss(self, 73 | eval_iters: int, 74 | rng_key_gen: Iterator[SafeKey], 75 | train_state: TrainState, 76 | samplers: dict[str, Callable[[SafeKey], BatchType]]) -> dict[str, float]: 77 | 78 | results = {} 79 | for split, sampler in samplers.items(): 80 | total_eval_loss = 0 81 | for key in next(rng_key_gen).split(eval_iters): 82 | eval_batch = sampler(key) 83 | total_eval_loss += self.loss_fn(train_state.weights, eval_batch).item() 84 | results[split] = total_eval_loss / eval_iters 85 | print(f"Estimated {split} loss: {total_eval_loss / eval_iters}") 86 | return results 87 | 88 | def train1(self, state: TrainState, batch: BatchType) -> TrainState: 89 | weights, opt_state = jax_calc_updates(self.optimiser, 90 | self.loss_fn, 91 | state.weights, 92 | state.train_mask, 93 | batch, 94 | state.opt_state) 95 | return state.update(weights=weights, opt_state=opt_state) 96 | 97 | 98 | class TrainState(Generic[W], NamedTuple): 99 | weights: W 100 | opt_state: OptState 101 | train_mask: Optional[WeightsTree] = None 102 | 103 | def update(self, **kwargs): 104 | return self._replace(**kwargs) 105 | 106 | 107 | FixedLenBatchType = tuple[Iterable[int], Iterable[int]] 108 | MixedLenBatchType = tuple[Iterable[int], Iterable[int], Iterable[int]] 109 | BatchType = Union[FixedLenBatchType, MixedLenBatchType] 110 | 111 | 112 | def get_lm_loss(f: Callable[[WeightsTree, Arr], Arr], w: WeightsTree, batch: FixedLenBatchType) -> Arr: 113 | inputs, labels = batch 114 | logits = vmap(f, in_axes=(None, 0), out_axes=0)(w, np.array(inputs)) 115 | return softmax_cross_entropy_with_integer_labels(logits, np.array(labels)).mean() 116 | 117 | 118 | def get_classification_loss(f: Callable[[WeightsTree, Arr], Arr], w: WeightsTree, batch: MixedLenBatchType) -> Arr: 119 | inputs, labels, seq_len = batch 120 | logits = vmap(f, in_axes=(None, 0), out_axes=0)(w, np.array(inputs)) 121 | return softmax_cross_entropy_with_integer_labels(logits[np.arange(logits.shape[0]), seq_len], np.array(labels)).mean() 122 | -------------------------------------------------------------------------------- /python_utils.py: -------------------------------------------------------------------------------- 1 | def format_num(num, unit): 2 | num_str = "{:.1f}".format(num).replace(".", "_") 3 | return num_str.rstrip("0").rstrip("_") + unit 4 | 5 | 6 | def num_short_form(num): 7 | if num == 0: 8 | return "0" 9 | abs_num = abs(num) 10 | sign = "-" if num < 0 else "" 11 | 12 | if abs_num < 1000: 13 | return str(num) 14 | elif abs_num < 1000000: 15 | return sign + format_num(abs_num / 1000, "K") 16 | elif abs_num < 1000000000: 17 | return sign + format_num(abs_num / 1000000, "M") 18 | else: 19 | return sign + format_num(abs_num / 1000000000, "B") 20 | 21 | 22 | if __name__ == "__main__": 23 | def test_convert_num_to_str(): 24 | assert num_short_form(0) == "0", "Error: expected '0', but got '{}'".format(num_short_form(0)) 25 | assert num_short_form(1) == "1", "Error: expected '1', but got '{}'".format(num_short_form(1)) 26 | assert num_short_form(10) == "10", "Error: expected '10', but got '{}'".format(num_short_form(10)) 27 | assert num_short_form(100) == "100", "Error: expected '100', but got '{}'".format(num_short_form(100)) 28 | assert num_short_form(999) == "999", "Error: expected '999', but got '{}'".format(num_short_form(999)) 29 | assert num_short_form(1000) == "1K", "Error: expected '1K', but got '{}'".format(num_short_form(1000)) 30 | assert num_short_form(1500) == "1_5K", "Error: expected '1_5K', but got '{}'".format( 31 | num_short_form(1500)) 32 | assert num_short_form(1999) == "2K", "Error: expected '2K', but got '{}'".format(num_short_form(1999)) 33 | assert num_short_form(1000000) == "1M", "Error: expected '1M', but got '{}'".format( 34 | num_short_form(1000000)) 35 | assert num_short_form(1500000) == "1_5M", "Error: expected '1_5M', but got '{}'".format( 36 | num_short_form(1500000)) 37 | assert num_short_form(999999999) == "1000M", "Error: expected '1000M', but got '{}'".format( 38 | num_short_form(999999999)) 39 | assert num_short_form(1000000000) == "1B", "Error: expected '1B', but got '{}'".format( 40 | num_short_form(1000000000)) 41 | assert num_short_form(1500000000) == "1_5B", "Error: expected '1_5B', but got '{}'".format( 42 | num_short_form(1500000000)) 43 | assert num_short_form(-1000) == "-1K", "Error: expected '-1K', but got '{}'".format( 44 | num_short_form(-1000)) 45 | assert num_short_form(-1500) == "-1_5K", "Error: expected '-1_5K', but got '{}'".format( 46 | num_short_form(-1500)) 47 | assert num_short_form(-1999) == "-2K", "Error: expected '-2K', but got '{}'".format( 48 | num_short_form(-1999)) 49 | assert num_short_form(-1000000) == "-1M", "Error: expected '-1M', but got '{}'".format( 50 | num_short_form(-1000000)) 51 | assert num_short_form(-1500000) == "-1_5M", "Error: expected '-1_5M', but got '{}'".format( 52 | num_short_form(-1500000)) 53 | assert num_short_form(-999999999) == "-1000M", "Error: expected '-1000M', but got '{}'".format( 54 | num_short_form(-999999999)) 55 | assert num_short_form(-1000000000) == "-1B", "Error: expected '-1B', but got '{}'".format( 56 | num_short_form(-1000000000)) 57 | assert num_short_form(-1500000000) == "-1_5B", "Error: expected '-1_5B', but got '{}'".format( 58 | num_short_form(-1500000000)) 59 | print("All tests passed") 60 | 61 | 62 | test_convert_num_to_str() 63 | -------------------------------------------------------------------------------- /saves/view_vec.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cwhy/rwkv-decon/edb98117140f6a8e99a8fe250a1882a9e07a5650/saves/view_vec.npy -------------------------------------------------------------------------------- /saves/view_vec2_dict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cwhy/rwkv-decon/edb98117140f6a8e99a8fe250a1882a9e07a5650/saves/view_vec2_dict -------------------------------------------------------------------------------- /saves/view_vec2_dict_jit: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cwhy/rwkv-decon/edb98117140f6a8e99a8fe250a1882a9e07a5650/saves/view_vec2_dict_jit -------------------------------------------------------------------------------- /saves/view_vec2_dict_new: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cwhy/rwkv-decon/edb98117140f6a8e99a8fe250a1882a9e07a5650/saves/view_vec2_dict_new -------------------------------------------------------------------------------- /saves/view_vec2_step1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cwhy/rwkv-decon/edb98117140f6a8e99a8fe250a1882a9e07a5650/saves/view_vec2_step1.npy -------------------------------------------------------------------------------- /saves/view_vecs_dict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cwhy/rwkv-decon/edb98117140f6a8e99a8fe250a1882a9e07a5650/saves/view_vecs_dict --------------------------------------------------------------------------------