├── w2s ├── __init__.py ├── gpu_pool.py ├── utils.py ├── roc_auc.py ├── sft_config.py ├── topo.py ├── probe.py ├── loss.py ├── model.py ├── sft_utils.py ├── logistic.py ├── sft.py └── ds_registry.py ├── .pre-commit-config.yaml ├── LICENSE ├── pyproject.toml ├── README.md ├── .gitignore └── run.py /w2s/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.4.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-added-large-files 10 | - repo: https://github.com/psf/black 11 | rev: 23.3.0 12 | hooks: 13 | - id: black 14 | - repo: https://github.com/charliermarsh/ruff-pre-commit 15 | rev: 'v0.0.262' 16 | hooks: 17 | - id: ruff 18 | args: [--fix, --exit-non-zero-on-fix, --line-length=100] 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 EleutherAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "w2s" 7 | description = "Experimenting with weak-to-strong generalization in deep learning" 8 | readme = "README.md" 9 | requires-python = ">=3.10" 10 | keywords = ["ai", "interpretability", "generalization"] 11 | license = {text = "MIT License"} 12 | dependencies = [ 13 | "datasets", 14 | "torch", 15 | "peft", 16 | "scipy", 17 | "simple-parsing", 18 | "fire ~= 0.4", 19 | "pynvml ~= 11.5", 20 | "scikit-learn ~= 1.3.2", 21 | # 4.0 introduced the breaking change of using return_dict=True by default 22 | "transformers>=4.0.0", 23 | "wandb", 24 | ] 25 | version = "0.0.1" 26 | 27 | [project.optional-dependencies] 28 | dev = [ 29 | "pre-commit", 30 | ] 31 | 32 | [tool.pyright] 33 | include = ["w2s*"] 34 | reportPrivateImportUsage = false 35 | 36 | [tool.pytest.ini_options] 37 | testpaths = ["tests"] 38 | 39 | [tool.setuptools.packages.find] 40 | include = ["w2s*"] 41 | 42 | [tool.ruff] 43 | # Enable pycodestyle (`E`), Pyflakes (`F`), and isort (`I`) codes 44 | # See https://beta.ruff.rs/docs/rules/ for more possible rules 45 | select = ["E", "F", "I"] 46 | # Same as Black. 47 | line-length = 88 48 | # Avoid automatically removing unused imports in __init__.py files. 49 | # Such imports will be flagged with a dedicated message suggesting 50 | # that the import is either added to the module's __all__ symbol 51 | ignore-init-module-imports = true 52 | -------------------------------------------------------------------------------- /w2s/gpu_pool.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Manager, Process 2 | import subprocess 3 | from datetime import datetime 4 | 5 | # Function that runs the job on a GPU 6 | def run_on_gpu(gpu: int, job: str): 7 | print(f"Starting on GPU {gpu}: {job}") 8 | print("at time:", datetime.now()) 9 | command = f"CUDA_VISIBLE_DEVICES={gpu} {job}" 10 | try: 11 | subprocess.run(command, shell=True, check=True) 12 | except Exception as e: 13 | print(f"[WARN] Error on GPU {gpu}: {job}") 14 | print(e) 15 | else: 16 | print(f"Finished on GPU {gpu}: {job}") 17 | finally: 18 | print("at time:", datetime.now()) 19 | 20 | # Worker function that gets jobs and runs them on a specific GPU 21 | def worker(gpu, jobs, lock): 22 | while True: 23 | with lock: 24 | if not jobs: 25 | print(f"GPU {gpu} has no more jobs.") 26 | return # No more jobs to process 27 | job = jobs.pop(0) 28 | 29 | run_on_gpu(gpu, job) 30 | 31 | def gpu_map(gpus, jobs): 32 | # Create a shared job list and a lock 33 | manager = Manager() 34 | jobs = manager.list(jobs) 35 | lock = manager.Lock() 36 | 37 | # Create and start worker processes, each assigned to a specific GPU 38 | processes = [] 39 | for gpu in gpus: 40 | p = Process(target=worker, args=(gpu, jobs, lock)) 41 | processes.append(p) 42 | p.start() 43 | 44 | # Wait for all worker processes to finish 45 | for p in processes: 46 | p.join() 47 | 48 | print("All jobs finished.") 49 | -------------------------------------------------------------------------------- /w2s/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Type, TypeVar, cast 2 | from simple_parsing import Serializable 3 | 4 | T = TypeVar("T") 5 | 6 | 7 | def assert_type(typ: Type[T], obj: Any) -> T: 8 | """Assert that an object is of a given type at runtime and return it.""" 9 | if not isinstance(obj, typ): 10 | raise TypeError(f"Expected {typ.__name__}, got {type(obj).__name__}") 11 | 12 | return cast(typ, obj) 13 | 14 | NICKNAMES = { 15 | "Qwen/Qwen1.5-0.5B": "Qw0.5", 16 | "meta-llama/Meta-Llama-3-8B": "Ll8", 17 | "./results": "rs", 18 | "cosine": "c", 19 | } 20 | 21 | def get_config_foldername(config: dict) -> str: 22 | def shorten_key(key: str) -> str: 23 | return "".join(word[0] for word in key.split("_")) 24 | 25 | def shorten_value(value) -> str: 26 | if isinstance(value, bool): 27 | return "1" if value else "0" 28 | elif isinstance(value, str): 29 | if value in NICKNAMES: 30 | return NICKNAMES[value] 31 | value = value.split("/")[-1] 32 | if "_" in value: 33 | return "_".join(word[:4] for word in value.split("_")) 34 | else: 35 | return value 36 | else: 37 | return str(value) 38 | 39 | config = flatten_dict(config) 40 | return "-".join( 41 | f"{shorten_key(k)}={shorten_value(v)}" for k, v in sorted(config.items()) 42 | ) 43 | 44 | 45 | def flatten_dict(d: dict, parent_key: str = "", sep: str = "_") -> dict: 46 | items = [] 47 | for k, v in d.items(): 48 | new_key = parent_key + sep + k if parent_key else k 49 | if isinstance(v, dict): 50 | items.extend(flatten_dict(v, new_key, sep=sep).items()) 51 | elif isinstance(v, Serializable): # can't use LossConfig, etc to avoid circular import 52 | items.extend(flatten_dict(v.to_dict(), new_key, sep=sep).items()) 53 | else: 54 | items.append((new_key, v)) 55 | return dict(items) 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Weak-to-Strong Generalization 2 | 3 | Source code for experiments from [this blog post](https://blog.eleuther.ai/weak-to-strong/), based in part on [openai/weak-to-strong](https://github.com/openai/weak-to-strong). 4 | 5 | ## Installation 6 | 7 | `pip install -e .` 8 | 9 | If you run into problems, try installing inside a conda or venv environment. 10 | 11 | ## Running experiments 12 | 13 | Basic invocation: 14 | 15 | `python run.py --dataset sciq --run_name my_run` 16 | 17 | List of datasets: `from w2s.ds_registry import VALID_DATASETS` 18 | 19 | Additional args to reproduce blog post experiments: 20 | 21 | ```sh 22 | --loss xent 23 | --s2s_iters 2 24 | --probe_relabel --probe knn 25 | --probe_relabel --probe logreg 26 | --probe_filter --probe knn 27 | --probe_filter --probe logreg 28 | --probe_filter --probe topo 29 | --loss window --radius midweak 30 | --loss entropy 31 | ``` 32 | 33 | There is `--help` available via `simpleparsing`. For individual loss functions and probes, try e.g. `python run.py --probe topo --help`. 34 | 35 | Defaults are set in `sft_config.py`, `probe.py`, and `loss.py`. 36 | 37 | LoRA is on by default (rank 8). Pass `--disable_lora` to switch it off, although this is somewhat untested. For architectures other than Llama, Mistral, and Qwen, you will need to set `ModelConfig.lora_modules` in the arguments to `w2s.sft.train()`. 38 | 39 | ## Output and shared folders 40 | 41 | Strong student results are stored in `./results/[run_name]/[dataset]/`. (You can set a different `--run_name` per experiment so that they don't overwrite each other.) 42 | 43 | Basic metrics, like test AUC and accuracy, are in `w2s/results.json`. `wandb` is used for detailed logging if available. 44 | 45 | Floor and ceiling results, weak supervisor predictions, and activations are stored in a shared folder so that they can be reused across experiments. By default this is `./results/[shared_folder]/[dataset]/`; the default `--shared_folder` is `shared`. You should change this if you change the weak or strong model, or anything else about the weak model training setup. 46 | 47 | ## Troubleshooting 48 | 49 | Llama 3 is gated, see [here](https://huggingface.co/docs/hub/models-gated#access-gated-models-as-a-user) for details. 50 | 51 | Loss and probe parameters are set from the CLI via `simpleparsing` [subgroups](https://github.com/lebrice/SimpleParsing/blob/master/examples/subgroups/README.md). -------------------------------------------------------------------------------- /w2s/roc_auc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def roc_auc(y_true: Tensor, y_pred: Tensor) -> Tensor: 6 | """Area under the receiver operating characteristic curve (ROC AUC). 7 | 8 | Unlike scikit-learn's implementation, this function supports batched inputs of 9 | shape `(N, n)` where `N` is the number of datasets and `n` is the number of samples 10 | within each dataset. This is primarily useful for efficiently computing bootstrap 11 | confidence intervals. 12 | 13 | Args: 14 | y_true: Ground truth tensor of shape `(N,)` or `(N, n)`. 15 | y_pred: Predicted class tensor of shape `(N,)` or `(N, n)`. 16 | 17 | Returns: 18 | Tensor: If the inputs are 1D, a scalar containing the ROC AUC. If they're 2D, 19 | a tensor of shape (N,) containing the ROC AUC for each dataset. 20 | """ 21 | if y_true.shape != y_pred.shape: 22 | raise ValueError( 23 | f"y_true and y_pred should have the same shape; " 24 | f"got {y_true.shape} and {y_pred.shape}" 25 | ) 26 | if y_true.dim() not in (1, 2): 27 | raise ValueError("y_true and y_pred should be 1D or 2D tensors") 28 | if not ((y_true == 1) | (y_true == 0)).all(): 29 | raise ValueError("y_true should contain only 0s and 1s") 30 | 31 | # Sort y_pred in descending order and get indices 32 | indices = y_pred.argsort(descending=True, dim=-1) 33 | 34 | # Reorder y_true based on sorted y_pred indices 35 | y_true_sorted = y_true.gather(-1, indices) 36 | 37 | # Calculate number of positive and negative samples 38 | num_positives = y_true.sum(dim=-1) 39 | num_negatives = y_true.shape[-1] - num_positives 40 | 41 | # Calculate cumulative sum of true positive counts (TPs) 42 | tps = torch.cumsum(y_true_sorted, dim=-1) 43 | 44 | # Calculate cumulative sum of false positive counts (FPs) 45 | fps = torch.cumsum(1 - y_true_sorted, dim=-1) 46 | 47 | # Calculate true positive rate (TPR) and false positive rate (FPR) 48 | tpr = tps / num_positives.view(-1, 1) 49 | fpr = fps / num_negatives.view(-1, 1) 50 | 51 | # Calculate differences between consecutive FPR values (widths of trapezoids) 52 | fpr_diffs = torch.cat( 53 | [fpr[..., 1:] - fpr[..., :-1], torch.zeros_like(fpr[..., :1])], dim=-1 54 | ) 55 | 56 | # Calculate area under the ROC curve for each dataset using trapezoidal rule 57 | return torch.sum(tpr * fpr_diffs, dim=-1).squeeze() 58 | -------------------------------------------------------------------------------- /w2s/sft_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Union 3 | 4 | from simple_parsing import Serializable, field, subgroups 5 | from w2s.loss import LOSS_CONFIGS, LossConfig 6 | from w2s.probe import PROBE_CONFIGS, ProbeConfig 7 | from w2s.sft_utils import literal 8 | 9 | 10 | 11 | @dataclass 12 | class SFTConfig(Serializable): 13 | # name of the model to train 14 | weak_model_name: str = "Qwen/Qwen1.5-0.5B" 15 | strong_model_name: str = "meta-llama/Meta-Llama-3-8B" 16 | # name of the dataset to use 17 | dataset: str = "boolq" 18 | n_epochs: float = 3 19 | n_train: int = 10_000 20 | n_val: int = 1_000 21 | n_test: int = 5_000 22 | # when "train", it uses the training set to generate predictions 23 | # otherwise it uses n_predict held out examples 24 | n_predict: Union[literal("train"), int] = 0 25 | # examples per minibatch (small to fit in memory for long sequences) 26 | minibatch_size: int = 1 27 | # examples per update (gradient accumulated across minibatches) 28 | batch_size: int = 32 29 | results_folder: str = "./results" 30 | run_name: str = "default" 31 | shared_folder: str = "shared" 32 | disable_lora: bool = False 33 | lr_schedule: str = "cosine" 34 | n_warmup_steps: int = 40 # 2 / (1 - 0.95) = 40 35 | eval_every: int = 25 # steps 36 | save_every: int = 25 # steps 37 | save_total_limit: Optional[int] = 1 38 | weight_decay: float = 0.1 39 | weak_lr: float = 5e-4 40 | strong_lr: float = 8e-5 41 | load_best_model_at_end: bool = True 42 | metric_for_best_model: str = "val_auroc" 43 | 44 | loss: LossConfig = subgroups(LOSS_CONFIGS, default="logconf") 45 | probe: ProbeConfig = subgroups(PROBE_CONFIGS, default="knn") 46 | 47 | probe_layer: Optional[int] = None 48 | probe_relabel: bool = False 49 | probe_filter: bool = False 50 | contamination: float = 0.1 51 | 52 | greater_is_better: bool = field(init=False) 53 | loss_name: str = field(init=False) 54 | probe_name: str = field(init=False) 55 | 56 | s2s_iters: int = 0 57 | save_strong_acts: bool = False 58 | 59 | 60 | def __post_init__(self): 61 | if "loss" in self.metric_for_best_model: 62 | self.greater_is_better = False 63 | elif ( 64 | "auroc" in self.metric_for_best_model 65 | or "accuracy" in self.metric_for_best_model 66 | ): 67 | self.greater_is_better = True 68 | else: 69 | raise ValueError(f"Unknown metric {self.metric_for_best_model}") 70 | 71 | self.loss_name = {LOSS_CONFIGS[k]:k for k in LOSS_CONFIGS}[type(self.loss)] 72 | self.probe_name = {PROBE_CONFIGS[k]:k for k in PROBE_CONFIGS}[type(self.probe)] 73 | 74 | def to_dict(self): 75 | irrelevant_fields = ["results_folder", "run_name", "minibatch_size"] 76 | return {k: v for k, v in vars(self).items() if k not in irrelevant_fields} 77 | -------------------------------------------------------------------------------- /w2s/topo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from datasets import Dataset 4 | from scipy.sparse.csgraph import connected_components 5 | from tqdm.auto import tqdm 6 | from transformers import ( 7 | PretrainedConfig, 8 | PreTrainedModel, 9 | ) 10 | 11 | from .utils import assert_type 12 | 13 | 14 | def lcc_mask(adj: torch.Tensor): 15 | """Mask for membership in the largest connected component""" 16 | num_cmps, cmps = connected_components(adj.cpu(), connection="strong") 17 | cmp_sizes = np.bincount(cmps, minlength=num_cmps) 18 | return torch.from_numpy(cmps == cmp_sizes.argmax()).to(adj.device) 19 | 20 | 21 | def topo_cc(x: torch.Tensor, y: torch.Tensor, *, k: int = 5): 22 | """TopoCC label filtering algorithm.""" 23 | # All pairwise distances, leaving out the diagonal 24 | dists = torch.cdist(x, x).fill_diagonal_(torch.inf) 25 | 26 | # Find indices of `k` nearest neighbors 27 | indices = dists.topk(k, largest=False).indices 28 | 29 | # Create kNN adjacency matrix 30 | adj = indices.new_zeros(len(x), len(x), dtype=torch.bool) 31 | adj.scatter_(1, indices, True) 32 | 33 | cls_mask = y[:, None] > 0.5 34 | pos_mask = lcc_mask(adj & cls_mask) 35 | neg_mask = lcc_mask(adj & ~cls_mask) 36 | return neg_mask | pos_mask 37 | 38 | 39 | def topofilter( 40 | x: torch.Tensor, y: torch.Tensor, contamination: float = 0.1, 41 | *, k_cc: int = 5, k_zeta: int = 5, 42 | ): 43 | """Remove points whose labels are far the average of their neighbors' labels.""" 44 | 45 | C = topo_cc(x, y, k=k_cc) 46 | x_C, y_C = x[C], y[C] 47 | 48 | # Zeta filtering 49 | dists = torch.cdist(x_C, x_C).fill_diagonal_(torch.inf) 50 | indices = dists.topk(k_zeta, largest=False).indices 51 | 52 | # Compute how far each point is from its average neighbor 53 | knn_labels = y_C[indices].float().mean(1) 54 | dists = torch.abs(y_C - knn_labels) 55 | 56 | # Remove points that are furthest from their average neighbor 57 | cc_removed = len(x) - len(x_C) 58 | remaining = round(len(x) * contamination) - cc_removed 59 | n = max(remaining, 0) 60 | 61 | if n == 0: 62 | print("TopoCC overshot contamination. No additional points removed.") 63 | print(f"frac removed = {cc_removed / len(x)}") 64 | 65 | filtered = dists.topk(n).indices.cpu() 66 | C_indices = C.nonzero().squeeze(1).cpu() 67 | return np.delete(C_indices, filtered) 68 | 69 | 70 | def topolabel( 71 | x0: torch.Tensor, y0: torch.Tensor, x: torch.Tensor, 72 | *, k_cc: int = 5, k_zeta: int = 5, 73 | ): 74 | """Relabel points x to the average of their neighbors' labels within x0 CC.""" 75 | 76 | C = topo_cc(x0, y0, k=k_cc) 77 | x_C, y_C = x0[C], y0[C] 78 | 79 | # Zeta filtering 80 | dists = torch.cdist(x, x_C) 81 | indices = dists.topk(k_zeta, largest=False).indices 82 | 83 | # Compute how far each point is from its average neighbor 84 | knn_labels = y_C[indices].float().mean(1) 85 | return knn_labels -------------------------------------------------------------------------------- /w2s/probe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass 3 | from typing import Optional, Union 4 | 5 | from w2s.sft_utils import literal 6 | from w2s.logistic import Classifier 7 | from w2s.topo import topofilter, topolabel 8 | 9 | from simple_parsing import Serializable 10 | 11 | 12 | @dataclass 13 | class ProbeConfig(Serializable): 14 | def to_dict(self): 15 | irrelevant_fields = [] 16 | return {k: v for k, v in vars(self).items() if k not in irrelevant_fields} 17 | 18 | @dataclass 19 | class KnnProbeConfig(ProbeConfig): 20 | k: int = 50 21 | 22 | @dataclass 23 | class LogisticProbeConfig(ProbeConfig): 24 | l2p: float = 1e-3 25 | 26 | @dataclass 27 | class TopoProbeConfig(ProbeConfig): 28 | k_cc: int = 50 29 | k_zeta: int = 50 30 | modified: bool = False 31 | 32 | 33 | PROBE_CONFIGS = { 34 | "knn": KnnProbeConfig, 35 | "logreg": LogisticProbeConfig, 36 | "topo": TopoProbeConfig, 37 | } 38 | 39 | 40 | class Probe: 41 | def __init__(self, config: ProbeConfig): 42 | self.config = config 43 | 44 | def fit(self, acts, labels): 45 | raise NotImplementedError 46 | 47 | def predict(self, acts): 48 | raise NotImplementedError 49 | 50 | def filter(self, acts, labels, contamination): 51 | preds = self.predict(acts) 52 | disagree = (preds - labels).abs() 53 | # return indices for bottom (1-contamination) of examples 54 | return disagree.argsort(descending=True)[int(contamination * len(disagree)):] 55 | 56 | 57 | class KnnProbe(Probe): 58 | def __init__(self, config: KnnProbeConfig): 59 | super().__init__(config) 60 | self.k = config.k 61 | 62 | def fit(self, acts, labels): 63 | self.acts = acts 64 | self.labels = labels 65 | 66 | def predict(self, acts): 67 | # compute cosine similarity 68 | dists = torch.cdist(acts, self.acts) 69 | # get top k 70 | topk = dists.topk(self.k, largest=False) 71 | # get labels 72 | labels = self.labels[topk.indices] 73 | # get majority vote 74 | pred = labels.float().mean(dim=-1) 75 | return pred 76 | 77 | 78 | class LogisticProbe(Probe): 79 | def __init__(self, config: LogisticProbeConfig): 80 | super().__init__(config) 81 | self.l2p = config.l2p 82 | 83 | def fit(self, acts, labels): 84 | acts = acts.to(torch.float32) 85 | self.clf = Classifier(acts.shape[1], num_classes=1, device=acts.device) 86 | self.clf.fit(acts, labels, l2_penalty=self.l2p) 87 | 88 | def predict(self, acts): 89 | acts = acts.to(torch.float32) 90 | preds = torch.sigmoid(self.clf(acts)) 91 | return preds 92 | 93 | 94 | class TopoProbe(Probe): 95 | def __init__(self, config: TopoProbeConfig): 96 | super().__init__(config) 97 | self.k_cc = config.k_cc 98 | self.k_zeta = config.k_zeta 99 | self.modified = config.modified 100 | 101 | def fit(self, acts, labels): 102 | self.acts = acts 103 | self.labels = labels 104 | 105 | def predict(self, acts): 106 | return topolabel(self.acts, self.labels, acts, k_cc=self.k_cc, k_zeta=self.k_zeta) 107 | 108 | def filter(self, acts, labels, contamination): 109 | if not self.config.modified: 110 | return topofilter(acts, labels, contamination, k_cc=self.k_cc, k_zeta=self.k_zeta) 111 | else: 112 | return super().filter(acts, labels, contamination) 113 | 114 | 115 | PROBES = { 116 | "knn": KnnProbe, 117 | "logreg": LogisticProbe, 118 | "topo": TopoProbe, 119 | } 120 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | results*/ 2 | wandb/ 3 | plots/ 4 | scratch/ 5 | venv*/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/#use-with-ide 116 | .pdm.toml 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | -------------------------------------------------------------------------------- /w2s/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from dataclasses import dataclass 4 | from typing import Optional, Union 5 | from w2s.sft_utils import literal 6 | 7 | from simple_parsing import Serializable 8 | 9 | 10 | @dataclass 11 | class LossConfig(Serializable): 12 | def to_dict(self): 13 | irrelevant_fields = [] 14 | return {k: v for k, v in vars(self).items() if k not in irrelevant_fields} 15 | 16 | @dataclass 17 | class LogConfidenceLossConfig(LossConfig): 18 | logconf_weight: float = 0.5 19 | logconf_warmup_steps: int = 100 20 | balance_batch: bool = False 21 | 22 | @dataclass 23 | class ConfidenceWindowLossConfig(LossConfig): 24 | radius: Union[float, literal("midweak")] = 0.15 25 | 26 | def to_dict(self): 27 | return {"radius": self.radius if isinstance(self.radius, float) else "midweak"} 28 | 29 | @dataclass 30 | class LogEntropyLossConfig(LogConfidenceLossConfig): 31 | pass 32 | 33 | @dataclass 34 | class CrossEntropyLossConfig(LossConfig): 35 | pass 36 | 37 | @dataclass 38 | class KLDivergenceLossConfig(LossConfig): 39 | pass 40 | 41 | LOSS_CONFIGS = { 42 | "logconf": LogConfidenceLossConfig, 43 | "window": ConfidenceWindowLossConfig, 44 | "entropy": LogEntropyLossConfig, 45 | "xent": CrossEntropyLossConfig, 46 | "kl": KLDivergenceLossConfig, 47 | } 48 | 49 | 50 | def confidence_window_loss( 51 | logits, 52 | labels, 53 | radius: float = 0.15, 54 | ): 55 | """ 56 | Use cross-entropy loss only for the examples where the model is uncertain. 57 | """ 58 | logits = logits.float() 59 | labels = labels.float() 60 | 61 | preds = torch.softmax(logits, dim=-1) 62 | 63 | uncertain = (preds.max(dim=-1).values < 0.5 + radius) 64 | 65 | target = torch.stack([1.0 - labels, labels], dim=1) 66 | 67 | loss = torch.nn.functional.cross_entropy( 68 | logits[uncertain], 69 | target[uncertain], 70 | reduction="sum" 71 | ) 72 | 73 | return loss / logits.shape[0] 74 | 75 | 76 | def cross_entropy_loss( 77 | logits, 78 | labels, 79 | ): 80 | logits = logits.float() 81 | labels = labels.float() 82 | 83 | target = torch.stack([1.0 - labels, labels], dim=1) 84 | return torch.nn.functional.cross_entropy(logits, target) 85 | 86 | 87 | def kl_divergence_loss( 88 | logits, 89 | labels, 90 | ): 91 | logits = logits.float() 92 | labels = labels.float() 93 | 94 | target = torch.stack([1.0 - labels, labels], dim=1) 95 | log_preds = torch.log_softmax(logits, dim=-1) 96 | 97 | return F.kl_div(log_preds, target, reduction="batchmean") 98 | 99 | 100 | def log_confidence_loss( 101 | logits, 102 | labels, 103 | step: int, 104 | warmup_steps: int = 200, 105 | aux_coef: float = 0.5, 106 | balance_batch: bool = False, 107 | harden: bool = True, 108 | buffer: list = None, 109 | buffer_size: int = 32, 110 | ): 111 | """ 112 | This is similar to the loss in Burns et al., except that it also optionally 113 | balances the labels by mean-subtracting in log-odds space. 114 | """ 115 | logits = logits.float() 116 | labels = labels.float() 117 | if balance_batch: 118 | logodds_labels = torch.log(labels + 1e-7) - torch.log(1 - labels + 1e-7) 119 | labels = torch.sigmoid(logodds_labels - logodds_labels.mean()) 120 | prior = 0.5 121 | else: 122 | prior = labels.mean() if labels.shape[0] > 1 else 0.5 123 | 124 | coef = aux_coef * min(1.0, step / warmup_steps) if warmup_steps > 0 else aux_coef 125 | preds = torch.softmax(logits, dim=-1) 126 | buffer += list(preds[:, 0].detach()) 127 | buffer = buffer[-buffer_size:] 128 | 129 | if harden: 130 | threshold = torch.quantile(torch.stack(buffer), prior) 131 | target_preds = torch.cat( 132 | [(preds[:, 0] >= threshold)[:, None], (preds[:, 0] < threshold)[:, None]], 133 | dim=1, 134 | ) 135 | else: 136 | target_preds = preds 137 | 138 | labels_binary = torch.stack([1.0 - labels, labels], dim=1) 139 | target = labels_binary * (1 - coef) + target_preds.detach() * coef 140 | return torch.nn.functional.cross_entropy(logits, target) -------------------------------------------------------------------------------- /w2s/model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import List, Optional, Union 4 | 5 | import torch 6 | from peft import LoraConfig, TaskType, get_peft_model 7 | from transformers import ( 8 | AutoModelForSequenceClassification, 9 | AutoTokenizer, 10 | ) 11 | 12 | from w2s.utils import assert_type 13 | 14 | # Works for Llama, Mistral, and Qwen architectures 15 | DEFAULT_LORA_MODULES = [ 16 | "gate_proj", 17 | "down_proj", 18 | "up_proj", 19 | "q_proj", 20 | "k_proj", 21 | "v_proj", 22 | "o_proj", 23 | ] 24 | 25 | 26 | @dataclass 27 | class PredictorConfig(ABC): 28 | @abstractmethod 29 | def to_dict(self) -> dict: 30 | ... 31 | 32 | 33 | @dataclass 34 | class ModelConfig(PredictorConfig): 35 | name: str 36 | enable_lora: bool 37 | lora_modules: Optional[List[str]] = None 38 | 39 | def to_dict(self): 40 | return vars(self) 41 | 42 | 43 | class AutoCastingScore(torch.nn.Module): 44 | def __init__( 45 | self, score: torch.nn.Linear, output_dtype: torch.dtype = torch.bfloat16 46 | ): 47 | super().__init__() 48 | # make a leaf tensor with the same data as score 49 | self.weight = torch.nn.Parameter(score.weight.to(torch.float32).data) 50 | self.output_dtype = output_dtype 51 | 52 | def forward(self, hiddens): 53 | return torch.nn.functional.linear( 54 | hiddens.to(self.weight.dtype), self.weight, None 55 | ).to(self.output_dtype) 56 | 57 | 58 | def init_tokenizer(cfg: ModelConfig) -> AutoTokenizer: 59 | tokenizer = AutoTokenizer.from_pretrained(cfg.name) 60 | if tokenizer.pad_token_id is None: 61 | tokenizer.pad_token = tokenizer.eos_token 62 | 63 | return tokenizer 64 | 65 | 66 | def init_model(tokenizer, cfg: ModelConfig): 67 | model = AutoModelForSequenceClassification.from_pretrained( 68 | cfg.name, torch_dtype="auto", device_map={"": "cuda"}, 69 | # force_download=True, 70 | ) 71 | 72 | if cfg.lora_modules is None and cfg.enable_lora: 73 | cfg.lora_modules = MODEL_REGISTRY.get(cfg.name, {}).get( 74 | "lora_modules", DEFAULT_LORA_MODULES 75 | ) 76 | 77 | model.config.pad_token_id = tokenizer.pad_token_id # type: ignore 78 | model.score.weight.data *= 0.01 79 | model.config.problem_type = "single_label_classification" 80 | 81 | if cfg.enable_lora: 82 | lora_cfg = LoraConfig( 83 | target_modules=cfg.lora_modules, task_type=TaskType.SEQ_CLS 84 | ) 85 | 86 | # NOTE: adding task_type causes dtype errors, but is necessary for proper module saving 87 | # and for making the lm head trainable, so we need to wrap it in an AutoCastingScore 88 | for attr in ["score", "classifier"]: 89 | if hasattr(model, attr): 90 | setattr( 91 | model, 92 | attr, 93 | AutoCastingScore(getattr(model, attr), output_dtype=model.dtype), 94 | ) 95 | break 96 | else: 97 | raise ValueError("Could not find classifier head in model.") 98 | model = get_peft_model(model, lora_cfg) 99 | 100 | # put all the trainable (e.g. LoRA) parameters in float32 101 | for p in model.parameters(): 102 | if p.requires_grad: 103 | p.data = p.data.float() 104 | 105 | return model 106 | 107 | 108 | def init_model_and_tokenizer(cfg: ModelConfig): 109 | tokenizer = init_tokenizer(cfg) 110 | model = init_model(tokenizer, cfg) 111 | 112 | return model, tokenizer 113 | 114 | 115 | # TODO: make a legitimate model registry 116 | # for now we just have a map from model name to learning rate and lora modules 117 | MODEL_REGISTRY = { 118 | "meta-llama/Meta-Llama-3-8B": { 119 | "lr": 8e-5, 120 | "lora_modules": DEFAULT_LORA_MODULES, 121 | }, 122 | "mistralai/Mistral-7B-v0.1": { 123 | "lr": 8e-5, 124 | "lora_modules": DEFAULT_LORA_MODULES, 125 | }, 126 | "gemma/gemma-7b": { 127 | "lr": 8e-5, 128 | "lora_modules": DEFAULT_LORA_MODULES, 129 | }, 130 | "Qwen/Qwen1.5-0.5B": { 131 | "lr": 5e-4, 132 | "lora_modules": DEFAULT_LORA_MODULES, 133 | }, 134 | } -------------------------------------------------------------------------------- /w2s/sft_utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from pathlib import Path 3 | 4 | import pynvml 5 | import torch 6 | from datasets import Dataset 7 | from tqdm import tqdm 8 | from transformers import PretrainedConfig, Trainer 9 | 10 | from w2s.utils import assert_type 11 | 12 | 13 | # simple_parsing doesn't like typing.Literal (pre-3.12) so I rolled my own 14 | # note: parens, not brackets 15 | 16 | # Python 3.11 version: 17 | # literal = lambda *args: StrEnum("option", args) 18 | 19 | # Python 3.10 version: 20 | def ident_escape_char(c: str) -> str: 21 | if c.isalnum() or c == "_": 22 | return c 23 | return f"_{ord(c)}_" 24 | 25 | def ident_escape(s: str) -> str: 26 | return "".join(ident_escape_char(c) for c in s) 27 | 28 | def literal(s: str): 29 | return type('LiteralString_' + ident_escape(s), (LiteralString,), {"value": s}) 30 | 31 | class LiteralString(): 32 | value = "" 33 | 34 | def __init__(self, value): 35 | if value != self.value: 36 | raise ValueError(f"Invalid value {value!r} is not literally {self.value!r}") 37 | 38 | def __str__(self): 39 | return self.value 40 | 41 | def __eq__(self, other): 42 | return self.value == other 43 | 44 | 45 | @torch.no_grad() 46 | def gather_hiddens(model: torch.nn.Module, dataset: Dataset): 47 | dataset = dataset.with_format("torch", device="cuda") 48 | 49 | cfg = assert_type(PretrainedConfig, model.config) 50 | D = assert_type(int, cfg.hidden_size) 51 | L = assert_type(int, cfg.num_hidden_layers + 1) 52 | 53 | buffer = torch.empty(len(dataset), L, D, device=model.device, dtype=model.dtype) 54 | print(f"Allocated buffer of shape {buffer.shape}") 55 | for i, ex in enumerate(tqdm(dataset)): 56 | ex = assert_type(dict, ex) 57 | 58 | out = model(ex["input_ids"][None], output_hidden_states=True) 59 | act = torch.stack(out.hidden_states)[:, 0, -1] # Final token 60 | if act.shape != (L, D): 61 | raise ValueError(f"Unexpected shape {act.shape} for hidden states on example {i}") 62 | buffer[i] = act 63 | 64 | return buffer 65 | 66 | 67 | def move_best_ckpt(trainer: Trainer): 68 | if trainer.args.load_best_model_at_end: 69 | path = trainer.state.best_model_checkpoint 70 | perf = trainer.state.best_metric 71 | assert path is not None, "No best checkpoint found" 72 | assert perf is not None, "No best metric" 73 | 74 | src = Path(path) 75 | dest = src.parent / "best-ckpt" 76 | src.rename(dest) 77 | print(f"Best model (auroc {perf:.3f}) saved at: {dest}") 78 | 79 | 80 | def clear_mem(verbose: bool = False): 81 | """ 82 | This function is used to clear the memory allocated by PyTorch. 83 | It does so by calling the garbage collector to release unused GPU memory. 84 | After clearing the memory, it prints the current amount of memory still 85 | allocated by PyTorch (post-clean). 86 | 87 | Parameters: 88 | verbose (bool): Whether to print additional information. 89 | """ 90 | 91 | gc.collect() 92 | torch.cuda.empty_cache() 93 | print( 94 | f"torch.cuda.memory_allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f}GB" 95 | ) 96 | 97 | if verbose: 98 | 99 | def try_attr(x, a): 100 | try: 101 | return getattr(x, a) 102 | except Exception: 103 | # amazing that this can cause... 104 | # (AttributeError, OSError, AssertionError, RuntimeError, ModuleNotFoundError) 105 | return None 106 | 107 | for obj in gc.get_objects(): 108 | if torch.is_tensor(obj) or torch.is_tensor(try_attr(obj, "data")): 109 | print(type(obj), obj.size(), obj.dtype) 110 | 111 | 112 | def get_gpu_mem_used() -> float: 113 | """returns proportion of used GPU memory averaged across all GPUs""" 114 | prop_sum = 0 115 | pynvml.nvmlInit() 116 | try: 117 | num_devices = pynvml.nvmlDeviceGetCount() 118 | for i in range(num_devices): 119 | handle = pynvml.nvmlDeviceGetHandleByIndex(i) 120 | meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) 121 | prop_sum += int(meminfo.used) / int(meminfo.total) 122 | finally: 123 | pynvml.nvmlShutdown() 124 | return prop_sum / num_devices 125 | 126 | 127 | def spotcheck_init(model: torch.nn.Module): 128 | spots = {} 129 | for name, param in model.named_parameters(): 130 | p = param.detach().clone() 131 | depth = param.dim() 132 | for _ in range(depth - 1): 133 | p = p[0] 134 | spots[name] = p 135 | 136 | return spots 137 | 138 | 139 | def spotcheck(model: torch.nn.Module, spots): 140 | for name, param in model.named_parameters(): 141 | p = param.detach().clone() 142 | depth = param.dim() 143 | for _ in range(depth - 1): 144 | p = p[0] 145 | if torch.allclose(p, spots[name], rtol=1e-2) or not param.requires_grad: 146 | print(f"[WARNING] Param {name} possibly unchanged.") 147 | print(f"change: {torch.norm(p - spots[name]) / torch.norm(p):.4e}, grad: {param.requires_grad}") 148 | 149 | def lshape(l): 150 | if not isinstance(l, list): 151 | return [] 152 | return [len(l)] + lshape(l[0]) -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os 3 | 4 | import torch 5 | from datasets import DatasetDict, load_from_disk 6 | from simple_parsing import parse 7 | from transformers import ( 8 | TrainingArguments, 9 | ) 10 | 11 | from w2s.ds_registry import load_and_process_dataset 12 | from w2s.model import ModelConfig 13 | from w2s.sft import train 14 | from w2s.sft_config import SFTConfig 15 | from w2s.utils import get_config_foldername 16 | 17 | 18 | def run_train(cfg: SFTConfig): 19 | print(f"Loading and processing dataset {cfg.dataset}") 20 | splits = load_and_process_dataset( 21 | cfg.dataset, cfg.n_train, cfg.n_val, cfg.n_test, cfg.n_predict 22 | ) 23 | 24 | train_halves = splits["train"].train_test_split(test_size=0.5, seed=42) 25 | splits["weak_train"] = train_halves["train"] 26 | splits["strong_train"] = train_halves["test"] 27 | 28 | cols = ["hard_label", "txt"] 29 | splits = splits.select_columns(cols).rename_column("hard_label", "labels") 30 | for split in splits: 31 | splits[split] = splits[split].add_column("gt_labels", splits[split]["labels"]) 32 | 33 | print( 34 | f"Example:\n\n{splits['strong_train'][0]['txt']}\n\nLabel: {splits['strong_train'][0]['labels']}" 35 | ) 36 | 37 | root = Path(cfg.results_folder) / cfg.run_name 38 | shared_root = Path(cfg.results_folder) / cfg.shared_folder 39 | cfg_name = cfg.dataset 40 | train_args: dict = dict( 41 | num_train_epochs=cfg.n_epochs, 42 | adam_beta2=0.95, 43 | gradient_accumulation_steps=cfg.batch_size // cfg.minibatch_size, 44 | eval_strategy="steps", 45 | label_names=["labels"], 46 | load_best_model_at_end=cfg.load_best_model_at_end, 47 | logging_steps=25, 48 | metric_for_best_model=cfg.metric_for_best_model, 49 | greater_is_better=cfg.greater_is_better, 50 | per_device_train_batch_size=cfg.minibatch_size, 51 | per_device_eval_batch_size=cfg.minibatch_size, 52 | save_strategy="steps", 53 | save_total_limit=cfg.save_total_limit, 54 | tf32=True, # Use Tensor Cores even for fp32 matmuls 55 | warmup_steps=cfg.n_warmup_steps, 56 | weight_decay=cfg.weight_decay, 57 | lr_scheduler_type=cfg.lr_schedule, 58 | eval_steps=cfg.eval_every, 59 | save_steps=cfg.save_every, 60 | ) 61 | 62 | def get_model_and_run_name(model_name, current_name): 63 | model_last = model_name.split("/")[-1] 64 | model_cfg = ModelConfig(name=model_name, enable_lora=not cfg.disable_lora) 65 | run_name = f"{current_name}-{cfg.run_name}-{cfg.dataset}-{model_last}" 66 | return model_cfg, run_name 67 | 68 | # train weak floor, get predictions 69 | print("\n\033[32m===== Training weak model =====\033[0m") 70 | model_cfg, run_name = get_model_and_run_name(cfg.weak_model_name, "weak") 71 | train_args["run_name"] = run_name 72 | train_args["output_dir"] = str(shared_root / cfg_name / "weak") 73 | train_args["learning_rate"] = cfg.weak_lr 74 | weak_ds_dict = DatasetDict( 75 | { 76 | "train": splits["weak_train"], 77 | "val": splits["val"], 78 | "test": splits["test"], 79 | } 80 | ) 81 | weak_predict_dict = {"train": splits["strong_train"], "val": splits["val"]} 82 | train( 83 | weak_ds_dict, 84 | model_cfg, 85 | TrainingArguments(**train_args), 86 | cfg.to_dict(), 87 | transfer=False, 88 | predict_dict=weak_predict_dict, 89 | ) 90 | 91 | # train strong ceil 92 | print("\n\033[32m===== Training strong model =====\033[0m") 93 | model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, "strong") 94 | train_args["run_name"] = run_name 95 | train_args["output_dir"] = str(shared_root / cfg_name / "strong") 96 | train_args["learning_rate"] = cfg.strong_lr 97 | strong_ds_dict = DatasetDict( 98 | { 99 | "train": splits["strong_train"], 100 | "val": splits["val"], 101 | "test": splits["test"], 102 | } 103 | ) 104 | train( 105 | strong_ds_dict, 106 | model_cfg, 107 | TrainingArguments(**train_args), 108 | cfg.to_dict(), 109 | transfer=False, 110 | ) 111 | 112 | # load weak predictions 113 | weak_preds_root = shared_root / cfg_name / "weak" / "predictions" 114 | print(f"Loading weak predictions from {weak_preds_root}") 115 | weak_train_preds_ds = load_from_disk(str(weak_preds_root / "train")) 116 | weak_val_preds_ds = load_from_disk(str(weak_preds_root / "val")) 117 | 118 | # train w2s, get predictions 119 | print("\n\033[32m===== Training w2s model =====\033[0m") 120 | model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, "w2s") 121 | train_args["run_name"] = run_name 122 | train_args["output_dir"] = str(root / cfg_name / "w2s") 123 | train_args["learning_rate"] = cfg.strong_lr 124 | w2s_ds_dict = DatasetDict( 125 | { 126 | "train": ( 127 | splits["strong_train"] 128 | .remove_columns("labels") 129 | .add_column("labels", weak_train_preds_ds["soft_pred"]) # type: ignore 130 | ), 131 | "val": ( 132 | splits["val"] 133 | .remove_columns("labels") 134 | .add_column("labels", weak_val_preds_ds["soft_pred"]) 135 | ), # type: ignore 136 | "test": splits["test"], 137 | } 138 | ) 139 | # assert (weak_train_preds_ds["id"] == w2s_ds_dict["train"]["id"]) 140 | # assert (weak_val_preds_ds["id"] == w2s_ds_dict["val"]["id"]) 141 | w2s_predict_dict = {"train": splits["strong_train"], "val": splits["val"]} 142 | train( 143 | w2s_ds_dict, 144 | model_cfg, 145 | TrainingArguments(**train_args), 146 | cfg.to_dict(), 147 | transfer=True, 148 | predict_dict=w2s_predict_dict, 149 | save_activations=True, 150 | acts_dir=shared_root / cfg_name / "w2s" / "activations", 151 | ) 152 | 153 | prev = "w2s" 154 | 155 | # strong-to-strong iterations 156 | for s2s_iter in range(cfg.s2s_iters): 157 | 158 | # load prev predictions 159 | prev_preds_root = root / cfg_name / prev / "predictions" 160 | print(f"Loading {prev} predictions from {prev_preds_root}") 161 | prev_train_preds_ds = load_from_disk(str(prev_preds_root / "train")) 162 | prev_val_preds_ds = load_from_disk(str(prev_preds_root / "val")) 163 | 164 | # train s2s, get predictions 165 | print(f"\n\033[32m===== Training s2s model iteration {s2s_iter} =====\033[0m") 166 | model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, f"s2s-{s2s_iter}") 167 | train_args["run_name"] = run_name 168 | train_args["output_dir"] = str(root / cfg_name / f"s2s-{s2s_iter}") 169 | train_args["learning_rate"] = cfg.strong_lr 170 | s2s_ds_dict = DatasetDict( 171 | { 172 | "train": ( 173 | splits["strong_train"] 174 | .remove_columns("labels") 175 | .add_column("labels", prev_train_preds_ds["soft_pred"]) # type: ignore 176 | ), 177 | "val": ( 178 | splits["val"] 179 | .remove_columns("labels") 180 | .add_column("labels", prev_val_preds_ds["soft_pred"]) 181 | ), # type: ignore 182 | "test": splits["test"], 183 | } 184 | ) 185 | # assert (prev_train_preds_ds["id"] == s2s_ds_dict["train"]["id"]) 186 | # assert (prev_val_preds_ds["id"] == s2s_ds_dict["val"]["id"]) 187 | s2s_predict_dict = {"train": splits["strong_train"], "val": splits["val"]} 188 | train( 189 | s2s_ds_dict, 190 | model_cfg, 191 | TrainingArguments(**train_args), 192 | cfg.to_dict(), 193 | transfer=True, 194 | predict_dict=s2s_predict_dict, 195 | acts_dir=shared_root / cfg_name / f"s2s-{s2s_iter}" / "activations", 196 | ) 197 | 198 | prev = f"s2s-{s2s_iter}" 199 | 200 | if __name__ == "__main__": 201 | run_train(parse(SFTConfig)) 202 | -------------------------------------------------------------------------------- /w2s/logistic.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn.functional import ( 6 | binary_cross_entropy_with_logits as bce_with_logits, 7 | ) 8 | from torch.nn.functional import ( 9 | cross_entropy, 10 | ) 11 | 12 | 13 | @dataclass 14 | class InlpResult: 15 | """Result of Iterative Nullspace Projection (NLP).""" 16 | 17 | losses: list[float] = field(default_factory=list) 18 | classifiers: list["Classifier"] = field(default_factory=list) 19 | 20 | 21 | @dataclass 22 | class RegularizationPath: 23 | """Result of cross-validation.""" 24 | 25 | penalties: list[float] 26 | losses: list[float] 27 | 28 | @property 29 | def best_penalty(self) -> float: 30 | """Returns the best L2 regularization penalty.""" 31 | return self.penalties[self.losses.index(self.best_loss)] 32 | 33 | @property 34 | def best_loss(self) -> float: 35 | """Returns the best loss.""" 36 | return min(self.losses) 37 | 38 | 39 | class Classifier(torch.nn.Module): 40 | """Linear classifier trained with supervised learning.""" 41 | 42 | def __init__( 43 | self, 44 | input_dim: int, 45 | num_classes: int = 2, 46 | device: str | torch.device | None = None, 47 | dtype: torch.dtype | None = None, 48 | ): 49 | super().__init__() 50 | 51 | self.linear = torch.nn.Linear( 52 | input_dim, num_classes if num_classes > 2 else 1, device=device, dtype=dtype 53 | ) 54 | self.linear.bias.data.zero_() 55 | self.linear.weight.data.zero_() 56 | 57 | def forward(self, x: Tensor) -> Tensor: 58 | return self.linear(x).squeeze(-1) 59 | 60 | @torch.enable_grad() 61 | def fit( 62 | self, 63 | x: Tensor, 64 | y: Tensor, 65 | *, 66 | l2_penalty: float = 0.001, 67 | max_iter: int = 10_000, 68 | ) -> float: 69 | """Fits the model to the input data using L-BFGS with L2 regularization. 70 | 71 | Args: 72 | x: Input tensor of shape (N, D), where N is the number of samples and D is 73 | the input dimension. 74 | y: Target tensor of shape (N,) for binary classification or (N, C) for 75 | multiclass classification, where C is the number of classes. 76 | l2_penalty: L2 regularization strength. 77 | max_iter: Maximum number of iterations for the L-BFGS optimizer. 78 | 79 | Returns: 80 | Final value of the loss function after optimization. 81 | """ 82 | optimizer = torch.optim.LBFGS( 83 | self.parameters(), 84 | line_search_fn="strong_wolfe", 85 | max_iter=max_iter, 86 | ) 87 | 88 | num_classes = self.linear.out_features 89 | loss_fn = bce_with_logits if num_classes == 1 else cross_entropy 90 | loss = torch.inf 91 | y = y.to( 92 | torch.get_default_dtype() if num_classes == 1 else torch.long, 93 | ) 94 | 95 | def closure(): 96 | nonlocal loss 97 | optimizer.zero_grad() 98 | 99 | # Calculate the loss function 100 | logits = self(x).squeeze(-1) 101 | loss = loss_fn(logits, y) 102 | if l2_penalty: 103 | reg_loss = loss + l2_penalty * self.linear.weight.square().sum() 104 | else: 105 | reg_loss = loss 106 | 107 | reg_loss.backward() 108 | return float(reg_loss) 109 | 110 | optimizer.step(closure) 111 | return float(loss) 112 | 113 | @torch.no_grad() 114 | def fit_cv( 115 | self, 116 | x: Tensor, 117 | y: Tensor, 118 | *, 119 | k: int = 5, 120 | max_iter: int = 10_000, 121 | num_penalties: int = 10, 122 | seed: int = 42, 123 | ) -> RegularizationPath: 124 | """Fit using k-fold cross-validation to select the best L2 penalty. 125 | 126 | Args: 127 | x: Input tensor of shape (N, D), where N is the number of samples and D is 128 | the input dimension. 129 | y: Target tensor of shape (N,) for binary classification or (N, C) for 130 | multiclass classification, where C is the number of classes. 131 | k: Number of folds for k-fold cross-validation. 132 | max_iter: Maximum number of iterations for the L-BFGS optimizer. 133 | num_penalties: Number of L2 regularization penalties to try. 134 | seed: Random seed for the k-fold cross-validation. 135 | 136 | Returns: 137 | `RegularizationPath` containing the penalties tried and the validation loss 138 | achieved using that penalty, averaged across the folds. 139 | """ 140 | num_samples = x.shape[0] 141 | if k < 3: 142 | raise ValueError("`k` must be at least 3") 143 | if k > num_samples: 144 | raise ValueError("`k` must be less than or equal to the number of samples") 145 | 146 | rng = torch.Generator(device=x.device) 147 | rng.manual_seed(seed) 148 | 149 | fold_size = num_samples // k 150 | indices = torch.randperm(num_samples, device=x.device, generator=rng) 151 | 152 | # Try a range of L2 penalties, including 0 153 | l2_penalties = [0.0] + torch.logspace(-4, 4, num_penalties).tolist() 154 | 155 | num_classes = self.linear.out_features 156 | loss_fn = bce_with_logits if num_classes == 1 else cross_entropy 157 | losses = x.new_zeros((k, num_penalties + 1)) 158 | y = y.to( 159 | torch.get_default_dtype() if num_classes == 1 else torch.long, 160 | ) 161 | 162 | for i in range(k): 163 | start, end = i * fold_size, (i + 1) * fold_size 164 | train_indices = torch.cat([indices[:start], indices[end:]]) 165 | val_indices = indices[start:end] 166 | 167 | train_x, train_y = x[train_indices], y[train_indices] 168 | val_x, val_y = x[val_indices], y[val_indices] 169 | 170 | # Regularization path with warm-starting 171 | for j, l2_penalty in enumerate(l2_penalties): 172 | self.fit(train_x, train_y, l2_penalty=l2_penalty, max_iter=max_iter) 173 | 174 | logits = self(val_x).squeeze(-1) 175 | loss = loss_fn(logits, val_y) 176 | losses[i, j] = loss 177 | 178 | mean_losses = losses.mean(dim=0) 179 | best_idx = mean_losses.argmin() 180 | 181 | # Refit with the best penalty 182 | best_penalty = l2_penalties[best_idx] 183 | self.fit(x, y, l2_penalty=best_penalty, max_iter=max_iter) 184 | return RegularizationPath(l2_penalties, mean_losses.tolist()) 185 | 186 | @classmethod 187 | def inlp( 188 | cls, x: Tensor, y: Tensor, max_iter: int | None = None, tol: float = 0.01 189 | ) -> InlpResult: 190 | """Iterative Nullspace Projection (INLP) . 191 | 192 | Args: 193 | x: Input tensor of shape (N, D), where N is the number of samples and D is 194 | the input dimension. 195 | y: Target tensor of shape (N,) for binary classification or (N, C) for 196 | multiclass classification, where C is the number of classes. 197 | max_iter: Maximum number of iterations to run. If `None`, run for the full 198 | dimension of the input. 199 | tol: Tolerance for the loss function. The algorithm will stop when the loss 200 | is within `tol` of the entropy of the labels. 201 | 202 | Returns: 203 | `InlpResult` containing the classifiers and losses achieved at each 204 | iteration. 205 | """ 206 | 207 | y.shape[-1] if y.ndim > 1 else 2 208 | d = x.shape[-1] 209 | loss = 0.0 210 | 211 | # Compute entropy of the labels 212 | p = y.float().mean() 213 | H = -p * torch.log(p) - (1 - p) * torch.log(1 - p) 214 | 215 | if max_iter is not None: 216 | d = min(d, max_iter) 217 | 218 | # Iterate until the loss is within epsilon of the entropy 219 | result = InlpResult() 220 | for _ in range(d): 221 | clf = cls(d, device=x.device, dtype=x.dtype) 222 | loss = clf.fit(x, y) 223 | result.classifiers.append(clf) 224 | result.losses.append(loss) 225 | 226 | if loss >= (1.0 - tol) * H: 227 | break 228 | 229 | # Project the data onto the nullspace of the classifier 230 | x = clf.nullspace_project(x) 231 | 232 | return result 233 | 234 | def nullspace_project(self, x: Tensor) -> Tensor: 235 | """Project the given data onto the nullspace of the classifier.""" 236 | 237 | # https://en.wikipedia.org/wiki/Projection_(linear_algebra) 238 | A = self.linear.weight.data.T 239 | P = A @ torch.linalg.solve(A.mT @ A, A.mT) 240 | return x - x @ P 241 | -------------------------------------------------------------------------------- /w2s/sft.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Union, Optional 4 | import os 5 | 6 | import torch 7 | from datasets import DatasetDict 8 | from transformers import ( 9 | DataCollatorWithPadding, 10 | Trainer, 11 | TrainingArguments, 12 | ) 13 | 14 | import wandb 15 | from w2s.loss import ( 16 | log_confidence_loss, 17 | confidence_window_loss, 18 | cross_entropy_loss, 19 | kl_divergence_loss, 20 | ) 21 | from w2s.model import ModelConfig, init_tokenizer, init_model 22 | from w2s.roc_auc import roc_auc 23 | from w2s.sft_utils import ( 24 | clear_mem, 25 | get_gpu_mem_used, 26 | move_best_ckpt, 27 | gather_hiddens, 28 | spotcheck_init, 29 | spotcheck, 30 | lshape, 31 | ) 32 | from w2s.loss import LossConfig 33 | from w2s.probe import PROBES 34 | 35 | 36 | class CustomLossTrainer(Trainer): 37 | def __init__( 38 | self, 39 | loss_name: str, 40 | loss_cfg: LossConfig, 41 | transfer: bool, 42 | buffer_size: int, 43 | *args, 44 | **kwargs, 45 | ): 46 | super().__init__(*args, **kwargs) 47 | self.loss_name = loss_name 48 | self.loss_cfg = loss_cfg 49 | self.transfer = transfer 50 | if loss_name in ["logconf", "entropy"]: 51 | self.buffer = [] 52 | self.buffer_size = buffer_size 53 | 54 | 55 | def compute_loss(self, model, inputs, return_outputs=False): 56 | labels = inputs.pop("labels").float() 57 | 58 | outputs = model(**inputs) 59 | 60 | if self.loss_name == 'logconf': 61 | loss = log_confidence_loss( 62 | outputs.logits, 63 | labels, 64 | self.state.global_step, 65 | aux_coef=(self.loss_cfg.logconf_weight if self.transfer else 0.), 66 | warmup_steps=self.loss_cfg.logconf_warmup_steps, 67 | balance_batch=self.loss_cfg.balance_batch, 68 | harden=True, 69 | buffer=self.buffer, 70 | buffer_size=self.buffer_size, 71 | ) 72 | elif self.loss_name == 'entropy': 73 | loss = log_confidence_loss( 74 | outputs.logits, 75 | labels, 76 | self.state.global_step, 77 | aux_coef=(self.loss_cfg.logconf_weight if self.transfer else 0.), 78 | warmup_steps=self.loss_cfg.logconf_warmup_steps, 79 | balance_batch=self.loss_cfg.balance_batch, 80 | harden=False, 81 | buffer=self.buffer, 82 | buffer_size=self.buffer_size, 83 | ) 84 | elif self.loss_name == 'xent': 85 | loss = cross_entropy_loss( 86 | outputs.logits, 87 | labels, 88 | ) 89 | elif self.loss_name == 'kl': 90 | loss = kl_divergence_loss( 91 | outputs.logits, 92 | labels, 93 | ) 94 | elif self.loss_name == 'window': 95 | loss = confidence_window_loss( 96 | outputs.logits, 97 | labels, 98 | radius=(self.loss_cfg.radius if self.transfer else 0.51), 99 | ) 100 | else: 101 | raise ValueError(f"Unknown loss function: {self.loss_name}") 102 | 103 | return (loss, outputs) if return_outputs else loss 104 | 105 | 106 | def train( 107 | ds_dict: DatasetDict, 108 | model_cfg: ModelConfig, 109 | train_args: TrainingArguments, 110 | cfg: dict, 111 | transfer: bool, 112 | predict_dict: Union[DatasetDict, dict, None] = None, 113 | save_activations: bool = False, 114 | use_probe: bool = False, 115 | acts_dir: Optional[Path] = None, 116 | ): 117 | """ 118 | ds_dict: DatasetDict with splits for train, val, test, and (optionally) predict, 119 | with columns "txt" and "labels" 120 | model_cfg: ModelConfig with the model name and whether to enable LoRA 121 | train_args: TrainingArguments with the training hyperparameters 122 | cfg: a dictionary containing all the relevant details for reproducibility. 123 | This will be updated with your train_args and model_cfg before saving. 124 | logconf_weight: the weight for the log confidence loss 125 | logconf_warmup_steps: the number of steps to linearly increase the logconf_weight 126 | balance_batch: whether to balance the batch with the log confidence loss 127 | 128 | This function trains a model on ds_dict["train"], uses ds_dict["val"] for early stopping, 129 | and evaluates on ds_dict["test"]. 130 | It also optionally predicts on ds_dict["predict"] and saves the predictions. 131 | """ 132 | save_dir = Path(train_args.output_dir) 133 | results_path = save_dir / "results.json" 134 | 135 | os.makedirs(save_dir, exist_ok=True) 136 | 137 | clear_mem() 138 | 139 | print(f"{get_gpu_mem_used() * 100:.2f}% of all GPU memory in use before toker init") 140 | tokenizer = init_tokenizer(model_cfg) 141 | model = None 142 | 143 | def process(examples): 144 | out = tokenizer(examples["txt"], truncation=True) 145 | return out 146 | 147 | ds_dict = ds_dict.map(process, batched=True) 148 | 149 | def compute_metrics_torch(predictions, labels): 150 | hard_labels = (labels > 0.5).long() 151 | return dict( 152 | accuracy=predictions.argmax(dim=1).eq(hard_labels).float().mean(), 153 | auroc=roc_auc(hard_labels, predictions[:, 1]), 154 | ) 155 | 156 | def detorch_metrics(metrics): 157 | return {k: v.detach().cpu().item() for k, v in metrics.items()} 158 | 159 | def compute_metrics(eval_pred): 160 | predictions, labels = map(torch.from_numpy, eval_pred) 161 | return compute_metrics_torch(predictions, labels) 162 | 163 | probe_required = transfer and (cfg['probe_relabel'] or cfg['probe_filter']) 164 | 165 | if save_activations or probe_required: 166 | if acts_dir.exists() and all((acts_dir / f"{name}.pt").exists() for name in ds_dict.keys()): 167 | print("Activations already exist at", acts_dir) 168 | else: 169 | print("Saving activations to", acts_dir) 170 | if model is None: 171 | print(f"{get_gpu_mem_used() * 100:.2f}% of all GPU memory in use before training") 172 | model = init_model(tokenizer, model_cfg) 173 | print(f"{get_gpu_mem_used() * 100:.2f}% of all GPU memory in use after model init") 174 | acts_dir.mkdir(parents=True, exist_ok=True) 175 | for name, ds in ds_dict.items(): 176 | acts = gather_hiddens(model, ds) 177 | torch.save(acts, acts_dir / f"{name}.pt") 178 | 179 | if probe_required: 180 | print("Training probe") 181 | print(f"{get_gpu_mem_used() * 100:.2f}% of all GPU memory in use before first act load") 182 | all_acts = torch.load(acts_dir / f"train.pt", map_location="cuda") 183 | probe_layer = cfg["probe_layer"] 184 | if probe_layer is None: 185 | probe_layer = all_acts.shape[1] // 2 186 | print(f"Using probe layer {probe_layer}") 187 | acts = all_acts[:, probe_layer] 188 | print(f"{get_gpu_mem_used() * 100:.2f}% of all GPU memory in use after first act load") 189 | probe = PROBES[cfg["probe_name"]](cfg["probe"]) 190 | probe.fit(acts, torch.tensor(ds_dict["train"]["labels"], device="cuda")) 191 | print(f"{get_gpu_mem_used() * 100:.2f}% of all GPU memory in use after probe training") 192 | for name, ds in ds_dict.items(): 193 | all_acts = torch.load(acts_dir / f"{name}.pt", map_location="cuda") 194 | acts = all_acts[:, probe_layer] 195 | preds = probe.predict(acts) 196 | # preds --> (1-preds, preds) 197 | preds = torch.stack([1 - preds, preds], dim=-1) 198 | print(f"{get_gpu_mem_used() * 100:.2f}% of all GPU memory in use after probe prediction on {name}") 199 | # print(f"preds shape: {preds.shape}") 200 | # print(f"labels shape: {torch.tensor(ds['labels']).shape}") 201 | agree_metrics = compute_metrics_torch(preds, torch.tensor(ds["labels"], device="cuda")) 202 | gt_metrics = compute_metrics_torch(preds, torch.tensor(ds["gt_labels"], device="cuda")) 203 | with open(save_dir / f"{name}_probe_metrics.json", "w", ) as f: 204 | json.dump({ 205 | "agree": detorch_metrics(agree_metrics), 206 | "gt": detorch_metrics(gt_metrics), 207 | }, f, indent=2) 208 | if name in ["train", "val"]: 209 | if cfg['probe_filter']: 210 | good_indices = probe.filter(acts, torch.tensor(ds["labels"], device="cuda"), cfg['contamination']) 211 | sizes = { 212 | "before": len(ds), 213 | "after": len(good_indices), 214 | "removed": len(ds) - len(good_indices), 215 | "contamination": int(cfg['contamination'] * len(ds)), 216 | } 217 | with open(save_dir / f"{name}_filter_sizes.json", "w") as f: 218 | json.dump(sizes, f, indent=2) 219 | ds = ds.select(good_indices) 220 | ds_dict[name] = ds 221 | if cfg['probe_relabel']: 222 | # print(lshape(ds["labels"])) 223 | ds = ds.remove_columns("labels").add_column("labels", preds[:, 1].detach().cpu().numpy()) 224 | ds_dict[name] = ds 225 | 226 | if results_path.exists(): 227 | print( 228 | f"Results already exist at {results_path}. Skipping training and evaluation." 229 | ) 230 | return 231 | 232 | print(f"No results found at {results_path}. Training model.") 233 | 234 | if model is None: 235 | print(f"{get_gpu_mem_used() * 100:.2f}% of all GPU memory in use before training") 236 | model = init_model(tokenizer, model_cfg) 237 | print(f"{get_gpu_mem_used() * 100:.2f}% of all GPU memory in use after model init") 238 | 239 | if transfer and cfg["loss_name"] == "window" and cfg["loss"].radius == "midweak": 240 | confs = torch.abs(torch.tensor(ds_dict["train"]["labels"]) - 0.5) 241 | cfg["loss"].radius = confs.median().item() 242 | print(f"Setting radius to {cfg['loss'].radius:.2f} based on median confidence in train set") 243 | 244 | trainer = CustomLossTrainer( 245 | loss_name=cfg["loss_name"], 246 | loss_cfg=cfg["loss"], 247 | buffer_size=cfg["batch_size"], 248 | transfer=transfer, 249 | args=train_args, 250 | compute_metrics=compute_metrics, 251 | data_collator=DataCollatorWithPadding(tokenizer), 252 | eval_dataset={k: ds_dict[k] for k in ["val", "test"]}, 253 | model=model, 254 | tokenizer=tokenizer, 255 | train_dataset=ds_dict["train"], 256 | ) 257 | 258 | # spotcheck for untrained params 259 | # spots = spotcheck_init(model) 260 | 261 | # train 262 | trainer.train() 263 | 264 | # spotcheck 265 | # spotcheck(model, spots) 266 | 267 | # evaluate on test dataset 268 | eval_results = trainer.evaluate(ds_dict["test"]) # type: ignore 269 | move_best_ckpt(trainer) 270 | 271 | # save results 272 | with open(results_path, "w") as f: 273 | json.dump(eval_results, f, indent=2) 274 | 275 | # save config 276 | with open(save_dir / "config.json", "w") as f: 277 | cfg["model"] = model_cfg.to_dict() 278 | cfg["train_args"] = train_args.to_dict() 279 | cfg["transfer"] = transfer 280 | cfg["loss"] = cfg["loss"].to_dict() 281 | cfg["probe"] = cfg["probe"].to_dict() 282 | json.dump(cfg, f, indent=2) 283 | wandb.config.update(cfg) 284 | 285 | # save predictions 286 | if predict_dict is not None: 287 | for name, predict_ds in predict_dict.items(): 288 | predict_ds = predict_ds.map(process, batched=True) 289 | print("Gathering predictions for", name) 290 | pred_logits = torch.from_numpy(trainer.predict(predict_ds).predictions) 291 | preds = pred_logits.softmax(-1)[:, 1].cpu().float().numpy() 292 | pred_ds = predict_ds.add_column("soft_pred", preds) 293 | 294 | pred_ds.save_to_disk(str(save_dir / "predictions" / name)) 295 | 296 | wandb.finish() 297 | -------------------------------------------------------------------------------- /w2s/ds_registry.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import hashlib 3 | from collections import Counter 4 | from dataclasses import dataclass 5 | from random import Random 6 | from typing import Any, Callable, Literal, Union 7 | 8 | from datasets import Dataset as HfDataset 9 | from datasets import DatasetDict as HfDatasetDict 10 | from datasets import concatenate_datasets 11 | from datasets import load_dataset as hf_load_dataset 12 | 13 | 14 | @dataclass 15 | class DatasetConfig: 16 | # split -> unshuffled dataset of items 17 | loader: Callable[[str], HfDataset] 18 | # formats items to have keys 'txt' and 'hard_label', takes a random.Random rng 19 | # (or for generative tasks, 'ctx' and 'target', and no 'hard_label' key) 20 | # deprecated OAI legacy: 21 | # optionally also adds the key 'choices', a pair of strings, indicating to use the 22 | # lm head 23 | formatter: Callable[[Any], Any] 24 | # "classify" or "generate" 25 | task: str = "classify" 26 | 27 | 28 | # mapping from dataset name to load function and format function 29 | _REGISTRY: dict[str, DatasetConfig] = {} 30 | 31 | 32 | def register_dataset(name: str, config: DatasetConfig): 33 | _REGISTRY[name] = config 34 | 35 | 36 | def balance(ds: HfDataset, seed: int): 37 | """Undersample balance to 50/50""" 38 | 39 | label_counts = Counter(ds["hard_label"]) 40 | assert len(label_counts) == 2, f"Dataset must be binary {label_counts}" 41 | 42 | # undersample the majority class 43 | majority_label = max(label_counts, key=lambda k: label_counts[k]) 44 | minority_label = 1 - majority_label 45 | minority_count = label_counts[minority_label] 46 | minority_ds = ds.filter(lambda ex: ex["hard_label"] == minority_label) 47 | majority_ds = ( 48 | ds.filter(lambda ex: ex["hard_label"] == majority_label) 49 | .shuffle(seed=seed) 50 | .select(range(minority_count)) 51 | ) 52 | return concatenate_datasets([minority_ds, majority_ds]).shuffle(seed=seed) 53 | 54 | 55 | def load_and_process_train_test( 56 | ds_name: str, 57 | split_sizes: dict[str, int], 58 | seed: int = 0, 59 | take_test_from_train: bool = False, 60 | ): 61 | n_test = split_sizes.get("test", 0) 62 | if take_test_from_train: 63 | # in this case we gather excess documents from the train set, and 64 | # at the end redistribute them to the test set 65 | split_sizes["train"] += n_test 66 | del split_sizes["test"] 67 | 68 | if ds_name not in _REGISTRY: 69 | raise ValueError(f"Unknown dataset {ds_name}, please register") 70 | cfg = _REGISTRY[ds_name] 71 | results = {} 72 | for split, n_docs in split_sizes.items(): 73 | ds = cfg.loader(split).shuffle(seed=seed) 74 | ds = ds.map(functools.partial(cfg.formatter, rng=Random(seed))) # type: ignore 75 | 76 | if cfg.task == "generate": 77 | ds = ds.filter(lambda ex: ex["ctx"] != "") # remove empty texts 78 | ds = ds.filter(lambda ex: ex["target"] != "") 79 | else: 80 | ds = ds.filter(lambda ex: ex["txt"] != "") # remove empty texts 81 | ds = balance(ds, seed) # balance to 50/50 82 | 83 | try: 84 | ds = ds.select(range(n_docs)) 85 | except IndexError: 86 | print(f"{ds_name} has < {n_docs} docs after balancing, using all {len(ds)}") 87 | 88 | if cfg.task == "generate": 89 | ds = ds.map( 90 | lambda ex: { 91 | "id": hashlib.sha1(ex["ctx"].encode()).hexdigest()[:8], 92 | } 93 | ) 94 | else: 95 | ds = ds.map( 96 | lambda ex: { 97 | "id": hashlib.sha1(ex["txt"].encode()).hexdigest()[:8], 98 | "soft_label": [ 99 | 1 - float(ex["hard_label"]), 100 | float(ex["hard_label"]), 101 | ], 102 | } 103 | ) 104 | results[split] = ds 105 | 106 | if take_test_from_train: 107 | # take the first n_test examples from the training set as the test set 108 | results["test"] = results["train"].select(range(n_test)) 109 | results["train"] = results["train"].select(range(n_test, len(results["train"]))) 110 | return results 111 | 112 | 113 | def load_and_process_dataset( 114 | ds_name: str, 115 | n_train: int, 116 | n_val: int, 117 | n_test: int, 118 | n_predict: Union[Literal["train"], int], 119 | take_test_from_train: bool = False, 120 | seed=0, 121 | ) -> HfDatasetDict: 122 | """ 123 | Returns a dict with keys 'train', 'val', 'test', and optionally 'predict', and dataset values 124 | Examples in 'test' split can never appear in 'train', 'val', or 'predict' on any run. 125 | """ 126 | split_sizes = dict(train=n_train + n_val, test=n_test) 127 | if n_predict != "train": 128 | assert n_predict >= 0 129 | split_sizes["train"] += n_predict 130 | 131 | results = load_and_process_train_test( 132 | ds_name, split_sizes, seed, take_test_from_train 133 | ) 134 | 135 | splits = dict( 136 | val=results["train"].select(range(n_val)), 137 | train=results["train"].select(range(n_val, len(results["train"]))), 138 | test=results["test"], 139 | ) 140 | if n_predict == "train": 141 | # simply use the training set for predictions 142 | splits["predict"] = splits["train"] 143 | elif n_predict > 0: 144 | # take the requested *fraction* of examples from the training set 145 | subsplits = splits["train"].train_test_split( 146 | test_size=n_predict / (n_train + n_predict) 147 | ) 148 | splits["train"], splits["predict"] = subsplits["train"], subsplits["test"] 149 | 150 | return HfDatasetDict(splits) 151 | 152 | 153 | warned_about_choices = set() 154 | 155 | 156 | def encode_choice(text, tokenizer): 157 | global warned_about_choices 158 | 159 | c_ids = tokenizer.encode(text, add_special_tokens=False) 160 | 161 | # some tokenizers split off the leading whitespace character 162 | if tokenizer.decode(c_ids[0]).strip() == "": 163 | c_ids = c_ids[1:] 164 | assert c_ids == tokenizer.encode(text.lstrip(), add_special_tokens=False) 165 | 166 | c_ids = tuple(c_ids) 167 | if len(c_ids) != 1 and c_ids not in warned_about_choices: 168 | assert c_ids[0] not in [ 169 | c[0] for c in warned_about_choices 170 | ], "Choice shares first token with another choice" 171 | warned_about_choices.add(c_ids) 172 | print( 173 | f'Warning: Only the first token of multitoken choice "{text}" will be used' 174 | ) 175 | return c_ids[0] 176 | 177 | 178 | def hf_loader(*hf_name, split_names=None, n_test=None): 179 | """ 180 | If `split_names` is provided, it maps from the requested 181 | split name to the actual name in the hugginface dataset. 182 | If `n_test` is provided, it will concatenate all splits together 183 | and then take a deterministic test set of size `n_test` from it. 184 | """ 185 | 186 | # this thunk avoids loading datasets at import time 187 | def thunk(split): 188 | nonlocal split_names 189 | if n_test is not None: 190 | assert split_names is None 191 | ds = hf_load_dataset(*hf_name) 192 | if isinstance(ds, HfDatasetDict): 193 | ds = concatenate_datasets(ds.values()) # type: ignore 194 | assert isinstance(ds, HfDataset) 195 | # the seed is fixed so that all runs use the same test pool 196 | splits = ds.train_test_split(test_size=n_test, seed=0) 197 | 198 | return splits[split] 199 | 200 | if split_names is None: 201 | split_names = dict() 202 | 203 | return hf_load_dataset(*hf_name, split=split_names.get(split, split)) 204 | 205 | return thunk 206 | 207 | 208 | ########## 209 | # ACTUAL DATASETS 210 | ########## 211 | 212 | 213 | def format_anli(ex, rng): 214 | txt = ( 215 | f"Premise: {ex['premise']}\nHypothesis: {ex['hypothesis']}\n\nDoes the premise" 216 | " entail the hypothesis?" 217 | ) 218 | return dict(txt=txt, hard_label=ex["label"] == 0) 219 | 220 | 221 | register_dataset( 222 | "anli-r2", 223 | DatasetConfig( 224 | loader=hf_loader( 225 | "facebook/anli", split_names=dict(train="train_r2", test="test_r2") 226 | ), # type: ignore 227 | formatter=format_anli, # type: ignore 228 | ), 229 | ) 230 | 231 | 232 | def format_cola(ex, rng): 233 | return dict(txt=ex["sentence"], hard_label=ex["label"]) 234 | 235 | 236 | register_dataset( 237 | "cola", 238 | DatasetConfig( 239 | loader=hf_loader( 240 | "nyu-mll/glue", "cola", split_names=dict(test="validation") 241 | ), # type: ignore 242 | formatter=format_cola, # type: ignore 243 | ), 244 | ) 245 | 246 | 247 | def format_dream(ex, rng): 248 | hard_label = int(rng.random() < 0.5) 249 | if hard_label: 250 | ans = ex["answer"] 251 | else: 252 | distractors = ex["choice"].copy() 253 | distractors.remove(ex["answer"]) 254 | 255 | ans = rng.choice(distractors) 256 | 257 | joined = "\n".join(ex["dialogue"]) 258 | txt = f"{joined}\n\nQ: {ex['question']} A: {ans}" 259 | return dict(txt=txt, hard_label=hard_label) 260 | 261 | 262 | register_dataset( 263 | "dream", 264 | DatasetConfig( 265 | loader=hf_loader("dream"), # type: ignore 266 | formatter=format_dream, # type: ignore 267 | ), 268 | ) 269 | 270 | 271 | def format_ethics_deontology(ex, rng): 272 | "yes" if ex["label"] else "no" 273 | txt = ( 274 | "Q: Is the following a morally valid" 275 | f" excuse?\n\n{ex['scenario']}\n\n{ex['excuse']}\n\nA:" 276 | ) 277 | return dict(txt=txt, hard_label=ex["label"]) 278 | 279 | 280 | register_dataset( 281 | "ethics-deontology", 282 | DatasetConfig( 283 | loader=hf_loader("hendrycks/ethics", "deontology"), # type: ignore 284 | formatter=format_ethics_deontology, # type: ignore 285 | ), 286 | ) 287 | 288 | 289 | def format_ethics_justice(ex, rng): 290 | txt = f"Q: Does this statement exemplify justice?\n\n{ex['scenario']}\n\nA:" 291 | return dict(txt=txt, hard_label=ex["label"]) 292 | 293 | 294 | register_dataset( 295 | "ethics-justice", 296 | DatasetConfig( 297 | loader=hf_loader("hendrycks/ethics", "justice"), # type: ignore 298 | formatter=format_ethics_justice, # type: ignore 299 | ), 300 | ) 301 | 302 | 303 | def format_ethics_virtue(ex, rng): 304 | txt = ( 305 | "Q: Does this behavior match the adjective that" 306 | f" follows?\n\n{ex['scenario']}\n\nA:" 307 | ) 308 | return dict(txt=txt, hard_label=ex["label"]) 309 | 310 | 311 | register_dataset( 312 | "ethics-virtue", 313 | DatasetConfig( 314 | loader=hf_loader("hendrycks/ethics", "virtue"), # type: ignore 315 | formatter=format_ethics_virtue, # type: ignore 316 | ), 317 | ) 318 | 319 | 320 | def format_ethics_utilitarianism(ex, rng): 321 | hard_label = int(rng.random() < 0.5) 322 | 323 | choices = [ex["baseline"], ex["less_pleasant"]] 324 | rng.shuffle(choices) 325 | 326 | correct = choices.index(ex["baseline"]) 327 | response = correct if hard_label else 1 - correct 328 | 329 | txt = f"Which is more pleasant?\n1) {choices[0]}\n2) {choices[1]} A: {response + 1}" 330 | return dict(txt=txt, hard_label=hard_label) 331 | 332 | 333 | register_dataset( 334 | "ethics-utilitarianism", 335 | DatasetConfig( 336 | loader=hf_loader("hendrycks/ethics", "utilitarianism"), # type: ignore 337 | formatter=format_ethics_utilitarianism, # type: ignore 338 | ), 339 | ) 340 | 341 | 342 | def format_mc_taco(ex, rng): 343 | template = "{sentence}\n\nGiven the above, {question} Is the answer {answer}?" 344 | return dict(txt=template.format(**ex), hard_label=ex["label"]) 345 | 346 | 347 | register_dataset( 348 | "mc_taco", 349 | DatasetConfig( # we switch train and test bc test is bigger 350 | loader=hf_loader( # type: ignore 351 | "mc_taco", split_names=dict(train="test", test="validation") 352 | ), 353 | formatter=format_mc_taco, # type: ignore 354 | ), 355 | ) 356 | 357 | 358 | def format_hellaswag(ex, rng): 359 | hard_label = int(rng.random() < 0.5) 360 | if hard_label: 361 | ans = ex["endings"][int(ex["label"])] 362 | else: 363 | ans = rng.choice( 364 | [e for i, e in enumerate(ex["endings"]) if i != int(ex["label"])] 365 | ) 366 | 367 | endings = "\n".join(ex["endings"]) 368 | txt = ( 369 | f'Context:\n{ex["ctx"]}\n\nContinuations:\n\n{endings}\n\nQ: Is "{ans}" the' 370 | " best continuation?" 371 | ) 372 | return dict(txt=txt, hard_label=hard_label) 373 | 374 | 375 | register_dataset( 376 | "hellaswag", 377 | DatasetConfig( 378 | loader=hf_loader( 379 | "Rowan/hellaswag", split_names=dict(test="validation") 380 | ), # type: ignore 381 | formatter=format_hellaswag, # type: ignore 382 | ), 383 | ) 384 | 385 | 386 | def format_multirc(ex, rng): 387 | template = 'Passage:\n\n{paragraph}\n\nQ: "{question}" Is the answer "{answer}"?' 388 | 389 | txt = template.format(**ex) 390 | return dict(txt=txt, hard_label=ex["label"]) 391 | 392 | 393 | register_dataset( 394 | "multirc", 395 | DatasetConfig( 396 | loader=hf_loader( 397 | "super_glue", "multirc", split_names=dict(test="validation") 398 | ), # type: ignore 399 | formatter=format_multirc, # type: ignore 400 | ), 401 | ) 402 | 403 | 404 | def format_openbookqa(ex, rng): 405 | hard_label = int(rng.random() < 0.5) 406 | if hard_label: 407 | ans = ex["answerKey"] 408 | else: 409 | letters = ex["choices"]["label"] 410 | 411 | distractors = ex["choices"]["text"].copy() 412 | del distractors[letters.index(ex["answerKey"])] 413 | ans = rng.choice(distractors) 414 | 415 | choices = [ 416 | f"{a}) {t}" for a, t in zip(ex["choices"]["label"], ex["choices"]["text"]) 417 | ] 418 | joined = "\n".join(choices) 419 | txt = f"Q: {ex['question_stem']}\n\nChoices:\n{joined}\n\nAnswer: {ans}" 420 | return dict(txt=txt, hard_label=hard_label) 421 | 422 | 423 | register_dataset( 424 | "openbookqa", 425 | DatasetConfig( 426 | loader=hf_loader("allenai/openbookqa"), # type: ignore 427 | formatter=format_openbookqa, # type: ignore 428 | ), 429 | ) 430 | 431 | 432 | def format_paws(ex, rng): 433 | template = ( 434 | "Sent 1: {sentence1}\nSent 2: {sentence2}\n\nQ: Are these sentences" 435 | " semantically equivalent?" 436 | ) 437 | return dict(txt=template.format(**ex), hard_label=ex["label"]) 438 | 439 | 440 | register_dataset( 441 | "paws", 442 | DatasetConfig( 443 | loader=hf_loader("paws", "labeled_final"), # type: ignore 444 | formatter=format_paws, # type: ignore 445 | ), 446 | ) 447 | 448 | 449 | def format_piqa(ex, rng): 450 | hard_label = int(rng.random() < 0.5) 451 | 452 | if hard_label: 453 | ans = ex["sol2"] if ex["label"] else ex["sol1"] 454 | else: 455 | ans = ex["sol1"] if ex["label"] else ex["sol2"] 456 | 457 | txt = f"Q: {ex['goal']} A: {ans}" 458 | return dict(txt=txt, hard_label=hard_label) 459 | 460 | 461 | register_dataset( 462 | "piqa", 463 | DatasetConfig( 464 | loader=hf_loader("piqa", split_names=dict(test="validation")), # type: ignore 465 | formatter=format_piqa, # type: ignore 466 | ), 467 | ) 468 | 469 | 470 | def format_quail(ex, rng): 471 | template = 'Passage:\n\n{context}\n\nQ: "{question}" Is the answer "{answer}"?' 472 | hard_label = int(rng.random() < 0.5) 473 | 474 | correct_id = ex["correct_answer_id"] 475 | if hard_label: 476 | ans = ex["answers"][correct_id] 477 | else: 478 | ans = rng.choice([a for i, a in enumerate(ex["answers"]) if i != correct_id]) 479 | 480 | txt = template.format(**ex, answer=ans) 481 | return dict(txt=txt, hard_label=hard_label) 482 | 483 | 484 | register_dataset( 485 | "quail", 486 | DatasetConfig( 487 | loader=hf_loader("quail", split_names=dict(test="validation")), # type: ignore 488 | formatter=format_quail, # type: ignore 489 | ), 490 | ) 491 | 492 | 493 | def format_quartz(ex, rng): 494 | template = 'Passage:\n{para}\n\nQ: "{question}" Is the answer "{answer}"?' 495 | hard_label = int(rng.random() < 0.5) 496 | 497 | correct_id = ex["choices"]["label"].index(ex["answerKey"]) 498 | ans = ex["choices"]["text"][correct_id if hard_label else 1 - correct_id] 499 | 500 | txt = template.format(**ex, answer=ans) 501 | return dict(txt=txt, hard_label=hard_label) 502 | 503 | 504 | register_dataset( 505 | "quartz", 506 | DatasetConfig( 507 | loader=hf_loader("allenai/quartz"), # type: ignore 508 | formatter=format_quartz, # type: ignore 509 | ), 510 | ) 511 | 512 | 513 | def format_social_i_qa(ex, rng): 514 | template = ( 515 | "Context:\n{context}\n\nQuestion:" 516 | ' "{question}"\n\nChoices:\n{answerA}\n{answerB}\n{answerC}\n\nIs the answer' 517 | ' "{answer}"?' 518 | ) 519 | hard_label = int(rng.random() < 0.5) 520 | 521 | answers = [ex["answerA"], ex["answerB"], ex["answerC"]] 522 | correct_id = int(ex["label"]) - 1 523 | if hard_label: 524 | ans = answers[correct_id] 525 | else: 526 | ans = rng.choice([a for i, a in enumerate(answers) if i != correct_id]) 527 | 528 | txt = template.format(**ex, answer=ans) 529 | return dict(txt=txt, hard_label=hard_label) 530 | 531 | 532 | register_dataset( 533 | "social_i_qa", 534 | DatasetConfig( 535 | loader=hf_loader( 536 | "social_i_qa", split_names=dict(test="validation") 537 | ), # type: ignore 538 | formatter=format_social_i_qa, # type: ignore 539 | ), 540 | ) 541 | 542 | 543 | def format_sst2(ex, rng): 544 | return dict(txt=ex["sentence"], hard_label=ex["label"]) 545 | 546 | 547 | register_dataset( 548 | "sst2", 549 | DatasetConfig( 550 | loader=hf_loader( 551 | "stanfordnlp/sst2", split_names=dict(test="validation") 552 | ), # type: ignore 553 | formatter=format_sst2, # type: ignore 554 | ), 555 | ) 556 | 557 | 558 | def format_wic(ex, rng): 559 | template = ( 560 | 'Sentence 1:\n{sentence1}\n\nSentence 2:\n{sentence2}\n\nQ: Does "{word}" have' 561 | " the same meaning in the above sentences?" 562 | ) 563 | return dict(txt=template.format(**ex), hard_label=ex["label"]) 564 | 565 | 566 | register_dataset( 567 | "wic", 568 | DatasetConfig( 569 | loader=hf_loader( 570 | "super_glue", "wic", split_names=dict(test="validation") 571 | ), # type: ignore 572 | formatter=format_wic, # type: ignore 573 | ), 574 | ) 575 | 576 | 577 | def format_twitter_sentiment(ex, rng): 578 | return dict(txt=ex["text"], hard_label=ex["label"]) 579 | 580 | 581 | register_dataset( 582 | "twitter-sentiment", 583 | DatasetConfig( 584 | loader=hf_loader("EleutherAI/twitter-sentiment"), # type: ignore 585 | formatter=format_twitter_sentiment, # type: ignore 586 | ), 587 | ) 588 | 589 | 590 | SCIQ_N_TEST = 3000 591 | 592 | 593 | def format_sciq(ex, rng): 594 | hard_label = int(rng.random() < 0.5) 595 | if hard_label: 596 | ans = ex["correct_answer"] 597 | else: 598 | ans = rng.choice([ex["distractor1"], ex["distractor2"], ex["distractor3"]]) 599 | 600 | txt = f"Q: {ex['question']} A: {ans}" 601 | return dict(txt=txt, hard_label=hard_label) 602 | 603 | 604 | register_dataset( 605 | "sciq", 606 | DatasetConfig( 607 | loader=hf_loader("sciq", n_test=SCIQ_N_TEST), # type: ignore 608 | formatter=format_sciq, # type: ignore 609 | ), 610 | ) 611 | 612 | 613 | def format_sciq_with_support(ex, rng): 614 | # from https://github.com/EleutherAI/elk-generalization 615 | template = ( 616 | 'Name: Bob\n\nPassage 1:\n{support}\n\nQ1: "{question}" Is the answer' 617 | ' "{answer}"?' 618 | ) 619 | hard_label = int(rng.random() < 0.5) 620 | if hard_label: 621 | ans = ex["correct_answer"] 622 | else: 623 | ans = rng.choice([ex["distractor1"], ex["distractor2"], ex["distractor3"]]) 624 | txt = template.format(support=ex["support"], question=ex["question"], answer=ans) 625 | return dict(txt=txt, hard_label=hard_label) 626 | 627 | 628 | register_dataset( 629 | "sciq_with_support", 630 | DatasetConfig( 631 | loader=hf_loader("sciq", n_test=SCIQ_N_TEST), # type: ignore 632 | formatter=format_sciq_with_support, # type: ignore 633 | ), 634 | ) 635 | 636 | 637 | def format_anthropic_hh(ex, rng): 638 | hard_label = int(rng.random() < 0.5) 639 | txt = ex["chosen"] if hard_label else ex["rejected"] 640 | return dict(txt=txt, hard_label=hard_label) 641 | 642 | 643 | register_dataset( 644 | "anthropic_hh", 645 | DatasetConfig( 646 | loader=hf_loader("Anthropic/hh-rlhf"), # type: ignore 647 | formatter=format_anthropic_hh, # type: ignore 648 | ), 649 | ) 650 | 651 | 652 | def format_cosmosqa(ex, rng): 653 | true_answer = ex["answer" + str(ex["label"])] 654 | if "None of the above choices ." in true_answer: 655 | hard_label = 0 656 | else: 657 | assert "None of the above choices" not in true_answer, true_answer 658 | hard_label = int(rng.random() < 0.5) 659 | if hard_label: 660 | answer = true_answer 661 | else: 662 | candidate_answers = [ex["answer" + str(i)] for i in range(4)] 663 | answer = rng.choice([x for x in candidate_answers if x != true_answer]) 664 | txt = f"Context: {ex['context']}\nQuestion: {ex['question']}\nAnswer: {answer}" 665 | return dict(txt=txt, hard_label=hard_label) 666 | 667 | 668 | register_dataset( 669 | "cosmos_qa", 670 | DatasetConfig( 671 | loader=hf_loader( 672 | "cosmos_qa", split_names=dict(test="validation") 673 | ), # type: ignore 674 | formatter=format_cosmosqa, # type: ignore 675 | ), 676 | ) 677 | 678 | 679 | def format_boolq(ex, rng): 680 | hard_label = int(ex["answer"]) 681 | txt = f"Passage: {ex['passage']}\nQuestion: {ex['question']}" 682 | return dict(txt=txt, hard_label=hard_label) 683 | 684 | 685 | register_dataset( 686 | "boolq", 687 | DatasetConfig( 688 | loader=hf_loader("boolq", split_names=dict(test="validation")), # type: ignore 689 | formatter=format_boolq, # type: ignore 690 | ), 691 | ) 692 | 693 | 694 | def format_amazon_polarity(ex, rng): 695 | return dict(txt=f"{ex['title']} {ex['content']}", hard_label=ex["label"]) 696 | 697 | 698 | register_dataset( 699 | "amazon_polarity", 700 | DatasetConfig( 701 | loader=hf_loader("amazon_polarity"), # type: ignore 702 | formatter=format_amazon_polarity, # type: ignore 703 | ), 704 | ) 705 | 706 | 707 | def format_underspecified_amazon_polarity(ex, rng, use_gt=True): 708 | txt = f"{ex['title']} {ex['content']}\n\nDoes this review say the product was \"good\" or \"great\"?" # noqa 709 | words = {"good", "great"} 710 | label = ( 711 | ex["label"] 712 | if use_gt 713 | else any(w in ex["content"].lower() or w in ex["title"].lower() for w in words) 714 | ) 715 | return dict(txt=txt, hard_label=label) 716 | 717 | 718 | # register_dataset( 719 | # "amazon_polarity_gt", 720 | # DatasetConfig( 721 | # loader=hf_loader("amazon_polarity"), # type: ignore 722 | # formatter=functools.partial(format_underspecified_amazon_polarity, use_gt=True), # type: ignore # noqa 723 | # ), 724 | # ) 725 | 726 | # register_dataset( 727 | # "amazon_polarity_weak", 728 | # DatasetConfig( 729 | # loader=hf_loader("amazon_polarity"), # type: ignore 730 | # formatter=functools.partial(format_underspecified_amazon_polarity, use_gt=False), # type: ignore # noqa 731 | # ), 732 | # ) 733 | 734 | 735 | VALID_DATASETS: list[str] = list(_REGISTRY.keys()) 736 | 737 | 738 | """ 739 | from datasets import disable_caching 740 | disable_caching() 741 | 742 | from w2s.datasets import load_and_process_dataset, VALID_DATASETS 743 | import numpy as np 744 | 745 | ds_name = "boolq" 746 | print(VALID_DATASETS) 747 | 748 | ds = load_and_process_dataset(ds_name, split_sizes=dict(train=500, test=10)) 749 | train = list(ds['train']) 750 | test = list(ds['test']) 751 | print(test[0]) 752 | print(np.mean([x['hard_label'] for x in train])) 753 | """ 754 | --------------------------------------------------------------------------------