├── README.md ├── figure_1.png ├── fit_bnsl_and_extrapolate__4_digit_addition__dataset_size_x-axis.py ├── fit_bnsl_and_extrapolate__4_digit_addition__dataset_size_x-axis__noiseless_simulation.py ├── make_figure_1__decomposition_of_bnsl_into_power_law_segments.py ├── plot__bnsl__fit_and_extrapolate__4_digit_addition__dataset_size_x-axis.png └── plot__bnsl__fit_and_extrapolate__4_digit_addition__dataset_size_x-axis__noiseless_simulation.png /README.md: -------------------------------------------------------------------------------- 1 | # Code Release for "Broken Neural Scaling Laws" (BNSL) paper ([arxiv.org/abs/2210.14891](https://arxiv.org/abs/2210.14891)) 2 | 3 | Read Appendix A.6 of arXiv version of this paper for more details on how to use this code. 4 | 5 | To reproduce the Fitting and Extrapolation of BNSL on 4 Digit Addition from Figure 5 Left, run 6 | 7 | ```python fit_bnsl_and_extrapolate__4_digit_addition__dataset_size_x-axis.py``` 8 | 9 | 10 | To reproduce the Fitting and Extrapolation of BNSL on a noiseless simulation of the scaling behavior of 4 Digit Addition from Figure 5 Right, run 11 | 12 | ```python fit_bnsl_and_extrapolate__4_digit_addition__dataset_size_x-axis__noiseless_simulation.py``` 13 | 14 | 15 | 16 | 17 | To reproduce the Decomposition of BNSL into Power Law Segments from Figure 1, run 18 | 19 | ```python make_figure_1__decomposition_of_bnsl_into_power_law_segments.py ``` 20 | 21 | 22 | # Note: 23 | 24 | 🚨🚨🚨 25 | 26 | **When you fit a BNSL to your own scaling data, you may need to adjust the grid search range and resolution to get a good fit.** 27 | 28 | 🚨🚨🚨 29 | 30 | # Here is some bibtex to use for citation: 31 | 32 | ``` 33 | @inproceedings{ 34 | caballero2023broken, 35 | title={Broken Neural Scaling Laws}, 36 | author={Ethan Caballero and Kshitij Gupta and Irina Rish and David Krueger}, 37 | booktitle={The Eleventh International Conference on Learning Representations }, 38 | year={2023}, 39 | url={https://arxiv.org/abs/2210.14891} 40 | } 41 | ``` 42 | -------------------------------------------------------------------------------- /figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethancaballero/broken_neural_scaling_laws/32307b81e385c7d3a95e562ed28bff8cabb44857/figure_1.png -------------------------------------------------------------------------------- /fit_bnsl_and_extrapolate__4_digit_addition__dataset_size_x-axis.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.pyplot as plt 3 | import scipy.optimize 4 | import time 5 | import math 6 | plt.style.use('seaborn-whitegrid') 7 | import numpy as np 8 | 9 | """ 10 | Code to reproduce Figure 5 Left of arxiv.org/abs/2210.14891 11 | """ 12 | 13 | def bnsl_with_1_break(_x, a, b, c0, c1, d1, f1): 14 | y = a + b * _x**(-c0) * (1 + (_x/d1)**(1/f1))**(-c1 * f1) 15 | return y 16 | 17 | def bnsl_with_1_break__log(_x, a, b, c0, c1, d1, f1): 18 | y = bnsl_with_1_break(_x, a, b, c0, c1, d1, f1) 19 | return np.log(y+1) 20 | 21 | def bnsl_with_1_break__msle_optim(p, _x, _y): 22 | a, b, c0, c1, d1, f1 = p 23 | b = 1.25**b - 1 + 1e-8 24 | d1 = 1.25**d1 - 1 + 1e-8 25 | y = bnsl_with_1_break(_x, a, b, c0, c1, d1, f1) 26 | return np.mean((np.log(y+1)-np.log(_y+1))**2) 27 | 28 | def bnsl_with_1_break__sle(p, _x, _y): 29 | a, b, c0, c1, d1, f1 = p 30 | y = bnsl_with_1_break(_x, a, b, c0, c1, d1, f1) 31 | return (np.log(y)-np.log(_y))**2 32 | 33 | 34 | x = np.array([160, 192, 256, 320, 384, 35 | 448, 480, 512, 544, 576, 36 | 608, 640, 672, 736, 800, 37 | 864, 928]) 38 | 39 | y = np.array([2.13809046, 2.11813418, 2.08955508, 2.06988398, 2.05404987, 40 | 2.03837089, 2.02814281, 2.00496872, 1.95576149, 1.86313841, 41 | 1.70891537, 1.50637664, 1.29754721, 0.96559684, 0.75856477, 42 | 0.64768338, 0.55695445]) 43 | 44 | if __name__ == '__main__': 45 | 46 | print("x ground_truth: ", x) 47 | print("y ground_truth: ", y) 48 | 49 | split_point = 14 50 | 51 | x1 = x[:split_point] 52 | y1 = y[:split_point] 53 | 54 | x2 = x[split_point:] 55 | y2 = y[split_point:] 56 | 57 | plt.plot(x2, y2, 'o', color=[0.0, 0.925, 0.0]) 58 | plt.plot(x1, y1, 'o', color='black') 59 | 60 | # grid search range and resolution 61 | p_grid = (slice(0.0, 1., .1), slice(0, 40, 2.5), slice(0, 1, 0.25), slice(0, 1, 0.25), slice(0, 40, 2.5), slice(0, 1, 0.25)) 62 | 63 | start = time.time() 64 | res = scipy.optimize.brute(bnsl_with_1_break__msle_optim, p_grid, args=(x1, y1), full_output=False, finish=None, Ns=1, workers=-1) 65 | a, b, c0, c1, d1, f1 = res 66 | b = 1.25**b - 1 + 1e-8 67 | d1 = 1.25**d1 - 1 + 1e-8 68 | y_log = np.log(y1+1) 69 | popt, _ = scipy.optimize.curve_fit(bnsl_with_1_break__log, x1, y_log, p0=[a, b, c0, c1, d1, f1], maxfev=100000000) 70 | a, b, c0, c1, d1, f1 = popt 71 | total_time = time.time() - start 72 | print("time: ", total_time) 73 | 74 | points = 4096 75 | x_tile = np.array([1.01**i * 10**0 for i in range(points)]).astype(float) 76 | 77 | print("a =", a) 78 | print("b =", b) 79 | print("c0 =", c0) 80 | print("c1 =", c1) 81 | print("d1 =", d1) 82 | print("f1 =", f1) 83 | 84 | pred = bnsl_with_1_break(x_tile.astype(float), a, b, c0, c1, d1, f1) 85 | plt.plot(x_tile, pred, color=[1.0, 0.125, 0.125], linewidth=2.5) 86 | 87 | sle = bnsl_with_1_break__sle((a, b, c0, c1, d1, f1), x, y) 88 | 89 | print("rmsle train: ", np.sqrt(np.mean(sle[:split_point]))) 90 | print("rmsle extrapolate: ", np.sqrt(np.mean(sle[split_point:]))) 91 | 92 | plt.title("4 Digit Addition") 93 | plt.xlabel("Training Dataset Size") 94 | plt.ylabel("Test Cross-Entropy") 95 | 96 | """ 97 | plt.xscale('log') 98 | plt.yscale('log') 99 | #""" 100 | 101 | plt.xlim(140,983) 102 | plt.ylim(0,2.5) 103 | plt.savefig('plot__bnsl__fit_and_extrapolate__4_digit_addition__dataset_size_x-axis.png', bbox_inches='tight') 104 | plt.show() 105 | 106 | plt.close() 107 | plt.cla() 108 | plt.clf() 109 | -------------------------------------------------------------------------------- /fit_bnsl_and_extrapolate__4_digit_addition__dataset_size_x-axis__noiseless_simulation.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.pyplot as plt 3 | import scipy.optimize 4 | import time 5 | import math 6 | plt.style.use('seaborn-whitegrid') 7 | import numpy as np 8 | 9 | """ 10 | Code to reproduce Figure 5 Right of arxiv.org/abs/2210.14891 11 | """ 12 | 13 | def bnsl_with_1_break(_x, a, b, c0, c1, d1, f1): 14 | y = a + b * _x**(-c0) * (1 + (_x/d1)**(1/f1))**(-c1 * f1) 15 | return y 16 | 17 | def bnsl_with_1_break__log(_x, a, b, c0, c1, d1, f1): 18 | y = bnsl_with_1_break(_x, a, b, c0, c1, d1, f1) 19 | return np.log(y+1) 20 | 21 | def bnsl_with_1_break__msle_optim(p, _x, _y): 22 | a, b, c0, c1, d1, f1 = p 23 | b = 1.25**b - 1 + 1e-8 24 | d1 = 1.25**d1 - 1 + 1e-8 25 | y = bnsl_with_1_break(_x, a, b, c0, c1, d1, f1) 26 | return np.mean((np.log(y+1)-np.log(_y+1))**2) 27 | 28 | def bnsl_with_1_break__sle(p, _x, _y): 29 | a, b, c0, c1, d1, f1 = p 30 | y = bnsl_with_1_break(_x, a, b, c0, c1, d1, f1) 31 | return (np.log(y)-np.log(_y))**2 32 | 33 | 34 | # ground_truth 35 | a_gt = 0.41388453071629455 36 | b_gt = 2.2772722897737556 37 | c0_gt = 0.055077348404973955 38 | c1_gt = 5.662903816010331 39 | d1_gt = 612.5836172918001 40 | f1_gt = 0.059193036393742314 41 | 42 | x_points = 4096 43 | x = np.array([i for i in range(1, x_points)]).astype(float) 44 | y = bnsl_with_1_break(x, a_gt, b_gt, c0_gt, c1_gt, d1_gt, f1_gt) 45 | 46 | if __name__ == '__main__': 47 | 48 | print("x ground_truth: ", x) 49 | print("y ground_truth: ", y) 50 | 51 | split_point = 405 52 | 53 | # set split_point to 390 to see what failure looks like 54 | # split_point = 390 55 | 56 | x1 = x[:split_point] 57 | y1 = y[:split_point] 58 | 59 | x2 = x[split_point:] 60 | y2 = y[split_point:] 61 | 62 | plt.plot(x2, y2, 'o', color=[0.0, 0.925, 0.0]) 63 | plt.plot(x1, y1, 'o', color='black') 64 | 65 | # grid search range and resolution. 66 | # this range can be made as wide as you want and extrapolation will be the same, but grid search will run slower on a laptop if made wider. 67 | p_grid = (slice(0.0, 2.5, .1), slice(0, 5, .25), slice(0, .2, 0.05), slice(0, 8, 0.5), slice(0, 35, 2.5), slice(0, .2, 0.05)) 68 | 69 | start = time.time() 70 | res = scipy.optimize.brute(bnsl_with_1_break__msle_optim, p_grid, args=(x1, y1), full_output=False, finish=None, Ns=1, workers=-1) 71 | a, b, c0, c1, d1, f1 = res 72 | b = 1.25**b - 1 + 1e-8 73 | d1 = 1.25**d1 - 1 + 1e-8 74 | y_log = np.log(y1+1) 75 | popt, _ = scipy.optimize.curve_fit(bnsl_with_1_break__log, x1, y_log, p0=[a, b, c0, c1, d1, f1], maxfev=100000000) 76 | a, b, c0, c1, d1, f1 = popt 77 | total_time = time.time() - start 78 | print("time: ", total_time) 79 | 80 | points = 4096 81 | x_tile = np.array([1.01**i * 10**0 for i in range(points)]).astype(float) 82 | 83 | print("a =", a) 84 | print("b =", b) 85 | print("c0 =", c0) 86 | print("c1 =", c1) 87 | print("d1 =", d1) 88 | print("f1 =", f1) 89 | 90 | pred = bnsl_with_1_break(x_tile.astype(float), a, b, c0, c1, d1, f1) 91 | plt.plot(x_tile, pred, color=[1.0, 0.125, 0.125], linewidth=2.5) 92 | 93 | sle = bnsl_with_1_break__sle((a, b, c0, c1, d1, f1), x, y) 94 | 95 | print("rmsle train: ", np.sqrt(np.mean(sle[:split_point]))) 96 | print("rmsle extrapolate: ", np.sqrt(np.mean(sle[split_point:]))) 97 | 98 | plt.title("4 Digit Addition") 99 | plt.xlabel("Training Dataset Size") 100 | plt.ylabel("Test Cross-Entropy") 101 | 102 | """ 103 | plt.xscale('log') 104 | plt.yscale('log') 105 | #""" 106 | 107 | plt.xlim(140,983) 108 | plt.ylim(0,2.5) 109 | plt.savefig('plot__bnsl__fit_and_extrapolate__4_digit_addition__dataset_size_x-axis__noiseless_simulation.png', bbox_inches='tight') 110 | plt.show() 111 | 112 | plt.close() 113 | plt.cla() 114 | plt.clf() 115 | -------------------------------------------------------------------------------- /make_figure_1__decomposition_of_bnsl_into_power_law_segments.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | x_min = -1 5 | x_max = 26 6 | x = np.logspace(x_min, x_max, 8192) 7 | 8 | # unimprovable performance 9 | a = 1e-3 10 | 11 | # offset (on log-log plot) 12 | b = 4.66e3 13 | 14 | #changes in slope (on log-log plot) 15 | c0 = 0.05 16 | c1 = .49 17 | c2 = -0.811 18 | c3 = 2.1 19 | 20 | #where breaks happens 21 | d1 = 3e6 22 | d2 = 1e14 23 | d3 = 2e20 24 | 25 | # sharpness of transitions during breaks; smaller (nonnegative) values are sharper; larger values are less sharp 26 | f1 = 1.1 27 | f2 = 1.425 28 | f3 = 0.05 29 | 30 | """ 31 | Decomposition based on combination of insights from: 32 | https://en.wikipedia.org/wiki/Power_law#Broken_power_law , 33 | https://docs.astropy.org/en/stable/api/astropy.modeling.powerlaws.SmoothlyBrokenPowerLaw1D.html , 34 | paragraph below equation 1.1 of http://wallpaintings.at/geminga/Multiply_broken_power-law_densities_as_survival_functions.pdf#page=2 35 | """ 36 | 37 | # broken neual scaling law (bnsl) with 3 breaks 38 | y = a + b*(x)**(-c0) * (1.+(x/d1)**(1./f1))**((-c1)*f1) * (1.+(x/d2)**(1./f2))**((-c2)*f2) * (1.+(x/d3)**(1./f3))**((-c3)*f3) 39 | 40 | x1 = x[(x <= d1*25)] 41 | x2 = x[(x >= d1*.04) & (x <= d2*25)] 42 | x3 = x[(x >= d2*.04) & (x <= d3*25)] 43 | x4 = x[(x >= d3*.32)] 44 | 45 | #individual power law segments within the bnsl; this decomposition is usable when values of f are not too large and a is subtracted out from y-axis 46 | segment1 = b * (x1)**(-c0) 47 | segment2 = b * (d1)**(-c0) * (x2/d1)**(-(c1+c0)) 48 | segment3 = b * (d1)**(-c0) * (d2/d1)**(-(c1+c0)) * (x3/d2)**(-(c2+c1+c0)) 49 | segment4 = b * (d1)**(-c0) * (d2/d1)**(-(c1+c0)) * (d3/d2)**(-(c2+c1+c0)) * (x4/d3)**(-(c3+c2+c1+c0)) 50 | #segment4 = a + b * (d1)**(-c0) * (d2/d1)**(-(c1+c0)) * (d3/d2)**(-(c2+c1+c0)) * (x4/d3)**(-(c3+c2+c1+c0)) 51 | 52 | linewidth = 2.0 53 | plt.figure(figsize=(6.4, 4)) 54 | 55 | plt.title("Decomposition of BNSL into Power Law Segments") 56 | 57 | plt.plot(x, y, color = 'black', label='BNSL', linewidth=3.5) 58 | 59 | plt.axvline(x = d1, linestyle=':', color = [0.8,0.0,0.8], label = 'Break 1', linewidth=linewidth) 60 | plt.axvline(x = d2, linestyle=':', color = [0.6,0.0,0.6], label = 'Break 2', linewidth=linewidth) 61 | plt.axvline(x = d3, linestyle=':', color = [0.4,0.0,0.4], label = 'Break 3', linewidth=linewidth) 62 | 63 | plt.plot(x1, segment1, '--', label='Segment 1', color = [0.8,0.8,0.0], linewidth=linewidth) 64 | plt.plot(x2, segment2, '--', label='Segment 2', color = [0.0,0.9,0.9], linewidth=linewidth) 65 | plt.plot(x3, segment3, '--', label='Segment 3', color = [1.0, 0.45, 0.45], linewidth=linewidth) 66 | plt.plot(x4, segment4, '--', label='Segment 4', color = [0.2, 0.925, 0.2], linewidth=linewidth) 67 | 68 | plt.axhline(a, linestyle=('-.'), color = 'silver', label = 'Limit', linewidth=linewidth*.89) 69 | 70 | plt.xlabel("Quantity Being Scaled") 71 | plt.ylabel("Performance Evaluation Metric") 72 | 73 | plt.xscale('log') 74 | plt.yscale('log') 75 | 76 | plt.xlim(1.1*(10**x_min), .9*(10**x_max)) 77 | plt.ylim(y.min()/1.5, y.max()*1.5) 78 | 79 | plt.legend(loc='lower left') 80 | 81 | plt.savefig('figure_1.png', bbox_inches='tight') 82 | plt.show() 83 | 84 | plt.close() 85 | plt.cla() 86 | plt.clf() 87 | -------------------------------------------------------------------------------- /plot__bnsl__fit_and_extrapolate__4_digit_addition__dataset_size_x-axis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethancaballero/broken_neural_scaling_laws/32307b81e385c7d3a95e562ed28bff8cabb44857/plot__bnsl__fit_and_extrapolate__4_digit_addition__dataset_size_x-axis.png -------------------------------------------------------------------------------- /plot__bnsl__fit_and_extrapolate__4_digit_addition__dataset_size_x-axis__noiseless_simulation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethancaballero/broken_neural_scaling_laws/32307b81e385c7d3a95e562ed28bff8cabb44857/plot__bnsl__fit_and_extrapolate__4_digit_addition__dataset_size_x-axis__noiseless_simulation.png --------------------------------------------------------------------------------