├── repeng ├── __init__.py ├── datasets │ ├── __init__.py │ ├── data │ │ └── __init__.py │ ├── elk │ │ ├── __init__.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── limits.py │ │ │ ├── collections.py │ │ │ ├── fns.py │ │ │ └── filters.py │ │ ├── truthful_model_written.py │ │ ├── truthful_qa.py │ │ ├── true_false.py │ │ ├── open_book_qa.py │ │ ├── common_sense_qa.py │ │ ├── types.py │ │ ├── race.py │ │ ├── arc.py │ │ ├── geometry_of_truth.py │ │ └── dlk.py │ ├── utils │ │ ├── __init__.py │ │ ├── shuffles.py │ │ └── splits.py │ ├── activations │ │ ├── __init__.py │ │ ├── types.py │ │ └── creation.py │ └── modelwritten │ │ ├── __init__.py │ │ ├── personas.py │ │ ├── generation.py │ │ └── filtering.py ├── evals │ ├── __init__.py │ ├── types.py │ ├── probes.py │ └── logits.py ├── hooks │ ├── __init__.py │ ├── patch.py │ ├── points.py │ └── grab.py ├── models │ ├── __init__.py │ ├── loading.py │ ├── types.py │ ├── points.py │ └── llms.py ├── probes │ ├── __init__.py │ ├── normalization.py │ ├── random.py │ ├── difference_in_means.py │ ├── base.py │ ├── linear_discriminant_analysis.py │ ├── principal_component_analysis.py │ ├── linear_artificial_tomography.py │ ├── collections.py │ ├── logistic_regression.py │ └── contrast_consistent_search.py ├── utils │ ├── __init__.py │ └── pydantic_ndarray.py └── activations │ ├── __init__.py │ ├── inference.py │ └── probe_preparations.py ├── .gitignore ├── experiments ├── scratch │ ├── README.md │ ├── ccs_repro.py │ ├── modeltest.py │ ├── fp32_to_fp16.py │ ├── large_models.py │ ├── perturb.py │ ├── lr_speed.py │ ├── test_f16.py │ ├── fake_probes.py │ ├── zeroshot_investigation.py │ ├── lat_test.py │ ├── repe_repro.py │ ├── got_test.py │ ├── repe_performance_check.py │ └── probe_test.py ├── saliency_dataset.py ├── comparison_dataset.py ├── dataset_analysis.py ├── truthful_model_written_evals.py └── saliency.py ├── .github └── workflows │ └── remote.yaml ├── .pre-commit-config.yaml ├── pyproject.toml └── README.md /repeng/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /repeng/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /repeng/evals/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /repeng/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /repeng/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /repeng/probes/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /repeng/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /repeng/activations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /repeng/datasets/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /repeng/datasets/elk/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /repeng/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /repeng/datasets/elk/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /repeng/datasets/activations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /repeng/datasets/modelwritten/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | output 3 | .env 4 | .vscode 5 | -------------------------------------------------------------------------------- /repeng/datasets/modelwritten/personas.py: -------------------------------------------------------------------------------- 1 | HONEST = "always tells the truth" 2 | -------------------------------------------------------------------------------- /experiments/scratch/README.md: -------------------------------------------------------------------------------- 1 | # Scratch experiments 2 | 3 | Very very rough notebooks for tinkering. 4 | -------------------------------------------------------------------------------- /.github/workflows/remote.yaml: -------------------------------------------------------------------------------- 1 | name: remote 2 | on: 3 | push: 4 | branches: 5 | - 'main' 6 | paths: 7 | - 'poetry.lock' 8 | jobs: 9 | remote: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: mishajw/remote@main 13 | with: 14 | dockerhub_repo: 'ccs' 15 | dockerhub_username: ${{ secrets.DOCKERHUB_USERNAME }} 16 | dockerhub_password: ${{ secrets.DOCKERHUB_PASSWORD }} 17 | -------------------------------------------------------------------------------- /repeng/probes/normalization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from jaxtyping import Float, Int64 3 | 4 | 5 | def normalize_by_group( 6 | activations: Float[np.ndarray, "n d"], # noqa: F722 7 | group: Int64[np.ndarray, "n"], # noqa: F821 8 | ) -> Float[np.ndarray, "n d"]: # noqa: F722 9 | result = activations.copy() 10 | for g in np.unique(group): 11 | result[group == g] -= result[group == g].mean(axis=0) 12 | return result 13 | -------------------------------------------------------------------------------- /repeng/probes/random.py: -------------------------------------------------------------------------------- 1 | """ 2 | Baseline random probe. 3 | """ 4 | 5 | import numpy as np 6 | from jaxtyping import Float 7 | 8 | from repeng.probes.base import DotProductProbe 9 | 10 | 11 | def train_random_probe( 12 | *, 13 | activations: Float[np.ndarray, "n d"], # noqa: F722 14 | ) -> DotProductProbe: 15 | _, hidden_dim = activations.shape 16 | probe = np.random.uniform(-1, 1, size=hidden_dim) 17 | probe /= np.linalg.norm(probe) 18 | return DotProductProbe(probe=probe) 19 | -------------------------------------------------------------------------------- /repeng/datasets/utils/shuffles.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from typing import Any, Callable, Iterable, TypeVar 3 | 4 | T = TypeVar("T") 5 | 6 | 7 | def deterministic_shuffle(data: Iterable[T], key: Callable[[T], str]) -> list[T]: 8 | return sorted( 9 | data, 10 | key=lambda t: deterministic_shuffle_sort_fn(key(t), None), 11 | ) 12 | 13 | 14 | def deterministic_shuffle_sort_fn(key: str, _: Any) -> int: 15 | hash = hashlib.sha256(key.encode("utf-8")) 16 | return int(hash.hexdigest(), 16) 17 | -------------------------------------------------------------------------------- /experiments/scratch/ccs_repro.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from pathlib import Path 3 | 4 | import torch 5 | from mppr import MContext 6 | 7 | from repeng.datasets.elk.utils.collections import get_datasets 8 | from repeng.models.llms import get_llm 9 | 10 | # %% 11 | llm = get_llm("gpt2", device=torch.device("cuda"), dtype=torch.bfloat16) 12 | 13 | # %% 14 | mcontext = MContext(Path("../output/ccs_repro")) 15 | dataset = mcontext.create_cached( 16 | "dataset", 17 | lambda: get_datasets(["imdb"]), 18 | to="pickle", 19 | ) 20 | 21 | # %% 22 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: trailing-whitespace 7 | - repo: https://github.com/psf/black 8 | rev: 23.3.0 9 | hooks: 10 | - id: black 11 | - repo: https://github.com/astral-sh/ruff-pre-commit 12 | rev: v0.0.275 13 | hooks: 14 | - id: ruff 15 | args: [--fix] 16 | - repo: https://github.com/pycqa/isort 17 | rev: 5.12.0 18 | hooks: 19 | - id: isort 20 | -------------------------------------------------------------------------------- /experiments/scratch/modeltest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dotenv import load_dotenv 3 | 4 | from repeng.hooks.grab import grab 5 | from repeng.models.loading import load_llm_oioo 6 | 7 | assert load_dotenv() 8 | 9 | device = torch.device("cuda") 10 | llm = load_llm_oioo( 11 | "gemma-2b", 12 | device=device, 13 | use_half_precision=True, 14 | ) 15 | with grab(llm.model, llm.points[-3]) as grab_fn: 16 | tokens = llm.tokenizer.encode("Hello, world!", return_tensors="pt").to( 17 | device=device 18 | ) 19 | print(llm.model(tokens)) 20 | print(grab_fn()) 21 | -------------------------------------------------------------------------------- /repeng/datasets/activations/types.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | from repeng.datasets.elk.types import DatasetId, Split 4 | from repeng.models.llms import LlmId 5 | from repeng.utils.pydantic_ndarray import NdArray 6 | 7 | 8 | class ActivationResultRow(BaseModel, extra="forbid"): 9 | dataset_id: DatasetId 10 | group_id: str | None 11 | answer_type: str | None 12 | activations: dict[str, NdArray] # (s, d) 13 | prompt_logprobs: float 14 | label: bool 15 | split: Split 16 | llm_id: LlmId 17 | 18 | class Config: 19 | arbitrary_types_allowed = True 20 | -------------------------------------------------------------------------------- /repeng/evals/types.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class QuestionsEvalResult(BaseModel, extra="forbid"): 5 | accuracy: float 6 | is_flipped: bool 7 | n: int 8 | """ 9 | Some probes are trained with unsupervised methods, thus the probe is predicting the 10 | inverse of what we expect. In these cases, we flip the logits. 11 | """ 12 | 13 | 14 | class RowsEvalResult(BaseModel, extra="forbid"): 15 | f1_score: float 16 | precision: float 17 | recall: float 18 | roc_auc_score: float 19 | accuracy: float 20 | predicted_true: float 21 | fprs: list[float] 22 | tprs: list[float] 23 | logits: list[float] 24 | is_flipped: bool 25 | n: int 26 | -------------------------------------------------------------------------------- /repeng/utils/pydantic_ndarray.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pydantic import BeforeValidator, PlainSerializer 3 | from typing_extensions import Annotated 4 | 5 | 6 | def _serialize(array: np.ndarray) -> dict: 7 | return dict( 8 | dtype=str(array.dtype), 9 | array=array.tolist(), 10 | ) 11 | 12 | 13 | def _deserialize(obj: dict | np.ndarray) -> np.ndarray: 14 | if isinstance(obj, np.ndarray): 15 | return obj 16 | assert obj.keys() == {"dtype", "array"} 17 | return np.array( 18 | obj["array"], 19 | dtype=obj["dtype"], 20 | ) 21 | 22 | 23 | NdArray = Annotated[ 24 | np.ndarray, 25 | BeforeValidator(_deserialize), 26 | PlainSerializer(_serialize), 27 | ] 28 | -------------------------------------------------------------------------------- /repeng/hooks/patch.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import Callable, Generator, TypeVar 3 | 4 | import torch 5 | 6 | from repeng.hooks.points import Point 7 | 8 | ModelT = TypeVar("ModelT", bound=torch.nn.Module) 9 | 10 | 11 | @contextmanager 12 | def patch( 13 | model: ModelT, 14 | point: Point[ModelT], 15 | fn: Callable[[torch.Tensor], torch.Tensor], 16 | ) -> Generator[None, None, None]: 17 | hook_handle = None 18 | try: 19 | hook_handle = point.module_fn(model).register_forward_hook( 20 | lambda _module, _input, output: point.tensor_extractor.insert( 21 | output, fn(point.tensor_extractor.extract(output)) 22 | ) 23 | ) 24 | yield None 25 | finally: 26 | if hook_handle is not None: 27 | hook_handle.remove() 28 | -------------------------------------------------------------------------------- /repeng/probes/difference_in_means.py: -------------------------------------------------------------------------------- 1 | """ 2 | Replication of difference-in-means probes. 3 | 4 | See DIM probes in and MMP described in 5 | . 6 | 7 | Methodology: 8 | 1. Given a set of activations, and whether they respond to true or false statements. 9 | 2. Compute the difference in means between the true and false activations. This is the 10 | probe. 11 | 12 | Regularization: None. 13 | """ 14 | 15 | import numpy as np 16 | from jaxtyping import Bool, Float 17 | 18 | from repeng.probes.base import DotProductProbe 19 | 20 | 21 | def train_dim_probe( 22 | *, 23 | activations: Float[np.ndarray, "n d"], # noqa: F722 24 | labels: Bool[np.ndarray, "n"], # noqa: F821 25 | ) -> DotProductProbe: 26 | return DotProductProbe( 27 | activations[labels].mean(axis=0) - activations[~labels].mean(axis=0) 28 | ) 29 | -------------------------------------------------------------------------------- /repeng/models/loading.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | 5 | from repeng.models.llms import Llm, LlmId, get_llm 6 | 7 | _loaded_llm_id: LlmId | None = None 8 | _loaded_llm: Llm[Any, Any] | None = None 9 | 10 | 11 | def load_llm_oioo( 12 | llm_id: LlmId, 13 | device: torch.device, 14 | use_half_precision: bool, 15 | ) -> Llm[Any, Any]: 16 | """ 17 | Loads an LLM with a one-in-one-out policy, i.e. only one model is loaded into 18 | memory at a time. 19 | """ 20 | global _loaded_llm_id 21 | global _loaded_llm 22 | if llm_id != _loaded_llm_id: 23 | if _loaded_llm is not None: 24 | _loaded_llm.model = _loaded_llm.model.cpu() 25 | del _loaded_llm 26 | torch.cuda.empty_cache() 27 | print(f"Unloaded LLM {_loaded_llm_id}, loading LLM {llm_id}") 28 | else: 29 | print(f"Loading LLM {llm_id}") 30 | _loaded_llm = get_llm(llm_id, device, use_half_precision=use_half_precision) 31 | _loaded_llm_id = llm_id 32 | assert _loaded_llm is not None 33 | return _loaded_llm 34 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "repeng" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Misha Wagner "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = ">=3.10, <3.12" 10 | datasets = "^2.16.1" 11 | torch = "^2.1.2" 12 | pytest = "^7.4.4" 13 | scikit-learn = "^1.3.2" 14 | seaborn = "^0.13.1" 15 | mppr = { git = "https://github.com/mishajw/mppr.git", branch = "main" } 16 | transformers = "^4.36.2" 17 | openai = "^1.7.2" 18 | python-dotenv = "^1.0.0" 19 | jsonlines = "^4.0.0" 20 | jaxtyping = "^0.2.25" 21 | fire = "^0.5.0" 22 | accelerate = "^0.26.1" 23 | plotly = "^5.18.0" 24 | promptsource = { git = "https://github.com/bigscience-workshop/promptsource", branch = "main" } 25 | kaleido = "0.2.1" 26 | 27 | [tool.poetry.group.dev.dependencies] 28 | ipykernel = "^6.28.0" 29 | black = "^21.0.0" 30 | pre-commit = "^3.6.0" 31 | jupyter = "^1.0.0" 32 | pyright = "^1.1.347" 33 | viztracer = "^0.16.2" 34 | 35 | [build-system] 36 | requires = ["poetry-core"] 37 | build-backend = "poetry.core.masonry.api" 38 | 39 | [tool.isort] 40 | profile = "black" 41 | 42 | [tool.ruff] 43 | exclude = ['experiments/scratch'] 44 | 45 | [tool.pyright] 46 | exclude = ['experiments/scratch'] 47 | -------------------------------------------------------------------------------- /experiments/saliency_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from repeng.datasets.activations.creation import create_activations_dataset 4 | from repeng.datasets.elk.utils.limits import Limits, SplitLimits 5 | 6 | """ 7 | 4 models 8 | * 4 datasets 9 | * 20 layers 10 | * 1 token 11 | * 4K questions 12 | * 3 answers 13 | * 5120 hidden dim size 14 | * 2 bytes 15 | = 39GB 16 | """ 17 | 18 | create_activations_dataset( 19 | tag="saliency_2024-02-26_v1", 20 | llm_ids=[ 21 | "Llama-2-7b-hf", 22 | "Llama-2-7b-chat-hf", 23 | "Llama-2-13b-hf", 24 | "Llama-2-13b-chat-hf", 25 | "gemma-2b", 26 | "gemma-2b-it", 27 | "gemma-7b", 28 | "gemma-7b-it", 29 | "Mistral-7B", 30 | "Mistral-7B-Instruct", 31 | ], 32 | dataset_ids=[ 33 | "boolq/simple", 34 | "imdb/simple", 35 | "race/simple", 36 | "got_cities", 37 | ], 38 | group_limits=Limits( 39 | default=SplitLimits( 40 | train=1000, 41 | train_hparams=0, 42 | validation=400, 43 | ), 44 | by_dataset={}, 45 | ), 46 | num_tokens_from_end=1, 47 | device=torch.device("cuda"), 48 | layers_start=1, 49 | layers_end=None, 50 | layers_skip=2, 51 | ) 52 | -------------------------------------------------------------------------------- /experiments/scratch/fp32_to_fp16.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import pickle 3 | from itertools import count 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from repeng.datasets.activations.types import ActivationResultRow 10 | 11 | # %% 12 | path = Path("../../output/comparison/activations_results.pickle") 13 | output_path = Path("../../output/comparison/activations_results_fp16.pickle") 14 | assert path.exists() 15 | 16 | # %% 17 | with output_path.open("wb") as output_f: 18 | with path.open("rb") as input_f: 19 | for _ in tqdm(count()): 20 | try: 21 | result = pickle.load(input_f) 22 | except EOFError: 23 | break 24 | assert isinstance(result, dict) 25 | assert result.keys() == {"key", "value"} 26 | assert isinstance(result["key"], str) 27 | row: ActivationResultRow = result["value"] 28 | row.activations = { 29 | k: v.astype(np.float16) for k, v in row.activations.items() 30 | } 31 | assert set(value.dtype for value in row.activations.values()) == { 32 | np.dtype("float16") 33 | } 34 | pickle.dump(result, output_f) 35 | # yield result["key"], result["value"] 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Representation Engineering 2 | 3 | Experiments with representation engineering. There's been a bunch of recent work ([1](https://arxiv.org/abs/2310.01405), [2](https://arxiv.org/abs/2308.10248), [3](https://arxiv.org/abs/2212.03827)) into using a neural network's latent representations to control & interpret models. 4 | 5 | This repository contains utilities for running experiments (the `repeng` package) and a bunch of experiments (the notebooks in `experiments`). 6 | 7 | ## Installation 8 | ```bash 9 | git clone https://github.com/mishajw/repeng 10 | cd repeng 11 | pip install -e . 12 | # Or if using poetry: 13 | poetry install 14 | ``` 15 | 16 | ## Reproducing experiments 17 | 18 | ### How well do truth probes generalise? 19 | [Report](https://docs.google.com/document/d/1tz-JulAUz3SOc8Qm8MLwE9TohX8gSZlXQ2Y4PBwfJ1U). 20 | 21 | 1. Install the repository, as described [above](#installation). 22 | 2. Optional: Check out `c99e9aa`. This shouldn't be necessary, unless I introduce breaking changes. 23 | 3. Create a dataset of activations: `python experiments/comparison_dataset.py`. 24 | - This will upload the experiments to S3. Some tinkering may be required to change the upload location - sorry about that! 25 | 4. Run the analysis: `python experiments/comparison.py`. 26 | - This will write plots to `./output/comparison`. 27 | 28 | This is split into two scripts as only the first requires a GPU for LLM inference. 29 | -------------------------------------------------------------------------------- /experiments/comparison_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from repeng.datasets.activations.creation import create_activations_dataset 4 | from repeng.datasets.elk.utils.collections import ( 5 | DatasetCollectionId, 6 | resolve_dataset_ids, 7 | ) 8 | from repeng.datasets.elk.utils.limits import Limits, SplitLimits 9 | 10 | """ 11 | 18 datasets 12 | * 20 layers 13 | * 1 token 14 | * 4400 (400 + 2000 + 2000) questions 15 | * 3 answers 16 | * 5120 hidden dim size 17 | * 2 bytes 18 | = 49GB 19 | """ 20 | 21 | collections: list[DatasetCollectionId] = ["dlk", "repe", "got"] 22 | create_activations_dataset( 23 | tag="datasets_2024-02-23_truthfulqa_v1", 24 | llm_ids=["Llama-2-13b-chat-hf"], 25 | dataset_ids=[ 26 | *[ 27 | dataset_id 28 | for collection in collections 29 | for dataset_id in resolve_dataset_ids(collection) 30 | ], 31 | "truthful_qa", 32 | ], 33 | group_limits=Limits( 34 | default=SplitLimits( 35 | train=400, 36 | train_hparams=2000, 37 | validation=2000, 38 | ), 39 | by_dataset={ 40 | "truthful_qa": SplitLimits( 41 | train=0, 42 | train_hparams=0, 43 | validation=2000, 44 | ) 45 | }, 46 | ), 47 | num_tokens_from_end=1, 48 | device=torch.device("cuda"), 49 | layers_start=1, 50 | layers_end=None, 51 | layers_skip=2, 52 | ) 53 | -------------------------------------------------------------------------------- /repeng/probes/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | 4 | import numpy as np 5 | from jaxtyping import Float, Int64 6 | from typing_extensions import override 7 | 8 | 9 | class BaseProbe(ABC): 10 | @abstractmethod 11 | def predict( 12 | self, 13 | activations: Float[np.ndarray, "n d"], # noqa: F722 14 | ) -> "PredictResult": 15 | """ 16 | Predicts the probability of the label being true for each row. 17 | """ 18 | ... 19 | 20 | 21 | class BaseGroupedProbe(BaseProbe, ABC): 22 | @abstractmethod 23 | def predict_grouped( 24 | self, 25 | activations: Float[np.ndarray, "n d"], # noqa: F722 26 | pairs: Int64[np.ndarray, "n"], # noqa: F821 27 | ) -> "PredictResult": 28 | """ 29 | Predicts the probability of the label being true for each row. 30 | 31 | Activations are grouped into pairs, and the pair information is used for 32 | predictions. 33 | """ 34 | ... 35 | 36 | 37 | @dataclass 38 | class PredictResult: 39 | logits: Float[np.ndarray, "n"] # noqa: F821 40 | 41 | 42 | @dataclass 43 | class DotProductProbe(BaseProbe): 44 | probe: Float[np.ndarray, "d"] # noqa: F821 45 | 46 | @override 47 | def predict( 48 | self, 49 | activations: Float[np.ndarray, "n d"], # noqa: F722 50 | ) -> PredictResult: 51 | logits = activations @ self.probe 52 | return PredictResult(logits=logits) 53 | -------------------------------------------------------------------------------- /repeng/datasets/elk/truthful_model_written.py: -------------------------------------------------------------------------------- 1 | from importlib import resources 2 | 3 | import jsonlines 4 | 5 | from repeng.datasets import data 6 | from repeng.datasets.elk.types import BinaryRow 7 | from repeng.datasets.utils.shuffles import deterministic_shuffle 8 | from repeng.datasets.utils.splits import split_to_all 9 | 10 | _DATASET_ID = "truthful_model_written" 11 | _TEMPLATE = ( 12 | "Is the following something you would say?\n" '"{statement}"\n' "Answer: {answer}" 13 | ) 14 | 15 | 16 | def get_truthful_model_written() -> dict[str, BinaryRow]: 17 | truthful_json = resources.files(data) / "truthful.jsonl" 18 | with jsonlines.open(str(truthful_json)) as reader: 19 | results = {} 20 | for row in deterministic_shuffle(reader, lambda row: row["key"]): 21 | key = row["key"] 22 | for answer in True, False: 23 | answer_str = "Yes" if answer else "No" 24 | format_args = dict( 25 | statement=row["value"]["statement"], answer=answer_str 26 | ) 27 | text = _TEMPLATE.format(**format_args) 28 | results[f"{key}-{answer}"] = BinaryRow( 29 | dataset_id=_DATASET_ID, 30 | group_id=row["key"], 31 | split=split_to_all(_DATASET_ID, row["key"]), 32 | text=text, 33 | label=(row["value"]["honest"] and answer) 34 | or (not row["value"]["honest"] and not answer), 35 | format_args=dict(), 36 | ) 37 | return results 38 | -------------------------------------------------------------------------------- /repeng/datasets/elk/truthful_qa.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import datasets 4 | 5 | from repeng.datasets.elk.types import BinaryRow 6 | from repeng.datasets.utils.shuffles import deterministic_shuffle 7 | from repeng.datasets.utils.splits import split_validation 8 | 9 | _DATASET_ID = "truthful_qa" 10 | # Taken from https://arxiv.org/abs/2310.01405 D.1.1. 11 | _TEMPLATE = ( 12 | "Consider the amount of truthfulness in the following answer:\n" 13 | "Question: {question}\n" 14 | "Answer: {answer}\n" 15 | "The amount of truthfulness in the answer is" 16 | ) 17 | 18 | 19 | def get_truthful_qa() -> dict[str, BinaryRow]: 20 | dataset: Any = datasets.load_dataset("truthful_qa", "multiple_choice") 21 | results = {} 22 | for group_id, row in deterministic_shuffle( 23 | enumerate(dataset["validation"]), lambda row: str(row[0]) 24 | ): 25 | for answer_idx, (answer, is_correct) in enumerate( 26 | zip( 27 | row["mc1_targets"]["choices"], 28 | row["mc1_targets"]["labels"], 29 | ) 30 | ): 31 | format_args = dict(question=row["question"], answer=answer) 32 | results[f"{_DATASET_ID}-{group_id}-{answer_idx}"] = BinaryRow( 33 | dataset_id=_DATASET_ID, 34 | split=split_validation(seed=_DATASET_ID, row_id=str(group_id)), 35 | group_id=str(group_id), 36 | text=_TEMPLATE.format(**format_args), 37 | label=is_correct == 1, 38 | format_args=format_args, 39 | ) 40 | return results 41 | -------------------------------------------------------------------------------- /repeng/probes/linear_discriminant_analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Replication of linear discriminant analysis (LDA) probes. 3 | 4 | See LDA probes in and MMP-IID described in 5 | . 6 | 7 | Methodology: 8 | 1. Given a set of activations, and whether they respond to true or false statements. 9 | 2. Train a linear discriminant analysis model to separate the true and false 10 | activations. 11 | 12 | See 13 | for more details. 14 | 15 | Intuitively, this takes the truth direction and then accounts for interference from 16 | other features. 17 | 18 | Regularization: None. 19 | """ 20 | 21 | from dataclasses import dataclass 22 | 23 | import numpy as np 24 | from jaxtyping import Bool, Float 25 | from overrides import override 26 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 27 | 28 | from repeng.probes.base import BaseProbe, PredictResult 29 | 30 | 31 | @dataclass 32 | class LdaProbe(BaseProbe): 33 | model: LinearDiscriminantAnalysis 34 | 35 | @override 36 | def predict( 37 | self, 38 | activations: Float[np.ndarray, "n d"], # noqa: F722 39 | ) -> PredictResult: 40 | logits = self.model.decision_function(activations) 41 | return PredictResult(logits=logits) 42 | 43 | 44 | def train_lda_probe( 45 | *, 46 | activations: Float[np.ndarray, "n d"], # noqa: F722 47 | labels: Bool[np.ndarray, "n"], # noqa: F821 48 | ) -> LdaProbe: 49 | lda = LinearDiscriminantAnalysis() 50 | lda.fit(activations, labels) 51 | return LdaProbe(lda) 52 | -------------------------------------------------------------------------------- /experiments/scratch/large_models.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from datetime import datetime 3 | 4 | import accelerate 5 | import torch 6 | from tqdm import tqdm 7 | from transformers import AutoTokenizer, GPTNeoXForCausalLM, PreTrainedTokenizerFast 8 | 9 | # %% 10 | model_name = "EleutherAI/pythia-12b" 11 | device = torch.device("cuda") 12 | dtype = torch.bfloat16 13 | model = GPTNeoXForCausalLM.from_pretrained( 14 | model_name, 15 | device_map=device, 16 | torch_dtype=dtype, 17 | ) 18 | tokenizer = AutoTokenizer.from_pretrained(model_name) 19 | assert isinstance(model, GPTNeoXForCausalLM) 20 | assert isinstance(tokenizer, PreTrainedTokenizerFast) 21 | torch.cuda.empty_cache() 22 | 23 | # %% 24 | n_samples = 50 25 | start_time = datetime.now() 26 | for _ in tqdm(range(n_samples)): 27 | tokens = tokenizer.encode("Hello world" * 100, return_tensors="pt").to(device) 28 | model(tokens) 29 | time = datetime.now() - start_time 30 | print(time.total_seconds() / n_samples, "seconds per sample") 31 | 32 | # %% 33 | accelerate.cpu_offload(model=model) 34 | 35 | # %% 36 | layers = list(range(35, 36)) 37 | print("offloading", layers) 38 | for layer in layers: 39 | try: 40 | accelerate.cpu_offload(model.gpt_neox.layers[layer]) 41 | print("offloaded", layer) 42 | except NotImplementedError: 43 | pass 44 | 45 | # %% 46 | torch.cuda.empty_cache() 47 | 48 | # %% 49 | try: 50 | model.to(device=torch.device("cpu")) 51 | except NotImplementedError: 52 | pass 53 | del model 54 | torch.cuda.empty_cache() 55 | 56 | # %% 57 | model 58 | 59 | # %% 60 | # Experiment results: 61 | # - pythia-12b, no offloading: 0.05s 23.75GiB 62 | # - pythia-12b, offloading 35: 0.16s (3x) 23.14GiB 63 | # - pythia-12b, offloading 34-35: 0.28s (6x) 22.00GiB 64 | -------------------------------------------------------------------------------- /repeng/datasets/utils/splits.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | 3 | from repeng.datasets.elk.types import Split 4 | 5 | _TRAIN_WEIGHT = 0.6 6 | _TRAIN_HPARAMS_WEIGHT = 0.2 7 | # unused: _VALIDATION_WEIGHT = 0.2 8 | 9 | 10 | def split_to_all(seed: str, row_id: str) -> Split: 11 | """ 12 | Splits into train (60%), train-hparams (20%), and validation (20%). 13 | """ 14 | prob = _get_prob(seed=seed, row_id=row_id) 15 | if prob < _TRAIN_WEIGHT: 16 | return "train" 17 | elif prob < _TRAIN_WEIGHT + _TRAIN_HPARAMS_WEIGHT: 18 | return "train-hparams" 19 | else: 20 | return "validation" 21 | 22 | 23 | def split_train(split: Split, *, seed: str, row_id: str) -> Split: 24 | """ 25 | Splits: 26 | - train -> train (75%), train-hparams (25%) 27 | - validation -> validation 28 | """ 29 | if split == "validation": 30 | return "validation" 31 | prob = _get_prob(seed=seed, row_id=row_id) 32 | train_weight = _TRAIN_WEIGHT / (_TRAIN_WEIGHT + _TRAIN_HPARAMS_WEIGHT) 33 | if prob < train_weight: 34 | return "train" 35 | else: 36 | return "train-hparams" 37 | 38 | 39 | def split_validation(*, seed: str, row_id: str) -> Split: 40 | """ 41 | Splits into validation (80%) and train-hparams (20%). 42 | 43 | Used for datasets that are never trained on, for example TruthfulQA. 44 | """ 45 | prob = _get_prob(seed=seed, row_id=row_id) 46 | validation_weight = 0.8 47 | if prob < validation_weight: 48 | return "validation" 49 | else: 50 | return "train-hparams" 51 | 52 | 53 | def _get_prob(*, seed: str, row_id: str) -> float: 54 | hash = hashlib.sha256(f"{seed}-{row_id}".encode("utf-8")) 55 | hash = int(hash.hexdigest(), 16) 56 | return (hash % 1000) / 1000 57 | -------------------------------------------------------------------------------- /repeng/datasets/elk/true_false.py: -------------------------------------------------------------------------------- 1 | import io 2 | import zipfile 3 | from typing import cast 4 | 5 | import pandas as pd 6 | import requests 7 | 8 | from repeng.datasets.elk.types import BinaryRow 9 | from repeng.datasets.utils.shuffles import deterministic_shuffle 10 | from repeng.datasets.utils.splits import split_to_all 11 | 12 | _DATASET_ID = "true_false" 13 | 14 | 15 | # TODO: Use the prompt template from the paper. I'm not entirely sure how to set it up, 16 | # it seems that they only use true statements, but also say "an honest" vs. "a 17 | # dishonest" person in the prompt. I should look at the code I guess? 18 | def get_true_false_dataset() -> dict[str, BinaryRow]: 19 | result = {} 20 | dfs = _download_dataframes() 21 | for csv_name, df in deterministic_shuffle(dfs.items(), lambda row: row[0]): 22 | for index, row in deterministic_shuffle(df.iterrows(), lambda row: str(row[0])): 23 | assert isinstance(index, int) 24 | result[f"{csv_name}-{index}"] = BinaryRow( 25 | dataset_id=_DATASET_ID, 26 | split=split_to_all(_DATASET_ID, f"{csv_name}-{index}"), 27 | text=cast(str, row["statement"]), 28 | label=row["label"] == 1, 29 | format_args=dict(), 30 | ) 31 | return result 32 | 33 | 34 | def _download_dataframes() -> dict[str, pd.DataFrame]: 35 | response = requests.get( 36 | "http://azariaa.com/Content/Datasets/true-false-dataset.zip" 37 | ) 38 | response.raise_for_status() 39 | file_stream = io.BytesIO(response.content) 40 | dataframes = {} 41 | with zipfile.ZipFile(file_stream, "r") as zip_ref: 42 | for file_name in zip_ref.namelist(): 43 | with zip_ref.open(file_name) as file: 44 | dataframes[file_name] = pd.read_csv(file) 45 | return dataframes 46 | -------------------------------------------------------------------------------- /repeng/models/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Generic, Literal, TypeVar, get_args 3 | 4 | from transformers import PreTrainedModel, PreTrainedTokenizerFast 5 | 6 | from repeng.hooks.points import Point 7 | 8 | _ModelT = TypeVar("_ModelT", bound=PreTrainedModel) 9 | _TokenizerT = TypeVar("_TokenizerT", bound=PreTrainedTokenizerFast) 10 | 11 | PythiaId = Literal[ 12 | "pythia-70m", 13 | "pythia-160m", 14 | "pythia-410m", 15 | "pythia-1b", 16 | "pythia-1.4b", 17 | "pythia-2.8b", 18 | "pythia-6.9b", 19 | "pythia-12b", 20 | ] 21 | PythiaDpoId = Literal[ 22 | "pythia-dpo-1b", 23 | "pythia-dpo-1.4b", 24 | "pythia-sft-1b", 25 | "pythia-sft-1.4b", 26 | ] 27 | Gpt2Id = Literal["gpt2"] 28 | Llama2Id = Literal[ 29 | "Llama-2-7b-hf", 30 | "Llama-2-13b-hf", 31 | "Llama-2-70b-hf", 32 | "Llama-2-7b-chat-hf", 33 | "Llama-2-13b-chat-hf", 34 | "Llama-2-70b-chat-hf", 35 | ] 36 | MistralId = Literal[ 37 | "Mistral-7B", 38 | "Mistral-7B-Instruct", 39 | ] 40 | GemmaId = Literal[ 41 | "gemma-2b", 42 | "gemma-7b", 43 | "gemma-2b-it", 44 | "gemma-7b-it", 45 | ] 46 | 47 | LlmId = PythiaId | PythiaDpoId | Gpt2Id | Llama2Id | MistralId | GemmaId 48 | 49 | 50 | PYTHIA_DPO_TO_PYTHIA: dict[PythiaDpoId, PythiaId] = { 51 | "pythia-dpo-1b": "pythia-1b", 52 | "pythia-dpo-1.4b": "pythia-1.4b", 53 | "pythia-sft-1b": "pythia-1b", 54 | "pythia-sft-1.4b": "pythia-1.4b", 55 | } 56 | 57 | 58 | @dataclass 59 | class Llm(Generic[_ModelT, _TokenizerT]): 60 | model: _ModelT 61 | tokenizer: _TokenizerT 62 | points: list[Point[_ModelT]] 63 | 64 | 65 | def is_llm_id(llm_id: str) -> bool: 66 | return llm_id in { 67 | *get_args(PythiaId), 68 | *get_args(PythiaDpoId), 69 | *get_args(Gpt2Id), 70 | *get_args(Llama2Id), 71 | } 72 | -------------------------------------------------------------------------------- /repeng/hooks/points.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass, field 3 | from typing import Callable, Generic, TypeVar 4 | 5 | import torch 6 | from pyparsing import Any 7 | 8 | ModelT = TypeVar("ModelT", bound=torch.nn.Module) 9 | 10 | 11 | class TensorExtractor(ABC): 12 | @abstractmethod 13 | def extract(self, output: Any) -> torch.Tensor: 14 | ... 15 | 16 | @abstractmethod 17 | def insert(self, output: Any, tensor: torch.Tensor) -> Any: 18 | ... 19 | 20 | 21 | class IdentityTensorExtractor(TensorExtractor): 22 | def extract(self, output: Any) -> torch.Tensor: 23 | assert isinstance( 24 | output, torch.Tensor 25 | ), f"Expected tensor, instead found: {type(output)}" 26 | return output 27 | 28 | def insert(self, output: Any, tensor: torch.Tensor) -> Any: 29 | return tensor 30 | 31 | 32 | @dataclass 33 | class TupleTensorExtractor(TensorExtractor): 34 | index: int 35 | 36 | def extract(self, output: Any) -> torch.Tensor: 37 | assert isinstance( 38 | output, tuple 39 | ), f"Expected tuple, instead found: {type(output)}" 40 | assert isinstance( 41 | output[self.index], torch.Tensor 42 | ), f"Expected tensor, instead found: {type(output[self.index])}" 43 | return output[self.index] 44 | 45 | def insert(self, output: Any, tensor: torch.Tensor) -> Any: 46 | assert isinstance( 47 | output, tuple 48 | ), f"Expected tuple, instead found: {type(output)}" 49 | return (*output[: self.index], tensor, *output[self.index + 1 :]) 50 | 51 | 52 | @dataclass 53 | class Point(Generic[ModelT]): 54 | name: str 55 | module_fn: Callable[[ModelT], torch.nn.Module] 56 | tensor_extractor: "TensorExtractor" = field(default_factory=IdentityTensorExtractor) 57 | -------------------------------------------------------------------------------- /repeng/evals/probes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from jaxtyping import Bool, Float, Int64 3 | 4 | from repeng.evals.logits import eval_logits_by_question, eval_logits_by_row 5 | from repeng.evals.types import QuestionsEvalResult, RowsEvalResult 6 | from repeng.probes.base import BaseProbe 7 | 8 | 9 | def eval_probe_by_row( 10 | probe: BaseProbe, 11 | *, 12 | activations: Float[np.ndarray, "n d"], # noqa: F722 13 | labels: Bool[np.ndarray, "n"], # noqa: F821 14 | ) -> RowsEvalResult: 15 | result = probe.predict(activations) 16 | eval_result = eval_logits_by_row( 17 | logits=result.logits, 18 | labels=labels, 19 | ) 20 | eval_result_flipped = eval_logits_by_row( 21 | logits=-result.logits, 22 | labels=labels, 23 | ) 24 | if eval_result.roc_auc_score < eval_result_flipped.roc_auc_score: 25 | return eval_result_flipped.model_copy( 26 | update=dict(is_flipped=True), 27 | ) 28 | else: 29 | return eval_result 30 | 31 | 32 | def eval_probe_by_question( 33 | probe: BaseProbe, 34 | *, 35 | activations: Float[np.ndarray, "n d"], # noqa: F722 36 | groups: Int64[np.ndarray, "n d"], # noqa: F722 37 | labels: Bool[np.ndarray, "n"], # noqa: F821 38 | ) -> QuestionsEvalResult: 39 | result = probe.predict(activations) 40 | eval_result = eval_logits_by_question( 41 | logits=result.logits, 42 | labels=labels, 43 | groups=groups, 44 | ) 45 | eval_result_flipped = eval_logits_by_question( 46 | logits=-result.logits, 47 | labels=labels, 48 | groups=groups, 49 | ) 50 | if eval_result.accuracy < eval_result_flipped.accuracy: 51 | return QuestionsEvalResult( 52 | accuracy=eval_result_flipped.accuracy, 53 | is_flipped=True, 54 | n=eval_result.n, 55 | ) 56 | else: 57 | return eval_result 58 | -------------------------------------------------------------------------------- /repeng/probes/principal_component_analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of PCA based probes. 3 | 4 | Grouped PCA probe is equivalent to CRC-TPC described in 5 | . 6 | 7 | Methodology for ungrouped PCA probes: 8 | 1. Given a set of activations. 9 | 2. Subtract the mean activation from each activation. 10 | 3. Take the first principle component of the normalized activations. This results in the 11 | probe. 12 | 13 | Methodology for grouped PCA probes: 14 | 1. Given a set of activations. 15 | 2. Subtract the questions' mean activation from each activation. 16 | 3. Take the first principle component of the normalized activations. This results in the 17 | probe. 18 | 19 | Regularization: None. 20 | """ 21 | 22 | import numpy as np 23 | from jaxtyping import Float, Int64 24 | from sklearn.decomposition import PCA 25 | 26 | from repeng.probes.base import DotProductProbe 27 | from repeng.probes.normalization import normalize_by_group 28 | 29 | 30 | def train_pca_probe( 31 | *, 32 | activations: Float[np.ndarray, "n d"], # noqa: F722 33 | answer_types: Int64[np.ndarray, "n d"] | None, # noqa: F722 34 | ) -> DotProductProbe: 35 | if answer_types is not None: 36 | activations = normalize_by_group(activations, answer_types) 37 | activations = activations - activations.mean(axis=0) 38 | pca = PCA(n_components=1) 39 | pca.fit_transform(activations) 40 | probe = pca.components_.squeeze(0) 41 | probe = probe / np.linalg.norm(probe) 42 | return DotProductProbe(probe=probe) 43 | 44 | 45 | def train_grouped_pca_probe( 46 | *, 47 | activations: Float[np.ndarray, "n d"], # noqa: F722 48 | groups: Float[np.ndarray, "n"], # noqa: F821 49 | answer_types: Int64[np.ndarray, "n d"] | None, # noqa: F722 50 | ) -> DotProductProbe: 51 | activations = normalize_by_group(activations, groups) 52 | return train_pca_probe(activations=activations, answer_types=answer_types) 53 | -------------------------------------------------------------------------------- /repeng/hooks/grab.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import Any, Callable, Generator, TypeVar 3 | 4 | import torch 5 | 6 | from repeng.hooks.points import Point 7 | 8 | ModelT = TypeVar("ModelT", bound=torch.nn.Module) 9 | 10 | 11 | @contextmanager 12 | def grab( 13 | model: ModelT, point: Point[ModelT] 14 | ) -> Generator[Callable[[], torch.Tensor], None, None]: 15 | result: torch.Tensor | None = None 16 | 17 | def hook( 18 | _module: torch.nn.Module, 19 | _input: Any, 20 | output: Any, 21 | ) -> None: 22 | nonlocal result 23 | assert result is None, f"Hook called multiple times for point {point.name}" 24 | result = point.tensor_extractor.extract(output) 25 | assert isinstance( 26 | result, torch.Tensor 27 | ), f"Hook returned non-tensor for point {point.name}, type={type(result)}" 28 | 29 | def get_result() -> torch.Tensor: 30 | assert result is not None, f"Hook not called for point {point.name}" 31 | return result 32 | 33 | hook_handle = None 34 | try: 35 | hook_handle = point.module_fn(model).register_forward_hook(hook) 36 | yield get_result 37 | finally: 38 | if hook_handle is not None: 39 | hook_handle.remove() 40 | 41 | 42 | @contextmanager 43 | def grab_many( 44 | model: ModelT, points: list[Point[ModelT]] 45 | ) -> Generator[Callable[[], dict[str, torch.Tensor]], None, None]: 46 | context_managers = None 47 | try: 48 | context_managers = {point.name: grab(model, point) for point in points} 49 | result_fns = { 50 | name: context_manager.__enter__() 51 | for name, context_manager in context_managers.items() 52 | } 53 | yield lambda: {name: result_fn() for name, result_fn in result_fns.items()} 54 | finally: 55 | if context_managers is not None: 56 | for context_manager in context_managers.values(): 57 | context_manager.__exit__(None, None, None) 58 | -------------------------------------------------------------------------------- /repeng/datasets/elk/open_book_qa.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import datasets 4 | 5 | from repeng.datasets.elk.types import BinaryRow, DatasetId, Split, TemplateType 6 | from repeng.datasets.utils.shuffles import deterministic_shuffle 7 | from repeng.datasets.utils.splits import split_train 8 | 9 | _DATASET_IDS: dict[TemplateType, DatasetId] = { 10 | "repe": "open_book_qa", 11 | "simple": "open_book_qa/simple", 12 | } 13 | _TEMPLATES: dict[TemplateType, str] = { 14 | # Taken from https://arxiv.org/abs/2310.01405 D.1.5. 15 | "repe": ( 16 | "Consider the correctness of the following fact:\n" 17 | "Fact: {question_stem} {choice}\n" 18 | "The probability of the fact being correct is" 19 | ), 20 | "simple": "{question_stem} {choice}", 21 | } 22 | 23 | 24 | def get_open_book_qa(template_type: TemplateType) -> dict[str, BinaryRow]: 25 | return { 26 | **_get_open_book_qa_split("train", template_type=template_type), 27 | **_get_open_book_qa_split("validation", template_type=template_type), 28 | } 29 | 30 | 31 | def _get_open_book_qa_split( 32 | split: Split, template_type: TemplateType 33 | ) -> dict[str, BinaryRow]: 34 | dataset_id = _DATASET_IDS[template_type] 35 | template = _TEMPLATES[template_type] 36 | dataset: Any = datasets.load_dataset("openbookqa") 37 | results = {} 38 | for row in deterministic_shuffle(dataset[split], lambda row: row["id"]): 39 | group_id = row["id"] 40 | for choice, choice_label in zip( 41 | row["choices"]["text"], row["choices"]["label"], strict=True 42 | ): 43 | format_args = dict(question_stem=row["question_stem"], choice=choice) 44 | results[f"{dataset_id}-{group_id}-{choice_label}"] = BinaryRow( 45 | dataset_id=dataset_id, 46 | split=split_train(split, seed="open_book_qa", row_id=group_id), 47 | group_id=group_id, 48 | text=template.format(**format_args), 49 | label=row["answerKey"] == choice_label, 50 | format_args=format_args, 51 | ) 52 | return results 53 | -------------------------------------------------------------------------------- /repeng/datasets/elk/utils/limits.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from dataclasses import dataclass 3 | from typing import Callable 4 | 5 | from repeng.datasets.elk.types import BinaryRow, DatasetId, Split 6 | 7 | 8 | @dataclass 9 | class Limits: 10 | default: "SplitLimits" 11 | by_dataset: dict[DatasetId, "SplitLimits"] 12 | 13 | 14 | @dataclass 15 | class SplitLimits: 16 | train: int | None 17 | train_hparams: int | None 18 | validation: int | None 19 | 20 | 21 | @dataclass(frozen=True) 22 | class _DatasetAndSplit: 23 | dataset_id: DatasetId 24 | split: Split 25 | template_name: str | None = None 26 | 27 | 28 | @dataclass 29 | class _GroupCount: 30 | groups: set[str] 31 | num_nones: int 32 | 33 | def add(self, group_id: str | None) -> None: 34 | if group_id is None: 35 | self.num_nones += 1 36 | else: 37 | self.groups.add(group_id) 38 | 39 | def count(self) -> int: 40 | return len(self.groups) + self.num_nones 41 | 42 | 43 | def limit_groups(limits: Limits) -> Callable[[str, BinaryRow], bool]: 44 | group_counts: dict[_DatasetAndSplit, _GroupCount] = defaultdict( 45 | lambda: _GroupCount(set(), 0) 46 | ) 47 | 48 | def fn(_, row: BinaryRow) -> bool: 49 | dataset_and_split = _DatasetAndSplit( 50 | dataset_id=row.dataset_id, 51 | split=row.split, 52 | ) 53 | 54 | if row.dataset_id not in limits.by_dataset: 55 | split_limits = limits.default 56 | else: 57 | split_limits = limits.by_dataset[row.dataset_id] 58 | 59 | if row.split == "train": 60 | limit = split_limits.train 61 | elif row.split == "train-hparams": 62 | limit = split_limits.train_hparams 63 | elif row.split == "validation": 64 | limit = split_limits.validation 65 | else: 66 | raise ValueError() 67 | 68 | if limit is None: 69 | return True 70 | group_counts[dataset_and_split].add(row.group_id) 71 | return group_counts[dataset_and_split].count() <= limit 72 | 73 | return fn 74 | -------------------------------------------------------------------------------- /experiments/scratch/perturb.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from itertools import islice 3 | from typing import Any, cast 4 | 5 | import torch 6 | from datasets import load_dataset 7 | from nnsight import LanguageModel 8 | from tqdm import tqdm 9 | 10 | from repeng.models.llms import pythia 11 | 12 | # %% 13 | model = LanguageModel("EleutherAI/pythia-70m") 14 | dataset: Any = load_dataset("NeelNanda/pile-10k") 15 | 16 | # %% 17 | N_HIDDEN = 512 18 | NUM_ROWS = 200 19 | SEQ = 1024 20 | BATCH_SIZE = 8 21 | seqs = [] 22 | for row in islice(dataset["train"], 0, NUM_ROWS): 23 | prompt = row["text"] 24 | tokens = cast(torch.Tensor, model.tokenizer.encode(prompt, return_tensors="pt")) 25 | tokens = tokens.squeeze(0) 26 | for i in range(0, len(tokens), SEQ): 27 | seqs.append( 28 | torch.nn.functional.pad( 29 | tokens[i : i + SEQ], (0, SEQ - len(tokens[i : i + SEQ])), value=0 30 | ) 31 | ) 32 | tokens = torch.stack(seqs) 33 | tokens = tokens[: BATCH_SIZE * (len(tokens) // BATCH_SIZE)] 34 | tokens = tokens.reshape(-1, BATCH_SIZE, SEQ) 35 | tokens.shape 36 | 37 | # %% 38 | activations = [] 39 | for batch in tqdm(tokens[:2]): 40 | with model.invoke(tokens[0]) as invoker: 41 | batch_activations = model.gpt_neox.layers[3].mlp.input[0][0].save() 42 | activations.append(batch_activations) 43 | activations = torch.concat(activations, dim=0).reshape(-1, N_HIDDEN) 44 | 45 | # %% 46 | layer = model.meta_model.gpt_neox.layers[3].mlp 47 | new_direction = torch.nn.Parameter( 48 | torch.rand((N_HIDDEN), requires_grad=True), 49 | ) 50 | optimizer = torch.optim.Adam([new_direction]) 51 | 52 | pbar = tqdm(range(10_000)) 53 | for _ in pbar: 54 | optimizer.zero_grad() 55 | new_direction_norm = new_direction / new_direction.norm() 56 | 57 | output = model(activations) # (N_SAMPLES, N_HIDDEN) 58 | output_mod = model(activations + new_direction_norm) # (N_SAMPLES, N_HIDDEN) 59 | similarities = output @ output_mod.T 60 | loss = -similarities.mean() 61 | loss.backward() 62 | optimizer.step() 63 | pbar.set_postfix( 64 | loss=loss.item(), 65 | min=similarities.min().item(), 66 | max=similarities.max().item(), 67 | ) 68 | -------------------------------------------------------------------------------- /repeng/probes/linear_artificial_tomography.py: -------------------------------------------------------------------------------- 1 | """ 2 | Replication of LAT probes described in . See appendix 3 | C.1. 4 | 5 | Methodology: 6 | 1. Given a set of activations, randomly sample pairs without replacement. 7 | 2. Compute the difference between each pair. 8 | 3. Normalize the differences by subtracting the mean difference. 9 | 4. Take the first principle component of the normalized differences. This results in the 10 | probe. 11 | 12 | Regularization: None. 13 | """ 14 | 15 | import random 16 | from dataclasses import dataclass 17 | 18 | import numpy as np 19 | from jaxtyping import Float, Int64 20 | from sklearn.decomposition import PCA 21 | from typing_extensions import override 22 | 23 | from repeng.probes.base import DotProductProbe, PredictResult 24 | from repeng.probes.normalization import normalize_by_group 25 | 26 | 27 | @dataclass 28 | class CentredDotProductProbe(DotProductProbe): 29 | center: Float[np.ndarray, "d"] # noqa: F821 30 | 31 | @override 32 | def predict( 33 | self, 34 | activations: Float[np.ndarray, "n d"], # noqa: F722 35 | ) -> PredictResult: 36 | return super().predict(activations - self.center) 37 | 38 | 39 | def train_lat_probe( 40 | *, 41 | activations: Float[np.ndarray, "n d"], # noqa: F722 42 | answer_types: Int64[np.ndarray, "n d"] | None, # noqa: F722 43 | ) -> DotProductProbe: 44 | if answer_types is not None: 45 | activations = normalize_by_group(activations, answer_types) 46 | indices = list(range(len(activations))) 47 | random.shuffle(indices) # TODO: Double check if shuffling breaks things. 48 | indices = np.array(indices)[: len(indices) // 2 * 2] 49 | indices_1, indices_2 = indices.reshape(2, -1) 50 | 51 | activation_diffs = activations[indices_1] - activations[indices_2] 52 | activations_center = np.mean(activation_diffs, axis=0) 53 | activation_diffs_norm = activation_diffs - activations_center 54 | pca = PCA(n_components=1) 55 | pca.fit_transform(activation_diffs_norm) 56 | probe = pca.components_.squeeze(0) 57 | probe = probe / np.linalg.norm(probe) 58 | return CentredDotProductProbe(probe=probe, center=activations_center) 59 | -------------------------------------------------------------------------------- /repeng/datasets/elk/common_sense_qa.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import datasets 4 | 5 | from repeng.datasets.elk.types import BinaryRow, DatasetId, Split, TemplateType 6 | from repeng.datasets.utils.shuffles import deterministic_shuffle 7 | from repeng.datasets.utils.splits import split_train 8 | 9 | _DATASET_IDS: dict[TemplateType, DatasetId] = { 10 | "repe": "common_sense_qa", 11 | "simple": "common_sense_qa/simple", 12 | } 13 | # Taken from https://arxiv.org/abs/2310.01405 D.1.6. 14 | _TEMPLATES: dict[TemplateType, str] = { 15 | "repe": ( 16 | "Based on commonsense reasoning, consider the plausibility of the\n" 17 | "answer to the following question:\n" 18 | "Question: {question}\n" 19 | "Answer: {answer}\n" 20 | "The probability of the answer being plausible is" 21 | ), 22 | "simple": "Question: {question}\nAnswer: {answer}", 23 | } 24 | 25 | 26 | def get_common_sense_qa(template_type: TemplateType) -> dict[str, BinaryRow]: 27 | return { 28 | **_get_common_sense_qa_split("train", template_type=template_type), 29 | **_get_common_sense_qa_split("validation", template_type=template_type), 30 | } 31 | 32 | 33 | def _get_common_sense_qa_split( 34 | split: Split, template_type: TemplateType 35 | ) -> dict[str, BinaryRow]: 36 | dataset_id = _DATASET_IDS[template_type] 37 | template = _TEMPLATES[template_type] 38 | dataset: Any = datasets.load_dataset("commonsense_qa") 39 | results = {} 40 | for row in deterministic_shuffle(dataset[split], lambda row: row["id"]): 41 | group_id = row["id"] 42 | for choice, choice_label in zip( 43 | row["choices"]["text"], row["choices"]["label"], strict=True 44 | ): 45 | format_args = dict(question=row["question"], answer=choice) 46 | results[f"{dataset_id}-{group_id}-{choice_label}"] = BinaryRow( 47 | dataset_id=dataset_id, 48 | split=split_train(split, seed="common_sense_qa", row_id=group_id), 49 | group_id=group_id, 50 | text=template.format(**format_args), 51 | label=row["answerKey"] == choice_label, 52 | format_args=format_args, 53 | ) 54 | return results 55 | -------------------------------------------------------------------------------- /experiments/scratch/lr_speed.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import random 3 | from datetime import datetime 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import plotly.express as px 9 | from mppr import MContext 10 | 11 | from repeng.activations.probe_preparations import ActivationArrayDataset 12 | from repeng.datasets.activations.types import ActivationResultRow 13 | from repeng.datasets.elk.utils.filters import DatasetIdFilter 14 | from repeng.models.points import get_points 15 | from repeng.probes.logistic_regression import train_lr_probe 16 | 17 | # %% 18 | activations = np.random.normal(size=(800, 4096)) 19 | labels = np.random.binomial(1, 0.25, size=(800)).astype(bool) 20 | 21 | # %% 22 | mcontext = MContext(Path("../../output/comparison")) 23 | activation_results_nonchat: list[ActivationResultRow] = mcontext.load( 24 | "activations_results_nonchat", 25 | to="pickle", 26 | ).get() 27 | 28 | # %% 29 | results = [] 30 | points = get_points("Llama-2-7b-hf") 31 | points = points[::3] 32 | random.shuffle(points) 33 | for point in points: 34 | dataset = ActivationArrayDataset(activation_results_nonchat) 35 | arrays = dataset.get( 36 | llm_id="Llama-2-7b-hf", 37 | dataset_filter=DatasetIdFilter("arc_easy"), 38 | split="train", 39 | point_name=point.name, 40 | token_idx=0, 41 | limit=None, 42 | ) 43 | activations = arrays.activations 44 | labels = arrays.labels 45 | for solver in ["lbfgs", "liblinear", "newton-cg"]: 46 | start = datetime.now() 47 | probe = train_lr_probe(activations=activations, labels=labels, solver=solver) 48 | end = datetime.now() 49 | results.append( 50 | dict( 51 | solver=solver, 52 | time=end - start, 53 | point=point.name, 54 | n_iters=probe.model.n_iter_, 55 | ) 56 | ) 57 | print(results[-1]) 58 | 59 | # %% 60 | df = pd.DataFrame(results) 61 | df["point"] = df["point"].apply(lambda a: int(a.lstrip("h"))) 62 | df = df.sort_values("point") 63 | df["n_iters"] = df["n_iters"].apply(lambda a: a[0]) 64 | px.line(df, x="point", y="time", color="solver").show() 65 | px.line(df, x="point", y="n_iters", color="solver").show() 66 | df 67 | -------------------------------------------------------------------------------- /repeng/datasets/elk/types.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, get_args 2 | 3 | from pydantic import BaseModel 4 | 5 | Split = Literal["train", "train-hparams", "validation"] 6 | 7 | DatasetId = Literal[ 8 | "arc_challenge", 9 | "arc_easy", 10 | "got_cities", 11 | "got_sp_en_trans", 12 | "got_larger_than", 13 | "got_cities_cities_conj", 14 | "got_cities_cities_disj", 15 | "common_sense_qa", 16 | "open_book_qa", 17 | "race", 18 | "truthful_qa", 19 | "truthful_model_written", 20 | "true_false", 21 | "imdb", 22 | "imdb/simple", 23 | "amazon_polarity", 24 | "ag_news", 25 | "dbpedia_14", 26 | "rte", 27 | "copa", 28 | "boolq", 29 | "boolq/simple", 30 | "piqa", 31 | "open_book_qa/simple", 32 | "race/simple", 33 | "arc_challenge/simple", 34 | "arc_easy/simple", 35 | "common_sense_qa/simple", 36 | ] 37 | 38 | DlkDatasetId = Literal[ 39 | "imdb", 40 | "imdb/simple", 41 | "amazon_polarity", 42 | "ag_news", 43 | "dbpedia_14", 44 | "rte", 45 | "copa", 46 | "boolq", 47 | "boolq/simple", 48 | "piqa", 49 | ] 50 | 51 | GroupedDatasetId = Literal[ 52 | "arc_challenge", 53 | "arc_easy", 54 | "common_sense_qa", 55 | "open_book_qa", 56 | "race", 57 | "truthful_qa", 58 | "truthful_model_written", 59 | "true_false", 60 | ] 61 | 62 | TemplateType = Literal["repe", "dlk", "simple"] 63 | 64 | 65 | class BinaryRow(BaseModel, extra="forbid"): 66 | dataset_id: DatasetId 67 | split: Split 68 | text: str 69 | label: bool 70 | format_args: dict[str, str] 71 | group_id: str | None = None 72 | """ 73 | Rows are grouped, for example by question, in order to allow for probes that take 74 | into account intra-group relationships. 75 | """ 76 | answer_type: str | None = None 77 | """ 78 | For example, 'true' and 'false' for answers to true/false questions, or 'A', 'B', 79 | 'C', or 'D' for multiple choice questions. 80 | If not set, the prompt template doesn't include any consistent answer templates 81 | (e.g. it's just question-answer). 82 | """ 83 | 84 | 85 | def is_dataset_grouped(dataset_id: DatasetId) -> bool: 86 | return dataset_id in get_args(GroupedDatasetId) 87 | -------------------------------------------------------------------------------- /repeng/datasets/elk/race.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import datasets 4 | 5 | from repeng.datasets.elk.types import BinaryRow, DatasetId, Split, TemplateType 6 | from repeng.datasets.utils.shuffles import deterministic_shuffle 7 | from repeng.datasets.utils.splits import split_train 8 | 9 | _DATASET_IDS: dict[TemplateType, DatasetId] = { 10 | "repe": "race", 11 | "simple": "race/simple", 12 | } 13 | _TEMPLATES: dict[TemplateType, str] = { 14 | # Taken from https://arxiv.org/abs/2310.01405 D.1.7. 15 | "repe": ( 16 | "Consider the correctness of the answer to the following question based on the " 17 | "article:\n" 18 | "\n" 19 | "Article: {article}\n" 20 | "\n" 21 | "Question: {question}\n" 22 | "Answer: {answer}\n" 23 | "The probability of the answer being correct is" 24 | ), 25 | "simple": "Article: {article}\n\nQuestion: {question}\nAnswer: {answer}", 26 | } 27 | _ANSWER_TO_INDEX = {"A": 0, "B": 1, "C": 2, "D": 3} 28 | 29 | 30 | def get_race(template_id: TemplateType) -> dict[str, BinaryRow]: 31 | return { 32 | **_get_race_split("train", template_id=template_id), 33 | **_get_race_split("validation", template_id=template_id), 34 | } 35 | 36 | 37 | def _get_race_split(split: Split, template_id: TemplateType) -> dict[str, BinaryRow]: 38 | dataset_id = _DATASET_IDS[template_id] 39 | template = _TEMPLATES[template_id] 40 | dataset: Any = datasets.load_dataset("race", "all") 41 | results = {} 42 | for row in deterministic_shuffle(dataset[split], lambda row: row["example_id"]): 43 | group_id = row["example_id"] 44 | for option_idx, option in enumerate(row["options"]): 45 | format_args = dict( 46 | article=row["article"], question=row["question"], answer=option 47 | ) 48 | results[f"{dataset_id}-{group_id}-{option_idx}"] = BinaryRow( 49 | dataset_id=dataset_id, 50 | split=split_train(split, seed="race", row_id=group_id), 51 | group_id=group_id, 52 | text=template.format(**format_args), 53 | label=_ANSWER_TO_INDEX[row["answer"]] == option_idx, 54 | format_args=format_args, 55 | ) 56 | return results 57 | -------------------------------------------------------------------------------- /repeng/datasets/elk/arc.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal 2 | 3 | import datasets 4 | 5 | from repeng.datasets.elk.types import BinaryRow, DatasetId, Split, TemplateType 6 | from repeng.datasets.utils.shuffles import deterministic_shuffle 7 | from repeng.datasets.utils.splits import split_train 8 | 9 | ArcSubset = Literal["easy", "challenge"] 10 | 11 | _DATASET_IDS: dict[tuple[ArcSubset, TemplateType], DatasetId] = { 12 | ("easy", "repe"): "arc_easy", 13 | ("challenge", "repe"): "arc_challenge", 14 | ("easy", "simple"): "arc_easy/simple", 15 | ("challenge", "simple"): "arc_challenge/simple", 16 | } 17 | _TEMPLATE: dict[TemplateType, str] = { 18 | # Taken from https://arxiv.org/abs/2310.01405 D.1.4. 19 | "repe": ( 20 | "Consider the correctness of the answer to the following question:\n" 21 | "\n" 22 | "Question: {question}\n" 23 | "Answer: {answer}\n" 24 | "The probability of the answer being correct is" 25 | ), 26 | "simple": "Question: {question}\nAnswer: {answer}", 27 | } 28 | _SUBSET_TO_NAME: dict[ArcSubset, str] = { 29 | "easy": "ARC-Easy", 30 | "challenge": "ARC-Challenge", 31 | } 32 | 33 | 34 | def get_arc(subset: ArcSubset, template_type: TemplateType) -> dict[str, BinaryRow]: 35 | return { 36 | **_get_arc_split(subset=subset, split="train", template_type=template_type), 37 | **_get_arc_split( 38 | subset=subset, split="validation", template_type=template_type 39 | ), 40 | } 41 | 42 | 43 | def _get_arc_split( 44 | *, 45 | subset: ArcSubset, 46 | split: Split, 47 | template_type: TemplateType, 48 | ) -> dict[str, BinaryRow]: 49 | dataset_id = _DATASET_IDS[(subset, template_type)] 50 | template = _TEMPLATE[template_type] 51 | dataset: Any = datasets.load_dataset("ai2_arc", _SUBSET_TO_NAME[subset]) 52 | results = {} 53 | for row in deterministic_shuffle(dataset[split], lambda row: row["id"]): 54 | group_id = row["id"] 55 | for choice, choice_label in zip( 56 | row["choices"]["text"], row["choices"]["label"], strict=True 57 | ): 58 | format_args = dict(question=row["question"], answer=choice) 59 | results[f"{dataset_id}-{group_id}-{choice_label}"] = BinaryRow( 60 | dataset_id=dataset_id, 61 | split=split_train(split, seed="arc" + subset, row_id=group_id), 62 | group_id=group_id, 63 | text=template.format(**format_args), 64 | label=row["answerKey"] == choice_label, 65 | format_args=format_args, 66 | ) 67 | return results 68 | -------------------------------------------------------------------------------- /repeng/activations/inference.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | import numpy as np 4 | import torch 5 | from pydantic import BaseModel 6 | from transformers import PreTrainedModel, PreTrainedTokenizerFast 7 | 8 | from repeng.hooks.grab import grab_many 9 | from repeng.models.llms import Llm 10 | from repeng.utils.pydantic_ndarray import NdArray 11 | 12 | ModelT = TypeVar("ModelT", bound=PreTrainedModel) 13 | 14 | 15 | class ActivationRow(BaseModel, extra="forbid"): 16 | text: str 17 | text_tokenized: list[str] 18 | activations: dict[str, NdArray] 19 | token_logprobs: NdArray 20 | 21 | class Config: 22 | arbitrary_types_allowed = True 23 | 24 | 25 | @torch.inference_mode() 26 | def get_model_activations( 27 | llm: Llm[ModelT, PreTrainedTokenizerFast], 28 | *, 29 | text: str, 30 | last_n_tokens: int | None, 31 | points_start: int | None, 32 | points_end: int | None, 33 | points_skip: int | None, 34 | ) -> ActivationRow: 35 | assert last_n_tokens is None or last_n_tokens > 0, last_n_tokens 36 | 37 | tokens = llm.tokenizer.encode(text, return_tensors="pt") 38 | assert isinstance(tokens, torch.Tensor) 39 | tokens = tokens.to(next(llm.model.parameters()).device) 40 | tokens_str = llm.tokenizer.convert_ids_to_tokens(tokens.squeeze().tolist()) 41 | assert isinstance(tokens_str, list) 42 | 43 | points = llm.points[points_start:points_end:points_skip] 44 | with grab_many(llm.model, points) as activation_fn: 45 | output = llm.model.forward(tokens) 46 | logits: torch.Tensor = output.logits 47 | layer_activations = activation_fn() 48 | 49 | logprobs = logits.log_softmax(dim=-1) 50 | logprobs_shifted = logprobs[0, :-1] 51 | tokens_shifted = tokens[0, 1:, None] 52 | token_logprobs = ( 53 | logprobs_shifted.gather(dim=-1, index=tokens_shifted).squeeze(0).detach() 54 | ) 55 | 56 | def get_activation(activations: torch.Tensor) -> np.ndarray: 57 | activations = activations.squeeze(0) 58 | if last_n_tokens is not None: 59 | activations = activations[-last_n_tokens:] 60 | activations = activations.detach() 61 | # bfloat16 is not supported by numpy. We lose some precision by converting to 62 | # float16, but it significantly saves space an empirically makes little 63 | # difference. 64 | activations = activations.to(dtype=torch.float16) 65 | return activations.cpu().numpy() 66 | 67 | return ActivationRow( 68 | text=text, 69 | text_tokenized=tokens_str, 70 | activations={ 71 | name: get_activation(activations) 72 | for name, activations in layer_activations.items() 73 | }, 74 | token_logprobs=token_logprobs.float().cpu().numpy(), 75 | ) 76 | -------------------------------------------------------------------------------- /repeng/datasets/modelwritten/generation.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import openai 4 | from openai.types.chat import ( 5 | ChatCompletionSystemMessageParam, 6 | ChatCompletionUserMessageParam, 7 | ) 8 | from openai.types.chat.completion_create_params import ResponseFormat 9 | from pydantic import BaseModel 10 | 11 | from repeng.datasets.modelwritten import personas 12 | 13 | _GENERATION_TEMPLATE = ( 14 | "Suppose there is a person who {description}. " 15 | "Please write a list of statements (stated in the first person) " 16 | "that they would {agree_or_disagree} with, " 17 | "but others would {disagree_or_agree} with:\n" 18 | ) 19 | 20 | 21 | class GenerationConfig(BaseModel, extra="forbid"): 22 | model: Literal["gpt-3.5-turbo-1106", "gpt-4-1106-preview"] = "gpt-3.5-turbo-1106" 23 | temperature: float = 1.4 24 | top_p: float = 0.975 25 | 26 | 27 | class Statements(BaseModel, extra="forbid"): 28 | statements: list[str] 29 | agrees: bool 30 | 31 | 32 | class _StatementsSchema(BaseModel, extra="forbid"): 33 | statements: list[str] 34 | 35 | 36 | async def generate_statements( 37 | client: openai.AsyncOpenAI, 38 | config: GenerationConfig, 39 | *, 40 | agrees: bool, 41 | ) -> Statements: 42 | response = await client.chat.completions.create( 43 | messages=[ 44 | ChatCompletionSystemMessageParam( 45 | role="system", 46 | content=( 47 | "Give your response in JSON format, using the schema: " 48 | f"{_StatementsSchema.model_json_schema()}" 49 | ), 50 | ), 51 | ChatCompletionUserMessageParam( 52 | role="user", 53 | content=_format_generation_prompt(agrees=agrees), 54 | ), 55 | ], 56 | model=config.model, 57 | temperature=config.temperature, 58 | top_p=config.top_p, 59 | response_format=ResponseFormat(type="json_object"), 60 | ) 61 | assert len(response.choices) == 1, response 62 | message = response.choices[0].message 63 | assert message.content is not None, message 64 | statements = _StatementsSchema.model_validate_json(message.content).statements 65 | return Statements(statements=statements, agrees=agrees) 66 | 67 | 68 | def _format_generation_prompt(*, agrees: bool) -> str: 69 | description = personas.HONEST 70 | if agrees: 71 | agree_or_disagree = "agree" 72 | disagree_or_agree = "disagree" 73 | else: 74 | agree_or_disagree = "disagree" 75 | disagree_or_agree = "agree" 76 | return _GENERATION_TEMPLATE.format( 77 | description=description, 78 | agree_or_disagree=agree_or_disagree, 79 | disagree_or_agree=disagree_or_agree, 80 | ) 81 | -------------------------------------------------------------------------------- /repeng/datasets/elk/utils/collections.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal, cast, get_args 3 | 4 | from overrides import override 5 | 6 | from repeng.datasets.elk.types import BinaryRow, DatasetId 7 | from repeng.datasets.elk.utils.filters import DatasetCollectionFilter, DatasetFilter 8 | from repeng.datasets.elk.utils.fns import get_datasets 9 | 10 | DatasetCollectionId = Literal[ 11 | "all", 12 | "dlk", 13 | "repe", 14 | "got", 15 | "repe-simple", 16 | ] 17 | 18 | _DATASET_COLLECTIONS: dict[DatasetCollectionId, DatasetCollectionFilter] = { 19 | "all": DatasetCollectionFilter( 20 | "all", cast(list[DatasetId], list(get_args(DatasetId))) 21 | ), 22 | "dlk": DatasetCollectionFilter( 23 | "dlk", 24 | [ 25 | "imdb", 26 | "amazon_polarity", 27 | "ag_news", 28 | "dbpedia_14", 29 | "rte", 30 | "copa", 31 | "piqa", 32 | "boolq", 33 | ], 34 | ), 35 | "repe": DatasetCollectionFilter( 36 | "repe", 37 | [ 38 | "open_book_qa", 39 | "race", 40 | "arc_challenge", 41 | "arc_easy", 42 | "common_sense_qa", 43 | ], 44 | ), 45 | "repe-simple": DatasetCollectionFilter( 46 | "repe-simple", 47 | [ 48 | "open_book_qa/simple", 49 | "race/simple", 50 | "arc_challenge/simple", 51 | "arc_easy/simple", 52 | "common_sense_qa/simple", 53 | ], 54 | ), 55 | "got": DatasetCollectionFilter( 56 | "got", 57 | [ 58 | "got_cities", 59 | "got_sp_en_trans", 60 | "got_cities_cities_conj", 61 | "got_cities_cities_disj", 62 | "got_larger_than", 63 | ], 64 | ), 65 | } 66 | 67 | 68 | @dataclass 69 | class DatasetCollectionIdFilter(DatasetFilter): 70 | collection: DatasetCollectionId 71 | 72 | @override 73 | def get_name(self) -> str: 74 | return self.collection 75 | 76 | @override 77 | def filter(self, dataset_id: DatasetId, answer_type: str | None) -> bool: 78 | return dataset_id in resolve_dataset_ids(self.collection) 79 | 80 | 81 | def get_dataset_collection( 82 | dataset_collection_id: DatasetCollectionId, 83 | ) -> dict[str, BinaryRow]: 84 | return get_datasets(resolve_dataset_ids(dataset_collection_id)) 85 | 86 | 87 | def resolve_dataset_ids( 88 | id: DatasetId | DatasetCollectionId, 89 | ) -> list[DatasetId]: 90 | if id in get_args(DatasetId): 91 | return [cast(DatasetId, id)] 92 | elif id in get_args(DatasetCollectionId): 93 | return _DATASET_COLLECTIONS[cast(DatasetCollectionId, id)].datasets 94 | else: 95 | raise ValueError(f"Unknown ID: {id}") 96 | -------------------------------------------------------------------------------- /repeng/datasets/elk/utils/fns.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from tqdm import tqdm 4 | 5 | from repeng.datasets.elk.arc import get_arc 6 | from repeng.datasets.elk.common_sense_qa import get_common_sense_qa 7 | from repeng.datasets.elk.dlk import get_dlk_dataset 8 | from repeng.datasets.elk.geometry_of_truth import get_geometry_of_truth 9 | from repeng.datasets.elk.open_book_qa import get_open_book_qa 10 | from repeng.datasets.elk.race import get_race 11 | from repeng.datasets.elk.true_false import get_true_false_dataset 12 | from repeng.datasets.elk.truthful_model_written import get_truthful_model_written 13 | from repeng.datasets.elk.truthful_qa import get_truthful_qa 14 | from repeng.datasets.elk.types import BinaryRow, DatasetId 15 | 16 | _DATASET_FNS: dict[DatasetId, Callable[[], dict[str, BinaryRow]]] = { 17 | "got_cities": lambda: get_geometry_of_truth("cities"), 18 | "got_sp_en_trans": lambda: get_geometry_of_truth("sp_en_trans"), 19 | "got_larger_than": lambda: get_geometry_of_truth("larger_than"), 20 | "got_cities_cities_conj": lambda: get_geometry_of_truth("cities_cities_conj"), 21 | "got_cities_cities_disj": lambda: get_geometry_of_truth("cities_cities_disj"), 22 | "arc_challenge": lambda: get_arc("challenge", "repe"), 23 | "arc_easy": lambda: get_arc("easy", "repe"), 24 | "common_sense_qa": lambda: get_common_sense_qa("repe"), 25 | "open_book_qa": lambda: get_open_book_qa("repe"), 26 | "race": lambda: get_race("repe"), 27 | "arc_challenge/simple": lambda: get_arc("challenge", "simple"), 28 | "arc_easy/simple": lambda: get_arc("easy", "simple"), 29 | "common_sense_qa/simple": lambda: get_common_sense_qa("simple"), 30 | "open_book_qa/simple": lambda: get_open_book_qa("simple"), 31 | "race/simple": lambda: get_race("simple"), 32 | "truthful_qa": lambda: get_truthful_qa(), 33 | "truthful_model_written": lambda: get_truthful_model_written(), 34 | "true_false": get_true_false_dataset, 35 | "imdb": lambda: get_dlk_dataset("imdb"), 36 | "imdb/simple": lambda: get_dlk_dataset("imdb/simple"), 37 | "amazon_polarity": lambda: get_dlk_dataset("amazon_polarity"), 38 | "ag_news": lambda: get_dlk_dataset("ag_news"), 39 | "dbpedia_14": lambda: get_dlk_dataset("dbpedia_14"), 40 | "rte": lambda: get_dlk_dataset("rte"), 41 | "copa": lambda: get_dlk_dataset("copa"), 42 | "boolq": lambda: get_dlk_dataset("boolq"), 43 | "boolq/simple": lambda: get_dlk_dataset("boolq/simple"), 44 | "piqa": lambda: get_dlk_dataset("piqa"), 45 | } 46 | 47 | 48 | def get_dataset(dataset_id: DatasetId) -> dict[str, BinaryRow]: 49 | return _DATASET_FNS[dataset_id]() 50 | 51 | 52 | def get_datasets(dataset_ids: list[DatasetId]) -> dict[str, BinaryRow]: 53 | result = {} 54 | pbar = tqdm(dataset_ids, desc="loading datasets") 55 | for dataset_id in pbar: 56 | pbar.set_postfix(dataset=dataset_id) 57 | result.update(get_dataset(dataset_id)) 58 | return result 59 | -------------------------------------------------------------------------------- /repeng/datasets/modelwritten/filtering.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Literal 3 | 4 | import openai 5 | from openai.types.chat import ( 6 | ChatCompletionSystemMessageParam, 7 | ChatCompletionTokenLogprob, 8 | ChatCompletionUserMessageParam, 9 | ) 10 | from pydantic import BaseModel 11 | 12 | from repeng.datasets.modelwritten import personas 13 | 14 | _FILTERING_TEMPLATE = ( 15 | "Suppose there is a person who {description}. " 16 | "Would they agree or disagree with the following statement? " 17 | '"{statement}"' 18 | ) 19 | 20 | 21 | class FilteringConfig(BaseModel, extra="forbid"): 22 | model: Literal["gpt-3.5-turbo-1106", "gpt-4-1106-preview"] = "gpt-3.5-turbo-1106" 23 | 24 | 25 | class StatementLikelihood(BaseModel, extra="forbid"): 26 | statement: str 27 | likelihood: float 28 | yes_logprobs: float | None 29 | no_logprobs: float | None 30 | 31 | 32 | async def get_statement_likelihood( 33 | client: openai.AsyncClient, 34 | config: FilteringConfig, 35 | *, 36 | statement: str, 37 | ) -> StatementLikelihood: 38 | response = await client.chat.completions.create( 39 | messages=[ 40 | ChatCompletionSystemMessageParam( 41 | role="system", 42 | content="Only reply 'Yes', or 'No' to the user's question.", 43 | ), 44 | ChatCompletionUserMessageParam( 45 | role="user", 46 | content=_FILTERING_TEMPLATE.format( 47 | description=personas.HONEST, 48 | statement=statement, 49 | ), 50 | ), 51 | ], 52 | model=config.model, 53 | temperature=0, 54 | logprobs=True, 55 | top_logprobs=5, 56 | max_tokens=1, 57 | ) 58 | assert len(response.choices) == 1, response 59 | 60 | choice = response.choices[0] 61 | assert choice.message.content is not None, choice 62 | assert choice.message.content in ["Yes", "No"], choice 63 | 64 | assert choice.logprobs is not None, choice 65 | assert choice.logprobs.content is not None, choice 66 | assert len(choice.logprobs.content) == 1, choice 67 | yes_logprobs = _get_logprobs("Yes", choice.logprobs.content[0]) 68 | no_logprobs = _get_logprobs("No", choice.logprobs.content[0]) 69 | yes_probs = math.exp(yes_logprobs) if yes_logprobs is not None else 0 70 | no_probs = math.exp(no_logprobs) if no_logprobs is not None else 0 71 | likelihood = yes_probs / (yes_probs + no_probs) 72 | 73 | return StatementLikelihood( 74 | statement=statement, 75 | likelihood=likelihood, 76 | yes_logprobs=yes_logprobs, 77 | no_logprobs=no_logprobs, 78 | ) 79 | 80 | 81 | def _get_logprobs( 82 | token_str: str, token_logprobs: ChatCompletionTokenLogprob 83 | ) -> float | None: 84 | for token in token_logprobs.top_logprobs: 85 | if token.token == token_str: 86 | return token.logprob 87 | return None 88 | -------------------------------------------------------------------------------- /repeng/evals/logits.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn.metrics 3 | from jaxtyping import Bool, Float, Int64 4 | 5 | from repeng.evals.types import QuestionsEvalResult, RowsEvalResult 6 | 7 | 8 | def eval_logits_by_question( 9 | *, 10 | logits: Float[np.ndarray, "n"], # noqa: F821 11 | labels: Bool[np.ndarray, "n"], # noqa: F821 12 | groups: Int64[np.ndarray, "n"], # noqa: F821 13 | ) -> QuestionsEvalResult: 14 | correct: list[bool] = [] 15 | for group in np.unique(groups): 16 | group_labels = labels[groups == group].copy() 17 | group_logits = logits[groups == group].copy() 18 | max_logit = np.max(group_logits) 19 | # Instead of just doing argmax, we randomly select among the max logits. This is 20 | # because some of the datasets always have the correct answer first, so 21 | # outputting identical logits for all answers would have 100% accuracy. 22 | max_idx = np.random.choice(np.where(group_logits == max_logit)[0]) 23 | # N.B.: Some datasets have multiple correct answers per question. This is fine, 24 | # we just check that one of the correct answers was chosen. 25 | correct.append(group_labels[max_idx]) # type: ignore 26 | accuracy = sum(correct) / len(correct) 27 | return QuestionsEvalResult(accuracy=accuracy, is_flipped=False, n=len(correct)) 28 | 29 | 30 | def eval_logits_by_row( 31 | *, 32 | logits: Float[np.ndarray, "n"], # noqa: F821 33 | labels: Bool[np.ndarray, "n"], # noqa: F821 34 | ) -> RowsEvalResult: 35 | fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, logits) 36 | best_idx = np.argmin(np.sqrt((1 - tpr) ** 2 + fpr**2)) 37 | threshold = thresholds[best_idx] 38 | 39 | labels_pred = logits > threshold 40 | f1_score = sklearn.metrics.f1_score( 41 | labels, 42 | labels_pred, 43 | zero_division=0, # type: ignore 44 | ) 45 | precision = sklearn.metrics.precision_score( 46 | labels, 47 | labels_pred, 48 | zero_division=0, # type: ignore 49 | ) 50 | recall = sklearn.metrics.recall_score( 51 | labels, 52 | labels_pred, 53 | zero_division=0, # type: ignore 54 | ) 55 | accuracy = sklearn.metrics.accuracy_score(labels, labels_pred) 56 | predicted_true = labels_pred.mean().item() 57 | roc_auc_score = sklearn.metrics.roc_auc_score(labels, logits) 58 | assert ( 59 | isinstance(f1_score, float) 60 | and isinstance(precision, float) 61 | and isinstance(recall, float) 62 | and isinstance(roc_auc_score, float) 63 | and isinstance(accuracy, float) 64 | ) 65 | return RowsEvalResult( 66 | f1_score=f1_score, 67 | precision=precision, 68 | recall=recall, 69 | roc_auc_score=roc_auc_score, 70 | accuracy=accuracy, 71 | predicted_true=predicted_true, 72 | fprs=fpr.tolist(), 73 | tprs=tpr.tolist(), 74 | logits=logits.tolist(), 75 | is_flipped=False, 76 | n=len(labels), 77 | ) 78 | -------------------------------------------------------------------------------- /experiments/scratch/test_f16.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import plotly.express as px 6 | from mppr import MContext 7 | 8 | from repeng.activations.probe_preparations import ActivationArrayDataset 9 | from repeng.datasets.activations.types import ActivationResultRow 10 | from repeng.evals.logits import eval_logits_by_question 11 | from repeng.evals.probes import eval_probe_by_question 12 | from repeng.probes.logistic_regression import train_grouped_lr_probe 13 | 14 | # %% 15 | mcontext = MContext(Path("../output/comparison")) 16 | activations_dataset: list[ActivationResultRow] = mcontext.download_cached( 17 | "activations_dataset", 18 | path=( 19 | "s3://repeng/datasets/activations/" 20 | # "datasets_2024-02-02_tokensandlayers_v1.pickle" 21 | "datasets_2024-02-03_v1.pickle" 22 | ), 23 | to="pickle", 24 | ).get() 25 | print(set(row.llm_id for row in activations_dataset)) 26 | print(set(row.dataset_id for row in activations_dataset)) 27 | print(set(row.split for row in activations_dataset)) 28 | dataset = ActivationArrayDataset(activations_dataset) 29 | 30 | # %% 31 | arrays = dataset.get( 32 | llm_id="pythia-12b", 33 | dataset_filter_id="arc_easy", 34 | split="train", 35 | point_name="h34", 36 | token_idx=-1, 37 | limit=None, 38 | ) 39 | assert arrays.groups is not None 40 | probe = train_grouped_lr_probe( 41 | activations=arrays.activations, 42 | labels=arrays.labels, 43 | groups=arrays.groups, 44 | ) 45 | 46 | arrays_val = dataset.get( 47 | llm_id="pythia-12b", 48 | dataset_filter_id="arc_easy", 49 | split="validation", 50 | point_name="h34", 51 | token_idx=-1, 52 | limit=None, 53 | ) 54 | assert arrays_val.groups is not None 55 | print( 56 | eval_probe_by_question( 57 | probe, 58 | activations=arrays_val.activations, 59 | labels=arrays_val.labels, 60 | groups=arrays_val.groups, 61 | ) 62 | ) 63 | 64 | # %% 65 | arrays = dataset.get( 66 | llm_id="pythia-12b", 67 | dataset_filter_id="arc_easy", 68 | split="train", 69 | point_name="logprobs", 70 | token_idx=-1, 71 | limit=None, 72 | ) 73 | assert arrays.groups is not None 74 | print( 75 | eval_logits_by_question( 76 | logits=arrays.activations, 77 | labels=arrays.labels, 78 | groups=arrays.groups, 79 | ) 80 | ) 81 | 82 | # %% 83 | activations = arrays.activations.copy() 84 | for group in np.unique(arrays.groups): 85 | activations[arrays.groups == group] -= activations[arrays.groups == group].mean( 86 | axis=0 87 | ) 88 | 89 | (idxs0,) = np.where(arrays.labels == 0) 90 | (idxs1,) = np.where(arrays.labels == 1) 91 | idxs = np.concatenate([idxs0[::3], idxs1]) 92 | px.histogram( 93 | x=activations[idxs], 94 | color=arrays.labels[idxs], 95 | nbins=50, 96 | opacity=0.5, 97 | barmode="overlay", 98 | ) 99 | 100 | # %% 101 | activations_dataset[0].activations["h0"].dtype 102 | -------------------------------------------------------------------------------- /repeng/activations/probe_preparations.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from jaxtyping import Bool, Float, Int64 7 | 8 | from repeng.datasets.activations.types import ActivationResultRow 9 | from repeng.datasets.elk.types import Split 10 | from repeng.datasets.elk.utils.filters import DatasetFilter 11 | from repeng.models.types import LlmId 12 | 13 | 14 | @dataclass 15 | class ActivationArrays: 16 | activations: Float[np.ndarray, "n d"] # noqa: F722 17 | labels: Bool[np.ndarray, "n"] # noqa: F821 18 | groups: Int64[np.ndarray, "n"] | None # noqa: F821 19 | answer_types: Int64[np.ndarray, "n"] | None # noqa: F821 20 | 21 | 22 | @dataclass 23 | class ActivationArrayDataset: 24 | rows: list[ActivationResultRow] 25 | 26 | def get( 27 | self, 28 | *, 29 | llm_id: LlmId, 30 | dataset_filter: DatasetFilter, 31 | split: Split, 32 | point_name: str | Literal["logprobs"], 33 | token_idx: int, 34 | limit: int | None, 35 | ) -> ActivationArrays: 36 | df = pd.DataFrame( 37 | [ 38 | dict( 39 | label=row.label, 40 | group_id=row.group_id, 41 | answer_type=row.answer_type, 42 | activations=( 43 | row.activations[point_name][token_idx].copy() 44 | if point_name != "logprobs" 45 | else np.array(row.prompt_logprobs) 46 | ), 47 | ) 48 | for row in self.rows 49 | if dataset_filter.filter( 50 | dataset_id=row.dataset_id, 51 | answer_type=row.answer_type, 52 | ) 53 | and row.split == split 54 | and row.llm_id == llm_id 55 | ][:limit] 56 | ) 57 | assert not df.empty, (llm_id, dataset_filter, split, point_name, token_idx) 58 | 59 | group_counts = df["group_id"].value_counts().rename("group_count") 60 | df = df.join(group_counts, on="group_id") 61 | df.loc[df["group_count"] <= 1, "group_id"] = np.nan 62 | if df["group_id"].isna().all(): # type: ignore 63 | groups = None 64 | else: 65 | df = df[df["group_id"].notna()] 66 | groups = ( 67 | df["group_id"].astype("category").cat.codes.to_numpy() # type: ignore 68 | ) 69 | 70 | if df["answer_type"].isna().any(): # type: ignore 71 | answer_types = None 72 | else: 73 | answer_types = ( 74 | df["answer_type"] 75 | .astype("category") 76 | .cat.codes.to_numpy() # type: ignore 77 | ) 78 | 79 | return ActivationArrays( 80 | activations=np.stack(df["activations"].tolist()).astype(np.float32), 81 | labels=df["label"].to_numpy(), # type: ignore 82 | groups=groups, 83 | answer_types=answer_types, 84 | ) 85 | -------------------------------------------------------------------------------- /repeng/datasets/elk/utils/filters.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Literal 4 | 5 | from overrides import override 6 | 7 | from repeng.datasets.elk.types import DatasetId 8 | 9 | 10 | class DatasetFilter(ABC): 11 | @abstractmethod 12 | def get_name(self) -> str: 13 | """ 14 | Gets the name of the filter, for use in reporting and plotting. 15 | """ 16 | ... 17 | 18 | @abstractmethod 19 | def filter(self, dataset_id: DatasetId, answer_type: str | None) -> bool: 20 | """ 21 | Filters a row of a dataset. 22 | """ 23 | ... 24 | 25 | 26 | @dataclass 27 | class DatasetIdFilter(DatasetFilter): 28 | dataset_id: DatasetId 29 | 30 | @override 31 | def get_name(self) -> str: 32 | return self.dataset_id 33 | 34 | @override 35 | def filter(self, dataset_id: DatasetId, answer_type: str | None) -> bool: 36 | return dataset_id == self.dataset_id 37 | 38 | 39 | @dataclass 40 | class ExactMatchFilter(DatasetFilter): 41 | name: str 42 | dataset_id: DatasetId 43 | answer_type: str | None 44 | 45 | @override 46 | def get_name(self) -> str: 47 | return self.name 48 | 49 | @override 50 | def filter(self, dataset_id: DatasetId, answer_type: str | None) -> bool: 51 | return dataset_id == self.dataset_id and answer_type == self.answer_type 52 | 53 | 54 | @dataclass 55 | class DatasetCollectionFilter(DatasetFilter): 56 | name: str 57 | datasets: list[DatasetId] 58 | 59 | @override 60 | def get_name(self) -> str: 61 | return self.name 62 | 63 | @override 64 | def filter(self, dataset_id: DatasetId, answer_type: str | None) -> bool: 65 | return dataset_id in self.datasets 66 | 67 | 68 | DatasetFilterId = Literal[ 69 | "got_cities/pos", 70 | "got_cities/neg", 71 | "got_sp_en_trans/pos", 72 | "got_sp_en_trans/neg", 73 | "got_larger_than/large", 74 | "got_larger_than/small", 75 | ] 76 | 77 | DATASET_FILTER_FNS: dict[DatasetFilterId, DatasetFilter] = { 78 | "got_cities/pos": ExactMatchFilter( 79 | "got_cities/pos", 80 | dataset_id="got_cities", 81 | answer_type="pos", 82 | ), 83 | "got_cities/neg": ExactMatchFilter( 84 | "got_cities/neg", 85 | dataset_id="got_cities", 86 | answer_type="neg", 87 | ), 88 | "got_sp_en_trans/pos": ExactMatchFilter( 89 | "got_sp_en_trans/pos", 90 | dataset_id="got_sp_en_trans", 91 | answer_type="pos", 92 | ), 93 | "got_sp_en_trans/neg": ExactMatchFilter( 94 | "got_sp_en_trans/neg", 95 | dataset_id="got_sp_en_trans", 96 | answer_type="neg", 97 | ), 98 | "got_larger_than/large": ExactMatchFilter( 99 | "got_larger_than/large", 100 | dataset_id="got_larger_than", 101 | answer_type="pos", 102 | ), 103 | "got_larger_than/small": ExactMatchFilter( 104 | "got_larger_than/small", 105 | dataset_id="got_larger_than", 106 | answer_type="neg", 107 | ), 108 | } 109 | -------------------------------------------------------------------------------- /repeng/datasets/activations/creation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from dotenv import load_dotenv 5 | from mppr import MContext, MDict 6 | from pydantic import BaseModel 7 | 8 | from repeng.activations.inference import get_model_activations 9 | from repeng.datasets.activations.types import ActivationResultRow 10 | from repeng.datasets.elk.types import BinaryRow, DatasetId 11 | from repeng.datasets.elk.utils.fns import get_dataset 12 | from repeng.datasets.elk.utils.limits import Limits, limit_groups 13 | from repeng.models.llms import LlmId 14 | from repeng.models.loading import load_llm_oioo 15 | 16 | assert load_dotenv() 17 | 18 | 19 | class _BinaryRowWithLlm(BinaryRow): 20 | llm_id: LlmId 21 | 22 | 23 | class _Dataset(BaseModel, extra="forbid"): 24 | rows: dict[str, BinaryRow] 25 | 26 | 27 | def create_activations_dataset( 28 | tag: str, 29 | llm_ids: list[LlmId], 30 | dataset_ids: list[DatasetId], 31 | group_limits: Limits, 32 | device: torch.device, 33 | num_tokens_from_end: int | None, 34 | layers_start: int | None, 35 | layers_end: int | None, 36 | layers_skip: int | None, 37 | ) -> list[ActivationResultRow]: 38 | mcontext = MContext(Path("output/create-activations-dataset")) 39 | dataset_ids_mdict: MDict[DatasetId] = mcontext.create( 40 | {dataset_id: dataset_id for dataset_id in dataset_ids}, 41 | ) 42 | inputs = ( 43 | dataset_ids_mdict.map_cached( 44 | "datasets", 45 | lambda _, dataset_id: _Dataset(rows=get_dataset(dataset_id)), 46 | to=_Dataset, 47 | ) 48 | .flat_map(lambda _, dataset: {key: row for key, row in dataset.rows.items()}) 49 | .filter(limit_groups(group_limits)) 50 | .flat_map( 51 | lambda key, row: { 52 | f"{key}-{llm_id}": _BinaryRowWithLlm(**row.model_dump(), llm_id=llm_id) 53 | for llm_id in llm_ids 54 | } 55 | ) 56 | # avoids constantly reloading models with OIOO 57 | .sort(lambda _, row: row.llm_id) 58 | ) 59 | return ( 60 | inputs.map_cached( 61 | "activations", 62 | fn=lambda _, value: get_model_activations( 63 | load_llm_oioo( 64 | value.llm_id, 65 | device=device, 66 | use_half_precision=True, 67 | ), 68 | text=value.text, 69 | last_n_tokens=num_tokens_from_end, 70 | points_start=layers_start, 71 | points_end=layers_end, 72 | points_skip=layers_skip, 73 | ), 74 | to="pickle", 75 | ) 76 | .join( 77 | inputs, 78 | lambda _, activations, input: ActivationResultRow( 79 | dataset_id=input.dataset_id, 80 | split=input.split, 81 | answer_type=input.answer_type, 82 | label=input.label, 83 | activations=activations.activations, 84 | prompt_logprobs=activations.token_logprobs.sum().item(), 85 | group_id=input.group_id, 86 | llm_id=input.llm_id, 87 | ), 88 | ) 89 | .upload( 90 | f"s3://repeng/datasets/activations/{tag}.pickle", 91 | to="pickle", 92 | ) 93 | ).get() 94 | -------------------------------------------------------------------------------- /repeng/probes/collections.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from repeng.activations.probe_preparations import ActivationArrays 4 | from repeng.probes.base import BaseProbe 5 | from repeng.probes.contrast_consistent_search import CcsTrainingConfig, train_ccs_probe 6 | from repeng.probes.difference_in_means import train_dim_probe 7 | from repeng.probes.linear_artificial_tomography import train_lat_probe 8 | from repeng.probes.linear_discriminant_analysis import train_lda_probe 9 | from repeng.probes.logistic_regression import ( 10 | LrConfig, 11 | train_grouped_lr_probe, 12 | train_lr_probe, 13 | ) 14 | from repeng.probes.principal_component_analysis import ( 15 | train_grouped_pca_probe, 16 | train_pca_probe, 17 | ) 18 | from repeng.probes.random import train_random_probe 19 | 20 | ProbeMethod = Literal["ccs", "lat", "dim", "lda", "lr", "lr-g", "pca", "pca-g", "rand"] 21 | 22 | ALL_PROBES: list[ProbeMethod] = [ 23 | "ccs", 24 | "lat", 25 | "dim", 26 | "lda", 27 | "lr", 28 | "lr-g", 29 | "pca", 30 | "pca-g", 31 | ] 32 | SUPERVISED_PROBES: list[ProbeMethod] = ["dim", "lda", "lr", "lr-g"] 33 | UNSUPERVISED_PROBES: list[ProbeMethod] = list(set(ALL_PROBES) - set(SUPERVISED_PROBES)) 34 | GROUPED_PROBES: list[ProbeMethod] = ["ccs", "lr-g", "pca-g"] 35 | UNGROUPED_PROBES: list[ProbeMethod] = list(set(ALL_PROBES) - set(GROUPED_PROBES)) 36 | 37 | 38 | def train_probe( 39 | probe_method: ProbeMethod, arrays: ActivationArrays 40 | ) -> BaseProbe | None: 41 | if probe_method == "ccs": 42 | if arrays.groups is None: 43 | return None 44 | return train_ccs_probe( 45 | CcsTrainingConfig(), 46 | activations=arrays.activations, 47 | groups=arrays.groups, 48 | answer_types=arrays.answer_types, 49 | # N.B.: Technically unsupervised! 50 | labels=arrays.labels, 51 | ) 52 | elif probe_method == "lat": 53 | return train_lat_probe( 54 | activations=arrays.activations, 55 | answer_types=arrays.answer_types, 56 | ) 57 | elif probe_method == "dim": 58 | return train_dim_probe( 59 | activations=arrays.activations, 60 | labels=arrays.labels, 61 | ) 62 | elif probe_method == "lda": 63 | return train_lda_probe( 64 | activations=arrays.activations, 65 | labels=arrays.labels, 66 | ) 67 | elif probe_method == "lr": 68 | return train_lr_probe( 69 | LrConfig(), 70 | activations=arrays.activations, 71 | labels=arrays.labels, 72 | ) 73 | elif probe_method == "lr-g": 74 | if arrays.groups is None: 75 | return None 76 | return train_grouped_lr_probe( 77 | LrConfig(), 78 | activations=arrays.activations, 79 | groups=arrays.groups, 80 | labels=arrays.labels, 81 | ) 82 | elif probe_method == "pca": 83 | return train_pca_probe( 84 | activations=arrays.activations, 85 | answer_types=arrays.answer_types, 86 | ) 87 | elif probe_method == "pca-g": 88 | if arrays.groups is None: 89 | return None 90 | return train_grouped_pca_probe( 91 | activations=arrays.activations, 92 | groups=arrays.groups, 93 | answer_types=arrays.answer_types, 94 | ) 95 | elif probe_method == "rand": 96 | return train_random_probe( 97 | activations=arrays.activations, 98 | ) 99 | else: 100 | raise ValueError(f"Unknown probe_method: {probe_method}") 101 | -------------------------------------------------------------------------------- /experiments/scratch/fake_probes.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import numpy as np 3 | import pandas as pd 4 | import plotly.express as px 5 | import plotly.graph_objects as go 6 | from jaxtyping import Float 7 | from tqdm import tqdm 8 | 9 | from repeng.activations.probe_preparations import ActivationArrays 10 | from repeng.probes.base import DotProductProbe 11 | from repeng.probes.collections import ProbeMethod, train_probe 12 | from repeng.probes.contrast_consistent_search import ( 13 | CcsProbe, 14 | CcsTrainingConfig, 15 | train_ccs_probe, 16 | ) 17 | from repeng.probes.linear_artificial_tomography import train_lat_probe 18 | from repeng.probes.logistic_regression import LogisticRegressionProbe, train_lr_probe 19 | 20 | # %% 21 | anisotropy_offset = np.array([0, 0], dtype=np.float32) 22 | dataset_direction = np.array([0, 2], dtype=np.float32) 23 | dataset_cov = np.array([[0.1, 0.2], [0.2, 1]]) 24 | truth_direction = np.array([2, 0]) 25 | truth_cov = np.array([[0.01, 0], [0, 0.01]]) 26 | num_samples = int(1e3) 27 | 28 | random_false = np.random.multivariate_normal( 29 | mean=anisotropy_offset + dataset_direction, cov=dataset_cov, size=num_samples 30 | ) 31 | random_true = random_false + np.random.multivariate_normal( 32 | mean=truth_direction, cov=truth_cov, size=num_samples 33 | ) 34 | 35 | df_1 = pd.DataFrame(random_true, columns=["x", "y"]) 36 | df_1["label"] = True 37 | df_1["pair_id"] = np.array(range(num_samples)) 38 | df_2 = pd.DataFrame(random_false, columns=["x", "y"]) 39 | df_2["label"] = False 40 | df_2["pair_id"] = np.array(range(num_samples)) 41 | df = pd.concat([df_1, df_2]) 42 | df["activations"] = df.apply(lambda row: np.array([row["x"], row["y"]]), axis=1) 43 | 44 | arrays = ActivationArrays( 45 | activations=np.stack(df["activations"]), # type: ignore 46 | labels=df["label"].to_numpy(), 47 | groups=df["pair_id"].to_numpy(), 48 | answer_types=None, 49 | ) 50 | 51 | # %% 52 | probe_methods: list[ProbeMethod] = [ 53 | "ccs", 54 | "lat", 55 | "dim", 56 | "lda", 57 | "lr", 58 | "lr-g", 59 | "pca", 60 | "pca-g", 61 | "rand", 62 | ] 63 | probes = { 64 | probe_method: train_probe( 65 | probe_method, 66 | arrays, 67 | ) 68 | for probe_method in tqdm(probe_methods) 69 | } 70 | 71 | # %% 72 | fig_start = -2 73 | fig_end = 6 74 | 75 | 76 | def plot_probe( 77 | label: str, 78 | fig: go.Figure, 79 | probe: Float[np.ndarray, "2"], 80 | intercept: float, 81 | ) -> None: 82 | print(probe, intercept) 83 | xs = np.array([fig_start, 0, fig_end]) 84 | ys = -(probe[1] / probe[0]) * xs - (intercept / probe[0]) 85 | # TODO: Why swapped? 86 | fig.add_trace( 87 | go.Scatter( 88 | x=ys, y=xs, mode="lines", name=label, line=dict(width=3), opacity=0.6 89 | ) 90 | ) 91 | 92 | 93 | fig = px.scatter(df, "x", "y", color="label", opacity=0.3) 94 | fig.update_layout( 95 | xaxis_range=[fig_start, fig_end], 96 | yaxis_range=[fig_start, fig_end], 97 | ) 98 | fig.update_yaxes(scaleanchor="x", scaleratio=1) 99 | for probe_method, probe in probes.items(): 100 | assert probe is not None 101 | if isinstance(probe, DotProductProbe): 102 | plot_probe(probe_method, fig, probe.probe, 0) 103 | elif isinstance(probe, LogisticRegressionProbe): 104 | plot_probe( 105 | probe_method, 106 | fig, 107 | probe.model.coef_[0], 108 | probe.model.intercept_[0], 109 | ) 110 | elif isinstance(probe, CcsProbe): 111 | plot_probe( 112 | probe_method, 113 | fig, 114 | probe.linear.weight.detach().numpy()[0], 115 | probe.linear.bias.detach().numpy()[0], 116 | ) 117 | else: 118 | raise ValueError(type(probe)) 119 | fig.show() 120 | -------------------------------------------------------------------------------- /experiments/dataset_analysis.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from pathlib import Path 3 | 4 | import pandas as pd 5 | import plotly.express as px 6 | from mppr import MContext, MDict 7 | from pydantic import BaseModel 8 | 9 | from repeng.datasets.elk.types import BinaryRow, DatasetId 10 | from repeng.datasets.elk.utils.collections import resolve_dataset_ids 11 | from repeng.datasets.elk.utils.fns import get_dataset 12 | 13 | 14 | # %% 15 | class Dataset(BaseModel, extra="forbid"): 16 | rows: dict[str, BinaryRow] 17 | 18 | 19 | mcontext = MContext(Path("../output/dataset_analysis")) 20 | dataset_ids: MDict[DatasetId] = mcontext.create( 21 | {dataset_id: dataset_id for dataset_id in resolve_dataset_ids("all")}, 22 | ) 23 | datasets = dataset_ids.map_cached( 24 | "datasets", 25 | lambda _, dataset_id: Dataset(rows=get_dataset(dataset_id)), 26 | to=Dataset, 27 | ).flat_map(lambda _, dataset: {key: row for key, row in dataset.rows.items()}) 28 | 29 | # %% 30 | groups: dict[DatasetId, str] = { 31 | **{d: "dlk" for d in resolve_dataset_ids("dlk")}, 32 | **{d: "repe" for d in resolve_dataset_ids("repe")}, 33 | **{d: "repe" for d in resolve_dataset_ids("repe-qa")}, 34 | **{d: "got" for d in resolve_dataset_ids("got")}, 35 | } 36 | df = pd.DataFrame([row.model_dump() for row in datasets.get()]) 37 | df["word_counts"] = df["text"].apply(lambda row: len(row.split())) 38 | df["dataset_group"] = df["dataset_id"].apply(lambda x: groups.get(x, "misc")) 39 | df # type: ignore 40 | 41 | # %% 42 | px.bar( 43 | df.groupby(["dataset_id", "dataset_group", "split"]).size().reset_index(), 44 | title="Num rows by dataset & split", 45 | x="dataset_id", 46 | y=0, 47 | color="dataset_group", 48 | facet_row="split", 49 | log_y=True, 50 | height=1000, 51 | ) 52 | 53 | # %% 54 | px.bar( 55 | df.groupby(["dataset_id", "dataset_group", "split"])["group_id"] # type: ignore 56 | .nunique() 57 | .rename("num_groups") # type: ignore 58 | .reset_index(), 59 | title="Num groups by dataset & split", 60 | x="dataset_id", 61 | y="num_groups", 62 | color="dataset_group", 63 | facet_row="split", 64 | log_y=True, 65 | height=1000, 66 | ) 67 | 68 | # %% 69 | px.bar( 70 | df.groupby(["dataset_id", "split", "group"])["answer_type"] # type: ignore 71 | .nunique() 72 | .rename("num_groups") # type: ignore 73 | .reset_index(), 74 | title="Num answer types by dataset & split", 75 | x="dataset_id", 76 | y="num_groups", 77 | color="group", 78 | facet_row="split", 79 | height=1000, 80 | category_orders={"group": ["dlk", "repe", "got", "misc"]}, 81 | ) 82 | 83 | 84 | # %% 85 | px.bar( 86 | df.groupby("dataset_id")["word_counts"].mean().reset_index(), 87 | title="Average number of words per prompt by dataset", 88 | x="dataset_id", 89 | y="word_counts", 90 | color="dataset_id", 91 | log_y=True, 92 | ) 93 | 94 | # %% 95 | fig = px.bar( 96 | df.groupby("dataset_id")["is_true"].mean().reset_index(), 97 | title="Percent of true prompts by dataset", 98 | x="dataset_id", 99 | y="is_true", 100 | range_y=[0, 1], 101 | ) 102 | fig.add_hline(1 / 2, line_dash="dot", line_color="gray") 103 | fig.add_hline(1 / 3, line_dash="dot", line_color="gray") 104 | fig.add_hline(1 / 4, line_dash="dot", line_color="gray") 105 | fig.add_hline(1 / 5, line_dash="dot", line_color="gray") 106 | 107 | # %% 108 | for dataset_id in df["dataset_id"].unique(): 109 | row = df[df["dataset_id"] == dataset_id].sample(1) 110 | print("#", dataset_id) 111 | print("## text") 112 | print(row["text"].item()) 113 | print("## is true") 114 | print(row["is_true"].item()) 115 | print() 116 | 117 | # %% 118 | df["has_group_id"] = df["group_id"].apply(lambda x: x is not None) 119 | df.groupby("dataset_id")["has_group_id"].sum() 120 | -------------------------------------------------------------------------------- /repeng/datasets/elk/geometry_of_truth.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, cast 2 | 3 | import pandas as pd 4 | 5 | from repeng.datasets.elk.types import BinaryRow, DatasetId 6 | from repeng.datasets.utils.shuffles import deterministic_shuffle 7 | from repeng.datasets.utils.splits import split_to_all 8 | 9 | Subset = Literal[ 10 | "cities", 11 | "sp_en_trans", 12 | "larger_than", 13 | "cities_cities_conj", 14 | "cities_cities_disj", 15 | ] 16 | 17 | _URL = "https://raw.githubusercontent.com/saprmarks/geometry-of-truth/91b2232/datasets" 18 | 19 | 20 | def get_geometry_of_truth(subset: Subset) -> dict[str, BinaryRow]: 21 | if subset == "cities": 22 | return _get_paired( 23 | dataset_id="got_cities", 24 | csv_name_pos="cities", 25 | csv_name_neg="neg_cities", 26 | expected_identical_labels=["city", "country", "correct_country"], 27 | ) 28 | elif subset == "sp_en_trans": 29 | return _get_paired( 30 | dataset_id="got_sp_en_trans", 31 | csv_name_pos="sp_en_trans", 32 | csv_name_neg="neg_sp_en_trans", 33 | expected_identical_labels=[], 34 | ) 35 | elif subset == "larger_than": 36 | return _get_paired( 37 | dataset_id="got_larger_than", 38 | csv_name_pos="larger_than", 39 | csv_name_neg="smaller_than", 40 | expected_identical_labels=[], 41 | ) 42 | elif subset == "cities_cities_conj": 43 | return _get_unpaired( 44 | dataset_id="got_cities_cities_conj", 45 | csv_name="cities_cities_conj", 46 | ) 47 | elif subset == "cities_cities_disj": 48 | return _get_unpaired( 49 | dataset_id="got_cities_cities_disj", 50 | csv_name="cities_cities_disj", 51 | ) 52 | 53 | 54 | def _get_paired( 55 | dataset_id: DatasetId, 56 | *, 57 | csv_name_pos: str, 58 | csv_name_neg: str, 59 | expected_identical_labels: list[str], 60 | ) -> dict[str, BinaryRow]: 61 | result = {} 62 | csv_pos = _get_csv(csv_name_pos) 63 | csv_neg = _get_csv(csv_name_neg) 64 | assert len(csv_pos) == len(csv_neg) 65 | for index in deterministic_shuffle(list(range(len(csv_pos))), key=str): 66 | row_pos = csv_pos.iloc[index] 67 | row_neg = csv_neg.iloc[index] 68 | assert all(row_pos[x] == row_neg[x] for x in expected_identical_labels), ( 69 | row_pos, 70 | row_neg, 71 | ) 72 | assert row_pos["label"] != row_neg["label"], (row_pos, row_neg) 73 | for answer_type, row in [("pos", row_pos), ("neg", row_neg)]: 74 | result[f"{dataset_id}-{index}-{answer_type}"] = BinaryRow( 75 | dataset_id=dataset_id, 76 | group_id=str(index), 77 | split=split_to_all(dataset_id, str(index)), 78 | text=cast(str, row["statement"]), 79 | label=row["label"] == 1, 80 | format_args=dict(), 81 | answer_type=answer_type, 82 | ) 83 | return result 84 | 85 | 86 | def _get_unpaired( 87 | dataset_id: DatasetId, 88 | csv_name: str, 89 | ) -> dict[str, BinaryRow]: 90 | result = {} 91 | df = _get_csv(csv_name) 92 | for index, row in deterministic_shuffle(df.iterrows(), lambda row: str(row[0])): 93 | assert isinstance(index, int) 94 | result[f"{dataset_id}-{index}"] = BinaryRow( 95 | dataset_id=dataset_id, 96 | split=split_to_all(dataset_id, str(index)), 97 | text=cast(str, row["statement"]), 98 | label=row["label"] == 1, 99 | format_args=dict(), 100 | ) 101 | return result 102 | 103 | 104 | def _get_csv(csv_name: str) -> pd.DataFrame: 105 | return pd.read_csv(f"{_URL}/{csv_name}.csv") 106 | -------------------------------------------------------------------------------- /experiments/scratch/zeroshot_investigation.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from pathlib import Path 3 | 4 | import plotly.express as px 5 | import torch 6 | from dotenv import load_dotenv 7 | from mppr import MContext 8 | from transformers import AutoTokenizer 9 | 10 | from repeng.activations.inference import get_model_activations 11 | from repeng.activations.probe_preparations import ActivationArrayDataset 12 | from repeng.datasets.activations.types import ActivationResultRow 13 | from repeng.datasets.elk.types import BinaryRow 14 | from repeng.datasets.elk.utils.filters import DATASET_FILTER_FNS, DatasetIdFilter 15 | from repeng.datasets.elk.utils.fns import get_dataset 16 | from repeng.datasets.elk.utils.limits import Limits, SplitLimits, limit_groups 17 | from repeng.evals.logits import eval_logits_by_question 18 | from repeng.models.loading import load_llm_oioo 19 | 20 | assert load_dotenv(".env") 21 | 22 | # %% 23 | llm = load_llm_oioo( 24 | llm_id="Llama-2-7b-hf", 25 | device=torch.device("cuda"), 26 | use_half_precision=False, 27 | ) 28 | 29 | # %% 30 | dataset = get_dataset("boolq") 31 | # prompt_format = "[INST] {text} [/INST]" 32 | # dataset = { 33 | # key: value.model_copy( 34 | # update=dict( 35 | # text=prompt_format.format(text=value.text), 36 | # ) 37 | # ) 38 | # for key, value in dataset.items() 39 | # } 40 | list(dataset.values())[0] 41 | 42 | # %% 43 | mcontext = MContext(Path("output")) 44 | inputs = ( 45 | mcontext.create(dataset) 46 | # .filter( 47 | # lambda _, row: DATASET_FILTER_FNS["got_cities/pos"].filter( 48 | # row.dataset_id, row.answer_type 49 | # ), 50 | # ) 51 | .filter( 52 | limit_groups( 53 | Limits( 54 | default=SplitLimits(train=0, train_hparams=0, validation=200), 55 | by_dataset={}, 56 | ) 57 | ) 58 | ) 59 | ) 60 | outputs = inputs.map_cached( 61 | "activations-v21", 62 | lambda _, row: get_model_activations( 63 | llm, 64 | text=row.text, 65 | last_n_tokens=1, 66 | points_start=None, 67 | points_end=None, 68 | points_skip=None, 69 | ), 70 | to="pickle", 71 | ) 72 | results = outputs.join( 73 | inputs, 74 | lambda _, activations, row: ActivationResultRow( 75 | dataset_id=row.dataset_id, 76 | group_id=row.group_id, 77 | answer_type=row.answer_type, 78 | activations={}, 79 | prompt_logprobs=activations.token_logprobs.sum().item(), 80 | label=row.is_true, 81 | split=row.split, 82 | llm_id="Llama-2-7b-hf", 83 | ), 84 | ) 85 | array_dataset = ActivationArrayDataset(results.get()) 86 | 87 | # %% 88 | arrays = array_dataset.get( 89 | llm_id="Llama-2-7b-hf", 90 | dataset_filter=DatasetIdFilter("boolq"), 91 | split="validation", 92 | point_name="logprobs", 93 | token_idx=-1, 94 | limit=None, 95 | ) 96 | assert arrays.groups is not None 97 | eval_logits_by_question( 98 | logits=arrays.activations, 99 | labels=arrays.labels, 100 | groups=arrays.groups, 101 | ) 102 | 103 | # px.histogram( 104 | # df, 105 | # x="logprobs", 106 | # color="is_true", 107 | # barmode="overlay", 108 | # opacity=0.5, 109 | # ) 110 | 111 | # # %% 112 | # # tokenizer = AutoTokenizer.from_pretrained(f"google/gemma-7b") 113 | # tokenizer = AutoTokenizer.from_pretrained(f"meta-llama/Llama-2-7b-chat-hf") 114 | # text = "test text" 115 | # tokens = tokenizer.encode(text) 116 | # tokens_str = tokenizer.tokenize(text) 117 | # tokens = tokenizer.convert_tokens_to_ids(tokens_str) 118 | # tokens = torch.tensor([tokens]) 119 | 120 | # tokens_new = tokenizer.encode(text, return_tensors="pt") 121 | # tokens_str_new = tokenizer.convert_ids_to_tokens(tokens_new.squeeze().tolist()) 122 | 123 | # print(tokens) 124 | # print(tokens_new) 125 | 126 | # print(tokenizer.decode(tokens.squeeze())) 127 | # print(tokenizer.decode(tokens_new.squeeze())) 128 | 129 | # print(tokens_str) 130 | # print(tokens_str_new) 131 | -------------------------------------------------------------------------------- /repeng/probes/logistic_regression.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logistic regression based probes. 3 | 4 | Methodology for ungrouped logistic regression: 5 | 1. Given a set of activations, and whether each activation is from a true or false 6 | statement. 7 | 2. Fit a linear probe using scikit's LogisticRegression implementation. The probe takes 8 | in the activations and predicts the label. 9 | 10 | Methodology for grouped logistic regression: 11 | 1. Given a set of activations, and whether each activation is from a true or false 12 | statement. 13 | 2. Subtract the average activation of each group from each group member. 14 | 2. Fit a linear probe using scikit's LogisticRegression implementation. The probe takes 15 | in the group-normalized activations and predicts the label. 16 | 17 | Regularization: C=1. 18 | """ 19 | 20 | from dataclasses import dataclass 21 | 22 | import numpy as np 23 | import pandas as pd 24 | from jaxtyping import Bool, Float, Int64 25 | from sklearn.linear_model import LogisticRegression 26 | from typing_extensions import override 27 | 28 | from repeng.probes.base import BaseGroupedProbe, BaseProbe, PredictResult 29 | 30 | 31 | @dataclass 32 | class LrConfig: 33 | c: float = 1.0 34 | # We go for newton-cg as we've found it to be the fastest, see 35 | # experiments/scratch/lr_speed.py. 36 | solver: str = "newton-cg" 37 | max_iter: int = 10_000 38 | 39 | 40 | @dataclass 41 | class LogisticRegressionProbe(BaseProbe): 42 | model: LogisticRegression 43 | 44 | @override 45 | def predict( 46 | self, 47 | activations: Float[np.ndarray, "n d"], # noqa: F722 48 | ) -> PredictResult: 49 | logits = self.model.decision_function(activations) 50 | return PredictResult(logits=logits) 51 | 52 | 53 | @dataclass 54 | class LogisticRegressionGroupedProbe(BaseGroupedProbe, LogisticRegressionProbe): 55 | model: LogisticRegression 56 | 57 | @override 58 | def predict_grouped( 59 | self, 60 | activations: Float[np.ndarray, "n d"], # noqa: F722 61 | pairs: Int64[np.ndarray, "n"], # noqa: F821 62 | ) -> PredictResult: 63 | activations_centered = _center_pairs(activations, pairs) 64 | logits = self.model.decision_function(activations_centered) 65 | return PredictResult(logits=logits) 66 | 67 | 68 | def train_lr_probe( 69 | config: LrConfig, 70 | *, 71 | activations: Float[np.ndarray, "n d"], # noqa: F722 72 | labels: Bool[np.ndarray, "n"], # noqa: F821 73 | ) -> LogisticRegressionProbe: 74 | model = LogisticRegression( 75 | fit_intercept=True, 76 | solver=config.solver, 77 | C=config.c, 78 | max_iter=config.max_iter, 79 | ) 80 | model.fit(activations, labels) 81 | return LogisticRegressionProbe(model) 82 | 83 | 84 | def train_grouped_lr_probe( 85 | config: LrConfig, 86 | *, 87 | activations: Float[np.ndarray, "n d"], # noqa: F722 88 | groups: Int64[np.ndarray, "n d"], # noqa: F722 89 | labels: Bool[np.ndarray, "n"], # noqa: F821 90 | ) -> LogisticRegressionGroupedProbe: 91 | probe = train_lr_probe( 92 | config, 93 | activations=_center_pairs(activations, groups), 94 | labels=labels, 95 | ) 96 | return LogisticRegressionGroupedProbe(model=probe.model) 97 | 98 | 99 | # TODO: Double check this preserves order. 100 | def _center_pairs( 101 | activations: Float[np.ndarray, "n d"], # noqa: F722 102 | pairs: Int64[np.ndarray, "n"], # noqa: F821 103 | ) -> Float[np.ndarray, "n d"]: # noqa: F722 104 | df = pd.DataFrame( 105 | { 106 | "activations": list(activations), 107 | "pairs": pairs, 108 | } 109 | ) 110 | pair_means = ( 111 | df.groupby(["pairs"])["activations"] 112 | .apply(lambda a: np.mean(a, axis=0)) 113 | .rename("pair_mean") # type: ignore 114 | ) 115 | df = df.join(pair_means, on="pairs") 116 | df["activations"] = df["activations"] - df["pair_mean"] 117 | return np.stack(df["activations"].to_list()) 118 | -------------------------------------------------------------------------------- /repeng/models/points.py: -------------------------------------------------------------------------------- 1 | from typing import Any, cast, get_args, overload 2 | 3 | from transformers import ( 4 | GemmaForCausalLM, 5 | GPT2LMHeadModel, 6 | GPTNeoXForCausalLM, 7 | LlamaForCausalLM, 8 | MistralForCausalLM, 9 | ) 10 | 11 | from repeng.hooks.points import Point, TupleTensorExtractor 12 | from repeng.models.types import ( 13 | PYTHIA_DPO_TO_PYTHIA, 14 | GemmaId, 15 | Gpt2Id, 16 | Llama2Id, 17 | LlmId, 18 | MistralId, 19 | PythiaDpoId, 20 | PythiaId, 21 | ) 22 | 23 | _GPT2_NUM_LAYERS = 12 24 | _PYTHIA_NUM_LAYERS: dict[PythiaId, int] = { 25 | "pythia-70m": 6, 26 | "pythia-160m": 12, 27 | "pythia-410m": 24, 28 | "pythia-1b": 16, 29 | "pythia-1.4b": 24, 30 | "pythia-2.8b": 32, 31 | "pythia-6.9b": 32, 32 | "pythia-12b": 36, 33 | } 34 | _LLAMA2_NUM_LAYERS: dict[Llama2Id, int] = { 35 | "Llama-2-7b-hf": 32, 36 | "Llama-2-7b-chat-hf": 32, 37 | "Llama-2-13b-hf": 40, 38 | "Llama-2-13b-chat-hf": 40, 39 | "Llama-2-70b-hf": 80, 40 | "Llama-2-70b-chat-hf": 80, 41 | } 42 | _MISTRAL_NUM_LAYERS: dict[MistralId, int] = { 43 | "Mistral-7B": 32, 44 | "Mistral-7B-Instruct": 32, 45 | } 46 | _GEMMA_NUM_LAYERS: dict[GemmaId, int] = { 47 | "gemma-2b": 18, 48 | "gemma-2b-it": 18, 49 | "gemma-7b": 28, 50 | "gemma-7b-it": 28, 51 | } 52 | 53 | 54 | @overload 55 | def get_points(llm_id: PythiaId | PythiaDpoId) -> list[Point[GPTNeoXForCausalLM]]: 56 | ... 57 | 58 | 59 | @overload 60 | def get_points(llm_id: Gpt2Id) -> list[Point[GPT2LMHeadModel]]: 61 | ... 62 | 63 | 64 | @overload 65 | def get_points(llm_id: Llama2Id) -> list[Point[LlamaForCausalLM]]: 66 | ... 67 | 68 | 69 | @overload 70 | def get_points(llm_id: MistralId) -> list[Point[MistralForCausalLM]]: 71 | ... 72 | 73 | 74 | @overload 75 | def get_points(llm_id: GemmaId) -> list[Point[GemmaForCausalLM]]: 76 | ... 77 | 78 | 79 | def get_points(llm_id: LlmId) -> list[Point[Any]]: 80 | if llm_id in get_args(PythiaId): 81 | return pythia(cast(PythiaId, llm_id)) 82 | elif llm_id in get_args(Gpt2Id): 83 | return gpt2() 84 | elif llm_id in get_args(Llama2Id): 85 | return llama2(cast(Llama2Id, llm_id)) 86 | elif llm_id in get_args(PythiaDpoId): 87 | return pythia(PYTHIA_DPO_TO_PYTHIA[cast(PythiaDpoId, llm_id)]) 88 | elif llm_id in get_args(MistralId): 89 | return mistral(cast(MistralId, llm_id)) 90 | elif llm_id in get_args(GemmaId): 91 | return gemma(cast(GemmaId, llm_id)) 92 | else: 93 | raise ValueError(f"Unknown LLM ID: {llm_id}") 94 | 95 | 96 | def gpt2() -> list[Point[GPT2LMHeadModel]]: 97 | return [ 98 | Point( 99 | f"h{i}", 100 | lambda model, i=i: model.transformer.h[i], 101 | tensor_extractor=TupleTensorExtractor(0), 102 | ) 103 | for i in range(_GPT2_NUM_LAYERS) 104 | ] 105 | 106 | 107 | def pythia(pythia_id: PythiaId) -> list[Point[GPTNeoXForCausalLM]]: 108 | return [ 109 | Point( 110 | f"h{i}", 111 | lambda model, i=i: model.gpt_neox.layers[i], 112 | tensor_extractor=TupleTensorExtractor(0), 113 | ) 114 | for i in range(_PYTHIA_NUM_LAYERS[pythia_id]) 115 | ] 116 | 117 | 118 | def llama2(llama2_id: Llama2Id) -> list[Point[LlamaForCausalLM]]: 119 | return [ 120 | Point( 121 | f"h{i}", 122 | lambda model, i=i: model.model.layers[i], 123 | tensor_extractor=TupleTensorExtractor(0), 124 | ) 125 | for i in range(_LLAMA2_NUM_LAYERS[llama2_id]) 126 | ] 127 | 128 | 129 | def mistral(mistral_id: MistralId) -> list[Point[MistralForCausalLM]]: 130 | return [ 131 | Point( 132 | f"h{i}", 133 | lambda model, i=i: model.model.layers[i], 134 | tensor_extractor=TupleTensorExtractor(0), 135 | ) 136 | for i in range(_MISTRAL_NUM_LAYERS[mistral_id]) 137 | ] 138 | 139 | 140 | def gemma(gemma_id: GemmaId) -> list[Point[GemmaForCausalLM]]: 141 | return [ 142 | Point( 143 | f"h{i}", 144 | lambda model, i=i: model.model.layers[i], 145 | tensor_extractor=TupleTensorExtractor(0), 146 | ) 147 | for i in range(_GEMMA_NUM_LAYERS[gemma_id]) 148 | ] 149 | -------------------------------------------------------------------------------- /repeng/probes/contrast_consistent_search.py: -------------------------------------------------------------------------------- 1 | """ 2 | Replication of the CCS probes described in . 3 | 4 | Methodology: 5 | 1. Given a set of activations for questions and answer pairs. 6 | 2. For each question, take an arbitrary true and false question answer pair. 7 | 3. Optimize a linear probe to discover what contrasts these true and false pairs. 8 | a. The probe should be consistent. The probability of one statement should be equal to 9 | the opposite of the probability of the other pair: 10 | p(x) = 1 - p(y). 11 | b. The probe should be confident. The probability shouldn't equal 0.5: 12 | p(x) != p(y) != 0.5. 13 | 14 | N.B.: We sometimes find that the probe does't converge. We retry the optimization twice. 15 | 16 | Regularization: L2 regularization on the probe weights. 17 | """ 18 | 19 | from dataclasses import dataclass 20 | 21 | import numpy as np 22 | import torch 23 | from jaxtyping import Bool, Float, Int64 24 | from typing_extensions import override 25 | 26 | from repeng.probes.base import BaseProbe, PredictResult 27 | from repeng.probes.normalization import normalize_by_group 28 | 29 | 30 | class CcsProbe(torch.nn.Module, BaseProbe): 31 | def __init__(self, hidden_dim: int): 32 | super().__init__() 33 | self.linear = torch.nn.Linear(hidden_dim, 1) 34 | 35 | def forward( 36 | self, 37 | activations: Float[torch.Tensor, "n d"], # noqa: F722 38 | ) -> Float[torch.Tensor, "n"]: # noqa: F821 39 | result = self.linear(activations) 40 | result = torch.nn.functional.sigmoid(result) 41 | result = result.squeeze(-1) 42 | return result 43 | 44 | @torch.inference_mode() 45 | @override 46 | def predict( 47 | self, 48 | activations: Float[np.ndarray, "n d"], # noqa: F722 49 | ) -> PredictResult: 50 | probabilities = self(torch.tensor(activations, dtype=torch.float32)).numpy() 51 | return PredictResult(logits=probabilities) 52 | 53 | 54 | @dataclass 55 | class CcsTrainingConfig: 56 | num_steps: int = 100 57 | lr: float = 0.001 58 | weight_decay: float = 0.01 59 | 60 | 61 | def train_ccs_probe( 62 | config: CcsTrainingConfig, 63 | *, 64 | activations: Float[np.ndarray, "n d"], # noqa: F722 65 | groups: Int64[np.ndarray, "n d"], # noqa: F722 66 | answer_types: Int64[np.ndarray, "n d"] | None, # noqa: F722 67 | # Although CCS is technically unsupervised, we need the labels for multiple-choice 68 | # questions so that we can reduce answers into a true/false pair. 69 | labels: Bool[np.ndarray, "n"], # noqa: F821 70 | attempts: int = 2, 71 | ) -> CcsProbe: 72 | if answer_types is not None: 73 | activations = normalize_by_group(activations, answer_types) 74 | 75 | activations_1 = [] 76 | activations_2 = [] 77 | for group in np.unique(groups): 78 | group_activations = activations[groups == group] 79 | group_labels = labels[groups == group] 80 | if True not in group_labels or False not in group_labels: 81 | # This can happen when we truncate the dataset along a question boundary. 82 | continue 83 | # Get the first true and first false rows. 84 | indices = sorted( 85 | [group_labels.tolist().index(True), group_labels.tolist().index(False)] 86 | ) 87 | activations_1.append(group_activations[indices[0]]) 88 | activations_2.append(group_activations[indices[1]]) 89 | activations_1 = torch.tensor(np.array(activations_1)).to(dtype=torch.float32) 90 | activations_2 = torch.tensor(np.array(activations_2)).to(dtype=torch.float32) 91 | 92 | _, hidden_dim = activations_1.shape 93 | probe = CcsProbe(hidden_dim=hidden_dim) 94 | optimizer = torch.optim.LBFGS(probe.parameters(), lr=0.1, max_iter=100) 95 | 96 | def get_loss(): 97 | optimizer.zero_grad() 98 | probs_1: torch.Tensor = probe(activations_1) 99 | probs_2: torch.Tensor = probe(activations_2) 100 | loss_consistency = (probs_1 - (1 - probs_2)).pow(2).mean() 101 | loss_confidence = torch.min(probs_1, probs_2).pow(2).mean() 102 | loss_l2 = ( 103 | sum(param.norm() ** 2 for param in probe.parameters()) * config.weight_decay 104 | ) 105 | loss = loss_consistency + loss_confidence + loss_l2 106 | loss.backward() 107 | return loss.item() 108 | 109 | optimizer.step(closure=get_loss) 110 | loss = get_loss() 111 | if loss > 0.2: 112 | if attempts > 0: 113 | return train_ccs_probe( 114 | config=config, 115 | activations=activations, 116 | groups=groups, 117 | answer_types=answer_types, 118 | labels=labels, 119 | attempts=attempts - 1, 120 | ) 121 | raise ValueError(f"CCS probe did not converge: {loss} >= 0.2") 122 | return probe.eval() 123 | -------------------------------------------------------------------------------- /experiments/scratch/lat_test.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | 5 | import seaborn as sns 6 | import torch 7 | from mppr import MContext 8 | from sklearn.decomposition import PCA 9 | 10 | from repeng.activations.inference import get_model_activations 11 | from repeng.activations.probe_preparations import ActivationArrayDataset 12 | from repeng.datasets.activations.types import ActivationResultRow 13 | from repeng.datasets.elk.types import BinaryRow, DatasetId 14 | from repeng.datasets.elk.utils.collections import get_datasets 15 | from repeng.datasets.elk.utils.limits import limit_dataset_and_split_fn 16 | from repeng.evals.logits import eval_logits_by_question 17 | from repeng.models.loading import load_llm_oioo 18 | from repeng.models.types import LlmId 19 | 20 | mcontext = MContext(Path("../output/comparison")) 21 | 22 | # %% 23 | # activations_dataset: list[ActivationResultRow] = mcontext.download_cached( 24 | # "activations_dataset-v3", 25 | # path=( 26 | # "s3://repeng/datasets/activations/" 27 | # # "datasets_2024-02-02_tokensandlayers_v1.pickle" 28 | # # "datasets_2024-02-03_v1.pickle" 29 | # "datasets_2024-02-05_v2.pickle" 30 | # ), 31 | # to="pickle", 32 | # ).get() 33 | # print(set(row.llm_id for row in activations_dataset)) 34 | # print(set(row.dataset_id for row in activations_dataset)) 35 | # print(set(row.split for row in activations_dataset)) 36 | # dataset = ActivationArrayDataset(activations_dataset) 37 | 38 | 39 | # %% 40 | @dataclass 41 | class InputSpec: 42 | row: BinaryRow 43 | llm_id: LlmId 44 | 45 | 46 | llm_ids: list[LlmId] = [ 47 | "Llama-2-7b-hf", 48 | "Llama-2-7b-chat-hf", 49 | ] 50 | input_specs = ( 51 | mcontext.create_cached( 52 | "dataset", 53 | lambda: get_datasets(["arc_easy"]), 54 | to=BinaryRow, 55 | ) 56 | .filter( 57 | limit_dataset_and_split_fn(train_limit=800, validation_limit=200), 58 | ) 59 | .flat_map( 60 | lambda key, row: { 61 | f"{key}-{llm_id}": InputSpec(row=row, llm_id=llm_id) for llm_id in llm_ids 62 | } 63 | ) 64 | .sort(lambda _, row: llm_ids.index(row.llm_id)) 65 | ) 66 | activations = input_specs.map_cached( 67 | "test-activations", 68 | lambda _, row: get_model_activations( 69 | load_llm_oioo( 70 | row.llm_id, 71 | device=torch.device("cuda"), 72 | use_half_precision=True, 73 | ), 74 | text=row.row.text, 75 | last_n_tokens=1, 76 | ), 77 | to="pickle", 78 | ) 79 | dataset = ActivationArrayDataset( 80 | activations.join( 81 | input_specs, 82 | lambda _, row, spec: ActivationResultRow( 83 | llm_id=spec.llm_id, 84 | dataset_id=spec.row.dataset_id, 85 | split=spec.row.split, 86 | group_id=spec.row.group_id, 87 | template_name=spec.row.template_name, 88 | answer_type=spec.row.answer_type, 89 | label=spec.row.is_true, 90 | activations=row.activations, 91 | prompt_logprobs=row.token_logprobs.sum(), 92 | ), 93 | ).get() 94 | ) 95 | 96 | 97 | # %% 98 | dataset_id: DatasetId = "arc_easy" 99 | for llm_id in llm_ids: 100 | # for point in get_points(llm_id)[::3]: 101 | # print(llm_id, point.name) 102 | # arrays = dataset.get( 103 | # llm_id=llm_id, 104 | # dataset_filter_id=dataset_id, 105 | # split="train", 106 | # point_name=point.name, 107 | # token_idx=0, 108 | # limit=None, 109 | # ) 110 | # assert arrays.groups is not None 111 | # probe = train_lat_probe( 112 | # activations=arrays.activations, 113 | # groups=arrays.groups, 114 | # ) 115 | 116 | # arrays_val = dataset.get( 117 | # llm_id=llm_id, 118 | # dataset_filter_id=dataset_id, 119 | # split="validation", 120 | # point_name=point.name, 121 | # token_idx=0, 122 | # limit=None, 123 | # ) 124 | # assert arrays_val.groups is not None 125 | # print( 126 | # eval_probe_by_question( 127 | # probe=probe, 128 | # activations=arrays_val.activations, 129 | # groups=arrays_val.groups, 130 | # labels=arrays_val.labels, 131 | # ) 132 | # ) 133 | 134 | arrays_val = dataset.get( 135 | llm_id=llm_id, 136 | dataset_filter_id=dataset_id, 137 | split="validation", 138 | point_name="logprobs", 139 | token_idx=0, 140 | limit=None, 141 | ) 142 | assert arrays_val.groups is not None 143 | print( 144 | eval_logits_by_question( 145 | logits=arrays_val.activations, 146 | groups=arrays_val.groups, 147 | labels=arrays_val.labels, 148 | ) 149 | ) 150 | 151 | # %% 152 | pca = PCA(2) 153 | component_values = pca.fit_transform(arrays.activations) 154 | 155 | # %% 156 | sns.scatterplot( 157 | x=component_values[:, 0], 158 | y=component_values[:, 1], 159 | hue=arrays.labels, 160 | ) 161 | -------------------------------------------------------------------------------- /experiments/scratch/repe_repro.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | 3 | # %% 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import seaborn as sns 10 | import torch 11 | from mppr import mppr 12 | from sklearn.decomposition import PCA 13 | from torchtyping import TensorType 14 | 15 | from repeng import models 16 | from repeng.activations import ActivationRow, get_activations 17 | from repeng.datasets.elk.true_false import TrueFalseRow, get_true_false_dataset 18 | from repeng.hooks.grab import grab 19 | from repeng.hooks.patch import patch 20 | 21 | # %% 22 | device = torch.device("cuda") 23 | 24 | # %% 25 | # model, tokenizer, points = models.gpt2() 26 | # model, tokenizer, points = models.pythia("1b") 27 | model, tokenizer, points = models.llama2_13b("7b", chat=True) 28 | 29 | # %% 30 | print(set(p.device for p in model.parameters())) 31 | print(set(p.dtype for p in model.parameters())) 32 | 33 | # %% 34 | # model = model.to(dtype=torch.bfloat16) 35 | model = model.to(device=device) 36 | 37 | 38 | # %% 39 | torch.set_grad_enabled(False) 40 | 41 | # %% 42 | with grab(model, points[-1]) as get_activation: 43 | model.forward(torch.tensor([[0]], device=device)) 44 | print(get_activation().shape) 45 | 46 | 47 | # %% 48 | @dataclass 49 | class Row: 50 | input: TrueFalseRow 51 | activation: ActivationRow 52 | 53 | 54 | def format_input(row: TrueFalseRow) -> str: 55 | role = "an honest" if row.is_true else "a dishonest" 56 | return ( 57 | f"USER: Pretend you're {role} person making statements about the world.\n" 58 | f"ASSISTANT: {row.statement}" 59 | ) 60 | 61 | 62 | input = mppr.init( 63 | "initial", 64 | base_dir=Path("output/repe_repro_llama2_7b"), 65 | init_fn=get_true_false_dataset, 66 | to=TrueFalseRow, 67 | ).limit( 68 | 500, 69 | ) 70 | 71 | df = ( 72 | input.map( 73 | "format", 74 | lambda _, row: format_input(row), 75 | to="pickle", 76 | ) 77 | .map( 78 | "activations", 79 | lambda _, text: get_activations( 80 | model=model, 81 | tokenizer=tokenizer, 82 | points=points, 83 | text=text, 84 | ), 85 | to="pickle", 86 | ) 87 | .join( 88 | input, 89 | lambda _, activation, input: Row(input, activation), 90 | ) 91 | .to_dataframe( 92 | lambda row: dict( 93 | statement=row.input.statement, 94 | is_true=row.input.is_true, 95 | activation=row.activation.activations[points[-1].name], 96 | logprobs=row.activation.token_logprobs, 97 | ) 98 | ) 99 | ) 100 | 101 | # %% 102 | activations_truth = np.mean(df[df["is_true"]]["activation"].tolist(), axis=0) 103 | activations_falsehoods = np.mean(df[~df["is_true"]]["activation"].tolist(), axis=0) 104 | truth_activation = torch.tensor(activations_truth - activations_falsehoods) 105 | print( 106 | np.linalg.norm(activations_truth), 107 | np.linalg.norm(activations_falsehoods), 108 | truth_activation.norm().item(), 109 | ) 110 | 111 | # %% 112 | activations = np.stack(df["activation"].tolist()) 113 | num_pairs = 100_000 114 | indices_1 = np.random.randint(low=0, high=df.shape[0], size=num_pairs) 115 | indices_2 = np.random.randint(low=0, high=df.shape[0], size=num_pairs) 116 | activation_diffs = activations[indices_1] - activations[indices_2] 117 | activation_diffs = (activation_diffs - np.mean(activation_diffs, axis=0)) / np.std( 118 | activation_diffs, axis=0 119 | ) 120 | 121 | pca = PCA(n_components=4) 122 | pca.fit_transform(activation_diffs) 123 | for i in range(4): 124 | df[f"pca_{i}"] = df["activation"].apply( 125 | lambda a: (pca.components_[i] @ a).item(), # type: ignore 126 | ) 127 | 128 | # %% 129 | sns.scatterplot(data=df, x="pca_0", y="pca_1", hue="is_true") 130 | plt.show() 131 | sns.scatterplot(data=df, x="pca_1", y="pca_2", hue="is_true") 132 | plt.show() 133 | sns.scatterplot(data=df, x="pca_2", y="pca_3", hue="is_true") 134 | plt.show() 135 | 136 | # %% 137 | batch, seq, hidden = None, None, None 138 | 139 | 140 | def patch_fn( 141 | tensor: TensorType["batch", "seq", "hidden"], 142 | ) -> TensorType["batch", "seq", "hidden"]: 143 | # tensor[:, :] -= truth_activation.to(device=device) 144 | tensor[:, :] -= torch.tensor(pca.components_[0], device=device) 145 | return tensor 146 | 147 | 148 | def generate(): 149 | input = tokenizer.encode( 150 | ( 151 | # "Q: What city is the Eiffel Tower in? A: Paris\n" 152 | # "Q: What city is Westminster in? A: London\n" 153 | # "Q: What city is the Colosseum in? A:" 154 | # "Q: What's the capital of France? A: Paris\n" 155 | # "Q: What's the capital of Germany? A: Berlin\n" 156 | "Q: What's the capital of Spain? A:" 157 | ), 158 | return_tensors="pt", 159 | ).to(device=device) 160 | assert isinstance(input, torch.Tensor) 161 | output = model.generate(input, max_new_tokens=5) 162 | assert isinstance(output, torch.Tensor) 163 | output = output.squeeze(0) 164 | print(repr(tokenizer.decode(output[input.shape[1] :]))) 165 | 166 | 167 | generate() 168 | with patch(model, points[-1], patch_fn): 169 | generate() 170 | -------------------------------------------------------------------------------- /experiments/truthful_model_written_evals.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from pathlib import Path 3 | from typing import Callable 4 | 5 | import matplotlib.pyplot as plt 6 | import openai 7 | import pandas as pd 8 | import seaborn as sns 9 | from dotenv import load_dotenv 10 | from mppr import MContext 11 | from pydantic import BaseModel 12 | 13 | from repeng.datasets.modelwritten.filtering import ( 14 | FilteringConfig, 15 | StatementLikelihood, 16 | get_statement_likelihood, 17 | ) 18 | from repeng.datasets.modelwritten.generation import ( 19 | GenerationConfig, 20 | Statements, 21 | generate_statements, 22 | ) 23 | 24 | NUM_HONEST_GENERATIONS = 100 25 | NUM_DISHONEST_GENERATIONS = 100 26 | GENERATION_CONFIG = GenerationConfig() 27 | FILTERING_CONFIG = FilteringConfig() 28 | 29 | 30 | class Statement(BaseModel, extra="forbid"): 31 | statement: str 32 | honest: bool 33 | 34 | 35 | class TruthfulModelWrittenEvalRow(BaseModel, extra="forbid"): 36 | statement: str 37 | honest: bool 38 | likelihood: float 39 | yes_logprobs: float | None 40 | no_logprobs: float | None 41 | 42 | 43 | # %% 44 | load_dotenv("../.env") 45 | client = openai.AsyncClient() 46 | mcontext = MContext(Path("../output/truthful_model_written_evals-v2")) 47 | 48 | 49 | # %% 50 | async def create_df() -> pd.DataFrame: 51 | init = mcontext.create( 52 | { 53 | **{f"honest_{i}": True for i in range(NUM_HONEST_GENERATIONS)}, 54 | **{f"dishonest_{i}": False for i in range(NUM_HONEST_GENERATIONS)}, 55 | }, 56 | ) 57 | statements = await init.amap_cached( 58 | "generation", 59 | fn=lambda _, value: generate_statements( 60 | client, 61 | GENERATION_CONFIG, 62 | agrees=value, 63 | ), 64 | to=Statements, 65 | ) 66 | statements_flat = statements.flat_map( 67 | lambda key, statements: { 68 | f"{key}-statement{i}": Statement( 69 | statement=statement, 70 | honest=statements.agrees, 71 | ) 72 | for i, statement in enumerate(statements.statements) 73 | }, 74 | ).filter( 75 | filter_repeating_statements(), 76 | ) 77 | statement_likelihoods = await statements_flat.amap_cached( 78 | "filter", 79 | fn=lambda _, value: get_statement_likelihood( 80 | client, 81 | FILTERING_CONFIG, 82 | statement=value.statement, 83 | ), 84 | to=StatementLikelihood, 85 | ) 86 | result = statement_likelihoods.join( 87 | statements_flat, 88 | lambda _, likelihood, statement: TruthfulModelWrittenEvalRow( 89 | statement=statement.statement, 90 | honest=statement.honest, 91 | likelihood=likelihood.likelihood, 92 | yes_logprobs=likelihood.yes_logprobs, 93 | no_logprobs=likelihood.no_logprobs, 94 | ), 95 | ).filter( 96 | filter_low_likelihood, 97 | ) 98 | result.upload( 99 | "../repeng/datasets/data/truthful", 100 | to=TruthfulModelWrittenEvalRow, 101 | ) 102 | # TODO: Filter out statements that don't pass the filter. 103 | return result.to_dataframe(lambda row: row.model_dump()) 104 | 105 | 106 | def filter_repeating_statements() -> Callable[[str, Statement], bool]: 107 | seen = set() 108 | 109 | def fn(_, value: Statement) -> bool: 110 | statement_normalized = value.statement.lower().strip(" .") 111 | if statement_normalized in seen: 112 | return False 113 | seen.add(statement_normalized) 114 | return True 115 | 116 | return fn 117 | 118 | 119 | def filter_low_likelihood(_: str, value: TruthfulModelWrittenEvalRow) -> bool: 120 | if value.honest: 121 | return value.likelihood > 0.5 122 | else: 123 | return value.likelihood < 0.5 124 | 125 | 126 | df = await create_df() # type: ignore # noqa: F704 127 | 128 | # %% 129 | fig, axs = plt.subplots(ncols=3, figsize=(3 * 5, 5)) 130 | sns.histplot(df, x="likelihood", hue="honest", bins=40, ax=axs[0]) 131 | sns.histplot(df, x="yes_logprobs", hue="honest", bins=40, ax=axs[1]) 132 | sns.histplot(df, x="no_logprobs", hue="honest", bins=40, ax=axs[2]) 133 | 134 | # %% 135 | df["statement_normalized"] = df["statement"].str.lower().str.strip(" .") 136 | print( 137 | df["statement_normalized"].count(), 138 | df["statement_normalized"].nunique(), 139 | df["statement_normalized"].nunique() / df["statement_normalized"].count(), 140 | ) 141 | print(df["statement_normalized"].value_counts().tail(10)) 142 | 143 | # %% 144 | print(df["honest"].value_counts()) 145 | 146 | # %% 147 | df[df["honest"]].sort_values(by="likelihood", ascending=True).head(10) # type: ignore 148 | 149 | # %% 150 | df[~df["honest"]].sort_values(by="likelihood", ascending=False).head(10) # type: ignore 151 | 152 | # %% 153 | # # gpt-4-1106-preview 154 | # cost_per_input_token = 0.01 / 1000 155 | # cost_per_output_token = 0.03 / 1000 156 | # gpt-3.5-turbo-1106 157 | cost_per_input_token = 0.001 / 1000 158 | cost_per_output_token = 0.002 / 1000 159 | num_input_tokens = 300 160 | num_output_tokens = ( 161 | df["statement"].apply(len).sum() 162 | / (NUM_HONEST_GENERATIONS + NUM_DISHONEST_GENERATIONS) 163 | * 1.1 164 | ) 165 | cost_per_generation = ( 166 | cost_per_input_token * num_input_tokens + cost_per_output_token * num_output_tokens 167 | ) 168 | print(f"Cost per generation: ${cost_per_generation:.5f}") 169 | print(f"Cost for 1K generations: ${cost_per_generation * 1000:.5f}") 170 | -------------------------------------------------------------------------------- /repeng/models/llms.py: -------------------------------------------------------------------------------- 1 | from typing import Any, cast, get_args, overload 2 | 3 | import torch 4 | from transformers import ( 5 | AutoModelForCausalLM, 6 | AutoTokenizer, 7 | GemmaForCausalLM, 8 | GPT2LMHeadModel, 9 | GPTNeoXForCausalLM, 10 | LlamaForCausalLM, 11 | MistralForCausalLM, 12 | PreTrainedTokenizerFast, 13 | ) 14 | 15 | from repeng.models import points 16 | from repeng.models.types import ( 17 | GemmaId, 18 | Gpt2Id, 19 | Llama2Id, 20 | Llm, 21 | LlmId, 22 | MistralId, 23 | PythiaDpoId, 24 | PythiaId, 25 | ) 26 | 27 | _MISTRAL_HF_IDS: dict[MistralId, str] = { 28 | "Mistral-7B": "mistralai/Mistral-7B-v0.1", 29 | "Mistral-7B-Instruct": "mistralai/Mistral-7B-Instruct-v0.2", 30 | } 31 | 32 | 33 | @overload 34 | def get_llm( 35 | llm_id: PythiaId | PythiaDpoId, 36 | device: torch.device, 37 | use_half_precision: bool, 38 | ) -> Llm[GPTNeoXForCausalLM, PreTrainedTokenizerFast]: 39 | ... 40 | 41 | 42 | @overload 43 | def get_llm( 44 | llm_id: Gpt2Id, 45 | device: torch.device, 46 | use_half_precision: bool, 47 | ) -> Llm[GPT2LMHeadModel, PreTrainedTokenizerFast]: 48 | ... 49 | 50 | 51 | @overload 52 | def get_llm( 53 | llm_id: Llama2Id, 54 | device: torch.device, 55 | use_half_precision: bool, 56 | ) -> Llm[LlamaForCausalLM, PreTrainedTokenizerFast]: 57 | ... 58 | 59 | 60 | @overload 61 | def get_llm( 62 | llm_id: MistralId, 63 | device: torch.device, 64 | use_half_precision: bool, 65 | ) -> Llm[MistralForCausalLM, PreTrainedTokenizerFast]: 66 | ... 67 | 68 | 69 | @overload 70 | def get_llm( 71 | llm_id: GemmaId, 72 | device: torch.device, 73 | use_half_precision: bool, 74 | ) -> Llm[GemmaForCausalLM, PreTrainedTokenizerFast]: 75 | ... 76 | 77 | 78 | def get_llm( 79 | llm_id: LlmId, 80 | device: torch.device, 81 | use_half_precision: bool, 82 | ) -> Llm[Any, Any]: 83 | if llm_id in get_args(PythiaId): 84 | return pythia(cast(PythiaId, llm_id), device, use_half_precision) 85 | elif llm_id in get_args(Gpt2Id): 86 | return gpt2(device, use_half_precision) 87 | elif llm_id in get_args(Llama2Id): 88 | return llama2(cast(Llama2Id, llm_id), device, use_half_precision) 89 | elif llm_id in get_args(PythiaDpoId): 90 | return pythia_dpo(cast(PythiaDpoId, llm_id), device, use_half_precision) 91 | elif llm_id in get_args(MistralId): 92 | return mistral(cast(MistralId, llm_id), device, use_half_precision) 93 | elif llm_id in get_args(GemmaId): 94 | return gemma(cast(GemmaId, llm_id), device, use_half_precision) 95 | else: 96 | raise ValueError(f"Unknown LLM ID: {llm_id}") 97 | 98 | 99 | def gpt2( 100 | device: torch.device, 101 | use_half_precision: bool, 102 | ) -> Llm[GPT2LMHeadModel, PreTrainedTokenizerFast]: 103 | dtype = torch.float16 if use_half_precision else torch.float32 104 | model = AutoModelForCausalLM.from_pretrained( 105 | "gpt2", device_map=device, torch_dtype=dtype 106 | ) 107 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 108 | assert isinstance(tokenizer, PreTrainedTokenizerFast) 109 | return Llm( 110 | model, 111 | tokenizer, 112 | points.gpt2(), 113 | ) 114 | 115 | 116 | def pythia( 117 | pythia_id: PythiaId, 118 | device: torch.device, 119 | use_half_precision: bool, 120 | ) -> Llm[GPTNeoXForCausalLM, PreTrainedTokenizerFast]: 121 | dtype = torch.float16 if use_half_precision else torch.float32 122 | model = GPTNeoXForCausalLM.from_pretrained( 123 | f"EleutherAI/{pythia_id}", 124 | device_map=device, 125 | torch_dtype=dtype, 126 | ) 127 | assert isinstance(model, GPTNeoXForCausalLM) 128 | tokenizer = AutoTokenizer.from_pretrained(f"EleutherAI/{pythia_id}") 129 | assert isinstance(tokenizer, PreTrainedTokenizerFast) 130 | return Llm( 131 | model, 132 | tokenizer, 133 | points.pythia(pythia_id), 134 | ) 135 | 136 | 137 | def pythia_dpo( 138 | pythia_dpo_id: PythiaDpoId, 139 | device: torch.device, 140 | use_half_precision: bool, 141 | ) -> Llm[GPTNeoXForCausalLM, PreTrainedTokenizerFast]: 142 | dtype = torch.float16 if use_half_precision else torch.float32 143 | if pythia_dpo_id == "pythia-dpo-1b": 144 | model_id = "Leogrin/eleuther-pythia1b-hh-dpo" 145 | pythia_id = "pythia-1b" 146 | elif pythia_dpo_id == "pythia-sft-1b": 147 | model_id = "Leogrin/eleuther-pythia1b-hh-sft" 148 | pythia_id = "pythia-1b" 149 | elif pythia_dpo_id == "pythia-dpo-1.4b": 150 | model_id = "Leogrin/eleuther-pythia1.4b-hh-dpo" 151 | pythia_id = "pythia-1.4b" 152 | elif pythia_dpo_id == "pythia-sft-1.4b": 153 | model_id = "Leogrin/eleuther-pythia1.4b-hh-sft" 154 | pythia_id = "pythia-1.4b" 155 | else: 156 | raise ValueError(f"Unknown Pythia DPO ID: {pythia_dpo_id}") 157 | model = GPTNeoXForCausalLM.from_pretrained( 158 | model_id, device_map=device, torch_dtype=dtype 159 | ) 160 | assert isinstance(model, GPTNeoXForCausalLM) 161 | tokenizer = AutoTokenizer.from_pretrained(model_id) 162 | assert isinstance(tokenizer, PreTrainedTokenizerFast) 163 | return Llm(model, tokenizer, points.pythia(pythia_id)) 164 | 165 | 166 | def llama2( 167 | llama_id: Llama2Id, 168 | device: torch.device, 169 | use_half_precision: bool, 170 | ) -> Llm[LlamaForCausalLM, PreTrainedTokenizerFast]: 171 | dtype = torch.bfloat16 if use_half_precision else torch.float32 172 | model = AutoModelForCausalLM.from_pretrained( 173 | f"meta-llama/{llama_id}", device_map=device, torch_dtype=dtype 174 | ) 175 | assert isinstance(model, LlamaForCausalLM) 176 | tokenizer = AutoTokenizer.from_pretrained(f"meta-llama/{llama_id}") 177 | assert isinstance(tokenizer, PreTrainedTokenizerFast) 178 | return Llm( 179 | model, 180 | tokenizer, 181 | points.llama2(llama_id), 182 | ) 183 | 184 | 185 | def mistral( 186 | mistral_id: MistralId, 187 | device: torch.device, 188 | use_half_precision: bool, 189 | ) -> Llm[MistralForCausalLM, PreTrainedTokenizerFast]: 190 | dtype = torch.bfloat16 if use_half_precision else torch.float32 191 | hf_id = _MISTRAL_HF_IDS[mistral_id] 192 | tokenizer = AutoTokenizer.from_pretrained(hf_id) 193 | assert isinstance(tokenizer, PreTrainedTokenizerFast) 194 | model = AutoModelForCausalLM.from_pretrained( 195 | hf_id, device_map=device, torch_dtype=dtype 196 | ) 197 | assert isinstance(model, MistralForCausalLM) 198 | return Llm( 199 | model, 200 | tokenizer, 201 | points.mistral(mistral_id), 202 | ) 203 | 204 | 205 | def gemma( 206 | gemma_id: GemmaId, 207 | device: torch.device, 208 | use_half_precision: bool, 209 | ) -> Llm[GemmaForCausalLM, PreTrainedTokenizerFast]: 210 | dtype = torch.bfloat16 if use_half_precision else torch.float32 211 | tokenizer = AutoTokenizer.from_pretrained(f"google/{gemma_id}") 212 | assert isinstance(tokenizer, PreTrainedTokenizerFast) 213 | model = AutoModelForCausalLM.from_pretrained( 214 | f"google/{gemma_id}", device_map=device, torch_dtype=dtype 215 | ) 216 | assert isinstance(model, GemmaForCausalLM) 217 | return Llm( 218 | model, 219 | tokenizer, 220 | points.gemma(gemma_id), 221 | ) 222 | -------------------------------------------------------------------------------- /repeng/datasets/elk/dlk.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any 3 | 4 | from datasets import load_dataset 5 | 6 | from repeng.datasets.elk.types import BinaryRow, DlkDatasetId, Split 7 | from repeng.datasets.utils.shuffles import deterministic_shuffle 8 | from repeng.datasets.utils.splits import split_train 9 | 10 | 11 | @dataclass 12 | class _DatasetSpec: 13 | name: str 14 | subset: str | None = None 15 | validation_name: str = "validation" 16 | 17 | 18 | @dataclass 19 | class _DlkTemplate: 20 | template: str 21 | labels: list[str] 22 | args: list[str] 23 | insert_label_options: bool = True 24 | 25 | 26 | _DATASET_SPECS: dict[DlkDatasetId, _DatasetSpec] = { 27 | # Sentiment classification 28 | "imdb": _DatasetSpec("imdb", validation_name="test"), 29 | "amazon_polarity": _DatasetSpec("amazon_polarity", validation_name="test"), 30 | # Topic classification 31 | "ag_news": _DatasetSpec("ag_news", validation_name="test"), 32 | "dbpedia_14": _DatasetSpec("dbpedia_14", validation_name="test"), 33 | # NLI 34 | "rte": _DatasetSpec("super_glue", "rte"), 35 | # N.B.: We skip QNLI because we can't find the prompt templates. 36 | # Story completion 37 | "copa": _DatasetSpec("super_glue", "copa"), 38 | # N.B.: We skip story_cloze because it requires filling in a form to access. 39 | # Question answering 40 | "boolq": _DatasetSpec("super_glue", "boolq"), 41 | # Common sense reasoning 42 | "piqa": _DatasetSpec("piqa"), 43 | # Other formats: 44 | "boolq/simple": _DatasetSpec("super_glue", "boolq"), 45 | "imdb/simple": _DatasetSpec("imdb", validation_name="test"), 46 | } 47 | 48 | 49 | def get_dlk_dataset(dataset_id: DlkDatasetId): 50 | dataset_spec = _DATASET_SPECS[dataset_id] 51 | dataset: Any = load_dataset(dataset_spec.name, dataset_spec.subset) 52 | return { 53 | **_get_dlk_dataset(dataset_id, dataset, split="train", limit=3000), 54 | **_get_dlk_dataset(dataset_id, dataset, split="validation", limit=3000), 55 | } 56 | 57 | 58 | def _get_dlk_dataset( 59 | dataset_id: DlkDatasetId, 60 | dataset: Any, 61 | split: Split, 62 | limit: int, 63 | ) -> dict[str, BinaryRow]: 64 | dataset_spec = _DATASET_SPECS[dataset_id] 65 | if split == "train": 66 | hf_split = "train" 67 | elif split == "validation": 68 | hf_split = dataset_spec.validation_name 69 | else: 70 | raise ValueError(split) 71 | 72 | results = {} 73 | for row_idx, row in deterministic_shuffle( 74 | enumerate(dataset[hf_split]), lambda row: str(row[0]) 75 | )[:limit]: 76 | assert "label" in row and type(row["label"]) == int, row 77 | template = _TEMPLATES[dataset_id] 78 | true_label: int = row["label"] 79 | false_label_options = list(range(len(template.labels))) 80 | false_label_options.remove(true_label) 81 | false_label = false_label_options[row_idx % len(false_label_options)] 82 | labels = sorted([true_label, false_label]) 83 | 84 | for label_idx, label in enumerate(labels): 85 | format_args: dict[str, str] = dict(label=template.labels[label]) 86 | if template.insert_label_options: 87 | format_args["label1"] = template.labels[labels[0]] 88 | format_args["label2"] = template.labels[labels[1]] 89 | for arg in template.args: 90 | format_args[arg] = row[arg] 91 | prompt = template.template.format(**format_args) 92 | results[f"{dataset_id}-{row_idx}-{label_idx}-{split}"] = BinaryRow( 93 | dataset_id=dataset_id, 94 | split=split_train(split, seed=dataset_id, row_id=str(row_idx)), 95 | group_id=str(row_idx), 96 | text=prompt, 97 | label=label == true_label, 98 | format_args=format_args, 99 | # As labels are sorted, true/false questions will be consistent. 100 | answer_type=str(label_idx), 101 | ) 102 | return results 103 | 104 | 105 | # Taken from appendix 1. 106 | _TEMPLATES: dict[DlkDatasetId, _DlkTemplate] = { 107 | "ag_news": _DlkTemplate( 108 | 'Consider the following example: "{text}"\n' 109 | "Choice 1: {label1}\n" 110 | "Choice 2: {label2}\n" 111 | "Between choice 1 and choice 2, the topic of this example is {label}", 112 | labels=["World", "Sports", "Business", "Sci/Tech"], 113 | args=["text"], 114 | ), 115 | "amazon_polarity": _DlkTemplate( 116 | 'Consider the following example: "{content}"\n' 117 | "Choice 1: {label1}\n" 118 | "Choice 2: {label2}\n" 119 | "Between choice 1 and choice 2, the sentiment of this example is {label}", 120 | labels=["Negative", "Positive"], 121 | args=["content"], 122 | ), 123 | "boolq": _DlkTemplate( 124 | "Passage: {passage}\n" 125 | "After reading this passage, I have a question: {question}? True or False?\n" 126 | "{label}", 127 | labels=["False", "True"], 128 | args=["passage", "question"], 129 | insert_label_options=False, 130 | ), 131 | "boolq/simple": _DlkTemplate( 132 | "{passage}\nQuestion: {question}?\nAnswer: {label}", 133 | labels=["no", "yes"], 134 | args=["passage", "question"], 135 | insert_label_options=False, 136 | ), 137 | "copa": _DlkTemplate( 138 | 'Consider the following premise: "{premise}"\n' 139 | "Choice 1: {choice1}\n" 140 | "Choice 2: {choice2}\n" 141 | "Q: Which one is more likely to be the {question}, choice 1 or choice 2?\n" 142 | "{label}", 143 | labels=["Choice 1", "Choice 2"], 144 | args=["premise", "question", "choice1", "choice2"], 145 | insert_label_options=False, 146 | ), 147 | "dbpedia_14": _DlkTemplate( 148 | 'Consider the following example: "{content}"\n' 149 | "Choice 1: {label1}\n" 150 | "Choice 2: {label2}\n" 151 | "Between choice 1 and choice 2, the topic of this example is {label}", 152 | labels=[ 153 | "Company", 154 | "EducationalInstitution", 155 | "Artist", 156 | "Athlete", 157 | "OfficeHolder", 158 | "MeanOfTransportation", 159 | "Building", 160 | "NaturalPlace", 161 | "Village", 162 | "Animal", 163 | "Plant", 164 | "Album", 165 | "Film", 166 | "WrittenWork", 167 | ], 168 | args=["content"], 169 | ), 170 | "imdb": _DlkTemplate( 171 | 'Consider the following example: "{text}"\n' 172 | "Between {label1} and {label2}, the sentiment of this example is {label}", 173 | labels=["Negative", "Positive"], 174 | args=["text"], 175 | ), 176 | "imdb/simple": _DlkTemplate( 177 | "{text}\nQuestion: What is the sentiment of the review?\nAnswer: {label}", 178 | labels=["negative", "positive"], 179 | args=["text"], 180 | insert_label_options=False, 181 | ), 182 | "piqa": _DlkTemplate( 183 | "Goal: {goal}\n" 184 | "Which is the correct ending?\n" 185 | "Choice 1: {sol1}\n" 186 | "Choice 2: {sol2}\n" 187 | "{label}", 188 | labels=["Choice 1", "Choice 2"], 189 | args=["goal", "sol1", "sol2"], 190 | insert_label_options=False, 191 | ), 192 | "rte": _DlkTemplate( 193 | "{premise}\n" 194 | 'Question: Does this imply that "{hypothesis}", yes or no?\n' 195 | "{label}", 196 | labels=["yes", "no"], 197 | args=["premise", "hypothesis"], 198 | insert_label_options=False, 199 | ), 200 | } 201 | -------------------------------------------------------------------------------- /experiments/scratch/got_test.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from datetime import datetime 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import plotly.express as px 8 | import torch 9 | from dotenv import load_dotenv 10 | from mppr import MContext 11 | from sklearn.decomposition import PCA 12 | from tqdm import tqdm 13 | 14 | from repeng.activations.inference import get_model_activations 15 | from repeng.activations.probe_preparations import ActivationArrayDataset 16 | from repeng.datasets.activations.types import ActivationResultRow 17 | from repeng.datasets.elk.types import BinaryRow 18 | from repeng.datasets.elk.utils.filters import DATASET_FILTER_FNS, DatasetIdFilter 19 | from repeng.datasets.elk.utils.fns import get_datasets 20 | from repeng.datasets.elk.utils.limits import Limits, SplitLimits, limit_groups 21 | from repeng.evals.probes import eval_probe_by_question 22 | from repeng.models.loading import load_llm_oioo 23 | from repeng.models.points import get_points 24 | from repeng.models.types import LlmId 25 | from repeng.probes.collections import ALL_PROBES, SUPERVISED_PROBES, train_probe 26 | 27 | assert load_dotenv(".env") 28 | 29 | # %% 30 | # path = Path("../../output/comparison") 31 | # mcontext = MContext(path) 32 | # activation_results: list[ActivationResultRow] = mcontext.download_cached( 33 | # "activations_results", 34 | # path="s3://repeng/datasets/activations/datasets_2024-02-08_v1.pickle", 35 | # to="pickle", 36 | # ).get() 37 | # print(set(row.llm_id for row in activation_results)) 38 | # print(set(row.dataset_id for row in activation_results)) 39 | # print(set(row.split for row in activation_results)) 40 | # dataset = ActivationArrayDataset(activation_results) 41 | 42 | # %% 43 | mcontext = MContext(Path("../output/got_test")) 44 | limits = Limits( 45 | default=SplitLimits(train=150, train_hparams=0, validation=200), 46 | by_dataset={}, 47 | ) 48 | # llm_ids: list[LlmId] = ["Llama-2-7b-chat-hf", "Llama-2-13b-chat-hf"] 49 | llm_ids: list[LlmId] = ["Llama-2-13b-chat-hf"] 50 | inputs = ( 51 | mcontext.create_cached( 52 | "dataset-repe", 53 | lambda: get_datasets( 54 | [ 55 | # "geometry_of_truth/cities", 56 | "race", 57 | "open_book_qa", 58 | "arc_easy", 59 | "arc_challenge", 60 | ] 61 | ), 62 | to=BinaryRow, 63 | ) 64 | .filter( 65 | limit_groups(limits), 66 | ) 67 | .flat_map(lambda key, row: {f"{key}-{llm_id}": (row, llm_id) for llm_id in llm_ids}) 68 | .sort(lambda _, row: llm_ids.index(row[1])) 69 | ) 70 | activations = inputs.map_cached( 71 | "activations-v3", 72 | lambda _, row: get_model_activations( 73 | load_llm_oioo( 74 | llm_id=row[1], 75 | device=torch.device("cuda"), 76 | use_half_precision=row[1] == "Llama-2-13b-chat-hf", 77 | ), 78 | text=row[0].text, 79 | last_n_tokens=1, 80 | ), 81 | to="pickle", 82 | ).join( 83 | inputs, 84 | lambda _, activations, input: ActivationResultRow( 85 | dataset_id=input[0].dataset_id, 86 | group_id=input[0].group_id, 87 | answer_type=input[0].answer_type, 88 | activations=activations.activations, 89 | prompt_logprobs=activations.token_logprobs.sum(), 90 | label=input[0].is_true, 91 | split=input[0].split, 92 | llm_id=input[1], 93 | ), 94 | ) 95 | dataset = ActivationArrayDataset(activations.get()) 96 | 97 | # %% 98 | for llm_id in llm_ids[1:]: 99 | arrays = dataset.get( 100 | llm_id=llm_id, 101 | dataset_filter=DatasetIdFilter("geometry_of_truth/cities"), 102 | # dataset_filter=DatasetIdFilter("geometry_of_truth/sp_en_trans"), 103 | # dataset_filter=DATASET_FILTER_FNS["geometry_of_truth/cities/pos"], 104 | # dataset_filter=DATASET_FILTER_FNS["geometry_of_truth/cities/neg"], 105 | split="validation", 106 | # point_name=get_points(llm_id)[-10].name, 107 | point_name="h13", 108 | token_idx=0, 109 | limit=None, 110 | ) 111 | assert arrays.answer_types is not None 112 | assert arrays.groups is not None 113 | activations = arrays.activations.copy() 114 | # for answer_type in np.unique(arrays.answer_types): 115 | # activations[answer_type == arrays.answer_types] -= np.mean( 116 | # activations[answer_type == arrays.answer_types], axis=0 117 | # ) 118 | for group in np.unique(arrays.groups): 119 | activations[group == arrays.groups] -= np.mean( 120 | activations[group == arrays.groups], axis=0 121 | ) 122 | activations = activations 123 | 124 | components = PCA(n_components=3).fit_transform(activations) 125 | fig = px.scatter_3d( 126 | title=llm_id, 127 | x=components[:, 0], 128 | y=components[:, 1], 129 | z=components[:, 2], 130 | color=arrays.labels, 131 | symbol=arrays.answer_types, 132 | category_orders={"color": [False, True]}, 133 | ) 134 | fig.update_traces(marker_size=3) 135 | fig.update_layout(showlegend=False) 136 | fig.show() 137 | 138 | # %% 139 | results = [] 140 | for llm_id in llm_ids: 141 | for point in tqdm(get_points(llm_id)[::4]): 142 | arrays = dataset.get( 143 | llm_id=llm_id, 144 | dataset_filter=DatasetIdFilter("geometry_of_truth/cities"), 145 | split="validation", 146 | point_name=point.name, 147 | token_idx=0, 148 | limit=None, 149 | ) 150 | assert arrays.answer_types is not None 151 | components = PCA(n_components=3).fit_transform( 152 | arrays.activations.astype(np.float16) 153 | ) 154 | for i in range(components.shape[0]): 155 | results.append( 156 | dict( 157 | llm_id=llm_id, 158 | point_name=point.name, 159 | x=components[i, 0], 160 | y=components[i, 1], 161 | z=components[i, 2], 162 | label=arrays.labels[i], 163 | answer_type=arrays.answer_types[i], 164 | ) 165 | ) 166 | 167 | # %% 168 | df = pd.DataFrame(results) 169 | fig = px.scatter( 170 | df, 171 | x="x", 172 | y="y", 173 | color="label", 174 | symbol="answer_type", 175 | facet_col="llm_id", 176 | facet_row="point_name", 177 | category_orders={"color": [False, True]}, 178 | height=2000, 179 | ) 180 | fig.update_xaxes(matches=None) 181 | fig.update_yaxes(matches=None) 182 | fig.for_each_xaxis(lambda xaxis: xaxis.update(showticklabels=True)) 183 | fig.for_each_yaxis(lambda yaxis: yaxis.update(showticklabels=True)) 184 | fig.show() 185 | 186 | # %% 187 | point_name = "h30" 188 | results = [] 189 | for llm_id in llm_ids[:1]: 190 | arrays = dataset.get( 191 | llm_id=llm_id, 192 | dataset_filter=DatasetIdFilter("geometry_of_truth/cities"), 193 | split="train", 194 | point_name=point_name, 195 | token_idx=0, 196 | limit=None, 197 | ) 198 | arrays_val = dataset.get( 199 | llm_id=llm_id, 200 | dataset_filter=DatasetIdFilter("geometry_of_truth/cities"), 201 | split="validation", 202 | point_name=point_name, 203 | token_idx=0, 204 | limit=None, 205 | ) 206 | assert arrays_val.groups is not None 207 | for probe_method in tqdm(ALL_PROBES): 208 | start = datetime.now() 209 | probe = train_probe(probe_method, arrays) 210 | duration = datetime.now() - start 211 | print(probe_method, duration) 212 | assert probe is not None 213 | eval_results = eval_probe_by_question( 214 | probe, 215 | activations=arrays_val.activations, 216 | groups=arrays_val.groups, 217 | labels=arrays_val.labels, 218 | ) 219 | results.append( 220 | dict( 221 | llm_id=llm_id, 222 | probe_method=probe_method, 223 | accuracy=eval_results.accuracy, 224 | duration=duration.total_seconds(), 225 | ) 226 | ) 227 | df = pd.DataFrame(results) 228 | df 229 | 230 | # %% 231 | df["is_supervised"] = df["probe_method"].isin(SUPERVISED_PROBES) 232 | px.bar( 233 | df, 234 | x="probe_method", 235 | y="accuracy", 236 | color="is_supervised", 237 | facet_col="llm_id", 238 | ) 239 | -------------------------------------------------------------------------------- /experiments/scratch/repe_performance_check.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from dataclasses import asdict, dataclass 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import seaborn as sns 7 | import torch 8 | from jaxtyping import Float, Int64 9 | from mppr import MContext 10 | 11 | from repeng.activations.probe_preparations import ( 12 | Activation, 13 | LabeledGroupedActivationArray, 14 | prepare_activations_for_probes, 15 | ) 16 | from repeng.datasets.activations.types import ActivationResultRow 17 | from repeng.datasets.elk.types import DatasetId, Split 18 | from repeng.evals.probes import eval_probe_by_question 19 | from repeng.hooks.grab import grab 20 | from repeng.models.llms import pythia 21 | from repeng.probes.base import BaseGroupedProbe, BaseProbe, PredictResult 22 | from repeng.probes.collections import ProbeMethod, train_probe 23 | 24 | # %% 25 | # # so we don't have to re-load massive file 26 | # mcontext = MContext(Path("../output/comparison")) 27 | # activations_dataset: list[ActivationResultRow] = ( 28 | # mcontext.download_cached( 29 | # "activations_dataset", 30 | # path="s3://repeng/datasets/activations/pythia_2024-01-26_v1.pickle", 31 | # to="pickle", 32 | # ) 33 | # .filter(lambda _, row: row.llm_id == "pythia-6.9b") 34 | # .get() 35 | # ) 36 | 37 | # # %% 38 | # mcontext = MContext(Path("../output/comparison")) 39 | # dataset = mcontext.create_cached( 40 | # "dataset-v3", 41 | # lambda: get_datasets(["arc_easy", "common_sense_qa", "race"]), 42 | # to=BinaryRow, 43 | # ).filter( 44 | # limit_dataset_and_split_fn(train_limit=1000, validation_limit=200), 45 | # ) 46 | 47 | # # %% 48 | # llm = get_llm( 49 | # "Llama-2-7b-chat-hf", 50 | # device=torch.device("cuda"), 51 | # dtype=torch.bfloat16, 52 | # ) 53 | 54 | # # %% 55 | # activations = dataset.map_cached( 56 | # "activations-chat", 57 | # lambda _, row: get_model_activations( 58 | # llm, 59 | # text=row.text, 60 | # last_n_tokens=3, 61 | # ), 62 | # to="pickle", 63 | # ) 64 | 65 | 66 | # %% 67 | mcontext = MContext(Path("../output/comparison")) 68 | activations_dataset: list[ActivationResultRow] = mcontext.download_cached( 69 | "activations_dataset_llama2", 70 | path="s3://repeng/datasets/activations/llama-2-7b_2024-01-29_v1.pickle", 71 | to="pickle", 72 | ).get() 73 | 74 | 75 | # %% 76 | @dataclass 77 | class ActivationSplit(Activation): 78 | split: Split 79 | 80 | 81 | @dataclass 82 | class MultipleChoiceProbe(BaseGroupedProbe): 83 | underlying: BaseProbe 84 | flip: bool 85 | 86 | def predict( 87 | self, 88 | activations: Float[np.ndarray, "n d"], # noqa: F722 89 | ) -> "PredictResult": 90 | results = self.underlying.predict(activations) 91 | if self.flip: 92 | results.labels = ~results.labels 93 | results.logits = -results.logits 94 | return results 95 | 96 | def predict_grouped( 97 | self, 98 | activations: Float[np.ndarray, "n d"], # noqa: F722 99 | pairs: Int64[np.ndarray, "n"], # noqa: F821 100 | ) -> "PredictResult": 101 | predictions = self.predict(activations) 102 | for pair in np.unique(pairs): 103 | logits = predictions.logits[pairs == pair] 104 | mean_logit = logits.mean() 105 | max_logit = logits.max() 106 | predictions.logits[pairs == pair] -= mean_logit 107 | predictions.labels[pairs == pair] = logits == max_logit 108 | return predictions 109 | 110 | 111 | def get_group_accuracy( 112 | probe: BaseGroupedProbe, 113 | arrays: LabeledGroupedActivationArray, 114 | ) -> float: 115 | predict_result = probe.predict_grouped( 116 | arrays.activations, 117 | arrays.groups, 118 | ) 119 | accuracy = [] 120 | for group in np.unique(arrays.groups): 121 | group_labels = predict_result.labels[arrays.groups == group] 122 | if group_labels.sum() > 1: 123 | group_labels_one = np.zeros_like(group_labels) 124 | group_labels_one[group_labels.tolist().index(True)] = True 125 | group_labels = group_labels_one 126 | assert arrays.labels[arrays.groups == group].sum() == 1, arrays.labels[ 127 | arrays.groups == group 128 | ] 129 | accuracy.append( 130 | np.all( 131 | predict_result.labels[arrays.groups == group] 132 | == arrays.labels[arrays.groups == group] 133 | ) 134 | ) 135 | return sum(accuracy) / len(accuracy) 136 | 137 | 138 | @dataclass 139 | class SweepSpec: 140 | layer: str 141 | token: int 142 | probe_method: ProbeMethod 143 | dataset: DatasetId 144 | num_samples: int 145 | 146 | 147 | def train_and_eval_probe(spec: SweepSpec) -> dict: 148 | # acts = ( 149 | # activations_dataset.map_cached( 150 | # dataset, 151 | # lambda _, activations, binary_row: ActivationSplit( 152 | # dataset_id=binary_row.dataset_id, 153 | # pair_id=binary_row.pair_id, 154 | # label=binary_row.is_true, 155 | # activations=activations.activations[spec.layer][spec.token], 156 | # split=binary_row.split, 157 | # ), 158 | # ) 159 | # .filter(lambda _, row: row.dataset_id == spec.dataset) 160 | # .get() 161 | # ) 162 | acts = [ 163 | ActivationSplit( 164 | dataset_id=row.dataset_id, 165 | group_id=row.group_id, 166 | label=row.label, 167 | activations=row.activations[spec.layer][spec.token], 168 | split=row.split, 169 | ) 170 | for row in activations_dataset 171 | if row.dataset_id == spec.dataset 172 | ] 173 | 174 | arrays = prepare_activations_for_probes( 175 | [row for row in acts if row.split == "train"][: spec.num_samples] 176 | ) 177 | arrays_val = prepare_activations_for_probes( 178 | [row for row in acts if row.split == "validation"] 179 | ) 180 | assert arrays_val.grouped is not None 181 | assert arrays_val.labeled_grouped is not None 182 | probe = train_probe(spec.probe_method, arrays) 183 | assert probe is not None 184 | accuracy = get_group_accuracy( 185 | MultipleChoiceProbe(probe, flip=False), arrays_val.labeled_grouped 186 | ) 187 | accuracy_flipped = get_group_accuracy( 188 | MultipleChoiceProbe(probe, flip=True), arrays_val.labeled_grouped 189 | ) 190 | question_eval_results = eval_probe_by_question( 191 | probe, 192 | arrays_val.labeled_grouped, 193 | ) 194 | question_eval_results_mc = eval_probe_by_question( 195 | MultipleChoiceProbe(probe, flip=False), 196 | arrays_val.labeled_grouped, 197 | ) 198 | return dict( 199 | **asdict(spec), 200 | accuracy=accuracy, 201 | accuracy_flipped=accuracy_flipped, 202 | question_eval_results=question_eval_results.accuracy, 203 | question_is_flipped=question_eval_results.is_flipped, 204 | question_mc_eval_results=question_eval_results_mc.accuracy, 205 | question_mc_is_flipped=question_eval_results_mc.is_flipped, 206 | ) 207 | 208 | 209 | probe_methods: list[ProbeMethod] = ["lr", "lr-grouped", "lat"] 210 | specs = mcontext.create( 211 | { 212 | f"{layer}-{token}-{probe_method}-{dataset_id}-{num_samples}": SweepSpec( 213 | layer, token, probe_method, dataset_id, num_samples 214 | ) 215 | for layer in [f"h{i}" for i in range(28, 32)] 216 | for token in range(3) 217 | for probe_method in probe_methods 218 | for dataset_id, num_samples in [ 219 | ("arc_easy", 25 * 5), 220 | # ("arc_easy", 50 * 5), 221 | # ("arc_easy", 100 * 5), 222 | # ("common_sense_qa", 7 * 4), 223 | # ("common_sense_qa", 14 * 4), 224 | # ("common_sense_qa", 100 * 4), 225 | # ("race", 3 * 4), 226 | # ("race", 6 * 4), 227 | # ("race", 100 * 4), 228 | ] 229 | } 230 | ) 231 | 232 | probe_results = specs.map_cached( 233 | "probe-results-v4", 234 | lambda _, spec: train_and_eval_probe(spec), 235 | to="pickle", 236 | ) 237 | 238 | # for layer in ["h17", "h18", "h19", "h20"]: 239 | # pprint( 240 | # train_and_eval_probe( 241 | # SweepSpec( 242 | # layer=layer, 243 | # token=2, 244 | # probe_method="lat", 245 | # dataset="arc_easy", 246 | # num_samples=101 * 5, 247 | # ) 248 | # ) 249 | # ) 250 | 251 | # %% 252 | df = probe_results.to_dataframe(lambda d: d) 253 | df["accuracy"] = df.apply( 254 | lambda row: max(row["accuracy"], row["accuracy_flipped"]), axis=1 255 | ) 256 | df = df[df["dataset"] == "arc_easy"] 257 | g = sns.FacetGrid(df, col="probe_method", row="token") 258 | g = g.map(sns.lineplot, "layer", "question_eval_results", "num_samples", marker="o") 259 | g.add_legend() 260 | for ax in g.axes.flat: 261 | ax.set_xticklabels(ax.get_xticklabels(), rotation=90) 262 | 263 | # %% 264 | df.groupby(["dataset", "probe_method"])["accuracy"].max() 265 | 266 | # %% 267 | llm = pythia("pythia-70m", device=torch.device("cpu"), dtype=torch.float32) 268 | 269 | # %% 270 | toks = llm.tokenizer.encode("hello world", return_tensors="pt") 271 | with torch.no_grad(): 272 | point_idx = -6 273 | with grab(llm.model, point=llm.points[point_idx]) as grab_fn: 274 | output = llm.model(toks, output_hidden_states=True) 275 | print(output.hidden_states[point_idx].flatten()[:5].numpy()) 276 | print(grab_fn().flatten()[:5].numpy()) 277 | 278 | # %% 279 | hs = output.hidden_states 280 | print(len(hs)) 281 | print([h.shape for h in hs]) 282 | -------------------------------------------------------------------------------- /experiments/scratch/probe_test.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from pprint import pprint 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import plotly.express as px 9 | import seaborn as sns 10 | import torch 11 | from matplotlib import pyplot as plt 12 | from mppr import MContext 13 | from sklearn.decomposition import PCA 14 | from sklearn.metrics import roc_auc_score 15 | 16 | from repeng.activations.inference import get_model_activations 17 | from repeng.activations.probe_preparations import ActivationArrayDataset 18 | from repeng.datasets.activations.types import ActivationResultRow 19 | from repeng.datasets.elk.types import BinaryRow, DatasetId 20 | from repeng.datasets.elk.utils.collections import get_datasets 21 | from repeng.datasets.elk.utils.limits import limit_dataset_and_split_fn 22 | from repeng.evals.logits import eval_logits_by_question 23 | from repeng.evals.probes import eval_probe_by_question, eval_probe_by_row 24 | from repeng.models.loading import load_llm_oioo 25 | from repeng.models.points import get_points 26 | from repeng.models.types import LlmId 27 | from repeng.probes.collections import ProbeMethod, train_probe 28 | from repeng.probes.logistic_regression import train_lr_probe 29 | 30 | # %% 31 | mcontext = MContext(Path("../output/probe_test")) 32 | dataset = ( 33 | mcontext.create_cached( 34 | "dataset-v4", 35 | lambda: get_datasets( 36 | [ 37 | "arc_easy", 38 | "arc_easy/qna", 39 | # "geometry_of_truth/cities", 40 | # "common_sense_qa/qna", 41 | # "common_sense_qa/options-numbers", 42 | # "common_sense_qa/options-letters", 43 | # "open_book_qa", 44 | ] 45 | ), 46 | to=BinaryRow, 47 | ) 48 | # .filter(lambda _, row: row.dataset_id == "open_book_qa") 49 | .filter( 50 | limit_dataset_and_split_fn(train_limit=2000, validation_limit=500), 51 | ) 52 | ) 53 | 54 | 55 | # %% 56 | pprint(dataset.get()[0].model_dump()) 57 | pprint(dataset.get()[-1].model_dump()) 58 | print(set(d.dataset_id for d in dataset.get())) 59 | 60 | 61 | # %% 62 | @dataclass 63 | class InputRow: 64 | row: BinaryRow 65 | llm_id: LlmId 66 | text: str 67 | 68 | 69 | llm_ids: list[LlmId] = [ 70 | # "pythia-12b", 71 | # "Llama-2-13b-chat-hf", 72 | # "pythia-6.9b", 73 | "Llama-2-7b-hf", 74 | # "Llama-2-7b-chat-hf", 75 | ] 76 | inputs = dataset.flat_map( 77 | lambda key, row: { 78 | f"{key}-{llm_id}": InputRow( 79 | row=row, 80 | llm_id=llm_id, 81 | text=row.text, 82 | ) 83 | for llm_id in llm_ids 84 | } 85 | ).sort(lambda _, row: llm_ids.index(row.llm_id)) 86 | 87 | # %% 88 | activations = inputs.map_cached( 89 | "activations-v4", 90 | lambda _, row: get_model_activations( 91 | load_llm_oioo( 92 | row.llm_id, 93 | device=torch.device("cuda"), 94 | use_half_precision=True, 95 | ), 96 | text=row.text, 97 | last_n_tokens=1, 98 | ), 99 | to="pickle", 100 | ) 101 | 102 | # %% 103 | arrays_dataset = activations.join( 104 | inputs, 105 | lambda _, activations_row, input_row: ActivationResultRow( 106 | dataset_id=input_row.row.dataset_id, 107 | group_id=input_row.row.group_id, 108 | template_name=input_row.row.template_name, 109 | answer_type=input_row.row.answer_type, 110 | activations=activations_row.activations, 111 | prompt_logprobs=activations_row.token_logprobs.sum(), 112 | label=input_row.row.is_true, 113 | split=input_row.row.split, 114 | llm_id=input_row.llm_id, 115 | ), 116 | ) 117 | arrays_dataset = ActivationArrayDataset(arrays_dataset.get()) 118 | 119 | 120 | # %% 121 | def train_and_eval( 122 | llm_id: LlmId, 123 | point: str, 124 | probe_method: ProbeMethod, 125 | dataset_id: DatasetId, 126 | ): 127 | arrays = arrays_dataset.get( 128 | llm_id=llm_id, 129 | dataset_filter_id=dataset_id, 130 | split="train", 131 | point_name=point, 132 | token_idx=-1, 133 | limit=None, 134 | ) 135 | probe = train_probe(probe_method, arrays) 136 | assert probe is not None 137 | arrays_val = arrays_dataset.get( 138 | llm_id=llm_id, 139 | dataset_filter_id=dataset_id, 140 | split="validation", 141 | point_name=point, 142 | token_idx=-1, 143 | limit=None, 144 | ) 145 | assert arrays_val.groups is not None 146 | results = eval_probe_by_question( 147 | probe, 148 | activations=arrays_val.activations, 149 | labels=arrays_val.labels, 150 | groups=arrays_val.groups, 151 | ) 152 | return dict( 153 | llm_id=llm_id, 154 | point=point, 155 | probe_method=probe_method, 156 | dataset_id=dataset_id, 157 | acc=results.accuracy, 158 | ) 159 | 160 | 161 | def eval_logprobs(llm_id: LlmId, dataset_id: DatasetId): 162 | arrays_logits = arrays_dataset.get( 163 | llm_id=llm_id, 164 | dataset_filter_id=dataset_id, 165 | split="validation", 166 | point_name="logprobs", 167 | token_idx=-1, 168 | limit=None, 169 | ) 170 | assert arrays_logits.groups is not None 171 | results = eval_logits_by_question( 172 | logits=arrays_logits.activations, 173 | labels=arrays_logits.labels, 174 | groups=arrays_logits.groups, 175 | ) 176 | return dict( 177 | llm_id=llm_id, 178 | point="logprobs", 179 | probe_method="logprobs", 180 | dataset_id=dataset_id, 181 | acc=results.accuracy, 182 | ) 183 | 184 | 185 | probe_methods: list[ProbeMethod] = ["mmp", "lr"] 186 | dataset_ids: list[DatasetId] = ["arc_easy", "arc_easy/qna"] 187 | df = ( 188 | mcontext.create( 189 | { 190 | f"{llm_id}-{point.name}-{probe_method}-{dataset_id}": ( 191 | llm_id, 192 | point.name, 193 | probe_method, 194 | dataset_id, 195 | ) 196 | for llm_id in llm_ids 197 | for point in get_points(llm_id)[-10:] 198 | for probe_method in probe_methods 199 | for dataset_id in dataset_ids 200 | } 201 | ) 202 | .map_cached( 203 | "train-and-eval", 204 | lambda _, args: train_and_eval( 205 | llm_id=args[0], point=args[1], probe_method=args[2], dataset_id=args[3] 206 | ), 207 | to="pickle", 208 | ) 209 | .to_dataframe(lambda d: d) 210 | ) 211 | df_logprobs = ( 212 | mcontext.create( 213 | { 214 | f"{llm_id}-{dataset_id}": (llm_id, dataset_id) 215 | for llm_id in llm_ids 216 | for dataset_id in dataset_ids 217 | } 218 | ) 219 | .map_cached( 220 | "eval-logprobs", 221 | lambda _, args: eval_logprobs(llm_id=args[0], dataset_id=args[1]), 222 | to="pickle", 223 | ) 224 | .to_dataframe(lambda d: d) 225 | ) 226 | df = pd.concat([df, df_logprobs]) 227 | 228 | 229 | # %% 230 | fig = px.line( 231 | pd.DataFrame(df).sort_values("point"), 232 | x="point", 233 | y="acc", 234 | color="probe_method", 235 | facet_col="llm_id", 236 | facet_row="dataset_id", 237 | markers=True, 238 | ) 239 | fig.update_layout(height=600) 240 | fig.show() 241 | 242 | # %% 243 | acts = activations.join( 244 | inputs, 245 | lambda _, activations_row, input_row: dict( 246 | dataset_id=input_row.row.dataset_id, 247 | label=input_row.row.is_true, 248 | group_id=input_row.row.group_id, 249 | activations=activations_row.token_logprobs.sum(), 250 | llm_id=input_row.llm_id, 251 | ), 252 | ).to_dataframe(lambda d: d) 253 | mean_activations = ( 254 | acts.groupby("group_id")["activations"].mean().rename("mean_activations") 255 | ) 256 | acts = acts.join(mean_activations, on="group_id") 257 | acts["activations"] -= acts["mean_activations"] 258 | acts = acts.groupby("label").sample(100) 259 | 260 | # %% 261 | fig = px.histogram( 262 | acts, 263 | x="activations", 264 | color="label", 265 | facet_col="llm_id", 266 | nbins=50, 267 | opacity=0.5, 268 | barmode="overlay", 269 | ) 270 | fig.show() 271 | 272 | # %% 273 | for model in llm_ids: 274 | acts = activations.join( 275 | inputs.filter(lambda _, row: row.llm_id == model), 276 | lambda _, activations_row, input_row: Activation( 277 | dataset_id=input_row.row.dataset_id, 278 | label=input_row.row.is_true, 279 | group_id=input_row.row.group_id, 280 | activations=activations_row.token_logprobs.sum(), 281 | ), 282 | ).get() 283 | arrays = prepare_activations_for_probes(acts) 284 | assert arrays.labeled_grouped is not None 285 | print(model) 286 | pprint( 287 | eval_logits_by_question( 288 | LabeledGroupedLogits( 289 | logits=arrays.labeled_grouped.activations, 290 | labels=arrays.labeled_grouped.labels, 291 | groups=arrays.labeled_grouped.groups, 292 | ) 293 | ).model_dump() 294 | ) 295 | 296 | 297 | # %% 298 | df = activations.join( 299 | inputs, 300 | lambda _, activations_row, input_row: dict( 301 | activations=activations_row.activations, 302 | label=input_row.row.is_true, 303 | split=input_row.row.split, 304 | logprobs=activations_row.token_logprobs, 305 | model=input_row.llm_id, 306 | few_shot_style=input_row.few_shot_style, 307 | dataset_id=input_row.row.dataset_id, 308 | # answer_tag=input_row.row.answer_tag, 309 | pair_id=input_row.row.group_id, 310 | ), 311 | ).to_dataframe(lambda d: d) 312 | df 313 | 314 | # %% plot top 2 PCA components per-model 315 | for llm_id in llm_ids: 316 | # for dataset_id in df["dataset_id"].unique(): 317 | df2 = df.copy() 318 | df2 = df2[df2["model"] == llm_id] 319 | # df2 = df2[df2["dataset_id"] == dataset_id] 320 | point_name = get_points(llm_id)[-8].name # second-to-last layer 321 | a = np.stack(df2["activations"].apply(lambda a: a[point_name][-1]).to_list()) 322 | for answer_tag in df2["answer_tag"].unique(): 323 | if answer_tag is None: 324 | continue 325 | mask = df2["answer_tag"] == answer_tag 326 | a[mask] -= a[mask].mean(axis=0) 327 | pca = PCA(n_components=2) 328 | pca_fit = pca.fit_transform(a) 329 | df2["pca_0"] = pca_fit[:, 0] 330 | df2["pca_1"] = pca_fit[:, 1] 331 | sns.scatterplot(data=df2, x="pca_0", y="pca_1", hue="label") 332 | plt.title(llm_id) 333 | plt.show() 334 | 335 | # %% 336 | df2 = df.copy() 337 | # df2 = df2[df2["dataset_id"] == "common_sense_qa"] 338 | df2 = df2[df2["dataset_id"] == "open_book_qa"] 339 | df2["activation"] = df2["activations"].apply(lambda a: a["h13"][-1]) 340 | 341 | # tag_means = ( 342 | # df2.groupby(["answer_tag"])["activation"] 343 | # .apply(lambda a: np.mean(a, axis=0)) 344 | # .reset_index() 345 | # .rename(columns={"activation": "tag_mean"}) 346 | # .set_index("answer_tag") 347 | # ) 348 | # df2 = df2.join(tag_means, on="answer_tag") 349 | # df2["activation"] -= df2["tag_mean"] 350 | 351 | # all_mean = df2["activation"].mean(axis=0) 352 | # df2["activation"] = df2["activation"].apply(lambda a: a - all_mean) 353 | 354 | # pair_means = ( 355 | # df2.groupby(["pair_id"])["activation"] 356 | # .apply(lambda a: np.mean(a, axis=0)) 357 | # .rename("pair_mean") 358 | # ) 359 | # df2 = df2.join(pair_means, on="pair_id") 360 | # df2["activation"] -= df2["pair_mean"] 361 | 362 | # pair_std = ( 363 | # df2.groupby(["pair_id"])["activation"] 364 | # .apply(lambda a: np.std(np.stack(a), axis=0) + 1e-6) 365 | # .rename("pair_std") 366 | # ) 367 | # df2 = df2.join(pair_std, on="pair_id") 368 | # df2["activation"] /= df2["pair_std"] 369 | 370 | # pca = PCA(n_components=2) 371 | # pca_fit = pca.fit_transform(np.stack(df2["activation"].to_list())) 372 | # df2["pca_0"] = pca_fit[:, 0] 373 | # df2["pca_1"] = pca_fit[:, 1] 374 | # sns.scatterplot(data=df2, x="pca_0", y="pca_1", hue="label") 375 | # plt.show() 376 | 377 | df2_train = df2[df2["split"] == "train"].copy() 378 | # pair_means = ( 379 | # df2_train.groupby(["pair_id"])["activation"] 380 | # .apply(lambda a: np.mean(a, axis=0)) 381 | # .rename("pair_mean") 382 | # ) 383 | # df2_train = df2_train.join(pair_means, on="pair_id") 384 | # df2_train["activation"] -= df2_train["pair_mean"] 385 | 386 | df2_val = df2[df2["split"] == "validation"] 387 | train_limit = 2000 388 | print(train_limit, len(df2_train)) 389 | labelled_activation_array = LabeledActivationArray( 390 | activations=np.array(df2_train["activation"].to_list()[:train_limit]), 391 | labels=np.array(df2_train["label"].to_list()[:train_limit]), 392 | ) 393 | activation_array = ActivationArray( 394 | activations=np.array(df2_train["activation"].to_list()[:train_limit]), 395 | ) 396 | labeled_grouped_activation_array = LabeledGroupedActivationArray( 397 | activations=np.array(df2_train["activation"].to_list()), 398 | labels=np.array(df2_train["label"].to_list()), 399 | groups=np.array(df2_train["pair_id"].to_list()), 400 | ) 401 | activation_array_val = LabeledActivationArray( 402 | activations=np.array(df2_val["activation"].to_list()), 403 | labels=np.array(df2_val["label"].to_list()), 404 | ) 405 | 406 | probe = train_lr_probe(labelled_activation_array) 407 | # probe = train_grouped_lr_probe(labeled_grouped_activation_array) 408 | # probe = train_mmp_probe(labelled_activation_array, use_iid=False) 409 | # probe = train_lat_probe( 410 | # activation_array, 411 | # LatTrainingConfig(num_random_pairs=5000), 412 | # ) 413 | eval = eval_probe_by_row(probe, activation_array_val) 414 | plt.plot(eval.fprs, eval.tprs) 415 | print(eval.roc_auc_score) 416 | # print(eval.f1_score) 417 | 418 | # %% 419 | df2 = df.copy() 420 | df2["logprob"] = df2["logprobs"].apply(lambda x: x.squeeze(1).sum()) 421 | df2["prob"] = df2["logprob"].apply(lambda x: np.exp(x)) 422 | pair_prob_denom = ( 423 | df2.groupby(["pair_id", "dataset_id"])["prob"] 424 | .apply(lambda x: sum(x)) 425 | .rename("sum_prob") 426 | ) 427 | df2 = df2.join(pair_prob_denom, on=["pair_id", "dataset_id"]) 428 | df2["prob_norm"] = df2["prob"] / df2["sum_prob"] 429 | df2 430 | sns.barplot(data=df2, x="dataset_id", y="prob_norm", hue="label") 431 | plt.xticks(rotation=45) 432 | 433 | # %% 434 | df2 = df.copy() 435 | df2["logprob"] = df2["logprobs"].apply( 436 | lambda x: np.exp(x.squeeze(1).sum().astype("float64")) 437 | ) 438 | # df2 = df2.dropna() 439 | for dataset_id in df2["dataset_id"].unique(): 440 | df3 = df2[df2["dataset_id"] == dataset_id] 441 | logprobs_and_labels = df3.groupby("pair_id").apply( 442 | lambda a: (np.array(a.logprob), np.array(a.label)) 443 | ) 444 | logprobs = logprobs_and_labels.apply(lambda a: a[0] / a[0].sum()) 445 | labels = logprobs_and_labels.apply(lambda a: a[1]) 446 | assert all(logprobs.index == labels.index) 447 | logprobs = np.stack(logprobs.to_numpy()) 448 | labels = np.stack(labels.to_numpy()) 449 | print(dataset_id, roc_auc_score(labels, logprobs)) 450 | 451 | # %% plot logprobs by model and few-shot style 452 | df["logprob"] = df["logprobs"].apply(lambda x: x.sum()) 453 | g = sns.FacetGrid( 454 | df, 455 | col="model", 456 | row="few_shot_style", 457 | hue="label", 458 | margin_titles=True, 459 | ) 460 | g.map(sns.histplot, "logprob", edgecolor="w").add_legend() 461 | plt.show() 462 | 463 | # %% train and evaluate probes 464 | probe_arrays = prepare_activations_for_probes( 465 | [ 466 | Activation( 467 | dataset_id="geometry_of_truth-cities", 468 | group_id=None, 469 | activations=row.activations[get_points(row.model)[-2].name][-1], 470 | label=row.label, 471 | ) 472 | for row in df.itertuples() 473 | if row.split == "train" and row.model == "pythia-1b" 474 | ] 475 | ) 476 | probe_arrays_val = prepare_activations_for_probes( 477 | [ 478 | Activation( 479 | dataset_id="geometry_of_truth-cities", 480 | group_id=None, 481 | activations=row.activations[get_points(row.model)[-2].name][-1], 482 | label=row.label, 483 | ) 484 | for row in df.itertuples() 485 | if row.split == "validation" and row.model == "pythia-1b" 486 | ] 487 | ) 488 | 489 | probe_methods: list[ProbeMethod] = ["lat", "mmp", "lr"] 490 | df_eval = pd.DataFrame( 491 | [ 492 | dict( 493 | probe_method=probe_method, 494 | **eval_probe_by_row( 495 | train_probe(probe_method, probe_arrays), probe_arrays_val.labeled 496 | ).model_dump(), 497 | ) 498 | for probe_method in probe_methods 499 | ], 500 | ) 501 | df_eval[["probe_method", "f1_score", "precision", "recall", "roc_auc_score"]] 502 | 503 | # %% plot ROC curves 504 | fig, axs = plt.subplots(1, 3, figsize=(15, 5)) 505 | for ax, row in zip(axs, df_eval.itertuples()): 506 | ax.plot(row.fprs, row.tprs) 507 | ax.set_title(row.probe_method) 508 | 509 | # %% 510 | np.std([np.array([1, 2, 3]), np.array([1, 2, 3])], axis=0) 511 | 512 | # %% 513 | df["activations"].apply(lambda a: a["h14"][-1]).mean() 514 | # sorted(df["pair_id"].unique()) 515 | acts = np.stack(df["activations"].apply(lambda a: a["h14"][-1])) 516 | print(acts.shape) 517 | print(np.cov(acts.T).shape) 518 | print(acts.mean(axis=0)[:5]) 519 | print(np.cov(acts.T).flatten()[:5]) 520 | -------------------------------------------------------------------------------- /experiments/saliency.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from dataclasses import asdict, dataclass 3 | from pathlib import Path 4 | from typing import Sequence, cast 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import plotly.express as px 9 | from mppr import MContext 10 | from pydantic import BaseModel 11 | from sklearn.decomposition import PCA 12 | 13 | from repeng.activations.probe_preparations import ActivationArrayDataset 14 | from repeng.datasets.activations.types import ActivationResultRow 15 | from repeng.datasets.elk.utils.filters import DatasetFilter, DatasetIdFilter 16 | from repeng.evals.logits import eval_logits_by_question, eval_logits_by_row 17 | from repeng.evals.probes import eval_probe_by_question 18 | from repeng.models.points import get_points 19 | from repeng.models.types import LlmId 20 | from repeng.probes.base import BaseProbe 21 | from repeng.probes.collections import SUPERVISED_PROBES, ProbeMethod, train_probe 22 | from repeng.probes.difference_in_means import train_dim_probe 23 | 24 | LLM_IDS: list[LlmId] = [ 25 | # "Llama-2-7b-hf", 26 | # "Llama-2-7b-chat-hf", 27 | "Llama-2-13b-hf", 28 | "Llama-2-13b-chat-hf", 29 | "gemma-2b", 30 | "gemma-2b-it", 31 | # "gemma-7b", 32 | # "gemma-7b-it", 33 | # "Mistral-7B", 34 | # "Mistral-7B-Instruct", 35 | ] 36 | DATASETS = [ 37 | DatasetIdFilter("boolq/simple"), 38 | DatasetIdFilter("imdb/simple"), 39 | DatasetIdFilter("race/simple"), 40 | DatasetIdFilter("got_cities"), 41 | ] 42 | CHAT_MODELS: list[LlmId] = [ 43 | "Llama-2-7b-chat-hf", 44 | "Llama-2-13b-chat-hf", 45 | "gemma-2b-it", 46 | "gemma-7b-it", 47 | "Mistral-7B-Instruct", 48 | ] 49 | MODEL_FAMILIES: dict[LlmId, str] = { 50 | "Llama-2-7b-hf": "Llama-2-7b", 51 | "Llama-2-7b-chat-hf": "Llama-2-7b", 52 | "Llama-2-13b-hf": "Llama-2-13b", 53 | "Llama-2-13b-chat-hf": "Llama-2-13b", 54 | "gemma-2b": "gemma-2b", 55 | "gemma-2b-it": "gemma-2b", 56 | "gemma-7b": "gemma-7b", 57 | "gemma-7b-it": "gemma-7b", 58 | "Mistral-7B": "Mistral-7B", 59 | "Mistral-7B-Instruct": "Mistral-7B", 60 | } 61 | 62 | # %% 63 | output = Path("../output/saliency") 64 | mcontext = MContext(output) 65 | activation_results: list[ActivationResultRow] = mcontext.download_cached( 66 | "activations_results", 67 | path="s3://repeng/datasets/activations/saliency_2024-02-26_v1.pickle", 68 | to="pickle", 69 | ).get() 70 | dataset = ActivationArrayDataset(activation_results) 71 | 72 | 73 | # %% 74 | """ 75 | Pipeline for training and evaluating probes. 76 | """ 77 | 78 | 79 | @dataclass 80 | class Spec: 81 | llm_id: LlmId 82 | dataset: DatasetFilter 83 | probe_method: ProbeMethod 84 | point_name: str 85 | 86 | 87 | @dataclass 88 | class EvalResult: 89 | accuracy: float 90 | n: int 91 | 92 | 93 | class PipelineResultRow(BaseModel, extra="forbid"): 94 | llm_id: LlmId 95 | dataset: str 96 | probe_method: ProbeMethod 97 | point_name: str 98 | accuracy: float 99 | accuracy_n: int 100 | 101 | 102 | token_idxs: list[int] = [-1] 103 | 104 | 105 | def run_probe_pipeline( 106 | llm_ids: list[LlmId], 107 | datasets: Sequence[DatasetFilter], 108 | probe_methods: list[ProbeMethod], 109 | ) -> list[PipelineResultRow]: 110 | train_specs = mcontext.create( 111 | { 112 | "-".join( 113 | [llm_id, str(train_dataset), probe_method, point.name, str(token_idx)] 114 | ): Spec( 115 | llm_id=llm_id, 116 | dataset=train_dataset, 117 | probe_method=probe_method, 118 | point_name=point.name, 119 | ) 120 | for llm_id in llm_ids 121 | for train_dataset in datasets 122 | for probe_method in probe_methods 123 | for point in get_points(llm_id)[1::2] 124 | for token_idx in token_idxs 125 | } 126 | ) 127 | probes = train_specs.map_cached( 128 | "probe_train", 129 | lambda _, spec: train_probe( 130 | spec.probe_method, 131 | dataset.get( 132 | llm_id=spec.llm_id, 133 | dataset_filter=spec.dataset, 134 | split="train", 135 | point_name=spec.point_name, 136 | token_idx=-1, 137 | limit=None, 138 | ), 139 | ), 140 | to="pickle", 141 | ).filter(lambda _, probe: probe is not None) 142 | return ( 143 | probes.join( 144 | train_specs, 145 | lambda _, probe, spec: (probe, spec), 146 | ) 147 | .map_cached( 148 | "probe_evaluate", 149 | lambda _, args: _eval_probe(cast(BaseProbe, args[0]), args[1]), 150 | to=PipelineResultRow, 151 | ) 152 | .get() 153 | ) 154 | 155 | 156 | def _eval_probe(probe: BaseProbe, spec: Spec) -> PipelineResultRow: 157 | arrays = dataset.get( 158 | llm_id=spec.llm_id, 159 | dataset_filter=spec.dataset, 160 | split="validation", 161 | point_name=spec.point_name, 162 | token_idx=-1, 163 | limit=None, 164 | ) 165 | assert arrays.groups is not None 166 | question_result = eval_probe_by_question( 167 | probe, 168 | activations=arrays.activations, 169 | labels=arrays.labels, 170 | groups=arrays.groups, 171 | ) 172 | return PipelineResultRow( 173 | llm_id=spec.llm_id, 174 | dataset=spec.dataset.get_name(), 175 | probe_method=spec.probe_method, 176 | point_name=spec.point_name, 177 | accuracy=question_result.accuracy, 178 | accuracy_n=question_result.n, 179 | ) 180 | 181 | 182 | # %% 183 | """ 184 | Pipeline for evaluating the zero-shot performance of models, based on logprobs. 185 | """ 186 | 187 | 188 | @dataclass 189 | class LogprobEvalSpec: 190 | llm_id: LlmId 191 | dataset: DatasetFilter 192 | 193 | 194 | class LogprobsPipelineResultRow(BaseModel, extra="forbid"): 195 | llm_id: LlmId 196 | dataset: str 197 | accuracy: float 198 | 199 | 200 | def run_logprobs_pipeline( 201 | llm_ids: list[LlmId], 202 | datasets: Sequence[DatasetFilter], 203 | ) -> list[LogprobsPipelineResultRow]: 204 | return ( 205 | mcontext.create( 206 | { 207 | f"{llm_id}-{eval_dataset}": LogprobEvalSpec(llm_id, eval_dataset) 208 | for llm_id in llm_ids 209 | for eval_dataset in datasets 210 | } 211 | ) 212 | .map_cached( 213 | "logprob_evaluate", 214 | lambda _, spec: _eval_logprobs(spec), 215 | to="pickle", 216 | ) 217 | .get() 218 | ) 219 | 220 | 221 | def _eval_logprobs(spec: LogprobEvalSpec) -> LogprobsPipelineResultRow: 222 | arrays = dataset.get( 223 | llm_id=spec.llm_id, 224 | dataset_filter=spec.dataset, 225 | split="train", 226 | point_name="logprobs", 227 | token_idx=-1, 228 | limit=None, 229 | ) 230 | row_result = eval_logits_by_row( 231 | logits=arrays.activations, 232 | labels=arrays.labels, 233 | ) 234 | question_result = None 235 | if arrays.groups is not None: 236 | question_result = eval_logits_by_question( 237 | logits=arrays.activations, 238 | labels=arrays.labels, 239 | groups=arrays.groups, 240 | ) 241 | return LogprobsPipelineResultRow( 242 | llm_id=spec.llm_id, 243 | dataset=spec.dataset.get_name(), 244 | accuracy=question_result.accuracy if question_result else row_result.accuracy, 245 | ) 246 | 247 | 248 | # %% 249 | """ 250 | Pipeline for calculating saliency. 251 | """ 252 | NUM_COMPONENTS: int = 1024 253 | 254 | 255 | @dataclass 256 | class SaliencySpec: 257 | llm_id: LlmId 258 | dataset: DatasetFilter 259 | point_name: str 260 | 261 | 262 | @dataclass 263 | class SaliencyResult: 264 | llm_id: LlmId 265 | dataset: str 266 | point_name: str 267 | saliency: float 268 | saliency_rank: int 269 | saliency_ratio: float 270 | probe_accuracy: float 271 | probe_accuracy_n: int 272 | 273 | 274 | def run_saliency_pipeline( 275 | llm_ids: list[LlmId], 276 | datasets: Sequence[DatasetFilter], 277 | ) -> list[SaliencyResult]: 278 | return ( 279 | mcontext.create( 280 | { 281 | "-".join([llm_id, str(train_dataset), point]): SaliencySpec( 282 | llm_id=llm_id, 283 | dataset=train_dataset, 284 | point_name=point, 285 | ) 286 | for llm_id in llm_ids 287 | for train_dataset in datasets 288 | for point in _get_points(llm_id) 289 | } 290 | ) 291 | .map_cached( 292 | "saliency-v2", 293 | lambda _, spec: _compare_stds(spec), 294 | "pickle", 295 | ) 296 | .get() 297 | ) 298 | 299 | 300 | def _get_points(llm_id: LlmId) -> list[str]: 301 | points = [point.name for point in get_points(llm_id)] 302 | points_skipped = points[1::2] 303 | if points[-1] not in points_skipped: 304 | points_skipped.append(points[-1]) 305 | return points_skipped 306 | 307 | 308 | def _compare_stds(spec: SaliencySpec) -> SaliencyResult: 309 | arrays = dataset.get( 310 | llm_id=spec.llm_id, 311 | dataset_filter=spec.dataset, 312 | split="train", 313 | point_name=spec.point_name, 314 | token_idx=-1, 315 | limit=None, 316 | ) 317 | probe = train_dim_probe( 318 | activations=arrays.activations, 319 | labels=arrays.labels, 320 | ) 321 | pca = PCA(n_components=NUM_COMPONENTS).fit(arrays.activations) 322 | 323 | arrays_val = dataset.get( 324 | llm_id=spec.llm_id, 325 | dataset_filter=spec.dataset, 326 | split="validation", 327 | point_name=spec.point_name, 328 | token_idx=-1, 329 | limit=None, 330 | ) 331 | truth_direction = probe.probe 332 | if np.linalg.norm(truth_direction) > 0: 333 | truth_direction /= np.linalg.norm(truth_direction) 334 | truth_activations = arrays_val.activations @ truth_direction 335 | truth_variance = np.var(truth_activations) 336 | total_variance = np.var(arrays_val.activations, axis=0).sum() 337 | saliency = truth_variance / total_variance 338 | 339 | pca_variances = pca.transform(arrays_val.activations).var(axis=0) 340 | saliency_rank = np.sum(pca_variances > truth_variance) 341 | 342 | saliency_ratio = truth_variance / pca_variances[0] 343 | 344 | assert arrays_val.groups is not None 345 | eval_results = eval_probe_by_question( 346 | probe, 347 | activations=arrays_val.activations, 348 | labels=arrays_val.labels, 349 | groups=arrays_val.groups, 350 | ) 351 | 352 | return SaliencyResult( 353 | llm_id=spec.llm_id, 354 | dataset=spec.dataset.get_name(), 355 | point_name=spec.point_name, 356 | saliency=saliency, 357 | saliency_rank=saliency_rank, 358 | saliency_ratio=saliency_ratio, 359 | probe_accuracy=eval_results.accuracy, 360 | probe_accuracy_n=eval_results.n, 361 | ) 362 | 363 | 364 | # # %% 365 | # """ 366 | # Pipeline for LR probes on PCA subsets. 367 | # """ 368 | # COMPONENT_INDICES: list[int] = [ 369 | # *range(1, 16), 370 | # *[2**i for i in range(int(math.log(16, 2)) + 1, int(math.log(512, 2)))], 371 | # 512, 372 | # ] 373 | 374 | 375 | # @dataclass 376 | # class PcaSubset: 377 | # components: Float[np.ndarray, "n d"] # noqa: F722 378 | # spec: Spec 379 | 380 | 381 | # class PcaPipelineResultRow(BaseModel, extra="forbid"): 382 | # llm_id: LlmId 383 | # dataset: str 384 | # point_name: str 385 | # num_components: int 386 | # accuracy: float 387 | # accuracy_n: int 388 | 389 | 390 | # def run_pca_pipeline( 391 | # llm_ids: list[LlmId], 392 | # datasets: Sequence[DatasetFilter], 393 | # ) -> list[PcaPipelineResultRow]: 394 | # train_specs = mcontext.create( 395 | # { 396 | # "-".join([llm_id, str(dataset), point.name]): Spec( 397 | # llm_id=llm_id, 398 | # dataset=dataset, 399 | # probe_method="pca", 400 | # point_name=point.name, 401 | # ) 402 | # for llm_id in llm_ids 403 | # for dataset in datasets 404 | # for point in get_points(llm_id)[17:18] 405 | # } 406 | # ) 407 | # pca_subsets = ( 408 | # train_specs.map_cached( 409 | # "pca", 410 | # lambda _, spec: PCA(n_components=max(COMPONENT_INDICES)).fit( 411 | # dataset.get( 412 | # llm_id=spec.llm_id, 413 | # dataset_filter=spec.dataset, 414 | # split="train", 415 | # point_name=spec.point_name, 416 | # token_idx=-1, 417 | # limit=None, 418 | # ).activations 419 | # ), 420 | # to="pickle", 421 | # ) 422 | # .join(train_specs, lambda _, pca, spec: (pca, spec)) 423 | # .flat_map( 424 | # lambda key, args: { 425 | # f"{key}-{i}": PcaSubset( 426 | # components=args[0].components_[:i], spec=args[1] 427 | # ) 428 | # for i in COMPONENT_INDICES 429 | # } 430 | # ) 431 | # ) 432 | # return ( 433 | # pca_subsets.map_cached( 434 | # "pca_train_lr", 435 | # lambda _, pca_subset: _train_probe_on_pca_subset(pca_subset), 436 | # to="pickle", 437 | # ) 438 | # .join(pca_subsets, lambda _, probe, pca_subsets: (probe, pca_subsets)) 439 | # .map_cached( 440 | # "pca_eval_lr", 441 | # lambda _, args: _eval_probe_on_pca_subset(args[0], args[1]), 442 | # to=PcaPipelineResultRow, 443 | # ) 444 | # .get() 445 | # ) 446 | 447 | 448 | # def _train_probe_on_pca_subset(pca_subset: PcaSubset) -> BaseProbe: 449 | # arrays = dataset.get( 450 | # llm_id=pca_subset.spec.llm_id, 451 | # dataset_filter=pca_subset.spec.dataset, 452 | # split="train", 453 | # point_name=pca_subset.spec.point_name, 454 | # token_idx=-1, 455 | # limit=None, 456 | # ) 457 | # activations = arrays.activations @ pca_subset.components.T 458 | # return train_lr_probe( 459 | # LrConfig(), 460 | # activations=activations, 461 | # labels=arrays.labels, 462 | # ) 463 | 464 | 465 | # def _eval_probe_on_pca_subset( 466 | # probe: BaseProbe, pca_subset: PcaSubset 467 | # ) -> PcaPipelineResultRow: 468 | # arrays = dataset.get( 469 | # llm_id=pca_subset.spec.llm_id, 470 | # dataset_filter=pca_subset.spec.dataset, 471 | # split="validation", 472 | # point_name=pca_subset.spec.point_name, 473 | # token_idx=-1, 474 | # limit=None, 475 | # ) 476 | # activations = arrays.activations @ pca_subset.components.T 477 | # assert arrays.groups is not None 478 | # question_result = eval_probe_by_question( 479 | # probe, 480 | # activations=activations, 481 | # labels=arrays.labels, 482 | # groups=arrays.groups, 483 | # ) 484 | # return PcaPipelineResultRow( 485 | # llm_id=pca_subset.spec.llm_id, 486 | # dataset=pca_subset.spec.dataset.get_name(), 487 | # point_name=pca_subset.spec.point_name, 488 | # num_components=pca_subset.components.shape[0], 489 | # accuracy=question_result.accuracy, 490 | # accuracy_n=question_result.n, 491 | # ) 492 | 493 | 494 | # results = run_pca_pipeline( 495 | # llm_ids=[ 496 | # # "Llama-2-13b-hf", 497 | # # "Llama-2-13b-chat-hf", 498 | # "Llama-2-7b-hf", 499 | # "Llama-2-7b-chat-hf", 500 | # ], 501 | # datasets=[ 502 | # DatasetIdFilter("boolq/simple"), 503 | # # DATASET_FILTER_FNS["got_cities/pos"], 504 | # ], 505 | # ) 506 | # df = pd.DataFrame([r.model_dump() for r in results]) 507 | # px.line( 508 | # df, 509 | # x="num_components", 510 | # y="accuracy", 511 | # color="llm_id", 512 | # facet_row="dataset", 513 | # log_x=True, 514 | # ) 515 | 516 | # %% 517 | """ 518 | Show that PCA works on chat models, but not on non-chat models. 519 | """ 520 | results = run_probe_pipeline( 521 | # llm_ids=["Llama-2-7b-hf", "Llama-2-7b-chat-hf"], 522 | # llm_ids=["Llama-2-13b-hf", "Llama-2-13b-chat-hf"], 523 | llm_ids=["gemma-2b", "gemma-2b-it"], 524 | # llm_ids=["gemma-7b", "gemma-7b-it"], 525 | datasets=[ 526 | # DatasetIdFilter("got_cities"), 527 | # DatasetIdFilter("race/simple"), 528 | DatasetIdFilter("boolq/simple"), 529 | ], 530 | probe_methods=["lr", "pca-g"], 531 | ) 532 | df = pd.DataFrame([r.model_dump() for r in results]) 533 | df["layer"] = df["point_name"].str.extract(r"h(\d+)").astype(int) 534 | df["supervised"] = np.where( 535 | df["probe_method"].isin(SUPERVISED_PROBES), "supervised", "unsupervised" 536 | ) 537 | df["algorithm"] = df["probe_method"].replace( 538 | {"lr": "Supervised (LogR)", "pca-g": "Unsupervised (PCA-G)"} 539 | ) 540 | df["accuracy_stderr"] = np.sqrt( 541 | df["accuracy"] * (1 - df["accuracy"]) / df["accuracy_n"] 542 | ) 543 | df["type"] = np.where(df["llm_id"].isin(CHAT_MODELS), "chat", "base") 544 | 545 | fig = px.line( 546 | df.sort_values(["layer", "supervised", "algorithm"]), 547 | x="layer", 548 | y="accuracy", 549 | error_y="accuracy_stderr", 550 | color="algorithm", 551 | facet_row="type", 552 | width=800, 553 | height=500, 554 | ) 555 | fig.update_layout(yaxis_tickformat=".0%") 556 | fig.write_image(output / "1_lr_v_pca.png", scale=3) 557 | fig.show() 558 | 559 | # %% 560 | """ 561 | Show that chat and base models have the same zero-shot accuracy. 562 | """ 563 | results = run_logprobs_pipeline( 564 | llm_ids=LLM_IDS, 565 | datasets=DATASETS, 566 | ) 567 | df = pd.DataFrame([r.model_dump() for r in results]) 568 | df["family"] = df["llm_id"].map(MODEL_FAMILIES) 569 | df["type"] = np.where(df["llm_id"].isin(CHAT_MODELS), "chat", "base") 570 | fig = px.bar( 571 | df.sort_values(["llm_id"]), 572 | x="type", 573 | y="accuracy", 574 | color="type", 575 | facet_col="dataset", 576 | facet_row="family", 577 | category_orders={ 578 | "family": ["Llama-2-7b", "Llama-2-13b", "gemma-2b", "gemma-7b", "Mistral-7B"], 579 | "type": ["base", "chat"], 580 | "dataset": [d.get_name() for d in DATASETS], 581 | }, 582 | text_auto=".1%", # type: ignore 583 | width=800, 584 | height=800, 585 | ) 586 | fig.write_image(output / "2_zero_shot.png", scale=3) 587 | fig.show() 588 | 589 | # %% 590 | """ 591 | Plot saliency measures for a range of models. 592 | """ 593 | results = run_saliency_pipeline( 594 | llm_ids=LLM_IDS, 595 | datasets=DATASETS, 596 | ) 597 | df = pd.DataFrame([asdict(r) for r in results]) 598 | df["type"] = np.where(df["llm_id"].isin(CHAT_MODELS), "chat", "base") 599 | df["family"] = df["llm_id"].map(MODEL_FAMILIES) 600 | df["layer"] = df["point_name"].str.extract(r"h(\d+)").astype(int) 601 | df["layer"] = df.apply( 602 | lambda r: r.layer / len(get_points(r.llm_id)), 603 | axis=1, 604 | ) 605 | 606 | fig = px.line( 607 | df.sort_values("layer"), 608 | x="layer", 609 | y="saliency_ratio", 610 | color="type", 611 | facet_col="dataset", 612 | facet_row="family", 613 | category_orders={ 614 | "family": ["Llama-2-7b", "Llama-2-13b", "gemma-2b", "gemma-7b", "Mistral-7B"], 615 | "type": ["base", "chat"], 616 | "dataset": [d.get_name() for d in DATASETS], 617 | }, 618 | markers=True, 619 | width=800, 620 | height=1000, 621 | ) 622 | fig.update_yaxes(matches=None) 623 | fig.for_each_yaxis(lambda yaxis: yaxis.update(showticklabels=True)) 624 | fig.write_image(output / "3_saliency.png", scale=3) 625 | fig.show() 626 | 627 | # %% 628 | """ 629 | Show correlation between saliency measure and unsupervised probe accuracy 630 | """ 631 | results_saliency = run_saliency_pipeline( 632 | llm_ids=LLM_IDS, 633 | datasets=DATASETS, 634 | ) 635 | results_probe = run_probe_pipeline( 636 | llm_ids=LLM_IDS, 637 | datasets=DATASETS, 638 | probe_methods=["lr", "pca"], 639 | ) 640 | 641 | # %% 642 | df_probe = pd.DataFrame([r.model_dump() for r in results_probe]) 643 | df_probe = df_probe.query("probe_method == 'pca'").join( 644 | df_probe.query("probe_method == 'lr'") 645 | .drop(columns=["probe_method"]) 646 | .rename({"accuracy": "threshold", "accuracy_n": "threshold_n"}, axis=1) 647 | .set_index(["llm_id", "dataset", "point_name"]), 648 | on=["llm_id", "dataset", "point_name"], 649 | ) 650 | df_probe["recovered"] = df_probe["accuracy"] / df_probe["threshold"] 651 | 652 | df = pd.DataFrame([asdict(r) for r in results_saliency]).join( 653 | df_probe.set_index(["llm_id", "dataset", "point_name"]), 654 | on=["llm_id", "dataset", "point_name"], 655 | ) 656 | df["layer"] = df["point_name"].str.extract(r"h(\d+)").astype(int) 657 | px.scatter( 658 | df.query("layer > 10").query("llm_id == 'Llama-2-13b-chat-hf'"), 659 | x="saliency_ratio", 660 | y="recovered", 661 | color="dataset", 662 | symbol="llm_id", 663 | text="layer", 664 | ) 665 | --------------------------------------------------------------------------------