├── .gitignore
├── LICENSE.md
├── README.md
├── notebooks
├── Plotting.ipynb
├── Plotting_old.ipynb
├── amazon_polarity.png
├── anthropic_hh.png
├── boolq.png
├── cosmos_qa.png
└── sciq.png
├── pyproject.toml
├── setup.py
├── sweep.py
├── train_simple.py
├── train_weak_to_strong.py
├── vision
├── README.md
├── data.py
├── models.py
└── run_weak_strong.py
├── weak-to-strong-setup.png
└── weak_to_strong
├── __init__.py
├── common.py
├── datasets.py
├── eval.py
├── logger.py
├── loss.py
├── model.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | dump
2 | *.pyc
3 | *.swp
4 | *.swo
5 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Copyright 2023 OpenAI
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | **STATUS**: This codebase is not well tested and does not use the exact same settings we used in the paper, but in our experience gives qualitatively similar results when using large model size gaps and multiple seeds. Expected results can be found for two datasets below.
2 |
3 | # Weak-to-strong generalization
4 |
5 | 
6 |
7 | This project contains code for implementing our [paper on weak-to-strong generalization](https://cdn.openai.com/papers/weak-to-strong-generalization.pdf).
8 |
9 | The primary codebase contains a re-implementation of our weak-to-strong learning setup for binary classification tasks. The codebase contains code for fine-tuning pretrained language models, and also training against the labels from another language model. We support various losses described in the paper as well, such as the confidence auxiliary loss.
10 |
11 | The `vision` directory contains stand-alone code for weak-to-strong in the vision models setting (AlexNet -> DINO on ImageNet).
12 |
13 | ### Getting Started
14 |
15 | These instructions will get you a copy of the project up and running on your local machine for development and testing purposes.
16 |
17 | #### Installation
18 |
19 | You need to have Python installed on your machine. The project uses `pyproject.toml` to manage dependencies. To install the dependencies, you can use a package manager like `pip`:
20 |
21 | ```
22 | pip install .
23 | ```
24 |
25 | #### Running the Script
26 |
27 | The main script of the project is `sweep.py`. It can be run from the command line using the following command:
28 | ```
29 | python sweep.py --model_sizes=gpt2,gpt2-medium
30 | ```
31 |
32 | In addition to `--model_sizes`, `sweep.py` takes in almost all of the arguments that `train_simple.py` takes (e.g.
33 | `--batch_size`, `--n_docs`, `--n_test_docs` etc., see `train_simple.py` for a full list). These arguments are simply
34 | forwarded to `train_simple.py`.
35 |
36 | `sweep.py` calls `train_simple.py` in the following way:
37 | 1. First, it calls `train_simple.py` for each model size to train the ground truth models
38 | 2. Then, for each pair of weak and strong models in `model_sizes` (where a model can be the strong model in the pair
39 | only if its index in the `model_sizes` list is >= the index of the weak model), it calls `train_simple.py` with a
40 | `--weak_model_size` argument so that the strong model is trained with the labels of the weak model.
41 |
42 | E.g. the example above will run gpt2 (ground truth), gpt2-medium (ground truth), gpt2 -> gpt2, gpt2 -> gpt2-medium, and
43 | gpt2-medium -> gpt2-medium.
44 |
45 | If needed, you can also run `train_simple.py` directly.
46 |
47 | Note that `sweep.py` will not accept the arguments `--weak_model_size`, `--weak_labels_path` or `--model_size` (as opposed
48 | to `--model_sizes`, with an "s") as choosing their values automatically is precisely the point of `sweep.py`.
49 |
50 | An example of Jupyter notebook for plotting results is found in `notebooks/Plotting.ipynb`.
51 |
52 | At the time of release, the main script was called `train_weak_to_strong.py`, but it was less usable than
53 | `sweep.py` and `train_simple.py`. It is preserved here and the old instructions are given at the end of the document.
54 |
55 | #### Expected results
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 | ### Authors
68 |
69 | - Adrien Ecoffet
70 | - Manas Joglekar
71 | - Jeffrey Wu
72 | - Jan Hendrik Kirchner
73 | - Pavel Izmailov (vision)
74 |
75 | ### License
76 |
77 | This project is licensed under the MIT License - see the LICENSE.md file for details.
78 |
79 | ### Acknowledgments
80 |
81 | - Hugging Face for their open-source transformer models
82 |
83 | ### Original single run script
84 |
85 | You can run the original training script using:
86 | ```
87 | python train_weak_to_strong.py
88 | ```
89 |
90 | The script accepts several command-line arguments to customize the training process. Here are some examples:
91 |
92 | ```
93 | python train_weak_to_strong.py --batch_size 32 --max_ctx 512 --ds_name "sciq" --loss "logconf" --n_docs 1000 --n_test_docs 100 --weak_model_size "gpt2-medium" --strong_model_size "gpt2-large" --seed 42
94 | ```
95 |
96 | The notebook `notebooks/Plotting_old.ipynb` preserves the plotting notebook corresponding to old style training.
97 |
98 | The key difference between this style and the new `sweep.py` style is that `train_weak_to_strong.py` will always
99 | train three models: a weak model, a transfer model, and a strong model. `sweep.py` optimizes this by training
100 | a series of ground truth models (which will serve as weak and strong models) as well as a series of transfer models
101 | all in one go. This reduces training duplication and is arguably simpler. The files generated by `train_simple.py`
102 | and `sweep.py` are also simpler to use.
--------------------------------------------------------------------------------
/notebooks/Plotting.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "eb9a4b5a",
6 | "metadata": {},
7 | "source": [
8 | "# Simple Plotting\n"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "88c7ff9f",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "RESULTS_PATH = \"../../your_sweep_path/default\"\n",
19 | "\n",
20 | "PLOT_ALL_SEEDS = False\n",
21 | "# Full sweep\n",
22 | "MODELS_TO_PLOT = [\"gpt2\", \"gpt2-medium\", \"gpt2-large\", \"gpt2-xl\", \"Qwen/Qwen-1_8B\", \"Qwen/Qwen-7B\", \"Qwen/Qwen-14B\"]\n",
23 | "# Minimal sweep\n",
24 | "# MODELS_TO_PLOT = [\"gpt2\", \"gpt2-medium\", \"gpt2-large\"]\n"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "id": "00ca073c",
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "import numpy as np\n",
35 | "import pandas as pd\n",
36 | "import matplotlib.pyplot as plt\n",
37 | "import seaborn as sns\n",
38 | "sns.set_style('whitegrid')\n",
39 | "\n",
40 | "from IPython.display import display\n",
41 | "\n",
42 | "import os\n",
43 | "import glob\n",
44 | "import json"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "id": "e5caa051",
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "records = []\n",
55 | "for result_filename in glob.glob(os.path.join(RESULTS_PATH, \"**/results_summary.json\"), recursive=True):\n",
56 | " config_file = os.path.join(\"/\".join(result_filename.split(\"/\")[:-1]), \"config.json\")\n",
57 | " config = json.load(open(config_file, \"r\"))\n",
58 | " if config[\"model_size\"] not in MODELS_TO_PLOT:\n",
59 | " continue\n",
60 | " if 'seed' not in config:\n",
61 | " config['seed'] = 0\n",
62 | " record = config.copy()\n",
63 | " if 'weak_model' in config:\n",
64 | " for k in record['weak_model']:\n",
65 | " if k == 'model_size':\n",
66 | " assert record['weak_model'][k] == record['weak_model_size']\n",
67 | " record['weak_' + k] = record['weak_model'][k]\n",
68 | " del record['weak_model']\n",
69 | " record.update(json.load(open(result_filename)))\n",
70 | " records.append(record)\n",
71 | "\n",
72 | "df = pd.DataFrame.from_records(records).sort_values(['ds_name', 'model_size'])"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "id": "2f628577",
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "datasets = df.ds_name.unique()\n",
83 | "for dataset in datasets:\n",
84 | " cur_df = df[(df.ds_name == dataset)].copy()\n",
85 | " base_accuracies = cur_df[cur_df['weak_model_size'].isna()].groupby('model_size').agg({'accuracy': 'mean', 'seed': 'count'}).sort_values('accuracy')\n",
86 | " base_accuracy_lookup = base_accuracies['accuracy'].to_dict()\n",
87 | " base_accuracies = base_accuracies.reset_index()\n",
88 | "\n",
89 | " cur_df['strong_model_accuracy'] = cur_df['model_size'].apply(lambda x: base_accuracy_lookup[x])\n",
90 | " cur_df.loc[~cur_df['weak_model_size'].isna(), 'weak_model_accuracy'] = cur_df.loc[~cur_df['weak_model_size'].isna(), 'weak_model_size'].apply(lambda x: base_accuracy_lookup[x])\n",
91 | "\n",
92 | " # Exclude cases where the weak model is better than the strong model from PGR calculation.\n",
93 | " valid_pgr_index = (\n",
94 | " (~cur_df['weak_model_size'].isna()) & \n",
95 | " (cur_df['weak_model_size'] != cur_df['model_size']) & \n",
96 | " (cur_df['strong_model_accuracy'] > cur_df['weak_model_accuracy'])\n",
97 | " )\n",
98 | " cur_df.loc[valid_pgr_index, 'pgr'] = (cur_df.loc[valid_pgr_index, 'accuracy'] - cur_df.loc[valid_pgr_index, 'weak_model_accuracy']) / (cur_df.loc[valid_pgr_index, 'strong_model_accuracy'] - cur_df.loc[valid_pgr_index, 'weak_model_accuracy'])\n",
99 | "\n",
100 | " cur_df.loc[cur_df['weak_model_size'].isna(), \"weak_model_size\"] = \"ground truth\"\n",
101 | "\n",
102 | " for seed in [None] + (sorted(cur_df['seed'].unique().tolist()) if PLOT_ALL_SEEDS else []):\n",
103 | " plot_df = cur_df.copy().sort_values(['strong_model_accuracy']).sort_values(['loss'], ascending=False)\n",
104 | " if seed is not None:\n",
105 | " plot_df = plot_df[plot_df['seed'] == seed]\n",
106 | "\n",
107 | " print(f\"Dataset: {dataset} (seed: {seed})\")\n",
108 | "\n",
109 | " pgr_results = plot_df[~plot_df['pgr'].isna()].groupby(['loss']).aggregate({\"pgr\": \"median\"})\n",
110 | "\n",
111 | " palette = sns.color_palette('colorblind', n_colors=len(plot_df['weak_model_size'].unique()) - 1)\n",
112 | " color_dict = {model: (\"black\" if model == 'ground truth' else palette.pop()) for model in plot_df['weak_model_size'].unique()}\n",
113 | "\n",
114 | " sns.lineplot(data=plot_df, x='strong_model_accuracy', y='accuracy', hue='weak_model_size', style='loss', markers=True, palette=color_dict)\n",
115 | " pd.plotting.table(plt.gca(), pgr_results.round(4), loc='lower right', colWidths=[0.1, 0.1], cellLoc='center', rowLoc='center')\n",
116 | " plt.xticks(ticks=base_accuracies['accuracy'], labels=[f\"{e} ({base_accuracy_lookup[e]:.4f})\" for e in base_accuracies['model_size']], rotation=90)\n",
117 | " plt.title(f\"Dataset: {dataset} (seed: {seed})\")\n",
118 | " plt.legend(loc='upper left')\n",
119 | " suffix = \"\"\n",
120 | " if seed is not None:\n",
121 | " suffix = f\"_{seed}\"\n",
122 | " plt.savefig(f\"{dataset.replace('/', '-')}{suffix}.png\", dpi=300, bbox_inches='tight')\n",
123 | " plt.show()"
124 | ]
125 | }
126 | ],
127 | "metadata": {
128 | "kernelspec": {
129 | "display_name": "openai",
130 | "language": "python",
131 | "name": "python3"
132 | },
133 | "language_info": {
134 | "codemirror_mode": {
135 | "name": "ipython",
136 | "version": 3
137 | },
138 | "file_extension": ".py",
139 | "mimetype": "text/x-python",
140 | "name": "python",
141 | "nbconvert_exporter": "python",
142 | "pygments_lexer": "ipython3",
143 | "version": "3.11.5"
144 | }
145 | },
146 | "nbformat": 4,
147 | "nbformat_minor": 5
148 | }
149 |
--------------------------------------------------------------------------------
/notebooks/Plotting_old.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "eb9a4b5a",
6 | "metadata": {},
7 | "source": [
8 | "# Simple Plotting\n"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "88c7ff9f",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "RESULTS_PATH = \"../../your_sweep_results_path\"\n",
19 | "\n",
20 | "PLOT_ALL_SEEDS = False\n",
21 | "# Full sweep\n",
22 | "MODELS_TO_PLOT = [\"gpt2\", \"gpt2-medium\", \"gpt2-large\", \"gpt2-xl\", \"Qwen/Qwen-1_8B\", \"Qwen/Qwen-7B\", \"Qwen/Qwen-14B\"]\n",
23 | "# Minimal sweep\n",
24 | "# MODELS_TO_PLOT = [\"gpt2\", \"gpt2-medium\", \"gpt2-large\"]\n"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "id": "00ca073c",
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "import numpy as np\n",
35 | "import pandas as pd\n",
36 | "import matplotlib.pyplot as plt\n",
37 | "import seaborn as sns\n",
38 | "sns.set_style('whitegrid')\n",
39 | "\n",
40 | "from IPython.display import display\n",
41 | "\n",
42 | "import os\n",
43 | "import glob\n",
44 | "import json"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "id": "e5caa051",
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "records = []\n",
55 | "all_results_folders = ['/'.join(e.split('/')[:-1]) for e in glob.glob(os.path.join(RESULTS_PATH, \"**/*.results_summary.json\"), recursive=True)]\n",
56 | "for result_folder in set(all_results_folders):\n",
57 | " config_file = os.path.join(result_folder, \"config.json\")\n",
58 | " config = json.load(open(config_file, \"r\"))\n",
59 | " if config[\"strong_model_size\"] not in MODELS_TO_PLOT:\n",
60 | " continue\n",
61 | " if 'seed' not in config:\n",
62 | " config['seed'] = 0\n",
63 | " result_filename = (config[\"weak_model_size\"].replace('.', '_') + \"_\" + config[\"strong_model_size\"].replace('.', '_') + \".results_summary.json\").replace('/', '_')\n",
64 | " record = config.copy()\n",
65 | " record.update(json.load(open(config_file.replace('config.json', result_filename))))\n",
66 | " records.append(record)\n",
67 | "\n",
68 | "df = pd.DataFrame.from_records(records).sort_values(['ds_name', 'weak_model_size', 'strong_model_size'])"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "id": "2f628577",
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "datasets = df.ds_name.unique()\n",
79 | "for dataset in datasets:\n",
80 | " cur_df = df[(df.ds_name == dataset)]\n",
81 | " base_df = pd.concat([\n",
82 | " pd.DataFrame.from_dict({\"strong_model_size\": cur_df['weak_model_size'].to_list(), \"accuracy\": cur_df['weak_acc'].to_list(), \"seed\": cur_df['seed'].to_list()}),\n",
83 | " pd.DataFrame.from_dict({\"strong_model_size\": cur_df['strong_model_size'].to_list(), \"accuracy\": cur_df['strong_acc'].to_list(), \"seed\": cur_df['seed'].to_list()})\n",
84 | " ])\n",
85 | " base_accuracies = base_df.groupby('strong_model_size').agg({'accuracy': 'mean', 'seed': 'count'}).sort_values('accuracy')\n",
86 | " base_accuracy_lookup = base_accuracies['accuracy'].to_dict()\n",
87 | " base_accuracies = base_accuracies.reset_index()\n",
88 | " base_df.reset_index(inplace=True)\n",
89 | " base_df['weak_model_size'] = 'ground truth'\n",
90 | " base_df['loss'] = 'xent'\n",
91 | " base_df['strong_model_accuracy'] = base_df['strong_model_size'].apply(lambda x: base_accuracy_lookup[x])\n",
92 | "\n",
93 | " weak_to_strong = cur_df[['weak_model_size', 'strong_model_size', 'seed'] + [e for e in cur_df.columns if e.startswith('transfer_acc')]]\n",
94 | " weak_to_strong = weak_to_strong.melt(id_vars=['weak_model_size', 'strong_model_size', 'seed'], var_name='loss', value_name='accuracy')\n",
95 | " weak_to_strong = weak_to_strong.dropna(subset=['accuracy'])\n",
96 | " weak_to_strong.reset_index(inplace=True)\n",
97 | " weak_to_strong['loss'] = weak_to_strong['loss'].str.replace('transfer_acc_', '')\n",
98 | " weak_to_strong['strong_model_accuracy'] = weak_to_strong['strong_model_size'].apply(lambda x: base_accuracy_lookup[x])\n",
99 | "\n",
100 | " # Exclude cases where the weak model is better than the strong model from PGR calculation.\n",
101 | " pgr_df = cur_df[(cur_df['weak_model_size'] != cur_df['strong_model_size']) & (cur_df['strong_acc'] > cur_df['weak_acc'])]\n",
102 | " pgr_df = pgr_df.melt(id_vars=[e for e in cur_df.columns if not e.startswith('transfer_acc')], var_name='loss', value_name='transfer_acc')\n",
103 | " pgr_df = pgr_df.dropna(subset=['transfer_acc'])\n",
104 | " pgr_df['loss'] = pgr_df['loss'].str.replace('transfer_acc_', '')\n",
105 | " pgr_df['pgr'] = (pgr_df['transfer_acc'] - pgr_df['weak_acc']) / (pgr_df['strong_acc'] - pgr_df['weak_acc'])\n",
106 | "\n",
107 | " for seed in [None] + (sorted(cur_df['seed'].unique().tolist()) if PLOT_ALL_SEEDS else []):\n",
108 | " plot_df = pd.concat([base_df, weak_to_strong])\n",
109 | " seed_pgr_df = pgr_df\n",
110 | " if seed is not None:\n",
111 | " plot_df = plot_df[plot_df['seed'] == seed]\n",
112 | " # We mean across seeds, this is because sometimes the weak and strong models will have run on different hardware and therefore\n",
113 | " # have slight differences. We want to average these out when filtering by seed.\n",
114 | "\n",
115 | " seed_pgr_df = pgr_df[pgr_df['seed'] == seed]\n",
116 | "\n",
117 | " if seed is not None or cur_df['seed'].nunique() == 1:\n",
118 | " plot_df = plot_df[['strong_model_accuracy', 'weak_model_size', 'loss', 'accuracy']].groupby(['strong_model_accuracy', 'weak_model_size', 'loss']).mean().reset_index().sort_values(['loss', 'weak_model_size'], ascending=False)\n",
119 | "\n",
120 | " print(f\"Dataset: {dataset} (seed: {seed})\")\n",
121 | "\n",
122 | " pgr_results = seed_pgr_df.groupby(['loss']).aggregate({\"pgr\": \"median\"})\n",
123 | " display(pgr_results)\n",
124 | "\n",
125 | " palette = sns.color_palette('colorblind', n_colors=len(plot_df['weak_model_size'].unique()) - 1)\n",
126 | " color_dict = {model: (\"black\" if model == 'ground truth' else palette.pop()) for model in plot_df['weak_model_size'].unique()}\n",
127 | "\n",
128 | " sns.lineplot(data=plot_df, x='strong_model_accuracy', y='accuracy', hue='weak_model_size', style='loss', markers=True, palette=color_dict)\n",
129 | " pd.plotting.table(plt.gca(), pgr_results.round(4), loc='lower right', colWidths=[0.1, 0.1], cellLoc='center', rowLoc='center')\n",
130 | " plt.xticks(ticks=base_accuracies['accuracy'], labels=[f\"{e} ({base_accuracy_lookup[e]:.4f})\" for e in base_accuracies['strong_model_size']], rotation=90)\n",
131 | " plt.title(f\"Dataset: {dataset} (seed: {seed})\")\n",
132 | " plt.legend(loc='upper left')\n",
133 | " plt.savefig(f\"{dataset.replace('/', '-')}_{seed}.png\", dpi=300, bbox_inches='tight')\n",
134 | " plt.show()"
135 | ]
136 | }
137 | ],
138 | "metadata": {
139 | "kernelspec": {
140 | "display_name": "openai",
141 | "language": "python",
142 | "name": "python3"
143 | },
144 | "language_info": {
145 | "codemirror_mode": {
146 | "name": "ipython",
147 | "version": 3
148 | },
149 | "file_extension": ".py",
150 | "mimetype": "text/x-python",
151 | "name": "python",
152 | "nbconvert_exporter": "python",
153 | "pygments_lexer": "ipython3",
154 | "version": "3.11.5"
155 | }
156 | },
157 | "nbformat": 4,
158 | "nbformat_minor": 5
159 | }
160 |
--------------------------------------------------------------------------------
/notebooks/amazon_polarity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/weak-to-strong/6b450f2cee3714d6f886e5e1910bac73633bf69a/notebooks/amazon_polarity.png
--------------------------------------------------------------------------------
/notebooks/anthropic_hh.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/weak-to-strong/6b450f2cee3714d6f886e5e1910bac73633bf69a/notebooks/anthropic_hh.png
--------------------------------------------------------------------------------
/notebooks/boolq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/weak-to-strong/6b450f2cee3714d6f886e5e1910bac73633bf69a/notebooks/boolq.png
--------------------------------------------------------------------------------
/notebooks/cosmos_qa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/weak-to-strong/6b450f2cee3714d6f886e5e1910bac73633bf69a/notebooks/cosmos_qa.png
--------------------------------------------------------------------------------
/notebooks/sciq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/weak-to-strong/6b450f2cee3714d6f886e5e1910bac73633bf69a/notebooks/sciq.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "weak_to_strong"
7 | version = "0.0.1"
8 | authors = [
9 | { name="OpenAI", email="generalization@openai.com" },
10 | ]
11 | description = "Weak-to-strong generalization"
12 | readme = "README.md"
13 | requires-python = ">=3.7"
14 | classifiers = [
15 | "Programming Language :: Python :: 3",
16 | "License :: OSI Approved :: MIT License",
17 | "Operating System :: OS Independent",
18 | ]
19 | dependencies=[
20 | "torch ~= 2.1",
21 | "numpy ~= 1.24",
22 | "transformers ~= 4.36",
23 | "datasets ~= 2.14",
24 | "fire ~= 0.4",
25 | "accelerate ~= 0.25",
26 | "transformers-stream-generator ~= 0.0.4",
27 | "torch_optimizer ~= 0.3",
28 | "wandb ~= 0.16.1"
29 | ]
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | setuptools.setup(
4 | name="weak_to_strong",
5 | version="0.1",
6 | description="Weak-to-strong generalization",
7 | url="#",
8 | author="OpenAI",
9 | author_email="generalization@openai.com",
10 | packages=setuptools.find_packages(),
11 | zip_safe=False,
12 | )
13 |
--------------------------------------------------------------------------------
/sweep.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import sys
4 | from typing import List, Union
5 |
6 | import fire
7 |
8 |
9 | def main(model_sizes: Union[List[str], str], **kwargs):
10 | if isinstance(model_sizes, str):
11 | model_sizes = model_sizes.split(",")
12 | assert (
13 | "weak_model_size" not in kwargs
14 | and "model_size" not in kwargs
15 | and "weak_labels_path" not in kwargs
16 | ), "Need to use model_sizes when using sweep.py"
17 | basic_args = [sys.executable, os.path.join(os.path.dirname(__file__), "train_simple.py")]
18 | for key, value in kwargs.items():
19 | basic_args.extend([f"--{key}", str(value)])
20 |
21 | print("Running ground truth models")
22 | for model_size in model_sizes:
23 | subprocess.run(basic_args + ["--model_size", model_size], check=True)
24 |
25 | print("Running transfer models")
26 | for i in range(len(model_sizes)):
27 | for j in range(i, len(model_sizes)):
28 | weak_model_size = model_sizes[i]
29 | strong_model_size = model_sizes[j]
30 | print(f"Running weak {weak_model_size} to strong {strong_model_size}")
31 | subprocess.run(
32 | basic_args
33 | + ["--weak_model_size", weak_model_size, "--model_size", strong_model_size],
34 | check=True,
35 | )
36 |
37 |
38 | if __name__ == "__main__":
39 | fire.Fire(main)
40 |
--------------------------------------------------------------------------------
/train_simple.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | import subprocess
5 | from typing import Dict, List, Optional
6 |
7 | import fire
8 | import numpy as np
9 | import torch
10 | from datasets import load_dataset, load_from_disk
11 |
12 | import weak_to_strong.logger as logger
13 | from weak_to_strong.common import get_tokenizer
14 | from weak_to_strong.datasets import (VALID_DATASETS, load_dataset,
15 | tokenize_dataset)
16 | from weak_to_strong.loss import logconf_loss_fn, product_loss_fn, xent_loss
17 | from weak_to_strong.train import ModelConfig, train_and_save_model
18 |
19 | # NOTE learning rates are not particularly tuned, work somewhat reasonably at train batch size 32
20 | MODEL_CONFIGS = [
21 | ModelConfig(
22 | name="gpt2",
23 | default_lr=5e-5,
24 | eval_batch_size=32,
25 | ),
26 | ModelConfig(
27 | name="gpt2-medium",
28 | default_lr=5e-5,
29 | eval_batch_size=32,
30 | ),
31 | ModelConfig(
32 | name="gpt2-large",
33 | default_lr=1e-5,
34 | eval_batch_size=32,
35 | ),
36 | ModelConfig(
37 | name="gpt2-xl",
38 | default_lr=1e-5,
39 | eval_batch_size=2,
40 | gradient_checkpointing=True,
41 | # Should use model_parallel on V100s (note: ironically if you have a single V100 it should run,
42 | # but if you have multiple it won't run without model_parallel because of the overhead of data
43 | # parallel training).
44 | model_parallel=(
45 | torch.cuda.get_device_properties(0).total_memory < 35e9
46 | and torch.cuda.device_count() > 1
47 | ),
48 | ),
49 | ModelConfig(
50 | name="Qwen/Qwen-1_8B",
51 | default_lr=1e-5,
52 | eval_batch_size=2,
53 | gradient_checkpointing=True,
54 | model_parallel=(
55 | torch.cuda.get_device_properties(0).total_memory < 35e9
56 | and torch.cuda.device_count() > 1
57 | ),
58 | custom_kwargs={
59 | "trust_remote_code": True,
60 | "bf16": torch.cuda.is_bf16_supported(),
61 | "fp32": not torch.cuda.is_bf16_supported(),
62 | "revision": "5fde88dff770a7d036847211f5d9d9705f0caa69",
63 | },
64 | ),
65 | ModelConfig(
66 | name="Qwen/Qwen-7B",
67 | default_lr=1e-5,
68 | eval_batch_size=2,
69 | gradient_checkpointing=True,
70 | model_parallel=True,
71 | # note: you will probably not be able to run this without many gpus
72 | custom_kwargs={
73 | "trust_remote_code": True,
74 | "bf16": torch.cuda.is_bf16_supported(),
75 | "fp32": not torch.cuda.is_bf16_supported(),
76 | "revision": "d4efd21e866b9cb3466cb65b963933f5e98016d1",
77 | },
78 | ),
79 | ModelConfig(
80 | name="Qwen/Qwen-14B",
81 | default_lr=1e-5,
82 | eval_batch_size=2,
83 | gradient_checkpointing=True,
84 | model_parallel=True,
85 | # note: you will probably not be able to run this bf16 support and without many gpus
86 | custom_kwargs={
87 | "trust_remote_code": True,
88 | "bf16": torch.cuda.is_bf16_supported(),
89 | "fp32": not torch.cuda.is_bf16_supported(),
90 | "revision": "8be2854218fea9054331e217fd26a06f3fd02004",
91 | },
92 | ),
93 | ModelConfig(
94 | name="Qwen/Qwen-72B",
95 | default_lr=1e-5,
96 | eval_batch_size=1,
97 | gradient_checkpointing=True,
98 | model_parallel=True,
99 | # note: you will probably not be able to run this without bf16 support and many gpus
100 | custom_kwargs={
101 | "trust_remote_code": True,
102 | "bf16": torch.cuda.is_bf16_supported(),
103 | "fp32": not torch.cuda.is_bf16_supported(),
104 | "revision": "fec78c0e3b3b10dd9f0ce775c34a686a3255a7d1",
105 | },
106 | # This model is really big, save space by using adafactor.
107 | # Note that even then it will take up ~60GB per GPU on an 8-GPU machine.
108 | default_optimizer="adafactor",
109 | ),
110 | ]
111 | MODELS_DICT: Dict[str, ModelConfig] = {
112 | model_config.name: model_config for model_config in MODEL_CONFIGS
113 | }
114 |
115 |
116 | loss_dict = {
117 | "logconf": logconf_loss_fn(),
118 | "product": product_loss_fn(),
119 | "xent": xent_loss(),
120 | }
121 |
122 | VALID_LOSSES: List[str] = list(loss_dict.keys())
123 |
124 |
125 | def get_config_foldername(config: dict) -> str:
126 | def shorten_key(key: str) -> str:
127 | return "".join(word[0] for word in key.split("_"))
128 |
129 | def shorten_value(value) -> str:
130 | if isinstance(value, bool):
131 | return "1" if value else "0"
132 | elif isinstance(value, str):
133 | value = value.split("/")[-1]
134 | if "_" in value:
135 | return "_".join(word[:4] for word in value.split("_"))
136 | else:
137 | return value
138 | else:
139 | return str(value)
140 |
141 | return "-".join(f"{shorten_key(k)}={shorten_value(v)}" for k, v in sorted(config.items()))
142 |
143 |
144 | def main(
145 | batch_size: int = 32,
146 | max_ctx: int = 1024,
147 | ds_name: str = "sciq",
148 | loss: str = "xent",
149 | n_docs: int = 20000,
150 | n_test_docs: int = 10000,
151 | model_size: str = "gpt2",
152 | lr: Optional[float] = None,
153 | optim: Optional[str] = None,
154 | epochs: int = 2,
155 | force_retrain: bool = False,
156 | seed: int = 0,
157 | minibatch_size_per_device: Optional[float] = None,
158 | train_with_dropout: bool = False,
159 | results_folder: str = "/tmp/results",
160 | linear_probe: bool = False,
161 | lr_schedule: str = "cosine_anneal",
162 | # Note: you can pass either weak_model_size or weak_labels_path. If you pass
163 | # weak_model_size, we will guess the path to the weak labels based on the weak
164 | # model. If you pass weak_labels_path, we will use that path instead.
165 | # If you pass neither, we will train on ground truth.
166 | weak_model_size: Optional[str] = None,
167 | weak_labels_path: Optional[str] = None,
168 | sweep_subfolder: str = "default",
169 | # Set to a very large value so that by default we don't do any intermediate evals but
170 | # still do final evals (which requires eval_every to be set to a non-zero, non-None value)
171 | eval_every: int = 1000000,
172 | sync_command: Optional[str] = None,
173 | ):
174 | # this is per device!
175 | if minibatch_size_per_device is None:
176 | minibatch_size_per_device = 1
177 | assert ds_name in VALID_DATASETS, f"Unknown dataset {ds_name} not in {VALID_DATASETS}"
178 | assert (
179 | weak_model_size is None or weak_labels_path is None
180 | ), "Can't pass both weak_model_size and weak_labels_path"
181 | model_config = MODELS_DICT[model_size]
182 |
183 | use_default_lr = False
184 | if lr is None:
185 | assert (
186 | batch_size == 32
187 | ), "Learning rates were tuned on batch size 32, you probably want to sweep LR if you are tuning batch size"
188 | lr = model_config.default_lr
189 | use_default_lr = True
190 |
191 | if optim is None:
192 | optim = model_config.default_optimizer
193 |
194 | # The commented out terms are the ones that should not change final results
195 | config = {
196 | "batch_size": batch_size,
197 | "max_ctx": max_ctx,
198 | "ds_name": ds_name,
199 | "loss": loss,
200 | "n_docs": n_docs,
201 | "n_test_docs": n_test_docs,
202 | "model_size": model_size,
203 | "lr": lr,
204 | "optim": optim,
205 | "epochs": epochs,
206 | # "force_retrain": force_retrain,
207 | "seed": seed,
208 | # "minibatch_size_per_device": minibatch_size_per_device,
209 | "train_with_dropout": train_with_dropout,
210 | # "results_folder": results_folder,
211 | "linear_probe": linear_probe,
212 | "lr_schedule": lr_schedule,
213 | "eval_every": eval_every,
214 | # "sweep_subfolder": sweep_subfolder,
215 | }
216 |
217 | if weak_model_size is not None:
218 | weak_model_config = config.copy()
219 | weak_model_config["model_size"] = weak_model_size
220 | weak_model_config["loss"] = "xent"
221 | if use_default_lr:
222 | weak_model_config["lr"] = MODELS_DICT[weak_model_size].default_lr
223 |
224 | weak_model_config_name = get_config_foldername(weak_model_config)
225 |
226 | weak_labels_path = (
227 | results_folder + "/" + sweep_subfolder + "/" + weak_model_config_name + "/weak_labels"
228 | )
229 |
230 | eval_batch_size = model_config.eval_batch_size
231 | random.seed(seed)
232 |
233 | # Load dataset
234 | dataset = load_dataset(ds_name, seed=seed, split_sizes=dict(train=n_docs, test=n_test_docs))
235 |
236 | # Split the training dataset in half
237 | train_dataset, test_ds = dataset["train"], dataset["test"]
238 |
239 | if weak_labels_path is None:
240 | split_data = train_dataset.train_test_split(test_size=0.5, seed=seed)
241 | train1_ds, train2_ds = split_data["train"], split_data["test"]
242 | print("len(train1):", len(train1_ds), "len(train2):", len(train2_ds))
243 | config_name = get_config_foldername(config)
244 | else:
245 | if not weak_labels_path.endswith("weak_labels"):
246 | weak_labels_path = weak_labels_path + "/weak_labels"
247 | if sync_command is not None:
248 | sync_command_list = sync_command.split(" ")
249 | sync_command_list.extend(
250 | ["download", weak_labels_path.replace("/weak_labels", ""), results_folder]
251 | )
252 | print(f"Running sync command: {' '.join(sync_command_list)}")
253 | result = subprocess.run(sync_command_list, check=True)
254 | if result.returncode != 0:
255 | raise RuntimeError(f"Sync command failed with return code {result.returncode}")
256 | train1_ds = load_from_disk(weak_labels_path)
257 | train2_ds = None
258 |
259 | weak_model_config = json.load(open(weak_labels_path.replace("weak_labels", "config.json")))
260 | config["weak_model_size"] = weak_model_config["model_size"]
261 | config_name = get_config_foldername(config)
262 | config["weak_model"] = weak_model_config
263 |
264 | save_path = os.path.join(results_folder, sweep_subfolder, config_name)
265 | logger.configure(
266 | name="{sweep_subfolder}_{config_name}_{datetime_now}",
267 | save_path=save_path,
268 | sweep_subfolder=sweep_subfolder,
269 | config_name=config_name,
270 | )
271 | # Tokenize datasets
272 | tokenizer = get_tokenizer(model_config.name)
273 | train1_ds = tokenize_dataset(train1_ds, tokenizer, max_ctx)
274 | test_ds = tokenize_dataset(test_ds, tokenizer, max_ctx)
275 | if train2_ds:
276 | train2_ds = tokenize_dataset(train2_ds, tokenizer, max_ctx)
277 |
278 | loss_fn = loss_dict[loss]
279 | print(f"Training model model, size {model_size}")
280 | test_results, weak_ds = train_and_save_model(
281 | model_config,
282 | train1_ds,
283 | test_ds,
284 | inference_ds=train2_ds,
285 | batch_size=batch_size,
286 | save_path=save_path,
287 | loss_fn=loss_fn,
288 | lr=lr,
289 | epochs=epochs,
290 | force_retrain=force_retrain,
291 | eval_batch_size=eval_batch_size,
292 | minibatch_size_per_device=minibatch_size_per_device,
293 | train_with_dropout=train_with_dropout,
294 | linear_probe=linear_probe,
295 | lr_schedule=lr_schedule,
296 | optimizer_name=optim,
297 | eval_every=eval_every,
298 | )
299 |
300 | if weak_ds is not None:
301 | weak_ds.save_to_disk(save_path + "/" + "weak_labels")
302 |
303 | acc = np.mean([x["acc"] for x in test_results])
304 | res_dict = {"accuracy": acc}
305 | print("accuracy:", acc)
306 |
307 | with open(os.path.join(save_path, f"config.json"), "w") as f:
308 | json.dump(config, f, indent=2)
309 |
310 | with open(os.path.join(save_path, f"results_summary.json"), "w") as f:
311 | json.dump(res_dict, f, indent=2)
312 |
313 | if sync_command is not None:
314 | print("Syncing results to remote storage...")
315 | try:
316 | sync_command_list = sync_command.split(" ")
317 | sync_command_list.extend(["upload", save_path, results_folder])
318 | print(f"Running sync command: {' '.join(sync_command_list)}")
319 | result = subprocess.run(sync_command_list, check=True)
320 | if result.returncode != 0:
321 | raise RuntimeError(f"Sync command failed with return code {result.returncode}")
322 | except Exception as e:
323 | raise RuntimeError("Failed to sync results to remote storage.") from e
324 |
325 |
326 | if __name__ == "__main__":
327 | fire.Fire(main)
328 |
--------------------------------------------------------------------------------
/train_weak_to_strong.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import Dict, List, Optional, Sequence, Union
4 |
5 | import fire
6 | import numpy as np
7 | import torch
8 |
9 | import weak_to_strong.logger as logger
10 | from weak_to_strong.common import get_tokenizer
11 | from weak_to_strong.datasets import (VALID_DATASETS, load_dataset,
12 | tokenize_dataset)
13 | from weak_to_strong.loss import logconf_loss_fn, product_loss_fn, xent_loss
14 | from weak_to_strong.train import ModelConfig, train_and_save_model
15 |
16 | # NOTE learning rates are not particularly tuned, work somewhat reasonably at train batch size 32
17 | MODEL_CONFIGS = [
18 | ModelConfig(
19 | name="gpt2",
20 | default_lr=5e-5,
21 | eval_batch_size=32,
22 | custom_kwargs={
23 | "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
24 | },
25 | ),
26 | ModelConfig(
27 | name="gpt2-medium",
28 | default_lr=5e-5,
29 | eval_batch_size=32,
30 | custom_kwargs={
31 | "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
32 | },
33 | ),
34 | ModelConfig(
35 | name="gpt2-large",
36 | default_lr=1e-5,
37 | eval_batch_size=32,
38 | custom_kwargs={
39 | "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
40 | },
41 | ),
42 | ModelConfig(
43 | name="gpt2-xl",
44 | default_lr=1e-5,
45 | eval_batch_size=2,
46 | gradient_checkpointing=True,
47 | model_parallel=True,
48 | custom_kwargs={
49 | "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
50 | },
51 | ),
52 | ModelConfig(
53 | name="Qwen/Qwen-1_8B",
54 | default_lr=1e-5,
55 | eval_batch_size=2,
56 | gradient_checkpointing=True,
57 | model_parallel=True,
58 | custom_kwargs={
59 | "trust_remote_code": True,
60 | "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
61 | },
62 | ),
63 | ModelConfig(
64 | name="Qwen/Qwen-7B",
65 | default_lr=1e-5,
66 | eval_batch_size=2,
67 | gradient_checkpointing=True,
68 | model_parallel=True,
69 | # note: you will probably not be able to run this without many gpus
70 | custom_kwargs={
71 | "trust_remote_code": True,
72 | "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
73 | },
74 | ),
75 | ModelConfig(
76 | name="Qwen/Qwen-14B",
77 | default_lr=1e-5,
78 | eval_batch_size=2,
79 | gradient_checkpointing=True,
80 | model_parallel=True,
81 | # note: you will probably not be able to run this without bf16 support and many gpus
82 | custom_kwargs={
83 | "trust_remote_code": True,
84 | "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
85 | },
86 | ),
87 | ModelConfig(
88 | name="Qwen/Qwen-72B",
89 | default_lr=1e-5,
90 | eval_batch_size=1,
91 | gradient_checkpointing=True,
92 | model_parallel=True,
93 | # note: you will probably not be able to run this without bf16 support and many gpus
94 | custom_kwargs={
95 | "trust_remote_code": True,
96 | "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
97 | },
98 | # This model is really big, save space by using adafactor.
99 | # Note that even then it will take up ~60GB per GPU on an 8-GPU machine.
100 | default_optimizer="adafactor",
101 | ),
102 | ]
103 | MODELS_DICT: Dict[str, ModelConfig] = {
104 | model_config.name: model_config for model_config in MODEL_CONFIGS
105 | }
106 |
107 |
108 | loss_dict = {
109 | "logconf": logconf_loss_fn(),
110 | "product": product_loss_fn(),
111 | "xent": xent_loss(),
112 | }
113 |
114 | VALID_LOSSES: List[str] = list(loss_dict.keys())
115 |
116 |
117 | def main(
118 | batch_size: int = 32,
119 | max_ctx: int = 1024,
120 | ds_name: str = "sciq",
121 | transfer_loss: Union[str, Sequence[str]] = "xent,logconf",
122 | n_docs: int = 10000,
123 | n_test_docs: int = 200,
124 | weak_model_size: str = "gpt2",
125 | weak_lr: Optional[float] = None,
126 | strong_model_size: str = "gpt2-xl",
127 | strong_lr: Optional[float] = None,
128 | # Defaults to strong_lr
129 | transfer_lr: Optional[float] = None,
130 | # Optims default to default_optimizer in the model definitions
131 | weak_optim: Optional[str] = None,
132 | strong_optim: Optional[str] = None,
133 | transfer_optim: Optional[str] = None,
134 | gt_epochs: int = 2,
135 | # defaults to gt_epochs
136 | transfer_epochs: Optional[int] = None,
137 | force_retrain: bool = False,
138 | seed: int = 0,
139 | minibatch_size_per_device: Optional[int] = None,
140 | train_with_dropout: bool = False,
141 | results_folder: str = "/tmp/results",
142 | linear_probe: bool = False,
143 | lr_schedule: str = "cosine_anneal",
144 | log_prefix: str = "",
145 | # Set to an absurdly high value so we don't do intermediate evals by default.
146 | eval_every: int = 100000000,
147 | ):
148 | # this is per device!
149 | if minibatch_size_per_device is None:
150 | minibatch_size_per_device = 1
151 | assert ds_name in VALID_DATASETS, f"Unknown dataset {ds_name} not in {VALID_DATASETS}"
152 | if isinstance(transfer_loss, str):
153 | transfer_losses = transfer_loss.split(",")
154 | else:
155 | transfer_losses = transfer_loss
156 | del transfer_loss
157 | for tloss in transfer_losses:
158 | assert tloss in VALID_LOSSES, f"Unknown loss {tloss} not in {VALID_LOSSES}"
159 | assert (
160 | weak_model_size in MODELS_DICT
161 | ), f"Unknown model size {weak_model_size} not in {MODELS_DICT}"
162 | weak_model_config = MODELS_DICT[weak_model_size]
163 | assert (
164 | strong_model_size in MODELS_DICT
165 | ), f"Unknown model size {strong_model_size} not in {MODELS_DICT}"
166 | strong_model_config = MODELS_DICT[strong_model_size]
167 |
168 | if weak_lr is None:
169 | assert batch_size == 32
170 | weak_lr = weak_model_config.default_lr
171 | if strong_lr is None:
172 | assert batch_size == 32
173 | strong_lr = strong_model_config.default_lr
174 | if transfer_lr is None:
175 | transfer_lr = strong_lr
176 | if transfer_epochs is None:
177 | transfer_epochs = gt_epochs
178 |
179 | if weak_optim is None:
180 | weak_optim = weak_model_config.default_optimizer
181 | if strong_optim is None:
182 | strong_optim = strong_model_config.default_optimizer
183 | if transfer_optim is None:
184 | transfer_optim = strong_optim
185 |
186 | weak_eval_batch_size = weak_model_config.eval_batch_size
187 | strong_eval_batch_size = strong_model_config.eval_batch_size
188 |
189 | # Load dataset
190 | dataset = load_dataset(ds_name, seed=seed, split_sizes=dict(train=n_docs, test=n_test_docs))
191 |
192 | # Split the training dataset in half
193 | train_dataset, test_ds = dataset["train"], dataset["test"]
194 |
195 | split_data = train_dataset.train_test_split(test_size=0.5, seed=seed)
196 | train1_ds, train2_ds = split_data["train"], split_data["test"]
197 | print("len(train1):", len(train1_ds), "len(train2):", len(train2_ds))
198 |
199 | def train_model(
200 | model_config: ModelConfig,
201 | train_ds: torch.utils.data.Dataset,
202 | test_ds: torch.utils.data.Dataset,
203 | *,
204 | loss_type: str,
205 | label: str,
206 | subpath,
207 | lr,
208 | eval_batch_size,
209 | epochs=1,
210 | inference_ds: Optional[torch.utils.data.Dataset] = None,
211 | linear_probe: bool = False,
212 | optimizer_name: str = "adam",
213 | ):
214 | save_path = os.path.join(results_folder, subpath)
215 | linprobe_str = "_linprobe" if linear_probe else ""
216 | logger.configure(
217 | name="{log_prefix}{label}_{base_model_name}_{ds_name}_{loss_type}_{optimizer_name}_{lr}_{lr_schedule}{linprobe_str}_{datetime_now}",
218 | label=label,
219 | ds_name=ds_name,
220 | truncation_max_len=n_docs or "none",
221 | loss_type=loss_type,
222 | lr=lr,
223 | batch_size=batch_size,
224 | eval_batch_size=eval_batch_size,
225 | minibatch_size_per_device=minibatch_size_per_device,
226 | save_path=save_path,
227 | base_model_name=model_config.name,
228 | epochs=epochs,
229 | linprobe_str=linprobe_str,
230 | lr_schedule=lr_schedule,
231 | log_prefix=log_prefix,
232 | optimizer_name=optimizer_name,
233 | )
234 | # Tokenize datasets
235 | tokenizer = get_tokenizer(model_config.name)
236 | train_ds = tokenize_dataset(train_ds, tokenizer, max_ctx)
237 | test_ds = tokenize_dataset(test_ds, tokenizer, max_ctx)
238 | if inference_ds:
239 | inference_ds = tokenize_dataset(inference_ds, tokenizer, max_ctx)
240 |
241 | loss_fn = loss_dict[loss_type]
242 | return train_and_save_model(
243 | model_config,
244 | train_ds,
245 | test_ds,
246 | inference_ds=inference_ds,
247 | batch_size=batch_size,
248 | save_path=save_path,
249 | loss_fn=loss_fn,
250 | lr=lr,
251 | epochs=epochs,
252 | force_retrain=force_retrain,
253 | eval_batch_size=eval_batch_size,
254 | minibatch_size_per_device=minibatch_size_per_device,
255 | train_with_dropout=train_with_dropout,
256 | linear_probe=linear_probe,
257 | lr_schedule=lr_schedule,
258 | optimizer_name=optimizer_name,
259 | eval_every=eval_every,
260 | )
261 |
262 | # Train the weak model on the first half of the training data
263 | print(f"Training weak model, size {weak_model_size}")
264 | weak_test_results, weak_ds = train_model(
265 | weak_model_config,
266 | train1_ds,
267 | test_ds,
268 | loss_type="xent",
269 | label="weak",
270 | subpath=os.path.join("weak_model_gt", weak_model_size.replace("/", "_")),
271 | lr=weak_lr,
272 | eval_batch_size=weak_eval_batch_size,
273 | inference_ds=train2_ds,
274 | epochs=gt_epochs,
275 | linear_probe=linear_probe,
276 | optimizer_name=weak_optim,
277 | )
278 |
279 | # Train the strong model on the second half of the training data
280 | print(f"Training strong model, size {strong_model_size}")
281 | strong_test_results, _ = train_model(
282 | strong_model_config,
283 | train2_ds,
284 | test_ds,
285 | loss_type="xent",
286 | label="strong",
287 | subpath=os.path.join("strong_model_gt", strong_model_size.replace("/", "_")),
288 | lr=strong_lr,
289 | eval_batch_size=strong_eval_batch_size,
290 | epochs=gt_epochs,
291 | linear_probe=linear_probe,
292 | optimizer_name=strong_optim,
293 | )
294 |
295 | # Train the strong model on the second half of the training data with labels generated by the weak model
296 | all_transfer_test_results = {}
297 | for tloss in transfer_losses:
298 | print(
299 | f"Training transfer model, size {strong_model_size} on labels from {weak_model_size}, with loss {tloss}"
300 | )
301 | transfer_test_results, _ = train_model(
302 | strong_model_config,
303 | weak_ds,
304 | test_ds,
305 | loss_type=tloss,
306 | label="weak2strong",
307 | subpath=os.path.join(
308 | "strong_model_transfer",
309 | f"{weak_model_size.replace('/', '_')}_{strong_model_size.replace('/', '_')}_{tloss}",
310 | ),
311 | lr=transfer_lr,
312 | eval_batch_size=strong_eval_batch_size,
313 | epochs=transfer_epochs,
314 | linear_probe=linear_probe,
315 | optimizer_name=transfer_optim,
316 | )
317 | all_transfer_test_results[tloss] = transfer_test_results
318 | del transfer_test_results
319 |
320 | weak_acc = np.mean([x["acc"] for x in weak_test_results])
321 | strong_acc = np.mean([x["acc"] for x in strong_test_results])
322 | res_dict = {
323 | "weak_acc": weak_acc,
324 | "strong_acc": strong_acc,
325 | }
326 | print("weak acc:", weak_acc)
327 | print("strong acc:", strong_acc)
328 | for tloss, transfer_test_results in all_transfer_test_results.items():
329 | transfer_acc = np.mean([x["acc"] for x in transfer_test_results])
330 | res_dict[f"transfer_acc_{tloss}"] = transfer_acc
331 | print(f"transfer acc ({tloss}):", transfer_acc)
332 |
333 | with open(
334 | os.path.join(
335 | results_folder,
336 | f"{weak_model_size.replace('/', '_')}_{strong_model_size.replace('/', '_')}.results_summary.json",
337 | ),
338 | "w",
339 | ) as f:
340 | json.dump(
341 | res_dict,
342 | f,
343 | )
344 |
345 |
346 | # python train_weak_to_strong.py --batch_size 32 --max_ctx 512 --ds_name "sciq" --transfer_loss "logconf" --n_docs 1000 --n_test_docs 100 --weak_model_size "gpt2-medium" --strong_model_size "gpt2-large" --seed 42
347 | if __name__ == "__main__":
348 | fire.Fire(main)
349 |
--------------------------------------------------------------------------------
/vision/README.md:
--------------------------------------------------------------------------------
1 | # A Simple Weak-to-Strong Experiment on ImageNet
2 |
3 | We provide code for a simple weak-to-strong experiment on ImageNet.
4 | We generate the weak labels using an [AlexNet](https://pytorch.org/vision/main/models/generated/torchvision.models.alexnet.html) model pretrained on ImageNet and we use linear probes on top of [DINO](https://github.com/facebookresearch/dino) models
5 | as a strong student.
6 |
7 | The full training command:
8 |
9 | ```bash
10 | python3 run_weak_strong.py \
11 | data_path: \
12 | weak_model_name: \
13 | strong_model_name: \
14 | batch_size \
15 | seed \
16 | n_epochs \
17 | lr \
18 | n_train
19 | ```
20 | Parameters:
21 |
22 | * ```DATA_PATH``` — path to the base directory containing ImageNet data, see [torchvision page](https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageNet.html) for instructions; should contain files `ILSVRC2012_devkit_t12.tar.gz` and `ILSVRC2012_img_val.tar`
23 | * ```WEAK_MODEL``` — weak model name:
24 | - `"alexnet"` is the only default model and the only one currently implemented
25 | * ```STRONG_MODEL``` — weak model name:
26 | - `"resnet50_dino"` (default)
27 | - `"vitb8_dino"`
28 | * ```BATCH_SIZE``` — batch size for weak label generation and embedding extraction (default: `128`)
29 | * ```SEED``` — random seed for dataset shuffling (default: `0`)
30 | * ```EPOCHS``` — number of training epochs (default: `10`)
31 | * ```LR``` — initial learning rate (default: `1e-3`)
32 | * ```N_TRAIN``` — number of datapoints used to train the linear probe; `50000 - N_TRAIN` datapoints are used as test (default: `40000`)
33 |
34 |
35 |
36 | Example commands:
37 |
38 | ```bash
39 | # AlexNet → ResNet50 (DINO):
40 | python3 run_weak_strong.py --strong_model_name resnet50_dino --n_epochs 20
41 |
42 | # AlexNet → ViT-B/8 (DINO):
43 | python3 run_weak_strong.py --strong_model_name vitb8_dino --n_epochs 5
44 | ```
45 |
46 | With the commands above we get the following results (note that the results may not reproduce exactly due to randomness):
47 |
48 | | Model | Top-1 Accuracy |
49 | |-------------------------|----------------|
50 | | AlexNet | 56.6 |
51 | | Dino ResNet50 | 64.5 |
52 | | Dino ViT-B/8 | 74.0 |
53 | | AlexNet → DINO ResNet50 | 61.9 |
54 | | AlexNet → DINO ViT-B/8 | 66.6 |
55 |
56 | You can add new custom models to the `models.py` and new datasets to `data.py`.
57 |
--------------------------------------------------------------------------------
/vision/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 |
4 | RESIZE, CROP = 256, 224
5 | TRANSFORM = torchvision.transforms.Compose(
6 | [
7 | torchvision.transforms.Resize(RESIZE),
8 | torchvision.transforms.CenterCrop(CROP),
9 | torchvision.transforms.ToTensor(),
10 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
11 | ]
12 | )
13 |
14 |
15 | def get_imagenet(datapath, split, batch_size, shuffle, transform=TRANSFORM):
16 | ds = torchvision.datasets.ImageNet(root=datapath, split=split, transform=transform)
17 | loader = torch.utils.data.DataLoader(ds, shuffle=shuffle, batch_size=batch_size)
18 | return ds, loader
19 |
--------------------------------------------------------------------------------
/vision/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 |
4 |
5 | class HeadAndEmbedding(torch.nn.Module):
6 | def __init__(self, head):
7 | super(HeadAndEmbedding, self).__init__()
8 | self.head = head
9 |
10 | def forward(self, x):
11 | return x, self.head(x)
12 |
13 |
14 | def _alexnet_replace_fc(model):
15 | model.classifier = HeadAndEmbedding(model.classifier)
16 | return model
17 |
18 |
19 | def resnet50_dino():
20 | model = torch.hub.load("facebookresearch/dino:main", "dino_resnet50")
21 | return model
22 |
23 |
24 | def vitb8_dino():
25 | model = torch.hub.load("facebookresearch/dino:main", "dino_vitb8")
26 | return model
27 |
28 |
29 | def alexnet():
30 | model = torchvision.models.alexnet(pretrained=True)
31 | return _alexnet_replace_fc(model)
32 |
--------------------------------------------------------------------------------
/vision/run_weak_strong.py:
--------------------------------------------------------------------------------
1 | import fire
2 | import numpy as np
3 | import torch
4 | import tqdm
5 | from data import get_imagenet
6 | from models import alexnet, resnet50_dino, vitb8_dino
7 | from torch import nn
8 |
9 |
10 | def get_model(name):
11 | if name == "alexnet":
12 | model = alexnet()
13 | elif name == "resnet50_dino":
14 | model = resnet50_dino()
15 | elif name == "vitb8_dino":
16 | model = vitb8_dino()
17 | else:
18 | raise ValueError(f"Unknown model {name}")
19 | model.cuda()
20 | model.eval()
21 | model = nn.DataParallel(model)
22 | return model
23 |
24 |
25 | def get_embeddings(model, loader):
26 | all_embeddings, all_y, all_probs = [], [], []
27 |
28 | for x, y in tqdm.tqdm(loader):
29 | output = model(x.cuda())
30 | if len(output) == 2:
31 | embeddings, logits = output
32 | probs = torch.nn.functional.softmax(logits, dim=-1).detach().cpu()
33 | all_probs.append(probs)
34 | else:
35 | embeddings = output
36 |
37 | all_embeddings.append(embeddings.detach().cpu())
38 | all_y.append(y)
39 |
40 | all_embeddings = torch.cat(all_embeddings, axis=0)
41 | all_y = torch.cat(all_y, axis=0)
42 | if len(all_probs) > 0:
43 | all_probs = torch.cat(all_probs, axis=0)
44 | acc = (torch.argmax(all_probs, dim=1) == all_y).float().mean()
45 | else:
46 | all_probs = None
47 | acc = None
48 | return all_embeddings, all_y, all_probs, acc
49 |
50 |
51 | def train_logreg(
52 | x_train,
53 | y_train,
54 | eval_datasets,
55 | n_epochs=10,
56 | weight_decay=0.0,
57 | lr=1.0e-3,
58 | batch_size=100,
59 | n_classes=1000,
60 | ):
61 | x_train = x_train.float()
62 | train_ds = torch.utils.data.TensorDataset(x_train, y_train)
63 | train_loader = torch.utils.data.DataLoader(train_ds, shuffle=True, batch_size=batch_size)
64 |
65 | d = x_train.shape[1]
66 | model = torch.nn.Linear(d, n_classes).cuda()
67 | criterion = torch.nn.CrossEntropyLoss()
68 | optimizer = torch.optim.Adam(model.parameters(), weight_decay=weight_decay, lr=lr)
69 | n_batches = len(train_loader)
70 | n_iter = n_batches * n_epochs
71 | schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_iter)
72 |
73 | results = {f"{key}_all": [] for key in eval_datasets.keys()}
74 | for epoch in (pbar := tqdm.tqdm(range(n_epochs), desc="Epoch 0")):
75 | correct, total = 0, 0
76 | for x, y in train_loader:
77 | x, y = x.cuda(), y.cuda()
78 | optimizer.zero_grad()
79 | pred = model(x)
80 | loss = criterion(pred, y)
81 | loss.backward()
82 | optimizer.step()
83 | schedule.step()
84 | if len(y.shape) > 1:
85 | y = torch.argmax(y, dim=1)
86 | correct += (torch.argmax(pred, -1) == y).detach().float().sum().item()
87 | total += len(y)
88 | pbar.set_description(f"Epoch {epoch}, Train Acc {correct / total:.3f}")
89 |
90 | for key, (x_test, y_test) in eval_datasets.items():
91 | x_test = x_test.float().cuda()
92 | pred = torch.argmax(model(x_test), axis=-1).detach().cpu()
93 | acc = (pred == y_test).float().mean()
94 | results[f"{key}_all"].append(acc)
95 |
96 | for key in eval_datasets.keys():
97 | results[key] = results[f"{key}_all"][-1]
98 | return results
99 |
100 |
101 | def main(
102 | batch_size: int = 128,
103 | weak_model_name: str = "alexnet",
104 | strong_model_name: str = "resnet50_dino",
105 | n_train: int = 40000,
106 | seed: int = 0,
107 | data_path: str = "/root/",
108 | n_epochs: int = 10,
109 | lr: float = 1e-3,
110 | ):
111 | weak_model = get_model(weak_model_name)
112 | strong_model = get_model(strong_model_name)
113 | _, loader = get_imagenet(data_path, split="val", batch_size=batch_size, shuffle=False)
114 | print("Getting weak labels...")
115 | _, gt_labels, weak_labels, weak_acc = get_embeddings(weak_model, loader)
116 | print(f"Weak label accuracy: {weak_acc:.3f}")
117 | print("Getting strong embeddings...")
118 | embeddings, strong_gt_labels, _, _ = get_embeddings(strong_model, loader)
119 | assert torch.all(gt_labels == strong_gt_labels)
120 | del strong_gt_labels
121 |
122 | order = np.arange(len(embeddings))
123 | rng = np.random.default_rng(seed)
124 | rng.shuffle(order)
125 | x = embeddings[order]
126 | y = gt_labels[order]
127 | yw = weak_labels[order]
128 | x_train, x_test = x[:n_train], x[n_train:]
129 | y_train, y_test = y[:n_train], y[n_train:]
130 | yw_train, yw_test = yw[:n_train], yw[n_train:]
131 | yw_test = torch.argmax(yw_test, dim=1)
132 | eval_datasets = {"test": (x_test, y_test), "test_weak": (x_test, yw_test)}
133 |
134 | print("Training logreg on weak labels...")
135 | results_weak = train_logreg(x_train, yw_train, eval_datasets, n_epochs=n_epochs, lr=lr)
136 | print(f"Final accuracy: {results_weak['test']:.3f}")
137 | print(f"Final supervisor-student agreement: {results_weak['test_weak']:.3f}")
138 | print(f"Accuracy by epoch: {[acc.item() for acc in results_weak['test_all']]}")
139 | print(
140 | f"Supervisor-student agreement by epoch: {[acc.item() for acc in results_weak['test_weak_all']]}"
141 | )
142 |
143 | print("Training logreg on ground truth labels...")
144 | results_gt = train_logreg(x_train, y_train, eval_datasets, n_epochs=n_epochs, lr=lr)
145 | print(f"Final accuracy: {results_gt['test']:.3f}")
146 | print(f"Accuracy by epoch: {[acc.item() for acc in results_gt['test_all']]}")
147 |
148 | print("\n\n" + "=" * 100)
149 | print(f"Weak label accuracy: {weak_acc:.3f}")
150 | print(f"Weak→Strong accuracy: {results_weak['test']:.3f}")
151 | print(f"Strong accuracy: {results_gt['test']:.3f}")
152 | print("=" * 100)
153 |
154 |
155 | if __name__ == "__main__":
156 | fire.Fire(main)
157 |
--------------------------------------------------------------------------------
/weak-to-strong-setup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/weak-to-strong/6b450f2cee3714d6f886e5e1910bac73633bf69a/weak-to-strong-setup.png
--------------------------------------------------------------------------------
/weak_to_strong/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/weak-to-strong/6b450f2cee3714d6f886e5e1910bac73633bf69a/weak_to_strong/__init__.py
--------------------------------------------------------------------------------
/weak_to_strong/common.py:
--------------------------------------------------------------------------------
1 | import gc
2 |
3 | import torch
4 | from transformers import AutoTokenizer
5 |
6 |
7 | def get_tokenizer(model_name: str):
8 | """
9 | This function returns a tokenizer based on the model name.
10 |
11 | Parameters:
12 | model_name: The name of the model for which the tokenizer is needed.
13 |
14 | Returns:
15 | A tokenizer for the specified model.
16 | """
17 | return AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
18 |
19 |
20 | def clear_mem(verbose: bool = False):
21 | """
22 | This function is used to clear the memory allocated by PyTorch.
23 | It does so by calling the garbage collector to release unused GPU memory.
24 | After clearing the memory, it prints the current amount of memory still allocated by PyTorch (post-clean).
25 |
26 | Parameters:
27 | verbose (bool): Whether to print additional information.
28 | """
29 |
30 | gc.collect()
31 | torch.cuda.empty_cache()
32 | print(
33 | f"torch.cuda.memory_allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f}GB"
34 | )
35 |
36 | if verbose:
37 |
38 | def try_attr(x, a):
39 | try:
40 | return getattr(x, a)
41 | except:
42 | # amazing that this can cause...
43 | # (AttributeError, OSError, AssertionError, RuntimeError, ModuleNotFoundError)
44 | return None
45 |
46 | for obj in gc.get_objects():
47 | if torch.is_tensor(obj) or torch.is_tensor(try_attr(obj, "data")):
48 | print(type(obj), obj.size(), obj.dtype)
49 |
--------------------------------------------------------------------------------
/weak_to_strong/datasets.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from dataclasses import dataclass
3 | from random import Random
4 | from typing import Any, Callable, Optional
5 |
6 | from datasets import Dataset as HfDataset
7 | from datasets import load_dataset as hf_load_dataset
8 |
9 |
10 | @dataclass
11 | class DatasetConfig:
12 | # split -> unshuffled dataset of items
13 | loader: Callable[[str], HfDataset]
14 | # formats items to have keys 'txt' and 'hard_label', takes a random.Random rng
15 | formatter: Callable[[Any], Any]
16 |
17 |
18 | # mapping from dataset name to load function and format function
19 | _REGISTRY: dict[str, DatasetConfig] = {}
20 |
21 |
22 | def register_dataset(name: str, config: DatasetConfig):
23 | _REGISTRY[name] = config
24 |
25 |
26 | def load_dataset(ds_name: str, seed: int = 0, split_sizes: Optional[dict] = None):
27 | if split_sizes is None:
28 | split_sizes = dict(train=None, test=None)
29 |
30 | if ds_name not in _REGISTRY:
31 | raise ValueError(f"Unknown dataset {ds_name}, please register")
32 | cfg = _REGISTRY[ds_name]
33 | results = {}
34 | for split, n_docs in split_sizes.items():
35 | ds = cfg.loader(split)
36 | try:
37 | ds = ds.select(range(n_docs))
38 | except IndexError as e:
39 | print(f"Warning {ds_name} has less than {n_docs} docs, using all: {e}")
40 | ds = ds.map(functools.partial(cfg.formatter, rng=Random(seed)))
41 | ds = ds.map(
42 | lambda ex: {"soft_label": [1 - float(ex["hard_label"]), float(ex["hard_label"])]}
43 | )
44 | ds = ds.shuffle(seed=seed) # shuffling a bit pointless for test set but wtv
45 | results[split] = ds
46 | return results
47 |
48 |
49 | def tokenize_dataset(
50 | raw_ds: HfDataset,
51 | tokenizer: Callable,
52 | max_ctx: int,
53 | ):
54 | """
55 | This function prepares the dataset for training. It takes the raw dataset, a formatting function,
56 | a tokenizer, a maximum context length
57 |
58 | Parameters:
59 | raw_ds: The raw dataset to be processed.
60 | tokenizer: The tokenizer to be used on the formatted dataset.
61 | max_ctx: The maximum context length for the tokenizer.
62 |
63 | Returns:
64 | ds: The processed and shuffled dataset ready for training.
65 | """
66 |
67 | def process_function(res):
68 | toks = tokenizer(res["txt"])
69 | return dict(
70 | input_ids=toks["input_ids"],
71 | )
72 |
73 | ds = raw_ds.map(process_function, batched=False).filter(lambda x: len(x["input_ids"]) < max_ctx)
74 | return ds
75 |
76 |
77 | def hf_loader(*hf_name, split_names=None):
78 | if split_names is None:
79 | split_names = dict()
80 | return lambda split: hf_load_dataset(*hf_name, split=split_names.get(split, split))
81 |
82 |
83 | ##########
84 | # ACTUAL DATASETS
85 | ##########
86 |
87 |
88 | def format_amazon_polarity(ex, rng):
89 | return dict(txt=f"{ex['title']} {ex['content']}", hard_label=ex["label"])
90 |
91 |
92 | register_dataset(
93 | "amazon_polarity",
94 | DatasetConfig(loader=hf_loader("amazon_polarity"), formatter=format_amazon_polarity),
95 | )
96 |
97 |
98 | def format_sciq(ex, rng):
99 | hard_label = int(rng.random() < 0.5)
100 | if hard_label:
101 | ans = ex["correct_answer"]
102 | else:
103 | ans = rng.choice([ex["distractor1"], ex["distractor2"], ex["distractor3"]])
104 | txt = f"Q: {ex['question']} A: {ans}"
105 | return dict(txt=txt, hard_label=hard_label)
106 |
107 |
108 | register_dataset(
109 | "sciq",
110 | DatasetConfig(loader=hf_loader("sciq"), formatter=format_sciq),
111 | )
112 |
113 |
114 | def format_anthropic_hh(ex, rng):
115 | hard_label = int(rng.random() < 0.5)
116 | txt = ex["chosen"] if hard_label else ex["rejected"]
117 | return dict(txt=txt, hard_label=hard_label)
118 |
119 |
120 | register_dataset(
121 | "anthropic_hh",
122 | DatasetConfig(loader=hf_loader("Anthropic/hh-rlhf"), formatter=format_anthropic_hh),
123 | )
124 |
125 |
126 | def format_cosmosqa(ex, rng):
127 | true_answer = ex["answer" + str(ex["label"])]
128 | if "None of the above choices ." in true_answer:
129 | hard_label = 0
130 | else:
131 | assert "None of the above choices" not in true_answer, true_answer
132 | hard_label = int(rng.random() < 0.5)
133 | if hard_label:
134 | answer = true_answer
135 | else:
136 | candidate_answers = [ex["answer" + str(i)] for i in range(4)]
137 | answer = rng.choice([x for x in candidate_answers if x != true_answer])
138 | txt = f"Context: {ex['context']}\nQuestion: {ex['question']}\nAnswer: {answer}"
139 | return dict(txt=txt, hard_label=hard_label)
140 |
141 |
142 | register_dataset(
143 | "cosmos_qa",
144 | DatasetConfig(
145 | loader=hf_loader("cosmos_qa", split_names=dict(test="validation")),
146 | formatter=format_cosmosqa,
147 | ),
148 | )
149 |
150 |
151 | def format_boolq(ex, rng):
152 | hard_label = int(ex["answer"])
153 | txt = f"Passage: {ex['passage']}\nQuestion: {ex['question']}"
154 | return dict(txt=txt, hard_label=hard_label)
155 |
156 |
157 | register_dataset(
158 | "boolq",
159 | DatasetConfig(
160 | loader=hf_loader("boolq", split_names=dict(test="validation")), formatter=format_boolq
161 | ),
162 | )
163 |
164 |
165 | VALID_DATASETS: list[str] = list(_REGISTRY.keys())
166 |
167 | """
168 | from datasets import disable_caching
169 | disable_caching()
170 |
171 | from weak_to_strong.datasets import load_dataset, VALID_DATASETS
172 | import numpy as np
173 |
174 | ds_name = "boolq"
175 | print(VALID_DATASETS)
176 |
177 | ds = load_dataset(ds_name, split_sizes=dict(train=500, test=10))
178 | train = list(ds['train'])
179 | test = list(ds['test'])
180 | print(test[0])
181 | print(np.mean([x['hard_label'] for x in train]))
182 | """
183 |
--------------------------------------------------------------------------------
/weak_to_strong/eval.py:
--------------------------------------------------------------------------------
1 | import datasets
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 |
6 |
7 | def to_batch(x, batch_size):
8 | for i in range(0, len(x), batch_size):
9 | yield x[i : i + batch_size]
10 |
11 |
12 | def unpack(x):
13 | assert isinstance(x, torch.Tensor), type(x)
14 | return x.detach().float().cpu().numpy().tolist()
15 |
16 |
17 | def eval_model_acc(model: nn.Module, ds: datasets.Dataset, eval_batch_size: int = 16) -> None:
18 | """
19 | This function evaluates the accuracy of a given model on a given dataset.
20 |
21 | Parameters:
22 | model (nn.Module): The model to be evaluated.
23 | ds (datasets.Dataset): The dataset on which the model is to be evaluated.
24 |
25 | Returns:
26 | results (list): A list of dictionaries containing the input_ids, ground truth label, predicted label,
27 | accuracy of prediction, logits and soft label for each example in the dataset.
28 | """
29 |
30 | model.eval()
31 |
32 | with torch.no_grad():
33 | results = []
34 | # for ex in ds:
35 | for batch in to_batch(ds, eval_batch_size):
36 | # pad input_ids to common length
37 | input_ids = torch.nn.utils.rnn.pad_sequence(
38 | [torch.tensor(ex) for ex in batch["input_ids"]], batch_first=True
39 | ).to(model.device if hasattr(model, "device") else "cpu")
40 | labels = batch["soft_label"]
41 | # run forward pass
42 | raw_logits = model(input_ids)
43 |
44 | probs = unpack(torch.nn.functional.softmax(raw_logits, dim=-1))
45 | logits = unpack(raw_logits)
46 |
47 | preds = np.argmax(probs, axis=-1)
48 | labels = np.argmax(labels, axis=-1)
49 |
50 | results.extend(
51 | [
52 | dict(
53 | txt=txt,
54 | input_ids=input_id,
55 | gt_label=label,
56 | hard_label=pred,
57 | acc=label == pred,
58 | logits=logit,
59 | soft_label=prob,
60 | )
61 | for input_id, txt, label, pred, prob, logit in zip(
62 | batch["input_ids"], batch["txt"], labels, preds, probs, logits
63 | )
64 | ]
65 | )
66 | accs = [r["acc"] for r in results]
67 | print("Accuracy:", np.mean(accs), "+/-", np.std(accs) / np.sqrt(len(accs)))
68 |
69 | return datasets.Dataset.from_list(results)
70 |
--------------------------------------------------------------------------------
/weak_to_strong/logger.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from datetime import datetime
4 |
5 | import wandb
6 |
7 |
8 | def append_to_jsonl(path: str, data: dict):
9 | with open(path, "a") as f:
10 | f.write(json.dumps(data) + "\n")
11 |
12 |
13 | class WandbLogger(object):
14 | CURRENT = None
15 |
16 | log_path = None
17 |
18 | def __init__(
19 | self,
20 | **kwargs,
21 | ):
22 | project = os.environ.get("WANDB_PROJECT")
23 | self.use_wandb = project is not None
24 | if self.use_wandb:
25 | wandb.init(
26 | config=kwargs,
27 | project=project,
28 | name=kwargs["name"].format(
29 | **kwargs, datetime_now=datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
30 | )
31 | if "name" in kwargs
32 | else None,
33 | )
34 | if "save_path" in kwargs:
35 | self.log_path = os.path.join(kwargs["save_path"], "log.jsonl")
36 | if not os.path.exists(kwargs["save_path"]):
37 | os.makedirs(kwargs["save_path"])
38 | self._log_dict = {}
39 |
40 | def logkv(self, key, value):
41 | self._log_dict[key] = value
42 |
43 | def logkvs(self, d):
44 | self._log_dict.update(d)
45 |
46 | def dumpkvs(self):
47 | if self.use_wandb:
48 | wandb.log(self._log_dict)
49 | if self.log_path is not None:
50 | append_to_jsonl(self.log_path, self._log_dict)
51 | self._log_dict = {}
52 |
53 | def shutdown(self):
54 | if self.use_wandb:
55 | wandb.finish()
56 |
57 |
58 | def is_configured():
59 | return WandbLogger.CURRENT is not None
60 |
61 |
62 | def get_current():
63 | assert is_configured(), "WandbLogger is not configured"
64 | return WandbLogger.CURRENT
65 |
66 |
67 | def configure(**kwargs):
68 | if is_configured():
69 | WandbLogger.CURRENT.shutdown()
70 | WandbLogger.CURRENT = WandbLogger(**kwargs)
71 | return WandbLogger.CURRENT
72 |
73 |
74 | def logkv(key, value):
75 | assert is_configured(), "WandbLogger is not configured"
76 | WandbLogger.CURRENT.logkv(key, value)
77 |
78 |
79 | def logkvs(d):
80 | assert is_configured(), "WandbLogger is not configured"
81 | WandbLogger.CURRENT.logkvs(d)
82 |
83 |
84 | def dumpkvs():
85 | assert is_configured(), "WandbLogger is not configured"
86 | WandbLogger.CURRENT.dumpkvs()
87 |
88 |
89 | def shutdown():
90 | assert is_configured(), "WandbLogger is not configured"
91 | WandbLogger.CURRENT.shutdown()
92 | WandbLogger.CURRENT = None
93 |
--------------------------------------------------------------------------------
/weak_to_strong/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class LossFnBase:
5 | def __call__(
6 | self,
7 | logits: torch.Tensor,
8 | labels: torch.Tensor,
9 | **kwargs,
10 | ) -> torch.Tensor:
11 | """
12 | This function calculates the loss between logits and labels.
13 | """
14 | raise NotImplementedError
15 |
16 |
17 | # Custom loss function
18 | class xent_loss(LossFnBase):
19 | def __call__(
20 | self, logits: torch.Tensor, labels: torch.Tensor, step_frac: float
21 | ) -> torch.Tensor:
22 | """
23 | This function calculates the cross entropy loss between logits and labels.
24 |
25 | Parameters:
26 | logits: The predicted values.
27 | labels: The actual values.
28 | step_frac: The fraction of total training steps completed.
29 |
30 | Returns:
31 | The mean of the cross entropy loss.
32 | """
33 | loss = torch.nn.functional.cross_entropy(logits, labels)
34 | return loss.mean()
35 |
36 |
37 | class product_loss_fn(LossFnBase):
38 | """
39 | This class defines a custom loss function for product of predictions and labels.
40 |
41 | Attributes:
42 | alpha: A float indicating how much to weigh the weak model.
43 | beta: A float indicating how much to weigh the strong model.
44 | warmup_frac: A float indicating the fraction of total training steps for warmup.
45 | """
46 |
47 | def __init__(
48 | self,
49 | alpha: float = 1.0, # how much to weigh the weak model
50 | beta: float = 1.0, # how much to weigh the strong model
51 | warmup_frac: float = 0.1, # in terms of fraction of total training steps
52 | ):
53 | self.alpha = alpha
54 | self.beta = beta
55 | self.warmup_frac = warmup_frac
56 |
57 | def __call__(
58 | self,
59 | logits: torch.Tensor,
60 | labels: torch.Tensor,
61 | step_frac: float,
62 | ) -> torch.Tensor:
63 | preds = torch.softmax(logits, dim=-1)
64 | target = torch.pow(preds, self.beta) * torch.pow(labels, self.alpha)
65 | target /= target.sum(dim=-1, keepdim=True)
66 | target = target.detach()
67 | loss = torch.nn.functional.cross_entropy(logits, target, reduction="none")
68 | return loss.mean()
69 |
70 |
71 | class logconf_loss_fn(LossFnBase):
72 | """
73 | This class defines a custom loss function for log confidence.
74 |
75 | Attributes:
76 | aux_coef: A float indicating the auxiliary coefficient.
77 | warmup_frac: A float indicating the fraction of total training steps for warmup.
78 | """
79 |
80 | def __init__(
81 | self,
82 | aux_coef: float = 0.5,
83 | warmup_frac: float = 0.1, # in terms of fraction of total training steps
84 | ):
85 | self.aux_coef = aux_coef
86 | self.warmup_frac = warmup_frac
87 |
88 | def __call__(
89 | self,
90 | logits: torch.Tensor,
91 | labels: torch.Tensor,
92 | step_frac: float,
93 | ) -> torch.Tensor:
94 | logits = logits.float()
95 | labels = labels.float()
96 | coef = 1.0 if step_frac > self.warmup_frac else step_frac
97 | coef = coef * self.aux_coef
98 | preds = torch.softmax(logits, dim=-1)
99 | mean_weak = torch.mean(labels, dim=0)
100 | assert mean_weak.shape == (2,)
101 | threshold = torch.quantile(preds[:, 0], mean_weak[1])
102 | strong_preds = torch.cat(
103 | [(preds[:, 0] >= threshold)[:, None], (preds[:, 0] < threshold)[:, None]],
104 | dim=1,
105 | )
106 | target = labels * (1 - coef) + strong_preds.detach() * coef
107 | loss = torch.nn.functional.cross_entropy(logits, target, reduction="none")
108 | return loss.mean()
109 |
--------------------------------------------------------------------------------
/weak_to_strong/model.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 | from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
5 |
6 |
7 | @dataclass
8 | class HeadOutput:
9 | logits: torch.FloatTensor
10 |
11 |
12 | class TransformerWithHead(PreTrainedModel):
13 | """
14 | This class initializes the linear head to zeros
15 | """
16 |
17 | def __init__(self, name, linear_probe=False, **kwargs):
18 | config = AutoConfig.from_pretrained(name, **kwargs)
19 | super().__init__(config)
20 | self.num_labels = config.num_labels
21 | lm = AutoModelForCausalLM.from_pretrained(name, **kwargs)
22 | self.lm = lm
23 | self.transformer = lm.transformer
24 | hidden_size = getattr(config, "n_embd", getattr(config, "hidden_size", None))
25 | self.score = torch.nn.Linear(hidden_size, self.num_labels, bias=False).to(
26 | lm.lm_head.weight.dtype
27 | )
28 | torch.nn.init.normal_(self.score.weight, std=0.0)
29 | self.linear_probe = linear_probe
30 |
31 | @classmethod
32 | def from_pretrained(cls, name, **kwargs):
33 | return cls(name, **kwargs)
34 |
35 | def gradient_checkpointing_enable(self):
36 | model = self.transformer
37 | (
38 | model if hasattr(model, "save_pretrained") else model.module
39 | ).gradient_checkpointing_enable()
40 |
41 | def forward(self, input_ids: torch.LongTensor):
42 | """
43 | Forward pass of the model with a linear head.
44 |
45 | Parameters:
46 | input_ids (torch.LongTensor): Input tensor containing the token ids.
47 |
48 | Returns:
49 | HeadOutput: Output dataclass containing the logits.
50 | """
51 | input_lens = (input_ids != 0).sum(dim=-1)
52 | transformer_outputs = self.transformer(input_ids)
53 | hidden_states = torch.stack(
54 | [transformer_outputs[0][i, input_lens[i] - 1, :] for i in range(len(input_lens))]
55 | )
56 | self.score.to(hidden_states.device)
57 | if self.linear_probe:
58 | hidden_states = hidden_states.detach()
59 | logits = self.score(hidden_states)
60 | return logits
61 |
--------------------------------------------------------------------------------
/weak_to_strong/train.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import os
3 | import pickle
4 | import time
5 | from dataclasses import dataclass
6 | from typing import Callable, Optional
7 |
8 | import datasets
9 | import numpy as np
10 | import torch
11 | import torch_optimizer as toptim
12 | from transformers.modeling_utils import load_sharded_checkpoint
13 |
14 | import weak_to_strong.logger as logger
15 | from weak_to_strong.common import clear_mem
16 | from weak_to_strong.eval import eval_model_acc
17 | from weak_to_strong.loss import xent_loss
18 | from weak_to_strong.model import TransformerWithHead
19 |
20 |
21 | @dataclass
22 | class ModelConfig:
23 | name: str
24 | default_lr: float
25 | eval_batch_size: int
26 | custom_kwargs: Optional[dict] = None
27 | gradient_checkpointing: bool = False
28 | model_parallel: bool = False
29 | default_optimizer: str = "adam"
30 |
31 |
32 | def train_model(
33 | model: torch.nn.Module,
34 | ds: datasets.Dataset,
35 | batch_size: int,
36 | lr: float = 1e-5,
37 | loss_fn: Callable = xent_loss,
38 | log_every: int = 10,
39 | eval_every: int = 100,
40 | eval_batch_size: int = 256,
41 | minibatch_size: int = 8,
42 | eval_ds: Optional[datasets.Dataset] = None,
43 | gradient_checkpointing: bool = False,
44 | train_with_dropout: bool = False,
45 | epochs: int = 1,
46 | lr_schedule: str = "cosine_anneal",
47 | optimizer_name: str = "adam",
48 | ):
49 | print("LR", lr, "batch_size", batch_size, "minibatch_size", minibatch_size)
50 | assert batch_size % minibatch_size == 0, "batch size must be divisible by minibatch size"
51 | # we purposefully turn off dropout, for determinism
52 | # this seems to help for 1 epoch finetuning anyways
53 | if train_with_dropout:
54 | model.train()
55 | else:
56 | model.eval()
57 | if gradient_checkpointing:
58 | (
59 | model if hasattr(model, "gradient_checkpointing_enable") else model.module
60 | ).gradient_checkpointing_enable()
61 |
62 | nsteps = len(ds) * epochs // batch_size
63 |
64 | def lr_schedule_fn(step):
65 | if lr_schedule == "constant":
66 | return 1
67 | else:
68 | assert False, f"invalid lr schedule, {lr_schedule}, must be constant or cosine_anneal"
69 |
70 | if optimizer_name.lower() == "adam":
71 | optimizer = torch.optim.Adam(model.parameters(), lr=lr)
72 | elif optimizer_name.lower() == "adafactor":
73 | optimizer = toptim.Adafactor(model.parameters(), lr=lr)
74 | else:
75 | assert False, f"invalid optimizer {optimizer_name}, must be adam or adafactor"
76 | if lr_schedule == "cosine_anneal":
77 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, nsteps)
78 | else:
79 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule_fn)
80 | step = 0
81 | it = itertools.chain.from_iterable(itertools.repeat(ds, epochs))
82 | losses = []
83 | accuracies = []
84 | eval_acc_dict = {}
85 |
86 | # If the model is wrapped by DataParallel, it doesn't have a device. In this case,
87 | # we use GPU 0 as the output device. This sadly means that this device will store
88 | # a bit more data than other ones, but hopefully should not be too big of a deal.
89 | io_device = model.device if hasattr(model, "device") else 0
90 |
91 | while step < nsteps:
92 | loss_tot = 0
93 | if eval_every and (step + 1) % eval_every == 0:
94 | eval_results = eval_model_acc(model, eval_ds, eval_batch_size)
95 | if gradient_checkpointing:
96 | (
97 | model if hasattr(model, "gradient_checkpointing_enable") else model.module
98 | ).gradient_checkpointing_enable()
99 | if train_with_dropout:
100 | model.train()
101 | eval_accs = np.mean([r["acc"] for r in eval_results])
102 | eval_acc_dict[step] = eval_accs
103 | logger.logkv("eval_accuracy", eval_accs)
104 | all_logits = []
105 | all_labels = []
106 | for i in range(batch_size // minibatch_size):
107 | try:
108 | mbatch = [next(it) for _ in range(minibatch_size)]
109 | except StopIteration:
110 | break
111 | input_ids = (
112 | torch.nn.utils.rnn.pad_sequence([torch.tensor(ex["input_ids"]) for ex in mbatch])
113 | .transpose(
114 | 0,
115 | 1,
116 | )
117 | .to(io_device)
118 | )
119 | labels = torch.tensor([ex["soft_label"] for ex in mbatch]).to(io_device)
120 |
121 | logits = model(input_ids)
122 |
123 | all_logits.extend(logits.to(io_device))
124 | all_labels.extend(labels)
125 | all_logits = torch.stack(all_logits)
126 | all_labels = torch.stack(all_labels)
127 | loss = loss_fn(all_logits, all_labels, step_frac=step / nsteps)
128 | loss_tot += loss.item()
129 | loss.backward()
130 | losses.append(loss_tot)
131 | accuracies.append(
132 | torch.mean(
133 | (torch.argmax(all_logits, dim=1) == torch.argmax(all_labels, dim=1)).to(
134 | torch.float32
135 | )
136 | ).item()
137 | )
138 | logger.logkvs(
139 | {
140 | "step": step,
141 | "progress": step / nsteps,
142 | "loss": loss_tot,
143 | "train_accuracy": accuracies[-1],
144 | "lr": lr_scheduler.get_last_lr()[0],
145 | }
146 | )
147 | optimizer.step()
148 | optimizer.zero_grad()
149 | lr_scheduler.step()
150 | if log_every and step % log_every == 0:
151 | print(
152 | f"Step: {step}/{nsteps} Recent losses: {np.mean(losses)} {np.mean(accuracies)} {len(losses)}"
153 | )
154 | losses = []
155 | accuracies = []
156 | step += 1
157 | logger.dumpkvs()
158 | final_eval_results = None
159 | if eval_every:
160 | print("Final evaluation:")
161 | final_eval_results = eval_model_acc(model, eval_ds, eval_batch_size)
162 | logger.logkv("eval_accuracy", np.mean([r["acc"] for r in final_eval_results]))
163 | logger.dumpkvs()
164 | return final_eval_results
165 |
166 |
167 | def train_and_save_model(
168 | model_config: ModelConfig,
169 | train_ds: datasets.Dataset,
170 | test_ds: datasets.Dataset,
171 | inference_ds: Optional[datasets.Dataset] = None,
172 | *,
173 | batch_size: int,
174 | lr: float,
175 | epochs: int,
176 | eval_batch_size: Optional[int] = None,
177 | minibatch_size_per_device: Optional[int] = None,
178 | save_path: Optional[str] = None,
179 | loss_fn: Callable = xent_loss,
180 | label: str = "default",
181 | force_retrain: bool = False,
182 | train_with_dropout: bool = False,
183 | linear_probe: bool = False,
184 | lr_schedule: str = "constant",
185 | optimizer_name: str = "adam",
186 | eval_every: Optional[int] = None,
187 | ):
188 | if eval_batch_size is None:
189 | eval_batch_size = batch_size
190 |
191 | if minibatch_size_per_device is None:
192 | minibatch_size_per_device = 1
193 |
194 | gradient_checkpointing = model_config.gradient_checkpointing
195 | custom_kwargs = model_config.custom_kwargs or {}
196 |
197 | def maybe_load_model(model):
198 | if os.path.exists(os.path.join(save_path, "results.pkl")) and not force_retrain:
199 | print("loading from", save_path)
200 | checkpoint_path = os.path.join(save_path, "pytorch_model.bin")
201 | if not os.path.exists(checkpoint_path):
202 | # Assume this means we have a sharded checkpoint, and load it appropriately
203 | load_sharded_checkpoint(model, checkpoint_path)
204 | else:
205 | state_dict = torch.load(os.path.join(save_path, "pytorch_model.bin"))
206 | state_dict = {
207 | k.replace("transformer.module", "transformer"): v
208 | for (k, v) in state_dict.items()
209 | }
210 | custom_kwargs["state_dict"] = state_dict
211 | return True
212 | return False
213 |
214 | already_trained = False
215 | # Load the model
216 | if model_config.model_parallel:
217 | assert torch.cuda.device_count() > 1, f"you might want more gpus for {model_config.name}"
218 | model = TransformerWithHead.from_pretrained(
219 | model_config.name,
220 | num_labels=2,
221 | device_map="auto",
222 | linear_probe=linear_probe,
223 | **custom_kwargs,
224 | )
225 | already_trained = maybe_load_model(model)
226 | # slight misnomer, more like minibatch_size_per_dp_replica
227 | minibatch_size = minibatch_size_per_device
228 | else:
229 | model = TransformerWithHead.from_pretrained(
230 | model_config.name, num_labels=2, linear_probe=linear_probe, **custom_kwargs
231 | ).to("cuda")
232 | already_trained = maybe_load_model(model)
233 | # data parallel: currently not supported with model parallel
234 |
235 | minibatch_size = min(minibatch_size_per_device * torch.cuda.device_count(), batch_size)
236 |
237 | if torch.cuda.device_count() > 1:
238 | model = torch.nn.DataParallel(model, output_device=0)
239 | print(
240 | "Using",
241 | torch.cuda.device_count(),
242 | "GPUs, setting minibatch_size to",
243 | minibatch_size,
244 | )
245 | else:
246 | minibatch_size = minibatch_size_per_device
247 |
248 | if already_trained:
249 | test_results = eval_model_acc(model, test_ds, eval_batch_size)
250 | else:
251 | start = time.time()
252 | test_results = train_model(
253 | model,
254 | train_ds,
255 | batch_size,
256 | lr=lr,
257 | epochs=epochs,
258 | eval_ds=test_ds,
259 | gradient_checkpointing=gradient_checkpointing,
260 | loss_fn=loss_fn,
261 | eval_batch_size=eval_batch_size,
262 | eval_every=eval_every,
263 | minibatch_size=minibatch_size,
264 | train_with_dropout=train_with_dropout,
265 | lr_schedule=lr_schedule,
266 | optimizer_name=optimizer_name,
267 | )
268 | print("Model training took", time.time() - start, "seconds")
269 | if save_path:
270 | # Note: If the model is wrapped by DataParallel, we need to unwrap it before saving
271 | (model if hasattr(model, "save_pretrained") else model.module).save_pretrained(
272 | save_path
273 | )
274 | print("saved", save_path)
275 |
276 | inference_results = None
277 | if inference_ds:
278 | inference_results = eval_model_acc(model, inference_ds, eval_batch_size)
279 | logger.logkv("inference_accuracy", np.mean([r["acc"] for r in inference_results]))
280 |
281 | if save_path:
282 | with open(os.path.join(save_path, "results.pkl"), "wb") as f:
283 | pickle.dump(
284 | {
285 | "avg_acc_test": float(np.mean([r["acc"] for r in test_results])),
286 | "avg_acc_inference": float(
287 | np.mean([r["acc"] for r in inference_results] if inference_results else [])
288 | ),
289 | "test_results": test_results,
290 | "inference_results": inference_results if inference_results else [],
291 | },
292 | f,
293 | )
294 | # try to clean up memory
295 | clear_mem()
296 | logger.shutdown()
297 |
298 | return test_results, inference_results
299 |
--------------------------------------------------------------------------------