├── .gitignore ├── LICENSE.md ├── README.md ├── notebooks ├── Plotting.ipynb ├── Plotting_old.ipynb ├── amazon_polarity.png ├── anthropic_hh.png ├── boolq.png ├── cosmos_qa.png └── sciq.png ├── pyproject.toml ├── setup.py ├── sweep.py ├── train_simple.py ├── train_weak_to_strong.py ├── vision ├── README.md ├── data.py ├── models.py └── run_weak_strong.py ├── weak-to-strong-setup.png └── weak_to_strong ├── __init__.py ├── common.py ├── datasets.py ├── eval.py ├── logger.py ├── loss.py ├── model.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | dump 2 | *.pyc 3 | *.swp 4 | *.swo 5 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright 2023 OpenAI 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **STATUS**: This codebase is not well tested and does not use the exact same settings we used in the paper, but in our experience gives qualitatively similar results when using large model size gaps and multiple seeds. Expected results can be found for two datasets below. 2 | 3 | # Weak-to-strong generalization 4 | 5 | ![Our setup and how it relates to superhuman AI alignment](./weak-to-strong-setup.png) 6 | 7 | This project contains code for implementing our [paper on weak-to-strong generalization](https://cdn.openai.com/papers/weak-to-strong-generalization.pdf). 8 | 9 | The primary codebase contains a re-implementation of our weak-to-strong learning setup for binary classification tasks. The codebase contains code for fine-tuning pretrained language models, and also training against the labels from another language model. We support various losses described in the paper as well, such as the confidence auxiliary loss. 10 | 11 | The `vision` directory contains stand-alone code for weak-to-strong in the vision models setting (AlexNet -> DINO on ImageNet). 12 | 13 | ### Getting Started 14 | 15 | These instructions will get you a copy of the project up and running on your local machine for development and testing purposes. 16 | 17 | #### Installation 18 | 19 | You need to have Python installed on your machine. The project uses `pyproject.toml` to manage dependencies. To install the dependencies, you can use a package manager like `pip`: 20 | 21 | ``` 22 | pip install . 23 | ``` 24 | 25 | #### Running the Script 26 | 27 | The main script of the project is `sweep.py`. It can be run from the command line using the following command: 28 | ``` 29 | python sweep.py --model_sizes=gpt2,gpt2-medium 30 | ``` 31 | 32 | In addition to `--model_sizes`, `sweep.py` takes in almost all of the arguments that `train_simple.py` takes (e.g. 33 | `--batch_size`, `--n_docs`, `--n_test_docs` etc., see `train_simple.py` for a full list). These arguments are simply 34 | forwarded to `train_simple.py`. 35 | 36 | `sweep.py` calls `train_simple.py` in the following way: 37 | 1. First, it calls `train_simple.py` for each model size to train the ground truth models 38 | 2. Then, for each pair of weak and strong models in `model_sizes` (where a model can be the strong model in the pair 39 | only if its index in the `model_sizes` list is >= the index of the weak model), it calls `train_simple.py` with a 40 | `--weak_model_size` argument so that the strong model is trained with the labels of the weak model. 41 | 42 | E.g. the example above will run gpt2 (ground truth), gpt2-medium (ground truth), gpt2 -> gpt2, gpt2 -> gpt2-medium, and 43 | gpt2-medium -> gpt2-medium. 44 | 45 | If needed, you can also run `train_simple.py` directly. 46 | 47 | Note that `sweep.py` will not accept the arguments `--weak_model_size`, `--weak_labels_path` or `--model_size` (as opposed 48 | to `--model_sizes`, with an "s") as choosing their values automatically is precisely the point of `sweep.py`. 49 | 50 | An example of Jupyter notebook for plotting results is found in `notebooks/Plotting.ipynb`. 51 | 52 | At the time of release, the main script was called `train_weak_to_strong.py`, but it was less usable than 53 | `sweep.py` and `train_simple.py`. It is preserved here and the old instructions are given at the end of the document. 54 | 55 | #### Expected results 56 | 57 | 58 |
59 | 60 |
61 | 62 |
63 | 64 |
65 | 66 | 67 | ### Authors 68 | 69 | - Adrien Ecoffet 70 | - Manas Joglekar 71 | - Jeffrey Wu 72 | - Jan Hendrik Kirchner 73 | - Pavel Izmailov (vision) 74 | 75 | ### License 76 | 77 | This project is licensed under the MIT License - see the LICENSE.md file for details. 78 | 79 | ### Acknowledgments 80 | 81 | - Hugging Face for their open-source transformer models 82 | 83 | ### Original single run script 84 | 85 | You can run the original training script using: 86 | ``` 87 | python train_weak_to_strong.py 88 | ``` 89 | 90 | The script accepts several command-line arguments to customize the training process. Here are some examples: 91 | 92 | ``` 93 | python train_weak_to_strong.py --batch_size 32 --max_ctx 512 --ds_name "sciq" --loss "logconf" --n_docs 1000 --n_test_docs 100 --weak_model_size "gpt2-medium" --strong_model_size "gpt2-large" --seed 42 94 | ``` 95 | 96 | The notebook `notebooks/Plotting_old.ipynb` preserves the plotting notebook corresponding to old style training. 97 | 98 | The key difference between this style and the new `sweep.py` style is that `train_weak_to_strong.py` will always 99 | train three models: a weak model, a transfer model, and a strong model. `sweep.py` optimizes this by training 100 | a series of ground truth models (which will serve as weak and strong models) as well as a series of transfer models 101 | all in one go. This reduces training duplication and is arguably simpler. The files generated by `train_simple.py` 102 | and `sweep.py` are also simpler to use. -------------------------------------------------------------------------------- /notebooks/Plotting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "eb9a4b5a", 6 | "metadata": {}, 7 | "source": [ 8 | "# Simple Plotting\n" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "88c7ff9f", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "RESULTS_PATH = \"../../your_sweep_path/default\"\n", 19 | "\n", 20 | "PLOT_ALL_SEEDS = False\n", 21 | "# Full sweep\n", 22 | "MODELS_TO_PLOT = [\"gpt2\", \"gpt2-medium\", \"gpt2-large\", \"gpt2-xl\", \"Qwen/Qwen-1_8B\", \"Qwen/Qwen-7B\", \"Qwen/Qwen-14B\"]\n", 23 | "# Minimal sweep\n", 24 | "# MODELS_TO_PLOT = [\"gpt2\", \"gpt2-medium\", \"gpt2-large\"]\n" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "00ca073c", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import numpy as np\n", 35 | "import pandas as pd\n", 36 | "import matplotlib.pyplot as plt\n", 37 | "import seaborn as sns\n", 38 | "sns.set_style('whitegrid')\n", 39 | "\n", 40 | "from IPython.display import display\n", 41 | "\n", 42 | "import os\n", 43 | "import glob\n", 44 | "import json" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "e5caa051", 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "records = []\n", 55 | "for result_filename in glob.glob(os.path.join(RESULTS_PATH, \"**/results_summary.json\"), recursive=True):\n", 56 | " config_file = os.path.join(\"/\".join(result_filename.split(\"/\")[:-1]), \"config.json\")\n", 57 | " config = json.load(open(config_file, \"r\"))\n", 58 | " if config[\"model_size\"] not in MODELS_TO_PLOT:\n", 59 | " continue\n", 60 | " if 'seed' not in config:\n", 61 | " config['seed'] = 0\n", 62 | " record = config.copy()\n", 63 | " if 'weak_model' in config:\n", 64 | " for k in record['weak_model']:\n", 65 | " if k == 'model_size':\n", 66 | " assert record['weak_model'][k] == record['weak_model_size']\n", 67 | " record['weak_' + k] = record['weak_model'][k]\n", 68 | " del record['weak_model']\n", 69 | " record.update(json.load(open(result_filename)))\n", 70 | " records.append(record)\n", 71 | "\n", 72 | "df = pd.DataFrame.from_records(records).sort_values(['ds_name', 'model_size'])" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "2f628577", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "datasets = df.ds_name.unique()\n", 83 | "for dataset in datasets:\n", 84 | " cur_df = df[(df.ds_name == dataset)].copy()\n", 85 | " base_accuracies = cur_df[cur_df['weak_model_size'].isna()].groupby('model_size').agg({'accuracy': 'mean', 'seed': 'count'}).sort_values('accuracy')\n", 86 | " base_accuracy_lookup = base_accuracies['accuracy'].to_dict()\n", 87 | " base_accuracies = base_accuracies.reset_index()\n", 88 | "\n", 89 | " cur_df['strong_model_accuracy'] = cur_df['model_size'].apply(lambda x: base_accuracy_lookup[x])\n", 90 | " cur_df.loc[~cur_df['weak_model_size'].isna(), 'weak_model_accuracy'] = cur_df.loc[~cur_df['weak_model_size'].isna(), 'weak_model_size'].apply(lambda x: base_accuracy_lookup[x])\n", 91 | "\n", 92 | " # Exclude cases where the weak model is better than the strong model from PGR calculation.\n", 93 | " valid_pgr_index = (\n", 94 | " (~cur_df['weak_model_size'].isna()) & \n", 95 | " (cur_df['weak_model_size'] != cur_df['model_size']) & \n", 96 | " (cur_df['strong_model_accuracy'] > cur_df['weak_model_accuracy'])\n", 97 | " )\n", 98 | " cur_df.loc[valid_pgr_index, 'pgr'] = (cur_df.loc[valid_pgr_index, 'accuracy'] - cur_df.loc[valid_pgr_index, 'weak_model_accuracy']) / (cur_df.loc[valid_pgr_index, 'strong_model_accuracy'] - cur_df.loc[valid_pgr_index, 'weak_model_accuracy'])\n", 99 | "\n", 100 | " cur_df.loc[cur_df['weak_model_size'].isna(), \"weak_model_size\"] = \"ground truth\"\n", 101 | "\n", 102 | " for seed in [None] + (sorted(cur_df['seed'].unique().tolist()) if PLOT_ALL_SEEDS else []):\n", 103 | " plot_df = cur_df.copy().sort_values(['strong_model_accuracy']).sort_values(['loss'], ascending=False)\n", 104 | " if seed is not None:\n", 105 | " plot_df = plot_df[plot_df['seed'] == seed]\n", 106 | "\n", 107 | " print(f\"Dataset: {dataset} (seed: {seed})\")\n", 108 | "\n", 109 | " pgr_results = plot_df[~plot_df['pgr'].isna()].groupby(['loss']).aggregate({\"pgr\": \"median\"})\n", 110 | "\n", 111 | " palette = sns.color_palette('colorblind', n_colors=len(plot_df['weak_model_size'].unique()) - 1)\n", 112 | " color_dict = {model: (\"black\" if model == 'ground truth' else palette.pop()) for model in plot_df['weak_model_size'].unique()}\n", 113 | "\n", 114 | " sns.lineplot(data=plot_df, x='strong_model_accuracy', y='accuracy', hue='weak_model_size', style='loss', markers=True, palette=color_dict)\n", 115 | " pd.plotting.table(plt.gca(), pgr_results.round(4), loc='lower right', colWidths=[0.1, 0.1], cellLoc='center', rowLoc='center')\n", 116 | " plt.xticks(ticks=base_accuracies['accuracy'], labels=[f\"{e} ({base_accuracy_lookup[e]:.4f})\" for e in base_accuracies['model_size']], rotation=90)\n", 117 | " plt.title(f\"Dataset: {dataset} (seed: {seed})\")\n", 118 | " plt.legend(loc='upper left')\n", 119 | " suffix = \"\"\n", 120 | " if seed is not None:\n", 121 | " suffix = f\"_{seed}\"\n", 122 | " plt.savefig(f\"{dataset.replace('/', '-')}{suffix}.png\", dpi=300, bbox_inches='tight')\n", 123 | " plt.show()" 124 | ] 125 | } 126 | ], 127 | "metadata": { 128 | "kernelspec": { 129 | "display_name": "openai", 130 | "language": "python", 131 | "name": "python3" 132 | }, 133 | "language_info": { 134 | "codemirror_mode": { 135 | "name": "ipython", 136 | "version": 3 137 | }, 138 | "file_extension": ".py", 139 | "mimetype": "text/x-python", 140 | "name": "python", 141 | "nbconvert_exporter": "python", 142 | "pygments_lexer": "ipython3", 143 | "version": "3.11.5" 144 | } 145 | }, 146 | "nbformat": 4, 147 | "nbformat_minor": 5 148 | } 149 | -------------------------------------------------------------------------------- /notebooks/Plotting_old.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "eb9a4b5a", 6 | "metadata": {}, 7 | "source": [ 8 | "# Simple Plotting\n" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "88c7ff9f", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "RESULTS_PATH = \"../../your_sweep_results_path\"\n", 19 | "\n", 20 | "PLOT_ALL_SEEDS = False\n", 21 | "# Full sweep\n", 22 | "MODELS_TO_PLOT = [\"gpt2\", \"gpt2-medium\", \"gpt2-large\", \"gpt2-xl\", \"Qwen/Qwen-1_8B\", \"Qwen/Qwen-7B\", \"Qwen/Qwen-14B\"]\n", 23 | "# Minimal sweep\n", 24 | "# MODELS_TO_PLOT = [\"gpt2\", \"gpt2-medium\", \"gpt2-large\"]\n" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "00ca073c", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import numpy as np\n", 35 | "import pandas as pd\n", 36 | "import matplotlib.pyplot as plt\n", 37 | "import seaborn as sns\n", 38 | "sns.set_style('whitegrid')\n", 39 | "\n", 40 | "from IPython.display import display\n", 41 | "\n", 42 | "import os\n", 43 | "import glob\n", 44 | "import json" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "e5caa051", 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "records = []\n", 55 | "all_results_folders = ['/'.join(e.split('/')[:-1]) for e in glob.glob(os.path.join(RESULTS_PATH, \"**/*.results_summary.json\"), recursive=True)]\n", 56 | "for result_folder in set(all_results_folders):\n", 57 | " config_file = os.path.join(result_folder, \"config.json\")\n", 58 | " config = json.load(open(config_file, \"r\"))\n", 59 | " if config[\"strong_model_size\"] not in MODELS_TO_PLOT:\n", 60 | " continue\n", 61 | " if 'seed' not in config:\n", 62 | " config['seed'] = 0\n", 63 | " result_filename = (config[\"weak_model_size\"].replace('.', '_') + \"_\" + config[\"strong_model_size\"].replace('.', '_') + \".results_summary.json\").replace('/', '_')\n", 64 | " record = config.copy()\n", 65 | " record.update(json.load(open(config_file.replace('config.json', result_filename))))\n", 66 | " records.append(record)\n", 67 | "\n", 68 | "df = pd.DataFrame.from_records(records).sort_values(['ds_name', 'weak_model_size', 'strong_model_size'])" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "id": "2f628577", 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "datasets = df.ds_name.unique()\n", 79 | "for dataset in datasets:\n", 80 | " cur_df = df[(df.ds_name == dataset)]\n", 81 | " base_df = pd.concat([\n", 82 | " pd.DataFrame.from_dict({\"strong_model_size\": cur_df['weak_model_size'].to_list(), \"accuracy\": cur_df['weak_acc'].to_list(), \"seed\": cur_df['seed'].to_list()}),\n", 83 | " pd.DataFrame.from_dict({\"strong_model_size\": cur_df['strong_model_size'].to_list(), \"accuracy\": cur_df['strong_acc'].to_list(), \"seed\": cur_df['seed'].to_list()})\n", 84 | " ])\n", 85 | " base_accuracies = base_df.groupby('strong_model_size').agg({'accuracy': 'mean', 'seed': 'count'}).sort_values('accuracy')\n", 86 | " base_accuracy_lookup = base_accuracies['accuracy'].to_dict()\n", 87 | " base_accuracies = base_accuracies.reset_index()\n", 88 | " base_df.reset_index(inplace=True)\n", 89 | " base_df['weak_model_size'] = 'ground truth'\n", 90 | " base_df['loss'] = 'xent'\n", 91 | " base_df['strong_model_accuracy'] = base_df['strong_model_size'].apply(lambda x: base_accuracy_lookup[x])\n", 92 | "\n", 93 | " weak_to_strong = cur_df[['weak_model_size', 'strong_model_size', 'seed'] + [e for e in cur_df.columns if e.startswith('transfer_acc')]]\n", 94 | " weak_to_strong = weak_to_strong.melt(id_vars=['weak_model_size', 'strong_model_size', 'seed'], var_name='loss', value_name='accuracy')\n", 95 | " weak_to_strong = weak_to_strong.dropna(subset=['accuracy'])\n", 96 | " weak_to_strong.reset_index(inplace=True)\n", 97 | " weak_to_strong['loss'] = weak_to_strong['loss'].str.replace('transfer_acc_', '')\n", 98 | " weak_to_strong['strong_model_accuracy'] = weak_to_strong['strong_model_size'].apply(lambda x: base_accuracy_lookup[x])\n", 99 | "\n", 100 | " # Exclude cases where the weak model is better than the strong model from PGR calculation.\n", 101 | " pgr_df = cur_df[(cur_df['weak_model_size'] != cur_df['strong_model_size']) & (cur_df['strong_acc'] > cur_df['weak_acc'])]\n", 102 | " pgr_df = pgr_df.melt(id_vars=[e for e in cur_df.columns if not e.startswith('transfer_acc')], var_name='loss', value_name='transfer_acc')\n", 103 | " pgr_df = pgr_df.dropna(subset=['transfer_acc'])\n", 104 | " pgr_df['loss'] = pgr_df['loss'].str.replace('transfer_acc_', '')\n", 105 | " pgr_df['pgr'] = (pgr_df['transfer_acc'] - pgr_df['weak_acc']) / (pgr_df['strong_acc'] - pgr_df['weak_acc'])\n", 106 | "\n", 107 | " for seed in [None] + (sorted(cur_df['seed'].unique().tolist()) if PLOT_ALL_SEEDS else []):\n", 108 | " plot_df = pd.concat([base_df, weak_to_strong])\n", 109 | " seed_pgr_df = pgr_df\n", 110 | " if seed is not None:\n", 111 | " plot_df = plot_df[plot_df['seed'] == seed]\n", 112 | " # We mean across seeds, this is because sometimes the weak and strong models will have run on different hardware and therefore\n", 113 | " # have slight differences. We want to average these out when filtering by seed.\n", 114 | "\n", 115 | " seed_pgr_df = pgr_df[pgr_df['seed'] == seed]\n", 116 | "\n", 117 | " if seed is not None or cur_df['seed'].nunique() == 1:\n", 118 | " plot_df = plot_df[['strong_model_accuracy', 'weak_model_size', 'loss', 'accuracy']].groupby(['strong_model_accuracy', 'weak_model_size', 'loss']).mean().reset_index().sort_values(['loss', 'weak_model_size'], ascending=False)\n", 119 | "\n", 120 | " print(f\"Dataset: {dataset} (seed: {seed})\")\n", 121 | "\n", 122 | " pgr_results = seed_pgr_df.groupby(['loss']).aggregate({\"pgr\": \"median\"})\n", 123 | " display(pgr_results)\n", 124 | "\n", 125 | " palette = sns.color_palette('colorblind', n_colors=len(plot_df['weak_model_size'].unique()) - 1)\n", 126 | " color_dict = {model: (\"black\" if model == 'ground truth' else palette.pop()) for model in plot_df['weak_model_size'].unique()}\n", 127 | "\n", 128 | " sns.lineplot(data=plot_df, x='strong_model_accuracy', y='accuracy', hue='weak_model_size', style='loss', markers=True, palette=color_dict)\n", 129 | " pd.plotting.table(plt.gca(), pgr_results.round(4), loc='lower right', colWidths=[0.1, 0.1], cellLoc='center', rowLoc='center')\n", 130 | " plt.xticks(ticks=base_accuracies['accuracy'], labels=[f\"{e} ({base_accuracy_lookup[e]:.4f})\" for e in base_accuracies['strong_model_size']], rotation=90)\n", 131 | " plt.title(f\"Dataset: {dataset} (seed: {seed})\")\n", 132 | " plt.legend(loc='upper left')\n", 133 | " plt.savefig(f\"{dataset.replace('/', '-')}_{seed}.png\", dpi=300, bbox_inches='tight')\n", 134 | " plt.show()" 135 | ] 136 | } 137 | ], 138 | "metadata": { 139 | "kernelspec": { 140 | "display_name": "openai", 141 | "language": "python", 142 | "name": "python3" 143 | }, 144 | "language_info": { 145 | "codemirror_mode": { 146 | "name": "ipython", 147 | "version": 3 148 | }, 149 | "file_extension": ".py", 150 | "mimetype": "text/x-python", 151 | "name": "python", 152 | "nbconvert_exporter": "python", 153 | "pygments_lexer": "ipython3", 154 | "version": "3.11.5" 155 | } 156 | }, 157 | "nbformat": 4, 158 | "nbformat_minor": 5 159 | } 160 | -------------------------------------------------------------------------------- /notebooks/amazon_polarity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/weak-to-strong/6b450f2cee3714d6f886e5e1910bac73633bf69a/notebooks/amazon_polarity.png -------------------------------------------------------------------------------- /notebooks/anthropic_hh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/weak-to-strong/6b450f2cee3714d6f886e5e1910bac73633bf69a/notebooks/anthropic_hh.png -------------------------------------------------------------------------------- /notebooks/boolq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/weak-to-strong/6b450f2cee3714d6f886e5e1910bac73633bf69a/notebooks/boolq.png -------------------------------------------------------------------------------- /notebooks/cosmos_qa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/weak-to-strong/6b450f2cee3714d6f886e5e1910bac73633bf69a/notebooks/cosmos_qa.png -------------------------------------------------------------------------------- /notebooks/sciq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/weak-to-strong/6b450f2cee3714d6f886e5e1910bac73633bf69a/notebooks/sciq.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "weak_to_strong" 7 | version = "0.0.1" 8 | authors = [ 9 | { name="OpenAI", email="generalization@openai.com" }, 10 | ] 11 | description = "Weak-to-strong generalization" 12 | readme = "README.md" 13 | requires-python = ">=3.7" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | ] 19 | dependencies=[ 20 | "torch ~= 2.1", 21 | "numpy ~= 1.24", 22 | "transformers ~= 4.36", 23 | "datasets ~= 2.14", 24 | "fire ~= 0.4", 25 | "accelerate ~= 0.25", 26 | "transformers-stream-generator ~= 0.0.4", 27 | "torch_optimizer ~= 0.3", 28 | "wandb ~= 0.16.1" 29 | ] -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="weak_to_strong", 5 | version="0.1", 6 | description="Weak-to-strong generalization", 7 | url="#", 8 | author="OpenAI", 9 | author_email="generalization@openai.com", 10 | packages=setuptools.find_packages(), 11 | zip_safe=False, 12 | ) 13 | -------------------------------------------------------------------------------- /sweep.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | from typing import List, Union 5 | 6 | import fire 7 | 8 | 9 | def main(model_sizes: Union[List[str], str], **kwargs): 10 | if isinstance(model_sizes, str): 11 | model_sizes = model_sizes.split(",") 12 | assert ( 13 | "weak_model_size" not in kwargs 14 | and "model_size" not in kwargs 15 | and "weak_labels_path" not in kwargs 16 | ), "Need to use model_sizes when using sweep.py" 17 | basic_args = [sys.executable, os.path.join(os.path.dirname(__file__), "train_simple.py")] 18 | for key, value in kwargs.items(): 19 | basic_args.extend([f"--{key}", str(value)]) 20 | 21 | print("Running ground truth models") 22 | for model_size in model_sizes: 23 | subprocess.run(basic_args + ["--model_size", model_size], check=True) 24 | 25 | print("Running transfer models") 26 | for i in range(len(model_sizes)): 27 | for j in range(i, len(model_sizes)): 28 | weak_model_size = model_sizes[i] 29 | strong_model_size = model_sizes[j] 30 | print(f"Running weak {weak_model_size} to strong {strong_model_size}") 31 | subprocess.run( 32 | basic_args 33 | + ["--weak_model_size", weak_model_size, "--model_size", strong_model_size], 34 | check=True, 35 | ) 36 | 37 | 38 | if __name__ == "__main__": 39 | fire.Fire(main) 40 | -------------------------------------------------------------------------------- /train_simple.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import subprocess 5 | from typing import Dict, List, Optional 6 | 7 | import fire 8 | import numpy as np 9 | import torch 10 | from datasets import load_dataset, load_from_disk 11 | 12 | import weak_to_strong.logger as logger 13 | from weak_to_strong.common import get_tokenizer 14 | from weak_to_strong.datasets import (VALID_DATASETS, load_dataset, 15 | tokenize_dataset) 16 | from weak_to_strong.loss import logconf_loss_fn, product_loss_fn, xent_loss 17 | from weak_to_strong.train import ModelConfig, train_and_save_model 18 | 19 | # NOTE learning rates are not particularly tuned, work somewhat reasonably at train batch size 32 20 | MODEL_CONFIGS = [ 21 | ModelConfig( 22 | name="gpt2", 23 | default_lr=5e-5, 24 | eval_batch_size=32, 25 | ), 26 | ModelConfig( 27 | name="gpt2-medium", 28 | default_lr=5e-5, 29 | eval_batch_size=32, 30 | ), 31 | ModelConfig( 32 | name="gpt2-large", 33 | default_lr=1e-5, 34 | eval_batch_size=32, 35 | ), 36 | ModelConfig( 37 | name="gpt2-xl", 38 | default_lr=1e-5, 39 | eval_batch_size=2, 40 | gradient_checkpointing=True, 41 | # Should use model_parallel on V100s (note: ironically if you have a single V100 it should run, 42 | # but if you have multiple it won't run without model_parallel because of the overhead of data 43 | # parallel training). 44 | model_parallel=( 45 | torch.cuda.get_device_properties(0).total_memory < 35e9 46 | and torch.cuda.device_count() > 1 47 | ), 48 | ), 49 | ModelConfig( 50 | name="Qwen/Qwen-1_8B", 51 | default_lr=1e-5, 52 | eval_batch_size=2, 53 | gradient_checkpointing=True, 54 | model_parallel=( 55 | torch.cuda.get_device_properties(0).total_memory < 35e9 56 | and torch.cuda.device_count() > 1 57 | ), 58 | custom_kwargs={ 59 | "trust_remote_code": True, 60 | "bf16": torch.cuda.is_bf16_supported(), 61 | "fp32": not torch.cuda.is_bf16_supported(), 62 | "revision": "5fde88dff770a7d036847211f5d9d9705f0caa69", 63 | }, 64 | ), 65 | ModelConfig( 66 | name="Qwen/Qwen-7B", 67 | default_lr=1e-5, 68 | eval_batch_size=2, 69 | gradient_checkpointing=True, 70 | model_parallel=True, 71 | # note: you will probably not be able to run this without many gpus 72 | custom_kwargs={ 73 | "trust_remote_code": True, 74 | "bf16": torch.cuda.is_bf16_supported(), 75 | "fp32": not torch.cuda.is_bf16_supported(), 76 | "revision": "d4efd21e866b9cb3466cb65b963933f5e98016d1", 77 | }, 78 | ), 79 | ModelConfig( 80 | name="Qwen/Qwen-14B", 81 | default_lr=1e-5, 82 | eval_batch_size=2, 83 | gradient_checkpointing=True, 84 | model_parallel=True, 85 | # note: you will probably not be able to run this bf16 support and without many gpus 86 | custom_kwargs={ 87 | "trust_remote_code": True, 88 | "bf16": torch.cuda.is_bf16_supported(), 89 | "fp32": not torch.cuda.is_bf16_supported(), 90 | "revision": "8be2854218fea9054331e217fd26a06f3fd02004", 91 | }, 92 | ), 93 | ModelConfig( 94 | name="Qwen/Qwen-72B", 95 | default_lr=1e-5, 96 | eval_batch_size=1, 97 | gradient_checkpointing=True, 98 | model_parallel=True, 99 | # note: you will probably not be able to run this without bf16 support and many gpus 100 | custom_kwargs={ 101 | "trust_remote_code": True, 102 | "bf16": torch.cuda.is_bf16_supported(), 103 | "fp32": not torch.cuda.is_bf16_supported(), 104 | "revision": "fec78c0e3b3b10dd9f0ce775c34a686a3255a7d1", 105 | }, 106 | # This model is really big, save space by using adafactor. 107 | # Note that even then it will take up ~60GB per GPU on an 8-GPU machine. 108 | default_optimizer="adafactor", 109 | ), 110 | ] 111 | MODELS_DICT: Dict[str, ModelConfig] = { 112 | model_config.name: model_config for model_config in MODEL_CONFIGS 113 | } 114 | 115 | 116 | loss_dict = { 117 | "logconf": logconf_loss_fn(), 118 | "product": product_loss_fn(), 119 | "xent": xent_loss(), 120 | } 121 | 122 | VALID_LOSSES: List[str] = list(loss_dict.keys()) 123 | 124 | 125 | def get_config_foldername(config: dict) -> str: 126 | def shorten_key(key: str) -> str: 127 | return "".join(word[0] for word in key.split("_")) 128 | 129 | def shorten_value(value) -> str: 130 | if isinstance(value, bool): 131 | return "1" if value else "0" 132 | elif isinstance(value, str): 133 | value = value.split("/")[-1] 134 | if "_" in value: 135 | return "_".join(word[:4] for word in value.split("_")) 136 | else: 137 | return value 138 | else: 139 | return str(value) 140 | 141 | return "-".join(f"{shorten_key(k)}={shorten_value(v)}" for k, v in sorted(config.items())) 142 | 143 | 144 | def main( 145 | batch_size: int = 32, 146 | max_ctx: int = 1024, 147 | ds_name: str = "sciq", 148 | loss: str = "xent", 149 | n_docs: int = 20000, 150 | n_test_docs: int = 10000, 151 | model_size: str = "gpt2", 152 | lr: Optional[float] = None, 153 | optim: Optional[str] = None, 154 | epochs: int = 2, 155 | force_retrain: bool = False, 156 | seed: int = 0, 157 | minibatch_size_per_device: Optional[float] = None, 158 | train_with_dropout: bool = False, 159 | results_folder: str = "/tmp/results", 160 | linear_probe: bool = False, 161 | lr_schedule: str = "cosine_anneal", 162 | # Note: you can pass either weak_model_size or weak_labels_path. If you pass 163 | # weak_model_size, we will guess the path to the weak labels based on the weak 164 | # model. If you pass weak_labels_path, we will use that path instead. 165 | # If you pass neither, we will train on ground truth. 166 | weak_model_size: Optional[str] = None, 167 | weak_labels_path: Optional[str] = None, 168 | sweep_subfolder: str = "default", 169 | # Set to a very large value so that by default we don't do any intermediate evals but 170 | # still do final evals (which requires eval_every to be set to a non-zero, non-None value) 171 | eval_every: int = 1000000, 172 | sync_command: Optional[str] = None, 173 | ): 174 | # this is per device! 175 | if minibatch_size_per_device is None: 176 | minibatch_size_per_device = 1 177 | assert ds_name in VALID_DATASETS, f"Unknown dataset {ds_name} not in {VALID_DATASETS}" 178 | assert ( 179 | weak_model_size is None or weak_labels_path is None 180 | ), "Can't pass both weak_model_size and weak_labels_path" 181 | model_config = MODELS_DICT[model_size] 182 | 183 | use_default_lr = False 184 | if lr is None: 185 | assert ( 186 | batch_size == 32 187 | ), "Learning rates were tuned on batch size 32, you probably want to sweep LR if you are tuning batch size" 188 | lr = model_config.default_lr 189 | use_default_lr = True 190 | 191 | if optim is None: 192 | optim = model_config.default_optimizer 193 | 194 | # The commented out terms are the ones that should not change final results 195 | config = { 196 | "batch_size": batch_size, 197 | "max_ctx": max_ctx, 198 | "ds_name": ds_name, 199 | "loss": loss, 200 | "n_docs": n_docs, 201 | "n_test_docs": n_test_docs, 202 | "model_size": model_size, 203 | "lr": lr, 204 | "optim": optim, 205 | "epochs": epochs, 206 | # "force_retrain": force_retrain, 207 | "seed": seed, 208 | # "minibatch_size_per_device": minibatch_size_per_device, 209 | "train_with_dropout": train_with_dropout, 210 | # "results_folder": results_folder, 211 | "linear_probe": linear_probe, 212 | "lr_schedule": lr_schedule, 213 | "eval_every": eval_every, 214 | # "sweep_subfolder": sweep_subfolder, 215 | } 216 | 217 | if weak_model_size is not None: 218 | weak_model_config = config.copy() 219 | weak_model_config["model_size"] = weak_model_size 220 | weak_model_config["loss"] = "xent" 221 | if use_default_lr: 222 | weak_model_config["lr"] = MODELS_DICT[weak_model_size].default_lr 223 | 224 | weak_model_config_name = get_config_foldername(weak_model_config) 225 | 226 | weak_labels_path = ( 227 | results_folder + "/" + sweep_subfolder + "/" + weak_model_config_name + "/weak_labels" 228 | ) 229 | 230 | eval_batch_size = model_config.eval_batch_size 231 | random.seed(seed) 232 | 233 | # Load dataset 234 | dataset = load_dataset(ds_name, seed=seed, split_sizes=dict(train=n_docs, test=n_test_docs)) 235 | 236 | # Split the training dataset in half 237 | train_dataset, test_ds = dataset["train"], dataset["test"] 238 | 239 | if weak_labels_path is None: 240 | split_data = train_dataset.train_test_split(test_size=0.5, seed=seed) 241 | train1_ds, train2_ds = split_data["train"], split_data["test"] 242 | print("len(train1):", len(train1_ds), "len(train2):", len(train2_ds)) 243 | config_name = get_config_foldername(config) 244 | else: 245 | if not weak_labels_path.endswith("weak_labels"): 246 | weak_labels_path = weak_labels_path + "/weak_labels" 247 | if sync_command is not None: 248 | sync_command_list = sync_command.split(" ") 249 | sync_command_list.extend( 250 | ["download", weak_labels_path.replace("/weak_labels", ""), results_folder] 251 | ) 252 | print(f"Running sync command: {' '.join(sync_command_list)}") 253 | result = subprocess.run(sync_command_list, check=True) 254 | if result.returncode != 0: 255 | raise RuntimeError(f"Sync command failed with return code {result.returncode}") 256 | train1_ds = load_from_disk(weak_labels_path) 257 | train2_ds = None 258 | 259 | weak_model_config = json.load(open(weak_labels_path.replace("weak_labels", "config.json"))) 260 | config["weak_model_size"] = weak_model_config["model_size"] 261 | config_name = get_config_foldername(config) 262 | config["weak_model"] = weak_model_config 263 | 264 | save_path = os.path.join(results_folder, sweep_subfolder, config_name) 265 | logger.configure( 266 | name="{sweep_subfolder}_{config_name}_{datetime_now}", 267 | save_path=save_path, 268 | sweep_subfolder=sweep_subfolder, 269 | config_name=config_name, 270 | ) 271 | # Tokenize datasets 272 | tokenizer = get_tokenizer(model_config.name) 273 | train1_ds = tokenize_dataset(train1_ds, tokenizer, max_ctx) 274 | test_ds = tokenize_dataset(test_ds, tokenizer, max_ctx) 275 | if train2_ds: 276 | train2_ds = tokenize_dataset(train2_ds, tokenizer, max_ctx) 277 | 278 | loss_fn = loss_dict[loss] 279 | print(f"Training model model, size {model_size}") 280 | test_results, weak_ds = train_and_save_model( 281 | model_config, 282 | train1_ds, 283 | test_ds, 284 | inference_ds=train2_ds, 285 | batch_size=batch_size, 286 | save_path=save_path, 287 | loss_fn=loss_fn, 288 | lr=lr, 289 | epochs=epochs, 290 | force_retrain=force_retrain, 291 | eval_batch_size=eval_batch_size, 292 | minibatch_size_per_device=minibatch_size_per_device, 293 | train_with_dropout=train_with_dropout, 294 | linear_probe=linear_probe, 295 | lr_schedule=lr_schedule, 296 | optimizer_name=optim, 297 | eval_every=eval_every, 298 | ) 299 | 300 | if weak_ds is not None: 301 | weak_ds.save_to_disk(save_path + "/" + "weak_labels") 302 | 303 | acc = np.mean([x["acc"] for x in test_results]) 304 | res_dict = {"accuracy": acc} 305 | print("accuracy:", acc) 306 | 307 | with open(os.path.join(save_path, f"config.json"), "w") as f: 308 | json.dump(config, f, indent=2) 309 | 310 | with open(os.path.join(save_path, f"results_summary.json"), "w") as f: 311 | json.dump(res_dict, f, indent=2) 312 | 313 | if sync_command is not None: 314 | print("Syncing results to remote storage...") 315 | try: 316 | sync_command_list = sync_command.split(" ") 317 | sync_command_list.extend(["upload", save_path, results_folder]) 318 | print(f"Running sync command: {' '.join(sync_command_list)}") 319 | result = subprocess.run(sync_command_list, check=True) 320 | if result.returncode != 0: 321 | raise RuntimeError(f"Sync command failed with return code {result.returncode}") 322 | except Exception as e: 323 | raise RuntimeError("Failed to sync results to remote storage.") from e 324 | 325 | 326 | if __name__ == "__main__": 327 | fire.Fire(main) 328 | -------------------------------------------------------------------------------- /train_weak_to_strong.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Dict, List, Optional, Sequence, Union 4 | 5 | import fire 6 | import numpy as np 7 | import torch 8 | 9 | import weak_to_strong.logger as logger 10 | from weak_to_strong.common import get_tokenizer 11 | from weak_to_strong.datasets import (VALID_DATASETS, load_dataset, 12 | tokenize_dataset) 13 | from weak_to_strong.loss import logconf_loss_fn, product_loss_fn, xent_loss 14 | from weak_to_strong.train import ModelConfig, train_and_save_model 15 | 16 | # NOTE learning rates are not particularly tuned, work somewhat reasonably at train batch size 32 17 | MODEL_CONFIGS = [ 18 | ModelConfig( 19 | name="gpt2", 20 | default_lr=5e-5, 21 | eval_batch_size=32, 22 | custom_kwargs={ 23 | "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32, 24 | }, 25 | ), 26 | ModelConfig( 27 | name="gpt2-medium", 28 | default_lr=5e-5, 29 | eval_batch_size=32, 30 | custom_kwargs={ 31 | "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32, 32 | }, 33 | ), 34 | ModelConfig( 35 | name="gpt2-large", 36 | default_lr=1e-5, 37 | eval_batch_size=32, 38 | custom_kwargs={ 39 | "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32, 40 | }, 41 | ), 42 | ModelConfig( 43 | name="gpt2-xl", 44 | default_lr=1e-5, 45 | eval_batch_size=2, 46 | gradient_checkpointing=True, 47 | model_parallel=True, 48 | custom_kwargs={ 49 | "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32, 50 | }, 51 | ), 52 | ModelConfig( 53 | name="Qwen/Qwen-1_8B", 54 | default_lr=1e-5, 55 | eval_batch_size=2, 56 | gradient_checkpointing=True, 57 | model_parallel=True, 58 | custom_kwargs={ 59 | "trust_remote_code": True, 60 | "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32, 61 | }, 62 | ), 63 | ModelConfig( 64 | name="Qwen/Qwen-7B", 65 | default_lr=1e-5, 66 | eval_batch_size=2, 67 | gradient_checkpointing=True, 68 | model_parallel=True, 69 | # note: you will probably not be able to run this without many gpus 70 | custom_kwargs={ 71 | "trust_remote_code": True, 72 | "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32, 73 | }, 74 | ), 75 | ModelConfig( 76 | name="Qwen/Qwen-14B", 77 | default_lr=1e-5, 78 | eval_batch_size=2, 79 | gradient_checkpointing=True, 80 | model_parallel=True, 81 | # note: you will probably not be able to run this without bf16 support and many gpus 82 | custom_kwargs={ 83 | "trust_remote_code": True, 84 | "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32, 85 | }, 86 | ), 87 | ModelConfig( 88 | name="Qwen/Qwen-72B", 89 | default_lr=1e-5, 90 | eval_batch_size=1, 91 | gradient_checkpointing=True, 92 | model_parallel=True, 93 | # note: you will probably not be able to run this without bf16 support and many gpus 94 | custom_kwargs={ 95 | "trust_remote_code": True, 96 | "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32, 97 | }, 98 | # This model is really big, save space by using adafactor. 99 | # Note that even then it will take up ~60GB per GPU on an 8-GPU machine. 100 | default_optimizer="adafactor", 101 | ), 102 | ] 103 | MODELS_DICT: Dict[str, ModelConfig] = { 104 | model_config.name: model_config for model_config in MODEL_CONFIGS 105 | } 106 | 107 | 108 | loss_dict = { 109 | "logconf": logconf_loss_fn(), 110 | "product": product_loss_fn(), 111 | "xent": xent_loss(), 112 | } 113 | 114 | VALID_LOSSES: List[str] = list(loss_dict.keys()) 115 | 116 | 117 | def main( 118 | batch_size: int = 32, 119 | max_ctx: int = 1024, 120 | ds_name: str = "sciq", 121 | transfer_loss: Union[str, Sequence[str]] = "xent,logconf", 122 | n_docs: int = 10000, 123 | n_test_docs: int = 200, 124 | weak_model_size: str = "gpt2", 125 | weak_lr: Optional[float] = None, 126 | strong_model_size: str = "gpt2-xl", 127 | strong_lr: Optional[float] = None, 128 | # Defaults to strong_lr 129 | transfer_lr: Optional[float] = None, 130 | # Optims default to default_optimizer in the model definitions 131 | weak_optim: Optional[str] = None, 132 | strong_optim: Optional[str] = None, 133 | transfer_optim: Optional[str] = None, 134 | gt_epochs: int = 2, 135 | # defaults to gt_epochs 136 | transfer_epochs: Optional[int] = None, 137 | force_retrain: bool = False, 138 | seed: int = 0, 139 | minibatch_size_per_device: Optional[int] = None, 140 | train_with_dropout: bool = False, 141 | results_folder: str = "/tmp/results", 142 | linear_probe: bool = False, 143 | lr_schedule: str = "cosine_anneal", 144 | log_prefix: str = "", 145 | # Set to an absurdly high value so we don't do intermediate evals by default. 146 | eval_every: int = 100000000, 147 | ): 148 | # this is per device! 149 | if minibatch_size_per_device is None: 150 | minibatch_size_per_device = 1 151 | assert ds_name in VALID_DATASETS, f"Unknown dataset {ds_name} not in {VALID_DATASETS}" 152 | if isinstance(transfer_loss, str): 153 | transfer_losses = transfer_loss.split(",") 154 | else: 155 | transfer_losses = transfer_loss 156 | del transfer_loss 157 | for tloss in transfer_losses: 158 | assert tloss in VALID_LOSSES, f"Unknown loss {tloss} not in {VALID_LOSSES}" 159 | assert ( 160 | weak_model_size in MODELS_DICT 161 | ), f"Unknown model size {weak_model_size} not in {MODELS_DICT}" 162 | weak_model_config = MODELS_DICT[weak_model_size] 163 | assert ( 164 | strong_model_size in MODELS_DICT 165 | ), f"Unknown model size {strong_model_size} not in {MODELS_DICT}" 166 | strong_model_config = MODELS_DICT[strong_model_size] 167 | 168 | if weak_lr is None: 169 | assert batch_size == 32 170 | weak_lr = weak_model_config.default_lr 171 | if strong_lr is None: 172 | assert batch_size == 32 173 | strong_lr = strong_model_config.default_lr 174 | if transfer_lr is None: 175 | transfer_lr = strong_lr 176 | if transfer_epochs is None: 177 | transfer_epochs = gt_epochs 178 | 179 | if weak_optim is None: 180 | weak_optim = weak_model_config.default_optimizer 181 | if strong_optim is None: 182 | strong_optim = strong_model_config.default_optimizer 183 | if transfer_optim is None: 184 | transfer_optim = strong_optim 185 | 186 | weak_eval_batch_size = weak_model_config.eval_batch_size 187 | strong_eval_batch_size = strong_model_config.eval_batch_size 188 | 189 | # Load dataset 190 | dataset = load_dataset(ds_name, seed=seed, split_sizes=dict(train=n_docs, test=n_test_docs)) 191 | 192 | # Split the training dataset in half 193 | train_dataset, test_ds = dataset["train"], dataset["test"] 194 | 195 | split_data = train_dataset.train_test_split(test_size=0.5, seed=seed) 196 | train1_ds, train2_ds = split_data["train"], split_data["test"] 197 | print("len(train1):", len(train1_ds), "len(train2):", len(train2_ds)) 198 | 199 | def train_model( 200 | model_config: ModelConfig, 201 | train_ds: torch.utils.data.Dataset, 202 | test_ds: torch.utils.data.Dataset, 203 | *, 204 | loss_type: str, 205 | label: str, 206 | subpath, 207 | lr, 208 | eval_batch_size, 209 | epochs=1, 210 | inference_ds: Optional[torch.utils.data.Dataset] = None, 211 | linear_probe: bool = False, 212 | optimizer_name: str = "adam", 213 | ): 214 | save_path = os.path.join(results_folder, subpath) 215 | linprobe_str = "_linprobe" if linear_probe else "" 216 | logger.configure( 217 | name="{log_prefix}{label}_{base_model_name}_{ds_name}_{loss_type}_{optimizer_name}_{lr}_{lr_schedule}{linprobe_str}_{datetime_now}", 218 | label=label, 219 | ds_name=ds_name, 220 | truncation_max_len=n_docs or "none", 221 | loss_type=loss_type, 222 | lr=lr, 223 | batch_size=batch_size, 224 | eval_batch_size=eval_batch_size, 225 | minibatch_size_per_device=minibatch_size_per_device, 226 | save_path=save_path, 227 | base_model_name=model_config.name, 228 | epochs=epochs, 229 | linprobe_str=linprobe_str, 230 | lr_schedule=lr_schedule, 231 | log_prefix=log_prefix, 232 | optimizer_name=optimizer_name, 233 | ) 234 | # Tokenize datasets 235 | tokenizer = get_tokenizer(model_config.name) 236 | train_ds = tokenize_dataset(train_ds, tokenizer, max_ctx) 237 | test_ds = tokenize_dataset(test_ds, tokenizer, max_ctx) 238 | if inference_ds: 239 | inference_ds = tokenize_dataset(inference_ds, tokenizer, max_ctx) 240 | 241 | loss_fn = loss_dict[loss_type] 242 | return train_and_save_model( 243 | model_config, 244 | train_ds, 245 | test_ds, 246 | inference_ds=inference_ds, 247 | batch_size=batch_size, 248 | save_path=save_path, 249 | loss_fn=loss_fn, 250 | lr=lr, 251 | epochs=epochs, 252 | force_retrain=force_retrain, 253 | eval_batch_size=eval_batch_size, 254 | minibatch_size_per_device=minibatch_size_per_device, 255 | train_with_dropout=train_with_dropout, 256 | linear_probe=linear_probe, 257 | lr_schedule=lr_schedule, 258 | optimizer_name=optimizer_name, 259 | eval_every=eval_every, 260 | ) 261 | 262 | # Train the weak model on the first half of the training data 263 | print(f"Training weak model, size {weak_model_size}") 264 | weak_test_results, weak_ds = train_model( 265 | weak_model_config, 266 | train1_ds, 267 | test_ds, 268 | loss_type="xent", 269 | label="weak", 270 | subpath=os.path.join("weak_model_gt", weak_model_size.replace("/", "_")), 271 | lr=weak_lr, 272 | eval_batch_size=weak_eval_batch_size, 273 | inference_ds=train2_ds, 274 | epochs=gt_epochs, 275 | linear_probe=linear_probe, 276 | optimizer_name=weak_optim, 277 | ) 278 | 279 | # Train the strong model on the second half of the training data 280 | print(f"Training strong model, size {strong_model_size}") 281 | strong_test_results, _ = train_model( 282 | strong_model_config, 283 | train2_ds, 284 | test_ds, 285 | loss_type="xent", 286 | label="strong", 287 | subpath=os.path.join("strong_model_gt", strong_model_size.replace("/", "_")), 288 | lr=strong_lr, 289 | eval_batch_size=strong_eval_batch_size, 290 | epochs=gt_epochs, 291 | linear_probe=linear_probe, 292 | optimizer_name=strong_optim, 293 | ) 294 | 295 | # Train the strong model on the second half of the training data with labels generated by the weak model 296 | all_transfer_test_results = {} 297 | for tloss in transfer_losses: 298 | print( 299 | f"Training transfer model, size {strong_model_size} on labels from {weak_model_size}, with loss {tloss}" 300 | ) 301 | transfer_test_results, _ = train_model( 302 | strong_model_config, 303 | weak_ds, 304 | test_ds, 305 | loss_type=tloss, 306 | label="weak2strong", 307 | subpath=os.path.join( 308 | "strong_model_transfer", 309 | f"{weak_model_size.replace('/', '_')}_{strong_model_size.replace('/', '_')}_{tloss}", 310 | ), 311 | lr=transfer_lr, 312 | eval_batch_size=strong_eval_batch_size, 313 | epochs=transfer_epochs, 314 | linear_probe=linear_probe, 315 | optimizer_name=transfer_optim, 316 | ) 317 | all_transfer_test_results[tloss] = transfer_test_results 318 | del transfer_test_results 319 | 320 | weak_acc = np.mean([x["acc"] for x in weak_test_results]) 321 | strong_acc = np.mean([x["acc"] for x in strong_test_results]) 322 | res_dict = { 323 | "weak_acc": weak_acc, 324 | "strong_acc": strong_acc, 325 | } 326 | print("weak acc:", weak_acc) 327 | print("strong acc:", strong_acc) 328 | for tloss, transfer_test_results in all_transfer_test_results.items(): 329 | transfer_acc = np.mean([x["acc"] for x in transfer_test_results]) 330 | res_dict[f"transfer_acc_{tloss}"] = transfer_acc 331 | print(f"transfer acc ({tloss}):", transfer_acc) 332 | 333 | with open( 334 | os.path.join( 335 | results_folder, 336 | f"{weak_model_size.replace('/', '_')}_{strong_model_size.replace('/', '_')}.results_summary.json", 337 | ), 338 | "w", 339 | ) as f: 340 | json.dump( 341 | res_dict, 342 | f, 343 | ) 344 | 345 | 346 | # python train_weak_to_strong.py --batch_size 32 --max_ctx 512 --ds_name "sciq" --transfer_loss "logconf" --n_docs 1000 --n_test_docs 100 --weak_model_size "gpt2-medium" --strong_model_size "gpt2-large" --seed 42 347 | if __name__ == "__main__": 348 | fire.Fire(main) 349 | -------------------------------------------------------------------------------- /vision/README.md: -------------------------------------------------------------------------------- 1 | # A Simple Weak-to-Strong Experiment on ImageNet 2 | 3 | We provide code for a simple weak-to-strong experiment on ImageNet. 4 | We generate the weak labels using an [AlexNet](https://pytorch.org/vision/main/models/generated/torchvision.models.alexnet.html) model pretrained on ImageNet and we use linear probes on top of [DINO](https://github.com/facebookresearch/dino) models 5 | as a strong student. 6 | 7 | The full training command: 8 | 9 | ```bash 10 | python3 run_weak_strong.py \ 11 | data_path: \ 12 | weak_model_name: \ 13 | strong_model_name: \ 14 | batch_size \ 15 | seed \ 16 | n_epochs \ 17 | lr \ 18 | n_train 19 | ``` 20 | Parameters: 21 | 22 | * ```DATA_PATH``` — path to the base directory containing ImageNet data, see [torchvision page](https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageNet.html) for instructions; should contain files `ILSVRC2012_devkit_t12.tar.gz` and `ILSVRC2012_img_val.tar` 23 | * ```WEAK_MODEL``` — weak model name: 24 | - `"alexnet"` is the only default model and the only one currently implemented 25 | * ```STRONG_MODEL``` — weak model name: 26 | - `"resnet50_dino"` (default) 27 | - `"vitb8_dino"` 28 | * ```BATCH_SIZE``` — batch size for weak label generation and embedding extraction (default: `128`) 29 | * ```SEED``` — random seed for dataset shuffling (default: `0`) 30 | * ```EPOCHS``` — number of training epochs (default: `10`) 31 | * ```LR``` — initial learning rate (default: `1e-3`) 32 | * ```N_TRAIN``` — number of datapoints used to train the linear probe; `50000 - N_TRAIN` datapoints are used as test (default: `40000`) 33 | 34 | 35 | 36 | Example commands: 37 | 38 | ```bash 39 | # AlexNet → ResNet50 (DINO): 40 | python3 run_weak_strong.py --strong_model_name resnet50_dino --n_epochs 20 41 | 42 | # AlexNet → ViT-B/8 (DINO): 43 | python3 run_weak_strong.py --strong_model_name vitb8_dino --n_epochs 5 44 | ``` 45 | 46 | With the commands above we get the following results (note that the results may not reproduce exactly due to randomness): 47 | 48 | | Model | Top-1 Accuracy | 49 | |-------------------------|----------------| 50 | | AlexNet | 56.6 | 51 | | Dino ResNet50 | 64.5 | 52 | | Dino ViT-B/8 | 74.0 | 53 | | AlexNet → DINO ResNet50 | 61.9 | 54 | | AlexNet → DINO ViT-B/8 | 66.6 | 55 | 56 | You can add new custom models to the `models.py` and new datasets to `data.py`. 57 | -------------------------------------------------------------------------------- /vision/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | RESIZE, CROP = 256, 224 5 | TRANSFORM = torchvision.transforms.Compose( 6 | [ 7 | torchvision.transforms.Resize(RESIZE), 8 | torchvision.transforms.CenterCrop(CROP), 9 | torchvision.transforms.ToTensor(), 10 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 11 | ] 12 | ) 13 | 14 | 15 | def get_imagenet(datapath, split, batch_size, shuffle, transform=TRANSFORM): 16 | ds = torchvision.datasets.ImageNet(root=datapath, split=split, transform=transform) 17 | loader = torch.utils.data.DataLoader(ds, shuffle=shuffle, batch_size=batch_size) 18 | return ds, loader 19 | -------------------------------------------------------------------------------- /vision/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | 5 | class HeadAndEmbedding(torch.nn.Module): 6 | def __init__(self, head): 7 | super(HeadAndEmbedding, self).__init__() 8 | self.head = head 9 | 10 | def forward(self, x): 11 | return x, self.head(x) 12 | 13 | 14 | def _alexnet_replace_fc(model): 15 | model.classifier = HeadAndEmbedding(model.classifier) 16 | return model 17 | 18 | 19 | def resnet50_dino(): 20 | model = torch.hub.load("facebookresearch/dino:main", "dino_resnet50") 21 | return model 22 | 23 | 24 | def vitb8_dino(): 25 | model = torch.hub.load("facebookresearch/dino:main", "dino_vitb8") 26 | return model 27 | 28 | 29 | def alexnet(): 30 | model = torchvision.models.alexnet(pretrained=True) 31 | return _alexnet_replace_fc(model) 32 | -------------------------------------------------------------------------------- /vision/run_weak_strong.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import numpy as np 3 | import torch 4 | import tqdm 5 | from data import get_imagenet 6 | from models import alexnet, resnet50_dino, vitb8_dino 7 | from torch import nn 8 | 9 | 10 | def get_model(name): 11 | if name == "alexnet": 12 | model = alexnet() 13 | elif name == "resnet50_dino": 14 | model = resnet50_dino() 15 | elif name == "vitb8_dino": 16 | model = vitb8_dino() 17 | else: 18 | raise ValueError(f"Unknown model {name}") 19 | model.cuda() 20 | model.eval() 21 | model = nn.DataParallel(model) 22 | return model 23 | 24 | 25 | def get_embeddings(model, loader): 26 | all_embeddings, all_y, all_probs = [], [], [] 27 | 28 | for x, y in tqdm.tqdm(loader): 29 | output = model(x.cuda()) 30 | if len(output) == 2: 31 | embeddings, logits = output 32 | probs = torch.nn.functional.softmax(logits, dim=-1).detach().cpu() 33 | all_probs.append(probs) 34 | else: 35 | embeddings = output 36 | 37 | all_embeddings.append(embeddings.detach().cpu()) 38 | all_y.append(y) 39 | 40 | all_embeddings = torch.cat(all_embeddings, axis=0) 41 | all_y = torch.cat(all_y, axis=0) 42 | if len(all_probs) > 0: 43 | all_probs = torch.cat(all_probs, axis=0) 44 | acc = (torch.argmax(all_probs, dim=1) == all_y).float().mean() 45 | else: 46 | all_probs = None 47 | acc = None 48 | return all_embeddings, all_y, all_probs, acc 49 | 50 | 51 | def train_logreg( 52 | x_train, 53 | y_train, 54 | eval_datasets, 55 | n_epochs=10, 56 | weight_decay=0.0, 57 | lr=1.0e-3, 58 | batch_size=100, 59 | n_classes=1000, 60 | ): 61 | x_train = x_train.float() 62 | train_ds = torch.utils.data.TensorDataset(x_train, y_train) 63 | train_loader = torch.utils.data.DataLoader(train_ds, shuffle=True, batch_size=batch_size) 64 | 65 | d = x_train.shape[1] 66 | model = torch.nn.Linear(d, n_classes).cuda() 67 | criterion = torch.nn.CrossEntropyLoss() 68 | optimizer = torch.optim.Adam(model.parameters(), weight_decay=weight_decay, lr=lr) 69 | n_batches = len(train_loader) 70 | n_iter = n_batches * n_epochs 71 | schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_iter) 72 | 73 | results = {f"{key}_all": [] for key in eval_datasets.keys()} 74 | for epoch in (pbar := tqdm.tqdm(range(n_epochs), desc="Epoch 0")): 75 | correct, total = 0, 0 76 | for x, y in train_loader: 77 | x, y = x.cuda(), y.cuda() 78 | optimizer.zero_grad() 79 | pred = model(x) 80 | loss = criterion(pred, y) 81 | loss.backward() 82 | optimizer.step() 83 | schedule.step() 84 | if len(y.shape) > 1: 85 | y = torch.argmax(y, dim=1) 86 | correct += (torch.argmax(pred, -1) == y).detach().float().sum().item() 87 | total += len(y) 88 | pbar.set_description(f"Epoch {epoch}, Train Acc {correct / total:.3f}") 89 | 90 | for key, (x_test, y_test) in eval_datasets.items(): 91 | x_test = x_test.float().cuda() 92 | pred = torch.argmax(model(x_test), axis=-1).detach().cpu() 93 | acc = (pred == y_test).float().mean() 94 | results[f"{key}_all"].append(acc) 95 | 96 | for key in eval_datasets.keys(): 97 | results[key] = results[f"{key}_all"][-1] 98 | return results 99 | 100 | 101 | def main( 102 | batch_size: int = 128, 103 | weak_model_name: str = "alexnet", 104 | strong_model_name: str = "resnet50_dino", 105 | n_train: int = 40000, 106 | seed: int = 0, 107 | data_path: str = "/root/", 108 | n_epochs: int = 10, 109 | lr: float = 1e-3, 110 | ): 111 | weak_model = get_model(weak_model_name) 112 | strong_model = get_model(strong_model_name) 113 | _, loader = get_imagenet(data_path, split="val", batch_size=batch_size, shuffle=False) 114 | print("Getting weak labels...") 115 | _, gt_labels, weak_labels, weak_acc = get_embeddings(weak_model, loader) 116 | print(f"Weak label accuracy: {weak_acc:.3f}") 117 | print("Getting strong embeddings...") 118 | embeddings, strong_gt_labels, _, _ = get_embeddings(strong_model, loader) 119 | assert torch.all(gt_labels == strong_gt_labels) 120 | del strong_gt_labels 121 | 122 | order = np.arange(len(embeddings)) 123 | rng = np.random.default_rng(seed) 124 | rng.shuffle(order) 125 | x = embeddings[order] 126 | y = gt_labels[order] 127 | yw = weak_labels[order] 128 | x_train, x_test = x[:n_train], x[n_train:] 129 | y_train, y_test = y[:n_train], y[n_train:] 130 | yw_train, yw_test = yw[:n_train], yw[n_train:] 131 | yw_test = torch.argmax(yw_test, dim=1) 132 | eval_datasets = {"test": (x_test, y_test), "test_weak": (x_test, yw_test)} 133 | 134 | print("Training logreg on weak labels...") 135 | results_weak = train_logreg(x_train, yw_train, eval_datasets, n_epochs=n_epochs, lr=lr) 136 | print(f"Final accuracy: {results_weak['test']:.3f}") 137 | print(f"Final supervisor-student agreement: {results_weak['test_weak']:.3f}") 138 | print(f"Accuracy by epoch: {[acc.item() for acc in results_weak['test_all']]}") 139 | print( 140 | f"Supervisor-student agreement by epoch: {[acc.item() for acc in results_weak['test_weak_all']]}" 141 | ) 142 | 143 | print("Training logreg on ground truth labels...") 144 | results_gt = train_logreg(x_train, y_train, eval_datasets, n_epochs=n_epochs, lr=lr) 145 | print(f"Final accuracy: {results_gt['test']:.3f}") 146 | print(f"Accuracy by epoch: {[acc.item() for acc in results_gt['test_all']]}") 147 | 148 | print("\n\n" + "=" * 100) 149 | print(f"Weak label accuracy: {weak_acc:.3f}") 150 | print(f"Weak→Strong accuracy: {results_weak['test']:.3f}") 151 | print(f"Strong accuracy: {results_gt['test']:.3f}") 152 | print("=" * 100) 153 | 154 | 155 | if __name__ == "__main__": 156 | fire.Fire(main) 157 | -------------------------------------------------------------------------------- /weak-to-strong-setup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/weak-to-strong/6b450f2cee3714d6f886e5e1910bac73633bf69a/weak-to-strong-setup.png -------------------------------------------------------------------------------- /weak_to_strong/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/weak-to-strong/6b450f2cee3714d6f886e5e1910bac73633bf69a/weak_to_strong/__init__.py -------------------------------------------------------------------------------- /weak_to_strong/common.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | import torch 4 | from transformers import AutoTokenizer 5 | 6 | 7 | def get_tokenizer(model_name: str): 8 | """ 9 | This function returns a tokenizer based on the model name. 10 | 11 | Parameters: 12 | model_name: The name of the model for which the tokenizer is needed. 13 | 14 | Returns: 15 | A tokenizer for the specified model. 16 | """ 17 | return AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 18 | 19 | 20 | def clear_mem(verbose: bool = False): 21 | """ 22 | This function is used to clear the memory allocated by PyTorch. 23 | It does so by calling the garbage collector to release unused GPU memory. 24 | After clearing the memory, it prints the current amount of memory still allocated by PyTorch (post-clean). 25 | 26 | Parameters: 27 | verbose (bool): Whether to print additional information. 28 | """ 29 | 30 | gc.collect() 31 | torch.cuda.empty_cache() 32 | print( 33 | f"torch.cuda.memory_allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f}GB" 34 | ) 35 | 36 | if verbose: 37 | 38 | def try_attr(x, a): 39 | try: 40 | return getattr(x, a) 41 | except: 42 | # amazing that this can cause... 43 | # (AttributeError, OSError, AssertionError, RuntimeError, ModuleNotFoundError) 44 | return None 45 | 46 | for obj in gc.get_objects(): 47 | if torch.is_tensor(obj) or torch.is_tensor(try_attr(obj, "data")): 48 | print(type(obj), obj.size(), obj.dtype) 49 | -------------------------------------------------------------------------------- /weak_to_strong/datasets.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from dataclasses import dataclass 3 | from random import Random 4 | from typing import Any, Callable, Optional 5 | 6 | from datasets import Dataset as HfDataset 7 | from datasets import load_dataset as hf_load_dataset 8 | 9 | 10 | @dataclass 11 | class DatasetConfig: 12 | # split -> unshuffled dataset of items 13 | loader: Callable[[str], HfDataset] 14 | # formats items to have keys 'txt' and 'hard_label', takes a random.Random rng 15 | formatter: Callable[[Any], Any] 16 | 17 | 18 | # mapping from dataset name to load function and format function 19 | _REGISTRY: dict[str, DatasetConfig] = {} 20 | 21 | 22 | def register_dataset(name: str, config: DatasetConfig): 23 | _REGISTRY[name] = config 24 | 25 | 26 | def load_dataset(ds_name: str, seed: int = 0, split_sizes: Optional[dict] = None): 27 | if split_sizes is None: 28 | split_sizes = dict(train=None, test=None) 29 | 30 | if ds_name not in _REGISTRY: 31 | raise ValueError(f"Unknown dataset {ds_name}, please register") 32 | cfg = _REGISTRY[ds_name] 33 | results = {} 34 | for split, n_docs in split_sizes.items(): 35 | ds = cfg.loader(split) 36 | try: 37 | ds = ds.select(range(n_docs)) 38 | except IndexError as e: 39 | print(f"Warning {ds_name} has less than {n_docs} docs, using all: {e}") 40 | ds = ds.map(functools.partial(cfg.formatter, rng=Random(seed))) 41 | ds = ds.map( 42 | lambda ex: {"soft_label": [1 - float(ex["hard_label"]), float(ex["hard_label"])]} 43 | ) 44 | ds = ds.shuffle(seed=seed) # shuffling a bit pointless for test set but wtv 45 | results[split] = ds 46 | return results 47 | 48 | 49 | def tokenize_dataset( 50 | raw_ds: HfDataset, 51 | tokenizer: Callable, 52 | max_ctx: int, 53 | ): 54 | """ 55 | This function prepares the dataset for training. It takes the raw dataset, a formatting function, 56 | a tokenizer, a maximum context length 57 | 58 | Parameters: 59 | raw_ds: The raw dataset to be processed. 60 | tokenizer: The tokenizer to be used on the formatted dataset. 61 | max_ctx: The maximum context length for the tokenizer. 62 | 63 | Returns: 64 | ds: The processed and shuffled dataset ready for training. 65 | """ 66 | 67 | def process_function(res): 68 | toks = tokenizer(res["txt"]) 69 | return dict( 70 | input_ids=toks["input_ids"], 71 | ) 72 | 73 | ds = raw_ds.map(process_function, batched=False).filter(lambda x: len(x["input_ids"]) < max_ctx) 74 | return ds 75 | 76 | 77 | def hf_loader(*hf_name, split_names=None): 78 | if split_names is None: 79 | split_names = dict() 80 | return lambda split: hf_load_dataset(*hf_name, split=split_names.get(split, split)) 81 | 82 | 83 | ########## 84 | # ACTUAL DATASETS 85 | ########## 86 | 87 | 88 | def format_amazon_polarity(ex, rng): 89 | return dict(txt=f"{ex['title']} {ex['content']}", hard_label=ex["label"]) 90 | 91 | 92 | register_dataset( 93 | "amazon_polarity", 94 | DatasetConfig(loader=hf_loader("amazon_polarity"), formatter=format_amazon_polarity), 95 | ) 96 | 97 | 98 | def format_sciq(ex, rng): 99 | hard_label = int(rng.random() < 0.5) 100 | if hard_label: 101 | ans = ex["correct_answer"] 102 | else: 103 | ans = rng.choice([ex["distractor1"], ex["distractor2"], ex["distractor3"]]) 104 | txt = f"Q: {ex['question']} A: {ans}" 105 | return dict(txt=txt, hard_label=hard_label) 106 | 107 | 108 | register_dataset( 109 | "sciq", 110 | DatasetConfig(loader=hf_loader("sciq"), formatter=format_sciq), 111 | ) 112 | 113 | 114 | def format_anthropic_hh(ex, rng): 115 | hard_label = int(rng.random() < 0.5) 116 | txt = ex["chosen"] if hard_label else ex["rejected"] 117 | return dict(txt=txt, hard_label=hard_label) 118 | 119 | 120 | register_dataset( 121 | "anthropic_hh", 122 | DatasetConfig(loader=hf_loader("Anthropic/hh-rlhf"), formatter=format_anthropic_hh), 123 | ) 124 | 125 | 126 | def format_cosmosqa(ex, rng): 127 | true_answer = ex["answer" + str(ex["label"])] 128 | if "None of the above choices ." in true_answer: 129 | hard_label = 0 130 | else: 131 | assert "None of the above choices" not in true_answer, true_answer 132 | hard_label = int(rng.random() < 0.5) 133 | if hard_label: 134 | answer = true_answer 135 | else: 136 | candidate_answers = [ex["answer" + str(i)] for i in range(4)] 137 | answer = rng.choice([x for x in candidate_answers if x != true_answer]) 138 | txt = f"Context: {ex['context']}\nQuestion: {ex['question']}\nAnswer: {answer}" 139 | return dict(txt=txt, hard_label=hard_label) 140 | 141 | 142 | register_dataset( 143 | "cosmos_qa", 144 | DatasetConfig( 145 | loader=hf_loader("cosmos_qa", split_names=dict(test="validation")), 146 | formatter=format_cosmosqa, 147 | ), 148 | ) 149 | 150 | 151 | def format_boolq(ex, rng): 152 | hard_label = int(ex["answer"]) 153 | txt = f"Passage: {ex['passage']}\nQuestion: {ex['question']}" 154 | return dict(txt=txt, hard_label=hard_label) 155 | 156 | 157 | register_dataset( 158 | "boolq", 159 | DatasetConfig( 160 | loader=hf_loader("boolq", split_names=dict(test="validation")), formatter=format_boolq 161 | ), 162 | ) 163 | 164 | 165 | VALID_DATASETS: list[str] = list(_REGISTRY.keys()) 166 | 167 | """ 168 | from datasets import disable_caching 169 | disable_caching() 170 | 171 | from weak_to_strong.datasets import load_dataset, VALID_DATASETS 172 | import numpy as np 173 | 174 | ds_name = "boolq" 175 | print(VALID_DATASETS) 176 | 177 | ds = load_dataset(ds_name, split_sizes=dict(train=500, test=10)) 178 | train = list(ds['train']) 179 | test = list(ds['test']) 180 | print(test[0]) 181 | print(np.mean([x['hard_label'] for x in train])) 182 | """ 183 | -------------------------------------------------------------------------------- /weak_to_strong/eval.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | 6 | 7 | def to_batch(x, batch_size): 8 | for i in range(0, len(x), batch_size): 9 | yield x[i : i + batch_size] 10 | 11 | 12 | def unpack(x): 13 | assert isinstance(x, torch.Tensor), type(x) 14 | return x.detach().float().cpu().numpy().tolist() 15 | 16 | 17 | def eval_model_acc(model: nn.Module, ds: datasets.Dataset, eval_batch_size: int = 16) -> None: 18 | """ 19 | This function evaluates the accuracy of a given model on a given dataset. 20 | 21 | Parameters: 22 | model (nn.Module): The model to be evaluated. 23 | ds (datasets.Dataset): The dataset on which the model is to be evaluated. 24 | 25 | Returns: 26 | results (list): A list of dictionaries containing the input_ids, ground truth label, predicted label, 27 | accuracy of prediction, logits and soft label for each example in the dataset. 28 | """ 29 | 30 | model.eval() 31 | 32 | with torch.no_grad(): 33 | results = [] 34 | # for ex in ds: 35 | for batch in to_batch(ds, eval_batch_size): 36 | # pad input_ids to common length 37 | input_ids = torch.nn.utils.rnn.pad_sequence( 38 | [torch.tensor(ex) for ex in batch["input_ids"]], batch_first=True 39 | ).to(model.device if hasattr(model, "device") else "cpu") 40 | labels = batch["soft_label"] 41 | # run forward pass 42 | raw_logits = model(input_ids) 43 | 44 | probs = unpack(torch.nn.functional.softmax(raw_logits, dim=-1)) 45 | logits = unpack(raw_logits) 46 | 47 | preds = np.argmax(probs, axis=-1) 48 | labels = np.argmax(labels, axis=-1) 49 | 50 | results.extend( 51 | [ 52 | dict( 53 | txt=txt, 54 | input_ids=input_id, 55 | gt_label=label, 56 | hard_label=pred, 57 | acc=label == pred, 58 | logits=logit, 59 | soft_label=prob, 60 | ) 61 | for input_id, txt, label, pred, prob, logit in zip( 62 | batch["input_ids"], batch["txt"], labels, preds, probs, logits 63 | ) 64 | ] 65 | ) 66 | accs = [r["acc"] for r in results] 67 | print("Accuracy:", np.mean(accs), "+/-", np.std(accs) / np.sqrt(len(accs))) 68 | 69 | return datasets.Dataset.from_list(results) 70 | -------------------------------------------------------------------------------- /weak_to_strong/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from datetime import datetime 4 | 5 | import wandb 6 | 7 | 8 | def append_to_jsonl(path: str, data: dict): 9 | with open(path, "a") as f: 10 | f.write(json.dumps(data) + "\n") 11 | 12 | 13 | class WandbLogger(object): 14 | CURRENT = None 15 | 16 | log_path = None 17 | 18 | def __init__( 19 | self, 20 | **kwargs, 21 | ): 22 | project = os.environ.get("WANDB_PROJECT") 23 | self.use_wandb = project is not None 24 | if self.use_wandb: 25 | wandb.init( 26 | config=kwargs, 27 | project=project, 28 | name=kwargs["name"].format( 29 | **kwargs, datetime_now=datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 30 | ) 31 | if "name" in kwargs 32 | else None, 33 | ) 34 | if "save_path" in kwargs: 35 | self.log_path = os.path.join(kwargs["save_path"], "log.jsonl") 36 | if not os.path.exists(kwargs["save_path"]): 37 | os.makedirs(kwargs["save_path"]) 38 | self._log_dict = {} 39 | 40 | def logkv(self, key, value): 41 | self._log_dict[key] = value 42 | 43 | def logkvs(self, d): 44 | self._log_dict.update(d) 45 | 46 | def dumpkvs(self): 47 | if self.use_wandb: 48 | wandb.log(self._log_dict) 49 | if self.log_path is not None: 50 | append_to_jsonl(self.log_path, self._log_dict) 51 | self._log_dict = {} 52 | 53 | def shutdown(self): 54 | if self.use_wandb: 55 | wandb.finish() 56 | 57 | 58 | def is_configured(): 59 | return WandbLogger.CURRENT is not None 60 | 61 | 62 | def get_current(): 63 | assert is_configured(), "WandbLogger is not configured" 64 | return WandbLogger.CURRENT 65 | 66 | 67 | def configure(**kwargs): 68 | if is_configured(): 69 | WandbLogger.CURRENT.shutdown() 70 | WandbLogger.CURRENT = WandbLogger(**kwargs) 71 | return WandbLogger.CURRENT 72 | 73 | 74 | def logkv(key, value): 75 | assert is_configured(), "WandbLogger is not configured" 76 | WandbLogger.CURRENT.logkv(key, value) 77 | 78 | 79 | def logkvs(d): 80 | assert is_configured(), "WandbLogger is not configured" 81 | WandbLogger.CURRENT.logkvs(d) 82 | 83 | 84 | def dumpkvs(): 85 | assert is_configured(), "WandbLogger is not configured" 86 | WandbLogger.CURRENT.dumpkvs() 87 | 88 | 89 | def shutdown(): 90 | assert is_configured(), "WandbLogger is not configured" 91 | WandbLogger.CURRENT.shutdown() 92 | WandbLogger.CURRENT = None 93 | -------------------------------------------------------------------------------- /weak_to_strong/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LossFnBase: 5 | def __call__( 6 | self, 7 | logits: torch.Tensor, 8 | labels: torch.Tensor, 9 | **kwargs, 10 | ) -> torch.Tensor: 11 | """ 12 | This function calculates the loss between logits and labels. 13 | """ 14 | raise NotImplementedError 15 | 16 | 17 | # Custom loss function 18 | class xent_loss(LossFnBase): 19 | def __call__( 20 | self, logits: torch.Tensor, labels: torch.Tensor, step_frac: float 21 | ) -> torch.Tensor: 22 | """ 23 | This function calculates the cross entropy loss between logits and labels. 24 | 25 | Parameters: 26 | logits: The predicted values. 27 | labels: The actual values. 28 | step_frac: The fraction of total training steps completed. 29 | 30 | Returns: 31 | The mean of the cross entropy loss. 32 | """ 33 | loss = torch.nn.functional.cross_entropy(logits, labels) 34 | return loss.mean() 35 | 36 | 37 | class product_loss_fn(LossFnBase): 38 | """ 39 | This class defines a custom loss function for product of predictions and labels. 40 | 41 | Attributes: 42 | alpha: A float indicating how much to weigh the weak model. 43 | beta: A float indicating how much to weigh the strong model. 44 | warmup_frac: A float indicating the fraction of total training steps for warmup. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | alpha: float = 1.0, # how much to weigh the weak model 50 | beta: float = 1.0, # how much to weigh the strong model 51 | warmup_frac: float = 0.1, # in terms of fraction of total training steps 52 | ): 53 | self.alpha = alpha 54 | self.beta = beta 55 | self.warmup_frac = warmup_frac 56 | 57 | def __call__( 58 | self, 59 | logits: torch.Tensor, 60 | labels: torch.Tensor, 61 | step_frac: float, 62 | ) -> torch.Tensor: 63 | preds = torch.softmax(logits, dim=-1) 64 | target = torch.pow(preds, self.beta) * torch.pow(labels, self.alpha) 65 | target /= target.sum(dim=-1, keepdim=True) 66 | target = target.detach() 67 | loss = torch.nn.functional.cross_entropy(logits, target, reduction="none") 68 | return loss.mean() 69 | 70 | 71 | class logconf_loss_fn(LossFnBase): 72 | """ 73 | This class defines a custom loss function for log confidence. 74 | 75 | Attributes: 76 | aux_coef: A float indicating the auxiliary coefficient. 77 | warmup_frac: A float indicating the fraction of total training steps for warmup. 78 | """ 79 | 80 | def __init__( 81 | self, 82 | aux_coef: float = 0.5, 83 | warmup_frac: float = 0.1, # in terms of fraction of total training steps 84 | ): 85 | self.aux_coef = aux_coef 86 | self.warmup_frac = warmup_frac 87 | 88 | def __call__( 89 | self, 90 | logits: torch.Tensor, 91 | labels: torch.Tensor, 92 | step_frac: float, 93 | ) -> torch.Tensor: 94 | logits = logits.float() 95 | labels = labels.float() 96 | coef = 1.0 if step_frac > self.warmup_frac else step_frac 97 | coef = coef * self.aux_coef 98 | preds = torch.softmax(logits, dim=-1) 99 | mean_weak = torch.mean(labels, dim=0) 100 | assert mean_weak.shape == (2,) 101 | threshold = torch.quantile(preds[:, 0], mean_weak[1]) 102 | strong_preds = torch.cat( 103 | [(preds[:, 0] >= threshold)[:, None], (preds[:, 0] < threshold)[:, None]], 104 | dim=1, 105 | ) 106 | target = labels * (1 - coef) + strong_preds.detach() * coef 107 | loss = torch.nn.functional.cross_entropy(logits, target, reduction="none") 108 | return loss.mean() 109 | -------------------------------------------------------------------------------- /weak_to_strong/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel 5 | 6 | 7 | @dataclass 8 | class HeadOutput: 9 | logits: torch.FloatTensor 10 | 11 | 12 | class TransformerWithHead(PreTrainedModel): 13 | """ 14 | This class initializes the linear head to zeros 15 | """ 16 | 17 | def __init__(self, name, linear_probe=False, **kwargs): 18 | config = AutoConfig.from_pretrained(name, **kwargs) 19 | super().__init__(config) 20 | self.num_labels = config.num_labels 21 | lm = AutoModelForCausalLM.from_pretrained(name, **kwargs) 22 | self.lm = lm 23 | self.transformer = lm.transformer 24 | hidden_size = getattr(config, "n_embd", getattr(config, "hidden_size", None)) 25 | self.score = torch.nn.Linear(hidden_size, self.num_labels, bias=False).to( 26 | lm.lm_head.weight.dtype 27 | ) 28 | torch.nn.init.normal_(self.score.weight, std=0.0) 29 | self.linear_probe = linear_probe 30 | 31 | @classmethod 32 | def from_pretrained(cls, name, **kwargs): 33 | return cls(name, **kwargs) 34 | 35 | def gradient_checkpointing_enable(self): 36 | model = self.transformer 37 | ( 38 | model if hasattr(model, "save_pretrained") else model.module 39 | ).gradient_checkpointing_enable() 40 | 41 | def forward(self, input_ids: torch.LongTensor): 42 | """ 43 | Forward pass of the model with a linear head. 44 | 45 | Parameters: 46 | input_ids (torch.LongTensor): Input tensor containing the token ids. 47 | 48 | Returns: 49 | HeadOutput: Output dataclass containing the logits. 50 | """ 51 | input_lens = (input_ids != 0).sum(dim=-1) 52 | transformer_outputs = self.transformer(input_ids) 53 | hidden_states = torch.stack( 54 | [transformer_outputs[0][i, input_lens[i] - 1, :] for i in range(len(input_lens))] 55 | ) 56 | self.score.to(hidden_states.device) 57 | if self.linear_probe: 58 | hidden_states = hidden_states.detach() 59 | logits = self.score(hidden_states) 60 | return logits 61 | -------------------------------------------------------------------------------- /weak_to_strong/train.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | import pickle 4 | import time 5 | from dataclasses import dataclass 6 | from typing import Callable, Optional 7 | 8 | import datasets 9 | import numpy as np 10 | import torch 11 | import torch_optimizer as toptim 12 | from transformers.modeling_utils import load_sharded_checkpoint 13 | 14 | import weak_to_strong.logger as logger 15 | from weak_to_strong.common import clear_mem 16 | from weak_to_strong.eval import eval_model_acc 17 | from weak_to_strong.loss import xent_loss 18 | from weak_to_strong.model import TransformerWithHead 19 | 20 | 21 | @dataclass 22 | class ModelConfig: 23 | name: str 24 | default_lr: float 25 | eval_batch_size: int 26 | custom_kwargs: Optional[dict] = None 27 | gradient_checkpointing: bool = False 28 | model_parallel: bool = False 29 | default_optimizer: str = "adam" 30 | 31 | 32 | def train_model( 33 | model: torch.nn.Module, 34 | ds: datasets.Dataset, 35 | batch_size: int, 36 | lr: float = 1e-5, 37 | loss_fn: Callable = xent_loss, 38 | log_every: int = 10, 39 | eval_every: int = 100, 40 | eval_batch_size: int = 256, 41 | minibatch_size: int = 8, 42 | eval_ds: Optional[datasets.Dataset] = None, 43 | gradient_checkpointing: bool = False, 44 | train_with_dropout: bool = False, 45 | epochs: int = 1, 46 | lr_schedule: str = "cosine_anneal", 47 | optimizer_name: str = "adam", 48 | ): 49 | print("LR", lr, "batch_size", batch_size, "minibatch_size", minibatch_size) 50 | assert batch_size % minibatch_size == 0, "batch size must be divisible by minibatch size" 51 | # we purposefully turn off dropout, for determinism 52 | # this seems to help for 1 epoch finetuning anyways 53 | if train_with_dropout: 54 | model.train() 55 | else: 56 | model.eval() 57 | if gradient_checkpointing: 58 | ( 59 | model if hasattr(model, "gradient_checkpointing_enable") else model.module 60 | ).gradient_checkpointing_enable() 61 | 62 | nsteps = len(ds) * epochs // batch_size 63 | 64 | def lr_schedule_fn(step): 65 | if lr_schedule == "constant": 66 | return 1 67 | else: 68 | assert False, f"invalid lr schedule, {lr_schedule}, must be constant or cosine_anneal" 69 | 70 | if optimizer_name.lower() == "adam": 71 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 72 | elif optimizer_name.lower() == "adafactor": 73 | optimizer = toptim.Adafactor(model.parameters(), lr=lr) 74 | else: 75 | assert False, f"invalid optimizer {optimizer_name}, must be adam or adafactor" 76 | if lr_schedule == "cosine_anneal": 77 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, nsteps) 78 | else: 79 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule_fn) 80 | step = 0 81 | it = itertools.chain.from_iterable(itertools.repeat(ds, epochs)) 82 | losses = [] 83 | accuracies = [] 84 | eval_acc_dict = {} 85 | 86 | # If the model is wrapped by DataParallel, it doesn't have a device. In this case, 87 | # we use GPU 0 as the output device. This sadly means that this device will store 88 | # a bit more data than other ones, but hopefully should not be too big of a deal. 89 | io_device = model.device if hasattr(model, "device") else 0 90 | 91 | while step < nsteps: 92 | loss_tot = 0 93 | if eval_every and (step + 1) % eval_every == 0: 94 | eval_results = eval_model_acc(model, eval_ds, eval_batch_size) 95 | if gradient_checkpointing: 96 | ( 97 | model if hasattr(model, "gradient_checkpointing_enable") else model.module 98 | ).gradient_checkpointing_enable() 99 | if train_with_dropout: 100 | model.train() 101 | eval_accs = np.mean([r["acc"] for r in eval_results]) 102 | eval_acc_dict[step] = eval_accs 103 | logger.logkv("eval_accuracy", eval_accs) 104 | all_logits = [] 105 | all_labels = [] 106 | for i in range(batch_size // minibatch_size): 107 | try: 108 | mbatch = [next(it) for _ in range(minibatch_size)] 109 | except StopIteration: 110 | break 111 | input_ids = ( 112 | torch.nn.utils.rnn.pad_sequence([torch.tensor(ex["input_ids"]) for ex in mbatch]) 113 | .transpose( 114 | 0, 115 | 1, 116 | ) 117 | .to(io_device) 118 | ) 119 | labels = torch.tensor([ex["soft_label"] for ex in mbatch]).to(io_device) 120 | 121 | logits = model(input_ids) 122 | 123 | all_logits.extend(logits.to(io_device)) 124 | all_labels.extend(labels) 125 | all_logits = torch.stack(all_logits) 126 | all_labels = torch.stack(all_labels) 127 | loss = loss_fn(all_logits, all_labels, step_frac=step / nsteps) 128 | loss_tot += loss.item() 129 | loss.backward() 130 | losses.append(loss_tot) 131 | accuracies.append( 132 | torch.mean( 133 | (torch.argmax(all_logits, dim=1) == torch.argmax(all_labels, dim=1)).to( 134 | torch.float32 135 | ) 136 | ).item() 137 | ) 138 | logger.logkvs( 139 | { 140 | "step": step, 141 | "progress": step / nsteps, 142 | "loss": loss_tot, 143 | "train_accuracy": accuracies[-1], 144 | "lr": lr_scheduler.get_last_lr()[0], 145 | } 146 | ) 147 | optimizer.step() 148 | optimizer.zero_grad() 149 | lr_scheduler.step() 150 | if log_every and step % log_every == 0: 151 | print( 152 | f"Step: {step}/{nsteps} Recent losses: {np.mean(losses)} {np.mean(accuracies)} {len(losses)}" 153 | ) 154 | losses = [] 155 | accuracies = [] 156 | step += 1 157 | logger.dumpkvs() 158 | final_eval_results = None 159 | if eval_every: 160 | print("Final evaluation:") 161 | final_eval_results = eval_model_acc(model, eval_ds, eval_batch_size) 162 | logger.logkv("eval_accuracy", np.mean([r["acc"] for r in final_eval_results])) 163 | logger.dumpkvs() 164 | return final_eval_results 165 | 166 | 167 | def train_and_save_model( 168 | model_config: ModelConfig, 169 | train_ds: datasets.Dataset, 170 | test_ds: datasets.Dataset, 171 | inference_ds: Optional[datasets.Dataset] = None, 172 | *, 173 | batch_size: int, 174 | lr: float, 175 | epochs: int, 176 | eval_batch_size: Optional[int] = None, 177 | minibatch_size_per_device: Optional[int] = None, 178 | save_path: Optional[str] = None, 179 | loss_fn: Callable = xent_loss, 180 | label: str = "default", 181 | force_retrain: bool = False, 182 | train_with_dropout: bool = False, 183 | linear_probe: bool = False, 184 | lr_schedule: str = "constant", 185 | optimizer_name: str = "adam", 186 | eval_every: Optional[int] = None, 187 | ): 188 | if eval_batch_size is None: 189 | eval_batch_size = batch_size 190 | 191 | if minibatch_size_per_device is None: 192 | minibatch_size_per_device = 1 193 | 194 | gradient_checkpointing = model_config.gradient_checkpointing 195 | custom_kwargs = model_config.custom_kwargs or {} 196 | 197 | def maybe_load_model(model): 198 | if os.path.exists(os.path.join(save_path, "results.pkl")) and not force_retrain: 199 | print("loading from", save_path) 200 | checkpoint_path = os.path.join(save_path, "pytorch_model.bin") 201 | if not os.path.exists(checkpoint_path): 202 | # Assume this means we have a sharded checkpoint, and load it appropriately 203 | load_sharded_checkpoint(model, checkpoint_path) 204 | else: 205 | state_dict = torch.load(os.path.join(save_path, "pytorch_model.bin")) 206 | state_dict = { 207 | k.replace("transformer.module", "transformer"): v 208 | for (k, v) in state_dict.items() 209 | } 210 | custom_kwargs["state_dict"] = state_dict 211 | return True 212 | return False 213 | 214 | already_trained = False 215 | # Load the model 216 | if model_config.model_parallel: 217 | assert torch.cuda.device_count() > 1, f"you might want more gpus for {model_config.name}" 218 | model = TransformerWithHead.from_pretrained( 219 | model_config.name, 220 | num_labels=2, 221 | device_map="auto", 222 | linear_probe=linear_probe, 223 | **custom_kwargs, 224 | ) 225 | already_trained = maybe_load_model(model) 226 | # slight misnomer, more like minibatch_size_per_dp_replica 227 | minibatch_size = minibatch_size_per_device 228 | else: 229 | model = TransformerWithHead.from_pretrained( 230 | model_config.name, num_labels=2, linear_probe=linear_probe, **custom_kwargs 231 | ).to("cuda") 232 | already_trained = maybe_load_model(model) 233 | # data parallel: currently not supported with model parallel 234 | 235 | minibatch_size = min(minibatch_size_per_device * torch.cuda.device_count(), batch_size) 236 | 237 | if torch.cuda.device_count() > 1: 238 | model = torch.nn.DataParallel(model, output_device=0) 239 | print( 240 | "Using", 241 | torch.cuda.device_count(), 242 | "GPUs, setting minibatch_size to", 243 | minibatch_size, 244 | ) 245 | else: 246 | minibatch_size = minibatch_size_per_device 247 | 248 | if already_trained: 249 | test_results = eval_model_acc(model, test_ds, eval_batch_size) 250 | else: 251 | start = time.time() 252 | test_results = train_model( 253 | model, 254 | train_ds, 255 | batch_size, 256 | lr=lr, 257 | epochs=epochs, 258 | eval_ds=test_ds, 259 | gradient_checkpointing=gradient_checkpointing, 260 | loss_fn=loss_fn, 261 | eval_batch_size=eval_batch_size, 262 | eval_every=eval_every, 263 | minibatch_size=minibatch_size, 264 | train_with_dropout=train_with_dropout, 265 | lr_schedule=lr_schedule, 266 | optimizer_name=optimizer_name, 267 | ) 268 | print("Model training took", time.time() - start, "seconds") 269 | if save_path: 270 | # Note: If the model is wrapped by DataParallel, we need to unwrap it before saving 271 | (model if hasattr(model, "save_pretrained") else model.module).save_pretrained( 272 | save_path 273 | ) 274 | print("saved", save_path) 275 | 276 | inference_results = None 277 | if inference_ds: 278 | inference_results = eval_model_acc(model, inference_ds, eval_batch_size) 279 | logger.logkv("inference_accuracy", np.mean([r["acc"] for r in inference_results])) 280 | 281 | if save_path: 282 | with open(os.path.join(save_path, "results.pkl"), "wb") as f: 283 | pickle.dump( 284 | { 285 | "avg_acc_test": float(np.mean([r["acc"] for r in test_results])), 286 | "avg_acc_inference": float( 287 | np.mean([r["acc"] for r in inference_results] if inference_results else []) 288 | ), 289 | "test_results": test_results, 290 | "inference_results": inference_results if inference_results else [], 291 | }, 292 | f, 293 | ) 294 | # try to clean up memory 295 | clear_mem() 296 | logger.shutdown() 297 | 298 | return test_results, inference_results 299 | --------------------------------------------------------------------------------