├── .gitignore
├── .idea
├── .gitignore
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── other.xml
├── rwkv-decon.iml
└── vcs.xml
├── README.md
├── archive
├── gpt_configs.py
├── gpt_decon.py
├── notebook.py
└── pico_utils.py
├── copy_init
├── copy_rwkv.py
├── test_weights.py
└── weights.py
├── custom_dataset.py
├── custom_dataset_str.py
├── gpt_recon
├── LICENSE
├── attn.html
├── bpe_encoder.py
├── chart4.svg
├── clean_frame.py
├── clean_frame_utils.py
├── gpt.py
├── lines.svg
├── load_model.py
├── run_gpt.py
├── train_gpt.py
├── train_notebook.py
├── tune_gpt.py
├── view_attn.py
├── view_notebook.py
└── view_vec.py
├── labels.py
├── lra_arena
├── lra_utils.py
├── rwkv_lra.py
└── use_lra_demo.py
├── mypy.ini
├── nlp_utils.py
├── pico_lru
├── pico_lru.py
├── pico_lru_parallel.py
└── pico_rnn.py
├── pico_rwkv
├── associative_scan_toolbox.py
├── close_to_original
│ ├── pico_rwkv.py
│ └── run_rwkv.py
├── jax_load.py
├── pico_rwkv_parallel.py
├── pico_rwkv_parallel_alternatives.py
├── pico_rwkv_rnn.py
├── pico_rwkv_weights.py
├── pth_to_safet.py
├── run_rwkv_parallel.py
├── rwkv_torch.py
├── rwkv_weight_profile.py
└── train_rwkv.py
├── pico_s5
└── pico_s5.py
├── picojax
├── jax_utils.py
├── random_utils.py
└── train_utils.py
├── python_utils.py
└── saves
├── view_vec.npy
├── view_vec2_dict
├── view_vec2_dict_jit
├── view_vec2_dict_new
├── view_vec2_step1.npy
└── view_vecs_dict
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # wanb/jax
132 | jax-trace/*
133 | wandb/*
134 | wandb
135 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/rwkv-decon.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # rwkv-decon
2 | Trying to deconstruct RWKV in understandable terms
3 |
--------------------------------------------------------------------------------
/archive/gpt_configs.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import math
4 | from functools import cache
5 | from typing import NamedTuple, Optional, cast, TypedDict, List, Literal, Union
6 |
7 | from chex import assert_shape
8 | from einops import rearrange
9 | from jax import numpy as jnp
10 | from jax.nn import softmax
11 | from jax.experimental.maps import xmap
12 | from safetensors.flax import save_file
13 |
14 | from clean_frame import Linear, for_all_T, gelu, LN
15 | from clean_frame_utils import check_config, WeightConfigDict, PartsDict, WeightsTree, config_weights_check, Arr, \
16 | WeightConfig
17 |
18 |
19 | class GptMha:
20 | class Config(NamedTuple):
21 | n_channels: Optional[int] = None
22 | n_heads: Optional[int] = None
23 | n_seq: Optional[Union[int, Literal['dynamic']]] = None
24 | inf_mask: float = -1e10
25 | linear: Linear.Config = Linear.Config()
26 | QKV_linear: Linear.Config = Linear.Config()
27 |
28 | name: str = "mha"
29 |
30 | @property
31 | @cache
32 | def T(self) -> Optional[int]:
33 | if self.n_seq == 'dynamic':
34 | return None
35 | else:
36 | return self.n_seq
37 |
38 | def fill(self) -> GptMha.Config:
39 | assert self.n_channels is not None, 'n_channels must be set'
40 | new = self._replace(linear=self.linear._replace(n_in=self.n_channels, n_out=self.n_channels),
41 | QKV_linear=self.QKV_linear._replace(n_in=self.n_channels, n_out=3 * self.n_channels))
42 | check_config(new)
43 | return new
44 |
45 | def make(self) -> GptMha:
46 | return GptMha(self.fill())
47 |
48 | @property
49 | def dim_heads(self) -> int:
50 | assert self.n_channels is not None
51 | assert self.n_heads is not None
52 | return self.n_channels // self.n_heads
53 |
54 | @property
55 | def weights(self) -> WeightConfigDict:
56 | return {}
57 |
58 | @property
59 | def parts(self) -> PartsDict:
60 | filled = self.fill()
61 | return dict(
62 | linear=filled.linear,
63 | QKV_linear=filled.QKV_linear
64 | )
65 |
66 | def weights_check(self, w: WeightsTree) -> GptMha.Weights:
67 | return cast(GptMha.Weights, config_weights_check(self, w))
68 |
69 | class Weights(TypedDict):
70 | QKV_linear: Linear.Weights
71 | linear: Linear.Weights
72 |
73 | def __init__(self, config: Config):
74 | assert config.n_channels is not None
75 | assert config.n_heads is not None
76 | assert config.dim_heads is not None
77 | self.n_channels = config.n_channels
78 | self.n_heads = config.n_heads
79 | self.T = config.T
80 | self.dim_heads = config.dim_heads
81 | self.config = config
82 |
83 | self.linear = config.linear.make()
84 | self.QKV_linear = config.QKV_linear.make()
85 | self.scale = math.sqrt(self.dim_heads)
86 | self.linearf = for_all_T(self.linear.f)
87 | self.QKV_linearf = for_all_T(self.QKV_linear.f)
88 | assert self.n_channels % self.n_heads == 0, 'n_channels must be divisible by n_heads'
89 |
90 | def get_mask(self, t: int) -> Arr:
91 | return (1 - jnp.tri(t)) * self.config.inf_mask
92 |
93 |
94 | class GptFfn:
95 | class Config(NamedTuple):
96 | n_channels: Optional[int] = None
97 | n_seq: Optional[Union[int, Literal['dynamic']]] = None
98 | linear1: Linear.Config = Linear.Config()
99 | linear2: Linear.Config = Linear.Config()
100 | name: str = "ffn"
101 |
102 | @property
103 | @cache
104 | def T(self) -> Optional[int]:
105 | if self.n_seq == 'dynamic':
106 | return None
107 | else:
108 | return self.n_seq
109 |
110 | def fill(self) -> GptFfn.Config:
111 | assert self.n_channels is not None
112 | new = self._replace(
113 | linear1=self.linear1._replace(n_in=self.n_channels, n_out=self.n_channels * 4),
114 | linear2=self.linear2._replace(n_in=self.n_channels * 4, n_out=self.n_channels))
115 | check_config(new)
116 | return new
117 |
118 | def make(self) -> GptFfn:
119 | return GptFfn(self.fill())
120 |
121 | @property
122 | def weights(self) -> WeightConfigDict:
123 | return {}
124 |
125 | @property
126 | def parts(self) -> PartsDict:
127 | filled = self.fill()
128 | return dict(
129 | linear1=filled.linear1,
130 | linear2=filled.linear2
131 | )
132 |
133 | def weights_check(self, w: WeightsTree) -> GptFfn.Weights:
134 | return cast(GptFfn.Weights, config_weights_check(self, w))
135 |
136 | class Weights(TypedDict):
137 | linear1: Linear.Weights
138 | linear2: Linear.Weights
139 |
140 | def __init__(self, config: Config):
141 | self.n_channels = config.n_channels
142 | self.linear1 = config.linear1.make()
143 | self.linear2 = config.linear2.make()
144 |
145 |
146 | class GptBlock:
147 | class Config(NamedTuple):
148 | eps: Optional[float] = None
149 | n_channels: Optional[int] = None
150 | n_heads: Optional[int] = None
151 | n_seq: Optional[Union[int, Literal['dynamic']]] = None
152 | mha: GptMha.Config = GptMha.Config()
153 | ffn: GptFfn.Config = GptFfn.Config()
154 | ln1: LN.Config = LN.Config()
155 | ln2: LN.Config = LN.Config()
156 | name: str = "gpt_block"
157 |
158 | @property
159 | def T(self) -> Optional[int]:
160 | if self.n_seq == 'dynamic':
161 | return None
162 | else:
163 | return self.n_seq
164 |
165 | @property
166 | def x_shape(self) -> tuple[Optional[int], ...]:
167 | assert self.n_channels is not None
168 | return self.T, self.n_channels
169 |
170 | def fill(self) -> GptBlock.Config:
171 | new = self._replace(
172 | mha=self.mha._replace(n_channels=self.n_channels, n_seq=self.n_seq, n_heads=self.n_heads).fill(),
173 | ffn=self.ffn._replace(n_channels=self.n_channels, n_seq=self.n_seq).fill(),
174 | ln1=self.ln1._replace(eps=self.eps, norm_dims=(0,), x_shape=self.x_shape),
175 | ln2=self.ln2._replace(eps=self.eps, norm_dims=(0,), x_shape=self.x_shape))
176 | check_config(new)
177 | return new
178 |
179 | def make(self) -> GptBlock:
180 | return GptBlock(self.fill())
181 |
182 | @property
183 | def weights(self) -> WeightConfigDict:
184 | return {}
185 |
186 | @property
187 | def parts(self) -> PartsDict:
188 | filled = self.fill()
189 | return dict(
190 | mha=filled.mha,
191 | ffn=filled.ffn,
192 | ln1=filled.ln1,
193 | ln2=filled.ln2,
194 | )
195 |
196 | def weights_check(self, w: WeightsTree) -> GptBlock.Weights:
197 | return cast(GptBlock.Weights, config_weights_check(self, w))
198 |
199 | class Weights(TypedDict):
200 | mha: GptMha.Weights
201 | ffn: GptFfn.Weights
202 | ln1: LN.Weights
203 | ln2: LN.Weights
204 |
205 | def __init__(self, config: Config):
206 | self.T = config.T
207 | self.n_channels = config.n_channels
208 | self.mha = config.mha.make()
209 | self.ffn = config.ffn.make()
210 | self.ln1 = config.ln1.make()
211 | self.ln2 = config.ln2.make()
212 |
213 | self.ffnf = for_all_T(self.ffn.f)
214 | self.ln1f = self.ln1.f
215 | self.ln2f = self.ln2.f
216 |
217 |
218 |
219 | class GptDecoder:
220 | class Config(NamedTuple):
221 | eps: Optional[float] = None
222 | n_channels: Optional[int] = None
223 | n_heads: Optional[int] = None
224 | n_seq: Optional[Union[int, Literal['dynamic']]] = None
225 | n_blocks: Optional[int] = None
226 | blocks: GptBlock.Config = GptBlock.Config()
227 |
228 | name: str = 'gpt_decoder'
229 |
230 | @property
231 | @cache
232 | def T(self) -> Optional[int]:
233 | if self.n_seq == 'dynamic':
234 | return None
235 | else:
236 | return self.n_seq
237 |
238 | def fill(self) -> GptDecoder.Config:
239 | new = self._replace(blocks=self.blocks._replace(eps=self.eps, n_channels=self.n_channels,
240 | n_heads=self.n_heads, n_seq=self.n_seq).fill())
241 | check_config(new)
242 | return new
243 |
244 | def make(self) -> GptDecoder:
245 | return GptDecoder(self.fill())
246 |
247 | @property
248 | def weights(self) -> WeightConfigDict:
249 | return {}
250 |
251 | @property
252 | def parts(self) -> PartsDict:
253 | filled = self.fill()
254 | assert filled.blocks is not None
255 | assert filled.n_blocks is not None
256 | return dict(
257 | blocks=[filled.blocks] * filled.n_blocks
258 | )
259 |
260 | def weights_check(self, w: WeightsTree) -> GptDecoder.Weights:
261 | return cast(GptDecoder.Weights, config_weights_check(self, w))
262 |
263 | class Weights(TypedDict):
264 | blocks: List[GptBlock.Weights]
265 |
266 | def __init__(self, config: Config):
267 | assert config.n_blocks is not None
268 | self.T = config.T
269 | self.n_channels = config.n_channels
270 | self.blocks = [config.blocks.make() for _ in range(config.n_blocks)]
271 |
272 |
273 |
274 | class Gpt:
275 | class Config(NamedTuple):
276 | eps: Optional[float] = None
277 | n_channels: Optional[int] = None
278 | n_heads: Optional[int] = None
279 | n_seq: Optional[Union[int, Literal['dynamic']]] = None
280 | n_blocks: Optional[int] = None
281 | n_tokens: Optional[int] = None
282 | max_seq_len: Optional[int] = None
283 |
284 | te_name: str = 'te'
285 | te_init: Literal['normal'] = 'normal'
286 | te_scale: float = 0.02
287 |
288 | decoder: GptDecoder.Config = GptDecoder.Config()
289 | ln: LN.Config = LN.Config()
290 |
291 | pe_name: str = 'pe'
292 |
293 | name: str = 'gpt'
294 |
295 | @property
296 | def T(self) -> Optional[int]:
297 | if self.n_seq == 'dynamic':
298 | return None
299 | else:
300 | return self.n_seq
301 |
302 | def fill(self) -> Gpt.Config:
303 | assert self.n_channels is not None, 'n_channels must be specified'
304 | new = self._replace(decoder=self.decoder._replace(eps=self.eps, n_channels=self.n_channels,
305 | n_heads=self.n_heads, n_seq=self.n_seq,
306 | n_blocks=self.n_blocks).fill(),
307 | ln=self.ln._replace(eps=self.eps, norm_dims=(0,), x_shape=(self.T, self.n_channels)))
308 |
309 | check_config(new)
310 | return new
311 |
312 | def make(self) -> Gpt:
313 | return Gpt(self.fill())
314 |
315 | @property
316 | def weights(self) -> WeightConfigDict:
317 | filled = self.fill()
318 | assert filled.max_seq_len is not None
319 | assert filled.n_tokens is not None
320 | assert filled.n_channels is not None
321 | return dict(
322 | token_embedding=WeightConfig(name=filled.te_name,
323 | init=filled.te_init,
324 | shape=(filled.n_tokens, filled.n_channels),
325 | scale=filled.te_scale),
326 |
327 | positional_encoding=WeightConfig(name=filled.pe_name,
328 | shape=(filled.max_seq_len, filled.n_channels)),
329 | )
330 |
331 | @property
332 | def parts(self) -> PartsDict:
333 | filled = self.fill()
334 | assert filled.decoder is not None
335 | assert filled.ln is not None
336 | assert filled.te_name is not None
337 | return dict(
338 | decoder=filled.decoder,
339 | ln=filled.ln,
340 | )
341 |
342 | def weights_check(self, w: WeightsTree) -> Gpt.Weights:
343 | return cast(Gpt.Weights, config_weights_check(self, w))
344 |
345 | def __init__(self, config: Config):
346 | assert config.n_blocks is not None
347 | assert config.n_tokens is not None
348 | assert config.n_channels is not None
349 | self.T = config.T
350 | self.n_channels = config.n_channels
351 | self.n_tokens = config.n_tokens
352 | self.eps = config.eps
353 | self.decoder = config.decoder.make()
354 | self.ln = config.ln.make()
355 |
356 |
--------------------------------------------------------------------------------
/archive/gpt_decon.py:
--------------------------------------------------------------------------------
1 | #%%
2 | from __future__ import annotations
3 |
4 | import math
5 | from typing import Callable
6 |
7 | import jax
8 | from jax.experimental.maps import xmap
9 | from optax import softmax_cross_entropy_with_integer_labels
10 | from simple_pytree import Pytree, static_field
11 | from jax.numpy import mean, var, sqrt, tanh, pi
12 | from jax.nn import softmax
13 | from jax.lax import rsqrt
14 | import jax.numpy as jnp
15 |
16 | Arr = jax.Array
17 |
18 |
19 | def gelu(x: Arr) -> Arr:
20 | return 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x ** 3)))
21 |
22 |
23 | class LN(Pytree):
24 | eps: float = static_field()
25 |
26 | def __init__(self, eps: float, w: Arr, b: Arr):
27 | self.eps = eps
28 | self.w = w
29 | self.b = b
30 |
31 | # x: (n_channels,)
32 | def f(self, x: Arr) -> Arr:
33 | o = x - mean(x)
34 | i = self.w * rsqrt(var(x) + self.eps)
35 | return o * i + self.b
36 |
37 |
38 | class Linear(Pytree):
39 | def __init__(self, w: Arr, b: Arr):
40 | self.w = w
41 | self.b = b
42 |
43 | # x: (n_channels,)
44 | def f(self, x: Arr) -> Arr:
45 | return self.w @ x + self.b
46 |
47 |
48 | class GptMha(Pytree):
49 | n_heads: int = static_field()
50 | scale: float = static_field()
51 | dim_head: int = static_field()
52 | T: int = static_field()
53 | mask: Arr = static_field()
54 | linearf: Callable[[Arr], Arr] = static_field()
55 |
56 | # q, k, v: (dim_head, T)
57 | def causal_dot_attention(self, q: Arr, k: Arr, v: Arr) -> Arr:
58 | return softmax((q.T @ k) / self.scale + self.mask) @ v
59 |
60 | # QKV: (3, n_heads, dim_head, n_channels)
61 | def __init__(self, n_channels: int, n_heads: int, T: int, QKV: Arr, linear: Linear):
62 | self.scale = math.sqrt(n_channels)
63 | self.n_heads = n_heads
64 | self.QKV = QKV
65 | self.linear = linear
66 | self.mask = jnp.tril(jnp.ones((T, T)))
67 | self.linearf = xmap(linear.f, [None, 'T'], 'T')
68 |
69 | # x: (n_channels, T)
70 | def f(self, x: Arr) -> Arr:
71 | q, k, v = self.QKV @ x
72 | xmap_in = [['n_head', ...]] * 3
73 | attended = xmap(self.causal_dot_attention, xmap_in, 'n_head')(q, k, v)
74 | # attended: (n_heads, dim_head, T)
75 |
76 | return self.linearf(jnp.concatenate(attended))
77 |
78 |
79 | class GptFfn(Pytree):
80 | def __init__(self, linear1: Linear, linear2: Linear):
81 | self.linear1 = linear1
82 | self.linear2 = linear2
83 |
84 | # x: (n_channels,)
85 | def f(self, x: Arr) -> Arr:
86 | return self.linear2.f(gelu(self.linear1.f(x.T))).T
87 |
88 |
89 | class GptBlock(Pytree):
90 | ln1f: Callable[[Arr], Arr] = static_field()
91 | ln2f: Callable[[Arr], Arr] = static_field()
92 | ffnf: Callable[[Arr], Arr] = static_field()
93 |
94 | def __init__(self, mha: GptMha, ffn: GptFfn, ln1: LN, ln2: LN):
95 | self.mha = mha
96 | self.ffn = ffn
97 | self.ln1 = ln1
98 | self.ln2 = ln2
99 | self.ln1f = xmap(ln1.f, [None, 'T'], 'T')
100 | self.ln2f = xmap(ln2.f, [None, 'T'], 'T')
101 | self.ffnf = xmap(ffn.f, [None, 'T'], 'T')
102 |
103 | # x: (n_channels, T)
104 | def f(self, x: Arr) -> Arr:
105 | x += self.mha.f(self.ln1f(x))
106 | x += self.ffnf(self.ln2f(x))
107 | return x
108 |
109 |
110 | class GptDecoder(Pytree):
111 | def __init__(self, blocks: list[GptBlock]):
112 | self.blocks = blocks
113 |
114 | # x: (n_channels, T)
115 | def f(self, x: Arr) -> Arr:
116 | for block in self.blocks:
117 | x = block.f(x)
118 | return x
119 |
120 |
121 | class Gpt(Pytree):
122 | T: int = static_field()
123 |
124 | def __init__(self, T: int, decoder: GptDecoder, token_embed: Arr, position_embed: Arr, ln: LN):
125 | self.decoder = decoder
126 | self.ln = ln
127 | self.te = token_embed
128 | self.pe = position_embed
129 | self.T = T
130 |
131 | # x: (n_channels, T)
132 | def f(self, x: Arr) -> Arr:
133 | x = self.te[x] + self.pe[self.T]
134 | return self.ln.f(self.decoder.f(x))
135 |
136 |
137 | def gpt_loss(gpt: Gpt, inputs: list[int], labels: list[int]) -> Arr:
138 | logits = gpt.f(jnp.array(inputs))
139 | return softmax_cross_entropy_with_integer_labels(logits, labels)
140 |
141 |
142 |
--------------------------------------------------------------------------------
/archive/notebook.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import jax
3 | import jax.numpy as jnp
4 |
5 | """
6 | mk o i g u K R V mv mv
7 | fk fi fK fR fv
8 | ln0b ln0w ln1b ln1w ln2b ln2w
9 | """
10 |
11 | sigma = jax.nn.sigmoid
12 | maxi = jax.numpy.maximum
13 | e = jax.numpy.exp
14 | relu = jax.nn.relu
15 |
16 |
17 | def mix(_in, _state, _mix):
18 | return _in * _mix + _state * (1 - _mix)
19 |
20 |
21 | """
22 | x: 1024
23 | s: 1024
24 |
25 | """
26 |
27 |
28 | def ln(x, w, b):
29 | mean = jnp.mean(x)
30 | v = jnp.var(x)
31 | o = x - mean
32 | i = w * jax.lax.rsqrt(v + 1e-5)
33 | return o * i + b
34 |
35 |
36 | def mix_t(x, s, K, V, R, mk, o, mr, g, u, mv, aa, bb, pp):
37 | xk, xv, xr = mix(x, s, mk), mix(x, s, mv), mix(x, s, mr)
38 | s.update_(x)
39 |
40 | k = xk @ K
41 | v = xv @ V
42 |
43 | ww = u + k
44 | p = maxi(pp, ww)
45 | wkv = (e(pp - p) * aa + e(ww - p) * v) / (e(pp - p) * bb + e(ww - p))
46 |
47 | p = maxi(pp + g, k)
48 | aa.update_(e(ww - p) * aa + e(k - p) * v)
49 | bb.update_(e(ww - p) * bb + e(k - p))
50 | pp.update_(p)
51 |
52 | return (sigma(xr @ R) * wkv) @ o
53 |
54 |
55 | def mix_h(x, ss, K, R, mk, mr, V):
56 | xk, xr = mix(x, ss, mk), mix(x, ss, mr)
57 | ss.update_(x)
58 | return sigma(xr @ R) * (relu(xk @ K) ** 2 @ V)
59 |
60 |
61 | def cell(x, s, K, V, R, mk, o, mr, g, u, mv, aa, bb, pp, ln1b, ln1w, ln2b, ln2w, fK, fR, mfk, mfr, fV):
62 | xx = ln(x, ln1w, ln1b)
63 | x += mix_t(xx, s, K, V, R, mk, o, mr, g, u, mv, aa, bb, pp)
64 | x = ln(x, ln2w, ln2b)
65 | xx = mix_h(x, s, fK, fR, mfk, mfr, fV)
66 | return x + xx
67 |
--------------------------------------------------------------------------------
/archive/pico_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import re
4 |
5 | import numpy as np
6 | import requests
7 | import tensorflow as tf
8 | from tqdm import tqdm
9 |
10 | from bpe_encoder import get_encoder
11 |
12 |
13 | def download_gpt2_files(model_size, model_dir):
14 | assert model_size in ["124M", "355M", "774M", "1558M"]
15 | for filename in [
16 | "checkpoint",
17 | "encoder.json",
18 | "hparams.json",
19 | "model.ckpt.data-00000-of-00001",
20 | "model.ckpt.index",
21 | "model.ckpt.meta",
22 | "vocab.bpe",
23 | ]:
24 | url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
25 | r = requests.get(f"{url}/{model_size}/{filename}", stream=True)
26 | r.raise_for_status()
27 |
28 | with open(os.path.join(model_dir, filename), "wb") as f:
29 | file_size = int(r.headers["content-length"])
30 | chunk_size = 1000
31 | with tqdm(
32 | ncols=100,
33 | desc="Fetching " + filename,
34 | total=file_size,
35 | unit_scale=True,
36 | unit="b",
37 | ) as pbar:
38 | # 1k for chunk_size, since Ethernet packet size is around 1500 bytes
39 | for chunk in r.iter_content(chunk_size=chunk_size):
40 | f.write(chunk)
41 | pbar.update(chunk_size)
42 |
43 |
44 | def load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams):
45 | def set_in_nested_dict(d, keys, val):
46 | if not keys:
47 | return val
48 | if keys[0] not in d:
49 | d[keys[0]] = {}
50 | d[keys[0]] = set_in_nested_dict(d[keys[0]], keys[1:], val)
51 | return d
52 |
53 | params = {"blocks": [{} for _ in range(hparams["n_layer"])]}
54 | for name, _ in tf.train.list_variables(tf_ckpt_path):
55 | array = np.squeeze(tf.train.load_variable(tf_ckpt_path, name))
56 | name = name[len("model/") :]
57 | if name.startswith("h"):
58 | m = re.match(r"h([0-9]+)/(.*)", name)
59 | n = int(m[1])
60 | sub_name = m[2]
61 | set_in_nested_dict(params["blocks"][n], sub_name.split("/"), array)
62 | else:
63 | set_in_nested_dict(params, name.split("/"), array)
64 |
65 | return params
66 |
67 |
68 | def load_encoder_hparams_and_params(model_size, models_dir):
69 | assert model_size in ["124M", "355M", "774M", "1558M"]
70 |
71 | model_dir = os.path.join(models_dir, model_size)
72 | tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
73 | if not tf_ckpt_path: # download files if necessary
74 | os.makedirs(model_dir, exist_ok=True)
75 | download_gpt2_files(model_size, model_dir)
76 | tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
77 |
78 | encoder = get_encoder(model_size, models_dir)
79 | hparams = json.load(open(os.path.join(model_dir, "hparams.json")))
80 | params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams)
81 |
82 | return encoder, hparams, params
--------------------------------------------------------------------------------
/copy_init/copy_rwkv.py:
--------------------------------------------------------------------------------
1 | #%%
2 | from __future__ import annotations
3 |
4 | from pathlib import Path
5 |
6 | from copy_init.weights import get_normal_weights_config_, init
7 | from picojax.random_utils import infinite_safe_keys
8 |
9 | path = Path("/Data/lm_models/rwkv")
10 | model_name = 'RWKV-4-Pile-430M-20220808-8066'
11 |
12 | weight_infos = get_normal_weights_config_(path, model_name)
13 |
14 | keygen = infinite_safe_keys(0)
15 | key = next(keygen)
16 |
17 | w = init(weight_infos, rng_key=key)
18 |
19 | print(w)
--------------------------------------------------------------------------------
/copy_init/test_weights.py:
--------------------------------------------------------------------------------
1 | # %%
2 | from __future__ import annotations
3 |
4 | from pathlib import Path
5 |
6 | import jax.numpy as np
7 | from jax.tree_util import tree_flatten
8 |
9 | from copy_init.weights import get_normal_weights_config_, init, save_pytree_, load_pytree_
10 | from pico_rwkv.pico_rwkv_weights import parse_rwkv_weight
11 | from picojax.jax_utils import WeightsTree
12 | from picojax.random_utils import infinite_safe_keys
13 |
14 | model_path = Path("/Data/lm_models/rwkv")
15 | # model_name = 'RWKV-4-Pile-430M-20220808-8066'
16 | model_name = 'RWKV-4-Pile-169M-20220807-8023'
17 |
18 | weight_infos = get_normal_weights_config_(model_path, model_name)
19 | keygen = infinite_safe_keys(0)
20 | key = next(keygen)
21 | init_weights_raw = init(weight_infos, rng_key=key)
22 | init_weights_: WeightsTree = parse_rwkv_weight(init_weights_raw.keys(), init_weights_raw.__getitem__, trim=True)
23 | _, tree_struct = tree_flatten(init_weights_)
24 | f = save_pytree_(init_weights_, ".", model_name)
25 | print(f)
26 | ww = load_pytree_(tree_struct, ".", model_name)
27 |
28 |
29 | def nested_dict_equal(a, b):
30 | """
31 | Check if two nested dictionaries of NumPy arrays and lists are equal.
32 | """
33 | if isinstance(a, dict) and isinstance(b, dict):
34 | if a.keys() != b.keys():
35 | return False
36 | for key in a.keys():
37 | if not nested_dict_equal(a[key], b[key]):
38 | print("kd", key)
39 | print(a[key].shape, b[key].shape)
40 | return False
41 | return True
42 | elif isinstance(a, list) and isinstance(b, list):
43 | if len(a) != len(b):
44 | return False
45 | for i in range(len(a)):
46 | if not nested_dict_equal(a[i], b[i]):
47 | return False
48 | return True
49 | elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
50 | return np.array_equal(a, b)
51 | else:
52 | return a == b
53 |
54 |
55 | print(nested_dict_equal(init_weights_, ww))
56 |
--------------------------------------------------------------------------------
/copy_init/weights.py:
--------------------------------------------------------------------------------
1 | # %%
2 | from __future__ import annotations
3 |
4 | import json
5 | import os
6 | from pathlib import Path
7 | from typing import Protocol, Literal, Optional, cast
8 |
9 | import jax.numpy as np
10 | from jax import random
11 | from jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef
12 | from pydantic import ValidationError
13 | from pydantic.dataclasses import dataclass
14 | from pydantic.json import pydantic_encoder
15 | from safetensors import safe_open
16 | from safetensors.flax import save_file
17 |
18 | from labels import Labels, L
19 | from picojax.jax_utils import Arr, WeightsTree
20 | from picojax.random_utils import SafeKey
21 |
22 | WeightConfigType = Literal['normal', 'zero']
23 |
24 |
25 | class WeightConfig(Protocol):
26 | name: str
27 | tags: Labels
28 | shape: tuple[int, ...]
29 |
30 | def init(self, rng_key: SafeKey) -> Arr:
31 | ...
32 |
33 | @classmethod
34 | def from_arr(cls, name: str, tags: Labels, arr: Arr) -> WeightConfig:
35 | ...
36 |
37 |
38 | @dataclass
39 | class NormalWeight:
40 | name: str
41 | tags: Labels
42 | shape: tuple[int, ...]
43 | mean: float = 0
44 | scale: float = 1
45 |
46 | def init(self, rng_key: SafeKey) -> Arr:
47 | return random.normal(rng_key.get(), self.shape) * self.scale + self.mean
48 |
49 | @classmethod
50 | def from_arr(cls, name: str, tags: Labels, arr: Arr) -> NormalWeight:
51 | return cls(name=name, tags=tags, shape=arr.shape, mean=arr.mean().item(), scale=arr.std().item())
52 |
53 |
54 | @dataclass
55 | class ZeroWeight:
56 | name: str
57 | tags: Labels
58 | shape: tuple[int, ...]
59 |
60 | def init(self, rng_key: SafeKey) -> Arr:
61 | return np.zeros(self.shape)
62 |
63 | @classmethod
64 | def from_arr(cls, name: str, tags: Labels, arr: Arr) -> ZeroWeight:
65 | return cls(name=name, tags=tags, shape=arr.shape)
66 |
67 |
68 |
69 | def get_weight_config(weight_config_type: WeightConfigType, key: str, arr: Arr, tags: Labels) -> WeightConfig:
70 | if weight_config_type == 'normal':
71 | return NormalWeight.from_arr(name=key, tags=tags, arr=arr)
72 | elif weight_config_type == 'zero':
73 | return ZeroWeight.from_arr(name=key, tags=tags, arr=arr)
74 | else:
75 | raise ValueError(f"weight_config_type must be 'normal' or 'zero', not {weight_config_type}")
76 |
77 |
78 | def get_normal_weights_config_(path: Path, model_name: str,
79 | non_normal_weight_tags: Optional[dict[str, WeightConfigType]] = None) -> dict[
80 | str, WeightConfig]:
81 | try:
82 | return load_weight_configs_(path, model_name)
83 | except FileNotFoundError:
84 | weight_infos: dict[str, WeightConfig] = {}
85 | with safe_open(path / f"{model_name}.safetensors", framework="flax", device="cpu") as f:
86 | for key in f.keys():
87 | t = f.get_tensor(key)
88 | # weight_infos[key] = NormalWeight.from_tensor(name=key, tags=L(*key.split('.')), arr=t)
89 | tags = L(*key.split('.'))
90 | weight_config_type = 'normal'
91 | if non_normal_weight_tags is not None:
92 | for tag, _type in non_normal_weight_tags.items():
93 | if (tag in tags.tags) and (_type in ['normal', 'zero']):
94 | weight_config_type = _type
95 | weight_config_type = cast(WeightConfigType, weight_config_type)
96 | weight_infos[key] = get_weight_config(weight_config_type, key, t, tags)
97 |
98 | with open(path / f"{model_name}.json", 'w') as f:
99 | json.dump(weight_infos, f, indent=2, default=pydantic_encoder)
100 | return weight_infos
101 |
102 |
103 | def load_weight_configs_(path: Path, model_name: str) -> dict[str, WeightConfig]:
104 | with open(path / f"{model_name}.json") as f:
105 | weight_info_dict = json.load(f)
106 | weight_infos: dict[str, WeightConfig] = {}
107 | for key, value in weight_info_dict.items():
108 | try:
109 | weight_infos[key] = NormalWeight(**value)
110 |
111 | except ValidationError as e:
112 | print(e)
113 | return weight_infos
114 |
115 |
116 | def init(w_infos: dict[str, WeightConfig], rng_key: SafeKey) -> dict[str, Arr]:
117 | rng_keys = rng_key.split(len(w_infos))
118 | return {key: w_info.init(rng_key)
119 | for key, w_info, rng_key in zip(w_infos.keys(), w_infos.values(), rng_keys)}
120 |
121 |
122 | def get_weights_mask(whitelist: list[Labels], w_infos: dict[str, WeightConfig]) -> dict[str, bool]:
123 | def is_in_whitelist(w_info: WeightConfig) -> bool:
124 | for tag in whitelist:
125 | if w_info.tags.covers(tag):
126 | return True
127 | return False
128 |
129 | return {key: is_in_whitelist(w_info) for key, w_info in w_infos.items()}
130 |
131 |
132 | def save_pytree_(w: WeightsTree, checkpoint_path: str, model_name: str) -> str:
133 | tensors, _ = tree_flatten(w)
134 | file_path = os.path.join(checkpoint_path, f"{model_name}.safetensors")
135 | n_tensors = len(tensors)
136 | save_file({str(i).zfill(len(str(n_tensors))): t for i, t in enumerate(tensors)}, file_path)
137 | return file_path
138 |
139 |
140 | def load_pytree_(t: PyTreeDef, checkpoint_path: str, model_name: str) -> WeightsTree:
141 | file_path = os.path.join(checkpoint_path, f"{model_name}.safetensors")
142 | with safe_open(file_path, framework="flax") as f:
143 | v = iter(f.get_tensor(k) for k in f.keys())
144 | return tree_unflatten(t, v)
145 |
--------------------------------------------------------------------------------
/custom_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import Counter
3 | from dataclasses import dataclass
4 | from pathlib import Path
5 | from typing import NamedTuple, Callable
6 |
7 | import jax.numpy as xp
8 |
9 | # dataset = "play"
10 | base_path = Path("/Data/nlp/")
11 |
12 | def load(dataset: str = "english"):
13 | path = base_path / dataset
14 | books = [f for f in os.listdir(path) if f.endswith('.txt')]
15 |
16 | def read_book_file(filename: str):
17 | print(f"reading {filename}...")
18 | try:
19 | with open(path / filename, 'r', encoding='utf-8') as f:
20 | return f.read()
21 | except UnicodeDecodeError:
22 | try:
23 | with open(path / filename, 'r', encoding='gb2312') as f:
24 | return f.read()
25 | except UnicodeDecodeError:
26 | try:
27 | with open(path / filename, 'r', encoding='gbk') as f:
28 | return f.read()
29 | except UnicodeDecodeError:
30 | try:
31 | with open(path / filename, 'r', encoding='big5') as f:
32 | return f.read()
33 | except UnicodeDecodeError:
34 | try:
35 | with open(path / filename, 'r', encoding='utf-16') as f:
36 | return f.read()
37 | except UnicodeDecodeError:
38 | try:
39 | with open(path / filename, 'r', encoding='gb18030') as f:
40 | return f.read()
41 | except UnicodeDecodeError:
42 | raise Exception(f"Failed to read {filename} with many encodings")
43 |
44 | text = "\n\n".join(f"{book_name}\n\n {read_book_file(book_name)}" for book_name in books)
45 | chars = [ch for ch, c in Counter(text).most_common()]
46 | vocab_size = len(chars)
47 | stoi = {c: i for i, c in enumerate(chars)}
48 | itos = {i: c for i, c in enumerate(chars)}
49 |
50 | def encode(_text: str):
51 | return [stoi[c] for c in _text]
52 |
53 | def decode(_encoded: list):
54 | return "".join(itos[i] for i in _encoded)
55 |
56 | return text, encode, decode, vocab_size
57 |
58 |
59 | def load_jax_cached(dataset: str = "english"):
60 | text, encode, decode, vocab_size = load(dataset)
61 | cache_path = base_path / dataset / 'encoded_jax.npy'
62 | try:
63 | with open(cache_path, 'rb') as f:
64 | encoded_jax = xp.load(f)
65 | except FileNotFoundError:
66 | encoded = encode(text)
67 | encoded_jax = xp.array(encoded, dtype=xp.int16)
68 | print(encoded_jax.shape, encoded_jax.dtype)
69 | with open(cache_path, 'wb') as fw:
70 | xp.save(fw, encoded_jax)
71 | return encoded_jax, encode, decode, vocab_size
72 |
73 |
74 | @dataclass
75 | class Tokens:
76 | ids: list[int]
77 |
78 |
79 | class Tokenizer(NamedTuple):
80 | vocab_size: int
81 | encode_: Callable[[str], Tokens]
82 | decode_: Callable[[list[int]], str]
83 |
84 | def encode(self, text: str) -> Tokens:
85 | return self.encode_(text)
86 |
87 | def decode(self, tokens: list[int]) -> str:
88 | return self.decode_(tokens)
89 |
90 | def get_vocab_size(self) -> int:
91 | return self.vocab_size
92 |
93 |
94 | def load_jax_cached_tokenizer(base_path: Path, dataset: str = "english") -> tuple[xp.ndarray, Tokenizer]:
95 | text, encode, decode, vocab_size = load(dataset)
96 | cache_path = base_path / dataset / 'encoded_jax.npy'
97 | try:
98 | with open(cache_path, 'rb') as f:
99 | encoded_jax = xp.load(f)
100 | except FileNotFoundError:
101 | encoded = encode(text)
102 | encoded_jax = xp.array(encoded, dtype=xp.int16)
103 | print(encoded_jax.shape, encoded_jax.dtype)
104 | with open(cache_path, 'wb') as fw:
105 | xp.save(fw, encoded_jax)
106 | return encoded_jax, Tokenizer(vocab_size, lambda x: Tokens(encode(x)), decode)
107 |
--------------------------------------------------------------------------------
/custom_dataset_str.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 |
5 | def load(base_path: Path, dataset: str = "english"):
6 | path = base_path / dataset
7 | books = [f for f in os.listdir(path) if f.endswith('.txt')]
8 |
9 | def read_book_file(filename: str):
10 | print(f"reading {filename}...")
11 | encodings = ['utf-8', 'gb18030', 'gbk', 'gb2312', 'big5', 'utf-16']
12 | for encoding in encodings:
13 | try:
14 | with open(path / filename, 'r', encoding=encoding) as f:
15 | return f.read()
16 | except UnicodeDecodeError:
17 | print(f"Failed to read {filename} with encoding: {encoding}")
18 | continue
19 | raise Exception(f"Failed to read {filename} with all available encodings")
20 |
21 | text = "\n\n".join(f"{book_name}\n\n {read_book_file(book_name)}" for book_name in books)
22 | return text
23 |
--------------------------------------------------------------------------------
/gpt_recon/bpe_encoder.py:
--------------------------------------------------------------------------------
1 | """Byte pair encoding utilities.
2 |
3 | Copied from: https://github.com/openai/gpt-2/blob/master/src/encoder.py.
4 | """
5 | import json
6 | import os
7 | from functools import lru_cache
8 |
9 | import regex as re
10 |
11 |
12 | @lru_cache()
13 | def bytes_to_unicode():
14 | """
15 | Returns list of utf-8 byte and a corresponding list of unicode strings.
16 | The reversible bpe codes work on unicode strings.
17 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
18 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
19 | This is a signficant percentage of your normal, say, 32K bpe vocab.
20 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
21 | And avoids mapping to whitespace/control characters the bpe code barfs on.
22 | """
23 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
24 | cs = bs[:]
25 | n = 0
26 | for b in range(2**8):
27 | if b not in bs:
28 | bs.append(b)
29 | cs.append(2**8 + n)
30 | n += 1
31 | cs = [chr(n) for n in cs]
32 | return dict(zip(bs, cs))
33 |
34 |
35 | def get_pairs(word):
36 | """Return set of symbol pairs in a word.
37 | Word is represented as tuple of symbols (symbols being variable-length strings).
38 | """
39 | pairs = set()
40 | prev_char = word[0]
41 | for char in word[1:]:
42 | pairs.add((prev_char, char))
43 | prev_char = char
44 | return pairs
45 |
46 |
47 | class Encoder:
48 | def __init__(self, encoder, bpe_merges, errors="replace"):
49 | self.encoder = encoder
50 | self.decoder = {v: k for k, v in self.encoder.items()}
51 | self.errors = errors # how to handle errors in decoding
52 | self.byte_encoder = bytes_to_unicode()
53 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
54 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
55 | self.cache = {}
56 |
57 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
58 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
59 |
60 | def bpe(self, token):
61 | if token in self.cache:
62 | return self.cache[token]
63 | word = tuple(token)
64 | pairs = get_pairs(word)
65 |
66 | if not pairs:
67 | return token
68 |
69 | while True:
70 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
71 | if bigram not in self.bpe_ranks:
72 | break
73 | first, second = bigram
74 | new_word = []
75 | i = 0
76 | while i < len(word):
77 | try:
78 | j = word.index(first, i)
79 | new_word.extend(word[i:j])
80 | i = j
81 | except:
82 | new_word.extend(word[i:])
83 | break
84 |
85 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
86 | new_word.append(first + second)
87 | i += 2
88 | else:
89 | new_word.append(word[i])
90 | i += 1
91 | new_word = tuple(new_word)
92 | word = new_word
93 | if len(word) == 1:
94 | break
95 | else:
96 | pairs = get_pairs(word)
97 | word = " ".join(word)
98 | self.cache[token] = word
99 | return word
100 |
101 | def encode(self, text):
102 | bpe_tokens = []
103 | for token in re.findall(self.pat, text):
104 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
105 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
106 | return bpe_tokens
107 |
108 | def decode(self, tokens):
109 | text = "".join([self.decoder[token] for token in tokens])
110 | text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
111 | return text
112 |
113 |
114 | def get_encoder(model_name, models_dir, encoder_file="encoder.json", vocab_file="vocab.bpe"):
115 | with open(os.path.join(models_dir, model_name, encoder_file), "r") as f:
116 | encoder = json.load(f)
117 | with open(os.path.join(models_dir, model_name, vocab_file), "r", encoding="utf-8") as f:
118 | bpe_data = f.read()
119 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
120 | return Encoder(encoder=encoder, bpe_merges=bpe_merges)
--------------------------------------------------------------------------------
/gpt_recon/clean_frame.py:
--------------------------------------------------------------------------------
1 | # %%
2 | from __future__ import annotations
3 | from typing import NamedTuple, TypeVar, Callable, TypedDict, Literal, cast, \
4 | Optional
5 |
6 | from chex import assert_shape
7 | from jax import vmap
8 | from jax.experimental.maps import xmap
9 | from jax.lax import rsqrt
10 | from jax.numpy import mean, var, sqrt, tanh, pi
11 |
12 | from picojax.jax_utils import jit_f, WeightsTree
13 | from gpt_recon.clean_frame_utils import Arr, PartsDict, WeightConfig, WeightConfigDict, check_config, config_weights_check
14 |
15 | C = TypeVar('C')
16 | W = TypeVar('W')
17 |
18 |
19 | # TODO: unify init_param
20 | # TODO: load real GPT weights
21 |
22 |
23 | def no_w(d: C) -> tuple[list, C]:
24 | return [...], d
25 |
26 |
27 | def batch_ops_x(f: Callable[[W, Arr], Arr], label: str, add_behind: bool, skip_w: bool) -> Callable[[W, Arr], Arr]:
28 | if add_behind:
29 | extension = [..., label]
30 | else:
31 | extension = [label, ...]
32 | if skip_w:
33 | return xmap(f, no_w(extension), extension)
34 | else:
35 | return xmap(f, extension, extension)
36 |
37 |
38 | def batch_ops(f: Callable[[W, Arr], Arr], label: str, add_behind: bool, skip_w: bool) -> Callable[[W, Arr], Arr]:
39 | if add_behind:
40 | extension = -1
41 | else:
42 | extension = 0
43 | if skip_w:
44 | return vmap(f, (None, extension), extension)
45 | else:
46 | return vmap(f, extension, extension)
47 |
48 |
49 | def for_all_T(f: Callable[[W, Arr], Arr]) -> Callable[[W, Arr], Arr]:
50 | return batch_ops(f, 'T', False, True)
51 |
52 |
53 | def batch_fy(f: Callable[[W, Arr], Arr]) -> Callable[[W, Arr], Arr]:
54 | return batch_ops(f, 'batch', False, True)
55 |
56 |
57 | def gelu(x: Arr) -> Arr:
58 | return 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x ** 3)))
59 |
60 |
61 | class Linear:
62 | class Weights(TypedDict):
63 | w: Arr
64 | b: Arr
65 |
66 | @jit_f
67 | def f(self, w: Weights, x: Arr) -> Arr:
68 | assert_shape(x, (self.n_in,))
69 | result = w['w'].T @ x + w['b']
70 | assert_shape(result, (self.n_out,))
71 | return result
72 |
73 | class Config(NamedTuple):
74 | n_in: Optional[int] = None
75 | n_out: Optional[int] = None
76 | w_save_name: str = "w"
77 | b_save_name: str = "b"
78 | w_init: Literal["kaiming"] = "kaiming"
79 | b_init: Literal[0] = 0
80 | save_name: str = "linear"
81 |
82 | def make(self) -> Linear:
83 | check_config(self)
84 | return Linear(self)
85 |
86 | @property
87 | def weights(self) -> WeightConfigDict:
88 | assert self.n_in is not None
89 | assert self.n_out is not None
90 | return dict(
91 | w=WeightConfig(
92 | save_name=self.w_save_name,
93 | shape=(self.n_in, self.n_out),
94 | init=self.w_init
95 | ),
96 | b=WeightConfig(
97 | save_name=self.b_save_name,
98 | shape=(self.n_out,),
99 | init=self.b_init)
100 | )
101 |
102 | @property
103 | def parts(self) -> PartsDict:
104 | return {}
105 |
106 | def fill(self) -> Linear.Config:
107 | return self
108 |
109 | def weights_check(self, w: WeightsTree) -> Linear.Weights:
110 | return cast(Linear.Weights, config_weights_check(self, w))
111 |
112 | def __init__(self, config: Config):
113 | self.config = config
114 | assert config.n_in is not None
115 | assert config.n_out is not None
116 | self.n_in = config.n_in
117 | self.n_out = config.n_out
118 |
119 |
120 | class LN:
121 | class Weights(TypedDict):
122 | w: Arr
123 | b: Arr
124 |
125 | # x: self.shape
126 | @jit_f
127 | def f(self, w: LN.Weights, x: Arr) -> Arr:
128 | o = x - mean(x, axis=self.config.non_norm_dims, keepdims=True)
129 | i = o * rsqrt(var(x, axis=self.config.non_norm_dims, keepdims=True) + self.eps)
130 | return w['w'] * i + w['b']
131 |
132 | class Config(NamedTuple):
133 | eps: Optional[float] = None
134 | # all the other dimensions are normalized
135 | norm_dims: Optional[tuple[int, ...]] = None
136 | x_shape: Optional[tuple[Optional[int], ...]] = None
137 | w_save_name: str = "w"
138 | b_save_name: str = "b"
139 | w_init: Literal[0] = 0
140 | b_init: Literal[0] = 0
141 | save_name: str = "ln"
142 | norm_dim_name: str = "norm_dim"
143 |
144 | def make(self) -> LN:
145 | check_config(self)
146 | return LN(self)
147 |
148 | @property
149 | def non_norm_dims(self) -> tuple[int, ...]:
150 | assert self.norm_dims is not None
151 | assert self.x_shape is not None
152 | return tuple(i for i in range(len(self.x_shape)) if i not in self.norm_dims)
153 |
154 | @property
155 | def weights(self) -> WeightConfigDict:
156 | assert self.norm_dims is not None, 'norm_dims must be specified'
157 | assert self.x_shape is not None, 'x_shape must be specified'
158 | assert self.eps is not None, 'eps must be specified'
159 | non_norm_shape = tuple(self.x_shape[i] for i in range(len(self.x_shape)) if i not in self.norm_dims)
160 | non_norm_shape = cast(tuple[int, ...], non_norm_shape)
161 | return dict(
162 | w=WeightConfig(
163 | save_name=self.w_save_name,
164 | shape=non_norm_shape,
165 | init=self.w_init
166 | ),
167 | b=WeightConfig(
168 | save_name=self.b_save_name,
169 | shape=non_norm_shape,
170 | init=self.b_init
171 | )
172 | )
173 |
174 | @property
175 | def parts(self) -> PartsDict:
176 | return {}
177 |
178 | def fill(self) -> LN.Config:
179 | return self
180 |
181 | def weights_check(self, w: WeightsTree) -> LN.Weights:
182 | return cast(LN.Weights, config_weights_check(self, w))
183 |
184 | def __init__(self, config: Config):
185 | assert config.eps is not None
186 | self.config = config
187 | self.eps = config.eps
188 |
--------------------------------------------------------------------------------
/gpt_recon/clean_frame_utils.py:
--------------------------------------------------------------------------------
1 | # %%
2 | from __future__ import annotations
3 |
4 | from abc import abstractmethod
5 | from typing import NamedTuple, TypeVar, List, Union, Protocol, Mapping, cast, Callable
6 |
7 | import jax.numpy as jnp
8 | from typing_extensions import runtime
9 |
10 | from picojax.jax_utils import Arr, WeightsTree
11 | from picojax.random_utils import kaiming_init, embedding_init, normal_init, dropout_gen, SafeKey, ArrayGen
12 |
13 |
14 | class WeightConfig(NamedTuple):
15 | # from in to out
16 | save_name: str
17 | shape: tuple[int, ...]
18 | init: Union[ArrayGen, int, float] = "kaiming"
19 | scale: float = 1
20 |
21 | def make(self, rng_key: SafeKey) -> Arr:
22 | if isinstance(self.init, int) or isinstance(self.init, float):
23 | return jnp.full(self.shape, float(self.init))
24 | elif self.init == 'kaiming':
25 | return kaiming_init(rng_key, self.scale, self.shape)
26 | elif self.init == 'embedding':
27 | return embedding_init(rng_key, self.scale, self.shape)
28 | elif self.init == 'normal':
29 | return normal_init(rng_key, self.scale, self.shape)
30 | elif self.init == 'dropout':
31 | return dropout_gen(rng_key, self.scale, self.shape)
32 | else:
33 | raise NotImplementedError("unsupported init type")
34 |
35 |
36 | W_co = TypeVar('W_co', covariant=True)
37 |
38 |
39 | @runtime
40 | class ModuleConfig(Protocol[W_co]):
41 | @property
42 | @abstractmethod
43 | def save_name(self) -> str:
44 | ...
45 |
46 | @property
47 | @abstractmethod
48 | def weights(self) -> WeightConfigDict:
49 | ...
50 |
51 | @property
52 | @abstractmethod
53 | def parts(self) -> PartsDict:
54 | ...
55 |
56 | def fill(self) -> ModuleConfig:
57 | ...
58 |
59 | def make(self) -> Module:
60 | ...
61 |
62 | def weights_check(self, weights: WeightsTree) -> W_co:
63 | ...
64 |
65 |
66 | W = TypeVar('W', contravariant=True)
67 |
68 |
69 | @runtime
70 | class Module(Protocol[W]):
71 |
72 | @abstractmethod
73 | def f(self, w: W, x: Arr) -> Arr:
74 | ...
75 |
76 |
77 | T = TypeVar('T')
78 |
79 |
80 | def check_config(config: NamedTuple) -> None:
81 | for k, v in config._asdict().items():
82 | if v is None:
83 | raise ValueError(f"Missing config '{k}' in {config.__class__}")
84 |
85 |
86 | WeightConfigTree = Mapping[str, Union[WeightConfig, "WeightConfigTree", list["WeightConfigTree"]]]
87 |
88 | WeightConfigDict = dict[str, WeightConfig]
89 | PartsDict = dict[str, Union[ModuleConfig, List[ModuleConfig]]]
90 |
91 |
92 | def config_weights_check(config: ModuleConfig, weights: WeightsTree) -> WeightsTree:
93 | try:
94 | checked_w: WeightsTree = {}
95 | assert isinstance(weights, dict), f"weights for {config.save_name} module is not a dict: {type(weights)}"
96 | for name, part_config in config.parts.items():
97 | if name not in weights:
98 | raise ValueError(f"Missing weight {name}")
99 | w = weights[name]
100 | if isinstance(part_config, ModuleConfig):
101 | if not isinstance(w, dict):
102 | raise ValueError(f"weights for {config.save_name} module is not a dict: {type(w)}")
103 | checked_w[name] = config_weights_check(part_config, w)
104 | else:
105 | err_msg = f"Config {config.save_name} should be a ModuleConfig or a non-empty list of ModuleConfigs, got {part_config}"
106 | assert isinstance(part_config, list), err_msg
107 | assert len(part_config) > 0, err_msg
108 | if len(part_config) != len(w):
109 | raise ValueError(
110 | f"Wrong number of weights in list {name}: weight {len(w)} != config {len(part_config)}")
111 | assert isinstance(part_config[0], ModuleConfig), err_msg
112 | if not isinstance(w, list):
113 | raise ValueError(f"weights for {name} module is not a list: {type(w)}")
114 | else:
115 | checked_w[name] = [config_weights_check(part, cast(WeightsTree, w[i]))
116 | for i, part in enumerate(part_config)]
117 | checked_w.update(weights_check(config.weights, weights))
118 | return checked_w
119 | except ValueError as e:
120 | raise ValueError(f"Config Weights check failed for {config.save_name}") from e
121 |
122 |
123 | def weights_check(weights_config: WeightConfigDict, weights: WeightsTree) -> WeightsTree:
124 | try:
125 | w: WeightsTree = {}
126 | for name, config in weights_config.items():
127 | assert isinstance(config, WeightConfig)
128 | w[name] = weight_check(config, name, weights)
129 | return w
130 | except ValueError as e:
131 | raise ValueError(f"Weights check failed for {weights_config}") from e
132 |
133 |
134 | def squeeze_side(s: tuple[int, ...]) -> tuple[int, ...]:
135 | assert len(s) > 0, "Empty shape cannot be squeezed"
136 | if len(s) == 1:
137 | return s
138 | if s[0] == 1:
139 | return squeeze_side(s[1:])
140 | if s[-1] == 1:
141 | return squeeze_side(s[:-1])
142 | return s
143 |
144 |
145 | def weight_check(config: WeightConfig, name: str, weights: WeightsTree) -> Arr:
146 | if name not in weights:
147 | raise ValueError(f"Missing weight {name}")
148 | w = weights[name]
149 | if not isinstance(w, Arr):
150 | raise ValueError(f"weight {name} is not an array: {type(w)}")
151 | if squeeze_side(w.shape) != squeeze_side(config.shape):
152 | raise ValueError(f"Shape for weight {name} does not match: {w.shape} != {config.shape}")
153 | else:
154 | return w.reshape(config.shape)
155 |
156 |
157 | def load_config(config: ModuleConfig, weights_getter: Callable[[str], Arr], prefix: str = "") -> WeightsTree:
158 | if len(config.save_name) == 0:
159 | name_prefix = prefix
160 | else:
161 | name_prefix = f"{prefix}{config.save_name}."
162 | weights: WeightsTree = {}
163 | for name, part_config in config.parts.items():
164 | if isinstance(part_config, ModuleConfig):
165 | weights[name] = load_config(part_config, weights_getter, name_prefix)
166 | else:
167 | err_msg = f"Config {config.save_name} should be a ModuleConfig or a non-empty list of ModuleConfigs"
168 | assert isinstance(part_config, list), err_msg
169 | assert len(part_config) > 0, err_msg
170 | assert isinstance(part_config[0], ModuleConfig), err_msg
171 | weights[name] = [load_config(part, weights_getter, f"{name_prefix}{i}.")
172 | for i, part in enumerate(part_config)]
173 | weights.update(load_weights(config.weights, weights_getter, name_prefix))
174 | return weights
175 |
176 |
177 | def load_weights(weight_configs: WeightConfigDict, weights_getter: Callable[[str], Arr],
178 | prefix: str = "") -> WeightsTree:
179 | weights: WeightsTree = {}
180 | for name, weight_config in weight_configs.items():
181 | assert isinstance(weight_config,
182 | WeightConfig), f"weight_config {name} is not a WeightConfig, but {type(weight_config)}"
183 | weights[name] = weights_getter(f"{prefix}{weight_config.save_name}")
184 | return weights
185 |
186 |
187 | def init_weights(weights_config: WeightConfigDict, rng_key: SafeKey) -> WeightsTree:
188 | w: WeightsTree = {}
189 | all_keys = rng_key.split(len(weights_config))
190 | for key, (name, weight_config) in zip(all_keys, weights_config.items()):
191 | assert isinstance(weight_config, WeightConfig), f"Config {name} should be a WeightConfig, got {weight_config}"
192 | w[name] = weight_config.make(key)
193 | return w
194 |
195 |
196 | def init_weight_module(module: ModuleConfig, rng_key: SafeKey) -> WeightsTree:
197 | try:
198 | parts_key, weights_key = rng_key.split(2)
199 | w = init_weight_parts(module.parts, parts_key)
200 | w.update(init_weights(module.weights, weights_key))
201 | except Exception as e:
202 | raise Exception(f"Failed to initialize module {module.save_name}.") from e
203 | return w
204 |
205 |
206 | def init_weight_parts(parts: PartsDict, rng_key: SafeKey) -> WeightsTree:
207 | w: WeightsTree = {}
208 | all_keys = rng_key.split(len(parts))
209 | for key, (name, part_config) in zip(all_keys, parts.items()):
210 | if isinstance(part_config, ModuleConfig):
211 | w[name] = init_weight_module(part_config, key)
212 | else:
213 | err_msg = f"Config {name} should be a ModuleConfig or a non-empty list of ModuleConfigs, got {part_config}"
214 | assert isinstance(part_config, list), err_msg
215 | assert len(part_config) > 0, err_msg
216 | assert isinstance(part_config[0], ModuleConfig), err_msg
217 | subkeys = key.split(len(part_config))
218 | w[name] = [init_weight_module(part, subkey)
219 | for i, (subkey, part) in enumerate(zip(subkeys, part_config))]
220 | return w
221 |
--------------------------------------------------------------------------------
/gpt_recon/gpt.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import math
4 | from functools import cache
5 | from typing import NamedTuple, Optional, cast, TypedDict, List, Literal, Union
6 |
7 | from chex import assert_shape
8 | from einops import rearrange
9 | from jax import numpy as jnp, vmap
10 | from jax.nn import softmax
11 |
12 | from picojax.jax_utils import jit_f, Arr, WeightsTree
13 | from gpt_recon.clean_frame import Linear, for_all_T, gelu, LN
14 | from gpt_recon.clean_frame_utils import check_config, WeightConfigDict, PartsDict, config_weights_check, \
15 | WeightConfig
16 |
17 |
18 | class GptMha:
19 | class Weights(TypedDict):
20 | QKV_linear: Linear.Weights
21 | linear: Linear.Weights
22 |
23 | @jit_f
24 | def causal_dot_attention(self, q: Arr, k: Arr, v: Arr) -> Arr:
25 | assert_shape([q, k, v], (self.T, self.dim_heads))
26 | mask = self.get_mask(q.shape[0])
27 | result = softmax((q @ k.T) / self.scale + mask) @ v
28 | assert_shape(result, (self.T, self.dim_heads))
29 | return result
30 |
31 | @jit_f
32 | def f(self, w: GptMha.Weights, x: Arr) -> Arr:
33 | assert_shape(x, (self.T, self.n_channels))
34 |
35 | q, k, v = rearrange(self.QKV_linearf(w['QKV_linear'], x),
36 | 'T (qkv n_heads dim_heads) -> qkv n_heads T dim_heads',
37 | qkv=3,
38 | n_heads=self.n_heads,
39 | dim_heads=self.dim_heads)
40 | # extension_shape = ['n_head', ...]
41 | # attended = xmap(self.causal_dot_attention, [extension_shape] * 3, extension_shape)(q, k, v)
42 | attended = vmap(self.causal_dot_attention, (0, 0, 0), 0)(q, k, v)
43 | assert_shape(attended, (self.n_heads, self.T, self.dim_heads))
44 | concatenated = jnp.concatenate(attended, -1)
45 | assert_shape(concatenated, (self.T, self.n_channels))
46 |
47 | result = self.linearf(w['linear'], concatenated)
48 |
49 | assert_shape(result, (self.T, self.n_channels))
50 | return result
51 |
52 | def f_debug(self, w: GptMha.Weights, x: Arr) -> dict[str, Arr]:
53 | assert_shape(x, (self.T, self.n_channels))
54 |
55 | qs, ks, vs = rearrange(self.QKV_linearf(w['QKV_linear'], x),
56 | 'T (qkv n_heads dim_heads) -> qkv n_heads T dim_heads',
57 | qkv=3,
58 | n_heads=self.n_heads,
59 | dim_heads=self.dim_heads)
60 | mask = self.get_mask(qs.shape[1])
61 | attended_list = []
62 | attn_maps = []
63 | attn_maps_raw = []
64 | for q, k, v in zip(qs, ks, vs):
65 | attn_map_raw = q @ k.T
66 | attn_maps_raw.append(attn_map_raw)
67 | attn_map = softmax(attn_map_raw / self.scale + mask)
68 | attn_maps.append(attn_map)
69 | attended_list.append(attn_map @ v)
70 | attended = jnp.stack(attended_list)
71 | assert_shape(attended, (self.n_heads, self.T, self.dim_heads))
72 | concatenated = jnp.concatenate(attended, -1)
73 | assert_shape(concatenated, (self.T, self.n_channels))
74 |
75 | result = self.linearf(w['linear'], concatenated)
76 |
77 | assert_shape(result, (self.T, self.n_channels))
78 | return dict(
79 | attn=jnp.stack(attn_maps),
80 | attn_raw=jnp.stack(attn_maps_raw),
81 | x_after_mha=result)
82 |
83 | class Config(NamedTuple):
84 | n_channels: Optional[int] = None
85 | n_heads: Optional[int] = None
86 | n_seq: Optional[Union[int, Literal['dynamic']]] = None
87 | inf_mask: float = -1e10
88 | linear: Linear.Config = Linear.Config()
89 | QKV_linear: Linear.Config = Linear.Config()
90 |
91 | save_name: str = "mha"
92 |
93 | @property
94 | @cache
95 | def T(self) -> Optional[int]:
96 | if self.n_seq == 'dynamic':
97 | return None
98 | else:
99 | return self.n_seq
100 |
101 | def fill(self) -> GptMha.Config:
102 | assert self.n_channels is not None, 'n_channels must be set'
103 | new = self._replace(linear=self.linear._replace(n_in=self.n_channels, n_out=self.n_channels),
104 | QKV_linear=self.QKV_linear._replace(n_in=self.n_channels, n_out=3 * self.n_channels))
105 | check_config(new)
106 | return new
107 |
108 | def make(self) -> GptMha:
109 | return GptMha(self.fill())
110 |
111 | @property
112 | def dim_heads(self) -> int:
113 | assert self.n_channels is not None
114 | assert self.n_heads is not None
115 | return self.n_channels // self.n_heads
116 |
117 | @property
118 | def weights(self) -> WeightConfigDict:
119 | return {}
120 |
121 | @property
122 | def parts(self) -> PartsDict:
123 | filled = self.fill()
124 | return dict(
125 | linear=filled.linear,
126 | QKV_linear=filled.QKV_linear
127 | )
128 |
129 | def weights_check(self, w: WeightsTree) -> GptMha.Weights:
130 | return cast(GptMha.Weights, config_weights_check(self, w))
131 |
132 | def __init__(self, config: Config):
133 | assert config.n_channels is not None
134 | assert config.n_heads is not None
135 | assert config.dim_heads is not None
136 | self.n_channels = config.n_channels
137 | self.n_heads = config.n_heads
138 | self.T = config.T
139 | self.dim_heads = config.dim_heads
140 | self.config = config
141 |
142 | self.linear = config.linear.make()
143 | self.QKV_linear = config.QKV_linear.make()
144 | self.scale = math.sqrt(self.dim_heads)
145 | self.linearf = for_all_T(self.linear.f)
146 | self.QKV_linearf = for_all_T(self.QKV_linear.f)
147 | assert self.n_channels % self.n_heads == 0, 'n_channels must be divisible by n_heads'
148 |
149 | def get_mask(self, t: int) -> Arr:
150 | return (1 - jnp.tri(t)) * self.config.inf_mask
151 |
152 |
153 | class GptFfn:
154 | class Weights(TypedDict):
155 | linear1: Linear.Weights
156 | linear2: Linear.Weights
157 |
158 | @jit_f
159 | def f(self, w: GptFfn.Weights, x: Arr) -> Arr:
160 | assert_shape(x, (self.n_channels,))
161 | result = self.linear2.f(w['linear2'], gelu(self.linear1.f(w['linear1'], x)))
162 | assert_shape(result, (self.n_channels,))
163 | return result
164 |
165 | class Config(NamedTuple):
166 | n_channels: Optional[int] = None
167 | n_seq: Optional[Union[int, Literal['dynamic']]] = None
168 | linear1: Linear.Config = Linear.Config()
169 | linear2: Linear.Config = Linear.Config()
170 | save_name: str = "ffn"
171 |
172 | @property
173 | @cache
174 | def T(self) -> Optional[int]:
175 | if self.n_seq == 'dynamic':
176 | return None
177 | else:
178 | return self.n_seq
179 |
180 | def fill(self) -> GptFfn.Config:
181 | assert self.n_channels is not None
182 | new = self._replace(
183 | linear1=self.linear1._replace(n_in=self.n_channels, n_out=self.n_channels * 4),
184 | linear2=self.linear2._replace(n_in=self.n_channels * 4, n_out=self.n_channels))
185 | check_config(new)
186 | return new
187 |
188 | def make(self) -> GptFfn:
189 | return GptFfn(self.fill())
190 |
191 | @property
192 | def weights(self) -> WeightConfigDict:
193 | return {}
194 |
195 | @property
196 | def parts(self) -> PartsDict:
197 | filled = self.fill()
198 | return dict(
199 | linear1=filled.linear1,
200 | linear2=filled.linear2
201 | )
202 |
203 | def weights_check(self, w: WeightsTree) -> GptFfn.Weights:
204 | return cast(GptFfn.Weights, config_weights_check(self, w))
205 |
206 | def __init__(self, config: Config):
207 | self.n_channels = config.n_channels
208 | self.linear1 = config.linear1.make()
209 | self.linear2 = config.linear2.make()
210 |
211 |
212 | class GptBlock:
213 | class Weights(TypedDict):
214 | mha: GptMha.Weights
215 | ffn: GptFfn.Weights
216 | ln1: LN.Weights
217 | ln2: LN.Weights
218 |
219 | @jit_f
220 | def f(self, w: GptBlock.Weights, x: Arr) -> Arr:
221 | assert_shape(x, (self.T, self.n_channels))
222 | x += self.mha.f(w['mha'], self.ln1f(w['ln1'], x))
223 | x += self.ffnf(w['ffn'], self.ln2f(w['ln2'], x))
224 | assert_shape(x, (self.T, self.n_channels))
225 | return x
226 |
227 | def f_debug(self, w: GptBlock.Weights, x: Arr) -> dict[str, Arr]:
228 | return_dict = {}
229 | x0 = x
230 | return_dict['x0'] = x
231 | x = self.ln1f(w['ln1'], x)
232 | return_dict['x_before_mha'] = x
233 | attn_result = self.mha.f_debug(w['mha'], x)
234 | return_dict.update(attn_result)
235 | x = attn_result['x_after_mha']
236 | x = x0 + x
237 | x0 = x
238 | x = self.ln2f(w['ln2'], x)
239 | return_dict['x_before_ffn'] = x
240 | x = self.ffnf(w['ffn'], x)
241 | return_dict['x_after_ffn'] = x
242 | x = x0 + x
243 | return_dict['x'] = x
244 | return return_dict
245 |
246 | class Config(NamedTuple):
247 | eps: Optional[float] = None
248 | n_channels: Optional[int] = None
249 | n_heads: Optional[int] = None
250 | n_seq: Optional[Union[int, Literal['dynamic']]] = None
251 | mha: GptMha.Config = GptMha.Config()
252 | ffn: GptFfn.Config = GptFfn.Config()
253 | ln1: LN.Config = LN.Config()
254 | ln2: LN.Config = LN.Config()
255 | save_name: str = "gpt_block"
256 |
257 | @property
258 | def T(self) -> Optional[int]:
259 | if self.n_seq == 'dynamic':
260 | return None
261 | else:
262 | return self.n_seq
263 |
264 | @property
265 | def x_shape(self) -> tuple[Optional[int], ...]:
266 | assert self.n_channels is not None
267 | return self.T, self.n_channels
268 |
269 | def fill(self) -> GptBlock.Config:
270 | new = self._replace(
271 | mha=self.mha._replace(n_channels=self.n_channels, n_seq=self.n_seq, n_heads=self.n_heads).fill(),
272 | ffn=self.ffn._replace(n_channels=self.n_channels, n_seq=self.n_seq).fill(),
273 | ln1=self.ln1._replace(eps=self.eps, norm_dims=(0,), x_shape=self.x_shape),
274 | ln2=self.ln2._replace(eps=self.eps, norm_dims=(0,), x_shape=self.x_shape))
275 | check_config(new)
276 | return new
277 |
278 | def make(self) -> GptBlock:
279 | return GptBlock(self.fill())
280 |
281 | @property
282 | def weights(self) -> WeightConfigDict:
283 | return {}
284 |
285 | @property
286 | def parts(self) -> PartsDict:
287 | filled = self.fill()
288 | return dict(
289 | mha=filled.mha,
290 | ffn=filled.ffn,
291 | ln1=filled.ln1,
292 | ln2=filled.ln2,
293 | )
294 |
295 | def weights_check(self, w: WeightsTree) -> GptBlock.Weights:
296 | return cast(GptBlock.Weights, config_weights_check(self, w))
297 |
298 | def __init__(self, config: Config):
299 | self.T = config.T
300 | self.n_channels = config.n_channels
301 | self.mha = config.mha.make()
302 | self.ffn = config.ffn.make()
303 | self.ln1 = config.ln1.make()
304 | self.ln2 = config.ln2.make()
305 |
306 | self.ffnf = for_all_T(self.ffn.f)
307 | self.ln1f = self.ln1.f
308 | self.ln2f = self.ln2.f
309 |
310 |
311 | class GptDecoder:
312 | class Weights(TypedDict):
313 | blocks: List[GptBlock.Weights]
314 |
315 | @jit_f
316 | def f(self, w: GptDecoder.Weights, x: Arr) -> Arr:
317 | assert_shape(x, (self.T, self.n_channels))
318 | for blk, blk_w in zip(self.blocks, w['blocks']):
319 | x = blk.f(blk_w, x)
320 | assert_shape(x, (self.T, self.n_channels))
321 | return x
322 |
323 | def f_debug(self, w: GptDecoder.Weights, x: Arr) -> tuple[Arr, dict[str, Arr]]:
324 | assert_shape(x, (self.T, self.n_channels))
325 | view_vec = []
326 | for blk, blk_w in zip(self.blocks, w['blocks']):
327 | return_dict = blk.f_debug(blk_w, x)
328 | view_vec.append(return_dict)
329 | x = return_dict['x']
330 | assert_shape(x, (self.T, self.n_channels))
331 | vecs = {key: jnp.stack([dict_item[key] for dict_item in view_vec]) for key in view_vec[0].keys()}
332 | return x, vecs
333 |
334 | class Config(NamedTuple):
335 | eps: Optional[float] = None
336 | n_channels: Optional[int] = None
337 | n_heads: Optional[int] = None
338 | n_seq: Optional[Union[int, Literal['dynamic']]] = None
339 | n_blocks: Optional[int] = None
340 | blocks: GptBlock.Config = GptBlock.Config()
341 |
342 | save_name: str = 'gpt_decoder'
343 |
344 | @property
345 | @cache
346 | def T(self) -> Optional[int]:
347 | if self.n_seq == 'dynamic':
348 | return None
349 | else:
350 | return self.n_seq
351 |
352 | def fill(self) -> GptDecoder.Config:
353 | new = self._replace(blocks=self.blocks._replace(eps=self.eps, n_channels=self.n_channels,
354 | n_heads=self.n_heads, n_seq=self.n_seq).fill())
355 | check_config(new)
356 | return new
357 |
358 | def make(self) -> GptDecoder:
359 | return GptDecoder(self.fill())
360 |
361 | @property
362 | def weights(self) -> WeightConfigDict:
363 | return {}
364 |
365 | @property
366 | def parts(self) -> PartsDict:
367 | filled = self.fill()
368 | assert filled.blocks is not None
369 | assert filled.n_blocks is not None
370 | return dict(
371 | blocks=[filled.blocks] * filled.n_blocks
372 | )
373 |
374 | def weights_check(self, w: WeightsTree) -> GptDecoder.Weights:
375 | return cast(GptDecoder.Weights, config_weights_check(self, w))
376 |
377 | def __init__(self, config: Config):
378 | assert config.n_blocks is not None
379 | self.T = config.T
380 | self.n_channels = config.n_channels
381 | self.blocks = [config.blocks.make() for _ in range(config.n_blocks)]
382 |
383 |
384 | class Gpt:
385 | class Weights(TypedDict):
386 | token_embedding: Arr
387 | positional_encoding: Arr
388 | decoder: GptDecoder.Weights
389 | ln: LN.Weights
390 |
391 | @jit_f
392 | def f(self, w: Gpt.Weights, x: Arr) -> Arr:
393 | assert_shape(x, (self.T,))
394 | # remove x.shape[0] for faster but static seq len
395 | if self.T is None:
396 | x = w['token_embedding'][x, :] + w['positional_encoding'][:x.shape[0], :]
397 | else:
398 | x = w['token_embedding'][x, :] + w['positional_encoding'][:self.T, :]
399 | assert_shape(x, (self.T, self.n_channels))
400 | result = self.ln.f(w['ln'], self.decoder.f(w['decoder'], x))
401 | assert_shape(result, (self.T, self.n_channels))
402 | return result @ w['token_embedding'].T
403 |
404 | def f_debug(self, w: Gpt.Weights, x: Arr) -> tuple[Arr, dict[str, Arr]]:
405 | assert_shape(x, (self.T,))
406 | x = w['token_embedding'][x, :] + w['positional_encoding'][:x.shape[0], :]
407 | assert_shape(x, (self.T, self.n_channels))
408 | decoder_out, save_vecs = self.decoder.f_debug(w['decoder'], x)
409 | result = self.ln.f(w['ln'], decoder_out)
410 | assert_shape(result, (self.T, self.n_channels))
411 | return result @ w['token_embedding'].T, save_vecs
412 |
413 | class Config(NamedTuple):
414 | eps: Optional[float] = None
415 | n_channels: Optional[int] = None
416 | n_heads: Optional[int] = None
417 | n_seq: Optional[Union[int, Literal['dynamic']]] = None
418 | n_blocks: Optional[int] = None
419 | n_tokens: Optional[int] = None
420 | max_seq_len: Optional[int] = None
421 |
422 | token_embedding_save_name: str = 'te'
423 | token_embedding_init: Literal['normal'] = 'normal'
424 | token_embedding_scale: float = 0.02
425 |
426 | decoder: GptDecoder.Config = GptDecoder.Config()
427 | ln: LN.Config = LN.Config()
428 |
429 | positional_embedding_save_name: str = 'pe'
430 |
431 | save_name: str = 'gpt'
432 |
433 | @property
434 | def T(self) -> Optional[int]:
435 | if self.n_seq == 'dynamic':
436 | return None
437 | else:
438 | return self.n_seq
439 |
440 | def fill(self) -> Gpt.Config:
441 | assert self.n_channels is not None, 'n_channels must be specified'
442 | new = self._replace(decoder=self.decoder._replace(eps=self.eps, n_channels=self.n_channels,
443 | n_heads=self.n_heads, n_seq=self.n_seq,
444 | n_blocks=self.n_blocks).fill(),
445 | ln=self.ln._replace(eps=self.eps, norm_dims=(0,), x_shape=(self.T, self.n_channels)))
446 |
447 | check_config(new)
448 | return new
449 |
450 | def make(self) -> Gpt:
451 | return Gpt(self.fill())
452 |
453 | @property
454 | def weights(self) -> WeightConfigDict:
455 | filled = self.fill()
456 | assert filled.max_seq_len is not None
457 | assert filled.n_tokens is not None
458 | assert filled.n_channels is not None
459 | return dict(
460 | token_embedding=WeightConfig(save_name=filled.token_embedding_save_name,
461 | init=filled.token_embedding_init,
462 | shape=(filled.n_tokens, filled.n_channels),
463 | scale=filled.token_embedding_scale),
464 |
465 | positional_encoding=WeightConfig(save_name=filled.positional_embedding_save_name,
466 | shape=(filled.max_seq_len, filled.n_channels)),
467 | )
468 |
469 | @property
470 | def parts(self) -> PartsDict:
471 | filled = self.fill()
472 | assert filled.decoder is not None
473 | assert filled.ln is not None
474 | assert filled.token_embedding_save_name is not None
475 | return dict(
476 | decoder=filled.decoder,
477 | ln=filled.ln,
478 | )
479 |
480 | def weights_check(self, w: WeightsTree) -> Gpt.Weights:
481 | return cast(Gpt.Weights, config_weights_check(self, w))
482 |
483 | def __init__(self, config: Config):
484 | assert config.n_blocks is not None
485 | assert config.n_tokens is not None
486 | assert config.n_channels is not None
487 | self.config = config
488 | self.T = config.T
489 | self.n_channels = config.n_channels
490 | self.n_tokens = config.n_tokens
491 | self.eps = config.eps
492 | self.decoder = config.decoder.make()
493 | self.ln = config.ln.make()
494 |
495 |
496 |
497 |
498 | def get_positional_encoding(max_len: int, d_model: int):
499 | pe = jnp.zeros((max_len, d_model))
500 | position = jnp.expand_dims(jnp.arange(0, max_len), 1)
501 | div_term = jnp.exp(
502 | jnp.arange(0, d_model, 2) * -(jnp.log(10000.0) / d_model)
503 | )
504 | pe = pe.at[:, 0::2].set(jnp.sin(position * div_term))
505 | pe = pe.at[:, 1::2].set(jnp.cos(position * div_term))
506 | return pe
507 |
508 |
509 |
--------------------------------------------------------------------------------
/gpt_recon/load_model.py:
--------------------------------------------------------------------------------
1 | # %%
2 | from safetensors import safe_open
3 | from pathlib import Path
4 |
5 | # %%
6 | path = Path("/Data/lm_models/gpt2")
7 | with safe_open(path / "model.safetensors", framework="flax", device="cpu") as f:
8 | for key in f.keys():
9 | t = f.get_tensor(key)
10 | print(key, t.shape)
11 |
12 | print(f.get_tensor("h.3.attn.bias"))
13 | print(type(t))
14 |
--------------------------------------------------------------------------------
/gpt_recon/run_gpt.py:
--------------------------------------------------------------------------------
1 | # %%
2 | from __future__ import annotations
3 |
4 | from collections import defaultdict
5 | from pathlib import Path
6 |
7 | import jax.numpy as jnp
8 | from jax import jit
9 | from safetensors import safe_open
10 | from safetensors.flax import save_file
11 |
12 | from bpe_encoder import get_encoder
13 | from gpt_recon.clean_frame import LN, Linear
14 | from gpt_recon.clean_frame_utils import Arr, load_config
15 | from gpt_recon.gpt import GptMha, GptFfn, GptBlock, GptDecoder, Gpt
16 |
17 | gpt_config = Gpt.Config(eps=1e-5,
18 | n_channels=768,
19 | n_heads=12,
20 | n_seq='dynamic',
21 | max_seq_len=1024,
22 | n_blocks=12,
23 | n_tokens=50257,
24 | token_embedding_save_name='wte.weight',
25 | positional_embedding_save_name='wpe.weight',
26 | ln=LN.Config(w_save_name='weight', b_save_name='bias', save_name='ln_f'),
27 | save_name="",
28 | decoder=GptDecoder.Config(
29 | save_name='h',
30 | blocks=GptBlock.Config(
31 | save_name="",
32 | mha=GptMha.Config(
33 | save_name='attn',
34 | QKV_linear=Linear.Config(save_name='c_attn', w_save_name='weight', b_save_name='bias'),
35 | linear=Linear.Config(save_name="c_proj", w_save_name='weight', b_save_name='bias'),
36 | ),
37 | ln1=LN.Config(w_save_name='weight', b_save_name='bias', save_name='ln_1'),
38 | ffn=GptFfn.Config(
39 | save_name='mlp',
40 | linear1=Linear.Config(w_save_name='weight', b_save_name='bias', save_name='c_fc'),
41 | linear2=Linear.Config(w_save_name='weight', b_save_name='bias', save_name='c_proj'),
42 | ),
43 | ln2=LN.Config(w_save_name='weight', b_save_name='bias', save_name='ln_2')
44 | )
45 | )).fill()
46 |
47 | print(gpt_config)
48 | print(gpt_config.weights)
49 | print(gpt_config.parts)
50 |
51 | z = defaultdict(int)
52 | ww = load_config(gpt_config, lambda i: z[i])
53 | print(z.keys())
54 |
55 | path = Path("/Data/lm_models/gpt2")
56 | with safe_open(path / "model.safetensors", framework="flax", device="cpu") as f:
57 | weights_tree_ = load_config(gpt_config, f.get_tensor)
58 | print(weights_tree_.keys())
59 |
60 | gpt_ = gpt_config.make()
61 | checked_weights = gpt_config.weights_check(weights_tree_)
62 |
63 |
64 | # r = gpt_.f(checked_weights, jnp.ones((1024,), dtype=jnp.int32))
65 | # print(r)
66 |
67 |
68 | def run(inputs):
69 | return gpt_.f(checked_weights, jnp.array(inputs))
70 |
71 |
72 | def debug(inputs) -> dict[str, Arr]:
73 | logits, to_save = jit(gpt_.f_debug)(checked_weights, jnp.array(inputs))
74 | out = encoder.decode([int(jnp.argmax(logits[-1]))])
75 | print(out)
76 | save_dir = "../saves"
77 | return save_file(to_save, f'{save_dir}/view_vec2_dict_jit')
78 |
79 |
80 | encoder = get_encoder("gpt2", "/Data/lm_models/", "vocab.json", "merges.txt")
81 |
82 | # prompt = "Alan Turing theorized that computers would one day become"
83 | prompt = "Time flies like an arrow, fruit flies like a banana. Time flies like an"
84 | # prompt = "Time flies like an arrow; fruit flies like a banana."
85 | # prompt = "The number of pixels used to render an image is set by the Axes size and the dpi of the figure. This can lead to aliasing artifacts when the image is resampled because the displayed image size will usually not match the size of X (see Image antialiasing)."
86 | input_ids = encoder.encode(prompt)
87 | print([encoder.decoder[t] for t in input_ids])
88 |
89 | debug(input_ids)
90 |
91 | # output_ids = generate(run, input_ids, 8)
92 | # output_text = encoder.decode(output_ids)
93 | # print(output_text)
94 |
95 | # TODO:
96 | # [done] implement weight check to parse a weight tree to proper weights
97 | # make shape a dict
98 | # [done] move out init_params
99 | # input/output shapes in config
100 | # [done] fix layernorm issue: clever loading
101 | # [done] make dyn a config instead of separate function
102 | # [done] compare with pico to fix stuff
103 | # change name to save_name
104 | # [half] investigate ffn residual (make good visualizations)
105 | # [done] complete the debug_chain
106 | # make view panel
107 | # [half] better folder structure
108 |
--------------------------------------------------------------------------------
/gpt_recon/train_gpt.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | from functools import partial
5 | from typing import Callable
6 |
7 | import jax.numpy as jnp
8 | import optax
9 | import wandb
10 | from optax import softmax_cross_entropy_with_integer_labels
11 |
12 | import gpt_recon.gpt as gpt
13 | import nlp_utils
14 | from gpt_recon.clean_frame import batch_fy, Linear
15 | from gpt_recon.clean_frame_utils import Arr, init_weight_module
16 | from custom_dataset import load_jax_cached
17 | from gpt_recon.gpt import Gpt
18 | from picojax.random_utils import infinite_safe_keys
19 | from picojax.train_utils import LMBatchConfig, TrainConfig, TrainState, BatchType
20 |
21 | # need to install the updated version for optax:
22 | # pip install git+https://github.com/deepmind/optax.git
23 |
24 | os.environ['JAX_LOG_COMPILES'] = '1'
25 |
26 |
27 | def gpt_loss_batch(forward: Callable[[Gpt.Weights, Arr], Arr], w: Gpt.Weights, batch: BatchType) -> Arr:
28 | inputs, labels = batch
29 | logits = batch_fy(forward)(w, jnp.array(inputs))
30 | return softmax_cross_entropy_with_integer_labels(logits, jnp.array(labels)).mean()
31 |
32 |
33 | dataset = "play"
34 | encoded_jax, encode, decode, vocab_size_ = load_jax_cached(dataset=dataset)
35 | n = int(len(encoded_jax) * 0.9)
36 | train_data = encoded_jax[:n]
37 | valid_data = encoded_jax[n:]
38 |
39 | key_gen = infinite_safe_keys(0)
40 |
41 | adam_params = {
42 | 'learning_rate': 1e-4,
43 | 'beta1': 0.9,
44 | 'beta2': 0.999,
45 | 'eps': 1e-8,
46 | }
47 | lion_params = {
48 | 'learning_rate': 1e-4,
49 | 'beta1': 0.95,
50 | 'beta2': 0.98,
51 | 'weight_decay': 0.01
52 | }
53 | train_params = {
54 | 'eval_iters': 200,
55 | 'eval_interval': 2000,
56 | 'max_iters': 1000000,
57 | # 'adam': adam_params,
58 | 'lion': lion_params,
59 | # 'adamw': adam_params,
60 | 'optimizer': 'lion',
61 | }
62 |
63 | experimental_params = {
64 | 'eps': 1e-5,
65 | 'n_tokens': vocab_size_,
66 | 'n_channels': 768,
67 | 'n_heads': 12,
68 | 'n_blocks': 12,
69 |
70 | 'batch_size': 8,
71 | 'block_size': 128,
72 | 'train': train_params
73 | }
74 |
75 | max_iters = experimental_params['train']['max_iters']
76 | eval_interval = experimental_params['train']['eval_interval']
77 | eval_iters = experimental_params['train']['eval_iters']
78 | batch_config_ = LMBatchConfig(block_size=experimental_params['block_size'],
79 | batch_size=experimental_params['batch_size'])
80 |
81 | # dataset = "english"
82 |
83 | gpt_config_ = Gpt.Config(eps=experimental_params['eps'],
84 | n_channels=experimental_params['n_channels'],
85 | n_heads=experimental_params['n_heads'],
86 | # n_seq='dynamic',
87 | n_seq=batch_config_.block_size,
88 | max_seq_len=batch_config_.block_size,
89 | n_blocks=experimental_params['n_blocks'],
90 | n_tokens=vocab_size_,
91 | decoder=gpt.GptDecoder.Config(
92 | blocks=gpt.GptBlock.Config(
93 | mha=gpt.GptMha.Config(
94 | linear=Linear.Config(
95 | ),
96 | QKV_linear=Linear.Config(
97 | )
98 | ),
99 | ffn=gpt.GptFfn.Config(
100 | )
101 | )
102 | )).fill()
103 |
104 | raw_weights = init_weight_module(gpt_config_, next(key_gen))
105 | raw_weights['positional_encoding'] = gpt.get_positional_encoding(experimental_params['block_size'],
106 | experimental_params['n_channels'])
107 | init_weights_ = gpt_config_.weights_check(raw_weights)
108 |
109 | if experimental_params['train']['optimizer'] == 'adam':
110 | adam_config = experimental_params['train']['adam']
111 | optimizer_ = optax.adam(learning_rate=adam_config['learning_rate'],
112 | b1=adam_config['beta1'],
113 | b2=adam_config['beta2'],
114 | eps=adam_config['eps'])
115 | elif experimental_params['train']['optimizer'] == 'lion':
116 | lion_config = experimental_params['train']['lion']
117 | optimizer_ = optax.lion(learning_rate=lion_config['learning_rate'],
118 | b1=lion_config['beta1'],
119 | b2=lion_config['beta2'],
120 | weight_decay=lion_config['weight_decay'])
121 | elif experimental_params['train']['optimizer'] == 'adamw':
122 | adamw_config = experimental_params['train']['adamw']
123 | optimizer_ = optax.adamw(learning_rate=adamw_config['learning_rate'],
124 | b1=adamw_config['beta1'],
125 | b2=adamw_config['beta2'],
126 | eps=adamw_config['eps'])
127 | else:
128 | raise ValueError(f"optimizer {experimental_params['train']['optimizer']} not supported")
129 |
130 | model = gpt_config_.make()
131 | loss_f = partial(gpt_loss_batch, model.f)
132 | # noinspection PyArgumentList
133 | # cuz it's a NamedTuple
134 | train_config_ = TrainConfig(loss_fn=loss_f,
135 | optimiser=optimizer_)
136 | # noinspection PyArgumentList
137 | # cuz it's a NamedTuple
138 | train_state_: TrainState[Gpt.Weights] = TrainState(weights=init_weights_,
139 | opt_state=optimizer_.init(init_weights_))
140 |
141 | wandb.init(
142 | project="inside-transformer",
143 | config=experimental_params,
144 | )
145 | keys_ = next(key_gen).split(max_iters)
146 | for step in range(max_iters):
147 | batch_ = batch_config_.sample(train_data, keys_[step])
148 | if step % eval_interval == 0:
149 | loss = train_config_.loss_fn(train_state_.weights, batch_)
150 | print(f"===step {step} is an eval step===")
151 | print(f"before step {step}, batch loss {loss}")
152 |
153 | train_state_ = train_config_.train1(train_state_, batch_)
154 | if step % eval_interval == 0:
155 | loss = train_config_.loss_fn(train_state_.weights, batch_)
156 | print(f"after step {step}, batch loss {loss}")
157 | results = train_config_.estimate_loss(eval_iters, key_gen, train_state_,
158 | {'train': partial(batch_config_.sample, train_data),
159 | 'val': partial(batch_config_.sample, valid_data)})
160 | # generate_f = jax.jit(partial(dynamic_model_f, train_state_.weights))
161 | # generated = gpt.generate(generate_f, [0], n_tokens_to_generate=10,
162 | # max_len=batch_config_.block_size)
163 | generate_f = partial(model.f, train_state_.weights)
164 | # TODO fix generation/ add temperature
165 | generated = nlp_utils.generate_static(generate_f, [0],
166 | n_tokens_to_generate=batch_config_.block_size - 1,
167 | max_len=batch_config_.block_size)
168 | wandb.log({"train_loss": results['train'],
169 | "validation_loss": results['val'],
170 | "batch_loss": loss,
171 | "n_tokens_trained": step * batch_config_.batch_size * batch_config_.block_size,
172 | 'generated': wandb.Html(f"{decode(generated)}")})
173 | print(decode(generated), flush=True)
174 |
175 | wandb.finish()
176 | # TODO: add trainable weights
177 | # TODO: add weight decay
178 |
--------------------------------------------------------------------------------
/gpt_recon/train_notebook.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | from functools import partial
5 | from pprint import pprint
6 | from typing import Callable
7 |
8 | import jax.numpy as jnp
9 | import optax
10 | from jax import random
11 | from optax import softmax_cross_entropy_with_integer_labels
12 |
13 | import gpt_recon.gpt as gpt
14 | import nlp_utils
15 | from gpt_recon.clean_frame import batch_fy
16 | from gpt_recon.clean_frame_utils import Arr, config_weights_check, init_weight_module, ModuleConfig
17 | from custom_dataset import load_jax_cached
18 | from gpt_recon.gpt import Gpt, GptMha
19 | from picojax.random_utils import SafeKey, infinite_safe_keys
20 | from picojax.train_utils import LMBatchConfig, TrainConfig, TrainState, BatchType
21 |
22 | os.environ['JAX_LOG_COMPILES'] = '1'
23 |
24 |
25 | def go(c: ModuleConfig, x: Arr, key: SafeKey) -> Arr:
26 | weights = init_weight_module(c, key)
27 | return c.make().f(weights, x)
28 |
29 |
30 | gpt_mha_config_ = GptMha.Config(n_channels=9,
31 | n_heads=3,
32 | n_seq='dynamic').fill()
33 | pprint(gpt_mha_config_.parts)
34 |
35 | w = init_weight_module(gpt_mha_config_, SafeKey(random.PRNGKey(0)))
36 | # print(w)
37 | checked = config_weights_check(gpt_mha_config_, w)
38 | # print(checked)
39 | print(go(gpt_mha_config_, jnp.ones((5, 9)), SafeKey(random.PRNGKey(0))).shape)
40 |
41 |
42 | # @partial(jax.jit, static_argnums=(0,))
43 | def gpt_loss_batch(forward: Callable[[Gpt.Weights, Arr], Arr], w: Gpt.Weights, batch: BatchType) -> Arr:
44 | inputs, labels = batch
45 | logits = batch_fy(forward)(w, jnp.array(inputs))
46 | return softmax_cross_entropy_with_integer_labels(logits, jnp.array(labels)).mean()
47 |
48 |
49 | # %%
50 |
51 |
52 | key_gen = infinite_safe_keys(0)
53 | max_iters = 100000
54 | eval_interval = 1000
55 | learning_rate_ = 1e-4
56 | eval_iters = 100
57 | batch_config_ = LMBatchConfig(block_size=128, batch_size=4)
58 |
59 | # dataset = "english"
60 | dataset = "play"
61 | encoded_jax, encode, decode, vocab_size_ = load_jax_cached(dataset=dataset)
62 | n = int(len(encoded_jax) * 0.9)
63 | train_data = encoded_jax[:n]
64 | valid_data = encoded_jax[n:]
65 |
66 | gpt_config_ = Gpt.Config(eps=1e-5,
67 | n_channels=768,
68 | n_heads=12,
69 | # n_seq='dynamic',
70 | n_seq=batch_config_.block_size,
71 | max_seq_len=batch_config_.block_size,
72 | n_blocks=12,
73 | n_tokens=vocab_size_).fill()
74 | gpt_dynamic_config_ = gpt_config_._replace(n_seq='dynamic')
75 | dynamic_model = gpt_dynamic_config_.make()
76 | # result = go(gpt_config_, jnp.ones((1024,), dtype=jnp.int32), next(key_gen))
77 | # print(result.shape)
78 |
79 | init_weights_ = gpt_config_.weights_check(init_weight_module(gpt_config_, next(key_gen)))
80 | optimizer_ = optax.adam(learning_rate=learning_rate_)
81 |
82 | # noinspection PyArgumentList
83 | # cuz it's a NamedTuple
84 | train_config_ = TrainConfig(model=gpt_config_.make(),
85 | loss_fn_in=gpt_loss_batch,
86 | optimiser=optimizer_)
87 | # noinspection PyArgumentList
88 | # cuz it's a NamedTuple
89 | train_state_: TrainState[Gpt.Weights] = TrainState(weights=init_weights_,
90 | opt_state=optimizer_.init(init_weights_))
91 |
92 | # %%
93 | # precompile for better profiling
94 | # batch_ = batch_config_.sample(train_data, next(key_gen))
95 | # train_state_ = train_config_.train1(train_state_, batch_)
96 | # batch_ = batch_config_.sample(train_data, next(key_gen))
97 | # train_state_ = train_config_.train1(train_state_, batch_)
98 | # batch_ = batch_config_.sample(train_data, next(key_gen))
99 | # train_state_ = train_config_.train1(train_state_, batch_)
100 | # # print([x.device() for x in tree_flatten(train_state_.weights)[0]])
101 | # train_config_.estimate_loss(eval_iters,
102 | # key_gen,
103 | # train_state_,
104 | # batch_config_,
105 | # {'train': train_data, 'val': valid_data})
106 | # train_config_.estimate_loss(eval_iters,
107 | # key_gen,
108 | # train_state_,
109 | # batch_config_,
110 | # {'train': train_data, 'val': valid_data})
111 |
112 | # %%
113 |
114 | keys_ = next(key_gen).split(max_iters)
115 | #with jax.profiler.trace("./jax-trace", create_perfetto_link=True):
116 | for step in range(max_iters):
117 | batch_ = batch_config_.sample(train_data, keys_[step])
118 | if step % eval_interval == 0:
119 | loss = train_config_.loss_fn(train_state_.weights, batch_)
120 | print(f"===step {step} is an eval step===")
121 | print(f"before step {step}, batch loss {loss}")
122 | train_config_.estimate_loss(eval_iters,
123 | key_gen,
124 | train_state_,
125 | batch_config_,
126 | {'train': train_data, 'val': valid_data})
127 |
128 | # print(make_jaxpr(partial(jax_calc_updates, train_state_.optimiser, train_state_.loss_fn))(train_state_.weights, batch_, train_state_.opt_state))
129 | train_state_ = train_config_.train1(train_state_, batch_)
130 | if step % eval_interval == 0:
131 | loss = train_config_.loss_fn(train_state_.weights, batch_)
132 | print(f"after step {step}, batch loss {loss}")
133 | train_config_.estimate_loss(eval_iters, key_gen, train_state_, batch_config_,
134 | {'train': train_data, 'val': valid_data})
135 | # generate_f = jax.jit(partial(dynamic_model.f, train_state_.weights))
136 | # generated = gpt.generate(generate_f, [0], n_tokens_to_generate=10,
137 | # max_len=batch_config_.block_size)
138 | generate_f = partial(train_config_.model.f, train_state_.weights)
139 | # TODO fix generation/ add temperature
140 | generated = nlp_utils.generate_static(generate_f, [0],
141 | n_tokens_to_generate=batch_config_.block_size - 1,
142 | max_len=batch_config_.block_size)
143 | print(decode(generated), flush=True)
144 |
145 |
--------------------------------------------------------------------------------
/gpt_recon/tune_gpt.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | from functools import partial
5 | from pathlib import Path
6 | from typing import Callable
7 |
8 | import jax.numpy as jnp
9 | import optax
10 | # import wandb
11 | from optax import softmax_cross_entropy_with_integer_labels
12 | from safetensors import safe_open
13 |
14 | import gpt_recon.gpt as gpt
15 | import nlp_utils
16 | from gpt_recon.clean_frame import batch_fy, LN, Linear
17 | from gpt_recon.clean_frame_utils import Arr, load_config
18 | from custom_dataset import load_jax_cached
19 | from gpt_recon.gpt import Gpt, GptMha, GptFfn, GptBlock, GptDecoder
20 | from picojax.random_utils import infinite_safe_keys
21 | from picojax.train_utils import LMBatchConfig, TrainConfig, TrainState, BatchType
22 |
23 | # need to install the updated version for optax:
24 | # pip install git+https://github.com/deepmind/optax.git
25 |
26 | os.environ['JAX_LOG_COMPILES'] = '1'
27 |
28 |
29 | def gpt_loss_batch(forward: Callable[[Gpt.Weights, Arr], Arr], w: Gpt.Weights, batch: BatchType) -> Arr:
30 | inputs, labels = batch
31 | logits = batch_fy(forward)(w, jnp.array(inputs))
32 | return softmax_cross_entropy_with_integer_labels(logits, jnp.array(labels)).mean()
33 |
34 |
35 | dataset = "play"
36 | encoded_jax, encode, decode, vocab_size_ = load_jax_cached(dataset=dataset)
37 | n = int(len(encoded_jax) * 0.9)
38 | train_data = encoded_jax[:n]
39 | valid_data = encoded_jax[n:]
40 |
41 | key_gen = infinite_safe_keys(0)
42 |
43 | adam_params = {
44 | 'learning_rate': 1e-4,
45 | 'beta1': 0.9,
46 | 'beta2': 0.999,
47 | 'eps': 1e-8,
48 | }
49 | lion_params = {
50 | 'learning_rate': 1e-4,
51 | 'beta1': 0.9,
52 | 'beta2': 0.99,
53 | }
54 | train_params = {
55 | 'eval_iters': 200,
56 | 'eval_interval': 2000,
57 | 'max_iters': 1000000,
58 | # 'adam': adam_params,
59 | # 'lion': lion_params,
60 | 'optimizer': 'adamw',
61 | 'adamw': adam_params,
62 | }
63 |
64 | experimental_params = {
65 | 'eps': 1e-5,
66 | 'n_tokens': vocab_size_,
67 | 'n_channels': 768,
68 | 'n_heads': 12,
69 | 'n_blocks': 12,
70 |
71 | 'batch_size': 8,
72 | 'block_size': 128,
73 | 'train': train_params
74 | }
75 |
76 | # wandb.init(
77 | # project="inside-transformer",
78 | # config=experimental_params,
79 | # )
80 |
81 | max_iters = experimental_params['train']['max_iters']
82 | eval_interval = experimental_params['train']['eval_interval']
83 | eval_iters = experimental_params['train']['eval_iters']
84 | batch_config_ = LMBatchConfig(block_size=experimental_params['block_size'],
85 | batch_size=experimental_params['batch_size'])
86 |
87 | gpt_config_ = Gpt.Config(eps=experimental_params['eps'],
88 | n_channels=experimental_params['n_channels'],
89 | n_heads=experimental_params['n_heads'],
90 | # n_seq='dynamic',
91 | n_seq=batch_config_.block_size,
92 | max_seq_len=batch_config_.block_size,
93 | n_blocks=experimental_params['n_blocks'],
94 | n_tokens=vocab_size_,
95 | token_embedding_save_name='wte.weight',
96 | positional_embedding_save_name='wpe.weight',
97 | ln=LN.Config(w_save_name='weight', b_save_name='bias', save_name='ln_f'),
98 | save_name="",
99 | decoder=GptDecoder.Config(
100 | save_name='h',
101 | blocks=GptBlock.Config(
102 | save_name="",
103 | mha=GptMha.Config(
104 | save_name='attn',
105 | QKV_linear=Linear.Config(save_name='c_attn', w_save_name='weight',
106 | b_save_name='bias'),
107 | linear=Linear.Config(save_name="c_proj", w_save_name='weight', b_save_name='bias'),
108 | ),
109 | ln1=LN.Config(w_save_name='weight', b_save_name='bias', save_name='ln_1'),
110 | ffn=GptFfn.Config(
111 | save_name='mlp',
112 | linear1=Linear.Config(w_save_name='weight', b_save_name='bias', save_name='c_fc'),
113 | linear2=Linear.Config(w_save_name='weight', b_save_name='bias',
114 | save_name='c_proj'),
115 | ),
116 | ln2=LN.Config(w_save_name='weight', b_save_name='bias', save_name='ln_2')
117 | )
118 | )).fill()
119 |
120 | # dataset = "english"
121 |
122 |
123 | path = Path("/Data/lm_models/gpt2")
124 | with safe_open(path / "model.safetensors", framework="flax", device="cpu") as f:
125 | weights_tree_ = load_config(gpt_config_, f.get_tensor)
126 | print(weights_tree_.keys())
127 |
128 | gpt_ = gpt_config_.make()
129 | init_weights_ = gpt_config_.weights_check(weights_tree_)
130 |
131 | if experimental_params['train']['optimizer'] == 'adam':
132 | adam_config = experimental_params['train']['adam']
133 | optimizer_ = optax.adam(learning_rate=adam_config['learning_rate'],
134 | b1=adam_config['beta1'],
135 | b2=adam_config['beta2'],
136 | eps=adam_config['eps'])
137 | elif experimental_params['train']['optimizer'] == 'lion':
138 | lion_config = experimental_params['train']['lion']
139 | optimizer_ = optax.lion(learning_rate=lion_config['learning_rate'],
140 | b1=lion_config['beta1'],
141 | b2=lion_config['beta2'])
142 | elif experimental_params['train']['optimizer'] == 'adamw':
143 | adamw_config = experimental_params['train']['adamw']
144 | optimizer_ = optax.adamw(learning_rate=adamw_config['learning_rate'],
145 | b1=adamw_config['beta1'],
146 | b2=adamw_config['beta2'],
147 | eps=adamw_config['eps'])
148 | else:
149 | raise ValueError(f"optimizer {experimental_params['train']['optimizer']} not supported")
150 |
151 | # noinspection PyArgumentList
152 | # cuz it's a NamedTuple
153 | train_config_ = TrainConfig(model=gpt_config_.make(),
154 | loss_fn_in=gpt_loss_batch,
155 | optimiser=optimizer_)
156 | # noinspection PyArgumentList
157 | # cuz it's a NamedTuple
158 | train_state_: TrainState[Gpt.Weights] = TrainState(weights=init_weights_,
159 | opt_state=optimizer_.init(init_weights_))
160 |
161 | keys_ = next(key_gen).split(max_iters)
162 | for step in range(max_iters):
163 | batch_ = batch_config_.sample(train_data, keys_[step])
164 | if step % eval_interval == 0:
165 | loss = train_config_.loss_fn(train_state_.weights, batch_)
166 | print(f"===step {step} is an eval step===")
167 | print(f"before step {step}, batch loss {loss}")
168 |
169 | train_state_ = train_config_.train1(train_state_, batch_)
170 | if step % eval_interval == 0:
171 | loss = train_config_.loss_fn(train_state_.weights, batch_)
172 | print(f"after step {step}, batch loss {loss}")
173 | results = train_config_.estimate_loss(eval_iters, key_gen, train_state_, batch_config_,
174 | {'train': train_data, 'val': valid_data})
175 | # generate_f = jax.jit(partial(dynamic_model.f, train_state_.weights))
176 | # generated = gpt.generate(generate_f, [0], n_tokens_to_generate=10,
177 | # max_len=batch_config_.block_size)
178 | generate_f = partial(train_config_.model.f, train_state_.weights)
179 | # TODO fix generation/ add temperature
180 | generated = nlp_utils.generate_static(generate_f, [0],
181 | n_tokens_to_generate=batch_config_.block_size - 1,
182 | max_len=batch_config_.block_size)
183 | # wandb.log({"train_loss": results['train'],
184 | # "validation_loss": results['val'],
185 | # "batch_loss": loss,
186 | # "n_tokens_trained": step * batch_config_.batch_size * batch_config_.block_size,
187 | # 'generated': wandb.Html(f"{decode(generated)}")})
188 | print(decode(generated), flush=True)
189 |
190 | # wandb.finish()
191 |
--------------------------------------------------------------------------------
/gpt_recon/view_attn.py:
--------------------------------------------------------------------------------
1 | import altair as alt
2 | import numpy as np
3 | import pandas as pd
4 | import safetensors
5 | from numpy.typing import NDArray
6 |
7 |
8 | def heatmap(mat: NDArray) -> alt.Chart:
9 | x_len, y_len = mat.shape
10 | x, y = np.meshgrid(np.arange(x_len), np.arange(y_len))
11 | source = pd.DataFrame({
12 | 'x': x.flatten(),
13 | 'y': y.flatten(),
14 | 'z': mat.flatten()
15 | })
16 |
17 | return alt.Chart(source).mark_rect().encode(
18 | x='x:O',
19 | y='y:O',
20 | color='z:Q'
21 | )
22 |
23 |
24 | view_vec_dict = safetensors.safe_open('saves/view_vec2_dict', 'flax')
25 | attn_result = view_vec_dict.get_tensor('attn_raw')
26 | # head, layer, token, shape
27 | chart = alt.vconcat().resolve_scale(
28 | color='independent'
29 | )
30 | for i in range(12):
31 | row = alt.hconcat().resolve_scale(color='independent')
32 | for j in range(12):
33 | row |= heatmap(attn_result[j, i, :, :])
34 | chart &= row
35 | chart.save('attn.html')
36 |
--------------------------------------------------------------------------------
/gpt_recon/view_notebook.py:
--------------------------------------------------------------------------------
1 | import altair as alt
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import pandas as pd
5 | import safetensors.flax
6 | from numpy.typing import NDArray
7 | from sklearn.cluster import AgglomerativeClustering
8 | from sklearn.manifold import TSNE
9 |
10 | alt.renderers.enable('svg')
11 |
12 |
13 | def heatmap_at(mat: NDArray, display: str):
14 | x_len, y_len = mat.shape
15 | x, y = np.meshgrid(np.arange(x_len), np.arange(y_len))
16 | source = pd.DataFrame({
17 | 'x': x.flatten(),
18 | 'y': y.flatten(),
19 | 'z': mat.flatten()
20 | })
21 |
22 | alt.Chart(source).mark_rect().encode(
23 | x='x:O',
24 | y='y:O',
25 | color='z:Q'
26 | ).save(f'{display}.html')
27 |
28 |
29 | def heatmap(x):
30 | plt.figure()
31 | plt.imshow(x)
32 | plt.show()
33 |
34 |
35 | def plot1(x):
36 | print(x.shape)
37 | plt.figure()
38 | plt.plot(x)
39 | plt.show()
40 |
41 |
42 | # %%
43 | # view_vec = np.load('view_vec.npy')
44 | view_vec_dict = safetensors.safe_open('view_vec2_dict', 'flax')
45 | view_vec1 = view_vec_dict.get_tensor('x_before_ffn')
46 | view_vec2 = view_vec_dict.get_tensor('x_after_ffn')
47 |
48 | plot1(view_vec1[:, 0, 0])
49 | plot1(view_vec2[:, 0, 0])
50 |
51 | # %%
52 | bf_layer = view_vec_dict.get_tensor('x_before_mha')
53 | aft_layer = view_vec_dict.get_tensor('x')
54 | prompt = "Time flies like an arrow ; fruit flies like a banana . Time files like an"
55 | token_list = prompt.split(" ")
56 | token_list = [str(i).zfill(2) + "_" + token for i, token in enumerate(token_list)]
57 | # view_vec = np.stack([bf_layer, aft_layer]).reshape((-1, len(token_list), 768))
58 | view_vec = aft_layer
59 |
60 | # layers, tokens, channels
61 | plot1(view_vec[:, 0, 0])
62 | plot1(view_vec[0, :, 0])
63 | plot1(view_vec[0, 0, :])
64 |
65 | # %%
66 | # heatmap_at(np.median(view_vec, axis=-1))
67 | n_layers, n_tokens, n_channels = view_vec.shape
68 | corr_layers = np.array([np.corrcoef(view_vec[:, i, :]) for i in range(n_tokens)]).mean(axis=0)
69 | heatmap_at(corr_layers, "chart1")
70 | corr_tokens = np.array([np.corrcoef(view_vec[i, :, :]) for i in range(n_layers)]).mean(axis=0)
71 | heatmap_at(corr_tokens, "chart2")
72 |
73 | # %%
74 | plot1(np.array([np.corrcoef(view_vec[:, i, :]) for i in range(n_tokens)]).mean(axis=(1, 2)))
75 | plot1(np.array([np.corrcoef(view_vec[i, :, :]) for i in range(n_layers)]).mean(axis=(1, 2)))
76 | plot1(np.array([np.corrcoef(view_vec[:, :, i]) for i in range(n_channels)]).mean(axis=(1, 2)))
77 | plot1(np.array([np.corrcoef(view_vec[:, :, i].T) for i in range(n_channels)]).mean(axis=(1, 2)))
78 |
79 | # %%
80 | all_vecs = view_vec.reshape(-1, n_channels)
81 | # clustering = AgglomerativeClustering(n_clusters=3).fit(all_vecs)
82 | clustering = AgglomerativeClustering(n_clusters=10, linkage="average", metric="cosine").fit(all_vecs)
83 | c = clustering.labels_.reshape(n_layers, n_tokens)
84 | heatmap_at(c, "chart3")
85 |
86 | # %%
87 | # TSNE with first layer
88 | view_vec_dict = safetensors.safe_open('view_vec2_dict', 'flax')
89 | bf_layer = view_vec_dict.get_tensor('x_before_mha')
90 | aft_layer = view_vec_dict.get_tensor('x')
91 | prompt = "Time flies like an arrow ; fruit flies like a banana . Time files like an"
92 | token_list = prompt.split(" ")
93 | token_list = [str(i).zfill(2) + "_" + token for i, token in enumerate(token_list)]
94 |
95 | view_vec = np.stack([bf_layer, aft_layer]).reshape((-1, len(token_list), 768))
96 |
97 | n_layers, n_tokens, n_channels = view_vec.shape
98 | assert n_tokens == len(token_list)
99 | all_vecs = view_vec.reshape(-1, n_channels)
100 | # all_vecs = view_vec[:-1, :, :].reshape(-1, n_channels)
101 | # n_layers = n_layers - 1
102 | new = TSNE(n_components=2, learning_rate='auto',
103 | init='random', perplexity=3, metric='cosine').fit_transform(all_vecs)
104 | new = np.concatenate([new, np.array([np.repeat(i, n_tokens) for i in range(n_layers)]).reshape(-1, 1),
105 | np.array([np.arange(n_tokens) for i in range(n_layers)]).reshape(-1, 1)], axis=1)
106 |
107 | source = pd.DataFrame(new, columns=['x', 'y', 'layer', 'token'])
108 | source['token'] = source['token'].apply(lambda x: token_list[int(x)])
109 | source['layer'] = source['layer'].apply(lambda x: int(x))
110 | alt.Chart(source).mark_circle().encode(
111 | x='x',
112 | y='y',
113 | size=alt.Size('layer', legend=alt.Legend(type="symbol", symbolLimit=0), type='ordinal'),
114 | # rainbow
115 | color=alt.Color('token', scale=alt.Scale(scheme='rainbow'), legend=alt.Legend(type="symbol", symbolLimit=0)),
116 | ).properties(width=1000, height=1000).save('chart4.svg')
117 |
118 | # TSNE without first layer
119 |
--------------------------------------------------------------------------------
/gpt_recon/view_vec.py:
--------------------------------------------------------------------------------
1 | import altair as alt
2 | import numpy as np
3 | import pandas as pd
4 | import safetensors.flax
5 | from sklearn.manifold import TSNE
6 | from pacmap import PaCMAP
7 | from umap import UMAP
8 | from trimap import TRIMAP
9 |
10 | alt.renderers.enable('svg')
11 |
12 |
13 | view_vec_dict = safetensors.safe_open('saves/view_vec2_dict_jit', 'flax')
14 | bf_layer = view_vec_dict.get_tensor('x0')
15 | # mid_layer = view_vec_dict.get_tensor('x_after_mha')
16 | mid_layer = view_vec_dict.get_tensor('x_before_ffn')
17 | aft_layer = view_vec_dict.get_tensor('x')
18 | assert (bf_layer[1:, :, :] == aft_layer[:-1, :, :]).all()
19 |
20 | # %%
21 | prompt = "Time flies like an arrow ; fruit flies like a banana . Time files like an"
22 | token_list = prompt.split(" ")
23 | token_list = [str(i).zfill(2) + "_" + token for i, token in enumerate(token_list)]
24 |
25 | # view_vec = np.stack([bf_layer, aft_layer]).reshape((-1, len(token_list), 768), order='F')
26 | # view_vec = bf_layer
27 | view_vec = np.concatenate((bf_layer[0, :, :][np.newaxis, :, :], aft_layer), axis=0)
28 | # view_vec = np.stack((bf_layer[0, :, :], aft_layer[-1, :, :]), axis=0)
29 |
30 | # view_vec = np.stack([mid_layer, aft_layer]).reshape((-1, len(token_list), 768), order='F')
31 | # view_vec = np.concatenate((bf_layer[0, :, :][np.newaxis, :, :], view_vec), axis=0)
32 |
33 | n_layers, n_tokens, n_channels = view_vec.shape
34 | assert n_tokens == len(token_list)
35 |
36 | #%%
37 |
38 |
39 | all_vecs = view_vec.reshape(-1, n_channels)
40 | # all_vecs = view_vec[:-1, :, :].reshape(-1, n_channels)
41 | # n_layers = n_layers - 1
42 | # embedding = PaCMAP(distance='angular', n_components=2, n_neighbors=None, MN_ratio=0.5, FP_ratio=2.0)
43 | embedding = UMAP(n_neighbors=10, metric='cosine')
44 | # embedding = TRIMAP(distance='cosine')
45 | # embedding = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=3, metric='cosine')
46 | new = embedding.fit_transform(all_vecs)
47 | new = np.concatenate([new, np.array([np.repeat(i, n_tokens) for i in range(n_layers)]).reshape(-1, 1),
48 | np.array([np.arange(n_tokens) for i in range(n_layers)]).reshape(-1, 1)], axis=1)
49 |
50 | source = pd.DataFrame(new, columns=['x', 'y', 'layer', 'token'])
51 | source['token'] = source['token'].apply(lambda x: token_list[int(x)])
52 | source['layer'] = source['layer'].apply(lambda x: int(x))
53 |
54 | lines_all = []
55 | for i in range(n_tokens):
56 | t = source[source['token'] == token_list[i]]
57 | lines = alt.Chart(t).mark_line().encode(
58 | x='x',
59 | y='y',
60 | order='layer',
61 | color=alt.value("#AAAAAA")
62 | )
63 | lines_all.append(lines)
64 |
65 | dots = alt.Chart(source).mark_circle().encode(
66 | x='x',
67 | y='y',
68 | size=alt.Size('layer', legend=alt.Legend(type="symbol", symbolLimit=0), type='ordinal'),
69 | # rainbow
70 | color=alt.Color('token', scale=alt.Scale(scheme='rainbow'), legend=alt.Legend(type="symbol", symbolLimit=0)),
71 | ).properties(width=1000, height=1000)
72 |
73 | final = sum(lines_all, alt.LayerChart()) + dots
74 | final.save('lines.svg')
75 |
76 | # TSNE without first layer
77 |
--------------------------------------------------------------------------------
/labels.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 | from typing import Protocol, runtime_checkable, Union
5 |
6 |
7 | @dataclass(frozen=True)
8 | class Labels:
9 | tags: frozenset[str]
10 |
11 | def fmt(self) -> str:
12 | return ', '.join(self.tags)
13 |
14 | def check(self) -> None:
15 | if not isinstance(self.tags, frozenset):
16 | raise ValueError(f"tags must be a frozenset, not {type(self.tags)}")
17 |
18 | def covers(self, other: Labels) -> bool:
19 | return self.tags >= other.tags
20 |
21 | def __contains__(self, item: str) -> bool:
22 | return item in self.tags
23 |
24 | def __len__(self) -> int:
25 | return len(self.tags)
26 |
27 | def __add__(self, other: Union[str, Labels]) -> Labels:
28 | if isinstance(other, str):
29 | return Labels(tags=self.tags | {other})
30 | else:
31 | return Labels(tags=self.tags | other.tags)
32 |
33 | def __radd__(self, other):
34 | return self + other
35 |
36 | @classmethod
37 | def empty(cls):
38 | return cls(frozenset())
39 |
40 | @classmethod
41 | def from_strs(cls, *tags: str) -> Labels:
42 | return cls(frozenset(tags))
43 |
44 |
45 |
46 | L = Labels.from_strs
47 |
48 |
49 | @runtime_checkable
50 | class WithLabels(Protocol):
51 | labels: Labels
52 |
53 | def clear_labels(self) -> WithLabels:
54 | ...
55 |
--------------------------------------------------------------------------------
/lra_arena/lra_utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import NamedTuple, Sequence, Callable, Generator, Iterator
3 |
4 | import jax.numpy as np
5 | from jax import random
6 | from s5.dataloading import Datasets, DataLoader
7 |
8 | from picojax.random_utils import SafeKey
9 | from picojax.train_utils import MixedLenBatchType
10 |
11 |
12 | class LRABatchConfig(NamedTuple):
13 | block_size: int
14 | batch_size: int
15 | s5_dataloaders: DataLoader
16 | train_size: int
17 | n_classes_in: int
18 | n_classes_out: int
19 |
20 | @classmethod
21 | def from_s5(cls, batch_size: int, cache_path: Path, dataset_name: str, seed: int = 0):
22 | create_dataset_fn = Datasets[dataset_name]
23 | trainloader, valloader, testloader, aux_dataloaders, n_classes, seq_len, in_dim, train_size = create_dataset_fn(
24 | cache_path, seed=seed, bsz=batch_size)
25 | return cls(block_size=seq_len, batch_size=batch_size,
26 | s5_dataloaders={'train': trainloader, 'val': valloader, 'test': testloader}, train_size=train_size,
27 | n_classes_in=in_dim, n_classes_out=n_classes)
28 |
29 | @property
30 | def samplers(self) -> dict[str, Callable[[SafeKey], MixedLenBatchType]]:
31 | def get_sampler(loader: DataLoader) -> Callable[[SafeKey], MixedLenBatchType]:
32 | loader_sampler = iter(loader.batch_sampler)
33 | def sampler(key: SafeKey):
34 | x, y, l = next(loader_sampler)
35 | return np.array(x), np.array(y), np.array(l['lengths'])
36 |
37 | return sampler
38 |
39 | return {k: get_sampler(v) for k, v in self.s5_dataloaders.items()}
40 |
41 | @property
42 | def dataloaders(self) -> dict[str, Iterator[MixedLenBatchType]]:
43 | def get_dataloader(loader: DataLoader) -> Iterator[MixedLenBatchType]:
44 | def data_generator() -> Iterator[MixedLenBatchType]:
45 | loader_iter = iter(loader)
46 | while True:
47 | x, y, l = next(loader_iter)
48 | yield np.array(x), np.array(y), np.array(l['lengths'])
49 |
50 | return data_generator()
51 | return {k: get_dataloader(v) for k, v in self.s5_dataloaders.items()}
52 |
--------------------------------------------------------------------------------
/lra_arena/rwkv_lra.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | from functools import partial
5 | from pathlib import Path
6 | from typing import Optional
7 |
8 | import jax.numpy as np
9 | import optax
10 | import wandb
11 | from jax.tree_util import tree_flatten
12 |
13 | from copy_init.weights import get_normal_weights_config_, init, save_pytree_
14 | from labels import Labels
15 | from lra_arena.lra_utils import LRABatchConfig
16 | from pico_rwkv.pico_rwkv_parallel import rwkv_net_parallel
17 | from pico_rwkv.pico_rwkv_rnn import rwkv_net_rnn
18 | from pico_rwkv.pico_rwkv_weights import parse_rwkv_weight
19 | from picojax.jax_utils import WeightsTree, Arr
20 | from picojax.random_utils import infinite_safe_keys
21 | from picojax.train_utils import TrainConfig, TrainState, get_classification_loss
22 | from python_utils import num_short_form
23 |
24 | os.environ['JAX_LOG_COMPILES'] = '1'
25 | model_path = Path("/Data/lm_models/rwkv")
26 | # model_name = 'RWKV-4-Pile-430M-20220808-8066'
27 | model_name = 'RWKV-4-Pile-169M-20220807-8023'
28 |
29 | weight_infos = get_normal_weights_config_(model_path, model_name)
30 |
31 | keygen = infinite_safe_keys(0)
32 | key = next(keygen)
33 |
34 | init_weights_raw = init(weight_infos, rng_key=key)
35 | all_weight_names = list(init_weights_raw.keys())
36 |
37 | # randomly initialize weights
38 | init_weights_: WeightsTree = parse_rwkv_weight(init_weights_raw.keys(), init_weights_raw.__getitem__, trim=True)
39 | ## load weights instead of randomly initialize
40 | # with safe_open(model_path / f"{model_name}.safetensors", framework="flax", device="cpu") as f:
41 | # init_weights_ = parse_rwkv_weight(f.keys(), f.get_tensor)
42 |
43 | n_channels = init_weights_['emb']['weight'].shape[1] # type: ignore
44 | _, tree_struct = tree_flatten(init_weights_)
45 |
46 | train_tags: Optional[list[Labels]] = None
47 | weight_mask = None
48 |
49 | key_gen = infinite_safe_keys(0)
50 |
51 | adam_params = {
52 | 'learning_rate': 1e-4,
53 | 'beta1': 0.9,
54 | 'beta2': 0.999,
55 | 'eps': 1e-8,
56 | }
57 | lion_params = {
58 | 'learning_rate': 1e-4,
59 | 'beta1': 0.95,
60 | 'beta2': 0.98,
61 | 'weight_decay': 0.01
62 | }
63 | train_params = {
64 | 'eval_iters': 200,
65 | 'eval_interval': 2000,
66 | 'save_interval': 10000,
67 | 'max_iters': 1000000,
68 | # 'adam': adam_params,
69 | 'lion': lion_params,
70 | # 'adamw': adam_params,
71 | 'optimizer': 'lion',
72 | }
73 | batch_size = 3
74 |
75 | batch_config_ = LRABatchConfig.from_s5(
76 | batch_size=batch_size,
77 | cache_path=Path("/Data/datasets/pile"),
78 | dataset_name='listops-classification'
79 | )
80 |
81 | # reduce embedding to vocab size
82 | init_weights_['emb']['weight'] = init_weights_['emb']['weight'][:batch_config_.n_classes_in, :] # type: ignore
83 | init_weights_['head']['weight'] = init_weights_['head']['weight'][:batch_config_.n_classes_out, :] # type: ignore
84 |
85 | experimental_params: dict = {
86 | 'eps': 1e-5,
87 | 'n_tokens': batch_config_.n_classes_in,
88 | 'n_channels': n_channels,
89 | 'n_blocks': len(init_weights_['blocks']),
90 |
91 | 'train_tags': [l.fmt() for l in train_tags] if train_tags is not None else None,
92 |
93 | 'batch_size': batch_config_.batch_size,
94 | 'block_size': batch_config_.block_size,
95 | 'train': train_params,
96 | 'model': "rwkv"
97 | }
98 |
99 | max_iters = experimental_params['train']['max_iters']
100 | eval_interval = experimental_params['train']['eval_interval']
101 | save_interval = experimental_params['train']['save_interval']
102 | eval_iters = experimental_params['train']['eval_iters']
103 |
104 | if experimental_params['train']['optimizer'] == 'adam':
105 | adam_config = experimental_params['train']['adam']
106 | optimizer_ = optax.adam(learning_rate=adam_config['learning_rate'],
107 | b1=adam_config['beta1'],
108 | b2=adam_config['beta2'],
109 | eps=adam_config['eps'])
110 | elif experimental_params['train']['optimizer'] == 'lion':
111 | lion_config = experimental_params['train']['lion']
112 | optimizer_ = optax.lion(learning_rate=lion_config['learning_rate'],
113 | b1=lion_config['beta1'],
114 | b2=lion_config['beta2'],
115 | weight_decay=lion_config['weight_decay'])
116 | elif experimental_params['train']['optimizer'] == 'adamw':
117 | adamw_config = experimental_params['train']['adamw']
118 | optimizer_ = optax.adamw(learning_rate=adamw_config['learning_rate'],
119 | b1=adamw_config['beta1'],
120 | b2=adamw_config['beta2'],
121 | eps=adamw_config['eps'])
122 | else:
123 | raise ValueError(f"optimizer {experimental_params['train']['optimizer']} not supported")
124 |
125 |
126 | def rwkv_f(w: WeightsTree, token_array: Arr) -> Arr:
127 | return rwkv_net_parallel(len(token_array), token_array, **w)
128 |
129 |
130 | def rwkv_rnn(w: WeightsTree, token_array: Arr, state: Arr) -> tuple[Arr, Arr]:
131 | return rwkv_net_rnn(token_array, state, **w)
132 |
133 |
134 | # noinspection PyArgumentList
135 | # cuz it's a NamedTuple
136 | train_config_ = TrainConfig(loss_fn=partial(get_classification_loss, rwkv_f),
137 | optimiser=optimizer_)
138 | # noinspection PyArgumentList
139 | # cuz it's a NamedTuple
140 | train_state_: TrainState = TrainState(weights=init_weights_,
141 | train_mask=weight_mask,
142 | opt_state=optimizer_.init(init_weights_))
143 |
144 | rnn_init_state = np.zeros((experimental_params['n_blocks'], 5, experimental_params['n_channels']))
145 | for i in range(experimental_params['n_blocks']):
146 | # to jax state[5 * i + 4] = -1e30
147 | rnn_init_state = rnn_init_state.at[i, 4].set(-1e30)
148 |
149 | run = wandb.init(
150 | project="inside-transformer",
151 | config=experimental_params,
152 | )
153 | assert isinstance(run, wandb.sdk.wandb_run.Run)
154 |
155 | keys_ = next(key_gen).split(max_iters)
156 | train_batch_iter = iter(batch_config_.dataloaders['train'])
157 | for step in range(max_iters):
158 | batch_ = next(train_batch_iter)
159 |
160 | if step % eval_interval == 0:
161 | loss = train_config_.loss_fn(train_state_.weights, batch_)
162 | print(f"\n===[ step {step} is an eval step ]==========")
163 | print(f"before step {step}, batch loss {loss}")
164 |
165 | train_state_ = train_config_.train1(train_state_, batch_)
166 | if step % eval_interval == 0:
167 | loss = train_config_.loss_fn(train_state_.weights, batch_)
168 | print(f"after step {step}, batch loss {loss}")
169 | results = train_config_.estimate_loss(eval_iters, key_gen, train_state_, batch_config_.samplers)
170 |
171 | wandb.log({"train_loss": results['train'],
172 | "validation_loss": results['val'],
173 | "batch_loss": loss,
174 | "n_tokens_trained": step * batch_config_.batch_size * batch_config_.block_size})
175 | if step % save_interval == 0: # and step > 0:
176 | n_tokens_trained = step * batch_config_.batch_size * batch_config_.block_size
177 | n_tokens_trained_str = num_short_form(n_tokens_trained)
178 | wandb.save(save_pytree_(train_state_.weights, run.dir, f"{model_name}_{n_tokens_trained_str}"), run.dir)
179 |
180 | wandb.finish()
181 |
182 | # [Done] add trainable weights (via gradient mask)
183 | # [TODO] add trainable weights via weight name mask
184 | # [Done] add checkpointing
185 | # TODO: add weight decay
186 |
--------------------------------------------------------------------------------
/lra_arena/use_lra_demo.py:
--------------------------------------------------------------------------------
1 | # to download data: https://github.com/google-research/long-range-arena
2 | # https://storage.googleapis.com/long-range-arena/lra_release.gz unzip
3 | from pathlib import Path
4 |
5 | # load data in /Data/LRA/lra_release/lra_release/
6 | # clone https://github.com/lindermanlab/S5/tree/main repo and `pip install -e .` in the repo
7 | from s5.dataloading import Datasets
8 | print(Datasets.keys())
9 | create_dataset_fn = Datasets['listops-classification']
10 | batch_size = 8
11 | seed = 0
12 | data_path = Path("/Data")
13 | trainloader, valloader, testloader, aux_dataloaders, n_classes, seq_len, in_dim, train_size = create_dataset_fn(data_path, seed=0, bsz=batch_size)
14 | print(aux_dataloaders)
15 | print(in_dim)
16 | print(seq_len)
17 | print(n_classes)
18 | print(next(iter(trainloader))[0].shape)
19 | print(next(iter(trainloader))[1].shape)
20 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | plugins = pydantic.mypy
--------------------------------------------------------------------------------
/nlp_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Callable, Iterator, Protocol
4 |
5 | from jax import numpy as jnp, random, numpy as np
6 | from jax._src.lax.control_flow import scan
7 | from jax._src.nn.functions import softmax
8 | from tqdm import tqdm
9 |
10 | from picojax.jax_utils import Arr
11 | from picojax.random_utils import SafeKey
12 |
13 |
14 | class Tokens(Protocol):
15 | ids: list[int]
16 |
17 |
18 | class Tokenizer(Protocol):
19 | def encode(self, text: str) -> Tokens:
20 | ...
21 |
22 | def decode(self, ids: list[int]) -> str:
23 | ...
24 |
25 | def get_vocab_size(self) -> int:
26 | ...
27 |
28 |
29 | def generate(get_logits: Callable[[Arr], Arr], inputs: list[int], n_tokens_to_generate: int, max_len: int):
30 | input_window = inputs
31 | for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
32 | logits = get_logits(jnp.array(input_window))
33 | next_id = jnp.argmax(logits[-1]) # greedy sampling
34 | inputs.append(int(next_id)) # append prediction to input
35 | input_window = inputs[-max_len:] # update input window
36 |
37 | return inputs[len(inputs) - n_tokens_to_generate:] # only return generated ids
38 |
39 |
40 | def generate_static(get_logits: Callable[[Arr], Arr], inputs: list[int], n_tokens_to_generate: int, max_len: int):
41 | for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
42 | if len(inputs) >= max_len:
43 | input_window = inputs[-max_len:] # update input window
44 | else:
45 | input_window = inputs + [0] * (max_len - len(inputs))
46 | output_index = len(inputs) - 1
47 | logits = get_logits(jnp.array(input_window))
48 | next_id = jnp.argmax(logits[output_index]) # greedy sampling
49 | inputs.append(int(next_id)) # append prediction to input
50 |
51 | return inputs[len(inputs) - n_tokens_to_generate:] # only return generated ids
52 |
53 |
54 | # not working yet from https://github.com/cgarciae/nanoGPT-jax/blob/master/model.py
55 | def generate_static_inplace(get_logits: Callable[[Arr], Arr],
56 | key: SafeKey,
57 | inputs: list[int],
58 | n_tokens_to_generate: int,
59 | max_len: int,
60 | temperature=1.0,
61 | top_k=None):
62 | input_len = len(inputs)
63 | input_tokens = jnp.array(inputs)
64 | padding = jnp.zeros(n_tokens_to_generate, dtype=jnp.int32)
65 | tokens = jnp.concatenate([input_tokens, padding], axis=-1)
66 | indexes = jnp.arange(input_len, input_len + n_tokens_to_generate)
67 |
68 | # tokens index -> tokens None
69 | def scan_f(tokens, i):
70 | # l: x y
71 | # t: a b - -
72 | # i: 0 1 2 3
73 | step_key = random.fold_in(key.get(), i)
74 | # if the sequence context is growing too long we must crop it at block_size
75 | # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
76 | # forward the model to get the logits for the index in the sequence
77 | logits = get_logits(tokens)
78 | # pluck the logits at the final step and scale by desired temperature
79 | logits = logits[i - 1] / temperature
80 | # optionally crop the logits to only the top k options
81 | # sample from the distribution
82 | if top_k is not None:
83 | top_logits, top_tokens = top_k(logits, min(top_k, logits.shape[-1]))
84 | token_idx = random.categorical(step_key, top_logits, axis=-1)
85 | next_token = jnp.take_along_axis(top_tokens, token_idx[:, None], axis=-1).squeeze(-1)
86 | else:
87 | next_token = random.categorical(step_key, logits, axis=-1)
88 | # logits = jnp.where(logits < v[:, -1:], float('-inf'), logits)
89 | # append sampled index to the running sequence and continue
90 | tokens = tokens.at[i].set(next_token)
91 |
92 | return tokens, None
93 |
94 | tokens, _ = scan(scan_f, tokens, indexes)
95 |
96 | return tokens.tolist()
97 |
98 |
99 | def rnn_generate(get_logits_rnn: Callable[[Arr, Arr], tuple[Arr, Arr]],
100 | context: str,
101 | init_state: Arr,
102 | tokenizer: Tokenizer,
103 | key_gen: Iterator[SafeKey],
104 | argmax: bool = False,
105 | length_per_trial: int = 100, n_trials: int = 1, temperature: float = 1.0, top_p: float = 0.85) -> str:
106 | init_state = init_state.copy()
107 | for token in tokenizer.encode(context).ids:
108 | init_out, init_state = get_logits_rnn(token, init_state)
109 | # print(init_state[1, :, 1])
110 |
111 | def sample_logits(logits, key, temperature=1.0, top_p=0.8):
112 | probs = softmax(logits, axis=-1)
113 | sorted_probs = np.sort(probs)[::-1]
114 | cumulative_probs = np.cumsum(sorted_probs)
115 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
116 | probs = np.where(probs < cutoff, 0, probs)
117 | if temperature != 1.0:
118 | probs = np.power(probs, 1.0 / temperature)
119 | probs = probs / np.sum(probs)
120 |
121 | out = random.choice(key.get(), a=len(probs), p=probs)
122 | return out
123 |
124 | out_str = ""
125 | for t in range(n_trials):
126 | to_print = f'--[ Trial {t} ]-----------------\n{context}'
127 | print(to_print, end="")
128 | out_str += to_print
129 | all_tokens = []
130 | out_last = 0
131 | out, state_ = init_out, init_state.copy()
132 | for i in range(length_per_trial):
133 | if argmax:
134 | token = np.argmax(out)
135 | else:
136 | token = sample_logits(out, next(key_gen), temperature, top_p)
137 | all_tokens.append(token.item())
138 | tmp = tokenizer.decode(all_tokens[out_last:])
139 | if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
140 | print(tmp, end="", flush=True)
141 | out_str += tmp
142 | out_last = i + 1
143 | out, state_ = get_logits_rnn(token, state_)
144 | print(flush=True)
145 | return out_str
146 |
--------------------------------------------------------------------------------
/pico_lru/pico_lru.py:
--------------------------------------------------------------------------------
1 | from jax import numpy as np
2 | from jax import vmap
3 | from jax import lax
4 |
5 |
6 | def forward(input_sequence, nu_log, theta_log, B_re, B_im, C_re, C_im, D, gamma_log):
7 | """Forward pass of the LRU layer. Output y and input_sequence are of shape (L, H)."""
8 |
9 | # Materializing the diagonal of Lambda and projections
10 | Lambda = np.exp(-np.exp(nu_log) + 1j * np.exp(theta_log))
11 | B_norm = (B_re + 1j * B_im) * np.expand_dims(np.exp(gamma_log), axis=-1)
12 | C = C_re + 1j * C_im
13 |
14 | # Running the LRU + output projection
15 | # For details on parallel scan, check discussion in Smith et al (2022).
16 | Lambda_elements = np.repeat(Lambda[None, ...], input_sequence.shape[0], axis=0)
17 | Bu_elements = vmap(lambda u: B_norm @ u)(input_sequence)
18 | elements = (Lambda_elements, Bu_elements)
19 | _, inner_states = lax.associative_scan(binary_operator_diag, elements) # all x_k
20 | y = vmap(lambda x, u: (C @ x).real + D * u)(inner_states, input_sequence)
21 |
22 | return y
23 |
--------------------------------------------------------------------------------
/pico_lru/pico_lru_parallel.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cwhy/rwkv-decon/edb98117140f6a8e99a8fe250a1882a9e07a5650/pico_lru/pico_lru_parallel.py
--------------------------------------------------------------------------------
/pico_lru/pico_rnn.py:
--------------------------------------------------------------------------------
1 | from jax import numpy as np
2 |
3 |
4 | def rnn_step(x, u, a, b, c, d):
5 | xp = a * x + b * u
6 | y = c * xp + d * u
7 | return xp, y
8 |
9 |
10 | def init_params_(state_size: int, input_size: int, output_size: int):
11 | return {
12 | 'a': np.zeros((state_size, state_size)),
13 | 'b': np.zeros((state_size, input_size)),
14 | 'c': np.zeros((output_size, state_size)),
15 | 'd': np.zeros((output_size, input_size)),
16 | }
17 |
--------------------------------------------------------------------------------
/pico_rwkv/associative_scan_toolbox.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from functools import reduce
3 | from typing import NamedTuple, Callable, Sequence, Protocol
4 |
5 | import sympy
6 | from sympy import symbols, init_printing, Eq, sin, cos, Idx, IndexedBase, pprint, Function, Piecewise, Tuple, Mul, Add, \
7 | lambdify, simplify, expand, Sum, Indexed, Expr, exp
8 |
9 | a, b, c, d, l = symbols('a, b, c, d, λ')
10 |
11 | t = symbols('t', integer=True)
12 | i = Idx('i', t)
13 | j = Idx('j', t)
14 | k = Idx('k', t)
15 |
16 | Z = IndexedBase('z', shape=(t,))
17 | Y = IndexedBase('y', shape=(t,))
18 | L = IndexedBase('λ', shape=(t,))
19 |
20 | init_printing(use_unicode=True)
21 |
22 | gen_next = Eq(Y[t + 1], Y[t] * l + Z[t])
23 | pprint(gen_next, use_unicode=True)
24 |
25 |
26 | def is_associative(op, a, b, c) -> bool:
27 | return op(op(a, b), c).equals(op(a, op(b, c)))
28 |
29 |
30 | def is_distributive(mul_like, add_like, a, b, c) -> bool:
31 | return mul_like(a, add_like(b, c)).equals(add_like(mul_like(a, b), mul_like(a, c)))
32 |
33 | class IndexedBaseG(Protocol):
34 | def __getitem__(self, item) -> Expr:
35 | ...
36 |
37 |
38 | class RecPair(NamedTuple):
39 | weight: Expr
40 | bias: Expr
41 |
42 | def equals(self, other: 'RecPair') -> bool:
43 | return self.weight.equals(other.weight) and self.bias.equals(other.bias)
44 |
45 |
46 | class IndexedPair(NamedTuple):
47 | weight: IndexedBaseG
48 | bias: IndexedBaseG
49 |
50 | def __getitem__(self, item) -> RecPair:
51 | return RecPair(self.weight[item], self.bias[item])
52 |
53 |
54 | class RecFormula(NamedTuple):
55 | x: IndexedBase
56 | params: IndexedPair
57 | op_mul: Function = Mul
58 | op_add: Function = Add
59 |
60 | def test(self):
61 | assert is_associative(self.op_mul, a, b, c)
62 | assert is_associative(self.op_add, a, b, c)
63 | assert is_distributive(self.op_mul, self.op_add, a, b, c)
64 |
65 | @property
66 | def rec_base(self) -> Indexed:
67 | return self.params[0].bias
68 |
69 | def rec_step(self, n: int):
70 | base_ = self.rec_base
71 | for i in range(1, n):
72 | w, b = self.params[i]
73 | base_ = self.op_add(self.op_mul(base_, w), b)
74 | return base_
75 |
76 | def pscan_i(self, i: int) -> RecPair:
77 | return self.params[i]
78 |
79 | @property
80 | def pscan_base(self) -> RecPair:
81 | return self.pscan_i(0)
82 |
83 | def pscan_step(self, left: RecPair, right: RecPair) -> RecPair:
84 | self.test()
85 | return RecPair(simplify(self.op_mul(left.weight, right.weight)),
86 | simplify(expand((self.op_add(self.op_mul(left.bias, right.weight),
87 | right.bias)))))
88 |
89 |
90 | @dataclass
91 | class IndexedExpr:
92 | base_vars: list[IndexedBaseG]
93 | expr: Callable[[Sequence[Expr]], Expr]
94 |
95 | def __getitem__(self, item) -> Expr:
96 | return self.expr([v[item] for v in self.base_vars])
97 |
98 |
99 | @dataclass
100 | class FakeIndex:
101 | const: Expr
102 |
103 | def __getitem__(self, item) -> Indexed:
104 | return self.const
105 |
106 | # print(build_rec_formula(Y, t, Z, L))
107 | f = RecFormula(Y, IndexedPair(FakeIndex(l), Z))
108 | # pprint(f.rec_equation, use_unicode=True)
109 | pprint(f.pscan_i(i), use_unicode=True)
110 | pprint(f.pscan_base, use_unicode=True)
111 | ne = f.pscan_step(f.pscan_base, f.pscan_i(1))
112 | pprint(ne, use_unicode=True)
113 | pprint(f.pscan_step(ne, f.pscan_i(2)), use_unicode=True)
114 | assert is_associative(Mul, a, b, c)
115 | assert is_associative(f.pscan_step, f.pscan_i(i), f.pscan_i(j), f.pscan_i(k))
116 |
117 | base_ = f.pscan_base
118 | for i in range(1, 10):
119 | base_ = f.pscan_step(base_, f.pscan_i(i))
120 | pprint(base_, use_unicode=True)
121 | rec10 = f.rec_step(10)
122 | pprint(rec10, use_unicode=True)
123 | assert base_[1].equals(f.rec_step(10))
124 | print(simplify(expand(rec10)))
125 | print(simplify(expand(base_[1])))
126 |
127 |
128 | #%%
129 | K = IndexedBase('k', shape=(t,))
130 | V = IndexedBase('v', shape=(t,))
131 | w = symbols('w')
132 |
133 | lc = exp(w)
134 | lci = FakeIndex(lc)
135 | expKV = IndexedExpr([K, V], lambda kv: exp(kv[0]) * kv[1])
136 | f = RecFormula(Y, IndexedPair(lci, expKV), op_mul=Mul, op_add=Add)
137 | rec10 = f.rec_step(3)
138 | print(rec10)
139 | base_ = f.pscan_base
140 | for i in range(1, 3):
141 | base_ = f.pscan_step(base_, f.pscan_i(i))
142 | print(base_[1])
143 | print(base_[0])
144 | assert base_[1].equals(f.rec_step(3))
145 | print(simplify(expand(f.rec_step(5))))
146 |
147 | i = Idx('i', t)
148 | dummy_l = RecPair(lc, expKV[i])
149 | dummy_r = RecPair(lc, expKV[i + 1])
150 | ne = f.pscan_step(dummy_l, dummy_r)
151 | print(ne[0])
152 | print(ne[1])
153 |
154 | # This can be alternative associative scan function of rwkv
155 | # version 2
156 | def pscan_step_alt(left: RecPair, right: RecPair) -> RecPair:
157 | return RecPair(simplify(left.weight + right.weight),
158 | simplify(expand((left.bias * exp(right.weight) + right.bias))))
159 |
160 | #i = Idx('i', t)
161 | res = reduce(pscan_step_alt, [RecPair(w, expKV[j]) for j in range(1, 5)])
162 | print(res)
163 | res_rev = reduce(lambda a, b: pscan_step_alt(b, a), [RecPair(w, expKV[j]) for j in reversed(range(1, 5))])
164 | print(res_rev)
165 | print(simplify(expand(f.rec_step(5))))
166 |
167 |
--------------------------------------------------------------------------------
/pico_rwkv/close_to_original/pico_rwkv.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import jax.numpy as np
4 | from jax import jit
5 | from jax.lax import rsqrt
6 | from jax.nn import sigmoid, relu
7 |
8 |
9 | def layer_norm(x, weight, bias, eps: float = 1e-5):
10 | mean = np.mean(x, axis=-1, keepdims=True)
11 | variance = np.var(x, axis=-1, keepdims=True)
12 | return weight * (x - mean) * rsqrt(variance + eps) + bias
13 |
14 |
15 | def time_mix(x, old_state, mix):
16 | return x * mix + old_state * (1 - mix)
17 |
18 |
19 | def exp_mix(max_coef, k, v_s, v, base):
20 | """
21 | 1. If `max_coef >= k`, the mixed value and the updated base value are computed as follows:
22 | mix_v = v_s + e^(k - max_coef) * v
23 | base_new = base + e^(k - max_coef)
24 |
25 | 2. If `max_coef < k`, the mixed value and the updated base value are computed as follows:
26 | mix_v = e^(max_coef - k) * v_s + v
27 | base_new = e^(max_coef - k) * base + 1
28 | """
29 | new_max_coef = np.maximum(max_coef, k)
30 | e1 = np.exp(max_coef - new_max_coef)
31 | e2 = np.exp(k - new_max_coef)
32 | mix_v = e1 * v_s + e2 * v
33 | base_new = e1 * base + e2
34 | return mix_v, base_new, new_max_coef
35 |
36 |
37 | @partial(jit, static_argnames='i')
38 | def rwkv(r, ow, k, v, state, i: int, time_first, time_decay, debug=False):
39 | """
40 | the original form of the equation is:
41 | $\omega = \frac{a_1 + e^{(t_1 + k) - p} \cdot v}{b_1 + e^{(t_1 + k) - p}}$
42 | for numerical stability, we rewrite it as:
43 | $\omega = \frac{e^{p - \max(p, t_1 + k)} \cdot a_1 + e^{(t_1 + k) - \max(p, t_1 + k)} \cdot v}{e^{p - \max(p, t_1 + k)} \cdot b_1 + e^{(t_1 + k) - \max(p, t_1 + k)}}$
44 | where
45 | $a_1$ is wkv_top (state[i, 2])
46 | $b_1$ is wkv_bot (state[i, 3])
47 | $p$ is max_coef (state[i, 4])
48 | $\omega$ is wkv
49 | $t_1$ is time_first
50 | $t_2$ is time_decay
51 |
52 | :param r: [1024, ]
53 | :param ow:
54 | :param k: [1024, ]
55 | :param v: [1024, ]
56 | :param state: [50, 1024]
57 | :param i: int
58 | :param time_first: [1024, ]
59 | :param time_decay:
60 | :return:
61 | """
62 | v_state, base_state, max_coef = state[i, 2], state[i, 3], state[i, 4]
63 |
64 | wkv_top, wkv_bot, _ = exp_mix(max_coef, k + time_first, v_state, v, base_state)
65 | wkv = wkv_top / wkv_bot
66 | wkv_top_new, wkv_bot_new, max_coef = exp_mix(max_coef + time_decay, k, v_state, v, base_state)
67 |
68 |
69 | state = state.at[i, 2].set(wkv_top_new)
70 | state = state.at[i, 3].set(wkv_bot_new)
71 | state = state.at[i, 4].set(max_coef)
72 | return ow @ (r * wkv), state
73 |
74 |
75 | @partial(jit, static_argnames='i')
76 | def token_mixing(x, state, i: int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow, debug=False):
77 | xk = time_mix(x, state[i, 1], time_mix_k)
78 | xv = time_mix(x, state[i, 1], time_mix_v)
79 | xr = time_mix(x, state[i, 1], time_mix_r)
80 |
81 | state = state.at[i, 1].set(x)
82 |
83 | r = sigmoid(rw @ xr)
84 | k = kw @ xk
85 | v = vw @ xv
86 | return rwkv(r, ow, k, v, state, i, time_first, time_decay, debug=debug)
87 |
88 |
89 | @partial(jit, static_argnames='i')
90 | def channel_mixing(x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw):
91 | xk = time_mix(x, state[i, 0], time_mix_k)
92 | xr = time_mix(x, state[i, 0], time_mix_r)
93 | state = state.at[i, 0].set(x)
94 | r = sigmoid(rw @ xr)
95 | k = np.square(relu(kw @ xk)) # square relu, primer paper
96 | return r * (vw @ k), state
97 |
98 |
99 | @partial(jit, static_argnames='i')
100 | def block(x, state, i: int, att, ffn, ln1, ln2, **kwargs):
101 | xn = layer_norm(x, **ln1)
102 | xp, state = token_mixing(xn, state, i,
103 | att['time_mix_k'],
104 | att['time_mix_v'],
105 | att['time_mix_r'],
106 | att['time_first'],
107 | att['time_decay'],
108 | att['key']['weight'],
109 | att['value']['weight'],
110 | att['receptance']['weight'],
111 | att['output']['weight'], debug=kwargs['debug'])
112 |
113 | x += xp
114 | xn = layer_norm(x, **ln2)
115 | xp, state = channel_mixing(xn, state, i, ffn['time_mix_k'], ffn['time_mix_r'],
116 | ffn['key']['weight'], ffn['value']['weight'],
117 | ffn['receptance']['weight'])
118 | x += xp
119 | return x, state
120 |
121 |
122 | def rwkv_net(token, state, ln_out, blocks, head, emb):
123 | #print(token)
124 | x = emb['weight'][token]
125 | w_ln0 = blocks[0]['ln0']
126 | x = layer_norm(x, **w_ln0)
127 | for i in range(len(blocks)):
128 | block_w = blocks[i]
129 | x, state = block(x, state, i, debug=token == 49, **block_w)
130 | xn = layer_norm(x, **ln_out)
131 | x = head['weight'] @ xn
132 | return x, state
133 |
134 |
135 | @jit
136 | def rwkv_net_w(token, state, w):
137 | return rwkv_net(token, state, w['ln_out'], w['blocks'], w['head'], w['emb'])
138 |
139 |
--------------------------------------------------------------------------------
/pico_rwkv/close_to_original/run_rwkv.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import jax.numpy as np
4 | from safetensors import safe_open
5 | from tokenizers import Tokenizer
6 |
7 | from nlp_utils import rnn_generate
8 | from pico_rwkv.pico_rwkv_weights import parse_rwkv_weight
9 | from picojax.random_utils import infinite_safe_keys
10 | from pico_rwkv.close_to_original.pico_rwkv import rwkv_net_w
11 |
12 | path = Path("/Data/lm_models/rwkv")
13 | # model_name = 'RWKV-4-Pile-430M-20220808-8066'
14 | model_name = 'RWKV-4-Pile-169M-20220807-8023'
15 | # jax.config.update('jax_platform_name', 'cpu')
16 |
17 |
18 | with safe_open(path / f"{model_name}.safetensors", framework="flax", device="cpu") as f:
19 | w = parse_rwkv_weight(f.keys(), f.get_tensor)
20 |
21 | tokenizer = Tokenizer.from_file(str(path / "20B_tokenizer.json"))
22 |
23 |
24 | n_channels = w['emb']['weight'].shape[1]
25 | n_layers = len(w['blocks'])
26 |
27 | context = ("\nPumas are large, cat-like animals found in America. When reports came into London Zoo that "
28 | "a wild puma had been spotted forty-five miles south of London, they were not taken seriously."
29 | " However, as the evidence began to accumulate, experts from the Zoo felt obliged to investigate,"
30 | " for the descriptions given by people who claimed to have seen the puma were extraordinarily similar."
31 | "\nThe hunt for the puma began in a small village where a woman picking blackberries saw 'a large cat'"
32 | " only five yards away from her. It")
33 |
34 | # context = "\nPumas are large, "
35 |
36 |
37 | def sample_rwkv_rnn(token_arr, state):
38 | return rwkv_net_w(token_arr, state, w)
39 |
40 |
41 | state = np.zeros((n_layers, 5, n_channels))
42 | for i in range(n_layers):
43 | # to jax state[5 * i + 4] = -1e30
44 | state = state.at[i, 4].set(-1e30)
45 |
46 |
47 | keygen = infinite_safe_keys(0)
48 | rnn_generate(sample_rwkv_rnn, context, state, tokenizer, keygen,
49 | n_trials=1,
50 | argmax=False,
51 | temperature=1.0,
52 | top_p=0.85,
53 | length_per_trial=100)
54 |
55 |
--------------------------------------------------------------------------------
/pico_rwkv/jax_load.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from safetensors import safe_open
4 |
5 | path = Path("/Data/lm_models/rwkv")
6 | model_name = 'RWKV-4-Pile-430M-20220808-8066'
7 | with safe_open(path / f"{model_name}.safetensors", framework="torch", device="cpu") as f:
8 | for key in f.keys():
9 | t = f.get_tensor(key)
10 | print(key, t.shape)
11 |
12 |
--------------------------------------------------------------------------------
/pico_rwkv/pico_rwkv_parallel.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import jax.numpy as np
4 | from jax import jit, vmap, lax
5 |
6 | from pico_rwkv.pico_rwkv_rnn import channel_mixing, token_mixing, layer_norm
7 |
8 |
9 | def exp_mix(v1, v2, p1, p2):
10 | """
11 | Given
12 | x1 --> (v1, p1) --> v1 exp(p1),
13 | x2 --> (v2, p2) --> v2 exp(p2),
14 | calculate x1 + x2 and normalize it such that the
15 | largest exp exponent is 0,
16 | x1 + x2 --> (v1 exp(p1-p) + v2 exp(p2-p), p)
17 | where p = max(p1, p2)
18 | """
19 | p = np.maximum(p1, p2)
20 | return (v1 * np.exp(p1 - p) + v2 * np.exp(p2 - p)), p
21 |
22 |
23 | def token_mixing_parallel(x, x_prev1, time_mix_k, time_mix_v, time_mix_r, kw, vw, rw, **kwargs):
24 | x_prev = np.vstack((x_prev1, x[:-1, :]))
25 | token_mixing_p = vmap(token_mixing, in_axes=(0, 0, None, None, None, None, None, None), out_axes=(0, 0, 0))
26 | return token_mixing_p(x, x_prev, time_mix_k, time_mix_v, time_mix_r, kw, vw, rw)
27 |
28 |
29 | def channel_mixing_parallel(x, x_prev1, time_mix_k, time_mix_r, kw, vw, rw, **kwargs):
30 | x_prev = np.vstack((x_prev1, x[:-1, :]))
31 | channel_mixing_p = vmap(channel_mixing, in_axes=(0, 0, None, None, None, None, None), out_axes=0)
32 | return channel_mixing_p(x, x_prev, time_mix_k, time_mix_r, kw, vw, rw)
33 |
34 |
35 | def lru_parallel_scannable_normalized(left, right):
36 | (l_exp_kv, l_w, p_w), (r_exp_kv, r_w, p_r) = left, right
37 | p = np.maximum(p_w + r_w, p_r)
38 | return l_exp_kv * np.exp(r_w + p_w - p) + r_exp_kv * np.exp(p_r - p), l_w + r_w, p
39 |
40 |
41 | def rwkv_parallel_scan_stable(r, k, v, ow, time_first, time_decay):
42 | w, u = time_decay, time_first
43 | W = np.repeat(w[np.newaxis, :], v.shape[0], axis=0)
44 | ones = np.ones_like(k)
45 |
46 | a_state, _, p_state = lax.associative_scan(lru_parallel_scannable_normalized, (v, W, k))
47 | b_state, _, _ = lax.associative_scan(lru_parallel_scannable_normalized, (ones, W, k))
48 |
49 | c, _ = exp_mix(a_state, v, p_state, u + w + k)
50 | d, _ = exp_mix(b_state, ones, p_state, u + w + k)
51 |
52 | wkv = c / d
53 | return (r * wkv) @ ow.T
54 |
55 |
56 | @partial(jit, static_argnums=(0,))
57 | def rwkv_net_parallel(seq_len: int, tokens, ln_out, blocks, head, emb):
58 | """
59 | :param seq_len: int
60 | :param tokens: int32[seq_len]
61 | :param ln_out:
62 | :param blocks:
63 | :param head:
64 | :param emb: {'weight': float32[n_vocab, n_channels]}
65 | :return:
66 | """
67 | assert seq_len >= 2
68 | zeros_padding = np.zeros_like(emb['weight'][0, :])
69 |
70 | x = emb['weight'][tokens, :]
71 | w_ln0 = blocks[0]['ln0']
72 | x = layer_norm(x, **w_ln0)
73 | for i in range(len(blocks)):
74 | block_w = blocks[i]
75 |
76 | xn_token = layer_norm(x, **blocks[i]['ln1'])
77 | r, k, v = token_mixing_parallel(xn_token, zeros_padding, **block_w['att'])
78 |
79 | xp = rwkv_parallel_scan_stable(r, k, v, block_w['att']['output']['weight'],
80 | block_w['att']['time_first'], block_w['att']['time_decay'])
81 | x += xp
82 | xn_channel = layer_norm(x, **blocks[i]['ln2'])
83 | xp = channel_mixing_parallel(xn_channel, zeros_padding, **block_w['ffn'])
84 |
85 | x += xp
86 | xn = layer_norm(x, **ln_out)
87 | # parallel version of `logits = head['weight'] @ xn`, t is time, c is channel, v is vocab
88 | logits = np.einsum('tc,vc->tv', xn, head['weight'])
89 | return logits
90 |
91 | # [Done] load params
92 | # [Done] jit
93 | # [TODO] associative scan
94 | # - trying to formulate according to jianlin's blog (is cumsum in rwkv viable)?
95 | # - not trying to use states, only use it for training and fixed-context inference
96 | # - [Done] make it work
97 | # - [Done] check numerical issues with help of yilun
98 | # [Done] make a scan version and compare with
99 | # - [Done] scan batch of tokens
100 | # - [TODO] maintain state using normalized version of scan
101 | # - [TODO] merge upper and lower scan
102 | # [Abandoned] pico-plus
103 | # [Done] training
104 |
--------------------------------------------------------------------------------
/pico_rwkv/pico_rwkv_parallel_alternatives.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | from jax import lax
3 |
4 | from pico_rwkv.pico_rwkv_parallel import token_mixing_parallel, channel_mixing_parallel
5 | from pico_rwkv.pico_rwkv_rnn import layer_norm, rwkv, rwkv_state_flow
6 |
7 |
8 | def rwkv_scan(state, rs, ks, vs, ow, time_first, time_decay):
9 | w, u = time_decay, time_first
10 |
11 | def lru_scannable(carry, new):
12 | r, k, v = new
13 | new = rwkv(carry, r, k, v, ow, time_first)
14 | carry_new = rwkv_state_flow(w, carry, k, v)
15 | return np.stack(carry_new), new
16 |
17 | return lax.scan(lru_scannable, state, (rs, ks, vs))
18 |
19 |
20 | def lru_parallel_scannable(left, right):
21 | (l_exp_kv, l_w), (r_exp_kv, r_w) = left, right
22 | return l_exp_kv * np.exp(r_w) + r_exp_kv, l_w + r_w
23 |
24 |
25 | def rwkv_parallel_scan(seq_len: int, r, k, v, ow, time_first, time_decay):
26 | w, u = time_decay, time_first
27 | W = np.repeat(w[np.newaxis, :], seq_len, axis=0)
28 |
29 | exp_k = np.exp(k)
30 | v_state, _ = lax.associative_scan(lru_parallel_scannable, (exp_k * v, W))
31 | base_state, _ = lax.associative_scan(lru_parallel_scannable, (exp_k, W))
32 | curr_k = np.exp(u) * exp_k
33 |
34 | def shift1pad0(x):
35 | return np.pad(x, ((1, 0), (0, 0)), mode='constant', constant_values=0)[:-1, :]
36 |
37 | v_state = shift1pad0(v_state) + curr_k * v
38 | base_state = shift1pad0(base_state) + curr_k
39 |
40 | wkv = v_state / base_state
41 | return (r * wkv) @ ow.T
42 |
43 |
44 | def rwkv_parallel_scan_alt(seq_len: int, r, k, v, ow, time_first, time_decay):
45 | w, u = time_decay, time_first
46 | W = np.repeat(w[np.newaxis, :], seq_len, axis=0)
47 |
48 | exp_k = np.exp(k)
49 | v_state, _ = lax.associative_scan(lru_parallel_scannable, (exp_k * v, W))
50 | base_state, _ = lax.associative_scan(lru_parallel_scannable, (exp_k, W))
51 | curr_diff = exp_k * (np.exp(u + w) - 1)
52 |
53 | v_state += curr_diff * v
54 | base_state += curr_diff
55 | base_state += 10e-6
56 |
57 | wkv = v_state / base_state
58 | return (r * wkv) @ ow.T
59 |
60 |
61 | def rwkv_net_scan(seq_len: int, tokens, states, ln_out, blocks, head, emb):
62 | assert seq_len >= 2
63 | prev_xn_token = states[:, 1, :]
64 | prev_xn_channel = states[:, 0, :]
65 | prev_wkv_states = states[:, 2:, :]
66 |
67 | x = emb['weight'][tokens, :]
68 | w_ln0 = blocks[0]['ln0']
69 | x = layer_norm(x, **w_ln0)
70 | new_states = np.empty_like(states)
71 | for i in range(len(blocks)):
72 | block_w = blocks[i]
73 |
74 | xn_token = layer_norm(x, **blocks[i]['ln1'])
75 | r, k, v = token_mixing_parallel(xn_token, prev_xn_token[i], **block_w['att'])
76 |
77 | state_wkv, xp = rwkv_scan(prev_wkv_states[i], r, k, v, block_w['att']['output']['weight'],
78 | block_w['att']['time_first'], block_w['att']['time_decay'])
79 | x += xp
80 | xn_channel = layer_norm(x, **blocks[i]['ln2'])
81 | xp = channel_mixing_parallel(xn_channel, prev_xn_channel[i], **block_w['ffn'])
82 |
83 | x += xp
84 | new_states = new_states.at[i, :].set(np.stack([xn_channel[-1], xn_token[-1], *(s for s in state_wkv)]))
85 | xn = layer_norm(x, **ln_out)
86 | # parallel version of `logits = head['weight'] @ xn`, t is time, c is channel, v is vocab
87 | logits = np.einsum('tc,vc->tv', xn, head['weight'])
88 | return logits, new_states
89 |
--------------------------------------------------------------------------------
/pico_rwkv/pico_rwkv_rnn.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | from jax import lax, jit
3 | from jax.lax import rsqrt
4 | from jax.nn import sigmoid, relu
5 |
6 |
7 | def layer_norm(x, weight, bias, eps: float = 1e-5):
8 | mean = np.mean(x, axis=-1, keepdims=True)
9 | variance = np.var(x, axis=-1, keepdims=True)
10 | return weight * (x - mean) * rsqrt(variance + eps) + bias
11 |
12 |
13 | def time_mix(x, x_prev, mix):
14 | return x * mix + x_prev * (1 - mix)
15 |
16 |
17 | def exp_mix_both(max_coef, k, v_s, v, base):
18 | """
19 | 1. If `max_coef >= k`, the mixed value and the updated base value are computed as follows:
20 | mix_v = v_s + e^(k - max_coef) * v
21 | base_new = base + e^(k - max_coef)
22 |
23 | 2. If `max_coef < k`, the mixed value and the updated base value are computed as follows:
24 | mix_v = e^(max_coef - k) * v_s + v
25 | base_new = e^(max_coef - k) * base + 1
26 | """
27 | new_max_coef = np.maximum(max_coef, k)
28 | e1 = np.exp(max_coef - new_max_coef)
29 | e2 = np.exp(k - new_max_coef)
30 | mix_v = e1 * v_s + e2 * v
31 | base_new = e1 * base + e2
32 | return mix_v, base_new, new_max_coef
33 |
34 |
35 | def rwkv(state_wkv, r, k, v, ow, time_first, debug=False):
36 | """
37 | state_wkv: (v_state, wkv_bot, max_coef) / state[i, 2:].T
38 | """
39 | v_state, base_state, max_coef = state_wkv
40 |
41 | v_state, base_state, _ = exp_mix_both(max_coef, k + time_first, v_state, v, base_state)
42 | wkv = v_state / base_state
43 | return ow @ (r * wkv)
44 |
45 |
46 | def rwkv_state_flow(time_decay, state_wkv, k, v, **kwargs):
47 | v_state, base_state, max_coef = state_wkv
48 | a, b, c = exp_mix_both(max_coef + time_decay, k, v_state, v, base_state)
49 | return a, b, c
50 |
51 |
52 | def token_mixing(x, x_prev, time_mix_k, time_mix_v, time_mix_r, kw, vw, rw, **kwargs):
53 | xk = time_mix(x, x_prev, time_mix_k)
54 | xv = time_mix(x, x_prev, time_mix_v)
55 | xr = time_mix(x, x_prev, time_mix_r)
56 |
57 | r = sigmoid(rw @ xr)
58 | k = kw @ xk
59 | v = vw @ xv
60 | return r, k, v
61 |
62 |
63 | def channel_mixing(x, x_prev, time_mix_k, time_mix_r, kw, vw, rw, **kwargs):
64 | xk = time_mix(x, x_prev, time_mix_k)
65 | xr = time_mix(x, x_prev, time_mix_r)
66 | r = sigmoid(rw @ xr)
67 | k = np.square(relu(kw @ xk)) # square relu, primer paper
68 | return r * (vw @ k)
69 |
70 |
71 | @jit
72 | def rwkv_net_rnn(token, state, ln_out, blocks, head, emb):
73 | x = emb['weight'][token]
74 | w_ln0 = blocks[0]['ln0']
75 | x = layer_norm(x, **w_ln0)
76 | new_states = np.empty_like(state)
77 | for i in range(len(blocks)):
78 | block_w = blocks[i]
79 |
80 | xn_token = layer_norm(x, **blocks[i]['ln1'])
81 | r, k, v = token_mixing(xn_token, state[i, 1], **block_w['att'])
82 |
83 | state_wkv = state[i, 2:]
84 | # print(state_wkv.shape)
85 |
86 | x += rwkv(state_wkv, r, k, v, block_w['att']['output']['weight'], block_w['att']['time_first'], debug=False)
87 | xn_channel = layer_norm(x, **blocks[i]['ln2'])
88 | xp = channel_mixing(xn_channel, state[i, 0], **block_w['ffn'])
89 |
90 | x += xp
91 | new_state = np.vstack([xn_channel,
92 | xn_token,
93 | *rwkv_state_flow(block_w['att']['time_decay'],
94 | state_wkv, k, v, debug=False)])
95 | new_states = new_states.at[i].set(new_state)
96 | xn = layer_norm(x, **ln_out)
97 | logits = head['weight'] @ xn
98 | return logits, new_states
99 |
--------------------------------------------------------------------------------
/pico_rwkv/pico_rwkv_weights.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, TypeVar, Iterable, Optional
2 |
3 | from copy_init.weights import WeightConfig, get_weights_mask
4 | from labels import Labels
5 | from picojax.jax_utils import Arr, WeightsTree
6 | from jax import numpy as np
7 |
8 |
9 | def parse_rwkv_weight(keys: Iterable[str], get_tensor: Callable[[str], Arr], trim: bool=False) -> WeightsTree:
10 | w = {}
11 | for k in keys:
12 | parts = k.split('.')
13 | last = parts.pop()
14 | current_ = w
15 | for p in parts:
16 | if p.isdigit():
17 | p = int(p)
18 | if p not in current_:
19 | current_[p] = {}
20 | current_ = current_[p]
21 | current_[last] = get_tensor(k)
22 |
23 | for i in w['blocks'].keys():
24 | att = w['blocks'][i]['att']
25 | ffn = w['blocks'][i]['ffn']
26 |
27 | for m in att, ffn:
28 | for k in ('key', 'value', 'receptance'):
29 | if k in m:
30 | m[k[0] + 'w'] = m[k]['weight']
31 | if trim:
32 | del m[k]['weight']
33 | return w
34 |
35 |
36 | def get_masks_to_train(train_tags: Optional[list[Labels]], info: dict[str, WeightConfig], trim:bool=False) -> WeightsTree:
37 | to_train = get_weights_mask(train_tags, info)
38 | mask_raw = {}
39 | for k, train in to_train.items():
40 | if train:
41 | mask_raw[k] = np.ones(info[k].shape, dtype=bool)
42 | else:
43 | mask_raw[k] = np.zeros(info[k].shape, dtype=bool)
44 |
45 | masks = parse_rwkv_weight(mask_raw.keys(), lambda k: mask_raw[k], trim)
46 | return masks
47 |
--------------------------------------------------------------------------------
/pico_rwkv/pth_to_safet.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | import jax
5 | from safetensors.flax import save_file
6 |
7 | path = Path("/Data/lm_models/rwkv")
8 | # model_name = 'RWKV-4-Pile-430M-20220808-8066'
9 | model_name = 'RWKV-4-Pile-169M-20220807-8023'
10 |
11 | w_raw = torch.load(path/f'{model_name}.pth', map_location='cpu')
12 |
13 | for k in w_raw.keys():
14 | if '.time_' in k:
15 | w_raw[k] = w_raw[k].squeeze()
16 | if '.time_decay' in k:
17 | w_raw[k] = -torch.exp(w_raw[k]) # the real time decay is like e^{-e^x}
18 | else:
19 | w_raw[k] = w_raw[k] # convert to f32 type
20 | w_raw[k] = jax.numpy.array(w_raw[k].float())
21 |
22 | save_file(w_raw, path/f'{model_name}.safetensors')
23 |
--------------------------------------------------------------------------------
/pico_rwkv/run_rwkv_parallel.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import jax.numpy as np
4 | from safetensors import safe_open
5 | from tokenizers import Tokenizer
6 |
7 | from nlp_utils import rnn_generate
8 | from pico_rwkv.pico_rwkv_parallel import rwkv_net_parallel
9 | from pico_rwkv.pico_rwkv_parallel_alternatives import rwkv_net_scan
10 | from pico_rwkv.pico_rwkv_rnn import rwkv_net_rnn
11 | from pico_rwkv.pico_rwkv_weights import parse_rwkv_weight
12 | from picojax.jax_utils import Arr
13 | from picojax.random_utils import infinite_safe_keys
14 |
15 | path = Path("/Data/lm_models/rwkv")
16 | model_name = 'RWKV-4-Pile-430M-20220808-8066'
17 | # jax.config.update('jax_platform_name', 'cpu')
18 |
19 | with safe_open(path / f"{model_name}.safetensors", framework="flax", device="cpu") as f:
20 | w = parse_rwkv_weight(f.keys(), f.get_tensor)
21 |
22 | tokenizer = Tokenizer.from_file(str(path / "20B_tokenizer.json"))
23 |
24 | n_channels = w['emb']['weight'].shape[1]
25 | n_layers = len(w['blocks'])
26 |
27 | key_gen = infinite_safe_keys(0)
28 |
29 | # context = """Where must the puma have come from?
30 | #
31 | # Pumas are large, cat-like animals which are found in America. When reports came into London Zoo that a wild puma had been spotted forty-five miles south of London, they were not taken seriously. However, as the evidence began to accumulate, experts from the Zoo felt obliged to investigate, for the descriptions given by people who claimed to have seen the puma were extraordinarily similar.
32 | # The hunt for the puma began in a small village where a woman picking blackberries saw 'a large cat' only five yards away from her. It immediately ran away when she saw it, and experts confirmed that a puma will not attack a human being unless it is cornered. The search proved difficult, for the puma was often observed at one place in the morning and at another place twenty miles away in the evening. Wherever it went, it left behind it a trail of dead deer and small animals like rabbits. Paw prints were seen in a number of places and puma fur was found clinging to bushes. Several people complained of "cat-like noises' at night and a businessman on a fishing trip saw the puma up a tree. The experts were now fully convinced that the animal was a puma, but where had it come from? As no pumas had been reported missing from any zoo in the country, this one must have been in the possession of a private collector and somehow managed to escape. The hunt went on for several weeks, but the puma was not caught. It is disturbing to think that a dangerous wild animal is still at large in the quiet countryside.
33 | #
34 | # In not more than 80 words describe how experts came to the conclusion that the animal seen by many people really was a puma. Do not include anything that is not in the passage.
35 | # Answer these questions in note form to get your points:
36 | # 1 What sort of reports were received by London Zoo?
37 | # 2 Were the reports similar in nature or not?
38 | # 3 Who saw it first?
39 | # 4 Did it stay in one place,or did it move from place to place?
40 | # 5 What did it leave behind it?
41 | # 6 Were paw prints and puma fur found as well or not?
42 | # 7 What was heard at night?
43 | # 8 Was the animal seen up a tree or not?
44 | # 9 Were experts now sure that the animal really was a puma or not?
45 | #
46 | # """
47 |
48 |
49 | context = ("\nPumas are large, cat-like animals found in America. When reports came into London Zoo that "
50 | "a wild puma had been spotted forty-five miles south of London, they were not taken seriously."
51 | " However, as the evidence began to accumulate, experts from the Zoo felt obliged to investigate,"
52 | " for the descriptions given by people who claimed to have seen the puma were extraordinarily similar."
53 | "\nThe hunt for the puma began in a small village where a woman picking blackberries saw 'a large cat'"
54 | " only five yards away from her. It")
55 | # context = "The quick brown fox jumps over the lazy"
56 | # context = "Once upon a "
57 |
58 | def sample_rwkv_rnn(token_arr: Arr, state: Arr) -> tuple[Arr, Arr]:
59 | return rwkv_net_rnn(token_arr, state, **w)
60 |
61 | mode = "parallel"
62 | # mode = "rnn"
63 | # mode = "parallel"
64 | if mode == "parallel":
65 | token_array = np.array(tokenizer.encode(context).ids)
66 | # init_out = rwkv_net_parallel(len(token_array), token_array, **w)
67 | # init_out = init_out[-1, :]
68 | print("token length: ", len(token_array))
69 | print(context, end="")
70 |
71 | for i in range(100):
72 | init_out = rwkv_net_parallel(len(token_array), token_array, **w)
73 | out = np.argmax(init_out[-1, :], axis=-1)
74 | print(tokenizer.decode([out]), end="", flush=True)
75 | token_array = np.append(token_array[1:], out)
76 | raise Exception("Done")
77 |
78 |
79 | elif mode == "rnn":
80 | state = np.zeros((n_layers, 5, n_channels))
81 | for i in range(n_layers):
82 | state = state.at[i, -1].set(-1e30)
83 | key_gen = infinite_safe_keys(0)
84 | _ = rnn_generate(sample_rwkv_rnn, context, state, tokenizer, key_gen,
85 | argmax=True,
86 | length_per_trial=100, n_trials=3, temperature=1.0, top_p=0.85)
87 |
88 | elif mode == "scan":
89 | max_len_ = 10
90 | state = np.zeros((n_layers, 5, n_channels))
91 | for i in range(n_layers):
92 | state = state.at[i, -1].set(-1e30)
93 |
94 |
95 | def process_tokens(token_objs, max_len):
96 | # l_token_array = np.array_split(full_token_array, len(full_token_array) // max_len)
97 | # if len(l_token_array[-1]) < max_len:
98 | # return l_token_array[:-1], l_token_array[-1]
99 | # else:
100 | # return l_token_array, []
101 | curr_len = len(token_objs)
102 | batch = []
103 | while curr_len > max_len:
104 | batch.append(np.array(token_objs[:max_len]))
105 | token_objs = token_objs[max_len:]
106 | curr_len = len(token_objs)
107 | return batch, np.array(token_objs)
108 |
109 |
110 | token_array_batch, left_over = process_tokens(tokenizer.encode(context).ids, max_len_)
111 | print(token_array_batch)
112 | print(left_over)
113 | if len(token_array_batch) > 0:
114 | for token_array in token_array_batch:
115 | init_out, state = rwkv_net_scan(max_len_, token_array, state, **w)
116 | if len(left_over) > 0:
117 | for token in left_over:
118 | init_out, state = rwkv_net_rnn(token, state, **w)
119 | else:
120 | print("scan_out", tokenizer.decode(np.argmax(init_out)))
121 | init_out = init_out[-1, :]
122 | key_gen = infinite_safe_keys(0)
123 | _ = rnn_generate(sample_rwkv_rnn, context, state, tokenizer, key_gen,
124 | argmax=True,
125 | length_per_trial=100, n_trials=3, temperature=1.0, top_p=0.85)
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
--------------------------------------------------------------------------------
/pico_rwkv/rwkv_torch.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | np.set_printoptions(precision=4, suppress=True, linewidth=200)
4 | import types, torch
5 | from torch.nn import functional as F
6 | from tokenizers import Tokenizer
7 |
8 | tokenizer = Tokenizer.from_file("/Data/lm_models/rwkv/20B_tokenizer.json")
9 |
10 | args = types.SimpleNamespace()
11 | args.MODEL_NAME = '/Data/lm_models/rwkv/RWKV-4-Pile-430M-20220808-8066'
12 | args.n_layer = 24
13 | args.n_embd = 1024
14 |
15 | context = "\nPumas are large, cat-like animals"
16 | NUM_TRIALS = 3
17 | LENGTH_PER_TRIAL = 100
18 | TEMPERATURE = 1.0
19 | TOP_P = 0.85
20 |
21 |
22 | ########################################################################################################
23 |
24 | class RWKV_RNN(torch.jit.ScriptModule):
25 | def __init__(self, args):
26 | super().__init__()
27 | self.args = args
28 | self.eval() # set torch to inference mode
29 |
30 | w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
31 | for k in w.keys():
32 | if '.time_' in k: w[k] = w[k].squeeze()
33 | if '.time_decay' in k:
34 | w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
35 | else:
36 | w[k] = w[k].float() # convert to f32 type
37 |
38 | self.w = types.SimpleNamespace() # set self.w from w
39 | self.w.blocks = {}
40 | for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
41 | parts = k.split('.')
42 | last = parts.pop()
43 | here = self.w
44 | for p in parts:
45 | if p.isdigit():
46 | p = int(p)
47 | if p not in here: here[p] = types.SimpleNamespace()
48 | here = here[p]
49 | else:
50 | if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
51 | here = getattr(here, p)
52 | setattr(here, last, w[k])
53 |
54 | def layer_norm(self, x, w):
55 | return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
56 |
57 | @torch.jit.script_method
58 | def channel_mixing(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw):
59 | xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
60 | xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
61 | state[5 * i + 0] = x
62 | r = torch.sigmoid(rw @ xr)
63 | k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
64 | return r * (vw @ k)
65 |
66 | @torch.jit.script_method
67 | def time_mixing(self, x, state, i: int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
68 | if i == 10:
69 | print(state[5 * i + 1])
70 | raise Exception("stop")
71 | xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
72 | xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
73 | xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
74 | state[5 * i + 1] = x
75 | if i == 9:
76 | print(state[5 * i + 1])
77 | r = torch.sigmoid(rw @ xr)
78 | k = kw @ xk
79 | v = vw @ xv
80 |
81 | aa = state[5 * i + 2]
82 | bb = state[5 * i + 3]
83 | pp = state[5 * i + 4]
84 | ww = time_first + k
85 | qq = torch.maximum(pp, ww)
86 | e1 = torch.exp(pp - qq)
87 | e2 = torch.exp(ww - qq)
88 | a = e1 * aa + e2 * v
89 | b = e1 * bb + e2
90 | wkv = a / b
91 | ww = pp + time_decay
92 | qq = torch.maximum(ww, k)
93 | e1 = torch.exp(ww - qq)
94 | e2 = torch.exp(k - qq)
95 | state[5 * i + 2] = e1 * aa + e2 * v
96 | state[5 * i + 3] = e1 * bb + e2
97 | state[5 * i + 4] = qq
98 | return ow @ (r * wkv)
99 |
100 | def forward(self, token, state):
101 | with torch.no_grad():
102 | if state == None:
103 | state = torch.zeros(self.args.n_layer * 5, self.args.n_embd)
104 | for i in range(self.args.n_layer): state[5 * i + 4] = -1e30 # -infinity
105 |
106 | x = self.w.emb.weight[token]
107 | x = self.layer_norm(x, self.w.blocks[0].ln0)
108 | for i in range(self.args.n_layer):
109 | att = self.w.blocks[i].att
110 | xp = self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
111 | att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay,
112 | att.key.weight, att.value.weight, att.receptance.weight, att.output.weight)
113 | x += xp
114 | ffn = self.w.blocks[i].ffn
115 | xp = self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i,
116 | ffn.time_mix_k, ffn.time_mix_r,
117 | ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
118 | x += xp
119 |
120 | x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
121 | return x.float(), state
122 |
123 |
124 | ##########################################################################################################
125 |
126 | def sample_logits(out, temperature=1.0, top_p=0.8):
127 | probs = F.softmax(out, dim=-1).numpy()
128 | sorted_probs = np.sort(probs)[::-1]
129 | cumulative_probs = np.cumsum(sorted_probs)
130 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
131 | probs[probs < cutoff] = 0
132 | if temperature != 1.0:
133 | probs = probs.pow(1.0 / temperature)
134 | probs = probs / np.sum(probs)
135 | out = np.random.choice(a=len(probs), p=probs)
136 | return out
137 |
138 |
139 | ########################################################################################################
140 |
141 | print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
142 | model = RWKV_RNN(args)
143 |
144 | print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
145 | init_state = None
146 | for token in tokenizer.encode(context).ids:
147 | init_out, init_state = model.forward(token, init_state)
148 |
149 | for TRIAL in range(NUM_TRIALS):
150 | print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
151 | all_tokens = []
152 | out_last = 0
153 | out, state = init_out.clone(), init_state.clone()
154 | for i in range(LENGTH_PER_TRIAL):
155 | token = np.argmax(out)
156 | # token = sample_logits(out, TEMPERATURE, TOP_P)
157 | all_tokens += [token]
158 | tmp = tokenizer.decode(all_tokens[out_last:])
159 | if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
160 | print(tmp, end="", flush=True)
161 | out_last = i + 1
162 | out, state = model.forward(token, state)
163 | print('\n')
--------------------------------------------------------------------------------
/pico_rwkv/rwkv_weight_profile.py:
--------------------------------------------------------------------------------
1 | from copy_init.weights import WeightConfig, WeightConfigType, NormalWeight, ZeroWeight
2 | from labels import L
3 |
4 |
5 | def make_weight_config(name: str, shape: tuple[int, ...], wtype: WeightConfigType) -> WeightConfig:
6 | tags = L(*name.split('.'))
7 | if wtype == 'normal':
8 | return NormalWeight(name=name, shape=shape, tags=tags)
9 | elif wtype == 'zero':
10 | return ZeroWeight(name=name, shape=shape, tags=tags)
11 | else:
12 | raise ValueError(f"weight_config_type must be 'normal' or 'zero', not {wtype}")
13 |
14 |
15 | def make_rwkv_weight_configs(n_blocks: int, n_embd: int, ffn_hidden_multiplier: int) -> dict[str, WeightConfig]:
16 | properties = [
17 | ("blocks.att.output.weight", (n_blocks, n_embd, n_embd), 'zero'),
18 | ("blocks.att.value.weight", (n_blocks, n_embd, n_embd), 'zero'),
19 | ("blocks.att.key.weight", (n_blocks, n_embd, n_embd), 'zero'),
20 | ("blocks.att.receptance.weight", (n_blocks, n_embd, n_embd), 'zero'),
21 | ("blocks.att.time_decay", (n_blocks, n_embd), 'zero'),
22 | ("blocks.att.time_first", (n_blocks, n_embd), 'zero'),
23 | ("blocks.att.time_mix_k", (n_blocks, n_embd), 'zero'),
24 | ("blocks.att.time_mix_r", (n_blocks, n_embd), 'zero'),
25 | ("blocks.att.time_mix_v", (n_blocks, n_embd), 'zero'),
26 | ("blocks.ffn.key.weight", (n_blocks, ffn_hidden_multiplier * n_embd, n_embd), 'zero'),
27 | ("blocks.ffn.value.weight", (n_blocks, n_embd, n_embd * ffn_hidden_multiplier), 'zero'),
28 | ("blocks.ffn.receptance.weight", (n_blocks, n_embd, n_embd), 'zero'),
29 | ("blocks.ffn.time_mix_k", (n_blocks, n_embd), 'zero'),
30 | ("blocks.ffn.time_mix_r", (n_blocks, n_embd), 'zero'),
31 | ("blocks.ffn.time_mix_v", (n_blocks, n_embd), 'zero'),
32 | ("blocks.ln0.weight", (n_blocks, n_embd), 'zero'),
33 | ("blocks.ln0.bias", (n_blocks, n_embd), 'zero'),
34 |
35 |
36 | ]
37 |
--------------------------------------------------------------------------------
/pico_rwkv/train_rwkv.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | from functools import partial
5 | from pathlib import Path
6 | from typing import Optional
7 |
8 | import jax.numpy as np
9 | import optax
10 | import wandb
11 | from jax.tree_util import tree_flatten
12 | from tokenizers import Tokenizer
13 |
14 | import custom_dataset
15 | import custom_dataset_str
16 | import nlp_utils
17 | from copy_init.weights import get_normal_weights_config_, init, save_pytree_
18 | from labels import Labels
19 | from pico_rwkv.pico_rwkv_parallel import rwkv_net_parallel
20 | from pico_rwkv.pico_rwkv_rnn import rwkv_net_rnn
21 | from pico_rwkv.pico_rwkv_weights import parse_rwkv_weight
22 | from picojax.jax_utils import WeightsTree, Arr
23 | from picojax.random_utils import infinite_safe_keys
24 | from picojax.train_utils import LMBatchConfig, TrainConfig, TrainState, get_lm_loss
25 | from python_utils import num_short_form
26 |
27 | os.environ['JAX_LOG_COMPILES'] = '1'
28 | model_path = Path("/Data/lm_models/rwkv")
29 | # model_name = 'RWKV-4-Pile-430M-20220808-8066'
30 | model_name = 'RWKV-4-Pile-169M-20220807-8023'
31 |
32 | weight_infos = get_normal_weights_config_(model_path, model_name)
33 |
34 | keygen = infinite_safe_keys(0)
35 | key = next(keygen)
36 |
37 | init_weights_raw = init(weight_infos, rng_key=key)
38 | all_weight_names = list(init_weights_raw.keys())
39 |
40 | # randomly initialize weights
41 | init_weights_: WeightsTree = parse_rwkv_weight(init_weights_raw.keys(), init_weights_raw.__getitem__, trim=True)
42 | ## load weights instead of randomly initialize
43 | # with safe_open(model_path / f"{model_name}.safetensors", framework="flax", device="cpu") as f:
44 | # init_weights_ = parse_rwkv_weight(f.keys(), f.get_tensor)
45 |
46 | n_channels = init_weights_['emb']['weight'].shape[1] # type: ignore
47 | _, tree_struct = tree_flatten(init_weights_)
48 |
49 | train_tags: Optional[list[Labels]] = None
50 | weight_mask = None
51 |
52 | # train_tags = [Labels.from_strs('ffn', 'key'), Labels.from_strs('ffn', 'value')]
53 | # weight_mask = get_masks_to_train(train_tags, weight_infos, trim=True)
54 |
55 |
56 | # %%
57 |
58 | data_path = Path("/Data/nlp/")
59 |
60 | str_sampling = False # turn this on if the data is too large to fit in memory, may be a bit slower
61 | custom_vocab = True # turn this off if you want to use the default tokenizer, str_sampling == True does not support custom_token
62 |
63 | dataset = "play"
64 | # dataset = "poem"
65 | # dataset = "english"
66 | # dataset = "russell"
67 | # dataset = "duilian"
68 | # dataset = "MQnovel"
69 |
70 | if str_sampling:
71 | input_data = custom_dataset_str.load(data_path, dataset)
72 | tokenizer = Tokenizer.from_file(str(model_path / "20B_tokenizer.json"))
73 | else:
74 | if custom_vocab:
75 | input_data, tokenizer = custom_dataset.load_jax_cached_tokenizer(data_path, dataset)
76 | else:
77 | input_data_str = custom_dataset_str.load(data_path, dataset)
78 | tokenizer = Tokenizer.from_file(str(model_path / "20B_tokenizer.json"))
79 | input_data = np.array(tokenizer.encode(input_data_str).ids)
80 | if custom_vocab:
81 | # reduce embedding to vocab size
82 | init_weights_['emb']['weight'] = init_weights_['emb']['weight'][:tokenizer.get_vocab_size(), :] # type: ignore
83 | init_weights_['head']['weight'] = init_weights_['head']['weight'][:tokenizer.get_vocab_size(), :] # type: ignore
84 |
85 | n = int(len(input_data) * 0.9)
86 | train_data = input_data[:n]
87 | valid_data = input_data[n:]
88 |
89 | key_gen = infinite_safe_keys(0)
90 |
91 | adam_params = {
92 | 'learning_rate': 1e-4,
93 | 'beta1': 0.9,
94 | 'beta2': 0.999,
95 | 'eps': 1e-8,
96 | }
97 | lion_params = {
98 | 'learning_rate': 1e-4,
99 | 'beta1': 0.95,
100 | 'beta2': 0.98,
101 | 'weight_decay': 0.01
102 | }
103 | train_params = {
104 | 'eval_iters': 200,
105 | 'eval_interval': 2000,
106 | 'save_interval': 10000,
107 | 'max_iters': 1000000,
108 | # 'adam': adam_params,
109 | 'lion': lion_params,
110 | # 'adamw': adam_params,
111 | 'optimizer': 'lion',
112 | }
113 |
114 | experimental_params: dict = {
115 | 'eps': 1e-5,
116 | 'n_tokens': tokenizer.get_vocab_size(),
117 | 'n_channels': n_channels,
118 | 'n_blocks': len(init_weights_['blocks']),
119 |
120 | 'train_tags': [l.fmt() for l in train_tags] if train_tags is not None else None,
121 |
122 | 'batch_size': 4,
123 | 'block_size': 128,
124 | 'train': train_params,
125 | 'model': "rwkv"
126 | }
127 |
128 | max_iters = experimental_params['train']['max_iters']
129 | eval_interval = experimental_params['train']['eval_interval']
130 | save_interval = experimental_params['train']['save_interval']
131 | eval_iters = experimental_params['train']['eval_iters']
132 | batch_config_ = LMBatchConfig(block_size=experimental_params['block_size'],
133 | batch_size=experimental_params['batch_size'],
134 | tokenizer=tokenizer,
135 | str_sampling=str_sampling)
136 |
137 | if experimental_params['train']['optimizer'] == 'adam':
138 | adam_config = experimental_params['train']['adam']
139 | optimizer_ = optax.adam(learning_rate=adam_config['learning_rate'],
140 | b1=adam_config['beta1'],
141 | b2=adam_config['beta2'],
142 | eps=adam_config['eps'])
143 | elif experimental_params['train']['optimizer'] == 'lion':
144 | lion_config = experimental_params['train']['lion']
145 | optimizer_ = optax.lion(learning_rate=lion_config['learning_rate'],
146 | b1=lion_config['beta1'],
147 | b2=lion_config['beta2'],
148 | weight_decay=lion_config['weight_decay'])
149 | elif experimental_params['train']['optimizer'] == 'adamw':
150 | adamw_config = experimental_params['train']['adamw']
151 | optimizer_ = optax.adamw(learning_rate=adamw_config['learning_rate'],
152 | b1=adamw_config['beta1'],
153 | b2=adamw_config['beta2'],
154 | eps=adamw_config['eps'])
155 | else:
156 | raise ValueError(f"optimizer {experimental_params['train']['optimizer']} not supported")
157 |
158 |
159 | def rwkv_f(w: WeightsTree, token_array: Arr) -> Arr:
160 | return rwkv_net_parallel(len(token_array), token_array, **w)
161 |
162 |
163 | def rwkv_rnn(w: WeightsTree, token_array: Arr, state: Arr) -> tuple[Arr, Arr]:
164 | return rwkv_net_rnn(token_array, state, **w)
165 |
166 |
167 | # noinspection PyArgumentList
168 | # cuz it's a NamedTuple
169 | train_config_ = TrainConfig(loss_fn=partial(get_lm_loss, rwkv_f),
170 | optimiser=optimizer_)
171 | # noinspection PyArgumentList
172 | # cuz it's a NamedTuple
173 | train_state_: TrainState = TrainState(weights=init_weights_,
174 | train_mask=weight_mask,
175 | opt_state=optimizer_.init(init_weights_))
176 |
177 | rnn_init_state = np.zeros((experimental_params['n_blocks'], 5, experimental_params['n_channels']))
178 | for i in range(experimental_params['n_blocks']):
179 | # to jax state[5 * i + 4] = -1e30
180 | rnn_init_state = rnn_init_state.at[i, 4].set(-1e30)
181 |
182 | run = wandb.init(
183 | project="inside-transformer",
184 | config=experimental_params,
185 | )
186 | assert isinstance(run, wandb.sdk.wandb_run.Run)
187 |
188 | keys_ = next(key_gen).split(max_iters)
189 | for step in range(max_iters):
190 | batch_ = batch_config_.sample(train_data, keys_[step])
191 |
192 | if step % eval_interval == 0:
193 | loss = train_config_.loss_fn(train_state_.weights, batch_)
194 | print(f"\n===[ step {step} is an eval step ]==========")
195 | print(f"before step {step}, batch loss {loss}")
196 |
197 | train_state_ = train_config_.train1(train_state_, batch_)
198 | if step % eval_interval == 0:
199 | loss = train_config_.loss_fn(train_state_.weights, batch_)
200 | print(f"after step {step}, batch loss {loss}")
201 | results = train_config_.estimate_loss(eval_iters, key_gen, train_state_,
202 | {'train': partial(batch_config_.sample, train_data),
203 | 'val': partial(batch_config_.sample, valid_data)})
204 | generate_f = partial(rwkv_rnn, train_state_.weights)
205 | generated = nlp_utils.rnn_generate(generate_f,
206 | "\n",
207 | tokenizer=batch_config_.tokenizer,
208 | argmax=True,
209 | key_gen=key_gen,
210 | init_state=rnn_init_state,
211 | length_per_trial=batch_config_.block_size - 1)
212 |
213 | wandb.log({"train_loss": results['train'],
214 | "validation_loss": results['val'],
215 | "batch_loss": loss,
216 | "n_tokens_trained": step * batch_config_.batch_size * batch_config_.block_size,
217 | 'generated': wandb.Html(f"{generated}")})
218 | if step % save_interval == 0: # and step > 0:
219 | n_tokens_trained = step * batch_config_.batch_size * batch_config_.block_size
220 | n_tokens_trained_str = num_short_form(n_tokens_trained)
221 | wandb.save(save_pytree_(train_state_.weights, run.dir, f"{model_name}_{n_tokens_trained_str}"), run.dir)
222 |
223 | wandb.finish()
224 |
225 | # [Done] add trainable weights (via gradient mask)
226 | # [TODO] add trainable weights via weight name mask
227 | # [Done] add checkpointing
228 | # TODO: add weight decay
229 |
--------------------------------------------------------------------------------
/pico_s5/pico_s5.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cwhy/rwkv-decon/edb98117140f6a8e99a8fe250a1882a9e07a5650/pico_s5/pico_s5.py
--------------------------------------------------------------------------------
/picojax/jax_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import jax
4 | from typing import Callable, Union, TypeVar
5 |
6 | Arr = jax.Array
7 |
8 |
9 | def jit_f(f: Callable) -> Callable:
10 | return jax.jit(f, static_argnums=(0,), inline=True)
11 |
12 |
13 | WeightsTree = dict[str, 'WeightsType']
14 | WeightsType = Union[Arr, WeightsTree, list[Union[WeightsTree, Arr]]]
15 |
16 |
--------------------------------------------------------------------------------
/picojax/random_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Tuple, Literal, Iterator
4 |
5 | import jax
6 | from jax import Array
7 | from jax import numpy as xp, random
8 | from jax.random import PRNGKeyArray
9 |
10 | ArrayGen = Literal['kaiming', 'dropout', 'embedding', 'normal']
11 |
12 |
13 | class SafeKey:
14 | """Safety wrapper for PRNG keys."""
15 |
16 | def __init__(self, key: PRNGKeyArray):
17 | self._key = key
18 | self._used = False
19 |
20 | def _assert_not_used(self) -> None:
21 | if self._used:
22 | raise RuntimeError('Random key has been used previously.')
23 |
24 | def get(self) -> PRNGKeyArray:
25 | self._assert_not_used()
26 | self._used = True
27 | return self._key
28 |
29 | def split(self, num_keys=2) -> Tuple['SafeKey', ...]:
30 | self._assert_not_used()
31 | self._used = True
32 | new_keys = jax.random.split(self._key, num_keys)
33 | return jax.tree_map(SafeKey, tuple(new_keys))
34 |
35 | def duplicate(self, num_keys=2) -> Tuple['SafeKey', ...]:
36 | self._assert_not_used()
37 | self._used = True
38 | return tuple(SafeKey(self._key) for _ in range(num_keys))
39 |
40 |
41 | def infinite_safe_keys(seed: int) -> Iterator[SafeKey]:
42 | init_key = jax.random.PRNGKey(seed)
43 | while True:
44 | init_key, key = jax.random.split(init_key)
45 | yield SafeKey(key)
46 |
47 |
48 | def dropout_gen(rng_key: SafeKey, keep_rate: float, shape: Tuple[int, ...]):
49 | return random.bernoulli(rng_key.get(), keep_rate, shape)
50 |
51 |
52 | def kaiming_init(rng_key: SafeKey, sd: float, shape: Tuple[int, ...]) -> Array:
53 | """
54 | Generate randomly initialized weight matrix with Kaiming initalization:
55 | Normally distributed scaled by sqrt(2/fan_in)
56 |
57 | Arguments:
58 | :param rng_key: random number generator key from jax
59 | :param sd: standard deviation for initialization
60 | :param shape: = (n_in, ..., n_out)
61 | where
62 | n_in is number of inputs to the layer
63 | n_out is number of outputs from the layer
64 |
65 | Returns:
66 | weight matrix of shape [n_in, n_out]
67 | """
68 | n_in = shape[0]
69 | return xp.sqrt(2 / n_in) * normal_init(rng_key, sd, shape)
70 |
71 |
72 | def embedding_init(rng_key: SafeKey, scale: float, shape: Tuple[int, ...]) -> Array:
73 | """
74 | Arguments:
75 | :param rng_key: random number generator key from jax
76 | :param scale: standard deviation for initialization
77 | :param shape: = (dict_size, ..., dim_model)
78 | where
79 |
80 | Returns:
81 | weight matrix of shape (dict_size, ..., dim_model)
82 | """
83 | dim_model = shape[-1]
84 | return random.normal(rng_key.get(), shape) * xp.sqrt(dim_model) * scale
85 |
86 |
87 | def normal_init(rng_key: SafeKey, sd: float, shape: Tuple[int, ...]) -> Array:
88 | return random.normal(rng_key.get(), shape) * sd
89 |
--------------------------------------------------------------------------------
/picojax/train_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from functools import partial
4 | from typing import Callable, TypeVar, Optional, Union, Protocol, Collection, Sequence
5 | from typing import Iterable, Generic, Iterator
6 |
7 | import jax
8 | import optax
9 | from jax import random, numpy as jnp, vmap, numpy as np
10 | from optax import GradientTransformation, OptState, softmax_cross_entropy_with_integer_labels
11 | from typing_extensions import NamedTuple
12 |
13 | from nlp_utils import Tokenizer
14 | from .jax_utils import Arr, WeightsTree
15 | from .random_utils import SafeKey
16 |
17 | Weights = TypeVar('Weights')
18 | Batch = TypeVar('Batch')
19 |
20 |
21 | @partial(jax.jit, static_argnums=(0, 1), inline=True)
22 | def jax_calc_updates(
23 | optimizer: optax.GradientTransformation,
24 | loss_fn: Callable[[Batch], Arr],
25 | weights: Weights,
26 | train_mask: Optional[WeightsTree],
27 | batch: Batch,
28 | opt_state: optax.OptState
29 | ) -> tuple[Weights, optax.OptState]:
30 | grads = jax.grad(loss_fn)(weights, batch)
31 | if train_mask is not None:
32 | grads = jax.tree_map(lambda x, y: x * y, grads, train_mask)
33 | updates, opt_state = optimizer.update(grads, opt_state, weights)
34 | return optax.apply_updates(weights, updates), opt_state
35 |
36 |
37 | class LMBatchConfig(NamedTuple):
38 | block_size: int
39 | batch_size: int
40 | tokenizer: Tokenizer
41 | str_sampling: bool = False
42 |
43 | def sample_jax(self, data: Arr, rng_key: SafeKey) -> tuple[Iterable[int], Iterable[int]]:
44 | ix = random.randint(key=rng_key.get(), minval=0, maxval=len(data) - self.block_size, shape=(self.batch_size,))
45 | inputs_ = data[(ix[:, jnp.newaxis] + jnp.arange(self.block_size)[jnp.newaxis, :])]
46 | targets_ = data[(ix[:, jnp.newaxis] + jnp.arange(1, self.block_size + 1)[jnp.newaxis, :])]
47 | return inputs_, targets_
48 |
49 | def sample_from_str(self, data: str, rng_key: SafeKey) -> tuple[Iterable[int], Iterable[int]]:
50 | ix = random.randint(key=rng_key.get(), minval=0, maxval=len(data) - self.block_size, shape=(self.batch_size,))
51 | arr = np.stack([np.array(self.tokenizer.encode(data[i:i + self.block_size]).ids).astype(np.int32) for i in ix])
52 | inputs_ = arr[:, :-1]
53 | targets_ = arr[:, 1:]
54 | return inputs_, targets_
55 |
56 | def sample(self, data: Union[str, Arr], rng_key: SafeKey) -> tuple[Iterable[int], Iterable[int]]:
57 | if self.str_sampling:
58 | assert isinstance(data, str), 'data must be a string when str_sampling is True'
59 | return self.sample_from_str(data, rng_key)
60 | else:
61 | assert not isinstance(data, str), 'data must be an array when str_sampling is False'
62 | return self.sample_jax(data, rng_key)
63 |
64 |
65 | W = TypeVar('W')
66 |
67 |
68 | class TrainConfig(NamedTuple, Generic[W]):
69 | loss_fn: Callable[[W, BatchType], Arr]
70 | optimiser: GradientTransformation
71 |
72 | def estimate_loss(self,
73 | eval_iters: int,
74 | rng_key_gen: Iterator[SafeKey],
75 | train_state: TrainState,
76 | samplers: dict[str, Callable[[SafeKey], BatchType]]) -> dict[str, float]:
77 |
78 | results = {}
79 | for split, sampler in samplers.items():
80 | total_eval_loss = 0
81 | for key in next(rng_key_gen).split(eval_iters):
82 | eval_batch = sampler(key)
83 | total_eval_loss += self.loss_fn(train_state.weights, eval_batch).item()
84 | results[split] = total_eval_loss / eval_iters
85 | print(f"Estimated {split} loss: {total_eval_loss / eval_iters}")
86 | return results
87 |
88 | def train1(self, state: TrainState, batch: BatchType) -> TrainState:
89 | weights, opt_state = jax_calc_updates(self.optimiser,
90 | self.loss_fn,
91 | state.weights,
92 | state.train_mask,
93 | batch,
94 | state.opt_state)
95 | return state.update(weights=weights, opt_state=opt_state)
96 |
97 |
98 | class TrainState(Generic[W], NamedTuple):
99 | weights: W
100 | opt_state: OptState
101 | train_mask: Optional[WeightsTree] = None
102 |
103 | def update(self, **kwargs):
104 | return self._replace(**kwargs)
105 |
106 |
107 | FixedLenBatchType = tuple[Iterable[int], Iterable[int]]
108 | MixedLenBatchType = tuple[Iterable[int], Iterable[int], Iterable[int]]
109 | BatchType = Union[FixedLenBatchType, MixedLenBatchType]
110 |
111 |
112 | def get_lm_loss(f: Callable[[WeightsTree, Arr], Arr], w: WeightsTree, batch: FixedLenBatchType) -> Arr:
113 | inputs, labels = batch
114 | logits = vmap(f, in_axes=(None, 0), out_axes=0)(w, np.array(inputs))
115 | return softmax_cross_entropy_with_integer_labels(logits, np.array(labels)).mean()
116 |
117 |
118 | def get_classification_loss(f: Callable[[WeightsTree, Arr], Arr], w: WeightsTree, batch: MixedLenBatchType) -> Arr:
119 | inputs, labels, seq_len = batch
120 | logits = vmap(f, in_axes=(None, 0), out_axes=0)(w, np.array(inputs))
121 | return softmax_cross_entropy_with_integer_labels(logits[np.arange(logits.shape[0]), seq_len], np.array(labels)).mean()
122 |
--------------------------------------------------------------------------------
/python_utils.py:
--------------------------------------------------------------------------------
1 | def format_num(num, unit):
2 | num_str = "{:.1f}".format(num).replace(".", "_")
3 | return num_str.rstrip("0").rstrip("_") + unit
4 |
5 |
6 | def num_short_form(num):
7 | if num == 0:
8 | return "0"
9 | abs_num = abs(num)
10 | sign = "-" if num < 0 else ""
11 |
12 | if abs_num < 1000:
13 | return str(num)
14 | elif abs_num < 1000000:
15 | return sign + format_num(abs_num / 1000, "K")
16 | elif abs_num < 1000000000:
17 | return sign + format_num(abs_num / 1000000, "M")
18 | else:
19 | return sign + format_num(abs_num / 1000000000, "B")
20 |
21 |
22 | if __name__ == "__main__":
23 | def test_convert_num_to_str():
24 | assert num_short_form(0) == "0", "Error: expected '0', but got '{}'".format(num_short_form(0))
25 | assert num_short_form(1) == "1", "Error: expected '1', but got '{}'".format(num_short_form(1))
26 | assert num_short_form(10) == "10", "Error: expected '10', but got '{}'".format(num_short_form(10))
27 | assert num_short_form(100) == "100", "Error: expected '100', but got '{}'".format(num_short_form(100))
28 | assert num_short_form(999) == "999", "Error: expected '999', but got '{}'".format(num_short_form(999))
29 | assert num_short_form(1000) == "1K", "Error: expected '1K', but got '{}'".format(num_short_form(1000))
30 | assert num_short_form(1500) == "1_5K", "Error: expected '1_5K', but got '{}'".format(
31 | num_short_form(1500))
32 | assert num_short_form(1999) == "2K", "Error: expected '2K', but got '{}'".format(num_short_form(1999))
33 | assert num_short_form(1000000) == "1M", "Error: expected '1M', but got '{}'".format(
34 | num_short_form(1000000))
35 | assert num_short_form(1500000) == "1_5M", "Error: expected '1_5M', but got '{}'".format(
36 | num_short_form(1500000))
37 | assert num_short_form(999999999) == "1000M", "Error: expected '1000M', but got '{}'".format(
38 | num_short_form(999999999))
39 | assert num_short_form(1000000000) == "1B", "Error: expected '1B', but got '{}'".format(
40 | num_short_form(1000000000))
41 | assert num_short_form(1500000000) == "1_5B", "Error: expected '1_5B', but got '{}'".format(
42 | num_short_form(1500000000))
43 | assert num_short_form(-1000) == "-1K", "Error: expected '-1K', but got '{}'".format(
44 | num_short_form(-1000))
45 | assert num_short_form(-1500) == "-1_5K", "Error: expected '-1_5K', but got '{}'".format(
46 | num_short_form(-1500))
47 | assert num_short_form(-1999) == "-2K", "Error: expected '-2K', but got '{}'".format(
48 | num_short_form(-1999))
49 | assert num_short_form(-1000000) == "-1M", "Error: expected '-1M', but got '{}'".format(
50 | num_short_form(-1000000))
51 | assert num_short_form(-1500000) == "-1_5M", "Error: expected '-1_5M', but got '{}'".format(
52 | num_short_form(-1500000))
53 | assert num_short_form(-999999999) == "-1000M", "Error: expected '-1000M', but got '{}'".format(
54 | num_short_form(-999999999))
55 | assert num_short_form(-1000000000) == "-1B", "Error: expected '-1B', but got '{}'".format(
56 | num_short_form(-1000000000))
57 | assert num_short_form(-1500000000) == "-1_5B", "Error: expected '-1_5B', but got '{}'".format(
58 | num_short_form(-1500000000))
59 | print("All tests passed")
60 |
61 |
62 | test_convert_num_to_str()
63 |
--------------------------------------------------------------------------------
/saves/view_vec.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cwhy/rwkv-decon/edb98117140f6a8e99a8fe250a1882a9e07a5650/saves/view_vec.npy
--------------------------------------------------------------------------------
/saves/view_vec2_dict:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cwhy/rwkv-decon/edb98117140f6a8e99a8fe250a1882a9e07a5650/saves/view_vec2_dict
--------------------------------------------------------------------------------
/saves/view_vec2_dict_jit:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cwhy/rwkv-decon/edb98117140f6a8e99a8fe250a1882a9e07a5650/saves/view_vec2_dict_jit
--------------------------------------------------------------------------------
/saves/view_vec2_dict_new:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cwhy/rwkv-decon/edb98117140f6a8e99a8fe250a1882a9e07a5650/saves/view_vec2_dict_new
--------------------------------------------------------------------------------
/saves/view_vec2_step1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cwhy/rwkv-decon/edb98117140f6a8e99a8fe250a1882a9e07a5650/saves/view_vec2_step1.npy
--------------------------------------------------------------------------------
/saves/view_vecs_dict:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cwhy/rwkv-decon/edb98117140f6a8e99a8fe250a1882a9e07a5650/saves/view_vecs_dict
--------------------------------------------------------------------------------