├── README.md ├── ablations ├── finding_a.py ├── finding_k.py └── finding_tau.py ├── all_paths_128.py ├── estimate_best_pool.py ├── grid_search.py ├── objective.py ├── plotter.py └── process_128_grid.py /README.md: -------------------------------------------------------------------------------- 1 | ## Scaling Laws for Data Filtering 2 | 3 | ### Registering data buckets 4 | 5 | The buckets should be registered in the following file: `all_paths_128.py` 6 | This file contains the following information: 7 | - `path`: The path to the data file that has the evaluation results for a model trained on that dataset. 8 | - `samples_per_epoch_dict`: The number of samples per epoch for the corresponding dataset. 9 | - `match_with_dict`: This tells us if the evaluation is done at a fixed epoch interval, or a fixed sample interval. 10 | - `subsample_every_dict`: In case you want to take the average of every `k` evaluations. This is usually only useful when the evaluation is done at a fixed sample interval. 11 | 12 | ### Estimating data bucket parameters 13 | 14 | This step involves estimating the scaling parameters for each bucket of interest. 15 | 16 | 17 | ### Grid search to find the bucket scaling parameters 18 | 19 | Grid search is performed to find the best scaling parameters for each bucket. The grid search is performed using the following file: `grid_search.py`. The objective minimized in the grid search is defined in `objective.py`. We chose grid search because the of instabilities observed in scipy based optimization methods. 20 | 21 | 22 | ### Objective Functions 23 | 24 | This file implements scaling laws based on FADU, and also those inspired from work on Scaling Data Constrained Language Models. 25 | 26 | - `func_effective_utility`: This is the function that uses the effective utility formulation as proposed in our work. 27 | - `func_effective_data`: This is the function that uses the formulation of effective data from Scaling Data Constrained Language Models. 28 | 29 | ``` 30 | python process_128_grid.py --a_upper 0.02 --objective effective_utility --d 0.1 31 | ``` 32 | Here `a_upper` is used to give an upper limit to the grid search for `a`, and `d` is the irreducibile loss. Refer to `ablations/finding_a.py` if you want to jointly minimize `a` across the pools. 33 | Copy the obtained scaling parameters to the `results/parameter_values.py` file, and give an appropriate key name. 34 | 35 | ### Finding best bucket combination 36 | ``` 37 | python estimate_best_pool.py --key given_key_name 38 | ``` -------------------------------------------------------------------------------- /ablations/finding_a.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | ''' 4 | This file calls main() function of process_128_grid.py. multiple times with different values of args.a 5 | Then finds the mean loss of all the buckets 6 | Then reports the a with the minimum mean loss over all the a values 7 | ''' 8 | import argparse 9 | keys_of_interest = ["imagenet1k", "cifar10", "vtab/caltech101", "vtab/cifar100", "food101", "imagenet_sketch", "imagenetv2", "imagenet-a", "imagenet-o", "imagenet-r", "objectnet", "vtab/flowers", "vtab/pets", "voc2007", "vtab/resisc45", "cars", "retrieval/flickr_1k_test_image_text_retrieval", "retrieval/mscoco_2014_5k_test_image_text_retrieval"] 10 | if __name__ == "__main__": 11 | 12 | import os 13 | import sys 14 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 15 | from process_128_grid import main 16 | # run main(args) with different values of args.a. verbose 0, plot 0, metric imagenet1k, filtering tmars, a_upper None 17 | # a_values = 10 different values between 0.001 and 0.1 18 | # metric = "caltech" 19 | # filtering = "tmars" 20 | # # objective = "effective_utility_b_delta" 21 | # objective = "effective_data" 22 | 23 | import argparse 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--metric', type=str, default="imagenet1k") 26 | parser.add_argument('--filtering', type=str, default="tmars") 27 | parser.add_argument('--objective', type=str, default="effective_data") 28 | args = parser.parse_args() 29 | 30 | 31 | a_values_global = [0.01, 0.015, 0.018, 0.02, 0.022, 0.025, 0.03] 32 | 33 | d_vals = list(np.linspace(0.02, 0.15, 10)) 34 | 35 | loss_values = [] 36 | b_values = [] 37 | c_values = [] 38 | a_values = [] 39 | d_values = [] 40 | for d in d_vals: 41 | for a in a_values_global: 42 | print("Metric = ", args.metric, "a = ", a, "d = ", d, "filtering = ", args.filtering, "objective = ", args.objective) 43 | args = argparse.Namespace(a_upper=None, metric=args.metric, a=a, filtering=args.filtering, plot=0, verbose=0, objective=args.objective, tau = None, b = None, d=d) 44 | b_value, a_value, c_value, loss = main(args) 45 | loss_values.append(loss) 46 | b_values.append(b_value) 47 | c_values.append(c_value) 48 | a_values.append(a_value) 49 | d_values.append(d) 50 | print("loss_values = ", loss_values) 51 | print("a_values = ", a_values_global) 52 | print("d_values = ", d_values) 53 | 54 | index_min_loss = loss_values.index(min(loss_values)) 55 | print("best a = ", a_values[index_min_loss][0]) 56 | print("best loss = ", min(loss_values)) 57 | print("b_values = ", b_values[index_min_loss]) 58 | print("c_values = ", c_values[index_min_loss]) 59 | print("a_values = ", a_values[index_min_loss]) 60 | print("d_value = ", d_values[index_min_loss]) 61 | print (args) -------------------------------------------------------------------------------- /ablations/finding_k.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | ''' 4 | This file calls main() function of process_128_grid.py. multiple times with different values of args.a 5 | Then finds the mean loss of all the buckets 6 | Then reports the a with the minimum mean loss over all the a values 7 | ''' 8 | import argparse 9 | keys_of_interest = ["imagenet1k", "cifar10", "vtab/caltech101", "vtab/cifar100", "food101", "imagenet_sketch", "imagenetv2", "imagenet-a", "imagenet-o", "imagenet-r", "objectnet", "vtab/flowers", "vtab/pets", "voc2007", "vtab/resisc45", "cars", "retrieval/flickr_1k_test_image_text_retrieval", "retrieval/mscoco_2014_5k_test_image_text_retrieval"] 10 | if __name__ == "__main__": 11 | 12 | import os 13 | import sys 14 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 15 | from process_128_grid import main 16 | # run main(args) with different values of args.a. verbose 0, plot 0, metric imagenet1k, filtering tmars, a_upper None 17 | # a_values = 10 different values between 0.001 and 0.1 18 | # metric = "caltech" 19 | # filtering = "tmars" 20 | # # objective = "effective_utility_b_delta" 21 | # objective = "effective_data" 22 | 23 | import argparse 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--metric', type=str, default="imagenet1k") 26 | parser.add_argument('--filtering', type=str, default="tmars") 27 | parser.add_argument('--objective', type=str, default="effective_data") 28 | args = parser.parse_args() 29 | 30 | 31 | if args.metric == "imagenet1k": 32 | # k_values_global = [0.001, 0.01, 0.015, 0.018, 0.019, 0.02, 0.021, 0.022, 0.025, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.1] 33 | k_values_global = list(np.linspace(0.1, 3, 20)) 34 | # k_values_global = #[5, 7, 10, 11, 12, 15, 20, 25]#[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] 35 | #make a grid of a with 20 values between 1 and 2 36 | # k_values_global = np.linspace(0.03, 0.05, 10) 37 | else: 38 | k_values_global = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08] 39 | # k_values_global = np.linspace(0.005, 0.02, 10) 40 | #convert to list 41 | k_values_global = list(k_values_global) 42 | # k_values_global = [0.008, 0.009, 0.01, 0.011, 0.012, 0.015] 43 | 44 | 45 | loss_values = [] 46 | b_values = [] 47 | c_values = [] 48 | a_values = [] 49 | d_values = [] 50 | for k in k_values_global: 51 | args = argparse.Namespace(a = 0.022, metric=args.metric, filtering=args.filtering, plot=0, verbose=0, objective=args.objective, tau = 3, b = -0.13106060606060607, d=0.1, k=k, a_upper=None) 52 | b_value, a_value, c_value, loss = main(args) 53 | loss_values.append(loss) 54 | print("Loss values: ", loss_values) 55 | print("K values: ", k_values_global) -------------------------------------------------------------------------------- /ablations/finding_tau.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file calls main() function of process_128_grid.py. multiple times with different values of args.a 3 | Then finds the mean loss of all the buckets 4 | Then reports the a with the minimum mean loss over all the a values 5 | ''' 6 | import argparse 7 | 8 | if __name__ == "__main__": 9 | import os 10 | import sys 11 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 12 | from process_128_grid import main 13 | # run main(args) with different values of args.a. verbose 0, plot 0, metric imagenet1k, filtering tmars, a_upper None 14 | # a_values = 10 different values between 0.001 and 0.1 15 | metric = "imagenet1k" 16 | filtering = "tmars" 17 | objective = "effective_utility" 18 | 19 | if filtering == "clip": 20 | tau_values_global = [1,2,3,4,4.5,4.7,5,5.5,5.7,6,6.3,6.5,6.7,7,7.3,8,9,10,11,12,13,14,15,16,17,18,19,20] 21 | else: 22 | tau_values_global = [1,2,3,4,5,5.5,5.7,6,6.3,6.5,6.7,7,7.3,8,9,10,11,12,13,14,15,16,17,18,19,20] 23 | 24 | 25 | # retrieved by running finding_a.py 26 | a_dict = { 27 | "tmars": { 28 | "18tasks": 0.033, 29 | "imagenet1k": 0.022 30 | }, 31 | "clip": { 32 | "18tasks": 0.035, 33 | "imagenet1k": 0.021 34 | } 35 | } 36 | 37 | #based on results from finding_a.py 38 | a = a_dict[filtering][metric] 39 | 40 | loss_values = [] 41 | b_values = [] 42 | c_values = [] 43 | a_values = [] 44 | 45 | for tau in tau_values_global: 46 | args = argparse.Namespace(a_upper=None, metric=metric, a=a, filtering=filtering, plot=0, verbose=0, objective=objective, tau = tau) 47 | b_value, a_value, c_value, loss = main(args) 48 | loss_values.append(loss) 49 | b_values.append(b_value) 50 | c_values.append(c_value) 51 | a_values.append(a_value) 52 | print("loss_values = ", loss_values) 53 | print("a_values = ", tau_values_global) 54 | 55 | index_min_loss = loss_values.index(min(loss_values)) 56 | print("best tau = ", tau_values_global[index_min_loss]) 57 | print("best loss = ", min(loss_values)) 58 | print("b_values = ", b_values[index_min_loss]) 59 | print("c_values = ", c_values[index_min_loss]) 60 | print("a_values = ", a_values[index_min_loss]) 61 | print (args) -------------------------------------------------------------------------------- /all_paths_128.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | ROOT="your_root_path" 4 | 5 | paths = { 6 | "tmars_top0to10p": Path("${ROOT}/tmars_bucket/tmars_top0to10p/"), 7 | "tmars_top10to20p": Path("${ROOT}/tmars_bucket/tmars_top10to20p"), 8 | "tmars_top20to30p": Path("${ROOT}/tmars_bucket/tmars_top20to30p"), 9 | "tmars_top30to40p": Path("${ROOT}/tmars_bucket/tmars_top30to40p"), 10 | "tmars_top0to40p_random25p": Path("${ROOT}/tmars_bucket/tmars_0%_to_40%_random25%"), 11 | } 12 | 13 | alt_name = { 14 | "tmars_top0to10p": "Top 10%", 15 | "tmars_top10to20p": "Top 10%-20%", 16 | "tmars_top20to30p": "Top 20%-30%", 17 | "tmars_top30to40p": "Top 30%-40%", 18 | "tmars_top0to40p_random25p": "Top 0%-40% Random 25% Data", 19 | 20 | } 21 | 22 | samples_per_epoch_dict = { 23 | "tmars_top0to10p": 12_800_000, 24 | "tmars_top10to20p": 12_800_000, 25 | "tmars_top20to30p": 12_800_000, 26 | "tmars_top30to40p": 12_800_000, 27 | "tmars_top0to40p_random25p": 12_800_000, 28 | } 29 | 30 | match_with_dict = { 31 | "tmars_top0to10p": "epoch", 32 | "tmars_top10to20p": "epoch", 33 | "tmars_top20to30p": "epoch", 34 | "tmars_top30to40p": "epoch", 35 | "tmars_top0to40p_random25p": "epoch", 36 | } 37 | 38 | subsample_every_dict = { 39 | "tmars_top0to10p": 1, 40 | "tmars_top10to20p": 1, 41 | "tmars_top20to30p": 1, 42 | "tmars_top30to40p": 1, 43 | "tmars_top0to40p_random25p": 1, 44 | } -------------------------------------------------------------------------------- /estimate_best_pool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import matplotlib.pyplot as plt 4 | import copy, json 5 | 6 | def power(x, power): 7 | if x==0: 8 | return 0 9 | return x**power 10 | 11 | from objective import func_effective_data_aggregation, func_effective_utility_aggregation 12 | 13 | 14 | def estimate_best_buckets(a_values, b_values, c_values, num_samples, pool_size, objective, d): 15 | d = d 16 | max_buckets = len(b_values) 17 | # we will store the error for each bucket length 18 | error = np.zeros(max_buckets) 19 | 20 | b_values = np.array(b_values) 21 | a_values = np.array(a_values) 22 | c_values = np.array(c_values) 23 | 24 | #sort largest based on b_values 25 | sorted_indices = np.argsort(b_values) 26 | b_values = b_values[sorted_indices] 27 | a_values = a_values[sorted_indices] 28 | c_values = c_values[sorted_indices] 29 | 30 | all_acc = [] 31 | for num_buckets in range(1, max_buckets + 1): 32 | # for num_buckets in range(2, 3): 33 | b_list = b_values[:num_buckets] 34 | a_list = a_values[:num_buckets] 35 | c_list = c_values[:num_buckets] 36 | 37 | total_data_size = pool_size*num_buckets 38 | 39 | if objective == "effective_data": 40 | c_list = c_list * num_buckets 41 | error[num_buckets - 1] = func_effective_data_aggregation(num_samples, [a_list, b_list, c_list, d], total_data_size) 42 | elif objective == "effective_utility": 43 | c_list = c_list * num_buckets 44 | error[num_buckets - 1] = func_effective_utility_aggregation(num_samples, [a_list, b_list, c_list, d], total_data_size) 45 | 46 | all_acc.append(1 - error[num_buckets - 1]) 47 | 48 | return error, np.argmin(error) + 1, all_acc 49 | print("best num buckets", np.argmin(error) + 1) 50 | 51 | 52 | 53 | if __name__ == "__main__": 54 | import argparse 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument("--pool_size", type=int, default=12_800_000) 57 | parser.add_argument("--path", type=str, default="results/parameter_values.jsonl") 58 | parser.add_argument("--key", type=str, default="tmars_imagenet1k") 59 | parser.add_argument("--objective", type=str, default="effective_utility") 60 | parser.add_argument("--d", type=float, default=0.1) 61 | args = parser.parse_args() 62 | 63 | found_key = 0 64 | with open(args.path, "r") as f: 65 | for line in f: 66 | data = json.loads(line) 67 | if data["id"] == args.key: 68 | found_key = 1 69 | break 70 | 71 | a_values = data["a_values"] 72 | b_values = data["b_values"] 73 | c_values = data["c_values"] 74 | 75 | assert found_key == 1, "key not found" 76 | 77 | b_values, a_values, c_values = b_values[:-1], a_values[:-1], c_values[:-1] 78 | 79 | all_accs = {"top10": [], "top20": [], "top30": [], "top40": []} 80 | for num_samples in [32_000_000, 64_000_000, 128_000_000, 640_000_000]: 81 | error, num_buckets, all_acc = estimate_best_buckets(a_values, b_values, c_values, num_samples, args.pool_size, args.objective, args.d) 82 | all_accs["top10"].append(all_acc[0]) 83 | all_accs["top20"].append(all_acc[1]) 84 | all_accs["top30"].append(all_acc[2]) 85 | all_accs["top40"].append(all_acc[3]) 86 | 87 | # print comman seperated values 88 | print(",".join([str(x*100) for x in all_accs["top10"]])) 89 | print(",".join([str(x*100) for x in all_accs["top20"]])) 90 | print(",".join([str(x*100) for x in all_accs["top30"]])) 91 | print(",".join([str(x*100) for x in all_accs["top40"]])) -------------------------------------------------------------------------------- /grid_search.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def grid_search(x_vals, samples_per_epoch_vals, error_vals, func, a_upper=0.2, a= None, tau = None, b = None, d= None): 7 | if a is None: 8 | a = np.linspace(0.0001, a_upper, 100) 9 | else: 10 | a = a 11 | 12 | if b is None: 13 | b_lim = np.linspace(-0.005, -0.5, 100) 14 | else: 15 | b_lim = b 16 | 17 | if tau is None: 18 | c_lim = np.linspace(1, 10, 100) 19 | else: 20 | c_lim = tau 21 | 22 | d = d 23 | 24 | # create a grid of all possible combinations 25 | grid = np.array(np.meshgrid(a, b_lim, c_lim, d)).T.reshape(-1, 4) 26 | #randomize the grid 27 | np.random.shuffle(grid) 28 | 29 | # get the best params by running func on all combinations 30 | # also store the loss grid to plot later of shape b_lim, c_lim 31 | # output_grid = np.zeros((len(b_lim), len(c_lim))) 32 | 33 | best_params = None 34 | best_loss = 10000 35 | pbar = tqdm(total=len(grid)) 36 | 37 | for params in grid: 38 | loss = 0 39 | func_values_list = [] 40 | 41 | for i in range(len(x_vals)): 42 | samples_per_epoch = samples_per_epoch_vals[i] 43 | samples_seen = x_vals[i] 44 | func_value = func(samples_seen, params, samples_per_epoch) 45 | func_values_list.append(func_value) 46 | true_value = error_vals[i] 47 | curr_loss = (func_value - true_value)**2 48 | loss += curr_loss 49 | 50 | param_1_idx = np.where(b_lim == params[1])[0][0] 51 | param_2_idx = np.where(c_lim == params[2])[0][0] 52 | 53 | if loss < best_loss: 54 | best_loss = loss 55 | best_params = params 56 | 57 | pbar.update(1) 58 | pbar.set_description("best loss: {} : best params : {}".format(best_loss, best_params)) 59 | if pbar.n == 100_000: 60 | break 61 | 62 | return best_loss, best_params 63 | 64 | 65 | def grid_search_from_dict(x_vals_dict, samples_per_epoch_vals_dict, error_vals_dict, func): 66 | ''' 67 | x_vals: dict of list of lists of x values 68 | samples_per_epoch_vals: dict of list of samples per epoch 69 | error_vals: dict of list of error values 70 | ''' 71 | a = 1 72 | b_lim_dict = {} 73 | for key in x_vals_dict.keys(): 74 | b_lim = np.linspace(-0.01, -0.2, 100) 75 | b_lim_dict[key] = b_lim 76 | 77 | c_lim = np.linspace(1, 100, 100) 78 | d = np.linspace(0.0, 0.4, 10) 79 | 80 | keys = list(x_vals_dict.keys()) 81 | 82 | grid = np.array(np.meshgrid(a, b_lim_dict[keys[0]],b_lim_dict[keys[1]], c_lim, d)).T.reshape(-1, 5) 83 | grid = np.array(np.meshgrid(a, b_lim_dict[keys[0]],1, c_lim, d)).T.reshape(-1, 5) 84 | 85 | np.random.shuffle(grid) 86 | 87 | output_grid = np.zeros((len(b_lim), len(c_lim))) 88 | best_params = None 89 | best_loss = 10000 90 | pbar = tqdm(total=len(grid)) 91 | 92 | for params in grid: 93 | loss = 0 94 | for key in keys: 95 | x_vals, samples_per_epoch_vals, error_vals = x_vals_dict[key], samples_per_epoch_vals_dict[key], error_vals_dict[key] 96 | b_index = keys.index(key) 97 | 98 | params_current = [params[0], params[1], params[3], params[4]] 99 | 100 | for i in range(len(x_vals)): 101 | samples_per_epoch = samples_per_epoch_vals[i] 102 | samples_seen = x_vals[i] 103 | func_value = func(samples_seen, params_current, samples_per_epoch) 104 | true_value = error_vals[i] 105 | curr_loss = (func_value - true_value)**2 106 | loss += curr_loss 107 | param_1_idx = np.where(b_lim == params[1])[0][0] 108 | param_2_idx = np.where(c_lim == params[3])[0][0] 109 | output_grid[param_1_idx, param_2_idx] = loss 110 | 111 | if loss < best_loss: 112 | best_loss = loss 113 | best_params = params 114 | 115 | pbar.update(1) 116 | pbar.set_description("best loss: {} : best params : {}".format(best_loss, best_params)) 117 | 118 | fig = plt.figure() 119 | plt.imshow(output_grid, cmap='hot', interpolation='nearest') 120 | 121 | plt.colorbar() 122 | plt.xlabel("c") 123 | plt.ylabel("b") 124 | 125 | plt.xticks(np.arange(0, len(c_lim), 10), c_lim[::10]) 126 | plt.yticks(np.arange(0, len(b_lim), 10), b_lim[::10]) 127 | plt.clim(0, 0.2) 128 | 129 | plt.xticks(rotation=90) 130 | plt.yticks(rotation=0) 131 | plt.title("Loss grid") 132 | plt.savefig("grid_search.png", bbox_inches='tight') 133 | 134 | 135 | 136 | print("best params", best_params) 137 | return best_params -------------------------------------------------------------------------------- /objective.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | 4 | # d = 0.02 5 | 6 | def power(x, power): 7 | if x==0: 8 | return 0 9 | return x**power 10 | 11 | 12 | ############################################################################################################ 13 | # functions for effective utility 14 | ############################################################################################################ 15 | def func_effective_utility_aggregation(samples_seen_total, params, samples_per_epoch, normalizer=100_000, d=None): 16 | ''' 17 | \del y = (y * b * \delta) / n 18 | b_decayed = b * \delta 19 | \delta = base**(epochs/c) 20 | 21 | ''' 22 | a_list, b_list, c_list, d = params 23 | # assert all a values are the same 24 | assert np.all(a_list == a_list[0]) 25 | 26 | a = a_list[0] 27 | 28 | 29 | #get effective params for the original function 30 | base = 0.5 31 | num_epochs_full = samples_seen_total // samples_per_epoch 32 | # Creating arrays for each epoch and an additional one for partial epoch if exists 33 | epochs = np.arange(num_epochs_full + 1) 34 | #effective b value = b * \delta for every epoch 35 | all_b_decayed = [] 36 | 37 | for i in range(len(a_list)): 38 | b, c = b_list[i], c_list[i] 39 | delta_list = base**(epochs/c) 40 | b_decayed_list = b * delta_list 41 | all_b_decayed.append(b_decayed_list) 42 | 43 | b_decayed_list = np.mean(all_b_decayed, axis=0) 44 | 45 | # Normalizing the samples 46 | samples = a*np.minimum(samples_per_epoch * (epochs + 1), samples_seen_total) 47 | samples_1 = a*samples_per_epoch * epochs 48 | samples, samples_1 = samples / normalizer, samples_1 / normalizer 49 | samples_all = copy.deepcopy(samples) 50 | 51 | # Calculating the loss 52 | loss = d + (power(samples_all[0], b_decayed_list[0])) 53 | if len(samples_all) > 1: 54 | samples, samples_1, epochs = samples[1:], samples_1[1:], epochs[1:] 55 | ratio = (samples/samples_1)**b_decayed_list[1:] 56 | loss *= ratio.prod() 57 | return loss 58 | 59 | 60 | def func_effective_utility(samples_seen, params, samples_per_epoch, normalizer=100_000, d=None): 61 | ''' 62 | \del y = (y * b * \delta) / n 63 | b_decayed = b * \delta 64 | \delta = base**(epochs/c) 65 | 66 | ''' 67 | a, b, c, d = params 68 | base = 0.5 69 | num_epochs_full = samples_seen // samples_per_epoch 70 | # Creating arrays for each epoch and an additional one for partial epoch if exists 71 | epochs = np.arange(num_epochs_full + 1) 72 | 73 | 74 | #effective b value = b * \delta for every epoch 75 | delta_list = base**(epochs/c) 76 | b_decayed_list = b * delta_list 77 | 78 | # Normalizing the samples 79 | samples = a*np.minimum(samples_per_epoch * (epochs + 1), samples_seen) 80 | samples_1 = a*samples_per_epoch * epochs 81 | samples, samples_1 = samples / normalizer, samples_1 / normalizer 82 | samples_all = copy.deepcopy(samples) 83 | 84 | # Calculating the loss 85 | loss = (power(samples_all[0], b_decayed_list[0])) + d 86 | if len(samples_all) > 1: 87 | samples, samples_1, epochs = samples[1:], samples_1[1:], epochs[1:] 88 | ratio = (samples/samples_1)**b_decayed_list[1:] 89 | loss *= ratio.prod() 90 | return loss 91 | 92 | ############################################################################################################ 93 | # functions for effective data with changing utility 94 | ############################################################################################################ 95 | 96 | def get_effective_samples(samples_seen, params, samples_per_epoch, normalizer=100_000): 97 | num_epochs_full = samples_seen // samples_per_epoch 98 | a, b, c, d = params 99 | base = 0.5 100 | 101 | # Creating arrays for each epoch and an additional one for partial epoch if exists 102 | epochs = np.arange(num_epochs_full + 1) 103 | samples = a * np.minimum(samples_per_epoch * (epochs + 1), samples_seen) 104 | samples_1 = a * samples_per_epoch * epochs 105 | 106 | # Normalizing the samples 107 | samples, samples_1 = samples / normalizer, samples_1 / normalizer 108 | 109 | # Calculating the effective samples 110 | effective_samples = 0 111 | for i in range(len(samples)): 112 | effective_samples += (samples[i] - samples_1[i]) * base**(epochs[i]/c) 113 | 114 | return effective_samples 115 | 116 | 117 | def func_effective_data_aggregation(samples_seen_total, params, samples_per_epoch, normalizer=100_000): 118 | base = 0.5 119 | a_list, b_list, c_list, d = params 120 | samples_seen_per_bucket = samples_seen_total//len(a_list) 121 | 122 | # assert all a values are the same 123 | assert np.all(a_list == a_list[0]) 124 | a = a_list[0] 125 | 126 | num_epochs_full = int(samples_seen_total // samples_per_epoch) 127 | epochs = np.arange(num_epochs_full + 1) 128 | #effective b value = b * \delta for every epoch 129 | all_b_decayed = [] 130 | all_delta_list = [] 131 | 132 | for i in range(len(a_list)): 133 | b, c = b_list[i], c_list[i] 134 | delta_list = base**(epochs/c) 135 | b_decayed_list = b * delta_list 136 | all_b_decayed.append(b_decayed_list) 137 | all_delta_list.append(delta_list) 138 | 139 | # b = (b1\delta1 + b2\delta2 + b3\delta3 + b4\delta4) / (\delta1 + \delta2 + \delta3 + \delta4) 140 | b_effective_list = np.sum(all_b_decayed, axis=0) / np.sum(all_delta_list, axis=0) 141 | 142 | # Creating arrays for each epoch and an additional one for partial epoch if exists 143 | samples = a*np.minimum(samples_per_epoch * (epochs + 1), samples_seen_total) 144 | samples_1 = a*samples_per_epoch * epochs 145 | samples, samples_1 = samples / normalizer, samples_1 / normalizer 146 | samples_all = copy.deepcopy(samples) 147 | 148 | loss = (power(samples_all[0], b_effective_list[0])) 149 | 150 | samples_effective_prev = 0 151 | for epoch in range(num_epochs_full + 1): 152 | # get b_effective for this epoch 153 | b_effective = b_effective_list[epoch] 154 | # get samples for this epoch by iterating over each bucket and its corresponding \delta value 155 | samples_effective = samples_effective_prev 156 | samples_in_current_epoch = samples[epoch] - samples_1[epoch] 157 | samples_per_bucket_in_epoch = samples_in_current_epoch / len(a_list) 158 | for bucket_id in range(len(a_list)): 159 | delta = all_delta_list[bucket_id][epoch] 160 | samples_effective += samples_per_bucket_in_epoch * delta 161 | 162 | if epoch > 0: 163 | samples_ratio = samples_effective / samples_effective_prev 164 | loss *= (power(samples_ratio, b_effective)) 165 | samples_effective_prev = samples_effective 166 | 167 | return loss+d 168 | 169 | def func_effective_data(samples_seen, params, samples_per_epoch, normalizer=100_000, d=None): 170 | a, b, c, d = params 171 | effective_samples = get_effective_samples(samples_seen, params, samples_per_epoch, normalizer) 172 | loss = (power(effective_samples, b)) + d 173 | 174 | return loss -------------------------------------------------------------------------------- /plotter.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | from matplotlib import pyplot as plt 3 | import numpy as np 4 | 5 | keys_of_interest = ["imagenet1k", "cifar10", "vtab/caltech101", "vtab/cifar100", "food101", "imagenet_sketch", "imagenetv2", "imagenet-a", "imagenet-o", "imagenet-r", "objectnet", "vtab/flowers", "vtab/pets", "voc2007", "vtab/resisc45", "cars", "retrieval/flickr_1k_test_image_text_retrieval", "retrieval/mscoco_2014_5k_test_image_text_retrieval"] 6 | # for these keys, make a clean names dict 7 | clean_dataset_names = { 8 | "imagenet1k": "ImageNet-1k", 9 | "cifar10": "CIFAR-10", 10 | "vtab/caltech101": "Caltech101", 11 | "vtab/cifar100": "CIFAR100", 12 | "food101": "Food101", 13 | "imagenet_sketch": "ImageNet-Sketch", 14 | "imagenetv2": "ImageNetV2", 15 | "imagenet-a": "ImageNet-A", 16 | "imagenet-o": "ImageNet-O", 17 | "imagenet-r": "ImageNet-R", 18 | "objectnet": "ObjectNet", 19 | "vtab/flowers": "Flowers", 20 | "vtab/pets": "Pets", 21 | "voc2007": "VOC2007", 22 | "vtab/resisc45": "RESISC45", 23 | "cars": "CARS", 24 | "retrieval/flickr_1k_test_image_text_retrieval": "Flickr1k", 25 | "retrieval/mscoco_2014_5k_test_image_text_retrieval": "MSCOCO", 26 | "18tasks": "Avg over 18 Tasks", 27 | } 28 | 29 | mpl.rcParams.update({ 30 | # 'text.usetex': True, # Use LaTeX for all text handling 31 | # 'font.family': 'serif', # Use serif font instead of sans-serif 32 | 'font.serif': 'Times', # Specific serif font (e.g., Times) 33 | 'axes.labelsize': 14, # Size of axis labels 34 | 'axes.titlesize': 16, # Size of title 35 | 'font.size': 14, # Size of general text 36 | 'legend.fontsize': 14, # Size of legend text 37 | 'xtick.labelsize': 14, # Size of x-tick labels 38 | 'ytick.labelsize': 12, # Size of y-tick labels 39 | 'figure.figsize': [6.4, 4.8], # Default figure size 40 | 'lines.linewidth': 1.5, # Width of lines 41 | 'lines.markersize': 6, # Size of markers 42 | 'axes.grid': True, # Enable grid by default 43 | 'grid.alpha': 0.5, # Transparency of grid 44 | 'grid.linestyle': '--', # Style of grid lines 45 | }) 46 | 47 | def plot_results(args, org_names, paths, x_vals_dict, y_vals_dict, error_vals_dict, fitted_vals_dict, a_values, b_values, c_values, d_values, samples_per_epoch): 48 | names = [org_names [i] + f"| b={b_values[i]:.2f} | $\\tau=${c_values[i]:.2f}" for i in range(len(org_names)-1)] 49 | names.append(org_names[-1] + f"\n| b={b_values[-1]:.2f} | $\\tau=${c_values[-1]:.2f}") 50 | # colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'orange', 'purple', 'brown', 'gray', 'pink', 'orange', 'purple', 'brown', 'gray'] 51 | # colors = ['#008066', '#80c066', '#ffff66', '#67000d', '#0000FF'] 52 | # colors = ['#004529', '#556b2f', '#8b0000', '#a50026', '#0000FF'] 53 | colors = ['darkgreen', 'limegreen', 'darkorange', 'peru','red', 'dodgerblue'] 54 | markers = ['o', 'x', '^', 's', 'p', '*', '+', 'D', 'v', 'h', '8', '1', '2', '3', '4', '5'] 55 | 56 | # make two subplots in the same figure 57 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4.5)) 58 | fig.subplots_adjust(wspace=0.3) 59 | 60 | # now everyting about first subplot below 61 | for i, key in enumerate(paths.keys()): 62 | data_name = key 63 | x_vals = x_vals_dict[data_name] 64 | y_vals = y_vals_dict[data_name] 65 | error_vals = error_vals_dict[data_name] 66 | x_vals_millions = [i / 1e6 for i in x_vals] 67 | #increase line width to 3 68 | ax1.scatter(x_vals_millions, error_vals, label=names[i], color=colors[i], marker=markers[i], zorder = 10) 69 | fitted_deltas = fitted_vals_dict[data_name] 70 | fitted = fitted_deltas 71 | # make the lines ... dottes 72 | ax1.plot(x_vals_millions, fitted, color=colors[i], zorder = 1, linestyle='dotted', linewidth=3) 73 | if i == len(paths.keys()) - 1: 74 | # make thickness in legend 2 but not in plot 75 | # make legend colour green 76 | ax1.plot([], [], color="black", zorder = 0, linestyle='dotted', label="Fitted scaling curve for the pool") 77 | 78 | 79 | 80 | # ax1.ylabel("Imagenet Zero-Shot Error") 81 | data_name_clean = clean_dataset_names[args.metric] 82 | ax1.set_ylabel(f"{data_name_clean} Zero-Shot Error") 83 | # ax1.xlabel("Millions of Samples Seen") 84 | ax1.set_xlabel("Millions of Samples Seen") 85 | 86 | # make the legend location in the middle of third quadrant and and 87 | # plt.legend(title='Legend Title') 88 | legend_title = "CLIP score based data pools" 89 | if args.filtering == "tmars": 90 | legend_title = "TMARS based data pools" 91 | #columsn = 2 92 | legend = ax1.legend(bbox_to_anchor=(-0.12, 1.45), loc='upper left', borderaxespad=0., title=legend_title, ncol=2) 93 | # legend = ax1.legend(loc='upper left', borderaxespad=0.) 94 | for text in legend.get_texts(): 95 | if text.get_text() == 'Fitted scaling curve for the pool': 96 | text.set_color('black') 97 | 98 | # second subplot now 99 | for i, key in enumerate(paths.keys()): 100 | data_name = key 101 | initial_b = b_values[i] 102 | x_vals = np.arange(0, 10) 103 | y_vals = [-1*initial_b * (0.5**(j/c_values[i])) for j in x_vals] 104 | # make the lines ... dottes 105 | horizontal_line_key = "Top 10%|" 106 | if args.filtering == "clip": 107 | horizontal_line_key = "Top 10%-20%" 108 | if horizontal_line_key in names[i]: 109 | #draw two horizontal lines one at x_vals = 2, and one at x_vals = 4, but it should be from x = 0 to x = 2 or 4 only 110 | ax2.plot([0, 2], [y_vals[2], y_vals[2]], color="black", zorder = 0, linestyle='-', linewidth=1.5) 111 | ax2.plot([0, 4], [y_vals[4], y_vals[4]], color="black", zorder = 0, linestyle='-', linewidth=1.5) 112 | ax2.plot(x_vals, y_vals, color=colors[i], zorder = 1, linestyle='dotted', linewidth=2, marker=markers[i], label = names[i]) 113 | 114 | # legend = ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.) 115 | 116 | 117 | # ax2.ylabel("Utility of b*\delta") #use latex 118 | ax2.set_ylabel("Data Utility ($b\\times\delta^{epoch}}$)") 119 | # ax2.xlabel("Number of Repetitions") 120 | ax2.set_xlabel("Number of Repetitions") 121 | plt.savefig(f"plots/{args.plot_name}.pdf", bbox_inches='tight') 122 | -------------------------------------------------------------------------------- /process_128_grid.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import json, sys, os 3 | import re 4 | import numpy as np 5 | 6 | from grid_search import grid_search 7 | from plotter import plot_results 8 | import copy 9 | 10 | 11 | samples_per_step = 4096 12 | 13 | keys_of_interest = ["imagenet1k", "cifar10", "vtab/caltech101", "vtab/cifar100", "food101", "imagenet_sketch", "imagenetv2", "imagenet-a", "imagenet-o", "imagenet-r", "objectnet", "vtab/flowers", "vtab/pets", "voc2007", "vtab/resisc45", "cars", "retrieval/flickr_1k_test_image_text_retrieval", "retrieval/mscoco_2014_5k_test_image_text_retrieval"] 14 | 15 | def get_avg(json_file_path, keys_of_interest): 16 | main_metrics = [] 17 | keys_of_interest_local = copy.deepcopy(keys_of_interest) 18 | with open(json_file_path, 'r') as file: 19 | for line in file: 20 | data = json.loads(line) 21 | if data["key"] in keys_of_interest_local: 22 | main_metrics.append(data["metrics"]["main_metric"]) 23 | keys_of_interest_local.remove(data["key"]) 24 | 25 | # assert check len of main_metrics is 18 26 | assert len(main_metrics) == 18 27 | 28 | 29 | average_main_metric = sum(main_metrics) / len(main_metrics) 30 | return average_main_metric 31 | 32 | def get_accuracy_from_jsonl(args, jsonl_file): 33 | if args.metric == "18tasks": 34 | return get_avg(jsonl_file, keys_of_interest) 35 | 36 | # assert args.metric == "imagenet1k" 37 | with open(jsonl_file, 'r') as f: 38 | main_metric = 0 39 | for i,line in enumerate(f): 40 | data = json.loads(line) 41 | if data.get("key") == args.metric: 42 | main_metric = data["metrics"]["main_metric"] 43 | return main_metric 44 | 45 | def get_jsonl_files_tmars(folder_name): 46 | all_paths = [] 47 | for i in [2,4,6,8,10]: 48 | new_path_to_append = f"eval_results_epoch_{i}_step_-1.jsonl" 49 | jsonl_path = os.path.join(folder_name, new_path_to_append) 50 | all_paths.append(jsonl_path) 51 | return all_paths 52 | 53 | def get_jsonl_files_clip(folder_name): 54 | all_paths = [] 55 | for i in [2,4,6,8,10]: 56 | new_path_to_append = f"_{i}x/eval_results.jsonl" 57 | jsonl_path = str(folder_name) + new_path_to_append 58 | all_paths.append(jsonl_path) 59 | return all_paths 60 | 61 | def get_jsonl_files_random_clip_for_tau_analysis(folder_name): 62 | all_paths = [] 63 | for i in [1,2,4,8]: 64 | if "bucket1" in str(folder_name) and i == 1: 65 | continue 66 | new_path_to_append = f"_{i}x/eval_results.jsonl" 67 | jsonl_path = str(folder_name) + new_path_to_append 68 | if os.path.exists(jsonl_path): 69 | all_paths.append(jsonl_path) 70 | else: 71 | print("Path does not exist: ", jsonl_path) 72 | return all_paths 73 | 74 | def get_jsonl_files_debug(folder_name): 75 | all_paths = [] 76 | # find all jsonl files in the folder 77 | for file in os.listdir(folder_name): 78 | if file.endswith(".jsonl"): 79 | all_paths.append(os.path.join(folder_name, file)) 80 | return all_paths 81 | 82 | def get_jsonl_files(folder_name): 83 | if "tmars" in str(folder_name): 84 | return get_jsonl_files_tmars(folder_name) 85 | elif "random_clip_top50_expts/bucket" in str(folder_name): 86 | return get_jsonl_files_random_clip_for_tau_analysis(folder_name) 87 | elif "evals" in str(folder_name): 88 | return get_jsonl_files_debug(folder_name) 89 | else: 90 | return get_jsonl_files_clip(folder_name) 91 | 92 | def get_all_results_from_folder(args, data_name, paths, match_with_dict, samples_per_epoch_dict, subsample_every=None): 93 | folder_path = paths[data_name] 94 | match_with = match_with_dict[data_name] 95 | result_dict = {} 96 | for jsonl_file in get_jsonl_files(folder_path): 97 | if "tmars" in str(folder_path) or "evals" in str(folder_path): 98 | match = re.search(r'epoch_(\d+)_', str(jsonl_file)) 99 | else: 100 | match = re.search(r'_(\d+)x', str(jsonl_file)) 101 | if match: 102 | epoch_number = int(match.group(1)) 103 | step_number = epoch_number * samples_per_epoch_dict[data_name] / samples_per_step 104 | result_dict[step_number*samples_per_step] = get_accuracy_from_jsonl(args, jsonl_file) 105 | 106 | result_dict = {k: v for k, v in sorted(result_dict.items())} 107 | return result_dict 108 | 109 | def main(args): 110 | all_results = {} 111 | # if args.filtering == "tmars": 112 | from all_paths_128 import paths, alt_name, samples_per_epoch_dict, match_with_dict, subsample_every_dict 113 | 114 | 115 | #load objective function 116 | if args.objective == "effective_data": 117 | from objective import func_effective_data as func 118 | elif args.objective == "effective_utility": 119 | from objective import func_effective_utility as func 120 | else: 121 | print("Not implemented") 122 | 123 | for key in paths.keys(): 124 | res = (get_all_results_from_folder(args, key, paths, match_with_dict, samples_per_epoch_dict, subsample_every_dict[key])) 125 | all_results[key] = res 126 | 127 | x_vals_dict = {} 128 | error_vals_dict = {} 129 | y_vals_dict = {} 130 | 131 | for key in paths.keys(): 132 | x_vals = list(all_results[key].keys()) 133 | y_vals = list(all_results[key].values()) 134 | error_vals = [1 - y_vals[i] for i in range(len(y_vals))] 135 | x_vals_dict[key] = x_vals 136 | y_vals_dict[key] = y_vals 137 | error_vals_dict[key] = error_vals 138 | 139 | def get_params_from_data(data_name, a_upper = None, a = None, tau = None, b = None, d= None): 140 | error_vals = error_vals_dict[data_name] 141 | x_vals = x_vals_dict[data_name] 142 | samples = samples_per_epoch_dict[data_name] 143 | samples = [samples for i in range(len(x_vals))] 144 | loss, popt = grid_search(x_vals, samples, error_vals, func, a_upper = a_upper, a = a, tau = tau, b = b, d = d) 145 | return loss, popt 146 | 147 | b_values = [] 148 | a_values = [] 149 | c_values = [] 150 | 151 | 152 | names = list(alt_name.values()) 153 | fitted_vals_dict = {} 154 | loss_values = {} 155 | 156 | for i, key in enumerate(paths.keys()): 157 | if args.verbose: 158 | print("****** ", key) 159 | 160 | print(key, args.a_upper, args.a, args.tau) 161 | tau_for_bucket = args.tau 162 | loss, popt = get_params_from_data(key, a_upper = args.a_upper, a = args.a, tau = tau_for_bucket, b = args.b, d=args.d) 163 | loss_values[key] = loss 164 | 165 | b_values.append(popt[1]) 166 | a_values.append(popt[0]) 167 | c_values.append(popt[2]) 168 | x_vals = x_vals_dict[key] 169 | samples = samples_per_epoch_dict[key] 170 | 171 | fit = [] 172 | 173 | for i in range(len(x_vals)): 174 | fit.append(func(x_vals[i], popt, samples).item()) 175 | fitted_vals_dict[key] = fit 176 | 177 | print("b_values = ", b_values) 178 | print("a_values = ", a_values) 179 | print("c_values = ", c_values) 180 | print("Avergae Loss = ", np.mean(list(loss_values.values()))) 181 | 182 | 183 | if args.plot: 184 | plot_results(args, names, paths, x_vals_dict, y_vals_dict, error_vals_dict, fitted_vals_dict, a_values, b_values, c_values, 0.1, samples_per_step) 185 | 186 | return b_values, a_values, c_values, np.mean(list(loss_values.values())) 187 | 188 | if __name__ == "__main__": 189 | import argparse 190 | parser = argparse.ArgumentParser() 191 | parser.add_argument('--a_upper', type=float, default=None) 192 | parser.add_argument('--metric', type=str, default="imagenet1k") 193 | parser.add_argument('--a', type=float, default=None) 194 | parser.add_argument('--plot', type=int, default=0) 195 | parser.add_argument('--verbose', type=int, default=0) 196 | parser.add_argument('--objective', type=str, default="effective_utility") 197 | parser.add_argument('--plot_name', type=str, default=None) 198 | parser.add_argument('--tau', type=float, default=None) 199 | parser.add_argument('--b', type=float, default=None) 200 | parser.add_argument('--d', type=float, default=None) 201 | parser.add_argument('--k', type=float, default=1) 202 | 203 | 204 | args = parser.parse_args() 205 | main(args) --------------------------------------------------------------------------------