├── README.md ├── better_scaling_law_llama2.py └── combined_llama2.csv /README.md: -------------------------------------------------------------------------------- 1 | # pysr_scaling_laws 2 | You should use PySR to find scaling laws. Here's an example: [better_scaling_law_llama2.py](https://github.com/MilesCranmer/pysr_scaling_laws/blob/master/better_scaling_law_llama2.py) 3 | 4 | 5 | Source for scaling dataset: https://arxiv.org/abs/2309.16039 6 | 7 | Their law: 8 | 9 | ![image](https://github.com/MilesCranmer/pysr_scaling_laws/assets/7593028/f0fe5ea6-75e9-4756-b121-49c319a50d17) 10 | 11 | The laws PySR can find automatically (a simultaneous fit for number parameters & context size): 12 | 13 | ![image](https://github.com/MilesCranmer/pysr_scaling_laws/assets/7593028/a941e87a-fce5-4f91-86f4-fee334752bee) 14 | 15 | -------------------------------------------------------------------------------- /better_scaling_law_llama2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from pysr import PySRRegressor 4 | from multiprocessing import cpu_count 5 | 6 | # - 7 | fname = "combined_llama2.csv" 8 | df = pd.read_csv(fname) 9 | 10 | # - 11 | 12 | model = PySRRegressor( 13 | model_selection="best", 14 | populations=30, 15 | population_size=100, 16 | niterations=100, 17 | maxsize=20, 18 | ncyclesperiteration=10000, 19 | weight_optimize=0.001, 20 | adaptive_parsimony_scaling=1000, 21 | binary_operators=["+", "-", "*", "/", "^"], 22 | constraints={"^": (-1, 3), "/": (-1, 5)}, 23 | multithreading=False, # Multiprocessing instead 24 | procs=cpu_count(), 25 | turbo=True, 26 | ) 27 | 28 | # - 29 | 30 | model.fit(df[["context", "params"]], df["loss"]) 31 | 32 | # - 33 | 34 | latex_output = "table.tex" 35 | print("I will save a LaTeX table to table.tex") 36 | s = model.latex_table() 37 | with open(latex_output, "w") as f: 38 | f.write(s) 39 | 40 | -------------------------------------------------------------------------------- /combined_llama2.csv: -------------------------------------------------------------------------------- 1 | ,context,loss,params 2 | 0,2.135068261712696,5.67145762818818,7 3 | 1,3.396975674854875,4.703963891061202,7 4 | 2,5.404718875966568,3.9727780576482625,7 5 | 3,8.599115485122592,3.34514120539374,7 6 | 4,13.576040010664132,2.7494968353465348,7 7 | 5,21.600001510063358,2.4739856531117845,7 8 | 6,34.36643269084734,2.253107019185151,7 9 | 7,54.67831543179723,2.058148169308335,7 10 | 8,86.32457597552673,1.8519130400109565,7 11 | 9,137.34571863093225,1.7806900974507098,7 12 | 10,218.52231781127077,1.8077534851235078,7 13 | 11,347.67740747657774,1.748748306281974,7 14 | 12,548.9032450921546,1.6613241167223403,7 15 | 13,873.326162382842,1.651330582296661,7 16 | 14,1389.4954943731361,1.6364528834535221,7 17 | 15,2210.7407427431235,1.5830389662084277,7 18 | 0,2.135068261712696,5.387927639918842,70 19 | 1,3.396975674854875,4.441919630986654,70 20 | 2,5.404718875966568,3.650980098708243,70 21 | 3,8.532817105836056,2.964882889084697,70 22 | 4,13.576040010664132,2.364528674549281,70 23 | 5,21.600001510063358,2.058148169308335,70 24 | 6,34.36643269084734,1.8857391883332701,70 25 | 7,54.67831543179723,1.701906715025687,70 26 | 8,86.99530164663707,1.5782704754529915,70 27 | 9,137.34571863093225,1.503898990395478,70 28 | 10,218.52231781127077,1.5359952662433745,70 29 | 11,347.67740747657774,1.5130002972412333,70 30 | 12,553.1681197617227,1.4158433721949366,70 31 | 13,873.326162382842,1.403087300499544,70 32 | 14,1389.4954943731361,1.4073265034280422,70 33 | 15,2210.7407427431235,1.3696301622216445,70 34 | 0,8.532817105836056,3.0649221254266004,13 35 | 1,13.576040010664132,2.488957737477002,13 36 | 2,21.600001510063358,2.186150682125697,13 37 | 3,34.36643269084734,2.0151409460246823,13 38 | 4,54.256750713714545,1.8463346364468831,13 39 | 5,86.99530164663707,1.696780166273288,13 40 | 6,137.34571863093225,1.61682425076669,13 41 | 7,218.52231781127077,1.6563198124307354,13 42 | 8,347.67740747657774,1.6364528834535221,13 43 | 9,548.9032450921546,1.5313684823272344,13 44 | 10,873.326162382842,1.5084427796536233,13 45 | 11,1389.4954943731361,1.5313684823272344,13 46 | 12,2210.7407427431235,1.4724734008094114,13 47 | 0,2210.7407427431235,1.4417044827731889,34 48 | 1,1389.4954943731361,1.4680379598952795,34 49 | 2,873.326162382842,1.454811639925331,34 50 | 3,553.1681197617227,1.4680379598952795,34 51 | 4,347.67740747657774,1.5687765421676607,34 52 | 5,218.52231781127077,1.5974310560964657,34 53 | 6,137.34571863093225,1.554642615746889,34 54 | 7,86.99530164663707,1.6413971630100512,34 55 | 8,54.67831543179723,1.7646469312386575,34 56 | 9,34.36643269084734,1.9434946713988266,34 57 | 10,21.600001510063358,2.1147945967772817,34 58 | 11,13.576040010664132,2.422289057613663,34 59 | --------------------------------------------------------------------------------