├── .deepsource.toml ├── .gitignore ├── LICENSE ├── README.md ├── configs ├── 280b.yaml ├── deep.yaml ├── shampoo.yaml └── small.yaml ├── main.py ├── requirements.txt └── src ├── dataclass.py ├── dataset.py ├── executable ├── inference.py ├── preprocess.py ├── profile.py └── train.py ├── model.py ├── optimizers ├── build.py ├── shampoo.py └── shampoo_utils.py └── utils ├── __init__.py ├── formatting.py ├── matrix_functions.py └── setup.py /.deepsource.toml: -------------------------------------------------------------------------------- 1 | version = 1 2 | 3 | test_patterns = ["*.py"] 4 | 5 | [[analyzers]] 6 | name = "python" 7 | enabled = true 8 | dependency_file_paths = ["requirements.txt"] 9 | 10 | [analyzers.meta] 11 | runtime_version = "3.x.x" 12 | max_line_length = 120 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### JetBrains template 3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 5 | 6 | # User-specific stuff 7 | .idea/**/workspace.xml 8 | .idea/**/tasks.xml 9 | .idea/**/usage.statistics.xml 10 | .idea/**/dictionaries 11 | .idea/**/shelf 12 | 13 | # Generated files 14 | .idea/**/contentModel.xml 15 | 16 | # Sensitive or high-churn files 17 | .idea/**/dataSources/ 18 | .idea/**/dataSources.ids 19 | .idea/**/dataSources.local.xml 20 | .idea/**/sqlDataSources.xml 21 | .idea/**/dynamic.xml 22 | .idea/**/uiDesigner.xml 23 | .idea/**/dbnavigator.xml 24 | 25 | # Gradle 26 | .idea/**/gradle.xml 27 | .idea/**/libraries 28 | 29 | # Gradle and Maven with auto-import 30 | # When using Gradle or Maven with auto-import, you should exclude module files, 31 | # since they will be recreated, and may cause churn. Uncomment if using 32 | # auto-import. 33 | # .idea/artifacts 34 | # .idea/compiler.xml 35 | # .idea/jarRepositories.xml 36 | # .idea/modules.xml 37 | # .idea/*.iml 38 | # .idea/modules 39 | # *.iml 40 | # *.ipr 41 | 42 | # CMake 43 | cmake-build-*/ 44 | 45 | # Mongo Explorer plugin 46 | .idea/**/mongoSettings.xml 47 | 48 | # File-based project format 49 | *.iws 50 | 51 | # IntelliJ 52 | out/ 53 | 54 | # mpeltonen/sbt-idea plugin 55 | .idea_modules/ 56 | 57 | # JIRA plugin 58 | atlassian-ide-plugin.xml 59 | 60 | # Cursive Clojure plugin 61 | .idea/replstate.xml 62 | 63 | # Crashlytics plugin (for Android Studio and IntelliJ) 64 | com_crashlytics_export_strings.xml 65 | crashlytics.properties 66 | crashlytics-build.properties 67 | fabric.properties 68 | 69 | # Editor-based Rest Client 70 | .idea/httpRequests 71 | 72 | # Android studio 3.1+ serialized cache file 73 | .idea/caches/build_file_checksums.ser 74 | 75 | ### C template 76 | # Prerequisites 77 | *.d 78 | 79 | # Object files 80 | *.o 81 | *.ko 82 | *.obj 83 | *.elf 84 | 85 | # Linker output 86 | *.ilk 87 | *.map 88 | *.exp 89 | 90 | # Precompiled Headers 91 | *.gch 92 | *.pch 93 | 94 | # Libraries 95 | *.lib 96 | *.a 97 | *.la 98 | *.lo 99 | 100 | # Shared objects (inc. Windows DLLs) 101 | *.dll 102 | *.so 103 | *.so.* 104 | *.dylib 105 | 106 | # Executables 107 | *.exe 108 | *.out 109 | *.app 110 | *.i*86 111 | *.x86_64 112 | *.hex 113 | 114 | # Debug files 115 | *.dSYM/ 116 | *.su 117 | *.idb 118 | *.pdb 119 | 120 | # Kernel Module Compile Results 121 | *.mod* 122 | *.cmd 123 | .tmp_versions/ 124 | modules.order 125 | Module.symvers 126 | Mkfile.old 127 | dkms.conf 128 | 129 | ### Python template 130 | # Byte-compiled / optimized / DLL files 131 | __pycache__/ 132 | *.py[cod] 133 | *$py.class 134 | 135 | # C extensions 136 | *.so 137 | 138 | # Distribution / packaging 139 | .Python 140 | build/ 141 | develop-eggs/ 142 | dist/ 143 | downloads/ 144 | eggs/ 145 | .eggs/ 146 | lib/ 147 | lib64/ 148 | parts/ 149 | sdist/ 150 | var/ 151 | wheels/ 152 | pip-wheel-metadata/ 153 | share/python-wheels/ 154 | *.egg-info/ 155 | .installed.cfg 156 | *.egg 157 | MANIFEST 158 | 159 | # PyInstaller 160 | # Usually these files are written by a python script from a template 161 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 162 | *.manifest 163 | *.spec 164 | 165 | # Installer logs 166 | pip-log.txt 167 | pip-delete-this-directory.txt 168 | 169 | # Unit test / coverage reports 170 | htmlcov/ 171 | .tox/ 172 | .nox/ 173 | .coverage 174 | .coverage.* 175 | .cache 176 | nosetests.xml 177 | coverage.xml 178 | *.cover 179 | *.py,cover 180 | .hypothesis/ 181 | .pytest_cache/ 182 | cover/ 183 | 184 | # Translations 185 | *.mo 186 | *.pot 187 | 188 | # Django stuff: 189 | *.log 190 | local_settings.py 191 | db.sqlite3 192 | db.sqlite3-journal 193 | 194 | # Flask stuff: 195 | instance/ 196 | .webassets-cache 197 | 198 | # Scrapy stuff: 199 | .scrapy 200 | 201 | # Sphinx documentation 202 | docs/_build/ 203 | 204 | # PyBuilder 205 | .pybuilder/ 206 | target/ 207 | 208 | # Jupyter Notebook 209 | .ipynb_checkpoints 210 | 211 | # IPython 212 | profile_default/ 213 | ipython_config.py 214 | 215 | # pyenv 216 | # For a library or package, you might want to ignore these files since the code is 217 | # intended to run in multiple environments; otherwise, check them in: 218 | # .python-version 219 | 220 | # pipenv 221 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 222 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 223 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 224 | # install all needed dependencies. 225 | #Pipfile.lock 226 | 227 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 228 | __pypackages__/ 229 | 230 | # Celery stuff 231 | celerybeat-schedule 232 | celerybeat.pid 233 | 234 | # SageMath parsed files 235 | *.sage.py 236 | 237 | # Environments 238 | .env 239 | .venv 240 | env/ 241 | venv/ 242 | ENV/ 243 | env.bak/ 244 | venv.bak/ 245 | 246 | # Spyder project settings 247 | .spyderproject 248 | .spyproject 249 | 250 | # Rope project settings 251 | .ropeproject 252 | 253 | # mkdocs documentation 254 | /site 255 | 256 | # mypy 257 | .mypy_cache/ 258 | .dmypy.json 259 | dmypy.json 260 | 261 | # Pyre type checker 262 | .pyre/ 263 | 264 | # pytype static type analyzer 265 | .pytype/ 266 | 267 | # Cython debug symbols 268 | cython_debug/ 269 | 270 | ### VirtualEnv template 271 | # Virtualenv 272 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 273 | .Python 274 | [Bb]in 275 | [Ii]nclude 276 | [Ll]ib 277 | [Ll]ib64 278 | [Ll]ocal 279 | [Ss]cripts 280 | pyvenv.cfg 281 | .venv 282 | pip-selfcheck.json 283 | 284 | .idea/ 285 | 286 | ### Project Specific Exclusions 287 | # Training Data 288 | out.tensor 289 | dataset.txt 290 | 291 | # Weights and biases autogenerated 292 | wandb/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2021, The HomebrewNLP Developers 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HomebrewNLP 2 | 3 | ## Overview 4 | 5 | A case study of efficient training of large language models using commodity hardware. 6 | 7 | ## Example Command 8 | 9 | ```BASH 10 | python3 main.py train --config_path configs/small.yaml 11 | ``` 12 | 13 | --- 14 | [![DeepSource](https://deepsource.io/gh/HomebrewNLP/HomebrewNLP.svg/?label=active+issues&show_trend=true&token=sAQ42SRyNPilkjj82sQd88ea)](https://deepsource.io/gh/HomebrewNLP/HomebrewNLP/?ref=repository-badge) 15 | | [Discord](https://discord.gg/JSGG6Abcyx) 16 | | [WandB](https://wandb.ai/homebrewnlp/gpt) 17 | 18 | ## Datasets 19 | * [Book Dataset](https://drive.google.com/file/u/1/d/1aoW3KI2E3nK7B28RE6I6_oDtNidTvoc2/view?usp=sharing) 20 | * [200MB Slice](https://drive.google.com/file/d/1QTbRYe-BOq2kw8foWB16NGPthQjZr7yn/view?usp=sharing) of [ThePile](https://github.com/EleutherAI/the-pile) 21 | 22 | 23 | 24 | 25 | ## Citing 26 | 27 | ### BibTeX 28 | 29 | ```bibtex 30 | @misc{nestler2021homebrewnlp, 31 | title = {{HomebrewNLP}}, 32 | author = {Nestler, Lucas and Gill, David}, 33 | year = {2021}, 34 | publisher = {GitHub}, 35 | journal = {GitHub repository}, 36 | doi = {10.5281/zenodo.5553247}, 37 | howpublished = {\url{https://github.com/HomebrewNLP/HomebrewNLP}} 38 | } 39 | ``` 40 | 41 | ### Latest DOI 42 | 43 | [![DOI](https://zenodo.org/badge/279888521.svg)](https://zenodo.org/badge/latestdoi/279888521) 44 | 45 | 46 | -------------------------------------------------------------------------------- /configs/280b.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | steps_per_checkpoint: 100000 3 | print_on_init: false 4 | depth: 1024 5 | offloading: true 6 | conv_kernel_size: 11 7 | weight_shared_blocks: 1 8 | batch_size: 4 9 | float16: yes 10 | feed_forward_intermediate_factor: 0.125 11 | features: 16384 12 | moe: 13 | use_in_input: false 14 | use_in_output: false 15 | num_experts: 8 16 | optimizer: 17 | beta2: 0.95 18 | gradient_accumulation_steps: 1 19 | one_cycle: 20 | cycle_first_step_size: 8192 21 | cycle_second_step_size: null 22 | cycle_min_lr: 0.0002 23 | cycle_max_lr: 0.002 24 | cycle_min_mom: 0.6 25 | cycle_max_mom: 0.9 26 | log: 27 | loss_steps_per_print: 1 28 | dataset: 29 | num_workers: 12 30 | -------------------------------------------------------------------------------- /configs/deep.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | steps_per_checkpoint: 32 3 | offloading: true 4 | print_on_init: false 5 | depth: 128 6 | conv_kernel_size: 11 7 | weight_shared_blocks: 1 8 | batch_size: 1 9 | feed_forward_intermediate_factor: 0.125 10 | features: 512 11 | moe: 12 | use_in_input: false 13 | use_in_output: true 14 | num_experts: 128 15 | optimizer: 16 | beta2: 0.95 17 | gradient_accumulation_steps: 1 18 | one_cycle: 19 | cycle_first_step_size: 8192 20 | cycle_second_step_size: null 21 | cycle_min_lr: 0.0001 22 | cycle_max_lr: 0.001 23 | cycle_min_mom: 0.6 24 | cycle_max_mom: 0.9 25 | log: 26 | loss_steps_per_print: 1 27 | -------------------------------------------------------------------------------- /configs/shampoo.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | depth: 32 3 | conv_kernel_size: 11 4 | weight_shared_blocks: 1 5 | batch_size: 32 6 | feed_forward_intermediate_factor: 0.125 7 | optimizer: 8 | type: Shampoo 9 | beta2: 0.90 10 | gradient_accumulation_steps: 1 11 | preconditioning_compute_steps: 4 12 | weight_decay: 0 13 | sharpness_aware_minimization: 14 | enabled: True 15 | one_cycle: 16 | cycle_first_step_size: 8192 17 | cycle_second_step_size: null 18 | cycle_momentum: False 19 | cycle_max_lr: 0.002 20 | log: 21 | loss_steps_per_print: 8 22 | dataset: 23 | num_workers: 4 24 | -------------------------------------------------------------------------------- /configs/small.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | depth: 32 3 | conv_kernel_size: 11 4 | weight_shared_blocks: 1 5 | batch_size: 1024 6 | feed_forward_intermediate_factor: 0.125 7 | optimizer: 8 | beta2: 0.95 9 | gradient_accumulation_steps: 1 10 | one_cycle: 11 | cycle_first_step_size: 8192 12 | cycle_second_step_size: null 13 | cycle_max_lr: 0.01 14 | log: 15 | loss_steps_per_print: 8 16 | dataset: 17 | num_workers: 12 18 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import typing 3 | 4 | import argh 5 | import yaml 6 | 7 | from src.dataclass import Context 8 | from src.executable.inference import inference_cli 9 | from src.executable.preprocess import preprocess_data 10 | from src.executable.train import train_model 11 | from src.utils.formatting import syntax_print 12 | from src.utils.setup import setup_torch 13 | 14 | 15 | def get_context(config_path: typing.Optional[str] = None) -> Context: 16 | ''' 17 | Loads context from provided config. Otherwise loads default. 18 | ''' 19 | if config_path is not None: 20 | config = pathlib.Path(config_path) 21 | assert config.suffix == '.yaml', 'Expected a .yaml file for config_path' 22 | ctx = Context(config_path=config) 23 | else: 24 | ctx = Context() 25 | return ctx 26 | 27 | 28 | @argh.arg('-i', '--in_path', default='data.txt', help='Path for data to be preprocessed') 29 | @argh.arg('-o', '--out_path', default='out.tensor', help='Path for data to be preprocessed') 30 | def preprocess(in_path: str = 'data.txt', out_path: str = "out.tensor"): 31 | ''' 32 | Processing original data into `out.tensor` 33 | ''' 34 | preprocess_data(in_path, out_path) 35 | 36 | 37 | @argh.arg('-c', '--config_path', default='configs/small.yaml', help='Path for the config file') 38 | @argh.arg('-s', '--steps', default=0, help='Number of steps to take. 0 = infinite') 39 | @argh.arg('-l', '--load_model', default=False, help='Whether to load an existing model checkpoint') 40 | def train(config_path: typing.Optional[str] = None, steps: int = 0, load_model: bool = False): 41 | ''' 42 | Trains a model given the config file. 43 | ''' 44 | ctx = get_context(config_path) 45 | setup_torch(0) 46 | 47 | dump = yaml.dump(ctx.serialize(), indent=4) 48 | syntax_print(dump, "yaml", title="Config") 49 | 50 | train_model(ctx, steps, load_model) 51 | 52 | 53 | @argh.arg('-g', '--generated_tokens', default='20', help='Number of tokens to be generated after prompt') 54 | @argh.arg('-t', '--temp', default='0.2', help='Temperature of the model.\nlower = consistency\nhigher = "creativity"') 55 | @argh.arg('-c', '--config_path', help='Path for the config file') 56 | def inference(generated_tokens: int = 20, temp: float = 0.2, config_path: str = None): 57 | ''' 58 | Runs inference of input data on desired model 59 | ''' 60 | assert config_path is not None, "Expected Config file!" 61 | 62 | ctx = get_context(config_path) 63 | 64 | inference_cli(ctx, float(temp), int(generated_tokens)) 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = argh.ArghParser() 69 | parser.add_commands([preprocess, train, inference]) 70 | parser.dispatch() 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy>=6.0.1 2 | mpi4py>=3.1.1 3 | revlib>=1.1.0 4 | deepspeed>=0.5.1 5 | rich>=10.9.0 6 | argh >=0.26.2 7 | numpy>=1.21.2 8 | wandb>=0.10.28 9 | PyYAML>=5.4.1 10 | -------------------------------------------------------------------------------- /src/dataclass.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import typing 3 | 4 | import torch 5 | import yaml 6 | 7 | 8 | class DataClass: 9 | def serialize(self): 10 | return serialize(self) 11 | 12 | 13 | def serialize(instance: typing.Union[DataClass, typing.Dict[str, typing.Any]]): 14 | if isinstance(instance, DataClass): 15 | attributes = {key: getattr(instance, key) for key in dir(instance) 16 | if not key.startswith('_') and not key.endswith('_')} 17 | return serialize({key: value for key, value in attributes.items() if not isinstance(value, typing.Callable)}) 18 | return {k: serialize(v) if isinstance(v, DataClass) else v for k, v in instance.items()} 19 | 20 | 21 | class Model(DataClass): 22 | weight_sharing: bool = False 23 | checkpoint_path: str = "checkpoint.torch" 24 | steps_per_checkpoint: int = 0 # 0 -> disabled 25 | print_on_init: bool = True 26 | features: int = 256 27 | momentumnet_beta: float = 0.99 # The higher this is, the more numerically stable. BUT also lower impact per layer 28 | depth: int = 64 29 | batch_size: int = 128 30 | sequence_length: int = 256 31 | activation_std: float = 0.5893595616022745 # std(relu(torch.randn((inf,)))) == 0.5893595616022745 32 | input_embedding_std: float = 1. 33 | position_embedding_std: float = 1. 34 | float16: bool = False 35 | device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu' 36 | conv_kernel_size: int = 7 37 | feature_shuffle: bool = False 38 | feed_forward_intermediate_factor: float = 2. 39 | norm_power: int = 2 # 1 = mean(abs(x)), 2 = std, ... 40 | bottleneck_group: int = 1 # not all group counts are possible. it has to be divide self.features without residual 41 | offloading: bool = False 42 | input_groups: int = 1 43 | output_groups: int = 1 44 | experts_in_input: int = 0 # 0 to disable MoE 45 | experts_in_output: int = 0 46 | moe_jitter_epsilon: float = 0.02 47 | expert_chunks: int = 1 # Increase it if not all MoE parameters fit onto the GPU 48 | 49 | 50 | class Dataset(DataClass): 51 | file_name: str = "out.tensor" 52 | classes: int = 256 53 | num_workers: int = 4 54 | pin_memory: bool = False 55 | prefetch_factor: int = 256 # 256 (Prefetch) * 8 (Long) * 2048 (GPT context) * 256 (High Batch) = 1GiB RAM 56 | 57 | 58 | class WandB(DataClass): 59 | project: str = 'gpt' 60 | entity: str = 'homebrewnlp' 61 | model_log_type: typing.Optional[str] = None # One of "gradients", "parameters", "all", or None 62 | log_frequency: int = 1000 # log gradients and parameters every N batches 63 | 64 | 65 | class Log(DataClass): 66 | loss_steps_per_print: int = 32 # 0 -> off 67 | wandb: WandB = WandB() 68 | 69 | 70 | class Offload(DataClass): 71 | device: str = "cpu" 72 | pin_memory: bool = True 73 | 74 | 75 | class Zero(DataClass): 76 | cpu_offload: bool = True 77 | contiguous_gradients: bool = False 78 | overlap_comm: bool = True 79 | offload_param: Offload = Offload() 80 | offload_optimizer: Offload = Offload() 81 | stage3_max_live_parameters: float = 1 82 | stage3_max_reuse_distance: float = 1 83 | stage3_prefetch_bucket_size: float = 1 84 | stage3_param_persistence_threshold: float = 1 85 | 86 | 87 | class OneCycle(DataClass): 88 | cycle_min_lr: float = 3e-4 # Base learning rate used at the start and end of cycle. 89 | cycle_max_lr: float = 1e-3 # Learning rate used in the middle of the cycle. Can be smaller than cycle_min_lr 90 | decay_lr_rate: float = 1e-4 # Decay rate for learning rate. 91 | cycle_first_step_size: int = 2048 # Number of training iterations in the increasing half of a cycle. 92 | cycle_second_step_size: typing.Optional[int] = None # steps in second phase. None -> cycle_first_step_size 93 | cycle_first_stair_count: int = 0 # Number of stairs in first phase. 0 means staircase disabled 94 | cycle_second_stair_count: typing.Optional[int] = None # Number of stairs in second phase 95 | decay_step_size: int = 2 # Every how many steps to decay lr. 0 -> no decay 96 | cycle_momentum: bool = True # Whether to cycle `momentum` inversely to learning rate. 97 | cycle_min_mom: float = 0.8 # Initial momentum which is the lower boundary in the cycle for each parameter group. 98 | cycle_max_mom: float = 0.9 # Upper momentum boundaries in the cycle for each parameter group. 99 | decay_mom_rate: float = 0 # Decay rate for momentum 100 | last_batch_iteration: int = -1 # The index of the last batch. This parameter is used when resuming a training job. 101 | 102 | 103 | class AdaptiveGradientClipping(DataClass): 104 | gradient_clipping: float = 0.01 105 | zero_division_eps: float = 1e-6 106 | eps: float = 1e-3 107 | 108 | 109 | class SharpnessAwareMinimization(DataClass): 110 | enabled: bool = True 111 | step_size: bool = 0.05 112 | adaptive: bool = True 113 | 114 | 115 | class Optimizer(DataClass): 116 | type: str = "AdamW" 117 | gradient_accumulation_steps: int = 1 118 | one_cycle: OneCycle = OneCycle() 119 | beta2: float = 0.95 # beta1 is controlled by one_cycle 120 | eps: float = 1e-8 121 | weight_decay: float = 0.01 122 | zero: Zero = Zero() 123 | agc = AdaptiveGradientClipping() 124 | sharpness_aware_minimization: SharpnessAwareMinimization = SharpnessAwareMinimization() 125 | 126 | # Shampoo hyper-params 127 | diagonal_eps: float = 1e-6 128 | matrix_eps: float = 1e-12 129 | inverse_exponent_override: int = 0 130 | start_preconditioning_step: int = 16 131 | preconditioning_compute_steps: int = 1 132 | statistics_compute_steps: int = 1 133 | block_size: int = 128 134 | best_effort_shape_interpretation: bool = True 135 | graft_type: str = 'adagrad' # 'Adagrad' or 'SGD' 136 | nesterov: bool = True 137 | no_preconditioning_for_layers_with_dim_gt: int = 8192 138 | 139 | 140 | class Eval(DataClass): 141 | cache: bool = False 142 | 143 | 144 | def init_class(instance: DataClass, config: typing.Dict[str, typing.Any]): 145 | for name in dir(instance): 146 | if name.startswith("_") or name.endswith("_") or name not in config: 147 | continue 148 | attr = getattr(instance, name) 149 | if isinstance(attr, DataClass): 150 | init_class(attr, config[name]) 151 | continue 152 | setattr(instance, name, config[name]) 153 | 154 | 155 | class Context(DataClass): 156 | def __init__(self, config: typing.Optional[typing.Dict[str, typing.Any]] = None, 157 | config_path: typing.Optional[pathlib.Path] = None): 158 | self.log = Log() 159 | self.optimizer = Optimizer() 160 | self.dataset = Dataset() 161 | self.model = Model() 162 | self.eval = Eval() 163 | self.wandb = WandB() 164 | 165 | if config_path is not None: 166 | config = yaml.safe_load(config_path.read_text()) 167 | 168 | if config is not None: 169 | init_class(self, config) 170 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import torch 4 | import torch.utils.data 5 | 6 | from src.dataclass import Context 7 | 8 | 9 | @torch.jit.script 10 | def get_sample(data: torch.Tensor, batch_index: torch.Tensor, idx: int) -> typing.Tuple[torch.Tensor, torch.Tensor]: 11 | dat = data[batch_index + idx] 12 | dat = dat.to(dtype=torch.long, non_blocking=True) 13 | return dat[:, :-1], dat[:, 1:] 14 | 15 | 16 | class Dataset(torch.utils.data.Dataset): 17 | def __init__(self, ctx: Context): 18 | self.data = torch.load(ctx.dataset.file_name) 19 | batch_index = torch.arange(0, ctx.model.batch_size).view(-1, 1) 20 | item_index = torch.arange(0, ctx.model.sequence_length + 1).view(1, -1) 21 | self.batch_index = batch_index + item_index 22 | self.length = self.data.size(0) - ctx.model.batch_size * ctx.model.sequence_length 23 | 24 | def __len__(self): 25 | return self.length 26 | 27 | def __getitem__(self, idx: int) -> typing.Tuple[torch.Tensor, torch.Tensor]: 28 | return get_sample(self.data, self.batch_index, idx) 29 | 30 | 31 | def get_dataset(ctx: Context) -> torch.utils.data.DataLoader: 32 | if ctx.dataset.prefetch_factor < ctx.dataset.num_workers: 33 | print(f"Warning: prefetch_factor ({ctx.dataset.prefetch_factor}) < num_workers ({ctx.dataset.num_workers})." 34 | f"Some workers will be idle at all times. Reducing num_workers ({ctx.dataset.num_workers}) to " 35 | f"prefetch_factor ({ctx.dataset.prefetch_factor}).") 36 | return torch.utils.data.DataLoader(Dataset(ctx), ctx.optimizer.gradient_accumulation_steps, True, 37 | num_workers=min(ctx.dataset.num_workers, ctx.dataset.prefetch_factor), 38 | pin_memory=ctx.dataset.pin_memory, prefetch_factor=ctx.dataset.prefetch_factor) 39 | -------------------------------------------------------------------------------- /src/executable/inference.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import torch 4 | 5 | from src.dataclass import Context 6 | from src.model import LinearAttention 7 | from src.utils.setup import encode, decode, get_model 8 | 9 | 10 | def complete_batch(ctx: Context, model: LinearAttention, prompt: torch.Tensor, temperature: float, 11 | generated_tokens: int) -> typing.List[str]: 12 | batch, prompt_size = prompt.size() 13 | out = prompt 14 | for _ in range(prompt_size, prompt_size + generated_tokens): 15 | tmp = model(prompt)[:, :, -1] 16 | tmp += torch.rand_like(tmp).clamp(min=1e-9).log().neg().log() * (-temperature) 17 | new_item = torch.argmax(tmp, -1).view(batch, -1) 18 | out = prompt = torch.cat([out, new_item], -1) 19 | if ctx.eval.cache: 20 | prompt = new_item 21 | model.reset_cache() 22 | return [decode(o) for o in out.unbind(0)] 23 | 24 | 25 | def complete(ctx: Context, model: LinearAttention, prompt: str, temperature: float, generated_tokens: int) -> str: 26 | return complete_batch(ctx, model, encode(prompt).to(dtype=torch.long, device=ctx.model.device).view(1, -1), 27 | temperature, generated_tokens)[0] 28 | 29 | 30 | @torch.no_grad() 31 | def inference_cli(ctx: Context, temperature: float, generated_tokens: int): 32 | mod = get_model(ctx, True).model 33 | mod.eval() 34 | while True: 35 | try: 36 | prompt = input("Prompt: ") 37 | except KeyboardInterrupt: 38 | break 39 | print(complete(ctx, mod, prompt, temperature, generated_tokens)) 40 | -------------------------------------------------------------------------------- /src/executable/preprocess.py: -------------------------------------------------------------------------------- 1 | import ftfy 2 | 3 | import torch 4 | from src.utils.setup import encode 5 | 6 | 7 | def preprocess_data(in_path: str, out_path: str): 8 | # Todo: convert to pathlib and confirm paths existance 9 | with open(in_path, 'r', errors="ignore") as f: 10 | dat = f.read() 11 | dat = ftfy.fix_text(dat) 12 | torch.save(encode(dat), out_path) 13 | -------------------------------------------------------------------------------- /src/executable/profile.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.dataclass import Context 4 | from src.executable.train import train_model 5 | 6 | 7 | def main(ctx: Context, chrome_trace_path: str = "torch_trace", steps: int = 128): 8 | with torch.autograd.profiler.profile(use_cuda=True, use_cpu=False, use_kineto=True) as prof: 9 | train_model(ctx, steps) 10 | print(prof.key_averages()) 11 | if chrome_trace_path: 12 | prof.export_chrome_trace(chrome_trace_path) 13 | -------------------------------------------------------------------------------- /src/executable/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | 4 | from src.dataclass import Context 5 | from src.dataset import get_dataset 6 | from src.utils.formatting import WandbLog 7 | from src.utils.setup import get_model 8 | 9 | 10 | def train_model(ctx: Context, steps=None, load_model: bool = False): 11 | wandb.init(project=ctx.log.wandb.project, entity=ctx.log.wandb.entity, config=ctx.serialize()) 12 | ctx = Context(wandb.config) 13 | 14 | data = get_dataset(ctx) 15 | data_len = len(data) 16 | data = iter(data) 17 | mod = get_model(ctx, load_model, next(data)[0]) 18 | wandb.watch(mod, log=ctx.log.wandb.model_log_type, log_freq=ctx.log.wandb.log_frequency) 19 | 20 | log = WandbLog(ctx, data_len) 21 | mean_loss = torch.zeros([], device=ctx.model.device, dtype=torch.float16 if ctx.model.float16 else torch.float) 22 | mean_max_loss = mean_loss.clone() 23 | 24 | i = 0 25 | while True: 26 | i += 1 27 | 28 | mean_loss += mod.accumulated_step(next(data)) 29 | if ctx.optimizer.sharpness_aware_minimization.enabled: 30 | with torch.no_grad(): 31 | for p in mod.gradients(): 32 | if ctx.optimizer.sharpness_aware_minimization.adaptive: 33 | p.grad *= p.square() 34 | p.grad *= ctx.optimizer.sharpness_aware_minimization.step_size 35 | p.add_(p.grad) 36 | p.prev_step = p.grad 37 | p.grad = None 38 | mean_max_loss += mod.accumulated_step(next(data)) 39 | mod.optimizer.step() 40 | if ctx.optimizer.sharpness_aware_minimization.enabled: 41 | with torch.no_grad(): 42 | for p in mod.gradients(): 43 | p.sub_(p.prev_step) 44 | p.prev_step = None 45 | p.grad = None 46 | else: 47 | mod.zero_grad() 48 | mod.scheduler.step() 49 | for p in mod.optimizer.param_groups: # OneCycle resets beta2 to 0.990 50 | p['betas'] = p['betas'][0], mod.ctx.optimizer.beta2 51 | with torch.no_grad(): 52 | if mod.ctx.log.loss_steps_per_print and i % mod.ctx.log.loss_steps_per_print == 0: 53 | log(mean_loss, mean_max_loss, 54 | mod.optimizer.param_groups[0]['lr'], mod.optimizer.param_groups[0]['betas']) 55 | mean_loss.zero_() 56 | mean_max_loss.zero_() 57 | if mod.ctx.model.steps_per_checkpoint and i % mod.ctx.model.steps_per_checkpoint == 0: 58 | mod.save() 59 | if steps and i > steps: 60 | return 61 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import typing 3 | 4 | import numpy as np 5 | import revlib 6 | import torch 7 | import torch.utils.data 8 | from deepspeed.runtime import lr_schedules 9 | from torch.nn import functional as F 10 | 11 | from src.dataclass import Context 12 | from src.optimizers.build import build_optimizer 13 | 14 | QUAD_TENSOR = typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] 15 | 16 | 17 | def orthonormal(inp: typing.Union[torch.Tensor, torch.nn.Parameter, typing.List[int]], gain: float): 18 | original_input = inp 19 | if isinstance(inp, list): 20 | inp = torch.zeros(inp) 21 | if isinstance(inp, torch.nn.Parameter): 22 | inp = inp.data 23 | flat_shape = (inp.shape[0], np.prod(inp.shape[1:])) 24 | a = torch.rand(flat_shape) 25 | u, _, v = torch.linalg.svd(a, full_matrices=False) 26 | inp.copy_((u if u.shape == flat_shape else v).reshape(inp.shape).mul(gain).to(device=inp.device, dtype=inp.dtype)) 27 | if isinstance(original_input, list): 28 | return torch.nn.Parameter(inp) 29 | return original_input 30 | 31 | 32 | class TripleNorm(torch.autograd.Function): 33 | @staticmethod 34 | def forward(ctx, scale0: torch.Tensor, scale1: torch.Tensor, shift: torch.Tensor, norm_power: int): 35 | scale0_relu = scale0.relu() 36 | inp = scale0_relu.pow(3) * scale1 + shift 37 | inp = inp - inp.mean(1, True) 38 | rstd = inp.size(1) ** (1 / norm_power) / inp.norm(norm_power, 1, True) 39 | inp *= rstd 40 | if scale1.requires_grad: 41 | ctx.save_for_backward(scale0_relu, scale1, inp, rstd) 42 | return inp 43 | 44 | @staticmethod 45 | def backward(ctx, dout: torch.Tensor): 46 | if not ctx.saved_tensors: 47 | return None, None, None, None 48 | scale0_relu, scale1, out, rstd = ctx.saved_tensors 49 | dout = dout * rstd 50 | dout -= (dout * out).mean(1, True) * out 51 | dout -= dout.mean(1, True) 52 | d_scale = dout * scale0_relu.square() 53 | return d_scale * scale1 * 3, d_scale * scale0_relu, dout, None 54 | 55 | 56 | def conv(inp: torch.Tensor, weight: torch.Tensor, groups: int, use_pad: bool) -> torch.Tensor: 57 | if use_pad and weight.size()[-1] - 1 > 0: 58 | inp = F.pad(inp, (weight.size()[-1] - 1, 0)) 59 | return F.conv1d(inp, weight, groups=groups) 60 | 61 | 62 | def expert_matmul(inp: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: 63 | return torch.einsum("bgf,gfo->bgo", inp, weight) 64 | 65 | 66 | class AuxLoss(torch.autograd.Function): 67 | @staticmethod 68 | def forward(ctx, inp: torch.Tensor): 69 | ctx.save_for_backward(inp) 70 | return inp 71 | 72 | @staticmethod 73 | def backward(ctx, grad_outputs: torch.Tensor): 74 | inp, = ctx.saved_tensors 75 | inp.mean().backward() 76 | 77 | 78 | def moe(inp: torch.Tensor, expert_weights: torch.nn.ParameterList, training: bool, 79 | jitter_epsilon: float, feature_shuffle: torch.Tensor, groups: int, experts: int) -> torch.Tensor: 80 | *expert_weights, gate = expert_weights 81 | batch, features, sequence = inp.size() 82 | tokens = batch * sequence 83 | capacity = tokens // experts 84 | 85 | # get gates 86 | if gate.dtype != torch.float32: 87 | gate = gate.float() 88 | inp = inp.transpose(1, 2).reshape(tokens, features) 89 | input_fp32 = inp.float() 90 | if training: 91 | input_fp32 = input_fp32 * (torch.rand_like(input_fp32) * jitter_epsilon + 1) 92 | logits = input_fp32.mm(gate) 93 | gates = F.softmax(logits, dim=1) 94 | 95 | # calculate permutation 96 | with torch.no_grad(): 97 | mask = torch.ones_like(gates[:, 0]) 98 | out = [] 99 | for g in gates.unbind(1): 100 | _, idx = torch.topk(g * mask, capacity, 0) 101 | out.append(idx) 102 | mask[idx] = 0 103 | expert_permutation = torch.stack(out, 1) 104 | expert_permutation = expert_permutation.view(-1, 1).long() 105 | permutation_inverse = torch.argsort(expert_permutation, 0).view(-1, 1) 106 | expert_index = permutation_inverse // capacity 107 | 108 | # apply loss 109 | AuxLoss(gates.sum() / tokens) 110 | inp = inp * gates.gather(1, expert_index) 111 | 112 | # permute 113 | inp = inp.gather(0, expert_permutation.expand_as(inp)) 114 | 115 | if feature_shuffle is not None: 116 | inp = inp.gather(1, feature_shuffle.view(1, -1).expand_as(inp)) 117 | inp = inp.view(tokens // experts, experts * groups, features // groups) 118 | if len(expert_weights) == 1: 119 | inp = expert_matmul(inp, expert_weights[0]) 120 | else: 121 | inp = torch.cat([expert_matmul(c, w) for c, w in zip(inp.chunk(len(expert_weights), 1), expert_weights)], -1) 122 | inp = inp.reshape(tokens, -1) 123 | inp = inp.gather(0, permutation_inverse.view(-1, 1).expand_as(inp)) 124 | inp = inp.view(batch, sequence, -1).transpose(1, 2) 125 | return inp 126 | 127 | 128 | def moe_check(inp: torch.Tensor, w: torch.nn.ParameterList, training: bool, 129 | jitter_epsilon: float, feature_shuffle: torch.Tensor, groups: int, experts: int) -> torch.Tensor: 130 | if experts > 0: 131 | return moe(inp, w, training, jitter_epsilon, feature_shuffle, groups, experts) 132 | return conv(inp, w[0], groups, False) 133 | 134 | 135 | def linear_attention(inp: torch.Tensor, divisor: torch.Tensor, 136 | w0: torch.nn.ParameterList, 137 | feature_shuffle0: typing.Optional[torch.Tensor], groups0: int, experts0: int, 138 | w1: torch.Tensor, 139 | w2: torch.nn.ParameterList, 140 | feature_shuffle2: typing.Optional[torch.Tensor], groups2: int, experts2: int, 141 | input_cache: torch.Tensor, cumsum_cache: torch.Tensor, bottleneck_group: int, training: bool, 142 | caching: bool, idx: int, norm_power: int, jitter_epsilon: float 143 | ) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 144 | kernel_size = w1.size(2) 145 | pad = True 146 | if not training and caching: 147 | if idx - 1 > kernel_size and inp.size(2) == 1: 148 | pad = False 149 | inp = torch.cat([input_cache, inp], -1) 150 | input_cache = inp[:, :, -kernel_size + 1:].detach() 151 | inp = moe_check(inp, w0, training, jitter_epsilon, feature_shuffle0, groups0, experts0) 152 | depth, scale, shift = inp.chunk(3, 1) 153 | cum = depth.cumsum(-1) 154 | if not training and caching: 155 | cum = cum + cumsum_cache 156 | scale = scale[:, :, -1:] 157 | shift = shift[:, :, -1:] 158 | cum = cum[:, :, -1:] 159 | if idx - 1 > kernel_size: 160 | cumsum_cache = cum.detach() 161 | inp = TripleNorm.apply(cum / divisor, scale, shift, norm_power) 162 | inp = conv(inp, w1, bottleneck_group, pad) 163 | inp = TripleNorm.apply(*inp.chunk(3, 1), norm_power) 164 | inp = moe_check(inp, w2, training, jitter_epsilon, feature_shuffle2, groups2, experts2) 165 | return input_cache, cumsum_cache, inp 166 | 167 | 168 | def conv_weight(in_features: int, out_features: int, kernel_size: int, groups: int, std: float): 169 | return orthonormal(torch.nn.Conv1d(in_features, out_features, (kernel_size,), groups=groups).weight, 1 / std) 170 | 171 | 172 | class Trainer(torch.nn.Module): 173 | def __init__(self, ctx: Context, model: torch.nn.Module, data: typing.Optional[torch.Tensor]): 174 | super(Trainer, self).__init__() 175 | self.ctx = ctx 176 | self.model = torch.jit.trace(model, data) if data else model 177 | self.optimizer = build_optimizer(ctx, self.model.parameters()) 178 | self.scheduler = lr_schedules.OneCycle(self.optimizer, 179 | ctx.optimizer.one_cycle.cycle_min_lr, 180 | ctx.optimizer.one_cycle.cycle_max_lr, 181 | ctx.optimizer.one_cycle.decay_lr_rate, 182 | ctx.optimizer.one_cycle.cycle_first_step_size, 183 | ctx.optimizer.one_cycle.cycle_second_step_size, 184 | ctx.optimizer.one_cycle.cycle_first_stair_count, 185 | ctx.optimizer.one_cycle.cycle_second_stair_count, 186 | ctx.optimizer.one_cycle.decay_step_size, 187 | ctx.optimizer.one_cycle.cycle_momentum, 188 | ctx.optimizer.one_cycle.cycle_min_mom, 189 | ctx.optimizer.one_cycle.cycle_max_mom, 190 | ctx.optimizer.one_cycle.decay_mom_rate, 191 | ctx.optimizer.one_cycle.last_batch_iteration) 192 | 193 | @torch.no_grad() 194 | def _to_device_detach(self, inp: torch.Tensor) -> torch.Tensor: 195 | return inp.to(device=self.ctx.model.device, non_blocking=True).detach() 196 | 197 | def _forward_backward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor: 198 | loss = F.cross_entropy(self.model(self._to_device_detach(src)), self._to_device_detach(tgt)) 199 | loss.backward() 200 | return loss.detach() 201 | 202 | @torch.no_grad() 203 | def _clip_gradient(self): 204 | for p in self.gradients(): 205 | g_norm = p.grad.norm(2, 0, True).clamp(min=self.ctx.optimizer.agc.zero_division_eps) 206 | p_norm = p.norm(2, 0, True).clamp(min=self.ctx.optimizer.agc.eps) 207 | grad_scale = (p_norm / g_norm * self.ctx.optimizer.agc.gradient_clipping).clamp(max=1) 208 | p.grad.data.copy_(p.grad * grad_scale) 209 | 210 | def accumulated_step(self, data: torch.Tensor) -> torch.Tensor: 211 | loss = sum(self._forward_backward(s, t) for s, t in zip(*data)) 212 | self._clip_gradient() 213 | return loss 214 | 215 | @torch.no_grad() 216 | def zero_grad(self): 217 | for p in self.model.parameters(): 218 | p.grad = None 219 | 220 | @torch.no_grad() 221 | def gradients(self) -> torch.nn.Parameter: 222 | for p in self.model.parameters(): 223 | if p.grad is None: 224 | continue 225 | yield p 226 | 227 | def save(self): 228 | torch.save(self.state_dict(), self.ctx.model.checkpoint_path) 229 | 230 | def load(self): 231 | wrong_keys = self.load_state_dict(torch.load(self.ctx.model.checkpoint_path), strict=False) 232 | for key in wrong_keys.missing_keys + wrong_keys.unexpected_keys: 233 | if not any(k.startswith('_') for k in key.split('.')): 234 | if key in wrong_keys.missing_keys: 235 | raise ValueError(f"{key} is missing in checkpoint but exists in model") 236 | if key in wrong_keys.unexpected_keys: 237 | raise ValueError(f"{key} is missing in model but exists in checkpoint") 238 | 239 | 240 | class LinearAttention(torch.nn.Module): 241 | def __init__(self, ctx: Context): 242 | super(LinearAttention, self).__init__() 243 | self.embedding = torch.nn.Embedding(ctx.dataset.classes, ctx.model.features * 2).to(ctx.model.device) 244 | orthonormal(self.embedding.weight, ctx.model.input_embedding_std * 2 ** -0.5) 245 | 246 | pos_embd = torch.arange(0, ctx.model.sequence_length).unsqueeze(0) + 1 247 | self.register_buffer("divisor", pos_embd.unsqueeze(0).to(torch.float).to(ctx.model.device)) 248 | 249 | cell = LinearAttentionCell(self, ctx, 1) 250 | self.stem = revlib.utils.momentum_net(*[copy.deepcopy(cell) for _ in range(ctx.model.depth)], 251 | target_device=ctx.model.device) 252 | self.output = torch.nn.Conv1d(ctx.model.features * 2, ctx.dataset.classes, (1,)).to(ctx.model.device) 253 | torch.nn.init.zeros_(self.output.weight.data) 254 | 255 | def forward(self, inp: torch.Tensor): 256 | return self.output(self.stem(self.embedding(inp).transpose(1, 2))) 257 | 258 | def reset_cache(self): 259 | for mod in self.stem.modules(): 260 | if isinstance(mod, LinearAttentionCell): 261 | mod.reset_cache() 262 | 263 | 264 | class ParameterStore(torch.nn.Module): 265 | """ 266 | Something (likely deepspeed) changes all parameters in a ParameterList to [1] even though standalone parameters 267 | work. That's why a torch.nn.ModuleList of ParameterStores needs to be initialized. 268 | """ 269 | 270 | def __init__(self, param: torch.Tensor): 271 | super(ParameterStore, self).__init__() 272 | self.param = torch.nn.Parameter(param) 273 | 274 | def __repr__(self): 275 | return (f'{self.__class__.__name__}(shape={str(list(self.param.size()))}, device={self.param.device}, ' 276 | f'dtype={self.param.dtype})') 277 | 278 | 279 | def get_moe_param(in_features: int, out_features: int, groups: int, experts: int, expert_chunks: int, std: float 280 | ) -> typing.List[torch.nn.Parameter]: 281 | if experts: 282 | experts = groups if experts < 0 else experts 283 | out = orthonormal([in_features // groups, out_features // groups], std).view(1, in_features // groups, -1) 284 | out = out.repeat(experts // expert_chunks * groups, 1, 1).detach() 285 | gate = [orthonormal([in_features, experts], 1)] 286 | return [torch.nn.Parameter(copy.deepcopy(out)) for _ in range(expert_chunks)] + gate 287 | return [torch.nn.Parameter(conv_weight(in_features, out_features, 1, groups, std))] 288 | 289 | 290 | class LinearAttentionCell(torch.nn.Module): 291 | def __init__(self, base: LinearAttention, ctx: Context, init_scale: float): 292 | super(LinearAttentionCell, self).__init__() 293 | self.divisor = lambda: base.divisor 294 | self.init_scale = init_scale 295 | self.caching = ctx.eval.cache 296 | self.kernel_size = ctx.model.conv_kernel_size 297 | self.bottleneck_group = ctx.model.bottleneck_group 298 | self.norm_power = ctx.model.norm_power 299 | self.groups0 = ctx.model.input_groups 300 | self.groups2 = ctx.model.output_groups 301 | self.experts0 = ctx.model.experts_in_input 302 | self.experts2 = ctx.model.experts_in_output 303 | self.jitter_epsilon = ctx.model.moe_jitter_epsilon 304 | self.expert_chunks = ctx.model.expert_chunks 305 | intermediate = int(ctx.model.features * ctx.model.feed_forward_intermediate_factor) 306 | self.w0 = torch.nn.ParameterList(get_moe_param(ctx.model.features, intermediate * 3, self.groups0, 307 | self.experts0, self.expert_chunks, ctx.model.activation_std)) 308 | self.w1 = conv_weight(intermediate, intermediate * 3, ctx.model.conv_kernel_size, ctx.model.bottleneck_group, 309 | ctx.model.activation_std) 310 | self.w2 = torch.nn.ParameterList(get_moe_param(intermediate, ctx.model.features, self.groups2, 311 | self.experts2, self.expert_chunks, 1)) 312 | self.idx: int = 0 313 | self._input_cache = torch.zeros([]) 314 | self._cumsum_cache = torch.zeros([]) 315 | if ctx.model.feature_shuffle: 316 | self.register_buffer("feature_shuffle0", torch.argsort(torch.randn(ctx.model.features)).view(1, -1, 1)) 317 | self.register_buffer("feature_shuffle2", torch.argsort(torch.randn(intermediate)).view(1, -1, 1)) 318 | else: 319 | self.feature_shuffle0 = None 320 | self.feature_shuffle2 = None 321 | 322 | def reset_cache(self): 323 | self._cumsum_cache = torch.zeros([]) 324 | self._input_cache = torch.zeros([]) 325 | self.idx = 0 326 | 327 | def forward(self, inp: torch.Tensor) -> torch.Tensor: 328 | if self.training: 329 | div = self.divisor() 330 | elif self.caching: 331 | self.idx += inp.size(2) 332 | div = torch.LongTensor([self.idx]).to(inp.device) 333 | else: 334 | self.idx = inp.size(2) 335 | div = torch.arange(self.idx, device=inp.device).view(1, 1, -1) + 1 336 | self._input_cache, self._cumsum_cache, out = linear_attention(inp, div, 337 | self.w0, self.feature_shuffle0, self.groups0, 338 | self.experts0, 339 | self.w1, 340 | self.w2, self.feature_shuffle2, self.groups2, 341 | self.experts2, self._input_cache, 342 | self._cumsum_cache, self.bottleneck_group, 343 | self.training, self.caching, self.idx, 344 | self.norm_power, self.jitter_epsilon 345 | ) 346 | out = out * self.init_scale 347 | return out 348 | -------------------------------------------------------------------------------- /src/optimizers/build.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import typing 3 | 4 | import deepspeed.ops.adam 5 | import torch 6 | 7 | from src.dataclass import Context 8 | from src.optimizers import shampoo 9 | 10 | OWN_OPTIMIZER = {'Shampoo': shampoo.Shampoo} 11 | LIB_OPTIMIZER = {'DeepSpeedCPUAdam': deepspeed.ops.adam.DeepSpeedCPUAdam} 12 | 13 | 14 | def build_optimizer(ctx: Context, parameters: typing.Iterable[torch.nn.Parameter]): 15 | opt_type = ctx.optimizer.type 16 | if opt_type in OWN_OPTIMIZER: 17 | return OWN_OPTIMIZER[opt_type](parameters, ctx.optimizer) 18 | opt = LIB_OPTIMIZER[opt_type] if opt_type in LIB_OPTIMIZER else getattr(torch.optim, opt_type) 19 | if torch.optim.Optimizer not in inspect.getmro(opt): 20 | raise ValueError("Optimizer must inherit from 'torch.optim.Optimizer'.") 21 | params = {key: getattr(ctx.optimizer, key) for key in inspect.signature(opt).parameters.keys() 22 | if key in ctx.optimizer.serialize()} 23 | return opt(parameters, **params) 24 | -------------------------------------------------------------------------------- /src/optimizers/shampoo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Modifications and additional copyright (c) 2021, The HomebrewNLP Developers, 17 | # under the BSD 2-Clause License. Full license available at: 18 | # 19 | # https://raw.githubusercontent.com/HomebrewNLP/HomebrewNLP/master/LICENSE 20 | # 21 | """Pytorch implementation of Shampoo.""" 22 | 23 | from __future__ import print_function 24 | 25 | import itertools 26 | 27 | import numpy as np 28 | import torch 29 | import torch.optim as optim 30 | 31 | from src.dataclass import Optimizer 32 | from src.utils.matrix_functions import ComputePower 33 | 34 | 35 | class BlockPartitioner: 36 | """Partitions a tensor into smaller tensors for preconditioning. 37 | 38 | For example, if a variable has shape (4096, 512), we might split the 39 | 4096 into 4 blocks, so we effectively have 4 variables of size 40 | (1024, 512) each. 41 | """ 42 | 43 | def __init__(self, var, hps): 44 | self._shape = var.shape 45 | self._splits = [] 46 | self._split_sizes = [] 47 | split_sizes = [] 48 | # We split var into smaller blocks. Here we store the metadata to make 49 | # that split. 50 | for i, d in enumerate(var.shape): 51 | if hps.block_size > 0 and d > hps.block_size: 52 | # d-1, otherwise split appends a 0-size array. 53 | nsplit = (d - 1) // hps.block_size 54 | indices = (np.arange(nsplit, dtype=np.int32) + 1) * hps.block_size 55 | sizes = np.ones(nsplit + 1, dtype=np.int32) * hps.block_size 56 | sizes[-1] = d - indices[-1] 57 | self._splits.append((i, indices)) 58 | self._split_sizes.append((i, sizes)) 59 | split_sizes.append(sizes) 60 | else: 61 | split_sizes.append(np.array([d], dtype=np.int32)) 62 | self._num_splits = len(split_sizes) 63 | self._preconditioner_shapes = [] 64 | for t in itertools.product(*split_sizes): 65 | self._preconditioner_shapes.extend([[d, d] for d in t]) 66 | 67 | def shapes_for_preconditioners(self): 68 | return self._preconditioner_shapes 69 | 70 | def num_splits(self): 71 | return self._num_splits 72 | 73 | def partition(self, tensor): 74 | """Partition tensor into blocks.""" 75 | 76 | if tensor.shape != self._shape: 77 | raise ValueError('Grad shape != var shape. X has shape \ 78 | of {}; Y shape is {}.'.format(str(tensor.shape), str(self._shape))) 79 | tensors = [tensor] 80 | for (i, sizes) in self._split_sizes: 81 | tensors_local = [] 82 | for t in tensors: 83 | tensors_local.extend( 84 | torch.split(t, tuple(sizes), dim=i)) 85 | tensors = tensors_local 86 | return tensors 87 | 88 | def merge_partitions(self, partitions): 89 | """Merge partitions back to original shape.""" 90 | 91 | for (i, indices) in reversed(self._splits): 92 | n = len(indices) + 1 93 | partial_merged_tensors = [] 94 | ind = 0 95 | while ind < len(partitions): 96 | partial_merged_tensors.append( 97 | torch.cat(partitions[ind:ind + n], axis=i)) 98 | ind += n 99 | partitions = partial_merged_tensors 100 | if len(partitions) != 1: 101 | raise ValueError('Partition merged failed.') 102 | return partitions[0] 103 | 104 | 105 | def _merge_small_dims(shape_to_merge, max_dim): 106 | """Merge small dimensions. 107 | 108 | If there are some small dimensions, we collapse them: 109 | e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024 110 | [1, 2, 768, 1, 2048] --> [2, 768, 2048] 111 | 112 | Args: 113 | shape_to_merge: Shape to merge small dimensions. 114 | max_dim: Maximal dimension of output shape used in merging. 115 | 116 | Returns: 117 | Merged shape. 118 | """ 119 | resulting_shape = [] 120 | product = 1 121 | for d in shape_to_merge: 122 | if product * d <= max_dim: 123 | product *= d 124 | else: 125 | if product > 1: 126 | resulting_shape.append(product) 127 | product = d 128 | if product > 1: 129 | resulting_shape.append(product) 130 | return resulting_shape 131 | 132 | 133 | class Preconditioner: 134 | """Compute statistics/shape from gradients for preconditioning.""" 135 | 136 | def __init__(self, var, hps): 137 | self._hps = hps 138 | self._original_shape = var.shape 139 | self._transformed_shape = var.shape 140 | 141 | if hps.best_effort_shape_interpretation: 142 | self._transformed_shape = _merge_small_dims( 143 | self._original_shape, hps.block_size) 144 | 145 | reshaped_var = torch.reshape(var, self._transformed_shape) 146 | self._partitioner = BlockPartitioner(reshaped_var, hps) 147 | shapes = self._partitioner.shapes_for_preconditioners() 148 | rank = len(self._transformed_shape) 149 | device = var.get_device() 150 | if rank <= 1: 151 | self.statistics = [] 152 | self.preconditioners = [] 153 | else: 154 | eps = self._hps.matrix_eps 155 | self.statistics = [eps * torch.eye(s[0], device=device) for s in shapes] 156 | self.preconditioners = [torch.eye(s[0], device=device) for s in shapes] 157 | 158 | def add_statistics(self, beta2, gradient): 159 | """Compute statistics from gradients and add to the correct state entries. 160 | 161 | Args: 162 | grad: Gradient to compute statistics from. 163 | """ 164 | if not self.statistics: 165 | return 166 | reshaped_grad = torch.reshape(gradient, self._transformed_shape) 167 | partitioned_grads = self._partitioner.partition(reshaped_grad) 168 | w2 = 1.0 if beta2 == 1.0 else (1.0 - beta2) 169 | rank = len(self._transformed_shape) 170 | for j, grad in enumerate(partitioned_grads): 171 | for i in range(rank): 172 | axes = list(range(i)) + list(range(i + 1, rank)) 173 | stat = torch.tensordot(grad, grad, [axes, axes]) 174 | self.statistics[j * rank + i].mul_(beta2).add_(stat, alpha=w2) 175 | 176 | def exponent_for_preconditioner(self): 177 | """Returns exponent to use for inverse-pth root M^{-1/p}.""" 178 | if self._hps.inverse_exponent_override > 0: 179 | return self._hps.inverse_exponent_override 180 | return 2 * len(self._transformed_shape) 181 | 182 | def compute_preconditioners(self): 183 | """Compute L^{-1/exp} for each stats matrix L.""" 184 | exp = self.exponent_for_preconditioner() 185 | eps = self._hps.matrix_eps 186 | for i, stat in enumerate(self.statistics): 187 | self.preconditioners[i] = ComputePower( 188 | stat, exp, ridge_epsilon=eps) 189 | 190 | def preconditioned_grad(self, gradient): 191 | """Precondition the gradient. 192 | 193 | Args: 194 | grad: A gradient tensor to precondition. 195 | 196 | Returns: 197 | A preconditioned gradient. 198 | """ 199 | if not self.preconditioners: 200 | return gradient 201 | reshaped_grad = torch.reshape(gradient, self._transformed_shape) 202 | partitioned_grads = self._partitioner.partition(reshaped_grad) 203 | preconditioned_partitioned_grads = [] 204 | num_splits = self._partitioner.num_splits() 205 | for i, grad in enumerate(partitioned_grads): 206 | preconditioners_for_grad = self.preconditioners[ 207 | i * num_splits: (i + 1) * num_splits] 208 | rank = len(grad.shape) 209 | precond_grad = grad 210 | for j in range(rank): 211 | preconditioner = preconditioners_for_grad[j] 212 | precond_grad = torch.tensordot( 213 | precond_grad, preconditioner, [[0], [0]]) 214 | preconditioned_partitioned_grads.append(precond_grad) 215 | merged_grad = self._partitioner.merge_partitions( 216 | preconditioned_partitioned_grads) 217 | return torch.reshape(merged_grad, self._original_shape) 218 | 219 | 220 | STEP = 'step' 221 | MOMENTUM = 'momentum' 222 | PRECONDITIONER = 'preconditioner' 223 | GRAFT = 'graft' 224 | 225 | 226 | class Shampoo(optim.Optimizer): 227 | """The Shampoo optimizer, configured for use in the 228 | HomebrewNLP linear attention model. This class is 229 | passed model parameters and an input context loaded 230 | from the model config which controls all relevant 231 | hyper-parameters. See "/configs/shampoo.yaml" for an 232 | example Shampoo config. 233 | 234 | Configurable Hyperparameters and default values: 235 | 236 | beta2: float = 0.99 - Parameter for exponential 237 | moving average of Shampoo second 238 | moment statistics. If set == 1.0, 239 | then sums statistics instead of 240 | moving average. 241 | diagonal_eps: float = 1e-6 - Only set if using 242 | Layerwise grafting mode to adagrad. 243 | This is the epsilon for adagrad 244 | updates. 245 | matrix_eps: float = 1e-12 - Epsilon to add to 246 | statistics before computing inverse 247 | pth root. Max of 1e-6 for float32 248 | weight_decay: float = 0.0 249 | inverse_exponent_override: int = 0 - fixed exponent 250 | for preconditioner, if >0 251 | start_preconditioning_step - Performance tuning params 252 | for controlling memory & compute. 253 | preconditioning_compute_steps - How often to compute 254 | preconditioner. 255 | statistics_compute_steps: int = 1 - How often to 256 | compute statistics. 257 | block_size: int = 128 - Block size for large 258 | layers (if > 0). Block size = 1 259 | is equivalent to Adagrad (but is 260 | extremely inefficient!) 261 | Block size should be as large as 262 | feasible under memory/time 263 | constraints. 264 | no_preconditioning_for_layers_with_dim_gt: int = 8192 265 | Avoids preconditioning large 266 | layers to reduce overall 267 | memory usage. 268 | best_effort_shape_interpretation: bool = True - 269 | Automatic shape interpretation 270 | (for eg: [4, 3, 1024, 512] would 271 | result in 12 x [1024, 512] L 272 | and R statistics. 273 | graft_type: str = 'Adagrad' - 274 | Type of grafting (SGD or AdaGrad). 275 | nesterov: bool = True 276 | """ 277 | 278 | def __init__(self, params, ctx: Optimizer): 279 | self.hps = ctx 280 | super(Shampoo, self).__init__(params, {"betas": [0, self.hps.beta2], 281 | 'lr': 1, 282 | 'weight_decay': ctx.weight_decay, 283 | 'eps': ctx.eps}) 284 | 285 | def _use_preconditioner(self, var): 286 | return len(var.shape) > 0 and all(s <= self.hps.no_preconditioning_for_layers_with_dim_gt for s in var.shape) 287 | 288 | @torch.no_grad() 289 | def step(self, closure=None): 290 | hps = self.hps 291 | for group in self.param_groups: 292 | lr = group['lr'] 293 | for p in group['params']: 294 | if p.grad is None: 295 | continue 296 | grad = p.grad.data 297 | if grad.is_sparse: 298 | raise RuntimeError('Shampoo does not support sparse yet') 299 | state = self.state[p] 300 | if not state: 301 | state[STEP] = 0 302 | state[MOMENTUM] = torch.zeros_like(p.data, device=p.get_device()) 303 | state[GRAFT] = torch.zeros_like(p.data, device=p.get_device()) 304 | if self._use_preconditioner(p): 305 | state[PRECONDITIONER] = Preconditioner(p, self.hps) 306 | state[STEP] += 1 307 | 308 | # Gather statistics, compute preconditioners 309 | 310 | # Precondition gradients 311 | shampoo_grad = grad 312 | if self.hps.graft_type == 'adagrad': 313 | state[GRAFT].add_(grad.square()) 314 | if self._use_preconditioner(p): 315 | preconditioner = state[PRECONDITIONER] 316 | if state[STEP] % hps.statistics_compute_steps == 0: 317 | preconditioner.add_statistics(group['betas'][1], grad) 318 | if state[STEP] % hps.preconditioning_compute_steps == 0: 319 | preconditioner.compute_preconditioners() 320 | if state[STEP] >= self.hps.start_preconditioning_step: 321 | shampoo_grad = preconditioner.preconditioned_grad(grad) 322 | 323 | # Grafting 324 | graft_grad = grad 325 | if self.hps.graft_type == 'adagrad': 326 | graft_grad = grad / (torch.sqrt(state[GRAFT]) + self.hps.diagonal_eps) 327 | graft_norm = torch.norm(graft_grad) 328 | shampoo_norm = torch.norm(shampoo_grad) 329 | shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16)) 330 | 331 | # Momentum and Nesterov momentum, if needed 332 | state[MOMENTUM].mul_(group['betas'][0]).add_(shampoo_grad) 333 | graft_momentum = grad 334 | if self.hps.graft_type == 'sgd': 335 | graft_momentum = state[GRAFT].mul_(group['betas'][0]).add_(grad) 336 | 337 | momentum_update = graft_momentum 338 | wd_update = graft_grad 339 | if state[STEP] >= self.hps.start_preconditioning_step and self._use_preconditioner(p): 340 | momentum_update = state[MOMENTUM] 341 | wd_update = shampoo_grad 342 | 343 | if hps.nesterov: 344 | momentum_update.mul_(group['betas'][0]).add_(wd_update) 345 | 346 | # Final update 347 | momentum_update.add_(p, alpha=group['weight_decay']) 348 | p.data.add_(momentum_update, alpha=-lr) 349 | -------------------------------------------------------------------------------- /src/optimizers/shampoo_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Additional material copyright (c) 2021, The HomebrewNLP Developers, 17 | # licensed under the BSD 2-Clause License. Full license available at: 18 | # 19 | # https://raw.githubusercontent.com/HomebrewNLP/HomebrewNLP/master/LICENSE 20 | # 21 | # Unless required by applicable law or agreed to in writing, software 22 | # distributed under the License is distributed on an "AS IS" BASIS, 23 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 24 | # See the License for the specific language governing permissions and 25 | # limitations under the License. 26 | 27 | """Some utility numpy/pytorch functions to reshape and split variables.""" 28 | 29 | from __future__ import print_function 30 | 31 | import math 32 | 33 | import numpy as np 34 | import torch 35 | 36 | VAR_SHAPE = 'var_shape' 37 | VAR_SPLITS = 'var_splits' 38 | 39 | 40 | def merge_small_dims(var_shape, reshape_size): 41 | """Computes the shape of the variable for preconditioning. 42 | 43 | If the variable has several small dimensions, we can reshape it so 44 | that there are fewer big ones. e.g for a convolution (512, 5, 5, 1024) 45 | we could reshape it into (512, 25, 1024). 46 | 47 | Args: 48 | var_shape: the shape of the variable 49 | reshape_size: maximum size of a reshaped dimension 50 | Returns: 51 | shape: a list of integers. Product(shape) = number of elements in var. 52 | """ 53 | shape = [] 54 | product = 1 55 | for d in var_shape: 56 | if product * d <= reshape_size: 57 | product *= d 58 | else: 59 | if product > 1: 60 | shape.append(product) 61 | product = d 62 | if product > 1: 63 | shape.append(product) 64 | return shape 65 | 66 | 67 | def compute_splits(var_shape, block_size): 68 | """Splits larger dimensions into smaller ones, for preconditioning. 69 | 70 | For example, if a variable has shape (4096, 512), we might split the 71 | 4096 into 4 blocks, so we effectively have 4 variables of size 72 | (1024, 512) each. 73 | 74 | Args: 75 | var_shape: list of integers, the shape to be split 76 | block_size: the maximum dimension of each block 77 | Returns: 78 | splits: set of tuples (i, split) if the i-th dimension should be split 79 | split_sizes: an array of tuples, one per dimension, each indicating how 80 | to split that dimension. 81 | """ 82 | splits = [] 83 | split_sizes = [] 84 | for i, d in enumerate(var_shape): 85 | if d > block_size > 0: 86 | nsplit = math.ceil(d / block_size) 87 | sizes = np.ones(nsplit, dtype=np.int32) * block_size 88 | if d % block_size > 0: 89 | sizes[-1] = d % block_size 90 | splits.append((i, tuple(sizes))) 91 | split_sizes.append(sizes) 92 | else: 93 | split_sizes.append(np.array([d], dtype=np.int32)) 94 | return splits, split_sizes 95 | 96 | 97 | def split_grad(state, gradient): 98 | """Split up the gradient according to the blocking strategy.""" 99 | if len(state[VAR_SHAPE]) < len(list(gradient.shape)): 100 | grad = torch.reshape(gradient, state[VAR_SHAPE]) 101 | grads = [gradient] 102 | for i, split_sizes in state[VAR_SPLITS]: 103 | split_grads = [] 104 | for grad in grads: 105 | split_grads.extend(torch.split(grad, split_sizes, dim=i)) 106 | grads = split_grads 107 | return grads 108 | 109 | 110 | def merge_grads(state, grads): 111 | """Merge the split gradients back into a single array.""" 112 | for i, split_sizes in reversed(state[VAR_SPLITS]): 113 | n = len(split_sizes) 114 | conc_grads = [] 115 | ind = 0 116 | while ind < len(grads): 117 | conc_grads.append(torch.cat(grads[ind:ind + n], axis=i)) 118 | ind += n 119 | grads = conc_grads 120 | if len(grads) != 1: 121 | raise ValueError('Grad merge failed.') 122 | return grads[0] 123 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HomebrewNLP-torch/0fd6e5a0d204df85dcea1516595eb2c2c3521bc3/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/formatting.py: -------------------------------------------------------------------------------- 1 | import time 2 | import typing 3 | from typing import Optional 4 | 5 | import torch 6 | import wandb 7 | from rich import print as rich_print 8 | from rich.console import Console 9 | from rich.syntax import Syntax 10 | 11 | from src.dataclass import Context 12 | 13 | # Color coded tracebacks 14 | # install(show_locals=False, extra_lines=0) 15 | console = Console() 16 | 17 | 18 | # TODO: Allow for users to choose theme 19 | def syntax_print(string: str, language: Optional[str] = "python", theme: Optional[str] = "monokai", 20 | title: Optional[str] = None) -> None: 21 | if title is not None: 22 | console.rule(title) 23 | syntax = Syntax(string, language, theme=theme, line_numbers=True) 24 | console.print(syntax) 25 | 26 | 27 | def pretty_print(*data): 28 | rich_print(*data) 29 | 30 | 31 | def log(*data, log_locals: bool = False): 32 | console.log(*data, log_locals=log_locals) 33 | 34 | 35 | class WandbLog: 36 | def __init__(self, ctx: Context, steps: int): 37 | self.mean_loss = 0 38 | self.mean_max_loss = 0 39 | self.start_time = time.time() 40 | self.ctx = ctx 41 | self.idx = 0 42 | self.prev = 0 43 | self.steps = steps 44 | 45 | def __call__(self, current_loss: torch.Tensor, max_loss: torch.Tensor, learning_rate: float, 46 | betas: typing.Tuple[float, float]): 47 | grad_accum = self.ctx.optimizer.gradient_accumulation_steps 48 | curr_loss = current_loss.item() / self.ctx.log.loss_steps_per_print / grad_accum 49 | curr_max_loss = max_loss.item() / self.ctx.log.loss_steps_per_print / grad_accum 50 | self.idx += 1 51 | self.mean_loss = (self.mean_loss * self.prev + curr_loss * self.idx) / (self.prev + self.idx) # LWMA 52 | mean_max = self.mean_max_loss = (self.mean_max_loss * self.prev + max_loss * self.idx) / (self.prev + self.idx) 53 | self.prev += self.idx 54 | 55 | rate = self.ctx.log.loss_steps_per_print * self.idx / (time.time() - self.start_time) 56 | tokens_per_day = grad_accum * 3600 * 24 * rate * self.ctx.model.batch_size * self.ctx.model.sequence_length 57 | 58 | pretty_print(f"[{self.idx * self.ctx.log.loss_steps_per_print:{len(str(self.steps))}d}/{self.steps}]", 59 | f"Loss: {curr_loss:7.4f} -", 60 | f"Mean: {self.mean_loss:7.4f} |", 61 | f"LR: {learning_rate:.6f} -", 62 | f"Beta1: {betas[0]:.3f} -", 63 | f"Beta2: {betas[1]:.3f} |", 64 | f"Batch/s: {rate:6.3f} -", 65 | f"Tokens/day: {tokens_per_day:11,.0f}") 66 | 67 | if not self.ctx.optimizer.sharpness_aware_minimization.enabled: 68 | curr_max_loss = None 69 | mean_max = None 70 | wandb.log({"Loss/Current": curr_loss, 71 | "Loss/Mean": self.mean_loss, 72 | "Loss/Current Max": curr_max_loss, 73 | "Loss/Mean Max": mean_max, 74 | "Speed/Batches per Second": rate, 75 | "Speed/Tokens per Day": tokens_per_day, 76 | "Optimizer/Learning Rate": learning_rate, 77 | "Optimizer/Beta1": betas[0], 78 | "Optimizer/Beta2": betas[1]}, 79 | step=self.idx * self.ctx.log.loss_steps_per_print) 80 | -------------------------------------------------------------------------------- /src/utils/matrix_functions.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Additional material copyright (c) 2021, The HomebrewNLP Developers, 17 | # licensed under the BSD 2-Clause License. Full license available at: 18 | # 19 | # https://raw.githubusercontent.com/HomebrewNLP/HomebrewNLP/master/LICENSE 20 | # 21 | # Unless required by applicable law or agreed to in writing, software 22 | # distributed under the License is distributed on an "AS IS" BASIS, 23 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 24 | # See the License for the specific language governing permissions and 25 | # limitations under the License. 26 | from __future__ import print_function 27 | import torch 28 | 29 | 30 | @torch.no_grad() 31 | def PowerIter(mat_g, error_tolerance=1e-6, num_iters=100): 32 | """Power iteration. 33 | 34 | Compute the maximum eigenvalue of mat, for scaling. 35 | v is a random vector with values in (-1, 1) 36 | 37 | Args: 38 | mat_g: the symmetric PSD matrix. 39 | error_tolerance: Iterative exit condition. 40 | num_iters: Number of iterations. 41 | 42 | Returns: 43 | eigen vector, eigen value, num_iters 44 | """ 45 | v = torch.rand(list(mat_g.shape)[0], device=mat_g.get_device()) * 2 - 1 46 | error = 1 47 | iters = 0 48 | singular_val = 0 49 | while error > error_tolerance and iters < num_iters: 50 | v = v / torch.norm(v) 51 | mat_v = torch.mv(mat_g, v) 52 | s_v = torch.dot(v, mat_v) 53 | error = torch.abs(s_v - singular_val) 54 | v = mat_v 55 | singular_val = s_v 56 | iters += 1 57 | return singular_val, v / torch.norm(v), iters 58 | 59 | 60 | @torch.no_grad() 61 | def MatPower(mat_m, p): 62 | """Computes mat_m^p, for p a positive integer. 63 | 64 | Args: 65 | mat_m: a square matrix 66 | p: a positive integer 67 | 68 | Returns: 69 | mat_m^p 70 | """ 71 | if p in [1, 2, 4, 8, 16, 32]: 72 | p_done = 1 73 | res = mat_m 74 | while p_done < p: 75 | res = torch.matmul(res, res) 76 | p_done *= 2 77 | return res 78 | 79 | power = None 80 | while p > 0: 81 | if p % 2 == 1: 82 | power = torch.matmul(mat_m, power) if power is not None else mat_m 83 | p //= 2 84 | mat_m = torch.matmul(mat_m, mat_m) 85 | return power 86 | 87 | 88 | @torch.no_grad() 89 | def ComputePower(mat_g, p, 90 | iter_count=100, 91 | error_tolerance=1e-6, 92 | ridge_epsilon=1e-6): 93 | """A method to compute G^{-1/p} using a coupled Newton iteration. 94 | 95 | See for example equation 3.2 on page 9 of: 96 | A Schur-Newton Method for the Matrix p-th Root and its Inverse 97 | by Chun-Hua Guo and Nicholas J. Higham 98 | SIAM Journal on Matrix Analysis and Applications, 99 | 2006, Vol. 28, No. 3 : pp. 788-804 100 | https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf 101 | 102 | Args: 103 | mat_g: A square positive semidefinite matrix 104 | p: a positive integer 105 | iter_count: Stop iterating after this many rounds. 106 | error_tolerance: Threshold for stopping iteration 107 | ridge_epsilon: We add this times I to G, to make is positive definite. 108 | For scaling, we multiply it by the largest eigenvalue of G. 109 | Returns: 110 | (mat_g + rI)^{-1/p} (r = ridge_epsilon * max_eigenvalue of mat_g). 111 | """ 112 | shape = list(mat_g.shape) 113 | if len(shape) == 1: 114 | return torch.pow(mat_g + ridge_epsilon, -1/p) 115 | identity = torch.eye(shape[0], device=mat_g.get_device()) 116 | if shape[0] == 1: 117 | return identity 118 | alpha = -1.0/p 119 | max_ev, _, _ = PowerIter(mat_g) 120 | ridge_epsilon *= max_ev 121 | mat_g += ridge_epsilon * identity 122 | z = (1 + p) / (2 * torch.norm(mat_g)) 123 | # The best value for z is 124 | # (1 + p) * (c_max^{1/p} - c_min^{1/p}) / 125 | # (c_max^{1+1/p} - c_min^{1+1/p}) 126 | # where c_max and c_min are the largest and smallest singular values of 127 | # mat_g. 128 | # The above estimate assumes that c_max > c_min * 2^p 129 | # Can replace above line by the one below, but it is less accurate, 130 | # hence needs more iterations to converge. 131 | # z = (1 + p) / tf.trace(mat_g) 132 | # If we want the method to always converge, use z = 1 / norm(mat_g) 133 | # or z = 1 / tf.trace(mat_g), but these can result in many 134 | # extra iterations. 135 | 136 | mat_root = identity * torch.pow(z, 1.0/p) 137 | mat_m = mat_g * z 138 | error = torch.max(torch.abs(mat_m - identity)) 139 | count = 0 140 | while error > error_tolerance and count < iter_count: 141 | tmp_mat_m = (1 - alpha) * identity + alpha * mat_m 142 | new_mat_root = torch.matmul(mat_root, tmp_mat_m) 143 | mat_m = torch.matmul(MatPower(tmp_mat_m, p), mat_m) 144 | new_error = torch.max(torch.abs(mat_m - identity)) 145 | if new_error > error * 1.2: 146 | break 147 | mat_root = new_mat_root 148 | error = new_error 149 | count += 1 150 | return mat_root 151 | -------------------------------------------------------------------------------- /src/utils/setup.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import typing 4 | 5 | import numpy as np 6 | import torch 7 | import torch.utils.data.dataloader 8 | 9 | from src.dataclass import Context 10 | from src.model import LinearAttention, Trainer 11 | from src.utils.formatting import pretty_print 12 | 13 | DataLoaderIter = torch.utils.data.dataloader._BaseDataLoaderIter 14 | 15 | 16 | def setup_torch(seed: int): 17 | torch._C._debug_set_autodiff_subgraph_inlining(False) # skipcq: PYL-W0212 18 | torch._C._set_graph_executor_optimize(True) # skipcq: PYL-W0212 19 | torch._C._set_backcompat_broadcast_warn(False) # skipcq: PYL-W0212 20 | torch._C._set_backcompat_keepdim_warn(False) # skipcq: PYL-W0212 21 | torch._C._set_cudnn_enabled(True) # skipcq: PYL-W0212 22 | torch._C._set_mkldnn_enabled(True) # skipcq: PYL-W0212 23 | torch._C._set_mkldnn_enabled(True) # skipcq: PYL-W0212 24 | torch._C._set_cudnn_benchmark(True) # skipcq: PYL-W0212 25 | torch._C._set_cudnn_deterministic(False) # skipcq: PYL-W0212 26 | torch._C._set_cudnn_allow_tf32(True) # skipcq: PYL-W0212 27 | torch._C._set_cublas_allow_tf32(True) # skipcq: PYL-W0212 28 | torch._C._jit_set_inline_everything_mode(True) # skipcq: PYL-W0212 29 | 30 | torch._C._jit_set_profiling_executor(True) # skipcq: PYL-W0212 31 | torch._C._jit_set_profiling_mode(True) # skipcq: PYL-W0212 32 | torch._C._jit_override_can_fuse_on_cpu(False) # skipcq: PYL-W0212 33 | torch._C._jit_override_can_fuse_on_gpu(True) # skipcq: PYL-W0212 34 | torch._C._jit_set_texpr_fuser_enabled(True) # skipcq: PYL-W0212 35 | torch._C._jit_set_nvfuser_enabled(False) # skipcq: PYL-W0212 36 | 37 | random.seed(seed) 38 | np.random.seed(seed) 39 | torch.manual_seed(seed) 40 | 41 | 42 | def get_model(ctx: Context, load_model: bool, data: typing.Optional[torch.Tensor] = None) -> Trainer: 43 | mod = Trainer(ctx, LinearAttention(ctx).to(dtype=torch.float16 if ctx.model.float16 else torch.float), 44 | data if data is None else None) 45 | 46 | if ctx.model.print_on_init: 47 | pretty_print(str(mod)) 48 | 49 | parameters = sum(np.prod(p.size()) for p in filter(lambda p: p.requires_grad, mod.parameters())) 50 | base = int(math.log10(parameters) / 3) 51 | pretty_print(f'Parameters: {parameters / (1000 ** base):.1f}{" kMBT"[base]}') 52 | if load_model: 53 | mod.load() 54 | if not ctx.model.offloading: 55 | mod = mod.to(ctx.model.device) 56 | return mod 57 | 58 | 59 | def encode(prompt: str) -> torch.Tensor: 60 | return torch.as_tensor(np.frombuffer(prompt.encode('UTF-8'), np.uint8)) 61 | 62 | 63 | def decode(output: torch.LongTensor) -> str: 64 | return ''.join(chr(c) for c in output.view(-1).unbind(0)) 65 | --------------------------------------------------------------------------------